mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 13:02:15 -07:00
SAEHD: removed 'dst_denoise' option. Added -t arhi option.
This commit is contained in:
parent
01f1a084b4
commit
6e094d873d
2 changed files with 136 additions and 73 deletions
|
@ -7,6 +7,10 @@ class DeepFakeArchi(nn.ArchiBase):
|
|||
|
||||
mod None - default
|
||||
'quick'
|
||||
|
||||
opts ''
|
||||
''
|
||||
't'
|
||||
"""
|
||||
def __init__(self, resolution, use_fp16=False, mod=None, opts=None):
|
||||
super().__init__()
|
||||
|
@ -81,18 +85,41 @@ class DeepFakeArchi(nn.ArchiBase):
|
|||
super().__init__(**kwargs)
|
||||
|
||||
def on_build(self):
|
||||
self.down1 = DownscaleBlock(self.in_ch, self.e_ch, n_downscales=4, kernel_size=5)
|
||||
if 't' in opts:
|
||||
self.down1 = Downscale(self.in_ch, self.e_ch, kernel_size=5)
|
||||
self.res1 = ResidualBlock(self.e_ch)
|
||||
self.down2 = Downscale(self.e_ch, self.e_ch*2, kernel_size=5)
|
||||
self.down3 = Downscale(self.e_ch*2, self.e_ch*4, kernel_size=5)
|
||||
self.down4 = Downscale(self.e_ch*4, self.e_ch*8, kernel_size=5)
|
||||
self.down5 = Downscale(self.e_ch*8, self.e_ch*8, kernel_size=5)
|
||||
self.res5 = ResidualBlock(self.e_ch*8)
|
||||
else:
|
||||
self.down1 = DownscaleBlock(self.in_ch, self.e_ch, n_downscales=4 if 't' not in opts else 5, kernel_size=5)
|
||||
|
||||
def forward(self, x):
|
||||
if use_fp16:
|
||||
x = tf.cast(x, tf.float16)
|
||||
x = nn.flatten(self.down1(x))
|
||||
|
||||
if 't' in opts:
|
||||
x = self.down1(x)
|
||||
x = self.res1(x)
|
||||
x = self.down2(x)
|
||||
x = self.down3(x)
|
||||
x = self.down4(x)
|
||||
x = self.down5(x)
|
||||
x = self.res5(x)
|
||||
else:
|
||||
x = self.down1(x)
|
||||
x = nn.flatten(x)
|
||||
if 'u' in opts:
|
||||
x = nn.pixel_norm(x, axes=-1)
|
||||
|
||||
if use_fp16:
|
||||
x = tf.cast(x, tf.float32)
|
||||
return x
|
||||
|
||||
def get_out_res(self, res):
|
||||
return res // (2**4)
|
||||
return res // ( (2**4) if 't' not in opts else (2**5) )
|
||||
|
||||
def get_out_ch(self):
|
||||
return self.e_ch * 8
|
||||
|
@ -107,48 +134,46 @@ class DeepFakeArchi(nn.ArchiBase):
|
|||
def on_build(self):
|
||||
in_ch, ae_ch, ae_out_ch = self.in_ch, self.ae_ch, self.ae_out_ch
|
||||
|
||||
if 'u' in opts:
|
||||
self.dense_norm = nn.DenseNorm()
|
||||
|
||||
self.dense1 = nn.Dense( in_ch, ae_ch )
|
||||
self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch )
|
||||
if 't' not in opts:
|
||||
self.upscale1 = Upscale(ae_out_ch, ae_out_ch)
|
||||
|
||||
def forward(self, inp):
|
||||
x = inp
|
||||
if 'u' in opts:
|
||||
x = self.dense_norm(x)
|
||||
x = self.dense1(x)
|
||||
x = self.dense2(x)
|
||||
x = nn.reshape_4D (x, lowest_dense_res, lowest_dense_res, self.ae_out_ch)
|
||||
|
||||
if use_fp16:
|
||||
x = tf.cast(x, tf.float16)
|
||||
|
||||
if 't' not in opts:
|
||||
x = self.upscale1(x)
|
||||
|
||||
return x
|
||||
|
||||
def get_out_res(self):
|
||||
return lowest_dense_res * 2
|
||||
return lowest_dense_res * 2 if 't' not in opts else lowest_dense_res
|
||||
|
||||
def get_out_ch(self):
|
||||
return self.ae_out_ch
|
||||
|
||||
class Decoder(nn.ModelBase):
|
||||
def on_build(self, in_ch, d_ch, d_mask_ch):
|
||||
if 't' not in opts:
|
||||
self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3)
|
||||
self.upscale1 = Upscale(d_ch*8, d_ch*4, kernel_size=3)
|
||||
self.upscale2 = Upscale(d_ch*4, d_ch*2, kernel_size=3)
|
||||
|
||||
self.res0 = ResidualBlock(d_ch*8, kernel_size=3)
|
||||
self.res1 = ResidualBlock(d_ch*4, kernel_size=3)
|
||||
self.res2 = ResidualBlock(d_ch*2, kernel_size=3)
|
||||
|
||||
self.out_conv = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='SAME', dtype=conv_dtype)
|
||||
|
||||
self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3)
|
||||
self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3)
|
||||
self.upscalem2 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3)
|
||||
self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME', dtype=conv_dtype)
|
||||
|
||||
self.out_conv = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='SAME', dtype=conv_dtype)
|
||||
|
||||
if 'd' in opts:
|
||||
self.out_conv1 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
|
||||
|
@ -158,6 +183,32 @@ class DeepFakeArchi(nn.ArchiBase):
|
|||
self.out_convm = nn.Conv2D( d_mask_ch*1, 1, kernel_size=1, padding='SAME', dtype=conv_dtype)
|
||||
else:
|
||||
self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME', dtype=conv_dtype)
|
||||
else:
|
||||
self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3)
|
||||
self.upscale1 = Upscale(d_ch*8, d_ch*8, kernel_size=3)
|
||||
self.upscale2 = Upscale(d_ch*8, d_ch*4, kernel_size=3)
|
||||
self.upscale3 = Upscale(d_ch*4, d_ch*2, kernel_size=3)
|
||||
self.res0 = ResidualBlock(d_ch*8, kernel_size=3)
|
||||
self.res1 = ResidualBlock(d_ch*8, kernel_size=3)
|
||||
self.res2 = ResidualBlock(d_ch*4, kernel_size=3)
|
||||
self.res3 = ResidualBlock(d_ch*2, kernel_size=3)
|
||||
|
||||
self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3)
|
||||
self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*8, kernel_size=3)
|
||||
self.upscalem2 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3)
|
||||
self.upscalem3 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3)
|
||||
self.out_conv = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='SAME', dtype=conv_dtype)
|
||||
|
||||
if 'd' in opts:
|
||||
self.out_conv1 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
|
||||
self.out_conv2 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
|
||||
self.out_conv3 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
|
||||
self.upscalem4 = Upscale(d_mask_ch*2, d_mask_ch*1, kernel_size=3)
|
||||
self.out_convm = nn.Conv2D( d_mask_ch*1, 1, kernel_size=1, padding='SAME', dtype=conv_dtype)
|
||||
else:
|
||||
self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME', dtype=conv_dtype)
|
||||
|
||||
|
||||
|
||||
def forward(self, z):
|
||||
x = self.upscale0(z)
|
||||
|
@ -167,6 +218,10 @@ class DeepFakeArchi(nn.ArchiBase):
|
|||
x = self.upscale2(x)
|
||||
x = self.res2(x)
|
||||
|
||||
if 't' in opts:
|
||||
x = self.upscale3(x)
|
||||
x = self.res3(x)
|
||||
|
||||
if 'd' in opts:
|
||||
x = tf.nn.sigmoid( nn.depth_to_space(tf.concat( (self.out_conv(x),
|
||||
self.out_conv1(x),
|
||||
|
@ -179,8 +234,15 @@ class DeepFakeArchi(nn.ArchiBase):
|
|||
m = self.upscalem0(z)
|
||||
m = self.upscalem1(m)
|
||||
m = self.upscalem2(m)
|
||||
|
||||
if 't' in opts:
|
||||
m = self.upscalem3(m)
|
||||
if 'd' in opts:
|
||||
m = self.upscalem4(m)
|
||||
else:
|
||||
if 'd' in opts:
|
||||
m = self.upscalem3(m)
|
||||
|
||||
m = tf.nn.sigmoid(self.out_convm(m))
|
||||
|
||||
if use_fp16:
|
||||
|
|
|
@ -35,9 +35,7 @@ class SAEHDModel(ModelBase):
|
|||
default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'f')
|
||||
default_models_opt_on_gpu = self.options['models_opt_on_gpu'] = self.load_or_def_option('models_opt_on_gpu', True)
|
||||
|
||||
archi = self.load_or_def_option('archi', 'liae-ud')
|
||||
archi = {'dfuhd':'df-u','liaeuhd':'liae-u'}.get(archi, archi) #backward comp
|
||||
default_archi = self.options['archi'] = archi
|
||||
default_archi = self.options['archi'] = self.load_or_def_option('archi', 'liae-ud')
|
||||
|
||||
default_ae_dims = self.options['ae_dims'] = self.load_or_def_option('ae_dims', 256)
|
||||
default_e_dims = self.options['e_dims'] = self.load_or_def_option('e_dims', 64)
|
||||
|
@ -47,7 +45,6 @@ class SAEHDModel(ModelBase):
|
|||
default_eyes_mouth_prio = self.options['eyes_mouth_prio'] = self.load_or_def_option('eyes_mouth_prio', False)
|
||||
default_uniform_yaw = self.options['uniform_yaw'] = self.load_or_def_option('uniform_yaw', False)
|
||||
default_blur_out_mask = self.options['blur_out_mask'] = self.load_or_def_option('blur_out_mask', False)
|
||||
default_dst_denoise = self.options['dst_denoise'] = self.load_or_def_option('dst_denoise', False)
|
||||
|
||||
default_adabelief = self.options['adabelief'] = self.load_or_def_option('adabelief', True)
|
||||
|
||||
|
@ -107,7 +104,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
if archi_opts is not None:
|
||||
if len(archi_opts) == 0:
|
||||
continue
|
||||
if len([ 1 for opt in archi_opts if opt not in ['u','d'] ]) != 0:
|
||||
if len([ 1 for opt in archi_opts if opt not in ['u','d','t'] ]) != 0:
|
||||
continue
|
||||
|
||||
if 'd' in archi_opts:
|
||||
|
@ -141,7 +138,6 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
self.options['eyes_mouth_prio'] = io.input_bool ("Eyes and mouth priority", default_eyes_mouth_prio, help_message='Helps to fix eye problems during training like "alien eyes" and wrong eyes direction. Also makes the detail of the teeth higher.')
|
||||
self.options['uniform_yaw'] = io.input_bool ("Uniform yaw distribution of samples", default_uniform_yaw, help_message='Helps to fix blurry side faces due to small amount of them in the faceset.')
|
||||
self.options['blur_out_mask'] = io.input_bool ("Blur out mask", default_blur_out_mask, help_message='Blurs nearby area outside of applied face mask of training samples. The result is the background near the face is smoothed and less noticeable on swapped face. The exact xseg mask in src and dst faceset is required.')
|
||||
self.options['dst_denoise'] = io.input_bool ("Denoise DST faceset.", default_dst_denoise, help_message='Used in RTM(ReadyToMerge) training with RTM DST faceset. Removes high frequency noise keeping edges. Result is better face syncronization with any face. Can be enabled at any time.')
|
||||
|
||||
default_gan_power = self.options['gan_power'] = self.load_or_def_option('gan_power', 0.0)
|
||||
default_gan_patch_size = self.options['gan_patch_size'] = self.load_or_def_option('gan_patch_size', self.options['resolution'] // 8)
|
||||
|
@ -233,7 +229,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
random_src_flip = self.random_src_flip if not self.pretrain else True
|
||||
random_dst_flip = self.random_dst_flip if not self.pretrain else True
|
||||
blur_out_mask = self.options['blur_out_mask']
|
||||
dst_denoise = self.options['dst_denoise']
|
||||
learn_dst_bg = False#True
|
||||
|
||||
if self.pretrain:
|
||||
self.options_show_override['gan_power'] = 0.0
|
||||
|
@ -327,8 +323,6 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
elif 'liae' in archi_type:
|
||||
self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter_AB.get_weights() + self.inter_B.get_weights() + self.decoder.get_weights()
|
||||
|
||||
|
||||
|
||||
self.src_dst_opt = OptimizerClass(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='src_dst_opt')
|
||||
self.src_dst_opt.initialize_variables (self.src_dst_trainable_weights, vars_on_cpu=optimizer_vars_on_cpu, lr_dropout_on_cpu=self.options['lr_dropout']=='cpu')
|
||||
self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ]
|
||||
|
@ -413,6 +407,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
|
||||
gpu_pred_src_src, gpu_pred_src_srcm = self.decoder(gpu_src_code)
|
||||
gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder(gpu_dst_code)
|
||||
gpu_pred_dst_dst_no_code_grad, _ = self.decoder(tf.stop_gradient(gpu_dst_code))
|
||||
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code)
|
||||
|
||||
gpu_pred_src_src_list.append(gpu_pred_src_src)
|
||||
|
@ -425,25 +420,30 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
|
||||
gpu_target_srcm_blur = nn.gaussian_blur(gpu_target_srcm, max(1, resolution // 32) )
|
||||
gpu_target_srcm_blur = tf.clip_by_value(gpu_target_srcm_blur, 0, 0.5) * 2
|
||||
gpu_target_srcm_anti_blur = 1.0-gpu_target_srcm_blur
|
||||
|
||||
gpu_target_dstm_blur = nn.gaussian_blur(gpu_target_dstm, max(1, resolution // 32) )
|
||||
gpu_target_dstm_style_blur = gpu_target_dstm_blur #default style mask is 0.5 on boundary
|
||||
gpu_target_dstm_style_anti_blur = 1.0 - gpu_target_dstm_style_blur
|
||||
gpu_target_dstm_blur = tf.clip_by_value(gpu_target_dstm_blur, 0, 0.5) * 2
|
||||
gpu_target_dstm_anti_blur = 1.0-gpu_target_dstm_blur
|
||||
|
||||
gpu_target_dst_masked = gpu_target_dst*gpu_target_dstm_blur
|
||||
gpu_target_dst_style_masked = gpu_target_dst*gpu_target_dstm_style_blur
|
||||
gpu_target_dst_style_anti_masked = gpu_target_dst*(1.0 - gpu_target_dstm_style_blur)
|
||||
gpu_target_dst_style_anti_masked = gpu_target_dst*gpu_target_dstm_style_anti_blur
|
||||
|
||||
gpu_target_src_anti_masked = gpu_target_src*gpu_target_srcm_anti_blur
|
||||
gpu_target_dst_anti_masked = gpu_target_dst*gpu_target_dstm_anti_blur
|
||||
gpu_pred_src_src_anti_masked = gpu_pred_src_src*gpu_target_srcm_anti_blur
|
||||
gpu_pred_dst_dst_anti_masked = gpu_pred_dst_dst*gpu_target_dstm_anti_blur
|
||||
|
||||
gpu_target_src_anti_masked = gpu_target_src*(1.0-gpu_target_srcm_blur)
|
||||
gpu_target_src_masked_opt = gpu_target_src*gpu_target_srcm_blur if masked_training else gpu_target_src
|
||||
gpu_target_dst_masked_opt = gpu_target_dst_masked if masked_training else gpu_target_dst
|
||||
|
||||
gpu_pred_src_src_masked_opt = gpu_pred_src_src*gpu_target_srcm_blur if masked_training else gpu_pred_src_src
|
||||
gpu_pred_src_src_anti_masked = gpu_pred_src_src*(1.0-gpu_target_srcm_blur)
|
||||
gpu_pred_dst_dst_masked_opt = gpu_pred_dst_dst*gpu_target_dstm_blur if masked_training else gpu_pred_dst_dst
|
||||
|
||||
gpu_psd_target_dst_style_masked = gpu_pred_src_dst*gpu_target_dstm_style_blur
|
||||
gpu_psd_target_dst_style_anti_masked = gpu_pred_src_dst*(1.0 - gpu_target_dstm_style_blur)
|
||||
gpu_psd_target_dst_style_anti_masked = gpu_pred_src_dst*gpu_target_dstm_style_anti_blur
|
||||
|
||||
if resolution < 256:
|
||||
gpu_src_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
|
||||
|
@ -483,6 +483,9 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
|
||||
gpu_G_loss = gpu_src_loss + gpu_dst_loss
|
||||
|
||||
if learn_dst_bg and masked_training and 'liae' in archi_type:
|
||||
gpu_G_loss += tf.reduce_mean( tf.square(gpu_pred_dst_dst_no_code_grad*gpu_target_dstm_anti_blur-gpu_target_dst_anti_masked),axis=[1,2,3] )
|
||||
|
||||
def DLoss(labels,logits):
|
||||
return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits), axis=[1,2,3])
|
||||
|
||||
|
@ -526,8 +529,6 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
gpu_G_loss += gan_power*(DLoss(gpu_pred_src_src_d_ones, gpu_pred_src_src_d) + \
|
||||
DLoss(gpu_pred_src_src_d2_ones, gpu_pred_src_src_d2))
|
||||
|
||||
|
||||
|
||||
if masked_training:
|
||||
# Minimal src-src-bg rec with total_variation_mse to suppress random bright dots from gan
|
||||
gpu_G_loss += 0.000001*nn.total_variation_mse(gpu_pred_src_src)
|
||||
|
@ -536,6 +537,8 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
gpu_G_loss_gvs += [ nn.gradients ( gpu_G_loss, self.src_dst_trainable_weights )]
|
||||
|
||||
|
||||
|
||||
|
||||
# Average losses and gradients, and create optimizer update ops
|
||||
with tf.device(f'/CPU:0'):
|
||||
pred_src_src = nn.concat(gpu_pred_src_src_list, 0)
|
||||
|
@ -560,7 +563,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
# Initializing training and view functions
|
||||
def src_dst_train(warped_src, target_src, target_srcm, target_srcm_em, \
|
||||
warped_dst, target_dst, target_dstm, target_dstm_em, ):
|
||||
s, d, _ = nn.tf_sess.run ( [ src_loss, dst_loss, src_dst_loss_gv_op],
|
||||
s, d = nn.tf_sess.run ( [ src_loss, dst_loss, src_dst_loss_gv_op],
|
||||
feed_dict={self.warped_src :warped_src,
|
||||
self.target_src :target_src,
|
||||
self.target_srcm:target_srcm,
|
||||
|
@ -569,7 +572,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
self.target_dst :target_dst,
|
||||
self.target_dstm:target_dstm,
|
||||
self.target_dstm_em:target_dstm_em,
|
||||
})
|
||||
})[:2]
|
||||
return s, d
|
||||
self.src_dst_train = src_dst_train
|
||||
|
||||
|
@ -674,7 +677,6 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
sample_process_options=SampleProcessor.Options(random_flip=random_dst_flip),
|
||||
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'denoise_filter' : dst_denoise, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.EYES_MOUTH, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
],
|
||||
|
@ -762,13 +764,13 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
bs = self.get_batch_size()
|
||||
|
||||
( (warped_src, target_src, target_srcm, target_srcm_em), \
|
||||
(warped_dst, target_dst, target_dst_train, target_dstm, target_dstm_em) ) = self.generate_next_samples()
|
||||
(warped_dst, target_dst, target_dstm, target_dstm_em) ) = self.generate_next_samples()
|
||||
|
||||
src_loss, dst_loss = self.src_dst_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst_train, target_dstm, target_dstm_em)
|
||||
src_loss, dst_loss = self.src_dst_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em)
|
||||
|
||||
for i in range(bs):
|
||||
self.last_src_samples_loss.append ( (src_loss[i], target_src[i], target_srcm[i], target_srcm_em[i],) )
|
||||
self.last_dst_samples_loss.append ( (dst_loss[i], target_dst[i], target_dst_train[i], target_dstm[i], target_dstm_em[i],) )
|
||||
self.last_dst_samples_loss.append ( (dst_loss[i], target_dst[i], target_dstm[i], target_dstm_em[i],) )
|
||||
|
||||
if len(self.last_src_samples_loss) >= bs*16:
|
||||
src_samples_loss = sorted(self.last_src_samples_loss, key=operator.itemgetter(0), reverse=True)
|
||||
|
@ -779,11 +781,10 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
target_srcm_em = np.stack( [ x[3] for x in src_samples_loss[:bs] ] )
|
||||
|
||||
target_dst = np.stack( [ x[1] for x in dst_samples_loss[:bs] ] )
|
||||
target_dst_train = np.stack( [ x[2] for x in dst_samples_loss[:bs] ] )
|
||||
target_dstm = np.stack( [ x[3] for x in dst_samples_loss[:bs] ] )
|
||||
target_dstm_em = np.stack( [ x[4] for x in dst_samples_loss[:bs] ] )
|
||||
target_dstm = np.stack( [ x[2] for x in dst_samples_loss[:bs] ] )
|
||||
target_dstm_em = np.stack( [ x[3] for x in dst_samples_loss[:bs] ] )
|
||||
|
||||
src_loss, dst_loss = self.src_dst_train (target_src, target_src, target_srcm, target_srcm_em, target_dst, target_dst_train, target_dstm, target_dstm_em)
|
||||
src_loss, dst_loss = self.src_dst_train (target_src, target_src, target_srcm, target_srcm_em, target_dst, target_dst, target_dstm, target_dstm_em)
|
||||
self.last_src_samples_loss = []
|
||||
self.last_dst_samples_loss = []
|
||||
|
||||
|
@ -791,14 +792,14 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
self.D_train (warped_src, warped_dst)
|
||||
|
||||
if self.gan_power != 0:
|
||||
self.D_src_dst_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst_train, target_dstm, target_dstm_em)
|
||||
self.D_src_dst_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em)
|
||||
|
||||
return ( ('src_loss', np.mean(src_loss) ), ('dst_loss', np.mean(dst_loss) ), )
|
||||
|
||||
#override
|
||||
def onGetPreview(self, samples, for_history=False):
|
||||
( (warped_src, target_src, target_srcm, target_srcm_em),
|
||||
(warped_dst, target_dst, target_dst_train, target_dstm, target_dstm_em) ) = samples
|
||||
(warped_dst, target_dst, target_dstm, target_dstm_em) ) = samples
|
||||
|
||||
S, D, SS, DD, DDM, SD, SDM = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([target_src,target_dst] + self.AE_view (target_src, target_dst) ) ]
|
||||
DDM, SDM, = [ np.repeat (x, (3,), -1) for x in [DDM, SDM] ]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue