diff --git a/README.md b/README.md index 6821339..9f398cb 100644 --- a/README.md +++ b/README.md @@ -92,7 +92,9 @@ DF - good for side faces, but results in a lower resolution and details. Covers LIAE - can partially fix dissimilar face shapes, but results in a less recognizable face. SAE - no matter how similar faces, src face will be morphed onto dst face, which can make face absolutely unrecognizable. Model can collapse on some scenes. Easy to overlay final face because dst background is also predicted. -also quality of src faceset significantly affects the final face. +Quality of src faceset significantly affects the final face. + +Narrow src face is better fakeable than wide. This is why Cage is so popular in deepfakes. ### **Sort tool**: diff --git a/models/ModelBase.py b/models/ModelBase.py index d421d4d..4913b54 100644 --- a/models/ModelBase.py +++ b/models/ModelBase.py @@ -58,11 +58,12 @@ class ModelBase(object): if self.epoch == 0: print ("\nModel first run. Enter model options as default for each run.") - self.options['write_preview_history'] = input_bool("Write preview history? (y/n skip:n) : ", False) + self.options['write_preview_history'] = input_bool("Write preview history? (y/n ?:help skip:n) : ", False, help_message="Preview history will be writed to _history folder.") self.options['target_epoch'] = max(0, input_int("Target epoch (skip:unlimited) : ", 0)) - self.options['batch_size'] = max(0, input_int("Batch_size (skip:model choice) : ", 0)) - self.options['sort_by_yaw'] = input_bool("Feed faces to network sorted by yaw? (y/n skip:n) : ", False) - self.options['random_flip'] = input_bool("Flip faces randomly? (y/n skip:y) : ", True) + self.options['batch_size'] = max(0, input_int("Batch_size (?:help skip:model choice) : ", 0, help_message="Larger batch size is always better for NN's generalization, but it can cause Out of Memory error. Tune this value for your videocard manually.")) + self.options['sort_by_yaw'] = input_bool("Feed faces to network sorted by yaw? (y/n ?:help skip:n) : ", False, help_message="NN will not learn src face directions that don't match dst face directions." ) + self.options['random_flip'] = input_bool("Flip faces randomly? (y/n ?:help skip:y) : ", True, help_message="Predicted face will look more naturally without this option, but src faceset should cover all face directions as dst faceset.") + self.options['src_scale_mod'] = np.clip( input_int("Src face scale modifier % ( -30...30, ?:help skip:0) : ", 0, help_message="If src face shape is wider than dst, try to decrease this value to get a better result."), -30, 30) #self.options['use_fp16'] = use_fp16 = input_bool("Use float16? (y/n skip:n) : ", False) else: self.options['write_preview_history'] = self.options.get('write_preview_history', False) @@ -70,6 +71,7 @@ class ModelBase(object): self.options['batch_size'] = self.options.get('batch_size', 0) self.options['sort_by_yaw'] = self.options.get('sort_by_yaw', False) self.options['random_flip'] = self.options.get('random_flip', True) + self.options['src_scale_mod'] = self.options.get('src_scale_mod', 0) #self.options['use_fp16'] = use_fp16 = self.options['use_fp16'] if 'use_fp16' in self.options.keys() else False use_fp16 = False #currently models fails with fp16 @@ -105,6 +107,10 @@ class ModelBase(object): self.random_flip = self.options['random_flip'] if self.random_flip: self.options.pop('random_flip') + + self.src_scale_mod = self.options['src_scale_mod'] + if self.src_scale_mod == 0: + self.options.pop('src_scale_mod') self.write_preview_history = session_write_preview_history self.target_epoch = session_target_epoch diff --git a/models/Model_DF/Model.py b/models/Model_DF/Model.py index 708f8c7..5a5308e 100644 --- a/models/Model_DF/Model.py +++ b/models/Model_DF/Model.py @@ -40,7 +40,7 @@ class Model(ModelBase): self.set_training_data_generators ([ SampleGeneratorFace(self.training_data_src_path, sort_by_yaw_target_samples_path=self.training_data_dst_path if self.sort_by_yaw else None, debug=self.is_debug(), batch_size=self.batch_size, - sample_process_options=SampleProcessor.Options(random_flip=self.random_flip), + sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05])+self.src_scale_mod / 100.0 ), output_sample_types=[ [f.WARPED_TRANSFORMED | f.FACE_ALIGN_FULL | f.MODE_BGR, 128], [f.TRANSFORMED | f.FACE_ALIGN_FULL | f.MODE_BGR, 128], [f.TRANSFORMED | f.FACE_ALIGN_FULL | f.MODE_M | f.FACE_MASK_FULL, 128] ] ), diff --git a/models/Model_H128/Model.py b/models/Model_H128/Model.py index 0e9cc63..6e12a90 100644 --- a/models/Model_H128/Model.py +++ b/models/Model_H128/Model.py @@ -46,7 +46,7 @@ class Model(ModelBase): self.set_training_data_generators ([ SampleGeneratorFace(self.training_data_src_path, sort_by_yaw_target_samples_path=self.training_data_dst_path if self.sort_by_yaw else None, debug=self.is_debug(), batch_size=self.batch_size, - sample_process_options=SampleProcessor.Options(random_flip=self.random_flip), + sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05])+self.src_scale_mod / 100.0 ), output_sample_types=[ [f.WARPED_TRANSFORMED | f.FACE_ALIGN_HALF | f.MODE_BGR, 128], [f.TRANSFORMED | f.FACE_ALIGN_HALF | f.MODE_BGR, 128], [f.TRANSFORMED | f.FACE_ALIGN_HALF | f.MODE_M | f.FACE_MASK_FULL, 128] ] ), diff --git a/models/Model_H64/Model.py b/models/Model_H64/Model.py index 8653f49..9705127 100644 --- a/models/Model_H64/Model.py +++ b/models/Model_H64/Model.py @@ -47,7 +47,7 @@ class Model(ModelBase): self.set_training_data_generators ([ SampleGeneratorFace(self.training_data_src_path, sort_by_yaw_target_samples_path=self.training_data_dst_path if self.sort_by_yaw else None, debug=self.is_debug(), batch_size=self.batch_size, - sample_process_options=SampleProcessor.Options(random_flip=self.random_flip), + sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05])+self.src_scale_mod / 100.0 ), output_sample_types=[ [f.WARPED_TRANSFORMED | f.FACE_ALIGN_HALF | f.MODE_BGR, 64], [f.TRANSFORMED | f.FACE_ALIGN_HALF | f.MODE_BGR, 64], [f.TRANSFORMED | f.FACE_ALIGN_HALF | f.MODE_M | f.FACE_MASK_FULL, 64] ] ), diff --git a/models/Model_LIAEF128/Model.py b/models/Model_LIAEF128/Model.py index 4326a29..a28da8d 100644 --- a/models/Model_LIAEF128/Model.py +++ b/models/Model_LIAEF128/Model.py @@ -47,7 +47,7 @@ class Model(ModelBase): SampleGeneratorFace(self.training_data_src_path, sort_by_yaw_target_samples_path=self.training_data_dst_path if self.sort_by_yaw else None, debug=self.is_debug(), batch_size=self.batch_size, - sample_process_options=SampleProcessor.Options(random_flip=self.random_flip), + sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05])+self.src_scale_mod / 100.0 ), output_sample_types=[ [f.WARPED_TRANSFORMED | f.FACE_ALIGN_FULL | f.MODE_BGR, 128], [f.TRANSFORMED | f.FACE_ALIGN_FULL | f.MODE_BGR, 128], [f.TRANSFORMED | f.FACE_ALIGN_FULL | f.MODE_M | f.FACE_MASK_FULL, 128] ] ), diff --git a/models/Model_SAE/Model.py b/models/Model_SAE/Model.py index 1f866a6..31c7f88 100644 --- a/models/Model_SAE/Model.py +++ b/models/Model_SAE/Model.py @@ -15,21 +15,32 @@ class SAEModel(ModelBase): decoderH5 = 'decoder.h5' decodermH5 = 'decoderm.h5' + decoder_srcH5 = 'decoder_src.h5' + decoder_srcmH5 = 'decoder_srcm.h5' + decoder_dstH5 = 'decoder_dst.h5' + decoder_dstmH5 = 'decoder_dstm.h5' + #override def onInitializeOptions(self, is_first_run, ask_for_session_options): default_resolution = 128 - default_ae_dims = 256 + default_archi = 'liae' + default_style_power = 100 default_face_type = 'f' if is_first_run: #first run - self.options['resolution'] = input_int("Resolution (valid: 64,128, skip:128) : ", default_resolution, [64,128]) - self.options['ae_dims'] = input_int("AutoEncoder dims (valid: 128,256,512 skip:256) : ", default_ae_dims, [128,256,512]) - self.options['face_type'] = input_str ("Half or Full face? (h/f, skip:f) : ", default_face_type, ['h','f']) - + self.options['resolution'] = input_int("Resolution (64,128, ?:help skip:128) : ", default_resolution, [64,128], help_message="More resolution requires more VRAM.") + self.options['archi'] = input_str ("AE architecture (df, liae, ?:help skip:liae) : ", default_archi, ['df','liae'], help_message="DF keeps faces more natural, while LIAE can fix overly different face shapes.").lower() + self.options['style_power'] = np.clip ( input_int("Style power (1..100 ?:help skip:100) : ", default_style_power, help_message="How fast NN will learn dst style during generalization of src and dst faces."), 1, 100 ) + default_ae_dims = 256 if self.options['archi'] == 'liae' else 512 + self.options['ae_dims'] = input_int("AutoEncoder dims (128,256,512 ?:help skip:%d) : " % (default_ae_dims) , default_ae_dims, [128,256,512], help_message="More dims are better, but requires more VRAM." ) + self.options['face_type'] = input_str ("Half or Full face? (h/f, ?:help skip:f) : ", default_face_type, ['h','f'], help_message="Half face has better resolution, but covers less area of cheeks.").lower() else: #not first run self.options['resolution'] = self.options.get('resolution', default_resolution) + self.options['archi'] = self.options.get('archi', default_archi) + self.options['style_power'] = self.options.get('style_power', default_style_power) + default_ae_dims = 256 if self.options['archi'] == 'liae' else 512 self.options['ae_dims'] = self.options.get('ae_dims', default_ae_dims) self.options['face_type'] = self.options.get('face_type', default_face_type) @@ -44,52 +55,84 @@ class SAEModel(ModelBase): bgr_shape = (resolution, resolution, 3) mask_shape = (resolution, resolution, 1) - self.encoder = modelify(SAEModel.EncFlow() ) (Input(bgr_shape)) - - enc_output_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ] - - self.inter_B = modelify(SAEModel.InterFlow(dims=ae_dims,lowest_dense_res=resolution // 16)) (enc_output_Inputs) - self.inter_AB = modelify(SAEModel.InterFlow(dims=ae_dims,lowest_dense_res=resolution // 16)) (enc_output_Inputs) - - inter_output_Inputs = [ Input( np.array(K.int_shape(x)[1:])*(1,1,2) ) for x in self.inter_B.outputs ] - - self.decoder = modelify(SAEModel.DecFlow (bgr_shape[2],dims=ae_dims*2)) (inter_output_Inputs) - self.decoderm = modelify(SAEModel.DecFlow (mask_shape[2],dims=ae_dims)) (inter_output_Inputs) - - if not self.is_first_run(): - self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5)) - self.inter_B.load_weights (self.get_strpath_storage_for_file(self.inter_BH5)) - self.inter_AB.load_weights (self.get_strpath_storage_for_file(self.inter_ABH5)) - self.decoder.load_weights (self.get_strpath_storage_for_file(self.decoderH5)) - self.decoderm.load_weights (self.get_strpath_storage_for_file(self.decodermH5)) - warped_src = Input(bgr_shape) target_src = Input(bgr_shape) target_srcm = Input(mask_shape) - warped_src_code = self.encoder (warped_src) - - warped_src_inter_AB_code = self.inter_AB (warped_src_code) - warped_src_inter_code = Concatenate()([warped_src_inter_AB_code,warped_src_inter_AB_code]) - - pred_src_src = self.decoder(warped_src_inter_code) - pred_src_srcm = self.decoderm(warped_src_inter_code) - warped_dst = Input(bgr_shape) target_dst = Input(bgr_shape) target_dstm = Input(mask_shape) + + if self.options['archi'] == 'liae': + self.encoder = modelify(SAEModel.EncFlow() ) (Input(bgr_shape)) + + enc_output_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ] + + self.inter_B = modelify(SAEModel.InterFlow(dims=ae_dims,lowest_dense_res=resolution // 16)) (enc_output_Inputs) + self.inter_AB = modelify(SAEModel.InterFlow(dims=ae_dims,lowest_dense_res=resolution // 16)) (enc_output_Inputs) + + inter_output_Inputs = [ Input( np.array(K.int_shape(x)[1:])*(1,1,2) ) for x in self.inter_B.outputs ] + + self.decoder = modelify(SAEModel.DecFlow (bgr_shape[2],dims=ae_dims*2)) (inter_output_Inputs) + self.decoderm = modelify(SAEModel.DecFlow (mask_shape[2],dims=ae_dims)) (inter_output_Inputs) + + if not self.is_first_run(): + self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5)) + self.inter_B.load_weights (self.get_strpath_storage_for_file(self.inter_BH5)) + self.inter_AB.load_weights (self.get_strpath_storage_for_file(self.inter_ABH5)) + self.decoder.load_weights (self.get_strpath_storage_for_file(self.decoderH5)) + self.decoderm.load_weights (self.get_strpath_storage_for_file(self.decodermH5)) + + warped_src_code = self.encoder (warped_src) + + warped_src_inter_AB_code = self.inter_AB (warped_src_code) + warped_src_inter_code = Concatenate()([warped_src_inter_AB_code,warped_src_inter_AB_code]) + + pred_src_src = self.decoder(warped_src_inter_code) + pred_src_srcm = self.decoderm(warped_src_inter_code) + + + warped_dst_code = self.encoder (warped_dst) + warped_dst_inter_B_code = self.inter_B (warped_dst_code) + warped_dst_inter_AB_code = self.inter_AB (warped_dst_code) + warped_dst_inter_code = Concatenate()([warped_dst_inter_B_code,warped_dst_inter_AB_code]) + pred_dst_dst = self.decoder(warped_dst_inter_code) + pred_dst_dstm = self.decoderm(warped_dst_inter_code) + + warped_src_dst_inter_code = Concatenate()([warped_dst_inter_AB_code,warped_dst_inter_AB_code]) + pred_src_dst = self.decoder(warped_src_dst_inter_code) + pred_src_dstm = self.decoderm(warped_src_dst_inter_code) + else: + self.encoder = modelify(SAEModel.DFEncFlow(dims=ae_dims,lowest_dense_res=resolution // 16) ) (Input(bgr_shape)) + + dec_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ] + + self.decoder_src = modelify(SAEModel.DFDecFlow (bgr_shape[2],dims=ae_dims)) (dec_Inputs) + self.decoder_dst = modelify(SAEModel.DFDecFlow (bgr_shape[2],dims=ae_dims)) (dec_Inputs) + + self.decoder_srcm = modelify(SAEModel.DFDecFlow (mask_shape[2],dims=ae_dims//2)) (dec_Inputs) + self.decoder_dstm = modelify(SAEModel.DFDecFlow (mask_shape[2],dims=ae_dims//2)) (dec_Inputs) + - warped_dst_code = self.encoder (warped_dst) - warped_dst_inter_B_code = self.inter_B (warped_dst_code) - warped_dst_inter_AB_code = self.inter_AB (warped_dst_code) - warped_dst_inter_code = Concatenate()([warped_dst_inter_B_code,warped_dst_inter_AB_code]) - pred_dst_dst = self.decoder(warped_dst_inter_code) - pred_dst_dstm = self.decoderm(warped_dst_inter_code) - - warped_src_dst_inter_code = Concatenate()([warped_dst_inter_AB_code,warped_dst_inter_AB_code]) - pred_src_dst = self.decoder(warped_src_dst_inter_code) - pred_src_dstm = self.decoderm(warped_src_dst_inter_code) - + if not self.is_first_run(): + self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5)) + self.decoder_src.load_weights (self.get_strpath_storage_for_file(self.decoder_srcH5)) + self.decoder_srcm.load_weights (self.get_strpath_storage_for_file(self.decoder_srcmH5)) + self.decoder_dst.load_weights (self.get_strpath_storage_for_file(self.decoder_dstH5)) + self.decoder_dstm.load_weights (self.get_strpath_storage_for_file(self.decoder_dstmH5)) + + warped_src_code = self.encoder (warped_src) + pred_src_src = self.decoder_src(warped_src_code) + pred_src_srcm = self.decoder_srcm(warped_src_code) + + warped_dst_code = self.encoder (warped_dst) + pred_dst_dst = self.decoder_dst(warped_dst_code) + pred_dst_dstm = self.decoder_dstm(warped_dst_code) + + pred_src_dst = self.decoder_src(warped_dst_code) + pred_src_dstm = self.decoder_srcm(warped_dst_code) + + target_srcm_blurred = tf_gaussian_blur(resolution // 32)(target_srcm) target_srcm_sigm = target_srcm_blurred / 2.0 + 0.5 target_srcm_anti_sigm = 1.0 - target_srcm_sigm @@ -105,48 +148,61 @@ class SAEModel(ModelBase): pred_dst_dst_sigm = pred_dst_dst+1 pred_src_dst_sigm = pred_src_dst+1 - pred_src_dstm_blurred = tf_gaussian_blur(resolution // 32)(pred_src_dstm) - pred_src_dstm_sigm = pred_src_dstm_blurred / 2.0 + 0.5 - pred_src_dstm_anti_sigm = 1.0 - pred_src_dstm_sigm - target_src_masked = target_src_sigm*target_srcm_sigm target_dst_masked = target_dst_sigm * target_dstm_sigm target_dst_anti_masked = target_dst_sigm * target_dstm_anti_sigm - target_dst_psd_masked = target_dst_sigm * pred_src_dstm_sigm - target_dst_psd_anti_masked = target_dst_sigm * pred_src_dstm_anti_sigm - pred_src_src_masked = pred_src_src_sigm * target_srcm_sigm pred_dst_dst_masked = pred_dst_dst_sigm * target_dstm_sigm psd_target_dst_masked = pred_src_dst_sigm * target_dstm_sigm psd_target_dst_anti_masked = pred_src_dst_sigm * target_dstm_anti_sigm - psd_psd_masked = pred_src_dst_sigm * pred_src_dstm_sigm - psd_psd_anti_masked = pred_src_dst_sigm * pred_src_dstm_anti_sigm + style_power = self.options['style_power'] / 100.0 + src_loss = K.mean( 100*K.square(tf_dssim(2.0)( target_src_masked, pred_src_src_masked )) ) - src_loss += tf_style_loss(gaussian_blur_radius=resolution // 8, loss_weight=0.2)(psd_target_dst_masked, target_dst_masked) - src_loss += K.mean( 100*K.square(tf_dssim(2.0)( psd_target_dst_anti_masked, target_dst_anti_masked ))) - #src_loss += tf_style_loss(gaussian_blur_radius=resolution // 8, loss_weight=0.2)(psd_psd_masked, target_dst_psd_masked) - #src_loss += K.mean( 100*K.square(tf_dssim(2.0)( psd_psd_anti_masked, target_dst_psd_anti_masked ))) + #src_loss += tf_style_loss(gaussian_blur_radius=resolution // 8, loss_weight=0.2*style_power)(psd_target_dst_masked, target_dst_masked) + src_loss += K.mean( (100*style_power)*K.square(tf_dssim(2.0)( psd_target_dst_anti_masked, target_dst_anti_masked ))) + if self.options['archi'] == 'liae': + src_train_weights = self.encoder.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights + else: + src_train_weights = self.encoder.trainable_weights + self.decoder_src.trainable_weights self.src_train = K.function ([warped_src, target_src, target_srcm, warped_dst, target_dst, target_dstm ],[src_loss], - Adam(lr=5e-5, beta_1=0.5, beta_2=0.999).get_updates(src_loss, self.encoder.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights) ) - + Adam(lr=5e-5, beta_1=0.5, beta_2=0.999).get_updates(src_loss, src_train_weights) ) + + dst_loss = K.mean( 100*K.square(tf_dssim(2.0)( target_dst_masked, pred_dst_dst_masked )) ) + + if self.options['archi'] == 'liae': + dst_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights + else: + dst_train_weights = self.encoder.trainable_weights + self.decoder_dst.trainable_weights self.dst_train = K.function ([warped_dst, target_dst, target_dstm],[dst_loss], - Adam(lr=5e-5, beta_1=0.5, beta_2=0.999).get_updates(dst_loss, self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights) ) + Adam(lr=5e-5, beta_1=0.5, beta_2=0.999).get_updates(dst_loss, dst_train_weights) ) - src_mask_loss = K.mean(K.square(target_srcm-pred_src_srcm)) + src_mask_loss = K.mean(K.square(target_srcm-pred_src_srcm)) + + if self.options['archi'] == 'liae': + src_mask_train_weights = self.encoder.trainable_weights + self.inter_AB.trainable_weights + self.decoderm.trainable_weights + else: + src_mask_train_weights = self.encoder.trainable_weights + self.decoder_srcm.trainable_weights + self.src_mask_train = K.function ([warped_src, target_srcm],[src_mask_loss], - Adam(lr=5e-5, beta_1=0.5, beta_2=0.999).get_updates(src_mask_loss, self.encoder.trainable_weights + self.inter_AB.trainable_weights + self.decoderm.trainable_weights) ) + Adam(lr=5e-5, beta_1=0.5, beta_2=0.999).get_updates(src_mask_loss, src_mask_train_weights ) ) - dst_mask_loss = K.mean(K.square(target_dstm-pred_dst_dstm)) + dst_mask_loss = K.mean(K.square(target_dstm-pred_dst_dstm)) + + if self.options['archi'] == 'liae': + dst_mask_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoderm.trainable_weights + else: + dst_mask_train_weights = self.encoder.trainable_weights + self.decoder_dstm.trainable_weights + self.dst_mask_train = K.function ([warped_dst, target_dstm],[dst_mask_loss], - Adam(lr=5e-5, beta_1=0.5, beta_2=0.999).get_updates(dst_mask_loss, self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoderm.trainable_weights) ) + Adam(lr=5e-5, beta_1=0.5, beta_2=0.999).get_updates(dst_mask_loss, dst_mask_train_weights) ) self.AE_view = K.function ([warped_src, warped_dst],[pred_src_src, pred_src_srcm, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm]) self.AE_convert = K.function ([warped_dst],[pred_src_dst, pred_src_dstm]) @@ -155,11 +211,11 @@ class SAEModel(ModelBase): f = SampleProcessor.TypeFlags face_type = f.FACE_ALIGN_FULL if self.options['face_type'] == 'f' else f.FACE_ALIGN_HALF - + self.set_training_data_generators ([ SampleGeneratorFace(self.training_data_src_path, sort_by_yaw_target_samples_path=self.training_data_dst_path if self.sort_by_yaw else None, debug=self.is_debug(), batch_size=self.batch_size, - sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, normalize_tanh = True), + sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, normalize_tanh = True, scale_range=np.array([-0.05, 0.05])+self.src_scale_mod / 100.0 ), output_sample_types=[ [f.WARPED_TRANSFORMED | face_type | f.MODE_BGR, resolution], [f.TRANSFORMED | face_type | f.MODE_BGR, resolution], [f.TRANSFORMED | face_type | f.MODE_M | f.FACE_MASK_FULL, resolution], @@ -174,14 +230,22 @@ class SAEModel(ModelBase): [f.TRANSFORMED | face_type | f.MODE_M | f.FACE_MASK_FULL, resolution] ] ) ]) #override - def onSave(self): - self.save_weights_safe( [[self.encoder, self.get_strpath_storage_for_file(self.encoderH5)], - [self.inter_B, self.get_strpath_storage_for_file(self.inter_BH5)], - [self.inter_AB, self.get_strpath_storage_for_file(self.inter_ABH5)], - [self.decoder, self.get_strpath_storage_for_file(self.decoderH5)], - [self.decoderm, self.get_strpath_storage_for_file(self.decodermH5)], - ] ) - + def onSave(self): + if self.options['archi'] == 'liae': + self.save_weights_safe( [[self.encoder, self.get_strpath_storage_for_file(self.encoderH5)], + [self.inter_B, self.get_strpath_storage_for_file(self.inter_BH5)], + [self.inter_AB, self.get_strpath_storage_for_file(self.inter_ABH5)], + [self.decoder, self.get_strpath_storage_for_file(self.decoderH5)], + [self.decoderm, self.get_strpath_storage_for_file(self.decodermH5)], + ] ) + else: + self.save_weights_safe( [[self.encoder, self.get_strpath_storage_for_file(self.encoderH5)], + [self.decoder_src, self.get_strpath_storage_for_file(self.decoder_srcH5)], + [self.decoder_srcm, self.get_strpath_storage_for_file(self.decoder_srcmH5)], + [self.decoder_dst, self.get_strpath_storage_for_file(self.decoder_dstH5)], + [self.decoder_dstm, self.get_strpath_storage_for_file(self.decoder_dstmH5)], + ] ) + #override def onTrainOneEpoch(self, sample): warped_src, target_src, target_src_mask = sample[0] @@ -216,7 +280,7 @@ class SAEModel(ModelBase): st.append ( np.concatenate ( ( S[i], SS[i], #SM[i], D[i], DD[i], #DM[i], - SD[i], #SDM[i] + SD[i], SDM[i] ), axis=1) ) return [ ('SAE', np.concatenate ( st, axis=0 ) ) ] @@ -306,4 +370,54 @@ class SAEModel(ModelBase): return func + + @staticmethod + def DFEncFlow(dims=512, lowest_dense_res=8): + exec (nnlib.import_all(), locals(), globals()) + + def downscale (dim): + def func(x): + return LeakyReLU(0.1)(Conv2D(dim, 5, strides=2, padding='same')(x)) + return func + + def upscale (dim): + def func(x): + return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, 3, strides=1, padding='same')(x))) + return func + + def func(input): + x = input + + x = downscale(128)(x) + x = downscale(256)(x) + x = downscale(512)(x) + x = downscale(1024)(x) + + x = Dense(dims)(Flatten()(x)) + x = Dense(lowest_dense_res * lowest_dense_res * dims)(x) + x = Reshape((lowest_dense_res, lowest_dense_res, dims))(x) + x = upscale(dims)(x) + + return x + return func + + @staticmethod + def DFDecFlow(output_nc,dims,activation='tanh'): + exec (nnlib.import_all(), locals(), globals()) + + def upscale (dim): + def func(x): + return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, 3, strides=1, padding='same')(x))) + return func + def func(input): + x = input[0] + x = upscale(dims)(x) + x = upscale(dims//2)(x) + x = upscale(dims//4)(x) + + x = Conv2D(output_nc, kernel_size=5, padding='same', activation=activation)(x) + return x + + return func + Model = SAEModel \ No newline at end of file diff --git a/utils/console_utils.py b/utils/console_utils.py index d5ea8ae..1af6002 100644 --- a/utils/console_utils.py +++ b/utils/console_utils.py @@ -1,26 +1,54 @@ -def input_int(s, default_value, valid_list=None): - try: - inp = input(s) - i = int(inp) - if (valid_list is not None) and (i not in valid_list): +def input_int(s, default_value, valid_list=None, help_message=None): + while True: + try: + inp = input(s) + if len(inp) == 0: + raise ValueError("") + + if help_message is not None and inp == '?': + print (help_message) + continue + + i = int(inp) + if (valid_list is not None) and (i not in valid_list): + return default_value + return i + except: + print (default_value) return default_value - return i - except: - return default_value -def input_bool(s, default_value): - try: - return bool ( {"y":True,"n":False,"1":True,"0":False}.get(input(s).lower(), default_value) ) - except: - return default_value - -def input_str(s, default_value, valid_list=None): - try: - inp = input(s) - if (valid_list is not None) and (inp.lower() not in valid_list): +def input_bool(s, default_value, help_message=None): + while True: + try: + inp = input(s) + if len(inp) == 0: + raise ValueError("") + + if help_message is not None and inp == '?': + print (help_message) + continue + + return bool ( {"y":True,"n":False,"1":True,"0":False}.get(inp.lower(), default_value) ) + except: + print ( "y" if default_value else "n" ) return default_value - return inp - except: - return default_value \ No newline at end of file + +def input_str(s, default_value, valid_list=None, help_message=None): + while True: + try: + inp = input(s) + if len(inp) == 0: + raise ValueError("") + + if help_message is not None and inp == '?': + print (help_message) + continue + + if (valid_list is not None) and (inp.lower() not in valid_list): + return default_value + return inp + except: + print (default_value) + return default_value \ No newline at end of file