From a858732b1d17f6333dc51a31d0caefe71a6a58f5 Mon Sep 17 00:00:00 2001 From: iperov Date: Thu, 21 Feb 2019 20:06:04 +0400 Subject: [PATCH] refactoring --- models/ModelBase.py | 2 ++ models/Model_DF/Model.py | 18 ++++++++---------- models/Model_H128/Model.py | 18 ++++++++---------- models/Model_H64/Model.py | 18 ++++++++---------- models/Model_LIAEF128/Model.py | 25 +++++++++++-------------- models/Model_SAE/Model.py | 22 +++++++++++----------- 6 files changed, 48 insertions(+), 55 deletions(-) diff --git a/models/ModelBase.py b/models/ModelBase.py index 42aef51..aff68a4 100644 --- a/models/ModelBase.py +++ b/models/ModelBase.py @@ -286,9 +286,11 @@ class ModelBase(object): def save_weights_safe(self, model_filename_list): for model, filename in model_filename_list: + filename = self.get_strpath_storage_for_file(filename) model.save_weights( filename + '.tmp' ) for model, filename in model_filename_list: + filename = self.get_strpath_storage_for_file(filename) source_filename = Path(filename+'.tmp') target_filename = Path(filename) if target_filename.exists(): diff --git a/models/Model_DF/Model.py b/models/Model_DF/Model.py index 531b3b0..0f52e4b 100644 --- a/models/Model_DF/Model.py +++ b/models/Model_DF/Model.py @@ -8,10 +8,6 @@ from utils.console_utils import * class Model(ModelBase): - encoderH5 = 'encoder.h5' - decoder_srcH5 = 'decoder_src.h5' - decoder_dstH5 = 'decoder_dst.h5' - #override def onInitializeOptions(self, is_first_run, ask_override): if is_first_run or ask_override: @@ -31,9 +27,11 @@ class Model(ModelBase): self.encoder, self.decoder_src, self.decoder_dst = self.Build(ae_input_layer) if not self.is_first_run(): - self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5)) - self.decoder_src.load_weights (self.get_strpath_storage_for_file(self.decoder_srcH5)) - self.decoder_dst.load_weights (self.get_strpath_storage_for_file(self.decoder_dstH5)) + weights_to_load = [ [self.encoder , 'encoder.h5'], + [self.decoder_src, 'decoder_src.h5'], + [self.decoder_dst, 'decoder_dst.h5'] + ] + self.load_weights_safe(weights_to_load) self.autoencoder_src = Model([ae_input_layer,mask_layer], self.decoder_src(self.encoder(ae_input_layer))) self.autoencoder_dst = Model([ae_input_layer,mask_layer], self.decoder_dst(self.encoder(ae_input_layer))) @@ -59,9 +57,9 @@ class Model(ModelBase): ]) #override def onSave(self): - self.save_weights_safe( [[self.encoder, self.get_strpath_storage_for_file(self.encoderH5)], - [self.decoder_src, self.get_strpath_storage_for_file(self.decoder_srcH5)], - [self.decoder_dst, self.get_strpath_storage_for_file(self.decoder_dstH5)]] ) + self.save_weights_safe( [[self.encoder, 'encoder.h5'], + [self.decoder_src, 'decoder_src.h5'], + [self.decoder_dst, 'decoder_dst.h5']] ) #override def onTrainOneEpoch(self, sample, generators_list): diff --git a/models/Model_H128/Model.py b/models/Model_H128/Model.py index 28e0016..4e7ec78 100644 --- a/models/Model_H128/Model.py +++ b/models/Model_H128/Model.py @@ -8,10 +8,6 @@ from utils.console_utils import * class Model(ModelBase): - encoderH5 = 'encoder.h5' - decoder_srcH5 = 'decoder_src.h5' - decoder_dstH5 = 'decoder_dst.h5' - #override def onInitializeOptions(self, is_first_run, ask_override): if is_first_run: @@ -35,9 +31,11 @@ class Model(ModelBase): bgr_shape, mask_shape, self.encoder, self.decoder_src, self.decoder_dst = self.Build( self.options['lighter_ae'] ) if not self.is_first_run(): - self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5)) - self.decoder_src.load_weights (self.get_strpath_storage_for_file(self.decoder_srcH5)) - self.decoder_dst.load_weights (self.get_strpath_storage_for_file(self.decoder_dstH5)) + weights_to_load = [ [self.encoder , 'encoder.h5'], + [self.decoder_src, 'decoder_src.h5'], + [self.decoder_dst, 'decoder_dst.h5'] + ] + self.load_weights_safe(weights_to_load) input_src_bgr = Input(bgr_shape) input_src_mask = Input(mask_shape) @@ -74,9 +72,9 @@ class Model(ModelBase): #override def onSave(self): - self.save_weights_safe( [[self.encoder, self.get_strpath_storage_for_file(self.encoderH5)], - [self.decoder_src, self.get_strpath_storage_for_file(self.decoder_srcH5)], - [self.decoder_dst, self.get_strpath_storage_for_file(self.decoder_dstH5)]]) + self.save_weights_safe( [[self.encoder, 'encoder.h5'], + [self.decoder_src, 'decoder_src.h5'], + [self.decoder_dst, 'decoder_dst.h5']] ) #override def onTrainOneEpoch(self, sample, generators_list): diff --git a/models/Model_H64/Model.py b/models/Model_H64/Model.py index 39c7d25..2d04526 100644 --- a/models/Model_H64/Model.py +++ b/models/Model_H64/Model.py @@ -8,10 +8,6 @@ from utils.console_utils import * class Model(ModelBase): - encoderH5 = 'encoder.h5' - decoder_srcH5 = 'decoder_src.h5' - decoder_dstH5 = 'decoder_dst.h5' - #override def onInitializeOptions(self, is_first_run, ask_override): if is_first_run: @@ -37,9 +33,11 @@ class Model(ModelBase): bgr_shape, mask_shape, self.encoder, self.decoder_src, self.decoder_dst = self.Build(self.options['lighter_ae']) if not self.is_first_run(): - self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5)) - self.decoder_src.load_weights (self.get_strpath_storage_for_file(self.decoder_srcH5)) - self.decoder_dst.load_weights (self.get_strpath_storage_for_file(self.decoder_dstH5)) + weights_to_load = [ [self.encoder , 'encoder.h5'], + [self.decoder_src, 'decoder_src.h5'], + [self.decoder_dst, 'decoder_dst.h5'] + ] + self.load_weights_safe(weights_to_load) input_src_bgr = Input(bgr_shape) input_src_mask = Input(mask_shape) @@ -75,9 +73,9 @@ class Model(ModelBase): #override def onSave(self): - self.save_weights_safe( [[self.encoder, self.get_strpath_storage_for_file(self.encoderH5)], - [self.decoder_src, self.get_strpath_storage_for_file(self.decoder_srcH5)], - [self.decoder_dst, self.get_strpath_storage_for_file(self.decoder_dstH5)]] ) + self.save_weights_safe( [[self.encoder, 'encoder.h5'], + [self.decoder_src, 'decoder_src.h5'], + [self.decoder_dst, 'decoder_dst.h5']] ) #override def onTrainOneEpoch(self, sample, generators_list): diff --git a/models/Model_LIAEF128/Model.py b/models/Model_LIAEF128/Model.py index e5674ea..b85c6ad 100644 --- a/models/Model_LIAEF128/Model.py +++ b/models/Model_LIAEF128/Model.py @@ -8,11 +8,6 @@ from utils.console_utils import * class Model(ModelBase): - encoderH5 = 'encoder.h5' - decoderH5 = 'decoder.h5' - inter_BH5 = 'inter_B.h5' - inter_ABH5 = 'inter_AB.h5' - #override def onInitializeOptions(self, is_first_run, ask_override): if is_first_run or ask_override: @@ -32,10 +27,12 @@ class Model(ModelBase): self.encoder, self.decoder, self.inter_B, self.inter_AB = self.Build(ae_input_layer) if not self.is_first_run(): - self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5)) - self.decoder.load_weights (self.get_strpath_storage_for_file(self.decoderH5)) - self.inter_B.load_weights (self.get_strpath_storage_for_file(self.inter_BH5)) - self.inter_AB.load_weights (self.get_strpath_storage_for_file(self.inter_ABH5)) + weights_to_load = [ [self.encoder, 'encoder.h5'], + [self.decoder, 'decoder.h5'], + [self.inter_B, 'inter_B.h5'], + [self.inter_AB, 'inter_AB.h5'] + ] + self.load_weights_safe(weights_to_load) code = self.encoder(ae_input_layer) AB = self.inter_AB(code) @@ -66,11 +63,11 @@ class Model(ModelBase): ]) #override - def onSave(self): - self.save_weights_safe( [[self.encoder, self.get_strpath_storage_for_file(self.encoderH5)], - [self.decoder, self.get_strpath_storage_for_file(self.decoderH5)], - [self.inter_B, self.get_strpath_storage_for_file(self.inter_BH5)], - [self.inter_AB, self.get_strpath_storage_for_file(self.inter_ABH5)]] ) + def onSave(self): + self.save_weights_safe( [[self.encoder, 'encoder.h5'], + [self.decoder, 'decoder.h5'], + [self.inter_B, 'inter_B.h5'], + [self.inter_AB, 'inter_AB.h5']] ) #override def onTrainOneEpoch(self, sample, generators_list): diff --git a/models/Model_SAE/Model.py b/models/Model_SAE/Model.py index 71eb335..88a4f56 100644 --- a/models/Model_SAE/Model.py +++ b/models/Model_SAE/Model.py @@ -349,21 +349,21 @@ class SAEModel(ModelBase): #override def onSave(self): if self.options['archi'] == 'liae': - ar = [[self.encoder, self.get_strpath_storage_for_file(self.encoderH5)], - [self.inter_B, self.get_strpath_storage_for_file(self.inter_BH5)], - [self.inter_AB, self.get_strpath_storage_for_file(self.inter_ABH5)], - [self.decoder, self.get_strpath_storage_for_file(self.decoderH5)] + ar = [[self.encoder, 'encoder.h5'], + [self.inter_B, 'inter_B.h5'], + [self.inter_AB, 'inter_AB.h5'], + [self.decoder, 'decoder.h5'] ] if self.options['learn_mask']: - ar += [ [self.decoderm, self.get_strpath_storage_for_file(self.decodermH5)] ] - else: - ar = [[self.encoder, self.get_strpath_storage_for_file(self.encoderH5)], - [self.decoder_src, self.get_strpath_storage_for_file(self.decoder_srcH5)], - [self.decoder_dst, self.get_strpath_storage_for_file(self.decoder_dstH5)] + ar += [ [self.decoderm, 'decoderm.h5'] ] + elif self.options['archi'] == 'df' or self.options['archi'] == 'vg': + ar = [[self.encoder, 'encoder.h5'], + [self.decoder_src, 'decoder_src.h5'], + [self.decoder_dst, 'decoder_dst.h5'] ] if self.options['learn_mask']: - ar += [ [self.decoder_srcm, self.get_strpath_storage_for_file(self.decoder_srcmH5)], - [self.decoder_dstm, self.get_strpath_storage_for_file(self.decoder_dstmH5)] ] + ar += [ [self.decoder_srcm, 'decoder_srcm.h5'], + [self.decoder_dstm, 'decoder_dstm.h5'] ] self.save_weights_safe(ar)