refactoring

This commit is contained in:
iperov 2019-02-21 20:06:04 +04:00
parent 97685ce0ae
commit a858732b1d
6 changed files with 48 additions and 55 deletions

View file

@ -349,21 +349,21 @@ class SAEModel(ModelBase):
#override
def onSave(self):
if self.options['archi'] == 'liae':
ar = [[self.encoder, self.get_strpath_storage_for_file(self.encoderH5)],
[self.inter_B, self.get_strpath_storage_for_file(self.inter_BH5)],
[self.inter_AB, self.get_strpath_storage_for_file(self.inter_ABH5)],
[self.decoder, self.get_strpath_storage_for_file(self.decoderH5)]
ar = [[self.encoder, 'encoder.h5'],
[self.inter_B, 'inter_B.h5'],
[self.inter_AB, 'inter_AB.h5'],
[self.decoder, 'decoder.h5']
]
if self.options['learn_mask']:
ar += [ [self.decoderm, self.get_strpath_storage_for_file(self.decodermH5)] ]
else:
ar = [[self.encoder, self.get_strpath_storage_for_file(self.encoderH5)],
[self.decoder_src, self.get_strpath_storage_for_file(self.decoder_srcH5)],
[self.decoder_dst, self.get_strpath_storage_for_file(self.decoder_dstH5)]
ar += [ [self.decoderm, 'decoderm.h5'] ]
elif self.options['archi'] == 'df' or self.options['archi'] == 'vg':
ar = [[self.encoder, 'encoder.h5'],
[self.decoder_src, 'decoder_src.h5'],
[self.decoder_dst, 'decoder_dst.h5']
]
if self.options['learn_mask']:
ar += [ [self.decoder_srcm, self.get_strpath_storage_for_file(self.decoder_srcmH5)],
[self.decoder_dstm, self.get_strpath_storage_for_file(self.decoder_dstmH5)] ]
ar += [ [self.decoder_srcm, 'decoder_srcm.h5'],
[self.decoder_dstm, 'decoder_dstm.h5'] ]
self.save_weights_safe(ar)