diff --git a/models/Model_AVATAR/Model.py b/models/Model_AVATAR/Model.py index bd4e55b..6843095 100644 --- a/models/Model_AVATAR/Model.py +++ b/models/Model_AVATAR/Model.py @@ -3,6 +3,7 @@ from models import TrainingDataType import numpy as np import cv2 from nnlib import tf_dssim +from nnlib import DSSIMLossClass from nnlib import conv from nnlib import upscale @@ -37,42 +38,35 @@ class Model(ModelBase): self.encoder64, self.decoder64_src, self.decoder64_dst, self.encoder256, self.decoder256 = self.to_multi_gpu_model_if_possible ( [self.encoder64, self.decoder64_src, self.decoder64_dst, self.encoder256, self.decoder256] ) input_A_warped64 = keras.layers.Input(img_shape64) - input_A_target64 = keras.layers.Input(img_shape64) input_B_warped64 = keras.layers.Input(img_shape64) - input_B_target64 = keras.layers.Input(img_shape64) + A_rec64 = self.decoder64_src(self.encoder64(input_A_warped64)) + B_rec64 = self.decoder64_dst(self.encoder64(input_B_warped64)) + self.ae64 = self.keras.models.Model([input_A_warped64, input_B_warped64], [A_rec64, B_rec64] ) - A_code64 = self.encoder64(input_A_warped64) - B_code64 = self.encoder64(input_B_warped64) - - A_rec64 = self.decoder64_src(A_code64) - B_rec64 = self.decoder64_dst(B_code64) - - A64_loss = tf_dssim(tf, input_A_target64, A_rec64) - B64_loss = tf_dssim(tf, input_B_target64, B_rec64) - total64_loss = A64_loss + B64_loss + if self.is_training_mode: + self.ae64, = self.to_multi_gpu_model_if_possible ( [self.ae64,] ) - self.ed64_train = K.function ([input_A_warped64, input_A_target64, input_B_warped64, input_B_target64],[K.mean(total64_loss)], - self.keras.optimizers.Adam(lr=5e-5, beta_1=0.5, beta_2=0.999).get_updates(total64_loss, self.encoder64.trainable_weights + self.decoder64_src.trainable_weights + self.decoder64_dst.trainable_weights) - ) + self.ae64.compile(optimizer=self.keras.optimizers.Adam(lr=5e-5, beta_1=0.5, beta_2=0.999), + loss=[DSSIMLossClass(self.tf)(), DSSIMLossClass(self.tf)()] ) + self.A64_view = K.function ([input_A_warped64], [A_rec64]) self.B64_view = K.function ([input_B_warped64], [B_rec64]) input_A_warped64 = keras.layers.Input(img_shape64) input_A_target256 = keras.layers.Input(img_shape256) - A_code256 = self.encoder256(input_A_warped64) - A_rec256 = self.decoder256(A_code256) + A_rec256 = self.decoder256( self.encoder256(input_A_warped64) ) input_B_warped64 = keras.layers.Input(img_shape64) - B_code64 = self.encoder64(input_B_warped64) - BA_rec64 = self.decoder64_src(B_code64) - BA_code256 = self.encoder256(BA_rec64) - BA_rec256 = self.decoder256(BA_code256) - - total256_loss = K.mean( tf_dssim(tf, input_A_target256, A_rec256) ) - - self.ed256_train = K.function ([input_A_warped64, input_A_target256],[total256_loss], - self.keras.optimizers.Adam(lr=5e-5, beta_1=0.5, beta_2=0.999).get_updates(total256_loss, self.encoder256.trainable_weights + self.decoder256.trainable_weights) - ) + BA_rec64 = self.decoder64_src( self.encoder64(input_B_warped64) ) + BA_rec256 = self.decoder256( self.encoder256(BA_rec64) ) + + self.ae256 = self.keras.models.Model([input_A_warped64], [A_rec256] ) + + if self.is_training_mode: + self.ae256, = self.to_multi_gpu_model_if_possible ( [self.ae256,] ) + + self.ae256.compile(optimizer=self.keras.optimizers.Adam(lr=5e-5, beta_1=0.5, beta_2=0.999), + loss=[DSSIMLossClass(self.tf)()]) self.A256_view = K.function ([input_A_warped64], [A_rec256]) self.BA256_view = K.function ([input_B_warped64], [BA_rec256]) @@ -108,10 +102,11 @@ class Model(ModelBase): warped_src64, target_src64, target_src256, target_src_source64, target_src_source256 = sample[0] warped_dst64, target_dst64, target_dst_source64, target_dst_source256 = sample[1] - loss64, = self.ed64_train ([warped_src64, target_src64, warped_dst64, target_dst64]) - loss256, = self.ed256_train ([warped_src64, target_src256]) + loss64, loss_src64, loss_dst64 = self.ae64.train_on_batch ([warped_src64, warped_dst64], [target_src64, target_dst64]) - return ( ('loss64', loss64), ('loss256', loss256), ) + loss256 = self.ae256.train_on_batch ([warped_src64], [target_src256]) + + return ( ('loss64', loss64 ), ('loss256', loss256), ) #override def onGetPreview(self, sample): diff --git a/nnlib/__init__.py b/nnlib/__init__.py index 69fe140..b543776 100644 --- a/nnlib/__init__.py +++ b/nnlib/__init__.py @@ -43,6 +43,19 @@ def DSSIMMaskLossClass(tf): return total_loss return DSSIMMaskLoss + +def DSSIMLossClass(tf): + class DSSIMLoss(object): + def __init__(self, is_tanh=False): + self.is_tanh = is_tanh + + def __call__(self,y_true, y_pred): + if not self.is_tanh: + return (1.0 - tf.image.ssim (y_true, y_pred, 1.0)) / 2.0 + else: + return (1.0 - tf.image.ssim ((y_true/2+0.5), (y_pred/2+0.5), 1.0)) / 2.0 + + return DSSIMLoss def MSEMaskLossClass(keras): class MSEMaskLoss(object):