diff --git a/main.py b/main.py index 2ba085f..cad3e92 100644 --- a/main.py +++ b/main.py @@ -127,6 +127,7 @@ if __name__ == "__main__": 'silent_start' : arguments.silent_start, 'execute_programs' : [ [int(x[0]), x[1] ] for x in arguments.execute_program ], 'debug' : arguments.debug, + 'dump_ckpt' : arguments.dump_ckpt, } from mainscripts import Trainer Trainer.main(**kwargs) @@ -144,6 +145,7 @@ if __name__ == "__main__": p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Train on CPU.") p.add_argument('--force-gpu-idxs', dest="force_gpu_idxs", default=None, help="Force to choose GPU indexes separated by comma.") p.add_argument('--silent-start', action="store_true", dest="silent_start", default=False, help="Silent start. Automatically chooses Best GPU and last used model.") + p.add_argument('--dump-ckpt', action="store_true", dest="dump_ckpt", default=False, help="Dump the model to ckpt format.") p.add_argument('--execute-program', dest="execute_program", default=[], action='append', nargs='+') diff --git a/mainscripts/Trainer.py b/mainscripts/Trainer.py index 7d73e2f..4afc218 100644 --- a/mainscripts/Trainer.py +++ b/mainscripts/Trainer.py @@ -27,6 +27,7 @@ def trainerThread (s2c, c2s, e, silent_start=False, execute_programs = None, debug=False, + dump_ckpt=False, **kwargs): while True: try: @@ -44,7 +45,7 @@ def trainerThread (s2c, c2s, e, saved_models_path.mkdir(exist_ok=True, parents=True) model = models.import_model(model_class_name)( - is_training=True, + is_training=not dump_ckpt, saved_models_path=saved_models_path, training_data_src_path=training_data_src_path, training_data_dst_path=training_data_dst_path, @@ -55,9 +56,13 @@ def trainerThread (s2c, c2s, e, force_gpu_idxs=force_gpu_idxs, cpu_only=cpu_only, silent_start=silent_start, - debug=debug, - ) + debug=debug) + if dump_ckpt: + e.set() + model.dump_ckpt() + break + is_reached_goal = model.is_reached_iter_goal() shared_state = { 'after_save' : False } diff --git a/models/Model_SAEHD/Model.py b/models/Model_SAEHD/Model.py index 380a053..0ef99a6 100644 --- a/models/Model_SAEHD/Model.py +++ b/models/Model_SAEHD/Model.py @@ -204,6 +204,8 @@ Examples: df, liae, df-d, df-ud, liae-ud, ... archi_type, archi_opts = archi_split elif len(archi_split) == 1: archi_type, archi_opts = archi_split[0], None + + self.archi_type = archi_type ae_dims = self.options['ae_dims'] e_dims = self.options['e_dims'] @@ -236,22 +238,22 @@ Examples: df, liae, df-d, df-ud, liae-ud, ... optimizer_vars_on_cpu = models_opt_device=='/CPU:0' input_ch=3 - bgr_shape = nn.get4Dshape(resolution,resolution,input_ch) + bgr_shape = self.bgr_shape = nn.get4Dshape(resolution,resolution,input_ch) mask_shape = nn.get4Dshape(resolution,resolution,1) self.model_filename_list = [] with tf.device ('/CPU:0'): #Place holders on CPU - self.warped_src = tf.placeholder (nn.floatx, bgr_shape) - self.warped_dst = tf.placeholder (nn.floatx, bgr_shape) + self.warped_src = tf.placeholder (nn.floatx, bgr_shape, name='warped_src') + self.warped_dst = tf.placeholder (nn.floatx, bgr_shape, name='warped_dst') - self.target_src = tf.placeholder (nn.floatx, bgr_shape) - self.target_dst = tf.placeholder (nn.floatx, bgr_shape) + self.target_src = tf.placeholder (nn.floatx, bgr_shape, name='target_src') + self.target_dst = tf.placeholder (nn.floatx, bgr_shape, name='target_dst') - self.target_srcm = tf.placeholder (nn.floatx, mask_shape) - self.target_srcm_em = tf.placeholder (nn.floatx, mask_shape) - self.target_dstm = tf.placeholder (nn.floatx, mask_shape) - self.target_dstm_em = tf.placeholder (nn.floatx, mask_shape) + self.target_srcm = tf.placeholder (nn.floatx, mask_shape, name='target_srcm') + self.target_srcm_em = tf.placeholder (nn.floatx, mask_shape, name='target_srcm_em') + self.target_dstm = tf.placeholder (nn.floatx, mask_shape, name='target_dstm') + self.target_dstm_em = tf.placeholder (nn.floatx, mask_shape, name='target_dstm_em') # Initializing model classes model_archi = nn.DeepFakeArchi(resolution, opts=archi_opts) @@ -609,7 +611,10 @@ Examples: df, liae, df-d, df-ud, liae-ud, ... if do_init: model.init_weights() - + + + ############### + # initializing sample generators if self.is_training: training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path() @@ -650,7 +655,44 @@ Examples: df, liae, df-d, df-ud, liae-ud, ... if self.pretrain_just_disabled: self.update_sample_for_preview(force_new=True) + + def dump_ckpt(self): + tf = nn.tf + + + with tf.device ('/CPU:0'): + warped_dst = tf.placeholder (nn.floatx, (None, self.resolution, self.resolution, 3), name='in_face') + warped_dst = tf.transpose(warped_dst, (0,3,1,2)) + + + if 'df' in self.archi_type: + gpu_dst_code = self.inter(self.encoder(warped_dst)) + gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code) + _, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code) + elif 'liae' in self.archi_type: + gpu_dst_code = self.encoder (warped_dst) + gpu_dst_inter_B_code = self.inter_B (gpu_dst_code) + gpu_dst_inter_AB_code = self.inter_AB (gpu_dst_code) + gpu_dst_code = tf.concat([gpu_dst_inter_B_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis) + gpu_src_dst_code = tf.concat([gpu_dst_inter_AB_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis) + + gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code) + _, gpu_pred_dst_dstm = self.decoder(gpu_dst_code) + + gpu_pred_src_dst = tf.transpose(gpu_pred_src_dst, (0,2,3,1)) + gpu_pred_dst_dstm = tf.transpose(gpu_pred_dst_dstm, (0,2,3,1)) + gpu_pred_src_dstm = tf.transpose(gpu_pred_src_dstm, (0,2,3,1)) + + + saver = tf.train.Saver() + tf.identity(gpu_pred_dst_dstm, name='out_face_mask') + tf.identity(gpu_pred_src_dst, name='out_celeb_face') + tf.identity(gpu_pred_src_dstm, name='out_celeb_face_mask') + + saver.save(nn.tf_sess, self.get_strpath_storage_for_file('.ckpt') ) + + #override def get_model_filename_list(self): return self.model_filename_list