From aa58d9e563aea7367a4df0168ed3655eba6149bc Mon Sep 17 00:00:00 2001 From: iperov Date: Mon, 25 Mar 2019 10:05:30 +0400 Subject: [PATCH] SAE: removed lightweight encoder --- models/Model_SAE/Model.py | 56 +++++++-------------------------------- 1 file changed, 10 insertions(+), 46 deletions(-) diff --git a/models/Model_SAE/Model.py b/models/Model_SAE/Model.py index e7b704a..0005760 100644 --- a/models/Model_SAE/Model.py +++ b/models/Model_SAE/Model.py @@ -73,12 +73,8 @@ class SAEModel(ModelBase): self.options['remove_gray_border'] = self.options.get('remove_gray_border', False) if is_first_run: - self.options['lighter_encoder'] = io.input_bool ("Use lightweight encoder? (y/n, ?:help skip:n) : ", False, help_message="Lightweight encoder is 35% faster, requires less VRAM, but sacrificing overall quality.") - self.options['multiscale_decoder'] = io.input_bool ("Use multiscale decoder? (y/n, ?:help skip:n) : ", False, help_message="Multiscale decoder helps to get better details.") else: - self.options['lighter_encoder'] = self.options.get('lighter_encoder', False) - self.options['multiscale_decoder'] = self.options.get('multiscale_decoder', False) default_face_style_power = 0.0 @@ -133,10 +129,9 @@ class SAEModel(ModelBase): padding = 'reflect' if self.options['remove_gray_border'] else 'zero' common_flow_kwargs = { 'padding': padding } - models_list = [] weights_to_load = [] if self.options['archi'] == 'liae': - self.encoder = modelify(SAEModel.LIAEEncFlow(resolution, self.options['lighter_encoder'], ch_dims=e_ch_dims, **common_flow_kwargs) ) (Input(bgr_shape)) + self.encoder = modelify(SAEModel.LIAEEncFlow(resolution, ch_dims=e_ch_dims, **common_flow_kwargs) ) (Input(bgr_shape)) enc_output_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ] @@ -147,11 +142,8 @@ class SAEModel(ModelBase): self.decoder = modelify(SAEModel.LIAEDecFlow (bgr_shape[2],ch_dims=d_ch_dims, multiscale_count=self.ms_count, add_residual_blocks=d_residual_blocks, **common_flow_kwargs)) (inter_output_Inputs) - models_list += [self.encoder, self.inter_B, self.inter_AB, self.decoder] - if self.options['learn_mask']: self.decoderm = modelify(SAEModel.LIAEDecFlow (mask_shape[2],ch_dims=d_ch_dims, **common_flow_kwargs)) (inter_output_Inputs) - models_list += [self.decoderm] if not self.is_first_run(): weights_to_load += [ [self.encoder , 'encoder.h5'], @@ -183,19 +175,16 @@ class SAEModel(ModelBase): pred_src_dstm = self.decoderm(warped_src_dst_inter_code) elif self.options['archi'] == 'df': - self.encoder = modelify(SAEModel.DFEncFlow(resolution, self.options['lighter_encoder'], ae_dims=ae_dims, ch_dims=e_ch_dims, **common_flow_kwargs) ) (Input(bgr_shape)) + self.encoder = modelify(SAEModel.DFEncFlow(resolution, ae_dims=ae_dims, ch_dims=e_ch_dims, **common_flow_kwargs) ) (Input(bgr_shape)) dec_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ] self.decoder_src = modelify(SAEModel.DFDecFlow (bgr_shape[2],ch_dims=d_ch_dims, multiscale_count=self.ms_count, add_residual_blocks=d_residual_blocks, **common_flow_kwargs )) (dec_Inputs) self.decoder_dst = modelify(SAEModel.DFDecFlow (bgr_shape[2],ch_dims=d_ch_dims, multiscale_count=self.ms_count, add_residual_blocks=d_residual_blocks, **common_flow_kwargs )) (dec_Inputs) - models_list += [self.encoder, self.decoder_src, self.decoder_dst] - if self.options['learn_mask']: self.decoder_srcm = modelify(SAEModel.DFDecFlow (mask_shape[2],ch_dims=d_ch_dims, **common_flow_kwargs )) (dec_Inputs) self.decoder_dstm = modelify(SAEModel.DFDecFlow (mask_shape[2],ch_dims=d_ch_dims, **common_flow_kwargs )) (dec_Inputs) - models_list += [self.decoder_srcm, self.decoder_dstm] if not self.is_first_run(): weights_to_load += [ [self.encoder , 'encoder.h5'], @@ -491,19 +480,6 @@ class SAEModel(ModelBase): return func SAEModel.downscale_pre = downscale_pre - def downscale_sep (dim, padding='zero'): - def func(x): - return LeakyReLU(0.1)(SeparableConv2D(dim, kernel_size=5, strides=2, padding=padding)(x)) - return func - SAEModel.downscale_sep = downscale_sep - - def downscale_sep_pre (**base_kwargs): - def func(*args, **kwargs): - kwargs.update(base_kwargs) - return downscale_sep(*args, **kwargs) - return func - SAEModel.downscale_sep_pre = downscale_sep_pre - def upscale (dim, padding='zero'): def func(x): return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, kernel_size=3, strides=1, padding=padding)(x))) @@ -531,25 +507,19 @@ class SAEModel(ModelBase): SAEModel.to_bgr_pre = to_bgr_pre @staticmethod - def LIAEEncFlow(resolution, light_enc, ch_dims, padding='zero', **kwargs): + def LIAEEncFlow(resolution, ch_dims, padding='zero', **kwargs): exec (nnlib.import_all(), locals(), globals()) upscale = SAEModel.upscale_pre(padding=padding) downscale = SAEModel.downscale_pre(padding=padding) - downscale_sep = SAEModel.downscale_sep_pre(padding=padding) def func(input): dims = K.int_shape(input)[-1]*ch_dims x = input x = downscale(dims)(x) - if not light_enc: - x = downscale(dims*2)(x) - x = downscale(dims*4)(x) - x = downscale(dims*8)(x) - else: - x = downscale_sep(dims*2)(x) - x = downscale(dims*4)(x) - x = downscale_sep(dims*8)(x) + x = downscale(dims*2)(x) + x = downscale(dims*4)(x) + x = downscale(dims*8)(x) x = Flatten()(x) return x @@ -612,11 +582,10 @@ class SAEModel(ModelBase): return func @staticmethod - def DFEncFlow(resolution, light_enc, ae_dims, ch_dims, padding='zero', **kwargs): + def DFEncFlow(resolution, ae_dims, ch_dims, padding='zero', **kwargs): exec (nnlib.import_all(), locals(), globals()) upscale = SAEModel.upscale_pre(padding=padding) downscale = SAEModel.downscale_pre(padding=padding) - downscale_sep = SAEModel.downscale_sep_pre(padding=padding) lowest_dense_res = resolution // 16 def func(input): @@ -625,14 +594,9 @@ class SAEModel(ModelBase): dims = K.int_shape(input)[-1]*ch_dims x = downscale(dims)(x) - if not light_enc: - x = downscale(dims*2)(x) - x = downscale(dims*4)(x) - x = downscale(dims*8)(x) - else: - x = downscale_sep(dims*2)(x) - x = downscale(dims*4)(x) - x = downscale_sep(dims*8)(x) + x = downscale(dims*2)(x) + x = downscale(dims*4)(x) + x = downscale(dims*8)(x) x = Dense(ae_dims)(Flatten()(x)) x = Dense(lowest_dense_res * lowest_dense_res * ae_dims)(x)