added XSeg model.

with XSeg model you can train your own mask segmentator of dst(and src) faces
that will be used in merger for whole_face.

Instead of using a pretrained model (which does not exist),
you control which part of faces should be masked.

Workflow is not easy, but at the moment it is the best solution
for obtaining the best quality of whole_face's deepfakes using minimum effort
without rotoscoping in AfterEffects.

new scripts:
	XSeg) data_dst edit.bat
	XSeg) data_dst merge.bat
	XSeg) data_dst split.bat
	XSeg) data_src edit.bat
	XSeg) data_src merge.bat
	XSeg) data_src split.bat
	XSeg) train.bat

Usage:
	unpack dst faceset if packed

	run XSeg) data_dst split.bat
		this scripts extracts (previously saved) .json data from jpg faces to use in label tool.

	run XSeg) data_dst edit.bat
		new tool 'labelme' is used

		use polygon (CTRL-N) to mask the face
			name polygon "1" (one symbol) as include polygon
			name polygon "0" (one symbol) as exclude polygon

			'exclude polygons' will be applied after all 'include polygons'

		Hot keys:
		ctrl-N			create polygon
		ctrl-J			edit polygon
		A/D 			navigate between frames
		ctrl + mousewheel 	image zoom
		mousewheel		vertical scroll
		alt+mousewheel		horizontal scroll

		repeat for 10/50/100 faces,
			you don't need to mask every frame of dst,
			only frames where the face is different significantly,
			for example:
				closed eyes
				changed head direction
				changed light
			the more various faces you mask, the more quality you will get

			Start masking from the upper left area and follow the clockwise direction.
			Keep the same logic of masking for all frames, for example:
				the same approximated jaw line of the side faces, where the jaw is not visible
				the same hair line
			Mask the obstructions using polygon with name "0".

	run XSeg) data_dst merge.bat
		this script merges .json data of polygons into jpg faces,
		therefore faceset can be sorted or packed as usual.

	run XSeg) train.bat
		train the model

		Check the faces of 'XSeg dst faces' preview.

		if some faces have wrong or glitchy mask, then repeat steps:
			split
			run edit
			find these glitchy faces and mask them
			merge
			train further or restart training from scratch

Restart training of XSeg model is only possible by deleting all 'model\XSeg_*' files.

If you want to get the mask of the predicted face in merger,
you should repeat the same steps for src faceset.

New mask modes available in merger for whole_face:

XSeg-prd	  - XSeg mask of predicted face	 -> faces from src faceset should be labeled
XSeg-dst	  - XSeg mask of dst face        -> faces from dst faceset should be labeled
XSeg-prd*XSeg-dst - the smallest area of both

if workspace\model folder contains trained XSeg model, then merger will use it,
otherwise you will get transparent mask by using XSeg-* modes.

Some screenshots:
label tool: https://i.imgur.com/aY6QGw1.jpg
trainer   : https://i.imgur.com/NM1Kn3s.jpg
merger    : https://i.imgur.com/glUzFQ8.jpg

example of the fake using 13 segmented dst faces
          : https://i.imgur.com/wmvyizU.gifv
This commit is contained in:
Colombo 2020-03-15 15:12:44 +04:00
parent 2be940092b
commit 45582d129d
27 changed files with 577 additions and 711 deletions

View file

@ -178,6 +178,7 @@ class DFLJPG(object):
def embed_data(filename, face_type=None, def embed_data(filename, face_type=None,
landmarks=None, landmarks=None,
ie_polys=None, ie_polys=None,
seg_ie_polys=None,
source_filename=None, source_filename=None,
source_rect=None, source_rect=None,
source_landmarks=None, source_landmarks=None,
@ -203,9 +204,14 @@ class DFLJPG(object):
if not isinstance(ie_polys, list): if not isinstance(ie_polys, list):
ie_polys = ie_polys.dump() ie_polys = ie_polys.dump()
if seg_ie_polys is not None:
if not isinstance(seg_ie_polys, list):
seg_ie_polys = seg_ie_polys.dump()
DFLJPG.embed_dfldict (filename, {'face_type': face_type, DFLJPG.embed_dfldict (filename, {'face_type': face_type,
'landmarks': landmarks, 'landmarks': landmarks,
'ie_polys' : ie_polys, 'ie_polys' : ie_polys,
'seg_ie_polys' : seg_ie_polys,
'source_filename': source_filename, 'source_filename': source_filename,
'source_rect': source_rect, 'source_rect': source_rect,
'source_landmarks': source_landmarks, 'source_landmarks': source_landmarks,
@ -218,6 +224,7 @@ class DFLJPG(object):
def embed_and_set(self, filename, face_type=None, def embed_and_set(self, filename, face_type=None,
landmarks=None, landmarks=None,
ie_polys=None, ie_polys=None,
seg_ie_polys=None,
source_filename=None, source_filename=None,
source_rect=None, source_rect=None,
source_landmarks=None, source_landmarks=None,
@ -230,6 +237,7 @@ class DFLJPG(object):
if face_type is None: face_type = self.get_face_type() if face_type is None: face_type = self.get_face_type()
if landmarks is None: landmarks = self.get_landmarks() if landmarks is None: landmarks = self.get_landmarks()
if ie_polys is None: ie_polys = self.get_ie_polys() if ie_polys is None: ie_polys = self.get_ie_polys()
if seg_ie_polys is None: seg_ie_polys = self.get_seg_ie_polys()
if source_filename is None: source_filename = self.get_source_filename() if source_filename is None: source_filename = self.get_source_filename()
if source_rect is None: source_rect = self.get_source_rect() if source_rect is None: source_rect = self.get_source_rect()
if source_landmarks is None: source_landmarks = self.get_source_landmarks() if source_landmarks is None: source_landmarks = self.get_source_landmarks()
@ -240,6 +248,7 @@ class DFLJPG(object):
DFLJPG.embed_data (filename, face_type=face_type, DFLJPG.embed_data (filename, face_type=face_type,
landmarks=landmarks, landmarks=landmarks,
ie_polys=ie_polys, ie_polys=ie_polys,
seg_ie_polys=seg_ie_polys,
source_filename=source_filename, source_filename=source_filename,
source_rect=source_rect, source_rect=source_rect,
source_landmarks=source_landmarks, source_landmarks=source_landmarks,
@ -251,6 +260,9 @@ class DFLJPG(object):
def remove_ie_polys(self): def remove_ie_polys(self):
self.dfl_dict['ie_polys'] = None self.dfl_dict['ie_polys'] = None
def remove_seg_ie_polys(self):
self.dfl_dict['seg_ie_polys'] = None
def remove_fanseg_mask(self): def remove_fanseg_mask(self):
self.dfl_dict['fanseg_mask'] = None self.dfl_dict['fanseg_mask'] = None
@ -308,6 +320,7 @@ class DFLJPG(object):
def get_face_type(self): return self.dfl_dict['face_type'] def get_face_type(self): return self.dfl_dict['face_type']
def get_landmarks(self): return np.array ( self.dfl_dict['landmarks'] ) def get_landmarks(self): return np.array ( self.dfl_dict['landmarks'] )
def get_ie_polys(self): return self.dfl_dict.get('ie_polys',None) def get_ie_polys(self): return self.dfl_dict.get('ie_polys',None)
def get_seg_ie_polys(self): return self.dfl_dict.get('seg_ie_polys',None)
def get_source_filename(self): return self.dfl_dict['source_filename'] def get_source_filename(self): return self.dfl_dict['source_filename']
def get_source_rect(self): return self.dfl_dict['source_rect'] def get_source_rect(self): return self.dfl_dict['source_rect']
def get_source_landmarks(self): return np.array ( self.dfl_dict['source_landmarks'] ) def get_source_landmarks(self): return np.array ( self.dfl_dict['source_landmarks'] )

View file

@ -89,6 +89,9 @@ class IEPolys:
if poly.n > 0: if poly.n > 0:
cv2.fillPoly(mask, [poly.points_to_n()], white if poly.type == 1 else black ) cv2.fillPoly(mask, [poly.points_to_n()], white if poly.type == 1 else black )
def get_total_points(self):
return sum([self.list[n].n for n in range(self.n)])
def dump(self): def dump(self):
result = [] result = []
for n in range(self.n): for n in range(self.n):

View file

@ -19,4 +19,8 @@ from .IEPolys import IEPolys
from .blursharpen import LinearMotionBlur, blursharpen from .blursharpen import LinearMotionBlur, blursharpen
from .filters import apply_random_hsv_shift, apply_random_motion_blur, apply_random_gaussian_blur, apply_random_bilinear_resize from .filters import apply_random_rgb_levels, \
apply_random_hsv_shift, \
apply_random_motion_blur, \
apply_random_gaussian_blur, \
apply_random_bilinear_resize

View file

@ -2,6 +2,27 @@ import numpy as np
from .blursharpen import LinearMotionBlur from .blursharpen import LinearMotionBlur
import cv2 import cv2
def apply_random_rgb_levels(img, mask=None, rnd_state=None):
if rnd_state is None:
rnd_state = np.random
np_rnd = rnd_state.rand
inBlack = np.array([np_rnd()*0.25 , np_rnd()*0.25 , np_rnd()*0.25], dtype=np.float32)
inWhite = np.array([1.0-np_rnd()*0.25, 1.0-np_rnd()*0.25, 1.0-np_rnd()*0.25], dtype=np.float32)
inGamma = np.array([0.5+np_rnd(), 0.5+np_rnd(), 0.5+np_rnd()], dtype=np.float32)
outBlack = np.array([np_rnd()*0.25 , np_rnd()*0.25 , np_rnd()*0.25], dtype=np.float32)
outWhite = np.array([1.0-np_rnd()*0.25, 1.0-np_rnd()*0.25, 1.0-np_rnd()*0.25], dtype=np.float32)
result = np.clip( (img - inBlack) / (inWhite - inBlack), 0, 1 )
result = ( result ** (1/inGamma) ) * (outWhite - outBlack) + outBlack
result = np.clip(result, 0, 1)
if mask is not None:
result = img*(1-mask) + result*mask
return result
def apply_random_hsv_shift(img, mask=None, rnd_state=None): def apply_random_hsv_shift(img, mask=None, rnd_state=None):
if rnd_state is None: if rnd_state is None:
rnd_state = np.random rnd_state = np.random

View file

@ -22,7 +22,11 @@ def circle_faded( hw, center, fade_dists ):
pts_dists = np.abs ( npla.norm(pts-center, axis=-1) ) pts_dists = np.abs ( npla.norm(pts-center, axis=-1) )
if fade_dists[1] == 0:
fade_dists[1] = 1
pts_dists = ( pts_dists - fade_dists[0] ) / fade_dists[1] pts_dists = ( pts_dists - fade_dists[0] ) / fade_dists[1]
pts_dists = np.clip( 1-pts_dists, 0, 1) pts_dists = np.clip( 1-pts_dists, 0, 1)
return pts_dists.reshape ( (h,w,1) ).astype(np.float32) return pts_dists.reshape ( (h,w,1) ).astype(np.float32)

View file

@ -1,151 +0,0 @@
from core.leras import nn
tf = nn.tf
class DFLSegnetArchi(nn.ArchiBase):
def __init__(self):
super().__init__()
class ConvBlock(nn.ModelBase):
def on_build(self, in_ch, out_ch):
self.conv = nn.Conv2D (in_ch, out_ch, kernel_size=3, padding='SAME')
self.frn = nn.FRNorm2D(out_ch)
self.tlu = nn.TLU(out_ch)
def forward(self, x):
x = self.conv(x)
x = self.frn(x)
x = self.tlu(x)
return x
class UpConvBlock(nn.ModelBase):
def on_build(self, in_ch, out_ch):
self.conv = nn.Conv2DTranspose (in_ch, out_ch, kernel_size=3, padding='SAME')
self.frn = nn.FRNorm2D(out_ch)
self.tlu = nn.TLU(out_ch)
def forward(self, x):
x = self.conv(x)
x = self.frn(x)
x = self.tlu(x)
return x
class Encoder(nn.ModelBase):
def on_build(self, in_ch, base_ch):
self.conv01 = ConvBlock(in_ch, base_ch)
self.conv02 = ConvBlock(base_ch, base_ch)
self.bp0 = nn.BlurPool (filt_size=3)
self.conv11 = ConvBlock(base_ch, base_ch*2)
self.conv12 = ConvBlock(base_ch*2, base_ch*2)
self.bp1 = nn.BlurPool (filt_size=3)
self.conv21 = ConvBlock(base_ch*2, base_ch*4)
self.conv22 = ConvBlock(base_ch*4, base_ch*4)
self.conv23 = ConvBlock(base_ch*4, base_ch*4)
self.bp2 = nn.BlurPool (filt_size=3)
self.conv31 = ConvBlock(base_ch*4, base_ch*8)
self.conv32 = ConvBlock(base_ch*8, base_ch*8)
self.conv33 = ConvBlock(base_ch*8, base_ch*8)
self.bp3 = nn.BlurPool (filt_size=3)
self.conv41 = ConvBlock(base_ch*8, base_ch*8)
self.conv42 = ConvBlock(base_ch*8, base_ch*8)
self.conv43 = ConvBlock(base_ch*8, base_ch*8)
self.bp4 = nn.BlurPool (filt_size=3)
self.conv_center = ConvBlock(base_ch*8, base_ch*8)
def forward(self, inp):
x = inp
x = self.conv01(x)
x = x0 = self.conv02(x)
x = self.bp0(x)
x = self.conv11(x)
x = x1 = self.conv12(x)
x = self.bp1(x)
x = self.conv21(x)
x = self.conv22(x)
x = x2 = self.conv23(x)
x = self.bp2(x)
x = self.conv31(x)
x = self.conv32(x)
x = x3 = self.conv33(x)
x = self.bp3(x)
x = self.conv41(x)
x = self.conv42(x)
x = x4 = self.conv43(x)
x = self.bp4(x)
x = self.conv_center(x)
return x0,x1,x2,x3,x4, x
class Decoder(nn.ModelBase):
def on_build(self, base_ch, out_ch):
self.up4 = UpConvBlock (base_ch*8, base_ch*4)
self.conv43 = ConvBlock(base_ch*12, base_ch*8)
self.conv42 = ConvBlock(base_ch*8, base_ch*8)
self.conv41 = ConvBlock(base_ch*8, base_ch*8)
self.up3 = UpConvBlock (base_ch*8, base_ch*4)
self.conv33 = ConvBlock(base_ch*12, base_ch*8)
self.conv32 = ConvBlock(base_ch*8, base_ch*8)
self.conv31 = ConvBlock(base_ch*8, base_ch*8)
self.up2 = UpConvBlock (base_ch*8, base_ch*4)
self.conv23 = ConvBlock(base_ch*8, base_ch*4)
self.conv22 = ConvBlock(base_ch*4, base_ch*4)
self.conv21 = ConvBlock(base_ch*4, base_ch*4)
self.up1 = UpConvBlock (base_ch*4, base_ch*2)
self.conv12 = ConvBlock(base_ch*4, base_ch*2)
self.conv11 = ConvBlock(base_ch*2, base_ch*2)
self.up0 = UpConvBlock (base_ch*2, base_ch)
self.conv02 = ConvBlock(base_ch*2, base_ch)
self.conv01 = ConvBlock(base_ch, base_ch)
self.out_conv = nn.Conv2D (base_ch, out_ch, kernel_size=3, padding='SAME')
def forward(self, inp):
x0,x1,x2,x3,x4,x = inp
x = self.up4(x)
x = self.conv43(tf.concat([x,x4],axis=nn.conv2d_ch_axis))
x = self.conv42(x)
x = self.conv41(x)
x = self.up3(x)
x = self.conv33(tf.concat([x,x3],axis=nn.conv2d_ch_axis))
x = self.conv32(x)
x = self.conv31(x)
x = self.up2(x)
x = self.conv23(tf.concat([x,x2],axis=nn.conv2d_ch_axis))
x = self.conv22(x)
x = self.conv21(x)
x = self.up1(x)
x = self.conv12(tf.concat([x,x1],axis=nn.conv2d_ch_axis))
x = self.conv11(x)
x = self.up0(x)
x = self.conv02(tf.concat([x,x0],axis=nn.conv2d_ch_axis))
x = self.conv01(x)
logits = self.out_conv(x)
return logits, tf.nn.sigmoid(logits)
self.Encoder = Encoder
self.Decoder = Decoder
nn.DFLSegnetArchi = DFLSegnetArchi

View file

@ -1,3 +1,2 @@
from .ArchiBase import * from .ArchiBase import *
from .DeepFakeArchi import * from .DeepFakeArchi import *
from .DFLSegnet import *

137
core/leras/models/XSeg.py Normal file
View file

@ -0,0 +1,137 @@
from core.leras import nn
tf = nn.tf
class XSeg(nn.ModelBase):
def on_build (self, in_ch, base_ch, out_ch):
class ConvBlock(nn.ModelBase):
def on_build(self, in_ch, out_ch):
self.conv = nn.Conv2D (in_ch, out_ch, kernel_size=3, padding='SAME')
self.frn = nn.FRNorm2D(out_ch)
self.tlu = nn.TLU(out_ch)
def forward(self, x):
x = self.conv(x)
x = self.frn(x)
x = self.tlu(x)
return x
class UpConvBlock(nn.ModelBase):
def on_build(self, in_ch, out_ch):
self.conv = nn.Conv2DTranspose (in_ch, out_ch, kernel_size=3, padding='SAME')
self.frn = nn.FRNorm2D(out_ch)
self.tlu = nn.TLU(out_ch)
def forward(self, x):
x = self.conv(x)
x = self.frn(x)
x = self.tlu(x)
return x
self.conv01 = ConvBlock(in_ch, base_ch)
self.conv02 = ConvBlock(base_ch, base_ch)
self.bp0 = nn.BlurPool (filt_size=3)
self.conv11 = ConvBlock(base_ch, base_ch*2)
self.conv12 = ConvBlock(base_ch*2, base_ch*2)
self.bp1 = nn.BlurPool (filt_size=3)
self.conv21 = ConvBlock(base_ch*2, base_ch*4)
self.conv22 = ConvBlock(base_ch*4, base_ch*4)
self.conv23 = ConvBlock(base_ch*4, base_ch*4)
self.bp2 = nn.BlurPool (filt_size=3)
self.conv31 = ConvBlock(base_ch*4, base_ch*8)
self.conv32 = ConvBlock(base_ch*8, base_ch*8)
self.conv33 = ConvBlock(base_ch*8, base_ch*8)
self.bp3 = nn.BlurPool (filt_size=3)
self.conv41 = ConvBlock(base_ch*8, base_ch*8)
self.conv42 = ConvBlock(base_ch*8, base_ch*8)
self.conv43 = ConvBlock(base_ch*8, base_ch*8)
self.bp4 = nn.BlurPool (filt_size=3)
self.up4 = UpConvBlock (base_ch*8, base_ch*4)
self.uconv43 = ConvBlock(base_ch*12, base_ch*8)
self.uconv42 = ConvBlock(base_ch*8, base_ch*8)
self.uconv41 = ConvBlock(base_ch*8, base_ch*8)
self.up3 = UpConvBlock (base_ch*8, base_ch*4)
self.uconv33 = ConvBlock(base_ch*12, base_ch*8)
self.uconv32 = ConvBlock(base_ch*8, base_ch*8)
self.uconv31 = ConvBlock(base_ch*8, base_ch*8)
self.up2 = UpConvBlock (base_ch*8, base_ch*4)
self.uconv23 = ConvBlock(base_ch*8, base_ch*4)
self.uconv22 = ConvBlock(base_ch*4, base_ch*4)
self.uconv21 = ConvBlock(base_ch*4, base_ch*4)
self.up1 = UpConvBlock (base_ch*4, base_ch*2)
self.uconv12 = ConvBlock(base_ch*4, base_ch*2)
self.uconv11 = ConvBlock(base_ch*2, base_ch*2)
self.up0 = UpConvBlock (base_ch*2, base_ch)
self.uconv02 = ConvBlock(base_ch*2, base_ch)
self.uconv01 = ConvBlock(base_ch, base_ch)
self.out_conv = nn.Conv2D (base_ch, out_ch, kernel_size=3, padding='SAME')
self.conv_center = ConvBlock(base_ch*8, base_ch*8)
def forward(self, inp):
x = inp
x = self.conv01(x)
x = x0 = self.conv02(x)
x = self.bp0(x)
x = self.conv11(x)
x = x1 = self.conv12(x)
x = self.bp1(x)
x = self.conv21(x)
x = self.conv22(x)
x = x2 = self.conv23(x)
x = self.bp2(x)
x = self.conv31(x)
x = self.conv32(x)
x = x3 = self.conv33(x)
x = self.bp3(x)
x = self.conv41(x)
x = self.conv42(x)
x = x4 = self.conv43(x)
x = self.bp4(x)
x = self.conv_center(x)
x = self.up4(x)
x = self.uconv43(tf.concat([x,x4],axis=nn.conv2d_ch_axis))
x = self.uconv42(x)
x = self.uconv41(x)
x = self.up3(x)
x = self.uconv33(tf.concat([x,x3],axis=nn.conv2d_ch_axis))
x = self.uconv32(x)
x = self.uconv31(x)
x = self.up2(x)
x = self.uconv23(tf.concat([x,x2],axis=nn.conv2d_ch_axis))
x = self.uconv22(x)
x = self.uconv21(x)
x = self.up1(x)
x = self.uconv12(tf.concat([x,x1],axis=nn.conv2d_ch_axis))
x = self.uconv11(x)
x = self.up0(x)
x = self.uconv02(tf.concat([x,x0],axis=nn.conv2d_ch_axis))
x = self.uconv01(x)
logits = self.out_conv(x)
return logits, tf.nn.sigmoid(logits)
nn.XSeg = XSeg

View file

@ -2,3 +2,4 @@ from .ModelBase import *
from .PatchDiscriminator import * from .PatchDiscriminator import *
from .CodeDiscriminator import * from .CodeDiscriminator import *
from .Ternaus import * from .Ternaus import *
from .XSeg import *

View file

@ -10,7 +10,7 @@ from core.interact import interact as io
from core.leras import nn from core.leras import nn
class DFLSegNet(object): class XSegNet(object):
VERSION = 1 VERSION = 1
def __init__ (self, name, def __init__ (self, name,
@ -34,28 +34,24 @@ class DFLSegNet(object):
self.target_t = tf.placeholder (nn.floatx, nn.get4Dshape(resolution,resolution,1) ) self.target_t = tf.placeholder (nn.floatx, nn.get4Dshape(resolution,resolution,1) )
# Initializing model classes # Initializing model classes
archi = nn.DFLSegnetArchi()
with tf.device ('/CPU:0' if place_model_on_cpu else '/GPU:0'): with tf.device ('/CPU:0' if place_model_on_cpu else '/GPU:0'):
self.enc = archi.Encoder(3, 64, name='Encoder') self.model = nn.XSeg(3, 32, 1, name=name)
self.dec = archi.Decoder(64, 1, name='Decoder') self.model_weights = self.model.get_weights()
self.enc_dec_weights = self.enc.get_weights()+self.dec.get_weights()
model_name = f'{name}_{resolution}' model_name = f'{name}_{resolution}'
self.model_filename_list = [ [self.enc, f'{model_name}_enc.npy'], self.model_filename_list = [ [self.model, f'{model_name}.npy'] ]
[self.dec, f'{model_name}_dec.npy'],
]
if training: if training:
if optimizer is None: if optimizer is None:
raise ValueError("Optimizer should be provided for training mode.") raise ValueError("Optimizer should be provided for training mode.")
self.opt = optimizer self.opt = optimizer
self.opt.initialize_variables (self.enc_dec_weights, vars_on_cpu=place_model_on_cpu) self.opt.initialize_variables (self.model_weights, vars_on_cpu=place_model_on_cpu)
self.model_filename_list += [ [self.opt, f'{model_name}_opt.npy' ] ] self.model_filename_list += [ [self.opt, f'{model_name}_opt.npy' ] ]
else: else:
with tf.device ('/CPU:0' if run_on_cpu else '/GPU:0'): with tf.device ('/CPU:0' if run_on_cpu else '/GPU:0'):
_, pred = self.dec(self.enc(self.input_t)) _, pred = self.model(self.input_t)
def net_run(input_np): def net_run(input_np):
return nn.tf_sess.run ( [pred], feed_dict={self.input_t :input_np})[0] return nn.tf_sess.run ( [pred], feed_dict={self.input_t :input_np})[0]
@ -72,10 +68,10 @@ class DFLSegNet(object):
model.init_weights() model.init_weights()
def flow(self, x): def flow(self, x):
return self.dec(self.enc(x)) return self.model(x)
def get_weights(self): def get_weights(self):
return self.enc_dec_weights return self.model_weights
def save_weights(self): def save_weights(self):
for model, filename in io.progress_bar_generator(self.model_filename_list, "Saving", leave=False): for model, filename in io.progress_bar_generator(self.model_filename_list, "Saving", leave=False):

View file

@ -3,4 +3,4 @@ from .S3FDExtractor import S3FDExtractor
from .FANExtractor import FANExtractor from .FANExtractor import FANExtractor
from .FaceEnhancer import FaceEnhancer from .FaceEnhancer import FaceEnhancer
from .TernausNet import TernausNet from .TernausNet import TernausNet
from .DFLSegNet import DFLSegNet from .XSegNet import XSegNet

121
main.py
View file

@ -55,86 +55,6 @@ if __name__ == "__main__":
p.set_defaults (func=process_extract) p.set_defaults (func=process_extract)
def process_dev_extract_vggface2_dataset(arguments):
osex.set_process_lowest_prio()
from mainscripts import dev_misc
dev_misc.extract_vggface2_dataset( arguments.input_dir,
device_args={'cpu_only' : arguments.cpu_only,
'multi_gpu' : arguments.multi_gpu,
}
)
p = subparsers.add_parser( "dev_extract_vggface2_dataset", help="")
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory. A directory containing the files you wish to process.")
p.add_argument('--multi-gpu', action="store_true", dest="multi_gpu", default=False, help="Enables multi GPU.")
p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Extract on CPU.")
p.set_defaults (func=process_dev_extract_vggface2_dataset)
def process_dev_extract_umd_csv(arguments):
osex.set_process_lowest_prio()
from mainscripts import dev_misc
dev_misc.extract_umd_csv( arguments.input_csv_file,
device_args={'cpu_only' : arguments.cpu_only,
'multi_gpu' : arguments.multi_gpu,
}
)
p = subparsers.add_parser( "dev_extract_umd_csv", help="")
p.add_argument('--input-csv-file', required=True, action=fixPathAction, dest="input_csv_file", help="input_csv_file")
p.add_argument('--multi-gpu', action="store_true", dest="multi_gpu", default=False, help="Enables multi GPU.")
p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Extract on CPU.")
p.set_defaults (func=process_dev_extract_umd_csv)
def process_dev_apply_celebamaskhq(arguments):
osex.set_process_lowest_prio()
from mainscripts import dev_misc
dev_misc.apply_celebamaskhq( arguments.input_dir )
p = subparsers.add_parser( "dev_apply_celebamaskhq", help="")
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir")
p.set_defaults (func=process_dev_apply_celebamaskhq)
def process_dev_test(arguments):
osex.set_process_lowest_prio()
from mainscripts import dev_misc
dev_misc.dev_test( arguments.input_dir )
p = subparsers.add_parser( "dev_test", help="")
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir")
p.set_defaults (func=process_dev_test)
def process_dev_segmented_extract(arguments):
osex.set_process_lowest_prio()
from mainscripts import dev_misc
dev_misc.dev_segmented_extract(arguments.input_dir, arguments.output_dir)
p = subparsers.add_parser( "dev_segmented_extract", help="")
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir")
p.add_argument('--output-dir', required=True, action=fixPathAction, dest="output_dir")
p.set_defaults (func=process_dev_segmented_extract)
def process_dev_segmented_trash(arguments):
osex.set_process_lowest_prio()
from mainscripts import dev_misc
dev_misc.dev_segmented_trash(arguments.input_dir)
p = subparsers.add_parser( "dev_segmented_trash", help="")
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir")
p.set_defaults (func=process_dev_segmented_trash)
def process_dev_resave_pngs(arguments):
osex.set_process_lowest_prio()
from mainscripts import dev_misc
dev_misc.dev_resave_pngs(arguments.input_dir)
p = subparsers.add_parser( "dev_resave_pngs", help="")
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir")
p.set_defaults (func=process_dev_resave_pngs)
def process_sort(arguments): def process_sort(arguments):
osex.set_process_lowest_prio() osex.set_process_lowest_prio()
from mainscripts import Sorter from mainscripts import Sorter
@ -342,27 +262,36 @@ if __name__ == "__main__":
p.set_defaults(func=process_faceset_enhancer) p.set_defaults(func=process_faceset_enhancer)
""" def process_dev_test(arguments):
def process_relight_faceset(arguments):
osex.set_process_lowest_prio() osex.set_process_lowest_prio()
from mainscripts import FacesetRelighter from mainscripts import dev_misc
FacesetRelighter.relight (arguments.input_dir, arguments.lighten, arguments.random_one) dev_misc.dev_test( arguments.input_dir )
def process_delete_relighted(arguments): p = subparsers.add_parser( "dev_test", help="")
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir")
p.set_defaults (func=process_dev_test)
# ========== XSeg util
xseg_parser = subparsers.add_parser( "xseg", help="XSeg utils.").add_subparsers()
def process_xseg_merge(arguments):
osex.set_process_lowest_prio() osex.set_process_lowest_prio()
from mainscripts import FacesetRelighter from mainscripts import XSegUtil
FacesetRelighter.delete_relighted (arguments.input_dir) XSegUtil.merge(arguments.input_dir)
p = xseg_parser.add_parser( "merge", help="")
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir")
p = facesettool_parser.add_parser ("relight", help="Synthesize new faces from existing ones by relighting them. With the relighted faces neural network will better reproduce face shadows.") p.set_defaults (func=process_xseg_merge)
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory of aligned faces.")
p.add_argument('--lighten', action="store_true", dest="lighten", default=None, help="Lighten the faces.")
p.add_argument('--random-one', action="store_true", dest="random_one", default=None, help="Relight the faces only with one random direction, otherwise relight with all directions.")
p.set_defaults(func=process_relight_faceset)
p = facesettool_parser.add_parser ("delete_relighted", help="Delete relighted faces.") def process_xseg_split(arguments):
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory of aligned faces.") osex.set_process_lowest_prio()
p.set_defaults(func=process_delete_relighted) from mainscripts import XSegUtil
""" XSegUtil.split(arguments.input_dir)
p = xseg_parser.add_parser( "split", help="")
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir")
p.set_defaults (func=process_xseg_split)
def bad_args(arguments): def bad_args(arguments):
parser.print_help() parser.print_help()

View file

@ -12,7 +12,7 @@ from core.interact import interact as io
from core.joblib import MPClassFuncOnDemand, MPFunc from core.joblib import MPClassFuncOnDemand, MPFunc
from core.leras import nn from core.leras import nn
from DFLIMG import DFLIMG from DFLIMG import DFLIMG
from facelib import FaceEnhancer, FaceType, LandmarksProcessor, TernausNet, DFLSegNet from facelib import FaceEnhancer, FaceType, LandmarksProcessor, TernausNet, XSegNet
from merger import FrameInfo, MergerConfig, InteractiveMergerSubprocessor from merger import FrameInfo, MergerConfig, InteractiveMergerSubprocessor
def main (model_class_name=None, def main (model_class_name=None,
@ -61,9 +61,10 @@ def main (model_class_name=None,
place_model_on_cpu=True, place_model_on_cpu=True,
run_on_cpu=run_on_cpu) run_on_cpu=run_on_cpu)
skinseg_256_extract_func = MPClassFuncOnDemand(DFLSegNet, 'extract', xseg_256_extract_func = MPClassFuncOnDemand(XSegNet, 'extract',
name='SkinSeg', name='XSeg',
resolution=256, resolution=256,
weights_file_root=saved_models_path,
place_model_on_cpu=True, place_model_on_cpu=True,
run_on_cpu=run_on_cpu) run_on_cpu=run_on_cpu)
@ -199,7 +200,7 @@ def main (model_class_name=None,
predictor_input_shape = predictor_input_shape, predictor_input_shape = predictor_input_shape,
face_enhancer_func = face_enhancer_func, face_enhancer_func = face_enhancer_func,
fanseg_full_face_256_extract_func = fanseg_full_face_256_extract_func, fanseg_full_face_256_extract_func = fanseg_full_face_256_extract_func,
skinseg_256_extract_func = skinseg_256_extract_func, xseg_256_extract_func = xseg_256_extract_func,
merger_config = cfg, merger_config = cfg,
frames = frames, frames = frames,
frames_root_path = input_path, frames_root_path = input_path,

View file

@ -67,7 +67,7 @@ class InteractiveMergerSubprocessor(Subprocessor):
self.predictor_input_shape = client_dict['predictor_input_shape'] self.predictor_input_shape = client_dict['predictor_input_shape']
self.face_enhancer_func = client_dict['face_enhancer_func'] self.face_enhancer_func = client_dict['face_enhancer_func']
self.fanseg_full_face_256_extract_func = client_dict['fanseg_full_face_256_extract_func'] self.fanseg_full_face_256_extract_func = client_dict['fanseg_full_face_256_extract_func']
self.skinseg_256_extract_func = client_dict['skinseg_256_extract_func'] self.xseg_256_extract_func = client_dict['xseg_256_extract_func']
#transfer and set stdin in order to work code.interact in debug subprocess #transfer and set stdin in order to work code.interact in debug subprocess
@ -104,7 +104,7 @@ class InteractiveMergerSubprocessor(Subprocessor):
final_img = MergeMasked (self.predictor_func, self.predictor_input_shape, final_img = MergeMasked (self.predictor_func, self.predictor_input_shape,
face_enhancer_func=self.face_enhancer_func, face_enhancer_func=self.face_enhancer_func,
fanseg_full_face_256_extract_func=self.fanseg_full_face_256_extract_func, fanseg_full_face_256_extract_func=self.fanseg_full_face_256_extract_func,
skinseg_256_extract_func=self.skinseg_256_extract_func, xseg_256_extract_func=self.xseg_256_extract_func,
cfg=cfg, cfg=cfg,
frame_info=frame_info) frame_info=frame_info)
except Exception as e: except Exception as e:
@ -137,7 +137,7 @@ class InteractiveMergerSubprocessor(Subprocessor):
#override #override
def __init__(self, is_interactive, merger_session_filepath, predictor_func, predictor_input_shape, face_enhancer_func, fanseg_full_face_256_extract_func, skinseg_256_extract_func, merger_config, frames, frames_root_path, output_path, output_mask_path, model_iter): def __init__(self, is_interactive, merger_session_filepath, predictor_func, predictor_input_shape, face_enhancer_func, fanseg_full_face_256_extract_func, xseg_256_extract_func, merger_config, frames, frames_root_path, output_path, output_mask_path, model_iter):
if len (frames) == 0: if len (frames) == 0:
raise ValueError ("len (frames) == 0") raise ValueError ("len (frames) == 0")
@ -152,7 +152,7 @@ class InteractiveMergerSubprocessor(Subprocessor):
self.face_enhancer_func = face_enhancer_func self.face_enhancer_func = face_enhancer_func
self.fanseg_full_face_256_extract_func = fanseg_full_face_256_extract_func self.fanseg_full_face_256_extract_func = fanseg_full_face_256_extract_func
self.skinseg_256_extract_func = skinseg_256_extract_func self.xseg_256_extract_func = xseg_256_extract_func
self.frames_root_path = frames_root_path self.frames_root_path = frames_root_path
self.output_path = output_path self.output_path = output_path
@ -274,7 +274,7 @@ class InteractiveMergerSubprocessor(Subprocessor):
'predictor_input_shape' : self.predictor_input_shape, 'predictor_input_shape' : self.predictor_input_shape,
'face_enhancer_func': self.face_enhancer_func, 'face_enhancer_func': self.face_enhancer_func,
'fanseg_full_face_256_extract_func' : self.fanseg_full_face_256_extract_func, 'fanseg_full_face_256_extract_func' : self.fanseg_full_face_256_extract_func,
'skinseg_256_extract_func' : self.skinseg_256_extract_func, 'xseg_256_extract_func' : self.xseg_256_extract_func,
'stdin_fd': sys.stdin.fileno() if MERGER_DEBUG else None 'stdin_fd': sys.stdin.fileno() if MERGER_DEBUG else None
} }

View file

@ -9,12 +9,12 @@ from core.interact import interact as io
from core.cv2ex import * from core.cv2ex import *
fanseg_input_size = 256 fanseg_input_size = 256
skinseg_input_size = 256 xseg_input_size = 256
def MergeMaskedFace (predictor_func, predictor_input_shape, def MergeMaskedFace (predictor_func, predictor_input_shape,
face_enhancer_func, face_enhancer_func,
fanseg_full_face_256_extract_func, fanseg_full_face_256_extract_func,
skinseg_256_extract_func, xseg_256_extract_func,
cfg, frame_info, img_bgr_uint8, img_bgr, img_face_landmarks): cfg, frame_info, img_bgr_uint8, img_bgr, img_face_landmarks):
img_size = img_bgr.shape[1], img_bgr.shape[0] img_size = img_bgr.shape[1], img_bgr.shape[0]
img_face_mask_a = LandmarksProcessor.get_image_hull_mask (img_bgr.shape, img_face_landmarks) img_face_mask_a = LandmarksProcessor.get_image_hull_mask (img_bgr.shape, img_face_landmarks)
@ -111,23 +111,23 @@ def MergeMaskedFace (predictor_func, predictor_input_shape,
elif cfg.mask_mode >= 8 and cfg.mask_mode <= 11: elif cfg.mask_mode >= 8 and cfg.mask_mode <= 11:
if cfg.mask_mode == 8 or cfg.mask_mode == 10 or cfg.mask_mode == 11: if cfg.mask_mode == 8 or cfg.mask_mode == 10 or cfg.mask_mode == 11:
prd_face_skinseg_bgr = cv2.resize (prd_face_bgr, (skinseg_input_size,)*2 ) prd_face_xseg_bgr = cv2.resize (prd_face_bgr, (xseg_input_size,)*2, cv2.INTER_CUBIC)
prd_face_skinseg_mask = skinseg_256_extract_func(prd_face_skinseg_bgr) prd_face_xseg_mask = xseg_256_extract_func(prd_face_xseg_bgr)
X_prd_face_mask_a_0 = cv2.resize ( prd_face_skinseg_mask, (output_size, output_size), cv2.INTER_CUBIC) X_prd_face_mask_a_0 = cv2.resize ( prd_face_xseg_mask, (output_size, output_size), cv2.INTER_CUBIC)
if cfg.mask_mode >= 9 and cfg.mask_mode <= 11: if cfg.mask_mode >= 9 and cfg.mask_mode <= 11:
whole_face_mat = LandmarksProcessor.get_transform_mat (img_face_landmarks, skinseg_input_size, face_type=FaceType.WHOLE_FACE) whole_face_mat = LandmarksProcessor.get_transform_mat (img_face_landmarks, xseg_input_size, face_type=FaceType.WHOLE_FACE)
dst_face_skinseg_bgr = cv2.warpAffine(img_bgr, whole_face_mat, (skinseg_input_size,)*2, flags=cv2.INTER_CUBIC ) dst_face_xseg_bgr = cv2.warpAffine(img_bgr, whole_face_mat, (xseg_input_size,)*2, flags=cv2.INTER_CUBIC )
dst_face_skinseg_mask = skinseg_256_extract_func(dst_face_skinseg_bgr) dst_face_xseg_mask = xseg_256_extract_func(dst_face_xseg_bgr)
X_dst_face_mask_a_0 = cv2.resize (dst_face_skinseg_mask, (output_size,output_size), cv2.INTER_CUBIC) X_dst_face_mask_a_0 = cv2.resize (dst_face_xseg_mask, (output_size,output_size), cv2.INTER_CUBIC)
if cfg.mask_mode == 8: #'SkinSeg-prd', if cfg.mask_mode == 8: #'XSeg-prd',
prd_face_mask_a_0 = X_prd_face_mask_a_0 prd_face_mask_a_0 = X_prd_face_mask_a_0
elif cfg.mask_mode == 9: #'SkinSeg-dst', elif cfg.mask_mode == 9: #'XSeg-dst',
prd_face_mask_a_0 = X_dst_face_mask_a_0 prd_face_mask_a_0 = X_dst_face_mask_a_0
elif cfg.mask_mode == 10: #'SkinSeg-prd*SkinSeg-dst', elif cfg.mask_mode == 10: #'XSeg-prd*XSeg-dst',
prd_face_mask_a_0 = X_prd_face_mask_a_0 * X_dst_face_mask_a_0 prd_face_mask_a_0 = X_prd_face_mask_a_0 * X_dst_face_mask_a_0
elif cfg.mask_mode == 11: #learned*SkinSeg-prd*SkinSeg-dst' elif cfg.mask_mode == 11: #learned*XSeg-prd*XSeg-dst'
prd_face_mask_a_0 = prd_face_mask_a_0 * X_prd_face_mask_a_0 * X_dst_face_mask_a_0 prd_face_mask_a_0 = prd_face_mask_a_0 * X_prd_face_mask_a_0 * X_dst_face_mask_a_0
prd_face_mask_a_0[ prd_face_mask_a_0 < (1.0/255.0) ] = 0.0 # get rid of noise prd_face_mask_a_0[ prd_face_mask_a_0 < (1.0/255.0) ] = 0.0 # get rid of noise
@ -347,7 +347,7 @@ def MergeMasked (predictor_func,
predictor_input_shape, predictor_input_shape,
face_enhancer_func, face_enhancer_func,
fanseg_full_face_256_extract_func, fanseg_full_face_256_extract_func,
skinseg_256_extract_func, xseg_256_extract_func,
cfg, cfg,
frame_info): frame_info):
img_bgr_uint8 = cv2_imread(frame_info.filepath) img_bgr_uint8 = cv2_imread(frame_info.filepath)
@ -356,7 +356,7 @@ def MergeMasked (predictor_func,
outs = [] outs = []
for face_num, img_landmarks in enumerate( frame_info.landmarks_list ): for face_num, img_landmarks in enumerate( frame_info.landmarks_list ):
out_img, out_img_merging_mask = MergeMaskedFace (predictor_func, predictor_input_shape, face_enhancer_func, fanseg_full_face_256_extract_func, skinseg_256_extract_func, cfg, frame_info, img_bgr_uint8, img_bgr, img_landmarks) out_img, out_img_merging_mask = MergeMaskedFace (predictor_func, predictor_input_shape, face_enhancer_func, fanseg_full_face_256_extract_func, xseg_256_extract_func, cfg, frame_info, img_bgr_uint8, img_bgr, img_landmarks)
outs += [ (out_img, out_img_merging_mask) ] outs += [ (out_img, out_img_merging_mask) ]
#Combining multiple face outputs #Combining multiple face outputs

View file

@ -83,6 +83,7 @@ mode_str_dict = {}
for key in mode_dict.keys(): for key in mode_dict.keys():
mode_str_dict[ mode_dict[key] ] = key mode_str_dict[ mode_dict[key] ] = key
"""
whole_face_mask_mode_dict = {1:'learned', whole_face_mask_mode_dict = {1:'learned',
2:'dst', 2:'dst',
3:'FAN-prd', 3:'FAN-prd',
@ -91,11 +92,14 @@ whole_face_mask_mode_dict = {1:'learned',
6:'learned*FAN-prd*FAN-dst' 6:'learned*FAN-prd*FAN-dst'
} }
""" """
8:'SkinSeg-prd', whole_face_mask_mode_dict = {1:'learned',
9:'SkinSeg-dst', 2:'dst',
10:'SkinSeg-prd*SkinSeg-dst', 8:'XSeg-prd',
11:'learned*SkinSeg-prd*SkinSeg-dst' 9:'XSeg-dst',
""" 10:'XSeg-prd*XSeg-dst',
11:'learned*XSeg-prd*XSeg-dst'
}
full_face_mask_mode_dict = {1:'learned', full_face_mask_mode_dict = {1:'learned',
2:'dst', 2:'dst',
3:'FAN-prd', 3:'FAN-prd',

View file

@ -32,6 +32,7 @@ class ModelBase(object):
force_gpu_idxs=None, force_gpu_idxs=None,
cpu_only=False, cpu_only=False,
debug=False, debug=False,
force_model_class_name=None,
**kwargs): **kwargs):
self.is_training = is_training self.is_training = is_training
self.saved_models_path = saved_models_path self.saved_models_path = saved_models_path
@ -44,6 +45,7 @@ class ModelBase(object):
self.model_class_name = model_class_name = Path(inspect.getmodule(self).__file__).parent.name.rsplit("_", 1)[1] self.model_class_name = model_class_name = Path(inspect.getmodule(self).__file__).parent.name.rsplit("_", 1)[1]
if force_model_class_name is None:
if force_model_name is not None: if force_model_name is not None:
self.model_name = force_model_name self.model_name = force_model_name
else: else:
@ -117,7 +119,10 @@ class ModelBase(object):
self.model_name = self.model_name.replace('_', ' ') self.model_name = self.model_name.replace('_', ' ')
break break
self.model_name = self.model_name + '_' + self.model_class_name self.model_name = self.model_name + '_' + self.model_class_name
else:
self.model_name = force_model_class_name
self.iter = 0 self.iter = 0
self.options = {} self.options = {}

View file

@ -13,6 +13,9 @@ from samplelib import *
class FANSegModel(ModelBase): class FANSegModel(ModelBase):
def __init__(self, *args, **kwargs):
super().__init__(*args, force_model_class_name='FANSeg', **kwargs)
#override #override
def on_initialize_options(self): def on_initialize_options(self):
device_config = nn.getCurrentDeviceConfig() device_config = nn.getCurrentDeviceConfig()
@ -48,7 +51,7 @@ class FANSegModel(ModelBase):
mask_shape = nn.get4Dshape(resolution,resolution,1) mask_shape = nn.get4Dshape(resolution,resolution,1)
# Initializing model classes # Initializing model classes
self.model = TernausNet(f'{self.model_name}_FANSeg_{FaceType.toString(self.face_type)}', self.model = TernausNet(f'FANSeg_{FaceType.toString(self.face_type)}',
resolution, resolution,
load_weights=not self.is_first_run(), load_weights=not self.is_first_run(),
weights_file_root=self.get_model_root_path(), weights_file_root=self.get_model_root_path(),
@ -117,14 +120,14 @@ class FANSegModel(ModelBase):
src_generator = SampleGeneratorFace(training_data_src_path, random_ct_samples_path=training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(), src_generator = SampleGeneratorFace(training_data_src_path, random_ct_samples_path=training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
sample_process_options=SampleProcessor.Options(random_flip=True), sample_process_options=SampleProcessor.Options(random_flip=True),
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'ct_mode':'lct', 'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'motion_blur':(25, 5), 'gaussian_blur':(25,5), 'data_format':nn.data_format, 'resolution': resolution}, output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'ct_mode':'lct', 'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'random_motion_blur':(25, 5), 'random_gaussian_blur':(25,5), 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
], ],
generators_count=src_generators_count ) generators_count=src_generators_count )
dst_generator = SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), dst_generator = SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
sample_process_options=SampleProcessor.Options(random_flip=True), sample_process_options=SampleProcessor.Options(random_flip=True),
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp':False, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'motion_blur':(25, 5), 'gaussian_blur':(25,5), 'data_format':nn.data_format, 'resolution': resolution}, output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp':False, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
], ],
generators_count=dst_generators_count, generators_count=dst_generators_count,
raise_on_no_data=False ) raise_on_no_data=False )

View file

@ -7,28 +7,18 @@ import numpy as np
from core import mathlib from core import mathlib
from core.interact import interact as io from core.interact import interact as io
from core.leras import nn from core.leras import nn
from facelib import FaceType, TernausNet, DFLSegNet from facelib import FaceType, TernausNet, XSegNet
from models import ModelBase from models import ModelBase
from samplelib import * from samplelib import *
class SkinSegModel(ModelBase): class XSegModel(ModelBase):
def __init__(self, *args, **kwargs):
super().__init__(*args, force_model_class_name='XSeg', **kwargs)
#override #override
def on_initialize_options(self): def on_initialize_options(self):
device_config = nn.getCurrentDeviceConfig() self.set_batch_size(4)
yn_str = {True:'y',False:'n'}
ask_override = self.ask_override()
if self.is_first_run() or ask_override:
self.ask_autobackup_hour()
self.ask_write_preview_history()
self.ask_target_iter()
self.ask_batch_size(8)
default_lr_dropout = self.options['lr_dropout'] = self.load_or_def_option('lr_dropout', False)
if self.is_first_run() or ask_override:
self.options['lr_dropout'] = io.input_bool ("Use learning rate dropout", default_lr_dropout, help_message="When the face is trained enough, you can enable this option to get extra sharpness and reduce subpixel shake for less amount of iterations.")
#override #override
def on_initialize(self): def on_initialize(self):
@ -43,20 +33,20 @@ class SkinSegModel(ModelBase):
self.resolution = resolution = 256 self.resolution = resolution = 256
self.face_type = FaceType.WHOLE_FACE self.face_type = FaceType.WHOLE_FACE
place_model_on_cpu = True #len(devices) == 0 place_model_on_cpu = len(devices) == 0
models_opt_device = '/CPU:0' if place_model_on_cpu else '/GPU:0' models_opt_device = '/CPU:0' if place_model_on_cpu else '/GPU:0'
bgr_shape = nn.get4Dshape(resolution,resolution,3) bgr_shape = nn.get4Dshape(resolution,resolution,3)
mask_shape = nn.get4Dshape(resolution,resolution,1) mask_shape = nn.get4Dshape(resolution,resolution,1)
# Initializing model classes # Initializing model classes
self.model = DFLSegNet(name=f'{self.model_name}_SkinSeg', self.model = XSegNet(name=f'XSeg',
resolution=resolution, resolution=resolution,
load_weights=not self.is_first_run(), load_weights=not self.is_first_run(),
weights_file_root=self.get_model_root_path(), weights_file_root=self.get_model_root_path(),
training=True, training=True,
place_model_on_cpu=place_model_on_cpu, place_model_on_cpu=place_model_on_cpu,
optimizer=nn.RMSprop(lr=0.0001, lr_dropout=0.3 if self.options['lr_dropout'] else 1.0, name='opt'), optimizer=nn.RMSprop(lr=0.0001, lr_dropout=0.3, name='opt'),
data_format=nn.data_format) data_format=nn.data_format)
if self.is_training: if self.is_training:
@ -111,38 +101,33 @@ class SkinSegModel(ModelBase):
# initializing sample generators # initializing sample generators
cpu_count = min(multiprocessing.cpu_count(), 8) cpu_count = min(multiprocessing.cpu_count(), 8)
src_dst_generators_count = cpu_count // 2
src_generators_count = cpu_count // 2 src_generators_count = cpu_count // 2
dst_generators_count = cpu_count // 2 dst_generators_count = cpu_count // 2
src_generators_count = int(src_generators_count * 1.5)
"""
src_generator = SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(), srcdst_generator = SampleGeneratorFaceXSeg([self.training_data_src_path, self.training_data_dst_path],
sample_process_options=SampleProcessor.Options(random_flip=True),
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR_RANDOM_HSV_SHIFT, 'border_replicate':False, 'face_type':self.face_type, 'motion_blur':(25, 5), 'gaussian_blur':(25,5), 'random_bilinear_resize':(25,75), 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.NONE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
],
generators_count=src_generators_count )
"""
src_generator = SampleGeneratorFaceSkinSegDataset(self.training_data_src_path,
debug=self.is_debug(), debug=self.is_debug(),
batch_size=self.get_batch_size(), batch_size=self.get_batch_size(),
resolution=resolution, resolution=resolution,
face_type=self.face_type, face_type=self.face_type,
generators_count=src_generators_count, generators_count=src_dst_generators_count,
data_format=nn.data_format) data_format=nn.data_format)
src_generator = SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
sample_process_options=SampleProcessor.Options(random_flip=False),
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp':False, 'transform':False, 'channel_type' : SampleProcessor.ChannelType.BGR, 'border_replicate':False, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
],
generators_count=src_generators_count,
raise_on_no_data=False )
dst_generator = SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), dst_generator = SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
sample_process_options=SampleProcessor.Options(random_flip=True), sample_process_options=SampleProcessor.Options(random_flip=False),
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp':False, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'border_replicate':False, 'face_type':self.face_type, 'motion_blur':(25, 5), 'gaussian_blur':(25,5), 'random_bilinear_resize':(25,75), 'data_format':nn.data_format, 'resolution': resolution}, output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp':False, 'transform':False, 'channel_type' : SampleProcessor.ChannelType.BGR, 'border_replicate':False, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
], ],
generators_count=dst_generators_count, generators_count=dst_generators_count,
raise_on_no_data=False ) raise_on_no_data=False )
self.set_training_data_generators ([srcdst_generator, src_generator, dst_generator])
if not dst_generator.is_initialized():
io.log_info(f"\nTo view the model on unseen faces, place any image faces in {self.training_data_dst_path}.\n")
self.set_training_data_generators ([src_generator, dst_generator])
#override #override
def get_model_filename_list(self): def get_model_filename_list(self):
@ -154,6 +139,8 @@ class SkinSegModel(ModelBase):
#override #override
def onTrainOneIter(self): def onTrainOneIter(self):
image_np, mask_np = self.generate_next_samples()[0] image_np, mask_np = self.generate_next_samples()[0]
loss = self.train (image_np, mask_np) loss = self.train (image_np, mask_np)
@ -163,8 +150,8 @@ class SkinSegModel(ModelBase):
def onGetPreview(self, samples): def onGetPreview(self, samples):
n_samples = min(4, self.get_batch_size(), 800 // self.resolution ) n_samples = min(4, self.get_batch_size(), 800 // self.resolution )
src_samples, dst_samples = samples srcdst_samples, src_samples, dst_samples = samples
image_np, mask_np = src_samples image_np, mask_np = srcdst_samples
I, M, IM, = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([image_np,mask_np] + self.view (image_np) ) ] I, M, IM, = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([image_np,mask_np] + self.view (image_np) ) ]
M, IM, = [ np.repeat (x, (3,), -1) for x in [M, IM] ] M, IM, = [ np.repeat (x, (3,), -1) for x in [M, IM] ]
@ -174,9 +161,23 @@ class SkinSegModel(ModelBase):
result = [] result = []
st = [] st = []
for i in range(n_samples): for i in range(n_samples):
ar = I[i]*M[i]+0.5*I[i]*(1-M[i])+0.5*green_bg*(1-M[i]), IM[i], I[i]*IM[i] + green_bg*(1-IM[i]) ar = I[i]*M[i]+0.5*I[i]*(1-M[i])+0.5*green_bg*(1-M[i]), IM[i], I[i]*IM[i]+0.5*I[i]*(1-IM[i]) + 0.5*green_bg*(1-IM[i])
st.append ( np.concatenate ( ar, axis=1) ) st.append ( np.concatenate ( ar, axis=1) )
result += [ ('SkinSeg training faces', np.concatenate (st, axis=0 )), ] result += [ ('XSeg training faces', np.concatenate (st, axis=0 )), ]
if len(src_samples) != 0:
src_np, = src_samples
D, DM, = [ np.clip(nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([src_np] + self.view (src_np) ) ]
DM, = [ np.repeat (x, (3,), -1) for x in [DM] ]
st = []
for i in range(n_samples):
ar = D[i], DM[i], D[i]*DM[i] + 0.5*D[i]*(1-DM[i]) + 0.5*green_bg*(1-DM[i])
st.append ( np.concatenate ( ar, axis=1) )
result += [ ('XSeg src faces', np.concatenate (st, axis=0 )), ]
if len(dst_samples) != 0: if len(dst_samples) != 0:
dst_np, = dst_samples dst_np, = dst_samples
@ -187,11 +188,11 @@ class SkinSegModel(ModelBase):
st = [] st = []
for i in range(n_samples): for i in range(n_samples):
ar = D[i], DM[i], D[i]*DM[i]+ green_bg*(1-DM[i]) ar = D[i], DM[i], D[i]*DM[i] + 0.5*D[i]*(1-DM[i]) + 0.5*green_bg*(1-DM[i])
st.append ( np.concatenate ( ar, axis=1) ) st.append ( np.concatenate ( ar, axis=1) )
result += [ ('SkinSeg unseen faces', np.concatenate (st, axis=0 )), ] result += [ ('XSeg dst faces', np.concatenate (st, axis=0 )), ]
return result return result
Model = SkinSegModel Model = XSegModel

View file

@ -6,4 +6,5 @@ ffmpeg-python==0.1.17
scikit-image==0.14.2 scikit-image==0.14.2
scipy==1.4.1 scipy==1.4.1
colorama colorama
labelme==4.2.9
tensorflow-gpu==1.13.2 tensorflow-gpu==1.13.2

View file

@ -27,6 +27,7 @@ class Sample(object):
'shape', 'shape',
'landmarks', 'landmarks',
'ie_polys', 'ie_polys',
'seg_ie_polys',
'eyebrows_expand_mod', 'eyebrows_expand_mod',
'source_filename', 'source_filename',
'person_name', 'person_name',
@ -40,6 +41,7 @@ class Sample(object):
shape=None, shape=None,
landmarks=None, landmarks=None,
ie_polys=None, ie_polys=None,
seg_ie_polys=None,
eyebrows_expand_mod=None, eyebrows_expand_mod=None,
source_filename=None, source_filename=None,
person_name=None, person_name=None,
@ -52,6 +54,7 @@ class Sample(object):
self.shape = shape self.shape = shape
self.landmarks = np.array(landmarks) if landmarks is not None else None self.landmarks = np.array(landmarks) if landmarks is not None else None
self.ie_polys = IEPolys.load(ie_polys) self.ie_polys = IEPolys.load(ie_polys)
self.seg_ie_polys = IEPolys.load(seg_ie_polys)
self.eyebrows_expand_mod = eyebrows_expand_mod self.eyebrows_expand_mod = eyebrows_expand_mod
self.source_filename = source_filename self.source_filename = source_filename
self.person_name = person_name self.person_name = person_name
@ -88,6 +91,7 @@ class Sample(object):
'shape': self.shape, 'shape': self.shape,
'landmarks': self.landmarks.tolist(), 'landmarks': self.landmarks.tolist(),
'ie_polys': self.ie_polys.dump(), 'ie_polys': self.ie_polys.dump(),
'seg_ie_polys': self.seg_ie_polys.dump(),
'eyebrows_expand_mod': self.eyebrows_expand_mod, 'eyebrows_expand_mod': self.eyebrows_expand_mod,
'source_filename': self.source_filename, 'source_filename': self.source_filename,
'person_name': self.person_name 'person_name': self.person_name

View file

@ -1,260 +0,0 @@
import multiprocessing
import pickle
import time
import traceback
from enum import IntEnum
import cv2
import numpy as np
from core import imagelib, mplib, pathex
from core.imagelib import sd
from core.cv2ex import *
from core.interact import interact as io
from core.joblib import SubprocessGenerator, ThisThreadGenerator
from facelib import LandmarksProcessor
from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor, SampleType)
class MaskType(IntEnum):
none = 0,
cloth = 1,
ear_r = 2,
eye_g = 3,
hair = 4,
hat = 5,
l_brow = 6,
l_ear = 7,
l_eye = 8,
l_lip = 9,
mouth = 10,
neck = 11,
neck_l = 12,
nose = 13,
r_brow = 14,
r_ear = 15,
r_eye = 16,
skin = 17,
u_lip = 18
MaskType_to_name = {
int(MaskType.none ) : 'none',
int(MaskType.cloth ) : 'cloth',
int(MaskType.ear_r ) : 'ear_r',
int(MaskType.eye_g ) : 'eye_g',
int(MaskType.hair ) : 'hair',
int(MaskType.hat ) : 'hat',
int(MaskType.l_brow) : 'l_brow',
int(MaskType.l_ear ) : 'l_ear',
int(MaskType.l_eye ) : 'l_eye',
int(MaskType.l_lip ) : 'l_lip',
int(MaskType.mouth ) : 'mouth',
int(MaskType.neck ) : 'neck',
int(MaskType.neck_l) : 'neck_l',
int(MaskType.nose ) : 'nose',
int(MaskType.r_brow) : 'r_brow',
int(MaskType.r_ear ) : 'r_ear',
int(MaskType.r_eye ) : 'r_eye',
int(MaskType.skin ) : 'skin',
int(MaskType.u_lip ) : 'u_lip',
}
MaskType_from_name = { MaskType_to_name[k] : k for k in MaskType_to_name.keys() }
class SampleGeneratorFaceSkinSegDataset(SampleGeneratorBase):
def __init__ (self, root_path, debug=False, batch_size=1, resolution=256, face_type=None,
generators_count=4, data_format="NHWC",
**kwargs):
super().__init__(debug, batch_size)
self.initialized = False
aligned_path = root_path /'aligned'
if not aligned_path.exists():
raise ValueError(f'Unable to find {aligned_path}')
obstructions_path = root_path / 'obstructions'
obstructions_images_paths = pathex.get_image_paths(obstructions_path, image_extensions=['.png'], subdirs=True)
samples = SampleLoader.load (SampleType.FACE, aligned_path, subdirs=True)
self.samples_len = len(samples)
pickled_samples = pickle.dumps(samples, 4)
if self.debug:
self.generators_count = 1
else:
self.generators_count = max(1, generators_count)
if self.debug:
self.generators = [ThisThreadGenerator ( self.batch_func, (pickled_samples, obstructions_images_paths, resolution, face_type, data_format) )]
else:
self.generators = [SubprocessGenerator ( self.batch_func, (pickled_samples, obstructions_images_paths, resolution, face_type, data_format), start_now=False ) \
for i in range(self.generators_count) ]
SubprocessGenerator.start_in_parallel( self.generators )
self.generator_counter = -1
self.initialized = True
#overridable
def is_initialized(self):
return self.initialized
def __iter__(self):
return self
def __next__(self):
self.generator_counter += 1
generator = self.generators[self.generator_counter % len(self.generators) ]
return next(generator)
def batch_func(self, param ):
pickled_samples, obstructions_images_paths, resolution, face_type, data_format = param
samples = pickle.loads(pickled_samples)
obstructions_images_paths_len = len(obstructions_images_paths)
shuffle_o_idxs = []
o_idxs = [*range(obstructions_images_paths_len)]
shuffle_idxs = []
idxs = [*range(len(samples))]
random_flip = True
rotation_range=[-10,10]
scale_range=[-0.05, 0.05]
tx_range=[-0.05, 0.05]
ty_range=[-0.05, 0.05]
o_random_flip = True
o_rotation_range=[-180,180]
o_scale_range=[-0.5, 0.05]
o_tx_range=[-0.5, 0.5]
o_ty_range=[-0.5, 0.5]
random_bilinear_resize_chance, random_bilinear_resize_max_size_per = 25,75
motion_blur_chance, motion_blur_mb_max_size = 25, 5
gaussian_blur_chance, gaussian_blur_kernel_max_size = 25, 5
bs = self.batch_size
while True:
batches = [ [], [] ]
n_batch = 0
while n_batch < bs:
try:
if len(shuffle_idxs) == 0:
shuffle_idxs = idxs.copy()
np.random.shuffle(shuffle_idxs)
idx = shuffle_idxs.pop()
sample = samples[idx]
img = sample.load_bgr()
h,w,c = img.shape
mask = np.zeros ((h,w,1), dtype=np.float32)
sample.ie_polys.overlay_mask(mask)
warp_params = imagelib.gen_warp_params(resolution, random_flip, rotation_range=rotation_range, scale_range=scale_range, tx_range=tx_range, ty_range=ty_range )
if face_type == sample.face_type:
if w != resolution:
img = cv2.resize( img, (resolution, resolution), cv2.INTER_LANCZOS4 )
mask = cv2.resize( mask, (resolution, resolution), cv2.INTER_LANCZOS4 )
else:
mat = LandmarksProcessor.get_transform_mat (sample.landmarks, resolution, face_type)
img = cv2.warpAffine( img, mat, (resolution,resolution), borderMode=cv2.BORDER_CONSTANT, flags=cv2.INTER_LANCZOS4 )
mask = cv2.warpAffine( mask, mat, (resolution,resolution), borderMode=cv2.BORDER_CONSTANT, flags=cv2.INTER_LANCZOS4 )
if len(mask.shape) == 2:
mask = mask[...,None]
if obstructions_images_paths_len != 0:
# apply obstruction
if len(shuffle_o_idxs) == 0:
shuffle_o_idxs = o_idxs.copy()
np.random.shuffle(shuffle_o_idxs)
o_idx = shuffle_o_idxs.pop()
o_img = cv2_imread (obstructions_images_paths[o_idx]).astype(np.float32) / 255.0
oh,ow,oc = o_img.shape
if oc == 4:
ohw = max(oh,ow)
scale = resolution / ohw
#o_img = cv2.resize (o_img, ( int(ow*rate), int(oh*rate), ), cv2.INTER_CUBIC)
mat = cv2.getRotationMatrix2D( (ow/2,oh/2),
np.random.uniform( o_rotation_range[0], o_rotation_range[1] ),
1.0 )
mat += np.float32( [[0,0, -ow/2 ],
[0,0, -oh/2 ]])
mat *= scale * np.random.uniform(1 +o_scale_range[0], 1 +o_scale_range[1])
mat += np.float32( [[0, 0, resolution/2 + resolution*np.random.uniform( o_tx_range[0], o_tx_range[1] ) ],
[0, 0, resolution/2 + resolution*np.random.uniform( o_ty_range[0], o_ty_range[1] ) ] ])
o_img = cv2.warpAffine( o_img, mat, (resolution,resolution), borderMode=cv2.BORDER_CONSTANT, flags=cv2.INTER_LANCZOS4 )
if o_random_flip and np.random.randint(10) < 4:
o_img = o_img[:,::-1,...]
o_mask = o_img[...,3:4]
o_mask[o_mask>0] = 1.0
o_mask = cv2.erode (o_mask, cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(5,5)), iterations = 1 )
o_mask = cv2.GaussianBlur(o_mask, (5, 5) , 0)[...,None]
img = img*(1-o_mask) + o_img[...,0:3]*o_mask
o_mask[o_mask<0.5] = 0.0
#import code
#code.interact(local=dict(globals(), **locals()))
mask *= (1-o_mask)
#cv2.imshow ("", np.clip(o_img*255, 0,255).astype(np.uint8) )
#cv2.waitKey(0)
img = imagelib.warp_by_params (warp_params, img, can_warp=True, can_transform=True, can_flip=True, border_replicate=False)
mask = imagelib.warp_by_params (warp_params, mask, can_warp=True, can_transform=True, can_flip=True, border_replicate=False)
img = np.clip(img.astype(np.float32), 0, 1)
mask[mask < 0.5] = 0.0
mask[mask >= 0.5] = 1.0
mask = np.clip(mask, 0, 1)
img = imagelib.apply_random_hsv_shift(img, mask=sd.random_circle_faded ([resolution,resolution]))
img = imagelib.apply_random_motion_blur( img, motion_blur_chance, motion_blur_mb_max_size, mask=sd.random_circle_faded ([resolution,resolution]))
img = imagelib.apply_random_gaussian_blur( img, gaussian_blur_chance, gaussian_blur_kernel_max_size, mask=sd.random_circle_faded ([resolution,resolution]))
img = imagelib.apply_random_bilinear_resize( img, random_bilinear_resize_chance, random_bilinear_resize_max_size_per, mask=sd.random_circle_faded ([resolution,resolution]))
if data_format == "NCHW":
img = np.transpose(img, (2,0,1) )
mask = np.transpose(mask, (2,0,1) )
batches[0].append ( img )
batches[1].append ( mask )
n_batch += 1
except:
io.log_err ( traceback.format_exc() )
yield [ np.array(batch) for batch in batches]

View file

@ -0,0 +1,148 @@
import multiprocessing
import pickle
import time
import traceback
from enum import IntEnum
import cv2
import numpy as np
from core import imagelib, mplib, pathex
from core.imagelib import sd
from core.cv2ex import *
from core.interact import interact as io
from core.joblib import SubprocessGenerator, ThisThreadGenerator
from facelib import LandmarksProcessor
from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor, SampleType)
class SampleGeneratorFaceXSeg(SampleGeneratorBase):
def __init__ (self, paths, debug=False, batch_size=1, resolution=256, face_type=None,
generators_count=4, data_format="NHWC",
**kwargs):
super().__init__(debug, batch_size)
self.initialized = False
samples = []
for path in paths:
samples += SampleLoader.load (SampleType.FACE, path)
seg_samples = [ sample for sample in samples if sample.seg_ie_polys.get_total_points() != 0]
seg_samples_len = len(seg_samples)
if seg_samples_len == 0:
raise Exception(f"No segmented faces found.")
else:
io.log_info(f"Using {seg_samples_len} segmented samples.")
pickled_samples = pickle.dumps(seg_samples, 4)
if self.debug:
self.generators_count = 1
else:
self.generators_count = max(1, generators_count)
if self.debug:
self.generators = [ThisThreadGenerator ( self.batch_func, (pickled_samples, resolution, face_type, data_format) )]
else:
self.generators = [SubprocessGenerator ( self.batch_func, (pickled_samples, resolution, face_type, data_format), start_now=False ) \
for i in range(self.generators_count) ]
SubprocessGenerator.start_in_parallel( self.generators )
self.generator_counter = -1
self.initialized = True
#overridable
def is_initialized(self):
return self.initialized
def __iter__(self):
return self
def __next__(self):
self.generator_counter += 1
generator = self.generators[self.generator_counter % len(self.generators) ]
return next(generator)
def batch_func(self, param ):
pickled_samples, resolution, face_type, data_format = param
samples = pickle.loads(pickled_samples)
shuffle_idxs = []
idxs = [*range(len(samples))]
random_flip = True
rotation_range=[-10,10]
scale_range=[-0.05, 0.05]
tx_range=[-0.05, 0.05]
ty_range=[-0.05, 0.05]
random_bilinear_resize_chance, random_bilinear_resize_max_size_per = 25,75
motion_blur_chance, motion_blur_mb_max_size = 25, 5
gaussian_blur_chance, gaussian_blur_kernel_max_size = 25, 5
bs = self.batch_size
while True:
batches = [ [], [] ]
n_batch = 0
while n_batch < bs:
try:
if len(shuffle_idxs) == 0:
shuffle_idxs = idxs.copy()
np.random.shuffle(shuffle_idxs)
idx = shuffle_idxs.pop()
sample = samples[idx]
img = sample.load_bgr()
h,w,c = img.shape
mask = np.zeros ((h,w,1), dtype=np.float32)
sample.seg_ie_polys.overlay_mask(mask)
warp_params = imagelib.gen_warp_params(resolution, random_flip, rotation_range=rotation_range, scale_range=scale_range, tx_range=tx_range, ty_range=ty_range )
if face_type == sample.face_type:
if w != resolution:
img = cv2.resize( img, (resolution, resolution), cv2.INTER_LANCZOS4 )
mask = cv2.resize( mask, (resolution, resolution), cv2.INTER_LANCZOS4 )
else:
mat = LandmarksProcessor.get_transform_mat (sample.landmarks, resolution, face_type)
img = cv2.warpAffine( img, mat, (resolution,resolution), borderMode=cv2.BORDER_CONSTANT, flags=cv2.INTER_LANCZOS4 )
mask = cv2.warpAffine( mask, mat, (resolution,resolution), borderMode=cv2.BORDER_CONSTANT, flags=cv2.INTER_LANCZOS4 )
if len(mask.shape) == 2:
mask = mask[...,None]
img = imagelib.warp_by_params (warp_params, img, can_warp=True, can_transform=True, can_flip=True, border_replicate=False)
mask = imagelib.warp_by_params (warp_params, mask, can_warp=True, can_transform=True, can_flip=True, border_replicate=False)
img = np.clip(img.astype(np.float32), 0, 1)
mask[mask < 0.5] = 0.0
mask[mask >= 0.5] = 1.0
mask = np.clip(mask, 0, 1)
if np.random.randint(2) == 0:
img = imagelib.apply_random_hsv_shift(img, mask=sd.random_circle_faded ([resolution,resolution]))
else:
img = imagelib.apply_random_rgb_levels(img, mask=sd.random_circle_faded ([resolution,resolution]))
img = imagelib.apply_random_motion_blur( img, motion_blur_chance, motion_blur_mb_max_size, mask=sd.random_circle_faded ([resolution,resolution]))
img = imagelib.apply_random_gaussian_blur( img, gaussian_blur_chance, gaussian_blur_kernel_max_size, mask=sd.random_circle_faded ([resolution,resolution]))
img = imagelib.apply_random_bilinear_resize( img, random_bilinear_resize_chance, random_bilinear_resize_max_size_per, mask=sd.random_circle_faded ([resolution,resolution]))
if data_format == "NCHW":
img = np.transpose(img, (2,0,1) )
mask = np.transpose(mask, (2,0,1) )
batches[0].append ( img )
batches[1].append ( mask )
n_batch += 1
except:
io.log_err ( traceback.format_exc() )
yield [ np.array(batch) for batch in batches]

View file

@ -75,6 +75,7 @@ class SampleLoader:
shape, shape,
landmarks, landmarks,
ie_polys, ie_polys,
seg_ie_polys,
eyebrows_expand_mod, eyebrows_expand_mod,
source_filename, source_filename,
) in result: ) in result:
@ -84,6 +85,7 @@ class SampleLoader:
shape=shape, shape=shape,
landmarks=landmarks, landmarks=landmarks,
ie_polys=ie_polys, ie_polys=ie_polys,
seg_ie_polys=seg_ie_polys,
eyebrows_expand_mod=eyebrows_expand_mod, eyebrows_expand_mod=eyebrows_expand_mod,
source_filename=source_filename, source_filename=source_filename,
)) ))
@ -177,6 +179,7 @@ class FaceSamplesLoaderSubprocessor(Subprocessor):
dflimg.get_shape(), dflimg.get_shape(),
dflimg.get_landmarks(), dflimg.get_landmarks(),
dflimg.get_ie_polys(), dflimg.get_ie_polys(),
dflimg.get_seg_ie_polys(),
dflimg.get_eyebrows_expand_mod(), dflimg.get_eyebrows_expand_mod(),
dflimg.get_source_filename() ) dflimg.get_source_filename() )

View file

@ -9,5 +9,5 @@ from .SampleGeneratorFaceTemporal import SampleGeneratorFaceTemporal
from .SampleGeneratorImage import SampleGeneratorImage from .SampleGeneratorImage import SampleGeneratorImage
from .SampleGeneratorImageTemporal import SampleGeneratorImageTemporal from .SampleGeneratorImageTemporal import SampleGeneratorImageTemporal
from .SampleGeneratorFaceCelebAMaskHQ import SampleGeneratorFaceCelebAMaskHQ from .SampleGeneratorFaceCelebAMaskHQ import SampleGeneratorFaceCelebAMaskHQ
from .SampleGeneratorFaceSkinSegDataset import SampleGeneratorFaceSkinSegDataset from .SampleGeneratorFaceXSeg import SampleGeneratorFaceXSeg
from .PackedFaceset import PackedFaceset from .PackedFaceset import PackedFaceset