mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 21:12:07 -07:00
248 lines
11 KiB
Python
248 lines
11 KiB
Python
from models import ModelBase
|
|
import numpy as np
|
|
import cv2
|
|
from mathlib import get_power_of_two
|
|
from nnlib import nnlib
|
|
|
|
|
|
from facelib import FaceType
|
|
from samples import *
|
|
|
|
class Model(ModelBase):
|
|
|
|
GAH5 = 'GA.h5'
|
|
PAH5 = 'PA.h5'
|
|
DAH5 = 'DA.h5'
|
|
GBH5 = 'GB.h5'
|
|
DBH5 = 'DB.h5'
|
|
PBH5 = 'PB.h5'
|
|
|
|
#override
|
|
def onInitialize(self, batch_size=-1, **in_options):
|
|
exec(nnlib.code_import_all, locals(), globals())
|
|
|
|
created_batch_size = self.get_batch_size()
|
|
if self.epoch == 0:
|
|
#first run
|
|
|
|
print ("\nModel first run. Enter options.")
|
|
|
|
try:
|
|
created_resolution = int ( input ("Resolution (default:64, valid: 64,128,256) : ") )
|
|
except:
|
|
created_resolution = 64
|
|
|
|
if created_resolution not in [64,128,256]:
|
|
created_resolution = 64
|
|
|
|
try:
|
|
created_batch_size = int ( input ("Batch_size (minimum/default - 16) : ") )
|
|
except:
|
|
created_batch_size = 16
|
|
created_batch_size = max(created_batch_size,1)
|
|
|
|
print ("Done. If training won't start, decrease resolution")
|
|
|
|
self.options['created_resolution'] = created_resolution
|
|
self.options['created_batch_size'] = created_batch_size
|
|
self.created_vram_gb = self.device_config.gpu_total_vram_gb
|
|
else:
|
|
#not first run
|
|
if 'created_batch_size' in self.options.keys():
|
|
created_batch_size = self.options['created_batch_size']
|
|
else:
|
|
raise Exception("Continue training, but created_batch_size not found.")
|
|
|
|
if 'created_resolution' in self.options.keys():
|
|
created_resolution = self.options['created_resolution']
|
|
else:
|
|
raise Exception("Continue training, but created_resolution not found.")
|
|
|
|
resolution = created_resolution
|
|
bgr_shape = (resolution, resolution, 3)
|
|
ngf = 64
|
|
npf = 64
|
|
ndf = 64
|
|
lambda_A = 10
|
|
lambda_B = 10
|
|
|
|
self.set_batch_size(created_batch_size)
|
|
|
|
use_batch_norm = created_batch_size > 1
|
|
self.GA = modelify(ResNet (bgr_shape[2], use_batch_norm, n_blocks=6, ngf=ngf, use_dropout=False))(Input(bgr_shape))
|
|
self.GB = modelify(ResNet (bgr_shape[2], use_batch_norm, n_blocks=6, ngf=ngf, use_dropout=False))(Input(bgr_shape))
|
|
#self.GA = modelify(UNet (bgr_shape[2], use_batch_norm, num_downs=get_power_of_two(resolution)-1, ngf=ngf, use_dropout=True))(Input(bgr_shape))
|
|
#self.GB = modelify(UNet (bgr_shape[2], use_batch_norm, num_downs=get_power_of_two(resolution)-1, ngf=ngf, use_dropout=True))(Input(bgr_shape))
|
|
|
|
self.PA = modelify(UNetTemporalPredictor(bgr_shape[2], use_batch_norm, num_downs=get_power_of_two(resolution)-1, ngf=npf, use_dropout=True))([Input(bgr_shape), Input(bgr_shape)])
|
|
self.PB = modelify(UNetTemporalPredictor(bgr_shape[2], use_batch_norm, num_downs=get_power_of_two(resolution)-1, ngf=npf, use_dropout=True))([Input(bgr_shape), Input(bgr_shape)])
|
|
|
|
self.DA = modelify(NLayerDiscriminator(use_batch_norm, ndf=ndf, n_layers=3) ) (Input(bgr_shape))
|
|
self.DB = modelify(NLayerDiscriminator(use_batch_norm, ndf=ndf, n_layers=3) ) (Input(bgr_shape))
|
|
|
|
if not self.is_first_run():
|
|
self.GA.load_weights (self.get_strpath_storage_for_file(self.GAH5))
|
|
self.DA.load_weights (self.get_strpath_storage_for_file(self.DAH5))
|
|
self.PA.load_weights (self.get_strpath_storage_for_file(self.PAH5))
|
|
self.GB.load_weights (self.get_strpath_storage_for_file(self.GBH5))
|
|
self.DB.load_weights (self.get_strpath_storage_for_file(self.DBH5))
|
|
self.PB.load_weights (self.get_strpath_storage_for_file(self.PBH5))
|
|
|
|
real_A0 = Input(bgr_shape, name="real_A0")
|
|
real_A1 = Input(bgr_shape, name="real_A1")
|
|
real_A2 = Input(bgr_shape, name="real_A2")
|
|
|
|
real_B0 = Input(bgr_shape, name="real_B0")
|
|
real_B1 = Input(bgr_shape, name="real_B1")
|
|
real_B2 = Input(bgr_shape, name="real_B2")
|
|
|
|
DA_ones = K.ones ( K.int_shape(self.DA.outputs[0])[1:] )
|
|
DA_zeros = K.zeros ( K.int_shape(self.DA.outputs[0])[1:] )
|
|
DB_ones = K.ones ( K.int_shape(self.DB.outputs[0])[1:] )
|
|
DB_zeros = K.zeros ( K.int_shape(self.DB.outputs[0])[1:] )
|
|
|
|
def CycleLoss (t1,t2):
|
|
return K.mean(K.square(t1 - t2))
|
|
|
|
def RecurrentLOSS(t1,t2):
|
|
return K.mean(K.square(t1 - t2))
|
|
|
|
def RecycleLOSS(t1,t2):
|
|
return K.mean(K.square(t1 - t2))
|
|
|
|
fake_B0 = self.GA(real_A0)
|
|
fake_B1 = self.GA(real_A1)
|
|
|
|
fake_A0 = self.GB(real_B0)
|
|
fake_A1 = self.GB(real_B1)
|
|
|
|
#rec_FB0 = self.GA(fake_A0)
|
|
#rec_FB1 = self.GA(fake_A1)
|
|
|
|
#rec_FA0 = self.GB(fake_B0)
|
|
#rec_FA1 = self.GB(fake_B1)
|
|
|
|
pred_A2 = self.PA ( [real_A0, real_A1])
|
|
pred_B2 = self.PB ( [real_B0, real_B1])
|
|
rec_A2 = self.GB ( self.PB ( [fake_B0, fake_B1]) )
|
|
rec_B2 = self.GA ( self.PA ( [fake_A0, fake_A1]))
|
|
|
|
loss_G = K.mean(K.square(self.DB(fake_B0) - DB_ones)) + \
|
|
K.mean(K.square(self.DB(fake_B1) - DB_ones)) + \
|
|
K.mean(K.square(self.DA(fake_A0) - DA_ones)) + \
|
|
K.mean(K.square(self.DA(fake_A1) - DA_ones)) + \
|
|
lambda_A * ( #CycleLoss(rec_FA0, real_A0) + \
|
|
#CycleLoss(rec_FA1, real_A1) + \
|
|
RecurrentLOSS(pred_A2, real_A2) + \
|
|
RecycleLOSS(rec_A2, real_A2) ) + \
|
|
lambda_B * ( #CycleLoss(rec_FB0, real_B0) + \
|
|
#CycleLoss(rec_FB1, real_B1) + \
|
|
RecurrentLOSS(pred_B2, real_B2) + \
|
|
RecycleLOSS(rec_B2, real_B2) )
|
|
|
|
weights_G = self.GA.trainable_weights + self.GB.trainable_weights + self.PA.trainable_weights + self.PB.trainable_weights
|
|
|
|
self.G_train = K.function ([real_A0, real_A1, real_A2, real_B0, real_B1, real_B2],[loss_G],
|
|
Adam(lr=2e-4, beta_1=0.5, beta_2=0.999).get_updates(loss_G, weights_G) )
|
|
|
|
###########
|
|
|
|
loss_D_A0 = ( K.mean(K.square( self.DA(real_A0) - DA_ones)) + \
|
|
K.mean(K.square( self.DA(fake_A0) - DA_zeros)) ) * 0.5
|
|
|
|
loss_D_A1 = ( K.mean(K.square( self.DA(real_A1) - DA_ones)) + \
|
|
K.mean(K.square( self.DA(fake_A1) - DA_zeros)) ) * 0.5
|
|
|
|
loss_D_A = loss_D_A0 + loss_D_A1
|
|
|
|
self.DA_train = K.function ([real_A0, real_A1, real_A2, real_B0, real_B1, real_B2],[loss_D_A],
|
|
Adam(lr=2e-4, beta_1=0.5, beta_2=0.999).get_updates(loss_D_A, self.DA.trainable_weights) )
|
|
|
|
############
|
|
|
|
loss_D_B0 = ( K.mean(K.square( self.DB(real_B0) - DB_ones)) + \
|
|
K.mean(K.square( self.DB(fake_B0) - DB_zeros)) ) * 0.5
|
|
|
|
loss_D_B1 = ( K.mean(K.square( self.DB(real_B1) - DB_ones)) + \
|
|
K.mean(K.square( self.DB(fake_B1) - DB_zeros)) ) * 0.5
|
|
|
|
loss_D_B = loss_D_B0 + loss_D_B1
|
|
|
|
self.DB_train = K.function ([real_A0, real_A1, real_A2, real_B0, real_B1, real_B2],[loss_D_B],
|
|
Adam(lr=2e-4, beta_1=0.5, beta_2=0.999).get_updates(loss_D_B, self.DB.trainable_weights) )
|
|
|
|
############
|
|
|
|
|
|
self.G_view = K.function([real_A0, real_A1, real_A2, real_B0, real_B1, real_B2],[fake_A0, fake_A1, pred_A2, rec_A2, fake_B0, fake_B1, pred_B2, rec_B2 ])
|
|
self.G_convert = K.function([real_B0],[fake_A0])
|
|
|
|
|
|
if self.is_training_mode:
|
|
f = SampleProcessor.TypeFlags
|
|
self.set_training_data_generators ([
|
|
SampleGeneratorImageTemporal(self.training_data_src_path, debug=self.is_debug(), batch_size=self.batch_size,
|
|
temporal_image_count=3,
|
|
sample_process_options=SampleProcessor.Options(random_flip = False, normalize_tanh = True),
|
|
output_sample_types=[ [f.SOURCE | f.MODE_BGR, resolution] ] ),
|
|
|
|
SampleGeneratorImageTemporal(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
|
|
temporal_image_count=3,
|
|
sample_process_options=SampleProcessor.Options(random_flip = False, normalize_tanh = True),
|
|
output_sample_types=[ [f.SOURCE | f.MODE_BGR, resolution] ] ),
|
|
])
|
|
|
|
#override
|
|
def onSave(self):
|
|
self.save_weights_safe( [[self.GA, self.get_strpath_storage_for_file(self.GAH5)],
|
|
[self.GB, self.get_strpath_storage_for_file(self.GBH5)],
|
|
[self.DA, self.get_strpath_storage_for_file(self.DAH5)],
|
|
[self.DB, self.get_strpath_storage_for_file(self.DBH5)],
|
|
[self.PA, self.get_strpath_storage_for_file(self.PAH5)],
|
|
[self.PB, self.get_strpath_storage_for_file(self.PBH5)] ])
|
|
|
|
#override
|
|
def onTrainOneEpoch(self, sample):
|
|
source_src_0, source_src_1, source_src_2, = sample[0]
|
|
source_dst_0, source_dst_1, source_dst_2, = sample[1]
|
|
|
|
feed = [source_src_0, source_src_1, source_src_2, source_dst_0, source_dst_1, source_dst_2]
|
|
|
|
loss_G, = self.G_train ( feed )
|
|
loss_DA, = self.DA_train( feed )
|
|
loss_DB, = self.DB_train( feed )
|
|
|
|
return ( ('G', loss_G), ('DA', loss_DA), ('DB', loss_DB) )
|
|
|
|
#override
|
|
def onGetPreview(self, sample):
|
|
test_A0 = sample[0][0]
|
|
test_A1 = sample[0][1]
|
|
test_A2 = sample[0][2]
|
|
|
|
test_B0 = sample[1][0]
|
|
test_B1 = sample[1][1]
|
|
test_B2 = sample[1][2]
|
|
|
|
G_view_result = self.G_view([test_A0, test_A1, test_A2, test_B0, test_B1, test_B2])
|
|
|
|
fake_A0, fake_A1, pred_A2, rec_A2, fake_B0, fake_B1, pred_B2, rec_B2 = [ x[0] / 2 + 0.5 for x in G_view_result]
|
|
test_A0, test_A1, test_A2, test_B0, test_B1, test_B2 = [ x[0] / 2 + 0.5 for x in [test_A0, test_A1, test_A2, test_B0, test_B1, test_B2] ]
|
|
|
|
|
|
r = np.concatenate ((np.concatenate ( (test_A0, test_A1, test_A2, pred_A2, fake_B0, fake_B1, rec_A2), axis=1),
|
|
np.concatenate ( (test_B0, test_B1, test_B2, pred_B2, fake_A0, fake_A1, rec_B2), axis=1)
|
|
), axis=0)
|
|
|
|
return [ ('RecycleGAN, A0-A1-A2-PA2-FB0-FB1-RA2, B0-B1-B2-PB2-FA0-FA1-RB2, ', r ) ]
|
|
|
|
def predictor_func (self, face):
|
|
x = self.G_convert ( [ np.expand_dims(face *2 - 1,0)] )[0]
|
|
return x[0] / 2 + 0.5
|
|
|
|
#override
|
|
def get_converter(self, **in_options):
|
|
from models import ConverterImage
|
|
|
|
return ConverterImage(self.predictor_func, predictor_input_size=self.options['created_resolution'], output_size=self.options['created_resolution'], **in_options)
|
|
|