FacesetRelighter fixes and improvements:

now you have 3 ways:
1) define light directions manually (not for google colab)
   watch demo https://youtu.be/79xz7yEO5Jw
2) relight faceset with one random direction
3) relight faceset with predefined 8 directions
This commit is contained in:
Colombo 2019-11-11 19:56:36 +04:00
parent fe58459f36
commit 05153d9ba5
3 changed files with 235 additions and 36 deletions

View file

@ -1,6 +1,7 @@
import traceback import traceback
from pathlib import Path from pathlib import Path
import imagelib
from interact import interact as io from interact import interact as io
from nnlib import DeepPortraitRelighting from nnlib import DeepPortraitRelighting
from utils import Path_utils from utils import Path_utils
@ -8,20 +9,195 @@ from utils.cv2_utils import *
from utils.DFLJPG import DFLJPG from utils.DFLJPG import DFLJPG
from utils.DFLPNG import DFLPNG from utils.DFLPNG import DFLPNG
class RelightEditor:
def __init__(self, image_paths, dpr, lighten):
self.image_paths = image_paths
self.dpr = dpr
self.lighten = lighten
self.current_img_path = None
self.current_img = None
self.current_img_shape = None
self.pick_new_face()
self.alt_azi_ar = [ [0,0] ]
self.alt_azi_cur = 0
self.mouse_x = self.mouse_y = 9999
self.screen_status_block = None
self.screen_status_block_dirty = True
self.screen_changed = True
def pick_new_face(self):
self.current_img_path = self.image_paths[ np.random.randint(len(self.image_paths)) ]
self.current_img = cv2_imread (str(self.current_img_path))
self.current_img_shape = self.current_img.shape
self.set_screen_changed()
def set_screen_changed(self):
self.screen_changed = True
def switch_screen_changed(self):
result = self.screen_changed
self.screen_changed = False
return result
def make_screen(self):
alt,azi=self.alt_azi_ar[self.alt_azi_cur]
img = self.dpr.relight (self.current_img, alt, azi, self.lighten)
h,w,c = img.shape
lines = ['Pick light directions for whole faceset.',
'[q]-new test face',
'[w][e]-navigate',
'[r]-new [t]-delete [enter]-process',
'']
for i, (alt,azi) in enumerate(self.alt_azi_ar):
s = '>:' if self.alt_azi_cur == i else ' :'
s += f'alt=[{ int(alt):03}] azi=[{ int(azi):03}]'
lines += [ s ]
lines_count = len(lines)
h_line = 16
sh = lines_count * h_line
sw = 400
sc = c
status_img = np.ones ( (sh,sw,sc) ) * 0.1
for i in range(lines_count):
status_img[ i*h_line:(i+1)*h_line, 0:sw] += \
imagelib.get_text_image ( (h_line,sw,c), lines[i], color=[0.8]*c )
status_img = np.clip(status_img*255, 0, 255).astype(np.uint8)
#combine screens
if sh > h:
img = np.concatenate ([img, np.zeros( (sh-h,w,c), dtype=img.dtype ) ], axis=0)
elif h > sh:
status_img = np.concatenate ([status_img, np.zeros( (h-sh,sw,sc), dtype=img.dtype ) ], axis=0)
img = np.concatenate ([img, status_img], axis=1)
return img
def run(self):
wnd_name = "Relighter"
io.named_window(wnd_name)
io.capture_keys(wnd_name)
io.capture_mouse(wnd_name)
zoom_factor = 1.0
is_angle_editing = False
is_exit = False
while not is_exit:
io.process_messages(0.0001)
mouse_events = io.get_mouse_events(wnd_name)
for ev in mouse_events:
(x, y, ev, flags) = ev
if ev == io.EVENT_LBUTTONDOWN:
is_angle_editing = True
if ev == io.EVENT_LBUTTONUP:
is_angle_editing = False
if is_angle_editing:
h,w,c = self.current_img_shape
self.alt_azi_ar[self.alt_azi_cur] = \
[np.clip ( ( 0.5-y/w )*2.0, -1, 1)*90, \
np.clip ( (x / h - 0.5)*2.0, -1, 1)*90 ]
self.set_screen_changed()
key_events = io.get_key_events(wnd_name)
key, chr_key, ctrl_pressed, alt_pressed, shift_pressed = key_events[-1] if len(key_events) > 0 else (0,0,False,False,False)
if key != 0:
if chr_key == 'q':
self.pick_new_face()
elif chr_key == 'w':
self.alt_azi_cur = np.clip (self.alt_azi_cur-1, 0, len(self.alt_azi_ar)-1)
self.set_screen_changed()
elif chr_key == 'e':
self.alt_azi_cur = np.clip (self.alt_azi_cur+1, 0, len(self.alt_azi_ar)-1)
self.set_screen_changed()
elif chr_key == 'r':
#add direction
self.alt_azi_ar += [ [0,0] ]
self.alt_azi_cur +=1
self.set_screen_changed()
elif chr_key == 't':
if len(self.alt_azi_ar) > 1:
self.alt_azi_ar.pop(self.alt_azi_cur)
self.alt_azi_cur = np.clip (self.alt_azi_cur, 0, len(self.alt_azi_ar)-1)
self.set_screen_changed()
elif key == 27 or chr_key == '\r' or chr_key == '\n': #esc
is_exit = True
if self.switch_screen_changed():
screen = self.make_screen()
if zoom_factor != 1.0:
h,w,c = screen.shape
screen = cv2.resize ( screen, ( int(w*zoom_factor), int(h*zoom_factor) ) )
io.show_image (wnd_name, screen )
io.destroy_window(wnd_name)
return self.alt_azi_ar
def relight(input_dir, lighten=None, random_one=None): def relight(input_dir, lighten=None, random_one=None):
if lighten is None: if lighten is None:
lighten = io.input_bool ("Lighten the faces? ( y/n default:n ) : ", False) lighten = io.input_bool ("Lighten the faces? ( y/n default:n ?:help ) : ", False, help_message="Lighten the faces instead of shadow. May produce artifacts." )
if io.is_colab():
io.log_info("In colab version you cannot choose light directions manually.")
manual = False
else:
manual = io.input_bool ("Choose light directions manually? ( y/n default:y ) : ", True)
if not manual:
if random_one is None: if random_one is None:
random_one = io.input_bool ("Relight the faces only with one random direction? ( y/n default:y ) : ", True) random_one = io.input_bool ("Relight the faces only with one random direction? ( y/n default:y ?:help) : ", True, help_message="Otherwise faceset will be relighted with predefined 7 light directions.")
input_path = Path(input_dir) image_paths = [Path(x) for x in Path_utils.get_image_paths(input_dir)]
filtered_image_paths = []
for filepath in io.progress_bar_generator(image_paths, "Collecting fileinfo"):
try:
if filepath.suffix == '.png':
dflimg = DFLPNG.load( str(filepath) )
elif filepath.suffix == '.jpg':
dflimg = DFLJPG.load ( str(filepath) )
else:
dflimg = None
image_paths = [Path(x) for x in Path_utils.get_image_paths(input_path)] if dflimg is None:
io.log_err ("%s is not a dfl image file" % (filepath.name) )
else:
if not dflimg.get_relighted():
filtered_image_paths += [filepath]
except:
io.log_err (f"Exception occured while processing file {filepath.name}. Error: {traceback.format_exc()}")
image_paths = filtered_image_paths
if len(image_paths) == 0:
io.log_info("No files to process.")
return
dpr = DeepPortraitRelighting() dpr = DeepPortraitRelighting()
if manual:
alt_azi_ar = RelightEditor(image_paths, dpr, lighten).run()
else:
if not random_one:
alt_azi_ar = [(60,0), (60,60), (0,60), (-60,60), (-60,0), (-60,-60), (0,-60), (60,-60)]
for filepath in io.progress_bar_generator(image_paths, "Relighting"): for filepath in io.progress_bar_generator(image_paths, "Relighting"):
try: try:
if filepath.suffix == '.png': if filepath.suffix == '.png':
@ -36,24 +212,30 @@ def relight(input_dir, lighten=None, random_one=None):
continue continue
else: else:
if dflimg.get_relighted(): if dflimg.get_relighted():
io.log_info (f"Skipping already relighted face [{filepath.name}]")
continue continue
img = cv2_imread (str(filepath)) img = cv2_imread (str(filepath))
if random_one: if random_one:
relighted_imgs = dpr.relight_random(img,lighten=lighten) alt = np.random.randint(-90,91)
azi = np.random.randint(-90,91)
relighted_imgs = [dpr.relight(img,alt=alt,azi=azi,lighten=lighten)]
else: else:
relighted_imgs = dpr.relight_all(img,lighten=lighten) relighted_imgs = [dpr.relight(img,alt=alt,azi=azi,lighten=lighten) for (alt,azi) in alt_azi_ar ]
i = 0
for i,relighted_img in enumerate(relighted_imgs): for i,relighted_img in enumerate(relighted_imgs):
im_flags = [] im_flags = []
if filepath.suffix == '.jpg': if filepath.suffix == '.jpg':
im_flags += [int(cv2.IMWRITE_JPEG_QUALITY), 100] im_flags += [int(cv2.IMWRITE_JPEG_QUALITY), 100]
relighted_filename = filepath.parent / (filepath.stem+f'_relighted_{i}'+filepath.suffix) while True:
relighted_filepath = filepath.parent / (filepath.stem+f'_relighted_{i}'+filepath.suffix)
if not relighted_filepath.exists():
break
i += 1
cv2_imwrite (relighted_filename, relighted_img ) cv2_imwrite (relighted_filepath, relighted_img )
dflimg.embed_and_set (relighted_filename, source_filename="_", relighted=True ) dflimg.embed_and_set (relighted_filepath, source_filename="_", relighted=True )
except: except:
io.log_err (f"Exception occured while processing file {filepath.name}. Error: {traceback.format_exc()}") io.log_err (f"Exception occured while processing file {filepath.name}. Error: {traceback.format_exc()}")

View file

@ -1,33 +1,55 @@
import math
from pathlib import Path from pathlib import Path
import numpy as np
import cv2 import cv2
import numpy as np
import numpy.linalg as npla
class DeepPortraitRelighting(object): class DeepPortraitRelighting(object):
def __init__(self): def __init__(self):
from nnlib import nnlib from nnlib import nnlib
nnlib.import_torch() nnlib.import_torch()
self.torch = nnlib.torch self.torch = nnlib.torch
self.torch_device = nnlib.torch_device self.torch_device = nnlib.torch_device
self.model = DeepPortraitRelighting.build_model(self.torch, self.torch_device) self.model = DeepPortraitRelighting.build_model(self.torch, self.torch_device)
self.shs = [ def SH_basis(self, alt, azi):
[1.084125496282453138e+00,-4.642676300617166185e-01,2.837846795150648915e-02,6.765292733937575687e-01,-3.594067725393816914e-01,4.790996460111427574e-02,-2.280054643781863066e-01,-8.125983081159608712e-02,2.881082012687687932e-01], alt = alt * math.pi / 180.0
[1.084125496282453138e+00,-4.642676300617170626e-01,5.466255701105990905e-01,3.996219229512094628e-01,-2.615439760463462715e-01,-2.511241554473071513e-01,6.495694866016435420e-02,3.510322039081858470e-01,1.189662732386344152e-01], azi = azi * math.pi / 180.0
[1.084125496282453138e+00,-4.642676300617179508e-01,6.532524688468428486e-01,-1.782088862752457814e-01,3.326676893441832261e-02,-3.610566644446819295e-01,3.647561777790956361e-01,-7.496419691318900735e-02,-5.412289239602386531e-02],
[1.084125496282453138e+00,-4.642676300617186724e-01,2.679669346194941126e-01,-6.218447693376460972e-01,3.030269583891490037e-01,-1.991061409014726058e-01,-6.162944418511027977e-02,-3.176699976873690878e-01,1.920509612235956343e-01], x = math.cos(alt)*math.sin(azi)
[1.084125496282453138e+00,-4.642676300617186724e-01,-3.191031669056417219e-01,-5.972188577671910803e-01,3.446016675533919993e-01,1.127753677656503223e-01,-1.716692196540034188e-01,2.163406460637767315e-01,2.555824552121269688e-01], y = -math.cos(alt)*math.cos(azi)
[1.084125496282453138e+00,-4.642676300617178398e-01,-6.658820752324799974e-01,-1.228749652534838893e-01,1.266842924569576145e-01,3.397347243069742673e-01,3.036887095295650041e-01,2.213893524577207617e-01,-1.886557316342868038e-02], z = math.sin(alt)
[1.084125496282453138e+00,-4.642676300617169516e-01,-5.112381993903207800e-01,4.439962822886048266e-01,-1.866289387481862572e-01,3.108669041197227867e-01,2.021743042675238355e-01,-3.148681770175290051e-01,3.974379604123656762e-02]
] normal = np.array([x,y,z])
norm_X = normal[0]
norm_Y = normal[1]
norm_Z = normal[2]
sh_basis = np.zeros((9))
att= np.pi*np.array([1, 2.0/3.0, 1/4.0])
sh_basis[0] = 0.5/np.sqrt(np.pi)*att[0]
sh_basis[1] = np.sqrt(3)/2/np.sqrt(np.pi)*norm_Y*att[1]
sh_basis[2] = np.sqrt(3)/2/np.sqrt(np.pi)*norm_Z*att[1]
sh_basis[3] = np.sqrt(3)/2/np.sqrt(np.pi)*norm_X*att[1]
sh_basis[4] = np.sqrt(15)/2/np.sqrt(np.pi)*norm_Y*norm_X*att[2]
sh_basis[5] = np.sqrt(15)/2/np.sqrt(np.pi)*norm_Y*norm_Z*att[2]
sh_basis[6] = np.sqrt(5)/4/np.sqrt(np.pi)*(3*norm_Z**2-1)*att[2]
sh_basis[7] = np.sqrt(15)/2/np.sqrt(np.pi)*norm_X*norm_Z*att[2]
sh_basis[8] = np.sqrt(15)/4/np.sqrt(np.pi)*(norm_X**2-norm_Y**2)*att[2]
return sh_basis
#n = [0..8] #n = [0..8]
def relight(self, img, n, lighten=False): def relight(self, img, alt, azi, lighten=False):
torch = self.torch torch = self.torch
sh = (np.array (self.shs[np.clip(n, 0,8)]).reshape( (1,9,1,1) )*0.7).astype(np.float32) sh = self.SH_basis (alt, azi)
sh = (sh.reshape( (1,9,1,1) ) ).astype(np.float32)
sh = torch.autograd.Variable(torch.from_numpy(sh).to(self.torch_device)) sh = torch.autograd.Variable(torch.from_numpy(sh).to(self.torch_device))
row, col, _ = img.shape row, col, _ = img.shape
@ -55,12 +77,6 @@ class DeepPortraitRelighting(object):
result = cv2.resize(result, (col, row)) result = cv2.resize(result, (col, row))
return result return result
def relight_all(self, img, lighten=False):
return [ self.relight(img, n, lighten=lighten) for n in range( len(self.shs) ) ]
def relight_random(self, img, lighten=False):
return [ self.relight(img, np.random.randint(len(self.shs)), lighten=lighten ) ]
@staticmethod @staticmethod
def build_model(torch, torch_device): def build_model(torch, torch_device):
nn = torch.nn nn = torch.nn

View file

@ -144,6 +144,7 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
if 'CUDA_VISIBLE_DEVICES' in os.environ.keys(): if 'CUDA_VISIBLE_DEVICES' in os.environ.keys():
os.environ.pop('CUDA_VISIBLE_DEVICES') os.environ.pop('CUDA_VISIBLE_DEVICES')
io.log_info ("Using PyTorch backend.")
import torch import torch
nnlib.torch = torch nnlib.torch = torch