This commit is contained in:
Colombo 2019-09-14 16:39:47 +04:00
parent b6b92bded0
commit 201b762541

View file

@ -114,7 +114,6 @@ class SAEModel(ModelBase):
if not self.pretrain: if not self.pretrain:
self.options.pop('pretrain') self.options.pop('pretrain')
d_residual_blocks = True
bgr_shape = (resolution, resolution, 3) bgr_shape = (resolution, resolution, 3)
mask_shape = (resolution, resolution, 1) mask_shape = (resolution, resolution, 1)
@ -151,7 +150,7 @@ class SAEModel(ModelBase):
return x return x
return func return func
def dec_flow(output_nc, d_ch_dims): def dec_flow(output_nc, d_ch_dims, add_residual_blocks=True):
def ResidualBlock(dim): def ResidualBlock(dim):
def func(inp): def func(inp):
x = Conv2D(dim, kernel_size=3, padding='same')(inp) x = Conv2D(dim, kernel_size=3, padding='same')(inp)
@ -165,14 +164,20 @@ class SAEModel(ModelBase):
def func(x): def func(x):
dims = output_nc * d_ch_dims dims = output_nc * d_ch_dims
x = upscale(dims*8)(x) x = upscale(dims*8)(x)
if add_residual_blocks:
x = ResidualBlock(dims*8)(x) x = ResidualBlock(dims*8)(x)
x = ResidualBlock(dims*8)(x) x = ResidualBlock(dims*8)(x)
x = upscale(dims*4)(x) x = upscale(dims*4)(x)
if add_residual_blocks:
x = ResidualBlock(dims*4)(x) x = ResidualBlock(dims*4)(x)
x = ResidualBlock(dims*4)(x) x = ResidualBlock(dims*4)(x)
x = upscale(dims*2)(x) x = upscale(dims*2)(x)
if add_residual_blocks:
x = ResidualBlock(dims*2)(x) x = ResidualBlock(dims*2)(x)
x = ResidualBlock(dims*2)(x) x = ResidualBlock(dims*2)(x)
@ -186,8 +191,8 @@ class SAEModel(ModelBase):
self.decoder_dst = modelify(dec_flow(output_nc, d_ch_dims)) ( Input(sh) ) self.decoder_dst = modelify(dec_flow(output_nc, d_ch_dims)) ( Input(sh) )
if learn_mask: if learn_mask:
self.decoder_srcm = modelify(dec_flow(1, d_ch_dims)) ( Input(sh) ) self.decoder_srcm = modelify(dec_flow(1, d_ch_dims, add_residual_blocks=False)) ( Input(sh) )
self.decoder_dstm = modelify(dec_flow(1, d_ch_dims)) ( Input(sh) ) self.decoder_dstm = modelify(dec_flow(1, d_ch_dims, add_residual_blocks=False)) ( Input(sh) )
self.src_dst_trainable_weights = self.encoder.trainable_weights + self.decoder_src.trainable_weights + self.decoder_dst.trainable_weights self.src_dst_trainable_weights = self.encoder.trainable_weights + self.decoder_src.trainable_weights + self.decoder_dst.trainable_weights
@ -402,7 +407,6 @@ class SAEModel(ModelBase):
if self.is_training_mode: if self.is_training_mode:
self.src_dst_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, clipnorm=1.0 if self.options['clipgrad'] else 0.0, tf_cpu_mode=self.options['optimizer_mode']-1) self.src_dst_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, clipnorm=1.0 if self.options['clipgrad'] else 0.0, tf_cpu_mode=self.options['optimizer_mode']-1)
self.src_dst_mask_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, clipnorm=1.0 if self.options['clipgrad'] else 0.0, tf_cpu_mode=self.options['optimizer_mode']-1) self.src_dst_mask_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, clipnorm=1.0 if self.options['clipgrad'] else 0.0, tf_cpu_mode=self.options['optimizer_mode']-1)
self.sr_opt = Adam(lr=5e-5, beta_1=0.9, beta_2=0.999, tf_cpu_mode=self.options['optimizer_mode']-1)
if not self.options['pixel_loss']: if not self.options['pixel_loss']:
src_loss = K.mean ( 10*dssim(kernel_size=int(resolution/11.6),max_value=1.0)( target_src_masked_opt, pred_src_src_masked_opt) ) src_loss = K.mean ( 10*dssim(kernel_size=int(resolution/11.6),max_value=1.0)( target_src_masked_opt, pred_src_src_masked_opt) )