mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-07 21:42:08 -07:00
fix avatar multigpu mode
This commit is contained in:
parent
adc1e701de
commit
58d9b4ae09
2 changed files with 37 additions and 29 deletions
|
@ -3,6 +3,7 @@ from models import TrainingDataType
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import cv2
|
import cv2
|
||||||
from nnlib import tf_dssim
|
from nnlib import tf_dssim
|
||||||
|
from nnlib import DSSIMLossClass
|
||||||
from nnlib import conv
|
from nnlib import conv
|
||||||
from nnlib import upscale
|
from nnlib import upscale
|
||||||
|
|
||||||
|
@ -37,42 +38,35 @@ class Model(ModelBase):
|
||||||
self.encoder64, self.decoder64_src, self.decoder64_dst, self.encoder256, self.decoder256 = self.to_multi_gpu_model_if_possible ( [self.encoder64, self.decoder64_src, self.decoder64_dst, self.encoder256, self.decoder256] )
|
self.encoder64, self.decoder64_src, self.decoder64_dst, self.encoder256, self.decoder256 = self.to_multi_gpu_model_if_possible ( [self.encoder64, self.decoder64_src, self.decoder64_dst, self.encoder256, self.decoder256] )
|
||||||
|
|
||||||
input_A_warped64 = keras.layers.Input(img_shape64)
|
input_A_warped64 = keras.layers.Input(img_shape64)
|
||||||
input_A_target64 = keras.layers.Input(img_shape64)
|
|
||||||
input_B_warped64 = keras.layers.Input(img_shape64)
|
input_B_warped64 = keras.layers.Input(img_shape64)
|
||||||
input_B_target64 = keras.layers.Input(img_shape64)
|
A_rec64 = self.decoder64_src(self.encoder64(input_A_warped64))
|
||||||
|
B_rec64 = self.decoder64_dst(self.encoder64(input_B_warped64))
|
||||||
|
self.ae64 = self.keras.models.Model([input_A_warped64, input_B_warped64], [A_rec64, B_rec64] )
|
||||||
|
|
||||||
A_code64 = self.encoder64(input_A_warped64)
|
if self.is_training_mode:
|
||||||
B_code64 = self.encoder64(input_B_warped64)
|
self.ae64, = self.to_multi_gpu_model_if_possible ( [self.ae64,] )
|
||||||
|
|
||||||
A_rec64 = self.decoder64_src(A_code64)
|
self.ae64.compile(optimizer=self.keras.optimizers.Adam(lr=5e-5, beta_1=0.5, beta_2=0.999),
|
||||||
B_rec64 = self.decoder64_dst(B_code64)
|
loss=[DSSIMLossClass(self.tf)(), DSSIMLossClass(self.tf)()] )
|
||||||
|
|
||||||
A64_loss = tf_dssim(tf, input_A_target64, A_rec64)
|
|
||||||
B64_loss = tf_dssim(tf, input_B_target64, B_rec64)
|
|
||||||
total64_loss = A64_loss + B64_loss
|
|
||||||
|
|
||||||
self.ed64_train = K.function ([input_A_warped64, input_A_target64, input_B_warped64, input_B_target64],[K.mean(total64_loss)],
|
|
||||||
self.keras.optimizers.Adam(lr=5e-5, beta_1=0.5, beta_2=0.999).get_updates(total64_loss, self.encoder64.trainable_weights + self.decoder64_src.trainable_weights + self.decoder64_dst.trainable_weights)
|
|
||||||
)
|
|
||||||
self.A64_view = K.function ([input_A_warped64], [A_rec64])
|
self.A64_view = K.function ([input_A_warped64], [A_rec64])
|
||||||
self.B64_view = K.function ([input_B_warped64], [B_rec64])
|
self.B64_view = K.function ([input_B_warped64], [B_rec64])
|
||||||
|
|
||||||
input_A_warped64 = keras.layers.Input(img_shape64)
|
input_A_warped64 = keras.layers.Input(img_shape64)
|
||||||
input_A_target256 = keras.layers.Input(img_shape256)
|
input_A_target256 = keras.layers.Input(img_shape256)
|
||||||
A_code256 = self.encoder256(input_A_warped64)
|
A_rec256 = self.decoder256( self.encoder256(input_A_warped64) )
|
||||||
A_rec256 = self.decoder256(A_code256)
|
|
||||||
|
|
||||||
input_B_warped64 = keras.layers.Input(img_shape64)
|
input_B_warped64 = keras.layers.Input(img_shape64)
|
||||||
B_code64 = self.encoder64(input_B_warped64)
|
BA_rec64 = self.decoder64_src( self.encoder64(input_B_warped64) )
|
||||||
BA_rec64 = self.decoder64_src(B_code64)
|
BA_rec256 = self.decoder256( self.encoder256(BA_rec64) )
|
||||||
BA_code256 = self.encoder256(BA_rec64)
|
|
||||||
BA_rec256 = self.decoder256(BA_code256)
|
|
||||||
|
|
||||||
total256_loss = K.mean( tf_dssim(tf, input_A_target256, A_rec256) )
|
self.ae256 = self.keras.models.Model([input_A_warped64], [A_rec256] )
|
||||||
|
|
||||||
self.ed256_train = K.function ([input_A_warped64, input_A_target256],[total256_loss],
|
if self.is_training_mode:
|
||||||
self.keras.optimizers.Adam(lr=5e-5, beta_1=0.5, beta_2=0.999).get_updates(total256_loss, self.encoder256.trainable_weights + self.decoder256.trainable_weights)
|
self.ae256, = self.to_multi_gpu_model_if_possible ( [self.ae256,] )
|
||||||
)
|
|
||||||
|
self.ae256.compile(optimizer=self.keras.optimizers.Adam(lr=5e-5, beta_1=0.5, beta_2=0.999),
|
||||||
|
loss=[DSSIMLossClass(self.tf)()])
|
||||||
|
|
||||||
self.A256_view = K.function ([input_A_warped64], [A_rec256])
|
self.A256_view = K.function ([input_A_warped64], [A_rec256])
|
||||||
self.BA256_view = K.function ([input_B_warped64], [BA_rec256])
|
self.BA256_view = K.function ([input_B_warped64], [BA_rec256])
|
||||||
|
@ -108,10 +102,11 @@ class Model(ModelBase):
|
||||||
warped_src64, target_src64, target_src256, target_src_source64, target_src_source256 = sample[0]
|
warped_src64, target_src64, target_src256, target_src_source64, target_src_source256 = sample[0]
|
||||||
warped_dst64, target_dst64, target_dst_source64, target_dst_source256 = sample[1]
|
warped_dst64, target_dst64, target_dst_source64, target_dst_source256 = sample[1]
|
||||||
|
|
||||||
loss64, = self.ed64_train ([warped_src64, target_src64, warped_dst64, target_dst64])
|
loss64, loss_src64, loss_dst64 = self.ae64.train_on_batch ([warped_src64, warped_dst64], [target_src64, target_dst64])
|
||||||
loss256, = self.ed256_train ([warped_src64, target_src256])
|
|
||||||
|
|
||||||
return ( ('loss64', loss64), ('loss256', loss256), )
|
loss256 = self.ae256.train_on_batch ([warped_src64], [target_src256])
|
||||||
|
|
||||||
|
return ( ('loss64', loss64 ), ('loss256', loss256), )
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def onGetPreview(self, sample):
|
def onGetPreview(self, sample):
|
||||||
|
|
|
@ -44,6 +44,19 @@ def DSSIMMaskLossClass(tf):
|
||||||
|
|
||||||
return DSSIMMaskLoss
|
return DSSIMMaskLoss
|
||||||
|
|
||||||
|
def DSSIMLossClass(tf):
|
||||||
|
class DSSIMLoss(object):
|
||||||
|
def __init__(self, is_tanh=False):
|
||||||
|
self.is_tanh = is_tanh
|
||||||
|
|
||||||
|
def __call__(self,y_true, y_pred):
|
||||||
|
if not self.is_tanh:
|
||||||
|
return (1.0 - tf.image.ssim (y_true, y_pred, 1.0)) / 2.0
|
||||||
|
else:
|
||||||
|
return (1.0 - tf.image.ssim ((y_true/2+0.5), (y_pred/2+0.5), 1.0)) / 2.0
|
||||||
|
|
||||||
|
return DSSIMLoss
|
||||||
|
|
||||||
def MSEMaskLossClass(keras):
|
def MSEMaskLossClass(keras):
|
||||||
class MSEMaskLoss(object):
|
class MSEMaskLoss(object):
|
||||||
def __init__(self, mask_list, is_tanh=False):
|
def __init__(self, mask_list, is_tanh=False):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue