diff --git a/README.md b/README.md index 184ea43..6b218bf 100644 --- a/README.md +++ b/README.md @@ -82,7 +82,9 @@ LIAEF128 Cage video: [![Watch the video](https://img.youtube.com/vi/mRsexePEVco/0.jpg)](https://www.youtube.com/watch?v=mRsexePEVco) -- **SAE (2GB+)** - Styled AutoEncoder - new superior model based on style loss. Morphing/stylizing done directly by neural network. Face obstructions also reconstructed without any masks. Converter mode 'overlay' should be used. Model has several options on start for fine tuning to fit your GPU. +- **SAE ( minimum 2GB+, recommended 11GB+ )** - Styled AutoEncoder - new superior model based on style loss. Morphing/stylizing done directly by neural network. Face obstructions also reconstructed without any masks. Converter mode 'overlay' should be used. Model has several options on start for fine tuning to fit your GPU. + +SAE actually contains all other models. Just set style powers to 0.0 to get default models. ![](https://github.com/iperov/DeepFaceLab/blob/master/doc/SAE_Cage_0.jpg) @@ -96,7 +98,6 @@ SAE model Putin-Navalny video: https://www.youtube.com/watch?v=Jj7b3mqx-Mw ![](https://github.com/iperov/DeepFaceLab/blob/master/doc/DeepFaceLab_convertor_overview.png) - ### **Tips and tricks**: unfortunately deepfaking is time/eletricpower consuming topic and has a lot of nuances. @@ -119,13 +120,17 @@ Narrow src face is better fakeable than wide. This is why Cage is so popular in SAE tips: +- SAE actually contains all other models. Just set style power options to 0.0 to get default models. + - if src faceset has number of faces more than dst faceset, model can be not converged. In this case try 'Feed faces to network sorted by yaw' option. - if src face wider than dst, model can be not converged. In this case try to decrease 'Src face scale modifier' to -5. - architecture 'df' make predicted face looking more like src, but if model not converges try default 'liae'. -- most scenes converge fine with batch size = 8. In this case better to increase 'encoder/decoder dims per channel' to get more sharp result. +- if you have a lot of VRAM, you can choose between batch size that affects quality of generalization and enc/dec dims that affects image quality. + +- if style speed is too fast, you will get the artifacts before the face becomes sharp. ### **Sort tool**: diff --git a/mainscripts/Trainer.py b/mainscripts/Trainer.py index d79818e..f902c23 100644 --- a/mainscripts/Trainer.py +++ b/mainscripts/Trainer.py @@ -163,10 +163,10 @@ def previewThread (input_queue, output_queue): if update_preview: update_preview = False - (h,w,c) = previews[0][1].shape - + selected_preview_name = previews[selected_preview][0] selected_preview_rgb = previews[selected_preview][1] + (h,w,c) = selected_preview_rgb.shape # HEAD head_text_color = [0.8]*c diff --git a/models/Model_DF/Model.py b/models/Model_DF/Model.py index 71f2d50..f918fb4 100644 --- a/models/Model_DF/Model.py +++ b/models/Model_DF/Model.py @@ -14,7 +14,7 @@ class Model(ModelBase): #override def onInitialize(self, **in_options): exec(nnlib.import_all(), locals(), globals()) - self.set_vram_batch_requirements( {4.5:16,5:16,6:16,7:16,8:24,9:24,10:32,11:32,12:32,13:48} ) + self.set_vram_batch_requirements( {4.5:4,5:6,6:8,7:16,8:24,9:24,10:32,11:32,12:32,13:48} ) ae_input_layer = Input(shape=(128, 128, 3)) mask_layer = Input(shape=(128, 128, 1)) #same as output diff --git a/models/Model_SAE/Model.py b/models/Model_SAE/Model.py index c26e539..6e52e3e 100644 --- a/models/Model_SAE/Model.py +++ b/models/Model_SAE/Model.py @@ -28,21 +28,21 @@ class SAEModel(ModelBase): if is_first_run: self.options['resolution'] = input_int("Resolution (64,128 ?:help skip:128) : ", default_resolution, [64,128], help_message="More resolution requires more VRAM.") - self.options['archi'] = input_str ("AE architecture (df, liae, ?:help skip:liae) : ", default_archi, ['df','liae'], help_message="DF keeps faces more natural, while LIAE can fix overly different face shapes.").lower() + self.options['archi'] = input_str ("AE architecture (df, liae, ?:help skip:%s) : " % (default_archi) , default_archi, ['df','liae'], help_message="DF keeps faces more natural, while LIAE can fix overly different face shapes.").lower() self.options['lighter_encoder'] = input_bool ("Use lightweight encoder? (y/n, ?:help skip:n) : ", False, help_message="Lightweight encoder is 35% faster, requires less VRAM, sacrificing overall quality.") else: self.options['resolution'] = self.options.get('resolution', default_resolution) self.options['archi'] = self.options.get('archi', default_archi) self.options['lighter_encoder'] = self.options.get('lighter_encoder', False) - default_face_style_power = 10.0 + default_face_style_power = 2.0 if is_first_run or ask_override: default_face_style_power = default_face_style_power if is_first_run else self.options.get('face_style_power', default_face_style_power) self.options['face_style_power'] = np.clip ( input_number("Face style power ( 0.0 .. 100.0 ?:help skip:%.1f) : " % (default_face_style_power), default_face_style_power, help_message="How fast NN will learn dst face style during generalization of src and dst faces."), 0.0, 100.0 ) else: self.options['face_style_power'] = self.options.get('face_style_power', default_face_style_power) - default_bg_style_power = 10.0 + default_bg_style_power = 2.0 if is_first_run or ask_override: default_bg_style_power = default_bg_style_power if is_first_run else self.options.get('bg_style_power', default_bg_style_power) self.options['bg_style_power'] = np.clip ( input_number("Background style power ( 0.0 .. 100.0 ?:help skip:%.1f) : " % (default_bg_style_power), default_bg_style_power, help_message="How fast NN will learn dst background style during generalization of src and dst faces."), 0.0, 100.0 ) @@ -102,8 +102,7 @@ class SAEModel(ModelBase): self.decoder = modelify(SAEModel.LIAEDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2)) (inter_output_Inputs) self.decoderm = modelify(SAEModel.LIAEDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5) )) (inter_output_Inputs) - - + if not self.is_first_run(): self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5)) self.inter_B.load_weights (self.get_strpath_storage_for_file(self.inter_BH5)) @@ -132,16 +131,14 @@ class SAEModel(ModelBase): pred_src_dstm = self.decoderm(warped_src_dst_inter_code) else: self.encoder = modelify(SAEModel.DFEncFlow(resolution, adapt_k_size, self.options['lighter_encoder'], ae_dims=ae_dims, ed_ch_dims=ed_ch_dims) ) (Input(bgr_shape)) - + dec_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ] self.decoder_src = modelify(SAEModel.DFDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2)) (dec_Inputs) self.decoder_dst = modelify(SAEModel.DFDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2)) (dec_Inputs) - self.decoder_srcm = modelify(SAEModel.DFDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5))) (dec_Inputs) self.decoder_dstm = modelify(SAEModel.DFDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5))) (dec_Inputs) - - + if not self.is_first_run(): self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5)) self.decoder_src.load_weights (self.get_strpath_storage_for_file(self.decoder_srcH5)) @@ -150,114 +147,106 @@ class SAEModel(ModelBase): self.decoder_dstm.load_weights (self.get_strpath_storage_for_file(self.decoder_dstmH5)) warped_src_code = self.encoder (warped_src) - pred_src_src = self.decoder_src(warped_src_code) + + pred_src_src = self.decoder_src(warped_src_code) pred_src_srcm = self.decoder_srcm(warped_src_code) - + warped_dst_code = self.encoder (warped_dst) - pred_dst_dst = self.decoder_dst(warped_dst_code) + + pred_dst_dst = self.decoder_dst(warped_dst_code) pred_dst_dstm = self.decoder_dstm(warped_dst_code) pred_src_dst = self.decoder_src(warped_dst_code) pred_src_dstm = self.decoder_srcm(warped_dst_code) - - - target_srcm_blurred = tf_gaussian_blur(resolution // 32)(target_srcm) - target_srcm_sigm = target_srcm_blurred / 2.0 + 0.5 - target_srcm_anti_sigm = 1.0 - target_srcm_sigm - target_dstm_blurred = tf_gaussian_blur(resolution // 32)(target_dstm) - target_dstm_sigm = target_dstm_blurred / 2.0 + 0.5 - target_dstm_anti_sigm = 1.0 - target_dstm_sigm + ms_count = len(pred_src_src) - target_src_sigm = target_src+1 - target_dst_sigm = target_dst+1 - - pred_src_src_sigm = pred_src_src+1 - pred_dst_dst_sigm = pred_dst_dst+1 - pred_src_dst_sigm = pred_src_dst+1 + target_src_ar = [ target_src if i == 0 else tf.image.resize_bicubic( target_src, (resolution // (2**i) ,)*2 ) for i in range(ms_count-1, -1, -1)] + target_srcm_ar = [ target_srcm if i == 0 else tf.image.resize_bicubic( target_srcm, (resolution // (2**i) ,)*2 ) for i in range(ms_count-1, -1, -1)] + target_dst_ar = [ target_dst if i == 0 else tf.image.resize_bicubic( target_dst, (resolution // (2**i) ,)*2 ) for i in range(ms_count-1, -1, -1)] + target_dstm_ar = [ target_dstm if i == 0 else tf.image.resize_bicubic( target_dstm, (resolution // (2**i) ,)*2 ) for i in range(ms_count-1, -1, -1)] - target_src_masked = target_src_sigm*target_srcm_sigm + target_srcm_blurred_ar = [ tf_gaussian_blur( max(1, x.get_shape().as_list()[1] // 32) )(x) for x in target_srcm_ar] + target_srcm_sigm_ar = [ x / 2.0 + 0.5 for x in target_srcm_blurred_ar] + target_srcm_anti_sigm_ar = [ 1.0 - x for x in target_srcm_sigm_ar] + + target_dstm_blurred_ar = [ tf_gaussian_blur( max(1, x.get_shape().as_list()[1] // 32) )(x) for x in target_dstm_ar] + target_dstm_sigm_ar = [ x / 2.0 + 0.5 for x in target_dstm_blurred_ar] + target_dstm_anti_sigm_ar = [ 1.0 - x for x in target_dstm_sigm_ar] - target_dst_masked = target_dst_sigm * target_dstm_sigm - target_dst_anti_masked = target_dst_sigm * target_dstm_anti_sigm - - pred_src_src_masked = pred_src_src_sigm * target_srcm_sigm - pred_dst_dst_masked = pred_dst_dst_sigm * target_dstm_sigm - - psd_target_dst_masked = pred_src_dst_sigm * target_dstm_sigm - psd_target_dst_anti_masked = pred_src_dst_sigm * target_dstm_anti_sigm + target_src_sigm_ar = [ x + 1 for x in target_src_ar] + target_dst_sigm_ar = [ x + 1 for x in target_dst_ar] + + pred_src_src_sigm_ar = [ x + 1 for x in pred_src_src] + pred_dst_dst_sigm_ar = [ x + 1 for x in pred_dst_dst] + pred_src_dst_sigm_ar = [ x + 1 for x in pred_src_dst] + + target_src_masked_ar = [ target_src_sigm_ar[i]*target_srcm_sigm_ar[i] for i in range(len(target_src_sigm_ar))] + target_dst_masked_ar = [ target_dst_sigm_ar[i]*target_dstm_sigm_ar[i] for i in range(len(target_dst_sigm_ar))] + target_dst_anti_masked_ar = [ target_dst_sigm_ar[i]*target_dstm_anti_sigm_ar[i] for i in range(len(target_dst_sigm_ar))] - src_loss = K.mean( 100*K.square(tf_dssim(2.0)( target_src_masked, pred_src_src_masked )) ) - - def optimizer(): - return Adam(lr=5e-5, beta_1=0.5, beta_2=0.999) - - if self.options['face_style_power'] != 0: - face_style_power = self.options['face_style_power'] / 100.0 - src_loss += tf_style_loss(gaussian_blur_radius=resolution // 8, loss_weight=0.2*face_style_power)(psd_target_dst_masked, target_dst_masked) - - if self.options['bg_style_power'] != 0: - bg_style_power = self.options['bg_style_power'] / 100.0 - src_loss += K.mean( (100*bg_style_power)*K.square(tf_dssim(2.0)( psd_target_dst_anti_masked, target_dst_anti_masked ))) - - if self.options['archi'] == 'liae': - src_train_weights = self.encoder.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights - else: - src_train_weights = self.encoder.trainable_weights + self.decoder_src.trainable_weights - self.src_train = K.function ([warped_src, target_src, target_srcm, warped_dst, target_dst, target_dstm ],[src_loss], - optimizer().get_updates(src_loss, src_train_weights) ) - - dst_loss = K.mean( 100*K.square(tf_dssim(2.0)( target_dst_masked, pred_dst_dst_masked )) ) - - if self.options['archi'] == 'liae': - dst_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights - else: - dst_train_weights = self.encoder.trainable_weights + self.decoder_dst.trainable_weights - self.dst_train = K.function ([warped_dst, target_dst, target_dstm],[dst_loss], - optimizer().get_updates(dst_loss, dst_train_weights) ) - - src_mask_loss = K.mean(K.square(target_srcm-pred_src_srcm)) - - if self.options['archi'] == 'liae': - src_mask_train_weights = self.encoder.trainable_weights + self.inter_AB.trainable_weights + self.decoderm.trainable_weights - else: - src_mask_train_weights = self.encoder.trainable_weights + self.decoder_srcm.trainable_weights - - self.src_mask_train = K.function ([warped_src, target_srcm],[src_mask_loss], - optimizer().get_updates(src_mask_loss, src_mask_train_weights ) ) - - dst_mask_loss = K.mean(K.square(target_dstm-pred_dst_dstm)) - - if self.options['archi'] == 'liae': - dst_mask_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoderm.trainable_weights - else: - dst_mask_train_weights = self.encoder.trainable_weights + self.decoder_dstm.trainable_weights - - self.dst_mask_train = K.function ([warped_dst, target_dstm],[dst_mask_loss], - optimizer().get_updates(dst_mask_loss, dst_mask_train_weights) ) - - self.AE_view = K.function ([warped_src, warped_dst],[pred_src_src, pred_src_srcm, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm]) - self.AE_convert = K.function ([warped_dst],[pred_src_dst, pred_src_dstm]) + psd_target_dst_masked_ar = [ pred_src_dst_sigm_ar[i]*target_dstm_sigm_ar[i] for i in range(len(pred_src_dst_sigm_ar))] + psd_target_dst_anti_masked_ar = [ pred_src_dst_sigm_ar[i]*target_dstm_anti_sigm_ar[i] for i in range(len(pred_src_dst_sigm_ar))] if self.is_training_mode: - f = SampleProcessor.TypeFlags + def optimizer(): + return Adam(lr=5e-5, beta_1=0.5, beta_2=0.999) - face_type = f.FACE_ALIGN_FULL if self.options['face_type'] == 'f' else f.FACE_ALIGN_HALF + + if self.options['archi'] == 'liae': + src_loss_train_weights = self.encoder.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights + src_mask_loss_train_weights = self.encoder.trainable_weights + self.inter_AB.trainable_weights + self.decoderm.trainable_weights + dst_loss_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights + dst_mask_loss_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoderm.trainable_weights + else: + src_loss_train_weights = self.encoder.trainable_weights + self.decoder_src.trainable_weights + src_mask_loss_train_weights = self.encoder.trainable_weights + self.decoder_srcm.trainable_weights + dst_loss_train_weights = self.encoder.trainable_weights + self.decoder_dst.trainable_weights + dst_mask_loss_train_weights = self.encoder.trainable_weights + self.decoder_dstm.trainable_weights + + src_loss = sum([ K.mean( 100*K.square(tf_dssim(2.0)( target_src_masked_ar[i], pred_src_src_sigm_ar[i] * target_srcm_sigm_ar[i] ) )) for i in range(len(target_src_masked_ar)) ]) + + if self.options['face_style_power'] != 0: + face_style_power = self.options['face_style_power'] / 100.0 + src_loss += tf_style_loss(gaussian_blur_radius=resolution // 8, loss_weight=0.2*face_style_power)( psd_target_dst_masked_ar[-1], target_dst_masked_ar[-1] ) + + if self.options['bg_style_power'] != 0: + bg_style_power = self.options['bg_style_power'] / 100.0 + src_loss += K.mean( (100*bg_style_power)*K.square(tf_dssim(2.0)( psd_target_dst_anti_masked_ar[-1], target_dst_anti_masked_ar[-1] ))) + self.src_train = K.function ([warped_src, target_src, target_srcm, warped_dst, target_dst, target_dstm ],[src_loss], optimizer().get_updates(src_loss, src_loss_train_weights) ) + + src_mask_loss = sum([ K.mean(K.square(target_srcm_ar[i]-pred_src_srcm[i])) for i in range(len(target_srcm_ar)) ]) + self.src_mask_train = K.function ([warped_src, target_srcm],[src_mask_loss], optimizer().get_updates(src_mask_loss, src_mask_loss_train_weights) ) + + dst_loss = sum([ K.mean( 100*K.square(tf_dssim(2.0)( target_dst_masked_ar[i], pred_dst_dst_sigm_ar[i] * target_dstm_sigm_ar[i] ) )) for i in range(len(target_dst_masked_ar)) ]) + self.dst_train = K.function ([warped_dst, target_dst, target_dstm ],[dst_loss], optimizer().get_updates(dst_loss, dst_loss_train_weights) ) + + dst_mask_loss = sum([ K.mean(K.square(target_dstm_ar[i]-pred_dst_dstm[i])) for i in range(len(target_dstm_ar)) ]) + self.dst_mask_train = K.function ([warped_dst, target_dstm],[dst_mask_loss], optimizer().get_updates(dst_mask_loss, dst_mask_loss_train_weights) ) + + self.AE_view = K.function ([warped_src, warped_dst], [pred_src_src[-1], pred_src_srcm[-1], pred_dst_dst[-1], pred_dst_dstm[-1], pred_src_dst[-1], pred_src_dstm[-1]] ) + else: + self.AE_convert = K.function ([warped_dst],[ pred_src_dst[-1], pred_src_dstm[-1] ]) + + if self.is_training_mode: + f = SampleProcessor.TypeFlags + face_type = f.FACE_ALIGN_FULL if self.options['face_type'] == 'f' else f.FACE_ALIGN_HALF self.set_training_data_generators ([ SampleGeneratorFace(self.training_data_src_path, sort_by_yaw_target_samples_path=self.training_data_dst_path if self.sort_by_yaw else None, debug=self.is_debug(), batch_size=self.batch_size, sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, normalize_tanh = True, scale_range=np.array([-0.05, 0.05])+self.src_scale_mod / 100.0 ), - output_sample_types=[ [f.WARPED_TRANSFORMED | face_type | f.MODE_BGR, resolution], + output_sample_types=[ [f.WARPED_TRANSFORMED | face_type | f.MODE_BGR, resolution], [f.TRANSFORMED | face_type | f.MODE_BGR, resolution], - [f.TRANSFORMED | face_type | f.MODE_M | f.FACE_MASK_FULL, resolution] + [f.TRANSFORMED | face_type | f.MODE_M | f.FACE_MASK_FULL, resolution] ] ), SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size, sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, normalize_tanh = True), - output_sample_types=[ [f.WARPED_TRANSFORMED | face_type | f.MODE_BGR, resolution], + output_sample_types=[ [f.WARPED_TRANSFORMED | face_type | f.MODE_BGR, resolution], [f.TRANSFORMED | face_type | f.MODE_BGR, resolution], - [f.TRANSFORMED | face_type | f.MODE_M | f.FACE_MASK_FULL, resolution] ] ) + [f.TRANSFORMED | face_type | f.MODE_M | f.FACE_MASK_FULL, resolution] + ] ) ]) #override def onSave(self): @@ -279,7 +268,7 @@ class SAEModel(ModelBase): #override def onTrainOneEpoch(self, sample): warped_src, target_src, target_src_mask = sample[0] - warped_dst, target_dst, target_dst_mask = sample[1] + warped_dst, target_dst, target_dst_mask = sample[1] src_loss, = self.src_train ([warped_src, target_src, target_src_mask, warped_dst, target_dst, target_dst_mask]) dst_loss, = self.dst_train ([warped_dst, target_dst, target_dst_mask]) @@ -296,24 +285,19 @@ class SAEModel(ModelBase): test_A_m = sample[0][2][0:4] #first 4 samples test_B = sample[1][1][0:4] test_B_m = sample[1][2][0:4] - - S = test_A - D = test_B - - SS, SM, DD, DM, SD, SDM = self.AE_view ([test_A, test_B]) - S, D, SS, SM, DD, DM, SD, SDM = [ x / 2 + 0.5 for x in [S, D, SS, SM, DD, DM, SD, SDM] ] - SM, DM, SDM = [ np.repeat (x, (3,), -1) for x in [SM, DM, SDM] ] - - st = [] + S, D, SS, SM, DD, DM, SD, SDM, = [ x / 2 + 0.5 for x in ([test_A,test_B] + self.AE_view ([test_A, test_B]) ) ] + #SM, DM, SDM = [ np.repeat (x, (3,), -1) for x in [SM, DM, SDM] ] + + st_x3 = [] for i in range(0, len(test_A)): - st.append ( np.concatenate ( ( + st_x3.append ( np.concatenate ( ( S[i], SS[i], #SM[i], D[i], DD[i], #DM[i], SD[i], #SDM[i] ), axis=1) ) - - return [ ('SAE', np.concatenate ( st, axis=0 ) ) ] + + return [ ('SAE', np.concatenate (st_x3, axis=0 )), ] def predictor_func (self, face): face = face * 2.0 - 1.0 @@ -371,7 +355,7 @@ class SAEModel(ModelBase): x = downscale(ed_dims*8)(x) else: x = downscale_sep(ed_dims*2)(x) - x = downscale_sep(ed_dims*4)(x) + x = downscale(ed_dims*4)(x) x = downscale_sep(ed_dims*8)(x) x = Flatten()(x) @@ -400,26 +384,31 @@ class SAEModel(ModelBase): @staticmethod def LIAEDecFlow(output_nc,ed_ch_dims=21,activation='tanh'): exec (nnlib.import_all(), locals(), globals()) + ed_dims = output_nc * ed_ch_dims def upscale (dim): def func(x): return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, 3, strides=1, padding='same')(x))) - return func - - def func(input): - ed_dims = output_nc * ed_ch_dims + return func + def to_bgr (): + def func(x): + return Conv2D(output_nc, kernel_size=5, padding='same', activation='tanh')(x) + return func + def func(input): x = input[0] - x = upscale(ed_dims*8)(x) - x = upscale(ed_dims*4)(x) - x = upscale(ed_dims*2)(x) - - x = Conv2D(output_nc, kernel_size=5, padding='same', activation=activation)(x) - return x + x1 = upscale(ed_dims*8)( x ) + x1_bgr = to_bgr() ( x1 ) - return func - + x2 = upscale(ed_dims*4)( x1 ) + x2_bgr = to_bgr() ( x2 ) + + x3 = upscale(ed_dims*2)( x2 ) + x3_bgr = to_bgr() ( x3 ) + return [ x1_bgr, x2_bgr, x3_bgr ] + return func + @staticmethod def DFEncFlow(resolution, adapt_k_size, light_enc, ae_dims=512, ed_ch_dims=42): exec (nnlib.import_all(), locals(), globals()) @@ -466,23 +455,31 @@ class SAEModel(ModelBase): return func @staticmethod - def DFDecFlow(output_nc, ed_ch_dims=21, activation='tanh'): + def DFDecFlow(output_nc, ed_ch_dims=21): exec (nnlib.import_all(), locals(), globals()) ed_dims = output_nc * ed_ch_dims def upscale (dim): def func(x): return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, 3, strides=1, padding='same')(x))) - return func - def func(input): - x = input[0] - x = upscale(ed_dims*8)(x) - x = upscale(ed_dims*4)(x) - x = upscale(ed_dims*2)(x) - - x = Conv2D(output_nc, kernel_size=5, padding='same', activation=activation)(x) - return x + return func + def to_bgr (): + def func(x): + return Conv2D(output_nc, kernel_size=5, padding='same', activation='tanh')(x) + return func + def func(input): + x = input[0] + x1 = upscale(ed_dims*8)( x ) + x1_bgr = to_bgr() ( x1 ) + + x2 = upscale(ed_dims*4)( x1 ) + x2_bgr = to_bgr() ( x2 ) + + x3 = upscale(ed_dims*2)( x2 ) + x3_bgr = to_bgr() ( x3 ) + + return [ x1_bgr, x2_bgr, x3_bgr ] return func Model = SAEModel \ No newline at end of file diff --git a/nnlib/nnlib.py b/nnlib/nnlib.py index 2d465e5..f5b2aaf 100644 --- a/nnlib/nnlib.py +++ b/nnlib/nnlib.py @@ -75,6 +75,7 @@ Conv2D = keras.layers.Conv2D Conv2DTranspose = keras.layers.Conv2DTranspose SeparableConv2D = keras.layers.SeparableConv2D MaxPooling2D = keras.layers.MaxPooling2D +UpSampling2D = keras.layers.UpSampling2D BatchNormalization = keras.layers.BatchNormalization LeakyReLU = keras.layers.LeakyReLU