diff --git a/models/ModelBase.py b/models/ModelBase.py index fcaddb0..42aef51 100644 --- a/models/ModelBase.py +++ b/models/ModelBase.py @@ -150,34 +150,39 @@ class ModelBase(object): if (self.sample_for_preview is None) or (self.epoch == 0): self.sample_for_preview = self.generate_next_sample() - print ("===== Model summary =====") - print ("== Model name: " + self.get_model_name()) - print ("==") - print ("== Current epoch: " + str(self.epoch) ) - print ("==") - print ("== Model options:") + model_summary_text = [] + + model_summary_text += ["===== Model summary ====="] + model_summary_text += ["== Model name: " + self.get_model_name()] + model_summary_text += ["=="] + model_summary_text += ["== Current epoch: " + str(self.epoch)] + model_summary_text += ["=="] + model_summary_text += ["== Model options:"] for key in self.options.keys(): - print ("== |== %s : %s" % (key, self.options[key]) ) + model_summary_text += ["== |== %s : %s" % (key, self.options[key])] if self.device_config.multi_gpu: - print ("== |== multi_gpu : True ") + model_summary_text += ["== |== multi_gpu : True "] - print ("== Running on:") + model_summary_text += ["== Running on:"] if self.device_config.cpu_only: - print ("== |== [CPU]") + model_summary_text += ["== |== [CPU]"] else: for idx in self.device_config.gpu_idxs: - print ("== |== [%d : %s]" % (idx, nnlib.device.getDeviceName(idx)) ) + model_summary_text += ["== |== [%d : %s]" % (idx, nnlib.device.getDeviceName(idx))] if not self.device_config.cpu_only and self.device_config.gpu_vram_gb[0] == 2: - print ("==") - print ("== WARNING: You are using 2GB GPU. Result quality may be significantly decreased.") - print ("== If training does not start, close all programs and try again.") - print ("== Also you can disable Windows Aero Desktop to get extra free VRAM.") - print ("==") + model_summary_text += ["=="] + model_summary_text += ["== WARNING: You are using 2GB GPU. Result quality may be significantly decreased."] + model_summary_text += ["== If training does not start, close all programs and try again."] + model_summary_text += ["== Also you can disable Windows Aero Desktop to get extra free VRAM."] + model_summary_text += ["=="] - print ("=========================") - + model_summary_text += ["========================="] + model_summary_text = "\r\n".join (model_summary_text) + self.model_summary_text = model_summary_text + print(model_summary_text) + #overridable def onInitializeOptions(self, is_first_run, ask_override): pass @@ -258,7 +263,8 @@ class ModelBase(object): if self.supress_std_once: supressor = std_utils.suppress_stdout_stderr() supressor.__enter__() - + + Path( self.get_strpath_storage_for_file('summary.txt') ).write_text(self.model_summary_text) self.onSave() if self.supress_std_once: @@ -274,6 +280,7 @@ class ModelBase(object): def load_weights_safe(self, model_filename_list): for model, filename in model_filename_list: + filename = self.get_strpath_storage_for_file(filename) if Path(filename).exists(): model.load_weights(filename) diff --git a/models/Model_SAE/Model.py b/models/Model_SAE/Model.py index 8e34cb9..71eb335 100644 --- a/models/Model_SAE/Model.py +++ b/models/Model_SAE/Model.py @@ -106,6 +106,7 @@ class SAEModel(ModelBase): target_dst_ar = [ Input ( ( bgr_shape[0] // (2**i) ,)*2 + (bgr_shape[-1],) ) for i in range(ms_count-1, -1, -1)] target_dstm_ar = [ Input ( ( mask_shape[0] // (2**i) ,)*2 + (mask_shape[-1],) ) for i in range(ms_count-1, -1, -1)] + weights_to_load = [] if self.options['archi'] == 'liae': self.encoder = modelify(SAEModel.LIAEEncFlow(resolution, self.options['lighter_encoder'], ed_ch_dims=ed_ch_dims) ) (Input(bgr_shape)) @@ -122,13 +123,14 @@ class SAEModel(ModelBase): self.decoderm = modelify(SAEModel.LIAEDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5) )) (inter_output_Inputs) if not self.is_first_run(): - self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5)) - self.inter_B.load_weights (self.get_strpath_storage_for_file(self.inter_BH5)) - self.inter_AB.load_weights (self.get_strpath_storage_for_file(self.inter_ABH5)) - self.decoder.load_weights (self.get_strpath_storage_for_file(self.decoderH5)) + weights_to_load += [ [self.encoder , 'encoder.h5'], + [self.inter_B , 'inter_B.h5'], + [self.inter_AB, 'inter_AB.h5'], + [self.decoder , 'decoder.h5'], + ] if self.options['learn_mask']: - self.decoderm.load_weights (self.get_strpath_storage_for_file(self.decodermH5)) - + weights_to_load += [ [self.decoderm, 'decoderm.h5'] ] + warped_src_code = self.encoder (warped_src) warped_src_inter_AB_code = self.inter_AB (warped_src_code) warped_src_inter_code = Concatenate()([warped_src_inter_AB_code,warped_src_inter_AB_code]) @@ -162,13 +164,15 @@ class SAEModel(ModelBase): self.decoder_dstm = modelify(SAEModel.DFDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5) )) (dec_Inputs) if not self.is_first_run(): - self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5)) - self.decoder_src.load_weights (self.get_strpath_storage_for_file(self.decoder_srcH5)) - self.decoder_dst.load_weights (self.get_strpath_storage_for_file(self.decoder_dstH5)) + weights_to_load += [ [self.encoder , 'encoder.h5'], + [self.decoder_src, 'decoder_src.h5'], + [self.decoder_dst, 'decoder_dst.h5'] + ] if self.options['learn_mask']: - self.decoder_srcm.load_weights (self.get_strpath_storage_for_file(self.decoder_srcmH5)) - self.decoder_dstm.load_weights (self.get_strpath_storage_for_file(self.decoder_dstmH5)) - + weights_to_load += [ [self.decoder_srcm, 'decoder_srcm.h5'], + [self.decoder_dstm, 'decoder_dstm.h5'], + ] + warped_src_code = self.encoder (warped_src) warped_dst_code = self.encoder (warped_dst) pred_src_src = self.decoder_src(warped_src_code) @@ -193,12 +197,14 @@ class SAEModel(ModelBase): self.decoder_dstm = modelify(SAEModel.VGDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5) )) (dec_Inputs) if not self.is_first_run(): - self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5)) - self.decoder_src.load_weights (self.get_strpath_storage_for_file(self.decoder_srcH5)) - self.decoder_dst.load_weights (self.get_strpath_storage_for_file(self.decoder_dstH5)) + weights_to_load += [ [self.encoder , 'encoder.h5'], + [self.decoder_src, 'decoder_src.h5'], + [self.decoder_dst, 'decoder_dst.h5'] + ] if self.options['learn_mask']: - self.decoder_srcm.load_weights (self.get_strpath_storage_for_file(self.decoder_srcmH5)) - self.decoder_dstm.load_weights (self.get_strpath_storage_for_file(self.decoder_dstmH5)) + weights_to_load += [ [self.decoder_srcm, 'decoder_srcm.h5'], + [self.decoder_dstm, 'decoder_dstm.h5'], + ] warped_src_code = self.encoder (warped_src) warped_dst_code = self.encoder (warped_dst) @@ -211,7 +217,9 @@ class SAEModel(ModelBase): pred_src_srcm = self.decoder_srcm(warped_src_code) pred_dst_dstm = self.decoder_dstm(warped_dst_code) pred_src_dstm = self.decoder_srcm(warped_dst_code) - + + self.load_weights_safe(weights_to_load) + pred_src_src, pred_dst_dst, pred_src_dst, = [ [x] if type(x) != list else x for x in [pred_src_src, pred_dst_dst, pred_src_dst, ] ] if self.options['learn_mask']: