fix DFLJPG,

SAE: added "rare sample booster"
SAE: pixel loss replaced to smooth transition from DSSIM to PixelLoss in 15k epochs by default
This commit is contained in:
iperov 2019-02-09 18:53:37 +04:00
parent f93b4713a9
commit 4d37fd62cd
11 changed files with 174 additions and 101 deletions

View file

@ -52,13 +52,7 @@ class SAEModel(ModelBase):
self.options['bg_style_power'] = np.clip ( input_number("Background style power ( 0.0 .. 100.0 ?:help skip:%.2f) : " % (default_bg_style_power), default_bg_style_power, help_message="How fast NN will learn dst background style during generalization of src and dst faces. If style is learned good enough, set this value to 0.1-0.3 to prevent artifacts appearing."), 0.0, 100.0 )
else:
self.options['bg_style_power'] = self.options.get('bg_style_power', default_bg_style_power)
if is_first_run or ask_override:
default_pixel_loss = False if is_first_run else self.options.get('pixel_loss', False)
self.options['pixel_loss'] = input_bool ("Use pixel loss? (y/n, ?:help skip: n/default ) : ", default_pixel_loss, help_message="Default DSSIM loss good for initial understanding structure of faces. Use pixel loss after 30-40k epochs to enhance fine details.")
else:
self.options['pixel_loss'] = self.options.get('pixel_loss', False)
default_ae_dims = 256 if self.options['archi'] == 'liae' else 512
default_ed_ch_dims = 42
if is_first_run:
@ -83,6 +77,7 @@ class SAEModel(ModelBase):
bgr_shape = (resolution, resolution, 3)
mask_shape = (resolution, resolution, 1)
dssim_pixel_alpha = Input( (1,) )
warped_src = Input(bgr_shape)
target_src = Input(bgr_shape)
target_srcm = Input(mask_shape)
@ -199,6 +194,7 @@ class SAEModel(ModelBase):
def optimizer():
return Adam(lr=5e-5, beta_1=0.5, beta_2=0.999)
dssim_pixel_alpha_value = dssim_pixel_alpha[0][0]
if self.options['archi'] == 'liae':
src_dst_loss_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights
@ -208,29 +204,29 @@ class SAEModel(ModelBase):
src_dst_loss_train_weights = self.encoder.trainable_weights + self.decoder_src.trainable_weights + self.decoder_dst.trainable_weights
if self.options['learn_mask']:
src_dst_mask_loss_train_weights = self.encoder.trainable_weights + self.decoder_srcm.trainable_weights + self.decoder_dstm.trainable_weights
if self.options['pixel_loss']:
src_loss = sum([ K.mean( 100*K.square( target_src_masked_ar[i] - pred_src_src_sigm_ar[i] * target_srcm_sigm_ar[i] )) for i in range(len(target_src_masked_ar)) ])
else:
src_loss = sum([ K.mean( 100*K.square(tf_dssim(2.0)( target_src_masked_ar[i], pred_src_src_sigm_ar[i] * target_srcm_sigm_ar[i] ) )) for i in range(len(target_src_masked_ar)) ])
src_dssim_loss_batch = sum([ ( 100*K.square(tf_dssim(2.0)( target_src_masked_ar[i], pred_src_src_sigm_ar[i] * target_srcm_sigm_ar[i] ) )) for i in range(len(target_src_masked_ar)) ])
src_pixel_loss_batch = sum([ tf_reduce_mean ( 100*K.square( target_src_masked_ar[i] - pred_src_src_sigm_ar[i] * target_srcm_sigm_ar[i] ), axis=[1,2,3]) for i in range(len(target_src_masked_ar)) ])
src_loss_batch = src_dssim_loss_batch*(1.0-dssim_pixel_alpha_value) + src_pixel_loss_batch*dssim_pixel_alpha_value
src_loss = K.mean(src_loss_batch)
if self.options['face_style_power'] != 0:
face_style_power = self.options['face_style_power'] / 100.0
src_loss += tf_style_loss(gaussian_blur_radius=resolution // 8, loss_weight=0.2*face_style_power)( psd_target_dst_masked_ar[-1], target_dst_masked_ar[-1] )
if self.options['bg_style_power'] != 0:
bg_style_power = self.options['bg_style_power'] / 100.0
if self.options['pixel_loss']:
src_loss += K.mean( (100*bg_style_power)*K.square( psd_target_dst_anti_masked_ar[-1] - target_dst_anti_masked_ar[-1] ))
else:
src_loss += K.mean( (100*bg_style_power)*K.square(tf_dssim(2.0)( psd_target_dst_anti_masked_ar[-1], target_dst_anti_masked_ar[-1] )))
if self.options['pixel_loss']:
dst_loss = sum([ K.mean( 100*K.square( target_dst_masked_ar[i] - pred_dst_dst_sigm_ar[i] * target_dstm_sigm_ar[i] )) for i in range(len(target_dst_masked_ar)) ])
else:
dst_loss = sum([ K.mean( 100*K.square(tf_dssim(2.0)( target_dst_masked_ar[i], pred_dst_dst_sigm_ar[i] * target_dstm_sigm_ar[i] ) )) for i in range(len(target_dst_masked_ar)) ])
self.src_dst_train = K.function ([warped_src, target_src, target_srcm, warped_dst, target_dst, target_dstm ],[src_loss,dst_loss], optimizer().get_updates(src_loss+dst_loss, src_dst_loss_train_weights) )
bg_style_power = self.options['bg_style_power'] / 100.0
bg_dssim_loss = K.mean( (100*bg_style_power)*K.square(tf_dssim(2.0)( psd_target_dst_anti_masked_ar[-1], target_dst_anti_masked_ar[-1] )))
bg_pixel_loss = K.mean( (100*bg_style_power)*K.square( psd_target_dst_anti_masked_ar[-1] - target_dst_anti_masked_ar[-1] ))
src_loss += bg_dssim_loss*(1.0-dssim_pixel_alpha_value) + bg_pixel_loss*dssim_pixel_alpha_value
dst_dssim_loss_batch = sum([ ( 100*K.square(tf_dssim(2.0)( target_dst_masked_ar[i], pred_dst_dst_sigm_ar[i] * target_dstm_sigm_ar[i] ) )) for i in range(len(target_dst_masked_ar)) ])
dst_pixel_loss_batch = sum([ tf_reduce_mean ( 100*K.square( target_dst_masked_ar[i] - pred_dst_dst_sigm_ar[i] * target_dstm_sigm_ar[i] ), axis=[1,2,3]) for i in range(len(target_dst_masked_ar)) ])
dst_loss_batch = dst_dssim_loss_batch*(1.0-dssim_pixel_alpha_value) + dst_pixel_loss_batch*dssim_pixel_alpha_value
dst_loss = K.mean(dst_loss_batch)
self.src_dst_train = K.function ([dssim_pixel_alpha, warped_src, target_src, target_srcm, warped_dst, target_dst, target_dstm ],[src_loss,dst_loss,src_loss_batch,dst_loss_batch], optimizer().get_updates(src_loss+dst_loss, src_dst_loss_train_weights) )
if self.options['learn_mask']:
@ -250,6 +246,9 @@ class SAEModel(ModelBase):
self.AE_convert = K.function ([warped_dst],[ pred_src_dst[-1] ])
if self.is_training_mode:
self.src_sample_losses = []
self.dst_sample_losses = []
f = SampleProcessor.TypeFlags
face_type = f.FACE_ALIGN_FULL if self.options['face_type'] == 'f' else f.FACE_ALIGN_HALF
self.set_training_data_generators ([
@ -259,14 +258,14 @@ class SAEModel(ModelBase):
output_sample_types=[ [f.WARPED_TRANSFORMED | face_type | f.MODE_BGR, resolution],
[f.TRANSFORMED | face_type | f.MODE_BGR, resolution],
[f.TRANSFORMED | face_type | f.MODE_M | f.FACE_MASK_FULL, resolution]
] ),
], add_sample_idx=True ),
SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, normalize_tanh = True),
output_sample_types=[ [f.WARPED_TRANSFORMED | face_type | f.MODE_BGR, resolution],
[f.TRANSFORMED | face_type | f.MODE_BGR, resolution],
[f.TRANSFORMED | face_type | f.MODE_M | f.FACE_MASK_FULL, resolution]
] )
], add_sample_idx=True )
])
#override
def onSave(self):
@ -289,13 +288,39 @@ class SAEModel(ModelBase):
self.save_weights_safe(ar)
#override
def onTrainOneEpoch(self, sample):
warped_src, target_src, target_src_mask = sample[0]
warped_dst, target_dst, target_dst_mask = sample[1]
def onTrainOneEpoch(self, generators_samples, generators_list):
warped_src, target_src, target_src_mask, src_sample_idxs = generators_samples[0]
warped_dst, target_dst, target_dst_mask, dst_sample_idxs = generators_samples[1]
src_loss, dst_loss = self.src_dst_train ([warped_src, target_src, target_src_mask, warped_dst, target_dst, target_dst_mask])
dssim_pixel_alpha = np.clip ( self.epoch / 15000.0, 0.0, 1.0 ) #smooth transition between DSSIM and MSE in 15k epochs
dssim_pixel_alpha = np.repeat( dssim_pixel_alpha, (self.batch_size,) )
dssim_pixel_alpha = np.expand_dims(dssim_pixel_alpha,-1)
src_loss, dst_loss, src_sample_losses, dst_sample_losses = self.src_dst_train ([dssim_pixel_alpha, warped_src, target_src, target_src_mask, warped_dst, target_dst, target_dst_mask])
#gathering array of sample_losses
self.src_sample_losses += [[src_sample_idxs[i], src_sample_losses[i]] for i in range(self.batch_size) ]
self.dst_sample_losses += [[dst_sample_idxs[i], dst_sample_losses[i]] for i in range(self.batch_size) ]
if len(self.src_sample_losses) >= 48: #array is big enough
#fetching idxs which losses are bigger than average
x = np.array (self.src_sample_losses)
self.src_sample_losses = []
b = x[:,1]
idxs = (x[:,0][ np.argwhere ( b [ b > np.mean(b) ] )[:,0] ]).astype(np.uint)
generators_list[0].repeat_sample_idxs(idxs) #ask generator to repeat these sample idxs
if len(self.dst_sample_losses) >= 48: #array is big enough
#fetching idxs which losses are bigger than average
x = np.array (self.dst_sample_losses)
self.dst_sample_losses = []
b = x[:,1]
idxs = (x[:,0][ np.argwhere ( b [ b > np.mean(b) ] )[:,0] ]).astype(np.uint)
generators_list[1].repeat_sample_idxs(idxs) #ask generator to repeat these sample idxs
if self.options['learn_mask']:
src_mask_loss, dst_mask_loss, = self.src_dst_mask_train ([warped_src, target_src_mask, warped_dst, target_dst_mask])
@ -453,6 +478,9 @@ class SAEModel(ModelBase):
strides = resolution // 32 if adapt_k_size else 2
lowest_dense_res = resolution // 16
def Conv2D (filters, kernel_size, strides=(1, 1), padding='valid', data_format=None, dilation_rate=(1, 1), activation=None, use_bias=True, kernel_initializer=RandomNormal(0, 0.02), bias_initializer='zeros', kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None):
return keras.layers.Conv2D( filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, data_format=data_format, dilation_rate=dilation_rate, activation=activation, use_bias=use_bias, kernel_initializer=kernel_initializer, bias_initializer=bias_initializer, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, activity_regularizer=activity_regularizer, kernel_constraint=kernel_constraint, bias_constraint=bias_constraint )
def downscale (dim):
def func(x):
return LeakyReLU(0.1)(Conv2D(dim, k_size, strides=strides, padding='same')(x))
@ -496,6 +524,10 @@ class SAEModel(ModelBase):
exec (nnlib.import_all(), locals(), globals())
ed_dims = output_nc * ed_ch_dims
def Conv2D (filters, kernel_size, strides=(1, 1), padding='valid', data_format=None, dilation_rate=(1, 1), activation=None, use_bias=True, kernel_initializer=RandomNormal(0, 0.02), bias_initializer='zeros', kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None):
return keras.layers.Conv2D( filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, data_format=data_format, dilation_rate=dilation_rate, activation=activation, use_bias=use_bias, kernel_initializer=kernel_initializer, bias_initializer=bias_initializer, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, activity_regularizer=activity_regularizer, kernel_constraint=kernel_constraint, bias_constraint=bias_constraint )
def upscale (dim):
def func(x):
return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, 3, strides=1, padding='same')(x)))