diff --git a/core/leras/archis/DeepFakeArchi.py b/core/leras/archis/DeepFakeArchi.py index 38b3f87..5dfd293 100644 --- a/core/leras/archis/DeepFakeArchi.py +++ b/core/leras/archis/DeepFakeArchi.py @@ -7,6 +7,10 @@ class DeepFakeArchi(nn.ArchiBase): mod None - default 'quick' + + opts '' + '' + 't' """ def __init__(self, resolution, use_fp16=False, mod=None, opts=None): super().__init__() @@ -16,7 +20,7 @@ class DeepFakeArchi(nn.ArchiBase): conv_dtype = tf.float16 if use_fp16 else tf.float32 - + if mod is None: class Downscale(nn.ModelBase): def __init__(self, in_ch, out_ch, kernel_size=5, *kwargs ): @@ -79,21 +83,44 @@ class DeepFakeArchi(nn.ArchiBase): self.in_ch = in_ch self.e_ch = e_ch super().__init__(**kwargs) - - def on_build(self): - self.down1 = DownscaleBlock(self.in_ch, self.e_ch, n_downscales=4, kernel_size=5) + + def on_build(self): + if 't' in opts: + self.down1 = Downscale(self.in_ch, self.e_ch, kernel_size=5) + self.res1 = ResidualBlock(self.e_ch) + self.down2 = Downscale(self.e_ch, self.e_ch*2, kernel_size=5) + self.down3 = Downscale(self.e_ch*2, self.e_ch*4, kernel_size=5) + self.down4 = Downscale(self.e_ch*4, self.e_ch*8, kernel_size=5) + self.down5 = Downscale(self.e_ch*8, self.e_ch*8, kernel_size=5) + self.res5 = ResidualBlock(self.e_ch*8) + else: + self.down1 = DownscaleBlock(self.in_ch, self.e_ch, n_downscales=4 if 't' not in opts else 5, kernel_size=5) def forward(self, x): if use_fp16: x = tf.cast(x, tf.float16) - x = nn.flatten(self.down1(x)) + + if 't' in opts: + x = self.down1(x) + x = self.res1(x) + x = self.down2(x) + x = self.down3(x) + x = self.down4(x) + x = self.down5(x) + x = self.res5(x) + else: + x = self.down1(x) + x = nn.flatten(x) + if 'u' in opts: + x = nn.pixel_norm(x, axes=-1) + if use_fp16: x = tf.cast(x, tf.float32) return x - + def get_out_res(self, res): - return res // (2**4) - + return res // ( (2**4) if 't' not in opts else (2**5) ) + def get_out_ch(self): return self.e_ch * 8 @@ -106,59 +133,83 @@ class DeepFakeArchi(nn.ArchiBase): def on_build(self): in_ch, ae_ch, ae_out_ch = self.in_ch, self.ae_ch, self.ae_out_ch - - if 'u' in opts: - self.dense_norm = nn.DenseNorm() - + self.dense1 = nn.Dense( in_ch, ae_ch ) self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch ) - self.upscale1 = Upscale(ae_out_ch, ae_out_ch) + if 't' not in opts: + self.upscale1 = Upscale(ae_out_ch, ae_out_ch) def forward(self, inp): x = inp - if 'u' in opts: - x = self.dense_norm(x) x = self.dense1(x) x = self.dense2(x) x = nn.reshape_4D (x, lowest_dense_res, lowest_dense_res, self.ae_out_ch) - + if use_fp16: x = tf.cast(x, tf.float16) - x = self.upscale1(x) + + if 't' not in opts: + x = self.upscale1(x) + return x def get_out_res(self): - return lowest_dense_res * 2 + return lowest_dense_res * 2 if 't' not in opts else lowest_dense_res def get_out_ch(self): return self.ae_out_ch class Decoder(nn.ModelBase): - def on_build(self, in_ch, d_ch, d_mask_ch): - self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3) - self.upscale1 = Upscale(d_ch*8, d_ch*4, kernel_size=3) - self.upscale2 = Upscale(d_ch*4, d_ch*2, kernel_size=3) + def on_build(self, in_ch, d_ch, d_mask_ch): + if 't' not in opts: + self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3) + self.upscale1 = Upscale(d_ch*8, d_ch*4, kernel_size=3) + self.upscale2 = Upscale(d_ch*4, d_ch*2, kernel_size=3) + self.res0 = ResidualBlock(d_ch*8, kernel_size=3) + self.res1 = ResidualBlock(d_ch*4, kernel_size=3) + self.res2 = ResidualBlock(d_ch*2, kernel_size=3) - self.res0 = ResidualBlock(d_ch*8, kernel_size=3) - self.res1 = ResidualBlock(d_ch*4, kernel_size=3) - self.res2 = ResidualBlock(d_ch*2, kernel_size=3) + self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3) + self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3) + self.upscalem2 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3) - self.out_conv = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='SAME', dtype=conv_dtype) + self.out_conv = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='SAME', dtype=conv_dtype) - self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3) - self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3) - self.upscalem2 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3) - self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME', dtype=conv_dtype) - - if 'd' in opts: - self.out_conv1 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype) - self.out_conv2 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype) - self.out_conv3 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype) - self.upscalem3 = Upscale(d_mask_ch*2, d_mask_ch*1, kernel_size=3) - self.out_convm = nn.Conv2D( d_mask_ch*1, 1, kernel_size=1, padding='SAME', dtype=conv_dtype) + if 'd' in opts: + self.out_conv1 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype) + self.out_conv2 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype) + self.out_conv3 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype) + self.upscalem3 = Upscale(d_mask_ch*2, d_mask_ch*1, kernel_size=3) + self.out_convm = nn.Conv2D( d_mask_ch*1, 1, kernel_size=1, padding='SAME', dtype=conv_dtype) + else: + self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME', dtype=conv_dtype) else: - self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME', dtype=conv_dtype) + self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3) + self.upscale1 = Upscale(d_ch*8, d_ch*8, kernel_size=3) + self.upscale2 = Upscale(d_ch*8, d_ch*4, kernel_size=3) + self.upscale3 = Upscale(d_ch*4, d_ch*2, kernel_size=3) + self.res0 = ResidualBlock(d_ch*8, kernel_size=3) + self.res1 = ResidualBlock(d_ch*8, kernel_size=3) + self.res2 = ResidualBlock(d_ch*4, kernel_size=3) + self.res3 = ResidualBlock(d_ch*2, kernel_size=3) + self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3) + self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*8, kernel_size=3) + self.upscalem2 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3) + self.upscalem3 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3) + self.out_conv = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='SAME', dtype=conv_dtype) + + if 'd' in opts: + self.out_conv1 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype) + self.out_conv2 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype) + self.out_conv3 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype) + self.upscalem4 = Upscale(d_mask_ch*2, d_mask_ch*1, kernel_size=3) + self.out_convm = nn.Conv2D( d_mask_ch*1, 1, kernel_size=1, padding='SAME', dtype=conv_dtype) + else: + self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME', dtype=conv_dtype) + + + def forward(self, z): x = self.upscale0(z) x = self.res0(x) @@ -167,6 +218,10 @@ class DeepFakeArchi(nn.ArchiBase): x = self.upscale2(x) x = self.res2(x) + if 't' in opts: + x = self.upscale3(x) + x = self.res3(x) + if 'd' in opts: x = tf.nn.sigmoid( nn.depth_to_space(tf.concat( (self.out_conv(x), self.out_conv1(x), @@ -179,16 +234,23 @@ class DeepFakeArchi(nn.ArchiBase): m = self.upscalem0(z) m = self.upscalem1(m) m = self.upscalem2(m) - if 'd' in opts: + + if 't' in opts: m = self.upscalem3(m) + if 'd' in opts: + m = self.upscalem4(m) + else: + if 'd' in opts: + m = self.upscalem3(m) + m = tf.nn.sigmoid(self.out_convm(m)) - + if use_fp16: - x = tf.cast(x, tf.float32) + x = tf.cast(x, tf.float32) m = tf.cast(m, tf.float32) - + return x, m - + self.Encoder = Encoder self.Inter = Inter self.Decoder = Decoder diff --git a/models/Model_SAEHD/Model.py b/models/Model_SAEHD/Model.py index 9569159..225b44d 100644 --- a/models/Model_SAEHD/Model.py +++ b/models/Model_SAEHD/Model.py @@ -35,9 +35,7 @@ class SAEHDModel(ModelBase): default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'f') default_models_opt_on_gpu = self.options['models_opt_on_gpu'] = self.load_or_def_option('models_opt_on_gpu', True) - archi = self.load_or_def_option('archi', 'liae-ud') - archi = {'dfuhd':'df-u','liaeuhd':'liae-u'}.get(archi, archi) #backward comp - default_archi = self.options['archi'] = archi + default_archi = self.options['archi'] = self.load_or_def_option('archi', 'liae-ud') default_ae_dims = self.options['ae_dims'] = self.load_or_def_option('ae_dims', 256) default_e_dims = self.options['e_dims'] = self.load_or_def_option('e_dims', 64) @@ -47,7 +45,6 @@ class SAEHDModel(ModelBase): default_eyes_mouth_prio = self.options['eyes_mouth_prio'] = self.load_or_def_option('eyes_mouth_prio', False) default_uniform_yaw = self.options['uniform_yaw'] = self.load_or_def_option('uniform_yaw', False) default_blur_out_mask = self.options['blur_out_mask'] = self.load_or_def_option('blur_out_mask', False) - default_dst_denoise = self.options['dst_denoise'] = self.load_or_def_option('dst_denoise', False) default_adabelief = self.options['adabelief'] = self.load_or_def_option('adabelief', True) @@ -107,7 +104,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ... if archi_opts is not None: if len(archi_opts) == 0: continue - if len([ 1 for opt in archi_opts if opt not in ['u','d'] ]) != 0: + if len([ 1 for opt in archi_opts if opt not in ['u','d','t'] ]) != 0: continue if 'd' in archi_opts: @@ -141,7 +138,6 @@ Examples: df, liae, df-d, df-ud, liae-ud, ... self.options['eyes_mouth_prio'] = io.input_bool ("Eyes and mouth priority", default_eyes_mouth_prio, help_message='Helps to fix eye problems during training like "alien eyes" and wrong eyes direction. Also makes the detail of the teeth higher.') self.options['uniform_yaw'] = io.input_bool ("Uniform yaw distribution of samples", default_uniform_yaw, help_message='Helps to fix blurry side faces due to small amount of them in the faceset.') self.options['blur_out_mask'] = io.input_bool ("Blur out mask", default_blur_out_mask, help_message='Blurs nearby area outside of applied face mask of training samples. The result is the background near the face is smoothed and less noticeable on swapped face. The exact xseg mask in src and dst faceset is required.') - self.options['dst_denoise'] = io.input_bool ("Denoise DST faceset.", default_dst_denoise, help_message='Used in RTM(ReadyToMerge) training with RTM DST faceset. Removes high frequency noise keeping edges. Result is better face syncronization with any face. Can be enabled at any time.') default_gan_power = self.options['gan_power'] = self.load_or_def_option('gan_power', 0.0) default_gan_patch_size = self.options['gan_patch_size'] = self.load_or_def_option('gan_patch_size', self.options['resolution'] // 8) @@ -233,7 +229,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ... random_src_flip = self.random_src_flip if not self.pretrain else True random_dst_flip = self.random_dst_flip if not self.pretrain else True blur_out_mask = self.options['blur_out_mask'] - dst_denoise = self.options['dst_denoise'] + learn_dst_bg = False#True if self.pretrain: self.options_show_override['gan_power'] = 0.0 @@ -327,8 +323,6 @@ Examples: df, liae, df-d, df-ud, liae-ud, ... elif 'liae' in archi_type: self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter_AB.get_weights() + self.inter_B.get_weights() + self.decoder.get_weights() - - self.src_dst_opt = OptimizerClass(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='src_dst_opt') self.src_dst_opt.initialize_variables (self.src_dst_trainable_weights, vars_on_cpu=optimizer_vars_on_cpu, lr_dropout_on_cpu=self.options['lr_dropout']=='cpu') self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ] @@ -413,6 +407,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ... gpu_pred_src_src, gpu_pred_src_srcm = self.decoder(gpu_src_code) gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder(gpu_dst_code) + gpu_pred_dst_dst_no_code_grad, _ = self.decoder(tf.stop_gradient(gpu_dst_code)) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code) gpu_pred_src_src_list.append(gpu_pred_src_src) @@ -425,25 +420,30 @@ Examples: df, liae, df-d, df-ud, liae-ud, ... gpu_target_srcm_blur = nn.gaussian_blur(gpu_target_srcm, max(1, resolution // 32) ) gpu_target_srcm_blur = tf.clip_by_value(gpu_target_srcm_blur, 0, 0.5) * 2 + gpu_target_srcm_anti_blur = 1.0-gpu_target_srcm_blur gpu_target_dstm_blur = nn.gaussian_blur(gpu_target_dstm, max(1, resolution // 32) ) gpu_target_dstm_style_blur = gpu_target_dstm_blur #default style mask is 0.5 on boundary + gpu_target_dstm_style_anti_blur = 1.0 - gpu_target_dstm_style_blur gpu_target_dstm_blur = tf.clip_by_value(gpu_target_dstm_blur, 0, 0.5) * 2 + gpu_target_dstm_anti_blur = 1.0-gpu_target_dstm_blur - gpu_target_dst_masked = gpu_target_dst*gpu_target_dstm_blur + gpu_target_dst_masked = gpu_target_dst*gpu_target_dstm_blur gpu_target_dst_style_masked = gpu_target_dst*gpu_target_dstm_style_blur - gpu_target_dst_style_anti_masked = gpu_target_dst*(1.0 - gpu_target_dstm_style_blur) + gpu_target_dst_style_anti_masked = gpu_target_dst*gpu_target_dstm_style_anti_blur + + gpu_target_src_anti_masked = gpu_target_src*gpu_target_srcm_anti_blur + gpu_target_dst_anti_masked = gpu_target_dst*gpu_target_dstm_anti_blur + gpu_pred_src_src_anti_masked = gpu_pred_src_src*gpu_target_srcm_anti_blur + gpu_pred_dst_dst_anti_masked = gpu_pred_dst_dst*gpu_target_dstm_anti_blur - gpu_target_src_anti_masked = gpu_target_src*(1.0-gpu_target_srcm_blur) gpu_target_src_masked_opt = gpu_target_src*gpu_target_srcm_blur if masked_training else gpu_target_src gpu_target_dst_masked_opt = gpu_target_dst_masked if masked_training else gpu_target_dst - gpu_pred_src_src_masked_opt = gpu_pred_src_src*gpu_target_srcm_blur if masked_training else gpu_pred_src_src - gpu_pred_src_src_anti_masked = gpu_pred_src_src*(1.0-gpu_target_srcm_blur) gpu_pred_dst_dst_masked_opt = gpu_pred_dst_dst*gpu_target_dstm_blur if masked_training else gpu_pred_dst_dst gpu_psd_target_dst_style_masked = gpu_pred_src_dst*gpu_target_dstm_style_blur - gpu_psd_target_dst_style_anti_masked = gpu_pred_src_dst*(1.0 - gpu_target_dstm_style_blur) + gpu_psd_target_dst_style_anti_masked = gpu_pred_src_dst*gpu_target_dstm_style_anti_blur if resolution < 256: gpu_src_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) @@ -483,6 +483,9 @@ Examples: df, liae, df-d, df-ud, liae-ud, ... gpu_G_loss = gpu_src_loss + gpu_dst_loss + if learn_dst_bg and masked_training and 'liae' in archi_type: + gpu_G_loss += tf.reduce_mean( tf.square(gpu_pred_dst_dst_no_code_grad*gpu_target_dstm_anti_blur-gpu_target_dst_anti_masked),axis=[1,2,3] ) + def DLoss(labels,logits): return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits), axis=[1,2,3]) @@ -526,14 +529,14 @@ Examples: df, liae, df-d, df-ud, liae-ud, ... gpu_G_loss += gan_power*(DLoss(gpu_pred_src_src_d_ones, gpu_pred_src_src_d) + \ DLoss(gpu_pred_src_src_d2_ones, gpu_pred_src_src_d2)) - - if masked_training: # Minimal src-src-bg rec with total_variation_mse to suppress random bright dots from gan gpu_G_loss += 0.000001*nn.total_variation_mse(gpu_pred_src_src) gpu_G_loss += 0.02*tf.reduce_mean(tf.square(gpu_pred_src_src_anti_masked-gpu_target_src_anti_masked),axis=[1,2,3] ) - gpu_G_loss_gvs += [ nn.gradients ( gpu_G_loss, self.src_dst_trainable_weights ) ] + gpu_G_loss_gvs += [ nn.gradients ( gpu_G_loss, self.src_dst_trainable_weights )] + + # Average losses and gradients, and create optimizer update ops @@ -560,7 +563,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ... # Initializing training and view functions def src_dst_train(warped_src, target_src, target_srcm, target_srcm_em, \ warped_dst, target_dst, target_dstm, target_dstm_em, ): - s, d, _ = nn.tf_sess.run ( [ src_loss, dst_loss, src_dst_loss_gv_op], + s, d = nn.tf_sess.run ( [ src_loss, dst_loss, src_dst_loss_gv_op], feed_dict={self.warped_src :warped_src, self.target_src :target_src, self.target_srcm:target_srcm, @@ -569,7 +572,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ... self.target_dst :target_dst, self.target_dstm:target_dstm, self.target_dstm_em:target_dstm_em, - }) + })[:2] return s, d self.src_dst_train = src_dst_train @@ -674,7 +677,6 @@ Examples: df, liae, df-d, df-ud, liae-ud, ... sample_process_options=SampleProcessor.Options(random_flip=random_dst_flip), output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, - {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'denoise_filter' : dst_denoise, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.EYES_MOUTH, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, ], @@ -762,13 +764,13 @@ Examples: df, liae, df-d, df-ud, liae-ud, ... bs = self.get_batch_size() ( (warped_src, target_src, target_srcm, target_srcm_em), \ - (warped_dst, target_dst, target_dst_train, target_dstm, target_dstm_em) ) = self.generate_next_samples() + (warped_dst, target_dst, target_dstm, target_dstm_em) ) = self.generate_next_samples() - src_loss, dst_loss = self.src_dst_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst_train, target_dstm, target_dstm_em) + src_loss, dst_loss = self.src_dst_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em) for i in range(bs): self.last_src_samples_loss.append ( (src_loss[i], target_src[i], target_srcm[i], target_srcm_em[i],) ) - self.last_dst_samples_loss.append ( (dst_loss[i], target_dst[i], target_dst_train[i], target_dstm[i], target_dstm_em[i],) ) + self.last_dst_samples_loss.append ( (dst_loss[i], target_dst[i], target_dstm[i], target_dstm_em[i],) ) if len(self.last_src_samples_loss) >= bs*16: src_samples_loss = sorted(self.last_src_samples_loss, key=operator.itemgetter(0), reverse=True) @@ -779,11 +781,10 @@ Examples: df, liae, df-d, df-ud, liae-ud, ... target_srcm_em = np.stack( [ x[3] for x in src_samples_loss[:bs] ] ) target_dst = np.stack( [ x[1] for x in dst_samples_loss[:bs] ] ) - target_dst_train = np.stack( [ x[2] for x in dst_samples_loss[:bs] ] ) - target_dstm = np.stack( [ x[3] for x in dst_samples_loss[:bs] ] ) - target_dstm_em = np.stack( [ x[4] for x in dst_samples_loss[:bs] ] ) + target_dstm = np.stack( [ x[2] for x in dst_samples_loss[:bs] ] ) + target_dstm_em = np.stack( [ x[3] for x in dst_samples_loss[:bs] ] ) - src_loss, dst_loss = self.src_dst_train (target_src, target_src, target_srcm, target_srcm_em, target_dst, target_dst_train, target_dstm, target_dstm_em) + src_loss, dst_loss = self.src_dst_train (target_src, target_src, target_srcm, target_srcm_em, target_dst, target_dst, target_dstm, target_dstm_em) self.last_src_samples_loss = [] self.last_dst_samples_loss = [] @@ -791,14 +792,14 @@ Examples: df, liae, df-d, df-ud, liae-ud, ... self.D_train (warped_src, warped_dst) if self.gan_power != 0: - self.D_src_dst_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst_train, target_dstm, target_dstm_em) + self.D_src_dst_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em) return ( ('src_loss', np.mean(src_loss) ), ('dst_loss', np.mean(dst_loss) ), ) #override def onGetPreview(self, samples, for_history=False): ( (warped_src, target_src, target_srcm, target_srcm_em), - (warped_dst, target_dst, target_dst_train, target_dstm, target_dstm_em) ) = samples + (warped_dst, target_dst, target_dstm, target_dstm_em) ) = samples S, D, SS, DD, DDM, SD, SDM = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([target_src,target_dst] + self.AE_view (target_src, target_dst) ) ] DDM, SDM, = [ np.repeat (x, (3,), -1) for x in [DDM, SDM] ]