mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-07 05:22:06 -07:00
SAE: removed lightweight encoder
This commit is contained in:
parent
addd87efa8
commit
aa58d9e563
1 changed files with 10 additions and 46 deletions
|
@ -73,12 +73,8 @@ class SAEModel(ModelBase):
|
||||||
self.options['remove_gray_border'] = self.options.get('remove_gray_border', False)
|
self.options['remove_gray_border'] = self.options.get('remove_gray_border', False)
|
||||||
|
|
||||||
if is_first_run:
|
if is_first_run:
|
||||||
self.options['lighter_encoder'] = io.input_bool ("Use lightweight encoder? (y/n, ?:help skip:n) : ", False, help_message="Lightweight encoder is 35% faster, requires less VRAM, but sacrificing overall quality.")
|
|
||||||
|
|
||||||
self.options['multiscale_decoder'] = io.input_bool ("Use multiscale decoder? (y/n, ?:help skip:n) : ", False, help_message="Multiscale decoder helps to get better details.")
|
self.options['multiscale_decoder'] = io.input_bool ("Use multiscale decoder? (y/n, ?:help skip:n) : ", False, help_message="Multiscale decoder helps to get better details.")
|
||||||
else:
|
else:
|
||||||
self.options['lighter_encoder'] = self.options.get('lighter_encoder', False)
|
|
||||||
|
|
||||||
self.options['multiscale_decoder'] = self.options.get('multiscale_decoder', False)
|
self.options['multiscale_decoder'] = self.options.get('multiscale_decoder', False)
|
||||||
|
|
||||||
default_face_style_power = 0.0
|
default_face_style_power = 0.0
|
||||||
|
@ -133,10 +129,9 @@ class SAEModel(ModelBase):
|
||||||
padding = 'reflect' if self.options['remove_gray_border'] else 'zero'
|
padding = 'reflect' if self.options['remove_gray_border'] else 'zero'
|
||||||
common_flow_kwargs = { 'padding': padding }
|
common_flow_kwargs = { 'padding': padding }
|
||||||
|
|
||||||
models_list = []
|
|
||||||
weights_to_load = []
|
weights_to_load = []
|
||||||
if self.options['archi'] == 'liae':
|
if self.options['archi'] == 'liae':
|
||||||
self.encoder = modelify(SAEModel.LIAEEncFlow(resolution, self.options['lighter_encoder'], ch_dims=e_ch_dims, **common_flow_kwargs) ) (Input(bgr_shape))
|
self.encoder = modelify(SAEModel.LIAEEncFlow(resolution, ch_dims=e_ch_dims, **common_flow_kwargs) ) (Input(bgr_shape))
|
||||||
|
|
||||||
enc_output_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ]
|
enc_output_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ]
|
||||||
|
|
||||||
|
@ -147,11 +142,8 @@ class SAEModel(ModelBase):
|
||||||
|
|
||||||
self.decoder = modelify(SAEModel.LIAEDecFlow (bgr_shape[2],ch_dims=d_ch_dims, multiscale_count=self.ms_count, add_residual_blocks=d_residual_blocks, **common_flow_kwargs)) (inter_output_Inputs)
|
self.decoder = modelify(SAEModel.LIAEDecFlow (bgr_shape[2],ch_dims=d_ch_dims, multiscale_count=self.ms_count, add_residual_blocks=d_residual_blocks, **common_flow_kwargs)) (inter_output_Inputs)
|
||||||
|
|
||||||
models_list += [self.encoder, self.inter_B, self.inter_AB, self.decoder]
|
|
||||||
|
|
||||||
if self.options['learn_mask']:
|
if self.options['learn_mask']:
|
||||||
self.decoderm = modelify(SAEModel.LIAEDecFlow (mask_shape[2],ch_dims=d_ch_dims, **common_flow_kwargs)) (inter_output_Inputs)
|
self.decoderm = modelify(SAEModel.LIAEDecFlow (mask_shape[2],ch_dims=d_ch_dims, **common_flow_kwargs)) (inter_output_Inputs)
|
||||||
models_list += [self.decoderm]
|
|
||||||
|
|
||||||
if not self.is_first_run():
|
if not self.is_first_run():
|
||||||
weights_to_load += [ [self.encoder , 'encoder.h5'],
|
weights_to_load += [ [self.encoder , 'encoder.h5'],
|
||||||
|
@ -183,19 +175,16 @@ class SAEModel(ModelBase):
|
||||||
pred_src_dstm = self.decoderm(warped_src_dst_inter_code)
|
pred_src_dstm = self.decoderm(warped_src_dst_inter_code)
|
||||||
|
|
||||||
elif self.options['archi'] == 'df':
|
elif self.options['archi'] == 'df':
|
||||||
self.encoder = modelify(SAEModel.DFEncFlow(resolution, self.options['lighter_encoder'], ae_dims=ae_dims, ch_dims=e_ch_dims, **common_flow_kwargs) ) (Input(bgr_shape))
|
self.encoder = modelify(SAEModel.DFEncFlow(resolution, ae_dims=ae_dims, ch_dims=e_ch_dims, **common_flow_kwargs) ) (Input(bgr_shape))
|
||||||
|
|
||||||
dec_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ]
|
dec_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ]
|
||||||
|
|
||||||
self.decoder_src = modelify(SAEModel.DFDecFlow (bgr_shape[2],ch_dims=d_ch_dims, multiscale_count=self.ms_count, add_residual_blocks=d_residual_blocks, **common_flow_kwargs )) (dec_Inputs)
|
self.decoder_src = modelify(SAEModel.DFDecFlow (bgr_shape[2],ch_dims=d_ch_dims, multiscale_count=self.ms_count, add_residual_blocks=d_residual_blocks, **common_flow_kwargs )) (dec_Inputs)
|
||||||
self.decoder_dst = modelify(SAEModel.DFDecFlow (bgr_shape[2],ch_dims=d_ch_dims, multiscale_count=self.ms_count, add_residual_blocks=d_residual_blocks, **common_flow_kwargs )) (dec_Inputs)
|
self.decoder_dst = modelify(SAEModel.DFDecFlow (bgr_shape[2],ch_dims=d_ch_dims, multiscale_count=self.ms_count, add_residual_blocks=d_residual_blocks, **common_flow_kwargs )) (dec_Inputs)
|
||||||
|
|
||||||
models_list += [self.encoder, self.decoder_src, self.decoder_dst]
|
|
||||||
|
|
||||||
if self.options['learn_mask']:
|
if self.options['learn_mask']:
|
||||||
self.decoder_srcm = modelify(SAEModel.DFDecFlow (mask_shape[2],ch_dims=d_ch_dims, **common_flow_kwargs )) (dec_Inputs)
|
self.decoder_srcm = modelify(SAEModel.DFDecFlow (mask_shape[2],ch_dims=d_ch_dims, **common_flow_kwargs )) (dec_Inputs)
|
||||||
self.decoder_dstm = modelify(SAEModel.DFDecFlow (mask_shape[2],ch_dims=d_ch_dims, **common_flow_kwargs )) (dec_Inputs)
|
self.decoder_dstm = modelify(SAEModel.DFDecFlow (mask_shape[2],ch_dims=d_ch_dims, **common_flow_kwargs )) (dec_Inputs)
|
||||||
models_list += [self.decoder_srcm, self.decoder_dstm]
|
|
||||||
|
|
||||||
if not self.is_first_run():
|
if not self.is_first_run():
|
||||||
weights_to_load += [ [self.encoder , 'encoder.h5'],
|
weights_to_load += [ [self.encoder , 'encoder.h5'],
|
||||||
|
@ -491,19 +480,6 @@ class SAEModel(ModelBase):
|
||||||
return func
|
return func
|
||||||
SAEModel.downscale_pre = downscale_pre
|
SAEModel.downscale_pre = downscale_pre
|
||||||
|
|
||||||
def downscale_sep (dim, padding='zero'):
|
|
||||||
def func(x):
|
|
||||||
return LeakyReLU(0.1)(SeparableConv2D(dim, kernel_size=5, strides=2, padding=padding)(x))
|
|
||||||
return func
|
|
||||||
SAEModel.downscale_sep = downscale_sep
|
|
||||||
|
|
||||||
def downscale_sep_pre (**base_kwargs):
|
|
||||||
def func(*args, **kwargs):
|
|
||||||
kwargs.update(base_kwargs)
|
|
||||||
return downscale_sep(*args, **kwargs)
|
|
||||||
return func
|
|
||||||
SAEModel.downscale_sep_pre = downscale_sep_pre
|
|
||||||
|
|
||||||
def upscale (dim, padding='zero'):
|
def upscale (dim, padding='zero'):
|
||||||
def func(x):
|
def func(x):
|
||||||
return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, kernel_size=3, strides=1, padding=padding)(x)))
|
return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, kernel_size=3, strides=1, padding=padding)(x)))
|
||||||
|
@ -531,25 +507,19 @@ class SAEModel(ModelBase):
|
||||||
SAEModel.to_bgr_pre = to_bgr_pre
|
SAEModel.to_bgr_pre = to_bgr_pre
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def LIAEEncFlow(resolution, light_enc, ch_dims, padding='zero', **kwargs):
|
def LIAEEncFlow(resolution, ch_dims, padding='zero', **kwargs):
|
||||||
exec (nnlib.import_all(), locals(), globals())
|
exec (nnlib.import_all(), locals(), globals())
|
||||||
upscale = SAEModel.upscale_pre(padding=padding)
|
upscale = SAEModel.upscale_pre(padding=padding)
|
||||||
downscale = SAEModel.downscale_pre(padding=padding)
|
downscale = SAEModel.downscale_pre(padding=padding)
|
||||||
downscale_sep = SAEModel.downscale_sep_pre(padding=padding)
|
|
||||||
|
|
||||||
def func(input):
|
def func(input):
|
||||||
dims = K.int_shape(input)[-1]*ch_dims
|
dims = K.int_shape(input)[-1]*ch_dims
|
||||||
|
|
||||||
x = input
|
x = input
|
||||||
x = downscale(dims)(x)
|
x = downscale(dims)(x)
|
||||||
if not light_enc:
|
x = downscale(dims*2)(x)
|
||||||
x = downscale(dims*2)(x)
|
x = downscale(dims*4)(x)
|
||||||
x = downscale(dims*4)(x)
|
x = downscale(dims*8)(x)
|
||||||
x = downscale(dims*8)(x)
|
|
||||||
else:
|
|
||||||
x = downscale_sep(dims*2)(x)
|
|
||||||
x = downscale(dims*4)(x)
|
|
||||||
x = downscale_sep(dims*8)(x)
|
|
||||||
|
|
||||||
x = Flatten()(x)
|
x = Flatten()(x)
|
||||||
return x
|
return x
|
||||||
|
@ -612,11 +582,10 @@ class SAEModel(ModelBase):
|
||||||
return func
|
return func
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def DFEncFlow(resolution, light_enc, ae_dims, ch_dims, padding='zero', **kwargs):
|
def DFEncFlow(resolution, ae_dims, ch_dims, padding='zero', **kwargs):
|
||||||
exec (nnlib.import_all(), locals(), globals())
|
exec (nnlib.import_all(), locals(), globals())
|
||||||
upscale = SAEModel.upscale_pre(padding=padding)
|
upscale = SAEModel.upscale_pre(padding=padding)
|
||||||
downscale = SAEModel.downscale_pre(padding=padding)
|
downscale = SAEModel.downscale_pre(padding=padding)
|
||||||
downscale_sep = SAEModel.downscale_sep_pre(padding=padding)
|
|
||||||
lowest_dense_res = resolution // 16
|
lowest_dense_res = resolution // 16
|
||||||
|
|
||||||
def func(input):
|
def func(input):
|
||||||
|
@ -625,14 +594,9 @@ class SAEModel(ModelBase):
|
||||||
dims = K.int_shape(input)[-1]*ch_dims
|
dims = K.int_shape(input)[-1]*ch_dims
|
||||||
|
|
||||||
x = downscale(dims)(x)
|
x = downscale(dims)(x)
|
||||||
if not light_enc:
|
x = downscale(dims*2)(x)
|
||||||
x = downscale(dims*2)(x)
|
x = downscale(dims*4)(x)
|
||||||
x = downscale(dims*4)(x)
|
x = downscale(dims*8)(x)
|
||||||
x = downscale(dims*8)(x)
|
|
||||||
else:
|
|
||||||
x = downscale_sep(dims*2)(x)
|
|
||||||
x = downscale(dims*4)(x)
|
|
||||||
x = downscale_sep(dims*8)(x)
|
|
||||||
|
|
||||||
x = Dense(ae_dims)(Flatten()(x))
|
x = Dense(ae_dims)(Flatten()(x))
|
||||||
x = Dense(lowest_dense_res * lowest_dense_res * ae_dims)(x)
|
x = Dense(lowest_dense_res * lowest_dense_res * ae_dims)(x)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue