mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-11 15:47:01 -07:00
AMP: code refactoring, fix preview history
added dumpdflive command
This commit is contained in:
parent
6d89d7fa4c
commit
5783191849
9 changed files with 143 additions and 144 deletions
|
@ -659,11 +659,15 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
if self.pretrain_just_disabled:
|
||||
self.update_sample_for_preview(force_new=True)
|
||||
|
||||
def dump_ckpt(self):
|
||||
def dump_dflive (self):
|
||||
output_path=self.get_strpath_storage_for_file('model.dflive')
|
||||
|
||||
io.log_info(f'Dumping .dflive to {output_path}')
|
||||
|
||||
tf = nn.tf
|
||||
nn.set_data_format('NCHW')
|
||||
|
||||
|
||||
with tf.device ('/CPU:0'):
|
||||
with tf.device (nn.tf_default_device_name):
|
||||
warped_dst = tf.placeholder (nn.floatx, (None, self.resolution, self.resolution, 3), name='in_face')
|
||||
warped_dst = tf.transpose(warped_dst, (0,3,1,2))
|
||||
|
||||
|
@ -687,15 +691,26 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
gpu_pred_dst_dstm = tf.transpose(gpu_pred_dst_dstm, (0,2,3,1))
|
||||
gpu_pred_src_dstm = tf.transpose(gpu_pred_src_dstm, (0,2,3,1))
|
||||
|
||||
|
||||
saver = tf.train.Saver()
|
||||
tf.identity(gpu_pred_dst_dstm, name='out_face_mask')
|
||||
tf.identity(gpu_pred_src_dst, name='out_celeb_face')
|
||||
tf.identity(gpu_pred_src_dstm, name='out_celeb_face_mask')
|
||||
|
||||
saver.save(nn.tf_sess, self.get_strpath_storage_for_file('.ckpt') )
|
||||
|
||||
output_graph_def = tf.graph_util.convert_variables_to_constants(
|
||||
nn.tf_sess,
|
||||
tf.get_default_graph().as_graph_def(),
|
||||
['out_face_mask','out_celeb_face','out_celeb_face_mask']
|
||||
)
|
||||
|
||||
import tf2onnx
|
||||
with tf.device("/CPU:0"):
|
||||
model_proto, _ = tf2onnx.convert._convert_common(
|
||||
output_graph_def,
|
||||
name='SAEHD',
|
||||
input_names=['in_face:0'],
|
||||
output_names=['out_face_mask:0','out_celeb_face:0','out_celeb_face_mask:0'],
|
||||
opset=13,
|
||||
output_path=output_path)
|
||||
|
||||
#override
|
||||
def get_model_filename_list(self):
|
||||
return self.model_filename_list
|
||||
|
@ -751,7 +766,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
return ( ('src_loss', np.mean(src_loss) ), ('dst_loss', np.mean(dst_loss) ), )
|
||||
|
||||
#override
|
||||
def onGetPreview(self, samples):
|
||||
def onGetPreview(self, samples, for_history=False):
|
||||
( (warped_src, target_src, target_srcm, target_srcm_em),
|
||||
(warped_dst, target_dst, target_dstm, target_dstm_em) ) = samples
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue