fix AVATARModel

This commit is contained in:
iperov 2019-08-24 20:33:29 +04:00
parent 5968ac21f6
commit e3505d9b8c

View file

@ -34,13 +34,10 @@ class AVATARModel(ModelBase):
self.options['avatar_type'] = self.options.get('avatar_type', 'source') self.options['avatar_type'] = self.options.get('avatar_type', 'source')
if is_first_run or ask_override: if is_first_run or ask_override:
def_stage = self.options.get('stage', 0) def_stage = self.options.get('stage', 1)
self.options['stage'] = io.input_int("Stage (0, 1, 2 ?:help skip:%d) : " % def_stage, def_stage, [0,1,2], help_message="Train first stage, then second. Tune batch size to maximum possible for both stages.") self.options['stage'] = io.input_int("Stage (0, 1, 2 ?:help skip:%d) : " % def_stage, def_stage, [0,1,2], help_message="Train first stage, then second. Tune batch size to maximum possible for both stages.")
else: else:
self.options['stage'] = self.options.get('stage', 0) self.options['stage'] = self.options.get('stage', 1)
#override #override
def onInitialize(self, batch_size=-1, **in_options): def onInitialize(self, batch_size=-1, **in_options):
@ -67,7 +64,7 @@ class AVATARModel(ModelBase):
if self.is_first_run(): if self.is_first_run():
conv_weights_list = [] conv_weights_list = []
for model in self.get_model_filename_list(): for model, _ in self.get_model_filename_list():
for layer in model.layers: for layer in model.layers:
if type(layer) == keras.layers.Conv2D: if type(layer) == keras.layers.Conv2D:
conv_weights_list += [layer.weights[0]] #Conv2D kernel_weights conv_weights_list += [layer.weights[0]] #Conv2D kernel_weights