diff --git a/models/Model_Quick96/Model.py b/models/Model_Quick96/Model.py new file mode 100644 index 0000000..42e508c --- /dev/null +++ b/models/Model_Quick96/Model.py @@ -0,0 +1,262 @@ +from functools import partial + +import numpy as np + +import mathlib +from facelib import FaceType +from interact import interact as io +from models import ModelBase +from nnlib import nnlib +from samplelib import * + + +class Quick96Model(ModelBase): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs, + ask_enable_autobackup=False, + ask_write_preview_history=False, + ask_target_iter=False, + ask_batch_size=False, + ask_sort_by_yaw=False, + ask_random_flip=False, + ask_src_scale_mod=False) + + #override + def onInitialize(self): + exec(nnlib.import_all(), locals(), globals()) + self.set_vram_batch_requirements({1.5:2,2:4})#,3:4,4:8}) + + resolution = self.resolution = 96 + + class CommonModel(object): + def downscale (self, dim, kernel_size=5, dilation_rate=1): + def func(x): + return SubpixelDownscaler()(ELU()(Conv2D(dim // 4, kernel_size=kernel_size, strides=1, dilation_rate=dilation_rate, padding='same')(x))) + return func + + def upscale (self, dim, size=(2,2)): + def func(x): + return SubpixelUpscaler(size=size)(ELU()(Conv2D(dim * np.prod(size) , kernel_size=3, strides=1, padding='same')(x))) + return func + + def ResidualBlock(self, dim): + def func(inp): + x = Conv2D(dim, kernel_size=3, padding='same')(inp) + x = LeakyReLU(0.2)(x) + x = Conv2D(dim, kernel_size=3, padding='same')(x) + x = Add()([x, inp]) + x = LeakyReLU(0.2)(x) + return x + return func + + class QModel(CommonModel): + def __init__(self, resolution, ae_dims, e_dims, d_dims): + super().__init__() + bgr_shape = (resolution, resolution, 3) + mask_shape = (resolution, resolution, 1) + lowest_dense_res = resolution // 16 + + def enc_flow(): + def func(inp): + x = self.downscale(e_dims, 3, 1 )(inp) + x = self.downscale(e_dims*2, 3, 1 )(x) + x = self.downscale(e_dims*4, 3, 1 )(x) + x0 = self.downscale(e_dims*8, 3, 1 )(x) + + x = self.downscale(e_dims, 3, 2 )(inp) + x = self.downscale(e_dims*2, 3, 2 )(x) + x = self.downscale(e_dims*4, 3, 2 )(x) + x1 = self.downscale(e_dims*8, 3, 2 )(x) + + x = Concatenate()([x0,x1]) + + x = DenseMaxout(ae_dims, kernel_initializer='orthogonal')(Flatten()(x)) + x = DenseMaxout(lowest_dense_res * lowest_dense_res * ae_dims, kernel_initializer='orthogonal')(x) + x = Reshape((lowest_dense_res, lowest_dense_res, ae_dims))(x) + + x = self.ResidualBlock(ae_dims)(x) + x = self.upscale(d_dims*8)(x) + x = self.ResidualBlock(d_dims*8)(x) + return x + return func + + def dec_flow(): + def func(inp): + x = self.upscale(d_dims*4)(inp) + x = self.ResidualBlock(d_dims*4)(x) + x = self.upscale(d_dims*2)(x) + x = self.ResidualBlock(d_dims*2)(x) + x = self.upscale(d_dims)(x) + x = self.ResidualBlock(d_dims)(x) + + y = self.upscale(d_dims)(inp) + y = self.upscale(d_dims//2)(y) + y = self.upscale(d_dims//4)(y) + + return Conv2D(3, kernel_size=5, padding='same', activation='tanh')(x), \ + Conv2D(1, kernel_size=5, padding='same', activation='sigmoid')(y) + + return func + + self.encoder = modelify(enc_flow()) ( Input(bgr_shape) ) + + sh = K.int_shape( self.encoder.outputs[0] )[1:] + self.decoder_src = modelify(dec_flow()) ( Input(sh) ) + self.decoder_dst = modelify(dec_flow()) ( Input(sh) ) + + self.src_trainable_weights = self.encoder.trainable_weights + self.decoder_src.trainable_weights + self.dst_trainable_weights = self.encoder.trainable_weights + self.decoder_dst.trainable_weights + + self.warped_src, self.warped_dst = Input(bgr_shape), Input(bgr_shape) + self.target_src, self.target_dst = Input(bgr_shape), Input(bgr_shape) + self.target_srcm, self.target_dstm = Input(mask_shape), Input(mask_shape) + + self.src_code = self.encoder(self.warped_src) + self.dst_code = self.encoder(self.warped_dst) + + self.pred_src_src, self.pred_src_srcm = self.decoder_src(self.src_code) + self.pred_dst_dst, self.pred_dst_dstm = self.decoder_dst(self.dst_code) + self.pred_src_dst, self.pred_src_dstm = self.decoder_src(self.dst_code) + + def get_model_filename_list(self, exclude_for_pretrain=False): + ar = [] + if not exclude_for_pretrain: + ar += [ [self.encoder, 'encoder.h5'] ] + ar += [ [self.decoder_src, 'decoder_src.h5'], + [self.decoder_dst, 'decoder_dst.h5'] ] + + return ar + + self.model = QModel (resolution, 128, 64, 64) + + loaded, not_loaded = [], self.model.get_model_filename_list() + if not self.is_first_run(): + loaded, not_loaded = self.load_weights_safe(not_loaded) + + CA_models = [ model for model, _ in not_loaded ] + + self.CA_conv_weights_list = [] + for model in CA_models: + for layer in model.layers: + if type(layer) == keras.layers.Conv2D: + self.CA_conv_weights_list += [layer.weights[0]] #- is Conv2D kernel_weights + + if self.is_training_mode: + self.src_dst_opt = RMSprop(lr=2e-4) + self.src_dst_mask_opt = RMSprop(lr=2e-4) + + target_src_masked = self.model.target_src*self.model.target_srcm + target_dst_masked = self.model.target_dst*self.model.target_dstm + + pred_src_src_masked = self.model.pred_src_src*self.model.target_srcm + pred_dst_dst_masked = self.model.pred_dst_dst*self.model.target_dstm + + src_loss = K.mean ( 10*dssim(kernel_size=int(resolution/11.6),max_value=2.0)( target_src_masked+1, pred_src_src_masked+1) ) + src_loss += K.mean ( 10*K.square( target_src_masked - pred_src_src_masked ) ) + src_loss += K.mean(K.square(self.model.target_srcm-self.model.pred_src_srcm)) + + dst_loss = K.mean( 10*dssim(kernel_size=int(resolution/11.6),max_value=2.0)(target_dst_masked+1, pred_dst_dst_masked+1) ) + dst_loss += K.mean( 10*K.square( target_dst_masked - pred_dst_dst_masked ) ) + dst_loss += K.mean(K.square(self.model.target_dstm-self.model.pred_dst_dstm)) + + self.src_train = K.function ([self.model.warped_src, self.model.target_src, self.model.target_srcm], [src_loss], self.src_dst_opt.get_updates( src_loss, self.model.src_trainable_weights) ) + self.dst_train = K.function ([self.model.warped_dst, self.model.target_dst, self.model.target_dstm], [dst_loss], self.src_dst_opt.get_updates( dst_loss, self.model.dst_trainable_weights) ) + self.AE_view = K.function ([self.model.warped_src, self.model.warped_dst], [self.model.pred_src_src, self.model.pred_dst_dst, self.model.pred_dst_dstm, self.model.pred_src_dst, self.model.pred_src_dstm]) + else: + self.AE_convert = K.function ([self.model.warped_dst],[ self.model.pred_src_dst, self.model.pred_dst_dstm, self.model.pred_src_dstm ]) + + if self.is_training_mode: + t = SampleProcessor.Types + + self.set_training_data_generators ([ + SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.batch_size, + sample_process_options=SampleProcessor.Options(random_flip=False, scale_range=np.array([-0.05, 0.05])+self.src_scale_mod / 100.0 ), + output_sample_types = [ {'types' : (t.IMG_WARPED_TRANSFORMED, t.FACE_TYPE_FULL, t.MODE_BGR), 'resolution': resolution, 'normalize_tanh':True }, + {'types' : (t.IMG_TRANSFORMED, t.FACE_TYPE_FULL, t.MODE_BGR), 'resolution': resolution, 'normalize_tanh':True }, + {'types' : (t.IMG_TRANSFORMED, t.FACE_TYPE_FULL, t.MODE_M), 'resolution': resolution } ] + ), + + SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size, + sample_process_options=SampleProcessor.Options(random_flip=False, ), + output_sample_types = [ {'types' : (t.IMG_WARPED_TRANSFORMED, t.FACE_TYPE_FULL, t.MODE_BGR), 'resolution': resolution, 'normalize_tanh':True }, + {'types' : (t.IMG_TRANSFORMED, t.FACE_TYPE_FULL, t.MODE_BGR), 'resolution': resolution, 'normalize_tanh':True }, + {'types' : (t.IMG_TRANSFORMED, t.FACE_TYPE_FULL, t.MODE_M), 'resolution': resolution} ]) + ]) + self.counter = 0 + + #override + def get_model_filename_list(self): + return self.model.get_model_filename_list () + + #override + def onSave(self): + self.save_weights_safe( self.get_model_filename_list() ) + + #override + def on_success_train_one_iter(self): + if len(self.CA_conv_weights_list) != 0: + exec(nnlib.import_all(), locals(), globals()) + CAInitializerMP ( self.CA_conv_weights_list ) + self.CA_conv_weights_list = [] + + #override + def onTrainOneIter(self, generators_samples, generators_list): + warped_src, target_src, target_srcm = generators_samples[0] + warped_dst, target_dst, target_dstm = generators_samples[1] + + self.counter += 1 + if self.counter % 3 == 0: + src_loss, = self.src_train ([warped_src, target_src, target_srcm]) + dst_loss, = self.dst_train ([warped_dst, target_dst, target_dstm]) + else: + src_loss, = self.src_train ([target_src, target_src, target_srcm]) + dst_loss, = self.dst_train ([target_dst, target_dst, target_dstm]) + + return ( ('src_loss', src_loss), ('dst_loss', dst_loss), ) + + #override + def onGetPreview(self, sample): + test_S = sample[0][1][0:4] #first 4 samples + test_S_m = sample[0][2][0:4] #first 4 samples + test_D = sample[1][1][0:4] + test_D_m = sample[1][2][0:4] + + S, D, SS, DD, DDM, SD, SDM = [test_S,test_D] + self.AE_view ([test_S, test_D]) + S, D, SS, DD, SD, = [ np.clip(x/2+0.5, 0.0, 1.0) for x in [S, D, SS, DD, SD] ] + DDM, SDM, = [ np.clip( np.repeat (x, (3,), -1), 0, 1) for x in [DDM, SDM] ] + + result = [] + st = [] + for i in range(len(test_S)): + ar = S[i], SS[i], D[i], DD[i], SD[i] + st.append ( np.concatenate ( ar, axis=1) ) + + result += [ ('Quick96', np.concatenate (st, axis=0 )), ] + + st_m = [] + for i in range(len(test_S)): + ar = S[i]*test_S_m[i], SS[i], D[i]*test_D_m[i], DD[i]*DDM[i], SD[i]*(DDM[i]*SDM[i]) + st_m.append ( np.concatenate ( ar, axis=1) ) + + result += [ ('Quick96 masked', np.concatenate (st_m, axis=0 )), ] + + return result + + def predictor_func (self, face=None, dummy_predict=False): + if dummy_predict: + self.AE_convert ([ np.zeros ( (1, self.resolution, self.resolution, 3), dtype=np.float32 ) ]) + else: + face = face * 2 - 1 + bgr, mask_dst_dstm, mask_src_dstm = self.AE_convert ([face[np.newaxis,...]]) + bgr = bgr /2 + 0.5 + mask = mask_dst_dstm[0] * mask_src_dstm[0] + return bgr[0], mask[...,0] + + #override + def get_ConverterConfig(self): + import converters + return self.predictor_func, (self.resolution, self.resolution, 3), converters.ConverterConfigMasked(face_type=FaceType.FULL, + default_mode = 1, clip_hborder_mask_per=0.0625) + +Model = Quick96Model diff --git a/models/Model_Quick96/__init__.py b/models/Model_Quick96/__init__.py new file mode 100644 index 0000000..0188f11 --- /dev/null +++ b/models/Model_Quick96/__init__.py @@ -0,0 +1 @@ +from .Model import Model