mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 13:02:15 -07:00
dump_ckpt
This commit is contained in:
parent
3d0e18b0ad
commit
b333fcea4b
3 changed files with 62 additions and 13 deletions
|
@ -204,6 +204,8 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
archi_type, archi_opts = archi_split
|
||||
elif len(archi_split) == 1:
|
||||
archi_type, archi_opts = archi_split[0], None
|
||||
|
||||
self.archi_type = archi_type
|
||||
|
||||
ae_dims = self.options['ae_dims']
|
||||
e_dims = self.options['e_dims']
|
||||
|
@ -236,22 +238,22 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
optimizer_vars_on_cpu = models_opt_device=='/CPU:0'
|
||||
|
||||
input_ch=3
|
||||
bgr_shape = nn.get4Dshape(resolution,resolution,input_ch)
|
||||
bgr_shape = self.bgr_shape = nn.get4Dshape(resolution,resolution,input_ch)
|
||||
mask_shape = nn.get4Dshape(resolution,resolution,1)
|
||||
self.model_filename_list = []
|
||||
|
||||
with tf.device ('/CPU:0'):
|
||||
#Place holders on CPU
|
||||
self.warped_src = tf.placeholder (nn.floatx, bgr_shape)
|
||||
self.warped_dst = tf.placeholder (nn.floatx, bgr_shape)
|
||||
self.warped_src = tf.placeholder (nn.floatx, bgr_shape, name='warped_src')
|
||||
self.warped_dst = tf.placeholder (nn.floatx, bgr_shape, name='warped_dst')
|
||||
|
||||
self.target_src = tf.placeholder (nn.floatx, bgr_shape)
|
||||
self.target_dst = tf.placeholder (nn.floatx, bgr_shape)
|
||||
self.target_src = tf.placeholder (nn.floatx, bgr_shape, name='target_src')
|
||||
self.target_dst = tf.placeholder (nn.floatx, bgr_shape, name='target_dst')
|
||||
|
||||
self.target_srcm = tf.placeholder (nn.floatx, mask_shape)
|
||||
self.target_srcm_em = tf.placeholder (nn.floatx, mask_shape)
|
||||
self.target_dstm = tf.placeholder (nn.floatx, mask_shape)
|
||||
self.target_dstm_em = tf.placeholder (nn.floatx, mask_shape)
|
||||
self.target_srcm = tf.placeholder (nn.floatx, mask_shape, name='target_srcm')
|
||||
self.target_srcm_em = tf.placeholder (nn.floatx, mask_shape, name='target_srcm_em')
|
||||
self.target_dstm = tf.placeholder (nn.floatx, mask_shape, name='target_dstm')
|
||||
self.target_dstm_em = tf.placeholder (nn.floatx, mask_shape, name='target_dstm_em')
|
||||
|
||||
# Initializing model classes
|
||||
model_archi = nn.DeepFakeArchi(resolution, opts=archi_opts)
|
||||
|
@ -609,7 +611,10 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
|
||||
if do_init:
|
||||
model.init_weights()
|
||||
|
||||
|
||||
|
||||
###############
|
||||
|
||||
# initializing sample generators
|
||||
if self.is_training:
|
||||
training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path()
|
||||
|
@ -650,7 +655,44 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
|
||||
if self.pretrain_just_disabled:
|
||||
self.update_sample_for_preview(force_new=True)
|
||||
|
||||
def dump_ckpt(self):
|
||||
tf = nn.tf
|
||||
|
||||
|
||||
with tf.device ('/CPU:0'):
|
||||
warped_dst = tf.placeholder (nn.floatx, (None, self.resolution, self.resolution, 3), name='in_face')
|
||||
warped_dst = tf.transpose(warped_dst, (0,3,1,2))
|
||||
|
||||
|
||||
if 'df' in self.archi_type:
|
||||
gpu_dst_code = self.inter(self.encoder(warped_dst))
|
||||
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code)
|
||||
_, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code)
|
||||
|
||||
elif 'liae' in self.archi_type:
|
||||
gpu_dst_code = self.encoder (warped_dst)
|
||||
gpu_dst_inter_B_code = self.inter_B (gpu_dst_code)
|
||||
gpu_dst_inter_AB_code = self.inter_AB (gpu_dst_code)
|
||||
gpu_dst_code = tf.concat([gpu_dst_inter_B_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis)
|
||||
gpu_src_dst_code = tf.concat([gpu_dst_inter_AB_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis)
|
||||
|
||||
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code)
|
||||
_, gpu_pred_dst_dstm = self.decoder(gpu_dst_code)
|
||||
|
||||
gpu_pred_src_dst = tf.transpose(gpu_pred_src_dst, (0,2,3,1))
|
||||
gpu_pred_dst_dstm = tf.transpose(gpu_pred_dst_dstm, (0,2,3,1))
|
||||
gpu_pred_src_dstm = tf.transpose(gpu_pred_src_dstm, (0,2,3,1))
|
||||
|
||||
|
||||
saver = tf.train.Saver()
|
||||
tf.identity(gpu_pred_dst_dstm, name='out_face_mask')
|
||||
tf.identity(gpu_pred_src_dst, name='out_celeb_face')
|
||||
tf.identity(gpu_pred_src_dstm, name='out_celeb_face_mask')
|
||||
|
||||
saver.save(nn.tf_sess, self.get_strpath_storage_for_file('.ckpt') )
|
||||
|
||||
|
||||
#override
|
||||
def get_model_filename_list(self):
|
||||
return self.model_filename_list
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue