First Cioscos commit

-Now the code is at the same point of Iperov's one
-SAEHD can optionally use fp16 (unstable)
-Other loss functions and background power are not available so far
-Some bug fix
This commit is contained in:
Cioscos 2021-09-29 01:41:54 +02:00
commit 0fe22be204
20 changed files with 909 additions and 589 deletions

View file

@ -1,7 +1,137 @@
import numpy as np import numpy as np
import numpy.linalg as npla
import cv2 import cv2
from core import randomex from core import randomex
def mls_rigid_deformation(vy, vx, p, q, alpha=1.0, eps=1e-8):
""" Rigid deformation
Parameters
----------
vx, vy: ndarray
coordinate grid, generated by np.meshgrid(gridX, gridY)
p: ndarray
an array with size [n, 2], original control points
q: ndarray
an array with size [n, 2], final control points
alpha: float
parameter used by weights
eps: float
epsilon
Return
------
A deformed image.
"""
# Change (x, y) to (row, col)
q = np.ascontiguousarray(q[:, [1, 0]].astype(np.int16))
p = np.ascontiguousarray(p[:, [1, 0]].astype(np.int16))
# Exchange p and q and hence we transform destination pixels to the corresponding source pixels.
p, q = q, p
grow = vx.shape[0] # grid rows
gcol = vx.shape[1] # grid cols
ctrls = p.shape[0] # control points
# Compute
reshaped_p = p.reshape(ctrls, 2, 1, 1) # [ctrls, 2, 1, 1]
reshaped_v = np.vstack((vx.reshape(1, grow, gcol), vy.reshape(1, grow, gcol))) # [2, grow, gcol]
w = 1.0 / (np.sum((reshaped_p - reshaped_v).astype(np.float32) ** 2, axis=1) + eps) ** alpha # [ctrls, grow, gcol]
w /= np.sum(w, axis=0, keepdims=True) # [ctrls, grow, gcol]
pstar = np.zeros((2, grow, gcol), np.float32)
for i in range(ctrls):
pstar += w[i] * reshaped_p[i] # [2, grow, gcol]
vpstar = reshaped_v - pstar # [2, grow, gcol]
reshaped_vpstar = vpstar.reshape(2, 1, grow, gcol) # [2, 1, grow, gcol]
neg_vpstar_verti = vpstar[[1, 0],...] # [2, grow, gcol]
neg_vpstar_verti[1,...] = -neg_vpstar_verti[1,...]
reshaped_neg_vpstar_verti = neg_vpstar_verti.reshape(2, 1, grow, gcol) # [2, 1, grow, gcol]
mul_right = np.concatenate((reshaped_vpstar, reshaped_neg_vpstar_verti), axis=1) # [2, 2, grow, gcol]
reshaped_mul_right = mul_right.reshape(2, 2, grow, gcol) # [2, 2, grow, gcol]
# Calculate q
reshaped_q = q.reshape((ctrls, 2, 1, 1)) # [ctrls, 2, 1, 1]
qstar = np.zeros((2, grow, gcol), np.float32)
for i in range(ctrls):
qstar += w[i] * reshaped_q[i] # [2, grow, gcol]
temp = np.zeros((grow, gcol, 2), np.float32)
for i in range(ctrls):
phat = reshaped_p[i] - pstar # [2, grow, gcol]
reshaped_phat = phat.reshape(1, 2, grow, gcol) # [1, 2, grow, gcol]
reshaped_w = w[i].reshape(1, 1, grow, gcol) # [1, 1, grow, gcol]
neg_phat_verti = phat[[1, 0]] # [2, grow, gcol]
neg_phat_verti[1] = -neg_phat_verti[1]
reshaped_neg_phat_verti = neg_phat_verti.reshape(1, 2, grow, gcol) # [1, 2, grow, gcol]
mul_left = np.concatenate((reshaped_phat, reshaped_neg_phat_verti), axis=0) # [2, 2, grow, gcol]
A = np.matmul((reshaped_w * mul_left).transpose(2, 3, 0, 1),
reshaped_mul_right.transpose(2, 3, 0, 1)) # [grow, gcol, 2, 2]
qhat = reshaped_q[i] - qstar # [2, grow, gcol]
reshaped_qhat = qhat.reshape(1, 2, grow, gcol).transpose(2, 3, 0, 1) # [grow, gcol, 1, 2]
# Get final image transfomer -- 3-D array
temp += np.matmul(reshaped_qhat, A).reshape(grow, gcol, 2) # [grow, gcol, 2]
temp = temp.transpose(2, 0, 1) # [2, grow, gcol]
normed_temp = np.linalg.norm(temp, axis=0, keepdims=True) # [1, grow, gcol]
normed_vpstar = np.linalg.norm(vpstar, axis=0, keepdims=True) # [1, grow, gcol]
nan_mask = normed_temp[0]==0
transformers = np.true_divide(temp, normed_temp, out=np.zeros_like(temp), where= ~nan_mask) * normed_vpstar + qstar
# fix nan values
nan_mask_flat = np.flatnonzero(nan_mask)
nan_mask_anti_flat = np.flatnonzero(~nan_mask)
transformers[0][nan_mask] = np.interp(nan_mask_flat, nan_mask_anti_flat, transformers[0][~nan_mask])
transformers[1][nan_mask] = np.interp(nan_mask_flat, nan_mask_anti_flat, transformers[1][~nan_mask])
return transformers
def gen_pts(W, H, rnd_state=None):
if rnd_state is None:
rnd_state = np.random
min_pts, max_pts = 4, 16
n_pts = rnd_state.randint(min_pts, max_pts)
min_radius_per = 0.00
max_radius_per = 0.10
pts = []
for i in range(max_pts):
while True:
x, y = rnd_state.randint(W), rnd_state.randint(H)
rad = min_radius_per + rnd_state.rand()*(max_radius_per-min_radius_per)
intersect = False
for px,py,prad,_,_ in pts:
dist = npla.norm([x-px, y-py])
if dist <= (rad+prad)*2:
intersect = True
break
if intersect:
continue
angle = rnd_state.rand()*(2*np.pi)
x2 = int(x+np.cos(angle)*W*rad)
y2 = int(y+np.sin(angle)*H*rad)
break
pts.append( (x,y,rad, x2,y2) )
pts1 = np.array( [ [pt[0],pt[1]] for pt in pts ] )
pts2 = np.array( [ [pt[-2],pt[-1]] for pt in pts ] )
return pts1, pts2
def gen_warp_params (w, flip=False, rotation_range=[-2,2], scale_range=[-0.5, 0.5], tx_range=[-0.05, 0.05], ty_range=[-0.05, 0.05], rnd_state=None ): def gen_warp_params (w, flip=False, rotation_range=[-2,2], scale_range=[-0.5, 0.5], tx_range=[-0.05, 0.05], ty_range=[-0.05, 0.05], rnd_state=None ):
if rnd_state is None: if rnd_state is None:
rnd_state = np.random rnd_state = np.random
@ -17,22 +147,29 @@ def gen_warp_params (w, flip=False, rotation_range=[-2,2], scale_range=[-0.5, 0.
ty = rnd_state.uniform( ty_range[0], ty_range[1] ) ty = rnd_state.uniform( ty_range[0], ty_range[1] )
p_flip = flip and rnd_state.randint(10) < 4 p_flip = flip and rnd_state.randint(10) < 4
#random warp by grid #random warp V1
cell_size = [ w // (2**i) for i in range(1,4) ] [ rnd_state.randint(3) ] cell_size = [ w // (2**i) for i in range(1,4) ] [ rnd_state.randint(3) ]
cell_count = w // cell_size + 1 cell_count = w // cell_size + 1
grid_points = np.linspace( 0, w, cell_count) grid_points = np.linspace( 0, w, cell_count)
mapx = np.broadcast_to(grid_points, (cell_count, cell_count)).copy() mapx = np.broadcast_to(grid_points, (cell_count, cell_count)).copy()
mapy = mapx.T mapy = mapx.T
mapx[1:-1,1:-1] = mapx[1:-1,1:-1] + randomex.random_normal( size=(cell_count-2, cell_count-2) )*(cell_size*0.24) mapx[1:-1,1:-1] = mapx[1:-1,1:-1] + randomex.random_normal( size=(cell_count-2, cell_count-2) )*(cell_size*0.24)
mapy[1:-1,1:-1] = mapy[1:-1,1:-1] + randomex.random_normal( size=(cell_count-2, cell_count-2) )*(cell_size*0.24) mapy[1:-1,1:-1] = mapy[1:-1,1:-1] + randomex.random_normal( size=(cell_count-2, cell_count-2) )*(cell_size*0.24)
half_cell_size = cell_size // 2 half_cell_size = cell_size // 2
mapx = cv2.resize(mapx, (w+cell_size,)*2 )[half_cell_size:-half_cell_size,half_cell_size:-half_cell_size].astype(np.float32) mapx = cv2.resize(mapx, (w+cell_size,)*2 )[half_cell_size:-half_cell_size,half_cell_size:-half_cell_size].astype(np.float32)
mapy = cv2.resize(mapy, (w+cell_size,)*2 )[half_cell_size:-half_cell_size,half_cell_size:-half_cell_size].astype(np.float32) mapy = cv2.resize(mapy, (w+cell_size,)*2 )[half_cell_size:-half_cell_size,half_cell_size:-half_cell_size].astype(np.float32)
##############
# random warp V2
# pts1, pts2 = gen_pts(w, w, rnd_state)
# gridX = np.arange(w, dtype=np.int16)
# gridY = np.arange(w, dtype=np.int16)
# vy, vx = np.meshgrid(gridX, gridY)
# drigid = mls_rigid_deformation(vy, vx, pts1, pts2)
# mapy, mapx = drigid.astype(np.float32)
################
#random transform #random transform
random_transform_mat = cv2.getRotationMatrix2D((w // 2, w // 2), rotation, scale) random_transform_mat = cv2.getRotationMatrix2D((w // 2, w // 2), rotation, scale)
random_transform_mat[:, 2] += (tx*w, ty*w) random_transform_mat[:, 2] += (tx*w, ty*w)

View file

@ -7,13 +7,20 @@ class DeepFakeArchi(nn.ArchiBase):
mod None - default mod None - default
'quick' 'quick'
opts ''
''
't'
""" """
def __init__(self, resolution, mod=None, opts=None): def __init__(self, resolution, use_fp16=False, mod=None, opts=None):
super().__init__() super().__init__()
if opts is None: if opts is None:
opts = '' opts = ''
conv_dtype = tf.float16 if use_fp16 else tf.float32
if mod is None: if mod is None:
class Downscale(nn.ModelBase): class Downscale(nn.ModelBase):
def __init__(self, in_ch, out_ch, kernel_size=5, *kwargs ): def __init__(self, in_ch, out_ch, kernel_size=5, *kwargs ):
@ -23,7 +30,7 @@ class DeepFakeArchi(nn.ArchiBase):
super().__init__(*kwargs) super().__init__(*kwargs)
def on_build(self, *args, **kwargs ): def on_build(self, *args, **kwargs ):
self.conv1 = nn.Conv2D( self.in_ch, self.out_ch, kernel_size=self.kernel_size, strides=2, padding='SAME') self.conv1 = nn.Conv2D( self.in_ch, self.out_ch, kernel_size=self.kernel_size, strides=2, padding='SAME', dtype=conv_dtype)
def forward(self, x): def forward(self, x):
x = self.conv1(x) x = self.conv1(x)
@ -40,7 +47,7 @@ class DeepFakeArchi(nn.ArchiBase):
last_ch = in_ch last_ch = in_ch
for i in range(n_downscales): for i in range(n_downscales):
cur_ch = ch*( min(2**i, 8) ) cur_ch = ch*( min(2**i, 8) )
self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size) ) self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size))
last_ch = self.downs[-1].get_out_ch() last_ch = self.downs[-1].get_out_ch()
def forward(self, inp): def forward(self, inp):
@ -50,8 +57,8 @@ class DeepFakeArchi(nn.ArchiBase):
return x return x
class Upscale(nn.ModelBase): class Upscale(nn.ModelBase):
def on_build(self, in_ch, out_ch, kernel_size=3 ): def on_build(self, in_ch, out_ch, kernel_size=3):
self.conv1 = nn.Conv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME') self.conv1 = nn.Conv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype)
def forward(self, x): def forward(self, x):
x = self.conv1(x) x = self.conv1(x)
@ -60,9 +67,9 @@ class DeepFakeArchi(nn.ArchiBase):
return x return x
class ResidualBlock(nn.ModelBase): class ResidualBlock(nn.ModelBase):
def on_build(self, ch, kernel_size=3 ): def on_build(self, ch, kernel_size=3):
self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME') self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype)
self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME') self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype)
def forward(self, inp): def forward(self, inp):
x = self.conv1(inp) x = self.conv1(inp)
@ -76,16 +83,44 @@ class DeepFakeArchi(nn.ArchiBase):
self.in_ch = in_ch self.in_ch = in_ch
self.e_ch = e_ch self.e_ch = e_ch
super().__init__(**kwargs) super().__init__(**kwargs)
def on_build(self):
self.down1 = DownscaleBlock(self.in_ch, self.e_ch, n_downscales=4, kernel_size=5)
def forward(self, inp): def on_build(self):
return nn.flatten(self.down1(inp)) if 't' in opts:
self.down1 = Downscale(self.in_ch, self.e_ch, kernel_size=5)
self.res1 = ResidualBlock(self.e_ch)
self.down2 = Downscale(self.e_ch, self.e_ch*2, kernel_size=5)
self.down3 = Downscale(self.e_ch*2, self.e_ch*4, kernel_size=5)
self.down4 = Downscale(self.e_ch*4, self.e_ch*8, kernel_size=5)
self.down5 = Downscale(self.e_ch*8, self.e_ch*8, kernel_size=5)
self.res5 = ResidualBlock(self.e_ch*8)
else:
self.down1 = DownscaleBlock(self.in_ch, self.e_ch, n_downscales=4 if 't' not in opts else 5, kernel_size=5)
def forward(self, x):
if use_fp16:
x = tf.cast(x, tf.float16)
if 't' in opts:
x = self.down1(x)
x = self.res1(x)
x = self.down2(x)
x = self.down3(x)
x = self.down4(x)
x = self.down5(x)
x = self.res5(x)
else:
x = self.down1(x)
x = nn.flatten(x)
if 'u' in opts:
x = nn.pixel_norm(x, axes=-1)
if use_fp16:
x = tf.cast(x, tf.float32)
return x
def get_out_res(self, res): def get_out_res(self, res):
return res // (2**4) return res // ( (2**4) if 't' not in opts else (2**5) )
def get_out_ch(self): def get_out_ch(self):
return self.e_ch * 8 return self.e_ch * 8
@ -98,58 +133,84 @@ class DeepFakeArchi(nn.ArchiBase):
def on_build(self): def on_build(self):
in_ch, ae_ch, ae_out_ch = self.in_ch, self.ae_ch, self.ae_out_ch in_ch, ae_ch, ae_out_ch = self.in_ch, self.ae_ch, self.ae_out_ch
if 'u' in opts:
self.dense_norm = nn.DenseNorm()
self.dense1 = nn.Dense( in_ch, ae_ch ) self.dense1 = nn.Dense( in_ch, ae_ch )
self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch ) self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch )
self.upscale1 = Upscale(ae_out_ch, ae_out_ch) if 't' not in opts:
self.upscale1 = Upscale(ae_out_ch, ae_out_ch)
def forward(self, inp): def forward(self, inp):
x = inp x = inp
if 'u' in opts:
x = self.dense_norm(x)
x = self.dense1(x) x = self.dense1(x)
x = self.dense2(x) x = self.dense2(x)
x = nn.reshape_4D (x, lowest_dense_res, lowest_dense_res, self.ae_out_ch) x = nn.reshape_4D (x, lowest_dense_res, lowest_dense_res, self.ae_out_ch)
x = self.upscale1(x)
if use_fp16:
x = tf.cast(x, tf.float16)
if 't' not in opts:
x = self.upscale1(x)
return x return x
def get_out_res(self): def get_out_res(self):
return lowest_dense_res * 2 return lowest_dense_res * 2 if 't' not in opts else lowest_dense_res
def get_out_ch(self): def get_out_ch(self):
return self.ae_out_ch return self.ae_out_ch
class Decoder(nn.ModelBase): class Decoder(nn.ModelBase):
def on_build(self, in_ch, d_ch, d_mask_ch ): def on_build(self, in_ch, d_ch, d_mask_ch):
self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3) if 't' not in opts:
self.upscale1 = Upscale(d_ch*8, d_ch*4, kernel_size=3) self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3)
self.upscale2 = Upscale(d_ch*4, d_ch*2, kernel_size=3) self.upscale1 = Upscale(d_ch*8, d_ch*4, kernel_size=3)
self.upscale2 = Upscale(d_ch*4, d_ch*2, kernel_size=3)
self.res0 = ResidualBlock(d_ch*8, kernel_size=3)
self.res1 = ResidualBlock(d_ch*4, kernel_size=3)
self.res2 = ResidualBlock(d_ch*2, kernel_size=3)
self.res0 = ResidualBlock(d_ch*8, kernel_size=3) self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3)
self.res1 = ResidualBlock(d_ch*4, kernel_size=3) self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3)
self.res2 = ResidualBlock(d_ch*2, kernel_size=3) self.upscalem2 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3)
self.out_conv = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='SAME') self.out_conv = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='SAME', dtype=conv_dtype)
self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3) if 'd' in opts:
self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3) self.out_conv1 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
self.upscalem2 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3) self.out_conv2 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME') self.out_conv3 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
self.upscalem3 = Upscale(d_mask_ch*2, d_mask_ch*1, kernel_size=3)
if 'd' in opts: self.out_convm = nn.Conv2D( d_mask_ch*1, 1, kernel_size=1, padding='SAME', dtype=conv_dtype)
self.out_conv1 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME') else:
self.out_conv2 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME') self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME', dtype=conv_dtype)
self.out_conv3 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME')
self.upscalem3 = Upscale(d_mask_ch*2, d_mask_ch*1, kernel_size=3)
self.out_convm = nn.Conv2D( d_mask_ch*1, 1, kernel_size=1, padding='SAME')
else: else:
self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME') self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3)
self.upscale1 = Upscale(d_ch*8, d_ch*8, kernel_size=3)
self.upscale2 = Upscale(d_ch*8, d_ch*4, kernel_size=3)
self.upscale3 = Upscale(d_ch*4, d_ch*2, kernel_size=3)
self.res0 = ResidualBlock(d_ch*8, kernel_size=3)
self.res1 = ResidualBlock(d_ch*8, kernel_size=3)
self.res2 = ResidualBlock(d_ch*4, kernel_size=3)
self.res3 = ResidualBlock(d_ch*2, kernel_size=3)
def forward(self, inp): self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3)
z = inp self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*8, kernel_size=3)
self.upscalem2 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3)
self.upscalem3 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3)
self.out_conv = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='SAME', dtype=conv_dtype)
if 'd' in opts:
self.out_conv1 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
self.out_conv2 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
self.out_conv3 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
self.upscalem4 = Upscale(d_mask_ch*2, d_mask_ch*1, kernel_size=3)
self.out_convm = nn.Conv2D( d_mask_ch*1, 1, kernel_size=1, padding='SAME', dtype=conv_dtype)
else:
self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME', dtype=conv_dtype)
def forward(self, z):
x = self.upscale0(z) x = self.upscale0(z)
x = self.res0(x) x = self.res0(x)
x = self.upscale1(x) x = self.upscale1(x)
@ -157,40 +218,15 @@ class DeepFakeArchi(nn.ArchiBase):
x = self.upscale2(x) x = self.upscale2(x)
x = self.res2(x) x = self.res2(x)
if 't' in opts:
x = self.upscale3(x)
x = self.res3(x)
if 'd' in opts: if 'd' in opts:
x0 = tf.nn.sigmoid(self.out_conv(x)) x = tf.nn.sigmoid( nn.depth_to_space(tf.concat( (self.out_conv(x),
x0 = nn.upsample2d(x0) self.out_conv1(x),
x1 = tf.nn.sigmoid(self.out_conv1(x)) self.out_conv2(x),
x1 = nn.upsample2d(x1) self.out_conv3(x)), nn.conv2d_ch_axis), 2) )
x2 = tf.nn.sigmoid(self.out_conv2(x))
x2 = nn.upsample2d(x2)
x3 = tf.nn.sigmoid(self.out_conv3(x))
x3 = nn.upsample2d(x3)
if nn.data_format == "NHWC":
tile_cfg = ( 1, resolution // 2, resolution //2, 1)
else:
tile_cfg = ( 1, 1, resolution // 2, resolution //2 )
z0 = tf.concat ( ( tf.concat ( ( tf.ones ( (1,1,1,1) ), tf.zeros ( (1,1,1,1) ) ), axis=nn.conv2d_spatial_axes[1] ),
tf.concat ( ( tf.zeros ( (1,1,1,1) ), tf.zeros ( (1,1,1,1) ) ), axis=nn.conv2d_spatial_axes[1] ) ), axis=nn.conv2d_spatial_axes[0] )
z0 = tf.tile ( z0, tile_cfg )
z1 = tf.concat ( ( tf.concat ( ( tf.zeros ( (1,1,1,1) ), tf.ones ( (1,1,1,1) ) ), axis=nn.conv2d_spatial_axes[1] ),
tf.concat ( ( tf.zeros ( (1,1,1,1) ), tf.zeros ( (1,1,1,1) ) ), axis=nn.conv2d_spatial_axes[1] ) ), axis=nn.conv2d_spatial_axes[0] )
z1 = tf.tile ( z1, tile_cfg )
z2 = tf.concat ( ( tf.concat ( ( tf.zeros ( (1,1,1,1) ), tf.zeros ( (1,1,1,1) ) ), axis=nn.conv2d_spatial_axes[1] ),
tf.concat ( ( tf.ones ( (1,1,1,1) ), tf.zeros ( (1,1,1,1) ) ), axis=nn.conv2d_spatial_axes[1] ) ), axis=nn.conv2d_spatial_axes[0] )
z2 = tf.tile ( z2, tile_cfg )
z3 = tf.concat ( ( tf.concat ( ( tf.zeros ( (1,1,1,1) ), tf.zeros ( (1,1,1,1) ) ), axis=nn.conv2d_spatial_axes[1] ),
tf.concat ( ( tf.zeros ( (1,1,1,1) ), tf.ones ( (1,1,1,1) ) ), axis=nn.conv2d_spatial_axes[1] ) ), axis=nn.conv2d_spatial_axes[0] )
z3 = tf.tile ( z3, tile_cfg )
x = x0*z0 + x1*z1 + x2*z2 + x3*z3
else: else:
x = tf.nn.sigmoid(self.out_conv(x)) x = tf.nn.sigmoid(self.out_conv(x))
@ -198,12 +234,23 @@ class DeepFakeArchi(nn.ArchiBase):
m = self.upscalem0(z) m = self.upscalem0(z)
m = self.upscalem1(m) m = self.upscalem1(m)
m = self.upscalem2(m) m = self.upscalem2(m)
if 'd' in opts:
if 't' in opts:
m = self.upscalem3(m) m = self.upscalem3(m)
if 'd' in opts:
m = self.upscalem4(m)
else:
if 'd' in opts:
m = self.upscalem3(m)
m = tf.nn.sigmoid(self.out_convm(m)) m = tf.nn.sigmoid(self.out_convm(m))
if use_fp16:
x = tf.cast(x, tf.float32)
m = tf.cast(m, tf.float32)
return x, m return x, m
self.Encoder = Encoder self.Encoder = Encoder
self.Inter = Inter self.Inter = Inter
self.Decoder = Decoder self.Decoder = Decoder

View file

@ -55,8 +55,8 @@ class Conv2D(nn.LayerBase):
if kernel_initializer is None: if kernel_initializer is None:
kernel_initializer = tf.initializers.random_normal(0, 1.0, dtype=self.dtype) kernel_initializer = tf.initializers.random_normal(0, 1.0, dtype=self.dtype)
if kernel_initializer is None: #if kernel_initializer is None:
kernel_initializer = nn.initializers.ca() #kernel_initializer = nn.initializers.ca()
self.weight = tf.get_variable("weight", (self.kernel_size,self.kernel_size,self.in_ch,self.out_ch), dtype=self.dtype, initializer=kernel_initializer, trainable=self.trainable ) self.weight = tf.get_variable("weight", (self.kernel_size,self.kernel_size,self.in_ch,self.out_ch), dtype=self.dtype, initializer=kernel_initializer, trainable=self.trainable )

View file

@ -38,8 +38,9 @@ class Conv2DTranspose(nn.LayerBase):
if kernel_initializer is None: if kernel_initializer is None:
kernel_initializer = tf.initializers.random_normal(0, 1.0, dtype=self.dtype) kernel_initializer = tf.initializers.random_normal(0, 1.0, dtype=self.dtype)
if kernel_initializer is None: #if kernel_initializer is None:
kernel_initializer = nn.initializers.ca() #kernel_initializer = nn.initializers.ca()
self.weight = tf.get_variable("weight", (self.kernel_size,self.kernel_size,self.out_ch,self.in_ch), dtype=self.dtype, initializer=kernel_initializer, trainable=self.trainable ) self.weight = tf.get_variable("weight", (self.kernel_size,self.kernel_size,self.out_ch,self.in_ch), dtype=self.dtype, initializer=kernel_initializer, trainable=self.trainable )
if self.use_bias: if self.use_bias:

View file

@ -68,8 +68,8 @@ class DepthwiseConv2D(nn.LayerBase):
if kernel_initializer is None: if kernel_initializer is None:
kernel_initializer = tf.initializers.random_normal(0, 1.0, dtype=self.dtype) kernel_initializer = tf.initializers.random_normal(0, 1.0, dtype=self.dtype)
if kernel_initializer is None: #if kernel_initializer is None:
kernel_initializer = nn.initializers.ca() #kernel_initializer = nn.initializers.ca()
self.weight = tf.get_variable("weight", (self.kernel_size,self.kernel_size,self.in_ch,self.depth_multiplier), dtype=self.dtype, initializer=kernel_initializer, trainable=self.trainable ) self.weight = tf.get_variable("weight", (self.kernel_size,self.kernel_size,self.in_ch,self.depth_multiplier), dtype=self.dtype, initializer=kernel_initializer, trainable=self.trainable )

View file

@ -46,7 +46,9 @@ class Saveable():
raise Exception("name must be defined.") raise Exception("name must be defined.")
name = self.name name = self.name
for w, w_val in zip(weights, nn.tf_sess.run (weights)):
for w in weights:
w_val = nn.tf_sess.run (w).copy()
w_name_split = w.name.split('/', 1) w_name_split = w.name.split('/', 1)
if name != w_name_split[0]: if name != w_name_split[0]:
raise Exception("weight first name != Saveable.name") raise Exception("weight first name != Saveable.name")
@ -97,10 +99,10 @@ class Saveable():
nn.batch_set_value(tuples) nn.batch_set_value(tuples)
except: except:
return False return False
return True return True
def init_weights(self): def init_weights(self):
nn.init_weights(self.get_weights()) nn.init_weights(self.get_weights())
nn.Saveable = Saveable nn.Saveable = Saveable

View file

@ -212,7 +212,9 @@ def gaussian_blur(input, radius=2.0):
return np.exp(-(float(x) - float(mu)) ** 2 / (2 * sigma ** 2)) return np.exp(-(float(x) - float(mu)) ** 2 / (2 * sigma ** 2))
def make_kernel(sigma): def make_kernel(sigma):
kernel_size = max(3, int(2 * 2 * sigma + 1)) kernel_size = max(3, int(2 * 2 * sigma))
if kernel_size % 2 == 0:
kernel_size += 1
mean = np.floor(0.5 * kernel_size) mean = np.floor(0.5 * kernel_size)
kernel_1d = np.array([gaussian(x, mean, sigma) for x in range(kernel_size)]) kernel_1d = np.array([gaussian(x, mean, sigma) for x in range(kernel_size)])
np_kernel = np.outer(kernel_1d, kernel_1d).astype(np.float32) np_kernel = np.outer(kernel_1d, kernel_1d).astype(np.float32)
@ -237,19 +239,6 @@ def gaussian_blur(input, radius=2.0):
return x return x
nn.gaussian_blur = gaussian_blur nn.gaussian_blur = gaussian_blur
def get_gaussian_weights(batch_size, in_ch, resolution, num_scale=5, sigma=(0.5, 1., 2., 4., 8.)):
w = np.empty((num_scale, batch_size, in_ch, resolution, resolution))
for i in range(num_scale):
gaussian = np.exp(-1.*np.arange(-(resolution/2-0.5), resolution/2+0.5)**2/(2*sigma[i]**2))
gaussian = np.outer(gaussian, gaussian.reshape((resolution, 1))) # extend to 2D
gaussian = gaussian/np.sum(gaussian) # normalization
gaussian = np.reshape(gaussian, (1, 1, resolution, resolution)) # reshape to 3D
gaussian = np.tile(gaussian, (batch_size, in_ch, 1, 1))
w[i, :, :, :, :] = gaussian
return w
nn.get_gaussian_weights = get_gaussian_weights
def style_loss(target, style, gaussian_blur_radius=0.0, loss_weight=1.0, step_size=1): def style_loss(target, style, gaussian_blur_radius=0.0, loss_weight=1.0, step_size=1):
def sd(content, style, loss_weight): def sd(content, style, loss_weight):
content_nc = content.shape[ nn.conv2d_ch_axis ] content_nc = content.shape[ nn.conv2d_ch_axis ]
@ -395,7 +384,7 @@ def total_variation_mse(images):
""" """
pixel_dif1 = images[:, 1:, :, :] - images[:, :-1, :, :] pixel_dif1 = images[:, 1:, :, :] - images[:, :-1, :, :]
pixel_dif2 = images[:, :, 1:, :] - images[:, :, :-1, :] pixel_dif2 = images[:, :, 1:, :] - images[:, :, :-1, :]
tot_var = ( tf.reduce_sum(tf.square(pixel_dif1), axis=[1,2,3]) + tot_var = ( tf.reduce_sum(tf.square(pixel_dif1), axis=[1,2,3]) +
tf.reduce_sum(tf.square(pixel_dif2), axis=[1,2,3]) ) tf.reduce_sum(tf.square(pixel_dif2), axis=[1,2,3]) )
return tot_var return tot_var
@ -416,3 +405,68 @@ def tf_suppress_lower_mean(t, eps=0.00001):
q = q * (t/eps) q = q * (t/eps)
return q return q
""" """
def _get_pixel_value(img, x, y):
shape = tf.shape(x)
batch_size = shape[0]
height = shape[1]
width = shape[2]
batch_idx = tf.range(0, batch_size)
batch_idx = tf.reshape(batch_idx, (batch_size, 1, 1))
b = tf.tile(batch_idx, (1, height, width))
indices = tf.stack([b, y, x], 3)
return tf.gather_nd(img, indices)
def bilinear_sampler(img, x, y):
H = tf.shape(img)[1]
W = tf.shape(img)[2]
H_MAX = tf.cast(H - 1, tf.int32)
W_MAX = tf.cast(W - 1, tf.int32)
# grab 4 nearest corner points for each (x_i, y_i)
x0 = tf.cast(tf.floor(x), tf.int32)
x1 = x0 + 1
y0 = tf.cast(tf.floor(y), tf.int32)
y1 = y0 + 1
# clip to range [0, H-1/W-1] to not violate img boundaries
x0 = tf.clip_by_value(x0, 0, W_MAX)
x1 = tf.clip_by_value(x1, 0, W_MAX)
y0 = tf.clip_by_value(y0, 0, H_MAX)
y1 = tf.clip_by_value(y1, 0, H_MAX)
# get pixel value at corner coords
Ia = _get_pixel_value(img, x0, y0)
Ib = _get_pixel_value(img, x0, y1)
Ic = _get_pixel_value(img, x1, y0)
Id = _get_pixel_value(img, x1, y1)
# recast as float for delta calculation
x0 = tf.cast(x0, tf.float32)
x1 = tf.cast(x1, tf.float32)
y0 = tf.cast(y0, tf.float32)
y1 = tf.cast(y1, tf.float32)
# calculate deltas
wa = (x1-x) * (y1-y)
wb = (x1-x) * (y-y0)
wc = (x-x0) * (y1-y)
wd = (x-x0) * (y-y0)
# add dimension for addition
wa = tf.expand_dims(wa, axis=3)
wb = tf.expand_dims(wb, axis=3)
wc = tf.expand_dims(wc, axis=3)
wd = tf.expand_dims(wd, axis=3)
# compute output
out = tf.add_n([wa*Ia, wb*Ib, wc*Ic, wd*Id])
return out
nn.bilinear_sampler = bilinear_sampler

View file

@ -50,11 +50,11 @@ class AdaBelief(nn.OptimizerBase):
updates = [] updates = []
if self.clipnorm > 0.0: if self.clipnorm > 0.0:
norm = tf.sqrt( sum([tf.reduce_sum(tf.square(g)) for g,v in grads_vars])) norm = tf.sqrt( sum([tf.reduce_sum(tf.square(tf.cast(g, tf.float32))) for g,v in grads_vars]))
updates += [ state_ops.assign_add( self.iterations, 1) ] updates += [ state_ops.assign_add( self.iterations, 1) ]
for i, (g,v) in enumerate(grads_vars): for i, (g,v) in enumerate(grads_vars):
if self.clipnorm > 0.0: if self.clipnorm > 0.0:
g = self.tf_clip_norm(g, self.clipnorm, norm) g = self.tf_clip_norm(g, self.clipnorm, tf.cast(norm, g.dtype) )
ms = self.ms_dict[ v.name ] ms = self.ms_dict[ v.name ]
vs = self.vs_dict[ v.name ] vs = self.vs_dict[ v.name ]

View file

@ -47,11 +47,11 @@ class RMSprop(nn.OptimizerBase):
updates = [] updates = []
if self.clipnorm > 0.0: if self.clipnorm > 0.0:
norm = tf.sqrt( sum([tf.reduce_sum(tf.square(g)) for g,v in grads_vars])) norm = tf.sqrt( sum([tf.reduce_sum(tf.square(tf.cast(g, tf.float32))) for g,v in grads_vars]))
updates += [ state_ops.assign_add( self.iterations, 1) ] updates += [ state_ops.assign_add( self.iterations, 1) ]
for i, (g,v) in enumerate(grads_vars): for i, (g,v) in enumerate(grads_vars):
if self.clipnorm > 0.0: if self.clipnorm > 0.0:
g = self.tf_clip_norm(g, self.clipnorm, norm) g = self.tf_clip_norm(g, self.clipnorm, tf.cast(norm, g.dtype) )
a = self.accumulators_dict[ v.name ] a = self.accumulators_dict[ v.name ]

10
main.py
View file

@ -153,6 +153,16 @@ if __name__ == "__main__":
p.add_argument('--execute-program', dest="execute_program", default=[], action='append', nargs='+') p.add_argument('--execute-program', dest="execute_program", default=[], action='append', nargs='+')
p.set_defaults (func=process_train) p.set_defaults (func=process_train)
def process_exportdfm(arguments):
osex.set_process_lowest_prio()
from mainscripts import ExportDFM
ExportDFM.main(model_class_name = arguments.model_name, saved_models_path = Path(arguments.model_dir))
p = subparsers.add_parser( "exportdfm", help="Export model to use in DeepFaceLive.")
p.add_argument('--model-dir', required=True, action=fixPathAction, dest="model_dir", help="Saved models dir.")
p.add_argument('--model', required=True, dest="model_name", choices=pathex.get_all_dir_names_startswith ( Path(__file__).parent / 'models' , 'Model_'), help="Model class name.")
p.set_defaults (func=process_exportdfm)
def process_merge(arguments): def process_merge(arguments):
osex.set_process_lowest_prio() osex.set_process_lowest_prio()
from mainscripts import Merger from mainscripts import Merger

View file

@ -166,7 +166,7 @@ class FacesetResizerSubprocessor(Subprocessor):
def process_folder ( dirpath): def process_folder ( dirpath):
image_size = io.input_int(f"New image size", 512, valid_range=[256,2048]) image_size = io.input_int(f"New image size", 512, valid_range=[128,2048])
face_type = io.input_str ("Change face type", 'same', ['h','mf','f','wf','head','same']).lower() face_type = io.input_str ("Change face type", 'same', ['h','mf','f','wf','head','same']).lower()
if face_type == 'same': if face_type == 'same':

View file

@ -49,6 +49,7 @@ def main (model_class_name=None,
model = models.import_model(model_class_name)(is_training=False, model = models.import_model(model_class_name)(is_training=False,
saved_models_path=saved_models_path, saved_models_path=saved_models_path,
force_gpu_idxs=force_gpu_idxs, force_gpu_idxs=force_gpu_idxs,
force_model_name=force_model_name,
cpu_only=cpu_only) cpu_only=cpu_only)
predictor_func, predictor_input_shape, cfg = model.get_MergerConfig() predictor_func, predictor_input_shape, cfg = model.get_MergerConfig()

View file

@ -36,7 +36,7 @@ def trainerThread (s2c, c2s, e,
try: try:
start_time = time.time() start_time = time.time()
save_interval_min = 15 save_interval_min = 25
if not training_data_src_path.exists(): if not training_data_src_path.exists():
training_data_src_path.mkdir(exist_ok=True, parents=True) training_data_src_path.mkdir(exist_ok=True, parents=True)

View file

@ -23,6 +23,7 @@ from samplelib import SampleGeneratorBase
class ModelBase(object): class ModelBase(object):
def __init__(self, is_training=False, def __init__(self, is_training=False,
is_exporting=False,
saved_models_path=None, saved_models_path=None,
training_data_src_path=None, training_data_src_path=None,
training_data_dst_path=None, training_data_dst_path=None,
@ -37,6 +38,7 @@ class ModelBase(object):
silent_start=False, silent_start=False,
**kwargs): **kwargs):
self.is_training = is_training self.is_training = is_training
self.is_exporting = is_exporting
self.saved_models_path = saved_models_path self.saved_models_path = saved_models_path
self.training_data_src_path = training_data_src_path self.training_data_src_path = training_data_src_path
self.training_data_dst_path = training_data_dst_path self.training_data_dst_path = training_data_dst_path
@ -234,7 +236,7 @@ class ModelBase(object):
preview_id_counter = 0 preview_id_counter = 0
while not choosed: while not choosed:
self.sample_for_preview = self.generate_next_samples() self.sample_for_preview = self.generate_next_samples()
previews = self.get_static_previews() previews = self.get_history_previews()
io.show_image( wnd_name, ( previews[preview_id_counter % len(previews) ][1] *255).astype(np.uint8) ) io.show_image( wnd_name, ( previews[preview_id_counter % len(previews) ][1] *255).astype(np.uint8) )
@ -260,7 +262,7 @@ class ModelBase(object):
self.sample_for_preview = self.generate_next_samples() self.sample_for_preview = self.generate_next_samples()
try: try:
self.get_static_previews() self.get_history_previews()
except: except:
self.sample_for_preview = self.generate_next_samples() self.sample_for_preview = self.generate_next_samples()
@ -357,7 +359,7 @@ class ModelBase(object):
return ( ('loss_src', 0), ('loss_dst', 0) ) return ( ('loss_src', 0), ('loss_dst', 0) )
#overridable #overridable
def onGetPreview(self, sample): def onGetPreview(self, sample, for_history=False):
#you can return multiple previews #you can return multiple previews
#return [ ('preview_name',preview_rgb), ... ] #return [ ('preview_name',preview_rgb), ... ]
return [] return []
@ -390,6 +392,9 @@ class ModelBase(object):
def get_static_previews(self): def get_static_previews(self):
return self.onGetPreview (self.sample_for_preview) return self.onGetPreview (self.sample_for_preview)
def get_history_previews(self):
return self.onGetPreview (self.sample_for_preview, for_history=True)
def get_preview_history_writer(self): def get_preview_history_writer(self):
if self.preview_history_writer is None: if self.preview_history_writer is None:
self.preview_history_writer = PreviewHistoryWriter() self.preview_history_writer = PreviewHistoryWriter()

File diff suppressed because it is too large Load diff

View file

@ -278,7 +278,7 @@ class QModel(ModelBase):
return ( ('src_loss', src_loss), ('dst_loss', dst_loss), ) return ( ('src_loss', src_loss), ('dst_loss', dst_loss), )
#override #override
def onGetPreview(self, samples): def onGetPreview(self, samples, for_history=False):
( (warped_src, target_src, target_srcm), ( (warped_src, target_src, target_srcm),
(warped_dst, target_dst, target_dstm) ) = samples (warped_dst, target_dst, target_dstm) ) = samples

View file

@ -30,13 +30,12 @@ class SAEHDModel(ModelBase):
min_res = 64 min_res = 64
max_res = 640 max_res = 640
default_usefp16 = self.options['use_fp16'] = self.load_or_def_option('use_fp16', False)
default_resolution = self.options['resolution'] = self.load_or_def_option('resolution', 128) default_resolution = self.options['resolution'] = self.load_or_def_option('resolution', 128)
default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'f') default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'f')
default_models_opt_on_gpu = self.options['models_opt_on_gpu'] = self.load_or_def_option('models_opt_on_gpu', True) default_models_opt_on_gpu = self.options['models_opt_on_gpu'] = self.load_or_def_option('models_opt_on_gpu', True)
archi = self.load_or_def_option('archi', 'liae-ud') default_archi = self.options['archi'] = self.load_or_def_option('archi', 'liae-ud')
archi = {'dfuhd':'df-u','liaeuhd':'liae-u'}.get(archi, archi) #backward comp
default_archi = self.options['archi'] = archi
default_ae_dims = self.options['ae_dims'] = self.load_or_def_option('ae_dims', 256) default_ae_dims = self.options['ae_dims'] = self.load_or_def_option('ae_dims', 256)
default_e_dims = self.options['e_dims'] = self.load_or_def_option('e_dims', 64) default_e_dims = self.options['e_dims'] = self.load_or_def_option('e_dims', 64)
@ -46,6 +45,7 @@ class SAEHDModel(ModelBase):
default_eyes_prio = self.options['eyes_prio'] = self.load_or_def_option('eyes_prio', False) default_eyes_prio = self.options['eyes_prio'] = self.load_or_def_option('eyes_prio', False)
default_mouth_prio = self.options['mouth_prio'] = self.load_or_def_option('mouth_prio', False) default_mouth_prio = self.options['mouth_prio'] = self.load_or_def_option('mouth_prio', False)
default_uniform_yaw = self.options['uniform_yaw'] = self.load_or_def_option('uniform_yaw', False) default_uniform_yaw = self.options['uniform_yaw'] = self.load_or_def_option('uniform_yaw', False)
default_blur_out_mask = self.options['blur_out_mask'] = self.load_or_def_option('blur_out_mask', False)
default_adabelief = self.options['adabelief'] = self.load_or_def_option('adabelief', True) default_adabelief = self.options['adabelief'] = self.load_or_def_option('adabelief', True)
@ -80,6 +80,7 @@ class SAEHDModel(ModelBase):
self.ask_random_src_flip() self.ask_random_src_flip()
self.ask_random_dst_flip() self.ask_random_dst_flip()
self.ask_batch_size(suggest_batch_size) self.ask_batch_size(suggest_batch_size)
self.options['use_fp16'] = io.input_bool ("Use fp16", default_usefp16, help_message='Increases training/inference speed, reduces model size. Model may crash. Enable it after 1-5k iters.')
if self.is_first_run(): if self.is_first_run():
resolution = io.input_int("Resolution", default_resolution, add_info="64-640", help_message="More resolution requires more VRAM and time to train. Value will be adjusted to multiple of 16 and 32 for -d archi.") resolution = io.input_int("Resolution", default_resolution, add_info="64-640", help_message="More resolution requires more VRAM and time to train. Value will be adjusted to multiple of 16 and 32 for -d archi.")
@ -112,7 +113,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
if archi_opts is not None: if archi_opts is not None:
if len(archi_opts) == 0: if len(archi_opts) == 0:
continue continue
if len([ 1 for opt in archi_opts if opt not in ['u','d'] ]) != 0: if len([ 1 for opt in archi_opts if opt not in ['u','d','t'] ]) != 0:
continue continue
if 'd' in archi_opts: if 'd' in archi_opts:
@ -147,6 +148,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
self.options['mouth_prio'] = io.input_bool ("Mouth priority", default_mouth_prio, help_message='Helps to fix mouth problems during training by forcing the neural network to train mouth with higher priority similar to eyes ') self.options['mouth_prio'] = io.input_bool ("Mouth priority", default_mouth_prio, help_message='Helps to fix mouth problems during training by forcing the neural network to train mouth with higher priority similar to eyes ')
self.options['uniform_yaw'] = io.input_bool ("Uniform yaw distribution of samples", default_uniform_yaw, help_message='Helps to fix blurry side faces due to small amount of them in the faceset.') self.options['uniform_yaw'] = io.input_bool ("Uniform yaw distribution of samples", default_uniform_yaw, help_message='Helps to fix blurry side faces due to small amount of them in the faceset.')
self.options['blur_out_mask'] = io.input_bool ("Blur out mask", default_blur_out_mask, help_message='Blurs nearby area outside of applied face mask of training samples. The result is the background near the face is smoothed and less noticeable on swapped face. The exact xseg mask in src and dst faceset is required.')
default_gan_version = self.options['gan_version'] = self.load_or_def_option('gan_version', 2) default_gan_version = self.options['gan_version'] = self.load_or_def_option('gan_version', 2)
default_gan_power = self.options['gan_power'] = self.load_or_def_option('gan_power', 0.0) default_gan_power = self.options['gan_power'] = self.load_or_def_option('gan_power', 0.0)
@ -250,11 +252,17 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
self.set_iter(0) self.set_iter(0)
adabelief = self.options['adabelief'] adabelief = self.options['adabelief']
use_fp16 = False
if self.is_exporting:
use_fp16 = io.input_bool ("Export quantized?", False, help_message='Makes the exported model faster. If you have problems, disable this option.')
self.gan_power = gan_power = 0.0 if self.pretrain else self.options['gan_power'] self.gan_power = gan_power = 0.0 if self.pretrain else self.options['gan_power']
random_warp = False if self.pretrain else self.options['random_warp'] random_warp = False if self.pretrain else self.options['random_warp']
random_src_flip = self.random_src_flip if not self.pretrain else True random_src_flip = self.random_src_flip if not self.pretrain else True
random_dst_flip = self.random_dst_flip if not self.pretrain else True random_dst_flip = self.random_dst_flip if not self.pretrain else True
blur_out_mask = self.options['blur_out_mask']
learn_dst_bg = False#True
if self.pretrain: if self.pretrain:
self.options_show_override['gan_power'] = 0.0 self.options_show_override['gan_power'] = 0.0
@ -293,7 +301,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
self.target_dstm_em = tf.placeholder (nn.floatx, mask_shape, name='target_dstm_em') self.target_dstm_em = tf.placeholder (nn.floatx, mask_shape, name='target_dstm_em')
# Initializing model classes # Initializing model classes
model_archi = nn.DeepFakeArchi(resolution, opts=archi_opts) model_archi = nn.DeepFakeArchi(resolution, use_fp16=use_fp16, opts=archi_opts)
with tf.device (models_opt_device): with tf.device (models_opt_device):
if 'df' in archi_type: if 'df' in archi_type:
@ -407,6 +415,22 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
gpu_target_dstm_all = self.target_dstm[batch_slice,:,:,:] gpu_target_dstm_all = self.target_dstm[batch_slice,:,:,:]
gpu_target_dstm_em = self.target_dstm_em[batch_slice,:,:,:] gpu_target_dstm_em = self.target_dstm_em[batch_slice,:,:,:]
gpu_target_srcm_anti = 1-gpu_target_srcm_all
gpu_target_dstm_anti = 1-gpu_target_dstm_all
if blur_out_mask:
sigma = resolution / 128
x = nn.gaussian_blur(gpu_target_src*gpu_target_srcm_anti, sigma)
y = 1-nn.gaussian_blur(gpu_target_srcm_all, sigma)
y = tf.where(tf.equal(y, 0), tf.ones_like(y), y)
gpu_target_src = gpu_target_src*gpu_target_srcm_all + (x/y)*gpu_target_srcm_anti
x = nn.gaussian_blur(gpu_target_dst*gpu_target_dstm_anti, sigma)
y = 1-nn.gaussian_blur(gpu_target_dstm_all, sigma)
y = tf.where(tf.equal(y, 0), tf.ones_like(y), y)
gpu_target_dst = gpu_target_dst*gpu_target_dstm_all + (x/y)*gpu_target_dstm_anti
# process model tensors # process model tensors
if 'df' in archi_type: if 'df' in archi_type:
gpu_src_code = self.inter(self.encoder(gpu_warped_src)) gpu_src_code = self.inter(self.encoder(gpu_warped_src))
@ -414,6 +438,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
gpu_pred_src_src, gpu_pred_src_srcm = self.decoder_src(gpu_src_code) gpu_pred_src_src, gpu_pred_src_srcm = self.decoder_src(gpu_src_code)
gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code) gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code)
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code)
gpu_pred_src_dst_no_code_grad, _ = self.decoder_src(tf.stop_gradient(gpu_dst_code))
elif 'liae' in archi_type: elif 'liae' in archi_type:
gpu_src_code = self.encoder (gpu_warped_src) gpu_src_code = self.encoder (gpu_warped_src)
@ -427,7 +452,9 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
gpu_pred_src_src, gpu_pred_src_srcm = self.decoder(gpu_src_code) gpu_pred_src_src, gpu_pred_src_srcm = self.decoder(gpu_src_code)
gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder(gpu_dst_code) gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder(gpu_dst_code)
gpu_pred_dst_dst_no_code_grad, _ = self.decoder(tf.stop_gradient(gpu_dst_code))
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code)
gpu_pred_src_dst_no_code_grad, _ = self.decoder(tf.stop_gradient(gpu_src_dst_code))
gpu_pred_src_src_list.append(gpu_pred_src_src) gpu_pred_src_src_list.append(gpu_pred_src_src)
gpu_pred_dst_dst_list.append(gpu_pred_dst_dst) gpu_pred_dst_dst_list.append(gpu_pred_dst_dst)
@ -449,25 +476,31 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
gpu_target_srcm_blur = nn.gaussian_blur(gpu_target_srcm, max(1, resolution // 32) ) gpu_target_srcm_blur = nn.gaussian_blur(gpu_target_srcm, max(1, resolution // 32) )
gpu_target_srcm_blur = tf.clip_by_value(gpu_target_srcm_blur, 0, 0.5) * 2 gpu_target_srcm_blur = tf.clip_by_value(gpu_target_srcm_blur, 0, 0.5) * 2
gpu_target_srcm_anti_blur = 1.0-gpu_target_srcm_blur
gpu_target_dstm_blur = nn.gaussian_blur(gpu_target_dstm, max(1, resolution // 32) ) gpu_target_dstm_blur = nn.gaussian_blur(gpu_target_dstm, max(1, resolution // 32) )
gpu_target_dstm_style_blur = gpu_target_dstm_blur #default style mask is 0.5 on boundary gpu_target_dstm_style_blur = gpu_target_dstm_blur #default style mask is 0.5 on boundary
gpu_target_dstm_style_anti_blur = 1.0 - gpu_target_dstm_style_blur
gpu_target_dstm_blur = tf.clip_by_value(gpu_target_dstm_blur, 0, 0.5) * 2 gpu_target_dstm_blur = tf.clip_by_value(gpu_target_dstm_blur, 0, 0.5) * 2
gpu_target_dstm_anti_blur = 1.0-gpu_target_dstm_blur
gpu_target_dst_masked = gpu_target_dst*gpu_target_dstm_blur gpu_target_dst_masked = gpu_target_dst*gpu_target_dstm_blur
gpu_target_dst_style_masked = gpu_target_dst*gpu_target_dstm_style_blur gpu_target_dst_style_masked = gpu_target_dst*gpu_target_dstm_style_blur
gpu_target_dst_style_anti_masked = gpu_target_dst*(1.0 - gpu_target_dstm_style_blur) gpu_target_dst_style_anti_masked = gpu_target_dst*gpu_target_dstm_style_anti_blur
gpu_target_src_anti_masked = gpu_target_src*gpu_target_srcm_anti_blur
gpu_target_dst_anti_masked = gpu_target_dst*gpu_target_dstm_anti_blur
gpu_pred_src_src_anti_masked = gpu_pred_src_src*gpu_target_srcm_anti_blur
gpu_pred_dst_dst_anti_masked = gpu_pred_dst_dst*gpu_target_dstm_anti_blur
gpu_target_src_anti_masked = gpu_target_src*(1.0-gpu_target_srcm_blur)
gpu_target_src_masked_opt = gpu_target_src*gpu_target_srcm_blur if masked_training else gpu_target_src gpu_target_src_masked_opt = gpu_target_src*gpu_target_srcm_blur if masked_training else gpu_target_src
gpu_target_dst_masked_opt = gpu_target_dst_masked if masked_training else gpu_target_dst gpu_target_dst_masked_opt = gpu_target_dst_masked if masked_training else gpu_target_dst
gpu_pred_src_src_masked_opt = gpu_pred_src_src*gpu_target_srcm_blur if masked_training else gpu_pred_src_src gpu_pred_src_src_masked_opt = gpu_pred_src_src*gpu_target_srcm_blur if masked_training else gpu_pred_src_src
gpu_pred_src_src_anti_masked = gpu_pred_src_src*(1.0-gpu_target_srcm_blur) gpu_pred_src_src_anti_masked = gpu_pred_src_src*(1.0-gpu_target_srcm_blur)
gpu_pred_dst_dst_masked_opt = gpu_pred_dst_dst*gpu_target_dstm_blur if masked_training else gpu_pred_dst_dst gpu_pred_dst_dst_masked_opt = gpu_pred_dst_dst*gpu_target_dstm_blur if masked_training else gpu_pred_dst_dst
gpu_psd_target_dst_style_masked = gpu_pred_src_dst*gpu_target_dstm_style_blur gpu_psd_target_dst_style_masked = gpu_pred_src_dst*gpu_target_dstm_style_blur
gpu_psd_target_dst_style_anti_masked = gpu_pred_src_dst*(1.0 - gpu_target_dstm_style_blur) gpu_psd_target_dst_style_anti_masked = gpu_pred_src_dst*gpu_target_dstm_style_anti_blur
if self.options['loss_function'] == 'MS-SSIM': if self.options['loss_function'] == 'MS-SSIM':
gpu_src_loss = 10 * nn.MsSsim(bs_per_gpu, input_ch, resolution)(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0) gpu_src_loss = 10 * nn.MsSsim(bs_per_gpu, input_ch, resolution)(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0)
@ -512,7 +545,8 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
face_style_power = self.options['face_style_power'] / 100.0 face_style_power = self.options['face_style_power'] / 100.0
if face_style_power != 0 and not self.pretrain: if face_style_power != 0 and not self.pretrain:
gpu_src_loss += nn.style_loss(gpu_psd_target_dst_style_masked, gpu_target_dst_style_masked, gaussian_blur_radius=resolution//16, loss_weight=10000*face_style_power) gpu_src_loss += nn.style_loss(gpu_pred_src_dst_no_code_grad*tf.stop_gradient(gpu_pred_src_dstm), tf.stop_gradient(gpu_pred_dst_dst*gpu_pred_dst_dstm), gaussian_blur_radius=resolution//8, loss_weight=10000*face_style_power)
#gpu_src_loss += nn.style_loss(gpu_psd_target_dst_style_masked, gpu_target_dst_style_masked, gaussian_blur_radius=resolution//16, loss_weight=10000*face_style_power)
bg_style_power = self.options['bg_style_power'] / 100.0 bg_style_power = self.options['bg_style_power'] / 100.0
if bg_style_power != 0 and not self.pretrain: if bg_style_power != 0 and not self.pretrain:
@ -532,7 +566,6 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
gpu_dst_loss += tf.reduce_mean ( 5*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/23.2) ), axis=[1]) gpu_dst_loss += tf.reduce_mean ( 5*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/23.2) ), axis=[1])
gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dst_masked_opt- gpu_pred_dst_dst_masked_opt ), axis=[1,2,3]) gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dst_masked_opt- gpu_pred_dst_dst_masked_opt ), axis=[1,2,3])
if eyes_prio or mouth_prio: if eyes_prio or mouth_prio:
if eyes_prio and mouth_prio: if eyes_prio and mouth_prio:
gpu_target_part_mask = gpu_target_dstm_eye_mouth gpu_target_part_mask = gpu_target_dstm_eye_mouth
@ -566,6 +599,9 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
gpu_G_loss = gpu_src_loss + gpu_dst_loss gpu_G_loss = gpu_src_loss + gpu_dst_loss
if learn_dst_bg and masked_training and 'liae' in archi_type:
gpu_G_loss += tf.reduce_mean( tf.square(gpu_pred_dst_dst_no_code_grad*gpu_target_dstm_anti_blur-gpu_target_dst_anti_masked),axis=[1,2,3] )
def DLoss(labels,logits): def DLoss(labels,logits):
return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits), axis=[1,2,3]) return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits), axis=[1,2,3])
@ -750,7 +786,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
random_ct_samples_path=training_data_dst_path if ct_mode is not None and not self.pretrain else None random_ct_samples_path=training_data_dst_path if ct_mode is not None and not self.pretrain else None
cpu_count = min(multiprocessing.cpu_count(), 8) cpu_count = multiprocessing.cpu_count()
src_generators_count = cpu_count // 2 src_generators_count = cpu_count // 2
dst_generators_count = cpu_count // 2 dst_generators_count = cpu_count // 2
if ct_mode is not None: if ct_mode is not None:
@ -801,16 +837,20 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
if self.pretrain_just_disabled: if self.pretrain_just_disabled:
self.update_sample_for_preview(force_new=True) self.update_sample_for_preview(force_new=True)
def dump_ckpt(self): def export_dfm (self):
output_path=self.get_strpath_storage_for_file('model.dfm')
io.log_info(f'Dumping .dfm to {output_path}')
tf = nn.tf tf = nn.tf
nn.set_data_format('NCHW')
with tf.device ('/CPU:0'): with tf.device (nn.tf_default_device_name):
warped_dst = tf.placeholder (nn.floatx, (None, self.resolution, self.resolution, 3), name='in_face') warped_dst = tf.placeholder (nn.floatx, (None, self.resolution, self.resolution, 3), name='in_face')
warped_dst = tf.transpose(warped_dst, (0,3,1,2)) warped_dst = tf.transpose(warped_dst, (0,3,1,2))
if 'df' in self.archi_type: if 'df' in self.archi_type:
gpu_dst_code = self.inter(self.encoder(warped_dst)) gpu_dst_code = self.inter(self.encoder(warped_dst))
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code)
@ -825,20 +865,31 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code)
_, gpu_pred_dst_dstm = self.decoder(gpu_dst_code) _, gpu_pred_dst_dstm = self.decoder(gpu_dst_code)
gpu_pred_src_dst = tf.transpose(gpu_pred_src_dst, (0,2,3,1)) gpu_pred_src_dst = tf.transpose(gpu_pred_src_dst, (0,2,3,1))
gpu_pred_dst_dstm = tf.transpose(gpu_pred_dst_dstm, (0,2,3,1)) gpu_pred_dst_dstm = tf.transpose(gpu_pred_dst_dstm, (0,2,3,1))
gpu_pred_src_dstm = tf.transpose(gpu_pred_src_dstm, (0,2,3,1)) gpu_pred_src_dstm = tf.transpose(gpu_pred_src_dstm, (0,2,3,1))
saver = tf.train.Saver()
tf.identity(gpu_pred_dst_dstm, name='out_face_mask') tf.identity(gpu_pred_dst_dstm, name='out_face_mask')
tf.identity(gpu_pred_src_dst, name='out_celeb_face') tf.identity(gpu_pred_src_dst, name='out_celeb_face')
tf.identity(gpu_pred_src_dstm, name='out_celeb_face_mask') tf.identity(gpu_pred_src_dstm, name='out_celeb_face_mask')
saver.save(nn.tf_sess, self.get_strpath_storage_for_file('.ckpt') ) output_graph_def = tf.graph_util.convert_variables_to_constants(
nn.tf_sess,
tf.get_default_graph().as_graph_def(),
['out_face_mask','out_celeb_face','out_celeb_face_mask']
)
import tf2onnx
with tf.device("/CPU:0"):
model_proto, _ = tf2onnx.convert._convert_common(
output_graph_def,
name='SAEHD',
input_names=['in_face:0'],
output_names=['out_face_mask:0','out_celeb_face:0','out_celeb_face_mask:0'],
opset=9,
output_path=output_path)
#override #override
def get_model_filename_list(self): def get_model_filename_list(self):
return self.model_filename_list return self.model_filename_list
@ -892,7 +943,6 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
self.D_src_dst_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em) self.D_src_dst_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em)
return ( ('src_loss', np.mean(src_loss) ), ('dst_loss', np.mean(dst_loss) ), ) return ( ('src_loss', np.mean(src_loss) ), ('dst_loss', np.mean(dst_loss) ), )
#override #override
def onGetPreview(self, samples): def onGetPreview(self, samples):
( (warped_src, target_src, target_srcm, target_srcm_em), ( (warped_src, target_src, target_srcm, target_srcm_em),

View file

@ -25,17 +25,24 @@ class XSegModel(ModelBase):
self.set_iter(0) self.set_iter(0)
default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'wf') default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'wf')
default_pretrain = self.options['pretrain'] = self.load_or_def_option('pretrain', False)
if self.is_first_run(): if self.is_first_run():
self.options['face_type'] = io.input_str ("Face type", default_face_type, ['h','mf','f','wf','head'], help_message="Half / mid face / full face / whole face / head. Choose the same as your deepfake model.").lower() self.options['face_type'] = io.input_str ("Face type", default_face_type, ['h','mf','f','wf','head'], help_message="Half / mid face / full face / whole face / head. Choose the same as your deepfake model.").lower()
if self.is_first_run() or ask_override: if self.is_first_run() or ask_override:
self.ask_batch_size(4, range=[2,16]) self.ask_batch_size(4, range=[2,16])
self.options['pretrain'] = io.input_bool ("Enable pretraining mode", default_pretrain)
if not self.is_exporting and (self.options['pretrain'] and self.get_pretraining_data_path() is None):
raise Exception("pretraining_data_path is not defined")
self.pretrain_just_disabled = (default_pretrain == True and self.options['pretrain'] == False)
#override #override
def on_initialize(self): def on_initialize(self):
device_config = nn.getCurrentDeviceConfig() device_config = nn.getCurrentDeviceConfig()
self.model_data_format = "NCHW" if len(device_config.devices) != 0 and not self.is_debug() else "NHWC" self.model_data_format = "NCHW" if self.is_exporting or (len(device_config.devices) != 0 and not self.is_debug()) else "NHWC"
nn.initialize(data_format=self.model_data_format) nn.initialize(data_format=self.model_data_format)
tf = nn.tf tf = nn.tf
@ -50,7 +57,8 @@ class XSegModel(ModelBase):
'f' : FaceType.FULL, 'f' : FaceType.FULL,
'wf' : FaceType.WHOLE_FACE, 'wf' : FaceType.WHOLE_FACE,
'head' : FaceType.HEAD}[ self.options['face_type'] ] 'head' : FaceType.HEAD}[ self.options['face_type'] ]
place_model_on_cpu = len(devices) == 0 place_model_on_cpu = len(devices) == 0
models_opt_device = '/CPU:0' if place_model_on_cpu else nn.tf_default_device_name models_opt_device = '/CPU:0' if place_model_on_cpu else nn.tf_default_device_name
@ -66,14 +74,17 @@ class XSegModel(ModelBase):
place_model_on_cpu=place_model_on_cpu, place_model_on_cpu=place_model_on_cpu,
optimizer=nn.RMSprop(lr=0.0001, lr_dropout=0.3, name='opt'), optimizer=nn.RMSprop(lr=0.0001, lr_dropout=0.3, name='opt'),
data_format=nn.data_format) data_format=nn.data_format)
self.pretrain = self.options['pretrain']
if self.pretrain_just_disabled:
self.set_iter(0)
if self.is_training: if self.is_training:
# Adjust batch size for multiple GPU # Adjust batch size for multiple GPU
gpu_count = max(1, len(devices) ) gpu_count = max(1, len(devices) )
bs_per_gpu = max(1, self.get_batch_size() // gpu_count) bs_per_gpu = max(1, self.get_batch_size() // gpu_count)
self.set_batch_size( gpu_count*bs_per_gpu) self.set_batch_size( gpu_count*bs_per_gpu)
# Compute losses per GPU # Compute losses per GPU
gpu_pred_list = [] gpu_pred_list = []
@ -81,8 +92,6 @@ class XSegModel(ModelBase):
gpu_loss_gvs = [] gpu_loss_gvs = []
for gpu_id in range(gpu_count): for gpu_id in range(gpu_count):
with tf.device(f'/{devices[gpu_id].tf_dev_type}:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ): with tf.device(f'/{devices[gpu_id].tf_dev_type}:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ):
with tf.device(f'/CPU:0'): with tf.device(f'/CPU:0'):
# slice on CPU, otherwise all batch data will be transfered to GPU first # slice on CPU, otherwise all batch data will be transfered to GPU first
@ -91,10 +100,18 @@ class XSegModel(ModelBase):
gpu_target_t = self.model.target_t [batch_slice,:,:,:] gpu_target_t = self.model.target_t [batch_slice,:,:,:]
# process model tensors # process model tensors
gpu_pred_logits_t, gpu_pred_t = self.model.flow(gpu_input_t) gpu_pred_logits_t, gpu_pred_t = self.model.flow(gpu_input_t, pretrain=self.pretrain)
gpu_pred_list.append(gpu_pred_t) gpu_pred_list.append(gpu_pred_t)
gpu_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=gpu_target_t, logits=gpu_pred_logits_t), axis=[1,2,3])
if self.pretrain:
# Structural loss
gpu_loss = tf.reduce_mean (5*nn.dssim(gpu_target_t, gpu_pred_t, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
gpu_loss += tf.reduce_mean (5*nn.dssim(gpu_target_t, gpu_pred_t, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1])
# Pixel loss
gpu_loss += tf.reduce_mean (10*tf.square(gpu_target_t-gpu_pred_t), axis=[1,2,3])
else:
gpu_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=gpu_target_t, logits=gpu_pred_logits_t), axis=[1,2,3])
gpu_losses += [gpu_loss] gpu_losses += [gpu_loss]
@ -110,9 +127,14 @@ class XSegModel(ModelBase):
# Initializing training and view functions # Initializing training and view functions
def train(input_np, target_np): if self.pretrain:
l, _ = nn.tf_sess.run ( [loss, loss_gv_op], feed_dict={self.model.input_t :input_np, self.model.target_t :target_np }) def train(input_np, target_np):
return l l, _ = nn.tf_sess.run ( [loss, loss_gv_op], feed_dict={self.model.input_t :input_np, self.model.target_t :target_np})
return l
else:
def train(input_np, target_np):
l, _ = nn.tf_sess.run ( [loss, loss_gv_op], feed_dict={self.model.input_t :input_np, self.model.target_t :target_np })
return l
self.train = train self.train = train
def view(input_np): def view(input_np):
@ -124,30 +146,39 @@ class XSegModel(ModelBase):
src_dst_generators_count = cpu_count // 2 src_dst_generators_count = cpu_count // 2
src_generators_count = cpu_count // 2 src_generators_count = cpu_count // 2
dst_generators_count = cpu_count // 2 dst_generators_count = cpu_count // 2
if self.pretrain:
pretrain_gen = SampleGeneratorFace(self.get_pretraining_data_path(), debug=self.is_debug(), batch_size=self.get_batch_size(),
sample_process_options=SampleProcessor.Options(random_flip=True),
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
],
uniform_yaw_distribution=False,
generators_count=cpu_count )
self.set_training_data_generators ([pretrain_gen])
else:
srcdst_generator = SampleGeneratorFaceXSeg([self.training_data_src_path, self.training_data_dst_path],
debug=self.is_debug(),
batch_size=self.get_batch_size(),
resolution=resolution,
face_type=self.face_type,
generators_count=src_dst_generators_count,
data_format=nn.data_format)
src_generator = SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
sample_process_options=SampleProcessor.Options(random_flip=False),
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp':False, 'transform':False, 'channel_type' : SampleProcessor.ChannelType.BGR, 'border_replicate':False, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
],
generators_count=src_generators_count,
raise_on_no_data=False )
dst_generator = SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
sample_process_options=SampleProcessor.Options(random_flip=False),
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp':False, 'transform':False, 'channel_type' : SampleProcessor.ChannelType.BGR, 'border_replicate':False, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
],
generators_count=dst_generators_count,
raise_on_no_data=False )
srcdst_generator = SampleGeneratorFaceXSeg([self.training_data_src_path, self.training_data_dst_path], self.set_training_data_generators ([srcdst_generator, src_generator, dst_generator])
debug=self.is_debug(),
batch_size=self.get_batch_size(),
resolution=resolution,
face_type=self.face_type,
generators_count=src_dst_generators_count,
data_format=nn.data_format)
src_generator = SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
sample_process_options=SampleProcessor.Options(random_flip=False),
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp':False, 'transform':False, 'channel_type' : SampleProcessor.ChannelType.BGR, 'border_replicate':False, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
],
generators_count=src_generators_count,
raise_on_no_data=False )
dst_generator = SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
sample_process_options=SampleProcessor.Options(random_flip=False),
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp':False, 'transform':False, 'channel_type' : SampleProcessor.ChannelType.BGR, 'border_replicate':False, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
],
generators_count=dst_generators_count,
raise_on_no_data=False )
self.set_training_data_generators ([srcdst_generator, src_generator, dst_generator])
#override #override
def get_model_filename_list(self): def get_model_filename_list(self):
@ -159,16 +190,21 @@ class XSegModel(ModelBase):
#override #override
def onTrainOneIter(self): def onTrainOneIter(self):
image_np, mask_np = self.generate_next_samples()[0] image_np, target_np = self.generate_next_samples()[0]
loss = self.train (image_np, mask_np) loss = self.train (image_np, target_np)
return ( ('loss', np.mean(loss) ), ) return ( ('loss', np.mean(loss) ), )
#override #override
def onGetPreview(self, samples): def onGetPreview(self, samples, for_history=False):
n_samples = min(4, self.get_batch_size(), 800 // self.resolution ) n_samples = min(4, self.get_batch_size(), 800 // self.resolution )
srcdst_samples, src_samples, dst_samples = samples if self.pretrain:
image_np, mask_np = srcdst_samples srcdst_samples, = samples
image_np, mask_np = srcdst_samples
else:
srcdst_samples, src_samples, dst_samples = samples
image_np, mask_np = srcdst_samples
I, M, IM, = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([image_np,mask_np] + self.view (image_np) ) ] I, M, IM, = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([image_np,mask_np] + self.view (image_np) ) ]
M, IM, = [ np.repeat (x, (3,), -1) for x in [M, IM] ] M, IM, = [ np.repeat (x, (3,), -1) for x in [M, IM] ]
@ -178,11 +214,14 @@ class XSegModel(ModelBase):
result = [] result = []
st = [] st = []
for i in range(n_samples): for i in range(n_samples):
ar = I[i]*M[i]+0.5*I[i]*(1-M[i])+0.5*green_bg*(1-M[i]), IM[i], I[i]*IM[i]+0.5*I[i]*(1-IM[i]) + 0.5*green_bg*(1-IM[i]) if self.pretrain:
ar = I[i], IM[i]
else:
ar = I[i]*M[i]+0.5*I[i]*(1-M[i])+0.5*green_bg*(1-M[i]), IM[i], I[i]*IM[i]+0.5*I[i]*(1-IM[i]) + 0.5*green_bg*(1-IM[i])
st.append ( np.concatenate ( ar, axis=1) ) st.append ( np.concatenate ( ar, axis=1) )
result += [ ('XSeg training faces', np.concatenate (st, axis=0 )), ] result += [ ('XSeg training faces', np.concatenate (st, axis=0 )), ]
if len(src_samples) != 0: if not self.pretrain and len(src_samples) != 0:
src_np, = src_samples src_np, = src_samples
@ -196,7 +235,7 @@ class XSegModel(ModelBase):
result += [ ('XSeg src faces', np.concatenate (st, axis=0 )), ] result += [ ('XSeg src faces', np.concatenate (st, axis=0 )), ]
if len(dst_samples) != 0: if not self.pretrain and len(dst_samples) != 0:
dst_np, = dst_samples dst_np, = dst_samples
@ -211,5 +250,34 @@ class XSegModel(ModelBase):
result += [ ('XSeg dst faces', np.concatenate (st, axis=0 )), ] result += [ ('XSeg dst faces', np.concatenate (st, axis=0 )), ]
return result return result
def export_dfm (self):
output_path = self.get_strpath_storage_for_file(f'model.onnx')
io.log_info(f'Dumping .onnx to {output_path}')
tf = nn.tf
with tf.device (nn.tf_default_device_name):
input_t = tf.placeholder (nn.floatx, (None, self.resolution, self.resolution, 3), name='in_face')
input_t = tf.transpose(input_t, (0,3,1,2))
_, pred_t = self.model.flow(input_t)
pred_t = tf.transpose(pred_t, (0,2,3,1))
tf.identity(pred_t, name='out_mask')
output_graph_def = tf.graph_util.convert_variables_to_constants(
nn.tf_sess,
tf.get_default_graph().as_graph_def(),
['out_mask']
)
import tf2onnx
with tf.device("/CPU:0"):
model_proto, _ = tf2onnx.convert._convert_common(
output_graph_def,
name='XSeg',
input_names=['in_face:0'],
output_names=['out_mask:0'],
opset=13,
output_path=output_path)
Model = XSegModel Model = XSegModel

View file

@ -1,6 +1,6 @@
tqdm tqdm
numpy==1.19.3 numpy==1.19.3
h5py==2.9.0 h5py==2.10.0
opencv-python==4.1.0.25 opencv-python==4.1.0.25
ffmpeg-python==0.1.17 ffmpeg-python==0.1.17
scikit-image==0.14.2 scikit-image==0.14.2

View file

@ -96,13 +96,14 @@ class SampleProcessor(object):
resolution = opts.get('resolution', None) resolution = opts.get('resolution', None)
if resolution is None: if resolution is None:
continue continue
params_per_resolution[resolution] = imagelib.gen_warp_params(resolution, if resolution not in params_per_resolution:
sample_process_options.random_flip, params_per_resolution[resolution] = imagelib.gen_warp_params(resolution,
rotation_range=sample_process_options.rotation_range, sample_process_options.random_flip,
scale_range=sample_process_options.scale_range, rotation_range=sample_process_options.rotation_range,
tx_range=sample_process_options.tx_range, scale_range=sample_process_options.scale_range,
ty_range=sample_process_options.ty_range, tx_range=sample_process_options.tx_range,
rnd_state=warp_rnd_state) ty_range=sample_process_options.ty_range,
rnd_state=warp_rnd_state)
outputs_sample = [] outputs_sample = []
for opts in output_sample_types: for opts in output_sample_types:
@ -118,6 +119,7 @@ class SampleProcessor(object):
random_jpeg = opts.get('random_jpeg', False) random_jpeg = opts.get('random_jpeg', False)
motion_blur = opts.get('motion_blur', None) motion_blur = opts.get('motion_blur', None)
gaussian_blur = opts.get('gaussian_blur', None) gaussian_blur = opts.get('gaussian_blur', None)
denoise_filter = opts.get('denoise_filter', False)
random_bilinear_resize = opts.get('random_bilinear_resize', None) random_bilinear_resize = opts.get('random_bilinear_resize', None)
random_rgb_levels = opts.get('random_rgb_levels', False) random_rgb_levels = opts.get('random_rgb_levels', False)
random_hsv_shift = opts.get('random_hsv_shift', False) random_hsv_shift = opts.get('random_hsv_shift', False)
@ -166,6 +168,7 @@ class SampleProcessor(object):
img = np.zeros ( sample_bgr.shape[0:2]+(1,), dtype=np.float32) img = np.zeros ( sample_bgr.shape[0:2]+(1,), dtype=np.float32)
if sample_face_type == FaceType.MARK_ONLY: if sample_face_type == FaceType.MARK_ONLY:
raise NotImplementedError()
mat = LandmarksProcessor.get_transform_mat (sample_landmarks, warp_resolution, face_type) mat = LandmarksProcessor.get_transform_mat (sample_landmarks, warp_resolution, face_type)
img = cv2.warpAffine( img, mat, (warp_resolution, warp_resolution), flags=cv2.INTER_LINEAR ) img = cv2.warpAffine( img, mat, (warp_resolution, warp_resolution), flags=cv2.INTER_LINEAR )
@ -286,7 +289,9 @@ class SampleProcessor(object):
random_mask = sd.random_circle_faded ([resolution,resolution], rnd_state=np.random.RandomState (sample_rnd_seed+4)) if random_circle_mask else None random_mask = sd.random_circle_faded ([resolution,resolution], rnd_state=np.random.RandomState (sample_rnd_seed+4)) if random_circle_mask else None
img = imagelib.apply_random_bilinear_resize(img, *random_bilinear_resize, mask=random_mask,rnd_state=np.random.RandomState (sample_rnd_seed+4) ) img = imagelib.apply_random_bilinear_resize(img, *random_bilinear_resize, mask=random_mask,rnd_state=np.random.RandomState (sample_rnd_seed+4) )
if denoise_filter:
d_size = ( (max(*img.shape[:2]) // 128) + 1 )*2 +1
img = cv2.bilateralFilter( np.clip(img*255, 0,255).astype(np.uint8), d_size, 80, 80).astype(np.float32) / 255.0
# Transform from BGR to desired channel_type # Transform from BGR to desired channel_type
if channel_type == SPCT.BGR: if channel_type == SPCT.BGR: