diff --git a/mainscripts/FacesetRelighter.py b/mainscripts/FacesetRelighter.py index 69fc5cf..227e2b4 100644 --- a/mainscripts/FacesetRelighter.py +++ b/mainscripts/FacesetRelighter.py @@ -1,6 +1,7 @@ import traceback from pathlib import Path +import imagelib from interact import interact as io from nnlib import DeepPortraitRelighting from utils import Path_utils @@ -8,20 +9,195 @@ from utils.cv2_utils import * from utils.DFLJPG import DFLJPG from utils.DFLPNG import DFLPNG +class RelightEditor: + def __init__(self, image_paths, dpr, lighten): + self.image_paths = image_paths + self.dpr = dpr + self.lighten = lighten + + self.current_img_path = None + self.current_img = None + self.current_img_shape = None + self.pick_new_face() + + self.alt_azi_ar = [ [0,0] ] + self.alt_azi_cur = 0 + + self.mouse_x = self.mouse_y = 9999 + self.screen_status_block = None + self.screen_status_block_dirty = True + self.screen_changed = True + + def pick_new_face(self): + self.current_img_path = self.image_paths[ np.random.randint(len(self.image_paths)) ] + self.current_img = cv2_imread (str(self.current_img_path)) + self.current_img_shape = self.current_img.shape + self.set_screen_changed() + + def set_screen_changed(self): + self.screen_changed = True + + def switch_screen_changed(self): + result = self.screen_changed + self.screen_changed = False + return result + + def make_screen(self): + alt,azi=self.alt_azi_ar[self.alt_azi_cur] + + img = self.dpr.relight (self.current_img, alt, azi, self.lighten) + + h,w,c = img.shape + + lines = ['Pick light directions for whole faceset.', + '[q]-new test face', + '[w][e]-navigate', + '[r]-new [t]-delete [enter]-process', + ''] + + for i, (alt,azi) in enumerate(self.alt_azi_ar): + s = '>:' if self.alt_azi_cur == i else ' :' + s += f'alt=[{ int(alt):03}] azi=[{ int(azi):03}]' + lines += [ s ] + + lines_count = len(lines) + h_line = 16 + + sh = lines_count * h_line + sw = 400 + sc = c + status_img = np.ones ( (sh,sw,sc) ) * 0.1 + + for i in range(lines_count): + status_img[ i*h_line:(i+1)*h_line, 0:sw] += \ + imagelib.get_text_image ( (h_line,sw,c), lines[i], color=[0.8]*c ) + + status_img = np.clip(status_img*255, 0, 255).astype(np.uint8) + + #combine screens + if sh > h: + img = np.concatenate ([img, np.zeros( (sh-h,w,c), dtype=img.dtype ) ], axis=0) + elif h > sh: + status_img = np.concatenate ([status_img, np.zeros( (h-sh,sw,sc), dtype=img.dtype ) ], axis=0) + + img = np.concatenate ([img, status_img], axis=1) + + return img + + def run(self): + wnd_name = "Relighter" + io.named_window(wnd_name) + io.capture_keys(wnd_name) + io.capture_mouse(wnd_name) + + zoom_factor = 1.0 + + is_angle_editing = False + + is_exit = False + while not is_exit: + io.process_messages(0.0001) + + mouse_events = io.get_mouse_events(wnd_name) + for ev in mouse_events: + (x, y, ev, flags) = ev + if ev == io.EVENT_LBUTTONDOWN: + is_angle_editing = True + + if ev == io.EVENT_LBUTTONUP: + is_angle_editing = False + + if is_angle_editing: + h,w,c = self.current_img_shape + + self.alt_azi_ar[self.alt_azi_cur] = \ + [np.clip ( ( 0.5-y/w )*2.0, -1, 1)*90, \ + np.clip ( (x / h - 0.5)*2.0, -1, 1)*90 ] + + self.set_screen_changed() + + key_events = io.get_key_events(wnd_name) + key, chr_key, ctrl_pressed, alt_pressed, shift_pressed = key_events[-1] if len(key_events) > 0 else (0,0,False,False,False) + + if key != 0: + if chr_key == 'q': + self.pick_new_face() + elif chr_key == 'w': + self.alt_azi_cur = np.clip (self.alt_azi_cur-1, 0, len(self.alt_azi_ar)-1) + self.set_screen_changed() + elif chr_key == 'e': + self.alt_azi_cur = np.clip (self.alt_azi_cur+1, 0, len(self.alt_azi_ar)-1) + self.set_screen_changed() + elif chr_key == 'r': + #add direction + self.alt_azi_ar += [ [0,0] ] + self.alt_azi_cur +=1 + self.set_screen_changed() + elif chr_key == 't': + if len(self.alt_azi_ar) > 1: + self.alt_azi_ar.pop(self.alt_azi_cur) + self.alt_azi_cur = np.clip (self.alt_azi_cur, 0, len(self.alt_azi_ar)-1) + self.set_screen_changed() + elif key == 27 or chr_key == '\r' or chr_key == '\n': #esc + is_exit = True + + if self.switch_screen_changed(): + screen = self.make_screen() + if zoom_factor != 1.0: + h,w,c = screen.shape + screen = cv2.resize ( screen, ( int(w*zoom_factor), int(h*zoom_factor) ) ) + io.show_image (wnd_name, screen ) + + io.destroy_window(wnd_name) + + return self.alt_azi_ar def relight(input_dir, lighten=None, random_one=None): if lighten is None: - lighten = io.input_bool ("Lighten the faces? ( y/n default:n ) : ", False) + lighten = io.input_bool ("Lighten the faces? ( y/n default:n ?:help ) : ", False, help_message="Lighten the faces instead of shadow. May produce artifacts." ) - if random_one is None: - random_one = io.input_bool ("Relight the faces only with one random direction? ( y/n default:y ) : ", True) + if io.is_colab(): + io.log_info("In colab version you cannot choose light directions manually.") + manual = False + else: + manual = io.input_bool ("Choose light directions manually? ( y/n default:y ) : ", True) - input_path = Path(input_dir) + if not manual: + if random_one is None: + random_one = io.input_bool ("Relight the faces only with one random direction? ( y/n default:y ?:help) : ", True, help_message="Otherwise faceset will be relighted with predefined 7 light directions.") - image_paths = [Path(x) for x in Path_utils.get_image_paths(input_path)] + image_paths = [Path(x) for x in Path_utils.get_image_paths(input_dir)] + filtered_image_paths = [] + for filepath in io.progress_bar_generator(image_paths, "Collecting fileinfo"): + try: + if filepath.suffix == '.png': + dflimg = DFLPNG.load( str(filepath) ) + elif filepath.suffix == '.jpg': + dflimg = DFLJPG.load ( str(filepath) ) + else: + dflimg = None + + if dflimg is None: + io.log_err ("%s is not a dfl image file" % (filepath.name) ) + else: + if not dflimg.get_relighted(): + filtered_image_paths += [filepath] + except: + io.log_err (f"Exception occured while processing file {filepath.name}. Error: {traceback.format_exc()}") + image_paths = filtered_image_paths + + if len(image_paths) == 0: + io.log_info("No files to process.") + return dpr = DeepPortraitRelighting() + if manual: + alt_azi_ar = RelightEditor(image_paths, dpr, lighten).run() + else: + if not random_one: + alt_azi_ar = [(60,0), (60,60), (0,60), (-60,60), (-60,0), (-60,-60), (0,-60), (60,-60)] + for filepath in io.progress_bar_generator(image_paths, "Relighting"): try: if filepath.suffix == '.png': @@ -36,24 +212,30 @@ def relight(input_dir, lighten=None, random_one=None): continue else: if dflimg.get_relighted(): - io.log_info (f"Skipping already relighted face [{filepath.name}]") continue img = cv2_imread (str(filepath)) if random_one: - relighted_imgs = dpr.relight_random(img,lighten=lighten) + alt = np.random.randint(-90,91) + azi = np.random.randint(-90,91) + relighted_imgs = [dpr.relight(img,alt=alt,azi=azi,lighten=lighten)] else: - relighted_imgs = dpr.relight_all(img,lighten=lighten) + relighted_imgs = [dpr.relight(img,alt=alt,azi=azi,lighten=lighten) for (alt,azi) in alt_azi_ar ] + i = 0 for i,relighted_img in enumerate(relighted_imgs): im_flags = [] if filepath.suffix == '.jpg': im_flags += [int(cv2.IMWRITE_JPEG_QUALITY), 100] - relighted_filename = filepath.parent / (filepath.stem+f'_relighted_{i}'+filepath.suffix) + while True: + relighted_filepath = filepath.parent / (filepath.stem+f'_relighted_{i}'+filepath.suffix) + if not relighted_filepath.exists(): + break + i += 1 - cv2_imwrite (relighted_filename, relighted_img ) - dflimg.embed_and_set (relighted_filename, source_filename="_", relighted=True ) + cv2_imwrite (relighted_filepath, relighted_img ) + dflimg.embed_and_set (relighted_filepath, source_filename="_", relighted=True ) except: io.log_err (f"Exception occured while processing file {filepath.name}. Error: {traceback.format_exc()}") diff --git a/nnlib/DeepPortraitRelighting.py b/nnlib/DeepPortraitRelighting.py index 9ee3ab8..c8d520a 100644 --- a/nnlib/DeepPortraitRelighting.py +++ b/nnlib/DeepPortraitRelighting.py @@ -1,33 +1,55 @@ +import math from pathlib import Path -import numpy as np + import cv2 +import numpy as np +import numpy.linalg as npla + class DeepPortraitRelighting(object): def __init__(self): from nnlib import nnlib - nnlib.import_torch() - + nnlib.import_torch() self.torch = nnlib.torch - self.torch_device = nnlib.torch_device - + self.torch_device = nnlib.torch_device self.model = DeepPortraitRelighting.build_model(self.torch, self.torch_device) + + def SH_basis(self, alt, azi): + alt = alt * math.pi / 180.0 + azi = azi * math.pi / 180.0 - self.shs = [ - [1.084125496282453138e+00,-4.642676300617166185e-01,2.837846795150648915e-02,6.765292733937575687e-01,-3.594067725393816914e-01,4.790996460111427574e-02,-2.280054643781863066e-01,-8.125983081159608712e-02,2.881082012687687932e-01], - [1.084125496282453138e+00,-4.642676300617170626e-01,5.466255701105990905e-01,3.996219229512094628e-01,-2.615439760463462715e-01,-2.511241554473071513e-01,6.495694866016435420e-02,3.510322039081858470e-01,1.189662732386344152e-01], - [1.084125496282453138e+00,-4.642676300617179508e-01,6.532524688468428486e-01,-1.782088862752457814e-01,3.326676893441832261e-02,-3.610566644446819295e-01,3.647561777790956361e-01,-7.496419691318900735e-02,-5.412289239602386531e-02], - [1.084125496282453138e+00,-4.642676300617186724e-01,2.679669346194941126e-01,-6.218447693376460972e-01,3.030269583891490037e-01,-1.991061409014726058e-01,-6.162944418511027977e-02,-3.176699976873690878e-01,1.920509612235956343e-01], - [1.084125496282453138e+00,-4.642676300617186724e-01,-3.191031669056417219e-01,-5.972188577671910803e-01,3.446016675533919993e-01,1.127753677656503223e-01,-1.716692196540034188e-01,2.163406460637767315e-01,2.555824552121269688e-01], - [1.084125496282453138e+00,-4.642676300617178398e-01,-6.658820752324799974e-01,-1.228749652534838893e-01,1.266842924569576145e-01,3.397347243069742673e-01,3.036887095295650041e-01,2.213893524577207617e-01,-1.886557316342868038e-02], - [1.084125496282453138e+00,-4.642676300617169516e-01,-5.112381993903207800e-01,4.439962822886048266e-01,-1.866289387481862572e-01,3.108669041197227867e-01,2.021743042675238355e-01,-3.148681770175290051e-01,3.974379604123656762e-02] - ] + x = math.cos(alt)*math.sin(azi) + y = -math.cos(alt)*math.cos(azi) + z = math.sin(alt) + + normal = np.array([x,y,z]) + + norm_X = normal[0] + norm_Y = normal[1] + norm_Z = normal[2] + + sh_basis = np.zeros((9)) + att= np.pi*np.array([1, 2.0/3.0, 1/4.0]) + sh_basis[0] = 0.5/np.sqrt(np.pi)*att[0] + + sh_basis[1] = np.sqrt(3)/2/np.sqrt(np.pi)*norm_Y*att[1] + sh_basis[2] = np.sqrt(3)/2/np.sqrt(np.pi)*norm_Z*att[1] + sh_basis[3] = np.sqrt(3)/2/np.sqrt(np.pi)*norm_X*att[1] + + sh_basis[4] = np.sqrt(15)/2/np.sqrt(np.pi)*norm_Y*norm_X*att[2] + sh_basis[5] = np.sqrt(15)/2/np.sqrt(np.pi)*norm_Y*norm_Z*att[2] + sh_basis[6] = np.sqrt(5)/4/np.sqrt(np.pi)*(3*norm_Z**2-1)*att[2] + sh_basis[7] = np.sqrt(15)/2/np.sqrt(np.pi)*norm_X*norm_Z*att[2] + sh_basis[8] = np.sqrt(15)/4/np.sqrt(np.pi)*(norm_X**2-norm_Y**2)*att[2] + return sh_basis #n = [0..8] - def relight(self, img, n, lighten=False): + def relight(self, img, alt, azi, lighten=False): torch = self.torch - - sh = (np.array (self.shs[np.clip(n, 0,8)]).reshape( (1,9,1,1) )*0.7).astype(np.float32) + + sh = self.SH_basis (alt, azi) + sh = (sh.reshape( (1,9,1,1) ) ).astype(np.float32) sh = torch.autograd.Variable(torch.from_numpy(sh).to(self.torch_device)) row, col, _ = img.shape @@ -54,13 +76,7 @@ class DeepPortraitRelighting(object): result = cv2.cvtColor(Lab, cv2.COLOR_LAB2BGR) result = cv2.resize(result, (col, row)) return result - - def relight_all(self, img, lighten=False): - return [ self.relight(img, n, lighten=lighten) for n in range( len(self.shs) ) ] - - def relight_random(self, img, lighten=False): - return [ self.relight(img, np.random.randint(len(self.shs)), lighten=lighten ) ] - + @staticmethod def build_model(torch, torch_device): nn = torch.nn @@ -220,4 +236,4 @@ class DeepPortraitRelighting(object): model.load_state_dict(t_dict) model.to( torch_device ) model.train(False) - return model \ No newline at end of file + return model diff --git a/nnlib/nnlib.py b/nnlib/nnlib.py index 5841bb9..c9f5059 100644 --- a/nnlib/nnlib.py +++ b/nnlib/nnlib.py @@ -144,6 +144,7 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator if 'CUDA_VISIBLE_DEVICES' in os.environ.keys(): os.environ.pop('CUDA_VISIBLE_DEVICES') + io.log_info ("Using PyTorch backend.") import torch nnlib.torch = torch