mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 21:12:07 -07:00
SAE: added option "Use CA weights":
Initialize network with 'Convolution Aware' weights. This may help to achieve a higher accuracy model, but consumes time at first run.
This commit is contained in:
parent
71ff0ce1a7
commit
d6a45763a2
4 changed files with 346 additions and 18 deletions
|
@ -56,12 +56,15 @@ class SAEModel(ModelBase):
|
|||
|
||||
default_ae_dims = 256 if self.options['archi'] == 'liae' else 512
|
||||
default_ed_ch_dims = 42
|
||||
def_ca_weights = False
|
||||
if is_first_run:
|
||||
self.options['ae_dims'] = np.clip ( io.input_int("AutoEncoder dims (32-1024 ?:help skip:%d) : " % (default_ae_dims) , default_ae_dims, help_message="More dims are better, but requires more VRAM. You can fine-tune model size to fit your GPU." ), 32, 1024 )
|
||||
self.options['ed_ch_dims'] = np.clip ( io.input_int("Encoder/Decoder dims per channel (21-85 ?:help skip:%d) : " % (default_ed_ch_dims) , default_ed_ch_dims, help_message="More dims are better, but requires more VRAM. You can fine-tune model size to fit your GPU." ), 21, 85 )
|
||||
self.options['ca_weights'] = io.input_bool ("Use CA weights? (y/n, ?:help skip: %s ) : " % (yn_str[def_ca_weights]), def_ca_weights, help_message="Initialize network with 'Convolution Aware' weights. This may help to achieve a higher accuracy model, but consumes time at first run.")
|
||||
else:
|
||||
self.options['ae_dims'] = self.options.get('ae_dims', default_ae_dims)
|
||||
self.options['ed_ch_dims'] = self.options.get('ed_ch_dims', default_ed_ch_dims)
|
||||
self.options['ca_weights'] = self.options.get('ca_weights', def_ca_weights)
|
||||
|
||||
if is_first_run:
|
||||
self.options['lighter_encoder'] = io.input_bool ("Use lightweight encoder? (y/n, ?:help skip:n) : ", False, help_message="Lightweight encoder is 35% faster, requires less VRAM, but sacrificing overall quality.")
|
||||
|
@ -122,6 +125,7 @@ class SAEModel(ModelBase):
|
|||
target_dstm_ar = [ Input ( ( mask_shape[0] // (2**i) ,)*2 + (mask_shape[-1],) ) for i in range(ms_count-1, -1, -1)]
|
||||
|
||||
|
||||
models_list = []
|
||||
weights_to_load = []
|
||||
if self.options['archi'] == 'liae':
|
||||
self.encoder = modelify(SAEModel.LIAEEncFlow(resolution, self.options['lighter_encoder'], ed_ch_dims=ed_ch_dims) ) (Input(bgr_shape))
|
||||
|
@ -135,9 +139,12 @@ class SAEModel(ModelBase):
|
|||
|
||||
self.decoder = modelify(SAEModel.LIAEDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2, multiscale_count=self.ms_count )) (inter_output_Inputs)
|
||||
|
||||
models_list += [self.encoder, self.inter_B, self.inter_AB, self.decoder]
|
||||
|
||||
if self.options['learn_mask']:
|
||||
self.decoderm = modelify(SAEModel.LIAEDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5) )) (inter_output_Inputs)
|
||||
|
||||
models_list += [self.decoderm]
|
||||
|
||||
if not self.is_first_run():
|
||||
weights_to_load += [ [self.encoder , 'encoder.h5'],
|
||||
[self.inter_B , 'inter_B.h5'],
|
||||
|
@ -146,7 +153,7 @@ class SAEModel(ModelBase):
|
|||
]
|
||||
if self.options['learn_mask']:
|
||||
weights_to_load += [ [self.decoderm, 'decoderm.h5'] ]
|
||||
|
||||
|
||||
warped_src_code = self.encoder (warped_src)
|
||||
warped_src_inter_AB_code = self.inter_AB (warped_src_code)
|
||||
warped_src_inter_code = Concatenate()([warped_src_inter_AB_code,warped_src_inter_AB_code])
|
||||
|
@ -175,10 +182,13 @@ class SAEModel(ModelBase):
|
|||
self.decoder_src = modelify(SAEModel.DFDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2, multiscale_count=self.ms_count )) (dec_Inputs)
|
||||
self.decoder_dst = modelify(SAEModel.DFDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2, multiscale_count=self.ms_count )) (dec_Inputs)
|
||||
|
||||
models_list += [self.encoder, self.decoder_src, self.decoder_dst]
|
||||
|
||||
if self.options['learn_mask']:
|
||||
self.decoder_srcm = modelify(SAEModel.DFDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5) )) (dec_Inputs)
|
||||
self.decoder_dstm = modelify(SAEModel.DFDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5) )) (dec_Inputs)
|
||||
|
||||
models_list += [self.decoder_srcm, self.decoder_dstm]
|
||||
|
||||
if not self.is_first_run():
|
||||
weights_to_load += [ [self.encoder , 'encoder.h5'],
|
||||
[self.decoder_src, 'decoder_src.h5'],
|
||||
|
@ -188,7 +198,11 @@ class SAEModel(ModelBase):
|
|||
weights_to_load += [ [self.decoder_srcm, 'decoder_srcm.h5'],
|
||||
[self.decoder_dstm, 'decoder_dstm.h5'],
|
||||
]
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
warped_src_code = self.encoder (warped_src)
|
||||
warped_dst_code = self.encoder (warped_dst)
|
||||
pred_src_src = self.decoder_src(warped_src_code)
|
||||
|
@ -208,10 +222,13 @@ class SAEModel(ModelBase):
|
|||
self.decoder_src = modelify(SAEModel.VGDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2 )) (dec_Inputs)
|
||||
self.decoder_dst = modelify(SAEModel.VGDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2 )) (dec_Inputs)
|
||||
|
||||
models_list += [self.encoder, self.decoder_src, self.decoder_dst]
|
||||
|
||||
if self.options['learn_mask']:
|
||||
self.decoder_srcm = modelify(SAEModel.VGDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5) )) (dec_Inputs)
|
||||
self.decoder_dstm = modelify(SAEModel.VGDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5) )) (dec_Inputs)
|
||||
|
||||
models_list += [self.decoder_srcm, self.decoder_dstm]
|
||||
|
||||
if not self.is_first_run():
|
||||
weights_to_load += [ [self.encoder , 'encoder.h5'],
|
||||
[self.decoder_src, 'decoder_src.h5'],
|
||||
|
@ -233,7 +250,17 @@ class SAEModel(ModelBase):
|
|||
pred_src_srcm = self.decoder_srcm(warped_src_code)
|
||||
pred_dst_dstm = self.decoder_dstm(warped_dst_code)
|
||||
pred_src_dstm = self.decoder_srcm(warped_dst_code)
|
||||
|
||||
|
||||
if self.is_first_run() and self.options['ca_weights']:
|
||||
io.log_info ("Initializing CA weights...")
|
||||
conv_weights_list = []
|
||||
for model in models_list:
|
||||
for layer in model.layers:
|
||||
if type(layer) == Conv2D:
|
||||
conv_weights_list += [layer.weights[0]] #Conv2D kernel_weights
|
||||
CAInitializerMP ( conv_weights_list )
|
||||
|
||||
|
||||
pred_src_src, pred_dst_dst, pred_src_dst, = [ [x] if type(x) != list else x for x in [pred_src_src, pred_dst_dst, pred_src_dst, ] ]
|
||||
|
||||
if self.options['learn_mask']:
|
||||
|
@ -468,9 +495,6 @@ class SAEModel(ModelBase):
|
|||
def initialize_nn_functions():
|
||||
exec (nnlib.import_all(), locals(), globals())
|
||||
|
||||
def conv_initializer():
|
||||
return RandomNormal(0, 0.02)
|
||||
|
||||
class ResidualBlock(object):
|
||||
def __init__(self, filters, kernel_size=3, padding='same', use_reflection_padding=False):
|
||||
self.filters = filters
|
||||
|
@ -484,13 +508,13 @@ class SAEModel(ModelBase):
|
|||
#if self.use_reflection_padding:
|
||||
# #var_x = ReflectionPadding2D(stride=1, kernel_size=kernel_size)(var_x)
|
||||
|
||||
var_x = Conv2D(self.filters, kernel_size=self.kernel_size, padding=self.padding, kernel_initializer=conv_initializer() )(var_x)
|
||||
var_x = Conv2D(self.filters, kernel_size=self.kernel_size, padding=self.padding, kernel_initializer=RandomNormal(0, 0.02) )(var_x)
|
||||
var_x = LeakyReLU(alpha=0.2)(var_x)
|
||||
|
||||
#if self.use_reflection_padding:
|
||||
# #var_x = ReflectionPadding2D(stride=1, kernel_size=kernel_size)(var_x)
|
||||
|
||||
var_x = Conv2D(self.filters, kernel_size=self.kernel_size, padding=self.padding, kernel_initializer=conv_initializer() )(var_x)
|
||||
var_x = Conv2D(self.filters, kernel_size=self.kernel_size, padding=self.padding, kernel_initializer=RandomNormal(0, 0.02) )(var_x)
|
||||
var_x = Scale(gamma_init=keras.initializers.Constant(value=0.1))(var_x)
|
||||
var_x = Add()([var_x, inp])
|
||||
var_x = LeakyReLU(alpha=0.2)(var_x)
|
||||
|
@ -499,25 +523,25 @@ class SAEModel(ModelBase):
|
|||
|
||||
def downscale (dim):
|
||||
def func(x):
|
||||
return LeakyReLU(0.1)(Conv2D(dim, kernel_size=5, strides=2, padding='same', kernel_initializer=conv_initializer())(x))
|
||||
return LeakyReLU(0.1)(Conv2D(dim, kernel_size=5, strides=2, padding='same', kernel_initializer=RandomNormal(0, 0.02))(x))
|
||||
return func
|
||||
SAEModel.downscale = downscale
|
||||
|
||||
def downscale_sep (dim):
|
||||
def func(x):
|
||||
return LeakyReLU(0.1)(SeparableConv2D(dim, kernel_size=5, strides=2, padding='same', depthwise_initializer=conv_initializer(), pointwise_initializer=RandomNormal(0, 0.02) )(x))
|
||||
return LeakyReLU(0.1)(SeparableConv2D(dim, kernel_size=5, strides=2, padding='same', depthwise_initializer=RandomNormal(0, 0.02), pointwise_initializer=RandomNormal(0, 0.02) )(x))
|
||||
return func
|
||||
SAEModel.downscale_sep = downscale_sep
|
||||
|
||||
def upscale (dim):
|
||||
def func(x):
|
||||
return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, kernel_size=3, strides=1, padding='same', kernel_initializer=conv_initializer() )(x)))
|
||||
return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, kernel_size=3, strides=1, padding='same', kernel_initializer=RandomNormal(0, 0.02) )(x)))
|
||||
return func
|
||||
SAEModel.upscale = upscale
|
||||
|
||||
def to_bgr (output_nc):
|
||||
def func(x):
|
||||
return Conv2D(output_nc, kernel_size=5, padding='same', activation='sigmoid', kernel_initializer=conv_initializer() )(x)
|
||||
return Conv2D(output_nc, kernel_size=5, padding='same', activation='sigmoid', kernel_initializer=RandomNormal(0, 0.02) )(x)
|
||||
return func
|
||||
SAEModel.to_bgr = to_bgr
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue