SAE: you have to restart training,

added multiscale decoder as option.
mask now training as not multiscaled
This commit is contained in:
iperov 2019-02-09 20:33:26 +04:00
parent b87e6be614
commit 51a13c90d1

View file

@ -31,12 +31,14 @@ class SAEModel(ModelBase):
self.options['face_type'] = input_str ("Half or Full face? (h/f, ?:help skip:f) : ", default_face_type, ['h','f'], help_message="Half face has better resolution, but covers less area of cheeks.").lower()
self.options['archi'] = input_str ("AE architecture (df, liae, ?:help skip:%s) : " % (default_archi) , default_archi, ['df','liae'], help_message="DF keeps faces more natural, while LIAE can fix overly different face shapes.").lower()
self.options['lighter_encoder'] = input_bool ("Use lightweight encoder? (y/n, ?:help skip:n) : ", False, help_message="Lightweight encoder is 35% faster, requires less VRAM, sacrificing overall quality.")
self.options['learn_mask'] = input_bool ("Learn mask? (y/n, ?:help skip:y) : ", True, help_message="Choose NO to reduce model size. In this case converter forced to use 'not predicted mask' that is not smooth as predicted. Styled SAE can learn without mask and produce same quality fake if you choose high blur value in converter.")
self.options['multiscale_decoder'] = input_bool ("Use multiscale decoder? (y/n, ?:help skip:n) : ", False, help_message="This option forces decoder to produce higher detailed image and make final face look more like dst.")
self.options['learn_mask'] = input_bool ("Learn mask? (y/n, ?:help skip:y) : ", True, help_message="Choose NO to reduce model size. In this case converter forced to use 'not predicted mask' that is not smooth as predicted. Styled SAE can learn without mask and produce same quality fake.")
else:
self.options['resolution'] = self.options.get('resolution', default_resolution)
self.options['face_type'] = self.options.get('face_type', default_face_type)
self.options['archi'] = self.options.get('archi', default_archi)
self.options['lighter_encoder'] = self.options.get('lighter_encoder', False)
self.options['multiscale_decoder'] = self.options.get('multiscale_decoder', False)
self.options['learn_mask'] = self.options.get('learn_mask', True)
default_face_style_power = 10.0
@ -96,10 +98,10 @@ class SAEModel(ModelBase):
inter_output_Inputs = [ Input( np.array(K.int_shape(x)[1:])*(1,1,2) ) for x in self.inter_B.outputs ]
self.decoder = modelify(SAEModel.LIAEDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2)) (inter_output_Inputs)
self.decoder = modelify(SAEModel.LIAEDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2, multiscale_decoder=self.options['multiscale_decoder'])) (inter_output_Inputs)
if self.options['learn_mask']:
self.decoderm = modelify(SAEModel.LIAEDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5) )) (inter_output_Inputs)
self.decoderm = modelify(SAEModel.LIAEDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5), multiscale_decoder=False )) (inter_output_Inputs)
if not self.is_first_run():
self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5))
@ -129,17 +131,19 @@ class SAEModel(ModelBase):
pred_dst_dstm = self.decoderm(warped_dst_inter_code)
pred_src_dstm = self.decoderm(warped_src_dst_inter_code)
else:
self.encoder = modelify(SAEModel.DFEncFlow(resolution, adapt_k_size, self.options['lighter_encoder'], ae_dims=ae_dims, ed_ch_dims=ed_ch_dims) ) (Input(bgr_shape))
dec_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ]
self.decoder_src = modelify(SAEModel.DFDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2)) (dec_Inputs)
self.decoder_dst = modelify(SAEModel.DFDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2)) (dec_Inputs)
self.decoder_src = modelify(SAEModel.DFDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2, multiscale_decoder=self.options['multiscale_decoder'])) (dec_Inputs)
self.decoder_dst = modelify(SAEModel.DFDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2, multiscale_decoder=self.options['multiscale_decoder'])) (dec_Inputs)
if self.options['learn_mask']:
self.decoder_srcm = modelify(SAEModel.DFDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5))) (dec_Inputs)
self.decoder_dstm = modelify(SAEModel.DFDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5))) (dec_Inputs)
self.decoder_srcm = modelify(SAEModel.DFDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5), multiscale_decoder=False)) (dec_Inputs)
self.decoder_dstm = modelify(SAEModel.DFDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5), multiscale_decoder=False)) (dec_Inputs)
if not self.is_first_run():
self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5))
@ -160,6 +164,10 @@ class SAEModel(ModelBase):
pred_dst_dstm = self.decoder_dstm(warped_dst_code)
pred_src_dstm = self.decoder_srcm(warped_dst_code)
pred_src_src, pred_dst_dst, pred_src_dst, = [ [x] if type(x) != list else x for x in [pred_src_src, pred_dst_dst, pred_src_dst, ] ]
if self.options['learn_mask']:
pred_src_srcm, pred_dst_dstm, pred_src_dstm = [ [x] if type(x) != list else x for x in [pred_src_srcm, pred_dst_dstm, pred_src_dstm] ]
ms_count = len(pred_src_src)
@ -230,8 +238,10 @@ class SAEModel(ModelBase):
if self.options['learn_mask']:
src_mask_loss = sum([ K.mean(K.square(target_srcm_ar[i]-pred_src_srcm[i])) for i in range(len(target_srcm_ar)) ])
dst_mask_loss = sum([ K.mean(K.square(target_dstm_ar[i]-pred_dst_dstm[i])) for i in range(len(target_dstm_ar)) ])
#src_mask_loss = sum([ K.mean(K.square(target_srcm_ar[i]-pred_src_srcm[i])) for i in range(len(target_srcm_ar)) ])
#dst_mask_loss = sum([ K.mean(K.square(target_dstm_ar[i]-pred_dst_dstm[i])) for i in range(len(target_dstm_ar)) ])
src_mask_loss = sum([ K.mean(K.square(target_srcm_ar[-1]-pred_src_srcm[-1])) for i in range(len(target_srcm_ar)) ])
dst_mask_loss = sum([ K.mean(K.square(target_dstm_ar[-1]-pred_dst_dstm[-1])) for i in range(len(target_dstm_ar)) ])
self.src_dst_mask_train = K.function ([warped_src, target_srcm, warped_dst, target_dstm],[src_mask_loss, dst_mask_loss], optimizer().get_updates(src_mask_loss+dst_mask_loss, src_dst_mask_loss_train_weights) )
if self.options['learn_mask']:
@ -444,7 +454,7 @@ class SAEModel(ModelBase):
return func
@staticmethod
def LIAEDecFlow(output_nc,ed_ch_dims=21,activation='tanh'):
def LIAEDecFlow(output_nc,ed_ch_dims=21, multiscale_decoder=True):
exec (nnlib.import_all(), locals(), globals())
ed_dims = output_nc * ed_ch_dims
@ -459,16 +469,23 @@ class SAEModel(ModelBase):
return func
def func(input):
x = input[0]
outputs = []
x1 = upscale(ed_dims*8)( x )
x1_bgr = to_bgr() ( x1 )
if multiscale_decoder:
outputs += [ to_bgr() ( x1 ) ]
x2 = upscale(ed_dims*4)( x1 )
x2_bgr = to_bgr() ( x2 )
if multiscale_decoder:
outputs += [ to_bgr() ( x2 ) ]
x3 = upscale(ed_dims*2)( x2 )
x3_bgr = to_bgr() ( x3 )
return [ x1_bgr, x2_bgr, x3_bgr ]
outputs += [ to_bgr() ( x3 ) ]
return outputs
return func
@staticmethod
@ -520,11 +537,10 @@ class SAEModel(ModelBase):
return func
@staticmethod
def DFDecFlow(output_nc, ed_ch_dims=21):
def DFDecFlow(output_nc, ed_ch_dims=21, multiscale_decoder=True):
exec (nnlib.import_all(), locals(), globals())
ed_dims = output_nc * ed_ch_dims
def Conv2D (filters, kernel_size, strides=(1, 1), padding='valid', data_format=None, dilation_rate=(1, 1), activation=None, use_bias=True, kernel_initializer=RandomNormal(0, 0.02), bias_initializer='zeros', kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None):
return keras.layers.Conv2D( filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, data_format=data_format, dilation_rate=dilation_rate, activation=activation, use_bias=use_bias, kernel_initializer=kernel_initializer, bias_initializer=bias_initializer, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, activity_regularizer=activity_regularizer, kernel_constraint=kernel_constraint, bias_constraint=bias_constraint )
@ -539,16 +555,24 @@ class SAEModel(ModelBase):
return func
def func(input):
x = input[0]
outputs = []
x1 = upscale(ed_dims*8)( x )
x1_bgr = to_bgr() ( x1 )
if multiscale_decoder:
outputs += [ to_bgr() ( x1 ) ]
x2 = upscale(ed_dims*4)( x1 )
x2_bgr = to_bgr() ( x2 )
if multiscale_decoder:
outputs += [ to_bgr() ( x2 ) ]
x3 = upscale(ed_dims*2)( x2 )
x3_bgr = to_bgr() ( x3 )
return [ x1_bgr, x2_bgr, x3_bgr ]
outputs += [ to_bgr() ( x3 ) ]
return outputs
return func
Model = SAEModel