From 201b762541dbfe5130a9441bf302ab6b58acaf50 Mon Sep 17 00:00:00 2001 From: Colombo Date: Sat, 14 Sep 2019 16:39:47 +0400 Subject: [PATCH] SAE: fix --- models/Model_SAE/Model.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/models/Model_SAE/Model.py b/models/Model_SAE/Model.py index 0b20b92..ccd7e58 100644 --- a/models/Model_SAE/Model.py +++ b/models/Model_SAE/Model.py @@ -114,7 +114,6 @@ class SAEModel(ModelBase): if not self.pretrain: self.options.pop('pretrain') - d_residual_blocks = True bgr_shape = (resolution, resolution, 3) mask_shape = (resolution, resolution, 1) @@ -151,7 +150,7 @@ class SAEModel(ModelBase): return x return func - def dec_flow(output_nc, d_ch_dims): + def dec_flow(output_nc, d_ch_dims, add_residual_blocks=True): def ResidualBlock(dim): def func(inp): x = Conv2D(dim, kernel_size=3, padding='same')(inp) @@ -165,16 +164,22 @@ class SAEModel(ModelBase): def func(x): dims = output_nc * d_ch_dims x = upscale(dims*8)(x) - x = ResidualBlock(dims*8)(x) - x = ResidualBlock(dims*8)(x) + + if add_residual_blocks: + x = ResidualBlock(dims*8)(x) + x = ResidualBlock(dims*8)(x) x = upscale(dims*4)(x) - x = ResidualBlock(dims*4)(x) - x = ResidualBlock(dims*4)(x) + + if add_residual_blocks: + x = ResidualBlock(dims*4)(x) + x = ResidualBlock(dims*4)(x) x = upscale(dims*2)(x) - x = ResidualBlock(dims*2)(x) - x = ResidualBlock(dims*2)(x) + + if add_residual_blocks: + x = ResidualBlock(dims*2)(x) + x = ResidualBlock(dims*2)(x) return Conv2D(output_nc, kernel_size=5, padding='same', activation='sigmoid')(x) return func @@ -186,8 +191,8 @@ class SAEModel(ModelBase): self.decoder_dst = modelify(dec_flow(output_nc, d_ch_dims)) ( Input(sh) ) if learn_mask: - self.decoder_srcm = modelify(dec_flow(1, d_ch_dims)) ( Input(sh) ) - self.decoder_dstm = modelify(dec_flow(1, d_ch_dims)) ( Input(sh) ) + self.decoder_srcm = modelify(dec_flow(1, d_ch_dims, add_residual_blocks=False)) ( Input(sh) ) + self.decoder_dstm = modelify(dec_flow(1, d_ch_dims, add_residual_blocks=False)) ( Input(sh) ) self.src_dst_trainable_weights = self.encoder.trainable_weights + self.decoder_src.trainable_weights + self.decoder_dst.trainable_weights @@ -402,7 +407,6 @@ class SAEModel(ModelBase): if self.is_training_mode: self.src_dst_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, clipnorm=1.0 if self.options['clipgrad'] else 0.0, tf_cpu_mode=self.options['optimizer_mode']-1) self.src_dst_mask_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, clipnorm=1.0 if self.options['clipgrad'] else 0.0, tf_cpu_mode=self.options['optimizer_mode']-1) - self.sr_opt = Adam(lr=5e-5, beta_1=0.9, beta_2=0.999, tf_cpu_mode=self.options['optimizer_mode']-1) if not self.options['pixel_loss']: src_loss = K.mean ( 10*dssim(kernel_size=int(resolution/11.6),max_value=1.0)( target_src_masked_opt, pred_src_src_masked_opt) )