refactoring

This commit is contained in:
iperov 2019-02-21 20:06:04 +04:00
parent 97685ce0ae
commit a858732b1d
6 changed files with 48 additions and 55 deletions

View file

@ -8,11 +8,6 @@ from utils.console_utils import *
class Model(ModelBase):
encoderH5 = 'encoder.h5'
decoderH5 = 'decoder.h5'
inter_BH5 = 'inter_B.h5'
inter_ABH5 = 'inter_AB.h5'
#override
def onInitializeOptions(self, is_first_run, ask_override):
if is_first_run or ask_override:
@ -32,10 +27,12 @@ class Model(ModelBase):
self.encoder, self.decoder, self.inter_B, self.inter_AB = self.Build(ae_input_layer)
if not self.is_first_run():
self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5))
self.decoder.load_weights (self.get_strpath_storage_for_file(self.decoderH5))
self.inter_B.load_weights (self.get_strpath_storage_for_file(self.inter_BH5))
self.inter_AB.load_weights (self.get_strpath_storage_for_file(self.inter_ABH5))
weights_to_load = [ [self.encoder, 'encoder.h5'],
[self.decoder, 'decoder.h5'],
[self.inter_B, 'inter_B.h5'],
[self.inter_AB, 'inter_AB.h5']
]
self.load_weights_safe(weights_to_load)
code = self.encoder(ae_input_layer)
AB = self.inter_AB(code)
@ -66,11 +63,11 @@ class Model(ModelBase):
])
#override
def onSave(self):
self.save_weights_safe( [[self.encoder, self.get_strpath_storage_for_file(self.encoderH5)],
[self.decoder, self.get_strpath_storage_for_file(self.decoderH5)],
[self.inter_B, self.get_strpath_storage_for_file(self.inter_BH5)],
[self.inter_AB, self.get_strpath_storage_for_file(self.inter_ABH5)]] )
def onSave(self):
self.save_weights_safe( [[self.encoder, 'encoder.h5'],
[self.decoder, 'decoder.h5'],
[self.inter_B, 'inter_B.h5'],
[self.inter_AB, 'inter_AB.h5']] )
#override
def onTrainOneEpoch(self, sample, generators_list):