mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 13:02:15 -07:00
dump_ckpt
This commit is contained in:
parent
3d0e18b0ad
commit
b333fcea4b
3 changed files with 62 additions and 13 deletions
2
main.py
2
main.py
|
@ -127,6 +127,7 @@ if __name__ == "__main__":
|
||||||
'silent_start' : arguments.silent_start,
|
'silent_start' : arguments.silent_start,
|
||||||
'execute_programs' : [ [int(x[0]), x[1] ] for x in arguments.execute_program ],
|
'execute_programs' : [ [int(x[0]), x[1] ] for x in arguments.execute_program ],
|
||||||
'debug' : arguments.debug,
|
'debug' : arguments.debug,
|
||||||
|
'dump_ckpt' : arguments.dump_ckpt,
|
||||||
}
|
}
|
||||||
from mainscripts import Trainer
|
from mainscripts import Trainer
|
||||||
Trainer.main(**kwargs)
|
Trainer.main(**kwargs)
|
||||||
|
@ -144,6 +145,7 @@ if __name__ == "__main__":
|
||||||
p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Train on CPU.")
|
p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Train on CPU.")
|
||||||
p.add_argument('--force-gpu-idxs', dest="force_gpu_idxs", default=None, help="Force to choose GPU indexes separated by comma.")
|
p.add_argument('--force-gpu-idxs', dest="force_gpu_idxs", default=None, help="Force to choose GPU indexes separated by comma.")
|
||||||
p.add_argument('--silent-start', action="store_true", dest="silent_start", default=False, help="Silent start. Automatically chooses Best GPU and last used model.")
|
p.add_argument('--silent-start', action="store_true", dest="silent_start", default=False, help="Silent start. Automatically chooses Best GPU and last used model.")
|
||||||
|
p.add_argument('--dump-ckpt', action="store_true", dest="dump_ckpt", default=False, help="Dump the model to ckpt format.")
|
||||||
|
|
||||||
|
|
||||||
p.add_argument('--execute-program', dest="execute_program", default=[], action='append', nargs='+')
|
p.add_argument('--execute-program', dest="execute_program", default=[], action='append', nargs='+')
|
||||||
|
|
|
@ -27,6 +27,7 @@ def trainerThread (s2c, c2s, e,
|
||||||
silent_start=False,
|
silent_start=False,
|
||||||
execute_programs = None,
|
execute_programs = None,
|
||||||
debug=False,
|
debug=False,
|
||||||
|
dump_ckpt=False,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
|
@ -44,7 +45,7 @@ def trainerThread (s2c, c2s, e,
|
||||||
saved_models_path.mkdir(exist_ok=True, parents=True)
|
saved_models_path.mkdir(exist_ok=True, parents=True)
|
||||||
|
|
||||||
model = models.import_model(model_class_name)(
|
model = models.import_model(model_class_name)(
|
||||||
is_training=True,
|
is_training=not dump_ckpt,
|
||||||
saved_models_path=saved_models_path,
|
saved_models_path=saved_models_path,
|
||||||
training_data_src_path=training_data_src_path,
|
training_data_src_path=training_data_src_path,
|
||||||
training_data_dst_path=training_data_dst_path,
|
training_data_dst_path=training_data_dst_path,
|
||||||
|
@ -55,9 +56,13 @@ def trainerThread (s2c, c2s, e,
|
||||||
force_gpu_idxs=force_gpu_idxs,
|
force_gpu_idxs=force_gpu_idxs,
|
||||||
cpu_only=cpu_only,
|
cpu_only=cpu_only,
|
||||||
silent_start=silent_start,
|
silent_start=silent_start,
|
||||||
debug=debug,
|
debug=debug)
|
||||||
)
|
|
||||||
|
|
||||||
|
if dump_ckpt:
|
||||||
|
e.set()
|
||||||
|
model.dump_ckpt()
|
||||||
|
break
|
||||||
|
|
||||||
is_reached_goal = model.is_reached_iter_goal()
|
is_reached_goal = model.is_reached_iter_goal()
|
||||||
|
|
||||||
shared_state = { 'after_save' : False }
|
shared_state = { 'after_save' : False }
|
||||||
|
|
|
@ -204,6 +204,8 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
||||||
archi_type, archi_opts = archi_split
|
archi_type, archi_opts = archi_split
|
||||||
elif len(archi_split) == 1:
|
elif len(archi_split) == 1:
|
||||||
archi_type, archi_opts = archi_split[0], None
|
archi_type, archi_opts = archi_split[0], None
|
||||||
|
|
||||||
|
self.archi_type = archi_type
|
||||||
|
|
||||||
ae_dims = self.options['ae_dims']
|
ae_dims = self.options['ae_dims']
|
||||||
e_dims = self.options['e_dims']
|
e_dims = self.options['e_dims']
|
||||||
|
@ -236,22 +238,22 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
||||||
optimizer_vars_on_cpu = models_opt_device=='/CPU:0'
|
optimizer_vars_on_cpu = models_opt_device=='/CPU:0'
|
||||||
|
|
||||||
input_ch=3
|
input_ch=3
|
||||||
bgr_shape = nn.get4Dshape(resolution,resolution,input_ch)
|
bgr_shape = self.bgr_shape = nn.get4Dshape(resolution,resolution,input_ch)
|
||||||
mask_shape = nn.get4Dshape(resolution,resolution,1)
|
mask_shape = nn.get4Dshape(resolution,resolution,1)
|
||||||
self.model_filename_list = []
|
self.model_filename_list = []
|
||||||
|
|
||||||
with tf.device ('/CPU:0'):
|
with tf.device ('/CPU:0'):
|
||||||
#Place holders on CPU
|
#Place holders on CPU
|
||||||
self.warped_src = tf.placeholder (nn.floatx, bgr_shape)
|
self.warped_src = tf.placeholder (nn.floatx, bgr_shape, name='warped_src')
|
||||||
self.warped_dst = tf.placeholder (nn.floatx, bgr_shape)
|
self.warped_dst = tf.placeholder (nn.floatx, bgr_shape, name='warped_dst')
|
||||||
|
|
||||||
self.target_src = tf.placeholder (nn.floatx, bgr_shape)
|
self.target_src = tf.placeholder (nn.floatx, bgr_shape, name='target_src')
|
||||||
self.target_dst = tf.placeholder (nn.floatx, bgr_shape)
|
self.target_dst = tf.placeholder (nn.floatx, bgr_shape, name='target_dst')
|
||||||
|
|
||||||
self.target_srcm = tf.placeholder (nn.floatx, mask_shape)
|
self.target_srcm = tf.placeholder (nn.floatx, mask_shape, name='target_srcm')
|
||||||
self.target_srcm_em = tf.placeholder (nn.floatx, mask_shape)
|
self.target_srcm_em = tf.placeholder (nn.floatx, mask_shape, name='target_srcm_em')
|
||||||
self.target_dstm = tf.placeholder (nn.floatx, mask_shape)
|
self.target_dstm = tf.placeholder (nn.floatx, mask_shape, name='target_dstm')
|
||||||
self.target_dstm_em = tf.placeholder (nn.floatx, mask_shape)
|
self.target_dstm_em = tf.placeholder (nn.floatx, mask_shape, name='target_dstm_em')
|
||||||
|
|
||||||
# Initializing model classes
|
# Initializing model classes
|
||||||
model_archi = nn.DeepFakeArchi(resolution, opts=archi_opts)
|
model_archi = nn.DeepFakeArchi(resolution, opts=archi_opts)
|
||||||
|
@ -609,7 +611,10 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
||||||
|
|
||||||
if do_init:
|
if do_init:
|
||||||
model.init_weights()
|
model.init_weights()
|
||||||
|
|
||||||
|
|
||||||
|
###############
|
||||||
|
|
||||||
# initializing sample generators
|
# initializing sample generators
|
||||||
if self.is_training:
|
if self.is_training:
|
||||||
training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path()
|
training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path()
|
||||||
|
@ -650,7 +655,44 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
||||||
|
|
||||||
if self.pretrain_just_disabled:
|
if self.pretrain_just_disabled:
|
||||||
self.update_sample_for_preview(force_new=True)
|
self.update_sample_for_preview(force_new=True)
|
||||||
|
|
||||||
|
def dump_ckpt(self):
|
||||||
|
tf = nn.tf
|
||||||
|
|
||||||
|
|
||||||
|
with tf.device ('/CPU:0'):
|
||||||
|
warped_dst = tf.placeholder (nn.floatx, (None, self.resolution, self.resolution, 3), name='in_face')
|
||||||
|
warped_dst = tf.transpose(warped_dst, (0,3,1,2))
|
||||||
|
|
||||||
|
|
||||||
|
if 'df' in self.archi_type:
|
||||||
|
gpu_dst_code = self.inter(self.encoder(warped_dst))
|
||||||
|
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code)
|
||||||
|
_, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code)
|
||||||
|
|
||||||
|
elif 'liae' in self.archi_type:
|
||||||
|
gpu_dst_code = self.encoder (warped_dst)
|
||||||
|
gpu_dst_inter_B_code = self.inter_B (gpu_dst_code)
|
||||||
|
gpu_dst_inter_AB_code = self.inter_AB (gpu_dst_code)
|
||||||
|
gpu_dst_code = tf.concat([gpu_dst_inter_B_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis)
|
||||||
|
gpu_src_dst_code = tf.concat([gpu_dst_inter_AB_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis)
|
||||||
|
|
||||||
|
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code)
|
||||||
|
_, gpu_pred_dst_dstm = self.decoder(gpu_dst_code)
|
||||||
|
|
||||||
|
gpu_pred_src_dst = tf.transpose(gpu_pred_src_dst, (0,2,3,1))
|
||||||
|
gpu_pred_dst_dstm = tf.transpose(gpu_pred_dst_dstm, (0,2,3,1))
|
||||||
|
gpu_pred_src_dstm = tf.transpose(gpu_pred_src_dstm, (0,2,3,1))
|
||||||
|
|
||||||
|
|
||||||
|
saver = tf.train.Saver()
|
||||||
|
tf.identity(gpu_pred_dst_dstm, name='out_face_mask')
|
||||||
|
tf.identity(gpu_pred_src_dst, name='out_celeb_face')
|
||||||
|
tf.identity(gpu_pred_src_dstm, name='out_celeb_face_mask')
|
||||||
|
|
||||||
|
saver.save(nn.tf_sess, self.get_strpath_storage_for_file('.ckpt') )
|
||||||
|
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def get_model_filename_list(self):
|
def get_model_filename_list(self):
|
||||||
return self.model_filename_list
|
return self.model_filename_list
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue