DeepFaceLab/models/Model_RecycleGAN/Model.py
iperov 7b70e7eec1 added new model U-net Face Morpher.
removed AVATAR - useless model was just for demo
removed MIAEF128 - use UFM insted
removed LIAEF128YAW - use model option sort by yaw on start for any model
All models now ask some options on start.
Session options (such as target epoch, batch_size, write_preview_history etc) can be overrided by special command arg.
Converter now always ask options and no more support to define options via command line.
fix bug when ConverterMasked always used not predicted mask.
SampleGenerator now always generate samples with replicated border, exclude mask samples.
refactorings
2019-01-02 17:26:12 +04:00

250 lines
11 KiB
Python

from models import ModelBase
import numpy as np
import cv2
from mathlib import get_power_of_two
from nnlib import nnlib
from facelib import FaceType
from samples import *
class Model(ModelBase):
GAH5 = 'GA.h5'
PAH5 = 'PA.h5'
DAH5 = 'DA.h5'
GBH5 = 'GB.h5'
DBH5 = 'DB.h5'
PBH5 = 'PB.h5'
#override
def onInitialize(self, batch_size=-1, **in_options):
exec(nnlib.code_import_all, locals(), globals())
created_batch_size = self.get_batch_size()
if self.epoch == 0:
#first run
try:
created_resolution = int ( input ("Resolution (default:64, valid: 64,128,256) : ") )
except:
created_resolution = 64
if created_resolution not in [64,128,256]:
created_resolution = 64
try:
created_batch_size = int ( input ("Batch_size (minimum/default - 10) : ") )
except:
created_batch_size = 10
created_batch_size = max(created_batch_size,1)
print ("Done. If training won't start, decrease resolution")
self.options['created_resolution'] = created_resolution
self.options['created_batch_size'] = created_batch_size
self.created_vram_gb = self.device_config.gpu_total_vram_gb
else:
#not first run
if 'created_batch_size' in self.options.keys():
created_batch_size = self.options['created_batch_size']
else:
raise Exception("Continue training, but created_batch_size not found.")
if 'created_resolution' in self.options.keys():
created_resolution = self.options['created_resolution']
else:
raise Exception("Continue training, but created_resolution not found.")
resolution = created_resolution
bgr_shape = (resolution, resolution, 3)
ngf = 64
npf = 64
ndf = 64
lambda_A = 10
lambda_B = 10
self.set_batch_size(created_batch_size)
use_batch_norm = False #created_batch_size > 1
self.GA = modelify(ResNet (bgr_shape[2], use_batch_norm, n_blocks=6, ngf=ngf, use_dropout=True))(Input(bgr_shape))
self.GB = modelify(ResNet (bgr_shape[2], use_batch_norm, n_blocks=6, ngf=ngf, use_dropout=True))(Input(bgr_shape))
#self.GA = modelify(UNet (bgr_shape[2], use_batch_norm, num_downs=get_power_of_two(resolution)-1, ngf=ngf, use_dropout=True))(Input(bgr_shape))
#self.GB = modelify(UNet (bgr_shape[2], use_batch_norm, num_downs=get_power_of_two(resolution)-1, ngf=ngf, use_dropout=True))(Input(bgr_shape))
self.PA = modelify(UNetTemporalPredictor(bgr_shape[2], use_batch_norm, num_downs=get_power_of_two(resolution)-1, ngf=npf, use_dropout=True))([Input(bgr_shape), Input(bgr_shape)])
self.PB = modelify(UNetTemporalPredictor(bgr_shape[2], use_batch_norm, num_downs=get_power_of_two(resolution)-1, ngf=npf, use_dropout=True))([Input(bgr_shape), Input(bgr_shape)])
self.DA = modelify(NLayerDiscriminator(use_batch_norm, ndf=ndf, n_layers=3) ) (Input(bgr_shape))
self.DB = modelify(NLayerDiscriminator(use_batch_norm, ndf=ndf, n_layers=3) ) (Input(bgr_shape))
if not self.is_first_run():
self.GA.load_weights (self.get_strpath_storage_for_file(self.GAH5))
self.DA.load_weights (self.get_strpath_storage_for_file(self.DAH5))
self.PA.load_weights (self.get_strpath_storage_for_file(self.PAH5))
self.GB.load_weights (self.get_strpath_storage_for_file(self.GBH5))
self.DB.load_weights (self.get_strpath_storage_for_file(self.DBH5))
self.PB.load_weights (self.get_strpath_storage_for_file(self.PBH5))
real_A0 = Input(bgr_shape, name="real_A0")
real_A1 = Input(bgr_shape, name="real_A1")
real_A2 = Input(bgr_shape, name="real_A2")
real_B0 = Input(bgr_shape, name="real_B0")
real_B1 = Input(bgr_shape, name="real_B1")
real_B2 = Input(bgr_shape, name="real_B2")
DA_ones = K.ones ( K.int_shape(self.DA.outputs[0])[1:] )
DA_zeros = K.zeros ( K.int_shape(self.DA.outputs[0])[1:] )
DB_ones = K.ones ( K.int_shape(self.DB.outputs[0])[1:] )
DB_zeros = K.zeros ( K.int_shape(self.DB.outputs[0])[1:] )
def CycleLoss (t1,t2):
return K.mean(K.square(t1 - t2))
def RecurrentLOSS(t1,t2):
return K.mean(K.square(t1 - t2))
def RecycleLOSS(t1,t2):
return K.mean(K.square(t1 - t2))
fake_B0 = self.GA(real_A0)
fake_B1 = self.GA(real_A1)
fake_A0 = self.GB(real_B0)
fake_A1 = self.GB(real_B1)
#rec_FB0 = self.GA(fake_A0)
#rec_FB1 = self.GA(fake_A1)
#rec_FA0 = self.GB(fake_B0)
#rec_FA1 = self.GB(fake_B1)
pred_A2 = self.PA ( [real_A0, real_A1])
pred_B2 = self.PB ( [real_B0, real_B1])
rec_A2 = self.GB ( self.PB ( [fake_B0, fake_B1]) )
rec_B2 = self.GA ( self.PA ( [fake_A0, fake_A1]))
loss_G = K.mean(K.square(self.DB(fake_B0) - DB_ones)) + \
K.mean(K.square(self.DB(fake_B1) - DB_ones)) + \
K.mean(K.square(self.DA(fake_A0) - DA_ones)) + \
K.mean(K.square(self.DA(fake_A1) - DA_ones)) + \
lambda_A * ( #CycleLoss(rec_FA0, real_A0) + \
#CycleLoss(rec_FA1, real_A1) + \
RecurrentLOSS(pred_A2, real_A2) + \
RecycleLOSS(rec_A2, real_A2) ) + \
lambda_B * ( #CycleLoss(rec_FB0, real_B0) + \
#CycleLoss(rec_FB1, real_B1) + \
RecurrentLOSS(pred_B2, real_B2) + \
RecycleLOSS(rec_B2, real_B2) )
weights_G = self.GA.trainable_weights + self.GB.trainable_weights + self.PA.trainable_weights + self.PB.trainable_weights
self.G_train = K.function ([real_A0, real_A1, real_A2, real_B0, real_B1, real_B2],[loss_G],
Adam(lr=2e-4, beta_1=0.5, beta_2=0.999).get_updates(loss_G, weights_G) )
###########
loss_D_A0 = ( K.mean(K.square( self.DA(real_A0) - DA_ones)) + \
K.mean(K.square( self.DA(fake_A0) - DA_zeros)) ) * 0.5
loss_D_A1 = ( K.mean(K.square( self.DA(real_A1) - DA_ones)) + \
K.mean(K.square( self.DA(fake_A1) - DA_zeros)) ) * 0.5
loss_D_A = loss_D_A0 + loss_D_A1
self.DA_train = K.function ([real_A0, real_A1, real_A2, real_B0, real_B1, real_B2],[loss_D_A],
Adam(lr=2e-4, beta_1=0.5, beta_2=0.999).get_updates(loss_D_A, self.DA.trainable_weights) )
############
loss_D_B0 = ( K.mean(K.square( self.DB(real_B0) - DB_ones)) + \
K.mean(K.square( self.DB(fake_B0) - DB_zeros)) ) * 0.5
loss_D_B1 = ( K.mean(K.square( self.DB(real_B1) - DB_ones)) + \
K.mean(K.square( self.DB(fake_B1) - DB_zeros)) ) * 0.5
loss_D_B = loss_D_B0 + loss_D_B1
self.DB_train = K.function ([real_A0, real_A1, real_A2, real_B0, real_B1, real_B2],[loss_D_B],
Adam(lr=2e-4, beta_1=0.5, beta_2=0.999).get_updates(loss_D_B, self.DB.trainable_weights) )
############
self.G_view = K.function([real_A0, real_A1, real_A2, real_B0, real_B1, real_B2],[fake_A0, fake_A1, pred_A2, rec_A2, fake_B0, fake_B1, pred_B2, rec_B2 ])
self.G_convert = K.function([real_B0],[fake_A0])
if self.is_training_mode:
f = SampleProcessor.TypeFlags
self.set_training_data_generators ([
SampleGeneratorImageTemporal(self.training_data_src_path, debug=self.is_debug(), batch_size=self.batch_size,
temporal_image_count=3,
sample_process_options=SampleProcessor.Options(random_flip = False, normalize_tanh = True),
output_sample_types=[ [f.SOURCE | f.MODE_BGR, resolution] ] ),
SampleGeneratorImageTemporal(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
temporal_image_count=3,
sample_process_options=SampleProcessor.Options(random_flip = False, normalize_tanh = True),
output_sample_types=[ [f.SOURCE | f.MODE_BGR, resolution] ] ),
])
#override
def onSave(self):
self.save_weights_safe( [[self.GA, self.get_strpath_storage_for_file(self.GAH5)],
[self.GB, self.get_strpath_storage_for_file(self.GBH5)],
[self.DA, self.get_strpath_storage_for_file(self.DAH5)],
[self.DB, self.get_strpath_storage_for_file(self.DBH5)],
[self.PA, self.get_strpath_storage_for_file(self.PAH5)],
[self.PB, self.get_strpath_storage_for_file(self.PBH5)] ])
#override
def onTrainOneEpoch(self, sample):
source_src_0, source_src_1, source_src_2, = sample[0]
source_dst_0, source_dst_1, source_dst_2, = sample[1]
feed = [source_src_0, source_src_1, source_src_2, source_dst_0, source_dst_1, source_dst_2]
loss_G, = self.G_train ( feed )
loss_DA, = self.DA_train( feed )
loss_DB, = self.DB_train( feed )
#return ( ('G', loss_G), )
return ( ('G', loss_G), ('DA', loss_DA), ('DB', loss_DB) )
#override
def onGetPreview(self, sample):
test_A0 = sample[0][0]
test_A1 = sample[0][1]
test_A2 = sample[0][2]
test_B0 = sample[1][0]
test_B1 = sample[1][1]
test_B2 = sample[1][2]
G_view_result = self.G_view([test_A0, test_A1, test_A2, test_B0, test_B1, test_B2])
fake_A0, fake_A1, pred_A2, rec_A2, fake_B0, fake_B1, pred_B2, rec_B2 = [ x[0] / 2 + 0.5 for x in G_view_result]
test_A0, test_A1, test_A2, test_B0, test_B1, test_B2 = [ x[0] / 2 + 0.5 for x in [test_A0, test_A1, test_A2, test_B0, test_B1, test_B2] ]
r = np.concatenate ((np.concatenate ( (test_A0, test_A1, test_A2, pred_A2, fake_B0, fake_B1, rec_A2), axis=1),
np.concatenate ( (test_B0, test_B1, test_B2, pred_B2, fake_A0, fake_A1, rec_B2), axis=1)
), axis=0)
return [ ('RecycleGAN, A0-A1-A2-PA2-FB0-FB1-RA2, B0-B1-B2-PB2-FA0-FA1-RB2, ', r ) ]
def predictor_func (self, face):
x = self.G_convert ( [ np.expand_dims(face *2 - 1,0)] )[0]
return x[0] / 2 + 0.5
#override
def get_converter(self, **in_options):
from models import ConverterImage
return ConverterImage(self.predictor_func,
predictor_input_size=self.options['created_resolution'],
output_size=self.options['created_resolution'],
**in_options)