mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 21:12:07 -07:00
SAEHD: removed 'dst_denoise' option. Added -t arhi option.
This commit is contained in:
parent
01f1a084b4
commit
6e094d873d
2 changed files with 136 additions and 73 deletions
|
@ -7,6 +7,10 @@ class DeepFakeArchi(nn.ArchiBase):
|
||||||
|
|
||||||
mod None - default
|
mod None - default
|
||||||
'quick'
|
'quick'
|
||||||
|
|
||||||
|
opts ''
|
||||||
|
''
|
||||||
|
't'
|
||||||
"""
|
"""
|
||||||
def __init__(self, resolution, use_fp16=False, mod=None, opts=None):
|
def __init__(self, resolution, use_fp16=False, mod=None, opts=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -81,18 +85,41 @@ class DeepFakeArchi(nn.ArchiBase):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def on_build(self):
|
def on_build(self):
|
||||||
self.down1 = DownscaleBlock(self.in_ch, self.e_ch, n_downscales=4, kernel_size=5)
|
if 't' in opts:
|
||||||
|
self.down1 = Downscale(self.in_ch, self.e_ch, kernel_size=5)
|
||||||
|
self.res1 = ResidualBlock(self.e_ch)
|
||||||
|
self.down2 = Downscale(self.e_ch, self.e_ch*2, kernel_size=5)
|
||||||
|
self.down3 = Downscale(self.e_ch*2, self.e_ch*4, kernel_size=5)
|
||||||
|
self.down4 = Downscale(self.e_ch*4, self.e_ch*8, kernel_size=5)
|
||||||
|
self.down5 = Downscale(self.e_ch*8, self.e_ch*8, kernel_size=5)
|
||||||
|
self.res5 = ResidualBlock(self.e_ch*8)
|
||||||
|
else:
|
||||||
|
self.down1 = DownscaleBlock(self.in_ch, self.e_ch, n_downscales=4 if 't' not in opts else 5, kernel_size=5)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if use_fp16:
|
if use_fp16:
|
||||||
x = tf.cast(x, tf.float16)
|
x = tf.cast(x, tf.float16)
|
||||||
x = nn.flatten(self.down1(x))
|
|
||||||
|
if 't' in opts:
|
||||||
|
x = self.down1(x)
|
||||||
|
x = self.res1(x)
|
||||||
|
x = self.down2(x)
|
||||||
|
x = self.down3(x)
|
||||||
|
x = self.down4(x)
|
||||||
|
x = self.down5(x)
|
||||||
|
x = self.res5(x)
|
||||||
|
else:
|
||||||
|
x = self.down1(x)
|
||||||
|
x = nn.flatten(x)
|
||||||
|
if 'u' in opts:
|
||||||
|
x = nn.pixel_norm(x, axes=-1)
|
||||||
|
|
||||||
if use_fp16:
|
if use_fp16:
|
||||||
x = tf.cast(x, tf.float32)
|
x = tf.cast(x, tf.float32)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def get_out_res(self, res):
|
def get_out_res(self, res):
|
||||||
return res // (2**4)
|
return res // ( (2**4) if 't' not in opts else (2**5) )
|
||||||
|
|
||||||
def get_out_ch(self):
|
def get_out_ch(self):
|
||||||
return self.e_ch * 8
|
return self.e_ch * 8
|
||||||
|
@ -107,48 +134,46 @@ class DeepFakeArchi(nn.ArchiBase):
|
||||||
def on_build(self):
|
def on_build(self):
|
||||||
in_ch, ae_ch, ae_out_ch = self.in_ch, self.ae_ch, self.ae_out_ch
|
in_ch, ae_ch, ae_out_ch = self.in_ch, self.ae_ch, self.ae_out_ch
|
||||||
|
|
||||||
if 'u' in opts:
|
|
||||||
self.dense_norm = nn.DenseNorm()
|
|
||||||
|
|
||||||
self.dense1 = nn.Dense( in_ch, ae_ch )
|
self.dense1 = nn.Dense( in_ch, ae_ch )
|
||||||
self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch )
|
self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch )
|
||||||
|
if 't' not in opts:
|
||||||
self.upscale1 = Upscale(ae_out_ch, ae_out_ch)
|
self.upscale1 = Upscale(ae_out_ch, ae_out_ch)
|
||||||
|
|
||||||
def forward(self, inp):
|
def forward(self, inp):
|
||||||
x = inp
|
x = inp
|
||||||
if 'u' in opts:
|
|
||||||
x = self.dense_norm(x)
|
|
||||||
x = self.dense1(x)
|
x = self.dense1(x)
|
||||||
x = self.dense2(x)
|
x = self.dense2(x)
|
||||||
x = nn.reshape_4D (x, lowest_dense_res, lowest_dense_res, self.ae_out_ch)
|
x = nn.reshape_4D (x, lowest_dense_res, lowest_dense_res, self.ae_out_ch)
|
||||||
|
|
||||||
if use_fp16:
|
if use_fp16:
|
||||||
x = tf.cast(x, tf.float16)
|
x = tf.cast(x, tf.float16)
|
||||||
|
|
||||||
|
if 't' not in opts:
|
||||||
x = self.upscale1(x)
|
x = self.upscale1(x)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def get_out_res(self):
|
def get_out_res(self):
|
||||||
return lowest_dense_res * 2
|
return lowest_dense_res * 2 if 't' not in opts else lowest_dense_res
|
||||||
|
|
||||||
def get_out_ch(self):
|
def get_out_ch(self):
|
||||||
return self.ae_out_ch
|
return self.ae_out_ch
|
||||||
|
|
||||||
class Decoder(nn.ModelBase):
|
class Decoder(nn.ModelBase):
|
||||||
def on_build(self, in_ch, d_ch, d_mask_ch):
|
def on_build(self, in_ch, d_ch, d_mask_ch):
|
||||||
|
if 't' not in opts:
|
||||||
self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3)
|
self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3)
|
||||||
self.upscale1 = Upscale(d_ch*8, d_ch*4, kernel_size=3)
|
self.upscale1 = Upscale(d_ch*8, d_ch*4, kernel_size=3)
|
||||||
self.upscale2 = Upscale(d_ch*4, d_ch*2, kernel_size=3)
|
self.upscale2 = Upscale(d_ch*4, d_ch*2, kernel_size=3)
|
||||||
|
|
||||||
self.res0 = ResidualBlock(d_ch*8, kernel_size=3)
|
self.res0 = ResidualBlock(d_ch*8, kernel_size=3)
|
||||||
self.res1 = ResidualBlock(d_ch*4, kernel_size=3)
|
self.res1 = ResidualBlock(d_ch*4, kernel_size=3)
|
||||||
self.res2 = ResidualBlock(d_ch*2, kernel_size=3)
|
self.res2 = ResidualBlock(d_ch*2, kernel_size=3)
|
||||||
|
|
||||||
self.out_conv = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='SAME', dtype=conv_dtype)
|
|
||||||
|
|
||||||
self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3)
|
self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3)
|
||||||
self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3)
|
self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3)
|
||||||
self.upscalem2 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3)
|
self.upscalem2 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3)
|
||||||
self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME', dtype=conv_dtype)
|
|
||||||
|
self.out_conv = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='SAME', dtype=conv_dtype)
|
||||||
|
|
||||||
if 'd' in opts:
|
if 'd' in opts:
|
||||||
self.out_conv1 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
|
self.out_conv1 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
|
||||||
|
@ -158,6 +183,32 @@ class DeepFakeArchi(nn.ArchiBase):
|
||||||
self.out_convm = nn.Conv2D( d_mask_ch*1, 1, kernel_size=1, padding='SAME', dtype=conv_dtype)
|
self.out_convm = nn.Conv2D( d_mask_ch*1, 1, kernel_size=1, padding='SAME', dtype=conv_dtype)
|
||||||
else:
|
else:
|
||||||
self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME', dtype=conv_dtype)
|
self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME', dtype=conv_dtype)
|
||||||
|
else:
|
||||||
|
self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3)
|
||||||
|
self.upscale1 = Upscale(d_ch*8, d_ch*8, kernel_size=3)
|
||||||
|
self.upscale2 = Upscale(d_ch*8, d_ch*4, kernel_size=3)
|
||||||
|
self.upscale3 = Upscale(d_ch*4, d_ch*2, kernel_size=3)
|
||||||
|
self.res0 = ResidualBlock(d_ch*8, kernel_size=3)
|
||||||
|
self.res1 = ResidualBlock(d_ch*8, kernel_size=3)
|
||||||
|
self.res2 = ResidualBlock(d_ch*4, kernel_size=3)
|
||||||
|
self.res3 = ResidualBlock(d_ch*2, kernel_size=3)
|
||||||
|
|
||||||
|
self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3)
|
||||||
|
self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*8, kernel_size=3)
|
||||||
|
self.upscalem2 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3)
|
||||||
|
self.upscalem3 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3)
|
||||||
|
self.out_conv = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='SAME', dtype=conv_dtype)
|
||||||
|
|
||||||
|
if 'd' in opts:
|
||||||
|
self.out_conv1 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
|
||||||
|
self.out_conv2 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
|
||||||
|
self.out_conv3 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
|
||||||
|
self.upscalem4 = Upscale(d_mask_ch*2, d_mask_ch*1, kernel_size=3)
|
||||||
|
self.out_convm = nn.Conv2D( d_mask_ch*1, 1, kernel_size=1, padding='SAME', dtype=conv_dtype)
|
||||||
|
else:
|
||||||
|
self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME', dtype=conv_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, z):
|
def forward(self, z):
|
||||||
x = self.upscale0(z)
|
x = self.upscale0(z)
|
||||||
|
@ -167,6 +218,10 @@ class DeepFakeArchi(nn.ArchiBase):
|
||||||
x = self.upscale2(x)
|
x = self.upscale2(x)
|
||||||
x = self.res2(x)
|
x = self.res2(x)
|
||||||
|
|
||||||
|
if 't' in opts:
|
||||||
|
x = self.upscale3(x)
|
||||||
|
x = self.res3(x)
|
||||||
|
|
||||||
if 'd' in opts:
|
if 'd' in opts:
|
||||||
x = tf.nn.sigmoid( nn.depth_to_space(tf.concat( (self.out_conv(x),
|
x = tf.nn.sigmoid( nn.depth_to_space(tf.concat( (self.out_conv(x),
|
||||||
self.out_conv1(x),
|
self.out_conv1(x),
|
||||||
|
@ -179,8 +234,15 @@ class DeepFakeArchi(nn.ArchiBase):
|
||||||
m = self.upscalem0(z)
|
m = self.upscalem0(z)
|
||||||
m = self.upscalem1(m)
|
m = self.upscalem1(m)
|
||||||
m = self.upscalem2(m)
|
m = self.upscalem2(m)
|
||||||
|
|
||||||
|
if 't' in opts:
|
||||||
|
m = self.upscalem3(m)
|
||||||
|
if 'd' in opts:
|
||||||
|
m = self.upscalem4(m)
|
||||||
|
else:
|
||||||
if 'd' in opts:
|
if 'd' in opts:
|
||||||
m = self.upscalem3(m)
|
m = self.upscalem3(m)
|
||||||
|
|
||||||
m = tf.nn.sigmoid(self.out_convm(m))
|
m = tf.nn.sigmoid(self.out_convm(m))
|
||||||
|
|
||||||
if use_fp16:
|
if use_fp16:
|
||||||
|
|
|
@ -35,9 +35,7 @@ class SAEHDModel(ModelBase):
|
||||||
default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'f')
|
default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'f')
|
||||||
default_models_opt_on_gpu = self.options['models_opt_on_gpu'] = self.load_or_def_option('models_opt_on_gpu', True)
|
default_models_opt_on_gpu = self.options['models_opt_on_gpu'] = self.load_or_def_option('models_opt_on_gpu', True)
|
||||||
|
|
||||||
archi = self.load_or_def_option('archi', 'liae-ud')
|
default_archi = self.options['archi'] = self.load_or_def_option('archi', 'liae-ud')
|
||||||
archi = {'dfuhd':'df-u','liaeuhd':'liae-u'}.get(archi, archi) #backward comp
|
|
||||||
default_archi = self.options['archi'] = archi
|
|
||||||
|
|
||||||
default_ae_dims = self.options['ae_dims'] = self.load_or_def_option('ae_dims', 256)
|
default_ae_dims = self.options['ae_dims'] = self.load_or_def_option('ae_dims', 256)
|
||||||
default_e_dims = self.options['e_dims'] = self.load_or_def_option('e_dims', 64)
|
default_e_dims = self.options['e_dims'] = self.load_or_def_option('e_dims', 64)
|
||||||
|
@ -47,7 +45,6 @@ class SAEHDModel(ModelBase):
|
||||||
default_eyes_mouth_prio = self.options['eyes_mouth_prio'] = self.load_or_def_option('eyes_mouth_prio', False)
|
default_eyes_mouth_prio = self.options['eyes_mouth_prio'] = self.load_or_def_option('eyes_mouth_prio', False)
|
||||||
default_uniform_yaw = self.options['uniform_yaw'] = self.load_or_def_option('uniform_yaw', False)
|
default_uniform_yaw = self.options['uniform_yaw'] = self.load_or_def_option('uniform_yaw', False)
|
||||||
default_blur_out_mask = self.options['blur_out_mask'] = self.load_or_def_option('blur_out_mask', False)
|
default_blur_out_mask = self.options['blur_out_mask'] = self.load_or_def_option('blur_out_mask', False)
|
||||||
default_dst_denoise = self.options['dst_denoise'] = self.load_or_def_option('dst_denoise', False)
|
|
||||||
|
|
||||||
default_adabelief = self.options['adabelief'] = self.load_or_def_option('adabelief', True)
|
default_adabelief = self.options['adabelief'] = self.load_or_def_option('adabelief', True)
|
||||||
|
|
||||||
|
@ -107,7 +104,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
||||||
if archi_opts is not None:
|
if archi_opts is not None:
|
||||||
if len(archi_opts) == 0:
|
if len(archi_opts) == 0:
|
||||||
continue
|
continue
|
||||||
if len([ 1 for opt in archi_opts if opt not in ['u','d'] ]) != 0:
|
if len([ 1 for opt in archi_opts if opt not in ['u','d','t'] ]) != 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if 'd' in archi_opts:
|
if 'd' in archi_opts:
|
||||||
|
@ -141,7 +138,6 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
||||||
self.options['eyes_mouth_prio'] = io.input_bool ("Eyes and mouth priority", default_eyes_mouth_prio, help_message='Helps to fix eye problems during training like "alien eyes" and wrong eyes direction. Also makes the detail of the teeth higher.')
|
self.options['eyes_mouth_prio'] = io.input_bool ("Eyes and mouth priority", default_eyes_mouth_prio, help_message='Helps to fix eye problems during training like "alien eyes" and wrong eyes direction. Also makes the detail of the teeth higher.')
|
||||||
self.options['uniform_yaw'] = io.input_bool ("Uniform yaw distribution of samples", default_uniform_yaw, help_message='Helps to fix blurry side faces due to small amount of them in the faceset.')
|
self.options['uniform_yaw'] = io.input_bool ("Uniform yaw distribution of samples", default_uniform_yaw, help_message='Helps to fix blurry side faces due to small amount of them in the faceset.')
|
||||||
self.options['blur_out_mask'] = io.input_bool ("Blur out mask", default_blur_out_mask, help_message='Blurs nearby area outside of applied face mask of training samples. The result is the background near the face is smoothed and less noticeable on swapped face. The exact xseg mask in src and dst faceset is required.')
|
self.options['blur_out_mask'] = io.input_bool ("Blur out mask", default_blur_out_mask, help_message='Blurs nearby area outside of applied face mask of training samples. The result is the background near the face is smoothed and less noticeable on swapped face. The exact xseg mask in src and dst faceset is required.')
|
||||||
self.options['dst_denoise'] = io.input_bool ("Denoise DST faceset.", default_dst_denoise, help_message='Used in RTM(ReadyToMerge) training with RTM DST faceset. Removes high frequency noise keeping edges. Result is better face syncronization with any face. Can be enabled at any time.')
|
|
||||||
|
|
||||||
default_gan_power = self.options['gan_power'] = self.load_or_def_option('gan_power', 0.0)
|
default_gan_power = self.options['gan_power'] = self.load_or_def_option('gan_power', 0.0)
|
||||||
default_gan_patch_size = self.options['gan_patch_size'] = self.load_or_def_option('gan_patch_size', self.options['resolution'] // 8)
|
default_gan_patch_size = self.options['gan_patch_size'] = self.load_or_def_option('gan_patch_size', self.options['resolution'] // 8)
|
||||||
|
@ -233,7 +229,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
||||||
random_src_flip = self.random_src_flip if not self.pretrain else True
|
random_src_flip = self.random_src_flip if not self.pretrain else True
|
||||||
random_dst_flip = self.random_dst_flip if not self.pretrain else True
|
random_dst_flip = self.random_dst_flip if not self.pretrain else True
|
||||||
blur_out_mask = self.options['blur_out_mask']
|
blur_out_mask = self.options['blur_out_mask']
|
||||||
dst_denoise = self.options['dst_denoise']
|
learn_dst_bg = False#True
|
||||||
|
|
||||||
if self.pretrain:
|
if self.pretrain:
|
||||||
self.options_show_override['gan_power'] = 0.0
|
self.options_show_override['gan_power'] = 0.0
|
||||||
|
@ -327,8 +323,6 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
||||||
elif 'liae' in archi_type:
|
elif 'liae' in archi_type:
|
||||||
self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter_AB.get_weights() + self.inter_B.get_weights() + self.decoder.get_weights()
|
self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter_AB.get_weights() + self.inter_B.get_weights() + self.decoder.get_weights()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
self.src_dst_opt = OptimizerClass(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='src_dst_opt')
|
self.src_dst_opt = OptimizerClass(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='src_dst_opt')
|
||||||
self.src_dst_opt.initialize_variables (self.src_dst_trainable_weights, vars_on_cpu=optimizer_vars_on_cpu, lr_dropout_on_cpu=self.options['lr_dropout']=='cpu')
|
self.src_dst_opt.initialize_variables (self.src_dst_trainable_weights, vars_on_cpu=optimizer_vars_on_cpu, lr_dropout_on_cpu=self.options['lr_dropout']=='cpu')
|
||||||
self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ]
|
self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ]
|
||||||
|
@ -413,6 +407,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
||||||
|
|
||||||
gpu_pred_src_src, gpu_pred_src_srcm = self.decoder(gpu_src_code)
|
gpu_pred_src_src, gpu_pred_src_srcm = self.decoder(gpu_src_code)
|
||||||
gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder(gpu_dst_code)
|
gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder(gpu_dst_code)
|
||||||
|
gpu_pred_dst_dst_no_code_grad, _ = self.decoder(tf.stop_gradient(gpu_dst_code))
|
||||||
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code)
|
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code)
|
||||||
|
|
||||||
gpu_pred_src_src_list.append(gpu_pred_src_src)
|
gpu_pred_src_src_list.append(gpu_pred_src_src)
|
||||||
|
@ -425,25 +420,30 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
||||||
|
|
||||||
gpu_target_srcm_blur = nn.gaussian_blur(gpu_target_srcm, max(1, resolution // 32) )
|
gpu_target_srcm_blur = nn.gaussian_blur(gpu_target_srcm, max(1, resolution // 32) )
|
||||||
gpu_target_srcm_blur = tf.clip_by_value(gpu_target_srcm_blur, 0, 0.5) * 2
|
gpu_target_srcm_blur = tf.clip_by_value(gpu_target_srcm_blur, 0, 0.5) * 2
|
||||||
|
gpu_target_srcm_anti_blur = 1.0-gpu_target_srcm_blur
|
||||||
|
|
||||||
gpu_target_dstm_blur = nn.gaussian_blur(gpu_target_dstm, max(1, resolution // 32) )
|
gpu_target_dstm_blur = nn.gaussian_blur(gpu_target_dstm, max(1, resolution // 32) )
|
||||||
gpu_target_dstm_style_blur = gpu_target_dstm_blur #default style mask is 0.5 on boundary
|
gpu_target_dstm_style_blur = gpu_target_dstm_blur #default style mask is 0.5 on boundary
|
||||||
|
gpu_target_dstm_style_anti_blur = 1.0 - gpu_target_dstm_style_blur
|
||||||
gpu_target_dstm_blur = tf.clip_by_value(gpu_target_dstm_blur, 0, 0.5) * 2
|
gpu_target_dstm_blur = tf.clip_by_value(gpu_target_dstm_blur, 0, 0.5) * 2
|
||||||
|
gpu_target_dstm_anti_blur = 1.0-gpu_target_dstm_blur
|
||||||
|
|
||||||
gpu_target_dst_masked = gpu_target_dst*gpu_target_dstm_blur
|
gpu_target_dst_masked = gpu_target_dst*gpu_target_dstm_blur
|
||||||
gpu_target_dst_style_masked = gpu_target_dst*gpu_target_dstm_style_blur
|
gpu_target_dst_style_masked = gpu_target_dst*gpu_target_dstm_style_blur
|
||||||
gpu_target_dst_style_anti_masked = gpu_target_dst*(1.0 - gpu_target_dstm_style_blur)
|
gpu_target_dst_style_anti_masked = gpu_target_dst*gpu_target_dstm_style_anti_blur
|
||||||
|
|
||||||
|
gpu_target_src_anti_masked = gpu_target_src*gpu_target_srcm_anti_blur
|
||||||
|
gpu_target_dst_anti_masked = gpu_target_dst*gpu_target_dstm_anti_blur
|
||||||
|
gpu_pred_src_src_anti_masked = gpu_pred_src_src*gpu_target_srcm_anti_blur
|
||||||
|
gpu_pred_dst_dst_anti_masked = gpu_pred_dst_dst*gpu_target_dstm_anti_blur
|
||||||
|
|
||||||
gpu_target_src_anti_masked = gpu_target_src*(1.0-gpu_target_srcm_blur)
|
|
||||||
gpu_target_src_masked_opt = gpu_target_src*gpu_target_srcm_blur if masked_training else gpu_target_src
|
gpu_target_src_masked_opt = gpu_target_src*gpu_target_srcm_blur if masked_training else gpu_target_src
|
||||||
gpu_target_dst_masked_opt = gpu_target_dst_masked if masked_training else gpu_target_dst
|
gpu_target_dst_masked_opt = gpu_target_dst_masked if masked_training else gpu_target_dst
|
||||||
|
|
||||||
gpu_pred_src_src_masked_opt = gpu_pred_src_src*gpu_target_srcm_blur if masked_training else gpu_pred_src_src
|
gpu_pred_src_src_masked_opt = gpu_pred_src_src*gpu_target_srcm_blur if masked_training else gpu_pred_src_src
|
||||||
gpu_pred_src_src_anti_masked = gpu_pred_src_src*(1.0-gpu_target_srcm_blur)
|
|
||||||
gpu_pred_dst_dst_masked_opt = gpu_pred_dst_dst*gpu_target_dstm_blur if masked_training else gpu_pred_dst_dst
|
gpu_pred_dst_dst_masked_opt = gpu_pred_dst_dst*gpu_target_dstm_blur if masked_training else gpu_pred_dst_dst
|
||||||
|
|
||||||
gpu_psd_target_dst_style_masked = gpu_pred_src_dst*gpu_target_dstm_style_blur
|
gpu_psd_target_dst_style_masked = gpu_pred_src_dst*gpu_target_dstm_style_blur
|
||||||
gpu_psd_target_dst_style_anti_masked = gpu_pred_src_dst*(1.0 - gpu_target_dstm_style_blur)
|
gpu_psd_target_dst_style_anti_masked = gpu_pred_src_dst*gpu_target_dstm_style_anti_blur
|
||||||
|
|
||||||
if resolution < 256:
|
if resolution < 256:
|
||||||
gpu_src_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
|
gpu_src_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
|
||||||
|
@ -483,6 +483,9 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
||||||
|
|
||||||
gpu_G_loss = gpu_src_loss + gpu_dst_loss
|
gpu_G_loss = gpu_src_loss + gpu_dst_loss
|
||||||
|
|
||||||
|
if learn_dst_bg and masked_training and 'liae' in archi_type:
|
||||||
|
gpu_G_loss += tf.reduce_mean( tf.square(gpu_pred_dst_dst_no_code_grad*gpu_target_dstm_anti_blur-gpu_target_dst_anti_masked),axis=[1,2,3] )
|
||||||
|
|
||||||
def DLoss(labels,logits):
|
def DLoss(labels,logits):
|
||||||
return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits), axis=[1,2,3])
|
return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits), axis=[1,2,3])
|
||||||
|
|
||||||
|
@ -526,8 +529,6 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
||||||
gpu_G_loss += gan_power*(DLoss(gpu_pred_src_src_d_ones, gpu_pred_src_src_d) + \
|
gpu_G_loss += gan_power*(DLoss(gpu_pred_src_src_d_ones, gpu_pred_src_src_d) + \
|
||||||
DLoss(gpu_pred_src_src_d2_ones, gpu_pred_src_src_d2))
|
DLoss(gpu_pred_src_src_d2_ones, gpu_pred_src_src_d2))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if masked_training:
|
if masked_training:
|
||||||
# Minimal src-src-bg rec with total_variation_mse to suppress random bright dots from gan
|
# Minimal src-src-bg rec with total_variation_mse to suppress random bright dots from gan
|
||||||
gpu_G_loss += 0.000001*nn.total_variation_mse(gpu_pred_src_src)
|
gpu_G_loss += 0.000001*nn.total_variation_mse(gpu_pred_src_src)
|
||||||
|
@ -536,6 +537,8 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
||||||
gpu_G_loss_gvs += [ nn.gradients ( gpu_G_loss, self.src_dst_trainable_weights )]
|
gpu_G_loss_gvs += [ nn.gradients ( gpu_G_loss, self.src_dst_trainable_weights )]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Average losses and gradients, and create optimizer update ops
|
# Average losses and gradients, and create optimizer update ops
|
||||||
with tf.device(f'/CPU:0'):
|
with tf.device(f'/CPU:0'):
|
||||||
pred_src_src = nn.concat(gpu_pred_src_src_list, 0)
|
pred_src_src = nn.concat(gpu_pred_src_src_list, 0)
|
||||||
|
@ -560,7 +563,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
||||||
# Initializing training and view functions
|
# Initializing training and view functions
|
||||||
def src_dst_train(warped_src, target_src, target_srcm, target_srcm_em, \
|
def src_dst_train(warped_src, target_src, target_srcm, target_srcm_em, \
|
||||||
warped_dst, target_dst, target_dstm, target_dstm_em, ):
|
warped_dst, target_dst, target_dstm, target_dstm_em, ):
|
||||||
s, d, _ = nn.tf_sess.run ( [ src_loss, dst_loss, src_dst_loss_gv_op],
|
s, d = nn.tf_sess.run ( [ src_loss, dst_loss, src_dst_loss_gv_op],
|
||||||
feed_dict={self.warped_src :warped_src,
|
feed_dict={self.warped_src :warped_src,
|
||||||
self.target_src :target_src,
|
self.target_src :target_src,
|
||||||
self.target_srcm:target_srcm,
|
self.target_srcm:target_srcm,
|
||||||
|
@ -569,7 +572,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
||||||
self.target_dst :target_dst,
|
self.target_dst :target_dst,
|
||||||
self.target_dstm:target_dstm,
|
self.target_dstm:target_dstm,
|
||||||
self.target_dstm_em:target_dstm_em,
|
self.target_dstm_em:target_dstm_em,
|
||||||
})
|
})[:2]
|
||||||
return s, d
|
return s, d
|
||||||
self.src_dst_train = src_dst_train
|
self.src_dst_train = src_dst_train
|
||||||
|
|
||||||
|
@ -674,7 +677,6 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
||||||
sample_process_options=SampleProcessor.Options(random_flip=random_dst_flip),
|
sample_process_options=SampleProcessor.Options(random_flip=random_dst_flip),
|
||||||
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||||
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||||
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'denoise_filter' : dst_denoise, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
|
||||||
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||||
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.EYES_MOUTH, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.EYES_MOUTH, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||||
],
|
],
|
||||||
|
@ -762,13 +764,13 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
||||||
bs = self.get_batch_size()
|
bs = self.get_batch_size()
|
||||||
|
|
||||||
( (warped_src, target_src, target_srcm, target_srcm_em), \
|
( (warped_src, target_src, target_srcm, target_srcm_em), \
|
||||||
(warped_dst, target_dst, target_dst_train, target_dstm, target_dstm_em) ) = self.generate_next_samples()
|
(warped_dst, target_dst, target_dstm, target_dstm_em) ) = self.generate_next_samples()
|
||||||
|
|
||||||
src_loss, dst_loss = self.src_dst_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst_train, target_dstm, target_dstm_em)
|
src_loss, dst_loss = self.src_dst_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em)
|
||||||
|
|
||||||
for i in range(bs):
|
for i in range(bs):
|
||||||
self.last_src_samples_loss.append ( (src_loss[i], target_src[i], target_srcm[i], target_srcm_em[i],) )
|
self.last_src_samples_loss.append ( (src_loss[i], target_src[i], target_srcm[i], target_srcm_em[i],) )
|
||||||
self.last_dst_samples_loss.append ( (dst_loss[i], target_dst[i], target_dst_train[i], target_dstm[i], target_dstm_em[i],) )
|
self.last_dst_samples_loss.append ( (dst_loss[i], target_dst[i], target_dstm[i], target_dstm_em[i],) )
|
||||||
|
|
||||||
if len(self.last_src_samples_loss) >= bs*16:
|
if len(self.last_src_samples_loss) >= bs*16:
|
||||||
src_samples_loss = sorted(self.last_src_samples_loss, key=operator.itemgetter(0), reverse=True)
|
src_samples_loss = sorted(self.last_src_samples_loss, key=operator.itemgetter(0), reverse=True)
|
||||||
|
@ -779,11 +781,10 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
||||||
target_srcm_em = np.stack( [ x[3] for x in src_samples_loss[:bs] ] )
|
target_srcm_em = np.stack( [ x[3] for x in src_samples_loss[:bs] ] )
|
||||||
|
|
||||||
target_dst = np.stack( [ x[1] for x in dst_samples_loss[:bs] ] )
|
target_dst = np.stack( [ x[1] for x in dst_samples_loss[:bs] ] )
|
||||||
target_dst_train = np.stack( [ x[2] for x in dst_samples_loss[:bs] ] )
|
target_dstm = np.stack( [ x[2] for x in dst_samples_loss[:bs] ] )
|
||||||
target_dstm = np.stack( [ x[3] for x in dst_samples_loss[:bs] ] )
|
target_dstm_em = np.stack( [ x[3] for x in dst_samples_loss[:bs] ] )
|
||||||
target_dstm_em = np.stack( [ x[4] for x in dst_samples_loss[:bs] ] )
|
|
||||||
|
|
||||||
src_loss, dst_loss = self.src_dst_train (target_src, target_src, target_srcm, target_srcm_em, target_dst, target_dst_train, target_dstm, target_dstm_em)
|
src_loss, dst_loss = self.src_dst_train (target_src, target_src, target_srcm, target_srcm_em, target_dst, target_dst, target_dstm, target_dstm_em)
|
||||||
self.last_src_samples_loss = []
|
self.last_src_samples_loss = []
|
||||||
self.last_dst_samples_loss = []
|
self.last_dst_samples_loss = []
|
||||||
|
|
||||||
|
@ -791,14 +792,14 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
||||||
self.D_train (warped_src, warped_dst)
|
self.D_train (warped_src, warped_dst)
|
||||||
|
|
||||||
if self.gan_power != 0:
|
if self.gan_power != 0:
|
||||||
self.D_src_dst_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst_train, target_dstm, target_dstm_em)
|
self.D_src_dst_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em)
|
||||||
|
|
||||||
return ( ('src_loss', np.mean(src_loss) ), ('dst_loss', np.mean(dst_loss) ), )
|
return ( ('src_loss', np.mean(src_loss) ), ('dst_loss', np.mean(dst_loss) ), )
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def onGetPreview(self, samples, for_history=False):
|
def onGetPreview(self, samples, for_history=False):
|
||||||
( (warped_src, target_src, target_srcm, target_srcm_em),
|
( (warped_src, target_src, target_srcm, target_srcm_em),
|
||||||
(warped_dst, target_dst, target_dst_train, target_dstm, target_dstm_em) ) = samples
|
(warped_dst, target_dst, target_dstm, target_dstm_em) ) = samples
|
||||||
|
|
||||||
S, D, SS, DD, DDM, SD, SDM = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([target_src,target_dst] + self.AE_view (target_src, target_dst) ) ]
|
S, D, SS, DD, DDM, SD, SDM = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([target_src,target_dst] + self.AE_view (target_src, target_dst) ) ]
|
||||||
DDM, SDM, = [ np.repeat (x, (3,), -1) for x in [DDM, SDM] ]
|
DDM, SDM, = [ np.repeat (x, (3,), -1) for x in [DDM, SDM] ]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue