mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-07 05:22:06 -07:00
SAE: removed lightweight encoder
This commit is contained in:
parent
addd87efa8
commit
aa58d9e563
1 changed files with 10 additions and 46 deletions
|
@ -73,12 +73,8 @@ class SAEModel(ModelBase):
|
|||
self.options['remove_gray_border'] = self.options.get('remove_gray_border', False)
|
||||
|
||||
if is_first_run:
|
||||
self.options['lighter_encoder'] = io.input_bool ("Use lightweight encoder? (y/n, ?:help skip:n) : ", False, help_message="Lightweight encoder is 35% faster, requires less VRAM, but sacrificing overall quality.")
|
||||
|
||||
self.options['multiscale_decoder'] = io.input_bool ("Use multiscale decoder? (y/n, ?:help skip:n) : ", False, help_message="Multiscale decoder helps to get better details.")
|
||||
else:
|
||||
self.options['lighter_encoder'] = self.options.get('lighter_encoder', False)
|
||||
|
||||
self.options['multiscale_decoder'] = self.options.get('multiscale_decoder', False)
|
||||
|
||||
default_face_style_power = 0.0
|
||||
|
@ -133,10 +129,9 @@ class SAEModel(ModelBase):
|
|||
padding = 'reflect' if self.options['remove_gray_border'] else 'zero'
|
||||
common_flow_kwargs = { 'padding': padding }
|
||||
|
||||
models_list = []
|
||||
weights_to_load = []
|
||||
if self.options['archi'] == 'liae':
|
||||
self.encoder = modelify(SAEModel.LIAEEncFlow(resolution, self.options['lighter_encoder'], ch_dims=e_ch_dims, **common_flow_kwargs) ) (Input(bgr_shape))
|
||||
self.encoder = modelify(SAEModel.LIAEEncFlow(resolution, ch_dims=e_ch_dims, **common_flow_kwargs) ) (Input(bgr_shape))
|
||||
|
||||
enc_output_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ]
|
||||
|
||||
|
@ -147,11 +142,8 @@ class SAEModel(ModelBase):
|
|||
|
||||
self.decoder = modelify(SAEModel.LIAEDecFlow (bgr_shape[2],ch_dims=d_ch_dims, multiscale_count=self.ms_count, add_residual_blocks=d_residual_blocks, **common_flow_kwargs)) (inter_output_Inputs)
|
||||
|
||||
models_list += [self.encoder, self.inter_B, self.inter_AB, self.decoder]
|
||||
|
||||
if self.options['learn_mask']:
|
||||
self.decoderm = modelify(SAEModel.LIAEDecFlow (mask_shape[2],ch_dims=d_ch_dims, **common_flow_kwargs)) (inter_output_Inputs)
|
||||
models_list += [self.decoderm]
|
||||
|
||||
if not self.is_first_run():
|
||||
weights_to_load += [ [self.encoder , 'encoder.h5'],
|
||||
|
@ -183,19 +175,16 @@ class SAEModel(ModelBase):
|
|||
pred_src_dstm = self.decoderm(warped_src_dst_inter_code)
|
||||
|
||||
elif self.options['archi'] == 'df':
|
||||
self.encoder = modelify(SAEModel.DFEncFlow(resolution, self.options['lighter_encoder'], ae_dims=ae_dims, ch_dims=e_ch_dims, **common_flow_kwargs) ) (Input(bgr_shape))
|
||||
self.encoder = modelify(SAEModel.DFEncFlow(resolution, ae_dims=ae_dims, ch_dims=e_ch_dims, **common_flow_kwargs) ) (Input(bgr_shape))
|
||||
|
||||
dec_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ]
|
||||
|
||||
self.decoder_src = modelify(SAEModel.DFDecFlow (bgr_shape[2],ch_dims=d_ch_dims, multiscale_count=self.ms_count, add_residual_blocks=d_residual_blocks, **common_flow_kwargs )) (dec_Inputs)
|
||||
self.decoder_dst = modelify(SAEModel.DFDecFlow (bgr_shape[2],ch_dims=d_ch_dims, multiscale_count=self.ms_count, add_residual_blocks=d_residual_blocks, **common_flow_kwargs )) (dec_Inputs)
|
||||
|
||||
models_list += [self.encoder, self.decoder_src, self.decoder_dst]
|
||||
|
||||
if self.options['learn_mask']:
|
||||
self.decoder_srcm = modelify(SAEModel.DFDecFlow (mask_shape[2],ch_dims=d_ch_dims, **common_flow_kwargs )) (dec_Inputs)
|
||||
self.decoder_dstm = modelify(SAEModel.DFDecFlow (mask_shape[2],ch_dims=d_ch_dims, **common_flow_kwargs )) (dec_Inputs)
|
||||
models_list += [self.decoder_srcm, self.decoder_dstm]
|
||||
|
||||
if not self.is_first_run():
|
||||
weights_to_load += [ [self.encoder , 'encoder.h5'],
|
||||
|
@ -491,19 +480,6 @@ class SAEModel(ModelBase):
|
|||
return func
|
||||
SAEModel.downscale_pre = downscale_pre
|
||||
|
||||
def downscale_sep (dim, padding='zero'):
|
||||
def func(x):
|
||||
return LeakyReLU(0.1)(SeparableConv2D(dim, kernel_size=5, strides=2, padding=padding)(x))
|
||||
return func
|
||||
SAEModel.downscale_sep = downscale_sep
|
||||
|
||||
def downscale_sep_pre (**base_kwargs):
|
||||
def func(*args, **kwargs):
|
||||
kwargs.update(base_kwargs)
|
||||
return downscale_sep(*args, **kwargs)
|
||||
return func
|
||||
SAEModel.downscale_sep_pre = downscale_sep_pre
|
||||
|
||||
def upscale (dim, padding='zero'):
|
||||
def func(x):
|
||||
return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, kernel_size=3, strides=1, padding=padding)(x)))
|
||||
|
@ -531,25 +507,19 @@ class SAEModel(ModelBase):
|
|||
SAEModel.to_bgr_pre = to_bgr_pre
|
||||
|
||||
@staticmethod
|
||||
def LIAEEncFlow(resolution, light_enc, ch_dims, padding='zero', **kwargs):
|
||||
def LIAEEncFlow(resolution, ch_dims, padding='zero', **kwargs):
|
||||
exec (nnlib.import_all(), locals(), globals())
|
||||
upscale = SAEModel.upscale_pre(padding=padding)
|
||||
downscale = SAEModel.downscale_pre(padding=padding)
|
||||
downscale_sep = SAEModel.downscale_sep_pre(padding=padding)
|
||||
|
||||
def func(input):
|
||||
dims = K.int_shape(input)[-1]*ch_dims
|
||||
|
||||
x = input
|
||||
x = downscale(dims)(x)
|
||||
if not light_enc:
|
||||
x = downscale(dims*2)(x)
|
||||
x = downscale(dims*4)(x)
|
||||
x = downscale(dims*8)(x)
|
||||
else:
|
||||
x = downscale_sep(dims*2)(x)
|
||||
x = downscale(dims*4)(x)
|
||||
x = downscale_sep(dims*8)(x)
|
||||
|
||||
x = Flatten()(x)
|
||||
return x
|
||||
|
@ -612,11 +582,10 @@ class SAEModel(ModelBase):
|
|||
return func
|
||||
|
||||
@staticmethod
|
||||
def DFEncFlow(resolution, light_enc, ae_dims, ch_dims, padding='zero', **kwargs):
|
||||
def DFEncFlow(resolution, ae_dims, ch_dims, padding='zero', **kwargs):
|
||||
exec (nnlib.import_all(), locals(), globals())
|
||||
upscale = SAEModel.upscale_pre(padding=padding)
|
||||
downscale = SAEModel.downscale_pre(padding=padding)
|
||||
downscale_sep = SAEModel.downscale_sep_pre(padding=padding)
|
||||
lowest_dense_res = resolution // 16
|
||||
|
||||
def func(input):
|
||||
|
@ -625,14 +594,9 @@ class SAEModel(ModelBase):
|
|||
dims = K.int_shape(input)[-1]*ch_dims
|
||||
|
||||
x = downscale(dims)(x)
|
||||
if not light_enc:
|
||||
x = downscale(dims*2)(x)
|
||||
x = downscale(dims*4)(x)
|
||||
x = downscale(dims*8)(x)
|
||||
else:
|
||||
x = downscale_sep(dims*2)(x)
|
||||
x = downscale(dims*4)(x)
|
||||
x = downscale_sep(dims*8)(x)
|
||||
|
||||
x = Dense(ae_dims)(Flatten()(x))
|
||||
x = Dense(lowest_dense_res * lowest_dense_res * ae_dims)(x)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue