added saving model_summary.txt

This commit is contained in:
iperov 2019-02-21 19:55:44 +04:00
parent f0a20b46d3
commit 97685ce0ae
2 changed files with 52 additions and 37 deletions

View file

@ -106,6 +106,7 @@ class SAEModel(ModelBase):
target_dst_ar = [ Input ( ( bgr_shape[0] // (2**i) ,)*2 + (bgr_shape[-1],) ) for i in range(ms_count-1, -1, -1)]
target_dstm_ar = [ Input ( ( mask_shape[0] // (2**i) ,)*2 + (mask_shape[-1],) ) for i in range(ms_count-1, -1, -1)]
weights_to_load = []
if self.options['archi'] == 'liae':
self.encoder = modelify(SAEModel.LIAEEncFlow(resolution, self.options['lighter_encoder'], ed_ch_dims=ed_ch_dims) ) (Input(bgr_shape))
@ -122,13 +123,14 @@ class SAEModel(ModelBase):
self.decoderm = modelify(SAEModel.LIAEDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5) )) (inter_output_Inputs)
if not self.is_first_run():
self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5))
self.inter_B.load_weights (self.get_strpath_storage_for_file(self.inter_BH5))
self.inter_AB.load_weights (self.get_strpath_storage_for_file(self.inter_ABH5))
self.decoder.load_weights (self.get_strpath_storage_for_file(self.decoderH5))
weights_to_load += [ [self.encoder , 'encoder.h5'],
[self.inter_B , 'inter_B.h5'],
[self.inter_AB, 'inter_AB.h5'],
[self.decoder , 'decoder.h5'],
]
if self.options['learn_mask']:
self.decoderm.load_weights (self.get_strpath_storage_for_file(self.decodermH5))
weights_to_load += [ [self.decoderm, 'decoderm.h5'] ]
warped_src_code = self.encoder (warped_src)
warped_src_inter_AB_code = self.inter_AB (warped_src_code)
warped_src_inter_code = Concatenate()([warped_src_inter_AB_code,warped_src_inter_AB_code])
@ -162,13 +164,15 @@ class SAEModel(ModelBase):
self.decoder_dstm = modelify(SAEModel.DFDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5) )) (dec_Inputs)
if not self.is_first_run():
self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5))
self.decoder_src.load_weights (self.get_strpath_storage_for_file(self.decoder_srcH5))
self.decoder_dst.load_weights (self.get_strpath_storage_for_file(self.decoder_dstH5))
weights_to_load += [ [self.encoder , 'encoder.h5'],
[self.decoder_src, 'decoder_src.h5'],
[self.decoder_dst, 'decoder_dst.h5']
]
if self.options['learn_mask']:
self.decoder_srcm.load_weights (self.get_strpath_storage_for_file(self.decoder_srcmH5))
self.decoder_dstm.load_weights (self.get_strpath_storage_for_file(self.decoder_dstmH5))
weights_to_load += [ [self.decoder_srcm, 'decoder_srcm.h5'],
[self.decoder_dstm, 'decoder_dstm.h5'],
]
warped_src_code = self.encoder (warped_src)
warped_dst_code = self.encoder (warped_dst)
pred_src_src = self.decoder_src(warped_src_code)
@ -193,12 +197,14 @@ class SAEModel(ModelBase):
self.decoder_dstm = modelify(SAEModel.VGDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5) )) (dec_Inputs)
if not self.is_first_run():
self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5))
self.decoder_src.load_weights (self.get_strpath_storage_for_file(self.decoder_srcH5))
self.decoder_dst.load_weights (self.get_strpath_storage_for_file(self.decoder_dstH5))
weights_to_load += [ [self.encoder , 'encoder.h5'],
[self.decoder_src, 'decoder_src.h5'],
[self.decoder_dst, 'decoder_dst.h5']
]
if self.options['learn_mask']:
self.decoder_srcm.load_weights (self.get_strpath_storage_for_file(self.decoder_srcmH5))
self.decoder_dstm.load_weights (self.get_strpath_storage_for_file(self.decoder_dstmH5))
weights_to_load += [ [self.decoder_srcm, 'decoder_srcm.h5'],
[self.decoder_dstm, 'decoder_dstm.h5'],
]
warped_src_code = self.encoder (warped_src)
warped_dst_code = self.encoder (warped_dst)
@ -211,7 +217,9 @@ class SAEModel(ModelBase):
pred_src_srcm = self.decoder_srcm(warped_src_code)
pred_dst_dstm = self.decoder_dstm(warped_dst_code)
pred_src_dstm = self.decoder_srcm(warped_dst_code)
self.load_weights_safe(weights_to_load)
pred_src_src, pred_dst_dst, pred_src_dst, = [ [x] if type(x) != list else x for x in [pred_src_src, pred_dst_dst, pred_src_dst, ] ]
if self.options['learn_mask']: