removing trailing spaces

This commit is contained in:
iperov 2019-03-19 23:53:27 +04:00
parent fa4e579b95
commit a3df04999c
61 changed files with 2110 additions and 2103 deletions

View file

@ -7,25 +7,25 @@ class Converter(object):
TYPE_FACE = 0 #calls convert_face TYPE_FACE = 0 #calls convert_face
TYPE_IMAGE = 1 #calls convert_image without landmarks TYPE_IMAGE = 1 #calls convert_image without landmarks
TYPE_IMAGE_WITH_LANDMARKS = 2 #calls convert_image with landmarks TYPE_IMAGE_WITH_LANDMARKS = 2 #calls convert_image with landmarks
#overridable #overridable
def __init__(self, predictor_func, type): def __init__(self, predictor_func, type):
self.predictor_func = predictor_func self.predictor_func = predictor_func
self.type = type self.type = type
#overridable #overridable
def convert_face (self, img_bgr, img_face_landmarks, debug): def convert_face (self, img_bgr, img_face_landmarks, debug):
#return float32 image #return float32 image
#if debug , return tuple ( images of any size and channels, ...) #if debug , return tuple ( images of any size and channels, ...)
return image return image
#overridable #overridable
def convert_image (self, img_bgr, img_landmarks, debug): def convert_image (self, img_bgr, img_landmarks, debug):
#img_landmarks not None, if input image is png with embedded data #img_landmarks not None, if input image is png with embedded data
#return float32 image #return float32 image
#if debug , return tuple ( images of any size and channels, ...) #if debug , return tuple ( images of any size and channels, ...)
return image return image
#overridable #overridable
def dummy_predict(self): def dummy_predict(self):
#do dummy predict here #do dummy predict here
@ -33,8 +33,8 @@ class Converter(object):
def copy(self): def copy(self):
return copy.copy(self) return copy.copy(self)
def copy_and_set_predictor(self, predictor_func): def copy_and_set_predictor(self, predictor_func):
result = self.copy() result = self.copy()
result.predictor_func = predictor_func result.predictor_func = predictor_func
return result return result

View file

@ -7,7 +7,7 @@ import numpy as np
from utils import image_utils from utils import image_utils
''' '''
predictor_func: predictor_func:
input: [predictor_input_size, predictor_input_size, BGR] input: [predictor_input_size, predictor_input_size, BGR]
output: [predictor_input_size, predictor_input_size, BGR] output: [predictor_input_size, predictor_input_size, BGR]
''' '''
@ -16,18 +16,18 @@ class ConverterImage(Converter):
#override #override
def __init__(self, predictor_func, def __init__(self, predictor_func,
predictor_input_size=0, predictor_input_size=0,
output_size=0): output_size=0):
super().__init__(predictor_func, Converter.TYPE_IMAGE) super().__init__(predictor_func, Converter.TYPE_IMAGE)
self.predictor_input_size = predictor_input_size self.predictor_input_size = predictor_input_size
self.output_size = output_size self.output_size = output_size
#override #override
def dummy_predict(self): def dummy_predict(self):
self.predictor_func ( np.zeros ( (self.predictor_input_size, self.predictor_input_size,3), dtype=np.float32) ) self.predictor_func ( np.zeros ( (self.predictor_input_size, self.predictor_input_size,3), dtype=np.float32) )
#override #override
def convert_image (self, img_bgr, img_landmarks, debug): def convert_image (self, img_bgr, img_landmarks, debug):
img_size = img_bgr.shape[1], img_bgr.shape[0] img_size = img_bgr.shape[1], img_bgr.shape[0]

View file

@ -4,36 +4,36 @@ import cv2
from pathlib import Path from pathlib import Path
class DLIBExtractor(object): class DLIBExtractor(object):
def __init__(self, dlib): def __init__(self, dlib):
self.scale_to = 1850 self.scale_to = 1850
#3100 eats ~1.687GB VRAM on 2GB 730 desktop card, but >4Gb on 6GB card, #3100 eats ~1.687GB VRAM on 2GB 730 desktop card, but >4Gb on 6GB card,
#but 3100 doesnt work on 2GB 850M notebook card, I cant understand this behaviour #but 3100 doesnt work on 2GB 850M notebook card, I cant understand this behaviour
#1850 works on 2GB 850M notebook card, works faster than 3100, produces good result #1850 works on 2GB 850M notebook card, works faster than 3100, produces good result
self.dlib = dlib self.dlib = dlib
def __enter__(self): def __enter__(self):
self.dlib_cnn_face_detector = self.dlib.cnn_face_detection_model_v1( str(Path(__file__).parent / "mmod_human_face_detector.dat") ) self.dlib_cnn_face_detector = self.dlib.cnn_face_detection_model_v1( str(Path(__file__).parent / "mmod_human_face_detector.dat") )
self.dlib_cnn_face_detector ( np.zeros ( (self.scale_to, self.scale_to, 3), dtype=np.uint8), 0 ) self.dlib_cnn_face_detector ( np.zeros ( (self.scale_to, self.scale_to, 3), dtype=np.uint8), 0 )
return self return self
def __exit__(self, exc_type=None, exc_value=None, traceback=None): def __exit__(self, exc_type=None, exc_value=None, traceback=None):
del self.dlib_cnn_face_detector del self.dlib_cnn_face_detector
return False #pass exception between __enter__ and __exit__ to outter level return False #pass exception between __enter__ and __exit__ to outter level
def extract_from_bgr (self, input_image): def extract_from_bgr (self, input_image):
input_image = input_image[:,:,::-1].copy() input_image = input_image[:,:,::-1].copy()
(h, w, ch) = input_image.shape (h, w, ch) = input_image.shape
detected_faces = [] detected_faces = []
input_scale = self.scale_to / (w if w > h else h) input_scale = self.scale_to / (w if w > h else h)
input_image = cv2.resize (input_image, ( int(w*input_scale), int(h*input_scale) ), interpolation=cv2.INTER_LINEAR) input_image = cv2.resize (input_image, ( int(w*input_scale), int(h*input_scale) ), interpolation=cv2.INTER_LINEAR)
detected_faces = self.dlib_cnn_face_detector(input_image, 0) detected_faces = self.dlib_cnn_face_detector(input_image, 0)
result = [] result = []
for d_rect in detected_faces: for d_rect in detected_faces:
if type(d_rect) == self.dlib.mmod_rectangle: if type(d_rect) == self.dlib.mmod_rectangle:
d_rect = d_rect.rect d_rect = d_rect.rect
left, top, right, bottom = d_rect.left(), d_rect.top(), d_rect.right(), d_rect.bottom() left, top, right, bottom = d_rect.left(), d_rect.top(), d_rect.right(), d_rect.bottom()
result.append ( (int(left/input_scale), int(top/input_scale), int(right/input_scale), int(bottom/input_scale)) ) result.append ( (int(left/input_scale), int(top/input_scale), int(right/input_scale), int(bottom/input_scale)) )

View file

@ -8,16 +8,16 @@ from interact import interact as io
class FANSegmentator(object): class FANSegmentator(object):
def __init__ (self, resolution, face_type_str, load_weights=True, weights_file_root=None): def __init__ (self, resolution, face_type_str, load_weights=True, weights_file_root=None):
exec( nnlib.import_all(), locals(), globals() ) exec( nnlib.import_all(), locals(), globals() )
self.model = FANSegmentator.BuildModel(resolution, ngf=32) self.model = FANSegmentator.BuildModel(resolution, ngf=32)
if weights_file_root: if weights_file_root:
weights_file_root = Path(weights_file_root) weights_file_root = Path(weights_file_root)
else: else:
weights_file_root = Path(__file__).parent weights_file_root = Path(__file__).parent
self.weights_path = weights_file_root / ('FANSeg_%d_%s.h5' % (resolution, face_type_str) ) self.weights_path = weights_file_root / ('FANSeg_%d_%s.h5' % (resolution, face_type_str) )
if load_weights: if load_weights:
self.model.load_weights (str(self.weights_path)) self.model.load_weights (str(self.weights_path))
else: else:
@ -31,19 +31,19 @@ class FANSegmentator(object):
def __enter__(self): def __enter__(self):
return self return self
def __exit__(self, exc_type=None, exc_value=None, traceback=None): def __exit__(self, exc_type=None, exc_value=None, traceback=None):
return False #pass exception between __enter__ and __exit__ to outter level return False #pass exception between __enter__ and __exit__ to outter level
def save_weights(self): def save_weights(self):
self.model.save_weights (str(self.weights_path)) self.model.save_weights (str(self.weights_path))
def train_on_batch(self, inp, outp): def train_on_batch(self, inp, outp):
return self.model.train_on_batch(inp, outp) return self.model.train_on_batch(inp, outp)
def extract_from_bgr (self, input_image): def extract_from_bgr (self, input_image):
return np.clip ( (self.model.predict(input_image) + 1) / 2.0, 0, 1.0 ) return np.clip ( (self.model.predict(input_image) + 1) / 2.0, 0, 1.0 )
@staticmethod @staticmethod
def BuildModel ( resolution, ngf=64): def BuildModel ( resolution, ngf=64):
exec( nnlib.import_all(), locals(), globals() ) exec( nnlib.import_all(), locals(), globals() )
@ -53,7 +53,7 @@ class FANSegmentator(object):
x = FANSegmentator.DecFlow(ngf=ngf)(x) x = FANSegmentator.DecFlow(ngf=ngf)(x)
model = Model(inp,x) model = Model(inp,x)
return model return model
@staticmethod @staticmethod
def EncFlow(ngf=64, num_downs=4): def EncFlow(ngf=64, num_downs=4):
exec( nnlib.import_all(), locals(), globals() ) exec( nnlib.import_all(), locals(), globals() )
@ -65,19 +65,19 @@ class FANSegmentator(object):
def downscale (dim): def downscale (dim):
def func(x): def func(x):
return LeakyReLU(0.1)(XNormalization(Conv2D(dim, kernel_size=5, strides=2, padding='same', kernel_initializer=RandomNormal(0, 0.02))(x))) return LeakyReLU(0.1)(XNormalization(Conv2D(dim, kernel_size=5, strides=2, padding='same', kernel_initializer=RandomNormal(0, 0.02))(x)))
return func return func
def func(input): def func(input):
x = input x = input
result = [] result = []
for i in range(num_downs): for i in range(num_downs):
x = downscale ( min(ngf*(2**i), ngf*8) )(x) x = downscale ( min(ngf*(2**i), ngf*8) )(x)
result += [x] result += [x]
return result return result
return func return func
@staticmethod @staticmethod
def DecFlow(output_nc=1, ngf=64, activation='tanh'): def DecFlow(output_nc=1, ngf=64, activation='tanh'):
exec (nnlib.import_all(), locals(), globals()) exec (nnlib.import_all(), locals(), globals())
@ -85,23 +85,23 @@ class FANSegmentator(object):
use_bias = True use_bias = True
def XNormalization(x): def XNormalization(x):
return InstanceNormalization (axis=3, gamma_initializer=RandomNormal(1., 0.02))(x) return InstanceNormalization (axis=3, gamma_initializer=RandomNormal(1., 0.02))(x)
def Conv2D (filters, kernel_size, strides=(1, 1), padding='valid', data_format=None, dilation_rate=(1, 1), activation=None, use_bias=use_bias, kernel_initializer=RandomNormal(0, 0.02), bias_initializer='zeros', kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None): def Conv2D (filters, kernel_size, strides=(1, 1), padding='valid', data_format=None, dilation_rate=(1, 1), activation=None, use_bias=use_bias, kernel_initializer=RandomNormal(0, 0.02), bias_initializer='zeros', kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None):
return keras.layers.Conv2D( filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, data_format=data_format, dilation_rate=dilation_rate, activation=activation, use_bias=use_bias, kernel_initializer=kernel_initializer, bias_initializer=bias_initializer, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, activity_regularizer=activity_regularizer, kernel_constraint=kernel_constraint, bias_constraint=bias_constraint ) return keras.layers.Conv2D( filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, data_format=data_format, dilation_rate=dilation_rate, activation=activation, use_bias=use_bias, kernel_initializer=kernel_initializer, bias_initializer=bias_initializer, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, activity_regularizer=activity_regularizer, kernel_constraint=kernel_constraint, bias_constraint=bias_constraint )
def upscale (dim): def upscale (dim):
def func(x): def func(x):
return SubpixelUpscaler()( LeakyReLU(0.1)(XNormalization(Conv2D(dim, kernel_size=3, strides=1, padding='same', kernel_initializer=RandomNormal(0, 0.02))(x)))) return SubpixelUpscaler()( LeakyReLU(0.1)(XNormalization(Conv2D(dim, kernel_size=3, strides=1, padding='same', kernel_initializer=RandomNormal(0, 0.02))(x))))
return func return func
def func(input): def func(input):
input_len = len(input) input_len = len(input)
x = input[input_len-1] x = input[input_len-1]
for i in range(input_len-1, -1, -1): for i in range(input_len-1, -1, -1):
x = upscale( min(ngf* (2**i) *4, ngf*8 *4 ) )(x) x = upscale( min(ngf* (2**i) *4, ngf*8 *4 ) )(x)
if i != 0: if i != 0:
x = Concatenate(axis=3)([ input[i-1] , x]) x = Concatenate(axis=3)([ input[i-1] , x])
return Conv2D(output_nc, 3, 1, 'same', activation=activation)(x) return Conv2D(output_nc, 3, 1, 'same', activation=activation)(x)
return func return func

View file

@ -3,7 +3,7 @@ from enum import IntEnum
class FaceType(IntEnum): class FaceType(IntEnum):
HALF = 0, HALF = 0,
FULL = 1, FULL = 1,
HEAD = 2, HEAD = 2,
AVATAR = 3, #centered nose only AVATAR = 3, #centered nose only
MARK_ONLY = 4, #no align at all, just embedded faceinfo MARK_ONLY = 4, #no align at all, just embedded faceinfo
QTY = 5 QTY = 5
@ -13,12 +13,12 @@ class FaceType(IntEnum):
r = from_string_dict.get (s.lower()) r = from_string_dict.get (s.lower())
if r is None: if r is None:
raise Exception ('FaceType.fromString value error') raise Exception ('FaceType.fromString value error')
return r return r
@staticmethod @staticmethod
def toString (face_type): def toString (face_type):
return to_string_list[face_type] return to_string_list[face_type]
from_string_dict = {'half_face': FaceType.HALF, from_string_dict = {'half_face': FaceType.HALF,
'full_face': FaceType.FULL, 'full_face': FaceType.FULL,
'head' : FaceType.HEAD, 'head' : FaceType.HEAD,
@ -29,6 +29,5 @@ to_string_list = [ 'half_face',
'full_face', 'full_face',
'head', 'head',
'avatar', 'avatar',
'mark_only' 'mark_only'
] ]

View file

@ -10,38 +10,38 @@ class LandmarksExtractor(object):
def __init__ (self, keras): def __init__ (self, keras):
self.keras = keras self.keras = keras
K = self.keras.backend K = self.keras.backend
def __enter__(self): def __enter__(self):
keras_model_path = Path(__file__).parent / "2DFAN-4.h5" keras_model_path = Path(__file__).parent / "2DFAN-4.h5"
if not keras_model_path.exists(): if not keras_model_path.exists():
return None return None
self.keras_model = self.keras.models.load_model (str(keras_model_path)) self.keras_model = self.keras.models.load_model (str(keras_model_path))
return self return self
def __exit__(self, exc_type=None, exc_value=None, traceback=None): def __exit__(self, exc_type=None, exc_value=None, traceback=None):
del self.keras_model del self.keras_model
return False #pass exception between __enter__ and __exit__ to outter level return False #pass exception between __enter__ and __exit__ to outter level
def extract_from_bgr (self, input_image, rects, second_pass_extractor=None): def extract_from_bgr (self, input_image, rects, second_pass_extractor=None):
input_image = input_image[:,:,::-1].copy() input_image = input_image[:,:,::-1].copy()
(h, w, ch) = input_image.shape (h, w, ch) = input_image.shape
landmarks = [] landmarks = []
for (left, top, right, bottom) in rects: for (left, top, right, bottom) in rects:
try: try:
center = np.array( [ (left + right) / 2.0, (top + bottom) / 2.0] ) center = np.array( [ (left + right) / 2.0, (top + bottom) / 2.0] )
#center[1] -= (bottom - top) * 0.12 #center[1] -= (bottom - top) * 0.12
scale = (right - left + bottom - top) / 195.0 scale = (right - left + bottom - top) / 195.0
image = self.crop(input_image, center, scale).astype(np.float32) image = self.crop(input_image, center, scale).astype(np.float32)
image = np.expand_dims(image, 0) image = np.expand_dims(image, 0)
predicted = self.keras_model.predict (image).transpose (0,3,1,2) predicted = self.keras_model.predict (image).transpose (0,3,1,2)
pts_img = self.get_pts_from_predict ( predicted[-1], center, scale) pts_img = self.get_pts_from_predict ( predicted[-1], center, scale)
pts_img = [ ( int(pt[0]), int(pt[1]) ) for pt in pts_img ] pts_img = [ ( int(pt[0]), int(pt[1]) ) for pt in pts_img ]
landmarks.append ( ( (left, top, right, bottom),pts_img ) ) landmarks.append ( ( (left, top, right, bottom),pts_img ) )
except Exception as e: except Exception as e:
landmarks.append ( ( (left, top, right, bottom), None ) ) landmarks.append ( ( (left, top, right, bottom), None ) )
@ -52,26 +52,26 @@ class LandmarksExtractor(object):
rect, lmrks = landmarks[i] rect, lmrks = landmarks[i]
if lmrks is None: if lmrks is None:
continue continue
image_to_face_mat = LandmarksProcessor.get_transform_mat (lmrks, 256, FaceType.FULL) image_to_face_mat = LandmarksProcessor.get_transform_mat (lmrks, 256, FaceType.FULL)
face_image = cv2.warpAffine(input_image, image_to_face_mat, (256, 256), cv2.INTER_CUBIC) face_image = cv2.warpAffine(input_image, image_to_face_mat, (256, 256), cv2.INTER_CUBIC)
rects2 = second_pass_extractor.extract_from_bgr(face_image) rects2 = second_pass_extractor.extract_from_bgr(face_image)
if len(rects2) != 1: #dont do second pass if more than 1 or zero faces detected in cropped image if len(rects2) != 1: #dont do second pass if more than 1 or zero faces detected in cropped image
continue continue
rect2 = rects2[0] rect2 = rects2[0]
lmrks2 = self.extract_from_bgr (face_image, [rect2] )[0][1] lmrks2 = self.extract_from_bgr (face_image, [rect2] )[0][1]
source_lmrks2 = LandmarksProcessor.transform_points (lmrks2, image_to_face_mat, True) source_lmrks2 = LandmarksProcessor.transform_points (lmrks2, image_to_face_mat, True)
landmarks[i] = (rect, source_lmrks2) landmarks[i] = (rect, source_lmrks2)
except: except:
continue continue
return landmarks return landmarks
def transform(self, point, center, scale, resolution): def transform(self, point, center, scale, resolution):
pt = np.array ( [point[0], point[1], 1.0] ) pt = np.array ( [point[0], point[1], 1.0] )
h = 200.0 * scale h = 200.0 * scale
m = np.eye(3) m = np.eye(3)
m[0,0] = resolution / h m[0,0] = resolution / h
@ -80,11 +80,11 @@ class LandmarksExtractor(object):
m[1,2] = resolution * ( -center[1] / h + 0.5 ) m[1,2] = resolution * ( -center[1] / h + 0.5 )
m = np.linalg.inv(m) m = np.linalg.inv(m)
return np.matmul (m, pt)[0:2] return np.matmul (m, pt)[0:2]
def crop(self, image, center, scale, resolution=256.0): def crop(self, image, center, scale, resolution=256.0):
ul = self.transform([1, 1], center, scale, resolution).astype( np.int ) ul = self.transform([1, 1], center, scale, resolution).astype( np.int )
br = self.transform([resolution, resolution], center, scale, resolution).astype( np.int ) br = self.transform([resolution, resolution], center, scale, resolution).astype( np.int )
if image.ndim > 2: if image.ndim > 2:
newDim = np.array([br[1] - ul[1], br[0] - ul[0], image.shape[2]], dtype=np.int32) newDim = np.array([br[1] - ul[1], br[0] - ul[0], image.shape[2]], dtype=np.int32)
newImg = np.zeros(newDim, dtype=np.uint8) newImg = np.zeros(newDim, dtype=np.uint8)
@ -98,14 +98,14 @@ class LandmarksExtractor(object):
oldX = np.array([max(1, ul[0] + 1), min(br[0], wd)], dtype=np.int32) oldX = np.array([max(1, ul[0] + 1), min(br[0], wd)], dtype=np.int32)
oldY = np.array([max(1, ul[1] + 1), min(br[1], ht)], dtype=np.int32) oldY = np.array([max(1, ul[1] + 1), min(br[1], ht)], dtype=np.int32)
newImg[newY[0] - 1:newY[1], newX[0] - 1:newX[1] ] = image[oldY[0] - 1:oldY[1], oldX[0] - 1:oldX[1], :] newImg[newY[0] - 1:newY[1], newX[0] - 1:newX[1] ] = image[oldY[0] - 1:oldY[1], oldX[0] - 1:oldX[1], :]
newImg = cv2.resize(newImg, dsize=(int(resolution), int(resolution)), interpolation=cv2.INTER_LINEAR) newImg = cv2.resize(newImg, dsize=(int(resolution), int(resolution)), interpolation=cv2.INTER_LINEAR)
return newImg return newImg
def get_pts_from_predict(self, a, center, scale): def get_pts_from_predict(self, a, center, scale):
b = a.reshape ( (a.shape[0], a.shape[1]*a.shape[2]) ) b = a.reshape ( (a.shape[0], a.shape[1]*a.shape[2]) )
c = b.argmax(1).reshape ( (a.shape[0], 1) ).repeat(2, axis=1).astype(np.float) c = b.argmax(1).reshape ( (a.shape[0], 1) ).repeat(2, axis=1).astype(np.float)
c[:,0] %= a.shape[2] c[:,0] %= a.shape[2]
c[:,1] = np.apply_along_axis ( lambda x: np.floor(x / a.shape[2]), 0, c[:,1] ) c[:,1] = np.apply_along_axis ( lambda x: np.floor(x / a.shape[2]), 0, c[:,1] )
for i in range(a.shape[0]): for i in range(a.shape[0]):
@ -113,6 +113,6 @@ class LandmarksExtractor(object):
if pX > 0 and pX < 63 and pY > 0 and pY < 63: if pX > 0 and pX < 63 and pY > 0 and pY < 63:
diff = np.array ( [a[i,pY,pX+1]-a[i,pY,pX-1], a[i,pY+1,pX]-a[i,pY-1,pX]] ) diff = np.array ( [a[i,pY,pX+1]-a[i,pY,pX-1], a[i,pY+1,pX]-a[i,pY-1,pX]] )
c[i] += np.sign(diff)*0.25 c[i] += np.sign(diff)*0.25
c += 0.5 c += 0.5
return [ self.transform (c[i], center, scale, a.shape[2]) for i in range(a.shape[0]) ] return [ self.transform (c[i], center, scale, a.shape[2]) for i in range(a.shape[0]) ]

View file

@ -36,7 +36,7 @@ landmarks_68_pt = { "mouth": (48,68),
"left_eye": (42, 48), "left_eye": (42, 48),
"nose": (27, 36), # missed one point "nose": (27, 36), # missed one point
"jaw": (0, 17) } "jaw": (0, 17) }
landmarks_68_3D = np.array( [ landmarks_68_3D = np.array( [
[-73.393523 , -29.801432 , 47.667532 ], [-73.393523 , -29.801432 , 47.667532 ],
@ -107,20 +107,20 @@ landmarks_68_3D = np.array( [
[8.449166 , 30.596216 , -20.671489 ], [8.449166 , 30.596216 , -20.671489 ],
[0.205322 , 31.408738 , -21.903670 ], [0.205322 , 31.408738 , -21.903670 ],
[-7.198266 , 30.844876 , -20.328022 ] ], dtype=np.float32) [-7.198266 , 30.844876 , -20.328022 ] ], dtype=np.float32)
def get_transform_mat (image_landmarks, output_size, face_type, scale=1.0): def get_transform_mat (image_landmarks, output_size, face_type, scale=1.0):
if not isinstance(image_landmarks, np.ndarray): if not isinstance(image_landmarks, np.ndarray):
image_landmarks = np.array (image_landmarks) image_landmarks = np.array (image_landmarks)
if face_type == FaceType.AVATAR: if face_type == FaceType.AVATAR:
centroid = np.mean (image_landmarks, axis=0) centroid = np.mean (image_landmarks, axis=0)
mat = umeyama(image_landmarks[17:], landmarks_2D, True)[0:2] mat = umeyama(image_landmarks[17:], landmarks_2D, True)[0:2]
a, c = mat[0,0], mat[1,0] a, c = mat[0,0], mat[1,0]
scale = math.sqrt((a * a) + (c * c)) scale = math.sqrt((a * a) + (c * c))
padding = (output_size / 64) * 32 padding = (output_size / 64) * 32
mat = np.eye ( 2,3 ) mat = np.eye ( 2,3 )
mat[0,2] = -centroid[0] mat[0,2] = -centroid[0]
mat[1,2] = -centroid[1] mat[1,2] = -centroid[1]
@ -135,15 +135,15 @@ def get_transform_mat (image_landmarks, output_size, face_type, scale=1.0):
padding = (output_size / 64) * 24 padding = (output_size / 64) * 24
else: else:
raise ValueError ('wrong face_type: ', face_type) raise ValueError ('wrong face_type: ', face_type)
mat = umeyama(image_landmarks[17:], landmarks_2D, True)[0:2] mat = umeyama(image_landmarks[17:], landmarks_2D, True)[0:2]
mat = mat * (output_size - 2 * padding) mat = mat * (output_size - 2 * padding)
mat[:,2] += padding mat[:,2] += padding
mat *= (1 / scale) mat *= (1 / scale)
mat[:,2] += -output_size*( ( (1 / scale) - 1.0 ) / 2 ) mat[:,2] += -output_size*( ( (1 / scale) - 1.0 ) / 2 )
return mat return mat
def transform_points(points, mat, invert=False): def transform_points(points, mat, invert=False):
if invert: if invert:
mat = cv2.invertAffineTransform (mat) mat = cv2.invertAffineTransform (mat)
@ -151,68 +151,68 @@ def transform_points(points, mat, invert=False):
points = cv2.transform(points, mat, points.shape) points = cv2.transform(points, mat, points.shape)
points = np.squeeze(points) points = np.squeeze(points)
return points return points
def get_image_hull_mask (image_shape, image_landmarks): def get_image_hull_mask (image_shape, image_landmarks):
if len(image_landmarks) != 68: if len(image_landmarks) != 68:
raise Exception('get_image_hull_mask works only with 68 landmarks') raise Exception('get_image_hull_mask works only with 68 landmarks')
hull_mask = np.zeros(image_shape[0:2]+(1,),dtype=np.float32) hull_mask = np.zeros(image_shape[0:2]+(1,),dtype=np.float32)
cv2.fillConvexPoly( hull_mask, cv2.convexHull( cv2.fillConvexPoly( hull_mask, cv2.convexHull(
np.concatenate ( (image_landmarks[0:9], np.concatenate ( (image_landmarks[0:9],
image_landmarks[17:18]))) , (1,) ) image_landmarks[17:18]))) , (1,) )
cv2.fillConvexPoly( hull_mask, cv2.convexHull( cv2.fillConvexPoly( hull_mask, cv2.convexHull(
np.concatenate ( (image_landmarks[8:17], np.concatenate ( (image_landmarks[8:17],
image_landmarks[26:27]))) , (1,) ) image_landmarks[26:27]))) , (1,) )
cv2.fillConvexPoly( hull_mask, cv2.convexHull( cv2.fillConvexPoly( hull_mask, cv2.convexHull(
np.concatenate ( (image_landmarks[17:20], np.concatenate ( (image_landmarks[17:20],
image_landmarks[8:9]))) , (1,) ) image_landmarks[8:9]))) , (1,) )
cv2.fillConvexPoly( hull_mask, cv2.convexHull( cv2.fillConvexPoly( hull_mask, cv2.convexHull(
np.concatenate ( (image_landmarks[24:27], np.concatenate ( (image_landmarks[24:27],
image_landmarks[8:9]))) , (1,) ) image_landmarks[8:9]))) , (1,) )
cv2.fillConvexPoly( hull_mask, cv2.convexHull( cv2.fillConvexPoly( hull_mask, cv2.convexHull(
np.concatenate ( (image_landmarks[19:25], np.concatenate ( (image_landmarks[19:25],
image_landmarks[8:9], image_landmarks[8:9],
))) , (1,) ) ))) , (1,) )
cv2.fillConvexPoly( hull_mask, cv2.convexHull( cv2.fillConvexPoly( hull_mask, cv2.convexHull(
np.concatenate ( (image_landmarks[17:22], np.concatenate ( (image_landmarks[17:22],
image_landmarks[27:28], image_landmarks[27:28],
image_landmarks[31:36], image_landmarks[31:36],
image_landmarks[8:9] image_landmarks[8:9]
))) , (1,) ) ))) , (1,) )
cv2.fillConvexPoly( hull_mask, cv2.convexHull( cv2.fillConvexPoly( hull_mask, cv2.convexHull(
np.concatenate ( (image_landmarks[22:27], np.concatenate ( (image_landmarks[22:27],
image_landmarks[27:28], image_landmarks[27:28],
image_landmarks[31:36], image_landmarks[31:36],
image_landmarks[8:9] image_landmarks[8:9]
))) , (1,) ) ))) , (1,) )
#nose #nose
cv2.fillConvexPoly( hull_mask, cv2.convexHull(image_landmarks[27:36]), (1,) ) cv2.fillConvexPoly( hull_mask, cv2.convexHull(image_landmarks[27:36]), (1,) )
return hull_mask return hull_mask
def get_image_eye_mask (image_shape, image_landmarks): def get_image_eye_mask (image_shape, image_landmarks):
if len(image_landmarks) != 68: if len(image_landmarks) != 68:
raise Exception('get_image_eye_mask works only with 68 landmarks') raise Exception('get_image_eye_mask works only with 68 landmarks')
hull_mask = np.zeros(image_shape[0:2]+(1,),dtype=np.float32) hull_mask = np.zeros(image_shape[0:2]+(1,),dtype=np.float32)
cv2.fillConvexPoly( hull_mask, cv2.convexHull( image_landmarks[36:42]), (1,) ) cv2.fillConvexPoly( hull_mask, cv2.convexHull( image_landmarks[36:42]), (1,) )
cv2.fillConvexPoly( hull_mask, cv2.convexHull( image_landmarks[42:48]), (1,) ) cv2.fillConvexPoly( hull_mask, cv2.convexHull( image_landmarks[42:48]), (1,) )
return hull_mask return hull_mask
def get_image_hull_mask_3D (image_shape, image_landmarks): def get_image_hull_mask_3D (image_shape, image_landmarks):
result = get_image_hull_mask(image_shape, image_landmarks) result = get_image_hull_mask(image_shape, image_landmarks)
return np.repeat ( result, (3,), -1 ) return np.repeat ( result, (3,), -1 )
def blur_image_hull_mask (hull_mask): def blur_image_hull_mask (hull_mask):
@ -224,7 +224,7 @@ def blur_image_hull_mask (hull_mask):
leny = maxy - miny; leny = maxy - miny;
masky = int(minx+(lenx//2)) masky = int(minx+(lenx//2))
maskx = int(miny+(leny//2)) maskx = int(miny+(leny//2))
lowest_len = min (lenx, leny) lowest_len = min (lenx, leny)
ero = int( lowest_len * 0.085 ) ero = int( lowest_len * 0.085 )
blur = int( lowest_len * 0.10 ) blur = int( lowest_len * 0.10 )
@ -233,10 +233,10 @@ def blur_image_hull_mask (hull_mask):
hull_mask = np.expand_dims (hull_mask,-1) hull_mask = np.expand_dims (hull_mask,-1)
return hull_mask return hull_mask
def get_blurred_image_hull_mask(image_shape, image_landmarks): def get_blurred_image_hull_mask(image_shape, image_landmarks):
return blur_image_hull_mask ( get_image_hull_mask(image_shape, image_landmarks) ) return blur_image_hull_mask ( get_image_hull_mask(image_shape, image_landmarks) )
mirror_idxs = [ mirror_idxs = [
[0,16], [0,16],
[1,15], [1,15],
@ -246,23 +246,23 @@ mirror_idxs = [
[5,11], [5,11],
[6,10], [6,10],
[7,9], [7,9],
[17,26], [17,26],
[18,25], [18,25],
[19,24], [19,24],
[20,23], [20,23],
[21,22], [21,22],
[36,45], [36,45],
[37,44], [37,44],
[38,43], [38,43],
[39,42], [39,42],
[40,47], [40,47],
[41,46], [41,46],
[31,35], [31,35],
[32,34], [32,34],
[50,52], [50,52],
[49,53], [49,53],
[48,54], [48,54],
@ -271,28 +271,28 @@ mirror_idxs = [
[67,65], [67,65],
[60,64], [60,64],
[61,63] ] [61,63] ]
def mirror_landmarks (landmarks, val): def mirror_landmarks (landmarks, val):
result = landmarks.copy() result = landmarks.copy()
for idx in mirror_idxs: for idx in mirror_idxs:
result [ idx ] = result [ idx[::-1] ] result [ idx ] = result [ idx[::-1] ]
result[:,0] = val - result[:,0] - 1 result[:,0] = val - result[:,0] - 1
return result return result
def draw_landmarks (image, image_landmarks, color=(0,255,0), transparent_mask=False): def draw_landmarks (image, image_landmarks, color=(0,255,0), transparent_mask=False):
if len(image_landmarks) != 68: if len(image_landmarks) != 68:
raise Exception('get_image_eye_mask works only with 68 landmarks') raise Exception('get_image_eye_mask works only with 68 landmarks')
jaw = image_landmarks[slice(*landmarks_68_pt["jaw"])] jaw = image_landmarks[slice(*landmarks_68_pt["jaw"])]
right_eyebrow = image_landmarks[slice(*landmarks_68_pt["right_eyebrow"])] right_eyebrow = image_landmarks[slice(*landmarks_68_pt["right_eyebrow"])]
left_eyebrow = image_landmarks[slice(*landmarks_68_pt["left_eyebrow"])] left_eyebrow = image_landmarks[slice(*landmarks_68_pt["left_eyebrow"])]
mouth = image_landmarks[slice(*landmarks_68_pt["mouth"])] mouth = image_landmarks[slice(*landmarks_68_pt["mouth"])]
right_eye = image_landmarks[slice(*landmarks_68_pt["right_eye"])] right_eye = image_landmarks[slice(*landmarks_68_pt["right_eye"])]
left_eye = image_landmarks[slice(*landmarks_68_pt["left_eye"])] left_eye = image_landmarks[slice(*landmarks_68_pt["left_eye"])]
nose = image_landmarks[slice(*landmarks_68_pt["nose"])] nose = image_landmarks[slice(*landmarks_68_pt["nose"])]
# open shapes # open shapes
cv2.polylines(image, tuple(np.array([v]) for v in ( right_eyebrow, jaw, left_eyebrow, np.concatenate((nose, [nose[-6]])) )), cv2.polylines(image, tuple(np.array([v]) for v in ( right_eyebrow, jaw, left_eyebrow, np.concatenate((nose, [nose[-6]])) )),
False, color, lineType=cv2.LINE_AA) False, color, lineType=cv2.LINE_AA)
@ -303,9 +303,9 @@ def draw_landmarks (image, image_landmarks, color=(0,255,0), transparent_mask=Fa
for x, y in np.concatenate((right_eyebrow, left_eyebrow, mouth, right_eye, left_eye, nose), axis=0): for x, y in np.concatenate((right_eyebrow, left_eyebrow, mouth, right_eye, left_eye, nose), axis=0):
cv2.circle(image, (x, y), 1, color, 1, lineType=cv2.LINE_AA) cv2.circle(image, (x, y), 1, color, 1, lineType=cv2.LINE_AA)
# jaw big circles # jaw big circles
for x, y in jaw: for x, y in jaw:
cv2.circle(image, (x, y), 2, color, lineType=cv2.LINE_AA) cv2.circle(image, (x, y), 2, color, lineType=cv2.LINE_AA)
if transparent_mask: if transparent_mask:
mask = get_image_hull_mask (image.shape, image_landmarks) mask = get_image_hull_mask (image.shape, image_landmarks)
image[...] = ( image * (1-mask) + image * mask / 2 )[...] image[...] = ( image * (1-mask) + image * mask / 2 )[...]
@ -314,24 +314,24 @@ def draw_rect_landmarks (image, rect, image_landmarks, face_size, face_type, tra
draw_landmarks(image, image_landmarks, color=landmarks_color, transparent_mask=transparent_mask) draw_landmarks(image, image_landmarks, color=landmarks_color, transparent_mask=transparent_mask)
image_utils.draw_rect (image, rect, (255,0,0), 2 ) image_utils.draw_rect (image, rect, (255,0,0), 2 )
image_to_face_mat = get_transform_mat (image_landmarks, face_size, face_type) image_to_face_mat = get_transform_mat (image_landmarks, face_size, face_type)
points = transform_points ( [ (0,0), (0,face_size-1), (face_size-1, face_size-1), (face_size-1,0) ], image_to_face_mat, True) points = transform_points ( [ (0,0), (0,face_size-1), (face_size-1, face_size-1), (face_size-1,0) ], image_to_face_mat, True)
image_utils.draw_polygon (image, points, (0,0,255), 2) image_utils.draw_polygon (image, points, (0,0,255), 2)
def calc_face_pitch(landmarks): def calc_face_pitch(landmarks):
if not isinstance(landmarks, np.ndarray): if not isinstance(landmarks, np.ndarray):
landmarks = np.array (landmarks) landmarks = np.array (landmarks)
t = ( (landmarks[6][1]-landmarks[8][1]) + (landmarks[10][1]-landmarks[8][1]) ) / 2.0 t = ( (landmarks[6][1]-landmarks[8][1]) + (landmarks[10][1]-landmarks[8][1]) ) / 2.0
b = landmarks[8][1] b = landmarks[8][1]
return float(b-t) return float(b-t)
def calc_face_yaw(landmarks): def calc_face_yaw(landmarks):
if not isinstance(landmarks, np.ndarray): if not isinstance(landmarks, np.ndarray):
landmarks = np.array (landmarks) landmarks = np.array (landmarks)
l = ( (landmarks[27][0]-landmarks[0][0]) + (landmarks[28][0]-landmarks[1][0]) + (landmarks[29][0]-landmarks[2][0]) ) / 3.0 l = ( (landmarks[27][0]-landmarks[0][0]) + (landmarks[28][0]-landmarks[1][0]) + (landmarks[29][0]-landmarks[2][0]) ) / 3.0
r = ( (landmarks[16][0]-landmarks[27][0]) + (landmarks[15][0]-landmarks[28][0]) + (landmarks[14][0]-landmarks[29][0]) ) / 3.0 r = ( (landmarks[16][0]-landmarks[27][0]) + (landmarks[15][0]-landmarks[28][0]) + (landmarks[14][0]-landmarks[29][0]) ) / 3.0
return float(r-l) return float(r-l)
#returns pitch,yaw [-1...+1] #returns pitch,yaw [-1...+1]
def estimate_pitch_yaw(aligned_256px_landmarks): def estimate_pitch_yaw(aligned_256px_landmarks):
shape = (256,256) shape = (256,256)
@ -347,8 +347,8 @@ def estimate_pitch_yaw(aligned_256px_landmarks):
aligned_256px_landmarks.astype(np.float32), aligned_256px_landmarks.astype(np.float32),
camera_matrix, camera_matrix,
np.zeros((4, 1)) ) np.zeros((4, 1)) )
pitch, yaw, _ = mathlib.rotationMatrixToEulerAngles( cv2.Rodrigues(rotation_vector)[0] ) pitch, yaw, _ = mathlib.rotationMatrixToEulerAngles( cv2.Rodrigues(rotation_vector)[0] )
pitch = np.clip ( pitch*1.25, -1.0, 1.0 ) pitch = np.clip ( pitch*1.25, -1.0, 1.0 )
yaw = np.clip ( yaw*1.25, -1.0, 1.0 ) yaw = np.clip ( yaw*1.25, -1.0, 1.0 )
return pitch, yaw return pitch, yaw

View file

@ -5,10 +5,10 @@ import cv2
from pathlib import Path from pathlib import Path
from nnlib import nnlib from nnlib import nnlib
class MTCExtractor(object): class MTCExtractor(object):
def __init__(self): def __init__(self):
self.scale_to = 1920 self.scale_to = 1920
self.min_face_size = self.scale_to * 0.042 self.min_face_size = self.scale_to * 0.042
self.thresh1 = 0.7 self.thresh1 = 0.7
self.thresh2 = 0.85 self.thresh2 = 0.85
@ -26,12 +26,12 @@ class MTCExtractor(object):
x = Conv2D (32, kernel_size=(3,3), strides=(1,1), padding='valid', name="conv3")(x) x = Conv2D (32, kernel_size=(3,3), strides=(1,1), padding='valid', name="conv3")(x)
x = PReLU (shared_axes=[1,2], name="PReLU3" )(x) x = PReLU (shared_axes=[1,2], name="PReLU3" )(x)
prob = Conv2D (2, kernel_size=(1,1), strides=(1,1), padding='valid', name="conv41")(x) prob = Conv2D (2, kernel_size=(1,1), strides=(1,1), padding='valid', name="conv41")(x)
prob = Softmax()(prob) prob = Softmax()(prob)
x = Conv2D (4, kernel_size=(1,1), strides=(1,1), padding='valid', name="conv42")(x) x = Conv2D (4, kernel_size=(1,1), strides=(1,1), padding='valid', name="conv42")(x)
PNet_model = Model(PNet_Input, [x,prob] ) PNet_model = Model(PNet_Input, [x,prob] )
PNet_model.load_weights ( (Path(__file__).parent / 'mtcnn_pnet.h5').__str__() ) PNet_model.load_weights ( (Path(__file__).parent / 'mtcnn_pnet.h5').__str__() )
RNet_Input = Input ( (24, 24, 3) ) RNet_Input = Input ( (24, 24, 3) )
x = RNet_Input x = RNet_Input
x = Conv2D (28, kernel_size=(3,3), strides=(1,1), padding='valid', name="conv1")(x) x = Conv2D (28, kernel_size=(3,3), strides=(1,1), padding='valid', name="conv1")(x)
@ -39,18 +39,18 @@ class MTCExtractor(object):
x = MaxPooling2D( pool_size=(3,3), strides=(2,2), padding='same' ) (x) x = MaxPooling2D( pool_size=(3,3), strides=(2,2), padding='same' ) (x)
x = Conv2D (48, kernel_size=(3,3), strides=(1,1), padding='valid', name="conv2")(x) x = Conv2D (48, kernel_size=(3,3), strides=(1,1), padding='valid', name="conv2")(x)
x = PReLU (shared_axes=[1,2], name="prelu2" )(x) x = PReLU (shared_axes=[1,2], name="prelu2" )(x)
x = MaxPooling2D( pool_size=(3,3), strides=(2,2), padding='valid' ) (x) x = MaxPooling2D( pool_size=(3,3), strides=(2,2), padding='valid' ) (x)
x = Conv2D (64, kernel_size=(2,2), strides=(1,1), padding='valid', name="conv3")(x) x = Conv2D (64, kernel_size=(2,2), strides=(1,1), padding='valid', name="conv3")(x)
x = PReLU (shared_axes=[1,2], name="prelu3" )(x) x = PReLU (shared_axes=[1,2], name="prelu3" )(x)
x = Lambda ( lambda x: K.reshape (x, (-1, np.prod(K.int_shape(x)[1:]),) ), output_shape=(np.prod(K.int_shape(x)[1:]),) ) (x) x = Lambda ( lambda x: K.reshape (x, (-1, np.prod(K.int_shape(x)[1:]),) ), output_shape=(np.prod(K.int_shape(x)[1:]),) ) (x)
x = Dense (128, name='conv4')(x) x = Dense (128, name='conv4')(x)
x = PReLU (name="prelu4" )(x) x = PReLU (name="prelu4" )(x)
prob = Dense (2, name='conv51')(x) prob = Dense (2, name='conv51')(x)
prob = Softmax()(prob) prob = Softmax()(prob)
x = Dense (4, name='conv52')(x) x = Dense (4, name='conv52')(x)
RNet_model = Model(RNet_Input, [x,prob] ) RNet_model = Model(RNet_Input, [x,prob] )
RNet_model.load_weights ( (Path(__file__).parent / 'mtcnn_rnet.h5').__str__() ) RNet_model.load_weights ( (Path(__file__).parent / 'mtcnn_rnet.h5').__str__() )
ONet_Input = Input ( (48, 48, 3) ) ONet_Input = Input ( (48, 48, 3) )
x = ONet_Input x = ONet_Input
x = Conv2D (32, kernel_size=(3,3), strides=(1,1), padding='valid', name="conv1")(x) x = Conv2D (32, kernel_size=(3,3), strides=(1,1), padding='valid', name="conv1")(x)
@ -58,20 +58,20 @@ class MTCExtractor(object):
x = MaxPooling2D( pool_size=(3,3), strides=(2,2), padding='same' ) (x) x = MaxPooling2D( pool_size=(3,3), strides=(2,2), padding='same' ) (x)
x = Conv2D (64, kernel_size=(3,3), strides=(1,1), padding='valid', name="conv2")(x) x = Conv2D (64, kernel_size=(3,3), strides=(1,1), padding='valid', name="conv2")(x)
x = PReLU (shared_axes=[1,2], name="prelu2" )(x) x = PReLU (shared_axes=[1,2], name="prelu2" )(x)
x = MaxPooling2D( pool_size=(3,3), strides=(2,2), padding='valid' ) (x) x = MaxPooling2D( pool_size=(3,3), strides=(2,2), padding='valid' ) (x)
x = Conv2D (64, kernel_size=(3,3), strides=(1,1), padding='valid', name="conv3")(x) x = Conv2D (64, kernel_size=(3,3), strides=(1,1), padding='valid', name="conv3")(x)
x = PReLU (shared_axes=[1,2], name="prelu3" )(x) x = PReLU (shared_axes=[1,2], name="prelu3" )(x)
x = MaxPooling2D( pool_size=(2,2), strides=(2,2), padding='same' ) (x) x = MaxPooling2D( pool_size=(2,2), strides=(2,2), padding='same' ) (x)
x = Conv2D (128, kernel_size=(2,2), strides=(1,1), padding='valid', name="conv4")(x) x = Conv2D (128, kernel_size=(2,2), strides=(1,1), padding='valid', name="conv4")(x)
x = PReLU (shared_axes=[1,2], name="prelu4" )(x) x = PReLU (shared_axes=[1,2], name="prelu4" )(x)
x = Lambda ( lambda x: K.reshape (x, (-1, np.prod(K.int_shape(x)[1:]),) ), output_shape=(np.prod(K.int_shape(x)[1:]),) ) (x) x = Lambda ( lambda x: K.reshape (x, (-1, np.prod(K.int_shape(x)[1:]),) ), output_shape=(np.prod(K.int_shape(x)[1:]),) ) (x)
x = Dense (256, name='conv5')(x) x = Dense (256, name='conv5')(x)
x = PReLU (name="prelu5" )(x) x = PReLU (name="prelu5" )(x)
prob = Dense (2, name='conv61')(x) prob = Dense (2, name='conv61')(x)
prob = Softmax()(prob) prob = Softmax()(prob)
x1 = Dense (4, name='conv62')(x) x1 = Dense (4, name='conv62')(x)
x2 = Dense (10, name='conv63')(x) x2 = Dense (10, name='conv63')(x)
ONet_model = Model(ONet_Input, [x1,x2,prob] ) ONet_model = Model(ONet_Input, [x1,x2,prob] )
ONet_model.load_weights ( (Path(__file__).parent / 'mtcnn_onet.h5').__str__() ) ONet_model.load_weights ( (Path(__file__).parent / 'mtcnn_onet.h5').__str__() )
self.pnet_fun = K.function ( PNet_model.inputs, PNet_model.outputs ) self.pnet_fun = K.function ( PNet_model.inputs, PNet_model.outputs )
@ -79,13 +79,13 @@ class MTCExtractor(object):
self.onet_fun = K.function ( ONet_model.inputs, ONet_model.outputs ) self.onet_fun = K.function ( ONet_model.inputs, ONet_model.outputs )
def __enter__(self): def __enter__(self):
faces, pnts = detect_face ( np.zeros ( (self.scale_to, self.scale_to, 3)), self.min_face_size, self.pnet_fun, self.rnet_fun, self.onet_fun, [ self.thresh1, self.thresh2, self.thresh3 ], self.scale_factor ) faces, pnts = detect_face ( np.zeros ( (self.scale_to, self.scale_to, 3)), self.min_face_size, self.pnet_fun, self.rnet_fun, self.onet_fun, [ self.thresh1, self.thresh2, self.thresh3 ], self.scale_factor )
return self return self
def __exit__(self, exc_type=None, exc_value=None, traceback=None): def __exit__(self, exc_type=None, exc_value=None, traceback=None):
return False #pass exception between __enter__ and __exit__ to outter level return False #pass exception between __enter__ and __exit__ to outter level
def extract_from_bgr (self, input_image): def extract_from_bgr (self, input_image):
input_image = input_image[:,:,::-1].copy() input_image = input_image[:,:,::-1].copy()
(h, w, ch) = input_image.shape (h, w, ch) = input_image.shape
@ -95,7 +95,7 @@ class MTCExtractor(object):
detected_faces, pnts = detect_face ( input_image, self.min_face_size, self.pnet_fun, self.rnet_fun, self.onet_fun, [ self.thresh1, self.thresh2, self.thresh3 ], self.scale_factor ) detected_faces, pnts = detect_face ( input_image, self.min_face_size, self.pnet_fun, self.rnet_fun, self.onet_fun, [ self.thresh1, self.thresh2, self.thresh3 ], self.scale_factor )
detected_faces = [ ( int(face[0]/input_scale), int(face[1]/input_scale), int(face[2]/input_scale), int(face[3]/input_scale)) for face in detected_faces ] detected_faces = [ ( int(face[0]/input_scale), int(face[1]/input_scale), int(face[2]/input_scale), int(face[3]/input_scale)) for face in detected_faces ]
return detected_faces return detected_faces
def detect_face(img, minsize, pnet, rnet, onet, threshold, factor): def detect_face(img, minsize, pnet, rnet, onet, threshold, factor):
@ -132,9 +132,9 @@ def detect_face(img, minsize, pnet, rnet, onet, threshold, factor):
out = pnet([img_y]) out = pnet([img_y])
out0 = np.transpose(out[0], (0,2,1,3)) out0 = np.transpose(out[0], (0,2,1,3))
out1 = np.transpose(out[1], (0,2,1,3)) out1 = np.transpose(out[1], (0,2,1,3))
boxes, _ = generateBoundingBox(out1[0,:,:,1].copy(), out0[0,:,:,:].copy(), scale, threshold[0]) boxes, _ = generateBoundingBox(out1[0,:,:,1].copy(), out0[0,:,:,:].copy(), scale, threshold[0])
# inter-scale nms # inter-scale nms
pick = nms(boxes.copy(), 0.5, 'Union') pick = nms(boxes.copy(), 0.5, 'Union')
if boxes.size>0 and pick.size>0: if boxes.size>0 and pick.size>0:
@ -217,7 +217,7 @@ def detect_face(img, minsize, pnet, rnet, onet, threshold, factor):
pick = nms(total_boxes.copy(), 0.7, 'Min') pick = nms(total_boxes.copy(), 0.7, 'Min')
total_boxes = total_boxes[pick,:] total_boxes = total_boxes[pick,:]
points = points[:,pick] points = points[:,pick]
return total_boxes, points return total_boxes, points
@ -235,7 +235,7 @@ def bbreg(boundingbox,reg):
b4 = boundingbox[:,3]+reg[:,3]*h b4 = boundingbox[:,3]+reg[:,3]*h
boundingbox[:,0:4] = np.transpose(np.vstack([b1, b2, b3, b4 ])) boundingbox[:,0:4] = np.transpose(np.vstack([b1, b2, b3, b4 ]))
return boundingbox return boundingbox
def generateBoundingBox(imap, reg, scale, t): def generateBoundingBox(imap, reg, scale, t):
"""Use heatmap to generate bounding boxes""" """Use heatmap to generate bounding boxes"""
stride=2 stride=2
@ -261,7 +261,7 @@ def generateBoundingBox(imap, reg, scale, t):
q2 = np.fix((stride*bb+cellsize-1+1)/scale) q2 = np.fix((stride*bb+cellsize-1+1)/scale)
boundingbox = np.hstack([q1, q2, np.expand_dims(score,1), reg]) boundingbox = np.hstack([q1, q2, np.expand_dims(score,1), reg])
return boundingbox, reg return boundingbox, reg
# function pick = nms(boxes,threshold,type) # function pick = nms(boxes,threshold,type)
def nms(boxes, threshold, method): def nms(boxes, threshold, method):
if boxes.size==0: if boxes.size==0:
@ -315,7 +315,7 @@ def pad(total_boxes, w, h):
tmp = np.where(ex>w) tmp = np.where(ex>w)
edx.flat[tmp] = np.expand_dims(-ex[tmp]+w+tmpw[tmp],1) edx.flat[tmp] = np.expand_dims(-ex[tmp]+w+tmpw[tmp],1)
ex[tmp] = w ex[tmp] = w
tmp = np.where(ey>h) tmp = np.where(ey>h)
edy.flat[tmp] = np.expand_dims(-ey[tmp]+h+tmph[tmp],1) edy.flat[tmp] = np.expand_dims(-ey[tmp]+h+tmph[tmp],1)
ey[tmp] = h ey[tmp] = h
@ -327,7 +327,7 @@ def pad(total_boxes, w, h):
tmp = np.where(y<1) tmp = np.where(y<1)
dy.flat[tmp] = np.expand_dims(2-y[tmp],1) dy.flat[tmp] = np.expand_dims(2-y[tmp],1)
y[tmp] = 1 y[tmp] = 1
return dy, edy, dx, edx, y, ey, x, ex, tmpw, tmph return dy, edy, dx, edx, y, ey, x, ex, tmpw, tmph
# function [bboxA] = rerec(bboxA) # function [bboxA] = rerec(bboxA)

View file

@ -3,35 +3,35 @@ from pathlib import Path
import cv2 import cv2
from nnlib import nnlib from nnlib import nnlib
class S3FDExtractor(object): class S3FDExtractor(object):
def __init__(self): def __init__(self):
exec( nnlib.import_all(), locals(), globals() ) exec( nnlib.import_all(), locals(), globals() )
model_path = Path(__file__).parent / "S3FD.h5" model_path = Path(__file__).parent / "S3FD.h5"
if not model_path.exists(): if not model_path.exists():
return None return None
self.model = nnlib.keras.models.load_model ( str(model_path) ) self.model = nnlib.keras.models.load_model ( str(model_path) )
def __enter__(self): def __enter__(self):
return self return self
def __exit__(self, exc_type=None, exc_value=None, traceback=None): def __exit__(self, exc_type=None, exc_value=None, traceback=None):
return False #pass exception between __enter__ and __exit__ to outter level return False #pass exception between __enter__ and __exit__ to outter level
def extract_from_bgr (self, input_image): def extract_from_bgr (self, input_image):
input_image = input_image[:,:,::-1].copy() input_image = input_image[:,:,::-1].copy()
(h, w, ch) = input_image.shape (h, w, ch) = input_image.shape
d = max(w, h) d = max(w, h)
scale_to = 640 if d >= 1280 else d / 2 scale_to = 640 if d >= 1280 else d / 2
scale_to = max(64, scale_to) scale_to = max(64, scale_to)
input_scale = d / scale_to input_scale = d / scale_to
input_image = cv2.resize (input_image, ( int(w/input_scale), int(h/input_scale) ), interpolation=cv2.INTER_LINEAR) input_image = cv2.resize (input_image, ( int(w/input_scale), int(h/input_scale) ), interpolation=cv2.INTER_LINEAR)
olist = self.model.predict( np.expand_dims(input_image,0) ) olist = self.model.predict( np.expand_dims(input_image,0) )
detected_faces = [] detected_faces = []
for ltrb in self.refine (olist): for ltrb in self.refine (olist):
l,t,r,b = [ x*input_scale for x in ltrb] l,t,r,b = [ x*input_scale for x in ltrb]
@ -42,7 +42,7 @@ class S3FDExtractor(object):
detected_faces.append ( [int(x) for x in (l,t,r,b) ] ) detected_faces.append ( [int(x) for x in (l,t,r,b) ] )
return detected_faces return detected_faces
def refine(self, olist): def refine(self, olist):
bboxlist = [] bboxlist = []
for i, ((ocls,), (oreg,)) in enumerate ( zip ( olist[::2], olist[1::2] ) ): for i, ((ocls,), (oreg,)) in enumerate ( zip ( olist[::2], olist[1::2] ) ):
@ -51,7 +51,7 @@ class S3FDExtractor(object):
s_m4 = stride * 4 s_m4 = stride * 4
for hindex, windex in zip(*np.where(ocls > 0.05)): for hindex, windex in zip(*np.where(ocls > 0.05)):
score = ocls[hindex, windex] score = ocls[hindex, windex]
loc = oreg[hindex, windex, :] loc = oreg[hindex, windex, :]
priors = np.array([windex * stride + s_d2, hindex * stride + s_d2, s_m4, s_m4]) priors = np.array([windex * stride + s_d2, hindex * stride + s_d2, s_m4, s_m4])
priors_2p = priors[2:] priors_2p = priors[2:]
@ -61,15 +61,15 @@ class S3FDExtractor(object):
box[2:] += box[:2] box[2:] += box[:2]
bboxlist.append([*box, score]) bboxlist.append([*box, score])
bboxlist = np.array(bboxlist) bboxlist = np.array(bboxlist)
if len(bboxlist) == 0: if len(bboxlist) == 0:
bboxlist = np.zeros((1, 5)) bboxlist = np.zeros((1, 5))
bboxlist = bboxlist[self.refine_nms(bboxlist, 0.3), :] bboxlist = bboxlist[self.refine_nms(bboxlist, 0.3), :]
bboxlist = [ x[:-1].astype(np.int) for x in bboxlist if x[-1] >= 0.5] bboxlist = [ x[:-1].astype(np.int) for x in bboxlist if x[-1] >= 0.5]
return bboxlist return bboxlist
def refine_nms(self, dets, thresh): def refine_nms(self, dets, thresh):
keep = list() keep = list()
if len(dets) == 0: if len(dets) == 0:
@ -91,4 +91,4 @@ class S3FDExtractor(object):
inds = np.where(ovr <= thresh)[0] inds = np.where(ovr <= thresh)[0]
order = order[inds + 1] order = order[inds + 1]
return keep return keep

View file

@ -1,6 +1,6 @@
""" """
Copyright (c) 2009-2010 Arizona Board of Regents. All Rights Reserved. Copyright (c) 2009-2010 Arizona Board of Regents. All Rights Reserved.
Contact: Lina Karam (karam@asu.edu) and Niranjan Narvekar (nnarveka@asu.edu) Contact: Lina Karam (karam@asu.edu) and Niranjan Narvekar (nnarveka@asu.edu)
Image, Video, and Usabilty (IVU) Lab, http://ivulab.asu.edu , Arizona State University Image, Video, and Usabilty (IVU) Lab, http://ivulab.asu.edu , Arizona State University
This copyright statement may not be removed from any file containing it or from modifications to these files. This copyright statement may not be removed from any file containing it or from modifications to these files.
This copyright notice must also be included in any file or product that is derived from the source files. This copyright notice must also be included in any file or product that is derived from the source files.
@ -267,11 +267,11 @@ def get_block_contrast(block):
# type: (numpy.ndarray) -> int # type: (numpy.ndarray) -> int
return int(np.max(block) - np.min(block)) return int(np.max(block) - np.min(block))
def estimate_sharpness(image): def estimate_sharpness(image):
height, width = image.shape[:2] height, width = image.shape[:2]
if image.ndim == 3: if image.ndim == 3:
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
return compute(image) return compute(image)

View file

@ -1 +1 @@
from .interact import interact from .interact import interact

View file

@ -12,28 +12,28 @@ class Interact(object):
EVENT_RBUTTONDOWN = 5 EVENT_RBUTTONDOWN = 5
EVENT_RBUTTONUP = 6 EVENT_RBUTTONUP = 6
EVENT_MOUSEWHEEL = 10 EVENT_MOUSEWHEEL = 10
def __init__(self): def __init__(self):
self.named_windows = {} self.named_windows = {}
self.capture_mouse_windows = {} self.capture_mouse_windows = {}
self.capture_keys_windows = {} self.capture_keys_windows = {}
self.mouse_events = {} self.mouse_events = {}
self.key_events = {} self.key_events = {}
self.pg_bar = None self.pg_bar = None
def log_info(self, msg, end='\n'): def log_info(self, msg, end='\n'):
print (msg, end=end) print (msg, end=end)
def log_err(self, msg, end='\n'): def log_err(self, msg, end='\n'):
print (msg, end=end) print (msg, end=end)
def named_window(self, wnd_name): def named_window(self, wnd_name):
if wnd_name not in self.named_windows: if wnd_name not in self.named_windows:
#we will show window only on first show_image #we will show window only on first show_image
self.named_windows[wnd_name] = 0 self.named_windows[wnd_name] = 0
else: print("named_window: ", wnd_name, " already created.") else: print("named_window: ", wnd_name, " already created.")
def destroy_all_windows(self): def destroy_all_windows(self):
if len( self.named_windows ) != 0: if len( self.named_windows ) != 0:
cv2.destroyAllWindows() cv2.destroyAllWindows()
@ -42,32 +42,32 @@ class Interact(object):
self.capture_keys_windows = {} self.capture_keys_windows = {}
self.mouse_events = {} self.mouse_events = {}
self.key_events = {} self.key_events = {}
def show_image(self, wnd_name, img): def show_image(self, wnd_name, img):
if wnd_name in self.named_windows: if wnd_name in self.named_windows:
if self.named_windows[wnd_name] == 0: if self.named_windows[wnd_name] == 0:
self.named_windows[wnd_name] = 1 self.named_windows[wnd_name] = 1
cv2.namedWindow(wnd_name) cv2.namedWindow(wnd_name)
if wnd_name in self.capture_mouse_windows: if wnd_name in self.capture_mouse_windows:
self.capture_mouse(wnd_name) self.capture_mouse(wnd_name)
cv2.imshow (wnd_name, img) cv2.imshow (wnd_name, img)
else: print("show_image: named_window ", wnd_name, " not found.") else: print("show_image: named_window ", wnd_name, " not found.")
def capture_mouse(self, wnd_name): def capture_mouse(self, wnd_name):
def onMouse(event, x, y, flags, param): def onMouse(event, x, y, flags, param):
(inst, wnd_name) = param (inst, wnd_name) = param
if event == cv2.EVENT_LBUTTONDOWN: ev = Interact.EVENT_LBUTTONDOWN if event == cv2.EVENT_LBUTTONDOWN: ev = Interact.EVENT_LBUTTONDOWN
elif event == cv2.EVENT_LBUTTONUP: ev = Interact.EVENT_LBUTTONUP elif event == cv2.EVENT_LBUTTONUP: ev = Interact.EVENT_LBUTTONUP
elif event == cv2.EVENT_RBUTTONDOWN: ev = Interact.EVENT_RBUTTONDOWN elif event == cv2.EVENT_RBUTTONDOWN: ev = Interact.EVENT_RBUTTONDOWN
elif event == cv2.EVENT_RBUTTONUP: ev = Interact.EVENT_RBUTTONUP elif event == cv2.EVENT_RBUTTONUP: ev = Interact.EVENT_RBUTTONUP
elif event == cv2.EVENT_MOUSEWHEEL: ev = Interact.EVENT_MOUSEWHEEL elif event == cv2.EVENT_MOUSEWHEEL: ev = Interact.EVENT_MOUSEWHEEL
else: ev = 0 else: ev = 0
inst.add_mouse_event (wnd_name, x, y, ev, flags) inst.add_mouse_event (wnd_name, x, y, ev, flags)
if wnd_name in self.named_windows: if wnd_name in self.named_windows:
self.capture_mouse_windows[wnd_name] = True self.capture_mouse_windows[wnd_name] = True
if self.named_windows[wnd_name] == 1: if self.named_windows[wnd_name] == 1:
cv2.setMouseCallback(wnd_name, onMouse, (self,wnd_name) ) cv2.setMouseCallback(wnd_name, onMouse, (self,wnd_name) )
else: print("capture_mouse: named_window ", wnd_name, " not found.") else: print("capture_mouse: named_window ", wnd_name, " not found.")
@ -95,66 +95,66 @@ class Interact(object):
self.pg_bar.close() self.pg_bar.close()
self.pg_bar = None self.pg_bar = None
else: print("progress_bar not set.") else: print("progress_bar not set.")
def progress_bar_generator(self, data, desc, leave=True): def progress_bar_generator(self, data, desc, leave=True):
for x in tqdm( data, desc=desc, leave=leave, ascii=True ): for x in tqdm( data, desc=desc, leave=leave, ascii=True ):
yield x yield x
def process_messages(self, sleep_time=0): def process_messages(self, sleep_time=0):
has_windows = False has_windows = False
has_capture_keys = False has_capture_keys = False
if len(self.named_windows) != 0: if len(self.named_windows) != 0:
has_windows = True has_windows = True
if len(self.capture_keys_windows) != 0: if len(self.capture_keys_windows) != 0:
has_capture_keys = True has_capture_keys = True
if has_windows or has_capture_keys: if has_windows or has_capture_keys:
wait_key_time = max(1, int(sleep_time*1000) ) wait_key_time = max(1, int(sleep_time*1000) )
key = cv2.waitKey(wait_key_time) & 0xFF key = cv2.waitKey(wait_key_time) & 0xFF
else: else:
if sleep_time != 0: if sleep_time != 0:
time.sleep(sleep_time) time.sleep(sleep_time)
if has_capture_keys and key != 255: if has_capture_keys and key != 255:
for wnd_name in self.capture_keys_windows: for wnd_name in self.capture_keys_windows:
self.add_key_event (wnd_name, key) self.add_key_event (wnd_name, key)
def wait_any_key(self): def wait_any_key(self):
cv2.waitKey(0) cv2.waitKey(0)
def add_mouse_event(self, wnd_name, x, y, ev, flags): def add_mouse_event(self, wnd_name, x, y, ev, flags):
if wnd_name not in self.mouse_events: if wnd_name not in self.mouse_events:
self.mouse_events[wnd_name] = [] self.mouse_events[wnd_name] = []
self.mouse_events[wnd_name] += [ (x, y, ev, flags) ] self.mouse_events[wnd_name] += [ (x, y, ev, flags) ]
def add_key_event(self, wnd_name, key): def add_key_event(self, wnd_name, key):
if wnd_name not in self.key_events: if wnd_name not in self.key_events:
self.key_events[wnd_name] = [] self.key_events[wnd_name] = []
self.key_events[wnd_name] += [ (key,) ] self.key_events[wnd_name] += [ (key,) ]
def get_mouse_events(self, wnd_name): def get_mouse_events(self, wnd_name):
ar = self.mouse_events.get(wnd_name, []) ar = self.mouse_events.get(wnd_name, [])
self.mouse_events[wnd_name] = [] self.mouse_events[wnd_name] = []
return ar return ar
def get_key_events(self, wnd_name): def get_key_events(self, wnd_name):
ar = self.key_events.get(wnd_name, []) ar = self.key_events.get(wnd_name, [])
self.key_events[wnd_name] = [] self.key_events[wnd_name] = []
return ar return ar
def input_number(self, s, default_value, valid_list=None, help_message=None): def input_number(self, s, default_value, valid_list=None, help_message=None):
while True: while True:
try: try:
inp = input(s) inp = input(s)
if len(inp) == 0: if len(inp) == 0:
raise ValueError("") raise ValueError("")
if help_message is not None and inp == '?': if help_message is not None and inp == '?':
print (help_message) print (help_message)
continue continue
i = float(inp) i = float(inp)
if (valid_list is not None) and (i not in valid_list): if (valid_list is not None) and (i not in valid_list):
return default_value return default_value
@ -162,18 +162,18 @@ class Interact(object):
except: except:
print (default_value) print (default_value)
return default_value return default_value
def input_int(self,s, default_value, valid_list=None, help_message=None): def input_int(self,s, default_value, valid_list=None, help_message=None):
while True: while True:
try: try:
inp = input(s) inp = input(s)
if len(inp) == 0: if len(inp) == 0:
raise ValueError("") raise ValueError("")
if help_message is not None and inp == '?': if help_message is not None and inp == '?':
print (help_message) print (help_message)
continue continue
i = int(inp) i = int(inp)
if (valid_list is not None) and (i not in valid_list): if (valid_list is not None) and (i not in valid_list):
return default_value return default_value
@ -181,41 +181,41 @@ class Interact(object):
except: except:
print (default_value) print (default_value)
return default_value return default_value
def input_bool(self, s, default_value, help_message=None): def input_bool(self, s, default_value, help_message=None):
while True: while True:
try: try:
inp = input(s) inp = input(s)
if len(inp) == 0: if len(inp) == 0:
raise ValueError("") raise ValueError("")
if help_message is not None and inp == '?': if help_message is not None and inp == '?':
print (help_message) print (help_message)
continue continue
return bool ( {"y":True,"n":False,"1":True,"0":False}.get(inp.lower(), default_value) ) return bool ( {"y":True,"n":False,"1":True,"0":False}.get(inp.lower(), default_value) )
except: except:
print ( "y" if default_value else "n" ) print ( "y" if default_value else "n" )
return default_value return default_value
def input_str(self, s, default_value, valid_list=None, help_message=None): def input_str(self, s, default_value, valid_list=None, help_message=None):
while True: while True:
try: try:
inp = input(s) inp = input(s)
if len(inp) == 0: if len(inp) == 0:
raise ValueError("") raise ValueError("")
if help_message is not None and inp == '?': if help_message is not None and inp == '?':
print (help_message) print (help_message)
continue continue
if (valid_list is not None) and (inp.lower() not in valid_list): if (valid_list is not None) and (inp.lower() not in valid_list):
return default_value return default_value
return inp return inp
except: except:
print (default_value) print (default_value)
return default_value return default_value
def input_process(self, stdin_fd, sq, str): def input_process(self, stdin_fd, sq, str):
sys.stdin = os.fdopen(stdin_fd) sys.stdin = os.fdopen(stdin_fd)
try: try:
@ -223,7 +223,7 @@ class Interact(object):
sq.put (True) sq.put (True)
except: except:
sq.put (False) sq.put (False)
def input_in_time (self, str, max_time_sec): def input_in_time (self, str, max_time_sec):
sq = multiprocessing.Queue() sq = multiprocessing.Queue()
p = multiprocessing.Process(target=self.input_process, args=( sys.stdin.fileno(), sq, str)) p = multiprocessing.Process(target=self.input_process, args=( sys.stdin.fileno(), sq, str))
@ -240,4 +240,4 @@ class Interact(object):
sys.stdin = os.fdopen( sys.stdin.fileno() ) sys.stdin = os.fdopen( sys.stdin.fileno() )
return inp return inp
interact = Interact() interact = Interact()

View file

@ -7,7 +7,7 @@ class SubprocessFunctionCaller(object):
self.s2c = s2c self.s2c = s2c
self.c2s = c2s self.c2s = c2s
self.lock = lock self.lock = lock
def __call__(self, value): def __call__(self, value):
self.lock.acquire() self.lock.acquire()
self.c2s.put (value) self.c2s.put (value)
@ -17,26 +17,26 @@ class SubprocessFunctionCaller(object):
self.lock.release() self.lock.release()
return obj return obj
time.sleep(0.005) time.sleep(0.005)
class HostProcessor(object): class HostProcessor(object):
def __init__(self, s2c, c2s, func): def __init__(self, s2c, c2s, func):
self.s2c = s2c self.s2c = s2c
self.c2s = c2s self.c2s = c2s
self.func = func self.func = func
def process_messages(self): def process_messages(self):
while not self.c2s.empty(): while not self.c2s.empty():
obj = self.c2s.get() obj = self.c2s.get()
result = self.func (obj) result = self.func (obj)
self.s2c.put (result) self.s2c.put (result)
@staticmethod @staticmethod
def make_pair( func ): def make_pair( func ):
s2c = multiprocessing.Queue() s2c = multiprocessing.Queue()
c2s = multiprocessing.Queue() c2s = multiprocessing.Queue()
lock = multiprocessing.Lock() lock = multiprocessing.Lock()
host_processor = SubprocessFunctionCaller.HostProcessor (s2c, c2s, func) host_processor = SubprocessFunctionCaller.HostProcessor (s2c, c2s, func)
cli_func = SubprocessFunctionCaller.CliFunction (s2c, c2s, lock) cli_func = SubprocessFunctionCaller.CliFunction (s2c, c2s, lock)
return host_processor, cli_func return host_processor, cli_func

View file

@ -3,12 +3,12 @@ import multiprocessing
import time import time
import sys import sys
from interact import interact as io from interact import interact as io
class Subprocessor(object): class Subprocessor(object):
class SilenceException(Exception): class SilenceException(Exception):
pass pass
class Cli(object): class Cli(object):
def __init__ ( self, client_dict ): def __init__ ( self, client_dict ):
self.s2c = multiprocessing.Queue() self.s2c = multiprocessing.Queue()
@ -16,41 +16,41 @@ class Subprocessor(object):
self.p = multiprocessing.Process(target=self._subprocess_run, args=(client_dict,) ) self.p = multiprocessing.Process(target=self._subprocess_run, args=(client_dict,) )
self.p.daemon = True self.p.daemon = True
self.p.start() self.p.start()
self.state = None self.state = None
self.sent_time = None self.sent_time = None
self.sent_data = None self.sent_data = None
self.name = None self.name = None
self.host_dict = None self.host_dict = None
def kill(self): def kill(self):
self.p.terminate() self.p.terminate()
self.p.join() self.p.join()
#overridable optional #overridable optional
def on_initialize(self, client_dict): def on_initialize(self, client_dict):
#initialize your subprocess here using client_dict #initialize your subprocess here using client_dict
pass pass
#overridable optional #overridable optional
def on_finalize(self): def on_finalize(self):
#finalize your subprocess here #finalize your subprocess here
pass pass
#overridable #overridable
def process_data(self, data): def process_data(self, data):
#process 'data' given from host and return result #process 'data' given from host and return result
raise NotImplementedError raise NotImplementedError
#overridable optional #overridable optional
def get_data_name (self, data): def get_data_name (self, data):
#return string identificator of your 'data' #return string identificator of your 'data'
return "undefined" return "undefined"
def log_info(self, msg): self.c2s.put ( {'op': 'log_info', 'msg':msg } ) def log_info(self, msg): self.c2s.put ( {'op': 'log_info', 'msg':msg } )
def log_err(self, msg): self.c2s.put ( {'op': 'log_err' , 'msg':msg } ) def log_err(self, msg): self.c2s.put ( {'op': 'log_err' , 'msg':msg } )
def progress_bar_inc(self, c): self.c2s.put ( {'op': 'progress_bar_inc' , 'c':c } ) def progress_bar_inc(self, c): self.c2s.put ( {'op': 'progress_bar_inc' , 'c':c } )
def _subprocess_run(self, client_dict): def _subprocess_run(self, client_dict):
data = None data = None
s2c, c2s = self.s2c, self.c2s s2c, c2s = self.s2c, self.c2s
@ -65,20 +65,20 @@ class Subprocessor(object):
if op == 'data': if op == 'data':
data = msg['data'] data = msg['data']
result = self.process_data (data) result = self.process_data (data)
c2s.put ( {'op': 'success', 'data' : data, 'result' : result} ) c2s.put ( {'op': 'success', 'data' : data, 'result' : result} )
data = None data = None
elif op == 'close': elif op == 'close':
break break
time.sleep(0.001) time.sleep(0.001)
self.on_finalize() self.on_finalize()
c2s.put ( {'op': 'finalized'} ) c2s.put ( {'op': 'finalized'} )
return return
except Subprocessor.SilenceException as e: except Subprocessor.SilenceException as e:
pass pass
except Exception as e: except Exception as e:
if data is not None: if data is not None:
print ('Exception while process data [%s]: %s' % (self.get_data_name(data), traceback.format_exc()) ) print ('Exception while process data [%s]: %s' % (self.get_data_name(data), traceback.format_exc()) )
else: else:
print ('Exception: %s' % (traceback.format_exc()) ) print ('Exception: %s' % (traceback.format_exc()) )
@ -91,10 +91,10 @@ class Subprocessor(object):
raise ValueError("SubprocessorCli_class must be subclass of Subprocessor.Cli") raise ValueError("SubprocessorCli_class must be subclass of Subprocessor.Cli")
self.name = name self.name = name
self.SubprocessorCli_class = SubprocessorCli_class self.SubprocessorCli_class = SubprocessorCli_class
self.no_response_time_sec = no_response_time_sec self.no_response_time_sec = no_response_time_sec
#overridable #overridable
def process_info_generator(self): def process_info_generator(self):
#yield per process (name, host_dict, client_dict) #yield per process (name, host_dict, client_dict)
raise NotImplementedError raise NotImplementedError
@ -103,42 +103,42 @@ class Subprocessor(object):
def on_clients_initialized(self): def on_clients_initialized(self):
#logic when all subprocesses initialized and ready #logic when all subprocesses initialized and ready
pass pass
#overridable optional #overridable optional
def on_clients_finalized(self): def on_clients_finalized(self):
#logic when all subprocess finalized #logic when all subprocess finalized
pass pass
#overridable #overridable
def get_data(self, host_dict): def get_data(self, host_dict):
#return data for processing here #return data for processing here
raise NotImplementedError raise NotImplementedError
#overridable #overridable
def on_data_return (self, host_dict, data): def on_data_return (self, host_dict, data):
#you have to place returned 'data' back to your queue #you have to place returned 'data' back to your queue
raise NotImplementedError raise NotImplementedError
#overridable #overridable
def on_result (self, host_dict, data, result): def on_result (self, host_dict, data, result):
#your logic what to do with 'result' of 'data' #your logic what to do with 'result' of 'data'
raise NotImplementedError raise NotImplementedError
#overridable #overridable
def get_result(self): def get_result(self):
#return result that will be returned in func run() #return result that will be returned in func run()
raise NotImplementedError raise NotImplementedError
#overridable #overridable
def on_tick(self): def on_tick(self):
#tick in main loop #tick in main loop
pass pass
def run(self): def run(self):
self.clis = [] self.clis = []
#getting info about name of subprocesses, host and client dicts, and spawning them #getting info about name of subprocesses, host and client dicts, and spawning them
for name, host_dict, client_dict in self.process_info_generator(): for name, host_dict, client_dict in self.process_info_generator():
try: try:
cli = self.SubprocessorCli_class(client_dict) cli = self.SubprocessorCli_class(client_dict)
cli.state = 1 cli.state = 1
@ -146,21 +146,21 @@ class Subprocessor(object):
cli.sent_data = None cli.sent_data = None
cli.name = name cli.name = name
cli.host_dict = host_dict cli.host_dict = host_dict
self.clis.append (cli) self.clis.append (cli)
except: except:
raise Exception ("Unable to start subprocess %s" % (name)) raise Exception ("Unable to start subprocess %s" % (name))
if len(self.clis) == 0: if len(self.clis) == 0:
raise Exception ("Unable to start Subprocessor '%s' " % (self.name)) raise Exception ("Unable to start Subprocessor '%s' " % (self.name))
#waiting subprocesses their success(or not) initialization #waiting subprocesses their success(or not) initialization
while True: while True:
for cli in self.clis[:]: for cli in self.clis[:]:
while not cli.c2s.empty(): while not cli.c2s.empty():
obj = cli.c2s.get() obj = cli.c2s.get()
op = obj.get('op','') op = obj.get('op','')
if op == 'init_ok': if op == 'init_ok':
cli.state = 0 cli.state = 0
elif op == 'log_info': elif op == 'log_info':
@ -172,16 +172,16 @@ class Subprocessor(object):
self.clis.remove(cli) self.clis.remove(cli)
break break
if all ([cli.state == 0 for cli in self.clis]): if all ([cli.state == 0 for cli in self.clis]):
break break
io.process_messages(0.005) io.process_messages(0.005)
if len(self.clis) == 0: if len(self.clis) == 0:
raise Exception ( "Unable to start subprocesses." ) raise Exception ( "Unable to start subprocesses." )
#ok some processes survived, initialize host logic #ok some processes survived, initialize host logic
self.on_clients_initialized() self.on_clients_initialized()
#main loop of data processing #main loop of data processing
while True: while True:
for cli in self.clis[:]: for cli in self.clis[:]:
@ -206,10 +206,10 @@ class Subprocessor(object):
io.log_err(obj['msg']) io.log_err(obj['msg'])
elif op == 'progress_bar_inc': elif op == 'progress_bar_inc':
io.progress_bar_inc(obj['c']) io.progress_bar_inc(obj['c'])
for cli in self.clis[:]: for cli in self.clis[:]:
if cli.state == 0: if cli.state == 0:
#free state of subprocess, get some data from get_data #free state of subprocess, get some data from get_data
data = self.get_data(cli.host_dict) data = self.get_data(cli.host_dict)
if data is not None: if data is not None:
#and send it to subprocess #and send it to subprocess
@ -217,7 +217,7 @@ class Subprocessor(object):
cli.sent_time = time.time() cli.sent_time = time.time()
cli.sent_data = data cli.sent_data = data
cli.state = 1 cli.state = 1
elif cli.state == 1: elif cli.state == 1:
if self.no_response_time_sec != 0 and (time.time() - cli.sent_time) > self.no_response_time_sec: if self.no_response_time_sec != 0 and (time.time() - cli.sent_time) > self.no_response_time_sec:
#subprocess busy too long #subprocess busy too long
@ -225,39 +225,39 @@ class Subprocessor(object):
self.on_data_return (cli.host_dict, cli.sent_data ) self.on_data_return (cli.host_dict, cli.sent_data )
cli.kill() cli.kill()
self.clis.remove(cli) self.clis.remove(cli)
if all ([cli.state == 0 for cli in self.clis]): if all ([cli.state == 0 for cli in self.clis]):
#all subprocesses free and no more data available to process, ending loop #all subprocesses free and no more data available to process, ending loop
break break
io.process_messages(0.005) io.process_messages(0.005)
self.on_tick() self.on_tick()
#gracefully terminating subprocesses #gracefully terminating subprocesses
for cli in self.clis[:]: for cli in self.clis[:]:
cli.s2c.put ( {'op': 'close'} ) cli.s2c.put ( {'op': 'close'} )
cli.sent_time = time.time() cli.sent_time = time.time()
while True: while True:
for cli in self.clis[:]: for cli in self.clis[:]:
terminate_it = False terminate_it = False
while not cli.c2s.empty(): while not cli.c2s.empty():
obj = cli.c2s.get() obj = cli.c2s.get()
obj_op = obj['op'] obj_op = obj['op']
if obj_op == 'finalized': if obj_op == 'finalized':
terminate_it = True terminate_it = True
break break
if self.no_response_time_sec != 0 and (time.time() - cli.sent_time) > self.no_response_time_sec: if self.no_response_time_sec != 0 and (time.time() - cli.sent_time) > self.no_response_time_sec:
terminate_it = True terminate_it = True
if terminate_it: if terminate_it:
cli.state = 2 cli.state = 2
cli.kill() cli.kill()
if all ([cli.state == 2 for cli in self.clis]): if all ([cli.state == 2 for cli in self.clis]):
break break
#finalizing host logic and return result #finalizing host logic and return result
self.on_clients_finalized() self.on_clients_finalized()
return self.get_result() return self.get_result()

View file

@ -1,2 +1,2 @@
from .SubprocessorBase import Subprocessor from .SubprocessorBase import Subprocessor
from .SubprocessFunctionCaller import SubprocessFunctionCaller from .SubprocessFunctionCaller import SubprocessFunctionCaller

136
main.py
View file

@ -14,110 +14,110 @@ class fixPathAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None): def __call__(self, parser, namespace, values, option_string=None):
setattr(namespace, self.dest, os.path.abspath(os.path.expanduser(values))) setattr(namespace, self.dest, os.path.abspath(os.path.expanduser(values)))
if __name__ == "__main__": if __name__ == "__main__":
multiprocessing.set_start_method("spawn") multiprocessing.set_start_method("spawn")
os_utils.set_process_lowest_prio() os_utils.set_process_lowest_prio()
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers() subparsers = parser.add_subparsers()
def process_extract(arguments): def process_extract(arguments):
from mainscripts import Extractor from mainscripts import Extractor
Extractor.main( arguments.input_dir, Extractor.main( arguments.input_dir,
arguments.output_dir, arguments.output_dir,
arguments.debug_dir, arguments.debug_dir,
arguments.detector, arguments.detector,
arguments.manual_fix, arguments.manual_fix,
arguments.manual_output_debug_fix, arguments.manual_output_debug_fix,
arguments.manual_window_size, arguments.manual_window_size,
face_type=arguments.face_type, face_type=arguments.face_type,
device_args={'cpu_only' : arguments.cpu_only, device_args={'cpu_only' : arguments.cpu_only,
'multi_gpu' : arguments.multi_gpu, 'multi_gpu' : arguments.multi_gpu,
} }
) )
p = subparsers.add_parser( "extract", help="Extract the faces from a pictures.") p = subparsers.add_parser( "extract", help="Extract the faces from a pictures.")
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory. A directory containing the files you wish to process.") p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory. A directory containing the files you wish to process.")
p.add_argument('--output-dir', required=True, action=fixPathAction, dest="output_dir", help="Output directory. This is where the extracted files will be stored.") p.add_argument('--output-dir', required=True, action=fixPathAction, dest="output_dir", help="Output directory. This is where the extracted files will be stored.")
p.add_argument('--debug-dir', action=fixPathAction, dest="debug_dir", help="Writes debug images to this directory.") p.add_argument('--debug-dir', action=fixPathAction, dest="debug_dir", help="Writes debug images to this directory.")
p.add_argument('--face-type', dest="face_type", choices=['half_face', 'full_face', 'head', 'avatar', 'mark_only'], default='full_face', help="Default 'full_face'. Don't change this option, currently all models uses 'full_face'") p.add_argument('--face-type', dest="face_type", choices=['half_face', 'full_face', 'head', 'avatar', 'mark_only'], default='full_face', help="Default 'full_face'. Don't change this option, currently all models uses 'full_face'")
p.add_argument('--detector', dest="detector", choices=['dlib','mt','s3fd','manual'], default='dlib', help="Type of detector. Default 'dlib'. 'mt' (MTCNNv1) - faster, better, almost no jitter, perfect for gathering thousands faces for src-set. It is also good for dst-set, but can generate false faces in frames where main face not recognized! In this case for dst-set use either 'dlib' with '--manual-fix' or '--detector manual'. Manual detector suitable only for dst-set.") p.add_argument('--detector', dest="detector", choices=['dlib','mt','s3fd','manual'], default='dlib', help="Type of detector. Default 'dlib'. 'mt' (MTCNNv1) - faster, better, almost no jitter, perfect for gathering thousands faces for src-set. It is also good for dst-set, but can generate false faces in frames where main face not recognized! In this case for dst-set use either 'dlib' with '--manual-fix' or '--detector manual'. Manual detector suitable only for dst-set.")
p.add_argument('--multi-gpu', action="store_true", dest="multi_gpu", default=False, help="Enables multi GPU.") p.add_argument('--multi-gpu', action="store_true", dest="multi_gpu", default=False, help="Enables multi GPU.")
p.add_argument('--manual-fix', action="store_true", dest="manual_fix", default=False, help="Enables manual extract only frames where faces were not recognized.") p.add_argument('--manual-fix', action="store_true", dest="manual_fix", default=False, help="Enables manual extract only frames where faces were not recognized.")
p.add_argument('--manual-output-debug-fix', action="store_true", dest="manual_output_debug_fix", default=False, help="Performs manual reextract input-dir frames which were deleted from [output_dir]_debug\ dir.") p.add_argument('--manual-output-debug-fix', action="store_true", dest="manual_output_debug_fix", default=False, help="Performs manual reextract input-dir frames which were deleted from [output_dir]_debug\ dir.")
p.add_argument('--manual-window-size', type=int, dest="manual_window_size", default=1368, help="Manual fix window size. Default: 1368.") p.add_argument('--manual-window-size', type=int, dest="manual_window_size", default=1368, help="Manual fix window size. Default: 1368.")
p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Extract on CPU. Forces to use MT extractor.") p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Extract on CPU. Forces to use MT extractor.")
p.set_defaults (func=process_extract) p.set_defaults (func=process_extract)
def process_sort(arguments): def process_sort(arguments):
from mainscripts import Sorter from mainscripts import Sorter
Sorter.main (input_path=arguments.input_dir, sort_by_method=arguments.sort_by_method) Sorter.main (input_path=arguments.input_dir, sort_by_method=arguments.sort_by_method)
p = subparsers.add_parser( "sort", help="Sort faces in a directory.") p = subparsers.add_parser( "sort", help="Sort faces in a directory.")
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory. A directory containing the files you wish to process.") p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory. A directory containing the files you wish to process.")
p.add_argument('--by', required=True, dest="sort_by_method", choices=("blur", "face", "face-dissim", "face-yaw", "face-pitch", "hist", "hist-dissim", "brightness", "hue", "black", "origname", "oneface", "final", "final-no-blur", "test"), help="Method of sorting. 'origname' sort by original filename to recover original sequence." ) p.add_argument('--by', required=True, dest="sort_by_method", choices=("blur", "face", "face-dissim", "face-yaw", "face-pitch", "hist", "hist-dissim", "brightness", "hue", "black", "origname", "oneface", "final", "final-no-blur", "test"), help="Method of sorting. 'origname' sort by original filename to recover original sequence." )
p.set_defaults (func=process_sort) p.set_defaults (func=process_sort)
def process_util(arguments): def process_util(arguments):
from mainscripts import Util from mainscripts import Util
if arguments.convert_png_to_jpg: if arguments.convert_png_to_jpg:
Util.convert_png_to_jpg_folder (input_path=arguments.input_dir) Util.convert_png_to_jpg_folder (input_path=arguments.input_dir)
if arguments.add_landmarks_debug_images: if arguments.add_landmarks_debug_images:
Util.add_landmarks_debug_images (input_path=arguments.input_dir) Util.add_landmarks_debug_images (input_path=arguments.input_dir)
if arguments.recover_original_aligned_filename: if arguments.recover_original_aligned_filename:
Util.recover_original_aligned_filename (input_path=arguments.input_dir) Util.recover_original_aligned_filename (input_path=arguments.input_dir)
p = subparsers.add_parser( "util", help="Utilities.") p = subparsers.add_parser( "util", help="Utilities.")
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory. A directory containing the files you wish to process.") p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory. A directory containing the files you wish to process.")
p.add_argument('--convert-png-to-jpg', action="store_true", dest="convert_png_to_jpg", default=False, help="Convert DeepFaceLAB PNG files to JPEG.") p.add_argument('--convert-png-to-jpg', action="store_true", dest="convert_png_to_jpg", default=False, help="Convert DeepFaceLAB PNG files to JPEG.")
p.add_argument('--add-landmarks-debug-images', action="store_true", dest="add_landmarks_debug_images", default=False, help="Add landmarks debug image for aligned faces.") p.add_argument('--add-landmarks-debug-images', action="store_true", dest="add_landmarks_debug_images", default=False, help="Add landmarks debug image for aligned faces.")
p.add_argument('--recover-original-aligned-filename', action="store_true", dest="recover_original_aligned_filename", default=False, help="Recover original aligned filename.") p.add_argument('--recover-original-aligned-filename', action="store_true", dest="recover_original_aligned_filename", default=False, help="Recover original aligned filename.")
p.set_defaults (func=process_util) p.set_defaults (func=process_util)
def process_train(arguments): def process_train(arguments):
args = {'training_data_src_dir' : arguments.training_data_src_dir, args = {'training_data_src_dir' : arguments.training_data_src_dir,
'training_data_dst_dir' : arguments.training_data_dst_dir, 'training_data_dst_dir' : arguments.training_data_dst_dir,
'model_path' : arguments.model_dir, 'model_path' : arguments.model_dir,
'model_name' : arguments.model_name, 'model_name' : arguments.model_name,
'no_preview' : arguments.no_preview, 'no_preview' : arguments.no_preview,
'debug' : arguments.debug, 'debug' : arguments.debug,
} }
device_args = {'cpu_only' : arguments.cpu_only, device_args = {'cpu_only' : arguments.cpu_only,
'force_gpu_idx' : arguments.force_gpu_idx, 'force_gpu_idx' : arguments.force_gpu_idx,
} }
from mainscripts import Trainer from mainscripts import Trainer
Trainer.main(args, device_args) Trainer.main(args, device_args)
p = subparsers.add_parser( "train", help="Trainer") p = subparsers.add_parser( "train", help="Trainer")
p.add_argument('--training-data-src-dir', required=True, action=fixPathAction, dest="training_data_src_dir", help="Dir of src-set.") p.add_argument('--training-data-src-dir', required=True, action=fixPathAction, dest="training_data_src_dir", help="Dir of src-set.")
p.add_argument('--training-data-dst-dir', required=True, action=fixPathAction, dest="training_data_dst_dir", help="Dir of dst-set.") p.add_argument('--training-data-dst-dir', required=True, action=fixPathAction, dest="training_data_dst_dir", help="Dir of dst-set.")
p.add_argument('--model-dir', required=True, action=fixPathAction, dest="model_dir", help="Model dir.") p.add_argument('--model-dir', required=True, action=fixPathAction, dest="model_dir", help="Model dir.")
p.add_argument('--model', required=True, dest="model_name", choices=Path_utils.get_all_dir_names_startswith ( Path(__file__).parent / 'models' , 'Model_'), help="Type of model") p.add_argument('--model', required=True, dest="model_name", choices=Path_utils.get_all_dir_names_startswith ( Path(__file__).parent / 'models' , 'Model_'), help="Type of model")
p.add_argument('--no-preview', action="store_true", dest="no_preview", default=False, help="Disable preview window.") p.add_argument('--no-preview', action="store_true", dest="no_preview", default=False, help="Disable preview window.")
p.add_argument('--debug', action="store_true", dest="debug", default=False, help="Debug samples.") p.add_argument('--debug', action="store_true", dest="debug", default=False, help="Debug samples.")
p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Train on CPU.") p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Train on CPU.")
p.add_argument('--force-gpu-idx', type=int, dest="force_gpu_idx", default=-1, help="Force to choose this GPU idx.") p.add_argument('--force-gpu-idx', type=int, dest="force_gpu_idx", default=-1, help="Force to choose this GPU idx.")
p.set_defaults (func=process_train) p.set_defaults (func=process_train)
def process_convert(arguments): def process_convert(arguments):
args = {'input_dir' : arguments.input_dir, args = {'input_dir' : arguments.input_dir,
'output_dir' : arguments.output_dir, 'output_dir' : arguments.output_dir,
'aligned_dir' : arguments.aligned_dir, 'aligned_dir' : arguments.aligned_dir,
'model_dir' : arguments.model_dir, 'model_dir' : arguments.model_dir,
'model_name' : arguments.model_name, 'model_name' : arguments.model_name,
'debug' : arguments.debug, 'debug' : arguments.debug,
} }
device_args = {'cpu_only' : arguments.cpu_only, device_args = {'cpu_only' : arguments.cpu_only,
'force_gpu_idx' : arguments.force_gpu_idx, 'force_gpu_idx' : arguments.force_gpu_idx,
} }
from mainscripts import Converter from mainscripts import Converter
Converter.main (args, device_args) Converter.main (args, device_args)
p = subparsers.add_parser( "convert", help="Converter") p = subparsers.add_parser( "convert", help="Converter")
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory. A directory containing the files you wish to process.") p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory. A directory containing the files you wish to process.")
p.add_argument('--output-dir', required=True, action=fixPathAction, dest="output_dir", help="Output directory. This is where the converted files will be stored.") p.add_argument('--output-dir', required=True, action=fixPathAction, dest="output_dir", help="Output directory. This is where the converted files will be stored.")
p.add_argument('--aligned-dir', action=fixPathAction, dest="aligned_dir", help="Aligned directory. This is where the extracted of dst faces stored. Not used in AVATAR model.") p.add_argument('--aligned-dir', action=fixPathAction, dest="aligned_dir", help="Aligned directory. This is where the extracted of dst faces stored. Not used in AVATAR model.")
@ -127,10 +127,10 @@ if __name__ == "__main__":
p.add_argument('--force-gpu-idx', type=int, dest="force_gpu_idx", default=-1, help="Force to choose this GPU idx.") p.add_argument('--force-gpu-idx', type=int, dest="force_gpu_idx", default=-1, help="Force to choose this GPU idx.")
p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Convert on CPU.") p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Convert on CPU.")
p.set_defaults(func=process_convert) p.set_defaults(func=process_convert)
videoed_parser = subparsers.add_parser( "videoed", help="Video processing.").add_subparsers() videoed_parser = subparsers.add_parser( "videoed", help="Video processing.").add_subparsers()
def process_videoed_extract_video(arguments): def process_videoed_extract_video(arguments):
from mainscripts import VideoEd from mainscripts import VideoEd
VideoEd.extract_video (arguments.input_file, arguments.output_dir, arguments.output_ext, arguments.fps) VideoEd.extract_video (arguments.input_file, arguments.output_dir, arguments.output_ext, arguments.fps)
p = videoed_parser.add_parser( "extract-video", help="Extract images from video file.") p = videoed_parser.add_parser( "extract-video", help="Extract images from video file.")
@ -139,23 +139,23 @@ if __name__ == "__main__":
p.add_argument('--ouptut-ext', dest="output_ext", default='png', help="Image format (extension) of output files.") p.add_argument('--ouptut-ext', dest="output_ext", default='png', help="Image format (extension) of output files.")
p.add_argument('--fps', type=int, dest="fps", default=None, help="How many frames of every second of the video will be extracted. 0 - full fps.") p.add_argument('--fps', type=int, dest="fps", default=None, help="How many frames of every second of the video will be extracted. 0 - full fps.")
p.set_defaults(func=process_videoed_extract_video) p.set_defaults(func=process_videoed_extract_video)
def process_videoed_cut_video(arguments): def process_videoed_cut_video(arguments):
from mainscripts import VideoEd from mainscripts import VideoEd
VideoEd.cut_video (arguments.input_file, VideoEd.cut_video (arguments.input_file,
arguments.from_time, arguments.from_time,
arguments.to_time, arguments.to_time,
arguments.audio_track_id, arguments.audio_track_id,
arguments.bitrate) arguments.bitrate)
p = videoed_parser.add_parser( "cut-video", help="Cut video file.") p = videoed_parser.add_parser( "cut-video", help="Cut video file.")
p.add_argument('--input-file', required=True, action=fixPathAction, dest="input_file", help="Input file to be processed. Specify .*-extension to find first file.") p.add_argument('--input-file', required=True, action=fixPathAction, dest="input_file", help="Input file to be processed. Specify .*-extension to find first file.")
p.add_argument('--from-time', dest="from_time", default=None, help="From time, for example 00:00:00.000") p.add_argument('--from-time', dest="from_time", default=None, help="From time, for example 00:00:00.000")
p.add_argument('--to-time', dest="to_time", default=None, help="To time, for example 00:00:00.000") p.add_argument('--to-time', dest="to_time", default=None, help="To time, for example 00:00:00.000")
p.add_argument('--audio-track-id', type=int, dest="audio_track_id", default=None, help="Specify audio track id.") p.add_argument('--audio-track-id', type=int, dest="audio_track_id", default=None, help="Specify audio track id.")
p.add_argument('--bitrate', type=int, dest="bitrate", default=None, help="Bitrate of output file in Megabits.") p.add_argument('--bitrate', type=int, dest="bitrate", default=None, help="Bitrate of output file in Megabits.")
p.set_defaults(func=process_videoed_cut_video) p.set_defaults(func=process_videoed_cut_video)
def process_videoed_denoise_image_sequence(arguments): def process_videoed_denoise_image_sequence(arguments):
from mainscripts import VideoEd from mainscripts import VideoEd
VideoEd.denoise_image_sequence (arguments.input_dir, arguments.ext, arguments.factor) VideoEd.denoise_image_sequence (arguments.input_dir, arguments.ext, arguments.factor)
p = videoed_parser.add_parser( "denoise-image-sequence", help="Denoise sequence of images, keeping sharp edges. This allows you to make the final fake more believable, since the neural network is not able to make a detailed skin texture, but it makes the edges quite clear. Therefore, if the whole frame is more `blurred`, then a fake will seem more believable. Especially true for scenes of the film, which are usually very clear.") p = videoed_parser.add_parser( "denoise-image-sequence", help="Denoise sequence of images, keeping sharp edges. This allows you to make the final fake more believable, since the neural network is not able to make a detailed skin texture, but it makes the edges quite clear. Therefore, if the whole frame is more `blurred`, then a fake will seem more believable. Especially true for scenes of the film, which are usually very clear.")
@ -163,65 +163,65 @@ if __name__ == "__main__":
p.add_argument('--ext', dest="ext", default='png', help="Image format (extension) of input files.") p.add_argument('--ext', dest="ext", default='png', help="Image format (extension) of input files.")
p.add_argument('--factor', type=int, dest="factor", default=None, help="Denoise factor (1-20).") p.add_argument('--factor', type=int, dest="factor", default=None, help="Denoise factor (1-20).")
p.set_defaults(func=process_videoed_denoise_image_sequence) p.set_defaults(func=process_videoed_denoise_image_sequence)
def process_videoed_video_from_sequence(arguments): def process_videoed_video_from_sequence(arguments):
from mainscripts import VideoEd from mainscripts import VideoEd
VideoEd.video_from_sequence (arguments.input_dir, VideoEd.video_from_sequence (arguments.input_dir,
arguments.output_file, arguments.output_file,
arguments.reference_file, arguments.reference_file,
arguments.ext, arguments.ext,
arguments.fps, arguments.fps,
arguments.bitrate, arguments.bitrate,
arguments.lossless) arguments.lossless)
p = videoed_parser.add_parser( "video-from-sequence", help="Make video from image sequence.") p = videoed_parser.add_parser( "video-from-sequence", help="Make video from image sequence.")
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input file to be processed. Specify .*-extension to find first file.") p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input file to be processed. Specify .*-extension to find first file.")
p.add_argument('--output-file', required=True, action=fixPathAction, dest="output_file", help="Input file to be processed. Specify .*-extension to find first file.") p.add_argument('--output-file', required=True, action=fixPathAction, dest="output_file", help="Input file to be processed. Specify .*-extension to find first file.")
p.add_argument('--reference-file', action=fixPathAction, dest="reference_file", help="Reference file used to determine proper FPS and transfer audio from it. Specify .*-extension to find first file.") p.add_argument('--reference-file', action=fixPathAction, dest="reference_file", help="Reference file used to determine proper FPS and transfer audio from it. Specify .*-extension to find first file.")
p.add_argument('--ext', dest="ext", default='png', help="Image format (extension) of input files.") p.add_argument('--ext', dest="ext", default='png', help="Image format (extension) of input files.")
p.add_argument('--fps', type=int, dest="fps", default=None, help="FPS of output file. Overwritten by reference-file.") p.add_argument('--fps', type=int, dest="fps", default=None, help="FPS of output file. Overwritten by reference-file.")
p.add_argument('--bitrate', type=int, dest="bitrate", default=None, help="Bitrate of output file in Megabits.") p.add_argument('--bitrate', type=int, dest="bitrate", default=None, help="Bitrate of output file in Megabits.")
p.add_argument('--lossless', action="store_true", dest="lossless", default=False, help="PNG codec.") p.add_argument('--lossless', action="store_true", dest="lossless", default=False, help="PNG codec.")
p.set_defaults(func=process_videoed_video_from_sequence) p.set_defaults(func=process_videoed_video_from_sequence)
def process_labelingtool(arguments): def process_labelingtool(arguments):
from mainscripts import LabelingTool from mainscripts import LabelingTool
LabelingTool.main (arguments.input_dir, arguments.output_dir) LabelingTool.main (arguments.input_dir, arguments.output_dir)
p = subparsers.add_parser( "labelingtool", help="Labeling tool.") p = subparsers.add_parser( "labelingtool", help="Labeling tool.")
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory of aligned faces.") p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory of aligned faces.")
p.add_argument('--output-dir', required=True, action=fixPathAction, dest="output_dir", help="Output directory. This is where the labeled faces will be stored.") p.add_argument('--output-dir', required=True, action=fixPathAction, dest="output_dir", help="Output directory. This is where the labeled faces will be stored.")
p.set_defaults(func=process_labelingtool) p.set_defaults(func=process_labelingtool)
def bad_args(arguments): def bad_args(arguments):
parser.print_help() parser.print_help()
exit(0) exit(0)
parser.set_defaults(func=bad_args) parser.set_defaults(func=bad_args)
arguments = parser.parse_args() arguments = parser.parse_args()
#os.environ['force_plaidML'] = '1' #os.environ['force_plaidML'] = '1'
arguments.func(arguments) arguments.func(arguments)
print ("Done.") print ("Done.")
""" """
Suppressing error with keras 2.2.4+ on python exit: Suppressing error with keras 2.2.4+ on python exit:
Exception ignored in: <bound method BaseSession._Callable.__del__ of <tensorflow.python.client.session.BaseSession._Callable object at 0x000000001BDEA9B0>> Exception ignored in: <bound method BaseSession._Callable.__del__ of <tensorflow.python.client.session.BaseSession._Callable object at 0x000000001BDEA9B0>>
Traceback (most recent call last): Traceback (most recent call last):
File "D:\DeepFaceLab\_internal\bin\lib\site-packages\tensorflow\python\client\session.py", line 1413, in __del__ File "D:\DeepFaceLab\_internal\bin\lib\site-packages\tensorflow\python\client\session.py", line 1413, in __del__
AttributeError: 'NoneType' object has no attribute 'raise_exception_on_not_ok_status' AttributeError: 'NoneType' object has no attribute 'raise_exception_on_not_ok_status'
reproduce: https://github.com/keras-team/keras/issues/11751 ( still no solution ) reproduce: https://github.com/keras-team/keras/issues/11751 ( still no solution )
""" """
outnull_file = open(os.devnull, 'w') outnull_file = open(os.devnull, 'w')
os.dup2 ( outnull_file.fileno(), sys.stderr.fileno() ) os.dup2 ( outnull_file.fileno(), sys.stderr.fileno() )
sys.stderr = outnull_file sys.stderr = outnull_file
''' '''
import code import code
code.interact(local=dict(globals(), **locals())) code.interact(local=dict(globals(), **locals()))
''' '''

View file

@ -18,39 +18,39 @@ from interact import interact as io
class ConvertSubprocessor(Subprocessor): class ConvertSubprocessor(Subprocessor):
class Cli(Subprocessor.Cli): class Cli(Subprocessor.Cli):
#override #override
def on_initialize(self, client_dict): def on_initialize(self, client_dict):
io.log_info ('Running on %s.' % (client_dict['device_name']) ) io.log_info ('Running on %s.' % (client_dict['device_name']) )
self.device_idx = client_dict['device_idx'] self.device_idx = client_dict['device_idx']
self.device_name = client_dict['device_name'] self.device_name = client_dict['device_name']
self.converter = client_dict['converter'] self.converter = client_dict['converter']
self.output_path = Path(client_dict['output_dir']) if 'output_dir' in client_dict.keys() else None self.output_path = Path(client_dict['output_dir']) if 'output_dir' in client_dict.keys() else None
self.alignments = client_dict['alignments'] self.alignments = client_dict['alignments']
self.debug = client_dict['debug'] self.debug = client_dict['debug']
#transfer and set stdin in order to work code.interact in debug subprocess #transfer and set stdin in order to work code.interact in debug subprocess
stdin_fd = client_dict['stdin_fd'] stdin_fd = client_dict['stdin_fd']
if stdin_fd is not None: if stdin_fd is not None:
sys.stdin = os.fdopen(stdin_fd) sys.stdin = os.fdopen(stdin_fd)
from nnlib import nnlib from nnlib import nnlib
#model process ate all GPU mem, #model process ate all GPU mem,
#so we cannot use GPU for any TF operations in converter processes #so we cannot use GPU for any TF operations in converter processes
#therefore forcing active_DeviceConfig to CPU only #therefore forcing active_DeviceConfig to CPU only
nnlib.active_DeviceConfig = nnlib.DeviceConfig (cpu_only=True) nnlib.active_DeviceConfig = nnlib.DeviceConfig (cpu_only=True)
return None return None
#override #override
def process_data(self, data): def process_data(self, data):
filename_path = Path(data) filename_path = Path(data)
files_processed = 1 files_processed = 1
faces_processed = 0 faces_processed = 0
output_filename_path = self.output_path / (filename_path.stem + '.png') output_filename_path = self.output_path / (filename_path.stem + '.png')
if self.converter.type == Converter.TYPE_FACE and filename_path.stem not in self.alignments.keys(): if self.converter.type == Converter.TYPE_FACE and filename_path.stem not in self.alignments.keys():
if not self.debug: if not self.debug:
self.log_info ( 'no faces found for %s, copying without faces' % (filename_path.name) ) self.log_info ( 'no faces found for %s, copying without faces' % (filename_path.name) )
shutil.copy ( str(filename_path), str(output_filename_path) ) shutil.copy ( str(filename_path), str(output_filename_path) )
@ -72,12 +72,12 @@ class ConvertSubprocessor(Subprocessor):
dflimg = DFLJPG.load ( str(filename_path) ) dflimg = DFLJPG.load ( str(filename_path) )
else: else:
dflimg = None dflimg = None
if dflimg is not None: if dflimg is not None:
image_landmarks = dflimg.get_landmarks() image_landmarks = dflimg.get_landmarks()
image = self.converter.convert_image(image, image_landmarks, self.debug) image = self.converter.convert_image(image, image_landmarks, self.debug)
if self.debug: if self.debug:
raise NotImplementedError raise NotImplementedError
#for img in image: #for img in image:
@ -85,14 +85,14 @@ class ConvertSubprocessor(Subprocessor):
# cv2.waitKey(0) # cv2.waitKey(0)
faces_processed = 1 faces_processed = 1
else: else:
self.log_err ("%s is not a dfl image file" % (filename_path.name) ) self.log_err ("%s is not a dfl image file" % (filename_path.name) )
elif self.converter.type == Converter.TYPE_FACE: elif self.converter.type == Converter.TYPE_FACE:
faces = self.alignments[filename_path.stem] faces = self.alignments[filename_path.stem]
if self.debug: if self.debug:
debug_images = [] debug_images = []
for face_num, image_landmarks in enumerate(faces): for face_num, image_landmarks in enumerate(faces):
try: try:
if self.debug: if self.debug:
@ -101,56 +101,56 @@ class ConvertSubprocessor(Subprocessor):
if self.debug: if self.debug:
debug_images += self.converter.convert_face(image, image_landmarks, self.debug) debug_images += self.converter.convert_face(image, image_landmarks, self.debug)
else: else:
image = self.converter.convert_face(image, image_landmarks, self.debug) image = self.converter.convert_face(image, image_landmarks, self.debug)
except Exception as e: except Exception as e:
e_str = traceback.format_exc() e_str = traceback.format_exc()
if 'MemoryError' in e_str: if 'MemoryError' in e_str:
raise Subprocessor.SilenceException raise Subprocessor.SilenceException
else: else:
raise Exception( 'Error while converting face_num [%d] in file [%s]: %s' % (face_num, filename_path, e_str) ) raise Exception( 'Error while converting face_num [%d] in file [%s]: %s' % (face_num, filename_path, e_str) )
if self.debug: if self.debug:
return (1, debug_images) return (1, debug_images)
faces_processed = len(faces) faces_processed = len(faces)
if not self.debug: if not self.debug:
cv2_imwrite (str(output_filename_path), (image*255).astype(np.uint8) ) cv2_imwrite (str(output_filename_path), (image*255).astype(np.uint8) )
return (0, files_processed, faces_processed) return (0, files_processed, faces_processed)
#overridable #overridable
def get_data_name (self, data): def get_data_name (self, data):
#return string identificator of your data #return string identificator of your data
return data return data
#override
def __init__(self, converter, input_path_image_paths, output_path, alignments, debug = False):
super().__init__('Converter', ConvertSubprocessor.Cli, 86400 if debug == True else 60)
self.converter = converter #override
def __init__(self, converter, input_path_image_paths, output_path, alignments, debug = False):
super().__init__('Converter', ConvertSubprocessor.Cli, 86400 if debug == True else 60)
self.converter = converter
self.host_processor, self.cli_func = SubprocessFunctionCaller.make_pair ( self.converter.predictor_func ) self.host_processor, self.cli_func = SubprocessFunctionCaller.make_pair ( self.converter.predictor_func )
self.process_converter = self.converter.copy_and_set_predictor(self.cli_func) self.process_converter = self.converter.copy_and_set_predictor(self.cli_func)
self.input_data = self.input_path_image_paths = input_path_image_paths self.input_data = self.input_path_image_paths = input_path_image_paths
self.output_path = output_path self.output_path = output_path
self.alignments = alignments self.alignments = alignments
self.debug = debug self.debug = debug
self.files_processed = 0 self.files_processed = 0
self.faces_processed = 0 self.faces_processed = 0
#override #override
def process_info_generator(self): def process_info_generator(self):
r = [0] if self.debug else range(multiprocessing.cpu_count()) r = [0] if self.debug else range(multiprocessing.cpu_count())
for i in r: for i in r:
yield 'CPU%d' % (i), {}, {'device_idx': i, yield 'CPU%d' % (i), {}, {'device_idx': i,
'device_name': 'CPU%d' % (i), 'device_name': 'CPU%d' % (i),
'converter' : self.process_converter, 'converter' : self.process_converter,
'output_dir' : str(self.output_path), 'output_dir' : str(self.output_path),
'alignments' : self.alignments, 'alignments' : self.alignments,
'debug': self.debug, 'debug': self.debug,
'stdin_fd': sys.stdin.fileno() if self.debug else None 'stdin_fd': sys.stdin.fileno() if self.debug else None
@ -160,25 +160,25 @@ class ConvertSubprocessor(Subprocessor):
def on_clients_initialized(self): def on_clients_initialized(self):
if self.debug: if self.debug:
io.named_window ("Debug convert") io.named_window ("Debug convert")
io.progress_bar ("Converting", len (self.input_data) ) io.progress_bar ("Converting", len (self.input_data) )
#overridable optional #overridable optional
def on_clients_finalized(self): def on_clients_finalized(self):
io.progress_bar_close() io.progress_bar_close()
if self.debug: if self.debug:
io.destroy_all_windows() io.destroy_all_windows()
#override #override
def get_data(self, host_dict): def get_data(self, host_dict):
if len (self.input_data) > 0: if len (self.input_data) > 0:
return self.input_data.pop(0) return self.input_data.pop(0)
return None return None
#override #override
def on_data_return (self, host_dict, data): def on_data_return (self, host_dict, data):
self.input_data.insert(0, data) self.input_data.insert(0, data)
#override #override
def on_result (self, host_dict, data, result): def on_result (self, host_dict, data, result):
@ -190,25 +190,25 @@ class ConvertSubprocessor(Subprocessor):
io.show_image ('Debug convert', (img*255).astype(np.uint8) ) io.show_image ('Debug convert', (img*255).astype(np.uint8) )
io.wait_any_key() io.wait_any_key()
io.progress_bar_inc(1) io.progress_bar_inc(1)
#override #override
def on_tick(self): def on_tick(self):
self.host_processor.process_messages() self.host_processor.process_messages()
#override #override
def get_result(self): def get_result(self):
return self.files_processed, self.faces_processed return self.files_processed, self.faces_processed
def main (args, device_args): def main (args, device_args):
io.log_info ("Running converter.\r\n") io.log_info ("Running converter.\r\n")
aligned_dir = args.get('aligned_dir', None) aligned_dir = args.get('aligned_dir', None)
try: try:
input_path = Path(args['input_dir']) input_path = Path(args['input_dir'])
output_path = Path(args['output_dir']) output_path = Path(args['output_dir'])
model_path = Path(args['model_dir']) model_path = Path(args['model_dir'])
if not input_path.exists(): if not input_path.exists():
io.log_err('Input directory not found. Please ensure it exists.') io.log_err('Input directory not found. Please ensure it exists.')
return return
@ -218,69 +218,69 @@ def main (args, device_args):
Path(filename).unlink() Path(filename).unlink()
else: else:
output_path.mkdir(parents=True, exist_ok=True) output_path.mkdir(parents=True, exist_ok=True)
if not model_path.exists(): if not model_path.exists():
io.log_err('Model directory not found. Please ensure it exists.') io.log_err('Model directory not found. Please ensure it exists.')
return return
import models import models
model = models.import_model( args['model_name'] )(model_path, device_args=device_args) model = models.import_model( args['model_name'] )(model_path, device_args=device_args)
converter = model.get_converter() converter = model.get_converter()
converter.dummy_predict() converter.dummy_predict()
alignments = None alignments = None
if converter.type == Converter.TYPE_FACE: if converter.type == Converter.TYPE_FACE:
if aligned_dir is None: if aligned_dir is None:
io.log_err('Aligned directory not found. Please ensure it exists.') io.log_err('Aligned directory not found. Please ensure it exists.')
return return
aligned_path = Path(aligned_dir) aligned_path = Path(aligned_dir)
if not aligned_path.exists(): if not aligned_path.exists():
io.log_err('Aligned directory not found. Please ensure it exists.') io.log_err('Aligned directory not found. Please ensure it exists.')
return return
alignments = {} alignments = {}
aligned_path_image_paths = Path_utils.get_image_paths(aligned_path) aligned_path_image_paths = Path_utils.get_image_paths(aligned_path)
for filepath in io.progress_bar_generator(aligned_path_image_paths, "Collecting alignments"): for filepath in io.progress_bar_generator(aligned_path_image_paths, "Collecting alignments"):
filepath = Path(filepath) filepath = Path(filepath)
if filepath.suffix == '.png': if filepath.suffix == '.png':
dflimg = DFLPNG.load( str(filepath) ) dflimg = DFLPNG.load( str(filepath) )
elif filepath.suffix == '.jpg': elif filepath.suffix == '.jpg':
dflimg = DFLJPG.load ( str(filepath) ) dflimg = DFLJPG.load ( str(filepath) )
else: else:
dflimg = None dflimg = None
if dflimg is None: if dflimg is None:
io.log_err ("%s is not a dfl image file" % (filepath.name) ) io.log_err ("%s is not a dfl image file" % (filepath.name) )
continue continue
source_filename_stem = Path( dflimg.get_source_filename() ).stem source_filename_stem = Path( dflimg.get_source_filename() ).stem
if source_filename_stem not in alignments.keys(): if source_filename_stem not in alignments.keys():
alignments[ source_filename_stem ] = [] alignments[ source_filename_stem ] = []
alignments[ source_filename_stem ].append (dflimg.get_source_landmarks()) alignments[ source_filename_stem ].append (dflimg.get_source_landmarks())
files_processed, faces_processed = ConvertSubprocessor ( files_processed, faces_processed = ConvertSubprocessor (
converter = converter, converter = converter,
input_path_image_paths = Path_utils.get_image_paths(input_path), input_path_image_paths = Path_utils.get_image_paths(input_path),
output_path = output_path, output_path = output_path,
alignments = alignments, alignments = alignments,
debug = args.get('debug',False) debug = args.get('debug',False)
).run() ).run()
model.finalize() model.finalize()
except Exception as e: except Exception as e:
print ( 'Error: %s' % (str(e))) print ( 'Error: %s' % (str(e)))
traceback.print_exc() traceback.print_exc()
''' '''
if model_name == 'AVATAR': if model_name == 'AVATAR':
output_path_image_paths = Path_utils.get_image_paths(output_path) output_path_image_paths = Path_utils.get_image_paths(output_path)
last_ok_frame = -1 last_ok_frame = -1
for filename in output_path_image_paths: for filename in output_path_image_paths:
filename_path = Path(filename) filename_path = Path(filename)
@ -289,15 +289,15 @@ if model_name == 'AVATAR':
frame = int(stem) frame = int(stem)
except: except:
raise Exception ('Aligned avatars must be created from indexed sequence files.') raise Exception ('Aligned avatars must be created from indexed sequence files.')
if frame-last_ok_frame > 1: if frame-last_ok_frame > 1:
start = last_ok_frame + 1 start = last_ok_frame + 1
end = frame - 1 end = frame - 1
print ("Filling gaps: [%d...%d]" % (start, end) ) print ("Filling gaps: [%d...%d]" % (start, end) )
for i in range (start, end+1): for i in range (start, end+1):
shutil.copy ( str(filename), str( output_path / ('%.5d%s' % (i, filename_path.suffix )) ) ) shutil.copy ( str(filename), str( output_path / ('%.5d%s' % (i, filename_path.suffix )) ) )
last_ok_frame = frame last_ok_frame = frame
''' '''
#interpolate landmarks #interpolate landmarks
@ -306,28 +306,28 @@ if model_name == 'AVATAR':
#a = sorted(alignments.keys()) #a = sorted(alignments.keys())
#a_len = len(a) #a_len = len(a)
# #
#box_pts = 3 #box_pts = 3
#box = np.ones(box_pts)/box_pts #box = np.ones(box_pts)/box_pts
#for i in range( a_len ): #for i in range( a_len ):
# if i >= box_pts and i <= a_len-box_pts-1: # if i >= box_pts and i <= a_len-box_pts-1:
# af0 = alignments[ a[i] ][0] ##first face # af0 = alignments[ a[i] ][0] ##first face
# m0 = LandmarksProcessor.get_transform_mat (af0, 256, face_type=FaceType.FULL) # m0 = LandmarksProcessor.get_transform_mat (af0, 256, face_type=FaceType.FULL)
# #
# points = [] # points = []
# #
# for j in range(-box_pts, box_pts+1): # for j in range(-box_pts, box_pts+1):
# af = alignments[ a[i+j] ][0] ##first face # af = alignments[ a[i+j] ][0] ##first face
# m = LandmarksProcessor.get_transform_mat (af, 256, face_type=FaceType.FULL) # m = LandmarksProcessor.get_transform_mat (af, 256, face_type=FaceType.FULL)
# p = LandmarksProcessor.transform_points (af, m) # p = LandmarksProcessor.transform_points (af, m)
# points.append (p) # points.append (p)
# #
# points = np.array(points) # points = np.array(points)
# points_len = len(points) # points_len = len(points)
# t_points = np.transpose(points, [1,0,2]) # t_points = np.transpose(points, [1,0,2])
# #
# p1 = np.array ( [ int(np.convolve(x[:,0], box, mode='same')[points_len//2]) for x in t_points ] ) # p1 = np.array ( [ int(np.convolve(x[:,0], box, mode='same')[points_len//2]) for x in t_points ] )
# p2 = np.array ( [ int(np.convolve(x[:,1], box, mode='same')[points_len//2]) for x in t_points ] ) # p2 = np.array ( [ int(np.convolve(x[:,1], box, mode='same')[points_len//2]) for x in t_points ] )
# #
# new_points = np.concatenate( [np.expand_dims(p1,-1),np.expand_dims(p2,-1)], -1 ) # new_points = np.concatenate( [np.expand_dims(p1,-1),np.expand_dims(p2,-1)], -1 )
# #
# alignments[ a[i] ][0] = LandmarksProcessor.transform_points (new_points, m0, True).astype(np.int32) # alignments[ a[i] ][0] = LandmarksProcessor.transform_points (new_points, m0, True).astype(np.int32)

View file

@ -18,9 +18,9 @@ from facelib import LandmarksProcessor
from nnlib import nnlib from nnlib import nnlib
from joblib import Subprocessor from joblib import Subprocessor
from interact import interact as io from interact import interact as io
class ExtractSubprocessor(Subprocessor): class ExtractSubprocessor(Subprocessor):
class Cli(Subprocessor.Cli): class Cli(Subprocessor.Cli):
#override #override
@ -32,19 +32,19 @@ class ExtractSubprocessor(Subprocessor):
self.face_type = client_dict['face_type'] self.face_type = client_dict['face_type']
self.device_idx = client_dict['device_idx'] self.device_idx = client_dict['device_idx']
self.cpu_only = client_dict['device_type'] == 'CPU' self.cpu_only = client_dict['device_type'] == 'CPU'
self.output_path = Path(client_dict['output_dir']) if 'output_dir' in client_dict.keys() else None self.output_path = Path(client_dict['output_dir']) if 'output_dir' in client_dict.keys() else None
self.debug_dir = client_dict['debug_dir'] self.debug_dir = client_dict['debug_dir']
self.detector = client_dict['detector'] self.detector = client_dict['detector']
self.cached_image = (None, None) self.cached_image = (None, None)
self.e = None self.e = None
device_config = nnlib.DeviceConfig ( cpu_only=self.cpu_only, force_gpu_idx=self.device_idx, allow_growth=True) device_config = nnlib.DeviceConfig ( cpu_only=self.cpu_only, force_gpu_idx=self.device_idx, allow_growth=True)
if self.type == 'rects': if self.type == 'rects':
if self.detector is not None: if self.detector is not None:
if self.detector == 'mt': if self.detector == 'mt':
nnlib.import_all (device_config) nnlib.import_all (device_config)
self.e = facelib.MTCExtractor() self.e = facelib.MTCExtractor()
elif self.detector == 'dlib': elif self.detector == 'dlib':
nnlib.import_dlib (device_config) nnlib.import_dlib (device_config)
self.e = facelib.DLIBExtractor(nnlib.dlib) self.e = facelib.DLIBExtractor(nnlib.dlib)
@ -53,10 +53,10 @@ class ExtractSubprocessor(Subprocessor):
self.e = facelib.S3FDExtractor() self.e = facelib.S3FDExtractor()
else: else:
raise ValueError ("Wrong detector type.") raise ValueError ("Wrong detector type.")
if self.e is not None: if self.e is not None:
self.e.__enter__() self.e.__enter__()
elif self.type == 'landmarks': elif self.type == 'landmarks':
nnlib.import_all (device_config) nnlib.import_all (device_config)
self.e = facelib.LandmarksExtractor(nnlib.keras) self.e = facelib.LandmarksExtractor(nnlib.keras)
@ -66,15 +66,15 @@ class ExtractSubprocessor(Subprocessor):
self.second_pass_e.__enter__() self.second_pass_e.__enter__()
else: else:
self.second_pass_e = None self.second_pass_e = None
elif self.type == 'final': elif self.type == 'final':
pass pass
#override #override
def on_finalize(self): def on_finalize(self):
if self.e is not None: if self.e is not None:
self.e.__exit__() self.e.__exit__()
#override #override
def process_data(self, data): def process_data(self, data):
filename_path = Path( data[0] ) filename_path = Path( data[0] )
@ -84,64 +84,64 @@ class ExtractSubprocessor(Subprocessor):
image = self.cached_image[1] #cached image for manual extractor image = self.cached_image[1] #cached image for manual extractor
else: else:
image = cv2_imread( filename_path_str ) image = cv2_imread( filename_path_str )
if image is None: if image is None:
self.log_err ( 'Failed to extract %s, reason: cv2_imread() fail.' % ( str(filename_path) ) ) self.log_err ( 'Failed to extract %s, reason: cv2_imread() fail.' % ( str(filename_path) ) )
return None return None
image_shape = image.shape image_shape = image.shape
if len(image_shape) == 2: if len(image_shape) == 2:
h, w = image.shape h, w = image.shape
ch = 1 ch = 1
else: else:
h, w, ch = image.shape h, w, ch = image.shape
if ch == 1: if ch == 1:
image = np.repeat ( image [:,:,np.newaxis], 3, -1 ) image = np.repeat ( image [:,:,np.newaxis], 3, -1 )
elif ch == 4: elif ch == 4:
image = image[:,:,0:3] image = image[:,:,0:3]
wm = w % 2 wm = w % 2
hm = h % 2 hm = h % 2
if wm + hm != 0: #fix odd image if wm + hm != 0: #fix odd image
image = image[0:h-hm,0:w-wm,:] image = image[0:h-hm,0:w-wm,:]
self.cached_image = ( filename_path_str, image ) self.cached_image = ( filename_path_str, image )
src_dflimg = None src_dflimg = None
h, w, ch = image.shape h, w, ch = image.shape
if h == w: if h == w:
#extracting from already extracted jpg image? #extracting from already extracted jpg image?
if filename_path.suffix == '.jpg': if filename_path.suffix == '.jpg':
src_dflimg = DFLJPG.load ( str(filename_path) ) src_dflimg = DFLJPG.load ( str(filename_path) )
if self.type == 'rects': if self.type == 'rects':
if min(w,h) < 128: if min(w,h) < 128:
self.log_err ( 'Image is too small %s : [%d, %d]' % ( str(filename_path), w, h ) ) self.log_err ( 'Image is too small %s : [%d, %d]' % ( str(filename_path), w, h ) )
rects = [] rects = []
else: else:
rects = self.e.extract_from_bgr (image) rects = self.e.extract_from_bgr (image)
return [str(filename_path), rects] return [str(filename_path), rects]
elif self.type == 'landmarks': elif self.type == 'landmarks':
rects = data[1] rects = data[1]
if rects is None: if rects is None:
landmarks = None landmarks = None
else: else:
landmarks = self.e.extract_from_bgr (image, rects, self.second_pass_e if src_dflimg is None else None) landmarks = self.e.extract_from_bgr (image, rects, self.second_pass_e if src_dflimg is None else None)
return [str(filename_path), landmarks] return [str(filename_path), landmarks]
elif self.type == 'final': elif self.type == 'final':
result = [] result = []
faces = data[1] faces = data[1]
if self.debug_dir is not None: if self.debug_dir is not None:
debug_output_file = str( Path(self.debug_dir) / (filename_path.stem+'.jpg') ) debug_output_file = str( Path(self.debug_dir) / (filename_path.stem+'.jpg') )
debug_image = image.copy() debug_image = image.copy()
if src_dflimg is not None and len(faces) != 1: if src_dflimg is not None and len(faces) != 1:
#if re-extracting from dflimg and more than 1 or zero faces detected - dont process and just copy it #if re-extracting from dflimg and more than 1 or zero faces detected - dont process and just copy it
print("src_dflimg is not None and len(faces) != 1", str(filename_path) ) print("src_dflimg is not None and len(faces) != 1", str(filename_path) )
@ -151,26 +151,26 @@ class ExtractSubprocessor(Subprocessor):
result.append (output_file) result.append (output_file)
else: else:
face_idx = 0 face_idx = 0
for face in faces: for face in faces:
rect = np.array(face[0]) rect = np.array(face[0])
image_landmarks = face[1] image_landmarks = face[1]
if image_landmarks is None: if image_landmarks is None:
continue continue
image_landmarks = np.array(image_landmarks) image_landmarks = np.array(image_landmarks)
if self.face_type == FaceType.MARK_ONLY: if self.face_type == FaceType.MARK_ONLY:
face_image = image face_image = image
face_image_landmarks = image_landmarks face_image_landmarks = image_landmarks
else: else:
image_to_face_mat = LandmarksProcessor.get_transform_mat (image_landmarks, self.image_size, self.face_type) image_to_face_mat = LandmarksProcessor.get_transform_mat (image_landmarks, self.image_size, self.face_type)
face_image = cv2.warpAffine(image, image_to_face_mat, (self.image_size, self.image_size), cv2.INTER_LANCZOS4) face_image = cv2.warpAffine(image, image_to_face_mat, (self.image_size, self.image_size), cv2.INTER_LANCZOS4)
face_image_landmarks = LandmarksProcessor.transform_points (image_landmarks, image_to_face_mat) face_image_landmarks = LandmarksProcessor.transform_points (image_landmarks, image_to_face_mat)
landmarks_bbox = LandmarksProcessor.transform_points ( [ (0,0), (0,self.image_size-1), (self.image_size-1, self.image_size-1), (self.image_size-1,0) ], image_to_face_mat, True) landmarks_bbox = LandmarksProcessor.transform_points ( [ (0,0), (0,self.image_size-1), (self.image_size-1, self.image_size-1), (self.image_size-1,0) ], image_to_face_mat, True)
rect_area = mathlib.polygon_area(np.array(rect[[0,2,2,0]]), np.array(rect[[1,1,3,3]])) rect_area = mathlib.polygon_area(np.array(rect[[0,2,2,0]]), np.array(rect[[1,1,3,3]]))
landmarks_area = mathlib.polygon_area(landmarks_bbox[:,0], landmarks_bbox[:,1] ) landmarks_area = mathlib.polygon_area(landmarks_bbox[:,0], landmarks_bbox[:,1] )
if landmarks_area > 4*rect_area: #get rid of faces which umeyama-landmark-area > 4*detector-rect-area if landmarks_area > 4*rect_area: #get rid of faces which umeyama-landmark-area > 4*detector-rect-area
continue continue
@ -192,24 +192,24 @@ class ExtractSubprocessor(Subprocessor):
source_rect=rect, source_rect=rect,
source_landmarks=image_landmarks.tolist(), source_landmarks=image_landmarks.tolist(),
image_to_face_mat=image_to_face_mat image_to_face_mat=image_to_face_mat
) )
result.append (output_file) result.append (output_file)
face_idx += 1 face_idx += 1
if self.debug_dir is not None: if self.debug_dir is not None:
cv2_imwrite(debug_output_file, debug_image, [int(cv2.IMWRITE_JPEG_QUALITY), 50] ) cv2_imwrite(debug_output_file, debug_image, [int(cv2.IMWRITE_JPEG_QUALITY), 50] )
return result return result
#overridable #overridable
def get_data_name (self, data): def get_data_name (self, data):
#return string identificator of your data #return string identificator of your data
return data[0] return data[0]
#override #override
def __init__(self, input_data, type, image_size, face_type, debug_dir, multi_gpu=False, cpu_only=False, manual=False, manual_window_size=0, detector=None, output_path=None): def __init__(self, input_data, type, image_size, face_type, debug_dir, multi_gpu=False, cpu_only=False, manual=False, manual_window_size=0, detector=None, output_path=None):
self.input_data = input_data self.input_data = input_data
self.type = type self.type = type
self.image_size = image_size self.image_size = image_size
@ -218,8 +218,8 @@ class ExtractSubprocessor(Subprocessor):
self.multi_gpu = multi_gpu self.multi_gpu = multi_gpu
self.cpu_only = cpu_only self.cpu_only = cpu_only
self.detector = detector self.detector = detector
self.output_path = output_path self.output_path = output_path
self.manual = manual self.manual = manual
self.manual_window_size = manual_window_size self.manual_window_size = manual_window_size
self.result = [] self.result = []
@ -233,32 +233,32 @@ class ExtractSubprocessor(Subprocessor):
io.named_window(self.wnd_name) io.named_window(self.wnd_name)
io.capture_mouse(self.wnd_name) io.capture_mouse(self.wnd_name)
io.capture_keys(self.wnd_name) io.capture_keys(self.wnd_name)
self.cache_original_image = (None, None) self.cache_original_image = (None, None)
self.cache_image = (None, None) self.cache_image = (None, None)
self.cache_text_lines_img = (None, None) self.cache_text_lines_img = (None, None)
self.hide_help = False self.hide_help = False
self.landmarks = None self.landmarks = None
self.x = 0 self.x = 0
self.y = 0 self.y = 0
self.rect_size = 100 self.rect_size = 100
self.rect_locked = False self.rect_locked = False
self.extract_needed = True self.extract_needed = True
io.progress_bar (None, len (self.input_data)) io.progress_bar (None, len (self.input_data))
#override #override
def on_clients_finalized(self): def on_clients_finalized(self):
if self.manual == True: if self.manual == True:
io.destroy_all_windows() io.destroy_all_windows()
io.progress_bar_close() io.progress_bar_close()
def get_devices_for_type (self, type, multi_gpu, cpu_only): def get_devices_for_type (self, type, multi_gpu, cpu_only):
if 'cpu' in nnlib.device.backend: if 'cpu' in nnlib.device.backend:
cpu_only = True cpu_only = True
if not cpu_only and (type == 'rects' or type == 'landmarks'): if not cpu_only and (type == 'rects' or type == 'landmarks'):
if type == 'rects' and (self.detector == 'mt') and nnlib.device.backend == "plaidML": if type == 'rects' and (self.detector == 'mt') and nnlib.device.backend == "plaidML":
cpu_only = True cpu_only = True
@ -269,11 +269,11 @@ class ExtractSubprocessor(Subprocessor):
devices = [nnlib.device.getBestValidDeviceIdx()] devices = [nnlib.device.getBestValidDeviceIdx()]
if len(devices) == 0: if len(devices) == 0:
devices = [0] devices = [0]
for idx in devices: for idx in devices:
dev_name = nnlib.device.getDeviceName(idx) dev_name = nnlib.device.getDeviceName(idx)
dev_vram = nnlib.device.getDeviceVRAMTotalGb(idx) dev_vram = nnlib.device.getDeviceVRAMTotalGb(idx)
if not self.manual and ( self.type == 'rects' and self.detector != 's3fd' ): if not self.manual and ( self.type == 'rects' and self.detector != 's3fd' ):
for i in range ( int (max (1, dev_vram / 2) ) ): for i in range ( int (max (1, dev_vram / 2) ) ):
yield (idx, 'GPU', '%s #%d' % (dev_name,i) , dev_vram) yield (idx, 'GPU', '%s #%d' % (dev_name,i) , dev_vram)
@ -286,21 +286,21 @@ class ExtractSubprocessor(Subprocessor):
else: else:
for i in range( min(8, multiprocessing.cpu_count() // 2) ): for i in range( min(8, multiprocessing.cpu_count() // 2) ):
yield (i, 'CPU', 'CPU%d' % (i), 0 ) yield (i, 'CPU', 'CPU%d' % (i), 0 )
if type == 'final': if type == 'final':
for i in range( min(8, multiprocessing.cpu_count()) ): for i in range( min(8, multiprocessing.cpu_count()) ):
yield (i, 'CPU', 'CPU%d' % (i), 0 ) yield (i, 'CPU', 'CPU%d' % (i), 0 )
#override #override
def process_info_generator(self): def process_info_generator(self):
base_dict = {'type' : self.type, base_dict = {'type' : self.type,
'image_size': self.image_size, 'image_size': self.image_size,
'face_type': self.face_type, 'face_type': self.face_type,
'debug_dir': self.debug_dir, 'debug_dir': self.debug_dir,
'output_dir': str(self.output_path), 'output_dir': str(self.output_path),
'detector': self.detector} 'detector': self.detector}
for (device_idx, device_type, device_name, device_total_vram_gb) in self.get_devices_for_type(self.type, self.multi_gpu, self.cpu_only): for (device_idx, device_type, device_name, device_total_vram_gb) in self.get_devices_for_type(self.type, self.multi_gpu, self.cpu_only):
client_dict = base_dict.copy() client_dict = base_dict.copy()
client_dict['device_idx'] = device_idx client_dict['device_idx'] = device_idx
client_dict['device_name'] = device_name client_dict['device_name'] = device_name
@ -311,7 +311,7 @@ class ExtractSubprocessor(Subprocessor):
def get_data(self, host_dict): def get_data(self, host_dict):
if not self.manual: if not self.manual:
if len (self.input_data) > 0: if len (self.input_data) > 0:
return self.input_data.pop(0) return self.input_data.pop(0)
else: else:
need_remark_face = False need_remark_face = False
@ -327,7 +327,7 @@ class ExtractSubprocessor(Subprocessor):
self.rect, self.landmarks = faces.pop() self.rect, self.landmarks = faces.pop()
faces.clear() faces.clear()
redraw_needed = True redraw_needed = True
self.rect_locked = True self.rect_locked = True
self.rect_size = ( self.rect[2] - self.rect[0] ) / 2 self.rect_size = ( self.rect[2] - self.rect[0] ) / 2
self.x = ( self.rect[0] + self.rect[2] ) / 2 self.x = ( self.rect[0] + self.rect[2] ) / 2
self.y = ( self.rect[1] + self.rect[3] ) / 2 self.y = ( self.rect[1] + self.rect[3] ) / 2
@ -338,19 +338,19 @@ class ExtractSubprocessor(Subprocessor):
else: else:
self.original_image = cv2_imread( filename ) self.original_image = cv2_imread( filename )
self.cache_original_image = (filename, self.original_image ) self.cache_original_image = (filename, self.original_image )
(h,w,c) = self.original_image.shape (h,w,c) = self.original_image.shape
self.view_scale = 1.0 if self.manual_window_size == 0 else self.manual_window_size / ( h * (16.0/9.0) ) self.view_scale = 1.0 if self.manual_window_size == 0 else self.manual_window_size / ( h * (16.0/9.0) )
if self.cache_image[0] == (h,w,c) + (self.view_scale,filename): if self.cache_image[0] == (h,w,c) + (self.view_scale,filename):
self.image = self.cache_image[1] self.image = self.cache_image[1]
else: else:
self.image = cv2.resize (self.original_image, ( int(w*self.view_scale), int(h*self.view_scale) ), interpolation=cv2.INTER_LINEAR) self.image = cv2.resize (self.original_image, ( int(w*self.view_scale), int(h*self.view_scale) ), interpolation=cv2.INTER_LINEAR)
self.cache_image = ( (h,w,c) + (self.view_scale,filename), self.image ) self.cache_image = ( (h,w,c) + (self.view_scale,filename), self.image )
(h,w,c) = self.image.shape (h,w,c) = self.image.shape
sh = (0,0, w, min(100, h) ) sh = (0,0, w, min(100, h) )
if self.cache_text_lines_img[0] == sh: if self.cache_text_lines_img[0] == sh:
self.text_lines_img = self.cache_text_lines_img[1] self.text_lines_img = self.cache_text_lines_img[1]
else: else:
@ -362,30 +362,30 @@ class ExtractSubprocessor(Subprocessor):
'[,] [.]- prev frame, next frame. [Q] - skip remaining frames', '[,] [.]- prev frame, next frame. [Q] - skip remaining frames',
'[h] - hide this help' '[h] - hide this help'
], (1, 1, 1) )*255).astype(np.uint8) ], (1, 1, 1) )*255).astype(np.uint8)
self.cache_text_lines_img = (sh, self.text_lines_img) self.cache_text_lines_img = (sh, self.text_lines_img)
while True: while True:
io.process_messages(0.0001) io.process_messages(0.0001)
new_x = self.x new_x = self.x
new_y = self.y new_y = self.y
new_rect_size = self.rect_size new_rect_size = self.rect_size
mouse_events = io.get_mouse_events(self.wnd_name) mouse_events = io.get_mouse_events(self.wnd_name)
for ev in mouse_events: for ev in mouse_events:
(x, y, ev, flags) = ev (x, y, ev, flags) = ev
if ev == io.EVENT_MOUSEWHEEL and not self.rect_locked: if ev == io.EVENT_MOUSEWHEEL and not self.rect_locked:
mod = 1 if flags > 0 else -1 mod = 1 if flags > 0 else -1
diff = 1 if new_rect_size <= 40 else np.clip(new_rect_size / 10, 1, 10) diff = 1 if new_rect_size <= 40 else np.clip(new_rect_size / 10, 1, 10)
new_rect_size = max (5, new_rect_size + diff*mod) new_rect_size = max (5, new_rect_size + diff*mod)
elif ev == io.EVENT_LBUTTONDOWN: elif ev == io.EVENT_LBUTTONDOWN:
self.rect_locked = not self.rect_locked self.rect_locked = not self.rect_locked
self.extract_needed = True self.extract_needed = True
elif not self.rect_locked: elif not self.rect_locked:
new_x = np.clip (x, 0, w-1) / self.view_scale new_x = np.clip (x, 0, w-1) / self.view_scale
new_y = np.clip (y, 0, h-1) / self.view_scale new_y = np.clip (y, 0, h-1) / self.view_scale
key_events = io.get_key_events(self.wnd_name) key_events = io.get_key_events(self.wnd_name)
key, = key_events[-1] if len(key_events) > 0 else (0,) key, = key_events[-1] if len(key_events) > 0 else (0,)
@ -393,48 +393,48 @@ class ExtractSubprocessor(Subprocessor):
#confirm frame #confirm frame
is_frame_done = True is_frame_done = True
faces.append ( [(self.rect), self.landmarks] ) faces.append ( [(self.rect), self.landmarks] )
break break
elif key == ord(' '): elif key == ord(' '):
#confirm skip frame #confirm skip frame
is_frame_done = True is_frame_done = True
break break
elif key == ord(',') and len(self.result) > 0: elif key == ord(',') and len(self.result) > 0:
#go prev frame #go prev frame
if self.rect_locked: if self.rect_locked:
# Only save the face if the rect is still locked # Only save the face if the rect is still locked
faces.append ( [(self.rect), self.landmarks] ) faces.append ( [(self.rect), self.landmarks] )
self.input_data.insert(0, self.result.pop() ) self.input_data.insert(0, self.result.pop() )
io.progress_bar_inc(-1) io.progress_bar_inc(-1)
need_remark_face = True need_remark_face = True
break break
elif key == ord('.'): elif key == ord('.'):
#go next frame #go next frame
if self.rect_locked: if self.rect_locked:
# Only save the face if the rect is still locked # Only save the face if the rect is still locked
faces.append ( [(self.rect), self.landmarks] ) faces.append ( [(self.rect), self.landmarks] )
need_remark_face = True need_remark_face = True
is_frame_done = True is_frame_done = True
break break
elif key == ord('q'): elif key == ord('q'):
#skip remaining #skip remaining
if self.rect_locked: if self.rect_locked:
faces.append ( [(self.rect), self.landmarks] ) faces.append ( [(self.rect), self.landmarks] )
while len(self.input_data) > 0: while len(self.input_data) > 0:
self.result.append( self.input_data.pop(0) ) self.result.append( self.input_data.pop(0) )
io.progress_bar_inc(1) io.progress_bar_inc(1)
break break
elif key == ord('h'): elif key == ord('h'):
self.hide_help = not self.hide_help self.hide_help = not self.hide_help
break break
if self.x != new_x or \ if self.x != new_x or \
self.y != new_y or \ self.y != new_y or \
self.rect_size != new_rect_size or \ self.rect_size != new_rect_size or \
@ -443,33 +443,33 @@ class ExtractSubprocessor(Subprocessor):
self.x = new_x self.x = new_x
self.y = new_y self.y = new_y
self.rect_size = new_rect_size self.rect_size = new_rect_size
self.rect = ( int(self.x-self.rect_size), self.rect = ( int(self.x-self.rect_size),
int(self.y-self.rect_size), int(self.y-self.rect_size),
int(self.x+self.rect_size), int(self.x+self.rect_size),
int(self.y+self.rect_size) ) int(self.y+self.rect_size) )
if redraw_needed: if redraw_needed:
redraw_needed = False redraw_needed = False
return [filename, None] return [filename, None]
else: else:
return [filename, [self.rect]] return [filename, [self.rect]]
else: else:
is_frame_done = True is_frame_done = True
if is_frame_done: if is_frame_done:
self.result.append ( data ) self.result.append ( data )
self.input_data.pop(0) self.input_data.pop(0)
io.progress_bar_inc(1) io.progress_bar_inc(1)
self.extract_needed = True self.extract_needed = True
self.rect_locked = False self.rect_locked = False
return None return None
#override #override
def on_data_return (self, host_dict, data): def on_data_return (self, host_dict, data):
if not self.manual: if not self.manual:
self.input_data.insert(0, data) self.input_data.insert(0, data)
#override #override
def on_result (self, host_dict, data, result): def on_result (self, host_dict, data, result):
@ -477,33 +477,33 @@ class ExtractSubprocessor(Subprocessor):
filename, landmarks = result filename, landmarks = result
if landmarks is not None: if landmarks is not None:
self.landmarks = landmarks[0][1] self.landmarks = landmarks[0][1]
(h,w,c) = self.image.shape (h,w,c) = self.image.shape
if not self.hide_help: if not self.hide_help:
image = cv2.addWeighted (self.image,1.0,self.text_lines_img,1.0,0) image = cv2.addWeighted (self.image,1.0,self.text_lines_img,1.0,0)
else: else:
image = self.image.copy() image = self.image.copy()
view_rect = (np.array(self.rect) * self.view_scale).astype(np.int).tolist() view_rect = (np.array(self.rect) * self.view_scale).astype(np.int).tolist()
view_landmarks = (np.array(self.landmarks) * self.view_scale).astype(np.int).tolist() view_landmarks = (np.array(self.landmarks) * self.view_scale).astype(np.int).tolist()
if self.rect_size <= 40: if self.rect_size <= 40:
scaled_rect_size = h // 3 if w > h else w // 3 scaled_rect_size = h // 3 if w > h else w // 3
p1 = (self.x - self.rect_size, self.y - self.rect_size) p1 = (self.x - self.rect_size, self.y - self.rect_size)
p2 = (self.x + self.rect_size, self.y - self.rect_size) p2 = (self.x + self.rect_size, self.y - self.rect_size)
p3 = (self.x - self.rect_size, self.y + self.rect_size) p3 = (self.x - self.rect_size, self.y + self.rect_size)
wh = h if h < w else w wh = h if h < w else w
np1 = (w / 2 - wh / 4, h / 2 - wh / 4) np1 = (w / 2 - wh / 4, h / 2 - wh / 4)
np2 = (w / 2 + wh / 4, h / 2 - wh / 4) np2 = (w / 2 + wh / 4, h / 2 - wh / 4)
np3 = (w / 2 - wh / 4, h / 2 + wh / 4) np3 = (w / 2 - wh / 4, h / 2 + wh / 4)
mat = cv2.getAffineTransform( np.float32([p1,p2,p3])*self.view_scale, np.float32([np1,np2,np3]) ) mat = cv2.getAffineTransform( np.float32([p1,p2,p3])*self.view_scale, np.float32([np1,np2,np3]) )
image = cv2.warpAffine(image, mat,(w,h) ) image = cv2.warpAffine(image, mat,(w,h) )
view_landmarks = LandmarksProcessor.transform_points (view_landmarks, mat) view_landmarks = LandmarksProcessor.transform_points (view_landmarks, mat)
landmarks_color = (255,255,0) if self.rect_locked else (0,255,0) landmarks_color = (255,255,0) if self.rect_locked else (0,255,0)
LandmarksProcessor.draw_rect_landmarks (image, view_rect, view_landmarks, self.image_size, self.face_type, landmarks_color=landmarks_color) LandmarksProcessor.draw_rect_landmarks (image, view_rect, view_landmarks, self.image_size, self.face_type, landmarks_color=landmarks_color)
self.extract_needed = False self.extract_needed = False
@ -513,10 +513,10 @@ class ExtractSubprocessor(Subprocessor):
if self.type == 'rects': if self.type == 'rects':
self.result.append ( result ) self.result.append ( result )
elif self.type == 'landmarks': elif self.type == 'landmarks':
self.result.append ( result ) self.result.append ( result )
elif self.type == 'final': elif self.type == 'final':
self.result += result self.result += result
io.progress_bar_inc(1) io.progress_bar_inc(1)
#override #override
@ -530,47 +530,47 @@ class DeletedFilesSearcherSubprocessor(Subprocessor):
def on_initialize(self, client_dict): def on_initialize(self, client_dict):
self.debug_paths_stems = client_dict['debug_paths_stems'] self.debug_paths_stems = client_dict['debug_paths_stems']
return None return None
#override #override
def process_data(self, data): def process_data(self, data):
input_path_stem = Path(data[0]).stem input_path_stem = Path(data[0]).stem
return any ( [ input_path_stem == d_stem for d_stem in self.debug_paths_stems] ) return any ( [ input_path_stem == d_stem for d_stem in self.debug_paths_stems] )
#override #override
def get_data_name (self, data): def get_data_name (self, data):
#return string identificator of your data #return string identificator of your data
return data[0] return data[0]
#override #override
def __init__(self, input_paths, debug_paths ): def __init__(self, input_paths, debug_paths ):
self.input_paths = input_paths self.input_paths = input_paths
self.debug_paths_stems = [ Path(d).stem for d in debug_paths] self.debug_paths_stems = [ Path(d).stem for d in debug_paths]
self.result = [] self.result = []
super().__init__('DeletedFilesSearcherSubprocessor', DeletedFilesSearcherSubprocessor.Cli, 60) super().__init__('DeletedFilesSearcherSubprocessor', DeletedFilesSearcherSubprocessor.Cli, 60)
#override #override
def process_info_generator(self): def process_info_generator(self):
for i in range(min(multiprocessing.cpu_count(), 8)): for i in range(min(multiprocessing.cpu_count(), 8)):
yield 'CPU%d' % (i), {}, {'debug_paths_stems' : self.debug_paths_stems} yield 'CPU%d' % (i), {}, {'debug_paths_stems' : self.debug_paths_stems}
#override #override
def on_clients_initialized(self): def on_clients_initialized(self):
io.progress_bar ("Searching deleted files", len (self.input_paths)) io.progress_bar ("Searching deleted files", len (self.input_paths))
#override #override
def on_clients_finalized(self): def on_clients_finalized(self):
io.progress_bar_close() io.progress_bar_close()
#override #override
def get_data(self, host_dict): def get_data(self, host_dict):
if len (self.input_paths) > 0: if len (self.input_paths) > 0:
return [self.input_paths.pop(0)] return [self.input_paths.pop(0)]
return None return None
#override #override
def on_data_return (self, host_dict, data): def on_data_return (self, host_dict, data):
self.input_paths.insert(0, data[0]) self.input_paths.insert(0, data[0])
#override #override
def on_result (self, host_dict, data, result): def on_result (self, host_dict, data, result):
if result == False: if result == False:
@ -591,40 +591,40 @@ def main(input_dir,
image_size=256, image_size=256,
face_type='full_face', face_type='full_face',
device_args={}): device_args={}):
input_path = Path(input_dir) input_path = Path(input_dir)
output_path = Path(output_dir) output_path = Path(output_dir)
face_type = FaceType.fromString(face_type) face_type = FaceType.fromString(face_type)
multi_gpu = device_args.get('multi_gpu', False) multi_gpu = device_args.get('multi_gpu', False)
cpu_only = device_args.get('cpu_only', False) cpu_only = device_args.get('cpu_only', False)
if not input_path.exists(): if not input_path.exists():
raise ValueError('Input directory not found. Please ensure it exists.') raise ValueError('Input directory not found. Please ensure it exists.')
if output_path.exists(): if output_path.exists():
if not manual_output_debug_fix and input_path != output_path: if not manual_output_debug_fix and input_path != output_path:
for filename in Path_utils.get_image_paths(output_path): for filename in Path_utils.get_image_paths(output_path):
Path(filename).unlink() Path(filename).unlink()
else: else:
output_path.mkdir(parents=True, exist_ok=True) output_path.mkdir(parents=True, exist_ok=True)
if manual_output_debug_fix: if manual_output_debug_fix:
if debug_dir is None: if debug_dir is None:
raise ValueError('debug-dir must be specified') raise ValueError('debug-dir must be specified')
detector = 'manual' detector = 'manual'
io.log_info('Performing re-extract frames which were deleted from _debug directory.') io.log_info('Performing re-extract frames which were deleted from _debug directory.')
input_path_image_paths = Path_utils.get_image_unique_filestem_paths(input_path, verbose_print_func=io.log_info) input_path_image_paths = Path_utils.get_image_unique_filestem_paths(input_path, verbose_print_func=io.log_info)
if debug_dir is not None: if debug_dir is not None:
debug_output_path = Path(debug_dir) debug_output_path = Path(debug_dir)
if manual_output_debug_fix: if manual_output_debug_fix:
if not debug_output_path.exists(): if not debug_output_path.exists():
raise ValueError("%s not found " % ( str(debug_output_path) )) raise ValueError("%s not found " % ( str(debug_output_path) ))
input_path_image_paths = DeletedFilesSearcherSubprocessor (input_path_image_paths, Path_utils.get_image_paths(debug_output_path) ).run() input_path_image_paths = DeletedFilesSearcherSubprocessor (input_path_image_paths, Path_utils.get_image_paths(debug_output_path) ).run()
input_path_image_paths = sorted (input_path_image_paths) input_path_image_paths = sorted (input_path_image_paths)
else: else:
if debug_output_path.exists(): if debug_output_path.exists():
for filename in Path_utils.get_image_paths(debug_output_path): for filename in Path_utils.get_image_paths(debug_output_path):
@ -634,20 +634,20 @@ def main(input_dir,
images_found = len(input_path_image_paths) images_found = len(input_path_image_paths)
faces_detected = 0 faces_detected = 0
if images_found != 0: if images_found != 0:
if detector == 'manual': if detector == 'manual':
io.log_info ('Performing manual extract...') io.log_info ('Performing manual extract...')
extracted_faces = ExtractSubprocessor ([ (filename,[]) for filename in input_path_image_paths ], 'landmarks', image_size, face_type, debug_dir, cpu_only=cpu_only, manual=True, manual_window_size=manual_window_size).run() extracted_faces = ExtractSubprocessor ([ (filename,[]) for filename in input_path_image_paths ], 'landmarks', image_size, face_type, debug_dir, cpu_only=cpu_only, manual=True, manual_window_size=manual_window_size).run()
else: else:
io.log_info ('Performing 1st pass...') io.log_info ('Performing 1st pass...')
extracted_rects = ExtractSubprocessor ([ (x,) for x in input_path_image_paths ], 'rects', image_size, face_type, debug_dir, multi_gpu=multi_gpu, cpu_only=cpu_only, manual=False, detector=detector).run() extracted_rects = ExtractSubprocessor ([ (x,) for x in input_path_image_paths ], 'rects', image_size, face_type, debug_dir, multi_gpu=multi_gpu, cpu_only=cpu_only, manual=False, detector=detector).run()
io.log_info ('Performing 2nd pass...') io.log_info ('Performing 2nd pass...')
extracted_faces = ExtractSubprocessor (extracted_rects, 'landmarks', image_size, face_type, debug_dir, multi_gpu=multi_gpu, cpu_only=cpu_only, manual=False).run() extracted_faces = ExtractSubprocessor (extracted_rects, 'landmarks', image_size, face_type, debug_dir, multi_gpu=multi_gpu, cpu_only=cpu_only, manual=False).run()
if manual_fix: if manual_fix:
io.log_info ('Performing manual fix...') io.log_info ('Performing manual fix...')
if all ( np.array ( [ len(data[1]) > 0 for data in extracted_faces] ) == True ): if all ( np.array ( [ len(data[1]) > 0 for data in extracted_faces] ) == True ):
io.log_info ('All faces are detected, manual fix not needed.') io.log_info ('All faces are detected, manual fix not needed.')
else: else:
@ -657,8 +657,8 @@ def main(input_dir,
io.log_info ('Performing 3rd pass...') io.log_info ('Performing 3rd pass...')
final_imgs_paths = ExtractSubprocessor (extracted_faces, 'final', image_size, face_type, debug_dir, multi_gpu=multi_gpu, cpu_only=cpu_only, manual=False, output_path=output_path).run() final_imgs_paths = ExtractSubprocessor (extracted_faces, 'final', image_size, face_type, debug_dir, multi_gpu=multi_gpu, cpu_only=cpu_only, manual=False, output_path=output_path).run()
faces_detected = len(final_imgs_paths) faces_detected = len(final_imgs_paths)
io.log_info ('-------------------------') io.log_info ('-------------------------')
io.log_info ('Images found: %d' % (images_found) ) io.log_info ('Images found: %d' % (images_found) )
io.log_info ('Faces detected: %d' % (faces_detected) ) io.log_info ('Faces detected: %d' % (faces_detected) )
io.log_info ('-------------------------') io.log_info ('-------------------------')

View file

@ -17,18 +17,18 @@ from facelib import LandmarksProcessor
def main(input_dir, output_dir): def main(input_dir, output_dir):
input_path = Path(input_dir) input_path = Path(input_dir)
output_path = Path(output_dir) output_path = Path(output_dir)
if not input_path.exists(): if not input_path.exists():
raise ValueError('Input directory not found. Please ensure it exists.') raise ValueError('Input directory not found. Please ensure it exists.')
if not output_path.exists(): if not output_path.exists():
output_path.mkdir(parents=True) output_path.mkdir(parents=True)
wnd_name = "Labeling tool" wnd_name = "Labeling tool"
io.named_window (wnd_name) io.named_window (wnd_name)
io.capture_mouse(wnd_name) io.capture_mouse(wnd_name)
io.capture_keys(wnd_name) io.capture_keys(wnd_name)
#for filename in io.progress_bar_generator (Path_utils.get_image_paths(input_path), desc="Labeling"): #for filename in io.progress_bar_generator (Path_utils.get_image_paths(input_path), desc="Labeling"):
for filename in Path_utils.get_image_paths(input_path): for filename in Path_utils.get_image_paths(input_path):
filepath = Path(filename) filepath = Path(filename)
@ -39,165 +39,165 @@ def main(input_dir, output_dir):
dflimg = DFLJPG.load ( str(filepath) ) dflimg = DFLJPG.load ( str(filepath) )
else: else:
dflimg = None dflimg = None
if dflimg is None: if dflimg is None:
io.log_err ("%s is not a dfl image file" % (filepath.name) ) io.log_err ("%s is not a dfl image file" % (filepath.name) )
continue continue
lmrks = dflimg.get_landmarks() lmrks = dflimg.get_landmarks()
lmrks_list = lmrks.tolist() lmrks_list = lmrks.tolist()
orig_img = cv2_imread(str(filepath)) orig_img = cv2_imread(str(filepath))
h,w,c = orig_img.shape h,w,c = orig_img.shape
mask_orig = LandmarksProcessor.get_image_hull_mask( orig_img.shape, lmrks).astype(np.uint8)[:,:,0] mask_orig = LandmarksProcessor.get_image_hull_mask( orig_img.shape, lmrks).astype(np.uint8)[:,:,0]
ero_dil_rate = w // 8 ero_dil_rate = w // 8
mask_ero = cv2.erode (mask_orig, cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(ero_dil_rate,ero_dil_rate)), iterations = 1 ) mask_ero = cv2.erode (mask_orig, cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(ero_dil_rate,ero_dil_rate)), iterations = 1 )
mask_dil = cv2.dilate(mask_orig, cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(ero_dil_rate,ero_dil_rate)), iterations = 1 ) mask_dil = cv2.dilate(mask_orig, cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(ero_dil_rate,ero_dil_rate)), iterations = 1 )
#mask_bg = np.zeros(orig_img.shape[:2],np.uint8) #mask_bg = np.zeros(orig_img.shape[:2],np.uint8)
mask_bg = 1-mask_dil mask_bg = 1-mask_dil
mask_bgp = np.ones(orig_img.shape[:2],np.uint8) #default - all background possible mask_bgp = np.ones(orig_img.shape[:2],np.uint8) #default - all background possible
mask_fg = np.zeros(orig_img.shape[:2],np.uint8) mask_fg = np.zeros(orig_img.shape[:2],np.uint8)
mask_fgp = np.zeros(orig_img.shape[:2],np.uint8) mask_fgp = np.zeros(orig_img.shape[:2],np.uint8)
img = orig_img.copy() img = orig_img.copy()
l_thick=2 l_thick=2
def draw_4_lines (masks_out, pts, thickness=1): def draw_4_lines (masks_out, pts, thickness=1):
fgp,fg,bg,bgp = masks_out fgp,fg,bg,bgp = masks_out
h,w = fg.shape h,w = fg.shape
fgp_pts = [] fgp_pts = []
fg_pts = np.array([ pts[i:i+2] for i in range(len(pts)-1)]) fg_pts = np.array([ pts[i:i+2] for i in range(len(pts)-1)])
bg_pts = [] bg_pts = []
bgp_pts = [] bgp_pts = []
for i in range(len(fg_pts)): for i in range(len(fg_pts)):
a, b = line = fg_pts[i] a, b = line = fg_pts[i]
ba = b-a ba = b-a
v = ba / npl.norm(ba) v = ba / npl.norm(ba)
ccpv = np.array([v[1],-v[0]]) ccpv = np.array([v[1],-v[0]])
cpv = np.array([-v[1],v[0]]) cpv = np.array([-v[1],v[0]])
step = 1 / max(np.abs(cpv)) step = 1 / max(np.abs(cpv))
fgp_pts.append ( np.clip (line + ccpv * step * thickness, 0, w-1 ).astype(np.int) ) fgp_pts.append ( np.clip (line + ccpv * step * thickness, 0, w-1 ).astype(np.int) )
bg_pts.append ( np.clip (line + cpv * step * thickness, 0, w-1 ).astype(np.int) ) bg_pts.append ( np.clip (line + cpv * step * thickness, 0, w-1 ).astype(np.int) )
bgp_pts.append ( np.clip (line + cpv * step * thickness * 2, 0, w-1 ).astype(np.int) ) bgp_pts.append ( np.clip (line + cpv * step * thickness * 2, 0, w-1 ).astype(np.int) )
fgp_pts = np.array(fgp_pts) fgp_pts = np.array(fgp_pts)
bg_pts = np.array(bg_pts) bg_pts = np.array(bg_pts)
bgp_pts = np.array(bgp_pts) bgp_pts = np.array(bgp_pts)
cv2.polylines(fgp, fgp_pts, False, (1,), thickness=thickness) cv2.polylines(fgp, fgp_pts, False, (1,), thickness=thickness)
cv2.polylines(fg, fg_pts, False, (1,), thickness=thickness) cv2.polylines(fg, fg_pts, False, (1,), thickness=thickness)
cv2.polylines(bg, bg_pts, False, (1,), thickness=thickness) cv2.polylines(bg, bg_pts, False, (1,), thickness=thickness)
cv2.polylines(bgp, bgp_pts, False, (1,), thickness=thickness) cv2.polylines(bgp, bgp_pts, False, (1,), thickness=thickness)
def draw_lines ( masks_steps, pts, thickness=1): def draw_lines ( masks_steps, pts, thickness=1):
lines = np.array([ pts[i:i+2] for i in range(len(pts)-1)]) lines = np.array([ pts[i:i+2] for i in range(len(pts)-1)])
for mask, step in masks_steps: for mask, step in masks_steps:
h,w = mask.shape h,w = mask.shape
mask_lines = [] mask_lines = []
for i in range(len(lines)): for i in range(len(lines)):
a, b = line = lines[i] a, b = line = lines[i]
ba = b-a ba = b-a
ba_len = npl.norm(ba) ba_len = npl.norm(ba)
if ba_len != 0: if ba_len != 0:
v = ba / ba_len v = ba / ba_len
pv = np.array([-v[1],v[0]]) pv = np.array([-v[1],v[0]])
pv_inv_max = 1 / max(np.abs(pv)) pv_inv_max = 1 / max(np.abs(pv))
mask_lines.append ( np.clip (line + pv * pv_inv_max * thickness * step, 0, w-1 ).astype(np.int) ) mask_lines.append ( np.clip (line + pv * pv_inv_max * thickness * step, 0, w-1 ).astype(np.int) )
else: else:
mask_lines.append ( np.array(line, dtype=np.int) ) mask_lines.append ( np.array(line, dtype=np.int) )
cv2.polylines(mask, mask_lines, False, (1,), thickness=thickness) cv2.polylines(mask, mask_lines, False, (1,), thickness=thickness)
def draw_fill_convex( mask_out, pts, scale=1.0 ): def draw_fill_convex( mask_out, pts, scale=1.0 ):
hull = cv2.convexHull(np.array(pts)) hull = cv2.convexHull(np.array(pts))
if scale !=1.0: if scale !=1.0:
pts_count = hull.shape[0] pts_count = hull.shape[0]
sum_x = np.sum(hull[:, 0, 0]) sum_x = np.sum(hull[:, 0, 0])
sum_y = np.sum(hull[:, 0, 1]) sum_y = np.sum(hull[:, 0, 1])
hull_center = np.array([sum_x/pts_count, sum_y/pts_count]) hull_center = np.array([sum_x/pts_count, sum_y/pts_count])
hull = hull_center+(hull-hull_center)*scale hull = hull_center+(hull-hull_center)*scale
hull = hull.astype(pts.dtype) hull = hull.astype(pts.dtype)
cv2.fillConvexPoly( mask_out, hull, (1,) ) cv2.fillConvexPoly( mask_out, hull, (1,) )
def get_gc_mask_bgr(gc_mask): def get_gc_mask_bgr(gc_mask):
h, w = gc_mask.shape h, w = gc_mask.shape
bgr = np.zeros( (h,w,3), dtype=np.uint8 ) bgr = np.zeros( (h,w,3), dtype=np.uint8 )
bgr [ gc_mask == 0 ] = (0,0,0) bgr [ gc_mask == 0 ] = (0,0,0)
bgr [ gc_mask == 1 ] = (255,255,255) bgr [ gc_mask == 1 ] = (255,255,255)
bgr [ gc_mask == 2 ] = (0,0,255) #RED bgr [ gc_mask == 2 ] = (0,0,255) #RED
bgr [ gc_mask == 3 ] = (0,255,0) #GREEN bgr [ gc_mask == 3 ] = (0,255,0) #GREEN
return bgr return bgr
def get_gc_mask_result(gc_mask): def get_gc_mask_result(gc_mask):
return np.where((gc_mask==1) + (gc_mask==3),1,0).astype(np.int) return np.where((gc_mask==1) + (gc_mask==3),1,0).astype(np.int)
#convex inner of right chin to end of right eyebrow #convex inner of right chin to end of right eyebrow
#draw_fill_convex ( mask_fgp, lmrks_list[8:17]+lmrks_list[26:27] ) #draw_fill_convex ( mask_fgp, lmrks_list[8:17]+lmrks_list[26:27] )
#convex inner of start right chin to right eyebrow #convex inner of start right chin to right eyebrow
#draw_fill_convex ( mask_fgp, lmrks_list[8:9]+lmrks_list[22:27] ) #draw_fill_convex ( mask_fgp, lmrks_list[8:9]+lmrks_list[22:27] )
#convex inner of nose #convex inner of nose
draw_fill_convex ( mask_fgp, lmrks[27:36] ) draw_fill_convex ( mask_fgp, lmrks[27:36] )
#convex inner of nose half #convex inner of nose half
draw_fill_convex ( mask_fg, lmrks[27:36], scale=0.5 ) draw_fill_convex ( mask_fg, lmrks[27:36], scale=0.5 )
#left corner of mouth to left corner of nose #left corner of mouth to left corner of nose
#draw_lines ( [ (mask_fg,0), ], lmrks_list[49:50]+lmrks_list[32:33], l_thick) #draw_lines ( [ (mask_fg,0), ], lmrks_list[49:50]+lmrks_list[32:33], l_thick)
#convex inner: right corner of nose to centers of eyebrows #convex inner: right corner of nose to centers of eyebrows
#draw_fill_convex ( mask_fgp, lmrks_list[35:36]+lmrks_list[19:20]+lmrks_list[24:25]) #draw_fill_convex ( mask_fgp, lmrks_list[35:36]+lmrks_list[19:20]+lmrks_list[24:25])
#right corner of mouth to right corner of nose #right corner of mouth to right corner of nose
#draw_lines ( [ (mask_fg,0), ], lmrks_list[54:55]+lmrks_list[35:36], l_thick) #draw_lines ( [ (mask_fg,0), ], lmrks_list[54:55]+lmrks_list[35:36], l_thick)
#left eye #left eye
#draw_fill_convex ( mask_fg, lmrks_list[36:40] ) #draw_fill_convex ( mask_fg, lmrks_list[36:40] )
#right eye #right eye
#draw_fill_convex ( mask_fg, lmrks_list[42:48] ) #draw_fill_convex ( mask_fg, lmrks_list[42:48] )
#right chin #right chin
draw_lines ( [ (mask_bg,0), (mask_fg,-1), ], lmrks[8:17], l_thick) draw_lines ( [ (mask_bg,0), (mask_fg,-1), ], lmrks[8:17], l_thick)
#left eyebrow center to right eyeprow center #left eyebrow center to right eyeprow center
draw_lines ( [ (mask_bg,-1), (mask_fg,0), ], lmrks_list[19:20] + lmrks_list[24:25], l_thick) draw_lines ( [ (mask_bg,-1), (mask_fg,0), ], lmrks_list[19:20] + lmrks_list[24:25], l_thick)
# #draw_lines ( [ (mask_bg,-1), (mask_fg,0), ], lmrks_list[24:25] + lmrks_list[19:17:-1], l_thick) # #draw_lines ( [ (mask_bg,-1), (mask_fg,0), ], lmrks_list[24:25] + lmrks_list[19:17:-1], l_thick)
#half right eyebrow to end of right chin #half right eyebrow to end of right chin
draw_lines ( [ (mask_bg,-1), (mask_fg,0), ], lmrks_list[24:27] + lmrks_list[16:17], l_thick) draw_lines ( [ (mask_bg,-1), (mask_fg,0), ], lmrks_list[24:27] + lmrks_list[16:17], l_thick)
#import code #import code
#code.interact(local=dict(globals(), **locals())) #code.interact(local=dict(globals(), **locals()))
#compose mask layers #compose mask layers
gc_mask = np.zeros(orig_img.shape[:2],np.uint8) gc_mask = np.zeros(orig_img.shape[:2],np.uint8)
gc_mask [ mask_bgp==1 ] = 2 gc_mask [ mask_bgp==1 ] = 2
gc_mask [ mask_fgp==1 ] = 3 gc_mask [ mask_fgp==1 ] = 3
gc_mask [ mask_bg==1 ] = 0 gc_mask [ mask_bg==1 ] = 0
gc_mask [ mask_fg==1 ] = 1 gc_mask [ mask_fg==1 ] = 1
gc_bgr_before = get_gc_mask_bgr (gc_mask) gc_bgr_before = get_gc_mask_bgr (gc_mask)
#io.show_image (wnd_name, gc_mask ) #io.show_image (wnd_name, gc_mask )
##points, hierarcy = cv2.findContours(original_mask,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE) ##points, hierarcy = cv2.findContours(original_mask,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
##gc_mask = ( (1-erode_mask)*2 + erode_mask )# * dilate_mask ##gc_mask = ( (1-erode_mask)*2 + erode_mask )# * dilate_mask
#gc_mask = (1-erode_mask)*2 + erode_mask #gc_mask = (1-erode_mask)*2 + erode_mask
@ -211,34 +211,34 @@ def main(input_dir, output_dir):
# #
# #
cv2.grabCut(img,gc_mask,None,np.zeros((1,65),np.float64),np.zeros((1,65),np.float64),5, cv2.GC_INIT_WITH_MASK) cv2.grabCut(img,gc_mask,None,np.zeros((1,65),np.float64),np.zeros((1,65),np.float64),5, cv2.GC_INIT_WITH_MASK)
gc_bgr = get_gc_mask_bgr (gc_mask) gc_bgr = get_gc_mask_bgr (gc_mask)
gc_mask_result = get_gc_mask_result(gc_mask) gc_mask_result = get_gc_mask_result(gc_mask)
gc_mask_result_1 = gc_mask_result[:,:,np.newaxis] gc_mask_result_1 = gc_mask_result[:,:,np.newaxis]
#import code #import code
#code.interact(local=dict(globals(), **locals())) #code.interact(local=dict(globals(), **locals()))
orig_img_gc_layers_masked = (0.5*orig_img + 0.5*gc_bgr).astype(np.uint8) orig_img_gc_layers_masked = (0.5*orig_img + 0.5*gc_bgr).astype(np.uint8)
orig_img_gc_before_layers_masked = (0.5*orig_img + 0.5*gc_bgr_before).astype(np.uint8) orig_img_gc_before_layers_masked = (0.5*orig_img + 0.5*gc_bgr_before).astype(np.uint8)
pink_bg = np.full ( orig_img.shape, (255,0,255), dtype=np.uint8 ) pink_bg = np.full ( orig_img.shape, (255,0,255), dtype=np.uint8 )
orig_img_result = orig_img * gc_mask_result_1 orig_img_result = orig_img * gc_mask_result_1
orig_img_result_pinked = orig_img_result + pink_bg * (1-gc_mask_result_1) orig_img_result_pinked = orig_img_result + pink_bg * (1-gc_mask_result_1)
#io.show_image (wnd_name, blended_img) #io.show_image (wnd_name, blended_img)
##gc_mask, bgdModel, fgdModel = ##gc_mask, bgdModel, fgdModel =
# #
#mask2 = np.where((gc_mask==1) + (gc_mask==3),255,0).astype('uint8')[:,:,np.newaxis] #mask2 = np.where((gc_mask==1) + (gc_mask==3),255,0).astype('uint8')[:,:,np.newaxis]
#mask2 = np.repeat(mask2, (3,), -1) #mask2 = np.repeat(mask2, (3,), -1)
# #
##mask2 = np.where(gc_mask!=0,255,0).astype('uint8') ##mask2 = np.where(gc_mask!=0,255,0).astype('uint8')
#blended_img = orig_img #-\ #blended_img = orig_img #-\
# #0.3 * np.full(original_img.shape, (50,50,50)) * (1-mask_0_27)[:,:,np.newaxis] # #0.3 * np.full(original_img.shape, (50,50,50)) * (1-mask_0_27)[:,:,np.newaxis]
# #0.3 * np.full(original_img.shape, (50,50,50)) * (1-dilate_mask)[:,:,np.newaxis] +\ # #0.3 * np.full(original_img.shape, (50,50,50)) * (1-dilate_mask)[:,:,np.newaxis] +\
# #0.3 * np.full(original_img.shape, (50,50,50)) * (erode_mask)[:,:,np.newaxis] # #0.3 * np.full(original_img.shape, (50,50,50)) * (erode_mask)[:,:,np.newaxis]
#blended_img = np.clip(blended_img, 0, 255).astype(np.uint8) #blended_img = np.clip(blended_img, 0, 255).astype(np.uint8)
@ -246,25 +246,25 @@ def main(input_dir, output_dir):
##code.interact(local=dict(globals(), **locals())) ##code.interact(local=dict(globals(), **locals()))
orig_img_lmrked = orig_img.copy() orig_img_lmrked = orig_img.copy()
LandmarksProcessor.draw_landmarks(orig_img_lmrked, lmrks, transparent_mask=True) LandmarksProcessor.draw_landmarks(orig_img_lmrked, lmrks, transparent_mask=True)
screen = np.concatenate ([orig_img_gc_before_layers_masked, screen = np.concatenate ([orig_img_gc_before_layers_masked,
orig_img_gc_layers_masked, orig_img_gc_layers_masked,
orig_img, orig_img,
orig_img_lmrked, orig_img_lmrked,
orig_img_result_pinked, orig_img_result_pinked,
orig_img_result, orig_img_result,
], axis=1) ], axis=1)
io.show_image (wnd_name, screen.astype(np.uint8) ) io.show_image (wnd_name, screen.astype(np.uint8) )
while True: while True:
io.process_messages() io.process_messages()
for (x,y,ev,flags) in io.get_mouse_events(wnd_name): for (x,y,ev,flags) in io.get_mouse_events(wnd_name):
pass pass
#print (x,y,ev,flags) #print (x,y,ev,flags)
key_events = [ ev for ev, in io.get_key_events(wnd_name) ] key_events = [ ev for ev, in io.get_key_events(wnd_name) ]
for key in key_events: for key in key_events:
if key == ord('1'): if key == ord('1'):
@ -273,15 +273,15 @@ def main(input_dir, output_dir):
pass pass
if key == ord('3'): if key == ord('3'):
pass pass
if ord(' ') in key_events: if ord(' ') in key_events:
break break
import code import code
code.interact(local=dict(globals(), **locals())) code.interact(local=dict(globals(), **locals()))
#original_mask = np.ones(original_img.shape[:2],np.uint8)*2 #original_mask = np.ones(original_img.shape[:2],np.uint8)*2
#cv2.drawContours(original_mask, points, -1, (1,), 1) #cv2.drawContours(original_mask, points, -1, (1,), 1)

View file

@ -15,10 +15,10 @@ from joblib import Subprocessor
import multiprocessing import multiprocessing
from interact import interact as io from interact import interact as io
from imagelib import estimate_sharpness from imagelib import estimate_sharpness
class BlurEstimatorSubprocessor(Subprocessor): class BlurEstimatorSubprocessor(Subprocessor):
class Cli(Subprocessor.Cli): class Cli(Subprocessor.Cli):
#override #override
def on_initialize(self, client_dict): def on_initialize(self, client_dict):
self.log_info('Running on %s.' % (client_dict['device_name']) ) self.log_info('Running on %s.' % (client_dict['device_name']) )
@ -26,58 +26,58 @@ class BlurEstimatorSubprocessor(Subprocessor):
#override #override
def process_data(self, data): def process_data(self, data):
filepath = Path( data[0] ) filepath = Path( data[0] )
if filepath.suffix == '.png': if filepath.suffix == '.png':
dflimg = DFLPNG.load( str(filepath) ) dflimg = DFLPNG.load( str(filepath) )
elif filepath.suffix == '.jpg': elif filepath.suffix == '.jpg':
dflimg = DFLJPG.load ( str(filepath) ) dflimg = DFLJPG.load ( str(filepath) )
else: else:
dflimg = None dflimg = None
if dflimg is not None: if dflimg is not None:
image = cv2_imread( str(filepath) ) image = cv2_imread( str(filepath) )
return [ str(filepath), estimate_sharpness(image) ] return [ str(filepath), estimate_sharpness(image) ]
else: else:
self.log_err ("%s is not a dfl image file" % (filepath.name) ) self.log_err ("%s is not a dfl image file" % (filepath.name) )
return [ str(filepath), 0 ] return [ str(filepath), 0 ]
#override #override
def get_data_name (self, data): def get_data_name (self, data):
#return string identificator of your data #return string identificator of your data
return data[0] return data[0]
#override #override
def __init__(self, input_data ): def __init__(self, input_data ):
self.input_data = input_data self.input_data = input_data
self.img_list = [] self.img_list = []
self.trash_img_list = [] self.trash_img_list = []
super().__init__('BlurEstimator', BlurEstimatorSubprocessor.Cli, 60) super().__init__('BlurEstimator', BlurEstimatorSubprocessor.Cli, 60)
#override #override
def on_clients_initialized(self): def on_clients_initialized(self):
io.progress_bar ("", len (self.input_data)) io.progress_bar ("", len (self.input_data))
#override #override
def on_clients_finalized(self): def on_clients_finalized(self):
io.progress_bar_close () io.progress_bar_close ()
#override #override
def process_info_generator(self): def process_info_generator(self):
for i in range(0, multiprocessing.cpu_count() ): for i in range(0, multiprocessing.cpu_count() ):
yield 'CPU%d' % (i), {}, {'device_idx': i, yield 'CPU%d' % (i), {}, {'device_idx': i,
'device_name': 'CPU%d' % (i), 'device_name': 'CPU%d' % (i),
} }
#override #override
def get_data(self, host_dict): def get_data(self, host_dict):
if len (self.input_data) > 0: if len (self.input_data) > 0:
return self.input_data.pop(0) return self.input_data.pop(0)
return None return None
#override #override
def on_data_return (self, host_dict, data): def on_data_return (self, host_dict, data):
self.input_data.insert(0, data) self.input_data.insert(0, data)
#override #override
def on_result (self, host_dict, data, result): def on_result (self, host_dict, data, result):
@ -85,20 +85,20 @@ class BlurEstimatorSubprocessor(Subprocessor):
self.trash_img_list.append ( result ) self.trash_img_list.append ( result )
else: else:
self.img_list.append ( result ) self.img_list.append ( result )
io.progress_bar_inc(1) io.progress_bar_inc(1)
#override #override
def get_result(self): def get_result(self):
return self.img_list, self.trash_img_list return self.img_list, self.trash_img_list
def sort_by_blur(input_path): def sort_by_blur(input_path):
io.log_info ("Sorting by blur...") io.log_info ("Sorting by blur...")
img_list = [ (filename,[]) for filename in Path_utils.get_image_paths(input_path) ] img_list = [ (filename,[]) for filename in Path_utils.get_image_paths(input_path) ]
img_list, trash_img_list = BlurEstimatorSubprocessor (img_list).run() img_list, trash_img_list = BlurEstimatorSubprocessor (img_list).run()
io.log_info ("Sorting...") io.log_info ("Sorting...")
img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True) img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True)
@ -111,21 +111,21 @@ def sort_by_face(input_path):
trash_img_list = [] trash_img_list = []
for filepath in io.progress_bar_generator( Path_utils.get_image_paths(input_path), "Loading"): for filepath in io.progress_bar_generator( Path_utils.get_image_paths(input_path), "Loading"):
filepath = Path(filepath) filepath = Path(filepath)
if filepath.suffix == '.png': if filepath.suffix == '.png':
dflimg = DFLPNG.load( str(filepath) ) dflimg = DFLPNG.load( str(filepath) )
elif filepath.suffix == '.jpg': elif filepath.suffix == '.jpg':
dflimg = DFLJPG.load ( str(filepath) ) dflimg = DFLJPG.load ( str(filepath) )
else: else:
dflimg = None dflimg = None
if dflimg is None: if dflimg is None:
io.log_err ("%s is not a dfl image file" % (filepath.name) ) io.log_err ("%s is not a dfl image file" % (filepath.name) )
trash_img_list.append ( [str(filepath)] ) trash_img_list.append ( [str(filepath)] )
continue continue
img_list.append( [str(filepath), dflimg.get_landmarks()] ) img_list.append( [str(filepath), dflimg.get_landmarks()] )
img_list_len = len(img_list) img_list_len = len(img_list)
for i in io.progress_bar_generator ( range(0, img_list_len-1), "Sorting"): for i in io.progress_bar_generator ( range(0, img_list_len-1), "Sorting"):
@ -152,21 +152,21 @@ def sort_by_face_dissim(input_path):
trash_img_list = [] trash_img_list = []
for filepath in io.progress_bar_generator( Path_utils.get_image_paths(input_path), "Loading"): for filepath in io.progress_bar_generator( Path_utils.get_image_paths(input_path), "Loading"):
filepath = Path(filepath) filepath = Path(filepath)
if filepath.suffix == '.png': if filepath.suffix == '.png':
dflimg = DFLPNG.load( str(filepath) ) dflimg = DFLPNG.load( str(filepath) )
elif filepath.suffix == '.jpg': elif filepath.suffix == '.jpg':
dflimg = DFLJPG.load ( str(filepath) ) dflimg = DFLJPG.load ( str(filepath) )
else: else:
dflimg = None dflimg = None
if dflimg is None: if dflimg is None:
io.log_err ("%s is not a dfl image file" % (filepath.name) ) io.log_err ("%s is not a dfl image file" % (filepath.name) )
trash_img_list.append ( [str(filepath)] ) trash_img_list.append ( [str(filepath)] )
continue continue
img_list.append( [str(filepath), dflimg.get_landmarks(), 0 ] ) img_list.append( [str(filepath), dflimg.get_landmarks(), 0 ] )
img_list_len = len(img_list) img_list_len = len(img_list)
for i in io.progress_bar_generator( range(img_list_len-1), "Sorting"): for i in io.progress_bar_generator( range(img_list_len-1), "Sorting"):
score_total = 0 score_total = 0
@ -183,79 +183,79 @@ def sort_by_face_dissim(input_path):
img_list = sorted(img_list, key=operator.itemgetter(2), reverse=True) img_list = sorted(img_list, key=operator.itemgetter(2), reverse=True)
return img_list, trash_img_list return img_list, trash_img_list
def sort_by_face_yaw(input_path): def sort_by_face_yaw(input_path):
io.log_info ("Sorting by face yaw...") io.log_info ("Sorting by face yaw...")
img_list = [] img_list = []
trash_img_list = [] trash_img_list = []
for filepath in io.progress_bar_generator( Path_utils.get_image_paths(input_path), "Loading"): for filepath in io.progress_bar_generator( Path_utils.get_image_paths(input_path), "Loading"):
filepath = Path(filepath) filepath = Path(filepath)
if filepath.suffix == '.png': if filepath.suffix == '.png':
dflimg = DFLPNG.load( str(filepath) ) dflimg = DFLPNG.load( str(filepath) )
elif filepath.suffix == '.jpg': elif filepath.suffix == '.jpg':
dflimg = DFLJPG.load ( str(filepath) ) dflimg = DFLJPG.load ( str(filepath) )
else: else:
dflimg = None dflimg = None
if dflimg is None: if dflimg is None:
io.log_err ("%s is not a dfl image file" % (filepath.name) ) io.log_err ("%s is not a dfl image file" % (filepath.name) )
trash_img_list.append ( [str(filepath)] ) trash_img_list.append ( [str(filepath)] )
continue continue
pitch, yaw = LandmarksProcessor.estimate_pitch_yaw ( dflimg.get_landmarks() ) pitch, yaw = LandmarksProcessor.estimate_pitch_yaw ( dflimg.get_landmarks() )
img_list.append( [str(filepath), yaw ] ) img_list.append( [str(filepath), yaw ] )
io.log_info ("Sorting...") io.log_info ("Sorting...")
img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True) img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True)
return img_list, trash_img_list return img_list, trash_img_list
def sort_by_face_pitch(input_path): def sort_by_face_pitch(input_path):
io.log_info ("Sorting by face pitch...") io.log_info ("Sorting by face pitch...")
img_list = [] img_list = []
trash_img_list = [] trash_img_list = []
for filepath in io.progress_bar_generator( Path_utils.get_image_paths(input_path), "Loading"): for filepath in io.progress_bar_generator( Path_utils.get_image_paths(input_path), "Loading"):
filepath = Path(filepath) filepath = Path(filepath)
if filepath.suffix == '.png': if filepath.suffix == '.png':
dflimg = DFLPNG.load( str(filepath) ) dflimg = DFLPNG.load( str(filepath) )
elif filepath.suffix == '.jpg': elif filepath.suffix == '.jpg':
dflimg = DFLJPG.load ( str(filepath) ) dflimg = DFLJPG.load ( str(filepath) )
else: else:
dflimg = None dflimg = None
if dflimg is None: if dflimg is None:
io.log_err ("%s is not a dfl image file" % (filepath.name) ) io.log_err ("%s is not a dfl image file" % (filepath.name) )
trash_img_list.append ( [str(filepath)] ) trash_img_list.append ( [str(filepath)] )
continue continue
pitch, yaw = LandmarksProcessor.estimate_pitch_yaw ( dflimg.get_landmarks() ) pitch, yaw = LandmarksProcessor.estimate_pitch_yaw ( dflimg.get_landmarks() )
img_list.append( [str(filepath), pitch ] ) img_list.append( [str(filepath), pitch ] )
io.log_info ("Sorting...") io.log_info ("Sorting...")
img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True) img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True)
return img_list, trash_img_list return img_list, trash_img_list
class HistSsimSubprocessor(Subprocessor): class HistSsimSubprocessor(Subprocessor):
class Cli(Subprocessor.Cli): class Cli(Subprocessor.Cli):
#override #override
def on_initialize(self, client_dict): def on_initialize(self, client_dict):
self.log_info ('Running on %s.' % (client_dict['device_name']) ) self.log_info ('Running on %s.' % (client_dict['device_name']) )
#override #override
def process_data(self, data): def process_data(self, data):
img_list = [] img_list = []
for x in data: for x in data:
img = cv2_imread(x) img = cv2_imread(x)
img_list.append ([x, cv2.calcHist([img], [0], None, [256], [0, 256]), img_list.append ([x, cv2.calcHist([img], [0], None, [256], [0, 256]),
cv2.calcHist([img], [1], None, [256], [0, 256]), cv2.calcHist([img], [1], None, [256], [0, 256]),
cv2.calcHist([img], [2], None, [256], [0, 256]) cv2.calcHist([img], [2], None, [256], [0, 256])
]) ])
img_list_len = len(img_list) img_list_len = len(img_list)
for i in range(img_list_len-1): for i in range(img_list_len-1):
min_score = float("inf") min_score = float("inf")
@ -268,23 +268,23 @@ class HistSsimSubprocessor(Subprocessor):
min_score = score min_score = score
j_min_score = j j_min_score = j
img_list[i+1], img_list[j_min_score] = img_list[j_min_score], img_list[i+1] img_list[i+1], img_list[j_min_score] = img_list[j_min_score], img_list[i+1]
self.progress_bar_inc(1) self.progress_bar_inc(1)
return img_list return img_list
#override #override
def get_data_name (self, data): def get_data_name (self, data):
return "Bunch of images" return "Bunch of images"
#override #override
def __init__(self, img_list ): def __init__(self, img_list ):
self.img_list = img_list self.img_list = img_list
self.img_list_len = len(img_list) self.img_list_len = len(img_list)
slice_count = 20000 slice_count = 20000
sliced_count = self.img_list_len // slice_count sliced_count = self.img_list_len // slice_count
if sliced_count > 12: if sliced_count > 12:
sliced_count = 11.9 sliced_count = 11.9
slice_count = int(self.img_list_len / sliced_count) slice_count = int(self.img_list_len / sliced_count)
@ -294,10 +294,10 @@ class HistSsimSubprocessor(Subprocessor):
[ self.img_list[sliced_count*slice_count:] ] [ self.img_list[sliced_count*slice_count:] ]
self.result = [] self.result = []
super().__init__('HistSsim', HistSsimSubprocessor.Cli, 0) super().__init__('HistSsim', HistSsimSubprocessor.Cli, 0)
#override #override
def process_info_generator(self): def process_info_generator(self):
for i in range( len(self.img_chunks_list) ): for i in range( len(self.img_chunks_list) ):
yield 'CPU%d' % (i), {'i':i}, {'device_idx': i, yield 'CPU%d' % (i), {'i':i}, {'device_idx': i,
'device_name': 'CPU%d' % (i) 'device_name': 'CPU%d' % (i)
@ -306,21 +306,21 @@ class HistSsimSubprocessor(Subprocessor):
def on_clients_initialized(self): def on_clients_initialized(self):
io.progress_bar ("Sorting", len(self.img_list)) io.progress_bar ("Sorting", len(self.img_list))
io.progress_bar_inc(len(self.img_chunks_list)) io.progress_bar_inc(len(self.img_chunks_list))
#override #override
def on_clients_finalized(self): def on_clients_finalized(self):
io.progress_bar_close() io.progress_bar_close()
#override #override
def get_data(self, host_dict): def get_data(self, host_dict):
if len (self.img_chunks_list) > 0: if len (self.img_chunks_list) > 0:
return self.img_chunks_list.pop(0) return self.img_chunks_list.pop(0)
return None return None
#override #override
def on_data_return (self, host_dict, data): def on_data_return (self, host_dict, data):
raise Exception("Fail to process data. Decrease number of images and try again.") raise Exception("Fail to process data. Decrease number of images and try again.")
#override #override
def on_result (self, host_dict, data, result): def on_result (self, host_dict, data, result):
self.result += result self.result += result
@ -329,10 +329,10 @@ class HistSsimSubprocessor(Subprocessor):
#override #override
def get_result(self): def get_result(self):
return self.result return self.result
def sort_by_hist(input_path): def sort_by_hist(input_path):
io.log_info ("Sorting by histogram similarity...") io.log_info ("Sorting by histogram similarity...")
img_list = HistSsimSubprocessor(Path_utils.get_image_paths(input_path)).run() img_list = HistSsimSubprocessor(Path_utils.get_image_paths(input_path)).run()
return img_list return img_list
class HistDissimSubprocessor(Subprocessor): class HistDissimSubprocessor(Subprocessor):
@ -344,7 +344,7 @@ class HistDissimSubprocessor(Subprocessor):
self.img_list_len = len(self.img_list) self.img_list_len = len(self.img_list)
#override #override
def process_data(self, data): def process_data(self, data):
i = data[0] i = data[0]
score_total = 0 score_total = 0
for j in range( 0, self.img_list_len): for j in range( 0, self.img_list_len):
@ -358,40 +358,40 @@ class HistDissimSubprocessor(Subprocessor):
def get_data_name (self, data): def get_data_name (self, data):
#return string identificator of your data #return string identificator of your data
return self.img_list[data[0]][0] return self.img_list[data[0]][0]
#override #override
def __init__(self, img_list ): def __init__(self, img_list ):
self.img_list = img_list self.img_list = img_list
self.img_list_range = [i for i in range(0, len(img_list) )] self.img_list_range = [i for i in range(0, len(img_list) )]
self.result = [] self.result = []
super().__init__('HistDissim', HistDissimSubprocessor.Cli, 60) super().__init__('HistDissim', HistDissimSubprocessor.Cli, 60)
#override #override
def on_clients_initialized(self): def on_clients_initialized(self):
io.progress_bar ("Sorting", len (self.img_list) ) io.progress_bar ("Sorting", len (self.img_list) )
#override #override
def on_clients_finalized(self): def on_clients_finalized(self):
io.progress_bar_close() io.progress_bar_close()
#override #override
def process_info_generator(self): def process_info_generator(self):
for i in range(0, min(multiprocessing.cpu_count(), 8) ): for i in range(0, min(multiprocessing.cpu_count(), 8) ):
yield 'CPU%d' % (i), {}, {'device_idx': i, yield 'CPU%d' % (i), {}, {'device_idx': i,
'device_name': 'CPU%d' % (i), 'device_name': 'CPU%d' % (i),
'img_list' : self.img_list 'img_list' : self.img_list
} }
#override #override
def get_data(self, host_dict): def get_data(self, host_dict):
if len (self.img_list_range) > 0: if len (self.img_list_range) > 0:
return [self.img_list_range.pop(0)] return [self.img_list_range.pop(0)]
return None return None
#override #override
def on_data_return (self, host_dict, data): def on_data_return (self, host_dict, data):
self.img_list_range.insert(0, data[0]) self.img_list_range.insert(0, data[0])
#override #override
def on_result (self, host_dict, data, result): def on_result (self, host_dict, data, result):
self.img_list[data[0]][2] = result self.img_list[data[0]][2] = result
@ -400,7 +400,7 @@ class HistDissimSubprocessor(Subprocessor):
#override #override
def get_result(self): def get_result(self):
return self.img_list return self.img_list
def sort_by_hist_dissim(input_path): def sort_by_hist_dissim(input_path):
io.log_info ("Sorting by histogram dissimilarity...") io.log_info ("Sorting by histogram dissimilarity...")
@ -408,19 +408,19 @@ def sort_by_hist_dissim(input_path):
trash_img_list = [] trash_img_list = []
for filepath in io.progress_bar_generator( Path_utils.get_image_paths(input_path), "Loading"): for filepath in io.progress_bar_generator( Path_utils.get_image_paths(input_path), "Loading"):
filepath = Path(filepath) filepath = Path(filepath)
if filepath.suffix == '.png': if filepath.suffix == '.png':
dflimg = DFLPNG.load( str(filepath) ) dflimg = DFLPNG.load( str(filepath) )
elif filepath.suffix == '.jpg': elif filepath.suffix == '.jpg':
dflimg = DFLJPG.load ( str(filepath) ) dflimg = DFLJPG.load ( str(filepath) )
else: else:
dflimg = None dflimg = None
if dflimg is None: if dflimg is None:
io.log_err ("%s is not a dfl image file" % (filepath.name) ) io.log_err ("%s is not a dfl image file" % (filepath.name) )
trash_img_list.append ([str(filepath)]) trash_img_list.append ([str(filepath)])
continue continue
image = cv2_imread(str(filepath)) image = cv2_imread(str(filepath))
face_mask = LandmarksProcessor.get_image_hull_mask (image.shape, dflimg.get_landmarks()) face_mask = LandmarksProcessor.get_image_hull_mask (image.shape, dflimg.get_landmarks())
image = (image*face_mask).astype(np.uint8) image = (image*face_mask).astype(np.uint8)
@ -428,26 +428,26 @@ def sort_by_hist_dissim(input_path):
img_list.append ([str(filepath), cv2.calcHist([cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)], [0], None, [256], [0, 256]), 0 ]) img_list.append ([str(filepath), cv2.calcHist([cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)], [0], None, [256], [0, 256]), 0 ])
img_list = HistDissimSubprocessor(img_list).run() img_list = HistDissimSubprocessor(img_list).run()
io.log_info ("Sorting...") io.log_info ("Sorting...")
img_list = sorted(img_list, key=operator.itemgetter(2), reverse=True) img_list = sorted(img_list, key=operator.itemgetter(2), reverse=True)
return img_list, trash_img_list return img_list, trash_img_list
def sort_by_brightness(input_path): def sort_by_brightness(input_path):
io.log_info ("Sorting by brightness...") io.log_info ("Sorting by brightness...")
img_list = [ [x, np.mean ( cv2.cvtColor(cv2_imread(x), cv2.COLOR_BGR2HSV)[...,2].flatten() )] for x in io.progress_bar_generator( Path_utils.get_image_paths(input_path), "Loading") ] img_list = [ [x, np.mean ( cv2.cvtColor(cv2_imread(x), cv2.COLOR_BGR2HSV)[...,2].flatten() )] for x in io.progress_bar_generator( Path_utils.get_image_paths(input_path), "Loading") ]
io.log_info ("Sorting...") io.log_info ("Sorting...")
img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True) img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True)
return img_list return img_list
def sort_by_hue(input_path): def sort_by_hue(input_path):
io.log_info ("Sorting by hue...") io.log_info ("Sorting by hue...")
img_list = [ [x, np.mean ( cv2.cvtColor(cv2_imread(x), cv2.COLOR_BGR2HSV)[...,0].flatten() )] for x in io.progress_bar_generator( Path_utils.get_image_paths(input_path), "Loading") ] img_list = [ [x, np.mean ( cv2.cvtColor(cv2_imread(x), cv2.COLOR_BGR2HSV)[...,0].flatten() )] for x in io.progress_bar_generator( Path_utils.get_image_paths(input_path), "Loading") ]
io.log_info ("Sorting...") io.log_info ("Sorting...")
img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True) img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True)
return img_list return img_list
def sort_by_black(input_path): def sort_by_black(input_path):
io.log_info ("Sorting by amount of black pixels...") io.log_info ("Sorting by amount of black pixels...")
@ -460,22 +460,22 @@ def sort_by_black(input_path):
img_list = sorted(img_list, key=operator.itemgetter(1), reverse=False) img_list = sorted(img_list, key=operator.itemgetter(1), reverse=False)
return img_list return img_list
def sort_by_origname(input_path): def sort_by_origname(input_path):
io.log_info ("Sort by original filename...") io.log_info ("Sort by original filename...")
img_list = [] img_list = []
trash_img_list = [] trash_img_list = []
for filepath in io.progress_bar_generator( Path_utils.get_image_paths(input_path), "Loading"): for filepath in io.progress_bar_generator( Path_utils.get_image_paths(input_path), "Loading"):
filepath = Path(filepath) filepath = Path(filepath)
if filepath.suffix == '.png': if filepath.suffix == '.png':
dflimg = DFLPNG.load( str(filepath) ) dflimg = DFLPNG.load( str(filepath) )
elif filepath.suffix == '.jpg': elif filepath.suffix == '.jpg':
dflimg = DFLJPG.load( str(filepath) ) dflimg = DFLJPG.load( str(filepath) )
else: else:
dflimg = None dflimg = None
if dflimg is None: if dflimg is None:
io.log_err ("%s is not a dfl image file" % (filepath.name) ) io.log_err ("%s is not a dfl image file" % (filepath.name) )
trash_img_list.append( [str(filepath)] ) trash_img_list.append( [str(filepath)] )
@ -486,7 +486,7 @@ def sort_by_origname(input_path):
io.log_info ("Sorting...") io.log_info ("Sorting...")
img_list = sorted(img_list, key=operator.itemgetter(1)) img_list = sorted(img_list, key=operator.itemgetter(1))
return img_list, trash_img_list return img_list, trash_img_list
def sort_by_oneface_in_image(input_path): def sort_by_oneface_in_image(input_path):
io.log_info ("Sort by one face in images...") io.log_info ("Sort by one face in images...")
image_paths = Path_utils.get_image_paths(input_path) image_paths = Path_utils.get_image_paths(input_path)
@ -503,17 +503,17 @@ def sort_by_oneface_in_image(input_path):
trash_img_list = [ (image_paths[x],) for x in idxs ] trash_img_list = [ (image_paths[x],) for x in idxs ]
return img_list, trash_img_list return img_list, trash_img_list
return [], [] return [], []
class FinalLoaderSubprocessor(Subprocessor): class FinalLoaderSubprocessor(Subprocessor):
class Cli(Subprocessor.Cli): class Cli(Subprocessor.Cli):
#override #override
def on_initialize(self, client_dict): def on_initialize(self, client_dict):
self.log_info ('Running on %s.' % (client_dict['device_name']) ) self.log_info ('Running on %s.' % (client_dict['device_name']) )
self.include_by_blur = client_dict['include_by_blur'] self.include_by_blur = client_dict['include_by_blur']
#override #override
def process_data(self, data): def process_data(self, data):
filepath = Path(data[0]) filepath = Path(data[0])
try: try:
if filepath.suffix == '.png': if filepath.suffix == '.png':
@ -522,40 +522,40 @@ class FinalLoaderSubprocessor(Subprocessor):
dflimg = DFLJPG.load( str(filepath) ) dflimg = DFLJPG.load( str(filepath) )
else: else:
dflimg = None dflimg = None
if dflimg is None: if dflimg is None:
self.log_err("%s is not a dfl image file" % (filepath.name)) self.log_err("%s is not a dfl image file" % (filepath.name))
return [ 1, [str(filepath)] ] return [ 1, [str(filepath)] ]
bgr = cv2_imread(str(filepath)) bgr = cv2_imread(str(filepath))
if bgr is None: if bgr is None:
raise Exception ("Unable to load %s" % (filepath.name) ) raise Exception ("Unable to load %s" % (filepath.name) )
gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY) gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY)
sharpness = estimate_sharpness(gray) if self.include_by_blur else 0 sharpness = estimate_sharpness(gray) if self.include_by_blur else 0
pitch, yaw = LandmarksProcessor.estimate_pitch_yaw ( dflimg.get_landmarks() ) pitch, yaw = LandmarksProcessor.estimate_pitch_yaw ( dflimg.get_landmarks() )
hist = cv2.calcHist([gray], [0], None, [256], [0, 256]) hist = cv2.calcHist([gray], [0], None, [256], [0, 256])
except Exception as e: except Exception as e:
self.log_err (e) self.log_err (e)
return [ 1, [str(filepath)] ] return [ 1, [str(filepath)] ]
return [ 0, [str(filepath), sharpness, hist, yaw ] ] return [ 0, [str(filepath), sharpness, hist, yaw ] ]
#override #override
def get_data_name (self, data): def get_data_name (self, data):
#return string identificator of your data #return string identificator of your data
return data[0] return data[0]
#override #override
def __init__(self, img_list, include_by_blur ): def __init__(self, img_list, include_by_blur ):
self.img_list = img_list self.img_list = img_list
self.include_by_blur = include_by_blur self.include_by_blur = include_by_blur
self.result = [] self.result = []
self.result_trash = [] self.result_trash = []
super().__init__('FinalLoader', FinalLoaderSubprocessor.Cli, 60) super().__init__('FinalLoader', FinalLoaderSubprocessor.Cli, 60)
#override #override
def on_clients_initialized(self): def on_clients_initialized(self):
@ -564,9 +564,9 @@ class FinalLoaderSubprocessor(Subprocessor):
#override #override
def on_clients_finalized(self): def on_clients_finalized(self):
io.progress_bar_close() io.progress_bar_close()
#override #override
def process_info_generator(self): def process_info_generator(self):
for i in range(0, min(multiprocessing.cpu_count(), 8) ): for i in range(0, min(multiprocessing.cpu_count(), 8) ):
yield 'CPU%d' % (i), {}, {'device_idx': i, yield 'CPU%d' % (i), {}, {'device_idx': i,
'device_name': 'CPU%d' % (i), 'device_name': 'CPU%d' % (i),
@ -575,15 +575,15 @@ class FinalLoaderSubprocessor(Subprocessor):
#override #override
def get_data(self, host_dict): def get_data(self, host_dict):
if len (self.img_list) > 0: if len (self.img_list) > 0:
return [self.img_list.pop(0)] return [self.img_list.pop(0)]
return None return None
#override #override
def on_data_return (self, host_dict, data): def on_data_return (self, host_dict, data):
self.img_list.insert(0, data[0]) self.img_list.insert(0, data[0])
#override #override
def on_result (self, host_dict, data, result): def on_result (self, host_dict, data, result):
if result[0] == 0: if result[0] == 0:
@ -599,7 +599,7 @@ class FinalLoaderSubprocessor(Subprocessor):
class FinalHistDissimSubprocessor(Subprocessor): class FinalHistDissimSubprocessor(Subprocessor):
class Cli(Subprocessor.Cli): class Cli(Subprocessor.Cli):
#override #override
def on_initialize(self, client_dict): def on_initialize(self, client_dict):
self.log_info ('Running on %s.' % (client_dict['device_name']) ) self.log_info ('Running on %s.' % (client_dict['device_name']) )
#override #override
@ -611,25 +611,25 @@ class FinalHistDissimSubprocessor(Subprocessor):
if i == j: if i == j:
continue continue
score_total += cv2.compareHist(img_list[i][2], img_list[j][2], cv2.HISTCMP_BHATTACHARYYA) score_total += cv2.compareHist(img_list[i][2], img_list[j][2], cv2.HISTCMP_BHATTACHARYYA)
img_list[i][3] = score_total img_list[i][3] = score_total
img_list = sorted(img_list, key=operator.itemgetter(3), reverse=True) img_list = sorted(img_list, key=operator.itemgetter(3), reverse=True)
return idx, img_list return idx, img_list
#override #override
def get_data_name (self, data): def get_data_name (self, data):
return "Bunch of images" return "Bunch of images"
#override
def __init__(self, yaws_sample_list ):
self.yaws_sample_list = yaws_sample_list
self.yaws_sample_list_len = len(yaws_sample_list)
self.yaws_sample_list_idxs = [ i for i in range(self.yaws_sample_list_len) if self.yaws_sample_list[i] is not None ]
self.result = [ None for _ in range(self.yaws_sample_list_len) ]
super().__init__('FinalHistDissimSubprocessor', FinalHistDissimSubprocessor.Cli)
#override #override
def process_info_generator(self): def __init__(self, yaws_sample_list ):
self.yaws_sample_list = yaws_sample_list
self.yaws_sample_list_len = len(yaws_sample_list)
self.yaws_sample_list_idxs = [ i for i in range(self.yaws_sample_list_len) if self.yaws_sample_list[i] is not None ]
self.result = [ None for _ in range(self.yaws_sample_list_len) ]
super().__init__('FinalHistDissimSubprocessor', FinalHistDissimSubprocessor.Cli)
#override
def process_info_generator(self):
for i in range(min(multiprocessing.cpu_count(), 8) ): for i in range(min(multiprocessing.cpu_count(), 8) ):
yield 'CPU%d' % (i), {'i':i}, {'device_idx': i, yield 'CPU%d' % (i), {'i':i}, {'device_idx': i,
'device_name': 'CPU%d' % (i) 'device_name': 'CPU%d' % (i)
@ -637,38 +637,38 @@ class FinalHistDissimSubprocessor(Subprocessor):
#override #override
def on_clients_initialized(self): def on_clients_initialized(self):
io.progress_bar ("Sort by hist-dissim", self.yaws_sample_list_len) io.progress_bar ("Sort by hist-dissim", self.yaws_sample_list_len)
#override #override
def on_clients_finalized(self): def on_clients_finalized(self):
io.progress_bar_close() io.progress_bar_close()
#override #override
def get_data(self, host_dict): def get_data(self, host_dict):
if len (self.yaws_sample_list_idxs) > 0: if len (self.yaws_sample_list_idxs) > 0:
idx = self.yaws_sample_list_idxs.pop(0) idx = self.yaws_sample_list_idxs.pop(0)
return idx, self.yaws_sample_list[idx] return idx, self.yaws_sample_list[idx]
return None return None
#override #override
def on_data_return (self, host_dict, data): def on_data_return (self, host_dict, data):
self.yaws_sample_list_idxs.insert(0, data[0]) self.yaws_sample_list_idxs.insert(0, data[0])
#override #override
def on_result (self, host_dict, data, result): def on_result (self, host_dict, data, result):
idx, yaws_sample_list = data idx, yaws_sample_list = data
self.result[idx] = yaws_sample_list self.result[idx] = yaws_sample_list
io.progress_bar_inc(1) io.progress_bar_inc(1)
#override #override
def get_result(self): def get_result(self):
return self.result return self.result
def sort_final(input_path, include_by_blur=True): def sort_final(input_path, include_by_blur=True):
io.log_info ("Performing final sort.") io.log_info ("Performing final sort.")
target_count = io.input_int ("Target number of images? (default:2000) : ", 2000) target_count = io.input_int ("Target number of images? (default:2000) : ", 2000)
img_list, trash_img_list = FinalLoaderSubprocessor( Path_utils.get_image_paths(input_path), include_by_blur ).run() img_list, trash_img_list = FinalLoaderSubprocessor( Path_utils.get_image_paths(input_path), include_by_blur ).run()
final_img_list = [] final_img_list = []
@ -676,12 +676,12 @@ def sort_final(input_path, include_by_blur=True):
imgs_per_grad = round (target_count / grads) imgs_per_grad = round (target_count / grads)
grads_space = np.linspace (-1.0,1.0,grads) grads_space = np.linspace (-1.0,1.0,grads)
yaws_sample_list = [None]*grads yaws_sample_list = [None]*grads
for g in io.progress_bar_generator ( range(grads), "Sort by yaw"): for g in io.progress_bar_generator ( range(grads), "Sort by yaw"):
yaw = grads_space[g] yaw = grads_space[g]
next_yaw = grads_space[g+1] if g < grads-1 else yaw next_yaw = grads_space[g+1] if g < grads-1 else yaw
yaw_samples = [] yaw_samples = []
for img in img_list: for img in img_list:
s_yaw = -img[3] s_yaw = -img[3]
@ -691,17 +691,17 @@ def sort_final(input_path, include_by_blur=True):
yaw_samples += [ img ] yaw_samples += [ img ]
if len(yaw_samples) > 0: if len(yaw_samples) > 0:
yaws_sample_list[g] = yaw_samples yaws_sample_list[g] = yaw_samples
total_lack = 0 total_lack = 0
for g in io.progress_bar_generator ( range(grads), ""): for g in io.progress_bar_generator ( range(grads), ""):
img_list = yaws_sample_list[g] img_list = yaws_sample_list[g]
img_list_len = len(img_list) if img_list is not None else 0 img_list_len = len(img_list) if img_list is not None else 0
lack = imgs_per_grad - img_list_len
total_lack += max(lack, 0)
imgs_per_grad += total_lack // grads lack = imgs_per_grad - img_list_len
total_lack += max(lack, 0)
imgs_per_grad += total_lack // grads
if include_by_blur: if include_by_blur:
sharpned_imgs_per_grad = imgs_per_grad*10 sharpned_imgs_per_grad = imgs_per_grad*10
for g in io.progress_bar_generator ( range (grads), "Sort by blur"): for g in io.progress_bar_generator ( range (grads), "Sort by blur"):
@ -709,47 +709,47 @@ def sort_final(input_path, include_by_blur=True):
if img_list is None: if img_list is None:
continue continue
img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True) img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True)
if len(img_list) > sharpned_imgs_per_grad: if len(img_list) > sharpned_imgs_per_grad:
trash_img_list += img_list[sharpned_imgs_per_grad:] trash_img_list += img_list[sharpned_imgs_per_grad:]
img_list = img_list[0:sharpned_imgs_per_grad] img_list = img_list[0:sharpned_imgs_per_grad]
yaws_sample_list[g] = img_list yaws_sample_list[g] = img_list
yaws_sample_list = FinalHistDissimSubprocessor(yaws_sample_list).run() yaws_sample_list = FinalHistDissimSubprocessor(yaws_sample_list).run()
for g in io.progress_bar_generator ( range (grads), "Fetching best"): for g in io.progress_bar_generator ( range (grads), "Fetching best"):
img_list = yaws_sample_list[g] img_list = yaws_sample_list[g]
if img_list is None: if img_list is None:
continue continue
final_img_list += img_list[0:imgs_per_grad] final_img_list += img_list[0:imgs_per_grad]
trash_img_list += img_list[imgs_per_grad:] trash_img_list += img_list[imgs_per_grad:]
return final_img_list, trash_img_list return final_img_list, trash_img_list
def final_process(input_path, img_list, trash_img_list): def final_process(input_path, img_list, trash_img_list):
if len(trash_img_list) != 0: if len(trash_img_list) != 0:
parent_input_path = input_path.parent parent_input_path = input_path.parent
trash_path = parent_input_path / (input_path.stem + '_trash') trash_path = parent_input_path / (input_path.stem + '_trash')
trash_path.mkdir (exist_ok=True) trash_path.mkdir (exist_ok=True)
io.log_info ("Trashing %d items to %s" % ( len(trash_img_list), str(trash_path) ) ) io.log_info ("Trashing %d items to %s" % ( len(trash_img_list), str(trash_path) ) )
for filename in Path_utils.get_image_paths(trash_path): for filename in Path_utils.get_image_paths(trash_path):
Path(filename).unlink() Path(filename).unlink()
for i in io.progress_bar_generator( range(len(trash_img_list)), "Moving trash", leave=False): for i in io.progress_bar_generator( range(len(trash_img_list)), "Moving trash", leave=False):
src = Path (trash_img_list[i][0]) src = Path (trash_img_list[i][0])
dst = trash_path / src.name dst = trash_path / src.name
try: try:
src.rename (dst) src.rename (dst)
except: except:
io.log_info ('fail to trashing %s' % (src.name) ) io.log_info ('fail to trashing %s' % (src.name) )
io.log_info ("") io.log_info ("")
if len(img_list) != 0: if len(img_list) != 0:
for i in io.progress_bar_generator( [*range(len(img_list))], "Renaming", leave=False): for i in io.progress_bar_generator( [*range(len(img_list))], "Renaming", leave=False):
src = Path (img_list[i][0]) src = Path (img_list[i][0])
@ -758,24 +758,24 @@ def final_process(input_path, img_list, trash_img_list):
src.rename (dst) src.rename (dst)
except: except:
io.log_info ('fail to rename %s' % (src.name) ) io.log_info ('fail to rename %s' % (src.name) )
for i in io.progress_bar_generator( [*range(len(img_list))], "Renaming"): for i in io.progress_bar_generator( [*range(len(img_list))], "Renaming"):
src = Path (img_list[i][0]) src = Path (img_list[i][0])
src = input_path / ('%.5d_%s' % (i, src.name)) src = input_path / ('%.5d_%s' % (i, src.name))
dst = input_path / ('%.5d%s' % (i, src.suffix)) dst = input_path / ('%.5d%s' % (i, src.suffix))
try: try:
src.rename (dst) src.rename (dst)
except: except:
io.log_info ('fail to rename %s' % (src.name) ) io.log_info ('fail to rename %s' % (src.name) )
def main (input_path, sort_by_method): def main (input_path, sort_by_method):
input_path = Path(input_path) input_path = Path(input_path)
sort_by_method = sort_by_method.lower() sort_by_method = sort_by_method.lower()
io.log_info ("Running sort tool.\r\n") io.log_info ("Running sort tool.\r\n")
img_list = [] img_list = []
trash_img_list = [] trash_img_list = []
if sort_by_method == 'blur': img_list, trash_img_list = sort_by_blur (input_path) if sort_by_method == 'blur': img_list, trash_img_list = sort_by_blur (input_path)
@ -787,10 +787,10 @@ def main (input_path, sort_by_method):
elif sort_by_method == 'hist-dissim': img_list, trash_img_list = sort_by_hist_dissim (input_path) elif sort_by_method == 'hist-dissim': img_list, trash_img_list = sort_by_hist_dissim (input_path)
elif sort_by_method == 'brightness': img_list = sort_by_brightness (input_path) elif sort_by_method == 'brightness': img_list = sort_by_brightness (input_path)
elif sort_by_method == 'hue': img_list = sort_by_hue (input_path) elif sort_by_method == 'hue': img_list = sort_by_hue (input_path)
elif sort_by_method == 'black': img_list = sort_by_black (input_path) elif sort_by_method == 'black': img_list = sort_by_black (input_path)
elif sort_by_method == 'origname': img_list, trash_img_list = sort_by_origname (input_path) elif sort_by_method == 'origname': img_list, trash_img_list = sort_by_origname (input_path)
elif sort_by_method == 'oneface': img_list, trash_img_list = sort_by_oneface_in_image (input_path) elif sort_by_method == 'oneface': img_list, trash_img_list = sort_by_oneface_in_image (input_path)
elif sort_by_method == 'final': img_list, trash_img_list = sort_final (input_path) elif sort_by_method == 'final': img_list, trash_img_list = sort_final (input_path)
elif sort_by_method == 'final-no-blur': img_list, trash_img_list = sort_final (input_path, include_by_blur=False) elif sort_by_method == 'final-no-blur': img_list, trash_img_list = sort_final (input_path, include_by_blur=False)
final_process (input_path, img_list, trash_img_list) final_process (input_path, img_list, trash_img_list)

View file

@ -7,39 +7,39 @@ import numpy as np
import itertools import itertools
from pathlib import Path from pathlib import Path
from utils import Path_utils from utils import Path_utils
from utils import image_utils from utils import image_utils
import cv2 import cv2
import models import models
from interact import interact as io from interact import interact as io
def trainerThread (s2c, c2s, args, device_args): def trainerThread (s2c, c2s, args, device_args):
while True: while True:
try: try:
training_data_src_path = Path( args.get('training_data_src_dir', '') ) training_data_src_path = Path( args.get('training_data_src_dir', '') )
training_data_dst_path = Path( args.get('training_data_dst_dir', '') ) training_data_dst_path = Path( args.get('training_data_dst_dir', '') )
model_path = Path( args.get('model_path', '') ) model_path = Path( args.get('model_path', '') )
model_name = args.get('model_name', '') model_name = args.get('model_name', '')
save_interval_min = 15 save_interval_min = 15
debug = args.get('debug', '') debug = args.get('debug', '')
if not training_data_src_path.exists(): if not training_data_src_path.exists():
io.log_err('Training data src directory does not exist.') io.log_err('Training data src directory does not exist.')
break break
if not training_data_dst_path.exists(): if not training_data_dst_path.exists():
io.log_err('Training data dst directory does not exist.') io.log_err('Training data dst directory does not exist.')
break break
if not model_path.exists(): if not model_path.exists():
model_path.mkdir(exist_ok=True) model_path.mkdir(exist_ok=True)
model = models.import_model(model_name)( model = models.import_model(model_name)(
model_path, model_path,
training_data_src_path=training_data_src_path, training_data_src_path=training_data_src_path,
training_data_dst_path=training_data_dst_path, training_data_dst_path=training_data_dst_path,
debug=debug, debug=debug,
device_args=device_args) device_args=device_args)
is_reached_goal = model.is_reached_iter_goal() is_reached_goal = model.is_reached_iter_goal()
is_upd_save_time_after_train = False is_upd_save_time_after_train = False
loss_string = "" loss_string = ""
@ -49,37 +49,37 @@ def trainerThread (s2c, c2s, args, device_args):
model.save() model.save()
io.log_info(loss_string) io.log_info(loss_string)
is_upd_save_time_after_train = True is_upd_save_time_after_train = True
def send_preview(): def send_preview():
if not debug: if not debug:
previews = model.get_previews() previews = model.get_previews()
c2s.put ( {'op':'show', 'previews': previews, 'iter':model.get_iter(), 'loss_history': model.get_loss_history().copy() } ) c2s.put ( {'op':'show', 'previews': previews, 'iter':model.get_iter(), 'loss_history': model.get_loss_history().copy() } )
else: else:
previews = [( 'debug, press update for new', model.debug_one_iter())] previews = [( 'debug, press update for new', model.debug_one_iter())]
c2s.put ( {'op':'show', 'previews': previews} ) c2s.put ( {'op':'show', 'previews': previews} )
if model.is_first_run(): if model.is_first_run():
model_save() model_save()
if model.get_target_iter() != 0: if model.get_target_iter() != 0:
if is_reached_goal: if is_reached_goal:
io.log_info('Model already trained to target iteration. You can use preview.') io.log_info('Model already trained to target iteration. You can use preview.')
else: else:
io.log_info('Starting. Target iteration: %d. Press "Enter" to stop training and save model.' % ( model.get_target_iter() ) ) io.log_info('Starting. Target iteration: %d. Press "Enter" to stop training and save model.' % ( model.get_target_iter() ) )
else: else:
io.log_info('Starting. Press "Enter" to stop training and save model.') io.log_info('Starting. Press "Enter" to stop training and save model.')
last_save_time = time.time() last_save_time = time.time()
for i in itertools.count(0,1): for i in itertools.count(0,1):
if not debug: if not debug:
if not is_reached_goal: if not is_reached_goal:
loss_string = model.train_one_iter() loss_string = model.train_one_iter()
if is_upd_save_time_after_train: if is_upd_save_time_after_train:
#save resets plaidML programs, so upd last_save_time only after plaidML rebuild them #save resets plaidML programs, so upd last_save_time only after plaidML rebuild them
last_save_time = time.time() last_save_time = time.time()
io.log_info (loss_string, end='\r') io.log_info (loss_string, end='\r')
if model.get_target_iter() != 0 and model.is_reached_iter_goal(): if model.get_target_iter() != 0 and model.is_reached_iter_goal():
io.log_info ('Reached target iteration.') io.log_info ('Reached target iteration.')
@ -91,77 +91,77 @@ def trainerThread (s2c, c2s, args, device_args):
last_save_time = time.time() last_save_time = time.time()
model_save() model_save()
send_preview() send_preview()
if i==0: if i==0:
if is_reached_goal: if is_reached_goal:
model.pass_one_iter() model.pass_one_iter()
send_preview() send_preview()
if debug: if debug:
time.sleep(0.005) time.sleep(0.005)
while not s2c.empty(): while not s2c.empty():
input = s2c.get() input = s2c.get()
op = input['op'] op = input['op']
if op == 'save': if op == 'save':
model_save() model_save()
elif op == 'preview': elif op == 'preview':
if is_reached_goal: if is_reached_goal:
model.pass_one_iter() model.pass_one_iter()
send_preview() send_preview()
elif op == 'close': elif op == 'close':
model_save() model_save()
i = -1 i = -1
break break
if i == -1: if i == -1:
break break
model.finalize() model.finalize()
except Exception as e: except Exception as e:
print ('Error: %s' % (str(e))) print ('Error: %s' % (str(e)))
traceback.print_exc() traceback.print_exc()
break break
c2s.put ( {'op':'close'} ) c2s.put ( {'op':'close'} )
def main(args, device_args): def main(args, device_args):
io.log_info ("Running trainer.\r\n") io.log_info ("Running trainer.\r\n")
no_preview = args.get('no_preview', False) no_preview = args.get('no_preview', False)
s2c = queue.Queue() s2c = queue.Queue()
c2s = queue.Queue() c2s = queue.Queue()
thread = threading.Thread(target=trainerThread, args=(s2c, c2s, args, device_args) ) thread = threading.Thread(target=trainerThread, args=(s2c, c2s, args, device_args) )
thread.start() thread.start()
if no_preview: if no_preview:
while True: while True:
if not c2s.empty(): if not c2s.empty():
input = c2s.get() input = c2s.get()
op = input.get('op','') op = input.get('op','')
if op == 'close': if op == 'close':
break break
io.process_messages(0.1) io.process_messages(0.1)
else: else:
wnd_name = "Training preview" wnd_name = "Training preview"
io.named_window(wnd_name) io.named_window(wnd_name)
io.capture_keys(wnd_name) io.capture_keys(wnd_name)
previews = None previews = None
loss_history = None loss_history = None
selected_preview = 0 selected_preview = 0
update_preview = False update_preview = False
is_showing = False is_showing = False
is_waiting_preview = False is_waiting_preview = False
show_last_history_iters_count = 0 show_last_history_iters_count = 0
iter = 0 iter = 0
while True: while True:
if not c2s.empty(): if not c2s.empty():
input = c2s.get() input = c2s.get()
op = input['op'] op = input['op']
@ -177,7 +177,7 @@ def main(args, device_args):
(h, w, c) = preview_rgb.shape (h, w, c) = preview_rgb.shape
max_h = max (max_h, h) max_h = max (max_h, h)
max_w = max (max_w, w) max_w = max (max_w, w)
max_size = 800 max_size = 800
if max_h > max_size: if max_h > max_size:
max_w = int( max_w / (max_h / max_size) ) max_w = int( max_w / (max_h / max_size) )
@ -194,49 +194,49 @@ def main(args, device_args):
update_preview = True update_preview = True
elif op == 'close': elif op == 'close':
break break
if update_preview: if update_preview:
update_preview = False update_preview = False
selected_preview_name = previews[selected_preview][0] selected_preview_name = previews[selected_preview][0]
selected_preview_rgb = previews[selected_preview][1] selected_preview_rgb = previews[selected_preview][1]
(h,w,c) = selected_preview_rgb.shape (h,w,c) = selected_preview_rgb.shape
# HEAD # HEAD
head_lines = [ head_lines = [
'[s]:save [enter]:exit', '[s]:save [enter]:exit',
'[p]:update [space]:next preview [l]:change history range', '[p]:update [space]:next preview [l]:change history range',
'Preview: "%s" [%d/%d]' % (selected_preview_name,selected_preview+1, len(previews) ) 'Preview: "%s" [%d/%d]' % (selected_preview_name,selected_preview+1, len(previews) )
] ]
head_line_height = 15 head_line_height = 15
head_height = len(head_lines) * head_line_height head_height = len(head_lines) * head_line_height
head = np.ones ( (head_height,w,c) ) * 0.1 head = np.ones ( (head_height,w,c) ) * 0.1
for i in range(0, len(head_lines)): for i in range(0, len(head_lines)):
t = i*head_line_height t = i*head_line_height
b = (i+1)*head_line_height b = (i+1)*head_line_height
head[t:b, 0:w] += image_utils.get_text_image ( (w,head_line_height,c) , head_lines[i], color=[0.8]*c ) head[t:b, 0:w] += image_utils.get_text_image ( (w,head_line_height,c) , head_lines[i], color=[0.8]*c )
final = head final = head
if loss_history is not None: if loss_history is not None:
if show_last_history_iters_count == 0: if show_last_history_iters_count == 0:
loss_history_to_show = loss_history loss_history_to_show = loss_history
else: else:
loss_history_to_show = loss_history[-show_last_history_iters_count:] loss_history_to_show = loss_history[-show_last_history_iters_count:]
lh_img = models.ModelBase.get_loss_history_preview(loss_history_to_show, iter, w, c) lh_img = models.ModelBase.get_loss_history_preview(loss_history_to_show, iter, w, c)
final = np.concatenate ( [final, lh_img], axis=0 ) final = np.concatenate ( [final, lh_img], axis=0 )
final = np.concatenate ( [final, selected_preview_rgb], axis=0 ) final = np.concatenate ( [final, selected_preview_rgb], axis=0 )
final = np.clip(final, 0, 1) final = np.clip(final, 0, 1)
io.show_image( wnd_name, (final*255).astype(np.uint8) ) io.show_image( wnd_name, (final*255).astype(np.uint8) )
is_showing = True is_showing = True
key_events = io.get_key_events(wnd_name) key_events = io.get_key_events(wnd_name)
key, = key_events[-1] if len(key_events) > 0 else (0,) key, = key_events[-1] if len(key_events) > 0 else (0,)
if key == ord('\n') or key == ord('\r'): if key == ord('\n') or key == ord('\r'):
s2c.put ( {'op': 'close'} ) s2c.put ( {'op': 'close'} )
elif key == ord('s'): elif key == ord('s'):
@ -253,14 +253,14 @@ def main(args, device_args):
elif show_last_history_iters_count == 10000: elif show_last_history_iters_count == 10000:
show_last_history_iters_count = 50000 show_last_history_iters_count = 50000
elif show_last_history_iters_count == 50000: elif show_last_history_iters_count == 50000:
show_last_history_iters_count = 100000 show_last_history_iters_count = 100000
elif show_last_history_iters_count == 100000: elif show_last_history_iters_count == 100000:
show_last_history_iters_count = 0 show_last_history_iters_count = 0
update_preview = True update_preview = True
elif key == ord(' '): elif key == ord(' '):
selected_preview = (selected_preview + 1) % len(previews) selected_preview = (selected_preview + 1) % len(previews)
update_preview = True update_preview = True
io.process_messages(0.1) io.process_messages(0.1)
io.destroy_all_windows() io.destroy_all_windows()

View file

@ -9,30 +9,30 @@ from interact import interact as io
def convert_png_to_jpg_file (filepath): def convert_png_to_jpg_file (filepath):
filepath = Path(filepath) filepath = Path(filepath)
if filepath.suffix != '.png': if filepath.suffix != '.png':
return return
dflpng = DFLPNG.load (str(filepath) ) dflpng = DFLPNG.load (str(filepath) )
if dflpng is None: if dflpng is None:
io.log_err ("%s is not a dfl image file" % (filepath.name) ) io.log_err ("%s is not a dfl image file" % (filepath.name) )
return return
dfl_dict = dflpng.getDFLDictData() dfl_dict = dflpng.getDFLDictData()
img = cv2_imread (str(filepath)) img = cv2_imread (str(filepath))
new_filepath = str(filepath.parent / (filepath.stem + '.jpg')) new_filepath = str(filepath.parent / (filepath.stem + '.jpg'))
cv2_imwrite ( new_filepath, img, [int(cv2.IMWRITE_JPEG_QUALITY), 85]) cv2_imwrite ( new_filepath, img, [int(cv2.IMWRITE_JPEG_QUALITY), 85])
DFLJPG.embed_data( new_filepath, DFLJPG.embed_data( new_filepath,
face_type=dfl_dict.get('face_type', None), face_type=dfl_dict.get('face_type', None),
landmarks=dfl_dict.get('landmarks', None), landmarks=dfl_dict.get('landmarks', None),
source_filename=dfl_dict.get('source_filename', None), source_filename=dfl_dict.get('source_filename', None),
source_rect=dfl_dict.get('source_rect', None), source_rect=dfl_dict.get('source_rect', None),
source_landmarks=dfl_dict.get('source_landmarks', None) ) source_landmarks=dfl_dict.get('source_landmarks', None) )
filepath.unlink() filepath.unlink()
def convert_png_to_jpg_folder (input_path): def convert_png_to_jpg_folder (input_path):
input_path = Path(input_path) input_path = Path(input_path)
@ -41,73 +41,73 @@ def convert_png_to_jpg_folder (input_path):
for filepath in io.progress_bar_generator( Path_utils.get_image_paths(input_path), "Converting"): for filepath in io.progress_bar_generator( Path_utils.get_image_paths(input_path), "Converting"):
filepath = Path(filepath) filepath = Path(filepath)
convert_png_to_jpg_file(filepath) convert_png_to_jpg_file(filepath)
def add_landmarks_debug_images(input_path): def add_landmarks_debug_images(input_path):
io.log_info ("Adding landmarks debug images...") io.log_info ("Adding landmarks debug images...")
for filepath in io.progress_bar_generator( Path_utils.get_image_paths(input_path), "Processing"): for filepath in io.progress_bar_generator( Path_utils.get_image_paths(input_path), "Processing"):
filepath = Path(filepath) filepath = Path(filepath)
img = cv2_imread(str(filepath)) img = cv2_imread(str(filepath))
if filepath.suffix == '.png': if filepath.suffix == '.png':
dflimg = DFLPNG.load( str(filepath) ) dflimg = DFLPNG.load( str(filepath) )
elif filepath.suffix == '.jpg': elif filepath.suffix == '.jpg':
dflimg = DFLJPG.load ( str(filepath) ) dflimg = DFLJPG.load ( str(filepath) )
else: else:
dflimg = None dflimg = None
if dflimg is None: if dflimg is None:
io.log_err ("%s is not a dfl image file" % (filepath.name) ) io.log_err ("%s is not a dfl image file" % (filepath.name) )
continue continue
if img is not None: if img is not None:
face_landmarks = dflimg.get_landmarks() face_landmarks = dflimg.get_landmarks()
LandmarksProcessor.draw_landmarks(img, face_landmarks, transparent_mask=True) LandmarksProcessor.draw_landmarks(img, face_landmarks, transparent_mask=True)
output_file = '{}{}'.format( str(Path(str(input_path)) / filepath.stem), '_debug.jpg') output_file = '{}{}'.format( str(Path(str(input_path)) / filepath.stem), '_debug.jpg')
cv2_imwrite(output_file, img, [int(cv2.IMWRITE_JPEG_QUALITY), 50] ) cv2_imwrite(output_file, img, [int(cv2.IMWRITE_JPEG_QUALITY), 50] )
def recover_original_aligned_filename(input_path): def recover_original_aligned_filename(input_path):
io.log_info ("Recovering original aligned filename...") io.log_info ("Recovering original aligned filename...")
files = [] files = []
for filepath in io.progress_bar_generator( Path_utils.get_image_paths(input_path), "Processing"): for filepath in io.progress_bar_generator( Path_utils.get_image_paths(input_path), "Processing"):
filepath = Path(filepath) filepath = Path(filepath)
if filepath.suffix == '.png': if filepath.suffix == '.png':
dflimg = DFLPNG.load( str(filepath) ) dflimg = DFLPNG.load( str(filepath) )
elif filepath.suffix == '.jpg': elif filepath.suffix == '.jpg':
dflimg = DFLJPG.load ( str(filepath) ) dflimg = DFLJPG.load ( str(filepath) )
else: else:
dflimg = None dflimg = None
if dflimg is None: if dflimg is None:
io.log_err ("%s is not a dfl image file" % (filepath.name) ) io.log_err ("%s is not a dfl image file" % (filepath.name) )
continue continue
files += [ [filepath, None, dflimg.get_source_filename(), False] ] files += [ [filepath, None, dflimg.get_source_filename(), False] ]
files_len = len(files) files_len = len(files)
for i in io.progress_bar_generator( range(files_len), "Sorting" ): for i in io.progress_bar_generator( range(files_len), "Sorting" ):
fp, _, sf, converted = files[i] fp, _, sf, converted = files[i]
if converted: if converted:
continue continue
sf_stem = Path(sf).stem sf_stem = Path(sf).stem
files[i][1] = fp.parent / ( sf_stem + '_0' + fp.suffix ) files[i][1] = fp.parent / ( sf_stem + '_0' + fp.suffix )
files[i][3] = True files[i][3] = True
c = 1 c = 1
for j in range(i+1, files_len): for j in range(i+1, files_len):
fp_j, _, sf_j, converted_j = files[j] fp_j, _, sf_j, converted_j = files[j]
if converted_j: if converted_j:
continue continue
if sf_j == sf: if sf_j == sf:
files[j][1] = fp_j.parent / ( sf_stem + ('_%d' % (c)) + fp_j.suffix ) files[j][1] = fp_j.parent / ( sf_stem + ('_%d' % (c)) + fp_j.suffix )
files[j][3] = True files[j][3] = True
c += 1 c += 1
@ -118,11 +118,11 @@ def recover_original_aligned_filename(input_path):
fs.rename (dst) fs.rename (dst)
except: except:
io.log_err ('fail to rename %s' % (fs.name) ) io.log_err ('fail to rename %s' % (fs.name) )
for file in io.progress_bar_generator( files, "Renaming" ): for file in io.progress_bar_generator( files, "Renaming" ):
fs, fd, _, _ = file fs, fd, _, _ = file
fs = fs.parent / ( fs.stem + '_tmp' + fs.suffix ) fs = fs.parent / ( fs.stem + '_tmp' + fs.suffix )
try: try:
fs.rename (fd) fs.rename (fd)
except: except:
io.log_err ('fail to rename %s' % (fs.name) ) io.log_err ('fail to rename %s' % (fs.name) )

View file

@ -8,38 +8,38 @@ from interact import interact as io
def extract_video(input_file, output_dir, output_ext=None, fps=None): def extract_video(input_file, output_dir, output_ext=None, fps=None):
input_file_path = Path(input_file) input_file_path = Path(input_file)
output_path = Path(output_dir) output_path = Path(output_dir)
if not output_path.exists(): if not output_path.exists():
output_path.mkdir(exist_ok=True) output_path.mkdir(exist_ok=True)
if input_file_path.suffix == '.*': if input_file_path.suffix == '.*':
input_file_path = Path_utils.get_first_file_by_stem (input_file_path.parent, input_file_path.stem) input_file_path = Path_utils.get_first_file_by_stem (input_file_path.parent, input_file_path.stem)
else: else:
if not input_file_path.exists(): if not input_file_path.exists():
input_file_path = None input_file_path = None
if input_file_path is None: if input_file_path is None:
io.log_err("input_file not found.") io.log_err("input_file not found.")
return return
if output_ext is None: if output_ext is None:
output_ext = io.input_str ("Output image format (extension)? ( default:png ) : ", "png") output_ext = io.input_str ("Output image format (extension)? ( default:png ) : ", "png")
if fps is None: if fps is None:
fps = io.input_int ("Enter FPS ( ?:help skip:fullfps ) : ", 0, help_message="How many frames of every second of the video will be extracted.") fps = io.input_int ("Enter FPS ( ?:help skip:fullfps ) : ", 0, help_message="How many frames of every second of the video will be extracted.")
for filename in Path_utils.get_image_paths (output_path, ['.'+output_ext]): for filename in Path_utils.get_image_paths (output_path, ['.'+output_ext]):
Path(filename).unlink() Path(filename).unlink()
job = ffmpeg.input(str(input_file_path)) job = ffmpeg.input(str(input_file_path))
kwargs = {} kwargs = {}
if fps != 0: if fps != 0:
kwargs.update ({'r':str(fps)}) kwargs.update ({'r':str(fps)})
job = job.output( str (output_path / ('%5d.'+output_ext)), **kwargs ) job = job.output( str (output_path / ('%5d.'+output_ext)), **kwargs )
try: try:
job = job.run() job = job.run()
except: except:
@ -50,18 +50,18 @@ def cut_video ( input_file, from_time=None, to_time=None, audio_track_id=None, b
if input_file_path is None: if input_file_path is None:
io.log_err("input_file not found.") io.log_err("input_file not found.")
return return
output_file_path = input_file_path.parent / (input_file_path.stem + "_cut" + input_file_path.suffix) output_file_path = input_file_path.parent / (input_file_path.stem + "_cut" + input_file_path.suffix)
if from_time is None: if from_time is None:
from_time = io.input_str ("From time (skip: 00:00:00.000) : ", "00:00:00.000") from_time = io.input_str ("From time (skip: 00:00:00.000) : ", "00:00:00.000")
if to_time is None: if to_time is None:
to_time = io.input_str ("To time (skip: 00:00:00.000) : ", "00:00:00.000") to_time = io.input_str ("To time (skip: 00:00:00.000) : ", "00:00:00.000")
if audio_track_id is None: if audio_track_id is None:
audio_track_id = io.input_int ("Specify audio track id. ( skip:0 ) : ", 0) audio_track_id = io.input_int ("Specify audio track id. ( skip:0 ) : ", 0)
if bitrate is None: if bitrate is None:
bitrate = max (1, io.input_int ("Bitrate of output file in MB/s ? (default:25) : ", 25) ) bitrate = max (1, io.input_int ("Bitrate of output file in MB/s ? (default:25) : ", 25) )
@ -69,64 +69,64 @@ def cut_video ( input_file, from_time=None, to_time=None, audio_track_id=None, b
"b:v": "%dM" %(bitrate), "b:v": "%dM" %(bitrate),
"pix_fmt": "yuv420p", "pix_fmt": "yuv420p",
} }
job = ffmpeg.input(str(input_file_path), ss=from_time, to=to_time) job = ffmpeg.input(str(input_file_path), ss=from_time, to=to_time)
job_v = job['v:0'] job_v = job['v:0']
job_a = job['a:' + str(audio_track_id) + '?' ] job_a = job['a:' + str(audio_track_id) + '?' ]
job = ffmpeg.output(job_v, job_a, str(output_file_path), **kwargs).overwrite_output() job = ffmpeg.output(job_v, job_a, str(output_file_path), **kwargs).overwrite_output()
try: try:
job = job.run() job = job.run()
except: except:
io.log_err ("ffmpeg fail, job commandline:" + str(job.compile()) ) io.log_err ("ffmpeg fail, job commandline:" + str(job.compile()) )
def denoise_image_sequence( input_dir, ext=None, factor=None ): def denoise_image_sequence( input_dir, ext=None, factor=None ):
input_path = Path(input_dir) input_path = Path(input_dir)
if not input_path.exists(): if not input_path.exists():
io.log_err("input_dir not found.") io.log_err("input_dir not found.")
return return
if ext is None: if ext is None:
ext = io.input_str ("Input image format (extension)? ( default:png ) : ", "png") ext = io.input_str ("Input image format (extension)? ( default:png ) : ", "png")
if factor is None: if factor is None:
factor = np.clip ( io.input_int ("Denoise factor? (1-20 default:5) : ", 5), 1, 20 ) factor = np.clip ( io.input_int ("Denoise factor? (1-20 default:5) : ", 5), 1, 20 )
job = ( ffmpeg job = ( ffmpeg
.input(str ( input_path / ('%5d.'+ext) ) ) .input(str ( input_path / ('%5d.'+ext) ) )
.filter("hqdn3d", factor, factor, 5,5) .filter("hqdn3d", factor, factor, 5,5)
.output(str ( input_path / ('%5d.'+ext) ) ) .output(str ( input_path / ('%5d.'+ext) ) )
) )
try: try:
job = job.run() job = job.run()
except: except:
io.log_err ("ffmpeg fail, job commandline:" + str(job.compile()) ) io.log_err ("ffmpeg fail, job commandline:" + str(job.compile()) )
def video_from_sequence( input_dir, output_file, reference_file=None, ext=None, fps=None, bitrate=None, lossless=None ): def video_from_sequence( input_dir, output_file, reference_file=None, ext=None, fps=None, bitrate=None, lossless=None ):
input_path = Path(input_dir) input_path = Path(input_dir)
output_file_path = Path(output_file) output_file_path = Path(output_file)
reference_file_path = Path(reference_file) if reference_file is not None else None reference_file_path = Path(reference_file) if reference_file is not None else None
if not input_path.exists(): if not input_path.exists():
io.log_err("input_dir not found.") io.log_err("input_dir not found.")
return return
if not output_file_path.parent.exists(): if not output_file_path.parent.exists():
output_file_path.parent.mkdir(parents=True, exist_ok=True) output_file_path.parent.mkdir(parents=True, exist_ok=True)
return return
out_ext = output_file_path.suffix out_ext = output_file_path.suffix
if ext is None: if ext is None:
ext = io.input_str ("Input image format (extension)? ( default:png ) : ", "png") ext = io.input_str ("Input image format (extension)? ( default:png ) : ", "png")
if lossless is None: if lossless is None:
lossless = io.input_bool ("Use lossless codec ? ( default:no ) : ", False) lossless = io.input_bool ("Use lossless codec ? ( default:no ) : ", False)
video_id = None video_id = None
audio_id = None audio_id = None
ref_in_a = None ref_in_a = None
@ -136,7 +136,7 @@ def video_from_sequence( input_dir, output_file, reference_file=None, ext=None,
else: else:
if not reference_file_path.exists(): if not reference_file_path.exists():
reference_file_path = None reference_file_path = None
if reference_file_path is None: if reference_file_path is None:
io.log_err("reference_file not found.") io.log_err("reference_file not found.")
return return
@ -149,32 +149,32 @@ def video_from_sequence( input_dir, output_file, reference_file=None, ext=None,
if video_id is None and stream['codec_type'] == 'video': if video_id is None and stream['codec_type'] == 'video':
video_id = stream['index'] video_id = stream['index']
fps = stream['r_frame_rate'] fps = stream['r_frame_rate']
if audio_id is None and stream['codec_type'] == 'audio': if audio_id is None and stream['codec_type'] == 'audio':
audio_id = stream['index'] audio_id = stream['index']
if audio_id is not None: if audio_id is not None:
#has audio track #has audio track
ref_in_a = ffmpeg.input (str(reference_file_path))[str(audio_id)] ref_in_a = ffmpeg.input (str(reference_file_path))[str(audio_id)]
if fps is None: if fps is None:
#if fps not specified and not overwritten by reference-file #if fps not specified and not overwritten by reference-file
fps = max (1, io.input_int ("FPS ? (default:25) : ", 25) ) fps = max (1, io.input_int ("FPS ? (default:25) : ", 25) )
if not lossless and bitrate is None: if not lossless and bitrate is None:
bitrate = max (1, io.input_int ("Bitrate of output file in MB/s ? (default:16) : ", 16) ) bitrate = max (1, io.input_int ("Bitrate of output file in MB/s ? (default:16) : ", 16) )
i_in = ffmpeg.input(str (input_path / ('%5d.'+ext)), r=fps) i_in = ffmpeg.input(str (input_path / ('%5d.'+ext)), r=fps)
output_args = [i_in] output_args = [i_in]
if ref_in_a is not None: if ref_in_a is not None:
output_args += [ref_in_a] output_args += [ref_in_a]
output_args += [str (output_file_path)] output_args += [str (output_file_path)]
output_kwargs = {} output_kwargs = {}
if lossless: if lossless:
output_kwargs.update ({"c:v": "png" output_kwargs.update ({"c:v": "png"
}) })
@ -183,15 +183,14 @@ def video_from_sequence( input_dir, output_file, reference_file=None, ext=None,
"b:v": "%dM" %(bitrate), "b:v": "%dM" %(bitrate),
"pix_fmt": "yuv420p", "pix_fmt": "yuv420p",
}) })
output_kwargs.update ({"c:a": "aac", output_kwargs.update ({"c:a": "aac",
"b:a": "192k", "b:a": "192k",
"ar" : "48000" "ar" : "48000"
}) })
job = ( ffmpeg.output(*output_args, **output_kwargs).overwrite_output() ) job = ( ffmpeg.output(*output_args, **output_kwargs).overwrite_output() )
try: try:
job = job.run() job = job.run()
except: except:
io.log_err ("ffmpeg fail, job commandline:" + str(job.compile()) ) io.log_err ("ffmpeg fail, job commandline:" + str(job.compile()) )

View file

@ -7,10 +7,10 @@ def get_power_of_two(x):
while (1 << i) < x: while (1 << i) < x:
i += 1 i += 1
return i return i
def rotationMatrixToEulerAngles(R) : def rotationMatrixToEulerAngles(R) :
sy = math.sqrt(R[0,0] * R[0,0] + R[1,0] * R[1,0]) sy = math.sqrt(R[0,0] * R[0,0] + R[1,0] * R[1,0])
singular = sy < 1e-6 singular = sy < 1e-6
if not singular : if not singular :
x = math.atan2(R[2,1] , R[2,2]) x = math.atan2(R[2,1] , R[2,2])
y = math.atan2(-R[2,0], sy) y = math.atan2(-R[2,0], sy)
@ -18,8 +18,8 @@ def rotationMatrixToEulerAngles(R) :
else : else :
x = math.atan2(-R[1,2], R[1,1]) x = math.atan2(-R[1,2], R[1,1])
y = math.atan2(-R[2,0], sy) y = math.atan2(-R[2,0], sy)
z = 0 z = 0
return np.array([x, y, z]) return np.array([x, y, z])
def polygon_area(x,y): def polygon_area(x,y):
return 0.5*np.abs(np.dot(x,np.roll(y,1))-np.dot(y,np.roll(x,1))) return 0.5*np.abs(np.dot(x,np.roll(y,1))-np.dot(y,np.roll(x,1)))

View file

@ -68,4 +68,4 @@ def umeyama(src, dst, estimate_scale):
T[:dim, dim] = dst_mean - scale * np.dot(T[:dim, :dim], src_mean.T) T[:dim, dim] = dst_mean - scale * np.dot(T[:dim, :dim], src_mean.T)
T[:dim, :dim] *= scale T[:dim, :dim] *= scale
return T return T

View file

@ -23,11 +23,11 @@ class ModelBase(object):
def __init__(self, model_path, training_data_src_path=None, training_data_dst_path=None, debug = False, device_args = None, def __init__(self, model_path, training_data_src_path=None, training_data_dst_path=None, debug = False, device_args = None,
ask_write_preview_history=True, ask_target_iter=True, ask_batch_size=True, ask_sort_by_yaw=True, ask_write_preview_history=True, ask_target_iter=True, ask_batch_size=True, ask_sort_by_yaw=True,
ask_random_flip=True, ask_src_scale_mod=True): ask_random_flip=True, ask_src_scale_mod=True):
device_args['force_gpu_idx'] = device_args.get('force_gpu_idx',-1) device_args['force_gpu_idx'] = device_args.get('force_gpu_idx',-1)
device_args['cpu_only'] = device_args.get('cpu_only',False) device_args['cpu_only'] = device_args.get('cpu_only',False)
if device_args['force_gpu_idx'] == -1 and not device_args['cpu_only']: if device_args['force_gpu_idx'] == -1 and not device_args['cpu_only']:
idxs_names_list = nnlib.device.getValidDevicesIdxsWithNamesList() idxs_names_list = nnlib.device.getValidDevicesIdxsWithNamesList()
if len(idxs_names_list) > 1: if len(idxs_names_list) > 1:
io.log_info ("You have multi GPUs in a system: ") io.log_info ("You have multi GPUs in a system: ")
@ -36,17 +36,17 @@ class ModelBase(object):
device_args['force_gpu_idx'] = io.input_int("Which GPU idx to choose? ( skip: best GPU ) : ", -1, [ x[0] for x in idxs_names_list] ) device_args['force_gpu_idx'] = io.input_int("Which GPU idx to choose? ( skip: best GPU ) : ", -1, [ x[0] for x in idxs_names_list] )
self.device_args = device_args self.device_args = device_args
self.device_config = nnlib.DeviceConfig(allow_growth=False, **self.device_args) self.device_config = nnlib.DeviceConfig(allow_growth=False, **self.device_args)
io.log_info ("Loading model...") io.log_info ("Loading model...")
self.model_path = model_path self.model_path = model_path
self.model_data_path = Path( self.get_strpath_storage_for_file('data.dat') ) self.model_data_path = Path( self.get_strpath_storage_for_file('data.dat') )
self.training_data_src_path = training_data_src_path self.training_data_src_path = training_data_src_path
self.training_data_dst_path = training_data_dst_path self.training_data_dst_path = training_data_dst_path
self.src_images_paths = None self.src_images_paths = None
self.dst_images_paths = None self.dst_images_paths = None
self.src_yaw_images_paths = None self.src_yaw_images_paths = None
@ -60,10 +60,10 @@ class ModelBase(object):
self.options = {} self.options = {}
self.loss_history = [] self.loss_history = []
self.sample_for_preview = None self.sample_for_preview = None
model_data = {} model_data = {}
if self.model_data_path.exists(): if self.model_data_path.exists():
model_data = pickle.loads ( self.model_data_path.read_bytes() ) model_data = pickle.loads ( self.model_data_path.read_bytes() )
self.iter = max( model_data.get('iter',0), model_data.get('epoch',0) ) self.iter = max( model_data.get('iter',0), model_data.get('epoch',0) )
if 'epoch' in self.options: if 'epoch' in self.options:
self.options.pop('epoch') self.options.pop('epoch')
@ -73,101 +73,101 @@ class ModelBase(object):
self.sample_for_preview = model_data['sample_for_preview'] if 'sample_for_preview' in model_data.keys() else None self.sample_for_preview = model_data['sample_for_preview'] if 'sample_for_preview' in model_data.keys() else None
ask_override = self.is_training_mode and self.iter != 0 and io.input_in_time ("Press enter in 2 seconds to override model settings.", 2) ask_override = self.is_training_mode and self.iter != 0 and io.input_in_time ("Press enter in 2 seconds to override model settings.", 2)
yn_str = {True:'y',False:'n'} yn_str = {True:'y',False:'n'}
if self.iter == 0: if self.iter == 0:
io.log_info ("\nModel first run. Enter model options as default for each run.") io.log_info ("\nModel first run. Enter model options as default for each run.")
if ask_write_preview_history and (self.iter == 0 or ask_override): if ask_write_preview_history and (self.iter == 0 or ask_override):
default_write_preview_history = False if self.iter == 0 else self.options.get('write_preview_history',False) default_write_preview_history = False if self.iter == 0 else self.options.get('write_preview_history',False)
self.options['write_preview_history'] = io.input_bool("Write preview history? (y/n ?:help skip:%s) : " % (yn_str[default_write_preview_history]) , default_write_preview_history, help_message="Preview history will be writed to <ModelName>_history folder.") self.options['write_preview_history'] = io.input_bool("Write preview history? (y/n ?:help skip:%s) : " % (yn_str[default_write_preview_history]) , default_write_preview_history, help_message="Preview history will be writed to <ModelName>_history folder.")
else: else:
self.options['write_preview_history'] = self.options.get('write_preview_history', False) self.options['write_preview_history'] = self.options.get('write_preview_history', False)
if ask_target_iter and (self.iter == 0 or ask_override): if ask_target_iter and (self.iter == 0 or ask_override):
self.options['target_iter'] = max(0, io.input_int("Target iteration (skip:unlimited/default) : ", 0)) self.options['target_iter'] = max(0, io.input_int("Target iteration (skip:unlimited/default) : ", 0))
else: else:
self.options['target_iter'] = max(model_data.get('target_iter',0), self.options.get('target_epoch',0)) self.options['target_iter'] = max(model_data.get('target_iter',0), self.options.get('target_epoch',0))
if 'target_epoch' in self.options: if 'target_epoch' in self.options:
self.options.pop('target_epoch') self.options.pop('target_epoch')
if ask_batch_size and (self.iter == 0 or ask_override): if ask_batch_size and (self.iter == 0 or ask_override):
default_batch_size = 0 if self.iter == 0 else self.options.get('batch_size',0) default_batch_size = 0 if self.iter == 0 else self.options.get('batch_size',0)
self.options['batch_size'] = max(0, io.input_int("Batch_size (?:help skip:%d) : " % (default_batch_size), default_batch_size, help_message="Larger batch size is always better for NN's generalization, but it can cause Out of Memory error. Tune this value for your videocard manually.")) self.options['batch_size'] = max(0, io.input_int("Batch_size (?:help skip:%d) : " % (default_batch_size), default_batch_size, help_message="Larger batch size is always better for NN's generalization, but it can cause Out of Memory error. Tune this value for your videocard manually."))
else: else:
self.options['batch_size'] = self.options.get('batch_size', 0) self.options['batch_size'] = self.options.get('batch_size', 0)
if ask_sort_by_yaw and (self.iter == 0): if ask_sort_by_yaw and (self.iter == 0):
self.options['sort_by_yaw'] = io.input_bool("Feed faces to network sorted by yaw? (y/n ?:help skip:n) : ", False, help_message="NN will not learn src face directions that don't match dst face directions." ) self.options['sort_by_yaw'] = io.input_bool("Feed faces to network sorted by yaw? (y/n ?:help skip:n) : ", False, help_message="NN will not learn src face directions that don't match dst face directions." )
else: else:
self.options['sort_by_yaw'] = self.options.get('sort_by_yaw', False) self.options['sort_by_yaw'] = self.options.get('sort_by_yaw', False)
if ask_random_flip and (self.iter == 0): if ask_random_flip and (self.iter == 0):
self.options['random_flip'] = io.input_bool("Flip faces randomly? (y/n ?:help skip:y) : ", True, help_message="Predicted face will look more naturally without this option, but src faceset should cover all face directions as dst faceset.") self.options['random_flip'] = io.input_bool("Flip faces randomly? (y/n ?:help skip:y) : ", True, help_message="Predicted face will look more naturally without this option, but src faceset should cover all face directions as dst faceset.")
else: else:
self.options['random_flip'] = self.options.get('random_flip', True) self.options['random_flip'] = self.options.get('random_flip', True)
if ask_src_scale_mod and (self.iter == 0): if ask_src_scale_mod and (self.iter == 0):
self.options['src_scale_mod'] = np.clip( io.input_int("Src face scale modifier % ( -30...30, ?:help skip:0) : ", 0, help_message="If src face shape is wider than dst, try to decrease this value to get a better result."), -30, 30) self.options['src_scale_mod'] = np.clip( io.input_int("Src face scale modifier % ( -30...30, ?:help skip:0) : ", 0, help_message="If src face shape is wider than dst, try to decrease this value to get a better result."), -30, 30)
else: else:
self.options['src_scale_mod'] = self.options.get('src_scale_mod', 0) self.options['src_scale_mod'] = self.options.get('src_scale_mod', 0)
self.write_preview_history = self.options['write_preview_history'] self.write_preview_history = self.options['write_preview_history']
if not self.options['write_preview_history']: if not self.options['write_preview_history']:
self.options.pop('write_preview_history') self.options.pop('write_preview_history')
self.target_iter = self.options['target_iter'] self.target_iter = self.options['target_iter']
if self.options['target_iter'] == 0: if self.options['target_iter'] == 0:
self.options.pop('target_iter') self.options.pop('target_iter')
self.batch_size = self.options['batch_size'] self.batch_size = self.options['batch_size']
self.sort_by_yaw = self.options['sort_by_yaw'] self.sort_by_yaw = self.options['sort_by_yaw']
self.random_flip = self.options['random_flip'] self.random_flip = self.options['random_flip']
self.src_scale_mod = self.options['src_scale_mod'] self.src_scale_mod = self.options['src_scale_mod']
if self.src_scale_mod == 0: if self.src_scale_mod == 0:
self.options.pop('src_scale_mod') self.options.pop('src_scale_mod')
self.onInitializeOptions(self.iter == 0, ask_override) self.onInitializeOptions(self.iter == 0, ask_override)
nnlib.import_all(self.device_config) nnlib.import_all(self.device_config)
self.keras = nnlib.keras self.keras = nnlib.keras
self.K = nnlib.keras.backend self.K = nnlib.keras.backend
self.onInitialize() self.onInitialize()
self.options['batch_size'] = self.batch_size self.options['batch_size'] = self.batch_size
if self.debug or self.batch_size == 0: if self.debug or self.batch_size == 0:
self.batch_size = 1 self.batch_size = 1
if self.is_training_mode: if self.is_training_mode:
if self.write_preview_history: if self.write_preview_history:
if self.device_args['force_gpu_idx'] == -1: if self.device_args['force_gpu_idx'] == -1:
self.preview_history_path = self.model_path / ( '%s_history' % (self.get_model_name()) ) self.preview_history_path = self.model_path / ( '%s_history' % (self.get_model_name()) )
else: else:
self.preview_history_path = self.model_path / ( '%d_%s_history' % (self.device_args['force_gpu_idx'], self.get_model_name()) ) self.preview_history_path = self.model_path / ( '%d_%s_history' % (self.device_args['force_gpu_idx'], self.get_model_name()) )
if not self.preview_history_path.exists(): if not self.preview_history_path.exists():
self.preview_history_path.mkdir(exist_ok=True) self.preview_history_path.mkdir(exist_ok=True)
else: else:
if self.iter == 0: if self.iter == 0:
for filename in Path_utils.get_image_paths(self.preview_history_path): for filename in Path_utils.get_image_paths(self.preview_history_path):
Path(filename).unlink() Path(filename).unlink()
if self.generator_list is None: if self.generator_list is None:
raise ValueError( 'You didnt set_training_data_generators()') raise ValueError( 'You didnt set_training_data_generators()')
else: else:
for i, generator in enumerate(self.generator_list): for i, generator in enumerate(self.generator_list):
if not isinstance(generator, SampleGeneratorBase): if not isinstance(generator, SampleGeneratorBase):
raise ValueError('training data generator is not subclass of SampleGeneratorBase') raise ValueError('training data generator is not subclass of SampleGeneratorBase')
if (self.sample_for_preview is None) or (self.iter == 0): if (self.sample_for_preview is None) or (self.iter == 0):
self.sample_for_preview = self.generate_next_sample() self.sample_for_preview = self.generate_next_sample()
model_summary_text = [] model_summary_text = []
model_summary_text += ["===== Model summary ====="] model_summary_text += ["===== Model summary ====="]
model_summary_text += ["== Model name: " + self.get_model_name()] model_summary_text += ["== Model name: " + self.get_model_name()]
model_summary_text += ["=="] model_summary_text += ["=="]
@ -179,41 +179,41 @@ class ModelBase(object):
if self.device_config.multi_gpu: if self.device_config.multi_gpu:
model_summary_text += ["== |== multi_gpu : True "] model_summary_text += ["== |== multi_gpu : True "]
model_summary_text += ["== Running on:"] model_summary_text += ["== Running on:"]
if self.device_config.cpu_only: if self.device_config.cpu_only:
model_summary_text += ["== |== [CPU]"] model_summary_text += ["== |== [CPU]"]
else: else:
for idx in self.device_config.gpu_idxs: for idx in self.device_config.gpu_idxs:
model_summary_text += ["== |== [%d : %s]" % (idx, nnlib.device.getDeviceName(idx))] model_summary_text += ["== |== [%d : %s]" % (idx, nnlib.device.getDeviceName(idx))]
if not self.device_config.cpu_only and self.device_config.gpu_vram_gb[0] == 2: if not self.device_config.cpu_only and self.device_config.gpu_vram_gb[0] == 2:
model_summary_text += ["=="] model_summary_text += ["=="]
model_summary_text += ["== WARNING: You are using 2GB GPU. Result quality may be significantly decreased."] model_summary_text += ["== WARNING: You are using 2GB GPU. Result quality may be significantly decreased."]
model_summary_text += ["== If training does not start, close all programs and try again."] model_summary_text += ["== If training does not start, close all programs and try again."]
model_summary_text += ["== Also you can disable Windows Aero Desktop to get extra free VRAM."] model_summary_text += ["== Also you can disable Windows Aero Desktop to get extra free VRAM."]
model_summary_text += ["=="] model_summary_text += ["=="]
model_summary_text += ["========================="] model_summary_text += ["========================="]
model_summary_text = "\r\n".join (model_summary_text) model_summary_text = "\r\n".join (model_summary_text)
self.model_summary_text = model_summary_text self.model_summary_text = model_summary_text
io.log_info(model_summary_text) io.log_info(model_summary_text)
#overridable #overridable
def onInitializeOptions(self, is_first_run, ask_override): def onInitializeOptions(self, is_first_run, ask_override):
pass pass
#overridable #overridable
def onInitialize(self): def onInitialize(self):
''' '''
initialize your keras models initialize your keras models
store and retrieve your model options in self.options[''] store and retrieve your model options in self.options['']
check example check example
''' '''
pass pass
#overridable #overridable
def onSave(self): def onSave(self):
#save your keras models here #save your keras models here
@ -229,59 +229,59 @@ class ModelBase(object):
#overridable #overridable
def onGetPreview(self, sample): def onGetPreview(self, sample):
#you can return multiple previews #you can return multiple previews
#return [ ('preview_name',preview_rgb), ... ] #return [ ('preview_name',preview_rgb), ... ]
return [] return []
#overridable if you want model name differs from folder name #overridable if you want model name differs from folder name
def get_model_name(self): def get_model_name(self):
return Path(inspect.getmodule(self).__file__).parent.name.rsplit("_", 1)[1] return Path(inspect.getmodule(self).__file__).parent.name.rsplit("_", 1)[1]
#overridable #overridable
def get_converter(self): def get_converter(self):
raise NotImplementeError raise NotImplementeError
#return existing or your own converter which derived from base #return existing or your own converter which derived from base
def get_target_iter(self): def get_target_iter(self):
return self.target_iter return self.target_iter
def is_reached_iter_goal(self): def is_reached_iter_goal(self):
return self.target_iter != 0 and self.iter >= self.target_iter return self.target_iter != 0 and self.iter >= self.target_iter
#multi gpu in keras actually is fake and doesn't work for training https://github.com/keras-team/keras/issues/11976 #multi gpu in keras actually is fake and doesn't work for training https://github.com/keras-team/keras/issues/11976
#def to_multi_gpu_model_if_possible (self, models_list): #def to_multi_gpu_model_if_possible (self, models_list):
# if len(self.device_config.gpu_idxs) > 1: # if len(self.device_config.gpu_idxs) > 1:
# #make batch_size to divide on GPU count without remainder # #make batch_size to divide on GPU count without remainder
# self.batch_size = int( self.batch_size / len(self.device_config.gpu_idxs) ) # self.batch_size = int( self.batch_size / len(self.device_config.gpu_idxs) )
# if self.batch_size == 0: # if self.batch_size == 0:
# self.batch_size = 1 # self.batch_size = 1
# self.batch_size *= len(self.device_config.gpu_idxs) # self.batch_size *= len(self.device_config.gpu_idxs)
# #
# result = [] # result = []
# for model in models_list: # for model in models_list:
# for i in range( len(model.output_names) ): # for i in range( len(model.output_names) ):
# model.output_names = 'output_%d' % (i) # model.output_names = 'output_%d' % (i)
# result += [ nnlib.keras.utils.multi_gpu_model( model, self.device_config.gpu_idxs ) ] # result += [ nnlib.keras.utils.multi_gpu_model( model, self.device_config.gpu_idxs ) ]
# #
# return result # return result
# else: # else:
# return models_list # return models_list
def get_previews(self): def get_previews(self):
return self.onGetPreview ( self.last_sample ) return self.onGetPreview ( self.last_sample )
def get_static_preview(self): def get_static_preview(self):
return self.onGetPreview (self.sample_for_preview)[0][1] #first preview, and bgr return self.onGetPreview (self.sample_for_preview)[0][1] #first preview, and bgr
def save(self): def save(self):
Path( self.get_strpath_storage_for_file('summary.txt') ).write_text(self.model_summary_text) Path( self.get_strpath_storage_for_file('summary.txt') ).write_text(self.model_summary_text)
self.onSave() self.onSave()
model_data = { model_data = {
'iter': self.iter, 'iter': self.iter,
'options': self.options, 'options': self.options,
'loss_history': self.loss_history, 'loss_history': self.loss_history,
'sample_for_preview' : self.sample_for_preview 'sample_for_preview' : self.sample_for_preview
} }
self.model_data_path.write_bytes( pickle.dumps(model_data) ) self.model_data_path.write_bytes( pickle.dumps(model_data) )
def load_weights_safe(self, model_filename_list, optimizer_filename_list=[]): def load_weights_safe(self, model_filename_list, optimizer_filename_list=[]):
@ -289,17 +289,17 @@ class ModelBase(object):
filename = self.get_strpath_storage_for_file(filename) filename = self.get_strpath_storage_for_file(filename)
if Path(filename).exists(): if Path(filename).exists():
model.load_weights(filename) model.load_weights(filename)
if len(optimizer_filename_list) != 0: if len(optimizer_filename_list) != 0:
opt_filename = self.get_strpath_storage_for_file('opt.h5') opt_filename = self.get_strpath_storage_for_file('opt.h5')
if Path(opt_filename).exists(): if Path(opt_filename).exists():
try: try:
with open(opt_filename, "rb") as f: with open(opt_filename, "rb") as f:
d = pickle.loads(f.read()) d = pickle.loads(f.read())
for x in optimizer_filename_list: for x in optimizer_filename_list:
opt, filename = x opt, filename = x
if filename in d: if filename in d:
weights = d[filename].get('weights', None) weights = d[filename].get('weights', None)
if weights: if weights:
opt.set_weights(weights) opt.set_weights(weights)
@ -307,16 +307,16 @@ class ModelBase(object):
except Exception as e: except Exception as e:
print ("Unable to load ", opt_filename) print ("Unable to load ", opt_filename)
def save_weights_safe(self, model_filename_list, optimizer_filename_list=[]): def save_weights_safe(self, model_filename_list, optimizer_filename_list=[]):
for model, filename in model_filename_list: for model, filename in model_filename_list:
filename = self.get_strpath_storage_for_file(filename) filename = self.get_strpath_storage_for_file(filename)
model.save_weights( filename + '.tmp' ) model.save_weights( filename + '.tmp' )
rename_list = model_filename_list rename_list = model_filename_list
if len(optimizer_filename_list) != 0: if len(optimizer_filename_list) != 0:
opt_filename = self.get_strpath_storage_for_file('opt.h5') opt_filename = self.get_strpath_storage_for_file('opt.h5')
try: try:
d = {} d = {}
for opt, filename in optimizer_filename_list: for opt, filename in optimizer_filename_list:
@ -324,54 +324,54 @@ class ModelBase(object):
symbolic_weights = getattr(opt, 'weights') symbolic_weights = getattr(opt, 'weights')
if symbolic_weights: if symbolic_weights:
fd['weights'] = self.K.batch_get_value(symbolic_weights) fd['weights'] = self.K.batch_get_value(symbolic_weights)
d[filename] = fd d[filename] = fd
with open(opt_filename+'.tmp', 'wb') as f: with open(opt_filename+'.tmp', 'wb') as f:
f.write( pickle.dumps(d) ) f.write( pickle.dumps(d) )
rename_list += [('', 'opt.h5')] rename_list += [('', 'opt.h5')]
except Exception as e: except Exception as e:
print ("Unable to save ", opt_filename) print ("Unable to save ", opt_filename)
for _, filename in rename_list: for _, filename in rename_list:
filename = self.get_strpath_storage_for_file(filename) filename = self.get_strpath_storage_for_file(filename)
source_filename = Path(filename+'.tmp') source_filename = Path(filename+'.tmp')
if source_filename.exists(): if source_filename.exists():
target_filename = Path(filename) target_filename = Path(filename)
if target_filename.exists(): if target_filename.exists():
target_filename.unlink() target_filename.unlink()
source_filename.rename ( str(target_filename) ) source_filename.rename ( str(target_filename) )
def debug_one_iter(self): def debug_one_iter(self):
images = [] images = []
for generator in self.generator_list: for generator in self.generator_list:
for i,batch in enumerate(next(generator)): for i,batch in enumerate(next(generator)):
if len(batch.shape) == 4: if len(batch.shape) == 4:
images.append( batch[0] ) images.append( batch[0] )
return image_utils.equalize_and_stack_square (images) return image_utils.equalize_and_stack_square (images)
def generate_next_sample(self): def generate_next_sample(self):
return [next(generator) for generator in self.generator_list] return [next(generator) for generator in self.generator_list]
def train_one_iter(self): def train_one_iter(self):
sample = self.generate_next_sample() sample = self.generate_next_sample()
iter_time = time.time() iter_time = time.time()
losses = self.onTrainOneIter(sample, self.generator_list) losses = self.onTrainOneIter(sample, self.generator_list)
iter_time = time.time() - iter_time iter_time = time.time() - iter_time
self.last_sample = sample self.last_sample = sample
self.loss_history.append ( [float(loss[1]) for loss in losses] ) self.loss_history.append ( [float(loss[1]) for loss in losses] )
if self.write_preview_history: if self.write_preview_history:
if self.iter % 10 == 0: if self.iter % 10 == 0:
preview = self.get_static_preview() preview = self.get_static_preview()
preview_lh = ModelBase.get_loss_history_preview(self.loss_history, self.iter, preview.shape[1], preview.shape[2]) preview_lh = ModelBase.get_loss_history_preview(self.loss_history, self.iter, preview.shape[1], preview.shape[2])
img = (np.concatenate ( [preview_lh, preview], axis=0 ) * 255).astype(np.uint8) img = (np.concatenate ( [preview_lh, preview], axis=0 ) * 255).astype(np.uint8)
cv2_imwrite ( str (self.preview_history_path / ('%.6d.jpg' %( self.iter) )), img ) cv2_imwrite ( str (self.preview_history_path / ('%.6d.jpg' %( self.iter) )), img )
self.iter += 1 self.iter += 1
time_str = time.strftime("[%H:%M:%S]") time_str = time.strftime("[%H:%M:%S]")
@ -383,40 +383,40 @@ class ModelBase(object):
loss_string += " %s:%.3f" % (loss_name, loss_value) loss_string += " %s:%.3f" % (loss_name, loss_value)
return loss_string return loss_string
def pass_one_iter(self): def pass_one_iter(self):
self.last_sample = self.generate_next_sample() self.last_sample = self.generate_next_sample()
def finalize(self): def finalize(self):
nnlib.finalize_all() nnlib.finalize_all()
def is_first_run(self): def is_first_run(self):
return self.iter == 0 return self.iter == 0
def is_debug(self): def is_debug(self):
return self.debug return self.debug
def set_batch_size(self, batch_size): def set_batch_size(self, batch_size):
self.batch_size = batch_size self.batch_size = batch_size
def get_batch_size(self): def get_batch_size(self):
return self.batch_size return self.batch_size
def get_iter(self): def get_iter(self):
return self.iter return self.iter
def get_loss_history(self): def get_loss_history(self):
return self.loss_history return self.loss_history
def set_training_data_generators (self, generator_list): def set_training_data_generators (self, generator_list):
self.generator_list = generator_list self.generator_list = generator_list
def get_training_data_generators (self): def get_training_data_generators (self):
return self.generator_list return self.generator_list
def get_model_root_path(self): def get_model_root_path(self):
return self.model_path return self.model_path
def get_strpath_storage_for_file(self, filename): def get_strpath_storage_for_file(self, filename):
if self.device_args['force_gpu_idx'] == -1: if self.device_args['force_gpu_idx'] == -1:
return str( self.model_path / ( self.get_model_name() + '_' + filename) ) return str( self.model_path / ( self.get_model_name() + '_' + filename) )
@ -424,65 +424,65 @@ class ModelBase(object):
return str( self.model_path / ( str(self.device_args['force_gpu_idx']) + '_' + self.get_model_name() + '_' + filename) ) return str( self.model_path / ( str(self.device_args['force_gpu_idx']) + '_' + self.get_model_name() + '_' + filename) )
def set_vram_batch_requirements (self, d): def set_vram_batch_requirements (self, d):
#example d = {2:2,3:4,4:8,5:16,6:32,7:32,8:32,9:48} #example d = {2:2,3:4,4:8,5:16,6:32,7:32,8:32,9:48}
keys = [x for x in d.keys()] keys = [x for x in d.keys()]
if self.device_config.cpu_only: if self.device_config.cpu_only:
if self.batch_size == 0: if self.batch_size == 0:
self.batch_size = 2 self.batch_size = 2
else: else:
if self.batch_size == 0: if self.batch_size == 0:
for x in keys: for x in keys:
if self.device_config.gpu_vram_gb[0] <= x: if self.device_config.gpu_vram_gb[0] <= x:
self.batch_size = d[x] self.batch_size = d[x]
break break
if self.batch_size == 0: if self.batch_size == 0:
self.batch_size = d[ keys[-1] ] self.batch_size = d[ keys[-1] ]
@staticmethod @staticmethod
def get_loss_history_preview(loss_history, iter, w, c): def get_loss_history_preview(loss_history, iter, w, c):
loss_history = np.array (loss_history.copy()) loss_history = np.array (loss_history.copy())
lh_height = 100 lh_height = 100
lh_img = np.ones ( (lh_height,w,c) ) * 0.1 lh_img = np.ones ( (lh_height,w,c) ) * 0.1
loss_count = len(loss_history[0]) loss_count = len(loss_history[0])
lh_len = len(loss_history) lh_len = len(loss_history)
l_per_col = lh_len / w l_per_col = lh_len / w
plist_max = [ [ max (0.0, loss_history[int(col*l_per_col)][p], plist_max = [ [ max (0.0, loss_history[int(col*l_per_col)][p],
*[ loss_history[i_ab][p] *[ loss_history[i_ab][p]
for i_ab in range( int(col*l_per_col), int((col+1)*l_per_col) ) for i_ab in range( int(col*l_per_col), int((col+1)*l_per_col) )
] ]
) )
for p in range(loss_count) for p in range(loss_count)
] ]
for col in range(w) for col in range(w)
] ]
plist_min = [ [ min (plist_max[col][p], loss_history[int(col*l_per_col)][p], plist_min = [ [ min (plist_max[col][p], loss_history[int(col*l_per_col)][p],
*[ loss_history[i_ab][p] *[ loss_history[i_ab][p]
for i_ab in range( int(col*l_per_col), int((col+1)*l_per_col) ) for i_ab in range( int(col*l_per_col), int((col+1)*l_per_col) )
] ]
) )
for p in range(loss_count) for p in range(loss_count)
] ]
for col in range(w) for col in range(w)
] ]
plist_abs_max = np.mean(loss_history[ len(loss_history) // 5 : ]) * 2 plist_abs_max = np.mean(loss_history[ len(loss_history) // 5 : ]) * 2
for col in range(0, w): for col in range(0, w):
for p in range(0,loss_count): for p in range(0,loss_count):
point_color = [1.0]*c point_color = [1.0]*c
point_color[0:3] = colorsys.hsv_to_rgb ( p * (1.0/loss_count), 1.0, 1.0 ) point_color[0:3] = colorsys.hsv_to_rgb ( p * (1.0/loss_count), 1.0, 1.0 )
ph_max = int ( (plist_max[col][p] / plist_abs_max) * (lh_height-1) ) ph_max = int ( (plist_max[col][p] / plist_abs_max) * (lh_height-1) )
ph_max = np.clip( ph_max, 0, lh_height-1 ) ph_max = np.clip( ph_max, 0, lh_height-1 )
ph_min = int ( (plist_min[col][p] / plist_abs_max) * (lh_height-1) ) ph_min = int ( (plist_min[col][p] / plist_abs_max) * (lh_height-1) )
ph_min = np.clip( ph_min, 0, lh_height-1 ) ph_min = np.clip( ph_min, 0, lh_height-1 )
for ph in range(ph_min, ph_max+1): for ph in range(ph_min, ph_max+1):
lh_img[ (lh_height-ph-1), col ] = point_color lh_img[ (lh_height-ph-1), col ] = point_color
@ -490,11 +490,11 @@ class ModelBase(object):
lh_line_height = (lh_height-1)/lh_lines lh_line_height = (lh_height-1)/lh_lines
for i in range(0,lh_lines+1): for i in range(0,lh_lines+1):
lh_img[ int(i*lh_line_height), : ] = (0.8,)*c lh_img[ int(i*lh_line_height), : ] = (0.8,)*c
last_line_t = int((lh_lines-1)*lh_line_height) last_line_t = int((lh_lines-1)*lh_line_height)
last_line_b = int(lh_lines*lh_line_height) last_line_b = int(lh_lines*lh_line_height)
lh_text = 'Iter: %d' % (iter) if iter != 0 else '' lh_text = 'Iter: %d' % (iter) if iter != 0 else ''
lh_img[last_line_t:last_line_b, 0:w] += image_utils.get_text_image ( (w,last_line_b-last_line_t,c), lh_text, color=[0.8]*c ) lh_img[last_line_t:last_line_b, 0:w] += image_utils.get_text_image ( (w,last_line_b-last_line_t,c), lh_text, color=[0.8]*c )
return lh_img return lh_img

View file

@ -9,22 +9,22 @@ from interact import interact as io
class Model(ModelBase): class Model(ModelBase):
#override #override
def onInitializeOptions(self, is_first_run, ask_override): def onInitializeOptions(self, is_first_run, ask_override):
if is_first_run or ask_override: if is_first_run or ask_override:
def_pixel_loss = self.options.get('pixel_loss', False) def_pixel_loss = self.options.get('pixel_loss', False)
self.options['pixel_loss'] = io.input_bool ("Use pixel loss? (y/n, ?:help skip: n/default ) : ", def_pixel_loss, help_message="Default DSSIM loss good for initial understanding structure of faces. Use pixel loss after 20k iters to enhance fine details and decrease face jitter.") self.options['pixel_loss'] = io.input_bool ("Use pixel loss? (y/n, ?:help skip: n/default ) : ", def_pixel_loss, help_message="Default DSSIM loss good for initial understanding structure of faces. Use pixel loss after 20k iters to enhance fine details and decrease face jitter.")
else: else:
self.options['pixel_loss'] = self.options.get('pixel_loss', False) self.options['pixel_loss'] = self.options.get('pixel_loss', False)
#override #override
def onInitialize(self): def onInitialize(self):
exec(nnlib.import_all(), locals(), globals()) exec(nnlib.import_all(), locals(), globals())
self.set_vram_batch_requirements( {4.5:4} ) self.set_vram_batch_requirements( {4.5:4} )
ae_input_layer = Input(shape=(128, 128, 3)) ae_input_layer = Input(shape=(128, 128, 3))
mask_layer = Input(shape=(128, 128, 1)) #same as output mask_layer = Input(shape=(128, 128, 1)) #same as output
self.encoder, self.decoder_src, self.decoder_dst = self.Build(ae_input_layer) self.encoder, self.decoder_src, self.decoder_dst = self.Build(ae_input_layer)
if not self.is_first_run(): if not self.is_first_run():
weights_to_load = [ [self.encoder , 'encoder.h5'], weights_to_load = [ [self.encoder , 'encoder.h5'],
@ -38,39 +38,39 @@ class Model(ModelBase):
self.autoencoder_src.compile(optimizer=Adam(lr=5e-5, beta_1=0.5, beta_2=0.999), loss=[DSSIMMSEMaskLoss(mask_layer, is_mse=self.options['pixel_loss']), 'mse'] ) self.autoencoder_src.compile(optimizer=Adam(lr=5e-5, beta_1=0.5, beta_2=0.999), loss=[DSSIMMSEMaskLoss(mask_layer, is_mse=self.options['pixel_loss']), 'mse'] )
self.autoencoder_dst.compile(optimizer=Adam(lr=5e-5, beta_1=0.5, beta_2=0.999), loss=[DSSIMMSEMaskLoss(mask_layer, is_mse=self.options['pixel_loss']), 'mse'] ) self.autoencoder_dst.compile(optimizer=Adam(lr=5e-5, beta_1=0.5, beta_2=0.999), loss=[DSSIMMSEMaskLoss(mask_layer, is_mse=self.options['pixel_loss']), 'mse'] )
if self.is_training_mode: if self.is_training_mode:
f = SampleProcessor.TypeFlags f = SampleProcessor.TypeFlags
self.set_training_data_generators ([ self.set_training_data_generators ([
SampleGeneratorFace(self.training_data_src_path, sort_by_yaw_target_samples_path=self.training_data_dst_path if self.sort_by_yaw else None, SampleGeneratorFace(self.training_data_src_path, sort_by_yaw_target_samples_path=self.training_data_dst_path if self.sort_by_yaw else None,
debug=self.is_debug(), batch_size=self.batch_size, debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05])+self.src_scale_mod / 100.0 ), sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05])+self.src_scale_mod / 100.0 ),
output_sample_types=[ [f.WARPED_TRANSFORMED | f.FACE_ALIGN_FULL | f.MODE_BGR, 128], output_sample_types=[ [f.WARPED_TRANSFORMED | f.FACE_ALIGN_FULL | f.MODE_BGR, 128],
[f.TRANSFORMED | f.FACE_ALIGN_FULL | f.MODE_BGR, 128], [f.TRANSFORMED | f.FACE_ALIGN_FULL | f.MODE_BGR, 128],
[f.TRANSFORMED | f.FACE_ALIGN_FULL | f.MODE_M | f.FACE_MASK_FULL, 128] ] ), [f.TRANSFORMED | f.FACE_ALIGN_FULL | f.MODE_M | f.FACE_MASK_FULL, 128] ] ),
SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size, SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip), sample_process_options=SampleProcessor.Options(random_flip=self.random_flip),
output_sample_types=[ [f.WARPED_TRANSFORMED | f.FACE_ALIGN_FULL | f.MODE_BGR, 128], output_sample_types=[ [f.WARPED_TRANSFORMED | f.FACE_ALIGN_FULL | f.MODE_BGR, 128],
[f.TRANSFORMED | f.FACE_ALIGN_FULL | f.MODE_BGR, 128], [f.TRANSFORMED | f.FACE_ALIGN_FULL | f.MODE_BGR, 128],
[f.TRANSFORMED | f.FACE_ALIGN_FULL | f.MODE_M | f.FACE_MASK_FULL, 128] ] ) [f.TRANSFORMED | f.FACE_ALIGN_FULL | f.MODE_M | f.FACE_MASK_FULL, 128] ] )
]) ])
#override #override
def onSave(self): def onSave(self):
self.save_weights_safe( [[self.encoder, 'encoder.h5'], self.save_weights_safe( [[self.encoder, 'encoder.h5'],
[self.decoder_src, 'decoder_src.h5'], [self.decoder_src, 'decoder_src.h5'],
[self.decoder_dst, 'decoder_dst.h5']] ) [self.decoder_dst, 'decoder_dst.h5']] )
#override #override
def onTrainOneIter(self, sample, generators_list): def onTrainOneIter(self, sample, generators_list):
warped_src, target_src, target_src_mask = sample[0] warped_src, target_src, target_src_mask = sample[0]
warped_dst, target_dst, target_dst_mask = sample[1] warped_dst, target_dst, target_dst_mask = sample[1]
loss_src = self.autoencoder_src.train_on_batch( [warped_src, target_src_mask], [target_src, target_src_mask] ) loss_src = self.autoencoder_src.train_on_batch( [warped_src, target_src_mask], [target_src, target_src_mask] )
loss_dst = self.autoencoder_dst.train_on_batch( [warped_dst, target_dst_mask], [target_dst, target_dst_mask] ) loss_dst = self.autoencoder_dst.train_on_batch( [warped_dst, target_dst_mask], [target_dst, target_dst_mask] )
return ( ('loss_src', loss_src[0]), ('loss_dst', loss_dst[0]) ) return ( ('loss_src', loss_src[0]), ('loss_dst', loss_dst[0]) )
#override #override
def onGetPreview(self, sample): def onGetPreview(self, sample):
@ -78,64 +78,64 @@ class Model(ModelBase):
test_A_m = sample[0][2][0:4] #first 4 samples test_A_m = sample[0][2][0:4] #first 4 samples
test_B = sample[1][1][0:4] test_B = sample[1][1][0:4]
test_B_m = sample[1][2][0:4] test_B_m = sample[1][2][0:4]
AA, mAA = self.autoencoder_src.predict([test_A, test_A_m]) AA, mAA = self.autoencoder_src.predict([test_A, test_A_m])
AB, mAB = self.autoencoder_src.predict([test_B, test_B_m]) AB, mAB = self.autoencoder_src.predict([test_B, test_B_m])
BB, mBB = self.autoencoder_dst.predict([test_B, test_B_m]) BB, mBB = self.autoencoder_dst.predict([test_B, test_B_m])
mAA = np.repeat ( mAA, (3,), -1) mAA = np.repeat ( mAA, (3,), -1)
mAB = np.repeat ( mAB, (3,), -1) mAB = np.repeat ( mAB, (3,), -1)
mBB = np.repeat ( mBB, (3,), -1) mBB = np.repeat ( mBB, (3,), -1)
st = [] st = []
for i in range(0, len(test_A)): for i in range(0, len(test_A)):
st.append ( np.concatenate ( ( st.append ( np.concatenate ( (
test_A[i,:,:,0:3], test_A[i,:,:,0:3],
AA[i], AA[i],
#mAA[i], #mAA[i],
test_B[i,:,:,0:3], test_B[i,:,:,0:3],
BB[i], BB[i],
#mBB[i], #mBB[i],
AB[i], AB[i],
#mAB[i] #mAB[i]
), axis=1) ) ), axis=1) )
return [ ('DF', np.concatenate ( st, axis=0 ) ) ] return [ ('DF', np.concatenate ( st, axis=0 ) ) ]
def predictor_func (self, face): def predictor_func (self, face):
face_128_bgr = face[...,0:3] face_128_bgr = face[...,0:3]
face_128_mask = np.expand_dims(face[...,3],-1) face_128_mask = np.expand_dims(face[...,3],-1)
x, mx = self.autoencoder_src.predict ( [ np.expand_dims(face_128_bgr,0), np.expand_dims(face_128_mask,0) ] ) x, mx = self.autoencoder_src.predict ( [ np.expand_dims(face_128_bgr,0), np.expand_dims(face_128_mask,0) ] )
x, mx = x[0], mx[0] x, mx = x[0], mx[0]
return np.concatenate ( (x,mx), -1 ) return np.concatenate ( (x,mx), -1 )
#override #override
def get_converter(self): def get_converter(self):
from converters import ConverterMasked from converters import ConverterMasked
return ConverterMasked(self.predictor_func, return ConverterMasked(self.predictor_func,
predictor_input_size=128, predictor_input_size=128,
output_size=128, output_size=128,
face_type=FaceType.FULL, face_type=FaceType.FULL,
base_erode_mask_modifier=30, base_erode_mask_modifier=30,
base_blur_mask_modifier=0) base_blur_mask_modifier=0)
def Build(self, input_layer): def Build(self, input_layer):
exec(nnlib.code_import_all, locals(), globals()) exec(nnlib.code_import_all, locals(), globals())
def downscale (dim): def downscale (dim):
def func(x): def func(x):
return LeakyReLU(0.1)(Conv2D(dim, 5, strides=2, padding='same')(x)) return LeakyReLU(0.1)(Conv2D(dim, 5, strides=2, padding='same')(x))
return func return func
def upscale (dim): def upscale (dim):
def func(x): def func(x):
return PixelShuffler()(LeakyReLU(0.1)(Conv2D(dim * 4, 3, strides=1, padding='same')(x))) return PixelShuffler()(LeakyReLU(0.1)(Conv2D(dim * 4, 3, strides=1, padding='same')(x)))
return func return func
def Encoder(input_layer): def Encoder(input_layer):
x = input_layer x = input_layer
x = downscale(128)(x) x = downscale(128)(x)
x = downscale(256)(x) x = downscale(256)(x)
@ -146,7 +146,7 @@ class Model(ModelBase):
x = Dense(8 * 8 * 512)(x) x = Dense(8 * 8 * 512)(x)
x = Reshape((8, 8, 512))(x) x = Reshape((8, 8, 512))(x)
x = upscale(512)(x) x = upscale(512)(x)
return Model(input_layer, x) return Model(input_layer, x)
def Decoder(): def Decoder():
@ -155,15 +155,15 @@ class Model(ModelBase):
x = upscale(512)(x) x = upscale(512)(x)
x = upscale(256)(x) x = upscale(256)(x)
x = upscale(128)(x) x = upscale(128)(x)
y = input_ #mask decoder y = input_ #mask decoder
y = upscale(512)(y) y = upscale(512)(y)
y = upscale(256)(y) y = upscale(256)(y)
y = upscale(128)(y) y = upscale(128)(y)
x = Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(x) x = Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(x)
y = Conv2D(1, kernel_size=5, padding='same', activation='sigmoid')(y) y = Conv2D(1, kernel_size=5, padding='same', activation='sigmoid')(y)
return Model(input_, [x,y]) return Model(input_, [x,y])
return Encoder(input_layer), Decoder(), Decoder() return Encoder(input_layer), Decoder(), Decoder()

View file

@ -1 +1 @@
from .Model import Model from .Model import Model

View file

@ -10,13 +10,13 @@ from interact import interact as io
class Model(ModelBase): class Model(ModelBase):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs, super().__init__(*args, **kwargs,
ask_write_preview_history=False, ask_write_preview_history=False,
ask_target_iter=False, ask_target_iter=False,
ask_sort_by_yaw=False, ask_sort_by_yaw=False,
ask_random_flip=False, ask_random_flip=False,
ask_src_scale_mod=False) ask_src_scale_mod=False)
#override #override
def onInitialize(self): def onInitialize(self):
exec(nnlib.import_all(), locals(), globals()) exec(nnlib.import_all(), locals(), globals())
@ -24,33 +24,33 @@ class Model(ModelBase):
self.resolution = 256 self.resolution = 256
self.face_type = FaceType.FULL self.face_type = FaceType.FULL
self.fan_seg = FANSegmentator(self.resolution, self.fan_seg = FANSegmentator(self.resolution,
FaceType.toString(self.face_type), FaceType.toString(self.face_type),
load_weights=not self.is_first_run(), load_weights=not self.is_first_run(),
weights_file_root=self.get_model_root_path() ) weights_file_root=self.get_model_root_path() )
if self.is_training_mode: if self.is_training_mode:
f = SampleProcessor.TypeFlags f = SampleProcessor.TypeFlags
f_type = f.FACE_ALIGN_FULL f_type = f.FACE_ALIGN_FULL
self.set_training_data_generators ([ self.set_training_data_generators ([
SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.batch_size, SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(random_flip=True, normalize_tanh = True ), sample_process_options=SampleProcessor.Options(random_flip=True, normalize_tanh = True ),
output_sample_types=[ [f.TRANSFORMED | f_type | f.MODE_BGR_SHUFFLE, self.resolution], output_sample_types=[ [f.TRANSFORMED | f_type | f.MODE_BGR_SHUFFLE, self.resolution],
[f.TRANSFORMED | f_type | f.MODE_M | f.FACE_MASK_FULL, self.resolution] [f.TRANSFORMED | f_type | f.MODE_M | f.FACE_MASK_FULL, self.resolution]
]), ]),
SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size, SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(random_flip=True, normalize_tanh = True ), sample_process_options=SampleProcessor.Options(random_flip=True, normalize_tanh = True ),
output_sample_types=[ [f.TRANSFORMED | f_type | f.MODE_BGR_SHUFFLE, self.resolution] output_sample_types=[ [f.TRANSFORMED | f_type | f.MODE_BGR_SHUFFLE, self.resolution]
]) ])
]) ])
#override #override
def onSave(self): def onSave(self):
self.fan_seg.save_weights() self.fan_seg.save_weights()
#override #override
def onTrainOneIter(self, generators_samples, generators_list): def onTrainOneIter(self, generators_samples, generators_list):
target_src, target_src_mask = generators_samples[0] target_src, target_src_mask = generators_samples[0]
@ -58,20 +58,20 @@ class Model(ModelBase):
loss = self.fan_seg.train_on_batch( [target_src], [target_src_mask] ) loss = self.fan_seg.train_on_batch( [target_src], [target_src_mask] )
return ( ('loss', loss), ) return ( ('loss', loss), )
#override #override
def onGetPreview(self, sample): def onGetPreview(self, sample):
test_A = sample[0][0][0:4] #first 4 samples test_A = sample[0][0][0:4] #first 4 samples
test_B = sample[1][0][0:4] #first 4 samples test_B = sample[1][0][0:4] #first 4 samples
mAA = self.fan_seg.extract_from_bgr([test_A]) mAA = self.fan_seg.extract_from_bgr([test_A])
mBB = self.fan_seg.extract_from_bgr([test_B]) mBB = self.fan_seg.extract_from_bgr([test_B])
test_A, test_B, = [ np.clip( (x + 1.0)/2.0, 0.0, 1.0) for x in [test_A, test_B] ] test_A, test_B, = [ np.clip( (x + 1.0)/2.0, 0.0, 1.0) for x in [test_A, test_B] ]
mAA = np.repeat ( mAA, (3,), -1) mAA = np.repeat ( mAA, (3,), -1)
mBB = np.repeat ( mBB, (3,), -1) mBB = np.repeat ( mBB, (3,), -1)
st = [] st = []
for i in range(0, len(test_A)): for i in range(0, len(test_A)):
st.append ( np.concatenate ( ( st.append ( np.concatenate ( (
@ -79,7 +79,7 @@ class Model(ModelBase):
mAA[i], mAA[i],
test_A[i,:,:,0:3]*mAA[i], test_A[i,:,:,0:3]*mAA[i],
), axis=1) ) ), axis=1) )
st2 = [] st2 = []
for i in range(0, len(test_B)): for i in range(0, len(test_B)):
st2.append ( np.concatenate ( ( st2.append ( np.concatenate ( (
@ -87,7 +87,7 @@ class Model(ModelBase):
mBB[i], mBB[i],
test_B[i,:,:,0:3]*mBB[i], test_B[i,:,:,0:3]*mBB[i],
), axis=1) ) ), axis=1) )
return [ ('FANSegmentator', np.concatenate ( st, axis=0 ) ), return [ ('FANSegmentator', np.concatenate ( st, axis=0 ) ),
('never seen', np.concatenate ( st2, axis=0 ) ), ('never seen', np.concatenate ( st2, axis=0 ) ),
] ]

View file

@ -1 +1 @@
from .Model import Model from .Model import Model

View file

@ -9,7 +9,7 @@ from interact import interact as io
class Model(ModelBase): class Model(ModelBase):
#override #override
def onInitializeOptions(self, is_first_run, ask_override): def onInitializeOptions(self, is_first_run, ask_override):
if is_first_run: if is_first_run:
self.options['lighter_ae'] = io.input_bool ("Use lightweight autoencoder? (y/n, ?:help skip:n) : ", False, help_message="Lightweight autoencoder is faster, requires less VRAM, sacrificing overall quality. If your GPU VRAM <= 4, you should to choose this option.") self.options['lighter_ae'] = io.input_bool ("Use lightweight autoencoder? (y/n, ?:help skip:n) : ", False, help_message="Lightweight autoencoder is faster, requires less VRAM, sacrificing overall quality. If your GPU VRAM <= 4, you should to choose this option.")
else: else:
@ -17,18 +17,18 @@ class Model(ModelBase):
if 'created_vram_gb' in self.options.keys(): if 'created_vram_gb' in self.options.keys():
self.options.pop ('created_vram_gb') self.options.pop ('created_vram_gb')
self.options['lighter_ae'] = self.options.get('lighter_ae', default_lighter_ae) self.options['lighter_ae'] = self.options.get('lighter_ae', default_lighter_ae)
if is_first_run or ask_override: if is_first_run or ask_override:
def_pixel_loss = self.options.get('pixel_loss', False) def_pixel_loss = self.options.get('pixel_loss', False)
self.options['pixel_loss'] = io.input_bool ("Use pixel loss? (y/n, ?:help skip: n/default ) : ", def_pixel_loss, help_message="Default DSSIM loss good for initial understanding structure of faces. Use pixel loss after 20k iters to enhance fine details and decrease face jitter.") self.options['pixel_loss'] = io.input_bool ("Use pixel loss? (y/n, ?:help skip: n/default ) : ", def_pixel_loss, help_message="Default DSSIM loss good for initial understanding structure of faces. Use pixel loss after 20k iters to enhance fine details and decrease face jitter.")
else: else:
self.options['pixel_loss'] = self.options.get('pixel_loss', False) self.options['pixel_loss'] = self.options.get('pixel_loss', False)
#override #override
def onInitialize(self): def onInitialize(self):
exec(nnlib.import_all(), locals(), globals()) exec(nnlib.import_all(), locals(), globals())
self.set_vram_batch_requirements( {2.5:4} ) self.set_vram_batch_requirements( {2.5:4} )
bgr_shape, mask_shape, self.encoder, self.decoder_src, self.decoder_dst = self.Build( self.options['lighter_ae'] ) bgr_shape, mask_shape, self.encoder, self.decoder_src, self.decoder_dst = self.Build( self.options['lighter_ae'] )
if not self.is_first_run(): if not self.is_first_run():
weights_to_load = [ [self.encoder , 'encoder.h5'], weights_to_load = [ [self.encoder , 'encoder.h5'],
@ -36,120 +36,120 @@ class Model(ModelBase):
[self.decoder_dst, 'decoder_dst.h5'] [self.decoder_dst, 'decoder_dst.h5']
] ]
self.load_weights_safe(weights_to_load) self.load_weights_safe(weights_to_load)
input_src_bgr = Input(bgr_shape) input_src_bgr = Input(bgr_shape)
input_src_mask = Input(mask_shape) input_src_mask = Input(mask_shape)
input_dst_bgr = Input(bgr_shape) input_dst_bgr = Input(bgr_shape)
input_dst_mask = Input(mask_shape) input_dst_mask = Input(mask_shape)
rec_src_bgr, rec_src_mask = self.decoder_src( self.encoder(input_src_bgr) ) rec_src_bgr, rec_src_mask = self.decoder_src( self.encoder(input_src_bgr) )
rec_dst_bgr, rec_dst_mask = self.decoder_dst( self.encoder(input_dst_bgr) ) rec_dst_bgr, rec_dst_mask = self.decoder_dst( self.encoder(input_dst_bgr) )
self.ae = Model([input_src_bgr,input_src_mask,input_dst_bgr,input_dst_mask], [rec_src_bgr, rec_src_mask, rec_dst_bgr, rec_dst_mask] ) self.ae = Model([input_src_bgr,input_src_mask,input_dst_bgr,input_dst_mask], [rec_src_bgr, rec_src_mask, rec_dst_bgr, rec_dst_mask] )
self.ae.compile(optimizer=Adam(lr=5e-5, beta_1=0.5, beta_2=0.999), self.ae.compile(optimizer=Adam(lr=5e-5, beta_1=0.5, beta_2=0.999),
loss=[ DSSIMMSEMaskLoss(input_src_mask, is_mse=self.options['pixel_loss']), 'mae', DSSIMMSEMaskLoss(input_dst_mask, is_mse=self.options['pixel_loss']), 'mae' ] ) loss=[ DSSIMMSEMaskLoss(input_src_mask, is_mse=self.options['pixel_loss']), 'mae', DSSIMMSEMaskLoss(input_dst_mask, is_mse=self.options['pixel_loss']), 'mae' ] )
self.src_view = K.function([input_src_bgr],[rec_src_bgr, rec_src_mask]) self.src_view = K.function([input_src_bgr],[rec_src_bgr, rec_src_mask])
self.dst_view = K.function([input_dst_bgr],[rec_dst_bgr, rec_dst_mask]) self.dst_view = K.function([input_dst_bgr],[rec_dst_bgr, rec_dst_mask])
if self.is_training_mode: if self.is_training_mode:
f = SampleProcessor.TypeFlags f = SampleProcessor.TypeFlags
self.set_training_data_generators ([ self.set_training_data_generators ([
SampleGeneratorFace(self.training_data_src_path, sort_by_yaw_target_samples_path=self.training_data_dst_path if self.sort_by_yaw else None, SampleGeneratorFace(self.training_data_src_path, sort_by_yaw_target_samples_path=self.training_data_dst_path if self.sort_by_yaw else None,
debug=self.is_debug(), batch_size=self.batch_size, debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05])+self.src_scale_mod / 100.0 ), sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05])+self.src_scale_mod / 100.0 ),
output_sample_types=[ [f.WARPED_TRANSFORMED | f.FACE_ALIGN_HALF | f.MODE_BGR, 128], output_sample_types=[ [f.WARPED_TRANSFORMED | f.FACE_ALIGN_HALF | f.MODE_BGR, 128],
[f.TRANSFORMED | f.FACE_ALIGN_HALF | f.MODE_BGR, 128], [f.TRANSFORMED | f.FACE_ALIGN_HALF | f.MODE_BGR, 128],
[f.TRANSFORMED | f.FACE_ALIGN_HALF | f.MODE_M | f.FACE_MASK_FULL, 128] ] ), [f.TRANSFORMED | f.FACE_ALIGN_HALF | f.MODE_M | f.FACE_MASK_FULL, 128] ] ),
SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size, SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip), sample_process_options=SampleProcessor.Options(random_flip=self.random_flip),
output_sample_types=[ [f.WARPED_TRANSFORMED | f.FACE_ALIGN_HALF | f.MODE_BGR, 128], output_sample_types=[ [f.WARPED_TRANSFORMED | f.FACE_ALIGN_HALF | f.MODE_BGR, 128],
[f.TRANSFORMED | f.FACE_ALIGN_HALF | f.MODE_BGR, 128], [f.TRANSFORMED | f.FACE_ALIGN_HALF | f.MODE_BGR, 128],
[f.TRANSFORMED | f.FACE_ALIGN_HALF | f.MODE_M | f.FACE_MASK_FULL, 128] ] ) [f.TRANSFORMED | f.FACE_ALIGN_HALF | f.MODE_M | f.FACE_MASK_FULL, 128] ] )
]) ])
#override #override
def onSave(self): def onSave(self):
self.save_weights_safe( [[self.encoder, 'encoder.h5'], self.save_weights_safe( [[self.encoder, 'encoder.h5'],
[self.decoder_src, 'decoder_src.h5'], [self.decoder_src, 'decoder_src.h5'],
[self.decoder_dst, 'decoder_dst.h5']] ) [self.decoder_dst, 'decoder_dst.h5']] )
#override #override
def onTrainOneIter(self, sample, generators_list): def onTrainOneIter(self, sample, generators_list):
warped_src, target_src, target_src_mask = sample[0] warped_src, target_src, target_src_mask = sample[0]
warped_dst, target_dst, target_dst_mask = sample[1] warped_dst, target_dst, target_dst_mask = sample[1]
total, loss_src_bgr, loss_src_mask, loss_dst_bgr, loss_dst_mask = self.ae.train_on_batch( [warped_src, target_src_mask, warped_dst, target_dst_mask], [target_src, target_src_mask, target_dst, target_dst_mask] ) total, loss_src_bgr, loss_src_mask, loss_dst_bgr, loss_dst_mask = self.ae.train_on_batch( [warped_src, target_src_mask, warped_dst, target_dst_mask], [target_src, target_src_mask, target_dst, target_dst_mask] )
return ( ('loss_src', loss_src_bgr), ('loss_dst', loss_dst_bgr) ) return ( ('loss_src', loss_src_bgr), ('loss_dst', loss_dst_bgr) )
#override #override
def onGetPreview(self, sample): def onGetPreview(self, sample):
test_A = sample[0][1][0:4] #first 4 samples test_A = sample[0][1][0:4] #first 4 samples
test_A_m = sample[0][2][0:4] #first 4 samples test_A_m = sample[0][2][0:4] #first 4 samples
test_B = sample[1][1][0:4] test_B = sample[1][1][0:4]
test_B_m = sample[1][2][0:4] test_B_m = sample[1][2][0:4]
AA, mAA = self.src_view([test_A]) AA, mAA = self.src_view([test_A])
AB, mAB = self.src_view([test_B]) AB, mAB = self.src_view([test_B])
BB, mBB = self.dst_view([test_B]) BB, mBB = self.dst_view([test_B])
mAA = np.repeat ( mAA, (3,), -1) mAA = np.repeat ( mAA, (3,), -1)
mAB = np.repeat ( mAB, (3,), -1) mAB = np.repeat ( mAB, (3,), -1)
mBB = np.repeat ( mBB, (3,), -1) mBB = np.repeat ( mBB, (3,), -1)
st = [] st = []
for i in range(0, len(test_A)): for i in range(0, len(test_A)):
st.append ( np.concatenate ( ( st.append ( np.concatenate ( (
test_A[i,:,:,0:3], test_A[i,:,:,0:3],
AA[i], AA[i],
#mAA[i], #mAA[i],
test_B[i,:,:,0:3], test_B[i,:,:,0:3],
BB[i], BB[i],
#mBB[i], #mBB[i],
AB[i], AB[i],
#mAB[i] #mAB[i]
), axis=1) ) ), axis=1) )
return [ ('H128', np.concatenate ( st, axis=0 ) ) ] return [ ('H128', np.concatenate ( st, axis=0 ) ) ]
def predictor_func (self, face): def predictor_func (self, face):
face_128_bgr = face[...,0:3] face_128_bgr = face[...,0:3]
face_128_mask = np.expand_dims(face[...,3],-1) face_128_mask = np.expand_dims(face[...,3],-1)
x, mx = self.src_view ( [ np.expand_dims(face_128_bgr,0) ] ) x, mx = self.src_view ( [ np.expand_dims(face_128_bgr,0) ] )
x, mx = x[0], mx[0] x, mx = x[0], mx[0]
return np.concatenate ( (x,mx), -1 ) return np.concatenate ( (x,mx), -1 )
#override #override
def get_converter(self): def get_converter(self):
from converters import ConverterMasked from converters import ConverterMasked
return ConverterMasked(self.predictor_func, return ConverterMasked(self.predictor_func,
predictor_input_size=128, predictor_input_size=128,
output_size=128, output_size=128,
face_type=FaceType.HALF, face_type=FaceType.HALF,
base_erode_mask_modifier=100, base_erode_mask_modifier=100,
base_blur_mask_modifier=100) base_blur_mask_modifier=100)
def Build(self, lighter_ae): def Build(self, lighter_ae):
exec(nnlib.code_import_all, locals(), globals()) exec(nnlib.code_import_all, locals(), globals())
bgr_shape = (128, 128, 3) bgr_shape = (128, 128, 3)
mask_shape = (128, 128, 1) mask_shape = (128, 128, 1)
def downscale (dim): def downscale (dim):
def func(x): def func(x):
return LeakyReLU(0.1)(Conv2D(dim, 5, strides=2, padding='same')(x)) return LeakyReLU(0.1)(Conv2D(dim, 5, strides=2, padding='same')(x))
return func return func
def upscale (dim): def upscale (dim):
def func(x): def func(x):
return PixelShuffler()(LeakyReLU(0.1)(Conv2D(dim * 4, 3, strides=1, padding='same')(x))) return PixelShuffler()(LeakyReLU(0.1)(Conv2D(dim * 4, 3, strides=1, padding='same')(x)))
return func return func
def Encoder(input_shape): def Encoder(input_shape):
input_layer = Input(input_shape) input_layer = Input(input_shape)
x = input_layer x = input_layer
@ -171,7 +171,7 @@ class Model(ModelBase):
x = Dense(8 * 8 * 256)(x) x = Dense(8 * 8 * 256)(x)
x = Reshape((8, 8, 256))(x) x = Reshape((8, 8, 256))(x)
x = upscale(256)(x) x = upscale(256)(x)
return Model(input_layer, x) return Model(input_layer, x)
def Decoder(): def Decoder():
@ -181,7 +181,7 @@ class Model(ModelBase):
x = upscale(512)(x) x = upscale(512)(x)
x = upscale(256)(x) x = upscale(256)(x)
x = upscale(128)(x) x = upscale(128)(x)
y = input_ #mask decoder y = input_ #mask decoder
y = upscale(512)(y) y = upscale(512)(y)
y = upscale(256)(y) y = upscale(256)(y)
@ -192,16 +192,16 @@ class Model(ModelBase):
x = upscale(256)(x) x = upscale(256)(x)
x = upscale(128)(x) x = upscale(128)(x)
x = upscale(64)(x) x = upscale(64)(x)
y = input_ #mask decoder y = input_ #mask decoder
y = upscale(256)(y) y = upscale(256)(y)
y = upscale(128)(y) y = upscale(128)(y)
y = upscale(64)(y) y = upscale(64)(y)
x = Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(x) x = Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(x)
y = Conv2D(1, kernel_size=5, padding='same', activation='sigmoid')(y) y = Conv2D(1, kernel_size=5, padding='same', activation='sigmoid')(y)
return Model(input_, [x,y]) return Model(input_, [x,y])
return bgr_shape, mask_shape, Encoder(bgr_shape), Decoder(), Decoder() return bgr_shape, mask_shape, Encoder(bgr_shape), Decoder(), Decoder()

View file

@ -1 +1 @@
from .Model import Model from .Model import Model

View file

@ -9,7 +9,7 @@ from interact import interact as io
class Model(ModelBase): class Model(ModelBase):
#override #override
def onInitializeOptions(self, is_first_run, ask_override): def onInitializeOptions(self, is_first_run, ask_override):
if is_first_run: if is_first_run:
self.options['lighter_ae'] = io.input_bool ("Use lightweight autoencoder? (y/n, ?:help skip:n) : ", False, help_message="Lightweight autoencoder is faster, requires less VRAM, sacrificing overall quality. If your GPU VRAM <= 4, you should to choose this option.") self.options['lighter_ae'] = io.input_bool ("Use lightweight autoencoder? (y/n, ?:help skip:n) : ", False, help_message="Lightweight autoencoder is faster, requires less VRAM, sacrificing overall quality. If your GPU VRAM <= 4, you should to choose this option.")
else: else:
@ -17,141 +17,141 @@ class Model(ModelBase):
if 'created_vram_gb' in self.options.keys(): if 'created_vram_gb' in self.options.keys():
self.options.pop ('created_vram_gb') self.options.pop ('created_vram_gb')
self.options['lighter_ae'] = self.options.get('lighter_ae', default_lighter_ae) self.options['lighter_ae'] = self.options.get('lighter_ae', default_lighter_ae)
if is_first_run or ask_override: if is_first_run or ask_override:
def_pixel_loss = self.options.get('pixel_loss', False) def_pixel_loss = self.options.get('pixel_loss', False)
self.options['pixel_loss'] = io.input_bool ("Use pixel loss? (y/n, ?:help skip: n/default ) : ", def_pixel_loss, help_message="Default DSSIM loss good for initial understanding structure of faces. Use pixel loss after 20k iters to enhance fine details and decrease face jitter.") self.options['pixel_loss'] = io.input_bool ("Use pixel loss? (y/n, ?:help skip: n/default ) : ", def_pixel_loss, help_message="Default DSSIM loss good for initial understanding structure of faces. Use pixel loss after 20k iters to enhance fine details and decrease face jitter.")
else: else:
self.options['pixel_loss'] = self.options.get('pixel_loss', False) self.options['pixel_loss'] = self.options.get('pixel_loss', False)
#override #override
def onInitialize(self): def onInitialize(self):
exec(nnlib.import_all(), locals(), globals()) exec(nnlib.import_all(), locals(), globals())
self.set_vram_batch_requirements( {1.5:4} ) self.set_vram_batch_requirements( {1.5:4} )
bgr_shape, mask_shape, self.encoder, self.decoder_src, self.decoder_dst = self.Build(self.options['lighter_ae']) bgr_shape, mask_shape, self.encoder, self.decoder_src, self.decoder_dst = self.Build(self.options['lighter_ae'])
if not self.is_first_run(): if not self.is_first_run():
weights_to_load = [ [self.encoder , 'encoder.h5'], weights_to_load = [ [self.encoder , 'encoder.h5'],
[self.decoder_src, 'decoder_src.h5'], [self.decoder_src, 'decoder_src.h5'],
[self.decoder_dst, 'decoder_dst.h5'] [self.decoder_dst, 'decoder_dst.h5']
] ]
self.load_weights_safe(weights_to_load) self.load_weights_safe(weights_to_load)
input_src_bgr = Input(bgr_shape) input_src_bgr = Input(bgr_shape)
input_src_mask = Input(mask_shape) input_src_mask = Input(mask_shape)
input_dst_bgr = Input(bgr_shape) input_dst_bgr = Input(bgr_shape)
input_dst_mask = Input(mask_shape) input_dst_mask = Input(mask_shape)
rec_src_bgr, rec_src_mask = self.decoder_src( self.encoder(input_src_bgr) ) rec_src_bgr, rec_src_mask = self.decoder_src( self.encoder(input_src_bgr) )
rec_dst_bgr, rec_dst_mask = self.decoder_dst( self.encoder(input_dst_bgr) ) rec_dst_bgr, rec_dst_mask = self.decoder_dst( self.encoder(input_dst_bgr) )
self.ae = Model([input_src_bgr,input_src_mask,input_dst_bgr,input_dst_mask], [rec_src_bgr, rec_src_mask, rec_dst_bgr, rec_dst_mask] ) self.ae = Model([input_src_bgr,input_src_mask,input_dst_bgr,input_dst_mask], [rec_src_bgr, rec_src_mask, rec_dst_bgr, rec_dst_mask] )
self.ae.compile(optimizer=Adam(lr=5e-5, beta_1=0.5, beta_2=0.999), loss=[ DSSIMMSEMaskLoss(input_src_mask, is_mse=self.options['pixel_loss']), 'mae', DSSIMMSEMaskLoss(input_dst_mask, is_mse=self.options['pixel_loss']), 'mae' ] ) self.ae.compile(optimizer=Adam(lr=5e-5, beta_1=0.5, beta_2=0.999), loss=[ DSSIMMSEMaskLoss(input_src_mask, is_mse=self.options['pixel_loss']), 'mae', DSSIMMSEMaskLoss(input_dst_mask, is_mse=self.options['pixel_loss']), 'mae' ] )
self.src_view = K.function([input_src_bgr],[rec_src_bgr, rec_src_mask]) self.src_view = K.function([input_src_bgr],[rec_src_bgr, rec_src_mask])
self.dst_view = K.function([input_dst_bgr],[rec_dst_bgr, rec_dst_mask]) self.dst_view = K.function([input_dst_bgr],[rec_dst_bgr, rec_dst_mask])
if self.is_training_mode: if self.is_training_mode:
f = SampleProcessor.TypeFlags f = SampleProcessor.TypeFlags
self.set_training_data_generators ([ self.set_training_data_generators ([
SampleGeneratorFace(self.training_data_src_path, sort_by_yaw_target_samples_path=self.training_data_dst_path if self.sort_by_yaw else None, SampleGeneratorFace(self.training_data_src_path, sort_by_yaw_target_samples_path=self.training_data_dst_path if self.sort_by_yaw else None,
debug=self.is_debug(), batch_size=self.batch_size, debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05])+self.src_scale_mod / 100.0 ), sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05])+self.src_scale_mod / 100.0 ),
output_sample_types=[ [f.WARPED_TRANSFORMED | f.FACE_ALIGN_HALF | f.MODE_BGR, 64], output_sample_types=[ [f.WARPED_TRANSFORMED | f.FACE_ALIGN_HALF | f.MODE_BGR, 64],
[f.TRANSFORMED | f.FACE_ALIGN_HALF | f.MODE_BGR, 64], [f.TRANSFORMED | f.FACE_ALIGN_HALF | f.MODE_BGR, 64],
[f.TRANSFORMED | f.FACE_ALIGN_HALF | f.MODE_M | f.FACE_MASK_FULL, 64] ] ), [f.TRANSFORMED | f.FACE_ALIGN_HALF | f.MODE_M | f.FACE_MASK_FULL, 64] ] ),
SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size, SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip), sample_process_options=SampleProcessor.Options(random_flip=self.random_flip),
output_sample_types=[ [f.WARPED_TRANSFORMED | f.FACE_ALIGN_HALF | f.MODE_BGR, 64], output_sample_types=[ [f.WARPED_TRANSFORMED | f.FACE_ALIGN_HALF | f.MODE_BGR, 64],
[f.TRANSFORMED | f.FACE_ALIGN_HALF | f.MODE_BGR, 64], [f.TRANSFORMED | f.FACE_ALIGN_HALF | f.MODE_BGR, 64],
[f.TRANSFORMED | f.FACE_ALIGN_HALF | f.MODE_M | f.FACE_MASK_FULL, 64] ] ) [f.TRANSFORMED | f.FACE_ALIGN_HALF | f.MODE_M | f.FACE_MASK_FULL, 64] ] )
]) ])
#override #override
def onSave(self): def onSave(self):
self.save_weights_safe( [[self.encoder, 'encoder.h5'], self.save_weights_safe( [[self.encoder, 'encoder.h5'],
[self.decoder_src, 'decoder_src.h5'], [self.decoder_src, 'decoder_src.h5'],
[self.decoder_dst, 'decoder_dst.h5']] ) [self.decoder_dst, 'decoder_dst.h5']] )
#override #override
def onTrainOneIter(self, sample, generators_list): def onTrainOneIter(self, sample, generators_list):
warped_src, target_src, target_src_full_mask = sample[0] warped_src, target_src, target_src_full_mask = sample[0]
warped_dst, target_dst, target_dst_full_mask = sample[1] warped_dst, target_dst, target_dst_full_mask = sample[1]
total, loss_src_bgr, loss_src_mask, loss_dst_bgr, loss_dst_mask = self.ae.train_on_batch( [warped_src, target_src_full_mask, warped_dst, target_dst_full_mask], [target_src, target_src_full_mask, target_dst, target_dst_full_mask] ) total, loss_src_bgr, loss_src_mask, loss_dst_bgr, loss_dst_mask = self.ae.train_on_batch( [warped_src, target_src_full_mask, warped_dst, target_dst_full_mask], [target_src, target_src_full_mask, target_dst, target_dst_full_mask] )
return ( ('loss_src', loss_src_bgr), ('loss_dst', loss_dst_bgr) ) return ( ('loss_src', loss_src_bgr), ('loss_dst', loss_dst_bgr) )
#override #override
def onGetPreview(self, sample): def onGetPreview(self, sample):
test_A = sample[0][1][0:4] #first 4 samples test_A = sample[0][1][0:4] #first 4 samples
test_A_m = sample[0][2][0:4] test_A_m = sample[0][2][0:4]
test_B = sample[1][1][0:4] test_B = sample[1][1][0:4]
test_B_m = sample[1][2][0:4] test_B_m = sample[1][2][0:4]
AA, mAA = self.src_view([test_A]) AA, mAA = self.src_view([test_A])
AB, mAB = self.src_view([test_B]) AB, mAB = self.src_view([test_B])
BB, mBB = self.dst_view([test_B]) BB, mBB = self.dst_view([test_B])
mAA = np.repeat ( mAA, (3,), -1) mAA = np.repeat ( mAA, (3,), -1)
mAB = np.repeat ( mAB, (3,), -1) mAB = np.repeat ( mAB, (3,), -1)
mBB = np.repeat ( mBB, (3,), -1) mBB = np.repeat ( mBB, (3,), -1)
st = [] st = []
for i in range(0, len(test_A)): for i in range(0, len(test_A)):
st.append ( np.concatenate ( ( st.append ( np.concatenate ( (
test_A[i,:,:,0:3], test_A[i,:,:,0:3],
AA[i], AA[i],
#mAA[i], #mAA[i],
test_B[i,:,:,0:3], test_B[i,:,:,0:3],
BB[i], BB[i],
#mBB[i], #mBB[i],
AB[i], AB[i],
#mAB[i] #mAB[i]
), axis=1) ) ), axis=1) )
return [ ('H64', np.concatenate ( st, axis=0 ) ) ] return [ ('H64', np.concatenate ( st, axis=0 ) ) ]
def predictor_func (self, face): def predictor_func (self, face):
face_64_bgr = face[...,0:3] face_64_bgr = face[...,0:3]
face_64_mask = np.expand_dims(face[...,3],-1) face_64_mask = np.expand_dims(face[...,3],-1)
x, mx = self.src_view ( [ np.expand_dims(face_64_bgr,0) ] ) x, mx = self.src_view ( [ np.expand_dims(face_64_bgr,0) ] )
x, mx = x[0], mx[0] x, mx = x[0], mx[0]
return np.concatenate ( (x,mx), -1 ) return np.concatenate ( (x,mx), -1 )
#override #override
def get_converter(self): def get_converter(self):
from converters import ConverterMasked from converters import ConverterMasked
return ConverterMasked(self.predictor_func, return ConverterMasked(self.predictor_func,
predictor_input_size=64, predictor_input_size=64,
output_size=64, output_size=64,
face_type=FaceType.HALF, face_type=FaceType.HALF,
base_erode_mask_modifier=100, base_erode_mask_modifier=100,
base_blur_mask_modifier=100) base_blur_mask_modifier=100)
def Build(self, lighter_ae): def Build(self, lighter_ae):
exec(nnlib.code_import_all, locals(), globals()) exec(nnlib.code_import_all, locals(), globals())
bgr_shape = (64, 64, 3) bgr_shape = (64, 64, 3)
mask_shape = (64, 64, 1) mask_shape = (64, 64, 1)
def downscale (dim): def downscale (dim):
def func(x): def func(x):
return LeakyReLU(0.1)(Conv2D(dim, 5, strides=2, padding='same')(x)) return LeakyReLU(0.1)(Conv2D(dim, 5, strides=2, padding='same')(x))
return func return func
def upscale (dim): def upscale (dim):
def func(x): def func(x):
return PixelShuffler()(LeakyReLU(0.1)(Conv2D(dim * 4, 3, strides=1, padding='same')(x))) return PixelShuffler()(LeakyReLU(0.1)(Conv2D(dim * 4, 3, strides=1, padding='same')(x)))
return func return func
def Encoder(input_shape): def Encoder(input_shape):
input_layer = Input(input_shape) input_layer = Input(input_shape)
x = input_layer x = input_layer
@ -183,23 +183,23 @@ class Model(ModelBase):
x = upscale(512)(x) x = upscale(512)(x)
x = upscale(256)(x) x = upscale(256)(x)
x = upscale(128)(x) x = upscale(128)(x)
else: else:
input_ = Input(shape=(8, 8, 256)) input_ = Input(shape=(8, 8, 256))
x = input_ x = input_
x = upscale(256)(x) x = upscale(256)(x)
x = upscale(128)(x) x = upscale(128)(x)
x = upscale(64)(x) x = upscale(64)(x)
y = input_ #mask decoder y = input_ #mask decoder
y = upscale(256)(y) y = upscale(256)(y)
y = upscale(128)(y) y = upscale(128)(y)
y = upscale(64)(y) y = upscale(64)(y)
x = Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(x) x = Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(x)
y = Conv2D(1, kernel_size=5, padding='same', activation='sigmoid')(y) y = Conv2D(1, kernel_size=5, padding='same', activation='sigmoid')(y)
return Model(input_, [x,y]) return Model(input_, [x,y])
return bgr_shape, mask_shape, Encoder(bgr_shape), Decoder(), Decoder() return bgr_shape, mask_shape, Encoder(bgr_shape), Decoder(), Decoder()

View file

@ -1 +1 @@
from .Model import Model from .Model import Model

View file

@ -9,13 +9,13 @@ from interact import interact as io
class Model(ModelBase): class Model(ModelBase):
#override #override
def onInitializeOptions(self, is_first_run, ask_override): def onInitializeOptions(self, is_first_run, ask_override):
if is_first_run or ask_override: if is_first_run or ask_override:
def_pixel_loss = self.options.get('pixel_loss', False) def_pixel_loss = self.options.get('pixel_loss', False)
self.options['pixel_loss'] = io.input_bool ("Use pixel loss? (y/n, ?:help skip: n/default ) : ", def_pixel_loss, help_message="Default DSSIM loss good for initial understanding structure of faces. Use pixel loss after 20k iters to enhance fine details and decrease face jitter.") self.options['pixel_loss'] = io.input_bool ("Use pixel loss? (y/n, ?:help skip: n/default ) : ", def_pixel_loss, help_message="Default DSSIM loss good for initial understanding structure of faces. Use pixel loss after 20k iters to enhance fine details and decrease face jitter.")
else: else:
self.options['pixel_loss'] = self.options.get('pixel_loss', False) self.options['pixel_loss'] = self.options.get('pixel_loss', False)
#override #override
def onInitialize(self): def onInitialize(self):
exec(nnlib.import_all(), locals(), globals()) exec(nnlib.import_all(), locals(), globals())
@ -25,7 +25,7 @@ class Model(ModelBase):
mask_layer = Input(shape=(128, 128, 1)) #same as output mask_layer = Input(shape=(128, 128, 1)) #same as output
self.encoder, self.decoder, self.inter_B, self.inter_AB = self.Build(ae_input_layer) self.encoder, self.decoder, self.inter_B, self.inter_AB = self.Build(ae_input_layer)
if not self.is_first_run(): if not self.is_first_run():
weights_to_load = [ [self.encoder, 'encoder.h5'], weights_to_load = [ [self.encoder, 'encoder.h5'],
[self.decoder, 'decoder.h5'], [self.decoder, 'decoder.h5'],
@ -39,46 +39,46 @@ class Model(ModelBase):
B = self.inter_B(code) B = self.inter_B(code)
self.autoencoder_src = Model([ae_input_layer,mask_layer], self.decoder(Concatenate()([AB, AB])) ) self.autoencoder_src = Model([ae_input_layer,mask_layer], self.decoder(Concatenate()([AB, AB])) )
self.autoencoder_dst = Model([ae_input_layer,mask_layer], self.decoder(Concatenate()([B, AB])) ) self.autoencoder_dst = Model([ae_input_layer,mask_layer], self.decoder(Concatenate()([B, AB])) )
self.autoencoder_src.compile(optimizer=Adam(lr=5e-5, beta_1=0.5, beta_2=0.999), loss=[DSSIMMSEMaskLoss(mask_layer, is_mse=self.options['pixel_loss']), 'mse'] ) self.autoencoder_src.compile(optimizer=Adam(lr=5e-5, beta_1=0.5, beta_2=0.999), loss=[DSSIMMSEMaskLoss(mask_layer, is_mse=self.options['pixel_loss']), 'mse'] )
self.autoencoder_dst.compile(optimizer=Adam(lr=5e-5, beta_1=0.5, beta_2=0.999), loss=[DSSIMMSEMaskLoss(mask_layer, is_mse=self.options['pixel_loss']), 'mse'] ) self.autoencoder_dst.compile(optimizer=Adam(lr=5e-5, beta_1=0.5, beta_2=0.999), loss=[DSSIMMSEMaskLoss(mask_layer, is_mse=self.options['pixel_loss']), 'mse'] )
if self.is_training_mode: if self.is_training_mode:
f = SampleProcessor.TypeFlags f = SampleProcessor.TypeFlags
self.set_training_data_generators ([ self.set_training_data_generators ([
SampleGeneratorFace(self.training_data_src_path, sort_by_yaw_target_samples_path=self.training_data_dst_path if self.sort_by_yaw else None, SampleGeneratorFace(self.training_data_src_path, sort_by_yaw_target_samples_path=self.training_data_dst_path if self.sort_by_yaw else None,
debug=self.is_debug(), batch_size=self.batch_size, debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05])+self.src_scale_mod / 100.0 ), sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05])+self.src_scale_mod / 100.0 ),
output_sample_types=[ [f.WARPED_TRANSFORMED | f.FACE_ALIGN_FULL | f.MODE_BGR, 128], output_sample_types=[ [f.WARPED_TRANSFORMED | f.FACE_ALIGN_FULL | f.MODE_BGR, 128],
[f.TRANSFORMED | f.FACE_ALIGN_FULL | f.MODE_BGR, 128], [f.TRANSFORMED | f.FACE_ALIGN_FULL | f.MODE_BGR, 128],
[f.TRANSFORMED | f.FACE_ALIGN_FULL | f.MODE_M | f.FACE_MASK_FULL, 128] ] ), [f.TRANSFORMED | f.FACE_ALIGN_FULL | f.MODE_M | f.FACE_MASK_FULL, 128] ] ),
SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size, SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip), sample_process_options=SampleProcessor.Options(random_flip=self.random_flip),
output_sample_types=[ [f.WARPED_TRANSFORMED | f.FACE_ALIGN_FULL | f.MODE_BGR, 128], output_sample_types=[ [f.WARPED_TRANSFORMED | f.FACE_ALIGN_FULL | f.MODE_BGR, 128],
[f.TRANSFORMED | f.FACE_ALIGN_FULL | f.MODE_BGR, 128], [f.TRANSFORMED | f.FACE_ALIGN_FULL | f.MODE_BGR, 128],
[f.TRANSFORMED | f.FACE_ALIGN_FULL | f.MODE_M | f.FACE_MASK_FULL, 128] ] ) [f.TRANSFORMED | f.FACE_ALIGN_FULL | f.MODE_M | f.FACE_MASK_FULL, 128] ] )
]) ])
#override #override
def onSave(self): def onSave(self):
self.save_weights_safe( [[self.encoder, 'encoder.h5'], self.save_weights_safe( [[self.encoder, 'encoder.h5'],
[self.decoder, 'decoder.h5'], [self.decoder, 'decoder.h5'],
[self.inter_B, 'inter_B.h5'], [self.inter_B, 'inter_B.h5'],
[self.inter_AB, 'inter_AB.h5']] ) [self.inter_AB, 'inter_AB.h5']] )
#override #override
def onTrainOneIter(self, sample, generators_list): def onTrainOneIter(self, sample, generators_list):
warped_src, target_src, target_src_mask = sample[0] warped_src, target_src, target_src_mask = sample[0]
warped_dst, target_dst, target_dst_mask = sample[1] warped_dst, target_dst, target_dst_mask = sample[1]
loss_src = self.autoencoder_src.train_on_batch( [warped_src, target_src_mask], [target_src, target_src_mask] ) loss_src = self.autoencoder_src.train_on_batch( [warped_src, target_src_mask], [target_src, target_src_mask] )
loss_dst = self.autoencoder_dst.train_on_batch( [warped_dst, target_dst_mask], [target_dst, target_dst_mask] ) loss_dst = self.autoencoder_dst.train_on_batch( [warped_dst, target_dst_mask], [target_dst, target_dst_mask] )
return ( ('loss_src', loss_src[0]), ('loss_dst', loss_dst[0]) ) return ( ('loss_src', loss_src[0]), ('loss_dst', loss_dst[0]) )
#override #override
def onGetPreview(self, sample): def onGetPreview(self, sample):
@ -86,63 +86,63 @@ class Model(ModelBase):
test_A_m = sample[0][2][0:4] #first 4 samples test_A_m = sample[0][2][0:4] #first 4 samples
test_B = sample[1][1][0:4] test_B = sample[1][1][0:4]
test_B_m = sample[1][2][0:4] test_B_m = sample[1][2][0:4]
AA, mAA = self.autoencoder_src.predict([test_A, test_A_m]) AA, mAA = self.autoencoder_src.predict([test_A, test_A_m])
AB, mAB = self.autoencoder_src.predict([test_B, test_B_m]) AB, mAB = self.autoencoder_src.predict([test_B, test_B_m])
BB, mBB = self.autoencoder_dst.predict([test_B, test_B_m]) BB, mBB = self.autoencoder_dst.predict([test_B, test_B_m])
mAA = np.repeat ( mAA, (3,), -1) mAA = np.repeat ( mAA, (3,), -1)
mAB = np.repeat ( mAB, (3,), -1) mAB = np.repeat ( mAB, (3,), -1)
mBB = np.repeat ( mBB, (3,), -1) mBB = np.repeat ( mBB, (3,), -1)
st = [] st = []
for i in range(0, len(test_A)): for i in range(0, len(test_A)):
st.append ( np.concatenate ( ( st.append ( np.concatenate ( (
test_A[i,:,:,0:3], test_A[i,:,:,0:3],
AA[i], AA[i],
#mAA[i], #mAA[i],
test_B[i,:,:,0:3], test_B[i,:,:,0:3],
BB[i], BB[i],
#mBB[i], #mBB[i],
AB[i], AB[i],
#mAB[i] #mAB[i]
), axis=1) ) ), axis=1) )
return [ ('LIAEF128', np.concatenate ( st, axis=0 ) ) ] return [ ('LIAEF128', np.concatenate ( st, axis=0 ) ) ]
def predictor_func (self, face): def predictor_func (self, face):
face_128_bgr = face[...,0:3] face_128_bgr = face[...,0:3]
face_128_mask = np.expand_dims(face[...,3],-1) face_128_mask = np.expand_dims(face[...,3],-1)
x, mx = self.autoencoder_src.predict ( [ np.expand_dims(face_128_bgr,0), np.expand_dims(face_128_mask,0) ] ) x, mx = self.autoencoder_src.predict ( [ np.expand_dims(face_128_bgr,0), np.expand_dims(face_128_mask,0) ] )
x, mx = x[0], mx[0] x, mx = x[0], mx[0]
return np.concatenate ( (x,mx), -1 ) return np.concatenate ( (x,mx), -1 )
#override #override
def get_converter(self): def get_converter(self):
from converters import ConverterMasked from converters import ConverterMasked
return ConverterMasked(self.predictor_func, return ConverterMasked(self.predictor_func,
predictor_input_size=128, predictor_input_size=128,
output_size=128, output_size=128,
face_type=FaceType.FULL, face_type=FaceType.FULL,
base_erode_mask_modifier=30, base_erode_mask_modifier=30,
base_blur_mask_modifier=0) base_blur_mask_modifier=0)
def Build(self, input_layer): def Build(self, input_layer):
exec(nnlib.code_import_all, locals(), globals()) exec(nnlib.code_import_all, locals(), globals())
def downscale (dim): def downscale (dim):
def func(x): def func(x):
return LeakyReLU(0.1)(Conv2D(dim, 5, strides=2, padding='same')(x)) return LeakyReLU(0.1)(Conv2D(dim, 5, strides=2, padding='same')(x))
return func return func
def upscale (dim): def upscale (dim):
def func(x): def func(x):
return PixelShuffler()(LeakyReLU(0.1)(Conv2D(dim * 4, 3, strides=1, padding='same')(x))) return PixelShuffler()(LeakyReLU(0.1)(Conv2D(dim * 4, 3, strides=1, padding='same')(x)))
return func return func
def Encoder(): def Encoder():
x = input_layer x = input_layer
x = downscale(128)(x) x = downscale(128)(x)
@ -161,20 +161,20 @@ class Model(ModelBase):
x = upscale(512)(x) x = upscale(512)(x)
return Model(input_layer, x) return Model(input_layer, x)
def Decoder(): def Decoder():
input_ = Input(shape=(16, 16, 1024)) input_ = Input(shape=(16, 16, 1024))
x = input_ x = input_
x = upscale(512)(x) x = upscale(512)(x)
x = upscale(256)(x) x = upscale(256)(x)
x = upscale(128)(x) x = upscale(128)(x)
x = Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(x) x = Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(x)
y = input_ #mask decoder y = input_ #mask decoder
y = upscale(512)(y) y = upscale(512)(y)
y = upscale(256)(y) y = upscale(256)(y)
y = upscale(128)(y) y = upscale(128)(y)
y = Conv2D(1, kernel_size=5, padding='same', activation='sigmoid' )(y) y = Conv2D(1, kernel_size=5, padding='same', activation='sigmoid' )(y)
return Model(input_, [x,y]) return Model(input_, [x,y])
return Encoder(), Decoder(), Intermediate(), Intermediate() return Encoder(), Decoder(), Intermediate(), Intermediate()

View file

@ -1 +1 @@
from .Model import Model from .Model import Model

View file

@ -9,51 +9,51 @@ from interact import interact as io
#SAE - Styled AutoEncoder #SAE - Styled AutoEncoder
class SAEModel(ModelBase): class SAEModel(ModelBase):
encoderH5 = 'encoder.h5' encoderH5 = 'encoder.h5'
inter_BH5 = 'inter_B.h5' inter_BH5 = 'inter_B.h5'
inter_ABH5 = 'inter_AB.h5' inter_ABH5 = 'inter_AB.h5'
decoderH5 = 'decoder.h5' decoderH5 = 'decoder.h5'
decodermH5 = 'decoderm.h5' decodermH5 = 'decoderm.h5'
decoder_srcH5 = 'decoder_src.h5' decoder_srcH5 = 'decoder_src.h5'
decoder_srcmH5 = 'decoder_srcm.h5' decoder_srcmH5 = 'decoder_srcm.h5'
decoder_dstH5 = 'decoder_dst.h5' decoder_dstH5 = 'decoder_dst.h5'
decoder_dstmH5 = 'decoder_dstm.h5' decoder_dstmH5 = 'decoder_dstm.h5'
#override #override
def onInitializeOptions(self, is_first_run, ask_override): def onInitializeOptions(self, is_first_run, ask_override):
yn_str = {True:'y',False:'n'} yn_str = {True:'y',False:'n'}
default_resolution = 128 default_resolution = 128
default_archi = 'df' default_archi = 'df'
default_face_type = 'f' default_face_type = 'f'
if is_first_run: if is_first_run:
resolution = io.input_int("Resolution ( 64-256 ?:help skip:128) : ", default_resolution, help_message="More resolution requires more VRAM and time to train. Value will be adjusted to multiple of 16.") resolution = io.input_int("Resolution ( 64-256 ?:help skip:128) : ", default_resolution, help_message="More resolution requires more VRAM and time to train. Value will be adjusted to multiple of 16.")
resolution = np.clip (resolution, 64, 256) resolution = np.clip (resolution, 64, 256)
while np.modf(resolution / 16)[0] != 0.0: while np.modf(resolution / 16)[0] != 0.0:
resolution -= 1 resolution -= 1
self.options['resolution'] = resolution self.options['resolution'] = resolution
self.options['face_type'] = io.input_str ("Half or Full face? (h/f, ?:help skip:f) : ", default_face_type, ['h','f'], help_message="Half face has better resolution, but covers less area of cheeks.").lower() self.options['face_type'] = io.input_str ("Half or Full face? (h/f, ?:help skip:f) : ", default_face_type, ['h','f'], help_message="Half face has better resolution, but covers less area of cheeks.").lower()
self.options['learn_mask'] = io.input_bool ("Learn mask? (y/n, ?:help skip:y) : ", True, help_message="Learning mask can help model to recognize face directions. Learn without mask can reduce model size, in this case converter forced to use 'not predicted mask' that is not smooth as predicted. Model with style values can be learned without mask and produce same quality result.") self.options['learn_mask'] = io.input_bool ("Learn mask? (y/n, ?:help skip:y) : ", True, help_message="Learning mask can help model to recognize face directions. Learn without mask can reduce model size, in this case converter forced to use 'not predicted mask' that is not smooth as predicted. Model with style values can be learned without mask and produce same quality result.")
else: else:
self.options['resolution'] = self.options.get('resolution', default_resolution) self.options['resolution'] = self.options.get('resolution', default_resolution)
self.options['face_type'] = self.options.get('face_type', default_face_type) self.options['face_type'] = self.options.get('face_type', default_face_type)
self.options['learn_mask'] = self.options.get('learn_mask', True) self.options['learn_mask'] = self.options.get('learn_mask', True)
if is_first_run and 'tensorflow' in self.device_config.backend: if is_first_run and 'tensorflow' in self.device_config.backend:
def_optimizer_mode = self.options.get('optimizer_mode', 1) def_optimizer_mode = self.options.get('optimizer_mode', 1)
self.options['optimizer_mode'] = io.input_int ("Optimizer mode? ( 1,2,3 ?:help skip:%d) : " % (def_optimizer_mode), def_optimizer_mode, help_message="1 - no changes. 2 - allows you to train x2 bigger network consuming RAM. 3 - allows you to train x3 bigger network consuming huge amount of RAM and slower, depends on CPU power.") self.options['optimizer_mode'] = io.input_int ("Optimizer mode? ( 1,2,3 ?:help skip:%d) : " % (def_optimizer_mode), def_optimizer_mode, help_message="1 - no changes. 2 - allows you to train x2 bigger network consuming RAM. 3 - allows you to train x3 bigger network consuming huge amount of RAM and slower, depends on CPU power.")
else: else:
self.options['optimizer_mode'] = self.options.get('optimizer_mode', 1) self.options['optimizer_mode'] = self.options.get('optimizer_mode', 1)
if is_first_run: if is_first_run:
self.options['archi'] = io.input_str ("AE architecture (df, liae, vg ?:help skip:%s) : " % (default_archi) , default_archi, ['df','liae','vg'], help_message="'df' keeps faces more natural. 'liae' can fix overly different face shapes. 'vg' - currently testing.").lower() self.options['archi'] = io.input_str ("AE architecture (df, liae, vg ?:help skip:%s) : " % (default_archi) , default_archi, ['df','liae','vg'], help_message="'df' keeps faces more natural. 'liae' can fix overly different face shapes. 'vg' - currently testing.").lower()
else: else:
self.options['archi'] = self.options.get('archi', default_archi) self.options['archi'] = self.options.get('archi', default_archi)
default_ae_dims = 256 if self.options['archi'] == 'liae' else 512 default_ae_dims = 256 if self.options['archi'] == 'liae' else 512
default_ed_ch_dims = 42 default_ed_ch_dims = 42
def_ca_weights = False def_ca_weights = False
@ -65,31 +65,31 @@ class SAEModel(ModelBase):
self.options['ae_dims'] = self.options.get('ae_dims', default_ae_dims) self.options['ae_dims'] = self.options.get('ae_dims', default_ae_dims)
self.options['ed_ch_dims'] = self.options.get('ed_ch_dims', default_ed_ch_dims) self.options['ed_ch_dims'] = self.options.get('ed_ch_dims', default_ed_ch_dims)
self.options['ca_weights'] = self.options.get('ca_weights', def_ca_weights) self.options['ca_weights'] = self.options.get('ca_weights', def_ca_weights)
if is_first_run: if is_first_run:
self.options['lighter_encoder'] = io.input_bool ("Use lightweight encoder? (y/n, ?:help skip:n) : ", False, help_message="Lightweight encoder is 35% faster, requires less VRAM, but sacrificing overall quality.") self.options['lighter_encoder'] = io.input_bool ("Use lightweight encoder? (y/n, ?:help skip:n) : ", False, help_message="Lightweight encoder is 35% faster, requires less VRAM, but sacrificing overall quality.")
if self.options['archi'] != 'vg': if self.options['archi'] != 'vg':
self.options['multiscale_decoder'] = io.input_bool ("Use multiscale decoder? (y/n, ?:help skip:n) : ", False, help_message="Multiscale decoder helps to get better details.") self.options['multiscale_decoder'] = io.input_bool ("Use multiscale decoder? (y/n, ?:help skip:n) : ", False, help_message="Multiscale decoder helps to get better details.")
else: else:
self.options['lighter_encoder'] = self.options.get('lighter_encoder', False) self.options['lighter_encoder'] = self.options.get('lighter_encoder', False)
if self.options['archi'] != 'vg': if self.options['archi'] != 'vg':
self.options['multiscale_decoder'] = self.options.get('multiscale_decoder', False) self.options['multiscale_decoder'] = self.options.get('multiscale_decoder', False)
default_face_style_power = 0.0 default_face_style_power = 0.0
default_bg_style_power = 0.0 default_bg_style_power = 0.0
if is_first_run or ask_override: if is_first_run or ask_override:
def_pixel_loss = self.options.get('pixel_loss', False) def_pixel_loss = self.options.get('pixel_loss', False)
self.options['pixel_loss'] = io.input_bool ("Use pixel loss? (y/n, ?:help skip: %s ) : " % (yn_str[def_pixel_loss]), def_pixel_loss, help_message="Default DSSIM loss good for initial understanding structure of faces. Use pixel loss after 15-25k iters to enhance fine details and decrease face jitter.") self.options['pixel_loss'] = io.input_bool ("Use pixel loss? (y/n, ?:help skip: %s ) : " % (yn_str[def_pixel_loss]), def_pixel_loss, help_message="Default DSSIM loss good for initial understanding structure of faces. Use pixel loss after 15-25k iters to enhance fine details and decrease face jitter.")
default_face_style_power = default_face_style_power if is_first_run else self.options.get('face_style_power', default_face_style_power) default_face_style_power = default_face_style_power if is_first_run else self.options.get('face_style_power', default_face_style_power)
self.options['face_style_power'] = np.clip ( io.input_number("Face style power ( 0.0 .. 100.0 ?:help skip:%.2f) : " % (default_face_style_power), default_face_style_power, self.options['face_style_power'] = np.clip ( io.input_number("Face style power ( 0.0 .. 100.0 ?:help skip:%.2f) : " % (default_face_style_power), default_face_style_power,
help_message="Learn to transfer face style details such as light and color conditions. Warning: Enable it only after 10k iters, when predicted face is clear enough to start learn style. Start from 0.1 value and check history changes."), 0.0, 100.0 ) help_message="Learn to transfer face style details such as light and color conditions. Warning: Enable it only after 10k iters, when predicted face is clear enough to start learn style. Start from 0.1 value and check history changes."), 0.0, 100.0 )
default_bg_style_power = default_bg_style_power if is_first_run else self.options.get('bg_style_power', default_bg_style_power) default_bg_style_power = default_bg_style_power if is_first_run else self.options.get('bg_style_power', default_bg_style_power)
self.options['bg_style_power'] = np.clip ( io.input_number("Background style power ( 0.0 .. 100.0 ?:help skip:%.2f) : " % (default_bg_style_power), default_bg_style_power, self.options['bg_style_power'] = np.clip ( io.input_number("Background style power ( 0.0 .. 100.0 ?:help skip:%.2f) : " % (default_bg_style_power), default_bg_style_power,
help_message="Learn to transfer image around face. This can make face more like dst."), 0.0, 100.0 ) help_message="Learn to transfer image around face. This can make face more like dst."), 0.0, 100.0 )
else: else:
self.options['pixel_loss'] = self.options.get('pixel_loss', False) self.options['pixel_loss'] = self.options.get('pixel_loss', False)
self.options['face_style_power'] = self.options.get('face_style_power', default_face_style_power) self.options['face_style_power'] = self.options.get('face_style_power', default_face_style_power)
@ -100,7 +100,7 @@ class SAEModel(ModelBase):
exec(nnlib.import_all(), locals(), globals()) exec(nnlib.import_all(), locals(), globals())
SAEModel.initialize_nn_functions() SAEModel.initialize_nn_functions()
self.set_vram_batch_requirements({1.5:4}) self.set_vram_batch_requirements({1.5:4})
resolution = self.options['resolution'] resolution = self.options['resolution']
ae_dims = self.options['ae_dims'] ae_dims = self.options['ae_dims']
ed_ch_dims = self.options['ed_ch_dims'] ed_ch_dims = self.options['ed_ch_dims']
@ -108,13 +108,13 @@ class SAEModel(ModelBase):
mask_shape = (resolution, resolution, 1) mask_shape = (resolution, resolution, 1)
self.ms_count = ms_count = 3 if (self.options['archi'] != 'vg' and self.options['multiscale_decoder']) else 1 self.ms_count = ms_count = 3 if (self.options['archi'] != 'vg' and self.options['multiscale_decoder']) else 1
masked_training = True masked_training = True
warped_src = Input(bgr_shape) warped_src = Input(bgr_shape)
target_src = Input(bgr_shape) target_src = Input(bgr_shape)
target_srcm = Input(mask_shape) target_srcm = Input(mask_shape)
warped_dst = Input(bgr_shape) warped_dst = Input(bgr_shape)
target_dst = Input(bgr_shape) target_dst = Input(bgr_shape)
target_dstm = Input(mask_shape) target_dstm = Input(mask_shape)
@ -124,27 +124,28 @@ class SAEModel(ModelBase):
target_dst_ar = [ Input ( ( bgr_shape[0] // (2**i) ,)*2 + (bgr_shape[-1],) ) for i in range(ms_count-1, -1, -1)] target_dst_ar = [ Input ( ( bgr_shape[0] // (2**i) ,)*2 + (bgr_shape[-1],) ) for i in range(ms_count-1, -1, -1)]
target_dstm_ar = [ Input ( ( mask_shape[0] // (2**i) ,)*2 + (mask_shape[-1],) ) for i in range(ms_count-1, -1, -1)] target_dstm_ar = [ Input ( ( mask_shape[0] // (2**i) ,)*2 + (mask_shape[-1],) ) for i in range(ms_count-1, -1, -1)]
use_bn = True
models_list = [] models_list = []
weights_to_load = [] weights_to_load = []
if self.options['archi'] == 'liae': if self.options['archi'] == 'liae':
self.encoder = modelify(SAEModel.LIAEEncFlow(resolution, self.options['lighter_encoder'], ed_ch_dims=ed_ch_dims) ) (Input(bgr_shape)) self.encoder = modelify(SAEModel.LIAEEncFlow(resolution, self.options['lighter_encoder'], ed_ch_dims=ed_ch_dims, use_bn=use_bn) ) (Input(bgr_shape))
enc_output_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ] enc_output_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ]
self.inter_B = modelify(SAEModel.LIAEInterFlow(resolution, ae_dims=ae_dims)) (enc_output_Inputs) self.inter_B = modelify(SAEModel.LIAEInterFlow(resolution, ae_dims=ae_dims, use_bn=use_bn)) (enc_output_Inputs)
self.inter_AB = modelify(SAEModel.LIAEInterFlow(resolution, ae_dims=ae_dims)) (enc_output_Inputs) self.inter_AB = modelify(SAEModel.LIAEInterFlow(resolution, ae_dims=ae_dims, use_bn=use_bn)) (enc_output_Inputs)
inter_output_Inputs = [ Input( np.array(K.int_shape(x)[1:])*(1,1,2) ) for x in self.inter_B.outputs ] inter_output_Inputs = [ Input( np.array(K.int_shape(x)[1:])*(1,1,2) ) for x in self.inter_B.outputs ]
self.decoder = modelify(SAEModel.LIAEDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2, multiscale_count=self.ms_count, use_bn=use_bn )) (inter_output_Inputs)
self.decoder = modelify(SAEModel.LIAEDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2, multiscale_count=self.ms_count )) (inter_output_Inputs)
models_list += [self.encoder, self.inter_B, self.inter_AB, self.decoder] models_list += [self.encoder, self.inter_B, self.inter_AB, self.decoder]
if self.options['learn_mask']: if self.options['learn_mask']:
self.decoderm = modelify(SAEModel.LIAEDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5) )) (inter_output_Inputs) self.decoderm = modelify(SAEModel.LIAEDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5), use_bn=use_bn )) (inter_output_Inputs)
models_list += [self.decoderm] models_list += [self.decoderm]
if not self.is_first_run(): if not self.is_first_run():
weights_to_load += [ [self.encoder , 'encoder.h5'], weights_to_load += [ [self.encoder , 'encoder.h5'],
[self.inter_B , 'inter_B.h5'], [self.inter_B , 'inter_B.h5'],
@ -153,22 +154,22 @@ class SAEModel(ModelBase):
] ]
if self.options['learn_mask']: if self.options['learn_mask']:
weights_to_load += [ [self.decoderm, 'decoderm.h5'] ] weights_to_load += [ [self.decoderm, 'decoderm.h5'] ]
warped_src_code = self.encoder (warped_src) warped_src_code = self.encoder (warped_src)
warped_src_inter_AB_code = self.inter_AB (warped_src_code) warped_src_inter_AB_code = self.inter_AB (warped_src_code)
warped_src_inter_code = Concatenate()([warped_src_inter_AB_code,warped_src_inter_AB_code]) warped_src_inter_code = Concatenate()([warped_src_inter_AB_code,warped_src_inter_AB_code])
warped_dst_code = self.encoder (warped_dst) warped_dst_code = self.encoder (warped_dst)
warped_dst_inter_B_code = self.inter_B (warped_dst_code) warped_dst_inter_B_code = self.inter_B (warped_dst_code)
warped_dst_inter_AB_code = self.inter_AB (warped_dst_code) warped_dst_inter_AB_code = self.inter_AB (warped_dst_code)
warped_dst_inter_code = Concatenate()([warped_dst_inter_B_code,warped_dst_inter_AB_code]) warped_dst_inter_code = Concatenate()([warped_dst_inter_B_code,warped_dst_inter_AB_code])
warped_src_dst_inter_code = Concatenate()([warped_dst_inter_AB_code,warped_dst_inter_AB_code]) warped_src_dst_inter_code = Concatenate()([warped_dst_inter_AB_code,warped_dst_inter_AB_code])
pred_src_src = self.decoder(warped_src_inter_code) pred_src_src = self.decoder(warped_src_inter_code)
pred_dst_dst = self.decoder(warped_dst_inter_code) pred_dst_dst = self.decoder(warped_dst_inter_code)
pred_src_dst = self.decoder(warped_src_dst_inter_code) pred_src_dst = self.decoder(warped_src_dst_inter_code)
if self.options['learn_mask']: if self.options['learn_mask']:
pred_src_srcm = self.decoderm(warped_src_inter_code) pred_src_srcm = self.decoderm(warped_src_inter_code)
pred_dst_dstm = self.decoderm(warped_dst_inter_code) pred_dst_dstm = self.decoderm(warped_dst_inter_code)
@ -177,18 +178,18 @@ class SAEModel(ModelBase):
elif self.options['archi'] == 'df': elif self.options['archi'] == 'df':
self.encoder = modelify(SAEModel.DFEncFlow(resolution, self.options['lighter_encoder'], ae_dims=ae_dims, ed_ch_dims=ed_ch_dims) ) (Input(bgr_shape)) self.encoder = modelify(SAEModel.DFEncFlow(resolution, self.options['lighter_encoder'], ae_dims=ae_dims, ed_ch_dims=ed_ch_dims) ) (Input(bgr_shape))
dec_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ] dec_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ]
self.decoder_src = modelify(SAEModel.DFDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2, multiscale_count=self.ms_count )) (dec_Inputs) self.decoder_src = modelify(SAEModel.DFDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2, multiscale_count=self.ms_count )) (dec_Inputs)
self.decoder_dst = modelify(SAEModel.DFDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2, multiscale_count=self.ms_count )) (dec_Inputs) self.decoder_dst = modelify(SAEModel.DFDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2, multiscale_count=self.ms_count )) (dec_Inputs)
models_list += [self.encoder, self.decoder_src, self.decoder_dst] models_list += [self.encoder, self.decoder_src, self.decoder_dst]
if self.options['learn_mask']: if self.options['learn_mask']:
self.decoder_srcm = modelify(SAEModel.DFDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5) )) (dec_Inputs) self.decoder_srcm = modelify(SAEModel.DFDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5) )) (dec_Inputs)
self.decoder_dstm = modelify(SAEModel.DFDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5) )) (dec_Inputs) self.decoder_dstm = modelify(SAEModel.DFDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5) )) (dec_Inputs)
models_list += [self.decoder_srcm, self.decoder_dstm] models_list += [self.decoder_srcm, self.decoder_dstm]
if not self.is_first_run(): if not self.is_first_run():
weights_to_load += [ [self.encoder , 'encoder.h5'], weights_to_load += [ [self.encoder , 'encoder.h5'],
[self.decoder_src, 'decoder_src.h5'], [self.decoder_src, 'decoder_src.h5'],
@ -198,37 +199,37 @@ class SAEModel(ModelBase):
weights_to_load += [ [self.decoder_srcm, 'decoder_srcm.h5'], weights_to_load += [ [self.decoder_srcm, 'decoder_srcm.h5'],
[self.decoder_dstm, 'decoder_dstm.h5'], [self.decoder_dstm, 'decoder_dstm.h5'],
] ]
warped_src_code = self.encoder (warped_src) warped_src_code = self.encoder (warped_src)
warped_dst_code = self.encoder (warped_dst) warped_dst_code = self.encoder (warped_dst)
pred_src_src = self.decoder_src(warped_src_code) pred_src_src = self.decoder_src(warped_src_code)
pred_dst_dst = self.decoder_dst(warped_dst_code) pred_dst_dst = self.decoder_dst(warped_dst_code)
pred_src_dst = self.decoder_src(warped_dst_code) pred_src_dst = self.decoder_src(warped_dst_code)
if self.options['learn_mask']: if self.options['learn_mask']:
pred_src_srcm = self.decoder_srcm(warped_src_code) pred_src_srcm = self.decoder_srcm(warped_src_code)
pred_dst_dstm = self.decoder_dstm(warped_dst_code) pred_dst_dstm = self.decoder_dstm(warped_dst_code)
pred_src_dstm = self.decoder_srcm(warped_dst_code) pred_src_dstm = self.decoder_srcm(warped_dst_code)
elif self.options['archi'] == 'vg': elif self.options['archi'] == 'vg':
self.encoder = modelify(SAEModel.VGEncFlow(resolution, self.options['lighter_encoder'], ae_dims=ae_dims, ed_ch_dims=ed_ch_dims) ) (Input(bgr_shape)) self.encoder = modelify(SAEModel.VGEncFlow(resolution, self.options['lighter_encoder'], ae_dims=ae_dims, ed_ch_dims=ed_ch_dims) ) (Input(bgr_shape))
dec_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ] dec_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ]
self.decoder_src = modelify(SAEModel.VGDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2 )) (dec_Inputs) self.decoder_src = modelify(SAEModel.VGDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2 )) (dec_Inputs)
self.decoder_dst = modelify(SAEModel.VGDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2 )) (dec_Inputs) self.decoder_dst = modelify(SAEModel.VGDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2 )) (dec_Inputs)
models_list += [self.encoder, self.decoder_src, self.decoder_dst] models_list += [self.encoder, self.decoder_src, self.decoder_dst]
if self.options['learn_mask']: if self.options['learn_mask']:
self.decoder_srcm = modelify(SAEModel.VGDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5) )) (dec_Inputs) self.decoder_srcm = modelify(SAEModel.VGDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5) )) (dec_Inputs)
self.decoder_dstm = modelify(SAEModel.VGDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5) )) (dec_Inputs) self.decoder_dstm = modelify(SAEModel.VGDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5) )) (dec_Inputs)
models_list += [self.decoder_srcm, self.decoder_dstm] models_list += [self.decoder_srcm, self.decoder_dstm]
if not self.is_first_run(): if not self.is_first_run():
weights_to_load += [ [self.encoder , 'encoder.h5'], weights_to_load += [ [self.encoder , 'encoder.h5'],
[self.decoder_src, 'decoder_src.h5'], [self.decoder_src, 'decoder_src.h5'],
@ -238,7 +239,7 @@ class SAEModel(ModelBase):
weights_to_load += [ [self.decoder_srcm, 'decoder_srcm.h5'], weights_to_load += [ [self.decoder_srcm, 'decoder_srcm.h5'],
[self.decoder_dstm, 'decoder_dstm.h5'], [self.decoder_dstm, 'decoder_dstm.h5'],
] ]
warped_src_code = self.encoder (warped_src) warped_src_code = self.encoder (warped_src)
warped_dst_code = self.encoder (warped_dst) warped_dst_code = self.encoder (warped_dst)
pred_src_src = self.decoder_src(warped_src_code) pred_src_src = self.decoder_src(warped_src_code)
@ -250,30 +251,30 @@ class SAEModel(ModelBase):
pred_src_srcm = self.decoder_srcm(warped_src_code) pred_src_srcm = self.decoder_srcm(warped_src_code)
pred_dst_dstm = self.decoder_dstm(warped_dst_code) pred_dst_dstm = self.decoder_dstm(warped_dst_code)
pred_src_dstm = self.decoder_srcm(warped_dst_code) pred_src_dstm = self.decoder_srcm(warped_dst_code)
if self.is_first_run() and self.options['ca_weights']: if self.is_first_run() and self.options['ca_weights']:
io.log_info ("Initializing CA weights...") io.log_info ("Initializing CA weights...")
conv_weights_list = [] conv_weights_list = []
for model in models_list: for model in models_list:
for layer in model.layers: for layer in model.layers:
if type(layer) == Conv2D: if type(layer) == Conv2D:
conv_weights_list += [layer.weights[0]] #Conv2D kernel_weights conv_weights_list += [layer.weights[0]] #Conv2D kernel_weights
CAInitializerMP ( conv_weights_list ) CAInitializerMP ( conv_weights_list )
pred_src_src, pred_dst_dst, pred_src_dst, = [ [x] if type(x) != list else x for x in [pred_src_src, pred_dst_dst, pred_src_dst, ] ] pred_src_src, pred_dst_dst, pred_src_dst, = [ [x] if type(x) != list else x for x in [pred_src_src, pred_dst_dst, pred_src_dst, ] ]
if self.options['learn_mask']: if self.options['learn_mask']:
pred_src_srcm, pred_dst_dstm, pred_src_dstm = [ [x] if type(x) != list else x for x in [pred_src_srcm, pred_dst_dstm, pred_src_dstm] ] pred_src_srcm, pred_dst_dstm, pred_src_dstm = [ [x] if type(x) != list else x for x in [pred_src_srcm, pred_dst_dstm, pred_src_dstm] ]
target_srcm_blurred_ar = [ gaussian_blur( max(1, K.int_shape(x)[1] // 32) )(x) for x in target_srcm_ar] target_srcm_blurred_ar = [ gaussian_blur( max(1, K.int_shape(x)[1] // 32) )(x) for x in target_srcm_ar]
target_srcm_sigm_ar = target_srcm_blurred_ar #[ x / 2.0 + 0.5 for x in target_srcm_blurred_ar] target_srcm_sigm_ar = target_srcm_blurred_ar #[ x / 2.0 + 0.5 for x in target_srcm_blurred_ar]
target_srcm_anti_sigm_ar = [ 1.0 - x for x in target_srcm_sigm_ar] target_srcm_anti_sigm_ar = [ 1.0 - x for x in target_srcm_sigm_ar]
target_dstm_blurred_ar = [ gaussian_blur( max(1, K.int_shape(x)[1] // 32) )(x) for x in target_dstm_ar] target_dstm_blurred_ar = [ gaussian_blur( max(1, K.int_shape(x)[1] // 32) )(x) for x in target_dstm_ar]
target_dstm_sigm_ar = target_dstm_blurred_ar#[ x / 2.0 + 0.5 for x in target_dstm_blurred_ar] target_dstm_sigm_ar = target_dstm_blurred_ar#[ x / 2.0 + 0.5 for x in target_dstm_blurred_ar]
target_dstm_anti_sigm_ar = [ 1.0 - x for x in target_dstm_sigm_ar] target_dstm_anti_sigm_ar = [ 1.0 - x for x in target_dstm_sigm_ar]
target_src_sigm_ar = target_src_ar#[ x + 1 for x in target_src_ar] target_src_sigm_ar = target_src_ar#[ x + 1 for x in target_src_ar]
target_dst_sigm_ar = target_dst_ar#[ x + 1 for x in target_dst_ar] target_dst_sigm_ar = target_dst_ar#[ x + 1 for x in target_dst_ar]
@ -284,32 +285,32 @@ class SAEModel(ModelBase):
target_src_masked_ar = [ target_src_sigm_ar[i]*target_srcm_sigm_ar[i] for i in range(len(target_src_sigm_ar))] target_src_masked_ar = [ target_src_sigm_ar[i]*target_srcm_sigm_ar[i] for i in range(len(target_src_sigm_ar))]
target_dst_masked_ar = [ target_dst_sigm_ar[i]*target_dstm_sigm_ar[i] for i in range(len(target_dst_sigm_ar))] target_dst_masked_ar = [ target_dst_sigm_ar[i]*target_dstm_sigm_ar[i] for i in range(len(target_dst_sigm_ar))]
target_dst_anti_masked_ar = [ target_dst_sigm_ar[i]*target_dstm_anti_sigm_ar[i] for i in range(len(target_dst_sigm_ar))] target_dst_anti_masked_ar = [ target_dst_sigm_ar[i]*target_dstm_anti_sigm_ar[i] for i in range(len(target_dst_sigm_ar))]
pred_src_src_masked_ar = [ pred_src_src_sigm_ar[i] * target_srcm_sigm_ar[i] for i in range(len(pred_src_src_sigm_ar))] pred_src_src_masked_ar = [ pred_src_src_sigm_ar[i] * target_srcm_sigm_ar[i] for i in range(len(pred_src_src_sigm_ar))]
pred_dst_dst_masked_ar = [ pred_dst_dst_sigm_ar[i] * target_dstm_sigm_ar[i] for i in range(len(pred_dst_dst_sigm_ar))] pred_dst_dst_masked_ar = [ pred_dst_dst_sigm_ar[i] * target_dstm_sigm_ar[i] for i in range(len(pred_dst_dst_sigm_ar))]
target_src_masked_ar_opt = target_src_masked_ar if masked_training else target_src_sigm_ar target_src_masked_ar_opt = target_src_masked_ar if masked_training else target_src_sigm_ar
target_dst_masked_ar_opt = target_dst_masked_ar if masked_training else target_dst_sigm_ar target_dst_masked_ar_opt = target_dst_masked_ar if masked_training else target_dst_sigm_ar
pred_src_src_masked_ar_opt = pred_src_src_masked_ar if masked_training else pred_src_src_sigm_ar pred_src_src_masked_ar_opt = pred_src_src_masked_ar if masked_training else pred_src_src_sigm_ar
pred_dst_dst_masked_ar_opt = pred_dst_dst_masked_ar if masked_training else pred_dst_dst_sigm_ar pred_dst_dst_masked_ar_opt = pred_dst_dst_masked_ar if masked_training else pred_dst_dst_sigm_ar
psd_target_dst_masked_ar = [ pred_src_dst_sigm_ar[i]*target_dstm_sigm_ar[i] for i in range(len(pred_src_dst_sigm_ar))] psd_target_dst_masked_ar = [ pred_src_dst_sigm_ar[i]*target_dstm_sigm_ar[i] for i in range(len(pred_src_dst_sigm_ar))]
psd_target_dst_anti_masked_ar = [ pred_src_dst_sigm_ar[i]*target_dstm_anti_sigm_ar[i] for i in range(len(pred_src_dst_sigm_ar))] psd_target_dst_anti_masked_ar = [ pred_src_dst_sigm_ar[i]*target_dstm_anti_sigm_ar[i] for i in range(len(pred_src_dst_sigm_ar))]
if self.is_training_mode: if self.is_training_mode:
self.src_dst_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, tf_cpu_mode=self.options['optimizer_mode']-1) self.src_dst_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, tf_cpu_mode=self.options['optimizer_mode']-1)
self.src_dst_mask_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, tf_cpu_mode=self.options['optimizer_mode']-1) self.src_dst_mask_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, tf_cpu_mode=self.options['optimizer_mode']-1)
if self.options['archi'] == 'liae': if self.options['archi'] == 'liae':
src_dst_loss_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights src_dst_loss_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights
if self.options['learn_mask']: if self.options['learn_mask']:
src_dst_mask_loss_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoderm.trainable_weights src_dst_mask_loss_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoderm.trainable_weights
else: else:
src_dst_loss_train_weights = self.encoder.trainable_weights + self.decoder_src.trainable_weights + self.decoder_dst.trainable_weights src_dst_loss_train_weights = self.encoder.trainable_weights + self.decoder_src.trainable_weights + self.decoder_dst.trainable_weights
if self.options['learn_mask']: if self.options['learn_mask']:
src_dst_mask_loss_train_weights = self.encoder.trainable_weights + self.decoder_srcm.trainable_weights + self.decoder_dstm.trainable_weights src_dst_mask_loss_train_weights = self.encoder.trainable_weights + self.decoder_srcm.trainable_weights + self.decoder_dstm.trainable_weights
if not self.options['pixel_loss']: if not self.options['pixel_loss']:
src_loss_batch = sum([ ( 100*K.square( dssim(kernel_size=int(resolution/11.6),max_value=1.0)( target_src_masked_ar_opt[i], pred_src_src_masked_ar_opt[i] ) )) for i in range(len(target_src_masked_ar_opt)) ]) src_loss_batch = sum([ ( 100*K.square( dssim(kernel_size=int(resolution/11.6),max_value=1.0)( target_src_masked_ar_opt[i], pred_src_src_masked_ar_opt[i] ) )) for i in range(len(target_src_masked_ar_opt)) ])
else: else:
@ -318,9 +319,9 @@ class SAEModel(ModelBase):
src_loss = K.mean(src_loss_batch) src_loss = K.mean(src_loss_batch)
face_style_power = self.options['face_style_power'] / 100.0 face_style_power = self.options['face_style_power'] / 100.0
if face_style_power != 0: if face_style_power != 0:
src_loss += style_loss(gaussian_blur_radius=resolution//16, loss_weight=face_style_power, wnd_size=0)( psd_target_dst_masked_ar[-1], target_dst_masked_ar[-1] ) src_loss += style_loss(gaussian_blur_radius=resolution//16, loss_weight=face_style_power, wnd_size=0)( psd_target_dst_masked_ar[-1], target_dst_masked_ar[-1] )
bg_style_power = self.options['bg_style_power'] / 100.0 bg_style_power = self.options['bg_style_power'] / 100.0
if bg_style_power != 0: if bg_style_power != 0:
@ -334,32 +335,32 @@ class SAEModel(ModelBase):
dst_loss_batch = sum([ ( 100*K.square(dssim(kernel_size=int(resolution/11.6),max_value=1.0)( target_dst_masked_ar_opt[i], pred_dst_dst_masked_ar_opt[i] ) )) for i in range(len(target_dst_masked_ar_opt)) ]) dst_loss_batch = sum([ ( 100*K.square(dssim(kernel_size=int(resolution/11.6),max_value=1.0)( target_dst_masked_ar_opt[i], pred_dst_dst_masked_ar_opt[i] ) )) for i in range(len(target_dst_masked_ar_opt)) ])
else: else:
dst_loss_batch = sum([ K.mean ( 100*K.square( target_dst_masked_ar_opt[i] - pred_dst_dst_masked_ar_opt[i] ), axis=[1,2,3]) for i in range(len(target_dst_masked_ar_opt)) ]) dst_loss_batch = sum([ K.mean ( 100*K.square( target_dst_masked_ar_opt[i] - pred_dst_dst_masked_ar_opt[i] ), axis=[1,2,3]) for i in range(len(target_dst_masked_ar_opt)) ])
dst_loss = K.mean(dst_loss_batch) dst_loss = K.mean(dst_loss_batch)
feed = [warped_src, warped_dst] feed = [warped_src, warped_dst]
feed += target_src_ar[::-1] feed += target_src_ar[::-1]
feed += target_srcm_ar[::-1] feed += target_srcm_ar[::-1]
feed += target_dst_ar[::-1] feed += target_dst_ar[::-1]
feed += target_dstm_ar[::-1] feed += target_dstm_ar[::-1]
self.src_dst_train = K.function (feed,[src_loss,dst_loss], self.src_dst_opt.get_updates(src_loss+dst_loss, src_dst_loss_train_weights) ) self.src_dst_train = K.function (feed,[src_loss,dst_loss], self.src_dst_opt.get_updates(src_loss+dst_loss, src_dst_loss_train_weights) )
if self.options['learn_mask']: if self.options['learn_mask']:
src_mask_loss = sum([ K.mean(K.square(target_srcm_ar[-1]-pred_src_srcm[-1])) for i in range(len(target_srcm_ar)) ]) src_mask_loss = sum([ K.mean(K.square(target_srcm_ar[-1]-pred_src_srcm[-1])) for i in range(len(target_srcm_ar)) ])
dst_mask_loss = sum([ K.mean(K.square(target_dstm_ar[-1]-pred_dst_dstm[-1])) for i in range(len(target_dstm_ar)) ]) dst_mask_loss = sum([ K.mean(K.square(target_dstm_ar[-1]-pred_dst_dstm[-1])) for i in range(len(target_dstm_ar)) ])
feed = [ warped_src, warped_dst] feed = [ warped_src, warped_dst]
feed += target_srcm_ar[::-1] feed += target_srcm_ar[::-1]
feed += target_dstm_ar[::-1] feed += target_dstm_ar[::-1]
self.src_dst_mask_train = K.function (feed,[src_mask_loss, dst_mask_loss], self.src_dst_mask_opt.get_updates(src_mask_loss+dst_mask_loss, src_dst_mask_loss_train_weights) ) self.src_dst_mask_train = K.function (feed,[src_mask_loss, dst_mask_loss], self.src_dst_mask_opt.get_updates(src_mask_loss+dst_mask_loss, src_dst_mask_loss_train_weights) )
if self.options['learn_mask']: if self.options['learn_mask']:
self.AE_view = K.function ([warped_src, warped_dst], [pred_src_src[-1], pred_dst_dst[-1], pred_src_dst[-1], pred_src_dstm[-1]]) self.AE_view = K.function ([warped_src, warped_dst], [pred_src_src[-1], pred_dst_dst[-1], pred_src_dst[-1], pred_src_dstm[-1]])
else: else:
self.AE_view = K.function ([warped_src, warped_dst], [pred_src_src[-1], pred_dst_dst[-1], pred_src_dst[-1] ] ) self.AE_view = K.function ([warped_src, warped_dst], [pred_src_src[-1], pred_dst_dst[-1], pred_src_dst[-1] ] )
self.load_weights_safe(weights_to_load)#, [ [self.src_dst_opt, 'src_dst_opt'], [self.src_dst_mask_opt, 'src_dst_mask_opt']]) self.load_weights_safe(weights_to_load)#, [ [self.src_dst_opt, 'src_dst_opt'], [self.src_dst_mask_opt, 'src_dst_mask_opt']])
else: else:
self.load_weights_safe(weights_to_load) self.load_weights_safe(weights_to_load)
@ -367,30 +368,30 @@ class SAEModel(ModelBase):
self.AE_convert = K.function ([warped_dst],[ pred_src_dst[-1], pred_src_dstm[-1] ]) self.AE_convert = K.function ([warped_dst],[ pred_src_dst[-1], pred_src_dstm[-1] ])
else: else:
self.AE_convert = K.function ([warped_dst],[ pred_src_dst[-1] ]) self.AE_convert = K.function ([warped_dst],[ pred_src_dst[-1] ])
if self.is_training_mode: if self.is_training_mode:
self.src_sample_losses = [] self.src_sample_losses = []
self.dst_sample_losses = [] self.dst_sample_losses = []
f = SampleProcessor.TypeFlags f = SampleProcessor.TypeFlags
face_type = f.FACE_ALIGN_FULL if self.options['face_type'] == 'f' else f.FACE_ALIGN_HALF face_type = f.FACE_ALIGN_FULL if self.options['face_type'] == 'f' else f.FACE_ALIGN_HALF
output_sample_types=[ [f.WARPED_TRANSFORMED | face_type | f.MODE_BGR, resolution] ] output_sample_types=[ [f.WARPED_TRANSFORMED | face_type | f.MODE_BGR, resolution] ]
output_sample_types += [ [f.TRANSFORMED | face_type | f.MODE_BGR, resolution // (2**i) ] for i in range(ms_count)] output_sample_types += [ [f.TRANSFORMED | face_type | f.MODE_BGR, resolution // (2**i) ] for i in range(ms_count)]
output_sample_types += [ [f.TRANSFORMED | face_type | f.MODE_M | f.FACE_MASK_FULL, resolution // (2**i) ] for i in range(ms_count)] output_sample_types += [ [f.TRANSFORMED | face_type | f.MODE_M | f.FACE_MASK_FULL, resolution // (2**i) ] for i in range(ms_count)]
self.set_training_data_generators ([ self.set_training_data_generators ([
SampleGeneratorFace(self.training_data_src_path, sort_by_yaw_target_samples_path=self.training_data_dst_path if self.sort_by_yaw else None, SampleGeneratorFace(self.training_data_src_path, sort_by_yaw_target_samples_path=self.training_data_dst_path if self.sort_by_yaw else None,
debug=self.is_debug(), batch_size=self.batch_size, debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05])+self.src_scale_mod / 100.0 ), sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05])+self.src_scale_mod / 100.0 ),
output_sample_types=output_sample_types ), output_sample_types=output_sample_types ),
SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size, SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, ), sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, ),
output_sample_types=output_sample_types ) output_sample_types=output_sample_types )
]) ])
#override #override
def onSave(self): def onSave(self):
opt_ar = [ [self.src_dst_opt, 'src_dst_opt'], opt_ar = [ [self.src_dst_opt, 'src_dst_opt'],
@ -413,10 +414,10 @@ class SAEModel(ModelBase):
if self.options['learn_mask']: if self.options['learn_mask']:
ar += [ [self.decoder_srcm, 'decoder_srcm.h5'], ar += [ [self.decoder_srcm, 'decoder_srcm.h5'],
[self.decoder_dstm, 'decoder_dstm.h5'] ] [self.decoder_dstm, 'decoder_dstm.h5'] ]
self.save_weights_safe(ar) self.save_weights_safe(ar)
#override #override
def onTrainOneIter(self, generators_samples, generators_list): def onTrainOneIter(self, generators_samples, generators_list):
src_samples = generators_samples[0] src_samples = generators_samples[0]
@ -425,17 +426,17 @@ class SAEModel(ModelBase):
feed = [src_samples[0], dst_samples[0] ] + \ feed = [src_samples[0], dst_samples[0] ] + \
src_samples[1:1+self.ms_count*2] + \ src_samples[1:1+self.ms_count*2] + \
dst_samples[1:1+self.ms_count*2] dst_samples[1:1+self.ms_count*2]
src_loss, dst_loss, = self.src_dst_train (feed) src_loss, dst_loss, = self.src_dst_train (feed)
if self.options['learn_mask']: if self.options['learn_mask']:
feed = [ src_samples[0], dst_samples[0] ] + \ feed = [ src_samples[0], dst_samples[0] ] + \
src_samples[1+self.ms_count:1+self.ms_count*2] + \ src_samples[1+self.ms_count:1+self.ms_count*2] + \
dst_samples[1+self.ms_count:1+self.ms_count*2] dst_samples[1+self.ms_count:1+self.ms_count*2]
src_mask_loss, dst_mask_loss, = self.src_dst_mask_train (feed) src_mask_loss, dst_mask_loss, = self.src_dst_mask_train (feed)
return ( ('src_loss', src_loss), ('dst_loss', dst_loss) ) return ( ('src_loss', src_loss), ('dst_loss', dst_loss) )
#override #override
def onGetPreview(self, sample): def onGetPreview(self, sample):
@ -454,33 +455,33 @@ class SAEModel(ModelBase):
for i in range(0, len(test_A)): for i in range(0, len(test_A)):
ar = S[i], SS[i], D[i], DD[i], SD[i] ar = S[i], SS[i], D[i], DD[i], SD[i]
#if self.options['learn_mask']: #if self.options['learn_mask']:
# ar += (SDM[i],) # ar += (SDM[i],)
st.append ( np.concatenate ( ar, axis=1) ) st.append ( np.concatenate ( ar, axis=1) )
return [ ('SAE', np.concatenate (st, axis=0 )), ] return [ ('SAE', np.concatenate (st, axis=0 )), ]
def predictor_func (self, face): def predictor_func (self, face):
prd = [ x[0] for x in self.AE_convert ( [ face[np.newaxis,:,:,0:3] ] ) ] prd = [ x[0] for x in self.AE_convert ( [ face[np.newaxis,:,:,0:3] ] ) ]
if not self.options['learn_mask']: if not self.options['learn_mask']:
prd += [ face[...,3:4] ] prd += [ face[...,3:4] ]
return np.concatenate ( prd, -1 ) return np.concatenate ( prd, -1 )
#override #override
def get_converter(self): def get_converter(self):
base_erode_mask_modifier = 30 if self.options['face_type'] == 'f' else 100 base_erode_mask_modifier = 30 if self.options['face_type'] == 'f' else 100
base_blur_mask_modifier = 0 if self.options['face_type'] == 'f' else 100 base_blur_mask_modifier = 0 if self.options['face_type'] == 'f' else 100
default_erode_mask_modifier = 0 default_erode_mask_modifier = 0
default_blur_mask_modifier = 100 if (self.options['face_style_power'] or self.options['bg_style_power']) and \ default_blur_mask_modifier = 100 if (self.options['face_style_power'] or self.options['bg_style_power']) and \
self.options['face_type'] == 'f' else 0 self.options['face_type'] == 'f' else 0
face_type = FaceType.FULL if self.options['face_type'] == 'f' else FaceType.HALF face_type = FaceType.FULL if self.options['face_type'] == 'f' else FaceType.HALF
from converters import ConverterMasked from converters import ConverterMasked
return ConverterMasked(self.predictor_func, return ConverterMasked(self.predictor_func,
predictor_input_size=self.options['resolution'], predictor_input_size=self.options['resolution'],
output_size=self.options['resolution'], output_size=self.options['resolution'],
face_type=face_type, face_type=face_type,
@ -490,30 +491,34 @@ class SAEModel(ModelBase):
default_erode_mask_modifier=default_erode_mask_modifier, default_erode_mask_modifier=default_erode_mask_modifier,
default_blur_mask_modifier=default_blur_mask_modifier, default_blur_mask_modifier=default_blur_mask_modifier,
clip_hborder_mask_per=0.0625 if self.options['face_type'] == 'f' else 0) clip_hborder_mask_per=0.0625 if self.options['face_type'] == 'f' else 0)
@staticmethod @staticmethod
def initialize_nn_functions(): def initialize_nn_functions():
exec (nnlib.import_all(), locals(), globals()) exec (nnlib.import_all(), locals(), globals())
def BatchNorm():
return BatchNormalization(axis=-1, gamma_initializer=RandomNormal(1., 0.02) )
class ResidualBlock(object): class ResidualBlock(object):
def __init__(self, filters, kernel_size=3, padding='same', use_reflection_padding=False): def __init__(self, filters, kernel_size=3, padding='same', use_reflection_padding=False):
self.filters = filters self.filters = filters
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.padding = padding #if not use_reflection_padding else 'valid' self.padding = padding #if not use_reflection_padding else 'valid'
self.use_reflection_padding = use_reflection_padding self.use_reflection_padding = use_reflection_padding
def __call__(self, inp): def __call__(self, inp):
var_x = LeakyReLU(alpha=0.2)(inp) var_x = LeakyReLU(alpha=0.2)(inp)
#if self.use_reflection_padding: #if self.use_reflection_padding:
# #var_x = ReflectionPadding2D(stride=1, kernel_size=kernel_size)(var_x) # #var_x = ReflectionPadding2D(stride=1, kernel_size=kernel_size)(var_x)
var_x = Conv2D(self.filters, kernel_size=self.kernel_size, padding=self.padding, kernel_initializer=RandomNormal(0, 0.02) )(var_x) var_x = Conv2D(self.filters, kernel_size=self.kernel_size, padding=self.padding, kernel_initializer=RandomNormal(0, 0.02) )(var_x)
var_x = LeakyReLU(alpha=0.2)(var_x) var_x = LeakyReLU(alpha=0.2)(var_x)
#if self.use_reflection_padding: #if self.use_reflection_padding:
# #var_x = ReflectionPadding2D(stride=1, kernel_size=kernel_size)(var_x) # #var_x = ReflectionPadding2D(stride=1, kernel_size=kernel_size)(var_x)
var_x = Conv2D(self.filters, kernel_size=self.kernel_size, padding=self.padding, kernel_initializer=RandomNormal(0, 0.02) )(var_x) var_x = Conv2D(self.filters, kernel_size=self.kernel_size, padding=self.padding, kernel_initializer=RandomNormal(0, 0.02) )(var_x)
var_x = Scale(gamma_init=keras.initializers.Constant(value=0.1))(var_x) var_x = Scale(gamma_init=keras.initializers.Constant(value=0.1))(var_x)
var_x = Add()([var_x, inp]) var_x = Add()([var_x, inp])
@ -521,99 +526,108 @@ class SAEModel(ModelBase):
return var_x return var_x
SAEModel.ResidualBlock = ResidualBlock SAEModel.ResidualBlock = ResidualBlock
def downscale (dim): def downscale (dim, use_bn=False):
def func(x): def func(x):
return LeakyReLU(0.1)(Conv2D(dim, kernel_size=5, strides=2, padding='same', kernel_initializer=RandomNormal(0, 0.02))(x)) if use_bn:
return func return LeakyReLU(0.1)(BatchNorm()(Conv2D(dim, kernel_size=5, strides=2, padding='same', kernel_initializer=RandomNormal(0, 0.02), use_bias=False)(x)))
else:
return LeakyReLU(0.1)(Conv2D(dim, kernel_size=5, strides=2, padding='same', kernel_initializer=RandomNormal(0, 0.02))(x))
return func
SAEModel.downscale = downscale SAEModel.downscale = downscale
def downscale_sep (dim): def downscale_sep (dim, use_bn=False):
def func(x): def func(x):
return LeakyReLU(0.1)(SeparableConv2D(dim, kernel_size=5, strides=2, padding='same', depthwise_initializer=RandomNormal(0, 0.02), pointwise_initializer=RandomNormal(0, 0.02) )(x)) if use_bn:
return func return LeakyReLU(0.1)(BatchNorm()(SeparableConv2D(dim, kernel_size=5, strides=2, padding='same', depthwise_initializer=RandomNormal(0, 0.02), pointwise_initializer=RandomNormal(0, 0.02), use_bias=False )(x)))
else:
return LeakyReLU(0.1)(SeparableConv2D(dim, kernel_size=5, strides=2, padding='same', depthwise_initializer=RandomNormal(0, 0.02), pointwise_initializer=RandomNormal(0, 0.02) )(x))
return func
SAEModel.downscale_sep = downscale_sep SAEModel.downscale_sep = downscale_sep
def upscale (dim): def upscale (dim, use_bn=False):
def func(x): def func(x):
return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, kernel_size=3, strides=1, padding='same', kernel_initializer=RandomNormal(0, 0.02) )(x))) if use_bn:
return func return SubpixelUpscaler()(LeakyReLU(0.1)(BatchNorm()(Conv2D(dim * 4, kernel_size=3, strides=1, padding='same', kernel_initializer=RandomNormal(0,0.02), use_bias=False )(x))))
else:
return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, kernel_size=3, strides=1, padding='same', kernel_initializer=RandomNormal(0, 0.02) )(x)))
return func
SAEModel.upscale = upscale SAEModel.upscale = upscale
def to_bgr (output_nc): def to_bgr (output_nc):
def func(x): def func(x):
return Conv2D(output_nc, kernel_size=5, padding='same', activation='sigmoid', kernel_initializer=RandomNormal(0, 0.02) )(x) return Conv2D(output_nc, kernel_size=5, padding='same', activation='sigmoid', kernel_initializer=RandomNormal(0, 0.02) )(x)
return func return func
SAEModel.to_bgr = to_bgr SAEModel.to_bgr = to_bgr
@staticmethod @staticmethod
def LIAEEncFlow(resolution, light_enc, ed_ch_dims=42): def LIAEEncFlow(resolution, light_enc, ed_ch_dims=42, use_bn=False):
exec (nnlib.import_all(), locals(), globals()) exec (nnlib.import_all(), locals(), globals())
upscale = SAEModel.upscale upscale = SAEModel.upscale
downscale = SAEModel.downscale downscale = SAEModel.downscale
downscale_sep = SAEModel.downscale_sep downscale_sep = SAEModel.downscale_sep
def func(input): def func(input):
ed_dims = K.int_shape(input)[-1]*ed_ch_dims ed_dims = K.int_shape(input)[-1]*ed_ch_dims
x = input x = input
x = downscale(ed_dims)(x) x = downscale(ed_dims)(x)
if not light_enc: if not light_enc:
x = downscale(ed_dims*2)(x) x = downscale(ed_dims*2, use_bn=use_bn)(x)
x = downscale(ed_dims*4)(x) x = downscale(ed_dims*4, use_bn=use_bn)(x)
x = downscale(ed_dims*8)(x) x = downscale(ed_dims*8, use_bn=use_bn)(x)
else: else:
x = downscale_sep(ed_dims*2)(x) x = downscale_sep(ed_dims*2, use_bn=use_bn)(x)
x = downscale(ed_dims*4)(x) x = downscale(ed_dims*4, use_bn=use_bn)(x)
x = downscale_sep(ed_dims*8)(x) x = downscale_sep(ed_dims*8, use_bn=use_bn)(x)
x = Flatten()(x) x = Flatten()(x)
return x return x
return func return func
@staticmethod @staticmethod
def LIAEInterFlow(resolution, ae_dims=256): def LIAEInterFlow(resolution, ae_dims=256, use_bn=False):
exec (nnlib.import_all(), locals(), globals()) exec (nnlib.import_all(), locals(), globals())
upscale = SAEModel.upscale upscale = SAEModel.upscale
lowest_dense_res=resolution // 16 lowest_dense_res=resolution // 16
def func(input): def func(input):
x = input[0] x = input[0]
x = Dense(ae_dims)(x) x = Dense(ae_dims)(x)
x = Dense(lowest_dense_res * lowest_dense_res * ae_dims*2)(x) x = Dense(lowest_dense_res * lowest_dense_res * ae_dims*2)(x)
x = Reshape((lowest_dense_res, lowest_dense_res, ae_dims*2))(x) x = Reshape((lowest_dense_res, lowest_dense_res, ae_dims*2))(x)
x = upscale(ae_dims*2)(x) x = upscale(ae_dims*2, use_bn=use_bn)(x)
return x return x
return func return func
@staticmethod @staticmethod
def LIAEDecFlow(output_nc,ed_ch_dims=21, multiscale_count=1): def LIAEDecFlow(output_nc,ed_ch_dims=21, multiscale_count=1, use_bn=False):
exec (nnlib.import_all(), locals(), globals()) exec (nnlib.import_all(), locals(), globals())
upscale = SAEModel.upscale upscale = SAEModel.upscale
to_bgr = SAEModel.to_bgr to_bgr = SAEModel.to_bgr
ed_dims = output_nc * ed_ch_dims ed_dims = output_nc * ed_ch_dims
def func(input): def func(input):
x = input[0] x = input[0]
outputs = [] outputs = []
x1 = upscale(ed_dims*8)( x ) x1 = upscale(ed_dims*8, use_bn=use_bn)( x )
if multiscale_count >= 3: if multiscale_count >= 3:
outputs += [ to_bgr(output_nc) ( x1 ) ] outputs += [ to_bgr(output_nc) ( x1 ) ]
x2 = upscale(ed_dims*4)( x1 ) x2 = upscale(ed_dims*4, use_bn=use_bn)( x1 )
if multiscale_count >= 2: if multiscale_count >= 2:
outputs += [ to_bgr(output_nc) ( x2 ) ] outputs += [ to_bgr(output_nc) ( x2 ) ]
x3 = upscale(ed_dims*2)( x2 ) x3 = upscale(ed_dims*2, use_bn=use_bn)( x2 )
outputs += [ to_bgr(output_nc) ( x3 ) ] outputs += [ to_bgr(output_nc) ( x3 ) ]
return outputs return outputs
return func return func
@staticmethod @staticmethod
def DFEncFlow(resolution, light_enc, ae_dims=512, ed_ch_dims=42): def DFEncFlow(resolution, light_enc, ae_dims=512, ed_ch_dims=42):
exec (nnlib.import_all(), locals(), globals()) exec (nnlib.import_all(), locals(), globals())
@ -622,11 +636,11 @@ class SAEModel(ModelBase):
downscale_sep = SAEModel.downscale_sep downscale_sep = SAEModel.downscale_sep
lowest_dense_res = resolution // 16 lowest_dense_res = resolution // 16
def func(input): def func(input):
x = input x = input
ed_dims = K.int_shape(input)[-1]*ed_ch_dims ed_dims = K.int_shape(input)[-1]*ed_ch_dims
x = downscale(ed_dims)(x) x = downscale(ed_dims)(x)
if not light_enc: if not light_enc:
x = downscale(ed_dims*2)(x) x = downscale(ed_dims*2)(x)
@ -636,15 +650,15 @@ class SAEModel(ModelBase):
x = downscale_sep(ed_dims*2)(x) x = downscale_sep(ed_dims*2)(x)
x = downscale_sep(ed_dims*4)(x) x = downscale_sep(ed_dims*4)(x)
x = downscale_sep(ed_dims*8)(x) x = downscale_sep(ed_dims*8)(x)
x = Dense(ae_dims)(Flatten()(x)) x = Dense(ae_dims)(Flatten()(x))
x = Dense(lowest_dense_res * lowest_dense_res * ae_dims)(x) x = Dense(lowest_dense_res * lowest_dense_res * ae_dims)(x)
x = Reshape((lowest_dense_res, lowest_dense_res, ae_dims))(x) x = Reshape((lowest_dense_res, lowest_dense_res, ae_dims))(x)
x = upscale(ae_dims)(x) x = upscale(ae_dims)(x)
return x return x
return func return func
@staticmethod @staticmethod
def DFDecFlow(output_nc, ed_ch_dims=21, multiscale_count=1): def DFDecFlow(output_nc, ed_ch_dims=21, multiscale_count=1):
exec (nnlib.import_all(), locals(), globals()) exec (nnlib.import_all(), locals(), globals())
@ -652,29 +666,29 @@ class SAEModel(ModelBase):
to_bgr = SAEModel.to_bgr to_bgr = SAEModel.to_bgr
ed_dims = output_nc * ed_ch_dims ed_dims = output_nc * ed_ch_dims
def func(input): def func(input):
x = input[0] x = input[0]
outputs = [] outputs = []
x1 = upscale(ed_dims*8)( x ) x1 = upscale(ed_dims*8)( x )
if multiscale_count >= 3: if multiscale_count >= 3:
outputs += [ to_bgr(output_nc) ( x1 ) ] outputs += [ to_bgr(output_nc) ( x1 ) ]
x2 = upscale(ed_dims*4)( x1 ) x2 = upscale(ed_dims*4)( x1 )
if multiscale_count >= 2: if multiscale_count >= 2:
outputs += [ to_bgr(output_nc) ( x2 ) ] outputs += [ to_bgr(output_nc) ( x2 ) ]
x3 = upscale(ed_dims*2)( x2 ) x3 = upscale(ed_dims*2)( x2 )
outputs += [ to_bgr(output_nc) ( x3 ) ] outputs += [ to_bgr(output_nc) ( x3 ) ]
return outputs return outputs
return func return func
@staticmethod @staticmethod
def VGEncFlow(resolution, light_enc, ae_dims=512, ed_ch_dims=42): def VGEncFlow(resolution, light_enc, ae_dims=512, ed_ch_dims=42):
exec (nnlib.import_all(), locals(), globals()) exec (nnlib.import_all(), locals(), globals())
@ -683,78 +697,78 @@ class SAEModel(ModelBase):
downscale_sep = SAEModel.downscale_sep downscale_sep = SAEModel.downscale_sep
ResidualBlock = SAEModel.ResidualBlock ResidualBlock = SAEModel.ResidualBlock
lowest_dense_res = resolution // 16 lowest_dense_res = resolution // 16
def func(input): def func(input):
x = input x = input
ed_dims = K.int_shape(input)[-1]*ed_ch_dims ed_dims = K.int_shape(input)[-1]*ed_ch_dims
while np.modf(ed_dims / 4)[0] != 0.0: while np.modf(ed_dims / 4)[0] != 0.0:
ed_dims -= 1 ed_dims -= 1
in_conv_filters = ed_dims# if resolution <= 128 else ed_dims + (resolution//128)*ed_ch_dims in_conv_filters = ed_dims# if resolution <= 128 else ed_dims + (resolution//128)*ed_ch_dims
x = tmp_x = Conv2D (in_conv_filters, kernel_size=5, strides=2, padding='same') (x) x = tmp_x = Conv2D (in_conv_filters, kernel_size=5, strides=2, padding='same') (x)
for _ in range ( 8 if light_enc else 16 ): for _ in range ( 8 if light_enc else 16 ):
x = ResidualBlock(ed_dims)(x) x = ResidualBlock(ed_dims)(x)
x = Add()([x, tmp_x]) x = Add()([x, tmp_x])
x = downscale(ed_dims)(x) x = downscale(ed_dims)(x)
x = SubpixelUpscaler()(x) x = SubpixelUpscaler()(x)
x = downscale(ed_dims)(x) x = downscale(ed_dims)(x)
x = SubpixelUpscaler()(x) x = SubpixelUpscaler()(x)
x = downscale(ed_dims)(x) x = downscale(ed_dims)(x)
if light_enc: if light_enc:
x = downscale_sep (ed_dims*2)(x) x = downscale_sep (ed_dims*2)(x)
else: else:
x = downscale (ed_dims*2)(x) x = downscale (ed_dims*2)(x)
x = downscale(ed_dims*4)(x) x = downscale(ed_dims*4)(x)
if light_enc: if light_enc:
x = downscale_sep (ed_dims*8)(x) x = downscale_sep (ed_dims*8)(x)
else: else:
x = downscale (ed_dims*8)(x) x = downscale (ed_dims*8)(x)
x = Dense(ae_dims)(Flatten()(x)) x = Dense(ae_dims)(Flatten()(x))
x = Dense(lowest_dense_res * lowest_dense_res * ae_dims)(x) x = Dense(lowest_dense_res * lowest_dense_res * ae_dims)(x)
x = Reshape((lowest_dense_res, lowest_dense_res, ae_dims))(x) x = Reshape((lowest_dense_res, lowest_dense_res, ae_dims))(x)
x = upscale(ae_dims)(x) x = upscale(ae_dims)(x)
return x return x
return func return func
@staticmethod @staticmethod
def VGDecFlow(output_nc, ed_ch_dims=21, multiscale_count=1): def VGDecFlow(output_nc, ed_ch_dims=21, multiscale_count=1):
exec (nnlib.import_all(), locals(), globals()) exec (nnlib.import_all(), locals(), globals())
upscale = SAEModel.upscale upscale = SAEModel.upscale
to_bgr = SAEModel.to_bgr to_bgr = SAEModel.to_bgr
ResidualBlock = SAEModel.ResidualBlock ResidualBlock = SAEModel.ResidualBlock
ed_dims = output_nc * ed_ch_dims ed_dims = output_nc * ed_ch_dims
def func(input): def func(input):
x = input[0] x = input[0]
x = upscale( ed_dims*8 )(x) x = upscale( ed_dims*8 )(x)
x = ResidualBlock( ed_dims*8 )(x) x = ResidualBlock( ed_dims*8 )(x)
x = upscale( ed_dims*4 )(x) x = upscale( ed_dims*4 )(x)
x = ResidualBlock( ed_dims*4 )(x) x = ResidualBlock( ed_dims*4 )(x)
x = upscale( ed_dims*2 )(x) x = upscale( ed_dims*2 )(x)
x = ResidualBlock( ed_dims*2 )(x) x = ResidualBlock( ed_dims*2 )(x)
x = to_bgr(output_nc) (x) x = to_bgr(output_nc) (x)
return x return x
return func return func
Model = SAEModel Model = SAEModel
# 'worst' sample booster gives no good result, or I dont know how to filter worst samples properly. # 'worst' sample booster gives no good result, or I dont know how to filter worst samples properly.
# #
##gathering array of sample_losses ##gathering array of sample_losses
@ -769,7 +783,7 @@ Model = SAEModel
# idxs = (x[:,0][ np.argwhere ( b [ b > (np.mean(b)+np.std(b)) ] )[:,0] ]).astype(np.uint) # idxs = (x[:,0][ np.argwhere ( b [ b > (np.mean(b)+np.std(b)) ] )[:,0] ]).astype(np.uint)
# generators_list[0].repeat_sample_idxs(idxs) #ask generator to repeat these sample idxs # generators_list[0].repeat_sample_idxs(idxs) #ask generator to repeat these sample idxs
# print ("src repeated %d" % (len(idxs)) ) # print ("src repeated %d" % (len(idxs)) )
# #
#if len(self.dst_sample_losses) >= 128: #array is big enough #if len(self.dst_sample_losses) >= 128: #array is big enough
# #fetching idxs which losses are bigger than average # #fetching idxs which losses are bigger than average
# x = np.array (self.dst_sample_losses) # x = np.array (self.dst_sample_losses)
@ -777,4 +791,4 @@ Model = SAEModel
# b = x[:,1] # b = x[:,1]
# idxs = (x[:,0][ np.argwhere ( b [ b > (np.mean(b)+np.std(b)) ] )[:,0] ]).astype(np.uint) # idxs = (x[:,0][ np.argwhere ( b [ b > (np.mean(b)+np.std(b)) ] )[:,0] ]).astype(np.uint)
# generators_list[1].repeat_sample_idxs(idxs) #ask generator to repeat these sample idxs # generators_list[1].repeat_sample_idxs(idxs) #ask generator to repeat these sample idxs
# print ("dst repeated %d" % (len(idxs)) ) # print ("dst repeated %d" % (len(idxs)) )

View file

@ -1 +1 @@
from .Model import Model from .Model import Model

View file

@ -2,4 +2,4 @@ from .ModelBase import ModelBase
def import_model(name): def import_model(name):
module = __import__('Model_'+name, globals(), locals(), [], 1) module = __import__('Model_'+name, globals(), locals(), [], 1)
return getattr(module, 'Model') return getattr(module, 'Model')

View file

@ -61,7 +61,7 @@ def _scale_filters(filters, variance):
def CAGenerateWeights ( shape, floatx, data_format, eps_std=0.05, seed=None ): def CAGenerateWeights ( shape, floatx, data_format, eps_std=0.05, seed=None ):
if seed is not None: if seed is not None:
np.random.seed(seed) np.random.seed(seed)
fan_in, fan_out = _compute_fans(shape, data_format) fan_in, fan_out = _compute_fans(shape, data_format)
variance = 2 / fan_in variance = 2 / fan_in
@ -109,4 +109,4 @@ def CAGenerateWeights ( shape, floatx, data_format, eps_std=0.05, seed=None ):
# Format of array is now: filters, stack, row, column # Format of array is now: filters, stack, row, column
init = np.array(init) init = np.array(init)
init = _scale_filters(init, variance) init = _scale_filters(init, variance)
return init.transpose(transpose_dimensions) return init.transpose(transpose_dimensions)

View file

@ -1 +1 @@
from .nnlib import nnlib from .nnlib import nnlib

View file

@ -5,11 +5,11 @@ from .pynvml import *
#you can set DFL_TF_MIN_REQ_CAP manually for your build #you can set DFL_TF_MIN_REQ_CAP manually for your build
#the reason why we cannot check tensorflow.version is it requires import tensorflow #the reason why we cannot check tensorflow.version is it requires import tensorflow
tf_min_req_cap = int(os.environ.get("DFL_TF_MIN_REQ_CAP", 35)) tf_min_req_cap = int(os.environ.get("DFL_TF_MIN_REQ_CAP", 35))
class device: class device:
backend = None backend = None
class Config(): class Config():
force_gpu_idx = -1 force_gpu_idx = -1
multi_gpu = False multi_gpu = False
force_gpu_idxs = None force_gpu_idxs = None
@ -22,36 +22,36 @@ class device:
use_fp16 = False use_fp16 = False
cpu_only = False cpu_only = False
backend = None backend = None
def __init__ (self, force_gpu_idx = -1, def __init__ (self, force_gpu_idx = -1,
multi_gpu = False, multi_gpu = False,
force_gpu_idxs = None, force_gpu_idxs = None,
choose_worst_gpu = False, choose_worst_gpu = False,
allow_growth = True, allow_growth = True,
use_fp16 = False, use_fp16 = False,
cpu_only = False, cpu_only = False,
**in_options): **in_options):
self.backend = device.backend self.backend = device.backend
self.use_fp16 = use_fp16 self.use_fp16 = use_fp16
self.cpu_only = cpu_only self.cpu_only = cpu_only
if not self.cpu_only: if not self.cpu_only:
self.cpu_only = (self.backend == "tensorflow-cpu") self.cpu_only = (self.backend == "tensorflow-cpu")
if not self.cpu_only: if not self.cpu_only:
self.force_gpu_idx = force_gpu_idx self.force_gpu_idx = force_gpu_idx
self.multi_gpu = multi_gpu self.multi_gpu = multi_gpu
self.force_gpu_idxs = force_gpu_idxs self.force_gpu_idxs = force_gpu_idxs
self.choose_worst_gpu = choose_worst_gpu self.choose_worst_gpu = choose_worst_gpu
self.allow_growth = allow_growth self.allow_growth = allow_growth
self.gpu_idxs = [] self.gpu_idxs = []
if force_gpu_idxs is not None: if force_gpu_idxs is not None:
for idx in force_gpu_idxs.split(','): for idx in force_gpu_idxs.split(','):
idx = int(idx) idx = int(idx)
if device.isValidDeviceIdx(idx): if device.isValidDeviceIdx(idx):
self.gpu_idxs.append(idx) self.gpu_idxs.append(idx)
else: else:
gpu_idx = force_gpu_idx if (force_gpu_idx >= 0 and device.isValidDeviceIdx(force_gpu_idx)) else device.getBestValidDeviceIdx() if not choose_worst_gpu else device.getWorstValidDeviceIdx() gpu_idx = force_gpu_idx if (force_gpu_idx >= 0 and device.isValidDeviceIdx(force_gpu_idx)) else device.getBestValidDeviceIdx() if not choose_worst_gpu else device.getWorstValidDeviceIdx()
if gpu_idx != -1: if gpu_idx != -1:
@ -61,10 +61,10 @@ class device:
self.multi_gpu = False self.multi_gpu = False
else: else:
self.gpu_idxs = [gpu_idx] self.gpu_idxs = [gpu_idx]
self.cpu_only = (len(self.gpu_idxs) == 0) self.cpu_only = (len(self.gpu_idxs) == 0)
if not self.cpu_only: if not self.cpu_only:
self.gpu_names = [] self.gpu_names = []
self.gpu_compute_caps = [] self.gpu_compute_caps = []
@ -78,10 +78,10 @@ class device:
self.gpu_names = ['CPU'] self.gpu_names = ['CPU']
self.gpu_compute_caps = [99] self.gpu_compute_caps = [99]
self.gpu_vram_gb = [0] self.gpu_vram_gb = [0]
if self.cpu_only: if self.cpu_only:
self.backend = "tensorflow-cpu" self.backend = "tensorflow-cpu"
@staticmethod @staticmethod
def getValidDeviceIdxsEnumerator(): def getValidDeviceIdxsEnumerator():
if device.backend == "plaidML": if device.backend == "plaidML":
@ -94,8 +94,8 @@ class device:
yield gpu_idx yield gpu_idx
elif device.backend == "tensorflow-generic": elif device.backend == "tensorflow-generic":
yield 0 yield 0
@staticmethod @staticmethod
def getValidDevicesWithAtLeastTotalMemoryGB(totalmemsize_gb): def getValidDevicesWithAtLeastTotalMemoryGB(totalmemsize_gb):
result = [] result = []
@ -111,9 +111,9 @@ class device:
result.append (i) result.append (i)
elif device.backend == "tensorflow-generic": elif device.backend == "tensorflow-generic":
return [0] return [0]
return result return result
@staticmethod @staticmethod
def getAllDevicesIdxsList(): def getAllDevicesIdxsList():
if device.backend == "plaidML": if device.backend == "plaidML":
@ -121,8 +121,8 @@ class device:
elif device.backend == "tensorflow": elif device.backend == "tensorflow":
return [ *range(nvmlDeviceGetCount() ) ] return [ *range(nvmlDeviceGetCount() ) ]
elif device.backend == "tensorflow-generic": elif device.backend == "tensorflow-generic":
return [0] return [0]
@staticmethod @staticmethod
def getValidDevicesIdxsWithNamesList(): def getValidDevicesIdxsWithNamesList():
if device.backend == "plaidML": if device.backend == "plaidML":
@ -137,17 +137,17 @@ class device:
@staticmethod @staticmethod
def getDeviceVRAMTotalGb (idx): def getDeviceVRAMTotalGb (idx):
if device.backend == "plaidML": if device.backend == "plaidML":
if idx < plaidML_devices_count: if idx < plaidML_devices_count:
return plaidML_devices[idx]['globalMemSize'] / (1024*1024*1024) return plaidML_devices[idx]['globalMemSize'] / (1024*1024*1024)
elif device.backend == "tensorflow": elif device.backend == "tensorflow":
if idx < nvmlDeviceGetCount(): if idx < nvmlDeviceGetCount():
memInfo = nvmlDeviceGetMemoryInfo( nvmlDeviceGetHandleByIndex(idx) ) memInfo = nvmlDeviceGetMemoryInfo( nvmlDeviceGetHandleByIndex(idx) )
return round ( memInfo.total / (1024*1024*1024) ) return round ( memInfo.total / (1024*1024*1024) )
return 0 return 0
elif device.backend == "tensorflow-generic": elif device.backend == "tensorflow-generic":
return 2 return 2
@staticmethod @staticmethod
def getBestValidDeviceIdx(): def getBestValidDeviceIdx():
if device.backend == "plaidML": if device.backend == "plaidML":
@ -172,7 +172,7 @@ class device:
return idx return idx
elif device.backend == "tensorflow-generic": elif device.backend == "tensorflow-generic":
return 0 return 0
@staticmethod @staticmethod
def getWorstValidDeviceIdx(): def getWorstValidDeviceIdx():
if device.backend == "plaidML": if device.backend == "plaidML":
@ -197,7 +197,7 @@ class device:
return idx return idx
elif device.backend == "tensorflow-generic": elif device.backend == "tensorflow-generic":
return 0 return 0
@staticmethod @staticmethod
def isValidDeviceIdx(idx): def isValidDeviceIdx(idx):
if device.backend == "plaidML": if device.backend == "plaidML":
@ -206,11 +206,11 @@ class device:
return idx in [*device.getValidDeviceIdxsEnumerator()] return idx in [*device.getValidDeviceIdxsEnumerator()]
elif device.backend == "tensorflow-generic": elif device.backend == "tensorflow-generic":
return (idx == 0) return (idx == 0)
@staticmethod @staticmethod
def getDeviceIdxsEqualModel(idx): def getDeviceIdxsEqualModel(idx):
if device.backend == "plaidML": if device.backend == "plaidML":
result = [] result = []
idx_name = plaidML_devices[idx]['description'] idx_name = plaidML_devices[idx]['description']
for i in device.getValidDeviceIdxsEnumerator(): for i in device.getValidDeviceIdxsEnumerator():
if plaidML_devices[i]['description'] == idx_name: if plaidML_devices[i]['description'] == idx_name:
@ -218,7 +218,7 @@ class device:
return result return result
elif device.backend == "tensorflow": elif device.backend == "tensorflow":
result = [] result = []
idx_name = nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(idx)).decode() idx_name = nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(idx)).decode()
for i in device.getValidDeviceIdxsEnumerator(): for i in device.getValidDeviceIdxsEnumerator():
if nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(i)).decode() == idx_name: if nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(i)).decode() == idx_name:
@ -226,60 +226,60 @@ class device:
return result return result
elif device.backend == "tensorflow-generic": elif device.backend == "tensorflow-generic":
return [0] if idx == 0 else [] return [0] if idx == 0 else []
@staticmethod @staticmethod
def getDeviceName (idx): def getDeviceName (idx):
if device.backend == "plaidML": if device.backend == "plaidML":
if idx < plaidML_devices_count: if idx < plaidML_devices_count:
return plaidML_devices[idx]['description'] return plaidML_devices[idx]['description']
elif device.backend == "tensorflow": elif device.backend == "tensorflow":
if idx < nvmlDeviceGetCount(): if idx < nvmlDeviceGetCount():
return nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(idx)).decode() return nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(idx)).decode()
elif device.backend == "tensorflow-generic": elif device.backend == "tensorflow-generic":
if idx == 0: if idx == 0:
return "Generic GeForce GPU" return "Generic GeForce GPU"
return None return None
@staticmethod @staticmethod
def getDeviceID (idx): def getDeviceID (idx):
if device.backend == "plaidML": if device.backend == "plaidML":
if idx < plaidML_devices_count: if idx < plaidML_devices_count:
return plaidML_devices[idx]['id'].decode() return plaidML_devices[idx]['id'].decode()
return None return None
@staticmethod @staticmethod
def getDeviceComputeCapability(idx): def getDeviceComputeCapability(idx):
result = 0 result = 0
if device.backend == "plaidML": if device.backend == "plaidML":
return 99 return 99
elif device.backend == "tensorflow": elif device.backend == "tensorflow":
if idx < nvmlDeviceGetCount(): if idx < nvmlDeviceGetCount():
result = nvmlDeviceGetCudaComputeCapability(nvmlDeviceGetHandleByIndex(idx)) result = nvmlDeviceGetCudaComputeCapability(nvmlDeviceGetHandleByIndex(idx))
elif device.backend == "tensorflow-generic": elif device.backend == "tensorflow-generic":
return 99 if idx == 0 else 0 return 99 if idx == 0 else 0
return result[0] * 10 + result[1] return result[0] * 10 + result[1]
force_plaidML = os.environ.get("DFL_FORCE_PLAIDML", "0") == "1" #for OpenCL build , forcing using plaidML even if NVIDIA found force_plaidML = os.environ.get("DFL_FORCE_PLAIDML", "0") == "1" #for OpenCL build , forcing using plaidML even if NVIDIA found
force_tf_cpu = os.environ.get("DFL_FORCE_TF_CPU", "0") == "1" #for OpenCL build , forcing using tf-cpu if plaidML failed force_tf_cpu = os.environ.get("DFL_FORCE_TF_CPU", "0") == "1" #for OpenCL build , forcing using tf-cpu if plaidML failed
has_nvml = False has_nvml = False
has_nvml_cap = False has_nvml_cap = False
#use DFL_FORCE_HAS_NVIDIA_DEVICE=1 if #use DFL_FORCE_HAS_NVIDIA_DEVICE=1 if
#- your NVIDIA cannot be seen by OpenCL #- your NVIDIA cannot be seen by OpenCL
#- CUDA build of DFL #- CUDA build of DFL
has_nvidia_device = os.environ.get("DFL_FORCE_HAS_NVIDIA_DEVICE", "0") == "1" has_nvidia_device = os.environ.get("DFL_FORCE_HAS_NVIDIA_DEVICE", "0") == "1"
plaidML_devices = [] plaidML_devices = []
# Using plaidML OpenCL backend to determine system devices and has_nvidia_device # Using plaidML OpenCL backend to determine system devices and has_nvidia_device
try: try:
os.environ['PLAIDML_EXPERIMENTAL'] = 'false' #this enables work plaidML without run 'plaidml-setup' os.environ['PLAIDML_EXPERIMENTAL'] = 'false' #this enables work plaidML without run 'plaidml-setup'
import plaidml import plaidml
ctx = plaidml.Context() ctx = plaidml.Context()
for d in plaidml.devices(ctx, return_all=True)[0]: for d in plaidml.devices(ctx, return_all=True)[0]:
details = json.loads(d.details) details = json.loads(d.details)
@ -288,13 +288,13 @@ try:
if 'nvidia' in details['vendor'].lower(): if 'nvidia' in details['vendor'].lower():
has_nvidia_device = True has_nvidia_device = True
plaidML_devices += [ {'id':d.id, plaidML_devices += [ {'id':d.id,
'globalMemSize' : int(details['globalMemSize']), 'globalMemSize' : int(details['globalMemSize']),
'description' : d.description.decode() 'description' : d.description.decode()
}] }]
ctx.shutdown() ctx.shutdown()
except: except:
pass pass
plaidML_devices_count = len(plaidML_devices) plaidML_devices_count = len(plaidML_devices)
#choosing backend #choosing backend
@ -306,11 +306,11 @@ if device.backend is None and not force_tf_cpu:
nvmlInit() nvmlInit()
has_nvml = True has_nvml = True
device.backend = "tensorflow" #set tensorflow backend in order to use device.*device() functions device.backend = "tensorflow" #set tensorflow backend in order to use device.*device() functions
gpu_idxs = device.getAllDevicesIdxsList() gpu_idxs = device.getAllDevicesIdxsList()
gpu_caps = np.array ( [ device.getDeviceComputeCapability(gpu_idx) for gpu_idx in gpu_idxs ] ) gpu_caps = np.array ( [ device.getDeviceComputeCapability(gpu_idx) for gpu_idx in gpu_idxs ] )
if len ( np.ndarray.flatten ( np.argwhere (gpu_caps >= tf_min_req_cap) ) ) == 0: if len ( np.ndarray.flatten ( np.argwhere (gpu_caps >= tf_min_req_cap) ) ) == 0:
if not force_plaidML: if not force_plaidML:
print ("No CUDA devices found with minimum required compute capability: %d.%d. Falling back to OpenCL mode." % (tf_min_req_cap // 10, tf_min_req_cap % 10) ) print ("No CUDA devices found with minimum required compute capability: %d.%d. Falling back to OpenCL mode." % (tf_min_req_cap // 10, tf_min_req_cap % 10) )
device.backend = None device.backend = None
@ -320,7 +320,7 @@ if device.backend is None and not force_tf_cpu:
except: except:
#if no NVSMI installed exception will occur #if no NVSMI installed exception will occur
device.backend = None device.backend = None
has_nvml = False has_nvml = False
if force_plaidML or (device.backend is None and not has_nvidia_device): if force_plaidML or (device.backend is None and not has_nvidia_device):
#tensorflow backend was failed without has_nvidia_device , or forcing plaidML, trying to use plaidML backend #tensorflow backend was failed without has_nvidia_device , or forcing plaidML, trying to use plaidML backend
@ -333,7 +333,7 @@ if force_plaidML or (device.backend is None and not has_nvidia_device):
if device.backend is None: if device.backend is None:
if force_tf_cpu: if force_tf_cpu:
device.backend = "tensorflow-cpu" device.backend = "tensorflow-cpu"
elif not has_nvml: elif not has_nvml:
if has_nvidia_device: if has_nvidia_device:
#some notebook systems have NVIDIA card without NVSMI in official drivers #some notebook systems have NVIDIA card without NVSMI in official drivers
#in that case considering we have system with one capable GPU and let tensorflow to choose best GPU #in that case considering we have system with one capable GPU and let tensorflow to choose best GPU
@ -348,4 +348,3 @@ if device.backend is None:
else: else:
#has NVSMI, no capable CUDA-devices, also plaidML was failed, then CPU only #has NVSMI, no capable CUDA-devices, also plaidML was failed, then CPU only
device.backend = "tensorflow-cpu" device.backend = "tensorflow-cpu"

View file

@ -541,7 +541,6 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
result = CAInitializerMPSubprocessor ( [ (i, K.int_shape(conv_weights)) for i, conv_weights in enumerate(conv_weights_list) ], K.floatx(), K.image_data_format() ).run() result = CAInitializerMPSubprocessor ( [ (i, K.int_shape(conv_weights)) for i, conv_weights in enumerate(conv_weights_list) ], K.floatx(), K.image_data_format() ).run()
for idx, weights in result: for idx, weights in result:
K.set_value ( conv_weights_list[idx], weights ) K.set_value ( conv_weights_list[idx], weights )
nnlib.CAInitializerMP = CAInitializerMP nnlib.CAInitializerMP = CAInitializerMP

View file

@ -3,7 +3,7 @@
# #
# Redistribution and use in source and binary forms, with or without # Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met: # modification, are permitted provided that the following conditions are met:
# #
# * Redistributions of source code must retain the above copyright notice, # * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer. # this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright # * Redistributions in binary form must reproduce the above copyright
@ -18,11 +18,11 @@
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
# THE POSSIBILITY OF SUCH DAMAGE. # THE POSSIBILITY OF SUCH DAMAGE.
##### #####
@ -35,7 +35,7 @@ import sys
import os import os
import threading import threading
import string import string
## C Type mappings ## ## C Type mappings ##
## Enums ## Enums
_nvmlEnableState_t = c_uint _nvmlEnableState_t = c_uint
@ -155,9 +155,9 @@ NVML_FAN_FAILED = 1
_nvmlLedColor_t = c_uint _nvmlLedColor_t = c_uint
NVML_LED_COLOR_GREEN = 0 NVML_LED_COLOR_GREEN = 0
NVML_LED_COLOR_AMBER = 1 NVML_LED_COLOR_AMBER = 1
_nvmlGpuOperationMode_t = c_uint _nvmlGpuOperationMode_t = c_uint
NVML_GOM_ALL_ON = 0 NVML_GOM_ALL_ON = 0
NVML_GOM_COMPUTE = 1 NVML_GOM_COMPUTE = 1
NVML_GOM_LOW_DP = 2 NVML_GOM_LOW_DP = 2
@ -173,7 +173,7 @@ NVML_RESTRICTED_API_COUNT = 2
_nvmlBridgeChipType_t = c_uint _nvmlBridgeChipType_t = c_uint
NVML_BRIDGE_CHIP_PLX = 0 NVML_BRIDGE_CHIP_PLX = 0
NVML_BRIDGE_CHIP_BRO4 = 1 NVML_BRIDGE_CHIP_BRO4 = 1
NVML_MAX_PHYSICAL_BRIDGE = 128 NVML_MAX_PHYSICAL_BRIDGE = 128
_nvmlValueType_t = c_uint _nvmlValueType_t = c_uint
@ -317,7 +317,7 @@ def _nvmlGetFunctionPointer(name):
if name in _nvmlGetFunctionPointer_cache: if name in _nvmlGetFunctionPointer_cache:
return _nvmlGetFunctionPointer_cache[name] return _nvmlGetFunctionPointer_cache[name]
libLoadLock.acquire() libLoadLock.acquire()
try: try:
# ensure library was loaded # ensure library was loaded
@ -364,7 +364,7 @@ def nvmlFriendlyObjectToStruct(obj, model):
class struct_c_nvmlUnit_t(Structure): class struct_c_nvmlUnit_t(Structure):
pass # opaque handle pass # opaque handle
c_nvmlUnit_t = POINTER(struct_c_nvmlUnit_t) c_nvmlUnit_t = POINTER(struct_c_nvmlUnit_t)
class _PrintableStructure(Structure): class _PrintableStructure(Structure):
""" """
Abstract class that produces nicer __str__ output than ctypes.Structure. Abstract class that produces nicer __str__ output than ctypes.Structure.
@ -373,7 +373,7 @@ class _PrintableStructure(Structure):
<class_name object at 0x7fdf82fef9e0> <class_name object at 0x7fdf82fef9e0>
this class will print this class will print
class_name(field_name: formatted_value, field_name: formatted_value) class_name(field_name: formatted_value, field_name: formatted_value)
_fmt_ dictionary of <str _field_ name> -> <str format> _fmt_ dictionary of <str _field_ name> -> <str format>
e.g. class that has _field_ 'hex_value', c_uint could be formatted with e.g. class that has _field_ 'hex_value', c_uint could be formatted with
_fmt_ = {"hex_value" : "%08X"} _fmt_ = {"hex_value" : "%08X"}
@ -397,7 +397,7 @@ class _PrintableStructure(Structure):
fmt = self._fmt_["<default>"] fmt = self._fmt_["<default>"]
result.append(("%s: " + fmt) % (key, value)) result.append(("%s: " + fmt) % (key, value))
return self.__class__.__name__ + "(" + string.join(result, ", ") + ")" return self.__class__.__name__ + "(" + string.join(result, ", ") + ")"
class c_nvmlUnitInfo_t(_PrintableStructure): class c_nvmlUnitInfo_t(_PrintableStructure):
_fields_ = [ _fields_ = [
('name', c_char * 96), ('name', c_char * 96),
@ -444,7 +444,7 @@ class nvmlPciInfo_t(_PrintableStructure):
('bus', c_uint), ('bus', c_uint),
('device', c_uint), ('device', c_uint),
('pciDeviceId', c_uint), ('pciDeviceId', c_uint),
# Added in 2.285 # Added in 2.285
('pciSubSystemId', c_uint), ('pciSubSystemId', c_uint),
('reserved0', c_uint), ('reserved0', c_uint),
@ -503,7 +503,7 @@ class c_nvmlBridgeChipHierarchy_t(_PrintableStructure):
_fields_ = [ _fields_ = [
('bridgeCount', c_uint), ('bridgeCount', c_uint),
('bridgeChipInfo', c_nvmlBridgeChipInfo_t * 128), ('bridgeChipInfo', c_nvmlBridgeChipInfo_t * 128),
] ]
class c_nvmlEccErrorCounts_t(_PrintableStructure): class c_nvmlEccErrorCounts_t(_PrintableStructure):
_fields_ = [ _fields_ = [
@ -582,7 +582,7 @@ nvmlClocksThrottleReasonAll = (
nvmlClocksThrottleReasonSwPowerCap | nvmlClocksThrottleReasonSwPowerCap |
nvmlClocksThrottleReasonHwSlowdown | nvmlClocksThrottleReasonHwSlowdown |
nvmlClocksThrottleReasonUnknown nvmlClocksThrottleReasonUnknown
) )
class c_nvmlEventData_t(_PrintableStructure): class c_nvmlEventData_t(_PrintableStructure):
_fields_ = [ _fields_ = [
@ -606,31 +606,31 @@ class c_nvmlAccountingStats_t(_PrintableStructure):
## C function wrappers ## ## C function wrappers ##
def nvmlInit(): def nvmlInit():
_LoadNvmlLibrary() _LoadNvmlLibrary()
# #
# Initialize the library # Initialize the library
# #
fn = _nvmlGetFunctionPointer("nvmlInit_v2") fn = _nvmlGetFunctionPointer("nvmlInit_v2")
ret = fn() ret = fn()
_nvmlCheckReturn(ret) _nvmlCheckReturn(ret)
# Atomically update refcount # Atomically update refcount
global _nvmlLib_refcount global _nvmlLib_refcount
libLoadLock.acquire() libLoadLock.acquire()
_nvmlLib_refcount += 1 _nvmlLib_refcount += 1
libLoadLock.release() libLoadLock.release()
return None return None
def _LoadNvmlLibrary(): def _LoadNvmlLibrary():
''' '''
Load the library if it isn't loaded already Load the library if it isn't loaded already
''' '''
global nvmlLib global nvmlLib
if (nvmlLib == None): if (nvmlLib == None):
# lock to ensure only one caller loads the library # lock to ensure only one caller loads the library
libLoadLock.acquire() libLoadLock.acquire()
try: try:
# ensure the library still isn't loaded # ensure the library still isn't loaded
if (nvmlLib == None): if (nvmlLib == None):
@ -649,7 +649,7 @@ def _LoadNvmlLibrary():
finally: finally:
# lock is always freed # lock is always freed
libLoadLock.release() libLoadLock.release()
def nvmlShutdown(): def nvmlShutdown():
# #
# Leave the library loaded, but shutdown the interface # Leave the library loaded, but shutdown the interface
@ -657,7 +657,7 @@ def nvmlShutdown():
fn = _nvmlGetFunctionPointer("nvmlShutdown") fn = _nvmlGetFunctionPointer("nvmlShutdown")
ret = fn() ret = fn()
_nvmlCheckReturn(ret) _nvmlCheckReturn(ret)
# Atomically update refcount # Atomically update refcount
global _nvmlLib_refcount global _nvmlLib_refcount
libLoadLock.acquire() libLoadLock.acquire()
@ -701,19 +701,19 @@ def nvmlSystemGetHicVersion():
c_count = c_uint(0) c_count = c_uint(0)
hics = None hics = None
fn = _nvmlGetFunctionPointer("nvmlSystemGetHicVersion") fn = _nvmlGetFunctionPointer("nvmlSystemGetHicVersion")
# get the count # get the count
ret = fn(byref(c_count), None) ret = fn(byref(c_count), None)
# this should only fail with insufficient size # this should only fail with insufficient size
if ((ret != NVML_SUCCESS) and if ((ret != NVML_SUCCESS) and
(ret != NVML_ERROR_INSUFFICIENT_SIZE)): (ret != NVML_ERROR_INSUFFICIENT_SIZE)):
raise NVMLError(ret) raise NVMLError(ret)
# if there are no hics # if there are no hics
if (c_count.value == 0): if (c_count.value == 0):
return [] return []
hic_array = c_nvmlHwbcEntry_t * c_count.value hic_array = c_nvmlHwbcEntry_t * c_count.value
hics = hic_array() hics = hic_array()
ret = fn(byref(c_count), hics) ret = fn(byref(c_count), hics)
@ -770,7 +770,7 @@ def nvmlUnitGetFanSpeedInfo(unit):
ret = fn(unit, byref(c_speeds)) ret = fn(unit, byref(c_speeds))
_nvmlCheckReturn(ret) _nvmlCheckReturn(ret)
return c_speeds return c_speeds
# added to API # added to API
def nvmlUnitGetDeviceCount(unit): def nvmlUnitGetDeviceCount(unit):
c_count = c_uint(0) c_count = c_uint(0)
@ -822,7 +822,7 @@ def nvmlDeviceGetHandleByUUID(uuid):
ret = fn(c_uuid, byref(device)) ret = fn(c_uuid, byref(device))
_nvmlCheckReturn(ret) _nvmlCheckReturn(ret)
return device return device
def nvmlDeviceGetHandleByPciBusId(pciBusId): def nvmlDeviceGetHandleByPciBusId(pciBusId):
c_busId = c_char_p(pciBusId) c_busId = c_char_p(pciBusId)
device = c_nvmlDevice_t() device = c_nvmlDevice_t()
@ -858,7 +858,7 @@ def nvmlDeviceGetBrand(handle):
ret = fn(handle, byref(c_type)) ret = fn(handle, byref(c_type))
_nvmlCheckReturn(ret) _nvmlCheckReturn(ret)
return c_type.value return c_type.value
def nvmlDeviceGetSerial(handle): def nvmlDeviceGetSerial(handle):
c_serial = create_string_buffer(NVML_DEVICE_SERIAL_BUFFER_SIZE) c_serial = create_string_buffer(NVML_DEVICE_SERIAL_BUFFER_SIZE)
fn = _nvmlGetFunctionPointer("nvmlDeviceGetSerial") fn = _nvmlGetFunctionPointer("nvmlDeviceGetSerial")
@ -892,14 +892,14 @@ def nvmlDeviceGetMinorNumber(handle):
ret = fn(handle, byref(c_minor_number)) ret = fn(handle, byref(c_minor_number))
_nvmlCheckReturn(ret) _nvmlCheckReturn(ret)
return c_minor_number.value return c_minor_number.value
def nvmlDeviceGetUUID(handle): def nvmlDeviceGetUUID(handle):
c_uuid = create_string_buffer(NVML_DEVICE_UUID_BUFFER_SIZE) c_uuid = create_string_buffer(NVML_DEVICE_UUID_BUFFER_SIZE)
fn = _nvmlGetFunctionPointer("nvmlDeviceGetUUID") fn = _nvmlGetFunctionPointer("nvmlDeviceGetUUID")
ret = fn(handle, c_uuid, c_uint(NVML_DEVICE_UUID_BUFFER_SIZE)) ret = fn(handle, c_uuid, c_uint(NVML_DEVICE_UUID_BUFFER_SIZE))
_nvmlCheckReturn(ret) _nvmlCheckReturn(ret)
return c_uuid.value return c_uuid.value
def nvmlDeviceGetInforomVersion(handle, infoRomObject): def nvmlDeviceGetInforomVersion(handle, infoRomObject):
c_version = create_string_buffer(NVML_DEVICE_INFOROM_VERSION_BUFFER_SIZE) c_version = create_string_buffer(NVML_DEVICE_INFOROM_VERSION_BUFFER_SIZE)
fn = _nvmlGetFunctionPointer("nvmlDeviceGetInforomVersion") fn = _nvmlGetFunctionPointer("nvmlDeviceGetInforomVersion")
@ -929,7 +929,7 @@ def nvmlDeviceValidateInforom(handle):
fn = _nvmlGetFunctionPointer("nvmlDeviceValidateInforom") fn = _nvmlGetFunctionPointer("nvmlDeviceValidateInforom")
ret = fn(handle) ret = fn(handle)
_nvmlCheckReturn(ret) _nvmlCheckReturn(ret)
return None return None
def nvmlDeviceGetDisplayMode(handle): def nvmlDeviceGetDisplayMode(handle):
c_mode = _nvmlEnableState_t() c_mode = _nvmlEnableState_t()
@ -937,29 +937,29 @@ def nvmlDeviceGetDisplayMode(handle):
ret = fn(handle, byref(c_mode)) ret = fn(handle, byref(c_mode))
_nvmlCheckReturn(ret) _nvmlCheckReturn(ret)
return c_mode.value return c_mode.value
def nvmlDeviceGetDisplayActive(handle): def nvmlDeviceGetDisplayActive(handle):
c_mode = _nvmlEnableState_t() c_mode = _nvmlEnableState_t()
fn = _nvmlGetFunctionPointer("nvmlDeviceGetDisplayActive") fn = _nvmlGetFunctionPointer("nvmlDeviceGetDisplayActive")
ret = fn(handle, byref(c_mode)) ret = fn(handle, byref(c_mode))
_nvmlCheckReturn(ret) _nvmlCheckReturn(ret)
return c_mode.value return c_mode.value
def nvmlDeviceGetPersistenceMode(handle): def nvmlDeviceGetPersistenceMode(handle):
c_state = _nvmlEnableState_t() c_state = _nvmlEnableState_t()
fn = _nvmlGetFunctionPointer("nvmlDeviceGetPersistenceMode") fn = _nvmlGetFunctionPointer("nvmlDeviceGetPersistenceMode")
ret = fn(handle, byref(c_state)) ret = fn(handle, byref(c_state))
_nvmlCheckReturn(ret) _nvmlCheckReturn(ret)
return c_state.value return c_state.value
def nvmlDeviceGetPciInfo(handle): def nvmlDeviceGetPciInfo(handle):
c_info = nvmlPciInfo_t() c_info = nvmlPciInfo_t()
fn = _nvmlGetFunctionPointer("nvmlDeviceGetPciInfo_v2") fn = _nvmlGetFunctionPointer("nvmlDeviceGetPciInfo_v2")
ret = fn(handle, byref(c_info)) ret = fn(handle, byref(c_info))
_nvmlCheckReturn(ret) _nvmlCheckReturn(ret)
return c_info return c_info
def nvmlDeviceGetClockInfo(handle, type): def nvmlDeviceGetClockInfo(handle, type):
c_clock = c_uint() c_clock = c_uint()
fn = _nvmlGetFunctionPointer("nvmlDeviceGetClockInfo") fn = _nvmlGetFunctionPointer("nvmlDeviceGetClockInfo")
@ -997,7 +997,7 @@ def nvmlDeviceGetSupportedMemoryClocks(handle):
c_count = c_uint(0) c_count = c_uint(0)
fn = _nvmlGetFunctionPointer("nvmlDeviceGetSupportedMemoryClocks") fn = _nvmlGetFunctionPointer("nvmlDeviceGetSupportedMemoryClocks")
ret = fn(handle, byref(c_count), None) ret = fn(handle, byref(c_count), None)
if (ret == NVML_SUCCESS): if (ret == NVML_SUCCESS):
# special case, no clocks # special case, no clocks
return [] return []
@ -1005,11 +1005,11 @@ def nvmlDeviceGetSupportedMemoryClocks(handle):
# typical case # typical case
clocks_array = c_uint * c_count.value clocks_array = c_uint * c_count.value
c_clocks = clocks_array() c_clocks = clocks_array()
# make the call again # make the call again
ret = fn(handle, byref(c_count), c_clocks) ret = fn(handle, byref(c_count), c_clocks)
_nvmlCheckReturn(ret) _nvmlCheckReturn(ret)
procs = [] procs = []
for i in range(c_count.value): for i in range(c_count.value):
procs.append(c_clocks[i]) procs.append(c_clocks[i])
@ -1025,7 +1025,7 @@ def nvmlDeviceGetSupportedGraphicsClocks(handle, memoryClockMHz):
c_count = c_uint(0) c_count = c_uint(0)
fn = _nvmlGetFunctionPointer("nvmlDeviceGetSupportedGraphicsClocks") fn = _nvmlGetFunctionPointer("nvmlDeviceGetSupportedGraphicsClocks")
ret = fn(handle, c_uint(memoryClockMHz), byref(c_count), None) ret = fn(handle, c_uint(memoryClockMHz), byref(c_count), None)
if (ret == NVML_SUCCESS): if (ret == NVML_SUCCESS):
# special case, no clocks # special case, no clocks
return [] return []
@ -1033,11 +1033,11 @@ def nvmlDeviceGetSupportedGraphicsClocks(handle, memoryClockMHz):
# typical case # typical case
clocks_array = c_uint * c_count.value clocks_array = c_uint * c_count.value
c_clocks = clocks_array() c_clocks = clocks_array()
# make the call again # make the call again
ret = fn(handle, c_uint(memoryClockMHz), byref(c_count), c_clocks) ret = fn(handle, c_uint(memoryClockMHz), byref(c_count), c_clocks)
_nvmlCheckReturn(ret) _nvmlCheckReturn(ret)
procs = [] procs = []
for i in range(c_count.value): for i in range(c_count.value):
procs.append(c_clocks[i]) procs.append(c_clocks[i])
@ -1053,7 +1053,7 @@ def nvmlDeviceGetFanSpeed(handle):
ret = fn(handle, byref(c_speed)) ret = fn(handle, byref(c_speed))
_nvmlCheckReturn(ret) _nvmlCheckReturn(ret)
return c_speed.value return c_speed.value
def nvmlDeviceGetTemperature(handle, sensor): def nvmlDeviceGetTemperature(handle, sensor):
c_temp = c_uint() c_temp = c_uint()
fn = _nvmlGetFunctionPointer("nvmlDeviceGetTemperature") fn = _nvmlGetFunctionPointer("nvmlDeviceGetTemperature")
@ -1075,7 +1075,7 @@ def nvmlDeviceGetPowerState(handle):
ret = fn(handle, byref(c_pstate)) ret = fn(handle, byref(c_pstate))
_nvmlCheckReturn(ret) _nvmlCheckReturn(ret)
return c_pstate.value return c_pstate.value
def nvmlDeviceGetPerformanceState(handle): def nvmlDeviceGetPerformanceState(handle):
c_pstate = _nvmlPstates_t() c_pstate = _nvmlPstates_t()
fn = _nvmlGetFunctionPointer("nvmlDeviceGetPerformanceState") fn = _nvmlGetFunctionPointer("nvmlDeviceGetPerformanceState")
@ -1089,7 +1089,7 @@ def nvmlDeviceGetPowerManagementMode(handle):
ret = fn(handle, byref(c_pcapMode)) ret = fn(handle, byref(c_pcapMode))
_nvmlCheckReturn(ret) _nvmlCheckReturn(ret)
return c_pcapMode.value return c_pcapMode.value
def nvmlDeviceGetPowerManagementLimit(handle): def nvmlDeviceGetPowerManagementLimit(handle):
c_limit = c_uint() c_limit = c_uint()
fn = _nvmlGetFunctionPointer("nvmlDeviceGetPowerManagementLimit") fn = _nvmlGetFunctionPointer("nvmlDeviceGetPowerManagementLimit")
@ -1113,7 +1113,7 @@ def nvmlDeviceGetPowerManagementDefaultLimit(handle):
ret = fn(handle, byref(c_limit)) ret = fn(handle, byref(c_limit))
_nvmlCheckReturn(ret) _nvmlCheckReturn(ret)
return c_limit.value return c_limit.value
# Added in 331 # Added in 331
def nvmlDeviceGetEnforcedPowerLimit(handle): def nvmlDeviceGetEnforcedPowerLimit(handle):
@ -1146,7 +1146,7 @@ def nvmlDeviceGetCurrentGpuOperationMode(handle):
# Added in 4.304 # Added in 4.304
def nvmlDeviceGetPendingGpuOperationMode(handle): def nvmlDeviceGetPendingGpuOperationMode(handle):
return nvmlDeviceGetGpuOperationMode(handle)[1] return nvmlDeviceGetGpuOperationMode(handle)[1]
def nvmlDeviceGetMemoryInfo(handle): def nvmlDeviceGetMemoryInfo(handle):
c_memory = c_nvmlMemory_t() c_memory = c_nvmlMemory_t()
fn = _nvmlGetFunctionPointer("nvmlDeviceGetMemoryInfo") fn = _nvmlGetFunctionPointer("nvmlDeviceGetMemoryInfo")
@ -1160,14 +1160,14 @@ def nvmlDeviceGetBAR1MemoryInfo(handle):
ret = fn(handle, byref(c_bar1_memory)) ret = fn(handle, byref(c_bar1_memory))
_nvmlCheckReturn(ret) _nvmlCheckReturn(ret)
return c_bar1_memory return c_bar1_memory
def nvmlDeviceGetComputeMode(handle): def nvmlDeviceGetComputeMode(handle):
c_mode = _nvmlComputeMode_t() c_mode = _nvmlComputeMode_t()
fn = _nvmlGetFunctionPointer("nvmlDeviceGetComputeMode") fn = _nvmlGetFunctionPointer("nvmlDeviceGetComputeMode")
ret = fn(handle, byref(c_mode)) ret = fn(handle, byref(c_mode))
_nvmlCheckReturn(ret) _nvmlCheckReturn(ret)
return c_mode.value return c_mode.value
def nvmlDeviceGetEccMode(handle): def nvmlDeviceGetEccMode(handle):
c_currState = _nvmlEnableState_t() c_currState = _nvmlEnableState_t()
c_pendingState = _nvmlEnableState_t() c_pendingState = _nvmlEnableState_t()
@ -1200,7 +1200,7 @@ def nvmlDeviceGetDetailedEccErrors(handle, errorType, counterType):
_nvmlEccCounterType_t(counterType), byref(c_counts)) _nvmlEccCounterType_t(counterType), byref(c_counts))
_nvmlCheckReturn(ret) _nvmlCheckReturn(ret)
return c_counts return c_counts
# Added in 4.304 # Added in 4.304
def nvmlDeviceGetMemoryErrorCounter(handle, errorType, counterType, locationType): def nvmlDeviceGetMemoryErrorCounter(handle, errorType, counterType, locationType):
c_count = c_ulonglong() c_count = c_ulonglong()
@ -1212,7 +1212,7 @@ def nvmlDeviceGetMemoryErrorCounter(handle, errorType, counterType, locationType
byref(c_count)) byref(c_count))
_nvmlCheckReturn(ret) _nvmlCheckReturn(ret)
return c_count.value return c_count.value
def nvmlDeviceGetUtilizationRates(handle): def nvmlDeviceGetUtilizationRates(handle):
c_util = c_nvmlUtilization_t() c_util = c_nvmlUtilization_t()
fn = _nvmlGetFunctionPointer("nvmlDeviceGetUtilizationRates") fn = _nvmlGetFunctionPointer("nvmlDeviceGetUtilizationRates")
@ -1273,7 +1273,7 @@ def nvmlDeviceGetComputeRunningProcesses(handle):
c_count = c_uint(0) c_count = c_uint(0)
fn = _nvmlGetFunctionPointer("nvmlDeviceGetComputeRunningProcesses") fn = _nvmlGetFunctionPointer("nvmlDeviceGetComputeRunningProcesses")
ret = fn(handle, byref(c_count), None) ret = fn(handle, byref(c_count), None)
if (ret == NVML_SUCCESS): if (ret == NVML_SUCCESS):
# special case, no running processes # special case, no running processes
return [] return []
@ -1283,11 +1283,11 @@ def nvmlDeviceGetComputeRunningProcesses(handle):
c_count.value = c_count.value * 2 + 5 c_count.value = c_count.value * 2 + 5
proc_array = c_nvmlProcessInfo_t * c_count.value proc_array = c_nvmlProcessInfo_t * c_count.value
c_procs = proc_array() c_procs = proc_array()
# make the call again # make the call again
ret = fn(handle, byref(c_count), c_procs) ret = fn(handle, byref(c_count), c_procs)
_nvmlCheckReturn(ret) _nvmlCheckReturn(ret)
procs = [] procs = []
for i in range(c_count.value): for i in range(c_count.value):
# use an alternative struct for this object # use an alternative struct for this object
@ -1317,11 +1317,11 @@ def nvmlDeviceGetGraphicsRunningProcesses(handle):
c_count.value = c_count.value * 2 + 5 c_count.value = c_count.value * 2 + 5
proc_array = c_nvmlProcessInfo_t * c_count.value proc_array = c_nvmlProcessInfo_t * c_count.value
c_procs = proc_array() c_procs = proc_array()
# make the call again # make the call again
ret = fn(handle, byref(c_count), c_procs) ret = fn(handle, byref(c_count), c_procs)
_nvmlCheckReturn(ret) _nvmlCheckReturn(ret)
procs = [] procs = []
for i in range(c_count.value): for i in range(c_count.value):
# use an alternative struct for this object # use an alternative struct for this object
@ -1351,19 +1351,19 @@ def nvmlUnitSetLedState(unit, color):
ret = fn(unit, _nvmlLedColor_t(color)) ret = fn(unit, _nvmlLedColor_t(color))
_nvmlCheckReturn(ret) _nvmlCheckReturn(ret)
return None return None
def nvmlDeviceSetPersistenceMode(handle, mode): def nvmlDeviceSetPersistenceMode(handle, mode):
fn = _nvmlGetFunctionPointer("nvmlDeviceSetPersistenceMode") fn = _nvmlGetFunctionPointer("nvmlDeviceSetPersistenceMode")
ret = fn(handle, _nvmlEnableState_t(mode)) ret = fn(handle, _nvmlEnableState_t(mode))
_nvmlCheckReturn(ret) _nvmlCheckReturn(ret)
return None return None
def nvmlDeviceSetComputeMode(handle, mode): def nvmlDeviceSetComputeMode(handle, mode):
fn = _nvmlGetFunctionPointer("nvmlDeviceSetComputeMode") fn = _nvmlGetFunctionPointer("nvmlDeviceSetComputeMode")
ret = fn(handle, _nvmlComputeMode_t(mode)) ret = fn(handle, _nvmlComputeMode_t(mode))
_nvmlCheckReturn(ret) _nvmlCheckReturn(ret)
return None return None
def nvmlDeviceSetEccMode(handle, mode): def nvmlDeviceSetEccMode(handle, mode):
fn = _nvmlGetFunctionPointer("nvmlDeviceSetEccMode") fn = _nvmlGetFunctionPointer("nvmlDeviceSetEccMode")
ret = fn(handle, _nvmlEnableState_t(mode)) ret = fn(handle, _nvmlEnableState_t(mode))
@ -1381,15 +1381,15 @@ def nvmlDeviceSetDriverModel(handle, model):
ret = fn(handle, _nvmlDriverModel_t(model)) ret = fn(handle, _nvmlDriverModel_t(model))
_nvmlCheckReturn(ret) _nvmlCheckReturn(ret)
return None return None
def nvmlDeviceSetAutoBoostedClocksEnabled(handle, enabled): def nvmlDeviceSetAutoBoostedClocksEnabled(handle, enabled):
fn = _nvmlGetFunctionPointer("nvmlDeviceSetAutoBoostedClocksEnabled") fn = _nvmlGetFunctionPointer("nvmlDeviceSetAutoBoostedClocksEnabled")
ret = fn(handle, _nvmlEnableState_t(enabled)) ret = fn(handle, _nvmlEnableState_t(enabled))
_nvmlCheckReturn(ret) _nvmlCheckReturn(ret)
return None return None
#Throws NVML_ERROR_NOT_SUPPORTED if hardware doesn't support setting auto boosted clocks #Throws NVML_ERROR_NOT_SUPPORTED if hardware doesn't support setting auto boosted clocks
def nvmlDeviceSetDefaultAutoBoostedClocksEnabled(handle, enabled, flags): def nvmlDeviceSetDefaultAutoBoostedClocksEnabled(handle, enabled, flags):
fn = _nvmlGetFunctionPointer("nvmlDeviceSetDefaultAutoBoostedClocksEnabled") fn = _nvmlGetFunctionPointer("nvmlDeviceSetDefaultAutoBoostedClocksEnabled")
ret = fn(handle, _nvmlEnableState_t(enabled), c_uint(flags)) ret = fn(handle, _nvmlEnableState_t(enabled), c_uint(flags))
_nvmlCheckReturn(ret) _nvmlCheckReturn(ret)
@ -1402,7 +1402,7 @@ def nvmlDeviceSetApplicationsClocks(handle, maxMemClockMHz, maxGraphicsClockMHz)
ret = fn(handle, c_uint(maxMemClockMHz), c_uint(maxGraphicsClockMHz)) ret = fn(handle, c_uint(maxMemClockMHz), c_uint(maxGraphicsClockMHz))
_nvmlCheckReturn(ret) _nvmlCheckReturn(ret)
return None return None
# Added in 4.304 # Added in 4.304
def nvmlDeviceResetApplicationsClocks(handle): def nvmlDeviceResetApplicationsClocks(handle):
fn = _nvmlGetFunctionPointer("nvmlDeviceResetApplicationsClocks") fn = _nvmlGetFunctionPointer("nvmlDeviceResetApplicationsClocks")
@ -1416,7 +1416,7 @@ def nvmlDeviceSetPowerManagementLimit(handle, limit):
ret = fn(handle, c_uint(limit)) ret = fn(handle, c_uint(limit))
_nvmlCheckReturn(ret) _nvmlCheckReturn(ret)
return None return None
# Added in 4.304 # Added in 4.304
def nvmlDeviceSetGpuOperationMode(handle, mode): def nvmlDeviceSetGpuOperationMode(handle, mode):
fn = _nvmlGetFunctionPointer("nvmlDeviceSetGpuOperationMode") fn = _nvmlGetFunctionPointer("nvmlDeviceSetGpuOperationMode")
@ -1534,7 +1534,7 @@ def nvmlDeviceGetAccountingMode(handle):
ret = fn(handle, byref(c_mode)) ret = fn(handle, byref(c_mode))
_nvmlCheckReturn(ret) _nvmlCheckReturn(ret)
return c_mode.value return c_mode.value
def nvmlDeviceSetAccountingMode(handle, mode): def nvmlDeviceSetAccountingMode(handle, mode):
fn = _nvmlGetFunctionPointer("nvmlDeviceSetAccountingMode") fn = _nvmlGetFunctionPointer("nvmlDeviceSetAccountingMode")
ret = fn(handle, _nvmlEnableState_t(mode)) ret = fn(handle, _nvmlEnableState_t(mode))
@ -1563,7 +1563,7 @@ def nvmlDeviceGetAccountingPids(handle):
fn = _nvmlGetFunctionPointer("nvmlDeviceGetAccountingPids") fn = _nvmlGetFunctionPointer("nvmlDeviceGetAccountingPids")
ret = fn(handle, byref(count), pids) ret = fn(handle, byref(count), pids)
_nvmlCheckReturn(ret) _nvmlCheckReturn(ret)
return map(int, pids[0:count.value]) return map(int, pids[0:count.value])
def nvmlDeviceGetAccountingBufferSize(handle): def nvmlDeviceGetAccountingBufferSize(handle):
bufferSize = c_uint() bufferSize = c_uint()
@ -1576,10 +1576,10 @@ def nvmlDeviceGetRetiredPages(device, sourceFilter):
c_source = _nvmlPageRetirementCause_t(sourceFilter) c_source = _nvmlPageRetirementCause_t(sourceFilter)
c_count = c_uint(0) c_count = c_uint(0)
fn = _nvmlGetFunctionPointer("nvmlDeviceGetRetiredPages") fn = _nvmlGetFunctionPointer("nvmlDeviceGetRetiredPages")
# First call will get the size # First call will get the size
ret = fn(device, c_source, byref(c_count), None) ret = fn(device, c_source, byref(c_count), None)
# this should only fail with insufficient size # this should only fail with insufficient size
if ((ret != NVML_SUCCESS) and if ((ret != NVML_SUCCESS) and
(ret != NVML_ERROR_INSUFFICIENT_SIZE)): (ret != NVML_ERROR_INSUFFICIENT_SIZE)):
@ -1651,7 +1651,7 @@ def nvmlDeviceGetViolationStatus(device, perfPolicyType):
ret = fn(device, c_perfPolicy_type, byref(c_violTime)) ret = fn(device, c_perfPolicy_type, byref(c_violTime))
_nvmlCheckReturn(ret) _nvmlCheckReturn(ret)
return c_violTime return c_violTime
def nvmlDeviceGetPcieThroughput(device, counter): def nvmlDeviceGetPcieThroughput(device, counter):
c_util = c_uint() c_util = c_uint()
fn = _nvmlGetFunctionPointer("nvmlDeviceGetPcieThroughput") fn = _nvmlGetFunctionPointer("nvmlDeviceGetPcieThroughput")
@ -1704,17 +1704,17 @@ def nvmlDeviceGetTopologyCommonAncestor(device1, device2):
def nvmlDeviceGetCudaComputeCapability(device): def nvmlDeviceGetCudaComputeCapability(device):
c_major = c_int() c_major = c_int()
c_minor = c_int() c_minor = c_int()
try: try:
fn = _nvmlGetFunctionPointer("nvmlDeviceGetCudaComputeCapability") fn = _nvmlGetFunctionPointer("nvmlDeviceGetCudaComputeCapability")
except: except:
return 9, 9 return 9, 9
# get the count # get the count
ret = fn(device, byref(c_major), byref(c_minor)) ret = fn(device, byref(c_major), byref(c_minor))
# this should only fail with insufficient size # this should only fail with insufficient size
if (ret != NVML_SUCCESS): if (ret != NVML_SUCCESS):
raise NVMLError(ret) raise NVMLError(ret)
return c_major.value, c_minor.value return c_major.value, c_minor.value

View file

@ -5,17 +5,17 @@ from utils.cv2_utils import *
class SampleType(IntEnum): class SampleType(IntEnum):
IMAGE = 0 #raw image IMAGE = 0 #raw image
FACE_BEGIN = 1 FACE_BEGIN = 1
FACE = 1 #aligned face unsorted FACE = 1 #aligned face unsorted
FACE_YAW_SORTED = 2 #sorted by yaw FACE_YAW_SORTED = 2 #sorted by yaw
FACE_YAW_SORTED_AS_TARGET = 3 #sorted by yaw and included only yaws which exist in TARGET also automatic mirrored FACE_YAW_SORTED_AS_TARGET = 3 #sorted by yaw and included only yaws which exist in TARGET also automatic mirrored
FACE_WITH_CLOSE_TO_SELF = 4 FACE_WITH_CLOSE_TO_SELF = 4
FACE_END = 4 FACE_END = 4
QTY = 5 QTY = 5
class Sample(object): class Sample(object):
def __init__(self, sample_type=None, filename=None, face_type=None, shape=None, landmarks=None, pitch=None, yaw=None, mirror=None, close_target_list=None): def __init__(self, sample_type=None, filename=None, face_type=None, shape=None, landmarks=None, pitch=None, yaw=None, mirror=None, close_target_list=None):
self.sample_type = sample_type if sample_type is not None else SampleType.IMAGE self.sample_type = sample_type if sample_type is not None else SampleType.IMAGE
self.filename = filename self.filename = filename
@ -26,19 +26,19 @@ class Sample(object):
self.yaw = yaw self.yaw = yaw
self.mirror = mirror self.mirror = mirror
self.close_target_list = close_target_list self.close_target_list = close_target_list
def copy_and_set(self, sample_type=None, filename=None, face_type=None, shape=None, landmarks=None, pitch=None, yaw=None, mirror=None, close_target_list=None): def copy_and_set(self, sample_type=None, filename=None, face_type=None, shape=None, landmarks=None, pitch=None, yaw=None, mirror=None, close_target_list=None):
return Sample( return Sample(
sample_type=sample_type if sample_type is not None else self.sample_type, sample_type=sample_type if sample_type is not None else self.sample_type,
filename=filename if filename is not None else self.filename, filename=filename if filename is not None else self.filename,
face_type=face_type if face_type is not None else self.face_type, face_type=face_type if face_type is not None else self.face_type,
shape=shape if shape is not None else self.shape, shape=shape if shape is not None else self.shape,
landmarks=landmarks if landmarks is not None else self.landmarks.copy(), landmarks=landmarks if landmarks is not None else self.landmarks.copy(),
pitch=pitch if pitch is not None else self.pitch, pitch=pitch if pitch is not None else self.pitch,
yaw=yaw if yaw is not None else self.yaw, yaw=yaw if yaw is not None else self.yaw,
mirror=mirror if mirror is not None else self.mirror, mirror=mirror if mirror is not None else self.mirror,
close_target_list=close_target_list if close_target_list is not None else self.close_target_list) close_target_list=close_target_list if close_target_list is not None else self.close_target_list)
def load_bgr(self): def load_bgr(self):
img = cv2_imread (self.filename).astype(np.float32) / 255.0 img = cv2_imread (self.filename).astype(np.float32) / 255.0
if self.mirror: if self.mirror:
@ -48,4 +48,4 @@ class Sample(object):
def get_random_close_target_sample(self): def get_random_close_target_sample(self):
if self.close_target_list is None: if self.close_target_list is None:
return None return None
return self.close_target_list[randint (0, len(self.close_target_list)-1)] return self.close_target_list[randint (0, len(self.close_target_list)-1)]

View file

@ -4,22 +4,21 @@ from pathlib import Path
You can implement your own SampleGenerator You can implement your own SampleGenerator
''' '''
class SampleGeneratorBase(object): class SampleGeneratorBase(object):
def __init__ (self, samples_path, debug, batch_size): def __init__ (self, samples_path, debug, batch_size):
if samples_path is None: if samples_path is None:
raise Exception('samples_path is None') raise Exception('samples_path is None')
self.samples_path = Path(samples_path) self.samples_path = Path(samples_path)
self.debug = debug self.debug = debug
self.batch_size = 1 if self.debug else batch_size self.batch_size = 1 if self.debug else batch_size
#overridable #overridable
def __iter__(self): def __iter__(self):
#implement your own iterator #implement your own iterator
return self return self
def __next__(self): def __next__(self):
#implement your own iterator #implement your own iterator
return None return None

View file

@ -12,9 +12,9 @@ from samples import SampleLoader
from samples import SampleGeneratorBase from samples import SampleGeneratorBase
''' '''
arg arg
output_sample_types = [ output_sample_types = [
[SampleProcessor.TypeFlags, size, (optional)random_sub_size] , [SampleProcessor.TypeFlags, size, (optional)random_sub_size] ,
... ...
] ]
''' '''
@ -26,7 +26,7 @@ class SampleGeneratorFace(SampleGeneratorBase):
self.add_sample_idx = add_sample_idx self.add_sample_idx = add_sample_idx
self.add_pitch = add_pitch self.add_pitch = add_pitch
self.add_yaw = add_yaw self.add_yaw = add_yaw
if sort_by_yaw_target_samples_path is not None: if sort_by_yaw_target_samples_path is not None:
self.sample_type = SampleType.FACE_YAW_SORTED_AS_TARGET self.sample_type = SampleType.FACE_YAW_SORTED_AS_TARGET
elif sort_by_yaw: elif sort_by_yaw:
@ -34,9 +34,9 @@ class SampleGeneratorFace(SampleGeneratorBase):
elif with_close_to_self: elif with_close_to_self:
self.sample_type = SampleType.FACE_WITH_CLOSE_TO_SELF self.sample_type = SampleType.FACE_WITH_CLOSE_TO_SELF
else: else:
self.sample_type = SampleType.FACE self.sample_type = SampleType.FACE
self.samples = SampleLoader.load (self.sample_type, self.samples_path, sort_by_yaw_target_samples_path) self.samples = SampleLoader.load (self.sample_type, self.samples_path, sort_by_yaw_target_samples_path)
if self.debug: if self.debug:
self.generators_count = 1 self.generators_count = 1
@ -46,24 +46,24 @@ class SampleGeneratorFace(SampleGeneratorBase):
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, i ) for i in range(self.generators_count) ] self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, i ) for i in range(self.generators_count) ]
self.generators_sq = [ multiprocessing.Queue() for _ in range(self.generators_count) ] self.generators_sq = [ multiprocessing.Queue() for _ in range(self.generators_count) ]
self.generator_counter = -1 self.generator_counter = -1
def __iter__(self): def __iter__(self):
return self return self
def __next__(self): def __next__(self):
self.generator_counter += 1 self.generator_counter += 1
generator = self.generators[self.generator_counter % len(self.generators) ] generator = self.generators[self.generator_counter % len(self.generators) ]
return next(generator) return next(generator)
#forces to repeat these sample idxs as fast as possible #forces to repeat these sample idxs as fast as possible
#currently unused #currently unused
def repeat_sample_idxs(self, idxs): # [ idx, ... ] def repeat_sample_idxs(self, idxs): # [ idx, ... ]
#send idxs list to all sub generators. #send idxs list to all sub generators.
for gen_sq in self.generators_sq: for gen_sq in self.generators_sq:
gen_sq.put (idxs) gen_sq.put (idxs)
def batch_func(self, generator_id): def batch_func(self, generator_id):
gen_sq = self.generators_sq[generator_id] gen_sq = self.generators_sq[generator_id]
samples = self.samples samples = self.samples
@ -73,11 +73,11 @@ class SampleGeneratorFace(SampleGeneratorBase):
if len(samples_idxs) == 0: if len(samples_idxs) == 0:
raise ValueError('No training data provided.') raise ValueError('No training data provided.')
if self.sample_type == SampleType.FACE_YAW_SORTED or self.sample_type == SampleType.FACE_YAW_SORTED_AS_TARGET: if self.sample_type == SampleType.FACE_YAW_SORTED or self.sample_type == SampleType.FACE_YAW_SORTED_AS_TARGET:
if all ( [ samples[idx] == None for idx in samples_idxs] ): if all ( [ samples[idx] == None for idx in samples_idxs] ):
raise ValueError('Not enough training data. Gather more faces!') raise ValueError('Not enough training data. Gather more faces!')
if self.sample_type == SampleType.FACE or self.sample_type == SampleType.FACE_WITH_CLOSE_TO_SELF: if self.sample_type == SampleType.FACE or self.sample_type == SampleType.FACE_WITH_CLOSE_TO_SELF:
shuffle_idxs = [] shuffle_idxs = []
elif self.sample_type == SampleType.FACE_YAW_SORTED or self.sample_type == SampleType.FACE_YAW_SORTED_AS_TARGET: elif self.sample_type == SampleType.FACE_YAW_SORTED or self.sample_type == SampleType.FACE_YAW_SORTED_AS_TARGET:
@ -89,25 +89,25 @@ class SampleGeneratorFace(SampleGeneratorBase):
idxs = gen_sq.get() idxs = gen_sq.get()
for idx in idxs: for idx in idxs:
if idx in samples_idxs: if idx in samples_idxs:
repeat_samples_idxs.append(idx) repeat_samples_idxs.append(idx)
batches = None batches = None
for n_batch in range(self.batch_size): for n_batch in range(self.batch_size):
while True: while True:
sample = None sample = None
if len(repeat_samples_idxs) > 0: if len(repeat_samples_idxs) > 0:
idx = repeat_samples_idxs.pop() idx = repeat_samples_idxs.pop()
if self.sample_type == SampleType.FACE or self.sample_type == SampleType.FACE_WITH_CLOSE_TO_SELF: if self.sample_type == SampleType.FACE or self.sample_type == SampleType.FACE_WITH_CLOSE_TO_SELF:
sample = samples[idx] sample = samples[idx]
elif self.sample_type == SampleType.FACE_YAW_SORTED or self.sample_type == SampleType.FACE_YAW_SORTED_AS_TARGET: elif self.sample_type == SampleType.FACE_YAW_SORTED or self.sample_type == SampleType.FACE_YAW_SORTED_AS_TARGET:
sample = samples[(idx >> 16) & 0xFFFF][idx & 0xFFFF] sample = samples[(idx >> 16) & 0xFFFF][idx & 0xFFFF]
else: else:
if self.sample_type == SampleType.FACE or self.sample_type == SampleType.FACE_WITH_CLOSE_TO_SELF: if self.sample_type == SampleType.FACE or self.sample_type == SampleType.FACE_WITH_CLOSE_TO_SELF:
if len(shuffle_idxs) == 0: if len(shuffle_idxs) == 0:
shuffle_idxs = samples_idxs.copy() shuffle_idxs = samples_idxs.copy()
np.random.shuffle(shuffle_idxs) np.random.shuffle(shuffle_idxs)
idx = shuffle_idxs.pop() idx = shuffle_idxs.pop()
sample = samples[ idx ] sample = samples[ idx ]
@ -120,18 +120,18 @@ class SampleGeneratorFace(SampleGeneratorBase):
if samples[idx] != None: if samples[idx] != None:
if len(shuffle_idxs_2D[idx]) == 0: if len(shuffle_idxs_2D[idx]) == 0:
shuffle_idxs_2D[idx] = random.sample( range(len(samples[idx])), len(samples[idx]) ) shuffle_idxs_2D[idx] = random.sample( range(len(samples[idx])), len(samples[idx]) )
idx2 = shuffle_idxs_2D[idx].pop() idx2 = shuffle_idxs_2D[idx].pop()
sample = samples[idx][idx2] sample = samples[idx][idx2]
idx = (idx << 16) | (idx2 & 0xFFFF) idx = (idx << 16) | (idx2 & 0xFFFF)
if sample is not None: if sample is not None:
try: try:
x = SampleProcessor.process (sample, self.sample_process_options, self.output_sample_types, self.debug) x = SampleProcessor.process (sample, self.sample_process_options, self.output_sample_types, self.debug)
except: except:
raise Exception ("Exception occured in sample %s. Error: %s" % (sample.filename, traceback.format_exc() ) ) raise Exception ("Exception occured in sample %s. Error: %s" % (sample.filename, traceback.format_exc() ) )
if type(x) != tuple and type(x) != list: if type(x) != tuple and type(x) != list:
raise Exception('SampleProcessor.process returns NOT tuple/list') raise Exception('SampleProcessor.process returns NOT tuple/list')
@ -144,23 +144,23 @@ class SampleGeneratorFace(SampleGeneratorBase):
batches += [ [] ] batches += [ [] ]
i_pitch = len(batches)-1 i_pitch = len(batches)-1
if self.add_yaw: if self.add_yaw:
batches += [ [] ] batches += [ [] ]
i_yaw = len(batches)-1 i_yaw = len(batches)-1
for i in range(len(x)): for i in range(len(x)):
batches[i].append ( x[i] ) batches[i].append ( x[i] )
if self.add_sample_idx: if self.add_sample_idx:
batches[i_sample_idx].append (idx) batches[i_sample_idx].append (idx)
if self.add_pitch or self.add_yaw: if self.add_pitch or self.add_yaw:
pitch, yaw = LandmarksProcessor.estimate_pitch_yaw (sample.landmarks) pitch, yaw = LandmarksProcessor.estimate_pitch_yaw (sample.landmarks)
if self.add_pitch: if self.add_pitch:
batches[i_pitch].append ([pitch]) batches[i_pitch].append ([pitch])
if self.add_yaw: if self.add_yaw:
batches[i_yaw].append ([yaw]) batches[i_yaw].append ([yaw])
break break
yield [ np.array(batch) for batch in batches] yield [ np.array(batch) for batch in batches]

View file

@ -11,36 +11,36 @@ from samples import SampleLoader
from samples import SampleGeneratorBase from samples import SampleGeneratorBase
''' '''
output_sample_types = [ output_sample_types = [
[SampleProcessor.TypeFlags, size, (optional)random_sub_size] , [SampleProcessor.TypeFlags, size, (optional)random_sub_size] ,
... ...
] ]
''' '''
class SampleGeneratorImageTemporal(SampleGeneratorBase): class SampleGeneratorImageTemporal(SampleGeneratorBase):
def __init__ (self, samples_path, debug, batch_size, temporal_image_count, sample_process_options=SampleProcessor.Options(), output_sample_types=[], **kwargs): def __init__ (self, samples_path, debug, batch_size, temporal_image_count, sample_process_options=SampleProcessor.Options(), output_sample_types=[], **kwargs):
super().__init__(samples_path, debug, batch_size) super().__init__(samples_path, debug, batch_size)
self.temporal_image_count = temporal_image_count self.temporal_image_count = temporal_image_count
self.sample_process_options = sample_process_options self.sample_process_options = sample_process_options
self.output_sample_types = output_sample_types self.output_sample_types = output_sample_types
self.samples = SampleLoader.load (SampleType.IMAGE, self.samples_path) self.samples = SampleLoader.load (SampleType.IMAGE, self.samples_path)
self.generator_samples = [ self.samples ] self.generator_samples = [ self.samples ]
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, 0 )] if self.debug else \ self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, 0 )] if self.debug else \
[iter_utils.SubprocessGenerator ( self.batch_func, 0 )] [iter_utils.SubprocessGenerator ( self.batch_func, 0 )]
self.generator_counter = -1 self.generator_counter = -1
def __iter__(self): def __iter__(self):
return self return self
def __next__(self): def __next__(self):
self.generator_counter += 1 self.generator_counter += 1
generator = self.generators[self.generator_counter % len(self.generators) ] generator = self.generators[self.generator_counter % len(self.generators) ]
return next(generator) return next(generator)
def batch_func(self, generator_id): def batch_func(self, generator_id):
samples = self.generator_samples[generator_id] samples = self.generator_samples[generator_id]
samples_len = len(samples) samples_len = len(samples)
if samples_len == 0: if samples_len == 0:
@ -48,20 +48,20 @@ class SampleGeneratorImageTemporal(SampleGeneratorBase):
if samples_len - self.temporal_image_count < 0: if samples_len - self.temporal_image_count < 0:
raise ValueError('Not enough samples to fit temporal line.') raise ValueError('Not enough samples to fit temporal line.')
shuffle_idxs = [] shuffle_idxs = []
samples_sub_len = samples_len - self.temporal_image_count + 1 samples_sub_len = samples_len - self.temporal_image_count + 1
while True: while True:
batches = None batches = None
for n_batch in range(self.batch_size): for n_batch in range(self.batch_size):
if len(shuffle_idxs) == 0: if len(shuffle_idxs) == 0:
shuffle_idxs = random.sample( range(samples_sub_len), samples_sub_len ) shuffle_idxs = random.sample( range(samples_sub_len), samples_sub_len )
idx = shuffle_idxs.pop() idx = shuffle_idxs.pop()
temporal_samples = [] temporal_samples = []
for i in range( self.temporal_image_count ): for i in range( self.temporal_image_count ):
@ -70,11 +70,11 @@ class SampleGeneratorImageTemporal(SampleGeneratorBase):
temporal_samples += SampleProcessor.process (sample, self.sample_process_options, self.output_sample_types, self.debug) temporal_samples += SampleProcessor.process (sample, self.sample_process_options, self.output_sample_types, self.debug)
except: except:
raise Exception ("Exception occured in sample %s. Error: %s" % (sample.filename, traceback.format_exc() ) ) raise Exception ("Exception occured in sample %s. Error: %s" % (sample.filename, traceback.format_exc() ) )
if batches is None: if batches is None:
batches = [ [] for _ in range(len(temporal_samples)) ] batches = [ [] for _ in range(len(temporal_samples)) ]
for i in range(len(temporal_samples)): for i in range(len(temporal_samples)):
batches[i].append ( temporal_samples[i] ) batches[i].append ( temporal_samples[i] )
yield [ np.array(batch) for batch in batches] yield [ np.array(batch) for batch in batches]

View file

@ -17,44 +17,44 @@ from interact import interact as io
class SampleLoader: class SampleLoader:
cache = dict() cache = dict()
@staticmethod @staticmethod
def load(sample_type, samples_path, target_samples_path=None): def load(sample_type, samples_path, target_samples_path=None):
cache = SampleLoader.cache cache = SampleLoader.cache
if str(samples_path) not in cache.keys(): if str(samples_path) not in cache.keys():
cache[str(samples_path)] = [None]*SampleType.QTY cache[str(samples_path)] = [None]*SampleType.QTY
datas = cache[str(samples_path)] datas = cache[str(samples_path)]
if sample_type == SampleType.IMAGE: if sample_type == SampleType.IMAGE:
if datas[sample_type] is None: if datas[sample_type] is None:
datas[sample_type] = [ Sample(filename=filename) for filename in io.progress_bar_generator( Path_utils.get_image_paths(samples_path), "Loading") ] datas[sample_type] = [ Sample(filename=filename) for filename in io.progress_bar_generator( Path_utils.get_image_paths(samples_path), "Loading") ]
elif sample_type == SampleType.FACE: elif sample_type == SampleType.FACE:
if datas[sample_type] is None: if datas[sample_type] is None:
datas[sample_type] = SampleLoader.upgradeToFaceSamples( [ Sample(filename=filename) for filename in Path_utils.get_image_paths(samples_path) ] ) datas[sample_type] = SampleLoader.upgradeToFaceSamples( [ Sample(filename=filename) for filename in Path_utils.get_image_paths(samples_path) ] )
elif sample_type == SampleType.FACE_YAW_SORTED: elif sample_type == SampleType.FACE_YAW_SORTED:
if datas[sample_type] is None: if datas[sample_type] is None:
datas[sample_type] = SampleLoader.upgradeToFaceYawSortedSamples( SampleLoader.load(SampleType.FACE, samples_path) ) datas[sample_type] = SampleLoader.upgradeToFaceYawSortedSamples( SampleLoader.load(SampleType.FACE, samples_path) )
elif sample_type == SampleType.FACE_YAW_SORTED_AS_TARGET: elif sample_type == SampleType.FACE_YAW_SORTED_AS_TARGET:
if datas[sample_type] is None: if datas[sample_type] is None:
if target_samples_path is None: if target_samples_path is None:
raise Exception('target_samples_path is None for FACE_YAW_SORTED_AS_TARGET') raise Exception('target_samples_path is None for FACE_YAW_SORTED_AS_TARGET')
datas[sample_type] = SampleLoader.upgradeToFaceYawSortedAsTargetSamples( SampleLoader.load(SampleType.FACE_YAW_SORTED, samples_path), SampleLoader.load(SampleType.FACE_YAW_SORTED, target_samples_path) ) datas[sample_type] = SampleLoader.upgradeToFaceYawSortedAsTargetSamples( SampleLoader.load(SampleType.FACE_YAW_SORTED, samples_path), SampleLoader.load(SampleType.FACE_YAW_SORTED, target_samples_path) )
elif sample_type == SampleType.FACE_WITH_CLOSE_TO_SELF: elif sample_type == SampleType.FACE_WITH_CLOSE_TO_SELF:
if datas[sample_type] is None: if datas[sample_type] is None:
datas[sample_type] = SampleLoader.upgradeToFaceCloseToSelfSamples( SampleLoader.load(SampleType.FACE, samples_path) ) datas[sample_type] = SampleLoader.upgradeToFaceCloseToSelfSamples( SampleLoader.load(SampleType.FACE, samples_path) )
return datas[sample_type] return datas[sample_type]
@staticmethod @staticmethod
def upgradeToFaceSamples ( samples ): def upgradeToFaceSamples ( samples ):
sample_list = [] sample_list = []
for s in io.progress_bar_generator(samples, "Loading"): for s in io.progress_bar_generator(samples, "Loading"):
s_filename_path = Path(s.filename) s_filename_path = Path(s.filename)
try: try:
@ -64,57 +64,57 @@ class SampleLoader:
dflimg = DFLJPG.load ( str(s_filename_path) ) dflimg = DFLJPG.load ( str(s_filename_path) )
else: else:
dflimg = None dflimg = None
if dflimg is None: if dflimg is None:
print ("%s is not a dfl image file required for training" % (s_filename_path.name) ) print ("%s is not a dfl image file required for training" % (s_filename_path.name) )
continue continue
pitch, yaw = LandmarksProcessor.estimate_pitch_yaw ( dflimg.get_landmarks() ) pitch, yaw = LandmarksProcessor.estimate_pitch_yaw ( dflimg.get_landmarks() )
sample_list.append( s.copy_and_set(sample_type=SampleType.FACE, sample_list.append( s.copy_and_set(sample_type=SampleType.FACE,
face_type=FaceType.fromString (dflimg.get_face_type()), face_type=FaceType.fromString (dflimg.get_face_type()),
shape=dflimg.get_shape(), shape=dflimg.get_shape(),
landmarks=dflimg.get_landmarks(), landmarks=dflimg.get_landmarks(),
pitch=pitch, pitch=pitch,
yaw=yaw) ) yaw=yaw) )
except: except:
print ("Unable to load %s , error: %s" % (str(s_filename_path), traceback.format_exc() ) ) print ("Unable to load %s , error: %s" % (str(s_filename_path), traceback.format_exc() ) )
return sample_list return sample_list
@staticmethod @staticmethod
def upgradeToFaceCloseToSelfSamples (samples): def upgradeToFaceCloseToSelfSamples (samples):
yaw_samples = SampleLoader.upgradeToFaceYawSortedSamples(samples) yaw_samples = SampleLoader.upgradeToFaceYawSortedSamples(samples)
yaw_samples_len = len(yaw_samples) yaw_samples_len = len(yaw_samples)
sample_list = [] sample_list = []
for i in io.progress_bar_generator( range(yaw_samples_len), "Sorting"): for i in io.progress_bar_generator( range(yaw_samples_len), "Sorting"):
if yaw_samples[i] is not None: if yaw_samples[i] is not None:
for s in yaw_samples[i]: for s in yaw_samples[i]:
s_t = [] s_t = []
for n in range(2000): for n in range(2000):
yaw_idx = np.clip ( i-10 +np.random.randint(20), 0, yaw_samples_len-1 ) yaw_idx = np.clip ( i-10 +np.random.randint(20), 0, yaw_samples_len-1 )
if yaw_samples[yaw_idx] is None: if yaw_samples[yaw_idx] is None:
continue continue
yaw_idx_samples_len = len(yaw_samples[yaw_idx]) yaw_idx_samples_len = len(yaw_samples[yaw_idx])
yaw_idx_sample = yaw_samples[yaw_idx][ np.random.randint(yaw_idx_samples_len) ] yaw_idx_sample = yaw_samples[yaw_idx][ np.random.randint(yaw_idx_samples_len) ]
if s.filename == yaw_idx_sample.filename: if s.filename == yaw_idx_sample.filename:
continue continue
s_t.append ( yaw_idx_sample ) s_t.append ( yaw_idx_sample )
if len(s_t) >= 50: if len(s_t) >= 50:
break break
if len(s_t) == 0: if len(s_t) == 0:
s_t = [s] s_t = [s]
sample_list.append( s.copy_and_set(close_target_list = s_t) ) sample_list.append( s.copy_and_set(close_target_list = s_t) )
return sample_list return sample_list
@staticmethod @staticmethod
def upgradeToFaceYawSortedSamples( samples ): def upgradeToFaceYawSortedSamples( samples ):
@ -123,50 +123,50 @@ class SampleLoader:
diff_rot_per_grad = abs(highest_yaw-lowest_yaw) / gradations diff_rot_per_grad = abs(highest_yaw-lowest_yaw) / gradations
yaws_sample_list = [None]*gradations yaws_sample_list = [None]*gradations
for i in io.progress_bar_generator(range(gradations), "Sorting"): for i in io.progress_bar_generator(range(gradations), "Sorting"):
yaw = lowest_yaw + i*diff_rot_per_grad yaw = lowest_yaw + i*diff_rot_per_grad
next_yaw = lowest_yaw + (i+1)*diff_rot_per_grad next_yaw = lowest_yaw + (i+1)*diff_rot_per_grad
yaw_samples = [] yaw_samples = []
for s in samples: for s in samples:
s_yaw = s.yaw s_yaw = s.yaw
if (i == 0 and s_yaw < next_yaw) or \ if (i == 0 and s_yaw < next_yaw) or \
(i < gradations-1 and s_yaw >= yaw and s_yaw < next_yaw) or \ (i < gradations-1 and s_yaw >= yaw and s_yaw < next_yaw) or \
(i == gradations-1 and s_yaw >= yaw): (i == gradations-1 and s_yaw >= yaw):
yaw_samples.append ( s.copy_and_set(sample_type=SampleType.FACE_YAW_SORTED) ) yaw_samples.append ( s.copy_and_set(sample_type=SampleType.FACE_YAW_SORTED) )
if len(yaw_samples) > 0: if len(yaw_samples) > 0:
yaws_sample_list[i] = yaw_samples yaws_sample_list[i] = yaw_samples
return yaws_sample_list return yaws_sample_list
@staticmethod @staticmethod
def upgradeToFaceYawSortedAsTargetSamples (s, t): def upgradeToFaceYawSortedAsTargetSamples (s, t):
l = len(s) l = len(s)
if l != len(t): if l != len(t):
raise Exception('upgradeToFaceYawSortedAsTargetSamples() s_len != t_len') raise Exception('upgradeToFaceYawSortedAsTargetSamples() s_len != t_len')
b = l // 2 b = l // 2
s_idxs = np.argwhere ( np.array ( [ 1 if x != None else 0 for x in s] ) == 1 )[:,0] s_idxs = np.argwhere ( np.array ( [ 1 if x != None else 0 for x in s] ) == 1 )[:,0]
t_idxs = np.argwhere ( np.array ( [ 1 if x != None else 0 for x in t] ) == 1 )[:,0] t_idxs = np.argwhere ( np.array ( [ 1 if x != None else 0 for x in t] ) == 1 )[:,0]
new_s = [None]*l new_s = [None]*l
for t_idx in t_idxs: for t_idx in t_idxs:
search_idxs = [] search_idxs = []
for i in range(0,l): for i in range(0,l):
search_idxs += [t_idx - i, (l-t_idx-1) - i, t_idx + i, (l-t_idx-1) + i] search_idxs += [t_idx - i, (l-t_idx-1) - i, t_idx + i, (l-t_idx-1) + i]
for search_idx in search_idxs: for search_idx in search_idxs:
if search_idx in s_idxs: if search_idx in s_idxs:
mirrored = ( t_idx != search_idx and ((t_idx < b and search_idx >= b) or (search_idx < b and t_idx >= b)) ) mirrored = ( t_idx != search_idx and ((t_idx < b and search_idx >= b) or (search_idx < b and t_idx >= b)) )
new_s[t_idx] = [ sample.copy_and_set(sample_type=SampleType.FACE_YAW_SORTED_AS_TARGET, new_s[t_idx] = [ sample.copy_and_set(sample_type=SampleType.FACE_YAW_SORTED_AS_TARGET,
mirror=True, mirror=True,
yaw=-sample.yaw, yaw=-sample.yaw,
landmarks=LandmarksProcessor.mirror_landmarks (sample.landmarks, sample.shape[1] )) landmarks=LandmarksProcessor.mirror_landmarks (sample.landmarks, sample.shape[1] ))
for sample in s[search_idx] for sample in s[search_idx]
] if mirrored else s[search_idx] ] if mirrored else s[search_idx]
break break
return new_s return new_s

View file

@ -13,61 +13,61 @@ class SampleProcessor(object):
WARPED_TRANSFORMED = 0x00000004, WARPED_TRANSFORMED = 0x00000004,
TRANSFORMED = 0x00000008, TRANSFORMED = 0x00000008,
LANDMARKS_ARRAY = 0x00000010, #currently unused LANDMARKS_ARRAY = 0x00000010, #currently unused
RANDOM_CLOSE = 0x00000020, RANDOM_CLOSE = 0x00000020,
MORPH_TO_RANDOM_CLOSE = 0x00000040, MORPH_TO_RANDOM_CLOSE = 0x00000040,
FACE_ALIGN_HALF = 0x00000100, FACE_ALIGN_HALF = 0x00000100,
FACE_ALIGN_FULL = 0x00000200, FACE_ALIGN_FULL = 0x00000200,
FACE_ALIGN_HEAD = 0x00000400, FACE_ALIGN_HEAD = 0x00000400,
FACE_ALIGN_AVATAR = 0x00000800, FACE_ALIGN_AVATAR = 0x00000800,
FACE_MASK_FULL = 0x00001000, FACE_MASK_FULL = 0x00001000,
FACE_MASK_EYES = 0x00002000, FACE_MASK_EYES = 0x00002000,
MODE_BGR = 0x01000000, #BGR MODE_BGR = 0x01000000, #BGR
MODE_G = 0x02000000, #Grayscale MODE_G = 0x02000000, #Grayscale
MODE_GGG = 0x04000000, #3xGrayscale MODE_GGG = 0x04000000, #3xGrayscale
MODE_M = 0x08000000, #mask only MODE_M = 0x08000000, #mask only
MODE_BGR_SHUFFLE = 0x10000000, #BGR shuffle MODE_BGR_SHUFFLE = 0x10000000, #BGR shuffle
class Options(object): class Options(object):
def __init__(self, random_flip = True, normalize_tanh = False, rotation_range=[-10,10], scale_range=[-0.05, 0.05], tx_range=[-0.05, 0.05], ty_range=[-0.05, 0.05]): def __init__(self, random_flip = True, normalize_tanh = False, rotation_range=[-10,10], scale_range=[-0.05, 0.05], tx_range=[-0.05, 0.05], ty_range=[-0.05, 0.05]):
self.random_flip = random_flip self.random_flip = random_flip
self.normalize_tanh = normalize_tanh self.normalize_tanh = normalize_tanh
self.rotation_range = rotation_range self.rotation_range = rotation_range
self.scale_range = scale_range self.scale_range = scale_range
self.tx_range = tx_range self.tx_range = tx_range
self.ty_range = ty_range self.ty_range = ty_range
@staticmethod @staticmethod
def process (sample, sample_process_options, output_sample_types, debug): def process (sample, sample_process_options, output_sample_types, debug):
sample_bgr = sample.load_bgr() sample_bgr = sample.load_bgr()
h,w,c = sample_bgr.shape h,w,c = sample_bgr.shape
is_face_sample = sample.landmarks is not None is_face_sample = sample.landmarks is not None
if debug and is_face_sample: if debug and is_face_sample:
LandmarksProcessor.draw_landmarks (sample_bgr, sample.landmarks, (0, 1, 0)) LandmarksProcessor.draw_landmarks (sample_bgr, sample.landmarks, (0, 1, 0))
close_sample = sample.close_target_list[ np.random.randint(0, len(sample.close_target_list)) ] if sample.close_target_list is not None else None close_sample = sample.close_target_list[ np.random.randint(0, len(sample.close_target_list)) ] if sample.close_target_list is not None else None
close_sample_bgr = close_sample.load_bgr() if close_sample is not None else None close_sample_bgr = close_sample.load_bgr() if close_sample is not None else None
if debug and close_sample_bgr is not None: if debug and close_sample_bgr is not None:
LandmarksProcessor.draw_landmarks (close_sample_bgr, close_sample.landmarks, (0, 1, 0)) LandmarksProcessor.draw_landmarks (close_sample_bgr, close_sample.landmarks, (0, 1, 0))
params = image_utils.gen_warp_params(sample_bgr, sample_process_options.random_flip, rotation_range=sample_process_options.rotation_range, scale_range=sample_process_options.scale_range, tx_range=sample_process_options.tx_range, ty_range=sample_process_options.ty_range ) params = image_utils.gen_warp_params(sample_bgr, sample_process_options.random_flip, rotation_range=sample_process_options.rotation_range, scale_range=sample_process_options.scale_range, tx_range=sample_process_options.tx_range, ty_range=sample_process_options.ty_range )
images = [[None]*3 for _ in range(30)] images = [[None]*3 for _ in range(30)]
sample_rnd_seed = np.random.randint(0x80000000) sample_rnd_seed = np.random.randint(0x80000000)
outputs = [] outputs = []
for sample_type in output_sample_types: for sample_type in output_sample_types:
f = sample_type[0] f = sample_type[0]
size = sample_type[1] size = sample_type[1]
random_sub_size = 0 if len (sample_type) < 3 else min( sample_type[2] , size) random_sub_size = 0 if len (sample_type) < 3 else min( sample_type[2] , size)
if f & SampleProcessor.TypeFlags.SOURCE != 0: if f & SampleProcessor.TypeFlags.SOURCE != 0:
img_type = 0 img_type = 0
elif f & SampleProcessor.TypeFlags.WARPED != 0: elif f & SampleProcessor.TypeFlags.WARPED != 0:
@ -77,53 +77,53 @@ class SampleProcessor(object):
elif f & SampleProcessor.TypeFlags.TRANSFORMED != 0: elif f & SampleProcessor.TypeFlags.TRANSFORMED != 0:
img_type = 3 img_type = 3
elif f & SampleProcessor.TypeFlags.LANDMARKS_ARRAY != 0: elif f & SampleProcessor.TypeFlags.LANDMARKS_ARRAY != 0:
img_type = 4 img_type = 4
else: else:
raise ValueError ('expected SampleTypeFlags type') raise ValueError ('expected SampleTypeFlags type')
if f & SampleProcessor.TypeFlags.RANDOM_CLOSE != 0: if f & SampleProcessor.TypeFlags.RANDOM_CLOSE != 0:
img_type += 10 img_type += 10
elif f & SampleProcessor.TypeFlags.MORPH_TO_RANDOM_CLOSE != 0: elif f & SampleProcessor.TypeFlags.MORPH_TO_RANDOM_CLOSE != 0:
img_type += 20 img_type += 20
face_mask_type = 0 face_mask_type = 0
if f & SampleProcessor.TypeFlags.FACE_MASK_FULL != 0: if f & SampleProcessor.TypeFlags.FACE_MASK_FULL != 0:
face_mask_type = 1 face_mask_type = 1
elif f & SampleProcessor.TypeFlags.FACE_MASK_EYES != 0: elif f & SampleProcessor.TypeFlags.FACE_MASK_EYES != 0:
face_mask_type = 2 face_mask_type = 2
target_face_type = -1 target_face_type = -1
if f & SampleProcessor.TypeFlags.FACE_ALIGN_HALF != 0: if f & SampleProcessor.TypeFlags.FACE_ALIGN_HALF != 0:
target_face_type = FaceType.HALF target_face_type = FaceType.HALF
elif f & SampleProcessor.TypeFlags.FACE_ALIGN_FULL != 0: elif f & SampleProcessor.TypeFlags.FACE_ALIGN_FULL != 0:
target_face_type = FaceType.FULL target_face_type = FaceType.FULL
elif f & SampleProcessor.TypeFlags.FACE_ALIGN_HEAD != 0: elif f & SampleProcessor.TypeFlags.FACE_ALIGN_HEAD != 0:
target_face_type = FaceType.HEAD target_face_type = FaceType.HEAD
elif f & SampleProcessor.TypeFlags.FACE_ALIGN_AVATAR != 0: elif f & SampleProcessor.TypeFlags.FACE_ALIGN_AVATAR != 0:
target_face_type = FaceType.AVATAR target_face_type = FaceType.AVATAR
if img_type == 4: if img_type == 4:
l = sample.landmarks l = sample.landmarks
l = np.concatenate ( [ np.expand_dims(l[:,0] / w,-1), np.expand_dims(l[:,1] / h,-1) ], -1 ) l = np.concatenate ( [ np.expand_dims(l[:,0] / w,-1), np.expand_dims(l[:,1] / h,-1) ], -1 )
l = np.clip(l, 0.0, 1.0) l = np.clip(l, 0.0, 1.0)
img = l img = l
else: else:
if images[img_type][face_mask_type] is None: if images[img_type][face_mask_type] is None:
if img_type >= 10 and img_type <= 19: #RANDOM_CLOSE if img_type >= 10 and img_type <= 19: #RANDOM_CLOSE
img_type -= 10 img_type -= 10
img = close_sample_bgr img = close_sample_bgr
cur_sample = close_sample cur_sample = close_sample
elif img_type >= 20 and img_type <= 29: #MORPH_TO_RANDOM_CLOSE elif img_type >= 20 and img_type <= 29: #MORPH_TO_RANDOM_CLOSE
img_type -= 20 img_type -= 20
res = sample.shape[0] res = sample.shape[0]
s_landmarks = sample.landmarks.copy() s_landmarks = sample.landmarks.copy()
d_landmarks = close_sample.landmarks.copy() d_landmarks = close_sample.landmarks.copy()
idxs = list(range(len(s_landmarks))) idxs = list(range(len(s_landmarks)))
#remove landmarks near boundaries #remove landmarks near boundaries
for i in idxs[:]: for i in idxs[:]:
s_l = s_landmarks[i] s_l = s_landmarks[i]
d_l = d_landmarks[i] d_l = d_landmarks[i]
if s_l[0] < 5 or s_l[1] < 5 or s_l[0] >= res-5 or s_l[1] >= res-5 or \ if s_l[0] < 5 or s_l[1] < 5 or s_l[0] >= res-5 or s_l[1] >= res-5 or \
d_l[0] < 5 or d_l[1] < 5 or d_l[0] >= res-5 or d_l[1] >= res-5: d_l[0] < 5 or d_l[1] < 5 or d_l[0] >= res-5 or d_l[1] >= res-5:
@ -139,39 +139,39 @@ class SampleProcessor(object):
diff_l = np.abs(s_l - s_l_2) diff_l = np.abs(s_l - s_l_2)
if np.sqrt(diff_l.dot(diff_l)) < 5: if np.sqrt(diff_l.dot(diff_l)) < 5:
idxs.remove(i) idxs.remove(i)
break break
s_landmarks = s_landmarks[idxs] s_landmarks = s_landmarks[idxs]
d_landmarks = d_landmarks[idxs] d_landmarks = d_landmarks[idxs]
s_landmarks = np.concatenate ( [s_landmarks, [ [0,0], [ res // 2, 0], [ res-1, 0], [0, res//2], [res-1, res//2] ,[0,res-1] ,[res//2, res-1] ,[res-1,res-1] ] ] ) s_landmarks = np.concatenate ( [s_landmarks, [ [0,0], [ res // 2, 0], [ res-1, 0], [0, res//2], [res-1, res//2] ,[0,res-1] ,[res//2, res-1] ,[res-1,res-1] ] ] )
d_landmarks = np.concatenate ( [d_landmarks, [ [0,0], [ res // 2, 0], [ res-1, 0], [0, res//2], [res-1, res//2] ,[0,res-1] ,[res//2, res-1] ,[res-1,res-1] ] ] ) d_landmarks = np.concatenate ( [d_landmarks, [ [0,0], [ res // 2, 0], [ res-1, 0], [0, res//2], [res-1, res//2] ,[0,res-1] ,[res//2, res-1] ,[res-1,res-1] ] ] )
img = image_utils.morph_by_points (sample_bgr, s_landmarks, d_landmarks) img = image_utils.morph_by_points (sample_bgr, s_landmarks, d_landmarks)
cur_sample = close_sample cur_sample = close_sample
else: else:
img = sample_bgr img = sample_bgr
cur_sample = sample cur_sample = sample
if is_face_sample: if is_face_sample:
if face_mask_type == 1: if face_mask_type == 1:
img = np.concatenate( (img, LandmarksProcessor.get_image_hull_mask (img.shape, cur_sample.landmarks) ), -1 ) img = np.concatenate( (img, LandmarksProcessor.get_image_hull_mask (img.shape, cur_sample.landmarks) ), -1 )
elif face_mask_type == 2: elif face_mask_type == 2:
mask = LandmarksProcessor.get_image_eye_mask (img.shape, cur_sample.landmarks) mask = LandmarksProcessor.get_image_eye_mask (img.shape, cur_sample.landmarks)
mask = np.expand_dims (cv2.blur (mask, ( w // 32, w // 32 ) ), -1) mask = np.expand_dims (cv2.blur (mask, ( w // 32, w // 32 ) ), -1)
mask[mask > 0.0] = 1.0 mask[mask > 0.0] = 1.0
img = np.concatenate( (img, mask ), -1 ) img = np.concatenate( (img, mask ), -1 )
images[img_type][face_mask_type] = image_utils.warp_by_params (params, img, (img_type==1 or img_type==2), (img_type==2 or img_type==3), img_type != 0, face_mask_type == 0) images[img_type][face_mask_type] = image_utils.warp_by_params (params, img, (img_type==1 or img_type==2), (img_type==2 or img_type==3), img_type != 0, face_mask_type == 0)
img = images[img_type][face_mask_type] img = images[img_type][face_mask_type]
if is_face_sample and target_face_type != -1: if is_face_sample and target_face_type != -1:
if target_face_type > sample.face_type: if target_face_type > sample.face_type:
raise Exception ('sample %s type %s does not match model requirement %s. Consider extract necessary type of faces.' % (sample.filename, sample.face_type, target_face_type) ) raise Exception ('sample %s type %s does not match model requirement %s. Consider extract necessary type of faces.' % (sample.filename, sample.face_type, target_face_type) )
img = cv2.warpAffine( img, LandmarksProcessor.get_transform_mat (sample.landmarks, size, target_face_type), (size,size), flags=cv2.INTER_CUBIC ) img = cv2.warpAffine( img, LandmarksProcessor.get_transform_mat (sample.landmarks, size, target_face_type), (size,size), flags=cv2.INTER_CUBIC )
else: else:
img = cv2.resize( img, (size,size), cv2.INTER_CUBIC ) img = cv2.resize( img, (size,size), cv2.INTER_CUBIC )
if random_sub_size != 0: if random_sub_size != 0:
sub_size = size - random_sub_size sub_size = size - random_sub_size
rnd_state = np.random.RandomState (sample_rnd_seed+random_sub_size) rnd_state = np.random.RandomState (sample_rnd_seed+random_sub_size)
start_x = rnd_state.randint(sub_size+1) start_x = rnd_state.randint(sub_size+1)
start_y = rnd_state.randint(sub_size+1) start_y = rnd_state.randint(sub_size+1)
@ -195,7 +195,7 @@ class SampleProcessor(object):
img = img_mask img = img_mask
else: else:
raise ValueError ('expected SampleTypeFlags mode') raise ValueError ('expected SampleTypeFlags mode')
if not debug: if not debug:
if sample_process_options.normalize_tanh: if sample_process_options.normalize_tanh:
img = np.clip (img * 2.0 - 1.0, -1.0, 1.0) img = np.clip (img * 2.0 - 1.0, -1.0, 1.0)
@ -213,6 +213,6 @@ class SampleProcessor(object):
elif output.shape[2] == 4: elif output.shape[2] == 4:
result += [output[...,0:3]*output[...,3:4],] result += [output[...,0:3]*output[...,3:4],]
return result return result
else: else:
return outputs return outputs

View file

@ -4,4 +4,4 @@ from .SampleLoader import SampleLoader
from .SampleProcessor import SampleProcessor from .SampleProcessor import SampleProcessor
from .SampleGeneratorBase import SampleGeneratorBase from .SampleGeneratorBase import SampleGeneratorBase
from .SampleGeneratorFace import SampleGeneratorFace from .SampleGeneratorFace import SampleGeneratorFace
from .SampleGeneratorImageTemporal import SampleGeneratorImageTemporal from .SampleGeneratorImageTemporal import SampleGeneratorImageTemporal

View file

@ -11,7 +11,7 @@ class DFLJPG(object):
self.chunks = [] self.chunks = []
self.dfl_dict = None self.dfl_dict = None
self.shape = (0,0,0) self.shape = (0,0,0)
@staticmethod @staticmethod
def load_raw(filename): def load_raw(filename):
try: try:
@ -19,7 +19,7 @@ class DFLJPG(object):
data = f.read() data = f.read()
except: except:
raise FileNotFoundError(data) raise FileNotFoundError(data)
try: try:
inst = DFLJPG() inst = DFLJPG()
inst.data = data inst.data = data
@ -30,23 +30,23 @@ class DFLJPG(object):
while data_counter < inst_length: while data_counter < inst_length:
chunk_m_l, chunk_m_h = struct.unpack ("BB", data[data_counter:data_counter+2]) chunk_m_l, chunk_m_h = struct.unpack ("BB", data[data_counter:data_counter+2])
data_counter += 2 data_counter += 2
if chunk_m_l != 0xFF: if chunk_m_l != 0xFF:
raise ValueError("No Valid JPG info") raise ValueError("No Valid JPG info")
chunk_name = None chunk_name = None
chunk_size = None chunk_size = None
chunk_data = None chunk_data = None
chunk_ex_data = None chunk_ex_data = None
is_unk_chunk = False is_unk_chunk = False
if chunk_m_h & 0xF0 == 0xD0: if chunk_m_h & 0xF0 == 0xD0:
n = chunk_m_h & 0x0F n = chunk_m_h & 0x0F
if n >= 0 and n <= 7: if n >= 0 and n <= 7:
chunk_name = "RST%d" % (n) chunk_name = "RST%d" % (n)
chunk_size = 0 chunk_size = 0
elif n == 0x8: elif n == 0x8:
chunk_name = "SOI" chunk_name = "SOI"
chunk_size = 0 chunk_size = 0
if len(chunks) != 0: if len(chunks) != 0:
@ -54,73 +54,73 @@ class DFLJPG(object):
elif n == 0x9: elif n == 0x9:
chunk_name = "EOI" chunk_name = "EOI"
chunk_size = 0 chunk_size = 0
elif n == 0xA: elif n == 0xA:
chunk_name = "SOS" chunk_name = "SOS"
elif n == 0xB: elif n == 0xB:
chunk_name = "DQT" chunk_name = "DQT"
elif n == 0xD: elif n == 0xD:
chunk_name = "DRI" chunk_name = "DRI"
chunk_size = 2 chunk_size = 2
else: else:
is_unk_chunk = True is_unk_chunk = True
elif chunk_m_h & 0xF0 == 0xC0: elif chunk_m_h & 0xF0 == 0xC0:
n = chunk_m_h & 0x0F n = chunk_m_h & 0x0F
if n == 0: if n == 0:
chunk_name = "SOF0" chunk_name = "SOF0"
elif n == 2: elif n == 2:
chunk_name = "SOF2" chunk_name = "SOF2"
elif n == 4: elif n == 4:
chunk_name = "DHT" chunk_name = "DHT"
else: else:
is_unk_chunk = True is_unk_chunk = True
elif chunk_m_h & 0xF0 == 0xE0: elif chunk_m_h & 0xF0 == 0xE0:
n = chunk_m_h & 0x0F n = chunk_m_h & 0x0F
chunk_name = "APP%d" % (n) chunk_name = "APP%d" % (n)
else: else:
is_unk_chunk = True is_unk_chunk = True
if is_unk_chunk: if is_unk_chunk:
raise ValueError("Unknown chunk %X" % (chunk_m_h) ) raise ValueError("Unknown chunk %X" % (chunk_m_h) )
if chunk_size == None: #variable size if chunk_size == None: #variable size
chunk_size, = struct.unpack (">H", data[data_counter:data_counter+2]) chunk_size, = struct.unpack (">H", data[data_counter:data_counter+2])
chunk_size -= 2 chunk_size -= 2
data_counter += 2 data_counter += 2
if chunk_size > 0: if chunk_size > 0:
chunk_data = data[data_counter:data_counter+chunk_size] chunk_data = data[data_counter:data_counter+chunk_size]
data_counter += chunk_size data_counter += chunk_size
if chunk_name == "SOS": if chunk_name == "SOS":
c = data_counter c = data_counter
while c < inst_length and (data[c] != 0xFF or data[c+1] != 0xD9): while c < inst_length and (data[c] != 0xFF or data[c+1] != 0xD9):
c += 1 c += 1
chunk_ex_data = data[data_counter:c] chunk_ex_data = data[data_counter:c]
data_counter = c data_counter = c
chunks.append ({'name' : chunk_name, chunks.append ({'name' : chunk_name,
'm_h' : chunk_m_h, 'm_h' : chunk_m_h,
'data' : chunk_data, 'data' : chunk_data,
'ex_data' : chunk_ex_data, 'ex_data' : chunk_ex_data,
}) })
inst.chunks = chunks inst.chunks = chunks
return inst return inst
except Exception as e: except Exception as e:
raise Exception ("Corrupted JPG file: %s" % (str(e))) raise Exception ("Corrupted JPG file: %s" % (str(e)))
@staticmethod @staticmethod
def load(filename): def load(filename):
try: try:
inst = DFLJPG.load_raw (filename) inst = DFLJPG.load_raw (filename)
inst.dfl_dict = None inst.dfl_dict = None
for chunk in inst.chunks: for chunk in inst.chunks:
if chunk['name'] == 'APP0': if chunk['name'] == 'APP0':
d, c = chunk['data'], 0 d, c = chunk['data'], 0
c, id, _ = struct_unpack (d, c, "=4sB") c, id, _ = struct_unpack (d, c, "=4sB")
if id == b"JFIF": if id == b"JFIF":
c, ver_major, ver_minor, units, Xdensity, Ydensity, Xthumbnail, Ythumbnail = struct_unpack (d, c, "=BBBHHBB") c, ver_major, ver_minor, units, Xdensity, Ydensity, Xthumbnail, Ythumbnail = struct_unpack (d, c, "=BBBHHBB")
#if units == 0: #if units == 0:
@ -131,22 +131,22 @@ class DFLJPG(object):
d, c = chunk['data'], 0 d, c = chunk['data'], 0
c, precision, height, width = struct_unpack (d, c, ">BHH") c, precision, height, width = struct_unpack (d, c, ">BHH")
inst.shape = (height, width, 3) inst.shape = (height, width, 3)
elif chunk['name'] == 'APP15': elif chunk['name'] == 'APP15':
if type(chunk['data']) == bytes: if type(chunk['data']) == bytes:
inst.dfl_dict = pickle.loads(chunk['data']) inst.dfl_dict = pickle.loads(chunk['data'])
if (inst.dfl_dict is not None) and ('face_type' not in inst.dfl_dict.keys()): if (inst.dfl_dict is not None) and ('face_type' not in inst.dfl_dict.keys()):
inst.dfl_dict['face_type'] = FaceType.toString (FaceType.FULL) inst.dfl_dict['face_type'] = FaceType.toString (FaceType.FULL)
if inst.dfl_dict == None: if inst.dfl_dict == None:
return None return None
return inst return inst
except Exception as e: except Exception as e:
print (e) print (e)
return None return None
@staticmethod @staticmethod
def embed_data(filename, face_type=None, def embed_data(filename, face_type=None,
landmarks=None, landmarks=None,
@ -155,7 +155,7 @@ class DFLJPG(object):
source_landmarks=None, source_landmarks=None,
image_to_face_mat=None image_to_face_mat=None
): ):
inst = DFLJPG.load_raw (filename) inst = DFLJPG.load_raw (filename)
inst.setDFLDictData ({ inst.setDFLDictData ({
'face_type': face_type, 'face_type': face_type,
@ -165,41 +165,41 @@ class DFLJPG(object):
'source_landmarks': source_landmarks, 'source_landmarks': source_landmarks,
'image_to_face_mat': image_to_face_mat 'image_to_face_mat': image_to_face_mat
}) })
try: try:
with open(filename, "wb") as f: with open(filename, "wb") as f:
f.write ( inst.dump() ) f.write ( inst.dump() )
except: except:
raise Exception( 'cannot save %s' % (filename) ) raise Exception( 'cannot save %s' % (filename) )
def dump(self): def dump(self):
data = b"" data = b""
for chunk in self.chunks: for chunk in self.chunks:
data += struct.pack ("BB", 0xFF, chunk['m_h'] ) data += struct.pack ("BB", 0xFF, chunk['m_h'] )
chunk_data = chunk['data'] chunk_data = chunk['data']
if chunk_data is not None: if chunk_data is not None:
data += struct.pack (">H", len(chunk_data)+2 ) data += struct.pack (">H", len(chunk_data)+2 )
data += chunk_data data += chunk_data
chunk_ex_data = chunk['ex_data'] chunk_ex_data = chunk['ex_data']
if chunk_ex_data is not None: if chunk_ex_data is not None:
data += chunk_ex_data data += chunk_ex_data
return data return data
def get_shape(self): def get_shape(self):
return self.shape return self.shape
def get_height(self): def get_height(self):
for chunk in self.chunks: for chunk in self.chunks:
if type(chunk) == IHDR: if type(chunk) == IHDR:
return chunk.height return chunk.height
return 0 return 0
def getDFLDictData(self): def getDFLDictData(self):
return self.dfl_dict return self.dfl_dict
def setDFLDictData (self, dict_data=None): def setDFLDictData (self, dict_data=None):
self.dfl_dict = dict_data self.dfl_dict = dict_data
@ -211,17 +211,17 @@ class DFLJPG(object):
last_app_chunk = 0 last_app_chunk = 0
for i, chunk in enumerate (self.chunks): for i, chunk in enumerate (self.chunks):
if chunk['m_h'] & 0xF0 == 0xE0: if chunk['m_h'] & 0xF0 == 0xE0:
last_app_chunk = i last_app_chunk = i
dflchunk = {'name' : 'APP15', dflchunk = {'name' : 'APP15',
'm_h' : 0xEF, 'm_h' : 0xEF,
'data' : pickle.dumps(dict_data), 'data' : pickle.dumps(dict_data),
'ex_data' : None, 'ex_data' : None,
} }
self.chunks.insert (last_app_chunk+1, dflchunk) self.chunks.insert (last_app_chunk+1, dflchunk)
def get_face_type(self): return self.dfl_dict['face_type'] def get_face_type(self): return self.dfl_dict['face_type']
def get_landmarks(self): return np.array ( self.dfl_dict['landmarks'] ) def get_landmarks(self): return np.array ( self.dfl_dict['landmarks'] )
def get_source_filename(self): return self.dfl_dict['source_filename'] def get_source_filename(self): return self.dfl_dict['source_filename']
def get_source_rect(self): return self.dfl_dict['source_rect'] def get_source_rect(self): return self.dfl_dict['source_rect']
def get_source_landmarks(self): return np.array ( self.dfl_dict['source_landmarks'] ) def get_source_landmarks(self): return np.array ( self.dfl_dict['source_landmarks'] )

View file

@ -110,7 +110,7 @@ class Chunk(object):
def __str__(self): def __str__(self):
return "<Chunk '{name}' length={length} crc={crc:08X}>".format(**self.__dict__) return "<Chunk '{name}' length={length} crc={crc:08X}>".format(**self.__dict__)
class IHDR(Chunk): class IHDR(Chunk):
"""IHDR Chunk """IHDR Chunk
width, height, bit_depth, color_type, compression_method, width, height, bit_depth, color_type, compression_method,
@ -189,24 +189,24 @@ class IEND(Chunk):
class DFLChunk(Chunk): class DFLChunk(Chunk):
def __init__(self, dict_data=None): def __init__(self, dict_data=None):
super().__init__("fcWp") super().__init__("fcWp")
self.dict_data = dict_data self.dict_data = dict_data
def setDictData(self, dict_data): def setDictData(self, dict_data):
self.dict_data = dict_data self.dict_data = dict_data
def getDictData(self): def getDictData(self):
return self.dict_data return self.dict_data
@classmethod @classmethod
def load(cls, data): def load(cls, data):
inst = super().load(data) inst = super().load(data)
inst.dict_data = pickle.loads( inst.data ) inst.dict_data = pickle.loads( inst.data )
return inst return inst
def dump(self): def dump(self):
self.data = pickle.dumps (self.dict_data) self.data = pickle.dumps (self.dict_data)
return super().dump() return super().dump()
chunk_map = { chunk_map = {
b"IHDR": IHDR, b"IHDR": IHDR,
b"fcWp": DFLChunk, b"fcWp": DFLChunk,
@ -219,7 +219,7 @@ class DFLPNG(object):
self.length = 0 self.length = 0
self.chunks = [] self.chunks = []
self.fcwp_dict = None self.fcwp_dict = None
@staticmethod @staticmethod
def load_raw(filename): def load_raw(filename):
try: try:
@ -227,11 +227,11 @@ class DFLPNG(object):
data = f.read() data = f.read()
except: except:
raise FileNotFoundError(data) raise FileNotFoundError(data)
inst = DFLPNG() inst = DFLPNG()
inst.data = data inst.data = data
inst.length = len(data) inst.length = len(data)
if data[0:8] != PNG_HEADER: if data[0:8] != PNG_HEADER:
msg = "No Valid PNG header" msg = "No Valid PNG header"
raise ValueError(msg) raise ValueError(msg)
@ -244,26 +244,26 @@ class DFLPNG(object):
chunk = chunk_map.get(chunk_name, Chunk).load(data[chunk_start:chunk_end]) chunk = chunk_map.get(chunk_name, Chunk).load(data[chunk_start:chunk_end])
inst.chunks.append(chunk) inst.chunks.append(chunk)
chunk_start = chunk_end chunk_start = chunk_end
return inst return inst
@staticmethod @staticmethod
def load(filename): def load(filename):
try: try:
inst = DFLPNG.load_raw (filename) inst = DFLPNG.load_raw (filename)
inst.fcwp_dict = inst.getDFLDictData() inst.fcwp_dict = inst.getDFLDictData()
if (inst.fcwp_dict is not None) and ('face_type' not in inst.fcwp_dict.keys()): if (inst.fcwp_dict is not None) and ('face_type' not in inst.fcwp_dict.keys()):
inst.fcwp_dict['face_type'] = FaceType.toString (FaceType.FULL) inst.fcwp_dict['face_type'] = FaceType.toString (FaceType.FULL)
if inst.fcwp_dict == None: if inst.fcwp_dict == None:
return None return None
return inst return inst
except Exception as e: except Exception as e:
print(e) print(e)
return None return None
@staticmethod @staticmethod
def embed_data(filename, face_type=None, def embed_data(filename, face_type=None,
landmarks=None, landmarks=None,
@ -271,7 +271,7 @@ class DFLPNG(object):
source_rect=None, source_rect=None,
source_landmarks=None source_landmarks=None
): ):
inst = DFLPNG.load_raw (filename) inst = DFLPNG.load_raw (filename)
inst.setDFLDictData ({ inst.setDFLDictData ({
'face_type': face_type, 'face_type': face_type,
@ -280,7 +280,7 @@ class DFLPNG(object):
'source_rect': source_rect, 'source_rect': source_rect,
'source_landmarks': source_landmarks 'source_landmarks': source_landmarks
}) })
try: try:
with open(filename, "wb") as f: with open(filename, "wb") as f:
f.write ( inst.dump() ) f.write ( inst.dump() )
@ -292,7 +292,7 @@ class DFLPNG(object):
for chunk in self.chunks: for chunk in self.chunks:
data += chunk.dump() data += chunk.dump()
return data return data
def get_shape(self): def get_shape(self):
for chunk in self.chunks: for chunk in self.chunks:
if type(chunk) == IHDR: if type(chunk) == IHDR:
@ -301,34 +301,34 @@ class DFLPNG(object):
h = chunk.height h = chunk.height
return (h,w,c) return (h,w,c)
return (0,0,0) return (0,0,0)
def get_height(self): def get_height(self):
for chunk in self.chunks: for chunk in self.chunks:
if type(chunk) == IHDR: if type(chunk) == IHDR:
return chunk.height return chunk.height
return 0 return 0
def getDFLDictData(self): def getDFLDictData(self):
for chunk in self.chunks: for chunk in self.chunks:
if type(chunk) == DFLChunk: if type(chunk) == DFLChunk:
return chunk.getDictData() return chunk.getDictData()
return None return None
def setDFLDictData (self, dict_data=None): def setDFLDictData (self, dict_data=None):
for chunk in self.chunks: for chunk in self.chunks:
if type(chunk) == DFLChunk: if type(chunk) == DFLChunk:
self.chunks.remove(chunk) self.chunks.remove(chunk)
break break
if not dict_data is None: if not dict_data is None:
chunk = DFLChunk(dict_data) chunk = DFLChunk(dict_data)
self.chunks.insert(-1, chunk) self.chunks.insert(-1, chunk)
def get_face_type(self): return self.fcwp_dict['face_type'] def get_face_type(self): return self.fcwp_dict['face_type']
def get_landmarks(self): return np.array ( self.fcwp_dict['landmarks'] ) def get_landmarks(self): return np.array ( self.fcwp_dict['landmarks'] )
def get_source_filename(self): return self.fcwp_dict['source_filename'] def get_source_filename(self): return self.fcwp_dict['source_filename']
def get_source_rect(self): return self.fcwp_dict['source_rect'] def get_source_rect(self): return self.fcwp_dict['source_rect']
def get_source_landmarks(self): return np.array ( self.fcwp_dict['source_landmarks'] ) def get_source_landmarks(self): return np.array ( self.fcwp_dict['source_landmarks'] )
def __str__(self): def __str__(self):
return "<PNG length={length} chunks={}>".format(len(self.chunks), **self.__dict__) return "<PNG length={length} chunks={}>".format(len(self.chunks), **self.__dict__)

View file

@ -5,8 +5,8 @@ image_extensions = [".jpg", ".jpeg", ".png", ".tif", ".tiff"]
def get_image_paths(dir_path, image_extensions=image_extensions): def get_image_paths(dir_path, image_extensions=image_extensions):
dir_path = Path (dir_path) dir_path = Path (dir_path)
result = [] result = []
if dir_path.exists(): if dir_path.exists():
for x in list(scandir(str(dir_path))): for x in list(scandir(str(dir_path))):
if any([x.name.lower().endswith(ext) for ext in image_extensions]): if any([x.name.lower().endswith(ext) for ext in image_extensions]):
@ -14,25 +14,25 @@ def get_image_paths(dir_path, image_extensions=image_extensions):
return result return result
def get_image_unique_filestem_paths(dir_path, verbose_print_func=None): def get_image_unique_filestem_paths(dir_path, verbose_print_func=None):
result = get_image_paths(dir_path) result = get_image_paths(dir_path)
result_dup = set() result_dup = set()
for f in result[:]: for f in result[:]:
f_stem = Path(f).stem f_stem = Path(f).stem
if f_stem in result_dup: if f_stem in result_dup:
result.remove(f) result.remove(f)
if verbose_print_func is not None: if verbose_print_func is not None:
verbose_print_func ("Duplicate filenames are not allowed, skipping: %s" % Path(f).name ) verbose_print_func ("Duplicate filenames are not allowed, skipping: %s" % Path(f).name )
continue continue
result_dup.add(f_stem) result_dup.add(f_stem)
return result return result
def get_all_dir_names_startswith (dir_path, startswith): def get_all_dir_names_startswith (dir_path, startswith):
dir_path = Path (dir_path) dir_path = Path (dir_path)
startswith = startswith.lower() startswith = startswith.lower()
result = [] result = []
if dir_path.exists(): if dir_path.exists():
for x in list(scandir(str(dir_path))): for x in list(scandir(str(dir_path))):
if x.name.lower().startswith(startswith): if x.name.lower().startswith(startswith):
@ -42,7 +42,7 @@ def get_all_dir_names_startswith (dir_path, startswith):
def get_first_file_by_stem (dir_path, stem, exts=None): def get_first_file_by_stem (dir_path, stem, exts=None):
dir_path = Path (dir_path) dir_path = Path (dir_path)
stem = stem.lower() stem = stem.lower()
if dir_path.exists(): if dir_path.exists():
for x in list(scandir(str(dir_path))): for x in list(scandir(str(dir_path))):
if not x.is_file(): if not x.is_file():
@ -50,5 +50,5 @@ def get_first_file_by_stem (dir_path, stem, exts=None):
xp = Path(x.path) xp = Path(x.path)
if xp.stem.lower() == stem and (exts is None or xp.suffix.lower() in exts): if xp.stem.lower() == stem and (exts is None or xp.suffix.lower() in exts):
return xp return xp
return None return None

View file

@ -11,7 +11,7 @@ def cv2_imread(filename, flags=cv2.IMREAD_UNCHANGED):
return cv2.imdecode(numpyarray, flags) return cv2.imdecode(numpyarray, flags)
except: except:
return None return None
def cv2_imwrite(filename, img, *args): def cv2_imwrite(filename, img, *args):
ret, buf = cv2.imencode( Path(filename).suffix, img, *args) ret, buf = cv2.imencode( Path(filename).suffix, img, *args)
if ret == True: if ret == True:
@ -19,4 +19,4 @@ def cv2_imwrite(filename, img, *args):
with open(filename, "wb") as stream: with open(filename, "wb") as stream:
stream.write( buf ) stream.write( buf )
except: except:
pass pass

View file

@ -21,7 +21,7 @@ def reinhard_color_transfer(target, source, clip=False, preserve_paper=False, so
OpenCV image in BGR color space (the source image) OpenCV image in BGR color space (the source image)
target: NumPy array target: NumPy array
OpenCV image in BGR color space (the target image) OpenCV image in BGR color space (the target image)
clip: Should components of L*a*b* image be scaled by np.clip before clip: Should components of L*a*b* image be scaled by np.clip before
converting back to BGR color space? converting back to BGR color space?
If False then components will be min-max scaled appropriately. If False then components will be min-max scaled appropriately.
Clipping will keep target image brightness truer to the input. Clipping will keep target image brightness truer to the input.
@ -32,7 +32,7 @@ def reinhard_color_transfer(target, source, clip=False, preserve_paper=False, so
aesthetically pleasing results. aesthetically pleasing results.
If False then L*a*b* components will scaled using the reciprocal of If False then L*a*b* components will scaled using the reciprocal of
the scaling factor proposed in the paper. This method seems to produce the scaling factor proposed in the paper. This method seems to produce
more consistently aesthetically pleasing results more consistently aesthetically pleasing results
Returns: Returns:
------- -------
@ -40,13 +40,13 @@ def reinhard_color_transfer(target, source, clip=False, preserve_paper=False, so
OpenCV image (w, h, 3) NumPy array (uint8) OpenCV image (w, h, 3) NumPy array (uint8)
""" """
# convert the images from the RGB to L*ab* color space, being # convert the images from the RGB to L*ab* color space, being
# sure to utilizing the floating point data type (note: OpenCV # sure to utilizing the floating point data type (note: OpenCV
# expects floats to be 32-bit, so use that instead of 64-bit) # expects floats to be 32-bit, so use that instead of 64-bit)
source = cv2.cvtColor(source, cv2.COLOR_BGR2LAB).astype(np.float32) source = cv2.cvtColor(source, cv2.COLOR_BGR2LAB).astype(np.float32)
target = cv2.cvtColor(target, cv2.COLOR_BGR2LAB).astype(np.float32) target = cv2.cvtColor(target, cv2.COLOR_BGR2LAB).astype(np.float32)
# compute color statistics for the source and target images # compute color statistics for the source and target images
src_input = source if source_mask is None else source*source_mask src_input = source if source_mask is None else source*source_mask
tgt_input = target if target_mask is None else target*target_mask tgt_input = target if target_mask is None else target*target_mask
@ -86,7 +86,7 @@ def reinhard_color_transfer(target, source, clip=False, preserve_paper=False, so
# type # type
transfer = cv2.merge([l, a, b]) transfer = cv2.merge([l, a, b])
transfer = cv2.cvtColor(transfer.astype(np.uint8), cv2.COLOR_LAB2BGR) transfer = cv2.cvtColor(transfer.astype(np.uint8), cv2.COLOR_LAB2BGR)
# return the color transferred image # return the color transferred image
return transfer return transfer
@ -127,7 +127,7 @@ def linear_color_transfer(target_img, source_img, mode='pca', eps=1e-5):
matched_img[matched_img>1] = 1 matched_img[matched_img>1] = 1
matched_img[matched_img<0] = 0 matched_img[matched_img<0] = 0
return matched_img return matched_img
def lab_image_stats(image): def lab_image_stats(image):
# compute the mean and standard deviation of each channel # compute the mean and standard deviation of each channel
(l, a, b) = cv2.split(image) (l, a, b) = cv2.split(image)
@ -137,7 +137,7 @@ def lab_image_stats(image):
# return the color statistics # return the color statistics
return (lMean, lStd, aMean, aStd, bMean, bStd) return (lMean, lStd, aMean, aStd, bMean, bStd)
def _scale_array(arr, clip=True): def _scale_array(arr, clip=True):
if clip: if clip:
return np.clip(arr, 0, 255) return np.clip(arr, 0, 255)
@ -145,12 +145,12 @@ def _scale_array(arr, clip=True):
mn = arr.min() mn = arr.min()
mx = arr.max() mx = arr.max()
scale_range = (max([mn, 0]), min([mx, 255])) scale_range = (max([mn, 0]), min([mx, 255]))
if mn < scale_range[0] or mx > scale_range[1]: if mn < scale_range[0] or mx > scale_range[1]:
return (scale_range[1] - scale_range[0]) * (arr - mn) / (mx - mn) + scale_range[0] return (scale_range[1] - scale_range[0]) * (arr - mn) / (mx - mn) + scale_range[0]
return arr return arr
def channel_hist_match(source, template, hist_match_threshold=255, mask=None): def channel_hist_match(source, template, hist_match_threshold=255, mask=None):
# Code borrowed from: # Code borrowed from:
# https://stackoverflow.com/questions/32655686/histogram-matching-of-two-images-in-python-2-x # https://stackoverflow.com/questions/32655686/histogram-matching-of-two-images-in-python-2-x
@ -179,22 +179,22 @@ def channel_hist_match(source, template, hist_match_threshold=255, mask=None):
t_quantiles = 255 * t_quantiles / t_quantiles[-1] t_quantiles = 255 * t_quantiles / t_quantiles[-1]
interp_t_values = np.interp(s_quantiles, t_quantiles, t_values) interp_t_values = np.interp(s_quantiles, t_quantiles, t_values)
return interp_t_values[bin_idx].reshape(oldshape) return interp_t_values[bin_idx].reshape(oldshape)
def color_hist_match(src_im, tar_im, hist_match_threshold=255): def color_hist_match(src_im, tar_im, hist_match_threshold=255):
h,w,c = src_im.shape h,w,c = src_im.shape
matched_R = channel_hist_match(src_im[:,:,0], tar_im[:,:,0], hist_match_threshold, None) matched_R = channel_hist_match(src_im[:,:,0], tar_im[:,:,0], hist_match_threshold, None)
matched_G = channel_hist_match(src_im[:,:,1], tar_im[:,:,1], hist_match_threshold, None) matched_G = channel_hist_match(src_im[:,:,1], tar_im[:,:,1], hist_match_threshold, None)
matched_B = channel_hist_match(src_im[:,:,2], tar_im[:,:,2], hist_match_threshold, None) matched_B = channel_hist_match(src_im[:,:,2], tar_im[:,:,2], hist_match_threshold, None)
to_stack = (matched_R, matched_G, matched_B) to_stack = (matched_R, matched_G, matched_B)
for i in range(3, c): for i in range(3, c):
to_stack += ( src_im[:,:,i],) to_stack += ( src_im[:,:,i],)
matched = np.stack(to_stack, axis=-1).astype(src_im.dtype) matched = np.stack(to_stack, axis=-1).astype(src_im.dtype)
return matched return matched
pil_fonts = {} pil_fonts = {}
def _get_pil_font (font, size): def _get_pil_font (font, size):
@ -204,65 +204,65 @@ def _get_pil_font (font, size):
if font_str_id not in pil_fonts.keys(): if font_str_id not in pil_fonts.keys():
pil_fonts[font_str_id] = ImageFont.truetype(font + ".ttf", size=size, encoding="unic") pil_fonts[font_str_id] = ImageFont.truetype(font + ".ttf", size=size, encoding="unic")
pil_font = pil_fonts[font_str_id] pil_font = pil_fonts[font_str_id]
return pil_font return pil_font
except: except:
return ImageFont.load_default() return ImageFont.load_default()
def get_text_image( shape, text, color=(1,1,1), border=0.2, font=None): def get_text_image( shape, text, color=(1,1,1), border=0.2, font=None):
try: try:
size = shape[1] size = shape[1]
pil_font = _get_pil_font( localization.get_default_ttf_font_name() , size) pil_font = _get_pil_font( localization.get_default_ttf_font_name() , size)
text_width, text_height = pil_font.getsize(text) text_width, text_height = pil_font.getsize(text)
canvas = Image.new('RGB', shape[0:2], (0,0,0) ) canvas = Image.new('RGB', shape[0:2], (0,0,0) )
draw = ImageDraw.Draw(canvas) draw = ImageDraw.Draw(canvas)
offset = ( 0, 0) offset = ( 0, 0)
draw.text(offset, text, font=pil_font, fill=tuple((np.array(color)*255).astype(np.int)) ) draw.text(offset, text, font=pil_font, fill=tuple((np.array(color)*255).astype(np.int)) )
result = np.asarray(canvas) / 255 result = np.asarray(canvas) / 255
if shape[2] != 3: if shape[2] != 3:
result = np.concatenate ( (result, np.ones ( (shape[1],) + (shape[0],) + (shape[2]-3,)) ), axis=2 ) result = np.concatenate ( (result, np.ones ( (shape[1],) + (shape[0],) + (shape[2]-3,)) ), axis=2 )
return result return result
except: except:
return np.zeros ( (shape[1], shape[0], shape[2]), dtype=np.float32 ) return np.zeros ( (shape[1], shape[0], shape[2]), dtype=np.float32 )
def draw_text( image, rect, text, color=(1,1,1), border=0.2, font=None): def draw_text( image, rect, text, color=(1,1,1), border=0.2, font=None):
h,w,c = image.shape h,w,c = image.shape
l,t,r,b = rect l,t,r,b = rect
l = np.clip (l, 0, w-1) l = np.clip (l, 0, w-1)
r = np.clip (r, 0, w-1) r = np.clip (r, 0, w-1)
t = np.clip (t, 0, h-1) t = np.clip (t, 0, h-1)
b = np.clip (b, 0, h-1) b = np.clip (b, 0, h-1)
image[t:b, l:r] += get_text_image ( (r-l,b-t,c) , text, color, border, font ) image[t:b, l:r] += get_text_image ( (r-l,b-t,c) , text, color, border, font )
def draw_text_lines (image, rect, text_lines, color=(1,1,1), border=0.2, font=None): def draw_text_lines (image, rect, text_lines, color=(1,1,1), border=0.2, font=None):
text_lines_len = len(text_lines) text_lines_len = len(text_lines)
if text_lines_len == 0: if text_lines_len == 0:
return return
l,t,r,b = rect l,t,r,b = rect
h = b-t h = b-t
h_per_line = h // text_lines_len h_per_line = h // text_lines_len
for i in range(0, text_lines_len): for i in range(0, text_lines_len):
draw_text (image, (l, i*h_per_line, r, (i+1)*h_per_line), text_lines[i], color, border, font) draw_text (image, (l, i*h_per_line, r, (i+1)*h_per_line), text_lines[i], color, border, font)
def get_draw_text_lines ( image, rect, text_lines, color=(1,1,1), border=0.2, font=None): def get_draw_text_lines ( image, rect, text_lines, color=(1,1,1), border=0.2, font=None):
image = np.zeros ( image.shape, dtype=np.float ) image = np.zeros ( image.shape, dtype=np.float )
draw_text_lines ( image, rect, text_lines, color, border, font) draw_text_lines ( image, rect, text_lines, color, border, font)
return image return image
def draw_polygon (image, points, color, thickness = 1): def draw_polygon (image, points, color, thickness = 1):
points_len = len(points) points_len = len(points)
for i in range (0, points_len): for i in range (0, points_len):
p0 = tuple( points[i] ) p0 = tuple( points[i] )
p1 = tuple( points[ (i+1) % points_len] ) p1 = tuple( points[ (i+1) % points_len] )
cv2.line (image, p0, p1, color, thickness=thickness) cv2.line (image, p0, p1, color, thickness=thickness)
def draw_rect(image, rect, color, thickness=1): def draw_rect(image, rect, color, thickness=1):
l,t,r,b = rect l,t,r,b = rect
draw_polygon (image, [ (l,t), (r,t), (r,b), (l,b ) ], color, thickness) draw_polygon (image, [ (l,t), (r,t), (r,b), (l,b ) ], color, thickness)
@ -272,40 +272,40 @@ def rectContains(rect, point) :
def applyAffineTransform(src, srcTri, dstTri, size) : def applyAffineTransform(src, srcTri, dstTri, size) :
warpMat = cv2.getAffineTransform( np.float32(srcTri), np.float32(dstTri) ) warpMat = cv2.getAffineTransform( np.float32(srcTri), np.float32(dstTri) )
return cv2.warpAffine( src, warpMat, (size[0], size[1]), None, flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT_101 ) return cv2.warpAffine( src, warpMat, (size[0], size[1]), None, flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT_101 )
def morphTriangle(dst_img, src_img, st, dt) : def morphTriangle(dst_img, src_img, st, dt) :
(h,w,c) = dst_img.shape (h,w,c) = dst_img.shape
sr = np.array( cv2.boundingRect(np.float32(st)) ) sr = np.array( cv2.boundingRect(np.float32(st)) )
dr = np.array( cv2.boundingRect(np.float32(dt)) ) dr = np.array( cv2.boundingRect(np.float32(dt)) )
sRect = st - sr[0:2] sRect = st - sr[0:2]
dRect = dt - dr[0:2] dRect = dt - dr[0:2]
d_mask = np.zeros((dr[3], dr[2], c), dtype = np.float32) d_mask = np.zeros((dr[3], dr[2], c), dtype = np.float32)
cv2.fillConvexPoly(d_mask, np.int32(dRect), (1.0,)*c, 8, 0); cv2.fillConvexPoly(d_mask, np.int32(dRect), (1.0,)*c, 8, 0);
imgRect = src_img[sr[1]:sr[1] + sr[3], sr[0]:sr[0] + sr[2]] imgRect = src_img[sr[1]:sr[1] + sr[3], sr[0]:sr[0] + sr[2]]
size = (dr[2], dr[3]) size = (dr[2], dr[3])
warpImage1 = applyAffineTransform(imgRect, sRect, dRect, size) warpImage1 = applyAffineTransform(imgRect, sRect, dRect, size)
if c == 1: if c == 1:
warpImage1 = np.expand_dims( warpImage1, -1 ) warpImage1 = np.expand_dims( warpImage1, -1 )
dst_img[dr[1]:dr[1]+dr[3], dr[0]:dr[0]+dr[2]] = dst_img[dr[1]:dr[1]+dr[3], dr[0]:dr[0]+dr[2]]*(1-d_mask) + warpImage1 * d_mask dst_img[dr[1]:dr[1]+dr[3], dr[0]:dr[0]+dr[2]] = dst_img[dr[1]:dr[1]+dr[3], dr[0]:dr[0]+dr[2]]*(1-d_mask) + warpImage1 * d_mask
def morph_by_points (image, sp, dp): def morph_by_points (image, sp, dp):
if sp.shape != dp.shape: if sp.shape != dp.shape:
raise ValueError ('morph_by_points() sp.shape != dp.shape') raise ValueError ('morph_by_points() sp.shape != dp.shape')
(h,w,c) = image.shape (h,w,c) = image.shape
result_image = np.zeros(image.shape, dtype = image.dtype) result_image = np.zeros(image.shape, dtype = image.dtype)
for tri in Delaunay(dp).simplices: for tri in Delaunay(dp).simplices:
morphTriangle(result_image, image, sp[tri], dp[tri]) morphTriangle(result_image, image, sp[tri], dp[tri])
return result_image return result_image
def equalize_and_stack_square (images, axis=1): def equalize_and_stack_square (images, axis=1):
max_c = max ([ 1 if len(image.shape) == 2 else image.shape[2] for image in images ] ) max_c = max ([ 1 if len(image.shape) == 2 else image.shape[2] for image in images ] )
target_wh = 99999 target_wh = 99999
for i,image in enumerate(images): for i,image in enumerate(images):
if len(image.shape) == 2: if len(image.shape) == 2:
@ -313,113 +313,112 @@ def equalize_and_stack_square (images, axis=1):
c = 1 c = 1
else: else:
h,w,c = image.shape h,w,c = image.shape
if h < target_wh: if h < target_wh:
target_wh = h target_wh = h
if w < target_wh: if w < target_wh:
target_wh = w target_wh = w
for i,image in enumerate(images): for i,image in enumerate(images):
if len(image.shape) == 2: if len(image.shape) == 2:
h,w = image.shape h,w = image.shape
c = 1 c = 1
else: else:
h,w,c = image.shape h,w,c = image.shape
if c < max_c: if c < max_c:
if c == 1: if c == 1:
if len(image.shape) == 2: if len(image.shape) == 2:
image = np.expand_dims ( image, -1 ) image = np.expand_dims ( image, -1 )
image = np.concatenate ( (image,)*max_c, -1 ) image = np.concatenate ( (image,)*max_c, -1 )
elif c == 2: #GA elif c == 2: #GA
image = np.expand_dims ( image[...,0], -1 ) image = np.expand_dims ( image[...,0], -1 )
image = np.concatenate ( (image,)*max_c, -1 ) image = np.concatenate ( (image,)*max_c, -1 )
else: else:
image = np.concatenate ( (image, np.ones((h,w,max_c - c))), -1 ) image = np.concatenate ( (image, np.ones((h,w,max_c - c))), -1 )
if h != target_wh or w != target_wh: if h != target_wh or w != target_wh:
image = cv2.resize ( image, (target_wh, target_wh) ) image = cv2.resize ( image, (target_wh, target_wh) )
h,w,c = image.shape h,w,c = image.shape
images[i] = image images[i] = image
return np.concatenate ( images, axis = 1 ) return np.concatenate ( images, axis = 1 )
def bgr2hsv (img): def bgr2hsv (img):
return cv2.cvtColor(img, cv2.COLOR_BGR2HSV) return cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
def hsv2bgr (img): def hsv2bgr (img):
return cv2.cvtColor(img, cv2.COLOR_HSV2BGR) return cv2.cvtColor(img, cv2.COLOR_HSV2BGR)
def bgra2hsva (img): def bgra2hsva (img):
return np.concatenate ( (cv2.cvtColor(img[...,0:3], cv2.COLOR_BGR2HSV ), np.expand_dims (img[...,3], -1)), -1 ) return np.concatenate ( (cv2.cvtColor(img[...,0:3], cv2.COLOR_BGR2HSV ), np.expand_dims (img[...,3], -1)), -1 )
def bgra2hsva_list (imgs): def bgra2hsva_list (imgs):
return [ bgra2hsva(img) for img in imgs ] return [ bgra2hsva(img) for img in imgs ]
def hsva2bgra (img): def hsva2bgra (img):
return np.concatenate ( (cv2.cvtColor(img[...,0:3], cv2.COLOR_HSV2BGR ), np.expand_dims (img[...,3], -1)), -1 ) return np.concatenate ( (cv2.cvtColor(img[...,0:3], cv2.COLOR_HSV2BGR ), np.expand_dims (img[...,3], -1)), -1 )
def hsva2bgra_list (imgs): def hsva2bgra_list (imgs):
return [ hsva2bgra(img) for img in imgs ] return [ hsva2bgra(img) for img in imgs ]
def gen_warp_params (source, flip, rotation_range=[-10,10], scale_range=[-0.5, 0.5], tx_range=[-0.05, 0.05], ty_range=[-0.05, 0.05] ): def gen_warp_params (source, flip, rotation_range=[-10,10], scale_range=[-0.5, 0.5], tx_range=[-0.05, 0.05], ty_range=[-0.05, 0.05] ):
h,w,c = source.shape h,w,c = source.shape
if (h != w) or (w != 64 and w != 128 and w != 256 and w != 512 and w != 1024): if (h != w) or (w != 64 and w != 128 and w != 256 and w != 512 and w != 1024):
raise ValueError ('TrainingDataGenerator accepts only square power of 2 images.') raise ValueError ('TrainingDataGenerator accepts only square power of 2 images.')
rotation = np.random.uniform( rotation_range[0], rotation_range[1] ) rotation = np.random.uniform( rotation_range[0], rotation_range[1] )
scale = np.random.uniform(1 +scale_range[0], 1 +scale_range[1]) scale = np.random.uniform(1 +scale_range[0], 1 +scale_range[1])
tx = np.random.uniform( tx_range[0], tx_range[1] ) tx = np.random.uniform( tx_range[0], tx_range[1] )
ty = np.random.uniform( ty_range[0], ty_range[1] ) ty = np.random.uniform( ty_range[0], ty_range[1] )
#random warp by grid #random warp by grid
cell_size = [ w // (2**i) for i in range(1,4) ] [ np.random.randint(3) ] cell_size = [ w // (2**i) for i in range(1,4) ] [ np.random.randint(3) ]
cell_count = w // cell_size + 1 cell_count = w // cell_size + 1
grid_points = np.linspace( 0, w, cell_count) grid_points = np.linspace( 0, w, cell_count)
mapx = np.broadcast_to(grid_points, (cell_count, cell_count)).copy() mapx = np.broadcast_to(grid_points, (cell_count, cell_count)).copy()
mapy = mapx.T mapy = mapx.T
mapx[1:-1,1:-1] = mapx[1:-1,1:-1] + random_utils.random_normal( size=(cell_count-2, cell_count-2) )*(cell_size*0.24) mapx[1:-1,1:-1] = mapx[1:-1,1:-1] + random_utils.random_normal( size=(cell_count-2, cell_count-2) )*(cell_size*0.24)
mapy[1:-1,1:-1] = mapy[1:-1,1:-1] + random_utils.random_normal( size=(cell_count-2, cell_count-2) )*(cell_size*0.24) mapy[1:-1,1:-1] = mapy[1:-1,1:-1] + random_utils.random_normal( size=(cell_count-2, cell_count-2) )*(cell_size*0.24)
half_cell_size = cell_size // 2 half_cell_size = cell_size // 2
mapx = cv2.resize(mapx, (w+cell_size,)*2 )[half_cell_size:-half_cell_size-1,half_cell_size:-half_cell_size-1].astype(np.float32) mapx = cv2.resize(mapx, (w+cell_size,)*2 )[half_cell_size:-half_cell_size-1,half_cell_size:-half_cell_size-1].astype(np.float32)
mapy = cv2.resize(mapy, (w+cell_size,)*2 )[half_cell_size:-half_cell_size-1,half_cell_size:-half_cell_size-1].astype(np.float32) mapy = cv2.resize(mapy, (w+cell_size,)*2 )[half_cell_size:-half_cell_size-1,half_cell_size:-half_cell_size-1].astype(np.float32)
#random transform #random transform
random_transform_mat = cv2.getRotationMatrix2D((w // 2, w // 2), rotation, scale) random_transform_mat = cv2.getRotationMatrix2D((w // 2, w // 2), rotation, scale)
random_transform_mat[:, 2] += (tx*w, ty*w) random_transform_mat[:, 2] += (tx*w, ty*w)
params = dict() params = dict()
params['mapx'] = mapx params['mapx'] = mapx
params['mapy'] = mapy params['mapy'] = mapy
params['rmat'] = random_transform_mat params['rmat'] = random_transform_mat
params['w'] = w params['w'] = w
params['flip'] = flip and np.random.randint(10) < 4 params['flip'] = flip and np.random.randint(10) < 4
return params return params
def warp_by_params (params, img, warp, transform, flip, is_border_replicate): def warp_by_params (params, img, warp, transform, flip, is_border_replicate):
if warp: if warp:
img = cv2.remap(img, params['mapx'], params['mapy'], cv2.INTER_CUBIC ) img = cv2.remap(img, params['mapx'], params['mapy'], cv2.INTER_CUBIC )
if transform: if transform:
img = cv2.warpAffine( img, params['rmat'], (params['w'], params['w']), borderMode=(cv2.BORDER_REPLICATE if is_border_replicate else cv2.BORDER_CONSTANT), flags=cv2.INTER_CUBIC ) img = cv2.warpAffine( img, params['rmat'], (params['w'], params['w']), borderMode=(cv2.BORDER_REPLICATE if is_border_replicate else cv2.BORDER_CONSTANT), flags=cv2.INTER_CUBIC )
if flip and params['flip']: if flip and params['flip']:
img = img[:,::-1,:] img = img[:,::-1,:]
return img return img
#n_colors = [0..256] #n_colors = [0..256]
def reduce_colors (img_bgr, n_colors): def reduce_colors (img_bgr, n_colors):
img_rgb = (cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) * 255.0).astype(np.uint8) img_rgb = (cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) * 255.0).astype(np.uint8)
img_rgb_pil = Image.fromarray(img_rgb) img_rgb_pil = Image.fromarray(img_rgb)
img_rgb_pil_p = img_rgb_pil.convert('P', palette=Image.ADAPTIVE, colors=n_colors) img_rgb_pil_p = img_rgb_pil.convert('P', palette=Image.ADAPTIVE, colors=n_colors)
img_rgb_p = img_rgb_pil_p.convert('RGB') img_rgb_p = img_rgb_pil_p.convert('RGB')
img_bgr = cv2.cvtColor( np.array(img_rgb_p, dtype=np.float32) / 255.0, cv2.COLOR_RGB2BGR ) img_bgr = cv2.cvtColor( np.array(img_rgb_p, dtype=np.float32) / 255.0, cv2.COLOR_RGB2BGR )
return img_bgr return img_bgr

View file

@ -5,7 +5,7 @@ import time
class ThisThreadGenerator(object): class ThisThreadGenerator(object):
def __init__(self, generator_func, user_param=None): def __init__(self, generator_func, user_param=None):
super().__init__() super().__init__()
self.generator_func = generator_func self.generator_func = generator_func
self.user_param = user_param self.user_param = user_param
@ -13,30 +13,30 @@ class ThisThreadGenerator(object):
def __iter__(self): def __iter__(self):
return self return self
def __next__(self): def __next__(self):
if not self.initialized: if not self.initialized:
self.initialized = True self.initialized = True
self.generator_func = self.generator_func(self.user_param) self.generator_func = self.generator_func(self.user_param)
return next(self.generator_func) return next(self.generator_func)
class SubprocessGenerator(object): class SubprocessGenerator(object):
def __init__(self, generator_func, user_param=None, prefetch=2): def __init__(self, generator_func, user_param=None, prefetch=2):
super().__init__() super().__init__()
self.prefetch = prefetch self.prefetch = prefetch
self.generator_func = generator_func self.generator_func = generator_func
self.user_param = user_param self.user_param = user_param
self.sc_queue = multiprocessing.Queue() self.sc_queue = multiprocessing.Queue()
self.cs_queue = multiprocessing.Queue() self.cs_queue = multiprocessing.Queue()
self.p = None self.p = None
def process_func(self): def process_func(self):
self.generator_func = self.generator_func(self.user_param) self.generator_func = self.generator_func(self.user_param)
while True: while True:
while self.prefetch > -1: while self.prefetch > -1:
try: try:
gen_data = next (self.generator_func) gen_data = next (self.generator_func)
except StopIteration: except StopIteration:
self.cs_queue.put (None) self.cs_queue.put (None)
return return
@ -47,17 +47,17 @@ class SubprocessGenerator(object):
def __iter__(self): def __iter__(self):
return self return self
def __next__(self): def __next__(self):
if self.p == None: if self.p == None:
self.p = multiprocessing.Process(target=self.process_func, args=()) self.p = multiprocessing.Process(target=self.process_func, args=())
self.p.daemon = True self.p.daemon = True
self.p.start() self.p.start()
gen_data = self.cs_queue.get() gen_data = self.cs_queue.get()
if gen_data is None: if gen_data is None:
self.p.terminate() self.p.terminate()
self.p.join() self.p.join()
raise StopIteration() raise StopIteration()
self.sc_queue.put (1) self.sc_queue.put (1)
return gen_data return gen_data

View file

@ -4,12 +4,12 @@ import sys
if sys.platform[0:3] == 'win': if sys.platform[0:3] == 'win':
from ctypes import windll from ctypes import windll
from ctypes import wintypes from ctypes import wintypes
def set_process_lowest_prio(): def set_process_lowest_prio():
try: try:
if sys.platform[0:3] == 'win': if sys.platform[0:3] == 'win':
GetCurrentProcess = windll.kernel32.GetCurrentProcess GetCurrentProcess = windll.kernel32.GetCurrentProcess
GetCurrentProcess.restype = wintypes.HANDLE GetCurrentProcess.restype = wintypes.HANDLE
SetPriorityClass = windll.kernel32.SetPriorityClass SetPriorityClass = windll.kernel32.SetPriorityClass
SetPriorityClass.argtypes = (wintypes.HANDLE, wintypes.DWORD) SetPriorityClass.argtypes = (wintypes.HANDLE, wintypes.DWORD)
SetPriorityClass ( GetCurrentProcess(), 0x00000040 ) SetPriorityClass ( GetCurrentProcess(), 0x00000040 )
@ -19,7 +19,7 @@ def set_process_lowest_prio():
os.nice(20) os.nice(20)
except: except:
print("Unable to set lowest process priority") print("Unable to set lowest process priority")
def set_process_dpi_aware(): def set_process_dpi_aware():
if sys.platform[0:3] == 'win': if sys.platform[0:3] == 'win':
windll.user32.SetProcessDPIAware(True) windll.user32.SetProcessDPIAware(True)

View file

@ -3,12 +3,12 @@ import numpy as np
def random_normal( size=(1,), trunc_val = 2.5 ): def random_normal( size=(1,), trunc_val = 2.5 ):
len = np.array(size).prod() len = np.array(size).prod()
result = np.empty ( (len,) , dtype=np.float32) result = np.empty ( (len,) , dtype=np.float32)
for i in range (len): for i in range (len):
while True: while True:
x = np.random.normal() x = np.random.normal()
if x >= -trunc_val and x <= trunc_val: if x >= -trunc_val and x <= trunc_val:
break break
result[i] = (x / trunc_val) result[i] = (x / trunc_val)
return result.reshape ( size ) return result.reshape ( size )

View file

@ -11,26 +11,26 @@ class suppress_stdout_stderr(object):
self.old_stdout_fileno = os.dup ( sys.stdout.fileno() ) self.old_stdout_fileno = os.dup ( sys.stdout.fileno() )
self.old_stderr_fileno = os.dup ( sys.stderr.fileno() ) self.old_stderr_fileno = os.dup ( sys.stderr.fileno() )
self.old_stdout = sys.stdout self.old_stdout = sys.stdout
self.old_stderr = sys.stderr self.old_stderr = sys.stderr
os.dup2 ( self.outnull_file.fileno(), self.old_stdout_fileno_undup ) os.dup2 ( self.outnull_file.fileno(), self.old_stdout_fileno_undup )
os.dup2 ( self.errnull_file.fileno(), self.old_stderr_fileno_undup ) os.dup2 ( self.errnull_file.fileno(), self.old_stderr_fileno_undup )
sys.stdout = self.outnull_file sys.stdout = self.outnull_file
sys.stderr = self.errnull_file sys.stderr = self.errnull_file
return self return self
def __exit__(self, *_): def __exit__(self, *_):
sys.stdout = self.old_stdout sys.stdout = self.old_stdout
sys.stderr = self.old_stderr sys.stderr = self.old_stderr
os.dup2 ( self.old_stdout_fileno, self.old_stdout_fileno_undup ) os.dup2 ( self.old_stdout_fileno, self.old_stdout_fileno_undup )
os.dup2 ( self.old_stderr_fileno, self.old_stderr_fileno_undup ) os.dup2 ( self.old_stderr_fileno, self.old_stderr_fileno_undup )
os.close ( self.old_stdout_fileno ) os.close ( self.old_stdout_fileno )
os.close ( self.old_stderr_fileno ) os.close ( self.old_stderr_fileno )
self.outnull_file.close() self.outnull_file.close()
self.errnull_file.close() self.errnull_file.close()

View file

@ -1,6 +1,5 @@
import struct import struct
def struct_unpack(data, counter, fmt): def struct_unpack(data, counter, fmt):
fmt_size = struct.calcsize(fmt) fmt_size = struct.calcsize(fmt)
return (counter+fmt_size,) + struct.unpack (fmt, data[counter:counter+fmt_size]) return (counter+fmt_size,) + struct.unpack (fmt, data[counter:counter+fmt_size])