refactoring

This commit is contained in:
iperov 2019-02-21 20:06:04 +04:00
parent 97685ce0ae
commit a858732b1d
6 changed files with 48 additions and 55 deletions

View file

@ -8,10 +8,6 @@ from utils.console_utils import *
class Model(ModelBase):
encoderH5 = 'encoder.h5'
decoder_srcH5 = 'decoder_src.h5'
decoder_dstH5 = 'decoder_dst.h5'
#override
def onInitializeOptions(self, is_first_run, ask_override):
if is_first_run or ask_override:
@ -31,9 +27,11 @@ class Model(ModelBase):
self.encoder, self.decoder_src, self.decoder_dst = self.Build(ae_input_layer)
if not self.is_first_run():
self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5))
self.decoder_src.load_weights (self.get_strpath_storage_for_file(self.decoder_srcH5))
self.decoder_dst.load_weights (self.get_strpath_storage_for_file(self.decoder_dstH5))
weights_to_load = [ [self.encoder , 'encoder.h5'],
[self.decoder_src, 'decoder_src.h5'],
[self.decoder_dst, 'decoder_dst.h5']
]
self.load_weights_safe(weights_to_load)
self.autoencoder_src = Model([ae_input_layer,mask_layer], self.decoder_src(self.encoder(ae_input_layer)))
self.autoencoder_dst = Model([ae_input_layer,mask_layer], self.decoder_dst(self.encoder(ae_input_layer)))
@ -59,9 +57,9 @@ class Model(ModelBase):
])
#override
def onSave(self):
self.save_weights_safe( [[self.encoder, self.get_strpath_storage_for_file(self.encoderH5)],
[self.decoder_src, self.get_strpath_storage_for_file(self.decoder_srcH5)],
[self.decoder_dst, self.get_strpath_storage_for_file(self.decoder_dstH5)]] )
self.save_weights_safe( [[self.encoder, 'encoder.h5'],
[self.decoder_src, 'decoder_src.h5'],
[self.decoder_dst, 'decoder_dst.h5']] )
#override
def onTrainOneEpoch(self, sample, generators_list):