This commit is contained in:
Colombo 2020-01-21 21:05:29 +04:00
parent 56de1d9fa5
commit 9797a70fd3
5 changed files with 112 additions and 47 deletions

View file

@ -278,8 +278,7 @@ class SAEHDModel(ModelBase):
z = inp
if self.is_hd:
x, upx = self.res0(z)
x, upx = self.res0(z)
x = self.upscale0(x)
x = tf.nn.leaky_relu(x + upx, 0.2)
x, upx = self.res1(x)
@ -410,8 +409,8 @@ class SAEHDModel(ModelBase):
self.src_dst_opt = nn.TFRMSpropOptimizer(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='src_dst_opt')
self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ]
if 'df' in archi:
self.src_dst_all_trainable_weights = self.encoder.get_weights() + self.decoder_src.get_weights() + self.decoder_dst.get_weights()
self.src_dst_trainable_weights = self.encoder.get_weights() + self.decoder_src.get_weights_ex(learn_mask) + self.decoder_dst.get_weights_ex(learn_mask)
self.src_dst_all_trainable_weights = self.encoder.get_weights() + self.inter.get_weights() + self.decoder_src.get_weights() + self.decoder_dst.get_weights()
self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter.get_weights() + self.decoder_src.get_weights_ex(learn_mask) + self.decoder_dst.get_weights_ex(learn_mask)
elif 'liae' in archi:
self.src_dst_all_trainable_weights = self.encoder.get_weights() + self.inter_AB.get_weights() + self.inter_B.get_weights() + self.decoder.get_weights()