Upgraded to TF version 1.13.2

Removed the wait at first launch for most graphics cards.

Increased speed of training by 10-20%, but you have to retrain all models from scratch.

SAEHD:

added option 'use float16'
	Experimental option. Reduces the model size by half.
	Increases the speed of training.
	Decreases the accuracy of the model.
	The model may collapse or not train.
	Model may not learn the mask in large resolutions.

true_face_training option is replaced by
"True face power". 0.0000 .. 1.0
Experimental option. Discriminates the result face to be more like the src face. Higher value - stronger discrimination.
Comparison - https://i.imgur.com/czScS9q.png
This commit is contained in:
Colombo 2020-01-25 21:58:19 +04:00
parent a3dfcb91b9
commit 76ca79216e
49 changed files with 1320 additions and 1297 deletions

View file

@ -4,7 +4,7 @@ from .DFLJPG import DFLJPG
from .DFLPNG import DFLPNG from .DFLPNG import DFLPNG
class DFLIMG(): class DFLIMG():
@staticmethod @staticmethod
def load(filepath, loader_func=None): def load(filepath, loader_func=None):
if filepath.suffix == '.png': if filepath.suffix == '.png':

View file

@ -197,7 +197,7 @@ class DFLJPG(object):
else: else:
io.log_err("Unable to encode fanseg_mask for %s" % (filename) ) io.log_err("Unable to encode fanseg_mask for %s" % (filename) )
fanseg_mask = None fanseg_mask = None
if ie_polys is not None: if ie_polys is not None:
if not isinstance(ie_polys, list): if not isinstance(ie_polys, list):
ie_polys = ie_polys.dump() ie_polys = ie_polys.dump()

View file

@ -287,7 +287,7 @@ class DFLPNG(object):
f.write ( inst.dump() ) f.write ( inst.dump() )
except: except:
raise Exception( 'cannot save %s' % (filename) ) raise Exception( 'cannot save %s' % (filename) )
@staticmethod @staticmethod
def embed_data(filename, face_type=None, def embed_data(filename, face_type=None,
landmarks=None, landmarks=None,
@ -312,11 +312,11 @@ class DFLPNG(object):
else: else:
io.log_err("Unable to encode fanseg_mask for %s" % (filename) ) io.log_err("Unable to encode fanseg_mask for %s" % (filename) )
fanseg_mask = None fanseg_mask = None
if ie_polys is not None: if ie_polys is not None:
if not isinstance(ie_polys, list): if not isinstance(ie_polys, list):
ie_polys = ie_polys.dump() ie_polys = ie_polys.dump()
DFLPNG.embed_dfldict (filename, {'face_type': face_type, DFLPNG.embed_dfldict (filename, {'face_type': face_type,
'landmarks': landmarks, 'landmarks': landmarks,
'ie_polys' : ie_polys, 'ie_polys' : ie_polys,
@ -351,7 +351,7 @@ class DFLPNG(object):
if fanseg_mask is None: fanseg_mask = self.get_fanseg_mask() if fanseg_mask is None: fanseg_mask = self.get_fanseg_mask()
if eyebrows_expand_mod is None: eyebrows_expand_mod = self.get_eyebrows_expand_mod() if eyebrows_expand_mod is None: eyebrows_expand_mod = self.get_eyebrows_expand_mod()
if relighted is None: relighted = self.get_relighted() if relighted is None: relighted = self.get_relighted()
DFLPNG.embed_data (filename, face_type=face_type, DFLPNG.embed_data (filename, face_type=face_type,
landmarks=landmarks, landmarks=landmarks,
ie_polys=ie_polys, ie_polys=ie_polys,
@ -368,7 +368,7 @@ class DFLPNG(object):
def remove_fanseg_mask(self): def remove_fanseg_mask(self):
self.dfl_dict['fanseg_mask'] = None self.dfl_dict['fanseg_mask'] = None
def remove_source_filename(self): def remove_source_filename(self):
self.dfl_dict['source_filename'] = None self.dfl_dict['source_filename'] = None

View file

@ -54,7 +54,7 @@ class IEPolys:
self.n = max(0, self.n-1) self.n = max(0, self.n-1)
self.dirty = True self.dirty = True
return self.n return self.n
def n_inc(self): def n_inc(self):
self.n = min(len(self.list), self.n+1) self.n = min(len(self.list), self.n+1)
self.dirty = True self.dirty = True

View file

@ -9,7 +9,7 @@ from scipy.sparse.linalg import spsolve
def color_transfer_sot(src,trg, steps=10, batch_size=5, reg_sigmaXY=16.0, reg_sigmaV=5.0): def color_transfer_sot(src,trg, steps=10, batch_size=5, reg_sigmaXY=16.0, reg_sigmaV=5.0):
""" """
Color Transform via Sliced Optimal Transfer Color Transform via Sliced Optimal Transfer
ported by @iperov from https://github.com/dcoeurjo/OTColorTransfer ported by @iperov from https://github.com/dcoeurjo/OTColorTransfer
src - any float range any channel image src - any float range any channel image
dst - any float range any channel image, same shape as src dst - any float range any channel image, same shape as src
@ -17,7 +17,7 @@ def color_transfer_sot(src,trg, steps=10, batch_size=5, reg_sigmaXY=16.0, reg_si
batch_size - solver batch size batch_size - solver batch size
reg_sigmaXY - apply regularization and sigmaXY of filter, otherwise set to 0.0 reg_sigmaXY - apply regularization and sigmaXY of filter, otherwise set to 0.0
reg_sigmaV - sigmaV of filter reg_sigmaV - sigmaV of filter
return value - clip it manually return value - clip it manually
""" """
if not np.issubdtype(src.dtype, np.floating): if not np.issubdtype(src.dtype, np.floating):
@ -27,11 +27,11 @@ def color_transfer_sot(src,trg, steps=10, batch_size=5, reg_sigmaXY=16.0, reg_si
if len(src.shape) != 3: if len(src.shape) != 3:
raise ValueError("src shape must have rank 3 (h,w,c)") raise ValueError("src shape must have rank 3 (h,w,c)")
if src.shape != trg.shape:
raise ValueError("src and trg shapes must be equal")
src_dtype = src.dtype if src.shape != trg.shape:
raise ValueError("src and trg shapes must be equal")
src_dtype = src.dtype
h,w,c = src.shape h,w,c = src.shape
new_src = src.copy() new_src = src.copy()
@ -59,63 +59,63 @@ def color_transfer_sot(src,trg, steps=10, batch_size=5, reg_sigmaXY=16.0, reg_si
src_diff_filt = src_diff_filt[...,None] src_diff_filt = src_diff_filt[...,None]
new_src = src + src_diff_filt new_src = src + src_diff_filt
return new_src return new_src
def color_transfer_mkl(x0, x1): def color_transfer_mkl(x0, x1):
eps = np.finfo(float).eps eps = np.finfo(float).eps
h,w,c = x0.shape h,w,c = x0.shape
h1,w1,c1 = x1.shape h1,w1,c1 = x1.shape
x0 = x0.reshape ( (h*w,c) ) x0 = x0.reshape ( (h*w,c) )
x1 = x1.reshape ( (h1*w1,c1) ) x1 = x1.reshape ( (h1*w1,c1) )
a = np.cov(x0.T) a = np.cov(x0.T)
b = np.cov(x1.T) b = np.cov(x1.T)
Da2, Ua = np.linalg.eig(a) Da2, Ua = np.linalg.eig(a)
Da = np.diag(np.sqrt(Da2.clip(eps, None))) Da = np.diag(np.sqrt(Da2.clip(eps, None)))
C = np.dot(np.dot(np.dot(np.dot(Da, Ua.T), b), Ua), Da) C = np.dot(np.dot(np.dot(np.dot(Da, Ua.T), b), Ua), Da)
Dc2, Uc = np.linalg.eig(C) Dc2, Uc = np.linalg.eig(C)
Dc = np.diag(np.sqrt(Dc2.clip(eps, None))) Dc = np.diag(np.sqrt(Dc2.clip(eps, None)))
Da_inv = np.diag(1./(np.diag(Da))) Da_inv = np.diag(1./(np.diag(Da)))
t = np.dot(np.dot(np.dot(np.dot(np.dot(np.dot(Ua, Da_inv), Uc), Dc), Uc.T), Da_inv), Ua.T) t = np.dot(np.dot(np.dot(np.dot(np.dot(np.dot(Ua, Da_inv), Uc), Dc), Uc.T), Da_inv), Ua.T)
mx0 = np.mean(x0, axis=0) mx0 = np.mean(x0, axis=0)
mx1 = np.mean(x1, axis=0) mx1 = np.mean(x1, axis=0)
result = np.dot(x0-mx0, t) + mx1 result = np.dot(x0-mx0, t) + mx1
return np.clip ( result.reshape ( (h,w,c) ).astype(x0.dtype), 0, 1) return np.clip ( result.reshape ( (h,w,c) ).astype(x0.dtype), 0, 1)
def color_transfer_idt(i0, i1, bins=256, n_rot=20): def color_transfer_idt(i0, i1, bins=256, n_rot=20):
relaxation = 1 / n_rot relaxation = 1 / n_rot
h,w,c = i0.shape h,w,c = i0.shape
h1,w1,c1 = i1.shape h1,w1,c1 = i1.shape
i0 = i0.reshape ( (h*w,c) ) i0 = i0.reshape ( (h*w,c) )
i1 = i1.reshape ( (h1*w1,c1) ) i1 = i1.reshape ( (h1*w1,c1) )
n_dims = c n_dims = c
d0 = i0.T d0 = i0.T
d1 = i1.T d1 = i1.T
for i in range(n_rot): for i in range(n_rot):
r = sp.stats.special_ortho_group.rvs(n_dims).astype(np.float32) r = sp.stats.special_ortho_group.rvs(n_dims).astype(np.float32)
d0r = np.dot(r, d0) d0r = np.dot(r, d0)
d1r = np.dot(r, d1) d1r = np.dot(r, d1)
d_r = np.empty_like(d0) d_r = np.empty_like(d0)
for j in range(n_dims): for j in range(n_dims):
lo = min(d0r[j].min(), d1r[j].min()) lo = min(d0r[j].min(), d1r[j].min())
hi = max(d0r[j].max(), d1r[j].max()) hi = max(d0r[j].max(), d1r[j].max())
p0r, edges = np.histogram(d0r[j], bins=bins, range=[lo, hi]) p0r, edges = np.histogram(d0r[j], bins=bins, range=[lo, hi])
p1r, _ = np.histogram(d1r[j], bins=bins, range=[lo, hi]) p1r, _ = np.histogram(d1r[j], bins=bins, range=[lo, hi])
@ -124,11 +124,11 @@ def color_transfer_idt(i0, i1, bins=256, n_rot=20):
cp1r = p1r.cumsum().astype(np.float32) cp1r = p1r.cumsum().astype(np.float32)
cp1r /= cp1r[-1] cp1r /= cp1r[-1]
f = np.interp(cp0r, cp1r, edges[1:]) f = np.interp(cp0r, cp1r, edges[1:])
d_r[j] = np.interp(d0r[j], edges[1:], f, left=0, right=bins) d_r[j] = np.interp(d0r[j], edges[1:], f, left=0, right=bins)
d0 = relaxation * np.linalg.solve(r, (d_r - d0r)) + d0 d0 = relaxation * np.linalg.solve(r, (d_r - d0r)) + d0
return np.clip ( d0.T.reshape ( (h,w,c) ).astype(i0.dtype) , 0, 1) return np.clip ( d0.T.reshape ( (h,w,c) ).astype(i0.dtype) , 0, 1)
@ -137,16 +137,16 @@ def laplacian_matrix(n, m):
mat_D = scipy.sparse.lil_matrix((m, m)) mat_D = scipy.sparse.lil_matrix((m, m))
mat_D.setdiag(-1, -1) mat_D.setdiag(-1, -1)
mat_D.setdiag(4) mat_D.setdiag(4)
mat_D.setdiag(-1, 1) mat_D.setdiag(-1, 1)
mat_A = scipy.sparse.block_diag([mat_D] * n).tolil() mat_A = scipy.sparse.block_diag([mat_D] * n).tolil()
mat_A.setdiag(-1, 1*m) mat_A.setdiag(-1, 1*m)
mat_A.setdiag(-1, -1*m) mat_A.setdiag(-1, -1*m)
return mat_A return mat_A
def seamless_clone(source, target, mask): def seamless_clone(source, target, mask):
h, w,c = target.shape h, w,c = target.shape
result = [] result = []
mat_A = laplacian_matrix(h, w) mat_A = laplacian_matrix(h, w)
laplacian = mat_A.tocsc() laplacian = mat_A.tocsc()
@ -155,7 +155,7 @@ def seamless_clone(source, target, mask):
mask[:,0] = 1 mask[:,0] = 1
mask[:,-1] = 1 mask[:,-1] = 1
q = np.argwhere(mask==0) q = np.argwhere(mask==0)
k = q[:,1]+q[:,0]*w k = q[:,1]+q[:,0]*w
mat_A[k, k] = 1 mat_A[k, k] = 1
mat_A[k, k + 1] = 0 mat_A[k, k + 1] = 0
@ -163,22 +163,22 @@ def seamless_clone(source, target, mask):
mat_A[k, k + w] = 0 mat_A[k, k + w] = 0
mat_A[k, k - w] = 0 mat_A[k, k - w] = 0
mat_A = mat_A.tocsc() mat_A = mat_A.tocsc()
mask_flat = mask.flatten() mask_flat = mask.flatten()
for channel in range(c): for channel in range(c):
source_flat = source[:, :, channel].flatten() source_flat = source[:, :, channel].flatten()
target_flat = target[:, :, channel].flatten() target_flat = target[:, :, channel].flatten()
mat_b = laplacian.dot(source_flat)*0.75 mat_b = laplacian.dot(source_flat)*0.75
mat_b[mask_flat==0] = target_flat[mask_flat==0] mat_b[mask_flat==0] = target_flat[mask_flat==0]
x = spsolve(mat_A, mat_b).reshape((h, w)) x = spsolve(mat_A, mat_b).reshape((h, w))
result.append (x) result.append (x)
return np.clip( np.dstack(result), 0, 1 ) return np.clip( np.dstack(result), 0, 1 )
def reinhard_color_transfer(target, source, clip=False, preserve_paper=False, source_mask=None, target_mask=None): def reinhard_color_transfer(target, source, clip=False, preserve_paper=False, source_mask=None, target_mask=None):
""" """
Transfers the color distribution from the source to the target Transfers the color distribution from the source to the target
@ -368,26 +368,26 @@ def color_hist_match(src_im, tar_im, hist_match_threshold=255):
def color_transfer_mix(img_src,img_trg): def color_transfer_mix(img_src,img_trg):
img_src = (img_src*255.0).astype(np.uint8) img_src = (img_src*255.0).astype(np.uint8)
img_trg = (img_trg*255.0).astype(np.uint8) img_trg = (img_trg*255.0).astype(np.uint8)
img_src_lab = cv2.cvtColor(img_src, cv2.COLOR_BGR2LAB) img_src_lab = cv2.cvtColor(img_src, cv2.COLOR_BGR2LAB)
img_trg_lab = cv2.cvtColor(img_trg, cv2.COLOR_BGR2LAB) img_trg_lab = cv2.cvtColor(img_trg, cv2.COLOR_BGR2LAB)
rct_light = np.clip ( linear_color_transfer(img_src_lab[...,0:1].astype(np.float32)/255.0, rct_light = np.clip ( linear_color_transfer(img_src_lab[...,0:1].astype(np.float32)/255.0,
img_trg_lab[...,0:1].astype(np.float32)/255.0 )[...,0]*255.0, img_trg_lab[...,0:1].astype(np.float32)/255.0 )[...,0]*255.0,
0, 255).astype(np.uint8) 0, 255).astype(np.uint8)
img_src_lab[...,0] = (np.ones_like (rct_light)*100).astype(np.uint8) img_src_lab[...,0] = (np.ones_like (rct_light)*100).astype(np.uint8)
img_src_lab = cv2.cvtColor(img_src_lab, cv2.COLOR_LAB2BGR) img_src_lab = cv2.cvtColor(img_src_lab, cv2.COLOR_LAB2BGR)
img_trg_lab[...,0] = (np.ones_like (rct_light)*100).astype(np.uint8) img_trg_lab[...,0] = (np.ones_like (rct_light)*100).astype(np.uint8)
img_trg_lab = cv2.cvtColor(img_trg_lab, cv2.COLOR_LAB2BGR) img_trg_lab = cv2.cvtColor(img_trg_lab, cv2.COLOR_LAB2BGR)
img_rct = color_transfer_sot( img_src_lab.astype(np.float32), img_trg_lab.astype(np.float32) ) img_rct = color_transfer_sot( img_src_lab.astype(np.float32), img_trg_lab.astype(np.float32) )
img_rct = np.clip(img_rct, 0, 255).astype(np.uint8) img_rct = np.clip(img_rct, 0, 255).astype(np.uint8)
img_rct = cv2.cvtColor(img_rct, cv2.COLOR_BGR2LAB) img_rct = cv2.cvtColor(img_rct, cv2.COLOR_BGR2LAB)
img_rct[...,0] = rct_light img_rct[...,0] = rct_light
img_rct = cv2.cvtColor(img_rct, cv2.COLOR_LAB2BGR) img_rct = cv2.cvtColor(img_rct, cv2.COLOR_LAB2BGR)
return (img_rct / 255.0).astype(np.float32) return (img_rct / 255.0).astype(np.float32)

View file

@ -13,24 +13,24 @@ def normalize_channels(img, target_channels):
if c == 0 and target_channels > 0: if c == 0 and target_channels > 0:
img = img[...,np.newaxis] img = img[...,np.newaxis]
c = 1 c = 1
if c == 1 and target_channels > 1: if c == 1 and target_channels > 1:
img = np.repeat (img, target_channels, -1) img = np.repeat (img, target_channels, -1)
c = target_channels c = target_channels
if c > target_channels: if c > target_channels:
img = img[...,0:target_channels] img = img[...,0:target_channels]
c = target_channels c = target_channels
return img return img
def cut_odd_image(img): def cut_odd_image(img):
h, w, c = img.shape h, w, c = img.shape
wm, hm = w % 2, h % 2 wm, hm = w % 2, h % 2
if wm + hm != 0: if wm + hm != 0:
img = img[0:h-hm,0:w-wm,:] img = img[0:h-hm,0:w-wm,:]
return img return img
def overlay_alpha_image(img_target, img_source, xy_offset=(0,0) ): def overlay_alpha_image(img_target, img_source, xy_offset=(0,0) ):
(h,w,c) = img_source.shape (h,w,c) = img_source.shape
if c != 4: if c != 4:

View file

@ -16,7 +16,7 @@ def _get_pil_font (font, size):
def get_text_image( shape, text, color=(1,1,1), border=0.2, font=None): def get_text_image( shape, text, color=(1,1,1), border=0.2, font=None):
h,w,c = shape h,w,c = shape
try: try:
pil_font = _get_pil_font( localization.get_default_ttf_font_name() , h-2) pil_font = _get_pil_font( localization.get_default_ttf_font_name() , h-2)
canvas = Image.new('RGB', (w,h) , (0,0,0) ) canvas = Image.new('RGB', (w,h) , (0,0,0) )
@ -25,7 +25,7 @@ def get_text_image( shape, text, color=(1,1,1), border=0.2, font=None):
draw.text(offset, text, font=pil_font, fill=tuple((np.array(color)*255).astype(np.int)) ) draw.text(offset, text, font=pil_font, fill=tuple((np.array(color)*255).astype(np.int)) )
result = np.asarray(canvas) / 255 result = np.asarray(canvas) / 255
if c > 3: if c > 3:
result = np.concatenate ( (result, np.ones ((h,w,c-3)) ), axis=-1 ) result = np.concatenate ( (result, np.ones ((h,w,c-3)) ), axis=-1 )
elif c < 3: elif c < 3:

View file

@ -6,7 +6,7 @@ def gen_warp_params (source, flip, rotation_range=[-10,10], scale_range=[-0.5, 0
h,w,c = source.shape h,w,c = source.shape
if (h != w): if (h != w):
raise ValueError ('gen_warp_params accepts only square images.') raise ValueError ('gen_warp_params accepts only square images.')
if rnd_seed != None: if rnd_seed != None:
rnd_state = np.random.RandomState (rnd_seed) rnd_state = np.random.RandomState (rnd_seed)
else: else:
@ -15,9 +15,9 @@ def gen_warp_params (source, flip, rotation_range=[-10,10], scale_range=[-0.5, 0
rotation = rnd_state.uniform( rotation_range[0], rotation_range[1] ) rotation = rnd_state.uniform( rotation_range[0], rotation_range[1] )
scale = rnd_state.uniform(1 +scale_range[0], 1 +scale_range[1]) scale = rnd_state.uniform(1 +scale_range[0], 1 +scale_range[1])
tx = rnd_state.uniform( tx_range[0], tx_range[1] ) tx = rnd_state.uniform( tx_range[0], tx_range[1] )
ty = rnd_state.uniform( ty_range[0], ty_range[1] ) ty = rnd_state.uniform( ty_range[0], ty_range[1] )
p_flip = flip and rnd_state.randint(10) < 4 p_flip = flip and rnd_state.randint(10) < 4
#random warp by grid #random warp by grid
cell_size = [ w // (2**i) for i in range(1,4) ] [ rnd_state.randint(3) ] cell_size = [ w // (2**i) for i in range(1,4) ] [ rnd_state.randint(3) ]
cell_count = w // cell_size + 1 cell_count = w // cell_size + 1

View file

@ -189,29 +189,29 @@ class InteractBase(object):
ar = self.key_events.get(wnd_name, []) ar = self.key_events.get(wnd_name, [])
self.key_events[wnd_name] = [] self.key_events[wnd_name] = []
return ar return ar
def input(self, s): def input(self, s):
return input(s) return input(s)
def input_number(self, s, default_value, valid_list=None, show_default_value=True, add_info=None, help_message=None): def input_number(self, s, default_value, valid_list=None, show_default_value=True, add_info=None, help_message=None):
if show_default_value and default_value is not None: if show_default_value and default_value is not None:
s = f"[{default_value}] {s}" s = f"[{default_value}] {s}"
if add_info is not None or \ if add_info is not None or \
help_message is not None: help_message is not None:
s += " (" s += " ("
if add_info is not None: if add_info is not None:
s += f" {add_info}" s += f" {add_info}"
if help_message is not None: if help_message is not None:
s += " ?:help" s += " ?:help"
if add_info is not None or \ if add_info is not None or \
help_message is not None: help_message is not None:
s += " )" s += " )"
s += " : " s += " : "
while True: while True:
try: try:
inp = input(s) inp = input(s)
@ -232,32 +232,32 @@ class InteractBase(object):
except: except:
result = default_value result = default_value
break break
print(result) print(result)
return result return result
def input_int(self, s, default_value, valid_list=None, add_info=None, show_default_value=True, help_message=None): def input_int(self, s, default_value, valid_list=None, add_info=None, show_default_value=True, help_message=None):
if show_default_value: if show_default_value:
if len(s) != 0: if len(s) != 0:
s = f"[{default_value}] {s}" s = f"[{default_value}] {s}"
else: else:
s = f"[{default_value}]" s = f"[{default_value}]"
if add_info is not None or \ if add_info is not None or \
help_message is not None: help_message is not None:
s += " (" s += " ("
if add_info is not None: if add_info is not None:
s += f" {add_info}" s += f" {add_info}"
if help_message is not None: if help_message is not None:
s += " ?:help" s += " ?:help"
if add_info is not None or \ if add_info is not None or \
help_message is not None: help_message is not None:
s += " )" s += " )"
s += " : " s += " : "
while True: while True:
try: try:
inp = input(s) inp = input(s)
@ -280,13 +280,13 @@ class InteractBase(object):
print (result) print (result)
return result return result
def input_bool(self, s, default_value, help_message=None): def input_bool(self, s, default_value, help_message=None):
s = f"[{yn_str[default_value]}] {s} ( y/n" s = f"[{yn_str[default_value]}] {s} ( y/n"
if help_message is not None: if help_message is not None:
s += " ?:help" s += " ?:help"
s += " ) : " s += " ) : "
while True: while True:
try: try:
inp = input(s) inp = input(s)
@ -305,46 +305,46 @@ class InteractBase(object):
def input_str(self, s, default_value=None, valid_list=None, show_default_value=True, help_message=None): def input_str(self, s, default_value=None, valid_list=None, show_default_value=True, help_message=None):
if show_default_value and default_value is not None: if show_default_value and default_value is not None:
s = f"[{default_value}] {s}" s = f"[{default_value}] {s}"
if valid_list is not None or \ if valid_list is not None or \
help_message is not None: help_message is not None:
s += " (" s += " ("
if valid_list is not None: if valid_list is not None:
s += " " + "/".join(valid_list) s += " " + "/".join(valid_list)
if help_message is not None: if help_message is not None:
s += " ?:help" s += " ?:help"
if valid_list is not None or \ if valid_list is not None or \
help_message is not None: help_message is not None:
s += " )" s += " )"
s += " : " s += " : "
while True: while True:
try: try:
inp = input(s) inp = input(s)
if len(inp) == 0: if len(inp) == 0:
if default_value is None: if default_value is None:
print("") print("")
return None return None
result = default_value result = default_value
break break
if help_message is not None and inp == '?': if help_message is not None and inp == '?':
print(help_message) print(help_message)
continue continue
if valid_list is not None: if valid_list is not None:
if inp.lower() in valid_list: if inp.lower() in valid_list:
result = inp.lower() result = inp.lower()
break break
if inp in valid_list: if inp in valid_list:
result = inp result = inp
break break
continue continue
result = inp result = inp
@ -352,10 +352,10 @@ class InteractBase(object):
except: except:
result = default_value result = default_value
break break
print(result) print(result)
return result return result
def input_process(self, stdin_fd, sq, str): def input_process(self, stdin_fd, sq, str):
sys.stdin = os.fdopen(stdin_fd) sys.stdin = os.fdopen(stdin_fd)
try: try:
@ -389,8 +389,8 @@ class InteractBase(object):
sys.stdin.read() sys.stdin.read()
except: except:
pass pass
def input_skip_pending(self): def input_skip_pending(self):
if is_colab: if is_colab:
# currently it does not work on Colab # currently it does not work on Colab
return return
@ -401,7 +401,7 @@ class InteractBase(object):
p.daemon = True p.daemon = True
p.start() p.start()
time.sleep(0.5) time.sleep(0.5)
p.terminate() p.terminate()
sys.stdin = os.fdopen( sys.stdin.fileno() ) sys.stdin = os.fdopen( sys.stdin.fileno() )
@ -409,11 +409,11 @@ class InteractDesktop(InteractBase):
def __init__(self): def __init__(self):
colorama.init() colorama.init()
super().__init__() super().__init__()
def color_red(self): def color_red(self):
pass pass
def is_support_windows(self): def is_support_windows(self):
return True return True
@ -469,7 +469,7 @@ class InteractDesktop(InteractBase):
shift_pressed = False shift_pressed = False
if ord_key != -1: if ord_key != -1:
chr_key = chr(ord_key) chr_key = chr(ord_key)
if chr_key >= 'A' and chr_key <= 'Z': if chr_key >= 'A' and chr_key <= 'Z':
shift_pressed = True shift_pressed = True
ord_key += 32 ord_key += 32

View file

@ -12,7 +12,7 @@ class SubprocessGenerator(object):
self.p = None self.p = None
if start_now: if start_now:
self._start() self._start()
def _start(self): def _start(self):
if self.p == None: if self.p == None:
user_param = self.user_param user_param = self.user_param

View file

@ -16,7 +16,7 @@ class Subprocessor(object):
c2s = multiprocessing.Queue() c2s = multiprocessing.Queue()
self.p = multiprocessing.Process(target=self._subprocess_run, args=(client_dict,s2c,c2s) ) self.p = multiprocessing.Process(target=self._subprocess_run, args=(client_dict,s2c,c2s) )
self.s2c = s2c self.s2c = s2c
self.c2s = c2s self.c2s = c2s
self.p.daemon = True self.p.daemon = True
self.p.start() self.p.start()
@ -88,13 +88,13 @@ class Subprocessor(object):
print ('Exception: %s' % (traceback.format_exc()) ) print ('Exception: %s' % (traceback.format_exc()) )
c2s.put ( {'op': 'error', 'data' : data} ) c2s.put ( {'op': 'error', 'data' : data} )
# disable pickling # disable pickling
def __getstate__(self): def __getstate__(self):
return dict() return dict()
def __setstate__(self, d): def __setstate__(self, d):
self.__dict__.update(d) self.__dict__.update(d)
#overridable #overridable
def __init__(self, name, SubprocessorCli_class, no_response_time_sec = 0, io_loop_sleep_time=0.005, initialize_subprocesses_in_serial=False): def __init__(self, name, SubprocessorCli_class, no_response_time_sec = 0, io_loop_sleep_time=0.005, initialize_subprocesses_in_serial=False):
if not issubclass(SubprocessorCli_class, Subprocessor.Cli): if not issubclass(SubprocessorCli_class, Subprocessor.Cli):

View file

@ -1,7 +1,7 @@
import sys import sys
import ctypes import ctypes
import os import os
class Device(object): class Device(object):
def __init__(self, index, name, total_mem, free_mem, cc=0): def __init__(self, index, name, total_mem, free_mem, cc=0):
self.index = index self.index = index
@ -11,25 +11,25 @@ class Device(object):
self.total_mem_gb = total_mem / 1024**3 self.total_mem_gb = total_mem / 1024**3
self.free_mem = free_mem self.free_mem = free_mem
self.free_mem_gb = free_mem / 1024**3 self.free_mem_gb = free_mem / 1024**3
def __str__(self): def __str__(self):
return f"[{self.index}]:[{self.name}][{self.free_mem_gb:.3}/{self.total_mem_gb :.3}]" return f"[{self.index}]:[{self.name}][{self.free_mem_gb:.3}/{self.total_mem_gb :.3}]"
class Devices(object): class Devices(object):
all_devices = None all_devices = None
def __init__(self, devices): def __init__(self, devices):
self.devices = devices self.devices = devices
def __len__(self): def __len__(self):
return len(self.devices) return len(self.devices)
def __getitem__(self, key): def __getitem__(self, key):
result = self.devices[key] result = self.devices[key]
if isinstance(key, slice): if isinstance(key, slice):
return Devices(result) return Devices(result)
return result return result
def __iter__(self): def __iter__(self):
for device in self.devices: for device in self.devices:
yield device yield device
@ -59,14 +59,14 @@ class Devices(object):
if device.index == idx: if device.index == idx:
return device return device
return None return None
def get_devices_from_index_list(self, idx_list): def get_devices_from_index_list(self, idx_list):
result = [] result = []
for device in self.devices: for device in self.devices:
if device.index in idx_list: if device.index in idx_list:
result += [device] result += [device]
return Devices(result) return Devices(result)
def get_equal_devices(self, device): def get_equal_devices(self, device):
device_name = device.name device_name = device.name
result = [] result = []
@ -74,7 +74,7 @@ class Devices(object):
if device.name == device_name: if device.name == device_name:
result.append (device) result.append (device)
return Devices(result) return Devices(result)
def get_devices_at_least_mem(self, totalmemsize_gb): def get_devices_at_least_mem(self, totalmemsize_gb):
result = [] result = []
for device in self.devices: for device in self.devices:
@ -84,7 +84,7 @@ class Devices(object):
@staticmethod @staticmethod
def initialize_main_env(): def initialize_main_env():
min_cc = int(os.environ.get("TF_MIN_REQ_CAP", 35)) min_cc = int(os.environ.get("TF_MIN_REQ_CAP", 35))
libnames = ('libcuda.so', 'libcuda.dylib', 'nvcuda.dll') libnames = ('libcuda.so', 'libcuda.dylib', 'nvcuda.dll')
for libname in libnames: for libname in libnames:
try: try:
@ -122,40 +122,40 @@ class Devices(object):
if cuda.cuMemGetInfo_v2(ctypes.byref(freeMem), ctypes.byref(totalMem)) == 0: if cuda.cuMemGetInfo_v2(ctypes.byref(freeMem), ctypes.byref(totalMem)) == 0:
cc = cc_major.value * 10 + cc_minor.value cc = cc_major.value * 10 + cc_minor.value
if cc >= min_cc: if cc >= min_cc:
devices.append ( {'name' : name.split(b'\0', 1)[0].decode(), devices.append ( {'name' : name.split(b'\0', 1)[0].decode(),
'total_mem' : totalMem.value, 'total_mem' : totalMem.value,
'free_mem' : freeMem.value, 'free_mem' : freeMem.value,
'cc' : cc 'cc' : cc
}) })
cuda.cuCtxDetach(context) cuda.cuCtxDetach(context)
os.environ['NN_DEVICES_INITIALIZED'] = '1' os.environ['NN_DEVICES_INITIALIZED'] = '1'
os.environ['NN_DEVICES_COUNT'] = str(len(devices)) os.environ['NN_DEVICES_COUNT'] = str(len(devices))
for i, device in enumerate(devices): for i, device in enumerate(devices):
os.environ[f'NN_DEVICE_{i}_NAME'] = device['name'] os.environ[f'NN_DEVICE_{i}_NAME'] = device['name']
os.environ[f'NN_DEVICE_{i}_TOTAL_MEM'] = str(device['total_mem']) os.environ[f'NN_DEVICE_{i}_TOTAL_MEM'] = str(device['total_mem'])
os.environ[f'NN_DEVICE_{i}_FREE_MEM'] = str(device['free_mem']) os.environ[f'NN_DEVICE_{i}_FREE_MEM'] = str(device['free_mem'])
os.environ[f'NN_DEVICE_{i}_CC'] = str(device['cc']) os.environ[f'NN_DEVICE_{i}_CC'] = str(device['cc'])
@staticmethod @staticmethod
def getDevices(): def getDevices():
if Devices.all_devices is None: if Devices.all_devices is None:
if int(os.environ.get("NN_DEVICES_INITIALIZED", 0)) != 1: if int(os.environ.get("NN_DEVICES_INITIALIZED", 0)) != 1:
raise Exception("nn devices are not initialized. Run initialize_main_env() in main process.") raise Exception("nn devices are not initialized. Run initialize_main_env() in main process.")
devices = [] devices = []
for i in range ( int(os.environ['NN_DEVICES_COUNT']) ): for i in range ( int(os.environ['NN_DEVICES_COUNT']) ):
devices.append ( Device(index=i, devices.append ( Device(index=i,
name=os.environ[f'NN_DEVICE_{i}_NAME'], name=os.environ[f'NN_DEVICE_{i}_NAME'],
total_mem=int(os.environ[f'NN_DEVICE_{i}_TOTAL_MEM']), total_mem=int(os.environ[f'NN_DEVICE_{i}_TOTAL_MEM']),
free_mem=int(os.environ[f'NN_DEVICE_{i}_FREE_MEM']), free_mem=int(os.environ[f'NN_DEVICE_{i}_FREE_MEM']),
cc=int(os.environ[f'NN_DEVICE_{i}_CC']) )) cc=int(os.environ[f'NN_DEVICE_{i}_CC']) ))
Devices.all_devices = Devices(devices) Devices.all_devices = Devices(devices)
return Devices.all_devices return Devices.all_devices
""" """
if Devices.all_devices is None: if Devices.all_devices is None:
min_cc = int(os.environ.get("TF_MIN_REQ_CAP", 35)) min_cc = int(os.environ.get("TF_MIN_REQ_CAP", 35))
libnames = ('libcuda.so', 'libcuda.dylib', 'nvcuda.dll') libnames = ('libcuda.so', 'libcuda.dylib', 'nvcuda.dll')
for libname in libnames: for libname in libnames:
@ -195,7 +195,7 @@ if Devices.all_devices is None:
cc = cc_major.value * 10 + cc_minor.value cc = cc_major.value * 10 + cc_minor.value
if cc >= min_cc: if cc >= min_cc:
devices.append ( Device(index=i, devices.append ( Device(index=i,
name=name.split(b'\0', 1)[0].decode(), name=name.split(b'\0', 1)[0].decode(),
total_mem=totalMem.value, total_mem=totalMem.value,
free_mem=freeMem.value, free_mem=freeMem.value,
cc=cc) ) cc=cc) )

View file

@ -11,17 +11,14 @@ def initialize_initializers(nn):
class initializers(): class initializers():
class ca (init_ops.Initializer): class ca (init_ops.Initializer):
def __init__(self, dtype=None):
pass
def __call__(self, shape, dtype=None, partition_info=None): def __call__(self, shape, dtype=None, partition_info=None):
return tf.zeros( shape, name="_cai_") return tf.zeros( shape, dtype=dtype, name="_cai_")
@staticmethod @staticmethod
def generate_batch( data_list, eps_std=0.05 ): def generate_batch( data_list, eps_std=0.05 ):
# list of (shape, np.dtype) # list of (shape, np.dtype)
return CAInitializerSubprocessor (data_list).run() return CAInitializerSubprocessor (data_list).run()
nn.initializers = initializers nn.initializers = initializers
class CAInitializerSubprocessor(Subprocessor): class CAInitializerSubprocessor(Subprocessor):
@ -62,7 +59,7 @@ class CAInitializerSubprocessor(Subprocessor):
x = x * np.sqrt( (2/fan_in) / np.var(x) ) x = x * np.sqrt( (2/fan_in) / np.var(x) )
x = np.transpose( x, (2, 3, 1, 0) ) x = np.transpose( x, (2, 3, 1, 0) )
return x.astype(dtype) return x.astype(dtype)
class Cli(Subprocessor.Cli): class Cli(Subprocessor.Cli):
#override #override
def process_data(self, data): def process_data(self, data):

View file

@ -8,7 +8,7 @@ import numpy as np
def initialize_layers(nn): def initialize_layers(nn):
tf = nn.tf tf = nn.tf
class Saveable(): class Saveable():
def __init__(self, name=None): def __init__(self, name=None):
self.name = name self.name = name
@ -65,6 +65,8 @@ def initialize_layers(nn):
sub_w_name = "/".join(w_name_split[1:]) sub_w_name = "/".join(w_name_split[1:])
w_val = d.get(sub_w_name, None) w_val = d.get(sub_w_name, None)
w_val = np.reshape( w_val, w.shape.as_list() )
if w_val is None: if w_val is None:
io.log_err(f"Weight {w.name} was not loaded from file {filename}") io.log_err(f"Weight {w.name} was not loaded from file {filename}")
tuples.append ( (w, w.initializer) ) tuples.append ( (w, w.initializer) )
@ -77,8 +79,8 @@ def initialize_layers(nn):
def init_weights(self): def init_weights(self):
ops = [] ops = []
ca_tuples_w = [] ca_tuples_w = []
ca_tuples = [] ca_tuples = []
for w in self.get_weights(): for w in self.get_weights():
initializer = w.initializer initializer = w.initializer
@ -92,12 +94,12 @@ def initialize_layers(nn):
if len(ops) != 0: if len(ops) != 0:
nn.tf_sess.run (ops) nn.tf_sess.run (ops)
if len(ca_tuples) != 0: if len(ca_tuples) != 0:
nn.tf_batch_set_value( [*zip(ca_tuples_w, nn.initializers.ca.generate_batch (ca_tuples))] ) nn.tf_batch_set_value( [*zip(ca_tuples_w, nn.initializers.ca.generate_batch (ca_tuples))] )
nn.Saveable = Saveable nn.Saveable = Saveable
class LayerBase(): class LayerBase():
def __init__(self, name=None, **kwargs): def __init__(self, name=None, **kwargs):
self.name = name self.name = name
@ -124,7 +126,7 @@ def initialize_layers(nn):
nn.tf_batch_set_value (tuples) nn.tf_batch_set_value (tuples)
nn.LayerBase = LayerBase nn.LayerBase = LayerBase
class ModelBase(Saveable): class ModelBase(Saveable):
def __init__(self, *args, name=None, **kwargs): def __init__(self, *args, name=None, **kwargs):
super().__init__(name=name) super().__init__(name=name)
@ -157,33 +159,33 @@ def initialize_layers(nn):
def build(self): def build(self):
with tf.variable_scope(self.name): with tf.variable_scope(self.name):
current_vars = [] current_vars = []
generator = None generator = None
while True: while True:
if generator is None: if generator is None:
generator = self.on_build(*self.args, **self.kwargs) generator = self.on_build(*self.args, **self.kwargs)
if not isinstance(generator, types.GeneratorType): if not isinstance(generator, types.GeneratorType):
generator = None generator = None
if generator is not None: if generator is not None:
try: try:
next(generator) next(generator)
except StopIteration: except StopIteration:
generator = None generator = None
v = vars(self) v = vars(self)
new_vars = self.xor_list (current_vars, list(v.keys()) ) new_vars = self.xor_list (current_vars, list(v.keys()) )
for name in new_vars: for name in new_vars:
self._build_sub(v[name],name) self._build_sub(v[name],name)
current_vars += new_vars current_vars += new_vars
if generator is None: if generator is None:
break break
self.built = True self.built = True
#override #override
@ -211,9 +213,9 @@ def initialize_layers(nn):
def on_build(self, *args, **kwargs): def on_build(self, *args, **kwargs):
""" """
init model layers here init model layers here
return 'yield' if build is not finished return 'yield' if build is not finished
therefore dependency models will be initialized therefore dependency models will be initialized
""" """
pass pass
@ -227,16 +229,16 @@ def initialize_layers(nn):
self.build() self.build()
return self.forward(*args, **kwargs) return self.forward(*args, **kwargs)
def compute_output_shape(self, shapes): def compute_output_shape(self, shapes):
if not self.built: if not self.built:
self.build() self.build()
not_list = False not_list = False
if not isinstance(shapes, list): if not isinstance(shapes, list):
not_list = True not_list = True
shapes = [shapes] shapes = [shapes]
with tf.device('/CPU:0'): with tf.device('/CPU:0'):
# CPU tensors will not impact any performance, only slightly RAM "leakage" # CPU tensors will not impact any performance, only slightly RAM "leakage"
phs = [] phs = []
@ -244,24 +246,33 @@ def initialize_layers(nn):
phs += [ tf.placeholder(dtype, sh) ] phs += [ tf.placeholder(dtype, sh) ]
result = self.__call__(phs[0] if not_list else phs) result = self.__call__(phs[0] if not_list else phs)
if not isinstance(result, list): if not isinstance(result, list):
result = [result] result = [result]
result_shapes = [] result_shapes = []
for t in result: for t in result:
result_shapes += [ t.shape.as_list() ] result_shapes += [ t.shape.as_list() ]
return result_shapes[0] if not_list else result_shapes return result_shapes[0] if not_list else result_shapes
def compute_output_channels(self, shapes):
shape = self.compute_output_shape(shapes)
shape_len = len(shape)
if shape_len == 4:
if nn.data_format == "NCHW":
return shape[1]
return shape[-1]
def build_for_run(self, shapes_list): def build_for_run(self, shapes_list):
if not isinstance(shapes_list, list): if not isinstance(shapes_list, list):
raise ValueError("shapes_list must be a list.") raise ValueError("shapes_list must be a list.")
self.run_placeholders = [] self.run_placeholders = []
for dtype,sh in shapes_list: for dtype,sh in shapes_list:
self.run_placeholders.append ( tf.placeholder(dtype, (None,)+sh) ) self.run_placeholders.append ( tf.placeholder(dtype, sh) )
self.run_output = self.__call__(self.run_placeholders) self.run_output = self.__call__(self.run_placeholders)
@ -279,7 +290,7 @@ def initialize_layers(nn):
return nn.tf_sess.run ( self.run_output, feed_dict=feed_dict) return nn.tf_sess.run ( self.run_output, feed_dict=feed_dict)
nn.ModelBase = ModelBase nn.ModelBase = ModelBase
class Conv2D(LayerBase): class Conv2D(LayerBase):
""" """
use_wscale bool enables equalized learning rate, kernel_initializer will be forced to random_normal use_wscale bool enables equalized learning rate, kernel_initializer will be forced to random_normal
@ -292,6 +303,9 @@ def initialize_layers(nn):
if not isinstance(dilations, int): if not isinstance(dilations, int):
raise ValueError ("dilations must be an int type") raise ValueError ("dilations must be an int type")
if dtype is None:
dtype = nn.tf_floatx
if isinstance(padding, str): if isinstance(padding, str):
if padding == "SAME": if padding == "SAME":
padding = ( (kernel_size - 1) * dilations + 1 ) // 2 padding = ( (kernel_size - 1) * dilations + 1 ) // 2
@ -302,37 +316,48 @@ def initialize_layers(nn):
if isinstance(padding, int): if isinstance(padding, int):
if padding != 0: if padding != 0:
padding = [ [0,0], [padding,padding], [padding,padding], [0,0] ] if nn.data_format == "NHWC":
padding = [ [0,0], [padding,padding], [padding,padding], [0,0] ]
else:
padding = [ [0,0], [0,0], [padding,padding], [padding,padding] ]
else: else:
padding = None padding = None
if nn.data_format == "NHWC":
strides = [1,strides,strides,1]
else:
strides = [1,1,strides,strides]
if nn.data_format == "NHWC":
dilations = [1,dilations,dilations,1]
else:
dilations = [1,1,dilations,dilations]
self.in_ch = in_ch self.in_ch = in_ch
self.out_ch = out_ch self.out_ch = out_ch
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.strides = [1,strides,strides,1] self.strides = strides
self.padding = padding self.padding = padding
self.dilations = [1,dilations,dilations,1] self.dilations = dilations
self.use_bias = use_bias self.use_bias = use_bias
self.use_wscale = use_wscale self.use_wscale = use_wscale
self.kernel_initializer = None if use_wscale else kernel_initializer self.kernel_initializer = kernel_initializer
self.bias_initializer = bias_initializer self.bias_initializer = bias_initializer
self.trainable = trainable self.trainable = trainable
if dtype is None:
dtype = nn.tf_floatx
self.dtype = dtype self.dtype = dtype
super().__init__(**kwargs) super().__init__(**kwargs)
def build_weights(self): def build_weights(self):
kernel_initializer = self.kernel_initializer kernel_initializer = self.kernel_initializer
if self.use_wscale:
gain = 1.0 if self.kernel_size == 1 else np.sqrt(2)
fan_in = self.kernel_size*self.kernel_size*self.in_ch
he_std = gain / np.sqrt(fan_in) # He init
self.wscale = tf.constant(he_std, dtype=self.dtype )
kernel_initializer = tf.initializers.random_normal(0, 1.0, dtype=self.dtype)
if kernel_initializer is None: if kernel_initializer is None:
if self.use_wscale: kernel_initializer = tf.initializers.glorot_uniform(dtype=self.dtype)
gain = 1.0 if self.kernel_size == 1 else np.sqrt(2)
fan_in = self.kernel_size*self.kernel_size*self.in_ch
he_std = gain / np.sqrt(fan_in) # He init
self.wscale = tf.constant(he_std, dtype=self.dtype )
kernel_initializer = tf.initializers.random_normal(0, 1.0, dtype=self.dtype)
else:
kernel_initializer = tf.initializers.glorot_uniform(dtype=self.dtype)
self.weight = tf.get_variable("weight", (self.kernel_size,self.kernel_size,self.in_ch,self.out_ch), dtype=self.dtype, initializer=kernel_initializer, trainable=self.trainable ) self.weight = tf.get_variable("weight", (self.kernel_size,self.kernel_size,self.in_ch,self.out_ch), dtype=self.dtype, initializer=kernel_initializer, trainable=self.trainable )
@ -341,7 +366,7 @@ def initialize_layers(nn):
if bias_initializer is None: if bias_initializer is None:
bias_initializer = tf.initializers.zeros(dtype=self.dtype) bias_initializer = tf.initializers.zeros(dtype=self.dtype)
self.bias = tf.get_variable("bias", (1,1,1,self.out_ch), dtype=self.dtype, initializer=bias_initializer, trainable=self.trainable ) self.bias = tf.get_variable("bias", (self.out_ch,), dtype=self.dtype, initializer=bias_initializer, trainable=self.trainable )
def get_weights(self): def get_weights(self):
weights = [self.weight] weights = [self.weight]
@ -357,9 +382,13 @@ def initialize_layers(nn):
if self.padding is not None: if self.padding is not None:
x = tf.pad (x, self.padding, mode='CONSTANT') x = tf.pad (x, self.padding, mode='CONSTANT')
x = tf.nn.conv2d(x, weight, self.strides, 'VALID', dilations=self.dilations) x = tf.nn.conv2d(x, weight, self.strides, 'VALID', dilations=self.dilations, data_format=nn.data_format)
if self.use_bias: if self.use_bias:
x = x + self.bias if nn.data_format == "NHWC":
bias = tf.reshape (self.bias, (1,1,1,self.out_ch) )
else:
bias = tf.reshape (self.bias, (1,self.out_ch,1,1) )
x = tf.add(x, bias)
return x return x
def __str__(self): def __str__(self):
@ -367,7 +396,7 @@ def initialize_layers(nn):
return r return r
nn.Conv2D = Conv2D nn.Conv2D = Conv2D
class Conv2DTranspose(LayerBase): class Conv2DTranspose(LayerBase):
""" """
use_wscale enables weight scale (equalized learning rate) use_wscale enables weight scale (equalized learning rate)
@ -376,6 +405,10 @@ def initialize_layers(nn):
def __init__(self, in_ch, out_ch, kernel_size, strides=2, padding='SAME', use_bias=True, use_wscale=False, kernel_initializer=None, bias_initializer=None, trainable=True, dtype=None, **kwargs ): def __init__(self, in_ch, out_ch, kernel_size, strides=2, padding='SAME', use_bias=True, use_wscale=False, kernel_initializer=None, bias_initializer=None, trainable=True, dtype=None, **kwargs ):
if not isinstance(strides, int): if not isinstance(strides, int):
raise ValueError ("strides must be an int type") raise ValueError ("strides must be an int type")
if dtype is None:
dtype = nn.tf_floatx
self.in_ch = in_ch self.in_ch = in_ch
self.out_ch = out_ch self.out_ch = out_ch
self.kernel_size = kernel_size self.kernel_size = kernel_size
@ -383,33 +416,30 @@ def initialize_layers(nn):
self.padding = padding self.padding = padding
self.use_bias = use_bias self.use_bias = use_bias
self.use_wscale = use_wscale self.use_wscale = use_wscale
self.kernel_initializer = None if use_wscale else kernel_initializer self.kernel_initializer = kernel_initializer
self.bias_initializer = bias_initializer self.bias_initializer = bias_initializer
self.trainable = trainable self.trainable = trainable
if dtype is None:
dtype = nn.tf_floatx
self.dtype = dtype self.dtype = dtype
super().__init__(**kwargs) super().__init__(**kwargs)
def build_weights(self): def build_weights(self):
kernel_initializer = self.kernel_initializer kernel_initializer = self.kernel_initializer
if self.use_wscale:
gain = 1.0 if self.kernel_size == 1 else np.sqrt(2)
fan_in = self.kernel_size*self.kernel_size*self.in_ch
he_std = gain / np.sqrt(fan_in) # He init
self.wscale = tf.constant(he_std, dtype=self.dtype )
kernel_initializer = tf.initializers.random_normal(0, 1.0, dtype=self.dtype)
if kernel_initializer is None: if kernel_initializer is None:
if self.use_wscale: kernel_initializer = tf.initializers.glorot_uniform(dtype=self.dtype)
gain = 1.0 if self.kernel_size == 1 else np.sqrt(2)
fan_in = self.kernel_size*self.kernel_size*self.in_ch
he_std = gain / np.sqrt(fan_in) # He init
self.wscale = tf.constant(he_std, dtype=self.dtype )
kernel_initializer = tf.initializers.random_normal(0, 1.0, dtype=self.dtype)
else:
kernel_initializer = tf.initializers.glorot_uniform(dtype=self.dtype)
self.weight = tf.get_variable("weight", (self.kernel_size,self.kernel_size,self.out_ch,self.in_ch), dtype=self.dtype, initializer=kernel_initializer, trainable=self.trainable ) self.weight = tf.get_variable("weight", (self.kernel_size,self.kernel_size,self.out_ch,self.in_ch), dtype=self.dtype, initializer=kernel_initializer, trainable=self.trainable )
if self.use_bias: if self.use_bias:
bias_initializer = self.bias_initializer bias_initializer = self.bias_initializer
if bias_initializer is None: if bias_initializer is None:
bias_initializer = tf.initializers.zeros(dtype=self.dtype) bias_initializer = tf.initializers.zeros(dtype=self.dtype)
self.bias = tf.get_variable("bias", (1,1,1,self.out_ch), dtype=self.dtype, initializer=bias_initializer, trainable=self.trainable )
self.bias = tf.get_variable("bias", (self.out_ch,), dtype=self.dtype, initializer=bias_initializer, trainable=self.trainable )
def get_weights(self): def get_weights(self):
weights = [self.weight] weights = [self.weight]
@ -420,21 +450,34 @@ def initialize_layers(nn):
def __call__(self, x): def __call__(self, x):
shape = x.shape shape = x.shape
h,w,c = shape[1], shape[2], shape[3] if nn.data_format == "NHWC":
h,w,c = shape[1], shape[2], shape[3]
output_shape = tf.stack ( (tf.shape(x)[0], output_shape = tf.stack ( (tf.shape(x)[0],
self.deconv_length(w, self.strides, self.kernel_size, self.padding), self.deconv_length(w, self.strides, self.kernel_size, self.padding),
self.deconv_length(h, self.strides, self.kernel_size, self.padding), self.deconv_length(h, self.strides, self.kernel_size, self.padding),
self.out_ch) ) self.out_ch) )
strides = [1,self.strides,self.strides,1]
else:
c,h,w = shape[1], shape[2], shape[3]
output_shape = tf.stack ( (tf.shape(x)[0],
self.out_ch,
self.deconv_length(w, self.strides, self.kernel_size, self.padding),
self.deconv_length(h, self.strides, self.kernel_size, self.padding),
) )
strides = [1,1,self.strides,self.strides]
weight = self.weight weight = self.weight
if self.use_wscale: if self.use_wscale:
weight = weight * self.wscale weight = weight * self.wscale
x = tf.nn.conv2d_transpose(x, weight, output_shape, [1,self.strides,self.strides,1], padding=self.padding) x = tf.nn.conv2d_transpose(x, weight, output_shape, strides, padding=self.padding, data_format=nn.data_format)
if self.use_bias: if self.use_bias:
x = x + self.bias if nn.data_format == "NHWC":
bias = tf.reshape (self.bias, (1,1,1,self.out_ch) )
else:
bias = tf.reshape (self.bias, (1,self.out_ch,1,1) )
x = tf.add(x, bias)
return x return x
def __str__(self): def __str__(self):
@ -454,15 +497,18 @@ def initialize_layers(nn):
dim_size = dim_size * stride_size dim_size = dim_size * stride_size
return dim_size return dim_size
nn.Conv2DTranspose = Conv2DTranspose nn.Conv2DTranspose = Conv2DTranspose
class BlurPool(LayerBase): class BlurPool(LayerBase):
def __init__(self, filt_size=3, stride=2, **kwargs ): def __init__(self, filt_size=3, stride=2, **kwargs ):
self.strides = [1,stride,stride,1] self.strides = [1,stride,stride,1]
self.filt_size = filt_size self.filt_size = filt_size
self.padding = [ [0,0], pad = [ int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2)) ]
[ int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2)) ],
[ int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2)) ], if nn.data_format == "NHWC":
[0,0] ] self.padding = [ [0,0], pad, pad, [0,0] ]
else:
self.padding = [ [0,0], [0,0], pad, pad ]
if(self.filt_size==1): if(self.filt_size==1):
a = np.array([1.,]) a = np.array([1.,])
elif(self.filt_size==2): elif(self.filt_size==2):
@ -493,16 +539,16 @@ def initialize_layers(nn):
x = tf.nn.depthwise_conv2d(x, k, self.strides, 'VALID') x = tf.nn.depthwise_conv2d(x, k, self.strides, 'VALID')
return x return x
nn.BlurPool = BlurPool nn.BlurPool = BlurPool
class Dense(LayerBase): class Dense(LayerBase):
def __init__(self, in_ch, out_ch, use_bias=True, use_wscale=False, maxout_ch=0, kernel_initializer=None, bias_initializer=None, trainable=True, dtype=None, **kwargs ): def __init__(self, in_ch, out_ch, use_bias=True, use_wscale=False, maxout_ch=0, kernel_initializer=None, bias_initializer=None, trainable=True, dtype=None, **kwargs ):
""" """
use_wscale enables weight scale (equalized learning rate) use_wscale enables weight scale (equalized learning rate)
kernel_initializer will be forced to random_normal kernel_initializer will be forced to random_normal
maxout_ch https://link.springer.com/article/10.1186/s40537-019-0233-0 maxout_ch https://link.springer.com/article/10.1186/s40537-019-0233-0
typical 2-4 if you want to enable DenseMaxout behaviour typical 2-4 if you want to enable DenseMaxout behaviour
""" """
self.in_ch = in_ch self.in_ch = in_ch
self.out_ch = out_ch self.out_ch = out_ch
self.use_bias = use_bias self.use_bias = use_bias
@ -512,7 +558,8 @@ def initialize_layers(nn):
self.bias_initializer = bias_initializer self.bias_initializer = bias_initializer
self.trainable = trainable self.trainable = trainable
if dtype is None: if dtype is None:
dtype = tf.float32 dtype = nn.tf_floatx
self.dtype = dtype self.dtype = dtype
super().__init__(**kwargs) super().__init__(**kwargs)
@ -521,25 +568,26 @@ def initialize_layers(nn):
weight_shape = (self.in_ch,self.out_ch*self.maxout_ch) weight_shape = (self.in_ch,self.out_ch*self.maxout_ch)
else: else:
weight_shape = (self.in_ch,self.out_ch) weight_shape = (self.in_ch,self.out_ch)
kernel_initializer = self.kernel_initializer kernel_initializer = self.kernel_initializer
if self.use_wscale:
gain = 1.0
fan_in = np.prod( weight_shape[:-1] )
he_std = gain / np.sqrt(fan_in) # He init
self.wscale = tf.constant(he_std, dtype=self.dtype )
kernel_initializer = tf.initializers.random_normal(0, 1.0, dtype=self.dtype)
if kernel_initializer is None: if kernel_initializer is None:
if self.use_wscale: kernel_initializer = tf.initializers.glorot_uniform(dtype=self.dtype)
gain = 1.0
fan_in = np.prod( weight_shape[:-1] )
he_std = gain / np.sqrt(fan_in) # He init
self.wscale = tf.constant(he_std, dtype=self.dtype )
kernel_initializer = tf.initializers.random_normal(0, 1.0, dtype=self.dtype)
else:
kernel_initializer = tf.initializers.glorot_uniform(dtype=self.dtype)
self.weight = tf.get_variable("weight", weight_shape, dtype=self.dtype, initializer=kernel_initializer, trainable=self.trainable ) self.weight = tf.get_variable("weight", weight_shape, dtype=self.dtype, initializer=kernel_initializer, trainable=self.trainable )
if self.use_bias: if self.use_bias:
bias_initializer = self.bias_initializer bias_initializer = self.bias_initializer
if bias_initializer is None: if bias_initializer is None:
bias_initializer = tf.initializers.zeros(dtype=self.dtype) bias_initializer = tf.initializers.zeros(dtype=self.dtype)
self.bias = tf.get_variable("bias", (1,self.out_ch), dtype=self.dtype, initializer=bias_initializer, trainable=self.trainable ) self.bias = tf.get_variable("bias", (self.out_ch,), dtype=self.dtype, initializer=bias_initializer, trainable=self.trainable )
def get_weights(self): def get_weights(self):
weights = [self.weight] weights = [self.weight]
@ -553,46 +601,53 @@ def initialize_layers(nn):
weight = weight * self.wscale weight = weight * self.wscale
x = tf.matmul(x, weight) x = tf.matmul(x, weight)
if self.maxout_ch > 1: if self.maxout_ch > 1:
x = tf.reshape (x, (-1, self.out_ch, self.maxout_ch) ) x = tf.reshape (x, (-1, self.out_ch, self.maxout_ch) )
x = tf.reduce_max(x, axis=-1) x = tf.reduce_max(x, axis=-1)
if self.use_bias: if self.use_bias:
x = x + self.bias x = tf.add(x, tf.reshape(self.bias, (1,self.out_ch) ) )
return x return x
nn.Dense = Dense nn.Dense = Dense
class BatchNorm2D(LayerBase): class BatchNorm2D(LayerBase):
""" """
currently not for training currently not for training
""" """
def __init__(self, dim, eps=1e-05, momentum=0.1, dtype=None, **kwargs ): def __init__(self, dim, eps=1e-05, momentum=0.1, dtype=None, **kwargs):
self.dim = dim self.dim = dim
self.eps = eps self.eps = eps
self.momentum = momentum self.momentum = momentum
if dtype is None: if dtype is None:
dtype = nn.tf_floatx dtype = nn.tf_floatx
self.dtype = dtype self.dtype = dtype
self.shape = (1,1,1,dim)
super().__init__(**kwargs) super().__init__(**kwargs)
def build_weights(self): def build_weights(self):
self.weight = tf.get_variable("weight", self.shape, dtype=self.dtype, initializer=tf.initializers.ones() ) self.weight = tf.get_variable("weight", (self.dim,), dtype=self.dtype, initializer=tf.initializers.ones() )
self.bias = tf.get_variable("bias", self.shape, dtype=self.dtype, initializer=tf.initializers.zeros() ) self.bias = tf.get_variable("bias", (self.dim,), dtype=self.dtype, initializer=tf.initializers.zeros() )
self.running_mean = tf.get_variable("running_mean", self.shape, dtype=self.dtype, initializer=tf.initializers.zeros(), trainable=False ) self.running_mean = tf.get_variable("running_mean", (self.dim,), dtype=self.dtype, initializer=tf.initializers.zeros(), trainable=False )
self.running_var = tf.get_variable("running_var", self.shape, dtype=self.dtype, initializer=tf.initializers.zeros(), trainable=False ) self.running_var = tf.get_variable("running_var", (self.dim,), dtype=self.dtype, initializer=tf.initializers.zeros(), trainable=False )
def get_weights(self): def get_weights(self):
return [self.weight, self.bias, self.running_mean, self.running_var] return [self.weight, self.bias, self.running_mean, self.running_var]
def __call__(self, x): def __call__(self, x):
x = (x - self.running_mean) / tf.sqrt( self.running_var + self.eps ) if nn.data_format == "NHWC":
x *= self.weight shape = (1,1,1,self.dim)
x += self.bias else:
shape = (1,self.dim,1,1)
weight = tf.reshape ( self.weight , shape )
bias = tf.reshape ( self.bias , shape )
running_mean = tf.reshape ( self.running_mean, shape )
running_var = tf.reshape ( self.running_var , shape )
x = (x - running_mean) / tf.sqrt( running_var + self.eps )
x *= weight
x += bias
return x return x
nn.BatchNorm2D = BatchNorm2D nn.BatchNorm2D = BatchNorm2D

View file

@ -1,51 +1,67 @@
""" """
Leras. Leras.
like lighter keras. like lighter keras.
This is my lightweight neural network library written from scratch This is my lightweight neural network library written from scratch
based on pure tensorflow without keras. based on pure tensorflow without keras.
Provides: Provides:
+ full freedom of tensorflow operations without keras model's restrictions + full freedom of tensorflow operations without keras model's restrictions
+ easy model operations like in PyTorch, but in graph mode (no eager execution) + easy model operations like in PyTorch, but in graph mode (no eager execution)
+ convenient and understandable logic + convenient and understandable logic
Reasons why we cannot import tensorflow or any tensorflow.sub modules right here: Reasons why we cannot import tensorflow or any tensorflow.sub modules right here:
1) change env variables based on DeviceConfig before import tensorflow 1) change env variables based on DeviceConfig before import tensorflow
2) multiprocesses will import tensorflow every spawn 2) multiprocesses will import tensorflow every spawn
NCHW speed up training for 10-20%.
""" """
import os import os
import sys import sys
from pathlib import Path from pathlib import Path
import numpy as np
from core.interact import interact as io from core.interact import interact as io
from .device import Devices from .device import Devices
class nn(): class nn():
current_DeviceConfig = None current_DeviceConfig = None
tf = None tf = None
tf_sess = None tf_sess = None
tf_sess_config = None tf_sess_config = None
tf_default_device = None
data_format = None
conv2d_ch_axis = None
conv2d_spatial_axes = None
tf_floatx = None
np_floatx = None
# Tensor ops # Tensor ops
tf_get_value = None tf_get_value = None
tf_batch_set_value = None tf_batch_set_value = None
tf_gradients = None tf_gradients = None
tf_average_gv_list = None tf_average_gv_list = None
tf_average_tensor_list = None tf_average_tensor_list = None
tf_dot = None tf_concat = None
tf_gelu = None tf_gelu = None
tf_upsample2d = None tf_upsample2d = None
tf_upsample2d_bilinear = None tf_upsample2d_bilinear = None
tf_flatten = None tf_flatten = None
tf_reshape_4D = None
tf_random_binomial = None tf_random_binomial = None
tf_gaussian_blur = None tf_gaussian_blur = None
tf_style_loss = None tf_style_loss = None
tf_channel_histogram = None
tf_histogram = None
tf_dssim = None tf_dssim = None
tf_space_to_depth = None
tf_depth_to_space = None
# Layers # Layers
Saveable = None Saveable = None
LayerBase = None LayerBase = None
@ -55,16 +71,17 @@ class nn():
BlurPool = None BlurPool = None
Dense = None Dense = None
BatchNorm2D = None BatchNorm2D = None
# Initializers # Initializers
initializers = None initializers = None
# Optimizers # Optimizers
TFBaseOptimizer = None TFBaseOptimizer = None
TFRMSpropOptimizer = None TFRMSpropOptimizer = None
@staticmethod @staticmethod
def initialize(device_config=None): def initialize(device_config=None, floatx="float32", data_format="NHWC"):
if nn.tf is None: if nn.tf is None:
if device_config is None: if device_config is None:
device_config = nn.getCurrentDeviceConfig() device_config = nn.getCurrentDeviceConfig()
@ -74,11 +91,8 @@ class nn():
if 'CUDA_VISIBLE_DEVICES' in os.environ.keys(): if 'CUDA_VISIBLE_DEVICES' in os.environ.keys():
os.environ.pop('CUDA_VISIBLE_DEVICES') os.environ.pop('CUDA_VISIBLE_DEVICES')
os.environ['CUDA_CACHE_MAXSIZE'] = '536870912' #512Mb (32mb default)
first_run = False first_run = False
if len(device_config.devices) != 0:
if not device_config.cpu_only:
if sys.platform[0:3] == 'win': if sys.platform[0:3] == 'win':
if all( [ x.name == device_config.devices[0].name for x in device_config.devices ] ): if all( [ x.name == device_config.devices[0].name for x in device_config.devices ] ):
devices_str = "_" + device_config.devices[0].name.replace(' ','_') devices_str = "_" + device_config.devices[0].name.replace(' ','_')
@ -86,27 +100,33 @@ class nn():
devices_str = "" devices_str = ""
for device in device_config.devices: for device in device_config.devices:
devices_str += "_" + device.name.replace(' ','_') devices_str += "_" + device.name.replace(' ','_')
compute_cache_path = Path(os.environ['APPDATA']) / 'NVIDIA' / ('ComputeCache' + devices_str) compute_cache_path = Path(os.environ['APPDATA']) / 'NVIDIA' / ('ComputeCache' + devices_str)
if not compute_cache_path.exists(): if not compute_cache_path.exists():
first_run = True first_run = True
os.environ['CUDA_CACHE_PATH'] = str(compute_cache_path) os.environ['CUDA_CACHE_PATH'] = str(compute_cache_path)
os.environ['CUDA_CACHE_MAXSIZE'] = '536870912' #512Mb (32mb default)
os.environ['TF_MIN_GPU_MULTIPROCESSOR_COUNT'] = '2' os.environ['TF_MIN_GPU_MULTIPROCESSOR_COUNT'] = '2'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # tf log errors only os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # tf log errors only
import warnings import warnings
warnings.simplefilter(action='ignore', category=FutureWarning) warnings.simplefilter(action='ignore', category=FutureWarning)
if first_run: if first_run:
io.log_info("Caching GPU kernels...") io.log_info("Caching GPU kernels...")
import tensorflow as tf import tensorflow as tf
import logging
logging.getLogger('tensorflow').setLevel(logging.ERROR)
nn.tf = tf nn.tf = tf
if device_config.cpu_only: if len(device_config.devices) == 0:
nn.tf_default_device = "/CPU:0"
config = tf.ConfigProto(device_count={'GPU': 0}) config = tf.ConfigProto(device_count={'GPU': 0})
else: else:
nn.tf_default_device = "/GPU:0"
config = tf.ConfigProto() config = tf.ConfigProto()
config.gpu_options.visible_device_list = ','.join([str(device.index) for device in device_config.devices]) config.gpu_options.visible_device_list = ','.join([str(device.index) for device in device_config.devices])
@ -114,26 +134,81 @@ class nn():
config.gpu_options.allow_growth = True config.gpu_options.allow_growth = True
nn.tf_sess_config = config nn.tf_sess_config = config
nn.tf_floatx = nn.tf.float32 #nn.tf.float16 if device_config.use_fp16 else nn.tf.float32
nn.np_floatx = nn.tf_floatx.as_numpy_dtype
from .tensor_ops import initialize_tensor_ops from .tensor_ops import initialize_tensor_ops
from .layers import initialize_layers from .layers import initialize_layers
from .initializers import initialize_initializers from .initializers import initialize_initializers
from .optimizers import initialize_optimizers from .optimizers import initialize_optimizers
initialize_tensor_ops(nn) initialize_tensor_ops(nn)
initialize_layers(nn) initialize_layers(nn)
initialize_initializers(nn) initialize_initializers(nn)
initialize_optimizers(nn) initialize_optimizers(nn)
if nn.tf_sess is None: if nn.tf_sess is None:
nn.tf_sess = tf.Session(config=nn.tf_sess_config) nn.tf_sess = tf.Session(config=nn.tf_sess_config)
if floatx == "float32":
floatx = nn.tf.float32
elif floatx == "float16":
floatx = nn.tf.float16
else:
raise ValueError(f"unsupported floatx {floatx}")
nn.set_floatx(floatx)
nn.set_data_format(data_format)
@staticmethod @staticmethod
def initialize_main_env(): def initialize_main_env():
Devices.initialize_main_env() Devices.initialize_main_env()
@staticmethod
def set_floatx(tf_dtype):
"""
set default float type for all layers when dtype is None for them
"""
nn.tf_floatx = tf_dtype
nn.np_floatx = tf_dtype.as_numpy_dtype
@staticmethod
def set_data_format(data_format):
if data_format != "NHWC" and data_format != "NCHW":
raise ValueError(f"unsupported data_format {data_format}")
nn.data_format = data_format
if data_format == "NHWC":
nn.conv2d_ch_axis = 3
nn.conv2d_spatial_axes = [1,2]
elif data_format == "NCHW":
nn.conv2d_ch_axis = 1
nn.conv2d_spatial_axes = [2,3]
@staticmethod
def get4Dshape ( w, h, c, data_format=None ):
"""
returns 4D shape based on current data_format
"""
if data_format is None:
data_format = nn.data_format
if data_format == "NHWC":
return (None,h,w,c)
else:
return (None,c,h,w)
@staticmethod
def to_data_format( x, to_data_format, from_data_format=None):
if from_data_format is None:
from_data_format = nn.data_format
if to_data_format == from_data_format:
return x
if to_data_format == "NHWC":
return np.transpose(x, (0,2,3,1) )
elif to_data_format == "NCHW":
return np.transpose(x, (0,3,1,2) )
else:
raise ValueError(f"unsupported to_data_format {to_data_format}")
@staticmethod @staticmethod
def getCurrentDeviceConfig(): def getCurrentDeviceConfig():
if nn.current_DeviceConfig is None: if nn.current_DeviceConfig is None:
@ -151,27 +226,34 @@ class nn():
nn.tf.reset_default_graph() nn.tf.reset_default_graph()
nn.tf_sess.close() nn.tf_sess.close()
nn.tf_sess = nn.tf.Session(config=nn.tf_sess_config) nn.tf_sess = nn.tf.Session(config=nn.tf_sess_config)
@staticmethod @staticmethod
def tf_close_session(): def tf_close_session():
if nn.tf_sess is not None: if nn.tf_sess is not None:
nn.tf.reset_default_graph() nn.tf.reset_default_graph()
nn.tf_sess.close() nn.tf_sess.close()
nn.tf_sess = None nn.tf_sess = None
@staticmethod
def tf_get_current_device():
# Undocumented access to last tf.device(...)
objs = nn.tf.get_default_graph()._device_function_stack.peek_objs()
if len(objs) != 0:
return objs[0].display_name
return nn.tf_default_device
@staticmethod @staticmethod
def ask_choose_device_idxs(choose_only_one=False, allow_cpu=True, suggest_best_multi_gpu=False, suggest_all_gpu=False, return_device_config=False): def ask_choose_device_idxs(choose_only_one=False, allow_cpu=True, suggest_best_multi_gpu=False, suggest_all_gpu=False, return_device_config=False):
devices = Devices.getDevices() devices = Devices.getDevices()
if len(devices) == 0: if len(devices) == 0:
return [] return []
all_devices_indexes = [device.index for device in devices] all_devices_indexes = [device.index for device in devices]
if choose_only_one: if choose_only_one:
suggest_best_multi_gpu = False suggest_best_multi_gpu = False
suggest_all_gpu = False suggest_all_gpu = False
if suggest_all_gpu: if suggest_all_gpu:
best_device_indexes = all_devices_indexes best_device_indexes = all_devices_indexes
elif suggest_best_multi_gpu: elif suggest_best_multi_gpu:
@ -179,84 +261,84 @@ class nn():
else: else:
best_device_indexes = [ devices.get_best_device().index ] best_device_indexes = [ devices.get_best_device().index ]
best_device_indexes = ",".join([str(x) for x in best_device_indexes]) best_device_indexes = ",".join([str(x) for x in best_device_indexes])
io.log_info ("") io.log_info ("")
if choose_only_one: if choose_only_one:
io.log_info ("Choose one GPU idx.") io.log_info ("Choose one GPU idx.")
else: else:
io.log_info ("Choose one or several GPU idxs (separated by comma).") io.log_info ("Choose one or several GPU idxs (separated by comma).")
io.log_info ("") io.log_info ("")
if allow_cpu: if allow_cpu:
io.log_info ("[CPU] : CPU") io.log_info ("[CPU] : CPU")
for device in devices: for device in devices:
io.log_info (f" [{device.index}] : {device.name}") io.log_info (f" [{device.index}] : {device.name}")
io.log_info ("") io.log_info ("")
while True: while True:
try: try:
if choose_only_one: if choose_only_one:
choosed_idxs = io.input_str("Which GPU index to choose?", best_device_indexes) choosed_idxs = io.input_str("Which GPU index to choose?", best_device_indexes)
else: else:
choosed_idxs = io.input_str("Which GPU indexes to choose?", best_device_indexes) choosed_idxs = io.input_str("Which GPU indexes to choose?", best_device_indexes)
if allow_cpu and choosed_idxs.lower() == "cpu": if allow_cpu and choosed_idxs.lower() == "cpu":
choosed_idxs = [] choosed_idxs = []
break break
choosed_idxs = [ int(x) for x in choosed_idxs.split(',') ] choosed_idxs = [ int(x) for x in choosed_idxs.split(',') ]
if choose_only_one: if choose_only_one:
if len(choosed_idxs) == 1: if len(choosed_idxs) == 1:
break break
else: else:
if all( [idx in all_devices_indexes for idx in choosed_idxs] ): if all( [idx in all_devices_indexes for idx in choosed_idxs] ):
break break
except: except:
pass pass
io.log_info ("") io.log_info ("")
if return_device_config: if return_device_config:
return nn.DeviceConfig.GPUIndexes(choosed_idxs) return nn.DeviceConfig.GPUIndexes(choosed_idxs)
else: else:
return choosed_idxs return choosed_idxs
class DeviceConfig(): class DeviceConfig():
def __init__ (self, devices=None): def __init__ (self, devices=None):
devices = devices or [] devices = devices or []
if not isinstance(devices, Devices): if not isinstance(devices, Devices):
devices = Devices(devices) devices = Devices(devices)
self.devices = devices self.devices = devices
self.cpu_only = len(devices) == 0 self.cpu_only = len(devices) == 0
@staticmethod @staticmethod
def BestGPU(): def BestGPU():
devices = Devices.getDevices() devices = Devices.getDevices()
if len(devices) == 0: if len(devices) == 0:
return nn.DeviceConfig.CPU() return nn.DeviceConfig.CPU()
return nn.DeviceConfig([devices.get_best_device()]) return nn.DeviceConfig([devices.get_best_device()])
@staticmethod @staticmethod
def WorstGPU(): def WorstGPU():
devices = Devices.getDevices() devices = Devices.getDevices()
if len(devices) == 0: if len(devices) == 0:
return nn.DeviceConfig.CPU() return nn.DeviceConfig.CPU()
return nn.DeviceConfig([devices.get_worst_device()]) return nn.DeviceConfig([devices.get_worst_device()])
@staticmethod @staticmethod
def GPUIndexes(indexes): def GPUIndexes(indexes):
if len(indexes) != 0: if len(indexes) != 0:
devices = Devices.getDevices().get_devices_from_index_list(indexes) devices = Devices.getDevices().get_devices_from_index_list(indexes)
else: else:
devices = [] devices = []
return nn.DeviceConfig(devices) return nn.DeviceConfig(devices)
@staticmethod @staticmethod
def CPU(): def CPU():
return nn.DeviceConfig([]) return nn.DeviceConfig([])

View file

@ -73,7 +73,7 @@ def initialize_optimizers(nn):
e = tf.device('/CPU:0') if vars_on_cpu else None e = tf.device('/CPU:0') if vars_on_cpu else None
if e: e.__enter__() if e: e.__enter__()
with tf.variable_scope(self.name): with tf.variable_scope(self.name):
accumulators = [ tf.get_variable ( f'acc_{i+self.accumulator_counter}', v.shape, initializer=tf.initializers.constant(0.0), trainable=False) accumulators = [ tf.get_variable ( f'acc_{i+self.accumulator_counter}', v.shape, dtype=v.dtype, initializer=tf.initializers.constant(0.0), trainable=False)
for (i, v ) in enumerate(trainable_weights) ] for (i, v ) in enumerate(trainable_weights) ]
self.accumulators_dict.update ( { v.name : acc for v,acc in zip(trainable_weights,accumulators) } ) self.accumulators_dict.update ( { v.name : acc for v,acc in zip(trainable_weights,accumulators) } )
@ -81,13 +81,13 @@ def initialize_optimizers(nn):
self.accumulator_counter += len(trainable_weights) self.accumulator_counter += len(trainable_weights)
if self.lr_dropout != 1.0: if self.lr_dropout != 1.0:
lr_rnds = [ nn.tf_random_binomial( v.shape, p=self.lr_dropout) for v in trainable_weights ] lr_rnds = [ nn.tf_random_binomial( v.shape, p=self.lr_dropout, dtype=v.dtype) for v in trainable_weights ]
self.lr_rnds_dict.update ( { v.name : rnd for v,rnd in zip(trainable_weights,lr_rnds) } ) self.lr_rnds_dict.update ( { v.name : rnd for v,rnd in zip(trainable_weights,lr_rnds) } )
if e: e.__exit__(None, None, None) if e: e.__exit__(None, None, None)
def get_update_op(self, grads_vars): def get_update_op(self, grads_vars):
updates = [] updates = []
lr = self.lr
if self.clipnorm > 0.0: if self.clipnorm > 0.0:
norm = tf.sqrt( sum([tf.reduce_sum(tf.square(g)) for g,v in grads_vars])) norm = tf.sqrt( sum([tf.reduce_sum(tf.square(g)) for g,v in grads_vars]))
updates += [ state_ops.assign_add( self.iterations, 1) ] updates += [ state_ops.assign_add( self.iterations, 1) ]
@ -96,8 +96,14 @@ def initialize_optimizers(nn):
g = self.tf_clip_norm(g, self.clipnorm, norm) g = self.tf_clip_norm(g, self.clipnorm, norm)
a = self.accumulators_dict[v.name] a = self.accumulators_dict[v.name]
new_a = self.rho * a + (1. - self.rho) * tf.square(g)
v_diff = - lr * g / (tf.sqrt(new_a) + self.epsilon) rho = tf.cast(self.rho, a.dtype)
new_a = rho * a + (1. - rho) * tf.square(g)
lr = tf.cast(self.lr, a.dtype)
epsilon = tf.cast(self.epsilon, a.dtype)
v_diff = - lr * g / (tf.sqrt(new_a) + epsilon)
if self.lr_dropout != 1.0: if self.lr_dropout != 1.0:
lr_rnd = self.lr_rnds_dict[v.name] lr_rnd = self.lr_rnds_dict[v.name]
v_diff *= lr_rnd v_diff *= lr_rnd

View file

@ -2,14 +2,14 @@ import numpy as np
def initialize_tensor_ops(nn): def initialize_tensor_ops(nn):
tf = nn.tf tf = nn.tf
from tensorflow.python.ops import array_ops, random_ops, math_ops, sparse_ops, gradients from tensorflow.python.ops import array_ops, random_ops, math_ops, sparse_ops, gradients
from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import sparse_tensor
def tf_get_value(tensor): def tf_get_value(tensor):
return nn.tf_sess.run (tensor) return nn.tf_sess.run (tensor)
nn.tf_get_value = tf_get_value nn.tf_get_value = tf_get_value
def tf_batch_set_value(tuples): def tf_batch_set_value(tuples):
if len(tuples) != 0: if len(tuples) != 0:
with nn.tf.device('/CPU:0'): with nn.tf.device('/CPU:0'):
@ -28,8 +28,8 @@ def initialize_tensor_ops(nn):
nn.tf_sess.run(assign_ops, feed_dict=feed_dict) nn.tf_sess.run(assign_ops, feed_dict=feed_dict)
nn.tf_batch_set_value = tf_batch_set_value nn.tf_batch_set_value = tf_batch_set_value
def tf_gradients ( loss, vars ): def tf_gradients ( loss, vars ):
grads = gradients.gradients(loss, vars, colocate_gradients_with_ops=True ) grads = gradients.gradients(loss, vars, colocate_gradients_with_ops=True )
gv = [*zip(grads,vars)] gv = [*zip(grads,vars)]
@ -38,8 +38,11 @@ def initialize_tensor_ops(nn):
raise Exception("No gradient for variable {v.name}") raise Exception("No gradient for variable {v.name}")
return gv return gv
nn.tf_gradients = tf_gradients nn.tf_gradients = tf_gradients
def tf_average_gv_list(grad_var_list, tf_device_string=None): def tf_average_gv_list(grad_var_list, tf_device_string=None):
if len(grad_var_list) == 1:
return grad_var_list[0]
e = tf.device(tf_device_string) if tf_device_string is not None else None e = tf.device(tf_device_string) if tf_device_string is not None else None
if e is not None: e.__enter__() if e is not None: e.__enter__()
result = [] result = []
@ -56,71 +59,65 @@ def initialize_tensor_ops(nn):
if e is not None: e.__exit__(None,None,None) if e is not None: e.__exit__(None,None,None)
return result return result
nn.tf_average_gv_list = tf_average_gv_list nn.tf_average_gv_list = tf_average_gv_list
def tf_average_tensor_list(tensors_list, tf_device_string=None): def tf_average_tensor_list(tensors_list, tf_device_string=None):
if len(tensors_list) == 1:
return tensors_list[0]
e = tf.device(tf_device_string) if tf_device_string is not None else None e = tf.device(tf_device_string) if tf_device_string is not None else None
if e is not None: e.__enter__() if e is not None: e.__enter__()
result = tf.reduce_mean(tf.concat ([tf.expand_dims(t, 0) for t in tensors_list], 0), 0) result = tf.reduce_mean(tf.concat ([tf.expand_dims(t, 0) for t in tensors_list], 0), 0)
if e is not None: e.__exit__(None,None,None) if e is not None: e.__exit__(None,None,None)
return result return result
nn.tf_average_tensor_list = tf_average_tensor_list nn.tf_average_tensor_list = tf_average_tensor_list
def tf_dot(x, y): def tf_concat (tensors_list, axis):
if x.shape.ndims > 2 or y.shape.ndims > 2: """
x_shape = [] Better version.
for i, s in zip( x.shape.as_list(), array_ops.unstack(array_ops.shape(x))): """
if i is not None: if len(tensors_list) == 1:
x_shape.append(i) return tensors_list[0]
else: return tf.concat(tensors_list, axis)
x_shape.append(s) nn.tf_concat = tf_concat
x_shape = tuple(x_shape)
y_shape = []
for i, s in zip( y.shape.as_list(), array_ops.unstack(array_ops.shape(y))):
if i is not None:
y_shape.append(i)
else:
y_shape.append(s)
y_shape = tuple(y_shape)
y_permute_dim = list(range(y.shape.ndims))
y_permute_dim = [y_permute_dim.pop(-2)] + y_permute_dim
xt = array_ops.reshape(x, [-1, x_shape[-1]])
yt = array_ops.reshape(array_ops.transpose(y, perm=y_permute_dim), [y_shape[-2], -1])
import code
code.interact(local=dict(globals(), **locals()))
return array_ops.reshape(math_ops.matmul(xt, yt), x_shape[:-1] + y_shape[:-2] + y_shape[-1:])
if isinstance(x, sparse_tensor.SparseTensor):
out = sparse_ops.sparse_tensor_dense_matmul(x, y)
else:
out = math_ops.matmul(x, y)
return out
nn.tf_dot = tf_dot
def tf_gelu(x): def tf_gelu(x):
cdf = 0.5 * (1.0 + tf.nn.tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3))))) cdf = 0.5 * (1.0 + tf.nn.tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
return x * cdf return x * cdf
nn.tf_gelu = tf_gelu nn.tf_gelu = tf_gelu
def tf_upsample2d(x, size=2): def tf_upsample2d(x, size=2):
return tf.image.resize_nearest_neighbor(x, (x.shape[1]*size, x.shape[2]*size) ) if nn.data_format == "NCHW":
b,c,h,w = x.shape.as_list()
x = tf.reshape (x, (-1,c,h,1,w,1) )
x = tf.tile(x, (1,1,1,size,1,size) )
x = tf.reshape (x, (-1,c,h*size,w*size) )
return x
else:
return tf.image.resize_nearest_neighbor(x, (x.shape[1]*size, x.shape[2]*size) )
nn.tf_upsample2d = tf_upsample2d nn.tf_upsample2d = tf_upsample2d
def tf_upsample2d_bilinear(x, size=2): def tf_upsample2d_bilinear(x, size=2):
return tf.image.resize_images(x, (x.shape[1]*size, x.shape[2]*size) ) return tf.image.resize_images(x, (x.shape[1]*size, x.shape[2]*size) )
nn.tf_upsample2d_bilinear = tf_upsample2d_bilinear nn.tf_upsample2d_bilinear = tf_upsample2d_bilinear
def tf_flatten(x, dynamic_dims=False): def tf_flatten(x):
""" if nn.data_format == "NHWC":
dynamic_dims allows to flatten without knowing size on input dims # match NCHW version in order to switch data_format without problems
""" x = tf.transpose(x, (0,3,1,2) )
if dynamic_dims: return tf.reshape (x, (-1, np.prod(x.shape[1:])) )
sh = tf.shape(x)
return tf.reshape (x, (sh[0], tf.reduce_prod(sh[1:]) ) )
else:
return tf.reshape (x, (-1, np.prod(x.shape[1:])) )
nn.tf_flatten = tf_flatten nn.tf_flatten = tf_flatten
def tf_reshape_4D(x, w,h,c):
if nn.data_format == "NHWC":
# match NCHW version in order to switch data_format without problems
x = tf.reshape (x, (-1,c,h,w))
x = tf.transpose(x, (0,2,3,1) )
return x
else:
return tf.reshape (x, (-1,c,h,w))
nn.tf_reshape_4D = tf_reshape_4D
def tf_random_binomial(shape, p=0.0, dtype=None, seed=None): def tf_random_binomial(shape, p=0.0, dtype=None, seed=None):
if dtype is None: if dtype is None:
dtype=tf.float32 dtype=tf.float32
@ -131,7 +128,7 @@ def initialize_tensor_ops(nn):
random_ops.random_uniform(shape, dtype=tf.float16, seed=seed) < p, random_ops.random_uniform(shape, dtype=tf.float16, seed=seed) < p,
array_ops.ones(shape, dtype=dtype), array_ops.zeros(shape, dtype=dtype)) array_ops.ones(shape, dtype=dtype), array_ops.zeros(shape, dtype=dtype))
nn.tf_random_binomial = tf_random_binomial nn.tf_random_binomial = tf_random_binomial
def tf_gaussian_blur(input, radius=2.0): def tf_gaussian_blur(input, radius=2.0):
def gaussian(x, mu, sigma): def gaussian(x, mu, sigma):
return np.exp(-(float(x) - float(mu)) ** 2 / (2 * sigma ** 2)) return np.exp(-(float(x) - float(mu)) ** 2 / (2 * sigma ** 2))
@ -142,41 +139,42 @@ def initialize_tensor_ops(nn):
kernel_1d = np.array([gaussian(x, mean, sigma) for x in range(kernel_size)]) kernel_1d = np.array([gaussian(x, mean, sigma) for x in range(kernel_size)])
np_kernel = np.outer(kernel_1d, kernel_1d).astype(np.float32) np_kernel = np.outer(kernel_1d, kernel_1d).astype(np.float32)
kernel = np_kernel / np.sum(np_kernel) kernel = np_kernel / np.sum(np_kernel)
return kernel return kernel, kernel_size
gauss_kernel = make_kernel(radius) gauss_kernel, kernel_size = make_kernel(radius)
gauss_kernel = gauss_kernel[:, :,np.newaxis, np.newaxis] padding = kernel_size//2
kernel_size = gauss_kernel.shape[0] if padding != 0:
if nn.data_format == "NHWC":
inputs = [ input[:,:,:,i:i+1] for i in range( input.shape[-1] ) ] padding = [ [0,0], [padding,padding], [padding,padding], [0,0] ]
else:
padding = [ [0,0], [0,0], [padding,padding], [padding,padding] ]
else:
padding = None
gauss_kernel = gauss_kernel[:,:,None,None]
outputs = [] outputs = []
for i in range(len(inputs)): for i in range(input.shape[nn.conv2d_ch_axis]):
x = inputs[i] x = input[:,:,:,i:i+1] if nn.data_format == "NHWC" \
if kernel_size != 0: else input[:,i:i+1,:,:]
padding = kernel_size//2
x = tf.pad (x, [ [0,0], [padding,padding], [padding,padding], [0,0] ] )
outputs += [ tf.nn.conv2d(x, tf.constant(gauss_kernel, dtype=nn.tf_floatx ) , strides=[1,1,1,1], padding="VALID") ] if padding is not None:
x = tf.pad (x, padding)
outputs += [ tf.nn.conv2d(x, tf.constant(gauss_kernel, dtype=input.dtype ), strides=[1,1,1,1], padding="VALID", data_format=nn.data_format) ]
return tf.concat (outputs, axis=-1) return tf.concat (outputs, axis=nn.conv2d_ch_axis)
nn.tf_gaussian_blur = tf_gaussian_blur nn.tf_gaussian_blur = tf_gaussian_blur
def tf_style_loss(target, style, gaussian_blur_radius=0.0, loss_weight=1.0, step_size=1): def tf_style_loss(target, style, gaussian_blur_radius=0.0, loss_weight=1.0, step_size=1):
def sd(content, style, loss_weight): def sd(content, style, loss_weight):
content_nc = content.shape[-1] content_nc = content.shape[ nn.conv2d_ch_axis ]
style_nc = style.shape[-1] style_nc = style.shape[nn.conv2d_ch_axis]
if content_nc != style_nc: if content_nc != style_nc:
raise Exception("style_loss() content_nc != style_nc") raise Exception("style_loss() content_nc != style_nc")
c_mean, c_var = tf.nn.moments(content, axes=nn.conv2d_spatial_axes, keep_dims=True)
axes = [1,2] s_mean, s_var = tf.nn.moments(style, axes=nn.conv2d_spatial_axes, keep_dims=True)
c_mean, c_var = tf.nn.moments(content, axes=axes, keep_dims=True)
s_mean, s_var = tf.nn.moments(style, axes=axes, keep_dims=True)
c_std, s_std = tf.sqrt(c_var + 1e-5), tf.sqrt(s_var + 1e-5) c_std, s_std = tf.sqrt(c_var + 1e-5), tf.sqrt(s_var + 1e-5)
mean_loss = tf.reduce_sum(tf.square(c_mean-s_mean), axis=[1,2,3]) mean_loss = tf.reduce_sum(tf.square(c_mean-s_mean), axis=[1,2,3])
std_loss = tf.reduce_sum(tf.square(c_std-s_std), axis=[1,2,3]) std_loss = tf.reduce_sum(tf.square(c_std-s_std), axis=[1,2,3])
return (mean_loss + std_loss) * ( loss_weight / content_nc.value ) return (mean_loss + std_loss) * ( loss_weight / content_nc.value )
if gaussian_blur_radius > 0.0: if gaussian_blur_radius > 0.0:
@ -186,47 +184,30 @@ def initialize_tensor_ops(nn):
return sd( target, style, loss_weight=loss_weight ) return sd( target, style, loss_weight=loss_weight )
nn.tf_style_loss = tf_style_loss nn.tf_style_loss = tf_style_loss
def tf_channel_histogram (input, bins, data_range):
range_min, range_max = data_range
bin_range = (range_max-range_min) / (bins-1)
reduce_axes = [*range(input.shape.ndims)][1:]
x = input
x += bin_range/2
output = []
for i in range(bins-1, -1, -1):
y = x - (i*bin_range)
ones_mask = tf.sign( tf.nn.relu(y) )
x = x * (1.0 - ones_mask)
output.append ( tf.expand_dims(tf.reduce_sum (ones_mask, axis=reduce_axes ), -1) )
return tf.concat(output[::-1],-1)
nn.tf_channel_histogram = tf_channel_histogram
def tf_histogram(input, bins=256, data_range=(0,1.0)):
return tf.concat ( [tf.expand_dims( tf_channel_histogram( input[...,i], bins=bins, data_range=data_range ), -1 ) for i in range(input.shape[-1])], -1 )
nn.tf_histogram = tf_histogram
def tf_dssim(img1,img2, max_val, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03): def tf_dssim(img1,img2, max_val, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03):
if img1.dtype != img2.dtype:
ch = img2.shape[-1] raise ValueError("img1.dtype != img2.dtype")
def _fspecial_gauss(size, sigma): not_float32 = img1.dtype != tf.float32
#Function to mimic the 'fspecial' gaussian MATLAB function.
coords = np.arange(0, size, dtype=nn.np_floatx)
coords -= (size - 1 ) / 2.0
g = coords**2
g *= ( -0.5 / (sigma**2) )
g = np.reshape (g, (1,-1)) + np.reshape(g, (-1,1) )
g = tf.constant ( np.reshape (g, (1,-1)), dtype=nn.tf_floatx )
g = tf.nn.softmax(g)
g = tf.reshape (g, (size, size, 1, 1))
g = tf.tile (g, (1,1,ch,1))
return g
kernel = _fspecial_gauss(filter_size,filter_sigma) if not_float32:
img_dtype = img1.dtype
img1 = tf.cast(img1, tf.float32)
img2 = tf.cast(img2, tf.float32)
kernel = np.arange(0, filter_size, dtype=np.float32)
kernel -= (filter_size - 1 ) / 2.0
kernel = kernel**2
kernel *= ( -0.5 / (filter_sigma**2) )
kernel = np.reshape (kernel, (1,-1)) + np.reshape(kernel, (-1,1) )
kernel = tf.constant ( np.reshape (kernel, (1,-1)), dtype=tf.float32 )
kernel = tf.nn.softmax(kernel)
kernel = tf.reshape (kernel, (filter_size, filter_size, 1, 1))
kernel = tf.tile (kernel, (1,1, img1.shape[ nn.conv2d_ch_axis ] ,1))
def reducer(x): def reducer(x):
return tf.nn.depthwise_conv2d(x, kernel, strides=[1,1,1,1], padding='VALID') return tf.nn.depthwise_conv2d(x, kernel, strides=[1,1,1,1], padding='VALID', data_format=nn.data_format)
c1 = (k1 * max_val) ** 2 c1 = (k1 * max_val) ** 2
c2 = (k2 * max_val) ** 2 c2 = (k2 * max_val) ** 2
@ -242,10 +223,44 @@ def initialize_tensor_ops(nn):
c2 *= 1.0 #compensation factor c2 *= 1.0 #compensation factor
cs = (num1 - num0 + c2) / (den1 - den0 + c2) cs = (num1 - num0 + c2) / (den1 - den0 + c2)
ssim_val = tf.reduce_mean(luminance * cs, axis=(-3, -2) ) ssim_val = tf.reduce_mean(luminance * cs, axis=nn.conv2d_spatial_axes )
return(1.0 - ssim_val ) / 2.0 dssim = (1.0 - ssim_val ) / 2.0
if not_float32:
dssim = tf.cast(dssim, img_dtype)
return dssim
nn.tf_dssim = tf_dssim nn.tf_dssim = tf_dssim
def tf_space_to_depth(x, size):
if nn.data_format == "NHWC":
# match NCHW version in order to switch data_format without problems
b,h,w,c = x.shape.as_list()
oh, ow = h // size, w // size
x = tf.reshape(x, (-1, size, oh, size, ow, c))
x = tf.transpose(x, (0, 2, 4, 1, 3, 5))
x = tf.reshape(x, (-1, oh, ow, size* size* c ))
return x
else:
return tf.space_to_depth(x, size, data_format=nn.data_format)
nn.tf_space_to_depth = tf_space_to_depth
def tf_depth_to_space(x, size):
if nn.data_format == "NHWC":
# match NCHW version in order to switch data_format without problems
b,h,w,c = x.shape.as_list()
oh, ow = h * size, w * size
oc = c // (size * size)
x = tf.reshape(x, (-1, h, w, size, size, oc, ) )
x = tf.transpose(x, (0, 1, 3, 2, 4, 5))
x = tf.reshape(x, (-1, oh, ow, oc, ))
return x
else:
return tf.depth_to_space(x, size, data_format=nn.data_format)
nn.tf_depth_to_space = tf_depth_to_space
def tf_rgb_to_lab(srgb): def tf_rgb_to_lab(srgb):
srgb_pixels = tf.reshape(srgb, [-1, 3]) srgb_pixels = tf.reshape(srgb, [-1, 3])
linear_mask = tf.cast(srgb_pixels <= 0.04045, dtype=tf.float32) linear_mask = tf.cast(srgb_pixels <= 0.04045, dtype=tf.float32)
@ -275,14 +290,14 @@ def initialize_tensor_ops(nn):
lab_pixels = tf.matmul(fxfyfz_pixels, fxfyfz_to_lab) + tf.constant([-16.0, 0.0, 0.0]) lab_pixels = tf.matmul(fxfyfz_pixels, fxfyfz_to_lab) + tf.constant([-16.0, 0.0, 0.0])
return tf.reshape(lab_pixels, tf.shape(srgb)) return tf.reshape(lab_pixels, tf.shape(srgb))
nn.tf_rgb_to_lab = tf_rgb_to_lab nn.tf_rgb_to_lab = tf_rgb_to_lab
def tf_suppress_lower_mean(t, eps=0.00001): def tf_suppress_lower_mean(t, eps=0.00001):
if t.shape.ndims != 1: if t.shape.ndims != 1:
raise ValueError("tf_suppress_lower_mean: t rank must be 1") raise ValueError("tf_suppress_lower_mean: t rank must be 1")
t_mean_eps = tf.reduce_mean(t) - eps t_mean_eps = tf.reduce_mean(t) - eps
q = tf.clip_by_value(t, t_mean_eps, tf.reduce_max(t) ) q = tf.clip_by_value(t, t_mean_eps, tf.reduce_max(t) )
q = tf.clip_by_value(q-t_mean_eps, 0, eps) q = tf.clip_by_value(q-t_mean_eps, 0, eps)
q = q * (t/eps) q = q * (t/eps)
return q return q
""" """
class GeLU(KL.Layer): class GeLU(KL.Layer):

View file

@ -20,18 +20,18 @@ def scantree(path):
yield from scantree(entry.path) # see below for Python 2.x yield from scantree(entry.path) # see below for Python 2.x
else: else:
yield entry yield entry
def get_image_paths(dir_path, image_extensions=image_extensions, subdirs=False): def get_image_paths(dir_path, image_extensions=image_extensions, subdirs=False):
dir_path = Path (dir_path) dir_path = Path (dir_path)
result = [] result = []
if dir_path.exists(): if dir_path.exists():
if subdirs: if subdirs:
gen = scantree(str(dir_path)) gen = scantree(str(dir_path))
else: else:
gen = scandir(str(dir_path)) gen = scandir(str(dir_path))
for x in list(gen): for x in list(gen):
if any([x.name.lower().endswith(ext) for ext in image_extensions]): if any([x.name.lower().endswith(ext) for ext in image_extensions]):
result.append(x.path) result.append(x.path)
@ -51,7 +51,7 @@ def get_image_unique_filestem_paths(dir_path, verbose_print_func=None):
result_dup.add(f_stem) result_dup.add(f_stem)
return sorted(result) return sorted(result)
def get_file_paths(dir_path): def get_file_paths(dir_path):
dir_path = Path (dir_path) dir_path = Path (dir_path)
@ -59,7 +59,7 @@ def get_file_paths(dir_path):
return [ Path(x) for x in sorted([ x.path for x in list(scandir(str(dir_path))) if x.is_file() ]) ] return [ Path(x) for x in sorted([ x.path for x in list(scandir(str(dir_path))) if x.is_file() ]) ]
else: else:
return [] return []
def get_all_dir_names (dir_path): def get_all_dir_names (dir_path):
dir_path = Path (dir_path) dir_path = Path (dir_path)
@ -67,7 +67,7 @@ def get_all_dir_names (dir_path):
return sorted([ x.name for x in list(scandir(str(dir_path))) if x.is_dir() ]) return sorted([ x.name for x in list(scandir(str(dir_path))) if x.is_dir() ])
else: else:
return [] return []
def get_all_dir_names_startswith (dir_path, startswith): def get_all_dir_names_startswith (dir_path, startswith):
dir_path = Path (dir_path) dir_path = Path (dir_path)
startswith = startswith.lower() startswith = startswith.lower()
@ -98,7 +98,7 @@ def move_all_files (src_dir_path, dst_dir_path):
for p in paths: for p in paths:
p = Path(p) p = Path(p)
p.rename ( Path(dst_dir_path) / p.name ) p.rename ( Path(dst_dir_path) / p.name )
def delete_all_files (dir_path): def delete_all_files (dir_path):
paths = get_file_paths(dir_path) paths = get_file_paths(dir_path)
for p in paths: for p in paths:

View file

@ -11,4 +11,4 @@ def random_normal( size=(1,), trunc_val = 2.5 ):
break break
result[i] = (x / trunc_val) result[i] = (x / trunc_val)
return result.reshape ( size ) return result.reshape ( size )

View file

@ -18,7 +18,7 @@ class FANExtractor(object):
if not model_path.exists(): if not model_path.exists():
raise Exception("Unable to load FANExtractor model") raise Exception("Unable to load FANExtractor model")
nn.initialize() nn.initialize(data_format="NHWC")
tf = nn.tf tf = nn.tf
class ConvBlock(nn.ModelBase): class ConvBlock(nn.ModelBase):
@ -29,10 +29,10 @@ class FANExtractor(object):
self.bn1 = nn.BatchNorm2D(in_planes) self.bn1 = nn.BatchNorm2D(in_planes)
self.conv1 = nn.Conv2D (in_planes, out_planes/2, kernel_size=3, strides=1, padding='SAME', use_bias=False ) self.conv1 = nn.Conv2D (in_planes, out_planes/2, kernel_size=3, strides=1, padding='SAME', use_bias=False )
self.bn2 = nn.BatchNorm2D(out_planes/2) self.bn2 = nn.BatchNorm2D(out_planes//2)
self.conv2 = nn.Conv2D (out_planes/2, out_planes/4, kernel_size=3, strides=1, padding='SAME', use_bias=False ) self.conv2 = nn.Conv2D (out_planes/2, out_planes/4, kernel_size=3, strides=1, padding='SAME', use_bias=False )
self.bn3 = nn.BatchNorm2D(out_planes/4) self.bn3 = nn.BatchNorm2D(out_planes//4)
self.conv3 = nn.Conv2D (out_planes/4, out_planes/4, kernel_size=3, strides=1, padding='SAME', use_bias=False ) self.conv3 = nn.Conv2D (out_planes/4, out_planes/4, kernel_size=3, strides=1, padding='SAME', use_bias=False )
if self.in_planes != self.out_planes: if self.in_planes != self.out_planes:
@ -55,6 +55,7 @@ class FANExtractor(object):
x = self.bn3(x) x = self.bn3(x)
x = tf.nn.relu(x) x = tf.nn.relu(x)
x = out3 = self.conv3(x) x = out3 = self.conv3(x)
x = tf.concat ([out1, out2, out3], axis=-1) x = tf.concat ([out1, out2, out3], axis=-1)
if self.in_planes != self.out_planes: if self.in_planes != self.out_planes:
@ -148,7 +149,9 @@ class FANExtractor(object):
if i < 4 - 1: if i < 4 - 1:
ll = self.bl[i](ll) ll = self.bl[i](ll)
previous = previous + ll + self.al[i](tmp_out) previous = previous + ll + self.al[i](tmp_out)
return outputs[-1] x = outputs[-1]
x = tf.transpose(x, (0,3,1,2) )
return x
e = None e = None
if place_model_on_cpu: if place_model_on_cpu:
@ -159,7 +162,7 @@ class FANExtractor(object):
self.model.load_weights(str(model_path)) self.model.load_weights(str(model_path))
if e is not None: e.__exit__(None,None,None) if e is not None: e.__exit__(None,None,None)
self.model.build_for_run ([ ( tf.float32, (256,256,3) ) ]) self.model.build_for_run ([ ( tf.float32, (None,256,256,3) ) ])
def extract (self, input_image, rects, second_pass_extractor=None, is_bgr=True, multi_sample=False): def extract (self, input_image, rects, second_pass_extractor=None, is_bgr=True, multi_sample=False):
if len(rects) == 0: if len(rects) == 0:
@ -197,7 +200,7 @@ class FANExtractor(object):
predicted = [] predicted = []
for i in range( len(images) ): for i in range( len(images) ):
predicted += [ self.model.run ( [ images[i][None,...] ] ).transpose (0,3,1,2)[0] ] predicted += [ self.model.run ( [ images[i][None,...] ] )[0] ]
predicted = np.stack(predicted) predicted = np.stack(predicted)

View file

@ -11,7 +11,7 @@ class FaceEnhancer(object):
x4 face enhancer x4 face enhancer
""" """
def __init__(self, place_model_on_cpu=False): def __init__(self, place_model_on_cpu=False):
nn.initialize() nn.initialize(data_format="NHWC")
tf = nn.tf tf = nn.tf
class FaceEnhancer (nn.ModelBase): class FaceEnhancer (nn.ModelBase):
@ -167,9 +167,9 @@ class FaceEnhancer(object):
self.model.load_weights (model_path) self.model.load_weights (model_path)
if e is not None: e.__exit__(None,None,None) if e is not None: e.__exit__(None,None,None)
self.model.build_for_run ([ (tf.float32, (192,192,3) ), self.model.build_for_run ([ (tf.float32, nn.get4Dshape (192,192,3) ),
(tf.float32, (1,) ), (tf.float32, (None,1,) ),
(tf.float32, (1,) ), (tf.float32, (None,1,) ),
]) ])
@ -185,14 +185,14 @@ class FaceEnhancer(object):
ih,iw,ic = inp_img.shape ih,iw,ic = inp_img.shape
h,w,c = ih,iw,ic h,w,c = ih,iw,ic
th,tw = h*up_res, w*up_res th,tw = h*up_res, w*up_res
t_padding = 0 t_padding = 0
b_padding = 0 b_padding = 0
l_padding = 0 l_padding = 0
r_padding = 0 r_padding = 0
if h < patch_size: if h < patch_size:
t_padding = (patch_size-h)//2 t_padding = (patch_size-h)//2
b_padding = (patch_size-h) - t_padding b_padding = (patch_size-h) - t_padding
@ -200,24 +200,24 @@ class FaceEnhancer(object):
if w < patch_size: if w < patch_size:
l_padding = (patch_size-w)//2 l_padding = (patch_size-w)//2
r_padding = (patch_size-w) - l_padding r_padding = (patch_size-w) - l_padding
if t_padding != 0: if t_padding != 0:
inp_img = np.concatenate ([ np.zeros ( (t_padding,w,c), dtype=np.float32 ), inp_img ], axis=0 ) inp_img = np.concatenate ([ np.zeros ( (t_padding,w,c), dtype=np.float32 ), inp_img ], axis=0 )
h,w,c = inp_img.shape h,w,c = inp_img.shape
if b_padding != 0: if b_padding != 0:
inp_img = np.concatenate ([ inp_img, np.zeros ( (b_padding,w,c), dtype=np.float32 ) ], axis=0 ) inp_img = np.concatenate ([ inp_img, np.zeros ( (b_padding,w,c), dtype=np.float32 ) ], axis=0 )
h,w,c = inp_img.shape h,w,c = inp_img.shape
if l_padding != 0: if l_padding != 0:
inp_img = np.concatenate ([ np.zeros ( (h,l_padding,c), dtype=np.float32 ), inp_img ], axis=1 ) inp_img = np.concatenate ([ np.zeros ( (h,l_padding,c), dtype=np.float32 ), inp_img ], axis=1 )
h,w,c = inp_img.shape h,w,c = inp_img.shape
if r_padding != 0: if r_padding != 0:
inp_img = np.concatenate ([ inp_img, np.zeros ( (h,r_padding,c), dtype=np.float32 ) ], axis=1 ) inp_img = np.concatenate ([ inp_img, np.zeros ( (h,r_padding,c), dtype=np.float32 ) ], axis=1 )
h,w,c = inp_img.shape h,w,c = inp_img.shape
i_max = w-patch_size+1 i_max = w-patch_size+1
j_max = h-patch_size+1 j_max = h-patch_size+1
@ -248,7 +248,7 @@ class FaceEnhancer(object):
if t_padding+b_padding+l_padding+r_padding != 0: if t_padding+b_padding+l_padding+r_padding != 0:
final_img = final_img [t_padding*up_res:(h-b_padding)*up_res, l_padding*up_res:(w-r_padding)*up_res,:] final_img = final_img [t_padding*up_res:(h-b_padding)*up_res, l_padding*up_res:(w-r_padding)*up_res,:]
if preserve_size: if preserve_size:
final_img = cv2.resize (final_img, (iw,ih), cv2.INTER_LANCZOS4) final_img = cv2.resize (final_img, (iw,ih), cv2.INTER_LANCZOS4)
@ -271,15 +271,15 @@ class FaceEnhancer(object):
patch_size_half = patch_size // 2 patch_size_half = patch_size // 2
h,w,c = inp_img.shape h,w,c = inp_img.shape
th,tw = h*up_res, w*up_res th,tw = h*up_res, w*up_res
preupscale_rate = 1.0 preupscale_rate = 1.0
if h < patch_size or w < patch_size: if h < patch_size or w < patch_size:
preupscale_rate = 1.0 / ( max(h,w) / patch_size ) preupscale_rate = 1.0 / ( max(h,w) / patch_size )
if preupscale_rate != 1.0: if preupscale_rate != 1.0:
inp_img = cv2.resize (inp_img, ( int(w*preupscale_rate), int(h*preupscale_rate) ), cv2.INTER_LANCZOS4) inp_img = cv2.resize (inp_img, ( int(w*preupscale_rate), int(h*preupscale_rate) ), cv2.INTER_LANCZOS4)
h,w,c = inp_img.shape h,w,c = inp_img.shape
@ -314,7 +314,7 @@ class FaceEnhancer(object):
if preserve_size: if preserve_size:
final_img = cv2.resize (final_img, (w,h), cv2.INTER_LANCZOS4) final_img = cv2.resize (final_img, (w,h), cv2.INTER_LANCZOS4)
else: else:
if preupscale_rate != 1.0: if preupscale_rate != 1.0:
final_img = cv2.resize (final_img, (tw,th), cv2.INTER_LANCZOS4) final_img = cv2.resize (final_img, (tw,th), cv2.INTER_LANCZOS4)
if not is_tanh: if not is_tanh:

View file

@ -8,7 +8,7 @@ class FaceType(IntEnum):
FULL_NO_ALIGN = 3 FULL_NO_ALIGN = 3
HEAD = 4 HEAD = 4
HEAD_NO_ALIGN = 5 HEAD_NO_ALIGN = 5
MARK_ONLY = 10, #no align at all, just embedded faceinfo MARK_ONLY = 10, #no align at all, just embedded faceinfo
@staticmethod @staticmethod

View file

@ -263,29 +263,29 @@ def get_transform_mat (image_landmarks, output_size, face_type, scale=1.0, full_
tb_diag_vec /= npla.norm(tb_diag_vec) tb_diag_vec /= npla.norm(tb_diag_vec)
bt_diag_vec = (l_p[1]-l_p[3]).astype(np.float32) bt_diag_vec = (l_p[1]-l_p[3]).astype(np.float32)
bt_diag_vec /= npla.norm(bt_diag_vec) bt_diag_vec /= npla.norm(bt_diag_vec)
mod = (1.0 / scale)* ( npla.norm(l_p[0]-l_p[2])*(padding*np.sqrt(2.0) + 0.5) ) mod = (1.0 / scale)* ( npla.norm(l_p[0]-l_p[2])*(padding*np.sqrt(2.0) + 0.5) )
if not remove_align: if not remove_align:
l_t = np.array( [ np.round( l_c - tb_diag_vec*mod ), l_t = np.array( [ np.round( l_c - tb_diag_vec*mod ),
np.round( l_c + bt_diag_vec*mod ), np.round( l_c + bt_diag_vec*mod ),
np.round( l_c + tb_diag_vec*mod ) ] ) np.round( l_c + tb_diag_vec*mod ) ] )
else: else:
l_t = np.array( [ np.round( l_c - tb_diag_vec*mod ), l_t = np.array( [ np.round( l_c - tb_diag_vec*mod ),
np.round( l_c + bt_diag_vec*mod ), np.round( l_c + bt_diag_vec*mod ),
np.round( l_c + tb_diag_vec*mod ), np.round( l_c + tb_diag_vec*mod ),
np.round( l_c - bt_diag_vec*mod ), np.round( l_c - bt_diag_vec*mod ),
] ) ] )
area = mathlib.polygon_area(l_t[:,0], l_t[:,1] ) area = mathlib.polygon_area(l_t[:,0], l_t[:,1] )
side = np.float32(math.sqrt(area) / 2) side = np.float32(math.sqrt(area) / 2)
l_t = np.array( [ np.round( l_c + [-side,-side] ), l_t = np.array( [ np.round( l_c + [-side,-side] ),
np.round( l_c + [ side,-side] ), np.round( l_c + [ side,-side] ),
np.round( l_c + [ side, side] ) ] ) np.round( l_c + [ side, side] ) ] )
pts2 = np.float32(( (0,0),(output_size,0),(output_size,output_size) )) pts2 = np.float32(( (0,0),(output_size,0),(output_size,output_size) ))
mat = cv2.getAffineTransform(l_t,pts2) mat = cv2.getAffineTransform(l_t,pts2)
#if remove_align: #if remove_align:
# bbox = transform_points ( [ (0,0), (0,output_size), (output_size, output_size), (output_size,0) ], mat, True) # bbox = transform_points ( [ (0,0), (0,output_size), (output_size, output_size), (output_size,0) ], mat, True)
@ -301,24 +301,24 @@ def get_transform_mat (image_landmarks, output_size, face_type, scale=1.0, full_
return mat return mat
#if full_face_align_top and (face_type == FaceType.FULL or face_type == FaceType.FULL_NO_ALIGN): #if full_face_align_top and (face_type == FaceType.FULL or face_type == FaceType.FULL_NO_ALIGN):
# #lmrks2 = expand_eyebrows(image_landmarks) # #lmrks2 = expand_eyebrows(image_landmarks)
# #lmrks2_ = transform_points( [ lmrks2[19], lmrks2[24] ], mat, False ) # #lmrks2_ = transform_points( [ lmrks2[19], lmrks2[24] ], mat, False )
# #y_diff = np.float32( (0,np.min(lmrks2_[:,1])) ) # #y_diff = np.float32( (0,np.min(lmrks2_[:,1])) )
# #y_diff = transform_points( [ np.float32( (0,0) ), y_diff], mat, True) # #y_diff = transform_points( [ np.float32( (0,0) ), y_diff], mat, True)
# #y_diff = y_diff[1]-y_diff[0] # #y_diff = y_diff[1]-y_diff[0]
# #
# x_diff = np.float32((0,0)) # x_diff = np.float32((0,0))
# #
# lmrks2_ = transform_points( [ image_landmarks[0], image_landmarks[16] ], mat, False ) # lmrks2_ = transform_points( [ image_landmarks[0], image_landmarks[16] ], mat, False )
# if lmrks2_[0,0] < 0: # if lmrks2_[0,0] < 0:
# x_diff = lmrks2_[0,0] # x_diff = lmrks2_[0,0]
# x_diff = transform_points( [ np.float32( (0,0) ), np.float32((x_diff,0)) ], mat, True) # x_diff = transform_points( [ np.float32( (0,0) ), np.float32((x_diff,0)) ], mat, True)
# x_diff = x_diff[1]-x_diff[0] # x_diff = x_diff[1]-x_diff[0]
# elif lmrks2_[1,0] >= output_size: # elif lmrks2_[1,0] >= output_size:
# x_diff = lmrks2_[1,0]-(output_size-1) # x_diff = lmrks2_[1,0]-(output_size-1)
# x_diff = transform_points( [ np.float32( (0,0) ), np.float32((x_diff,0)) ], mat, True) # x_diff = transform_points( [ np.float32( (0,0) ), np.float32((x_diff,0)) ], mat, True)
# x_diff = x_diff[1]-x_diff[0] # x_diff = x_diff[1]-x_diff[0]
# #
# mat = cv2.getAffineTransform( l_t+y_diff+x_diff ,pts2) # mat = cv2.getAffineTransform( l_t+y_diff+x_diff ,pts2)
def expand_eyebrows(lmrks, eyebrows_expand_mod=1.0): def expand_eyebrows(lmrks, eyebrows_expand_mod=1.0):
if len(lmrks) != 68: if len(lmrks) != 68:
@ -687,5 +687,5 @@ def estimate_pitch_yaw_roll(aligned_256px_landmarks):
pitch = np.clip ( pitch, -math.pi, math.pi ) pitch = np.clip ( pitch, -math.pi, math.pi )
yaw = np.clip ( yaw , -math.pi, math.pi ) yaw = np.clip ( yaw , -math.pi, math.pi )
roll = np.clip ( roll, -math.pi, math.pi ) roll = np.clip ( roll, -math.pi, math.pi )
return -pitch, yaw, roll return -pitch, yaw, roll

View file

@ -8,9 +8,9 @@ from core.leras import nn
class S3FDExtractor(object): class S3FDExtractor(object):
def __init__(self, place_model_on_cpu=False): def __init__(self, place_model_on_cpu=False):
nn.initialize() nn.initialize(data_format="NHWC")
tf = nn.tf tf = nn.tf
model_path = Path(__file__).parent / "S3FD.npy" model_path = Path(__file__).parent / "S3FD.npy"
if not model_path.exists(): if not model_path.exists():
raise Exception("Unable to load S3FD.npy") raise Exception("Unable to load S3FD.npy")
@ -19,143 +19,143 @@ class S3FDExtractor(object):
def __init__(self, n_channels, **kwargs): def __init__(self, n_channels, **kwargs):
self.n_channels = n_channels self.n_channels = n_channels
super().__init__(**kwargs) super().__init__(**kwargs)
def build_weights(self): def build_weights(self):
self.weight = tf.get_variable ("weight", (1, 1, 1, self.n_channels), dtype=nn.tf_floatx, initializer=tf.initializers.ones ) self.weight = tf.get_variable ("weight", (1, 1, 1, self.n_channels), dtype=nn.tf_floatx, initializer=tf.initializers.ones )
def get_weights(self): def get_weights(self):
return [self.weight] return [self.weight]
def __call__(self, inputs): def __call__(self, inputs):
x = inputs x = inputs
x = x / (tf.sqrt( tf.reduce_sum( tf.pow(x, 2), axis=-1, keepdims=True ) ) + 1e-10) * self.weight x = x / (tf.sqrt( tf.reduce_sum( tf.pow(x, 2), axis=-1, keepdims=True ) ) + 1e-10) * self.weight
return x return x
class S3FD(nn.ModelBase): class S3FD(nn.ModelBase):
def __init__(self): def __init__(self):
super().__init__(name='S3FD') super().__init__(name='S3FD')
def on_build(self): def on_build(self):
self.minus = tf.constant([104,117,123], dtype=nn.tf_floatx ) self.minus = tf.constant([104,117,123], dtype=nn.tf_floatx )
self.conv1_1 = nn.Conv2D(3, 64, kernel_size=3, strides=1, padding='SAME') self.conv1_1 = nn.Conv2D(3, 64, kernel_size=3, strides=1, padding='SAME')
self.conv1_2 = nn.Conv2D(64, 64, kernel_size=3, strides=1, padding='SAME') self.conv1_2 = nn.Conv2D(64, 64, kernel_size=3, strides=1, padding='SAME')
self.conv2_1 = nn.Conv2D(64, 128, kernel_size=3, strides=1, padding='SAME') self.conv2_1 = nn.Conv2D(64, 128, kernel_size=3, strides=1, padding='SAME')
self.conv2_2 = nn.Conv2D(128, 128, kernel_size=3, strides=1, padding='SAME') self.conv2_2 = nn.Conv2D(128, 128, kernel_size=3, strides=1, padding='SAME')
self.conv3_1 = nn.Conv2D(128, 256, kernel_size=3, strides=1, padding='SAME') self.conv3_1 = nn.Conv2D(128, 256, kernel_size=3, strides=1, padding='SAME')
self.conv3_2 = nn.Conv2D(256, 256, kernel_size=3, strides=1, padding='SAME') self.conv3_2 = nn.Conv2D(256, 256, kernel_size=3, strides=1, padding='SAME')
self.conv3_3 = nn.Conv2D(256, 256, kernel_size=3, strides=1, padding='SAME') self.conv3_3 = nn.Conv2D(256, 256, kernel_size=3, strides=1, padding='SAME')
self.conv4_1 = nn.Conv2D(256, 512, kernel_size=3, strides=1, padding='SAME') self.conv4_1 = nn.Conv2D(256, 512, kernel_size=3, strides=1, padding='SAME')
self.conv4_2 = nn.Conv2D(512, 512, kernel_size=3, strides=1, padding='SAME') self.conv4_2 = nn.Conv2D(512, 512, kernel_size=3, strides=1, padding='SAME')
self.conv4_3 = nn.Conv2D(512, 512, kernel_size=3, strides=1, padding='SAME') self.conv4_3 = nn.Conv2D(512, 512, kernel_size=3, strides=1, padding='SAME')
self.conv5_1 = nn.Conv2D(512, 512, kernel_size=3, strides=1, padding='SAME') self.conv5_1 = nn.Conv2D(512, 512, kernel_size=3, strides=1, padding='SAME')
self.conv5_2 = nn.Conv2D(512, 512, kernel_size=3, strides=1, padding='SAME') self.conv5_2 = nn.Conv2D(512, 512, kernel_size=3, strides=1, padding='SAME')
self.conv5_3 = nn.Conv2D(512, 512, kernel_size=3, strides=1, padding='SAME') self.conv5_3 = nn.Conv2D(512, 512, kernel_size=3, strides=1, padding='SAME')
self.fc6 = nn.Conv2D(512, 1024, kernel_size=3, strides=1, padding=3) self.fc6 = nn.Conv2D(512, 1024, kernel_size=3, strides=1, padding=3)
self.fc7 = nn.Conv2D(1024, 1024, kernel_size=1, strides=1, padding='SAME') self.fc7 = nn.Conv2D(1024, 1024, kernel_size=1, strides=1, padding='SAME')
self.conv6_1 = nn.Conv2D(1024, 256, kernel_size=1, strides=1, padding='SAME') self.conv6_1 = nn.Conv2D(1024, 256, kernel_size=1, strides=1, padding='SAME')
self.conv6_2 = nn.Conv2D(256, 512, kernel_size=3, strides=2, padding='SAME') self.conv6_2 = nn.Conv2D(256, 512, kernel_size=3, strides=2, padding='SAME')
self.conv7_1 = nn.Conv2D(512, 128, kernel_size=1, strides=1, padding='SAME') self.conv7_1 = nn.Conv2D(512, 128, kernel_size=1, strides=1, padding='SAME')
self.conv7_2 = nn.Conv2D(128, 256, kernel_size=3, strides=2, padding='SAME') self.conv7_2 = nn.Conv2D(128, 256, kernel_size=3, strides=2, padding='SAME')
self.conv3_3_norm = L2Norm(256) self.conv3_3_norm = L2Norm(256)
self.conv4_3_norm = L2Norm(512) self.conv4_3_norm = L2Norm(512)
self.conv5_3_norm = L2Norm(512) self.conv5_3_norm = L2Norm(512)
self.conv3_3_norm_mbox_conf = nn.Conv2D(256, 4, kernel_size=3, strides=1, padding='SAME') self.conv3_3_norm_mbox_conf = nn.Conv2D(256, 4, kernel_size=3, strides=1, padding='SAME')
self.conv3_3_norm_mbox_loc = nn.Conv2D(256, 4, kernel_size=3, strides=1, padding='SAME') self.conv3_3_norm_mbox_loc = nn.Conv2D(256, 4, kernel_size=3, strides=1, padding='SAME')
self.conv4_3_norm_mbox_conf = nn.Conv2D(512, 2, kernel_size=3, strides=1, padding='SAME') self.conv4_3_norm_mbox_conf = nn.Conv2D(512, 2, kernel_size=3, strides=1, padding='SAME')
self.conv4_3_norm_mbox_loc = nn.Conv2D(512, 4, kernel_size=3, strides=1, padding='SAME') self.conv4_3_norm_mbox_loc = nn.Conv2D(512, 4, kernel_size=3, strides=1, padding='SAME')
self.conv5_3_norm_mbox_conf = nn.Conv2D(512, 2, kernel_size=3, strides=1, padding='SAME') self.conv5_3_norm_mbox_conf = nn.Conv2D(512, 2, kernel_size=3, strides=1, padding='SAME')
self.conv5_3_norm_mbox_loc = nn.Conv2D(512, 4, kernel_size=3, strides=1, padding='SAME') self.conv5_3_norm_mbox_loc = nn.Conv2D(512, 4, kernel_size=3, strides=1, padding='SAME')
self.fc7_mbox_conf = nn.Conv2D(1024, 2, kernel_size=3, strides=1, padding='SAME') self.fc7_mbox_conf = nn.Conv2D(1024, 2, kernel_size=3, strides=1, padding='SAME')
self.fc7_mbox_loc = nn.Conv2D(1024, 4, kernel_size=3, strides=1, padding='SAME') self.fc7_mbox_loc = nn.Conv2D(1024, 4, kernel_size=3, strides=1, padding='SAME')
self.conv6_2_mbox_conf = nn.Conv2D(512, 2, kernel_size=3, strides=1, padding='SAME') self.conv6_2_mbox_conf = nn.Conv2D(512, 2, kernel_size=3, strides=1, padding='SAME')
self.conv6_2_mbox_loc = nn.Conv2D(512, 4, kernel_size=3, strides=1, padding='SAME') self.conv6_2_mbox_loc = nn.Conv2D(512, 4, kernel_size=3, strides=1, padding='SAME')
self.conv7_2_mbox_conf = nn.Conv2D(256, 2, kernel_size=3, strides=1, padding='SAME') self.conv7_2_mbox_conf = nn.Conv2D(256, 2, kernel_size=3, strides=1, padding='SAME')
self.conv7_2_mbox_loc = nn.Conv2D(256, 4, kernel_size=3, strides=1, padding='SAME') self.conv7_2_mbox_loc = nn.Conv2D(256, 4, kernel_size=3, strides=1, padding='SAME')
def forward(self, inp): def forward(self, inp):
x, = inp x, = inp
x = x - self.minus x = x - self.minus
x = tf.nn.relu(self.conv1_1(x)) x = tf.nn.relu(self.conv1_1(x))
x = tf.nn.relu(self.conv1_2(x)) x = tf.nn.relu(self.conv1_2(x))
x = tf.nn.max_pool(x, [1,2,2,1], [1,2,2,1], "VALID") x = tf.nn.max_pool(x, [1,2,2,1], [1,2,2,1], "VALID")
x = tf.nn.relu(self.conv2_1(x)) x = tf.nn.relu(self.conv2_1(x))
x = tf.nn.relu(self.conv2_2(x)) x = tf.nn.relu(self.conv2_2(x))
x = tf.nn.max_pool(x, [1,2,2,1], [1,2,2,1], "VALID") x = tf.nn.max_pool(x, [1,2,2,1], [1,2,2,1], "VALID")
x = tf.nn.relu(self.conv3_1(x)) x = tf.nn.relu(self.conv3_1(x))
x = tf.nn.relu(self.conv3_2(x)) x = tf.nn.relu(self.conv3_2(x))
x = tf.nn.relu(self.conv3_3(x)) x = tf.nn.relu(self.conv3_3(x))
f3_3 = x f3_3 = x
x = tf.nn.max_pool(x, [1,2,2,1], [1,2,2,1], "VALID") x = tf.nn.max_pool(x, [1,2,2,1], [1,2,2,1], "VALID")
x = tf.nn.relu(self.conv4_1(x)) x = tf.nn.relu(self.conv4_1(x))
x = tf.nn.relu(self.conv4_2(x)) x = tf.nn.relu(self.conv4_2(x))
x = tf.nn.relu(self.conv4_3(x)) x = tf.nn.relu(self.conv4_3(x))
f4_3 = x f4_3 = x
x = tf.nn.max_pool(x, [1,2,2,1], [1,2,2,1], "VALID") x = tf.nn.max_pool(x, [1,2,2,1], [1,2,2,1], "VALID")
x = tf.nn.relu(self.conv5_1(x)) x = tf.nn.relu(self.conv5_1(x))
x = tf.nn.relu(self.conv5_2(x)) x = tf.nn.relu(self.conv5_2(x))
x = tf.nn.relu(self.conv5_3(x)) x = tf.nn.relu(self.conv5_3(x))
f5_3 = x f5_3 = x
x = tf.nn.max_pool(x, [1,2,2,1], [1,2,2,1], "VALID") x = tf.nn.max_pool(x, [1,2,2,1], [1,2,2,1], "VALID")
x = tf.nn.relu(self.fc6(x)) x = tf.nn.relu(self.fc6(x))
x = tf.nn.relu(self.fc7(x)) x = tf.nn.relu(self.fc7(x))
ffc7 = x ffc7 = x
x = tf.nn.relu(self.conv6_1(x)) x = tf.nn.relu(self.conv6_1(x))
x = tf.nn.relu(self.conv6_2(x)) x = tf.nn.relu(self.conv6_2(x))
f6_2 = x f6_2 = x
x = tf.nn.relu(self.conv7_1(x)) x = tf.nn.relu(self.conv7_1(x))
x = tf.nn.relu(self.conv7_2(x)) x = tf.nn.relu(self.conv7_2(x))
f7_2 = x f7_2 = x
f3_3 = self.conv3_3_norm(f3_3) f3_3 = self.conv3_3_norm(f3_3)
f4_3 = self.conv4_3_norm(f4_3) f4_3 = self.conv4_3_norm(f4_3)
f5_3 = self.conv5_3_norm(f5_3) f5_3 = self.conv5_3_norm(f5_3)
cls1 = self.conv3_3_norm_mbox_conf(f3_3) cls1 = self.conv3_3_norm_mbox_conf(f3_3)
reg1 = self.conv3_3_norm_mbox_loc(f3_3) reg1 = self.conv3_3_norm_mbox_loc(f3_3)
cls2 = tf.nn.softmax(self.conv4_3_norm_mbox_conf(f4_3)) cls2 = tf.nn.softmax(self.conv4_3_norm_mbox_conf(f4_3))
reg2 = self.conv4_3_norm_mbox_loc(f4_3) reg2 = self.conv4_3_norm_mbox_loc(f4_3)
cls3 = tf.nn.softmax(self.conv5_3_norm_mbox_conf(f5_3)) cls3 = tf.nn.softmax(self.conv5_3_norm_mbox_conf(f5_3))
reg3 = self.conv5_3_norm_mbox_loc(f5_3) reg3 = self.conv5_3_norm_mbox_loc(f5_3)
cls4 = tf.nn.softmax(self.fc7_mbox_conf(ffc7)) cls4 = tf.nn.softmax(self.fc7_mbox_conf(ffc7))
reg4 = self.fc7_mbox_loc(ffc7) reg4 = self.fc7_mbox_loc(ffc7)
cls5 = tf.nn.softmax(self.conv6_2_mbox_conf(f6_2)) cls5 = tf.nn.softmax(self.conv6_2_mbox_conf(f6_2))
reg5 = self.conv6_2_mbox_loc(f6_2) reg5 = self.conv6_2_mbox_loc(f6_2)
cls6 = tf.nn.softmax(self.conv7_2_mbox_conf(f7_2)) cls6 = tf.nn.softmax(self.conv7_2_mbox_conf(f7_2))
reg6 = self.conv7_2_mbox_loc(f7_2) reg6 = self.conv7_2_mbox_loc(f7_2)
# max-out background label # max-out background label
bmax = tf.maximum(tf.maximum(cls1[:,:,:,0:1], cls1[:,:,:,1:2]), cls1[:,:,:,2:3]) bmax = tf.maximum(tf.maximum(cls1[:,:,:,0:1], cls1[:,:,:,1:2]), cls1[:,:,:,2:3])
cls1 = tf.concat ([bmax, cls1[:,:,:,3:4] ], axis=-1) cls1 = tf.concat ([bmax, cls1[:,:,:,3:4] ], axis=-1)
cls1 = tf.nn.softmax(cls1) cls1 = tf.nn.softmax(cls1)
return [cls1, reg1, cls2, reg2, cls3, reg3, cls4, reg4, cls5, reg5, cls6, reg6] return [cls1, reg1, cls2, reg2, cls3, reg3, cls4, reg4, cls5, reg5, cls6, reg6]
e = None e = None
@ -165,10 +165,10 @@ class S3FDExtractor(object):
if e is not None: e.__enter__() if e is not None: e.__enter__()
self.model = S3FD() self.model = S3FD()
self.model.load_weights (model_path) self.model.load_weights (model_path)
if e is not None: e.__exit__(None,None,None) if e is not None: e.__exit__(None,None,None)
self.model.build_for_run ([ ( tf.float32, (None,None,3) ) ]) self.model.build_for_run ([ ( tf.float32, nn.get4Dshape (None,None,3) ) ])
def __enter__(self): def __enter__(self):
return self return self
@ -205,7 +205,7 @@ class S3FDExtractor(object):
detected_faces = [ [(l,t,r,b), (r-l)*(b-t) ] for (l,t,r,b) in detected_faces ] detected_faces = [ [(l,t,r,b), (r-l)*(b-t) ] for (l,t,r,b) in detected_faces ]
detected_faces = sorted(detected_faces, key=operator.itemgetter(1), reverse=True ) detected_faces = sorted(detected_faces, key=operator.itemgetter(1), reverse=True )
detected_faces = [ x[0] for x in detected_faces] detected_faces = [ x[0] for x in detected_faces]
if is_remove_intersects: if is_remove_intersects:
for i in range( len(detected_faces)-1, 0, -1): for i in range( len(detected_faces)-1, 0, -1):
l1,t1,r1,b1 = detected_faces[i] l1,t1,r1,b1 = detected_faces[i]
@ -214,8 +214,8 @@ class S3FDExtractor(object):
dx = min(r0, r1) - max(l0, l1) dx = min(r0, r1) - max(l0, l1)
dy = min(b0, b1) - max(t0, t1) dy = min(b0, b1) - max(t0, t1)
if (dx>=0) and (dy>=0): if (dx>=0) and (dy>=0):
detected_faces.pop(i) detected_faces.pop(i)
return detected_faces return detected_faces
def refine(self, olist): def refine(self, olist):

View file

@ -20,117 +20,117 @@ TernausNet: U-Net with VGG11 Encoder Pre-Trained on ImageNet for Image Segmentat
class TernausNet(object): class TernausNet(object):
VERSION = 1 VERSION = 1
def __init__ (self, name, resolution, face_type_str, load_weights=True, weights_file_root=None, training=False, place_model_on_cpu=False): def __init__ (self, name, resolution, face_type_str, load_weights=True, weights_file_root=None, training=False, place_model_on_cpu=False):
nn.initialize() nn.initialize(data_format="NHWC")
tf = nn.tf tf = nn.tf
class Ternaus(nn.ModelBase): class Ternaus(nn.ModelBase):
def on_build(self, in_ch, ch): def on_build(self, in_ch, ch):
self.features_0 = nn.Conv2D (in_ch, ch, kernel_size=3, padding='SAME') self.features_0 = nn.Conv2D (in_ch, ch, kernel_size=3, padding='SAME')
self.blurpool_0 = nn.BlurPool (filt_size=3) self.blurpool_0 = nn.BlurPool (filt_size=3)
self.features_3 = nn.Conv2D (ch, ch*2, kernel_size=3, padding='SAME') self.features_3 = nn.Conv2D (ch, ch*2, kernel_size=3, padding='SAME')
self.blurpool_3 = nn.BlurPool (filt_size=3) self.blurpool_3 = nn.BlurPool (filt_size=3)
self.features_6 = nn.Conv2D (ch*2, ch*4, kernel_size=3, padding='SAME') self.features_6 = nn.Conv2D (ch*2, ch*4, kernel_size=3, padding='SAME')
self.features_8 = nn.Conv2D (ch*4, ch*4, kernel_size=3, padding='SAME') self.features_8 = nn.Conv2D (ch*4, ch*4, kernel_size=3, padding='SAME')
self.blurpool_8 = nn.BlurPool (filt_size=3) self.blurpool_8 = nn.BlurPool (filt_size=3)
self.features_11 = nn.Conv2D (ch*4, ch*8, kernel_size=3, padding='SAME') self.features_11 = nn.Conv2D (ch*4, ch*8, kernel_size=3, padding='SAME')
self.features_13 = nn.Conv2D (ch*8, ch*8, kernel_size=3, padding='SAME') self.features_13 = nn.Conv2D (ch*8, ch*8, kernel_size=3, padding='SAME')
self.blurpool_13 = nn.BlurPool (filt_size=3) self.blurpool_13 = nn.BlurPool (filt_size=3)
self.features_16 = nn.Conv2D (ch*8, ch*8, kernel_size=3, padding='SAME') self.features_16 = nn.Conv2D (ch*8, ch*8, kernel_size=3, padding='SAME')
self.features_18 = nn.Conv2D (ch*8, ch*8, kernel_size=3, padding='SAME') self.features_18 = nn.Conv2D (ch*8, ch*8, kernel_size=3, padding='SAME')
self.blurpool_18 = nn.BlurPool (filt_size=3) self.blurpool_18 = nn.BlurPool (filt_size=3)
self.conv_center = nn.Conv2D (ch*8, ch*8, kernel_size=3, padding='SAME') self.conv_center = nn.Conv2D (ch*8, ch*8, kernel_size=3, padding='SAME')
self.conv1_up = nn.Conv2DTranspose (ch*8, ch*4, kernel_size=3, padding='SAME') self.conv1_up = nn.Conv2DTranspose (ch*8, ch*4, kernel_size=3, padding='SAME')
self.conv1 = nn.Conv2D (ch*12, ch*8, kernel_size=3, padding='SAME') self.conv1 = nn.Conv2D (ch*12, ch*8, kernel_size=3, padding='SAME')
self.conv2_up = nn.Conv2DTranspose (ch*8, ch*4, kernel_size=3, padding='SAME') self.conv2_up = nn.Conv2DTranspose (ch*8, ch*4, kernel_size=3, padding='SAME')
self.conv2 = nn.Conv2D (ch*12, ch*8, kernel_size=3, padding='SAME') self.conv2 = nn.Conv2D (ch*12, ch*8, kernel_size=3, padding='SAME')
self.conv3_up = nn.Conv2DTranspose (ch*8, ch*2, kernel_size=3, padding='SAME') self.conv3_up = nn.Conv2DTranspose (ch*8, ch*2, kernel_size=3, padding='SAME')
self.conv3 = nn.Conv2D (ch*6, ch*4, kernel_size=3, padding='SAME') self.conv3 = nn.Conv2D (ch*6, ch*4, kernel_size=3, padding='SAME')
self.conv4_up = nn.Conv2DTranspose (ch*4, ch, kernel_size=3, padding='SAME') self.conv4_up = nn.Conv2DTranspose (ch*4, ch, kernel_size=3, padding='SAME')
self.conv4 = nn.Conv2D (ch*3, ch*2, kernel_size=3, padding='SAME') self.conv4 = nn.Conv2D (ch*3, ch*2, kernel_size=3, padding='SAME')
self.conv5_up = nn.Conv2DTranspose (ch*2, ch//2, kernel_size=3, padding='SAME') self.conv5_up = nn.Conv2DTranspose (ch*2, ch//2, kernel_size=3, padding='SAME')
self.conv5 = nn.Conv2D (ch//2+ch, ch, kernel_size=3, padding='SAME') self.conv5 = nn.Conv2D (ch//2+ch, ch, kernel_size=3, padding='SAME')
self.out_conv = nn.Conv2D (ch, 1, kernel_size=3, padding='SAME') self.out_conv = nn.Conv2D (ch, 1, kernel_size=3, padding='SAME')
def forward(self, inp): def forward(self, inp):
x, = inp x, = inp
x = x0 = tf.nn.relu(self.features_0(x)) x = x0 = tf.nn.relu(self.features_0(x))
x = self.blurpool_0(x) x = self.blurpool_0(x)
x = x1 = tf.nn.relu(self.features_3(x)) x = x1 = tf.nn.relu(self.features_3(x))
x = self.blurpool_3(x) x = self.blurpool_3(x)
x = tf.nn.relu(self.features_6(x)) x = tf.nn.relu(self.features_6(x))
x = x2 = tf.nn.relu(self.features_8(x)) x = x2 = tf.nn.relu(self.features_8(x))
x = self.blurpool_8(x) x = self.blurpool_8(x)
x = tf.nn.relu(self.features_11(x)) x = tf.nn.relu(self.features_11(x))
x = x3 = tf.nn.relu(self.features_13(x)) x = x3 = tf.nn.relu(self.features_13(x))
x = self.blurpool_13(x) x = self.blurpool_13(x)
x = tf.nn.relu(self.features_16(x)) x = tf.nn.relu(self.features_16(x))
x = x4 = tf.nn.relu(self.features_18(x)) x = x4 = tf.nn.relu(self.features_18(x))
x = self.blurpool_18(x) x = self.blurpool_18(x)
x = self.conv_center(x) x = self.conv_center(x)
x = tf.nn.relu(self.conv1_up(x)) x = tf.nn.relu(self.conv1_up(x))
x = tf.concat( [x,x4], -1) x = tf.concat( [x,x4], -1)
x = tf.nn.relu(self.conv1(x)) x = tf.nn.relu(self.conv1(x))
x = tf.nn.relu(self.conv2_up(x)) x = tf.nn.relu(self.conv2_up(x))
x = tf.concat( [x,x3], -1) x = tf.concat( [x,x3], -1)
x = tf.nn.relu(self.conv2(x)) x = tf.nn.relu(self.conv2(x))
x = tf.nn.relu(self.conv3_up(x)) x = tf.nn.relu(self.conv3_up(x))
x = tf.concat( [x,x2], -1) x = tf.concat( [x,x2], -1)
x = tf.nn.relu(self.conv3(x)) x = tf.nn.relu(self.conv3(x))
x = tf.nn.relu(self.conv4_up(x)) x = tf.nn.relu(self.conv4_up(x))
x = tf.concat( [x,x1], -1) x = tf.concat( [x,x1], -1)
x = tf.nn.relu(self.conv4(x)) x = tf.nn.relu(self.conv4(x))
x = tf.nn.relu(self.conv5_up(x)) x = tf.nn.relu(self.conv5_up(x))
x = tf.concat( [x,x0], -1) x = tf.concat( [x,x0], -1)
x = tf.nn.relu(self.conv5(x)) x = tf.nn.relu(self.conv5(x))
x = tf.nn.sigmoid(self.out_conv(x)) x = tf.nn.sigmoid(self.out_conv(x))
return x return x
if weights_file_root is not None: if weights_file_root is not None:
weights_file_root = Path(weights_file_root) weights_file_root = Path(weights_file_root)
else: else:
weights_file_root = Path(__file__).parent weights_file_root = Path(__file__).parent
self.weights_path = weights_file_root / ('%s_%d_%s.npy' % (name, resolution, face_type_str) ) self.weights_path = weights_file_root / ('%s_%d_%s.npy' % (name, resolution, face_type_str) )
e = tf.device('/CPU:0') if place_model_on_cpu else None e = tf.device('/CPU:0') if place_model_on_cpu else None
if e is not None: e.__enter__() if e is not None: e.__enter__()
self.net = Ternaus(3, 64, name='Ternaus') self.net = Ternaus(3, 64, name='Ternaus')
if load_weights: if load_weights:
self.net.load_weights (self.weights_path) self.net.load_weights (self.weights_path)
else: else:
self.net.init_weights() self.net.init_weights()
if e is not None: e.__exit__(None,None,None) if e is not None: e.__exit__(None,None,None)
self.net.build_for_run ( [(tf.float32, (resolution,resolution,3))] ) self.net.build_for_run ( [(tf.float32, nn.get4Dshape (resolution,resolution,3) )] )
if training: if training:
raise Exception("training not supported yet") raise Exception("training not supported yet")
""" """
if training: if training:
try: try:
@ -149,9 +149,9 @@ class TernausNet(object):
if 'CA.' in layer.name: if 'CA.' in layer.name:
conv_weights_list += [layer.weights[0]] #Conv2D kernel_weights conv_weights_list += [layer.weights[0]] #Conv2D kernel_weights
CAInitializerMP ( conv_weights_list ) CAInitializerMP ( conv_weights_list )
""" """
""" """
if training: if training:
inp_t = Input ( (resolution, resolution, 3) ) inp_t = Input ( (resolution, resolution, 3) )
@ -195,124 +195,3 @@ class TernausNet(object):
result = result[0] result = result[0]
return result return result
"""
self.weights_path = weights_file_root / ('%s_%d_%s.h5' % (name, resolution, face_type_str) )
self.net.build()
self.net.features_0.set_weights ( self.model.get_layer('features.0').get_weights() )
self.net.features_3.set_weights ( self.model.get_layer('features.3').get_weights() )
self.net.features_6.set_weights ( self.model.get_layer('features.6').get_weights() )
self.net.features_8.set_weights ( self.model.get_layer('features.8').get_weights() )
self.net.features_11.set_weights ( self.model.get_layer('features.11').get_weights() )
self.net.features_13.set_weights ( self.model.get_layer('features.13').get_weights() )
self.net.features_16.set_weights ( self.model.get_layer('features.16').get_weights() )
self.net.features_18.set_weights ( self.model.get_layer('features.18').get_weights() )
self.net.conv_center.set_weights ( self.model.get_layer('CA.1').get_weights() )
self.net.conv1_up.set_weights ( self.model.get_layer('CA.2').get_weights() )
self.net.conv1.set_weights ( self.model.get_layer('CA.3').get_weights() )
self.net.conv2_up.set_weights ( self.model.get_layer('CA.4').get_weights() )
self.net.conv2.set_weights ( self.model.get_layer('CA.5').get_weights() )
self.net.conv3_up.set_weights ( self.model.get_layer('CA.6').get_weights() )
self.net.conv3.set_weights ( self.model.get_layer('CA.7').get_weights() )
self.net.conv4_up.set_weights ( self.model.get_layer('CA.8').get_weights() )
self.net.conv4.set_weights ( self.model.get_layer('CA.9').get_weights() )
self.net.conv5_up.set_weights ( self.model.get_layer('CA.10').get_weights() )
self.net.conv5.set_weights ( self.model.get_layer('CA.11').get_weights() )
self.net.out_conv.set_weights ( self.model.get_layer('CA.12').get_weights() )
self.net.build_for_run ( [ (tf.float32, (resolution,resolution,3)) ])
self.net.save_weights (self.weights_path2)
def extract (self, input_image):
input_shape_len = len(input_image.shape)
if input_shape_len == 3:
input_image = input_image[np.newaxis,...]
result = np.clip ( self.model.predict( [input_image] ), 0, 1.0 )
result[result < 0.1] = 0 #get rid of noise
if input_shape_len == 3:
result = result[0]
return result
@staticmethod
def BuildModel ( resolution, ngf=64):
exec( nn.initialize(), locals(), globals() )
inp = Input ( (resolution,resolution,3) )
x = inp
x = TernausNet.Flow(ngf=ngf)(x)
model = Model(inp,x)
return model
@staticmethod
def Flow(ngf=64):
exec( nn.initialize(), locals(), globals() )
def func(input):
x = input
x0 = x = Conv2D(ngf, kernel_size=3, strides=1, padding='same', activation='relu', name='features.0')(x)
x = BlurPool(filt_size=3)(x)
x1 = x = Conv2D(ngf*2, kernel_size=3, strides=1, padding='same', activation='relu', name='features.3')(x)
x = BlurPool(filt_size=3)(x)
x = Conv2D(ngf*4, kernel_size=3, strides=1, padding='same', activation='relu', name='features.6')(x)
x2 = x = Conv2D(ngf*4, kernel_size=3, strides=1, padding='same', activation='relu', name='features.8')(x)
x = BlurPool(filt_size=3)(x)
x = Conv2D(ngf*8, kernel_size=3, strides=1, padding='same', activation='relu', name='features.11')(x)
x3 = x = Conv2D(ngf*8, kernel_size=3, strides=1, padding='same', activation='relu', name='features.13')(x)
x = BlurPool(filt_size=3)(x)
x = Conv2D(ngf*8, kernel_size=3, strides=1, padding='same', activation='relu', name='features.16')(x)
x4 = x = Conv2D(ngf*8, kernel_size=3, strides=1, padding='same', activation='relu', name='features.18')(x)
x = BlurPool(filt_size=3)(x)
x = Conv2D(ngf*8, kernel_size=3, strides=1, padding='same', name='CA.1')(x)
x = Conv2DTranspose (ngf*4, 3, strides=2, padding='same', activation='relu', name='CA.2') (x)
x = Concatenate(axis=3)([ x, x4])
x = Conv2D (ngf*8, 3, strides=1, padding='same', activation='relu', name='CA.3') (x)
x = Conv2DTranspose (ngf*4, 3, strides=2, padding='same', activation='relu', name='CA.4') (x)
x = Concatenate(axis=3)([ x, x3])
x = Conv2D (ngf*8, 3, strides=1, padding='same', activation='relu', name='CA.5') (x)
x = Conv2DTranspose (ngf*2, 3, strides=2, padding='same', activation='relu', name='CA.6') (x)
x = Concatenate(axis=3)([ x, x2])
x = Conv2D (ngf*4, 3, strides=1, padding='same', activation='relu', name='CA.7') (x)
x = Conv2DTranspose (ngf, 3, strides=2, padding='same', activation='relu', name='CA.8') (x)
x = Concatenate(axis=3)([ x, x1])
x = Conv2D (ngf*2, 3, strides=1, padding='same', activation='relu', name='CA.9') (x)
x = Conv2DTranspose (ngf // 2, 3, strides=2, padding='same', activation='relu', name='CA.10') (x)
x = Concatenate(axis=3)([ x, x0])
x = Conv2D (ngf, 3, strides=1, padding='same', activation='relu', name='CA.11') (x)
return Conv2D(1, 3, strides=1, padding='same', activation='sigmoid', name='CA.12')(x)
return func
"""

48
main.py
View file

@ -1,16 +1,16 @@
if __name__ == "__main__": if __name__ == "__main__":
# Fix for linux # Fix for linux
import multiprocessing import multiprocessing
multiprocessing.set_start_method("spawn") multiprocessing.set_start_method("spawn")
from core.leras import nn from core.leras import nn
nn.initialize_main_env() nn.initialize_main_env()
import os import os
import sys import sys
import time import time
import argparse import argparse
from core import pathex from core import pathex
from core import osex from core import osex
from pathlib import Path from pathlib import Path
@ -22,7 +22,7 @@ if __name__ == "__main__":
class fixPathAction(argparse.Action): class fixPathAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None): def __call__(self, parser, namespace, values, option_string=None):
setattr(namespace, self.dest, os.path.abspath(os.path.expanduser(values))) setattr(namespace, self.dest, os.path.abspath(os.path.expanduser(values)))
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers() subparsers = parser.add_subparsers()
@ -32,7 +32,7 @@ if __name__ == "__main__":
Extractor.main( detector = arguments.detector, Extractor.main( detector = arguments.detector,
input_path = Path(arguments.input_dir), input_path = Path(arguments.input_dir),
output_path = Path(arguments.output_dir), output_path = Path(arguments.output_dir),
output_debug = arguments.output_debug, output_debug = arguments.output_debug,
manual_fix = arguments.manual_fix, manual_fix = arguments.manual_fix,
manual_output_debug_fix = arguments.manual_output_debug_fix, manual_output_debug_fix = arguments.manual_output_debug_fix,
manual_window_size = arguments.manual_window_size, manual_window_size = arguments.manual_window_size,
@ -53,7 +53,7 @@ if __name__ == "__main__":
p.add_argument('--manual-window-size', type=int, dest="manual_window_size", default=1368, help="Manual fix window size. Default: 1368.") p.add_argument('--manual-window-size', type=int, dest="manual_window_size", default=1368, help="Manual fix window size. Default: 1368.")
p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Extract on CPU..") p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Extract on CPU..")
p.add_argument('--force-gpu-idxs', dest="force_gpu_idxs", default=None, help="Force to choose GPU indexes separated by comma.") p.add_argument('--force-gpu-idxs', dest="force_gpu_idxs", default=None, help="Force to choose GPU indexes separated by comma.")
p.set_defaults (func=process_extract) p.set_defaults (func=process_extract)
def process_dev_extract_vggface2_dataset(arguments): def process_dev_extract_vggface2_dataset(arguments):
@ -104,7 +104,7 @@ if __name__ == "__main__":
p = subparsers.add_parser( "dev_test", help="") p = subparsers.add_parser( "dev_test", help="")
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir") p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir")
p.set_defaults (func=process_dev_test) p.set_defaults (func=process_dev_test)
def process_sort(arguments): def process_sort(arguments):
osex.set_process_lowest_prio() osex.set_process_lowest_prio()
from mainscripts import Sorter from mainscripts import Sorter
@ -133,14 +133,14 @@ if __name__ == "__main__":
if arguments.remove_ie_polys: if arguments.remove_ie_polys:
Util.remove_ie_polys_folder (input_path=arguments.input_dir) Util.remove_ie_polys_folder (input_path=arguments.input_dir)
if arguments.save_faceset_metadata: if arguments.save_faceset_metadata:
Util.save_faceset_metadata_folder (input_path=arguments.input_dir) Util.save_faceset_metadata_folder (input_path=arguments.input_dir)
if arguments.restore_faceset_metadata: if arguments.restore_faceset_metadata:
Util.restore_faceset_metadata_folder (input_path=arguments.input_dir) Util.restore_faceset_metadata_folder (input_path=arguments.input_dir)
if arguments.pack_faceset: if arguments.pack_faceset:
io.log_info ("Performing faceset packing...\r\n") io.log_info ("Performing faceset packing...\r\n")
from samplelib import PackedFaceset from samplelib import PackedFaceset
PackedFaceset.pack( Path(arguments.input_dir) ) PackedFaceset.pack( Path(arguments.input_dir) )
@ -149,7 +149,7 @@ if __name__ == "__main__":
io.log_info ("Performing faceset unpacking...\r\n") io.log_info ("Performing faceset unpacking...\r\n")
from samplelib import PackedFaceset from samplelib import PackedFaceset
PackedFaceset.unpack( Path(arguments.input_dir) ) PackedFaceset.unpack( Path(arguments.input_dir) )
p = subparsers.add_parser( "util", help="Utilities.") p = subparsers.add_parser( "util", help="Utilities.")
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory. A directory containing the files you wish to process.") p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory. A directory containing the files you wish to process.")
p.add_argument('--convert-png-to-jpg', action="store_true", dest="convert_png_to_jpg", default=False, help="Convert DeepFaceLAB PNG files to JPEG.") p.add_argument('--convert-png-to-jpg', action="store_true", dest="convert_png_to_jpg", default=False, help="Convert DeepFaceLAB PNG files to JPEG.")
@ -166,7 +166,7 @@ if __name__ == "__main__":
def process_train(arguments): def process_train(arguments):
osex.set_process_lowest_prio() osex.set_process_lowest_prio()
kwargs = {'model_class_name' : arguments.model_name, kwargs = {'model_class_name' : arguments.model_name,
'saved_models_path' : Path(arguments.model_dir), 'saved_models_path' : Path(arguments.model_dir),
@ -179,7 +179,7 @@ if __name__ == "__main__":
'force_gpu_idxs' : arguments.force_gpu_idxs, 'force_gpu_idxs' : arguments.force_gpu_idxs,
'cpu_only' : arguments.cpu_only, 'cpu_only' : arguments.cpu_only,
'execute_programs' : [ [int(x[0]), x[1] ] for x in arguments.execute_program ], 'execute_programs' : [ [int(x[0]), x[1] ] for x in arguments.execute_program ],
'debug' : arguments.debug, 'debug' : arguments.debug,
} }
from mainscripts import Trainer from mainscripts import Trainer
Trainer.main(**kwargs) Trainer.main(**kwargs)
@ -188,12 +188,12 @@ if __name__ == "__main__":
p.add_argument('--training-data-src-dir', required=True, action=fixPathAction, dest="training_data_src_dir", help="Dir of extracted SRC faceset.") p.add_argument('--training-data-src-dir', required=True, action=fixPathAction, dest="training_data_src_dir", help="Dir of extracted SRC faceset.")
p.add_argument('--training-data-dst-dir', required=True, action=fixPathAction, dest="training_data_dst_dir", help="Dir of extracted DST faceset.") p.add_argument('--training-data-dst-dir', required=True, action=fixPathAction, dest="training_data_dst_dir", help="Dir of extracted DST faceset.")
p.add_argument('--pretraining-data-dir', action=fixPathAction, dest="pretraining_data_dir", default=None, help="Optional dir of extracted faceset that will be used in pretraining mode.") p.add_argument('--pretraining-data-dir', action=fixPathAction, dest="pretraining_data_dir", default=None, help="Optional dir of extracted faceset that will be used in pretraining mode.")
p.add_argument('--pretrained-model-dir', action=fixPathAction, dest="pretrained_model_dir", default=None, help="Optional dir of pretrain model files. (Currently only for Quick96).") p.add_argument('--pretrained-model-dir', action=fixPathAction, dest="pretrained_model_dir", default=None, help="Optional dir of pretrain model files. (Currently only for Quick96).")
p.add_argument('--model-dir', required=True, action=fixPathAction, dest="model_dir", help="Saved models dir.") p.add_argument('--model-dir', required=True, action=fixPathAction, dest="model_dir", help="Saved models dir.")
p.add_argument('--model', required=True, dest="model_name", choices=pathex.get_all_dir_names_startswith ( Path(__file__).parent / 'models' , 'Model_'), help="Model class name.") p.add_argument('--model', required=True, dest="model_name", choices=pathex.get_all_dir_names_startswith ( Path(__file__).parent / 'models' , 'Model_'), help="Model class name.")
p.add_argument('--debug', action="store_true", dest="debug", default=False, help="Debug samples.") p.add_argument('--debug', action="store_true", dest="debug", default=False, help="Debug samples.")
p.add_argument('--no-preview', action="store_true", dest="no_preview", default=False, help="Disable preview window.") p.add_argument('--no-preview', action="store_true", dest="no_preview", default=False, help="Disable preview window.")
p.add_argument('--force-model-name', dest="force_model_name", default=None, help="Forcing to choose model name from model/ folder.") p.add_argument('--force-model-name', dest="force_model_name", default=None, help="Forcing to choose model name from model/ folder.")
p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Train on CPU.") p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Train on CPU.")
p.add_argument('--force-gpu-idxs', dest="force_gpu_idxs", default=None, help="Force to choose GPU indexes separated by comma.") p.add_argument('--force-gpu-idxs', dest="force_gpu_idxs", default=None, help="Force to choose GPU indexes separated by comma.")
p.add_argument('--execute-program', dest="execute_program", default=[], action='append', nargs='+') p.add_argument('--execute-program', dest="execute_program", default=[], action='append', nargs='+')
@ -221,7 +221,7 @@ if __name__ == "__main__":
p.add_argument('--aligned-dir', action=fixPathAction, dest="aligned_dir", default=None, help="Aligned directory. This is where the extracted of dst faces stored.") p.add_argument('--aligned-dir', action=fixPathAction, dest="aligned_dir", default=None, help="Aligned directory. This is where the extracted of dst faces stored.")
p.add_argument('--model-dir', required=True, action=fixPathAction, dest="model_dir", help="Model dir.") p.add_argument('--model-dir', required=True, action=fixPathAction, dest="model_dir", help="Model dir.")
p.add_argument('--model', required=True, dest="model_name", choices=pathex.get_all_dir_names_startswith ( Path(__file__).parent / 'models' , 'Model_'), help="Model class name.") p.add_argument('--model', required=True, dest="model_name", choices=pathex.get_all_dir_names_startswith ( Path(__file__).parent / 'models' , 'Model_'), help="Model class name.")
p.add_argument('--force-model-name', dest="force_model_name", default=None, help="Forcing to choose model name from model/ folder.") p.add_argument('--force-model-name', dest="force_model_name", default=None, help="Forcing to choose model name from model/ folder.")
p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Merge on CPU.") p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Merge on CPU.")
p.add_argument('--force-gpu-idxs', dest="force_gpu_idxs", default=None, help="Force to choose GPU indexes separated by comma.") p.add_argument('--force-gpu-idxs', dest="force_gpu_idxs", default=None, help="Force to choose GPU indexes separated by comma.")
p.set_defaults(func=process_merge) p.set_defaults(func=process_merge)
@ -304,18 +304,18 @@ if __name__ == "__main__":
def process_faceset_enhancer(arguments): def process_faceset_enhancer(arguments):
osex.set_process_lowest_prio() osex.set_process_lowest_prio()
from mainscripts import FacesetEnhancer from mainscripts import FacesetEnhancer
FacesetEnhancer.process_folder ( Path(arguments.input_dir), FacesetEnhancer.process_folder ( Path(arguments.input_dir),
cpu_only=arguments.cpu_only, cpu_only=arguments.cpu_only,
force_gpu_idxs=arguments.force_gpu_idxs force_gpu_idxs=arguments.force_gpu_idxs
) )
p = facesettool_parser.add_parser ("enhance", help="Enhance details in DFL faceset.") p = facesettool_parser.add_parser ("enhance", help="Enhance details in DFL faceset.")
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory of aligned faces.") p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory of aligned faces.")
p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Process on CPU.") p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Process on CPU.")
p.add_argument('--force-gpu-idxs', dest="force_gpu_idxs", default=None, help="Force to choose GPU indexes separated by comma.") p.add_argument('--force-gpu-idxs', dest="force_gpu_idxs", default=None, help="Force to choose GPU indexes separated by comma.")
p.set_defaults(func=process_faceset_enhancer) p.set_defaults(func=process_faceset_enhancer)
""" """
def process_relight_faceset(arguments): def process_relight_faceset(arguments):
osex.set_process_lowest_prio() osex.set_process_lowest_prio()
@ -326,7 +326,7 @@ if __name__ == "__main__":
osex.set_process_lowest_prio() osex.set_process_lowest_prio()
from mainscripts import FacesetRelighter from mainscripts import FacesetRelighter
FacesetRelighter.delete_relighted (arguments.input_dir) FacesetRelighter.delete_relighted (arguments.input_dir)
p = facesettool_parser.add_parser ("relight", help="Synthesize new faces from existing ones by relighting them. With the relighted faces neural network will better reproduce face shadows.") p = facesettool_parser.add_parser ("relight", help="Synthesize new faces from existing ones by relighting them. With the relighted faces neural network will better reproduce face shadows.")
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory of aligned faces.") p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory of aligned faces.")
p.add_argument('--lighten', action="store_true", dest="lighten", default=None, help="Lighten the faces.") p.add_argument('--lighten', action="store_true", dest="lighten", default=None, help="Lighten the faces.")

View file

@ -47,7 +47,7 @@ class ExtractSubprocessor(Subprocessor):
self.max_faces_from_image = client_dict['max_faces_from_image'] self.max_faces_from_image = client_dict['max_faces_from_image']
self.device_idx = client_dict['device_idx'] self.device_idx = client_dict['device_idx']
self.cpu_only = client_dict['device_type'] == 'CPU' self.cpu_only = client_dict['device_type'] == 'CPU'
self.final_output_path = client_dict['final_output_path'] self.final_output_path = client_dict['final_output_path']
self.output_debug_path = client_dict['output_debug_path'] self.output_debug_path = client_dict['output_debug_path']
#transfer and set stdin in order to work code.interact in debug subprocess #transfer and set stdin in order to work code.interact in debug subprocess
@ -64,9 +64,9 @@ class ExtractSubprocessor(Subprocessor):
if self.type == 'all' or 'rects' in self.type or 'landmarks' in self.type: if self.type == 'all' or 'rects' in self.type or 'landmarks' in self.type:
nn.initialize (device_config) nn.initialize (device_config)
self.log_info (f"Running on {client_dict['device_name'] }") self.log_info (f"Running on {client_dict['device_name'] }")
if self.type == 'all' or self.type == 'rects-s3fd' or 'landmarks' in self.type: if self.type == 'all' or self.type == 'rects-s3fd' or 'landmarks' in self.type:
self.rects_extractor = facelib.S3FDExtractor(place_model_on_cpu=place_model_on_cpu) self.rects_extractor = facelib.S3FDExtractor(place_model_on_cpu=place_model_on_cpu)
@ -79,8 +79,8 @@ class ExtractSubprocessor(Subprocessor):
def process_data(self, data): def process_data(self, data):
if 'landmarks' in self.type and len(data.rects) == 0: if 'landmarks' in self.type and len(data.rects) == 0:
return data return data
filepath = data.filepath filepath = data.filepath
cached_filepath, image = self.cached_image cached_filepath, image = self.cached_image
if cached_filepath != filepath: if cached_filepath != filepath:
image = cv2_imread( filepath ) image = cv2_imread( filepath )
@ -93,7 +93,7 @@ class ExtractSubprocessor(Subprocessor):
h, w, c = image.shape h, w, c = image.shape
extract_from_dflimg = (h == w and DFLIMG.load (filepath) is not None) extract_from_dflimg = (h == w and DFLIMG.load (filepath) is not None)
if 'rects' in self.type or self.type == 'all': if 'rects' in self.type or self.type == 'all':
data = ExtractSubprocessor.Cli.rects_stage (data=data, data = ExtractSubprocessor.Cli.rects_stage (data=data,
image=image, image=image,
@ -119,7 +119,7 @@ class ExtractSubprocessor(Subprocessor):
final_output_path=self.final_output_path, final_output_path=self.final_output_path,
) )
return data return data
@staticmethod @staticmethod
def rects_stage(data, def rects_stage(data,
image, image,
@ -157,7 +157,7 @@ class ExtractSubprocessor(Subprocessor):
rects_extractor, rects_extractor,
): ):
h, w, ch = image.shape h, w, ch = image.shape
if data.rects_rotation == 0: if data.rects_rotation == 0:
rotated_image = image rotated_image = image
elif data.rects_rotation == 90: elif data.rects_rotation == 90:
@ -323,7 +323,7 @@ class ExtractSubprocessor(Subprocessor):
self.manual_window_size = manual_window_size self.manual_window_size = manual_window_size
self.max_faces_from_image = max_faces_from_image self.max_faces_from_image = max_faces_from_image
self.result = [] self.result = []
self.devices = ExtractSubprocessor.get_devices_for_config(self.type, device_config) self.devices = ExtractSubprocessor.get_devices_for_config(self.type, device_config)
super().__init__('Extractor', ExtractSubprocessor.Cli, super().__init__('Extractor', ExtractSubprocessor.Cli,
@ -731,21 +731,21 @@ def main(detector=None,
if detector == 'manual': if detector == 'manual':
io.log_info ('Performing manual extract...') io.log_info ('Performing manual extract...')
data = ExtractSubprocessor ([ ExtractSubprocessor.Data(Path(filename)) for filename in input_path_image_paths ], 'landmarks-manual', image_size, face_type, output_debug_path if output_debug else None, manual_window_size=manual_window_size, device_config=device_config).run() data = ExtractSubprocessor ([ ExtractSubprocessor.Data(Path(filename)) for filename in input_path_image_paths ], 'landmarks-manual', image_size, face_type, output_debug_path if output_debug else None, manual_window_size=manual_window_size, device_config=device_config).run()
io.log_info ('Performing 3rd pass...') io.log_info ('Performing 3rd pass...')
data = ExtractSubprocessor (data, 'final', image_size, face_type, output_debug_path if output_debug else None, final_output_path=output_path, device_config=device_config).run() data = ExtractSubprocessor (data, 'final', image_size, face_type, output_debug_path if output_debug else None, final_output_path=output_path, device_config=device_config).run()
else: else:
io.log_info ('Extracting faces...') io.log_info ('Extracting faces...')
data = ExtractSubprocessor ([ ExtractSubprocessor.Data(Path(filename)) for filename in input_path_image_paths ], data = ExtractSubprocessor ([ ExtractSubprocessor.Data(Path(filename)) for filename in input_path_image_paths ],
'all', 'all',
image_size, image_size,
face_type, face_type,
output_debug_path if output_debug else None, output_debug_path if output_debug else None,
max_faces_from_image=max_faces_from_image, max_faces_from_image=max_faces_from_image,
final_output_path=output_path, final_output_path=output_path,
device_config=device_config).run() device_config=device_config).run()
faces_detected += sum([d.faces_detected for d in data]) faces_detected += sum([d.faces_detected for d in data])
if manual_fix: if manual_fix:

View file

@ -10,7 +10,7 @@ from core.cv2ex import *
class FacesetEnhancerSubprocessor(Subprocessor): class FacesetEnhancerSubprocessor(Subprocessor):
#override #override
def __init__(self, image_paths, output_dirpath, device_config): def __init__(self, image_paths, output_dirpath, device_config):
self.image_paths = image_paths self.image_paths = image_paths
@ -18,17 +18,17 @@ class FacesetEnhancerSubprocessor(Subprocessor):
self.result = [] self.result = []
self.nn_initialize_mp_lock = multiprocessing.Lock() self.nn_initialize_mp_lock = multiprocessing.Lock()
self.devices = FacesetEnhancerSubprocessor.get_devices_for_config(device_config) self.devices = FacesetEnhancerSubprocessor.get_devices_for_config(device_config)
super().__init__('FacesetEnhancer', FacesetEnhancerSubprocessor.Cli, 600) super().__init__('FacesetEnhancer', FacesetEnhancerSubprocessor.Cli, 600)
#override #override
def on_clients_initialized(self): def on_clients_initialized(self):
io.progress_bar (None, len (self.image_paths)) io.progress_bar (None, len (self.image_paths))
#override #override
def on_clients_finalized(self): def on_clients_finalized(self):
io.progress_bar_close() io.progress_bar_close()
#override #override
def process_info_generator(self): def process_info_generator(self):
base_dict = {'output_dirpath':self.output_dirpath, base_dict = {'output_dirpath':self.output_dirpath,
@ -42,34 +42,34 @@ class FacesetEnhancerSubprocessor(Subprocessor):
yield client_dict['device_name'], {}, client_dict yield client_dict['device_name'], {}, client_dict
#override #override
def get_data(self, host_dict): def get_data(self, host_dict):
if len (self.image_paths) > 0: if len (self.image_paths) > 0:
return self.image_paths.pop(0) return self.image_paths.pop(0)
#override #override
def on_data_return (self, host_dict, data): def on_data_return (self, host_dict, data):
self.image_paths.insert(0, data) self.image_paths.insert(0, data)
#override #override
def on_result (self, host_dict, data, result): def on_result (self, host_dict, data, result):
io.progress_bar_inc(1) io.progress_bar_inc(1)
if result[0] == 1: if result[0] == 1:
self.result +=[ (result[1], result[2]) ] self.result +=[ (result[1], result[2]) ]
#override #override
def get_result(self): def get_result(self):
return self.result return self.result
@staticmethod @staticmethod
def get_devices_for_config (device_config): def get_devices_for_config (device_config):
devices = device_config.devices devices = device_config.devices
cpu_only = len(devices) == 0 cpu_only = len(devices) == 0
if not cpu_only: if not cpu_only:
return [ (device.index, 'GPU', device.name, device.total_mem_gb) for device in devices ] return [ (device.index, 'GPU', device.name, device.total_mem_gb) for device in devices ]
else: else:
return [ (i, 'CPU', 'CPU%d' % (i), 0 ) for i in range( min(8, multiprocessing.cpu_count() // 2) ) ] return [ (i, 'CPU', 'CPU%d' % (i), 0 ) for i in range( min(8, multiprocessing.cpu_count() // 2) ) ]
class Cli(Subprocessor.Cli): class Cli(Subprocessor.Cli):
#override #override
@ -85,14 +85,14 @@ class FacesetEnhancerSubprocessor(Subprocessor):
else: else:
device_config = nn.DeviceConfig.GPUIndexes ([device_idx]) device_config = nn.DeviceConfig.GPUIndexes ([device_idx])
device_vram = device_config.devices[0].total_mem_gb device_vram = device_config.devices[0].total_mem_gb
nn.initialize (device_config) nn.initialize (device_config)
intro_str = 'Running on %s.' % (client_dict['device_name']) intro_str = 'Running on %s.' % (client_dict['device_name'])
self.log_info (intro_str) self.log_info (intro_str)
from facelib import FaceEnhancer from facelib import FaceEnhancer
self.fe = FaceEnhancer( place_model_on_cpu=(device_vram<=2) ) self.fe = FaceEnhancer( place_model_on_cpu=(device_vram<=2) )
#override #override
@ -103,28 +103,28 @@ class FacesetEnhancerSubprocessor(Subprocessor):
self.log_err ("%s is not a dfl image file" % (filepath.name) ) self.log_err ("%s is not a dfl image file" % (filepath.name) )
else: else:
img = cv2_imread(filepath).astype(np.float32) / 255.0 img = cv2_imread(filepath).astype(np.float32) / 255.0
img = self.fe.enhance(img) img = self.fe.enhance(img)
img = np.clip (img*255, 0, 255).astype(np.uint8) img = np.clip (img*255, 0, 255).astype(np.uint8)
output_filepath = self.output_dirpath / filepath.name output_filepath = self.output_dirpath / filepath.name
cv2_imwrite ( str(output_filepath), img, [int(cv2.IMWRITE_JPEG_QUALITY), 100] ) cv2_imwrite ( str(output_filepath), img, [int(cv2.IMWRITE_JPEG_QUALITY), 100] )
dflimg.embed_and_set ( str(output_filepath) ) dflimg.embed_and_set ( str(output_filepath) )
return (1, filepath, output_filepath) return (1, filepath, output_filepath)
except: except:
self.log_err (f"Exception occured while processing file {filepath}. Error: {traceback.format_exc()}") self.log_err (f"Exception occured while processing file {filepath}. Error: {traceback.format_exc()}")
return (0, filepath, None) return (0, filepath, None)
def process_folder ( dirpath, cpu_only=False, force_gpu_idxs=None ): def process_folder ( dirpath, cpu_only=False, force_gpu_idxs=None ):
device_config = nn.DeviceConfig.GPUIndexes( force_gpu_idxs or nn.ask_choose_device_idxs(suggest_all_gpu=True) ) \ device_config = nn.DeviceConfig.GPUIndexes( force_gpu_idxs or nn.ask_choose_device_idxs(suggest_all_gpu=True) ) \
if not cpu_only else nn.DeviceConfig.CPU() if not cpu_only else nn.DeviceConfig.CPU()
output_dirpath = dirpath.parent / (dirpath.name + '_enhanced') output_dirpath = dirpath.parent / (dirpath.name + '_enhanced')
output_dirpath.mkdir (exist_ok=True, parents=True) output_dirpath.mkdir (exist_ok=True, parents=True)
dirpath_parts = '/'.join( dirpath.parts[-2:]) dirpath_parts = '/'.join( dirpath.parts[-2:])
output_dirpath_parts = '/'.join( output_dirpath.parts[-2:] ) output_dirpath_parts = '/'.join( output_dirpath.parts[-2:] )
io.log_info (f"Enhancing faceset in {dirpath_parts}") io.log_info (f"Enhancing faceset in {dirpath_parts}")
@ -134,19 +134,19 @@ def process_folder ( dirpath, cpu_only=False, force_gpu_idxs=None ):
if len(output_images_paths) > 0: if len(output_images_paths) > 0:
for filename in output_images_paths: for filename in output_images_paths:
Path(filename).unlink() Path(filename).unlink()
image_paths = [Path(x) for x in pathex.get_image_paths( dirpath )] image_paths = [Path(x) for x in pathex.get_image_paths( dirpath )]
result = FacesetEnhancerSubprocessor ( image_paths, output_dirpath, device_config=device_config).run() result = FacesetEnhancerSubprocessor ( image_paths, output_dirpath, device_config=device_config).run()
is_merge = io.input_bool (f"\r\nMerge {output_dirpath_parts} to {dirpath_parts} ?", True) is_merge = io.input_bool (f"\r\nMerge {output_dirpath_parts} to {dirpath_parts} ?", True)
if is_merge: if is_merge:
io.log_info (f"Copying processed files to {dirpath_parts}") io.log_info (f"Copying processed files to {dirpath_parts}")
for (filepath, output_filepath) in result: for (filepath, output_filepath) in result:
try: try:
shutil.copy (output_filepath, filepath) shutil.copy (output_filepath, filepath)
except: except:
pass pass
io.log_info (f"Removing {output_dirpath_parts}") io.log_info (f"Removing {output_dirpath_parts}")
shutil.rmtree(output_dirpath) shutil.rmtree(output_dirpath)

View file

@ -319,14 +319,14 @@ class MaskEditor:
def get_ie_polys(self): def get_ie_polys(self):
return self.ie_polys return self.ie_polys
def set_ie_polys(self, saved_ie_polys): def set_ie_polys(self, saved_ie_polys):
self.state = self.STATE_NONE self.state = self.STATE_NONE
self.ie_polys = saved_ie_polys self.ie_polys = saved_ie_polys
self.redo_to_end_point() self.redo_to_end_point()
self.mask_finish() self.mask_finish()
def mask_editor_main(input_dir, confirmed_dir=None, skipped_dir=None, no_default_mask=False): def mask_editor_main(input_dir, confirmed_dir=None, skipped_dir=None, no_default_mask=False):
input_path = Path(input_dir) input_path = Path(input_dir)
@ -341,7 +341,7 @@ def mask_editor_main(input_dir, confirmed_dir=None, skipped_dir=None, no_default
if not skipped_path.exists(): if not skipped_path.exists():
skipped_path.mkdir(parents=True) skipped_path.mkdir(parents=True)
if not no_default_mask: if not no_default_mask:
eyebrows_expand_mod = np.clip ( io.input_int ("Default eyebrows expand modifier?", 100, add_info="0..400"), 0, 400 ) / 100.0 eyebrows_expand_mod = np.clip ( io.input_int ("Default eyebrows expand modifier?", 100, add_info="0..400"), 0, 400 ) / 100.0
else: else:
@ -368,7 +368,7 @@ def mask_editor_main(input_dir, confirmed_dir=None, skipped_dir=None, no_default
do_save_count = 0 do_save_count = 0
do_skip_move_count = 0 do_skip_move_count = 0
do_skip_count = 0 do_skip_count = 0
def jobs_count(): def jobs_count():
return do_prev_count + do_save_move_count + do_save_count + do_skip_move_count + do_skip_count return do_prev_count + do_save_move_count + do_save_count + do_skip_move_count + do_skip_count

View file

@ -237,7 +237,7 @@ class MergeSubprocessor(Subprocessor):
try: try:
with open( str(self.merger_session_filepath), "rb") as f: with open( str(self.merger_session_filepath), "rb") as f:
session_data = pickle.loads(f.read()) session_data = pickle.loads(f.read())
except Exception as e: except Exception as e:
pass pass
@ -282,8 +282,8 @@ class MergeSubprocessor(Subprocessor):
self.frames_done_idxs = s_frames_done_idxs self.frames_done_idxs = s_frames_done_idxs
rewind_to_begin = len(self.frames_idxs) == 0 # all frames are done? rewind_to_begin = len(self.frames_idxs) == 0 # all frames are done?
if self.model_iter != s_model_iter: if self.model_iter != s_model_iter:
# model was more trained, recompute all frames # model was more trained, recompute all frames
rewind_to_begin = True rewind_to_begin = True
for frame in self.frames: for frame in self.frames:
@ -461,15 +461,15 @@ class MergeSubprocessor(Subprocessor):
if key == 27: #esc if key == 27: #esc
self.is_interactive_quitting = True self.is_interactive_quitting = True
elif self.screen_manager.get_current() is self.main_screen: elif self.screen_manager.get_current() is self.main_screen:
if self.merger_config.type == MergerConfig.TYPE_MASKED and chr_key in self.masked_keys: if self.merger_config.type == MergerConfig.TYPE_MASKED and chr_key in self.masked_keys:
self.process_remain_frames = False self.process_remain_frames = False
if cur_frame is not None: if cur_frame is not None:
cfg = cur_frame.cfg cfg = cur_frame.cfg
prev_cfg = cfg.copy() prev_cfg = cfg.copy()
if cfg.type == MergerConfig.TYPE_MASKED: if cfg.type == MergerConfig.TYPE_MASKED:
self.masked_keys_funcs[chr_key](cfg, shift_pressed) self.masked_keys_funcs[chr_key](cfg, shift_pressed)
if prev_cfg != cfg: if prev_cfg != cfg:
@ -485,7 +485,7 @@ class MergeSubprocessor(Subprocessor):
if chr_key == ',': if chr_key == ',':
if shift_pressed: if shift_pressed:
go_first_frame = True go_first_frame = True
elif chr_key == 'm': elif chr_key == 'm':
if not shift_pressed: if not shift_pressed:
go_prev_frame_overriding_cfg = True go_prev_frame_overriding_cfg = True
@ -499,7 +499,7 @@ class MergeSubprocessor(Subprocessor):
if chr_key == '.': if chr_key == '.':
if shift_pressed: if shift_pressed:
self.process_remain_frames = not self.process_remain_frames self.process_remain_frames = not self.process_remain_frames
elif chr_key == '/': elif chr_key == '/':
if not shift_pressed: if not shift_pressed:
go_next_frame_overriding_cfg = True go_next_frame_overriding_cfg = True
@ -566,7 +566,7 @@ class MergeSubprocessor(Subprocessor):
frame.cfg = cur_frame.cfg.copy() frame.cfg = cur_frame.cfg.copy()
else: else:
frame.cfg = f[ self.frames_idxs[i-1] ].cfg.copy() frame.cfg = f[ self.frames_idxs[i-1] ].cfg.copy()
frame.is_done = False #initiate solve again frame.is_done = False #initiate solve again
frame.is_shown = False frame.is_shown = False
@ -775,7 +775,7 @@ def main (model_class_name=None,
io.log_info ("No frames to merge in input_dir.") io.log_info ("No frames to merge in input_dir.")
else: else:
MergeSubprocessor ( MergeSubprocessor (
is_interactive = is_interactive, is_interactive = is_interactive,
merger_session_filepath = merger_session_filepath, merger_session_filepath = merger_session_filepath,
predictor_func = predictor_func, predictor_func = predictor_func,
predictor_input_shape = predictor_input_shape, predictor_input_shape = predictor_input_shape,

View file

@ -717,7 +717,7 @@ def sort_by_absdiff(input_path):
from core.leras import nn from core.leras import nn
device_config = nn.ask_choose_device_idxs(choose_only_one=True, return_device_config=True) device_config = nn.ask_choose_device_idxs(choose_only_one=True, return_device_config=True)
nn.initialize( device_config=device_config ) nn.initialize( device_config=device_config, data_format="NHWC" )
tf = nn.tf tf = nn.tf
image_paths = pathex.get_image_paths(input_path) image_paths = pathex.get_image_paths(input_path)

View file

@ -12,19 +12,19 @@ import cv2
import models import models
from core.interact import interact as io from core.interact import interact as io
def trainerThread (s2c, c2s, e, def trainerThread (s2c, c2s, e,
model_class_name = None, model_class_name = None,
saved_models_path = None, saved_models_path = None,
training_data_src_path = None, training_data_src_path = None,
training_data_dst_path = None, training_data_dst_path = None,
pretraining_data_path = None, pretraining_data_path = None,
pretrained_model_path = None, pretrained_model_path = None,
no_preview=False, no_preview=False,
force_model_name=None, force_model_name=None,
force_gpu_idxs=None, force_gpu_idxs=None,
cpu_only=None, cpu_only=None,
execute_programs = None, execute_programs = None,
debug=False, debug=False,
**kwargs): **kwargs):
while True: while True:
try: try:
@ -98,11 +98,11 @@ def trainerThread (s2c, c2s, e,
exec_prog = False exec_prog = False
if prog_time > 0 and (cur_time - start_time) >= prog_time: if prog_time > 0 and (cur_time - start_time) >= prog_time:
x[0] = 0 x[0] = 0
exec_prog = True
elif prog_time < 0 and (cur_time - last_time) >= -prog_time:
x[2] = cur_time
exec_prog = True exec_prog = True
elif prog_time < 0 and (cur_time - last_time) >= -prog_time:
x[2] = cur_time
exec_prog = True
if exec_prog: if exec_prog:
try: try:
exec(prog) exec(prog)
@ -110,12 +110,12 @@ def trainerThread (s2c, c2s, e,
print("Unable to execute program: %s" % (prog) ) print("Unable to execute program: %s" % (prog) )
if not is_reached_goal: if not is_reached_goal:
if model.get_iter() == 0: if model.get_iter() == 0:
io.log_info("") io.log_info("")
io.log_info("Trying to do the first iteration. If an error occurs, reduce the model parameters.") io.log_info("Trying to do the first iteration. If an error occurs, reduce the model parameters.")
io.log_info("") io.log_info("")
iter, iter_time = model.train_one_iter() iter, iter_time = model.train_one_iter()
loss_history = model.get_loss_history() loss_history = model.get_loss_history()
@ -127,8 +127,8 @@ def trainerThread (s2c, c2s, e,
if shared_state['after_save']: if shared_state['after_save']:
shared_state['after_save'] = False shared_state['after_save'] = False
last_save_time = time.time() last_save_time = time.time()
mean_loss = np.mean ( [ np.array(loss_history[i]) for i in range(save_iter, iter) ], axis=0) mean_loss = np.mean ( [ np.array(loss_history[i]) for i in range(save_iter, iter) ], axis=0)
for loss_value in mean_loss: for loss_value in mean_loss:
@ -145,10 +145,10 @@ def trainerThread (s2c, c2s, e,
io.log_info ('\r' + loss_string, end='') io.log_info ('\r' + loss_string, end='')
else: else:
io.log_info (loss_string, end='\r') io.log_info (loss_string, end='\r')
if model.get_iter() == 1: if model.get_iter() == 1:
model_save() model_save()
if model.get_target_iter() != 0 and model.is_reached_iter_goal(): if model.get_target_iter() != 0 and model.is_reached_iter_goal():
io.log_info ('Reached target iteration.') io.log_info ('Reached target iteration.')
model_save() model_save()

View file

@ -15,34 +15,34 @@ def save_faceset_metadata_folder(input_path):
input_path = Path(input_path) input_path = Path(input_path)
metadata_filepath = input_path / 'meta.dat' metadata_filepath = input_path / 'meta.dat'
io.log_info (f"Saving metadata to {str(metadata_filepath)}\r\n") io.log_info (f"Saving metadata to {str(metadata_filepath)}\r\n")
d = {} d = {}
for filepath in io.progress_bar_generator( pathex.get_image_paths(input_path), "Processing"): for filepath in io.progress_bar_generator( pathex.get_image_paths(input_path), "Processing"):
filepath = Path(filepath) filepath = Path(filepath)
dflimg = DFLIMG.load (filepath) dflimg = DFLIMG.load (filepath)
dfl_dict = dflimg.getDFLDictData() dfl_dict = dflimg.getDFLDictData()
d[filepath.name] = ( dflimg.get_shape(), dfl_dict ) d[filepath.name] = ( dflimg.get_shape(), dfl_dict )
try: try:
with open(metadata_filepath, "wb") as f: with open(metadata_filepath, "wb") as f:
f.write ( pickle.dumps(d) ) f.write ( pickle.dumps(d) )
except: except:
raise Exception( 'cannot save %s' % (filename) ) raise Exception( 'cannot save %s' % (filename) )
io.log_info("Now you can edit images.") io.log_info("Now you can edit images.")
io.log_info("!!! Keep same filenames in the folder.") io.log_info("!!! Keep same filenames in the folder.")
io.log_info("You can change size of images, restoring process will downscale back to original size.") io.log_info("You can change size of images, restoring process will downscale back to original size.")
io.log_info("After that, use restore metadata.") io.log_info("After that, use restore metadata.")
def restore_faceset_metadata_folder(input_path): def restore_faceset_metadata_folder(input_path):
input_path = Path(input_path) input_path = Path(input_path)
metadata_filepath = input_path / 'meta.dat' metadata_filepath = input_path / 'meta.dat'
io.log_info (f"Restoring metadata from {str(metadata_filepath)}.\r\n") io.log_info (f"Restoring metadata from {str(metadata_filepath)}.\r\n")
if not metadata_filepath.exists(): if not metadata_filepath.exists():
io.log_err(f"Unable to find {str(metadata_filepath)}.") io.log_err(f"Unable to find {str(metadata_filepath)}.")
@ -54,27 +54,27 @@ def restore_faceset_metadata_folder(input_path):
for filepath in io.progress_bar_generator( pathex.get_image_paths(input_path), "Processing"): for filepath in io.progress_bar_generator( pathex.get_image_paths(input_path), "Processing"):
filepath = Path(filepath) filepath = Path(filepath)
shape, dfl_dict = d.get(filepath.name, None) shape, dfl_dict = d.get(filepath.name, None)
img = cv2_imread (str(filepath)) img = cv2_imread (str(filepath))
if img.shape != shape: if img.shape != shape:
img = cv2.resize (img, (shape[1], shape[0]), cv2.INTER_LANCZOS4 ) img = cv2.resize (img, (shape[1], shape[0]), cv2.INTER_LANCZOS4 )
if filepath.suffix == '.png': if filepath.suffix == '.png':
cv2_imwrite (str(filepath), img) cv2_imwrite (str(filepath), img)
elif filepath.suffix == '.jpg': elif filepath.suffix == '.jpg':
cv2_imwrite (str(filepath), img, [int(cv2.IMWRITE_JPEG_QUALITY), 100] ) cv2_imwrite (str(filepath), img, [int(cv2.IMWRITE_JPEG_QUALITY), 100] )
if filepath.suffix == '.png': if filepath.suffix == '.png':
DFLPNG.embed_dfldict( str(filepath), dfl_dict ) DFLPNG.embed_dfldict( str(filepath), dfl_dict )
elif filepath.suffix == '.jpg': elif filepath.suffix == '.jpg':
DFLJPG.embed_dfldict( str(filepath), dfl_dict ) DFLJPG.embed_dfldict( str(filepath), dfl_dict )
else: else:
continue continue
metadata_filepath.unlink() metadata_filepath.unlink()
def remove_ie_polys_file (filepath): def remove_ie_polys_file (filepath):
filepath = Path(filepath) filepath = Path(filepath)
@ -95,7 +95,7 @@ def remove_ie_polys_folder(input_path):
for filepath in io.progress_bar_generator( pathex.get_image_paths(input_path), "Removing"): for filepath in io.progress_bar_generator( pathex.get_image_paths(input_path), "Removing"):
filepath = Path(filepath) filepath = Path(filepath)
remove_ie_polys_file(filepath) remove_ie_polys_file(filepath)
def remove_fanseg_file (filepath): def remove_fanseg_file (filepath):
filepath = Path(filepath) filepath = Path(filepath)

View file

@ -101,7 +101,7 @@ def denoise_image_sequence( input_dir, ext=None, factor=None ):
kwargs = {} kwargs = {}
if ext == 'jpg': if ext == 'jpg':
kwargs.update ({'q:v':'2'}) kwargs.update ({'q:v':'2'})
job = ( ffmpeg job = ( ffmpeg
.input(str ( input_path / ('%5d.'+ext) ) ) .input(str ( input_path / ('%5d.'+ext) ) )
.filter("hqdn3d", factor, factor, 5,5) .filter("hqdn3d", factor, factor, 5,5)
@ -174,7 +174,7 @@ def video_from_sequence( input_dir, output_file, reference_file=None, ext=None,
input_image_paths = pathex.get_image_paths(input_path) input_image_paths = pathex.get_image_paths(input_path)
i_in = ffmpeg.input('pipe:', format='image2pipe', r=fps) i_in = ffmpeg.input('pipe:', format='image2pipe', r=fps)
output_args = [i_in] output_args = [i_in]
if ref_in_a is not None: if ref_in_a is not None:
@ -200,14 +200,14 @@ def video_from_sequence( input_dir, output_file, reference_file=None, ext=None,
job = ( ffmpeg.output(*output_args, **output_kwargs).overwrite_output() ) job = ( ffmpeg.output(*output_args, **output_kwargs).overwrite_output() )
try: try:
job_run = job.run_async(pipe_stdin=True) job_run = job.run_async(pipe_stdin=True)
for image_path in input_image_paths: for image_path in input_image_paths:
with open (image_path, "rb") as f: with open (image_path, "rb") as f:
image_bytes = f.read() image_bytes = f.read()
job_run.stdin.write (image_bytes) job_run.stdin.write (image_bytes)
job_run.stdin.close() job_run.stdin.close()
job_run.wait() job_run.wait()
except: except:

View file

@ -23,26 +23,26 @@ def extract_vggface2_dataset(input_dir, device_args={} ):
input_path = Path(input_dir) input_path = Path(input_dir)
if not input_path.exists(): if not input_path.exists():
raise ValueError('Input directory not found. Please ensure it exists.') raise ValueError('Input directory not found. Please ensure it exists.')
bb_csv = input_path / 'loose_bb_train.csv' bb_csv = input_path / 'loose_bb_train.csv'
if not bb_csv.exists(): if not bb_csv.exists():
raise ValueError('loose_bb_train.csv found. Please ensure it exists.') raise ValueError('loose_bb_train.csv found. Please ensure it exists.')
bb_lines = bb_csv.read_text().split('\n') bb_lines = bb_csv.read_text().split('\n')
bb_lines.pop(0) bb_lines.pop(0)
bb_dict = {} bb_dict = {}
for line in bb_lines: for line in bb_lines:
name, l, t, w, h = line.split(',') name, l, t, w, h = line.split(',')
name = name[1:-1] name = name[1:-1]
l, t, w, h = [ int(x) for x in (l, t, w, h) ] l, t, w, h = [ int(x) for x in (l, t, w, h) ]
bb_dict[name] = (l,t,w, h) bb_dict[name] = (l,t,w, h)
output_path = input_path.parent / (input_path.name + '_out') output_path = input_path.parent / (input_path.name + '_out')
dir_names = pathex.get_all_dir_names(input_path) dir_names = pathex.get_all_dir_names(input_path)
if not output_path.exists(): if not output_path.exists():
output_path.mkdir(parents=True, exist_ok=True) output_path.mkdir(parents=True, exist_ok=True)
@ -50,15 +50,15 @@ def extract_vggface2_dataset(input_dir, device_args={} ):
for dir_name in io.progress_bar_generator(dir_names, "Collecting"): for dir_name in io.progress_bar_generator(dir_names, "Collecting"):
cur_input_path = input_path / dir_name cur_input_path = input_path / dir_name
cur_output_path = output_path / dir_name cur_output_path = output_path / dir_name
if not cur_output_path.exists(): if not cur_output_path.exists():
cur_output_path.mkdir(parents=True, exist_ok=True) cur_output_path.mkdir(parents=True, exist_ok=True)
input_path_image_paths = pathex.get_image_paths(cur_input_path) input_path_image_paths = pathex.get_image_paths(cur_input_path)
for filename in input_path_image_paths: for filename in input_path_image_paths:
filename_path = Path(filename) filename_path = Path(filename)
name = filename_path.parent.name + '/' + filename_path.stem name = filename_path.parent.name + '/' + filename_path.stem
if name not in bb_dict: if name not in bb_dict:
continue continue
@ -66,29 +66,29 @@ def extract_vggface2_dataset(input_dir, device_args={} ):
l,t,w,h = bb_dict[name] l,t,w,h = bb_dict[name]
if min(w,h) < 128: if min(w,h) < 128:
continue continue
data += [ ExtractSubprocessor.Data(filename=filename,rects=[ (l,t,l+w,t+h) ], landmarks_accurate=False, force_output_path=cur_output_path ) ] data += [ ExtractSubprocessor.Data(filename=filename,rects=[ (l,t,l+w,t+h) ], landmarks_accurate=False, force_output_path=cur_output_path ) ]
face_type = FaceType.fromString('full_face') face_type = FaceType.fromString('full_face')
io.log_info ('Performing 2nd pass...') io.log_info ('Performing 2nd pass...')
data = ExtractSubprocessor (data, 'landmarks', 256, face_type, debug_dir=None, multi_gpu=multi_gpu, cpu_only=cpu_only, manual=False).run() data = ExtractSubprocessor (data, 'landmarks', 256, face_type, debug_dir=None, multi_gpu=multi_gpu, cpu_only=cpu_only, manual=False).run()
io.log_info ('Performing 3rd pass...') io.log_info ('Performing 3rd pass...')
ExtractSubprocessor (data, 'final', 256, face_type, debug_dir=None, multi_gpu=multi_gpu, cpu_only=cpu_only, manual=False, final_output_path=None).run() ExtractSubprocessor (data, 'final', 256, face_type, debug_dir=None, multi_gpu=multi_gpu, cpu_only=cpu_only, manual=False, final_output_path=None).run()
""" """
import code import code
code.interact(local=dict(globals(), **locals())) code.interact(local=dict(globals(), **locals()))
data_len = len(data) data_len = len(data)
i = 0 i = 0
while i < data_len-1: while i < data_len-1:
i_name = Path(data[i].filename).parent.name i_name = Path(data[i].filename).parent.name
sub_data = [] sub_data = []
for j in range (i, data_len): for j in range (i, data_len):
j_name = Path(data[j].filename).parent.name j_name = Path(data[j].filename).parent.name
if i_name == j_name: if i_name == j_name:
@ -96,33 +96,33 @@ def extract_vggface2_dataset(input_dir, device_args={} ):
else: else:
break break
i = j i = j
cur_output_path = output_path / i_name cur_output_path = output_path / i_name
io.log_info (f"Processing: {str(cur_output_path)}, {i}/{data_len} ") io.log_info (f"Processing: {str(cur_output_path)}, {i}/{data_len} ")
if not cur_output_path.exists(): if not cur_output_path.exists():
cur_output_path.mkdir(parents=True, exist_ok=True) cur_output_path.mkdir(parents=True, exist_ok=True)
for dir_name in dir_names: for dir_name in dir_names:
cur_input_path = input_path / dir_name cur_input_path = input_path / dir_name
cur_output_path = output_path / dir_name cur_output_path = output_path / dir_name
input_path_image_paths = pathex.get_image_paths(cur_input_path) input_path_image_paths = pathex.get_image_paths(cur_input_path)
l = len(input_path_image_paths) l = len(input_path_image_paths)
#if l < 250 or l > 350: #if l < 250 or l > 350:
# continue # continue
io.log_info (f"Processing: {str(cur_input_path)} ") io.log_info (f"Processing: {str(cur_input_path)} ")
if not cur_output_path.exists(): if not cur_output_path.exists():
cur_output_path.mkdir(parents=True, exist_ok=True) cur_output_path.mkdir(parents=True, exist_ok=True)
@ -130,41 +130,41 @@ def extract_vggface2_dataset(input_dir, device_args={} ):
data = [] data = []
for filename in input_path_image_paths: for filename in input_path_image_paths:
filename_path = Path(filename) filename_path = Path(filename)
name = filename_path.parent.name + '/' + filename_path.stem name = filename_path.parent.name + '/' + filename_path.stem
if name not in bb_dict: if name not in bb_dict:
continue continue
bb = bb_dict[name] bb = bb_dict[name]
l,t,w,h = bb l,t,w,h = bb
if min(w,h) < 128: if min(w,h) < 128:
continue continue
data += [ ExtractSubprocessor.Data(filename=filename,rects=[ (l,t,l+w,t+h) ], landmarks_accurate=False ) ] data += [ ExtractSubprocessor.Data(filename=filename,rects=[ (l,t,l+w,t+h) ], landmarks_accurate=False ) ]
io.log_info ('Performing 2nd pass...') io.log_info ('Performing 2nd pass...')
data = ExtractSubprocessor (data, 'landmarks', 256, face_type, debug_dir=None, multi_gpu=False, cpu_only=False, manual=False).run() data = ExtractSubprocessor (data, 'landmarks', 256, face_type, debug_dir=None, multi_gpu=False, cpu_only=False, manual=False).run()
io.log_info ('Performing 3rd pass...') io.log_info ('Performing 3rd pass...')
data = ExtractSubprocessor (data, 'final', 256, face_type, debug_dir=None, multi_gpu=False, cpu_only=False, manual=False, final_output_path=cur_output_path).run() data = ExtractSubprocessor (data, 'final', 256, face_type, debug_dir=None, multi_gpu=False, cpu_only=False, manual=False, final_output_path=cur_output_path).run()
io.log_info (f"Sorting: {str(cur_output_path)} ") io.log_info (f"Sorting: {str(cur_output_path)} ")
Sorter.main (input_path=str(cur_output_path), sort_by_method='hist') Sorter.main (input_path=str(cur_output_path), sort_by_method='hist')
import code import code
code.interact(local=dict(globals(), **locals())) code.interact(local=dict(globals(), **locals()))
#try: #try:
# io.log_info (f"Removing: {str(cur_input_path)} ") # io.log_info (f"Removing: {str(cur_input_path)} ")
# shutil.rmtree(cur_input_path) # shutil.rmtree(cur_input_path)
#except: #except:
# io.log_info (f"unable to remove: {str(cur_input_path)} ") # io.log_info (f"unable to remove: {str(cur_input_path)} ")
def extract_vggface2_dataset(input_dir, device_args={} ): def extract_vggface2_dataset(input_dir, device_args={} ):
multi_gpu = device_args.get('multi_gpu', False) multi_gpu = device_args.get('multi_gpu', False)
@ -173,27 +173,27 @@ def extract_vggface2_dataset(input_dir, device_args={} ):
input_path = Path(input_dir) input_path = Path(input_dir)
if not input_path.exists(): if not input_path.exists():
raise ValueError('Input directory not found. Please ensure it exists.') raise ValueError('Input directory not found. Please ensure it exists.')
output_path = input_path.parent / (input_path.name + '_out') output_path = input_path.parent / (input_path.name + '_out')
dir_names = pathex.get_all_dir_names(input_path) dir_names = pathex.get_all_dir_names(input_path)
if not output_path.exists(): if not output_path.exists():
output_path.mkdir(parents=True, exist_ok=True) output_path.mkdir(parents=True, exist_ok=True)
for dir_name in dir_names: for dir_name in dir_names:
cur_input_path = input_path / dir_name cur_input_path = input_path / dir_name
cur_output_path = output_path / dir_name cur_output_path = output_path / dir_name
l = len(pathex.get_image_paths(cur_input_path)) l = len(pathex.get_image_paths(cur_input_path))
if l < 250 or l > 350: if l < 250 or l > 350:
continue continue
io.log_info (f"Processing: {str(cur_input_path)} ") io.log_info (f"Processing: {str(cur_input_path)} ")
if not cur_output_path.exists(): if not cur_output_path.exists():
cur_output_path.mkdir(parents=True, exist_ok=True) cur_output_path.mkdir(parents=True, exist_ok=True)
@ -204,17 +204,17 @@ def extract_vggface2_dataset(input_dir, device_args={} ):
face_type='full_face', face_type='full_face',
max_faces_from_image=1, max_faces_from_image=1,
device_args=device_args ) device_args=device_args )
io.log_info (f"Sorting: {str(cur_input_path)} ") io.log_info (f"Sorting: {str(cur_input_path)} ")
Sorter.main (input_path=str(cur_output_path), sort_by_method='hist') Sorter.main (input_path=str(cur_output_path), sort_by_method='hist')
try: try:
io.log_info (f"Removing: {str(cur_input_path)} ") io.log_info (f"Removing: {str(cur_input_path)} ")
shutil.rmtree(cur_input_path) shutil.rmtree(cur_input_path)
except: except:
io.log_info (f"unable to remove: {str(cur_input_path)} ") io.log_info (f"unable to remove: {str(cur_input_path)} ")
""" """
class CelebAMASKHQSubprocessor(Subprocessor): class CelebAMASKHQSubprocessor(Subprocessor):
class Cli(Subprocessor.Cli): class Cli(Subprocessor.Cli):
@ -228,31 +228,31 @@ class CelebAMASKHQSubprocessor(Subprocessor):
filename = data[0] filename = data[0]
dflimg = DFLIMG.load(Path(filename)) dflimg = DFLIMG.load(Path(filename))
image_to_face_mat = dflimg.get_image_to_face_mat() image_to_face_mat = dflimg.get_image_to_face_mat()
src_filename = dflimg.get_source_filename() src_filename = dflimg.get_source_filename()
img = cv2_imread(filename) img = cv2_imread(filename)
h,w,c = img.shape h,w,c = img.shape
fanseg_mask = LandmarksProcessor.get_image_hull_mask(img.shape, dflimg.get_landmarks() ) fanseg_mask = LandmarksProcessor.get_image_hull_mask(img.shape, dflimg.get_landmarks() )
idx_name = '%.5d' % int(src_filename.split('.')[0]) idx_name = '%.5d' % int(src_filename.split('.')[0])
idx_files = [ x for x in self.masks_files_paths if idx_name in x ] idx_files = [ x for x in self.masks_files_paths if idx_name in x ]
skin_files = [ x for x in idx_files if 'skin' in x ] skin_files = [ x for x in idx_files if 'skin' in x ]
eye_glass_files = [ x for x in idx_files if 'eye_g' in x ] eye_glass_files = [ x for x in idx_files if 'eye_g' in x ]
for files, is_invert in [ (skin_files,False), for files, is_invert in [ (skin_files,False),
(eye_glass_files,True) ]: (eye_glass_files,True) ]:
if len(files) > 0: if len(files) > 0:
mask = cv2_imread(files[0]) mask = cv2_imread(files[0])
mask = mask[...,0] mask = mask[...,0]
mask[mask == 255] = 1 mask[mask == 255] = 1
mask = mask.astype(np.float32) mask = mask.astype(np.float32)
mask = cv2.resize(mask, (1024,1024) ) mask = cv2.resize(mask, (1024,1024) )
mask = cv2.warpAffine(mask, image_to_face_mat, (w, h), cv2.INTER_LANCZOS4) mask = cv2.warpAffine(mask, image_to_face_mat, (w, h), cv2.INTER_LANCZOS4)
if not is_invert: if not is_invert:
fanseg_mask *= mask[...,None] fanseg_mask *= mask[...,None]
else: else:
@ -270,7 +270,7 @@ class CelebAMASKHQSubprocessor(Subprocessor):
def __init__(self, image_paths, masks_files_paths ): def __init__(self, image_paths, masks_files_paths ):
self.image_paths = image_paths self.image_paths = image_paths
self.masks_files_paths = masks_files_paths self.masks_files_paths = masks_files_paths
self.result = [] self.result = []
super().__init__('CelebAMASKHQSubprocessor', CelebAMASKHQSubprocessor.Cli, 60) super().__init__('CelebAMASKHQSubprocessor', CelebAMASKHQSubprocessor.Cli, 60)
@ -304,23 +304,23 @@ class CelebAMASKHQSubprocessor(Subprocessor):
#override #override
def get_result(self): def get_result(self):
return self.result return self.result
#unused in end user workflow #unused in end user workflow
def apply_celebamaskhq(input_dir ): def apply_celebamaskhq(input_dir ):
input_path = Path(input_dir) input_path = Path(input_dir)
img_path = input_path / 'aligned' img_path = input_path / 'aligned'
mask_path = input_path / 'mask' mask_path = input_path / 'mask'
if not img_path.exists(): if not img_path.exists():
raise ValueError(f'{str(img_path)} directory not found. Please ensure it exists.') raise ValueError(f'{str(img_path)} directory not found. Please ensure it exists.')
CelebAMASKHQSubprocessor(pathex.get_image_paths(img_path), CelebAMASKHQSubprocessor(pathex.get_image_paths(img_path),
pathex.get_image_paths(mask_path, subdirs=True) ).run() pathex.get_image_paths(mask_path, subdirs=True) ).run()
return return
paths_to_extract = [] paths_to_extract = []
for filename in io.progress_bar_generator(pathex.get_image_paths(img_path), desc="Processing"): for filename in io.progress_bar_generator(pathex.get_image_paths(img_path), desc="Processing"):
filepath = Path(filename) filepath = Path(filename)
@ -328,44 +328,44 @@ def apply_celebamaskhq(input_dir ):
if dflimg is not None: if dflimg is not None:
paths_to_extract.append (filepath) paths_to_extract.append (filepath)
image_to_face_mat = dflimg.get_image_to_face_mat() image_to_face_mat = dflimg.get_image_to_face_mat()
src_filename = dflimg.get_source_filename() src_filename = dflimg.get_source_filename()
#img = cv2_imread(filename) #img = cv2_imread(filename)
h,w,c = dflimg.get_shape() h,w,c = dflimg.get_shape()
fanseg_mask = LandmarksProcessor.get_image_hull_mask( (h,w,c), dflimg.get_landmarks() ) fanseg_mask = LandmarksProcessor.get_image_hull_mask( (h,w,c), dflimg.get_landmarks() )
idx_name = '%.5d' % int(src_filename.split('.')[0]) idx_name = '%.5d' % int(src_filename.split('.')[0])
idx_files = [ x for x in masks_files if idx_name in x ] idx_files = [ x for x in masks_files if idx_name in x ]
skin_files = [ x for x in idx_files if 'skin' in x ] skin_files = [ x for x in idx_files if 'skin' in x ]
eye_glass_files = [ x for x in idx_files if 'eye_g' in x ] eye_glass_files = [ x for x in idx_files if 'eye_g' in x ]
for files, is_invert in [ (skin_files,False), for files, is_invert in [ (skin_files,False),
(eye_glass_files,True) ]: (eye_glass_files,True) ]:
if len(files) > 0: if len(files) > 0:
mask = cv2_imread(files[0]) mask = cv2_imread(files[0])
mask = mask[...,0] mask = mask[...,0]
mask[mask == 255] = 1 mask[mask == 255] = 1
mask = mask.astype(np.float32) mask = mask.astype(np.float32)
mask = cv2.resize(mask, (1024,1024) ) mask = cv2.resize(mask, (1024,1024) )
mask = cv2.warpAffine(mask, image_to_face_mat, (w, h), cv2.INTER_LANCZOS4) mask = cv2.warpAffine(mask, image_to_face_mat, (w, h), cv2.INTER_LANCZOS4)
if not is_invert: if not is_invert:
fanseg_mask *= mask[...,None] fanseg_mask *= mask[...,None]
else: else:
fanseg_mask *= (1-mask[...,None]) fanseg_mask *= (1-mask[...,None])
#cv2.imshow("", (fanseg_mask*255).astype(np.uint8) ) #cv2.imshow("", (fanseg_mask*255).astype(np.uint8) )
#cv2.waitKey(0) #cv2.waitKey(0)
dflimg.embed_and_set (filename, fanseg_mask=fanseg_mask) dflimg.embed_and_set (filename, fanseg_mask=fanseg_mask)
#import code #import code
#code.interact(local=dict(globals(), **locals())) #code.interact(local=dict(globals(), **locals()))
@ -375,43 +375,43 @@ def apply_celebamaskhq(input_dir ):
def extract_fanseg(input_dir, device_args={} ): def extract_fanseg(input_dir, device_args={} ):
multi_gpu = device_args.get('multi_gpu', False) multi_gpu = device_args.get('multi_gpu', False)
cpu_only = device_args.get('cpu_only', False) cpu_only = device_args.get('cpu_only', False)
input_path = Path(input_dir) input_path = Path(input_dir)
if not input_path.exists(): if not input_path.exists():
raise ValueError('Input directory not found. Please ensure it exists.') raise ValueError('Input directory not found. Please ensure it exists.')
paths_to_extract = [] paths_to_extract = []
for filename in pathex.get_image_paths(input_path) : for filename in pathex.get_image_paths(input_path) :
filepath = Path(filename) filepath = Path(filename)
dflimg = DFLIMG.load ( filepath ) dflimg = DFLIMG.load ( filepath )
if dflimg is not None: if dflimg is not None:
paths_to_extract.append (filepath) paths_to_extract.append (filepath)
paths_to_extract_len = len(paths_to_extract) paths_to_extract_len = len(paths_to_extract)
if paths_to_extract_len > 0: if paths_to_extract_len > 0:
io.log_info ("Performing extract fanseg for %d files..." % (paths_to_extract_len) ) io.log_info ("Performing extract fanseg for %d files..." % (paths_to_extract_len) )
data = ExtractSubprocessor ([ ExtractSubprocessor.Data(filename) for filename in paths_to_extract ], 'fanseg', multi_gpu=multi_gpu, cpu_only=cpu_only).run() data = ExtractSubprocessor ([ ExtractSubprocessor.Data(filename) for filename in paths_to_extract ], 'fanseg', multi_gpu=multi_gpu, cpu_only=cpu_only).run()
#unused in end user workflow #unused in end user workflow
def extract_umd_csv(input_file_csv, def extract_umd_csv(input_file_csv,
image_size=256, image_size=256,
face_type='full_face', face_type='full_face',
device_args={} ): device_args={} ):
#extract faces from umdfaces.io dataset csv file with pitch,yaw,roll info. #extract faces from umdfaces.io dataset csv file with pitch,yaw,roll info.
multi_gpu = device_args.get('multi_gpu', False) multi_gpu = device_args.get('multi_gpu', False)
cpu_only = device_args.get('cpu_only', False) cpu_only = device_args.get('cpu_only', False)
face_type = FaceType.fromString(face_type) face_type = FaceType.fromString(face_type)
input_file_csv_path = Path(input_file_csv) input_file_csv_path = Path(input_file_csv)
if not input_file_csv_path.exists(): if not input_file_csv_path.exists():
raise ValueError('input_file_csv not found. Please ensure it exists.') raise ValueError('input_file_csv not found. Please ensure it exists.')
input_file_csv_root_path = input_file_csv_path.parent input_file_csv_root_path = input_file_csv_path.parent
output_path = input_file_csv_path.parent / ('aligned_' + input_file_csv_path.name) output_path = input_file_csv_path.parent / ('aligned_' + input_file_csv_path.name)
io.log_info("Output dir is %s." % (str(output_path)) ) io.log_info("Output dir is %s." % (str(output_path)) )
if output_path.exists(): if output_path.exists():
output_images_paths = pathex.get_image_paths(output_path) output_images_paths = pathex.get_image_paths(output_path)
if len(output_images_paths) > 0: if len(output_images_paths) > 0:
@ -420,15 +420,15 @@ def extract_umd_csv(input_file_csv,
Path(filename).unlink() Path(filename).unlink()
else: else:
output_path.mkdir(parents=True, exist_ok=True) output_path.mkdir(parents=True, exist_ok=True)
try: try:
with open( str(input_file_csv_path), 'r') as f: with open( str(input_file_csv_path), 'r') as f:
csv_file = f.read() csv_file = f.read()
except Exception as e: except Exception as e:
io.log_err("Unable to open or read file " + str(input_file_csv_path) + ": " + str(e) ) io.log_err("Unable to open or read file " + str(input_file_csv_path) + ": " + str(e) )
return return
strings = csv_file.split('\n') strings = csv_file.split('\n')
keys = strings[0].split(',') keys = strings[0].split(',')
keys_len = len(keys) keys_len = len(keys)
csv_data = [] csv_data = []
@ -437,29 +437,29 @@ def extract_umd_csv(input_file_csv,
if keys_len != len(values): if keys_len != len(values):
io.log_err("Wrong string in csv file, skipping.") io.log_err("Wrong string in csv file, skipping.")
continue continue
csv_data += [ { keys[n] : values[n] for n in range(keys_len) } ] csv_data += [ { keys[n] : values[n] for n in range(keys_len) } ]
data = [] data = []
for d in csv_data: for d in csv_data:
filename = input_file_csv_root_path / d['FILE'] filename = input_file_csv_root_path / d['FILE']
x,y,w,h = float(d['FACE_X']), float(d['FACE_Y']), float(d['FACE_WIDTH']), float(d['FACE_HEIGHT']) x,y,w,h = float(d['FACE_X']), float(d['FACE_Y']), float(d['FACE_WIDTH']), float(d['FACE_HEIGHT'])
data += [ ExtractSubprocessor.Data(filename=filename, rects=[ [x,y,x+w,y+h] ]) ] data += [ ExtractSubprocessor.Data(filename=filename, rects=[ [x,y,x+w,y+h] ]) ]
images_found = len(data) images_found = len(data)
faces_detected = 0 faces_detected = 0
if len(data) > 0: if len(data) > 0:
io.log_info ("Performing 2nd pass from csv file...") io.log_info ("Performing 2nd pass from csv file...")
data = ExtractSubprocessor (data, 'landmarks', multi_gpu=multi_gpu, cpu_only=cpu_only).run() data = ExtractSubprocessor (data, 'landmarks', multi_gpu=multi_gpu, cpu_only=cpu_only).run()
io.log_info ('Performing 3rd pass...') io.log_info ('Performing 3rd pass...')
data = ExtractSubprocessor (data, 'final', image_size, face_type, None, multi_gpu=multi_gpu, cpu_only=cpu_only, manual=False, final_output_path=output_path).run() data = ExtractSubprocessor (data, 'final', image_size, face_type, None, multi_gpu=multi_gpu, cpu_only=cpu_only, manual=False, final_output_path=output_path).run()
faces_detected += sum([d.faces_detected for d in data]) faces_detected += sum([d.faces_detected for d in data])
io.log_info ('-------------------------') io.log_info ('-------------------------')
io.log_info ('Images found: %d' % (images_found) ) io.log_info ('Images found: %d' % (images_found) )
io.log_info ('Faces detected: %d' % (faces_detected) ) io.log_info ('Faces detected: %d' % (faces_detected) )
@ -467,22 +467,21 @@ def extract_umd_csv(input_file_csv,
def dev_test(input_dir): def dev_test(input_dir):
input_path = Path(input_dir) input_path = Path(input_dir)
dir_names = pathex.get_all_dir_names(input_path) dir_names = pathex.get_all_dir_names(input_path)
for dir_name in io.progress_bar_generator(dir_names, desc="Processing"): for dir_name in io.progress_bar_generator(dir_names, desc="Processing"):
img_paths = pathex.get_image_paths (input_path / dir_name) img_paths = pathex.get_image_paths (input_path / dir_name)
for filename in img_paths: for filename in img_paths:
filepath = Path(filename) filepath = Path(filename)
dflimg = DFLIMG.load (filepath) dflimg = DFLIMG.load (filepath)
if dflimg is None: if dflimg is None:
raise ValueError raise ValueError
dflimg.embed_and_set(filename, person_name=dir_name) dflimg.embed_and_set(filename, person_name=dir_name)
#import code #import code
#code.interact(local=dict(globals(), **locals())) #code.interact(local=dict(globals(), **locals()))

View file

@ -7,13 +7,13 @@ from core.cv2ex import *
def process_frame_info(frame_info, inp_sh): def process_frame_info(frame_info, inp_sh):
img_uint8 = cv2_imread (frame_info.filename) img_uint8 = cv2_imread (frame_info.filename)
img_uint8 = imagelib.normalize_channels (img_uint8, 3) img_uint8 = imagelib.normalize_channels (img_uint8, 3)
img = img_uint8.astype(np.float32) / 255.0 img = img_uint8.astype(np.float32) / 255.0
img_mat = LandmarksProcessor.get_transform_mat (frame_info.landmarks_list[0], inp_sh[0], face_type=FaceType.FULL_NO_ALIGN) img_mat = LandmarksProcessor.get_transform_mat (frame_info.landmarks_list[0], inp_sh[0], face_type=FaceType.FULL_NO_ALIGN)
img = cv2.warpAffine( img, img_mat, inp_sh[0:2], borderMode=cv2.BORDER_REPLICATE, flags=cv2.INTER_CUBIC ) img = cv2.warpAffine( img, img_mat, inp_sh[0:2], borderMode=cv2.BORDER_REPLICATE, flags=cv2.INTER_CUBIC )
return img return img
def MergeFaceAvatar (predictor_func, predictor_input_shape, cfg, prev_temporal_frame_infos, frame_info, next_temporal_frame_infos): def MergeFaceAvatar (predictor_func, predictor_input_shape, cfg, prev_temporal_frame_infos, frame_info, next_temporal_frame_infos):
inp_sh = predictor_input_shape inp_sh = predictor_input_shape
@ -28,14 +28,14 @@ def MergeFaceAvatar (predictor_func, predictor_input_shape, cfg, prev_temporal_f
if cfg.super_resolution_mode != 0: if cfg.super_resolution_mode != 0:
prd_f = cfg.superres_func(cfg.super_resolution_mode, prd_f) prd_f = cfg.superres_func(cfg.super_resolution_mode, prd_f)
if cfg.sharpen_mode != 0 and cfg.sharpen_amount != 0: if cfg.sharpen_mode != 0 and cfg.sharpen_amount != 0:
prd_f = cfg.sharpen_func ( prd_f, cfg.sharpen_mode, 3, cfg.sharpen_amount) prd_f = cfg.sharpen_func ( prd_f, cfg.sharpen_mode, 3, cfg.sharpen_amount)
out_img = np.clip(prd_f, 0.0, 1.0) out_img = np.clip(prd_f, 0.0, 1.0)
if cfg.add_source_image: if cfg.add_source_image:
out_img = np.concatenate ( [cv2.resize ( img, (prd_f.shape[1], prd_f.shape[0]) ), out_img = np.concatenate ( [cv2.resize ( img, (prd_f.shape[1], prd_f.shape[0]) ),
out_img], axis=1 ) out_img], axis=1 )
return (out_img*255).astype(np.uint8) return (out_img*255).astype(np.uint8)

View file

@ -29,7 +29,7 @@ def MergeMaskedFace (predictor_func, predictor_input_shape, cfg, frame_info, img
dst_face_bgr = cv2.warpAffine( img_bgr , face_mat, (output_size, output_size), flags=cv2.INTER_CUBIC ) dst_face_bgr = cv2.warpAffine( img_bgr , face_mat, (output_size, output_size), flags=cv2.INTER_CUBIC )
dst_face_bgr = np.clip(dst_face_bgr, 0, 1) dst_face_bgr = np.clip(dst_face_bgr, 0, 1)
dst_face_mask_a_0 = cv2.warpAffine( img_face_mask_a, face_mat, (output_size, output_size), flags=cv2.INTER_CUBIC ) dst_face_mask_a_0 = cv2.warpAffine( img_face_mask_a, face_mat, (output_size, output_size), flags=cv2.INTER_CUBIC )
dst_face_mask_a_0 = np.clip(dst_face_mask_a_0, 0, 1) dst_face_mask_a_0 = np.clip(dst_face_mask_a_0, 0, 1)
@ -50,7 +50,7 @@ def MergeMaskedFace (predictor_func, predictor_input_shape, cfg, frame_info, img
if cfg.super_resolution_mode: if cfg.super_resolution_mode:
prd_face_bgr = cfg.superres_func(cfg.super_resolution_mode, prd_face_bgr) prd_face_bgr = cfg.superres_func(cfg.super_resolution_mode, prd_face_bgr)
prd_face_bgr = np.clip(prd_face_bgr, 0, 1) prd_face_bgr = np.clip(prd_face_bgr, 0, 1)
if predictor_masked: if predictor_masked:
prd_face_mask_a_0 = cv2.resize (prd_face_mask_a_0, (output_size, output_size), cv2.INTER_CUBIC) prd_face_mask_a_0 = cv2.resize (prd_face_mask_a_0, (output_size, output_size), cv2.INTER_CUBIC)
else: else:
@ -192,12 +192,12 @@ def MergeMaskedFace (predictor_func, predictor_input_shape, cfg, frame_info, img
prd_face_bgr = imagelib.color_transfer_idt (prd_face_bgr, dst_face_bgr) prd_face_bgr = imagelib.color_transfer_idt (prd_face_bgr, dst_face_bgr)
elif cfg.color_transfer_mode == 6: #idt-m elif cfg.color_transfer_mode == 6: #idt-m
prd_face_bgr = imagelib.color_transfer_idt (prd_face_bgr*prd_face_mask_a, dst_face_bgr*prd_face_mask_a) prd_face_bgr = imagelib.color_transfer_idt (prd_face_bgr*prd_face_mask_a, dst_face_bgr*prd_face_mask_a)
elif cfg.color_transfer_mode == 7: #sot-m elif cfg.color_transfer_mode == 7: #sot-m
prd_face_bgr = imagelib.color_transfer_sot (prd_face_bgr*prd_face_mask_a, dst_face_bgr*prd_face_mask_a) prd_face_bgr = imagelib.color_transfer_sot (prd_face_bgr*prd_face_mask_a, dst_face_bgr*prd_face_mask_a)
prd_face_bgr = np.clip (prd_face_bgr, 0.0, 1.0) prd_face_bgr = np.clip (prd_face_bgr, 0.0, 1.0)
elif cfg.color_transfer_mode == 8: #mix-m elif cfg.color_transfer_mode == 8: #mix-m
prd_face_bgr = imagelib.color_transfer_mix (prd_face_bgr*prd_face_mask_a, dst_face_bgr*prd_face_mask_a) prd_face_bgr = imagelib.color_transfer_mix (prd_face_bgr*prd_face_mask_a, dst_face_bgr*prd_face_mask_a)
if cfg.mode == 'hist-match-bw': if cfg.mode == 'hist-match-bw':
prd_face_bgr = cv2.cvtColor(prd_face_bgr, cv2.COLOR_BGR2GRAY) prd_face_bgr = cv2.cvtColor(prd_face_bgr, cv2.COLOR_BGR2GRAY)
prd_face_bgr = np.repeat( np.expand_dims (prd_face_bgr, -1), (3,), -1 ) prd_face_bgr = np.repeat( np.expand_dims (prd_face_bgr, -1), (3,), -1 )
@ -236,7 +236,7 @@ def MergeMaskedFace (predictor_func, predictor_input_shape, cfg, frame_info, img
break break
out_img = cv2.warpAffine( prd_face_bgr, face_output_mat, img_size, out_img, cv2.WARP_INVERSE_MAP | cv2.INTER_CUBIC, cv2.BORDER_TRANSPARENT ) out_img = cv2.warpAffine( prd_face_bgr, face_output_mat, img_size, out_img, cv2.WARP_INVERSE_MAP | cv2.INTER_CUBIC, cv2.BORDER_TRANSPARENT )
out_img = np.clip(out_img, 0.0, 1.0) out_img = np.clip(out_img, 0.0, 1.0)
if 'seamless' in cfg.mode: if 'seamless' in cfg.mode:
@ -254,8 +254,8 @@ def MergeMaskedFace (predictor_func, predictor_input_shape, cfg, frame_info, img
raise Exception("Seamless fail: " + e_str) #reraise MemoryError in order to reprocess this data by other processes raise Exception("Seamless fail: " + e_str) #reraise MemoryError in order to reprocess this data by other processes
else: else:
print ("Seamless fail: " + e_str) print ("Seamless fail: " + e_str)
out_img = img_bgr*(1-img_face_mask_aaa) + (out_img*img_face_mask_aaa) out_img = img_bgr*(1-img_face_mask_aaa) + (out_img*img_face_mask_aaa)
out_face_bgr = cv2.warpAffine( out_img, face_mat, (output_size, output_size) ) out_face_bgr = cv2.warpAffine( out_img, face_mat, (output_size, output_size) )
@ -279,12 +279,12 @@ def MergeMaskedFace (predictor_func, predictor_input_shape, cfg, frame_info, img
out_face_bgr = imagelib.color_transfer_idt (out_face_bgr, dst_face_bgr) out_face_bgr = imagelib.color_transfer_idt (out_face_bgr, dst_face_bgr)
elif cfg.color_transfer_mode == 6: #idt-m elif cfg.color_transfer_mode == 6: #idt-m
out_face_bgr = imagelib.color_transfer_idt (out_face_bgr*prd_face_mask_a, dst_face_bgr*prd_face_mask_a) out_face_bgr = imagelib.color_transfer_idt (out_face_bgr*prd_face_mask_a, dst_face_bgr*prd_face_mask_a)
elif cfg.color_transfer_mode == 7: #sot-m elif cfg.color_transfer_mode == 7: #sot-m
out_face_bgr = imagelib.color_transfer_sot (out_face_bgr*prd_face_mask_a, dst_face_bgr*prd_face_mask_a) out_face_bgr = imagelib.color_transfer_sot (out_face_bgr*prd_face_mask_a, dst_face_bgr*prd_face_mask_a)
out_face_bgr = np.clip (out_face_bgr, 0.0, 1.0) out_face_bgr = np.clip (out_face_bgr, 0.0, 1.0)
elif cfg.color_transfer_mode == 8: #mix-m elif cfg.color_transfer_mode == 8: #mix-m
out_face_bgr = imagelib.color_transfer_mix (out_face_bgr*prd_face_mask_a, dst_face_bgr*prd_face_mask_a) out_face_bgr = imagelib.color_transfer_mix (out_face_bgr*prd_face_mask_a, dst_face_bgr*prd_face_mask_a)
if cfg.mode == 'seamless-hist-match': if cfg.mode == 'seamless-hist-match':
out_face_bgr = imagelib.color_hist_match(out_face_bgr, dst_face_bgr, cfg.hist_match_threshold) out_face_bgr = imagelib.color_hist_match(out_face_bgr, dst_face_bgr, cfg.hist_match_threshold)
@ -327,7 +327,7 @@ def MergeMaskedFace (predictor_func, predictor_input_shape, cfg, frame_info, img
else: else:
alpha = cfg.color_degrade_power / 100.0 alpha = cfg.color_degrade_power / 100.0
out_img = (out_img*(1.0-alpha) + out_img_reduced*alpha) out_img = (out_img*(1.0-alpha) + out_img_reduced*alpha)
out_merging_mask = img_face_mask_aaa out_merging_mask = img_face_mask_aaa
return out_img, out_merging_mask[...,0:1] return out_img, out_merging_mask[...,0:1]
@ -353,10 +353,10 @@ def MergeMasked (predictor_func, predictor_input_shape, cfg, frame_info):
final_img = img final_img = img
final_mask = merging_mask final_mask = merging_mask
else: else:
final_img = final_img*(1-merging_mask) + img*merging_mask final_img = final_img*(1-merging_mask) + img*merging_mask
final_mask = np.clip (final_mask + merging_mask, 0, 1 ) final_mask = np.clip (final_mask + merging_mask, 0, 1 )
if cfg.export_mask_alpha: if cfg.export_mask_alpha:
final_img = np.concatenate ( [final_img, final_mask], -1) final_img = np.concatenate ( [final_img, final_mask], -1)
return (final_img*255).astype(np.uint8) return (final_img*255).astype(np.uint8)

View file

@ -43,7 +43,7 @@ class MergerConfig(object):
def ask_settings(self): def ask_settings(self):
s = """Choose sharpen mode: \n""" s = """Choose sharpen mode: \n"""
for key in self.sharpen_dict.keys(): for key in self.sharpen_dict.keys():
s += f"""({key}) {self.sharpen_dict[key]}\n""" s += f"""({key}) {self.sharpen_dict[key]}\n"""
io.log_info(s) io.log_info(s)
self.sharpen_mode = io.input_int ("", 0, valid_list=self.sharpen_dict.keys(), help_message="Enhance details by applying sharpen filter.") self.sharpen_mode = io.input_int ("", 0, valid_list=self.sharpen_dict.keys(), help_message="Enhance details by applying sharpen filter.")

View file

@ -68,7 +68,7 @@ class ModelBase(object):
s = f"[{i}] : {model_name} " s = f"[{i}] : {model_name} "
if i == 0: if i == 0:
s += "- latest" s += "- latest"
io.log_info (s) io.log_info (s)
inp = io.input_str(f"", "0", show_default_value=False ) inp = io.input_str(f"", "0", show_default_value=False )
model_idx = -1 model_idx = -1
@ -81,27 +81,27 @@ class ModelBase(object):
if len(inp) == 1: if len(inp) == 1:
is_rename = inp[0] == 'r' is_rename = inp[0] == 'r'
is_delete = inp[0] == 'd' is_delete = inp[0] == 'd'
if is_rename or is_delete: if is_rename or is_delete:
if len(saved_models_names) != 0: if len(saved_models_names) != 0:
if is_rename: if is_rename:
name = io.input_str(f"Enter the name of the model you want to rename") name = io.input_str(f"Enter the name of the model you want to rename")
elif is_delete: elif is_delete:
name = io.input_str(f"Enter the name of the model you want to delete") name = io.input_str(f"Enter the name of the model you want to delete")
if name in saved_models_names: if name in saved_models_names:
if is_rename: if is_rename:
new_model_name = io.input_str(f"Enter new name of the model") new_model_name = io.input_str(f"Enter new name of the model")
for filepath in pathex.get_file_paths(saved_models_path): for filepath in pathex.get_file_paths(saved_models_path):
filepath_name = filepath.name filepath_name = filepath.name
model_filename, remain_filename = filepath_name.split('_', 1) model_filename, remain_filename = filepath_name.split('_', 1)
if model_filename == name: if model_filename == name:
if is_rename: if is_rename:
new_filepath = filepath.parent / ( new_model_name + '_' + remain_filename ) new_filepath = filepath.parent / ( new_model_name + '_' + remain_filename )
filepath.rename (new_filepath) filepath.rename (new_filepath)
elif is_delete: elif is_delete:
@ -159,7 +159,7 @@ class ModelBase(object):
##### #####
io.input_skip_pending() io.input_skip_pending()
self.on_initialize_options() self.on_initialize_options()
if self.is_first_run(): if self.is_first_run():
# save as default options only for first run model initialize # save as default options only for first run model initialize
@ -172,7 +172,7 @@ class ModelBase(object):
self.on_initialize() self.on_initialize()
self.options['batch_size'] = self.batch_size self.options['batch_size'] = self.batch_size
if self.is_training: if self.is_training:
self.preview_history_path = self.saved_models_path / ( f'{self.get_model_name()}_history' ) self.preview_history_path = self.saved_models_path / ( f'{self.get_model_name()}_history' )
self.autobackups_path = self.saved_models_path / ( f'{self.get_model_name()}_autobackups' ) self.autobackups_path = self.saved_models_path / ( f'{self.get_model_name()}_autobackups' )
@ -326,7 +326,7 @@ class ModelBase(object):
def get_pretraining_data_path(self): def get_pretraining_data_path(self):
return self.pretraining_data_path return self.pretraining_data_path
def get_target_iter(self): def get_target_iter(self):
return self.target_iter return self.target_iter
@ -479,7 +479,7 @@ class ModelBase(object):
#Find the longest key name and value string. Used as column widths. #Find the longest key name and value string. Used as column widths.
width_name = max([len(k) for k in self.options.keys()] + [17]) + 1 # Single space buffer to left edge. Minimum of 17, the length of the longest static string used "Current iteration" width_name = max([len(k) for k in self.options.keys()] + [17]) + 1 # Single space buffer to left edge. Minimum of 17, the length of the longest static string used "Current iteration"
width_value = max([len(str(x)) for x in self.options.values()] + [len(str(self.get_iter())), len(self.get_model_name())]) + 1 # Single space buffer to right edge width_value = max([len(str(x)) for x in self.options.values()] + [len(str(self.get_iter())), len(self.get_model_name())]) + 1 # Single space buffer to right edge
if not self.device_config.cpu_only: #Check length of GPU names if len(self.device_config.devices) != 0: #Check length of GPU names
width_value = max([len(device.name)+1 for device in self.device_config.devices] + [width_value]) width_value = max([len(device.name)+1 for device in self.device_config.devices] + [width_value])
width_total = width_name + width_value + 2 #Plus 2 for ": " width_total = width_name + width_value + 2 #Plus 2 for ": "
@ -499,7 +499,7 @@ class ModelBase(object):
summary_text += [f'=={" Running On ":-^{width_total}}=='] # Training hardware info summary_text += [f'=={" Running On ":-^{width_total}}=='] # Training hardware info
summary_text += [f'=={" "*width_total}=='] summary_text += [f'=={" "*width_total}==']
if self.device_config.cpu_only: if len(self.device_config.devices) == 0:
summary_text += [f'=={"Using device": >{width_name}}: {"CPU": <{width_value}}=='] # cpu_only summary_text += [f'=={"Using device": >{width_name}}: {"CPU": <{width_value}}=='] # cpu_only
else: else:
for device in self.device_config.devices: for device in self.device_config.devices:

View file

@ -13,11 +13,13 @@ from samplelib import *
class QModel(ModelBase): class QModel(ModelBase):
#override #override
def on_initialize(self): def on_initialize(self):
nn.initialize() device_config = nn.getCurrentDeviceConfig()
self.model_data_format = "NCHW" if len(device_config.devices) != 0 else "NHWC"
nn.initialize(data_format=self.model_data_format)
tf = nn.tf tf = nn.tf
conv_kernel_initializer = nn.initializers.ca conv_kernel_initializer = nn.initializers.ca()
class Downscale(nn.ModelBase): class Downscale(nn.ModelBase):
def __init__(self, in_ch, out_ch, kernel_size=5, dilations=1, subpixel=True, use_activator=True, *kwargs ): def __init__(self, in_ch, out_ch, kernel_size=5, dilations=1, subpixel=True, use_activator=True, *kwargs ):
self.in_ch = in_ch self.in_ch = in_ch
@ -39,7 +41,7 @@ class QModel(ModelBase):
x = self.conv1(x) x = self.conv1(x)
if self.subpixel: if self.subpixel:
x = tf.nn.space_to_depth(x, 2) x = nn.tf_space_to_depth(x, 2)
if self.use_activator: if self.use_activator:
x = nn.tf_gelu(x) x = nn.tf_gelu(x)
@ -63,7 +65,7 @@ class QModel(ModelBase):
for down in self.downs: for down in self.downs:
x = down(x) x = down(x)
return x return x
class Upscale(nn.ModelBase): class Upscale(nn.ModelBase):
def on_build(self, in_ch, out_ch, kernel_size=3 ): def on_build(self, in_ch, out_ch, kernel_size=3 ):
self.conv1 = nn.Conv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer) self.conv1 = nn.Conv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer)
@ -71,9 +73,9 @@ class QModel(ModelBase):
def forward(self, x): def forward(self, x):
x = self.conv1(x) x = self.conv1(x)
x = nn.tf_gelu(x) x = nn.tf_gelu(x)
x = tf.nn.depth_to_space(x, 2) x = nn.tf_depth_to_space(x, 2)
return x return x
class ResidualBlock(nn.ModelBase): class ResidualBlock(nn.ModelBase):
def on_build(self, ch, kernel_size=3 ): def on_build(self, ch, kernel_size=3 ):
self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer) self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer)
@ -109,7 +111,7 @@ class QModel(ModelBase):
def forward(self, inp): def forward(self, inp):
x = self.dense1(inp) x = self.dense1(inp)
x = self.dense2(x) x = self.dense2(x)
x = tf.reshape (x, (-1, lowest_dense_res, lowest_dense_res, self.ae_out_ch)) x = nn.tf_reshape_4D (x, lowest_dense_res, lowest_dense_res, self.ae_out_ch)
x = self.upscale1(x) x = self.upscale1(x)
x = self.res1(x) x = self.res1(x)
return x return x
@ -118,11 +120,11 @@ class QModel(ModelBase):
return self.ae_out_ch return self.ae_out_ch
class Decoder(nn.ModelBase): class Decoder(nn.ModelBase):
def on_build(self, in_ch, d_ch): def on_build(self, in_ch, d_ch):
self.upscale1 = Upscale(in_ch, d_ch*4) self.upscale1 = Upscale(in_ch, d_ch*4)
self.res1 = ResidualBlock(d_ch*4) self.res1 = ResidualBlock(d_ch*4)
self.upscale2 = Upscale(d_ch*4, d_ch*2) self.upscale2 = Upscale(d_ch*4, d_ch*2)
self.res2 = ResidualBlock(d_ch*2) self.res2 = ResidualBlock(d_ch*2)
self.upscale3 = Upscale(d_ch*2, d_ch*1) self.upscale3 = Upscale(d_ch*2, d_ch*1)
self.res3 = ResidualBlock(d_ch*1) self.res3 = ResidualBlock(d_ch*1)
@ -134,8 +136,8 @@ class QModel(ModelBase):
self.out_convm = nn.Conv2D( d_ch//2, 1, kernel_size=1, padding='SAME', kernel_initializer=conv_kernel_initializer) self.out_convm = nn.Conv2D( d_ch//2, 1, kernel_size=1, padding='SAME', kernel_initializer=conv_kernel_initializer)
def forward(self, inp): def forward(self, inp):
z = inp z = inp
x = self.upscale1 (z) x = self.upscale1 (z)
x = self.res1 (x) x = self.res1 (x)
x = self.upscale2 (x) x = self.upscale2 (x)
x = self.res2 (x) x = self.res2 (x)
@ -158,7 +160,7 @@ class QModel(ModelBase):
d_dims = 64 d_dims = 64
self.pretrain = False self.pretrain = False
self.pretrain_just_disabled = False self.pretrain_just_disabled = False
masked_training = True masked_training = True
models_opt_on_gpu = len(devices) == 1 and devices[0].total_mem_gb >= 4 models_opt_on_gpu = len(devices) == 1 and devices[0].total_mem_gb >= 4
@ -167,8 +169,8 @@ class QModel(ModelBase):
input_nc = 3 input_nc = 3
output_nc = 3 output_nc = 3
bgr_shape = (resolution, resolution, output_nc) bgr_shape = nn.get4Dshape(resolution,resolution,input_nc)
mask_shape = (resolution, resolution, 1) mask_shape = nn.get4Dshape(resolution,resolution,1)
lowest_dense_res = resolution // 16 lowest_dense_res = resolution // 16
self.model_filename_list = [] self.model_filename_list = []
@ -176,22 +178,22 @@ class QModel(ModelBase):
with tf.device ('/CPU:0'): with tf.device ('/CPU:0'):
#Place holders on CPU #Place holders on CPU
self.warped_src = tf.placeholder (tf.float32, (None,)+bgr_shape) self.warped_src = tf.placeholder (nn.tf_floatx, bgr_shape)
self.warped_dst = tf.placeholder (tf.float32, (None,)+bgr_shape) self.warped_dst = tf.placeholder (nn.tf_floatx, bgr_shape)
self.target_src = tf.placeholder (tf.float32, (None,)+bgr_shape) self.target_src = tf.placeholder (nn.tf_floatx, bgr_shape)
self.target_dst = tf.placeholder (tf.float32, (None,)+bgr_shape) self.target_dst = tf.placeholder (nn.tf_floatx, bgr_shape)
self.target_srcm = tf.placeholder (tf.float32, (None,)+mask_shape) self.target_srcm = tf.placeholder (nn.tf_floatx, mask_shape)
self.target_dstm = tf.placeholder (tf.float32, (None,)+mask_shape) self.target_dstm = tf.placeholder (nn.tf_floatx, mask_shape)
# Initializing model classes # Initializing model classes
with tf.device (models_opt_device): with tf.device (models_opt_device):
self.encoder = Encoder(in_ch=input_nc, e_ch=e_dims, name='encoder') self.encoder = Encoder(in_ch=input_nc, e_ch=e_dims, name='encoder')
encoder_out_ch = self.encoder.compute_output_shape ( (tf.float32, (None,resolution,resolution,input_nc)))[-1] encoder_out_ch = self.encoder.compute_output_channels ( (nn.tf_floatx, bgr_shape))
self.inter = Inter (in_ch=encoder_out_ch, lowest_dense_res=lowest_dense_res, ae_ch=ae_dims, ae_out_ch=ae_dims, d_ch=d_dims, name='inter') self.inter = Inter (in_ch=encoder_out_ch, lowest_dense_res=lowest_dense_res, ae_ch=ae_dims, ae_out_ch=ae_dims, d_ch=d_dims, name='inter')
inter_out_ch = self.inter.compute_output_shape ( (tf.float32, (None,encoder_out_ch)))[-1] inter_out_ch = self.inter.compute_output_channels ( (nn.tf_floatx, (None,encoder_out_ch)))
self.decoder_src = Decoder(in_ch=inter_out_ch, d_ch=d_dims, name='decoder_src') self.decoder_src = Decoder(in_ch=inter_out_ch, d_ch=d_dims, name='decoder_src')
self.decoder_dst = Decoder(in_ch=inter_out_ch, d_ch=d_dims, name='decoder_dst') self.decoder_dst = Decoder(in_ch=inter_out_ch, d_ch=d_dims, name='decoder_dst')
@ -203,7 +205,7 @@ class QModel(ModelBase):
if self.is_training: if self.is_training:
self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter.get_weights() + self.decoder_src.get_weights() + self.decoder_dst.get_weights() self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter.get_weights() + self.decoder_src.get_weights() + self.decoder_dst.get_weights()
# Initialize optimizers # Initialize optimizers
self.src_dst_opt = nn.TFRMSpropOptimizer(lr=2e-4, lr_dropout=0.3, name='src_dst_opt') self.src_dst_opt = nn.TFRMSpropOptimizer(lr=2e-4, lr_dropout=0.3, name='src_dst_opt')
self.src_dst_opt.initialize_variables(self.src_dst_trainable_weights, vars_on_cpu=optimizer_vars_on_cpu ) self.src_dst_opt.initialize_variables(self.src_dst_trainable_weights, vars_on_cpu=optimizer_vars_on_cpu )
@ -222,7 +224,7 @@ class QModel(ModelBase):
gpu_pred_src_srcm_list = [] gpu_pred_src_srcm_list = []
gpu_pred_dst_dstm_list = [] gpu_pred_dst_dstm_list = []
gpu_pred_src_dstm_list = [] gpu_pred_src_dstm_list = []
gpu_src_losses = [] gpu_src_losses = []
gpu_dst_losses = [] gpu_dst_losses = []
gpu_src_dst_loss_gvs = [] gpu_src_dst_loss_gvs = []
@ -239,7 +241,7 @@ class QModel(ModelBase):
gpu_target_srcm = self.target_srcm[batch_slice,:,:,:] gpu_target_srcm = self.target_srcm[batch_slice,:,:,:]
gpu_target_dstm = self.target_dstm[batch_slice,:,:,:] gpu_target_dstm = self.target_dstm[batch_slice,:,:,:]
# process model tensors # process model tensors
gpu_src_code = self.inter(self.encoder(gpu_warped_src)) gpu_src_code = self.inter(self.encoder(gpu_warped_src))
gpu_dst_code = self.inter(self.encoder(gpu_warped_dst)) gpu_dst_code = self.inter(self.encoder(gpu_warped_dst))
gpu_pred_src_src, gpu_pred_src_srcm = self.decoder_src(gpu_src_code) gpu_pred_src_src, gpu_pred_src_srcm = self.decoder_src(gpu_src_code)
@ -249,11 +251,11 @@ class QModel(ModelBase):
gpu_pred_src_src_list.append(gpu_pred_src_src) gpu_pred_src_src_list.append(gpu_pred_src_src)
gpu_pred_dst_dst_list.append(gpu_pred_dst_dst) gpu_pred_dst_dst_list.append(gpu_pred_dst_dst)
gpu_pred_src_dst_list.append(gpu_pred_src_dst) gpu_pred_src_dst_list.append(gpu_pred_src_dst)
gpu_pred_src_srcm_list.append(gpu_pred_src_srcm) gpu_pred_src_srcm_list.append(gpu_pred_src_srcm)
gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm) gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm)
gpu_pred_src_dstm_list.append(gpu_pred_src_dstm) gpu_pred_src_dstm_list.append(gpu_pred_src_dstm)
gpu_target_srcm_blur = nn.tf_gaussian_blur(gpu_target_srcm, max(1, resolution // 32) ) gpu_target_srcm_blur = nn.tf_gaussian_blur(gpu_target_srcm, max(1, resolution // 32) )
gpu_target_dstm_blur = nn.tf_gaussian_blur(gpu_target_dstm, max(1, resolution // 32) ) gpu_target_dstm_blur = nn.tf_gaussian_blur(gpu_target_dstm, max(1, resolution // 32) )
@ -271,11 +273,11 @@ class QModel(ModelBase):
gpu_src_loss = tf.reduce_mean ( 10*nn.tf_dssim(gpu_target_srcmasked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) gpu_src_loss = tf.reduce_mean ( 10*nn.tf_dssim(gpu_target_srcmasked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
gpu_src_loss += tf.reduce_mean ( 10*tf.square ( gpu_target_srcmasked_opt - gpu_pred_src_src_masked_opt ), axis=[1,2,3]) gpu_src_loss += tf.reduce_mean ( 10*tf.square ( gpu_target_srcmasked_opt - gpu_pred_src_src_masked_opt ), axis=[1,2,3])
gpu_src_loss += tf.reduce_mean ( tf.square( gpu_target_srcm - gpu_pred_src_srcm ),axis=[1,2,3] ) gpu_src_loss += tf.reduce_mean ( 10*tf.square( gpu_target_srcm - gpu_pred_src_srcm ),axis=[1,2,3] )
gpu_dst_loss = tf.reduce_mean ( 10*nn.tf_dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1]) gpu_dst_loss = tf.reduce_mean ( 10*nn.tf_dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1])
gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dst_masked_opt- gpu_pred_dst_dst_masked_opt ), axis=[1,2,3]) gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dst_masked_opt- gpu_pred_dst_dst_masked_opt ), axis=[1,2,3])
gpu_dst_loss += tf.reduce_mean ( tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] ) gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] )
gpu_src_losses += [gpu_src_loss] gpu_src_losses += [gpu_src_loss]
gpu_dst_losses += [gpu_dst_loss] gpu_dst_losses += [gpu_dst_loss]
@ -286,29 +288,16 @@ class QModel(ModelBase):
# Average losses and gradients, and create optimizer update ops # Average losses and gradients, and create optimizer update ops
with tf.device (models_opt_device): with tf.device (models_opt_device):
if gpu_count == 1: pred_src_src = nn.tf_concat(gpu_pred_src_src_list, 0)
pred_src_src = gpu_pred_src_src_list[0] pred_dst_dst = nn.tf_concat(gpu_pred_dst_dst_list, 0)
pred_dst_dst = gpu_pred_dst_dst_list[0] pred_src_dst = nn.tf_concat(gpu_pred_src_dst_list, 0)
pred_src_dst = gpu_pred_src_dst_list[0] pred_src_srcm = nn.tf_concat(gpu_pred_src_srcm_list, 0)
pred_src_srcm = gpu_pred_src_srcm_list[0] pred_dst_dstm = nn.tf_concat(gpu_pred_dst_dstm_list, 0)
pred_dst_dstm = gpu_pred_dst_dstm_list[0] pred_src_dstm = nn.tf_concat(gpu_pred_src_dstm_list, 0)
pred_src_dstm = gpu_pred_src_dstm_list[0]
src_loss = gpu_src_losses[0]
dst_loss = gpu_dst_losses[0]
src_dst_loss_gv = gpu_src_dst_loss_gvs[0]
else:
pred_src_src = tf.concat(gpu_pred_src_src_list, 0)
pred_dst_dst = tf.concat(gpu_pred_dst_dst_list, 0)
pred_src_dst = tf.concat(gpu_pred_src_dst_list, 0)
pred_src_srcm = tf.concat(gpu_pred_src_srcm_list, 0)
pred_dst_dstm = tf.concat(gpu_pred_dst_dstm_list, 0)
pred_src_dstm = tf.concat(gpu_pred_src_dstm_list, 0)
src_loss = nn.tf_average_tensor_list(gpu_src_losses)
dst_loss = nn.tf_average_tensor_list(gpu_dst_losses)
src_dst_loss_gv = nn.tf_average_gv_list (gpu_src_dst_loss_gvs)
src_loss = nn.tf_average_tensor_list(gpu_src_losses)
dst_loss = nn.tf_average_tensor_list(gpu_dst_losses)
src_dst_loss_gv = nn.tf_average_gv_list (gpu_src_dst_loss_gvs)
src_dst_loss_gv_op = self.src_dst_opt.get_update_op (src_dst_loss_gv) src_dst_loss_gv_op = self.src_dst_opt.get_update_op (src_dst_loss_gv)
# Initializing training and view functions # Initializing training and view functions
@ -341,17 +330,15 @@ class QModel(ModelBase):
_, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code) _, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code)
def AE_merge( warped_dst): def AE_merge( warped_dst):
return nn.tf_sess.run ( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm], feed_dict={self.warped_dst:warped_dst}) return nn.tf_sess.run ( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm], feed_dict={self.warped_dst:warped_dst})
self.AE_merge = AE_merge self.AE_merge = AE_merge
# Loading/initializing all models/optimizers weights # Loading/initializing all models/optimizers weights
for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"): for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"):
do_init = self.is_first_run() do_init = self.is_first_run()
if self.pretrain_just_disabled: if self.pretrain_just_disabled:
if model == self.inter: if model == self.inter:
do_init = True do_init = True
@ -359,16 +346,15 @@ class QModel(ModelBase):
if not do_init: if not do_init:
do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) ) do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) )
if do_init and self.pretrained_model_path is not None: if do_init and self.pretrained_model_path is not None:
pretrained_filepath = self.pretrained_model_path / filename pretrained_filepath = self.pretrained_model_path / filename
if pretrained_filepath.exists(): if pretrained_filepath.exists():
do_init = not model.load_weights(pretrained_filepath) do_init = not model.load_weights(pretrained_filepath)
if do_init: if do_init:
model.init_weights() model.init_weights()
# initializing sample generators # initializing sample generators
if self.is_training: if self.is_training:
t = SampleProcessor.Types t = SampleProcessor.Types
face_type = t.FACE_TYPE_FULL face_type = t.FACE_TYPE_FULL
@ -384,19 +370,19 @@ class QModel(ModelBase):
self.set_training_data_generators ([ self.set_training_data_generators ([
SampleGeneratorFace(training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(), SampleGeneratorFace(training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
sample_process_options=SampleProcessor.Options(random_flip=True if self.pretrain else False), sample_process_options=SampleProcessor.Options(random_flip=True if self.pretrain else False),
output_sample_types = [ {'types' : (t.IMG_WARPED_TRANSFORMED, face_type, t.MODE_BGR), 'resolution':resolution, }, output_sample_types = [ {'types' : (t.IMG_WARPED_TRANSFORMED, face_type, t.MODE_BGR), 'data_format':nn.data_format, 'resolution':resolution, },
{'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_BGR), 'resolution': resolution, }, {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_BGR), 'data_format':nn.data_format, 'resolution': resolution, },
{'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_M), 'resolution': resolution } ], {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_M), 'data_format':nn.data_format, 'resolution': resolution } ],
generators_count=src_generators_count ), generators_count=src_generators_count ),
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
sample_process_options=SampleProcessor.Options(random_flip=True if self.pretrain else False), sample_process_options=SampleProcessor.Options(random_flip=True if self.pretrain else False),
output_sample_types = [ {'types' : (t.IMG_WARPED_TRANSFORMED, face_type, t.MODE_BGR), 'resolution':resolution}, output_sample_types = [ {'types' : (t.IMG_WARPED_TRANSFORMED, face_type, t.MODE_BGR), 'data_format':nn.data_format, 'resolution':resolution},
{'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_BGR), 'resolution': resolution}, {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_BGR), 'data_format':nn.data_format, 'resolution': resolution},
{'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_M), 'resolution': resolution} ], {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_M), 'data_format':nn.data_format, 'resolution': resolution} ],
generators_count=dst_generators_count ) generators_count=dst_generators_count )
]) ])
self.last_samples = None self.last_samples = None
#override #override
@ -408,22 +394,21 @@ class QModel(ModelBase):
for model, filename in io.progress_bar_generator(self.get_model_filename_list(), "Saving", leave=False): for model, filename in io.progress_bar_generator(self.get_model_filename_list(), "Saving", leave=False):
model.save_weights ( self.get_strpath_storage_for_file(filename) ) model.save_weights ( self.get_strpath_storage_for_file(filename) )
#override #override
def onTrainOneIter(self): def onTrainOneIter(self):
if self.get_iter() % 3 == 0 and self.last_samples is not None: if self.get_iter() % 3 == 0 and self.last_samples is not None:
( (warped_src, target_src, target_srcm), \ ( (warped_src, target_src, target_srcm), \
(warped_dst, target_dst, target_dstm) ) = self.last_samples (warped_dst, target_dst, target_dstm) ) = self.last_samples
src_loss, dst_loss = self.src_dst_train (target_src, target_src, target_srcm, src_loss, dst_loss = self.src_dst_train (target_src, target_src, target_srcm,
target_dst, target_dst, target_dstm) target_dst, target_dst, target_dstm)
else: else:
samples = self.last_samples = self.generate_next_samples() samples = self.last_samples = self.generate_next_samples()
( (warped_src, target_src, target_srcm), \ ( (warped_src, target_src, target_srcm), \
(warped_dst, target_dst, target_dstm) ) = samples (warped_dst, target_dst, target_dstm) ) = samples
src_loss, dst_loss = self.src_dst_train (warped_src, target_src, target_srcm, src_loss, dst_loss = self.src_dst_train (warped_src, target_src, target_srcm,
warped_dst, target_dst, target_dstm) warped_dst, target_dst, target_dstm)
return ( ('src_loss', src_loss), ('dst_loss', dst_loss), ) return ( ('src_loss', src_loss), ('dst_loss', dst_loss), )
#override #override
@ -435,9 +420,11 @@ class QModel(ModelBase):
[ [sample[0:n_samples] for sample in sample_list ] [ [sample[0:n_samples] for sample in sample_list ]
for sample_list in samples ] for sample_list in samples ]
S, D, SS, DD, DDM, SD, SDM = [ np.clip(x, 0.0, 1.0) for x in ([target_src,target_dst] + self.AE_view (target_src, target_dst) ) ] S, D, SS, DD, DDM, SD, SDM = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([target_src,target_dst] + self.AE_view (target_src, target_dst) ) ]
DDM, SDM, = [ np.repeat (x, (3,), -1) for x in [DDM, SDM] ] DDM, SDM, = [ np.repeat (x, (3,), -1) for x in [DDM, SDM] ]
target_srcm, target_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format) for x in ([target_srcm, target_dstm] )]
result = [] result = []
st = [] st = []
for i in range(n_samples): for i in range(n_samples):
@ -456,8 +443,10 @@ class QModel(ModelBase):
return result return result
def predictor_func (self, face=None): def predictor_func (self, face=None):
face = face[None,...]
face = nn.to_data_format(face, self.model_data_format, "NHWC")
bgr, mask_dst_dstm, mask_src_dstm = self.AE_merge (face[np.newaxis,...]) bgr, mask_dst_dstm, mask_src_dstm = [ nn.to_data_format(x, "NHWC", self.model_data_format).astype(np.float32) for x in self.AE_merge (face) ]
mask = mask_dst_dstm[0] * mask_src_dstm[0] mask = mask_dst_dstm[0] * mask_src_dstm[0]
return bgr[0], mask[...,0] return bgr[0], mask[...,0]

View file

@ -15,25 +15,17 @@ class SAEHDModel(ModelBase):
#override #override
def on_initialize_options(self): def on_initialize_options(self):
device_config = nn.getCurrentDeviceConfig() device_config = nn.getCurrentDeviceConfig()
lowest_vram = 2 lowest_vram = 2
if len(device_config.devices) != 0: if len(device_config.devices) != 0:
lowest_vram = device_config.devices.get_worst_device().total_mem_gb lowest_vram = device_config.devices.get_worst_device().total_mem_gb
if lowest_vram >= 4: if lowest_vram >= 4:
suggest_batch_size = 8 suggest_batch_size = 8
else: else:
suggest_batch_size = 4 suggest_batch_size = 4
yn_str = {True:'y',False:'n'}
ask_override = self.ask_override()
if self.is_first_run() or ask_override: yn_str = {True:'y',False:'n'}
self.ask_enable_autobackup()
self.ask_write_preview_history()
self.ask_target_iter()
self.ask_random_flip()
self.ask_batch_size(suggest_batch_size)
default_resolution = self.options['resolution'] = self.load_or_def_option('resolution', 128) default_resolution = self.options['resolution'] = self.load_or_def_option('resolution', 128)
default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'f') default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'f')
@ -42,52 +34,63 @@ class SAEHDModel(ModelBase):
default_ae_dims = self.options['ae_dims'] = self.load_or_def_option('ae_dims', 256) default_ae_dims = self.options['ae_dims'] = self.load_or_def_option('ae_dims', 256)
default_e_dims = self.options['e_dims'] = self.load_or_def_option('e_dims', 64) default_e_dims = self.options['e_dims'] = self.load_or_def_option('e_dims', 64)
default_d_dims = self.options['d_dims'] = self.load_or_def_option('d_dims', 64) default_d_dims = self.options['d_dims'] = self.load_or_def_option('d_dims', 64)
default_d_mask_dims = default_d_dims // 3 default_d_mask_dims = default_d_dims // 3
default_d_mask_dims += default_d_mask_dims % 2 default_d_mask_dims += default_d_mask_dims % 2
default_d_mask_dims = self.options['d_mask_dims'] = self.load_or_def_option('d_mask_dims', default_d_mask_dims) default_d_mask_dims = self.options['d_mask_dims'] = self.load_or_def_option('d_mask_dims', default_d_mask_dims)
default_use_float16 = self.options['use_float16'] = self.load_or_def_option('use_float16', False)
default_learn_mask = self.options['learn_mask'] = self.load_or_def_option('learn_mask', True) default_learn_mask = self.options['learn_mask'] = self.load_or_def_option('learn_mask', True)
default_lr_dropout = self.options['lr_dropout'] = self.load_or_def_option('lr_dropout', False) default_lr_dropout = self.options['lr_dropout'] = self.load_or_def_option('lr_dropout', False)
default_random_warp = self.options['random_warp'] = self.load_or_def_option('random_warp', True) default_random_warp = self.options['random_warp'] = self.load_or_def_option('random_warp', True)
default_true_face_training = self.options['true_face_training'] = self.load_or_def_option('true_face_training', False) default_true_face_power = self.options['true_face_power'] = self.load_or_def_option('true_face_power', 0.0)
default_face_style_power = self.options['face_style_power'] = self.load_or_def_option('face_style_power', 0.0) default_face_style_power = self.options['face_style_power'] = self.load_or_def_option('face_style_power', 0.0)
default_bg_style_power = self.options['bg_style_power'] = self.load_or_def_option('bg_style_power', 0.0) default_bg_style_power = self.options['bg_style_power'] = self.load_or_def_option('bg_style_power', 0.0)
default_ct_mode = self.options['ct_mode'] = self.load_or_def_option('ct_mode', 'none') default_ct_mode = self.options['ct_mode'] = self.load_or_def_option('ct_mode', 'none')
default_clipgrad = self.options['clipgrad'] = self.load_or_def_option('clipgrad', False) default_clipgrad = self.options['clipgrad'] = self.load_or_def_option('clipgrad', False)
default_pretrain = self.options['pretrain'] = self.load_or_def_option('pretrain', False) default_pretrain = self.options['pretrain'] = self.load_or_def_option('pretrain', False)
ask_override = self.ask_override()
if self.is_first_run() or ask_override:
self.ask_enable_autobackup()
self.ask_write_preview_history()
self.ask_target_iter()
self.ask_random_flip()
self.ask_batch_size(suggest_batch_size)
if self.is_first_run(): if self.is_first_run():
resolution = io.input_int("Resolution", default_resolution, add_info="64-256", help_message="More resolution requires more VRAM and time to train. Value will be adjusted to multiple of 16.") resolution = io.input_int("Resolution", default_resolution, add_info="64-256", help_message="More resolution requires more VRAM and time to train. Value will be adjusted to multiple of 16.")
resolution = np.clip ( (resolution // 16) * 16, 64, 256) resolution = np.clip ( (resolution // 16) * 16, 64, 256)
self.options['resolution'] = resolution self.options['resolution'] = resolution
self.options['face_type'] = io.input_str ("Face type", default_face_type, ['h','mf','f'], help_message="Half / mid face / full face. Half face has better resolution, but covers less area of cheeks. Mid face is 30% wider than half face.").lower() self.options['face_type'] = io.input_str ("Face type", default_face_type, ['h','mf','f'], help_message="Half / mid face / full face. Half face has better resolution, but covers less area of cheeks. Mid face is 30% wider than half face.").lower()
if (self.is_first_run() or ask_override) and len(device_config.devices) == 1:
self.options['models_opt_on_gpu'] = io.input_bool ("Place models and optimizer on GPU", default_models_opt_on_gpu, help_message="When you train on one GPU, by default model and optimizer weights are placed on GPU to accelerate the process. You can place they on CPU to free up extra VRAM, thus set bigger dimensions.")
if self.is_first_run():
self.options['archi'] = io.input_str ("AE architecture", default_archi, ['dfhd','liaehd','df','liae'], help_message="'df' keeps faces more natural. 'liae' can fix overly different face shapes. 'hd' is heavyweight version for the best quality.").lower() #-s version is slower, but has decreased change to collapse. self.options['archi'] = io.input_str ("AE architecture", default_archi, ['dfhd','liaehd','df','liae'], help_message="'df' keeps faces more natural. 'liae' can fix overly different face shapes. 'hd' is heavyweight version for the best quality.").lower() #-s version is slower, but has decreased change to collapse.
self.options['ae_dims'] = np.clip ( io.input_int("AutoEncoder dimensions", default_ae_dims, add_info="32-1024", help_message="All face information will packed to AE dims. If amount of AE dims are not enough, then for example closed eyes will not be recognized. More dims are better, but require more VRAM. You can fine-tune model size to fit your GPU." ), 32, 1024 ) self.options['ae_dims'] = np.clip ( io.input_int("AutoEncoder dimensions", default_ae_dims, add_info="32-1024", help_message="All face information will packed to AE dims. If amount of AE dims are not enough, then for example closed eyes will not be recognized. More dims are better, but require more VRAM. You can fine-tune model size to fit your GPU." ), 32, 1024 )
e_dims = np.clip ( io.input_int("Encoder dimensions", default_e_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 ) e_dims = np.clip ( io.input_int("Encoder dimensions", default_e_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 )
self.options['e_dims'] = e_dims + e_dims % 2 self.options['e_dims'] = e_dims + e_dims % 2
d_dims = np.clip ( io.input_int("Decoder dimensions", default_d_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 ) d_dims = np.clip ( io.input_int("Decoder dimensions", default_d_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 )
self.options['d_dims'] = d_dims + d_dims % 2 self.options['d_dims'] = d_dims + d_dims % 2
d_mask_dims = np.clip ( io.input_int("Decoder mask dimensions", default_d_mask_dims, add_info="16-256", help_message="Typical mask dimensions = decoder dimensions / 3. If you manually cut out obstacles from the dst mask, you can increase this parameter to achieve better quality." ), 16, 256 ) d_mask_dims = np.clip ( io.input_int("Decoder mask dimensions", default_d_mask_dims, add_info="16-256", help_message="Typical mask dimensions = decoder dimensions / 3. If you manually cut out obstacles from the dst mask, you can increase this parameter to achieve better quality." ), 16, 256 )
self.options['d_mask_dims'] = d_mask_dims + d_mask_dims % 2 self.options['d_mask_dims'] = d_mask_dims + d_mask_dims % 2
if self.is_first_run() or ask_override: if self.is_first_run() or ask_override:
self.options['learn_mask'] = io.input_bool ("Learn mask", default_learn_mask, help_message="Learning mask can help model to recognize face directions. Learn without mask can reduce model size, in this case merger forced to use 'not predicted mask' that is not smooth as predicted.") self.options['learn_mask'] = io.input_bool ("Learn mask", default_learn_mask, help_message="Learning mask can help model to recognize face directions. Learn without mask can reduce model size, in this case merger forced to use 'not predicted mask' that is not smooth as predicted.")
if self.is_first_run() or ask_override:
if len(device_config.devices) == 1:
self.options['models_opt_on_gpu'] = io.input_bool ("Place models and optimizer on GPU", default_models_opt_on_gpu, help_message="When you train on one GPU, by default model and optimizer weights are placed on GPU to accelerate the process. You can place they on CPU to free up extra VRAM, thus set bigger dimensions.")
self.options['use_float16'] = io.input_bool ("Use float16", default_use_float16, help_message="Experimental option. Reduces the model size by half. Increases the speed of training. Decreases the accuracy of the model. The model may collapse. Model does not study the mask in large resolutions.")
self.options['lr_dropout'] = io.input_bool ("Use learning rate dropout", default_lr_dropout, help_message="When the face is trained enough, you can enable this option to get extra sharpness for less amount of iterations.") self.options['lr_dropout'] = io.input_bool ("Use learning rate dropout", default_lr_dropout, help_message="When the face is trained enough, you can enable this option to get extra sharpness for less amount of iterations.")
self.options['random_warp'] = io.input_bool ("Enable random warp of samples", default_random_warp, help_message="Random warp is required to generalize facial expressions of both faces. When the face is trained enough, you can disable it to get extra sharpness for less amount of iterations.") self.options['random_warp'] = io.input_bool ("Enable random warp of samples", default_random_warp, help_message="Random warp is required to generalize facial expressions of both faces. When the face is trained enough, you can disable it to get extra sharpness for less amount of iterations.")
if 'df' in self.options['archi']: if 'df' in self.options['archi']:
self.options['true_face_training'] = io.input_bool ("Enable 'true face' training", default_true_face_training, help_message="The result face will be more like src and will get extra sharpness. Enable it for last 10-20k iterations before conversion.") self.options['true_face_power'] = np.clip ( io.input_number (" 'True face' power.", default_true_face_power, add_info="0.0000 .. 1.0", help_message="Experimental option. Discriminates result face to be more like src face. Higher value - stronger discrimination. Comparison - https://i.imgur.com/czScS9q.png"), 0.0, 1.0 )
else: else:
self.options['true_face_training'] = False self.options['true_face_power'] = 0.0
self.options['face_style_power'] = np.clip ( io.input_number("Face style power", default_face_style_power, add_info="0.0..100.0", help_message="Learn to transfer face style details such as light and color conditions. Warning: Enable it only after 10k iters, when predicted face is clear enough to start learn style. Start from 0.1 value and check history changes. Enabling this option increases the chance of model collapse."), 0.0, 100.0 ) self.options['face_style_power'] = np.clip ( io.input_number("Face style power", default_face_style_power, add_info="0.0..100.0", help_message="Learn to transfer face style details such as light and color conditions. Warning: Enable it only after 10k iters, when predicted face is clear enough to start learn style. Start from 0.1 value and check history changes. Enabling this option increases the chance of model collapse."), 0.0, 100.0 )
self.options['bg_style_power'] = np.clip ( io.input_number("Background style power", default_bg_style_power, add_info="0.0..100.0", help_message="Learn to transfer background around face. This can make face more like dst. Enabling this option increases the chance of model collapse."), 0.0, 100.0 ) self.options['bg_style_power'] = np.clip ( io.input_number("Background style power", default_bg_style_power, add_info="0.0..100.0", help_message="Learn to transfer background around face. This can make face more like dst. Enabling this option increases the chance of model collapse."), 0.0, 100.0 )
@ -96,20 +99,24 @@ class SAEHDModel(ModelBase):
self.options['pretrain'] = io.input_bool ("Enable pretraining mode", default_pretrain, help_message="Pretrain the model with large amount of various faces. After that, model can be used to train the fakes more quickly.") self.options['pretrain'] = io.input_bool ("Enable pretraining mode", default_pretrain, help_message="Pretrain the model with large amount of various faces. After that, model can be used to train the fakes more quickly.")
if self.options['pretrain'] and self.get_pretraining_data_path() is None: if self.options['pretrain'] and self.get_pretraining_data_path() is None:
raise Exception("pretraining_data_path is not defined") raise Exception("pretraining_data_path is not defined")
self.pretrain_just_disabled = (default_pretrain == True and self.options['pretrain'] == False) self.pretrain_just_disabled = (default_pretrain == True and self.options['pretrain'] == False)
if self.pretrain_just_disabled: if self.pretrain_just_disabled:
self.set_iter(1) self.set_iter(1)
#override #override
def on_initialize(self): def on_initialize(self):
nn.initialize() device_config = nn.getCurrentDeviceConfig()
self.model_data_format = "NCHW" if len(device_config.devices) != 0 else "NHWC"
nn.initialize(floatx="float16" if self.options['use_float16'] else "float32",
data_format=self.model_data_format)
tf = nn.tf tf = nn.tf
conv_kernel_initializer = nn.initializers.ca conv_kernel_initializer = nn.initializers.ca()
class Downscale(nn.ModelBase): class Downscale(nn.ModelBase):
def __init__(self, in_ch, out_ch, kernel_size=5, dilations=1, subpixel=True, use_activator=True, *kwargs ): def __init__(self, in_ch, out_ch, kernel_size=5, dilations=1, subpixel=True, use_activator=True, *kwargs ):
self.in_ch = in_ch self.in_ch = in_ch
@ -120,19 +127,19 @@ class SAEHDModel(ModelBase):
self.use_activator = use_activator self.use_activator = use_activator
super().__init__(*kwargs) super().__init__(*kwargs)
def on_build(self, *args, **kwargs ): def on_build(self, *args, **kwargs ):
self.conv1 = nn.Conv2D( self.in_ch, self.conv1 = nn.Conv2D( self.in_ch,
self.out_ch // (4 if self.subpixel else 1), self.out_ch // (4 if self.subpixel else 1),
kernel_size=self.kernel_size, kernel_size=self.kernel_size,
strides=1 if self.subpixel else 2, strides=1 if self.subpixel else 2,
padding='SAME', dilations=self.dilations, kernel_initializer=conv_kernel_initializer ) padding='SAME', dilations=self.dilations, kernel_initializer=conv_kernel_initializer)
def forward(self, x): def forward(self, x):
x = self.conv1(x) x = self.conv1(x)
if self.subpixel: if self.subpixel:
x = tf.nn.space_to_depth(x, 2) x = nn.tf_space_to_depth(x, 2)
if self.use_activator: if self.use_activator:
x = tf.nn.leaky_relu(x, 0.1) x = tf.nn.leaky_relu(x, 0.1)
return x return x
@ -143,19 +150,19 @@ class SAEHDModel(ModelBase):
class DownscaleBlock(nn.ModelBase): class DownscaleBlock(nn.ModelBase):
def on_build(self, in_ch, ch, n_downscales, kernel_size, dilations=1, subpixel=True): def on_build(self, in_ch, ch, n_downscales, kernel_size, dilations=1, subpixel=True):
self.downs = [] self.downs = []
last_ch = in_ch last_ch = in_ch
for i in range(n_downscales): for i in range(n_downscales):
cur_ch = ch*( min(2**i, 8) ) cur_ch = ch*( min(2**i, 8) )
self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size, dilations=dilations, subpixel=subpixel) ) self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size, dilations=dilations, subpixel=subpixel) )
last_ch = self.downs[-1].get_out_ch() last_ch = self.downs[-1].get_out_ch()
def forward(self, inp): def forward(self, inp):
x = inp x = inp
for down in self.downs: for down in self.downs:
x = down(x) x = down(x)
return x return x
class Upscale(nn.ModelBase): class Upscale(nn.ModelBase):
def on_build(self, in_ch, out_ch, kernel_size=3 ): def on_build(self, in_ch, out_ch, kernel_size=3 ):
self.conv1 = nn.Conv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer) self.conv1 = nn.Conv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer)
@ -163,7 +170,7 @@ class SAEHDModel(ModelBase):
def forward(self, x): def forward(self, x):
x = self.conv1(x) x = self.conv1(x)
x = tf.nn.leaky_relu(x, 0.1) x = tf.nn.leaky_relu(x, 0.1)
x = tf.nn.depth_to_space(x, 2) x = nn.tf_depth_to_space(x, 2)
return x return x
class ResidualBlock(nn.ModelBase): class ResidualBlock(nn.ModelBase):
@ -192,9 +199,9 @@ class SAEHDModel(ModelBase):
x = tf.nn.leaky_relu(x, 0.2) x = tf.nn.leaky_relu(x, 0.2)
return x, upx return x, upx
class Encoder(nn.ModelBase): class Encoder(nn.ModelBase):
def on_build(self, in_ch, e_ch, is_hd): def on_build(self, in_ch, e_ch, is_hd):
self.is_hd=is_hd self.is_hd=is_hd
if self.is_hd: if self.is_hd:
self.down1 = DownscaleBlock(in_ch, e_ch*2, n_downscales=4, kernel_size=3, dilations=1) self.down1 = DownscaleBlock(in_ch, e_ch*2, n_downscales=4, kernel_size=3, dilations=1)
self.down2 = DownscaleBlock(in_ch, e_ch*2, n_downscales=4, kernel_size=5, dilations=1) self.down2 = DownscaleBlock(in_ch, e_ch*2, n_downscales=4, kernel_size=5, dilations=1)
@ -202,7 +209,7 @@ class SAEHDModel(ModelBase):
self.down4 = DownscaleBlock(in_ch, e_ch//2, n_downscales=4, kernel_size=7, dilations=2) self.down4 = DownscaleBlock(in_ch, e_ch//2, n_downscales=4, kernel_size=7, dilations=2)
else: else:
self.down1 = DownscaleBlock(in_ch, e_ch, n_downscales=4, kernel_size=5, dilations=1, subpixel=False) self.down1 = DownscaleBlock(in_ch, e_ch, n_downscales=4, kernel_size=5, dilations=1, subpixel=False)
def forward(self, inp): def forward(self, inp):
if self.is_hd: if self.is_hd:
x = tf.concat([ nn.tf_flatten(self.down1(inp)), x = tf.concat([ nn.tf_flatten(self.down1(inp)),
@ -211,85 +218,84 @@ class SAEHDModel(ModelBase):
nn.tf_flatten(self.down4(inp)) ], -1 ) nn.tf_flatten(self.down4(inp)) ], -1 )
else: else:
x = nn.tf_flatten(self.down1(inp)) x = nn.tf_flatten(self.down1(inp))
return x return x
class Inter(nn.ModelBase): class Inter(nn.ModelBase):
def __init__(self, in_ch, lowest_dense_res, ae_ch, ae_out_ch, **kwargs): def __init__(self, in_ch, lowest_dense_res, ae_ch, ae_out_ch, **kwargs):
self.in_ch, self.lowest_dense_res, self.ae_ch, self.ae_out_ch = in_ch, lowest_dense_res, ae_ch, ae_out_ch self.in_ch, self.lowest_dense_res, self.ae_ch, self.ae_out_ch = in_ch, lowest_dense_res, ae_ch, ae_out_ch
super().__init__(**kwargs) super().__init__(**kwargs)
def on_build(self): def on_build(self):
in_ch, lowest_dense_res, ae_ch, ae_out_ch = self.in_ch, self.lowest_dense_res, self.ae_ch, self.ae_out_ch in_ch, lowest_dense_res, ae_ch, ae_out_ch = self.in_ch, self.lowest_dense_res, self.ae_ch, self.ae_out_ch
self.dense1 = nn.Dense( in_ch, ae_ch, kernel_initializer=tf.initializers.orthogonal ) self.dense1 = nn.Dense( in_ch, ae_ch )
self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch, kernel_initializer=tf.initializers.orthogonal ) self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch )
self.upscale1 = Upscale(ae_out_ch, ae_out_ch) self.upscale1 = Upscale(ae_out_ch, ae_out_ch)
def forward(self, inp): def forward(self, inp):
x = self.dense1(inp) x = self.dense1(inp)
x = self.dense2(x) x = self.dense2(x)
x = tf.reshape (x, (-1, lowest_dense_res, lowest_dense_res, self.ae_out_ch)) x = nn.tf_reshape_4D (x, lowest_dense_res, lowest_dense_res, self.ae_out_ch)
x = self.upscale1(x) x = self.upscale1(x)
return x return x
def get_out_ch(self): def get_out_ch(self):
return self.ae_out_ch return self.ae_out_ch
class Decoder(nn.ModelBase): class Decoder(nn.ModelBase):
def on_build(self, in_ch, d_ch, d_mask_ch, is_hd ): def on_build(self, in_ch, d_ch, d_mask_ch, is_hd ):
self.is_hd = is_hd self.is_hd = is_hd
self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3) self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3)
self.upscale1 = Upscale(d_ch*8, d_ch*4, kernel_size=3) self.upscale1 = Upscale(d_ch*8, d_ch*4, kernel_size=3)
self.upscale2 = Upscale(d_ch*4, d_ch*2, kernel_size=3) self.upscale2 = Upscale(d_ch*4, d_ch*2, kernel_size=3)
if is_hd: if is_hd:
self.res0 = UpdownResidualBlock(in_ch, d_ch*8, kernel_size=3) self.res0 = UpdownResidualBlock(in_ch, d_ch*8, kernel_size=3)
self.res1 = UpdownResidualBlock(d_ch*8, d_ch*4, kernel_size=3) self.res1 = UpdownResidualBlock(d_ch*8, d_ch*4, kernel_size=3)
self.res2 = UpdownResidualBlock(d_ch*4, d_ch*2, kernel_size=3) self.res2 = UpdownResidualBlock(d_ch*4, d_ch*2, kernel_size=3)
self.res3 = UpdownResidualBlock(d_ch*2, d_ch, kernel_size=3) self.res3 = UpdownResidualBlock(d_ch*2, d_ch, kernel_size=3)
else: else:
self.res0 = ResidualBlock(d_ch*8, kernel_size=3) self.res0 = ResidualBlock(d_ch*8, kernel_size=3)
self.res1 = ResidualBlock(d_ch*4, kernel_size=3) self.res1 = ResidualBlock(d_ch*4, kernel_size=3)
self.res2 = ResidualBlock(d_ch*2, kernel_size=3) self.res2 = ResidualBlock(d_ch*2, kernel_size=3)
self.out_conv = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='SAME', kernel_initializer=conv_kernel_initializer) self.out_conv = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='SAME', kernel_initializer=conv_kernel_initializer)
self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3) self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3)
self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3) self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3)
self.upscalem2 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3) self.upscalem2 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3)
self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME', kernel_initializer=conv_kernel_initializer) self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME', kernel_initializer=conv_kernel_initializer)
def get_weights_ex(self, include_mask): def get_weights_ex(self, include_mask):
# Call internal get_weights in order to initialize inner logic # Call internal get_weights in order to initialize inner logic
self.get_weights() self.get_weights()
weights = self.upscale0.get_weights() + self.upscale1.get_weights() + self.upscale2.get_weights() \ weights = self.upscale0.get_weights() + self.upscale1.get_weights() + self.upscale2.get_weights() \
+ self.res0.get_weights() + self.res1.get_weights() + self.res2.get_weights() + self.out_conv.get_weights() + self.res0.get_weights() + self.res1.get_weights() + self.res2.get_weights() + self.out_conv.get_weights()
if include_mask: if include_mask:
weights += self.upscalem0.get_weights() + self.upscalem1.get_weights() + self.upscalem2.get_weights() \ weights += self.upscalem0.get_weights() + self.upscalem1.get_weights() + self.upscalem2.get_weights() \
+ self.out_convm.get_weights() + self.out_convm.get_weights()
return weights return weights
def forward(self, inp): def forward(self, inp):
z = inp z = inp
if self.is_hd: if self.is_hd:
x, upx = self.res0(z) x, upx = self.res0(z)
x = self.upscale0(x) x = self.upscale0(x)
x = tf.nn.leaky_relu(x + upx, 0.2) x = tf.nn.leaky_relu(x + upx, 0.2)
x, upx = self.res1(x) x, upx = self.res1(x)
x = self.upscale1(x) x = self.upscale1(x)
x = tf.nn.leaky_relu(x + upx, 0.2) x = tf.nn.leaky_relu(x + upx, 0.2)
x, upx = self.res2(x) x, upx = self.res2(x)
x = self.upscale2(x) x = self.upscale2(x)
x = tf.nn.leaky_relu(x + upx, 0.2) x = tf.nn.leaky_relu(x + upx, 0.2)
x, upx = self.res3(x) x, upx = self.res3(x)
else: else:
x = self.upscale0(z) x = self.upscale0(z)
x = self.res0(x) x = self.res0(x)
@ -301,13 +307,13 @@ class SAEHDModel(ModelBase):
m = self.upscalem0(z) m = self.upscalem0(z)
m = self.upscalem1(m) m = self.upscalem1(m)
m = self.upscalem2(m) m = self.upscalem2(m)
return tf.nn.sigmoid(self.out_conv(x)), \ return tf.nn.sigmoid(self.out_conv(x)), \
tf.nn.sigmoid(self.out_convm(m)) tf.nn.sigmoid(self.out_convm(m))
class CodeDiscriminator(nn.ModelBase): class CodeDiscriminator(nn.ModelBase):
def on_build(self, in_ch, code_res, ch=256): def on_build(self, in_ch, code_res, ch=256):
n_downscales = 2 + code_res // 8 n_downscales = 1 + code_res // 8
self.convs = [] self.convs = []
prev_ch = in_ch prev_ch = in_ch
@ -329,12 +335,12 @@ class SAEHDModel(ModelBase):
resolution = self.options['resolution'] resolution = self.options['resolution']
learn_mask = self.options['learn_mask'] learn_mask = self.options['learn_mask']
archi = self.options['archi'] archi = self.options['archi']
ae_dims = self.options['ae_dims'] ae_dims = self.options['ae_dims']
e_dims = self.options['e_dims'] e_dims = self.options['e_dims']
d_dims = self.options['d_dims'] d_dims = self.options['d_dims']
d_mask_dims = self.options['d_mask_dims'] d_mask_dims = self.options['d_mask_dims']
self.pretrain = self.options['pretrain'] self.pretrain = self.options['pretrain']
masked_training = True masked_training = True
models_opt_on_gpu = False if len(devices) != 1 else self.options['models_opt_on_gpu'] models_opt_on_gpu = False if len(devices) != 1 else self.options['models_opt_on_gpu']
@ -343,8 +349,8 @@ class SAEHDModel(ModelBase):
input_nc = 3 input_nc = 3
output_nc = 3 output_nc = 3
bgr_shape = (resolution, resolution, output_nc) bgr_shape = nn.get4Dshape(resolution,resolution,input_nc)
mask_shape = (resolution, resolution, 1) mask_shape = nn.get4Dshape(resolution,resolution,1)
lowest_dense_res = resolution // 16 lowest_dense_res = resolution // 16
self.model_filename_list = [] self.model_filename_list = []
@ -352,24 +358,24 @@ class SAEHDModel(ModelBase):
with tf.device ('/CPU:0'): with tf.device ('/CPU:0'):
#Place holders on CPU #Place holders on CPU
self.warped_src = tf.placeholder (tf.float32, (None,)+bgr_shape) self.warped_src = tf.placeholder (nn.tf_floatx, bgr_shape)
self.warped_dst = tf.placeholder (tf.float32, (None,)+bgr_shape) self.warped_dst = tf.placeholder (nn.tf_floatx, bgr_shape)
self.target_src = tf.placeholder (tf.float32, (None,)+bgr_shape) self.target_src = tf.placeholder (nn.tf_floatx, bgr_shape)
self.target_dst = tf.placeholder (tf.float32, (None,)+bgr_shape) self.target_dst = tf.placeholder (nn.tf_floatx, bgr_shape)
self.target_srcm = tf.placeholder (tf.float32, (None,)+mask_shape) self.target_srcm = tf.placeholder (nn.tf_floatx, mask_shape)
self.target_dstm = tf.placeholder (tf.float32, (None,)+mask_shape) self.target_dstm = tf.placeholder (nn.tf_floatx, mask_shape)
# Initializing model classes # Initializing model classes
with tf.device (models_opt_device): with tf.device (models_opt_device):
if 'df' in archi: if 'df' in archi:
self.encoder = Encoder(in_ch=input_nc, e_ch=e_dims, is_hd='hd' in archi, name='encoder') self.encoder = Encoder(in_ch=input_nc, e_ch=e_dims, is_hd='hd' in archi, name='encoder')
encoder_out_ch = self.encoder.compute_output_shape ( (tf.float32, (None,resolution,resolution,input_nc)))[-1] encoder_out_ch = self.encoder.compute_output_channels ( (nn.tf_floatx, bgr_shape))
self.inter = Inter (in_ch=encoder_out_ch, lowest_dense_res=lowest_dense_res, ae_ch=ae_dims, ae_out_ch=ae_dims, name='inter') self.inter = Inter (in_ch=encoder_out_ch, lowest_dense_res=lowest_dense_res, ae_ch=ae_dims, ae_out_ch=ae_dims, name='inter')
inter_out_ch = self.inter.compute_output_shape ( (tf.float32, (None,encoder_out_ch)))[-1] inter_out_ch = self.inter.compute_output_channels ( (nn.tf_floatx, (None,encoder_out_ch)))
self.decoder_src = Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, is_hd='hd' in archi, name='decoder_src') self.decoder_src = Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, is_hd='hd' in archi, name='decoder_src')
self.decoder_dst = Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, is_hd='hd' in archi, name='decoder_dst') self.decoder_dst = Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, is_hd='hd' in archi, name='decoder_dst')
@ -379,23 +385,22 @@ class SAEHDModel(ModelBase):
[self.decoder_dst, 'decoder_dst.npy'] ] [self.decoder_dst, 'decoder_dst.npy'] ]
if self.is_training: if self.is_training:
if self.options['true_face_training']: if self.options['true_face_power'] != 0:
self.dis = CodeDiscriminator(ae_dims, code_res=lowest_dense_res*2, name='dis' ) self.dis = CodeDiscriminator(ae_dims, code_res=lowest_dense_res*2, name='dis' )
self.model_filename_list += [ [self.dis, 'dis.npy'] ] self.model_filename_list += [ [self.dis, 'dis.npy'] ]
elif 'liae' in archi: elif 'liae' in archi:
self.encoder = Encoder(in_ch=input_nc, e_ch=e_dims, is_hd='hd' in archi, name='encoder') self.encoder = Encoder(in_ch=input_nc, e_ch=e_dims, is_hd='hd' in archi, name='encoder')
encoder_out_ch = self.encoder.compute_output_shape ( (tf.float32, (None,resolution,resolution,input_nc)))[-1] encoder_out_ch = self.encoder.compute_output_channels ( (nn.tf_floatx, bgr_shape))
self.inter_AB = Inter(in_ch=encoder_out_ch, lowest_dense_res=lowest_dense_res, ae_ch=ae_dims, ae_out_ch=ae_dims*2, name='inter_AB') self.inter_AB = Inter(in_ch=encoder_out_ch, lowest_dense_res=lowest_dense_res, ae_ch=ae_dims, ae_out_ch=ae_dims*2, name='inter_AB')
self.inter_B = Inter(in_ch=encoder_out_ch, lowest_dense_res=lowest_dense_res, ae_ch=ae_dims, ae_out_ch=ae_dims*2, name='inter_B') self.inter_B = Inter(in_ch=encoder_out_ch, lowest_dense_res=lowest_dense_res, ae_ch=ae_dims, ae_out_ch=ae_dims*2, name='inter_B')
inter_AB_out_ch = self.inter_AB.compute_output_shape ( (tf.float32, (None,encoder_out_ch)))[-1] inter_AB_out_ch = self.inter_AB.compute_output_channels ( (nn.tf_floatx, (None,encoder_out_ch)))
inter_B_out_ch = self.inter_B.compute_output_shape ( (tf.float32, (None,encoder_out_ch)))[-1] inter_B_out_ch = self.inter_B.compute_output_channels ( (nn.tf_floatx, (None,encoder_out_ch)))
inters_out_ch = inter_AB_out_ch+inter_B_out_ch inters_out_ch = inter_AB_out_ch+inter_B_out_ch
self.decoder = Decoder(in_ch=inters_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, is_hd='hd' in archi, name='decoder') self.decoder = Decoder(in_ch=inters_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, is_hd='hd' in archi, name='decoder')
self.model_filename_list += [ [self.encoder, 'encoder.npy'], self.model_filename_list += [ [self.encoder, 'encoder.npy'],
[self.inter_AB, 'inter_AB.npy'], [self.inter_AB, 'inter_AB.npy'],
[self.inter_B , 'inter_B.npy'], [self.inter_B , 'inter_B.npy'],
@ -417,8 +422,8 @@ class SAEHDModel(ModelBase):
self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter_AB.get_weights() + self.inter_B.get_weights() + self.decoder.get_weights_ex(learn_mask) self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter_AB.get_weights() + self.inter_B.get_weights() + self.decoder.get_weights_ex(learn_mask)
self.src_dst_opt.initialize_variables (self.src_dst_all_trainable_weights, vars_on_cpu=optimizer_vars_on_cpu) self.src_dst_opt.initialize_variables (self.src_dst_all_trainable_weights, vars_on_cpu=optimizer_vars_on_cpu)
if self.options['true_face_training']: if self.options['true_face_power'] != 0:
self.D_opt = nn.TFRMSpropOptimizer(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='D_opt') self.D_opt = nn.TFRMSpropOptimizer(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='D_opt')
self.D_opt.initialize_variables ( self.dis.get_weights(), vars_on_cpu=optimizer_vars_on_cpu) self.D_opt.initialize_variables ( self.dis.get_weights(), vars_on_cpu=optimizer_vars_on_cpu)
self.model_filename_list += [ (self.D_opt, 'D_opt.npy') ] self.model_filename_list += [ (self.D_opt, 'D_opt.npy') ]
@ -429,7 +434,7 @@ class SAEHDModel(ModelBase):
bs_per_gpu = max(1, self.get_batch_size() // gpu_count) bs_per_gpu = max(1, self.get_batch_size() // gpu_count)
self.set_batch_size( gpu_count*bs_per_gpu) self.set_batch_size( gpu_count*bs_per_gpu)
# Compute losses per GPU # Compute losses per GPU
gpu_pred_src_src_list = [] gpu_pred_src_src_list = []
gpu_pred_dst_dst_list = [] gpu_pred_dst_dst_list = []
@ -462,29 +467,29 @@ class SAEHDModel(ModelBase):
gpu_pred_src_src, gpu_pred_src_srcm = self.decoder_src(gpu_src_code) gpu_pred_src_src, gpu_pred_src_srcm = self.decoder_src(gpu_src_code)
gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code) gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code)
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code)
elif 'liae' in archi: elif 'liae' in archi:
gpu_src_code = self.encoder (gpu_warped_src) gpu_src_code = self.encoder (gpu_warped_src)
gpu_src_inter_AB_code = self.inter_AB (gpu_src_code) gpu_src_inter_AB_code = self.inter_AB (gpu_src_code)
gpu_src_code = tf.concat([gpu_src_inter_AB_code,gpu_src_inter_AB_code],-1) gpu_src_code = tf.concat([gpu_src_inter_AB_code,gpu_src_inter_AB_code], nn.conv2d_ch_axis )
gpu_dst_code = self.encoder (gpu_warped_dst) gpu_dst_code = self.encoder (gpu_warped_dst)
gpu_dst_inter_B_code = self.inter_B (gpu_dst_code) gpu_dst_inter_B_code = self.inter_B (gpu_dst_code)
gpu_dst_inter_AB_code = self.inter_AB (gpu_dst_code) gpu_dst_inter_AB_code = self.inter_AB (gpu_dst_code)
gpu_dst_code = tf.concat([gpu_dst_inter_B_code,gpu_dst_inter_AB_code],-1) gpu_dst_code = tf.concat([gpu_dst_inter_B_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis )
gpu_src_dst_code = tf.concat([gpu_dst_inter_AB_code,gpu_dst_inter_AB_code],-1) gpu_src_dst_code = tf.concat([gpu_dst_inter_AB_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis )
gpu_pred_src_src, gpu_pred_src_srcm = self.decoder(gpu_src_code) gpu_pred_src_src, gpu_pred_src_srcm = self.decoder(gpu_src_code)
gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder(gpu_dst_code) gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder(gpu_dst_code)
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code)
gpu_pred_src_src_list.append(gpu_pred_src_src) gpu_pred_src_src_list.append(gpu_pred_src_src)
gpu_pred_dst_dst_list.append(gpu_pred_dst_dst) gpu_pred_dst_dst_list.append(gpu_pred_dst_dst)
gpu_pred_src_dst_list.append(gpu_pred_src_dst) gpu_pred_src_dst_list.append(gpu_pred_src_dst)
gpu_pred_src_srcm_list.append(gpu_pred_src_srcm) gpu_pred_src_srcm_list.append(gpu_pred_src_srcm)
gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm) gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm)
gpu_pred_src_dstm_list.append(gpu_pred_src_dstm) gpu_pred_src_dstm_list.append(gpu_pred_src_dstm)
gpu_target_srcm_blur = nn.tf_gaussian_blur(gpu_target_srcm, max(1, resolution // 32) ) gpu_target_srcm_blur = nn.tf_gaussian_blur(gpu_target_srcm, max(1, resolution // 32) )
gpu_target_dstm_blur = nn.tf_gaussian_blur(gpu_target_dstm, max(1, resolution // 32) ) gpu_target_dstm_blur = nn.tf_gaussian_blur(gpu_target_dstm, max(1, resolution // 32) )
@ -503,28 +508,28 @@ class SAEHDModel(ModelBase):
gpu_src_loss = tf.reduce_mean ( 10*nn.tf_dssim(gpu_target_srcmasked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) gpu_src_loss = tf.reduce_mean ( 10*nn.tf_dssim(gpu_target_srcmasked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
gpu_src_loss += tf.reduce_mean ( 10*tf.square ( gpu_target_srcmasked_opt - gpu_pred_src_src_masked_opt ), axis=[1,2,3]) gpu_src_loss += tf.reduce_mean ( 10*tf.square ( gpu_target_srcmasked_opt - gpu_pred_src_src_masked_opt ), axis=[1,2,3])
if learn_mask: if learn_mask:
gpu_src_loss += tf.reduce_mean ( tf.square( gpu_target_srcm - gpu_pred_src_srcm ),axis=[1,2,3] ) gpu_src_loss += tf.reduce_mean ( 10*tf.square( gpu_target_srcm - gpu_pred_src_srcm ),axis=[1,2,3] )
face_style_power = self.options['face_style_power'] / 100.0 face_style_power = self.options['face_style_power'] / 100.0
if face_style_power != 0 and not self.pretrain: if face_style_power != 0 and not self.pretrain:
gpu_src_loss += nn.tf_style_loss(gpu_psd_target_dst_masked, gpu_target_dst_masked, gaussian_blur_radius=resolution//16, loss_weight=10000*face_style_power) gpu_src_loss += nn.tf_style_loss(gpu_psd_target_dst_masked, gpu_target_dst_masked, gaussian_blur_radius=resolution//16, loss_weight=10000*face_style_power)
bg_style_power = self.options['bg_style_power'] / 100.0 bg_style_power = self.options['bg_style_power'] / 100.0
if bg_style_power != 0 and not self.pretrain: if bg_style_power != 0 and not self.pretrain:
gpu_src_loss += tf.reduce_mean( (10*bg_style_power)*nn.tf_dssim(gpu_psd_target_dst_anti_masked, gpu_target_dst_anti_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) gpu_src_loss += tf.reduce_mean( (10*bg_style_power)*nn.tf_dssim(gpu_psd_target_dst_anti_masked, gpu_target_dst_anti_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
gpu_src_loss += tf.reduce_mean( (10*bg_style_power)*tf.square( gpu_psd_target_dst_anti_masked - gpu_target_dst_anti_masked), axis=[1,2,3] ) gpu_src_loss += tf.reduce_mean( (10*bg_style_power)*tf.square( gpu_psd_target_dst_anti_masked - gpu_target_dst_anti_masked), axis=[1,2,3] )
gpu_dst_loss = tf.reduce_mean ( 10*nn.tf_dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1]) gpu_dst_loss = tf.reduce_mean ( 10*nn.tf_dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1])
gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dst_masked_opt- gpu_pred_dst_dst_masked_opt ), axis=[1,2,3]) gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dst_masked_opt- gpu_pred_dst_dst_masked_opt ), axis=[1,2,3])
if learn_mask: if learn_mask:
gpu_dst_loss += tf.reduce_mean ( tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] ) gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] )
gpu_src_losses += [gpu_src_loss] gpu_src_losses += [gpu_src_loss]
gpu_dst_losses += [gpu_dst_loss] gpu_dst_losses += [gpu_dst_loss]
gpu_src_dst_loss = gpu_src_loss + gpu_dst_loss gpu_src_dst_loss = gpu_src_loss + gpu_dst_loss
if self.options['true_face_training']: if self.options['true_face_power'] != 0:
def DLoss(labels,logits): def DLoss(labels,logits):
return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits), axis=[1,2,3]) return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits), axis=[1,2,3])
@ -533,8 +538,8 @@ class SAEHDModel(ModelBase):
gpu_src_code_d_zeros = tf.zeros_like(gpu_src_code_d) gpu_src_code_d_zeros = tf.zeros_like(gpu_src_code_d)
gpu_dst_code_d = self.dis( gpu_dst_code ) gpu_dst_code_d = self.dis( gpu_dst_code )
gpu_dst_code_d_ones = tf.ones_like(gpu_dst_code_d) gpu_dst_code_d_ones = tf.ones_like(gpu_dst_code_d)
gpu_src_dst_loss += 0.01*DLoss(gpu_src_code_d_ones, gpu_src_code_d) gpu_src_dst_loss += self.options['true_face_power']*DLoss(gpu_src_code_d_ones, gpu_src_code_d)
gpu_D_loss = (DLoss(gpu_src_code_d_ones , gpu_dst_code_d) + \ gpu_D_loss = (DLoss(gpu_src_code_d_ones , gpu_dst_code_d) + \
DLoss(gpu_src_code_d_zeros, gpu_src_code_d) ) * 0.5 DLoss(gpu_src_code_d_zeros, gpu_src_code_d) ) * 0.5
@ -546,35 +551,20 @@ class SAEHDModel(ModelBase):
# Average losses and gradients, and create optimizer update ops # Average losses and gradients, and create optimizer update ops
with tf.device (models_opt_device): with tf.device (models_opt_device):
if gpu_count == 1: pred_src_src = nn.tf_concat(gpu_pred_src_src_list, 0)
pred_src_src = gpu_pred_src_src_list[0] pred_dst_dst = nn.tf_concat(gpu_pred_dst_dst_list, 0)
pred_dst_dst = gpu_pred_dst_dst_list[0] pred_src_dst = nn.tf_concat(gpu_pred_src_dst_list, 0)
pred_src_dst = gpu_pred_src_dst_list[0] pred_src_srcm = nn.tf_concat(gpu_pred_src_srcm_list, 0)
pred_src_srcm = gpu_pred_src_srcm_list[0] pred_dst_dstm = nn.tf_concat(gpu_pred_dst_dstm_list, 0)
pred_dst_dstm = gpu_pred_dst_dstm_list[0] pred_src_dstm = nn.tf_concat(gpu_pred_src_dstm_list, 0)
pred_src_dstm = gpu_pred_src_dstm_list[0]
src_loss = gpu_src_losses[0]
dst_loss = gpu_dst_losses[0]
src_dst_loss_gv = gpu_src_dst_loss_gvs[0]
else:
pred_src_src = tf.concat(gpu_pred_src_src_list, 0)
pred_dst_dst = tf.concat(gpu_pred_dst_dst_list, 0)
pred_src_dst = tf.concat(gpu_pred_src_dst_list, 0)
pred_src_srcm = tf.concat(gpu_pred_src_srcm_list, 0)
pred_dst_dstm = tf.concat(gpu_pred_dst_dstm_list, 0)
pred_src_dstm = tf.concat(gpu_pred_src_dstm_list, 0)
src_loss = nn.tf_average_tensor_list(gpu_src_losses)
dst_loss = nn.tf_average_tensor_list(gpu_dst_losses)
src_dst_loss_gv = nn.tf_average_gv_list (gpu_src_dst_loss_gvs)
if self.options['true_face_training']: src_loss = nn.tf_average_tensor_list(gpu_src_losses)
D_loss_gv = nn.tf_average_gv_list(gpu_D_loss_gvs) dst_loss = nn.tf_average_tensor_list(gpu_dst_losses)
src_dst_loss_gv = nn.tf_average_gv_list (gpu_src_dst_loss_gvs)
src_dst_loss_gv_op = self.src_dst_opt.get_update_op (src_dst_loss_gv ) src_dst_loss_gv_op = self.src_dst_opt.get_update_op (src_dst_loss_gv )
if self.options['true_face_training']: if self.options['true_face_power'] != 0:
D_loss_gv = nn.tf_average_gv_list(gpu_D_loss_gvs)
D_loss_gv_op = self.D_opt.get_update_op (D_loss_gv ) D_loss_gv_op = self.D_opt.get_update_op (D_loss_gv )
@ -594,7 +584,7 @@ class SAEHDModel(ModelBase):
return s, d return s, d
self.src_dst_train = src_dst_train self.src_dst_train = src_dst_train
if self.options['true_face_training']: if self.options['true_face_power'] != 0:
def D_train(warped_src, warped_dst): def D_train(warped_src, warped_dst):
nn.tf_sess.run ([D_loss_gv_op], feed_dict={self.warped_src: warped_src, self.warped_dst: warped_dst}) nn.tf_sess.run ([D_loss_gv_op], feed_dict={self.warped_src: warped_src, self.warped_dst: warped_dst})
self.D_train = D_train self.D_train = D_train
@ -611,23 +601,23 @@ class SAEHDModel(ModelBase):
self.warped_dst:warped_dst}) self.warped_dst:warped_dst})
self.AE_view = AE_view self.AE_view = AE_view
else: else:
# Initializing merge function # Initializing merge function
with tf.device( f'/GPU:0' if len(devices) != 0 else f'/CPU:0'): with tf.device( f'/GPU:0' if len(devices) != 0 else f'/CPU:0'):
if 'df' in archi: if 'df' in archi:
gpu_dst_code = self.inter(self.encoder(self.warped_dst)) gpu_dst_code = self.inter(self.encoder(self.warped_dst))
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code)
_, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code) _, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code)
elif 'liae' in archi: elif 'liae' in archi:
gpu_dst_code = self.encoder (self.warped_dst) gpu_dst_code = self.encoder (self.warped_dst)
gpu_dst_inter_B_code = self.inter_B (gpu_dst_code) gpu_dst_inter_B_code = self.inter_B (gpu_dst_code)
gpu_dst_inter_AB_code = self.inter_AB (gpu_dst_code) gpu_dst_inter_AB_code = self.inter_AB (gpu_dst_code)
gpu_dst_code = tf.concat([gpu_dst_inter_B_code,gpu_dst_inter_AB_code],-1) gpu_dst_code = tf.concat([gpu_dst_inter_B_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis)
gpu_src_dst_code = tf.concat([gpu_dst_inter_AB_code,gpu_dst_inter_AB_code],-1) gpu_src_dst_code = tf.concat([gpu_dst_inter_AB_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis)
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code)
_, gpu_pred_dst_dstm = self.decoder(gpu_dst_code) _, gpu_pred_dst_dstm = self.decoder(gpu_dst_code)
if learn_mask: if learn_mask:
def AE_merge( warped_dst): def AE_merge( warped_dst):
return nn.tf_sess.run ( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm], feed_dict={self.warped_dst:warped_dst}) return nn.tf_sess.run ( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm], feed_dict={self.warped_dst:warped_dst})
@ -640,7 +630,7 @@ class SAEHDModel(ModelBase):
# Loading/initializing all models/optimizers weights # Loading/initializing all models/optimizers weights
for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"): for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"):
do_init = self.is_first_run() do_init = self.is_first_run()
if self.pretrain_just_disabled: if self.pretrain_just_disabled:
if 'df' in archi: if 'df' in archi:
if model == self.inter: if model == self.inter:
@ -648,15 +638,15 @@ class SAEHDModel(ModelBase):
elif 'liae' in archi: elif 'liae' in archi:
if model == self.inter_AB: if model == self.inter_AB:
do_init = True do_init = True
if not do_init: if not do_init:
do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) ) do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) )
if do_init: if do_init:
model.init_weights() model.init_weights()
# initializing sample generators # initializing sample generators
if self.is_training: if self.is_training:
t = SampleProcessor.Types t = SampleProcessor.Types
if self.options['face_type'] == 'h': if self.options['face_type'] == 'h':
@ -670,29 +660,29 @@ class SAEHDModel(ModelBase):
training_data_dst_path = self.training_data_dst_path if not self.pretrain else self.get_pretraining_data_path() training_data_dst_path = self.training_data_dst_path if not self.pretrain else self.get_pretraining_data_path()
random_ct_samples_path=training_data_dst_path if self.options['ct_mode'] != 'none' and not self.pretrain else None random_ct_samples_path=training_data_dst_path if self.options['ct_mode'] != 'none' and not self.pretrain else None
t_img_warped = t.IMG_WARPED_TRANSFORMED if self.options['random_warp'] else t.IMG_TRANSFORMED t_img_warped = t.IMG_WARPED_TRANSFORMED if self.options['random_warp'] else t.IMG_TRANSFORMED
cpu_count = multiprocessing.cpu_count() cpu_count = multiprocessing.cpu_count()
src_generators_count = cpu_count // 2 src_generators_count = cpu_count // 2
if self.options['ct_mode'] != 'none': if self.options['ct_mode'] != 'none':
src_generators_count = int(src_generators_count * 1.5) src_generators_count = int(src_generators_count * 1.5)
dst_generators_count = cpu_count - src_generators_count dst_generators_count = cpu_count - src_generators_count
self.set_training_data_generators ([ self.set_training_data_generators ([
SampleGeneratorFace(training_data_src_path, random_ct_samples_path=random_ct_samples_path, debug=self.is_debug(), batch_size=self.get_batch_size(), SampleGeneratorFace(training_data_src_path, random_ct_samples_path=random_ct_samples_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip), sample_process_options=SampleProcessor.Options(random_flip=self.random_flip),
output_sample_types = [ {'types' : (t_img_warped, face_type, t.MODE_BGR), 'resolution':resolution, 'ct_mode': self.options['ct_mode'] }, output_sample_types = [ {'types' : (t_img_warped, face_type, t.MODE_BGR), 'data_format':nn.data_format, 'resolution': resolution, 'ct_mode': self.options['ct_mode'] },
{'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_BGR), 'resolution': resolution, 'ct_mode': self.options['ct_mode'] }, {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_BGR), 'data_format':nn.data_format, 'resolution': resolution, 'ct_mode': self.options['ct_mode'] },
{'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_M), 'resolution': resolution } ], {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_M), 'data_format':nn.data_format, 'resolution': resolution } ],
generators_count=src_generators_count ), generators_count=src_generators_count ),
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip), sample_process_options=SampleProcessor.Options(random_flip=self.random_flip),
output_sample_types = [ {'types' : (t_img_warped, face_type, t.MODE_BGR), 'resolution':resolution}, output_sample_types = [ {'types' : (t_img_warped, face_type, t.MODE_BGR), 'data_format':nn.data_format, 'resolution': resolution},
{'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_BGR), 'resolution': resolution}, {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_BGR), 'data_format':nn.data_format, 'resolution': resolution},
{'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_M), 'resolution': resolution} ], {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_M), 'data_format':nn.data_format, 'resolution': resolution} ],
generators_count=dst_generators_count ) generators_count=dst_generators_count )
]) ])
@ -710,10 +700,10 @@ class SAEHDModel(ModelBase):
def onTrainOneIter(self): def onTrainOneIter(self):
( (warped_src, target_src, target_srcm), \ ( (warped_src, target_src, target_srcm), \
(warped_dst, target_dst, target_dstm) ) = self.generate_next_samples() (warped_dst, target_dst, target_dstm) ) = self.generate_next_samples()
src_loss, dst_loss = self.src_dst_train (warped_src, target_src, target_srcm, warped_dst, target_dst, target_dstm) src_loss, dst_loss = self.src_dst_train (warped_src, target_src, target_srcm, warped_dst, target_dst, target_dstm)
if self.options['true_face_training'] and not self.pretrain: if self.options['true_face_power'] != 0 and not self.pretrain:
self.D_train (warped_src, warped_dst) self.D_train (warped_src, warped_dst)
return ( ('src_loss', src_loss), ('dst_loss', dst_loss), ) return ( ('src_loss', src_loss), ('dst_loss', dst_loss), )
@ -728,10 +718,12 @@ class SAEHDModel(ModelBase):
for sample_list in samples ] for sample_list in samples ]
if self.options['learn_mask']: if self.options['learn_mask']:
S, D, SS, DD, DDM, SD, SDM = [ np.clip(x, 0.0, 1.0) for x in ([target_src,target_dst] + self.AE_view (target_src, target_dst) ) ] S, D, SS, DD, DDM, SD, SDM = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([target_src,target_dst] + self.AE_view (target_src, target_dst) ) ]
DDM, SDM, = [ np.repeat (x, (3,), -1) for x in [DDM, SDM] ] DDM, SDM, = [ np.repeat (x, (3,), -1) for x in [DDM, SDM] ]
else: else:
S, D, SS, DD, SD, = [ np.clip(x, 0.0, 1.0) for x in ([target_src,target_dst] + self.AE_view (target_src, target_dst) ) ] S, D, SS, DD, SD, = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format) , 0.0, 1.0) for x in ([target_src,target_dst] + self.AE_view (target_src, target_dst) ) ]
target_srcm, target_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format) for x in ([target_srcm, target_dstm] )]
result = [] result = []
st = [] st = []
@ -753,12 +745,16 @@ class SAEHDModel(ModelBase):
return result return result
def predictor_func (self, face=None): def predictor_func (self, face=None):
face = face[None,...]
face = nn.to_data_format(face, self.model_data_format, "NHWC")
if self.options['learn_mask']: if self.options['learn_mask']:
bgr, mask_dst_dstm, mask_src_dstm = self.AE_merge (face[np.newaxis,...]) bgr, mask_dst_dstm, mask_src_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format).astype(np.float32) for x in self.AE_merge (face) ]
mask = mask_dst_dstm[0] * mask_src_dstm[0] mask = mask_dst_dstm[0] * mask_src_dstm[0]
return bgr[0], mask[...,0] return bgr[0], mask[...,0]
else: else:
bgr, = self.AE_merge (face[np.newaxis,...]) bgr, = [ nn.to_data_format(x,"NHWC", self.model_data_format).astype(np.float32) for x in self.AE_merge (face) ]
return bgr[0] return bgr[0]
#override #override

View file

@ -6,4 +6,4 @@ ffmpeg-python==0.1.17
scikit-image==0.14.2 scikit-image==0.14.2
scipy==1.4.1 scipy==1.4.1
colorama colorama
tensorflow-gpu==1.13.1 tensorflow-gpu==1.13.2

View file

@ -6,4 +6,4 @@ ffmpeg-python==0.1.17
scikit-image==0.14.2 scikit-image==0.14.2
scipy==1.4.1 scipy==1.4.1
colorama colorama
tensorflow-gpu==1.12.0 tensorflow-gpu==1.13.2

View file

@ -136,7 +136,7 @@ class PackedFaceset():
samples_configs = pickle.loads ( f.read(sizeof_samples_bytes) ) samples_configs = pickle.loads ( f.read(sizeof_samples_bytes) )
samples = [] samples = []
for sample_config in samples_configs: for sample_config in samples_configs:
sample_config = pickle.loads(pickle.dumps (sample_config)) sample_config = pickle.loads(pickle.dumps (sample_config))
samples.append ( Sample (**sample_config) ) samples.append ( Sample (**sample_config) )
offsets = [ struct.unpack("Q", f.read(8) )[0] for _ in range(len(samples)+1) ] offsets = [ struct.unpack("Q", f.read(8) )[0] for _ in range(len(samples)+1) ]

View file

@ -31,7 +31,7 @@ class Sample(object):
'source_filename', 'source_filename',
'person_name', 'person_name',
'pitch_yaw_roll', 'pitch_yaw_roll',
'_filename_offset_size', '_filename_offset_size',
] ]
def __init__(self, sample_type=None, def __init__(self, sample_type=None,
@ -39,10 +39,10 @@ class Sample(object):
face_type=None, face_type=None,
shape=None, shape=None,
landmarks=None, landmarks=None,
ie_polys=None, ie_polys=None,
eyebrows_expand_mod=None, eyebrows_expand_mod=None,
source_filename=None, source_filename=None,
person_name=None, person_name=None,
pitch_yaw_roll=None, pitch_yaw_roll=None,
**kwargs): **kwargs):
@ -55,15 +55,15 @@ class Sample(object):
self.eyebrows_expand_mod = eyebrows_expand_mod self.eyebrows_expand_mod = eyebrows_expand_mod
self.source_filename = source_filename self.source_filename = source_filename
self.person_name = person_name self.person_name = person_name
self.pitch_yaw_roll = pitch_yaw_roll self.pitch_yaw_roll = pitch_yaw_roll
self._filename_offset_size = None self._filename_offset_size = None
def get_pitch_yaw_roll(self): def get_pitch_yaw_roll(self):
if self.pitch_yaw_roll is None: if self.pitch_yaw_roll is None:
self.pitch_yaw_roll = LandmarksProcessor.estimate_pitch_yaw_roll(landmarks) self.pitch_yaw_roll = LandmarksProcessor.estimate_pitch_yaw_roll(landmarks)
return self.pitch_yaw_roll return self.pitch_yaw_roll
def set_filename_offset_size(self, filename, offset, size): def set_filename_offset_size(self, filename, offset, size):
self._filename_offset_size = (filename, offset, size) self._filename_offset_size = (filename, offset, size)

View file

@ -14,11 +14,11 @@ from samplelib import (SampleGeneratorBase, SampleHost, SampleProcessor,
class SampleGeneratorFaceTemporal(SampleGeneratorBase): class SampleGeneratorFaceTemporal(SampleGeneratorBase):
def __init__ (self, samples_path, debug, batch_size, def __init__ (self, samples_path, debug, batch_size,
temporal_image_count=3, temporal_image_count=3,
sample_process_options=SampleProcessor.Options(), sample_process_options=SampleProcessor.Options(),
output_sample_types=[], output_sample_types=[],
generators_count=2, generators_count=2,
**kwargs): **kwargs):
super().__init__(samples_path, debug, batch_size) super().__init__(samples_path, debug, batch_size)
@ -35,11 +35,11 @@ class SampleGeneratorFaceTemporal(SampleGeneratorBase):
samples_len = len(samples) samples_len = len(samples)
if samples_len == 0: if samples_len == 0:
raise ValueError('No training data provided.') raise ValueError('No training data provided.')
mult_max = 1 mult_max = 1
l = samples_len - ( (self.temporal_image_count)*mult_max - (mult_max-1) ) l = samples_len - ( (self.temporal_image_count)*mult_max - (mult_max-1) )
index_host = mplib.IndexHost(l+1) index_host = mplib.IndexHost(l+1)
pickled_samples = pickle.dumps(samples, 4) pickled_samples = pickle.dumps(samples, 4)
if self.debug: if self.debug:
self.generators = [ThisThreadGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(),) )] self.generators = [ThisThreadGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(),) )]
@ -64,9 +64,9 @@ class SampleGeneratorFaceTemporal(SampleGeneratorBase):
while True: while True:
batches = None batches = None
indexes = index_host.multi_get(bs) indexes = index_host.multi_get(bs)
for n_batch in range(self.batch_size): for n_batch in range(self.batch_size):
idx = indexes[n_batch] idx = indexes[n_batch]

View file

@ -46,7 +46,7 @@ class SampleGeneratorImageTemporal(SampleGeneratorBase):
mult_max = 4 mult_max = 4
samples_sub_len = samples_len - ( (self.temporal_image_count)*mult_max - (mult_max-1) ) samples_sub_len = samples_len - ( (self.temporal_image_count)*mult_max - (mult_max-1) )
if samples_sub_len <= 0: if samples_sub_len <= 0:
raise ValueError('Not enough samples to fit temporal line.') raise ValueError('Not enough samples to fit temporal line.')

View file

@ -15,10 +15,6 @@ from .Sample import Sample, SampleType
class SampleHost: class SampleHost:
samples_cache = dict() samples_cache = dict()
@staticmethod @staticmethod
def get_person_id_max_count(samples_path): def get_person_id_max_count(samples_path):
@ -47,7 +43,7 @@ class SampleHost:
if sample_type == SampleType.IMAGE: if sample_type == SampleType.IMAGE:
if samples[sample_type] is None: if samples[sample_type] is None:
samples[sample_type] = [ Sample(filename=filename) for filename in io.progress_bar_generator( pathex.get_image_paths(samples_path), "Loading") ] samples[sample_type] = [ Sample(filename=filename) for filename in io.progress_bar_generator( pathex.get_image_paths(samples_path), "Loading") ]
elif sample_type == SampleType.FACE: elif sample_type == SampleType.FACE:
if samples[sample_type] is None: if samples[sample_type] is None:
try: try:
@ -61,12 +57,12 @@ class SampleHost:
if result is None: if result is None:
result = SampleHost.load_face_samples( pathex.get_image_paths(samples_path) ) result = SampleHost.load_face_samples( pathex.get_image_paths(samples_path) )
samples[sample_type] = result samples[sample_type] = result
elif sample_type == SampleType.FACE_TEMPORAL_SORTED: elif sample_type == SampleType.FACE_TEMPORAL_SORTED:
result = SampleHost.load (SampleType.FACE, samples_path) result = SampleHost.load (SampleType.FACE, samples_path)
result = SampleHost.upgradeToFaceTemporalSortedSamples(result) result = SampleHost.upgradeToFaceTemporalSortedSamples(result)
samples[sample_type] = result samples[sample_type] = result
return samples[sample_type] return samples[sample_type]
@staticmethod @staticmethod
@ -92,17 +88,17 @@ class SampleHost:
source_filename=source_filename, source_filename=source_filename,
)) ))
return sample_list return sample_list
""" """
@staticmethod @staticmethod
def load_face_samples ( image_paths): def load_face_samples ( image_paths):
sample_list = [] sample_list = []
for filename in io.progress_bar_generator (image_paths, desc="Loading"): for filename in io.progress_bar_generator (image_paths, desc="Loading"):
dflimg = DFLIMG.load (Path(filename)) dflimg = DFLIMG.load (Path(filename))
if dflimg is None: if dflimg is None:
io.log_err (f"{filename} is not a dfl image file.") io.log_err (f"{filename} is not a dfl image file.")
else: else:
sample_list.append( Sample(filename=filename, sample_list.append( Sample(filename=filename,
sample_type=SampleType.FACE, sample_type=SampleType.FACE,
face_type=FaceType.fromString ( dflimg.get_face_type() ), face_type=FaceType.fromString ( dflimg.get_face_type() ),
@ -114,15 +110,15 @@ class SampleHost:
)) ))
return sample_list return sample_list
""" """
@staticmethod @staticmethod
def upgradeToFaceTemporalSortedSamples( samples ): def upgradeToFaceTemporalSortedSamples( samples ):
new_s = [ (s, s.source_filename) for s in samples] new_s = [ (s, s.source_filename) for s in samples]
new_s = sorted(new_s, key=operator.itemgetter(1)) new_s = sorted(new_s, key=operator.itemgetter(1))
return [ s[0] for s in new_s] return [ s[0] for s in new_s]
class FaceSamplesLoaderSubprocessor(Subprocessor): class FaceSamplesLoaderSubprocessor(Subprocessor):
#override #override
def __init__(self, image_paths ): def __init__(self, image_paths ):

View file

@ -37,7 +37,7 @@ opts:
'resolution' : N 'resolution' : N
'motion_blur' : (chance_int, range) - chance 0..100 to apply to face (not mask), and max_size of motion blur 'motion_blur' : (chance_int, range) - chance 0..100 to apply to face (not mask), and max_size of motion blur
'ct_mode' : 'ct_mode' :
'normalize_tanh' : bool 'normalize_tanh' : bool
""" """
@ -94,11 +94,11 @@ class SampleProcessor(object):
@staticmethod @staticmethod
def process (samples, sample_process_options, output_sample_types, debug, ct_sample=None): def process (samples, sample_process_options, output_sample_types, debug, ct_sample=None):
SPTF = SampleProcessor.Types SPTF = SampleProcessor.Types
sample_rnd_seed = np.random.randint(0x80000000) sample_rnd_seed = np.random.randint(0x80000000)
outputs = [] outputs = []
for sample in samples: for sample in samples:
sample_bgr = sample.load_bgr() sample_bgr = sample.load_bgr()
ct_sample_bgr = None ct_sample_bgr = None
ct_sample_mask = None ct_sample_mask = None
@ -123,9 +123,11 @@ class SampleProcessor(object):
normalize_vgg = opts.get('normalize_vgg', False) normalize_vgg = opts.get('normalize_vgg', False)
motion_blur = opts.get('motion_blur', None) motion_blur = opts.get('motion_blur', None)
gaussian_blur = opts.get('gaussian_blur', None) gaussian_blur = opts.get('gaussian_blur', None)
ct_mode = opts.get('ct_mode', 'None') ct_mode = opts.get('ct_mode', 'None')
normalize_tanh = opts.get('normalize_tanh', False) normalize_tanh = opts.get('normalize_tanh', False)
data_format = opts.get('data_format', 'NHWC')
img_type = SPTF.NONE img_type = SPTF.NONE
target_face_type = SPTF.NONE target_face_type = SPTF.NONE
@ -149,7 +151,7 @@ class SampleProcessor(object):
img = l img = l
elif img_type == SPTF.IMG_PITCH_YAW_ROLL or img_type == SPTF.IMG_PITCH_YAW_ROLL_SIGMOID: elif img_type == SPTF.IMG_PITCH_YAW_ROLL or img_type == SPTF.IMG_PITCH_YAW_ROLL_SIGMOID:
pitch_yaw_roll = sample.get_pitch_yaw_roll() pitch_yaw_roll = sample.get_pitch_yaw_roll()
if params['flip']: if params['flip']:
yaw = -yaw yaw = -yaw
@ -174,7 +176,7 @@ class SampleProcessor(object):
if len(mask.shape) == 2: if len(mask.shape) == 2:
mask = mask[...,np.newaxis] mask = mask[...,np.newaxis]
return img, mask return img, mask
img = sample_bgr img = sample_bgr
@ -202,7 +204,7 @@ class SampleProcessor(object):
if gaussian_blur is not None: if gaussian_blur is not None:
chance, kernel_max_size = gaussian_blur chance, kernel_max_size = gaussian_blur
chance = np.clip(chance, 0, 100) chance = np.clip(chance, 0, 100)
if np.random.randint(100) < chance: if np.random.randint(100) < chance:
img = cv2.GaussianBlur(img, ( np.random.randint( kernel_max_size )*2+1 ,) *2 , 0) img = cv2.GaussianBlur(img, ( np.random.randint( kernel_max_size )*2+1 ,) *2 , 0)
@ -221,7 +223,7 @@ class SampleProcessor(object):
img = cv2.resize( img, (resolution,resolution), cv2.INTER_CUBIC ) img = cv2.resize( img, (resolution,resolution), cv2.INTER_CUBIC )
else: else:
img, mask = do_transform (img, mask) img, mask = do_transform (img, mask)
mat = LandmarksProcessor.get_transform_mat (sample.landmarks, resolution, target_ft) mat = LandmarksProcessor.get_transform_mat (sample.landmarks, resolution, target_ft)
img = cv2.warpAffine( img, mat, (resolution,resolution), borderMode=(cv2.BORDER_REPLICATE if border_replicate else cv2.BORDER_CONSTANT), flags=cv2.INTER_CUBIC ) img = cv2.warpAffine( img, mat, (resolution,resolution), borderMode=(cv2.BORDER_REPLICATE if border_replicate else cv2.BORDER_CONSTANT), flags=cv2.INTER_CUBIC )
mask = cv2.warpAffine( mask, mat, (resolution,resolution), borderMode=cv2.BORDER_CONSTANT, flags=cv2.INTER_CUBIC ) mask = cv2.warpAffine( mask, mat, (resolution,resolution), borderMode=cv2.BORDER_CONSTANT, flags=cv2.INTER_CUBIC )
@ -256,7 +258,7 @@ class SampleProcessor(object):
img_bgr = imagelib.reinhard_color_transfer ( np.clip( (img_bgr*255).astype(np.uint8), 0, 255), img_bgr = imagelib.reinhard_color_transfer ( np.clip( (img_bgr*255).astype(np.uint8), 0, 255),
np.clip( (ct_sample_bgr_resized*255).astype(np.uint8), 0, 255) ) np.clip( (ct_sample_bgr_resized*255).astype(np.uint8), 0, 255) )
img_bgr = np.clip( img_bgr.astype(np.float32) / 255.0, 0.0, 1.0) img_bgr = np.clip( img_bgr.astype(np.float32) / 255.0, 0.0, 1.0)
elif ct_mode == 'mkl': elif ct_mode == 'mkl':
img_bgr = imagelib.color_transfer_mkl (img_bgr, ct_sample_bgr_resized) img_bgr = imagelib.color_transfer_mkl (img_bgr, ct_sample_bgr_resized)
elif ct_mode == 'idt': elif ct_mode == 'idt':
img_bgr = imagelib.color_transfer_idt (img_bgr, ct_sample_bgr_resized) img_bgr = imagelib.color_transfer_idt (img_bgr, ct_sample_bgr_resized)
@ -271,21 +273,21 @@ class SampleProcessor(object):
img_bgr[:,:,0] -= 103.939 img_bgr[:,:,0] -= 103.939
img_bgr[:,:,1] -= 116.779 img_bgr[:,:,1] -= 116.779
img_bgr[:,:,2] -= 123.68 img_bgr[:,:,2] -= 123.68
if mode_type == SPTF.MODE_BGR: if mode_type == SPTF.MODE_BGR:
img = img_bgr img = img_bgr
elif mode_type == SPTF.MODE_BGR_SHUFFLE: elif mode_type == SPTF.MODE_BGR_SHUFFLE:
rnd_state = np.random.RandomState (sample_rnd_seed) rnd_state = np.random.RandomState (sample_rnd_seed)
img = np.take (img_bgr, rnd_state.permutation(img_bgr.shape[-1]), axis=-1) img = np.take (img_bgr, rnd_state.permutation(img_bgr.shape[-1]), axis=-1)
elif mode_type == SPTF.MODE_BGR_RANDOM_HSV_SHIFT: elif mode_type == SPTF.MODE_BGR_RANDOM_HSV_SHIFT:
rnd_state = np.random.RandomState (sample_rnd_seed) rnd_state = np.random.RandomState (sample_rnd_seed)
hsv = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2HSV) hsv = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2HSV)
h, s, v = cv2.split(hsv) h, s, v = cv2.split(hsv)
h = (h + rnd_state.randint(360) ) % 360 h = (h + rnd_state.randint(360) ) % 360
s = np.clip ( s + rnd_state.random()-0.5, 0, 1 ) s = np.clip ( s + rnd_state.random()-0.5, 0, 1 )
v = np.clip ( v + rnd_state.random()-0.5, 0, 1 ) v = np.clip ( v + rnd_state.random()-0.5, 0, 1 )
hsv = cv2.merge([h, s, v]) hsv = cv2.merge([h, s, v])
img = np.clip( cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) , 0, 1 ) img = np.clip( cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) , 0, 1 )
elif mode_type == SPTF.MODE_G: elif mode_type == SPTF.MODE_G:
img = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)[...,None] img = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)[...,None]
@ -300,9 +302,13 @@ class SampleProcessor(object):
else: else:
img = np.clip (img, 0.0, 1.0) img = np.clip (img, 0.0, 1.0)
if data_format == "NCHW":
img = np.transpose(img, (2,0,1) )
outputs_sample.append ( img ) outputs_sample.append ( img )
outputs += [outputs_sample] outputs += [outputs_sample]
return outputs return outputs
""" """