changed help message for pixel loss:

Pixel loss may help to enhance fine details and stabilize face color. Use it only if quality does not improve over time.

SAE:
previous SAE model will not work with this update.
Greatly decreased chance of model collapse.
Increased model accuracy.
Residual blocks now default and this option has been removed.
Improved 'learn mask'.
Added masked preview (switch by space key)

Converter:
fixed rct/lct in seamless mode
added mask mode (6) learned*FAN-prd*FAN-dst

added mask editor, its created for refining dataset for FANSeg model, and not for production, but you can spend your time and test it in regular fakes with face obstructions
This commit is contained in:
iperov 2019-04-04 10:22:53 +04:00
commit 5ac7e5d7f1
22 changed files with 715 additions and 387 deletions

View file

@ -1,287 +0,0 @@
import traceback
import os
import sys
import time
import numpy as np
import numpy.linalg as npl
import cv2
from pathlib import Path
from interact import interact as io
from utils.cv2_utils import *
from utils import Path_utils
from utils.DFLPNG import DFLPNG
from utils.DFLJPG import DFLJPG
from facelib import LandmarksProcessor
def main(input_dir, output_dir):
input_path = Path(input_dir)
output_path = Path(output_dir)
if not input_path.exists():
raise ValueError('Input directory not found. Please ensure it exists.')
if not output_path.exists():
output_path.mkdir(parents=True)
wnd_name = "Labeling tool"
io.named_window (wnd_name)
io.capture_mouse(wnd_name)
io.capture_keys(wnd_name)
#for filename in io.progress_bar_generator (Path_utils.get_image_paths(input_path), desc="Labeling"):
for filename in Path_utils.get_image_paths(input_path):
filepath = Path(filename)
if filepath.suffix == '.png':
dflimg = DFLPNG.load( str(filepath) )
elif filepath.suffix == '.jpg':
dflimg = DFLJPG.load ( str(filepath) )
else:
dflimg = None
if dflimg is None:
io.log_err ("%s is not a dfl image file" % (filepath.name) )
continue
lmrks = dflimg.get_landmarks()
lmrks_list = lmrks.tolist()
orig_img = cv2_imread(str(filepath))
h,w,c = orig_img.shape
mask_orig = LandmarksProcessor.get_image_hull_mask( orig_img.shape, lmrks).astype(np.uint8)[:,:,0]
ero_dil_rate = w // 8
mask_ero = cv2.erode (mask_orig, cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(ero_dil_rate,ero_dil_rate)), iterations = 1 )
mask_dil = cv2.dilate(mask_orig, cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(ero_dil_rate,ero_dil_rate)), iterations = 1 )
#mask_bg = np.zeros(orig_img.shape[:2],np.uint8)
mask_bg = 1-mask_dil
mask_bgp = np.ones(orig_img.shape[:2],np.uint8) #default - all background possible
mask_fg = np.zeros(orig_img.shape[:2],np.uint8)
mask_fgp = np.zeros(orig_img.shape[:2],np.uint8)
img = orig_img.copy()
l_thick=2
def draw_4_lines (masks_out, pts, thickness=1):
fgp,fg,bg,bgp = masks_out
h,w = fg.shape
fgp_pts = []
fg_pts = np.array([ pts[i:i+2] for i in range(len(pts)-1)])
bg_pts = []
bgp_pts = []
for i in range(len(fg_pts)):
a, b = line = fg_pts[i]
ba = b-a
v = ba / npl.norm(ba)
ccpv = np.array([v[1],-v[0]])
cpv = np.array([-v[1],v[0]])
step = 1 / max(np.abs(cpv))
fgp_pts.append ( np.clip (line + ccpv * step * thickness, 0, w-1 ).astype(np.int) )
bg_pts.append ( np.clip (line + cpv * step * thickness, 0, w-1 ).astype(np.int) )
bgp_pts.append ( np.clip (line + cpv * step * thickness * 2, 0, w-1 ).astype(np.int) )
fgp_pts = np.array(fgp_pts)
bg_pts = np.array(bg_pts)
bgp_pts = np.array(bgp_pts)
cv2.polylines(fgp, fgp_pts, False, (1,), thickness=thickness)
cv2.polylines(fg, fg_pts, False, (1,), thickness=thickness)
cv2.polylines(bg, bg_pts, False, (1,), thickness=thickness)
cv2.polylines(bgp, bgp_pts, False, (1,), thickness=thickness)
def draw_lines ( masks_steps, pts, thickness=1):
lines = np.array([ pts[i:i+2] for i in range(len(pts)-1)])
for mask, step in masks_steps:
h,w = mask.shape
mask_lines = []
for i in range(len(lines)):
a, b = line = lines[i]
ba = b-a
ba_len = npl.norm(ba)
if ba_len != 0:
v = ba / ba_len
pv = np.array([-v[1],v[0]])
pv_inv_max = 1 / max(np.abs(pv))
mask_lines.append ( np.clip (line + pv * pv_inv_max * thickness * step, 0, w-1 ).astype(np.int) )
else:
mask_lines.append ( np.array(line, dtype=np.int) )
cv2.polylines(mask, mask_lines, False, (1,), thickness=thickness)
def draw_fill_convex( mask_out, pts, scale=1.0 ):
hull = cv2.convexHull(np.array(pts))
if scale !=1.0:
pts_count = hull.shape[0]
sum_x = np.sum(hull[:, 0, 0])
sum_y = np.sum(hull[:, 0, 1])
hull_center = np.array([sum_x/pts_count, sum_y/pts_count])
hull = hull_center+(hull-hull_center)*scale
hull = hull.astype(pts.dtype)
cv2.fillConvexPoly( mask_out, hull, (1,) )
def get_gc_mask_bgr(gc_mask):
h, w = gc_mask.shape
bgr = np.zeros( (h,w,3), dtype=np.uint8 )
bgr [ gc_mask == 0 ] = (0,0,0)
bgr [ gc_mask == 1 ] = (255,255,255)
bgr [ gc_mask == 2 ] = (0,0,255) #RED
bgr [ gc_mask == 3 ] = (0,255,0) #GREEN
return bgr
def get_gc_mask_result(gc_mask):
return np.where((gc_mask==1) + (gc_mask==3),1,0).astype(np.int)
#convex inner of right chin to end of right eyebrow
#draw_fill_convex ( mask_fgp, lmrks_list[8:17]+lmrks_list[26:27] )
#convex inner of start right chin to right eyebrow
#draw_fill_convex ( mask_fgp, lmrks_list[8:9]+lmrks_list[22:27] )
#convex inner of nose
draw_fill_convex ( mask_fgp, lmrks[27:36] )
#convex inner of nose half
draw_fill_convex ( mask_fg, lmrks[27:36], scale=0.5 )
#left corner of mouth to left corner of nose
#draw_lines ( [ (mask_fg,0), ], lmrks_list[49:50]+lmrks_list[32:33], l_thick)
#convex inner: right corner of nose to centers of eyebrows
#draw_fill_convex ( mask_fgp, lmrks_list[35:36]+lmrks_list[19:20]+lmrks_list[24:25])
#right corner of mouth to right corner of nose
#draw_lines ( [ (mask_fg,0), ], lmrks_list[54:55]+lmrks_list[35:36], l_thick)
#left eye
#draw_fill_convex ( mask_fg, lmrks_list[36:40] )
#right eye
#draw_fill_convex ( mask_fg, lmrks_list[42:48] )
#right chin
draw_lines ( [ (mask_bg,0), (mask_fg,-1), ], lmrks[8:17], l_thick)
#left eyebrow center to right eyeprow center
draw_lines ( [ (mask_bg,-1), (mask_fg,0), ], lmrks_list[19:20] + lmrks_list[24:25], l_thick)
# #draw_lines ( [ (mask_bg,-1), (mask_fg,0), ], lmrks_list[24:25] + lmrks_list[19:17:-1], l_thick)
#half right eyebrow to end of right chin
draw_lines ( [ (mask_bg,-1), (mask_fg,0), ], lmrks_list[24:27] + lmrks_list[16:17], l_thick)
#import code
#code.interact(local=dict(globals(), **locals()))
#compose mask layers
gc_mask = np.zeros(orig_img.shape[:2],np.uint8)
gc_mask [ mask_bgp==1 ] = 2
gc_mask [ mask_fgp==1 ] = 3
gc_mask [ mask_bg==1 ] = 0
gc_mask [ mask_fg==1 ] = 1
gc_bgr_before = get_gc_mask_bgr (gc_mask)
#io.show_image (wnd_name, gc_mask )
##points, hierarcy = cv2.findContours(original_mask,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
##gc_mask = ( (1-erode_mask)*2 + erode_mask )# * dilate_mask
#gc_mask = (1-erode_mask)*2 + erode_mask
#cv2.addWeighted(
#gc_mask = mask_0_27 + (1-mask_0_27)*2
#
##import code
##code.interact(local=dict(globals(), **locals()))
#
#rect = (1,1,img.shape[1]-2,img.shape[0]-2)
#
#
cv2.grabCut(img,gc_mask,None,np.zeros((1,65),np.float64),np.zeros((1,65),np.float64),5, cv2.GC_INIT_WITH_MASK)
gc_bgr = get_gc_mask_bgr (gc_mask)
gc_mask_result = get_gc_mask_result(gc_mask)
gc_mask_result_1 = gc_mask_result[:,:,np.newaxis]
#import code
#code.interact(local=dict(globals(), **locals()))
orig_img_gc_layers_masked = (0.5*orig_img + 0.5*gc_bgr).astype(np.uint8)
orig_img_gc_before_layers_masked = (0.5*orig_img + 0.5*gc_bgr_before).astype(np.uint8)
pink_bg = np.full ( orig_img.shape, (255,0,255), dtype=np.uint8 )
orig_img_result = orig_img * gc_mask_result_1
orig_img_result_pinked = orig_img_result + pink_bg * (1-gc_mask_result_1)
#io.show_image (wnd_name, blended_img)
##gc_mask, bgdModel, fgdModel =
#
#mask2 = np.where((gc_mask==1) + (gc_mask==3),255,0).astype('uint8')[:,:,np.newaxis]
#mask2 = np.repeat(mask2, (3,), -1)
#
##mask2 = np.where(gc_mask!=0,255,0).astype('uint8')
#blended_img = orig_img #-\
# #0.3 * np.full(original_img.shape, (50,50,50)) * (1-mask_0_27)[:,:,np.newaxis]
# #0.3 * np.full(original_img.shape, (50,50,50)) * (1-dilate_mask)[:,:,np.newaxis] +\
# #0.3 * np.full(original_img.shape, (50,50,50)) * (erode_mask)[:,:,np.newaxis]
#blended_img = np.clip(blended_img, 0, 255).astype(np.uint8)
##import code
##code.interact(local=dict(globals(), **locals()))
orig_img_lmrked = orig_img.copy()
LandmarksProcessor.draw_landmarks(orig_img_lmrked, lmrks, transparent_mask=True)
screen = np.concatenate ([orig_img_gc_before_layers_masked,
orig_img_gc_layers_masked,
orig_img,
orig_img_lmrked,
orig_img_result_pinked,
orig_img_result,
], axis=1)
io.show_image (wnd_name, screen.astype(np.uint8) )
while True:
io.process_messages()
for (x,y,ev,flags) in io.get_mouse_events(wnd_name):
pass
#print (x,y,ev,flags)
key_events = [ ev for ev, in io.get_key_events(wnd_name) ]
for key in key_events:
if key == ord('1'):
pass
if key == ord('2'):
pass
if key == ord('3'):
pass
if ord(' ') in key_events:
break
import code
code.interact(local=dict(globals(), **locals()))
#original_mask = np.ones(original_img.shape[:2],np.uint8)*2
#cv2.drawContours(original_mask, points, -1, (1,), 1)

View file

@ -0,0 +1,376 @@
import os
import sys
import time
import traceback
from pathlib import Path
import cv2
import numpy as np
import numpy.linalg as npl
import imagelib
from facelib import LandmarksProcessor
from imagelib import IEPolys
from interact import interact as io
from utils import Path_utils
from utils.cv2_utils import *
from utils.DFLJPG import DFLJPG
from utils.DFLPNG import DFLPNG
class MaskEditor:
STATE_NONE=0
STATE_MASKING=1
def __init__(self, img, mask=None, ie_polys=None, get_status_lines_func=None):
self.img = imagelib.normalize_channels (img,3)
h, w, c = img.shape
ph, pw = h // 4, w // 4
if mask is not None:
self.mask = imagelib.normalize_channels (mask,3)
else:
self.mask = np.zeros ( (h,w,3) )
self.get_status_lines_func = get_status_lines_func
self.state_prop = self.STATE_NONE
self.w, self.h = w, h
self.pw, self.ph = pw, ph
self.pwh = np.array([self.pw, self.ph])
self.pwh2 = np.array([self.pw*2, self.ph*2])
self.sw, self.sh = w+pw*2, h+ph*2
if ie_polys is None:
ie_polys = IEPolys()
self.ie_polys = ie_polys
self.polys_mask = None
self.screen_status_block = None
self.screen_status_block_dirty = True
def set_state(self, state):
self.state = state
@property
def state(self):
return self.state_prop
@state.setter
def state(self, value):
self.state_prop = value
if value == self.STATE_MASKING:
self.ie_polys.dirty = True
def get_mask(self):
if self.ie_polys.switch_dirty():
self.screen_status_block_dirty = True
self.ie_mask = img = self.mask.copy()
self.ie_polys.overlay_mask(img)
return img
return self.ie_mask
def get_screen_overlay(self):
img = np.zeros ( (self.sh, self.sw, 3) )
if self.state == self.STATE_MASKING:
mouse_xy = self.mouse_xy.copy() + self.pwh
l = self.ie_polys.n_list()
if l.n > 0:
p = l.cur_point().copy() + self.pwh
color = (0,1,0) if l.type == 1 else (0,0,1)
cv2.line(img, tuple(p), tuple(mouse_xy), color )
return img
def undo_to_begin_point(self):
while not self.undo_point():
pass
def undo_point(self):
if self.state == self.STATE_NONE:
if self.ie_polys.n > 0:
self.state = self.STATE_MASKING
if self.state == self.STATE_MASKING:
if self.ie_polys.n_list().n_dec() == 0 and \
self.ie_polys.n_dec() == 0:
self.state = self.STATE_NONE
else:
return False
return True
def redo_to_end_point(self):
while not self.redo_point():
pass
def redo_point(self):
if self.state == self.STATE_NONE:
if self.ie_polys.n_max > 0:
self.state = self.STATE_MASKING
if self.ie_polys.n == 0:
self.ie_polys.n_inc()
if self.state == self.STATE_MASKING:
while True:
l = self.ie_polys.n_list()
if l.n_inc() == l.n_max:
if self.ie_polys.n == self.ie_polys.n_max:
break
self.ie_polys.n_inc()
else:
return False
return True
def combine_screens(self, screens):
screens_len = len(screens)
new_screens = []
for screen, padded_overlay in screens:
screen_img = np.zeros( (self.sh, self.sw, 3), dtype=np.float32 )
screen = imagelib.normalize_channels (screen, 3)
h,w,c = screen.shape
screen_img[self.ph:-self.ph, self.pw:-self.pw, :] = screen
if padded_overlay is not None:
screen_img = screen_img + padded_overlay
screen_img = np.clip(screen_img*255, 0, 255).astype(np.uint8)
new_screens.append(screen_img)
return np.concatenate (new_screens, axis=1)
def get_screen_status_block(self, w, c):
if self.screen_status_block_dirty:
self.screen_status_block_dirty = False
lines = [
'Polys current/max = %d/%d' % (self.ie_polys.n, self.ie_polys.n_max),
]
if self.get_status_lines_func is not None:
lines += self.get_status_lines_func()
lines_count = len(lines)
h_line = 21
h = lines_count * h_line
img = np.ones ( (h,w,c) ) * 0.1
for i in range(lines_count):
img[ i*h_line:(i+1)*h_line, 0:w] += \
imagelib.get_text_image ( (h_line,w,c), lines[i], color=[0.8]*c )
self.screen_status_block = np.clip(img*255, 0, 255).astype(np.uint8)
return self.screen_status_block
def set_screen_status_block_dirty(self):
self.screen_status_block_dirty = True
def make_screen(self):
screen_overlay = self.get_screen_overlay()
final_mask = self.get_mask()
masked_img = self.img*final_mask*0.5 + self.img*(1-final_mask)
pink = np.full ( (self.h, self.w, 3), (1,0,1) )
pink_masked_img = self.img*final_mask + pink*(1-final_mask)
screens = [ (self.img, None),
(masked_img, screen_overlay),
(pink_masked_img, screen_overlay),
]
screens = self.combine_screens(screens)
status_img = self.get_screen_status_block( screens.shape[1], screens.shape[2] )
result = np.concatenate ( [screens, status_img], axis=0 )
return result
def mask_finish(self, n_clip=True):
if self.state == self.STATE_MASKING:
if self.ie_polys.n_list().n <= 2:
self.ie_polys.n_dec()
self.state = self.STATE_NONE
if n_clip:
self.ie_polys.n_clip()
def set_mouse_pos(self,x,y):
mouse_x = x % (self.sw) - self.pw
mouse_y = y % (self.sh) - self.ph
self.mouse_xy = np.array( [mouse_x, mouse_y] )
self.mouse_x, self.mouse_y = self.mouse_xy
def mask_point(self, type):
if self.state == self.STATE_MASKING and \
self.ie_polys.n_list().type != type:
self.mask_finish()
elif self.state == self.STATE_NONE:
self.state = self.STATE_MASKING
self.ie_polys.add(type)
if self.state == self.STATE_MASKING:
self.ie_polys.n_list().add (self.mouse_x, self.mouse_y)
def get_ie_polys(self):
return self.ie_polys
def mask_editor_main(input_dir, confirmed_dir=None, skipped_dir=None):
input_path = Path(input_dir)
confirmed_path = Path(confirmed_dir)
skipped_path = Path(skipped_dir)
if not input_path.exists():
raise ValueError('Input directory not found. Please ensure it exists.')
if not confirmed_path.exists():
confirmed_path.mkdir(parents=True)
if not skipped_path.exists():
skipped_path.mkdir(parents=True)
wnd_name = "MaskEditor tool"
io.named_window (wnd_name)
io.capture_mouse(wnd_name)
io.capture_keys(wnd_name)
image_paths = [ Path(x) for x in Path_utils.get_image_paths(input_path)]
done_paths = []
image_paths_total = len(image_paths)
is_exit = False
while not is_exit:
if len(image_paths) > 0:
filepath = image_paths.pop(0)
else:
filepath = None
if filepath is not None:
if filepath.suffix == '.png':
dflimg = DFLPNG.load( str(filepath) )
elif filepath.suffix == '.jpg':
dflimg = DFLJPG.load ( str(filepath) )
else:
dflimg = None
if dflimg is None:
io.log_err ("%s is not a dfl image file" % (filepath.name) )
continue
lmrks = dflimg.get_landmarks()
ie_polys = dflimg.get_ie_polys()
img = cv2_imread(str(filepath)) / 255.0
mask = LandmarksProcessor.get_image_hull_mask( img.shape, lmrks)
else:
img = np.zeros ( (256,256,3) )
mask = np.ones ( (256,256,3) )
ie_polys = None
def get_status_lines_func():
return ['Progress: %d / %d . Current file: %s' % (len(done_paths), image_paths_total, str(filepath.name) if filepath is not None else "end" ),
'[Left mouse button] - mark include mask.',
'[Right mouse button] - mark exclude mask.',
'[Middle mouse button] - finish current poly.',
'[Mouse wheel] - undo/redo poly or point. [+ctrl] - undo to begin/redo to end',
'[q] - prev image. [w] - skip and move to %s. [e] - save and move to %s. ' % (skipped_path.name, confirmed_path.name),
'[z] - prev image. [x] - skip. [c] - save. ',
'[esc] - quit'
]
ed = MaskEditor(img, mask, ie_polys, get_status_lines_func)
next = False
while not next:
io.process_messages(0.005)
for (x,y,ev,flags) in io.get_mouse_events(wnd_name):
ed.set_mouse_pos(x, y)
if filepath is not None:
if ev == io.EVENT_LBUTTONDOWN:
ed.mask_point(1)
elif ev == io.EVENT_RBUTTONDOWN:
ed.mask_point(0)
elif ev == io.EVENT_MBUTTONDOWN:
ed.mask_finish()
elif ev == io.EVENT_MOUSEWHEEL:
if flags & 0x80000000 != 0:
if flags & 0x8 != 0:
ed.undo_to_begin_point()
else:
ed.undo_point()
else:
if flags & 0x8 != 0:
ed.redo_to_end_point()
else:
ed.redo_point()
key_events = [ ev for ev, in io.get_key_events(wnd_name) ]
for key in key_events:
if key == ord('q') or key == ord('z'):
if len(done_paths) > 0:
image_paths.insert(0, filepath)
filepath = done_paths.pop(-1)
if filepath.parent != input_path:
new_filename_path = input_path / filepath.name
filepath.rename ( new_filename_path )
image_paths.insert(0, new_filename_path)
else:
image_paths.insert(0, filepath)
next = True
break
elif filepath is not None and ( key == ord('e') or key == ord('c') ):
ed.mask_finish()
dflimg.embed_and_set (str(filepath), ie_polys=ed.get_ie_polys() )
if key == ord('e'):
new_filename_path = confirmed_path / filepath.name
filepath.rename(new_filename_path)
done_paths += [new_filename_path]
else:
done_paths += [filepath]
next = True
break
elif filepath is not None and ( key == ord('w') or key == ord('x') ):
if key == ord('w'):
new_filename_path = skipped_path / filepath.name
filepath.rename(new_filename_path)
done_paths += [new_filename_path]
else:
done_paths += [filepath]
next = True
break
elif key == 27: #esc
is_exit = True
next = True
break
screen = ed.make_screen()
io.show_image (wnd_name, screen )
io.process_messages(0.005)
io.destroy_all_windows()

View file

@ -253,7 +253,7 @@ def main(args, device_args):
for i in range(0, len(head_lines)):
t = i*head_line_height
b = (i+1)*head_line_height
head[t:b, 0:w] += imagelib.get_text_image ( (w,head_line_height,c) , head_lines[i], color=[0.8]*c )
head[t:b, 0:w] += imagelib.get_text_image ( (head_line_height,w,c) , head_lines[i], color=[0.8]*c )
final = head

View file

@ -27,6 +27,7 @@ def convert_png_to_jpg_file (filepath):
DFLJPG.embed_data( new_filepath,
face_type=dfl_dict.get('face_type', None),
landmarks=dfl_dict.get('landmarks', None),
ie_polys=dfl_dict.get('ie_polys', None),
source_filename=dfl_dict.get('source_filename', None),
source_rect=dfl_dict.get('source_rect', None),
source_landmarks=dfl_dict.get('source_landmarks', None) )
@ -63,7 +64,7 @@ def add_landmarks_debug_images(input_path):
if img is not None:
face_landmarks = dflimg.get_landmarks()
LandmarksProcessor.draw_landmarks(img, face_landmarks, transparent_mask=True)
LandmarksProcessor.draw_landmarks(img, face_landmarks, transparent_mask=True, ie_polys=dflimg.get_ie_polys() )
output_file = '{}{}'.format( str(Path(str(input_path)) / filepath.stem), '_debug.jpg')
cv2_imwrite(output_file, img, [int(cv2.IMWRITE_JPEG_QUALITY), 50] )