mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-08-19 13:09:56 -07:00
add src-src merge and use stored aligned img mat for merging
This commit is contained in:
parent
075e8aaf82
commit
5e4e0a4d23
9 changed files with 200 additions and 61 deletions
4
main.py
4
main.py
|
@ -164,7 +164,8 @@ if __name__ == "__main__":
|
|||
output_mask_path = Path(arguments.output_mask_dir),
|
||||
aligned_path = Path(arguments.aligned_dir) if arguments.aligned_dir is not None else None,
|
||||
force_gpu_idxs = arguments.force_gpu_idxs,
|
||||
cpu_only = arguments.cpu_only)
|
||||
cpu_only = arguments.cpu_only,
|
||||
src_src = arguments.src_src)
|
||||
|
||||
p = subparsers.add_parser( "merge", help="Merger")
|
||||
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory. A directory containing the files you wish to process.")
|
||||
|
@ -176,6 +177,7 @@ if __name__ == "__main__":
|
|||
p.add_argument('--force-model-name', dest="force_model_name", default=None, help="Forcing to choose model name from model/ folder.")
|
||||
p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Merge on CPU.")
|
||||
p.add_argument('--force-gpu-idxs', dest="force_gpu_idxs", default=None, help="Force to choose GPU indexes separated by comma.")
|
||||
p.add_argument('--src-src', action="store_true", dest="src_src", default=False, help="Enables special src-src predicted output.")
|
||||
p.set_defaults(func=process_merge)
|
||||
|
||||
videoed_parser = subparsers.add_parser( "videoed", help="Video processing.").add_subparsers()
|
||||
|
|
|
@ -26,7 +26,8 @@ def main (model_class_name=None,
|
|||
output_mask_path=None,
|
||||
aligned_path=None,
|
||||
force_gpu_idxs=None,
|
||||
cpu_only=None):
|
||||
cpu_only=None,
|
||||
src_src=False):
|
||||
io.log_info ("Running merger.\r\n")
|
||||
|
||||
try:
|
||||
|
@ -47,6 +48,7 @@ def main (model_class_name=None,
|
|||
# Initialize model
|
||||
import models
|
||||
model = models.import_model(model_class_name)(is_training=False,
|
||||
src_src=src_src,
|
||||
saved_models_path=saved_models_path,
|
||||
force_gpu_idxs=force_gpu_idxs,
|
||||
cpu_only=cpu_only)
|
||||
|
@ -68,7 +70,10 @@ def main (model_class_name=None,
|
|||
place_model_on_cpu=True,
|
||||
run_on_cpu=run_on_cpu)
|
||||
|
||||
is_interactive = io.input_bool ("Use interactive merger?", True) if not io.is_colab() else False
|
||||
if src_src:
|
||||
is_interactive = False
|
||||
else:
|
||||
is_interactive = io.input_bool ("Use interactive merger?", True) if not io.is_colab() else False
|
||||
|
||||
if not is_interactive:
|
||||
cfg.ask_settings()
|
||||
|
@ -110,7 +115,7 @@ def main (model_class_name=None,
|
|||
io.log_err (f"{filepath.name} is not a dfl image file")
|
||||
continue
|
||||
|
||||
source_filename = dflimg.get_source_filename()
|
||||
source_filename = filepath.name if src_src else dflimg.get_source_filename()
|
||||
if source_filename is None:
|
||||
continue
|
||||
|
||||
|
@ -121,7 +126,7 @@ def main (model_class_name=None,
|
|||
alignments[ source_filename_stem ] = []
|
||||
|
||||
alignments_ar = alignments[ source_filename_stem ]
|
||||
alignments_ar.append ( (dflimg.get_source_landmarks(), filepath, source_filepath ) )
|
||||
alignments_ar.append ( (dflimg.get_source_landmarks(), filepath, source_filepath, dflimg.get_image_to_face_mat(), dflimg.get_shape()[0]) )
|
||||
|
||||
if len(alignments_ar) > 1:
|
||||
multiple_faces_detected = True
|
||||
|
@ -137,23 +142,32 @@ def main (model_class_name=None,
|
|||
for _, filepath, source_filepath in a_ar:
|
||||
io.log_info (f"alignment {filepath.name} refers to {source_filepath.name} ")
|
||||
io.log_info ("")
|
||||
|
||||
alignments[a_key] = [ a[0] for a in a_ar]
|
||||
alignments[a_key] = [ (a[0], a[3], a[4]) for a in a_ar]
|
||||
|
||||
if multiple_faces_detected:
|
||||
io.log_info ("It is strongly recommended to process the faces separatelly.")
|
||||
io.log_info ("Use 'recover original filename' to determine the exact duplicates.")
|
||||
io.log_info ("")
|
||||
|
||||
frames = [ InteractiveMergerSubprocessor.Frame( frame_info=FrameInfo(filepath=Path(p),
|
||||
landmarks_list=alignments.get(Path(p).stem, None)
|
||||
)
|
||||
)
|
||||
for p in input_path_image_paths ]
|
||||
frames = []
|
||||
for p in input_path_image_paths:
|
||||
alignment = alignments.get(Path(p).stem, None)
|
||||
landmarks_list = None
|
||||
image_to_face_mat = None
|
||||
aligned_size = None
|
||||
if alignment is not None:
|
||||
landmarks_list, image_to_face_mat, aligned_size = alignment[0]
|
||||
landmarks_list = [landmarks_list]
|
||||
frame_info = FrameInfo(filepath=Path(p), landmarks_list=landmarks_list, image_to_face_mat=image_to_face_mat, aligned_size=aligned_size)
|
||||
frame = InteractiveMergerSubprocessor.Frame(frame_info=frame_info)
|
||||
frames.append(frame)
|
||||
|
||||
if multiple_faces_detected:
|
||||
io.log_info ("Warning: multiple faces detected. Motion blur will not be used.")
|
||||
io.log_info ("")
|
||||
elif src_src:
|
||||
io.log_info ("SRC-SRC mode configured, skipping motion blur calculation...")
|
||||
io.log_info ("")
|
||||
else:
|
||||
s = 256
|
||||
local_pts = [ (s//2-1, s//2-1), (s//2-1,0) ] #center+up
|
||||
|
@ -206,6 +220,7 @@ def main (model_class_name=None,
|
|||
output_mask_path = output_mask_path,
|
||||
model_iter = model.get_iter(),
|
||||
subprocess_count = subprocess_count,
|
||||
src_src = src_src,
|
||||
).run()
|
||||
|
||||
model.finalize()
|
||||
|
|
|
@ -69,6 +69,7 @@ def trainerThread (s2c, c2s, e,
|
|||
start_time = time.time()
|
||||
|
||||
save_interval_min = 15
|
||||
tensorboard_preview_interval_min = 5
|
||||
|
||||
if not training_data_src_path.exists():
|
||||
training_data_src_path.mkdir(exist_ok=True, parents=True)
|
||||
|
@ -96,20 +97,14 @@ def trainerThread (s2c, c2s, e,
|
|||
|
||||
is_reached_goal = model.is_reached_iter_goal()
|
||||
|
||||
train_summary_writer = None
|
||||
if tensorboard_dir is not None:
|
||||
try:
|
||||
import tensorboardX
|
||||
if not os.path.exists(tensorboard_dir):
|
||||
os.makedirs(tensorboard_dir)
|
||||
summary_writer_folder = os.path.join(tensorboard_dir, model.model_name)
|
||||
train_summary_writer = tensorboardX.SummaryWriter(summary_writer_folder)
|
||||
if start_tensorboard:
|
||||
tb_tool = TensorBoardTool(tensorboard_dir)
|
||||
tb_tool.run()
|
||||
except:
|
||||
print("Error importing tensorboardX, please ensure it is installed (pip install tensorboardX)")
|
||||
print("Continuing training without tensorboard logging...")
|
||||
c2s.put({
|
||||
'op': 'tb',
|
||||
'action': 'init',
|
||||
'model_name': model.model_name,
|
||||
'tensorboard_dir': tensorboard_dir,
|
||||
'start_tensorboard': start_tensorboard
|
||||
})
|
||||
|
||||
shared_state = { 'after_save' : False }
|
||||
loss_string = ""
|
||||
|
@ -122,12 +117,29 @@ def trainerThread (s2c, c2s, e,
|
|||
|
||||
def model_backup():
|
||||
if not debug and not is_reached_goal:
|
||||
model.create_backup()
|
||||
model.create_backup()
|
||||
|
||||
def log_step(step, step_time, src_loss, dst_loss):
|
||||
c2s.put({
|
||||
'op': 'tb',
|
||||
'action': 'step',
|
||||
'step': step,
|
||||
'step_time': step_time,
|
||||
'src_loss': src_loss,
|
||||
'dst_loss': dst_loss
|
||||
})
|
||||
|
||||
def log_previews(step, previews, static_previews):
|
||||
c2s.put({
|
||||
'op': 'tb',
|
||||
'action': 'preview',
|
||||
'step': step,
|
||||
'previews': previews,
|
||||
'static_previews': static_previews
|
||||
})
|
||||
|
||||
def send_preview():
|
||||
if not debug:
|
||||
if train_summary_writer is not None:
|
||||
log_tensorboard_model_previews(iter, model, train_summary_writer)
|
||||
previews = model.get_previews()
|
||||
c2s.put ( {'op':'show', 'previews': previews, 'iter':model.get_iter(), 'loss_history': model.get_loss_history().copy() } )
|
||||
else:
|
||||
|
@ -144,6 +156,7 @@ def trainerThread (s2c, c2s, e,
|
|||
io.log_info('Starting. Press "Enter" to stop training and save model.')
|
||||
|
||||
last_save_time = time.time()
|
||||
last_preview_time = time.time()
|
||||
|
||||
execute_programs = [ [x[0], x[1], time.time() ] for x in execute_programs ]
|
||||
|
||||
|
@ -203,13 +216,8 @@ def trainerThread (s2c, c2s, e,
|
|||
else:
|
||||
io.log_info (loss_string, end='\r')
|
||||
|
||||
if train_summary_writer is not None:
|
||||
# report iteration time summary
|
||||
train_summary_writer.add_scalar('iteration time', iter_time, iter)
|
||||
# report loss summary
|
||||
src_loss, dst_loss = loss_history[-1]
|
||||
train_summary_writer.add_scalar('loss/src', src_loss, iter)
|
||||
train_summary_writer.add_scalar('loss/dst', dst_loss, iter)
|
||||
loss_entry = loss_history[-1]
|
||||
log_step(iter, iter_time, loss_entry[0], loss_entry[1] if len(loss_entry) > 1 else None)
|
||||
|
||||
if model.get_iter() == 1:
|
||||
model_save()
|
||||
|
@ -220,6 +228,12 @@ def trainerThread (s2c, c2s, e,
|
|||
is_reached_goal = True
|
||||
io.log_info ('You can use preview now.')
|
||||
|
||||
if not is_reached_goal and (time.time() - last_preview_time) >= tensorboard_preview_interval_min*60:
|
||||
last_preview_time += tensorboard_preview_interval_min*60
|
||||
previews = model.get_previews()
|
||||
static_previews = model.get_static_previews()
|
||||
log_previews(iter, previews, static_previews)
|
||||
|
||||
if not is_reached_goal and (time.time() - last_save_time) >= save_interval_min*60:
|
||||
last_save_time += save_interval_min*60
|
||||
model_save()
|
||||
|
@ -262,7 +276,54 @@ def trainerThread (s2c, c2s, e,
|
|||
break
|
||||
c2s.put ( {'op':'close'} )
|
||||
|
||||
_train_summary_writer = None
|
||||
def init_writer(model_name, tensorboard_dir, start_tensorboard):
|
||||
import tensorboardX
|
||||
global _train_summary_writer
|
||||
|
||||
if not os.path.exists(tensorboard_dir):
|
||||
os.makedirs(tensorboard_dir)
|
||||
summary_writer_folder = os.path.join(tensorboard_dir, model_name)
|
||||
_train_summary_writer = tensorboardX.SummaryWriter(summary_writer_folder)
|
||||
|
||||
if start_tensorboard:
|
||||
tb_tool = TensorBoardTool(tensorboard_dir)
|
||||
tb_tool.run()
|
||||
|
||||
return _train_summary_writer
|
||||
|
||||
def get_writer():
|
||||
global _train_summary_writer
|
||||
return _train_summary_writer
|
||||
|
||||
def handle_tensorboard_op(input):
|
||||
train_summary_writer = get_writer()
|
||||
action = input['action']
|
||||
if action == 'init':
|
||||
model_name = input['model_name']
|
||||
tensorboard_dir = input['tensorboard_dir']
|
||||
start_tensorboard = input['start_tensorboard']
|
||||
train_summary_writer = init_writer(model_name, tensorboard_dir, start_tensorboard)
|
||||
if train_summary_writer is not None:
|
||||
if action == 'step':
|
||||
step = input['step']
|
||||
step_time = input['step_time']
|
||||
src_loss = input['src_loss']
|
||||
dst_loss = input['dst_loss']
|
||||
# report iteration time summary
|
||||
train_summary_writer.add_scalar('iteration time', step_time, step)
|
||||
# report loss summary
|
||||
train_summary_writer.add_scalar('loss/src', src_loss, step)
|
||||
if dst_loss is not None:
|
||||
train_summary_writer.add_scalar('loss/dst', dst_loss, step)
|
||||
elif action == 'preview':
|
||||
step = input['step']
|
||||
previews = input['previews']
|
||||
static_previews = input['static_previews']
|
||||
if previews is not None:
|
||||
log_tensorboard_previews(step, previews, 'preview', train_summary_writer)
|
||||
if static_previews is not None:
|
||||
log_tensorboard_previews(step, static_previews, 'static_preview', train_summary_writer)
|
||||
|
||||
def main(**kwargs):
|
||||
io.log_info ("Running trainer.\r\n")
|
||||
|
@ -283,7 +344,9 @@ def main(**kwargs):
|
|||
if not c2s.empty():
|
||||
input = c2s.get()
|
||||
op = input.get('op','')
|
||||
if op == 'close':
|
||||
if op == 'tb':
|
||||
handle_tensorboard_op(input)
|
||||
elif op == 'close':
|
||||
break
|
||||
try:
|
||||
io.process_messages(0.1)
|
||||
|
@ -333,6 +396,8 @@ def main(**kwargs):
|
|||
previews.append ( (preview_name, cv2.resize(preview_rgb, (max_w, max_h))) )
|
||||
selected_preview = selected_preview % len(previews)
|
||||
update_preview = True
|
||||
elif op == 'tb':
|
||||
handle_tensorboard_op(input)
|
||||
elif op == 'close':
|
||||
break
|
||||
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
from pathlib import Path
|
||||
|
||||
class FrameInfo(object):
|
||||
def __init__(self, filepath=None, landmarks_list=None):
|
||||
def __init__(self, filepath=None, landmarks_list=None, image_to_face_mat=None, aligned_size=512):
|
||||
self.filepath = filepath
|
||||
self.landmarks_list = landmarks_list or []
|
||||
self.motion_deg = 0
|
||||
self.motion_power = 0
|
||||
self.motion_power = 0
|
||||
self.image_to_face_mat = image_to_face_mat
|
||||
self.aligned_size = aligned_size
|
|
@ -84,15 +84,14 @@ class InteractiveMergerSubprocessor(Subprocessor):
|
|||
filepath = frame_info.filepath
|
||||
|
||||
if len(frame_info.landmarks_list) == 0:
|
||||
|
||||
if cfg.mode == 'raw-predict':
|
||||
self.log_info (f'no faces found for {filepath.name}, copying without faces')
|
||||
if cfg.mode == 'raw-predict':
|
||||
h,w,c = self.predictor_input_shape
|
||||
img_bgr = np.zeros( (h,w,3), dtype=np.uint8)
|
||||
img_mask = np.zeros( (h,w,1), dtype=np.uint8)
|
||||
else:
|
||||
self.log_info (f'no faces found for {filepath.name}, copying without faces')
|
||||
img_mask = np.zeros( (h,w,1), dtype=np.uint8)
|
||||
else:
|
||||
img_bgr = cv2_imread(filepath)
|
||||
imagelib.normalize_channels(img_bgr, 3)
|
||||
imagelib.normalize_channels(img_bgr, 3)
|
||||
h,w,c = img_bgr.shape
|
||||
img_mask = np.zeros( (h,w,1), dtype=img_bgr.dtype)
|
||||
|
||||
|
@ -140,7 +139,7 @@ class InteractiveMergerSubprocessor(Subprocessor):
|
|||
|
||||
|
||||
#override
|
||||
def __init__(self, is_interactive, merger_session_filepath, predictor_func, predictor_input_shape, face_enhancer_func, xseg_256_extract_func, merger_config, frames, frames_root_path, output_path, output_mask_path, model_iter, subprocess_count=4):
|
||||
def __init__(self, is_interactive, merger_session_filepath, predictor_func, predictor_input_shape, face_enhancer_func, xseg_256_extract_func, merger_config, frames, frames_root_path, output_path, output_mask_path, model_iter, subprocess_count=4, src_src=False):
|
||||
if len (frames) == 0:
|
||||
raise ValueError ("len (frames) == 0")
|
||||
|
||||
|
@ -163,6 +162,8 @@ class InteractiveMergerSubprocessor(Subprocessor):
|
|||
|
||||
self.prefetch_frame_count = self.process_count = subprocess_count
|
||||
|
||||
self.src_src = src_src
|
||||
|
||||
session_data = None
|
||||
if self.is_interactive and self.merger_session_filepath.exists():
|
||||
io.input_skip_pending()
|
||||
|
|
|
@ -12,6 +12,17 @@ from facelib import FaceType, LandmarksProcessor
|
|||
is_windows = sys.platform[0:3] == 'win'
|
||||
xseg_input_size = 256
|
||||
|
||||
def concat_matrix(first, second):
|
||||
mul1 = np.vstack([*first, [0, 0, 1]])
|
||||
mul2 = np.vstack([*second, [0, 0, 1]])
|
||||
mul_r = np.matmul(mul1, mul2)
|
||||
return np.delete(mul_r, (2), axis=0);
|
||||
|
||||
def get_identity_affine_mat():
|
||||
pt1 = np.float32([ [0, 0], [1, 0], [1, 1] ])
|
||||
pt2 = np.float32([ [0, 0], [1, 0], [1, 1] ])
|
||||
return cv2.getAffineTransform(pt1, pt2)
|
||||
|
||||
def MergeMaskedFace (predictor_func, predictor_input_shape,
|
||||
face_enhancer_func,
|
||||
xseg_256_extract_func,
|
||||
|
@ -26,15 +37,28 @@ def MergeMaskedFace (predictor_func, predictor_input_shape,
|
|||
if cfg.super_resolution_power != 0:
|
||||
output_size *= 4
|
||||
|
||||
face_mat = LandmarksProcessor.get_transform_mat (img_face_landmarks, output_size, face_type=cfg.face_type)
|
||||
face_output_mat = LandmarksProcessor.get_transform_mat (img_face_landmarks, output_size, face_type=cfg.face_type, scale= 1.0 + 0.01*cfg.output_face_scale)
|
||||
existing_mat = get_identity_affine_mat() if cfg.src_src else frame_info.image_to_face_mat
|
||||
aligned_size = frame_info.aligned_size
|
||||
|
||||
if existing_mat is not None:
|
||||
face_scale_mat = cv2.getRotationMatrix2D((0,0), 0, output_size/aligned_size)
|
||||
face_mat = concat_matrix(face_scale_mat, existing_mat)
|
||||
output_scale_mat = cv2.getRotationMatrix2D((output_size/2, output_size/2), 0, 1.0 + 0.01 * cfg.output_face_scale)
|
||||
face_output_mat = concat_matrix(output_scale_mat, face_mat)
|
||||
else:
|
||||
face_mat = LandmarksProcessor.get_transform_mat (img_face_landmarks, output_size, face_type=cfg.face_type)
|
||||
face_output_mat = LandmarksProcessor.get_transform_mat (img_face_landmarks, output_size, face_type=cfg.face_type, scale= 1.0 + 0.01*cfg.output_face_scale)
|
||||
|
||||
if mask_subres_size == output_size:
|
||||
face_mask_output_mat = face_output_mat
|
||||
else:
|
||||
face_mask_output_mat = LandmarksProcessor.get_transform_mat (img_face_landmarks, mask_subres_size, face_type=cfg.face_type, scale= 1.0 + 0.01*cfg.output_face_scale)
|
||||
if existing_mat is not None:
|
||||
mask_output_scale_mat = cv2.getRotationMatrix2D((0, 0), 0, mask_subres_size/output_size)
|
||||
face_mask_output_mat = concat_matrix(mask_output_scale_mat, face_mat)
|
||||
else:
|
||||
face_mask_output_mat = LandmarksProcessor.get_transform_mat (img_face_landmarks, mask_subres_size, face_type=cfg.face_type, scale= 1.0 + 0.01*cfg.output_face_scale)
|
||||
|
||||
dst_face_bgr = cv2.warpAffine( img_bgr , face_mat, (output_size, output_size), flags=cv2.INTER_CUBIC )
|
||||
dst_face_bgr = cv2.warpAffine( img_bgr, face_mat, (output_size, output_size), flags=cv2.INTER_CUBIC )
|
||||
dst_face_bgr = np.clip(dst_face_bgr, 0, 1)
|
||||
|
||||
dst_face_mask_a_0 = cv2.warpAffine( img_face_mask_a, face_mat, (output_size, output_size), flags=cv2.INTER_CUBIC )
|
||||
|
@ -76,7 +100,11 @@ def MergeMaskedFace (predictor_func, predictor_input_shape,
|
|||
|
||||
if cfg.mask_mode >= 7 and cfg.mask_mode <= 9:
|
||||
# obtain XSeg-dst
|
||||
xseg_mat = LandmarksProcessor.get_transform_mat (img_face_landmarks, xseg_input_size, face_type=cfg.face_type)
|
||||
if existing_mat is not None:
|
||||
xseg_scale_mat = cv2.getRotationMatrix2D((0,0), 0, xseg_input_size/aligned_size)
|
||||
xseg_mat = concat_matrix(xseg_scale_mat, existing_mat)
|
||||
else:
|
||||
xseg_mat = LandmarksProcessor.get_transform_mat (img_face_landmarks, xseg_input_size, face_type=cfg.face_type)
|
||||
dst_face_xseg_bgr = cv2.warpAffine(img_bgr, xseg_mat, (xseg_input_size,)*2, flags=cv2.INTER_CUBIC )
|
||||
dst_face_xseg_mask = xseg_256_extract_func(dst_face_xseg_bgr)
|
||||
X_dst_face_mask_a_0 = cv2.resize (dst_face_xseg_mask, (output_size,output_size), interpolation=cv2.INTER_CUBIC)
|
||||
|
|
|
@ -113,6 +113,7 @@ class MergerConfigMasked(MergerConfig):
|
|||
image_denoise_power = 0,
|
||||
bicubic_degrade_power = 0,
|
||||
color_degrade_power = 0,
|
||||
src_src = False,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
|
@ -122,6 +123,8 @@ class MergerConfigMasked(MergerConfig):
|
|||
if self.face_type not in [FaceType.HALF, FaceType.MID_FULL, FaceType.FULL, FaceType.WHOLE_FACE, FaceType.HEAD ]:
|
||||
raise ValueError("MergerConfigMasked does not support this type of face.")
|
||||
|
||||
self.src_src = src_src
|
||||
|
||||
self.default_mode = default_mode
|
||||
|
||||
#default changeable params
|
||||
|
@ -188,13 +191,13 @@ class MergerConfigMasked(MergerConfig):
|
|||
self.bicubic_degrade_power = np.clip ( self.bicubic_degrade_power+diff, 0, 100)
|
||||
|
||||
def ask_settings(self):
|
||||
s = """Choose mode: \n"""
|
||||
for key in mode_dict.keys():
|
||||
s += f"""({key}) {mode_dict[key]}\n"""
|
||||
io.log_info(s)
|
||||
mode = io.input_int ("", mode_str_dict.get(self.default_mode, 1) )
|
||||
|
||||
self.mode = mode_dict.get (mode, self.default_mode )
|
||||
if 'raw-predict' not in self.mode:
|
||||
s = """Choose mode: \n"""
|
||||
for key in mode_dict.keys():
|
||||
s += f"""({key}) {mode_dict[key]}\n"""
|
||||
io.log_info(s)
|
||||
mode = io.input_int ("", mode_str_dict.get(self.default_mode, 1) )
|
||||
self.mode = mode_dict.get (mode, self.default_mode)
|
||||
|
||||
if 'raw' not in self.mode:
|
||||
if self.mode == 'hist-match':
|
||||
|
|
|
@ -22,6 +22,7 @@ from samplelib import SampleGeneratorBase
|
|||
|
||||
class ModelBase(object):
|
||||
def __init__(self, is_training=False,
|
||||
src_src=False,
|
||||
saved_models_path=None,
|
||||
training_data_src_path=None,
|
||||
training_data_dst_path=None,
|
||||
|
@ -36,6 +37,7 @@ class ModelBase(object):
|
|||
silent_start=False,
|
||||
**kwargs):
|
||||
self.is_training = is_training
|
||||
self.src_src = src_src
|
||||
self.saved_models_path = saved_models_path
|
||||
self.training_data_src_path = training_data_src_path
|
||||
self.training_data_dst_path = training_data_dst_path
|
||||
|
|
|
@ -601,20 +601,29 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
_, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code)
|
||||
|
||||
elif 'liae' in archi_type:
|
||||
gpu_src_code = self.encoder (self.warped_src)
|
||||
gpu_src_inter_AB_code = self.inter_AB (gpu_src_code)
|
||||
gpu_src_code = tf.concat([gpu_src_inter_AB_code,gpu_src_inter_AB_code], nn.conv2d_ch_axis )
|
||||
gpu_pred_src_src, gpu_pred_src_srcm = self.decoder(gpu_src_code)
|
||||
|
||||
gpu_dst_code = self.encoder (self.warped_dst)
|
||||
gpu_dst_inter_B_code = self.inter_B (gpu_dst_code)
|
||||
gpu_dst_inter_AB_code = self.inter_AB (gpu_dst_code)
|
||||
gpu_dst_code = tf.concat([gpu_dst_inter_B_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis)
|
||||
gpu_src_dst_code = tf.concat([gpu_dst_inter_AB_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis)
|
||||
|
||||
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code)
|
||||
gpu_pred_src_dst, gpu_pred_src_dst_dstm = self.decoder(gpu_src_dst_code)
|
||||
_, gpu_pred_dst_dstm = self.decoder(gpu_dst_code)
|
||||
|
||||
|
||||
def AE_merge( warped_dst):
|
||||
return nn.tf_sess.run ( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm], feed_dict={self.warped_dst:warped_dst})
|
||||
def AE_merge(warped_dst):
|
||||
return nn.tf_sess.run ( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dst_dstm], feed_dict={self.warped_dst:warped_dst})
|
||||
|
||||
def AE_src(warped_src):
|
||||
return nn.tf_sess.run( [gpu_pred_src_src, gpu_pred_src_srcm], feed_dict={self.warped_src:warped_src})
|
||||
|
||||
self.AE_merge = AE_merge
|
||||
self.AE_src = AE_src
|
||||
|
||||
# Loading/initializing all models/optimizers weights
|
||||
for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"):
|
||||
|
@ -813,9 +822,21 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
|
||||
return bgr[0], mask_src_dstm[0][...,0], mask_dst_dstm[0][...,0]
|
||||
|
||||
def src_predictor_func (self, face=None):
|
||||
face = nn.to_data_format(face[None,...], self.model_data_format, "NHWC")
|
||||
bgr, mask_src_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format).astype(np.float32) for x in self.AE_src (face) ]
|
||||
return bgr[0], mask_src_dstm[0][...,0], mask_src_dstm[0][...,0]
|
||||
|
||||
#override
|
||||
def get_MergerConfig(self):
|
||||
import merger
|
||||
return self.predictor_func, (self.options['resolution'], self.options['resolution'], 3), merger.MergerConfigMasked(face_type=self.face_type, default_mode = 'overlay')
|
||||
if self.src_src:
|
||||
merger_config = merger.MergerConfigMasked(face_type=self.face_type,
|
||||
default_mode='raw-predict',
|
||||
mode='raw-predict',
|
||||
src_src=True)
|
||||
return self.src_predictor_func, (self.options['resolution'], self.options['resolution'], 3), merger_config
|
||||
else:
|
||||
return self.predictor_func, (self.options['resolution'], self.options['resolution'], 3), merger.MergerConfigMasked(face_type=self.face_type, default_mode = 'overlay')
|
||||
|
||||
Model = SAEHDModel
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue