remove all trailing spaces

This commit is contained in:
camjac251 2021-06-20 03:57:57 -05:00
commit 5c27ef8883
No known key found for this signature in database
GPG key ID: BEB14628800F8CE9
76 changed files with 713 additions and 717 deletions

8
.vscode/launch.json vendored
View file

@ -3,9 +3,9 @@
// Hover to view descriptions of existing attributes. // Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0", "version": "0.2.0",
"configurations": [ "configurations": [
{ {
"name": "DFL train TEST", "name": "DFL train TEST",
"subProcess": true, "subProcess": true,
"justMyCode": true, "justMyCode": true,
@ -14,8 +14,8 @@
"program": "${env:DFL_ROOT}\\main.py", "program": "${env:DFL_ROOT}\\main.py",
"pythonPath": "${env:PYTHONEXECUTABLE}", "pythonPath": "${env:PYTHONEXECUTABLE}",
"cwd": "${env:WORKSPACE}", "cwd": "${env:WORKSPACE}",
"console": "integratedTerminal", "console": "integratedTerminal",
"args": ["train", "args": ["train",
"--training-data-src-dir", "${env:WORKSPACE}\\data_src\\aligned", "--training-data-src-dir", "${env:WORKSPACE}\\data_src\\aligned",
"--training-data-dst-dir", "${env:WORKSPACE}\\data_dst\\aligned", "--training-data-dst-dir", "${env:WORKSPACE}\\data_dst\\aligned",
"--model-dir", "${env:WORKSPACE}\\model", "--model-dir", "${env:WORKSPACE}\\model",

View file

@ -287,7 +287,7 @@ class DFLJPG(object):
return None return None
return mask_buf return mask_buf
def get_xseg_mask(self): def get_xseg_mask(self):
mask_buf = self.dfl_dict.get('xseg_mask',None) mask_buf = self.dfl_dict.get('xseg_mask',None)
if mask_buf is None: if mask_buf is None:

View file

@ -1,7 +1,7 @@
from PyQt5.QtCore import * from PyQt5.QtCore import *
from PyQt5.QtGui import * from PyQt5.QtGui import *
from PyQt5.QtWidgets import * from PyQt5.QtWidgets import *
class QCursorDB(): class QCursorDB():
@staticmethod @staticmethod
def initialize(cursor_path): def initialize(cursor_path):

View file

@ -23,4 +23,3 @@ class QIconDB():
QIconDB.view_baked = QIcon ( str(icon_path / 'view_baked.png') ) QIconDB.view_baked = QIcon ( str(icon_path / 'view_baked.png') )
QIconDB.view_xseg = QIcon ( str(icon_path / 'view_xseg.png') ) QIconDB.view_xseg = QIcon ( str(icon_path / 'view_xseg.png') )
QIconDB.view_xseg_overlay = QIcon ( str(icon_path / 'view_xseg_overlay.png') ) QIconDB.view_xseg_overlay = QIcon ( str(icon_path / 'view_xseg_overlay.png') )

View file

@ -2,101 +2,101 @@ from localization import system_language
class QStringDB(): class QStringDB():
@staticmethod @staticmethod
def initialize(): def initialize():
lang = system_language lang = system_language
if lang not in ['en','ru','zh']: if lang not in ['en','ru','zh']:
lang = 'en' lang = 'en'
QStringDB.btn_poly_color_red_tip = { 'en' : 'Poly color scheme red', QStringDB.btn_poly_color_red_tip = { 'en' : 'Poly color scheme red',
'ru' : 'Красная цветовая схема полигонов', 'ru' : 'Красная цветовая схема полигонов',
'zh' : '选区配色方案红色', 'zh' : '选区配色方案红色',
}[lang] }[lang]
QStringDB.btn_poly_color_green_tip = { 'en' : 'Poly color scheme green', QStringDB.btn_poly_color_green_tip = { 'en' : 'Poly color scheme green',
'ru' : 'Зелёная цветовая схема полигонов', 'ru' : 'Зелёная цветовая схема полигонов',
'zh' : '选区配色方案绿色', 'zh' : '选区配色方案绿色',
}[lang] }[lang]
QStringDB.btn_poly_color_blue_tip = { 'en' : 'Poly color scheme blue', QStringDB.btn_poly_color_blue_tip = { 'en' : 'Poly color scheme blue',
'ru' : 'Синяя цветовая схема полигонов', 'ru' : 'Синяя цветовая схема полигонов',
'zh' : '选区配色方案蓝色', 'zh' : '选区配色方案蓝色',
}[lang] }[lang]
QStringDB.btn_view_baked_mask_tip = { 'en' : 'View baked mask', QStringDB.btn_view_baked_mask_tip = { 'en' : 'View baked mask',
'ru' : 'Посмотреть запечёную маску', 'ru' : 'Посмотреть запечёную маску',
'zh' : '查看遮罩通道', 'zh' : '查看遮罩通道',
}[lang] }[lang]
QStringDB.btn_view_xseg_mask_tip = { 'en' : 'View trained XSeg mask', QStringDB.btn_view_xseg_mask_tip = { 'en' : 'View trained XSeg mask',
'ru' : 'Посмотреть тренированную XSeg маску', 'ru' : 'Посмотреть тренированную XSeg маску',
'zh' : '查看导入后的XSeg遮罩', 'zh' : '查看导入后的XSeg遮罩',
}[lang] }[lang]
QStringDB.btn_view_xseg_overlay_mask_tip = { 'en' : 'View trained XSeg mask overlay face', QStringDB.btn_view_xseg_overlay_mask_tip = { 'en' : 'View trained XSeg mask overlay face',
'ru' : 'Посмотреть тренированную XSeg маску поверх лица', 'ru' : 'Посмотреть тренированную XSeg маску поверх лица',
'zh' : '查看导入后的XSeg遮罩于脸上方', 'zh' : '查看导入后的XSeg遮罩于脸上方',
}[lang] }[lang]
QStringDB.btn_poly_type_include_tip = { 'en' : 'Poly include mode', QStringDB.btn_poly_type_include_tip = { 'en' : 'Poly include mode',
'ru' : 'Режим полигонов - включение', 'ru' : 'Режим полигонов - включение',
'zh' : '包含选区模式', 'zh' : '包含选区模式',
}[lang] }[lang]
QStringDB.btn_poly_type_exclude_tip = { 'en' : 'Poly exclude mode', QStringDB.btn_poly_type_exclude_tip = { 'en' : 'Poly exclude mode',
'ru' : 'Режим полигонов - исключение', 'ru' : 'Режим полигонов - исключение',
'zh' : '排除选区模式', 'zh' : '排除选区模式',
}[lang] }[lang]
QStringDB.btn_undo_pt_tip = { 'en' : 'Undo point', QStringDB.btn_undo_pt_tip = { 'en' : 'Undo point',
'ru' : 'Отменить точку', 'ru' : 'Отменить точку',
'zh' : '撤消点', 'zh' : '撤消点',
}[lang] }[lang]
QStringDB.btn_redo_pt_tip = { 'en' : 'Redo point', QStringDB.btn_redo_pt_tip = { 'en' : 'Redo point',
'ru' : 'Повторить точку', 'ru' : 'Повторить точку',
'zh' : '重做点', 'zh' : '重做点',
}[lang] }[lang]
QStringDB.btn_delete_poly_tip = { 'en' : 'Delete poly', QStringDB.btn_delete_poly_tip = { 'en' : 'Delete poly',
'ru' : 'Удалить полигон', 'ru' : 'Удалить полигон',
'zh' : '删除选区', 'zh' : '删除选区',
}[lang] }[lang]
QStringDB.btn_pt_edit_mode_tip = { 'en' : 'Add/delete point mode ( HOLD CTRL )', QStringDB.btn_pt_edit_mode_tip = { 'en' : 'Add/delete point mode ( HOLD CTRL )',
'ru' : 'Режим добавления/удаления точек ( удерживайте CTRL )', 'ru' : 'Режим добавления/удаления точек ( удерживайте CTRL )',
'zh' : '点加/删除模式 ( 按住CTRL )', 'zh' : '点加/删除模式 ( 按住CTRL )',
}[lang] }[lang]
QStringDB.btn_view_lock_center_tip = { 'en' : 'Lock cursor at the center ( HOLD SHIFT )', QStringDB.btn_view_lock_center_tip = { 'en' : 'Lock cursor at the center ( HOLD SHIFT )',
'ru' : 'Заблокировать курсор в центре ( удерживайте SHIFT )', 'ru' : 'Заблокировать курсор в центре ( удерживайте SHIFT )',
'zh' : '将光标锁定在中心 ( 按住SHIFT )', 'zh' : '将光标锁定在中心 ( 按住SHIFT )',
}[lang] }[lang]
QStringDB.btn_prev_image_tip = { 'en' : 'Save and Prev image\nHold SHIFT : accelerate\nHold CTRL : skip non masked\n', QStringDB.btn_prev_image_tip = { 'en' : 'Save and Prev image\nHold SHIFT : accelerate\nHold CTRL : skip non masked\n',
'ru' : 'Сохранить и предыдущее изображение\nУдерживать SHIFT : ускорить\nУдерживать CTRL : пропустить неразмеченные\n', 'ru' : 'Сохранить и предыдущее изображение\nУдерживать SHIFT : ускорить\nУдерживать CTRL : пропустить неразмеченные\n',
'zh' : '保存并转到上一张图片\n按住SHIFT : 加快\n按住CTRL : 跳过未标记的\n', 'zh' : '保存并转到上一张图片\n按住SHIFT : 加快\n按住CTRL : 跳过未标记的\n',
}[lang] }[lang]
QStringDB.btn_next_image_tip = { 'en' : 'Save and Next image\nHold SHIFT : accelerate\nHold CTRL : skip non masked\n', QStringDB.btn_next_image_tip = { 'en' : 'Save and Next image\nHold SHIFT : accelerate\nHold CTRL : skip non masked\n',
'ru' : 'Сохранить и следующее изображение\nУдерживать SHIFT : ускорить\nУдерживать CTRL : пропустить неразмеченные\n', 'ru' : 'Сохранить и следующее изображение\nУдерживать SHIFT : ускорить\nУдерживать CTRL : пропустить неразмеченные\n',
'zh' : '保存并转到下一张图片\n按住SHIFT : 加快\n按住CTRL : 跳过未标记的\n', 'zh' : '保存并转到下一张图片\n按住SHIFT : 加快\n按住CTRL : 跳过未标记的\n',
}[lang] }[lang]
QStringDB.btn_delete_image_tip = { 'en' : 'Move to _trash and Next image\n', QStringDB.btn_delete_image_tip = { 'en' : 'Move to _trash and Next image\n',
'ru' : 'Переместить в _trash и следующее изображение\n', 'ru' : 'Переместить в _trash и следующее изображение\n',
'zh' : '移至_trash转到下一张图片 ', 'zh' : '移至_trash转到下一张图片 ',
}[lang] }[lang]
QStringDB.loading_tip = {'en' : 'Loading', QStringDB.loading_tip = {'en' : 'Loading',
'ru' : 'Загрузка', 'ru' : 'Загрузка',
'zh' : '正在载入', 'zh' : '正在载入',
}[lang] }[lang]
QStringDB.labeled_tip = {'en' : 'labeled', QStringDB.labeled_tip = {'en' : 'labeled',
'ru' : 'размечено', 'ru' : 'размечено',
'zh' : '标记的', 'zh' : '标记的',
}[lang] }[lang]

View file

@ -286,7 +286,7 @@ class QCanvasControlsRightBar(QFrame):
controls_bar_frame2.setFrameShape(QFrame.StyledPanel) controls_bar_frame2.setFrameShape(QFrame.StyledPanel)
controls_bar_frame2.setSizePolicy (QSizePolicy.Fixed, QSizePolicy.Fixed) controls_bar_frame2.setSizePolicy (QSizePolicy.Fixed, QSizePolicy.Fixed)
controls_bar_frame2.setLayout(controls_bar_frame2_l) controls_bar_frame2.setLayout(controls_bar_frame2_l)
controls_bar_frame1_l = QVBoxLayout() controls_bar_frame1_l = QVBoxLayout()
controls_bar_frame1_l.addWidget ( btn_poly_color_red ) controls_bar_frame1_l.addWidget ( btn_poly_color_red )
controls_bar_frame1_l.addWidget ( btn_poly_color_green ) controls_bar_frame1_l.addWidget ( btn_poly_color_green )
@ -297,7 +297,7 @@ class QCanvasControlsRightBar(QFrame):
controls_bar_frame1.setFrameShape(QFrame.StyledPanel) controls_bar_frame1.setFrameShape(QFrame.StyledPanel)
controls_bar_frame1.setSizePolicy (QSizePolicy.Fixed, QSizePolicy.Fixed) controls_bar_frame1.setSizePolicy (QSizePolicy.Fixed, QSizePolicy.Fixed)
controls_bar_frame1.setLayout(controls_bar_frame1_l) controls_bar_frame1.setLayout(controls_bar_frame1_l)
controls_bar_frame3_l = QVBoxLayout() controls_bar_frame3_l = QVBoxLayout()
controls_bar_frame3_l.addWidget ( btn_view_lock_center ) controls_bar_frame3_l.addWidget ( btn_view_lock_center )
controls_bar_frame3 = QFrame() controls_bar_frame3 = QFrame()
@ -310,7 +310,7 @@ class QCanvasControlsRightBar(QFrame):
controls_bar_l.addWidget(controls_bar_frame2) controls_bar_l.addWidget(controls_bar_frame2)
controls_bar_l.addWidget(controls_bar_frame1) controls_bar_l.addWidget(controls_bar_frame1)
controls_bar_l.addWidget(controls_bar_frame3) controls_bar_l.addWidget(controls_bar_frame3)
self.setSizePolicy ( QSizePolicy.Fixed, QSizePolicy.Expanding ) self.setSizePolicy ( QSizePolicy.Fixed, QSizePolicy.Expanding )
self.setLayout(controls_bar_l) self.setLayout(controls_bar_l)
@ -1342,18 +1342,18 @@ class MainWindow(QXMainWindow):
self.update_cached_images() self.update_cached_images()
self.update_preview_bar() self.update_preview_bar()
def trash_current_image(self): def trash_current_image(self):
self.process_next_image() self.process_next_image()
img_path = self.image_paths_done.pop(-1) img_path = self.image_paths_done.pop(-1)
img_path = Path(img_path) img_path = Path(img_path)
self.trash_dirpath.mkdir(parents=True, exist_ok=True) self.trash_dirpath.mkdir(parents=True, exist_ok=True)
img_path.rename( self.trash_dirpath / img_path.name ) img_path.rename( self.trash_dirpath / img_path.name )
self.update_cached_images() self.update_cached_images()
self.update_preview_bar() self.update_preview_bar()
def initialize_ui(self): def initialize_ui(self):
self.canvas = QCanvas() self.canvas = QCanvas()
@ -1370,10 +1370,10 @@ class MainWindow(QXMainWindow):
btn_delete_image = QXIconButton(QIconDB.trashcan, QStringDB.btn_delete_image_tip, shortcut='X', click_func=self.trash_current_image) btn_delete_image = QXIconButton(QIconDB.trashcan, QStringDB.btn_delete_image_tip, shortcut='X', click_func=self.trash_current_image)
btn_delete_image.setIconSize(QUIConfig.preview_bar_icon_q_size) btn_delete_image.setIconSize(QUIConfig.preview_bar_icon_q_size)
pad_image = QWidget() pad_image = QWidget()
pad_image.setFixedSize(QUIConfig.preview_bar_icon_q_size) pad_image.setFixedSize(QUIConfig.preview_bar_icon_q_size)
preview_image_bar_frame_l = QHBoxLayout() preview_image_bar_frame_l = QHBoxLayout()
preview_image_bar_frame_l.setContentsMargins(0,0,0,0) preview_image_bar_frame_l.setContentsMargins(0,0,0,0)
preview_image_bar_frame_l.addWidget ( pad_image, alignment=Qt.AlignCenter) preview_image_bar_frame_l.addWidget ( pad_image, alignment=Qt.AlignCenter)
@ -1393,11 +1393,11 @@ class MainWindow(QXMainWindow):
preview_image_bar_frame2 = QFrame() preview_image_bar_frame2 = QFrame()
preview_image_bar_frame2.setSizePolicy ( QSizePolicy.Fixed, QSizePolicy.Fixed ) preview_image_bar_frame2.setSizePolicy ( QSizePolicy.Fixed, QSizePolicy.Fixed )
preview_image_bar_frame2.setLayout(preview_image_bar_frame2_l) preview_image_bar_frame2.setLayout(preview_image_bar_frame2_l)
preview_image_bar_l = QHBoxLayout() preview_image_bar_l = QHBoxLayout()
preview_image_bar_l.addWidget (preview_image_bar_frame, alignment=Qt.AlignCenter) preview_image_bar_l.addWidget (preview_image_bar_frame, alignment=Qt.AlignCenter)
preview_image_bar_l.addWidget (preview_image_bar_frame2) preview_image_bar_l.addWidget (preview_image_bar_frame2)
preview_image_bar = QFrame() preview_image_bar = QFrame()
preview_image_bar.setFrameShape(QFrame.StyledPanel) preview_image_bar.setFrameShape(QFrame.StyledPanel)
preview_image_bar.setSizePolicy ( QSizePolicy.Expanding, QSizePolicy.Fixed ) preview_image_bar.setSizePolicy ( QSizePolicy.Expanding, QSizePolicy.Fixed )

View file

@ -4,6 +4,6 @@ plugins:
relative_links: relative_links:
enabled: true enabled: true
collections: true collections: true
include: include:
- README.md - README.md

View file

@ -2,7 +2,7 @@ import cv2
import numpy as np import numpy as np
from pathlib import Path from pathlib import Path
from core.interact import interact as io from core.interact import interact as io
from core import imagelib from core import imagelib
import traceback import traceback
def cv2_imread(filename, flags=cv2.IMREAD_UNCHANGED, loader_func=None, verbose=True): def cv2_imread(filename, flags=cv2.IMREAD_UNCHANGED, loader_func=None, verbose=True):
@ -34,7 +34,6 @@ def cv2_imwrite(filename, img, *args):
def cv2_resize(x, *args, **kwargs): def cv2_resize(x, *args, **kwargs):
h,w,c = x.shape h,w,c = x.shape
x = cv2.resize(x, *args, **kwargs) x = cv2.resize(x, *args, **kwargs)
x = imagelib.normalize_channels(x, c) x = imagelib.normalize_channels(x, c)
return x return x

View file

@ -6,13 +6,13 @@ from enum import IntEnum
class SegIEPolyType(IntEnum): class SegIEPolyType(IntEnum):
EXCLUDE = 0 EXCLUDE = 0
INCLUDE = 1 INCLUDE = 1
class SegIEPoly(): class SegIEPoly():
def __init__(self, type=None, pts=None, **kwargs): def __init__(self, type=None, pts=None, **kwargs):
self.type = type self.type = type
if pts is None: if pts is None:
pts = np.empty( (0,2), dtype=np.float32 ) pts = np.empty( (0,2), dtype=np.float32 )
else: else:
@ -24,12 +24,12 @@ class SegIEPoly():
return {'type': int(self.type), return {'type': int(self.type),
'pts' : self.get_pts(), 'pts' : self.get_pts(),
} }
def identical(self, b): def identical(self, b):
if self.n != b.n: if self.n != b.n:
return False return False
return (self.pts[0:self.n] == b.pts[0:b.n]).all() return (self.pts[0:self.n] == b.pts[0:b.n]).all()
def get_type(self): def get_type(self):
return self.type return self.type
@ -54,33 +54,33 @@ class SegIEPoly():
raise ValueError("insert_pt out of range") raise ValueError("insert_pt out of range")
self.pts = np.concatenate( (self.pts[0:n], pt[None,...].astype(np.float32), self.pts[n:]), axis=0) self.pts = np.concatenate( (self.pts[0:n], pt[None,...].astype(np.float32), self.pts[n:]), axis=0)
self.n_max = self.n = self.n+1 self.n_max = self.n = self.n+1
def remove_pt(self, n): def remove_pt(self, n):
if n < 0 or n >= self.n: if n < 0 or n >= self.n:
raise ValueError("remove_pt out of range") raise ValueError("remove_pt out of range")
self.pts = np.concatenate( (self.pts[0:n], self.pts[n+1:]), axis=0) self.pts = np.concatenate( (self.pts[0:n], self.pts[n+1:]), axis=0)
self.n_max = self.n = self.n-1 self.n_max = self.n = self.n-1
def get_last_point(self): def get_last_point(self):
return self.pts[self.n-1].copy() return self.pts[self.n-1].copy()
def get_pts(self): def get_pts(self):
return self.pts[0:self.n].copy() return self.pts[0:self.n].copy()
def get_pts_count(self): def get_pts_count(self):
return self.n return self.n
def set_point(self, id, pt): def set_point(self, id, pt):
self.pts[id] = pt self.pts[id] = pt
def set_points(self, pts): def set_points(self, pts):
self.pts = np.array(pts) self.pts = np.array(pts)
self.n_max = self.n = len(pts) self.n_max = self.n = len(pts)
def mult_points(self, val): def mult_points(self, val):
self.pts *= val self.pts *= val
class SegIEPolys(): class SegIEPolys():
def __init__(self): def __init__(self):
@ -91,10 +91,10 @@ class SegIEPolys():
o_polys_len = len(b.polys) o_polys_len = len(b.polys)
if polys_len != o_polys_len: if polys_len != o_polys_len:
return False return False
return all ([ a_poly.identical(b_poly) for a_poly, b_poly in zip(self.polys, b.polys) ]) return all ([ a_poly.identical(b_poly) for a_poly, b_poly in zip(self.polys, b.polys) ])
def add_poly(self, ie_poly_type): def add_poly(self, ie_poly_type):
poly = SegIEPoly(ie_poly_type) poly = SegIEPoly(ie_poly_type)
self.polys.append (poly) self.polys.append (poly)
return poly return poly
@ -105,22 +105,22 @@ class SegIEPolys():
def has_polys(self): def has_polys(self):
return len(self.polys) != 0 return len(self.polys) != 0
def get_poly(self, id): def get_poly(self, id):
return self.polys[id] return self.polys[id]
def get_polys(self): def get_polys(self):
return self.polys return self.polys
def get_pts_count(self): def get_pts_count(self):
return sum([poly.get_pts_count() for poly in self.polys]) return sum([poly.get_pts_count() for poly in self.polys])
def sort(self): def sort(self):
poly_by_type = { SegIEPolyType.EXCLUDE : [], SegIEPolyType.INCLUDE : [] } poly_by_type = { SegIEPolyType.EXCLUDE : [], SegIEPolyType.INCLUDE : [] }
for poly in self.polys: for poly in self.polys:
poly_by_type[poly.type].append(poly) poly_by_type[poly.type].append(poly)
self.polys = poly_by_type[SegIEPolyType.INCLUDE] + poly_by_type[SegIEPolyType.EXCLUDE] self.polys = poly_by_type[SegIEPolyType.INCLUDE] + poly_by_type[SegIEPolyType.EXCLUDE]
def __iter__(self): def __iter__(self):
@ -138,11 +138,11 @@ class SegIEPolys():
def dump(self): def dump(self):
return {'polys' : [ poly.dump() for poly in self.polys ] } return {'polys' : [ poly.dump() for poly in self.polys ] }
def mult_points(self, val): def mult_points(self, val):
for poly in self.polys: for poly in self.polys:
poly.mult_points(val) poly.mult_points(val)
@staticmethod @staticmethod
def load(data=None): def load(data=None):
ie_polys = SegIEPolys() ie_polys = SegIEPolys()
@ -150,9 +150,9 @@ class SegIEPolys():
if isinstance(data, list): if isinstance(data, list):
# Backward comp # Backward comp
ie_polys.polys = [ SegIEPoly(type=type, pts=pts) for (type, pts) in data ] ie_polys.polys = [ SegIEPoly(type=type, pts=pts) for (type, pts) in data ]
elif isinstance(data, dict): elif isinstance(data, dict):
ie_polys.polys = [ SegIEPoly(**poly_cfg) for poly_cfg in data['polys'] ] ie_polys.polys = [ SegIEPoly(**poly_cfg) for poly_cfg in data['polys'] ]
ie_polys.sort() ie_polys.sort()
return ie_polys return ie_polys

View file

@ -7,7 +7,7 @@ def LinearMotionBlur(image, size, angle):
k = cv2.warpAffine(k, cv2.getRotationMatrix2D( (size / 2 -0.5 , size / 2 -0.5 ) , angle, 1.0), (size, size) ) k = cv2.warpAffine(k, cv2.getRotationMatrix2D( (size / 2 -0.5 , size / 2 -0.5 ) , angle, 1.0), (size, size) )
k = k * ( 1.0 / np.sum(k) ) k = k * ( 1.0 / np.sum(k) )
return cv2.filter2D(image, -1, k) return cv2.filter2D(image, -1, k)
def blursharpen (img, sharpen_mode=0, kernel_size=3, amount=100): def blursharpen (img, sharpen_mode=0, kernel_size=3, amount=100):
if kernel_size % 2 == 0: if kernel_size % 2 == 0:
kernel_size += 1 kernel_size += 1

View file

@ -90,7 +90,7 @@ def color_transfer_mkl(x0, x1):
def color_transfer_idt(i0, i1, bins=256, n_rot=20): def color_transfer_idt(i0, i1, bins=256, n_rot=20):
import scipy.stats import scipy.stats
relaxation = 1 / n_rot relaxation = 1 / n_rot
h,w,c = i0.shape h,w,c = i0.shape
h1,w1,c1 = i1.shape h1,w1,c1 = i1.shape

View file

@ -2,15 +2,15 @@ import numpy as np
def random_crop(img, w, h): def random_crop(img, w, h):
height, width = img.shape[:2] height, width = img.shape[:2]
h_rnd = height - h h_rnd = height - h
w_rnd = width - w w_rnd = width - w
y = np.random.randint(0, h_rnd) if h_rnd > 0 else 0 y = np.random.randint(0, h_rnd) if h_rnd > 0 else 0
x = np.random.randint(0, w_rnd) if w_rnd > 0 else 0 x = np.random.randint(0, w_rnd) if w_rnd > 0 else 0
return img[y:y+height, x:x+width] return img[y:y+height, x:x+width]
def normalize_channels(img, target_channels): def normalize_channels(img, target_channels):
img_shape_len = len(img.shape) img_shape_len = len(img.shape)
if img_shape_len == 2: if img_shape_len == 2:

View file

@ -43,7 +43,7 @@ def sobel(image):
from skimage.filters.edges import HSOBEL_WEIGHTS from skimage.filters.edges import HSOBEL_WEIGHTS
h1 = np.array(HSOBEL_WEIGHTS) h1 = np.array(HSOBEL_WEIGHTS)
h1 /= np.sum(abs(h1)) # normalize h1 h1 /= np.sum(abs(h1)) # normalize h1
from scipy.ndimage import convolve from scipy.ndimage import convolve
strength2 = np.square(convolve(image, h1.T)) strength2 = np.square(convolve(image, h1.T))
@ -274,5 +274,5 @@ def estimate_sharpness(image):
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
else: else:
image = image[...,0] image = image[...,0]
return compute(image) return compute(image)

View file

@ -50,7 +50,7 @@ def apply_random_sharpen( img, chance, kernel_max_size, mask=None, rnd_state=Non
result = blursharpen(result, 1, sharp_rnd_kernel, rnd_state.randint(10) ) result = blursharpen(result, 1, sharp_rnd_kernel, rnd_state.randint(10) )
else: else:
result = blursharpen(result, 2, sharp_rnd_kernel, rnd_state.randint(50) ) result = blursharpen(result, 2, sharp_rnd_kernel, rnd_state.randint(50) )
if mask is not None: if mask is not None:
result = img*(1-mask) + result*mask result = img*(1-mask) + result*mask
@ -127,7 +127,7 @@ def apply_random_jpeg_compress( img, chance, mask=None, rnd_state=None ):
result = img*(1-mask) + result*mask result = img*(1-mask) + result*mask
return result return result
def apply_random_overlay_triangle( img, max_alpha, mask=None, rnd_state=None ): def apply_random_overlay_triangle( img, max_alpha, mask=None, rnd_state=None ):
if rnd_state is None: if rnd_state is None:
rnd_state = np.random rnd_state = np.random
@ -136,21 +136,21 @@ def apply_random_overlay_triangle( img, max_alpha, mask=None, rnd_state=None ):
pt1 = [rnd_state.randint(w), rnd_state.randint(h) ] pt1 = [rnd_state.randint(w), rnd_state.randint(h) ]
pt2 = [rnd_state.randint(w), rnd_state.randint(h) ] pt2 = [rnd_state.randint(w), rnd_state.randint(h) ]
pt3 = [rnd_state.randint(w), rnd_state.randint(h) ] pt3 = [rnd_state.randint(w), rnd_state.randint(h) ]
alpha = rnd_state.uniform()*max_alpha alpha = rnd_state.uniform()*max_alpha
tri_mask = cv2.fillPoly( np.zeros_like(img), [ np.array([pt1,pt2,pt3], np.int32) ], (alpha,)*c ) tri_mask = cv2.fillPoly( np.zeros_like(img), [ np.array([pt1,pt2,pt3], np.int32) ], (alpha,)*c )
if rnd_state.randint(2) == 0: if rnd_state.randint(2) == 0:
result = np.clip(img+tri_mask, 0, 1) result = np.clip(img+tri_mask, 0, 1)
else: else:
result = np.clip(img-tri_mask, 0, 1) result = np.clip(img-tri_mask, 0, 1)
if mask is not None: if mask is not None:
result = img*(1-mask) + result*mask result = img*(1-mask) + result*mask
return result return result
def _min_resize(x, m): def _min_resize(x, m):
if x.shape[0] < x.shape[1]: if x.shape[0] < x.shape[1]:
s0 = m s0 = m
@ -161,7 +161,7 @@ def _min_resize(x, m):
new_max = min(s1, s0) new_max = min(s1, s0)
raw_max = min(x.shape[0], x.shape[1]) raw_max = min(x.shape[0], x.shape[1])
return cv2.resize(x, (s1, s0), interpolation=cv2.INTER_LANCZOS4) return cv2.resize(x, (s1, s0), interpolation=cv2.INTER_LANCZOS4)
def _d_resize(x, d, fac=1.0): def _d_resize(x, d, fac=1.0):
new_min = min(int(d[1] * fac), int(d[0] * fac)) new_min = min(int(d[1] * fac), int(d[0] * fac))
raw_min = min(x.shape[0], x.shape[1]) raw_min = min(x.shape[0], x.shape[1])
@ -171,7 +171,7 @@ def _d_resize(x, d, fac=1.0):
interpolation = cv2.INTER_LANCZOS4 interpolation = cv2.INTER_LANCZOS4
y = cv2.resize(x, (int(d[1] * fac), int(d[0] * fac)), interpolation=interpolation) y = cv2.resize(x, (int(d[1] * fac), int(d[0] * fac)), interpolation=interpolation)
return y return y
def _get_image_gradient(dist): def _get_image_gradient(dist):
cols = cv2.filter2D(dist, cv2.CV_32F, np.array([[-1, 0, +1], [-2, 0, +2], [-1, 0, +1]])) cols = cv2.filter2D(dist, cv2.CV_32F, np.array([[-1, 0, +1], [-2, 0, +2], [-1, 0, +1]]))
rows = cv2.filter2D(dist, cv2.CV_32F, np.array([[-1, -2, -1], [0, 0, 0], [+1, +2, +1]])) rows = cv2.filter2D(dist, cv2.CV_32F, np.array([[-1, -2, -1], [0, 0, 0], [+1, +2, +1]]))
@ -211,24 +211,24 @@ def _generate_lighting_effects(content):
coarse_effect_rows = (coarse_effect_rows + EPS) / (max_effect + EPS) coarse_effect_rows = (coarse_effect_rows + EPS) / (max_effect + EPS)
return np.stack([ np.zeros_like(coarse_effect_rows), coarse_effect_rows, coarse_effect_cols], axis=-1) return np.stack([ np.zeros_like(coarse_effect_rows), coarse_effect_rows, coarse_effect_cols], axis=-1)
def apply_random_relight(img, mask=None, rnd_state=None): def apply_random_relight(img, mask=None, rnd_state=None):
if rnd_state is None: if rnd_state is None:
rnd_state = np.random rnd_state = np.random
def_img = img def_img = img
if rnd_state.randint(2) == 0: if rnd_state.randint(2) == 0:
light_pos_y = 1.0 if rnd_state.randint(2) == 0 else -1.0 light_pos_y = 1.0 if rnd_state.randint(2) == 0 else -1.0
light_pos_x = rnd_state.uniform()*2-1.0 light_pos_x = rnd_state.uniform()*2-1.0
else: else:
light_pos_y = rnd_state.uniform()*2-1.0 light_pos_y = rnd_state.uniform()*2-1.0
light_pos_x = 1.0 if rnd_state.randint(2) == 0 else -1.0 light_pos_x = 1.0 if rnd_state.randint(2) == 0 else -1.0
light_source_height = 0.3*rnd_state.uniform()*0.7 light_source_height = 0.3*rnd_state.uniform()*0.7
light_intensity = 1.0+rnd_state.uniform() light_intensity = 1.0+rnd_state.uniform()
ambient_intensity = 0.5 ambient_intensity = 0.5
light_source_location = np.array([[[light_source_height, light_pos_y, light_pos_x ]]], dtype=np.float32) light_source_location = np.array([[[light_source_height, light_pos_y, light_pos_x ]]], dtype=np.float32)
light_source_direction = light_source_location / np.sqrt(np.sum(np.square(light_source_location))) light_source_direction = light_source_location / np.sqrt(np.sum(np.square(light_source_location)))
@ -238,8 +238,8 @@ def apply_random_relight(img, mask=None, rnd_state=None):
result = def_img * (ambient_intensity + lighting_effect * light_intensity) #light_source_color result = def_img * (ambient_intensity + lighting_effect * light_intensity) #light_source_color
result = np.clip(result, 0, 1) result = np.clip(result, 0, 1)
if mask is not None: if mask is not None:
result = def_img*(1-mask) + result*mask result = def_img*(1-mask) + result*mask
return result return result

View file

@ -14,12 +14,12 @@ def dist_to_edges(pts, pt, is_closed=False):
pa = pt-a pa = pt-a
ba = b-a ba = b-a
div = np.einsum('ij,ij->i', ba, ba) div = np.einsum('ij,ij->i', ba, ba)
div[div==0]=1 div[div==0]=1
h = np.clip( np.einsum('ij,ij->i', pa, ba) / div, 0, 1 ) h = np.clip( np.einsum('ij,ij->i', pa, ba) / div, 0, 1 )
x = npla.norm ( pa - ba*h[...,None], axis=1 ) x = npla.norm ( pa - ba*h[...,None], axis=1 )
return x, a+ba*h[...,None] return x, a+ba*h[...,None]

View file

@ -9,10 +9,10 @@ from numpy import linalg as npla
def vector2_dot(a,b): def vector2_dot(a,b):
return a[...,0]*b[...,0]+a[...,1]*b[...,1] return a[...,0]*b[...,0]+a[...,1]*b[...,1]
def vector2_dot2(a): def vector2_dot2(a):
return a[...,0]*a[...,0]+a[...,1]*a[...,1] return a[...,0]*a[...,0]+a[...,1]*a[...,1]
def vector2_cross(a,b): def vector2_cross(a,b):
return a[...,0]*b[...,1]-a[...,1]*b[...,0] return a[...,0]*b[...,1]-a[...,1]*b[...,0]
@ -20,51 +20,51 @@ def vector2_cross(a,b):
def circle_faded( wh, center, fade_dists ): def circle_faded( wh, center, fade_dists ):
""" """
returns drawn circle in [h,w,1] output range [0..1.0] float32 returns drawn circle in [h,w,1] output range [0..1.0] float32
wh = [w,h] resolution wh = [w,h] resolution
center = [x,y] center of circle center = [x,y] center of circle
fade_dists = [fade_start, fade_end] fade values fade_dists = [fade_start, fade_end] fade values
""" """
w,h = wh w,h = wh
pts = np.empty( (h,w,2), dtype=np.float32 ) pts = np.empty( (h,w,2), dtype=np.float32 )
pts[...,0] = np.arange(w)[:,None] pts[...,0] = np.arange(w)[:,None]
pts[...,1] = np.arange(h)[None,:] pts[...,1] = np.arange(h)[None,:]
pts = pts.reshape ( (h*w, -1) ) pts = pts.reshape ( (h*w, -1) )
pts_dists = np.abs ( npla.norm(pts-center, axis=-1) ) pts_dists = np.abs ( npla.norm(pts-center, axis=-1) )
if fade_dists[1] == 0: if fade_dists[1] == 0:
fade_dists[1] = 1 fade_dists[1] = 1
pts_dists = ( pts_dists - fade_dists[0] ) / fade_dists[1] pts_dists = ( pts_dists - fade_dists[0] ) / fade_dists[1]
pts_dists = np.clip( 1-pts_dists, 0, 1) pts_dists = np.clip( 1-pts_dists, 0, 1)
return pts_dists.reshape ( (h,w,1) ).astype(np.float32) return pts_dists.reshape ( (h,w,1) ).astype(np.float32)
def bezier( wh, A, B, C ): def bezier( wh, A, B, C ):
""" """
returns drawn bezier in [h,w,1] output range float32, returns drawn bezier in [h,w,1] output range float32,
every pixel contains signed distance to bezier line every pixel contains signed distance to bezier line
wh [w,h] resolution wh [w,h] resolution
A,B,C points [x,y] A,B,C points [x,y]
""" """
width,height = wh width,height = wh
A = np.float32(A) A = np.float32(A)
B = np.float32(B) B = np.float32(B)
C = np.float32(C) C = np.float32(C)
pos = np.empty( (height,width,2), dtype=np.float32 ) pos = np.empty( (height,width,2), dtype=np.float32 )
pos[...,0] = np.arange(width)[:,None] pos[...,0] = np.arange(width)[:,None]
pos[...,1] = np.arange(height)[None,:] pos[...,1] = np.arange(height)[None,:]
a = B-A a = B-A
b = A - 2.0*B + C b = A - 2.0*B + C
@ -74,7 +74,7 @@ def bezier( wh, A, B, C ):
b_dot = vector2_dot(b,b) b_dot = vector2_dot(b,b)
if b_dot == 0.0: if b_dot == 0.0:
return np.zeros( (height,width), dtype=np.float32 ) return np.zeros( (height,width), dtype=np.float32 )
kk = 1.0 / b_dot kk = 1.0 / b_dot
kx = kk * vector2_dot(a,b) kx = kk * vector2_dot(a,b)
@ -85,72 +85,72 @@ def bezier( wh, A, B, C ):
sgn = 0.0; sgn = 0.0;
p = ky - kx*kx; p = ky - kx*kx;
p3 = p*p*p; p3 = p*p*p;
q = kx*(2.0*kx*kx - 3.0*ky) + kz; q = kx*(2.0*kx*kx - 3.0*ky) + kz;
h = q*q + 4.0*p3; h = q*q + 4.0*p3;
hp_sel = h >= 0.0 hp_sel = h >= 0.0
hp_p = h[hp_sel] hp_p = h[hp_sel]
hp_p = np.sqrt(hp_p) hp_p = np.sqrt(hp_p)
hp_x = ( np.stack( (hp_p,-hp_p), -1) -q[hp_sel,None] ) / 2.0 hp_x = ( np.stack( (hp_p,-hp_p), -1) -q[hp_sel,None] ) / 2.0
hp_uv = np.sign(hp_x) * np.power( np.abs(hp_x), [1.0/3.0, 1.0/3.0] ) hp_uv = np.sign(hp_x) * np.power( np.abs(hp_x), [1.0/3.0, 1.0/3.0] )
hp_t = np.clip( hp_uv[...,0] + hp_uv[...,1] - kx, 0.0, 1.0 ) hp_t = np.clip( hp_uv[...,0] + hp_uv[...,1] - kx, 0.0, 1.0 )
hp_t = hp_t[...,None] hp_t = hp_t[...,None]
hp_q = d[hp_sel]+(c+b*hp_t)*hp_t hp_q = d[hp_sel]+(c+b*hp_t)*hp_t
hp_res = vector2_dot2(hp_q) hp_res = vector2_dot2(hp_q)
hp_sgn = vector2_cross(c+2.0*b*hp_t,hp_q) hp_sgn = vector2_cross(c+2.0*b*hp_t,hp_q)
hl_sel = h < 0.0 hl_sel = h < 0.0
hl_q = q[hl_sel] hl_q = q[hl_sel]
hl_p = p[hl_sel] hl_p = p[hl_sel]
hl_z = np.sqrt(-hl_p) hl_z = np.sqrt(-hl_p)
hl_v = np.arccos( hl_q / (hl_p*hl_z*2.0)) / 3.0 hl_v = np.arccos( hl_q / (hl_p*hl_z*2.0)) / 3.0
hl_m = np.cos(hl_v) hl_m = np.cos(hl_v)
hl_n = np.sin(hl_v)*1.732050808; hl_n = np.sin(hl_v)*1.732050808;
hl_t = np.clip( np.stack( (hl_m+hl_m,-hl_n-hl_m,hl_n-hl_m), -1)*hl_z[...,None]-kx, 0.0, 1.0 ); hl_t = np.clip( np.stack( (hl_m+hl_m,-hl_n-hl_m,hl_n-hl_m), -1)*hl_z[...,None]-kx, 0.0, 1.0 );
hl_d = d[hl_sel] hl_d = d[hl_sel]
hl_qx = hl_d+(c+b*hl_t[...,0:1])*hl_t[...,0:1] hl_qx = hl_d+(c+b*hl_t[...,0:1])*hl_t[...,0:1]
hl_dx = vector2_dot2(hl_qx) hl_dx = vector2_dot2(hl_qx)
hl_sx = vector2_cross(c+2.0*b*hl_t[...,0:1], hl_qx) hl_sx = vector2_cross(c+2.0*b*hl_t[...,0:1], hl_qx)
hl_qy = hl_d+(c+b*hl_t[...,1:2])*hl_t[...,1:2] hl_qy = hl_d+(c+b*hl_t[...,1:2])*hl_t[...,1:2]
hl_dy = vector2_dot2(hl_qy) hl_dy = vector2_dot2(hl_qy)
hl_sy = vector2_cross(c+2.0*b*hl_t[...,1:2],hl_qy); hl_sy = vector2_cross(c+2.0*b*hl_t[...,1:2],hl_qy);
hl_dx_l_dy = hl_dx<hl_dy hl_dx_l_dy = hl_dx<hl_dy
hl_dx_ge_dy = hl_dx>=hl_dy hl_dx_ge_dy = hl_dx>=hl_dy
hl_res = np.empty_like(hl_dx) hl_res = np.empty_like(hl_dx)
hl_res[hl_dx_l_dy] = hl_dx[hl_dx_l_dy] hl_res[hl_dx_l_dy] = hl_dx[hl_dx_l_dy]
hl_res[hl_dx_ge_dy] = hl_dy[hl_dx_ge_dy] hl_res[hl_dx_ge_dy] = hl_dy[hl_dx_ge_dy]
hl_sgn = np.empty_like(hl_sx) hl_sgn = np.empty_like(hl_sx)
hl_sgn[hl_dx_l_dy] = hl_sx[hl_dx_l_dy] hl_sgn[hl_dx_l_dy] = hl_sx[hl_dx_l_dy]
hl_sgn[hl_dx_ge_dy] = hl_sy[hl_dx_ge_dy] hl_sgn[hl_dx_ge_dy] = hl_sy[hl_dx_ge_dy]
res = np.empty( (height, width), np.float32 ) res = np.empty( (height, width), np.float32 )
res[hp_sel] = hp_res res[hp_sel] = hp_res
res[hl_sel] = hl_res res[hl_sel] = hl_res
sgn = np.empty( (height, width), np.float32 ) sgn = np.empty( (height, width), np.float32 )
sgn[hp_sel] = hp_sgn sgn[hp_sel] = hp_sgn
sgn[hl_sel] = hl_sgn sgn[hl_sel] = hl_sgn
sgn = np.sign(sgn) sgn = np.sign(sgn)
res = np.sqrt(res)*sgn res = np.sqrt(res)*sgn
return res[...,None] return res[...,None]
def random_faded(wh): def random_faded(wh):
""" """
apply one of them: apply one of them:
@ -162,39 +162,39 @@ def random_faded(wh):
return random_circle_faded(wh) return random_circle_faded(wh)
elif rnd == 1: elif rnd == 1:
return random_bezier_split_faded(wh) return random_bezier_split_faded(wh)
def random_circle_faded ( wh, rnd_state=None ): def random_circle_faded ( wh, rnd_state=None ):
if rnd_state is None: if rnd_state is None:
rnd_state = np.random rnd_state = np.random
w,h = wh w,h = wh
wh_max = max(w,h) wh_max = max(w,h)
fade_start = rnd_state.randint(wh_max) fade_start = rnd_state.randint(wh_max)
fade_end = fade_start + rnd_state.randint(wh_max- fade_start) fade_end = fade_start + rnd_state.randint(wh_max- fade_start)
return circle_faded (wh, [ rnd_state.randint(h), rnd_state.randint(w) ], return circle_faded (wh, [ rnd_state.randint(h), rnd_state.randint(w) ],
[fade_start, fade_end] ) [fade_start, fade_end] )
def random_bezier_split_faded( wh ): def random_bezier_split_faded( wh ):
width, height = wh width, height = wh
degA = np.random.randint(360) degA = np.random.randint(360)
degB = np.random.randint(360) degB = np.random.randint(360)
degC = np.random.randint(360) degC = np.random.randint(360)
deg_2_rad = math.pi / 180.0 deg_2_rad = math.pi / 180.0
center = np.float32([width / 2.0, height / 2.0]) center = np.float32([width / 2.0, height / 2.0])
radius = max(width, height) radius = max(width, height)
A = center + radius*np.float32([ math.sin( degA * deg_2_rad), math.cos( degA * deg_2_rad) ] ) A = center + radius*np.float32([ math.sin( degA * deg_2_rad), math.cos( degA * deg_2_rad) ] )
B = center + np.random.randint(radius)*np.float32([ math.sin( degB * deg_2_rad), math.cos( degB * deg_2_rad) ] ) B = center + np.random.randint(radius)*np.float32([ math.sin( degB * deg_2_rad), math.cos( degB * deg_2_rad) ] )
C = center + radius*np.float32([ math.sin( degC * deg_2_rad), math.cos( degC * deg_2_rad) ] ) C = center + radius*np.float32([ math.sin( degC * deg_2_rad), math.cos( degC * deg_2_rad) ] )
x = bezier( (width,height), A, B, C ) x = bezier( (width,height), A, B, C )
x = x / (1+np.random.randint(radius)) + 0.5 x = x / (1+np.random.randint(radius)) + 0.5
x = np.clip(x, 0, 1) x = np.clip(x, 0, 1)
return x return x

View file

@ -7,10 +7,10 @@ def gen_warp_params (w, flip=False, rotation_range=[-10,10], scale_range=[-0.5,
rnd_state = np.random rnd_state = np.random
rw = None rw = None
if w < 64: if w < 64:
rw = w rw = w
w = 64 w = 64
rotation = rnd_state.uniform( rotation_range[0], rotation_range[1] ) rotation = rnd_state.uniform( rotation_range[0], rotation_range[1] )
scale = rnd_state.uniform(1 +scale_range[0], 1 +scale_range[1]) scale = rnd_state.uniform(1 +scale_range[0], 1 +scale_range[1])
tx = rnd_state.uniform( tx_range[0], tx_range[1] ) tx = rnd_state.uniform( tx_range[0], tx_range[1] )
@ -52,19 +52,19 @@ def gen_warp_params (w, flip=False, rotation_range=[-10,10], scale_range=[-0.5,
def warp_by_params (params, img, can_warp, can_transform, can_flip, border_replicate, cv2_inter=cv2.INTER_CUBIC): def warp_by_params (params, img, can_warp, can_transform, can_flip, border_replicate, cv2_inter=cv2.INTER_CUBIC):
rw = params['rw'] rw = params['rw']
if (can_warp or can_transform) and rw is not None: if (can_warp or can_transform) and rw is not None:
img = cv2.resize(img, (64,64), interpolation=cv2_inter) img = cv2.resize(img, (64,64), interpolation=cv2_inter)
if can_warp: if can_warp:
img = cv2.remap(img, params['mapx'], params['mapy'], cv2_inter ) img = cv2.remap(img, params['mapx'], params['mapy'], cv2_inter )
if can_transform: if can_transform:
img = cv2.warpAffine( img, params['rmat'], (params['w'], params['w']), borderMode=(cv2.BORDER_REPLICATE if border_replicate else cv2.BORDER_CONSTANT), flags=cv2_inter ) img = cv2.warpAffine( img, params['rmat'], (params['w'], params['w']), borderMode=(cv2.BORDER_REPLICATE if border_replicate else cv2.BORDER_CONSTANT), flags=cv2_inter )
if (can_warp or can_transform) and rw is not None: if (can_warp or can_transform) and rw is not None:
img = cv2.resize(img, (rw,rw), interpolation=cv2_inter) img = cv2.resize(img, (rw,rw), interpolation=cv2_inter)
if len(img.shape) == 2: if len(img.shape) == 2:
img = img[...,None] img = img[...,None]
if can_flip and params['flip']: if can_flip and params['flip']:

View file

@ -502,7 +502,7 @@ class InteractDesktop(InteractBase):
if has_windows or has_capture_keys: if has_windows or has_capture_keys:
wait_key_time = max(1, int(sleep_time*1000) ) wait_key_time = max(1, int(sleep_time*1000) )
ord_key = cv2.waitKeyEx(wait_key_time) ord_key = cv2.waitKeyEx(wait_key_time)
shift_pressed = False shift_pressed = False
if ord_key != -1: if ord_key != -1:
chr_key = chr(ord_key) if ord_key <= 255 else chr(0) chr_key = chr(ord_key) if ord_key <= 255 else chr(0)

View file

@ -6,16 +6,16 @@ class MPClassFuncOnDemand():
self.class_handle = class_handle self.class_handle = class_handle
self.class_func_name = class_func_name self.class_func_name = class_func_name
self.class_kwargs = class_kwargs self.class_kwargs = class_kwargs
self.class_func = None self.class_func = None
self.s2c = multiprocessing.Queue() self.s2c = multiprocessing.Queue()
self.c2s = multiprocessing.Queue() self.c2s = multiprocessing.Queue()
self.lock = multiprocessing.Lock() self.lock = multiprocessing.Lock()
io.add_process_messages_callback(self.io_callback) io.add_process_messages_callback(self.io_callback)
def io_callback(self): def io_callback(self):
while not self.c2s.empty(): while not self.c2s.empty():
func_args, func_kwargs = self.c2s.get() func_args, func_kwargs = self.c2s.get()
if self.class_func is None: if self.class_func is None:

View file

@ -4,14 +4,14 @@ from core.interact import interact as io
class MPFunc(): class MPFunc():
def __init__(self, func): def __init__(self, func):
self.func = func self.func = func
self.s2c = multiprocessing.Queue() self.s2c = multiprocessing.Queue()
self.c2s = multiprocessing.Queue() self.c2s = multiprocessing.Queue()
self.lock = multiprocessing.Lock() self.lock = multiprocessing.Lock()
io.add_process_messages_callback(self.io_callback) io.add_process_messages_callback(self.io_callback)
def io_callback(self): def io_callback(self):
while not self.c2s.empty(): while not self.c2s.empty():
func_args, func_kwargs = self.c2s.get() func_args, func_kwargs = self.c2s.get()
self.s2c.put ( self.func (*func_args, **func_kwargs) ) self.s2c.put ( self.func (*func_args, **func_kwargs) )

View file

@ -5,11 +5,11 @@ import time
class SubprocessGenerator(object): class SubprocessGenerator(object):
@staticmethod @staticmethod
def launch_thread(generator): def launch_thread(generator):
generator._start() generator._start()
@staticmethod @staticmethod
def start_in_parallel( generator_list ): def start_in_parallel( generator_list ):
""" """
@ -22,7 +22,7 @@ class SubprocessGenerator(object):
while not all ([generator._is_started() for generator in generator_list]): while not all ([generator._is_started() for generator in generator_list]):
time.sleep(0.005) time.sleep(0.005)
def __init__(self, generator_func, user_param=None, prefetch=2, start_now=True): def __init__(self, generator_func, user_param=None, prefetch=2, start_now=True):
super().__init__() super().__init__()
self.prefetch = prefetch self.prefetch = prefetch
@ -42,10 +42,10 @@ class SubprocessGenerator(object):
p.daemon = True p.daemon = True
p.start() p.start()
self.p = p self.p = p
def _is_started(self): def _is_started(self):
return self.p is not None return self.p is not None
def process_func(self, user_param): def process_func(self, user_param):
self.generator_func = self.generator_func(user_param) self.generator_func = self.generator_func(user_param)
while True: while True:

View file

@ -229,9 +229,9 @@ class Subprocessor(object):
err_msg = obj.get('err_msg', None) err_msg = obj.get('err_msg', None)
if err_msg is not None: if err_msg is not None:
io.log_info(f'Error while processing data: {err_msg}') io.log_info(f'Error while processing data: {err_msg}')
if 'data' in obj.keys(): if 'data' in obj.keys():
self.on_data_return (cli.host_dict, obj['data'] ) self.on_data_return (cli.host_dict, obj['data'] )
#and killing process #and killing process
cli.kill() cli.kill()
self.clis.remove(cli) self.clis.remove(cli)

View file

@ -1,17 +1,17 @@
from core.leras import nn from core.leras import nn
class ArchiBase(): class ArchiBase():
def __init__(self, *args, name=None, **kwargs): def __init__(self, *args, name=None, **kwargs):
self.name=name self.name=name
#overridable #overridable
def flow(self, *args, **kwargs): def flow(self, *args, **kwargs):
raise Exception("this archi does not support flow. Use model classes directly.") raise Exception("this archi does not support flow. Use model classes directly.")
#overridable #overridable
def get_weights(self): def get_weights(self):
pass pass
nn.ArchiBase = ArchiBase nn.ArchiBase = ArchiBase

View file

@ -76,16 +76,16 @@ class DeepFakeArchi(nn.ArchiBase):
self.in_ch = in_ch self.in_ch = in_ch
self.e_ch = e_ch self.e_ch = e_ch
super().__init__(**kwargs) super().__init__(**kwargs)
def on_build(self): def on_build(self):
self.down1 = DownscaleBlock(self.in_ch, self.e_ch, n_downscales=4, kernel_size=5) self.down1 = DownscaleBlock(self.in_ch, self.e_ch, n_downscales=4, kernel_size=5)
def forward(self, inp): def forward(self, inp):
return nn.flatten(self.down1(inp)) return nn.flatten(self.down1(inp))
def get_out_res(self, res): def get_out_res(self, res):
return res // (2**4) return res // (2**4)
def get_out_ch(self): def get_out_ch(self):
return self.e_ch * 8 return self.e_ch * 8
@ -203,7 +203,7 @@ class DeepFakeArchi(nn.ArchiBase):
m = tf.nn.sigmoid(self.out_convm(m)) m = tf.nn.sigmoid(self.out_convm(m))
return x, m return x, m
self.Encoder = Encoder self.Encoder = Encoder
self.Inter = Inter self.Inter = Inter
self.Decoder = Decoder self.Decoder = Decoder

View file

@ -13,7 +13,7 @@ class Device(object):
self.index = index self.index = index
self.tf_dev_type = tf_dev_type self.tf_dev_type = tf_dev_type
self.name = name self.name = name
self.total_mem = total_mem self.total_mem = total_mem
self.total_mem_gb = total_mem / 1024**3 self.total_mem_gb = total_mem / 1024**3
self.free_mem = free_mem self.free_mem = free_mem
@ -91,16 +91,16 @@ class Devices(object):
@staticmethod @staticmethod
def _get_tf_devices_proc(q : multiprocessing.Queue): def _get_tf_devices_proc(q : multiprocessing.Queue):
if sys.platform[0:3] == 'win': if sys.platform[0:3] == 'win':
compute_cache_path = Path(os.environ['APPDATA']) / 'NVIDIA' / ('ComputeCache_ALL') compute_cache_path = Path(os.environ['APPDATA']) / 'NVIDIA' / ('ComputeCache_ALL')
os.environ['CUDA_CACHE_PATH'] = str(compute_cache_path) os.environ['CUDA_CACHE_PATH'] = str(compute_cache_path)
if not compute_cache_path.exists(): if not compute_cache_path.exists():
io.log_info("Caching GPU kernels...") io.log_info("Caching GPU kernels...")
compute_cache_path.mkdir(parents=True, exist_ok=True) compute_cache_path.mkdir(parents=True, exist_ok=True)
import tensorflow import tensorflow
tf_version = tensorflow.version.VERSION tf_version = tensorflow.version.VERSION
#if tf_version is None: #if tf_version is None:
# tf_version = tensorflow.version.GIT_VERSION # tf_version = tensorflow.version.GIT_VERSION
@ -110,7 +110,7 @@ class Devices(object):
tf = tensorflow.compat.v1 tf = tensorflow.compat.v1
else: else:
tf = tensorflow tf = tensorflow
import logging import logging
# Disable tensorflow warnings # Disable tensorflow warnings
tf_logger = logging.getLogger('tensorflow') tf_logger = logging.getLogger('tensorflow')
@ -119,19 +119,19 @@ class Devices(object):
from tensorflow.python.client import device_lib from tensorflow.python.client import device_lib
devices = [] devices = []
physical_devices = device_lib.list_local_devices() physical_devices = device_lib.list_local_devices()
physical_devices_f = {} physical_devices_f = {}
for dev in physical_devices: for dev in physical_devices:
dev_type = dev.device_type dev_type = dev.device_type
dev_tf_name = dev.name dev_tf_name = dev.name
dev_tf_name = dev_tf_name[ dev_tf_name.index(dev_type) : ] dev_tf_name = dev_tf_name[ dev_tf_name.index(dev_type) : ]
dev_idx = int(dev_tf_name.split(':')[-1]) dev_idx = int(dev_tf_name.split(':')[-1])
if dev_type in ['GPU','DML']: if dev_type in ['GPU','DML']:
dev_name = dev_tf_name dev_name = dev_tf_name
dev_desc = dev.physical_device_desc dev_desc = dev.physical_device_desc
if len(dev_desc) != 0: if len(dev_desc) != 0:
if dev_desc[0] == '{': if dev_desc[0] == '{':
@ -146,35 +146,35 @@ class Devices(object):
if param == 'name': if param == 'name':
dev_name = value dev_name = value
break break
physical_devices_f[dev_idx] = (dev_type, dev_name, dev.memory_limit) physical_devices_f[dev_idx] = (dev_type, dev_name, dev.memory_limit)
q.put(physical_devices_f) q.put(physical_devices_f)
time.sleep(0.1) time.sleep(0.1)
@staticmethod @staticmethod
def initialize_main_env(): def initialize_main_env():
if int(os.environ.get("NN_DEVICES_INITIALIZED", 0)) != 0: if int(os.environ.get("NN_DEVICES_INITIALIZED", 0)) != 0:
return return
if 'CUDA_VISIBLE_DEVICES' in os.environ.keys(): if 'CUDA_VISIBLE_DEVICES' in os.environ.keys():
os.environ.pop('CUDA_VISIBLE_DEVICES') os.environ.pop('CUDA_VISIBLE_DEVICES')
os.environ['CUDA_CACHE_MAXSIZE'] = '2147483647' os.environ['CUDA_CACHE_MAXSIZE'] = '2147483647'
os.environ['TF_MIN_GPU_MULTIPROCESSOR_COUNT'] = '2' os.environ['TF_MIN_GPU_MULTIPROCESSOR_COUNT'] = '2'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # tf log errors only os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # tf log errors only
q = multiprocessing.Queue() q = multiprocessing.Queue()
p = multiprocessing.Process(target=Devices._get_tf_devices_proc, args=(q,), daemon=True) p = multiprocessing.Process(target=Devices._get_tf_devices_proc, args=(q,), daemon=True)
p.start() p.start()
p.join() p.join()
visible_devices = q.get() visible_devices = q.get()
os.environ['NN_DEVICES_INITIALIZED'] = '1' os.environ['NN_DEVICES_INITIALIZED'] = '1'
os.environ['NN_DEVICES_COUNT'] = str(len(visible_devices)) os.environ['NN_DEVICES_COUNT'] = str(len(visible_devices))
for i in visible_devices: for i in visible_devices:
dev_type, name, total_mem = visible_devices[i] dev_type, name, total_mem = visible_devices[i]
@ -182,8 +182,8 @@ class Devices(object):
os.environ[f'NN_DEVICE_{i}_NAME'] = name os.environ[f'NN_DEVICE_{i}_NAME'] = name
os.environ[f'NN_DEVICE_{i}_TOTAL_MEM'] = str(total_mem) os.environ[f'NN_DEVICE_{i}_TOTAL_MEM'] = str(total_mem)
os.environ[f'NN_DEVICE_{i}_FREE_MEM'] = str(total_mem) os.environ[f'NN_DEVICE_{i}_FREE_MEM'] = str(total_mem)
@staticmethod @staticmethod
def getDevices(): def getDevices():
@ -204,20 +204,20 @@ class Devices(object):
""" """
# {'name' : name.split(b'\0', 1)[0].decode(), # {'name' : name.split(b'\0', 1)[0].decode(),
# 'total_mem' : totalMem.value # 'total_mem' : totalMem.value
# } # }
return return
min_cc = int(os.environ.get("TF_MIN_REQ_CAP", 35)) min_cc = int(os.environ.get("TF_MIN_REQ_CAP", 35))
libnames = ('libcuda.so', 'libcuda.dylib', 'nvcuda.dll') libnames = ('libcuda.so', 'libcuda.dylib', 'nvcuda.dll')
for libname in libnames: for libname in libnames:

View file

@ -28,8 +28,8 @@ class Conv2D(nn.LayerBase):
raise ValueError ("Wrong padding type. Should be VALID SAME or INT or 4x INTs") raise ValueError ("Wrong padding type. Should be VALID SAME or INT or 4x INTs")
else: else:
padding = int(padding) padding = int(padding)
self.in_ch = in_ch self.in_ch = in_ch
self.out_ch = out_ch self.out_ch = out_ch
@ -85,7 +85,7 @@ class Conv2D(nn.LayerBase):
else: else:
padding = [ [0,0], [0,0], [padding,padding], [padding,padding] ] padding = [ [0,0], [0,0], [padding,padding], [padding,padding] ]
x = tf.pad (x, padding, mode='CONSTANT') x = tf.pad (x, padding, mode='CONSTANT')
strides = self.strides strides = self.strides
if nn.data_format == "NHWC": if nn.data_format == "NHWC":
strides = [1,strides,strides,1] strides = [1,strides,strides,1]
@ -97,7 +97,7 @@ class Conv2D(nn.LayerBase):
dilations = [1,dilations,dilations,1] dilations = [1,dilations,dilations,1]
else: else:
dilations = [1,1,dilations,dilations] dilations = [1,1,dilations,dilations]
x = tf.nn.conv2d(x, weight, strides, 'VALID', dilations=dilations, data_format=nn.data_format) x = tf.nn.conv2d(x, weight, strides, 'VALID', dilations=dilations, data_format=nn.data_format)
if self.use_bias: if self.use_bias:
if nn.data_format == "NHWC": if nn.data_format == "NHWC":

View file

@ -3,7 +3,7 @@ tf = nn.tf
class DenseNorm(nn.LayerBase): class DenseNorm(nn.LayerBase):
def __init__(self, dense=False, eps=1e-06, dtype=None, **kwargs): def __init__(self, dense=False, eps=1e-06, dtype=None, **kwargs):
self.dense = dense self.dense = dense
if dtype is None: if dtype is None:
dtype = nn.floatx dtype = nn.floatx
self.eps = tf.constant(eps, dtype=dtype, name="epsilon") self.eps = tf.constant(eps, dtype=dtype, name="epsilon")
@ -12,5 +12,5 @@ class DenseNorm(nn.LayerBase):
def __call__(self, x): def __call__(self, x):
return x * tf.rsqrt(tf.reduce_mean(tf.square(x), axis=-1, keepdims=True) + self.eps) return x * tf.rsqrt(tf.reduce_mean(tf.square(x), axis=-1, keepdims=True) + self.eps)
nn.DenseNorm = DenseNorm nn.DenseNorm = DenseNorm

View file

@ -5,11 +5,11 @@ class LayerBase(nn.Saveable):
#override #override
def build_weights(self): def build_weights(self):
pass pass
#override #override
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
pass pass
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs) return self.forward(*args, **kwargs)

View file

@ -97,10 +97,10 @@ class Saveable():
nn.batch_set_value(tuples) nn.batch_set_value(tuples)
except: except:
return False return False
return True return True
def init_weights(self): def init_weights(self):
nn.init_weights(self.get_weights()) nn.init_weights(self.get_weights())
nn.Saveable = Saveable nn.Saveable = Saveable

View file

@ -2,7 +2,7 @@ from core.leras import nn
tf = nn.tf tf = nn.tf
class CodeDiscriminator(nn.ModelBase): class CodeDiscriminator(nn.ModelBase):
def on_build(self, in_ch, code_res, ch=256, conv_kernel_initializer=None): def on_build(self, in_ch, code_res, ch=256, conv_kernel_initializer=None):
n_downscales = 1 + code_res // 8 n_downscales = 1 + code_res // 8
self.convs = [] self.convs = []
@ -18,5 +18,5 @@ class CodeDiscriminator(nn.ModelBase):
for conv in self.convs: for conv in self.convs:
x = tf.nn.leaky_relu( conv(x), 0.1 ) x = tf.nn.leaky_relu( conv(x), 0.1 )
return self.out_conv(x) return self.out_conv(x)
nn.CodeDiscriminator = CodeDiscriminator nn.CodeDiscriminator = CodeDiscriminator

View file

@ -111,7 +111,7 @@ class UNetPatchDiscriminator(nn.ModelBase):
for i in range(layers_count-1): for i in range(layers_count-1):
st = 1 + (1 if val & (1 << i) !=0 else 0 ) st = 1 + (1 if val & (1 << i) !=0 else 0 )
layers.append ( [3, st ]) layers.append ( [3, st ])
sum_st += st sum_st += st
rf = self.calc_receptive_field_size(layers) rf = self.calc_receptive_field_size(layers)
@ -131,7 +131,7 @@ class UNetPatchDiscriminator(nn.ModelBase):
return s[q][2] return s[q][2]
def on_build(self, patch_size, in_ch, base_ch = 16): def on_build(self, patch_size, in_ch, base_ch = 16):
class ResidualBlock(nn.ModelBase): class ResidualBlock(nn.ModelBase):
def on_build(self, ch, kernel_size=3 ): def on_build(self, ch, kernel_size=3 ):
self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME') self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
@ -152,7 +152,7 @@ class UNetPatchDiscriminator(nn.ModelBase):
self.upres1 = [] self.upres1 = []
self.upres2 = [] self.upres2 = []
layers = self.find_archi(patch_size) layers = self.find_archi(patch_size)
level_chs = { i-1:v for i,v in enumerate([ min( base_ch * (2**i), 512 ) for i in range(len(layers)+1)]) } level_chs = { i-1:v for i,v in enumerate([ min( base_ch * (2**i), 512 ) for i in range(len(layers)+1)]) }
self.in_conv = nn.Conv2D( in_ch, level_chs[-1], kernel_size=1, padding='VALID') self.in_conv = nn.Conv2D( in_ch, level_chs[-1], kernel_size=1, padding='VALID')
@ -162,12 +162,12 @@ class UNetPatchDiscriminator(nn.ModelBase):
self.res1.append ( ResidualBlock(level_chs[i]) ) self.res1.append ( ResidualBlock(level_chs[i]) )
self.res2.append ( ResidualBlock(level_chs[i]) ) self.res2.append ( ResidualBlock(level_chs[i]) )
self.upconvs.insert (0, nn.Conv2DTranspose( level_chs[i]*(2 if i != len(layers)-1 else 1), level_chs[i-1], kernel_size=kernel_size, strides=strides, padding='SAME') ) self.upconvs.insert (0, nn.Conv2DTranspose( level_chs[i]*(2 if i != len(layers)-1 else 1), level_chs[i-1], kernel_size=kernel_size, strides=strides, padding='SAME') )
self.upres1.insert (0, ResidualBlock(level_chs[i-1]*2) ) self.upres1.insert (0, ResidualBlock(level_chs[i-1]*2) )
self.upres2.insert (0, ResidualBlock(level_chs[i-1]*2) ) self.upres2.insert (0, ResidualBlock(level_chs[i-1]*2) )
self.out_conv = nn.Conv2D( level_chs[-1]*2, 1, kernel_size=1, padding='VALID') self.out_conv = nn.Conv2D( level_chs[-1]*2, 1, kernel_size=1, padding='VALID')
self.center_out = nn.Conv2D( level_chs[len(layers)-1], 1, kernel_size=1, padding='VALID') self.center_out = nn.Conv2D( level_chs[len(layers)-1], 1, kernel_size=1, padding='VALID')
@ -183,7 +183,7 @@ class UNetPatchDiscriminator(nn.ModelBase):
x = tf.nn.leaky_relu( conv(x), 0.2 ) x = tf.nn.leaky_relu( conv(x), 0.2 )
x = res1(x) x = res1(x)
x = res2(x) x = res2(x)
center_out, x = self.center_out(x), tf.nn.leaky_relu( self.center_conv(x), 0.2 ) center_out, x = self.center_out(x), tf.nn.leaky_relu( self.center_conv(x), 0.2 )
for i, (upconv, enc, upres1, upres2 ) in enumerate(zip(self.upconvs, encs, self.upres1, self.upres2)): for i, (upconv, enc, upres1, upres2 ) in enumerate(zip(self.upconvs, encs, self.upres1, self.upres2)):

View file

@ -2,16 +2,16 @@ from core.leras import nn
tf = nn.tf tf = nn.tf
class XSeg(nn.ModelBase): class XSeg(nn.ModelBase):
def on_build (self, in_ch, base_ch, out_ch): def on_build (self, in_ch, base_ch, out_ch):
class ConvBlock(nn.ModelBase): class ConvBlock(nn.ModelBase):
def on_build(self, in_ch, out_ch): def on_build(self, in_ch, out_ch):
self.conv = nn.Conv2D (in_ch, out_ch, kernel_size=3, padding='SAME') self.conv = nn.Conv2D (in_ch, out_ch, kernel_size=3, padding='SAME')
self.frn = nn.FRNorm2D(out_ch) self.frn = nn.FRNorm2D(out_ch)
self.tlu = nn.TLU(out_ch) self.tlu = nn.TLU(out_ch)
def forward(self, x): def forward(self, x):
x = self.conv(x) x = self.conv(x)
x = self.frn(x) x = self.frn(x)
x = self.tlu(x) x = self.tlu(x)
@ -28,7 +28,7 @@ class XSeg(nn.ModelBase):
x = self.frn(x) x = self.frn(x)
x = self.tlu(x) x = self.tlu(x)
return x return x
self.base_ch = base_ch self.base_ch = base_ch
self.conv01 = ConvBlock(in_ch, base_ch) self.conv01 = ConvBlock(in_ch, base_ch)
@ -52,20 +52,20 @@ class XSeg(nn.ModelBase):
self.conv42 = ConvBlock(base_ch*8, base_ch*8) self.conv42 = ConvBlock(base_ch*8, base_ch*8)
self.conv43 = ConvBlock(base_ch*8, base_ch*8) self.conv43 = ConvBlock(base_ch*8, base_ch*8)
self.bp4 = nn.BlurPool (filt_size=2) self.bp4 = nn.BlurPool (filt_size=2)
self.conv51 = ConvBlock(base_ch*8, base_ch*8) self.conv51 = ConvBlock(base_ch*8, base_ch*8)
self.conv52 = ConvBlock(base_ch*8, base_ch*8) self.conv52 = ConvBlock(base_ch*8, base_ch*8)
self.conv53 = ConvBlock(base_ch*8, base_ch*8) self.conv53 = ConvBlock(base_ch*8, base_ch*8)
self.bp5 = nn.BlurPool (filt_size=2) self.bp5 = nn.BlurPool (filt_size=2)
self.dense1 = nn.Dense ( 4*4* base_ch*8, 512) self.dense1 = nn.Dense ( 4*4* base_ch*8, 512)
self.dense2 = nn.Dense ( 512, 4*4* base_ch*8) self.dense2 = nn.Dense ( 512, 4*4* base_ch*8)
self.up5 = UpConvBlock (base_ch*8, base_ch*4) self.up5 = UpConvBlock (base_ch*8, base_ch*4)
self.uconv53 = ConvBlock(base_ch*12, base_ch*8) self.uconv53 = ConvBlock(base_ch*12, base_ch*8)
self.uconv52 = ConvBlock(base_ch*8, base_ch*8) self.uconv52 = ConvBlock(base_ch*8, base_ch*8)
self.uconv51 = ConvBlock(base_ch*8, base_ch*8) self.uconv51 = ConvBlock(base_ch*8, base_ch*8)
self.up4 = UpConvBlock (base_ch*8, base_ch*4) self.up4 = UpConvBlock (base_ch*8, base_ch*4)
self.uconv43 = ConvBlock(base_ch*12, base_ch*8) self.uconv43 = ConvBlock(base_ch*12, base_ch*8)
self.uconv42 = ConvBlock(base_ch*8, base_ch*8) self.uconv42 = ConvBlock(base_ch*8, base_ch*8)
@ -88,8 +88,8 @@ class XSeg(nn.ModelBase):
self.uconv02 = ConvBlock(base_ch*2, base_ch) self.uconv02 = ConvBlock(base_ch*2, base_ch)
self.uconv01 = ConvBlock(base_ch, base_ch) self.uconv01 = ConvBlock(base_ch, base_ch)
self.out_conv = nn.Conv2D (base_ch, out_ch, kernel_size=3, padding='SAME') self.out_conv = nn.Conv2D (base_ch, out_ch, kernel_size=3, padding='SAME')
def forward(self, inp): def forward(self, inp):
x = inp x = inp
@ -119,17 +119,17 @@ class XSeg(nn.ModelBase):
x = self.conv52(x) x = self.conv52(x)
x = x5 = self.conv53(x) x = x5 = self.conv53(x)
x = self.bp5(x) x = self.bp5(x)
x = nn.flatten(x) x = nn.flatten(x)
x = self.dense1(x) x = self.dense1(x)
x = self.dense2(x) x = self.dense2(x)
x = nn.reshape_4D (x, 4, 4, self.base_ch*8 ) x = nn.reshape_4D (x, 4, 4, self.base_ch*8 )
x = self.up5(x) x = self.up5(x)
x = self.uconv53(tf.concat([x,x5],axis=nn.conv2d_ch_axis)) x = self.uconv53(tf.concat([x,x5],axis=nn.conv2d_ch_axis))
x = self.uconv52(x) x = self.uconv52(x)
x = self.uconv51(x) x = self.uconv51(x)
x = self.up4(x) x = self.up4(x)
x = self.uconv43(tf.concat([x,x4],axis=nn.conv2d_ch_axis)) x = self.uconv43(tf.concat([x,x4],axis=nn.conv2d_ch_axis))
x = self.uconv42(x) x = self.uconv42(x)

View file

@ -34,13 +34,13 @@ class nn():
tf_sess = None tf_sess = None
tf_sess_config = None tf_sess_config = None
tf_default_device_name = None tf_default_device_name = None
data_format = None data_format = None
conv2d_ch_axis = None conv2d_ch_axis = None
conv2d_spatial_axes = None conv2d_spatial_axes = None
floatx = None floatx = None
@staticmethod @staticmethod
def initialize(device_config=None, floatx="float32", data_format="NHWC"): def initialize(device_config=None, floatx="float32", data_format="NHWC"):
@ -67,7 +67,7 @@ class nn():
first_run = True first_run = True
compute_cache_path.mkdir(parents=True, exist_ok=True) compute_cache_path.mkdir(parents=True, exist_ok=True)
os.environ['CUDA_CACHE_PATH'] = str(compute_cache_path) os.environ['CUDA_CACHE_PATH'] = str(compute_cache_path)
if first_run: if first_run:
io.log_info("Caching GPU kernels...") io.log_info("Caching GPU kernels...")
@ -87,7 +87,7 @@ class nn():
# Disable tensorflow warnings # Disable tensorflow warnings
tf_logger = logging.getLogger('tensorflow') tf_logger = logging.getLogger('tensorflow')
tf_logger.setLevel(logging.ERROR) tf_logger.setLevel(logging.ERROR)
if tf_version[0] == '2': if tf_version[0] == '2':
tf.disable_v2_behavior() tf.disable_v2_behavior()
nn.tf = tf nn.tf = tf
@ -99,21 +99,21 @@ class nn():
import core.leras.optimizers import core.leras.optimizers
import core.leras.models import core.leras.models
import core.leras.archis import core.leras.archis
# Configure tensorflow session-config # Configure tensorflow session-config
if len(device_config.devices) == 0: if len(device_config.devices) == 0:
config = tf.ConfigProto(device_count={'GPU': 0}) config = tf.ConfigProto(device_count={'GPU': 0})
nn.tf_default_device_name = '/CPU:0' nn.tf_default_device_name = '/CPU:0'
else: else:
nn.tf_default_device_name = f'/{device_config.devices[0].tf_dev_type}:0' nn.tf_default_device_name = f'/{device_config.devices[0].tf_dev_type}:0'
config = tf.ConfigProto() config = tf.ConfigProto()
config.gpu_options.visible_device_list = ','.join([str(device.index) for device in device_config.devices]) config.gpu_options.visible_device_list = ','.join([str(device.index) for device in device_config.devices])
config.gpu_options.force_gpu_compatible = True config.gpu_options.force_gpu_compatible = True
config.gpu_options.allow_growth = True config.gpu_options.allow_growth = True
nn.tf_sess_config = config nn.tf_sess_config = config
if nn.tf_sess is None: if nn.tf_sess is None:
nn.tf_sess = tf.Session(config=nn.tf_sess_config) nn.tf_sess = tf.Session(config=nn.tf_sess_config)
@ -260,7 +260,7 @@ class nn():
@staticmethod @staticmethod
def ask_choose_device(*args, **kwargs): def ask_choose_device(*args, **kwargs):
return nn.DeviceConfig.GPUIndexes( nn.ask_choose_device_idxs(*args,**kwargs) ) return nn.DeviceConfig.GPUIndexes( nn.ask_choose_device_idxs(*args,**kwargs) )
def __init__ (self, devices=None): def __init__ (self, devices=None):
devices = devices or [] devices = devices or []

View file

@ -382,7 +382,7 @@ def total_variation_mse(images):
""" """
pixel_dif1 = images[:, 1:, :, :] - images[:, :-1, :, :] pixel_dif1 = images[:, 1:, :, :] - images[:, :-1, :, :]
pixel_dif2 = images[:, :, 1:, :] - images[:, :, :-1, :] pixel_dif2 = images[:, :, 1:, :] - images[:, :, :-1, :]
tot_var = ( tf.reduce_sum(tf.square(pixel_dif1), axis=[1,2,3]) + tot_var = ( tf.reduce_sum(tf.square(pixel_dif1), axis=[1,2,3]) +
tf.reduce_sum(tf.square(pixel_dif2), axis=[1,2,3]) ) tf.reduce_sum(tf.square(pixel_dif2), axis=[1,2,3]) )
return tot_var return tot_var
@ -392,7 +392,7 @@ nn.total_variation_mse = total_variation_mse
def pixel_norm(x, axes): def pixel_norm(x, axes):
return x * tf.rsqrt(tf.reduce_mean(tf.square(x), axis=axes, keepdims=True) + 1e-06) return x * tf.rsqrt(tf.reduce_mean(tf.square(x), axis=axes, keepdims=True) + 1e-06)
nn.pixel_norm = pixel_norm nn.pixel_norm = pixel_norm
""" """
def tf_suppress_lower_mean(t, eps=0.00001): def tf_suppress_lower_mean(t, eps=0.00001):
if t.shape.ndims != 1: if t.shape.ndims != 1:

View file

@ -37,12 +37,12 @@ class AdaBelief(nn.OptimizerBase):
vs = { v.name : tf.get_variable ( f'vs_{v.name}'.replace(':','_'), v.shape, dtype=v.dtype, initializer=tf.initializers.constant(0.0), trainable=False) for v in trainable_weights } vs = { v.name : tf.get_variable ( f'vs_{v.name}'.replace(':','_'), v.shape, dtype=v.dtype, initializer=tf.initializers.constant(0.0), trainable=False) for v in trainable_weights }
self.ms_dict.update (ms) self.ms_dict.update (ms)
self.vs_dict.update (vs) self.vs_dict.update (vs)
if self.lr_dropout != 1.0: if self.lr_dropout != 1.0:
e = tf.device('/CPU:0') if lr_dropout_on_cpu else None e = tf.device('/CPU:0') if lr_dropout_on_cpu else None
if e: e.__enter__() if e: e.__enter__()
lr_rnds = [ nn.random_binomial( v.shape, p=self.lr_dropout, dtype=v.dtype) for v in trainable_weights ] lr_rnds = [ nn.random_binomial( v.shape, p=self.lr_dropout, dtype=v.dtype) for v in trainable_weights ]
if e: e.__exit__(None, None, None) if e: e.__exit__(None, None, None)
self.lr_rnds_dict.update ( { v.name : rnd for v,rnd in zip(trainable_weights,lr_rnds) } ) self.lr_rnds_dict.update ( { v.name : rnd for v,rnd in zip(trainable_weights,lr_rnds) } )
if e: e.__exit__(None, None, None) if e: e.__exit__(None, None, None)
@ -58,7 +58,7 @@ class AdaBelief(nn.OptimizerBase):
ms = self.ms_dict[ v.name ] ms = self.ms_dict[ v.name ]
vs = self.vs_dict[ v.name ] vs = self.vs_dict[ v.name ]
m_t = self.beta_1*ms + (1.0-self.beta_1) * g m_t = self.beta_1*ms + (1.0-self.beta_1) * g
v_t = self.beta_2*vs + (1.0-self.beta_2) * tf.square(g-m_t) v_t = self.beta_2*vs + (1.0-self.beta_2) * tf.square(g-m_t)

View file

@ -13,12 +13,12 @@ class RMSprop(nn.OptimizerBase):
self.lr = lr self.lr = lr
self.rho = rho self.rho = rho
self.epsilon = epsilon self.epsilon = epsilon
self.clipnorm = clipnorm self.clipnorm = clipnorm
with tf.device('/CPU:0') : with tf.device('/CPU:0') :
with tf.variable_scope(self.name): with tf.variable_scope(self.name):
self.iterations = tf.Variable(0, dtype=tf.int64, name='iters') self.iterations = tf.Variable(0, dtype=tf.int64, name='iters')
self.accumulators_dict = {} self.accumulators_dict = {}
@ -37,9 +37,9 @@ class RMSprop(nn.OptimizerBase):
if self.lr_dropout != 1.0: if self.lr_dropout != 1.0:
e = tf.device('/CPU:0') if lr_dropout_on_cpu else None e = tf.device('/CPU:0') if lr_dropout_on_cpu else None
if e: e.__enter__() if e: e.__enter__()
lr_rnds = [ nn.random_binomial( v.shape, p=self.lr_dropout, dtype=v.dtype) for v in trainable_weights ] lr_rnds = [ nn.random_binomial( v.shape, p=self.lr_dropout, dtype=v.dtype) for v in trainable_weights ]
if e: e.__exit__(None, None, None) if e: e.__exit__(None, None, None)
self.lr_rnds_dict.update ( { v.name : rnd for v,rnd in zip(trainable_weights,lr_rnds) } ) self.lr_rnds_dict.update ( { v.name : rnd for v,rnd in zip(trainable_weights,lr_rnds) } )
if e: e.__exit__(None, None, None) if e: e.__exit__(None, None, None)

View file

@ -42,7 +42,7 @@ def rotate_point(origin, point, deg):
qx = ox + math.cos(rad) * (px - ox) - math.sin(rad) * (py - oy) qx = ox + math.cos(rad) * (px - ox) - math.sin(rad) * (py - oy)
qy = oy + math.sin(rad) * (px - ox) + math.cos(rad) * (py - oy) qy = oy + math.sin(rad) * (px - ox) + math.cos(rad) * (py - oy)
return np.float32([qx, qy]) return np.float32([qx, qy])
def transform_points(points, mat, invert=False): def transform_points(points, mat, invert=False):
if invert: if invert:
mat = cv2.invertAffineTransform (mat) mat = cv2.invertAffineTransform (mat)
@ -51,47 +51,47 @@ def transform_points(points, mat, invert=False):
points = np.squeeze(points) points = np.squeeze(points)
return points return points
def transform_mat(mat, res, tx, ty, rotation, scale): def transform_mat(mat, res, tx, ty, rotation, scale):
""" """
transform mat in local space of res transform mat in local space of res
scale -> translate -> rotate scale -> translate -> rotate
tx,ty float tx,ty float
rotation int degrees rotation int degrees
scale float scale float
""" """
lt, rt, lb, ct = transform_points ( np.float32([(0,0),(res,0),(0,res),(res / 2, res/2) ]),mat, True) lt, rt, lb, ct = transform_points ( np.float32([(0,0),(res,0),(0,res),(res / 2, res/2) ]),mat, True)
hor_v = (rt-lt).astype(np.float32) hor_v = (rt-lt).astype(np.float32)
hor_size = npla.norm(hor_v) hor_size = npla.norm(hor_v)
hor_v /= hor_size hor_v /= hor_size
ver_v = (lb-lt).astype(np.float32) ver_v = (lb-lt).astype(np.float32)
ver_size = npla.norm(ver_v) ver_size = npla.norm(ver_v)
ver_v /= ver_size ver_v /= ver_size
bt_diag_vec = (rt-ct).astype(np.float32) bt_diag_vec = (rt-ct).astype(np.float32)
half_diag_len = npla.norm(bt_diag_vec) half_diag_len = npla.norm(bt_diag_vec)
bt_diag_vec /= half_diag_len bt_diag_vec /= half_diag_len
tb_diag_vec = np.float32( [ -bt_diag_vec[1], bt_diag_vec[0] ] ) tb_diag_vec = np.float32( [ -bt_diag_vec[1], bt_diag_vec[0] ] )
rt = ct + bt_diag_vec*half_diag_len*scale rt = ct + bt_diag_vec*half_diag_len*scale
lb = ct - bt_diag_vec*half_diag_len*scale lb = ct - bt_diag_vec*half_diag_len*scale
lt = ct - tb_diag_vec*half_diag_len*scale lt = ct - tb_diag_vec*half_diag_len*scale
rt[0] += tx*hor_size rt[0] += tx*hor_size
lb[0] += tx*hor_size lb[0] += tx*hor_size
lt[0] += tx*hor_size lt[0] += tx*hor_size
rt[1] += ty*ver_size rt[1] += ty*ver_size
lb[1] += ty*ver_size lb[1] += ty*ver_size
lt[1] += ty*ver_size lt[1] += ty*ver_size
rt = rotate_point(ct, rt, rotation) rt = rotate_point(ct, rt, rotation)
lb = rotate_point(ct, lb, rotation) lb = rotate_point(ct, lb, rotation)
lt = rotate_point(ct, lt, rotation) lt = rotate_point(ct, lt, rotation)
return cv2.getAffineTransform( np.float32([lt, rt, lb]), np.float32([ [0,0], [res,0], [0,res] ]) ) return cv2.getAffineTransform( np.float32([lt, rt, lb]), np.float32([ [0,0], [res,0], [0,res] ]) )

View file

@ -105,23 +105,23 @@ class Index2DHost():
np.random.shuffle(shuffle_idxs) np.random.shuffle(shuffle_idxs)
idx_1D = shuffle_idxs.pop() idx_1D = shuffle_idxs.pop()
#print(f'idx_1D = {idx_1D}, len(shuffle_idxs_2D[idx_1D])= {len(shuffle_idxs_2D[idx_1D])}') #print(f'idx_1D = {idx_1D}, len(shuffle_idxs_2D[idx_1D])= {len(shuffle_idxs_2D[idx_1D])}')
if len(shuffle_idxs_2D[idx_1D]) == 0: if len(shuffle_idxs_2D[idx_1D]) == 0:
shuffle_idxs_2D[idx_1D] = idxs_2D[idx_1D].copy() shuffle_idxs_2D[idx_1D] = idxs_2D[idx_1D].copy()
#print(f'new shuffle_idxs_2d for {idx_1D} = { shuffle_idxs_2D[idx_1D] }') #print(f'new shuffle_idxs_2d for {idx_1D} = { shuffle_idxs_2D[idx_1D] }')
#print(f'len(shuffle_idxs_2D[idx_1D])= {len(shuffle_idxs_2D[idx_1D])}') #print(f'len(shuffle_idxs_2D[idx_1D])= {len(shuffle_idxs_2D[idx_1D])}')
np.random.shuffle( shuffle_idxs_2D[idx_1D] ) np.random.shuffle( shuffle_idxs_2D[idx_1D] )
idx_2D = shuffle_idxs_2D[idx_1D].pop() idx_2D = shuffle_idxs_2D[idx_1D].pop()
#print(f'len(shuffle_idxs_2D[idx_1D])= {len(shuffle_idxs_2D[idx_1D])}') #print(f'len(shuffle_idxs_2D[idx_1D])= {len(shuffle_idxs_2D[idx_1D])}')
#print(f'idx_2D = {idx_2D}') #print(f'idx_2D = {idx_2D}')
result.append( indexes2D[idx_1D][idx_2D]) result.append( indexes2D[idx_1D][idx_2D])

View file

@ -32,6 +32,5 @@ def get_screen_size():
pass pass
elif 'linux' in sys.platform: elif 'linux' in sys.platform:
pass pass
return (1366, 768) return (1366, 768)

View file

@ -59,7 +59,7 @@ def get_paths(dir_path):
return [ Path(x) for x in sorted([ x.path for x in list(scandir(str(dir_path))) ]) ] return [ Path(x) for x in sorted([ x.path for x in list(scandir(str(dir_path))) ]) ]
else: else:
return [] return []
def get_file_paths(dir_path): def get_file_paths(dir_path):
dir_path = Path (dir_path) dir_path = Path (dir_path)

View file

@ -13,7 +13,7 @@ from .qtex import *
class QSubprocessor(object): class QSubprocessor(object):
""" """
""" """
class Cli(object): class Cli(object):
@ -149,11 +149,11 @@ class QSubprocessor(object):
#ok some processes survived, initialize host logic #ok some processes survived, initialize host logic
self.on_clients_initialized() self.on_clients_initialized()
self.q_timer = QTimer() self.q_timer = QTimer()
self.q_timer.timeout.connect(self.tick) self.q_timer.timeout.connect(self.tick)
self.q_timer.start(5) self.q_timer.start(5)
#overridable #overridable
def process_info_generator(self): def process_info_generator(self):
#yield per process (name, host_dict, client_dict) #yield per process (name, host_dict, client_dict)
@ -259,4 +259,4 @@ class QSubprocessor(object):
self.q_timer.stop() self.q_timer.stop()
self.q_timer = None self.q_timer = None
self.on_clients_finalized() self.on_clients_finalized()

View file

@ -12,10 +12,10 @@ class QXIconButton(QPushButton):
currently works only with one-key shortcut currently works only with one-key shortcut
""" """
def __init__(self, icon, def __init__(self, icon,
tooltip=None, tooltip=None,
shortcut=None, shortcut=None,
click_func=None, click_func=None,
first_repeat_delay=300, first_repeat_delay=300,
repeat_delay=20, repeat_delay=20,
): ):
@ -23,28 +23,28 @@ class QXIconButton(QPushButton):
super().__init__(icon, "") super().__init__(icon, "")
self.setIcon(icon) self.setIcon(icon)
if shortcut is not None: if shortcut is not None:
tooltip = f"{tooltip} ( {StringsDB['S_HOT_KEY'] }: {shortcut} )" tooltip = f"{tooltip} ( {StringsDB['S_HOT_KEY'] }: {shortcut} )"
self.setToolTip(tooltip) self.setToolTip(tooltip)
self.seq = QKeySequence(shortcut) if shortcut is not None else None self.seq = QKeySequence(shortcut) if shortcut is not None else None
QXMainWindow.inst.add_keyPressEvent_listener ( self.on_keyPressEvent ) QXMainWindow.inst.add_keyPressEvent_listener ( self.on_keyPressEvent )
QXMainWindow.inst.add_keyReleaseEvent_listener ( self.on_keyReleaseEvent ) QXMainWindow.inst.add_keyReleaseEvent_listener ( self.on_keyReleaseEvent )
self.click_func = click_func self.click_func = click_func
self.first_repeat_delay = first_repeat_delay self.first_repeat_delay = first_repeat_delay
self.repeat_delay = repeat_delay self.repeat_delay = repeat_delay
self.repeat_timer = None self.repeat_timer = None
self.op_device = None self.op_device = None
self.pressed.connect( lambda : self.action(is_pressed=True) ) self.pressed.connect( lambda : self.action(is_pressed=True) )
self.released.connect( lambda : self.action(is_pressed=False) ) self.released.connect( lambda : self.action(is_pressed=False) )
def action(self, is_pressed=None, op_device=None): def action(self, is_pressed=None, op_device=None):
if self.click_func is None: if self.click_func is None:
return return
@ -64,12 +64,12 @@ class QXIconButton(QPushButton):
self.click_func() self.click_func()
if self.repeat_timer is not None: if self.repeat_timer is not None:
self.repeat_timer.setInterval(self.repeat_delay) self.repeat_timer.setInterval(self.repeat_delay)
def on_keyPressEvent(self, ev): def on_keyPressEvent(self, ev):
key = ev.nativeVirtualKey() key = ev.nativeVirtualKey()
if ev.isAutoRepeat(): if ev.isAutoRepeat():
return return
if self.seq is not None: if self.seq is not None:
if key == self.seq[0]: if key == self.seq[0]:
self.action(is_pressed=True) self.action(is_pressed=True)

View file

@ -8,27 +8,27 @@ class QXMainWindow(QWidget):
""" """
inst = None inst = None
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
if QXMainWindow.inst is not None: if QXMainWindow.inst is not None:
raise Exception("QXMainWindow can only be one.") raise Exception("QXMainWindow can only be one.")
QXMainWindow.inst = self QXMainWindow.inst = self
self.keyPressEvent_listeners = [] self.keyPressEvent_listeners = []
self.keyReleaseEvent_listeners = [] self.keyReleaseEvent_listeners = []
self.setFocusPolicy(Qt.WheelFocus) self.setFocusPolicy(Qt.WheelFocus)
def add_keyPressEvent_listener(self, func): def add_keyPressEvent_listener(self, func):
self.keyPressEvent_listeners.append (func) self.keyPressEvent_listeners.append (func)
def add_keyReleaseEvent_listener(self, func): def add_keyReleaseEvent_listener(self, func):
self.keyReleaseEvent_listeners.append (func) self.keyReleaseEvent_listeners.append (func)
def keyPressEvent(self, ev): def keyPressEvent(self, ev):
super().keyPressEvent(ev) super().keyPressEvent(ev)
for func in self.keyPressEvent_listeners: for func in self.keyPressEvent_listeners:
func(ev) func(ev)
def keyReleaseEvent(self, ev): def keyReleaseEvent(self, ev):
super().keyReleaseEvent(ev) super().keyReleaseEvent(ev)
for func in self.keyReleaseEvent_listeners: for func in self.keyReleaseEvent_listeners:
func(ev) func(ev)

View file

@ -5,27 +5,27 @@ from PyQt5.QtWidgets import *
from localization import StringsDB from localization import StringsDB
from .QXMainWindow import * from .QXMainWindow import *
class QActionEx(QAction): class QActionEx(QAction):
def __init__(self, icon, text, shortcut=None, trigger_func=None, shortcut_in_tooltip=False, is_checkable=False, is_auto_repeat=False ): def __init__(self, icon, text, shortcut=None, trigger_func=None, shortcut_in_tooltip=False, is_checkable=False, is_auto_repeat=False ):
super().__init__(icon, text) super().__init__(icon, text)
if shortcut is not None: if shortcut is not None:
self.setShortcut(shortcut) self.setShortcut(shortcut)
if shortcut_in_tooltip: if shortcut_in_tooltip:
self.setToolTip( f"{text} ( {StringsDB['S_HOT_KEY'] }: {shortcut} )") self.setToolTip( f"{text} ( {StringsDB['S_HOT_KEY'] }: {shortcut} )")
if trigger_func is not None: if trigger_func is not None:
self.triggered.connect(trigger_func) self.triggered.connect(trigger_func)
if is_checkable: if is_checkable:
self.setCheckable(True) self.setCheckable(True)
self.setAutoRepeat(is_auto_repeat) self.setAutoRepeat(is_auto_repeat)
def QImage_from_np(img): def QImage_from_np(img):
if img.dtype != np.uint8: if img.dtype != np.uint8:
raise ValueError("img should be in np.uint8 format") raise ValueError("img should be in np.uint8 format")
h,w,c = img.shape h,w,c = img.shape
if c == 1: if c == 1:
fmt = QImage.Format_Grayscale8 fmt = QImage.Format_Grayscale8
@ -34,33 +34,33 @@ def QImage_from_np(img):
elif c == 4: elif c == 4:
fmt = QImage.Format_ARGB32 fmt = QImage.Format_ARGB32
else: else:
raise ValueError("unsupported channel count") raise ValueError("unsupported channel count")
return QImage(img.data, w, h, c*w, fmt ) return QImage(img.data, w, h, c*w, fmt )
def QImage_to_np(q_img, fmt=QImage.Format_BGR888): def QImage_to_np(q_img, fmt=QImage.Format_BGR888):
q_img = q_img.convertToFormat(fmt) q_img = q_img.convertToFormat(fmt)
width = q_img.width() width = q_img.width()
height = q_img.height() height = q_img.height()
b = q_img.constBits() b = q_img.constBits()
b.setsize(height * width * 3) b.setsize(height * width * 3)
arr = np.frombuffer(b, np.uint8).reshape((height, width, 3)) arr = np.frombuffer(b, np.uint8).reshape((height, width, 3))
return arr#[::-1] return arr#[::-1]
def QPixmap_from_np(img): def QPixmap_from_np(img):
return QPixmap.fromImage(QImage_from_np(img)) return QPixmap.fromImage(QImage_from_np(img))
def QPoint_from_np(n): def QPoint_from_np(n):
return QPoint(*n.astype(np.int)) return QPoint(*n.astype(np.int))
def QPoint_to_np(q): def QPoint_to_np(q):
return np.int32( [q.x(), q.y()] ) return np.int32( [q.x(), q.y()] )
def QSize_to_np(q): def QSize_to_np(q):
return np.int32( [q.width(), q.height()] ) return np.int32( [q.width(), q.height()] )
class QDarkPalette(QPalette): class QDarkPalette(QPalette):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -71,7 +71,7 @@ class QDarkPalette(QPalette):
self.setColor(QPalette.AlternateBase, QColor(53, 53, 53)) self.setColor(QPalette.AlternateBase, QColor(53, 53, 53))
self.setColor(QPalette.ToolTipBase, text_color ) self.setColor(QPalette.ToolTipBase, text_color )
self.setColor(QPalette.ToolTipText, text_color ) self.setColor(QPalette.ToolTipText, text_color )
self.setColor(QPalette.Text, text_color ) self.setColor(QPalette.Text, text_color )
self.setColor(QPalette.Button, QColor(53, 53, 53)) self.setColor(QPalette.Button, QColor(53, 53, 53))
self.setColor(QPalette.ButtonText, Qt.white) self.setColor(QPalette.ButtonText, Qt.white)
self.setColor(QPalette.BrightText, Qt.red) self.setColor(QPalette.BrightText, Qt.red)

View file

@ -14,7 +14,7 @@ ported from https://github.com/1adrianb/face-alignment
""" """
class FANExtractor(object): class FANExtractor(object):
def __init__ (self, landmarks_3D=False, place_model_on_cpu=False): def __init__ (self, landmarks_3D=False, place_model_on_cpu=False):
model_path = Path(__file__).parent / ( "2DFAN.npy" if not landmarks_3D else "3DFAN.npy") model_path = Path(__file__).parent / ( "2DFAN.npy" if not landmarks_3D else "3DFAN.npy")
if not model_path.exists(): if not model_path.exists():
raise Exception("Unable to load FANExtractor model") raise Exception("Unable to load FANExtractor model")

View file

@ -164,7 +164,7 @@ class FaceEnhancer(object):
with tf.device ('/CPU:0' if place_model_on_cpu else nn.tf_default_device_name): with tf.device ('/CPU:0' if place_model_on_cpu else nn.tf_default_device_name):
self.model = FaceEnhancer() self.model = FaceEnhancer()
self.model.load_weights (model_path) self.model.load_weights (model_path)
with tf.device ('/CPU:0' if run_on_cpu else nn.tf_default_device_name): with tf.device ('/CPU:0' if run_on_cpu else nn.tf_default_device_name):
self.model.build_for_run ([ (tf.float32, nn.get4Dshape (192,192,3) ), self.model.build_for_run ([ (tf.float32, nn.get4Dshape (192,192,3) ),
(tf.float32, (None,1,) ), (tf.float32, (None,1,) ),

View file

@ -30,8 +30,8 @@ to_string_dict = { FaceType.HALF : 'half_face',
FaceType.WHOLE_FACE : 'whole_face', FaceType.WHOLE_FACE : 'whole_face',
FaceType.HEAD : 'head', FaceType.HEAD : 'head',
FaceType.HEAD_NO_ALIGN : 'head_no_align', FaceType.HEAD_NO_ALIGN : 'head_no_align',
FaceType.MARK_ONLY :'mark_only', FaceType.MARK_ONLY :'mark_only',
} }
from_string_dict = { to_string_dict[x] : x for x in to_string_dict.keys() } from_string_dict = { to_string_dict[x] : x for x in to_string_dict.keys() }

View file

@ -201,7 +201,7 @@ landmarks_68_3D = np.array( [
[8.444722 , 25.326198 , -21.025520 ], #33 [8.444722 , 25.326198 , -21.025520 ], #33
[24.474473 , 28.323008 , -5.712776 ], #34 [24.474473 , 28.323008 , -5.712776 ], #34
[8.449166 , 30.596216 , -20.671489 ], #35 [8.449166 , 30.596216 , -20.671489 ], #35
[0.205322 , 31.408738 , -21.903670 ], #36 [0.205322 , 31.408738 , -21.903670 ], #36
[-7.198266 , 30.844876 , -20.328022 ] #37 [-7.198266 , 30.844876 , -20.328022 ] #37
], dtype=np.float32) ], dtype=np.float32)
@ -303,18 +303,18 @@ def get_transform_mat (image_landmarks, output_size, face_type, scale=1.0):
elif face_type == FaceType.HEAD: elif face_type == FaceType.HEAD:
mat = umeyama( np.concatenate ( [ image_landmarks[17:49] , image_landmarks[54:55] ] ) , landmarks_2D_new, True)[0:2] mat = umeyama( np.concatenate ( [ image_landmarks[17:49] , image_landmarks[54:55] ] ) , landmarks_2D_new, True)[0:2]
# assuming image_landmarks are 3D_Landmarks extracted for HEAD, # assuming image_landmarks are 3D_Landmarks extracted for HEAD,
# adjust horizontal offset according to estimated yaw # adjust horizontal offset according to estimated yaw
yaw = estimate_averaged_yaw(transform_points (image_landmarks, mat, False)) yaw = estimate_averaged_yaw(transform_points (image_landmarks, mat, False))
hvec = (g_p[0]-g_p[1]).astype(np.float32) hvec = (g_p[0]-g_p[1]).astype(np.float32)
hvec_len = npla.norm(hvec) hvec_len = npla.norm(hvec)
hvec /= hvec_len hvec /= hvec_len
yaw *= np.abs(math.tanh(yaw*2)) # Damp near zero yaw *= np.abs(math.tanh(yaw*2)) # Damp near zero
g_c -= hvec * (yaw * hvec_len / 2.0) g_c -= hvec * (yaw * hvec_len / 2.0)
# adjust vertical offset for HEAD, 50% below # adjust vertical offset for HEAD, 50% below
vvec = (g_p[0]-g_p[3]).astype(np.float32) vvec = (g_p[0]-g_p[3]).astype(np.float32)
@ -454,7 +454,7 @@ def get_image_mouth_mask (image_shape, image_landmarks):
hull_mask = hull_mask[...,None] hull_mask = hull_mask[...,None]
return hull_mask return hull_mask
def alpha_to_color (img_alpha, color): def alpha_to_color (img_alpha, color):
if len(img_alpha.shape) == 2: if len(img_alpha.shape) == 2:
img_alpha = img_alpha[...,None] img_alpha = img_alpha[...,None]
@ -741,10 +741,10 @@ def estimate_averaged_yaw(landmarks):
# Works much better than solvePnP if landmarks from "3DFAN" # Works much better than solvePnP if landmarks from "3DFAN"
if not isinstance(landmarks, np.ndarray): if not isinstance(landmarks, np.ndarray):
landmarks = np.array (landmarks) landmarks = np.array (landmarks)
l = ( (landmarks[27][0]-landmarks[0][0]) + (landmarks[28][0]-landmarks[1][0]) + (landmarks[29][0]-landmarks[2][0]) ) / 3.0 l = ( (landmarks[27][0]-landmarks[0][0]) + (landmarks[28][0]-landmarks[1][0]) + (landmarks[29][0]-landmarks[2][0]) ) / 3.0
r = ( (landmarks[16][0]-landmarks[27][0]) + (landmarks[15][0]-landmarks[28][0]) + (landmarks[14][0]-landmarks[29][0]) ) / 3.0 r = ( (landmarks[16][0]-landmarks[27][0]) + (landmarks[15][0]-landmarks[28][0]) + (landmarks[14][0]-landmarks[29][0]) ) / 3.0
return float(r-l) return float(r-l)
def estimate_pitch_yaw_roll(aligned_landmarks, size=256): def estimate_pitch_yaw_roll(aligned_landmarks, size=256):
""" """
returns pitch,yaw,roll [-pi/2...+pi/2] returns pitch,yaw,roll [-pi/2...+pi/2]
@ -764,7 +764,7 @@ def estimate_pitch_yaw_roll(aligned_landmarks, size=256):
np.zeros((4, 1)) ) np.zeros((4, 1)) )
pitch, yaw, roll = mathlib.rotationMatrixToEulerAngles( cv2.Rodrigues(rotation_vector)[0] ) pitch, yaw, roll = mathlib.rotationMatrixToEulerAngles( cv2.Rodrigues(rotation_vector)[0] )
half_pi = math.pi / 2.0 half_pi = math.pi / 2.0
pitch = np.clip ( pitch, -half_pi, half_pi ) pitch = np.clip ( pitch, -half_pi, half_pi )
yaw = np.clip ( yaw , -half_pi, half_pi ) yaw = np.clip ( yaw , -half_pi, half_pi )

View file

@ -13,26 +13,26 @@ from core.leras import nn
class XSegNet(object): class XSegNet(object):
VERSION = 1 VERSION = 1
def __init__ (self, name, def __init__ (self, name,
resolution=256, resolution=256,
load_weights=True, load_weights=True,
weights_file_root=None, weights_file_root=None,
training=False, training=False,
place_model_on_cpu=False, place_model_on_cpu=False,
run_on_cpu=False, run_on_cpu=False,
optimizer=None, optimizer=None,
data_format="NHWC", data_format="NHWC",
raise_on_no_model_files=False): raise_on_no_model_files=False):
self.resolution = resolution self.resolution = resolution
self.weights_file_root = Path(weights_file_root) if weights_file_root is not None else Path(__file__).parent self.weights_file_root = Path(weights_file_root) if weights_file_root is not None else Path(__file__).parent
nn.initialize(data_format=data_format) nn.initialize(data_format=data_format)
tf = nn.tf tf = nn.tf
model_name = f'{name}_{resolution}' model_name = f'{name}_{resolution}'
self.model_filename_list = [] self.model_filename_list = []
with tf.device ('/CPU:0'): with tf.device ('/CPU:0'):
#Place holders on CPU #Place holders on CPU
self.input_t = tf.placeholder (nn.floatx, nn.get4Dshape(resolution,resolution,3) ) self.input_t = tf.placeholder (nn.floatx, nn.get4Dshape(resolution,resolution,3) )
@ -44,12 +44,12 @@ class XSegNet(object):
self.model_weights = self.model.get_weights() self.model_weights = self.model.get_weights()
if training: if training:
if optimizer is None: if optimizer is None:
raise ValueError("Optimizer should be provided for training mode.") raise ValueError("Optimizer should be provided for training mode.")
self.opt = optimizer self.opt = optimizer
self.opt.initialize_variables (self.model_weights, vars_on_cpu=place_model_on_cpu) self.opt.initialize_variables (self.model_weights, vars_on_cpu=place_model_on_cpu)
self.model_filename_list += [ [self.opt, f'{model_name}_opt.npy' ] ] self.model_filename_list += [ [self.opt, f'{model_name}_opt.npy' ] ]
self.model_filename_list += [ [self.model, f'{model_name}.npy'] ] self.model_filename_list += [ [self.model, f'{model_name}.npy'] ]
if not training: if not training:
@ -77,10 +77,10 @@ class XSegNet(object):
if do_init: if do_init:
model.init_weights() model.init_weights()
def get_resolution(self): def get_resolution(self):
return self.resolution return self.resolution
def flow(self, x): def flow(self, x):
return self.model(x) return self.model(x)
@ -94,8 +94,8 @@ class XSegNet(object):
def extract (self, input_image): def extract (self, input_image):
if not self.initialized: if not self.initialized:
return 0.5*np.ones ( (self.resolution, self.resolution, 1), nn.floatx.as_numpy_dtype ) return 0.5*np.ones ( (self.resolution, self.resolution, 1), nn.floatx.as_numpy_dtype )
input_shape_len = len(input_image.shape) input_shape_len = len(input_image.shape)
if input_shape_len == 3: if input_shape_len == 3:
input_image = input_image[None,...] input_image = input_image[None,...]

View file

@ -36,7 +36,6 @@ SID_HOT_KEY = 1
if system_language == 'en': if system_language == 'en':
StringsDB = {'S_HOT_KEY' : 'hot key'} StringsDB = {'S_HOT_KEY' : 'hot key'}
elif system_language == 'ru': elif system_language == 'ru':
StringsDB = {'S_HOT_KEY' : 'горячая клавиша'} StringsDB = {'S_HOT_KEY' : 'горячая клавиша'}
elif system_language == 'zh': elif system_language == 'zh':
StringsDB = {'S_HOT_KEY' : '热键'} StringsDB = {'S_HOT_KEY' : '热键'}

42
main.py
View file

@ -23,7 +23,7 @@ if __name__ == "__main__":
setattr(namespace, self.dest, os.path.abspath(os.path.expanduser(values))) setattr(namespace, self.dest, os.path.abspath(os.path.expanduser(values)))
exit_code = 0 exit_code = 0
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers() subparsers = parser.add_subparsers()
@ -52,9 +52,9 @@ if __name__ == "__main__":
p.add_argument('--output-debug', action="store_true", dest="output_debug", default=None, help="Writes debug images to <output-dir>_debug\ directory.") p.add_argument('--output-debug', action="store_true", dest="output_debug", default=None, help="Writes debug images to <output-dir>_debug\ directory.")
p.add_argument('--no-output-debug', action="store_false", dest="output_debug", default=None, help="Don't writes debug images to <output-dir>_debug\ directory.") p.add_argument('--no-output-debug', action="store_false", dest="output_debug", default=None, help="Don't writes debug images to <output-dir>_debug\ directory.")
p.add_argument('--face-type', dest="face_type", choices=['half_face', 'full_face', 'whole_face', 'head', 'mark_only'], default=None) p.add_argument('--face-type', dest="face_type", choices=['half_face', 'full_face', 'whole_face', 'head', 'mark_only'], default=None)
p.add_argument('--max-faces-from-image', type=int, dest="max_faces_from_image", default=None, help="Max faces from image.") p.add_argument('--max-faces-from-image', type=int, dest="max_faces_from_image", default=None, help="Max faces from image.")
p.add_argument('--image-size', type=int, dest="image_size", default=None, help="Output image size.") p.add_argument('--image-size', type=int, dest="image_size", default=None, help="Output image size.")
p.add_argument('--jpeg-quality', type=int, dest="jpeg_quality", default=None, help="Jpeg quality.") p.add_argument('--jpeg-quality', type=int, dest="jpeg_quality", default=None, help="Jpeg quality.")
p.add_argument('--manual-fix', action="store_true", dest="manual_fix", default=False, help="Enables manual extract only frames where faces were not recognized.") p.add_argument('--manual-fix', action="store_true", dest="manual_fix", default=False, help="Enables manual extract only frames where faces were not recognized.")
p.add_argument('--manual-output-debug-fix', action="store_true", dest="manual_output_debug_fix", default=False, help="Performs manual reextract input-dir frames which were deleted from [output_dir]_debug\ dir.") p.add_argument('--manual-output-debug-fix', action="store_true", dest="manual_output_debug_fix", default=False, help="Performs manual reextract input-dir frames which were deleted from [output_dir]_debug\ dir.")
p.add_argument('--manual-window-size', type=int, dest="manual_window_size", default=1368, help="Manual fix window size. Default: 1368.") p.add_argument('--manual-window-size', type=int, dest="manual_window_size", default=1368, help="Manual fix window size. Default: 1368.")
@ -146,8 +146,8 @@ if __name__ == "__main__":
p.add_argument('--force-gpu-idxs', dest="force_gpu_idxs", default=None, help="Force to choose GPU indexes separated by comma.") p.add_argument('--force-gpu-idxs', dest="force_gpu_idxs", default=None, help="Force to choose GPU indexes separated by comma.")
p.add_argument('--silent-start', action="store_true", dest="silent_start", default=False, help="Silent start. Automatically chooses Best GPU and last used model.") p.add_argument('--silent-start', action="store_true", dest="silent_start", default=False, help="Silent start. Automatically chooses Best GPU and last used model.")
p.add_argument('--dump-ckpt', action="store_true", dest="dump_ckpt", default=False, help="Dump the model to ckpt format.") p.add_argument('--dump-ckpt', action="store_true", dest="dump_ckpt", default=False, help="Dump the model to ckpt format.")
p.add_argument('--execute-program', dest="execute_program", default=[], action='append', nargs='+') p.add_argument('--execute-program', dest="execute_program", default=[], action='append', nargs='+')
p.set_defaults (func=process_train) p.set_defaults (func=process_train)
@ -254,8 +254,8 @@ if __name__ == "__main__":
p.add_argument('--force-gpu-idxs', dest="force_gpu_idxs", default=None, help="Force to choose GPU indexes separated by comma.") p.add_argument('--force-gpu-idxs', dest="force_gpu_idxs", default=None, help="Force to choose GPU indexes separated by comma.")
p.set_defaults(func=process_faceset_enhancer) p.set_defaults(func=process_faceset_enhancer)
p = facesettool_parser.add_parser ("resize", help="Resize DFL faceset.") p = facesettool_parser.add_parser ("resize", help="Resize DFL faceset.")
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory of aligned faces.") p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory of aligned faces.")
@ -264,7 +264,7 @@ if __name__ == "__main__":
from mainscripts import FacesetResizer from mainscripts import FacesetResizer
FacesetResizer.process_folder ( Path(arguments.input_dir) ) FacesetResizer.process_folder ( Path(arguments.input_dir) )
p.set_defaults(func=process_faceset_resizer) p.set_defaults(func=process_faceset_resizer)
def process_dev_test(arguments): def process_dev_test(arguments):
osex.set_process_lowest_prio() osex.set_process_lowest_prio()
from mainscripts import dev_misc from mainscripts import dev_misc
@ -273,10 +273,10 @@ if __name__ == "__main__":
p = subparsers.add_parser( "dev_test", help="") p = subparsers.add_parser( "dev_test", help="")
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir") p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir")
p.set_defaults (func=process_dev_test) p.set_defaults (func=process_dev_test)
# ========== XSeg # ========== XSeg
xseg_parser = subparsers.add_parser( "xseg", help="XSeg tools.").add_subparsers() xseg_parser = subparsers.add_parser( "xseg", help="XSeg tools.").add_subparsers()
p = xseg_parser.add_parser( "editor", help="XSeg editor.") p = xseg_parser.add_parser( "editor", help="XSeg editor.")
def process_xsegeditor(arguments): def process_xsegeditor(arguments):
@ -284,11 +284,11 @@ if __name__ == "__main__":
from XSegEditor import XSegEditor from XSegEditor import XSegEditor
global exit_code global exit_code
exit_code = XSegEditor.start (Path(arguments.input_dir)) exit_code = XSegEditor.start (Path(arguments.input_dir))
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir") p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir")
p.set_defaults (func=process_xsegeditor) p.set_defaults (func=process_xsegeditor)
p = xseg_parser.add_parser( "apply", help="Apply trained XSeg model to the extracted faces.") p = xseg_parser.add_parser( "apply", help="Apply trained XSeg model to the extracted faces.")
def process_xsegapply(arguments): def process_xsegapply(arguments):
@ -298,8 +298,8 @@ if __name__ == "__main__":
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir") p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir")
p.add_argument('--model-dir', required=True, action=fixPathAction, dest="model_dir") p.add_argument('--model-dir', required=True, action=fixPathAction, dest="model_dir")
p.set_defaults (func=process_xsegapply) p.set_defaults (func=process_xsegapply)
p = xseg_parser.add_parser( "remove", help="Remove applied XSeg masks from the extracted faces.") p = xseg_parser.add_parser( "remove", help="Remove applied XSeg masks from the extracted faces.")
def process_xsegremove(arguments): def process_xsegremove(arguments):
osex.set_process_lowest_prio() osex.set_process_lowest_prio()
@ -307,8 +307,8 @@ if __name__ == "__main__":
XSegUtil.remove_xseg (Path(arguments.input_dir) ) XSegUtil.remove_xseg (Path(arguments.input_dir) )
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir") p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir")
p.set_defaults (func=process_xsegremove) p.set_defaults (func=process_xsegremove)
p = xseg_parser.add_parser( "remove_labels", help="Remove XSeg labels from the extracted faces.") p = xseg_parser.add_parser( "remove_labels", help="Remove XSeg labels from the extracted faces.")
def process_xsegremovelabels(arguments): def process_xsegremovelabels(arguments):
osex.set_process_lowest_prio() osex.set_process_lowest_prio()
@ -316,8 +316,8 @@ if __name__ == "__main__":
XSegUtil.remove_xseg_labels (Path(arguments.input_dir) ) XSegUtil.remove_xseg_labels (Path(arguments.input_dir) )
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir") p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir")
p.set_defaults (func=process_xsegremovelabels) p.set_defaults (func=process_xsegremovelabels)
p = xseg_parser.add_parser( "fetch", help="Copies faces containing XSeg polygons in <input_dir>_xseg dir.") p = xseg_parser.add_parser( "fetch", help="Copies faces containing XSeg polygons in <input_dir>_xseg dir.")
def process_xsegfetch(arguments): def process_xsegfetch(arguments):
@ -326,7 +326,7 @@ if __name__ == "__main__":
XSegUtil.fetch_xseg (Path(arguments.input_dir) ) XSegUtil.fetch_xseg (Path(arguments.input_dir) )
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir") p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir")
p.set_defaults (func=process_xsegfetch) p.set_defaults (func=process_xsegfetch)
def bad_args(arguments): def bad_args(arguments):
parser.print_help() parser.print_help()
exit(0) exit(0)
@ -337,9 +337,9 @@ if __name__ == "__main__":
if exit_code == 0: if exit_code == 0:
print ("Done.") print ("Done.")
exit(exit_code) exit(exit_code)
''' '''
import code import code
code.interact(local=dict(globals(), **locals())) code.interact(local=dict(globals(), **locals()))

View file

@ -79,79 +79,79 @@ class FacesetResizerSubprocessor(Subprocessor):
h,w = img.shape[:2] h,w = img.shape[:2]
if h != w: if h != w:
raise Exception(f'w != h in {filepath}') raise Exception(f'w != h in {filepath}')
image_size = self.image_size image_size = self.image_size
face_type = self.face_type face_type = self.face_type
output_filepath = self.output_dirpath / filepath.name output_filepath = self.output_dirpath / filepath.name
if face_type is not None: if face_type is not None:
lmrks = dflimg.get_landmarks() lmrks = dflimg.get_landmarks()
mat = LandmarksProcessor.get_transform_mat(lmrks, image_size, face_type) mat = LandmarksProcessor.get_transform_mat(lmrks, image_size, face_type)
img = cv2.warpAffine(img, mat, (image_size, image_size), flags=cv2.INTER_LANCZOS4 ) img = cv2.warpAffine(img, mat, (image_size, image_size), flags=cv2.INTER_LANCZOS4 )
img = np.clip(img, 0, 255).astype(np.uint8) img = np.clip(img, 0, 255).astype(np.uint8)
cv2_imwrite ( str(output_filepath), img, [int(cv2.IMWRITE_JPEG_QUALITY), 100] ) cv2_imwrite ( str(output_filepath), img, [int(cv2.IMWRITE_JPEG_QUALITY), 100] )
dfl_dict = dflimg.get_dict() dfl_dict = dflimg.get_dict()
dflimg = DFLIMG.load (output_filepath) dflimg = DFLIMG.load (output_filepath)
dflimg.set_dict(dfl_dict) dflimg.set_dict(dfl_dict)
xseg_mask = dflimg.get_xseg_mask() xseg_mask = dflimg.get_xseg_mask()
if xseg_mask is not None: if xseg_mask is not None:
xseg_res = 256 xseg_res = 256
xseg_lmrks = lmrks.copy() xseg_lmrks = lmrks.copy()
xseg_lmrks *= (xseg_res / w) xseg_lmrks *= (xseg_res / w)
xseg_mat = LandmarksProcessor.get_transform_mat(xseg_lmrks, xseg_res, face_type) xseg_mat = LandmarksProcessor.get_transform_mat(xseg_lmrks, xseg_res, face_type)
xseg_mask = cv2.warpAffine(xseg_mask, xseg_mat, (xseg_res, xseg_res), flags=cv2.INTER_LANCZOS4 ) xseg_mask = cv2.warpAffine(xseg_mask, xseg_mat, (xseg_res, xseg_res), flags=cv2.INTER_LANCZOS4 )
xseg_mask[xseg_mask < 0.5] = 0 xseg_mask[xseg_mask < 0.5] = 0
xseg_mask[xseg_mask >= 0.5] = 1 xseg_mask[xseg_mask >= 0.5] = 1
dflimg.set_xseg_mask(xseg_mask) dflimg.set_xseg_mask(xseg_mask)
seg_ie_polys = dflimg.get_seg_ie_polys() seg_ie_polys = dflimg.get_seg_ie_polys()
for poly in seg_ie_polys.get_polys(): for poly in seg_ie_polys.get_polys():
poly_pts = poly.get_pts() poly_pts = poly.get_pts()
poly_pts = LandmarksProcessor.transform_points(poly_pts, mat) poly_pts = LandmarksProcessor.transform_points(poly_pts, mat)
poly.set_points(poly_pts) poly.set_points(poly_pts)
dflimg.set_seg_ie_polys(seg_ie_polys) dflimg.set_seg_ie_polys(seg_ie_polys)
lmrks = LandmarksProcessor.transform_points(lmrks, mat) lmrks = LandmarksProcessor.transform_points(lmrks, mat)
dflimg.set_landmarks(lmrks) dflimg.set_landmarks(lmrks)
image_to_face_mat = dflimg.get_image_to_face_mat() image_to_face_mat = dflimg.get_image_to_face_mat()
if image_to_face_mat is not None: if image_to_face_mat is not None:
image_to_face_mat = LandmarksProcessor.get_transform_mat ( dflimg.get_source_landmarks(), image_size, face_type ) image_to_face_mat = LandmarksProcessor.get_transform_mat ( dflimg.get_source_landmarks(), image_size, face_type )
dflimg.set_image_to_face_mat(image_to_face_mat) dflimg.set_image_to_face_mat(image_to_face_mat)
dflimg.set_face_type( FaceType.toString(face_type) ) dflimg.set_face_type( FaceType.toString(face_type) )
dflimg.save() dflimg.save()
else: else:
dfl_dict = dflimg.get_dict() dfl_dict = dflimg.get_dict()
scale = w / image_size scale = w / image_size
img = cv2.resize(img, (image_size, image_size), interpolation=cv2.INTER_LANCZOS4) img = cv2.resize(img, (image_size, image_size), interpolation=cv2.INTER_LANCZOS4)
cv2_imwrite ( str(output_filepath), img, [int(cv2.IMWRITE_JPEG_QUALITY), 100] ) cv2_imwrite ( str(output_filepath), img, [int(cv2.IMWRITE_JPEG_QUALITY), 100] )
dflimg = DFLIMG.load (output_filepath) dflimg = DFLIMG.load (output_filepath)
dflimg.set_dict(dfl_dict) dflimg.set_dict(dfl_dict)
lmrks = dflimg.get_landmarks() lmrks = dflimg.get_landmarks()
lmrks /= scale lmrks /= scale
dflimg.set_landmarks(lmrks) dflimg.set_landmarks(lmrks)
seg_ie_polys = dflimg.get_seg_ie_polys() seg_ie_polys = dflimg.get_seg_ie_polys()
seg_ie_polys.mult_points( 1.0 / scale) seg_ie_polys.mult_points( 1.0 / scale)
dflimg.set_seg_ie_polys(seg_ie_polys) dflimg.set_seg_ie_polys(seg_ie_polys)
image_to_face_mat = dflimg.get_image_to_face_mat() image_to_face_mat = dflimg.get_image_to_face_mat()
if image_to_face_mat is not None: if image_to_face_mat is not None:
face_type = FaceType.fromString ( dflimg.get_face_type() ) face_type = FaceType.fromString ( dflimg.get_face_type() )
image_to_face_mat = LandmarksProcessor.get_transform_mat ( dflimg.get_source_landmarks(), image_size, face_type ) image_to_face_mat = LandmarksProcessor.get_transform_mat ( dflimg.get_source_landmarks(), image_size, face_type )
@ -165,9 +165,9 @@ class FacesetResizerSubprocessor(Subprocessor):
return (0, filepath, None) return (0, filepath, None)
def process_folder ( dirpath): def process_folder ( dirpath):
image_size = io.input_int(f"New image size", 512, valid_range=[256,2048]) image_size = io.input_int(f"New image size", 512, valid_range=[256,2048])
face_type = io.input_str ("Change face type", 'same', ['h','mf','f','wf','head','same']).lower() face_type = io.input_str ("Change face type", 'same', ['h','mf','f','wf','head','same']).lower()
if face_type == 'same': if face_type == 'same':
face_type = None face_type = None
@ -177,7 +177,7 @@ def process_folder ( dirpath):
'f' : FaceType.FULL, 'f' : FaceType.FULL,
'wf' : FaceType.WHOLE_FACE, 'wf' : FaceType.WHOLE_FACE,
'head' : FaceType.HEAD}[face_type] 'head' : FaceType.HEAD}[face_type]
output_dirpath = dirpath.parent / (dirpath.name + '_resized') output_dirpath = dirpath.parent / (dirpath.name + '_resized')
output_dirpath.mkdir (exist_ok=True, parents=True) output_dirpath.mkdir (exist_ok=True, parents=True)

View file

@ -72,8 +72,8 @@ def main (model_class_name=None,
if not is_interactive: if not is_interactive:
cfg.ask_settings() cfg.ask_settings()
subprocess_count = io.input_int("Number of workers?", max(8, multiprocessing.cpu_count()), subprocess_count = io.input_int("Number of workers?", max(8, multiprocessing.cpu_count()),
valid_range=[1, multiprocessing.cpu_count()], help_message="Specify the number of threads to process. A low value may affect performance. A high value may result in memory error. The value may not be greater than CPU cores." ) valid_range=[1, multiprocessing.cpu_count()], help_message="Specify the number of threads to process. A low value may affect performance. A high value may result in memory error. The value may not be greater than CPU cores." )
input_path_image_paths = pathex.get_image_paths(input_path) input_path_image_paths = pathex.get_image_paths(input_path)

View file

@ -25,7 +25,7 @@ class BlurEstimatorSubprocessor(Subprocessor):
class Cli(Subprocessor.Cli): class Cli(Subprocessor.Cli):
def on_initialize(self, client_dict): def on_initialize(self, client_dict):
self.estimate_motion_blur = client_dict['estimate_motion_blur'] self.estimate_motion_blur = client_dict['estimate_motion_blur']
#override #override
def process_data(self, data): def process_data(self, data):
filepath = Path( data[0] ) filepath = Path( data[0] )
@ -36,16 +36,16 @@ class BlurEstimatorSubprocessor(Subprocessor):
return [ str(filepath), 0 ] return [ str(filepath), 0 ]
else: else:
image = cv2_imread( str(filepath) ) image = cv2_imread( str(filepath) )
face_mask = LandmarksProcessor.get_image_hull_mask (image.shape, dflimg.get_landmarks()) face_mask = LandmarksProcessor.get_image_hull_mask (image.shape, dflimg.get_landmarks())
image = (image*face_mask).astype(np.uint8) image = (image*face_mask).astype(np.uint8)
if self.estimate_motion_blur: if self.estimate_motion_blur:
value = cv2.Laplacian(image, cv2.CV_64F, ksize=11).var() value = cv2.Laplacian(image, cv2.CV_64F, ksize=11).var()
else: else:
value = estimate_sharpness(image) value = estimate_sharpness(image)
return [ str(filepath), value ] return [ str(filepath), value ]
@ -113,7 +113,7 @@ def sort_by_blur(input_path):
img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True) img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True)
return img_list, trash_img_list return img_list, trash_img_list
def sort_by_motion_blur(input_path): def sort_by_motion_blur(input_path):
io.log_info ("Sorting by motion blur...") io.log_info ("Sorting by motion blur...")
@ -124,7 +124,7 @@ def sort_by_motion_blur(input_path):
img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True) img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True)
return img_list, trash_img_list return img_list, trash_img_list
def sort_by_face_yaw(input_path): def sort_by_face_yaw(input_path):
io.log_info ("Sorting by face yaw...") io.log_info ("Sorting by face yaw...")
img_list = [] img_list = []
@ -472,7 +472,7 @@ class FinalLoaderSubprocessor(Subprocessor):
source_rect = dflimg.get_source_rect() source_rect = dflimg.get_source_rect()
sharpness = mathlib.polygon_area(np.array(source_rect[[0,2,2,0]]).astype(np.float32), np.array(source_rect[[1,1,3,3]]).astype(np.float32)) sharpness = mathlib.polygon_area(np.array(source_rect[[0,2,2,0]]).astype(np.float32), np.array(source_rect[[1,1,3,3]]).astype(np.float32))
else: else:
face_mask = LandmarksProcessor.get_image_hull_mask (gray.shape, dflimg.get_landmarks()) face_mask = LandmarksProcessor.get_image_hull_mask (gray.shape, dflimg.get_landmarks())
sharpness = estimate_sharpness( (gray[...,None]*face_mask).astype(np.uint8) ) sharpness = estimate_sharpness( (gray[...,None]*face_mask).astype(np.uint8) )
pitch, yaw, roll = LandmarksProcessor.estimate_pitch_yaw_roll ( dflimg.get_landmarks(), size=dflimg.get_shape()[1] ) pitch, yaw, roll = LandmarksProcessor.estimate_pitch_yaw_roll ( dflimg.get_landmarks(), size=dflimg.get_shape()[1] )

View file

@ -43,10 +43,10 @@ def trainerThread (s2c, c2s, e,
if not saved_models_path.exists(): if not saved_models_path.exists():
saved_models_path.mkdir(exist_ok=True, parents=True) saved_models_path.mkdir(exist_ok=True, parents=True)
if dump_ckpt: if dump_ckpt:
cpu_only=True cpu_only=True
model = models.import_model(model_class_name)( model = models.import_model(model_class_name)(
is_training=not dump_ckpt, is_training=not dump_ckpt,
saved_models_path=saved_models_path, saved_models_path=saved_models_path,
@ -65,7 +65,7 @@ def trainerThread (s2c, c2s, e,
e.set() e.set()
model.dump_ckpt() model.dump_ckpt()
break break
is_reached_goal = model.is_reached_iter_goal() is_reached_goal = model.is_reached_iter_goal()
shared_state = { 'after_save' : False } shared_state = { 'after_save' : False }
@ -76,10 +76,10 @@ def trainerThread (s2c, c2s, e,
io.log_info ("Saving....", end='\r') io.log_info ("Saving....", end='\r')
model.save() model.save()
shared_state['after_save'] = True shared_state['after_save'] = True
def model_backup(): def model_backup():
if not debug and not is_reached_goal: if not debug and not is_reached_goal:
model.create_backup() model.create_backup()
def send_preview(): def send_preview():
if not debug: if not debug:
@ -128,7 +128,7 @@ def trainerThread (s2c, c2s, e,
io.log_info("") io.log_info("")
io.log_info("Trying to do the first iteration. If an error occurs, reduce the model parameters.") io.log_info("Trying to do the first iteration. If an error occurs, reduce the model parameters.")
io.log_info("") io.log_info("")
if sys.platform[0:3] == 'win': if sys.platform[0:3] == 'win':
io.log_info("!!!") io.log_info("!!!")
io.log_info("Windows 10 users IMPORTANT notice. You should set this setting in order to work correctly.") io.log_info("Windows 10 users IMPORTANT notice. You should set this setting in order to work correctly.")
@ -146,7 +146,7 @@ def trainerThread (s2c, c2s, e,
if shared_state['after_save']: if shared_state['after_save']:
shared_state['after_save'] = False shared_state['after_save'] = False
mean_loss = np.mean ( loss_history[save_iter:iter], axis=0) mean_loss = np.mean ( loss_history[save_iter:iter], axis=0)
for loss_value in mean_loss: for loss_value in mean_loss:

View file

@ -24,7 +24,7 @@ def save_faceset_metadata_folder(input_path):
if dflimg is None or not dflimg.has_data(): if dflimg is None or not dflimg.has_data():
io.log_info(f"{filepath} is not a dfl image file") io.log_info(f"{filepath} is not a dfl image file")
continue continue
dfl_dict = dflimg.get_dict() dfl_dict = dflimg.get_dict()
d[filepath.name] = ( dflimg.get_shape(), dfl_dict ) d[filepath.name] = ( dflimg.get_shape(), dfl_dict )
@ -59,7 +59,7 @@ def restore_faceset_metadata_folder(input_path):
if saved_data is None: if saved_data is None:
io.log_info(f"No saved metadata for {filepath}") io.log_info(f"No saved metadata for {filepath}")
continue continue
shape, dfl_dict = saved_data shape, dfl_dict = saved_data
img = cv2_imread (filepath) img = cv2_imread (filepath)
@ -90,19 +90,19 @@ def add_landmarks_debug_images(input_path):
if dflimg is None or not dflimg.has_data(): if dflimg is None or not dflimg.has_data():
io.log_err (f"{filepath.name} is not a dfl image file") io.log_err (f"{filepath.name} is not a dfl image file")
continue continue
if img is not None: if img is not None:
face_landmarks = dflimg.get_landmarks() face_landmarks = dflimg.get_landmarks()
face_type = FaceType.fromString ( dflimg.get_face_type() ) face_type = FaceType.fromString ( dflimg.get_face_type() )
if face_type == FaceType.MARK_ONLY: if face_type == FaceType.MARK_ONLY:
rect = dflimg.get_source_rect() rect = dflimg.get_source_rect()
LandmarksProcessor.draw_rect_landmarks(img, rect, face_landmarks, FaceType.FULL ) LandmarksProcessor.draw_rect_landmarks(img, rect, face_landmarks, FaceType.FULL )
else: else:
LandmarksProcessor.draw_landmarks(img, face_landmarks, transparent_mask=True ) LandmarksProcessor.draw_landmarks(img, face_landmarks, transparent_mask=True )
output_file = '{}{}'.format( str(Path(str(input_path)) / filepath.stem), '_debug.jpg') output_file = '{}{}'.format( str(Path(str(input_path)) / filepath.stem), '_debug.jpg')
cv2_imwrite(output_file, img, [int(cv2.IMWRITE_JPEG_QUALITY), 50] ) cv2_imwrite(output_file, img, [int(cv2.IMWRITE_JPEG_QUALITY), 50] )

View file

@ -19,159 +19,159 @@ def apply_xseg(input_path, model_path):
if not model_path.exists(): if not model_path.exists():
raise ValueError(f'{model_path} not found. Please ensure it exists.') raise ValueError(f'{model_path} not found. Please ensure it exists.')
face_type = None face_type = None
model_dat = model_path / 'XSeg_data.dat' model_dat = model_path / 'XSeg_data.dat'
if model_dat.exists(): if model_dat.exists():
dat = pickle.loads( model_dat.read_bytes() ) dat = pickle.loads( model_dat.read_bytes() )
dat_options = dat.get('options', None) dat_options = dat.get('options', None)
if dat_options is not None: if dat_options is not None:
face_type = dat_options.get('face_type', None) face_type = dat_options.get('face_type', None)
if face_type is None: if face_type is None:
face_type = io.input_str ("XSeg model face type", 'same', ['h','mf','f','wf','head','same'], help_message="Specify face type of trained XSeg model. For example if XSeg model trained as WF, but faceset is HEAD, specify WF to apply xseg only on WF part of HEAD. Default is 'same'").lower() face_type = io.input_str ("XSeg model face type", 'same', ['h','mf','f','wf','head','same'], help_message="Specify face type of trained XSeg model. For example if XSeg model trained as WF, but faceset is HEAD, specify WF to apply xseg only on WF part of HEAD. Default is 'same'").lower()
if face_type == 'same': if face_type == 'same':
face_type = None face_type = None
if face_type is not None: if face_type is not None:
face_type = {'h' : FaceType.HALF, face_type = {'h' : FaceType.HALF,
'mf' : FaceType.MID_FULL, 'mf' : FaceType.MID_FULL,
'f' : FaceType.FULL, 'f' : FaceType.FULL,
'wf' : FaceType.WHOLE_FACE, 'wf' : FaceType.WHOLE_FACE,
'head' : FaceType.HEAD}[face_type] 'head' : FaceType.HEAD}[face_type]
io.log_info(f'Applying trained XSeg model to {input_path.name}/ folder.') io.log_info(f'Applying trained XSeg model to {input_path.name}/ folder.')
device_config = nn.DeviceConfig.ask_choose_device(choose_only_one=True) device_config = nn.DeviceConfig.ask_choose_device(choose_only_one=True)
nn.initialize(device_config) nn.initialize(device_config)
xseg = XSegNet(name='XSeg', xseg = XSegNet(name='XSeg',
load_weights=True, load_weights=True,
weights_file_root=model_path, weights_file_root=model_path,
data_format=nn.data_format, data_format=nn.data_format,
raise_on_no_model_files=True) raise_on_no_model_files=True)
xseg_res = xseg.get_resolution() xseg_res = xseg.get_resolution()
images_paths = pathex.get_image_paths(input_path, return_Path_class=True) images_paths = pathex.get_image_paths(input_path, return_Path_class=True)
for filepath in io.progress_bar_generator(images_paths, "Processing"): for filepath in io.progress_bar_generator(images_paths, "Processing"):
dflimg = DFLIMG.load(filepath) dflimg = DFLIMG.load(filepath)
if dflimg is None or not dflimg.has_data(): if dflimg is None or not dflimg.has_data():
io.log_info(f'{filepath} is not a DFLIMG') io.log_info(f'{filepath} is not a DFLIMG')
continue continue
img = cv2_imread(filepath).astype(np.float32) / 255.0 img = cv2_imread(filepath).astype(np.float32) / 255.0
h,w,c = img.shape h,w,c = img.shape
img_face_type = FaceType.fromString( dflimg.get_face_type() ) img_face_type = FaceType.fromString( dflimg.get_face_type() )
if face_type is not None and img_face_type != face_type: if face_type is not None and img_face_type != face_type:
lmrks = dflimg.get_source_landmarks() lmrks = dflimg.get_source_landmarks()
fmat = LandmarksProcessor.get_transform_mat(lmrks, w, face_type) fmat = LandmarksProcessor.get_transform_mat(lmrks, w, face_type)
imat = LandmarksProcessor.get_transform_mat(lmrks, w, img_face_type) imat = LandmarksProcessor.get_transform_mat(lmrks, w, img_face_type)
g_p = LandmarksProcessor.transform_points (np.float32([(0,0),(w,0),(0,w) ]), fmat, True) g_p = LandmarksProcessor.transform_points (np.float32([(0,0),(w,0),(0,w) ]), fmat, True)
g_p2 = LandmarksProcessor.transform_points (g_p, imat) g_p2 = LandmarksProcessor.transform_points (g_p, imat)
mat = cv2.getAffineTransform( g_p2, np.float32([(0,0),(w,0),(0,w) ]) ) mat = cv2.getAffineTransform( g_p2, np.float32([(0,0),(w,0),(0,w) ]) )
img = cv2.warpAffine(img, mat, (w, w), cv2.INTER_LANCZOS4) img = cv2.warpAffine(img, mat, (w, w), cv2.INTER_LANCZOS4)
img = cv2.resize(img, (xseg_res, xseg_res), interpolation=cv2.INTER_LANCZOS4) img = cv2.resize(img, (xseg_res, xseg_res), interpolation=cv2.INTER_LANCZOS4)
else: else:
if w != xseg_res: if w != xseg_res:
img = cv2.resize( img, (xseg_res,xseg_res), interpolation=cv2.INTER_LANCZOS4 ) img = cv2.resize( img, (xseg_res,xseg_res), interpolation=cv2.INTER_LANCZOS4 )
if len(img.shape) == 2: if len(img.shape) == 2:
img = img[...,None] img = img[...,None]
mask = xseg.extract(img) mask = xseg.extract(img)
if face_type is not None and img_face_type != face_type: if face_type is not None and img_face_type != face_type:
mask = cv2.resize(mask, (w, w), interpolation=cv2.INTER_LANCZOS4) mask = cv2.resize(mask, (w, w), interpolation=cv2.INTER_LANCZOS4)
mask = cv2.warpAffine( mask, mat, (w,w), np.zeros( (h,w,c), dtype=np.float), cv2.WARP_INVERSE_MAP | cv2.INTER_LANCZOS4) mask = cv2.warpAffine( mask, mat, (w,w), np.zeros( (h,w,c), dtype=np.float), cv2.WARP_INVERSE_MAP | cv2.INTER_LANCZOS4)
mask = cv2.resize(mask, (xseg_res, xseg_res), interpolation=cv2.INTER_LANCZOS4) mask = cv2.resize(mask, (xseg_res, xseg_res), interpolation=cv2.INTER_LANCZOS4)
mask[mask < 0.5]=0 mask[mask < 0.5]=0
mask[mask >= 0.5]=1 mask[mask >= 0.5]=1
dflimg.set_xseg_mask(mask) dflimg.set_xseg_mask(mask)
dflimg.save() dflimg.save()
def fetch_xseg(input_path): def fetch_xseg(input_path):
if not input_path.exists(): if not input_path.exists():
raise ValueError(f'{input_path} not found. Please ensure it exists.') raise ValueError(f'{input_path} not found. Please ensure it exists.')
output_path = input_path.parent / (input_path.name + '_xseg') output_path = input_path.parent / (input_path.name + '_xseg')
output_path.mkdir(exist_ok=True, parents=True) output_path.mkdir(exist_ok=True, parents=True)
io.log_info(f'Copying faces containing XSeg polygons to {output_path.name}/ folder.') io.log_info(f'Copying faces containing XSeg polygons to {output_path.name}/ folder.')
images_paths = pathex.get_image_paths(input_path, return_Path_class=True) images_paths = pathex.get_image_paths(input_path, return_Path_class=True)
files_copied = [] files_copied = []
for filepath in io.progress_bar_generator(images_paths, "Processing"): for filepath in io.progress_bar_generator(images_paths, "Processing"):
dflimg = DFLIMG.load(filepath) dflimg = DFLIMG.load(filepath)
if dflimg is None or not dflimg.has_data(): if dflimg is None or not dflimg.has_data():
io.log_info(f'{filepath} is not a DFLIMG') io.log_info(f'{filepath} is not a DFLIMG')
continue continue
ie_polys = dflimg.get_seg_ie_polys() ie_polys = dflimg.get_seg_ie_polys()
if ie_polys.has_polys(): if ie_polys.has_polys():
files_copied.append(filepath) files_copied.append(filepath)
shutil.copy ( str(filepath), str(output_path / filepath.name) ) shutil.copy ( str(filepath), str(output_path / filepath.name) )
io.log_info(f'Files copied: {len(files_copied)}') io.log_info(f'Files copied: {len(files_copied)}')
is_delete = io.input_bool (f"\r\nDelete original files?", True) is_delete = io.input_bool (f"\r\nDelete original files?", True)
if is_delete: if is_delete:
for filepath in files_copied: for filepath in files_copied:
Path(filepath).unlink() Path(filepath).unlink()
def remove_xseg(input_path): def remove_xseg(input_path):
if not input_path.exists(): if not input_path.exists():
raise ValueError(f'{input_path} not found. Please ensure it exists.') raise ValueError(f'{input_path} not found. Please ensure it exists.')
io.log_info(f'Processing folder {input_path}') io.log_info(f'Processing folder {input_path}')
io.log_info('!!! WARNING : APPLIED XSEG MASKS WILL BE REMOVED FROM THE FRAMES !!!') io.log_info('!!! WARNING : APPLIED XSEG MASKS WILL BE REMOVED FROM THE FRAMES !!!')
io.log_info('!!! WARNING : APPLIED XSEG MASKS WILL BE REMOVED FROM THE FRAMES !!!') io.log_info('!!! WARNING : APPLIED XSEG MASKS WILL BE REMOVED FROM THE FRAMES !!!')
io.log_info('!!! WARNING : APPLIED XSEG MASKS WILL BE REMOVED FROM THE FRAMES !!!') io.log_info('!!! WARNING : APPLIED XSEG MASKS WILL BE REMOVED FROM THE FRAMES !!!')
io.input_str('Press enter to continue.') io.input_str('Press enter to continue.')
images_paths = pathex.get_image_paths(input_path, return_Path_class=True) images_paths = pathex.get_image_paths(input_path, return_Path_class=True)
files_processed = 0 files_processed = 0
for filepath in io.progress_bar_generator(images_paths, "Processing"): for filepath in io.progress_bar_generator(images_paths, "Processing"):
dflimg = DFLIMG.load(filepath) dflimg = DFLIMG.load(filepath)
if dflimg is None or not dflimg.has_data(): if dflimg is None or not dflimg.has_data():
io.log_info(f'{filepath} is not a DFLIMG') io.log_info(f'{filepath} is not a DFLIMG')
continue continue
if dflimg.has_xseg_mask(): if dflimg.has_xseg_mask():
dflimg.set_xseg_mask(None) dflimg.set_xseg_mask(None)
dflimg.save() dflimg.save()
files_processed += 1 files_processed += 1
io.log_info(f'Files processed: {files_processed}') io.log_info(f'Files processed: {files_processed}')
def remove_xseg_labels(input_path): def remove_xseg_labels(input_path):
if not input_path.exists(): if not input_path.exists():
raise ValueError(f'{input_path} not found. Please ensure it exists.') raise ValueError(f'{input_path} not found. Please ensure it exists.')
io.log_info(f'Processing folder {input_path}') io.log_info(f'Processing folder {input_path}')
io.log_info('!!! WARNING : LABELED XSEG POLYGONS WILL BE REMOVED FROM THE FRAMES !!!') io.log_info('!!! WARNING : LABELED XSEG POLYGONS WILL BE REMOVED FROM THE FRAMES !!!')
io.log_info('!!! WARNING : LABELED XSEG POLYGONS WILL BE REMOVED FROM THE FRAMES !!!') io.log_info('!!! WARNING : LABELED XSEG POLYGONS WILL BE REMOVED FROM THE FRAMES !!!')
io.log_info('!!! WARNING : LABELED XSEG POLYGONS WILL BE REMOVED FROM THE FRAMES !!!') io.log_info('!!! WARNING : LABELED XSEG POLYGONS WILL BE REMOVED FROM THE FRAMES !!!')
io.input_str('Press enter to continue.') io.input_str('Press enter to continue.')
images_paths = pathex.get_image_paths(input_path, return_Path_class=True) images_paths = pathex.get_image_paths(input_path, return_Path_class=True)
files_processed = 0 files_processed = 0
for filepath in io.progress_bar_generator(images_paths, "Processing"): for filepath in io.progress_bar_generator(images_paths, "Processing"):
dflimg = DFLIMG.load(filepath) dflimg = DFLIMG.load(filepath)
@ -181,7 +181,7 @@ def remove_xseg_labels(input_path):
if dflimg.has_seg_ie_polys(): if dflimg.has_seg_ie_polys():
dflimg.set_seg_ie_polys(None) dflimg.set_seg_ie_polys(None)
dflimg.save() dflimg.save()
files_processed += 1 files_processed += 1
io.log_info(f'Files processed: {files_processed}') io.log_info(f'Files processed: {files_processed}')

View file

@ -274,7 +274,7 @@ def dev_test_68(input_dir ):
img = cv2_imread(filepath) img = cv2_imread(filepath)
img = imagelib.normalize_channels(img, 3) img = imagelib.normalize_channels(img, 3)
cv2_imwrite(output_filepath, img, [int(cv2.IMWRITE_JPEG_QUALITY), 95] ) cv2_imwrite(output_filepath, img, [int(cv2.IMWRITE_JPEG_QUALITY), 95] )
raise Exception("unimplemented") raise Exception("unimplemented")
#DFLJPG.x(output_filepath, face_type=FaceType.toString(FaceType.MARK_ONLY), #DFLJPG.x(output_filepath, face_type=FaceType.toString(FaceType.MARK_ONLY),
# landmarks=lmrks, # landmarks=lmrks,
@ -358,25 +358,25 @@ def extract_umd_csv(input_file_csv,
io.log_info ('-------------------------') io.log_info ('-------------------------')
def dev_test(input_dir): def dev_test(input_dir):
# LaPa dataset # LaPa dataset
image_size = 1024 image_size = 1024
face_type = FaceType.HEAD face_type = FaceType.HEAD
input_path = Path(input_dir) input_path = Path(input_dir)
images_path = input_path / 'images' images_path = input_path / 'images'
if not images_path.exists: if not images_path.exists:
raise ValueError('LaPa dataset: images folder not found.') raise ValueError('LaPa dataset: images folder not found.')
labels_path = input_path / 'labels' labels_path = input_path / 'labels'
if not labels_path.exists: if not labels_path.exists:
raise ValueError('LaPa dataset: labels folder not found.') raise ValueError('LaPa dataset: labels folder not found.')
landmarks_path = input_path / 'landmarks' landmarks_path = input_path / 'landmarks'
if not landmarks_path.exists: if not landmarks_path.exists:
raise ValueError('LaPa dataset: landmarks folder not found.') raise ValueError('LaPa dataset: landmarks folder not found.')
output_path = input_path / 'out' output_path = input_path / 'out'
if output_path.exists(): if output_path.exists():
output_images_paths = pathex.get_image_paths(output_path) output_images_paths = pathex.get_image_paths(output_path)
if len(output_images_paths) != 0: if len(output_images_paths) != 0:
@ -384,9 +384,9 @@ def dev_test(input_dir):
for filename in output_images_paths: for filename in output_images_paths:
Path(filename).unlink() Path(filename).unlink()
output_path.mkdir(parents=True, exist_ok=True) output_path.mkdir(parents=True, exist_ok=True)
data = [] data = []
img_paths = pathex.get_image_paths (images_path) img_paths = pathex.get_image_paths (images_path)
for filename in img_paths: for filename in img_paths:
filepath = Path(filename) filepath = Path(filename)
@ -394,42 +394,42 @@ def dev_test(input_dir):
landmark_filepath = landmarks_path / (filepath.stem + '.txt') landmark_filepath = landmarks_path / (filepath.stem + '.txt')
if not landmark_filepath.exists(): if not landmark_filepath.exists():
raise ValueError(f'no landmarks for {filepath}') raise ValueError(f'no landmarks for {filepath}')
#img = cv2_imread(filepath) #img = cv2_imread(filepath)
lm = landmark_filepath.read_text() lm = landmark_filepath.read_text()
lm = lm.split('\n') lm = lm.split('\n')
if int(lm[0]) != 106: if int(lm[0]) != 106:
raise ValueError(f'wrong landmarks format in {landmark_filepath}') raise ValueError(f'wrong landmarks format in {landmark_filepath}')
lmrks = [] lmrks = []
for i in range(106): for i in range(106):
x,y = lm[i+1].split(' ') x,y = lm[i+1].split(' ')
x,y = float(x), float(y) x,y = float(x), float(y)
lmrks.append ( (x,y) ) lmrks.append ( (x,y) )
lmrks = np.array(lmrks) lmrks = np.array(lmrks)
l,t = np.min(lmrks, 0) l,t = np.min(lmrks, 0)
r,b = np.max(lmrks, 0) r,b = np.max(lmrks, 0)
l,t,r,b = ( int(x) for x in (l,t,r,b) ) l,t,r,b = ( int(x) for x in (l,t,r,b) )
#for x, y in lmrks: #for x, y in lmrks:
# x,y = int(x), int(y) # x,y = int(x), int(y)
# cv2.circle(img, (x, y), 1, (0,255,0) , 1, lineType=cv2.LINE_AA) # cv2.circle(img, (x, y), 1, (0,255,0) , 1, lineType=cv2.LINE_AA)
#imagelib.draw_rect(img, (l,t,r,b), (0,255,0) ) #imagelib.draw_rect(img, (l,t,r,b), (0,255,0) )
data += [ ExtractSubprocessor.Data(filepath=filepath, rects=[ (l,t,r,b) ]) ] data += [ ExtractSubprocessor.Data(filepath=filepath, rects=[ (l,t,r,b) ]) ]
#cv2.imshow("", img) #cv2.imshow("", img)
#cv2.waitKey(0) #cv2.waitKey(0)
if len(data) > 0: if len(data) > 0:
device_config = nn.DeviceConfig.BestGPU() device_config = nn.DeviceConfig.BestGPU()
io.log_info ("Performing 2nd pass...") io.log_info ("Performing 2nd pass...")
data = ExtractSubprocessor (data, 'landmarks', image_size, 95, face_type, device_config=device_config).run() data = ExtractSubprocessor (data, 'landmarks', image_size, 95, face_type, device_config=device_config).run()
io.log_info ("Performing 3rd pass...") io.log_info ("Performing 3rd pass...")
@ -438,33 +438,33 @@ def dev_test(input_dir):
for filename in pathex.get_image_paths (output_path): for filename in pathex.get_image_paths (output_path):
filepath = Path(filename) filepath = Path(filename)
dflimg = DFLJPG.load(filepath) dflimg = DFLJPG.load(filepath)
src_filename = dflimg.get_source_filename() src_filename = dflimg.get_source_filename()
image_to_face_mat = dflimg.get_image_to_face_mat() image_to_face_mat = dflimg.get_image_to_face_mat()
label_filepath = labels_path / ( Path(src_filename).stem + '.png') label_filepath = labels_path / ( Path(src_filename).stem + '.png')
if not label_filepath.exists(): if not label_filepath.exists():
raise ValueError(f'{label_filepath} does not exist') raise ValueError(f'{label_filepath} does not exist')
mask = cv2_imread(label_filepath) mask = cv2_imread(label_filepath)
#mask[mask == 10] = 0 # remove hair #mask[mask == 10] = 0 # remove hair
mask[mask > 0] = 1 mask[mask > 0] = 1
mask = cv2.warpAffine(mask, image_to_face_mat, (image_size, image_size), cv2.INTER_LINEAR) mask = cv2.warpAffine(mask, image_to_face_mat, (image_size, image_size), cv2.INTER_LINEAR)
mask = cv2.blur(mask, (3,3) ) mask = cv2.blur(mask, (3,3) )
#cv2.imshow("", (mask*255).astype(np.uint8) ) #cv2.imshow("", (mask*255).astype(np.uint8) )
#cv2.waitKey(0) #cv2.waitKey(0)
dflimg.set_xseg_mask(mask) dflimg.set_xseg_mask(mask)
dflimg.save() dflimg.save()
import code import code
code.interact(local=dict(globals(), **locals())) code.interact(local=dict(globals(), **locals()))
def dev_resave_pngs(input_dir): def dev_resave_pngs(input_dir):
input_path = Path(input_dir) input_path = Path(input_dir)

View file

@ -84,18 +84,18 @@ class InteractiveMergerSubprocessor(Subprocessor):
filepath = frame_info.filepath filepath = frame_info.filepath
if len(frame_info.landmarks_list) == 0: if len(frame_info.landmarks_list) == 0:
if cfg.mode == 'raw-predict': if cfg.mode == 'raw-predict':
h,w,c = self.predictor_input_shape h,w,c = self.predictor_input_shape
img_bgr = np.zeros( (h,w,3), dtype=np.uint8) img_bgr = np.zeros( (h,w,3), dtype=np.uint8)
img_mask = np.zeros( (h,w,1), dtype=np.uint8) img_mask = np.zeros( (h,w,1), dtype=np.uint8)
else: else:
self.log_info (f'no faces found for {filepath.name}, copying without faces') self.log_info (f'no faces found for {filepath.name}, copying without faces')
img_bgr = cv2_imread(filepath) img_bgr = cv2_imread(filepath)
imagelib.normalize_channels(img_bgr, 3) imagelib.normalize_channels(img_bgr, 3)
h,w,c = img_bgr.shape h,w,c = img_bgr.shape
img_mask = np.zeros( (h,w,1), dtype=img_bgr.dtype) img_mask = np.zeros( (h,w,1), dtype=img_bgr.dtype)
cv2_imwrite (pf.output_filepath, img_bgr) cv2_imwrite (pf.output_filepath, img_bgr)
cv2_imwrite (pf.output_mask_filepath, img_mask) cv2_imwrite (pf.output_mask_filepath, img_mask)

View file

@ -316,7 +316,7 @@ def MergeMaskedFace (predictor_func, predictor_input_shape,
if out_img is None: if out_img is None:
out_img = img_bgr.copy() out_img = img_bgr.copy()
return out_img, out_merging_mask_a return out_img, out_merging_mask_a

View file

@ -43,10 +43,10 @@ class Screen(object):
def toggle_show_checker_board(self): def toggle_show_checker_board(self):
self.show_checker_board = not self.show_checker_board self.show_checker_board = not self.show_checker_board
self.force_update = True self.force_update = True
def get_image(self): def get_image(self):
return self.image return self.image
def set_image(self, img): def set_image(self, img):
if not img is self.image: if not img is self.image:
self.force_update = True self.force_update = True

View file

@ -187,7 +187,7 @@ class ModelBase(object):
self.random_flip = self.options.get('random_flip',True) self.random_flip = self.options.get('random_flip',True)
self.random_src_flip = self.options.get('random_src_flip', False) self.random_src_flip = self.options.get('random_src_flip', False)
self.random_dst_flip = self.options.get('random_dst_flip', True) self.random_dst_flip = self.options.get('random_dst_flip', True)
self.on_initialize() self.on_initialize()
self.options['batch_size'] = self.batch_size self.options['batch_size'] = self.batch_size
@ -299,7 +299,7 @@ class ModelBase(object):
def ask_random_flip(self): def ask_random_flip(self):
default_random_flip = self.load_or_def_option('random_flip', True) default_random_flip = self.load_or_def_option('random_flip', True)
self.options['random_flip'] = io.input_bool("Flip faces randomly", default_random_flip, help_message="Predicted face will look more naturally without this option, but src faceset should cover all face directions as dst faceset.") self.options['random_flip'] = io.input_bool("Flip faces randomly", default_random_flip, help_message="Predicted face will look more naturally without this option, but src faceset should cover all face directions as dst faceset.")
def ask_random_src_flip(self): def ask_random_src_flip(self):
default_random_src_flip = self.load_or_def_option('random_src_flip', False) default_random_src_flip = self.load_or_def_option('random_src_flip', False)
self.options['random_src_flip'] = io.input_bool("Flip SRC faces randomly", default_random_src_flip, help_message="Random horizontal flip SRC faceset. Covers more angles, but the face may look less naturally.") self.options['random_src_flip'] = io.input_bool("Flip SRC faces randomly", default_random_src_flip, help_message="Random horizontal flip SRC faceset. Covers more angles, but the face may look less naturally.")
@ -545,7 +545,7 @@ class ModelBase(object):
def get_summary_text(self): def get_summary_text(self):
visible_options = self.options.copy() visible_options = self.options.copy()
visible_options.update(self.options_show_override) visible_options.update(self.options_show_override)
###Generate text summary of model hyperparameters ###Generate text summary of model hyperparameters
#Find the longest key name and value string. Used as column widths. #Find the longest key name and value string. Used as column widths.
width_name = max([len(k) for k in visible_options.keys()] + [17]) + 1 # Single space buffer to left edge. Minimum of 17, the length of the longest static string used "Current iteration" width_name = max([len(k) for k in visible_options.keys()] + [17]) + 1 # Single space buffer to left edge. Minimum of 17, the length of the longest static string used "Current iteration"

View file

@ -87,7 +87,7 @@ class AMPModel(ModelBase):
d_mask_dims = np.clip ( io.input_int("Decoder mask dimensions", default_d_mask_dims, add_info="16-256", help_message="Typical mask dimensions = decoder dimensions / 3. If you manually cut out obstacles from the dst mask, you can increase this parameter to achieve better quality." ), 16, 256 ) d_mask_dims = np.clip ( io.input_int("Decoder mask dimensions", default_d_mask_dims, add_info="16-256", help_message="Typical mask dimensions = decoder dimensions / 3. If you manually cut out obstacles from the dst mask, you can increase this parameter to achieve better quality." ), 16, 256 )
self.options['d_mask_dims'] = d_mask_dims + d_mask_dims % 2 self.options['d_mask_dims'] = d_mask_dims + d_mask_dims % 2
morph_factor = np.clip ( io.input_number ("Morph factor.", default_morph_factor, add_info="0.1 .. 0.5", help_message="The smaller the value, the more src-like facial expressions will appear. The larger the value, the less space there is to train a large dst faceset in the neural network. Typical fine value is 0.33"), 0.1, 0.5 ) morph_factor = np.clip ( io.input_number ("Morph factor.", default_morph_factor, add_info="0.1 .. 0.5", help_message="The smaller the value, the more src-like facial expressions will appear. The larger the value, the less space there is to train a large dst faceset in the neural network. Typical fine value is 0.33"), 0.1, 0.5 )
self.options['morph_factor'] = morph_factor self.options['morph_factor'] = morph_factor
@ -121,9 +121,9 @@ class AMPModel(ModelBase):
self.options['ct_mode'] = io.input_str (f"Color transfer for src faceset", default_ct_mode, ['none','rct','lct','mkl','idt','sot'], help_message="Change color distribution of src samples close to dst samples. Try all modes to find the best.") self.options['ct_mode'] = io.input_str (f"Color transfer for src faceset", default_ct_mode, ['none','rct','lct','mkl','idt','sot'], help_message="Change color distribution of src samples close to dst samples. Try all modes to find the best.")
self.options['clipgrad'] = io.input_bool ("Enable gradient clipping", default_clipgrad, help_message="Gradient clipping reduces chance of model collapse, sacrificing speed of training.") self.options['clipgrad'] = io.input_bool ("Enable gradient clipping", default_clipgrad, help_message="Gradient clipping reduces chance of model collapse, sacrificing speed of training.")
self.options['pretrain'] = io.input_bool ("Enable pretraining mode", default_pretrain, help_message="Pretrain the model with large amount of various faces. After that, model can be used to train the fakes more quickly. Forces random_warp=N, random_flips=Y, gan_power=0.0, lr_dropout=N, uniform_yaw=Y") self.options['pretrain'] = io.input_bool ("Enable pretraining mode", default_pretrain, help_message="Pretrain the model with large amount of various faces. After that, model can be used to train the fakes more quickly. Forces random_warp=N, random_flips=Y, gan_power=0.0, lr_dropout=N, uniform_yaw=Y")
self.gan_model_changed = (default_gan_patch_size != self.options['gan_patch_size']) or (default_gan_dims != self.options['gan_dims']) self.gan_model_changed = (default_gan_patch_size != self.options['gan_patch_size']) or (default_gan_dims != self.options['gan_dims'])
self.pretrain_just_disabled = (default_pretrain == True and self.options['pretrain'] == False) self.pretrain_just_disabled = (default_pretrain == True and self.options['pretrain'] == False)
@ -285,22 +285,22 @@ class AMPModel(ModelBase):
d_dims = self.options['d_dims'] d_dims = self.options['d_dims']
d_mask_dims = self.options['d_mask_dims'] d_mask_dims = self.options['d_mask_dims']
morph_factor = self.options['morph_factor'] morph_factor = self.options['morph_factor']
pretrain = self.pretrain = self.options['pretrain'] pretrain = self.pretrain = self.options['pretrain']
if self.pretrain_just_disabled: if self.pretrain_just_disabled:
self.set_iter(0) self.set_iter(0)
self.gan_power = gan_power = 0.0 if self.pretrain else self.options['gan_power'] self.gan_power = gan_power = 0.0 if self.pretrain else self.options['gan_power']
random_warp = False if self.pretrain else self.options['random_warp'] random_warp = False if self.pretrain else self.options['random_warp']
random_src_flip = self.random_src_flip if not self.pretrain else True random_src_flip = self.random_src_flip if not self.pretrain else True
random_dst_flip = self.random_dst_flip if not self.pretrain else True random_dst_flip = self.random_dst_flip if not self.pretrain else True
if self.pretrain: if self.pretrain:
self.options_show_override['gan_power'] = 0.0 self.options_show_override['gan_power'] = 0.0
self.options_show_override['random_warp'] = False self.options_show_override['random_warp'] = False
self.options_show_override['lr_dropout'] = 'n' self.options_show_override['lr_dropout'] = 'n'
self.options_show_override['uniform_yaw'] = True self.options_show_override['uniform_yaw'] = True
masked_training = self.options['masked_training'] masked_training = self.options['masked_training']
ct_mode = self.options['ct_mode'] ct_mode = self.options['ct_mode']
if ct_mode == 'none': if ct_mode == 'none':
@ -351,7 +351,7 @@ class AMPModel(ModelBase):
# Initialize optimizers # Initialize optimizers
lr=5e-5 lr=5e-5
lr_dropout = 0.3 if self.options['lr_dropout'] in ['y','cpu'] and not self.pretrain else 1.0 lr_dropout = 0.3 if self.options['lr_dropout'] in ['y','cpu'] and not self.pretrain else 1.0
clipnorm = 1.0 if self.options['clipgrad'] else 0.0 clipnorm = 1.0 if self.options['clipgrad'] else 0.0
self.all_weights = self.encoder.get_weights() + self.inter_src.get_weights() + self.inter_dst.get_weights() + self.decoder.get_weights() self.all_weights = self.encoder.get_weights() + self.inter_src.get_weights() + self.inter_dst.get_weights() + self.decoder.get_weights()
@ -407,7 +407,7 @@ class AMPModel(ModelBase):
# process model tensors # process model tensors
gpu_src_code = self.encoder (gpu_warped_src) gpu_src_code = self.encoder (gpu_warped_src)
gpu_dst_code = self.encoder (gpu_warped_dst) gpu_dst_code = self.encoder (gpu_warped_dst)
if pretrain: if pretrain:
gpu_src_inter_src_code = self.inter_src (gpu_src_code) gpu_src_inter_src_code = self.inter_src (gpu_src_code)
gpu_dst_inter_dst_code = self.inter_dst (gpu_dst_code) gpu_dst_inter_dst_code = self.inter_dst (gpu_dst_code)
@ -454,7 +454,7 @@ class AMPModel(ModelBase):
gpu_pred_src_src_anti_masked = gpu_pred_src_src*(1.0-gpu_target_srcm_blur) gpu_pred_src_src_anti_masked = gpu_pred_src_src*(1.0-gpu_target_srcm_blur)
gpu_pred_dst_dst_masked_opt = gpu_pred_dst_dst*gpu_target_dstm_blur if masked_training else gpu_pred_dst_dst gpu_pred_dst_dst_masked_opt = gpu_pred_dst_dst*gpu_target_dstm_blur if masked_training else gpu_pred_dst_dst
gpu_pred_dst_dst_anti_masked = gpu_pred_dst_dst*(1.0-gpu_target_dstm_blur) gpu_pred_dst_dst_anti_masked = gpu_pred_dst_dst*(1.0-gpu_target_dstm_blur)
if resolution < 256: if resolution < 256:
gpu_dst_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1]) gpu_dst_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1])
else: else:
@ -481,12 +481,12 @@ class AMPModel(ModelBase):
gpu_src_loss += tf.reduce_mean ( 10*tf.square( gpu_target_srcm - gpu_pred_src_srcm ),axis=[1,2,3] ) gpu_src_loss += tf.reduce_mean ( 10*tf.square( gpu_target_srcm - gpu_pred_src_srcm ),axis=[1,2,3] )
else: else:
gpu_src_loss = gpu_dst_loss gpu_src_loss = gpu_dst_loss
gpu_src_losses += [gpu_src_loss] gpu_src_losses += [gpu_src_loss]
if pretrain: if pretrain:
gpu_G_loss = gpu_dst_loss gpu_G_loss = gpu_dst_loss
else: else:
gpu_G_loss = gpu_src_loss + gpu_dst_loss gpu_G_loss = gpu_src_loss + gpu_dst_loss
def DLossOnes(logits): def DLossOnes(logits):
@ -605,7 +605,7 @@ class AMPModel(ModelBase):
if self.is_training and gan_power != 0 and model == self.GAN: if self.is_training and gan_power != 0 and model == self.GAN:
if self.gan_model_changed: if self.gan_model_changed:
do_init = True do_init = True
if not do_init: if not do_init:
do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) ) do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) )
if do_init: if do_init:
@ -654,7 +654,7 @@ class AMPModel(ModelBase):
self.last_dst_samples_loss = [] self.last_dst_samples_loss = []
if self.pretrain_just_disabled: if self.pretrain_just_disabled:
self.update_sample_for_preview(force_new=True) self.update_sample_for_preview(force_new=True)
def dump_ckpt(self): def dump_ckpt(self):
tf = nn.tf tf = nn.tf
@ -681,13 +681,13 @@ class AMPModel(ModelBase):
tf.identity(gpu_pred_dst_dstm, name='out_face_mask') tf.identity(gpu_pred_dst_dstm, name='out_face_mask')
tf.identity(gpu_pred_src_dst, name='out_celeb_face') tf.identity(gpu_pred_src_dst, name='out_celeb_face')
tf.identity(gpu_pred_src_dstm, name='out_celeb_face_mask') tf.identity(gpu_pred_src_dstm, name='out_celeb_face_mask')
output_graph_def = tf.graph_util.convert_variables_to_constants( output_graph_def = tf.graph_util.convert_variables_to_constants(
nn.tf_sess, nn.tf_sess,
tf.get_default_graph().as_graph_def(), tf.get_default_graph().as_graph_def(),
['out_face_mask','out_celeb_face','out_celeb_face_mask'] ['out_face_mask','out_celeb_face','out_celeb_face_mask']
) )
pb_filepath = self.get_strpath_storage_for_file('.pb') pb_filepath = self.get_strpath_storage_for_file('.pb')
with tf.gfile.GFile(pb_filepath, "wb") as f: with tf.gfile.GFile(pb_filepath, "wb") as f:
f.write(output_graph_def.SerializeToString()) f.write(output_graph_def.SerializeToString())
@ -791,7 +791,7 @@ class AMPModel(ModelBase):
def predictor_func (self, face, morph_value): def predictor_func (self, face, morph_value):
face = nn.to_data_format(face[None,...], self.model_data_format, "NHWC") face = nn.to_data_format(face[None,...], self.model_data_format, "NHWC")
bgr, mask_dst_dstm, mask_src_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format).astype(np.float32) for x in self.AE_merge (face, morph_value) ] bgr, mask_dst_dstm, mask_src_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format).astype(np.float32) for x in self.AE_merge (face, morph_value) ]
return bgr[0], mask_src_dstm[0][...,0], mask_dst_dstm[0][...,0] return bgr[0], mask_src_dstm[0][...,0], mask_dst_dstm[0][...,0]
@ -802,9 +802,9 @@ class AMPModel(ModelBase):
def predictor_morph(face): def predictor_morph(face):
return self.predictor_func(face, morph_factor) return self.predictor_func(face, morph_factor)
import merger
import merger
return predictor_morph, (self.options['resolution'], self.options['resolution'], 3), merger.MergerConfigMasked(face_type=self.face_type, default_mode = 'overlay') return predictor_morph, (self.options['resolution'], self.options['resolution'], 3), merger.MergerConfigMasked(face_type=self.face_type, default_mode = 'overlay')
Model = AMPModel Model = AMPModel

View file

@ -39,7 +39,7 @@ class QModel(ModelBase):
mask_shape = nn.get4Dshape(resolution,resolution,1) mask_shape = nn.get4Dshape(resolution,resolution,1)
self.model_filename_list = [] self.model_filename_list = []
model_archi = nn.DeepFakeArchi(resolution, opts='ud') model_archi = nn.DeepFakeArchi(resolution, opts='ud')
with tf.device ('/CPU:0'): with tf.device ('/CPU:0'):
@ -94,7 +94,7 @@ class QModel(ModelBase):
gpu_src_losses = [] gpu_src_losses = []
gpu_dst_losses = [] gpu_dst_losses = []
gpu_src_dst_loss_gvs = [] gpu_src_dst_loss_gvs = []
for gpu_id in range(gpu_count): for gpu_id in range(gpu_count):
with tf.device( f'/{devices[gpu_id].tf_dev_type}:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ): with tf.device( f'/{devices[gpu_id].tf_dev_type}:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ):
batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu ) batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu )

View file

@ -73,9 +73,9 @@ class SAEHDModel(ModelBase):
resolution = io.input_int("Resolution", default_resolution, add_info="64-640", help_message="More resolution requires more VRAM and time to train. Value will be adjusted to multiple of 16 and 32 for -d archi.") resolution = io.input_int("Resolution", default_resolution, add_info="64-640", help_message="More resolution requires more VRAM and time to train. Value will be adjusted to multiple of 16 and 32 for -d archi.")
resolution = np.clip ( (resolution // 16) * 16, min_res, max_res) resolution = np.clip ( (resolution // 16) * 16, min_res, max_res)
self.options['resolution'] = resolution self.options['resolution'] = resolution
self.options['face_type'] = io.input_str ("Face type", default_face_type, ['h','mf','f','wf','head'], help_message="Half / mid face / full face / whole face / head. Half face has better resolution, but covers less area of cheeks. Mid face is 30% wider than half face. 'Whole face' covers full area of face include forehead. 'head' covers full head, but requires XSeg for src and dst faceset.").lower() self.options['face_type'] = io.input_str ("Face type", default_face_type, ['h','mf','f','wf','head'], help_message="Half / mid face / full face / whole face / head. Half face has better resolution, but covers less area of cheeks. Mid face is 30% wider than half face. 'Whole face' covers full area of face include forehead. 'head' covers full head, but requires XSeg for src and dst faceset.").lower()
while True: while True:
@ -136,11 +136,11 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
self.options['eyes_mouth_prio'] = io.input_bool ("Eyes and mouth priority", default_eyes_mouth_prio, help_message='Helps to fix eye problems during training like "alien eyes" and wrong eyes direction. Also makes the detail of the teeth higher.') self.options['eyes_mouth_prio'] = io.input_bool ("Eyes and mouth priority", default_eyes_mouth_prio, help_message='Helps to fix eye problems during training like "alien eyes" and wrong eyes direction. Also makes the detail of the teeth higher.')
self.options['uniform_yaw'] = io.input_bool ("Uniform yaw distribution of samples", default_uniform_yaw, help_message='Helps to fix blurry side faces due to small amount of them in the faceset.') self.options['uniform_yaw'] = io.input_bool ("Uniform yaw distribution of samples", default_uniform_yaw, help_message='Helps to fix blurry side faces due to small amount of them in the faceset.')
default_gan_power = self.options['gan_power'] = self.load_or_def_option('gan_power', 0.0) default_gan_power = self.options['gan_power'] = self.load_or_def_option('gan_power', 0.0)
default_gan_patch_size = self.options['gan_patch_size'] = self.load_or_def_option('gan_patch_size', self.options['resolution'] // 8) default_gan_patch_size = self.options['gan_patch_size'] = self.load_or_def_option('gan_patch_size', self.options['resolution'] // 8)
default_gan_dims = self.options['gan_dims'] = self.load_or_def_option('gan_dims', 16) default_gan_dims = self.options['gan_dims'] = self.load_or_def_option('gan_dims', 16)
if self.is_first_run() or ask_override: if self.is_first_run() or ask_override:
self.options['models_opt_on_gpu'] = io.input_bool ("Place models and optimizer on GPU", default_models_opt_on_gpu, help_message="When you train on one GPU, by default model and optimizer weights are placed on GPU to accelerate the process. You can place they on CPU to free up extra VRAM, thus set bigger dimensions.") self.options['models_opt_on_gpu'] = io.input_bool ("Place models and optimizer on GPU", default_models_opt_on_gpu, help_message="When you train on one GPU, by default model and optimizer weights are placed on GPU to accelerate the process. You can place they on CPU to free up extra VRAM, thus set bigger dimensions.")
@ -151,14 +151,14 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
self.options['random_warp'] = io.input_bool ("Enable random warp of samples", default_random_warp, help_message="Random warp is required to generalize facial expressions of both faces. When the face is trained enough, you can disable it to get extra sharpness and reduce subpixel shake for less amount of iterations.") self.options['random_warp'] = io.input_bool ("Enable random warp of samples", default_random_warp, help_message="Random warp is required to generalize facial expressions of both faces. When the face is trained enough, you can disable it to get extra sharpness and reduce subpixel shake for less amount of iterations.")
self.options['gan_power'] = np.clip ( io.input_number ("GAN power", default_gan_power, add_info="0.0 .. 1.0", help_message="Forces the neural network to learn small details of the face. Enable it only when the face is trained enough with lr_dropout(on) and random_warp(off), and don't disable. The higher the value, the higher the chances of artifacts. Typical fine value is 0.1"), 0.0, 1.0 ) self.options['gan_power'] = np.clip ( io.input_number ("GAN power", default_gan_power, add_info="0.0 .. 1.0", help_message="Forces the neural network to learn small details of the face. Enable it only when the face is trained enough with lr_dropout(on) and random_warp(off), and don't disable. The higher the value, the higher the chances of artifacts. Typical fine value is 0.1"), 0.0, 1.0 )
if self.options['gan_power'] != 0.0: if self.options['gan_power'] != 0.0:
gan_patch_size = np.clip ( io.input_int("GAN patch size", default_gan_patch_size, add_info="3-640", help_message="The higher patch size, the higher the quality, the more VRAM is required. You can get sharper edges even at the lowest setting. Typical fine value is resolution / 8." ), 3, 640 ) gan_patch_size = np.clip ( io.input_int("GAN patch size", default_gan_patch_size, add_info="3-640", help_message="The higher patch size, the higher the quality, the more VRAM is required. You can get sharper edges even at the lowest setting. Typical fine value is resolution / 8." ), 3, 640 )
self.options['gan_patch_size'] = gan_patch_size self.options['gan_patch_size'] = gan_patch_size
gan_dims = np.clip ( io.input_int("GAN dimensions", default_gan_dims, add_info="4-64", help_message="The dimensions of the GAN network. The higher dimensions, the more VRAM is required. You can get sharper edges even at the lowest setting. Typical fine value is 16." ), 4, 64 ) gan_dims = np.clip ( io.input_int("GAN dimensions", default_gan_dims, add_info="4-64", help_message="The dimensions of the GAN network. The higher dimensions, the more VRAM is required. You can get sharper edges even at the lowest setting. Typical fine value is 16." ), 4, 64 )
self.options['gan_dims'] = gan_dims self.options['gan_dims'] = gan_dims
if 'df' in self.options['archi']: if 'df' in self.options['archi']:
self.options['true_face_power'] = np.clip ( io.input_number ("'True face' power.", default_true_face_power, add_info="0.0000 .. 1.0", help_message="Experimental option. Discriminates result face to be more like src face. Higher value - stronger discrimination. Typical value is 0.01 . Comparison - https://i.imgur.com/czScS9q.png"), 0.0, 1.0 ) self.options['true_face_power'] = np.clip ( io.input_number ("'True face' power.", default_true_face_power, add_info="0.0000 .. 1.0", help_message="Experimental option. Discriminates result face to be more like src face. Higher value - stronger discrimination. Typical value is 0.01 . Comparison - https://i.imgur.com/czScS9q.png"), 0.0, 1.0 )
else: else:
@ -174,7 +174,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
if self.options['pretrain'] and self.get_pretraining_data_path() is None: if self.options['pretrain'] and self.get_pretraining_data_path() is None:
raise Exception("pretraining_data_path is not defined") raise Exception("pretraining_data_path is not defined")
self.gan_model_changed = (default_gan_patch_size != self.options['gan_patch_size']) or (default_gan_dims != self.options['gan_dims']) self.gan_model_changed = (default_gan_patch_size != self.options['gan_patch_size']) or (default_gan_dims != self.options['gan_dims'])
self.pretrain_just_disabled = (default_pretrain == True and self.options['pretrain'] == False) self.pretrain_just_disabled = (default_pretrain == True and self.options['pretrain'] == False)
@ -196,7 +196,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
if 'eyes_prio' in self.options: if 'eyes_prio' in self.options:
self.options.pop('eyes_prio') self.options.pop('eyes_prio')
eyes_mouth_prio = self.options['eyes_mouth_prio'] eyes_mouth_prio = self.options['eyes_mouth_prio']
archi_split = self.options['archi'].split('-') archi_split = self.options['archi'].split('-')
@ -205,7 +205,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
archi_type, archi_opts = archi_split archi_type, archi_opts = archi_split
elif len(archi_split) == 1: elif len(archi_split) == 1:
archi_type, archi_opts = archi_split[0], None archi_type, archi_opts = archi_split[0], None
self.archi_type = archi_type self.archi_type = archi_type
ae_dims = self.options['ae_dims'] ae_dims = self.options['ae_dims']
@ -222,7 +222,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
random_warp = False if self.pretrain else self.options['random_warp'] random_warp = False if self.pretrain else self.options['random_warp']
random_src_flip = self.random_src_flip if not self.pretrain else True random_src_flip = self.random_src_flip if not self.pretrain else True
random_dst_flip = self.random_dst_flip if not self.pretrain else True random_dst_flip = self.random_dst_flip if not self.pretrain else True
if self.pretrain: if self.pretrain:
self.options_show_override['gan_power'] = 0.0 self.options_show_override['gan_power'] = 0.0
self.options_show_override['random_warp'] = False self.options_show_override['random_warp'] = False
@ -235,8 +235,8 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
ct_mode = self.options['ct_mode'] ct_mode = self.options['ct_mode']
if ct_mode == 'none': if ct_mode == 'none':
ct_mode = None ct_mode = None
models_opt_on_gpu = False if len(devices) == 0 else self.options['models_opt_on_gpu'] models_opt_on_gpu = False if len(devices) == 0 else self.options['models_opt_on_gpu']
models_opt_device = nn.tf_default_device_name if models_opt_on_gpu and self.is_training else '/CPU:0' models_opt_device = nn.tf_default_device_name if models_opt_on_gpu and self.is_training else '/CPU:0'
optimizer_vars_on_cpu = models_opt_device=='/CPU:0' optimizer_vars_on_cpu = models_opt_device=='/CPU:0'
@ -350,7 +350,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
gpu_G_loss_gvs = [] gpu_G_loss_gvs = []
gpu_D_code_loss_gvs = [] gpu_D_code_loss_gvs = []
gpu_D_src_dst_loss_gvs = [] gpu_D_src_dst_loss_gvs = []
for gpu_id in range(gpu_count): for gpu_id in range(gpu_count):
with tf.device( f'/{devices[gpu_id].tf_dev_type}:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ): with tf.device( f'/{devices[gpu_id].tf_dev_type}:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ):
with tf.device(f'/CPU:0'): with tf.device(f'/CPU:0'):
@ -402,7 +402,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
gpu_target_dstm_style_blur = gpu_target_dstm_blur #default style mask is 0.5 on boundary gpu_target_dstm_style_blur = gpu_target_dstm_blur #default style mask is 0.5 on boundary
gpu_target_dstm_blur = tf.clip_by_value(gpu_target_dstm_blur, 0, 0.5) * 2 gpu_target_dstm_blur = tf.clip_by_value(gpu_target_dstm_blur, 0, 0.5) * 2
gpu_target_dst_masked = gpu_target_dst*gpu_target_dstm_blur gpu_target_dst_masked = gpu_target_dst*gpu_target_dstm_blur
gpu_target_dst_style_masked = gpu_target_dst*gpu_target_dstm_style_blur gpu_target_dst_style_masked = gpu_target_dst*gpu_target_dstm_style_blur
gpu_target_dst_style_anti_masked = gpu_target_dst*(1.0 - gpu_target_dstm_style_blur) gpu_target_dst_style_anti_masked = gpu_target_dst*(1.0 - gpu_target_dstm_style_blur)
@ -497,14 +497,14 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
gpu_G_loss += gan_power*(DLoss(gpu_pred_src_src_d_ones, gpu_pred_src_src_d) + \ gpu_G_loss += gan_power*(DLoss(gpu_pred_src_src_d_ones, gpu_pred_src_src_d) + \
DLoss(gpu_pred_src_src_d2_ones, gpu_pred_src_src_d2)) DLoss(gpu_pred_src_src_d2_ones, gpu_pred_src_src_d2))
if masked_training: if masked_training:
# Minimal src-src-bg rec with total_variation_mse to suppress random bright dots from gan # Minimal src-src-bg rec with total_variation_mse to suppress random bright dots from gan
gpu_G_loss += 0.000001*nn.total_variation_mse(gpu_pred_src_src) gpu_G_loss += 0.000001*nn.total_variation_mse(gpu_pred_src_src)
gpu_G_loss += 0.02*tf.reduce_mean(tf.square(gpu_pred_src_src_anti_masked-gpu_target_src_anti_masked),axis=[1,2,3] ) gpu_G_loss += 0.02*tf.reduce_mean(tf.square(gpu_pred_src_src_anti_masked-gpu_target_src_anti_masked),axis=[1,2,3] )
gpu_G_loss_gvs += [ nn.gradients ( gpu_G_loss, self.src_dst_trainable_weights ) ] gpu_G_loss_gvs += [ nn.gradients ( gpu_G_loss, self.src_dst_trainable_weights ) ]
@ -614,10 +614,10 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
if do_init: if do_init:
model.init_weights() model.init_weights()
############### ###############
# initializing sample generators # initializing sample generators
if self.is_training: if self.is_training:
training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path() training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path()
@ -658,16 +658,16 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
if self.pretrain_just_disabled: if self.pretrain_just_disabled:
self.update_sample_for_preview(force_new=True) self.update_sample_for_preview(force_new=True)
def dump_ckpt(self): def dump_ckpt(self):
tf = nn.tf tf = nn.tf
with tf.device ('/CPU:0'): with tf.device ('/CPU:0'):
warped_dst = tf.placeholder (nn.floatx, (None, self.resolution, self.resolution, 3), name='in_face') warped_dst = tf.placeholder (nn.floatx, (None, self.resolution, self.resolution, 3), name='in_face')
warped_dst = tf.transpose(warped_dst, (0,3,1,2)) warped_dst = tf.transpose(warped_dst, (0,3,1,2))
if 'df' in self.archi_type: if 'df' in self.archi_type:
gpu_dst_code = self.inter(self.encoder(warped_dst)) gpu_dst_code = self.inter(self.encoder(warped_dst))
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code)
@ -682,20 +682,20 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code)
_, gpu_pred_dst_dstm = self.decoder(gpu_dst_code) _, gpu_pred_dst_dstm = self.decoder(gpu_dst_code)
gpu_pred_src_dst = tf.transpose(gpu_pred_src_dst, (0,2,3,1)) gpu_pred_src_dst = tf.transpose(gpu_pred_src_dst, (0,2,3,1))
gpu_pred_dst_dstm = tf.transpose(gpu_pred_dst_dstm, (0,2,3,1)) gpu_pred_dst_dstm = tf.transpose(gpu_pred_dst_dstm, (0,2,3,1))
gpu_pred_src_dstm = tf.transpose(gpu_pred_src_dstm, (0,2,3,1)) gpu_pred_src_dstm = tf.transpose(gpu_pred_src_dstm, (0,2,3,1))
saver = tf.train.Saver() saver = tf.train.Saver()
tf.identity(gpu_pred_dst_dstm, name='out_face_mask') tf.identity(gpu_pred_dst_dstm, name='out_face_mask')
tf.identity(gpu_pred_src_dst, name='out_celeb_face') tf.identity(gpu_pred_src_dst, name='out_celeb_face')
tf.identity(gpu_pred_src_dstm, name='out_celeb_face_mask') tf.identity(gpu_pred_src_dstm, name='out_celeb_face_mask')
saver.save(nn.tf_sess, self.get_strpath_storage_for_file('.ckpt') ) saver.save(nn.tf_sess, self.get_strpath_storage_for_file('.ckpt') )
#override #override
def get_model_filename_list(self): def get_model_filename_list(self):
return self.model_filename_list return self.model_filename_list

View file

@ -95,7 +95,7 @@ class XSegModel(ModelBase):
gpu_pred_list.append(gpu_pred_t) gpu_pred_list.append(gpu_pred_t)
gpu_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=gpu_target_t, logits=gpu_pred_logits_t), axis=[1,2,3]) gpu_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=gpu_target_t, logits=gpu_pred_logits_t), axis=[1,2,3])
gpu_losses += [gpu_loss] gpu_losses += [gpu_loss]
gpu_loss_gvs += [ nn.gradients ( gpu_loss, self.model.get_weights() ) ] gpu_loss_gvs += [ nn.gradients ( gpu_loss, self.model.get_weights() ) ]

View file

@ -84,7 +84,7 @@ class PackedFaceset():
of.write ( struct.pack("Q", offset) ) of.write ( struct.pack("Q", offset) )
of.seek(0,2) of.seek(0,2)
of.close() of.close()
if io.input_bool(f"Delete original files?", True): if io.input_bool(f"Delete original files?", True):
for filename in io.progress_bar_generator(image_paths, "Deleting files"): for filename in io.progress_bar_generator(image_paths, "Deleting files"):
Path(filename).unlink() Path(filename).unlink()
@ -125,7 +125,7 @@ class PackedFaceset():
def path_contains(samples_path): def path_contains(samples_path):
samples_dat_path = samples_path / packed_faceset_filename samples_dat_path = samples_path / packed_faceset_filename
return samples_dat_path.exists() return samples_dat_path.exists()
@staticmethod @staticmethod
def load(samples_path): def load(samples_path):
samples_dat_path = samples_path / packed_faceset_filename samples_dat_path = samples_path / packed_faceset_filename

View file

@ -55,23 +55,23 @@ class Sample(object):
self.face_type = face_type self.face_type = face_type
self.shape = shape self.shape = shape
self.landmarks = np.array(landmarks) if landmarks is not None else None self.landmarks = np.array(landmarks) if landmarks is not None else None
if isinstance(seg_ie_polys, SegIEPolys): if isinstance(seg_ie_polys, SegIEPolys):
self.seg_ie_polys = seg_ie_polys self.seg_ie_polys = seg_ie_polys
else: else:
self.seg_ie_polys = SegIEPolys.load(seg_ie_polys) self.seg_ie_polys = SegIEPolys.load(seg_ie_polys)
self.xseg_mask = xseg_mask self.xseg_mask = xseg_mask
self.xseg_mask_compressed = xseg_mask_compressed self.xseg_mask_compressed = xseg_mask_compressed
if self.xseg_mask_compressed is None and self.xseg_mask is not None: if self.xseg_mask_compressed is None and self.xseg_mask is not None:
xseg_mask = np.clip( imagelib.normalize_channels(xseg_mask, 1)*255, 0, 255 ).astype(np.uint8) xseg_mask = np.clip( imagelib.normalize_channels(xseg_mask, 1)*255, 0, 255 ).astype(np.uint8)
ret, xseg_mask_compressed = cv2.imencode('.png', xseg_mask) ret, xseg_mask_compressed = cv2.imencode('.png', xseg_mask)
if not ret: if not ret:
raise Exception("Sample(): unable to generate xseg_mask_compressed") raise Exception("Sample(): unable to generate xseg_mask_compressed")
self.xseg_mask_compressed = xseg_mask_compressed self.xseg_mask_compressed = xseg_mask_compressed
self.xseg_mask = None self.xseg_mask = None
self.eyebrows_expand_mod = eyebrows_expand_mod if eyebrows_expand_mod is not None else 1.0 self.eyebrows_expand_mod = eyebrows_expand_mod if eyebrows_expand_mod is not None else 1.0
self.source_filename = source_filename self.source_filename = source_filename
self.person_name = person_name self.person_name = person_name
@ -81,7 +81,7 @@ class Sample(object):
def has_xseg_mask(self): def has_xseg_mask(self):
return self.xseg_mask is not None or self.xseg_mask_compressed is not None return self.xseg_mask is not None or self.xseg_mask_compressed is not None
def get_xseg_mask(self): def get_xseg_mask(self):
if self.xseg_mask_compressed is not None: if self.xseg_mask_compressed is not None:
xseg_mask = cv2.imdecode(self.xseg_mask_compressed, cv2.IMREAD_UNCHANGED) xseg_mask = cv2.imdecode(self.xseg_mask_compressed, cv2.IMREAD_UNCHANGED)
@ -89,7 +89,7 @@ class Sample(object):
xseg_mask = xseg_mask[...,None] xseg_mask = xseg_mask[...,None]
return xseg_mask.astype(np.float32) / 255.0 return xseg_mask.astype(np.float32) / 255.0
return self.xseg_mask return self.xseg_mask
def get_pitch_yaw_roll(self): def get_pitch_yaw_roll(self):
if self.pitch_yaw_roll is None: if self.pitch_yaw_roll is None:
self.pitch_yaw_roll = LandmarksProcessor.estimate_pitch_yaw_roll(self.landmarks, size=self.shape[1]) self.pitch_yaw_roll = LandmarksProcessor.estimate_pitch_yaw_roll(self.landmarks, size=self.shape[1])

View file

@ -29,7 +29,7 @@ class SampleGeneratorBase(object):
def __next__(self): def __next__(self):
#implement your own iterator #implement your own iterator
return None return None
#overridable #overridable
def is_initialized(self): def is_initialized(self):
return True return True

View file

@ -27,14 +27,14 @@ class SampleGeneratorFace(SampleGeneratorBase):
output_sample_types=[], output_sample_types=[],
uniform_yaw_distribution=False, uniform_yaw_distribution=False,
generators_count=4, generators_count=4,
raise_on_no_data=True, raise_on_no_data=True,
**kwargs): **kwargs):
super().__init__(debug, batch_size) super().__init__(debug, batch_size)
self.initialized = False self.initialized = False
self.sample_process_options = sample_process_options self.sample_process_options = sample_process_options
self.output_sample_types = output_sample_types self.output_sample_types = output_sample_types
if self.debug: if self.debug:
self.generators_count = 1 self.generators_count = 1
else: else:
@ -42,16 +42,16 @@ class SampleGeneratorFace(SampleGeneratorBase):
samples = SampleLoader.load (SampleType.FACE, samples_path) samples = SampleLoader.load (SampleType.FACE, samples_path)
self.samples_len = len(samples) self.samples_len = len(samples)
if self.samples_len == 0: if self.samples_len == 0:
if raise_on_no_data: if raise_on_no_data:
raise ValueError('No training data provided.') raise ValueError('No training data provided.')
else: else:
return return
if uniform_yaw_distribution: if uniform_yaw_distribution:
samples_pyr = [ ( idx, sample.get_pitch_yaw_roll() ) for idx, sample in enumerate(samples) ] samples_pyr = [ ( idx, sample.get_pitch_yaw_roll() ) for idx, sample in enumerate(samples) ]
grads = 128 grads = 128
#instead of math.pi / 2, using -1.2,+1.2 because actually maximum yaw for 2DFAN landmarks are -1.2+1.2 #instead of math.pi / 2, using -1.2,+1.2 because actually maximum yaw for 2DFAN landmarks are -1.2+1.2
grads_space = np.linspace (-1.2, 1.2,grads) grads_space = np.linspace (-1.2, 1.2,grads)
@ -70,9 +70,9 @@ class SampleGeneratorFace(SampleGeneratorBase):
yaw_samples += [ idx ] yaw_samples += [ idx ]
if len(yaw_samples) > 0: if len(yaw_samples) > 0:
yaws_sample_list[g] = yaw_samples yaws_sample_list[g] = yaw_samples
yaws_sample_list = [ y for y in yaws_sample_list if y is not None ] yaws_sample_list = [ y for y in yaws_sample_list if y is not None ]
index_host = mplib.Index2DHost( yaws_sample_list ) index_host = mplib.Index2DHost( yaws_sample_list )
else: else:
index_host = mplib.IndexHost(self.samples_len) index_host = mplib.IndexHost(self.samples_len)
@ -89,31 +89,31 @@ class SampleGeneratorFace(SampleGeneratorBase):
else: else:
self.generators = [SubprocessGenerator ( self.batch_func, (samples, index_host.create_cli(), ct_samples, ct_index_host.create_cli() if ct_index_host is not None else None), start_now=False ) \ self.generators = [SubprocessGenerator ( self.batch_func, (samples, index_host.create_cli(), ct_samples, ct_index_host.create_cli() if ct_index_host is not None else None), start_now=False ) \
for i in range(self.generators_count) ] for i in range(self.generators_count) ]
SubprocessGenerator.start_in_parallel( self.generators ) SubprocessGenerator.start_in_parallel( self.generators )
self.generator_counter = -1 self.generator_counter = -1
self.initialized = True self.initialized = True
#overridable #overridable
def is_initialized(self): def is_initialized(self):
return self.initialized return self.initialized
def __iter__(self): def __iter__(self):
return self return self
def __next__(self): def __next__(self):
if not self.initialized: if not self.initialized:
return [] return []
self.generator_counter += 1 self.generator_counter += 1
generator = self.generators[self.generator_counter % len(self.generators) ] generator = self.generators[self.generator_counter % len(self.generators) ]
return next(generator) return next(generator)
def batch_func(self, param ): def batch_func(self, param ):
samples, index_host, ct_samples, ct_index_host = param samples, index_host, ct_samples, ct_index_host = param
bs = self.batch_size bs = self.batch_size
while True: while True:
batches = None batches = None

View file

@ -198,22 +198,22 @@ class SampleGeneratorFaceCelebAMaskHQ(SampleGeneratorBase):
if hat_path.exists(): if hat_path.exists():
hat = cv2_imread(hat_path)[...,0:1].astype(np.float32) / 255.0 hat = cv2_imread(hat_path)[...,0:1].astype(np.float32) / 255.0
mask *= (1-hat) mask *= (1-hat)
#if neck_path is not None: #if neck_path is not None:
# neck_path = masks_path / neck_path # neck_path = masks_path / neck_path
# if neck_path.exists(): # if neck_path.exists():
# neck = cv2_imread(neck_path)[...,0:1].astype(np.float32) / 255.0 # neck = cv2_imread(neck_path)[...,0:1].astype(np.float32) / 255.0
# mask = np.clip(mask+neck, 0, 1) # mask = np.clip(mask+neck, 0, 1)
warp_params = imagelib.gen_warp_params(resolution, random_flip, rotation_range=rotation_range, scale_range=scale_range, tx_range=tx_range, ty_range=ty_range ) warp_params = imagelib.gen_warp_params(resolution, random_flip, rotation_range=rotation_range, scale_range=scale_range, tx_range=tx_range, ty_range=ty_range )
img = cv2.resize( img, (resolution,resolution), cv2.INTER_LANCZOS4 ) img = cv2.resize( img, (resolution,resolution), cv2.INTER_LANCZOS4 )
h, s, v = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV)) h, s, v = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
h = ( h + np.random.randint(360) ) % 360 h = ( h + np.random.randint(360) ) % 360
s = np.clip ( s + np.random.random()-0.5, 0, 1 ) s = np.clip ( s + np.random.random()-0.5, 0, 1 )
v = np.clip ( v + np.random.random()/2-0.25, 0, 1 ) v = np.clip ( v + np.random.random()/2-0.25, 0, 1 )
img = np.clip( cv2.cvtColor(cv2.merge([h, s, v]), cv2.COLOR_HSV2BGR) , 0, 1 ) img = np.clip( cv2.cvtColor(cv2.merge([h, s, v]), cv2.COLOR_HSV2BGR) , 0, 1 )
if motion_blur is not None: if motion_blur is not None:
chance, mb_max_size = motion_blur chance, mb_max_size = motion_blur
chance = np.clip(chance, 0, 100) chance = np.clip(chance, 0, 100)
@ -226,7 +226,7 @@ class SampleGeneratorFaceCelebAMaskHQ(SampleGeneratorBase):
img = imagelib.LinearMotionBlur (img, mblur_rnd_kernel, mblur_rnd_deg ) img = imagelib.LinearMotionBlur (img, mblur_rnd_kernel, mblur_rnd_deg )
img = imagelib.warp_by_params (warp_params, img, can_warp=True, can_transform=True, can_flip=True, border_replicate=False, cv2_inter=cv2.INTER_LANCZOS4) img = imagelib.warp_by_params (warp_params, img, can_warp=True, can_transform=True, can_flip=True, border_replicate=False, cv2_inter=cv2.INTER_LANCZOS4)
if gaussian_blur is not None: if gaussian_blur is not None:
chance, kernel_max_size = gaussian_blur chance, kernel_max_size = gaussian_blur
chance = np.clip(chance, 0, 100) chance = np.clip(chance, 0, 100)
@ -236,16 +236,16 @@ class SampleGeneratorFaceCelebAMaskHQ(SampleGeneratorBase):
if gblur_rnd_chance < chance: if gblur_rnd_chance < chance:
img = cv2.GaussianBlur(img, (gblur_rnd_kernel,) *2 , 0) img = cv2.GaussianBlur(img, (gblur_rnd_kernel,) *2 , 0)
if random_bilinear_resize is not None: if random_bilinear_resize is not None:
chance, max_size_per = random_bilinear_resize chance, max_size_per = random_bilinear_resize
chance = np.clip(chance, 0, 100) chance = np.clip(chance, 0, 100)
pick_chance = np.random.randint(100) pick_chance = np.random.randint(100)
resize_to = resolution - int( np.random.rand()* int(resolution*(max_size_per/100.0)) ) resize_to = resolution - int( np.random.rand()* int(resolution*(max_size_per/100.0)) )
img = cv2.resize (img, (resize_to,resize_to), cv2.INTER_LINEAR ) img = cv2.resize (img, (resize_to,resize_to), cv2.INTER_LINEAR )
img = cv2.resize (img, (resolution,resolution), cv2.INTER_LINEAR ) img = cv2.resize (img, (resolution,resolution), cv2.INTER_LINEAR )
mask = cv2.resize( mask, (resolution,resolution), cv2.INTER_LANCZOS4 )[...,None] mask = cv2.resize( mask, (resolution,resolution), cv2.INTER_LANCZOS4 )[...,None]
mask = imagelib.warp_by_params (warp_params, mask, can_warp=True, can_transform=True, can_flip=True, border_replicate=False, cv2_inter=cv2.INTER_LANCZOS4) mask = imagelib.warp_by_params (warp_params, mask, can_warp=True, can_transform=True, can_flip=True, border_replicate=False, cv2_inter=cv2.INTER_LANCZOS4)
mask[mask < 0.5] = 0.0 mask[mask < 0.5] = 0.0
@ -255,10 +255,10 @@ class SampleGeneratorFaceCelebAMaskHQ(SampleGeneratorBase):
if data_format == "NCHW": if data_format == "NCHW":
img = np.transpose(img, (2,0,1) ) img = np.transpose(img, (2,0,1) )
mask = np.transpose(mask, (2,0,1) ) mask = np.transpose(mask, (2,0,1) )
if batches is None: if batches is None:
batches = [ [], [] ] batches = [ [], [] ]
batches[0].append ( img ) batches[0].append ( img )
batches[1].append ( mask ) batches[1].append ( mask )

View file

@ -103,7 +103,7 @@ class Index2DHost():
if not self.cq.empty(): if not self.cq.empty():
return self.cq.get() return self.cq.get()
time.sleep(0.001) time.sleep(0.001)
''' '''
arg arg
output_sample_types = [ output_sample_types = [

View file

@ -151,7 +151,7 @@ class SampleGeneratorFaceXSeg(SampleGeneratorBase):
mask[mask < 0.5] = 0.0 mask[mask < 0.5] = 0.0
mask[mask >= 0.5] = 1.0 mask[mask >= 0.5] = 1.0
mask = np.clip(mask, 0, 1) mask = np.clip(mask, 0, 1)
if np.random.randint(2) == 0: if np.random.randint(2) == 0:
# random face flare # random face flare
krn = np.random.randint( resolution//4, resolution ) krn = np.random.randint( resolution//4, resolution )
@ -168,13 +168,13 @@ class SampleGeneratorFaceXSeg(SampleGeneratorBase):
img = imagelib.apply_random_hsv_shift(img, mask=sd.random_circle_faded ([resolution,resolution])) img = imagelib.apply_random_hsv_shift(img, mask=sd.random_circle_faded ([resolution,resolution]))
else: else:
img = imagelib.apply_random_rgb_levels(img, mask=sd.random_circle_faded ([resolution,resolution])) img = imagelib.apply_random_rgb_levels(img, mask=sd.random_circle_faded ([resolution,resolution]))
if np.random.randint(2) == 0: if np.random.randint(2) == 0:
img = imagelib.apply_random_sharpen( img, sharpen_chance, sharpen_kernel_max_size, mask=sd.random_circle_faded ([resolution,resolution])) img = imagelib.apply_random_sharpen( img, sharpen_chance, sharpen_kernel_max_size, mask=sd.random_circle_faded ([resolution,resolution]))
else: else:
img = imagelib.apply_random_motion_blur( img, motion_blur_chance, motion_blur_mb_max_size, mask=sd.random_circle_faded ([resolution,resolution])) img = imagelib.apply_random_motion_blur( img, motion_blur_chance, motion_blur_mb_max_size, mask=sd.random_circle_faded ([resolution,resolution]))
img = imagelib.apply_random_gaussian_blur( img, gaussian_blur_chance, gaussian_blur_kernel_max_size, mask=sd.random_circle_faded ([resolution,resolution])) img = imagelib.apply_random_gaussian_blur( img, gaussian_blur_chance, gaussian_blur_kernel_max_size, mask=sd.random_circle_faded ([resolution,resolution]))
if np.random.randint(2) == 0: if np.random.randint(2) == 0:
img = imagelib.apply_random_nearest_resize( img, random_bilinear_resize_chance, random_bilinear_resize_max_size_per, mask=sd.random_circle_faded ([resolution,resolution])) img = imagelib.apply_random_nearest_resize( img, random_bilinear_resize_chance, random_bilinear_resize_max_size_per, mask=sd.random_circle_faded ([resolution,resolution]))
else: else:

View file

@ -16,18 +16,18 @@ class SampleGeneratorImage(SampleGeneratorBase):
self.output_sample_types = output_sample_types self.output_sample_types = output_sample_types
samples = SampleLoader.load (SampleType.IMAGE, samples_path) samples = SampleLoader.load (SampleType.IMAGE, samples_path)
if len(samples) == 0: if len(samples) == 0:
if raise_on_no_data: if raise_on_no_data:
raise ValueError('No training data provided.') raise ValueError('No training data provided.')
return return
self.generators = [ThisThreadGenerator ( self.batch_func, samples )] if self.debug else \ self.generators = [ThisThreadGenerator ( self.batch_func, samples )] if self.debug else \
[SubprocessGenerator ( self.batch_func, samples )] [SubprocessGenerator ( self.batch_func, samples )]
self.generator_counter = -1 self.generator_counter = -1
self.initialized = True self.initialized = True
def __iter__(self): def __iter__(self):
return self return self
@ -38,7 +38,7 @@ class SampleGeneratorImage(SampleGeneratorBase):
def batch_func(self, samples): def batch_func(self, samples):
samples_len = len(samples) samples_len = len(samples)
idxs = [ *range(samples_len) ] idxs = [ *range(samples_len) ]
shuffle_idxs = [] shuffle_idxs = []
@ -54,7 +54,7 @@ class SampleGeneratorImage(SampleGeneratorBase):
idx = shuffle_idxs.pop() idx = shuffle_idxs.pop()
sample = samples[idx] sample = samples[idx]
x, = SampleProcessor.process ([sample], self.sample_process_options, self.output_sample_types, self.debug) x, = SampleProcessor.process ([sample], self.sample_process_options, self.output_sample_types, self.debug)
if batches is None: if batches is None:

View file

@ -84,7 +84,7 @@ class SampleLoader:
xseg_mask_compressed, xseg_mask_compressed,
eyebrows_expand_mod, eyebrows_expand_mod,
source_filename ) = data source_filename ) = data
sample_list.append( Sample(filename=filename, sample_list.append( Sample(filename=filename,
sample_type=SampleType.FACE, sample_type=SampleType.FACE,
face_type=FaceType.fromString (face_type), face_type=FaceType.fromString (face_type),

View file

@ -56,51 +56,51 @@ class SampleProcessor(object):
sample_landmarks = sample.landmarks sample_landmarks = sample.landmarks
ct_sample_bgr = None ct_sample_bgr = None
h,w,c = sample_bgr.shape h,w,c = sample_bgr.shape
def get_full_face_mask(): def get_full_face_mask():
xseg_mask = sample.get_xseg_mask() xseg_mask = sample.get_xseg_mask()
if xseg_mask is not None: if xseg_mask is not None:
if xseg_mask.shape[0] != h or xseg_mask.shape[1] != w: if xseg_mask.shape[0] != h or xseg_mask.shape[1] != w:
xseg_mask = cv2.resize(xseg_mask, (w,h), interpolation=cv2.INTER_CUBIC) xseg_mask = cv2.resize(xseg_mask, (w,h), interpolation=cv2.INTER_CUBIC)
xseg_mask = imagelib.normalize_channels(xseg_mask, 1) xseg_mask = imagelib.normalize_channels(xseg_mask, 1)
return np.clip(xseg_mask, 0, 1) return np.clip(xseg_mask, 0, 1)
else: else:
full_face_mask = LandmarksProcessor.get_image_hull_mask (sample_bgr.shape, sample_landmarks, eyebrows_expand_mod=sample.eyebrows_expand_mod ) full_face_mask = LandmarksProcessor.get_image_hull_mask (sample_bgr.shape, sample_landmarks, eyebrows_expand_mod=sample.eyebrows_expand_mod )
return np.clip(full_face_mask, 0, 1) return np.clip(full_face_mask, 0, 1)
def get_eyes_mask(): def get_eyes_mask():
eyes_mask = LandmarksProcessor.get_image_eye_mask (sample_bgr.shape, sample_landmarks) eyes_mask = LandmarksProcessor.get_image_eye_mask (sample_bgr.shape, sample_landmarks)
return np.clip(eyes_mask, 0, 1) return np.clip(eyes_mask, 0, 1)
def get_eyes_mouth_mask(): def get_eyes_mouth_mask():
eyes_mask = LandmarksProcessor.get_image_eye_mask (sample_bgr.shape, sample_landmarks) eyes_mask = LandmarksProcessor.get_image_eye_mask (sample_bgr.shape, sample_landmarks)
mouth_mask = LandmarksProcessor.get_image_mouth_mask (sample_bgr.shape, sample_landmarks) mouth_mask = LandmarksProcessor.get_image_mouth_mask (sample_bgr.shape, sample_landmarks)
mask = eyes_mask + mouth_mask mask = eyes_mask + mouth_mask
return np.clip(mask, 0, 1) return np.clip(mask, 0, 1)
is_face_sample = sample_landmarks is not None is_face_sample = sample_landmarks is not None
if debug and is_face_sample: if debug and is_face_sample:
LandmarksProcessor.draw_landmarks (sample_bgr, sample_landmarks, (0, 1, 0)) LandmarksProcessor.draw_landmarks (sample_bgr, sample_landmarks, (0, 1, 0))
params_per_resolution = {} params_per_resolution = {}
warp_rnd_state = np.random.RandomState (sample_rnd_seed-1) warp_rnd_state = np.random.RandomState (sample_rnd_seed-1)
for opts in output_sample_types: for opts in output_sample_types:
resolution = opts.get('resolution', None) resolution = opts.get('resolution', None)
if resolution is None: if resolution is None:
continue continue
params_per_resolution[resolution] = imagelib.gen_warp_params(resolution, params_per_resolution[resolution] = imagelib.gen_warp_params(resolution,
sample_process_options.random_flip, sample_process_options.random_flip,
rotation_range=sample_process_options.rotation_range, rotation_range=sample_process_options.rotation_range,
scale_range=sample_process_options.scale_range, scale_range=sample_process_options.scale_range,
tx_range=sample_process_options.tx_range, tx_range=sample_process_options.tx_range,
ty_range=sample_process_options.ty_range, ty_range=sample_process_options.ty_range,
rnd_state=warp_rnd_state) rnd_state=warp_rnd_state)
outputs_sample = [] outputs_sample = []
for opts in output_sample_types: for opts in output_sample_types:
sample_type = opts.get('sample_type', SPST.NONE) sample_type = opts.get('sample_type', SPST.NONE)
channel_type = opts.get('channel_type', SPCT.NONE) channel_type = opts.get('channel_type', SPCT.NONE)
resolution = opts.get('resolution', 0) resolution = opts.get('resolution', 0)
nearest_resize_to = opts.get('nearest_resize_to', None) nearest_resize_to = opts.get('nearest_resize_to', None)
warp = opts.get('warp', False) warp = opts.get('warp', False)
@ -114,36 +114,36 @@ class SampleProcessor(object):
normalize_tanh = opts.get('normalize_tanh', False) normalize_tanh = opts.get('normalize_tanh', False)
ct_mode = opts.get('ct_mode', None) ct_mode = opts.get('ct_mode', None)
data_format = opts.get('data_format', 'NHWC') data_format = opts.get('data_format', 'NHWC')
if sample_type == SPST.FACE_MASK or sample_type == SPST.IMAGE: if sample_type == SPST.FACE_MASK or sample_type == SPST.IMAGE:
border_replicate = False border_replicate = False
elif sample_type == SPST.FACE_IMAGE: elif sample_type == SPST.FACE_IMAGE:
border_replicate = True border_replicate = True
border_replicate = opts.get('border_replicate', border_replicate) border_replicate = opts.get('border_replicate', border_replicate)
borderMode = cv2.BORDER_REPLICATE if border_replicate else cv2.BORDER_CONSTANT borderMode = cv2.BORDER_REPLICATE if border_replicate else cv2.BORDER_CONSTANT
if sample_type == SPST.FACE_IMAGE or sample_type == SPST.FACE_MASK: if sample_type == SPST.FACE_IMAGE or sample_type == SPST.FACE_MASK:
if not is_face_sample: if not is_face_sample:
raise ValueError("face_samples should be provided for sample_type FACE_*") raise ValueError("face_samples should be provided for sample_type FACE_*")
if sample_type == SPST.FACE_IMAGE or sample_type == SPST.FACE_MASK: if sample_type == SPST.FACE_IMAGE or sample_type == SPST.FACE_MASK:
face_type = opts.get('face_type', None) face_type = opts.get('face_type', None)
face_mask_type = opts.get('face_mask_type', SPFMT.NONE) face_mask_type = opts.get('face_mask_type', SPFMT.NONE)
if face_type is None: if face_type is None:
raise ValueError("face_type must be defined for face samples") raise ValueError("face_type must be defined for face samples")
if sample_type == SPST.FACE_MASK: if sample_type == SPST.FACE_MASK:
if face_mask_type == SPFMT.FULL_FACE: if face_mask_type == SPFMT.FULL_FACE:
img = get_full_face_mask() img = get_full_face_mask()
elif face_mask_type == SPFMT.EYES: elif face_mask_type == SPFMT.EYES:
img = get_eyes_mask() img = get_eyes_mask()
elif face_mask_type == SPFMT.EYES_MOUTH: elif face_mask_type == SPFMT.EYES_MOUTH:
mask = get_full_face_mask().copy() mask = get_full_face_mask().copy()
mask[mask != 0.0] = 1.0 mask[mask != 0.0] = 1.0
img = get_eyes_mouth_mask()*mask img = get_eyes_mouth_mask()*mask
else: else:
img = np.zeros ( sample_bgr.shape[0:2]+(1,), dtype=np.float32) img = np.zeros ( sample_bgr.shape[0:2]+(1,), dtype=np.float32)
@ -151,35 +151,35 @@ class SampleProcessor(object):
if sample_face_type == FaceType.MARK_ONLY: if sample_face_type == FaceType.MARK_ONLY:
mat = LandmarksProcessor.get_transform_mat (sample_landmarks, warp_resolution, face_type) mat = LandmarksProcessor.get_transform_mat (sample_landmarks, warp_resolution, face_type)
img = cv2.warpAffine( img, mat, (warp_resolution, warp_resolution), flags=cv2.INTER_LINEAR ) img = cv2.warpAffine( img, mat, (warp_resolution, warp_resolution), flags=cv2.INTER_LINEAR )
img = imagelib.warp_by_params (params_per_resolution[resolution], img, warp, transform, can_flip=True, border_replicate=border_replicate, cv2_inter=cv2.INTER_LINEAR) img = imagelib.warp_by_params (params_per_resolution[resolution], img, warp, transform, can_flip=True, border_replicate=border_replicate, cv2_inter=cv2.INTER_LINEAR)
img = cv2.resize( img, (resolution,resolution), interpolation=cv2.INTER_LINEAR ) img = cv2.resize( img, (resolution,resolution), interpolation=cv2.INTER_LINEAR )
else: else:
if face_type != sample_face_type: if face_type != sample_face_type:
mat = LandmarksProcessor.get_transform_mat (sample_landmarks, resolution, face_type) mat = LandmarksProcessor.get_transform_mat (sample_landmarks, resolution, face_type)
img = cv2.warpAffine( img, mat, (resolution,resolution), borderMode=borderMode, flags=cv2.INTER_LINEAR ) img = cv2.warpAffine( img, mat, (resolution,resolution), borderMode=borderMode, flags=cv2.INTER_LINEAR )
else: else:
if w != resolution: if w != resolution:
img = cv2.resize( img, (resolution, resolution), interpolation=cv2.INTER_LINEAR ) img = cv2.resize( img, (resolution, resolution), interpolation=cv2.INTER_LINEAR )
img = imagelib.warp_by_params (params_per_resolution[resolution], img, warp, transform, can_flip=True, border_replicate=border_replicate, cv2_inter=cv2.INTER_LINEAR) img = imagelib.warp_by_params (params_per_resolution[resolution], img, warp, transform, can_flip=True, border_replicate=border_replicate, cv2_inter=cv2.INTER_LINEAR)
if face_mask_type == SPFMT.EYES_MOUTH: if face_mask_type == SPFMT.EYES_MOUTH:
div = img.max() div = img.max()
if div != 0.0: if div != 0.0:
img = img / div # normalize to 1.0 after warp img = img / div # normalize to 1.0 after warp
if len(img.shape) == 2: if len(img.shape) == 2:
img = img[...,None] img = img[...,None]
if channel_type == SPCT.G: if channel_type == SPCT.G:
out_sample = img.astype(np.float32) out_sample = img.astype(np.float32)
else: else:
raise ValueError("only channel_type.G supported for the mask") raise ValueError("only channel_type.G supported for the mask")
elif sample_type == SPST.FACE_IMAGE: elif sample_type == SPST.FACE_IMAGE:
img = sample_bgr img = sample_bgr
if random_rgb_levels: if random_rgb_levels:
random_mask = sd.random_circle_faded ([w,w], rnd_state=np.random.RandomState (sample_rnd_seed) ) if random_circle_mask else None random_mask = sd.random_circle_faded ([w,w], rnd_state=np.random.RandomState (sample_rnd_seed) ) if random_circle_mask else None
img = imagelib.apply_random_rgb_levels(img, mask=random_mask, rnd_state=np.random.RandomState (sample_rnd_seed) ) img = imagelib.apply_random_rgb_levels(img, mask=random_mask, rnd_state=np.random.RandomState (sample_rnd_seed) )
@ -188,42 +188,42 @@ class SampleProcessor(object):
random_mask = sd.random_circle_faded ([w,w], rnd_state=np.random.RandomState (sample_rnd_seed+1) ) if random_circle_mask else None random_mask = sd.random_circle_faded ([w,w], rnd_state=np.random.RandomState (sample_rnd_seed+1) ) if random_circle_mask else None
img = imagelib.apply_random_hsv_shift(img, mask=random_mask, rnd_state=np.random.RandomState (sample_rnd_seed+1) ) img = imagelib.apply_random_hsv_shift(img, mask=random_mask, rnd_state=np.random.RandomState (sample_rnd_seed+1) )
if face_type != sample_face_type: if face_type != sample_face_type:
mat = LandmarksProcessor.get_transform_mat (sample_landmarks, resolution, face_type) mat = LandmarksProcessor.get_transform_mat (sample_landmarks, resolution, face_type)
img = cv2.warpAffine( img, mat, (resolution,resolution), borderMode=borderMode, flags=cv2.INTER_CUBIC ) img = cv2.warpAffine( img, mat, (resolution,resolution), borderMode=borderMode, flags=cv2.INTER_CUBIC )
else: else:
if w != resolution: if w != resolution:
img = cv2.resize( img, (resolution, resolution), interpolation=cv2.INTER_CUBIC ) img = cv2.resize( img, (resolution, resolution), interpolation=cv2.INTER_CUBIC )
# Apply random color transfer # Apply random color transfer
if ct_mode is not None and ct_sample is not None: if ct_mode is not None and ct_sample is not None:
if ct_sample_bgr is None: if ct_sample_bgr is None:
ct_sample_bgr = ct_sample.load_bgr() ct_sample_bgr = ct_sample.load_bgr()
img = imagelib.color_transfer (ct_mode, img, cv2.resize( ct_sample_bgr, (resolution,resolution), interpolation=cv2.INTER_LINEAR ) ) img = imagelib.color_transfer (ct_mode, img, cv2.resize( ct_sample_bgr, (resolution,resolution), interpolation=cv2.INTER_LINEAR ) )
img = imagelib.warp_by_params (params_per_resolution[resolution], img, warp, transform, can_flip=True, border_replicate=border_replicate)
img = np.clip(img.astype(np.float32), 0, 1)
img = imagelib.warp_by_params (params_per_resolution[resolution], img, warp, transform, can_flip=True, border_replicate=border_replicate)
if motion_blur is not None: img = np.clip(img.astype(np.float32), 0, 1)
if motion_blur is not None:
random_mask = sd.random_circle_faded ([resolution,resolution], rnd_state=np.random.RandomState (sample_rnd_seed+2)) if random_circle_mask else None random_mask = sd.random_circle_faded ([resolution,resolution], rnd_state=np.random.RandomState (sample_rnd_seed+2)) if random_circle_mask else None
img = imagelib.apply_random_motion_blur(img, *motion_blur, mask=random_mask,rnd_state=np.random.RandomState (sample_rnd_seed+2) ) img = imagelib.apply_random_motion_blur(img, *motion_blur, mask=random_mask,rnd_state=np.random.RandomState (sample_rnd_seed+2) )
if gaussian_blur is not None: if gaussian_blur is not None:
random_mask = sd.random_circle_faded ([resolution,resolution], rnd_state=np.random.RandomState (sample_rnd_seed+3)) if random_circle_mask else None random_mask = sd.random_circle_faded ([resolution,resolution], rnd_state=np.random.RandomState (sample_rnd_seed+3)) if random_circle_mask else None
img = imagelib.apply_random_gaussian_blur(img, *gaussian_blur, mask=random_mask,rnd_state=np.random.RandomState (sample_rnd_seed+3) ) img = imagelib.apply_random_gaussian_blur(img, *gaussian_blur, mask=random_mask,rnd_state=np.random.RandomState (sample_rnd_seed+3) )
if random_bilinear_resize is not None: if random_bilinear_resize is not None:
random_mask = sd.random_circle_faded ([resolution,resolution], rnd_state=np.random.RandomState (sample_rnd_seed+4)) if random_circle_mask else None random_mask = sd.random_circle_faded ([resolution,resolution], rnd_state=np.random.RandomState (sample_rnd_seed+4)) if random_circle_mask else None
img = imagelib.apply_random_bilinear_resize(img, *random_bilinear_resize, mask=random_mask,rnd_state=np.random.RandomState (sample_rnd_seed+4) ) img = imagelib.apply_random_bilinear_resize(img, *random_bilinear_resize, mask=random_mask,rnd_state=np.random.RandomState (sample_rnd_seed+4) )
# Transform from BGR to desired channel_type # Transform from BGR to desired channel_type
if channel_type == SPCT.BGR: if channel_type == SPCT.BGR:
out_sample = img out_sample = img
@ -235,22 +235,22 @@ class SampleProcessor(object):
# Final transformations # Final transformations
if nearest_resize_to is not None: if nearest_resize_to is not None:
out_sample = cv2_resize(out_sample, (nearest_resize_to,nearest_resize_to), interpolation=cv2.INTER_NEAREST) out_sample = cv2_resize(out_sample, (nearest_resize_to,nearest_resize_to), interpolation=cv2.INTER_NEAREST)
if not debug: if not debug:
if normalize_tanh: if normalize_tanh:
out_sample = np.clip (out_sample * 2.0 - 1.0, -1.0, 1.0) out_sample = np.clip (out_sample * 2.0 - 1.0, -1.0, 1.0)
if data_format == "NCHW": if data_format == "NCHW":
out_sample = np.transpose(out_sample, (2,0,1) ) out_sample = np.transpose(out_sample, (2,0,1) )
elif sample_type == SPST.IMAGE: elif sample_type == SPST.IMAGE:
img = sample_bgr img = sample_bgr
img = imagelib.warp_by_params (params_per_resolution[resolution], img, warp, transform, can_flip=True, border_replicate=True) img = imagelib.warp_by_params (params_per_resolution[resolution], img, warp, transform, can_flip=True, border_replicate=True)
img = cv2.resize( img, (resolution, resolution), interpolation=cv2.INTER_CUBIC ) img = cv2.resize( img, (resolution, resolution), interpolation=cv2.INTER_CUBIC )
out_sample = img out_sample = img
if data_format == "NCHW": if data_format == "NCHW":
out_sample = np.transpose(out_sample, (2,0,1) ) out_sample = np.transpose(out_sample, (2,0,1) )
elif sample_type == SPST.LANDMARKS_ARRAY: elif sample_type == SPST.LANDMARKS_ARRAY:
l = sample_landmarks l = sample_landmarks
l = np.concatenate ( [ np.expand_dims(l[:,0] / w,-1), np.expand_dims(l[:,1] / h,-1) ], -1 ) l = np.concatenate ( [ np.expand_dims(l[:,0] / w,-1), np.expand_dims(l[:,1] / h,-1) ], -1 )
@ -283,9 +283,9 @@ class SampleProcessor(object):
img = LandmarksProcessor.get_face_struct_mask (sample_bgr.shape, sample_landmarks, eyebrows_expand_mod=sample.eyebrows_expand_mod ) img = LandmarksProcessor.get_face_struct_mask (sample_bgr.shape, sample_landmarks, eyebrows_expand_mod=sample.eyebrows_expand_mod )
else: else:
img = LandmarksProcessor.get_face_struct_mask (sample_bgr.shape, sample_landmarks) img = LandmarksProcessor.get_face_struct_mask (sample_bgr.shape, sample_landmarks)
close_sample = sample.close_target_list[ np.random.randint(0, len(sample.close_target_list)) ] if sample.close_target_list is not None else None close_sample = sample.close_target_list[ np.random.randint(0, len(sample.close_target_list)) ] if sample.close_target_list is not None else None
close_sample_bgr = close_sample.load_bgr() if close_sample is not None else None close_sample_bgr = close_sample.load_bgr() if close_sample is not None else None