mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-08-21 22:13:20 -07:00
Merge branch 'master' into pr/5290
This commit is contained in:
commit
22a140c51d
8 changed files with 297 additions and 35 deletions
|
@ -194,7 +194,7 @@ Unfortunately, there is no "make everything ok" button in DeepFaceLab. You shoul
|
|||
</td></tr>
|
||||
|
||||
<tr><td align="right">
|
||||
<a href="https://tinyurl.com/y8lntghz">Windows (magnet link)</a>
|
||||
<a href="https://tinyurl.com/d8wpuayx">Windows (magnet link)</a>
|
||||
</td><td align="center">Last release. Use torrent client to download.</td></tr>
|
||||
|
||||
<tr><td align="right">
|
||||
|
|
|
@ -77,6 +77,8 @@ class SegIEPoly():
|
|||
self.pts = np.array(pts)
|
||||
self.n_max = self.n = len(pts)
|
||||
|
||||
def mult_points(self, val):
|
||||
self.pts *= val
|
||||
|
||||
|
||||
|
||||
|
@ -137,6 +139,10 @@ class SegIEPolys():
|
|||
def dump(self):
|
||||
return {'polys' : [ poly.dump() for poly in self.polys ] }
|
||||
|
||||
def mult_points(self, val):
|
||||
for poly in self.polys:
|
||||
poly.mult_points(val)
|
||||
|
||||
@staticmethod
|
||||
def load(data=None):
|
||||
ie_polys = SegIEPolys()
|
||||
|
|
12
main.py
12
main.py
|
@ -127,6 +127,7 @@ if __name__ == "__main__":
|
|||
'silent_start' : arguments.silent_start,
|
||||
'execute_programs' : [ [int(x[0]), x[1] ] for x in arguments.execute_program ],
|
||||
'debug' : arguments.debug,
|
||||
'dump_ckpt' : arguments.dump_ckpt,
|
||||
}
|
||||
from mainscripts import Trainer
|
||||
Trainer.main(**kwargs)
|
||||
|
@ -144,6 +145,7 @@ if __name__ == "__main__":
|
|||
p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Train on CPU.")
|
||||
p.add_argument('--force-gpu-idxs', dest="force_gpu_idxs", default=None, help="Force to choose GPU indexes separated by comma.")
|
||||
p.add_argument('--silent-start', action="store_true", dest="silent_start", default=False, help="Silent start. Automatically chooses Best GPU and last used model.")
|
||||
p.add_argument('--dump-ckpt', action="store_true", dest="dump_ckpt", default=False, help="Dump the model to ckpt format.")
|
||||
|
||||
|
||||
p.add_argument('--execute-program', dest="execute_program", default=[], action='append', nargs='+')
|
||||
|
@ -253,6 +255,16 @@ if __name__ == "__main__":
|
|||
|
||||
p.set_defaults(func=process_faceset_enhancer)
|
||||
|
||||
|
||||
p = facesettool_parser.add_parser ("resize", help="Resize DFL faceset.")
|
||||
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory of aligned faces.")
|
||||
|
||||
def process_faceset_resizer(arguments):
|
||||
osex.set_process_lowest_prio()
|
||||
from mainscripts import FacesetResizer
|
||||
FacesetResizer.process_folder ( Path(arguments.input_dir) )
|
||||
p.set_defaults(func=process_faceset_resizer)
|
||||
|
||||
def process_dev_test(arguments):
|
||||
osex.set_process_lowest_prio()
|
||||
from mainscripts import dev_misc
|
||||
|
|
147
mainscripts/FacesetResizer.py
Normal file
147
mainscripts/FacesetResizer.py
Normal file
|
@ -0,0 +1,147 @@
|
|||
import multiprocessing
|
||||
import shutil
|
||||
|
||||
from core import pathex
|
||||
from core.cv2ex import *
|
||||
from core.interact import interact as io
|
||||
from core.joblib import Subprocessor
|
||||
from DFLIMG import *
|
||||
from facelib import LandmarksProcessor, FaceType
|
||||
|
||||
|
||||
class FacesetResizerSubprocessor(Subprocessor):
|
||||
|
||||
#override
|
||||
def __init__(self, image_paths, output_dirpath, image_size):
|
||||
self.image_paths = image_paths
|
||||
self.output_dirpath = output_dirpath
|
||||
self.image_size = image_size
|
||||
self.result = []
|
||||
|
||||
super().__init__('FacesetResizer', FacesetResizerSubprocessor.Cli, 600)
|
||||
|
||||
#override
|
||||
def on_clients_initialized(self):
|
||||
io.progress_bar (None, len (self.image_paths))
|
||||
|
||||
#override
|
||||
def on_clients_finalized(self):
|
||||
io.progress_bar_close()
|
||||
|
||||
#override
|
||||
def process_info_generator(self):
|
||||
base_dict = {'output_dirpath':self.output_dirpath, 'image_size':self.image_size}
|
||||
|
||||
for device_idx in range( min(8, multiprocessing.cpu_count()) ):
|
||||
client_dict = base_dict.copy()
|
||||
device_name = f'CPU #{device_idx}'
|
||||
client_dict['device_name'] = device_name
|
||||
yield device_name, {}, client_dict
|
||||
|
||||
#override
|
||||
def get_data(self, host_dict):
|
||||
if len (self.image_paths) > 0:
|
||||
return self.image_paths.pop(0)
|
||||
|
||||
#override
|
||||
def on_data_return (self, host_dict, data):
|
||||
self.image_paths.insert(0, data)
|
||||
|
||||
#override
|
||||
def on_result (self, host_dict, data, result):
|
||||
io.progress_bar_inc(1)
|
||||
if result[0] == 1:
|
||||
self.result +=[ (result[1], result[2]) ]
|
||||
|
||||
#override
|
||||
def get_result(self):
|
||||
return self.result
|
||||
|
||||
class Cli(Subprocessor.Cli):
|
||||
|
||||
#override
|
||||
def on_initialize(self, client_dict):
|
||||
self.output_dirpath = client_dict['output_dirpath']
|
||||
self.image_size = client_dict['image_size']
|
||||
self.log_info (f"Running on { client_dict['device_name'] }")
|
||||
|
||||
#override
|
||||
def process_data(self, filepath):
|
||||
try:
|
||||
dflimg = DFLIMG.load (filepath)
|
||||
if dflimg is None or not dflimg.has_data():
|
||||
self.log_err (f"{filepath.name} is not a dfl image file")
|
||||
else:
|
||||
dfl_dict = dflimg.get_dict()
|
||||
|
||||
img = cv2_imread(filepath)
|
||||
h,w = img.shape[:2]
|
||||
if h != w:
|
||||
raise Exception(f'w != h in {filepath}')
|
||||
|
||||
image_size = self.image_size
|
||||
scale = w / image_size
|
||||
|
||||
img = cv2.resize(img, (image_size, image_size), interpolation=cv2.INTER_LANCZOS4)
|
||||
|
||||
output_filepath = self.output_dirpath / filepath.name
|
||||
cv2_imwrite ( str(output_filepath), img, [int(cv2.IMWRITE_JPEG_QUALITY), 100] )
|
||||
|
||||
dflimg = DFLIMG.load (output_filepath)
|
||||
dflimg.set_dict(dfl_dict)
|
||||
|
||||
|
||||
lmrks = dflimg.get_landmarks()
|
||||
lmrks /= scale
|
||||
dflimg.set_landmarks(lmrks)
|
||||
|
||||
seg_ie_polys = dflimg.get_seg_ie_polys()
|
||||
seg_ie_polys.mult_points( 1.0 / scale)
|
||||
dflimg.set_seg_ie_polys(seg_ie_polys)
|
||||
|
||||
mat = dflimg.get_image_to_face_mat()
|
||||
if mat is not None:
|
||||
face_type = FaceType.fromString ( dflimg.get_face_type() )
|
||||
mat = LandmarksProcessor.get_transform_mat ( dflimg.get_source_landmarks(), image_size, face_type )
|
||||
dflimg.set_image_to_face_mat(mat)
|
||||
dflimg.save()
|
||||
|
||||
return (1, filepath, output_filepath)
|
||||
except:
|
||||
self.log_err (f"Exception occured while processing file {filepath}. Error: {traceback.format_exc()}")
|
||||
|
||||
return (0, filepath, None)
|
||||
|
||||
def process_folder ( dirpath):
|
||||
|
||||
image_size = io.input_int(f"New image size", 512, valid_range=[256,2048])
|
||||
|
||||
|
||||
output_dirpath = dirpath.parent / (dirpath.name + '_resized')
|
||||
output_dirpath.mkdir (exist_ok=True, parents=True)
|
||||
|
||||
dirpath_parts = '/'.join( dirpath.parts[-2:])
|
||||
output_dirpath_parts = '/'.join( output_dirpath.parts[-2:] )
|
||||
io.log_info (f"Resizing faceset in {dirpath_parts}")
|
||||
io.log_info ( f"Processing to {output_dirpath_parts}")
|
||||
|
||||
output_images_paths = pathex.get_image_paths(output_dirpath)
|
||||
if len(output_images_paths) > 0:
|
||||
for filename in output_images_paths:
|
||||
Path(filename).unlink()
|
||||
|
||||
image_paths = [Path(x) for x in pathex.get_image_paths( dirpath )]
|
||||
result = FacesetResizerSubprocessor ( image_paths, output_dirpath, image_size).run()
|
||||
|
||||
is_merge = io.input_bool (f"\r\nMerge {output_dirpath_parts} to {dirpath_parts} ?", True)
|
||||
if is_merge:
|
||||
io.log_info (f"Copying processed files to {dirpath_parts}")
|
||||
|
||||
for (filepath, output_filepath) in result:
|
||||
try:
|
||||
shutil.copy (output_filepath, filepath)
|
||||
except:
|
||||
pass
|
||||
|
||||
io.log_info (f"Removing {output_dirpath_parts}")
|
||||
shutil.rmtree(output_dirpath)
|
|
@ -27,6 +27,7 @@ def trainerThread (s2c, c2s, e,
|
|||
silent_start=False,
|
||||
execute_programs = None,
|
||||
debug=False,
|
||||
dump_ckpt=False,
|
||||
**kwargs):
|
||||
while True:
|
||||
try:
|
||||
|
@ -44,7 +45,7 @@ def trainerThread (s2c, c2s, e,
|
|||
saved_models_path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
model = models.import_model(model_class_name)(
|
||||
is_training=True,
|
||||
is_training=not dump_ckpt,
|
||||
saved_models_path=saved_models_path,
|
||||
training_data_src_path=training_data_src_path,
|
||||
training_data_dst_path=training_data_dst_path,
|
||||
|
@ -55,8 +56,12 @@ def trainerThread (s2c, c2s, e,
|
|||
force_gpu_idxs=force_gpu_idxs,
|
||||
cpu_only=cpu_only,
|
||||
silent_start=silent_start,
|
||||
debug=debug,
|
||||
)
|
||||
debug=debug)
|
||||
|
||||
if dump_ckpt:
|
||||
e.set()
|
||||
model.dump_ckpt()
|
||||
break
|
||||
|
||||
is_reached_goal = model.is_reached_iter_goal()
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ from core.cv2ex import *
|
|||
from core.interact import interact as io
|
||||
from core.leras import nn
|
||||
from DFLIMG import *
|
||||
from facelib import XSegNet
|
||||
from facelib import XSegNet, LandmarksProcessor, FaceType
|
||||
|
||||
|
||||
def apply_xseg(input_path, model_path):
|
||||
|
@ -20,6 +20,15 @@ def apply_xseg(input_path, model_path):
|
|||
if not model_path.exists():
|
||||
raise ValueError(f'{model_path} not found. Please ensure it exists.')
|
||||
|
||||
face_type = io.input_str ("XSeg model face type", 'same', ['h','mf','f','wf','head','same'], help_message="Specify face type of trained XSeg model. For example if XSeg model trained as WF, but faceset is HEAD, specify WF to apply xseg only on WF part of HEAD. Default is 'same'").lower()
|
||||
if face_type == 'same':
|
||||
face_type = None
|
||||
else:
|
||||
face_type = {'h' : FaceType.HALF,
|
||||
'mf' : FaceType.MID_FULL,
|
||||
'f' : FaceType.FULL,
|
||||
'wf' : FaceType.WHOLE_FACE,
|
||||
'head' : FaceType.HEAD}[face_type]
|
||||
io.log_info(f'Applying trained XSeg model to {input_path.name}/ folder.')
|
||||
|
||||
device_config = nn.DeviceConfig.ask_choose_device(choose_only_one=True)
|
||||
|
@ -30,7 +39,7 @@ def apply_xseg(input_path, model_path):
|
|||
weights_file_root=model_path,
|
||||
data_format=nn.data_format,
|
||||
raise_on_no_model_files=True)
|
||||
res = xseg.get_resolution()
|
||||
xseg_res = xseg.get_resolution()
|
||||
|
||||
images_paths = pathex.get_image_paths(input_path, return_Path_class=True)
|
||||
|
||||
|
@ -42,15 +51,36 @@ def apply_xseg(input_path, model_path):
|
|||
|
||||
img = cv2_imread(filepath).astype(np.float32) / 255.0
|
||||
h,w,c = img.shape
|
||||
if w != res:
|
||||
img = cv2.resize( img, (res,res), interpolation=cv2.INTER_CUBIC )
|
||||
|
||||
img_face_type = FaceType.fromString( dflimg.get_face_type() )
|
||||
if face_type is not None and img_face_type != face_type:
|
||||
lmrks = dflimg.get_source_landmarks()
|
||||
|
||||
fmat = LandmarksProcessor.get_transform_mat(lmrks, w, face_type)
|
||||
imat = LandmarksProcessor.get_transform_mat(lmrks, w, img_face_type)
|
||||
|
||||
g_p = LandmarksProcessor.transform_points (np.float32([(0,0),(w,0),(0,w) ]), fmat, True)
|
||||
g_p2 = LandmarksProcessor.transform_points (g_p, imat)
|
||||
|
||||
mat = cv2.getAffineTransform( g_p2, np.float32([(0,0),(w,0),(0,w) ]) )
|
||||
|
||||
img = cv2.warpAffine(img, mat, (w, w), cv2.INTER_LANCZOS4)
|
||||
img = cv2.resize(img, (xseg_res, xseg_res), interpolation=cv2.INTER_LANCZOS4)
|
||||
else:
|
||||
if w != xseg_res:
|
||||
img = cv2.resize( img, (xseg_res,xseg_res), interpolation=cv2.INTER_LANCZOS4 )
|
||||
|
||||
if len(img.shape) == 2:
|
||||
img = img[...,None]
|
||||
|
||||
mask = xseg.extract(img)
|
||||
|
||||
if face_type is not None and img_face_type != face_type:
|
||||
mask = cv2.resize(mask, (w, w), interpolation=cv2.INTER_LANCZOS4)
|
||||
mask = cv2.warpAffine( mask, mat, (w,w), np.zeros( (h,w,c), dtype=np.float), cv2.WARP_INVERSE_MAP | cv2.INTER_LANCZOS4)
|
||||
mask = cv2.resize(mask, (xseg_res, xseg_res), interpolation=cv2.INTER_LANCZOS4)
|
||||
mask[mask < 0.5]=0
|
||||
mask[mask >= 0.5]=1
|
||||
|
||||
dflimg.set_xseg_mask(mask)
|
||||
dflimg.save()
|
||||
|
||||
|
@ -67,7 +97,8 @@ def fetch_xseg(input_path):
|
|||
|
||||
images_paths = pathex.get_image_paths(input_path, return_Path_class=True)
|
||||
|
||||
files_copied = 0
|
||||
|
||||
files_copied = []
|
||||
for filepath in io.progress_bar_generator(images_paths, "Processing"):
|
||||
dflimg = DFLIMG.load(filepath)
|
||||
if dflimg is None or not dflimg.has_data():
|
||||
|
@ -77,10 +108,16 @@ def fetch_xseg(input_path):
|
|||
ie_polys = dflimg.get_seg_ie_polys()
|
||||
|
||||
if ie_polys.has_polys():
|
||||
files_copied += 1
|
||||
files_copied.append(filepath)
|
||||
shutil.copy ( str(filepath), str(output_path / filepath.name) )
|
||||
|
||||
io.log_info(f'Files copied: {files_copied}')
|
||||
io.log_info(f'Files copied: {len(files_copied)}')
|
||||
|
||||
is_delete = io.input_bool (f"\r\nDelete original files?", True)
|
||||
if is_delete:
|
||||
for filepath in files_copied:
|
||||
Path(filepath).unlink()
|
||||
|
||||
|
||||
def remove_xseg(input_path):
|
||||
if not input_path.exists():
|
||||
|
|
|
@ -185,6 +185,8 @@ class ModelBase(object):
|
|||
self.write_preview_history = self.options.get('write_preview_history', False)
|
||||
self.target_iter = self.options.get('target_iter',0)
|
||||
self.random_flip = self.options.get('random_flip',True)
|
||||
self.random_src_flip = self.options.get('random_src_flip', False)
|
||||
self.random_dst_flip = self.options.get('random_dst_flip', True)
|
||||
|
||||
self.on_initialize()
|
||||
self.options['batch_size'] = self.batch_size
|
||||
|
@ -298,6 +300,14 @@ class ModelBase(object):
|
|||
default_random_flip = self.load_or_def_option('random_flip', True)
|
||||
self.options['random_flip'] = io.input_bool("Flip faces randomly", default_random_flip, help_message="Predicted face will look more naturally without this option, but src faceset should cover all face directions as dst faceset.")
|
||||
|
||||
def ask_random_src_flip(self):
|
||||
default_random_src_flip = self.load_or_def_option('random_src_flip', False)
|
||||
self.options['random_src_flip'] = io.input_bool("Flip SRC faces randomly", default_random_src_flip, help_message="Random horizontal flip SRC faceset. Covers more angles, but the face may look less naturally.")
|
||||
|
||||
def ask_random_dst_flip(self):
|
||||
default_random_dst_flip = self.load_or_def_option('random_dst_flip', True)
|
||||
self.options['random_dst_flip'] = io.input_bool("Flip DST faces randomly", default_random_dst_flip, help_message="Random horizontal flip DST faceset. Makes generalization of src->dst better, if src random flip is not enabled.")
|
||||
|
||||
def ask_batch_size(self, suggest_batch_size=None, range=None):
|
||||
default_batch_size = self.load_or_def_option('batch_size', suggest_batch_size or self.batch_size)
|
||||
|
||||
|
|
|
@ -65,7 +65,8 @@ class SAEHDModel(ModelBase):
|
|||
self.ask_autobackup_hour()
|
||||
self.ask_write_preview_history()
|
||||
self.ask_target_iter()
|
||||
self.ask_random_flip()
|
||||
self.ask_random_src_flip()
|
||||
self.ask_random_dst_flip()
|
||||
self.ask_batch_size(suggest_batch_size)
|
||||
|
||||
if self.is_first_run():
|
||||
|
@ -169,7 +170,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
self.options['ct_mode'] = io.input_str (f"Color transfer for src faceset", default_ct_mode, ['none','rct','lct','mkl','idt','sot'], help_message="Change color distribution of src samples close to dst samples. Try all modes to find the best.")
|
||||
self.options['clipgrad'] = io.input_bool ("Enable gradient clipping", default_clipgrad, help_message="Gradient clipping reduces chance of model collapse, sacrificing speed of training.")
|
||||
|
||||
self.options['pretrain'] = io.input_bool ("Enable pretraining mode", default_pretrain, help_message="Pretrain the model with large amount of various faces. After that, model can be used to train the fakes more quickly.")
|
||||
self.options['pretrain'] = io.input_bool ("Enable pretraining mode", default_pretrain, help_message="Pretrain the model with large amount of various faces. After that, model can be used to train the fakes more quickly. Forces random_warp=Y, random_flips=Y, gan_power=0.0, lr_dropout=N, styles=0.0, uniform_yaw=Y")
|
||||
|
||||
if self.options['pretrain'] and self.get_pretraining_data_path() is None:
|
||||
raise Exception("pretraining_data_path is not defined")
|
||||
|
@ -205,6 +206,8 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
elif len(archi_split) == 1:
|
||||
archi_type, archi_opts = archi_split[0], None
|
||||
|
||||
self.archi_type = archi_type
|
||||
|
||||
ae_dims = self.options['ae_dims']
|
||||
e_dims = self.options['e_dims']
|
||||
d_dims = self.options['d_dims']
|
||||
|
@ -217,6 +220,8 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
|
||||
self.gan_power = gan_power = 0.0 if self.pretrain else self.options['gan_power']
|
||||
random_warp = False if self.pretrain else self.options['random_warp']
|
||||
random_src_flip = self.random_src_flip if not self.pretrain else True
|
||||
random_dst_flip = self.random_dst_flip if not self.pretrain else True
|
||||
|
||||
if self.pretrain:
|
||||
self.options_show_override['gan_power'] = 0.0
|
||||
|
@ -236,22 +241,22 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
optimizer_vars_on_cpu = models_opt_device=='/CPU:0'
|
||||
|
||||
input_ch=3
|
||||
bgr_shape = nn.get4Dshape(resolution,resolution,input_ch)
|
||||
bgr_shape = self.bgr_shape = nn.get4Dshape(resolution,resolution,input_ch)
|
||||
mask_shape = nn.get4Dshape(resolution,resolution,1)
|
||||
self.model_filename_list = []
|
||||
|
||||
with tf.device ('/CPU:0'):
|
||||
#Place holders on CPU
|
||||
self.warped_src = tf.placeholder (nn.floatx, bgr_shape)
|
||||
self.warped_dst = tf.placeholder (nn.floatx, bgr_shape)
|
||||
self.warped_src = tf.placeholder (nn.floatx, bgr_shape, name='warped_src')
|
||||
self.warped_dst = tf.placeholder (nn.floatx, bgr_shape, name='warped_dst')
|
||||
|
||||
self.target_src = tf.placeholder (nn.floatx, bgr_shape)
|
||||
self.target_dst = tf.placeholder (nn.floatx, bgr_shape)
|
||||
self.target_src = tf.placeholder (nn.floatx, bgr_shape, name='target_src')
|
||||
self.target_dst = tf.placeholder (nn.floatx, bgr_shape, name='target_dst')
|
||||
|
||||
self.target_srcm = tf.placeholder (nn.floatx, mask_shape)
|
||||
self.target_srcm_em = tf.placeholder (nn.floatx, mask_shape)
|
||||
self.target_dstm = tf.placeholder (nn.floatx, mask_shape)
|
||||
self.target_dstm_em = tf.placeholder (nn.floatx, mask_shape)
|
||||
self.target_srcm = tf.placeholder (nn.floatx, mask_shape, name='target_srcm')
|
||||
self.target_srcm_em = tf.placeholder (nn.floatx, mask_shape, name='target_srcm_em')
|
||||
self.target_dstm = tf.placeholder (nn.floatx, mask_shape, name='target_dstm')
|
||||
self.target_dstm_em = tf.placeholder (nn.floatx, mask_shape, name='target_dstm_em')
|
||||
|
||||
# Initializing model classes
|
||||
model_archi = nn.DeepFakeArchi(resolution, opts=archi_opts)
|
||||
|
@ -610,6 +615,9 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
if do_init:
|
||||
model.init_weights()
|
||||
|
||||
|
||||
###############
|
||||
|
||||
# initializing sample generators
|
||||
if self.is_training:
|
||||
training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path()
|
||||
|
@ -625,7 +633,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
|
||||
self.set_training_data_generators ([
|
||||
SampleGeneratorFace(training_data_src_path, random_ct_samples_path=random_ct_samples_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
|
||||
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip),
|
||||
sample_process_options=SampleProcessor.Options(random_flip=random_src_flip),
|
||||
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
|
@ -635,7 +643,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
generators_count=src_generators_count ),
|
||||
|
||||
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
|
||||
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip),
|
||||
sample_process_options=SampleProcessor.Options(random_flip=random_dst_flip),
|
||||
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
|
@ -651,6 +659,43 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
if self.pretrain_just_disabled:
|
||||
self.update_sample_for_preview(force_new=True)
|
||||
|
||||
def dump_ckpt(self):
|
||||
tf = nn.tf
|
||||
|
||||
|
||||
with tf.device ('/CPU:0'):
|
||||
warped_dst = tf.placeholder (nn.floatx, (None, self.resolution, self.resolution, 3), name='in_face')
|
||||
warped_dst = tf.transpose(warped_dst, (0,3,1,2))
|
||||
|
||||
|
||||
if 'df' in self.archi_type:
|
||||
gpu_dst_code = self.inter(self.encoder(warped_dst))
|
||||
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code)
|
||||
_, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code)
|
||||
|
||||
elif 'liae' in self.archi_type:
|
||||
gpu_dst_code = self.encoder (warped_dst)
|
||||
gpu_dst_inter_B_code = self.inter_B (gpu_dst_code)
|
||||
gpu_dst_inter_AB_code = self.inter_AB (gpu_dst_code)
|
||||
gpu_dst_code = tf.concat([gpu_dst_inter_B_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis)
|
||||
gpu_src_dst_code = tf.concat([gpu_dst_inter_AB_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis)
|
||||
|
||||
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code)
|
||||
_, gpu_pred_dst_dstm = self.decoder(gpu_dst_code)
|
||||
|
||||
gpu_pred_src_dst = tf.transpose(gpu_pred_src_dst, (0,2,3,1))
|
||||
gpu_pred_dst_dstm = tf.transpose(gpu_pred_dst_dstm, (0,2,3,1))
|
||||
gpu_pred_src_dstm = tf.transpose(gpu_pred_src_dstm, (0,2,3,1))
|
||||
|
||||
|
||||
saver = tf.train.Saver()
|
||||
tf.identity(gpu_pred_dst_dstm, name='out_face_mask')
|
||||
tf.identity(gpu_pred_src_dst, name='out_celeb_face')
|
||||
tf.identity(gpu_pred_src_dstm, name='out_celeb_face_mask')
|
||||
|
||||
saver.save(nn.tf_sess, self.get_strpath_storage_for_file('.ckpt') )
|
||||
|
||||
|
||||
#override
|
||||
def get_model_filename_list(self):
|
||||
return self.model_filename_list
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue