Maximum resolution is increased to 640.

‘hd’ archi is removed. ‘hd’ was experimental archi created to remove subpixel shake, but ‘lr_dropout’ and ‘disable random warping’ do that better.

‘uhd’ is renamed to ‘-u’
dfuhd and liaeuhd will be automatically renamed to df-u and liae-u in existing models.

Added new experimental archi (key -d) which doubles the resolution using the same computation cost.
It is mean same configs will be x2 faster, or for example you can set 448 resolution and it will train as 224.
Strongly recommended not to train from scratch and use pretrained models.

New archi naming:
'df' keeps more identity-preserved face.
'liae' can fix overly different face shapes.
'-u' increased likeness of the face.
'-d' (experimental) doubling the resolution using the same computation cost
Examples: df, liae, df-d, df-ud, liae-ud, ...

Improved GAN training (GAN_power option).  It was used for dst model, but actually we don’t need it for dst.
Instead, a second src GAN model with x2 smaller patch size was added, so the overall quality for hi-res models should be higher.

Added option ‘Uniform yaw distribution of samples (y/n)’:
	Helps to fix blurry side faces due to small amount of them in the faceset.

Quick96:
	Now based on df-ud archi and 20% faster.

XSeg trainer:
	Improved sample generator.
Now it randomly adds the background from other samples.
Result is reduced chance of random mask noise on the area outside the face.
Now you can specify ‘batch_size’ in range 2-16.

Reduced size of samples with applied XSeg mask. Thus size of packed samples with applied xseg mask is also reduced.
This commit is contained in:
Colombo 2020-06-19 09:45:55 +04:00
parent 9fd3a9ff8d
commit 0c2e1c3944
14 changed files with 513 additions and 572 deletions

View file

@ -7,6 +7,7 @@ import types
import colorama import colorama
import cv2 import cv2
import numpy as np
from tqdm import tqdm from tqdm import tqdm
from core import stdex from core import stdex
@ -255,7 +256,7 @@ class InteractBase(object):
print(result) print(result)
return result return result
def input_int(self, s, default_value, valid_list=None, add_info=None, show_default_value=True, help_message=None): def input_int(self, s, default_value, valid_range=None, valid_list=None, add_info=None, show_default_value=True, help_message=None):
if show_default_value: if show_default_value:
if len(s) != 0: if len(s) != 0:
s = f"[{default_value}] {s}" s = f"[{default_value}] {s}"
@ -263,15 +264,21 @@ class InteractBase(object):
s = f"[{default_value}]" s = f"[{default_value}]"
if add_info is not None or \ if add_info is not None or \
valid_range is not None or \
help_message is not None: help_message is not None:
s += " (" s += " ("
if valid_range is not None:
s += f" {valid_range[0]}-{valid_range[1]} "
if add_info is not None: if add_info is not None:
s += f" {add_info}" s += f" {add_info}"
if help_message is not None: if help_message is not None:
s += " ?:help" s += " ?:help"
if add_info is not None or \ if add_info is not None or \
valid_range is not None or \
help_message is not None: help_message is not None:
s += " )" s += " )"
@ -288,9 +295,12 @@ class InteractBase(object):
continue continue
i = int(inp) i = int(inp)
if valid_range is not None:
i = np.clip(i, valid_range[0], valid_range[1])
if (valid_list is not None) and (i not in valid_list): if (valid_list is not None) and (i not in valid_list):
result = default_value i = default_value
break
result = i result = i
break break
except: except:

View file

@ -1,54 +1,46 @@
from core.leras import nn from core.leras import nn
tf = nn.tf tf = nn.tf
class DeepFakeArchi(nn.ArchiBase): class DeepFakeArchi(nn.ArchiBase):
""" """
resolution resolution
mod None - default mod None - default
'uhd'
'quick' 'quick'
""" """
def __init__(self, resolution, mod=None): def __init__(self, resolution, mod=None, opts=None):
super().__init__() super().__init__()
if opts is None:
opts = ''
if mod is None: if mod is None:
class Downscale(nn.ModelBase): class Downscale(nn.ModelBase):
def __init__(self, in_ch, out_ch, kernel_size=5, dilations=1, subpixel=True, use_activator=True, *kwargs ): def __init__(self, in_ch, out_ch, kernel_size=5, *kwargs ):
self.in_ch = in_ch self.in_ch = in_ch
self.out_ch = out_ch self.out_ch = out_ch
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.dilations = dilations
self.subpixel = subpixel
self.use_activator = use_activator
super().__init__(*kwargs) super().__init__(*kwargs)
def on_build(self, *args, **kwargs ): def on_build(self, *args, **kwargs ):
self.conv1 = nn.Conv2D( self.in_ch, self.conv1 = nn.Conv2D( self.in_ch, self.out_ch, kernel_size=self.kernel_size, strides=2, padding='SAME')
self.out_ch // (4 if self.subpixel else 1),
kernel_size=self.kernel_size,
strides=1 if self.subpixel else 2,
padding='SAME', dilations=self.dilations)
def forward(self, x): def forward(self, x):
x = self.conv1(x) x = self.conv1(x)
if self.subpixel: x = tf.nn.leaky_relu(x, 0.1)
x = nn.space_to_depth(x, 2)
if self.use_activator:
x = tf.nn.leaky_relu(x, 0.1)
return x return x
def get_out_ch(self): def get_out_ch(self):
return (self.out_ch // 4) * 4 if self.subpixel else self.out_ch return self.out_ch
class DownscaleBlock(nn.ModelBase): class DownscaleBlock(nn.ModelBase):
def on_build(self, in_ch, ch, n_downscales, kernel_size, dilations=1, subpixel=True): def on_build(self, in_ch, ch, n_downscales, kernel_size):
self.downs = [] self.downs = []
last_ch = in_ch last_ch = in_ch
for i in range(n_downscales): for i in range(n_downscales):
cur_ch = ch*( min(2**i, 8) ) cur_ch = ch*( min(2**i, 8) )
self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size, dilations=dilations, subpixel=subpixel) ) self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size) )
last_ch = self.downs[-1].get_out_ch() last_ch = self.downs[-1].get_out_ch()
def forward(self, inp): def forward(self, inp):
@ -79,346 +71,33 @@ class DeepFakeArchi(nn.ArchiBase):
x = tf.nn.leaky_relu(inp + x, 0.2) x = tf.nn.leaky_relu(inp + x, 0.2)
return x return x
class UpdownResidualBlock(nn.ModelBase):
def on_build(self, ch, inner_ch, kernel_size=3 ):
self.up = Upscale (ch, inner_ch, kernel_size=kernel_size)
self.res = ResidualBlock (inner_ch, kernel_size=kernel_size)
self.down = Downscale (inner_ch, ch, kernel_size=kernel_size, use_activator=False)
def forward(self, inp):
x = self.up(inp)
x = upx = self.res(x)
x = self.down(x)
x = x + inp
x = tf.nn.leaky_relu(x, 0.2)
return x, upx
class Encoder(nn.ModelBase): class Encoder(nn.ModelBase):
def on_build(self, in_ch, e_ch, is_hd): def on_build(self, in_ch, e_ch):
self.is_hd=is_hd self.down1 = DownscaleBlock(in_ch, e_ch, n_downscales=4, kernel_size=5)
if self.is_hd:
self.down1 = DownscaleBlock(in_ch, e_ch*2, n_downscales=4, kernel_size=3, dilations=1)
self.down2 = DownscaleBlock(in_ch, e_ch*2, n_downscales=4, kernel_size=5, dilations=1)
self.down3 = DownscaleBlock(in_ch, e_ch//2, n_downscales=4, kernel_size=5, dilations=2)
self.down4 = DownscaleBlock(in_ch, e_ch//2, n_downscales=4, kernel_size=7, dilations=2)
else:
self.down1 = DownscaleBlock(in_ch, e_ch, n_downscales=4, kernel_size=5, dilations=1, subpixel=False)
def forward(self, inp): def forward(self, inp):
if self.is_hd: return nn.flatten(self.down1(inp))
x = tf.concat([ nn.flatten(self.down1(inp)),
nn.flatten(self.down2(inp)), lowest_dense_res = resolution // (32 if 'd' in opts else 16)
nn.flatten(self.down3(inp)),
nn.flatten(self.down4(inp)) ], -1 )
else:
x = nn.flatten(self.down1(inp))
return x
lowest_dense_res = resolution // 16
class Inter(nn.ModelBase): class Inter(nn.ModelBase):
def __init__(self, in_ch, ae_ch, ae_out_ch, is_hd=False, **kwargs): def __init__(self, in_ch, ae_ch, ae_out_ch, **kwargs):
self.in_ch, self.ae_ch, self.ae_out_ch = in_ch, ae_ch, ae_out_ch self.in_ch, self.ae_ch, self.ae_out_ch = in_ch, ae_ch, ae_out_ch
super().__init__(**kwargs) super().__init__(**kwargs)
def on_build(self): def on_build(self):
in_ch, ae_ch, ae_out_ch = self.in_ch, self.ae_ch, self.ae_out_ch in_ch, ae_ch, ae_out_ch = self.in_ch, self.ae_ch, self.ae_out_ch
if 'u' in opts:
self.dense_norm = nn.DenseNorm()
self.dense1 = nn.Dense( in_ch, ae_ch ) self.dense1 = nn.Dense( in_ch, ae_ch )
self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch ) self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch )
self.upscale1 = Upscale(ae_out_ch, ae_out_ch) self.upscale1 = Upscale(ae_out_ch, ae_out_ch)
def forward(self, inp):
x = self.dense1(inp)
x = self.dense2(x)
x = nn.reshape_4D (x, lowest_dense_res, lowest_dense_res, self.ae_out_ch)
x = self.upscale1(x)
return x
@staticmethod
def get_code_res():
return lowest_dense_res
def get_out_ch(self):
return self.ae_out_ch
class Decoder(nn.ModelBase):
def on_build(self, in_ch, d_ch, d_mask_ch, is_hd ):
self.is_hd = is_hd
self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3)
self.upscale1 = Upscale(d_ch*8, d_ch*4, kernel_size=3)
self.upscale2 = Upscale(d_ch*4, d_ch*2, kernel_size=3)
if is_hd:
self.res0 = UpdownResidualBlock(in_ch, d_ch*8, kernel_size=3)
self.res1 = UpdownResidualBlock(d_ch*8, d_ch*4, kernel_size=3)
self.res2 = UpdownResidualBlock(d_ch*4, d_ch*2, kernel_size=3)
self.res3 = UpdownResidualBlock(d_ch*2, d_ch, kernel_size=3)
else:
self.res0 = ResidualBlock(d_ch*8, kernel_size=3)
self.res1 = ResidualBlock(d_ch*4, kernel_size=3)
self.res2 = ResidualBlock(d_ch*2, kernel_size=3)
self.out_conv = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='SAME')
self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3)
self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3)
self.upscalem2 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3)
self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME')
def forward(self, inp):
z = inp
if self.is_hd:
x, upx = self.res0(z)
x = self.upscale0(x)
x = tf.nn.leaky_relu(x + upx, 0.2)
x, upx = self.res1(x)
x = self.upscale1(x)
x = tf.nn.leaky_relu(x + upx, 0.2)
x, upx = self.res2(x)
x = self.upscale2(x)
x = tf.nn.leaky_relu(x + upx, 0.2)
x, upx = self.res3(x)
else:
x = self.upscale0(z)
x = self.res0(x)
x = self.upscale1(x)
x = self.res1(x)
x = self.upscale2(x)
x = self.res2(x)
m = self.upscalem0(z)
m = self.upscalem1(m)
m = self.upscalem2(m)
return tf.nn.sigmoid(self.out_conv(x)), \
tf.nn.sigmoid(self.out_convm(m))
elif mod == 'quick':
class Downscale(nn.ModelBase):
def __init__(self, in_ch, out_ch, kernel_size=5, dilations=1, subpixel=True, use_activator=True, *kwargs ):
self.in_ch = in_ch
self.out_ch = out_ch
self.kernel_size = kernel_size
self.dilations = dilations
self.subpixel = subpixel
self.use_activator = use_activator
super().__init__(*kwargs)
def on_build(self, *args, **kwargs ):
self.conv1 = nn.Conv2D( self.in_ch,
self.out_ch // (4 if self.subpixel else 1),
kernel_size=self.kernel_size,
strides=1 if self.subpixel else 2,
padding='SAME', dilations=self.dilations )
def forward(self, x):
x = self.conv1(x)
if self.subpixel:
x = nn.space_to_depth(x, 2)
if self.use_activator:
x = nn.gelu(x)
return x
def get_out_ch(self):
return (self.out_ch // 4) * 4 if self.subpixel else self.out_ch
class DownscaleBlock(nn.ModelBase):
def on_build(self, in_ch, ch, n_downscales, kernel_size, dilations=1, subpixel=True):
self.downs = []
last_ch = in_ch
for i in range(n_downscales):
cur_ch = ch*( min(2**i, 8) )
self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size, dilations=dilations, subpixel=subpixel) )
last_ch = self.downs[-1].get_out_ch()
def forward(self, inp): def forward(self, inp):
x = inp x = inp
for down in self.downs: if 'u' in opts:
x = down(x) x = self.dense_norm(x)
return x
class Upscale(nn.ModelBase):
def on_build(self, in_ch, out_ch, kernel_size=3 ):
self.conv1 = nn.Conv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME')
def forward(self, x):
x = self.conv1(x)
x = nn.gelu(x)
x = nn.depth_to_space(x, 2)
return x
class ResidualBlock(nn.ModelBase):
def on_build(self, ch, kernel_size=3 ):
self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
def forward(self, inp):
x = self.conv1(inp)
x = nn.gelu(x)
x = self.conv2(x)
x = inp + x
x = nn.gelu(x)
return x
class Encoder(nn.ModelBase):
def on_build(self, in_ch, e_ch):
self.down1 = DownscaleBlock(in_ch, e_ch, n_downscales=4, kernel_size=5)
def forward(self, inp):
return nn.flatten(self.down1(inp))
lowest_dense_res = resolution // 16
class Inter(nn.ModelBase):
def __init__(self, in_ch, ae_ch, ae_out_ch, d_ch, **kwargs):
self.in_ch, self.ae_ch, self.ae_out_ch, self.d_ch = in_ch, ae_ch, ae_out_ch, d_ch
super().__init__(**kwargs)
def on_build(self):
in_ch, ae_ch, ae_out_ch, d_ch = self.in_ch, self.ae_ch, self.ae_out_ch, self.d_ch
self.dense1 = nn.Dense( in_ch, ae_ch, kernel_initializer=tf.initializers.orthogonal )
self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch, kernel_initializer=tf.initializers.orthogonal )
self.upscale1 = Upscale(ae_out_ch, d_ch*8)
self.res1 = ResidualBlock(d_ch*8)
def forward(self, inp):
x = self.dense1(inp)
x = self.dense2(x)
x = nn.reshape_4D (x, lowest_dense_res, lowest_dense_res, self.ae_out_ch)
x = self.upscale1(x)
x = self.res1(x)
return x
def get_out_ch(self):
return self.ae_out_ch
class Decoder(nn.ModelBase):
def on_build(self, in_ch, d_ch):
self.upscale1 = Upscale(in_ch, d_ch*4)
self.res1 = ResidualBlock(d_ch*4)
self.upscale2 = Upscale(d_ch*4, d_ch*2)
self.res2 = ResidualBlock(d_ch*2)
self.upscale3 = Upscale(d_ch*2, d_ch*1)
self.res3 = ResidualBlock(d_ch*1)
self.upscalem1 = Upscale(in_ch, d_ch)
self.upscalem2 = Upscale(d_ch, d_ch//2)
self.upscalem3 = Upscale(d_ch//2, d_ch//2)
self.out_conv = nn.Conv2D( d_ch*1, 3, kernel_size=1, padding='SAME')
self.out_convm = nn.Conv2D( d_ch//2, 1, kernel_size=1, padding='SAME')
def forward(self, inp):
z = inp
x = self.upscale1 (z)
x = self.res1 (x)
x = self.upscale2 (x)
x = self.res2 (x)
x = self.upscale3 (x)
x = self.res3 (x)
y = self.upscalem1 (z)
y = self.upscalem2 (y)
y = self.upscalem3 (y)
return tf.nn.sigmoid(self.out_conv(x)), \
tf.nn.sigmoid(self.out_convm(y))
elif mod == 'uhd':
class Downscale(nn.ModelBase):
def __init__(self, in_ch, out_ch, kernel_size=5, dilations=1, subpixel=True, use_activator=True, *kwargs ):
self.in_ch = in_ch
self.out_ch = out_ch
self.kernel_size = kernel_size
self.dilations = dilations
self.subpixel = subpixel
self.use_activator = use_activator
super().__init__(*kwargs)
def on_build(self, *args, **kwargs ):
self.conv1 = nn.Conv2D( self.in_ch,
self.out_ch // (4 if self.subpixel else 1),
kernel_size=self.kernel_size,
strides=1 if self.subpixel else 2,
padding='SAME', dilations=self.dilations)
def forward(self, x):
x = self.conv1(x)
if self.subpixel:
x = nn.space_to_depth(x, 2)
if self.use_activator:
x = tf.nn.leaky_relu(x, 0.1)
return x
def get_out_ch(self):
return (self.out_ch // 4) * 4 if self.subpixel else self.out_ch
class DownscaleBlock(nn.ModelBase):
def on_build(self, in_ch, ch, n_downscales, kernel_size, dilations=1, subpixel=True):
self.downs = []
last_ch = in_ch
for i in range(n_downscales):
cur_ch = ch*( min(2**i, 8) )
self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size, dilations=dilations, subpixel=subpixel) )
last_ch = self.downs[-1].get_out_ch()
def forward(self, inp):
x = inp
for down in self.downs:
x = down(x)
return x
class Upscale(nn.ModelBase):
def on_build(self, in_ch, out_ch, kernel_size=3 ):
self.conv1 = nn.Conv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME')
def forward(self, x):
x = self.conv1(x)
x = tf.nn.leaky_relu(x, 0.1)
x = nn.depth_to_space(x, 2)
return x
class ResidualBlock(nn.ModelBase):
def on_build(self, ch, kernel_size=3 ):
self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
def forward(self, inp):
x = self.conv1(inp)
x = tf.nn.leaky_relu(x, 0.2)
x = self.conv2(x)
x = tf.nn.leaky_relu(inp + x, 0.2)
return x
class Encoder(nn.ModelBase):
def on_build(self, in_ch, e_ch, **kwargs):
self.down1 = DownscaleBlock(in_ch, e_ch, n_downscales=4, kernel_size=5, dilations=1, subpixel=False)
def forward(self, inp):
x = nn.flatten(self.down1(inp))
return x
lowest_dense_res = resolution // 16
class Inter(nn.ModelBase):
def on_build(self, in_ch, ae_ch, ae_out_ch, **kwargs):
self.ae_out_ch = ae_out_ch
self.dense_norm = nn.DenseNorm()
self.dense1 = nn.Dense( in_ch, ae_ch )
self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch )
self.upscale1 = Upscale(ae_out_ch, ae_out_ch)
def forward(self, inp):
x = self.dense_norm(inp)
x = self.dense1(x) x = self.dense1(x)
x = self.dense2(x) x = self.dense2(x)
x = nn.reshape_4D (x, lowest_dense_res, lowest_dense_res, self.ae_out_ch) x = nn.reshape_4D (x, lowest_dense_res, lowest_dense_res, self.ae_out_ch)
@ -433,8 +112,7 @@ class DeepFakeArchi(nn.ArchiBase):
return self.ae_out_ch return self.ae_out_ch
class Decoder(nn.ModelBase): class Decoder(nn.ModelBase):
def on_build(self, in_ch, d_ch, d_mask_ch, **kwargs ): def on_build(self, in_ch, d_ch, d_mask_ch ):
self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3) self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3)
self.upscale1 = Upscale(d_ch*8, d_ch*4, kernel_size=3) self.upscale1 = Upscale(d_ch*8, d_ch*4, kernel_size=3)
self.upscale2 = Upscale(d_ch*4, d_ch*2, kernel_size=3) self.upscale2 = Upscale(d_ch*4, d_ch*2, kernel_size=3)
@ -450,6 +128,15 @@ class DeepFakeArchi(nn.ArchiBase):
self.upscalem2 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3) self.upscalem2 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3)
self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME') self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME')
if 'd' in opts:
self.out_conv1 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME')
self.out_conv2 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME')
self.out_conv3 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME')
self.upscalem3 = Upscale(d_mask_ch*2, d_mask_ch*1, kernel_size=3)
self.out_convm = nn.Conv2D( d_mask_ch*1, 1, kernel_size=1, padding='SAME')
else:
self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME')
def forward(self, inp): def forward(self, inp):
z = inp z = inp
@ -460,13 +147,53 @@ class DeepFakeArchi(nn.ArchiBase):
x = self.upscale2(x) x = self.upscale2(x)
x = self.res2(x) x = self.res2(x)
if 'd' in opts:
x0 = tf.nn.sigmoid(self.out_conv(x))
x0 = nn.upsample2d(x0)
x1 = tf.nn.sigmoid(self.out_conv1(x))
x1 = nn.upsample2d(x1)
x2 = tf.nn.sigmoid(self.out_conv2(x))
x2 = nn.upsample2d(x2)
x3 = tf.nn.sigmoid(self.out_conv3(x))
x3 = nn.upsample2d(x3)
if nn.data_format == "NHWC":
tile_cfg = ( 1, resolution // 2, resolution //2, 1)
else:
tile_cfg = ( 1, 1, resolution // 2, resolution //2 )
z0 = tf.concat ( ( tf.concat ( ( tf.ones ( (1,1,1,1) ), tf.zeros ( (1,1,1,1) ) ), axis=nn.conv2d_spatial_axes[1] ),
tf.concat ( ( tf.zeros ( (1,1,1,1) ), tf.zeros ( (1,1,1,1) ) ), axis=nn.conv2d_spatial_axes[1] ) ), axis=nn.conv2d_spatial_axes[0] )
z0 = tf.tile ( z0, tile_cfg )
z1 = tf.concat ( ( tf.concat ( ( tf.zeros ( (1,1,1,1) ), tf.ones ( (1,1,1,1) ) ), axis=nn.conv2d_spatial_axes[1] ),
tf.concat ( ( tf.zeros ( (1,1,1,1) ), tf.zeros ( (1,1,1,1) ) ), axis=nn.conv2d_spatial_axes[1] ) ), axis=nn.conv2d_spatial_axes[0] )
z1 = tf.tile ( z1, tile_cfg )
z2 = tf.concat ( ( tf.concat ( ( tf.zeros ( (1,1,1,1) ), tf.zeros ( (1,1,1,1) ) ), axis=nn.conv2d_spatial_axes[1] ),
tf.concat ( ( tf.ones ( (1,1,1,1) ), tf.zeros ( (1,1,1,1) ) ), axis=nn.conv2d_spatial_axes[1] ) ), axis=nn.conv2d_spatial_axes[0] )
z2 = tf.tile ( z2, tile_cfg )
z3 = tf.concat ( ( tf.concat ( ( tf.zeros ( (1,1,1,1) ), tf.zeros ( (1,1,1,1) ) ), axis=nn.conv2d_spatial_axes[1] ),
tf.concat ( ( tf.zeros ( (1,1,1,1) ), tf.ones ( (1,1,1,1) ) ), axis=nn.conv2d_spatial_axes[1] ) ), axis=nn.conv2d_spatial_axes[0] )
z3 = tf.tile ( z3, tile_cfg )
x = x0*z0 + x1*z1 + x2*z2 + x3*z3
else:
x = tf.nn.sigmoid(self.out_conv(x))
m = self.upscalem0(z) m = self.upscalem0(z)
m = self.upscalem1(m) m = self.upscalem1(m)
m = self.upscalem2(m) m = self.upscalem2(m)
if 'd' in opts:
m = self.upscalem3(m)
m = tf.nn.sigmoid(self.out_convm(m))
return tf.nn.sigmoid(self.out_conv(x)), \ return x, m
tf.nn.sigmoid(self.out_convm(m))
self.Encoder = Encoder self.Encoder = Encoder
self.Inter = Inter self.Inter = Inter
self.Decoder = Decoder self.Decoder = Decoder

View file

@ -40,7 +40,15 @@ patch_discriminator_kernels = \
35 : (256, [ [3,2], [4,2], [4,2], [3,1] ]), 35 : (256, [ [3,2], [4,2], [4,2], [3,1] ]),
36 : (256, [ [4,2], [3,2], [4,2], [3,1] ]), 36 : (256, [ [4,2], [3,2], [4,2], [3,1] ]),
37 : (256, [ [3,2], [4,2], [4,2], [3,1] ]), 37 : (256, [ [3,2], [4,2], [4,2], [3,1] ]),
38 : (256, [ [4,2], [4,2], [4,2], [3,1] ]), 38 : (256, [ [4,2], [4,2], [4,2], [3,1] ]),
39 : (256, [ [3,2], [3,2], [3,2], [4,1] ]),
40 : (256, [ [4,2], [3,2], [3,2], [4,1] ]),
41 : (256, [ [3,2], [4,2], [3,2], [4,1] ]),
42 : (256, [ [4,2], [4,2], [3,2], [4,1] ]),
43 : (256, [ [3,2], [4,2], [4,2], [4,1] ]),
44 : (256, [ [4,2], [3,2], [4,2], [4,1] ]),
45 : (256, [ [3,2], [4,2], [4,2], [4,1] ]),
46 : (256, [ [4,2], [4,2], [4,2], [4,1] ]),
} }

View file

@ -120,21 +120,21 @@ nn.upsample2d = upsample2d
def resize2d_bilinear(x, size=2): def resize2d_bilinear(x, size=2):
h = x.shape[nn.conv2d_spatial_axes[0]].value h = x.shape[nn.conv2d_spatial_axes[0]].value
w = x.shape[nn.conv2d_spatial_axes[1]].value w = x.shape[nn.conv2d_spatial_axes[1]].value
if nn.data_format == "NCHW": if nn.data_format == "NCHW":
x = tf.transpose(x, (0,2,3,1)) x = tf.transpose(x, (0,2,3,1))
if size > 0: if size > 0:
new_size = (h*size,w*size) new_size = (h*size,w*size)
else: else:
new_size = (h//-size,w//-size) new_size = (h//-size,w//-size)
x = tf.image.resize(x, new_size, method=tf.image.ResizeMethod.BILINEAR) x = tf.image.resize(x, new_size, method=tf.image.ResizeMethod.BILINEAR)
if nn.data_format == "NCHW": if nn.data_format == "NCHW":
x = tf.transpose(x, (0,3,1,2)) x = tf.transpose(x, (0,3,1,2))
return x return x
nn.resize2d_bilinear = resize2d_bilinear nn.resize2d_bilinear = resize2d_bilinear

View file

@ -5,96 +5,6 @@ import time
import numpy as np import numpy as np
class Index2DHost():
"""
Provides random shuffled 2D indexes for multiprocesses
"""
def __init__(self, indexes2D):
self.sq = multiprocessing.Queue()
self.cqs = []
self.clis = []
self.thread = threading.Thread(target=self.host_thread, args=(indexes2D,) )
self.thread.daemon = True
self.thread.start()
def host_thread(self, indexes2D):
indexes_counts_len = len(indexes2D)
idxs = [*range(indexes_counts_len)]
idxs_2D = [None]*indexes_counts_len
shuffle_idxs = []
shuffle_idxs_2D = [None]*indexes_counts_len
for i in range(indexes_counts_len):
idxs_2D[i] = indexes2D[i]
shuffle_idxs_2D[i] = []
sq = self.sq
while True:
while not sq.empty():
obj = sq.get()
cq_id, cmd = obj[0], obj[1]
if cmd == 0: #get_1D
count = obj[2]
result = []
for i in range(count):
if len(shuffle_idxs) == 0:
shuffle_idxs = idxs.copy()
np.random.shuffle(shuffle_idxs)
result.append(shuffle_idxs.pop())
self.cqs[cq_id].put (result)
elif cmd == 1: #get_2D
targ_idxs,count = obj[2], obj[3]
result = []
for targ_idx in targ_idxs:
sub_idxs = []
for i in range(count):
ar = shuffle_idxs_2D[targ_idx]
if len(ar) == 0:
ar = shuffle_idxs_2D[targ_idx] = idxs_2D[targ_idx].copy()
np.random.shuffle(ar)
sub_idxs.append(ar.pop())
result.append (sub_idxs)
self.cqs[cq_id].put (result)
time.sleep(0.005)
def create_cli(self):
cq = multiprocessing.Queue()
self.cqs.append ( cq )
cq_id = len(self.cqs)-1
return Index2DHost.Cli(self.sq, cq, cq_id)
# disable pickling
def __getstate__(self):
return dict()
def __setstate__(self, d):
self.__dict__.update(d)
class Cli():
def __init__(self, sq, cq, cq_id):
self.sq = sq
self.cq = cq
self.cq_id = cq_id
def get_1D(self, count):
self.sq.put ( (self.cq_id,0, count) )
while True:
if not self.cq.empty():
return self.cq.get()
time.sleep(0.001)
def get_2D(self, idxs, count):
self.sq.put ( (self.cq_id,1,idxs,count) )
while True:
if not self.cq.empty():
return self.cq.get()
time.sleep(0.001)
class IndexHost(): class IndexHost():
""" """
@ -108,9 +18,9 @@ class IndexHost():
self.thread.daemon = True self.thread.daemon = True
self.thread.start() self.thread.start()
def host_thread(self, indexes_count, rnd_seed): def host_thread(self, indexes_count, rnd_seed):
rnd_state = np.random.RandomState(rnd_seed) if rnd_seed is not None else np.random rnd_state = np.random.RandomState(rnd_seed) if rnd_seed is not None else np.random
idxs = [*range(indexes_count)] idxs = [*range(indexes_count)]
shuffle_idxs = [] shuffle_idxs = []
sq = self.sq sq = self.sq
@ -156,6 +66,95 @@ class IndexHost():
return self.cq.get() return self.cq.get()
time.sleep(0.001) time.sleep(0.001)
class Index2DHost():
"""
Provides random shuffled indexes for multiprocesses
"""
def __init__(self, indexes2D):
self.sq = multiprocessing.Queue()
self.cqs = []
self.clis = []
self.thread = threading.Thread(target=self.host_thread, args=(indexes2D,) )
self.thread.daemon = True
self.thread.start()
def host_thread(self, indexes2D):
indexes2D_len = len(indexes2D)
idxs = [*range(indexes2D_len)]
idxs_2D = [None]*indexes2D_len
shuffle_idxs = []
shuffle_idxs_2D = [None]*indexes2D_len
for i in range(indexes2D_len):
idxs_2D[i] = [*range(len(indexes2D[i]))]
shuffle_idxs_2D[i] = []
#print(idxs)
#print(idxs_2D)
sq = self.sq
while True:
while not sq.empty():
obj = sq.get()
cq_id, count = obj[0], obj[1]
result = []
for i in range(count):
if len(shuffle_idxs) == 0:
shuffle_idxs = idxs.copy()
np.random.shuffle(shuffle_idxs)
idx_1D = shuffle_idxs.pop()
#print(f'idx_1D = {idx_1D}, len(shuffle_idxs_2D[idx_1D])= {len(shuffle_idxs_2D[idx_1D])}')
if len(shuffle_idxs_2D[idx_1D]) == 0:
shuffle_idxs_2D[idx_1D] = idxs_2D[idx_1D].copy()
#print(f'new shuffle_idxs_2d for {idx_1D} = { shuffle_idxs_2D[idx_1D] }')
#print(f'len(shuffle_idxs_2D[idx_1D])= {len(shuffle_idxs_2D[idx_1D])}')
np.random.shuffle( shuffle_idxs_2D[idx_1D] )
idx_2D = shuffle_idxs_2D[idx_1D].pop()
#print(f'len(shuffle_idxs_2D[idx_1D])= {len(shuffle_idxs_2D[idx_1D])}')
#print(f'idx_2D = {idx_2D}')
result.append( indexes2D[idx_1D][idx_2D])
self.cqs[cq_id].put (result)
time.sleep(0.001)
def create_cli(self):
cq = multiprocessing.Queue()
self.cqs.append ( cq )
cq_id = len(self.cqs)-1
return Index2DHost.Cli(self.sq, cq, cq_id)
# disable pickling
def __getstate__(self):
return dict()
def __setstate__(self, d):
self.__dict__.update(d)
class Cli():
def __init__(self, sq, cq, cq_id):
self.sq = sq
self.cq = cq
self.cq_id = cq_id
def multi_get(self, count):
self.sq.put ( (self.cq_id,count) )
while True:
if not self.cq.empty():
return self.cq.get()
time.sleep(0.001)
class ListHost(): class ListHost():
def __init__(self, list_): def __init__(self, list_):
self.sq = multiprocessing.Queue() self.sq = multiprocessing.Queue()

View file

@ -157,7 +157,7 @@ class ModelBase(object):
else: else:
self.device_config = nn.DeviceConfig.GPUIndexes( force_gpu_idxs or nn.ask_choose_device_idxs(suggest_best_multi_gpu=True)) \ self.device_config = nn.DeviceConfig.GPUIndexes( force_gpu_idxs or nn.ask_choose_device_idxs(suggest_best_multi_gpu=True)) \
if not cpu_only else nn.DeviceConfig.CPU() if not cpu_only else nn.DeviceConfig.CPU()
nn.initialize(self.device_config) nn.initialize(self.device_config)
#### ####
@ -296,9 +296,15 @@ class ModelBase(object):
default_random_flip = self.load_or_def_option('random_flip', True) default_random_flip = self.load_or_def_option('random_flip', True)
self.options['random_flip'] = io.input_bool("Flip faces randomly", default_random_flip, help_message="Predicted face will look more naturally without this option, but src faceset should cover all face directions as dst faceset.") self.options['random_flip'] = io.input_bool("Flip faces randomly", default_random_flip, help_message="Predicted face will look more naturally without this option, but src faceset should cover all face directions as dst faceset.")
def ask_batch_size(self, suggest_batch_size=None): def ask_batch_size(self, suggest_batch_size=None, range=None):
default_batch_size = self.load_or_def_option('batch_size', suggest_batch_size or self.batch_size) default_batch_size = self.load_or_def_option('batch_size', suggest_batch_size or self.batch_size)
self.options['batch_size'] = self.batch_size = max(0, io.input_int("Batch_size", default_batch_size, help_message="Larger batch size is better for NN's generalization, but it can cause Out of Memory error. Tune this value for your videocard manually."))
batch_size = max(0, io.input_int("Batch_size", default_batch_size, valid_range=range, help_message="Larger batch size is better for NN's generalization, but it can cause Out of Memory error. Tune this value for your videocard manually."))
if range is not None:
batch_size = np.clip(batch_size, range[0], range[1])
self.options['batch_size'] = self.batch_size = batch_size
#overridable #overridable

View file

@ -22,8 +22,9 @@ class QModel(ModelBase):
resolution = self.resolution = 96 resolution = self.resolution = 96
self.face_type = FaceType.FULL self.face_type = FaceType.FULL
ae_dims = 128 ae_dims = 128
e_dims = 128 e_dims = 64
d_dims = 64 d_dims = 64
d_mask_dims = 16
self.pretrain = False self.pretrain = False
self.pretrain_just_disabled = False self.pretrain_just_disabled = False
@ -39,7 +40,7 @@ class QModel(ModelBase):
self.model_filename_list = [] self.model_filename_list = []
model_archi = nn.DeepFakeArchi(resolution, mod='quick') model_archi = nn.DeepFakeArchi(resolution, opts='ud')
with tf.device ('/CPU:0'): with tf.device ('/CPU:0'):
#Place holders on CPU #Place holders on CPU
@ -57,11 +58,11 @@ class QModel(ModelBase):
self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, name='encoder') self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, name='encoder')
encoder_out_ch = self.encoder.compute_output_channels ( (nn.floatx, bgr_shape)) encoder_out_ch = self.encoder.compute_output_channels ( (nn.floatx, bgr_shape))
self.inter = model_archi.Inter (in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims, d_ch=d_dims, name='inter') self.inter = model_archi.Inter (in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims, name='inter')
inter_out_ch = self.inter.compute_output_channels ( (nn.floatx, (None,encoder_out_ch))) inter_out_ch = self.inter.compute_output_channels ( (nn.floatx, (None,encoder_out_ch)))
self.decoder_src = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, name='decoder_src') self.decoder_src = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder_src')
self.decoder_dst = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, name='decoder_dst') self.decoder_dst = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder_dst')
self.model_filename_list += [ [self.encoder, 'encoder.npy' ], self.model_filename_list += [ [self.encoder, 'encoder.npy' ],
[self.inter, 'inter.npy' ], [self.inter, 'inter.npy' ],

View file

@ -31,18 +31,23 @@ class SAEHDModel(ModelBase):
default_resolution = self.options['resolution'] = self.load_or_def_option('resolution', 128) default_resolution = self.options['resolution'] = self.load_or_def_option('resolution', 128)
default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'f') default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'f')
default_models_opt_on_gpu = self.options['models_opt_on_gpu'] = self.load_or_def_option('models_opt_on_gpu', True) default_models_opt_on_gpu = self.options['models_opt_on_gpu'] = self.load_or_def_option('models_opt_on_gpu', True)
default_archi = self.options['archi'] = self.load_or_def_option('archi', 'df')
archi = self.load_or_def_option('archi', 'df')
archi = {'dfuhd':'df-u','liaeuhd':'liae-u'}.get(archi, archi) #backward comp
default_archi = self.options['archi'] = archi
default_ae_dims = self.options['ae_dims'] = self.load_or_def_option('ae_dims', 256) default_ae_dims = self.options['ae_dims'] = self.load_or_def_option('ae_dims', 256)
default_e_dims = self.options['e_dims'] = self.load_or_def_option('e_dims', 64) default_e_dims = self.options['e_dims'] = self.load_or_def_option('e_dims', 64)
default_d_dims = self.options['d_dims'] = self.options.get('d_dims', None) default_d_dims = self.options['d_dims'] = self.options.get('d_dims', None)
default_d_mask_dims = self.options['d_mask_dims'] = self.options.get('d_mask_dims', None) default_d_mask_dims = self.options['d_mask_dims'] = self.options.get('d_mask_dims', None)
default_masked_training = self.options['masked_training'] = self.load_or_def_option('masked_training', True) default_masked_training = self.options['masked_training'] = self.load_or_def_option('masked_training', True)
default_eyes_prio = self.options['eyes_prio'] = self.load_or_def_option('eyes_prio', False) default_eyes_prio = self.options['eyes_prio'] = self.load_or_def_option('eyes_prio', False)
default_uniform_yaw = self.options['uniform_yaw'] = self.load_or_def_option('uniform_yaw', False)
lr_dropout = self.load_or_def_option('lr_dropout', 'n') lr_dropout = self.load_or_def_option('lr_dropout', 'n')
lr_dropout = {True:'y', False:'n'}.get(lr_dropout, lr_dropout) #backward comp lr_dropout = {True:'y', False:'n'}.get(lr_dropout, lr_dropout) #backward comp
default_lr_dropout = self.options['lr_dropout'] = lr_dropout default_lr_dropout = self.options['lr_dropout'] = lr_dropout
default_random_warp = self.options['random_warp'] = self.load_or_def_option('random_warp', True) default_random_warp = self.options['random_warp'] = self.load_or_def_option('random_warp', True)
default_gan_power = self.options['gan_power'] = self.load_or_def_option('gan_power', 0.0) default_gan_power = self.options['gan_power'] = self.load_or_def_option('gan_power', 0.0)
default_true_face_power = self.options['true_face_power'] = self.load_or_def_option('true_face_power', 0.0) default_true_face_power = self.options['true_face_power'] = self.load_or_def_option('true_face_power', 0.0)
@ -61,20 +66,43 @@ class SAEHDModel(ModelBase):
self.ask_batch_size(suggest_batch_size) self.ask_batch_size(suggest_batch_size)
if self.is_first_run(): if self.is_first_run():
resolution = io.input_int("Resolution", default_resolution, add_info="64-512", help_message="More resolution requires more VRAM and time to train. Value will be adjusted to multiple of 16.") resolution = io.input_int("Resolution", default_resolution, add_info="64-640", help_message="More resolution requires more VRAM and time to train. Value will be adjusted to multiple of 16.")
resolution = np.clip ( (resolution // 16) * 16, 64, 512) resolution = np.clip ( (resolution // 16) * 16, 64, 640)
self.options['resolution'] = resolution self.options['resolution'] = resolution
self.options['face_type'] = io.input_str ("Face type", default_face_type, ['h','mf','f','wf','head'], help_message="Half / mid face / full face / whole face / head. Half face has better resolution, but covers less area of cheeks. Mid face is 30% wider than half face. 'Whole face' covers full area of face include forehead. 'head' covers full head, but requires XSeg for src and dst faceset.").lower() self.options['face_type'] = io.input_str ("Face type", default_face_type, ['h','mf','f','wf','head'], help_message="Half / mid face / full face / whole face / head. Half face has better resolution, but covers less area of cheeks. Mid face is 30% wider than half face. 'Whole face' covers full area of face include forehead. 'head' covers full head, but requires XSeg for src and dst faceset.").lower()
self.options['archi'] = io.input_str ("AE architecture", default_archi, ['df','liae','dfhd','liaehd','dfuhd','liaeuhd'], help_message=\
while True:
archi = io.input_str ("AE architecture", default_archi, help_message=\
""" """
'df' keeps faces more natural. 'df' keeps more identity-preserved face.
'liae' can fix overly different face shapes. 'liae' can fix overly different face shapes.
'hd' experimental versions. '-u' increased likeness of the face.
'uhd' increased likeness of the face. '-d' (experimental) doubling the resolution using the same computation cost.
Examples: df, liae, df-d, df-ud, liae-ud, ...
""").lower() """).lower()
default_d_dims = 48 if self.options['archi'] == 'dfhd' else 64 archi_split = archi.split('-')
default_d_dims = self.options['d_dims'] = self.load_or_def_option('d_dims', default_d_dims)
if len(archi_split) == 2:
archi_type, archi_opts = archi_split
elif len(archi_split) == 1:
archi_type, archi_opts = archi_split[0], None
else:
continue
if archi_type not in ['df', 'liae']:
continue
if archi_opts is not None:
if len(archi_opts) == 0:
continue
if len([ 1 for opt in archi_opts if opt not in ['u','d'] ]) != 0:
continue
break
self.options['archi'] = archi
default_d_dims = self.options['d_dims'] = self.load_or_def_option('d_dims', 64)
default_d_mask_dims = default_d_dims // 3 default_d_mask_dims = default_d_dims // 3
default_d_mask_dims += default_d_mask_dims % 2 default_d_mask_dims += default_d_mask_dims % 2
@ -86,7 +114,6 @@ class SAEHDModel(ModelBase):
e_dims = np.clip ( io.input_int("Encoder dimensions", default_e_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 ) e_dims = np.clip ( io.input_int("Encoder dimensions", default_e_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 )
self.options['e_dims'] = e_dims + e_dims % 2 self.options['e_dims'] = e_dims + e_dims % 2
d_dims = np.clip ( io.input_int("Decoder dimensions", default_d_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 ) d_dims = np.clip ( io.input_int("Decoder dimensions", default_d_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 )
self.options['d_dims'] = d_dims + d_dims % 2 self.options['d_dims'] = d_dims + d_dims % 2
@ -97,13 +124,14 @@ class SAEHDModel(ModelBase):
if self.options['face_type'] == 'wf' or self.options['face_type'] == 'head': if self.options['face_type'] == 'wf' or self.options['face_type'] == 'head':
self.options['masked_training'] = io.input_bool ("Masked training", default_masked_training, help_message="This option is available only for 'whole_face' type. Masked training clips training area to full_face mask, thus network will train the faces properly. When the face is trained enough, disable this option to train all area of the frame. Merge with 'raw-rgb' mode, then use Adobe After Effects to manually mask and compose whole face include forehead.") self.options['masked_training'] = io.input_bool ("Masked training", default_masked_training, help_message="This option is available only for 'whole_face' type. Masked training clips training area to full_face mask, thus network will train the faces properly. When the face is trained enough, disable this option to train all area of the frame. Merge with 'raw-rgb' mode, then use Adobe After Effects to manually mask and compose whole face include forehead.")
self.options['eyes_prio'] = io.input_bool ("Eyes priority", default_eyes_prio, help_message='Helps to fix eye problems during training like "alien eyes" and wrong eyes direction ( especially on HD architectures ) by forcing the neural network to train eyes with higher priority. before/after https://i.imgur.com/YQHOuSR.jpg ') self.options['eyes_prio'] = io.input_bool ("Eyes priority", default_eyes_prio, help_message='Helps to fix eye problems during training like "alien eyes" and wrong eyes direction ( especially on HD architectures ) by forcing the neural network to train eyes with higher priority. before/after https://i.imgur.com/YQHOuSR.jpg ')
self.options['uniform_yaw'] = io.input_bool ("Uniform yaw distribution of samples", default_uniform_yaw, help_message='Helps to fix blurry side faces due to small amount of them in the faceset.')
if self.is_first_run() or ask_override: if self.is_first_run() or ask_override:
self.options['models_opt_on_gpu'] = io.input_bool ("Place models and optimizer on GPU", default_models_opt_on_gpu, help_message="When you train on one GPU, by default model and optimizer weights are placed on GPU to accelerate the process. You can place they on CPU to free up extra VRAM, thus set bigger dimensions.") self.options['models_opt_on_gpu'] = io.input_bool ("Place models and optimizer on GPU", default_models_opt_on_gpu, help_message="When you train on one GPU, by default model and optimizer weights are placed on GPU to accelerate the process. You can place they on CPU to free up extra VRAM, thus set bigger dimensions.")
self.options['lr_dropout'] = io.input_str (f"Use learning rate dropout", default_lr_dropout, ['n','y','cpu'], help_message="When the face is trained enough, you can enable this option to get extra sharpness and reduce subpixel shake for less amount of iterations.\nn - disabled.\ny - enabled\ncpu - enabled on CPU. This allows not to use extra VRAM, sacrificing 20% time of iteration.") self.options['lr_dropout'] = io.input_str (f"Use learning rate dropout", default_lr_dropout, ['n','y','cpu'], help_message="When the face is trained enough, you can enable this option to get extra sharpness and reduce subpixel shake for less amount of iterations.\nn - disabled.\ny - enabled\ncpu - enabled on CPU. This allows not to use extra VRAM, sacrificing 20% time of iteration.")
self.options['random_warp'] = io.input_bool ("Enable random warp of samples", default_random_warp, help_message="Random warp is required to generalize facial expressions of both faces. When the face is trained enough, you can disable it to get extra sharpness and reduce subpixel shake for less amount of iterations.") self.options['random_warp'] = io.input_bool ("Enable random warp of samples", default_random_warp, help_message="Random warp is required to generalize facial expressions of both faces. When the face is trained enough, you can disable it to get extra sharpness and reduce subpixel shake for less amount of iterations.")
self.options['gan_power'] = np.clip ( io.input_number ("GAN power", default_gan_power, add_info="0.0 .. 10.0", help_message="Train the network in Generative Adversarial manner. Accelerates the speed of training. Forces the neural network to learn small details of the face. You can enable/disable this option at any time. Typical value is 1.0"), 0.0, 10.0 ) self.options['gan_power'] = np.clip ( io.input_number ("GAN power", default_gan_power, add_info="0.0 .. 10.0", help_message="Train the network in Generative Adversarial manner. Accelerates the speed of training. Forces the neural network to learn small details of the face. You can enable/disable this option at any time. Typical value is 1.0"), 0.0, 10.0 )
@ -142,8 +170,14 @@ class SAEHDModel(ModelBase):
'head' : FaceType.HEAD}[ self.options['face_type'] ] 'head' : FaceType.HEAD}[ self.options['face_type'] ]
eyes_prio = self.options['eyes_prio'] eyes_prio = self.options['eyes_prio']
archi = self.options['archi']
is_hd = 'hd' in archi archi_split = self.options['archi'].split('-')
if len(archi_split) == 2:
archi_type, archi_opts = archi_split
elif len(archi_split) == 1:
archi_type, archi_opts = archi_split[0], None
ae_dims = self.options['ae_dims'] ae_dims = self.options['ae_dims']
e_dims = self.options['e_dims'] e_dims = self.options['e_dims']
d_dims = self.options['d_dims'] d_dims = self.options['d_dims']
@ -180,18 +214,18 @@ class SAEHDModel(ModelBase):
self.target_dstm_all = tf.placeholder (nn.floatx, mask_shape) self.target_dstm_all = tf.placeholder (nn.floatx, mask_shape)
# Initializing model classes # Initializing model classes
model_archi = nn.DeepFakeArchi(resolution, mod='uhd' if 'uhd' in archi else None) model_archi = nn.DeepFakeArchi(resolution, opts=archi_opts)
with tf.device (models_opt_device): with tf.device (models_opt_device):
if 'df' in archi: if 'df' in archi_type:
self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, is_hd=is_hd, name='encoder') self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, name='encoder')
encoder_out_ch = self.encoder.compute_output_channels ( (nn.floatx, bgr_shape)) encoder_out_ch = self.encoder.compute_output_channels ( (nn.floatx, bgr_shape))
self.inter = model_archi.Inter (in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims, is_hd=is_hd, name='inter') self.inter = model_archi.Inter (in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims, name='inter')
inter_out_ch = self.inter.compute_output_channels ( (nn.floatx, (None,encoder_out_ch))) inter_out_ch = self.inter.compute_output_channels ( (nn.floatx, (None,encoder_out_ch)))
self.decoder_src = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, is_hd=is_hd, name='decoder_src') self.decoder_src = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder_src')
self.decoder_dst = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, is_hd=is_hd, name='decoder_dst') self.decoder_dst = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder_dst')
self.model_filename_list += [ [self.encoder, 'encoder.npy' ], self.model_filename_list += [ [self.encoder, 'encoder.npy' ],
[self.inter, 'inter.npy' ], [self.inter, 'inter.npy' ],
@ -203,17 +237,17 @@ class SAEHDModel(ModelBase):
self.code_discriminator = nn.CodeDiscriminator(ae_dims, code_res=model_archi.Inter.get_code_res()*2, name='dis' ) self.code_discriminator = nn.CodeDiscriminator(ae_dims, code_res=model_archi.Inter.get_code_res()*2, name='dis' )
self.model_filename_list += [ [self.code_discriminator, 'code_discriminator.npy'] ] self.model_filename_list += [ [self.code_discriminator, 'code_discriminator.npy'] ]
elif 'liae' in archi: elif 'liae' in archi_type:
self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, is_hd=is_hd, name='encoder') self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, name='encoder')
encoder_out_ch = self.encoder.compute_output_channels ( (nn.floatx, bgr_shape)) encoder_out_ch = self.encoder.compute_output_channels ( (nn.floatx, bgr_shape))
self.inter_AB = model_archi.Inter(in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims*2, is_hd=is_hd, name='inter_AB') self.inter_AB = model_archi.Inter(in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims*2, name='inter_AB')
self.inter_B = model_archi.Inter(in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims*2, is_hd=is_hd, name='inter_B') self.inter_B = model_archi.Inter(in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims*2, name='inter_B')
inter_AB_out_ch = self.inter_AB.compute_output_channels ( (nn.floatx, (None,encoder_out_ch))) inter_AB_out_ch = self.inter_AB.compute_output_channels ( (nn.floatx, (None,encoder_out_ch)))
inter_B_out_ch = self.inter_B.compute_output_channels ( (nn.floatx, (None,encoder_out_ch))) inter_B_out_ch = self.inter_B.compute_output_channels ( (nn.floatx, (None,encoder_out_ch)))
inters_out_ch = inter_AB_out_ch+inter_B_out_ch inters_out_ch = inter_AB_out_ch+inter_B_out_ch
self.decoder = model_archi.Decoder(in_ch=inters_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, is_hd=is_hd, name='decoder') self.decoder = model_archi.Decoder(in_ch=inters_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder')
self.model_filename_list += [ [self.encoder, 'encoder.npy'], self.model_filename_list += [ [self.encoder, 'encoder.npy'],
[self.inter_AB, 'inter_AB.npy'], [self.inter_AB, 'inter_AB.npy'],
@ -223,22 +257,23 @@ class SAEHDModel(ModelBase):
if self.is_training: if self.is_training:
if gan_power != 0: if gan_power != 0:
self.D_src = nn.PatchDiscriminator(patch_size=resolution//16, in_ch=input_ch, name="D_src") self.D_src = nn.PatchDiscriminator(patch_size=resolution//16, in_ch=input_ch, name="D_src")
self.D_dst = nn.PatchDiscriminator(patch_size=resolution//16, in_ch=input_ch, name="D_dst") self.D_src_x2 = nn.PatchDiscriminator(patch_size=resolution//32, in_ch=input_ch, name="D_src_x2")
self.model_filename_list += [ [self.D_src, 'D_src.npy'] ] self.model_filename_list += [ [self.D_src, 'D_src.npy'] ]
self.model_filename_list += [ [self.D_dst, 'D_dst.npy'] ] self.model_filename_list += [ [self.D_src_x2, 'D_src_x2.npy'] ]
# Initialize optimizers # Initialize optimizers
lr=5e-5 lr=5e-5
lr_dropout = 0.3 if self.options['lr_dropout'] in ['y','cpu'] and not self.pretrain else 1.0 lr_dropout = 0.3 if self.options['lr_dropout'] in ['y','cpu'] and not self.pretrain else 1.0
clipnorm = 1.0 if self.options['clipgrad'] else 0.0 clipnorm = 1.0 if self.options['clipgrad'] else 0.0
self.src_dst_opt = nn.RMSprop(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='src_dst_opt')
self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ] if 'df' in archi_type:
if 'df' in archi:
self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter.get_weights() + self.decoder_src.get_weights() + self.decoder_dst.get_weights() self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter.get_weights() + self.decoder_src.get_weights() + self.decoder_dst.get_weights()
elif 'liae' in archi: elif 'liae' in archi_type:
self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter_AB.get_weights() + self.inter_B.get_weights() + self.decoder.get_weights() self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter_AB.get_weights() + self.inter_B.get_weights() + self.decoder.get_weights()
self.src_dst_opt = nn.RMSprop(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='src_dst_opt')
self.src_dst_opt.initialize_variables (self.src_dst_trainable_weights, vars_on_cpu=optimizer_vars_on_cpu, lr_dropout_on_cpu=self.options['lr_dropout']=='cpu') self.src_dst_opt.initialize_variables (self.src_dst_trainable_weights, vars_on_cpu=optimizer_vars_on_cpu, lr_dropout_on_cpu=self.options['lr_dropout']=='cpu')
self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ]
if self.options['true_face_power'] != 0: if self.options['true_face_power'] != 0:
self.D_code_opt = nn.RMSprop(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='D_code_opt') self.D_code_opt = nn.RMSprop(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='D_code_opt')
@ -247,7 +282,7 @@ class SAEHDModel(ModelBase):
if gan_power != 0: if gan_power != 0:
self.D_src_dst_opt = nn.RMSprop(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='D_src_dst_opt') self.D_src_dst_opt = nn.RMSprop(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='D_src_dst_opt')
self.D_src_dst_opt.initialize_variables ( self.D_src.get_weights()+self.D_dst.get_weights(), vars_on_cpu=optimizer_vars_on_cpu, lr_dropout_on_cpu=self.options['lr_dropout']=='cpu') self.D_src_dst_opt.initialize_variables ( self.D_src.get_weights()+self.D_src_x2.get_weights(), vars_on_cpu=optimizer_vars_on_cpu, lr_dropout_on_cpu=self.options['lr_dropout']=='cpu')
self.model_filename_list += [ (self.D_src_dst_opt, 'D_src_dst_opt.npy') ] self.model_filename_list += [ (self.D_src_dst_opt, 'D_src_dst_opt.npy') ]
if self.is_training: if self.is_training:
@ -284,14 +319,14 @@ class SAEHDModel(ModelBase):
gpu_target_dstm_all = self.target_dstm_all[batch_slice,:,:,:] gpu_target_dstm_all = self.target_dstm_all[batch_slice,:,:,:]
# process model tensors # process model tensors
if 'df' in archi: if 'df' in archi_type:
gpu_src_code = self.inter(self.encoder(gpu_warped_src)) gpu_src_code = self.inter(self.encoder(gpu_warped_src))
gpu_dst_code = self.inter(self.encoder(gpu_warped_dst)) gpu_dst_code = self.inter(self.encoder(gpu_warped_dst))
gpu_pred_src_src, gpu_pred_src_srcm = self.decoder_src(gpu_src_code) gpu_pred_src_src, gpu_pred_src_srcm = self.decoder_src(gpu_src_code)
gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code) gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code)
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code)
elif 'liae' in archi: elif 'liae' in archi_type:
gpu_src_code = self.encoder (gpu_warped_src) gpu_src_code = self.encoder (gpu_warped_src)
gpu_src_inter_AB_code = self.inter_AB (gpu_src_code) gpu_src_inter_AB_code = self.inter_AB (gpu_src_code)
gpu_src_code = tf.concat([gpu_src_inter_AB_code,gpu_src_inter_AB_code], nn.conv2d_ch_axis ) gpu_src_code = tf.concat([gpu_src_inter_AB_code,gpu_src_inter_AB_code], nn.conv2d_ch_axis )
@ -338,7 +373,7 @@ class SAEHDModel(ModelBase):
gpu_src_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) gpu_src_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
else: else:
gpu_src_loss = tf.reduce_mean ( 5*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) gpu_src_loss = tf.reduce_mean ( 5*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
gpu_src_loss += tf.reduce_mean ( 5*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1]) gpu_src_loss += tf.reduce_mean ( 5*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1])
gpu_src_loss += tf.reduce_mean ( 10*tf.square ( gpu_target_src_masked_opt - gpu_pred_src_src_masked_opt ), axis=[1,2,3]) gpu_src_loss += tf.reduce_mean ( 10*tf.square ( gpu_target_src_masked_opt - gpu_pred_src_src_masked_opt ), axis=[1,2,3])
if eyes_prio: if eyes_prio:
@ -359,9 +394,9 @@ class SAEHDModel(ModelBase):
gpu_dst_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1]) gpu_dst_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1])
else: else:
gpu_dst_loss = tf.reduce_mean ( 5*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1]) gpu_dst_loss = tf.reduce_mean ( 5*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1])
gpu_dst_loss += tf.reduce_mean ( 5*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/23.2) ), axis=[1]) gpu_dst_loss += tf.reduce_mean ( 5*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/23.2) ), axis=[1])
gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dst_masked_opt- gpu_pred_dst_dst_masked_opt ), axis=[1,2,3]) gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dst_masked_opt- gpu_pred_dst_dst_masked_opt ), axis=[1,2,3])
if eyes_prio: if eyes_prio:
gpu_dst_loss += tf.reduce_mean ( 300*tf.abs ( gpu_target_dst*gpu_target_dstm_eyes - gpu_pred_dst_dst*gpu_target_dstm_eyes ), axis=[1,2,3]) gpu_dst_loss += tf.reduce_mean ( 300*tf.abs ( gpu_target_dst*gpu_target_dstm_eyes - gpu_pred_dst_dst*gpu_target_dstm_eyes ), axis=[1,2,3])
@ -396,20 +431,21 @@ class SAEHDModel(ModelBase):
gpu_pred_src_src_d_zeros = tf.zeros_like(gpu_pred_src_src_d) gpu_pred_src_src_d_zeros = tf.zeros_like(gpu_pred_src_src_d)
gpu_target_src_d = self.D_src(gpu_target_src_masked_opt) gpu_target_src_d = self.D_src(gpu_target_src_masked_opt)
gpu_target_src_d_ones = tf.ones_like(gpu_target_src_d) gpu_target_src_d_ones = tf.ones_like(gpu_target_src_d)
gpu_pred_dst_dst_d = self.D_dst(gpu_pred_dst_dst_masked_opt)
gpu_pred_dst_dst_d_ones = tf.ones_like (gpu_pred_dst_dst_d)
gpu_pred_dst_dst_d_zeros = tf.zeros_like(gpu_pred_dst_dst_d)
gpu_target_dst_d = self.D_dst(gpu_target_dst_masked_opt)
gpu_target_dst_d_ones = tf.ones_like(gpu_target_dst_d)
gpu_D_src_dst_loss = (DLoss(gpu_target_src_d_ones , gpu_target_src_d) + \ gpu_pred_src_src_x2_d = self.D_src_x2(gpu_pred_src_src_masked_opt)
DLoss(gpu_pred_src_src_d_zeros, gpu_pred_src_src_d) ) * 0.5 + \ gpu_pred_src_src_x2_d_ones = tf.ones_like (gpu_pred_src_src_x2_d)
(DLoss(gpu_target_dst_d_ones , gpu_target_dst_d) + \ gpu_pred_src_src_x2_d_zeros = tf.zeros_like(gpu_pred_src_src_x2_d)
DLoss(gpu_pred_dst_dst_d_zeros, gpu_pred_dst_dst_d) ) * 0.5 gpu_target_src_x2_d = self.D_src_x2(gpu_target_src_masked_opt)
gpu_target_src_x2_d_ones = tf.ones_like(gpu_target_src_x2_d)
gpu_D_src_dst_loss_gvs += [ nn.gradients (gpu_D_src_dst_loss, self.D_src.get_weights()+self.D_dst.get_weights() ) ] gpu_D_src_dst_loss = (DLoss(gpu_target_src_d_ones , gpu_target_src_d) + \
DLoss(gpu_pred_src_src_d_zeros , gpu_pred_src_src_d) ) * 0.5 + \
(DLoss(gpu_target_src_x2_d_ones , gpu_target_src_x2_d) + \
DLoss(gpu_pred_src_src_x2_d_zeros, gpu_pred_src_src_x2_d) ) * 0.5
gpu_G_loss += gan_power*(DLoss(gpu_pred_src_src_d_ones, gpu_pred_src_src_d) + DLoss(gpu_pred_dst_dst_d_ones, gpu_pred_dst_dst_d)) gpu_D_src_dst_loss_gvs += [ nn.gradients (gpu_D_src_dst_loss, self.D_src.get_weights()+self.D_src_x2.get_weights() ) ]
gpu_G_loss += 0.5*gan_power*( DLoss(gpu_pred_src_src_d_ones, gpu_pred_src_src_d) + DLoss(gpu_pred_src_src_x2_d_ones, gpu_pred_src_src_x2_d))
gpu_G_loss_gvs += [ nn.gradients ( gpu_G_loss, self.src_dst_trainable_weights ) ] gpu_G_loss_gvs += [ nn.gradients ( gpu_G_loss, self.src_dst_trainable_weights ) ]
@ -474,12 +510,12 @@ class SAEHDModel(ModelBase):
else: else:
# Initializing merge function # Initializing merge function
with tf.device( f'/GPU:0' if len(devices) != 0 else f'/CPU:0'): with tf.device( f'/GPU:0' if len(devices) != 0 else f'/CPU:0'):
if 'df' in archi: if 'df' in archi_type:
gpu_dst_code = self.inter(self.encoder(self.warped_dst)) gpu_dst_code = self.inter(self.encoder(self.warped_dst))
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code)
_, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code) _, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code)
elif 'liae' in archi: elif 'liae' in archi_type:
gpu_dst_code = self.encoder (self.warped_dst) gpu_dst_code = self.encoder (self.warped_dst)
gpu_dst_inter_B_code = self.inter_B (gpu_dst_code) gpu_dst_inter_B_code = self.inter_B (gpu_dst_code)
gpu_dst_inter_AB_code = self.inter_AB (gpu_dst_code) gpu_dst_inter_AB_code = self.inter_AB (gpu_dst_code)
@ -499,10 +535,10 @@ class SAEHDModel(ModelBase):
for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"): for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"):
if self.pretrain_just_disabled: if self.pretrain_just_disabled:
do_init = False do_init = False
if 'df' in archi: if 'df' in archi_type:
if model == self.inter: if model == self.inter:
do_init = True do_init = True
elif 'liae' in archi: elif 'liae' in archi_type:
if model == self.inter_AB or model == self.inter_B: if model == self.inter_AB or model == self.inter_B:
do_init = True do_init = True
else: else:
@ -534,6 +570,7 @@ class SAEHDModel(ModelBase):
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE_EYES, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE_EYES, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
], ],
uniform_yaw_distribution=self.options['uniform_yaw'],
generators_count=src_generators_count ), generators_count=src_generators_count ),
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
@ -542,6 +579,7 @@ class SAEHDModel(ModelBase):
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE_EYES, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE_EYES, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
], ],
uniform_yaw_distribution=self.options['uniform_yaw'],
generators_count=dst_generators_count ) generators_count=dst_generators_count )
]) ])
@ -563,6 +601,9 @@ class SAEHDModel(ModelBase):
#override #override
def onTrainOneIter(self): def onTrainOneIter(self):
if self.get_iter() == 0 and not self.pretrain and not self.pretrain_just_disabled:
io.log_info('You are training the model from scratch. It is strongly recommended to use a pretrained model to speed up the training and improve the quality.\n')
bs = self.get_batch_size() bs = self.get_batch_size()
( (warped_src, target_src, target_srcm_all), \ ( (warped_src, target_src, target_srcm_all), \

View file

@ -17,9 +17,7 @@ class XSegModel(ModelBase):
super().__init__(*args, force_model_class_name='XSeg', **kwargs) super().__init__(*args, force_model_class_name='XSeg', **kwargs)
#override #override
def on_initialize_options(self): def on_initialize_options(self):
self.set_batch_size(4)
ask_override = self.ask_override() ask_override = self.ask_override()
if not self.is_first_run() and ask_override: if not self.is_first_run() and ask_override:
@ -30,7 +28,9 @@ class XSegModel(ModelBase):
if self.is_first_run(): if self.is_first_run():
self.options['face_type'] = io.input_str ("Face type", default_face_type, ['h','mf','f','wf','head'], help_message="Half / mid face / full face / whole face / head. Choose the same as your deepfake model.").lower() self.options['face_type'] = io.input_str ("Face type", default_face_type, ['h','mf','f','wf','head'], help_message="Half / mid face / full face / whole face / head. Choose the same as your deepfake model.").lower()
if self.is_first_run() or ask_override:
self.ask_batch_size(4, range=[2,16])
#override #override
def on_initialize(self): def on_initialize(self):

View file

@ -5,8 +5,8 @@ import cv2
import numpy as np import numpy as np
from core.cv2ex import * from core.cv2ex import *
from DFLIMG import *
from facelib import LandmarksProcessor from facelib import LandmarksProcessor
from core import imagelib
from core.imagelib import SegIEPolys from core.imagelib import SegIEPolys
class SampleType(IntEnum): class SampleType(IntEnum):
@ -28,6 +28,7 @@ class Sample(object):
'landmarks', 'landmarks',
'seg_ie_polys', 'seg_ie_polys',
'xseg_mask', 'xseg_mask',
'xseg_mask_compressed',
'eyebrows_expand_mod', 'eyebrows_expand_mod',
'source_filename', 'source_filename',
'person_name', 'person_name',
@ -42,6 +43,7 @@ class Sample(object):
landmarks=None, landmarks=None,
seg_ie_polys=None, seg_ie_polys=None,
xseg_mask=None, xseg_mask=None,
xseg_mask_compressed=None,
eyebrows_expand_mod=None, eyebrows_expand_mod=None,
source_filename=None, source_filename=None,
person_name=None, person_name=None,
@ -60,6 +62,16 @@ class Sample(object):
self.seg_ie_polys = SegIEPolys.load(seg_ie_polys) self.seg_ie_polys = SegIEPolys.load(seg_ie_polys)
self.xseg_mask = xseg_mask self.xseg_mask = xseg_mask
self.xseg_mask_compressed = xseg_mask_compressed
if self.xseg_mask_compressed is None and self.xseg_mask is not None:
xseg_mask = np.clip( imagelib.normalize_channels(xseg_mask, 1)*255, 0, 255 ).astype(np.uint8)
ret, xseg_mask_compressed = cv2.imencode('.png', xseg_mask)
if not ret:
raise Exception("Sample(): unable to generate xseg_mask_compressed")
self.xseg_mask_compressed = xseg_mask_compressed
self.xseg_mask = None
self.eyebrows_expand_mod = eyebrows_expand_mod if eyebrows_expand_mod is not None else 1.0 self.eyebrows_expand_mod = eyebrows_expand_mod if eyebrows_expand_mod is not None else 1.0
self.source_filename = source_filename self.source_filename = source_filename
self.person_name = person_name self.person_name = person_name
@ -67,6 +79,14 @@ class Sample(object):
self._filename_offset_size = None self._filename_offset_size = None
def get_xseg_mask(self):
if self.xseg_mask_compressed is not None:
xseg_mask = cv2.imdecode(self.xseg_mask_compressed, cv2.IMREAD_UNCHANGED)
if len(xseg_mask.shape) == 2:
xseg_mask = xseg_mask[...,None]
return xseg_mask.astype(np.float32) / 255.0
return self.xseg_mask
def get_pitch_yaw_roll(self): def get_pitch_yaw_roll(self):
if self.pitch_yaw_roll is None: if self.pitch_yaw_roll is None:
self.pitch_yaw_roll = LandmarksProcessor.estimate_pitch_yaw_roll(self.landmarks, size=self.shape[1]) self.pitch_yaw_roll = LandmarksProcessor.estimate_pitch_yaw_roll(self.landmarks, size=self.shape[1])
@ -97,6 +117,7 @@ class Sample(object):
'landmarks': self.landmarks.tolist(), 'landmarks': self.landmarks.tolist(),
'seg_ie_polys': self.seg_ie_polys.dump(), 'seg_ie_polys': self.seg_ie_polys.dump(),
'xseg_mask' : self.xseg_mask, 'xseg_mask' : self.xseg_mask,
'xseg_mask_compressed' : self.xseg_mask_compressed,
'eyebrows_expand_mod': self.eyebrows_expand_mod, 'eyebrows_expand_mod': self.eyebrows_expand_mod,
'source_filename': self.source_filename, 'source_filename': self.source_filename,
'person_name': self.person_name 'person_name': self.person_name

View file

@ -6,11 +6,13 @@ import cv2
import numpy as np import numpy as np
from core import mplib from core import mplib
from core.interact import interact as io
from core.joblib import SubprocessGenerator, ThisThreadGenerator from core.joblib import SubprocessGenerator, ThisThreadGenerator
from facelib import LandmarksProcessor from facelib import LandmarksProcessor
from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor, from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor,
SampleType) SampleType)
''' '''
arg arg
output_sample_types = [ output_sample_types = [
@ -23,15 +25,15 @@ class SampleGeneratorFace(SampleGeneratorBase):
random_ct_samples_path=None, random_ct_samples_path=None,
sample_process_options=SampleProcessor.Options(), sample_process_options=SampleProcessor.Options(),
output_sample_types=[], output_sample_types=[],
add_sample_idx=False, uniform_yaw_distribution=False,
generators_count=4, generators_count=4,
raise_on_no_data=True, raise_on_no_data=True,
**kwargs): **kwargs):
super().__init__(debug, batch_size) super().__init__(debug, batch_size)
self.initialized = False
self.sample_process_options = sample_process_options self.sample_process_options = sample_process_options
self.output_sample_types = output_sample_types self.output_sample_types = output_sample_types
self.add_sample_idx = add_sample_idx
if self.debug: if self.debug:
self.generators_count = 1 self.generators_count = 1
@ -40,15 +42,40 @@ class SampleGeneratorFace(SampleGeneratorBase):
samples = SampleLoader.load (SampleType.FACE, samples_path) samples = SampleLoader.load (SampleType.FACE, samples_path)
self.samples_len = len(samples) self.samples_len = len(samples)
self.initialized = False
if self.samples_len == 0: if self.samples_len == 0:
if raise_on_no_data: if raise_on_no_data:
raise ValueError('No training data provided.') raise ValueError('No training data provided.')
else: else:
return return
if uniform_yaw_distribution:
samples_pyr = [ ( idx, sample.get_pitch_yaw_roll() ) for idx, sample in enumerate(samples) ]
grads = 128
#instead of math.pi / 2, using -1.2,+1.2 because actually maximum yaw for 2DFAN landmarks are -1.2+1.2
grads_space = np.linspace (-1.2, 1.2,grads)
index_host = mplib.IndexHost(self.samples_len) yaws_sample_list = [None]*grads
for g in io.progress_bar_generator ( range(grads), "Sort by yaw"):
yaw = grads_space[g]
next_yaw = grads_space[g+1] if g < grads-1 else yaw
yaw_samples = []
for idx, pyr in samples_pyr:
s_yaw = -pyr[1]
if (g == 0 and s_yaw < next_yaw) or \
(g < grads-1 and s_yaw >= yaw and s_yaw < next_yaw) or \
(g == grads-1 and s_yaw >= yaw):
yaw_samples += [ idx ]
if len(yaw_samples) > 0:
yaws_sample_list[g] = yaw_samples
yaws_sample_list = [ y for y in yaws_sample_list if y is not None ]
index_host = mplib.Index2DHost( yaws_sample_list )
else:
index_host = mplib.IndexHost(self.samples_len)
if random_ct_samples_path is not None: if random_ct_samples_path is not None:
ct_samples = SampleLoader.load (SampleType.FACE, random_ct_samples_path) ct_samples = SampleLoader.load (SampleType.FACE, random_ct_samples_path)
@ -110,14 +137,8 @@ class SampleGeneratorFace(SampleGeneratorBase):
if batches is None: if batches is None:
batches = [ [] for _ in range(len(x)) ] batches = [ [] for _ in range(len(x)) ]
if self.add_sample_idx:
batches += [ [] ]
i_sample_idx = len(batches)-1
for i in range(len(x)): for i in range(len(x)):
batches[i].append ( x[i] ) batches[i].append ( x[i] )
if self.add_sample_idx:
batches[i_sample_idx].append (sample_idx)
yield [ np.array(batch) for batch in batches] yield [ np.array(batch) for batch in batches]

View file

@ -12,6 +12,98 @@ from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor,
SampleType) SampleType)
class Index2DHost():
"""
Provides random shuffled 2D indexes for multiprocesses
"""
def __init__(self, indexes2D):
self.sq = multiprocessing.Queue()
self.cqs = []
self.clis = []
self.thread = threading.Thread(target=self.host_thread, args=(indexes2D,) )
self.thread.daemon = True
self.thread.start()
def host_thread(self, indexes2D):
indexes_counts_len = len(indexes2D)
idxs = [*range(indexes_counts_len)]
idxs_2D = [None]*indexes_counts_len
shuffle_idxs = []
shuffle_idxs_2D = [None]*indexes_counts_len
for i in range(indexes_counts_len):
idxs_2D[i] = indexes2D[i]
shuffle_idxs_2D[i] = []
sq = self.sq
while True:
while not sq.empty():
obj = sq.get()
cq_id, cmd = obj[0], obj[1]
if cmd == 0: #get_1D
count = obj[2]
result = []
for i in range(count):
if len(shuffle_idxs) == 0:
shuffle_idxs = idxs.copy()
np.random.shuffle(shuffle_idxs)
result.append(shuffle_idxs.pop())
self.cqs[cq_id].put (result)
elif cmd == 1: #get_2D
targ_idxs,count = obj[2], obj[3]
result = []
for targ_idx in targ_idxs:
sub_idxs = []
for i in range(count):
ar = shuffle_idxs_2D[targ_idx]
if len(ar) == 0:
ar = shuffle_idxs_2D[targ_idx] = idxs_2D[targ_idx].copy()
np.random.shuffle(ar)
sub_idxs.append(ar.pop())
result.append (sub_idxs)
self.cqs[cq_id].put (result)
time.sleep(0.001)
def create_cli(self):
cq = multiprocessing.Queue()
self.cqs.append ( cq )
cq_id = len(self.cqs)-1
return Index2DHost.Cli(self.sq, cq, cq_id)
# disable pickling
def __getstate__(self):
return dict()
def __setstate__(self, d):
self.__dict__.update(d)
class Cli():
def __init__(self, sq, cq, cq_id):
self.sq = sq
self.cq = cq
self.cq_id = cq_id
def get_1D(self, count):
self.sq.put ( (self.cq_id,0, count) )
while True:
if not self.cq.empty():
return self.cq.get()
time.sleep(0.001)
def get_2D(self, idxs, count):
self.sq.put ( (self.cq_id,1,idxs,count) )
while True:
if not self.cq.empty():
return self.cq.get()
time.sleep(0.001)
''' '''
arg arg
output_sample_types = [ output_sample_types = [
@ -45,7 +137,7 @@ class SampleGeneratorFacePerson(SampleGeneratorBase):
for i,sample in enumerate(samples): for i,sample in enumerate(samples):
persons_name_idxs[sample.person_name].append (i) persons_name_idxs[sample.person_name].append (i)
indexes2D = [ persons_name_idxs[person_name] for person_name in unique_person_names ] indexes2D = [ persons_name_idxs[person_name] for person_name in unique_person_names ]
index2d_host = mplib.Index2DHost(indexes2D) index2d_host = Index2DHost(indexes2D)
if self.debug: if self.debug:
self.generators_count = 1 self.generators_count = 1

View file

@ -25,7 +25,7 @@ class SampleGeneratorFaceXSeg(SampleGeneratorBase):
samples = sum([ SampleLoader.load (SampleType.FACE, path) for path in paths ] ) samples = sum([ SampleLoader.load (SampleType.FACE, path) for path in paths ] )
seg_sample_idxs = SegmentedSampleFilterSubprocessor(samples).run() seg_sample_idxs = SegmentedSampleFilterSubprocessor(samples).run()
seg_samples_len = len(seg_sample_idxs) seg_samples_len = len(seg_sample_idxs)
if seg_samples_len == 0: if seg_samples_len == 0:
raise Exception(f"No segmented faces found.") raise Exception(f"No segmented faces found.")
@ -63,9 +63,10 @@ class SampleGeneratorFaceXSeg(SampleGeneratorBase):
def batch_func(self, param ): def batch_func(self, param ):
samples, seg_sample_idxs, resolution, face_type, data_format = param samples, seg_sample_idxs, resolution, face_type, data_format = param
shuffle_idxs = [] shuffle_idxs = []
bg_shuffle_idxs = []
random_flip = True random_flip = True
rotation_range=[-10,10] rotation_range=[-10,10]
scale_range=[-0.05, 0.05] scale_range=[-0.05, 0.05]
@ -76,6 +77,25 @@ class SampleGeneratorFaceXSeg(SampleGeneratorBase):
motion_blur_chance, motion_blur_mb_max_size = 25, 5 motion_blur_chance, motion_blur_mb_max_size = 25, 5
gaussian_blur_chance, gaussian_blur_kernel_max_size = 25, 5 gaussian_blur_chance, gaussian_blur_kernel_max_size = 25, 5
def gen_img_mask(sample):
img = sample.load_bgr()
h,w,c = img.shape
mask = np.zeros ((h,w,1), dtype=np.float32)
sample.seg_ie_polys.overlay_mask(mask)
if face_type == sample.face_type:
if w != resolution:
img = cv2.resize( img, (resolution, resolution), cv2.INTER_LANCZOS4 )
mask = cv2.resize( mask, (resolution, resolution), cv2.INTER_LANCZOS4 )
else:
mat = LandmarksProcessor.get_transform_mat (sample.landmarks, resolution, face_type)
img = cv2.warpAffine( img, mat, (resolution,resolution), borderMode=cv2.BORDER_CONSTANT, flags=cv2.INTER_LANCZOS4 )
mask = cv2.warpAffine( mask, mat, (resolution,resolution), borderMode=cv2.BORDER_CONSTANT, flags=cv2.INTER_LANCZOS4 )
if len(mask.shape) == 2:
mask = mask[...,None]
return img, mask
bs = self.batch_size bs = self.batch_size
while True: while True:
batches = [ [], [] ] batches = [ [], [] ]
@ -86,30 +106,26 @@ class SampleGeneratorFaceXSeg(SampleGeneratorBase):
if len(shuffle_idxs) == 0: if len(shuffle_idxs) == 0:
shuffle_idxs = seg_sample_idxs.copy() shuffle_idxs = seg_sample_idxs.copy()
np.random.shuffle(shuffle_idxs) np.random.shuffle(shuffle_idxs)
idx = shuffle_idxs.pop() sample = samples[shuffle_idxs.pop()]
img, mask = gen_img_mask(sample)
sample = samples[idx]
img = sample.load_bgr() if np.random.randint(2) == 0:
h,w,c = img.shape
mask = np.zeros ((h,w,1), dtype=np.float32) if len(bg_shuffle_idxs) == 0:
sample.seg_ie_polys.overlay_mask(mask) bg_shuffle_idxs = seg_sample_idxs.copy()
np.random.shuffle(bg_shuffle_idxs)
bg_sample = samples[bg_shuffle_idxs.pop()]
bg_img, bg_mask = gen_img_mask(bg_sample)
bg_wp = imagelib.gen_warp_params(resolution, True, rotation_range=[-180,180], scale_range=[-0.10, 0.10], tx_range=[-0.10, 0.10], ty_range=[-0.10, 0.10] )
bg_img = imagelib.warp_by_params (bg_wp, bg_img, can_warp=False, can_transform=True, can_flip=True, border_replicate=False)
bg_mask = imagelib.warp_by_params (bg_wp, bg_mask, can_warp=False, can_transform=True, can_flip=True, border_replicate=False)
c_mask = (1-bg_mask) * (1-mask)
img = img*(1-c_mask) + bg_img * c_mask
warp_params = imagelib.gen_warp_params(resolution, random_flip, rotation_range=rotation_range, scale_range=scale_range, tx_range=tx_range, ty_range=ty_range ) warp_params = imagelib.gen_warp_params(resolution, random_flip, rotation_range=rotation_range, scale_range=scale_range, tx_range=tx_range, ty_range=ty_range )
if face_type == sample.face_type:
if w != resolution:
img = cv2.resize( img, (resolution, resolution), cv2.INTER_LANCZOS4 )
mask = cv2.resize( mask, (resolution, resolution), cv2.INTER_LANCZOS4 )
else:
mat = LandmarksProcessor.get_transform_mat (sample.landmarks, resolution, face_type)
img = cv2.warpAffine( img, mat, (resolution,resolution), borderMode=cv2.BORDER_CONSTANT, flags=cv2.INTER_LANCZOS4 )
mask = cv2.warpAffine( mask, mat, (resolution,resolution), borderMode=cv2.BORDER_CONSTANT, flags=cv2.INTER_LANCZOS4 )
if len(mask.shape) == 2:
mask = mask[...,None]
img = imagelib.warp_by_params (warp_params, img, can_warp=True, can_transform=True, can_flip=True, border_replicate=False) img = imagelib.warp_by_params (warp_params, img, can_warp=True, can_transform=True, can_flip=True, border_replicate=False)
mask = imagelib.warp_by_params (warp_params, mask, can_warp=True, can_transform=True, can_flip=True, border_replicate=False) mask = imagelib.warp_by_params (warp_params, mask, can_warp=True, can_transform=True, can_flip=True, border_replicate=False)
@ -120,13 +136,13 @@ class SampleGeneratorFaceXSeg(SampleGeneratorBase):
if np.random.randint(2) == 0: if np.random.randint(2) == 0:
img = imagelib.apply_random_hsv_shift(img, mask=sd.random_circle_faded ([resolution,resolution])) img = imagelib.apply_random_hsv_shift(img, mask=sd.random_circle_faded ([resolution,resolution]))
else: else:
img = imagelib.apply_random_rgb_levels(img, mask=sd.random_circle_faded ([resolution,resolution])) img = imagelib.apply_random_rgb_levels(img, mask=sd.random_circle_faded ([resolution,resolution]))
img = imagelib.apply_random_motion_blur( img, motion_blur_chance, motion_blur_mb_max_size, mask=sd.random_circle_faded ([resolution,resolution])) img = imagelib.apply_random_motion_blur( img, motion_blur_chance, motion_blur_mb_max_size, mask=sd.random_circle_faded ([resolution,resolution]))
img = imagelib.apply_random_gaussian_blur( img, gaussian_blur_chance, gaussian_blur_kernel_max_size, mask=sd.random_circle_faded ([resolution,resolution])) img = imagelib.apply_random_gaussian_blur( img, gaussian_blur_chance, gaussian_blur_kernel_max_size, mask=sd.random_circle_faded ([resolution,resolution]))
img = imagelib.apply_random_bilinear_resize( img, random_bilinear_resize_chance, random_bilinear_resize_max_size_per, mask=sd.random_circle_faded ([resolution,resolution])) img = imagelib.apply_random_bilinear_resize( img, random_bilinear_resize_chance, random_bilinear_resize_max_size_per, mask=sd.random_circle_faded ([resolution,resolution]))
if data_format == "NCHW": if data_format == "NCHW":
img = np.transpose(img, (2,0,1) ) img = np.transpose(img, (2,0,1) )
mask = np.transpose(mask, (2,0,1) ) mask = np.transpose(mask, (2,0,1) )
@ -140,12 +156,10 @@ class SampleGeneratorFaceXSeg(SampleGeneratorBase):
yield [ np.array(batch) for batch in batches] yield [ np.array(batch) for batch in batches]
class SegmentedSampleFilterSubprocessor(Subprocessor): class SegmentedSampleFilterSubprocessor(Subprocessor):
#override #override
def __init__(self, samples ): def __init__(self, samples ):
self.samples = samples self.samples = samples
self.samples_len = len(self.samples) self.samples_len = len(self.samples)
self.idxs = [*range(self.samples_len)] self.idxs = [*range(self.samples_len)]
@ -164,7 +178,7 @@ class SegmentedSampleFilterSubprocessor(Subprocessor):
#override #override
def on_clients_finalized(self): def on_clients_finalized(self):
io.progress_bar_close() io.progress_bar_close()
#override #override
def get_data(self, host_dict): def get_data(self, host_dict):
if len (self.idxs) > 0: if len (self.idxs) > 0:
@ -184,11 +198,11 @@ class SegmentedSampleFilterSubprocessor(Subprocessor):
io.progress_bar_inc(1) io.progress_bar_inc(1)
def get_result(self): def get_result(self):
return self.result return self.result
class Cli(Subprocessor.Cli): class Cli(Subprocessor.Cli):
#overridable optional #overridable optional
def on_initialize(self, client_dict): def on_initialize(self, client_dict):
self.samples = client_dict['samples'] self.samples = client_dict['samples']
def process_data(self, idx): def process_data(self, idx):
return idx, self.samples[idx].seg_ie_polys.get_pts_count() != 0 return idx, self.samples[idx].seg_ie_polys.get_pts_count() != 0

View file

@ -56,15 +56,16 @@ class SampleProcessor(object):
ct_sample_bgr = None ct_sample_bgr = None
h,w,c = sample_bgr.shape h,w,c = sample_bgr.shape
def get_full_face_mask(): def get_full_face_mask():
if sample.xseg_mask is not None: xseg_mask = sample.get_xseg_mask()
full_face_mask = sample.xseg_mask if xseg_mask is not None:
if full_face_mask.shape[0] != h or full_face_mask.shape[1] != w: if xseg_mask.shape[0] != h or xseg_mask.shape[1] != w:
full_face_mask = cv2.resize(full_face_mask, (w,h), interpolation=cv2.INTER_CUBIC) xseg_mask = cv2.resize(xseg_mask, (w,h), interpolation=cv2.INTER_CUBIC)
full_face_mask = imagelib.normalize_channels(full_face_mask, 1) xseg_mask = imagelib.normalize_channels(xseg_mask, 1)
return np.clip(xseg_mask, 0, 1)
else: else:
full_face_mask = LandmarksProcessor.get_image_hull_mask (sample_bgr.shape, sample_landmarks, eyebrows_expand_mod=sample.eyebrows_expand_mod ) full_face_mask = LandmarksProcessor.get_image_hull_mask (sample_bgr.shape, sample_landmarks, eyebrows_expand_mod=sample.eyebrows_expand_mod )
return np.clip(full_face_mask, 0, 1) return np.clip(full_face_mask, 0, 1)
def get_eyes_mask(): def get_eyes_mask():
eyes_mask = LandmarksProcessor.get_image_eye_mask (sample_bgr.shape, sample_landmarks) eyes_mask = LandmarksProcessor.get_image_eye_mask (sample_bgr.shape, sample_landmarks)