from functools import partial import cv2 import numpy as np from facelib import FaceType from interact import interact as io from mathlib import get_power_of_two from models import ModelBase from nnlib import nnlib from samplelib import * from facelib import PoseEstimator class AVATARModel(ModelBase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs, ask_sort_by_yaw=False, ask_random_flip=False, ask_src_scale_mod=False) #override def onInitializeOptions(self, is_first_run, ask_override): if is_first_run: avatar_type = io.input_int("Avatar type ( 0:source, 1:head, 2:full_face ?:help skip:1) : ", 1, [0,1,2], help_message="Training target for the model. Source is direct untouched images. Full_face or head are centered nose unaligned faces.") avatar_type = {0:'source', 1:'head', 2:'full_face'}[avatar_type] self.options['avatar_type'] = avatar_type else: self.options['avatar_type'] = self.options.get('avatar_type', 'head') if is_first_run or ask_override: def_stage = self.options.get('stage', 1) self.options['stage'] = io.input_int("Stage (0, 1, 2 ?:help skip:%d) : " % def_stage, def_stage, [0,1,2], help_message="Train first stage, then second. Tune batch size to maximum possible for both stages.") else: self.options['stage'] = self.options.get('stage', 1) #override def onInitialize(self, batch_size=-1, **in_options): exec(nnlib.code_import_all, locals(), globals()) self.set_vram_batch_requirements({6:4}) AVATARModel.initialize_nn_functions() resolution = self.resolution = 224 avatar_type = self.options['avatar_type'] stage = self.stage = self.options['stage'] df_res = self.df_res = 128 df_bgr_shape = (df_res, df_res, 3) df_mask_shape = (df_res, df_res, 1) res_bgr_shape = (resolution, resolution, 3) res_bgr_t_shape = (resolution, resolution, 9) self.enc = modelify(AVATARModel.EncFlow())( [Input(df_bgr_shape),] ) self.decA64 = modelify(AVATARModel.DecFlow()) ( [ Input(K.int_shape(self.enc.outputs[0])[1:]) ] ) self.decB64 = modelify(AVATARModel.DecFlow()) ( [ Input(K.int_shape(self.enc.outputs[0])[1:]) ] ) self.D = modelify(AVATARModel.Discriminator() ) (Input(df_bgr_shape)) self.C = modelify(AVATARModel.ResNet (9, use_batch_norm=False, n_blocks=6, ngf=128, use_dropout=False))( Input(res_bgr_t_shape)) #self.CD = modelify(AVATARModel.CDiscriminator() ) (Input(res_bgr_t_shape)) if self.is_first_run(): conv_weights_list = [] for model, _ in self.get_model_filename_list(): for layer in model.layers: if type(layer) == keras.layers.Conv2D: conv_weights_list += [layer.weights[0]] #Conv2D kernel_weights CAInitializerMP ( conv_weights_list ) if not self.is_first_run(): self.load_weights_safe( self.get_model_filename_list() ) def DLoss(labels,logits): return K.mean(K.binary_crossentropy(labels,logits)) warped_A64 = Input(df_bgr_shape) real_A64 = Input(df_bgr_shape) real_A64m = Input(df_mask_shape) real_B64_t0 = Input(df_bgr_shape) real_B64_t1 = Input(df_bgr_shape) real_B64_t2 = Input(df_bgr_shape) real_A64_t0 = Input(df_bgr_shape) real_A64m_t0 = Input(df_mask_shape) real_A_t0 = Input(res_bgr_shape) real_A64_t1 = Input(df_bgr_shape) real_A64m_t1 = Input(df_mask_shape) real_A_t1 = Input(res_bgr_shape) real_A64_t2 = Input(df_bgr_shape) real_A64m_t2 = Input(df_mask_shape) real_A_t2 = Input(res_bgr_shape) warped_B64 = Input(df_bgr_shape) real_B64 = Input(df_bgr_shape) real_B64m = Input(df_mask_shape) warped_A_code = self.enc (warped_A64) warped_B_code = self.enc (warped_B64) rec_A64 = self.decA64(warped_A_code) rec_B64 = self.decB64(warped_B_code) rec_AB64 = self.decA64(warped_B_code) def Lambda_grey_mask (x,m): return Lambda (lambda x: x[0]*m+(1-m)*0.5, output_shape= K.int_shape(x)[1:3] + (3,)) ([x, m]) def Lambda_gray_pad(x): a = np.ones((resolution,resolution,3))*0.5 pad = ( resolution - df_res ) // 2 a[pad:-pad:,pad:-pad:,:] = 0 return Lambda ( lambda x: K.spatial_2d_padding(x, padding=((pad, pad), (pad, pad)) ) + K.constant(a, dtype=K.floatx() ), output_shape=(resolution,resolution,3) ) (x) def Lambda_concat ( x ): c = sum ( [ K.int_shape(l)[-1] for l in x ] ) return Lambda ( lambda x: K.concatenate (x, axis=-1), output_shape=K.int_shape(x[0])[1:3] + (c,) ) (x) def Lambda_Cto3t(x): return Lambda ( lambda x: x[...,0:3], output_shape= K.int_shape(x)[1:3] + (3,) ) (x), \ Lambda ( lambda x: x[...,3:6], output_shape= K.int_shape(x)[1:3] + (3,) ) (x), \ Lambda ( lambda x: x[...,6:9], output_shape= K.int_shape(x)[1:3] + (3,) ) (x) real_A64_d = self.D( Lambda_grey_mask(real_A64, real_A64m) ) real_A64_d_ones = K.ones_like(real_A64_d) fake_A64_d = self.D(rec_AB64) fake_A64_d_ones = K.ones_like(fake_A64_d) fake_A64_d_zeros = K.zeros_like(fake_A64_d) rec_AB_t0 = Lambda_gray_pad( self.decA64 (self.enc (real_B64_t0)) ) rec_AB_t1 = Lambda_gray_pad( self.decA64 (self.enc (real_B64_t1)) ) rec_AB_t2 = Lambda_gray_pad( self.decA64 (self.enc (real_B64_t2)) ) C_in_A_t0 = Lambda_gray_pad( Lambda_grey_mask (real_A64_t0, real_A64m_t0) ) C_in_A_t1 = Lambda_gray_pad( Lambda_grey_mask (real_A64_t1, real_A64m_t1) ) C_in_A_t2 = Lambda_gray_pad( Lambda_grey_mask (real_A64_t2, real_A64m_t2) ) rec_C_A_t0, rec_C_A_t1, rec_C_A_t2 = Lambda_Cto3t ( self.C ( Lambda_concat ( [C_in_A_t0, C_in_A_t1, C_in_A_t2]) ) ) rec_C_AB_t0, rec_C_AB_t1, rec_C_AB_t2 = Lambda_Cto3t( self.C ( Lambda_concat ( [rec_AB_t0, rec_AB_t1, rec_AB_t2]) ) ) #real_A_t012_d = self.CD ( K.concatenate ( [real_A_t0, real_A_t1,real_A_t2], axis=-1) ) #real_A_t012_d_ones = K.ones_like(real_A_t012_d) #rec_C_AB_t012_d = self.CD ( K.concatenate ( [rec_C_AB_t0,rec_C_AB_t1, rec_C_AB_t2], axis=-1) ) #rec_C_AB_t012_d_ones = K.ones_like(rec_C_AB_t012_d) #rec_C_AB_t012_d_zeros = K.zeros_like(rec_C_AB_t012_d) self.G64_view = K.function([warped_A64, warped_B64],[rec_A64, rec_B64, rec_AB64]) self.G_view = K.function([real_A64_t0, real_A64m_t0, real_A64_t1, real_A64m_t1, real_A64_t2, real_A64m_t2, real_B64_t0, real_B64_t1, real_B64_t2], [rec_C_A_t0, rec_C_A_t1, rec_C_A_t2, rec_C_AB_t0, rec_C_AB_t1, rec_C_AB_t2]) if self.is_training_mode: loss_AB64 = K.mean(10 * dssim(kernel_size=int(df_res/11.6),max_value=1.0) ( rec_A64, real_A64*real_A64m + (1-real_A64m)*0.5) ) + \ K.mean(10 * dssim(kernel_size=int(df_res/11.6),max_value=1.0) ( rec_B64, real_B64*real_B64m + (1-real_B64m)*0.5) ) + 0.1*DLoss(fake_A64_d_ones, fake_A64_d ) weights_AB64 = self.enc.trainable_weights + self.decA64.trainable_weights + self.decB64.trainable_weights loss_C = K.mean( 10 * dssim(kernel_size=int(resolution/11.6),max_value=1.0) ( real_A_t0, rec_C_A_t0 ) ) + \ K.mean( 10 * dssim(kernel_size=int(resolution/11.6),max_value=1.0) ( real_A_t1, rec_C_A_t1 ) ) + \ K.mean( 10 * dssim(kernel_size=int(resolution/11.6),max_value=1.0) ( real_A_t2, rec_C_A_t2 ) ) #0.1*DLoss(rec_C_AB_t012_d_ones, rec_C_AB_t012_d ) weights_C = self.C.trainable_weights loss_D = (DLoss(real_A64_d_ones, real_A64_d ) + \ DLoss(fake_A64_d_zeros, fake_A64_d ) ) * 0.5 #loss_CD = ( DLoss(real_A_t012_d_ones, real_A_t012_d) + \ # DLoss(rec_C_AB_t012_d_zeros, rec_C_AB_t012_d) ) * 0.5 # #weights_CD = self.CD.trainable_weights def opt(lr=5e-5): return Adam(lr=lr, beta_1=0.5, beta_2=0.999, tf_cpu_mode=2 if 'tensorflow' in self.device_config.backend else 0 ) self.AB64_train = K.function ([warped_A64, real_A64, real_A64m, warped_B64, real_B64, real_B64m], [loss_AB64], opt().get_updates(loss_AB64, weights_AB64) ) self.C_train = K.function ([real_A64_t0, real_A64m_t0, real_A_t0, real_A64_t1, real_A64m_t1, real_A_t1, real_A64_t2, real_A64m_t2, real_A_t2, real_B64_t0, real_B64_t1, real_B64_t2],[ loss_C ], opt().get_updates(loss_C, weights_C) ) self.D_train = K.function ([warped_A64, real_A64, real_A64m, warped_B64, real_B64, real_B64m],[loss_D], opt().get_updates(loss_D, self.D.trainable_weights) ) #self.CD_train = K.function ([real_A64_t0, real_A64m_t0, real_A_t0, # real_A64_t1, real_A64m_t1, real_A_t1, # real_A64_t2, real_A64m_t2, real_A_t2, # real_B64_t0, real_B64_t1, real_B64_t2 ],[ loss_CD ], opt().get_updates(loss_CD, weights_CD) ) ########### t = SampleProcessor.Types training_target = {'source' : t.NONE, 'full_face' : t.FACE_TYPE_FULL_NO_ALIGN, 'head' : t.FACE_TYPE_HEAD_NO_ALIGN}[avatar_type] generators = [ SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.batch_size, sample_process_options=SampleProcessor.Options(random_flip=False), output_sample_types=[ {'types': (t.IMG_WARPED_TRANSFORMED, t.FACE_TYPE_FULL_NO_ALIGN, t.MODE_BGR), 'resolution':df_res}, {'types': (t.IMG_TRANSFORMED, t.FACE_TYPE_FULL_NO_ALIGN, t.MODE_BGR), 'resolution':df_res}, {'types': (t.IMG_TRANSFORMED, t.FACE_TYPE_FULL_NO_ALIGN, t.MODE_M), 'resolution':df_res} ] ), SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size, sample_process_options=SampleProcessor.Options(random_flip=False), output_sample_types=[ {'types': (t.IMG_WARPED_TRANSFORMED, t.FACE_TYPE_FULL_NO_ALIGN, t.MODE_BGR), 'resolution':df_res}, {'types': (t.IMG_TRANSFORMED, t.FACE_TYPE_FULL_NO_ALIGN, t.MODE_BGR), 'resolution':df_res}, {'types': (t.IMG_TRANSFORMED, t.FACE_TYPE_FULL_NO_ALIGN, t.MODE_M), 'resolution':df_res} ] ), SampleGeneratorFaceTemporal(self.training_data_src_path, debug=self.is_debug(), batch_size=self.batch_size, temporal_image_count=3, sample_process_options=SampleProcessor.Options(random_flip=False), output_sample_types=[{'types': (t.IMG_WARPED_TRANSFORMED, t.FACE_TYPE_FULL_NO_ALIGN, t.MODE_BGR), 'resolution':df_res},#IMG_WARPED_TRANSFORMED {'types': (t.IMG_WARPED_TRANSFORMED, t.FACE_TYPE_FULL_NO_ALIGN, t.MODE_M), 'resolution':df_res}, {'types': (t.IMG_SOURCE, training_target, t.MODE_BGR), 'resolution':resolution}, ] ), SampleGeneratorFaceTemporal(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size, temporal_image_count=3, sample_process_options=SampleProcessor.Options(random_flip=False), output_sample_types=[{'types': (t.IMG_SOURCE, t.FACE_TYPE_FULL_NO_ALIGN, t.MODE_BGR), 'resolution':df_res}, {'types': (t.IMG_SOURCE, t.NONE, t.MODE_BGR), 'resolution':resolution}, ] ), ] if self.stage == 1: generators[2].set_active(False) generators[3].set_active(False) elif self.stage == 2: generators[0].set_active(False) generators[1].set_active(False) self.set_training_data_generators (generators) else: self.G_convert = K.function([real_B64_t0, real_B64_t1, real_B64_t2],[rec_C_AB_t1]) #override , return [ [model, filename],... ] list def get_model_filename_list(self): return [ [self.enc, 'enc.h5'], [self.decA64, 'decA64.h5'], [self.decB64, 'decB64.h5'], [self.C, 'C.h5'], [self.D, 'D.h5'], #[self.CD, 'CD.h5'], ] #override def onSave(self): self.save_weights_safe( self.get_model_filename_list() ) #override def onTrainOneIter(self, generators_samples, generators_list): warped_src64, src64, src64m = generators_samples[0] warped_dst64, dst64, dst64m = generators_samples[1] real_A64_t0, real_A64m_t0, real_A_t0, real_A64_t1, real_A64m_t1, real_A_t1, real_A64_t2, real_A64m_t2, real_A_t2 = generators_samples[2] real_B64_t0, _, real_B64_t1, _, real_B64_t2, _ = generators_samples[3] if self.stage == 0 or self.stage == 1: loss, = self.AB64_train ( [warped_src64, src64, src64m, warped_dst64, dst64, dst64m] ) loss_D, = self.D_train ( [warped_src64, src64, src64m, warped_dst64, dst64, dst64m] ) if self.stage != 0: loss_C = loss_CD = 0 if self.stage == 0 or self.stage == 2: loss_C1, = self.C_train ( [real_A64_t0, real_A64m_t0, real_A_t0, real_A64_t1, real_A64m_t1, real_A_t1, real_A64_t2, real_A64m_t2, real_A_t2, real_B64_t0, real_B64_t1, real_B64_t2] ) loss_C2, = self.C_train ( [real_A64_t2, real_A64m_t2, real_A_t2, real_A64_t1, real_A64m_t1, real_A_t1, real_A64_t0, real_A64m_t0, real_A_t0, real_B64_t0, real_B64_t1, real_B64_t2] ) #loss_CD1, = self.CD_train ( [real_A64_t0, real_A64m_t0, real_A_t0, # real_A64_t1, real_A64m_t1, real_A_t1, # real_A64_t2, real_A64m_t2, real_A_t2, # real_B64_t0, real_B64_t1, real_B64_t2] ) # #loss_CD2, = self.CD_train ( [real_A64_t2, real_A64m_t2, real_A_t2, # real_A64_t1, real_A64m_t1, real_A_t1, # real_A64_t0, real_A64m_t0, real_A_t0, # real_B64_t0, real_B64_t1, real_B64_t2] ) loss_C = (loss_C1 + loss_C2) / 2 #loss_CD = (loss_CD1 + loss_CD2) / 2 if self.stage != 0: loss = loss_D = 0 return ( ('loss', loss), ('D', loss_D), ('C', loss_C), ) #('CD', loss_CD) ) #override def onGetPreview(self, sample): test_A064w = sample[0][0][0:4] test_A064r = sample[0][1][0:4] test_A064m = sample[0][2][0:4] test_B064w = sample[1][0][0:4] test_B064r = sample[1][1][0:4] test_B064m = sample[1][2][0:4] t_src64_0 = sample[2][0][0:4] t_src64m_0 = sample[2][1][0:4] t_src_0 = sample[2][2][0:4] t_src64_1 = sample[2][3][0:4] t_src64m_1 = sample[2][4][0:4] t_src_1 = sample[2][5][0:4] t_src64_2 = sample[2][6][0:4] t_src64m_2 = sample[2][7][0:4] t_src_2 = sample[2][8][0:4] t_dst64_0 = sample[3][0][0:4] t_dst_0 = sample[3][1][0:4] t_dst64_1 = sample[3][2][0:4] t_dst_1 = sample[3][3][0:4] t_dst64_2 = sample[3][4][0:4] t_dst_2 = sample[3][5][0:4] G64_view_result = self.G64_view ([test_A064r, test_B064r]) test_A064r, test_B064r, rec_A64, rec_B64, rec_AB64 = [ x[0] for x in ([test_A064r, test_B064r] + G64_view_result) ] sample64x4 = np.concatenate ([ np.concatenate ( [rec_B64, rec_A64], axis=1 ), np.concatenate ( [test_B064r, rec_AB64], axis=1) ], axis=0 ) sample64x4 = cv2.resize (sample64x4, (self.resolution, self.resolution) ) G_view_result = self.G_view([t_src64_0, t_src64m_0, t_src64_1, t_src64m_1, t_src64_2, t_src64m_2, t_dst64_0, t_dst64_1, t_dst64_2 ]) t_dst_0, t_dst_1, t_dst_2, rec_C_A_t0, rec_C_A_t1, rec_C_A_t2, rec_C_AB_t0, rec_C_AB_t1, rec_C_AB_t2 = [ x[0] for x in ([t_dst_0, t_dst_1, t_dst_2, ] + G_view_result) ] c1 = np.concatenate ( (sample64x4, rec_C_A_t0, t_dst_0, rec_C_AB_t0 ), axis=1 ) c2 = np.concatenate ( (sample64x4, rec_C_A_t1, t_dst_1, rec_C_AB_t1 ), axis=1 ) c3 = np.concatenate ( (sample64x4, rec_C_A_t2, t_dst_2, rec_C_AB_t2 ), axis=1 ) r = np.concatenate ( [c1,c2,c3], axis=0 ) return [ ('AVATAR', r ) ] def predictor_func (self, prev_imgs=None, img=None, next_imgs=None, dummy_predict=False): if dummy_predict: z = np.zeros ( (1, self.df_res, self.df_res, 3), dtype=np.float32 ) self.G_convert ([z,z,z]) else: feed = [ prev_imgs[-1][np.newaxis,...], img[np.newaxis,...], next_imgs[0][np.newaxis,...] ] x = self.G_convert (feed)[0] return np.clip ( x[0], 0, 1) #override def get_ConverterConfig(self): import converters return self.predictor_func, (self.df_res, self.df_res, 3), converters.ConverterConfigFaceAvatar(temporal_face_count=1) @staticmethod def NLayerDiscriminator(ndf=64, n_layers=3): exec (nnlib.import_all(), locals(), globals()) #use_bias = True #def XNormalization(x): # return InstanceNormalization (axis=-1)(x) use_bias = False def XNormalization(x): return BatchNormalization (axis=-1)(x) XConv2D = partial(Conv2D, use_bias=use_bias) def func(x): f = ndf x = XConv2D( f, 4, strides=2, padding='same', use_bias=True)(x) f = min( ndf*8, f*2 ) x = LeakyReLU(0.2)(x) for i in range(n_layers): x = XConv2D( f, 4, strides=2, padding='same')(x) x = XNormalization(x) x = LeakyReLU(0.2)(x) f = min( ndf*8, f*2 ) x = XConv2D( f, 4, strides=1, padding='same')(x) x = XNormalization(x) x = LeakyReLU(0.2)(x) return XConv2D( 1, 4, strides=1, padding='same', use_bias=True, activation='sigmoid')(x)# return func """ @staticmethod def Discriminator(ndf=128): exec (nnlib.import_all(), locals(), globals()) #use_bias = True #def XNormalization(x): # return InstanceNormalization (axis=-1)(x) use_bias = False def XNormalization(x): return BatchNormalization (axis=-1)(x) XConv2D = partial(Conv2D, use_bias=use_bias) def func(input): b,h,w,c = K.int_shape(input) x = input x = XConv2D( ndf, 4, strides=2, padding='same', use_bias=True)(x) x = LeakyReLU(0.2)(x) x = XConv2D( ndf*2, 4, strides=2, padding='same')(x) x = XNormalization(x) x = LeakyReLU(0.2)(x) x = XConv2D( ndf*4, 4, strides=2, padding='same')(x) x = XNormalization(x) x = LeakyReLU(0.2)(x) x = XConv2D( ndf*8, 4, strides=2, padding='same')(x) x = XNormalization(x) x = LeakyReLU(0.2)(x) return XConv2D( 1, 4, strides=1, padding='same', use_bias=True, activation='sigmoid')(x)# return func """ @staticmethod def Discriminator(ndf=128): exec (nnlib.import_all(), locals(), globals()) use_bias = True def XNormalization(x): return InstanceNormalization (axis=-1)(x) #use_bias = False #def XNormalization(x): # return BatchNormalization (axis=-1)(x) XConv2D = partial(Conv2D, use_bias=use_bias) def func(input): b,h,w,c = K.int_shape(input) x = input x = XConv2D( ndf, 4, strides=2, padding='same', use_bias=True)(x) x = LeakyReLU(0.2)(x) x = XConv2D( ndf*2, 4, strides=2, padding='same')(x) x = XNormalization(x) x = LeakyReLU(0.2)(x) x = XConv2D( ndf*4, 4, strides=2, padding='same')(x) x = XNormalization(x) x = LeakyReLU(0.2)(x) x = XConv2D( ndf*8, 4, strides=2, padding='same')(x) x = XNormalization(x) x = LeakyReLU(0.2)(x) return XConv2D( 1, 4, strides=1, padding='same', use_bias=True, activation='sigmoid')(x)# return func @staticmethod def CDiscriminator(ndf=256): exec (nnlib.import_all(), locals(), globals()) use_bias = True def XNormalization(x): return InstanceNormalization (axis=-1)(x) #use_bias = False #def XNormalization(x): # return BatchNormalization (axis=-1)(x) XConv2D = partial(Conv2D, use_bias=use_bias) def func(input): b,h,w,c = K.int_shape(input) x = input x = XConv2D( ndf, 4, strides=2, padding='same', use_bias=True)(x) x = LeakyReLU(0.2)(x) x = XConv2D( ndf*2, 4, strides=2, padding='same')(x) x = XNormalization(x) x = LeakyReLU(0.2)(x) x = XConv2D( ndf*4, 4, strides=2, padding='same')(x) x = XNormalization(x) x = LeakyReLU(0.2)(x) #x = XConv2D( ndf*8, 4, strides=2, padding='same')(x) #x = XNormalization(x) #x = LeakyReLU(0.2)(x) return XConv2D( 1, 4, strides=1, padding='same', use_bias=True, activation='sigmoid')(x)# return func @staticmethod def EncFlow(padding='zero', **kwargs): exec (nnlib.import_all(), locals(), globals()) use_bias = False def XNorm(x): return BatchNormalization (axis=-1)(x) XConv2D = partial(Conv2D, padding=padding, use_bias=use_bias) def downscale (dim): def func(x): return LeakyReLU(0.1)( Conv2D(dim, 5, strides=2, padding='same')(x)) return func def upscale (dim): def func(x): return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, 3, strides=1, padding='same')(x))) return func def func(input): x, = input b,h,w,c = K.int_shape(x) dim_res = w // 16 x = downscale(64)(x) x = downscale(128)(x) x = downscale(256)(x) x = downscale(512)(x) x = Dense(512)(Flatten()(x)) x = Dense(dim_res * dim_res * 512)(x) x = Reshape((dim_res, dim_res, 512))(x) x = upscale(512)(x) return x return func @staticmethod def DecFlow(output_nc=3, **kwargs): exec (nnlib.import_all(), locals(), globals()) ResidualBlock = AVATARModel.ResidualBlock upscale = AVATARModel.upscale to_bgr = AVATARModel.to_bgr def func(input): x = input[0] x = upscale(512)(x) x = upscale(256)(x) x = upscale(128)(x) return to_bgr(output_nc) (x) return func """ @staticmethod def CNet(output_nc, use_batch_norm, ngf=64, n_blocks=6, use_dropout=False): exec (nnlib.import_all(), locals(), globals()) if not use_batch_norm: use_bias = True def XNormalization(x): return InstanceNormalization (axis=-1)(x) else: use_bias = False def XNormalization(x): return BatchNormalization (axis=-1)(x) XConv2D = partial(Conv2D, padding='same', use_bias=use_bias) XConv2DTranspose = partial(Conv2DTranspose, padding='same', use_bias=use_bias) def ResnetBlock(dim, use_dropout=False): def func(input): x = input x = XConv2D(dim, 3, strides=1)(x) x = XNormalization(x) x = ReLU()(x) if use_dropout: x = Dropout(0.5)(x) x = XConv2D(dim, 3, strides=1)(x) x = XNormalization(x) x = ReLU()(x) return Add()([x,input]) return func def preprocess(target_res): def func(input): inp_shape = K.int_shape (input[0]) t_len = len(input) total_ch = 0 for i in range(t_len): total_ch += K.int_shape (input[i])[-1] K.concatenate ( input, axis=-1) ) import code c ode.interact(local=dict(globals(), **locals())) x_shape = K.int_shape(x)[1:] pad = (target_res - x_shape[0]) // 2 a = np.ones((target_res,target_res,3))*0.5 a[pad:-pad:,pad:-pad:,:] = 0 return K.spatial_2d_padding(x, padding=((pad, pad), (pad, pad)) ) + K.constant(a, dtype=K.floatx() ) return func def func(input): inp_shape = K.int_shape (input[0]) t_len = len(input) total_ch = 0 for i in range(t_len): total_ch += K.int_shape (input[i])[-1] x = Lambda ( preprocess(128) , output_shape=(inp_shape[1], inp_shape[2], total_ch) ) (input) x = ReLU()(XNormalization(XConv2D(ngf, 7, strides=1)(x))) x = ReLU()(XNormalization(XConv2D(ngf*2, 3, strides=2)(x))) x = ReLU()(XNormalization(XConv2D(ngf*4, 3, strides=2)(x))) for i in range(n_blocks): x = ResnetBlock(ngf*4, use_dropout=use_dropout)(x) x = ReLU()(XNormalization(XConv2DTranspose(ngf*2, 3, strides=2)(x))) x = ReLU()(XNormalization(XConv2DTranspose(ngf , 3, strides=2)(x))) x = XConv2D(output_nc, 7, strides=1, activation='sigmoid', use_bias=True)(x) return x return func """ @staticmethod def ResNet(output_nc, use_batch_norm, ngf=64, n_blocks=6, use_dropout=False): exec (nnlib.import_all(), locals(), globals()) if not use_batch_norm: use_bias = True def XNormalization(x): return InstanceNormalization (axis=-1)(x) else: use_bias = False def XNormalization(x): return BatchNormalization (axis=-1)(x) XConv2D = partial(Conv2D, padding='same', use_bias=use_bias) XConv2DTranspose = partial(Conv2DTranspose, padding='same', use_bias=use_bias) def func(input): def ResnetBlock(dim, use_dropout=False): def func(input): x = input x = XConv2D(dim, 3, strides=1)(x) x = XNormalization(x) x = ReLU()(x) if use_dropout: x = Dropout(0.5)(x) x = XConv2D(dim, 3, strides=1)(x) x = XNormalization(x) x = ReLU()(x) return Add()([x,input]) return func x = input x = ReLU()(XNormalization(XConv2D(ngf, 7, strides=1)(x))) x = ReLU()(XNormalization(XConv2D(ngf*2, 3, strides=2)(x))) x = ReLU()(XNormalization(XConv2D(ngf*4, 3, strides=2)(x))) x = ReLU()(XNormalization(XConv2D(ngf*4, 3, strides=2)(x))) for i in range(n_blocks): x = ResnetBlock(ngf*4, use_dropout=use_dropout)(x) x = ReLU()(XNormalization(XConv2DTranspose(ngf*4, 3, strides=2)(x))) x = ReLU()(XNormalization(XConv2DTranspose(ngf*2, 3, strides=2)(x))) x = ReLU()(XNormalization(XConv2DTranspose(ngf , 3, strides=2)(x))) x = XConv2D(output_nc, 7, strides=1, activation='sigmoid', use_bias=True)(x) return x return func @staticmethod def initialize_nn_functions(): exec (nnlib.import_all(), locals(), globals()) class ResidualBlock(object): def __init__(self, filters, kernel_size=3, padding='zero', **kwargs): self.filters = filters self.kernel_size = kernel_size self.padding = padding def __call__(self, inp): x = inp x = Conv2D(self.filters, kernel_size=self.kernel_size, padding=self.padding)(x) x = LeakyReLU(0.2)(x) x = Conv2D(self.filters, kernel_size=self.kernel_size, padding=self.padding)(x) x = Add()([x, inp]) x = LeakyReLU(0.2)(x) return x AVATARModel.ResidualBlock = ResidualBlock def downscale (dim, padding='zero', act='', **kwargs): def func(x): return LeakyReLU(0.2) (Conv2D(dim, kernel_size=5, strides=2, padding=padding)(x)) return func AVATARModel.downscale = downscale def upscale (dim, padding='zero', norm='', act='', **kwargs): def func(x): return SubpixelUpscaler()( LeakyReLU(0.2)(Conv2D(dim * 4, kernel_size=3, strides=1, padding=padding)(x))) return func AVATARModel.upscale = upscale def to_bgr (output_nc, padding='zero', **kwargs): def func(x): return Conv2D(output_nc, kernel_size=5, padding=padding, activation='sigmoid')(x) return func AVATARModel.to_bgr = to_bgr Model = AVATARModel