diff --git a/models/Model_AMP/Model.py b/models/Model_AMP/Model.py index 2ac5588..44e047d 100644 --- a/models/Model_AMP/Model.py +++ b/models/Model_AMP/Model.py @@ -122,9 +122,9 @@ class AMPModel(ModelBase): morph_factor = self.options['morph_factor'] gan_power = self.gan_power = self.options['gan_power'] random_warp = self.options['random_warp'] - + blur_out_mask = self.options['blur_out_mask'] - + ct_mode = self.options['ct_mode'] if ct_mode == 'none': ct_mode = None @@ -294,7 +294,10 @@ class AMPModel(ModelBase): clipnorm = 1.0 if self.options['clipgrad'] else 0.0 lr_dropout = 0.3 if self.options['lr_dropout'] in ['y','cpu'] else 1.0 - self.G_weights = self.encoder.get_weights() + self.inter_src.get_weights() + self.inter_dst.get_weights() + self.decoder.get_weights() + self.G_weights = self.encoder.get_weights() + self.decoder.get_weights() + + if random_warp: + self.G_weights += self.inter_src.get_weights() + self.inter_dst.get_weights() self.src_dst_opt = nn.AdaBelief(lr=5e-5, lr_dropout=lr_dropout, clipnorm=clipnorm, name='src_dst_opt') self.src_dst_opt.initialize_variables (self.G_weights, vars_on_cpu=optimizer_vars_on_cpu) @@ -352,64 +355,64 @@ class AMPModel(ModelBase): gpu_src_inter_src_code, gpu_src_inter_dst_code = self.inter_src (gpu_src_code), self.inter_dst (gpu_src_code) gpu_dst_inter_src_code, gpu_dst_inter_dst_code = self.inter_src (gpu_dst_code), self.inter_dst (gpu_dst_code) - - inter_dims_bin = int(inter_dims*morph_factor) - with tf.device(f'/CPU:0'): + + inter_dims_bin = int(inter_dims*morph_factor) + with tf.device(f'/CPU:0'): inter_rnd_binomial = tf.stack([tf.random.shuffle(tf.concat([tf.tile(tf.constant([1], tf.float32), ( inter_dims_bin, )), tf.tile(tf.constant([0], tf.float32), ( inter_dims-inter_dims_bin, ))], 0 )) for _ in range(bs_per_gpu)], 0) - + inter_rnd_binomial = tf.stop_gradient(inter_rnd_binomial[...,None,None]) - + gpu_src_code = gpu_src_inter_src_code * inter_rnd_binomial + gpu_src_inter_dst_code * (1-inter_rnd_binomial) gpu_dst_code = gpu_dst_inter_dst_code inter_dims_slice = tf.cast(inter_dims*self.morph_value_t[0], tf.int32) gpu_src_dst_code = tf.concat( (tf.slice(gpu_dst_inter_src_code, [0,0,0,0], [-1, inter_dims_slice , inter_res, inter_res]), tf.slice(gpu_dst_inter_dst_code, [0,inter_dims_slice,0,0], [-1,inter_dims-inter_dims_slice, inter_res,inter_res]) ), 1 ) - + gpu_pred_src_src, gpu_pred_src_srcm = self.decoder(gpu_src_code) gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder(gpu_dst_code) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code) - + gpu_pred_src_src_list.append(gpu_pred_src_src), gpu_pred_src_srcm_list.append(gpu_pred_src_srcm) gpu_pred_dst_dst_list.append(gpu_pred_dst_dst), gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm) gpu_pred_src_dst_list.append(gpu_pred_src_dst), gpu_pred_src_dstm_list.append(gpu_pred_src_dstm) gpu_target_srcm_anti = 1-gpu_target_srcm gpu_target_dstm_anti = 1-gpu_target_dstm - + gpu_target_srcm_gblur = nn.gaussian_blur(gpu_target_srcm, resolution // 32) gpu_target_dstm_gblur = nn.gaussian_blur(gpu_target_dstm, resolution // 32) - + gpu_target_srcm_blur = tf.clip_by_value(gpu_target_srcm_gblur, 0, 0.5) * 2 gpu_target_dstm_blur = tf.clip_by_value(gpu_target_dstm_gblur, 0, 0.5) * 2 gpu_target_srcm_anti_blur = 1.0-gpu_target_srcm_blur gpu_target_dstm_anti_blur = 1.0-gpu_target_dstm_blur - + if blur_out_mask: #gpu_target_src = gpu_target_src*gpu_target_srcm_blur + nn.gaussian_blur(gpu_target_src, resolution // 32)*gpu_target_srcm_anti_blur #gpu_target_dst = gpu_target_dst*gpu_target_dstm_blur + nn.gaussian_blur(gpu_target_dst, resolution // 32)*gpu_target_dstm_anti_blur bg_blur_div = 128 - + gpu_target_src = gpu_target_src*gpu_target_srcm + \ - tf.math.divide_no_nan(nn.gaussian_blur(gpu_target_src*gpu_target_srcm_anti, resolution / bg_blur_div), + tf.math.divide_no_nan(nn.gaussian_blur(gpu_target_src*gpu_target_srcm_anti, resolution / bg_blur_div), (1-nn.gaussian_blur(gpu_target_srcm, resolution / bg_blur_div) ) ) * gpu_target_srcm_anti gpu_target_dst = gpu_target_dst*gpu_target_dstm + \ tf.math.divide_no_nan(nn.gaussian_blur(gpu_target_dst*gpu_target_dstm_anti, resolution / bg_blur_div), (1-nn.gaussian_blur(gpu_target_dstm, resolution / bg_blur_div)) ) * gpu_target_dstm_anti - + gpu_target_src_masked = gpu_target_src*gpu_target_srcm_blur gpu_target_dst_masked = gpu_target_dst*gpu_target_dstm_blur gpu_target_src_anti_masked = gpu_target_src*gpu_target_srcm_anti_blur gpu_target_dst_anti_masked = gpu_target_dst*gpu_target_dstm_anti_blur - + gpu_pred_src_src_masked = gpu_pred_src_src*gpu_target_srcm_blur gpu_pred_dst_dst_masked = gpu_pred_dst_dst*gpu_target_dstm_blur gpu_pred_src_src_anti_masked = gpu_pred_src_src*gpu_target_srcm_anti_blur gpu_pred_dst_dst_anti_masked = gpu_pred_dst_dst*gpu_target_dstm_anti_blur - + # Structural loss gpu_src_loss = tf.reduce_mean (5*nn.dssim(gpu_target_src_masked, gpu_pred_src_src_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) gpu_src_loss += tf.reduce_mean (5*nn.dssim(gpu_target_src_masked, gpu_pred_src_src_masked, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1]) @@ -434,8 +437,8 @@ class AMPModel(ModelBase): # dst-dst background weak loss gpu_G_loss += tf.reduce_mean(0.1*tf.square(gpu_pred_dst_dst_anti_masked-gpu_target_dst_anti_masked),axis=[1,2,3] ) gpu_G_loss += 0.000001*nn.total_variation_mse(gpu_pred_dst_dst_anti_masked) - - + + if gan_power != 0: gpu_pred_src_src_d, gpu_pred_src_src_d2 = self.GAN(gpu_pred_src_src_masked) gpu_pred_dst_dst_d, gpu_pred_dst_dst_d2 = self.GAN(gpu_pred_dst_dst_masked) @@ -453,7 +456,7 @@ class AMPModel(ModelBase): gpu_G_loss += (DLossOnes(gpu_pred_src_src_d) + DLossOnes(gpu_pred_src_src_d2) + \ DLossOnes(gpu_pred_dst_dst_d) + DLossOnes(gpu_pred_dst_dst_d2) ) * gan_power - + # Minimal src-src-bg rec with total_variation_mse to suppress random bright dots from gan gpu_G_loss += 0.000001*nn.total_variation_mse(gpu_pred_src_src) gpu_G_loss += 0.02*tf.reduce_mean(tf.square(gpu_pred_src_src_anti_masked-gpu_target_src_anti_masked),axis=[1,2,3] ) @@ -473,7 +476,7 @@ class AMPModel(ModelBase): src_loss = tf.concat(gpu_src_losses, 0) dst_loss = tf.concat(gpu_dst_losses, 0) train_op = self.src_dst_opt.get_update_op (nn.average_gv_list (gpu_G_loss_gradients)) - + if gan_power != 0: GAN_train_op = self.GAN_opt.get_update_op (nn.average_gv_list(gpu_GAN_loss_gradients) ) @@ -554,8 +557,8 @@ class AMPModel(ModelBase): dst_generators_count = cpu_count // 2 if ct_mode is not None: src_generators_count = int(src_generators_count * 1.5) - - + + self.set_training_data_generators ([ SampleGeneratorFace(training_data_src_path, random_ct_samples_path=random_ct_samples_path, debug=self.is_debug(), batch_size=self.get_batch_size(),