SAE: removed lightweight encoder

This commit is contained in:
iperov 2019-03-25 10:05:30 +04:00
parent addd87efa8
commit aa58d9e563

View file

@ -73,12 +73,8 @@ class SAEModel(ModelBase):
self.options['remove_gray_border'] = self.options.get('remove_gray_border', False) self.options['remove_gray_border'] = self.options.get('remove_gray_border', False)
if is_first_run: if is_first_run:
self.options['lighter_encoder'] = io.input_bool ("Use lightweight encoder? (y/n, ?:help skip:n) : ", False, help_message="Lightweight encoder is 35% faster, requires less VRAM, but sacrificing overall quality.")
self.options['multiscale_decoder'] = io.input_bool ("Use multiscale decoder? (y/n, ?:help skip:n) : ", False, help_message="Multiscale decoder helps to get better details.") self.options['multiscale_decoder'] = io.input_bool ("Use multiscale decoder? (y/n, ?:help skip:n) : ", False, help_message="Multiscale decoder helps to get better details.")
else: else:
self.options['lighter_encoder'] = self.options.get('lighter_encoder', False)
self.options['multiscale_decoder'] = self.options.get('multiscale_decoder', False) self.options['multiscale_decoder'] = self.options.get('multiscale_decoder', False)
default_face_style_power = 0.0 default_face_style_power = 0.0
@ -133,10 +129,9 @@ class SAEModel(ModelBase):
padding = 'reflect' if self.options['remove_gray_border'] else 'zero' padding = 'reflect' if self.options['remove_gray_border'] else 'zero'
common_flow_kwargs = { 'padding': padding } common_flow_kwargs = { 'padding': padding }
models_list = []
weights_to_load = [] weights_to_load = []
if self.options['archi'] == 'liae': if self.options['archi'] == 'liae':
self.encoder = modelify(SAEModel.LIAEEncFlow(resolution, self.options['lighter_encoder'], ch_dims=e_ch_dims, **common_flow_kwargs) ) (Input(bgr_shape)) self.encoder = modelify(SAEModel.LIAEEncFlow(resolution, ch_dims=e_ch_dims, **common_flow_kwargs) ) (Input(bgr_shape))
enc_output_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ] enc_output_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ]
@ -147,11 +142,8 @@ class SAEModel(ModelBase):
self.decoder = modelify(SAEModel.LIAEDecFlow (bgr_shape[2],ch_dims=d_ch_dims, multiscale_count=self.ms_count, add_residual_blocks=d_residual_blocks, **common_flow_kwargs)) (inter_output_Inputs) self.decoder = modelify(SAEModel.LIAEDecFlow (bgr_shape[2],ch_dims=d_ch_dims, multiscale_count=self.ms_count, add_residual_blocks=d_residual_blocks, **common_flow_kwargs)) (inter_output_Inputs)
models_list += [self.encoder, self.inter_B, self.inter_AB, self.decoder]
if self.options['learn_mask']: if self.options['learn_mask']:
self.decoderm = modelify(SAEModel.LIAEDecFlow (mask_shape[2],ch_dims=d_ch_dims, **common_flow_kwargs)) (inter_output_Inputs) self.decoderm = modelify(SAEModel.LIAEDecFlow (mask_shape[2],ch_dims=d_ch_dims, **common_flow_kwargs)) (inter_output_Inputs)
models_list += [self.decoderm]
if not self.is_first_run(): if not self.is_first_run():
weights_to_load += [ [self.encoder , 'encoder.h5'], weights_to_load += [ [self.encoder , 'encoder.h5'],
@ -183,19 +175,16 @@ class SAEModel(ModelBase):
pred_src_dstm = self.decoderm(warped_src_dst_inter_code) pred_src_dstm = self.decoderm(warped_src_dst_inter_code)
elif self.options['archi'] == 'df': elif self.options['archi'] == 'df':
self.encoder = modelify(SAEModel.DFEncFlow(resolution, self.options['lighter_encoder'], ae_dims=ae_dims, ch_dims=e_ch_dims, **common_flow_kwargs) ) (Input(bgr_shape)) self.encoder = modelify(SAEModel.DFEncFlow(resolution, ae_dims=ae_dims, ch_dims=e_ch_dims, **common_flow_kwargs) ) (Input(bgr_shape))
dec_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ] dec_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ]
self.decoder_src = modelify(SAEModel.DFDecFlow (bgr_shape[2],ch_dims=d_ch_dims, multiscale_count=self.ms_count, add_residual_blocks=d_residual_blocks, **common_flow_kwargs )) (dec_Inputs) self.decoder_src = modelify(SAEModel.DFDecFlow (bgr_shape[2],ch_dims=d_ch_dims, multiscale_count=self.ms_count, add_residual_blocks=d_residual_blocks, **common_flow_kwargs )) (dec_Inputs)
self.decoder_dst = modelify(SAEModel.DFDecFlow (bgr_shape[2],ch_dims=d_ch_dims, multiscale_count=self.ms_count, add_residual_blocks=d_residual_blocks, **common_flow_kwargs )) (dec_Inputs) self.decoder_dst = modelify(SAEModel.DFDecFlow (bgr_shape[2],ch_dims=d_ch_dims, multiscale_count=self.ms_count, add_residual_blocks=d_residual_blocks, **common_flow_kwargs )) (dec_Inputs)
models_list += [self.encoder, self.decoder_src, self.decoder_dst]
if self.options['learn_mask']: if self.options['learn_mask']:
self.decoder_srcm = modelify(SAEModel.DFDecFlow (mask_shape[2],ch_dims=d_ch_dims, **common_flow_kwargs )) (dec_Inputs) self.decoder_srcm = modelify(SAEModel.DFDecFlow (mask_shape[2],ch_dims=d_ch_dims, **common_flow_kwargs )) (dec_Inputs)
self.decoder_dstm = modelify(SAEModel.DFDecFlow (mask_shape[2],ch_dims=d_ch_dims, **common_flow_kwargs )) (dec_Inputs) self.decoder_dstm = modelify(SAEModel.DFDecFlow (mask_shape[2],ch_dims=d_ch_dims, **common_flow_kwargs )) (dec_Inputs)
models_list += [self.decoder_srcm, self.decoder_dstm]
if not self.is_first_run(): if not self.is_first_run():
weights_to_load += [ [self.encoder , 'encoder.h5'], weights_to_load += [ [self.encoder , 'encoder.h5'],
@ -491,19 +480,6 @@ class SAEModel(ModelBase):
return func return func
SAEModel.downscale_pre = downscale_pre SAEModel.downscale_pre = downscale_pre
def downscale_sep (dim, padding='zero'):
def func(x):
return LeakyReLU(0.1)(SeparableConv2D(dim, kernel_size=5, strides=2, padding=padding)(x))
return func
SAEModel.downscale_sep = downscale_sep
def downscale_sep_pre (**base_kwargs):
def func(*args, **kwargs):
kwargs.update(base_kwargs)
return downscale_sep(*args, **kwargs)
return func
SAEModel.downscale_sep_pre = downscale_sep_pre
def upscale (dim, padding='zero'): def upscale (dim, padding='zero'):
def func(x): def func(x):
return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, kernel_size=3, strides=1, padding=padding)(x))) return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, kernel_size=3, strides=1, padding=padding)(x)))
@ -531,25 +507,19 @@ class SAEModel(ModelBase):
SAEModel.to_bgr_pre = to_bgr_pre SAEModel.to_bgr_pre = to_bgr_pre
@staticmethod @staticmethod
def LIAEEncFlow(resolution, light_enc, ch_dims, padding='zero', **kwargs): def LIAEEncFlow(resolution, ch_dims, padding='zero', **kwargs):
exec (nnlib.import_all(), locals(), globals()) exec (nnlib.import_all(), locals(), globals())
upscale = SAEModel.upscale_pre(padding=padding) upscale = SAEModel.upscale_pre(padding=padding)
downscale = SAEModel.downscale_pre(padding=padding) downscale = SAEModel.downscale_pre(padding=padding)
downscale_sep = SAEModel.downscale_sep_pre(padding=padding)
def func(input): def func(input):
dims = K.int_shape(input)[-1]*ch_dims dims = K.int_shape(input)[-1]*ch_dims
x = input x = input
x = downscale(dims)(x) x = downscale(dims)(x)
if not light_enc:
x = downscale(dims*2)(x) x = downscale(dims*2)(x)
x = downscale(dims*4)(x) x = downscale(dims*4)(x)
x = downscale(dims*8)(x) x = downscale(dims*8)(x)
else:
x = downscale_sep(dims*2)(x)
x = downscale(dims*4)(x)
x = downscale_sep(dims*8)(x)
x = Flatten()(x) x = Flatten()(x)
return x return x
@ -612,11 +582,10 @@ class SAEModel(ModelBase):
return func return func
@staticmethod @staticmethod
def DFEncFlow(resolution, light_enc, ae_dims, ch_dims, padding='zero', **kwargs): def DFEncFlow(resolution, ae_dims, ch_dims, padding='zero', **kwargs):
exec (nnlib.import_all(), locals(), globals()) exec (nnlib.import_all(), locals(), globals())
upscale = SAEModel.upscale_pre(padding=padding) upscale = SAEModel.upscale_pre(padding=padding)
downscale = SAEModel.downscale_pre(padding=padding) downscale = SAEModel.downscale_pre(padding=padding)
downscale_sep = SAEModel.downscale_sep_pre(padding=padding)
lowest_dense_res = resolution // 16 lowest_dense_res = resolution // 16
def func(input): def func(input):
@ -625,14 +594,9 @@ class SAEModel(ModelBase):
dims = K.int_shape(input)[-1]*ch_dims dims = K.int_shape(input)[-1]*ch_dims
x = downscale(dims)(x) x = downscale(dims)(x)
if not light_enc:
x = downscale(dims*2)(x) x = downscale(dims*2)(x)
x = downscale(dims*4)(x) x = downscale(dims*4)(x)
x = downscale(dims*8)(x) x = downscale(dims*8)(x)
else:
x = downscale_sep(dims*2)(x)
x = downscale(dims*4)(x)
x = downscale_sep(dims*8)(x)
x = Dense(ae_dims)(Flatten()(x)) x = Dense(ae_dims)(Flatten()(x))
x = Dense(lowest_dense_res * lowest_dense_res * ae_dims)(x) x = Dense(lowest_dense_res * lowest_dense_res * ae_dims)(x)