mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 04:52:13 -07:00
Session is now saved to the model folder. blur and erode ranges are increased to -400+400 hist-match-bw is now replaced with seamless2 mode. Added 'ebs' color transfer mode (works only on Windows). FANSEG model (used in FAN-x mask modes) is retrained with new model configuration and now produces better precision and less jitter
734 lines
No EOL
30 KiB
Python
734 lines
No EOL
30 KiB
Python
from functools import partial
|
|
|
|
import cv2
|
|
import numpy as np
|
|
|
|
from facelib import FaceType
|
|
from interact import interact as io
|
|
from mathlib import get_power_of_two
|
|
from models import ModelBase
|
|
from nnlib import nnlib
|
|
from samplelib import *
|
|
|
|
from facelib import PoseEstimator
|
|
|
|
class AVATARModel(ModelBase):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs,
|
|
ask_sort_by_yaw=False,
|
|
ask_random_flip=False,
|
|
ask_src_scale_mod=False)
|
|
|
|
#override
|
|
def onInitializeOptions(self, is_first_run, ask_override):
|
|
if is_first_run:
|
|
avatar_type = io.input_int("Avatar type ( 0:source, 1:head, 2:full_face ?:help skip:1) : ", 1, [0,1,2],
|
|
help_message="Training target for the model. Source is direct untouched images. Full_face or head are centered nose unaligned faces.")
|
|
avatar_type = {0:'source',
|
|
1:'head',
|
|
2:'full_face'}[avatar_type]
|
|
|
|
self.options['avatar_type'] = avatar_type
|
|
else:
|
|
self.options['avatar_type'] = self.options.get('avatar_type', 'head')
|
|
|
|
if is_first_run or ask_override:
|
|
def_stage = self.options.get('stage', 1)
|
|
self.options['stage'] = io.input_int("Stage (0, 1, 2 ?:help skip:%d) : " % def_stage, def_stage, [0,1,2], help_message="Train first stage, then second. Tune batch size to maximum possible for both stages.")
|
|
else:
|
|
self.options['stage'] = self.options.get('stage', 1)
|
|
|
|
#override
|
|
def onInitialize(self, batch_size=-1, **in_options):
|
|
exec(nnlib.code_import_all, locals(), globals())
|
|
self.set_vram_batch_requirements({6:4})
|
|
AVATARModel.initialize_nn_functions()
|
|
|
|
resolution = self.resolution = 224
|
|
avatar_type = self.options['avatar_type']
|
|
stage = self.stage = self.options['stage']
|
|
df_res = self.df_res = 128
|
|
df_bgr_shape = (df_res, df_res, 3)
|
|
df_mask_shape = (df_res, df_res, 1)
|
|
res_bgr_shape = (resolution, resolution, 3)
|
|
res_bgr_t_shape = (resolution, resolution, 9)
|
|
|
|
self.enc = modelify(AVATARModel.EncFlow())( [Input(df_bgr_shape),] )
|
|
|
|
self.decA64 = modelify(AVATARModel.DecFlow()) ( [ Input(K.int_shape(self.enc.outputs[0])[1:]) ] )
|
|
self.decB64 = modelify(AVATARModel.DecFlow()) ( [ Input(K.int_shape(self.enc.outputs[0])[1:]) ] )
|
|
self.D = modelify(AVATARModel.Discriminator() ) (Input(df_bgr_shape))
|
|
self.C = modelify(AVATARModel.ResNet (9, use_batch_norm=False, n_blocks=6, ngf=128, use_dropout=False))( Input(res_bgr_t_shape))
|
|
#self.CD = modelify(AVATARModel.CDiscriminator() ) (Input(res_bgr_t_shape))
|
|
|
|
if self.is_first_run():
|
|
conv_weights_list = []
|
|
for model, _ in self.get_model_filename_list():
|
|
for layer in model.layers:
|
|
if type(layer) == keras.layers.Conv2D:
|
|
conv_weights_list += [layer.weights[0]] #Conv2D kernel_weights
|
|
CAInitializerMP ( conv_weights_list )
|
|
|
|
if not self.is_first_run():
|
|
self.load_weights_safe( self.get_model_filename_list() )
|
|
|
|
def DLoss(labels,logits):
|
|
return K.mean(K.binary_crossentropy(labels,logits))
|
|
|
|
warped_A64 = Input(df_bgr_shape)
|
|
real_A64 = Input(df_bgr_shape)
|
|
real_A64m = Input(df_mask_shape)
|
|
|
|
real_B64_t0 = Input(df_bgr_shape)
|
|
real_B64_t1 = Input(df_bgr_shape)
|
|
real_B64_t2 = Input(df_bgr_shape)
|
|
|
|
real_A64_t0 = Input(df_bgr_shape)
|
|
real_A64m_t0 = Input(df_mask_shape)
|
|
real_A_t0 = Input(res_bgr_shape)
|
|
real_A64_t1 = Input(df_bgr_shape)
|
|
real_A64m_t1 = Input(df_mask_shape)
|
|
real_A_t1 = Input(res_bgr_shape)
|
|
real_A64_t2 = Input(df_bgr_shape)
|
|
real_A64m_t2 = Input(df_mask_shape)
|
|
real_A_t2 = Input(res_bgr_shape)
|
|
|
|
warped_B64 = Input(df_bgr_shape)
|
|
real_B64 = Input(df_bgr_shape)
|
|
real_B64m = Input(df_mask_shape)
|
|
|
|
warped_A_code = self.enc (warped_A64)
|
|
warped_B_code = self.enc (warped_B64)
|
|
|
|
rec_A64 = self.decA64(warped_A_code)
|
|
rec_B64 = self.decB64(warped_B_code)
|
|
rec_AB64 = self.decA64(warped_B_code)
|
|
|
|
def Lambda_grey_mask (x,m):
|
|
return Lambda (lambda x: x[0]*m+(1-m)*0.5, output_shape= K.int_shape(x)[1:3] + (3,)) ([x, m])
|
|
|
|
def Lambda_gray_pad(x):
|
|
a = np.ones((resolution,resolution,3))*0.5
|
|
pad = ( resolution - df_res ) // 2
|
|
a[pad:-pad:,pad:-pad:,:] = 0
|
|
|
|
return Lambda ( lambda x: K.spatial_2d_padding(x, padding=((pad, pad), (pad, pad)) ) + K.constant(a, dtype=K.floatx() ),
|
|
output_shape=(resolution,resolution,3) ) (x)
|
|
|
|
def Lambda_concat ( x ):
|
|
c = sum ( [ K.int_shape(l)[-1] for l in x ] )
|
|
return Lambda ( lambda x: K.concatenate (x, axis=-1), output_shape=K.int_shape(x[0])[1:3] + (c,) ) (x)
|
|
|
|
def Lambda_Cto3t(x):
|
|
return Lambda ( lambda x: x[...,0:3], output_shape= K.int_shape(x)[1:3] + (3,) ) (x), \
|
|
Lambda ( lambda x: x[...,3:6], output_shape= K.int_shape(x)[1:3] + (3,) ) (x), \
|
|
Lambda ( lambda x: x[...,6:9], output_shape= K.int_shape(x)[1:3] + (3,) ) (x)
|
|
|
|
real_A64_d = self.D( Lambda_grey_mask(real_A64, real_A64m) )
|
|
|
|
real_A64_d_ones = K.ones_like(real_A64_d)
|
|
fake_A64_d = self.D(rec_AB64)
|
|
fake_A64_d_ones = K.ones_like(fake_A64_d)
|
|
fake_A64_d_zeros = K.zeros_like(fake_A64_d)
|
|
|
|
rec_AB_t0 = Lambda_gray_pad( self.decA64 (self.enc (real_B64_t0)) )
|
|
rec_AB_t1 = Lambda_gray_pad( self.decA64 (self.enc (real_B64_t1)) )
|
|
rec_AB_t2 = Lambda_gray_pad( self.decA64 (self.enc (real_B64_t2)) )
|
|
|
|
C_in_A_t0 = Lambda_gray_pad( Lambda_grey_mask (real_A64_t0, real_A64m_t0) )
|
|
C_in_A_t1 = Lambda_gray_pad( Lambda_grey_mask (real_A64_t1, real_A64m_t1) )
|
|
C_in_A_t2 = Lambda_gray_pad( Lambda_grey_mask (real_A64_t2, real_A64m_t2) )
|
|
|
|
rec_C_A_t0, rec_C_A_t1, rec_C_A_t2 = Lambda_Cto3t ( self.C ( Lambda_concat ( [C_in_A_t0, C_in_A_t1, C_in_A_t2]) ) )
|
|
rec_C_AB_t0, rec_C_AB_t1, rec_C_AB_t2 = Lambda_Cto3t( self.C ( Lambda_concat ( [rec_AB_t0, rec_AB_t1, rec_AB_t2]) ) )
|
|
|
|
#real_A_t012_d = self.CD ( K.concatenate ( [real_A_t0, real_A_t1,real_A_t2], axis=-1) )
|
|
#real_A_t012_d_ones = K.ones_like(real_A_t012_d)
|
|
#rec_C_AB_t012_d = self.CD ( K.concatenate ( [rec_C_AB_t0,rec_C_AB_t1, rec_C_AB_t2], axis=-1) )
|
|
#rec_C_AB_t012_d_ones = K.ones_like(rec_C_AB_t012_d)
|
|
#rec_C_AB_t012_d_zeros = K.zeros_like(rec_C_AB_t012_d)
|
|
|
|
self.G64_view = K.function([warped_A64, warped_B64],[rec_A64, rec_B64, rec_AB64])
|
|
self.G_view = K.function([real_A64_t0, real_A64m_t0, real_A64_t1, real_A64m_t1, real_A64_t2, real_A64m_t2, real_B64_t0, real_B64_t1, real_B64_t2], [rec_C_A_t0, rec_C_A_t1, rec_C_A_t2, rec_C_AB_t0, rec_C_AB_t1, rec_C_AB_t2])
|
|
|
|
if self.is_training_mode:
|
|
loss_AB64 = K.mean(10 * dssim(kernel_size=int(df_res/11.6),max_value=1.0) ( rec_A64, real_A64*real_A64m + (1-real_A64m)*0.5) ) + \
|
|
K.mean(10 * dssim(kernel_size=int(df_res/11.6),max_value=1.0) ( rec_B64, real_B64*real_B64m + (1-real_B64m)*0.5) ) + 0.1*DLoss(fake_A64_d_ones, fake_A64_d )
|
|
|
|
weights_AB64 = self.enc.trainable_weights + self.decA64.trainable_weights + self.decB64.trainable_weights
|
|
|
|
loss_C = K.mean( 10 * dssim(kernel_size=int(resolution/11.6),max_value=1.0) ( real_A_t0, rec_C_A_t0 ) ) + \
|
|
K.mean( 10 * dssim(kernel_size=int(resolution/11.6),max_value=1.0) ( real_A_t1, rec_C_A_t1 ) ) + \
|
|
K.mean( 10 * dssim(kernel_size=int(resolution/11.6),max_value=1.0) ( real_A_t2, rec_C_A_t2 ) )
|
|
#0.1*DLoss(rec_C_AB_t012_d_ones, rec_C_AB_t012_d )
|
|
|
|
weights_C = self.C.trainable_weights
|
|
|
|
loss_D = (DLoss(real_A64_d_ones, real_A64_d ) + \
|
|
DLoss(fake_A64_d_zeros, fake_A64_d ) ) * 0.5
|
|
|
|
#loss_CD = ( DLoss(real_A_t012_d_ones, real_A_t012_d) + \
|
|
# DLoss(rec_C_AB_t012_d_zeros, rec_C_AB_t012_d) ) * 0.5
|
|
#
|
|
#weights_CD = self.CD.trainable_weights
|
|
|
|
def opt(lr=5e-5):
|
|
return Adam(lr=lr, beta_1=0.5, beta_2=0.999, tf_cpu_mode=2 if 'tensorflow' in self.device_config.backend else 0 )
|
|
|
|
self.AB64_train = K.function ([warped_A64, real_A64, real_A64m, warped_B64, real_B64, real_B64m], [loss_AB64], opt().get_updates(loss_AB64, weights_AB64) )
|
|
self.C_train = K.function ([real_A64_t0, real_A64m_t0, real_A_t0,
|
|
real_A64_t1, real_A64m_t1, real_A_t1,
|
|
real_A64_t2, real_A64m_t2, real_A_t2,
|
|
real_B64_t0, real_B64_t1, real_B64_t2],[ loss_C ], opt().get_updates(loss_C, weights_C) )
|
|
|
|
self.D_train = K.function ([warped_A64, real_A64, real_A64m, warped_B64, real_B64, real_B64m],[loss_D], opt().get_updates(loss_D, self.D.trainable_weights) )
|
|
|
|
|
|
#self.CD_train = K.function ([real_A64_t0, real_A64m_t0, real_A_t0,
|
|
# real_A64_t1, real_A64m_t1, real_A_t1,
|
|
# real_A64_t2, real_A64m_t2, real_A_t2,
|
|
# real_B64_t0, real_B64_t1, real_B64_t2 ],[ loss_CD ], opt().get_updates(loss_CD, weights_CD) )
|
|
|
|
###########
|
|
t = SampleProcessor.Types
|
|
|
|
training_target = {'source' : t.NONE,
|
|
'full_face' : t.FACE_TYPE_FULL_NO_ALIGN,
|
|
'head' : t.FACE_TYPE_HEAD_NO_ALIGN}[avatar_type]
|
|
|
|
generators = [
|
|
SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.batch_size,
|
|
sample_process_options=SampleProcessor.Options(random_flip=False),
|
|
output_sample_types=[ {'types': (t.IMG_WARPED_TRANSFORMED, t.FACE_TYPE_FULL_NO_ALIGN, t.MODE_BGR), 'resolution':df_res},
|
|
{'types': (t.IMG_TRANSFORMED, t.FACE_TYPE_FULL_NO_ALIGN, t.MODE_BGR), 'resolution':df_res},
|
|
{'types': (t.IMG_TRANSFORMED, t.FACE_TYPE_FULL_NO_ALIGN, t.MODE_M), 'resolution':df_res}
|
|
] ),
|
|
SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
|
|
sample_process_options=SampleProcessor.Options(random_flip=False),
|
|
output_sample_types=[ {'types': (t.IMG_WARPED_TRANSFORMED, t.FACE_TYPE_FULL_NO_ALIGN, t.MODE_BGR), 'resolution':df_res},
|
|
{'types': (t.IMG_TRANSFORMED, t.FACE_TYPE_FULL_NO_ALIGN, t.MODE_BGR), 'resolution':df_res},
|
|
{'types': (t.IMG_TRANSFORMED, t.FACE_TYPE_FULL_NO_ALIGN, t.MODE_M), 'resolution':df_res}
|
|
] ),
|
|
|
|
SampleGeneratorFaceTemporal(self.training_data_src_path, debug=self.is_debug(), batch_size=self.batch_size,
|
|
temporal_image_count=3,
|
|
sample_process_options=SampleProcessor.Options(random_flip=False),
|
|
output_sample_types=[{'types': (t.IMG_WARPED_TRANSFORMED, t.FACE_TYPE_FULL_NO_ALIGN, t.MODE_BGR), 'resolution':df_res},#IMG_WARPED_TRANSFORMED
|
|
{'types': (t.IMG_WARPED_TRANSFORMED, t.FACE_TYPE_FULL_NO_ALIGN, t.MODE_M), 'resolution':df_res},
|
|
{'types': (t.IMG_SOURCE, training_target, t.MODE_BGR), 'resolution':resolution},
|
|
] ),
|
|
|
|
SampleGeneratorFaceTemporal(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
|
|
temporal_image_count=3,
|
|
sample_process_options=SampleProcessor.Options(random_flip=False),
|
|
output_sample_types=[{'types': (t.IMG_SOURCE, t.FACE_TYPE_FULL_NO_ALIGN, t.MODE_BGR), 'resolution':df_res},
|
|
{'types': (t.IMG_SOURCE, t.NONE, t.MODE_BGR), 'resolution':resolution},
|
|
] ),
|
|
]
|
|
|
|
if self.stage == 1:
|
|
generators[2].set_active(False)
|
|
generators[3].set_active(False)
|
|
elif self.stage == 2:
|
|
generators[0].set_active(False)
|
|
generators[1].set_active(False)
|
|
|
|
self.set_training_data_generators (generators)
|
|
else:
|
|
self.G_convert = K.function([real_B64_t0, real_B64_t1, real_B64_t2],[rec_C_AB_t1])
|
|
|
|
#override , return [ [model, filename],... ] list
|
|
def get_model_filename_list(self):
|
|
return [ [self.enc, 'enc.h5'],
|
|
[self.decA64, 'decA64.h5'],
|
|
[self.decB64, 'decB64.h5'],
|
|
[self.C, 'C.h5'],
|
|
[self.D, 'D.h5'],
|
|
#[self.CD, 'CD.h5'],
|
|
]
|
|
|
|
#override
|
|
def onSave(self):
|
|
self.save_weights_safe( self.get_model_filename_list() )
|
|
|
|
#override
|
|
def onTrainOneIter(self, generators_samples, generators_list):
|
|
warped_src64, src64, src64m = generators_samples[0]
|
|
warped_dst64, dst64, dst64m = generators_samples[1]
|
|
|
|
real_A64_t0, real_A64m_t0, real_A_t0, real_A64_t1, real_A64m_t1, real_A_t1, real_A64_t2, real_A64m_t2, real_A_t2 = generators_samples[2]
|
|
real_B64_t0, _, real_B64_t1, _, real_B64_t2, _ = generators_samples[3]
|
|
|
|
if self.stage == 0 or self.stage == 1:
|
|
loss, = self.AB64_train ( [warped_src64, src64, src64m, warped_dst64, dst64, dst64m] )
|
|
loss_D, = self.D_train ( [warped_src64, src64, src64m, warped_dst64, dst64, dst64m] )
|
|
if self.stage != 0:
|
|
loss_C = loss_CD = 0
|
|
|
|
if self.stage == 0 or self.stage == 2:
|
|
loss_C1, = self.C_train ( [real_A64_t0, real_A64m_t0, real_A_t0,
|
|
real_A64_t1, real_A64m_t1, real_A_t1,
|
|
real_A64_t2, real_A64m_t2, real_A_t2,
|
|
real_B64_t0, real_B64_t1, real_B64_t2] )
|
|
|
|
loss_C2, = self.C_train ( [real_A64_t2, real_A64m_t2, real_A_t2,
|
|
real_A64_t1, real_A64m_t1, real_A_t1,
|
|
real_A64_t0, real_A64m_t0, real_A_t0,
|
|
real_B64_t0, real_B64_t1, real_B64_t2] )
|
|
|
|
#loss_CD1, = self.CD_train ( [real_A64_t0, real_A64m_t0, real_A_t0,
|
|
# real_A64_t1, real_A64m_t1, real_A_t1,
|
|
# real_A64_t2, real_A64m_t2, real_A_t2,
|
|
# real_B64_t0, real_B64_t1, real_B64_t2] )
|
|
#
|
|
#loss_CD2, = self.CD_train ( [real_A64_t2, real_A64m_t2, real_A_t2,
|
|
# real_A64_t1, real_A64m_t1, real_A_t1,
|
|
# real_A64_t0, real_A64m_t0, real_A_t0,
|
|
# real_B64_t0, real_B64_t1, real_B64_t2] )
|
|
|
|
loss_C = (loss_C1 + loss_C2) / 2
|
|
#loss_CD = (loss_CD1 + loss_CD2) / 2
|
|
if self.stage != 0:
|
|
loss = loss_D = 0
|
|
|
|
return ( ('loss', loss), ('D', loss_D), ('C', loss_C), ) #('CD', loss_CD) )
|
|
|
|
#override
|
|
def onGetPreview(self, sample):
|
|
test_A064w = sample[0][0][0:4]
|
|
test_A064r = sample[0][1][0:4]
|
|
test_A064m = sample[0][2][0:4]
|
|
|
|
test_B064w = sample[1][0][0:4]
|
|
test_B064r = sample[1][1][0:4]
|
|
test_B064m = sample[1][2][0:4]
|
|
|
|
t_src64_0 = sample[2][0][0:4]
|
|
t_src64m_0 = sample[2][1][0:4]
|
|
t_src_0 = sample[2][2][0:4]
|
|
t_src64_1 = sample[2][3][0:4]
|
|
t_src64m_1 = sample[2][4][0:4]
|
|
t_src_1 = sample[2][5][0:4]
|
|
t_src64_2 = sample[2][6][0:4]
|
|
t_src64m_2 = sample[2][7][0:4]
|
|
t_src_2 = sample[2][8][0:4]
|
|
|
|
t_dst64_0 = sample[3][0][0:4]
|
|
t_dst_0 = sample[3][1][0:4]
|
|
t_dst64_1 = sample[3][2][0:4]
|
|
t_dst_1 = sample[3][3][0:4]
|
|
t_dst64_2 = sample[3][4][0:4]
|
|
t_dst_2 = sample[3][5][0:4]
|
|
|
|
G64_view_result = self.G64_view ([test_A064r, test_B064r])
|
|
test_A064r, test_B064r, rec_A64, rec_B64, rec_AB64 = [ x[0] for x in ([test_A064r, test_B064r] + G64_view_result) ]
|
|
|
|
sample64x4 = np.concatenate ([ np.concatenate ( [rec_B64, rec_A64], axis=1 ),
|
|
np.concatenate ( [test_B064r, rec_AB64], axis=1) ], axis=0 )
|
|
|
|
sample64x4 = cv2.resize (sample64x4, (self.resolution, self.resolution) )
|
|
|
|
G_view_result = self.G_view([t_src64_0, t_src64m_0, t_src64_1, t_src64m_1, t_src64_2, t_src64m_2, t_dst64_0, t_dst64_1, t_dst64_2 ])
|
|
|
|
t_dst_0, t_dst_1, t_dst_2, rec_C_A_t0, rec_C_A_t1, rec_C_A_t2, rec_C_AB_t0, rec_C_AB_t1, rec_C_AB_t2 = [ x[0] for x in ([t_dst_0, t_dst_1, t_dst_2, ] + G_view_result) ]
|
|
|
|
c1 = np.concatenate ( (sample64x4, rec_C_A_t0, t_dst_0, rec_C_AB_t0 ), axis=1 )
|
|
c2 = np.concatenate ( (sample64x4, rec_C_A_t1, t_dst_1, rec_C_AB_t1 ), axis=1 )
|
|
c3 = np.concatenate ( (sample64x4, rec_C_A_t2, t_dst_2, rec_C_AB_t2 ), axis=1 )
|
|
|
|
r = np.concatenate ( [c1,c2,c3], axis=0 )
|
|
|
|
return [ ('AVATAR', r ) ]
|
|
|
|
def predictor_func (self, prev_imgs=None, img=None, next_imgs=None, dummy_predict=False):
|
|
if dummy_predict:
|
|
z = np.zeros ( (1, self.df_res, self.df_res, 3), dtype=np.float32 )
|
|
self.G_convert ([z,z,z])
|
|
else:
|
|
feed = [ prev_imgs[-1][np.newaxis,...], img[np.newaxis,...], next_imgs[0][np.newaxis,...] ]
|
|
x = self.G_convert (feed)[0]
|
|
return np.clip ( x[0], 0, 1)
|
|
|
|
#override
|
|
def get_ConverterConfig(self):
|
|
import converters
|
|
return self.predictor_func, (self.df_res, self.df_res, 3), converters.ConverterConfigFaceAvatar(temporal_face_count=1)
|
|
|
|
@staticmethod
|
|
def NLayerDiscriminator(ndf=64, n_layers=3):
|
|
exec (nnlib.import_all(), locals(), globals())
|
|
|
|
#use_bias = True
|
|
#def XNormalization(x):
|
|
# return InstanceNormalization (axis=-1)(x)
|
|
use_bias = False
|
|
def XNormalization(x):
|
|
return BatchNormalization (axis=-1)(x)
|
|
|
|
XConv2D = partial(Conv2D, use_bias=use_bias)
|
|
|
|
def func(x):
|
|
f = ndf
|
|
|
|
x = XConv2D( f, 4, strides=2, padding='same', use_bias=True)(x)
|
|
f = min( ndf*8, f*2 )
|
|
x = LeakyReLU(0.2)(x)
|
|
|
|
for i in range(n_layers):
|
|
x = XConv2D( f, 4, strides=2, padding='same')(x)
|
|
x = XNormalization(x)
|
|
x = LeakyReLU(0.2)(x)
|
|
f = min( ndf*8, f*2 )
|
|
|
|
x = XConv2D( f, 4, strides=1, padding='same')(x)
|
|
x = XNormalization(x)
|
|
x = LeakyReLU(0.2)(x)
|
|
|
|
return XConv2D( 1, 4, strides=1, padding='same', use_bias=True, activation='sigmoid')(x)#
|
|
return func
|
|
|
|
"""
|
|
@staticmethod
|
|
def Discriminator(ndf=128):
|
|
exec (nnlib.import_all(), locals(), globals())
|
|
|
|
#use_bias = True
|
|
#def XNormalization(x):
|
|
# return InstanceNormalization (axis=-1)(x)
|
|
use_bias = False
|
|
def XNormalization(x):
|
|
return BatchNormalization (axis=-1)(x)
|
|
|
|
XConv2D = partial(Conv2D, use_bias=use_bias)
|
|
|
|
def func(input):
|
|
b,h,w,c = K.int_shape(input)
|
|
|
|
x = input
|
|
|
|
x = XConv2D( ndf, 4, strides=2, padding='same', use_bias=True)(x)
|
|
x = LeakyReLU(0.2)(x)
|
|
|
|
x = XConv2D( ndf*2, 4, strides=2, padding='same')(x)
|
|
x = XNormalization(x)
|
|
x = LeakyReLU(0.2)(x)
|
|
|
|
x = XConv2D( ndf*4, 4, strides=2, padding='same')(x)
|
|
x = XNormalization(x)
|
|
x = LeakyReLU(0.2)(x)
|
|
|
|
x = XConv2D( ndf*8, 4, strides=2, padding='same')(x)
|
|
x = XNormalization(x)
|
|
x = LeakyReLU(0.2)(x)
|
|
|
|
return XConv2D( 1, 4, strides=1, padding='same', use_bias=True, activation='sigmoid')(x)#
|
|
return func
|
|
"""
|
|
@staticmethod
|
|
def Discriminator(ndf=128):
|
|
exec (nnlib.import_all(), locals(), globals())
|
|
|
|
use_bias = True
|
|
def XNormalization(x):
|
|
return InstanceNormalization (axis=-1)(x)
|
|
#use_bias = False
|
|
#def XNormalization(x):
|
|
# return BatchNormalization (axis=-1)(x)
|
|
|
|
XConv2D = partial(Conv2D, use_bias=use_bias)
|
|
|
|
def func(input):
|
|
b,h,w,c = K.int_shape(input)
|
|
|
|
x = input
|
|
|
|
x = XConv2D( ndf, 4, strides=2, padding='same', use_bias=True)(x)
|
|
x = LeakyReLU(0.2)(x)
|
|
|
|
x = XConv2D( ndf*2, 4, strides=2, padding='same')(x)
|
|
x = XNormalization(x)
|
|
x = LeakyReLU(0.2)(x)
|
|
|
|
x = XConv2D( ndf*4, 4, strides=2, padding='same')(x)
|
|
x = XNormalization(x)
|
|
x = LeakyReLU(0.2)(x)
|
|
|
|
x = XConv2D( ndf*8, 4, strides=2, padding='same')(x)
|
|
x = XNormalization(x)
|
|
x = LeakyReLU(0.2)(x)
|
|
|
|
return XConv2D( 1, 4, strides=1, padding='same', use_bias=True, activation='sigmoid')(x)#
|
|
return func
|
|
|
|
@staticmethod
|
|
def CDiscriminator(ndf=256):
|
|
exec (nnlib.import_all(), locals(), globals())
|
|
|
|
use_bias = True
|
|
def XNormalization(x):
|
|
return InstanceNormalization (axis=-1)(x)
|
|
#use_bias = False
|
|
#def XNormalization(x):
|
|
# return BatchNormalization (axis=-1)(x)
|
|
|
|
XConv2D = partial(Conv2D, use_bias=use_bias)
|
|
|
|
def func(input):
|
|
b,h,w,c = K.int_shape(input)
|
|
|
|
x = input
|
|
|
|
x = XConv2D( ndf, 4, strides=2, padding='same', use_bias=True)(x)
|
|
x = LeakyReLU(0.2)(x)
|
|
|
|
x = XConv2D( ndf*2, 4, strides=2, padding='same')(x)
|
|
x = XNormalization(x)
|
|
x = LeakyReLU(0.2)(x)
|
|
|
|
x = XConv2D( ndf*4, 4, strides=2, padding='same')(x)
|
|
x = XNormalization(x)
|
|
x = LeakyReLU(0.2)(x)
|
|
|
|
#x = XConv2D( ndf*8, 4, strides=2, padding='same')(x)
|
|
#x = XNormalization(x)
|
|
#x = LeakyReLU(0.2)(x)
|
|
|
|
return XConv2D( 1, 4, strides=1, padding='same', use_bias=True, activation='sigmoid')(x)#
|
|
return func
|
|
|
|
@staticmethod
|
|
def EncFlow(padding='zero', **kwargs):
|
|
exec (nnlib.import_all(), locals(), globals())
|
|
|
|
use_bias = False
|
|
def XNorm(x):
|
|
return BatchNormalization (axis=-1)(x)
|
|
XConv2D = partial(Conv2D, padding=padding, use_bias=use_bias)
|
|
|
|
def downscale (dim):
|
|
def func(x):
|
|
return LeakyReLU(0.1)( Conv2D(dim, 5, strides=2, padding='same')(x))
|
|
return func
|
|
|
|
def upscale (dim):
|
|
def func(x):
|
|
return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, 3, strides=1, padding='same')(x)))
|
|
return func
|
|
|
|
|
|
def func(input):
|
|
x, = input
|
|
b,h,w,c = K.int_shape(x)
|
|
|
|
dim_res = w // 16
|
|
|
|
x = downscale(64)(x)
|
|
x = downscale(128)(x)
|
|
x = downscale(256)(x)
|
|
x = downscale(512)(x)
|
|
|
|
x = Dense(512)(Flatten()(x))
|
|
x = Dense(dim_res * dim_res * 512)(x)
|
|
x = Reshape((dim_res, dim_res, 512))(x)
|
|
x = upscale(512)(x)
|
|
return x
|
|
|
|
return func
|
|
|
|
@staticmethod
|
|
def DecFlow(output_nc=3, **kwargs):
|
|
exec (nnlib.import_all(), locals(), globals())
|
|
|
|
ResidualBlock = AVATARModel.ResidualBlock
|
|
upscale = AVATARModel.upscale
|
|
to_bgr = AVATARModel.to_bgr
|
|
|
|
def func(input):
|
|
x = input[0]
|
|
|
|
x = upscale(512)(x)
|
|
x = upscale(256)(x)
|
|
x = upscale(128)(x)
|
|
return to_bgr(output_nc) (x)
|
|
|
|
return func
|
|
"""
|
|
@staticmethod
|
|
def CNet(output_nc, use_batch_norm, ngf=64, n_blocks=6, use_dropout=False):
|
|
exec (nnlib.import_all(), locals(), globals())
|
|
|
|
if not use_batch_norm:
|
|
use_bias = True
|
|
def XNormalization(x):
|
|
return InstanceNormalization (axis=-1)(x)
|
|
else:
|
|
use_bias = False
|
|
def XNormalization(x):
|
|
return BatchNormalization (axis=-1)(x)
|
|
|
|
XConv2D = partial(Conv2D, padding='same', use_bias=use_bias)
|
|
XConv2DTranspose = partial(Conv2DTranspose, padding='same', use_bias=use_bias)
|
|
|
|
def ResnetBlock(dim, use_dropout=False):
|
|
def func(input):
|
|
x = input
|
|
|
|
x = XConv2D(dim, 3, strides=1)(x)
|
|
x = XNormalization(x)
|
|
x = ReLU()(x)
|
|
|
|
if use_dropout:
|
|
x = Dropout(0.5)(x)
|
|
|
|
x = XConv2D(dim, 3, strides=1)(x)
|
|
x = XNormalization(x)
|
|
x = ReLU()(x)
|
|
return Add()([x,input])
|
|
return func
|
|
|
|
def preprocess(target_res):
|
|
def func(input):
|
|
inp_shape = K.int_shape (input[0])
|
|
t_len = len(input)
|
|
total_ch = 0
|
|
for i in range(t_len):
|
|
total_ch += K.int_shape (input[i])[-1]
|
|
|
|
K.concatenate ( input, axis=-1) )
|
|
import code
|
|
c ode.interact(local=dict(globals(), **locals()))
|
|
|
|
x_shape = K.int_shape(x)[1:]
|
|
|
|
pad = (target_res - x_shape[0]) // 2
|
|
|
|
a = np.ones((target_res,target_res,3))*0.5
|
|
a[pad:-pad:,pad:-pad:,:] = 0
|
|
return K.spatial_2d_padding(x, padding=((pad, pad), (pad, pad)) ) + K.constant(a, dtype=K.floatx() )
|
|
return func
|
|
|
|
def func(input):
|
|
inp_shape = K.int_shape (input[0])
|
|
t_len = len(input)
|
|
total_ch = 0
|
|
for i in range(t_len):
|
|
total_ch += K.int_shape (input[i])[-1]
|
|
|
|
x = Lambda ( preprocess(128) , output_shape=(inp_shape[1], inp_shape[2], total_ch) ) (input)
|
|
|
|
x = ReLU()(XNormalization(XConv2D(ngf, 7, strides=1)(x)))
|
|
|
|
x = ReLU()(XNormalization(XConv2D(ngf*2, 3, strides=2)(x)))
|
|
x = ReLU()(XNormalization(XConv2D(ngf*4, 3, strides=2)(x)))
|
|
|
|
for i in range(n_blocks):
|
|
x = ResnetBlock(ngf*4, use_dropout=use_dropout)(x)
|
|
|
|
x = ReLU()(XNormalization(XConv2DTranspose(ngf*2, 3, strides=2)(x)))
|
|
x = ReLU()(XNormalization(XConv2DTranspose(ngf , 3, strides=2)(x)))
|
|
|
|
x = XConv2D(output_nc, 7, strides=1, activation='sigmoid', use_bias=True)(x)
|
|
|
|
return x
|
|
|
|
return func
|
|
"""
|
|
@staticmethod
|
|
def ResNet(output_nc, use_batch_norm, ngf=64, n_blocks=6, use_dropout=False):
|
|
exec (nnlib.import_all(), locals(), globals())
|
|
|
|
if not use_batch_norm:
|
|
use_bias = True
|
|
def XNormalization(x):
|
|
return InstanceNormalization (axis=-1)(x)
|
|
else:
|
|
use_bias = False
|
|
def XNormalization(x):
|
|
return BatchNormalization (axis=-1)(x)
|
|
|
|
XConv2D = partial(Conv2D, padding='same', use_bias=use_bias)
|
|
XConv2DTranspose = partial(Conv2DTranspose, padding='same', use_bias=use_bias)
|
|
|
|
def func(input):
|
|
|
|
|
|
def ResnetBlock(dim, use_dropout=False):
|
|
def func(input):
|
|
x = input
|
|
|
|
x = XConv2D(dim, 3, strides=1)(x)
|
|
x = XNormalization(x)
|
|
x = ReLU()(x)
|
|
|
|
if use_dropout:
|
|
x = Dropout(0.5)(x)
|
|
|
|
x = XConv2D(dim, 3, strides=1)(x)
|
|
x = XNormalization(x)
|
|
x = ReLU()(x)
|
|
return Add()([x,input])
|
|
return func
|
|
|
|
x = input
|
|
|
|
x = ReLU()(XNormalization(XConv2D(ngf, 7, strides=1)(x)))
|
|
|
|
x = ReLU()(XNormalization(XConv2D(ngf*2, 3, strides=2)(x)))
|
|
x = ReLU()(XNormalization(XConv2D(ngf*4, 3, strides=2)(x)))
|
|
|
|
x = ReLU()(XNormalization(XConv2D(ngf*4, 3, strides=2)(x)))
|
|
|
|
for i in range(n_blocks):
|
|
x = ResnetBlock(ngf*4, use_dropout=use_dropout)(x)
|
|
|
|
x = ReLU()(XNormalization(XConv2DTranspose(ngf*4, 3, strides=2)(x)))
|
|
|
|
x = ReLU()(XNormalization(XConv2DTranspose(ngf*2, 3, strides=2)(x)))
|
|
x = ReLU()(XNormalization(XConv2DTranspose(ngf , 3, strides=2)(x)))
|
|
|
|
x = XConv2D(output_nc, 7, strides=1, activation='sigmoid', use_bias=True)(x)
|
|
|
|
return x
|
|
|
|
return func
|
|
|
|
@staticmethod
|
|
def initialize_nn_functions():
|
|
exec (nnlib.import_all(), locals(), globals())
|
|
|
|
class ResidualBlock(object):
|
|
def __init__(self, filters, kernel_size=3, padding='zero', **kwargs):
|
|
self.filters = filters
|
|
self.kernel_size = kernel_size
|
|
self.padding = padding
|
|
|
|
def __call__(self, inp):
|
|
x = inp
|
|
x = Conv2D(self.filters, kernel_size=self.kernel_size, padding=self.padding)(x)
|
|
x = LeakyReLU(0.2)(x)
|
|
x = Conv2D(self.filters, kernel_size=self.kernel_size, padding=self.padding)(x)
|
|
x = Add()([x, inp])
|
|
x = LeakyReLU(0.2)(x)
|
|
return x
|
|
AVATARModel.ResidualBlock = ResidualBlock
|
|
|
|
def downscale (dim, padding='zero', act='', **kwargs):
|
|
def func(x):
|
|
return LeakyReLU(0.2) (Conv2D(dim, kernel_size=5, strides=2, padding=padding)(x))
|
|
return func
|
|
AVATARModel.downscale = downscale
|
|
|
|
def upscale (dim, padding='zero', norm='', act='', **kwargs):
|
|
def func(x):
|
|
return SubpixelUpscaler()( LeakyReLU(0.2)(Conv2D(dim * 4, kernel_size=3, strides=1, padding=padding)(x)))
|
|
return func
|
|
AVATARModel.upscale = upscale
|
|
|
|
def to_bgr (output_nc, padding='zero', **kwargs):
|
|
def func(x):
|
|
return Conv2D(output_nc, kernel_size=5, padding=padding, activation='sigmoid')(x)
|
|
return func
|
|
AVATARModel.to_bgr = to_bgr
|
|
|
|
Model = AVATARModel |