Compare commits

..

No commits in common. "master" and "DF.wf.288res.384.92.72.22" have entirely different histories.

103 changed files with 2251 additions and 5263 deletions

2
.vscode/launch.json vendored
View file

@ -12,7 +12,7 @@
"type": "python", "type": "python",
"request": "launch", "request": "launch",
"program": "${env:DFL_ROOT}\\main.py", "program": "${env:DFL_ROOT}\\main.py",
"python": "${env:PYTHONEXECUTABLE}", "pythonPath": "${env:PYTHONEXECUTABLE}",
"cwd": "${env:WORKSPACE}", "cwd": "${env:WORKSPACE}",
"console": "integratedTerminal", "console": "integratedTerminal",
"args": ["train", "args": ["train",

View file

@ -6,7 +6,6 @@ import cv2
import numpy as np import numpy as np
from core import imagelib from core import imagelib
from core.cv2ex import *
from core.imagelib import SegIEPolys from core.imagelib import SegIEPolys
from core.interact import interact as io from core.interact import interact as io
from core.structex import * from core.structex import *
@ -20,8 +19,7 @@ class DFLJPG(object):
self.length = 0 self.length = 0
self.chunks = [] self.chunks = []
self.dfl_dict = None self.dfl_dict = None
self.shape = None self.shape = (0,0,0)
self.img = None
@staticmethod @staticmethod
def load_raw(filename, loader_func=None): def load_raw(filename, loader_func=None):
@ -138,6 +136,8 @@ class DFLJPG(object):
if id == b"JFIF": if id == b"JFIF":
c, ver_major, ver_minor, units, Xdensity, Ydensity, Xthumbnail, Ythumbnail = struct_unpack (d, c, "=BBBHHBB") c, ver_major, ver_minor, units, Xdensity, Ydensity, Xthumbnail, Ythumbnail = struct_unpack (d, c, "=BBBHHBB")
#if units == 0:
# inst.shape = (Ydensity, Xdensity, 3)
else: else:
raise Exception("Unknown jpeg ID: %s" % (id) ) raise Exception("Unknown jpeg ID: %s" % (id) )
elif chunk['name'] == 'SOF0' or chunk['name'] == 'SOF2': elif chunk['name'] == 'SOF0' or chunk['name'] == 'SOF2':
@ -205,16 +205,7 @@ class DFLJPG(object):
return data return data
def get_img(self):
if self.img is None:
self.img = cv2_imread(self.filename)
return self.img
def get_shape(self): def get_shape(self):
if self.shape is None:
img = self.get_img()
if img is not None:
self.shape = img.shape
return self.shape return self.shape
def get_height(self): def get_height(self):
@ -281,13 +272,6 @@ class DFLJPG(object):
def has_xseg_mask(self): def has_xseg_mask(self):
return self.dfl_dict.get('xseg_mask',None) is not None return self.dfl_dict.get('xseg_mask',None) is not None
def get_xseg_mask_compressed(self):
mask_buf = self.dfl_dict.get('xseg_mask',None)
if mask_buf is None:
return None
return mask_buf
def get_xseg_mask(self): def get_xseg_mask(self):
mask_buf = self.dfl_dict.get('xseg_mask',None) mask_buf = self.dfl_dict.get('xseg_mask',None)
if mask_buf is None: if mask_buf is None:
@ -308,7 +292,7 @@ class DFLJPG(object):
mask_a = imagelib.normalize_channels(mask_a, 1) mask_a = imagelib.normalize_channels(mask_a, 1)
img_data = np.clip( mask_a*255, 0, 255 ).astype(np.uint8) img_data = np.clip( mask_a*255, 0, 255 ).astype(np.uint8)
data_max_len = 50000 data_max_len = 4096
ret, buf = cv2.imencode('.png', img_data) ret, buf = cv2.imencode('.png', img_data)

265
README.md
View file

@ -1,237 +1,176 @@
<table align="center" border="0"> <table align="center" border="0"><tr><td align="center" width="9999">
<tr><td colspan=2 align="center">
# DeepFaceLab # DeepFaceLab
### the leading software for creating deepfakes
<a href="https://arxiv.org/abs/2005.05535"> <img src="doc/DFL_welcome.png" align="center">
<img src="https://static.arxiv.org/static/browse/0.3.0/images/icons/favicon.ico" width=14></img>
https://arxiv.org/abs/2005.05535</a>
</td></tr> </td></tr>
<tr><td colspan=2 align="center"> <tr><td align="center" width="9999">
<p align="center"> <p align="center">
![](doc/logo_tensorflow.png)
![](doc/logo_cuda.png) ![](doc/logo_cuda.png)
![](doc/logo_directx.png) ![](doc/logo_tensorflow.png)
![](doc/logo_python.png)
</p> </p>
More than 95% of deepfake videos are created with DeepFaceLab.
DeepFaceLab is used by such popular youtube channels as DeepFaceLab is used by such popular youtube channels as
|![](doc/tiktok_icon.png) [deeptomcruise](https://www.tiktok.com/@deeptomcruise)|![](doc/tiktok_icon.png) [1facerussia](https://www.tiktok.com/@1facerussia)|![](doc/tiktok_icon.png) [arnoldschwarzneggar](https://www.tiktok.com/@arnoldschwarzneggar) |![](doc/youtube_icon.png) [Ctrl Shift Face](https://www.youtube.com/channel/UCKpH0CKltc73e4wh0_pgL3g)|![](doc/youtube_icon.png) [VFXChris Ume](https://www.youtube.com/channel/UCGf4OlX_aTt8DlrgiH3jN3g/videos)|
|---|---|---| |---|---|
|![](doc/tiktok_icon.png) [mariahcareyathome?](https://www.tiktok.com/@mariahcareyathome?)|![](doc/tiktok_icon.png) [diepnep](https://www.tiktok.com/@diepnep)|![](doc/tiktok_icon.png) [mr__heisenberg](https://www.tiktok.com/@mr__heisenberg)|![](doc/tiktok_icon.png) [deepcaprio](https://www.tiktok.com/@deepcaprio) |![](doc/youtube_icon.png) [Sham00k](https://www.youtube.com/channel/UCZXbWcv7fSZFTAZV4beckyw/videos)|![](doc/youtube_icon.png) [Collider videos](https://www.youtube.com/watch?v=A91P2qtPT54&list=PLayt6616lBclvOprvrC8qKGCO-mAhPRux)|![](doc/youtube_icon.png) [iFake](https://www.youtube.com/channel/UCC0lK2Zo2BMXX-k1Ks0r7dg/videos)|![](doc/youtube_icon.png) [NextFace](https://www.youtube.com/channel/UCFh3gL0a8BS21g-DHvXZEeQ/videos)|
|---|---|---|---| |---|---|---|---|
|![](doc/youtube_icon.png) [VFXChris Ume](https://www.youtube.com/channel/UCGf4OlX_aTt8DlrgiH3jN3g/videos)|![](doc/youtube_icon.png) [Sham00k](https://www.youtube.com/channel/UCZXbWcv7fSZFTAZV4beckyw/videos)|
|---|---|
|![](doc/youtube_icon.png) [Collider videos](https://www.youtube.com/watch?v=A91P2qtPT54&list=PLayt6616lBclvOprvrC8qKGCO-mAhPRux)|![](doc/youtube_icon.png) [iFake](https://www.youtube.com/channel/UCC0lK2Zo2BMXX-k1Ks0r7dg/videos)|![](doc/youtube_icon.png) [NextFace](https://www.youtube.com/channel/UCFh3gL0a8BS21g-DHvXZEeQ/videos)|
|---|---|---|
|![](doc/youtube_icon.png) [Futuring Machine](https://www.youtube.com/channel/UCC5BbFxqLQgfnWPhprmQLVg)|![](doc/youtube_icon.png) [RepresentUS](https://www.youtube.com/channel/UCRzgK52MmetD9aG8pDOID3g)|![](doc/youtube_icon.png) [Corridor Crew](https://www.youtube.com/c/corridorcrew/videos)|
|---|---|---|
|![](doc/youtube_icon.png) [DeepFaker](https://www.youtube.com/channel/UCkHecfDTcSazNZSKPEhtPVQ)|![](doc/youtube_icon.png) [DeepFakes in movie](https://www.youtube.com/c/DeepFakesinmovie/videos)|
|---|---|
|![](doc/youtube_icon.png) [DeepFakeCreator](https://www.youtube.com/channel/UCkNFhcYNLQ5hr6A6lZ56mKA)|![](doc/youtube_icon.png) [Jarkan](https://www.youtube.com/user/Jarkancio/videos)|
|---|---|
</td></tr> </td></tr>
<tr><td align="center" width="9999">
<tr><td colspan=2 align="center">
# What can I do using DeepFaceLab? # What can I do using DeepFaceLab?
</td></tr> </td></tr>
<tr><td colspan=2 align="center"> <tr><td align="center" width="9999">
## Replace the face ## Replace the face
<img src="doc/replace_the_face.jpg" align="center"> <img src="doc/replace_the_face.png" align="center">
</td></tr> </td></tr>
<tr><td align="center" width="9999">
<tr><td colspan=2 align="center">
## De-age the face
</td></tr>
<tr><td align="center" width="50%">
<img src="doc/deage_0_1.jpg" align="center">
</td>
<td align="center" width="50%">
<img src="doc/deage_0_2.jpg" align="center">
</td></tr>
<tr><td colspan=2 align="center">
![](doc/youtube_icon.png) https://www.youtube.com/watch?v=Ddx5B-84ebo
</td></tr>
<tr><td colspan=2 align="center">
## Replace the head ## Replace the head
</td></tr>
<tr><td align="center" width="50%"> <table align="center" border="0"><tr>
<img src="doc/head_replace_1_1.jpg" align="center"> <td align="center" width="9999">
<img src="doc/head_replace_1.jpg" align="center">
</td> </td>
<td align="center" width="50%">
<img src="doc/head_replace_1_2.jpg" align="center"> <td align="center" width="9999">
<img src="doc/head_replace_2.jpg" align="center">
</td>
</table>
![](doc/youtube_icon.png) https://www.youtube.com/watch?v=xr5FHd0AdlQ
</td></tr> </td></tr>
<tr><td align="center" width="9999">
<tr><td colspan=2 align="center"> ## Change the lip movement of politicians*
![](doc/youtube_icon.png) https://www.youtube.com/watch?v=RTjgkhMugVw <img src="doc/political_speech.jpg" align="center">
![](doc/youtube_icon.png) https://www.youtube.com/watch?v=2Z1oA3GYPaY
\* also requires a skill in video editors such as *Adobe After Effects* or *Davinci Resolve*
</td></tr>
<tr><td align="center" width="9999">
# Deepfake native resolution progress
</td></tr> </td></tr>
<tr><td align="center" width="9999">
<tr><td colspan=2 align="center">
# Native resolution progress
</td></tr>
<tr><td colspan=2 align="center">
<img src="doc/deepfake_progress.png" align="center"> <img src="doc/deepfake_progress.png" align="center">
</td></tr> </td></tr>
<tr><td colspan=2 align="center"> <tr><td align="center" width="9999">
<img src="doc/make_everything_ok.png" align="center"> <img src="doc/make_everything_ok.png" align="center">
Unfortunately, there is no "make everything ok" button in DeepFaceLab. You should spend time studying the workflow and growing your skills. A skill in programs such as *AfterEffects* or *Davinci Resolve* is also desirable. Unfortunately, there is no "make everything ok" button in DeepFaceLab. You should spend time studying the workflow and growing your skills. A skill in programs such as *AfterEffects* or *Davince Resolve* is also desirable.
</td></tr> </td></tr>
<tr><td colspan=2 align="center"> <tr><td align="center" width="9999">
## Mini tutorial
<a href="https://www.youtube.com/watch?v=kOIMXt8KK8M">
<img src="doc/mini_tutorial.jpg" align="center">
</a>
</td></tr>
<tr><td colspan=2 align="center">
## Releases ## Releases
</td></tr> ||||
|---|---|---|
<tr><td align="right"> |Windows|[github releases](https://github.com/iperov/DeepFaceLab/releases)|Direct download|
<a href="https://tinyurl.com/2p9cvt25">Windows (magnet link)</a> ||[Google drive](https://drive.google.com/open?id=1BCFK_L7lPNwMbEQ_kFPqPpDdFEOd_Dci)|if the download quota is exceeded, add the file to your own google drive and download from it|
</td><td align="center">Last release. Use torrent client to download.</td></tr> |Google Colab|[github](https://github.com/chervonij/DFL-Colab)|by @chervonij . You can train fakes for free using Google Colab.|
|CentOS Linux|[github](https://github.com/elemantalcode/dfl)|by @elemantalcode|
<tr><td align="right"> |Linux|[github](https://github.com/lbfs/DeepFaceLab_Linux)|by @lbfs |
<a href="https://mega.nz/folder/Po0nGQrA#dbbttiNWojCt8jzD4xYaPw">Windows (Mega.nz)</a> ||||
</td><td align="center">Contains new and prev releases.</td></tr>
<tr><td align="right">
<a href="https://disk.yandex.ru/d/7i5XTKIKVg5UUg">Windows (yandex.ru)</a>
</td><td align="center">Contains new and prev releases.</td></tr>
<tr><td align="right">
<a href="https://github.com/nagadit/DeepFaceLab_Linux">Linux (github)</a>
</td><td align="center">by @nagadit</td></tr>
<tr><td align="right">
<a href="https://github.com/elemantalcode/dfl">CentOS Linux (github)</a>
</td><td align="center">May be outdated. By @elemantalcode</td></tr>
</table>
<table align="center" border="0">
<tr><td colspan=2 align="center">
### Communication groups
</td></tr> </td></tr>
<tr><td align="right"> <tr><td align="center" width="9999">
<a href="https://discord.gg/rxa7h9M6rH">Discord</a>
</td><td align="center">Official discord channel. English / Russian.</td></tr>
<tr><td colspan=2 align="center"> ## Links
## Related works
||||
|---|---|---|
|Guides and tutorials|||
||[DeepFaceLab guide](https://mrdeepfakes.com/forums/thread-guide-deepfacelab-2-0-explained-and-tutorials-recommended)|Main guide|
||[Faceset creation guide](https://mrdeepfakes.com/forums/thread-guide-celebrity-faceset-dataset-creation-how-to-create-celebrity-facesets)|How to create the right faceset |
||[Google Colab guide](https://mrdeepfakes.com/forums/thread-guide-deepfacelab-google-colab-tutorial)|Guide how to train the fake on Google Colab|
||[Compositing](https://mrdeepfakes.com/forums/thread-deepfacelab-2-0-compositing-in-davinci-resolve-vegas-pro-and-after-effects)|To achieve the highest quality, compose deepfake manually in video editors such as Davince Resolve or Adobe AfterEffects|
||[Discussion and suggestions](https://mrdeepfakes.com/forums/thread-deepfacelab-2-0-discussion-tips-suggestions)||
||||
|Supplementary material|||
||[Ready to work facesets](https://mrdeepfakes.com/forums/forum-celebrity-facesets)|Celebrity facesets made by community|
||[Pretrained models](https://mrdeepfakes.com/forums/forum-celebrity-facesets)|Use pretrained models made by community to speed up training|
||||
|Communication groups|||
||[telegram (English / Russian)](https://t.me/DeepFaceLab_official)|Don't forget to hide your phone number|
||[telegram (English only)](https://t.me/DeepFaceLab_official_en)|Don't forget to hide your phone number|
||[Русское сообщество](https://mrdeepfakes.com/forums/forum-russian-community)||
||[mrdeepfakes](https://mrdeepfakes.com/forums/)|the biggest NSFW English community|
||[reddit r/GifFakes/](https://www.reddit.com/r/GifFakes/new/)|Post your deepfakes there !|
||[reddit r/SFWdeepfakes/](https://www.reddit.com/r/SFWdeepfakes/new/)|Post your deepfakes there !|
||QQ 951138799|中文 Chinese QQ group for ML/AI experts|
||[deepfaker.xyz](https://www.deepfaker.xyz/)|中文学习站(非官方|
</td></tr> </td></tr>
<tr><td align="right"> <tr><td align="center" width="9999">
<a href="https://github.com/iperov/DeepFaceLive">DeepFaceLive</a>
</td><td align="center">Real-time face swap for PC streaming or video calls</td></tr>
</td></tr>
</table>
<table align="center" border="0">
<tr><td colspan=2 align="center">
## How I can help the project? ## How I can help the project?
|||
|---|---|
|Donate|Sponsor deepfake research and DeepFaceLab development.|
||[Donate via Paypal](https://www.paypal.com/cgi-bin/webscr?cmd=_donations&business=lepersorium@gmail.com&lc=US&no_note=0&item_name=Support+DeepFaceLab&cn=&curency_code=USD&bn=PP-DonationsBF:btn_donateCC_LG.gif:NonHosted)
||[Donate via Yandex.Money](https://money.yandex.ru/to/41001142318065)|
||bitcoin:bc1qkhh7h0gwwhxgg6h6gpllfgstkd645fefrd5s6z|
|Alipay 捐款|![](doc/Alipay_donation.jpg)|
|||
|Last donations|20$ ( 飛星工作室 )|
||100$ ( Peter S. )|
||50$ ( John Lee )|
|||
|Collect facesets|You can collect faceset of any celebrity that can be used in DeepFaceLab and share it [in the community](https://mrdeepfakes.com/forums/forum-celebrity-facesets)|
|||
|Star this repo|Register github account and push "Star" button.|
</td></tr> </td></tr>
<tr><td colspan=2 align="center"> <tr><td align="center" width="9999">
### Star this repo
</td></tr>
<tr><td colspan=2 align="center">
Register github account and push "Star" button.
</td></tr>
</table>
<table align="center" border="0">
<tr><td colspan=2 align="center">
## Meme zone ## Meme zone
<p align="center">
![](doc/meme1.jpg)
![](doc/meme2.jpg)
</p>
</td></tr> </td></tr>
<tr><td align="center" width="9999">
<tr><td align="center" width="50%"> <sub>#deepfacelab #deepfakes #faceswap #face-swap #deep-learning #deeplearning #deep-neural-networks #deepface #deep-face-swap #fakeapp #fake-app #neural-networks #neural-nets #tensorflow #cuda #nvidia</sub>
<img src="doc/meme1.jpg" align="center">
</td>
<td align="center" width="50%">
<img src="doc/meme2.jpg" align="center">
</td></tr> </td></tr>
<tr><td colspan=2 align="center">
<sub>#deepfacelab #faceswap #face-swap #deep-learning #deeplearning #deep-neural-networks #deepface #deep-face-swap #neural-networks #neural-nets #tensorflow #cuda #nvidia</sub>
</td></tr>
</table> </table>

View file

@ -17,10 +17,6 @@ class QIconDB():
QIconDB.poly_type_exclude = QIcon ( str(icon_path / 'poly_type_exclude.png') ) QIconDB.poly_type_exclude = QIcon ( str(icon_path / 'poly_type_exclude.png') )
QIconDB.left = QIcon ( str(icon_path / 'left.png') ) QIconDB.left = QIcon ( str(icon_path / 'left.png') )
QIconDB.right = QIcon ( str(icon_path / 'right.png') ) QIconDB.right = QIcon ( str(icon_path / 'right.png') )
QIconDB.trashcan = QIcon ( str(icon_path / 'trashcan.png') )
QIconDB.pt_edit_mode = QIcon ( str(icon_path / 'pt_edit_mode.png') ) QIconDB.pt_edit_mode = QIcon ( str(icon_path / 'pt_edit_mode.png') )
QIconDB.view_lock_center = QIcon ( str(icon_path / 'view_lock_center.png') )
QIconDB.view_baked = QIcon ( str(icon_path / 'view_baked.png') ) QIconDB.view_baked = QIcon ( str(icon_path / 'view_baked.png') )
QIconDB.view_xseg = QIcon ( str(icon_path / 'view_xseg.png') ) QIconDB.view_xseg = QIcon ( str(icon_path / 'view_xseg.png') )
QIconDB.view_xseg_overlay = QIcon ( str(icon_path / 'view_xseg_overlay.png') )

View file

@ -1,8 +0,0 @@
from PyQt5.QtCore import *
from PyQt5.QtGui import *
from PyQt5.QtWidgets import *
class QImageDB():
@staticmethod
def initialize(image_path):
QImageDB.intro = QImage ( str(image_path / 'intro.png') )

View file

@ -35,11 +35,6 @@ class QStringDB():
'zh' : '查看导入后的XSeg遮罩', 'zh' : '查看导入后的XSeg遮罩',
}[lang] }[lang]
QStringDB.btn_view_xseg_overlay_mask_tip = { 'en' : 'View trained XSeg mask overlay face',
'ru' : 'Посмотреть тренированную XSeg маску поверх лица',
'zh' : '查看导入后的XSeg遮罩于脸上方',
}[lang]
QStringDB.btn_poly_type_include_tip = { 'en' : 'Poly include mode', QStringDB.btn_poly_type_include_tip = { 'en' : 'Poly include mode',
'ru' : 'Режим полигонов - включение', 'ru' : 'Режим полигонов - включение',
'zh' : '包含选区模式', 'zh' : '包含选区模式',
@ -65,17 +60,11 @@ class QStringDB():
'zh' : '删除选区', 'zh' : '删除选区',
}[lang] }[lang]
QStringDB.btn_pt_edit_mode_tip = { 'en' : 'Add/delete point mode ( HOLD CTRL )', QStringDB.btn_pt_edit_mode_tip = { 'en' : 'Edit point mode ( HOLD CTRL )',
'ru' : 'Режим добавления/удаления точек ( удерживайте CTRL )', 'ru' : 'Режим правки точек',
'zh' : '加/删除模式 ( 按住CTRL )', 'zh' : '编辑点模式 ( 按住CTRL )',
}[lang] }[lang]
QStringDB.btn_view_lock_center_tip = { 'en' : 'Lock cursor at the center ( HOLD SHIFT )',
'ru' : 'Заблокировать курсор в центре ( удерживайте SHIFT )',
'zh' : '将光标锁定在中心 ( 按住SHIFT )',
}[lang]
QStringDB.btn_prev_image_tip = { 'en' : 'Save and Prev image\nHold SHIFT : accelerate\nHold CTRL : skip non masked\n', QStringDB.btn_prev_image_tip = { 'en' : 'Save and Prev image\nHold SHIFT : accelerate\nHold CTRL : skip non masked\n',
'ru' : 'Сохранить и предыдущее изображение\nУдерживать SHIFT : ускорить\nУдерживать CTRL : пропустить неразмеченные\n', 'ru' : 'Сохранить и предыдущее изображение\nУдерживать SHIFT : ускорить\nУдерживать CTRL : пропустить неразмеченные\n',
'zh' : '保存并转到上一张图片\n按住SHIFT : 加快\n按住CTRL : 跳过未标记的\n', 'zh' : '保存并转到上一张图片\n按住SHIFT : 加快\n按住CTRL : 跳过未标记的\n',
@ -85,18 +74,4 @@ class QStringDB():
'zh' : '保存并转到下一张图片\n按住SHIFT : 加快\n按住CTRL : 跳过未标记的\n', 'zh' : '保存并转到下一张图片\n按住SHIFT : 加快\n按住CTRL : 跳过未标记的\n',
}[lang] }[lang]
QStringDB.btn_delete_image_tip = { 'en' : 'Move to _trash and Next image\n',
'ru' : 'Переместить в _trash и следующее изображение\n',
'zh' : '移至_trash转到下一张图片 ',
}[lang]
QStringDB.loading_tip = {'en' : 'Loading',
'ru' : 'Загрузка',
'zh' : '正在载入',
}[lang]
QStringDB.labeled_tip = {'en' : 'labeled',
'ru' : 'размечено',
'zh' : '标记的',
}[lang]

View file

@ -16,18 +16,18 @@ from PyQt5.QtCore import *
from PyQt5.QtGui import * from PyQt5.QtGui import *
from PyQt5.QtWidgets import * from PyQt5.QtWidgets import *
from core import imagelib, pathex from core import pathex
from core.cv2ex import * from core.cv2ex import *
from core import imagelib
from core.imagelib import SegIEPoly, SegIEPolys, SegIEPolyType, sd from core.imagelib import SegIEPoly, SegIEPolys, SegIEPolyType, sd
from core.qtex import * from core.qtex import *
from DFLIMG import * from DFLIMG import *
from localization import StringsDB, system_language from localization import StringsDB, system_language
from samplelib import PackedFaceset
from .QCursorDB import QCursorDB from .QCursorDB import QCursorDB
from .QIconDB import QIconDB from .QIconDB import QIconDB
from .QStringDB import QStringDB from .QStringDB import QStringDB
from .QImageDB import QImageDB
class OpMode(IntEnum): class OpMode(IntEnum):
NONE = 0 NONE = 0
@ -45,10 +45,6 @@ class DragType(IntEnum):
IMAGE_LOOK = 1 IMAGE_LOOK = 1
POLY_PT = 2 POLY_PT = 2
class ViewLock(IntEnum):
NONE = 0
CENTER = 1
class QUIConfig(): class QUIConfig():
@staticmethod @staticmethod
def initialize(icon_size = 48, icon_spacer_size=16, preview_bar_icon_size=64): def initialize(icon_size = 48, icon_spacer_size=16, preview_bar_icon_size=64):
@ -70,24 +66,12 @@ class ImagePreviewSequenceBar(QFrame):
main_frame_l_cont_hl = QGridLayout() main_frame_l_cont_hl = QGridLayout()
main_frame_l_cont_hl.setContentsMargins(0,0,0,0) main_frame_l_cont_hl.setContentsMargins(0,0,0,0)
#main_frame_l_cont_hl.setSpacing(0)
for i in range(len(self.image_containers)): for i in range(len(self.image_containers)):
q_label = self.image_containers[i] q_label = self.image_containers[i]
q_label.setScaledContents(True) q_label.setScaledContents(True)
if i == preview_images_count//2: q_label.setMinimumSize(icon_size, icon_size )
q_label.setMinimumSize(icon_size+16, icon_size+16 ) q_label.setSizePolicy (QSizePolicy.Ignored, QSizePolicy.Ignored)
q_label.setMaximumSize(icon_size+16, icon_size+16 )
else:
q_label.setMinimumSize(icon_size, icon_size )
q_label.setMaximumSize(icon_size, icon_size )
opacity_effect = QGraphicsOpacityEffect()
opacity_effect.setOpacity(0.5)
q_label.setGraphicsEffect(opacity_effect)
q_label.setSizePolicy (QSizePolicy.Fixed, QSizePolicy.Fixed)
main_frame_l_cont_hl.addWidget (q_label, 0, i) main_frame_l_cont_hl.addWidget (q_label, 0, i)
@ -101,33 +85,39 @@ class ImagePreviewSequenceBar(QFrame):
def get_preview_images_count(self): def get_preview_images_count(self):
return self.preview_images_count return self.preview_images_count
def update_images(self, prev_imgs=None, next_imgs=None): def update_images(self, prev_q_imgs=None, next_q_imgs=None):
# Fix arrays # Fix arrays
if prev_imgs is None: if prev_q_imgs is None:
prev_imgs = [] prev_q_imgs = []
prev_img_conts_len = len(self.prev_img_conts) prev_img_conts_len = len(self.prev_img_conts)
prev_q_imgs_len = len(prev_imgs) prev_q_imgs_len = len(prev_q_imgs)
if prev_q_imgs_len < prev_img_conts_len: if prev_q_imgs_len < prev_img_conts_len:
for i in range ( prev_img_conts_len - prev_q_imgs_len ): for i in range ( prev_img_conts_len - prev_q_imgs_len ):
prev_imgs.append(None) prev_q_imgs.append(None)
elif prev_q_imgs_len > prev_img_conts_len: elif prev_q_imgs_len > prev_img_conts_len:
prev_imgs = prev_imgs[:prev_img_conts_len] prev_q_imgs = prev_q_imgs[:prev_img_conts_len]
if next_imgs is None: if next_q_imgs is None:
next_imgs = [] next_q_imgs = []
next_img_conts_len = len(self.next_img_conts) next_img_conts_len = len(self.next_img_conts)
next_q_imgs_len = len(next_imgs) next_q_imgs_len = len(next_q_imgs)
if next_q_imgs_len < next_img_conts_len: if next_q_imgs_len < next_img_conts_len:
for i in range ( next_img_conts_len - next_q_imgs_len ): for i in range ( next_img_conts_len - next_q_imgs_len ):
next_imgs.append(None) next_q_imgs.append(None)
elif next_q_imgs_len > next_img_conts_len: elif next_q_imgs_len > next_img_conts_len:
next_imgs = next_imgs[:next_img_conts_len] next_q_imgs = next_q_imgs[:next_img_conts_len]
for i,img in enumerate(prev_imgs): for i,q_img in enumerate(prev_q_imgs):
self.prev_img_conts[i].setPixmap( QPixmap.fromImage( QImage_from_np(img) ) if img is not None else self.black_q_pixmap ) if q_img is None:
self.prev_img_conts[i].setPixmap( self.black_q_pixmap )
else:
self.prev_img_conts[i].setPixmap( QPixmap.fromImage(q_img) )
for i,img in enumerate(next_imgs): for i,q_img in enumerate(next_q_imgs):
self.next_img_conts[i].setPixmap( QPixmap.fromImage( QImage_from_np(img) ) if img is not None else self.black_q_pixmap ) if q_img is None:
self.next_img_conts[i].setPixmap( self.black_q_pixmap )
else:
self.next_img_conts[i].setPixmap( QPixmap.fromImage(q_img) )
class ColorScheme(): class ColorScheme():
def __init__(self, unselected_color, selected_color, outline_color, outline_width, pt_outline_color, cross_cursor): def __init__(self, unselected_color, selected_color, outline_color, outline_width, pt_outline_color, cross_cursor):
@ -197,7 +187,6 @@ class QCanvasControlsLeftBar(QFrame):
self.btn_pt_edit_mode_act = QActionEx( QIconDB.pt_edit_mode, QStringDB.btn_pt_edit_mode_tip, shortcut_in_tooltip=True, is_checkable=True) self.btn_pt_edit_mode_act = QActionEx( QIconDB.pt_edit_mode, QStringDB.btn_pt_edit_mode_tip, shortcut_in_tooltip=True, is_checkable=True)
btn_pt_edit_mode.setDefaultAction(self.btn_pt_edit_mode_act) btn_pt_edit_mode.setDefaultAction(self.btn_pt_edit_mode_act)
btn_pt_edit_mode.setIconSize(QUIConfig.icon_q_size) btn_pt_edit_mode.setIconSize(QUIConfig.icon_q_size)
#==============================================
controls_bar_frame2_l = QVBoxLayout() controls_bar_frame2_l = QVBoxLayout()
controls_bar_frame2_l.addWidget ( btn_poly_type_include ) controls_bar_frame2_l.addWidget ( btn_poly_type_include )
@ -262,11 +251,6 @@ class QCanvasControlsRightBar(QFrame):
btn_view_xseg_mask.setDefaultAction(self.btn_view_xseg_mask_act) btn_view_xseg_mask.setDefaultAction(self.btn_view_xseg_mask_act)
btn_view_xseg_mask.setIconSize(QUIConfig.icon_q_size) btn_view_xseg_mask.setIconSize(QUIConfig.icon_q_size)
btn_view_xseg_overlay_mask = QToolButton()
self.btn_view_xseg_overlay_mask_act = QActionEx( QIconDB.view_xseg_overlay, QStringDB.btn_view_xseg_overlay_mask_tip, shortcut='`', shortcut_in_tooltip=True, is_checkable=True)
btn_view_xseg_overlay_mask.setDefaultAction(self.btn_view_xseg_overlay_mask_act)
btn_view_xseg_overlay_mask.setIconSize(QUIConfig.icon_q_size)
self.btn_poly_color_act_grp = QActionGroup (self) self.btn_poly_color_act_grp = QActionGroup (self)
self.btn_poly_color_act_grp.addAction(self.btn_poly_color_red_act) self.btn_poly_color_act_grp.addAction(self.btn_poly_color_red_act)
self.btn_poly_color_act_grp.addAction(self.btn_poly_color_green_act) self.btn_poly_color_act_grp.addAction(self.btn_poly_color_green_act)
@ -275,17 +259,6 @@ class QCanvasControlsRightBar(QFrame):
self.btn_poly_color_act_grp.addAction(self.btn_view_xseg_mask_act) self.btn_poly_color_act_grp.addAction(self.btn_view_xseg_mask_act)
self.btn_poly_color_act_grp.setExclusive(True) self.btn_poly_color_act_grp.setExclusive(True)
#============================================== #==============================================
btn_view_lock_center = QToolButton()
self.btn_view_lock_center_act = QActionEx( QIconDB.view_lock_center, QStringDB.btn_view_lock_center_tip, shortcut_in_tooltip=True, is_checkable=True)
btn_view_lock_center.setDefaultAction(self.btn_view_lock_center_act)
btn_view_lock_center.setIconSize(QUIConfig.icon_q_size)
controls_bar_frame2_l = QVBoxLayout()
controls_bar_frame2_l.addWidget ( btn_view_xseg_overlay_mask )
controls_bar_frame2 = QFrame()
controls_bar_frame2.setFrameShape(QFrame.StyledPanel)
controls_bar_frame2.setSizePolicy (QSizePolicy.Fixed, QSizePolicy.Fixed)
controls_bar_frame2.setLayout(controls_bar_frame2_l)
controls_bar_frame1_l = QVBoxLayout() controls_bar_frame1_l = QVBoxLayout()
controls_bar_frame1_l.addWidget ( btn_poly_color_red ) controls_bar_frame1_l.addWidget ( btn_poly_color_red )
@ -298,18 +271,9 @@ class QCanvasControlsRightBar(QFrame):
controls_bar_frame1.setSizePolicy (QSizePolicy.Fixed, QSizePolicy.Fixed) controls_bar_frame1.setSizePolicy (QSizePolicy.Fixed, QSizePolicy.Fixed)
controls_bar_frame1.setLayout(controls_bar_frame1_l) controls_bar_frame1.setLayout(controls_bar_frame1_l)
controls_bar_frame3_l = QVBoxLayout()
controls_bar_frame3_l.addWidget ( btn_view_lock_center )
controls_bar_frame3 = QFrame()
controls_bar_frame3.setFrameShape(QFrame.StyledPanel)
controls_bar_frame3.setSizePolicy (QSizePolicy.Fixed, QSizePolicy.Fixed)
controls_bar_frame3.setLayout(controls_bar_frame3_l)
controls_bar_l = QVBoxLayout() controls_bar_l = QVBoxLayout()
controls_bar_l.setContentsMargins(0,0,0,0) controls_bar_l.setContentsMargins(0,0,0,0)
controls_bar_l.addWidget(controls_bar_frame2)
controls_bar_l.addWidget(controls_bar_frame1) controls_bar_l.addWidget(controls_bar_frame1)
controls_bar_l.addWidget(controls_bar_frame3)
self.setSizePolicy ( QSizePolicy.Fixed, QSizePolicy.Expanding ) self.setSizePolicy ( QSizePolicy.Fixed, QSizePolicy.Expanding )
self.setLayout(controls_bar_l) self.setLayout(controls_bar_l)
@ -324,10 +288,8 @@ class QCanvasOperator(QWidget):
self.cbar.btn_poly_color_red_act.triggered.connect ( lambda : self.set_color_scheme_id(0) ) self.cbar.btn_poly_color_red_act.triggered.connect ( lambda : self.set_color_scheme_id(0) )
self.cbar.btn_poly_color_green_act.triggered.connect ( lambda : self.set_color_scheme_id(1) ) self.cbar.btn_poly_color_green_act.triggered.connect ( lambda : self.set_color_scheme_id(1) )
self.cbar.btn_poly_color_blue_act.triggered.connect ( lambda : self.set_color_scheme_id(2) ) self.cbar.btn_poly_color_blue_act.triggered.connect ( lambda : self.set_color_scheme_id(2) )
self.cbar.btn_view_baked_mask_act.triggered.connect ( lambda : self.set_op_mode(OpMode.VIEW_BAKED) ) self.cbar.btn_view_baked_mask_act.toggled.connect ( lambda : self.set_op_mode(OpMode.VIEW_BAKED) )
self.cbar.btn_view_xseg_mask_act.triggered.connect ( lambda : self.set_op_mode(OpMode.VIEW_XSEG_MASK) ) self.cbar.btn_view_xseg_mask_act.toggled.connect ( self.set_view_xseg_mask )
self.cbar.btn_view_xseg_overlay_mask_act.toggled.connect ( lambda is_checked: self.update() )
self.cbar.btn_poly_type_include_act.triggered.connect ( lambda : self.set_poly_include_type(SegIEPolyType.INCLUDE) ) self.cbar.btn_poly_type_include_act.triggered.connect ( lambda : self.set_poly_include_type(SegIEPolyType.INCLUDE) )
self.cbar.btn_poly_type_exclude_act.triggered.connect ( lambda : self.set_poly_include_type(SegIEPolyType.EXCLUDE) ) self.cbar.btn_poly_type_exclude_act.triggered.connect ( lambda : self.set_poly_include_type(SegIEPolyType.EXCLUDE) )
@ -338,7 +300,6 @@ class QCanvasOperator(QWidget):
self.cbar.btn_delete_poly_act.triggered.connect ( lambda : self.action_delete_poly() ) self.cbar.btn_delete_poly_act.triggered.connect ( lambda : self.action_delete_poly() )
self.cbar.btn_pt_edit_mode_act.toggled.connect ( lambda is_checked: self.set_pt_edit_mode( PTEditMode.ADD_DEL if is_checked else PTEditMode.MOVE ) ) self.cbar.btn_pt_edit_mode_act.toggled.connect ( lambda is_checked: self.set_pt_edit_mode( PTEditMode.ADD_DEL if is_checked else PTEditMode.MOVE ) )
self.cbar.btn_view_lock_center_act.toggled.connect ( lambda is_checked: self.set_view_lock( ViewLock.CENTER if is_checked else ViewLock.NONE ) )
self.mouse_in_widget = False self.mouse_in_widget = False
@ -349,22 +310,16 @@ class QCanvasOperator(QWidget):
self.initialized = False self.initialized = False
self.last_state = None self.last_state = None
def initialize(self, img, img_look_pt=None, view_scale=None, ie_polys=None, xseg_mask=None, canvas_config=None ): def initialize(self, q_img, img_look_pt=None, view_scale=None, ie_polys=None, xseg_mask=None, canvas_config=None ):
q_img = self.q_img = QImage_from_np(img) self.q_img = q_img
self.img_pixmap = QPixmap.fromImage(q_img) self.img_pixmap = QPixmap.fromImage(q_img)
self.xseg_mask_pixmap = None self.xseg_mask_pixmap = None
self.xseg_overlay_mask_pixmap = None
if xseg_mask is not None: if xseg_mask is not None:
h,w,c = img.shape w,h = QSize_to_np ( q_img.size() )
xseg_mask = cv2.resize(xseg_mask, (w,h), interpolation=cv2.INTER_CUBIC) xseg_mask = cv2.resize(xseg_mask, (w,h), cv2.INTER_CUBIC)
xseg_mask = imagelib.normalize_channels(xseg_mask, 1) xseg_mask = (imagelib.normalize_channels(xseg_mask, 1) * 255).astype(np.uint8)
xseg_img = img.astype(np.float32)/255.0
xseg_overlay_mask = xseg_img*(1-xseg_mask)*0.5 + xseg_img*xseg_mask
xseg_overlay_mask = np.clip(xseg_overlay_mask*255, 0, 255).astype(np.uint8)
xseg_mask = np.clip(xseg_mask*255, 0, 255).astype(np.uint8)
self.xseg_mask_pixmap = QPixmap.fromImage(QImage_from_np(xseg_mask)) self.xseg_mask_pixmap = QPixmap.fromImage(QImage_from_np(xseg_mask))
self.xseg_overlay_mask_pixmap = QPixmap.fromImage(QImage_from_np(xseg_overlay_mask))
self.img_size = QSize_to_np (self.img_pixmap.size()) self.img_size = QSize_to_np (self.img_pixmap.size())
@ -382,7 +337,6 @@ class QCanvasOperator(QWidget):
# UI init # UI init
self.set_cbar_disabled() self.set_cbar_disabled()
self.cbar.btn_poly_color_act_grp.setDisabled(False) self.cbar.btn_poly_color_act_grp.setDisabled(False)
self.cbar.btn_view_xseg_overlay_mask_act.setDisabled(False)
self.cbar.btn_poly_type_act_grp.setDisabled(False) self.cbar.btn_poly_type_act_grp.setDisabled(False)
# Initial vars # Initial vars
@ -390,14 +344,12 @@ class QCanvasOperator(QWidget):
self.mouse_hull_poly = None self.mouse_hull_poly = None
self.mouse_wire_poly = None self.mouse_wire_poly = None
self.drag_type = DragType.NONE self.drag_type = DragType.NONE
self.mouse_cli_pt = np.zeros((2,), np.float32 )
# Initial state # Initial state
self.set_op_mode(OpMode.NONE) self.set_op_mode(OpMode.NONE)
self.set_color_scheme_id(1) self.set_color_scheme_id(1)
self.set_poly_include_type(SegIEPolyType.INCLUDE) self.set_poly_include_type(SegIEPolyType.INCLUDE)
self.set_pt_edit_mode(PTEditMode.MOVE) self.set_pt_edit_mode(PTEditMode.MOVE)
self.set_view_lock(ViewLock.NONE)
# Apply last state # Apply last state
if self.last_state is not None: if self.last_state is not None:
@ -418,7 +370,8 @@ class QCanvasOperator(QWidget):
self.set_op_mode(OpMode.EDIT_PTS) self.set_op_mode(OpMode.EDIT_PTS)
self.last_state = sn(op_mode = self.op_mode if self.op_mode in [OpMode.VIEW_BAKED, OpMode.VIEW_XSEG_MASK] else None, self.last_state = sn(op_mode = self.op_mode if self.op_mode in [OpMode.VIEW_BAKED, OpMode.VIEW_XSEG_MASK] else None,
color_scheme_id = self.color_scheme_id) color_scheme_id = self.color_scheme_id,
)
self.img_pixmap = None self.img_pixmap = None
self.update_cursor(is_finalize=True) self.update_cursor(is_finalize=True)
@ -433,15 +386,13 @@ class QCanvasOperator(QWidget):
# ====================================== GETTERS ===================================== # ====================================== GETTERS =====================================
# ==================================================================================== # ====================================================================================
# ==================================================================================== # ====================================================================================
def is_initialized(self): def is_initialized(self):
return self.initialized return self.initialized
def get_ie_polys(self): def get_ie_polys(self):
return self.ie_polys return self.ie_polys
def get_cli_center_pt(self):
return np.round(QSize_to_np(self.size())/2.0)
def get_img_look_pt(self): def get_img_look_pt(self):
img_look_pt = self.img_look_pt img_look_pt = self.img_look_pt
if img_look_pt is None: if img_look_pt is None:
@ -503,10 +454,10 @@ class QCanvasOperator(QWidget):
return None return None
def img_to_cli_pt(self, p): def img_to_cli_pt(self, p):
return (p - self.get_img_look_pt()) * self.get_view_scale() + self.get_cli_center_pt()# QSize_to_np(self.size())/2.0 return (p - self.get_img_look_pt()) * self.get_view_scale() + QSize_to_np(self.size())/2.0
def cli_to_img_pt(self, p): def cli_to_img_pt(self, p):
return (p - self.get_cli_center_pt() ) / self.get_view_scale() + self.get_img_look_pt() return (p - QSize_to_np(self.size())/2.0 ) / self.get_view_scale() + self.get_img_look_pt()
def img_to_cli_rect(self, rect): def img_to_cli_rect(self, rect):
tl = QPoint_to_np(rect.topLeft()) tl = QPoint_to_np(rect.topLeft())
@ -531,13 +482,9 @@ class QCanvasOperator(QWidget):
elif self.op_mode == OpMode.DRAW_PTS: elif self.op_mode == OpMode.DRAW_PTS:
self.cbar.btn_undo_pt_act.setDisabled(True) self.cbar.btn_undo_pt_act.setDisabled(True)
self.cbar.btn_redo_pt_act.setDisabled(True) self.cbar.btn_redo_pt_act.setDisabled(True)
self.cbar.btn_view_lock_center_act.setDisabled(True)
# Reset view_lock when exit from DRAW_PTS
self.set_view_lock(ViewLock.NONE)
# Remove unfinished poly
if self.op_poly.get_pts_count() < 3: if self.op_poly.get_pts_count() < 3:
# Remove unfinished poly
self.ie_polys.remove_poly(self.op_poly) self.ie_polys.remove_poly(self.op_poly)
elif self.op_mode == OpMode.EDIT_PTS: elif self.op_mode == OpMode.EDIT_PTS:
self.cbar.btn_pt_edit_mode_act.setDisabled(True) self.cbar.btn_pt_edit_mode_act.setDisabled(True)
self.cbar.btn_delete_poly_act.setDisabled(True) self.cbar.btn_delete_poly_act.setDisabled(True)
@ -556,7 +503,6 @@ class QCanvasOperator(QWidget):
elif op_mode == OpMode.DRAW_PTS: elif op_mode == OpMode.DRAW_PTS:
self.cbar.btn_undo_pt_act.setDisabled(False) self.cbar.btn_undo_pt_act.setDisabled(False)
self.cbar.btn_redo_pt_act.setDisabled(False) self.cbar.btn_redo_pt_act.setDisabled(False)
self.cbar.btn_view_lock_center_act.setDisabled(False)
elif op_mode == OpMode.EDIT_PTS: elif op_mode == OpMode.EDIT_PTS:
self.cbar.btn_pt_edit_mode_act.setDisabled(False) self.cbar.btn_pt_edit_mode_act.setDisabled(False)
self.cbar.btn_delete_poly_act.setDisabled(False) self.cbar.btn_delete_poly_act.setDisabled(False)
@ -570,12 +516,11 @@ class QCanvasOperator(QWidget):
self.img_baked_pixmap = QPixmap.fromImage(QImage_from_np(n)) self.img_baked_pixmap = QPixmap.fromImage(QImage_from_np(n))
elif op_mode == OpMode.VIEW_XSEG_MASK: elif op_mode == OpMode.VIEW_XSEG_MASK:
self.cbar.btn_view_xseg_mask_act.setChecked(True) self.cbar.btn_view_xseg_mask_act.setChecked(True)
if op_mode in [OpMode.DRAW_PTS, OpMode.EDIT_PTS]: if op_mode in [OpMode.DRAW_PTS, OpMode.EDIT_PTS]:
self.mouse_op_poly_pt_id = None self.mouse_op_poly_pt_id = None
self.mouse_op_poly_edge_id = None self.mouse_op_poly_edge_id = None
self.mouse_op_poly_edge_id_pt = None self.mouse_op_poly_edge_id_pt = None
#
self.op_poly = op_poly self.op_poly = op_poly
if op_poly is not None: if op_poly is not None:
self.update_mouse_info() self.update_mouse_info()
@ -588,32 +533,19 @@ class QCanvasOperator(QWidget):
self.pt_edit_mode = pt_edit_mode self.pt_edit_mode = pt_edit_mode
self.update_cursor() self.update_cursor()
self.update() self.update()
self.cbar.btn_pt_edit_mode_act.setChecked( self.pt_edit_mode == PTEditMode.ADD_DEL ) self.cbar.btn_pt_edit_mode_act.setChecked( self.pt_edit_mode == PTEditMode.ADD_DEL )
def set_view_lock(self, view_lock):
if not hasattr(self, 'view_lock') or self.view_lock != view_lock:
if hasattr(self, 'view_lock') and self.view_lock != view_lock:
if view_lock == ViewLock.CENTER:
self.img_look_pt = self.mouse_img_pt
QCursor.setPos ( self.mapToGlobal( QPoint_from_np(self.img_to_cli_pt(self.img_look_pt)) ))
self.view_lock = view_lock
self.update()
self.cbar.btn_view_lock_center_act.setChecked( self.view_lock == ViewLock.CENTER )
def set_cbar_disabled(self): def set_cbar_disabled(self):
self.cbar.btn_delete_poly_act.setDisabled(True) self.cbar.btn_delete_poly_act.setDisabled(True)
self.cbar.btn_undo_pt_act.setDisabled(True) self.cbar.btn_undo_pt_act.setDisabled(True)
self.cbar.btn_redo_pt_act.setDisabled(True) self.cbar.btn_redo_pt_act.setDisabled(True)
self.cbar.btn_pt_edit_mode_act.setDisabled(True) self.cbar.btn_pt_edit_mode_act.setDisabled(True)
self.cbar.btn_view_lock_center_act.setDisabled(True)
self.cbar.btn_poly_color_act_grp.setDisabled(True) self.cbar.btn_poly_color_act_grp.setDisabled(True)
self.cbar.btn_view_xseg_overlay_mask_act.setDisabled(True)
self.cbar.btn_poly_type_act_grp.setDisabled(True) self.cbar.btn_poly_type_act_grp.setDisabled(True)
def set_color_scheme_id(self, id): def set_color_scheme_id(self, id):
if self.op_mode == OpMode.VIEW_BAKED or self.op_mode == OpMode.VIEW_XSEG_MASK: if self.op_mode == OpMode.VIEW_BAKED:
self.set_op_mode(OpMode.NONE) self.set_op_mode(OpMode.NONE)
if not hasattr(self, 'color_scheme_id') or self.color_scheme_id != id: if not hasattr(self, 'color_scheme_id') or self.color_scheme_id != id:
@ -634,9 +566,29 @@ class QCanvasOperator(QWidget):
self.op_mode in [OpMode.NONE, OpMode.EDIT_PTS] ): self.op_mode in [OpMode.NONE, OpMode.EDIT_PTS] ):
self.poly_include_type = poly_include_type self.poly_include_type = poly_include_type
self.update() self.update()
self.cbar.btn_poly_type_include_act.setChecked(self.poly_include_type == SegIEPolyType.INCLUDE) self.cbar.btn_poly_type_include_act.setChecked(self.poly_include_type == SegIEPolyType.INCLUDE)
self.cbar.btn_poly_type_exclude_act.setChecked(self.poly_include_type == SegIEPolyType.EXCLUDE) self.cbar.btn_poly_type_exclude_act.setChecked(self.poly_include_type == SegIEPolyType.EXCLUDE)
def set_view_xseg_mask(self, is_checked):
if is_checked:
self.set_op_mode(OpMode.VIEW_XSEG_MASK)
#n = QImage_to_np ( self.q_img ).astype(np.float32) / 255.0
#h,w,c = n.shape
#mask = np.zeros( (h,w,1), dtype=np.float32 )
#self.ie_polys.overlay_mask(mask)
#n = (mask*255).astype(np.uint8)
#self.img_baked_pixmap = QPixmap.fromImage(QImage_from_np(n))
else:
self.set_op_mode(OpMode.NONE)
self.cbar.btn_view_xseg_mask_act.setChecked(is_checked )
# ==================================================================================== # ====================================================================================
# ==================================================================================== # ====================================================================================
# ====================================== METHODS ===================================== # ====================================== METHODS =====================================
@ -762,9 +714,7 @@ class QCanvasOperator(QWidget):
return return
key = ev.key() key = ev.key()
key_mods = int(ev.modifiers()) key_mods = int(ev.modifiers())
if self.op_mode == OpMode.DRAW_PTS: if self.op_mode == OpMode.EDIT_PTS:
self.set_view_lock(ViewLock.CENTER if key_mods == Qt.ShiftModifier else ViewLock.NONE )
elif self.op_mode == OpMode.EDIT_PTS:
self.set_pt_edit_mode(PTEditMode.ADD_DEL if key_mods == Qt.ControlModifier else PTEditMode.MOVE ) self.set_pt_edit_mode(PTEditMode.ADD_DEL if key_mods == Qt.ControlModifier else PTEditMode.MOVE )
def on_keyReleaseEvent(self, ev): def on_keyReleaseEvent(self, ev):
@ -772,9 +722,7 @@ class QCanvasOperator(QWidget):
return return
key = ev.key() key = ev.key()
key_mods = int(ev.modifiers()) key_mods = int(ev.modifiers())
if self.op_mode == OpMode.DRAW_PTS: if self.op_mode == OpMode.EDIT_PTS:
self.set_view_lock(ViewLock.CENTER if key_mods == Qt.ShiftModifier else ViewLock.NONE )
elif self.op_mode == OpMode.EDIT_PTS:
self.set_pt_edit_mode(PTEditMode.ADD_DEL if key_mods == Qt.ControlModifier else PTEditMode.MOVE ) self.set_pt_edit_mode(PTEditMode.ADD_DEL if key_mods == Qt.ControlModifier else PTEditMode.MOVE )
def enterEvent(self, ev): def enterEvent(self, ev):
@ -825,10 +773,10 @@ class QCanvasOperator(QWidget):
if self.mouse_op_poly_pt_id is not None: if self.mouse_op_poly_pt_id is not None:
# Click on point of op_poly # Click on point of op_poly
if self.pt_edit_mode == PTEditMode.ADD_DEL: if self.pt_edit_mode == PTEditMode.ADD_DEL:
# in mode 'delete point' # with mode -> delete point
self.op_poly.remove_pt(self.mouse_op_poly_pt_id) self.op_poly.remove_pt(self.mouse_op_poly_pt_id)
if self.op_poly.get_pts_count() < 3: if self.op_poly.get_pts_count() < 3:
# not enough points after delete -> remove poly # not enough points -> remove poly
self.ie_polys.remove_poly (self.op_poly) self.ie_polys.remove_poly (self.op_poly)
self.set_op_mode(OpMode.NONE) self.set_op_mode(OpMode.NONE)
self.update() self.update()
@ -842,7 +790,7 @@ class QCanvasOperator(QWidget):
elif self.mouse_op_poly_edge_id is not None: elif self.mouse_op_poly_edge_id is not None:
# Click on edge of op_poly # Click on edge of op_poly
if self.pt_edit_mode == PTEditMode.ADD_DEL: if self.pt_edit_mode == PTEditMode.ADD_DEL:
# in mode 'insert new point' # with mode -> insert new point
edge_img_pt = self.cli_to_img_pt(self.mouse_op_poly_edge_id_pt) edge_img_pt = self.cli_to_img_pt(self.mouse_op_poly_edge_id_pt)
self.op_poly.insert_pt (self.mouse_op_poly_edge_id+1, edge_img_pt) self.op_poly.insert_pt (self.mouse_op_poly_edge_id+1, edge_img_pt)
self.update() self.update()
@ -888,15 +836,8 @@ class QCanvasOperator(QWidget):
if not self.initialized: if not self.initialized:
return return
prev_mouse_cli_pt = self.mouse_cli_pt
self.update_mouse_info(QPoint_to_np(ev.pos())) self.update_mouse_info(QPoint_to_np(ev.pos()))
if self.view_lock == ViewLock.CENTER:
if npla.norm(self.mouse_cli_pt - prev_mouse_cli_pt) >= 1:
self.img_look_pt = self.mouse_img_pt
QCursor.setPos ( self.mapToGlobal( QPoint_from_np(self.img_to_cli_pt(self.img_look_pt)) ))
self.update()
if self.drag_type == DragType.IMAGE_LOOK: if self.drag_type == DragType.IMAGE_LOOK:
delta_pt = self.cli_to_img_pt(self.mouse_cli_pt) - self.cli_to_img_pt(self.drag_cli_pt) delta_pt = self.cli_to_img_pt(self.mouse_cli_pt) - self.cli_to_img_pt(self.drag_cli_pt)
self.img_look_pt = self.drag_img_look_pt - delta_pt self.img_look_pt = self.drag_img_look_pt - delta_pt
@ -905,7 +846,9 @@ class QCanvasOperator(QWidget):
if self.op_mode == OpMode.DRAW_PTS: if self.op_mode == OpMode.DRAW_PTS:
self.update() self.update()
elif self.op_mode == OpMode.EDIT_PTS: elif self.op_mode == OpMode.EDIT_PTS:
if self.drag_type == DragType.POLY_PT: if self.drag_type == DragType.POLY_PT:
delta_pt = self.cli_to_img_pt(self.mouse_cli_pt) - self.cli_to_img_pt(self.drag_cli_pt) delta_pt = self.cli_to_img_pt(self.mouse_cli_pt) - self.cli_to_img_pt(self.drag_cli_pt)
self.op_poly.set_point(self.drag_poly_pt_id, self.drag_poly_pt + delta_pt) self.op_poly.set_point(self.drag_poly_pt_id, self.drag_poly_pt + delta_pt)
self.update() self.update()
@ -944,19 +887,20 @@ class QCanvasOperator(QWidget):
qp.setRenderHint(QPainter.HighQualityAntialiasing) qp.setRenderHint(QPainter.HighQualityAntialiasing)
qp.setRenderHint(QPainter.SmoothPixmapTransform) qp.setRenderHint(QPainter.SmoothPixmapTransform)
src_rect = QRect(0, 0, *self.img_size)
dst_rect = self.img_to_cli_rect( src_rect )
if self.op_mode == OpMode.VIEW_BAKED: if self.op_mode == OpMode.VIEW_BAKED:
src_rect = QRect(0, 0, *self.img_size)
dst_rect = self.img_to_cli_rect( src_rect )
qp.drawPixmap(dst_rect, self.img_baked_pixmap, src_rect) qp.drawPixmap(dst_rect, self.img_baked_pixmap, src_rect)
elif self.op_mode == OpMode.VIEW_XSEG_MASK: elif self.op_mode == OpMode.VIEW_XSEG_MASK:
if self.xseg_mask_pixmap is not None: if self.xseg_mask_pixmap is not None:
src_rect = QRect(0, 0, *self.img_size)
dst_rect = self.img_to_cli_rect( src_rect )
qp.drawPixmap(dst_rect, self.xseg_mask_pixmap, src_rect) qp.drawPixmap(dst_rect, self.xseg_mask_pixmap, src_rect)
else: else:
if self.cbar.btn_view_xseg_overlay_mask_act.isChecked() and \ if self.img_pixmap is not None:
self.xseg_overlay_mask_pixmap is not None: src_rect = QRect(0, 0, *self.img_size)
qp.drawPixmap(dst_rect, self.xseg_overlay_mask_pixmap, src_rect) dst_rect = self.img_to_cli_rect( src_rect )
elif self.img_pixmap is not None:
qp.drawPixmap(dst_rect, self.img_pixmap, src_rect) qp.drawPixmap(dst_rect, self.img_pixmap, src_rect)
polys = self.ie_polys.get_polys() polys = self.ie_polys.get_polys()
@ -1051,9 +995,9 @@ class QCanvasOperator(QWidget):
if op_mode == OpMode.NONE: if op_mode == OpMode.NONE:
if poly == self.mouse_wire_poly: if poly == self.mouse_wire_poly:
qp.setBrush(color_scheme.poly_selected_brush) qp.setBrush(color_scheme.poly_selected_brush)
#else: else:
# if poly == op_poly: if poly == op_poly:
# qp.setBrush(color_scheme.poly_selected_brush) qp.setBrush(color_scheme.poly_selected_brush)
qp.drawPath(poly_line_path) qp.drawPath(poly_line_path)
@ -1066,6 +1010,7 @@ class QCanvasOperator(QWidget):
qp.end() qp.end()
class QCanvas(QFrame): class QCanvas(QFrame):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -1078,16 +1023,17 @@ class QCanvas(QFrame):
btn_poly_color_blue_act = self.canvas_control_right_bar.btn_poly_color_blue_act, btn_poly_color_blue_act = self.canvas_control_right_bar.btn_poly_color_blue_act,
btn_view_baked_mask_act = self.canvas_control_right_bar.btn_view_baked_mask_act, btn_view_baked_mask_act = self.canvas_control_right_bar.btn_view_baked_mask_act,
btn_view_xseg_mask_act = self.canvas_control_right_bar.btn_view_xseg_mask_act, btn_view_xseg_mask_act = self.canvas_control_right_bar.btn_view_xseg_mask_act,
btn_view_xseg_overlay_mask_act = self.canvas_control_right_bar.btn_view_xseg_overlay_mask_act,
btn_poly_color_act_grp = self.canvas_control_right_bar.btn_poly_color_act_grp, btn_poly_color_act_grp = self.canvas_control_right_bar.btn_poly_color_act_grp,
btn_view_lock_center_act = self.canvas_control_right_bar.btn_view_lock_center_act,
btn_poly_type_include_act = self.canvas_control_left_bar.btn_poly_type_include_act, btn_poly_type_include_act = self.canvas_control_left_bar.btn_poly_type_include_act,
btn_poly_type_exclude_act = self.canvas_control_left_bar.btn_poly_type_exclude_act, btn_poly_type_exclude_act = self.canvas_control_left_bar.btn_poly_type_exclude_act,
btn_poly_type_act_grp = self.canvas_control_left_bar.btn_poly_type_act_grp, btn_poly_type_act_grp = self.canvas_control_left_bar.btn_poly_type_act_grp,
btn_undo_pt_act = self.canvas_control_left_bar.btn_undo_pt_act, btn_undo_pt_act = self.canvas_control_left_bar.btn_undo_pt_act,
btn_redo_pt_act = self.canvas_control_left_bar.btn_redo_pt_act, btn_redo_pt_act = self.canvas_control_left_bar.btn_redo_pt_act,
btn_delete_poly_act = self.canvas_control_left_bar.btn_delete_poly_act, btn_delete_poly_act = self.canvas_control_left_bar.btn_delete_poly_act,
btn_pt_edit_mode_act = self.canvas_control_left_bar.btn_pt_edit_mode_act ) btn_pt_edit_mode_act = self.canvas_control_left_bar.btn_pt_edit_mode_act )
self.op = QCanvasOperator(cbar) self.op = QCanvasOperator(cbar)
@ -1122,7 +1068,7 @@ class LoaderQSubprocessor(QSubprocessor):
if len (self.idxs) > 0: if len (self.idxs) > 0:
idx = self.idxs.pop(0) idx = self.idxs.pop(0)
image_path = self.image_paths[idx] image_path = self.image_paths[idx]
self.q_label.setText(f'{QStringDB.loading_tip}... {image_path.name}') self.q_label.setText(f'{image_path.name}')
return idx, image_path return idx, image_path
@ -1155,22 +1101,18 @@ class LoaderQSubprocessor(QSubprocessor):
return idx, True, ie_polys.has_polys() return idx, True, ie_polys.has_polys()
return idx, False, False return idx, False, False
class MainWindow(QXMainWindow): class MainWindow(QXMainWindow):
def __init__(self, input_dirpath, cfg_root_path): def __init__(self, input_dirpath, cfg_root_path):
self.loading_frame = None
self.help_frame = None
super().__init__() super().__init__()
self.input_dirpath = input_dirpath self.input_dirpath = input_dirpath
self.trash_dirpath = input_dirpath.parent / (input_dirpath.name + '_trash')
self.cfg_root_path = cfg_root_path self.cfg_root_path = cfg_root_path
self.cfg_path = cfg_root_path / 'MainWindow_cfg.dat' self.cfg_path = cfg_root_path / 'MainWindow_cfg.dat'
self.cfg_dict = pickle.loads(self.cfg_path.read_bytes()) if self.cfg_path.exists() else {} self.cfg_dict = pickle.loads(self.cfg_path.read_bytes()) if self.cfg_path.exists() else {}
self.cached_images = {} self.cached_QImages = {}
self.cached_has_ie_polys = {} self.cached_has_ie_polys = {}
self.initialize_ui() self.initialize_ui()
@ -1181,20 +1123,9 @@ class MainWindow(QXMainWindow):
self.loading_frame.setFrameShape(QFrame.StyledPanel) self.loading_frame.setFrameShape(QFrame.StyledPanel)
self.loader_label = QLabel() self.loader_label = QLabel()
self.loader_progress_bar = QProgressBar() self.loader_progress_bar = QProgressBar()
intro_image = QLabel()
intro_image.setPixmap( QPixmap.fromImage(QImageDB.intro) )
intro_image_frame_l = QVBoxLayout()
intro_image_frame_l.addWidget(intro_image, alignment=Qt.AlignCenter)
intro_image_frame = QFrame()
intro_image_frame.setSizePolicy (QSizePolicy.Expanding, QSizePolicy.Expanding)
intro_image_frame.setLayout(intro_image_frame_l)
loading_frame_l = QVBoxLayout() loading_frame_l = QVBoxLayout()
loading_frame_l.addWidget (intro_image_frame) loading_frame_l.addWidget (self.loader_label, alignment=Qt.AlignBottom)
loading_frame_l.addWidget (self.loader_label) loading_frame_l.addWidget (self.loader_progress_bar, alignment=Qt.AlignTop)
loading_frame_l.addWidget (self.loader_progress_bar)
self.loading_frame.setLayout(loading_frame_l) self.loading_frame.setLayout(loading_frame_l)
self.loader_subprocessor = LoaderQSubprocessor( image_paths=pathex.get_image_paths(input_dirpath, return_Path_class=True), self.loader_subprocessor = LoaderQSubprocessor( image_paths=pathex.get_image_paths(input_dirpath, return_Path_class=True),
@ -1207,7 +1138,6 @@ class MainWindow(QXMainWindow):
self.image_paths_done = [] self.image_paths_done = []
self.image_paths = image_paths self.image_paths = image_paths
self.image_paths_has_ie_polys = image_paths_has_ie_polys self.image_paths_has_ie_polys = image_paths_has_ie_polys
self.set_has_ie_polys_count ( len([ 1 for x in self.image_paths_has_ie_polys if self.image_paths_has_ie_polys[x] == True]) )
self.loading_frame.hide() self.loading_frame.hide()
self.loading_frame = None self.loading_frame = None
@ -1219,7 +1149,7 @@ class MainWindow(QXMainWindow):
def update_cached_images (self, count=5): def update_cached_images (self, count=5):
d = self.cached_images d = self.cached_QImages
for image_path in self.image_paths_done[:-count]+self.image_paths[count:]: for image_path in self.image_paths_done[:-count]+self.image_paths[count:]:
if image_path in d: if image_path in d:
@ -1229,14 +1159,13 @@ class MainWindow(QXMainWindow):
if image_path not in d: if image_path not in d:
img = cv2_imread(image_path) img = cv2_imread(image_path)
if img is not None: if img is not None:
d[image_path] = img d[image_path] = QImage_from_np(img)
def load_image(self, image_path): def load_QImage(self, image_path):
try: try:
img = self.cached_images.get(image_path, None) img = self.cached_QImages.get(image_path, None)
if img is None: if img is None:
img = cv2_imread(image_path) img = QImage_from_np(cv2_imread(image_path))
self.cached_images[image_path] = img
if img is None: if img is None:
io.log_err(f'Unable to load {image_path}') io.log_err(f'Unable to load {image_path}')
except: except:
@ -1246,10 +1175,10 @@ class MainWindow(QXMainWindow):
def update_preview_bar(self): def update_preview_bar(self):
count = self.image_bar.get_preview_images_count() count = self.image_bar.get_preview_images_count()
d = self.cached_images d = self.cached_QImages
prev_imgs = [ d.get(image_path, None) for image_path in self.image_paths_done[-1:-count:-1] ] prev_q_imgs = [ d.get(image_path, None) for image_path in self.image_paths_done[-1:-count:-1] ]
next_imgs = [ d.get(image_path, None) for image_path in self.image_paths[:count] ] next_q_imgs = [ d.get(image_path, None) for image_path in self.image_paths[:count] ]
self.image_bar.update_images(prev_imgs, next_imgs) self.image_bar.update_images(prev_q_imgs, next_q_imgs)
def canvas_initialize(self, image_path, only_has_polys=False): def canvas_initialize(self, image_path, only_has_polys=False):
@ -1262,13 +1191,13 @@ class MainWindow(QXMainWindow):
ie_polys = dflimg.get_seg_ie_polys() ie_polys = dflimg.get_seg_ie_polys()
xseg_mask = dflimg.get_xseg_mask() xseg_mask = dflimg.get_xseg_mask()
img = self.load_image(image_path) q_img = self.load_QImage(image_path)
if img is None: if q_img is None:
return False return False
self.canvas.op.initialize ( img, ie_polys=ie_polys, xseg_mask=xseg_mask ) self.canvas.op.initialize ( q_img, ie_polys=ie_polys, xseg_mask=xseg_mask )
self.filename_label.setText(f"{image_path.name}") self.filename_label.setText(str(image_path.name))
return True return True
@ -1281,19 +1210,11 @@ class MainWindow(QXMainWindow):
new_ie_polys = self.canvas.op.get_ie_polys() new_ie_polys = self.canvas.op.get_ie_polys()
if not new_ie_polys.identical(ie_polys): if not new_ie_polys.identical(ie_polys):
prev_has_polys = self.image_paths_has_ie_polys[image_path]
self.image_paths_has_ie_polys[image_path] = new_ie_polys.has_polys() self.image_paths_has_ie_polys[image_path] = new_ie_polys.has_polys()
new_has_polys = self.image_paths_has_ie_polys[image_path]
if not prev_has_polys and new_has_polys:
self.set_has_ie_polys_count ( self.get_has_ie_polys_count() +1)
elif prev_has_polys and not new_has_polys:
self.set_has_ie_polys_count ( self.get_has_ie_polys_count() -1)
dflimg.set_seg_ie_polys( new_ie_polys ) dflimg.set_seg_ie_polys( new_ie_polys )
dflimg.save() dflimg.save()
self.filename_label.setText(f"") self.filename_label.setText("")
def process_prev_image(self): def process_prev_image(self):
key_mods = QApplication.keyboardModifiers() key_mods = QApplication.keyboardModifiers()
@ -1343,17 +1264,6 @@ class MainWindow(QXMainWindow):
self.update_cached_images() self.update_cached_images()
self.update_preview_bar() self.update_preview_bar()
def trash_current_image(self):
self.process_next_image()
img_path = self.image_paths_done.pop(-1)
img_path = Path(img_path)
self.trash_dirpath.mkdir(parents=True, exist_ok=True)
img_path.rename( self.trash_dirpath / img_path.name )
self.update_cached_images()
self.update_preview_bar()
def initialize_ui(self): def initialize_ui(self):
self.canvas = QCanvas() self.canvas = QCanvas()
@ -1368,59 +1278,33 @@ class MainWindow(QXMainWindow):
btn_next_image = QXIconButton(QIconDB.right, QStringDB.btn_next_image_tip, shortcut='D', click_func=self.process_next_image) btn_next_image = QXIconButton(QIconDB.right, QStringDB.btn_next_image_tip, shortcut='D', click_func=self.process_next_image)
btn_next_image.setIconSize(QUIConfig.preview_bar_icon_q_size) btn_next_image.setIconSize(QUIConfig.preview_bar_icon_q_size)
btn_delete_image = QXIconButton(QIconDB.trashcan, QStringDB.btn_delete_image_tip, shortcut='X', click_func=self.trash_current_image)
btn_delete_image.setIconSize(QUIConfig.preview_bar_icon_q_size)
pad_image = QWidget()
pad_image.setFixedSize(QUIConfig.preview_bar_icon_q_size)
preview_image_bar_frame_l = QHBoxLayout() preview_image_bar_frame_l = QHBoxLayout()
preview_image_bar_frame_l.setContentsMargins(0,0,0,0) preview_image_bar_frame_l.setContentsMargins(0,0,0,0)
preview_image_bar_frame_l.addWidget ( pad_image, alignment=Qt.AlignCenter)
preview_image_bar_frame_l.addWidget ( btn_prev_image, alignment=Qt.AlignCenter) preview_image_bar_frame_l.addWidget ( btn_prev_image, alignment=Qt.AlignCenter)
preview_image_bar_frame_l.addWidget ( image_bar) preview_image_bar_frame_l.addWidget ( image_bar)
preview_image_bar_frame_l.addWidget ( btn_next_image, alignment=Qt.AlignCenter) preview_image_bar_frame_l.addWidget ( btn_next_image, alignment=Qt.AlignCenter)
#preview_image_bar_frame_l.addWidget ( btn_delete_image, alignment=Qt.AlignCenter)
preview_image_bar_frame = QFrame() preview_image_bar_frame = QFrame()
preview_image_bar_frame.setSizePolicy ( QSizePolicy.Fixed, QSizePolicy.Fixed ) preview_image_bar_frame.setSizePolicy ( QSizePolicy.Fixed, QSizePolicy.Fixed )
preview_image_bar_frame.setLayout(preview_image_bar_frame_l) preview_image_bar_frame.setLayout(preview_image_bar_frame_l)
preview_image_bar_frame2_l = QHBoxLayout()
preview_image_bar_frame2_l.setContentsMargins(0,0,0,0)
preview_image_bar_frame2_l.addWidget ( btn_delete_image, alignment=Qt.AlignCenter)
preview_image_bar_frame2 = QFrame()
preview_image_bar_frame2.setSizePolicy ( QSizePolicy.Fixed, QSizePolicy.Fixed )
preview_image_bar_frame2.setLayout(preview_image_bar_frame2_l)
preview_image_bar_l = QHBoxLayout() preview_image_bar_l = QHBoxLayout()
preview_image_bar_l.addWidget (preview_image_bar_frame, alignment=Qt.AlignCenter) preview_image_bar_l.addWidget (preview_image_bar_frame)
preview_image_bar_l.addWidget (preview_image_bar_frame2)
preview_image_bar = QFrame() preview_image_bar = QFrame()
preview_image_bar.setFrameShape(QFrame.StyledPanel) preview_image_bar.setFrameShape(QFrame.StyledPanel)
preview_image_bar.setSizePolicy ( QSizePolicy.Expanding, QSizePolicy.Fixed ) preview_image_bar.setSizePolicy ( QSizePolicy.Expanding, QSizePolicy.Fixed )
preview_image_bar.setLayout(preview_image_bar_l) preview_image_bar.setLayout(preview_image_bar_l)
label_font = QFont('Courier New')
self.filename_label = QLabel() self.filename_label = QLabel()
self.filename_label.setFont(label_font) f = QFont('Courier New')
self.filename_label.setFont(f)
self.has_ie_polys_count_label = QLabel()
status_frame_l = QHBoxLayout()
status_frame_l.setContentsMargins(0,0,0,0)
status_frame_l.addWidget ( QLabel(), alignment=Qt.AlignCenter)
status_frame_l.addWidget (self.filename_label, alignment=Qt.AlignCenter)
status_frame_l.addWidget (self.has_ie_polys_count_label, alignment=Qt.AlignCenter)
status_frame = QFrame()
status_frame.setLayout(status_frame_l)
main_canvas_l = QVBoxLayout() main_canvas_l = QVBoxLayout()
main_canvas_l.setContentsMargins(0,0,0,0) main_canvas_l.setContentsMargins(0,0,0,0)
main_canvas_l.addWidget (self.canvas) main_canvas_l.addWidget (self.canvas)
main_canvas_l.addWidget (status_frame) main_canvas_l.addWidget (self.filename_label, alignment=Qt.AlignCenter)
main_canvas_l.addWidget (preview_image_bar) main_canvas_l.addWidget (preview_image_bar)
self.main_canvas_frame = QFrame() self.main_canvas_frame = QFrame()
@ -1438,29 +1322,11 @@ class MainWindow(QXMainWindow):
else: else:
self.move( QPoint(0,0)) self.move( QPoint(0,0))
def get_has_ie_polys_count(self):
return self.has_ie_polys_count
def set_has_ie_polys_count(self, c):
self.has_ie_polys_count = c
self.has_ie_polys_count_label.setText(f"{c} {QStringDB.labeled_tip}")
def resizeEvent(self, ev): def resizeEvent(self, ev):
if self.loading_frame is not None: if self.loading_frame is not None:
self.loading_frame.resize( ev.size() ) self.loading_frame.resize( ev.size() )
if self.help_frame is not None:
self.help_frame.resize( ev.size() )
def start(input_dirpath): def start(input_dirpath):
"""
returns exit_code
"""
io.log_info("Running XSeg editor.")
if PackedFaceset.path_contains(input_dirpath):
io.log_info (f'\n{input_dirpath} contains packed faceset! Unpack it first.\n')
return 1
root_path = Path(__file__).parent root_path = Path(__file__).parent
cfg_root_path = Path(tempfile.gettempdir()) cfg_root_path = Path(tempfile.gettempdir())
@ -1480,7 +1346,6 @@ def start(input_dirpath):
QIconDB.initialize( root_path / 'gfx' / 'icons' ) QIconDB.initialize( root_path / 'gfx' / 'icons' )
QCursorDB.initialize( root_path / 'gfx' / 'cursors' ) QCursorDB.initialize( root_path / 'gfx' / 'cursors' )
QImageDB.initialize( root_path / 'gfx' / 'images' )
app.setWindowIcon(QIconDB.app_icon) app.setWindowIcon(QIconDB.app_icon)
app.setPalette( QDarkPalette() ) app.setPalette( QDarkPalette() )
@ -1491,4 +1356,3 @@ def start(input_dirpath):
win.raise_() win.raise_()
app.exec_() app.exec_()
return 0

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.2 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 12 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 30 KiB

View file

@ -2,7 +2,6 @@ import cv2
import numpy as np import numpy as np
from pathlib import Path from pathlib import Path
from core.interact import interact as io from core.interact import interact as io
from core import imagelib
import traceback import traceback
def cv2_imread(filename, flags=cv2.IMREAD_UNCHANGED, loader_func=None, verbose=True): def cv2_imread(filename, flags=cv2.IMREAD_UNCHANGED, loader_func=None, verbose=True):
@ -30,11 +29,3 @@ def cv2_imwrite(filename, img, *args):
stream.write( buf ) stream.write( buf )
except: except:
pass pass
def cv2_resize(x, *args, **kwargs):
h,w,c = x.shape
x = cv2.resize(x, *args, **kwargs)
x = imagelib.normalize_channels(x, c)
return x

View file

@ -77,8 +77,6 @@ class SegIEPoly():
self.pts = np.array(pts) self.pts = np.array(pts)
self.n_max = self.n = len(pts) self.n_max = self.n = len(pts)
def mult_points(self, val):
self.pts *= val
@ -139,10 +137,6 @@ class SegIEPolys():
def dump(self): def dump(self):
return {'polys' : [ poly.dump() for poly in self.polys ] } return {'polys' : [ poly.dump() for poly in self.polys ] }
def mult_points(self, val):
for poly in self.polys:
poly.mult_points(val)
@staticmethod @staticmethod
def load(data=None): def load(data=None):
ie_polys = SegIEPolys() ie_polys = SegIEPolys()

View file

@ -1,5 +1,4 @@
from .estimate_sharpness import estimate_sharpness from .estimate_sharpness import estimate_sharpness
from .equalize_and_stack_square import equalize_and_stack_square from .equalize_and_stack_square import equalize_and_stack_square
from .text import get_text_image, get_draw_text_lines from .text import get_text_image, get_draw_text_lines
@ -12,21 +11,16 @@ from .warp import gen_warp_params, warp_by_params
from .reduce_colors import reduce_colors from .reduce_colors import reduce_colors
from .color_transfer import color_transfer, color_transfer_mix, color_transfer_sot, color_transfer_mkl, color_transfer_idt, color_hist_match, reinhard_color_transfer, linear_color_transfer from .color_transfer import color_transfer, color_transfer_mix, color_transfer_sot, color_transfer_mkl, color_transfer_idt, color_hist_match, reinhard_color_transfer, linear_color_transfer, seamless_clone
from .common import random_crop, normalize_channels, cut_odd_image, overlay_alpha_image from .common import normalize_channels, cut_odd_image, overlay_alpha_image
from .SegIEPolys import * from .SegIEPolys import *
from .blursharpen import LinearMotionBlur, blursharpen from .blursharpen import LinearMotionBlur, blursharpen
from .filters import apply_random_rgb_levels, \ from .filters import apply_random_rgb_levels, \
apply_random_overlay_triangle, \
apply_random_hsv_shift, \ apply_random_hsv_shift, \
apply_random_sharpen, \
apply_random_motion_blur, \ apply_random_motion_blur, \
apply_random_gaussian_blur, \ apply_random_gaussian_blur, \
apply_random_nearest_resize, \ apply_random_bilinear_resize
apply_random_bilinear_resize, \
apply_random_jpeg_compress, \
apply_random_relight

View file

@ -1,9 +1,10 @@
import cv2 import cv2
import numexpr as ne
import numpy as np import numpy as np
import scipy as sp
from numpy import linalg as npla from numpy import linalg as npla
import scipy as sp
import scipy.sparse
from scipy.sparse.linalg import spsolve
def color_transfer_sot(src,trg, steps=10, batch_size=5, reg_sigmaXY=16.0, reg_sigmaV=5.0): def color_transfer_sot(src,trg, steps=10, batch_size=5, reg_sigmaXY=16.0, reg_sigmaV=5.0):
""" """
@ -34,9 +35,8 @@ def color_transfer_sot(src,trg, steps=10, batch_size=5, reg_sigmaXY=16.0, reg_si
h,w,c = src.shape h,w,c = src.shape
new_src = src.copy() new_src = src.copy()
advect = np.empty ( (h*w,c), dtype=src_dtype )
for step in range (steps): for step in range (steps):
advect.fill(0) advect = np.zeros ( (h*w,c), dtype=src_dtype )
for batch in range (batch_size): for batch in range (batch_size):
dir = np.random.normal(size=c).astype(src_dtype) dir = np.random.normal(size=c).astype(src_dtype)
dir /= npla.norm(dir) dir /= npla.norm(dir)
@ -91,8 +91,6 @@ def color_transfer_mkl(x0, x1):
return np.clip ( result.reshape ( (h,w,c) ).astype(x0.dtype), 0, 1) return np.clip ( result.reshape ( (h,w,c) ).astype(x0.dtype), 0, 1)
def color_transfer_idt(i0, i1, bins=256, n_rot=20): def color_transfer_idt(i0, i1, bins=256, n_rot=20):
import scipy.stats
relaxation = 1 / n_rot relaxation = 1 / n_rot
h,w,c = i0.shape h,w,c = i0.shape
h1,w1,c1 = i1.shape h1,w1,c1 = i1.shape
@ -135,58 +133,136 @@ def color_transfer_idt(i0, i1, bins=256, n_rot=20):
return np.clip ( d0.T.reshape ( (h,w,c) ).astype(i0.dtype) , 0, 1) return np.clip ( d0.T.reshape ( (h,w,c) ).astype(i0.dtype) , 0, 1)
def reinhard_color_transfer(target : np.ndarray, source : np.ndarray, target_mask : np.ndarray = None, source_mask : np.ndarray = None, mask_cutoff=0.5) -> np.ndarray: def laplacian_matrix(n, m):
""" mat_D = scipy.sparse.lil_matrix((m, m))
Transfer color using rct method. mat_D.setdiag(-1, -1)
mat_D.setdiag(4)
mat_D.setdiag(-1, 1)
mat_A = scipy.sparse.block_diag([mat_D] * n).tolil()
mat_A.setdiag(-1, 1*m)
mat_A.setdiag(-1, -1*m)
return mat_A
target np.ndarray H W 3C (BGR) np.float32 def seamless_clone(source, target, mask):
source np.ndarray H W 3C (BGR) np.float32 h, w,c = target.shape
result = []
target_mask(None) np.ndarray H W 1C np.float32 mat_A = laplacian_matrix(h, w)
source_mask(None) np.ndarray H W 1C np.float32 laplacian = mat_A.tocsc()
mask_cutoff(0.5) float mask[0,:] = 1
mask[-1,:] = 1
mask[:,0] = 1
mask[:,-1] = 1
q = np.argwhere(mask==0)
masks are used to limit the space where color statistics will be computed to adjust the target k = q[:,1]+q[:,0]*w
mat_A[k, k] = 1
mat_A[k, k + 1] = 0
mat_A[k, k - 1] = 0
mat_A[k, k + w] = 0
mat_A[k, k - w] = 0
reference: Color Transfer between Images https://www.cs.tau.ac.il/~turkel/imagepapers/ColorTransfer.pdf mat_A = mat_A.tocsc()
""" mask_flat = mask.flatten()
source = cv2.cvtColor(source, cv2.COLOR_BGR2LAB) for channel in range(c):
target = cv2.cvtColor(target, cv2.COLOR_BGR2LAB)
source_input = source source_flat = source[:, :, channel].flatten()
if source_mask is not None: target_flat = target[:, :, channel].flatten()
source_input = source_input.copy()
source_input[source_mask[...,0] < mask_cutoff] = [0,0,0]
target_input = target mat_b = laplacian.dot(source_flat)*0.75
if target_mask is not None: mat_b[mask_flat==0] = target_flat[mask_flat==0]
target_input = target_input.copy()
target_input[target_mask[...,0] < mask_cutoff] = [0,0,0]
target_l_mean, target_l_std, target_a_mean, target_a_std, target_b_mean, target_b_std, \ x = spsolve(mat_A, mat_b).reshape((h, w))
= target_input[...,0].mean(), target_input[...,0].std(), target_input[...,1].mean(), target_input[...,1].std(), target_input[...,2].mean(), target_input[...,2].std() result.append (x)
source_l_mean, source_l_std, source_a_mean, source_a_std, source_b_mean, source_b_std, \
= source_input[...,0].mean(), source_input[...,0].std(), source_input[...,1].mean(), source_input[...,1].std(), source_input[...,2].mean(), source_input[...,2].std()
# not as in the paper: scale by the standard deviations using reciprocal of paper proposed factor return np.clip( np.dstack(result), 0, 1 )
target_l = target[...,0]
target_l = ne.evaluate('(target_l - target_l_mean) * source_l_std / target_l_std + source_l_mean')
target_a = target[...,1] def reinhard_color_transfer(target, source, clip=False, preserve_paper=False, source_mask=None, target_mask=None):
target_a = ne.evaluate('(target_a - target_a_mean) * source_a_std / target_a_std + source_a_mean') """
Transfers the color distribution from the source to the target
image using the mean and standard deviations of the L*a*b*
color space.
target_b = target[...,2] This implementation is (loosely) based on to the "Color Transfer
target_b = ne.evaluate('(target_b - target_b_mean) * source_b_std / target_b_std + source_b_mean') between Images" paper by Reinhard et al., 2001.
np.clip(target_l, 0, 100, out=target_l) Parameters:
np.clip(target_a, -127, 127, out=target_a) -------
np.clip(target_b, -127, 127, out=target_b) source: NumPy array
OpenCV image in BGR color space (the source image)
target: NumPy array
OpenCV image in BGR color space (the target image)
clip: Should components of L*a*b* image be scaled by np.clip before
converting back to BGR color space?
If False then components will be min-max scaled appropriately.
Clipping will keep target image brightness truer to the input.
Scaling will adjust image brightness to avoid washed out portions
in the resulting color transfer that can be caused by clipping.
preserve_paper: Should color transfer strictly follow methodology
layed out in original paper? The method does not always produce
aesthetically pleasing results.
If False then L*a*b* components will scaled using the reciprocal of
the scaling factor proposed in the paper. This method seems to produce
more consistently aesthetically pleasing results
return cv2.cvtColor(np.stack([target_l,target_a,target_b], -1), cv2.COLOR_LAB2BGR) Returns:
-------
transfer: NumPy array
OpenCV image (w, h, 3) NumPy array (uint8)
"""
# convert the images from the RGB to L*ab* color space, being
# sure to utilizing the floating point data type (note: OpenCV
# expects floats to be 32-bit, so use that instead of 64-bit)
source = cv2.cvtColor(source, cv2.COLOR_BGR2LAB).astype(np.float32)
target = cv2.cvtColor(target, cv2.COLOR_BGR2LAB).astype(np.float32)
# compute color statistics for the source and target images
src_input = source if source_mask is None else source*source_mask
tgt_input = target if target_mask is None else target*target_mask
(lMeanSrc, lStdSrc, aMeanSrc, aStdSrc, bMeanSrc, bStdSrc) = lab_image_stats(src_input)
(lMeanTar, lStdTar, aMeanTar, aStdTar, bMeanTar, bStdTar) = lab_image_stats(tgt_input)
# subtract the means from the target image
(l, a, b) = cv2.split(target)
l -= lMeanTar
a -= aMeanTar
b -= bMeanTar
if preserve_paper:
# scale by the standard deviations using paper proposed factor
l = (lStdTar / lStdSrc) * l
a = (aStdTar / aStdSrc) * a
b = (bStdTar / bStdSrc) * b
else:
# scale by the standard deviations using reciprocal of paper proposed factor
l = (lStdSrc / lStdTar) * l
a = (aStdSrc / aStdTar) * a
b = (bStdSrc / bStdTar) * b
# add in the source mean
l += lMeanSrc
a += aMeanSrc
b += bMeanSrc
# clip/scale the pixel intensities to [0, 255] if they fall
# outside this range
l = _scale_array(l, clip=clip)
a = _scale_array(a, clip=clip)
b = _scale_array(b, clip=clip)
# merge the channels together and convert back to the RGB color
# space, being sure to utilize the 8-bit unsigned integer data
# type
transfer = cv2.merge([l, a, b])
transfer = cv2.cvtColor(transfer.astype(np.uint8), cv2.COLOR_LAB2BGR)
# return the color transferred image
return transfer
def linear_color_transfer(target_img, source_img, mode='pca', eps=1e-5): def linear_color_transfer(target_img, source_img, mode='pca', eps=1e-5):
''' '''
Matches the colour distribution of the target image to that of the source image Matches the colour distribution of the target image to that of the source image
@ -323,7 +399,9 @@ def color_transfer(ct_mode, img_src, img_trg):
if ct_mode == 'lct': if ct_mode == 'lct':
out = linear_color_transfer (img_src, img_trg) out = linear_color_transfer (img_src, img_trg)
elif ct_mode == 'rct': elif ct_mode == 'rct':
out = reinhard_color_transfer(img_src, img_trg) out = reinhard_color_transfer ( np.clip( img_src*255, 0, 255 ).astype(np.uint8),
np.clip( img_trg*255, 0, 255 ).astype(np.uint8) )
out = np.clip( out.astype(np.float32) / 255.0, 0.0, 1.0)
elif ct_mode == 'mkl': elif ct_mode == 'mkl':
out = color_transfer_mkl (img_src, img_trg) out = color_transfer_mkl (img_src, img_trg)
elif ct_mode == 'idt': elif ct_mode == 'idt':

View file

@ -1,16 +1,5 @@
import numpy as np import numpy as np
def random_crop(img, w, h):
height, width = img.shape[:2]
h_rnd = height - h
w_rnd = width - w
y = np.random.randint(0, h_rnd) if h_rnd > 0 else 0
x = np.random.randint(0, w_rnd) if w_rnd > 0 else 0
return img[y:y+height, x:x+width]
def normalize_channels(img, target_channels): def normalize_channels(img, target_channels):
img_shape_len = len(img.shape) img_shape_len = len(img.shape)
if img_shape_len == 2: if img_shape_len == 2:

View file

@ -31,7 +31,9 @@ goods or services; loss of use, data, or profits; or business interruption) howe
import numpy as np import numpy as np
import cv2 import cv2
from math import atan2, pi from math import atan2, pi
from scipy.ndimage import convolve
from skimage.filters.edges import HSOBEL_WEIGHTS
from skimage.feature import canny
def sobel(image): def sobel(image):
# type: (numpy.ndarray) -> numpy.ndarray # type: (numpy.ndarray) -> numpy.ndarray
@ -40,11 +42,10 @@ def sobel(image):
Inspired by the [Octave implementation](https://sourceforge.net/p/octave/image/ci/default/tree/inst/edge.m#l196). Inspired by the [Octave implementation](https://sourceforge.net/p/octave/image/ci/default/tree/inst/edge.m#l196).
""" """
from skimage.filters.edges import HSOBEL_WEIGHTS
h1 = np.array(HSOBEL_WEIGHTS) h1 = np.array(HSOBEL_WEIGHTS)
h1 /= np.sum(abs(h1)) # normalize h1 h1 /= np.sum(abs(h1)) # normalize h1
from scipy.ndimage import convolve
strength2 = np.square(convolve(image, h1.T)) strength2 = np.square(convolve(image, h1.T))
# Note: https://sourceforge.net/p/octave/image/ci/default/tree/inst/edge.m#l59 # Note: https://sourceforge.net/p/octave/image/ci/default/tree/inst/edge.m#l59
@ -102,7 +103,6 @@ def compute(image):
# edge detection using canny and sobel canny edge detection is done to # edge detection using canny and sobel canny edge detection is done to
# classify the blocks as edge or non-edge blocks and sobel edge # classify the blocks as edge or non-edge blocks and sobel edge
# detection is done for the purpose of edge width measurement. # detection is done for the purpose of edge width measurement.
from skimage.feature import canny
canny_edges = canny(image) canny_edges = canny(image)
sobel_edges = sobel(image) sobel_edges = sobel(image)
@ -269,10 +269,9 @@ def get_block_contrast(block):
def estimate_sharpness(image): def estimate_sharpness(image):
height, width = image.shape[:2]
if image.ndim == 3: if image.ndim == 3:
if image.shape[2] > 1: image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
else:
image = image[...,0]
return compute(image) return compute(image)

View file

@ -1,5 +1,5 @@
import numpy as np import numpy as np
from .blursharpen import LinearMotionBlur, blursharpen from .blursharpen import LinearMotionBlur
import cv2 import cv2
def apply_random_rgb_levels(img, mask=None, rnd_state=None): def apply_random_rgb_levels(img, mask=None, rnd_state=None):
@ -38,24 +38,6 @@ def apply_random_hsv_shift(img, mask=None, rnd_state=None):
return result return result
def apply_random_sharpen( img, chance, kernel_max_size, mask=None, rnd_state=None ):
if rnd_state is None:
rnd_state = np.random
sharp_rnd_kernel = rnd_state.randint(kernel_max_size)+1
result = img
if rnd_state.randint(100) < np.clip(chance, 0, 100):
if rnd_state.randint(2) == 0:
result = blursharpen(result, 1, sharp_rnd_kernel, rnd_state.randint(10) )
else:
result = blursharpen(result, 2, sharp_rnd_kernel, rnd_state.randint(50) )
if mask is not None:
result = img*(1-mask) + result*mask
return result
def apply_random_motion_blur( img, chance, mb_max_size, mask=None, rnd_state=None ): def apply_random_motion_blur( img, chance, mb_max_size, mask=None, rnd_state=None ):
if rnd_state is None: if rnd_state is None:
rnd_state = np.random rnd_state = np.random
@ -84,7 +66,8 @@ def apply_random_gaussian_blur( img, chance, kernel_max_size, mask=None, rnd_sta
return result return result
def apply_random_resize( img, chance, max_size_per, interpolation=cv2.INTER_LINEAR, mask=None, rnd_state=None ):
def apply_random_bilinear_resize( img, chance, max_size_per, mask=None, rnd_state=None ):
if rnd_state is None: if rnd_state is None:
rnd_state = np.random rnd_state = np.random
@ -96,150 +79,9 @@ def apply_random_resize( img, chance, max_size_per, interpolation=cv2.INTER_LINE
rw = w - int( trg * int(w*(max_size_per/100.0)) ) rw = w - int( trg * int(w*(max_size_per/100.0)) )
rh = h - int( trg * int(h*(max_size_per/100.0)) ) rh = h - int( trg * int(h*(max_size_per/100.0)) )
result = cv2.resize (result, (rw,rh), interpolation=interpolation ) result = cv2.resize (result, (rw,rh), cv2.INTER_LINEAR )
result = cv2.resize (result, (w,h), interpolation=interpolation ) result = cv2.resize (result, (w,h), cv2.INTER_LINEAR )
if mask is not None: if mask is not None:
result = img*(1-mask) + result*mask result = img*(1-mask) + result*mask
return result return result
def apply_random_nearest_resize( img, chance, max_size_per, mask=None, rnd_state=None ):
return apply_random_resize( img, chance, max_size_per, interpolation=cv2.INTER_NEAREST, mask=mask, rnd_state=rnd_state )
def apply_random_bilinear_resize( img, chance, max_size_per, mask=None, rnd_state=None ):
return apply_random_resize( img, chance, max_size_per, interpolation=cv2.INTER_LINEAR, mask=mask, rnd_state=rnd_state )
def apply_random_jpeg_compress( img, chance, mask=None, rnd_state=None ):
if rnd_state is None:
rnd_state = np.random
result = img
if rnd_state.randint(100) < np.clip(chance, 0, 100):
h,w,c = result.shape
quality = rnd_state.randint(10,101)
ret, result = cv2.imencode('.jpg', np.clip(img*255, 0,255).astype(np.uint8), [int(cv2.IMWRITE_JPEG_QUALITY), quality] )
if ret == True:
result = cv2.imdecode(result, flags=cv2.IMREAD_UNCHANGED)
result = result.astype(np.float32) / 255.0
if mask is not None:
result = img*(1-mask) + result*mask
return result
def apply_random_overlay_triangle( img, max_alpha, mask=None, rnd_state=None ):
if rnd_state is None:
rnd_state = np.random
h,w,c = img.shape
pt1 = [rnd_state.randint(w), rnd_state.randint(h) ]
pt2 = [rnd_state.randint(w), rnd_state.randint(h) ]
pt3 = [rnd_state.randint(w), rnd_state.randint(h) ]
alpha = rnd_state.uniform()*max_alpha
tri_mask = cv2.fillPoly( np.zeros_like(img), [ np.array([pt1,pt2,pt3], np.int32) ], (alpha,)*c )
if rnd_state.randint(2) == 0:
result = np.clip(img+tri_mask, 0, 1)
else:
result = np.clip(img-tri_mask, 0, 1)
if mask is not None:
result = img*(1-mask) + result*mask
return result
def _min_resize(x, m):
if x.shape[0] < x.shape[1]:
s0 = m
s1 = int(float(m) / float(x.shape[0]) * float(x.shape[1]))
else:
s0 = int(float(m) / float(x.shape[1]) * float(x.shape[0]))
s1 = m
new_max = min(s1, s0)
raw_max = min(x.shape[0], x.shape[1])
return cv2.resize(x, (s1, s0), interpolation=cv2.INTER_LANCZOS4)
def _d_resize(x, d, fac=1.0):
new_min = min(int(d[1] * fac), int(d[0] * fac))
raw_min = min(x.shape[0], x.shape[1])
if new_min < raw_min:
interpolation = cv2.INTER_AREA
else:
interpolation = cv2.INTER_LANCZOS4
y = cv2.resize(x, (int(d[1] * fac), int(d[0] * fac)), interpolation=interpolation)
return y
def _get_image_gradient(dist):
cols = cv2.filter2D(dist, cv2.CV_32F, np.array([[-1, 0, +1], [-2, 0, +2], [-1, 0, +1]]))
rows = cv2.filter2D(dist, cv2.CV_32F, np.array([[-1, -2, -1], [0, 0, 0], [+1, +2, +1]]))
return cols, rows
def _generate_lighting_effects(content):
h512 = content
h256 = cv2.pyrDown(h512)
h128 = cv2.pyrDown(h256)
h64 = cv2.pyrDown(h128)
h32 = cv2.pyrDown(h64)
h16 = cv2.pyrDown(h32)
c512, r512 = _get_image_gradient(h512)
c256, r256 = _get_image_gradient(h256)
c128, r128 = _get_image_gradient(h128)
c64, r64 = _get_image_gradient(h64)
c32, r32 = _get_image_gradient(h32)
c16, r16 = _get_image_gradient(h16)
c = c16
c = _d_resize(cv2.pyrUp(c), c32.shape) * 4.0 + c32
c = _d_resize(cv2.pyrUp(c), c64.shape) * 4.0 + c64
c = _d_resize(cv2.pyrUp(c), c128.shape) * 4.0 + c128
c = _d_resize(cv2.pyrUp(c), c256.shape) * 4.0 + c256
c = _d_resize(cv2.pyrUp(c), c512.shape) * 4.0 + c512
r = r16
r = _d_resize(cv2.pyrUp(r), r32.shape) * 4.0 + r32
r = _d_resize(cv2.pyrUp(r), r64.shape) * 4.0 + r64
r = _d_resize(cv2.pyrUp(r), r128.shape) * 4.0 + r128
r = _d_resize(cv2.pyrUp(r), r256.shape) * 4.0 + r256
r = _d_resize(cv2.pyrUp(r), r512.shape) * 4.0 + r512
coarse_effect_cols = c
coarse_effect_rows = r
EPS = 1e-10
max_effect = np.max((coarse_effect_cols**2 + coarse_effect_rows**2)**0.5, axis=0, keepdims=True, ).max(1, keepdims=True)
coarse_effect_cols = (coarse_effect_cols + EPS) / (max_effect + EPS)
coarse_effect_rows = (coarse_effect_rows + EPS) / (max_effect + EPS)
return np.stack([ np.zeros_like(coarse_effect_rows), coarse_effect_rows, coarse_effect_cols], axis=-1)
def apply_random_relight(img, mask=None, rnd_state=None):
if rnd_state is None:
rnd_state = np.random
def_img = img
if rnd_state.randint(2) == 0:
light_pos_y = 1.0 if rnd_state.randint(2) == 0 else -1.0
light_pos_x = rnd_state.uniform()*2-1.0
else:
light_pos_y = rnd_state.uniform()*2-1.0
light_pos_x = 1.0 if rnd_state.randint(2) == 0 else -1.0
light_source_height = 0.3*rnd_state.uniform()*0.7
light_intensity = 1.0+rnd_state.uniform()
ambient_intensity = 0.5
light_source_location = np.array([[[light_source_height, light_pos_y, light_pos_x ]]], dtype=np.float32)
light_source_direction = light_source_location / np.sqrt(np.sum(np.square(light_source_location)))
lighting_effect = _generate_lighting_effects(img)
lighting_effect = np.sum(lighting_effect * light_source_direction, axis=-1).clip(0, 1)
lighting_effect = np.mean(lighting_effect, axis=-1, keepdims=True)
result = def_img * (ambient_intensity + lighting_effect * light_intensity) #light_source_color
result = np.clip(result, 0, 1)
if mask is not None:
result = def_img*(1-mask) + result*mask
return result

View file

@ -1,2 +1,2 @@
from .draw import circle_faded, random_circle_faded, bezier, random_bezier_split_faded, random_faded from .draw import *
from .calc import * from .calc import *

View file

@ -1,36 +1,23 @@
""" """
Signed distance drawing functions using numpy. Signed distance drawing functions using numpy.
""" """
import math
import numpy as np import numpy as np
from numpy import linalg as npla from numpy import linalg as npla
def circle_faded( hw, center, fade_dists ):
def vector2_dot(a,b):
return a[...,0]*b[...,0]+a[...,1]*b[...,1]
def vector2_dot2(a):
return a[...,0]*a[...,0]+a[...,1]*a[...,1]
def vector2_cross(a,b):
return a[...,0]*b[...,1]-a[...,1]*b[...,0]
def circle_faded( wh, center, fade_dists ):
""" """
returns drawn circle in [h,w,1] output range [0..1.0] float32 returns drawn circle in [h,w,1] output range [0..1.0] float32
wh = [w,h] resolution hw = [h,w] resolution
center = [x,y] center of circle center = [y,x] center of circle
fade_dists = [fade_start, fade_end] fade values fade_dists = [fade_start, fade_end] fade values
""" """
w,h = wh h,w = hw
pts = np.empty( (h,w,2), dtype=np.float32 ) pts = np.empty( (h,w,2), dtype=np.float32 )
pts[...,0] = np.arange(w)[:,None]
pts[...,1] = np.arange(h)[None,:] pts[...,1] = np.arange(h)[None,:]
pts[...,0] = np.arange(w)[:,None]
pts = pts.reshape ( (h*w, -1) ) pts = pts.reshape ( (h*w, -1) )
pts_dists = np.abs ( npla.norm(pts-center, axis=-1) ) pts_dists = np.abs ( npla.norm(pts-center, axis=-1) )
@ -44,157 +31,14 @@ def circle_faded( wh, center, fade_dists ):
return pts_dists.reshape ( (h,w,1) ).astype(np.float32) return pts_dists.reshape ( (h,w,1) ).astype(np.float32)
def random_circle_faded ( hw, rnd_state=None ):
def bezier( wh, A, B, C ):
"""
returns drawn bezier in [h,w,1] output range float32,
every pixel contains signed distance to bezier line
wh [w,h] resolution
A,B,C points [x,y]
"""
width,height = wh
A = np.float32(A)
B = np.float32(B)
C = np.float32(C)
pos = np.empty( (height,width,2), dtype=np.float32 )
pos[...,0] = np.arange(width)[:,None]
pos[...,1] = np.arange(height)[None,:]
a = B-A
b = A - 2.0*B + C
c = a * 2.0
d = A - pos
b_dot = vector2_dot(b,b)
if b_dot == 0.0:
return np.zeros( (height,width), dtype=np.float32 )
kk = 1.0 / b_dot
kx = kk * vector2_dot(a,b)
ky = kk * (2.0*vector2_dot(a,a)+vector2_dot(d,b))/3.0;
kz = kk * vector2_dot(d,a);
res = 0.0;
sgn = 0.0;
p = ky - kx*kx;
p3 = p*p*p;
q = kx*(2.0*kx*kx - 3.0*ky) + kz;
h = q*q + 4.0*p3;
hp_sel = h >= 0.0
hp_p = h[hp_sel]
hp_p = np.sqrt(hp_p)
hp_x = ( np.stack( (hp_p,-hp_p), -1) -q[hp_sel,None] ) / 2.0
hp_uv = np.sign(hp_x) * np.power( np.abs(hp_x), [1.0/3.0, 1.0/3.0] )
hp_t = np.clip( hp_uv[...,0] + hp_uv[...,1] - kx, 0.0, 1.0 )
hp_t = hp_t[...,None]
hp_q = d[hp_sel]+(c+b*hp_t)*hp_t
hp_res = vector2_dot2(hp_q)
hp_sgn = vector2_cross(c+2.0*b*hp_t,hp_q)
hl_sel = h < 0.0
hl_q = q[hl_sel]
hl_p = p[hl_sel]
hl_z = np.sqrt(-hl_p)
hl_v = np.arccos( hl_q / (hl_p*hl_z*2.0)) / 3.0
hl_m = np.cos(hl_v)
hl_n = np.sin(hl_v)*1.732050808;
hl_t = np.clip( np.stack( (hl_m+hl_m,-hl_n-hl_m,hl_n-hl_m), -1)*hl_z[...,None]-kx, 0.0, 1.0 );
hl_d = d[hl_sel]
hl_qx = hl_d+(c+b*hl_t[...,0:1])*hl_t[...,0:1]
hl_dx = vector2_dot2(hl_qx)
hl_sx = vector2_cross(c+2.0*b*hl_t[...,0:1], hl_qx)
hl_qy = hl_d+(c+b*hl_t[...,1:2])*hl_t[...,1:2]
hl_dy = vector2_dot2(hl_qy)
hl_sy = vector2_cross(c+2.0*b*hl_t[...,1:2],hl_qy);
hl_dx_l_dy = hl_dx<hl_dy
hl_dx_ge_dy = hl_dx>=hl_dy
hl_res = np.empty_like(hl_dx)
hl_res[hl_dx_l_dy] = hl_dx[hl_dx_l_dy]
hl_res[hl_dx_ge_dy] = hl_dy[hl_dx_ge_dy]
hl_sgn = np.empty_like(hl_sx)
hl_sgn[hl_dx_l_dy] = hl_sx[hl_dx_l_dy]
hl_sgn[hl_dx_ge_dy] = hl_sy[hl_dx_ge_dy]
res = np.empty( (height, width), np.float32 )
res[hp_sel] = hp_res
res[hl_sel] = hl_res
sgn = np.empty( (height, width), np.float32 )
sgn[hp_sel] = hp_sgn
sgn[hl_sel] = hl_sgn
sgn = np.sign(sgn)
res = np.sqrt(res)*sgn
return res[...,None]
def random_faded(wh):
"""
apply one of them:
random_circle_faded
random_bezier_split_faded
"""
rnd = np.random.randint(2)
if rnd == 0:
return random_circle_faded(wh)
elif rnd == 1:
return random_bezier_split_faded(wh)
def random_circle_faded ( wh, rnd_state=None ):
if rnd_state is None: if rnd_state is None:
rnd_state = np.random rnd_state = np.random
w,h = wh h,w = hw
wh_max = max(w,h) hw_max = max(h,w)
fade_start = rnd_state.randint(wh_max) fade_start = rnd_state.randint(hw_max)
fade_end = fade_start + rnd_state.randint(wh_max- fade_start) fade_end = fade_start + rnd_state.randint(hw_max- fade_start)
return circle_faded (wh, [ rnd_state.randint(h), rnd_state.randint(w) ], return circle_faded (hw, [ rnd_state.randint(h), rnd_state.randint(w) ],
[fade_start, fade_end] ) [fade_start, fade_end] )
def random_bezier_split_faded( wh ):
width, height = wh
degA = np.random.randint(360)
degB = np.random.randint(360)
degC = np.random.randint(360)
deg_2_rad = math.pi / 180.0
center = np.float32([width / 2.0, height / 2.0])
radius = max(width, height)
A = center + radius*np.float32([ math.sin( degA * deg_2_rad), math.cos( degA * deg_2_rad) ] )
B = center + np.random.randint(radius)*np.float32([ math.sin( degB * deg_2_rad), math.cos( degB * deg_2_rad) ] )
C = center + radius*np.float32([ math.sin( degC * deg_2_rad), math.cos( degC * deg_2_rad) ] )
x = bezier( (width,height), A, B, C )
x = x / (1+np.random.randint(radius)) + 0.5
x = np.clip(x, 0, 1)
return x

View file

@ -1,146 +1,32 @@
import numpy as np import numpy as np
import numpy.linalg as npla
import cv2 import cv2
from core import randomex from core import randomex
def mls_rigid_deformation(vy, vx, src_pts, dst_pts, alpha=1.0, eps=1e-8): def gen_warp_params (w, flip, rotation_range=[-10,10], scale_range=[-0.5, 0.5], tx_range=[-0.05, 0.05], ty_range=[-0.05, 0.05], rnd_state=None ):
dst_pts = dst_pts[..., ::-1].astype(np.int16)
src_pts = src_pts[..., ::-1].astype(np.int16)
src_pts, dst_pts = dst_pts, src_pts
grow = vx.shape[0]
gcol = vx.shape[1]
ctrls = src_pts.shape[0]
reshaped_p = src_pts.reshape(ctrls, 2, 1, 1)
reshaped_v = np.vstack((vx.reshape(1, grow, gcol), vy.reshape(1, grow, gcol)))
w = 1.0 / (np.sum((reshaped_p - reshaped_v).astype(np.float32) ** 2, axis=1) + eps) ** alpha
w /= np.sum(w, axis=0, keepdims=True)
pstar = np.zeros((2, grow, gcol), np.float32)
for i in range(ctrls):
pstar += w[i] * reshaped_p[i]
vpstar = reshaped_v - pstar
reshaped_mul_right = np.concatenate((vpstar[:,None,...],
np.concatenate((vpstar[1:2,None,...],-vpstar[0:1,None,...]), 0)
), axis=1).transpose(2, 3, 0, 1)
reshaped_q = dst_pts.reshape((ctrls, 2, 1, 1))
qstar = np.zeros((2, grow, gcol), np.float32)
for i in range(ctrls):
qstar += w[i] * reshaped_q[i]
temp = np.zeros((grow, gcol, 2), np.float32)
for i in range(ctrls):
phat = reshaped_p[i] - pstar
qhat = reshaped_q[i] - qstar
temp += np.matmul(qhat.reshape(1, 2, grow, gcol).transpose(2, 3, 0, 1),
np.matmul( ( w[None, i:i+1,...] *
np.concatenate((phat.reshape(1, 2, grow, gcol),
np.concatenate( (phat[None,1:2], -phat[None,0:1]), 1 )), 0)
).transpose(2, 3, 0, 1), reshaped_mul_right
)
).reshape(grow, gcol, 2)
temp = temp.transpose(2, 0, 1)
normed_temp = np.linalg.norm(temp, axis=0, keepdims=True)
normed_vpstar = np.linalg.norm(vpstar, axis=0, keepdims=True)
nan_mask = normed_temp[0]==0
transformers = np.true_divide(temp, normed_temp, out=np.zeros_like(temp), where= ~nan_mask) * normed_vpstar + qstar
nan_mask_flat = np.flatnonzero(nan_mask)
nan_mask_anti_flat = np.flatnonzero(~nan_mask)
transformers[0][nan_mask] = np.interp(nan_mask_flat, nan_mask_anti_flat, transformers[0][~nan_mask])
transformers[1][nan_mask] = np.interp(nan_mask_flat, nan_mask_anti_flat, transformers[1][~nan_mask])
return transformers
def gen_pts(W, H, rnd_state=None):
if rnd_state is None: if rnd_state is None:
rnd_state = np.random rnd_state = np.random
min_pts, max_pts = 4, 8
n_pts = rnd_state.randint(min_pts, max_pts)
min_radius_per = 0.00
max_radius_per = 0.10
pts = []
for i in range(n_pts):
while True:
x, y = rnd_state.randint(W), rnd_state.randint(H)
rad = min_radius_per + rnd_state.rand()*(max_radius_per-min_radius_per)
intersect = False
for px,py,prad,_,_ in pts:
dist = npla.norm([x-px, y-py])
if dist <= (rad+prad)*2:
intersect = True
break
if intersect:
continue
angle = rnd_state.rand()*(2*np.pi)
x2 = int(x+np.cos(angle)*W*rad)
y2 = int(y+np.sin(angle)*H*rad)
break
pts.append( (x,y,rad, x2,y2) )
pts1 = np.array( [ [pt[0],pt[1]] for pt in pts ] )
pts2 = np.array( [ [pt[-2],pt[-1]] for pt in pts ] )
return pts1, pts2
def gen_warp_params (w, flip=False, rotation_range=[-10,10], scale_range=[-0.5, 0.5], tx_range=[-0.05, 0.05], ty_range=[-0.05, 0.05], rnd_state=None, warp_rnd_state=None ):
if rnd_state is None:
rnd_state = np.random
if warp_rnd_state is None:
warp_rnd_state = np.random
rw = None
if w < 64:
rw = w
w = 64
rotation = rnd_state.uniform( rotation_range[0], rotation_range[1] ) rotation = rnd_state.uniform( rotation_range[0], rotation_range[1] )
scale = rnd_state.uniform( 1/(1-scale_range[0]) , 1+scale_range[1] ) scale = rnd_state.uniform(1 +scale_range[0], 1 +scale_range[1])
tx = rnd_state.uniform( tx_range[0], tx_range[1] ) tx = rnd_state.uniform( tx_range[0], tx_range[1] )
ty = rnd_state.uniform( ty_range[0], ty_range[1] ) ty = rnd_state.uniform( ty_range[0], ty_range[1] )
p_flip = flip and rnd_state.randint(10) < 4 p_flip = flip and rnd_state.randint(10) < 4
#random warp V1 #random warp by grid
cell_size = [ w // (2**i) for i in range(1,4) ] [ warp_rnd_state.randint(3) ] cell_size = [ w // (2**i) for i in range(1,4) ] [ rnd_state.randint(3) ]
cell_count = w // cell_size + 1 cell_count = w // cell_size + 1
grid_points = np.linspace( 0, w, cell_count) grid_points = np.linspace( 0, w, cell_count)
mapx = np.broadcast_to(grid_points, (cell_count, cell_count)).copy() mapx = np.broadcast_to(grid_points, (cell_count, cell_count)).copy()
mapy = mapx.T mapy = mapx.T
mapx[1:-1,1:-1] = mapx[1:-1,1:-1] + randomex.random_normal( size=(cell_count-2, cell_count-2), rnd_state=warp_rnd_state )*(cell_size*0.24)
mapy[1:-1,1:-1] = mapy[1:-1,1:-1] + randomex.random_normal( size=(cell_count-2, cell_count-2), rnd_state=warp_rnd_state )*(cell_size*0.24) mapx[1:-1,1:-1] = mapx[1:-1,1:-1] + randomex.random_normal( size=(cell_count-2, cell_count-2) )*(cell_size*0.24)
mapy[1:-1,1:-1] = mapy[1:-1,1:-1] + randomex.random_normal( size=(cell_count-2, cell_count-2) )*(cell_size*0.24)
half_cell_size = cell_size // 2 half_cell_size = cell_size // 2
mapx = cv2.resize(mapx, (w+cell_size,)*2 )[half_cell_size:-half_cell_size,half_cell_size:-half_cell_size].astype(np.float32) mapx = cv2.resize(mapx, (w+cell_size,)*2 )[half_cell_size:-half_cell_size,half_cell_size:-half_cell_size].astype(np.float32)
mapy = cv2.resize(mapy, (w+cell_size,)*2 )[half_cell_size:-half_cell_size,half_cell_size:-half_cell_size].astype(np.float32) mapy = cv2.resize(mapy, (w+cell_size,)*2 )[half_cell_size:-half_cell_size,half_cell_size:-half_cell_size].astype(np.float32)
##############
# random warp V2
# pts1, pts2 = gen_pts(w, w, rnd_state)
# gridX = np.arange(w, dtype=np.int16)
# gridY = np.arange(w, dtype=np.int16)
# vy, vx = np.meshgrid(gridX, gridY)
# drigid = mls_rigid_deformation(vy, vx, pts1, pts2)
# mapy, mapx = drigid.astype(np.float32)
################
#random transform #random transform
random_transform_mat = cv2.getRotationMatrix2D((w // 2, w // 2), rotation, scale) random_transform_mat = cv2.getRotationMatrix2D((w // 2, w // 2), rotation, scale)
@ -150,30 +36,16 @@ def gen_warp_params (w, flip=False, rotation_range=[-10,10], scale_range=[-0.5,
params['mapx'] = mapx params['mapx'] = mapx
params['mapy'] = mapy params['mapy'] = mapy
params['rmat'] = random_transform_mat params['rmat'] = random_transform_mat
u_mat = random_transform_mat.copy()
u_mat[:,2] /= w
params['umat'] = u_mat
params['w'] = w params['w'] = w
params['rw'] = rw
params['flip'] = p_flip params['flip'] = p_flip
return params return params
def warp_by_params (params, img, can_warp, can_transform, can_flip, border_replicate, cv2_inter=cv2.INTER_CUBIC): def warp_by_params (params, img, can_warp, can_transform, can_flip, border_replicate, cv2_inter=cv2.INTER_CUBIC):
rw = params['rw']
if (can_warp or can_transform) and rw is not None:
img = cv2.resize(img, (64,64), interpolation=cv2_inter)
if can_warp: if can_warp:
img = cv2.remap(img, params['mapx'], params['mapy'], cv2_inter ) img = cv2.remap(img, params['mapx'], params['mapy'], cv2_inter )
if can_transform: if can_transform:
img = cv2.warpAffine( img, params['rmat'], (params['w'], params['w']), borderMode=(cv2.BORDER_REPLICATE if border_replicate else cv2.BORDER_CONSTANT), flags=cv2_inter ) img = cv2.warpAffine( img, params['rmat'], (params['w'], params['w']), borderMode=(cv2.BORDER_REPLICATE if border_replicate else cv2.BORDER_CONSTANT), flags=cv2_inter )
if (can_warp or can_transform) and rw is not None:
img = cv2.resize(img, (rw,rw), interpolation=cv2_inter)
if len(img.shape) == 2: if len(img.shape) == 2:
img = img[...,None] img = img[...,None]
if can_flip and params['flip']: if can_flip and params['flip']:

View file

@ -7,7 +7,6 @@ import types
import colorama import colorama
import cv2 import cv2
import numpy as np
from tqdm import tqdm from tqdm import tqdm
from core import stdex from core import stdex
@ -198,7 +197,7 @@ class InteractBase(object):
def add_key_event(self, wnd_name, ord_key, ctrl_pressed, alt_pressed, shift_pressed): def add_key_event(self, wnd_name, ord_key, ctrl_pressed, alt_pressed, shift_pressed):
if wnd_name not in self.key_events: if wnd_name not in self.key_events:
self.key_events[wnd_name] = [] self.key_events[wnd_name] = []
self.key_events[wnd_name] += [ (ord_key, chr(ord_key) if ord_key <= 255 else chr(0), ctrl_pressed, alt_pressed, shift_pressed) ] self.key_events[wnd_name] += [ (ord_key, chr(ord_key), ctrl_pressed, alt_pressed, shift_pressed) ]
def get_mouse_events(self, wnd_name): def get_mouse_events(self, wnd_name):
ar = self.mouse_events.get(wnd_name, []) ar = self.mouse_events.get(wnd_name, [])
@ -256,7 +255,7 @@ class InteractBase(object):
print(result) print(result)
return result return result
def input_int(self, s, default_value, valid_range=None, valid_list=None, add_info=None, show_default_value=True, help_message=None): def input_int(self, s, default_value, valid_list=None, add_info=None, show_default_value=True, help_message=None):
if show_default_value: if show_default_value:
if len(s) != 0: if len(s) != 0:
s = f"[{default_value}] {s}" s = f"[{default_value}] {s}"
@ -264,21 +263,15 @@ class InteractBase(object):
s = f"[{default_value}]" s = f"[{default_value}]"
if add_info is not None or \ if add_info is not None or \
valid_range is not None or \
help_message is not None: help_message is not None:
s += " (" s += " ("
if valid_range is not None:
s += f" {valid_range[0]}-{valid_range[1]}"
if add_info is not None: if add_info is not None:
s += f" {add_info}" s += f" {add_info}"
if help_message is not None: if help_message is not None:
s += " ?:help" s += " ?:help"
if add_info is not None or \ if add_info is not None or \
valid_range is not None or \
help_message is not None: help_message is not None:
s += " )" s += " )"
@ -295,12 +288,9 @@ class InteractBase(object):
continue continue
i = int(inp) i = int(inp)
if valid_range is not None:
i = int(np.clip(i, valid_range[0], valid_range[1]))
if (valid_list is not None) and (i not in valid_list): if (valid_list is not None) and (i not in valid_list):
i = default_value result = default_value
break
result = i result = i
break break
except: except:
@ -501,11 +491,10 @@ class InteractDesktop(InteractBase):
if has_windows or has_capture_keys: if has_windows or has_capture_keys:
wait_key_time = max(1, int(sleep_time*1000) ) wait_key_time = max(1, int(sleep_time*1000) )
ord_key = cv2.waitKeyEx(wait_key_time) ord_key = cv2.waitKey(wait_key_time)
shift_pressed = False shift_pressed = False
if ord_key != -1: if ord_key != -1:
chr_key = chr(ord_key) if ord_key <= 255 else chr(0) chr_key = chr(ord_key)
if chr_key >= 'A' and chr_key <= 'Z': if chr_key >= 'A' and chr_key <= 'Z':
shift_pressed = True shift_pressed = True

View file

@ -81,8 +81,11 @@ class Subprocessor(object):
except Subprocessor.SilenceException as e: except Subprocessor.SilenceException as e:
c2s.put ( {'op': 'error', 'data' : data} ) c2s.put ( {'op': 'error', 'data' : data} )
except Exception as e: except Exception as e:
err_msg = traceback.format_exc() c2s.put ( {'op': 'error', 'data' : data} )
c2s.put ( {'op': 'error', 'data' : data, 'err_msg' : err_msg} ) if data is not None:
print ('Exception while process data [%s]: %s' % (self.get_data_name(data), traceback.format_exc()) )
else:
print ('Exception: %s' % (traceback.format_exc()) )
c2s.close() c2s.close()
s2c.close() s2c.close()
@ -156,24 +159,6 @@ class Subprocessor(object):
self.clis = [] self.clis = []
def cli_init_dispatcher(cli):
while not cli.c2s.empty():
obj = cli.c2s.get()
op = obj.get('op','')
if op == 'init_ok':
cli.state = 0
elif op == 'log_info':
io.log_info(obj['msg'])
elif op == 'log_err':
io.log_err(obj['msg'])
elif op == 'error':
err_msg = obj.get('err_msg', None)
if err_msg is not None:
io.log_info(f'Error while subprocess initialization: {err_msg}')
cli.kill()
self.clis.remove(cli)
break
#getting info about name of subprocesses, host and client dicts, and spawning them #getting info about name of subprocesses, host and client dicts, and spawning them
for name, host_dict, client_dict in self.process_info_generator(): for name, host_dict, client_dict in self.process_info_generator():
try: try:
@ -188,7 +173,19 @@ class Subprocessor(object):
if self.initialize_subprocesses_in_serial: if self.initialize_subprocesses_in_serial:
while True: while True:
cli_init_dispatcher(cli) while not cli.c2s.empty():
obj = cli.c2s.get()
op = obj.get('op','')
if op == 'init_ok':
cli.state = 0
elif op == 'log_info':
io.log_info(obj['msg'])
elif op == 'log_err':
io.log_err(obj['msg'])
elif op == 'error':
cli.kill()
self.clis.remove(cli)
break
if cli.state == 0: if cli.state == 0:
break break
io.process_messages(0.005) io.process_messages(0.005)
@ -201,7 +198,19 @@ class Subprocessor(object):
#waiting subprocesses their success(or not) initialization #waiting subprocesses their success(or not) initialization
while True: while True:
for cli in self.clis[:]: for cli in self.clis[:]:
cli_init_dispatcher(cli) while not cli.c2s.empty():
obj = cli.c2s.get()
op = obj.get('op','')
if op == 'init_ok':
cli.state = 0
elif op == 'log_info':
io.log_info(obj['msg'])
elif op == 'log_err':
io.log_err(obj['msg'])
elif op == 'error':
cli.kill()
self.clis.remove(cli)
break
if all ([cli.state == 0 for cli in self.clis]): if all ([cli.state == 0 for cli in self.clis]):
break break
io.process_messages(0.005) io.process_messages(0.005)
@ -226,10 +235,6 @@ class Subprocessor(object):
cli.state = 0 cli.state = 0
elif op == 'error': elif op == 'error':
#some error occured while process data, returning chunk to on_data_return #some error occured while process data, returning chunk to on_data_return
err_msg = obj.get('err_msg', None)
if err_msg is not None:
io.log_info(f'Error while processing data: {err_msg}')
if 'data' in obj.keys(): if 'data' in obj.keys():
self.on_data_return (cli.host_dict, obj['data'] ) self.on_data_return (cli.host_dict, obj['data'] )
#and killing process #and killing process

View file

@ -6,55 +6,49 @@ class DeepFakeArchi(nn.ArchiBase):
resolution resolution
mod None - default mod None - default
'uhd'
'quick' 'quick'
opts ''
''
't'
""" """
def __init__(self, resolution, use_fp16=False, mod=None, opts=None): def __init__(self, resolution, mod=None):
super().__init__() super().__init__()
if opts is None:
opts = ''
conv_dtype = tf.float16 if use_fp16 else tf.float32
if 'c' in opts:
def act(x, alpha=0.1):
return x*tf.cos(x)
else:
def act(x, alpha=0.1):
return tf.nn.leaky_relu(x, alpha)
if mod is None: if mod is None:
class Downscale(nn.ModelBase): class Downscale(nn.ModelBase):
def __init__(self, in_ch, out_ch, kernel_size=5, *kwargs ): def __init__(self, in_ch, out_ch, kernel_size=5, dilations=1, subpixel=True, use_activator=True, *kwargs ):
self.in_ch = in_ch self.in_ch = in_ch
self.out_ch = out_ch self.out_ch = out_ch
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.dilations = dilations
self.subpixel = subpixel
self.use_activator = use_activator
super().__init__(*kwargs) super().__init__(*kwargs)
def on_build(self, *args, **kwargs ): def on_build(self, *args, **kwargs ):
self.conv1 = nn.Conv2D( self.in_ch, self.out_ch, kernel_size=self.kernel_size, strides=2, padding='SAME', dtype=conv_dtype) self.conv1 = nn.Conv2D( self.in_ch,
self.out_ch // (4 if self.subpixel else 1),
kernel_size=self.kernel_size,
strides=1 if self.subpixel else 2,
padding='SAME', dilations=self.dilations)
def forward(self, x): def forward(self, x):
x = self.conv1(x) x = self.conv1(x)
x = act(x, 0.1) if self.subpixel:
x = nn.space_to_depth(x, 2)
if self.use_activator:
x = tf.nn.leaky_relu(x, 0.1)
return x return x
def get_out_ch(self): def get_out_ch(self):
return self.out_ch return (self.out_ch // 4) * 4 if self.subpixel else self.out_ch
class DownscaleBlock(nn.ModelBase): class DownscaleBlock(nn.ModelBase):
def on_build(self, in_ch, ch, n_downscales, kernel_size): def on_build(self, in_ch, ch, n_downscales, kernel_size, dilations=1, subpixel=True):
self.downs = [] self.downs = []
last_ch = in_ch last_ch = in_ch
for i in range(n_downscales): for i in range(n_downscales):
cur_ch = ch*( min(2**i, 8) ) cur_ch = ch*( min(2**i, 8) )
self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size)) self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size, dilations=dilations, subpixel=subpixel) )
last_ch = self.downs[-1].get_out_ch() last_ch = self.downs[-1].get_out_ch()
def forward(self, inp): def forward(self, inp):
@ -64,77 +58,66 @@ class DeepFakeArchi(nn.ArchiBase):
return x return x
class Upscale(nn.ModelBase): class Upscale(nn.ModelBase):
def on_build(self, in_ch, out_ch, kernel_size=3): def on_build(self, in_ch, out_ch, kernel_size=3 ):
self.conv1 = nn.Conv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype) self.conv1 = nn.Conv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME')
def forward(self, x): def forward(self, x):
x = self.conv1(x) x = self.conv1(x)
x = act(x, 0.1) x = tf.nn.leaky_relu(x, 0.1)
x = nn.depth_to_space(x, 2) x = nn.depth_to_space(x, 2)
return x return x
class ResidualBlock(nn.ModelBase): class ResidualBlock(nn.ModelBase):
def on_build(self, ch, kernel_size=3): def on_build(self, ch, kernel_size=3 ):
self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype) self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype) self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
def forward(self, inp): def forward(self, inp):
x = self.conv1(inp) x = self.conv1(inp)
x = act(x, 0.2) x = tf.nn.leaky_relu(x, 0.2)
x = self.conv2(x) x = self.conv2(x)
x = act(inp + x, 0.2) x = tf.nn.leaky_relu(inp + x, 0.2)
return x return x
class UpdownResidualBlock(nn.ModelBase):
def on_build(self, ch, inner_ch, kernel_size=3 ):
self.up = Upscale (ch, inner_ch, kernel_size=kernel_size)
self.res = ResidualBlock (inner_ch, kernel_size=kernel_size)
self.down = Downscale (inner_ch, ch, kernel_size=kernel_size, use_activator=False)
def forward(self, inp):
x = self.up(inp)
x = upx = self.res(x)
x = self.down(x)
x = x + inp
x = tf.nn.leaky_relu(x, 0.2)
return x, upx
class Encoder(nn.ModelBase): class Encoder(nn.ModelBase):
def __init__(self, in_ch, e_ch, **kwargs ): def on_build(self, in_ch, e_ch, is_hd):
self.in_ch = in_ch self.is_hd=is_hd
self.e_ch = e_ch if self.is_hd:
super().__init__(**kwargs) self.down1 = DownscaleBlock(in_ch, e_ch*2, n_downscales=4, kernel_size=3, dilations=1)
self.down2 = DownscaleBlock(in_ch, e_ch*2, n_downscales=4, kernel_size=5, dilations=1)
def on_build(self): self.down3 = DownscaleBlock(in_ch, e_ch//2, n_downscales=4, kernel_size=5, dilations=2)
if 't' in opts: self.down4 = DownscaleBlock(in_ch, e_ch//2, n_downscales=4, kernel_size=7, dilations=2)
self.down1 = Downscale(self.in_ch, self.e_ch, kernel_size=5)
self.res1 = ResidualBlock(self.e_ch)
self.down2 = Downscale(self.e_ch, self.e_ch*2, kernel_size=5)
self.down3 = Downscale(self.e_ch*2, self.e_ch*4, kernel_size=5)
self.down4 = Downscale(self.e_ch*4, self.e_ch*8, kernel_size=5)
self.down5 = Downscale(self.e_ch*8, self.e_ch*8, kernel_size=5)
self.res5 = ResidualBlock(self.e_ch*8)
else: else:
self.down1 = DownscaleBlock(self.in_ch, self.e_ch, n_downscales=4 if 't' not in opts else 5, kernel_size=5) self.down1 = DownscaleBlock(in_ch, e_ch, n_downscales=4, kernel_size=5, dilations=1, subpixel=False)
def forward(self, x): def forward(self, inp):
if use_fp16: if self.is_hd:
x = tf.cast(x, tf.float16) x = tf.concat([ nn.flatten(self.down1(inp)),
nn.flatten(self.down2(inp)),
if 't' in opts: nn.flatten(self.down3(inp)),
x = self.down1(x) nn.flatten(self.down4(inp)) ], -1 )
x = self.res1(x)
x = self.down2(x)
x = self.down3(x)
x = self.down4(x)
x = self.down5(x)
x = self.res5(x)
else: else:
x = self.down1(x) x = nn.flatten(self.down1(inp))
x = nn.flatten(x)
if 'u' in opts:
x = nn.pixel_norm(x, axes=-1)
if use_fp16:
x = tf.cast(x, tf.float32)
return x return x
def get_out_res(self, res): lowest_dense_res = resolution // 16
return res // ( (2**4) if 't' not in opts else (2**5) )
def get_out_ch(self):
return self.e_ch * 8
lowest_dense_res = resolution // (32 if 'd' in opts else 16)
class Inter(nn.ModelBase): class Inter(nn.ModelBase):
def __init__(self, in_ch, ae_ch, ae_out_ch, **kwargs): def __init__(self, in_ch, ae_ch, ae_out_ch, is_hd=False, **kwargs):
self.in_ch, self.ae_ch, self.ae_out_ch = in_ch, ae_ch, ae_out_ch self.in_ch, self.ae_ch, self.ae_out_ch = in_ch, ae_ch, ae_out_ch
super().__init__(**kwargs) super().__init__(**kwargs)
@ -143,81 +126,335 @@ class DeepFakeArchi(nn.ArchiBase):
self.dense1 = nn.Dense( in_ch, ae_ch ) self.dense1 = nn.Dense( in_ch, ae_ch )
self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch ) self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch )
if 't' not in opts: self.upscale1 = Upscale(ae_out_ch, ae_out_ch)
self.upscale1 = Upscale(ae_out_ch, ae_out_ch)
def forward(self, inp): def forward(self, inp):
x = inp x = self.dense1(inp)
x = self.dense1(x)
x = self.dense2(x) x = self.dense2(x)
x = nn.reshape_4D (x, lowest_dense_res, lowest_dense_res, self.ae_out_ch) x = nn.reshape_4D (x, lowest_dense_res, lowest_dense_res, self.ae_out_ch)
x = self.upscale1(x)
if use_fp16:
x = tf.cast(x, tf.float16)
if 't' not in opts:
x = self.upscale1(x)
return x return x
def get_out_res(self): @staticmethod
return lowest_dense_res * 2 if 't' not in opts else lowest_dense_res def get_code_res():
return lowest_dense_res
def get_out_ch(self): def get_out_ch(self):
return self.ae_out_ch return self.ae_out_ch
class Decoder(nn.ModelBase): class Decoder(nn.ModelBase):
def on_build(self, in_ch, d_ch, d_mask_ch): def on_build(self, in_ch, d_ch, d_mask_ch, is_hd ):
if 't' not in opts: self.is_hd = is_hd
self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3)
self.upscale1 = Upscale(d_ch*8, d_ch*4, kernel_size=3) self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3)
self.upscale2 = Upscale(d_ch*4, d_ch*2, kernel_size=3) self.upscale1 = Upscale(d_ch*8, d_ch*4, kernel_size=3)
self.upscale2 = Upscale(d_ch*4, d_ch*2, kernel_size=3)
if is_hd:
self.res0 = UpdownResidualBlock(in_ch, d_ch*8, kernel_size=3)
self.res1 = UpdownResidualBlock(d_ch*8, d_ch*4, kernel_size=3)
self.res2 = UpdownResidualBlock(d_ch*4, d_ch*2, kernel_size=3)
self.res3 = UpdownResidualBlock(d_ch*2, d_ch, kernel_size=3)
else:
self.res0 = ResidualBlock(d_ch*8, kernel_size=3) self.res0 = ResidualBlock(d_ch*8, kernel_size=3)
self.res1 = ResidualBlock(d_ch*4, kernel_size=3) self.res1 = ResidualBlock(d_ch*4, kernel_size=3)
self.res2 = ResidualBlock(d_ch*2, kernel_size=3) self.res2 = ResidualBlock(d_ch*2, kernel_size=3)
self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3) self.out_conv = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='SAME')
self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3)
self.upscalem2 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3)
self.out_conv = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='SAME', dtype=conv_dtype) self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3)
self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3)
self.upscalem2 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3)
self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME')
if 'd' in opts: def forward(self, inp):
self.out_conv1 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype) z = inp
self.out_conv2 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
self.out_conv3 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype) if self.is_hd:
self.upscalem3 = Upscale(d_mask_ch*2, d_mask_ch*1, kernel_size=3) x, upx = self.res0(z)
self.out_convm = nn.Conv2D( d_mask_ch*1, 1, kernel_size=1, padding='SAME', dtype=conv_dtype) x = self.upscale0(x)
else: x = tf.nn.leaky_relu(x + upx, 0.2)
self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME', dtype=conv_dtype) x, upx = self.res1(x)
x = self.upscale1(x)
x = tf.nn.leaky_relu(x + upx, 0.2)
x, upx = self.res2(x)
x = self.upscale2(x)
x = tf.nn.leaky_relu(x + upx, 0.2)
x, upx = self.res3(x)
else: else:
self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3) x = self.upscale0(z)
self.upscale1 = Upscale(d_ch*8, d_ch*8, kernel_size=3) x = self.res0(x)
self.upscale2 = Upscale(d_ch*8, d_ch*4, kernel_size=3) x = self.upscale1(x)
self.upscale3 = Upscale(d_ch*4, d_ch*2, kernel_size=3) x = self.res1(x)
self.res0 = ResidualBlock(d_ch*8, kernel_size=3) x = self.upscale2(x)
self.res1 = ResidualBlock(d_ch*8, kernel_size=3) x = self.res2(x)
self.res2 = ResidualBlock(d_ch*4, kernel_size=3)
self.res3 = ResidualBlock(d_ch*2, kernel_size=3)
self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3) m = self.upscalem0(z)
self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*8, kernel_size=3) m = self.upscalem1(m)
self.upscalem2 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3) m = self.upscalem2(m)
self.upscalem3 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3)
self.out_conv = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='SAME', dtype=conv_dtype)
if 'd' in opts: return tf.nn.sigmoid(self.out_conv(x)), \
self.out_conv1 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype) tf.nn.sigmoid(self.out_convm(m))
self.out_conv2 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
self.out_conv3 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
self.upscalem4 = Upscale(d_mask_ch*2, d_mask_ch*1, kernel_size=3)
self.out_convm = nn.Conv2D( d_mask_ch*1, 1, kernel_size=1, padding='SAME', dtype=conv_dtype)
else:
self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME', dtype=conv_dtype)
elif mod == 'quick':
class Downscale(nn.ModelBase):
def __init__(self, in_ch, out_ch, kernel_size=5, dilations=1, subpixel=True, use_activator=True, *kwargs ):
self.in_ch = in_ch
self.out_ch = out_ch
self.kernel_size = kernel_size
self.dilations = dilations
self.subpixel = subpixel
self.use_activator = use_activator
super().__init__(*kwargs)
def on_build(self, *args, **kwargs ):
self.conv1 = nn.Conv2D( self.in_ch,
self.out_ch // (4 if self.subpixel else 1),
kernel_size=self.kernel_size,
strides=1 if self.subpixel else 2,
padding='SAME', dilations=self.dilations )
def forward(self, x):
x = self.conv1(x)
if self.subpixel:
x = nn.space_to_depth(x, 2)
if self.use_activator:
x = nn.gelu(x)
return x
def get_out_ch(self):
return (self.out_ch // 4) * 4 if self.subpixel else self.out_ch
class DownscaleBlock(nn.ModelBase):
def on_build(self, in_ch, ch, n_downscales, kernel_size, dilations=1, subpixel=True):
self.downs = []
last_ch = in_ch
for i in range(n_downscales):
cur_ch = ch*( min(2**i, 8) )
self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size, dilations=dilations, subpixel=subpixel) )
last_ch = self.downs[-1].get_out_ch()
def forward(self, inp):
x = inp
for down in self.downs:
x = down(x)
return x
class Upscale(nn.ModelBase):
def on_build(self, in_ch, out_ch, kernel_size=3 ):
self.conv1 = nn.Conv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME')
def forward(self, x):
x = self.conv1(x)
x = nn.gelu(x)
x = nn.depth_to_space(x, 2)
return x
class ResidualBlock(nn.ModelBase):
def on_build(self, ch, kernel_size=3 ):
self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
def forward(self, inp):
x = self.conv1(inp)
x = nn.gelu(x)
x = self.conv2(x)
x = inp + x
x = nn.gelu(x)
return x
class Encoder(nn.ModelBase):
def on_build(self, in_ch, e_ch):
self.down1 = DownscaleBlock(in_ch, e_ch, n_downscales=4, kernel_size=5)
def forward(self, inp):
return nn.flatten(self.down1(inp))
lowest_dense_res = resolution // 16
class Inter(nn.ModelBase):
def __init__(self, in_ch, ae_ch, ae_out_ch, d_ch, **kwargs):
self.in_ch, self.ae_ch, self.ae_out_ch, self.d_ch = in_ch, ae_ch, ae_out_ch, d_ch
super().__init__(**kwargs)
def on_build(self):
in_ch, ae_ch, ae_out_ch, d_ch = self.in_ch, self.ae_ch, self.ae_out_ch, self.d_ch
self.dense1 = nn.Dense( in_ch, ae_ch, kernel_initializer=tf.initializers.orthogonal )
self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch, kernel_initializer=tf.initializers.orthogonal )
self.upscale1 = Upscale(ae_out_ch, d_ch*8)
self.res1 = ResidualBlock(d_ch*8)
def forward(self, inp):
x = self.dense1(inp)
x = self.dense2(x)
x = nn.reshape_4D (x, lowest_dense_res, lowest_dense_res, self.ae_out_ch)
x = self.upscale1(x)
x = self.res1(x)
return x
def get_out_ch(self):
return self.ae_out_ch
class Decoder(nn.ModelBase):
def on_build(self, in_ch, d_ch):
self.upscale1 = Upscale(in_ch, d_ch*4)
self.res1 = ResidualBlock(d_ch*4)
self.upscale2 = Upscale(d_ch*4, d_ch*2)
self.res2 = ResidualBlock(d_ch*2)
self.upscale3 = Upscale(d_ch*2, d_ch*1)
self.res3 = ResidualBlock(d_ch*1)
self.upscalem1 = Upscale(in_ch, d_ch)
self.upscalem2 = Upscale(d_ch, d_ch//2)
self.upscalem3 = Upscale(d_ch//2, d_ch//2)
self.out_conv = nn.Conv2D( d_ch*1, 3, kernel_size=1, padding='SAME')
self.out_convm = nn.Conv2D( d_ch//2, 1, kernel_size=1, padding='SAME')
def forward(self, inp):
z = inp
x = self.upscale1 (z)
x = self.res1 (x)
x = self.upscale2 (x)
x = self.res2 (x)
x = self.upscale3 (x)
x = self.res3 (x)
y = self.upscalem1 (z)
y = self.upscalem2 (y)
y = self.upscalem3 (y)
return tf.nn.sigmoid(self.out_conv(x)), \
tf.nn.sigmoid(self.out_convm(y))
elif mod == 'uhd':
class Downscale(nn.ModelBase):
def __init__(self, in_ch, out_ch, kernel_size=5, dilations=1, subpixel=True, use_activator=True, *kwargs ):
self.in_ch = in_ch
self.out_ch = out_ch
self.kernel_size = kernel_size
self.dilations = dilations
self.subpixel = subpixel
self.use_activator = use_activator
super().__init__(*kwargs)
def on_build(self, *args, **kwargs ):
self.conv1 = nn.Conv2D( self.in_ch,
self.out_ch // (4 if self.subpixel else 1),
kernel_size=self.kernel_size,
strides=1 if self.subpixel else 2,
padding='SAME', dilations=self.dilations)
def forward(self, x):
x = self.conv1(x)
if self.subpixel:
x = nn.space_to_depth(x, 2)
if self.use_activator:
x = tf.nn.leaky_relu(x, 0.1)
return x
def get_out_ch(self):
return (self.out_ch // 4) * 4 if self.subpixel else self.out_ch
class DownscaleBlock(nn.ModelBase):
def on_build(self, in_ch, ch, n_downscales, kernel_size, dilations=1, subpixel=True):
self.downs = []
last_ch = in_ch
for i in range(n_downscales):
cur_ch = ch*( min(2**i, 8) )
self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size, dilations=dilations, subpixel=subpixel) )
last_ch = self.downs[-1].get_out_ch()
def forward(self, inp):
x = inp
for down in self.downs:
x = down(x)
return x
class Upscale(nn.ModelBase):
def on_build(self, in_ch, out_ch, kernel_size=3 ):
self.conv1 = nn.Conv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME')
def forward(self, x):
x = self.conv1(x)
x = tf.nn.leaky_relu(x, 0.1)
x = nn.depth_to_space(x, 2)
return x
class ResidualBlock(nn.ModelBase):
def on_build(self, ch, kernel_size=3 ):
self.conv1 = nn.Conv2D( ch, ch*2, kernel_size=kernel_size, padding='SAME')
self.conv2 = nn.Conv2D( ch*2, ch, kernel_size=kernel_size, padding='SAME')
self.scale_add = nn.ScaleAdd(ch)
def forward(self, inp):
x = self.conv1(inp)
x = tf.nn.leaky_relu(x, 0.2)
x = self.conv2(x)
x = tf.nn.leaky_relu(x, 0.2)
x = self.scale_add([inp, x])
return x
class Encoder(nn.ModelBase):
def on_build(self, in_ch, e_ch, **kwargs):
self.down1 = DownscaleBlock(in_ch, e_ch, n_downscales=4, kernel_size=5, dilations=1, subpixel=False)
def forward(self, inp):
x = nn.flatten(self.down1(inp))
return x
lowest_dense_res = resolution // 16
class Inter(nn.ModelBase):
def on_build(self, in_ch, ae_ch, ae_out_ch, **kwargs):
self.ae_out_ch = ae_out_ch
self.dense_norm = nn.DenseNorm()
self.dense1 = nn.Dense( in_ch, ae_ch )
self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch )
self.upscale1 = Upscale(ae_out_ch, ae_out_ch)
def forward(self, inp):
x = self.dense_norm(inp)
x = self.dense1(x)
x = self.dense2(x)
x = nn.reshape_4D (x, lowest_dense_res, lowest_dense_res, self.ae_out_ch)
x = self.upscale1(x)
return x
@staticmethod
def get_code_res():
return lowest_dense_res
def get_out_ch(self):
return self.ae_out_ch
class Decoder(nn.ModelBase):
def on_build(self, in_ch, d_ch, d_mask_ch, **kwargs ):
self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3)
self.upscale1 = Upscale(d_ch*8, d_ch*4, kernel_size=3)
self.upscale2 = Upscale(d_ch*4, d_ch*2, kernel_size=3)
self.res0 = ResidualBlock(d_ch*8, kernel_size=3)
self.res1 = ResidualBlock(d_ch*4, kernel_size=3)
self.res2 = ResidualBlock(d_ch*2, kernel_size=3)
self.out_conv = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='SAME')
self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3)
self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3)
self.upscalem2 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3)
self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME')
def forward(self, inp):
z = inp
def forward(self, z):
x = self.upscale0(z) x = self.upscale0(z)
x = self.res0(x) x = self.res0(x)
x = self.upscale1(x) x = self.upscale1(x)
@ -225,38 +462,12 @@ class DeepFakeArchi(nn.ArchiBase):
x = self.upscale2(x) x = self.upscale2(x)
x = self.res2(x) x = self.res2(x)
if 't' in opts:
x = self.upscale3(x)
x = self.res3(x)
if 'd' in opts:
x = tf.nn.sigmoid( nn.depth_to_space(tf.concat( (self.out_conv(x),
self.out_conv1(x),
self.out_conv2(x),
self.out_conv3(x)), nn.conv2d_ch_axis), 2) )
else:
x = tf.nn.sigmoid(self.out_conv(x))
m = self.upscalem0(z) m = self.upscalem0(z)
m = self.upscalem1(m) m = self.upscalem1(m)
m = self.upscalem2(m) m = self.upscalem2(m)
if 't' in opts: return tf.nn.sigmoid(self.out_conv(x)), \
m = self.upscalem3(m) tf.nn.sigmoid(self.out_convm(m))
if 'd' in opts:
m = self.upscalem4(m)
else:
if 'd' in opts:
m = self.upscalem3(m)
m = tf.nn.sigmoid(self.out_convm(m))
if use_fp16:
x = tf.cast(x, tf.float32)
m = tf.cast(m, tf.float32)
return x, m
self.Encoder = Encoder self.Encoder = Encoder
self.Inter = Inter self.Inter = Inter

View file

@ -1,19 +1,12 @@
import sys import sys
import ctypes import ctypes
import os import os
import multiprocessing
import json
import time
from pathlib import Path
from core.interact import interact as io
class Device(object): class Device(object):
def __init__(self, index, tf_dev_type, name, total_mem, free_mem): def __init__(self, index, name, total_mem, free_mem, cc=0):
self.index = index self.index = index
self.tf_dev_type = tf_dev_type
self.name = name self.name = name
self.cc = cc
self.total_mem = total_mem self.total_mem = total_mem
self.total_mem_gb = total_mem / 1024**3 self.total_mem_gb = total_mem / 1024**3
self.free_mem = free_mem self.free_mem = free_mem
@ -89,135 +82,10 @@ class Devices(object):
result.append (device) result.append (device)
return Devices(result) return Devices(result)
@staticmethod
def _get_tf_devices_proc(q : multiprocessing.Queue):
if sys.platform[0:3] == 'win':
compute_cache_path = Path(os.environ['APPDATA']) / 'NVIDIA' / ('ComputeCache_ALL')
os.environ['CUDA_CACHE_PATH'] = str(compute_cache_path)
if not compute_cache_path.exists():
io.log_info("Caching GPU kernels...")
compute_cache_path.mkdir(parents=True, exist_ok=True)
import tensorflow
tf_version = tensorflow.version.VERSION
#if tf_version is None:
# tf_version = tensorflow.version.GIT_VERSION
if tf_version[0] == 'v':
tf_version = tf_version[1:]
if tf_version[0] == '2':
tf = tensorflow.compat.v1
else:
tf = tensorflow
import logging
# Disable tensorflow warnings
tf_logger = logging.getLogger('tensorflow')
tf_logger.setLevel(logging.ERROR)
from tensorflow.python.client import device_lib
devices = []
physical_devices = device_lib.list_local_devices()
physical_devices_f = {}
for dev in physical_devices:
dev_type = dev.device_type
dev_tf_name = dev.name
dev_tf_name = dev_tf_name[ dev_tf_name.index(dev_type) : ]
dev_idx = int(dev_tf_name.split(':')[-1])
if dev_type in ['GPU','DML']:
dev_name = dev_tf_name
dev_desc = dev.physical_device_desc
if len(dev_desc) != 0:
if dev_desc[0] == '{':
dev_desc_json = json.loads(dev_desc)
dev_desc_json_name = dev_desc_json.get('name',None)
if dev_desc_json_name is not None:
dev_name = dev_desc_json_name
else:
for param, value in ( v.split(':') for v in dev_desc.split(',') ):
param = param.strip()
value = value.strip()
if param == 'name':
dev_name = value
break
physical_devices_f[dev_idx] = (dev_type, dev_name, dev.memory_limit)
q.put(physical_devices_f)
time.sleep(0.1)
@staticmethod @staticmethod
def initialize_main_env(): def initialize_main_env():
if int(os.environ.get("NN_DEVICES_INITIALIZED", 0)) != 0:
return
if 'CUDA_VISIBLE_DEVICES' in os.environ.keys():
os.environ.pop('CUDA_VISIBLE_DEVICES')
os.environ['TF_DIRECTML_KERNEL_CACHE_SIZE'] = '2500'
os.environ['CUDA_CACHE_MAXSIZE'] = '2147483647'
os.environ['TF_MIN_GPU_MULTIPROCESSOR_COUNT'] = '2'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # tf log errors only
q = multiprocessing.Queue()
p = multiprocessing.Process(target=Devices._get_tf_devices_proc, args=(q,), daemon=True)
p.start()
p.join()
visible_devices = q.get()
os.environ['NN_DEVICES_INITIALIZED'] = '1' os.environ['NN_DEVICES_INITIALIZED'] = '1'
os.environ['NN_DEVICES_COUNT'] = str(len(visible_devices)) os.environ['NN_DEVICES_COUNT'] = '0'
for i in visible_devices:
dev_type, name, total_mem = visible_devices[i]
os.environ[f'NN_DEVICE_{i}_TF_DEV_TYPE'] = dev_type
os.environ[f'NN_DEVICE_{i}_NAME'] = name
os.environ[f'NN_DEVICE_{i}_TOTAL_MEM'] = str(total_mem)
os.environ[f'NN_DEVICE_{i}_FREE_MEM'] = str(total_mem)
@staticmethod
def getDevices():
if Devices.all_devices is None:
if int(os.environ.get("NN_DEVICES_INITIALIZED", 0)) != 1:
raise Exception("nn devices are not initialized. Run initialize_main_env() in main process.")
devices = []
for i in range ( int(os.environ['NN_DEVICES_COUNT']) ):
devices.append ( Device(index=i,
tf_dev_type=os.environ[f'NN_DEVICE_{i}_TF_DEV_TYPE'],
name=os.environ[f'NN_DEVICE_{i}_NAME'],
total_mem=int(os.environ[f'NN_DEVICE_{i}_TOTAL_MEM']),
free_mem=int(os.environ[f'NN_DEVICE_{i}_FREE_MEM']), )
)
Devices.all_devices = Devices(devices)
return Devices.all_devices
"""
# {'name' : name.split(b'\0', 1)[0].decode(),
# 'total_mem' : totalMem.value
# }
return
min_cc = int(os.environ.get("TF_MIN_REQ_CAP", 35)) min_cc = int(os.environ.get("TF_MIN_REQ_CAP", 35))
libnames = ('libcuda.so', 'libcuda.dylib', 'nvcuda.dll') libnames = ('libcuda.so', 'libcuda.dylib', 'nvcuda.dll')
@ -270,4 +138,70 @@ class Devices(object):
os.environ[f'NN_DEVICE_{i}_TOTAL_MEM'] = str(device['total_mem']) os.environ[f'NN_DEVICE_{i}_TOTAL_MEM'] = str(device['total_mem'])
os.environ[f'NN_DEVICE_{i}_FREE_MEM'] = str(device['free_mem']) os.environ[f'NN_DEVICE_{i}_FREE_MEM'] = str(device['free_mem'])
os.environ[f'NN_DEVICE_{i}_CC'] = str(device['cc']) os.environ[f'NN_DEVICE_{i}_CC'] = str(device['cc'])
@staticmethod
def getDevices():
if Devices.all_devices is None:
if int(os.environ.get("NN_DEVICES_INITIALIZED", 0)) != 1:
raise Exception("nn devices are not initialized. Run initialize_main_env() in main process.")
devices = []
for i in range ( int(os.environ['NN_DEVICES_COUNT']) ):
devices.append ( Device(index=i,
name=os.environ[f'NN_DEVICE_{i}_NAME'],
total_mem=int(os.environ[f'NN_DEVICE_{i}_TOTAL_MEM']),
free_mem=int(os.environ[f'NN_DEVICE_{i}_FREE_MEM']),
cc=int(os.environ[f'NN_DEVICE_{i}_CC']) ))
Devices.all_devices = Devices(devices)
return Devices.all_devices
"""
if Devices.all_devices is None:
min_cc = int(os.environ.get("TF_MIN_REQ_CAP", 35))
libnames = ('libcuda.so', 'libcuda.dylib', 'nvcuda.dll')
for libname in libnames:
try:
cuda = ctypes.CDLL(libname)
except:
continue
else:
break
else:
return Devices([])
nGpus = ctypes.c_int()
name = b' ' * 200
cc_major = ctypes.c_int()
cc_minor = ctypes.c_int()
freeMem = ctypes.c_size_t()
totalMem = ctypes.c_size_t()
result = ctypes.c_int()
device = ctypes.c_int()
context = ctypes.c_void_p()
error_str = ctypes.c_char_p()
devices = []
if cuda.cuInit(0) == 0 and \
cuda.cuDeviceGetCount(ctypes.byref(nGpus)) == 0:
for i in range(nGpus.value):
if cuda.cuDeviceGet(ctypes.byref(device), i) != 0 or \
cuda.cuDeviceGetName(ctypes.c_char_p(name), len(name), device) != 0 or \
cuda.cuDeviceComputeCapability(ctypes.byref(cc_major), ctypes.byref(cc_minor), device) != 0:
continue
if cuda.cuCtxCreate_v2(ctypes.byref(context), 0, device) == 0:
if cuda.cuMemGetInfo_v2(ctypes.byref(freeMem), ctypes.byref(totalMem)) == 0:
cc = cc_major.value * 10 + cc_minor.value
if cc >= min_cc:
devices.append ( Device(index=i,
name=name.split(b'\0', 1)[0].decode(),
total_mem=totalMem.value,
free_mem=freeMem.value,
cc=cc) )
cuda.cuCtxDetach(context)
Devices.all_devices = Devices(devices)
return Devices.all_devices
""" """

View file

@ -23,13 +23,28 @@ class Conv2D(nn.LayerBase):
if padding == "SAME": if padding == "SAME":
padding = ( (kernel_size - 1) * dilations + 1 ) // 2 padding = ( (kernel_size - 1) * dilations + 1 ) // 2
elif padding == "VALID": elif padding == "VALID":
padding = None padding = 0
else: else:
raise ValueError ("Wrong padding type. Should be VALID SAME or INT or 4x INTs") raise ValueError ("Wrong padding type. Should be VALID SAME or INT or 4x INTs")
if isinstance(padding, int):
if padding != 0:
if nn.data_format == "NHWC":
padding = [ [0,0], [padding,padding], [padding,padding], [0,0] ]
else:
padding = [ [0,0], [0,0], [padding,padding], [padding,padding] ]
else:
padding = None
if nn.data_format == "NHWC":
strides = [1,strides,strides,1]
else: else:
padding = int(padding) strides = [1,1,strides,strides]
if nn.data_format == "NHWC":
dilations = [1,dilations,dilations,1]
else:
dilations = [1,1,dilations,dilations]
self.in_ch = in_ch self.in_ch = in_ch
self.out_ch = out_ch self.out_ch = out_ch
@ -55,8 +70,8 @@ class Conv2D(nn.LayerBase):
if kernel_initializer is None: if kernel_initializer is None:
kernel_initializer = tf.initializers.random_normal(0, 1.0, dtype=self.dtype) kernel_initializer = tf.initializers.random_normal(0, 1.0, dtype=self.dtype)
#if kernel_initializer is None: if kernel_initializer is None:
# kernel_initializer = nn.initializers.ca() kernel_initializer = nn.initializers.ca()
self.weight = tf.get_variable("weight", (self.kernel_size,self.kernel_size,self.in_ch,self.out_ch), dtype=self.dtype, initializer=kernel_initializer, trainable=self.trainable ) self.weight = tf.get_variable("weight", (self.kernel_size,self.kernel_size,self.in_ch,self.out_ch), dtype=self.dtype, initializer=kernel_initializer, trainable=self.trainable )
@ -78,27 +93,10 @@ class Conv2D(nn.LayerBase):
if self.use_wscale: if self.use_wscale:
weight = weight * self.wscale weight = weight * self.wscale
padding = self.padding if self.padding is not None:
if padding is not None: x = tf.pad (x, self.padding, mode='CONSTANT')
if nn.data_format == "NHWC":
padding = [ [0,0], [padding,padding], [padding,padding], [0,0] ]
else:
padding = [ [0,0], [0,0], [padding,padding], [padding,padding] ]
x = tf.pad (x, padding, mode='CONSTANT')
strides = self.strides x = tf.nn.conv2d(x, weight, self.strides, 'VALID', dilations=self.dilations, data_format=nn.data_format)
if nn.data_format == "NHWC":
strides = [1,strides,strides,1]
else:
strides = [1,1,strides,strides]
dilations = self.dilations
if nn.data_format == "NHWC":
dilations = [1,dilations,dilations,1]
else:
dilations = [1,1,dilations,dilations]
x = tf.nn.conv2d(x, weight, strides, 'VALID', dilations=dilations, data_format=nn.data_format)
if self.use_bias: if self.use_bias:
if nn.data_format == "NHWC": if nn.data_format == "NHWC":
bias = tf.reshape (self.bias, (1,1,1,self.out_ch) ) bias = tf.reshape (self.bias, (1,1,1,self.out_ch) )

View file

@ -38,8 +38,8 @@ class Conv2DTranspose(nn.LayerBase):
if kernel_initializer is None: if kernel_initializer is None:
kernel_initializer = tf.initializers.random_normal(0, 1.0, dtype=self.dtype) kernel_initializer = tf.initializers.random_normal(0, 1.0, dtype=self.dtype)
#if kernel_initializer is None: if kernel_initializer is None:
# kernel_initializer = nn.initializers.ca() kernel_initializer = nn.initializers.ca()
self.weight = tf.get_variable("weight", (self.kernel_size,self.kernel_size,self.out_ch,self.in_ch), dtype=self.dtype, initializer=kernel_initializer, trainable=self.trainable ) self.weight = tf.get_variable("weight", (self.kernel_size,self.kernel_size,self.out_ch,self.in_ch), dtype=self.dtype, initializer=kernel_initializer, trainable=self.trainable )
if self.use_bias: if self.use_bias:

View file

@ -1,110 +0,0 @@
import numpy as np
from core.leras import nn
tf = nn.tf
class DepthwiseConv2D(nn.LayerBase):
"""
default kernel_initializer - CA
use_wscale bool enables equalized learning rate, if kernel_initializer is None, it will be forced to random_normal
"""
def __init__(self, in_ch, kernel_size, strides=1, padding='SAME', depth_multiplier=1, dilations=1, use_bias=True, use_wscale=False, kernel_initializer=None, bias_initializer=None, trainable=True, dtype=None, **kwargs ):
if not isinstance(strides, int):
raise ValueError ("strides must be an int type")
if not isinstance(dilations, int):
raise ValueError ("dilations must be an int type")
kernel_size = int(kernel_size)
if dtype is None:
dtype = nn.floatx
if isinstance(padding, str):
if padding == "SAME":
padding = ( (kernel_size - 1) * dilations + 1 ) // 2
elif padding == "VALID":
padding = 0
else:
raise ValueError ("Wrong padding type. Should be VALID SAME or INT or 4x INTs")
if isinstance(padding, int):
if padding != 0:
if nn.data_format == "NHWC":
padding = [ [0,0], [padding,padding], [padding,padding], [0,0] ]
else:
padding = [ [0,0], [0,0], [padding,padding], [padding,padding] ]
else:
padding = None
if nn.data_format == "NHWC":
strides = [1,strides,strides,1]
else:
strides = [1,1,strides,strides]
if nn.data_format == "NHWC":
dilations = [1,dilations,dilations,1]
else:
dilations = [1,1,dilations,dilations]
self.in_ch = in_ch
self.depth_multiplier = depth_multiplier
self.kernel_size = kernel_size
self.strides = strides
self.padding = padding
self.dilations = dilations
self.use_bias = use_bias
self.use_wscale = use_wscale
self.kernel_initializer = kernel_initializer
self.bias_initializer = bias_initializer
self.trainable = trainable
self.dtype = dtype
super().__init__(**kwargs)
def build_weights(self):
kernel_initializer = self.kernel_initializer
if self.use_wscale:
gain = 1.0 if self.kernel_size == 1 else np.sqrt(2)
fan_in = self.kernel_size*self.kernel_size*self.in_ch
he_std = gain / np.sqrt(fan_in)
self.wscale = tf.constant(he_std, dtype=self.dtype )
if kernel_initializer is None:
kernel_initializer = tf.initializers.random_normal(0, 1.0, dtype=self.dtype)
#if kernel_initializer is None:
# kernel_initializer = nn.initializers.ca()
self.weight = tf.get_variable("weight", (self.kernel_size,self.kernel_size,self.in_ch,self.depth_multiplier), dtype=self.dtype, initializer=kernel_initializer, trainable=self.trainable )
if self.use_bias:
bias_initializer = self.bias_initializer
if bias_initializer is None:
bias_initializer = tf.initializers.zeros(dtype=self.dtype)
self.bias = tf.get_variable("bias", (self.in_ch*self.depth_multiplier,), dtype=self.dtype, initializer=bias_initializer, trainable=self.trainable )
def get_weights(self):
weights = [self.weight]
if self.use_bias:
weights += [self.bias]
return weights
def forward(self, x):
weight = self.weight
if self.use_wscale:
weight = weight * self.wscale
if self.padding is not None:
x = tf.pad (x, self.padding, mode='CONSTANT')
x = tf.nn.depthwise_conv2d(x, weight, self.strides, 'VALID', data_format=nn.data_format)
if self.use_bias:
if nn.data_format == "NHWC":
bias = tf.reshape (self.bias, (1,1,1,self.in_ch*self.depth_multiplier) )
else:
bias = tf.reshape (self.bias, (1,self.in_ch*self.depth_multiplier,1,1) )
x = tf.add(x, bias)
return x
def __str__(self):
r = f"{self.__class__.__name__} : in_ch:{self.in_ch} depth_multiplier:{self.depth_multiplier} "
return r
nn.DepthwiseConv2D = DepthwiseConv2D

View file

@ -46,9 +46,7 @@ class Saveable():
raise Exception("name must be defined.") raise Exception("name must be defined.")
name = self.name name = self.name
for w, w_val in zip(weights, nn.tf_sess.run (weights)):
for w in weights:
w_val = nn.tf_sess.run (w).copy()
w_name_split = w.name.split('/', 1) w_name_split = w.name.split('/', 1)
if name != w_name_split[0]: if name != w_name_split[0]:
raise Exception("weight first name != Saveable.name") raise Exception("weight first name != Saveable.name")
@ -78,27 +76,24 @@ class Saveable():
if self.name is None: if self.name is None:
raise Exception("name must be defined.") raise Exception("name must be defined.")
try: tuples = []
tuples = [] for w in weights:
for w in weights: w_name_split = w.name.split('/')
w_name_split = w.name.split('/') if self.name != w_name_split[0]:
if self.name != w_name_split[0]: raise Exception("weight first name != Saveable.name")
raise Exception("weight first name != Saveable.name")
sub_w_name = "/".join(w_name_split[1:]) sub_w_name = "/".join(w_name_split[1:])
w_val = d.get(sub_w_name, None) w_val = d.get(sub_w_name, None)
if w_val is None: if w_val is None:
#io.log_err(f"Weight {w.name} was not loaded from file {filename}") #io.log_err(f"Weight {w.name} was not loaded from file {filename}")
tuples.append ( (w, w.initializer) ) tuples.append ( (w, w.initializer) )
else: else:
w_val = np.reshape( w_val, w.shape.as_list() ) w_val = np.reshape( w_val, w.shape.as_list() )
tuples.append ( (w, w_val) ) tuples.append ( (w, w_val) )
nn.batch_set_value(tuples) nn.batch_set_value(tuples)
except:
return False
return True return True

View file

@ -1,104 +0,0 @@
import numpy as np
from core.leras import nn
tf = nn.tf
class TanhPolar(nn.LayerBase):
"""
RoI Tanh-polar Transformer Network for Face Parsing in the Wild
https://github.com/hhj1897/roi_tanh_warping
"""
def __init__(self, width, height, angular_offset_deg=270, **kwargs):
self.width = width
self.height = height
warp_gridx, warp_gridy = TanhPolar._get_tanh_polar_warp_grids(width,height,angular_offset_deg=angular_offset_deg)
restore_gridx, restore_gridy = TanhPolar._get_tanh_polar_restore_grids(width,height,angular_offset_deg=angular_offset_deg)
self.warp_gridx_t = tf.constant(warp_gridx[None, ...])
self.warp_gridy_t = tf.constant(warp_gridy[None, ...])
self.restore_gridx_t = tf.constant(restore_gridx[None, ...])
self.restore_gridy_t = tf.constant(restore_gridy[None, ...])
super().__init__(**kwargs)
def warp(self, inp_t):
batch_t = tf.shape(inp_t)[0]
warp_gridx_t = tf.tile(self.warp_gridx_t, (batch_t,1,1) )
warp_gridy_t = tf.tile(self.warp_gridy_t, (batch_t,1,1) )
if nn.data_format == "NCHW":
inp_t = tf.transpose(inp_t,(0,2,3,1))
out_t = nn.bilinear_sampler(inp_t, warp_gridx_t, warp_gridy_t)
if nn.data_format == "NCHW":
out_t = tf.transpose(out_t,(0,3,1,2))
return out_t
def restore(self, inp_t):
batch_t = tf.shape(inp_t)[0]
restore_gridx_t = tf.tile(self.restore_gridx_t, (batch_t,1,1) )
restore_gridy_t = tf.tile(self.restore_gridy_t, (batch_t,1,1) )
if nn.data_format == "NCHW":
inp_t = tf.transpose(inp_t,(0,2,3,1))
inp_t = tf.pad(inp_t, [(0,0), (1, 1), (1, 0), (0, 0)], "SYMMETRIC")
out_t = nn.bilinear_sampler(inp_t, restore_gridx_t, restore_gridy_t)
if nn.data_format == "NCHW":
out_t = tf.transpose(out_t,(0,3,1,2))
return out_t
@staticmethod
def _get_tanh_polar_warp_grids(W,H,angular_offset_deg):
angular_offset_pi = angular_offset_deg * np.pi / 180.0
roi_center = np.array([ W//2, H//2], np.float32 )
roi_radii = np.array([W, H], np.float32 ) / np.pi ** 0.5
cos_offset, sin_offset = np.cos(angular_offset_pi), np.sin(angular_offset_pi)
normalised_dest_indices = np.stack(np.meshgrid(np.arange(0.0, 1.0, 1.0 / W),np.arange(0.0, 2.0 * np.pi, 2.0 * np.pi / H)), axis=-1)
radii = normalised_dest_indices[..., 0]
orientation_x = np.cos(normalised_dest_indices[..., 1])
orientation_y = np.sin(normalised_dest_indices[..., 1])
src_radii = np.arctanh(radii) * (roi_radii[0] * roi_radii[1] / np.sqrt(roi_radii[1] ** 2 * orientation_x ** 2 + roi_radii[0] ** 2 * orientation_y ** 2))
src_x_indices = src_radii * orientation_x
src_y_indices = src_radii * orientation_y
src_x_indices, src_y_indices = (roi_center[0] + cos_offset * src_x_indices - sin_offset * src_y_indices,
roi_center[1] + cos_offset * src_y_indices + sin_offset * src_x_indices)
return src_x_indices.astype(np.float32), src_y_indices.astype(np.float32)
@staticmethod
def _get_tanh_polar_restore_grids(W,H,angular_offset_deg):
angular_offset_pi = angular_offset_deg * np.pi / 180.0
roi_center = np.array([ W//2, H//2], np.float32 )
roi_radii = np.array([W, H], np.float32 ) / np.pi ** 0.5
cos_offset, sin_offset = np.cos(angular_offset_pi), np.sin(angular_offset_pi)
dest_indices = np.stack(np.meshgrid(np.arange(W), np.arange(H)), axis=-1).astype(float)
normalised_dest_indices = np.matmul(dest_indices - roi_center, np.array([[cos_offset, -sin_offset],
[sin_offset, cos_offset]]))
radii = np.linalg.norm(normalised_dest_indices, axis=-1)
normalised_dest_indices[..., 0] /= np.clip(radii, 1e-9, None)
normalised_dest_indices[..., 1] /= np.clip(radii, 1e-9, None)
radii *= np.sqrt(roi_radii[1] ** 2 * normalised_dest_indices[..., 0] ** 2 +
roi_radii[0] ** 2 * normalised_dest_indices[..., 1] ** 2) / roi_radii[0] / roi_radii[1]
src_radii = np.tanh(radii)
src_x_indices = src_radii * W + 1.0
src_y_indices = np.mod((np.arctan2(normalised_dest_indices[..., 1], normalised_dest_indices[..., 0]) /
2.0 / np.pi) * H, H) + 1.0
return src_x_indices.astype(np.float32), src_y_indices.astype(np.float32)
nn.TanhPolar = TanhPolar

View file

@ -3,16 +3,12 @@ from .LayerBase import *
from .Conv2D import * from .Conv2D import *
from .Conv2DTranspose import * from .Conv2DTranspose import *
from .DepthwiseConv2D import *
from .Dense import * from .Dense import *
from .BlurPool import * from .BlurPool import *
from .BatchNorm2D import * from .BatchNorm2D import *
from .InstanceNorm2D import *
from .FRNorm2D import * from .FRNorm2D import *
from .TLU import * from .TLU import *
from .ScaleAdd import * from .ScaleAdd import *
from .DenseNorm import * from .DenseNorm import *
from .AdaIN import *
from .TanhPolar import *

View file

@ -18,10 +18,6 @@ class ModelBase(nn.Saveable):
if isinstance (layer, list): if isinstance (layer, list):
for i,sublayer in enumerate(layer): for i,sublayer in enumerate(layer):
self._build_sub(sublayer, f"{name}_{i}") self._build_sub(sublayer, f"{name}_{i}")
elif isinstance (layer, dict):
for subname in layer.keys():
sublayer = layer[subname]
self._build_sub(sublayer, f"{name}_{subname}")
elif isinstance (layer, nn.LayerBase) or \ elif isinstance (layer, nn.LayerBase) or \
isinstance (layer, ModelBase): isinstance (layer, ModelBase):
@ -116,32 +112,41 @@ class ModelBase(nn.Saveable):
return self.forward(*args, **kwargs) return self.forward(*args, **kwargs)
# def compute_output_shape(self, shapes): def compute_output_shape(self, shapes):
# if not self.built: if not self.built:
# self.build() self.build()
# not_list = False not_list = False
# if not isinstance(shapes, list): if not isinstance(shapes, list):
# not_list = True not_list = True
# shapes = [shapes] shapes = [shapes]
# with tf.device('/CPU:0'): with tf.device('/CPU:0'):
# # CPU tensors will not impact any performance, only slightly RAM "leakage" # CPU tensors will not impact any performance, only slightly RAM "leakage"
# phs = [] phs = []
# for dtype,sh in shapes: for dtype,sh in shapes:
# phs += [ tf.placeholder(dtype, sh) ] phs += [ tf.placeholder(dtype, sh) ]
# result = self.__call__(phs[0] if not_list else phs) result = self.__call__(phs[0] if not_list else phs)
# if not isinstance(result, list): if not isinstance(result, list):
# result = [result] result = [result]
# result_shapes = [] result_shapes = []
# for t in result: for t in result:
# result_shapes += [ t.shape.as_list() ] result_shapes += [ t.shape.as_list() ]
# return result_shapes[0] if not_list else result_shapes return result_shapes[0] if not_list else result_shapes
def compute_output_channels(self, shapes):
shape = self.compute_output_shape(shapes)
shape_len = len(shape)
if shape_len == 4:
if nn.data_format == "NCHW":
return shape[1]
return shape[-1]
def build_for_run(self, shapes_list): def build_for_run(self, shapes_list):
if not isinstance(shapes_list, list): if not isinstance(shapes_list, list):

View file

@ -1,7 +1,7 @@
import numpy as np
from core.leras import nn from core.leras import nn
tf = nn.tf tf = nn.tf
patch_discriminator_kernels = \ patch_discriminator_kernels = \
{ 1 : (512, [ [1,1] ]), { 1 : (512, [ [1,1] ]),
2 : (512, [ [2,1] ]), 2 : (512, [ [2,1] ]),
@ -41,14 +41,6 @@ patch_discriminator_kernels = \
36 : (256, [ [4,2], [3,2], [4,2], [3,1] ]), 36 : (256, [ [4,2], [3,2], [4,2], [3,1] ]),
37 : (256, [ [3,2], [4,2], [4,2], [3,1] ]), 37 : (256, [ [3,2], [4,2], [4,2], [3,1] ]),
38 : (256, [ [4,2], [4,2], [4,2], [3,1] ]), 38 : (256, [ [4,2], [4,2], [4,2], [3,1] ]),
39 : (256, [ [3,2], [3,2], [3,2], [4,1] ]),
40 : (256, [ [4,2], [3,2], [3,2], [4,1] ]),
41 : (256, [ [3,2], [4,2], [3,2], [4,1] ]),
42 : (256, [ [4,2], [4,2], [3,2], [4,1] ]),
43 : (256, [ [3,2], [4,2], [4,2], [4,1] ]),
44 : (256, [ [4,2], [3,2], [4,2], [4,1] ]),
45 : (256, [ [3,2], [4,2], [4,2], [4,1] ]),
46 : (256, [ [4,2], [4,2], [4,2], [4,1] ]),
} }
@ -75,120 +67,3 @@ class PatchDiscriminator(nn.ModelBase):
return self.out_conv(x) return self.out_conv(x)
nn.PatchDiscriminator = PatchDiscriminator nn.PatchDiscriminator = PatchDiscriminator
class UNetPatchDiscriminator(nn.ModelBase):
"""
Inspired by https://arxiv.org/abs/2002.12655 "A U-Net Based Discriminator for Generative Adversarial Networks"
"""
def calc_receptive_field_size(self, layers):
"""
result the same as https://fomoro.com/research/article/receptive-field-calculatorindex.html
"""
rf = 0
ts = 1
for i, (k, s) in enumerate(layers):
if i == 0:
rf = k
else:
rf += (k-1)*ts
ts *= s
return rf
def find_archi(self, target_patch_size, max_layers=9):
"""
Find the best configuration of layers using only 3x3 convs for target patch size
"""
s = {}
for layers_count in range(1,max_layers+1):
val = 1 << (layers_count-1)
while True:
val -= 1
layers = []
sum_st = 0
layers.append ( [3, 2])
sum_st += 2
for i in range(layers_count-1):
st = 1 + (1 if val & (1 << i) !=0 else 0 )
layers.append ( [3, st ])
sum_st += st
rf = self.calc_receptive_field_size(layers)
s_rf = s.get(rf, None)
if s_rf is None:
s[rf] = (layers_count, sum_st, layers)
else:
if layers_count < s_rf[0] or \
( layers_count == s_rf[0] and sum_st > s_rf[1] ):
s[rf] = (layers_count, sum_st, layers)
if val == 0:
break
x = sorted(list(s.keys()))
q=x[np.abs(np.array(x)-target_patch_size).argmin()]
return s[q][2]
def on_build(self, patch_size, in_ch, base_ch = 16, use_fp16 = False):
self.use_fp16 = use_fp16
conv_dtype = tf.float16 if use_fp16 else tf.float32
class ResidualBlock(nn.ModelBase):
def on_build(self, ch, kernel_size=3 ):
self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype)
self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype)
def forward(self, inp):
x = self.conv1(inp)
x = tf.nn.leaky_relu(x, 0.2)
x = self.conv2(x)
x = tf.nn.leaky_relu(inp + x, 0.2)
return x
prev_ch = in_ch
self.convs = []
self.upconvs = []
layers = self.find_archi(patch_size)
level_chs = { i-1:v for i,v in enumerate([ min( base_ch * (2**i), 512 ) for i in range(len(layers)+1)]) }
self.in_conv = nn.Conv2D( in_ch, level_chs[-1], kernel_size=1, padding='VALID', dtype=conv_dtype)
for i, (kernel_size, strides) in enumerate(layers):
self.convs.append ( nn.Conv2D( level_chs[i-1], level_chs[i], kernel_size=kernel_size, strides=strides, padding='SAME', dtype=conv_dtype) )
self.upconvs.insert (0, nn.Conv2DTranspose( level_chs[i]*(2 if i != len(layers)-1 else 1), level_chs[i-1], kernel_size=kernel_size, strides=strides, padding='SAME', dtype=conv_dtype) )
self.out_conv = nn.Conv2D( level_chs[-1]*2, 1, kernel_size=1, padding='VALID', dtype=conv_dtype)
self.center_out = nn.Conv2D( level_chs[len(layers)-1], 1, kernel_size=1, padding='VALID', dtype=conv_dtype)
self.center_conv = nn.Conv2D( level_chs[len(layers)-1], level_chs[len(layers)-1], kernel_size=1, padding='VALID', dtype=conv_dtype)
def forward(self, x):
if self.use_fp16:
x = tf.cast(x, tf.float16)
x = tf.nn.leaky_relu( self.in_conv(x), 0.2 )
encs = []
for conv in self.convs:
encs.insert(0, x)
x = tf.nn.leaky_relu( conv(x), 0.2 )
center_out, x = self.center_out(x), tf.nn.leaky_relu( self.center_conv(x), 0.2 )
for i, (upconv, enc) in enumerate(zip(self.upconvs, encs)):
x = tf.nn.leaky_relu( upconv(x), 0.2 )
x = tf.concat( [enc, x], axis=nn.conv2d_ch_axis)
x = self.out_conv(x)
if self.use_fp16:
center_out = tf.cast(center_out, tf.float32)
x = tf.cast(x, tf.float32)
return center_out, x
nn.UNetPatchDiscriminator = UNetPatchDiscriminator

View file

@ -29,11 +29,10 @@ class XSeg(nn.ModelBase):
x = self.tlu(x) x = self.tlu(x)
return x return x
self.base_ch = base_ch
self.conv01 = ConvBlock(in_ch, base_ch) self.conv01 = ConvBlock(in_ch, base_ch)
self.conv02 = ConvBlock(base_ch, base_ch) self.conv02 = ConvBlock(base_ch, base_ch)
self.bp0 = nn.BlurPool (filt_size=4) self.bp0 = nn.BlurPool (filt_size=3)
self.conv11 = ConvBlock(base_ch, base_ch*2) self.conv11 = ConvBlock(base_ch, base_ch*2)
self.conv12 = ConvBlock(base_ch*2, base_ch*2) self.conv12 = ConvBlock(base_ch*2, base_ch*2)
@ -41,30 +40,19 @@ class XSeg(nn.ModelBase):
self.conv21 = ConvBlock(base_ch*2, base_ch*4) self.conv21 = ConvBlock(base_ch*2, base_ch*4)
self.conv22 = ConvBlock(base_ch*4, base_ch*4) self.conv22 = ConvBlock(base_ch*4, base_ch*4)
self.bp2 = nn.BlurPool (filt_size=2) self.conv23 = ConvBlock(base_ch*4, base_ch*4)
self.bp2 = nn.BlurPool (filt_size=3)
self.conv31 = ConvBlock(base_ch*4, base_ch*8) self.conv31 = ConvBlock(base_ch*4, base_ch*8)
self.conv32 = ConvBlock(base_ch*8, base_ch*8) self.conv32 = ConvBlock(base_ch*8, base_ch*8)
self.conv33 = ConvBlock(base_ch*8, base_ch*8) self.conv33 = ConvBlock(base_ch*8, base_ch*8)
self.bp3 = nn.BlurPool (filt_size=2) self.bp3 = nn.BlurPool (filt_size=3)
self.conv41 = ConvBlock(base_ch*8, base_ch*8) self.conv41 = ConvBlock(base_ch*8, base_ch*8)
self.conv42 = ConvBlock(base_ch*8, base_ch*8) self.conv42 = ConvBlock(base_ch*8, base_ch*8)
self.conv43 = ConvBlock(base_ch*8, base_ch*8) self.conv43 = ConvBlock(base_ch*8, base_ch*8)
self.bp4 = nn.BlurPool (filt_size=2) self.bp4 = nn.BlurPool (filt_size=3)
self.conv51 = ConvBlock(base_ch*8, base_ch*8)
self.conv52 = ConvBlock(base_ch*8, base_ch*8)
self.conv53 = ConvBlock(base_ch*8, base_ch*8)
self.bp5 = nn.BlurPool (filt_size=2)
self.dense1 = nn.Dense ( 4*4* base_ch*8, 512)
self.dense2 = nn.Dense ( 512, 4*4* base_ch*8)
self.up5 = UpConvBlock (base_ch*8, base_ch*4)
self.uconv53 = ConvBlock(base_ch*12, base_ch*8)
self.uconv52 = ConvBlock(base_ch*8, base_ch*8)
self.uconv51 = ConvBlock(base_ch*8, base_ch*8)
self.up4 = UpConvBlock (base_ch*8, base_ch*4) self.up4 = UpConvBlock (base_ch*8, base_ch*4)
self.uconv43 = ConvBlock(base_ch*12, base_ch*8) self.uconv43 = ConvBlock(base_ch*12, base_ch*8)
@ -77,7 +65,8 @@ class XSeg(nn.ModelBase):
self.uconv31 = ConvBlock(base_ch*8, base_ch*8) self.uconv31 = ConvBlock(base_ch*8, base_ch*8)
self.up2 = UpConvBlock (base_ch*8, base_ch*4) self.up2 = UpConvBlock (base_ch*8, base_ch*4)
self.uconv22 = ConvBlock(base_ch*8, base_ch*4) self.uconv23 = ConvBlock(base_ch*8, base_ch*4)
self.uconv22 = ConvBlock(base_ch*4, base_ch*4)
self.uconv21 = ConvBlock(base_ch*4, base_ch*4) self.uconv21 = ConvBlock(base_ch*4, base_ch*4)
self.up1 = UpConvBlock (base_ch*4, base_ch*2) self.up1 = UpConvBlock (base_ch*4, base_ch*2)
@ -89,8 +78,9 @@ class XSeg(nn.ModelBase):
self.uconv01 = ConvBlock(base_ch, base_ch) self.uconv01 = ConvBlock(base_ch, base_ch)
self.out_conv = nn.Conv2D (base_ch, out_ch, kernel_size=3, padding='SAME') self.out_conv = nn.Conv2D (base_ch, out_ch, kernel_size=3, padding='SAME')
self.conv_center = ConvBlock(base_ch*8, base_ch*8)
def forward(self, inp, pretrain=False): def forward(self, inp):
x = inp x = inp
x = self.conv01(x) x = self.conv01(x)
@ -102,7 +92,8 @@ class XSeg(nn.ModelBase):
x = self.bp1(x) x = self.bp1(x)
x = self.conv21(x) x = self.conv21(x)
x = x2 = self.conv22(x) x = self.conv22(x)
x = x2 = self.conv23(x)
x = self.bp2(x) x = self.bp2(x)
x = self.conv31(x) x = self.conv31(x)
@ -115,52 +106,28 @@ class XSeg(nn.ModelBase):
x = x4 = self.conv43(x) x = x4 = self.conv43(x)
x = self.bp4(x) x = self.bp4(x)
x = self.conv51(x) x = self.conv_center(x)
x = self.conv52(x)
x = x5 = self.conv53(x)
x = self.bp5(x)
x = nn.flatten(x)
x = self.dense1(x)
x = self.dense2(x)
x = nn.reshape_4D (x, 4, 4, self.base_ch*8 )
x = self.up5(x)
if pretrain:
x5 = tf.zeros_like(x5)
x = self.uconv53(tf.concat([x,x5],axis=nn.conv2d_ch_axis))
x = self.uconv52(x)
x = self.uconv51(x)
x = self.up4(x) x = self.up4(x)
if pretrain:
x4 = tf.zeros_like(x4)
x = self.uconv43(tf.concat([x,x4],axis=nn.conv2d_ch_axis)) x = self.uconv43(tf.concat([x,x4],axis=nn.conv2d_ch_axis))
x = self.uconv42(x) x = self.uconv42(x)
x = self.uconv41(x) x = self.uconv41(x)
x = self.up3(x) x = self.up3(x)
if pretrain:
x3 = tf.zeros_like(x3)
x = self.uconv33(tf.concat([x,x3],axis=nn.conv2d_ch_axis)) x = self.uconv33(tf.concat([x,x3],axis=nn.conv2d_ch_axis))
x = self.uconv32(x) x = self.uconv32(x)
x = self.uconv31(x) x = self.uconv31(x)
x = self.up2(x) x = self.up2(x)
if pretrain: x = self.uconv23(tf.concat([x,x2],axis=nn.conv2d_ch_axis))
x2 = tf.zeros_like(x2) x = self.uconv22(x)
x = self.uconv22(tf.concat([x,x2],axis=nn.conv2d_ch_axis))
x = self.uconv21(x) x = self.uconv21(x)
x = self.up1(x) x = self.up1(x)
if pretrain:
x1 = tf.zeros_like(x1)
x = self.uconv12(tf.concat([x,x1],axis=nn.conv2d_ch_axis)) x = self.uconv12(tf.concat([x,x1],axis=nn.conv2d_ch_axis))
x = self.uconv11(x) x = self.uconv11(x)
x = self.up0(x) x = self.up0(x)
if pretrain:
x0 = tf.zeros_like(x0)
x = self.uconv02(tf.concat([x,x0],axis=nn.conv2d_ch_axis)) x = self.uconv02(tf.concat([x,x0],axis=nn.conv2d_ch_axis))
x = self.uconv01(x) x = self.uconv01(x)

View file

@ -33,7 +33,7 @@ class nn():
tf = None tf = None
tf_sess = None tf_sess = None
tf_sess_config = None tf_sess_config = None
tf_default_device_name = None tf_default_device = None
data_format = None data_format = None
conv2d_ch_axis = None conv2d_ch_axis = None
@ -51,6 +51,9 @@ class nn():
# Manipulate environment variables before import tensorflow # Manipulate environment variables before import tensorflow
if 'CUDA_VISIBLE_DEVICES' in os.environ.keys():
os.environ.pop('CUDA_VISIBLE_DEVICES')
first_run = False first_run = False
if len(device_config.devices) != 0: if len(device_config.devices) != 0:
if sys.platform[0:3] == 'win': if sys.platform[0:3] == 'win':
@ -65,32 +68,21 @@ class nn():
compute_cache_path = Path(os.environ['APPDATA']) / 'NVIDIA' / ('ComputeCache' + devices_str) compute_cache_path = Path(os.environ['APPDATA']) / 'NVIDIA' / ('ComputeCache' + devices_str)
if not compute_cache_path.exists(): if not compute_cache_path.exists():
first_run = True first_run = True
compute_cache_path.mkdir(parents=True, exist_ok=True)
os.environ['CUDA_CACHE_PATH'] = str(compute_cache_path) os.environ['CUDA_CACHE_PATH'] = str(compute_cache_path)
os.environ['CUDA_CACHE_MAXSIZE'] = '536870912' #512Mb (32mb default)
os.environ['TF_MIN_GPU_MULTIPROCESSOR_COUNT'] = '2'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # tf log errors only
if first_run: if first_run:
io.log_info("Caching GPU kernels...") io.log_info("Caching GPU kernels...")
import tensorflow import tensorflow as tf
nn.tf = tf
tf_version = tensorflow.version.VERSION
#if tf_version is None:
# tf_version = tensorflow.version.GIT_VERSION
if tf_version[0] == 'v':
tf_version = tf_version[1:]
if tf_version[0] == '2':
tf = tensorflow.compat.v1
else:
tf = tensorflow
import logging import logging
# Disable tensorflow warnings # Disable tensorflow warnings
tf_logger = logging.getLogger('tensorflow') logging.getLogger('tensorflow').setLevel(logging.ERROR)
tf_logger.setLevel(logging.ERROR)
if tf_version[0] == '2':
tf.disable_v2_behavior()
nn.tf = tf
# Initialize framework # Initialize framework
import core.leras.ops import core.leras.ops
@ -102,11 +94,10 @@ class nn():
# Configure tensorflow session-config # Configure tensorflow session-config
if len(device_config.devices) == 0: if len(device_config.devices) == 0:
nn.tf_default_device = "/CPU:0"
config = tf.ConfigProto(device_count={'GPU': 0}) config = tf.ConfigProto(device_count={'GPU': 0})
nn.tf_default_device_name = '/CPU:0'
else: else:
nn.tf_default_device_name = f'/{device_config.devices[0].tf_dev_type}:0' nn.tf_default_device = "/GPU:0"
config = tf.ConfigProto() config = tf.ConfigProto()
config.gpu_options.visible_device_list = ','.join([str(device.index) for device in device_config.devices]) config.gpu_options.visible_device_list = ','.join([str(device.index) for device in device_config.devices])
@ -197,6 +188,14 @@ class nn():
nn.tf_sess.close() nn.tf_sess.close()
nn.tf_sess = None nn.tf_sess = None
@staticmethod
def get_current_device():
# Undocumented access to last tf.device(...)
objs = nn.tf.get_default_graph()._device_function_stack.peek_objs()
if len(objs) != 0:
return objs[0].display_name
return nn.tf_default_device
@staticmethod @staticmethod
def ask_choose_device_idxs(choose_only_one=False, allow_cpu=True, suggest_best_multi_gpu=False, suggest_all_gpu=False): def ask_choose_device_idxs(choose_only_one=False, allow_cpu=True, suggest_best_multi_gpu=False, suggest_all_gpu=False):
devices = Devices.getDevices() devices = Devices.getDevices()

View file

@ -108,15 +108,10 @@ nn.gelu = gelu
def upsample2d(x, size=2): def upsample2d(x, size=2):
if nn.data_format == "NCHW": if nn.data_format == "NCHW":
x = tf.transpose(x, (0,2,3,1)) b,c,h,w = x.shape.as_list()
x = tf.image.resize_nearest_neighbor(x, (x.shape[1]*size, x.shape[2]*size) ) x = tf.reshape (x, (-1,c,h,1,w,1) )
x = tf.transpose(x, (0,3,1,2)) x = tf.tile(x, (1,1,1,size,1,size) )
x = tf.reshape (x, (-1,c,h*size,w*size) )
# b,c,h,w = x.shape.as_list()
# x = tf.reshape (x, (-1,c,h,1,w,1) )
# x = tf.tile(x, (1,1,1,size,1,size) )
# x = tf.reshape (x, (-1,c,h*size,w*size) )
return x return x
else: else:
return tf.image.resize_nearest_neighbor(x, (x.shape[1]*size, x.shape[2]*size) ) return tf.image.resize_nearest_neighbor(x, (x.shape[1]*size, x.shape[2]*size) )
@ -142,39 +137,8 @@ def resize2d_bilinear(x, size=2):
return x return x
nn.resize2d_bilinear = resize2d_bilinear nn.resize2d_bilinear = resize2d_bilinear
def resize2d_nearest(x, size=2):
if size in [-1,0,1]:
return x
if size > 0:
raise Exception("")
else:
if nn.data_format == "NCHW":
x = x[:,:,::-size,::-size]
else:
x = x[:,::-size,::-size,:]
return x
h = x.shape[nn.conv2d_spatial_axes[0]].value
w = x.shape[nn.conv2d_spatial_axes[1]].value
if nn.data_format == "NCHW":
x = tf.transpose(x, (0,2,3,1))
if size > 0:
new_size = (h*size,w*size)
else:
new_size = (h//-size,w//-size)
x = tf.image.resize(x, new_size, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
if nn.data_format == "NCHW":
x = tf.transpose(x, (0,3,1,2))
return x
nn.resize2d_nearest = resize2d_nearest
def flatten(x): def flatten(x):
if nn.data_format == "NHWC": if nn.data_format == "NHWC":
# match NCHW version in order to switch data_format without problems # match NCHW version in order to switch data_format without problems
@ -209,7 +173,7 @@ def random_binomial(shape, p=0.0, dtype=None, seed=None):
seed = np.random.randint(10e6) seed = np.random.randint(10e6)
return array_ops.where( return array_ops.where(
random_ops.random_uniform(shape, dtype=tf.float16, seed=seed) < p, random_ops.random_uniform(shape, dtype=tf.float16, seed=seed) < p,
array_ops.ones(shape, dtype=dtype), array_ops.zeros(shape, dtype=dtype)) array_ops.ones(shape, dtype=dtype), array_ops.zeros(shape, dtype=dtype))
nn.random_binomial = random_binomial nn.random_binomial = random_binomial
def gaussian_blur(input, radius=2.0): def gaussian_blur(input, radius=2.0):
@ -217,9 +181,7 @@ def gaussian_blur(input, radius=2.0):
return np.exp(-(float(x) - float(mu)) ** 2 / (2 * sigma ** 2)) return np.exp(-(float(x) - float(mu)) ** 2 / (2 * sigma ** 2))
def make_kernel(sigma): def make_kernel(sigma):
kernel_size = max(3, int(2 * 2 * sigma)) kernel_size = max(3, int(2 * 2 * sigma + 1))
if kernel_size % 2 == 0:
kernel_size += 1
mean = np.floor(0.5 * kernel_size) mean = np.floor(0.5 * kernel_size)
kernel_1d = np.array([gaussian(x, mean, sigma) for x in range(kernel_size)]) kernel_1d = np.array([gaussian(x, mean, sigma) for x in range(kernel_size)])
np_kernel = np.outer(kernel_1d, kernel_1d).astype(np.float32) np_kernel = np.outer(kernel_1d, kernel_1d).astype(np.float32)
@ -276,8 +238,6 @@ def dssim(img1,img2, max_val, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03
img1 = tf.cast(img1, tf.float32) img1 = tf.cast(img1, tf.float32)
img2 = tf.cast(img2, tf.float32) img2 = tf.cast(img2, tf.float32)
filter_size = max(1, filter_size)
kernel = np.arange(0, filter_size, dtype=np.float32) kernel = np.arange(0, filter_size, dtype=np.float32)
kernel -= (filter_size - 1 ) / 2.0 kernel -= (filter_size - 1 ) / 2.0
kernel = kernel**2 kernel = kernel**2
@ -340,17 +300,7 @@ def depth_to_space(x, size):
x = tf.reshape(x, (-1, oh, ow, oc, )) x = tf.reshape(x, (-1, oh, ow, oc, ))
return x return x
else: else:
cfg = nn.getCurrentDeviceConfig() return tf.depth_to_space(x, size, data_format=nn.data_format)
if not cfg.cpu_only:
return tf.depth_to_space(x, size, data_format=nn.data_format)
b,c,h,w = x.shape.as_list()
oh, ow = h * size, w * size
oc = c // (size * size)
x = tf.reshape(x, (-1, size, size, oc, h, w, ) )
x = tf.transpose(x, (0, 3, 4, 1, 5, 2))
x = tf.reshape(x, (-1, oc, oh, ow))
return x
nn.depth_to_space = depth_to_space nn.depth_to_space = depth_to_space
def rgb_to_lab(srgb): def rgb_to_lab(srgb):
@ -383,23 +333,6 @@ def rgb_to_lab(srgb):
return tf.reshape(lab_pixels, tf.shape(srgb)) return tf.reshape(lab_pixels, tf.shape(srgb))
nn.rgb_to_lab = rgb_to_lab nn.rgb_to_lab = rgb_to_lab
def total_variation_mse(images):
"""
Same as generic total_variation, but MSE diff instead of MAE
"""
pixel_dif1 = images[:, 1:, :, :] - images[:, :-1, :, :]
pixel_dif2 = images[:, :, 1:, :] - images[:, :, :-1, :]
tot_var = ( tf.reduce_sum(tf.square(pixel_dif1), axis=[1,2,3]) +
tf.reduce_sum(tf.square(pixel_dif2), axis=[1,2,3]) )
return tot_var
nn.total_variation_mse = total_variation_mse
def pixel_norm(x, axes):
return x * tf.rsqrt(tf.reduce_mean(tf.square(x), axis=axes, keepdims=True) + 1e-06)
nn.pixel_norm = pixel_norm
""" """
def tf_suppress_lower_mean(t, eps=0.00001): def tf_suppress_lower_mean(t, eps=0.00001):
if t.shape.ndims != 1: if t.shape.ndims != 1:
@ -410,69 +343,3 @@ def tf_suppress_lower_mean(t, eps=0.00001):
q = q * (t/eps) q = q * (t/eps)
return q return q
""" """
def _get_pixel_value(img, x, y):
shape = tf.shape(x)
batch_size = shape[0]
height = shape[1]
width = shape[2]
batch_idx = tf.range(0, batch_size)
batch_idx = tf.reshape(batch_idx, (batch_size, 1, 1))
b = tf.tile(batch_idx, (1, height, width))
indices = tf.stack([b, y, x], 3)
return tf.gather_nd(img, indices)
def bilinear_sampler(img, x, y):
H = tf.shape(img)[1]
W = tf.shape(img)[2]
H_MAX = tf.cast(H - 1, tf.int32)
W_MAX = tf.cast(W - 1, tf.int32)
# grab 4 nearest corner points for each (x_i, y_i)
x0 = tf.cast(tf.floor(x), tf.int32)
x1 = x0 + 1
y0 = tf.cast(tf.floor(y), tf.int32)
y1 = y0 + 1
# clip to range [0, H-1/W-1] to not violate img boundaries
x0 = tf.clip_by_value(x0, 0, W_MAX)
x1 = tf.clip_by_value(x1, 0, W_MAX)
y0 = tf.clip_by_value(y0, 0, H_MAX)
y1 = tf.clip_by_value(y1, 0, H_MAX)
# get pixel value at corner coords
Ia = _get_pixel_value(img, x0, y0)
Ib = _get_pixel_value(img, x0, y1)
Ic = _get_pixel_value(img, x1, y0)
Id = _get_pixel_value(img, x1, y1)
# recast as float for delta calculation
x0 = tf.cast(x0, tf.float32)
x1 = tf.cast(x1, tf.float32)
y0 = tf.cast(y0, tf.float32)
y1 = tf.cast(y1, tf.float32)
# calculate deltas
wa = (x1-x) * (y1-y)
wb = (x1-x) * (y-y0)
wc = (x-x0) * (y1-y)
wd = (x-x0) * (y-y0)
# add dimension for addition
wa = tf.expand_dims(wa, axis=3)
wb = tf.expand_dims(wb, axis=3)
wc = tf.expand_dims(wc, axis=3)
wd = tf.expand_dims(wd, axis=3)
# compute output
out = tf.add_n([wa*Ia, wb*Ib, wc*Ic, wd*Id])
return out
nn.bilinear_sampler = bilinear_sampler

View file

@ -1,81 +0,0 @@
import numpy as np
from core.leras import nn
from tensorflow.python.ops import control_flow_ops, state_ops
tf = nn.tf
class AdaBelief(nn.OptimizerBase):
def __init__(self, lr=0.001, beta_1=0.9, beta_2=0.999, lr_dropout=1.0, lr_cos=0, clipnorm=0.0, name=None, **kwargs):
super().__init__(name=name)
if name is None:
raise ValueError('name must be defined.')
self.lr = lr
self.beta_1 = beta_1
self.beta_2 = beta_2
self.lr_dropout = lr_dropout
self.lr_cos = lr_cos
self.clipnorm = clipnorm
with tf.device('/CPU:0') :
with tf.variable_scope(self.name):
self.iterations = tf.Variable(0, dtype=tf.int64, name='iters')
self.ms_dict = {}
self.vs_dict = {}
self.lr_rnds_dict = {}
def get_weights(self):
return [self.iterations] + list(self.ms_dict.values()) + list(self.vs_dict.values())
def initialize_variables(self, trainable_weights, vars_on_cpu=True, lr_dropout_on_cpu=False):
# Initialize here all trainable variables used in training
e = tf.device('/CPU:0') if vars_on_cpu else None
if e: e.__enter__()
with tf.variable_scope(self.name):
ms = { v.name : tf.get_variable ( f'ms_{v.name}'.replace(':','_'), v.shape, dtype=v.dtype, initializer=tf.initializers.constant(0.0), trainable=False) for v in trainable_weights }
vs = { v.name : tf.get_variable ( f'vs_{v.name}'.replace(':','_'), v.shape, dtype=v.dtype, initializer=tf.initializers.constant(0.0), trainable=False) for v in trainable_weights }
self.ms_dict.update (ms)
self.vs_dict.update (vs)
if self.lr_dropout != 1.0:
e = tf.device('/CPU:0') if lr_dropout_on_cpu else None
if e: e.__enter__()
lr_rnds = [ nn.random_binomial( v.shape, p=self.lr_dropout, dtype=v.dtype) for v in trainable_weights ]
if e: e.__exit__(None, None, None)
self.lr_rnds_dict.update ( { v.name : rnd for v,rnd in zip(trainable_weights,lr_rnds) } )
if e: e.__exit__(None, None, None)
def get_update_op(self, grads_vars):
updates = []
if self.clipnorm > 0.0:
norm = tf.sqrt( sum([tf.reduce_sum(tf.square(tf.cast(g, tf.float32))) for g,v in grads_vars]))
updates += [ state_ops.assign_add( self.iterations, 1) ]
for i, (g,v) in enumerate(grads_vars):
if self.clipnorm > 0.0:
g = self.tf_clip_norm(g, self.clipnorm, tf.cast(norm, g.dtype) )
ms = self.ms_dict[ v.name ]
vs = self.vs_dict[ v.name ]
m_t = self.beta_1*ms + (1.0-self.beta_1) * g
v_t = self.beta_2*vs + (1.0-self.beta_2) * tf.square(g-m_t)
lr = tf.constant(self.lr, g.dtype)
if self.lr_cos != 0:
lr *= (tf.cos( tf.cast(self.iterations, g.dtype) * (2*3.1415926535/ float(self.lr_cos) ) ) + 1.0) / 2.0
v_diff = - lr * m_t / (tf.sqrt(v_t) + np.finfo( g.dtype.as_numpy_dtype ).resolution )
if self.lr_dropout != 1.0:
lr_rnd = self.lr_rnds_dict[v.name]
v_diff *= lr_rnd
new_v = v + v_diff
updates.append (state_ops.assign(ms, m_t))
updates.append (state_ops.assign(vs, v_t))
updates.append (state_ops.assign(v, new_v))
return control_flow_ops.group ( *updates, name=self.name+'_updates')
nn.AdaBelief = AdaBelief

View file

@ -1,33 +1,31 @@
import numpy as np
from tensorflow.python.ops import control_flow_ops, state_ops from tensorflow.python.ops import control_flow_ops, state_ops
from core.leras import nn from core.leras import nn
tf = nn.tf tf = nn.tf
class RMSprop(nn.OptimizerBase): class RMSprop(nn.OptimizerBase):
def __init__(self, lr=0.001, rho=0.9, lr_dropout=1.0, lr_cos=0, clipnorm=0.0, name=None, **kwargs): def __init__(self, lr=0.001, rho=0.9, lr_dropout=1.0, epsilon=1e-7, clipnorm=0.0, name=None):
super().__init__(name=name) super().__init__(name=name)
if name is None: if name is None:
raise ValueError('name must be defined.') raise ValueError('name must be defined.')
self.lr_dropout = lr_dropout self.lr_dropout = lr_dropout
self.lr_cos = lr_cos
self.lr = lr
self.rho = rho
self.clipnorm = clipnorm self.clipnorm = clipnorm
with tf.device('/CPU:0') : with tf.device('/CPU:0') :
with tf.variable_scope(self.name): with tf.variable_scope(self.name):
self.lr = tf.Variable (lr, name="lr")
self.rho = tf.Variable (rho, name="rho")
self.epsilon = tf.Variable (epsilon, name="epsilon")
self.iterations = tf.Variable(0, dtype=tf.int64, name='iters') self.iterations = tf.Variable(0, dtype=tf.int64, name='iters')
self.accumulators_dict = {} self.accumulators_dict = {}
self.lr_rnds_dict = {} self.lr_rnds_dict = {}
def get_weights(self): def get_weights(self):
return [self.iterations] + list(self.accumulators_dict.values()) return [self.lr, self.rho, self.epsilon, self.iterations] + list(self.accumulators_dict.values())
def initialize_variables(self, trainable_weights, vars_on_cpu=True, lr_dropout_on_cpu=False): def initialize_variables(self, trainable_weights, vars_on_cpu=True):
# Initialize here all trainable variables used in training # Initialize here all trainable variables used in training
e = tf.device('/CPU:0') if vars_on_cpu else None e = tf.device('/CPU:0') if vars_on_cpu else None
if e: e.__enter__() if e: e.__enter__()
@ -36,10 +34,7 @@ class RMSprop(nn.OptimizerBase):
self.accumulators_dict.update ( accumulators) self.accumulators_dict.update ( accumulators)
if self.lr_dropout != 1.0: if self.lr_dropout != 1.0:
e = tf.device('/CPU:0') if lr_dropout_on_cpu else None
if e: e.__enter__()
lr_rnds = [ nn.random_binomial( v.shape, p=self.lr_dropout, dtype=v.dtype) for v in trainable_weights ] lr_rnds = [ nn.random_binomial( v.shape, p=self.lr_dropout, dtype=v.dtype) for v in trainable_weights ]
if e: e.__exit__(None, None, None)
self.lr_rnds_dict.update ( { v.name : rnd for v,rnd in zip(trainable_weights,lr_rnds) } ) self.lr_rnds_dict.update ( { v.name : rnd for v,rnd in zip(trainable_weights,lr_rnds) } )
if e: e.__exit__(None, None, None) if e: e.__exit__(None, None, None)
@ -47,21 +42,21 @@ class RMSprop(nn.OptimizerBase):
updates = [] updates = []
if self.clipnorm > 0.0: if self.clipnorm > 0.0:
norm = tf.sqrt( sum([tf.reduce_sum(tf.square(tf.cast(g, tf.float32))) for g,v in grads_vars])) norm = tf.sqrt( sum([tf.reduce_sum(tf.square(g)) for g,v in grads_vars]))
updates += [ state_ops.assign_add( self.iterations, 1) ] updates += [ state_ops.assign_add( self.iterations, 1) ]
for i, (g,v) in enumerate(grads_vars): for i, (g,v) in enumerate(grads_vars):
if self.clipnorm > 0.0: if self.clipnorm > 0.0:
g = self.tf_clip_norm(g, self.clipnorm, tf.cast(norm, g.dtype) ) g = self.tf_clip_norm(g, self.clipnorm, norm)
a = self.accumulators_dict[ v.name ] a = self.accumulators_dict[ v.name ]
new_a = self.rho * a + (1. - self.rho) * tf.square(g) rho = tf.cast(self.rho, a.dtype)
new_a = rho * a + (1. - rho) * tf.square(g)
lr = tf.constant(self.lr, g.dtype) lr = tf.cast(self.lr, a.dtype)
if self.lr_cos != 0: epsilon = tf.cast(self.epsilon, a.dtype)
lr *= (tf.cos( tf.cast(self.iterations, g.dtype) * (2*3.1415926535/ float(self.lr_cos) ) ) + 1.0) / 2.0
v_diff = - lr * g / (tf.sqrt(new_a) + np.finfo( g.dtype.as_numpy_dtype ).resolution ) v_diff = - lr * g / (tf.sqrt(new_a) + epsilon)
if self.lr_dropout != 1.0: if self.lr_dropout != 1.0:
lr_rnd = self.lr_rnds_dict[v.name] lr_rnd = self.lr_rnds_dict[v.name]
v_diff *= lr_rnd v_diff *= lr_rnd

View file

@ -1,3 +1,2 @@
from .OptimizerBase import * from .OptimizerBase import *
from .RMSprop import * from .RMSprop import *
from .AdaBelief import *

View file

@ -1,12 +1,7 @@
import math
import cv2
import numpy as np import numpy as np
import numpy.linalg as npla import math
from .umeyama import umeyama from .umeyama import umeyama
def get_power_of_two(x): def get_power_of_two(x):
i = 0 i = 0
while (1 << i) < x: while (1 << i) < x:
@ -28,70 +23,3 @@ def rotationMatrixToEulerAngles(R) :
def polygon_area(x,y): def polygon_area(x,y):
return 0.5*np.abs(np.dot(x,np.roll(y,1))-np.dot(y,np.roll(x,1))) return 0.5*np.abs(np.dot(x,np.roll(y,1))-np.dot(y,np.roll(x,1)))
def rotate_point(origin, point, deg):
"""
Rotate a point counterclockwise by a given angle around a given origin.
The angle should be given in radians.
"""
ox, oy = origin
px, py = point
rad = deg * math.pi / 180.0
qx = ox + math.cos(rad) * (px - ox) - math.sin(rad) * (py - oy)
qy = oy + math.sin(rad) * (px - ox) + math.cos(rad) * (py - oy)
return np.float32([qx, qy])
def transform_points(points, mat, invert=False):
if invert:
mat = cv2.invertAffineTransform (mat)
points = np.expand_dims(points, axis=1)
points = cv2.transform(points, mat, points.shape)
points = np.squeeze(points)
return points
def transform_mat(mat, res, tx, ty, rotation, scale):
"""
transform mat in local space of res
scale -> translate -> rotate
tx,ty float
rotation int degrees
scale float
"""
lt, rt, lb, ct = transform_points ( np.float32([(0,0),(res,0),(0,res),(res / 2, res/2) ]),mat, True)
hor_v = (rt-lt).astype(np.float32)
hor_size = npla.norm(hor_v)
hor_v /= hor_size
ver_v = (lb-lt).astype(np.float32)
ver_size = npla.norm(ver_v)
ver_v /= ver_size
bt_diag_vec = (rt-ct).astype(np.float32)
half_diag_len = npla.norm(bt_diag_vec)
bt_diag_vec /= half_diag_len
tb_diag_vec = np.float32( [ -bt_diag_vec[1], bt_diag_vec[0] ] )
rt = ct + bt_diag_vec*half_diag_len*scale
lb = ct - bt_diag_vec*half_diag_len*scale
lt = ct - tb_diag_vec*half_diag_len*scale
rt[0] += tx*hor_size
lb[0] += tx*hor_size
lt[0] += tx*hor_size
rt[1] += ty*ver_size
lb[1] += ty*ver_size
lt[1] += ty*ver_size
rt = rotate_point(ct, rt, rotation)
lb = rotate_point(ct, lb, rotation)
lt = rotate_point(ct, lt, rotation)
return cv2.getAffineTransform( np.float32([lt, rt, lb]), np.float32([ [0,0], [res,0], [0,res] ]) )

View file

@ -60,11 +60,9 @@ class MPSharedList():
break break
key -= self.obj_counts[i] key -= self.obj_counts[i]
sh_b = memoryview(sh_b).cast('B') offset_start, offset_end = struct.unpack('<QQ', bytes(sh_b[ table_offset + key*8 : table_offset + (key+2)*8]) )
offset_start, offset_end = struct.unpack('<QQ', sh_b[ table_offset + key*8 : table_offset + (key+2)*8].tobytes() ) return pickle.loads( bytes(sh_b[ data_offset + offset_start : data_offset + offset_end ]) )
return pickle.loads( sh_b[ data_offset + offset_start : data_offset + offset_end ].tobytes() )
def __iter__(self): def __iter__(self):
for i in range(self.__len__()): for i in range(self.__len__()):
@ -86,8 +84,7 @@ class MPSharedList():
data_size = sum([len(x) for x in obj_pickled_ar]) data_size = sum([len(x) for x in obj_pickled_ar])
sh_b = multiprocessing.RawArray('B', table_size + data_size) sh_b = multiprocessing.RawArray('B', table_size + data_size)
#sh_b[0:8] = struct.pack('<Q', obj_count) sh_b[0:8] = struct.pack('<Q', obj_count)
sh_b_view = memoryview(sh_b).cast('B')
offset = 0 offset = 0
@ -100,12 +97,51 @@ class MPSharedList():
offset += len(obj_pickled_ar[i]) offset += len(obj_pickled_ar[i])
offsets.append(offset) offsets.append(offset)
sh_b_view[table_offset:table_offset+table_size] = struct.pack( '<'+'Q'*len(offsets), *offsets ) sh_b[table_offset:table_offset+table_size] = struct.pack( '<'+'Q'*len(offsets), *offsets )
for i, obj_pickled in enumerate(obj_pickled_ar): ArrayFillerSubprocessor(sh_b, [ (data_offset+offsets[i], obj_pickled_ar[i] ) for i in range(obj_count) ] ).run()
offset = data_offset+offsets[i]
sh_b_view[offset:offset+len(obj_pickled)] = obj_pickled_ar[i]
return obj_count, table_offset, data_offset, sh_b return obj_count, table_offset, data_offset, sh_b
return 0, 0, 0, None return 0, 0, 0, None
class ArrayFillerSubprocessor(Subprocessor):
"""
Much faster to fill shared memory via subprocesses rather than direct whole bytes fill.
"""
#override
def __init__(self, sh_b, data_list ):
self.sh_b = sh_b
self.data_list = data_list
super().__init__('ArrayFillerSubprocessor', ArrayFillerSubprocessor.Cli, 60)
#override
def process_info_generator(self):
for i in range(min(multiprocessing.cpu_count(), 8)):
yield 'CPU%d' % (i), {}, {'sh_b':self.sh_b}
#override
def get_data(self, host_dict):
if len(self.data_list) > 0:
return self.data_list.pop(0)
return None
#override
def on_data_return (self, host_dict, data):
self.data_list.insert(0, data)
#override
def on_result (self, host_dict, data, result):
pass
class Cli(Subprocessor.Cli):
#overridable optional
def on_initialize(self, client_dict):
self.sh_b = client_dict['sh_b']
def process_data(self, data):
offset, b = data
self.sh_b[offset:offset+len(b)]=b
return 0

View file

@ -5,6 +5,96 @@ import time
import numpy as np import numpy as np
class Index2DHost():
"""
Provides random shuffled 2D indexes for multiprocesses
"""
def __init__(self, indexes2D):
self.sq = multiprocessing.Queue()
self.cqs = []
self.clis = []
self.thread = threading.Thread(target=self.host_thread, args=(indexes2D,) )
self.thread.daemon = True
self.thread.start()
def host_thread(self, indexes2D):
indexes_counts_len = len(indexes2D)
idxs = [*range(indexes_counts_len)]
idxs_2D = [None]*indexes_counts_len
shuffle_idxs = []
shuffle_idxs_2D = [None]*indexes_counts_len
for i in range(indexes_counts_len):
idxs_2D[i] = indexes2D[i]
shuffle_idxs_2D[i] = []
sq = self.sq
while True:
while not sq.empty():
obj = sq.get()
cq_id, cmd = obj[0], obj[1]
if cmd == 0: #get_1D
count = obj[2]
result = []
for i in range(count):
if len(shuffle_idxs) == 0:
shuffle_idxs = idxs.copy()
np.random.shuffle(shuffle_idxs)
result.append(shuffle_idxs.pop())
self.cqs[cq_id].put (result)
elif cmd == 1: #get_2D
targ_idxs,count = obj[2], obj[3]
result = []
for targ_idx in targ_idxs:
sub_idxs = []
for i in range(count):
ar = shuffle_idxs_2D[targ_idx]
if len(ar) == 0:
ar = shuffle_idxs_2D[targ_idx] = idxs_2D[targ_idx].copy()
np.random.shuffle(ar)
sub_idxs.append(ar.pop())
result.append (sub_idxs)
self.cqs[cq_id].put (result)
time.sleep(0.005)
def create_cli(self):
cq = multiprocessing.Queue()
self.cqs.append ( cq )
cq_id = len(self.cqs)-1
return Index2DHost.Cli(self.sq, cq, cq_id)
# disable pickling
def __getstate__(self):
return dict()
def __setstate__(self, d):
self.__dict__.update(d)
class Cli():
def __init__(self, sq, cq, cq_id):
self.sq = sq
self.cq = cq
self.cq_id = cq_id
def get_1D(self, count):
self.sq.put ( (self.cq_id,0, count) )
while True:
if not self.cq.empty():
return self.cq.get()
time.sleep(0.001)
def get_2D(self, idxs, count):
self.sq.put ( (self.cq_id,1,idxs,count) )
while True:
if not self.cq.empty():
return self.cq.get()
time.sleep(0.001)
class IndexHost(): class IndexHost():
""" """
@ -66,95 +156,6 @@ class IndexHost():
return self.cq.get() return self.cq.get()
time.sleep(0.001) time.sleep(0.001)
class Index2DHost():
"""
Provides random shuffled indexes for multiprocesses
"""
def __init__(self, indexes2D):
self.sq = multiprocessing.Queue()
self.cqs = []
self.clis = []
self.thread = threading.Thread(target=self.host_thread, args=(indexes2D,) )
self.thread.daemon = True
self.thread.start()
def host_thread(self, indexes2D):
indexes2D_len = len(indexes2D)
idxs = [*range(indexes2D_len)]
idxs_2D = [None]*indexes2D_len
shuffle_idxs = []
shuffle_idxs_2D = [None]*indexes2D_len
for i in range(indexes2D_len):
idxs_2D[i] = [*range(len(indexes2D[i]))]
shuffle_idxs_2D[i] = []
#print(idxs)
#print(idxs_2D)
sq = self.sq
while True:
while not sq.empty():
obj = sq.get()
cq_id, count = obj[0], obj[1]
result = []
for i in range(count):
if len(shuffle_idxs) == 0:
shuffle_idxs = idxs.copy()
np.random.shuffle(shuffle_idxs)
idx_1D = shuffle_idxs.pop()
#print(f'idx_1D = {idx_1D}, len(shuffle_idxs_2D[idx_1D])= {len(shuffle_idxs_2D[idx_1D])}')
if len(shuffle_idxs_2D[idx_1D]) == 0:
shuffle_idxs_2D[idx_1D] = idxs_2D[idx_1D].copy()
#print(f'new shuffle_idxs_2d for {idx_1D} = { shuffle_idxs_2D[idx_1D] }')
#print(f'len(shuffle_idxs_2D[idx_1D])= {len(shuffle_idxs_2D[idx_1D])}')
np.random.shuffle( shuffle_idxs_2D[idx_1D] )
idx_2D = shuffle_idxs_2D[idx_1D].pop()
#print(f'len(shuffle_idxs_2D[idx_1D])= {len(shuffle_idxs_2D[idx_1D])}')
#print(f'idx_2D = {idx_2D}')
result.append( indexes2D[idx_1D][idx_2D])
self.cqs[cq_id].put (result)
time.sleep(0.001)
def create_cli(self):
cq = multiprocessing.Queue()
self.cqs.append ( cq )
cq_id = len(self.cqs)-1
return Index2DHost.Cli(self.sq, cq, cq_id)
# disable pickling
def __getstate__(self):
return dict()
def __setstate__(self, d):
self.__dict__.update(d)
class Cli():
def __init__(self, sq, cq, cq_id):
self.sq = sq
self.cq = cq
self.cq_id = cq_id
def multi_get(self, count):
self.sq.put ( (self.cq_id,count) )
while True:
if not self.cq.empty():
return self.cq.get()
time.sleep(0.001)
class ListHost(): class ListHost():
def __init__(self, list_): def __init__(self, list_):
self.sq = multiprocessing.Queue() self.sq = multiprocessing.Queue()

View file

@ -38,8 +38,8 @@ def QImage_from_np(img):
return QImage(img.data, w, h, c*w, fmt ) return QImage(img.data, w, h, c*w, fmt )
def QImage_to_np(q_img, fmt=QImage.Format_BGR888): def QImage_to_np(q_img):
q_img = q_img.convertToFormat(fmt) q_img = q_img.convertToFormat(QImage.Format_BGR888)
width = q_img.width() width = q_img.width()
height = q_img.height() height = q_img.height()

View file

@ -1,14 +1,12 @@
import numpy as np import numpy as np
def random_normal( size=(1,), trunc_val = 2.5, rnd_state=None ): def random_normal( size=(1,), trunc_val = 2.5 ):
if rnd_state is None:
rnd_state = np.random
len = np.array(size).prod() len = np.array(size).prod()
result = np.empty ( (len,) , dtype=np.float32) result = np.empty ( (len,) , dtype=np.float32)
for i in range (len): for i in range (len):
while True: while True:
x = rnd_state.normal() x = np.random.normal()
if x >= -trunc_val and x <= trunc_val: if x >= -trunc_val and x <= trunc_val:
break break
result[i] = (x / trunc_val) result[i] = (x / trunc_val)

BIN
doc/Alipay_donation.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 63 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 482 KiB

After

Width:  |  Height:  |  Size: 544 KiB

Before After
Before After

Binary file not shown.

Before

Width:  |  Height:  |  Size: 74 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 68 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1 MiB

After

Width:  |  Height:  |  Size: 313 KiB

Before After
Before After

Binary file not shown.

View file

Before

Width:  |  Height:  |  Size: 71 KiB

After

Width:  |  Height:  |  Size: 71 KiB

Before After
Before After

Binary file not shown.

Before

Width:  |  Height:  |  Size: 122 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 123 KiB

View file

Before

Width:  |  Height:  |  Size: 67 KiB

After

Width:  |  Height:  |  Size: 67 KiB

Before After
Before After

Binary file not shown.

Before

Width:  |  Height:  |  Size: 98 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 97 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 25 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 208 KiB

After

Width:  |  Height:  |  Size: 178 KiB

Before After
Before After

Binary file not shown.

Before

Width:  |  Height:  |  Size: 310 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 273 KiB

BIN
doc/political_speech.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 548 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 247 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 349 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 662 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 378 KiB

BIN
doc/replace_the_face.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1,004 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 268 B

View file

@ -28,13 +28,13 @@ class FANExtractor(object):
self.out_planes = out_planes self.out_planes = out_planes
self.bn1 = nn.BatchNorm2D(in_planes) self.bn1 = nn.BatchNorm2D(in_planes)
self.conv1 = nn.Conv2D (in_planes, out_planes//2, kernel_size=3, strides=1, padding='SAME', use_bias=False ) self.conv1 = nn.Conv2D (in_planes, out_planes/2, kernel_size=3, strides=1, padding='SAME', use_bias=False )
self.bn2 = nn.BatchNorm2D(out_planes//2) self.bn2 = nn.BatchNorm2D(out_planes//2)
self.conv2 = nn.Conv2D (out_planes//2, out_planes//4, kernel_size=3, strides=1, padding='SAME', use_bias=False ) self.conv2 = nn.Conv2D (out_planes/2, out_planes/4, kernel_size=3, strides=1, padding='SAME', use_bias=False )
self.bn3 = nn.BatchNorm2D(out_planes//4) self.bn3 = nn.BatchNorm2D(out_planes//4)
self.conv3 = nn.Conv2D (out_planes//4, out_planes//4, kernel_size=3, strides=1, padding='SAME', use_bias=False ) self.conv3 = nn.Conv2D (out_planes/4, out_planes/4, kernel_size=3, strides=1, padding='SAME', use_bias=False )
if self.in_planes != self.out_planes: if self.in_planes != self.out_planes:
self.down_bn1 = nn.BatchNorm2D(in_planes) self.down_bn1 = nn.BatchNorm2D(in_planes)

View file

@ -161,11 +161,11 @@ class FaceEnhancer(object):
if not model_path.exists(): if not model_path.exists():
raise Exception("Unable to load FaceEnhancer.npy") raise Exception("Unable to load FaceEnhancer.npy")
with tf.device ('/CPU:0' if place_model_on_cpu else nn.tf_default_device_name): with tf.device ('/CPU:0' if place_model_on_cpu else '/GPU:0'):
self.model = FaceEnhancer() self.model = FaceEnhancer()
self.model.load_weights (model_path) self.model.load_weights (model_path)
with tf.device ('/CPU:0' if run_on_cpu else nn.tf_default_device_name): with tf.device ('/CPU:0' if run_on_cpu else '/GPU:0'):
self.model.build_for_run ([ (tf.float32, nn.get4Dshape (192,192,3) ), self.model.build_for_run ([ (tf.float32, nn.get4Dshape (192,192,3) ),
(tf.float32, (None,1,) ), (tf.float32, (None,1,) ),
(tf.float32, (None,1,) ), (tf.float32, (None,1,) ),
@ -248,7 +248,7 @@ class FaceEnhancer(object):
final_img = final_img [t_padding*up_res:(h-b_padding)*up_res, l_padding*up_res:(w-r_padding)*up_res,:] final_img = final_img [t_padding*up_res:(h-b_padding)*up_res, l_padding*up_res:(w-r_padding)*up_res,:]
if preserve_size: if preserve_size:
final_img = cv2.resize (final_img, (iw,ih), interpolation=cv2.INTER_LANCZOS4) final_img = cv2.resize (final_img, (iw,ih), cv2.INTER_LANCZOS4)
if not is_tanh: if not is_tanh:
final_img = np.clip( final_img/2+0.5, 0, 1 ) final_img = np.clip( final_img/2+0.5, 0, 1 )
@ -278,7 +278,7 @@ class FaceEnhancer(object):
preupscale_rate = 1.0 / ( max(h,w) / patch_size ) preupscale_rate = 1.0 / ( max(h,w) / patch_size )
if preupscale_rate != 1.0: if preupscale_rate != 1.0:
inp_img = cv2.resize (inp_img, ( int(w*preupscale_rate), int(h*preupscale_rate) ), interpolation=cv2.INTER_LANCZOS4) inp_img = cv2.resize (inp_img, ( int(w*preupscale_rate), int(h*preupscale_rate) ), cv2.INTER_LANCZOS4)
h,w,c = inp_img.shape h,w,c = inp_img.shape
i_max = w-patch_size+1 i_max = w-patch_size+1
@ -310,10 +310,10 @@ class FaceEnhancer(object):
final_img /= final_img_div final_img /= final_img_div
if preserve_size: if preserve_size:
final_img = cv2.resize (final_img, (w,h), interpolation=cv2.INTER_LANCZOS4) final_img = cv2.resize (final_img, (w,h), cv2.INTER_LANCZOS4)
else: else:
if preupscale_rate != 1.0: if preupscale_rate != 1.0:
final_img = cv2.resize (final_img, (tw,th), interpolation=cv2.INTER_LANCZOS4) final_img = cv2.resize (final_img, (tw,th), cv2.INTER_LANCZOS4)
if not is_tanh: if not is_tanh:
final_img = np.clip( final_img/2+0.5, 0, 1 ) final_img = np.clip( final_img/2+0.5, 0, 1 )

View file

@ -302,6 +302,8 @@ def get_transform_mat (image_landmarks, output_size, face_type, scale=1.0):
g_c += vec*vec_len*0.07 g_c += vec*vec_len*0.07
elif face_type == FaceType.HEAD: elif face_type == FaceType.HEAD:
mat = umeyama( np.concatenate ( [ image_landmarks[17:49] , image_landmarks[54:55] ] ) , landmarks_2D_new, True)[0:2]
# assuming image_landmarks are 3D_Landmarks extracted for HEAD, # assuming image_landmarks are 3D_Landmarks extracted for HEAD,
# adjust horizontal offset according to estimated yaw # adjust horizontal offset according to estimated yaw
yaw = estimate_averaged_yaw(transform_points (image_landmarks, mat, False)) yaw = estimate_averaged_yaw(transform_points (image_landmarks, mat, False))
@ -431,27 +433,6 @@ def get_image_eye_mask (image_shape, image_landmarks):
return hull_mask return hull_mask
def get_image_mouth_mask (image_shape, image_landmarks):
if len(image_landmarks) != 68:
raise Exception('get_image_eye_mask works only with 68 landmarks')
h,w,c = image_shape
hull_mask = np.zeros( (h,w,1),dtype=np.float32)
image_landmarks = image_landmarks.astype(np.int)
cv2.fillConvexPoly( hull_mask, cv2.convexHull( image_landmarks[60:]), (1,) )
dilate = h // 32
hull_mask = cv2.dilate(hull_mask, cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(dilate,dilate)), iterations = 1 )
blur = h // 16
blur = blur + (1-blur % 2)
hull_mask = cv2.GaussianBlur(hull_mask, (blur, blur) , 0)
hull_mask = hull_mask[...,None]
return hull_mask
def alpha_to_color (img_alpha, color): def alpha_to_color (img_alpha, color):
if len(img_alpha.shape) == 2: if len(img_alpha.shape) == 2:

View file

@ -30,37 +30,35 @@ class XSegNet(object):
nn.initialize(data_format=data_format) nn.initialize(data_format=data_format)
tf = nn.tf tf = nn.tf
model_name = f'{name}_{resolution}'
self.model_filename_list = []
with tf.device ('/CPU:0'): with tf.device ('/CPU:0'):
#Place holders on CPU #Place holders on CPU
self.input_t = tf.placeholder (nn.floatx, nn.get4Dshape(resolution,resolution,3) ) self.input_t = tf.placeholder (nn.floatx, nn.get4Dshape(resolution,resolution,3) )
self.target_t = tf.placeholder (nn.floatx, nn.get4Dshape(resolution,resolution,1) ) self.target_t = tf.placeholder (nn.floatx, nn.get4Dshape(resolution,resolution,1) )
# Initializing model classes # Initializing model classes
with tf.device ('/CPU:0' if place_model_on_cpu else nn.tf_default_device_name): with tf.device ('/CPU:0' if place_model_on_cpu else '/GPU:0'):
self.model = nn.XSeg(3, 32, 1, name=name) self.model = nn.XSeg(3, 32, 1, name=name)
self.model_weights = self.model.get_weights() self.model_weights = self.model.get_weights()
if training:
if optimizer is None:
raise ValueError("Optimizer should be provided for training mode.")
self.opt = optimizer
self.opt.initialize_variables (self.model_weights, vars_on_cpu=place_model_on_cpu)
self.model_filename_list += [ [self.opt, f'{model_name}_opt.npy' ] ]
model_name = f'{name}_{resolution}'
self.model_filename_list += [ [self.model, f'{model_name}.npy'] ] self.model_filename_list = [ [self.model, f'{model_name}.npy'] ]
if not training: if training:
with tf.device ('/CPU:0' if run_on_cpu else nn.tf_default_device_name): if optimizer is None:
raise ValueError("Optimizer should be provided for training mode.")
self.opt = optimizer
self.opt.initialize_variables (self.model_weights, vars_on_cpu=place_model_on_cpu)
self.model_filename_list += [ [self.opt, f'{model_name}_opt.npy' ] ]
else:
with tf.device ('/CPU:0' if run_on_cpu else '/GPU:0'):
_, pred = self.model(self.input_t) _, pred = self.model(self.input_t)
def net_run(input_np): def net_run(input_np):
return nn.tf_sess.run ( [pred], feed_dict={self.input_t :input_np})[0] return nn.tf_sess.run ( [pred], feed_dict={self.input_t :input_np})[0]
self.net_run = net_run self.net_run = net_run
self.initialized = True
# Loading/initializing all models/optimizers weights # Loading/initializing all models/optimizers weights
for model, filename in self.model_filename_list: for model, filename in self.model_filename_list:
do_init = not load_weights do_init = not load_weights
@ -68,12 +66,8 @@ class XSegNet(object):
if not do_init: if not do_init:
model_file_path = self.weights_file_root / filename model_file_path = self.weights_file_root / filename
do_init = not model.load_weights( model_file_path ) do_init = not model.load_weights( model_file_path )
if do_init: if do_init and raise_on_no_model_files:
if raise_on_no_model_files: raise Exception(f'{model_file_path} does not exists.')
raise Exception(f'{model_file_path} does not exists.')
if not training:
self.initialized = False
break
if do_init: if do_init:
model.init_weights() model.init_weights()
@ -81,8 +75,8 @@ class XSegNet(object):
def get_resolution(self): def get_resolution(self):
return self.resolution return self.resolution
def flow(self, x, pretrain=False): def flow(self, x):
return self.model(x, pretrain=pretrain) return self.model(x)
def get_weights(self): def get_weights(self):
return self.model_weights return self.model_weights
@ -92,9 +86,6 @@ class XSegNet(object):
model.save_weights( self.weights_file_root / filename ) model.save_weights( self.weights_file_root / filename )
def extract (self, input_image): def extract (self, input_image):
if not self.initialized:
return 0.5*np.ones ( (self.resolution, self.resolution, 1), nn.floatx.as_numpy_dtype )
input_shape_len = len(input_image.shape) input_shape_len = len(input_image.shape)
if input_shape_len == 3: if input_shape_len == 3:
input_image = input_image[None,...] input_image = input_image[None,...]

47
main.py
View file

@ -22,8 +22,6 @@ if __name__ == "__main__":
def __call__(self, parser, namespace, values, option_string=None): def __call__(self, parser, namespace, values, option_string=None):
setattr(namespace, self.dest, os.path.abspath(os.path.expanduser(values))) setattr(namespace, self.dest, os.path.abspath(os.path.expanduser(values)))
exit_code = 0
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers() subparsers = parser.add_subparsers()
@ -38,9 +36,6 @@ if __name__ == "__main__":
manual_output_debug_fix = arguments.manual_output_debug_fix, manual_output_debug_fix = arguments.manual_output_debug_fix,
manual_window_size = arguments.manual_window_size, manual_window_size = arguments.manual_window_size,
face_type = arguments.face_type, face_type = arguments.face_type,
max_faces_from_image = arguments.max_faces_from_image,
image_size = arguments.image_size,
jpeg_quality = arguments.jpeg_quality,
cpu_only = arguments.cpu_only, cpu_only = arguments.cpu_only,
force_gpu_idxs = [ int(x) for x in arguments.force_gpu_idxs.split(',') ] if arguments.force_gpu_idxs is not None else None, force_gpu_idxs = [ int(x) for x in arguments.force_gpu_idxs.split(',') ] if arguments.force_gpu_idxs is not None else None,
) )
@ -52,9 +47,6 @@ if __name__ == "__main__":
p.add_argument('--output-debug', action="store_true", dest="output_debug", default=None, help="Writes debug images to <output-dir>_debug\ directory.") p.add_argument('--output-debug', action="store_true", dest="output_debug", default=None, help="Writes debug images to <output-dir>_debug\ directory.")
p.add_argument('--no-output-debug', action="store_false", dest="output_debug", default=None, help="Don't writes debug images to <output-dir>_debug\ directory.") p.add_argument('--no-output-debug', action="store_false", dest="output_debug", default=None, help="Don't writes debug images to <output-dir>_debug\ directory.")
p.add_argument('--face-type', dest="face_type", choices=['half_face', 'full_face', 'whole_face', 'head', 'mark_only'], default=None) p.add_argument('--face-type', dest="face_type", choices=['half_face', 'full_face', 'whole_face', 'head', 'mark_only'], default=None)
p.add_argument('--max-faces-from-image', type=int, dest="max_faces_from_image", default=None, help="Max faces from image.")
p.add_argument('--image-size', type=int, dest="image_size", default=None, help="Output image size.")
p.add_argument('--jpeg-quality', type=int, dest="jpeg_quality", default=None, help="Jpeg quality.")
p.add_argument('--manual-fix', action="store_true", dest="manual_fix", default=False, help="Enables manual extract only frames where faces were not recognized.") p.add_argument('--manual-fix', action="store_true", dest="manual_fix", default=False, help="Enables manual extract only frames where faces were not recognized.")
p.add_argument('--manual-output-debug-fix', action="store_true", dest="manual_output_debug_fix", default=False, help="Performs manual reextract input-dir frames which were deleted from [output_dir]_debug\ dir.") p.add_argument('--manual-output-debug-fix', action="store_true", dest="manual_output_debug_fix", default=False, help="Performs manual reextract input-dir frames which were deleted from [output_dir]_debug\ dir.")
p.add_argument('--manual-window-size', type=int, dest="manual_window_size", default=1368, help="Manual fix window size. Default: 1368.") p.add_argument('--manual-window-size', type=int, dest="manual_window_size", default=1368, help="Manual fix window size. Default: 1368.")
@ -70,7 +62,7 @@ if __name__ == "__main__":
p = subparsers.add_parser( "sort", help="Sort faces in a directory.") p = subparsers.add_parser( "sort", help="Sort faces in a directory.")
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory. A directory containing the files you wish to process.") p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory. A directory containing the files you wish to process.")
p.add_argument('--by', dest="sort_by_method", default=None, choices=("blur", "motion-blur", "face-yaw", "face-pitch", "face-source-rect-size", "hist", "hist-dissim", "brightness", "hue", "black", "origname", "oneface", "final-by-blur", "final-by-size", "absdiff"), help="Method of sorting. 'origname' sort by original filename to recover original sequence." ) p.add_argument('--by', dest="sort_by_method", default=None, choices=("blur", "face-yaw", "face-pitch", "face-source-rect-size", "hist", "hist-dissim", "brightness", "hue", "black", "origname", "oneface", "final", "final-faster", "absdiff"), help="Method of sorting. 'origname' sort by original filename to recover original sequence." )
p.set_defaults (func=process_sort) p.set_defaults (func=process_sort)
def process_util(arguments): def process_util(arguments):
@ -99,10 +91,6 @@ if __name__ == "__main__":
from samplelib import PackedFaceset from samplelib import PackedFaceset
PackedFaceset.unpack( Path(arguments.input_dir) ) PackedFaceset.unpack( Path(arguments.input_dir) )
if arguments.export_faceset_mask:
io.log_info ("Exporting faceset mask..\r\n")
Util.export_faceset_mask( Path(arguments.input_dir) )
p = subparsers.add_parser( "util", help="Utilities.") p = subparsers.add_parser( "util", help="Utilities.")
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory. A directory containing the files you wish to process.") p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory. A directory containing the files you wish to process.")
p.add_argument('--add-landmarks-debug-images', action="store_true", dest="add_landmarks_debug_images", default=False, help="Add landmarks debug image for aligned faces.") p.add_argument('--add-landmarks-debug-images', action="store_true", dest="add_landmarks_debug_images", default=False, help="Add landmarks debug image for aligned faces.")
@ -111,7 +99,6 @@ if __name__ == "__main__":
p.add_argument('--restore-faceset-metadata', action="store_true", dest="restore_faceset_metadata", default=False, help="Restore faceset metadata to file. Image filenames must be the same as used with save.") p.add_argument('--restore-faceset-metadata', action="store_true", dest="restore_faceset_metadata", default=False, help="Restore faceset metadata to file. Image filenames must be the same as used with save.")
p.add_argument('--pack-faceset', action="store_true", dest="pack_faceset", default=False, help="") p.add_argument('--pack-faceset', action="store_true", dest="pack_faceset", default=False, help="")
p.add_argument('--unpack-faceset', action="store_true", dest="unpack_faceset", default=False, help="") p.add_argument('--unpack-faceset', action="store_true", dest="unpack_faceset", default=False, help="")
p.add_argument('--export-faceset-mask', action="store_true", dest="export_faceset_mask", default=False, help="")
p.set_defaults (func=process_util) p.set_defaults (func=process_util)
@ -150,19 +137,10 @@ if __name__ == "__main__":
p.add_argument('--force-gpu-idxs', dest="force_gpu_idxs", default=None, help="Force to choose GPU indexes separated by comma.") p.add_argument('--force-gpu-idxs', dest="force_gpu_idxs", default=None, help="Force to choose GPU indexes separated by comma.")
p.add_argument('--silent-start', action="store_true", dest="silent_start", default=False, help="Silent start. Automatically chooses Best GPU and last used model.") p.add_argument('--silent-start', action="store_true", dest="silent_start", default=False, help="Silent start. Automatically chooses Best GPU and last used model.")
p.add_argument('--execute-program', dest="execute_program", default=[], action='append', nargs='+') p.add_argument('--execute-program', dest="execute_program", default=[], action='append', nargs='+')
p.set_defaults (func=process_train) p.set_defaults (func=process_train)
def process_exportdfm(arguments):
osex.set_process_lowest_prio()
from mainscripts import ExportDFM
ExportDFM.main(model_class_name = arguments.model_name, saved_models_path = Path(arguments.model_dir))
p = subparsers.add_parser( "exportdfm", help="Export model to use in DeepFaceLive.")
p.add_argument('--model-dir', required=True, action=fixPathAction, dest="model_dir", help="Saved models dir.")
p.add_argument('--model', required=True, dest="model_name", choices=pathex.get_all_dir_names_startswith ( Path(__file__).parent / 'models' , 'Model_'), help="Model class name.")
p.set_defaults (func=process_exportdfm)
def process_merge(arguments): def process_merge(arguments):
osex.set_process_lowest_prio() osex.set_process_lowest_prio()
from mainscripts import Merger from mainscripts import Merger
@ -267,20 +245,10 @@ if __name__ == "__main__":
p.set_defaults(func=process_faceset_enhancer) p.set_defaults(func=process_faceset_enhancer)
p = facesettool_parser.add_parser ("resize", help="Resize DFL faceset.")
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory of aligned faces.")
def process_faceset_resizer(arguments):
osex.set_process_lowest_prio()
from mainscripts import FacesetResizer
FacesetResizer.process_folder ( Path(arguments.input_dir) )
p.set_defaults(func=process_faceset_resizer)
def process_dev_test(arguments): def process_dev_test(arguments):
osex.set_process_lowest_prio() osex.set_process_lowest_prio()
from mainscripts import dev_misc from mainscripts import dev_misc
dev_misc.dev_gen_mask_files( arguments.input_dir ) dev_misc.dev_test( arguments.input_dir )
p = subparsers.add_parser( "dev_test", help="") p = subparsers.add_parser( "dev_test", help="")
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir") p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir")
@ -294,9 +262,7 @@ if __name__ == "__main__":
def process_xsegeditor(arguments): def process_xsegeditor(arguments):
osex.set_process_lowest_prio() osex.set_process_lowest_prio()
from XSegEditor import XSegEditor from XSegEditor import XSegEditor
global exit_code XSegEditor.start (Path(arguments.input_dir))
exit_code = XSegEditor.start (Path(arguments.input_dir))
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir") p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir")
p.set_defaults (func=process_xsegeditor) p.set_defaults (func=process_xsegeditor)
@ -347,10 +313,7 @@ if __name__ == "__main__":
arguments = parser.parse_args() arguments = parser.parse_args()
arguments.func(arguments) arguments.func(arguments)
if exit_code == 0: print ("Done.")
print ("Done.")
exit(exit_code)
''' '''
import code import code

View file

@ -1,22 +0,0 @@
import os
import sys
import traceback
import queue
import threading
import time
import numpy as np
import itertools
from pathlib import Path
from core import pathex
from core import imagelib
import cv2
import models
from core.interact import interact as io
def main(model_class_name, saved_models_path):
model = models.import_model(model_class_name)(
is_exporting=True,
saved_models_path=saved_models_path,
cpu_only=True)
model.export_dfm ()

View file

@ -10,7 +10,6 @@ from pathlib import Path
import cv2 import cv2
import numpy as np import numpy as np
from numpy import linalg as npla
import facelib import facelib
from core import imagelib from core import imagelib
@ -44,7 +43,6 @@ class ExtractSubprocessor(Subprocessor):
def on_initialize(self, client_dict): def on_initialize(self, client_dict):
self.type = client_dict['type'] self.type = client_dict['type']
self.image_size = client_dict['image_size'] self.image_size = client_dict['image_size']
self.jpeg_quality = client_dict['jpeg_quality']
self.face_type = client_dict['face_type'] self.face_type = client_dict['face_type']
self.max_faces_from_image = client_dict['max_faces_from_image'] self.max_faces_from_image = client_dict['max_faces_from_image']
self.device_idx = client_dict['device_idx'] self.device_idx = client_dict['device_idx']
@ -97,6 +95,9 @@ class ExtractSubprocessor(Subprocessor):
h, w, c = image.shape h, w, c = image.shape
dflimg = DFLIMG.load (filepath)
extract_from_dflimg = (h == w and (dflimg is not None and dflimg.has_data()) )
if 'rects' in self.type or self.type == 'all': if 'rects' in self.type or self.type == 'all':
data = ExtractSubprocessor.Cli.rects_stage (data=data, data = ExtractSubprocessor.Cli.rects_stage (data=data,
image=image, image=image,
@ -107,6 +108,7 @@ class ExtractSubprocessor(Subprocessor):
if 'landmarks' in self.type or self.type == 'all': if 'landmarks' in self.type or self.type == 'all':
data = ExtractSubprocessor.Cli.landmarks_stage (data=data, data = ExtractSubprocessor.Cli.landmarks_stage (data=data,
image=image, image=image,
extract_from_dflimg=extract_from_dflimg,
landmarks_extractor=self.landmarks_extractor, landmarks_extractor=self.landmarks_extractor,
rects_extractor=self.rects_extractor, rects_extractor=self.rects_extractor,
) )
@ -116,7 +118,7 @@ class ExtractSubprocessor(Subprocessor):
image=image, image=image,
face_type=self.face_type, face_type=self.face_type,
image_size=self.image_size, image_size=self.image_size,
jpeg_quality=self.jpeg_quality, extract_from_dflimg=extract_from_dflimg,
output_debug_path=self.output_debug_path, output_debug_path=self.output_debug_path,
final_output_path=self.final_output_path, final_output_path=self.final_output_path,
) )
@ -146,9 +148,7 @@ class ExtractSubprocessor(Subprocessor):
if len(rects) != 0: if len(rects) != 0:
data.rects_rotation = rot data.rects_rotation = rot
break break
if max_faces_from_image is not None and \ if max_faces_from_image != 0 and len(data.rects) > 1:
max_faces_from_image > 0 and \
len(data.rects) > 0:
data.rects = data.rects[0:max_faces_from_image] data.rects = data.rects[0:max_faces_from_image]
return data return data
@ -156,6 +156,7 @@ class ExtractSubprocessor(Subprocessor):
@staticmethod @staticmethod
def landmarks_stage(data, def landmarks_stage(data,
image, image,
extract_from_dflimg,
landmarks_extractor, landmarks_extractor,
rects_extractor, rects_extractor,
): ):
@ -170,7 +171,7 @@ class ExtractSubprocessor(Subprocessor):
elif data.rects_rotation == 270: elif data.rects_rotation == 270:
rotated_image = image.swapaxes( 0,1 )[::-1,:,:] rotated_image = image.swapaxes( 0,1 )[::-1,:,:]
data.landmarks = landmarks_extractor.extract (rotated_image, data.rects, rects_extractor if (data.landmarks_accurate) else None, is_bgr=True) data.landmarks = landmarks_extractor.extract (rotated_image, data.rects, rects_extractor if (not extract_from_dflimg and data.landmarks_accurate) else None, is_bgr=True)
if data.rects_rotation != 0: if data.rects_rotation != 0:
for i, (rect, lmrks) in enumerate(zip(data.rects, data.landmarks)): for i, (rect, lmrks) in enumerate(zip(data.rects, data.landmarks)):
new_rect, new_lmrks = rect, lmrks new_rect, new_lmrks = rect, lmrks
@ -200,7 +201,7 @@ class ExtractSubprocessor(Subprocessor):
image, image,
face_type, face_type,
image_size, image_size,
jpeg_quality, extract_from_dflimg = False,
output_debug_path=None, output_debug_path=None,
final_output_path=None, final_output_path=None,
): ):
@ -212,53 +213,72 @@ class ExtractSubprocessor(Subprocessor):
if output_debug_path is not None: if output_debug_path is not None:
debug_image = image.copy() debug_image = image.copy()
face_idx = 0 if extract_from_dflimg and len(rects) != 1:
for rect, image_landmarks in zip( rects, landmarks ): #if re-extracting from dflimg and more than 1 or zero faces detected - dont process and just copy it
if image_landmarks is None: print("extract_from_dflimg and len(rects) != 1", filepath )
continue output_filepath = final_output_path / filepath.name
if filepath != str(output_file):
shutil.copy ( str(filepath), str(output_filepath) )
data.final_output_files.append (output_filepath)
else:
face_idx = 0
for rect, image_landmarks in zip( rects, landmarks ):
rect = np.array(rect) if extract_from_dflimg and face_idx > 1:
#cannot extract more than 1 face from dflimg
break
if face_type == FaceType.MARK_ONLY: if image_landmarks is None:
image_to_face_mat = None
face_image = image
face_image_landmarks = image_landmarks
else:
image_to_face_mat = LandmarksProcessor.get_transform_mat (image_landmarks, image_size, face_type)
face_image = cv2.warpAffine(image, image_to_face_mat, (image_size, image_size), cv2.INTER_LANCZOS4)
face_image_landmarks = LandmarksProcessor.transform_points (image_landmarks, image_to_face_mat)
landmarks_bbox = LandmarksProcessor.transform_points ( [ (0,0), (0,image_size-1), (image_size-1, image_size-1), (image_size-1,0) ], image_to_face_mat, True)
rect_area = mathlib.polygon_area(np.array(rect[[0,2,2,0]]).astype(np.float32), np.array(rect[[1,1,3,3]]).astype(np.float32))
landmarks_area = mathlib.polygon_area(landmarks_bbox[:,0].astype(np.float32), landmarks_bbox[:,1].astype(np.float32) )
if not data.manual and face_type <= FaceType.FULL_NO_ALIGN and landmarks_area > 4*rect_area: #get rid of faces which umeyama-landmark-area > 4*detector-rect-area
continue continue
if output_debug_path is not None: rect = np.array(rect)
LandmarksProcessor.draw_rect_landmarks (debug_image, rect, image_landmarks, face_type, image_size, transparent_mask=True)
output_path = final_output_path if face_type == FaceType.MARK_ONLY:
if data.force_output_path is not None: image_to_face_mat = None
output_path = data.force_output_path face_image = image
face_image_landmarks = image_landmarks
else:
image_to_face_mat = LandmarksProcessor.get_transform_mat (image_landmarks, image_size, face_type)
output_filepath = output_path / f"{filepath.stem}_{face_idx}.jpg" face_image = cv2.warpAffine(image, image_to_face_mat, (image_size, image_size), cv2.INTER_LANCZOS4)
cv2_imwrite(output_filepath, face_image, [int(cv2.IMWRITE_JPEG_QUALITY), jpeg_quality ] ) face_image_landmarks = LandmarksProcessor.transform_points (image_landmarks, image_to_face_mat)
dflimg = DFLJPG.load(output_filepath) landmarks_bbox = LandmarksProcessor.transform_points ( [ (0,0), (0,image_size-1), (image_size-1, image_size-1), (image_size-1,0) ], image_to_face_mat, True)
dflimg.set_face_type(FaceType.toString(face_type))
dflimg.set_landmarks(face_image_landmarks.tolist())
dflimg.set_source_filename(filepath.name)
dflimg.set_source_rect(rect)
dflimg.set_source_landmarks(image_landmarks.tolist())
dflimg.set_image_to_face_mat(image_to_face_mat)
dflimg.save()
data.final_output_files.append (output_filepath) rect_area = mathlib.polygon_area(np.array(rect[[0,2,2,0]]).astype(np.float32), np.array(rect[[1,1,3,3]]).astype(np.float32))
face_idx += 1 landmarks_area = mathlib.polygon_area(landmarks_bbox[:,0].astype(np.float32), landmarks_bbox[:,1].astype(np.float32) )
data.faces_detected = face_idx
if not data.manual and face_type <= FaceType.FULL_NO_ALIGN and landmarks_area > 4*rect_area: #get rid of faces which umeyama-landmark-area > 4*detector-rect-area
continue
if output_debug_path is not None:
LandmarksProcessor.draw_rect_landmarks (debug_image, rect, image_landmarks, face_type, image_size, transparent_mask=True)
output_path = final_output_path
if data.force_output_path is not None:
output_path = data.force_output_path
if extract_from_dflimg and filepath.suffix == '.jpg':
#if extracting from dflimg and jpg copy it in order not to lose quality
output_filepath = output_path / filepath.name
if filepath != output_filepath:
shutil.copy ( str(filepath), str(output_filepath) )
else:
output_filepath = output_path / f"{filepath.stem}_{face_idx}.jpg"
cv2_imwrite(output_filepath, face_image, [int(cv2.IMWRITE_JPEG_QUALITY), 90] )
dflimg = DFLJPG.load(output_filepath)
dflimg.set_face_type(FaceType.toString(face_type))
dflimg.set_landmarks(face_image_landmarks.tolist())
dflimg.set_source_filename(filepath.name)
dflimg.set_source_rect(rect)
dflimg.set_source_landmarks(image_landmarks.tolist())
dflimg.set_image_to_face_mat(image_to_face_mat)
dflimg.save()
data.final_output_files.append (output_filepath)
face_idx += 1
data.faces_detected = face_idx
if output_debug_path is not None: if output_debug_path is not None:
cv2_imwrite( output_debug_path / (filepath.stem+'.jpg'), debug_image, [int(cv2.IMWRITE_JPEG_QUALITY), 50] ) cv2_imwrite( output_debug_path / (filepath.stem+'.jpg'), debug_image, [int(cv2.IMWRITE_JPEG_QUALITY), 50] )
@ -304,7 +324,7 @@ class ExtractSubprocessor(Subprocessor):
elif type == 'final': elif type == 'final':
return [ (i, 'CPU', 'CPU%d' % (i), 0 ) for i in (range(min(8, multiprocessing.cpu_count())) if not DEBUG else [0]) ] return [ (i, 'CPU', 'CPU%d' % (i), 0 ) for i in (range(min(8, multiprocessing.cpu_count())) if not DEBUG else [0]) ]
def __init__(self, input_data, type, image_size=None, jpeg_quality=None, face_type=None, output_debug_path=None, manual_window_size=0, max_faces_from_image=0, final_output_path=None, device_config=None): def __init__(self, input_data, type, image_size=None, face_type=None, output_debug_path=None, manual_window_size=0, max_faces_from_image=0, final_output_path=None, device_config=None):
if type == 'landmarks-manual': if type == 'landmarks-manual':
for x in input_data: for x in input_data:
x.manual = True x.manual = True
@ -313,7 +333,6 @@ class ExtractSubprocessor(Subprocessor):
self.type = type self.type = type
self.image_size = image_size self.image_size = image_size
self.jpeg_quality = jpeg_quality
self.face_type = face_type self.face_type = face_type
self.output_debug_path = output_debug_path self.output_debug_path = output_debug_path
self.final_output_path = final_output_path self.final_output_path = final_output_path
@ -339,7 +358,6 @@ class ExtractSubprocessor(Subprocessor):
self.cache_text_lines_img = (None, None) self.cache_text_lines_img = (None, None)
self.hide_help = False self.hide_help = False
self.landmarks_accurate = True self.landmarks_accurate = True
self.force_landmarks = False
self.landmarks = None self.landmarks = None
self.x = 0 self.x = 0
@ -348,9 +366,6 @@ class ExtractSubprocessor(Subprocessor):
self.rect_locked = False self.rect_locked = False
self.extract_needed = True self.extract_needed = True
self.image = None
self.image_filepath = None
io.progress_bar (None, len (self.input_data)) io.progress_bar (None, len (self.input_data))
#override #override
@ -364,7 +379,6 @@ class ExtractSubprocessor(Subprocessor):
def process_info_generator(self): def process_info_generator(self):
base_dict = {'type' : self.type, base_dict = {'type' : self.type,
'image_size': self.image_size, 'image_size': self.image_size,
'jpeg_quality' : self.jpeg_quality,
'face_type': self.face_type, 'face_type': self.face_type,
'max_faces_from_image':self.max_faces_from_image, 'max_faces_from_image':self.max_faces_from_image,
'output_debug_path': self.output_debug_path, 'output_debug_path': self.output_debug_path,
@ -383,13 +397,26 @@ class ExtractSubprocessor(Subprocessor):
def get_data(self, host_dict): def get_data(self, host_dict):
if self.type == 'landmarks-manual': if self.type == 'landmarks-manual':
need_remark_face = False need_remark_face = False
redraw_needed = False
while len (self.input_data) > 0: while len (self.input_data) > 0:
data = self.input_data[0] data = self.input_data[0]
filepath, data_rects, data_landmarks = data.filepath, data.rects, data.landmarks filepath, data_rects, data_landmarks = data.filepath, data.rects, data.landmarks
is_frame_done = False is_frame_done = False
if self.image_filepath != filepath: if need_remark_face: # need remark image from input data that already has a marked face?
self.image_filepath = filepath need_remark_face = False
if len(data_rects) != 0: # If there was already a face then lock the rectangle to it until the mouse is clicked
self.rect = data_rects.pop()
self.landmarks = data_landmarks.pop()
data_rects.clear()
data_landmarks.clear()
redraw_needed = True
self.rect_locked = True
self.rect_size = ( self.rect[2] - self.rect[0] ) / 2
self.x = ( self.rect[0] + self.rect[2] ) / 2
self.y = ( self.rect[1] + self.rect[3] ) / 2
if len(data_rects) == 0:
if self.cache_original_image[0] == filepath: if self.cache_original_image[0] == filepath:
self.original_image = self.cache_original_image[1] self.original_image = self.cache_original_image[1]
else: else:
@ -413,8 +440,8 @@ class ExtractSubprocessor(Subprocessor):
self.text_lines_img = self.cache_text_lines_img[1] self.text_lines_img = self.cache_text_lines_img[1]
else: else:
self.text_lines_img = (imagelib.get_draw_text_lines ( self.image, sh, self.text_lines_img = (imagelib.get_draw_text_lines ( self.image, sh,
[ '[L Mouse click] - lock/unlock selection. [Mouse wheel] - change rect', [ '[Mouse click] - lock/unlock selection',
'[R Mouse Click] - manual face rectangle', '[Mouse wheel] - change rect',
'[Enter] / [Space] - confirm / skip frame', '[Enter] / [Space] - confirm / skip frame',
'[,] [.]- prev frame, next frame. [Q] - skip remaining frames', '[,] [.]- prev frame, next frame. [Q] - skip remaining frames',
'[a] - accuracy on/off (more fps)', '[a] - accuracy on/off (more fps)',
@ -423,29 +450,11 @@ class ExtractSubprocessor(Subprocessor):
self.cache_text_lines_img = (sh, self.text_lines_img) self.cache_text_lines_img = (sh, self.text_lines_img)
if need_remark_face: # need remark image from input data that already has a marked face?
need_remark_face = False
if len(data_rects) != 0: # If there was already a face then lock the rectangle to it until the mouse is clicked
self.rect = data_rects.pop()
self.landmarks = data_landmarks.pop()
data_rects.clear()
data_landmarks.clear()
self.rect_locked = True
self.rect_size = ( self.rect[2] - self.rect[0] ) / 2
self.x = ( self.rect[0] + self.rect[2] ) / 2
self.y = ( self.rect[1] + self.rect[3] ) / 2
self.redraw()
if len(data_rects) == 0:
(h,w,c) = self.image.shape
while True: while True:
io.process_messages(0.0001) io.process_messages(0.0001)
if not self.force_landmarks: new_x = self.x
new_x = self.x new_y = self.y
new_y = self.y
new_rect_size = self.rect_size new_rect_size = self.rect_size
mouse_events = io.get_mouse_events(self.wnd_name) mouse_events = io.get_mouse_events(self.wnd_name)
@ -456,19 +465,8 @@ class ExtractSubprocessor(Subprocessor):
diff = 1 if new_rect_size <= 40 else np.clip(new_rect_size / 10, 1, 10) diff = 1 if new_rect_size <= 40 else np.clip(new_rect_size / 10, 1, 10)
new_rect_size = max (5, new_rect_size + diff*mod) new_rect_size = max (5, new_rect_size + diff*mod)
elif ev == io.EVENT_LBUTTONDOWN: elif ev == io.EVENT_LBUTTONDOWN:
if self.force_landmarks: self.rect_locked = not self.rect_locked
self.x = new_x self.extract_needed = True
self.y = new_y
self.force_landmarks = False
self.rect_locked = True
self.redraw()
else:
self.rect_locked = not self.rect_locked
self.extract_needed = True
elif ev == io.EVENT_RBUTTONDOWN:
self.force_landmarks = not self.force_landmarks
if self.force_landmarks:
self.rect_locked = False
elif not self.rect_locked: elif not self.rect_locked:
new_x = np.clip (x, 0, w-1) / self.view_scale new_x = np.clip (x, 0, w-1) / self.view_scale
new_y = np.clip (y, 0, h-1) / self.view_scale new_y = np.clip (y, 0, h-1) / self.view_scale
@ -534,35 +532,11 @@ class ExtractSubprocessor(Subprocessor):
self.landmarks_accurate = not self.landmarks_accurate self.landmarks_accurate = not self.landmarks_accurate
break break
if self.force_landmarks: if self.x != new_x or \
pt2 = np.float32([new_x, new_y])
pt1 = np.float32([self.x, self.y])
pt_vec_len = npla.norm(pt2-pt1)
pt_vec = pt2-pt1
if pt_vec_len != 0:
pt_vec /= pt_vec_len
self.rect_size = pt_vec_len
self.rect = ( int(self.x-self.rect_size),
int(self.y-self.rect_size),
int(self.x+self.rect_size),
int(self.y+self.rect_size) )
if pt_vec_len > 0:
lmrks = np.concatenate ( (np.zeros ((17,2), np.float32), LandmarksProcessor.landmarks_2D), axis=0 )
lmrks -= lmrks[30:31,:]
mat = cv2.getRotationMatrix2D( (0, 0), -np.arctan2( pt_vec[1], pt_vec[0] )*180/math.pi , pt_vec_len)
mat[:, 2] += (self.x, self.y)
self.landmarks = LandmarksProcessor.transform_points(lmrks, mat )
self.redraw()
elif self.x != new_x or \
self.y != new_y or \ self.y != new_y or \
self.rect_size != new_rect_size or \ self.rect_size != new_rect_size or \
self.extract_needed: self.extract_needed or \
redraw_needed:
self.x = new_x self.x = new_x
self.y = new_y self.y = new_y
self.rect_size = new_rect_size self.rect_size = new_rect_size
@ -571,7 +545,11 @@ class ExtractSubprocessor(Subprocessor):
int(self.x+self.rect_size), int(self.x+self.rect_size),
int(self.y+self.rect_size) ) int(self.y+self.rect_size) )
return ExtractSubprocessor.Data (filepath, rects=[self.rect], landmarks_accurate=self.landmarks_accurate) if redraw_needed:
redraw_needed = False
return ExtractSubprocessor.Data (filepath, landmarks_accurate=self.landmarks_accurate)
else:
return ExtractSubprocessor.Data (filepath, rects=[self.rect], landmarks_accurate=self.landmarks_accurate)
else: else:
is_frame_done = True is_frame_done = True
@ -593,40 +571,6 @@ class ExtractSubprocessor(Subprocessor):
if not self.type != 'landmarks-manual': if not self.type != 'landmarks-manual':
self.input_data.insert(0, data) self.input_data.insert(0, data)
def redraw(self):
(h,w,c) = self.image.shape
if not self.hide_help:
image = cv2.addWeighted (self.image,1.0,self.text_lines_img,1.0,0)
else:
image = self.image.copy()
view_rect = (np.array(self.rect) * self.view_scale).astype(np.int).tolist()
view_landmarks = (np.array(self.landmarks) * self.view_scale).astype(np.int).tolist()
if self.rect_size <= 40:
scaled_rect_size = h // 3 if w > h else w // 3
p1 = (self.x - self.rect_size, self.y - self.rect_size)
p2 = (self.x + self.rect_size, self.y - self.rect_size)
p3 = (self.x - self.rect_size, self.y + self.rect_size)
wh = h if h < w else w
np1 = (w / 2 - wh / 4, h / 2 - wh / 4)
np2 = (w / 2 + wh / 4, h / 2 - wh / 4)
np3 = (w / 2 - wh / 4, h / 2 + wh / 4)
mat = cv2.getAffineTransform( np.float32([p1,p2,p3])*self.view_scale, np.float32([np1,np2,np3]) )
image = cv2.warpAffine(image, mat,(w,h) )
view_landmarks = LandmarksProcessor.transform_points (view_landmarks, mat)
landmarks_color = (255,255,0) if self.rect_locked else (0,255,0)
LandmarksProcessor.draw_rect_landmarks (image, view_rect, view_landmarks, self.face_type, self.image_size, landmarks_color=landmarks_color)
self.extract_needed = False
io.show_image (self.wnd_name, image)
#override #override
def on_result (self, host_dict, data, result): def on_result (self, host_dict, data, result):
if self.type == 'landmarks-manual': if self.type == 'landmarks-manual':
@ -635,7 +579,37 @@ class ExtractSubprocessor(Subprocessor):
if len(landmarks) != 0 and landmarks[0] is not None: if len(landmarks) != 0 and landmarks[0] is not None:
self.landmarks = landmarks[0] self.landmarks = landmarks[0]
self.redraw() (h,w,c) = self.image.shape
if not self.hide_help:
image = cv2.addWeighted (self.image,1.0,self.text_lines_img,1.0,0)
else:
image = self.image.copy()
view_rect = (np.array(self.rect) * self.view_scale).astype(np.int).tolist()
view_landmarks = (np.array(self.landmarks) * self.view_scale).astype(np.int).tolist()
if self.rect_size <= 40:
scaled_rect_size = h // 3 if w > h else w // 3
p1 = (self.x - self.rect_size, self.y - self.rect_size)
p2 = (self.x + self.rect_size, self.y - self.rect_size)
p3 = (self.x - self.rect_size, self.y + self.rect_size)
wh = h if h < w else w
np1 = (w / 2 - wh / 4, h / 2 - wh / 4)
np2 = (w / 2 + wh / 4, h / 2 - wh / 4)
np3 = (w / 2 - wh / 4, h / 2 + wh / 4)
mat = cv2.getAffineTransform( np.float32([p1,p2,p3])*self.view_scale, np.float32([np1,np2,np3]) )
image = cv2.warpAffine(image, mat,(w,h) )
view_landmarks = LandmarksProcessor.transform_points (view_landmarks, mat)
landmarks_color = (255,255,0) if self.rect_locked else (0,255,0)
LandmarksProcessor.draw_rect_landmarks (image, view_rect, view_landmarks, self.face_type, self.image_size, landmarks_color=landmarks_color)
self.extract_needed = False
io.show_image (self.wnd_name, image)
else: else:
self.result.append ( result ) self.result.append ( result )
io.progress_bar_inc(1) io.progress_bar_inc(1)
@ -712,9 +686,7 @@ def main(detector=None,
manual_output_debug_fix=False, manual_output_debug_fix=False,
manual_window_size=1368, manual_window_size=1368,
face_type='full_face', face_type='full_face',
max_faces_from_image=None, max_faces_from_image=0,
image_size=None,
jpeg_quality=None,
cpu_only = False, cpu_only = False,
force_gpu_idxs = None, force_gpu_idxs = None,
): ):
@ -723,57 +695,24 @@ def main(detector=None,
io.log_err ('Input directory not found. Please ensure it exists.') io.log_err ('Input directory not found. Please ensure it exists.')
return return
if not output_path.exists():
output_path.mkdir(parents=True, exist_ok=True)
if face_type is not None: if face_type is not None:
face_type = FaceType.fromString(face_type) face_type = FaceType.fromString(face_type)
if face_type is None: if face_type is None:
if manual_output_debug_fix: if manual_output_debug_fix and output_path.exists():
files = pathex.get_image_paths(output_path) files = pathex.get_image_paths(output_path)
if len(files) != 0: if len(files) != 0:
dflimg = DFLIMG.load(Path(files[0])) dflimg = DFLIMG.load(Path(files[0]))
if dflimg is not None and dflimg.has_data(): if dflimg is not None and dflimg.has_data():
face_type = FaceType.fromString ( dflimg.get_face_type() ) face_type = FaceType.fromString ( dflimg.get_face_type() )
input_image_paths = pathex.get_image_unique_filestem_paths(input_path, verbose_print_func=io.log_info)
output_images_paths = pathex.get_image_paths(output_path)
output_debug_path = output_path.parent / (output_path.name + '_debug')
continue_extraction = False
if not manual_output_debug_fix and len(output_images_paths) > 0:
if len(output_images_paths) > 128:
continue_extraction = io.input_bool ("Continue extraction?", True, help_message="Extraction can be continued, but you must specify the same options again.")
if len(output_images_paths) > 128 and continue_extraction:
try:
input_image_paths = input_image_paths[ [ Path(x).stem for x in input_image_paths ].index ( Path(output_images_paths[-128]).stem.split('_')[0] ) : ]
except:
io.log_err("Error in fetching the last index. Extraction cannot be continued.")
return
elif input_path != output_path:
io.input(f"\n WARNING !!! \n {output_path} contains files! \n They will be deleted. \n Press enter to continue.\n")
for filename in output_images_paths:
Path(filename).unlink()
device_config = nn.DeviceConfig.GPUIndexes( force_gpu_idxs or nn.ask_choose_device_idxs(choose_only_one=detector=='manual', suggest_all_gpu=True) ) \
if not cpu_only else nn.DeviceConfig.CPU()
if face_type is None: if face_type is None:
face_type = io.input_str ("Face type", 'wf', ['f','wf','head'], help_message="Full face / whole face / head. 'Whole face' covers full area of face include forehead. 'head' covers full head, but requires XSeg for src and dst faceset.").lower() face_type = io.input_str ("Face type", 'wf', ['f','wf','head'], help_message="Full face / whole face / head. 'Whole face' covers full area of face include forehead. 'head' covers full head, but requires XSeg for src and dst faceset.").lower()
face_type = {'f' : FaceType.FULL, face_type = {'f' : FaceType.FULL,
'wf' : FaceType.WHOLE_FACE, 'wf' : FaceType.WHOLE_FACE,
'head' : FaceType.HEAD}[face_type] 'head' : FaceType.HEAD}[face_type]
if max_faces_from_image is None: image_size = 512 if face_type < FaceType.HEAD else 768
max_faces_from_image = io.input_int(f"Max number of faces from image", 0, help_message="If you extract a src faceset that has frames with a large number of faces, it is advisable to set max faces to 3 to speed up extraction. 0 - unlimited")
if image_size is None:
image_size = io.input_int(f"Image size", 512 if face_type < FaceType.HEAD else 768, valid_range=[256,2048], help_message="Output image size. The higher image size, the worse face-enhancer works. Use higher than 512 value only if the source image is sharp enough and the face does not need to be enhanced.")
if jpeg_quality is None:
jpeg_quality = io.input_int(f"Jpeg quality", 90, valid_range=[1,100], help_message="Jpeg quality. The higher jpeg quality the larger the output file size.")
if detector is None: if detector is None:
io.log_info ("Choose detector type.") io.log_info ("Choose detector type.")
@ -781,12 +720,25 @@ def main(detector=None,
io.log_info ("[1] manual") io.log_info ("[1] manual")
detector = {0:'s3fd', 1:'manual'}[ io.input_int("", 0, [0,1]) ] detector = {0:'s3fd', 1:'manual'}[ io.input_int("", 0, [0,1]) ]
device_config = nn.DeviceConfig.GPUIndexes( force_gpu_idxs or nn.ask_choose_device_idxs(choose_only_one=detector=='manual', suggest_all_gpu=True) ) \
if not cpu_only else nn.DeviceConfig.CPU()
output_debug_path = output_path.parent / (output_path.name + '_debug')
if output_debug is None: if output_debug is None:
output_debug = io.input_bool (f"Write debug images to {output_debug_path.name}?", False) output_debug = io.input_bool (f"Write debug images to {output_debug_path.name}?", False)
if output_debug: if output_path.exists():
output_debug_path.mkdir(parents=True, exist_ok=True) if not manual_output_debug_fix and input_path != output_path:
output_images_paths = pathex.get_image_paths(output_path)
if len(output_images_paths) > 0:
io.input(f"\n WARNING !!! \n {output_path} contains files! \n They will be deleted. \n Press enter to continue.\n")
for filename in output_images_paths:
Path(filename).unlink()
else:
output_path.mkdir(parents=True, exist_ok=True)
input_path_image_paths = pathex.get_image_unique_filestem_paths(input_path, verbose_print_func=io.log_info)
if manual_output_debug_fix: if manual_output_debug_fix:
if not output_debug_path.exists(): if not output_debug_path.exists():
@ -796,30 +748,31 @@ def main(detector=None,
detector = 'manual' detector = 'manual'
io.log_info('Performing re-extract frames which were deleted from _debug directory.') io.log_info('Performing re-extract frames which were deleted from _debug directory.')
input_image_paths = DeletedFilesSearcherSubprocessor (input_image_paths, pathex.get_image_paths(output_debug_path) ).run() input_path_image_paths = DeletedFilesSearcherSubprocessor (input_path_image_paths, pathex.get_image_paths(output_debug_path) ).run()
input_image_paths = sorted (input_image_paths) input_path_image_paths = sorted (input_path_image_paths)
io.log_info('Found %d images.' % (len(input_image_paths))) io.log_info('Found %d images.' % (len(input_path_image_paths)))
else: else:
if not continue_extraction and output_debug_path.exists(): if output_debug_path.exists():
for filename in pathex.get_image_paths(output_debug_path): for filename in pathex.get_image_paths(output_debug_path):
Path(filename).unlink() Path(filename).unlink()
else:
output_debug_path.mkdir(parents=True, exist_ok=True)
images_found = len(input_image_paths) images_found = len(input_path_image_paths)
faces_detected = 0 faces_detected = 0
if images_found != 0: if images_found != 0:
if detector == 'manual': if detector == 'manual':
io.log_info ('Performing manual extract...') io.log_info ('Performing manual extract...')
data = ExtractSubprocessor ([ ExtractSubprocessor.Data(Path(filename)) for filename in input_image_paths ], 'landmarks-manual', image_size, jpeg_quality, face_type, output_debug_path if output_debug else None, manual_window_size=manual_window_size, device_config=device_config).run() data = ExtractSubprocessor ([ ExtractSubprocessor.Data(Path(filename)) for filename in input_path_image_paths ], 'landmarks-manual', image_size, face_type, output_debug_path if output_debug else None, manual_window_size=manual_window_size, device_config=device_config).run()
io.log_info ('Performing 3rd pass...') io.log_info ('Performing 3rd pass...')
data = ExtractSubprocessor (data, 'final', image_size, jpeg_quality, face_type, output_debug_path if output_debug else None, final_output_path=output_path, device_config=device_config).run() data = ExtractSubprocessor (data, 'final', image_size, face_type, output_debug_path if output_debug else None, final_output_path=output_path, device_config=device_config).run()
else: else:
io.log_info ('Extracting faces...') io.log_info ('Extracting faces...')
data = ExtractSubprocessor ([ ExtractSubprocessor.Data(Path(filename)) for filename in input_image_paths ], data = ExtractSubprocessor ([ ExtractSubprocessor.Data(Path(filename)) for filename in input_path_image_paths ],
'all', 'all',
image_size, image_size,
jpeg_quality,
face_type, face_type,
output_debug_path if output_debug else None, output_debug_path if output_debug else None,
max_faces_from_image=max_faces_from_image, max_faces_from_image=max_faces_from_image,
@ -834,8 +787,8 @@ def main(detector=None,
else: else:
fix_data = [ ExtractSubprocessor.Data(d.filepath) for d in data if d.faces_detected == 0 ] fix_data = [ ExtractSubprocessor.Data(d.filepath) for d in data if d.faces_detected == 0 ]
io.log_info ('Performing manual fix for %d images...' % (len(fix_data)) ) io.log_info ('Performing manual fix for %d images...' % (len(fix_data)) )
fix_data = ExtractSubprocessor (fix_data, 'landmarks-manual', image_size, jpeg_quality, face_type, output_debug_path if output_debug else None, manual_window_size=manual_window_size, device_config=device_config).run() fix_data = ExtractSubprocessor (fix_data, 'landmarks-manual', image_size, face_type, output_debug_path if output_debug else None, manual_window_size=manual_window_size, device_config=device_config).run()
fix_data = ExtractSubprocessor (fix_data, 'final', image_size, jpeg_quality, face_type, output_debug_path if output_debug else None, final_output_path=output_path, device_config=device_config).run() fix_data = ExtractSubprocessor (fix_data, 'final', image_size, face_type, output_debug_path if output_debug else None, final_output_path=output_path, device_config=device_config).run()
faces_detected += sum([d.faces_detected for d in fix_data]) faces_detected += sum([d.faces_detected for d in fix_data])

View file

@ -1,209 +0,0 @@
import multiprocessing
import shutil
import cv2
from core import pathex
from core.cv2ex import *
from core.interact import interact as io
from core.joblib import Subprocessor
from DFLIMG import *
from facelib import FaceType, LandmarksProcessor
class FacesetResizerSubprocessor(Subprocessor):
#override
def __init__(self, image_paths, output_dirpath, image_size, face_type=None):
self.image_paths = image_paths
self.output_dirpath = output_dirpath
self.image_size = image_size
self.face_type = face_type
self.result = []
super().__init__('FacesetResizer', FacesetResizerSubprocessor.Cli, 600)
#override
def on_clients_initialized(self):
io.progress_bar (None, len (self.image_paths))
#override
def on_clients_finalized(self):
io.progress_bar_close()
#override
def process_info_generator(self):
base_dict = {'output_dirpath':self.output_dirpath, 'image_size':self.image_size, 'face_type':self.face_type}
for device_idx in range( min(8, multiprocessing.cpu_count()) ):
client_dict = base_dict.copy()
device_name = f'CPU #{device_idx}'
client_dict['device_name'] = device_name
yield device_name, {}, client_dict
#override
def get_data(self, host_dict):
if len (self.image_paths) > 0:
return self.image_paths.pop(0)
#override
def on_data_return (self, host_dict, data):
self.image_paths.insert(0, data)
#override
def on_result (self, host_dict, data, result):
io.progress_bar_inc(1)
if result[0] == 1:
self.result +=[ (result[1], result[2]) ]
#override
def get_result(self):
return self.result
class Cli(Subprocessor.Cli):
#override
def on_initialize(self, client_dict):
self.output_dirpath = client_dict['output_dirpath']
self.image_size = client_dict['image_size']
self.face_type = client_dict['face_type']
self.log_info (f"Running on { client_dict['device_name'] }")
#override
def process_data(self, filepath):
try:
dflimg = DFLIMG.load (filepath)
if dflimg is None or not dflimg.has_data():
self.log_err (f"{filepath.name} is not a dfl image file")
else:
img = cv2_imread(filepath)
h,w = img.shape[:2]
if h != w:
raise Exception(f'w != h in {filepath}')
image_size = self.image_size
face_type = self.face_type
output_filepath = self.output_dirpath / filepath.name
if face_type is not None:
lmrks = dflimg.get_landmarks()
mat = LandmarksProcessor.get_transform_mat(lmrks, image_size, face_type)
img = cv2.warpAffine(img, mat, (image_size, image_size), flags=cv2.INTER_LANCZOS4 )
img = np.clip(img, 0, 255).astype(np.uint8)
cv2_imwrite ( str(output_filepath), img, [int(cv2.IMWRITE_JPEG_QUALITY), 100] )
dfl_dict = dflimg.get_dict()
dflimg = DFLIMG.load (output_filepath)
dflimg.set_dict(dfl_dict)
xseg_mask = dflimg.get_xseg_mask()
if xseg_mask is not None:
xseg_res = 256
xseg_lmrks = lmrks.copy()
xseg_lmrks *= (xseg_res / w)
xseg_mat = LandmarksProcessor.get_transform_mat(xseg_lmrks, xseg_res, face_type)
xseg_mask = cv2.warpAffine(xseg_mask, xseg_mat, (xseg_res, xseg_res), flags=cv2.INTER_LANCZOS4 )
xseg_mask[xseg_mask < 0.5] = 0
xseg_mask[xseg_mask >= 0.5] = 1
dflimg.set_xseg_mask(xseg_mask)
seg_ie_polys = dflimg.get_seg_ie_polys()
for poly in seg_ie_polys.get_polys():
poly_pts = poly.get_pts()
poly_pts = LandmarksProcessor.transform_points(poly_pts, mat)
poly.set_points(poly_pts)
dflimg.set_seg_ie_polys(seg_ie_polys)
lmrks = LandmarksProcessor.transform_points(lmrks, mat)
dflimg.set_landmarks(lmrks)
image_to_face_mat = dflimg.get_image_to_face_mat()
if image_to_face_mat is not None:
image_to_face_mat = LandmarksProcessor.get_transform_mat ( dflimg.get_source_landmarks(), image_size, face_type )
dflimg.set_image_to_face_mat(image_to_face_mat)
dflimg.set_face_type( FaceType.toString(face_type) )
dflimg.save()
else:
dfl_dict = dflimg.get_dict()
scale = w / image_size
img = cv2.resize(img, (image_size, image_size), interpolation=cv2.INTER_LANCZOS4)
cv2_imwrite ( str(output_filepath), img, [int(cv2.IMWRITE_JPEG_QUALITY), 100] )
dflimg = DFLIMG.load (output_filepath)
dflimg.set_dict(dfl_dict)
lmrks = dflimg.get_landmarks()
lmrks /= scale
dflimg.set_landmarks(lmrks)
seg_ie_polys = dflimg.get_seg_ie_polys()
seg_ie_polys.mult_points( 1.0 / scale)
dflimg.set_seg_ie_polys(seg_ie_polys)
image_to_face_mat = dflimg.get_image_to_face_mat()
if image_to_face_mat is not None:
face_type = FaceType.fromString ( dflimg.get_face_type() )
image_to_face_mat = LandmarksProcessor.get_transform_mat ( dflimg.get_source_landmarks(), image_size, face_type )
dflimg.set_image_to_face_mat(image_to_face_mat)
dflimg.save()
return (1, filepath, output_filepath)
except:
self.log_err (f"Exception occured while processing file {filepath}. Error: {traceback.format_exc()}")
return (0, filepath, None)
def process_folder ( dirpath):
image_size = io.input_int(f"New image size", 512, valid_range=[128,2048])
face_type = io.input_str ("Change face type", 'same', ['h','mf','f','wf','head','same']).lower()
if face_type == 'same':
face_type = None
else:
face_type = {'h' : FaceType.HALF,
'mf' : FaceType.MID_FULL,
'f' : FaceType.FULL,
'wf' : FaceType.WHOLE_FACE,
'head' : FaceType.HEAD}[face_type]
output_dirpath = dirpath.parent / (dirpath.name + '_resized')
output_dirpath.mkdir (exist_ok=True, parents=True)
dirpath_parts = '/'.join( dirpath.parts[-2:])
output_dirpath_parts = '/'.join( output_dirpath.parts[-2:] )
io.log_info (f"Resizing faceset in {dirpath_parts}")
io.log_info ( f"Processing to {output_dirpath_parts}")
output_images_paths = pathex.get_image_paths(output_dirpath)
if len(output_images_paths) > 0:
for filename in output_images_paths:
Path(filename).unlink()
image_paths = [Path(x) for x in pathex.get_image_paths( dirpath )]
result = FacesetResizerSubprocessor ( image_paths, output_dirpath, image_size, face_type).run()
is_merge = io.input_bool (f"\r\nMerge {output_dirpath_parts} to {dirpath_parts} ?", True)
if is_merge:
io.log_info (f"Copying processed files to {dirpath_parts}")
for (filepath, output_filepath) in result:
try:
shutil.copy (output_filepath, filepath)
except:
pass
io.log_info (f"Removing {output_dirpath_parts}")
shutil.rmtree(output_dirpath)

View file

@ -1,5 +1,4 @@
import math import math
import multiprocessing
import traceback import traceback
from pathlib import Path from pathlib import Path
@ -14,8 +13,7 @@ from core.joblib import MPClassFuncOnDemand, MPFunc
from core.leras import nn from core.leras import nn
from DFLIMG import DFLIMG from DFLIMG import DFLIMG
from facelib import FaceEnhancer, FaceType, LandmarksProcessor, XSegNet from facelib import FaceEnhancer, FaceType, LandmarksProcessor, XSegNet
from merger import FrameInfo, InteractiveMergerSubprocessor, MergerConfig from merger import FrameInfo, MergerConfig, InteractiveMergerSubprocessor
def main (model_class_name=None, def main (model_class_name=None,
saved_models_path=None, saved_models_path=None,
@ -49,7 +47,6 @@ def main (model_class_name=None,
model = models.import_model(model_class_name)(is_training=False, model = models.import_model(model_class_name)(is_training=False,
saved_models_path=saved_models_path, saved_models_path=saved_models_path,
force_gpu_idxs=force_gpu_idxs, force_gpu_idxs=force_gpu_idxs,
force_model_name=force_model_name,
cpu_only=cpu_only) cpu_only=cpu_only)
predictor_func, predictor_input_shape, cfg = model.get_MergerConfig() predictor_func, predictor_input_shape, cfg = model.get_MergerConfig()
@ -74,9 +71,6 @@ def main (model_class_name=None,
if not is_interactive: if not is_interactive:
cfg.ask_settings() cfg.ask_settings()
subprocess_count = io.input_int("Number of workers?", max(8, multiprocessing.cpu_count()),
valid_range=[1, multiprocessing.cpu_count()], help_message="Specify the number of threads to process. A low value may affect performance. A high value may result in memory error. The value may not be greater than CPU cores." )
input_path_image_paths = pathex.get_image_paths(input_path) input_path_image_paths = pathex.get_image_paths(input_path)
if cfg.type == MergerConfig.TYPE_MASKED: if cfg.type == MergerConfig.TYPE_MASKED:
@ -205,8 +199,7 @@ def main (model_class_name=None,
frames_root_path = input_path, frames_root_path = input_path,
output_path = output_path, output_path = output_path,
output_mask_path = output_mask_path, output_mask_path = output_mask_path,
model_iter = model.get_iter(), model_iter = model.get_iter()
subprocess_count = subprocess_count,
).run() ).run()
model.finalize() model.finalize()

View file

@ -23,9 +23,6 @@ from facelib import LandmarksProcessor
class BlurEstimatorSubprocessor(Subprocessor): class BlurEstimatorSubprocessor(Subprocessor):
class Cli(Subprocessor.Cli): class Cli(Subprocessor.Cli):
def on_initialize(self, client_dict):
self.estimate_motion_blur = client_dict['estimate_motion_blur']
#override #override
def process_data(self, data): def process_data(self, data):
filepath = Path( data[0] ) filepath = Path( data[0] )
@ -36,17 +33,7 @@ class BlurEstimatorSubprocessor(Subprocessor):
return [ str(filepath), 0 ] return [ str(filepath), 0 ]
else: else:
image = cv2_imread( str(filepath) ) image = cv2_imread( str(filepath) )
return [ str(filepath), estimate_sharpness(image) ]
face_mask = LandmarksProcessor.get_image_hull_mask (image.shape, dflimg.get_landmarks())
image = (image*face_mask).astype(np.uint8)
if self.estimate_motion_blur:
value = cv2.Laplacian(image, cv2.CV_64F, ksize=11).var()
else:
value = estimate_sharpness(image)
return [ str(filepath), value ]
#override #override
@ -55,9 +42,8 @@ class BlurEstimatorSubprocessor(Subprocessor):
return data[0] return data[0]
#override #override
def __init__(self, input_data, estimate_motion_blur=False ): def __init__(self, input_data ):
self.input_data = input_data self.input_data = input_data
self.estimate_motion_blur = estimate_motion_blur
self.img_list = [] self.img_list = []
self.trash_img_list = [] self.trash_img_list = []
super().__init__('BlurEstimator', BlurEstimatorSubprocessor.Cli, 60) super().__init__('BlurEstimator', BlurEstimatorSubprocessor.Cli, 60)
@ -76,7 +62,7 @@ class BlurEstimatorSubprocessor(Subprocessor):
io.log_info(f'Running on {cpu_count} CPUs') io.log_info(f'Running on {cpu_count} CPUs')
for i in range(cpu_count): for i in range(cpu_count):
yield 'CPU%d' % (i), {}, {'estimate_motion_blur':self.estimate_motion_blur} yield 'CPU%d' % (i), {}, {}
#override #override
def get_data(self, host_dict): def get_data(self, host_dict):
@ -114,17 +100,6 @@ def sort_by_blur(input_path):
return img_list, trash_img_list return img_list, trash_img_list
def sort_by_motion_blur(input_path):
io.log_info ("Sorting by motion blur...")
img_list = [ (filename,[]) for filename in pathex.get_image_paths(input_path) ]
img_list, trash_img_list = BlurEstimatorSubprocessor (img_list, estimate_motion_blur=True).run()
io.log_info ("Sorting...")
img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True)
return img_list, trash_img_list
def sort_by_face_yaw(input_path): def sort_by_face_yaw(input_path):
io.log_info ("Sorting by face yaw...") io.log_info ("Sorting by face yaw...")
img_list = [] img_list = []
@ -468,12 +443,12 @@ class FinalLoaderSubprocessor(Subprocessor):
raise Exception ("Unable to load %s" % (filepath.name) ) raise Exception ("Unable to load %s" % (filepath.name) )
gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY) gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY)
if self.faster: if self.faster:
source_rect = dflimg.get_source_rect() source_rect = dflimg.get_source_rect()
sharpness = mathlib.polygon_area(np.array(source_rect[[0,2,2,0]]).astype(np.float32), np.array(source_rect[[1,1,3,3]]).astype(np.float32)) sharpness = mathlib.polygon_area(np.array(source_rect[[0,2,2,0]]).astype(np.float32), np.array(source_rect[[1,1,3,3]]).astype(np.float32))
else: else:
face_mask = LandmarksProcessor.get_image_hull_mask (gray.shape, dflimg.get_landmarks()) sharpness = estimate_sharpness(gray)
sharpness = estimate_sharpness( (gray[...,None]*face_mask).astype(np.uint8) )
pitch, yaw, roll = LandmarksProcessor.estimate_pitch_yaw_roll ( dflimg.get_landmarks(), size=dflimg.get_shape()[1] ) pitch, yaw, roll = LandmarksProcessor.estimate_pitch_yaw_roll ( dflimg.get_landmarks(), size=dflimg.get_shape()[1] )
@ -897,7 +872,6 @@ def final_process(input_path, img_list, trash_img_list):
sort_func_methods = { sort_func_methods = {
'blur': ("blur", sort_by_blur), 'blur': ("blur", sort_by_blur),
'motion-blur': ("motion_blur", sort_by_motion_blur),
'face-yaw': ("face yaw direction", sort_by_face_yaw), 'face-yaw': ("face yaw direction", sort_by_face_yaw),
'face-pitch': ("face pitch direction", sort_by_face_pitch), 'face-pitch': ("face pitch direction", sort_by_face_pitch),
'face-source-rect-size' : ("face rect size in source image", sort_by_face_source_rect_size), 'face-source-rect-size' : ("face rect size in source image", sort_by_face_source_rect_size),
@ -925,7 +899,7 @@ def main (input_path, sort_by_method=None):
io.log_info(f"[{i}] {desc}") io.log_info(f"[{i}] {desc}")
io.log_info("") io.log_info("")
id = io.input_int("", 5, valid_list=[*range(len(key_list))] ) id = io.input_int("", 4, valid_list=[*range(len(key_list))] )
sort_by_method = key_list[id] sort_by_method = key_list[id]
else: else:

View file

@ -1,5 +1,4 @@
import os import sys
import sys
import traceback import traceback
import queue import queue
import threading import threading
@ -32,7 +31,7 @@ def trainerThread (s2c, c2s, e,
try: try:
start_time = time.time() start_time = time.time()
save_interval_min = 25 save_interval_min = 15
if not training_data_src_path.exists(): if not training_data_src_path.exists():
training_data_src_path.mkdir(exist_ok=True, parents=True) training_data_src_path.mkdir(exist_ok=True, parents=True)
@ -55,7 +54,8 @@ def trainerThread (s2c, c2s, e,
force_gpu_idxs=force_gpu_idxs, force_gpu_idxs=force_gpu_idxs,
cpu_only=cpu_only, cpu_only=cpu_only,
silent_start=silent_start, silent_start=silent_start,
debug=debug) debug=debug,
)
is_reached_goal = model.is_reached_iter_goal() is_reached_goal = model.is_reached_iter_goal()
@ -120,12 +120,6 @@ def trainerThread (s2c, c2s, e,
io.log_info("Trying to do the first iteration. If an error occurs, reduce the model parameters.") io.log_info("Trying to do the first iteration. If an error occurs, reduce the model parameters.")
io.log_info("") io.log_info("")
if sys.platform[0:3] == 'win':
io.log_info("!!!")
io.log_info("Windows 10 users IMPORTANT notice. You should set this setting in order to work correctly.")
io.log_info("https://i.imgur.com/B7cmDCB.jpg")
io.log_info("!!!")
iter, iter_time = model.train_one_iter() iter, iter_time = model.train_one_iter()
loss_history = model.get_loss_history() loss_history = model.get_loss_history()
@ -164,12 +158,8 @@ def trainerThread (s2c, c2s, e,
is_reached_goal = True is_reached_goal = True
io.log_info ('You can use preview now.') io.log_info ('You can use preview now.')
need_save = False if not is_reached_goal and (time.time() - last_save_time) >= save_interval_min*60:
while time.time() - last_save_time >= save_interval_min*60:
last_save_time += save_interval_min*60 last_save_time += save_interval_min*60
need_save = True
if not is_reached_goal and need_save:
model_save() model_save()
send_preview() send_preview()

View file

@ -64,7 +64,7 @@ def restore_faceset_metadata_folder(input_path):
img = cv2_imread (filepath) img = cv2_imread (filepath)
if img.shape != shape: if img.shape != shape:
img = cv2.resize (img, (shape[1], shape[0]), interpolation=cv2.INTER_LANCZOS4 ) img = cv2.resize (img, (shape[1], shape[0]), cv2.INTER_LANCZOS4 )
cv2_imwrite (str(filepath), img, [int(cv2.IMWRITE_JPEG_QUALITY), 100] ) cv2_imwrite (str(filepath), img, [int(cv2.IMWRITE_JPEG_QUALITY), 100] )
@ -159,32 +159,3 @@ def recover_original_aligned_filename(input_path):
fs.rename (fd) fs.rename (fd)
except: except:
io.log_err ('fail to rename %s' % (fs.name) ) io.log_err ('fail to rename %s' % (fs.name) )
def export_faceset_mask(input_dir):
for filename in io.progress_bar_generator(pathex.get_image_paths (input_dir), "Processing"):
filepath = Path(filename)
if '_mask' in filepath.stem:
continue
mask_filepath = filepath.parent / (filepath.stem+'_mask'+filepath.suffix)
dflimg = DFLJPG.load(filepath)
H,W,C = dflimg.shape
seg_ie_polys = dflimg.get_seg_ie_polys()
if dflimg.has_xseg_mask():
mask = dflimg.get_xseg_mask()
mask[mask < 0.5] = 0.0
mask[mask >= 0.5] = 1.0
elif seg_ie_polys.has_polys():
mask = np.zeros ((H,W,1), dtype=np.float32)
seg_ie_polys.overlay_mask(mask)
else:
raise Exception(f'no mask in file {filepath}')
cv2_imwrite(mask_filepath, (mask*255).astype(np.uint8), [int(cv2.IMWRITE_JPEG_QUALITY), 100] )

View file

@ -10,8 +10,8 @@ from core.cv2ex import *
from core.interact import interact as io from core.interact import interact as io
from core.leras import nn from core.leras import nn
from DFLIMG import * from DFLIMG import *
from facelib import XSegNet, LandmarksProcessor, FaceType from facelib import XSegNet
import pickle
def apply_xseg(input_path, model_path): def apply_xseg(input_path, model_path):
if not input_path.exists(): if not input_path.exists():
@ -20,42 +20,17 @@ def apply_xseg(input_path, model_path):
if not model_path.exists(): if not model_path.exists():
raise ValueError(f'{model_path} not found. Please ensure it exists.') raise ValueError(f'{model_path} not found. Please ensure it exists.')
face_type = None
model_dat = model_path / 'XSeg_data.dat'
if model_dat.exists():
dat = pickle.loads( model_dat.read_bytes() )
dat_options = dat.get('options', None)
if dat_options is not None:
face_type = dat_options.get('face_type', None)
if face_type is None:
face_type = io.input_str ("XSeg model face type", 'same', ['h','mf','f','wf','head','same'], help_message="Specify face type of trained XSeg model. For example if XSeg model trained as WF, but faceset is HEAD, specify WF to apply xseg only on WF part of HEAD. Default is 'same'").lower()
if face_type == 'same':
face_type = None
if face_type is not None:
face_type = {'h' : FaceType.HALF,
'mf' : FaceType.MID_FULL,
'f' : FaceType.FULL,
'wf' : FaceType.WHOLE_FACE,
'head' : FaceType.HEAD}[face_type]
io.log_info(f'Applying trained XSeg model to {input_path.name}/ folder.') io.log_info(f'Applying trained XSeg model to {input_path.name}/ folder.')
device_config = nn.DeviceConfig.ask_choose_device(choose_only_one=True) device_config = nn.DeviceConfig.ask_choose_device(choose_only_one=True)
nn.initialize(device_config) nn.initialize(device_config)
xseg = XSegNet(name='XSeg', xseg = XSegNet(name='XSeg',
load_weights=True, load_weights=True,
weights_file_root=model_path, weights_file_root=model_path,
data_format=nn.data_format, data_format=nn.data_format,
raise_on_no_model_files=True) raise_on_no_model_files=True)
xseg_res = xseg.get_resolution() res = xseg.get_resolution()
images_paths = pathex.get_image_paths(input_path, return_Path_class=True) images_paths = pathex.get_image_paths(input_path, return_Path_class=True)
@ -67,36 +42,15 @@ def apply_xseg(input_path, model_path):
img = cv2_imread(filepath).astype(np.float32) / 255.0 img = cv2_imread(filepath).astype(np.float32) / 255.0
h,w,c = img.shape h,w,c = img.shape
if w != res:
img_face_type = FaceType.fromString( dflimg.get_face_type() ) img = cv2.resize( img, (res,res), interpolation=cv2.INTER_CUBIC )
if face_type is not None and img_face_type != face_type: if len(img.shape) == 2:
lmrks = dflimg.get_source_landmarks() img = img[...,None]
fmat = LandmarksProcessor.get_transform_mat(lmrks, w, face_type)
imat = LandmarksProcessor.get_transform_mat(lmrks, w, img_face_type)
g_p = LandmarksProcessor.transform_points (np.float32([(0,0),(w,0),(0,w) ]), fmat, True)
g_p2 = LandmarksProcessor.transform_points (g_p, imat)
mat = cv2.getAffineTransform( g_p2, np.float32([(0,0),(w,0),(0,w) ]) )
img = cv2.warpAffine(img, mat, (w, w), cv2.INTER_LANCZOS4)
img = cv2.resize(img, (xseg_res, xseg_res), interpolation=cv2.INTER_LANCZOS4)
else:
if w != xseg_res:
img = cv2.resize( img, (xseg_res,xseg_res), interpolation=cv2.INTER_LANCZOS4 )
if len(img.shape) == 2:
img = img[...,None]
mask = xseg.extract(img) mask = xseg.extract(img)
if face_type is not None and img_face_type != face_type:
mask = cv2.resize(mask, (w, w), interpolation=cv2.INTER_LANCZOS4)
mask = cv2.warpAffine( mask, mat, (w,w), np.zeros( (h,w,c), dtype=np.float), cv2.WARP_INVERSE_MAP | cv2.INTER_LANCZOS4)
mask = cv2.resize(mask, (xseg_res, xseg_res), interpolation=cv2.INTER_LANCZOS4)
mask[mask < 0.5]=0 mask[mask < 0.5]=0
mask[mask >= 0.5]=1 mask[mask >= 0.5]=1
dflimg.set_xseg_mask(mask) dflimg.set_xseg_mask(mask)
dflimg.save() dflimg.save()
@ -113,8 +67,7 @@ def fetch_xseg(input_path):
images_paths = pathex.get_image_paths(input_path, return_Path_class=True) images_paths = pathex.get_image_paths(input_path, return_Path_class=True)
files_copied = 0
files_copied = []
for filepath in io.progress_bar_generator(images_paths, "Processing"): for filepath in io.progress_bar_generator(images_paths, "Processing"):
dflimg = DFLIMG.load(filepath) dflimg = DFLIMG.load(filepath)
if dflimg is None or not dflimg.has_data(): if dflimg is None or not dflimg.has_data():
@ -124,16 +77,10 @@ def fetch_xseg(input_path):
ie_polys = dflimg.get_seg_ie_polys() ie_polys = dflimg.get_seg_ie_polys()
if ie_polys.has_polys(): if ie_polys.has_polys():
files_copied.append(filepath) files_copied += 1
shutil.copy ( str(filepath), str(output_path / filepath.name) ) shutil.copy ( str(filepath), str(output_path / filepath.name) )
io.log_info(f'Files copied: {len(files_copied)}') io.log_info(f'Files copied: {files_copied}')
is_delete = io.input_bool (f"\r\nDelete original files?", True)
if is_delete:
for filepath in files_copied:
Path(filepath).unlink()
def remove_xseg(input_path): def remove_xseg(input_path):
if not input_path.exists(): if not input_path.exists():

View file

@ -13,6 +13,7 @@ from core.joblib import Subprocessor
from core.leras import nn from core.leras import nn
from DFLIMG import * from DFLIMG import *
from facelib import FaceType, LandmarksProcessor from facelib import FaceType, LandmarksProcessor
from . import Extractor, Sorter from . import Extractor, Sorter
from .Extractor import ExtractSubprocessor from .Extractor import ExtractSubprocessor
@ -356,114 +357,25 @@ def extract_umd_csv(input_file_csv,
io.log_info ('Faces detected: %d' % (faces_detected) ) io.log_info ('Faces detected: %d' % (faces_detected) )
io.log_info ('-------------------------') io.log_info ('-------------------------')
def dev_test1(input_dir): def dev_test1(input_dir):
# LaPa dataset
image_size = 1024
face_type = FaceType.HEAD
input_path = Path(input_dir) input_path = Path(input_dir)
images_path = input_path / 'images'
if not images_path.exists:
raise ValueError('LaPa dataset: images folder not found.')
labels_path = input_path / 'labels'
if not labels_path.exists:
raise ValueError('LaPa dataset: labels folder not found.')
landmarks_path = input_path / 'landmarks'
if not landmarks_path.exists:
raise ValueError('LaPa dataset: landmarks folder not found.')
output_path = input_path / 'out' dir_names = pathex.get_all_dir_names(input_path)
if output_path.exists():
output_images_paths = pathex.get_image_paths(output_path)
if len(output_images_paths) != 0:
io.input(f"\n WARNING !!! \n {output_path} contains files! \n They will be deleted. \n Press enter to continue.\n")
for filename in output_images_paths:
Path(filename).unlink()
output_path.mkdir(parents=True, exist_ok=True)
data = [] for dir_name in io.progress_bar_generator(dir_names, desc="Processing"):
img_paths = pathex.get_image_paths (images_path) img_paths = pathex.get_image_paths (input_path / dir_name)
for filename in img_paths: for filename in img_paths:
filepath = Path(filename)
landmark_filepath = landmarks_path / (filepath.stem + '.txt')
if not landmark_filepath.exists():
raise ValueError(f'no landmarks for {filepath}')
#img = cv2_imread(filepath)
lm = landmark_filepath.read_text()
lm = lm.split('\n')
if int(lm[0]) != 106:
raise ValueError(f'wrong landmarks format in {landmark_filepath}')
lmrks = []
for i in range(106):
x,y = lm[i+1].split(' ')
x,y = float(x), float(y)
lmrks.append ( (x,y) )
lmrks = np.array(lmrks)
l,t = np.min(lmrks, 0)
r,b = np.max(lmrks, 0)
l,t,r,b = ( int(x) for x in (l,t,r,b) )
#for x, y in lmrks:
# x,y = int(x), int(y)
# cv2.circle(img, (x, y), 1, (0,255,0) , 1, lineType=cv2.LINE_AA)
#imagelib.draw_rect(img, (l,t,r,b), (0,255,0) )
data += [ ExtractSubprocessor.Data(filepath=filepath, rects=[ (l,t,r,b) ]) ]
#cv2.imshow("", img)
#cv2.waitKey(0)
if len(data) > 0:
device_config = nn.DeviceConfig.BestGPU()
io.log_info ("Performing 2nd pass...")
data = ExtractSubprocessor (data, 'landmarks', image_size, 95, face_type, device_config=device_config).run()
io.log_info ("Performing 3rd pass...")
data = ExtractSubprocessor (data, 'final', image_size, 95, face_type, final_output_path=output_path, device_config=device_config).run()
for filename in pathex.get_image_paths (output_path):
filepath = Path(filename) filepath = Path(filename)
dflimg = DFLIMG.x (filepath)
if dflimg is None:
raise ValueError
dflimg = DFLJPG.load(filepath) #dflimg.x(filename, person_name=dir_name)
src_filename = dflimg.get_source_filename()
image_to_face_mat = dflimg.get_image_to_face_mat()
label_filepath = labels_path / ( Path(src_filename).stem + '.png')
if not label_filepath.exists():
raise ValueError(f'{label_filepath} does not exist')
mask = cv2_imread(label_filepath)
#mask[mask == 10] = 0 # remove hair
mask[mask > 0] = 1
mask = cv2.warpAffine(mask, image_to_face_mat, (image_size, image_size), cv2.INTER_LINEAR)
mask = cv2.blur(mask, (3,3) )
#cv2.imshow("", (mask*255).astype(np.uint8) )
#cv2.waitKey(0)
dflimg.set_xseg_mask(mask)
dflimg.save()
import code
code.interact(local=dict(globals(), **locals()))
#import code
#code.interact(local=dict(globals(), **locals()))
def dev_resave_pngs(input_dir): def dev_resave_pngs(input_dir):
input_path = Path(input_dir) input_path = Path(input_dir)
@ -499,96 +411,3 @@ def dev_segmented_trash(input_dir):
except: except:
io.log_info ('fail to trashing %s' % (src.name) ) io.log_info ('fail to trashing %s' % (src.name) )
def dev_test(input_dir):
"""
extract FaceSynthetics dataset https://github.com/microsoft/FaceSynthetics
BACKGROUND = 0
SKIN = 1
NOSE = 2
RIGHT_EYE = 3
LEFT_EYE = 4
RIGHT_BROW = 5
LEFT_BROW = 6
RIGHT_EAR = 7
LEFT_EAR = 8
MOUTH_INTERIOR = 9
TOP_LIP = 10
BOTTOM_LIP = 11
NECK = 12
HAIR = 13
BEARD = 14
CLOTHING = 15
GLASSES = 16
HEADWEAR = 17
FACEWEAR = 18
IGNORE = 255
"""
image_size = 1024
face_type = FaceType.WHOLE_FACE
input_path = Path(input_dir)
output_path = input_path.parent / f'{input_path.name}_out'
if output_path.exists():
output_images_paths = pathex.get_image_paths(output_path)
if len(output_images_paths) != 0:
io.input(f"\n WARNING !!! \n {output_path} contains files! \n They will be deleted. \n Press enter to continue.\n")
for filename in output_images_paths:
Path(filename).unlink()
output_path.mkdir(parents=True, exist_ok=True)
data = []
for filepath in io.progress_bar_generator(pathex.get_paths(input_path), "Processing"):
if filepath.suffix == '.txt':
image_filepath = filepath.parent / f'{filepath.name.split("_")[0]}.png'
if not image_filepath.exists():
print(f'{image_filepath} does not exist, skipping')
lmrks = []
for lmrk_line in filepath.read_text().split('\n'):
if len(lmrk_line) == 0:
continue
x, y = lmrk_line.split(' ')
x, y = float(x), float(y)
lmrks.append( (x,y) )
lmrks = np.array(lmrks[:68], np.float32)
rect = LandmarksProcessor.get_rect_from_landmarks(lmrks)
data += [ ExtractSubprocessor.Data(filepath=image_filepath, rects=[rect], landmarks=[ lmrks ] ) ]
if len(data) > 0:
io.log_info ("Performing 3rd pass...")
data = ExtractSubprocessor (data, 'final', image_size, 95, face_type, final_output_path=output_path, device_config=nn.DeviceConfig.CPU()).run()
for filename in io.progress_bar_generator(pathex.get_image_paths (output_path), "Processing"):
filepath = Path(filename)
dflimg = DFLJPG.load(filepath)
src_filename = dflimg.get_source_filename()
image_to_face_mat = dflimg.get_image_to_face_mat()
seg_filepath = input_path / ( Path(src_filename).stem + '_seg.png')
if not seg_filepath.exists():
raise ValueError(f'{seg_filepath} does not exist')
seg = cv2_imread(seg_filepath)
seg_inds = np.isin(seg, [1,2,3,4,5,6,9,10,11])
seg[~seg_inds] = 0
seg[seg_inds] = 1
seg = seg.astype(np.float32)
seg = cv2.warpAffine(seg, image_to_face_mat, (image_size, image_size), cv2.INTER_LANCZOS4)
dflimg.set_xseg_mask(seg)
dflimg.save()

View file

@ -140,7 +140,7 @@ class InteractiveMergerSubprocessor(Subprocessor):
#override #override
def __init__(self, is_interactive, merger_session_filepath, predictor_func, predictor_input_shape, face_enhancer_func, xseg_256_extract_func, merger_config, frames, frames_root_path, output_path, output_mask_path, model_iter, subprocess_count=4): def __init__(self, is_interactive, merger_session_filepath, predictor_func, predictor_input_shape, face_enhancer_func, xseg_256_extract_func, merger_config, frames, frames_root_path, output_path, output_mask_path, model_iter):
if len (frames) == 0: if len (frames) == 0:
raise ValueError ("len (frames) == 0") raise ValueError ("len (frames) == 0")
@ -161,7 +161,7 @@ class InteractiveMergerSubprocessor(Subprocessor):
self.output_mask_path = output_mask_path self.output_mask_path = output_mask_path
self.model_iter = model_iter self.model_iter = model_iter
self.prefetch_frame_count = self.process_count = subprocess_count self.prefetch_frame_count = self.process_count = multiprocessing.cpu_count()
session_data = None session_data = None
if self.is_interactive and self.merger_session_filepath.exists(): if self.is_interactive and self.merger_session_filepath.exists():
@ -393,7 +393,6 @@ class InteractiveMergerSubprocessor(Subprocessor):
# unable to read? recompute then # unable to read? recompute then
cur_frame.is_done = False cur_frame.is_done = False
else: else:
image = imagelib.normalize_channels(image, 3)
image_mask = imagelib.normalize_channels(image_mask, 1) image_mask = imagelib.normalize_channels(image_mask, 1)
cur_frame.image = np.concatenate([image, image_mask], -1) cur_frame.image = np.concatenate([image, image_mask], -1)

View file

@ -1,25 +1,26 @@
import sys
import traceback import traceback
import cv2 import cv2
import numpy as np import numpy as np
from core import imagelib from core import imagelib
from core.cv2ex import *
from core.interact import interact as io
from facelib import FaceType, LandmarksProcessor from facelib import FaceType, LandmarksProcessor
from core.interact import interact as io
from core.cv2ex import *
is_windows = sys.platform[0:3] == 'win'
xseg_input_size = 256 xseg_input_size = 256
def MergeMaskedFace (predictor_func, predictor_input_shape, def MergeMaskedFace (predictor_func, predictor_input_shape,
face_enhancer_func, face_enhancer_func,
xseg_256_extract_func, xseg_256_extract_func,
cfg, frame_info, img_bgr_uint8, img_bgr, img_face_landmarks): cfg, frame_info, img_bgr_uint8, img_bgr, img_face_landmarks):
img_size = img_bgr.shape[1], img_bgr.shape[0] img_size = img_bgr.shape[1], img_bgr.shape[0]
img_face_mask_a = LandmarksProcessor.get_image_hull_mask (img_bgr.shape, img_face_landmarks) img_face_mask_a = LandmarksProcessor.get_image_hull_mask (img_bgr.shape, img_face_landmarks)
out_img = img_bgr.copy()
out_merging_mask_a = None
input_size = predictor_input_shape[0] input_size = predictor_input_shape[0]
mask_subres_size = input_size*4 mask_subres_size = input_size*4
output_size = input_size output_size = input_size
@ -54,49 +55,45 @@ def MergeMaskedFace (predictor_func, predictor_input_shape,
prd_face_bgr = np.clip(prd_face_bgr, 0, 1) prd_face_bgr = np.clip(prd_face_bgr, 0, 1)
if cfg.super_resolution_power != 0: if cfg.super_resolution_power != 0:
prd_face_mask_a_0 = cv2.resize (prd_face_mask_a_0, (output_size, output_size), interpolation=cv2.INTER_CUBIC) prd_face_mask_a_0 = cv2.resize (prd_face_mask_a_0, (output_size, output_size), cv2.INTER_CUBIC)
prd_face_dst_mask_a_0 = cv2.resize (prd_face_dst_mask_a_0, (output_size, output_size), interpolation=cv2.INTER_CUBIC) prd_face_dst_mask_a_0 = cv2.resize (prd_face_dst_mask_a_0, (output_size, output_size), cv2.INTER_CUBIC)
if cfg.mask_mode == 0: #full if cfg.mask_mode == 1: #dst
wrk_face_mask_a_0 = np.ones_like(dst_face_mask_a_0) wrk_face_mask_a_0 = cv2.resize (dst_face_mask_a_0, (output_size,output_size), cv2.INTER_CUBIC)
elif cfg.mask_mode == 1: #dst
wrk_face_mask_a_0 = cv2.resize (dst_face_mask_a_0, (output_size,output_size), interpolation=cv2.INTER_CUBIC)
elif cfg.mask_mode == 2: #learned-prd elif cfg.mask_mode == 2: #learned-prd
wrk_face_mask_a_0 = prd_face_mask_a_0 wrk_face_mask_a_0 = prd_face_mask_a_0
elif cfg.mask_mode == 3: #learned-dst elif cfg.mask_mode == 3: #learned-dst
wrk_face_mask_a_0 = prd_face_dst_mask_a_0 wrk_face_mask_a_0 = prd_face_dst_mask_a_0
elif cfg.mask_mode == 4: #learned-prd*learned-dst elif cfg.mask_mode == 4: #learned-prd*learned-dst
wrk_face_mask_a_0 = prd_face_mask_a_0*prd_face_dst_mask_a_0 wrk_face_mask_a_0 = prd_face_mask_a_0*prd_face_dst_mask_a_0
elif cfg.mask_mode == 5: #learned-prd+learned-dst elif cfg.mask_mode >= 5 and cfg.mask_mode <= 8: #XSeg modes
wrk_face_mask_a_0 = np.clip( prd_face_mask_a_0+prd_face_dst_mask_a_0, 0, 1) if cfg.mask_mode == 5 or cfg.mask_mode == 7 or cfg.mask_mode == 8:
elif cfg.mask_mode >= 6 and cfg.mask_mode <= 9: #XSeg modes
if cfg.mask_mode == 6 or cfg.mask_mode == 8 or cfg.mask_mode == 9:
# obtain XSeg-prd # obtain XSeg-prd
prd_face_xseg_bgr = cv2.resize (prd_face_bgr, (xseg_input_size,)*2, interpolation=cv2.INTER_CUBIC) prd_face_xseg_bgr = cv2.resize (prd_face_bgr, (xseg_input_size,)*2, cv2.INTER_CUBIC)
prd_face_xseg_mask = xseg_256_extract_func(prd_face_xseg_bgr) prd_face_xseg_mask = xseg_256_extract_func(prd_face_xseg_bgr)
X_prd_face_mask_a_0 = cv2.resize ( prd_face_xseg_mask, (output_size, output_size), interpolation=cv2.INTER_CUBIC) X_prd_face_mask_a_0 = cv2.resize ( prd_face_xseg_mask, (output_size, output_size), cv2.INTER_CUBIC)
if cfg.mask_mode >= 7 and cfg.mask_mode <= 9: if cfg.mask_mode >= 6 and cfg.mask_mode <= 8:
# obtain XSeg-dst # obtain XSeg-dst
xseg_mat = LandmarksProcessor.get_transform_mat (img_face_landmarks, xseg_input_size, face_type=cfg.face_type) xseg_mat = LandmarksProcessor.get_transform_mat (img_face_landmarks, xseg_input_size, face_type=cfg.face_type)
dst_face_xseg_bgr = cv2.warpAffine(img_bgr, xseg_mat, (xseg_input_size,)*2, flags=cv2.INTER_CUBIC ) dst_face_xseg_bgr = cv2.warpAffine(img_bgr, xseg_mat, (xseg_input_size,)*2, flags=cv2.INTER_CUBIC )
dst_face_xseg_mask = xseg_256_extract_func(dst_face_xseg_bgr) dst_face_xseg_mask = xseg_256_extract_func(dst_face_xseg_bgr)
X_dst_face_mask_a_0 = cv2.resize (dst_face_xseg_mask, (output_size,output_size), interpolation=cv2.INTER_CUBIC) X_dst_face_mask_a_0 = cv2.resize (dst_face_xseg_mask, (output_size,output_size), cv2.INTER_CUBIC)
if cfg.mask_mode == 6: #'XSeg-prd' if cfg.mask_mode == 5: #'XSeg-prd'
wrk_face_mask_a_0 = X_prd_face_mask_a_0 wrk_face_mask_a_0 = X_prd_face_mask_a_0
elif cfg.mask_mode == 7: #'XSeg-dst' elif cfg.mask_mode == 6: #'XSeg-dst'
wrk_face_mask_a_0 = X_dst_face_mask_a_0 wrk_face_mask_a_0 = X_dst_face_mask_a_0
elif cfg.mask_mode == 8: #'XSeg-prd*XSeg-dst' elif cfg.mask_mode == 7: #'XSeg-prd*XSeg-dst'
wrk_face_mask_a_0 = X_prd_face_mask_a_0 * X_dst_face_mask_a_0 wrk_face_mask_a_0 = X_prd_face_mask_a_0 * X_dst_face_mask_a_0
elif cfg.mask_mode == 9: #learned-prd*learned-dst*XSeg-prd*XSeg-dst elif cfg.mask_mode == 8: #learned-prd*learned-dst*XSeg-prd*XSeg-dst
wrk_face_mask_a_0 = prd_face_mask_a_0 * prd_face_dst_mask_a_0 * X_prd_face_mask_a_0 * X_dst_face_mask_a_0 wrk_face_mask_a_0 = prd_face_mask_a_0 * prd_face_dst_mask_a_0 * X_prd_face_mask_a_0 * X_dst_face_mask_a_0
wrk_face_mask_a_0[ wrk_face_mask_a_0 < (1.0/255.0) ] = 0.0 # get rid of noise wrk_face_mask_a_0[ wrk_face_mask_a_0 < (1.0/255.0) ] = 0.0 # get rid of noise
# resize to mask_subres_size # resize to mask_subres_size
if wrk_face_mask_a_0.shape[0] != mask_subres_size: if wrk_face_mask_a_0.shape[0] != mask_subres_size:
wrk_face_mask_a_0 = cv2.resize (wrk_face_mask_a_0, (mask_subres_size, mask_subres_size), interpolation=cv2.INTER_CUBIC) wrk_face_mask_a_0 = cv2.resize (wrk_face_mask_a_0, (mask_subres_size, mask_subres_size), cv2.INTER_CUBIC)
# process mask in local predicted space # process mask in local predicted space
if 'raw' not in cfg.mode: if 'raw' not in cfg.mode:
@ -130,187 +127,184 @@ def MergeMaskedFace (predictor_func, predictor_input_shape,
img_face_mask_a = cv2.warpAffine( wrk_face_mask_a_0, face_mask_output_mat, img_size, np.zeros(img_bgr.shape[0:2], dtype=np.float32), flags=cv2.WARP_INVERSE_MAP | cv2.INTER_CUBIC )[...,None] img_face_mask_a = cv2.warpAffine( wrk_face_mask_a_0, face_mask_output_mat, img_size, np.zeros(img_bgr.shape[0:2], dtype=np.float32), flags=cv2.WARP_INVERSE_MAP | cv2.INTER_CUBIC )[...,None]
img_face_mask_a = np.clip (img_face_mask_a, 0.0, 1.0) img_face_mask_a = np.clip (img_face_mask_a, 0.0, 1.0)
img_face_mask_a [ img_face_mask_a < (1.0/255.0) ] = 0.0 # get rid of noise img_face_mask_a [ img_face_mask_a < (1.0/255.0) ] = 0.0 # get rid of noise
if wrk_face_mask_a_0.shape[0] != output_size: if wrk_face_mask_a_0.shape[0] != output_size:
wrk_face_mask_a_0 = cv2.resize (wrk_face_mask_a_0, (output_size,output_size), interpolation=cv2.INTER_CUBIC) wrk_face_mask_a_0 = cv2.resize (wrk_face_mask_a_0, (output_size,output_size), cv2.INTER_CUBIC)
wrk_face_mask_a = wrk_face_mask_a_0[...,None] wrk_face_mask_a = wrk_face_mask_a_0[...,None]
wrk_face_mask_area_a = wrk_face_mask_a.copy()
wrk_face_mask_area_a[wrk_face_mask_area_a>0] = 1.0
out_img = None
out_merging_mask_a = None
if cfg.mode == 'original': if cfg.mode == 'original':
return img_bgr, img_face_mask_a return img_bgr, img_face_mask_a
elif 'raw' in cfg.mode: elif 'raw' in cfg.mode:
if cfg.mode == 'raw-rgb': if cfg.mode == 'raw-rgb':
out_img_face = cv2.warpAffine( prd_face_bgr, face_output_mat, img_size, np.empty_like(img_bgr), cv2.WARP_INVERSE_MAP | cv2.INTER_CUBIC) out_img = cv2.warpAffine( prd_face_bgr, face_output_mat, img_size, out_img, cv2.WARP_INVERSE_MAP | cv2.INTER_CUBIC, cv2.BORDER_TRANSPARENT )
out_img_face_mask = cv2.warpAffine( np.ones_like(prd_face_bgr), face_output_mat, img_size, np.empty_like(img_bgr), cv2.WARP_INVERSE_MAP | cv2.INTER_CUBIC)
out_img = img_bgr*(1-out_img_face_mask) + out_img_face*out_img_face_mask
out_merging_mask_a = img_face_mask_a out_merging_mask_a = img_face_mask_a
elif cfg.mode == 'raw-predict': elif cfg.mode == 'raw-predict':
out_img = prd_face_bgr out_img = prd_face_bgr
out_merging_mask_a = wrk_face_mask_a out_merging_mask_a = wrk_face_mask_a
else:
raise ValueError(f"undefined raw type {cfg.mode}")
out_img = np.clip (out_img, 0.0, 1.0 ) out_img = np.clip (out_img, 0.0, 1.0 )
else: else:
#averaging [lenx, leny, maskx, masky] by grayscale gradients of upscaled mask
ar = []
for i in range(1, 10):
maxregion = np.argwhere( img_face_mask_a > i / 10.0 )
if maxregion.size != 0:
miny,minx = maxregion.min(axis=0)[:2]
maxy,maxx = maxregion.max(axis=0)[:2]
lenx = maxx - minx
leny = maxy - miny
if min(lenx,leny) >= 4:
ar += [ [ lenx, leny] ]
# Process if the mask meets minimum size if len(ar) > 0:
maxregion = np.argwhere( img_face_mask_a >= 0.1 )
if maxregion.size != 0:
miny,minx = maxregion.min(axis=0)[:2]
maxy,maxx = maxregion.max(axis=0)[:2]
lenx = maxx - minx
leny = maxy - miny
if min(lenx,leny) >= 4:
wrk_face_mask_area_a = wrk_face_mask_a.copy()
wrk_face_mask_area_a[wrk_face_mask_area_a>0] = 1.0
if 'seamless' not in cfg.mode and cfg.color_transfer_mode != 0: if 'seamless' not in cfg.mode and cfg.color_transfer_mode != 0:
if cfg.color_transfer_mode == 1: #rct if cfg.color_transfer_mode == 1: #rct
prd_face_bgr = imagelib.reinhard_color_transfer (prd_face_bgr, dst_face_bgr, target_mask=wrk_face_mask_area_a, source_mask=wrk_face_mask_area_a) prd_face_bgr = imagelib.reinhard_color_transfer ( np.clip( prd_face_bgr*wrk_face_mask_area_a*255, 0, 255).astype(np.uint8),
elif cfg.color_transfer_mode == 2: #lct np.clip( dst_face_bgr*wrk_face_mask_area_a*255, 0, 255).astype(np.uint8), )
prd_face_bgr = imagelib.linear_color_transfer (prd_face_bgr, dst_face_bgr)
elif cfg.color_transfer_mode == 3: #mkl
prd_face_bgr = imagelib.color_transfer_mkl (prd_face_bgr, dst_face_bgr)
elif cfg.color_transfer_mode == 4: #mkl-m
prd_face_bgr = imagelib.color_transfer_mkl (prd_face_bgr*wrk_face_mask_area_a, dst_face_bgr*wrk_face_mask_area_a)
elif cfg.color_transfer_mode == 5: #idt
prd_face_bgr = imagelib.color_transfer_idt (prd_face_bgr, dst_face_bgr)
elif cfg.color_transfer_mode == 6: #idt-m
prd_face_bgr = imagelib.color_transfer_idt (prd_face_bgr*wrk_face_mask_area_a, dst_face_bgr*wrk_face_mask_area_a)
elif cfg.color_transfer_mode == 7: #sot-m
prd_face_bgr = imagelib.color_transfer_sot (prd_face_bgr*wrk_face_mask_area_a, dst_face_bgr*wrk_face_mask_area_a, steps=10, batch_size=30)
prd_face_bgr = np.clip (prd_face_bgr, 0.0, 1.0)
elif cfg.color_transfer_mode == 8: #mix-m
prd_face_bgr = imagelib.color_transfer_mix (prd_face_bgr*wrk_face_mask_area_a, dst_face_bgr*wrk_face_mask_area_a)
if cfg.mode == 'hist-match': prd_face_bgr = np.clip( prd_face_bgr.astype(np.float32) / 255.0, 0.0, 1.0)
hist_mask_a = np.ones ( prd_face_bgr.shape[:2] + (1,) , dtype=np.float32) elif cfg.color_transfer_mode == 2: #lct
prd_face_bgr = imagelib.linear_color_transfer (prd_face_bgr, dst_face_bgr)
elif cfg.color_transfer_mode == 3: #mkl
prd_face_bgr = imagelib.color_transfer_mkl (prd_face_bgr, dst_face_bgr)
elif cfg.color_transfer_mode == 4: #mkl-m
prd_face_bgr = imagelib.color_transfer_mkl (prd_face_bgr*wrk_face_mask_area_a, dst_face_bgr*wrk_face_mask_area_a)
elif cfg.color_transfer_mode == 5: #idt
prd_face_bgr = imagelib.color_transfer_idt (prd_face_bgr, dst_face_bgr)
elif cfg.color_transfer_mode == 6: #idt-m
prd_face_bgr = imagelib.color_transfer_idt (prd_face_bgr*wrk_face_mask_area_a, dst_face_bgr*wrk_face_mask_area_a)
elif cfg.color_transfer_mode == 7: #sot-m
prd_face_bgr = imagelib.color_transfer_sot (prd_face_bgr*wrk_face_mask_area_a, dst_face_bgr*wrk_face_mask_area_a)
prd_face_bgr = np.clip (prd_face_bgr, 0.0, 1.0)
elif cfg.color_transfer_mode == 8: #mix-m
prd_face_bgr = imagelib.color_transfer_mix (prd_face_bgr*wrk_face_mask_area_a, dst_face_bgr*wrk_face_mask_area_a)
if cfg.masked_hist_match: if cfg.mode == 'hist-match':
hist_mask_a *= wrk_face_mask_area_a hist_mask_a = np.ones ( prd_face_bgr.shape[:2] + (1,) , dtype=np.float32)
white = (1.0-hist_mask_a)* np.ones ( prd_face_bgr.shape[:2] + (1,) , dtype=np.float32) if cfg.masked_hist_match:
hist_mask_a *= wrk_face_mask_area_a
hist_match_1 = prd_face_bgr*hist_mask_a + white white = (1.0-hist_mask_a)* np.ones ( prd_face_bgr.shape[:2] + (1,) , dtype=np.float32)
hist_match_1[ hist_match_1 > 1.0 ] = 1.0
hist_match_2 = dst_face_bgr*hist_mask_a + white hist_match_1 = prd_face_bgr*hist_mask_a + white
hist_match_2[ hist_match_1 > 1.0 ] = 1.0 hist_match_1[ hist_match_1 > 1.0 ] = 1.0
prd_face_bgr = imagelib.color_hist_match(hist_match_1, hist_match_2, cfg.hist_match_threshold ).astype(dtype=np.float32) hist_match_2 = dst_face_bgr*hist_mask_a + white
hist_match_2[ hist_match_1 > 1.0 ] = 1.0
if 'seamless' in cfg.mode: prd_face_bgr = imagelib.color_hist_match(hist_match_1, hist_match_2, cfg.hist_match_threshold ).astype(dtype=np.float32)
#mask used for cv2.seamlessClone
img_face_seamless_mask_a = None
for i in range(1,10):
a = img_face_mask_a > i / 10.0
if len(np.argwhere(a)) == 0:
continue
img_face_seamless_mask_a = img_face_mask_a.copy()
img_face_seamless_mask_a[a] = 1.0
img_face_seamless_mask_a[img_face_seamless_mask_a <= i / 10.0] = 0.0
break
out_img = cv2.warpAffine( prd_face_bgr, face_output_mat, img_size, np.empty_like(img_bgr), cv2.WARP_INVERSE_MAP | cv2.INTER_CUBIC ) if 'seamless' in cfg.mode:
out_img = np.clip(out_img, 0.0, 1.0) #mask used for cv2.seamlessClone
img_face_seamless_mask_a = None
for i in range(1,10):
a = img_face_mask_a > i / 10.0
if len(np.argwhere(a)) == 0:
continue
img_face_seamless_mask_a = img_face_mask_a.copy()
img_face_seamless_mask_a[a] = 1.0
img_face_seamless_mask_a[img_face_seamless_mask_a <= i / 10.0] = 0.0
break
if 'seamless' in cfg.mode: out_img = cv2.warpAffine( prd_face_bgr, face_output_mat, img_size, out_img, cv2.WARP_INVERSE_MAP | cv2.INTER_CUBIC, cv2.BORDER_TRANSPARENT )
try:
#calc same bounding rect and center point as in cv2.seamlessClone to prevent jittering (not flickering)
l,t,w,h = cv2.boundingRect( (img_face_seamless_mask_a*255).astype(np.uint8) )
s_maskx, s_masky = int(l+w/2), int(t+h/2)
out_img = cv2.seamlessClone( (out_img*255).astype(np.uint8), img_bgr_uint8, (img_face_seamless_mask_a*255).astype(np.uint8), (s_maskx,s_masky) , cv2.NORMAL_CLONE )
out_img = out_img.astype(dtype=np.float32) / 255.0
except Exception as e:
#seamlessClone may fail in some cases
e_str = traceback.format_exc()
if 'MemoryError' in e_str: out_img = np.clip(out_img, 0.0, 1.0)
raise Exception("Seamless fail: " + e_str) #reraise MemoryError in order to reprocess this data by other processes
else:
print ("Seamless fail: " + e_str)
cfg_mp = cfg.motion_blur_power / 100.0 if 'seamless' in cfg.mode:
try:
#calc same bounding rect and center point as in cv2.seamlessClone to prevent jittering (not flickering)
l,t,w,h = cv2.boundingRect( (img_face_seamless_mask_a*255).astype(np.uint8) )
s_maskx, s_masky = int(l+w/2), int(t+h/2)
out_img = cv2.seamlessClone( (out_img*255).astype(np.uint8), img_bgr_uint8, (img_face_seamless_mask_a*255).astype(np.uint8), (s_maskx,s_masky) , cv2.NORMAL_CLONE )
out_img = out_img.astype(dtype=np.float32) / 255.0
except Exception as e:
#seamlessClone may fail in some cases
e_str = traceback.format_exc()
out_img = img_bgr*(1-img_face_mask_a) + (out_img*img_face_mask_a) if 'MemoryError' in e_str:
raise Exception("Seamless fail: " + e_str) #reraise MemoryError in order to reprocess this data by other processes
if ('seamless' in cfg.mode and cfg.color_transfer_mode != 0) or \
cfg.mode == 'seamless-hist-match' or \
cfg_mp != 0 or \
cfg.blursharpen_amount != 0 or \
cfg.image_denoise_power != 0 or \
cfg.bicubic_degrade_power != 0:
out_face_bgr = cv2.warpAffine( out_img, face_mat, (output_size, output_size), flags=cv2.INTER_CUBIC )
if 'seamless' in cfg.mode and cfg.color_transfer_mode != 0:
if cfg.color_transfer_mode == 1:
out_face_bgr = imagelib.reinhard_color_transfer (out_face_bgr, dst_face_bgr, target_mask=wrk_face_mask_area_a, source_mask=wrk_face_mask_area_a)
elif cfg.color_transfer_mode == 2: #lct
out_face_bgr = imagelib.linear_color_transfer (out_face_bgr, dst_face_bgr)
elif cfg.color_transfer_mode == 3: #mkl
out_face_bgr = imagelib.color_transfer_mkl (out_face_bgr, dst_face_bgr)
elif cfg.color_transfer_mode == 4: #mkl-m
out_face_bgr = imagelib.color_transfer_mkl (out_face_bgr*wrk_face_mask_area_a, dst_face_bgr*wrk_face_mask_area_a)
elif cfg.color_transfer_mode == 5: #idt
out_face_bgr = imagelib.color_transfer_idt (out_face_bgr, dst_face_bgr)
elif cfg.color_transfer_mode == 6: #idt-m
out_face_bgr = imagelib.color_transfer_idt (out_face_bgr*wrk_face_mask_area_a, dst_face_bgr*wrk_face_mask_area_a)
elif cfg.color_transfer_mode == 7: #sot-m
out_face_bgr = imagelib.color_transfer_sot (out_face_bgr*wrk_face_mask_area_a, dst_face_bgr*wrk_face_mask_area_a, steps=10, batch_size=30)
out_face_bgr = np.clip (out_face_bgr, 0.0, 1.0)
elif cfg.color_transfer_mode == 8: #mix-m
out_face_bgr = imagelib.color_transfer_mix (out_face_bgr*wrk_face_mask_area_a, dst_face_bgr*wrk_face_mask_area_a)
if cfg.mode == 'seamless-hist-match':
out_face_bgr = imagelib.color_hist_match(out_face_bgr, dst_face_bgr, cfg.hist_match_threshold)
if cfg_mp != 0:
k_size = int(frame_info.motion_power*cfg_mp)
if k_size >= 1:
k_size = np.clip (k_size+1, 2, 50)
if cfg.super_resolution_power != 0:
k_size *= 2
out_face_bgr = imagelib.LinearMotionBlur (out_face_bgr, k_size , frame_info.motion_deg)
if cfg.blursharpen_amount != 0:
out_face_bgr = imagelib.blursharpen ( out_face_bgr, cfg.sharpen_mode, 3, cfg.blursharpen_amount)
if cfg.image_denoise_power != 0:
n = cfg.image_denoise_power
while n > 0:
img_bgr_denoised = cv2.medianBlur(img_bgr, 5)
if int(n / 100) != 0:
img_bgr = img_bgr_denoised
else:
pass_power = (n % 100) / 100.0
img_bgr = img_bgr*(1.0-pass_power)+img_bgr_denoised*pass_power
n = max(n-10,0)
if cfg.bicubic_degrade_power != 0:
p = 1.0 - cfg.bicubic_degrade_power / 101.0
img_bgr_downscaled = cv2.resize (img_bgr, ( int(img_size[0]*p), int(img_size[1]*p ) ), interpolation=cv2.INTER_CUBIC)
img_bgr = cv2.resize (img_bgr_downscaled, img_size, interpolation=cv2.INTER_CUBIC)
new_out = cv2.warpAffine( out_face_bgr, face_mat, img_size, np.empty_like(img_bgr), cv2.WARP_INVERSE_MAP | cv2.INTER_CUBIC )
out_img = np.clip( img_bgr*(1-img_face_mask_a) + (new_out*img_face_mask_a) , 0, 1.0 )
if cfg.color_degrade_power != 0:
out_img_reduced = imagelib.reduce_colors(out_img, 256)
if cfg.color_degrade_power == 100:
out_img = out_img_reduced
else: else:
alpha = cfg.color_degrade_power / 100.0 print ("Seamless fail: " + e_str)
out_img = (out_img*(1.0-alpha) + out_img_reduced*alpha)
out_merging_mask_a = img_face_mask_a
if out_img is None:
out_img = img_bgr.copy() out_img = img_bgr*(1-img_face_mask_a) + (out_img*img_face_mask_a)
out_face_bgr = cv2.warpAffine( out_img, face_mat, (output_size, output_size), flags=cv2.INTER_CUBIC )
if 'seamless' in cfg.mode and cfg.color_transfer_mode != 0:
if cfg.color_transfer_mode == 1:
out_face_bgr = imagelib.reinhard_color_transfer ( np.clip(out_face_bgr*wrk_face_mask_area_a*255, 0, 255).astype(np.uint8),
np.clip(dst_face_bgr*wrk_face_mask_area_a*255, 0, 255).astype(np.uint8) )
out_face_bgr = np.clip( out_face_bgr.astype(np.float32) / 255.0, 0.0, 1.0)
elif cfg.color_transfer_mode == 2: #lct
out_face_bgr = imagelib.linear_color_transfer (out_face_bgr, dst_face_bgr)
elif cfg.color_transfer_mode == 3: #mkl
out_face_bgr = imagelib.color_transfer_mkl (out_face_bgr, dst_face_bgr)
elif cfg.color_transfer_mode == 4: #mkl-m
out_face_bgr = imagelib.color_transfer_mkl (out_face_bgr*wrk_face_mask_area_a, dst_face_bgr*wrk_face_mask_area_a)
elif cfg.color_transfer_mode == 5: #idt
out_face_bgr = imagelib.color_transfer_idt (out_face_bgr, dst_face_bgr)
elif cfg.color_transfer_mode == 6: #idt-m
out_face_bgr = imagelib.color_transfer_idt (out_face_bgr*wrk_face_mask_area_a, dst_face_bgr*wrk_face_mask_area_a)
elif cfg.color_transfer_mode == 7: #sot-m
out_face_bgr = imagelib.color_transfer_sot (out_face_bgr*wrk_face_mask_area_a, dst_face_bgr*wrk_face_mask_area_a)
out_face_bgr = np.clip (out_face_bgr, 0.0, 1.0)
elif cfg.color_transfer_mode == 8: #mix-m
out_face_bgr = imagelib.color_transfer_mix (out_face_bgr*wrk_face_mask_area_a, dst_face_bgr*wrk_face_mask_area_a)
if cfg.mode == 'seamless-hist-match':
out_face_bgr = imagelib.color_hist_match(out_face_bgr, dst_face_bgr, cfg.hist_match_threshold)
cfg_mp = cfg.motion_blur_power / 100.0
if cfg_mp != 0:
k_size = int(frame_info.motion_power*cfg_mp)
if k_size >= 1:
k_size = np.clip (k_size+1, 2, 50)
if cfg.super_resolution_power != 0:
k_size *= 2
out_face_bgr = imagelib.LinearMotionBlur (out_face_bgr, k_size , frame_info.motion_deg)
if cfg.blursharpen_amount != 0:
out_face_bgr = imagelib.blursharpen ( out_face_bgr, cfg.sharpen_mode, 3, cfg.blursharpen_amount)
if cfg.image_denoise_power != 0:
n = cfg.image_denoise_power
while n > 0:
img_bgr_denoised = cv2.medianBlur(img_bgr, 5)
if int(n / 100) != 0:
img_bgr = img_bgr_denoised
else:
pass_power = (n % 100) / 100.0
img_bgr = img_bgr*(1.0-pass_power)+img_bgr_denoised*pass_power
n = max(n-10,0)
if cfg.bicubic_degrade_power != 0:
p = 1.0 - cfg.bicubic_degrade_power / 101.0
img_bgr_downscaled = cv2.resize (img_bgr, ( int(img_size[0]*p), int(img_size[1]*p ) ), cv2.INTER_CUBIC)
img_bgr = cv2.resize (img_bgr_downscaled, img_size, cv2.INTER_CUBIC)
new_out = cv2.warpAffine( out_face_bgr, face_mat, img_size, img_bgr.copy(), cv2.WARP_INVERSE_MAP | cv2.INTER_CUBIC, cv2.BORDER_TRANSPARENT )
out_img = np.clip( img_bgr*(1-img_face_mask_a) + (new_out*img_face_mask_a) , 0, 1.0 )
if cfg.color_degrade_power != 0:
out_img_reduced = imagelib.reduce_colors(out_img, 256)
if cfg.color_degrade_power == 100:
out_img = out_img_reduced
else:
alpha = cfg.color_degrade_power / 100.0
out_img = (out_img*(1.0-alpha) + out_img_reduced*alpha)
out_merging_mask_a = img_face_mask_a
return out_img, out_merging_mask_a return out_img, out_merging_mask_a

View file

@ -81,16 +81,14 @@ mode_dict = {0:'original',
mode_str_dict = { mode_dict[key] : key for key in mode_dict.keys() } mode_str_dict = { mode_dict[key] : key for key in mode_dict.keys() }
mask_mode_dict = {0:'full', mask_mode_dict = {1:'dst',
1:'dst',
2:'learned-prd', 2:'learned-prd',
3:'learned-dst', 3:'learned-dst',
4:'learned-prd*learned-dst', 4:'learned-prd*learned-dst',
5:'learned-prd+learned-dst', 5:'XSeg-prd',
6:'XSeg-prd', 6:'XSeg-dst',
7:'XSeg-dst', 7:'XSeg-prd*XSeg-dst',
8:'XSeg-prd*XSeg-dst', 8:'learned-prd*learned-dst*XSeg-prd*XSeg-dst'
9:'learned-prd*learned-dst*XSeg-prd*XSeg-dst'
} }

Binary file not shown.

Before

Width:  |  Height:  |  Size: 260 KiB

After

Width:  |  Height:  |  Size: 310 KiB

Before After
Before After

View file

@ -1,7 +1,6 @@
import colorsys import colorsys
import inspect import inspect
import json import json
import multiprocessing
import operator import operator
import os import os
import pickle import pickle
@ -13,16 +12,16 @@ from pathlib import Path
import cv2 import cv2
import numpy as np import numpy as np
from core import imagelib, pathex from core import imagelib
from core.cv2ex import *
from core.interact import interact as io from core.interact import interact as io
from core.leras import nn from core.leras import nn
from samplelib import SampleGeneratorBase from samplelib import SampleGeneratorBase
from core import pathex
from core.cv2ex import *
class ModelBase(object): class ModelBase(object):
def __init__(self, is_training=False, def __init__(self, is_training=False,
is_exporting=False,
saved_models_path=None, saved_models_path=None,
training_data_src_path=None, training_data_src_path=None,
training_data_dst_path=None, training_data_dst_path=None,
@ -37,7 +36,6 @@ class ModelBase(object):
silent_start=False, silent_start=False,
**kwargs): **kwargs):
self.is_training = is_training self.is_training = is_training
self.is_exporting = is_exporting
self.saved_models_path = saved_models_path self.saved_models_path = saved_models_path
self.training_data_src_path = training_data_src_path self.training_data_src_path = training_data_src_path
self.training_data_dst_path = training_data_dst_path self.training_data_dst_path = training_data_dst_path
@ -134,7 +132,6 @@ class ModelBase(object):
self.iter = 0 self.iter = 0
self.options = {} self.options = {}
self.options_show_override = {}
self.loss_history = [] self.loss_history = []
self.sample_for_preview = None self.sample_for_preview = None
self.choosed_gpu_indexes = None self.choosed_gpu_indexes = None
@ -187,13 +184,10 @@ class ModelBase(object):
self.write_preview_history = self.options.get('write_preview_history', False) self.write_preview_history = self.options.get('write_preview_history', False)
self.target_iter = self.options.get('target_iter',0) self.target_iter = self.options.get('target_iter',0)
self.random_flip = self.options.get('random_flip',True) self.random_flip = self.options.get('random_flip',True)
self.random_src_flip = self.options.get('random_src_flip', False)
self.random_dst_flip = self.options.get('random_dst_flip', True)
self.on_initialize() self.on_initialize()
self.options['batch_size'] = self.batch_size self.options['batch_size'] = self.batch_size
self.preview_history_writer = None
if self.is_training: if self.is_training:
self.preview_history_path = self.saved_models_path / ( f'{self.get_model_name()}_history' ) self.preview_history_path = self.saved_models_path / ( f'{self.get_model_name()}_history' )
self.autobackups_path = self.saved_models_path / ( f'{self.get_model_name()}_autobackups' ) self.autobackups_path = self.saved_models_path / ( f'{self.get_model_name()}_autobackups' )
@ -226,17 +220,15 @@ class ModelBase(object):
def update_sample_for_preview(self, choose_preview_history=False, force_new=False): def update_sample_for_preview(self, choose_preview_history=False, force_new=False):
if self.sample_for_preview is None or choose_preview_history or force_new: if self.sample_for_preview is None or choose_preview_history or force_new:
if choose_preview_history and io.is_support_windows(): if choose_preview_history and io.is_support_windows():
wnd_name = "[p] - next. [space] - switch preview type. [enter] - confirm." io.log_info ("Choose image for the preview history. [p] - next. [enter] - confirm.")
io.log_info (f"Choose image for the preview history. {wnd_name}") wnd_name = "[p] - next. [enter] - confirm."
io.named_window(wnd_name) io.named_window(wnd_name)
io.capture_keys(wnd_name) io.capture_keys(wnd_name)
choosed = False choosed = False
preview_id_counter = 0
while not choosed: while not choosed:
self.sample_for_preview = self.generate_next_samples() self.sample_for_preview = self.generate_next_samples()
previews = self.get_history_previews() preview = self.get_static_preview()
io.show_image( wnd_name, (preview*255).astype(np.uint8) )
io.show_image( wnd_name, ( previews[preview_id_counter % len(previews) ][1] *255).astype(np.uint8) )
while True: while True:
key_events = io.get_key_events(wnd_name) key_events = io.get_key_events(wnd_name)
@ -244,9 +236,6 @@ class ModelBase(object):
if key == ord('\n') or key == ord('\r'): if key == ord('\n') or key == ord('\r'):
choosed = True choosed = True
break break
elif key == ord(' '):
preview_id_counter += 1
break
elif key == ord('p'): elif key == ord('p'):
break break
@ -260,7 +249,7 @@ class ModelBase(object):
self.sample_for_preview = self.generate_next_samples() self.sample_for_preview = self.generate_next_samples()
try: try:
self.get_history_previews() self.get_static_preview()
except: except:
self.sample_for_preview = self.generate_next_samples() self.sample_for_preview = self.generate_next_samples()
@ -302,23 +291,9 @@ class ModelBase(object):
default_random_flip = self.load_or_def_option('random_flip', True) default_random_flip = self.load_or_def_option('random_flip', True)
self.options['random_flip'] = io.input_bool("Flip faces randomly", default_random_flip, help_message="Predicted face will look more naturally without this option, but src faceset should cover all face directions as dst faceset.") self.options['random_flip'] = io.input_bool("Flip faces randomly", default_random_flip, help_message="Predicted face will look more naturally without this option, but src faceset should cover all face directions as dst faceset.")
def ask_random_src_flip(self): def ask_batch_size(self, suggest_batch_size=None):
default_random_src_flip = self.load_or_def_option('random_src_flip', False)
self.options['random_src_flip'] = io.input_bool("Flip SRC faces randomly", default_random_src_flip, help_message="Random horizontal flip SRC faceset. Covers more angles, but the face may look less naturally.")
def ask_random_dst_flip(self):
default_random_dst_flip = self.load_or_def_option('random_dst_flip', True)
self.options['random_dst_flip'] = io.input_bool("Flip DST faces randomly", default_random_dst_flip, help_message="Random horizontal flip DST faceset. Makes generalization of src->dst better, if src random flip is not enabled.")
def ask_batch_size(self, suggest_batch_size=None, range=None):
default_batch_size = self.load_or_def_option('batch_size', suggest_batch_size or self.batch_size) default_batch_size = self.load_or_def_option('batch_size', suggest_batch_size or self.batch_size)
self.options['batch_size'] = self.batch_size = max(0, io.input_int("Batch_size", default_batch_size, help_message="Larger batch size is better for NN's generalization, but it can cause Out of Memory error. Tune this value for your videocard manually."))
batch_size = max(0, io.input_int("Batch_size", default_batch_size, valid_range=range, help_message="Larger batch size is better for NN's generalization, but it can cause Out of Memory error. Tune this value for your videocard manually."))
if range is not None:
batch_size = np.clip(batch_size, range[0], range[1])
self.options['batch_size'] = self.batch_size = batch_size
#overridable #overridable
@ -349,7 +324,7 @@ class ModelBase(object):
return ( ('loss_src', 0), ('loss_dst', 0) ) return ( ('loss_src', 0), ('loss_dst', 0) )
#overridable #overridable
def onGetPreview(self, sample, for_history=False): def onGetPreview(self, sample):
#you can return multiple previews #you can return multiple previews
#return [ ('preview_name',preview_rgb), ... ] #return [ ('preview_name',preview_rgb), ... ]
return [] return []
@ -379,13 +354,8 @@ class ModelBase(object):
def get_previews(self): def get_previews(self):
return self.onGetPreview ( self.last_sample ) return self.onGetPreview ( self.last_sample )
def get_history_previews(self): def get_static_preview(self):
return self.onGetPreview (self.sample_for_preview, for_history=True) return self.onGetPreview (self.sample_for_preview)[0][1] #first preview, and bgr
def get_preview_history_writer(self):
if self.preview_history_writer is None:
self.preview_history_writer = PreviewHistoryWriter()
return self.preview_history_writer
def save(self): def save(self):
Path( self.get_summary_path() ).write_text( self.get_summary_text() ) Path( self.get_summary_path() ).write_text( self.get_summary_text() )
@ -442,8 +412,10 @@ class ModelBase(object):
name, bgr = previews[i] name, bgr = previews[i]
plist += [ (bgr, idx_backup_path / ( ('preview_%s.jpg') % (name)) ) ] plist += [ (bgr, idx_backup_path / ( ('preview_%s.jpg') % (name)) ) ]
if len(plist) != 0: for preview, filepath in plist:
self.get_preview_history_writer().post(plist, self.loss_history, self.iter) preview_lh = ModelBase.get_loss_history_preview(self.loss_history, self.iter, preview.shape[1], preview.shape[2])
img = (np.concatenate ( [preview_lh, preview], axis=0 ) * 255).astype(np.uint8)
cv2_imwrite (filepath, img )
def debug_one_iter(self): def debug_one_iter(self):
images = [] images = []
@ -464,10 +436,6 @@ class ModelBase(object):
self.last_sample = sample self.last_sample = sample
return sample return sample
#overridable
def should_save_preview_history(self):
return (not io.is_colab() and self.iter % 10 == 0) or (io.is_colab() and self.iter % 100 == 0)
def train_one_iter(self): def train_one_iter(self):
iter_time = time.time() iter_time = time.time()
@ -476,7 +444,8 @@ class ModelBase(object):
self.loss_history.append ( [float(loss[1]) for loss in losses] ) self.loss_history.append ( [float(loss[1]) for loss in losses] )
if self.should_save_preview_history(): if (not io.is_colab() and self.iter % 10 == 0) or \
(io.is_colab() and self.iter % 100 == 0):
plist = [] plist = []
if io.is_colab(): if io.is_colab():
@ -486,16 +455,12 @@ class ModelBase(object):
plist += [ (bgr, self.get_strpath_storage_for_file('preview_%s.jpg' % (name) ) ) ] plist += [ (bgr, self.get_strpath_storage_for_file('preview_%s.jpg' % (name) ) ) ]
if self.write_preview_history: if self.write_preview_history:
previews = self.get_history_previews() plist += [ (self.get_static_preview(), str (self.preview_history_path / ('%.6d.jpg' % (self.iter))) ) ]
for i in range(len(previews)):
name, bgr = previews[i]
path = self.preview_history_path / name
plist += [ ( bgr, str ( path / ( f'{self.iter:07d}.jpg') ) ) ]
if not io.is_colab():
plist += [ ( bgr, str ( path / ( '_last.jpg' ) )) ]
if len(plist) != 0: for preview, filepath in plist:
self.get_preview_history_writer().post(plist, self.loss_history, self.iter) preview_lh = ModelBase.get_loss_history_preview(self.loss_history, self.iter, preview.shape[1], preview.shape[2])
img = (np.concatenate ( [preview_lh, preview], axis=0 ) * 255).astype(np.uint8)
cv2_imwrite (filepath, img )
self.iter += 1 self.iter += 1
@ -545,13 +510,10 @@ class ModelBase(object):
return self.get_strpath_storage_for_file('summary.txt') return self.get_strpath_storage_for_file('summary.txt')
def get_summary_text(self): def get_summary_text(self):
visible_options = self.options.copy()
visible_options.update(self.options_show_override)
###Generate text summary of model hyperparameters ###Generate text summary of model hyperparameters
#Find the longest key name and value string. Used as column widths. #Find the longest key name and value string. Used as column widths.
width_name = max([len(k) for k in visible_options.keys()] + [17]) + 1 # Single space buffer to left edge. Minimum of 17, the length of the longest static string used "Current iteration" width_name = max([len(k) for k in self.options.keys()] + [17]) + 1 # Single space buffer to left edge. Minimum of 17, the length of the longest static string used "Current iteration"
width_value = max([len(str(x)) for x in visible_options.values()] + [len(str(self.get_iter())), len(self.get_model_name())]) + 1 # Single space buffer to right edge width_value = max([len(str(x)) for x in self.options.values()] + [len(str(self.get_iter())), len(self.get_model_name())]) + 1 # Single space buffer to right edge
if len(self.device_config.devices) != 0: #Check length of GPU names if len(self.device_config.devices) != 0: #Check length of GPU names
width_value = max([len(device.name)+1 for device in self.device_config.devices] + [width_value]) width_value = max([len(device.name)+1 for device in self.device_config.devices] + [width_value])
width_total = width_name + width_value + 2 #Plus 2 for ": " width_total = width_name + width_value + 2 #Plus 2 for ": "
@ -566,8 +528,8 @@ class ModelBase(object):
summary_text += [f'=={" Model Options ":-^{width_total}}=='] # Model options summary_text += [f'=={" Model Options ":-^{width_total}}=='] # Model options
summary_text += [f'=={" "*width_total}=='] summary_text += [f'=={" "*width_total}==']
for key in visible_options.keys(): for key in self.options.keys():
summary_text += [f'=={key: >{width_name}}: {str(visible_options[key]): <{width_value}}=='] # visible_options key/value pairs summary_text += [f'=={key: >{width_name}}: {str(self.options[key]): <{width_value}}=='] # self.options key/value pairs
summary_text += [f'=={" "*width_total}=='] summary_text += [f'=={" "*width_total}==']
summary_text += [f'=={" Running On ":-^{width_total}}=='] # Training hardware info summary_text += [f'=={" Running On ":-^{width_total}}=='] # Training hardware info
@ -645,41 +607,3 @@ class ModelBase(object):
lh_img[last_line_t:last_line_b, 0:w] += imagelib.get_text_image ( (last_line_b-last_line_t,w,c), lh_text, color=[0.8]*c ) lh_img[last_line_t:last_line_b, 0:w] += imagelib.get_text_image ( (last_line_b-last_line_t,w,c), lh_text, color=[0.8]*c )
return lh_img return lh_img
class PreviewHistoryWriter():
def __init__(self):
self.sq = multiprocessing.Queue()
self.p = multiprocessing.Process(target=self.process, args=( self.sq, ))
self.p.daemon = True
self.p.start()
def process(self, sq):
while True:
while not sq.empty():
plist, loss_history, iter = sq.get()
preview_lh_cache = {}
for preview, filepath in plist:
filepath = Path(filepath)
i = (preview.shape[1], preview.shape[2])
preview_lh = preview_lh_cache.get(i, None)
if preview_lh is None:
preview_lh = ModelBase.get_loss_history_preview(loss_history, iter, preview.shape[1], preview.shape[2])
preview_lh_cache[i] = preview_lh
img = (np.concatenate ( [preview_lh, preview], axis=0 ) * 255).astype(np.uint8)
filepath.parent.mkdir(parents=True, exist_ok=True)
cv2_imwrite (filepath, img )
time.sleep(0.01)
def post(self, plist, loss_history, iter):
self.sq.put ( (plist, loss_history, iter) )
# disable pickling
def __getstate__(self):
return dict()
def __setstate__(self, d):
self.__dict__.update(d)

View file

@ -1,725 +0,0 @@
import multiprocessing
import operator
from functools import partial
import numpy as np
from core import mathlib
from core.interact import interact as io
from core.leras import nn
from facelib import FaceType
from models import ModelBase
from samplelib import *
from core.cv2ex import *
class AMPModel(ModelBase):
#override
def on_initialize_options(self):
default_resolution = self.options['resolution'] = self.load_or_def_option('resolution', 224)
default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'wf')
default_models_opt_on_gpu = self.options['models_opt_on_gpu'] = self.load_or_def_option('models_opt_on_gpu', True)
default_ae_dims = self.options['ae_dims'] = self.load_or_def_option('ae_dims', 256)
default_inter_dims = self.options['inter_dims'] = self.load_or_def_option('inter_dims', 1024)
default_e_dims = self.options['e_dims'] = self.load_or_def_option('e_dims', 64)
default_d_dims = self.options['d_dims'] = self.options.get('d_dims', None)
default_d_mask_dims = self.options['d_mask_dims'] = self.options.get('d_mask_dims', None)
default_morph_factor = self.options['morph_factor'] = self.options.get('morph_factor', 0.5)
default_uniform_yaw = self.options['uniform_yaw'] = self.load_or_def_option('uniform_yaw', False)
default_blur_out_mask = self.options['blur_out_mask'] = self.load_or_def_option('blur_out_mask', False)
default_lr_dropout = self.options['lr_dropout'] = self.load_or_def_option('lr_dropout', 'n')
default_random_warp = self.options['random_warp'] = self.load_or_def_option('random_warp', True)
default_ct_mode = self.options['ct_mode'] = self.load_or_def_option('ct_mode', 'none')
default_clipgrad = self.options['clipgrad'] = self.load_or_def_option('clipgrad', False)
ask_override = self.ask_override()
if self.is_first_run() or ask_override:
self.ask_autobackup_hour()
self.ask_write_preview_history()
self.ask_target_iter()
self.ask_random_src_flip()
self.ask_random_dst_flip()
self.ask_batch_size(8)
if self.is_first_run():
resolution = io.input_int("Resolution", default_resolution, add_info="64-640", help_message="More resolution requires more VRAM and time to train. Value will be adjusted to multiple of 32 .")
resolution = np.clip ( (resolution // 32) * 32, 64, 640)
self.options['resolution'] = resolution
self.options['face_type'] = io.input_str ("Face type", default_face_type, ['f','wf','head'], help_message="whole face / head").lower()
default_d_dims = self.options['d_dims'] = self.load_or_def_option('d_dims', 64)
default_d_mask_dims = default_d_dims // 3
default_d_mask_dims += default_d_mask_dims % 2
default_d_mask_dims = self.options['d_mask_dims'] = self.load_or_def_option('d_mask_dims', default_d_mask_dims)
if self.is_first_run():
self.options['ae_dims'] = np.clip ( io.input_int("AutoEncoder dimensions", default_ae_dims, add_info="32-1024", help_message="All face information will packed to AE dims. If amount of AE dims are not enough, then for example closed eyes will not be recognized. More dims are better, but require more VRAM. You can fine-tune model size to fit your GPU." ), 32, 1024 )
self.options['inter_dims'] = np.clip ( io.input_int("Inter dimensions", default_inter_dims, add_info="32-2048", help_message="Should be equal or more than AutoEncoder dimensions. More dims are better, but require more VRAM. You can fine-tune model size to fit your GPU." ), 32, 2048 )
e_dims = np.clip ( io.input_int("Encoder dimensions", default_e_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 )
self.options['e_dims'] = e_dims + e_dims % 2
d_dims = np.clip ( io.input_int("Decoder dimensions", default_d_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 )
self.options['d_dims'] = d_dims + d_dims % 2
d_mask_dims = np.clip ( io.input_int("Decoder mask dimensions", default_d_mask_dims, add_info="16-256", help_message="Typical mask dimensions = decoder dimensions / 3. If you manually cut out obstacles from the dst mask, you can increase this parameter to achieve better quality." ), 16, 256 )
self.options['d_mask_dims'] = d_mask_dims + d_mask_dims % 2
morph_factor = np.clip ( io.input_number ("Morph factor.", default_morph_factor, add_info="0.1 .. 0.5", help_message="Typical fine value is 0.5"), 0.1, 0.5 )
self.options['morph_factor'] = morph_factor
if self.is_first_run() or ask_override:
self.options['uniform_yaw'] = io.input_bool ("Uniform yaw distribution of samples", default_uniform_yaw, help_message='Helps to fix blurry side faces due to small amount of them in the faceset.')
self.options['blur_out_mask'] = io.input_bool ("Blur out mask", default_blur_out_mask, help_message='Blurs nearby area outside of applied face mask of training samples. The result is the background near the face is smoothed and less noticeable on swapped face. The exact xseg mask in src and dst faceset is required.')
self.options['lr_dropout'] = io.input_str (f"Use learning rate dropout", default_lr_dropout, ['n','y','cpu'], help_message="When the face is trained enough, you can enable this option to get extra sharpness and reduce subpixel shake for less amount of iterations. Enabled it before `disable random warp` and before GAN. \nn - disabled.\ny - enabled\ncpu - enabled on CPU. This allows not to use extra VRAM, sacrificing 20% time of iteration.")
default_gan_power = self.options['gan_power'] = self.load_or_def_option('gan_power', 0.0)
default_gan_patch_size = self.options['gan_patch_size'] = self.load_or_def_option('gan_patch_size', self.options['resolution'] // 8)
default_gan_dims = self.options['gan_dims'] = self.load_or_def_option('gan_dims', 16)
if self.is_first_run() or ask_override:
self.options['models_opt_on_gpu'] = io.input_bool ("Place models and optimizer on GPU", default_models_opt_on_gpu, help_message="When you train on one GPU, by default model and optimizer weights are placed on GPU to accelerate the process. You can place they on CPU to free up extra VRAM, thus set bigger dimensions.")
self.options['random_warp'] = io.input_bool ("Enable random warp of samples", default_random_warp, help_message="Random warp is required to generalize facial expressions of both faces. When the face is trained enough, you can disable it to get extra sharpness and reduce subpixel shake for less amount of iterations.")
self.options['gan_power'] = np.clip ( io.input_number ("GAN power", default_gan_power, add_info="0.0 .. 5.0", help_message="Forces the neural network to learn small details of the face. Enable it only when the face is trained enough with random_warp(off), and don't disable. The higher the value, the higher the chances of artifacts. Typical fine value is 0.1"), 0.0, 5.0 )
if self.options['gan_power'] != 0.0:
gan_patch_size = np.clip ( io.input_int("GAN patch size", default_gan_patch_size, add_info="3-640", help_message="The higher patch size, the higher the quality, the more VRAM is required. You can get sharper edges even at the lowest setting. Typical fine value is resolution / 8." ), 3, 640 )
self.options['gan_patch_size'] = gan_patch_size
gan_dims = np.clip ( io.input_int("GAN dimensions", default_gan_dims, add_info="4-512", help_message="The dimensions of the GAN network. The higher dimensions, the more VRAM is required. You can get sharper edges even at the lowest setting. Typical fine value is 16." ), 4, 512 )
self.options['gan_dims'] = gan_dims
self.options['ct_mode'] = io.input_str (f"Color transfer for src faceset", default_ct_mode, ['none','rct','lct','mkl','idt','sot'], help_message="Change color distribution of src samples close to dst samples. If src faceset is deverse enough, then lct mode is fine in most cases.")
self.options['clipgrad'] = io.input_bool ("Enable gradient clipping", default_clipgrad, help_message="Gradient clipping reduces chance of model collapse, sacrificing speed of training.")
self.gan_model_changed = (default_gan_patch_size != self.options['gan_patch_size']) or (default_gan_dims != self.options['gan_dims'])
#override
def on_initialize(self):
device_config = nn.getCurrentDeviceConfig()
devices = device_config.devices
self.model_data_format = "NCHW"
nn.initialize(data_format=self.model_data_format)
tf = nn.tf
input_ch=3
resolution = self.resolution = self.options['resolution']
e_dims = self.options['e_dims']
ae_dims = self.options['ae_dims']
inter_dims = self.inter_dims = self.options['inter_dims']
inter_res = self.inter_res = resolution // 32
d_dims = self.options['d_dims']
d_mask_dims = self.options['d_mask_dims']
face_type = self.face_type = {'f' : FaceType.FULL,
'wf' : FaceType.WHOLE_FACE,
'head' : FaceType.HEAD}[ self.options['face_type'] ]
morph_factor = self.options['morph_factor']
gan_power = self.gan_power = self.options['gan_power']
random_warp = self.options['random_warp']
blur_out_mask = self.options['blur_out_mask']
ct_mode = self.options['ct_mode']
if ct_mode == 'none':
ct_mode = None
use_fp16 = False
if self.is_exporting:
use_fp16 = io.input_bool ("Export quantized?", False, help_message='Makes the exported model faster. If you have problems, disable this option.')
conv_dtype = tf.float16 if use_fp16 else tf.float32
class Downscale(nn.ModelBase):
def on_build(self, in_ch, out_ch, kernel_size=5 ):
self.conv1 = nn.Conv2D( in_ch, out_ch, kernel_size=kernel_size, strides=2, padding='SAME', dtype=conv_dtype)
def forward(self, x):
return tf.nn.leaky_relu(self.conv1(x), 0.1)
class Upscale(nn.ModelBase):
def on_build(self, in_ch, out_ch, kernel_size=3 ):
self.conv1 = nn.Conv2D(in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype)
def forward(self, x):
x = nn.depth_to_space(tf.nn.leaky_relu(self.conv1(x), 0.1), 2)
return x
class ResidualBlock(nn.ModelBase):
def on_build(self, ch, kernel_size=3 ):
self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype)
self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype)
def forward(self, inp):
x = self.conv1(inp)
x = tf.nn.leaky_relu(x, 0.2)
x = self.conv2(x)
x = tf.nn.leaky_relu(inp+x, 0.2)
return x
class Encoder(nn.ModelBase):
def on_build(self):
self.down1 = Downscale(input_ch, e_dims, kernel_size=5)
self.res1 = ResidualBlock(e_dims)
self.down2 = Downscale(e_dims, e_dims*2, kernel_size=5)
self.down3 = Downscale(e_dims*2, e_dims*4, kernel_size=5)
self.down4 = Downscale(e_dims*4, e_dims*8, kernel_size=5)
self.down5 = Downscale(e_dims*8, e_dims*8, kernel_size=5)
self.res5 = ResidualBlock(e_dims*8)
self.dense1 = nn.Dense( (( resolution//(2**5) )**2) * e_dims*8, ae_dims )
def forward(self, x):
if use_fp16:
x = tf.cast(x, tf.float16)
x = self.down1(x)
x = self.res1(x)
x = self.down2(x)
x = self.down3(x)
x = self.down4(x)
x = self.down5(x)
x = self.res5(x)
if use_fp16:
x = tf.cast(x, tf.float32)
x = nn.pixel_norm(nn.flatten(x), axes=-1)
x = self.dense1(x)
return x
class Inter(nn.ModelBase):
def on_build(self):
self.dense2 = nn.Dense(ae_dims, inter_res * inter_res * inter_dims)
def forward(self, inp):
x = inp
x = self.dense2(x)
x = nn.reshape_4D (x, inter_res, inter_res, inter_dims)
return x
class Decoder(nn.ModelBase):
def on_build(self ):
self.upscale0 = Upscale(inter_dims, d_dims*8, kernel_size=3)
self.upscale1 = Upscale(d_dims*8, d_dims*8, kernel_size=3)
self.upscale2 = Upscale(d_dims*8, d_dims*4, kernel_size=3)
self.upscale3 = Upscale(d_dims*4, d_dims*2, kernel_size=3)
self.res0 = ResidualBlock(d_dims*8, kernel_size=3)
self.res1 = ResidualBlock(d_dims*8, kernel_size=3)
self.res2 = ResidualBlock(d_dims*4, kernel_size=3)
self.res3 = ResidualBlock(d_dims*2, kernel_size=3)
self.upscalem0 = Upscale(inter_dims, d_mask_dims*8, kernel_size=3)
self.upscalem1 = Upscale(d_mask_dims*8, d_mask_dims*8, kernel_size=3)
self.upscalem2 = Upscale(d_mask_dims*8, d_mask_dims*4, kernel_size=3)
self.upscalem3 = Upscale(d_mask_dims*4, d_mask_dims*2, kernel_size=3)
self.upscalem4 = Upscale(d_mask_dims*2, d_mask_dims*1, kernel_size=3)
self.out_convm = nn.Conv2D( d_mask_dims*1, 1, kernel_size=1, padding='SAME', dtype=conv_dtype)
self.out_conv = nn.Conv2D( d_dims*2, 3, kernel_size=1, padding='SAME', dtype=conv_dtype)
self.out_conv1 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
self.out_conv2 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
self.out_conv3 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
def forward(self, z):
if use_fp16:
z = tf.cast(z, tf.float16)
x = self.upscale0(z)
x = self.res0(x)
x = self.upscale1(x)
x = self.res1(x)
x = self.upscale2(x)
x = self.res2(x)
x = self.upscale3(x)
x = self.res3(x)
x = tf.nn.sigmoid( nn.depth_to_space(tf.concat( (self.out_conv(x),
self.out_conv1(x),
self.out_conv2(x),
self.out_conv3(x)), nn.conv2d_ch_axis), 2) )
m = self.upscalem0(z)
m = self.upscalem1(m)
m = self.upscalem2(m)
m = self.upscalem3(m)
m = self.upscalem4(m)
m = tf.nn.sigmoid(self.out_convm(m))
if use_fp16:
x = tf.cast(x, tf.float32)
m = tf.cast(m, tf.float32)
return x, m
models_opt_on_gpu = False if len(devices) == 0 else self.options['models_opt_on_gpu']
models_opt_device = nn.tf_default_device_name if models_opt_on_gpu and self.is_training else '/CPU:0'
optimizer_vars_on_cpu = models_opt_device=='/CPU:0'
bgr_shape = self.bgr_shape = nn.get4Dshape(resolution,resolution,input_ch)
mask_shape = nn.get4Dshape(resolution,resolution,1)
self.model_filename_list = []
with tf.device ('/CPU:0'):
#Place holders on CPU
self.warped_src = tf.placeholder (nn.floatx, bgr_shape, name='warped_src')
self.warped_dst = tf.placeholder (nn.floatx, bgr_shape, name='warped_dst')
self.target_src = tf.placeholder (nn.floatx, bgr_shape, name='target_src')
self.target_dst = tf.placeholder (nn.floatx, bgr_shape, name='target_dst')
self.target_srcm = tf.placeholder (nn.floatx, mask_shape, name='target_srcm')
self.target_srcm_em = tf.placeholder (nn.floatx, mask_shape, name='target_srcm_em')
self.target_dstm = tf.placeholder (nn.floatx, mask_shape, name='target_dstm')
self.target_dstm_em = tf.placeholder (nn.floatx, mask_shape, name='target_dstm_em')
self.morph_value_t = tf.placeholder (nn.floatx, (1,), name='morph_value_t')
# Initializing model classes
with tf.device (models_opt_device):
self.encoder = Encoder(name='encoder')
self.inter_src = Inter(name='inter_src')
self.inter_dst = Inter(name='inter_dst')
self.decoder = Decoder(name='decoder')
self.model_filename_list += [ [self.encoder, 'encoder.npy'],
[self.inter_src, 'inter_src.npy'],
[self.inter_dst , 'inter_dst.npy'],
[self.decoder , 'decoder.npy'] ]
if self.is_training:
# Initialize optimizers
clipnorm = 1.0 if self.options['clipgrad'] else 0.0
if self.options['lr_dropout'] in ['y','cpu']:
lr_cos = 500
lr_dropout = 0.3
else:
lr_cos = 0
lr_dropout = 1.0
self.G_weights = self.encoder.get_weights() + self.decoder.get_weights()
self.src_dst_opt = nn.AdaBelief(lr=5e-5, lr_dropout=lr_dropout, lr_cos=lr_cos, clipnorm=clipnorm, name='src_dst_opt')
self.src_dst_opt.initialize_variables (self.G_weights, vars_on_cpu=optimizer_vars_on_cpu)
self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ]
if gan_power != 0:
self.GAN = nn.UNetPatchDiscriminator(patch_size=self.options['gan_patch_size'], in_ch=input_ch, base_ch=self.options['gan_dims'], name="GAN")
self.GAN_opt = nn.AdaBelief(lr=5e-5, lr_dropout=lr_dropout, lr_cos=lr_cos, clipnorm=clipnorm, name='GAN_opt')
self.GAN_opt.initialize_variables ( self.GAN.get_weights(), vars_on_cpu=optimizer_vars_on_cpu)
self.model_filename_list += [ [self.GAN, 'GAN.npy'],
[self.GAN_opt, 'GAN_opt.npy'] ]
if self.is_training:
# Adjust batch size for multiple GPU
gpu_count = max(1, len(devices) )
bs_per_gpu = max(1, self.get_batch_size() // gpu_count)
self.set_batch_size( gpu_count*bs_per_gpu)
# Compute losses per GPU
gpu_pred_src_src_list = []
gpu_pred_dst_dst_list = []
gpu_pred_src_dst_list = []
gpu_pred_src_srcm_list = []
gpu_pred_dst_dstm_list = []
gpu_pred_src_dstm_list = []
gpu_src_losses = []
gpu_dst_losses = []
gpu_G_loss_gradients = []
gpu_GAN_loss_gradients = []
def DLossOnes(logits):
return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(logits), logits=logits), axis=[1,2,3])
def DLossZeros(logits):
return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(logits), logits=logits), axis=[1,2,3])
for gpu_id in range(gpu_count):
with tf.device( f'/{devices[gpu_id].tf_dev_type}:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ):
with tf.device(f'/CPU:0'):
# slice on CPU, otherwise all batch data will be transfered to GPU first
batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu )
gpu_warped_src = self.warped_src [batch_slice,:,:,:]
gpu_warped_dst = self.warped_dst [batch_slice,:,:,:]
gpu_target_src = self.target_src [batch_slice,:,:,:]
gpu_target_dst = self.target_dst [batch_slice,:,:,:]
gpu_target_srcm = self.target_srcm[batch_slice,:,:,:]
gpu_target_srcm_em = self.target_srcm_em[batch_slice,:,:,:]
gpu_target_dstm = self.target_dstm[batch_slice,:,:,:]
gpu_target_dstm_em = self.target_dstm_em[batch_slice,:,:,:]
# process model tensors
gpu_src_code = self.encoder (gpu_warped_src)
gpu_dst_code = self.encoder (gpu_warped_dst)
gpu_src_inter_src_code, gpu_src_inter_dst_code = self.inter_src (gpu_src_code), self.inter_dst (gpu_src_code)
gpu_dst_inter_src_code, gpu_dst_inter_dst_code = self.inter_src (gpu_dst_code), self.inter_dst (gpu_dst_code)
inter_dims_bin = int(inter_dims*morph_factor)
with tf.device(f'/CPU:0'):
inter_rnd_binomial = tf.stack([tf.random.shuffle(tf.concat([tf.tile(tf.constant([1], tf.float32), ( inter_dims_bin, )),
tf.tile(tf.constant([0], tf.float32), ( inter_dims-inter_dims_bin, ))], 0 )) for _ in range(bs_per_gpu)], 0)
inter_rnd_binomial = tf.stop_gradient(inter_rnd_binomial[...,None,None])
gpu_src_code = gpu_src_inter_src_code * inter_rnd_binomial + gpu_src_inter_dst_code * (1-inter_rnd_binomial)
gpu_dst_code = gpu_dst_inter_dst_code
inter_dims_slice = tf.cast(inter_dims*self.morph_value_t[0], tf.int32)
gpu_src_dst_code = tf.concat( (tf.slice(gpu_dst_inter_src_code, [0,0,0,0], [-1, inter_dims_slice , inter_res, inter_res]),
tf.slice(gpu_dst_inter_dst_code, [0,inter_dims_slice,0,0], [-1,inter_dims-inter_dims_slice, inter_res,inter_res]) ), 1 )
gpu_pred_src_src, gpu_pred_src_srcm = self.decoder(gpu_src_code)
gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder(gpu_dst_code)
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code)
gpu_pred_src_src_list.append(gpu_pred_src_src), gpu_pred_src_srcm_list.append(gpu_pred_src_srcm)
gpu_pred_dst_dst_list.append(gpu_pred_dst_dst), gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm)
gpu_pred_src_dst_list.append(gpu_pred_src_dst), gpu_pred_src_dstm_list.append(gpu_pred_src_dstm)
gpu_target_srcm_anti = 1-gpu_target_srcm
gpu_target_dstm_anti = 1-gpu_target_dstm
gpu_target_srcm_gblur = nn.gaussian_blur(gpu_target_srcm, resolution // 32)
gpu_target_dstm_gblur = nn.gaussian_blur(gpu_target_dstm, resolution // 32)
gpu_target_srcm_blur = tf.clip_by_value(gpu_target_srcm_gblur, 0, 0.5) * 2
gpu_target_dstm_blur = tf.clip_by_value(gpu_target_dstm_gblur, 0, 0.5) * 2
gpu_target_srcm_anti_blur = 1.0-gpu_target_srcm_blur
gpu_target_dstm_anti_blur = 1.0-gpu_target_dstm_blur
if blur_out_mask:
sigma = resolution / 128
x = nn.gaussian_blur(gpu_target_src*gpu_target_srcm_anti, sigma)
y = 1-nn.gaussian_blur(gpu_target_srcm, sigma)
y = tf.where(tf.equal(y, 0), tf.ones_like(y), y)
gpu_target_src = gpu_target_src*gpu_target_srcm + (x/y)*gpu_target_srcm_anti
x = nn.gaussian_blur(gpu_target_dst*gpu_target_dstm_anti, sigma)
y = 1-nn.gaussian_blur(gpu_target_dstm, sigma)
y = tf.where(tf.equal(y, 0), tf.ones_like(y), y)
gpu_target_dst = gpu_target_dst*gpu_target_dstm + (x/y)*gpu_target_dstm_anti
gpu_target_src_masked = gpu_target_src*gpu_target_srcm_blur
gpu_target_dst_masked = gpu_target_dst*gpu_target_dstm_blur
gpu_target_src_anti_masked = gpu_target_src*gpu_target_srcm_anti_blur
gpu_target_dst_anti_masked = gpu_target_dst*gpu_target_dstm_anti_blur
gpu_pred_src_src_masked = gpu_pred_src_src*gpu_target_srcm_blur
gpu_pred_dst_dst_masked = gpu_pred_dst_dst*gpu_target_dstm_blur
gpu_pred_src_src_anti_masked = gpu_pred_src_src*gpu_target_srcm_anti_blur
gpu_pred_dst_dst_anti_masked = gpu_pred_dst_dst*gpu_target_dstm_anti_blur
# Structural loss
gpu_src_loss = tf.reduce_mean (5*nn.dssim(gpu_target_src_masked, gpu_pred_src_src_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
gpu_src_loss += tf.reduce_mean (5*nn.dssim(gpu_target_src_masked, gpu_pred_src_src_masked, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1])
gpu_dst_loss = tf.reduce_mean (5*nn.dssim(gpu_target_dst_masked, gpu_pred_dst_dst_masked, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1])
gpu_dst_loss += tf.reduce_mean (5*nn.dssim(gpu_target_dst_masked, gpu_pred_dst_dst_masked, max_val=1.0, filter_size=int(resolution/23.2) ), axis=[1])
# Pixel loss
gpu_src_loss += tf.reduce_mean (10*tf.square(gpu_target_src_masked-gpu_pred_src_src_masked), axis=[1,2,3])
gpu_dst_loss += tf.reduce_mean (10*tf.square(gpu_target_dst_masked-gpu_pred_dst_dst_masked), axis=[1,2,3])
# Eyes+mouth prio loss
gpu_src_loss += tf.reduce_mean (300*tf.abs (gpu_target_src*gpu_target_srcm_em-gpu_pred_src_src*gpu_target_srcm_em), axis=[1,2,3])
gpu_dst_loss += tf.reduce_mean (300*tf.abs (gpu_target_dst*gpu_target_dstm_em-gpu_pred_dst_dst*gpu_target_dstm_em), axis=[1,2,3])
# Mask loss
gpu_src_loss += tf.reduce_mean ( 10*tf.square( gpu_target_srcm - gpu_pred_src_srcm ),axis=[1,2,3] )
gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] )
gpu_src_losses += [gpu_src_loss]
gpu_dst_losses += [gpu_dst_loss]
gpu_G_loss = gpu_src_loss + gpu_dst_loss
# dst-dst background weak loss
gpu_G_loss += tf.reduce_mean(0.1*tf.square(gpu_pred_dst_dst_anti_masked-gpu_target_dst_anti_masked),axis=[1,2,3] )
gpu_G_loss += 0.000001*nn.total_variation_mse(gpu_pred_dst_dst_anti_masked)
if gan_power != 0:
gpu_pred_src_src_d, gpu_pred_src_src_d2 = self.GAN(gpu_pred_src_src_masked)
gpu_pred_dst_dst_d, gpu_pred_dst_dst_d2 = self.GAN(gpu_pred_dst_dst_masked)
gpu_target_src_d, gpu_target_src_d2 = self.GAN(gpu_target_src_masked)
gpu_target_dst_d, gpu_target_dst_d2 = self.GAN(gpu_target_dst_masked)
gpu_GAN_loss = (DLossOnes (gpu_target_src_d) + DLossOnes (gpu_target_src_d2) + \
DLossZeros(gpu_pred_src_src_d) + DLossZeros(gpu_pred_src_src_d2) + \
DLossOnes (gpu_target_dst_d) + DLossOnes (gpu_target_dst_d2) + \
DLossZeros(gpu_pred_dst_dst_d) + DLossZeros(gpu_pred_dst_dst_d2)
) * (1.0 / 8)
gpu_GAN_loss_gradients += [ nn.gradients (gpu_GAN_loss, self.GAN.get_weights() ) ]
gpu_G_loss += (DLossOnes(gpu_pred_src_src_d) + DLossOnes(gpu_pred_src_src_d2) + \
DLossOnes(gpu_pred_dst_dst_d) + DLossOnes(gpu_pred_dst_dst_d2)
) * gan_power
# Minimal src-src-bg rec with total_variation_mse to suppress random bright dots from gan
gpu_G_loss += 0.000001*nn.total_variation_mse(gpu_pred_src_src)
gpu_G_loss += 0.02*tf.reduce_mean(tf.square(gpu_pred_src_src_anti_masked-gpu_target_src_anti_masked),axis=[1,2,3] )
gpu_G_loss_gradients += [ nn.gradients ( gpu_G_loss, self.G_weights ) ]
# Average losses and gradients, and create optimizer update ops
with tf.device(f'/CPU:0'):
pred_src_src = nn.concat(gpu_pred_src_src_list, 0)
pred_dst_dst = nn.concat(gpu_pred_dst_dst_list, 0)
pred_src_dst = nn.concat(gpu_pred_src_dst_list, 0)
pred_src_srcm = nn.concat(gpu_pred_src_srcm_list, 0)
pred_dst_dstm = nn.concat(gpu_pred_dst_dstm_list, 0)
pred_src_dstm = nn.concat(gpu_pred_src_dstm_list, 0)
with tf.device (models_opt_device):
src_loss = tf.concat(gpu_src_losses, 0)
dst_loss = tf.concat(gpu_dst_losses, 0)
train_op = self.src_dst_opt.get_update_op (nn.average_gv_list (gpu_G_loss_gradients))
if gan_power != 0:
GAN_train_op = self.GAN_opt.get_update_op (nn.average_gv_list(gpu_GAN_loss_gradients) )
# Initializing training and view functions
def train(warped_src, target_src, target_srcm, target_srcm_em, \
warped_dst, target_dst, target_dstm, target_dstm_em, ):
s, d, _ = nn.tf_sess.run ([src_loss, dst_loss, train_op],
feed_dict={self.warped_src :warped_src,
self.target_src :target_src,
self.target_srcm:target_srcm,
self.target_srcm_em:target_srcm_em,
self.warped_dst :warped_dst,
self.target_dst :target_dst,
self.target_dstm:target_dstm,
self.target_dstm_em:target_dstm_em,
})
return s, d
self.train = train
if gan_power != 0:
def GAN_train(warped_src, target_src, target_srcm, target_srcm_em, \
warped_dst, target_dst, target_dstm, target_dstm_em, ):
nn.tf_sess.run ([GAN_train_op], feed_dict={self.warped_src :warped_src,
self.target_src :target_src,
self.target_srcm:target_srcm,
self.target_srcm_em:target_srcm_em,
self.warped_dst :warped_dst,
self.target_dst :target_dst,
self.target_dstm:target_dstm,
self.target_dstm_em:target_dstm_em})
self.GAN_train = GAN_train
def AE_view(warped_src, warped_dst, morph_value):
return nn.tf_sess.run ( [pred_src_src, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm],
feed_dict={self.warped_src:warped_src, self.warped_dst:warped_dst, self.morph_value_t:[morph_value] })
self.AE_view = AE_view
else:
#Initializing merge function
with tf.device( nn.tf_default_device_name if len(devices) != 0 else f'/CPU:0'):
gpu_dst_code = self.encoder (self.warped_dst)
gpu_dst_inter_src_code = self.inter_src (gpu_dst_code)
gpu_dst_inter_dst_code = self.inter_dst (gpu_dst_code)
inter_dims_slice = tf.cast(inter_dims*self.morph_value_t[0], tf.int32)
gpu_src_dst_code = tf.concat( ( tf.slice(gpu_dst_inter_src_code, [0,0,0,0], [-1, inter_dims_slice , inter_res, inter_res]),
tf.slice(gpu_dst_inter_dst_code, [0,inter_dims_slice,0,0], [-1,inter_dims-inter_dims_slice, inter_res,inter_res]) ), 1 )
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code)
_, gpu_pred_dst_dstm = self.decoder(gpu_dst_inter_dst_code)
def AE_merge(warped_dst, morph_value):
return nn.tf_sess.run ( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm], feed_dict={self.warped_dst:warped_dst, self.morph_value_t:[morph_value] })
self.AE_merge = AE_merge
# Loading/initializing all models/optimizers weights
for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"):
do_init = self.is_first_run()
if self.is_training and gan_power != 0 and model == self.GAN:
if self.gan_model_changed:
do_init = True
if not do_init:
do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) )
if do_init:
model.init_weights()
###############
# initializing sample generators
if self.is_training:
training_data_src_path = self.training_data_src_path #if not self.pretrain else self.get_pretraining_data_path()
training_data_dst_path = self.training_data_dst_path #if not self.pretrain else self.get_pretraining_data_path()
random_ct_samples_path=training_data_dst_path if ct_mode is not None else None #and not self.pretrain
cpu_count = multiprocessing.cpu_count()
src_generators_count = cpu_count // 2
dst_generators_count = cpu_count // 2
if ct_mode is not None:
src_generators_count = int(src_generators_count * 1.5)
self.set_training_data_generators ([
SampleGeneratorFace(training_data_src_path, random_ct_samples_path=random_ct_samples_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
sample_process_options=SampleProcessor.Options(scale_range=[-0.15, 0.15], random_flip=self.random_src_flip),
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.EYES_MOUTH, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution},
],
uniform_yaw_distribution=self.options['uniform_yaw'],# or self.pretrain,
generators_count=src_generators_count ),
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
sample_process_options=SampleProcessor.Options(scale_range=[-0.15, 0.15], random_flip=self.random_dst_flip),
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.EYES_MOUTH, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution},
],
uniform_yaw_distribution=self.options['uniform_yaw'],# or self.pretrain,
generators_count=dst_generators_count )
])
def export_dfm (self):
output_path=self.get_strpath_storage_for_file('model.dfm')
io.log_info(f'Dumping .dfm to {output_path}')
tf = nn.tf
with tf.device (nn.tf_default_device_name):
warped_dst = tf.placeholder (nn.floatx, (None, self.resolution, self.resolution, 3), name='in_face')
warped_dst = tf.transpose(warped_dst, (0,3,1,2))
morph_value = tf.placeholder (nn.floatx, (1,), name='morph_value')
gpu_dst_code = self.encoder (warped_dst)
gpu_dst_inter_src_code = self.inter_src ( gpu_dst_code)
gpu_dst_inter_dst_code = self.inter_dst ( gpu_dst_code)
inter_dims_slice = tf.cast(self.inter_dims*morph_value[0], tf.int32)
gpu_src_dst_code = tf.concat( (tf.slice(gpu_dst_inter_src_code, [0,0,0,0], [-1, inter_dims_slice , self.inter_res, self.inter_res]),
tf.slice(gpu_dst_inter_dst_code, [0,inter_dims_slice,0,0], [-1,self.inter_dims-inter_dims_slice, self.inter_res,self.inter_res]) ), 1 )
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code)
_, gpu_pred_dst_dstm = self.decoder(gpu_dst_inter_dst_code)
gpu_pred_src_dst = tf.transpose(gpu_pred_src_dst, (0,2,3,1))
gpu_pred_dst_dstm = tf.transpose(gpu_pred_dst_dstm, (0,2,3,1))
gpu_pred_src_dstm = tf.transpose(gpu_pred_src_dstm, (0,2,3,1))
tf.identity(gpu_pred_dst_dstm, name='out_face_mask')
tf.identity(gpu_pred_src_dst, name='out_celeb_face')
tf.identity(gpu_pred_src_dstm, name='out_celeb_face_mask')
output_graph_def = tf.graph_util.convert_variables_to_constants(
nn.tf_sess,
tf.get_default_graph().as_graph_def(),
['out_face_mask','out_celeb_face','out_celeb_face_mask']
)
import tf2onnx
with tf.device("/CPU:0"):
model_proto, _ = tf2onnx.convert._convert_common(
output_graph_def,
name='AMP',
input_names=['in_face:0','morph_value:0'],
output_names=['out_face_mask:0','out_celeb_face:0','out_celeb_face_mask:0'],
opset=12,
output_path=output_path)
#override
def get_model_filename_list(self):
return self.model_filename_list
#override
def onSave(self):
for model, filename in io.progress_bar_generator(self.get_model_filename_list(), "Saving", leave=False):
model.save_weights ( self.get_strpath_storage_for_file(filename) )
#override
def should_save_preview_history(self):
return (not io.is_colab() and self.iter % ( 10*(max(1,self.resolution // 64)) ) == 0) or \
(io.is_colab() and self.iter % 100 == 0)
#override
def onTrainOneIter(self):
bs = self.get_batch_size()
( (warped_src, target_src, target_srcm, target_srcm_em), \
(warped_dst, target_dst, target_dstm, target_dstm_em) ) = self.generate_next_samples()
src_loss, dst_loss = self.train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em)
if self.gan_power != 0:
self.GAN_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em)
return ( ('src_loss', np.mean(src_loss) ), ('dst_loss', np.mean(dst_loss) ), )
#override
def onGetPreview(self, samples, for_history=False):
( (warped_src, target_src, target_srcm, target_srcm_em),
(warped_dst, target_dst, target_dstm, target_dstm_em) ) = samples
S, D, SS, DD, DDM_000, _, _ = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([target_src,target_dst] + self.AE_view (target_src, target_dst, 0.0) ) ]
_, _, DDM_025, SD_025, SDM_025 = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in self.AE_view (target_src, target_dst, 0.25) ]
_, _, DDM_050, SD_050, SDM_050 = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in self.AE_view (target_src, target_dst, 0.50) ]
_, _, DDM_065, SD_065, SDM_065 = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in self.AE_view (target_src, target_dst, 0.65) ]
_, _, DDM_075, SD_075, SDM_075 = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in self.AE_view (target_src, target_dst, 0.75) ]
_, _, DDM_100, SD_100, SDM_100 = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in self.AE_view (target_src, target_dst, 1.00) ]
(DDM_000,
DDM_025, SDM_025,
DDM_050, SDM_050,
DDM_065, SDM_065,
DDM_075, SDM_075,
DDM_100, SDM_100) = [ np.repeat (x, (3,), -1) for x in (DDM_000,
DDM_025, SDM_025,
DDM_050, SDM_050,
DDM_065, SDM_065,
DDM_075, SDM_075,
DDM_100, SDM_100) ]
target_srcm, target_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format) for x in ([target_srcm, target_dstm] )]
n_samples = min(4, self.get_batch_size(), 800 // self.resolution )
result = []
i = np.random.randint(n_samples) if not for_history else 0
st = [ np.concatenate ((S[i], D[i], DD[i]*DDM_000[i]), axis=1) ]
st += [ np.concatenate ((SS[i], DD[i], SD_100[i] ), axis=1) ]
result += [ ('AMP morph 1.0', np.concatenate (st, axis=0 )), ]
st = [ np.concatenate ((DD[i], SD_025[i], SD_050[i]), axis=1) ]
st += [ np.concatenate ((SD_065[i], SD_075[i], SD_100[i]), axis=1) ]
result += [ ('AMP morph list', np.concatenate (st, axis=0 )), ]
st = [ np.concatenate ((DD[i], SD_025[i]*DDM_025[i]*SDM_025[i], SD_050[i]*DDM_050[i]*SDM_050[i]), axis=1) ]
st += [ np.concatenate ((SD_065[i]*DDM_065[i]*SDM_065[i], SD_075[i]*DDM_075[i]*SDM_075[i], SD_100[i]*DDM_100[i]*SDM_100[i]), axis=1) ]
result += [ ('AMP morph list masked', np.concatenate (st, axis=0 )), ]
return result
def predictor_func (self, face, morph_value):
face = nn.to_data_format(face[None,...], self.model_data_format, "NHWC")
bgr, mask_dst_dstm, mask_src_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format).astype(np.float32) for x in self.AE_merge (face, morph_value) ]
return bgr[0], mask_src_dstm[0][...,0], mask_dst_dstm[0][...,0]
#override
def get_MergerConfig(self):
morph_factor = np.clip ( io.input_number ("Morph factor", 1.0, add_info="0.0 .. 1.0"), 0.0, 1.0 )
def predictor_morph(face):
return self.predictor_func(face, morph_factor)
import merger
return predictor_morph, (self.options['resolution'], self.options['resolution'], 3), merger.MergerConfigMasked(face_type=self.face_type, default_mode = 'overlay')
Model = AMPModel

View file

@ -1 +0,0 @@
from .Model import Model

View file

@ -22,16 +22,15 @@ class QModel(ModelBase):
resolution = self.resolution = 96 resolution = self.resolution = 96
self.face_type = FaceType.FULL self.face_type = FaceType.FULL
ae_dims = 128 ae_dims = 128
e_dims = 64 e_dims = 128
d_dims = 64 d_dims = 64
d_mask_dims = 16
self.pretrain = False self.pretrain = False
self.pretrain_just_disabled = False self.pretrain_just_disabled = False
masked_training = True masked_training = True
models_opt_on_gpu = len(devices) >= 1 and all([dev.total_mem_gb >= 4 for dev in devices]) models_opt_on_gpu = len(devices) >= 1 and all([dev.total_mem_gb >= 4 for dev in devices])
models_opt_device = nn.tf_default_device_name if models_opt_on_gpu and self.is_training else '/CPU:0' models_opt_device = '/GPU:0' if models_opt_on_gpu and self.is_training else '/CPU:0'
optimizer_vars_on_cpu = models_opt_device=='/CPU:0' optimizer_vars_on_cpu = models_opt_device=='/CPU:0'
input_ch = 3 input_ch = 3
@ -40,7 +39,7 @@ class QModel(ModelBase):
self.model_filename_list = [] self.model_filename_list = []
model_archi = nn.DeepFakeArchi(resolution, opts='ud') model_archi = nn.DeepFakeArchi(resolution, mod='quick')
with tf.device ('/CPU:0'): with tf.device ('/CPU:0'):
#Place holders on CPU #Place holders on CPU
@ -56,13 +55,13 @@ class QModel(ModelBase):
# Initializing model classes # Initializing model classes
with tf.device (models_opt_device): with tf.device (models_opt_device):
self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, name='encoder') self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, name='encoder')
encoder_out_ch = self.encoder.get_out_ch()*self.encoder.get_out_res(resolution)**2 encoder_out_ch = self.encoder.compute_output_channels ( (nn.floatx, bgr_shape))
self.inter = model_archi.Inter (in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims, name='inter') self.inter = model_archi.Inter (in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims, d_ch=d_dims, name='inter')
inter_out_ch = self.inter.get_out_ch() inter_out_ch = self.inter.compute_output_channels ( (nn.floatx, (None,encoder_out_ch)))
self.decoder_src = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder_src') self.decoder_src = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, name='decoder_src')
self.decoder_dst = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder_dst') self.decoder_dst = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, name='decoder_dst')
self.model_filename_list += [ [self.encoder, 'encoder.npy' ], self.model_filename_list += [ [self.encoder, 'encoder.npy' ],
[self.inter, 'inter.npy' ], [self.inter, 'inter.npy' ],
@ -96,7 +95,7 @@ class QModel(ModelBase):
gpu_src_dst_loss_gvs = [] gpu_src_dst_loss_gvs = []
for gpu_id in range(gpu_count): for gpu_id in range(gpu_count):
with tf.device( f'/{devices[gpu_id].tf_dev_type}:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ): with tf.device( f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ):
batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu ) batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu )
with tf.device(f'/CPU:0'): with tf.device(f'/CPU:0'):
# slice on CPU, otherwise all batch data will be transfered to GPU first # slice on CPU, otherwise all batch data will be transfered to GPU first
@ -190,7 +189,7 @@ class QModel(ModelBase):
self.AE_view = AE_view self.AE_view = AE_view
else: else:
# Initializing merge function # Initializing merge function
with tf.device( nn.tf_default_device_name if len(devices) != 0 else f'/CPU:0'): with tf.device( f'/GPU:0' if len(devices) != 0 else f'/CPU:0'):
gpu_dst_code = self.inter(self.encoder(self.warped_dst)) gpu_dst_code = self.inter(self.encoder(self.warped_dst))
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code)
_, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code) _, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code)
@ -278,7 +277,7 @@ class QModel(ModelBase):
return ( ('src_loss', src_loss), ('dst_loss', dst_loss), ) return ( ('src_loss', src_loss), ('dst_loss', dst_loss), )
#override #override
def onGetPreview(self, samples, for_history=False): def onGetPreview(self, samples):
( (warped_src, target_src, target_srcm), ( (warped_src, target_src, target_srcm),
(warped_dst, target_dst, target_dstm) ) = samples (warped_dst, target_dst, target_dstm) ) = samples

View file

@ -27,33 +27,20 @@ class SAEHDModel(ModelBase):
suggest_batch_size = 4 suggest_batch_size = 4
yn_str = {True:'y',False:'n'} yn_str = {True:'y',False:'n'}
min_res = 64
max_res = 640
#default_usefp16 = self.options['use_fp16'] = self.load_or_def_option('use_fp16', False)
default_resolution = self.options['resolution'] = self.load_or_def_option('resolution', 128) default_resolution = self.options['resolution'] = self.load_or_def_option('resolution', 128)
default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'f') default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'f')
default_models_opt_on_gpu = self.options['models_opt_on_gpu'] = self.load_or_def_option('models_opt_on_gpu', True) default_models_opt_on_gpu = self.options['models_opt_on_gpu'] = self.load_or_def_option('models_opt_on_gpu', True)
default_archi = self.options['archi'] = self.load_or_def_option('archi', 'df')
default_archi = self.options['archi'] = self.load_or_def_option('archi', 'liae-ud')
default_ae_dims = self.options['ae_dims'] = self.load_or_def_option('ae_dims', 256) default_ae_dims = self.options['ae_dims'] = self.load_or_def_option('ae_dims', 256)
default_e_dims = self.options['e_dims'] = self.load_or_def_option('e_dims', 64) default_e_dims = self.options['e_dims'] = self.load_or_def_option('e_dims', 64)
default_d_dims = self.options['d_dims'] = self.options.get('d_dims', None) default_d_dims = self.options['d_dims'] = self.options.get('d_dims', None)
default_d_mask_dims = self.options['d_mask_dims'] = self.options.get('d_mask_dims', None) default_d_mask_dims = self.options['d_mask_dims'] = self.options.get('d_mask_dims', None)
default_masked_training = self.options['masked_training'] = self.load_or_def_option('masked_training', True) default_masked_training = self.options['masked_training'] = self.load_or_def_option('masked_training', True)
default_eyes_mouth_prio = self.options['eyes_mouth_prio'] = self.load_or_def_option('eyes_mouth_prio', False) default_eyes_prio = self.options['eyes_prio'] = self.load_or_def_option('eyes_prio', False)
default_uniform_yaw = self.options['uniform_yaw'] = self.load_or_def_option('uniform_yaw', False) default_lr_dropout = self.options['lr_dropout'] = self.load_or_def_option('lr_dropout', False)
default_blur_out_mask = self.options['blur_out_mask'] = self.load_or_def_option('blur_out_mask', False)
default_adabelief = self.options['adabelief'] = self.load_or_def_option('adabelief', True)
lr_dropout = self.load_or_def_option('lr_dropout', 'n')
lr_dropout = {True:'y', False:'n'}.get(lr_dropout, lr_dropout) #backward comp
default_lr_dropout = self.options['lr_dropout'] = lr_dropout
default_random_warp = self.options['random_warp'] = self.load_or_def_option('random_warp', True) default_random_warp = self.options['random_warp'] = self.load_or_def_option('random_warp', True)
default_random_hsv_power = self.options['random_hsv_power'] = self.load_or_def_option('random_hsv_power', 0.0) default_gan_power = self.options['gan_power'] = self.load_or_def_option('gan_power', 0.0)
default_true_face_power = self.options['true_face_power'] = self.load_or_def_option('true_face_power', 0.0) default_true_face_power = self.options['true_face_power'] = self.load_or_def_option('true_face_power', 0.0)
default_face_style_power = self.options['face_style_power'] = self.load_or_def_option('face_style_power', 0.0) default_face_style_power = self.options['face_style_power'] = self.load_or_def_option('face_style_power', 0.0)
default_bg_style_power = self.options['bg_style_power'] = self.load_or_def_option('bg_style_power', 0.0) default_bg_style_power = self.options['bg_style_power'] = self.load_or_def_option('bg_style_power', 0.0)
@ -66,55 +53,18 @@ class SAEHDModel(ModelBase):
self.ask_autobackup_hour() self.ask_autobackup_hour()
self.ask_write_preview_history() self.ask_write_preview_history()
self.ask_target_iter() self.ask_target_iter()
self.ask_random_src_flip() self.ask_random_flip()
self.ask_random_dst_flip()
self.ask_batch_size(suggest_batch_size) self.ask_batch_size(suggest_batch_size)
#self.options['use_fp16'] = io.input_bool ("Use fp16", default_usefp16, help_message='Increases training/inference speed, reduces model size. Model may crash. Enable it after 1-5k iters.')
if self.is_first_run(): if self.is_first_run():
resolution = io.input_int("Resolution", default_resolution, add_info="64-640", help_message="More resolution requires more VRAM and time to train. Value will be adjusted to multiple of 16 and 32 for -d archi.") resolution = io.input_int("Resolution", default_resolution, add_info="64-512", help_message="More resolution requires more VRAM and time to train. Value will be adjusted to multiple of 16.")
resolution = np.clip ( (resolution // 16) * 16, min_res, max_res) resolution = np.clip ( (resolution // 16) * 16, 64, 512)
self.options['resolution'] = resolution self.options['resolution'] = resolution
self.options['face_type'] = io.input_str ("Face type", default_face_type, ['h','mf','f','wf','head'], help_message="Half / mid face / full face / whole face / head. Half face has better resolution, but covers less area of cheeks. Mid face is 30% wider than half face. 'Whole face' covers full area of face include forehead. 'head' covers full head, but requires XSeg for src and dst faceset.").lower() self.options['face_type'] = io.input_str ("Face type", default_face_type, ['h','mf','f','wf','head'], help_message="Half / mid face / full face / whole face / head. Half face has better resolution, but covers less area of cheeks. Mid face is 30% wider than half face. 'Whole face' covers full area of face include forehead. 'head' covers full head, but requires XSeg for src and dst faceset.").lower()
self.options['archi'] = io.input_str ("AE architecture", default_archi, ['df','liae','dfhd','liaehd','dfuhd','liaeuhd'], help_message="'df' keeps faces more natural.\n'liae' can fix overly different face shapes.\n'hd' are experimental versions.").lower()
while True: default_d_dims = 48 if self.options['archi'] == 'dfhd' else 64
archi = io.input_str ("AE architecture", default_archi, help_message=\ default_d_dims = self.options['d_dims'] = self.load_or_def_option('d_dims', default_d_dims)
"""
'df' keeps more identity-preserved face.
'liae' can fix overly different face shapes.
'-u' increased likeness of the face.
'-d' (experimental) doubling the resolution using the same computation cost.
Examples: df, liae, df-d, df-ud, liae-ud, ...
""").lower()
archi_split = archi.split('-')
if len(archi_split) == 2:
archi_type, archi_opts = archi_split
elif len(archi_split) == 1:
archi_type, archi_opts = archi_split[0], None
else:
continue
if archi_type not in ['df', 'liae']:
continue
if archi_opts is not None:
if len(archi_opts) == 0:
continue
if len([ 1 for opt in archi_opts if opt not in ['u','d','t','c'] ]) != 0:
continue
if 'd' in archi_opts:
self.options['resolution'] = np.clip ( (self.options['resolution'] // 32) * 32, min_res, max_res)
break
self.options['archi'] = archi
default_d_dims = self.options['d_dims'] = self.load_or_def_option('d_dims', 64)
default_d_mask_dims = default_d_dims // 3 default_d_mask_dims = default_d_dims // 3
default_d_mask_dims += default_d_mask_dims % 2 default_d_mask_dims += default_d_mask_dims % 2
@ -126,6 +76,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
e_dims = np.clip ( io.input_int("Encoder dimensions", default_e_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 ) e_dims = np.clip ( io.input_int("Encoder dimensions", default_e_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 )
self.options['e_dims'] = e_dims + e_dims % 2 self.options['e_dims'] = e_dims + e_dims % 2
d_dims = np.clip ( io.input_int("Decoder dimensions", default_d_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 ) d_dims = np.clip ( io.input_int("Decoder dimensions", default_d_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 )
self.options['d_dims'] = d_dims + d_dims % 2 self.options['d_dims'] = d_dims + d_dims % 2
@ -134,35 +85,17 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
if self.is_first_run() or ask_override: if self.is_first_run() or ask_override:
if self.options['face_type'] == 'wf' or self.options['face_type'] == 'head': if self.options['face_type'] == 'wf' or self.options['face_type'] == 'head':
self.options['masked_training'] = io.input_bool ("Masked training", default_masked_training, help_message="This option is available only for 'whole_face' or 'head' type. Masked training clips training area to full_face mask or XSeg mask, thus network will train the faces properly.") self.options['masked_training'] = io.input_bool ("Masked training", default_masked_training, help_message="This option is available only for 'whole_face' type. Masked training clips training area to full_face mask, thus network will train the faces properly. When the face is trained enough, disable this option to train all area of the frame. Merge with 'raw-rgb' mode, then use Adobe After Effects to manually mask and compose whole face include forehead.")
self.options['eyes_mouth_prio'] = io.input_bool ("Eyes and mouth priority", default_eyes_mouth_prio, help_message='Helps to fix eye problems during training like "alien eyes" and wrong eyes direction. Also makes the detail of the teeth higher.') self.options['eyes_prio'] = io.input_bool ("Eyes priority", default_eyes_prio, help_message='Helps to fix eye problems during training like "alien eyes" and wrong eyes direction ( especially on HD architectures ) by forcing the neural network to train eyes with higher priority. before/after https://i.imgur.com/YQHOuSR.jpg ')
self.options['uniform_yaw'] = io.input_bool ("Uniform yaw distribution of samples", default_uniform_yaw, help_message='Helps to fix blurry side faces due to small amount of them in the faceset.')
self.options['blur_out_mask'] = io.input_bool ("Blur out mask", default_blur_out_mask, help_message='Blurs nearby area outside of applied face mask of training samples. The result is the background near the face is smoothed and less noticeable on swapped face. The exact xseg mask in src and dst faceset is required.')
default_gan_power = self.options['gan_power'] = self.load_or_def_option('gan_power', 0.0)
default_gan_patch_size = self.options['gan_patch_size'] = self.load_or_def_option('gan_patch_size', self.options['resolution'] // 8)
default_gan_dims = self.options['gan_dims'] = self.load_or_def_option('gan_dims', 16)
if self.is_first_run() or ask_override: if self.is_first_run() or ask_override:
self.options['models_opt_on_gpu'] = io.input_bool ("Place models and optimizer on GPU", default_models_opt_on_gpu, help_message="When you train on one GPU, by default model and optimizer weights are placed on GPU to accelerate the process. You can place they on CPU to free up extra VRAM, thus set bigger dimensions.") self.options['models_opt_on_gpu'] = io.input_bool ("Place models and optimizer on GPU", default_models_opt_on_gpu, help_message="When you train on one GPU, by default model and optimizer weights are placed on GPU to accelerate the process. You can place they on CPU to free up extra VRAM, thus set bigger dimensions.")
self.options['adabelief'] = io.input_bool ("Use AdaBelief optimizer?", default_adabelief, help_message="Use AdaBelief optimizer. It requires more VRAM, but the accuracy and the generalization of the model is higher.") self.options['lr_dropout'] = io.input_bool ("Use learning rate dropout", default_lr_dropout, help_message="When the face is trained enough, you can enable this option to get extra sharpness and reduce subpixel shake for less amount of iterations.")
self.options['lr_dropout'] = io.input_str (f"Use learning rate dropout", default_lr_dropout, ['n','y','cpu'], help_message="When the face is trained enough, you can enable this option to get extra sharpness and reduce subpixel shake for less amount of iterations. Enabled it before `disable random warp` and before GAN. \nn - disabled.\ny - enabled\ncpu - enabled on CPU. This allows not to use extra VRAM, sacrificing 20% time of iteration.")
self.options['random_warp'] = io.input_bool ("Enable random warp of samples", default_random_warp, help_message="Random warp is required to generalize facial expressions of both faces. When the face is trained enough, you can disable it to get extra sharpness and reduce subpixel shake for less amount of iterations.") self.options['random_warp'] = io.input_bool ("Enable random warp of samples", default_random_warp, help_message="Random warp is required to generalize facial expressions of both faces. When the face is trained enough, you can disable it to get extra sharpness and reduce subpixel shake for less amount of iterations.")
self.options['random_hsv_power'] = np.clip ( io.input_number ("Random hue/saturation/light intensity", default_random_hsv_power, add_info="0.0 .. 0.3", help_message="Random hue/saturation/light intensity applied to the src face set only at the input of the neural network. Stabilizes color perturbations during face swapping. Reduces the quality of the color transfer by selecting the closest one in the src faceset. Thus the src faceset must be diverse enough. Typical fine value is 0.05"), 0.0, 0.3 ) self.options['gan_power'] = np.clip ( io.input_number ("GAN power", default_gan_power, add_info="0.0 .. 10.0", help_message="Train the network in Generative Adversarial manner. Accelerates the speed of training. Forces the neural network to learn small details of the face. You can enable/disable this option at any time. Typical value is 1.0"), 0.0, 10.0 )
self.options['gan_power'] = np.clip ( io.input_number ("GAN power", default_gan_power, add_info="0.0 .. 5.0", help_message="Forces the neural network to learn small details of the face. Enable it only when the face is trained enough with lr_dropout(on) and random_warp(off), and don't disable. The higher the value, the higher the chances of artifacts. Typical fine value is 0.1"), 0.0, 5.0 )
if self.options['gan_power'] != 0.0:
gan_patch_size = np.clip ( io.input_int("GAN patch size", default_gan_patch_size, add_info="3-640", help_message="The higher patch size, the higher the quality, the more VRAM is required. You can get sharper edges even at the lowest setting. Typical fine value is resolution / 8." ), 3, 640 )
self.options['gan_patch_size'] = gan_patch_size
gan_dims = np.clip ( io.input_int("GAN dimensions", default_gan_dims, add_info="4-512", help_message="The dimensions of the GAN network. The higher dimensions, the more VRAM is required. You can get sharper edges even at the lowest setting. Typical fine value is 16." ), 4, 512 )
self.options['gan_dims'] = gan_dims
if 'df' in self.options['archi']: if 'df' in self.options['archi']:
self.options['true_face_power'] = np.clip ( io.input_number ("'True face' power.", default_true_face_power, add_info="0.0000 .. 1.0", help_message="Experimental option. Discriminates result face to be more like src face. Higher value - stronger discrimination. Typical value is 0.01 . Comparison - https://i.imgur.com/czScS9q.png"), 0.0, 1.0 ) self.options['true_face_power'] = np.clip ( io.input_number ("'True face' power.", default_true_face_power, add_info="0.0000 .. 1.0", help_message="Experimental option. Discriminates result face to be more like src face. Higher value - stronger discrimination. Typical value is 0.01 . Comparison - https://i.imgur.com/czScS9q.png"), 0.0, 1.0 )
@ -175,13 +108,11 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
self.options['ct_mode'] = io.input_str (f"Color transfer for src faceset", default_ct_mode, ['none','rct','lct','mkl','idt','sot'], help_message="Change color distribution of src samples close to dst samples. Try all modes to find the best.") self.options['ct_mode'] = io.input_str (f"Color transfer for src faceset", default_ct_mode, ['none','rct','lct','mkl','idt','sot'], help_message="Change color distribution of src samples close to dst samples. Try all modes to find the best.")
self.options['clipgrad'] = io.input_bool ("Enable gradient clipping", default_clipgrad, help_message="Gradient clipping reduces chance of model collapse, sacrificing speed of training.") self.options['clipgrad'] = io.input_bool ("Enable gradient clipping", default_clipgrad, help_message="Gradient clipping reduces chance of model collapse, sacrificing speed of training.")
self.options['pretrain'] = io.input_bool ("Enable pretraining mode", default_pretrain, help_message="Pretrain the model with large amount of various faces. After that, model can be used to train the fakes more quickly. Forces random_warp=N, random_flips=Y, gan_power=0.0, lr_dropout=N, styles=0.0, uniform_yaw=Y") self.options['pretrain'] = io.input_bool ("Enable pretraining mode", default_pretrain, help_message="Pretrain the model with large amount of various faces. After that, model can be used to train the fakes more quickly.")
if self.options['pretrain'] and self.get_pretraining_data_path() is None: if self.options['pretrain'] and self.get_pretraining_data_path() is None:
raise Exception("pretraining_data_path is not defined") raise Exception("pretraining_data_path is not defined")
self.gan_model_changed = (default_gan_patch_size != self.options['gan_patch_size']) or (default_gan_dims != self.options['gan_dims'])
self.pretrain_just_disabled = (default_pretrain == True and self.options['pretrain'] == False) self.pretrain_just_disabled = (default_pretrain == True and self.options['pretrain'] == False)
#override #override
@ -199,20 +130,9 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
'wf' : FaceType.WHOLE_FACE, 'wf' : FaceType.WHOLE_FACE,
'head' : FaceType.HEAD}[ self.options['face_type'] ] 'head' : FaceType.HEAD}[ self.options['face_type'] ]
if 'eyes_prio' in self.options: eyes_prio = self.options['eyes_prio']
self.options.pop('eyes_prio') archi = self.options['archi']
is_hd = 'hd' in archi
eyes_mouth_prio = self.options['eyes_mouth_prio']
archi_split = self.options['archi'].split('-')
if len(archi_split) == 2:
archi_type, archi_opts = archi_split
elif len(archi_split) == 1:
archi_type, archi_opts = archi_split[0], None
self.archi_type = archi_type
ae_dims = self.options['ae_dims'] ae_dims = self.options['ae_dims']
e_dims = self.options['e_dims'] e_dims = self.options['e_dims']
d_dims = self.options['d_dims'] d_dims = self.options['d_dims']
@ -221,69 +141,46 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
if self.pretrain_just_disabled: if self.pretrain_just_disabled:
self.set_iter(0) self.set_iter(0)
adabelief = self.options['adabelief'] self.gan_power = gan_power = self.options['gan_power'] if not self.pretrain else 0.0
use_fp16 = False
if self.is_exporting:
use_fp16 = io.input_bool ("Export quantized?", False, help_message='Makes the exported model faster. If you have problems, disable this option.')
self.gan_power = gan_power = 0.0 if self.pretrain else self.options['gan_power']
random_warp = False if self.pretrain else self.options['random_warp']
random_src_flip = self.random_src_flip if not self.pretrain else True
random_dst_flip = self.random_dst_flip if not self.pretrain else True
random_hsv_power = self.options['random_hsv_power'] if not self.pretrain else 0.0
blur_out_mask = self.options['blur_out_mask']
if self.pretrain:
self.options_show_override['lr_dropout'] = 'n'
self.options_show_override['random_warp'] = False
self.options_show_override['gan_power'] = 0.0
self.options_show_override['random_hsv_power'] = 0.0
self.options_show_override['face_style_power'] = 0.0
self.options_show_override['bg_style_power'] = 0.0
self.options_show_override['uniform_yaw'] = True
masked_training = self.options['masked_training'] masked_training = self.options['masked_training']
ct_mode = self.options['ct_mode'] ct_mode = self.options['ct_mode']
if ct_mode == 'none': if ct_mode == 'none':
ct_mode = None ct_mode = None
models_opt_on_gpu = False if len(devices) == 0 else self.options['models_opt_on_gpu'] models_opt_on_gpu = False if len(devices) == 0 else self.options['models_opt_on_gpu']
models_opt_device = nn.tf_default_device_name if models_opt_on_gpu and self.is_training else '/CPU:0' models_opt_device = '/GPU:0' if models_opt_on_gpu and self.is_training else '/CPU:0'
optimizer_vars_on_cpu = models_opt_device=='/CPU:0' optimizer_vars_on_cpu = models_opt_device=='/CPU:0'
input_ch=3 input_ch=3
bgr_shape = self.bgr_shape = nn.get4Dshape(resolution,resolution,input_ch) bgr_shape = nn.get4Dshape(resolution,resolution,input_ch)
mask_shape = nn.get4Dshape(resolution,resolution,1) mask_shape = nn.get4Dshape(resolution,resolution,1)
self.model_filename_list = [] self.model_filename_list = []
with tf.device ('/CPU:0'): with tf.device ('/CPU:0'):
#Place holders on CPU #Place holders on CPU
self.warped_src = tf.placeholder (nn.floatx, bgr_shape, name='warped_src') self.warped_src = tf.placeholder (nn.floatx, bgr_shape)
self.warped_dst = tf.placeholder (nn.floatx, bgr_shape, name='warped_dst') self.warped_dst = tf.placeholder (nn.floatx, bgr_shape)
self.target_src = tf.placeholder (nn.floatx, bgr_shape, name='target_src') self.target_src = tf.placeholder (nn.floatx, bgr_shape)
self.target_dst = tf.placeholder (nn.floatx, bgr_shape, name='target_dst') self.target_dst = tf.placeholder (nn.floatx, bgr_shape)
self.target_srcm = tf.placeholder (nn.floatx, mask_shape, name='target_srcm') self.target_srcm_all = tf.placeholder (nn.floatx, mask_shape)
self.target_srcm_em = tf.placeholder (nn.floatx, mask_shape, name='target_srcm_em') self.target_dstm_all = tf.placeholder (nn.floatx, mask_shape)
self.target_dstm = tf.placeholder (nn.floatx, mask_shape, name='target_dstm')
self.target_dstm_em = tf.placeholder (nn.floatx, mask_shape, name='target_dstm_em')
# Initializing model classes # Initializing model classes
model_archi = nn.DeepFakeArchi(resolution, use_fp16=use_fp16, opts=archi_opts) model_archi = nn.DeepFakeArchi(resolution, mod='uhd' if 'uhd' in archi else None)
with tf.device (models_opt_device): with tf.device (models_opt_device):
if 'df' in archi_type: if 'df' in archi:
self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, name='encoder') self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, is_hd=is_hd, name='encoder')
encoder_out_ch = self.encoder.get_out_ch()*self.encoder.get_out_res(resolution)**2 encoder_out_ch = self.encoder.compute_output_channels ( (nn.floatx, bgr_shape))
self.inter = model_archi.Inter (in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims, name='inter') self.inter = model_archi.Inter (in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims, is_hd=is_hd, name='inter')
inter_out_ch = self.inter.get_out_ch() inter_out_ch = self.inter.compute_output_channels ( (nn.floatx, (None,encoder_out_ch)))
self.decoder_src = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder_src') self.decoder_src = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, is_hd=is_hd, name='decoder_src')
self.decoder_dst = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder_dst') self.decoder_dst = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, is_hd=is_hd, name='decoder_dst')
self.model_filename_list += [ [self.encoder, 'encoder.npy' ], self.model_filename_list += [ [self.encoder, 'encoder.npy' ],
[self.inter, 'inter.npy' ], [self.inter, 'inter.npy' ],
@ -292,19 +189,20 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
if self.is_training: if self.is_training:
if self.options['true_face_power'] != 0: if self.options['true_face_power'] != 0:
self.code_discriminator = nn.CodeDiscriminator(ae_dims, code_res=self.inter.get_out_res(), name='dis' ) self.code_discriminator = nn.CodeDiscriminator(ae_dims, code_res=model_archi.Inter.get_code_res()*2, name='dis' )
self.model_filename_list += [ [self.code_discriminator, 'code_discriminator.npy'] ] self.model_filename_list += [ [self.code_discriminator, 'code_discriminator.npy'] ]
elif 'liae' in archi_type: elif 'liae' in archi:
self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, name='encoder') self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, is_hd=is_hd, name='encoder')
encoder_out_ch = self.encoder.get_out_ch()*self.encoder.get_out_res(resolution)**2 encoder_out_ch = self.encoder.compute_output_channels ( (nn.floatx, bgr_shape))
self.inter_AB = model_archi.Inter(in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims*2, name='inter_AB') self.inter_AB = model_archi.Inter(in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims*2, is_hd=is_hd, name='inter_AB')
self.inter_B = model_archi.Inter(in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims*2, name='inter_B') self.inter_B = model_archi.Inter(in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims*2, is_hd=is_hd, name='inter_B')
inter_out_ch = self.inter_AB.get_out_ch() inter_AB_out_ch = self.inter_AB.compute_output_channels ( (nn.floatx, (None,encoder_out_ch)))
inters_out_ch = inter_out_ch*2 inter_B_out_ch = self.inter_B.compute_output_channels ( (nn.floatx, (None,encoder_out_ch)))
self.decoder = model_archi.Decoder(in_ch=inters_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder') inters_out_ch = inter_AB_out_ch+inter_B_out_ch
self.decoder = model_archi.Decoder(in_ch=inters_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, is_hd=is_hd, name='decoder')
self.model_filename_list += [ [self.encoder, 'encoder.npy'], self.model_filename_list += [ [self.encoder, 'encoder.npy'],
[self.inter_AB, 'inter_AB.npy'], [self.inter_AB, 'inter_AB.npy'],
@ -313,43 +211,33 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
if self.is_training: if self.is_training:
if gan_power != 0: if gan_power != 0:
self.D_src = nn.UNetPatchDiscriminator(patch_size=self.options['gan_patch_size'], in_ch=input_ch, base_ch=self.options['gan_dims'], name="D_src") self.D_src = nn.PatchDiscriminator(patch_size=resolution//16, in_ch=input_ch, name="D_src")
self.model_filename_list += [ [self.D_src, 'GAN.npy'] ] self.D_dst = nn.PatchDiscriminator(patch_size=resolution//16, in_ch=input_ch, name="D_dst")
self.model_filename_list += [ [self.D_src, 'D_src.npy'] ]
self.model_filename_list += [ [self.D_dst, 'D_dst.npy'] ]
# Initialize optimizers # Initialize optimizers
lr=5e-5 lr=5e-5
if self.options['lr_dropout'] in ['y','cpu'] and not self.pretrain: lr_dropout = 0.3 if self.options['lr_dropout'] and not self.pretrain else 1.0
lr_cos = 500
lr_dropout = 0.3
else:
lr_cos = 0
lr_dropout = 1.0
OptimizerClass = nn.AdaBelief if adabelief else nn.RMSprop
clipnorm = 1.0 if self.options['clipgrad'] else 0.0 clipnorm = 1.0 if self.options['clipgrad'] else 0.0
self.src_dst_opt = nn.RMSprop(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='src_dst_opt')
if 'df' in archi_type:
self.src_dst_saveable_weights = self.encoder.get_weights() + self.inter.get_weights() + self.decoder_src.get_weights() + self.decoder_dst.get_weights()
self.src_dst_trainable_weights = self.src_dst_saveable_weights
elif 'liae' in archi_type:
self.src_dst_saveable_weights = self.encoder.get_weights() + self.inter_AB.get_weights() + self.inter_B.get_weights() + self.decoder.get_weights()
if random_warp:
self.src_dst_trainable_weights = self.src_dst_saveable_weights
else:
self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter_B.get_weights() + self.decoder.get_weights()
self.src_dst_opt = OptimizerClass(lr=lr, lr_dropout=lr_dropout, lr_cos=lr_cos, clipnorm=clipnorm, name='src_dst_opt')
self.src_dst_opt.initialize_variables (self.src_dst_saveable_weights, vars_on_cpu=optimizer_vars_on_cpu, lr_dropout_on_cpu=self.options['lr_dropout']=='cpu')
self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ] self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ]
if 'df' in archi:
self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter.get_weights() + self.decoder_src.get_weights() + self.decoder_dst.get_weights()
elif 'liae' in archi:
self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter_AB.get_weights() + self.inter_B.get_weights() + self.decoder.get_weights()
self.src_dst_opt.initialize_variables (self.src_dst_trainable_weights, vars_on_cpu=optimizer_vars_on_cpu)
if self.options['true_face_power'] != 0: if self.options['true_face_power'] != 0:
self.D_code_opt = OptimizerClass(lr=lr, lr_dropout=lr_dropout, lr_cos=lr_cos, clipnorm=clipnorm, name='D_code_opt') self.D_code_opt = nn.RMSprop(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='D_code_opt')
self.D_code_opt.initialize_variables ( self.code_discriminator.get_weights(), vars_on_cpu=optimizer_vars_on_cpu, lr_dropout_on_cpu=self.options['lr_dropout']=='cpu') self.D_code_opt.initialize_variables ( self.code_discriminator.get_weights(), vars_on_cpu=optimizer_vars_on_cpu)
self.model_filename_list += [ (self.D_code_opt, 'D_code_opt.npy') ] self.model_filename_list += [ (self.D_code_opt, 'D_code_opt.npy') ]
if gan_power != 0: if gan_power != 0:
self.D_src_dst_opt = OptimizerClass(lr=lr, lr_dropout=lr_dropout, lr_cos=lr_cos, clipnorm=clipnorm, name='GAN_opt') self.D_src_dst_opt = nn.RMSprop(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='D_src_dst_opt')
self.D_src_dst_opt.initialize_variables ( self.D_src.get_weights(), vars_on_cpu=optimizer_vars_on_cpu, lr_dropout_on_cpu=self.options['lr_dropout']=='cpu')#+self.D_src_x2.get_weights() self.D_src_dst_opt.initialize_variables ( self.D_src.get_weights()+self.D_dst.get_weights(), vars_on_cpu=optimizer_vars_on_cpu)
self.model_filename_list += [ (self.D_src_dst_opt, 'GAN_opt.npy') ] self.model_filename_list += [ (self.D_src_dst_opt, 'D_src_dst_opt.npy') ]
if self.is_training: if self.is_training:
# Adjust batch size for multiple GPU # Adjust batch size for multiple GPU
@ -357,6 +245,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
bs_per_gpu = max(1, self.get_batch_size() // gpu_count) bs_per_gpu = max(1, self.get_batch_size() // gpu_count)
self.set_batch_size( gpu_count*bs_per_gpu) self.set_batch_size( gpu_count*bs_per_gpu)
# Compute losses per GPU # Compute losses per GPU
gpu_pred_src_src_list = [] gpu_pred_src_src_list = []
gpu_pred_dst_dst_list = [] gpu_pred_dst_dst_list = []
@ -370,9 +259,9 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
gpu_G_loss_gvs = [] gpu_G_loss_gvs = []
gpu_D_code_loss_gvs = [] gpu_D_code_loss_gvs = []
gpu_D_src_dst_loss_gvs = [] gpu_D_src_dst_loss_gvs = []
for gpu_id in range(gpu_count): for gpu_id in range(gpu_count):
with tf.device( f'/{devices[gpu_id].tf_dev_type}:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ): with tf.device( f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ):
with tf.device(f'/CPU:0'): with tf.device(f'/CPU:0'):
# slice on CPU, otherwise all batch data will be transfered to GPU first # slice on CPU, otherwise all batch data will be transfered to GPU first
batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu ) batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu )
@ -380,38 +269,18 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
gpu_warped_dst = self.warped_dst [batch_slice,:,:,:] gpu_warped_dst = self.warped_dst [batch_slice,:,:,:]
gpu_target_src = self.target_src [batch_slice,:,:,:] gpu_target_src = self.target_src [batch_slice,:,:,:]
gpu_target_dst = self.target_dst [batch_slice,:,:,:] gpu_target_dst = self.target_dst [batch_slice,:,:,:]
gpu_target_srcm = self.target_srcm[batch_slice,:,:,:] gpu_target_srcm_all = self.target_srcm_all[batch_slice,:,:,:]
gpu_target_srcm_em = self.target_srcm_em[batch_slice,:,:,:] gpu_target_dstm_all = self.target_dstm_all[batch_slice,:,:,:]
gpu_target_dstm = self.target_dstm[batch_slice,:,:,:]
gpu_target_dstm_em = self.target_dstm_em[batch_slice,:,:,:]
gpu_target_srcm_anti = 1-gpu_target_srcm
gpu_target_dstm_anti = 1-gpu_target_dstm
if blur_out_mask:
sigma = resolution / 128
x = nn.gaussian_blur(gpu_target_src*gpu_target_srcm_anti, sigma)
y = 1-nn.gaussian_blur(gpu_target_srcm, sigma)
y = tf.where(tf.equal(y, 0), tf.ones_like(y), y)
gpu_target_src = gpu_target_src*gpu_target_srcm + (x/y)*gpu_target_srcm_anti
x = nn.gaussian_blur(gpu_target_dst*gpu_target_dstm_anti, sigma)
y = 1-nn.gaussian_blur(gpu_target_dstm, sigma)
y = tf.where(tf.equal(y, 0), tf.ones_like(y), y)
gpu_target_dst = gpu_target_dst*gpu_target_dstm + (x/y)*gpu_target_dstm_anti
# process model tensors # process model tensors
if 'df' in archi_type: if 'df' in archi:
gpu_src_code = self.inter(self.encoder(gpu_warped_src)) gpu_src_code = self.inter(self.encoder(gpu_warped_src))
gpu_dst_code = self.inter(self.encoder(gpu_warped_dst)) gpu_dst_code = self.inter(self.encoder(gpu_warped_dst))
gpu_pred_src_src, gpu_pred_src_srcm = self.decoder_src(gpu_src_code) gpu_pred_src_src, gpu_pred_src_srcm = self.decoder_src(gpu_src_code)
gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code) gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code)
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code)
gpu_pred_src_dst_no_code_grad, _ = self.decoder_src(tf.stop_gradient(gpu_dst_code))
elif 'liae' in archi_type: elif 'liae' in archi:
gpu_src_code = self.encoder (gpu_warped_src) gpu_src_code = self.encoder (gpu_warped_src)
gpu_src_inter_AB_code = self.inter_AB (gpu_src_code) gpu_src_inter_AB_code = self.inter_AB (gpu_src_code)
gpu_src_code = tf.concat([gpu_src_inter_AB_code,gpu_src_inter_AB_code], nn.conv2d_ch_axis ) gpu_src_code = tf.concat([gpu_src_inter_AB_code,gpu_src_inter_AB_code], nn.conv2d_ch_axis )
@ -424,7 +293,6 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
gpu_pred_src_src, gpu_pred_src_srcm = self.decoder(gpu_src_code) gpu_pred_src_src, gpu_pred_src_srcm = self.decoder(gpu_src_code)
gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder(gpu_dst_code) gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder(gpu_dst_code)
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code)
gpu_pred_src_dst_no_code_grad, _ = self.decoder(tf.stop_gradient(gpu_src_dst_code))
gpu_pred_src_src_list.append(gpu_pred_src_src) gpu_pred_src_src_list.append(gpu_pred_src_src)
gpu_pred_dst_dst_list.append(gpu_pred_dst_dst) gpu_pred_dst_dst_list.append(gpu_pred_dst_dst)
@ -434,60 +302,49 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm) gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm)
gpu_pred_src_dstm_list.append(gpu_pred_src_dstm) gpu_pred_src_dstm_list.append(gpu_pred_src_dstm)
# unpack masks from one combined mask
gpu_target_srcm = tf.clip_by_value (gpu_target_srcm_all, 0, 1)
gpu_target_dstm = tf.clip_by_value (gpu_target_dstm_all, 0, 1)
gpu_target_srcm_eyes = tf.clip_by_value (gpu_target_srcm_all-1, 0, 1)
gpu_target_dstm_eyes = tf.clip_by_value (gpu_target_dstm_all-1, 0, 1)
gpu_target_srcm_blur = nn.gaussian_blur(gpu_target_srcm, max(1, resolution // 32) ) gpu_target_srcm_blur = nn.gaussian_blur(gpu_target_srcm, max(1, resolution // 32) )
gpu_target_srcm_blur = tf.clip_by_value(gpu_target_srcm_blur, 0, 0.5) * 2
gpu_target_srcm_anti_blur = 1.0-gpu_target_srcm_blur
gpu_target_dstm_blur = nn.gaussian_blur(gpu_target_dstm, max(1, resolution // 32) ) gpu_target_dstm_blur = nn.gaussian_blur(gpu_target_dstm, max(1, resolution // 32) )
gpu_target_dstm_blur = tf.clip_by_value(gpu_target_dstm_blur, 0, 0.5) * 2
gpu_style_mask_blur = nn.gaussian_blur(gpu_pred_src_dstm*gpu_pred_dst_dstm, max(1, resolution // 32) ) gpu_target_dst_masked = gpu_target_dst*gpu_target_dstm_blur
gpu_style_mask_blur = tf.stop_gradient(tf.clip_by_value(gpu_target_srcm_blur, 0, 1.0)) gpu_target_dst_anti_masked = gpu_target_dst*(1.0 - gpu_target_dstm_blur)
gpu_style_mask_anti_blur = 1.0 - gpu_style_mask_blur
gpu_target_dst_masked = gpu_target_dst*gpu_target_dstm_blur
gpu_target_src_anti_masked = gpu_target_src*gpu_target_srcm_anti_blur
gpu_pred_src_src_anti_masked = gpu_pred_src_src*gpu_target_srcm_anti_blur
gpu_target_src_masked_opt = gpu_target_src*gpu_target_srcm_blur if masked_training else gpu_target_src gpu_target_src_masked_opt = gpu_target_src*gpu_target_srcm_blur if masked_training else gpu_target_src
gpu_target_dst_masked_opt = gpu_target_dst_masked if masked_training else gpu_target_dst gpu_target_dst_masked_opt = gpu_target_dst_masked if masked_training else gpu_target_dst
gpu_pred_src_src_masked_opt = gpu_pred_src_src*gpu_target_srcm_blur if masked_training else gpu_pred_src_src gpu_pred_src_src_masked_opt = gpu_pred_src_src*gpu_target_srcm_blur if masked_training else gpu_pred_src_src
gpu_pred_dst_dst_masked_opt = gpu_pred_dst_dst*gpu_target_dstm_blur if masked_training else gpu_pred_dst_dst gpu_pred_dst_dst_masked_opt = gpu_pred_dst_dst*gpu_target_dstm_blur if masked_training else gpu_pred_dst_dst
if resolution < 256: gpu_psd_target_dst_masked = gpu_pred_src_dst*gpu_target_dstm_blur
gpu_src_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) gpu_psd_target_dst_anti_masked = gpu_pred_src_dst*(1.0 - gpu_target_dstm_blur)
else:
gpu_src_loss = tf.reduce_mean ( 5*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) gpu_src_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
gpu_src_loss += tf.reduce_mean ( 5*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1])
gpu_src_loss += tf.reduce_mean ( 10*tf.square ( gpu_target_src_masked_opt - gpu_pred_src_src_masked_opt ), axis=[1,2,3]) gpu_src_loss += tf.reduce_mean ( 10*tf.square ( gpu_target_src_masked_opt - gpu_pred_src_src_masked_opt ), axis=[1,2,3])
if eyes_mouth_prio: if eyes_prio:
gpu_src_loss += tf.reduce_mean ( 300*tf.abs ( gpu_target_src*gpu_target_srcm_em - gpu_pred_src_src*gpu_target_srcm_em ), axis=[1,2,3]) gpu_src_loss += tf.reduce_mean ( 300*tf.abs ( gpu_target_src*gpu_target_srcm_eyes - gpu_pred_src_src*gpu_target_srcm_eyes ), axis=[1,2,3])
gpu_src_loss += tf.reduce_mean ( 10*tf.square( gpu_target_srcm - gpu_pred_src_srcm ),axis=[1,2,3] ) gpu_src_loss += tf.reduce_mean ( 10*tf.square( gpu_target_srcm - gpu_pred_src_srcm ),axis=[1,2,3] )
face_style_power = self.options['face_style_power'] / 100.0 face_style_power = self.options['face_style_power'] / 100.0
if face_style_power != 0 and not self.pretrain: if face_style_power != 0 and not self.pretrain:
gpu_src_loss += nn.style_loss(gpu_pred_src_dst_no_code_grad*tf.stop_gradient(gpu_pred_src_dstm), tf.stop_gradient(gpu_pred_dst_dst*gpu_pred_dst_dstm), gaussian_blur_radius=resolution//8, loss_weight=10000*face_style_power) gpu_src_loss += nn.style_loss(gpu_psd_target_dst_masked, gpu_target_dst_masked, gaussian_blur_radius=resolution//16, loss_weight=10000*face_style_power)
bg_style_power = self.options['bg_style_power'] / 100.0 bg_style_power = self.options['bg_style_power'] / 100.0
if bg_style_power != 0 and not self.pretrain: if bg_style_power != 0 and not self.pretrain:
gpu_target_dst_style_anti_masked = gpu_target_dst*gpu_style_mask_anti_blur gpu_src_loss += tf.reduce_mean( (10*bg_style_power)*nn.dssim(gpu_psd_target_dst_anti_masked, gpu_target_dst_anti_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
gpu_psd_style_anti_masked = gpu_pred_src_dst*gpu_style_mask_anti_blur gpu_src_loss += tf.reduce_mean( (10*bg_style_power)*tf.square( gpu_psd_target_dst_anti_masked - gpu_target_dst_anti_masked), axis=[1,2,3] )
gpu_src_loss += tf.reduce_mean( (10*bg_style_power)*nn.dssim( gpu_psd_style_anti_masked, gpu_target_dst_style_anti_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) gpu_dst_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1])
gpu_src_loss += tf.reduce_mean( (10*bg_style_power)*tf.square(gpu_psd_style_anti_masked - gpu_target_dst_style_anti_masked), axis=[1,2,3] )
if resolution < 256:
gpu_dst_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1])
else:
gpu_dst_loss = tf.reduce_mean ( 5*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1])
gpu_dst_loss += tf.reduce_mean ( 5*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/23.2) ), axis=[1])
gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dst_masked_opt- gpu_pred_dst_dst_masked_opt ), axis=[1,2,3]) gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dst_masked_opt- gpu_pred_dst_dst_masked_opt ), axis=[1,2,3])
if eyes_mouth_prio: if eyes_prio:
gpu_dst_loss += tf.reduce_mean ( 300*tf.abs ( gpu_target_dst*gpu_target_dstm_em - gpu_pred_dst_dst*gpu_target_dstm_em ), axis=[1,2,3]) gpu_dst_loss += tf.reduce_mean ( 300*tf.abs ( gpu_target_dst*gpu_target_dstm_eyes - gpu_pred_dst_dst*gpu_target_dstm_eyes ), axis=[1,2,3])
gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] ) gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] )
@ -508,49 +365,38 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
gpu_G_loss += self.options['true_face_power']*DLoss(gpu_src_code_d_ones, gpu_src_code_d) gpu_G_loss += self.options['true_face_power']*DLoss(gpu_src_code_d_ones, gpu_src_code_d)
gpu_D_code_loss = (DLoss(gpu_dst_code_d_ones , gpu_dst_code_d) + \ gpu_D_code_loss = (DLoss(gpu_src_code_d_ones , gpu_dst_code_d) + \
DLoss(gpu_src_code_d_zeros, gpu_src_code_d) ) * 0.5 DLoss(gpu_src_code_d_zeros, gpu_src_code_d) ) * 0.5
gpu_D_code_loss_gvs += [ nn.gradients (gpu_D_code_loss, self.code_discriminator.get_weights() ) ] gpu_D_code_loss_gvs += [ nn.gradients (gpu_D_code_loss, self.code_discriminator.get_weights() ) ]
if gan_power != 0: if gan_power != 0:
gpu_pred_src_src_d, \ gpu_pred_src_src_d = self.D_src(gpu_pred_src_src_masked_opt)
gpu_pred_src_src_d2 = self.D_src(gpu_pred_src_src_masked_opt)
gpu_pred_src_src_d_ones = tf.ones_like (gpu_pred_src_src_d) gpu_pred_src_src_d_ones = tf.ones_like (gpu_pred_src_src_d)
gpu_pred_src_src_d_zeros = tf.zeros_like(gpu_pred_src_src_d) gpu_pred_src_src_d_zeros = tf.zeros_like(gpu_pred_src_src_d)
gpu_target_src_d = self.D_src(gpu_target_src_masked_opt)
gpu_pred_src_src_d2_ones = tf.ones_like (gpu_pred_src_src_d2)
gpu_pred_src_src_d2_zeros = tf.zeros_like(gpu_pred_src_src_d2)
gpu_target_src_d, \
gpu_target_src_d2 = self.D_src(gpu_target_src_masked_opt)
gpu_target_src_d_ones = tf.ones_like(gpu_target_src_d) gpu_target_src_d_ones = tf.ones_like(gpu_target_src_d)
gpu_target_src_d2_ones = tf.ones_like(gpu_target_src_d2) gpu_pred_dst_dst_d = self.D_dst(gpu_pred_dst_dst_masked_opt)
gpu_pred_dst_dst_d_ones = tf.ones_like (gpu_pred_dst_dst_d)
gpu_pred_dst_dst_d_zeros = tf.zeros_like(gpu_pred_dst_dst_d)
gpu_target_dst_d = self.D_dst(gpu_target_dst_masked_opt)
gpu_target_dst_d_ones = tf.ones_like(gpu_target_dst_d)
gpu_D_src_dst_loss = (DLoss(gpu_target_src_d_ones , gpu_target_src_d) + \ gpu_D_src_dst_loss = (DLoss(gpu_target_src_d_ones , gpu_target_src_d) + \
DLoss(gpu_pred_src_src_d_zeros , gpu_pred_src_src_d) ) * 0.5 + \ DLoss(gpu_pred_src_src_d_zeros, gpu_pred_src_src_d) ) * 0.5 + \
(DLoss(gpu_target_src_d2_ones , gpu_target_src_d2) + \ (DLoss(gpu_target_dst_d_ones , gpu_target_dst_d) + \
DLoss(gpu_pred_src_src_d2_zeros , gpu_pred_src_src_d2) ) * 0.5 DLoss(gpu_pred_dst_dst_d_zeros, gpu_pred_dst_dst_d) ) * 0.5
gpu_D_src_dst_loss_gvs += [ nn.gradients (gpu_D_src_dst_loss, self.D_src.get_weights() ) ]#+self.D_src_x2.get_weights() gpu_D_src_dst_loss_gvs += [ nn.gradients (gpu_D_src_dst_loss, self.D_src.get_weights()+self.D_dst.get_weights() ) ]
gpu_G_loss += gan_power*(DLoss(gpu_pred_src_src_d_ones, gpu_pred_src_src_d) + \ gpu_G_loss += gan_power*(DLoss(gpu_pred_src_src_d_ones, gpu_pred_src_src_d) + DLoss(gpu_pred_dst_dst_d_ones, gpu_pred_dst_dst_d))
DLoss(gpu_pred_src_src_d2_ones, gpu_pred_src_src_d2))
if masked_training:
# Minimal src-src-bg rec with total_variation_mse to suppress random bright dots from gan
gpu_G_loss += 0.000001*nn.total_variation_mse(gpu_pred_src_src)
gpu_G_loss += 0.02*tf.reduce_mean(tf.square(gpu_pred_src_src_anti_masked-gpu_target_src_anti_masked),axis=[1,2,3] )
gpu_G_loss_gvs += [ nn.gradients ( gpu_G_loss, self.src_dst_trainable_weights )]
gpu_G_loss_gvs += [ nn.gradients ( gpu_G_loss, self.src_dst_trainable_weights ) ]
# Average losses and gradients, and create optimizer update ops # Average losses and gradients, and create optimizer update ops
with tf.device(f'/CPU:0'): with tf.device (models_opt_device):
pred_src_src = nn.concat(gpu_pred_src_src_list, 0) pred_src_src = nn.concat(gpu_pred_src_src_list, 0)
pred_dst_dst = nn.concat(gpu_pred_dst_dst_list, 0) pred_dst_dst = nn.concat(gpu_pred_dst_dst_list, 0)
pred_src_dst = nn.concat(gpu_pred_src_dst_list, 0) pred_src_dst = nn.concat(gpu_pred_src_dst_list, 0)
@ -558,7 +404,6 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
pred_dst_dstm = nn.concat(gpu_pred_dst_dstm_list, 0) pred_dst_dstm = nn.concat(gpu_pred_dst_dstm_list, 0)
pred_src_dstm = nn.concat(gpu_pred_src_dstm_list, 0) pred_src_dstm = nn.concat(gpu_pred_src_dstm_list, 0)
with tf.device (models_opt_device):
src_loss = tf.concat(gpu_src_losses, 0) src_loss = tf.concat(gpu_src_losses, 0)
dst_loss = tf.concat(gpu_dst_losses, 0) dst_loss = tf.concat(gpu_dst_losses, 0)
src_dst_loss_gv_op = self.src_dst_opt.get_update_op (nn.average_gv_list (gpu_G_loss_gvs)) src_dst_loss_gv_op = self.src_dst_opt.get_update_op (nn.average_gv_list (gpu_G_loss_gvs))
@ -571,18 +416,16 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
# Initializing training and view functions # Initializing training and view functions
def src_dst_train(warped_src, target_src, target_srcm, target_srcm_em, \ def src_dst_train(warped_src, target_src, target_srcm_all, \
warped_dst, target_dst, target_dstm, target_dstm_em, ): warped_dst, target_dst, target_dstm_all):
s, d = nn.tf_sess.run ( [ src_loss, dst_loss, src_dst_loss_gv_op], s, d, _ = nn.tf_sess.run ( [ src_loss, dst_loss, src_dst_loss_gv_op],
feed_dict={self.warped_src :warped_src, feed_dict={self.warped_src :warped_src,
self.target_src :target_src, self.target_src :target_src,
self.target_srcm:target_srcm, self.target_srcm_all:target_srcm_all,
self.target_srcm_em:target_srcm_em,
self.warped_dst :warped_dst, self.warped_dst :warped_dst,
self.target_dst :target_dst, self.target_dst :target_dst,
self.target_dstm:target_dstm, self.target_dstm_all:target_dstm_all,
self.target_dstm_em:target_dstm_em, })
})[:2]
return s, d return s, d
self.src_dst_train = src_dst_train self.src_dst_train = src_dst_train
@ -592,16 +435,14 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
self.D_train = D_train self.D_train = D_train
if gan_power != 0: if gan_power != 0:
def D_src_dst_train(warped_src, target_src, target_srcm, target_srcm_em, \ def D_src_dst_train(warped_src, target_src, target_srcm_all, \
warped_dst, target_dst, target_dstm, target_dstm_em, ): warped_dst, target_dst, target_dstm_all):
nn.tf_sess.run ([src_D_src_dst_loss_gv_op], feed_dict={self.warped_src :warped_src, nn.tf_sess.run ([src_D_src_dst_loss_gv_op], feed_dict={self.warped_src :warped_src,
self.target_src :target_src, self.target_src :target_src,
self.target_srcm:target_srcm, self.target_srcm_all:target_srcm_all,
self.target_srcm_em:target_srcm_em,
self.warped_dst :warped_dst, self.warped_dst :warped_dst,
self.target_dst :target_dst, self.target_dst :target_dst,
self.target_dstm:target_dstm, self.target_dstm_all:target_dstm_all})
self.target_dstm_em:target_dstm_em})
self.D_src_dst_train = D_src_dst_train self.D_src_dst_train = D_src_dst_train
@ -612,13 +453,13 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
self.AE_view = AE_view self.AE_view = AE_view
else: else:
# Initializing merge function # Initializing merge function
with tf.device( nn.tf_default_device_name if len(devices) != 0 else f'/CPU:0'): with tf.device( f'/GPU:0' if len(devices) != 0 else f'/CPU:0'):
if 'df' in archi_type: if 'df' in archi:
gpu_dst_code = self.inter(self.encoder(self.warped_dst)) gpu_dst_code = self.inter(self.encoder(self.warped_dst))
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code)
_, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code) _, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code)
elif 'liae' in archi_type: elif 'liae' in archi:
gpu_dst_code = self.encoder (self.warped_dst) gpu_dst_code = self.encoder (self.warped_dst)
gpu_dst_inter_B_code = self.inter_B (gpu_dst_code) gpu_dst_inter_B_code = self.inter_B (gpu_dst_code)
gpu_dst_inter_AB_code = self.inter_AB (gpu_dst_code) gpu_dst_inter_AB_code = self.inter_AB (gpu_dst_code)
@ -638,17 +479,14 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"): for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"):
if self.pretrain_just_disabled: if self.pretrain_just_disabled:
do_init = False do_init = False
if 'df' in archi_type: if 'df' in archi:
if model == self.inter: if model == self.inter:
do_init = True do_init = True
elif 'liae' in archi_type: elif 'liae' in archi:
if model == self.inter_AB or model == self.inter_B: if model == self.inter_AB or model == self.inter_B:
do_init = True do_init = True
else: else:
do_init = self.is_first_run() do_init = self.is_first_run()
if self.is_training and gan_power != 0 and model == self.D_src:
if self.gan_model_changed:
do_init = True
if not do_init: if not do_init:
do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) ) do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) )
@ -656,9 +494,6 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
if do_init: if do_init:
model.init_weights() model.init_weights()
###############
# initializing sample generators # initializing sample generators
if self.is_training: if self.is_training:
training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path() training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path()
@ -666,7 +501,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
random_ct_samples_path=training_data_dst_path if ct_mode is not None and not self.pretrain else None random_ct_samples_path=training_data_dst_path if ct_mode is not None and not self.pretrain else None
cpu_count = multiprocessing.cpu_count() cpu_count = min(multiprocessing.cpu_count(), 8)
src_generators_count = cpu_count // 2 src_generators_count = cpu_count // 2
dst_generators_count = cpu_count // 2 dst_generators_count = cpu_count // 2
if ct_mode is not None: if ct_mode is not None:
@ -674,81 +509,28 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
self.set_training_data_generators ([ self.set_training_data_generators ([
SampleGeneratorFace(training_data_src_path, random_ct_samples_path=random_ct_samples_path, debug=self.is_debug(), batch_size=self.get_batch_size(), SampleGeneratorFace(training_data_src_path, random_ct_samples_path=random_ct_samples_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
sample_process_options=SampleProcessor.Options(scale_range=[-0.15, 0.15], random_flip=random_src_flip), sample_process_options=SampleProcessor.Options(random_flip=self.random_flip),
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'random_hsv_shift_amount' : random_hsv_power, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':self.options['random_warp'], 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE_EYES, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.EYES_MOUTH, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
], ],
uniform_yaw_distribution=self.options['uniform_yaw'] or self.pretrain,
generators_count=src_generators_count ), generators_count=src_generators_count ),
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
sample_process_options=SampleProcessor.Options(scale_range=[-0.15, 0.15], random_flip=random_dst_flip), sample_process_options=SampleProcessor.Options(random_flip=self.random_flip),
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':self.options['random_warp'], 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE_EYES, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.EYES_MOUTH, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
], ],
uniform_yaw_distribution=self.options['uniform_yaw'] or self.pretrain,
generators_count=dst_generators_count ) generators_count=dst_generators_count )
]) ])
self.last_src_samples_loss = []
self.last_dst_samples_loss = []
if self.pretrain_just_disabled: if self.pretrain_just_disabled:
self.update_sample_for_preview(force_new=True) self.update_sample_for_preview(force_new=True)
def export_dfm (self):
output_path=self.get_strpath_storage_for_file('model.dfm')
io.log_info(f'Dumping .dfm to {output_path}')
tf = nn.tf
nn.set_data_format('NCHW')
with tf.device (nn.tf_default_device_name):
warped_dst = tf.placeholder (nn.floatx, (None, self.resolution, self.resolution, 3), name='in_face')
warped_dst = tf.transpose(warped_dst, (0,3,1,2))
if 'df' in self.archi_type:
gpu_dst_code = self.inter(self.encoder(warped_dst))
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code)
_, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code)
elif 'liae' in self.archi_type:
gpu_dst_code = self.encoder (warped_dst)
gpu_dst_inter_B_code = self.inter_B (gpu_dst_code)
gpu_dst_inter_AB_code = self.inter_AB (gpu_dst_code)
gpu_dst_code = tf.concat([gpu_dst_inter_B_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis)
gpu_src_dst_code = tf.concat([gpu_dst_inter_AB_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis)
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code)
_, gpu_pred_dst_dstm = self.decoder(gpu_dst_code)
gpu_pred_src_dst = tf.transpose(gpu_pred_src_dst, (0,2,3,1))
gpu_pred_dst_dstm = tf.transpose(gpu_pred_dst_dstm, (0,2,3,1))
gpu_pred_src_dstm = tf.transpose(gpu_pred_src_dstm, (0,2,3,1))
tf.identity(gpu_pred_dst_dstm, name='out_face_mask')
tf.identity(gpu_pred_src_dst, name='out_celeb_face')
tf.identity(gpu_pred_src_dstm, name='out_celeb_face_mask')
output_graph_def = tf.graph_util.convert_variables_to_constants(
nn.tf_sess,
tf.get_default_graph().as_graph_def(),
['out_face_mask','out_celeb_face','out_celeb_face_mask']
)
import tf2onnx
with tf.device("/CPU:0"):
model_proto, _ = tf2onnx.convert._convert_common(
output_graph_def,
name='SAEHD',
input_names=['in_face:0'],
output_names=['out_face_mask:0','out_celeb_face:0','out_celeb_face_mask:0'],
opset=12,
output_path=output_path)
#override #override
def get_model_filename_list(self): def get_model_filename_list(self):
return self.model_filename_list return self.model_filename_list
@ -758,38 +540,54 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
for model, filename in io.progress_bar_generator(self.get_model_filename_list(), "Saving", leave=False): for model, filename in io.progress_bar_generator(self.get_model_filename_list(), "Saving", leave=False):
model.save_weights ( self.get_strpath_storage_for_file(filename) ) model.save_weights ( self.get_strpath_storage_for_file(filename) )
#override
def should_save_preview_history(self):
return (not io.is_colab() and self.iter % ( 10*(max(1,self.resolution // 64)) ) == 0) or \
(io.is_colab() and self.iter % 100 == 0)
#override #override
def onTrainOneIter(self): def onTrainOneIter(self):
if self.get_iter() == 0 and not self.pretrain and not self.pretrain_just_disabled: bs = self.get_batch_size()
io.log_info('You are training the model from scratch. It is strongly recommended to use a pretrained model to speed up the training and improve the quality.\n')
( (warped_src, target_src, target_srcm, target_srcm_em), \ ( (warped_src, target_src, target_srcm_all), \
(warped_dst, target_dst, target_dstm, target_dstm_em) ) = self.generate_next_samples() (warped_dst, target_dst, target_dstm_all) ) = self.generate_next_samples()
src_loss, dst_loss = self.src_dst_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em) src_loss, dst_loss = self.src_dst_train (warped_src, target_src, target_srcm_all, warped_dst, target_dst, target_dstm_all)
for i in range(bs):
self.last_src_samples_loss.append ( (target_src[i], target_srcm_all[i], src_loss[i] ) )
self.last_dst_samples_loss.append ( (target_dst[i], target_dstm_all[i], dst_loss[i] ) )
if len(self.last_src_samples_loss) >= bs*16:
src_samples_loss = sorted(self.last_src_samples_loss, key=operator.itemgetter(2), reverse=True)
dst_samples_loss = sorted(self.last_dst_samples_loss, key=operator.itemgetter(2), reverse=True)
target_src = np.stack( [ x[0] for x in src_samples_loss[:bs] ] )
target_srcm_all = np.stack( [ x[1] for x in src_samples_loss[:bs] ] )
target_dst = np.stack( [ x[0] for x in dst_samples_loss[:bs] ] )
target_dstm_all = np.stack( [ x[1] for x in dst_samples_loss[:bs] ] )
src_loss, dst_loss = self.src_dst_train (target_src, target_src, target_srcm_all, target_dst, target_dst, target_dstm_all)
self.last_src_samples_loss = []
self.last_dst_samples_loss = []
if self.options['true_face_power'] != 0 and not self.pretrain: if self.options['true_face_power'] != 0 and not self.pretrain:
self.D_train (warped_src, warped_dst) self.D_train (warped_src, warped_dst)
if self.gan_power != 0: if self.gan_power != 0:
self.D_src_dst_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em) self.D_src_dst_train (warped_src, target_src, target_srcm_all, warped_dst, target_dst, target_dstm_all)
return ( ('src_loss', np.mean(src_loss) ), ('dst_loss', np.mean(dst_loss) ), ) return ( ('src_loss', np.mean(src_loss) ), ('dst_loss', np.mean(dst_loss) ), )
#override #override
def onGetPreview(self, samples, for_history=False): def onGetPreview(self, samples):
( (warped_src, target_src, target_srcm, target_srcm_em), ( (warped_src, target_src, target_srcm_all,),
(warped_dst, target_dst, target_dstm, target_dstm_em) ) = samples (warped_dst, target_dst, target_dstm_all,) ) = samples
S, D, SS, DD, DDM, SD, SDM = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([target_src,target_dst] + self.AE_view (target_src, target_dst) ) ] S, D, SS, DD, DDM, SD, SDM = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([target_src,target_dst] + self.AE_view (target_src, target_dst) ) ]
DDM, SDM, = [ np.repeat (x, (3,), -1) for x in [DDM, SDM] ] DDM, SDM, = [ np.repeat (x, (3,), -1) for x in [DDM, SDM] ]
target_srcm, target_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format) for x in ([target_srcm, target_dstm] )] target_srcm_all, target_dstm_all = [ nn.to_data_format(x,"NHWC", self.model_data_format) for x in ([target_srcm_all, target_dstm_all] )]
target_srcm = np.clip(target_srcm_all, 0, 1)
target_dstm = np.clip(target_dstm_all, 0, 1)
n_samples = min(4, self.get_batch_size(), 800 // self.resolution ) n_samples = min(4, self.get_batch_size(), 800 // self.resolution )

View file

@ -18,6 +18,8 @@ class XSegModel(ModelBase):
#override #override
def on_initialize_options(self): def on_initialize_options(self):
self.set_batch_size(4)
ask_override = self.ask_override() ask_override = self.ask_override()
if not self.is_first_run() and ask_override: if not self.is_first_run() and ask_override:
@ -25,24 +27,15 @@ class XSegModel(ModelBase):
self.set_iter(0) self.set_iter(0)
default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'wf') default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'wf')
default_pretrain = self.options['pretrain'] = self.load_or_def_option('pretrain', False)
if self.is_first_run(): if self.is_first_run():
self.options['face_type'] = io.input_str ("Face type", default_face_type, ['h','mf','f','wf','head'], help_message="Half / mid face / full face / whole face / head. Choose the same as your deepfake model.").lower() self.options['face_type'] = io.input_str ("Face type", default_face_type, ['h','mf','f','wf','head'], help_message="Half / mid face / full face / whole face / head. Choose the same as your deepfake model.").lower()
if self.is_first_run() or ask_override:
self.ask_batch_size(4, range=[2,16])
self.options['pretrain'] = io.input_bool ("Enable pretraining mode", default_pretrain)
if not self.is_exporting and (self.options['pretrain'] and self.get_pretraining_data_path() is None):
raise Exception("pretraining_data_path is not defined")
self.pretrain_just_disabled = (default_pretrain == True and self.options['pretrain'] == False)
#override #override
def on_initialize(self): def on_initialize(self):
device_config = nn.getCurrentDeviceConfig() device_config = nn.getCurrentDeviceConfig()
self.model_data_format = "NCHW" if self.is_exporting or (len(device_config.devices) != 0 and not self.is_debug()) else "NHWC" self.model_data_format = "NCHW" if len(device_config.devices) != 0 and not self.is_debug() else "NHWC"
nn.initialize(data_format=self.model_data_format) nn.initialize(data_format=self.model_data_format)
tf = nn.tf tf = nn.tf
@ -58,9 +51,8 @@ class XSegModel(ModelBase):
'wf' : FaceType.WHOLE_FACE, 'wf' : FaceType.WHOLE_FACE,
'head' : FaceType.HEAD}[ self.options['face_type'] ] 'head' : FaceType.HEAD}[ self.options['face_type'] ]
place_model_on_cpu = len(devices) == 0 place_model_on_cpu = len(devices) == 0
models_opt_device = '/CPU:0' if place_model_on_cpu else nn.tf_default_device_name models_opt_device = '/CPU:0' if place_model_on_cpu else '/GPU:0'
bgr_shape = nn.get4Dshape(resolution,resolution,3) bgr_shape = nn.get4Dshape(resolution,resolution,3)
mask_shape = nn.get4Dshape(resolution,resolution,1) mask_shape = nn.get4Dshape(resolution,resolution,1)
@ -75,16 +67,13 @@ class XSegModel(ModelBase):
optimizer=nn.RMSprop(lr=0.0001, lr_dropout=0.3, name='opt'), optimizer=nn.RMSprop(lr=0.0001, lr_dropout=0.3, name='opt'),
data_format=nn.data_format) data_format=nn.data_format)
self.pretrain = self.options['pretrain']
if self.pretrain_just_disabled:
self.set_iter(0)
if self.is_training: if self.is_training:
# Adjust batch size for multiple GPU # Adjust batch size for multiple GPU
gpu_count = max(1, len(devices) ) gpu_count = max(1, len(devices) )
bs_per_gpu = max(1, self.get_batch_size() // gpu_count) bs_per_gpu = max(1, self.get_batch_size() // gpu_count)
self.set_batch_size( gpu_count*bs_per_gpu) self.set_batch_size( gpu_count*bs_per_gpu)
# Compute losses per GPU # Compute losses per GPU
gpu_pred_list = [] gpu_pred_list = []
@ -92,7 +81,8 @@ class XSegModel(ModelBase):
gpu_loss_gvs = [] gpu_loss_gvs = []
for gpu_id in range(gpu_count): for gpu_id in range(gpu_count):
with tf.device(f'/{devices[gpu_id].tf_dev_type}:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ): with tf.device( f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ):
with tf.device(f'/CPU:0'): with tf.device(f'/CPU:0'):
# slice on CPU, otherwise all batch data will be transfered to GPU first # slice on CPU, otherwise all batch data will be transfered to GPU first
batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu ) batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu )
@ -100,41 +90,27 @@ class XSegModel(ModelBase):
gpu_target_t = self.model.target_t [batch_slice,:,:,:] gpu_target_t = self.model.target_t [batch_slice,:,:,:]
# process model tensors # process model tensors
gpu_pred_logits_t, gpu_pred_t = self.model.flow(gpu_input_t, pretrain=self.pretrain) gpu_pred_logits_t, gpu_pred_t = self.model.flow(gpu_input_t)
gpu_pred_list.append(gpu_pred_t) gpu_pred_list.append(gpu_pred_t)
gpu_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=gpu_target_t, logits=gpu_pred_logits_t), axis=[1,2,3])
if self.pretrain:
# Structural loss
gpu_loss = tf.reduce_mean (5*nn.dssim(gpu_target_t, gpu_pred_t, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
gpu_loss += tf.reduce_mean (5*nn.dssim(gpu_target_t, gpu_pred_t, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1])
# Pixel loss
gpu_loss += tf.reduce_mean (10*tf.square(gpu_target_t-gpu_pred_t), axis=[1,2,3])
else:
gpu_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=gpu_target_t, logits=gpu_pred_logits_t), axis=[1,2,3])
gpu_losses += [gpu_loss] gpu_losses += [gpu_loss]
gpu_loss_gvs += [ nn.gradients ( gpu_loss, self.model.get_weights() ) ] gpu_loss_gvs += [ nn.gradients ( gpu_loss, self.model.get_weights() ) ]
# Average losses and gradients, and create optimizer update ops # Average losses and gradients, and create optimizer update ops
#with tf.device(f'/CPU:0'): # Temporary fix. Unknown bug with training freeze starts from 2.4.0, but 2.3.1 was ok
with tf.device (models_opt_device): with tf.device (models_opt_device):
pred = tf.concat(gpu_pred_list, 0) pred = nn.concat(gpu_pred_list, 0)
loss = tf.concat(gpu_losses, 0) loss = tf.reduce_mean(gpu_losses)
loss_gv_op = self.model.opt.get_update_op (nn.average_gv_list (gpu_loss_gvs)) loss_gv_op = self.model.opt.get_update_op (nn.average_gv_list (gpu_loss_gvs))
# Initializing training and view functions # Initializing training and view functions
if self.pretrain: def train(input_np, target_np):
def train(input_np, target_np): l, _ = nn.tf_sess.run ( [loss, loss_gv_op], feed_dict={self.model.input_t :input_np, self.model.target_t :target_np })
l, _ = nn.tf_sess.run ( [loss, loss_gv_op], feed_dict={self.model.input_t :input_np, self.model.target_t :target_np}) return l
return l
else:
def train(input_np, target_np):
l, _ = nn.tf_sess.run ( [loss, loss_gv_op], feed_dict={self.model.input_t :input_np, self.model.target_t :target_np })
return l
self.train = train self.train = train
def view(input_np): def view(input_np):
@ -147,38 +123,29 @@ class XSegModel(ModelBase):
src_generators_count = cpu_count // 2 src_generators_count = cpu_count // 2
dst_generators_count = cpu_count // 2 dst_generators_count = cpu_count // 2
if self.pretrain:
pretrain_gen = SampleGeneratorFace(self.get_pretraining_data_path(), debug=self.is_debug(), batch_size=self.get_batch_size(),
sample_process_options=SampleProcessor.Options(random_flip=True),
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
],
uniform_yaw_distribution=False,
generators_count=cpu_count )
self.set_training_data_generators ([pretrain_gen])
else:
srcdst_generator = SampleGeneratorFaceXSeg([self.training_data_src_path, self.training_data_dst_path],
debug=self.is_debug(),
batch_size=self.get_batch_size(),
resolution=resolution,
face_type=self.face_type,
generators_count=src_dst_generators_count,
data_format=nn.data_format)
src_generator = SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(), srcdst_generator = SampleGeneratorFaceXSeg([self.training_data_src_path, self.training_data_dst_path],
sample_process_options=SampleProcessor.Options(random_flip=False), debug=self.is_debug(),
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp':False, 'transform':False, 'channel_type' : SampleProcessor.ChannelType.BGR, 'border_replicate':False, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, batch_size=self.get_batch_size(),
], resolution=resolution,
generators_count=src_generators_count, face_type=self.face_type,
raise_on_no_data=False ) generators_count=src_dst_generators_count,
dst_generator = SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), data_format=nn.data_format)
sample_process_options=SampleProcessor.Options(random_flip=False),
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp':False, 'transform':False, 'channel_type' : SampleProcessor.ChannelType.BGR, 'border_replicate':False, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
],
generators_count=dst_generators_count,
raise_on_no_data=False )
self.set_training_data_generators ([srcdst_generator, src_generator, dst_generator]) src_generator = SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
sample_process_options=SampleProcessor.Options(random_flip=False),
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp':False, 'transform':False, 'channel_type' : SampleProcessor.ChannelType.BGR, 'border_replicate':False, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
],
generators_count=src_generators_count,
raise_on_no_data=False )
dst_generator = SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
sample_process_options=SampleProcessor.Options(random_flip=False),
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp':False, 'transform':False, 'channel_type' : SampleProcessor.ChannelType.BGR, 'border_replicate':False, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
],
generators_count=dst_generators_count,
raise_on_no_data=False )
self.set_training_data_generators ([srcdst_generator, src_generator, dst_generator])
#override #override
def get_model_filename_list(self): def get_model_filename_list(self):
@ -190,21 +157,19 @@ class XSegModel(ModelBase):
#override #override
def onTrainOneIter(self): def onTrainOneIter(self):
image_np, target_np = self.generate_next_samples()[0]
loss = self.train (image_np, target_np)
return ( ('loss', np.mean(loss) ), )
image_np, mask_np = self.generate_next_samples()[0]
loss = self.train (image_np, mask_np)
return ( ('loss', loss ), )
#override #override
def onGetPreview(self, samples, for_history=False): def onGetPreview(self, samples):
n_samples = min(4, self.get_batch_size(), 800 // self.resolution ) n_samples = min(4, self.get_batch_size(), 800 // self.resolution )
if self.pretrain: srcdst_samples, src_samples, dst_samples = samples
srcdst_samples, = samples image_np, mask_np = srcdst_samples
image_np, mask_np = srcdst_samples
else:
srcdst_samples, src_samples, dst_samples = samples
image_np, mask_np = srcdst_samples
I, M, IM, = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([image_np,mask_np] + self.view (image_np) ) ] I, M, IM, = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([image_np,mask_np] + self.view (image_np) ) ]
M, IM, = [ np.repeat (x, (3,), -1) for x in [M, IM] ] M, IM, = [ np.repeat (x, (3,), -1) for x in [M, IM] ]
@ -214,14 +179,11 @@ class XSegModel(ModelBase):
result = [] result = []
st = [] st = []
for i in range(n_samples): for i in range(n_samples):
if self.pretrain: ar = I[i]*M[i]+0.5*I[i]*(1-M[i])+0.5*green_bg*(1-M[i]), IM[i], I[i]*IM[i]+0.5*I[i]*(1-IM[i]) + 0.5*green_bg*(1-IM[i])
ar = I[i], IM[i]
else:
ar = I[i]*M[i]+0.5*I[i]*(1-M[i])+0.5*green_bg*(1-M[i]), IM[i], I[i]*IM[i]+0.5*I[i]*(1-IM[i]) + 0.5*green_bg*(1-IM[i])
st.append ( np.concatenate ( ar, axis=1) ) st.append ( np.concatenate ( ar, axis=1) )
result += [ ('XSeg training faces', np.concatenate (st, axis=0 )), ] result += [ ('XSeg training faces', np.concatenate (st, axis=0 )), ]
if not self.pretrain and len(src_samples) != 0: if len(src_samples) != 0:
src_np, = src_samples src_np, = src_samples
@ -235,7 +197,7 @@ class XSegModel(ModelBase):
result += [ ('XSeg src faces', np.concatenate (st, axis=0 )), ] result += [ ('XSeg src faces', np.concatenate (st, axis=0 )), ]
if not self.pretrain and len(dst_samples) != 0: if len(dst_samples) != 0:
dst_np, = dst_samples dst_np, = dst_samples
@ -251,33 +213,4 @@ class XSegModel(ModelBase):
return result return result
def export_dfm (self):
output_path = self.get_strpath_storage_for_file(f'model.onnx')
io.log_info(f'Dumping .onnx to {output_path}')
tf = nn.tf
with tf.device (nn.tf_default_device_name):
input_t = tf.placeholder (nn.floatx, (None, self.resolution, self.resolution, 3), name='in_face')
input_t = tf.transpose(input_t, (0,3,1,2))
_, pred_t = self.model.flow(input_t)
pred_t = tf.transpose(pred_t, (0,2,3,1))
tf.identity(pred_t, name='out_mask')
output_graph_def = tf.graph_util.convert_variables_to_constants(
nn.tf_sess,
tf.get_default_graph().as_graph_def(),
['out_mask']
)
import tf2onnx
with tf.device("/CPU:0"):
model_proto, _ = tf2onnx.convert._convert_common(
output_graph_def,
name='XSeg',
input_names=['in_face:0'],
output_names=['out_mask:0'],
opset=13,
output_path=output_path)
Model = XSegModel Model = XSegModel

50
project.code-workspace Normal file
View file

@ -0,0 +1,50 @@
{
"folders": [
{
"path": "."
}
],
"settings": {
"workbench.colorTheme": "Visual Studio Light",
"diffEditor.ignoreTrimWhitespace": true,
"workbench.sideBar.location": "right",
"breadcrumbs.enabled": false,
"editor.renderWhitespace": "none",
"editor.minimap.enabled": false,
"workbench.activityBar.visible": true,
"window.menuBarVisibility": "default",
"editor.fastScrollSensitivity": 10,
"editor.mouseWheelScrollSensitivity": 2,
"window.zoomLevel": 0,
"extensions.ignoreRecommendations": true,
"python.linting.pylintEnabled": false,
"python.linting.enabled": false,
"python.linting.pylamaEnabled": false,
"python.linting.pydocstyleEnabled": false,
"python.pythonPath": "${env:PYTHON_EXECUTABLE}",
"workbench.editor.tabCloseButton": "off",
"workbench.editor.tabSizing": "shrink",
"workbench.editor.highlightModifiedTabs": true,
"editor.mouseWheelScrollSensitivity": 3,
"editor.folding": false,
"editor.glyphMargin": false,
"files.exclude": {
"**/__pycache__": true,
"**/.github": true,
"**/.vscode": true,
"**/*.dat": true,
"**/*.h5": true,
"**/*.npy": true
},
"editor.quickSuggestions": {
"other": false,
"comments": false,
"strings": false
},
"editor.trimAutoWhitespace": false,
"python.linting.pylintArgs": [
"--disable=import-error"
]
}
}

View file

@ -1,11 +1,9 @@
tqdm tqdm
numpy==1.19.3 numpy==1.17.0
numexpr h5py==2.9.0
h5py==2.10.0
opencv-python==4.1.0.25 opencv-python==4.1.0.25
ffmpeg-python==0.1.17 ffmpeg-python==0.1.17
scikit-image==0.14.2 scikit-image==0.14.2
scipy==1.4.1 scipy==1.4.1
colorama colorama
tensorflow-gpu==2.4.0 tensorflow-gpu==1.13.2
tf2onnx==1.9.3

View file

@ -1,12 +1,11 @@
tqdm tqdm
numpy==1.19.3 numpy==1.17.0
numexpr h5py==2.9.0
h5py==2.10.0
opencv-python==4.1.0.25 opencv-python==4.1.0.25
ffmpeg-python==0.1.17 ffmpeg-python==0.1.17
scikit-image==0.14.2 scikit-image==0.14.2
scipy==1.4.1 scipy==1.4.1
colorama colorama
tensorflow-gpu==2.4.0 labelme==4.2.9
tensorflow-gpu==1.13.2
pyqt5 pyqt5
tf2onnx==1.9.3

View file

@ -85,17 +85,16 @@ class PackedFaceset():
of.seek(0,2) of.seek(0,2)
of.close() of.close()
if io.input_bool(f"Delete original files?", True): for filename in io.progress_bar_generator(image_paths, "Deleting files"):
for filename in io.progress_bar_generator(image_paths, "Deleting files"): Path(filename).unlink()
Path(filename).unlink()
if as_person_faceset: if as_person_faceset:
for dir_name in io.progress_bar_generator(dir_names, "Deleting dirs"): for dir_name in io.progress_bar_generator(dir_names, "Deleting dirs"):
dir_path = samples_path / dir_name dir_path = samples_path / dir_name
try: try:
shutil.rmtree(dir_path) shutil.rmtree(dir_path)
except: except:
io.log_info (f"unable to remove: {dir_path} ") io.log_info (f"unable to remove: {dir_path} ")
@staticmethod @staticmethod
def unpack(samples_path): def unpack(samples_path):
@ -121,11 +120,6 @@ class PackedFaceset():
samples_dat_path.unlink() samples_dat_path.unlink()
@staticmethod
def path_contains(samples_path):
samples_dat_path = samples_path / packed_faceset_filename
return samples_dat_path.exists()
@staticmethod @staticmethod
def load(samples_path): def load(samples_path):
samples_dat_path = samples_path / packed_faceset_filename samples_dat_path = samples_path / packed_faceset_filename

View file

@ -5,8 +5,8 @@ import cv2
import numpy as np import numpy as np
from core.cv2ex import * from core.cv2ex import *
from DFLIMG import *
from facelib import LandmarksProcessor from facelib import LandmarksProcessor
from core import imagelib
from core.imagelib import SegIEPolys from core.imagelib import SegIEPolys
class SampleType(IntEnum): class SampleType(IntEnum):
@ -28,7 +28,6 @@ class Sample(object):
'landmarks', 'landmarks',
'seg_ie_polys', 'seg_ie_polys',
'xseg_mask', 'xseg_mask',
'xseg_mask_compressed',
'eyebrows_expand_mod', 'eyebrows_expand_mod',
'source_filename', 'source_filename',
'person_name', 'person_name',
@ -43,7 +42,6 @@ class Sample(object):
landmarks=None, landmarks=None,
seg_ie_polys=None, seg_ie_polys=None,
xseg_mask=None, xseg_mask=None,
xseg_mask_compressed=None,
eyebrows_expand_mod=None, eyebrows_expand_mod=None,
source_filename=None, source_filename=None,
person_name=None, person_name=None,
@ -62,16 +60,6 @@ class Sample(object):
self.seg_ie_polys = SegIEPolys.load(seg_ie_polys) self.seg_ie_polys = SegIEPolys.load(seg_ie_polys)
self.xseg_mask = xseg_mask self.xseg_mask = xseg_mask
self.xseg_mask_compressed = xseg_mask_compressed
if self.xseg_mask_compressed is None and self.xseg_mask is not None:
xseg_mask = np.clip( imagelib.normalize_channels(xseg_mask, 1)*255, 0, 255 ).astype(np.uint8)
ret, xseg_mask_compressed = cv2.imencode('.png', xseg_mask)
if not ret:
raise Exception("Sample(): unable to generate xseg_mask_compressed")
self.xseg_mask_compressed = xseg_mask_compressed
self.xseg_mask = None
self.eyebrows_expand_mod = eyebrows_expand_mod if eyebrows_expand_mod is not None else 1.0 self.eyebrows_expand_mod = eyebrows_expand_mod if eyebrows_expand_mod is not None else 1.0
self.source_filename = source_filename self.source_filename = source_filename
self.person_name = person_name self.person_name = person_name
@ -79,17 +67,6 @@ class Sample(object):
self._filename_offset_size = None self._filename_offset_size = None
def has_xseg_mask(self):
return self.xseg_mask is not None or self.xseg_mask_compressed is not None
def get_xseg_mask(self):
if self.xseg_mask_compressed is not None:
xseg_mask = cv2.imdecode(self.xseg_mask_compressed, cv2.IMREAD_UNCHANGED)
if len(xseg_mask.shape) == 2:
xseg_mask = xseg_mask[...,None]
return xseg_mask.astype(np.float32) / 255.0
return self.xseg_mask
def get_pitch_yaw_roll(self): def get_pitch_yaw_roll(self):
if self.pitch_yaw_roll is None: if self.pitch_yaw_roll is None:
self.pitch_yaw_roll = LandmarksProcessor.estimate_pitch_yaw_roll(self.landmarks, size=self.shape[1]) self.pitch_yaw_roll = LandmarksProcessor.estimate_pitch_yaw_roll(self.landmarks, size=self.shape[1])
@ -120,7 +97,6 @@ class Sample(object):
'landmarks': self.landmarks.tolist(), 'landmarks': self.landmarks.tolist(),
'seg_ie_polys': self.seg_ie_polys.dump(), 'seg_ie_polys': self.seg_ie_polys.dump(),
'xseg_mask' : self.xseg_mask, 'xseg_mask' : self.xseg_mask,
'xseg_mask_compressed' : self.xseg_mask_compressed,
'eyebrows_expand_mod': self.eyebrows_expand_mod, 'eyebrows_expand_mod': self.eyebrows_expand_mod,
'source_filename': self.source_filename, 'source_filename': self.source_filename,
'person_name': self.person_name 'person_name': self.person_name

View file

@ -6,13 +6,11 @@ import cv2
import numpy as np import numpy as np
from core import mplib from core import mplib
from core.interact import interact as io
from core.joblib import SubprocessGenerator, ThisThreadGenerator from core.joblib import SubprocessGenerator, ThisThreadGenerator
from facelib import LandmarksProcessor from facelib import LandmarksProcessor
from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor, from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor,
SampleType) SampleType)
''' '''
arg arg
output_sample_types = [ output_sample_types = [
@ -25,15 +23,15 @@ class SampleGeneratorFace(SampleGeneratorBase):
random_ct_samples_path=None, random_ct_samples_path=None,
sample_process_options=SampleProcessor.Options(), sample_process_options=SampleProcessor.Options(),
output_sample_types=[], output_sample_types=[],
uniform_yaw_distribution=False, add_sample_idx=False,
generators_count=4, generators_count=4,
raise_on_no_data=True, raise_on_no_data=True,
**kwargs): **kwargs):
super().__init__(debug, batch_size) super().__init__(debug, batch_size)
self.initialized = False
self.sample_process_options = sample_process_options self.sample_process_options = sample_process_options
self.output_sample_types = output_sample_types self.output_sample_types = output_sample_types
self.add_sample_idx = add_sample_idx
if self.debug: if self.debug:
self.generators_count = 1 self.generators_count = 1
@ -43,39 +41,14 @@ class SampleGeneratorFace(SampleGeneratorBase):
samples = SampleLoader.load (SampleType.FACE, samples_path) samples = SampleLoader.load (SampleType.FACE, samples_path)
self.samples_len = len(samples) self.samples_len = len(samples)
self.initialized = False
if self.samples_len == 0: if self.samples_len == 0:
if raise_on_no_data: if raise_on_no_data:
raise ValueError('No training data provided.') raise ValueError('No training data provided.')
else: else:
return return
if uniform_yaw_distribution: index_host = mplib.IndexHost(self.samples_len)
samples_pyr = [ ( idx, sample.get_pitch_yaw_roll() ) for idx, sample in enumerate(samples) ]
grads = 128
#instead of math.pi / 2, using -1.2,+1.2 because actually maximum yaw for 2DFAN landmarks are -1.2+1.2
grads_space = np.linspace (-1.2, 1.2,grads)
yaws_sample_list = [None]*grads
for g in io.progress_bar_generator ( range(grads), "Sort by yaw"):
yaw = grads_space[g]
next_yaw = grads_space[g+1] if g < grads-1 else yaw
yaw_samples = []
for idx, pyr in samples_pyr:
s_yaw = -pyr[1]
if (g == 0 and s_yaw < next_yaw) or \
(g < grads-1 and s_yaw >= yaw and s_yaw < next_yaw) or \
(g == grads-1 and s_yaw >= yaw):
yaw_samples += [ idx ]
if len(yaw_samples) > 0:
yaws_sample_list[g] = yaw_samples
yaws_sample_list = [ y for y in yaws_sample_list if y is not None ]
index_host = mplib.Index2DHost( yaws_sample_list )
else:
index_host = mplib.IndexHost(self.samples_len)
if random_ct_samples_path is not None: if random_ct_samples_path is not None:
ct_samples = SampleLoader.load (SampleType.FACE, random_ct_samples_path) ct_samples = SampleLoader.load (SampleType.FACE, random_ct_samples_path)
@ -137,8 +110,14 @@ class SampleGeneratorFace(SampleGeneratorBase):
if batches is None: if batches is None:
batches = [ [] for _ in range(len(x)) ] batches = [ [] for _ in range(len(x)) ]
if self.add_sample_idx:
batches += [ [] ]
i_sample_idx = len(batches)-1
for i in range(len(x)): for i in range(len(x)):
batches[i].append ( x[i] ) batches[i].append ( x[i] )
if self.add_sample_idx:
batches[i_sample_idx].append (sample_idx)
yield [ np.array(batch) for batch in batches] yield [ np.array(batch) for batch in batches]

View file

@ -12,98 +12,6 @@ from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor,
SampleType) SampleType)
class Index2DHost():
"""
Provides random shuffled 2D indexes for multiprocesses
"""
def __init__(self, indexes2D):
self.sq = multiprocessing.Queue()
self.cqs = []
self.clis = []
self.thread = threading.Thread(target=self.host_thread, args=(indexes2D,) )
self.thread.daemon = True
self.thread.start()
def host_thread(self, indexes2D):
indexes_counts_len = len(indexes2D)
idxs = [*range(indexes_counts_len)]
idxs_2D = [None]*indexes_counts_len
shuffle_idxs = []
shuffle_idxs_2D = [None]*indexes_counts_len
for i in range(indexes_counts_len):
idxs_2D[i] = indexes2D[i]
shuffle_idxs_2D[i] = []
sq = self.sq
while True:
while not sq.empty():
obj = sq.get()
cq_id, cmd = obj[0], obj[1]
if cmd == 0: #get_1D
count = obj[2]
result = []
for i in range(count):
if len(shuffle_idxs) == 0:
shuffle_idxs = idxs.copy()
np.random.shuffle(shuffle_idxs)
result.append(shuffle_idxs.pop())
self.cqs[cq_id].put (result)
elif cmd == 1: #get_2D
targ_idxs,count = obj[2], obj[3]
result = []
for targ_idx in targ_idxs:
sub_idxs = []
for i in range(count):
ar = shuffle_idxs_2D[targ_idx]
if len(ar) == 0:
ar = shuffle_idxs_2D[targ_idx] = idxs_2D[targ_idx].copy()
np.random.shuffle(ar)
sub_idxs.append(ar.pop())
result.append (sub_idxs)
self.cqs[cq_id].put (result)
time.sleep(0.001)
def create_cli(self):
cq = multiprocessing.Queue()
self.cqs.append ( cq )
cq_id = len(self.cqs)-1
return Index2DHost.Cli(self.sq, cq, cq_id)
# disable pickling
def __getstate__(self):
return dict()
def __setstate__(self, d):
self.__dict__.update(d)
class Cli():
def __init__(self, sq, cq, cq_id):
self.sq = sq
self.cq = cq
self.cq_id = cq_id
def get_1D(self, count):
self.sq.put ( (self.cq_id,0, count) )
while True:
if not self.cq.empty():
return self.cq.get()
time.sleep(0.001)
def get_2D(self, idxs, count):
self.sq.put ( (self.cq_id,1,idxs,count) )
while True:
if not self.cq.empty():
return self.cq.get()
time.sleep(0.001)
''' '''
arg arg
output_sample_types = [ output_sample_types = [
@ -137,7 +45,7 @@ class SampleGeneratorFacePerson(SampleGeneratorBase):
for i,sample in enumerate(samples): for i,sample in enumerate(samples):
persons_name_idxs[sample.person_name].append (i) persons_name_idxs[sample.person_name].append (i)
indexes2D = [ persons_name_idxs[person_name] for person_name in unique_person_names ] indexes2D = [ persons_name_idxs[person_name] for person_name in unique_person_names ]
index2d_host = Index2DHost(indexes2D) index2d_host = mplib.Index2DHost(indexes2D)
if self.debug: if self.debug:
self.generators_count = 1 self.generators_count = 1

Some files were not shown because too many files have changed in this diff Show more