mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 21:12:07 -07:00
SAE : WARNING, RETRAIN IS REQUIRED !
fixed model sizes from previous update. avoided bug in ML framework(keras) that forces to train the model on random noise. Converter: added blur on the same keys as sharpness Added new model 'TrueFace'. This is a GAN model ported from https://github.com/NVlabs/FUNIT Model produces near zero morphing and high detail face. Model has higher failure rate than other models. Keep src and dst faceset in same lighting conditions.
This commit is contained in:
parent
201b762541
commit
dc11ec32be
26 changed files with 1308 additions and 250 deletions
343
nnlib/FUNIT.py
Normal file
343
nnlib/FUNIT.py
Normal file
|
@ -0,0 +1,343 @@
|
|||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
from interact import interact as io
|
||||
from nnlib import nnlib
|
||||
|
||||
"""
|
||||
My port of FUNIT: Few-Shot Unsupervised Image-to-Image Translation to pure keras.
|
||||
original repo: https://github.com/NVlabs/FUNIT/
|
||||
"""
|
||||
class FUNIT(object):
|
||||
VERSION = 1
|
||||
def __init__ (self, face_type_str,
|
||||
batch_size,
|
||||
encoder_nf=64,
|
||||
encoder_downs=2,
|
||||
encoder_res_blk=2,
|
||||
class_downs=4,
|
||||
class_nf=64,
|
||||
class_latent=64,
|
||||
mlp_nf=256,
|
||||
mlp_blks=2,
|
||||
dis_nf=64,
|
||||
dis_res_blks=10,
|
||||
num_classes=2,
|
||||
subpixel_decoder=True,
|
||||
initialize_weights=True,
|
||||
|
||||
load_weights_locally=False,
|
||||
weights_file_root=None,
|
||||
|
||||
is_training=True
|
||||
):
|
||||
exec( nnlib.import_all(), locals(), globals() )
|
||||
|
||||
self.batch_size = batch_size
|
||||
bgr_shape = (None, None, 3)
|
||||
label_shape = (1,)
|
||||
|
||||
self.enc_content = modelify ( FUNIT.ContentEncoderFlow(downs=encoder_downs, nf=encoder_nf, n_res_blks=encoder_res_blk) ) ( Input(bgr_shape) )
|
||||
self.enc_class_model = modelify ( FUNIT.ClassModelEncoderFlow(downs=class_downs, nf=class_nf, latent_dim=class_latent) ) ( Input(bgr_shape) )
|
||||
self.decoder = modelify ( FUNIT.DecoderFlow(ups=encoder_downs, n_res_blks=encoder_res_blk, mlp_nf=mlp_nf, mlp_blks=mlp_blks, subpixel_decoder=subpixel_decoder ) ) \
|
||||
( [ Input(K.int_shape(self.enc_content.outputs[0])[1:], name="decoder_input_1"),
|
||||
Input(K.int_shape(self.enc_class_model.outputs[0])[1:], name="decoder_input_2")
|
||||
] )
|
||||
|
||||
self.dis = modelify ( FUNIT.DiscriminatorFlow(nf=dis_nf, n_res_blks=dis_res_blks, num_classes=num_classes) ) (Input(bgr_shape))
|
||||
|
||||
self.G_opt = RMSprop(lr=0.0001, decay=0.0001, tf_cpu_mode=2 if 'tensorflow' in nnlib.active_DeviceConfig.backend else 0)
|
||||
self.D_opt = RMSprop(lr=0.0001, decay=0.0001, tf_cpu_mode=2 if 'tensorflow' in nnlib.active_DeviceConfig.backend else 0)
|
||||
|
||||
xa = Input(bgr_shape, name="xa")
|
||||
la = Input(label_shape, dtype=np.int32, name="la")
|
||||
|
||||
xb = Input(bgr_shape, name="xb")
|
||||
lb = Input(label_shape, dtype=np.int32, name="lb")
|
||||
|
||||
s_xa_one = Input( (self.enc_class_model.outputs[0].shape[-1].value,), name="s_xa_input")
|
||||
|
||||
c_xa = self.enc_content(xa)
|
||||
|
||||
s_xa = self.enc_class_model(xa)
|
||||
s_xb = self.enc_class_model(xb)
|
||||
|
||||
s_xa_mean = K.mean(s_xa, axis=0)
|
||||
|
||||
xr = self.decoder ([c_xa,s_xa])
|
||||
xt = self.decoder ([c_xa,s_xb])
|
||||
xr_one = self.decoder ([c_xa,s_xa_one])
|
||||
|
||||
d_xr, d_xr_feat = self.dis(xr)
|
||||
d_xt, d_xt_feat = self.dis(xt)
|
||||
|
||||
d_xa, d_xa_feat = self.dis(xa)
|
||||
d_xb, d_xb_feat = self.dis(xb)
|
||||
|
||||
def dis_gather(x,l):
|
||||
tensors = []
|
||||
for i in range(self.batch_size):
|
||||
t = x[i:i+1,:,:, l[i,0]]
|
||||
tensors += [t]
|
||||
return tensors
|
||||
|
||||
def dis_gather_batch_mean(x,l, func=None):
|
||||
x_shape = K.shape(x)
|
||||
b,h,w,c = x_shape[0],x_shape[1],x_shape[2],x_shape[3]
|
||||
b,h,w,c = [ K.cast(x, K.floatx()) for x in [b,h,w,c] ]
|
||||
|
||||
tensors = dis_gather(x,l)
|
||||
if func is not None:
|
||||
tensors = [func(t) for t in tensors]
|
||||
|
||||
return K.sum(tensors, axis=[1,2,3]) / (h*w)
|
||||
|
||||
def dis_gather_mean(x,l, func=None, acc_func=None):
|
||||
x_shape = K.shape(x)
|
||||
b,h,w,c = x_shape[0],x_shape[1],x_shape[2],x_shape[3]
|
||||
b,h,w,c = [ K.cast(x, K.floatx()) for x in [b,h,w,c] ]
|
||||
|
||||
tensors = dis_gather(x,l)
|
||||
|
||||
if acc_func is not None:
|
||||
acc = []
|
||||
for t in tensors:
|
||||
acc += [ K.sum( K.cast( acc_func(t), K.floatx() )) ]
|
||||
acc = K.cast( K.sum(acc), K.floatx() ) / (b*h*w)
|
||||
else:
|
||||
acc = None
|
||||
|
||||
if func is not None:
|
||||
tensors = [func(t) for t in tensors]
|
||||
|
||||
return K.sum(tensors) / (b*h*w), acc
|
||||
|
||||
d_xr_la, d_xr_la_acc = dis_gather_mean(d_xr, la, acc_func=lambda x: x >= 0)
|
||||
d_xt_lb, d_xt_lb_acc = dis_gather_mean(d_xt, lb, acc_func=lambda x: x >= 0)
|
||||
|
||||
d_xb_lb = dis_gather_batch_mean(d_xb, lb)
|
||||
|
||||
d_xb_lb_real, d_xb_lb_real_acc = dis_gather_mean(d_xb, lb, lambda x: K.relu(1.0-x), acc_func=lambda x: x >= 0)
|
||||
d_xt_lb_fake, d_xt_lb_fake_acc = dis_gather_mean(d_xt, lb, lambda x: K.relu(1.0+x), acc_func=lambda x: x < 0)
|
||||
|
||||
G_c_rec = K.mean(K.abs(K.mean(d_xr_feat, axis=[1,2]) - K.mean(d_xa_feat, axis=[1,2]))) #* 1.0
|
||||
G_m_rec = K.mean(K.abs(K.mean(d_xt_feat, axis=[1,2]) - K.mean(d_xb_feat, axis=[1,2]))) #* 1.0
|
||||
G_x_rec = 0.1 * K.mean(K.abs(xr-xa))
|
||||
|
||||
G_loss = (-d_xr_la-d_xt_lb)*0.5 + G_x_rec + G_c_rec + G_m_rec
|
||||
G_acc = (d_xr_la_acc+d_xt_lb_acc)*0.5
|
||||
|
||||
G_weights = self.enc_class_model.trainable_weights + self.enc_content.trainable_weights + self.decoder.trainable_weights
|
||||
######
|
||||
|
||||
D_real = d_xb_lb_real #1.0 *
|
||||
D_fake = d_xt_lb_fake #1.0 *
|
||||
|
||||
l_reg = 10 * K.sum( K.gradients( d_xb_lb, xb )[0] ** 2 ) # , axis=[1,2,3] / self.batch_size )
|
||||
|
||||
D_loss = D_real + D_fake + l_reg
|
||||
D_acc = (d_xb_lb_real_acc+d_xt_lb_fake_acc)*0.5
|
||||
|
||||
D_weights = self.dis.trainable_weights
|
||||
|
||||
self.G_train = K.function ([xa, la, xb, lb],[G_loss], self.G_opt.get_updates(G_loss, G_weights) )
|
||||
|
||||
self.D_train = K.function ([xa, la, xb, lb],[D_loss], self.D_opt.get_updates(D_loss, D_weights) )
|
||||
self.get_average_class_code = K.function ([xa],[s_xa_mean])
|
||||
|
||||
self.G_convert = K.function ([xa,s_xa_one],[xr_one])
|
||||
|
||||
if initialize_weights:
|
||||
#gather weights from layers for initialization
|
||||
weights_list = []
|
||||
for model, _ in self.get_model_filename_list():
|
||||
if type(model) == keras.models.Model:
|
||||
for layer in model.layers:
|
||||
if type(layer) == FUNITAdain:
|
||||
weights_list += [ x for x in layer.weights if 'kernel' in x.name ]
|
||||
elif type(layer) == keras.layers.Conv2D or type(layer) == keras.layers.Dense:
|
||||
weights_list += [ layer.weights[0] ]
|
||||
|
||||
initer = keras.initializers.he_normal()
|
||||
for w in weights_list:
|
||||
K.set_value( w, K.get_value(initer(K.int_shape(w))) )
|
||||
|
||||
#if not self.is_first_run():
|
||||
# self.load_weights_safe(self.get_model_filename_list())
|
||||
|
||||
|
||||
|
||||
if load_weights_locally:
|
||||
pass
|
||||
#f weights_file_root is not None:
|
||||
# weights_file_root = Path(weights_file_root)
|
||||
#lse:
|
||||
# weights_file_root = Path(__file__).parent
|
||||
#elf.weights_path = weights_file_root / ('FUNIT_%s.h5' % (face_type_str) )
|
||||
#f load_weights:
|
||||
# self.model.load_weights (str(self.weights_path))
|
||||
|
||||
|
||||
|
||||
def get_model_filename_list(self):
|
||||
return [[self.enc_class_model, 'enc_class_model.h5'],
|
||||
[self.enc_content, 'enc_content.h5'],
|
||||
[self.decoder, 'decoder.h5'],
|
||||
[self.dis, 'dis.h5'],
|
||||
[self.G_opt, 'G_opt.h5'],
|
||||
[self.D_opt, 'D_opt.h5'],
|
||||
]
|
||||
|
||||
#def save_weights(self):
|
||||
# self.model.save_weights (str(self.weights_path))
|
||||
|
||||
def train(self, xa,la,xb,lb):
|
||||
D_loss, = self.D_train ([xa,la,xb,lb])
|
||||
G_loss, = self.G_train ([xa,la,xb,lb])
|
||||
return G_loss, D_loss
|
||||
|
||||
def get_average_class_code(self, *args, **kwargs):
|
||||
return self.get_average_class_code(*args, **kwargs)
|
||||
|
||||
def convert(self, *args, **kwargs):
|
||||
return self.G_convert(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def ContentEncoderFlow(downs=2, nf=64, n_res_blks=2):
|
||||
exec (nnlib.import_all(), locals(), globals())
|
||||
|
||||
def ResBlock(dim):
|
||||
def func(input):
|
||||
x = input
|
||||
x = Conv2D(dim, 3, strides=1, padding='valid')(ZeroPadding2D(1)(x))
|
||||
x = InstanceNormalization()(x)
|
||||
x = ReLU()(x)
|
||||
x = Conv2D(dim, 3, strides=1, padding='valid')(ZeroPadding2D(1)(x))
|
||||
x = InstanceNormalization()(x)
|
||||
|
||||
return Add()([x,input])
|
||||
return func
|
||||
|
||||
def func(x):
|
||||
x = Conv2D (nf, kernel_size=7, strides=1, padding='valid')(ZeroPadding2D(3)(x))
|
||||
x = InstanceNormalization()(x)
|
||||
x = ReLU()(x)
|
||||
for i in range(downs):
|
||||
x = Conv2D (nf * 2**(i+1), kernel_size=4, strides=2, padding='valid')(ZeroPadding2D(1)(x))
|
||||
x = InstanceNormalization()(x)
|
||||
x = ReLU()(x)
|
||||
for i in range(n_res_blks):
|
||||
x = ResBlock( nf * 2**downs )(x)
|
||||
return x
|
||||
|
||||
return func
|
||||
|
||||
@staticmethod
|
||||
def ClassModelEncoderFlow(downs=4, nf=64, latent_dim=64):
|
||||
exec (nnlib.import_all(), locals(), globals())
|
||||
|
||||
def func(x):
|
||||
x = Conv2D (nf, kernel_size=7, strides=1, padding='valid', activation='relu')(ZeroPadding2D(3)(x))
|
||||
for i in range(downs):
|
||||
x = Conv2D (nf * min ( 4, 2**(i+1) ), kernel_size=4, strides=2, padding='valid', activation='relu')(ZeroPadding2D(1)(x))
|
||||
x = GlobalAveragePooling2D()(x)
|
||||
x = Dense(nf)(x)
|
||||
return x
|
||||
|
||||
return func
|
||||
|
||||
@staticmethod
|
||||
def DecoderFlow(ups, n_res_blks=2, mlp_nf=256, mlp_blks=2, subpixel_decoder=False ):
|
||||
exec (nnlib.import_all(), locals(), globals())
|
||||
|
||||
|
||||
|
||||
def ResBlock(dim):
|
||||
def func(input):
|
||||
inp, mlp = input
|
||||
x = inp
|
||||
x = Conv2D(dim, 3, strides=1, padding='valid')(ZeroPadding2D(1)(x))
|
||||
x = FUNITAdain()([x,mlp])
|
||||
x = ReLU()(x)
|
||||
x = Conv2D(dim, 3, strides=1, padding='valid')(ZeroPadding2D(1)(x))
|
||||
x = FUNITAdain()([x,mlp])
|
||||
return Add()([x,inp])
|
||||
return func
|
||||
|
||||
def func(inputs):
|
||||
x , class_code = inputs
|
||||
|
||||
nf = x.shape[-1].value
|
||||
|
||||
### MLP block inside decoder
|
||||
mlp = class_code
|
||||
for i in range(mlp_blks):
|
||||
mlp = Dense(mlp_nf, activation='relu')(mlp)
|
||||
|
||||
for i in range(n_res_blks):
|
||||
x = ResBlock(nf)( [x,mlp] )
|
||||
|
||||
for i in range(ups):
|
||||
|
||||
if subpixel_decoder:
|
||||
x = Conv2D (4* (nf // 2**(i+1)), kernel_size=3, strides=1, padding='valid')(ZeroPadding2D(1)(x))
|
||||
x = SubpixelUpscaler()(x)
|
||||
else:
|
||||
x = UpSampling2D()(x)
|
||||
x = Conv2D (nf // 2**(i+1), kernel_size=5, strides=1, padding='valid')(ZeroPadding2D(2)(x))
|
||||
|
||||
x = InstanceNormalization()(x)
|
||||
x = ReLU()(x)
|
||||
|
||||
rgb = Conv2D (3, kernel_size=7, strides=1, padding='valid', activation='tanh')(ZeroPadding2D(3)(x))
|
||||
return rgb
|
||||
|
||||
return func
|
||||
|
||||
|
||||
|
||||
@staticmethod
|
||||
def DiscriminatorFlow(nf, n_res_blks, num_classes ):
|
||||
exec (nnlib.import_all(), locals(), globals())
|
||||
|
||||
n_layers = n_res_blks // 2
|
||||
|
||||
def ActFirstResBlock(fout):
|
||||
def func(x):
|
||||
fin = K.int_shape(x)[-1]
|
||||
fhid = min(fin, fout)
|
||||
|
||||
if fin != fout:
|
||||
x_s = Conv2D (fout, kernel_size=1, strides=1, padding='valid', use_bias=False)(x)
|
||||
else:
|
||||
x_s = x
|
||||
|
||||
x = LeakyReLU(0.2)(x)
|
||||
x = Conv2D (fhid, kernel_size=3, strides=1, padding='valid')(ZeroPadding2D(1)(x))
|
||||
x = LeakyReLU(0.2)(x)
|
||||
x = Conv2D (fout, kernel_size=3, strides=1, padding='valid')(ZeroPadding2D(1)(x))
|
||||
return Add()([x_s, x])
|
||||
|
||||
return func
|
||||
|
||||
def func( x ):
|
||||
l_nf = nf
|
||||
x = Conv2D (l_nf, kernel_size=7, strides=1, padding='valid')(ZeroPadding2D(3)(x))
|
||||
for i in range(n_layers-1):
|
||||
l_nf_out = min( l_nf*2, 1024 )
|
||||
x = ActFirstResBlock(l_nf)(x)
|
||||
x = ActFirstResBlock(l_nf_out)(x)
|
||||
x = AveragePooling2D( pool_size=3, strides=2, padding='valid' )(ZeroPadding2D(1)(x))
|
||||
l_nf = min( l_nf*2, 1024 )
|
||||
|
||||
l_nf_out = min( l_nf*2, 1024 )
|
||||
x = ActFirstResBlock(l_nf)(x)
|
||||
feat = x = ActFirstResBlock(l_nf_out)(x)
|
||||
|
||||
x = LeakyReLU(0.2)(x)
|
||||
x = Conv2D (num_classes, kernel_size=1, strides=1, padding='valid')(x)
|
||||
|
||||
return x, feat
|
||||
|
||||
return func
|
Loading…
Add table
Add a link
Reference in a new issue