mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-07 05:22:06 -07:00
SAE: you have to restart training,
added multiscale decoder as option. mask now training as not multiscaled
This commit is contained in:
parent
b87e6be614
commit
51a13c90d1
1 changed files with 49 additions and 25 deletions
|
@ -31,12 +31,14 @@ class SAEModel(ModelBase):
|
||||||
self.options['face_type'] = input_str ("Half or Full face? (h/f, ?:help skip:f) : ", default_face_type, ['h','f'], help_message="Half face has better resolution, but covers less area of cheeks.").lower()
|
self.options['face_type'] = input_str ("Half or Full face? (h/f, ?:help skip:f) : ", default_face_type, ['h','f'], help_message="Half face has better resolution, but covers less area of cheeks.").lower()
|
||||||
self.options['archi'] = input_str ("AE architecture (df, liae, ?:help skip:%s) : " % (default_archi) , default_archi, ['df','liae'], help_message="DF keeps faces more natural, while LIAE can fix overly different face shapes.").lower()
|
self.options['archi'] = input_str ("AE architecture (df, liae, ?:help skip:%s) : " % (default_archi) , default_archi, ['df','liae'], help_message="DF keeps faces more natural, while LIAE can fix overly different face shapes.").lower()
|
||||||
self.options['lighter_encoder'] = input_bool ("Use lightweight encoder? (y/n, ?:help skip:n) : ", False, help_message="Lightweight encoder is 35% faster, requires less VRAM, sacrificing overall quality.")
|
self.options['lighter_encoder'] = input_bool ("Use lightweight encoder? (y/n, ?:help skip:n) : ", False, help_message="Lightweight encoder is 35% faster, requires less VRAM, sacrificing overall quality.")
|
||||||
self.options['learn_mask'] = input_bool ("Learn mask? (y/n, ?:help skip:y) : ", True, help_message="Choose NO to reduce model size. In this case converter forced to use 'not predicted mask' that is not smooth as predicted. Styled SAE can learn without mask and produce same quality fake if you choose high blur value in converter.")
|
self.options['multiscale_decoder'] = input_bool ("Use multiscale decoder? (y/n, ?:help skip:n) : ", False, help_message="This option forces decoder to produce higher detailed image and make final face look more like dst.")
|
||||||
|
self.options['learn_mask'] = input_bool ("Learn mask? (y/n, ?:help skip:y) : ", True, help_message="Choose NO to reduce model size. In this case converter forced to use 'not predicted mask' that is not smooth as predicted. Styled SAE can learn without mask and produce same quality fake.")
|
||||||
else:
|
else:
|
||||||
self.options['resolution'] = self.options.get('resolution', default_resolution)
|
self.options['resolution'] = self.options.get('resolution', default_resolution)
|
||||||
self.options['face_type'] = self.options.get('face_type', default_face_type)
|
self.options['face_type'] = self.options.get('face_type', default_face_type)
|
||||||
self.options['archi'] = self.options.get('archi', default_archi)
|
self.options['archi'] = self.options.get('archi', default_archi)
|
||||||
self.options['lighter_encoder'] = self.options.get('lighter_encoder', False)
|
self.options['lighter_encoder'] = self.options.get('lighter_encoder', False)
|
||||||
|
self.options['multiscale_decoder'] = self.options.get('multiscale_decoder', False)
|
||||||
self.options['learn_mask'] = self.options.get('learn_mask', True)
|
self.options['learn_mask'] = self.options.get('learn_mask', True)
|
||||||
|
|
||||||
default_face_style_power = 10.0
|
default_face_style_power = 10.0
|
||||||
|
@ -96,10 +98,10 @@ class SAEModel(ModelBase):
|
||||||
|
|
||||||
inter_output_Inputs = [ Input( np.array(K.int_shape(x)[1:])*(1,1,2) ) for x in self.inter_B.outputs ]
|
inter_output_Inputs = [ Input( np.array(K.int_shape(x)[1:])*(1,1,2) ) for x in self.inter_B.outputs ]
|
||||||
|
|
||||||
self.decoder = modelify(SAEModel.LIAEDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2)) (inter_output_Inputs)
|
self.decoder = modelify(SAEModel.LIAEDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2, multiscale_decoder=self.options['multiscale_decoder'])) (inter_output_Inputs)
|
||||||
|
|
||||||
if self.options['learn_mask']:
|
if self.options['learn_mask']:
|
||||||
self.decoderm = modelify(SAEModel.LIAEDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5) )) (inter_output_Inputs)
|
self.decoderm = modelify(SAEModel.LIAEDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5), multiscale_decoder=False )) (inter_output_Inputs)
|
||||||
|
|
||||||
if not self.is_first_run():
|
if not self.is_first_run():
|
||||||
self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5))
|
self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5))
|
||||||
|
@ -128,18 +130,20 @@ class SAEModel(ModelBase):
|
||||||
pred_src_srcm = self.decoderm(warped_src_inter_code)
|
pred_src_srcm = self.decoderm(warped_src_inter_code)
|
||||||
pred_dst_dstm = self.decoderm(warped_dst_inter_code)
|
pred_dst_dstm = self.decoderm(warped_dst_inter_code)
|
||||||
pred_src_dstm = self.decoderm(warped_src_dst_inter_code)
|
pred_src_dstm = self.decoderm(warped_src_dst_inter_code)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.encoder = modelify(SAEModel.DFEncFlow(resolution, adapt_k_size, self.options['lighter_encoder'], ae_dims=ae_dims, ed_ch_dims=ed_ch_dims) ) (Input(bgr_shape))
|
self.encoder = modelify(SAEModel.DFEncFlow(resolution, adapt_k_size, self.options['lighter_encoder'], ae_dims=ae_dims, ed_ch_dims=ed_ch_dims) ) (Input(bgr_shape))
|
||||||
|
|
||||||
dec_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ]
|
dec_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ]
|
||||||
|
|
||||||
self.decoder_src = modelify(SAEModel.DFDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2)) (dec_Inputs)
|
self.decoder_src = modelify(SAEModel.DFDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2, multiscale_decoder=self.options['multiscale_decoder'])) (dec_Inputs)
|
||||||
self.decoder_dst = modelify(SAEModel.DFDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2)) (dec_Inputs)
|
self.decoder_dst = modelify(SAEModel.DFDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2, multiscale_decoder=self.options['multiscale_decoder'])) (dec_Inputs)
|
||||||
|
|
||||||
if self.options['learn_mask']:
|
if self.options['learn_mask']:
|
||||||
self.decoder_srcm = modelify(SAEModel.DFDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5))) (dec_Inputs)
|
self.decoder_srcm = modelify(SAEModel.DFDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5), multiscale_decoder=False)) (dec_Inputs)
|
||||||
self.decoder_dstm = modelify(SAEModel.DFDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5))) (dec_Inputs)
|
self.decoder_dstm = modelify(SAEModel.DFDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5), multiscale_decoder=False)) (dec_Inputs)
|
||||||
|
|
||||||
if not self.is_first_run():
|
if not self.is_first_run():
|
||||||
self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5))
|
self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5))
|
||||||
|
@ -160,7 +164,11 @@ class SAEModel(ModelBase):
|
||||||
pred_dst_dstm = self.decoder_dstm(warped_dst_code)
|
pred_dst_dstm = self.decoder_dstm(warped_dst_code)
|
||||||
pred_src_dstm = self.decoder_srcm(warped_dst_code)
|
pred_src_dstm = self.decoder_srcm(warped_dst_code)
|
||||||
|
|
||||||
|
pred_src_src, pred_dst_dst, pred_src_dst, = [ [x] if type(x) != list else x for x in [pred_src_src, pred_dst_dst, pred_src_dst, ] ]
|
||||||
|
|
||||||
|
if self.options['learn_mask']:
|
||||||
|
pred_src_srcm, pred_dst_dstm, pred_src_dstm = [ [x] if type(x) != list else x for x in [pred_src_srcm, pred_dst_dstm, pred_src_dstm] ]
|
||||||
|
|
||||||
ms_count = len(pred_src_src)
|
ms_count = len(pred_src_src)
|
||||||
|
|
||||||
target_src_ar = [ target_src if i == 0 else tf.image.resize_bicubic( target_src, (resolution // (2**i) ,)*2 ) for i in range(ms_count-1, -1, -1)]
|
target_src_ar = [ target_src if i == 0 else tf.image.resize_bicubic( target_src, (resolution // (2**i) ,)*2 ) for i in range(ms_count-1, -1, -1)]
|
||||||
|
@ -230,8 +238,10 @@ class SAEModel(ModelBase):
|
||||||
|
|
||||||
|
|
||||||
if self.options['learn_mask']:
|
if self.options['learn_mask']:
|
||||||
src_mask_loss = sum([ K.mean(K.square(target_srcm_ar[i]-pred_src_srcm[i])) for i in range(len(target_srcm_ar)) ])
|
#src_mask_loss = sum([ K.mean(K.square(target_srcm_ar[i]-pred_src_srcm[i])) for i in range(len(target_srcm_ar)) ])
|
||||||
dst_mask_loss = sum([ K.mean(K.square(target_dstm_ar[i]-pred_dst_dstm[i])) for i in range(len(target_dstm_ar)) ])
|
#dst_mask_loss = sum([ K.mean(K.square(target_dstm_ar[i]-pred_dst_dstm[i])) for i in range(len(target_dstm_ar)) ])
|
||||||
|
src_mask_loss = sum([ K.mean(K.square(target_srcm_ar[-1]-pred_src_srcm[-1])) for i in range(len(target_srcm_ar)) ])
|
||||||
|
dst_mask_loss = sum([ K.mean(K.square(target_dstm_ar[-1]-pred_dst_dstm[-1])) for i in range(len(target_dstm_ar)) ])
|
||||||
self.src_dst_mask_train = K.function ([warped_src, target_srcm, warped_dst, target_dstm],[src_mask_loss, dst_mask_loss], optimizer().get_updates(src_mask_loss+dst_mask_loss, src_dst_mask_loss_train_weights) )
|
self.src_dst_mask_train = K.function ([warped_src, target_srcm, warped_dst, target_dstm],[src_mask_loss, dst_mask_loss], optimizer().get_updates(src_mask_loss+dst_mask_loss, src_dst_mask_loss_train_weights) )
|
||||||
|
|
||||||
if self.options['learn_mask']:
|
if self.options['learn_mask']:
|
||||||
|
@ -444,7 +454,7 @@ class SAEModel(ModelBase):
|
||||||
return func
|
return func
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def LIAEDecFlow(output_nc,ed_ch_dims=21,activation='tanh'):
|
def LIAEDecFlow(output_nc,ed_ch_dims=21, multiscale_decoder=True):
|
||||||
exec (nnlib.import_all(), locals(), globals())
|
exec (nnlib.import_all(), locals(), globals())
|
||||||
ed_dims = output_nc * ed_ch_dims
|
ed_dims = output_nc * ed_ch_dims
|
||||||
|
|
||||||
|
@ -459,16 +469,23 @@ class SAEModel(ModelBase):
|
||||||
return func
|
return func
|
||||||
def func(input):
|
def func(input):
|
||||||
x = input[0]
|
x = input[0]
|
||||||
x1 = upscale(ed_dims*8)( x )
|
|
||||||
x1_bgr = to_bgr() ( x1 )
|
|
||||||
|
|
||||||
x2 = upscale(ed_dims*4)( x1 )
|
outputs = []
|
||||||
x2_bgr = to_bgr() ( x2 )
|
x1 = upscale(ed_dims*8)( x )
|
||||||
|
|
||||||
|
if multiscale_decoder:
|
||||||
|
outputs += [ to_bgr() ( x1 ) ]
|
||||||
|
|
||||||
|
x2 = upscale(ed_dims*4)( x1 )
|
||||||
|
|
||||||
|
if multiscale_decoder:
|
||||||
|
outputs += [ to_bgr() ( x2 ) ]
|
||||||
|
|
||||||
x3 = upscale(ed_dims*2)( x2 )
|
x3 = upscale(ed_dims*2)( x2 )
|
||||||
x3_bgr = to_bgr() ( x3 )
|
|
||||||
|
outputs += [ to_bgr() ( x3 ) ]
|
||||||
|
|
||||||
return [ x1_bgr, x2_bgr, x3_bgr ]
|
return outputs
|
||||||
return func
|
return func
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -520,10 +537,9 @@ class SAEModel(ModelBase):
|
||||||
return func
|
return func
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def DFDecFlow(output_nc, ed_ch_dims=21):
|
def DFDecFlow(output_nc, ed_ch_dims=21, multiscale_decoder=True):
|
||||||
exec (nnlib.import_all(), locals(), globals())
|
exec (nnlib.import_all(), locals(), globals())
|
||||||
ed_dims = output_nc * ed_ch_dims
|
ed_dims = output_nc * ed_ch_dims
|
||||||
|
|
||||||
|
|
||||||
def Conv2D (filters, kernel_size, strides=(1, 1), padding='valid', data_format=None, dilation_rate=(1, 1), activation=None, use_bias=True, kernel_initializer=RandomNormal(0, 0.02), bias_initializer='zeros', kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None):
|
def Conv2D (filters, kernel_size, strides=(1, 1), padding='valid', data_format=None, dilation_rate=(1, 1), activation=None, use_bias=True, kernel_initializer=RandomNormal(0, 0.02), bias_initializer='zeros', kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None):
|
||||||
return keras.layers.Conv2D( filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, data_format=data_format, dilation_rate=dilation_rate, activation=activation, use_bias=use_bias, kernel_initializer=kernel_initializer, bias_initializer=bias_initializer, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, activity_regularizer=activity_regularizer, kernel_constraint=kernel_constraint, bias_constraint=bias_constraint )
|
return keras.layers.Conv2D( filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, data_format=data_format, dilation_rate=dilation_rate, activation=activation, use_bias=use_bias, kernel_initializer=kernel_initializer, bias_initializer=bias_initializer, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, activity_regularizer=activity_regularizer, kernel_constraint=kernel_constraint, bias_constraint=bias_constraint )
|
||||||
|
@ -539,16 +555,24 @@ class SAEModel(ModelBase):
|
||||||
return func
|
return func
|
||||||
def func(input):
|
def func(input):
|
||||||
x = input[0]
|
x = input[0]
|
||||||
x1 = upscale(ed_dims*8)( x )
|
|
||||||
x1_bgr = to_bgr() ( x1 )
|
|
||||||
|
|
||||||
x2 = upscale(ed_dims*4)( x1 )
|
outputs = []
|
||||||
x2_bgr = to_bgr() ( x2 )
|
x1 = upscale(ed_dims*8)( x )
|
||||||
|
|
||||||
|
if multiscale_decoder:
|
||||||
|
outputs += [ to_bgr() ( x1 ) ]
|
||||||
|
|
||||||
|
x2 = upscale(ed_dims*4)( x1 )
|
||||||
|
|
||||||
|
if multiscale_decoder:
|
||||||
|
outputs += [ to_bgr() ( x2 ) ]
|
||||||
|
|
||||||
x3 = upscale(ed_dims*2)( x2 )
|
x3 = upscale(ed_dims*2)( x2 )
|
||||||
x3_bgr = to_bgr() ( x3 )
|
|
||||||
|
outputs += [ to_bgr() ( x3 ) ]
|
||||||
|
|
||||||
return [ x1_bgr, x2_bgr, x3_bgr ]
|
return outputs
|
||||||
|
|
||||||
return func
|
return func
|
||||||
|
|
||||||
Model = SAEModel
|
Model = SAEModel
|
Loading…
Add table
Add a link
Reference in a new issue