mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-07 21:42:08 -07:00
AMP: some arhi change results training stabilization. New options blur_out_mask + rtm_dst_denoise. Sample processors count are no more limited to 8, thus if you have AMD processor with 16+ cores, increase paging file size.
This commit is contained in:
parent
8e63666390
commit
91187ecb95
1 changed files with 79 additions and 43 deletions
|
@ -28,7 +28,9 @@ class AMPModel(ModelBase):
|
||||||
default_d_mask_dims = self.options['d_mask_dims'] = self.options.get('d_mask_dims', None)
|
default_d_mask_dims = self.options['d_mask_dims'] = self.options.get('d_mask_dims', None)
|
||||||
default_morph_factor = self.options['morph_factor'] = self.options.get('morph_factor', 0.5)
|
default_morph_factor = self.options['morph_factor'] = self.options.get('morph_factor', 0.5)
|
||||||
default_uniform_yaw = self.options['uniform_yaw'] = self.load_or_def_option('uniform_yaw', False)
|
default_uniform_yaw = self.options['uniform_yaw'] = self.load_or_def_option('uniform_yaw', False)
|
||||||
default_lr_dropout = self.options['lr_dropout'] = self.load_or_def_option('lr_dropout', 'n')
|
default_blur_out_mask = self.options['blur_out_mask'] = self.load_or_def_option('blur_out_mask', False)
|
||||||
|
default_dst_denoise = self.options['rtm_dst_denoise'] = self.load_or_def_option('rtm_dst_denoise', False)
|
||||||
|
default_lr_dropout = self.options['lr_dropout'] = self.load_or_def_option('lr_dropout', 'n')
|
||||||
default_random_warp = self.options['random_warp'] = self.load_or_def_option('random_warp', True)
|
default_random_warp = self.options['random_warp'] = self.load_or_def_option('random_warp', True)
|
||||||
default_ct_mode = self.options['ct_mode'] = self.load_or_def_option('ct_mode', 'none')
|
default_ct_mode = self.options['ct_mode'] = self.load_or_def_option('ct_mode', 'none')
|
||||||
default_clipgrad = self.options['clipgrad'] = self.load_or_def_option('clipgrad', False)
|
default_clipgrad = self.options['clipgrad'] = self.load_or_def_option('clipgrad', False)
|
||||||
|
@ -73,6 +75,9 @@ class AMPModel(ModelBase):
|
||||||
|
|
||||||
if self.is_first_run() or ask_override:
|
if self.is_first_run() or ask_override:
|
||||||
self.options['uniform_yaw'] = io.input_bool ("Uniform yaw distribution of samples", default_uniform_yaw, help_message='Helps to fix blurry side faces due to small amount of them in the faceset.')
|
self.options['uniform_yaw'] = io.input_bool ("Uniform yaw distribution of samples", default_uniform_yaw, help_message='Helps to fix blurry side faces due to small amount of them in the faceset.')
|
||||||
|
self.options['blur_out_mask'] = io.input_bool ("Blur out mask", default_blur_out_mask, help_message='Blurs nearby area outside of applied face mask of training samples. The result is the background near the face is smoothed and less noticeable on swapped face. The exact xseg mask in src and dst faceset is required.')
|
||||||
|
self.options['rtm_dst_denoise'] = io.input_bool ("Denoise RTM DST faceset.", default_dst_denoise, help_message='Used in RTM(ReadyToMerge) training with RTM DST faceset. Removes high frequency noise keeping edges. Result is better face syncronization with any face. Can be enabled at any time.')
|
||||||
|
|
||||||
self.options['lr_dropout'] = io.input_str (f"Use learning rate dropout", default_lr_dropout, ['n','y','cpu'], help_message="When the face is trained enough, you can enable this option to get extra sharpness and reduce subpixel shake for less amount of iterations. Enabled it before `disable random warp` and before GAN. \nn - disabled.\ny - enabled\ncpu - enabled on CPU. This allows not to use extra VRAM, sacrificing 20% time of iteration.")
|
self.options['lr_dropout'] = io.input_str (f"Use learning rate dropout", default_lr_dropout, ['n','y','cpu'], help_message="When the face is trained enough, you can enable this option to get extra sharpness and reduce subpixel shake for less amount of iterations. Enabled it before `disable random warp` and before GAN. \nn - disabled.\ny - enabled\ncpu - enabled on CPU. This allows not to use extra VRAM, sacrificing 20% time of iteration.")
|
||||||
|
|
||||||
default_gan_power = self.options['gan_power'] = self.load_or_def_option('gan_power', 0.0)
|
default_gan_power = self.options['gan_power'] = self.load_or_def_option('gan_power', 0.0)
|
||||||
|
@ -120,7 +125,10 @@ class AMPModel(ModelBase):
|
||||||
morph_factor = self.options['morph_factor']
|
morph_factor = self.options['morph_factor']
|
||||||
gan_power = self.gan_power = self.options['gan_power']
|
gan_power = self.gan_power = self.options['gan_power']
|
||||||
random_warp = self.options['random_warp']
|
random_warp = self.options['random_warp']
|
||||||
|
|
||||||
|
blur_out_mask = self.options['blur_out_mask']
|
||||||
|
rtm_dst_denoise = self.options['rtm_dst_denoise']
|
||||||
|
|
||||||
ct_mode = self.options['ct_mode']
|
ct_mode = self.options['ct_mode']
|
||||||
if ct_mode == 'none':
|
if ct_mode == 'none':
|
||||||
ct_mode = None
|
ct_mode = None
|
||||||
|
@ -128,7 +136,7 @@ class AMPModel(ModelBase):
|
||||||
use_fp16 = False
|
use_fp16 = False
|
||||||
if self.is_exporting:
|
if self.is_exporting:
|
||||||
use_fp16 = io.input_bool ("Export quantized?", False, help_message='Makes the exported model faster. If you have problems, disable this option.')
|
use_fp16 = io.input_bool ("Export quantized?", False, help_message='Makes the exported model faster. If you have problems, disable this option.')
|
||||||
|
|
||||||
conv_dtype = tf.float16 if use_fp16 else tf.float32
|
conv_dtype = tf.float16 if use_fp16 else tf.float32
|
||||||
|
|
||||||
class Downscale(nn.ModelBase):
|
class Downscale(nn.ModelBase):
|
||||||
|
@ -289,11 +297,11 @@ class AMPModel(ModelBase):
|
||||||
# Initialize optimizers
|
# Initialize optimizers
|
||||||
clipnorm = 1.0 if self.options['clipgrad'] else 0.0
|
clipnorm = 1.0 if self.options['clipgrad'] else 0.0
|
||||||
lr_dropout = 0.3 if self.options['lr_dropout'] in ['y','cpu'] else 1.0
|
lr_dropout = 0.3 if self.options['lr_dropout'] in ['y','cpu'] else 1.0
|
||||||
|
|
||||||
self.all_weights = self.encoder.get_weights() + self.decoder.get_weights()
|
self.G_weights = self.encoder.get_weights() + self.decoder.get_weights()
|
||||||
|
|
||||||
self.src_dst_opt = nn.AdaBelief(lr=5e-5, lr_dropout=lr_dropout, clipnorm=clipnorm, name='src_dst_opt')
|
self.src_dst_opt = nn.AdaBelief(lr=5e-5, lr_dropout=lr_dropout, clipnorm=clipnorm, name='src_dst_opt')
|
||||||
self.src_dst_opt.initialize_variables (self.all_weights, vars_on_cpu=optimizer_vars_on_cpu)
|
self.src_dst_opt.initialize_variables (self.G_weights, vars_on_cpu=optimizer_vars_on_cpu)
|
||||||
self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ]
|
self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ]
|
||||||
|
|
||||||
if gan_power != 0:
|
if gan_power != 0:
|
||||||
|
@ -320,7 +328,13 @@ class AMPModel(ModelBase):
|
||||||
gpu_src_losses = []
|
gpu_src_losses = []
|
||||||
gpu_dst_losses = []
|
gpu_dst_losses = []
|
||||||
gpu_G_loss_gradients = []
|
gpu_G_loss_gradients = []
|
||||||
gpu_GAN_loss_grads = []
|
gpu_GAN_loss_gradients = []
|
||||||
|
|
||||||
|
def DLossOnes(logits):
|
||||||
|
return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(logits), logits=logits), axis=[1,2,3])
|
||||||
|
|
||||||
|
def DLossZeros(logits):
|
||||||
|
return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(logits), logits=logits), axis=[1,2,3])
|
||||||
|
|
||||||
for gpu_id in range(gpu_count):
|
for gpu_id in range(gpu_count):
|
||||||
with tf.device( f'/{devices[gpu_id].tf_dev_type}:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ):
|
with tf.device( f'/{devices[gpu_id].tf_dev_type}:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ):
|
||||||
|
@ -342,39 +356,63 @@ class AMPModel(ModelBase):
|
||||||
|
|
||||||
gpu_src_inter_src_code, gpu_src_inter_dst_code = self.inter_src (gpu_src_code), self.inter_dst (gpu_src_code)
|
gpu_src_inter_src_code, gpu_src_inter_dst_code = self.inter_src (gpu_src_code), self.inter_dst (gpu_src_code)
|
||||||
gpu_dst_inter_src_code, gpu_dst_inter_dst_code = self.inter_src (gpu_dst_code), self.inter_dst (gpu_dst_code)
|
gpu_dst_inter_src_code, gpu_dst_inter_dst_code = self.inter_src (gpu_dst_code), self.inter_dst (gpu_dst_code)
|
||||||
|
|
||||||
inter_rnd_binomial = nn.random_binomial( [bs_per_gpu, gpu_src_inter_src_code.shape.as_list()[1], 1,1] , p=morph_factor)
|
inter_dims_bin = int(inter_dims*morph_factor)
|
||||||
|
inter_rnd_binomial = tf.stack([tf.concat([tf.tile(tf.constant([1], tf.float32), ( inter_dims_bin, )),
|
||||||
|
tf.tile(tf.constant([0], tf.float32), ( inter_dims-inter_dims_bin, ))], 0 ) for _ in range(bs_per_gpu)], 0)
|
||||||
|
|
||||||
|
inter_rnd_binomial = tf.stop_gradient(inter_rnd_binomial[...,None,None])
|
||||||
|
|
||||||
gpu_src_code = gpu_src_inter_src_code * inter_rnd_binomial + gpu_src_inter_dst_code * (1-inter_rnd_binomial)
|
gpu_src_code = gpu_src_inter_src_code * inter_rnd_binomial + gpu_src_inter_dst_code * (1-inter_rnd_binomial)
|
||||||
gpu_dst_code = gpu_dst_inter_dst_code
|
gpu_dst_code = gpu_dst_inter_dst_code
|
||||||
|
|
||||||
inter_dims_slice = tf.cast(inter_dims*self.morph_value_t[0], tf.int32)
|
inter_dims_slice = tf.cast(inter_dims*self.morph_value_t[0], tf.int32)
|
||||||
gpu_src_dst_code = tf.concat( (tf.slice(gpu_dst_inter_src_code, [0,0,0,0], [-1, inter_dims_slice , inter_res, inter_res]),
|
gpu_src_dst_code = tf.concat( (tf.slice(gpu_dst_inter_src_code, [0,0,0,0], [-1, inter_dims_slice , inter_res, inter_res]),
|
||||||
tf.slice(gpu_dst_inter_dst_code, [0,inter_dims_slice,0,0], [-1,inter_dims-inter_dims_slice, inter_res,inter_res]) ), 1 )
|
tf.slice(gpu_dst_inter_dst_code, [0,inter_dims_slice,0,0], [-1,inter_dims-inter_dims_slice, inter_res,inter_res]) ), 1 )
|
||||||
|
|
||||||
gpu_pred_src_src, gpu_pred_src_srcm = self.decoder(gpu_src_code)
|
gpu_pred_src_src, gpu_pred_src_srcm = self.decoder(gpu_src_code)
|
||||||
gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder(gpu_dst_code)
|
gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder(gpu_dst_code)
|
||||||
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code)
|
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code)
|
||||||
|
|
||||||
gpu_pred_src_src_list.append(gpu_pred_src_src), gpu_pred_src_srcm_list.append(gpu_pred_src_srcm)
|
gpu_pred_src_src_list.append(gpu_pred_src_src), gpu_pred_src_srcm_list.append(gpu_pred_src_srcm)
|
||||||
gpu_pred_dst_dst_list.append(gpu_pred_dst_dst), gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm)
|
gpu_pred_dst_dst_list.append(gpu_pred_dst_dst), gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm)
|
||||||
gpu_pred_src_dst_list.append(gpu_pred_src_dst), gpu_pred_src_dstm_list.append(gpu_pred_src_dstm)
|
gpu_pred_src_dst_list.append(gpu_pred_src_dst), gpu_pred_src_dstm_list.append(gpu_pred_src_dstm)
|
||||||
|
|
||||||
gpu_target_srcm_blur = tf.clip_by_value( nn.gaussian_blur(gpu_target_srcm, max(1, resolution // 32) ), 0, 0.5) * 2
|
gpu_target_srcm_anti = 1-gpu_target_srcm
|
||||||
gpu_target_dstm_blur = tf.clip_by_value(nn.gaussian_blur(gpu_target_dstm, max(1, resolution // 32) ), 0, 0.5) * 2
|
gpu_target_dstm_anti = 1-gpu_target_dstm
|
||||||
|
|
||||||
|
gpu_target_srcm_gblur = nn.gaussian_blur(gpu_target_srcm, resolution // 32)
|
||||||
|
gpu_target_dstm_gblur = nn.gaussian_blur(gpu_target_dstm, resolution // 32)
|
||||||
|
|
||||||
|
gpu_target_srcm_blur = tf.clip_by_value(gpu_target_srcm_gblur, 0, 0.5) * 2
|
||||||
|
gpu_target_dstm_blur = tf.clip_by_value(gpu_target_dstm_gblur, 0, 0.5) * 2
|
||||||
gpu_target_srcm_anti_blur = 1.0-gpu_target_srcm_blur
|
gpu_target_srcm_anti_blur = 1.0-gpu_target_srcm_blur
|
||||||
gpu_target_dstm_anti_blur = 1.0-gpu_target_dstm_blur
|
gpu_target_dstm_anti_blur = 1.0-gpu_target_dstm_blur
|
||||||
|
|
||||||
|
if blur_out_mask:
|
||||||
|
#gpu_target_src = gpu_target_src*gpu_target_srcm_blur + nn.gaussian_blur(gpu_target_src, resolution // 32)*gpu_target_srcm_anti_blur
|
||||||
|
#gpu_target_dst = gpu_target_dst*gpu_target_dstm_blur + nn.gaussian_blur(gpu_target_dst, resolution // 32)*gpu_target_dstm_anti_blur
|
||||||
|
bg_blur_div = 128
|
||||||
|
|
||||||
|
gpu_target_src = gpu_target_src*gpu_target_srcm + \
|
||||||
|
tf.math.divide_no_nan(nn.gaussian_blur(gpu_target_src*gpu_target_srcm_anti, resolution / bg_blur_div),
|
||||||
|
(1-nn.gaussian_blur(gpu_target_srcm, resolution / bg_blur_div) ) ) * gpu_target_srcm_anti
|
||||||
|
|
||||||
|
gpu_target_dst = gpu_target_dst*gpu_target_dstm + \
|
||||||
|
tf.math.divide_no_nan(nn.gaussian_blur(gpu_target_dst*gpu_target_dstm_anti, resolution / bg_blur_div),
|
||||||
|
(1-nn.gaussian_blur(gpu_target_dstm, resolution / bg_blur_div)) ) * gpu_target_dstm_anti
|
||||||
|
|
||||||
|
|
||||||
gpu_target_src_masked = gpu_target_src*gpu_target_srcm_blur
|
gpu_target_src_masked = gpu_target_src*gpu_target_srcm_blur
|
||||||
gpu_target_dst_masked = gpu_target_dst*gpu_target_dstm_blur
|
gpu_target_dst_masked = gpu_target_dst*gpu_target_dstm_blur
|
||||||
gpu_target_src_anti_masked = gpu_target_src*gpu_target_srcm_anti_blur
|
gpu_target_src_anti_masked = gpu_target_src*gpu_target_srcm_anti_blur
|
||||||
gpu_target_dst_anti_masked = gpu_target_dst*gpu_target_dstm_anti_blur
|
gpu_target_dst_anti_masked = gpu_target_dst*gpu_target_dstm_anti_blur
|
||||||
|
|
||||||
gpu_pred_src_src_masked = gpu_pred_src_src*gpu_target_srcm_blur
|
gpu_pred_src_src_masked = gpu_pred_src_src*gpu_target_srcm_blur
|
||||||
gpu_pred_dst_dst_masked = gpu_pred_dst_dst*gpu_target_dstm_blur
|
gpu_pred_dst_dst_masked = gpu_pred_dst_dst*gpu_target_dstm_blur
|
||||||
gpu_pred_src_src_anti_masked = gpu_pred_src_src*gpu_target_srcm_anti_blur
|
gpu_pred_src_src_anti_masked = gpu_pred_src_src*gpu_target_srcm_anti_blur
|
||||||
gpu_pred_dst_dst_anti_masked = gpu_pred_dst_dst*gpu_target_dstm_anti_blur
|
gpu_pred_dst_dst_anti_masked = gpu_pred_dst_dst*gpu_target_dstm_anti_blur
|
||||||
|
|
||||||
# Structural loss
|
# Structural loss
|
||||||
gpu_src_loss = tf.reduce_mean (5*nn.dssim(gpu_target_src_masked, gpu_pred_src_src_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
|
gpu_src_loss = tf.reduce_mean (5*nn.dssim(gpu_target_src_masked, gpu_pred_src_src_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
|
||||||
gpu_src_loss += tf.reduce_mean (5*nn.dssim(gpu_target_src_masked, gpu_pred_src_src_masked, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1])
|
gpu_src_loss += tf.reduce_mean (5*nn.dssim(gpu_target_src_masked, gpu_pred_src_src_masked, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1])
|
||||||
|
@ -393,20 +431,14 @@ class AMPModel(ModelBase):
|
||||||
gpu_src_loss += tf.reduce_mean ( 10*tf.square( gpu_target_srcm - gpu_pred_src_srcm ),axis=[1,2,3] )
|
gpu_src_loss += tf.reduce_mean ( 10*tf.square( gpu_target_srcm - gpu_pred_src_srcm ),axis=[1,2,3] )
|
||||||
gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] )
|
gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] )
|
||||||
|
|
||||||
# dst-dst background weak loss
|
|
||||||
gpu_dst_loss += tf.reduce_mean(0.1*tf.square(gpu_pred_dst_dst_anti_masked-gpu_target_dst_anti_masked),axis=[1,2,3] )
|
|
||||||
gpu_dst_loss += 0.000001*nn.total_variation_mse(gpu_pred_dst_dst_anti_masked)
|
|
||||||
|
|
||||||
gpu_src_losses += [gpu_src_loss]
|
gpu_src_losses += [gpu_src_loss]
|
||||||
gpu_dst_losses += [gpu_dst_loss]
|
gpu_dst_losses += [gpu_dst_loss]
|
||||||
gpu_G_loss = gpu_src_loss + gpu_dst_loss
|
gpu_G_loss = gpu_src_loss + gpu_dst_loss
|
||||||
|
# dst-dst background weak loss
|
||||||
def DLossOnes(logits):
|
gpu_G_loss += tf.reduce_mean(0.1*tf.square(gpu_pred_dst_dst_anti_masked-gpu_target_dst_anti_masked),axis=[1,2,3] )
|
||||||
return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(logits), logits=logits), axis=[1,2,3])
|
gpu_G_loss += 0.000001*nn.total_variation_mse(gpu_pred_dst_dst_anti_masked)
|
||||||
|
|
||||||
def DLossZeros(logits):
|
|
||||||
return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(logits), logits=logits), axis=[1,2,3])
|
|
||||||
|
|
||||||
if gan_power != 0:
|
if gan_power != 0:
|
||||||
gpu_pred_src_src_d, gpu_pred_src_src_d2 = self.GAN(gpu_pred_src_src_masked)
|
gpu_pred_src_src_d, gpu_pred_src_src_d2 = self.GAN(gpu_pred_src_src_masked)
|
||||||
gpu_pred_dst_dst_d, gpu_pred_dst_dst_d2 = self.GAN(gpu_pred_dst_dst_masked)
|
gpu_pred_dst_dst_d, gpu_pred_dst_dst_d2 = self.GAN(gpu_pred_dst_dst_masked)
|
||||||
|
@ -419,17 +451,17 @@ class AMPModel(ModelBase):
|
||||||
DLossZeros(gpu_pred_dst_dst_d) + DLossZeros(gpu_pred_dst_dst_d2)
|
DLossZeros(gpu_pred_dst_dst_d) + DLossZeros(gpu_pred_dst_dst_d2)
|
||||||
) * (1.0 / 8)
|
) * (1.0 / 8)
|
||||||
|
|
||||||
gpu_GAN_loss_grads += [ nn.gradients (gpu_GAN_loss, self.GAN.get_weights() ) ]
|
gpu_GAN_loss_gradients += [ nn.gradients (gpu_GAN_loss, self.GAN.get_weights() ) ]
|
||||||
|
|
||||||
gpu_G_loss += (DLossOnes(gpu_pred_src_src_d) + DLossOnes(gpu_pred_src_src_d2) + \
|
gpu_G_loss += (DLossOnes(gpu_pred_src_src_d) + DLossOnes(gpu_pred_src_src_d2) + \
|
||||||
DLossOnes(gpu_pred_dst_dst_d) + DLossOnes(gpu_pred_dst_dst_d2)
|
DLossOnes(gpu_pred_dst_dst_d) + DLossOnes(gpu_pred_dst_dst_d2)
|
||||||
) * gan_power
|
) * gan_power
|
||||||
|
|
||||||
# Minimal src-src-bg rec with total_variation_mse to suppress random bright dots from gan
|
# Minimal src-src-bg rec with total_variation_mse to suppress random bright dots from gan
|
||||||
gpu_G_loss += 0.000001*nn.total_variation_mse(gpu_pred_src_src)
|
gpu_G_loss += 0.000001*nn.total_variation_mse(gpu_pred_src_src)
|
||||||
gpu_G_loss += 0.02*tf.reduce_mean(tf.square(gpu_pred_src_src_anti_masked-gpu_target_src_anti_masked),axis=[1,2,3] )
|
gpu_G_loss += 0.02*tf.reduce_mean(tf.square(gpu_pred_src_src_anti_masked-gpu_target_src_anti_masked),axis=[1,2,3] )
|
||||||
|
|
||||||
gpu_G_loss_gradients += [ nn.gradients ( gpu_G_loss, self.encoder.get_weights() + self.decoder.get_weights() ) ]
|
gpu_G_loss_gradients += [ nn.gradients ( gpu_G_loss, self.G_weights ) ]
|
||||||
|
|
||||||
# Average losses and gradients, and create optimizer update ops
|
# Average losses and gradients, and create optimizer update ops
|
||||||
with tf.device(f'/CPU:0'):
|
with tf.device(f'/CPU:0'):
|
||||||
|
@ -444,9 +476,9 @@ class AMPModel(ModelBase):
|
||||||
src_loss = tf.concat(gpu_src_losses, 0)
|
src_loss = tf.concat(gpu_src_losses, 0)
|
||||||
dst_loss = tf.concat(gpu_dst_losses, 0)
|
dst_loss = tf.concat(gpu_dst_losses, 0)
|
||||||
train_op = self.src_dst_opt.get_update_op (nn.average_gv_list (gpu_G_loss_gradients))
|
train_op = self.src_dst_opt.get_update_op (nn.average_gv_list (gpu_G_loss_gradients))
|
||||||
|
|
||||||
if gan_power != 0:
|
if gan_power != 0:
|
||||||
GAN_train_op = self.GAN_opt.get_update_op (nn.average_gv_list(gpu_GAN_loss_grads) )
|
GAN_train_op = self.GAN_opt.get_update_op (nn.average_gv_list(gpu_GAN_loss_gradients) )
|
||||||
|
|
||||||
# Initializing training and view functions
|
# Initializing training and view functions
|
||||||
def train(warped_src, target_src, target_srcm, target_srcm_em, \
|
def train(warped_src, target_src, target_srcm, target_srcm_em, \
|
||||||
|
@ -520,11 +552,13 @@ class AMPModel(ModelBase):
|
||||||
|
|
||||||
random_ct_samples_path=training_data_dst_path if ct_mode is not None else None #and not self.pretrain
|
random_ct_samples_path=training_data_dst_path if ct_mode is not None else None #and not self.pretrain
|
||||||
|
|
||||||
cpu_count = min(multiprocessing.cpu_count(), 8)
|
cpu_count = multiprocessing.cpu_count()
|
||||||
src_generators_count = cpu_count // 2
|
src_generators_count = cpu_count // 2
|
||||||
dst_generators_count = cpu_count // 2
|
dst_generators_count = cpu_count // 2
|
||||||
if ct_mode is not None:
|
if ct_mode is not None:
|
||||||
src_generators_count = int(src_generators_count * 1.5)
|
src_generators_count = int(src_generators_count * 1.5)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
self.set_training_data_generators ([
|
self.set_training_data_generators ([
|
||||||
SampleGeneratorFace(training_data_src_path, random_ct_samples_path=random_ct_samples_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
|
SampleGeneratorFace(training_data_src_path, random_ct_samples_path=random_ct_samples_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
|
||||||
|
@ -541,6 +575,7 @@ class AMPModel(ModelBase):
|
||||||
sample_process_options=SampleProcessor.Options(random_flip=self.random_dst_flip),
|
sample_process_options=SampleProcessor.Options(random_flip=self.random_dst_flip),
|
||||||
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||||
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||||
|
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'denoise_filter' : rtm_dst_denoise, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||||
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||||
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.EYES_MOUTH, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.EYES_MOUTH, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||||
],
|
],
|
||||||
|
@ -616,13 +651,13 @@ class AMPModel(ModelBase):
|
||||||
bs = self.get_batch_size()
|
bs = self.get_batch_size()
|
||||||
|
|
||||||
( (warped_src, target_src, target_srcm, target_srcm_em), \
|
( (warped_src, target_src, target_srcm, target_srcm_em), \
|
||||||
(warped_dst, target_dst, target_dstm, target_dstm_em) ) = self.generate_next_samples()
|
(warped_dst, target_dst, target_dst_train, target_dstm, target_dstm_em) ) = self.generate_next_samples()
|
||||||
|
|
||||||
src_loss, dst_loss = self.train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em)
|
src_loss, dst_loss = self.train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst_train, target_dstm, target_dstm_em)
|
||||||
|
|
||||||
for i in range(bs):
|
for i in range(bs):
|
||||||
self.last_src_samples_loss.append ( (src_loss[i], target_src[i], target_srcm[i], target_srcm_em[i]) )
|
self.last_src_samples_loss.append ( (src_loss[i], target_src[i], target_srcm[i], target_srcm_em[i]) )
|
||||||
self.last_dst_samples_loss.append ( (dst_loss[i], target_dst[i], target_dstm[i], target_dstm_em[i]) )
|
self.last_dst_samples_loss.append ( (dst_loss[i], target_dst[i], target_dst_train[i], target_dstm[i], target_dstm_em[i]) )
|
||||||
|
|
||||||
if len(self.last_src_samples_loss) >= bs*16:
|
if len(self.last_src_samples_loss) >= bs*16:
|
||||||
src_samples_loss = sorted(self.last_src_samples_loss, key=operator.itemgetter(0), reverse=True)
|
src_samples_loss = sorted(self.last_src_samples_loss, key=operator.itemgetter(0), reverse=True)
|
||||||
|
@ -633,22 +668,23 @@ class AMPModel(ModelBase):
|
||||||
target_srcm_em = np.stack( [ x[3] for x in src_samples_loss[:bs] ] )
|
target_srcm_em = np.stack( [ x[3] for x in src_samples_loss[:bs] ] )
|
||||||
|
|
||||||
target_dst = np.stack( [ x[1] for x in dst_samples_loss[:bs] ] )
|
target_dst = np.stack( [ x[1] for x in dst_samples_loss[:bs] ] )
|
||||||
target_dstm = np.stack( [ x[2] for x in dst_samples_loss[:bs] ] )
|
target_dst_train = np.stack( [ x[2] for x in dst_samples_loss[:bs] ] )
|
||||||
target_dstm_em = np.stack( [ x[3] for x in dst_samples_loss[:bs] ] )
|
target_dstm = np.stack( [ x[3] for x in dst_samples_loss[:bs] ] )
|
||||||
|
target_dstm_em = np.stack( [ x[4] for x in dst_samples_loss[:bs] ] )
|
||||||
|
|
||||||
src_loss, dst_loss = self.train (target_src, target_src, target_srcm, target_srcm_em, target_dst, target_dst, target_dstm, target_dstm_em)
|
src_loss, dst_loss = self.train (target_src, target_src, target_srcm, target_srcm_em, target_dst, target_dst_train, target_dstm, target_dstm_em)
|
||||||
self.last_src_samples_loss = []
|
self.last_src_samples_loss = []
|
||||||
self.last_dst_samples_loss = []
|
self.last_dst_samples_loss = []
|
||||||
|
|
||||||
if self.gan_power != 0:
|
if self.gan_power != 0:
|
||||||
self.GAN_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em)
|
self.GAN_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst_train, target_dstm, target_dstm_em)
|
||||||
|
|
||||||
return ( ('src_loss', np.mean(src_loss) ), ('dst_loss', np.mean(dst_loss) ), )
|
return ( ('src_loss', np.mean(src_loss) ), ('dst_loss', np.mean(dst_loss) ), )
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def onGetPreview(self, samples, for_history=False):
|
def onGetPreview(self, samples, for_history=False):
|
||||||
( (warped_src, target_src, target_srcm, target_srcm_em),
|
( (warped_src, target_src, target_srcm, target_srcm_em),
|
||||||
(warped_dst, target_dst, target_dstm, target_dstm_em) ) = samples
|
(warped_dst, target_dst, target_dst_train, target_dstm, target_dstm_em) ) = samples
|
||||||
|
|
||||||
S, D, SS, DD, DDM_000, _, _ = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([target_src,target_dst] + self.AE_view (target_src, target_dst, 0.0) ) ]
|
S, D, SS, DD, DDM_000, _, _ = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([target_src,target_dst] + self.AE_view (target_src, target_dst, 0.0) ) ]
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue