From 50759b346236df2e0253ddc967b16a17daef1021 Mon Sep 17 00:00:00 2001 From: Vincent Riemer Date: Thu, 12 Nov 2020 15:20:53 -0800 Subject: [PATCH 1/3] remove special dos characters --- localization/localization.py | 2 +- mainscripts/Extractor.py | 2 +- mainscripts/Merger.py | 2 +- mainscripts/Sorter.py | 2 +- mainscripts/Trainer.py | 2 +- mainscripts/Util.py | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/localization/localization.py b/localization/localization.py index 3df7bbd..63c5d56 100644 --- a/localization/localization.py +++ b/localization/localization.py @@ -1,4 +1,4 @@ -import sys +import sys import locale system_locale = locale.getdefaultlocale()[0] diff --git a/mainscripts/Extractor.py b/mainscripts/Extractor.py index 1d5c80f..e1a1fa7 100644 --- a/mainscripts/Extractor.py +++ b/mainscripts/Extractor.py @@ -1,4 +1,4 @@ -import traceback +import traceback import math import multiprocessing import operator diff --git a/mainscripts/Merger.py b/mainscripts/Merger.py index fba37f1..57d09fb 100644 --- a/mainscripts/Merger.py +++ b/mainscripts/Merger.py @@ -1,4 +1,4 @@ -import math +import math import multiprocessing import traceback from pathlib import Path diff --git a/mainscripts/Sorter.py b/mainscripts/Sorter.py index 9b09e43..aa4151e 100644 --- a/mainscripts/Sorter.py +++ b/mainscripts/Sorter.py @@ -1,4 +1,4 @@ -import math +import math import multiprocessing import operator import os diff --git a/mainscripts/Trainer.py b/mainscripts/Trainer.py index 07d1307..7facd29 100644 --- a/mainscripts/Trainer.py +++ b/mainscripts/Trainer.py @@ -1,4 +1,4 @@ -import sys +import sys import traceback import queue import threading diff --git a/mainscripts/Util.py b/mainscripts/Util.py index 4a51e53..513759c 100644 --- a/mainscripts/Util.py +++ b/mainscripts/Util.py @@ -1,4 +1,4 @@ -import pickle +import pickle from pathlib import Path import cv2 From 075e8aaf829f4775c9aa43f4fa2fb0b4bd63d859 Mon Sep 17 00:00:00 2001 From: Vincent Riemer Date: Mon, 23 Nov 2020 01:25:32 -0800 Subject: [PATCH 2/3] Add optional tensorboard logging --- main.py | 4 +++ mainscripts/Trainer.py | 62 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+) diff --git a/main.py b/main.py index fb64f96..8e22b89 100644 --- a/main.py +++ b/main.py @@ -127,6 +127,8 @@ if __name__ == "__main__": 'silent_start' : arguments.silent_start, 'execute_programs' : [ [int(x[0]), x[1] ] for x in arguments.execute_program ], 'debug' : arguments.debug, + 'tensorboard_dir' : arguments.tensorboard_dir, + 'start_tensorboard' : arguments.start_tensorboard } from mainscripts import Trainer Trainer.main(**kwargs) @@ -144,6 +146,8 @@ if __name__ == "__main__": p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Train on CPU.") p.add_argument('--force-gpu-idxs', dest="force_gpu_idxs", default=None, help="Force to choose GPU indexes separated by comma.") p.add_argument('--silent-start', action="store_true", dest="silent_start", default=False, help="Silent start. Automatically chooses Best GPU and last used model.") + p.add_argument('--tensorboard-logdir', action=fixPathAction, dest="tensorboard_dir", help="Directory of the tensorboard output files") + p.add_argument('--start-tensorboard', action="store_true", dest="start_tensorboard", default=False, help="Automatically start the tensorboard server preconfigured to the tensorboard-logdir") p.add_argument('--execute-program', dest="execute_program", default=[], action='append', nargs='+') diff --git a/mainscripts/Trainer.py b/mainscripts/Trainer.py index 7facd29..00b030d 100644 --- a/mainscripts/Trainer.py +++ b/mainscripts/Trainer.py @@ -11,6 +11,41 @@ from core import imagelib import cv2 import models from core.interact import interact as io +import logging +import datetime +import os + +# adapted from https://stackoverflow.com/a/52295534 +class TensorBoardTool: + def __init__(self, dir_path): + self.dir_path = dir_path + def run(self): + from tensorboard import default + from tensorboard import program + # remove http messages + log = logging.getLogger('werkzeug').setLevel(logging.ERROR) + # Start tensorboard server + tb = program.TensorBoard(default.get_plugins()) + tb.configure(argv=[None, '--logdir', self.dir_path, '--port', '6006', '--bind_all']) + url = tb.launch() + print('Launched TensorBoard at {}'.format(url)) + +def process_img_for_tensorboard(input_img): + # convert format from bgr to rgb + img = cv2.cvtColor(input_img, cv2.COLOR_BGR2RGB) + # adjust axis to put channel count at the beginning + img = np.moveaxis(img, -1, 0) + return img + +def log_tensorboard_previews(iter, previews, folder_name, train_summary_writer): + for preview in previews: + (preview_name, preview_bgr) = preview + preview_rgb = process_img_for_tensorboard(preview_bgr) + train_summary_writer.add_image('{}/{}'.format(folder_name, preview_name), preview_rgb, iter) + +def log_tensorboard_model_previews(iter, model, train_summary_writer): + log_tensorboard_previews(iter, model.get_previews(), 'preview', train_summary_writer) + log_tensorboard_previews(iter, model.get_static_previews(), 'static_preview', train_summary_writer) def trainerThread (s2c, c2s, e, model_class_name = None, @@ -26,6 +61,8 @@ def trainerThread (s2c, c2s, e, silent_start=False, execute_programs = None, debug=False, + tensorboard_dir=None, + start_tensorboard=False, **kwargs): while True: try: @@ -59,6 +96,21 @@ def trainerThread (s2c, c2s, e, is_reached_goal = model.is_reached_iter_goal() + train_summary_writer = None + if tensorboard_dir is not None: + try: + import tensorboardX + if not os.path.exists(tensorboard_dir): + os.makedirs(tensorboard_dir) + summary_writer_folder = os.path.join(tensorboard_dir, model.model_name) + train_summary_writer = tensorboardX.SummaryWriter(summary_writer_folder) + if start_tensorboard: + tb_tool = TensorBoardTool(tensorboard_dir) + tb_tool.run() + except: + print("Error importing tensorboardX, please ensure it is installed (pip install tensorboardX)") + print("Continuing training without tensorboard logging...") + shared_state = { 'after_save' : False } loss_string = "" save_iter = model.get_iter() @@ -74,6 +126,8 @@ def trainerThread (s2c, c2s, e, def send_preview(): if not debug: + if train_summary_writer is not None: + log_tensorboard_model_previews(iter, model, train_summary_writer) previews = model.get_previews() c2s.put ( {'op':'show', 'previews': previews, 'iter':model.get_iter(), 'loss_history': model.get_loss_history().copy() } ) else: @@ -149,6 +203,14 @@ def trainerThread (s2c, c2s, e, else: io.log_info (loss_string, end='\r') + if train_summary_writer is not None: + # report iteration time summary + train_summary_writer.add_scalar('iteration time', iter_time, iter) + # report loss summary + src_loss, dst_loss = loss_history[-1] + train_summary_writer.add_scalar('loss/src', src_loss, iter) + train_summary_writer.add_scalar('loss/dst', dst_loss, iter) + if model.get_iter() == 1: model_save() From 5e4e0a4d23887941d14ed463f9da53d776917976 Mon Sep 17 00:00:00 2001 From: Vincent Riemer Date: Thu, 10 Dec 2020 09:38:33 -0800 Subject: [PATCH 3/3] add src-src merge and use stored aligned img mat for merging --- main.py | 4 +- mainscripts/Merger.py | 37 +++++--- mainscripts/Trainer.py | 113 +++++++++++++++++++----- merger/FrameInfo.py | 6 +- merger/InteractiveMergerSubprocessor.py | 15 ++-- merger/MergeMasked.py | 38 ++++++-- merger/MergerConfig.py | 17 ++-- models/ModelBase.py | 2 + models/Model_SAEHD/Model.py | 29 +++++- 9 files changed, 200 insertions(+), 61 deletions(-) diff --git a/main.py b/main.py index 8e22b89..3d07640 100644 --- a/main.py +++ b/main.py @@ -164,7 +164,8 @@ if __name__ == "__main__": output_mask_path = Path(arguments.output_mask_dir), aligned_path = Path(arguments.aligned_dir) if arguments.aligned_dir is not None else None, force_gpu_idxs = arguments.force_gpu_idxs, - cpu_only = arguments.cpu_only) + cpu_only = arguments.cpu_only, + src_src = arguments.src_src) p = subparsers.add_parser( "merge", help="Merger") p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory. A directory containing the files you wish to process.") @@ -176,6 +177,7 @@ if __name__ == "__main__": p.add_argument('--force-model-name', dest="force_model_name", default=None, help="Forcing to choose model name from model/ folder.") p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Merge on CPU.") p.add_argument('--force-gpu-idxs', dest="force_gpu_idxs", default=None, help="Force to choose GPU indexes separated by comma.") + p.add_argument('--src-src', action="store_true", dest="src_src", default=False, help="Enables special src-src predicted output.") p.set_defaults(func=process_merge) videoed_parser = subparsers.add_parser( "videoed", help="Video processing.").add_subparsers() diff --git a/mainscripts/Merger.py b/mainscripts/Merger.py index 57d09fb..cf5c5e6 100644 --- a/mainscripts/Merger.py +++ b/mainscripts/Merger.py @@ -26,7 +26,8 @@ def main (model_class_name=None, output_mask_path=None, aligned_path=None, force_gpu_idxs=None, - cpu_only=None): + cpu_only=None, + src_src=False): io.log_info ("Running merger.\r\n") try: @@ -47,6 +48,7 @@ def main (model_class_name=None, # Initialize model import models model = models.import_model(model_class_name)(is_training=False, + src_src=src_src, saved_models_path=saved_models_path, force_gpu_idxs=force_gpu_idxs, cpu_only=cpu_only) @@ -68,7 +70,10 @@ def main (model_class_name=None, place_model_on_cpu=True, run_on_cpu=run_on_cpu) - is_interactive = io.input_bool ("Use interactive merger?", True) if not io.is_colab() else False + if src_src: + is_interactive = False + else: + is_interactive = io.input_bool ("Use interactive merger?", True) if not io.is_colab() else False if not is_interactive: cfg.ask_settings() @@ -110,7 +115,7 @@ def main (model_class_name=None, io.log_err (f"{filepath.name} is not a dfl image file") continue - source_filename = dflimg.get_source_filename() + source_filename = filepath.name if src_src else dflimg.get_source_filename() if source_filename is None: continue @@ -121,7 +126,7 @@ def main (model_class_name=None, alignments[ source_filename_stem ] = [] alignments_ar = alignments[ source_filename_stem ] - alignments_ar.append ( (dflimg.get_source_landmarks(), filepath, source_filepath ) ) + alignments_ar.append ( (dflimg.get_source_landmarks(), filepath, source_filepath, dflimg.get_image_to_face_mat(), dflimg.get_shape()[0]) ) if len(alignments_ar) > 1: multiple_faces_detected = True @@ -137,23 +142,32 @@ def main (model_class_name=None, for _, filepath, source_filepath in a_ar: io.log_info (f"alignment {filepath.name} refers to {source_filepath.name} ") io.log_info ("") - - alignments[a_key] = [ a[0] for a in a_ar] + alignments[a_key] = [ (a[0], a[3], a[4]) for a in a_ar] if multiple_faces_detected: io.log_info ("It is strongly recommended to process the faces separatelly.") io.log_info ("Use 'recover original filename' to determine the exact duplicates.") io.log_info ("") - frames = [ InteractiveMergerSubprocessor.Frame( frame_info=FrameInfo(filepath=Path(p), - landmarks_list=alignments.get(Path(p).stem, None) - ) - ) - for p in input_path_image_paths ] + frames = [] + for p in input_path_image_paths: + alignment = alignments.get(Path(p).stem, None) + landmarks_list = None + image_to_face_mat = None + aligned_size = None + if alignment is not None: + landmarks_list, image_to_face_mat, aligned_size = alignment[0] + landmarks_list = [landmarks_list] + frame_info = FrameInfo(filepath=Path(p), landmarks_list=landmarks_list, image_to_face_mat=image_to_face_mat, aligned_size=aligned_size) + frame = InteractiveMergerSubprocessor.Frame(frame_info=frame_info) + frames.append(frame) if multiple_faces_detected: io.log_info ("Warning: multiple faces detected. Motion blur will not be used.") io.log_info ("") + elif src_src: + io.log_info ("SRC-SRC mode configured, skipping motion blur calculation...") + io.log_info ("") else: s = 256 local_pts = [ (s//2-1, s//2-1), (s//2-1,0) ] #center+up @@ -206,6 +220,7 @@ def main (model_class_name=None, output_mask_path = output_mask_path, model_iter = model.get_iter(), subprocess_count = subprocess_count, + src_src = src_src, ).run() model.finalize() diff --git a/mainscripts/Trainer.py b/mainscripts/Trainer.py index 00b030d..cf0cb43 100644 --- a/mainscripts/Trainer.py +++ b/mainscripts/Trainer.py @@ -69,6 +69,7 @@ def trainerThread (s2c, c2s, e, start_time = time.time() save_interval_min = 15 + tensorboard_preview_interval_min = 5 if not training_data_src_path.exists(): training_data_src_path.mkdir(exist_ok=True, parents=True) @@ -96,20 +97,14 @@ def trainerThread (s2c, c2s, e, is_reached_goal = model.is_reached_iter_goal() - train_summary_writer = None if tensorboard_dir is not None: - try: - import tensorboardX - if not os.path.exists(tensorboard_dir): - os.makedirs(tensorboard_dir) - summary_writer_folder = os.path.join(tensorboard_dir, model.model_name) - train_summary_writer = tensorboardX.SummaryWriter(summary_writer_folder) - if start_tensorboard: - tb_tool = TensorBoardTool(tensorboard_dir) - tb_tool.run() - except: - print("Error importing tensorboardX, please ensure it is installed (pip install tensorboardX)") - print("Continuing training without tensorboard logging...") + c2s.put({ + 'op': 'tb', + 'action': 'init', + 'model_name': model.model_name, + 'tensorboard_dir': tensorboard_dir, + 'start_tensorboard': start_tensorboard + }) shared_state = { 'after_save' : False } loss_string = "" @@ -122,12 +117,29 @@ def trainerThread (s2c, c2s, e, def model_backup(): if not debug and not is_reached_goal: - model.create_backup() + model.create_backup() + + def log_step(step, step_time, src_loss, dst_loss): + c2s.put({ + 'op': 'tb', + 'action': 'step', + 'step': step, + 'step_time': step_time, + 'src_loss': src_loss, + 'dst_loss': dst_loss + }) + + def log_previews(step, previews, static_previews): + c2s.put({ + 'op': 'tb', + 'action': 'preview', + 'step': step, + 'previews': previews, + 'static_previews': static_previews + }) def send_preview(): if not debug: - if train_summary_writer is not None: - log_tensorboard_model_previews(iter, model, train_summary_writer) previews = model.get_previews() c2s.put ( {'op':'show', 'previews': previews, 'iter':model.get_iter(), 'loss_history': model.get_loss_history().copy() } ) else: @@ -144,6 +156,7 @@ def trainerThread (s2c, c2s, e, io.log_info('Starting. Press "Enter" to stop training and save model.') last_save_time = time.time() + last_preview_time = time.time() execute_programs = [ [x[0], x[1], time.time() ] for x in execute_programs ] @@ -203,13 +216,8 @@ def trainerThread (s2c, c2s, e, else: io.log_info (loss_string, end='\r') - if train_summary_writer is not None: - # report iteration time summary - train_summary_writer.add_scalar('iteration time', iter_time, iter) - # report loss summary - src_loss, dst_loss = loss_history[-1] - train_summary_writer.add_scalar('loss/src', src_loss, iter) - train_summary_writer.add_scalar('loss/dst', dst_loss, iter) + loss_entry = loss_history[-1] + log_step(iter, iter_time, loss_entry[0], loss_entry[1] if len(loss_entry) > 1 else None) if model.get_iter() == 1: model_save() @@ -220,6 +228,12 @@ def trainerThread (s2c, c2s, e, is_reached_goal = True io.log_info ('You can use preview now.') + if not is_reached_goal and (time.time() - last_preview_time) >= tensorboard_preview_interval_min*60: + last_preview_time += tensorboard_preview_interval_min*60 + previews = model.get_previews() + static_previews = model.get_static_previews() + log_previews(iter, previews, static_previews) + if not is_reached_goal and (time.time() - last_save_time) >= save_interval_min*60: last_save_time += save_interval_min*60 model_save() @@ -262,7 +276,54 @@ def trainerThread (s2c, c2s, e, break c2s.put ( {'op':'close'} ) +_train_summary_writer = None +def init_writer(model_name, tensorboard_dir, start_tensorboard): + import tensorboardX + global _train_summary_writer + if not os.path.exists(tensorboard_dir): + os.makedirs(tensorboard_dir) + summary_writer_folder = os.path.join(tensorboard_dir, model_name) + _train_summary_writer = tensorboardX.SummaryWriter(summary_writer_folder) + + if start_tensorboard: + tb_tool = TensorBoardTool(tensorboard_dir) + tb_tool.run() + + return _train_summary_writer + +def get_writer(): + global _train_summary_writer + return _train_summary_writer + +def handle_tensorboard_op(input): + train_summary_writer = get_writer() + action = input['action'] + if action == 'init': + model_name = input['model_name'] + tensorboard_dir = input['tensorboard_dir'] + start_tensorboard = input['start_tensorboard'] + train_summary_writer = init_writer(model_name, tensorboard_dir, start_tensorboard) + if train_summary_writer is not None: + if action == 'step': + step = input['step'] + step_time = input['step_time'] + src_loss = input['src_loss'] + dst_loss = input['dst_loss'] + # report iteration time summary + train_summary_writer.add_scalar('iteration time', step_time, step) + # report loss summary + train_summary_writer.add_scalar('loss/src', src_loss, step) + if dst_loss is not None: + train_summary_writer.add_scalar('loss/dst', dst_loss, step) + elif action == 'preview': + step = input['step'] + previews = input['previews'] + static_previews = input['static_previews'] + if previews is not None: + log_tensorboard_previews(step, previews, 'preview', train_summary_writer) + if static_previews is not None: + log_tensorboard_previews(step, static_previews, 'static_preview', train_summary_writer) def main(**kwargs): io.log_info ("Running trainer.\r\n") @@ -283,7 +344,9 @@ def main(**kwargs): if not c2s.empty(): input = c2s.get() op = input.get('op','') - if op == 'close': + if op == 'tb': + handle_tensorboard_op(input) + elif op == 'close': break try: io.process_messages(0.1) @@ -333,6 +396,8 @@ def main(**kwargs): previews.append ( (preview_name, cv2.resize(preview_rgb, (max_w, max_h))) ) selected_preview = selected_preview % len(previews) update_preview = True + elif op == 'tb': + handle_tensorboard_op(input) elif op == 'close': break diff --git a/merger/FrameInfo.py b/merger/FrameInfo.py index 1b8ebb0..9eeaf44 100644 --- a/merger/FrameInfo.py +++ b/merger/FrameInfo.py @@ -1,8 +1,10 @@ from pathlib import Path class FrameInfo(object): - def __init__(self, filepath=None, landmarks_list=None): + def __init__(self, filepath=None, landmarks_list=None, image_to_face_mat=None, aligned_size=512): self.filepath = filepath self.landmarks_list = landmarks_list or [] self.motion_deg = 0 - self.motion_power = 0 \ No newline at end of file + self.motion_power = 0 + self.image_to_face_mat = image_to_face_mat + self.aligned_size = aligned_size \ No newline at end of file diff --git a/merger/InteractiveMergerSubprocessor.py b/merger/InteractiveMergerSubprocessor.py index 58db0c1..f509971 100644 --- a/merger/InteractiveMergerSubprocessor.py +++ b/merger/InteractiveMergerSubprocessor.py @@ -84,15 +84,14 @@ class InteractiveMergerSubprocessor(Subprocessor): filepath = frame_info.filepath if len(frame_info.landmarks_list) == 0: - - if cfg.mode == 'raw-predict': + self.log_info (f'no faces found for {filepath.name}, copying without faces') + if cfg.mode == 'raw-predict': h,w,c = self.predictor_input_shape img_bgr = np.zeros( (h,w,3), dtype=np.uint8) - img_mask = np.zeros( (h,w,1), dtype=np.uint8) - else: - self.log_info (f'no faces found for {filepath.name}, copying without faces') + img_mask = np.zeros( (h,w,1), dtype=np.uint8) + else: img_bgr = cv2_imread(filepath) - imagelib.normalize_channels(img_bgr, 3) + imagelib.normalize_channels(img_bgr, 3) h,w,c = img_bgr.shape img_mask = np.zeros( (h,w,1), dtype=img_bgr.dtype) @@ -140,7 +139,7 @@ class InteractiveMergerSubprocessor(Subprocessor): #override - def __init__(self, is_interactive, merger_session_filepath, predictor_func, predictor_input_shape, face_enhancer_func, xseg_256_extract_func, merger_config, frames, frames_root_path, output_path, output_mask_path, model_iter, subprocess_count=4): + def __init__(self, is_interactive, merger_session_filepath, predictor_func, predictor_input_shape, face_enhancer_func, xseg_256_extract_func, merger_config, frames, frames_root_path, output_path, output_mask_path, model_iter, subprocess_count=4, src_src=False): if len (frames) == 0: raise ValueError ("len (frames) == 0") @@ -163,6 +162,8 @@ class InteractiveMergerSubprocessor(Subprocessor): self.prefetch_frame_count = self.process_count = subprocess_count + self.src_src = src_src + session_data = None if self.is_interactive and self.merger_session_filepath.exists(): io.input_skip_pending() diff --git a/merger/MergeMasked.py b/merger/MergeMasked.py index 1100ab2..64de368 100644 --- a/merger/MergeMasked.py +++ b/merger/MergeMasked.py @@ -12,6 +12,17 @@ from facelib import FaceType, LandmarksProcessor is_windows = sys.platform[0:3] == 'win' xseg_input_size = 256 +def concat_matrix(first, second): + mul1 = np.vstack([*first, [0, 0, 1]]) + mul2 = np.vstack([*second, [0, 0, 1]]) + mul_r = np.matmul(mul1, mul2) + return np.delete(mul_r, (2), axis=0); + +def get_identity_affine_mat(): + pt1 = np.float32([ [0, 0], [1, 0], [1, 1] ]) + pt2 = np.float32([ [0, 0], [1, 0], [1, 1] ]) + return cv2.getAffineTransform(pt1, pt2) + def MergeMaskedFace (predictor_func, predictor_input_shape, face_enhancer_func, xseg_256_extract_func, @@ -26,15 +37,28 @@ def MergeMaskedFace (predictor_func, predictor_input_shape, if cfg.super_resolution_power != 0: output_size *= 4 - face_mat = LandmarksProcessor.get_transform_mat (img_face_landmarks, output_size, face_type=cfg.face_type) - face_output_mat = LandmarksProcessor.get_transform_mat (img_face_landmarks, output_size, face_type=cfg.face_type, scale= 1.0 + 0.01*cfg.output_face_scale) + existing_mat = get_identity_affine_mat() if cfg.src_src else frame_info.image_to_face_mat + aligned_size = frame_info.aligned_size + + if existing_mat is not None: + face_scale_mat = cv2.getRotationMatrix2D((0,0), 0, output_size/aligned_size) + face_mat = concat_matrix(face_scale_mat, existing_mat) + output_scale_mat = cv2.getRotationMatrix2D((output_size/2, output_size/2), 0, 1.0 + 0.01 * cfg.output_face_scale) + face_output_mat = concat_matrix(output_scale_mat, face_mat) + else: + face_mat = LandmarksProcessor.get_transform_mat (img_face_landmarks, output_size, face_type=cfg.face_type) + face_output_mat = LandmarksProcessor.get_transform_mat (img_face_landmarks, output_size, face_type=cfg.face_type, scale= 1.0 + 0.01*cfg.output_face_scale) if mask_subres_size == output_size: face_mask_output_mat = face_output_mat else: - face_mask_output_mat = LandmarksProcessor.get_transform_mat (img_face_landmarks, mask_subres_size, face_type=cfg.face_type, scale= 1.0 + 0.01*cfg.output_face_scale) + if existing_mat is not None: + mask_output_scale_mat = cv2.getRotationMatrix2D((0, 0), 0, mask_subres_size/output_size) + face_mask_output_mat = concat_matrix(mask_output_scale_mat, face_mat) + else: + face_mask_output_mat = LandmarksProcessor.get_transform_mat (img_face_landmarks, mask_subres_size, face_type=cfg.face_type, scale= 1.0 + 0.01*cfg.output_face_scale) - dst_face_bgr = cv2.warpAffine( img_bgr , face_mat, (output_size, output_size), flags=cv2.INTER_CUBIC ) + dst_face_bgr = cv2.warpAffine( img_bgr, face_mat, (output_size, output_size), flags=cv2.INTER_CUBIC ) dst_face_bgr = np.clip(dst_face_bgr, 0, 1) dst_face_mask_a_0 = cv2.warpAffine( img_face_mask_a, face_mat, (output_size, output_size), flags=cv2.INTER_CUBIC ) @@ -76,7 +100,11 @@ def MergeMaskedFace (predictor_func, predictor_input_shape, if cfg.mask_mode >= 7 and cfg.mask_mode <= 9: # obtain XSeg-dst - xseg_mat = LandmarksProcessor.get_transform_mat (img_face_landmarks, xseg_input_size, face_type=cfg.face_type) + if existing_mat is not None: + xseg_scale_mat = cv2.getRotationMatrix2D((0,0), 0, xseg_input_size/aligned_size) + xseg_mat = concat_matrix(xseg_scale_mat, existing_mat) + else: + xseg_mat = LandmarksProcessor.get_transform_mat (img_face_landmarks, xseg_input_size, face_type=cfg.face_type) dst_face_xseg_bgr = cv2.warpAffine(img_bgr, xseg_mat, (xseg_input_size,)*2, flags=cv2.INTER_CUBIC ) dst_face_xseg_mask = xseg_256_extract_func(dst_face_xseg_bgr) X_dst_face_mask_a_0 = cv2.resize (dst_face_xseg_mask, (output_size,output_size), interpolation=cv2.INTER_CUBIC) diff --git a/merger/MergerConfig.py b/merger/MergerConfig.py index 432bdf1..3f0ccae 100644 --- a/merger/MergerConfig.py +++ b/merger/MergerConfig.py @@ -113,6 +113,7 @@ class MergerConfigMasked(MergerConfig): image_denoise_power = 0, bicubic_degrade_power = 0, color_degrade_power = 0, + src_src = False, **kwargs ): @@ -122,6 +123,8 @@ class MergerConfigMasked(MergerConfig): if self.face_type not in [FaceType.HALF, FaceType.MID_FULL, FaceType.FULL, FaceType.WHOLE_FACE, FaceType.HEAD ]: raise ValueError("MergerConfigMasked does not support this type of face.") + self.src_src = src_src + self.default_mode = default_mode #default changeable params @@ -188,13 +191,13 @@ class MergerConfigMasked(MergerConfig): self.bicubic_degrade_power = np.clip ( self.bicubic_degrade_power+diff, 0, 100) def ask_settings(self): - s = """Choose mode: \n""" - for key in mode_dict.keys(): - s += f"""({key}) {mode_dict[key]}\n""" - io.log_info(s) - mode = io.input_int ("", mode_str_dict.get(self.default_mode, 1) ) - - self.mode = mode_dict.get (mode, self.default_mode ) + if 'raw-predict' not in self.mode: + s = """Choose mode: \n""" + for key in mode_dict.keys(): + s += f"""({key}) {mode_dict[key]}\n""" + io.log_info(s) + mode = io.input_int ("", mode_str_dict.get(self.default_mode, 1) ) + self.mode = mode_dict.get (mode, self.default_mode) if 'raw' not in self.mode: if self.mode == 'hist-match': diff --git a/models/ModelBase.py b/models/ModelBase.py index 3cb88c5..63e12ff 100644 --- a/models/ModelBase.py +++ b/models/ModelBase.py @@ -22,6 +22,7 @@ from samplelib import SampleGeneratorBase class ModelBase(object): def __init__(self, is_training=False, + src_src=False, saved_models_path=None, training_data_src_path=None, training_data_dst_path=None, @@ -36,6 +37,7 @@ class ModelBase(object): silent_start=False, **kwargs): self.is_training = is_training + self.src_src = src_src self.saved_models_path = saved_models_path self.training_data_src_path = training_data_src_path self.training_data_dst_path = training_data_dst_path diff --git a/models/Model_SAEHD/Model.py b/models/Model_SAEHD/Model.py index d34e7dd..932f0bb 100644 --- a/models/Model_SAEHD/Model.py +++ b/models/Model_SAEHD/Model.py @@ -601,20 +601,29 @@ Examples: df, liae, df-d, df-ud, liae-ud, ... _, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code) elif 'liae' in archi_type: + gpu_src_code = self.encoder (self.warped_src) + gpu_src_inter_AB_code = self.inter_AB (gpu_src_code) + gpu_src_code = tf.concat([gpu_src_inter_AB_code,gpu_src_inter_AB_code], nn.conv2d_ch_axis ) + gpu_pred_src_src, gpu_pred_src_srcm = self.decoder(gpu_src_code) + gpu_dst_code = self.encoder (self.warped_dst) gpu_dst_inter_B_code = self.inter_B (gpu_dst_code) gpu_dst_inter_AB_code = self.inter_AB (gpu_dst_code) gpu_dst_code = tf.concat([gpu_dst_inter_B_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis) gpu_src_dst_code = tf.concat([gpu_dst_inter_AB_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis) - gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code) + gpu_pred_src_dst, gpu_pred_src_dst_dstm = self.decoder(gpu_src_dst_code) _, gpu_pred_dst_dstm = self.decoder(gpu_dst_code) - def AE_merge( warped_dst): - return nn.tf_sess.run ( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm], feed_dict={self.warped_dst:warped_dst}) + def AE_merge(warped_dst): + return nn.tf_sess.run ( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dst_dstm], feed_dict={self.warped_dst:warped_dst}) + + def AE_src(warped_src): + return nn.tf_sess.run( [gpu_pred_src_src, gpu_pred_src_srcm], feed_dict={self.warped_src:warped_src}) self.AE_merge = AE_merge + self.AE_src = AE_src # Loading/initializing all models/optimizers weights for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"): @@ -813,9 +822,21 @@ Examples: df, liae, df-d, df-ud, liae-ud, ... return bgr[0], mask_src_dstm[0][...,0], mask_dst_dstm[0][...,0] + def src_predictor_func (self, face=None): + face = nn.to_data_format(face[None,...], self.model_data_format, "NHWC") + bgr, mask_src_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format).astype(np.float32) for x in self.AE_src (face) ] + return bgr[0], mask_src_dstm[0][...,0], mask_src_dstm[0][...,0] + #override def get_MergerConfig(self): import merger - return self.predictor_func, (self.options['resolution'], self.options['resolution'], 3), merger.MergerConfigMasked(face_type=self.face_type, default_mode = 'overlay') + if self.src_src: + merger_config = merger.MergerConfigMasked(face_type=self.face_type, + default_mode='raw-predict', + mode='raw-predict', + src_src=True) + return self.src_predictor_func, (self.options['resolution'], self.options['resolution'], 3), merger_config + else: + return self.predictor_func, (self.options['resolution'], self.options['resolution'], 3), merger.MergerConfigMasked(face_type=self.face_type, default_mode = 'overlay') Model = SAEHDModel