Compare commits
No commits in common. "master" and "DF.wf.288res.384.92.72.22" have entirely different histories.
master
...
DF.wf.288r
2
.vscode/launch.json
vendored
|
@ -12,7 +12,7 @@
|
|||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "${env:DFL_ROOT}\\main.py",
|
||||
"python": "${env:PYTHONEXECUTABLE}",
|
||||
"pythonPath": "${env:PYTHONEXECUTABLE}",
|
||||
"cwd": "${env:WORKSPACE}",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["train",
|
||||
|
|
|
@ -6,7 +6,6 @@ import cv2
|
|||
import numpy as np
|
||||
|
||||
from core import imagelib
|
||||
from core.cv2ex import *
|
||||
from core.imagelib import SegIEPolys
|
||||
from core.interact import interact as io
|
||||
from core.structex import *
|
||||
|
@ -20,8 +19,7 @@ class DFLJPG(object):
|
|||
self.length = 0
|
||||
self.chunks = []
|
||||
self.dfl_dict = None
|
||||
self.shape = None
|
||||
self.img = None
|
||||
self.shape = (0,0,0)
|
||||
|
||||
@staticmethod
|
||||
def load_raw(filename, loader_func=None):
|
||||
|
@ -138,6 +136,8 @@ class DFLJPG(object):
|
|||
|
||||
if id == b"JFIF":
|
||||
c, ver_major, ver_minor, units, Xdensity, Ydensity, Xthumbnail, Ythumbnail = struct_unpack (d, c, "=BBBHHBB")
|
||||
#if units == 0:
|
||||
# inst.shape = (Ydensity, Xdensity, 3)
|
||||
else:
|
||||
raise Exception("Unknown jpeg ID: %s" % (id) )
|
||||
elif chunk['name'] == 'SOF0' or chunk['name'] == 'SOF2':
|
||||
|
@ -205,16 +205,7 @@ class DFLJPG(object):
|
|||
|
||||
return data
|
||||
|
||||
def get_img(self):
|
||||
if self.img is None:
|
||||
self.img = cv2_imread(self.filename)
|
||||
return self.img
|
||||
|
||||
def get_shape(self):
|
||||
if self.shape is None:
|
||||
img = self.get_img()
|
||||
if img is not None:
|
||||
self.shape = img.shape
|
||||
return self.shape
|
||||
|
||||
def get_height(self):
|
||||
|
@ -281,13 +272,6 @@ class DFLJPG(object):
|
|||
def has_xseg_mask(self):
|
||||
return self.dfl_dict.get('xseg_mask',None) is not None
|
||||
|
||||
def get_xseg_mask_compressed(self):
|
||||
mask_buf = self.dfl_dict.get('xseg_mask',None)
|
||||
if mask_buf is None:
|
||||
return None
|
||||
|
||||
return mask_buf
|
||||
|
||||
def get_xseg_mask(self):
|
||||
mask_buf = self.dfl_dict.get('xseg_mask',None)
|
||||
if mask_buf is None:
|
||||
|
@ -308,7 +292,7 @@ class DFLJPG(object):
|
|||
mask_a = imagelib.normalize_channels(mask_a, 1)
|
||||
img_data = np.clip( mask_a*255, 0, 255 ).astype(np.uint8)
|
||||
|
||||
data_max_len = 50000
|
||||
data_max_len = 4096
|
||||
|
||||
ret, buf = cv2.imencode('.png', img_data)
|
||||
|
||||
|
|
265
README.md
|
@ -1,237 +1,176 @@
|
|||
<table align="center" border="0">
|
||||
|
||||
<tr><td colspan=2 align="center">
|
||||
<table align="center" border="0"><tr><td align="center" width="9999">
|
||||
|
||||
# DeepFaceLab
|
||||
### the leading software for creating deepfakes
|
||||
|
||||
<a href="https://arxiv.org/abs/2005.05535">
|
||||
|
||||
<img src="https://static.arxiv.org/static/browse/0.3.0/images/icons/favicon.ico" width=14></img>
|
||||
https://arxiv.org/abs/2005.05535</a>
|
||||
<img src="doc/DFL_welcome.png" align="center">
|
||||
|
||||
</td></tr>
|
||||
<tr><td colspan=2 align="center">
|
||||
<tr><td align="center" width="9999">
|
||||
|
||||
<p align="center">
|
||||
|
||||

|
||||

|
||||

|
||||

|
||||

|
||||
|
||||
</p>
|
||||
|
||||
More than 95% of deepfake videos are created with DeepFaceLab.
|
||||
|
||||
DeepFaceLab is used by such popular youtube channels as
|
||||
|
||||
| [deeptomcruise](https://www.tiktok.com/@deeptomcruise)| [1facerussia](https://www.tiktok.com/@1facerussia)| [arnoldschwarzneggar](https://www.tiktok.com/@arnoldschwarzneggar)
|
||||
|---|---|---|
|
||||
| [Ctrl Shift Face](https://www.youtube.com/channel/UCKpH0CKltc73e4wh0_pgL3g)| [VFXChris Ume](https://www.youtube.com/channel/UCGf4OlX_aTt8DlrgiH3jN3g/videos)|
|
||||
|---|---|
|
||||
|
||||
| [mariahcareyathome?](https://www.tiktok.com/@mariahcareyathome?)| [diepnep](https://www.tiktok.com/@diepnep)| [mr__heisenberg](https://www.tiktok.com/@mr__heisenberg)| [deepcaprio](https://www.tiktok.com/@deepcaprio)
|
||||
| [Sham00k](https://www.youtube.com/channel/UCZXbWcv7fSZFTAZV4beckyw/videos)| [Collider videos](https://www.youtube.com/watch?v=A91P2qtPT54&list=PLayt6616lBclvOprvrC8qKGCO-mAhPRux)| [iFake](https://www.youtube.com/channel/UCC0lK2Zo2BMXX-k1Ks0r7dg/videos)| [NextFace](https://www.youtube.com/channel/UCFh3gL0a8BS21g-DHvXZEeQ/videos)|
|
||||
|---|---|---|---|
|
||||
|
||||
| [VFXChris Ume](https://www.youtube.com/channel/UCGf4OlX_aTt8DlrgiH3jN3g/videos)| [Sham00k](https://www.youtube.com/channel/UCZXbWcv7fSZFTAZV4beckyw/videos)|
|
||||
|---|---|
|
||||
|
||||
| [Collider videos](https://www.youtube.com/watch?v=A91P2qtPT54&list=PLayt6616lBclvOprvrC8qKGCO-mAhPRux)| [iFake](https://www.youtube.com/channel/UCC0lK2Zo2BMXX-k1Ks0r7dg/videos)| [NextFace](https://www.youtube.com/channel/UCFh3gL0a8BS21g-DHvXZEeQ/videos)|
|
||||
|---|---|---|
|
||||
|
||||
| [Futuring Machine](https://www.youtube.com/channel/UCC5BbFxqLQgfnWPhprmQLVg)| [RepresentUS](https://www.youtube.com/channel/UCRzgK52MmetD9aG8pDOID3g)| [Corridor Crew](https://www.youtube.com/c/corridorcrew/videos)|
|
||||
|---|---|---|
|
||||
|
||||
| [DeepFaker](https://www.youtube.com/channel/UCkHecfDTcSazNZSKPEhtPVQ)| [DeepFakes in movie](https://www.youtube.com/c/DeepFakesinmovie/videos)|
|
||||
|---|---|
|
||||
|
||||
| [DeepFakeCreator](https://www.youtube.com/channel/UCkNFhcYNLQ5hr6A6lZ56mKA)| [Jarkan](https://www.youtube.com/user/Jarkancio/videos)|
|
||||
|---|---|
|
||||
|
||||
</td></tr>
|
||||
|
||||
<tr><td colspan=2 align="center">
|
||||
<tr><td align="center" width="9999">
|
||||
|
||||
# What can I do using DeepFaceLab?
|
||||
|
||||
</td></tr>
|
||||
<tr><td colspan=2 align="center">
|
||||
<tr><td align="center" width="9999">
|
||||
|
||||
## Replace the face
|
||||
|
||||
<img src="doc/replace_the_face.jpg" align="center">
|
||||
<img src="doc/replace_the_face.png" align="center">
|
||||
|
||||
</td></tr>
|
||||
|
||||
<tr><td colspan=2 align="center">
|
||||
|
||||
## De-age the face
|
||||
|
||||
</td></tr>
|
||||
|
||||
<tr><td align="center" width="50%">
|
||||
|
||||
<img src="doc/deage_0_1.jpg" align="center">
|
||||
|
||||
</td>
|
||||
<td align="center" width="50%">
|
||||
|
||||
<img src="doc/deage_0_2.jpg" align="center">
|
||||
|
||||
</td></tr>
|
||||
|
||||
<tr><td colspan=2 align="center">
|
||||
|
||||
 https://www.youtube.com/watch?v=Ddx5B-84ebo
|
||||
|
||||
</td></tr>
|
||||
|
||||
<tr><td colspan=2 align="center">
|
||||
<tr><td align="center" width="9999">
|
||||
|
||||
## Replace the head
|
||||
|
||||
</td></tr>
|
||||
|
||||
<tr><td align="center" width="50%">
|
||||
<table align="center" border="0"><tr>
|
||||
|
||||
<img src="doc/head_replace_1_1.jpg" align="center">
|
||||
<td align="center" width="9999">
|
||||
|
||||
<img src="doc/head_replace_1.jpg" align="center">
|
||||
|
||||
</td>
|
||||
<td align="center" width="50%">
|
||||
|
||||
<img src="doc/head_replace_1_2.jpg" align="center">
|
||||
<td align="center" width="9999">
|
||||
|
||||
<img src="doc/head_replace_2.jpg" align="center">
|
||||
|
||||
</td>
|
||||
|
||||
</table>
|
||||
|
||||
 https://www.youtube.com/watch?v=xr5FHd0AdlQ
|
||||
|
||||
</td></tr>
|
||||
<tr><td align="center" width="9999">
|
||||
|
||||
<tr><td colspan=2 align="center">
|
||||
## Change the lip movement of politicians*
|
||||
|
||||
 https://www.youtube.com/watch?v=RTjgkhMugVw
|
||||
<img src="doc/political_speech.jpg" align="center">
|
||||
|
||||
 https://www.youtube.com/watch?v=2Z1oA3GYPaY
|
||||
|
||||
\* also requires a skill in video editors such as *Adobe After Effects* or *Davinci Resolve*
|
||||
</td></tr>
|
||||
<tr><td align="center" width="9999">
|
||||
|
||||
# Deepfake native resolution progress
|
||||
|
||||
</td></tr>
|
||||
|
||||
<tr><td colspan=2 align="center">
|
||||
|
||||
# Native resolution progress
|
||||
|
||||
</td></tr>
|
||||
<tr><td colspan=2 align="center">
|
||||
<tr><td align="center" width="9999">
|
||||
|
||||
<img src="doc/deepfake_progress.png" align="center">
|
||||
|
||||
</td></tr>
|
||||
<tr><td colspan=2 align="center">
|
||||
<tr><td align="center" width="9999">
|
||||
|
||||
<img src="doc/make_everything_ok.png" align="center">
|
||||
|
||||
Unfortunately, there is no "make everything ok" button in DeepFaceLab. You should spend time studying the workflow and growing your skills. A skill in programs such as *AfterEffects* or *Davinci Resolve* is also desirable.
|
||||
Unfortunately, there is no "make everything ok" button in DeepFaceLab. You should spend time studying the workflow and growing your skills. A skill in programs such as *AfterEffects* or *Davince Resolve* is also desirable.
|
||||
|
||||
</td></tr>
|
||||
<tr><td colspan=2 align="center">
|
||||
|
||||
## Mini tutorial
|
||||
|
||||
<a href="https://www.youtube.com/watch?v=kOIMXt8KK8M">
|
||||
|
||||
<img src="doc/mini_tutorial.jpg" align="center">
|
||||
|
||||
</a>
|
||||
|
||||
</td></tr>
|
||||
<tr><td colspan=2 align="center">
|
||||
<tr><td align="center" width="9999">
|
||||
|
||||
## Releases
|
||||
|
||||
</td></tr>
|
||||
|
||||
<tr><td align="right">
|
||||
<a href="https://tinyurl.com/2p9cvt25">Windows (magnet link)</a>
|
||||
</td><td align="center">Last release. Use torrent client to download.</td></tr>
|
||||
|
||||
<tr><td align="right">
|
||||
<a href="https://mega.nz/folder/Po0nGQrA#dbbttiNWojCt8jzD4xYaPw">Windows (Mega.nz)</a>
|
||||
</td><td align="center">Contains new and prev releases.</td></tr>
|
||||
|
||||
<tr><td align="right">
|
||||
<a href="https://disk.yandex.ru/d/7i5XTKIKVg5UUg">Windows (yandex.ru)</a>
|
||||
</td><td align="center">Contains new and prev releases.</td></tr>
|
||||
|
||||
<tr><td align="right">
|
||||
<a href="https://github.com/nagadit/DeepFaceLab_Linux">Linux (github)</a>
|
||||
</td><td align="center">by @nagadit</td></tr>
|
||||
|
||||
<tr><td align="right">
|
||||
<a href="https://github.com/elemantalcode/dfl">CentOS Linux (github)</a>
|
||||
</td><td align="center">May be outdated. By @elemantalcode</td></tr>
|
||||
|
||||
</table>
|
||||
|
||||
<table align="center" border="0">
|
||||
|
||||
<tr><td colspan=2 align="center">
|
||||
|
||||
### Communication groups
|
||||
||||
|
||||
|---|---|---|
|
||||
|Windows|[github releases](https://github.com/iperov/DeepFaceLab/releases)|Direct download|
|
||||
||[Google drive](https://drive.google.com/open?id=1BCFK_L7lPNwMbEQ_kFPqPpDdFEOd_Dci)|if the download quota is exceeded, add the file to your own google drive and download from it|
|
||||
|Google Colab|[github](https://github.com/chervonij/DFL-Colab)|by @chervonij . You can train fakes for free using Google Colab.|
|
||||
|CentOS Linux|[github](https://github.com/elemantalcode/dfl)|by @elemantalcode|
|
||||
|Linux|[github](https://github.com/lbfs/DeepFaceLab_Linux)|by @lbfs |
|
||||
||||
|
||||
|
||||
</td></tr>
|
||||
|
||||
<tr><td align="right">
|
||||
<a href="https://discord.gg/rxa7h9M6rH">Discord</a>
|
||||
</td><td align="center">Official discord channel. English / Russian.</td></tr>
|
||||
<tr><td align="center" width="9999">
|
||||
|
||||
<tr><td colspan=2 align="center">
|
||||
## Links
|
||||
|
||||
## Related works
|
||||
|
||||
||||
|
||||
|---|---|---|
|
||||
|Guides and tutorials|||
|
||||
||[DeepFaceLab guide](https://mrdeepfakes.com/forums/thread-guide-deepfacelab-2-0-explained-and-tutorials-recommended)|Main guide|
|
||||
||[Faceset creation guide](https://mrdeepfakes.com/forums/thread-guide-celebrity-faceset-dataset-creation-how-to-create-celebrity-facesets)|How to create the right faceset |
|
||||
||[Google Colab guide](https://mrdeepfakes.com/forums/thread-guide-deepfacelab-google-colab-tutorial)|Guide how to train the fake on Google Colab|
|
||||
||[Compositing](https://mrdeepfakes.com/forums/thread-deepfacelab-2-0-compositing-in-davinci-resolve-vegas-pro-and-after-effects)|To achieve the highest quality, compose deepfake manually in video editors such as Davince Resolve or Adobe AfterEffects|
|
||||
||[Discussion and suggestions](https://mrdeepfakes.com/forums/thread-deepfacelab-2-0-discussion-tips-suggestions)||
|
||||
||||
|
||||
|Supplementary material|||
|
||||
||[Ready to work facesets](https://mrdeepfakes.com/forums/forum-celebrity-facesets)|Celebrity facesets made by community|
|
||||
||[Pretrained models](https://mrdeepfakes.com/forums/forum-celebrity-facesets)|Use pretrained models made by community to speed up training|
|
||||
||||
|
||||
|Communication groups|||
|
||||
||[telegram (English / Russian)](https://t.me/DeepFaceLab_official)|Don't forget to hide your phone number|
|
||||
||[telegram (English only)](https://t.me/DeepFaceLab_official_en)|Don't forget to hide your phone number|
|
||||
||[Русское сообщество](https://mrdeepfakes.com/forums/forum-russian-community)||
|
||||
||[mrdeepfakes](https://mrdeepfakes.com/forums/)|the biggest NSFW English community|
|
||||
||[reddit r/GifFakes/](https://www.reddit.com/r/GifFakes/new/)|Post your deepfakes there !|
|
||||
||[reddit r/SFWdeepfakes/](https://www.reddit.com/r/SFWdeepfakes/new/)|Post your deepfakes there !|
|
||||
||QQ 951138799|中文 Chinese QQ group for ML/AI experts|
|
||||
||[deepfaker.xyz](https://www.deepfaker.xyz/)|中文学习站(非官方|
|
||||
|
||||
</td></tr>
|
||||
|
||||
<tr><td align="right">
|
||||
<a href="https://github.com/iperov/DeepFaceLive">DeepFaceLive</a>
|
||||
</td><td align="center">Real-time face swap for PC streaming or video calls</td></tr>
|
||||
|
||||
</td></tr>
|
||||
</table>
|
||||
|
||||
<table align="center" border="0">
|
||||
|
||||
<tr><td colspan=2 align="center">
|
||||
<tr><td align="center" width="9999">
|
||||
|
||||
## How I can help the project?
|
||||
|
||||
|||
|
||||
|---|---|
|
||||
|Donate|Sponsor deepfake research and DeepFaceLab development.|
|
||||
||[Donate via Paypal](https://www.paypal.com/cgi-bin/webscr?cmd=_donations&business=lepersorium@gmail.com&lc=US&no_note=0&item_name=Support+DeepFaceLab&cn=&curency_code=USD&bn=PP-DonationsBF:btn_donateCC_LG.gif:NonHosted)
|
||||
||[Donate via Yandex.Money](https://money.yandex.ru/to/41001142318065)|
|
||||
||bitcoin:bc1qkhh7h0gwwhxgg6h6gpllfgstkd645fefrd5s6z|
|
||||
|Alipay 捐款||
|
||||
|||
|
||||
|Last donations|20$ ( 飛星工作室 )|
|
||||
||100$ ( Peter S. )|
|
||||
||50$ ( John Lee )|
|
||||
|||
|
||||
|Collect facesets|You can collect faceset of any celebrity that can be used in DeepFaceLab and share it [in the community](https://mrdeepfakes.com/forums/forum-celebrity-facesets)|
|
||||
|||
|
||||
|Star this repo|Register github account and push "Star" button.|
|
||||
</td></tr>
|
||||
|
||||
<tr><td colspan=2 align="center">
|
||||
|
||||
### Star this repo
|
||||
|
||||
</td></tr>
|
||||
|
||||
<tr><td colspan=2 align="center">
|
||||
|
||||
Register github account and push "Star" button.
|
||||
|
||||
</td></tr>
|
||||
|
||||
</table>
|
||||
|
||||
<table align="center" border="0">
|
||||
<tr><td colspan=2 align="center">
|
||||
<tr><td align="center" width="9999">
|
||||
|
||||
## Meme zone
|
||||
<p align="center">
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
</p>
|
||||
|
||||
</td></tr>
|
||||
<tr><td align="center" width="9999">
|
||||
|
||||
<tr><td align="center" width="50%">
|
||||
|
||||
<img src="doc/meme1.jpg" align="center">
|
||||
|
||||
</td>
|
||||
|
||||
<td align="center" width="50%">
|
||||
|
||||
<img src="doc/meme2.jpg" align="center">
|
||||
<sub>#deepfacelab #deepfakes #faceswap #face-swap #deep-learning #deeplearning #deep-neural-networks #deepface #deep-face-swap #fakeapp #fake-app #neural-networks #neural-nets #tensorflow #cuda #nvidia</sub>
|
||||
|
||||
</td></tr>
|
||||
|
||||
<tr><td colspan=2 align="center">
|
||||
|
||||
<sub>#deepfacelab #faceswap #face-swap #deep-learning #deeplearning #deep-neural-networks #deepface #deep-face-swap #neural-networks #neural-nets #tensorflow #cuda #nvidia</sub>
|
||||
|
||||
</td></tr>
|
||||
|
||||
|
||||
|
||||
</table>
|
||||
|
|
|
@ -17,10 +17,6 @@ class QIconDB():
|
|||
QIconDB.poly_type_exclude = QIcon ( str(icon_path / 'poly_type_exclude.png') )
|
||||
QIconDB.left = QIcon ( str(icon_path / 'left.png') )
|
||||
QIconDB.right = QIcon ( str(icon_path / 'right.png') )
|
||||
QIconDB.trashcan = QIcon ( str(icon_path / 'trashcan.png') )
|
||||
QIconDB.pt_edit_mode = QIcon ( str(icon_path / 'pt_edit_mode.png') )
|
||||
QIconDB.view_lock_center = QIcon ( str(icon_path / 'view_lock_center.png') )
|
||||
QIconDB.view_baked = QIcon ( str(icon_path / 'view_baked.png') )
|
||||
QIconDB.view_xseg = QIcon ( str(icon_path / 'view_xseg.png') )
|
||||
QIconDB.view_xseg_overlay = QIcon ( str(icon_path / 'view_xseg_overlay.png') )
|
||||
|
|
@ -1,8 +0,0 @@
|
|||
from PyQt5.QtCore import *
|
||||
from PyQt5.QtGui import *
|
||||
from PyQt5.QtWidgets import *
|
||||
|
||||
class QImageDB():
|
||||
@staticmethod
|
||||
def initialize(image_path):
|
||||
QImageDB.intro = QImage ( str(image_path / 'intro.png') )
|
|
@ -35,11 +35,6 @@ class QStringDB():
|
|||
'zh' : '查看导入后的XSeg遮罩',
|
||||
}[lang]
|
||||
|
||||
QStringDB.btn_view_xseg_overlay_mask_tip = { 'en' : 'View trained XSeg mask overlay face',
|
||||
'ru' : 'Посмотреть тренированную XSeg маску поверх лица',
|
||||
'zh' : '查看导入后的XSeg遮罩于脸上方',
|
||||
}[lang]
|
||||
|
||||
QStringDB.btn_poly_type_include_tip = { 'en' : 'Poly include mode',
|
||||
'ru' : 'Режим полигонов - включение',
|
||||
'zh' : '包含选区模式',
|
||||
|
@ -65,17 +60,11 @@ class QStringDB():
|
|||
'zh' : '删除选区',
|
||||
}[lang]
|
||||
|
||||
QStringDB.btn_pt_edit_mode_tip = { 'en' : 'Add/delete point mode ( HOLD CTRL )',
|
||||
'ru' : 'Режим добавления/удаления точек ( удерживайте CTRL )',
|
||||
'zh' : '点加/删除模式 ( 按住CTRL )',
|
||||
QStringDB.btn_pt_edit_mode_tip = { 'en' : 'Edit point mode ( HOLD CTRL )',
|
||||
'ru' : 'Режим правки точек',
|
||||
'zh' : '编辑点模式 ( 按住CTRL )',
|
||||
}[lang]
|
||||
|
||||
QStringDB.btn_view_lock_center_tip = { 'en' : 'Lock cursor at the center ( HOLD SHIFT )',
|
||||
'ru' : 'Заблокировать курсор в центре ( удерживайте SHIFT )',
|
||||
'zh' : '将光标锁定在中心 ( 按住SHIFT )',
|
||||
}[lang]
|
||||
|
||||
|
||||
QStringDB.btn_prev_image_tip = { 'en' : 'Save and Prev image\nHold SHIFT : accelerate\nHold CTRL : skip non masked\n',
|
||||
'ru' : 'Сохранить и предыдущее изображение\nУдерживать SHIFT : ускорить\nУдерживать CTRL : пропустить неразмеченные\n',
|
||||
'zh' : '保存并转到上一张图片\n按住SHIFT : 加快\n按住CTRL : 跳过未标记的\n',
|
||||
|
@ -85,18 +74,4 @@ class QStringDB():
|
|||
'zh' : '保存并转到下一张图片\n按住SHIFT : 加快\n按住CTRL : 跳过未标记的\n',
|
||||
}[lang]
|
||||
|
||||
QStringDB.btn_delete_image_tip = { 'en' : 'Move to _trash and Next image\n',
|
||||
'ru' : 'Переместить в _trash и следующее изображение\n',
|
||||
'zh' : '移至_trash,转到下一张图片 ',
|
||||
}[lang]
|
||||
|
||||
QStringDB.loading_tip = {'en' : 'Loading',
|
||||
'ru' : 'Загрузка',
|
||||
'zh' : '正在载入',
|
||||
}[lang]
|
||||
|
||||
QStringDB.labeled_tip = {'en' : 'labeled',
|
||||
'ru' : 'размечено',
|
||||
'zh' : '标记的',
|
||||
}[lang]
|
||||
|
||||
|
|
|
@ -16,18 +16,18 @@ from PyQt5.QtCore import *
|
|||
from PyQt5.QtGui import *
|
||||
from PyQt5.QtWidgets import *
|
||||
|
||||
from core import imagelib, pathex
|
||||
from core import pathex
|
||||
from core.cv2ex import *
|
||||
from core import imagelib
|
||||
from core.imagelib import SegIEPoly, SegIEPolys, SegIEPolyType, sd
|
||||
from core.qtex import *
|
||||
from DFLIMG import *
|
||||
from localization import StringsDB, system_language
|
||||
from samplelib import PackedFaceset
|
||||
|
||||
from .QCursorDB import QCursorDB
|
||||
from .QIconDB import QIconDB
|
||||
from .QStringDB import QStringDB
|
||||
from .QImageDB import QImageDB
|
||||
|
||||
|
||||
class OpMode(IntEnum):
|
||||
NONE = 0
|
||||
|
@ -45,10 +45,6 @@ class DragType(IntEnum):
|
|||
IMAGE_LOOK = 1
|
||||
POLY_PT = 2
|
||||
|
||||
class ViewLock(IntEnum):
|
||||
NONE = 0
|
||||
CENTER = 1
|
||||
|
||||
class QUIConfig():
|
||||
@staticmethod
|
||||
def initialize(icon_size = 48, icon_spacer_size=16, preview_bar_icon_size=64):
|
||||
|
@ -70,24 +66,12 @@ class ImagePreviewSequenceBar(QFrame):
|
|||
|
||||
main_frame_l_cont_hl = QGridLayout()
|
||||
main_frame_l_cont_hl.setContentsMargins(0,0,0,0)
|
||||
#main_frame_l_cont_hl.setSpacing(0)
|
||||
|
||||
|
||||
|
||||
for i in range(len(self.image_containers)):
|
||||
q_label = self.image_containers[i]
|
||||
q_label.setScaledContents(True)
|
||||
if i == preview_images_count//2:
|
||||
q_label.setMinimumSize(icon_size+16, icon_size+16 )
|
||||
q_label.setMaximumSize(icon_size+16, icon_size+16 )
|
||||
else:
|
||||
q_label.setMinimumSize(icon_size, icon_size )
|
||||
q_label.setMaximumSize(icon_size, icon_size )
|
||||
opacity_effect = QGraphicsOpacityEffect()
|
||||
opacity_effect.setOpacity(0.5)
|
||||
q_label.setGraphicsEffect(opacity_effect)
|
||||
|
||||
q_label.setSizePolicy (QSizePolicy.Fixed, QSizePolicy.Fixed)
|
||||
q_label.setMinimumSize(icon_size, icon_size )
|
||||
q_label.setSizePolicy (QSizePolicy.Ignored, QSizePolicy.Ignored)
|
||||
|
||||
main_frame_l_cont_hl.addWidget (q_label, 0, i)
|
||||
|
||||
|
@ -101,33 +85,39 @@ class ImagePreviewSequenceBar(QFrame):
|
|||
def get_preview_images_count(self):
|
||||
return self.preview_images_count
|
||||
|
||||
def update_images(self, prev_imgs=None, next_imgs=None):
|
||||
def update_images(self, prev_q_imgs=None, next_q_imgs=None):
|
||||
# Fix arrays
|
||||
if prev_imgs is None:
|
||||
prev_imgs = []
|
||||
if prev_q_imgs is None:
|
||||
prev_q_imgs = []
|
||||
prev_img_conts_len = len(self.prev_img_conts)
|
||||
prev_q_imgs_len = len(prev_imgs)
|
||||
prev_q_imgs_len = len(prev_q_imgs)
|
||||
if prev_q_imgs_len < prev_img_conts_len:
|
||||
for i in range ( prev_img_conts_len - prev_q_imgs_len ):
|
||||
prev_imgs.append(None)
|
||||
prev_q_imgs.append(None)
|
||||
elif prev_q_imgs_len > prev_img_conts_len:
|
||||
prev_imgs = prev_imgs[:prev_img_conts_len]
|
||||
prev_q_imgs = prev_q_imgs[:prev_img_conts_len]
|
||||
|
||||
if next_imgs is None:
|
||||
next_imgs = []
|
||||
if next_q_imgs is None:
|
||||
next_q_imgs = []
|
||||
next_img_conts_len = len(self.next_img_conts)
|
||||
next_q_imgs_len = len(next_imgs)
|
||||
next_q_imgs_len = len(next_q_imgs)
|
||||
if next_q_imgs_len < next_img_conts_len:
|
||||
for i in range ( next_img_conts_len - next_q_imgs_len ):
|
||||
next_imgs.append(None)
|
||||
next_q_imgs.append(None)
|
||||
elif next_q_imgs_len > next_img_conts_len:
|
||||
next_imgs = next_imgs[:next_img_conts_len]
|
||||
next_q_imgs = next_q_imgs[:next_img_conts_len]
|
||||
|
||||
for i,img in enumerate(prev_imgs):
|
||||
self.prev_img_conts[i].setPixmap( QPixmap.fromImage( QImage_from_np(img) ) if img is not None else self.black_q_pixmap )
|
||||
for i,q_img in enumerate(prev_q_imgs):
|
||||
if q_img is None:
|
||||
self.prev_img_conts[i].setPixmap( self.black_q_pixmap )
|
||||
else:
|
||||
self.prev_img_conts[i].setPixmap( QPixmap.fromImage(q_img) )
|
||||
|
||||
for i,img in enumerate(next_imgs):
|
||||
self.next_img_conts[i].setPixmap( QPixmap.fromImage( QImage_from_np(img) ) if img is not None else self.black_q_pixmap )
|
||||
for i,q_img in enumerate(next_q_imgs):
|
||||
if q_img is None:
|
||||
self.next_img_conts[i].setPixmap( self.black_q_pixmap )
|
||||
else:
|
||||
self.next_img_conts[i].setPixmap( QPixmap.fromImage(q_img) )
|
||||
|
||||
class ColorScheme():
|
||||
def __init__(self, unselected_color, selected_color, outline_color, outline_width, pt_outline_color, cross_cursor):
|
||||
|
@ -197,7 +187,6 @@ class QCanvasControlsLeftBar(QFrame):
|
|||
self.btn_pt_edit_mode_act = QActionEx( QIconDB.pt_edit_mode, QStringDB.btn_pt_edit_mode_tip, shortcut_in_tooltip=True, is_checkable=True)
|
||||
btn_pt_edit_mode.setDefaultAction(self.btn_pt_edit_mode_act)
|
||||
btn_pt_edit_mode.setIconSize(QUIConfig.icon_q_size)
|
||||
#==============================================
|
||||
|
||||
controls_bar_frame2_l = QVBoxLayout()
|
||||
controls_bar_frame2_l.addWidget ( btn_poly_type_include )
|
||||
|
@ -262,11 +251,6 @@ class QCanvasControlsRightBar(QFrame):
|
|||
btn_view_xseg_mask.setDefaultAction(self.btn_view_xseg_mask_act)
|
||||
btn_view_xseg_mask.setIconSize(QUIConfig.icon_q_size)
|
||||
|
||||
btn_view_xseg_overlay_mask = QToolButton()
|
||||
self.btn_view_xseg_overlay_mask_act = QActionEx( QIconDB.view_xseg_overlay, QStringDB.btn_view_xseg_overlay_mask_tip, shortcut='`', shortcut_in_tooltip=True, is_checkable=True)
|
||||
btn_view_xseg_overlay_mask.setDefaultAction(self.btn_view_xseg_overlay_mask_act)
|
||||
btn_view_xseg_overlay_mask.setIconSize(QUIConfig.icon_q_size)
|
||||
|
||||
self.btn_poly_color_act_grp = QActionGroup (self)
|
||||
self.btn_poly_color_act_grp.addAction(self.btn_poly_color_red_act)
|
||||
self.btn_poly_color_act_grp.addAction(self.btn_poly_color_green_act)
|
||||
|
@ -275,17 +259,6 @@ class QCanvasControlsRightBar(QFrame):
|
|||
self.btn_poly_color_act_grp.addAction(self.btn_view_xseg_mask_act)
|
||||
self.btn_poly_color_act_grp.setExclusive(True)
|
||||
#==============================================
|
||||
btn_view_lock_center = QToolButton()
|
||||
self.btn_view_lock_center_act = QActionEx( QIconDB.view_lock_center, QStringDB.btn_view_lock_center_tip, shortcut_in_tooltip=True, is_checkable=True)
|
||||
btn_view_lock_center.setDefaultAction(self.btn_view_lock_center_act)
|
||||
btn_view_lock_center.setIconSize(QUIConfig.icon_q_size)
|
||||
|
||||
controls_bar_frame2_l = QVBoxLayout()
|
||||
controls_bar_frame2_l.addWidget ( btn_view_xseg_overlay_mask )
|
||||
controls_bar_frame2 = QFrame()
|
||||
controls_bar_frame2.setFrameShape(QFrame.StyledPanel)
|
||||
controls_bar_frame2.setSizePolicy (QSizePolicy.Fixed, QSizePolicy.Fixed)
|
||||
controls_bar_frame2.setLayout(controls_bar_frame2_l)
|
||||
|
||||
controls_bar_frame1_l = QVBoxLayout()
|
||||
controls_bar_frame1_l.addWidget ( btn_poly_color_red )
|
||||
|
@ -298,18 +271,9 @@ class QCanvasControlsRightBar(QFrame):
|
|||
controls_bar_frame1.setSizePolicy (QSizePolicy.Fixed, QSizePolicy.Fixed)
|
||||
controls_bar_frame1.setLayout(controls_bar_frame1_l)
|
||||
|
||||
controls_bar_frame3_l = QVBoxLayout()
|
||||
controls_bar_frame3_l.addWidget ( btn_view_lock_center )
|
||||
controls_bar_frame3 = QFrame()
|
||||
controls_bar_frame3.setFrameShape(QFrame.StyledPanel)
|
||||
controls_bar_frame3.setSizePolicy (QSizePolicy.Fixed, QSizePolicy.Fixed)
|
||||
controls_bar_frame3.setLayout(controls_bar_frame3_l)
|
||||
|
||||
controls_bar_l = QVBoxLayout()
|
||||
controls_bar_l.setContentsMargins(0,0,0,0)
|
||||
controls_bar_l.addWidget(controls_bar_frame2)
|
||||
controls_bar_l.addWidget(controls_bar_frame1)
|
||||
controls_bar_l.addWidget(controls_bar_frame3)
|
||||
|
||||
self.setSizePolicy ( QSizePolicy.Fixed, QSizePolicy.Expanding )
|
||||
self.setLayout(controls_bar_l)
|
||||
|
@ -324,10 +288,8 @@ class QCanvasOperator(QWidget):
|
|||
self.cbar.btn_poly_color_red_act.triggered.connect ( lambda : self.set_color_scheme_id(0) )
|
||||
self.cbar.btn_poly_color_green_act.triggered.connect ( lambda : self.set_color_scheme_id(1) )
|
||||
self.cbar.btn_poly_color_blue_act.triggered.connect ( lambda : self.set_color_scheme_id(2) )
|
||||
self.cbar.btn_view_baked_mask_act.triggered.connect ( lambda : self.set_op_mode(OpMode.VIEW_BAKED) )
|
||||
self.cbar.btn_view_xseg_mask_act.triggered.connect ( lambda : self.set_op_mode(OpMode.VIEW_XSEG_MASK) )
|
||||
|
||||
self.cbar.btn_view_xseg_overlay_mask_act.toggled.connect ( lambda is_checked: self.update() )
|
||||
self.cbar.btn_view_baked_mask_act.toggled.connect ( lambda : self.set_op_mode(OpMode.VIEW_BAKED) )
|
||||
self.cbar.btn_view_xseg_mask_act.toggled.connect ( self.set_view_xseg_mask )
|
||||
|
||||
self.cbar.btn_poly_type_include_act.triggered.connect ( lambda : self.set_poly_include_type(SegIEPolyType.INCLUDE) )
|
||||
self.cbar.btn_poly_type_exclude_act.triggered.connect ( lambda : self.set_poly_include_type(SegIEPolyType.EXCLUDE) )
|
||||
|
@ -338,7 +300,6 @@ class QCanvasOperator(QWidget):
|
|||
self.cbar.btn_delete_poly_act.triggered.connect ( lambda : self.action_delete_poly() )
|
||||
|
||||
self.cbar.btn_pt_edit_mode_act.toggled.connect ( lambda is_checked: self.set_pt_edit_mode( PTEditMode.ADD_DEL if is_checked else PTEditMode.MOVE ) )
|
||||
self.cbar.btn_view_lock_center_act.toggled.connect ( lambda is_checked: self.set_view_lock( ViewLock.CENTER if is_checked else ViewLock.NONE ) )
|
||||
|
||||
self.mouse_in_widget = False
|
||||
|
||||
|
@ -349,22 +310,16 @@ class QCanvasOperator(QWidget):
|
|||
self.initialized = False
|
||||
self.last_state = None
|
||||
|
||||
def initialize(self, img, img_look_pt=None, view_scale=None, ie_polys=None, xseg_mask=None, canvas_config=None ):
|
||||
q_img = self.q_img = QImage_from_np(img)
|
||||
def initialize(self, q_img, img_look_pt=None, view_scale=None, ie_polys=None, xseg_mask=None, canvas_config=None ):
|
||||
self.q_img = q_img
|
||||
self.img_pixmap = QPixmap.fromImage(q_img)
|
||||
|
||||
self.xseg_mask_pixmap = None
|
||||
self.xseg_overlay_mask_pixmap = None
|
||||
if xseg_mask is not None:
|
||||
h,w,c = img.shape
|
||||
xseg_mask = cv2.resize(xseg_mask, (w,h), interpolation=cv2.INTER_CUBIC)
|
||||
xseg_mask = imagelib.normalize_channels(xseg_mask, 1)
|
||||
xseg_img = img.astype(np.float32)/255.0
|
||||
xseg_overlay_mask = xseg_img*(1-xseg_mask)*0.5 + xseg_img*xseg_mask
|
||||
xseg_overlay_mask = np.clip(xseg_overlay_mask*255, 0, 255).astype(np.uint8)
|
||||
xseg_mask = np.clip(xseg_mask*255, 0, 255).astype(np.uint8)
|
||||
w,h = QSize_to_np ( q_img.size() )
|
||||
xseg_mask = cv2.resize(xseg_mask, (w,h), cv2.INTER_CUBIC)
|
||||
xseg_mask = (imagelib.normalize_channels(xseg_mask, 1) * 255).astype(np.uint8)
|
||||
self.xseg_mask_pixmap = QPixmap.fromImage(QImage_from_np(xseg_mask))
|
||||
self.xseg_overlay_mask_pixmap = QPixmap.fromImage(QImage_from_np(xseg_overlay_mask))
|
||||
|
||||
self.img_size = QSize_to_np (self.img_pixmap.size())
|
||||
|
||||
|
@ -382,7 +337,6 @@ class QCanvasOperator(QWidget):
|
|||
# UI init
|
||||
self.set_cbar_disabled()
|
||||
self.cbar.btn_poly_color_act_grp.setDisabled(False)
|
||||
self.cbar.btn_view_xseg_overlay_mask_act.setDisabled(False)
|
||||
self.cbar.btn_poly_type_act_grp.setDisabled(False)
|
||||
|
||||
# Initial vars
|
||||
|
@ -390,14 +344,12 @@ class QCanvasOperator(QWidget):
|
|||
self.mouse_hull_poly = None
|
||||
self.mouse_wire_poly = None
|
||||
self.drag_type = DragType.NONE
|
||||
self.mouse_cli_pt = np.zeros((2,), np.float32 )
|
||||
|
||||
# Initial state
|
||||
self.set_op_mode(OpMode.NONE)
|
||||
self.set_color_scheme_id(1)
|
||||
self.set_poly_include_type(SegIEPolyType.INCLUDE)
|
||||
self.set_pt_edit_mode(PTEditMode.MOVE)
|
||||
self.set_view_lock(ViewLock.NONE)
|
||||
|
||||
# Apply last state
|
||||
if self.last_state is not None:
|
||||
|
@ -418,7 +370,8 @@ class QCanvasOperator(QWidget):
|
|||
self.set_op_mode(OpMode.EDIT_PTS)
|
||||
|
||||
self.last_state = sn(op_mode = self.op_mode if self.op_mode in [OpMode.VIEW_BAKED, OpMode.VIEW_XSEG_MASK] else None,
|
||||
color_scheme_id = self.color_scheme_id)
|
||||
color_scheme_id = self.color_scheme_id,
|
||||
)
|
||||
|
||||
self.img_pixmap = None
|
||||
self.update_cursor(is_finalize=True)
|
||||
|
@ -433,15 +386,13 @@ class QCanvasOperator(QWidget):
|
|||
# ====================================== GETTERS =====================================
|
||||
# ====================================================================================
|
||||
# ====================================================================================
|
||||
|
||||
def is_initialized(self):
|
||||
return self.initialized
|
||||
|
||||
def get_ie_polys(self):
|
||||
return self.ie_polys
|
||||
|
||||
def get_cli_center_pt(self):
|
||||
return np.round(QSize_to_np(self.size())/2.0)
|
||||
|
||||
def get_img_look_pt(self):
|
||||
img_look_pt = self.img_look_pt
|
||||
if img_look_pt is None:
|
||||
|
@ -503,10 +454,10 @@ class QCanvasOperator(QWidget):
|
|||
return None
|
||||
|
||||
def img_to_cli_pt(self, p):
|
||||
return (p - self.get_img_look_pt()) * self.get_view_scale() + self.get_cli_center_pt()# QSize_to_np(self.size())/2.0
|
||||
return (p - self.get_img_look_pt()) * self.get_view_scale() + QSize_to_np(self.size())/2.0
|
||||
|
||||
def cli_to_img_pt(self, p):
|
||||
return (p - self.get_cli_center_pt() ) / self.get_view_scale() + self.get_img_look_pt()
|
||||
return (p - QSize_to_np(self.size())/2.0 ) / self.get_view_scale() + self.get_img_look_pt()
|
||||
|
||||
def img_to_cli_rect(self, rect):
|
||||
tl = QPoint_to_np(rect.topLeft())
|
||||
|
@ -531,13 +482,9 @@ class QCanvasOperator(QWidget):
|
|||
elif self.op_mode == OpMode.DRAW_PTS:
|
||||
self.cbar.btn_undo_pt_act.setDisabled(True)
|
||||
self.cbar.btn_redo_pt_act.setDisabled(True)
|
||||
self.cbar.btn_view_lock_center_act.setDisabled(True)
|
||||
# Reset view_lock when exit from DRAW_PTS
|
||||
self.set_view_lock(ViewLock.NONE)
|
||||
# Remove unfinished poly
|
||||
if self.op_poly.get_pts_count() < 3:
|
||||
# Remove unfinished poly
|
||||
self.ie_polys.remove_poly(self.op_poly)
|
||||
|
||||
elif self.op_mode == OpMode.EDIT_PTS:
|
||||
self.cbar.btn_pt_edit_mode_act.setDisabled(True)
|
||||
self.cbar.btn_delete_poly_act.setDisabled(True)
|
||||
|
@ -556,7 +503,6 @@ class QCanvasOperator(QWidget):
|
|||
elif op_mode == OpMode.DRAW_PTS:
|
||||
self.cbar.btn_undo_pt_act.setDisabled(False)
|
||||
self.cbar.btn_redo_pt_act.setDisabled(False)
|
||||
self.cbar.btn_view_lock_center_act.setDisabled(False)
|
||||
elif op_mode == OpMode.EDIT_PTS:
|
||||
self.cbar.btn_pt_edit_mode_act.setDisabled(False)
|
||||
self.cbar.btn_delete_poly_act.setDisabled(False)
|
||||
|
@ -570,12 +516,11 @@ class QCanvasOperator(QWidget):
|
|||
self.img_baked_pixmap = QPixmap.fromImage(QImage_from_np(n))
|
||||
elif op_mode == OpMode.VIEW_XSEG_MASK:
|
||||
self.cbar.btn_view_xseg_mask_act.setChecked(True)
|
||||
|
||||
if op_mode in [OpMode.DRAW_PTS, OpMode.EDIT_PTS]:
|
||||
self.mouse_op_poly_pt_id = None
|
||||
self.mouse_op_poly_edge_id = None
|
||||
self.mouse_op_poly_edge_id_pt = None
|
||||
|
||||
#
|
||||
self.op_poly = op_poly
|
||||
if op_poly is not None:
|
||||
self.update_mouse_info()
|
||||
|
@ -588,32 +533,19 @@ class QCanvasOperator(QWidget):
|
|||
self.pt_edit_mode = pt_edit_mode
|
||||
self.update_cursor()
|
||||
self.update()
|
||||
|
||||
self.cbar.btn_pt_edit_mode_act.setChecked( self.pt_edit_mode == PTEditMode.ADD_DEL )
|
||||
|
||||
def set_view_lock(self, view_lock):
|
||||
if not hasattr(self, 'view_lock') or self.view_lock != view_lock:
|
||||
if hasattr(self, 'view_lock') and self.view_lock != view_lock:
|
||||
if view_lock == ViewLock.CENTER:
|
||||
self.img_look_pt = self.mouse_img_pt
|
||||
QCursor.setPos ( self.mapToGlobal( QPoint_from_np(self.img_to_cli_pt(self.img_look_pt)) ))
|
||||
|
||||
self.view_lock = view_lock
|
||||
self.update()
|
||||
self.cbar.btn_view_lock_center_act.setChecked( self.view_lock == ViewLock.CENTER )
|
||||
|
||||
def set_cbar_disabled(self):
|
||||
self.cbar.btn_delete_poly_act.setDisabled(True)
|
||||
self.cbar.btn_undo_pt_act.setDisabled(True)
|
||||
self.cbar.btn_redo_pt_act.setDisabled(True)
|
||||
self.cbar.btn_pt_edit_mode_act.setDisabled(True)
|
||||
self.cbar.btn_view_lock_center_act.setDisabled(True)
|
||||
self.cbar.btn_poly_color_act_grp.setDisabled(True)
|
||||
self.cbar.btn_view_xseg_overlay_mask_act.setDisabled(True)
|
||||
self.cbar.btn_poly_type_act_grp.setDisabled(True)
|
||||
|
||||
|
||||
def set_color_scheme_id(self, id):
|
||||
if self.op_mode == OpMode.VIEW_BAKED or self.op_mode == OpMode.VIEW_XSEG_MASK:
|
||||
if self.op_mode == OpMode.VIEW_BAKED:
|
||||
self.set_op_mode(OpMode.NONE)
|
||||
|
||||
if not hasattr(self, 'color_scheme_id') or self.color_scheme_id != id:
|
||||
|
@ -634,9 +566,29 @@ class QCanvasOperator(QWidget):
|
|||
self.op_mode in [OpMode.NONE, OpMode.EDIT_PTS] ):
|
||||
self.poly_include_type = poly_include_type
|
||||
self.update()
|
||||
|
||||
self.cbar.btn_poly_type_include_act.setChecked(self.poly_include_type == SegIEPolyType.INCLUDE)
|
||||
self.cbar.btn_poly_type_exclude_act.setChecked(self.poly_include_type == SegIEPolyType.EXCLUDE)
|
||||
|
||||
def set_view_xseg_mask(self, is_checked):
|
||||
if is_checked:
|
||||
self.set_op_mode(OpMode.VIEW_XSEG_MASK)
|
||||
|
||||
#n = QImage_to_np ( self.q_img ).astype(np.float32) / 255.0
|
||||
#h,w,c = n.shape
|
||||
|
||||
#mask = np.zeros( (h,w,1), dtype=np.float32 )
|
||||
#self.ie_polys.overlay_mask(mask)
|
||||
|
||||
#n = (mask*255).astype(np.uint8)
|
||||
|
||||
#self.img_baked_pixmap = QPixmap.fromImage(QImage_from_np(n))
|
||||
else:
|
||||
self.set_op_mode(OpMode.NONE)
|
||||
|
||||
self.cbar.btn_view_xseg_mask_act.setChecked(is_checked )
|
||||
|
||||
|
||||
# ====================================================================================
|
||||
# ====================================================================================
|
||||
# ====================================== METHODS =====================================
|
||||
|
@ -762,9 +714,7 @@ class QCanvasOperator(QWidget):
|
|||
return
|
||||
key = ev.key()
|
||||
key_mods = int(ev.modifiers())
|
||||
if self.op_mode == OpMode.DRAW_PTS:
|
||||
self.set_view_lock(ViewLock.CENTER if key_mods == Qt.ShiftModifier else ViewLock.NONE )
|
||||
elif self.op_mode == OpMode.EDIT_PTS:
|
||||
if self.op_mode == OpMode.EDIT_PTS:
|
||||
self.set_pt_edit_mode(PTEditMode.ADD_DEL if key_mods == Qt.ControlModifier else PTEditMode.MOVE )
|
||||
|
||||
def on_keyReleaseEvent(self, ev):
|
||||
|
@ -772,9 +722,7 @@ class QCanvasOperator(QWidget):
|
|||
return
|
||||
key = ev.key()
|
||||
key_mods = int(ev.modifiers())
|
||||
if self.op_mode == OpMode.DRAW_PTS:
|
||||
self.set_view_lock(ViewLock.CENTER if key_mods == Qt.ShiftModifier else ViewLock.NONE )
|
||||
elif self.op_mode == OpMode.EDIT_PTS:
|
||||
if self.op_mode == OpMode.EDIT_PTS:
|
||||
self.set_pt_edit_mode(PTEditMode.ADD_DEL if key_mods == Qt.ControlModifier else PTEditMode.MOVE )
|
||||
|
||||
def enterEvent(self, ev):
|
||||
|
@ -825,10 +773,10 @@ class QCanvasOperator(QWidget):
|
|||
if self.mouse_op_poly_pt_id is not None:
|
||||
# Click on point of op_poly
|
||||
if self.pt_edit_mode == PTEditMode.ADD_DEL:
|
||||
# in mode 'delete point'
|
||||
# with mode -> delete point
|
||||
self.op_poly.remove_pt(self.mouse_op_poly_pt_id)
|
||||
if self.op_poly.get_pts_count() < 3:
|
||||
# not enough points after delete -> remove poly
|
||||
# not enough points -> remove poly
|
||||
self.ie_polys.remove_poly (self.op_poly)
|
||||
self.set_op_mode(OpMode.NONE)
|
||||
self.update()
|
||||
|
@ -842,7 +790,7 @@ class QCanvasOperator(QWidget):
|
|||
elif self.mouse_op_poly_edge_id is not None:
|
||||
# Click on edge of op_poly
|
||||
if self.pt_edit_mode == PTEditMode.ADD_DEL:
|
||||
# in mode 'insert new point'
|
||||
# with mode -> insert new point
|
||||
edge_img_pt = self.cli_to_img_pt(self.mouse_op_poly_edge_id_pt)
|
||||
self.op_poly.insert_pt (self.mouse_op_poly_edge_id+1, edge_img_pt)
|
||||
self.update()
|
||||
|
@ -888,15 +836,8 @@ class QCanvasOperator(QWidget):
|
|||
if not self.initialized:
|
||||
return
|
||||
|
||||
prev_mouse_cli_pt = self.mouse_cli_pt
|
||||
self.update_mouse_info(QPoint_to_np(ev.pos()))
|
||||
|
||||
if self.view_lock == ViewLock.CENTER:
|
||||
if npla.norm(self.mouse_cli_pt - prev_mouse_cli_pt) >= 1:
|
||||
self.img_look_pt = self.mouse_img_pt
|
||||
QCursor.setPos ( self.mapToGlobal( QPoint_from_np(self.img_to_cli_pt(self.img_look_pt)) ))
|
||||
self.update()
|
||||
|
||||
if self.drag_type == DragType.IMAGE_LOOK:
|
||||
delta_pt = self.cli_to_img_pt(self.mouse_cli_pt) - self.cli_to_img_pt(self.drag_cli_pt)
|
||||
self.img_look_pt = self.drag_img_look_pt - delta_pt
|
||||
|
@ -905,7 +846,9 @@ class QCanvasOperator(QWidget):
|
|||
if self.op_mode == OpMode.DRAW_PTS:
|
||||
self.update()
|
||||
elif self.op_mode == OpMode.EDIT_PTS:
|
||||
|
||||
if self.drag_type == DragType.POLY_PT:
|
||||
|
||||
delta_pt = self.cli_to_img_pt(self.mouse_cli_pt) - self.cli_to_img_pt(self.drag_cli_pt)
|
||||
self.op_poly.set_point(self.drag_poly_pt_id, self.drag_poly_pt + delta_pt)
|
||||
self.update()
|
||||
|
@ -944,19 +887,20 @@ class QCanvasOperator(QWidget):
|
|||
qp.setRenderHint(QPainter.HighQualityAntialiasing)
|
||||
qp.setRenderHint(QPainter.SmoothPixmapTransform)
|
||||
|
||||
src_rect = QRect(0, 0, *self.img_size)
|
||||
dst_rect = self.img_to_cli_rect( src_rect )
|
||||
|
||||
if self.op_mode == OpMode.VIEW_BAKED:
|
||||
|
||||
src_rect = QRect(0, 0, *self.img_size)
|
||||
dst_rect = self.img_to_cli_rect( src_rect )
|
||||
qp.drawPixmap(dst_rect, self.img_baked_pixmap, src_rect)
|
||||
elif self.op_mode == OpMode.VIEW_XSEG_MASK:
|
||||
if self.xseg_mask_pixmap is not None:
|
||||
src_rect = QRect(0, 0, *self.img_size)
|
||||
dst_rect = self.img_to_cli_rect( src_rect )
|
||||
qp.drawPixmap(dst_rect, self.xseg_mask_pixmap, src_rect)
|
||||
else:
|
||||
if self.cbar.btn_view_xseg_overlay_mask_act.isChecked() and \
|
||||
self.xseg_overlay_mask_pixmap is not None:
|
||||
qp.drawPixmap(dst_rect, self.xseg_overlay_mask_pixmap, src_rect)
|
||||
elif self.img_pixmap is not None:
|
||||
if self.img_pixmap is not None:
|
||||
src_rect = QRect(0, 0, *self.img_size)
|
||||
dst_rect = self.img_to_cli_rect( src_rect )
|
||||
qp.drawPixmap(dst_rect, self.img_pixmap, src_rect)
|
||||
|
||||
polys = self.ie_polys.get_polys()
|
||||
|
@ -1051,9 +995,9 @@ class QCanvasOperator(QWidget):
|
|||
if op_mode == OpMode.NONE:
|
||||
if poly == self.mouse_wire_poly:
|
||||
qp.setBrush(color_scheme.poly_selected_brush)
|
||||
#else:
|
||||
# if poly == op_poly:
|
||||
# qp.setBrush(color_scheme.poly_selected_brush)
|
||||
else:
|
||||
if poly == op_poly:
|
||||
qp.setBrush(color_scheme.poly_selected_brush)
|
||||
|
||||
qp.drawPath(poly_line_path)
|
||||
|
||||
|
@ -1066,6 +1010,7 @@ class QCanvasOperator(QWidget):
|
|||
|
||||
qp.end()
|
||||
|
||||
|
||||
class QCanvas(QFrame):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -1078,16 +1023,17 @@ class QCanvas(QFrame):
|
|||
btn_poly_color_blue_act = self.canvas_control_right_bar.btn_poly_color_blue_act,
|
||||
btn_view_baked_mask_act = self.canvas_control_right_bar.btn_view_baked_mask_act,
|
||||
btn_view_xseg_mask_act = self.canvas_control_right_bar.btn_view_xseg_mask_act,
|
||||
btn_view_xseg_overlay_mask_act = self.canvas_control_right_bar.btn_view_xseg_overlay_mask_act,
|
||||
btn_poly_color_act_grp = self.canvas_control_right_bar.btn_poly_color_act_grp,
|
||||
btn_view_lock_center_act = self.canvas_control_right_bar.btn_view_lock_center_act,
|
||||
|
||||
btn_poly_type_include_act = self.canvas_control_left_bar.btn_poly_type_include_act,
|
||||
btn_poly_type_exclude_act = self.canvas_control_left_bar.btn_poly_type_exclude_act,
|
||||
btn_poly_type_act_grp = self.canvas_control_left_bar.btn_poly_type_act_grp,
|
||||
|
||||
btn_undo_pt_act = self.canvas_control_left_bar.btn_undo_pt_act,
|
||||
btn_redo_pt_act = self.canvas_control_left_bar.btn_redo_pt_act,
|
||||
|
||||
btn_delete_poly_act = self.canvas_control_left_bar.btn_delete_poly_act,
|
||||
|
||||
btn_pt_edit_mode_act = self.canvas_control_left_bar.btn_pt_edit_mode_act )
|
||||
|
||||
self.op = QCanvasOperator(cbar)
|
||||
|
@ -1122,7 +1068,7 @@ class LoaderQSubprocessor(QSubprocessor):
|
|||
if len (self.idxs) > 0:
|
||||
idx = self.idxs.pop(0)
|
||||
image_path = self.image_paths[idx]
|
||||
self.q_label.setText(f'{QStringDB.loading_tip}... {image_path.name}')
|
||||
self.q_label.setText(f'{image_path.name}')
|
||||
|
||||
return idx, image_path
|
||||
|
||||
|
@ -1155,22 +1101,18 @@ class LoaderQSubprocessor(QSubprocessor):
|
|||
return idx, True, ie_polys.has_polys()
|
||||
return idx, False, False
|
||||
|
||||
|
||||
class MainWindow(QXMainWindow):
|
||||
|
||||
def __init__(self, input_dirpath, cfg_root_path):
|
||||
self.loading_frame = None
|
||||
self.help_frame = None
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.input_dirpath = input_dirpath
|
||||
self.trash_dirpath = input_dirpath.parent / (input_dirpath.name + '_trash')
|
||||
self.cfg_root_path = cfg_root_path
|
||||
|
||||
self.cfg_path = cfg_root_path / 'MainWindow_cfg.dat'
|
||||
self.cfg_dict = pickle.loads(self.cfg_path.read_bytes()) if self.cfg_path.exists() else {}
|
||||
|
||||
self.cached_images = {}
|
||||
self.cached_QImages = {}
|
||||
self.cached_has_ie_polys = {}
|
||||
|
||||
self.initialize_ui()
|
||||
|
@ -1181,20 +1123,9 @@ class MainWindow(QXMainWindow):
|
|||
self.loading_frame.setFrameShape(QFrame.StyledPanel)
|
||||
self.loader_label = QLabel()
|
||||
self.loader_progress_bar = QProgressBar()
|
||||
|
||||
intro_image = QLabel()
|
||||
intro_image.setPixmap( QPixmap.fromImage(QImageDB.intro) )
|
||||
|
||||
intro_image_frame_l = QVBoxLayout()
|
||||
intro_image_frame_l.addWidget(intro_image, alignment=Qt.AlignCenter)
|
||||
intro_image_frame = QFrame()
|
||||
intro_image_frame.setSizePolicy (QSizePolicy.Expanding, QSizePolicy.Expanding)
|
||||
intro_image_frame.setLayout(intro_image_frame_l)
|
||||
|
||||
loading_frame_l = QVBoxLayout()
|
||||
loading_frame_l.addWidget (intro_image_frame)
|
||||
loading_frame_l.addWidget (self.loader_label)
|
||||
loading_frame_l.addWidget (self.loader_progress_bar)
|
||||
loading_frame_l.addWidget (self.loader_label, alignment=Qt.AlignBottom)
|
||||
loading_frame_l.addWidget (self.loader_progress_bar, alignment=Qt.AlignTop)
|
||||
self.loading_frame.setLayout(loading_frame_l)
|
||||
|
||||
self.loader_subprocessor = LoaderQSubprocessor( image_paths=pathex.get_image_paths(input_dirpath, return_Path_class=True),
|
||||
|
@ -1207,7 +1138,6 @@ class MainWindow(QXMainWindow):
|
|||
self.image_paths_done = []
|
||||
self.image_paths = image_paths
|
||||
self.image_paths_has_ie_polys = image_paths_has_ie_polys
|
||||
self.set_has_ie_polys_count ( len([ 1 for x in self.image_paths_has_ie_polys if self.image_paths_has_ie_polys[x] == True]) )
|
||||
self.loading_frame.hide()
|
||||
self.loading_frame = None
|
||||
|
||||
|
@ -1219,7 +1149,7 @@ class MainWindow(QXMainWindow):
|
|||
|
||||
|
||||
def update_cached_images (self, count=5):
|
||||
d = self.cached_images
|
||||
d = self.cached_QImages
|
||||
|
||||
for image_path in self.image_paths_done[:-count]+self.image_paths[count:]:
|
||||
if image_path in d:
|
||||
|
@ -1229,14 +1159,13 @@ class MainWindow(QXMainWindow):
|
|||
if image_path not in d:
|
||||
img = cv2_imread(image_path)
|
||||
if img is not None:
|
||||
d[image_path] = img
|
||||
d[image_path] = QImage_from_np(img)
|
||||
|
||||
def load_image(self, image_path):
|
||||
def load_QImage(self, image_path):
|
||||
try:
|
||||
img = self.cached_images.get(image_path, None)
|
||||
img = self.cached_QImages.get(image_path, None)
|
||||
if img is None:
|
||||
img = cv2_imread(image_path)
|
||||
self.cached_images[image_path] = img
|
||||
img = QImage_from_np(cv2_imread(image_path))
|
||||
if img is None:
|
||||
io.log_err(f'Unable to load {image_path}')
|
||||
except:
|
||||
|
@ -1246,10 +1175,10 @@ class MainWindow(QXMainWindow):
|
|||
|
||||
def update_preview_bar(self):
|
||||
count = self.image_bar.get_preview_images_count()
|
||||
d = self.cached_images
|
||||
prev_imgs = [ d.get(image_path, None) for image_path in self.image_paths_done[-1:-count:-1] ]
|
||||
next_imgs = [ d.get(image_path, None) for image_path in self.image_paths[:count] ]
|
||||
self.image_bar.update_images(prev_imgs, next_imgs)
|
||||
d = self.cached_QImages
|
||||
prev_q_imgs = [ d.get(image_path, None) for image_path in self.image_paths_done[-1:-count:-1] ]
|
||||
next_q_imgs = [ d.get(image_path, None) for image_path in self.image_paths[:count] ]
|
||||
self.image_bar.update_images(prev_q_imgs, next_q_imgs)
|
||||
|
||||
|
||||
def canvas_initialize(self, image_path, only_has_polys=False):
|
||||
|
@ -1262,13 +1191,13 @@ class MainWindow(QXMainWindow):
|
|||
|
||||
ie_polys = dflimg.get_seg_ie_polys()
|
||||
xseg_mask = dflimg.get_xseg_mask()
|
||||
img = self.load_image(image_path)
|
||||
if img is None:
|
||||
q_img = self.load_QImage(image_path)
|
||||
if q_img is None:
|
||||
return False
|
||||
|
||||
self.canvas.op.initialize ( img, ie_polys=ie_polys, xseg_mask=xseg_mask )
|
||||
self.canvas.op.initialize ( q_img, ie_polys=ie_polys, xseg_mask=xseg_mask )
|
||||
|
||||
self.filename_label.setText(f"{image_path.name}")
|
||||
self.filename_label.setText(str(image_path.name))
|
||||
|
||||
return True
|
||||
|
||||
|
@ -1281,19 +1210,11 @@ class MainWindow(QXMainWindow):
|
|||
new_ie_polys = self.canvas.op.get_ie_polys()
|
||||
|
||||
if not new_ie_polys.identical(ie_polys):
|
||||
prev_has_polys = self.image_paths_has_ie_polys[image_path]
|
||||
self.image_paths_has_ie_polys[image_path] = new_ie_polys.has_polys()
|
||||
new_has_polys = self.image_paths_has_ie_polys[image_path]
|
||||
|
||||
if not prev_has_polys and new_has_polys:
|
||||
self.set_has_ie_polys_count ( self.get_has_ie_polys_count() +1)
|
||||
elif prev_has_polys and not new_has_polys:
|
||||
self.set_has_ie_polys_count ( self.get_has_ie_polys_count() -1)
|
||||
|
||||
dflimg.set_seg_ie_polys( new_ie_polys )
|
||||
dflimg.save()
|
||||
|
||||
self.filename_label.setText(f"")
|
||||
self.filename_label.setText("")
|
||||
|
||||
def process_prev_image(self):
|
||||
key_mods = QApplication.keyboardModifiers()
|
||||
|
@ -1343,17 +1264,6 @@ class MainWindow(QXMainWindow):
|
|||
self.update_cached_images()
|
||||
self.update_preview_bar()
|
||||
|
||||
def trash_current_image(self):
|
||||
self.process_next_image()
|
||||
|
||||
img_path = self.image_paths_done.pop(-1)
|
||||
img_path = Path(img_path)
|
||||
self.trash_dirpath.mkdir(parents=True, exist_ok=True)
|
||||
img_path.rename( self.trash_dirpath / img_path.name )
|
||||
|
||||
self.update_cached_images()
|
||||
self.update_preview_bar()
|
||||
|
||||
def initialize_ui(self):
|
||||
|
||||
self.canvas = QCanvas()
|
||||
|
@ -1368,59 +1278,33 @@ class MainWindow(QXMainWindow):
|
|||
btn_next_image = QXIconButton(QIconDB.right, QStringDB.btn_next_image_tip, shortcut='D', click_func=self.process_next_image)
|
||||
btn_next_image.setIconSize(QUIConfig.preview_bar_icon_q_size)
|
||||
|
||||
btn_delete_image = QXIconButton(QIconDB.trashcan, QStringDB.btn_delete_image_tip, shortcut='X', click_func=self.trash_current_image)
|
||||
btn_delete_image.setIconSize(QUIConfig.preview_bar_icon_q_size)
|
||||
|
||||
pad_image = QWidget()
|
||||
pad_image.setFixedSize(QUIConfig.preview_bar_icon_q_size)
|
||||
|
||||
preview_image_bar_frame_l = QHBoxLayout()
|
||||
preview_image_bar_frame_l.setContentsMargins(0,0,0,0)
|
||||
preview_image_bar_frame_l.addWidget ( pad_image, alignment=Qt.AlignCenter)
|
||||
preview_image_bar_frame_l.addWidget ( btn_prev_image, alignment=Qt.AlignCenter)
|
||||
preview_image_bar_frame_l.addWidget ( image_bar)
|
||||
preview_image_bar_frame_l.addWidget ( btn_next_image, alignment=Qt.AlignCenter)
|
||||
#preview_image_bar_frame_l.addWidget ( btn_delete_image, alignment=Qt.AlignCenter)
|
||||
|
||||
preview_image_bar_frame = QFrame()
|
||||
preview_image_bar_frame.setSizePolicy ( QSizePolicy.Fixed, QSizePolicy.Fixed )
|
||||
preview_image_bar_frame.setLayout(preview_image_bar_frame_l)
|
||||
|
||||
preview_image_bar_frame2_l = QHBoxLayout()
|
||||
preview_image_bar_frame2_l.setContentsMargins(0,0,0,0)
|
||||
preview_image_bar_frame2_l.addWidget ( btn_delete_image, alignment=Qt.AlignCenter)
|
||||
|
||||
preview_image_bar_frame2 = QFrame()
|
||||
preview_image_bar_frame2.setSizePolicy ( QSizePolicy.Fixed, QSizePolicy.Fixed )
|
||||
preview_image_bar_frame2.setLayout(preview_image_bar_frame2_l)
|
||||
|
||||
preview_image_bar_l = QHBoxLayout()
|
||||
preview_image_bar_l.addWidget (preview_image_bar_frame, alignment=Qt.AlignCenter)
|
||||
preview_image_bar_l.addWidget (preview_image_bar_frame2)
|
||||
preview_image_bar_l.addWidget (preview_image_bar_frame)
|
||||
|
||||
preview_image_bar = QFrame()
|
||||
preview_image_bar.setFrameShape(QFrame.StyledPanel)
|
||||
preview_image_bar.setSizePolicy ( QSizePolicy.Expanding, QSizePolicy.Fixed )
|
||||
preview_image_bar.setLayout(preview_image_bar_l)
|
||||
|
||||
label_font = QFont('Courier New')
|
||||
self.filename_label = QLabel()
|
||||
self.filename_label.setFont(label_font)
|
||||
|
||||
self.has_ie_polys_count_label = QLabel()
|
||||
|
||||
status_frame_l = QHBoxLayout()
|
||||
status_frame_l.setContentsMargins(0,0,0,0)
|
||||
status_frame_l.addWidget ( QLabel(), alignment=Qt.AlignCenter)
|
||||
status_frame_l.addWidget (self.filename_label, alignment=Qt.AlignCenter)
|
||||
status_frame_l.addWidget (self.has_ie_polys_count_label, alignment=Qt.AlignCenter)
|
||||
status_frame = QFrame()
|
||||
status_frame.setLayout(status_frame_l)
|
||||
f = QFont('Courier New')
|
||||
self.filename_label.setFont(f)
|
||||
|
||||
main_canvas_l = QVBoxLayout()
|
||||
main_canvas_l.setContentsMargins(0,0,0,0)
|
||||
main_canvas_l.addWidget (self.canvas)
|
||||
main_canvas_l.addWidget (status_frame)
|
||||
main_canvas_l.addWidget (self.filename_label, alignment=Qt.AlignCenter)
|
||||
main_canvas_l.addWidget (preview_image_bar)
|
||||
|
||||
self.main_canvas_frame = QFrame()
|
||||
|
@ -1438,29 +1322,11 @@ class MainWindow(QXMainWindow):
|
|||
else:
|
||||
self.move( QPoint(0,0))
|
||||
|
||||
def get_has_ie_polys_count(self):
|
||||
return self.has_ie_polys_count
|
||||
|
||||
def set_has_ie_polys_count(self, c):
|
||||
self.has_ie_polys_count = c
|
||||
self.has_ie_polys_count_label.setText(f"{c} {QStringDB.labeled_tip}")
|
||||
|
||||
def resizeEvent(self, ev):
|
||||
if self.loading_frame is not None:
|
||||
self.loading_frame.resize( ev.size() )
|
||||
if self.help_frame is not None:
|
||||
self.help_frame.resize( ev.size() )
|
||||
|
||||
def start(input_dirpath):
|
||||
"""
|
||||
returns exit_code
|
||||
"""
|
||||
io.log_info("Running XSeg editor.")
|
||||
|
||||
if PackedFaceset.path_contains(input_dirpath):
|
||||
io.log_info (f'\n{input_dirpath} contains packed faceset! Unpack it first.\n')
|
||||
return 1
|
||||
|
||||
root_path = Path(__file__).parent
|
||||
cfg_root_path = Path(tempfile.gettempdir())
|
||||
|
||||
|
@ -1480,7 +1346,6 @@ def start(input_dirpath):
|
|||
|
||||
QIconDB.initialize( root_path / 'gfx' / 'icons' )
|
||||
QCursorDB.initialize( root_path / 'gfx' / 'cursors' )
|
||||
QImageDB.initialize( root_path / 'gfx' / 'images' )
|
||||
|
||||
app.setWindowIcon(QIconDB.app_icon)
|
||||
app.setPalette( QDarkPalette() )
|
||||
|
@ -1491,4 +1356,3 @@ def start(input_dirpath):
|
|||
win.raise_()
|
||||
|
||||
app.exec_()
|
||||
return 0
|
||||
|
|
Before Width: | Height: | Size: 3.2 KiB |
Before Width: | Height: | Size: 4 KiB |
Before Width: | Height: | Size: 12 KiB |
Before Width: | Height: | Size: 30 KiB |
|
@ -2,7 +2,6 @@ import cv2
|
|||
import numpy as np
|
||||
from pathlib import Path
|
||||
from core.interact import interact as io
|
||||
from core import imagelib
|
||||
import traceback
|
||||
|
||||
def cv2_imread(filename, flags=cv2.IMREAD_UNCHANGED, loader_func=None, verbose=True):
|
||||
|
@ -30,11 +29,3 @@ def cv2_imwrite(filename, img, *args):
|
|||
stream.write( buf )
|
||||
except:
|
||||
pass
|
||||
|
||||
def cv2_resize(x, *args, **kwargs):
|
||||
h,w,c = x.shape
|
||||
x = cv2.resize(x, *args, **kwargs)
|
||||
|
||||
x = imagelib.normalize_channels(x, c)
|
||||
return x
|
||||
|
|
@ -77,8 +77,6 @@ class SegIEPoly():
|
|||
self.pts = np.array(pts)
|
||||
self.n_max = self.n = len(pts)
|
||||
|
||||
def mult_points(self, val):
|
||||
self.pts *= val
|
||||
|
||||
|
||||
|
||||
|
@ -139,10 +137,6 @@ class SegIEPolys():
|
|||
def dump(self):
|
||||
return {'polys' : [ poly.dump() for poly in self.polys ] }
|
||||
|
||||
def mult_points(self, val):
|
||||
for poly in self.polys:
|
||||
poly.mult_points(val)
|
||||
|
||||
@staticmethod
|
||||
def load(data=None):
|
||||
ie_polys = SegIEPolys()
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
from .estimate_sharpness import estimate_sharpness
|
||||
|
||||
from .equalize_and_stack_square import equalize_and_stack_square
|
||||
|
||||
from .text import get_text_image, get_draw_text_lines
|
||||
|
@ -12,21 +11,16 @@ from .warp import gen_warp_params, warp_by_params
|
|||
|
||||
from .reduce_colors import reduce_colors
|
||||
|
||||
from .color_transfer import color_transfer, color_transfer_mix, color_transfer_sot, color_transfer_mkl, color_transfer_idt, color_hist_match, reinhard_color_transfer, linear_color_transfer
|
||||
from .color_transfer import color_transfer, color_transfer_mix, color_transfer_sot, color_transfer_mkl, color_transfer_idt, color_hist_match, reinhard_color_transfer, linear_color_transfer, seamless_clone
|
||||
|
||||
from .common import random_crop, normalize_channels, cut_odd_image, overlay_alpha_image
|
||||
from .common import normalize_channels, cut_odd_image, overlay_alpha_image
|
||||
|
||||
from .SegIEPolys import *
|
||||
|
||||
from .blursharpen import LinearMotionBlur, blursharpen
|
||||
|
||||
from .filters import apply_random_rgb_levels, \
|
||||
apply_random_overlay_triangle, \
|
||||
apply_random_hsv_shift, \
|
||||
apply_random_sharpen, \
|
||||
apply_random_motion_blur, \
|
||||
apply_random_gaussian_blur, \
|
||||
apply_random_nearest_resize, \
|
||||
apply_random_bilinear_resize, \
|
||||
apply_random_jpeg_compress, \
|
||||
apply_random_relight
|
||||
apply_random_bilinear_resize
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
import cv2
|
||||
import numexpr as ne
|
||||
import numpy as np
|
||||
import scipy as sp
|
||||
from numpy import linalg as npla
|
||||
|
||||
import scipy as sp
|
||||
import scipy.sparse
|
||||
from scipy.sparse.linalg import spsolve
|
||||
|
||||
def color_transfer_sot(src,trg, steps=10, batch_size=5, reg_sigmaXY=16.0, reg_sigmaV=5.0):
|
||||
"""
|
||||
|
@ -34,9 +35,8 @@ def color_transfer_sot(src,trg, steps=10, batch_size=5, reg_sigmaXY=16.0, reg_si
|
|||
h,w,c = src.shape
|
||||
new_src = src.copy()
|
||||
|
||||
advect = np.empty ( (h*w,c), dtype=src_dtype )
|
||||
for step in range (steps):
|
||||
advect.fill(0)
|
||||
advect = np.zeros ( (h*w,c), dtype=src_dtype )
|
||||
for batch in range (batch_size):
|
||||
dir = np.random.normal(size=c).astype(src_dtype)
|
||||
dir /= npla.norm(dir)
|
||||
|
@ -91,8 +91,6 @@ def color_transfer_mkl(x0, x1):
|
|||
return np.clip ( result.reshape ( (h,w,c) ).astype(x0.dtype), 0, 1)
|
||||
|
||||
def color_transfer_idt(i0, i1, bins=256, n_rot=20):
|
||||
import scipy.stats
|
||||
|
||||
relaxation = 1 / n_rot
|
||||
h,w,c = i0.shape
|
||||
h1,w1,c1 = i1.shape
|
||||
|
@ -135,58 +133,136 @@ def color_transfer_idt(i0, i1, bins=256, n_rot=20):
|
|||
|
||||
return np.clip ( d0.T.reshape ( (h,w,c) ).astype(i0.dtype) , 0, 1)
|
||||
|
||||
def reinhard_color_transfer(target : np.ndarray, source : np.ndarray, target_mask : np.ndarray = None, source_mask : np.ndarray = None, mask_cutoff=0.5) -> np.ndarray:
|
||||
"""
|
||||
Transfer color using rct method.
|
||||
def laplacian_matrix(n, m):
|
||||
mat_D = scipy.sparse.lil_matrix((m, m))
|
||||
mat_D.setdiag(-1, -1)
|
||||
mat_D.setdiag(4)
|
||||
mat_D.setdiag(-1, 1)
|
||||
mat_A = scipy.sparse.block_diag([mat_D] * n).tolil()
|
||||
mat_A.setdiag(-1, 1*m)
|
||||
mat_A.setdiag(-1, -1*m)
|
||||
return mat_A
|
||||
|
||||
target np.ndarray H W 3C (BGR) np.float32
|
||||
source np.ndarray H W 3C (BGR) np.float32
|
||||
def seamless_clone(source, target, mask):
|
||||
h, w,c = target.shape
|
||||
result = []
|
||||
|
||||
target_mask(None) np.ndarray H W 1C np.float32
|
||||
source_mask(None) np.ndarray H W 1C np.float32
|
||||
mat_A = laplacian_matrix(h, w)
|
||||
laplacian = mat_A.tocsc()
|
||||
|
||||
mask_cutoff(0.5) float
|
||||
mask[0,:] = 1
|
||||
mask[-1,:] = 1
|
||||
mask[:,0] = 1
|
||||
mask[:,-1] = 1
|
||||
q = np.argwhere(mask==0)
|
||||
|
||||
masks are used to limit the space where color statistics will be computed to adjust the target
|
||||
k = q[:,1]+q[:,0]*w
|
||||
mat_A[k, k] = 1
|
||||
mat_A[k, k + 1] = 0
|
||||
mat_A[k, k - 1] = 0
|
||||
mat_A[k, k + w] = 0
|
||||
mat_A[k, k - w] = 0
|
||||
|
||||
reference: Color Transfer between Images https://www.cs.tau.ac.il/~turkel/imagepapers/ColorTransfer.pdf
|
||||
"""
|
||||
source = cv2.cvtColor(source, cv2.COLOR_BGR2LAB)
|
||||
target = cv2.cvtColor(target, cv2.COLOR_BGR2LAB)
|
||||
mat_A = mat_A.tocsc()
|
||||
mask_flat = mask.flatten()
|
||||
for channel in range(c):
|
||||
|
||||
source_input = source
|
||||
if source_mask is not None:
|
||||
source_input = source_input.copy()
|
||||
source_input[source_mask[...,0] < mask_cutoff] = [0,0,0]
|
||||
source_flat = source[:, :, channel].flatten()
|
||||
target_flat = target[:, :, channel].flatten()
|
||||
|
||||
target_input = target
|
||||
if target_mask is not None:
|
||||
target_input = target_input.copy()
|
||||
target_input[target_mask[...,0] < mask_cutoff] = [0,0,0]
|
||||
mat_b = laplacian.dot(source_flat)*0.75
|
||||
mat_b[mask_flat==0] = target_flat[mask_flat==0]
|
||||
|
||||
target_l_mean, target_l_std, target_a_mean, target_a_std, target_b_mean, target_b_std, \
|
||||
= target_input[...,0].mean(), target_input[...,0].std(), target_input[...,1].mean(), target_input[...,1].std(), target_input[...,2].mean(), target_input[...,2].std()
|
||||
x = spsolve(mat_A, mat_b).reshape((h, w))
|
||||
result.append (x)
|
||||
|
||||
source_l_mean, source_l_std, source_a_mean, source_a_std, source_b_mean, source_b_std, \
|
||||
= source_input[...,0].mean(), source_input[...,0].std(), source_input[...,1].mean(), source_input[...,1].std(), source_input[...,2].mean(), source_input[...,2].std()
|
||||
|
||||
# not as in the paper: scale by the standard deviations using reciprocal of paper proposed factor
|
||||
target_l = target[...,0]
|
||||
target_l = ne.evaluate('(target_l - target_l_mean) * source_l_std / target_l_std + source_l_mean')
|
||||
return np.clip( np.dstack(result), 0, 1 )
|
||||
|
||||
target_a = target[...,1]
|
||||
target_a = ne.evaluate('(target_a - target_a_mean) * source_a_std / target_a_std + source_a_mean')
|
||||
def reinhard_color_transfer(target, source, clip=False, preserve_paper=False, source_mask=None, target_mask=None):
|
||||
"""
|
||||
Transfers the color distribution from the source to the target
|
||||
image using the mean and standard deviations of the L*a*b*
|
||||
color space.
|
||||
|
||||
target_b = target[...,2]
|
||||
target_b = ne.evaluate('(target_b - target_b_mean) * source_b_std / target_b_std + source_b_mean')
|
||||
This implementation is (loosely) based on to the "Color Transfer
|
||||
between Images" paper by Reinhard et al., 2001.
|
||||
|
||||
np.clip(target_l, 0, 100, out=target_l)
|
||||
np.clip(target_a, -127, 127, out=target_a)
|
||||
np.clip(target_b, -127, 127, out=target_b)
|
||||
Parameters:
|
||||
-------
|
||||
source: NumPy array
|
||||
OpenCV image in BGR color space (the source image)
|
||||
target: NumPy array
|
||||
OpenCV image in BGR color space (the target image)
|
||||
clip: Should components of L*a*b* image be scaled by np.clip before
|
||||
converting back to BGR color space?
|
||||
If False then components will be min-max scaled appropriately.
|
||||
Clipping will keep target image brightness truer to the input.
|
||||
Scaling will adjust image brightness to avoid washed out portions
|
||||
in the resulting color transfer that can be caused by clipping.
|
||||
preserve_paper: Should color transfer strictly follow methodology
|
||||
layed out in original paper? The method does not always produce
|
||||
aesthetically pleasing results.
|
||||
If False then L*a*b* components will scaled using the reciprocal of
|
||||
the scaling factor proposed in the paper. This method seems to produce
|
||||
more consistently aesthetically pleasing results
|
||||
|
||||
return cv2.cvtColor(np.stack([target_l,target_a,target_b], -1), cv2.COLOR_LAB2BGR)
|
||||
Returns:
|
||||
-------
|
||||
transfer: NumPy array
|
||||
OpenCV image (w, h, 3) NumPy array (uint8)
|
||||
"""
|
||||
|
||||
|
||||
# convert the images from the RGB to L*ab* color space, being
|
||||
# sure to utilizing the floating point data type (note: OpenCV
|
||||
# expects floats to be 32-bit, so use that instead of 64-bit)
|
||||
source = cv2.cvtColor(source, cv2.COLOR_BGR2LAB).astype(np.float32)
|
||||
target = cv2.cvtColor(target, cv2.COLOR_BGR2LAB).astype(np.float32)
|
||||
|
||||
# compute color statistics for the source and target images
|
||||
src_input = source if source_mask is None else source*source_mask
|
||||
tgt_input = target if target_mask is None else target*target_mask
|
||||
(lMeanSrc, lStdSrc, aMeanSrc, aStdSrc, bMeanSrc, bStdSrc) = lab_image_stats(src_input)
|
||||
(lMeanTar, lStdTar, aMeanTar, aStdTar, bMeanTar, bStdTar) = lab_image_stats(tgt_input)
|
||||
|
||||
# subtract the means from the target image
|
||||
(l, a, b) = cv2.split(target)
|
||||
l -= lMeanTar
|
||||
a -= aMeanTar
|
||||
b -= bMeanTar
|
||||
|
||||
if preserve_paper:
|
||||
# scale by the standard deviations using paper proposed factor
|
||||
l = (lStdTar / lStdSrc) * l
|
||||
a = (aStdTar / aStdSrc) * a
|
||||
b = (bStdTar / bStdSrc) * b
|
||||
else:
|
||||
# scale by the standard deviations using reciprocal of paper proposed factor
|
||||
l = (lStdSrc / lStdTar) * l
|
||||
a = (aStdSrc / aStdTar) * a
|
||||
b = (bStdSrc / bStdTar) * b
|
||||
|
||||
# add in the source mean
|
||||
l += lMeanSrc
|
||||
a += aMeanSrc
|
||||
b += bMeanSrc
|
||||
|
||||
# clip/scale the pixel intensities to [0, 255] if they fall
|
||||
# outside this range
|
||||
l = _scale_array(l, clip=clip)
|
||||
a = _scale_array(a, clip=clip)
|
||||
b = _scale_array(b, clip=clip)
|
||||
|
||||
# merge the channels together and convert back to the RGB color
|
||||
# space, being sure to utilize the 8-bit unsigned integer data
|
||||
# type
|
||||
transfer = cv2.merge([l, a, b])
|
||||
transfer = cv2.cvtColor(transfer.astype(np.uint8), cv2.COLOR_LAB2BGR)
|
||||
|
||||
# return the color transferred image
|
||||
return transfer
|
||||
|
||||
def linear_color_transfer(target_img, source_img, mode='pca', eps=1e-5):
|
||||
'''
|
||||
Matches the colour distribution of the target image to that of the source image
|
||||
|
@ -323,7 +399,9 @@ def color_transfer(ct_mode, img_src, img_trg):
|
|||
if ct_mode == 'lct':
|
||||
out = linear_color_transfer (img_src, img_trg)
|
||||
elif ct_mode == 'rct':
|
||||
out = reinhard_color_transfer(img_src, img_trg)
|
||||
out = reinhard_color_transfer ( np.clip( img_src*255, 0, 255 ).astype(np.uint8),
|
||||
np.clip( img_trg*255, 0, 255 ).astype(np.uint8) )
|
||||
out = np.clip( out.astype(np.float32) / 255.0, 0.0, 1.0)
|
||||
elif ct_mode == 'mkl':
|
||||
out = color_transfer_mkl (img_src, img_trg)
|
||||
elif ct_mode == 'idt':
|
||||
|
|
|
@ -1,16 +1,5 @@
|
|||
import numpy as np
|
||||
|
||||
def random_crop(img, w, h):
|
||||
height, width = img.shape[:2]
|
||||
|
||||
h_rnd = height - h
|
||||
w_rnd = width - w
|
||||
|
||||
y = np.random.randint(0, h_rnd) if h_rnd > 0 else 0
|
||||
x = np.random.randint(0, w_rnd) if w_rnd > 0 else 0
|
||||
|
||||
return img[y:y+height, x:x+width]
|
||||
|
||||
def normalize_channels(img, target_channels):
|
||||
img_shape_len = len(img.shape)
|
||||
if img_shape_len == 2:
|
||||
|
|
|
@ -31,7 +31,9 @@ goods or services; loss of use, data, or profits; or business interruption) howe
|
|||
import numpy as np
|
||||
import cv2
|
||||
from math import atan2, pi
|
||||
|
||||
from scipy.ndimage import convolve
|
||||
from skimage.filters.edges import HSOBEL_WEIGHTS
|
||||
from skimage.feature import canny
|
||||
|
||||
def sobel(image):
|
||||
# type: (numpy.ndarray) -> numpy.ndarray
|
||||
|
@ -40,11 +42,10 @@ def sobel(image):
|
|||
|
||||
Inspired by the [Octave implementation](https://sourceforge.net/p/octave/image/ci/default/tree/inst/edge.m#l196).
|
||||
"""
|
||||
from skimage.filters.edges import HSOBEL_WEIGHTS
|
||||
|
||||
h1 = np.array(HSOBEL_WEIGHTS)
|
||||
h1 /= np.sum(abs(h1)) # normalize h1
|
||||
|
||||
from scipy.ndimage import convolve
|
||||
strength2 = np.square(convolve(image, h1.T))
|
||||
|
||||
# Note: https://sourceforge.net/p/octave/image/ci/default/tree/inst/edge.m#l59
|
||||
|
@ -102,7 +103,6 @@ def compute(image):
|
|||
# edge detection using canny and sobel canny edge detection is done to
|
||||
# classify the blocks as edge or non-edge blocks and sobel edge
|
||||
# detection is done for the purpose of edge width measurement.
|
||||
from skimage.feature import canny
|
||||
canny_edges = canny(image)
|
||||
sobel_edges = sobel(image)
|
||||
|
||||
|
@ -269,10 +269,9 @@ def get_block_contrast(block):
|
|||
|
||||
|
||||
def estimate_sharpness(image):
|
||||
height, width = image.shape[:2]
|
||||
|
||||
if image.ndim == 3:
|
||||
if image.shape[2] > 1:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
else:
|
||||
image = image[...,0]
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
return compute(image)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import numpy as np
|
||||
from .blursharpen import LinearMotionBlur, blursharpen
|
||||
from .blursharpen import LinearMotionBlur
|
||||
import cv2
|
||||
|
||||
def apply_random_rgb_levels(img, mask=None, rnd_state=None):
|
||||
|
@ -38,24 +38,6 @@ def apply_random_hsv_shift(img, mask=None, rnd_state=None):
|
|||
|
||||
return result
|
||||
|
||||
def apply_random_sharpen( img, chance, kernel_max_size, mask=None, rnd_state=None ):
|
||||
if rnd_state is None:
|
||||
rnd_state = np.random
|
||||
|
||||
sharp_rnd_kernel = rnd_state.randint(kernel_max_size)+1
|
||||
|
||||
result = img
|
||||
if rnd_state.randint(100) < np.clip(chance, 0, 100):
|
||||
if rnd_state.randint(2) == 0:
|
||||
result = blursharpen(result, 1, sharp_rnd_kernel, rnd_state.randint(10) )
|
||||
else:
|
||||
result = blursharpen(result, 2, sharp_rnd_kernel, rnd_state.randint(50) )
|
||||
|
||||
if mask is not None:
|
||||
result = img*(1-mask) + result*mask
|
||||
|
||||
return result
|
||||
|
||||
def apply_random_motion_blur( img, chance, mb_max_size, mask=None, rnd_state=None ):
|
||||
if rnd_state is None:
|
||||
rnd_state = np.random
|
||||
|
@ -84,7 +66,8 @@ def apply_random_gaussian_blur( img, chance, kernel_max_size, mask=None, rnd_sta
|
|||
|
||||
return result
|
||||
|
||||
def apply_random_resize( img, chance, max_size_per, interpolation=cv2.INTER_LINEAR, mask=None, rnd_state=None ):
|
||||
|
||||
def apply_random_bilinear_resize( img, chance, max_size_per, mask=None, rnd_state=None ):
|
||||
if rnd_state is None:
|
||||
rnd_state = np.random
|
||||
|
||||
|
@ -96,150 +79,9 @@ def apply_random_resize( img, chance, max_size_per, interpolation=cv2.INTER_LINE
|
|||
rw = w - int( trg * int(w*(max_size_per/100.0)) )
|
||||
rh = h - int( trg * int(h*(max_size_per/100.0)) )
|
||||
|
||||
result = cv2.resize (result, (rw,rh), interpolation=interpolation )
|
||||
result = cv2.resize (result, (w,h), interpolation=interpolation )
|
||||
result = cv2.resize (result, (rw,rh), cv2.INTER_LINEAR )
|
||||
result = cv2.resize (result, (w,h), cv2.INTER_LINEAR )
|
||||
if mask is not None:
|
||||
result = img*(1-mask) + result*mask
|
||||
|
||||
return result
|
||||
|
||||
def apply_random_nearest_resize( img, chance, max_size_per, mask=None, rnd_state=None ):
|
||||
return apply_random_resize( img, chance, max_size_per, interpolation=cv2.INTER_NEAREST, mask=mask, rnd_state=rnd_state )
|
||||
|
||||
def apply_random_bilinear_resize( img, chance, max_size_per, mask=None, rnd_state=None ):
|
||||
return apply_random_resize( img, chance, max_size_per, interpolation=cv2.INTER_LINEAR, mask=mask, rnd_state=rnd_state )
|
||||
|
||||
def apply_random_jpeg_compress( img, chance, mask=None, rnd_state=None ):
|
||||
if rnd_state is None:
|
||||
rnd_state = np.random
|
||||
|
||||
result = img
|
||||
if rnd_state.randint(100) < np.clip(chance, 0, 100):
|
||||
h,w,c = result.shape
|
||||
|
||||
quality = rnd_state.randint(10,101)
|
||||
|
||||
ret, result = cv2.imencode('.jpg', np.clip(img*255, 0,255).astype(np.uint8), [int(cv2.IMWRITE_JPEG_QUALITY), quality] )
|
||||
if ret == True:
|
||||
result = cv2.imdecode(result, flags=cv2.IMREAD_UNCHANGED)
|
||||
result = result.astype(np.float32) / 255.0
|
||||
if mask is not None:
|
||||
result = img*(1-mask) + result*mask
|
||||
|
||||
return result
|
||||
|
||||
def apply_random_overlay_triangle( img, max_alpha, mask=None, rnd_state=None ):
|
||||
if rnd_state is None:
|
||||
rnd_state = np.random
|
||||
|
||||
h,w,c = img.shape
|
||||
pt1 = [rnd_state.randint(w), rnd_state.randint(h) ]
|
||||
pt2 = [rnd_state.randint(w), rnd_state.randint(h) ]
|
||||
pt3 = [rnd_state.randint(w), rnd_state.randint(h) ]
|
||||
|
||||
alpha = rnd_state.uniform()*max_alpha
|
||||
|
||||
tri_mask = cv2.fillPoly( np.zeros_like(img), [ np.array([pt1,pt2,pt3], np.int32) ], (alpha,)*c )
|
||||
|
||||
if rnd_state.randint(2) == 0:
|
||||
result = np.clip(img+tri_mask, 0, 1)
|
||||
else:
|
||||
result = np.clip(img-tri_mask, 0, 1)
|
||||
|
||||
if mask is not None:
|
||||
result = img*(1-mask) + result*mask
|
||||
|
||||
return result
|
||||
|
||||
def _min_resize(x, m):
|
||||
if x.shape[0] < x.shape[1]:
|
||||
s0 = m
|
||||
s1 = int(float(m) / float(x.shape[0]) * float(x.shape[1]))
|
||||
else:
|
||||
s0 = int(float(m) / float(x.shape[1]) * float(x.shape[0]))
|
||||
s1 = m
|
||||
new_max = min(s1, s0)
|
||||
raw_max = min(x.shape[0], x.shape[1])
|
||||
return cv2.resize(x, (s1, s0), interpolation=cv2.INTER_LANCZOS4)
|
||||
|
||||
def _d_resize(x, d, fac=1.0):
|
||||
new_min = min(int(d[1] * fac), int(d[0] * fac))
|
||||
raw_min = min(x.shape[0], x.shape[1])
|
||||
if new_min < raw_min:
|
||||
interpolation = cv2.INTER_AREA
|
||||
else:
|
||||
interpolation = cv2.INTER_LANCZOS4
|
||||
y = cv2.resize(x, (int(d[1] * fac), int(d[0] * fac)), interpolation=interpolation)
|
||||
return y
|
||||
|
||||
def _get_image_gradient(dist):
|
||||
cols = cv2.filter2D(dist, cv2.CV_32F, np.array([[-1, 0, +1], [-2, 0, +2], [-1, 0, +1]]))
|
||||
rows = cv2.filter2D(dist, cv2.CV_32F, np.array([[-1, -2, -1], [0, 0, 0], [+1, +2, +1]]))
|
||||
return cols, rows
|
||||
|
||||
def _generate_lighting_effects(content):
|
||||
h512 = content
|
||||
h256 = cv2.pyrDown(h512)
|
||||
h128 = cv2.pyrDown(h256)
|
||||
h64 = cv2.pyrDown(h128)
|
||||
h32 = cv2.pyrDown(h64)
|
||||
h16 = cv2.pyrDown(h32)
|
||||
c512, r512 = _get_image_gradient(h512)
|
||||
c256, r256 = _get_image_gradient(h256)
|
||||
c128, r128 = _get_image_gradient(h128)
|
||||
c64, r64 = _get_image_gradient(h64)
|
||||
c32, r32 = _get_image_gradient(h32)
|
||||
c16, r16 = _get_image_gradient(h16)
|
||||
c = c16
|
||||
c = _d_resize(cv2.pyrUp(c), c32.shape) * 4.0 + c32
|
||||
c = _d_resize(cv2.pyrUp(c), c64.shape) * 4.0 + c64
|
||||
c = _d_resize(cv2.pyrUp(c), c128.shape) * 4.0 + c128
|
||||
c = _d_resize(cv2.pyrUp(c), c256.shape) * 4.0 + c256
|
||||
c = _d_resize(cv2.pyrUp(c), c512.shape) * 4.0 + c512
|
||||
r = r16
|
||||
r = _d_resize(cv2.pyrUp(r), r32.shape) * 4.0 + r32
|
||||
r = _d_resize(cv2.pyrUp(r), r64.shape) * 4.0 + r64
|
||||
r = _d_resize(cv2.pyrUp(r), r128.shape) * 4.0 + r128
|
||||
r = _d_resize(cv2.pyrUp(r), r256.shape) * 4.0 + r256
|
||||
r = _d_resize(cv2.pyrUp(r), r512.shape) * 4.0 + r512
|
||||
coarse_effect_cols = c
|
||||
coarse_effect_rows = r
|
||||
EPS = 1e-10
|
||||
|
||||
max_effect = np.max((coarse_effect_cols**2 + coarse_effect_rows**2)**0.5, axis=0, keepdims=True, ).max(1, keepdims=True)
|
||||
coarse_effect_cols = (coarse_effect_cols + EPS) / (max_effect + EPS)
|
||||
coarse_effect_rows = (coarse_effect_rows + EPS) / (max_effect + EPS)
|
||||
|
||||
return np.stack([ np.zeros_like(coarse_effect_rows), coarse_effect_rows, coarse_effect_cols], axis=-1)
|
||||
|
||||
def apply_random_relight(img, mask=None, rnd_state=None):
|
||||
if rnd_state is None:
|
||||
rnd_state = np.random
|
||||
|
||||
def_img = img
|
||||
|
||||
if rnd_state.randint(2) == 0:
|
||||
light_pos_y = 1.0 if rnd_state.randint(2) == 0 else -1.0
|
||||
light_pos_x = rnd_state.uniform()*2-1.0
|
||||
else:
|
||||
light_pos_y = rnd_state.uniform()*2-1.0
|
||||
light_pos_x = 1.0 if rnd_state.randint(2) == 0 else -1.0
|
||||
|
||||
light_source_height = 0.3*rnd_state.uniform()*0.7
|
||||
light_intensity = 1.0+rnd_state.uniform()
|
||||
ambient_intensity = 0.5
|
||||
|
||||
light_source_location = np.array([[[light_source_height, light_pos_y, light_pos_x ]]], dtype=np.float32)
|
||||
light_source_direction = light_source_location / np.sqrt(np.sum(np.square(light_source_location)))
|
||||
|
||||
lighting_effect = _generate_lighting_effects(img)
|
||||
lighting_effect = np.sum(lighting_effect * light_source_direction, axis=-1).clip(0, 1)
|
||||
lighting_effect = np.mean(lighting_effect, axis=-1, keepdims=True)
|
||||
|
||||
result = def_img * (ambient_intensity + lighting_effect * light_intensity) #light_source_color
|
||||
result = np.clip(result, 0, 1)
|
||||
|
||||
if mask is not None:
|
||||
result = def_img*(1-mask) + result*mask
|
||||
|
||||
return result
|
|
@ -1,2 +1,2 @@
|
|||
from .draw import circle_faded, random_circle_faded, bezier, random_bezier_split_faded, random_faded
|
||||
from .draw import *
|
||||
from .calc import *
|
|
@ -1,36 +1,23 @@
|
|||
"""
|
||||
Signed distance drawing functions using numpy.
|
||||
"""
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
from numpy import linalg as npla
|
||||
|
||||
|
||||
def vector2_dot(a,b):
|
||||
return a[...,0]*b[...,0]+a[...,1]*b[...,1]
|
||||
|
||||
def vector2_dot2(a):
|
||||
return a[...,0]*a[...,0]+a[...,1]*a[...,1]
|
||||
|
||||
def vector2_cross(a,b):
|
||||
return a[...,0]*b[...,1]-a[...,1]*b[...,0]
|
||||
|
||||
|
||||
def circle_faded( wh, center, fade_dists ):
|
||||
def circle_faded( hw, center, fade_dists ):
|
||||
"""
|
||||
returns drawn circle in [h,w,1] output range [0..1.0] float32
|
||||
|
||||
wh = [w,h] resolution
|
||||
center = [x,y] center of circle
|
||||
hw = [h,w] resolution
|
||||
center = [y,x] center of circle
|
||||
fade_dists = [fade_start, fade_end] fade values
|
||||
"""
|
||||
w,h = wh
|
||||
h,w = hw
|
||||
|
||||
pts = np.empty( (h,w,2), dtype=np.float32 )
|
||||
pts[...,0] = np.arange(w)[:,None]
|
||||
pts[...,1] = np.arange(h)[None,:]
|
||||
|
||||
pts[...,0] = np.arange(w)[:,None]
|
||||
pts = pts.reshape ( (h*w, -1) )
|
||||
|
||||
pts_dists = np.abs ( npla.norm(pts-center, axis=-1) )
|
||||
|
@ -44,157 +31,14 @@ def circle_faded( wh, center, fade_dists ):
|
|||
|
||||
return pts_dists.reshape ( (h,w,1) ).astype(np.float32)
|
||||
|
||||
|
||||
def bezier( wh, A, B, C ):
|
||||
"""
|
||||
returns drawn bezier in [h,w,1] output range float32,
|
||||
every pixel contains signed distance to bezier line
|
||||
|
||||
wh [w,h] resolution
|
||||
A,B,C points [x,y]
|
||||
"""
|
||||
|
||||
width,height = wh
|
||||
|
||||
A = np.float32(A)
|
||||
B = np.float32(B)
|
||||
C = np.float32(C)
|
||||
|
||||
|
||||
pos = np.empty( (height,width,2), dtype=np.float32 )
|
||||
pos[...,0] = np.arange(width)[:,None]
|
||||
pos[...,1] = np.arange(height)[None,:]
|
||||
|
||||
|
||||
a = B-A
|
||||
b = A - 2.0*B + C
|
||||
c = a * 2.0
|
||||
d = A - pos
|
||||
|
||||
b_dot = vector2_dot(b,b)
|
||||
if b_dot == 0.0:
|
||||
return np.zeros( (height,width), dtype=np.float32 )
|
||||
|
||||
kk = 1.0 / b_dot
|
||||
|
||||
kx = kk * vector2_dot(a,b)
|
||||
ky = kk * (2.0*vector2_dot(a,a)+vector2_dot(d,b))/3.0;
|
||||
kz = kk * vector2_dot(d,a);
|
||||
|
||||
res = 0.0;
|
||||
sgn = 0.0;
|
||||
|
||||
p = ky - kx*kx;
|
||||
|
||||
p3 = p*p*p;
|
||||
q = kx*(2.0*kx*kx - 3.0*ky) + kz;
|
||||
h = q*q + 4.0*p3;
|
||||
|
||||
hp_sel = h >= 0.0
|
||||
|
||||
hp_p = h[hp_sel]
|
||||
hp_p = np.sqrt(hp_p)
|
||||
|
||||
hp_x = ( np.stack( (hp_p,-hp_p), -1) -q[hp_sel,None] ) / 2.0
|
||||
hp_uv = np.sign(hp_x) * np.power( np.abs(hp_x), [1.0/3.0, 1.0/3.0] )
|
||||
hp_t = np.clip( hp_uv[...,0] + hp_uv[...,1] - kx, 0.0, 1.0 )
|
||||
|
||||
hp_t = hp_t[...,None]
|
||||
hp_q = d[hp_sel]+(c+b*hp_t)*hp_t
|
||||
hp_res = vector2_dot2(hp_q)
|
||||
hp_sgn = vector2_cross(c+2.0*b*hp_t,hp_q)
|
||||
|
||||
hl_sel = h < 0.0
|
||||
|
||||
hl_q = q[hl_sel]
|
||||
hl_p = p[hl_sel]
|
||||
hl_z = np.sqrt(-hl_p)
|
||||
hl_v = np.arccos( hl_q / (hl_p*hl_z*2.0)) / 3.0
|
||||
|
||||
hl_m = np.cos(hl_v)
|
||||
hl_n = np.sin(hl_v)*1.732050808;
|
||||
|
||||
hl_t = np.clip( np.stack( (hl_m+hl_m,-hl_n-hl_m,hl_n-hl_m), -1)*hl_z[...,None]-kx, 0.0, 1.0 );
|
||||
|
||||
hl_d = d[hl_sel]
|
||||
|
||||
hl_qx = hl_d+(c+b*hl_t[...,0:1])*hl_t[...,0:1]
|
||||
|
||||
hl_dx = vector2_dot2(hl_qx)
|
||||
hl_sx = vector2_cross(c+2.0*b*hl_t[...,0:1], hl_qx)
|
||||
|
||||
hl_qy = hl_d+(c+b*hl_t[...,1:2])*hl_t[...,1:2]
|
||||
hl_dy = vector2_dot2(hl_qy)
|
||||
hl_sy = vector2_cross(c+2.0*b*hl_t[...,1:2],hl_qy);
|
||||
|
||||
hl_dx_l_dy = hl_dx<hl_dy
|
||||
hl_dx_ge_dy = hl_dx>=hl_dy
|
||||
|
||||
hl_res = np.empty_like(hl_dx)
|
||||
hl_res[hl_dx_l_dy] = hl_dx[hl_dx_l_dy]
|
||||
hl_res[hl_dx_ge_dy] = hl_dy[hl_dx_ge_dy]
|
||||
|
||||
hl_sgn = np.empty_like(hl_sx)
|
||||
hl_sgn[hl_dx_l_dy] = hl_sx[hl_dx_l_dy]
|
||||
hl_sgn[hl_dx_ge_dy] = hl_sy[hl_dx_ge_dy]
|
||||
|
||||
res = np.empty( (height, width), np.float32 )
|
||||
res[hp_sel] = hp_res
|
||||
res[hl_sel] = hl_res
|
||||
|
||||
sgn = np.empty( (height, width), np.float32 )
|
||||
sgn[hp_sel] = hp_sgn
|
||||
sgn[hl_sel] = hl_sgn
|
||||
|
||||
sgn = np.sign(sgn)
|
||||
res = np.sqrt(res)*sgn
|
||||
|
||||
return res[...,None]
|
||||
|
||||
def random_faded(wh):
|
||||
"""
|
||||
apply one of them:
|
||||
random_circle_faded
|
||||
random_bezier_split_faded
|
||||
"""
|
||||
rnd = np.random.randint(2)
|
||||
if rnd == 0:
|
||||
return random_circle_faded(wh)
|
||||
elif rnd == 1:
|
||||
return random_bezier_split_faded(wh)
|
||||
|
||||
def random_circle_faded ( wh, rnd_state=None ):
|
||||
def random_circle_faded ( hw, rnd_state=None ):
|
||||
if rnd_state is None:
|
||||
rnd_state = np.random
|
||||
|
||||
w,h = wh
|
||||
wh_max = max(w,h)
|
||||
fade_start = rnd_state.randint(wh_max)
|
||||
fade_end = fade_start + rnd_state.randint(wh_max- fade_start)
|
||||
h,w = hw
|
||||
hw_max = max(h,w)
|
||||
fade_start = rnd_state.randint(hw_max)
|
||||
fade_end = fade_start + rnd_state.randint(hw_max- fade_start)
|
||||
|
||||
return circle_faded (wh, [ rnd_state.randint(h), rnd_state.randint(w) ],
|
||||
return circle_faded (hw, [ rnd_state.randint(h), rnd_state.randint(w) ],
|
||||
[fade_start, fade_end] )
|
||||
|
||||
def random_bezier_split_faded( wh ):
|
||||
width, height = wh
|
||||
|
||||
degA = np.random.randint(360)
|
||||
degB = np.random.randint(360)
|
||||
degC = np.random.randint(360)
|
||||
|
||||
deg_2_rad = math.pi / 180.0
|
||||
|
||||
center = np.float32([width / 2.0, height / 2.0])
|
||||
|
||||
radius = max(width, height)
|
||||
|
||||
A = center + radius*np.float32([ math.sin( degA * deg_2_rad), math.cos( degA * deg_2_rad) ] )
|
||||
B = center + np.random.randint(radius)*np.float32([ math.sin( degB * deg_2_rad), math.cos( degB * deg_2_rad) ] )
|
||||
C = center + radius*np.float32([ math.sin( degC * deg_2_rad), math.cos( degC * deg_2_rad) ] )
|
||||
|
||||
x = bezier( (width,height), A, B, C )
|
||||
|
||||
x = x / (1+np.random.randint(radius)) + 0.5
|
||||
|
||||
x = np.clip(x, 0, 1)
|
||||
return x
|
||||
|
|
|
@ -1,146 +1,32 @@
|
|||
import numpy as np
|
||||
import numpy.linalg as npla
|
||||
import cv2
|
||||
from core import randomex
|
||||
|
||||
def mls_rigid_deformation(vy, vx, src_pts, dst_pts, alpha=1.0, eps=1e-8):
|
||||
dst_pts = dst_pts[..., ::-1].astype(np.int16)
|
||||
src_pts = src_pts[..., ::-1].astype(np.int16)
|
||||
|
||||
src_pts, dst_pts = dst_pts, src_pts
|
||||
|
||||
grow = vx.shape[0]
|
||||
gcol = vx.shape[1]
|
||||
ctrls = src_pts.shape[0]
|
||||
|
||||
reshaped_p = src_pts.reshape(ctrls, 2, 1, 1)
|
||||
reshaped_v = np.vstack((vx.reshape(1, grow, gcol), vy.reshape(1, grow, gcol)))
|
||||
|
||||
w = 1.0 / (np.sum((reshaped_p - reshaped_v).astype(np.float32) ** 2, axis=1) + eps) ** alpha
|
||||
w /= np.sum(w, axis=0, keepdims=True)
|
||||
|
||||
pstar = np.zeros((2, grow, gcol), np.float32)
|
||||
for i in range(ctrls):
|
||||
pstar += w[i] * reshaped_p[i]
|
||||
|
||||
vpstar = reshaped_v - pstar
|
||||
|
||||
reshaped_mul_right = np.concatenate((vpstar[:,None,...],
|
||||
np.concatenate((vpstar[1:2,None,...],-vpstar[0:1,None,...]), 0)
|
||||
), axis=1).transpose(2, 3, 0, 1)
|
||||
|
||||
reshaped_q = dst_pts.reshape((ctrls, 2, 1, 1))
|
||||
|
||||
qstar = np.zeros((2, grow, gcol), np.float32)
|
||||
for i in range(ctrls):
|
||||
qstar += w[i] * reshaped_q[i]
|
||||
|
||||
temp = np.zeros((grow, gcol, 2), np.float32)
|
||||
for i in range(ctrls):
|
||||
phat = reshaped_p[i] - pstar
|
||||
qhat = reshaped_q[i] - qstar
|
||||
|
||||
temp += np.matmul(qhat.reshape(1, 2, grow, gcol).transpose(2, 3, 0, 1),
|
||||
|
||||
np.matmul( ( w[None, i:i+1,...] *
|
||||
np.concatenate((phat.reshape(1, 2, grow, gcol),
|
||||
np.concatenate( (phat[None,1:2], -phat[None,0:1]), 1 )), 0)
|
||||
).transpose(2, 3, 0, 1), reshaped_mul_right
|
||||
)
|
||||
).reshape(grow, gcol, 2)
|
||||
|
||||
temp = temp.transpose(2, 0, 1)
|
||||
|
||||
normed_temp = np.linalg.norm(temp, axis=0, keepdims=True)
|
||||
normed_vpstar = np.linalg.norm(vpstar, axis=0, keepdims=True)
|
||||
nan_mask = normed_temp[0]==0
|
||||
|
||||
transformers = np.true_divide(temp, normed_temp, out=np.zeros_like(temp), where= ~nan_mask) * normed_vpstar + qstar
|
||||
nan_mask_flat = np.flatnonzero(nan_mask)
|
||||
nan_mask_anti_flat = np.flatnonzero(~nan_mask)
|
||||
|
||||
transformers[0][nan_mask] = np.interp(nan_mask_flat, nan_mask_anti_flat, transformers[0][~nan_mask])
|
||||
transformers[1][nan_mask] = np.interp(nan_mask_flat, nan_mask_anti_flat, transformers[1][~nan_mask])
|
||||
|
||||
return transformers
|
||||
|
||||
def gen_pts(W, H, rnd_state=None):
|
||||
|
||||
def gen_warp_params (w, flip, rotation_range=[-10,10], scale_range=[-0.5, 0.5], tx_range=[-0.05, 0.05], ty_range=[-0.05, 0.05], rnd_state=None ):
|
||||
if rnd_state is None:
|
||||
rnd_state = np.random
|
||||
|
||||
min_pts, max_pts = 4, 8
|
||||
n_pts = rnd_state.randint(min_pts, max_pts)
|
||||
|
||||
min_radius_per = 0.00
|
||||
max_radius_per = 0.10
|
||||
pts = []
|
||||
|
||||
for i in range(n_pts):
|
||||
while True:
|
||||
x, y = rnd_state.randint(W), rnd_state.randint(H)
|
||||
rad = min_radius_per + rnd_state.rand()*(max_radius_per-min_radius_per)
|
||||
|
||||
intersect = False
|
||||
for px,py,prad,_,_ in pts:
|
||||
|
||||
dist = npla.norm([x-px, y-py])
|
||||
if dist <= (rad+prad)*2:
|
||||
intersect = True
|
||||
break
|
||||
if intersect:
|
||||
continue
|
||||
|
||||
angle = rnd_state.rand()*(2*np.pi)
|
||||
x2 = int(x+np.cos(angle)*W*rad)
|
||||
y2 = int(y+np.sin(angle)*H*rad)
|
||||
|
||||
break
|
||||
pts.append( (x,y,rad, x2,y2) )
|
||||
|
||||
pts1 = np.array( [ [pt[0],pt[1]] for pt in pts ] )
|
||||
pts2 = np.array( [ [pt[-2],pt[-1]] for pt in pts ] )
|
||||
|
||||
return pts1, pts2
|
||||
|
||||
|
||||
def gen_warp_params (w, flip=False, rotation_range=[-10,10], scale_range=[-0.5, 0.5], tx_range=[-0.05, 0.05], ty_range=[-0.05, 0.05], rnd_state=None, warp_rnd_state=None ):
|
||||
if rnd_state is None:
|
||||
rnd_state = np.random
|
||||
if warp_rnd_state is None:
|
||||
warp_rnd_state = np.random
|
||||
rw = None
|
||||
if w < 64:
|
||||
rw = w
|
||||
w = 64
|
||||
|
||||
rotation = rnd_state.uniform( rotation_range[0], rotation_range[1] )
|
||||
scale = rnd_state.uniform( 1/(1-scale_range[0]) , 1+scale_range[1] )
|
||||
scale = rnd_state.uniform(1 +scale_range[0], 1 +scale_range[1])
|
||||
tx = rnd_state.uniform( tx_range[0], tx_range[1] )
|
||||
ty = rnd_state.uniform( ty_range[0], ty_range[1] )
|
||||
p_flip = flip and rnd_state.randint(10) < 4
|
||||
|
||||
#random warp V1
|
||||
cell_size = [ w // (2**i) for i in range(1,4) ] [ warp_rnd_state.randint(3) ]
|
||||
#random warp by grid
|
||||
cell_size = [ w // (2**i) for i in range(1,4) ] [ rnd_state.randint(3) ]
|
||||
cell_count = w // cell_size + 1
|
||||
|
||||
grid_points = np.linspace( 0, w, cell_count)
|
||||
mapx = np.broadcast_to(grid_points, (cell_count, cell_count)).copy()
|
||||
mapy = mapx.T
|
||||
mapx[1:-1,1:-1] = mapx[1:-1,1:-1] + randomex.random_normal( size=(cell_count-2, cell_count-2), rnd_state=warp_rnd_state )*(cell_size*0.24)
|
||||
mapy[1:-1,1:-1] = mapy[1:-1,1:-1] + randomex.random_normal( size=(cell_count-2, cell_count-2), rnd_state=warp_rnd_state )*(cell_size*0.24)
|
||||
|
||||
mapx[1:-1,1:-1] = mapx[1:-1,1:-1] + randomex.random_normal( size=(cell_count-2, cell_count-2) )*(cell_size*0.24)
|
||||
mapy[1:-1,1:-1] = mapy[1:-1,1:-1] + randomex.random_normal( size=(cell_count-2, cell_count-2) )*(cell_size*0.24)
|
||||
|
||||
half_cell_size = cell_size // 2
|
||||
|
||||
mapx = cv2.resize(mapx, (w+cell_size,)*2 )[half_cell_size:-half_cell_size,half_cell_size:-half_cell_size].astype(np.float32)
|
||||
mapy = cv2.resize(mapy, (w+cell_size,)*2 )[half_cell_size:-half_cell_size,half_cell_size:-half_cell_size].astype(np.float32)
|
||||
##############
|
||||
|
||||
# random warp V2
|
||||
# pts1, pts2 = gen_pts(w, w, rnd_state)
|
||||
# gridX = np.arange(w, dtype=np.int16)
|
||||
# gridY = np.arange(w, dtype=np.int16)
|
||||
# vy, vx = np.meshgrid(gridX, gridY)
|
||||
# drigid = mls_rigid_deformation(vy, vx, pts1, pts2)
|
||||
# mapy, mapx = drigid.astype(np.float32)
|
||||
################
|
||||
|
||||
#random transform
|
||||
random_transform_mat = cv2.getRotationMatrix2D((w // 2, w // 2), rotation, scale)
|
||||
|
@ -150,30 +36,16 @@ def gen_warp_params (w, flip=False, rotation_range=[-10,10], scale_range=[-0.5,
|
|||
params['mapx'] = mapx
|
||||
params['mapy'] = mapy
|
||||
params['rmat'] = random_transform_mat
|
||||
u_mat = random_transform_mat.copy()
|
||||
u_mat[:,2] /= w
|
||||
params['umat'] = u_mat
|
||||
params['w'] = w
|
||||
params['rw'] = rw
|
||||
params['flip'] = p_flip
|
||||
|
||||
return params
|
||||
|
||||
def warp_by_params (params, img, can_warp, can_transform, can_flip, border_replicate, cv2_inter=cv2.INTER_CUBIC):
|
||||
rw = params['rw']
|
||||
|
||||
if (can_warp or can_transform) and rw is not None:
|
||||
img = cv2.resize(img, (64,64), interpolation=cv2_inter)
|
||||
|
||||
if can_warp:
|
||||
img = cv2.remap(img, params['mapx'], params['mapy'], cv2_inter )
|
||||
if can_transform:
|
||||
img = cv2.warpAffine( img, params['rmat'], (params['w'], params['w']), borderMode=(cv2.BORDER_REPLICATE if border_replicate else cv2.BORDER_CONSTANT), flags=cv2_inter )
|
||||
|
||||
|
||||
if (can_warp or can_transform) and rw is not None:
|
||||
img = cv2.resize(img, (rw,rw), interpolation=cv2_inter)
|
||||
|
||||
if len(img.shape) == 2:
|
||||
img = img[...,None]
|
||||
if can_flip and params['flip']:
|
||||
|
|
|
@ -7,7 +7,6 @@ import types
|
|||
|
||||
import colorama
|
||||
import cv2
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from core import stdex
|
||||
|
@ -198,7 +197,7 @@ class InteractBase(object):
|
|||
def add_key_event(self, wnd_name, ord_key, ctrl_pressed, alt_pressed, shift_pressed):
|
||||
if wnd_name not in self.key_events:
|
||||
self.key_events[wnd_name] = []
|
||||
self.key_events[wnd_name] += [ (ord_key, chr(ord_key) if ord_key <= 255 else chr(0), ctrl_pressed, alt_pressed, shift_pressed) ]
|
||||
self.key_events[wnd_name] += [ (ord_key, chr(ord_key), ctrl_pressed, alt_pressed, shift_pressed) ]
|
||||
|
||||
def get_mouse_events(self, wnd_name):
|
||||
ar = self.mouse_events.get(wnd_name, [])
|
||||
|
@ -256,7 +255,7 @@ class InteractBase(object):
|
|||
print(result)
|
||||
return result
|
||||
|
||||
def input_int(self, s, default_value, valid_range=None, valid_list=None, add_info=None, show_default_value=True, help_message=None):
|
||||
def input_int(self, s, default_value, valid_list=None, add_info=None, show_default_value=True, help_message=None):
|
||||
if show_default_value:
|
||||
if len(s) != 0:
|
||||
s = f"[{default_value}] {s}"
|
||||
|
@ -264,21 +263,15 @@ class InteractBase(object):
|
|||
s = f"[{default_value}]"
|
||||
|
||||
if add_info is not None or \
|
||||
valid_range is not None or \
|
||||
help_message is not None:
|
||||
s += " ("
|
||||
|
||||
if valid_range is not None:
|
||||
s += f" {valid_range[0]}-{valid_range[1]}"
|
||||
|
||||
if add_info is not None:
|
||||
s += f" {add_info}"
|
||||
|
||||
if help_message is not None:
|
||||
s += " ?:help"
|
||||
|
||||
if add_info is not None or \
|
||||
valid_range is not None or \
|
||||
help_message is not None:
|
||||
s += " )"
|
||||
|
||||
|
@ -295,12 +288,9 @@ class InteractBase(object):
|
|||
continue
|
||||
|
||||
i = int(inp)
|
||||
if valid_range is not None:
|
||||
i = int(np.clip(i, valid_range[0], valid_range[1]))
|
||||
|
||||
if (valid_list is not None) and (i not in valid_list):
|
||||
i = default_value
|
||||
|
||||
result = default_value
|
||||
break
|
||||
result = i
|
||||
break
|
||||
except:
|
||||
|
@ -501,11 +491,10 @@ class InteractDesktop(InteractBase):
|
|||
|
||||
if has_windows or has_capture_keys:
|
||||
wait_key_time = max(1, int(sleep_time*1000) )
|
||||
ord_key = cv2.waitKeyEx(wait_key_time)
|
||||
|
||||
ord_key = cv2.waitKey(wait_key_time)
|
||||
shift_pressed = False
|
||||
if ord_key != -1:
|
||||
chr_key = chr(ord_key) if ord_key <= 255 else chr(0)
|
||||
chr_key = chr(ord_key)
|
||||
|
||||
if chr_key >= 'A' and chr_key <= 'Z':
|
||||
shift_pressed = True
|
||||
|
|
|
@ -81,8 +81,11 @@ class Subprocessor(object):
|
|||
except Subprocessor.SilenceException as e:
|
||||
c2s.put ( {'op': 'error', 'data' : data} )
|
||||
except Exception as e:
|
||||
err_msg = traceback.format_exc()
|
||||
c2s.put ( {'op': 'error', 'data' : data, 'err_msg' : err_msg} )
|
||||
c2s.put ( {'op': 'error', 'data' : data} )
|
||||
if data is not None:
|
||||
print ('Exception while process data [%s]: %s' % (self.get_data_name(data), traceback.format_exc()) )
|
||||
else:
|
||||
print ('Exception: %s' % (traceback.format_exc()) )
|
||||
|
||||
c2s.close()
|
||||
s2c.close()
|
||||
|
@ -156,24 +159,6 @@ class Subprocessor(object):
|
|||
|
||||
self.clis = []
|
||||
|
||||
def cli_init_dispatcher(cli):
|
||||
while not cli.c2s.empty():
|
||||
obj = cli.c2s.get()
|
||||
op = obj.get('op','')
|
||||
if op == 'init_ok':
|
||||
cli.state = 0
|
||||
elif op == 'log_info':
|
||||
io.log_info(obj['msg'])
|
||||
elif op == 'log_err':
|
||||
io.log_err(obj['msg'])
|
||||
elif op == 'error':
|
||||
err_msg = obj.get('err_msg', None)
|
||||
if err_msg is not None:
|
||||
io.log_info(f'Error while subprocess initialization: {err_msg}')
|
||||
cli.kill()
|
||||
self.clis.remove(cli)
|
||||
break
|
||||
|
||||
#getting info about name of subprocesses, host and client dicts, and spawning them
|
||||
for name, host_dict, client_dict in self.process_info_generator():
|
||||
try:
|
||||
|
@ -188,7 +173,19 @@ class Subprocessor(object):
|
|||
|
||||
if self.initialize_subprocesses_in_serial:
|
||||
while True:
|
||||
cli_init_dispatcher(cli)
|
||||
while not cli.c2s.empty():
|
||||
obj = cli.c2s.get()
|
||||
op = obj.get('op','')
|
||||
if op == 'init_ok':
|
||||
cli.state = 0
|
||||
elif op == 'log_info':
|
||||
io.log_info(obj['msg'])
|
||||
elif op == 'log_err':
|
||||
io.log_err(obj['msg'])
|
||||
elif op == 'error':
|
||||
cli.kill()
|
||||
self.clis.remove(cli)
|
||||
break
|
||||
if cli.state == 0:
|
||||
break
|
||||
io.process_messages(0.005)
|
||||
|
@ -201,7 +198,19 @@ class Subprocessor(object):
|
|||
#waiting subprocesses their success(or not) initialization
|
||||
while True:
|
||||
for cli in self.clis[:]:
|
||||
cli_init_dispatcher(cli)
|
||||
while not cli.c2s.empty():
|
||||
obj = cli.c2s.get()
|
||||
op = obj.get('op','')
|
||||
if op == 'init_ok':
|
||||
cli.state = 0
|
||||
elif op == 'log_info':
|
||||
io.log_info(obj['msg'])
|
||||
elif op == 'log_err':
|
||||
io.log_err(obj['msg'])
|
||||
elif op == 'error':
|
||||
cli.kill()
|
||||
self.clis.remove(cli)
|
||||
break
|
||||
if all ([cli.state == 0 for cli in self.clis]):
|
||||
break
|
||||
io.process_messages(0.005)
|
||||
|
@ -226,10 +235,6 @@ class Subprocessor(object):
|
|||
cli.state = 0
|
||||
elif op == 'error':
|
||||
#some error occured while process data, returning chunk to on_data_return
|
||||
err_msg = obj.get('err_msg', None)
|
||||
if err_msg is not None:
|
||||
io.log_info(f'Error while processing data: {err_msg}')
|
||||
|
||||
if 'data' in obj.keys():
|
||||
self.on_data_return (cli.host_dict, obj['data'] )
|
||||
#and killing process
|
||||
|
|
|
@ -6,55 +6,49 @@ class DeepFakeArchi(nn.ArchiBase):
|
|||
resolution
|
||||
|
||||
mod None - default
|
||||
'uhd'
|
||||
'quick'
|
||||
|
||||
opts ''
|
||||
''
|
||||
't'
|
||||
"""
|
||||
def __init__(self, resolution, use_fp16=False, mod=None, opts=None):
|
||||
def __init__(self, resolution, mod=None):
|
||||
super().__init__()
|
||||
|
||||
if opts is None:
|
||||
opts = ''
|
||||
|
||||
|
||||
conv_dtype = tf.float16 if use_fp16 else tf.float32
|
||||
|
||||
if 'c' in opts:
|
||||
def act(x, alpha=0.1):
|
||||
return x*tf.cos(x)
|
||||
else:
|
||||
def act(x, alpha=0.1):
|
||||
return tf.nn.leaky_relu(x, alpha)
|
||||
|
||||
if mod is None:
|
||||
class Downscale(nn.ModelBase):
|
||||
def __init__(self, in_ch, out_ch, kernel_size=5, *kwargs ):
|
||||
def __init__(self, in_ch, out_ch, kernel_size=5, dilations=1, subpixel=True, use_activator=True, *kwargs ):
|
||||
self.in_ch = in_ch
|
||||
self.out_ch = out_ch
|
||||
self.kernel_size = kernel_size
|
||||
self.dilations = dilations
|
||||
self.subpixel = subpixel
|
||||
self.use_activator = use_activator
|
||||
super().__init__(*kwargs)
|
||||
|
||||
def on_build(self, *args, **kwargs ):
|
||||
self.conv1 = nn.Conv2D( self.in_ch, self.out_ch, kernel_size=self.kernel_size, strides=2, padding='SAME', dtype=conv_dtype)
|
||||
self.conv1 = nn.Conv2D( self.in_ch,
|
||||
self.out_ch // (4 if self.subpixel else 1),
|
||||
kernel_size=self.kernel_size,
|
||||
strides=1 if self.subpixel else 2,
|
||||
padding='SAME', dilations=self.dilations)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = act(x, 0.1)
|
||||
if self.subpixel:
|
||||
x = nn.space_to_depth(x, 2)
|
||||
if self.use_activator:
|
||||
x = tf.nn.leaky_relu(x, 0.1)
|
||||
return x
|
||||
|
||||
def get_out_ch(self):
|
||||
return self.out_ch
|
||||
return (self.out_ch // 4) * 4 if self.subpixel else self.out_ch
|
||||
|
||||
class DownscaleBlock(nn.ModelBase):
|
||||
def on_build(self, in_ch, ch, n_downscales, kernel_size):
|
||||
def on_build(self, in_ch, ch, n_downscales, kernel_size, dilations=1, subpixel=True):
|
||||
self.downs = []
|
||||
|
||||
last_ch = in_ch
|
||||
for i in range(n_downscales):
|
||||
cur_ch = ch*( min(2**i, 8) )
|
||||
self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size))
|
||||
self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size, dilations=dilations, subpixel=subpixel) )
|
||||
last_ch = self.downs[-1].get_out_ch()
|
||||
|
||||
def forward(self, inp):
|
||||
|
@ -64,77 +58,66 @@ class DeepFakeArchi(nn.ArchiBase):
|
|||
return x
|
||||
|
||||
class Upscale(nn.ModelBase):
|
||||
def on_build(self, in_ch, out_ch, kernel_size=3):
|
||||
self.conv1 = nn.Conv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype)
|
||||
def on_build(self, in_ch, out_ch, kernel_size=3 ):
|
||||
self.conv1 = nn.Conv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME')
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = act(x, 0.1)
|
||||
x = tf.nn.leaky_relu(x, 0.1)
|
||||
x = nn.depth_to_space(x, 2)
|
||||
return x
|
||||
|
||||
class ResidualBlock(nn.ModelBase):
|
||||
def on_build(self, ch, kernel_size=3):
|
||||
self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype)
|
||||
self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype)
|
||||
def on_build(self, ch, kernel_size=3 ):
|
||||
self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
|
||||
self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
|
||||
|
||||
def forward(self, inp):
|
||||
x = self.conv1(inp)
|
||||
x = act(x, 0.2)
|
||||
x = tf.nn.leaky_relu(x, 0.2)
|
||||
x = self.conv2(x)
|
||||
x = act(inp + x, 0.2)
|
||||
x = tf.nn.leaky_relu(inp + x, 0.2)
|
||||
return x
|
||||
|
||||
class UpdownResidualBlock(nn.ModelBase):
|
||||
def on_build(self, ch, inner_ch, kernel_size=3 ):
|
||||
self.up = Upscale (ch, inner_ch, kernel_size=kernel_size)
|
||||
self.res = ResidualBlock (inner_ch, kernel_size=kernel_size)
|
||||
self.down = Downscale (inner_ch, ch, kernel_size=kernel_size, use_activator=False)
|
||||
|
||||
def forward(self, inp):
|
||||
x = self.up(inp)
|
||||
x = upx = self.res(x)
|
||||
x = self.down(x)
|
||||
x = x + inp
|
||||
x = tf.nn.leaky_relu(x, 0.2)
|
||||
return x, upx
|
||||
|
||||
class Encoder(nn.ModelBase):
|
||||
def __init__(self, in_ch, e_ch, **kwargs ):
|
||||
self.in_ch = in_ch
|
||||
self.e_ch = e_ch
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def on_build(self):
|
||||
if 't' in opts:
|
||||
self.down1 = Downscale(self.in_ch, self.e_ch, kernel_size=5)
|
||||
self.res1 = ResidualBlock(self.e_ch)
|
||||
self.down2 = Downscale(self.e_ch, self.e_ch*2, kernel_size=5)
|
||||
self.down3 = Downscale(self.e_ch*2, self.e_ch*4, kernel_size=5)
|
||||
self.down4 = Downscale(self.e_ch*4, self.e_ch*8, kernel_size=5)
|
||||
self.down5 = Downscale(self.e_ch*8, self.e_ch*8, kernel_size=5)
|
||||
self.res5 = ResidualBlock(self.e_ch*8)
|
||||
def on_build(self, in_ch, e_ch, is_hd):
|
||||
self.is_hd=is_hd
|
||||
if self.is_hd:
|
||||
self.down1 = DownscaleBlock(in_ch, e_ch*2, n_downscales=4, kernel_size=3, dilations=1)
|
||||
self.down2 = DownscaleBlock(in_ch, e_ch*2, n_downscales=4, kernel_size=5, dilations=1)
|
||||
self.down3 = DownscaleBlock(in_ch, e_ch//2, n_downscales=4, kernel_size=5, dilations=2)
|
||||
self.down4 = DownscaleBlock(in_ch, e_ch//2, n_downscales=4, kernel_size=7, dilations=2)
|
||||
else:
|
||||
self.down1 = DownscaleBlock(self.in_ch, self.e_ch, n_downscales=4 if 't' not in opts else 5, kernel_size=5)
|
||||
self.down1 = DownscaleBlock(in_ch, e_ch, n_downscales=4, kernel_size=5, dilations=1, subpixel=False)
|
||||
|
||||
def forward(self, x):
|
||||
if use_fp16:
|
||||
x = tf.cast(x, tf.float16)
|
||||
|
||||
if 't' in opts:
|
||||
x = self.down1(x)
|
||||
x = self.res1(x)
|
||||
x = self.down2(x)
|
||||
x = self.down3(x)
|
||||
x = self.down4(x)
|
||||
x = self.down5(x)
|
||||
x = self.res5(x)
|
||||
def forward(self, inp):
|
||||
if self.is_hd:
|
||||
x = tf.concat([ nn.flatten(self.down1(inp)),
|
||||
nn.flatten(self.down2(inp)),
|
||||
nn.flatten(self.down3(inp)),
|
||||
nn.flatten(self.down4(inp)) ], -1 )
|
||||
else:
|
||||
x = self.down1(x)
|
||||
x = nn.flatten(x)
|
||||
if 'u' in opts:
|
||||
x = nn.pixel_norm(x, axes=-1)
|
||||
|
||||
if use_fp16:
|
||||
x = tf.cast(x, tf.float32)
|
||||
x = nn.flatten(self.down1(inp))
|
||||
return x
|
||||
|
||||
def get_out_res(self, res):
|
||||
return res // ( (2**4) if 't' not in opts else (2**5) )
|
||||
|
||||
def get_out_ch(self):
|
||||
return self.e_ch * 8
|
||||
|
||||
lowest_dense_res = resolution // (32 if 'd' in opts else 16)
|
||||
lowest_dense_res = resolution // 16
|
||||
|
||||
class Inter(nn.ModelBase):
|
||||
def __init__(self, in_ch, ae_ch, ae_out_ch, **kwargs):
|
||||
def __init__(self, in_ch, ae_ch, ae_out_ch, is_hd=False, **kwargs):
|
||||
self.in_ch, self.ae_ch, self.ae_out_ch = in_ch, ae_ch, ae_out_ch
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
@ -143,81 +126,335 @@ class DeepFakeArchi(nn.ArchiBase):
|
|||
|
||||
self.dense1 = nn.Dense( in_ch, ae_ch )
|
||||
self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch )
|
||||
if 't' not in opts:
|
||||
self.upscale1 = Upscale(ae_out_ch, ae_out_ch)
|
||||
self.upscale1 = Upscale(ae_out_ch, ae_out_ch)
|
||||
|
||||
def forward(self, inp):
|
||||
x = inp
|
||||
x = self.dense1(x)
|
||||
x = self.dense1(inp)
|
||||
x = self.dense2(x)
|
||||
x = nn.reshape_4D (x, lowest_dense_res, lowest_dense_res, self.ae_out_ch)
|
||||
|
||||
if use_fp16:
|
||||
x = tf.cast(x, tf.float16)
|
||||
|
||||
if 't' not in opts:
|
||||
x = self.upscale1(x)
|
||||
|
||||
x = self.upscale1(x)
|
||||
return x
|
||||
|
||||
def get_out_res(self):
|
||||
return lowest_dense_res * 2 if 't' not in opts else lowest_dense_res
|
||||
@staticmethod
|
||||
def get_code_res():
|
||||
return lowest_dense_res
|
||||
|
||||
def get_out_ch(self):
|
||||
return self.ae_out_ch
|
||||
|
||||
class Decoder(nn.ModelBase):
|
||||
def on_build(self, in_ch, d_ch, d_mask_ch):
|
||||
if 't' not in opts:
|
||||
self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3)
|
||||
self.upscale1 = Upscale(d_ch*8, d_ch*4, kernel_size=3)
|
||||
self.upscale2 = Upscale(d_ch*4, d_ch*2, kernel_size=3)
|
||||
def on_build(self, in_ch, d_ch, d_mask_ch, is_hd ):
|
||||
self.is_hd = is_hd
|
||||
|
||||
self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3)
|
||||
self.upscale1 = Upscale(d_ch*8, d_ch*4, kernel_size=3)
|
||||
self.upscale2 = Upscale(d_ch*4, d_ch*2, kernel_size=3)
|
||||
|
||||
if is_hd:
|
||||
self.res0 = UpdownResidualBlock(in_ch, d_ch*8, kernel_size=3)
|
||||
self.res1 = UpdownResidualBlock(d_ch*8, d_ch*4, kernel_size=3)
|
||||
self.res2 = UpdownResidualBlock(d_ch*4, d_ch*2, kernel_size=3)
|
||||
self.res3 = UpdownResidualBlock(d_ch*2, d_ch, kernel_size=3)
|
||||
else:
|
||||
self.res0 = ResidualBlock(d_ch*8, kernel_size=3)
|
||||
self.res1 = ResidualBlock(d_ch*4, kernel_size=3)
|
||||
self.res2 = ResidualBlock(d_ch*2, kernel_size=3)
|
||||
|
||||
self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3)
|
||||
self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3)
|
||||
self.upscalem2 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3)
|
||||
self.out_conv = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='SAME')
|
||||
|
||||
self.out_conv = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='SAME', dtype=conv_dtype)
|
||||
self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3)
|
||||
self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3)
|
||||
self.upscalem2 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3)
|
||||
self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME')
|
||||
|
||||
if 'd' in opts:
|
||||
self.out_conv1 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
|
||||
self.out_conv2 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
|
||||
self.out_conv3 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
|
||||
self.upscalem3 = Upscale(d_mask_ch*2, d_mask_ch*1, kernel_size=3)
|
||||
self.out_convm = nn.Conv2D( d_mask_ch*1, 1, kernel_size=1, padding='SAME', dtype=conv_dtype)
|
||||
else:
|
||||
self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME', dtype=conv_dtype)
|
||||
def forward(self, inp):
|
||||
z = inp
|
||||
|
||||
if self.is_hd:
|
||||
x, upx = self.res0(z)
|
||||
x = self.upscale0(x)
|
||||
x = tf.nn.leaky_relu(x + upx, 0.2)
|
||||
x, upx = self.res1(x)
|
||||
|
||||
x = self.upscale1(x)
|
||||
x = tf.nn.leaky_relu(x + upx, 0.2)
|
||||
x, upx = self.res2(x)
|
||||
|
||||
x = self.upscale2(x)
|
||||
x = tf.nn.leaky_relu(x + upx, 0.2)
|
||||
x, upx = self.res3(x)
|
||||
else:
|
||||
self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3)
|
||||
self.upscale1 = Upscale(d_ch*8, d_ch*8, kernel_size=3)
|
||||
self.upscale2 = Upscale(d_ch*8, d_ch*4, kernel_size=3)
|
||||
self.upscale3 = Upscale(d_ch*4, d_ch*2, kernel_size=3)
|
||||
self.res0 = ResidualBlock(d_ch*8, kernel_size=3)
|
||||
self.res1 = ResidualBlock(d_ch*8, kernel_size=3)
|
||||
self.res2 = ResidualBlock(d_ch*4, kernel_size=3)
|
||||
self.res3 = ResidualBlock(d_ch*2, kernel_size=3)
|
||||
x = self.upscale0(z)
|
||||
x = self.res0(x)
|
||||
x = self.upscale1(x)
|
||||
x = self.res1(x)
|
||||
x = self.upscale2(x)
|
||||
x = self.res2(x)
|
||||
|
||||
self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3)
|
||||
self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*8, kernel_size=3)
|
||||
self.upscalem2 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3)
|
||||
self.upscalem3 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3)
|
||||
self.out_conv = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='SAME', dtype=conv_dtype)
|
||||
m = self.upscalem0(z)
|
||||
m = self.upscalem1(m)
|
||||
m = self.upscalem2(m)
|
||||
|
||||
if 'd' in opts:
|
||||
self.out_conv1 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
|
||||
self.out_conv2 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
|
||||
self.out_conv3 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
|
||||
self.upscalem4 = Upscale(d_mask_ch*2, d_mask_ch*1, kernel_size=3)
|
||||
self.out_convm = nn.Conv2D( d_mask_ch*1, 1, kernel_size=1, padding='SAME', dtype=conv_dtype)
|
||||
else:
|
||||
self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME', dtype=conv_dtype)
|
||||
return tf.nn.sigmoid(self.out_conv(x)), \
|
||||
tf.nn.sigmoid(self.out_convm(m))
|
||||
|
||||
elif mod == 'quick':
|
||||
class Downscale(nn.ModelBase):
|
||||
def __init__(self, in_ch, out_ch, kernel_size=5, dilations=1, subpixel=True, use_activator=True, *kwargs ):
|
||||
self.in_ch = in_ch
|
||||
self.out_ch = out_ch
|
||||
self.kernel_size = kernel_size
|
||||
self.dilations = dilations
|
||||
self.subpixel = subpixel
|
||||
self.use_activator = use_activator
|
||||
super().__init__(*kwargs)
|
||||
|
||||
def on_build(self, *args, **kwargs ):
|
||||
self.conv1 = nn.Conv2D( self.in_ch,
|
||||
self.out_ch // (4 if self.subpixel else 1),
|
||||
kernel_size=self.kernel_size,
|
||||
strides=1 if self.subpixel else 2,
|
||||
padding='SAME', dilations=self.dilations )
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
|
||||
if self.subpixel:
|
||||
x = nn.space_to_depth(x, 2)
|
||||
|
||||
if self.use_activator:
|
||||
x = nn.gelu(x)
|
||||
return x
|
||||
|
||||
def get_out_ch(self):
|
||||
return (self.out_ch // 4) * 4 if self.subpixel else self.out_ch
|
||||
|
||||
class DownscaleBlock(nn.ModelBase):
|
||||
def on_build(self, in_ch, ch, n_downscales, kernel_size, dilations=1, subpixel=True):
|
||||
self.downs = []
|
||||
|
||||
last_ch = in_ch
|
||||
for i in range(n_downscales):
|
||||
cur_ch = ch*( min(2**i, 8) )
|
||||
self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size, dilations=dilations, subpixel=subpixel) )
|
||||
last_ch = self.downs[-1].get_out_ch()
|
||||
|
||||
def forward(self, inp):
|
||||
x = inp
|
||||
for down in self.downs:
|
||||
x = down(x)
|
||||
return x
|
||||
|
||||
class Upscale(nn.ModelBase):
|
||||
def on_build(self, in_ch, out_ch, kernel_size=3 ):
|
||||
self.conv1 = nn.Conv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME')
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = nn.gelu(x)
|
||||
x = nn.depth_to_space(x, 2)
|
||||
return x
|
||||
|
||||
class ResidualBlock(nn.ModelBase):
|
||||
def on_build(self, ch, kernel_size=3 ):
|
||||
self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
|
||||
self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
|
||||
|
||||
def forward(self, inp):
|
||||
x = self.conv1(inp)
|
||||
x = nn.gelu(x)
|
||||
x = self.conv2(x)
|
||||
x = inp + x
|
||||
x = nn.gelu(x)
|
||||
return x
|
||||
|
||||
class Encoder(nn.ModelBase):
|
||||
def on_build(self, in_ch, e_ch):
|
||||
self.down1 = DownscaleBlock(in_ch, e_ch, n_downscales=4, kernel_size=5)
|
||||
def forward(self, inp):
|
||||
return nn.flatten(self.down1(inp))
|
||||
|
||||
lowest_dense_res = resolution // 16
|
||||
|
||||
class Inter(nn.ModelBase):
|
||||
def __init__(self, in_ch, ae_ch, ae_out_ch, d_ch, **kwargs):
|
||||
self.in_ch, self.ae_ch, self.ae_out_ch, self.d_ch = in_ch, ae_ch, ae_out_ch, d_ch
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def on_build(self):
|
||||
in_ch, ae_ch, ae_out_ch, d_ch = self.in_ch, self.ae_ch, self.ae_out_ch, self.d_ch
|
||||
|
||||
self.dense1 = nn.Dense( in_ch, ae_ch, kernel_initializer=tf.initializers.orthogonal )
|
||||
self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch, kernel_initializer=tf.initializers.orthogonal )
|
||||
self.upscale1 = Upscale(ae_out_ch, d_ch*8)
|
||||
self.res1 = ResidualBlock(d_ch*8)
|
||||
|
||||
def forward(self, inp):
|
||||
x = self.dense1(inp)
|
||||
x = self.dense2(x)
|
||||
x = nn.reshape_4D (x, lowest_dense_res, lowest_dense_res, self.ae_out_ch)
|
||||
x = self.upscale1(x)
|
||||
x = self.res1(x)
|
||||
return x
|
||||
|
||||
def get_out_ch(self):
|
||||
return self.ae_out_ch
|
||||
|
||||
class Decoder(nn.ModelBase):
|
||||
def on_build(self, in_ch, d_ch):
|
||||
self.upscale1 = Upscale(in_ch, d_ch*4)
|
||||
self.res1 = ResidualBlock(d_ch*4)
|
||||
self.upscale2 = Upscale(d_ch*4, d_ch*2)
|
||||
self.res2 = ResidualBlock(d_ch*2)
|
||||
self.upscale3 = Upscale(d_ch*2, d_ch*1)
|
||||
self.res3 = ResidualBlock(d_ch*1)
|
||||
|
||||
self.upscalem1 = Upscale(in_ch, d_ch)
|
||||
self.upscalem2 = Upscale(d_ch, d_ch//2)
|
||||
self.upscalem3 = Upscale(d_ch//2, d_ch//2)
|
||||
|
||||
self.out_conv = nn.Conv2D( d_ch*1, 3, kernel_size=1, padding='SAME')
|
||||
self.out_convm = nn.Conv2D( d_ch//2, 1, kernel_size=1, padding='SAME')
|
||||
|
||||
def forward(self, inp):
|
||||
z = inp
|
||||
x = self.upscale1 (z)
|
||||
x = self.res1 (x)
|
||||
x = self.upscale2 (x)
|
||||
x = self.res2 (x)
|
||||
x = self.upscale3 (x)
|
||||
x = self.res3 (x)
|
||||
|
||||
y = self.upscalem1 (z)
|
||||
y = self.upscalem2 (y)
|
||||
y = self.upscalem3 (y)
|
||||
|
||||
return tf.nn.sigmoid(self.out_conv(x)), \
|
||||
tf.nn.sigmoid(self.out_convm(y))
|
||||
elif mod == 'uhd':
|
||||
|
||||
class Downscale(nn.ModelBase):
|
||||
def __init__(self, in_ch, out_ch, kernel_size=5, dilations=1, subpixel=True, use_activator=True, *kwargs ):
|
||||
self.in_ch = in_ch
|
||||
self.out_ch = out_ch
|
||||
self.kernel_size = kernel_size
|
||||
self.dilations = dilations
|
||||
self.subpixel = subpixel
|
||||
self.use_activator = use_activator
|
||||
super().__init__(*kwargs)
|
||||
|
||||
def on_build(self, *args, **kwargs ):
|
||||
self.conv1 = nn.Conv2D( self.in_ch,
|
||||
self.out_ch // (4 if self.subpixel else 1),
|
||||
kernel_size=self.kernel_size,
|
||||
strides=1 if self.subpixel else 2,
|
||||
padding='SAME', dilations=self.dilations)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
if self.subpixel:
|
||||
x = nn.space_to_depth(x, 2)
|
||||
if self.use_activator:
|
||||
x = tf.nn.leaky_relu(x, 0.1)
|
||||
return x
|
||||
|
||||
def get_out_ch(self):
|
||||
return (self.out_ch // 4) * 4 if self.subpixel else self.out_ch
|
||||
|
||||
class DownscaleBlock(nn.ModelBase):
|
||||
def on_build(self, in_ch, ch, n_downscales, kernel_size, dilations=1, subpixel=True):
|
||||
self.downs = []
|
||||
|
||||
last_ch = in_ch
|
||||
for i in range(n_downscales):
|
||||
cur_ch = ch*( min(2**i, 8) )
|
||||
self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size, dilations=dilations, subpixel=subpixel) )
|
||||
last_ch = self.downs[-1].get_out_ch()
|
||||
|
||||
def forward(self, inp):
|
||||
x = inp
|
||||
for down in self.downs:
|
||||
x = down(x)
|
||||
return x
|
||||
|
||||
class Upscale(nn.ModelBase):
|
||||
def on_build(self, in_ch, out_ch, kernel_size=3 ):
|
||||
self.conv1 = nn.Conv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME')
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = tf.nn.leaky_relu(x, 0.1)
|
||||
x = nn.depth_to_space(x, 2)
|
||||
return x
|
||||
|
||||
class ResidualBlock(nn.ModelBase):
|
||||
def on_build(self, ch, kernel_size=3 ):
|
||||
self.conv1 = nn.Conv2D( ch, ch*2, kernel_size=kernel_size, padding='SAME')
|
||||
self.conv2 = nn.Conv2D( ch*2, ch, kernel_size=kernel_size, padding='SAME')
|
||||
self.scale_add = nn.ScaleAdd(ch)
|
||||
|
||||
def forward(self, inp):
|
||||
x = self.conv1(inp)
|
||||
x = tf.nn.leaky_relu(x, 0.2)
|
||||
x = self.conv2(x)
|
||||
x = tf.nn.leaky_relu(x, 0.2)
|
||||
x = self.scale_add([inp, x])
|
||||
return x
|
||||
|
||||
class Encoder(nn.ModelBase):
|
||||
def on_build(self, in_ch, e_ch, **kwargs):
|
||||
self.down1 = DownscaleBlock(in_ch, e_ch, n_downscales=4, kernel_size=5, dilations=1, subpixel=False)
|
||||
|
||||
def forward(self, inp):
|
||||
x = nn.flatten(self.down1(inp))
|
||||
return x
|
||||
|
||||
lowest_dense_res = resolution // 16
|
||||
|
||||
class Inter(nn.ModelBase):
|
||||
def on_build(self, in_ch, ae_ch, ae_out_ch, **kwargs):
|
||||
self.ae_out_ch = ae_out_ch
|
||||
self.dense_norm = nn.DenseNorm()
|
||||
self.dense1 = nn.Dense( in_ch, ae_ch )
|
||||
self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch )
|
||||
self.upscale1 = Upscale(ae_out_ch, ae_out_ch)
|
||||
|
||||
def forward(self, inp):
|
||||
x = self.dense_norm(inp)
|
||||
x = self.dense1(x)
|
||||
x = self.dense2(x)
|
||||
x = nn.reshape_4D (x, lowest_dense_res, lowest_dense_res, self.ae_out_ch)
|
||||
x = self.upscale1(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def get_code_res():
|
||||
return lowest_dense_res
|
||||
|
||||
def get_out_ch(self):
|
||||
return self.ae_out_ch
|
||||
|
||||
class Decoder(nn.ModelBase):
|
||||
def on_build(self, in_ch, d_ch, d_mask_ch, **kwargs ):
|
||||
|
||||
self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3)
|
||||
self.upscale1 = Upscale(d_ch*8, d_ch*4, kernel_size=3)
|
||||
self.upscale2 = Upscale(d_ch*4, d_ch*2, kernel_size=3)
|
||||
|
||||
self.res0 = ResidualBlock(d_ch*8, kernel_size=3)
|
||||
self.res1 = ResidualBlock(d_ch*4, kernel_size=3)
|
||||
self.res2 = ResidualBlock(d_ch*2, kernel_size=3)
|
||||
|
||||
self.out_conv = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='SAME')
|
||||
|
||||
self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3)
|
||||
self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3)
|
||||
self.upscalem2 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3)
|
||||
self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME')
|
||||
|
||||
def forward(self, inp):
|
||||
z = inp
|
||||
|
||||
def forward(self, z):
|
||||
x = self.upscale0(z)
|
||||
x = self.res0(x)
|
||||
x = self.upscale1(x)
|
||||
|
@ -225,38 +462,12 @@ class DeepFakeArchi(nn.ArchiBase):
|
|||
x = self.upscale2(x)
|
||||
x = self.res2(x)
|
||||
|
||||
if 't' in opts:
|
||||
x = self.upscale3(x)
|
||||
x = self.res3(x)
|
||||
|
||||
if 'd' in opts:
|
||||
x = tf.nn.sigmoid( nn.depth_to_space(tf.concat( (self.out_conv(x),
|
||||
self.out_conv1(x),
|
||||
self.out_conv2(x),
|
||||
self.out_conv3(x)), nn.conv2d_ch_axis), 2) )
|
||||
else:
|
||||
x = tf.nn.sigmoid(self.out_conv(x))
|
||||
|
||||
|
||||
m = self.upscalem0(z)
|
||||
m = self.upscalem1(m)
|
||||
m = self.upscalem2(m)
|
||||
|
||||
if 't' in opts:
|
||||
m = self.upscalem3(m)
|
||||
if 'd' in opts:
|
||||
m = self.upscalem4(m)
|
||||
else:
|
||||
if 'd' in opts:
|
||||
m = self.upscalem3(m)
|
||||
|
||||
m = tf.nn.sigmoid(self.out_convm(m))
|
||||
|
||||
if use_fp16:
|
||||
x = tf.cast(x, tf.float32)
|
||||
m = tf.cast(m, tf.float32)
|
||||
|
||||
return x, m
|
||||
return tf.nn.sigmoid(self.out_conv(x)), \
|
||||
tf.nn.sigmoid(self.out_convm(m))
|
||||
|
||||
self.Encoder = Encoder
|
||||
self.Inter = Inter
|
||||
|
|
|
@ -1,19 +1,12 @@
|
|||
import sys
|
||||
import ctypes
|
||||
import os
|
||||
import multiprocessing
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
from core.interact import interact as io
|
||||
|
||||
|
||||
class Device(object):
|
||||
def __init__(self, index, tf_dev_type, name, total_mem, free_mem):
|
||||
def __init__(self, index, name, total_mem, free_mem, cc=0):
|
||||
self.index = index
|
||||
self.tf_dev_type = tf_dev_type
|
||||
self.name = name
|
||||
|
||||
self.cc = cc
|
||||
self.total_mem = total_mem
|
||||
self.total_mem_gb = total_mem / 1024**3
|
||||
self.free_mem = free_mem
|
||||
|
@ -89,135 +82,10 @@ class Devices(object):
|
|||
result.append (device)
|
||||
return Devices(result)
|
||||
|
||||
@staticmethod
|
||||
def _get_tf_devices_proc(q : multiprocessing.Queue):
|
||||
|
||||
if sys.platform[0:3] == 'win':
|
||||
compute_cache_path = Path(os.environ['APPDATA']) / 'NVIDIA' / ('ComputeCache_ALL')
|
||||
os.environ['CUDA_CACHE_PATH'] = str(compute_cache_path)
|
||||
if not compute_cache_path.exists():
|
||||
io.log_info("Caching GPU kernels...")
|
||||
compute_cache_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
import tensorflow
|
||||
|
||||
tf_version = tensorflow.version.VERSION
|
||||
#if tf_version is None:
|
||||
# tf_version = tensorflow.version.GIT_VERSION
|
||||
if tf_version[0] == 'v':
|
||||
tf_version = tf_version[1:]
|
||||
if tf_version[0] == '2':
|
||||
tf = tensorflow.compat.v1
|
||||
else:
|
||||
tf = tensorflow
|
||||
|
||||
import logging
|
||||
# Disable tensorflow warnings
|
||||
tf_logger = logging.getLogger('tensorflow')
|
||||
tf_logger.setLevel(logging.ERROR)
|
||||
|
||||
from tensorflow.python.client import device_lib
|
||||
|
||||
devices = []
|
||||
|
||||
physical_devices = device_lib.list_local_devices()
|
||||
physical_devices_f = {}
|
||||
for dev in physical_devices:
|
||||
dev_type = dev.device_type
|
||||
dev_tf_name = dev.name
|
||||
dev_tf_name = dev_tf_name[ dev_tf_name.index(dev_type) : ]
|
||||
|
||||
dev_idx = int(dev_tf_name.split(':')[-1])
|
||||
|
||||
if dev_type in ['GPU','DML']:
|
||||
dev_name = dev_tf_name
|
||||
|
||||
dev_desc = dev.physical_device_desc
|
||||
if len(dev_desc) != 0:
|
||||
if dev_desc[0] == '{':
|
||||
dev_desc_json = json.loads(dev_desc)
|
||||
dev_desc_json_name = dev_desc_json.get('name',None)
|
||||
if dev_desc_json_name is not None:
|
||||
dev_name = dev_desc_json_name
|
||||
else:
|
||||
for param, value in ( v.split(':') for v in dev_desc.split(',') ):
|
||||
param = param.strip()
|
||||
value = value.strip()
|
||||
if param == 'name':
|
||||
dev_name = value
|
||||
break
|
||||
|
||||
physical_devices_f[dev_idx] = (dev_type, dev_name, dev.memory_limit)
|
||||
|
||||
q.put(physical_devices_f)
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def initialize_main_env():
|
||||
if int(os.environ.get("NN_DEVICES_INITIALIZED", 0)) != 0:
|
||||
return
|
||||
|
||||
if 'CUDA_VISIBLE_DEVICES' in os.environ.keys():
|
||||
os.environ.pop('CUDA_VISIBLE_DEVICES')
|
||||
|
||||
os.environ['TF_DIRECTML_KERNEL_CACHE_SIZE'] = '2500'
|
||||
os.environ['CUDA_CACHE_MAXSIZE'] = '2147483647'
|
||||
os.environ['TF_MIN_GPU_MULTIPROCESSOR_COUNT'] = '2'
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # tf log errors only
|
||||
|
||||
q = multiprocessing.Queue()
|
||||
p = multiprocessing.Process(target=Devices._get_tf_devices_proc, args=(q,), daemon=True)
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
visible_devices = q.get()
|
||||
|
||||
os.environ['NN_DEVICES_INITIALIZED'] = '1'
|
||||
os.environ['NN_DEVICES_COUNT'] = str(len(visible_devices))
|
||||
|
||||
for i in visible_devices:
|
||||
dev_type, name, total_mem = visible_devices[i]
|
||||
|
||||
os.environ[f'NN_DEVICE_{i}_TF_DEV_TYPE'] = dev_type
|
||||
os.environ[f'NN_DEVICE_{i}_NAME'] = name
|
||||
os.environ[f'NN_DEVICE_{i}_TOTAL_MEM'] = str(total_mem)
|
||||
os.environ[f'NN_DEVICE_{i}_FREE_MEM'] = str(total_mem)
|
||||
|
||||
|
||||
|
||||
@staticmethod
|
||||
def getDevices():
|
||||
if Devices.all_devices is None:
|
||||
if int(os.environ.get("NN_DEVICES_INITIALIZED", 0)) != 1:
|
||||
raise Exception("nn devices are not initialized. Run initialize_main_env() in main process.")
|
||||
devices = []
|
||||
for i in range ( int(os.environ['NN_DEVICES_COUNT']) ):
|
||||
devices.append ( Device(index=i,
|
||||
tf_dev_type=os.environ[f'NN_DEVICE_{i}_TF_DEV_TYPE'],
|
||||
name=os.environ[f'NN_DEVICE_{i}_NAME'],
|
||||
total_mem=int(os.environ[f'NN_DEVICE_{i}_TOTAL_MEM']),
|
||||
free_mem=int(os.environ[f'NN_DEVICE_{i}_FREE_MEM']), )
|
||||
)
|
||||
Devices.all_devices = Devices(devices)
|
||||
|
||||
return Devices.all_devices
|
||||
|
||||
"""
|
||||
|
||||
|
||||
# {'name' : name.split(b'\0', 1)[0].decode(),
|
||||
# 'total_mem' : totalMem.value
|
||||
# }
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
return
|
||||
|
||||
|
||||
|
||||
os.environ['NN_DEVICES_COUNT'] = '0'
|
||||
|
||||
min_cc = int(os.environ.get("TF_MIN_REQ_CAP", 35))
|
||||
libnames = ('libcuda.so', 'libcuda.dylib', 'nvcuda.dll')
|
||||
|
@ -270,4 +138,70 @@ class Devices(object):
|
|||
os.environ[f'NN_DEVICE_{i}_TOTAL_MEM'] = str(device['total_mem'])
|
||||
os.environ[f'NN_DEVICE_{i}_FREE_MEM'] = str(device['free_mem'])
|
||||
os.environ[f'NN_DEVICE_{i}_CC'] = str(device['cc'])
|
||||
|
||||
@staticmethod
|
||||
def getDevices():
|
||||
if Devices.all_devices is None:
|
||||
if int(os.environ.get("NN_DEVICES_INITIALIZED", 0)) != 1:
|
||||
raise Exception("nn devices are not initialized. Run initialize_main_env() in main process.")
|
||||
devices = []
|
||||
for i in range ( int(os.environ['NN_DEVICES_COUNT']) ):
|
||||
devices.append ( Device(index=i,
|
||||
name=os.environ[f'NN_DEVICE_{i}_NAME'],
|
||||
total_mem=int(os.environ[f'NN_DEVICE_{i}_TOTAL_MEM']),
|
||||
free_mem=int(os.environ[f'NN_DEVICE_{i}_FREE_MEM']),
|
||||
cc=int(os.environ[f'NN_DEVICE_{i}_CC']) ))
|
||||
Devices.all_devices = Devices(devices)
|
||||
|
||||
return Devices.all_devices
|
||||
|
||||
"""
|
||||
if Devices.all_devices is None:
|
||||
min_cc = int(os.environ.get("TF_MIN_REQ_CAP", 35))
|
||||
|
||||
libnames = ('libcuda.so', 'libcuda.dylib', 'nvcuda.dll')
|
||||
for libname in libnames:
|
||||
try:
|
||||
cuda = ctypes.CDLL(libname)
|
||||
except:
|
||||
continue
|
||||
else:
|
||||
break
|
||||
else:
|
||||
return Devices([])
|
||||
|
||||
nGpus = ctypes.c_int()
|
||||
name = b' ' * 200
|
||||
cc_major = ctypes.c_int()
|
||||
cc_minor = ctypes.c_int()
|
||||
freeMem = ctypes.c_size_t()
|
||||
totalMem = ctypes.c_size_t()
|
||||
|
||||
result = ctypes.c_int()
|
||||
device = ctypes.c_int()
|
||||
context = ctypes.c_void_p()
|
||||
error_str = ctypes.c_char_p()
|
||||
|
||||
devices = []
|
||||
|
||||
if cuda.cuInit(0) == 0 and \
|
||||
cuda.cuDeviceGetCount(ctypes.byref(nGpus)) == 0:
|
||||
for i in range(nGpus.value):
|
||||
if cuda.cuDeviceGet(ctypes.byref(device), i) != 0 or \
|
||||
cuda.cuDeviceGetName(ctypes.c_char_p(name), len(name), device) != 0 or \
|
||||
cuda.cuDeviceComputeCapability(ctypes.byref(cc_major), ctypes.byref(cc_minor), device) != 0:
|
||||
continue
|
||||
|
||||
if cuda.cuCtxCreate_v2(ctypes.byref(context), 0, device) == 0:
|
||||
if cuda.cuMemGetInfo_v2(ctypes.byref(freeMem), ctypes.byref(totalMem)) == 0:
|
||||
cc = cc_major.value * 10 + cc_minor.value
|
||||
if cc >= min_cc:
|
||||
devices.append ( Device(index=i,
|
||||
name=name.split(b'\0', 1)[0].decode(),
|
||||
total_mem=totalMem.value,
|
||||
free_mem=freeMem.value,
|
||||
cc=cc) )
|
||||
cuda.cuCtxDetach(context)
|
||||
Devices.all_devices = Devices(devices)
|
||||
return Devices.all_devices
|
||||
"""
|
|
@ -23,13 +23,28 @@ class Conv2D(nn.LayerBase):
|
|||
if padding == "SAME":
|
||||
padding = ( (kernel_size - 1) * dilations + 1 ) // 2
|
||||
elif padding == "VALID":
|
||||
padding = None
|
||||
padding = 0
|
||||
else:
|
||||
raise ValueError ("Wrong padding type. Should be VALID SAME or INT or 4x INTs")
|
||||
|
||||
if isinstance(padding, int):
|
||||
if padding != 0:
|
||||
if nn.data_format == "NHWC":
|
||||
padding = [ [0,0], [padding,padding], [padding,padding], [0,0] ]
|
||||
else:
|
||||
padding = [ [0,0], [0,0], [padding,padding], [padding,padding] ]
|
||||
else:
|
||||
padding = None
|
||||
|
||||
if nn.data_format == "NHWC":
|
||||
strides = [1,strides,strides,1]
|
||||
else:
|
||||
padding = int(padding)
|
||||
|
||||
strides = [1,1,strides,strides]
|
||||
|
||||
if nn.data_format == "NHWC":
|
||||
dilations = [1,dilations,dilations,1]
|
||||
else:
|
||||
dilations = [1,1,dilations,dilations]
|
||||
|
||||
self.in_ch = in_ch
|
||||
self.out_ch = out_ch
|
||||
|
@ -55,8 +70,8 @@ class Conv2D(nn.LayerBase):
|
|||
if kernel_initializer is None:
|
||||
kernel_initializer = tf.initializers.random_normal(0, 1.0, dtype=self.dtype)
|
||||
|
||||
#if kernel_initializer is None:
|
||||
# kernel_initializer = nn.initializers.ca()
|
||||
if kernel_initializer is None:
|
||||
kernel_initializer = nn.initializers.ca()
|
||||
|
||||
self.weight = tf.get_variable("weight", (self.kernel_size,self.kernel_size,self.in_ch,self.out_ch), dtype=self.dtype, initializer=kernel_initializer, trainable=self.trainable )
|
||||
|
||||
|
@ -78,27 +93,10 @@ class Conv2D(nn.LayerBase):
|
|||
if self.use_wscale:
|
||||
weight = weight * self.wscale
|
||||
|
||||
padding = self.padding
|
||||
if padding is not None:
|
||||
if nn.data_format == "NHWC":
|
||||
padding = [ [0,0], [padding,padding], [padding,padding], [0,0] ]
|
||||
else:
|
||||
padding = [ [0,0], [0,0], [padding,padding], [padding,padding] ]
|
||||
x = tf.pad (x, padding, mode='CONSTANT')
|
||||
if self.padding is not None:
|
||||
x = tf.pad (x, self.padding, mode='CONSTANT')
|
||||
|
||||
strides = self.strides
|
||||
if nn.data_format == "NHWC":
|
||||
strides = [1,strides,strides,1]
|
||||
else:
|
||||
strides = [1,1,strides,strides]
|
||||
|
||||
dilations = self.dilations
|
||||
if nn.data_format == "NHWC":
|
||||
dilations = [1,dilations,dilations,1]
|
||||
else:
|
||||
dilations = [1,1,dilations,dilations]
|
||||
|
||||
x = tf.nn.conv2d(x, weight, strides, 'VALID', dilations=dilations, data_format=nn.data_format)
|
||||
x = tf.nn.conv2d(x, weight, self.strides, 'VALID', dilations=self.dilations, data_format=nn.data_format)
|
||||
if self.use_bias:
|
||||
if nn.data_format == "NHWC":
|
||||
bias = tf.reshape (self.bias, (1,1,1,self.out_ch) )
|
||||
|
|
|
@ -38,8 +38,8 @@ class Conv2DTranspose(nn.LayerBase):
|
|||
if kernel_initializer is None:
|
||||
kernel_initializer = tf.initializers.random_normal(0, 1.0, dtype=self.dtype)
|
||||
|
||||
#if kernel_initializer is None:
|
||||
# kernel_initializer = nn.initializers.ca()
|
||||
if kernel_initializer is None:
|
||||
kernel_initializer = nn.initializers.ca()
|
||||
self.weight = tf.get_variable("weight", (self.kernel_size,self.kernel_size,self.out_ch,self.in_ch), dtype=self.dtype, initializer=kernel_initializer, trainable=self.trainable )
|
||||
|
||||
if self.use_bias:
|
||||
|
|
|
@ -1,110 +0,0 @@
|
|||
import numpy as np
|
||||
from core.leras import nn
|
||||
tf = nn.tf
|
||||
|
||||
class DepthwiseConv2D(nn.LayerBase):
|
||||
"""
|
||||
default kernel_initializer - CA
|
||||
use_wscale bool enables equalized learning rate, if kernel_initializer is None, it will be forced to random_normal
|
||||
"""
|
||||
def __init__(self, in_ch, kernel_size, strides=1, padding='SAME', depth_multiplier=1, dilations=1, use_bias=True, use_wscale=False, kernel_initializer=None, bias_initializer=None, trainable=True, dtype=None, **kwargs ):
|
||||
if not isinstance(strides, int):
|
||||
raise ValueError ("strides must be an int type")
|
||||
if not isinstance(dilations, int):
|
||||
raise ValueError ("dilations must be an int type")
|
||||
kernel_size = int(kernel_size)
|
||||
|
||||
if dtype is None:
|
||||
dtype = nn.floatx
|
||||
|
||||
if isinstance(padding, str):
|
||||
if padding == "SAME":
|
||||
padding = ( (kernel_size - 1) * dilations + 1 ) // 2
|
||||
elif padding == "VALID":
|
||||
padding = 0
|
||||
else:
|
||||
raise ValueError ("Wrong padding type. Should be VALID SAME or INT or 4x INTs")
|
||||
|
||||
if isinstance(padding, int):
|
||||
if padding != 0:
|
||||
if nn.data_format == "NHWC":
|
||||
padding = [ [0,0], [padding,padding], [padding,padding], [0,0] ]
|
||||
else:
|
||||
padding = [ [0,0], [0,0], [padding,padding], [padding,padding] ]
|
||||
else:
|
||||
padding = None
|
||||
|
||||
if nn.data_format == "NHWC":
|
||||
strides = [1,strides,strides,1]
|
||||
else:
|
||||
strides = [1,1,strides,strides]
|
||||
|
||||
if nn.data_format == "NHWC":
|
||||
dilations = [1,dilations,dilations,1]
|
||||
else:
|
||||
dilations = [1,1,dilations,dilations]
|
||||
|
||||
self.in_ch = in_ch
|
||||
self.depth_multiplier = depth_multiplier
|
||||
self.kernel_size = kernel_size
|
||||
self.strides = strides
|
||||
self.padding = padding
|
||||
self.dilations = dilations
|
||||
self.use_bias = use_bias
|
||||
self.use_wscale = use_wscale
|
||||
self.kernel_initializer = kernel_initializer
|
||||
self.bias_initializer = bias_initializer
|
||||
self.trainable = trainable
|
||||
self.dtype = dtype
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def build_weights(self):
|
||||
kernel_initializer = self.kernel_initializer
|
||||
if self.use_wscale:
|
||||
gain = 1.0 if self.kernel_size == 1 else np.sqrt(2)
|
||||
fan_in = self.kernel_size*self.kernel_size*self.in_ch
|
||||
he_std = gain / np.sqrt(fan_in)
|
||||
self.wscale = tf.constant(he_std, dtype=self.dtype )
|
||||
if kernel_initializer is None:
|
||||
kernel_initializer = tf.initializers.random_normal(0, 1.0, dtype=self.dtype)
|
||||
|
||||
#if kernel_initializer is None:
|
||||
# kernel_initializer = nn.initializers.ca()
|
||||
|
||||
self.weight = tf.get_variable("weight", (self.kernel_size,self.kernel_size,self.in_ch,self.depth_multiplier), dtype=self.dtype, initializer=kernel_initializer, trainable=self.trainable )
|
||||
|
||||
if self.use_bias:
|
||||
bias_initializer = self.bias_initializer
|
||||
if bias_initializer is None:
|
||||
bias_initializer = tf.initializers.zeros(dtype=self.dtype)
|
||||
|
||||
self.bias = tf.get_variable("bias", (self.in_ch*self.depth_multiplier,), dtype=self.dtype, initializer=bias_initializer, trainable=self.trainable )
|
||||
|
||||
def get_weights(self):
|
||||
weights = [self.weight]
|
||||
if self.use_bias:
|
||||
weights += [self.bias]
|
||||
return weights
|
||||
|
||||
def forward(self, x):
|
||||
weight = self.weight
|
||||
if self.use_wscale:
|
||||
weight = weight * self.wscale
|
||||
|
||||
if self.padding is not None:
|
||||
x = tf.pad (x, self.padding, mode='CONSTANT')
|
||||
|
||||
x = tf.nn.depthwise_conv2d(x, weight, self.strides, 'VALID', data_format=nn.data_format)
|
||||
if self.use_bias:
|
||||
if nn.data_format == "NHWC":
|
||||
bias = tf.reshape (self.bias, (1,1,1,self.in_ch*self.depth_multiplier) )
|
||||
else:
|
||||
bias = tf.reshape (self.bias, (1,self.in_ch*self.depth_multiplier,1,1) )
|
||||
x = tf.add(x, bias)
|
||||
return x
|
||||
|
||||
def __str__(self):
|
||||
r = f"{self.__class__.__name__} : in_ch:{self.in_ch} depth_multiplier:{self.depth_multiplier} "
|
||||
return r
|
||||
|
||||
nn.DepthwiseConv2D = DepthwiseConv2D
|
|
@ -46,9 +46,7 @@ class Saveable():
|
|||
raise Exception("name must be defined.")
|
||||
|
||||
name = self.name
|
||||
|
||||
for w in weights:
|
||||
w_val = nn.tf_sess.run (w).copy()
|
||||
for w, w_val in zip(weights, nn.tf_sess.run (weights)):
|
||||
w_name_split = w.name.split('/', 1)
|
||||
if name != w_name_split[0]:
|
||||
raise Exception("weight first name != Saveable.name")
|
||||
|
@ -78,27 +76,24 @@ class Saveable():
|
|||
if self.name is None:
|
||||
raise Exception("name must be defined.")
|
||||
|
||||
try:
|
||||
tuples = []
|
||||
for w in weights:
|
||||
w_name_split = w.name.split('/')
|
||||
if self.name != w_name_split[0]:
|
||||
raise Exception("weight first name != Saveable.name")
|
||||
tuples = []
|
||||
for w in weights:
|
||||
w_name_split = w.name.split('/')
|
||||
if self.name != w_name_split[0]:
|
||||
raise Exception("weight first name != Saveable.name")
|
||||
|
||||
sub_w_name = "/".join(w_name_split[1:])
|
||||
sub_w_name = "/".join(w_name_split[1:])
|
||||
|
||||
w_val = d.get(sub_w_name, None)
|
||||
w_val = d.get(sub_w_name, None)
|
||||
|
||||
if w_val is None:
|
||||
#io.log_err(f"Weight {w.name} was not loaded from file {filename}")
|
||||
tuples.append ( (w, w.initializer) )
|
||||
else:
|
||||
w_val = np.reshape( w_val, w.shape.as_list() )
|
||||
tuples.append ( (w, w_val) )
|
||||
if w_val is None:
|
||||
#io.log_err(f"Weight {w.name} was not loaded from file {filename}")
|
||||
tuples.append ( (w, w.initializer) )
|
||||
else:
|
||||
w_val = np.reshape( w_val, w.shape.as_list() )
|
||||
tuples.append ( (w, w_val) )
|
||||
|
||||
nn.batch_set_value(tuples)
|
||||
except:
|
||||
return False
|
||||
nn.batch_set_value(tuples)
|
||||
|
||||
return True
|
||||
|
||||
|
|
|
@ -1,104 +0,0 @@
|
|||
import numpy as np
|
||||
from core.leras import nn
|
||||
tf = nn.tf
|
||||
|
||||
class TanhPolar(nn.LayerBase):
|
||||
"""
|
||||
RoI Tanh-polar Transformer Network for Face Parsing in the Wild
|
||||
https://github.com/hhj1897/roi_tanh_warping
|
||||
"""
|
||||
|
||||
def __init__(self, width, height, angular_offset_deg=270, **kwargs):
|
||||
self.width = width
|
||||
self.height = height
|
||||
|
||||
warp_gridx, warp_gridy = TanhPolar._get_tanh_polar_warp_grids(width,height,angular_offset_deg=angular_offset_deg)
|
||||
restore_gridx, restore_gridy = TanhPolar._get_tanh_polar_restore_grids(width,height,angular_offset_deg=angular_offset_deg)
|
||||
|
||||
self.warp_gridx_t = tf.constant(warp_gridx[None, ...])
|
||||
self.warp_gridy_t = tf.constant(warp_gridy[None, ...])
|
||||
self.restore_gridx_t = tf.constant(restore_gridx[None, ...])
|
||||
self.restore_gridy_t = tf.constant(restore_gridy[None, ...])
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def warp(self, inp_t):
|
||||
batch_t = tf.shape(inp_t)[0]
|
||||
warp_gridx_t = tf.tile(self.warp_gridx_t, (batch_t,1,1) )
|
||||
warp_gridy_t = tf.tile(self.warp_gridy_t, (batch_t,1,1) )
|
||||
|
||||
if nn.data_format == "NCHW":
|
||||
inp_t = tf.transpose(inp_t,(0,2,3,1))
|
||||
|
||||
out_t = nn.bilinear_sampler(inp_t, warp_gridx_t, warp_gridy_t)
|
||||
|
||||
if nn.data_format == "NCHW":
|
||||
out_t = tf.transpose(out_t,(0,3,1,2))
|
||||
|
||||
return out_t
|
||||
|
||||
def restore(self, inp_t):
|
||||
batch_t = tf.shape(inp_t)[0]
|
||||
restore_gridx_t = tf.tile(self.restore_gridx_t, (batch_t,1,1) )
|
||||
restore_gridy_t = tf.tile(self.restore_gridy_t, (batch_t,1,1) )
|
||||
|
||||
if nn.data_format == "NCHW":
|
||||
inp_t = tf.transpose(inp_t,(0,2,3,1))
|
||||
|
||||
inp_t = tf.pad(inp_t, [(0,0), (1, 1), (1, 0), (0, 0)], "SYMMETRIC")
|
||||
|
||||
out_t = nn.bilinear_sampler(inp_t, restore_gridx_t, restore_gridy_t)
|
||||
|
||||
if nn.data_format == "NCHW":
|
||||
out_t = tf.transpose(out_t,(0,3,1,2))
|
||||
|
||||
return out_t
|
||||
|
||||
@staticmethod
|
||||
def _get_tanh_polar_warp_grids(W,H,angular_offset_deg):
|
||||
angular_offset_pi = angular_offset_deg * np.pi / 180.0
|
||||
|
||||
roi_center = np.array([ W//2, H//2], np.float32 )
|
||||
roi_radii = np.array([W, H], np.float32 ) / np.pi ** 0.5
|
||||
cos_offset, sin_offset = np.cos(angular_offset_pi), np.sin(angular_offset_pi)
|
||||
normalised_dest_indices = np.stack(np.meshgrid(np.arange(0.0, 1.0, 1.0 / W),np.arange(0.0, 2.0 * np.pi, 2.0 * np.pi / H)), axis=-1)
|
||||
radii = normalised_dest_indices[..., 0]
|
||||
orientation_x = np.cos(normalised_dest_indices[..., 1])
|
||||
orientation_y = np.sin(normalised_dest_indices[..., 1])
|
||||
|
||||
src_radii = np.arctanh(radii) * (roi_radii[0] * roi_radii[1] / np.sqrt(roi_radii[1] ** 2 * orientation_x ** 2 + roi_radii[0] ** 2 * orientation_y ** 2))
|
||||
src_x_indices = src_radii * orientation_x
|
||||
src_y_indices = src_radii * orientation_y
|
||||
src_x_indices, src_y_indices = (roi_center[0] + cos_offset * src_x_indices - sin_offset * src_y_indices,
|
||||
roi_center[1] + cos_offset * src_y_indices + sin_offset * src_x_indices)
|
||||
|
||||
return src_x_indices.astype(np.float32), src_y_indices.astype(np.float32)
|
||||
|
||||
@staticmethod
|
||||
def _get_tanh_polar_restore_grids(W,H,angular_offset_deg):
|
||||
angular_offset_pi = angular_offset_deg * np.pi / 180.0
|
||||
|
||||
roi_center = np.array([ W//2, H//2], np.float32 )
|
||||
roi_radii = np.array([W, H], np.float32 ) / np.pi ** 0.5
|
||||
cos_offset, sin_offset = np.cos(angular_offset_pi), np.sin(angular_offset_pi)
|
||||
|
||||
dest_indices = np.stack(np.meshgrid(np.arange(W), np.arange(H)), axis=-1).astype(float)
|
||||
normalised_dest_indices = np.matmul(dest_indices - roi_center, np.array([[cos_offset, -sin_offset],
|
||||
[sin_offset, cos_offset]]))
|
||||
radii = np.linalg.norm(normalised_dest_indices, axis=-1)
|
||||
normalised_dest_indices[..., 0] /= np.clip(radii, 1e-9, None)
|
||||
normalised_dest_indices[..., 1] /= np.clip(radii, 1e-9, None)
|
||||
radii *= np.sqrt(roi_radii[1] ** 2 * normalised_dest_indices[..., 0] ** 2 +
|
||||
roi_radii[0] ** 2 * normalised_dest_indices[..., 1] ** 2) / roi_radii[0] / roi_radii[1]
|
||||
|
||||
src_radii = np.tanh(radii)
|
||||
|
||||
|
||||
src_x_indices = src_radii * W + 1.0
|
||||
src_y_indices = np.mod((np.arctan2(normalised_dest_indices[..., 1], normalised_dest_indices[..., 0]) /
|
||||
2.0 / np.pi) * H, H) + 1.0
|
||||
|
||||
return src_x_indices.astype(np.float32), src_y_indices.astype(np.float32)
|
||||
|
||||
|
||||
nn.TanhPolar = TanhPolar
|
|
@ -3,16 +3,12 @@ from .LayerBase import *
|
|||
|
||||
from .Conv2D import *
|
||||
from .Conv2DTranspose import *
|
||||
from .DepthwiseConv2D import *
|
||||
from .Dense import *
|
||||
from .BlurPool import *
|
||||
|
||||
from .BatchNorm2D import *
|
||||
from .InstanceNorm2D import *
|
||||
from .FRNorm2D import *
|
||||
|
||||
from .TLU import *
|
||||
from .ScaleAdd import *
|
||||
from .DenseNorm import *
|
||||
from .AdaIN import *
|
||||
from .TanhPolar import *
|
|
@ -18,10 +18,6 @@ class ModelBase(nn.Saveable):
|
|||
if isinstance (layer, list):
|
||||
for i,sublayer in enumerate(layer):
|
||||
self._build_sub(sublayer, f"{name}_{i}")
|
||||
elif isinstance (layer, dict):
|
||||
for subname in layer.keys():
|
||||
sublayer = layer[subname]
|
||||
self._build_sub(sublayer, f"{name}_{subname}")
|
||||
elif isinstance (layer, nn.LayerBase) or \
|
||||
isinstance (layer, ModelBase):
|
||||
|
||||
|
@ -116,32 +112,41 @@ class ModelBase(nn.Saveable):
|
|||
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
# def compute_output_shape(self, shapes):
|
||||
# if not self.built:
|
||||
# self.build()
|
||||
def compute_output_shape(self, shapes):
|
||||
if not self.built:
|
||||
self.build()
|
||||
|
||||
# not_list = False
|
||||
# if not isinstance(shapes, list):
|
||||
# not_list = True
|
||||
# shapes = [shapes]
|
||||
not_list = False
|
||||
if not isinstance(shapes, list):
|
||||
not_list = True
|
||||
shapes = [shapes]
|
||||
|
||||
# with tf.device('/CPU:0'):
|
||||
# # CPU tensors will not impact any performance, only slightly RAM "leakage"
|
||||
# phs = []
|
||||
# for dtype,sh in shapes:
|
||||
# phs += [ tf.placeholder(dtype, sh) ]
|
||||
with tf.device('/CPU:0'):
|
||||
# CPU tensors will not impact any performance, only slightly RAM "leakage"
|
||||
phs = []
|
||||
for dtype,sh in shapes:
|
||||
phs += [ tf.placeholder(dtype, sh) ]
|
||||
|
||||
# result = self.__call__(phs[0] if not_list else phs)
|
||||
result = self.__call__(phs[0] if not_list else phs)
|
||||
|
||||
# if not isinstance(result, list):
|
||||
# result = [result]
|
||||
if not isinstance(result, list):
|
||||
result = [result]
|
||||
|
||||
# result_shapes = []
|
||||
result_shapes = []
|
||||
|
||||
# for t in result:
|
||||
# result_shapes += [ t.shape.as_list() ]
|
||||
for t in result:
|
||||
result_shapes += [ t.shape.as_list() ]
|
||||
|
||||
# return result_shapes[0] if not_list else result_shapes
|
||||
return result_shapes[0] if not_list else result_shapes
|
||||
|
||||
def compute_output_channels(self, shapes):
|
||||
shape = self.compute_output_shape(shapes)
|
||||
shape_len = len(shape)
|
||||
|
||||
if shape_len == 4:
|
||||
if nn.data_format == "NCHW":
|
||||
return shape[1]
|
||||
return shape[-1]
|
||||
|
||||
def build_for_run(self, shapes_list):
|
||||
if not isinstance(shapes_list, list):
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import numpy as np
|
||||
from core.leras import nn
|
||||
tf = nn.tf
|
||||
|
||||
|
||||
patch_discriminator_kernels = \
|
||||
{ 1 : (512, [ [1,1] ]),
|
||||
2 : (512, [ [2,1] ]),
|
||||
|
@ -41,14 +41,6 @@ patch_discriminator_kernels = \
|
|||
36 : (256, [ [4,2], [3,2], [4,2], [3,1] ]),
|
||||
37 : (256, [ [3,2], [4,2], [4,2], [3,1] ]),
|
||||
38 : (256, [ [4,2], [4,2], [4,2], [3,1] ]),
|
||||
39 : (256, [ [3,2], [3,2], [3,2], [4,1] ]),
|
||||
40 : (256, [ [4,2], [3,2], [3,2], [4,1] ]),
|
||||
41 : (256, [ [3,2], [4,2], [3,2], [4,1] ]),
|
||||
42 : (256, [ [4,2], [4,2], [3,2], [4,1] ]),
|
||||
43 : (256, [ [3,2], [4,2], [4,2], [4,1] ]),
|
||||
44 : (256, [ [4,2], [3,2], [4,2], [4,1] ]),
|
||||
45 : (256, [ [3,2], [4,2], [4,2], [4,1] ]),
|
||||
46 : (256, [ [4,2], [4,2], [4,2], [4,1] ]),
|
||||
}
|
||||
|
||||
|
||||
|
@ -75,120 +67,3 @@ class PatchDiscriminator(nn.ModelBase):
|
|||
return self.out_conv(x)
|
||||
|
||||
nn.PatchDiscriminator = PatchDiscriminator
|
||||
|
||||
class UNetPatchDiscriminator(nn.ModelBase):
|
||||
"""
|
||||
Inspired by https://arxiv.org/abs/2002.12655 "A U-Net Based Discriminator for Generative Adversarial Networks"
|
||||
"""
|
||||
def calc_receptive_field_size(self, layers):
|
||||
"""
|
||||
result the same as https://fomoro.com/research/article/receptive-field-calculatorindex.html
|
||||
"""
|
||||
rf = 0
|
||||
ts = 1
|
||||
for i, (k, s) in enumerate(layers):
|
||||
if i == 0:
|
||||
rf = k
|
||||
else:
|
||||
rf += (k-1)*ts
|
||||
ts *= s
|
||||
return rf
|
||||
|
||||
def find_archi(self, target_patch_size, max_layers=9):
|
||||
"""
|
||||
Find the best configuration of layers using only 3x3 convs for target patch size
|
||||
"""
|
||||
s = {}
|
||||
for layers_count in range(1,max_layers+1):
|
||||
val = 1 << (layers_count-1)
|
||||
while True:
|
||||
val -= 1
|
||||
|
||||
layers = []
|
||||
sum_st = 0
|
||||
layers.append ( [3, 2])
|
||||
sum_st += 2
|
||||
for i in range(layers_count-1):
|
||||
st = 1 + (1 if val & (1 << i) !=0 else 0 )
|
||||
layers.append ( [3, st ])
|
||||
sum_st += st
|
||||
|
||||
rf = self.calc_receptive_field_size(layers)
|
||||
|
||||
s_rf = s.get(rf, None)
|
||||
if s_rf is None:
|
||||
s[rf] = (layers_count, sum_st, layers)
|
||||
else:
|
||||
if layers_count < s_rf[0] or \
|
||||
( layers_count == s_rf[0] and sum_st > s_rf[1] ):
|
||||
s[rf] = (layers_count, sum_st, layers)
|
||||
|
||||
if val == 0:
|
||||
break
|
||||
|
||||
x = sorted(list(s.keys()))
|
||||
q=x[np.abs(np.array(x)-target_patch_size).argmin()]
|
||||
return s[q][2]
|
||||
|
||||
def on_build(self, patch_size, in_ch, base_ch = 16, use_fp16 = False):
|
||||
self.use_fp16 = use_fp16
|
||||
conv_dtype = tf.float16 if use_fp16 else tf.float32
|
||||
|
||||
class ResidualBlock(nn.ModelBase):
|
||||
def on_build(self, ch, kernel_size=3 ):
|
||||
self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype)
|
||||
self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype)
|
||||
|
||||
def forward(self, inp):
|
||||
x = self.conv1(inp)
|
||||
x = tf.nn.leaky_relu(x, 0.2)
|
||||
x = self.conv2(x)
|
||||
x = tf.nn.leaky_relu(inp + x, 0.2)
|
||||
return x
|
||||
|
||||
prev_ch = in_ch
|
||||
self.convs = []
|
||||
self.upconvs = []
|
||||
layers = self.find_archi(patch_size)
|
||||
|
||||
level_chs = { i-1:v for i,v in enumerate([ min( base_ch * (2**i), 512 ) for i in range(len(layers)+1)]) }
|
||||
|
||||
self.in_conv = nn.Conv2D( in_ch, level_chs[-1], kernel_size=1, padding='VALID', dtype=conv_dtype)
|
||||
|
||||
for i, (kernel_size, strides) in enumerate(layers):
|
||||
self.convs.append ( nn.Conv2D( level_chs[i-1], level_chs[i], kernel_size=kernel_size, strides=strides, padding='SAME', dtype=conv_dtype) )
|
||||
|
||||
self.upconvs.insert (0, nn.Conv2DTranspose( level_chs[i]*(2 if i != len(layers)-1 else 1), level_chs[i-1], kernel_size=kernel_size, strides=strides, padding='SAME', dtype=conv_dtype) )
|
||||
|
||||
self.out_conv = nn.Conv2D( level_chs[-1]*2, 1, kernel_size=1, padding='VALID', dtype=conv_dtype)
|
||||
|
||||
self.center_out = nn.Conv2D( level_chs[len(layers)-1], 1, kernel_size=1, padding='VALID', dtype=conv_dtype)
|
||||
self.center_conv = nn.Conv2D( level_chs[len(layers)-1], level_chs[len(layers)-1], kernel_size=1, padding='VALID', dtype=conv_dtype)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_fp16:
|
||||
x = tf.cast(x, tf.float16)
|
||||
|
||||
x = tf.nn.leaky_relu( self.in_conv(x), 0.2 )
|
||||
|
||||
encs = []
|
||||
for conv in self.convs:
|
||||
encs.insert(0, x)
|
||||
x = tf.nn.leaky_relu( conv(x), 0.2 )
|
||||
|
||||
center_out, x = self.center_out(x), tf.nn.leaky_relu( self.center_conv(x), 0.2 )
|
||||
|
||||
for i, (upconv, enc) in enumerate(zip(self.upconvs, encs)):
|
||||
x = tf.nn.leaky_relu( upconv(x), 0.2 )
|
||||
x = tf.concat( [enc, x], axis=nn.conv2d_ch_axis)
|
||||
|
||||
x = self.out_conv(x)
|
||||
|
||||
if self.use_fp16:
|
||||
center_out = tf.cast(center_out, tf.float32)
|
||||
x = tf.cast(x, tf.float32)
|
||||
|
||||
return center_out, x
|
||||
|
||||
nn.UNetPatchDiscriminator = UNetPatchDiscriminator
|
||||
|
|
|
@ -29,11 +29,10 @@ class XSeg(nn.ModelBase):
|
|||
x = self.tlu(x)
|
||||
return x
|
||||
|
||||
self.base_ch = base_ch
|
||||
|
||||
self.conv01 = ConvBlock(in_ch, base_ch)
|
||||
self.conv02 = ConvBlock(base_ch, base_ch)
|
||||
self.bp0 = nn.BlurPool (filt_size=4)
|
||||
self.bp0 = nn.BlurPool (filt_size=3)
|
||||
|
||||
|
||||
self.conv11 = ConvBlock(base_ch, base_ch*2)
|
||||
self.conv12 = ConvBlock(base_ch*2, base_ch*2)
|
||||
|
@ -41,30 +40,19 @@ class XSeg(nn.ModelBase):
|
|||
|
||||
self.conv21 = ConvBlock(base_ch*2, base_ch*4)
|
||||
self.conv22 = ConvBlock(base_ch*4, base_ch*4)
|
||||
self.bp2 = nn.BlurPool (filt_size=2)
|
||||
self.conv23 = ConvBlock(base_ch*4, base_ch*4)
|
||||
self.bp2 = nn.BlurPool (filt_size=3)
|
||||
|
||||
|
||||
self.conv31 = ConvBlock(base_ch*4, base_ch*8)
|
||||
self.conv32 = ConvBlock(base_ch*8, base_ch*8)
|
||||
self.conv33 = ConvBlock(base_ch*8, base_ch*8)
|
||||
self.bp3 = nn.BlurPool (filt_size=2)
|
||||
self.bp3 = nn.BlurPool (filt_size=3)
|
||||
|
||||
self.conv41 = ConvBlock(base_ch*8, base_ch*8)
|
||||
self.conv42 = ConvBlock(base_ch*8, base_ch*8)
|
||||
self.conv43 = ConvBlock(base_ch*8, base_ch*8)
|
||||
self.bp4 = nn.BlurPool (filt_size=2)
|
||||
|
||||
self.conv51 = ConvBlock(base_ch*8, base_ch*8)
|
||||
self.conv52 = ConvBlock(base_ch*8, base_ch*8)
|
||||
self.conv53 = ConvBlock(base_ch*8, base_ch*8)
|
||||
self.bp5 = nn.BlurPool (filt_size=2)
|
||||
|
||||
self.dense1 = nn.Dense ( 4*4* base_ch*8, 512)
|
||||
self.dense2 = nn.Dense ( 512, 4*4* base_ch*8)
|
||||
|
||||
self.up5 = UpConvBlock (base_ch*8, base_ch*4)
|
||||
self.uconv53 = ConvBlock(base_ch*12, base_ch*8)
|
||||
self.uconv52 = ConvBlock(base_ch*8, base_ch*8)
|
||||
self.uconv51 = ConvBlock(base_ch*8, base_ch*8)
|
||||
self.bp4 = nn.BlurPool (filt_size=3)
|
||||
|
||||
self.up4 = UpConvBlock (base_ch*8, base_ch*4)
|
||||
self.uconv43 = ConvBlock(base_ch*12, base_ch*8)
|
||||
|
@ -77,7 +65,8 @@ class XSeg(nn.ModelBase):
|
|||
self.uconv31 = ConvBlock(base_ch*8, base_ch*8)
|
||||
|
||||
self.up2 = UpConvBlock (base_ch*8, base_ch*4)
|
||||
self.uconv22 = ConvBlock(base_ch*8, base_ch*4)
|
||||
self.uconv23 = ConvBlock(base_ch*8, base_ch*4)
|
||||
self.uconv22 = ConvBlock(base_ch*4, base_ch*4)
|
||||
self.uconv21 = ConvBlock(base_ch*4, base_ch*4)
|
||||
|
||||
self.up1 = UpConvBlock (base_ch*4, base_ch*2)
|
||||
|
@ -89,8 +78,9 @@ class XSeg(nn.ModelBase):
|
|||
self.uconv01 = ConvBlock(base_ch, base_ch)
|
||||
self.out_conv = nn.Conv2D (base_ch, out_ch, kernel_size=3, padding='SAME')
|
||||
|
||||
self.conv_center = ConvBlock(base_ch*8, base_ch*8)
|
||||
|
||||
def forward(self, inp, pretrain=False):
|
||||
def forward(self, inp):
|
||||
x = inp
|
||||
|
||||
x = self.conv01(x)
|
||||
|
@ -102,7 +92,8 @@ class XSeg(nn.ModelBase):
|
|||
x = self.bp1(x)
|
||||
|
||||
x = self.conv21(x)
|
||||
x = x2 = self.conv22(x)
|
||||
x = self.conv22(x)
|
||||
x = x2 = self.conv23(x)
|
||||
x = self.bp2(x)
|
||||
|
||||
x = self.conv31(x)
|
||||
|
@ -115,52 +106,28 @@ class XSeg(nn.ModelBase):
|
|||
x = x4 = self.conv43(x)
|
||||
x = self.bp4(x)
|
||||
|
||||
x = self.conv51(x)
|
||||
x = self.conv52(x)
|
||||
x = x5 = self.conv53(x)
|
||||
x = self.bp5(x)
|
||||
|
||||
x = nn.flatten(x)
|
||||
x = self.dense1(x)
|
||||
x = self.dense2(x)
|
||||
x = nn.reshape_4D (x, 4, 4, self.base_ch*8 )
|
||||
|
||||
x = self.up5(x)
|
||||
if pretrain:
|
||||
x5 = tf.zeros_like(x5)
|
||||
x = self.uconv53(tf.concat([x,x5],axis=nn.conv2d_ch_axis))
|
||||
x = self.uconv52(x)
|
||||
x = self.uconv51(x)
|
||||
x = self.conv_center(x)
|
||||
|
||||
x = self.up4(x)
|
||||
if pretrain:
|
||||
x4 = tf.zeros_like(x4)
|
||||
x = self.uconv43(tf.concat([x,x4],axis=nn.conv2d_ch_axis))
|
||||
x = self.uconv42(x)
|
||||
x = self.uconv41(x)
|
||||
|
||||
x = self.up3(x)
|
||||
if pretrain:
|
||||
x3 = tf.zeros_like(x3)
|
||||
x = self.uconv33(tf.concat([x,x3],axis=nn.conv2d_ch_axis))
|
||||
x = self.uconv32(x)
|
||||
x = self.uconv31(x)
|
||||
|
||||
x = self.up2(x)
|
||||
if pretrain:
|
||||
x2 = tf.zeros_like(x2)
|
||||
x = self.uconv22(tf.concat([x,x2],axis=nn.conv2d_ch_axis))
|
||||
x = self.uconv23(tf.concat([x,x2],axis=nn.conv2d_ch_axis))
|
||||
x = self.uconv22(x)
|
||||
x = self.uconv21(x)
|
||||
|
||||
x = self.up1(x)
|
||||
if pretrain:
|
||||
x1 = tf.zeros_like(x1)
|
||||
x = self.uconv12(tf.concat([x,x1],axis=nn.conv2d_ch_axis))
|
||||
x = self.uconv11(x)
|
||||
|
||||
x = self.up0(x)
|
||||
if pretrain:
|
||||
x0 = tf.zeros_like(x0)
|
||||
x = self.uconv02(tf.concat([x,x0],axis=nn.conv2d_ch_axis))
|
||||
x = self.uconv01(x)
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@ class nn():
|
|||
tf = None
|
||||
tf_sess = None
|
||||
tf_sess_config = None
|
||||
tf_default_device_name = None
|
||||
tf_default_device = None
|
||||
|
||||
data_format = None
|
||||
conv2d_ch_axis = None
|
||||
|
@ -51,6 +51,9 @@ class nn():
|
|||
|
||||
# Manipulate environment variables before import tensorflow
|
||||
|
||||
if 'CUDA_VISIBLE_DEVICES' in os.environ.keys():
|
||||
os.environ.pop('CUDA_VISIBLE_DEVICES')
|
||||
|
||||
first_run = False
|
||||
if len(device_config.devices) != 0:
|
||||
if sys.platform[0:3] == 'win':
|
||||
|
@ -65,32 +68,21 @@ class nn():
|
|||
compute_cache_path = Path(os.environ['APPDATA']) / 'NVIDIA' / ('ComputeCache' + devices_str)
|
||||
if not compute_cache_path.exists():
|
||||
first_run = True
|
||||
compute_cache_path.mkdir(parents=True, exist_ok=True)
|
||||
os.environ['CUDA_CACHE_PATH'] = str(compute_cache_path)
|
||||
|
||||
os.environ['CUDA_CACHE_MAXSIZE'] = '536870912' #512Mb (32mb default)
|
||||
os.environ['TF_MIN_GPU_MULTIPROCESSOR_COUNT'] = '2'
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # tf log errors only
|
||||
|
||||
if first_run:
|
||||
io.log_info("Caching GPU kernels...")
|
||||
|
||||
import tensorflow
|
||||
|
||||
tf_version = tensorflow.version.VERSION
|
||||
#if tf_version is None:
|
||||
# tf_version = tensorflow.version.GIT_VERSION
|
||||
if tf_version[0] == 'v':
|
||||
tf_version = tf_version[1:]
|
||||
if tf_version[0] == '2':
|
||||
tf = tensorflow.compat.v1
|
||||
else:
|
||||
tf = tensorflow
|
||||
import tensorflow as tf
|
||||
nn.tf = tf
|
||||
|
||||
import logging
|
||||
# Disable tensorflow warnings
|
||||
tf_logger = logging.getLogger('tensorflow')
|
||||
tf_logger.setLevel(logging.ERROR)
|
||||
|
||||
if tf_version[0] == '2':
|
||||
tf.disable_v2_behavior()
|
||||
nn.tf = tf
|
||||
logging.getLogger('tensorflow').setLevel(logging.ERROR)
|
||||
|
||||
# Initialize framework
|
||||
import core.leras.ops
|
||||
|
@ -102,11 +94,10 @@ class nn():
|
|||
|
||||
# Configure tensorflow session-config
|
||||
if len(device_config.devices) == 0:
|
||||
nn.tf_default_device = "/CPU:0"
|
||||
config = tf.ConfigProto(device_count={'GPU': 0})
|
||||
nn.tf_default_device_name = '/CPU:0'
|
||||
else:
|
||||
nn.tf_default_device_name = f'/{device_config.devices[0].tf_dev_type}:0'
|
||||
|
||||
nn.tf_default_device = "/GPU:0"
|
||||
config = tf.ConfigProto()
|
||||
config.gpu_options.visible_device_list = ','.join([str(device.index) for device in device_config.devices])
|
||||
|
||||
|
@ -197,6 +188,14 @@ class nn():
|
|||
nn.tf_sess.close()
|
||||
nn.tf_sess = None
|
||||
|
||||
@staticmethod
|
||||
def get_current_device():
|
||||
# Undocumented access to last tf.device(...)
|
||||
objs = nn.tf.get_default_graph()._device_function_stack.peek_objs()
|
||||
if len(objs) != 0:
|
||||
return objs[0].display_name
|
||||
return nn.tf_default_device
|
||||
|
||||
@staticmethod
|
||||
def ask_choose_device_idxs(choose_only_one=False, allow_cpu=True, suggest_best_multi_gpu=False, suggest_all_gpu=False):
|
||||
devices = Devices.getDevices()
|
||||
|
|
|
@ -108,15 +108,10 @@ nn.gelu = gelu
|
|||
|
||||
def upsample2d(x, size=2):
|
||||
if nn.data_format == "NCHW":
|
||||
x = tf.transpose(x, (0,2,3,1))
|
||||
x = tf.image.resize_nearest_neighbor(x, (x.shape[1]*size, x.shape[2]*size) )
|
||||
x = tf.transpose(x, (0,3,1,2))
|
||||
|
||||
|
||||
# b,c,h,w = x.shape.as_list()
|
||||
# x = tf.reshape (x, (-1,c,h,1,w,1) )
|
||||
# x = tf.tile(x, (1,1,1,size,1,size) )
|
||||
# x = tf.reshape (x, (-1,c,h*size,w*size) )
|
||||
b,c,h,w = x.shape.as_list()
|
||||
x = tf.reshape (x, (-1,c,h,1,w,1) )
|
||||
x = tf.tile(x, (1,1,1,size,1,size) )
|
||||
x = tf.reshape (x, (-1,c,h*size,w*size) )
|
||||
return x
|
||||
else:
|
||||
return tf.image.resize_nearest_neighbor(x, (x.shape[1]*size, x.shape[2]*size) )
|
||||
|
@ -142,39 +137,8 @@ def resize2d_bilinear(x, size=2):
|
|||
return x
|
||||
nn.resize2d_bilinear = resize2d_bilinear
|
||||
|
||||
def resize2d_nearest(x, size=2):
|
||||
if size in [-1,0,1]:
|
||||
return x
|
||||
|
||||
|
||||
if size > 0:
|
||||
raise Exception("")
|
||||
else:
|
||||
if nn.data_format == "NCHW":
|
||||
x = x[:,:,::-size,::-size]
|
||||
else:
|
||||
x = x[:,::-size,::-size,:]
|
||||
return x
|
||||
|
||||
h = x.shape[nn.conv2d_spatial_axes[0]].value
|
||||
w = x.shape[nn.conv2d_spatial_axes[1]].value
|
||||
|
||||
if nn.data_format == "NCHW":
|
||||
x = tf.transpose(x, (0,2,3,1))
|
||||
|
||||
if size > 0:
|
||||
new_size = (h*size,w*size)
|
||||
else:
|
||||
new_size = (h//-size,w//-size)
|
||||
|
||||
x = tf.image.resize(x, new_size, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
|
||||
|
||||
if nn.data_format == "NCHW":
|
||||
x = tf.transpose(x, (0,3,1,2))
|
||||
|
||||
return x
|
||||
nn.resize2d_nearest = resize2d_nearest
|
||||
|
||||
def flatten(x):
|
||||
if nn.data_format == "NHWC":
|
||||
# match NCHW version in order to switch data_format without problems
|
||||
|
@ -209,7 +173,7 @@ def random_binomial(shape, p=0.0, dtype=None, seed=None):
|
|||
seed = np.random.randint(10e6)
|
||||
return array_ops.where(
|
||||
random_ops.random_uniform(shape, dtype=tf.float16, seed=seed) < p,
|
||||
array_ops.ones(shape, dtype=dtype), array_ops.zeros(shape, dtype=dtype))
|
||||
array_ops.ones(shape, dtype=dtype), array_ops.zeros(shape, dtype=dtype))
|
||||
nn.random_binomial = random_binomial
|
||||
|
||||
def gaussian_blur(input, radius=2.0):
|
||||
|
@ -217,9 +181,7 @@ def gaussian_blur(input, radius=2.0):
|
|||
return np.exp(-(float(x) - float(mu)) ** 2 / (2 * sigma ** 2))
|
||||
|
||||
def make_kernel(sigma):
|
||||
kernel_size = max(3, int(2 * 2 * sigma))
|
||||
if kernel_size % 2 == 0:
|
||||
kernel_size += 1
|
||||
kernel_size = max(3, int(2 * 2 * sigma + 1))
|
||||
mean = np.floor(0.5 * kernel_size)
|
||||
kernel_1d = np.array([gaussian(x, mean, sigma) for x in range(kernel_size)])
|
||||
np_kernel = np.outer(kernel_1d, kernel_1d).astype(np.float32)
|
||||
|
@ -276,8 +238,6 @@ def dssim(img1,img2, max_val, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03
|
|||
img1 = tf.cast(img1, tf.float32)
|
||||
img2 = tf.cast(img2, tf.float32)
|
||||
|
||||
filter_size = max(1, filter_size)
|
||||
|
||||
kernel = np.arange(0, filter_size, dtype=np.float32)
|
||||
kernel -= (filter_size - 1 ) / 2.0
|
||||
kernel = kernel**2
|
||||
|
@ -340,17 +300,7 @@ def depth_to_space(x, size):
|
|||
x = tf.reshape(x, (-1, oh, ow, oc, ))
|
||||
return x
|
||||
else:
|
||||
cfg = nn.getCurrentDeviceConfig()
|
||||
if not cfg.cpu_only:
|
||||
return tf.depth_to_space(x, size, data_format=nn.data_format)
|
||||
b,c,h,w = x.shape.as_list()
|
||||
oh, ow = h * size, w * size
|
||||
oc = c // (size * size)
|
||||
|
||||
x = tf.reshape(x, (-1, size, size, oc, h, w, ) )
|
||||
x = tf.transpose(x, (0, 3, 4, 1, 5, 2))
|
||||
x = tf.reshape(x, (-1, oc, oh, ow))
|
||||
return x
|
||||
return tf.depth_to_space(x, size, data_format=nn.data_format)
|
||||
nn.depth_to_space = depth_to_space
|
||||
|
||||
def rgb_to_lab(srgb):
|
||||
|
@ -383,23 +333,6 @@ def rgb_to_lab(srgb):
|
|||
return tf.reshape(lab_pixels, tf.shape(srgb))
|
||||
nn.rgb_to_lab = rgb_to_lab
|
||||
|
||||
def total_variation_mse(images):
|
||||
"""
|
||||
Same as generic total_variation, but MSE diff instead of MAE
|
||||
"""
|
||||
pixel_dif1 = images[:, 1:, :, :] - images[:, :-1, :, :]
|
||||
pixel_dif2 = images[:, :, 1:, :] - images[:, :, :-1, :]
|
||||
|
||||
tot_var = ( tf.reduce_sum(tf.square(pixel_dif1), axis=[1,2,3]) +
|
||||
tf.reduce_sum(tf.square(pixel_dif2), axis=[1,2,3]) )
|
||||
return tot_var
|
||||
nn.total_variation_mse = total_variation_mse
|
||||
|
||||
|
||||
def pixel_norm(x, axes):
|
||||
return x * tf.rsqrt(tf.reduce_mean(tf.square(x), axis=axes, keepdims=True) + 1e-06)
|
||||
nn.pixel_norm = pixel_norm
|
||||
|
||||
"""
|
||||
def tf_suppress_lower_mean(t, eps=0.00001):
|
||||
if t.shape.ndims != 1:
|
||||
|
@ -410,69 +343,3 @@ def tf_suppress_lower_mean(t, eps=0.00001):
|
|||
q = q * (t/eps)
|
||||
return q
|
||||
"""
|
||||
|
||||
|
||||
|
||||
def _get_pixel_value(img, x, y):
|
||||
shape = tf.shape(x)
|
||||
batch_size = shape[0]
|
||||
height = shape[1]
|
||||
width = shape[2]
|
||||
|
||||
batch_idx = tf.range(0, batch_size)
|
||||
batch_idx = tf.reshape(batch_idx, (batch_size, 1, 1))
|
||||
b = tf.tile(batch_idx, (1, height, width))
|
||||
|
||||
indices = tf.stack([b, y, x], 3)
|
||||
|
||||
return tf.gather_nd(img, indices)
|
||||
|
||||
def bilinear_sampler(img, x, y):
|
||||
H = tf.shape(img)[1]
|
||||
W = tf.shape(img)[2]
|
||||
H_MAX = tf.cast(H - 1, tf.int32)
|
||||
W_MAX = tf.cast(W - 1, tf.int32)
|
||||
|
||||
# grab 4 nearest corner points for each (x_i, y_i)
|
||||
x0 = tf.cast(tf.floor(x), tf.int32)
|
||||
x1 = x0 + 1
|
||||
y0 = tf.cast(tf.floor(y), tf.int32)
|
||||
y1 = y0 + 1
|
||||
|
||||
# clip to range [0, H-1/W-1] to not violate img boundaries
|
||||
x0 = tf.clip_by_value(x0, 0, W_MAX)
|
||||
x1 = tf.clip_by_value(x1, 0, W_MAX)
|
||||
y0 = tf.clip_by_value(y0, 0, H_MAX)
|
||||
y1 = tf.clip_by_value(y1, 0, H_MAX)
|
||||
|
||||
# get pixel value at corner coords
|
||||
Ia = _get_pixel_value(img, x0, y0)
|
||||
Ib = _get_pixel_value(img, x0, y1)
|
||||
Ic = _get_pixel_value(img, x1, y0)
|
||||
Id = _get_pixel_value(img, x1, y1)
|
||||
|
||||
# recast as float for delta calculation
|
||||
x0 = tf.cast(x0, tf.float32)
|
||||
x1 = tf.cast(x1, tf.float32)
|
||||
y0 = tf.cast(y0, tf.float32)
|
||||
y1 = tf.cast(y1, tf.float32)
|
||||
|
||||
# calculate deltas
|
||||
wa = (x1-x) * (y1-y)
|
||||
wb = (x1-x) * (y-y0)
|
||||
wc = (x-x0) * (y1-y)
|
||||
wd = (x-x0) * (y-y0)
|
||||
|
||||
# add dimension for addition
|
||||
wa = tf.expand_dims(wa, axis=3)
|
||||
wb = tf.expand_dims(wb, axis=3)
|
||||
wc = tf.expand_dims(wc, axis=3)
|
||||
wd = tf.expand_dims(wd, axis=3)
|
||||
|
||||
# compute output
|
||||
out = tf.add_n([wa*Ia, wb*Ib, wc*Ic, wd*Id])
|
||||
|
||||
return out
|
||||
|
||||
nn.bilinear_sampler = bilinear_sampler
|
||||
|
||||
|
|
|
@ -1,81 +0,0 @@
|
|||
import numpy as np
|
||||
from core.leras import nn
|
||||
from tensorflow.python.ops import control_flow_ops, state_ops
|
||||
|
||||
tf = nn.tf
|
||||
|
||||
class AdaBelief(nn.OptimizerBase):
|
||||
def __init__(self, lr=0.001, beta_1=0.9, beta_2=0.999, lr_dropout=1.0, lr_cos=0, clipnorm=0.0, name=None, **kwargs):
|
||||
super().__init__(name=name)
|
||||
|
||||
if name is None:
|
||||
raise ValueError('name must be defined.')
|
||||
|
||||
self.lr = lr
|
||||
self.beta_1 = beta_1
|
||||
self.beta_2 = beta_2
|
||||
self.lr_dropout = lr_dropout
|
||||
self.lr_cos = lr_cos
|
||||
self.clipnorm = clipnorm
|
||||
|
||||
with tf.device('/CPU:0') :
|
||||
with tf.variable_scope(self.name):
|
||||
self.iterations = tf.Variable(0, dtype=tf.int64, name='iters')
|
||||
|
||||
self.ms_dict = {}
|
||||
self.vs_dict = {}
|
||||
self.lr_rnds_dict = {}
|
||||
|
||||
def get_weights(self):
|
||||
return [self.iterations] + list(self.ms_dict.values()) + list(self.vs_dict.values())
|
||||
|
||||
def initialize_variables(self, trainable_weights, vars_on_cpu=True, lr_dropout_on_cpu=False):
|
||||
# Initialize here all trainable variables used in training
|
||||
e = tf.device('/CPU:0') if vars_on_cpu else None
|
||||
if e: e.__enter__()
|
||||
with tf.variable_scope(self.name):
|
||||
ms = { v.name : tf.get_variable ( f'ms_{v.name}'.replace(':','_'), v.shape, dtype=v.dtype, initializer=tf.initializers.constant(0.0), trainable=False) for v in trainable_weights }
|
||||
vs = { v.name : tf.get_variable ( f'vs_{v.name}'.replace(':','_'), v.shape, dtype=v.dtype, initializer=tf.initializers.constant(0.0), trainable=False) for v in trainable_weights }
|
||||
self.ms_dict.update (ms)
|
||||
self.vs_dict.update (vs)
|
||||
|
||||
if self.lr_dropout != 1.0:
|
||||
e = tf.device('/CPU:0') if lr_dropout_on_cpu else None
|
||||
if e: e.__enter__()
|
||||
lr_rnds = [ nn.random_binomial( v.shape, p=self.lr_dropout, dtype=v.dtype) for v in trainable_weights ]
|
||||
if e: e.__exit__(None, None, None)
|
||||
self.lr_rnds_dict.update ( { v.name : rnd for v,rnd in zip(trainable_weights,lr_rnds) } )
|
||||
if e: e.__exit__(None, None, None)
|
||||
|
||||
def get_update_op(self, grads_vars):
|
||||
updates = []
|
||||
|
||||
if self.clipnorm > 0.0:
|
||||
norm = tf.sqrt( sum([tf.reduce_sum(tf.square(tf.cast(g, tf.float32))) for g,v in grads_vars]))
|
||||
updates += [ state_ops.assign_add( self.iterations, 1) ]
|
||||
for i, (g,v) in enumerate(grads_vars):
|
||||
if self.clipnorm > 0.0:
|
||||
g = self.tf_clip_norm(g, self.clipnorm, tf.cast(norm, g.dtype) )
|
||||
|
||||
ms = self.ms_dict[ v.name ]
|
||||
vs = self.vs_dict[ v.name ]
|
||||
|
||||
m_t = self.beta_1*ms + (1.0-self.beta_1) * g
|
||||
v_t = self.beta_2*vs + (1.0-self.beta_2) * tf.square(g-m_t)
|
||||
|
||||
lr = tf.constant(self.lr, g.dtype)
|
||||
if self.lr_cos != 0:
|
||||
lr *= (tf.cos( tf.cast(self.iterations, g.dtype) * (2*3.1415926535/ float(self.lr_cos) ) ) + 1.0) / 2.0
|
||||
|
||||
v_diff = - lr * m_t / (tf.sqrt(v_t) + np.finfo( g.dtype.as_numpy_dtype ).resolution )
|
||||
if self.lr_dropout != 1.0:
|
||||
lr_rnd = self.lr_rnds_dict[v.name]
|
||||
v_diff *= lr_rnd
|
||||
new_v = v + v_diff
|
||||
|
||||
updates.append (state_ops.assign(ms, m_t))
|
||||
updates.append (state_ops.assign(vs, v_t))
|
||||
updates.append (state_ops.assign(v, new_v))
|
||||
|
||||
return control_flow_ops.group ( *updates, name=self.name+'_updates')
|
||||
nn.AdaBelief = AdaBelief
|
|
@ -1,33 +1,31 @@
|
|||
import numpy as np
|
||||
from tensorflow.python.ops import control_flow_ops, state_ops
|
||||
from core.leras import nn
|
||||
tf = nn.tf
|
||||
|
||||
class RMSprop(nn.OptimizerBase):
|
||||
def __init__(self, lr=0.001, rho=0.9, lr_dropout=1.0, lr_cos=0, clipnorm=0.0, name=None, **kwargs):
|
||||
def __init__(self, lr=0.001, rho=0.9, lr_dropout=1.0, epsilon=1e-7, clipnorm=0.0, name=None):
|
||||
super().__init__(name=name)
|
||||
|
||||
if name is None:
|
||||
raise ValueError('name must be defined.')
|
||||
|
||||
self.lr_dropout = lr_dropout
|
||||
self.lr_cos = lr_cos
|
||||
self.lr = lr
|
||||
self.rho = rho
|
||||
self.clipnorm = clipnorm
|
||||
|
||||
with tf.device('/CPU:0') :
|
||||
with tf.variable_scope(self.name):
|
||||
|
||||
self.lr = tf.Variable (lr, name="lr")
|
||||
self.rho = tf.Variable (rho, name="rho")
|
||||
self.epsilon = tf.Variable (epsilon, name="epsilon")
|
||||
self.iterations = tf.Variable(0, dtype=tf.int64, name='iters')
|
||||
|
||||
self.accumulators_dict = {}
|
||||
self.lr_rnds_dict = {}
|
||||
|
||||
def get_weights(self):
|
||||
return [self.iterations] + list(self.accumulators_dict.values())
|
||||
return [self.lr, self.rho, self.epsilon, self.iterations] + list(self.accumulators_dict.values())
|
||||
|
||||
def initialize_variables(self, trainable_weights, vars_on_cpu=True, lr_dropout_on_cpu=False):
|
||||
def initialize_variables(self, trainable_weights, vars_on_cpu=True):
|
||||
# Initialize here all trainable variables used in training
|
||||
e = tf.device('/CPU:0') if vars_on_cpu else None
|
||||
if e: e.__enter__()
|
||||
|
@ -36,10 +34,7 @@ class RMSprop(nn.OptimizerBase):
|
|||
self.accumulators_dict.update ( accumulators)
|
||||
|
||||
if self.lr_dropout != 1.0:
|
||||
e = tf.device('/CPU:0') if lr_dropout_on_cpu else None
|
||||
if e: e.__enter__()
|
||||
lr_rnds = [ nn.random_binomial( v.shape, p=self.lr_dropout, dtype=v.dtype) for v in trainable_weights ]
|
||||
if e: e.__exit__(None, None, None)
|
||||
self.lr_rnds_dict.update ( { v.name : rnd for v,rnd in zip(trainable_weights,lr_rnds) } )
|
||||
if e: e.__exit__(None, None, None)
|
||||
|
||||
|
@ -47,21 +42,21 @@ class RMSprop(nn.OptimizerBase):
|
|||
updates = []
|
||||
|
||||
if self.clipnorm > 0.0:
|
||||
norm = tf.sqrt( sum([tf.reduce_sum(tf.square(tf.cast(g, tf.float32))) for g,v in grads_vars]))
|
||||
norm = tf.sqrt( sum([tf.reduce_sum(tf.square(g)) for g,v in grads_vars]))
|
||||
updates += [ state_ops.assign_add( self.iterations, 1) ]
|
||||
for i, (g,v) in enumerate(grads_vars):
|
||||
if self.clipnorm > 0.0:
|
||||
g = self.tf_clip_norm(g, self.clipnorm, tf.cast(norm, g.dtype) )
|
||||
g = self.tf_clip_norm(g, self.clipnorm, norm)
|
||||
|
||||
a = self.accumulators_dict[ v.name ]
|
||||
|
||||
new_a = self.rho * a + (1. - self.rho) * tf.square(g)
|
||||
rho = tf.cast(self.rho, a.dtype)
|
||||
new_a = rho * a + (1. - rho) * tf.square(g)
|
||||
|
||||
lr = tf.constant(self.lr, g.dtype)
|
||||
if self.lr_cos != 0:
|
||||
lr *= (tf.cos( tf.cast(self.iterations, g.dtype) * (2*3.1415926535/ float(self.lr_cos) ) ) + 1.0) / 2.0
|
||||
lr = tf.cast(self.lr, a.dtype)
|
||||
epsilon = tf.cast(self.epsilon, a.dtype)
|
||||
|
||||
v_diff = - lr * g / (tf.sqrt(new_a) + np.finfo( g.dtype.as_numpy_dtype ).resolution )
|
||||
v_diff = - lr * g / (tf.sqrt(new_a) + epsilon)
|
||||
if self.lr_dropout != 1.0:
|
||||
lr_rnd = self.lr_rnds_dict[v.name]
|
||||
v_diff *= lr_rnd
|
||||
|
|
|
@ -1,3 +1,2 @@
|
|||
from .OptimizerBase import *
|
||||
from .RMSprop import *
|
||||
from .AdaBelief import *
|
|
@ -1,12 +1,7 @@
|
|||
import math
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import numpy.linalg as npla
|
||||
|
||||
import math
|
||||
from .umeyama import umeyama
|
||||
|
||||
|
||||
def get_power_of_two(x):
|
||||
i = 0
|
||||
while (1 << i) < x:
|
||||
|
@ -28,70 +23,3 @@ def rotationMatrixToEulerAngles(R) :
|
|||
|
||||
def polygon_area(x,y):
|
||||
return 0.5*np.abs(np.dot(x,np.roll(y,1))-np.dot(y,np.roll(x,1)))
|
||||
|
||||
def rotate_point(origin, point, deg):
|
||||
"""
|
||||
Rotate a point counterclockwise by a given angle around a given origin.
|
||||
|
||||
The angle should be given in radians.
|
||||
"""
|
||||
ox, oy = origin
|
||||
px, py = point
|
||||
|
||||
rad = deg * math.pi / 180.0
|
||||
qx = ox + math.cos(rad) * (px - ox) - math.sin(rad) * (py - oy)
|
||||
qy = oy + math.sin(rad) * (px - ox) + math.cos(rad) * (py - oy)
|
||||
return np.float32([qx, qy])
|
||||
|
||||
def transform_points(points, mat, invert=False):
|
||||
if invert:
|
||||
mat = cv2.invertAffineTransform (mat)
|
||||
points = np.expand_dims(points, axis=1)
|
||||
points = cv2.transform(points, mat, points.shape)
|
||||
points = np.squeeze(points)
|
||||
return points
|
||||
|
||||
|
||||
def transform_mat(mat, res, tx, ty, rotation, scale):
|
||||
"""
|
||||
transform mat in local space of res
|
||||
scale -> translate -> rotate
|
||||
|
||||
tx,ty float
|
||||
rotation int degrees
|
||||
scale float
|
||||
"""
|
||||
|
||||
|
||||
lt, rt, lb, ct = transform_points ( np.float32([(0,0),(res,0),(0,res),(res / 2, res/2) ]),mat, True)
|
||||
|
||||
hor_v = (rt-lt).astype(np.float32)
|
||||
hor_size = npla.norm(hor_v)
|
||||
hor_v /= hor_size
|
||||
|
||||
ver_v = (lb-lt).astype(np.float32)
|
||||
ver_size = npla.norm(ver_v)
|
||||
ver_v /= ver_size
|
||||
|
||||
bt_diag_vec = (rt-ct).astype(np.float32)
|
||||
half_diag_len = npla.norm(bt_diag_vec)
|
||||
bt_diag_vec /= half_diag_len
|
||||
|
||||
tb_diag_vec = np.float32( [ -bt_diag_vec[1], bt_diag_vec[0] ] )
|
||||
|
||||
rt = ct + bt_diag_vec*half_diag_len*scale
|
||||
lb = ct - bt_diag_vec*half_diag_len*scale
|
||||
lt = ct - tb_diag_vec*half_diag_len*scale
|
||||
|
||||
rt[0] += tx*hor_size
|
||||
lb[0] += tx*hor_size
|
||||
lt[0] += tx*hor_size
|
||||
rt[1] += ty*ver_size
|
||||
lb[1] += ty*ver_size
|
||||
lt[1] += ty*ver_size
|
||||
|
||||
rt = rotate_point(ct, rt, rotation)
|
||||
lb = rotate_point(ct, lb, rotation)
|
||||
lt = rotate_point(ct, lt, rotation)
|
||||
|
||||
return cv2.getAffineTransform( np.float32([lt, rt, lb]), np.float32([ [0,0], [res,0], [0,res] ]) )
|
||||
|
|
|
@ -60,11 +60,9 @@ class MPSharedList():
|
|||
break
|
||||
key -= self.obj_counts[i]
|
||||
|
||||
sh_b = memoryview(sh_b).cast('B')
|
||||
offset_start, offset_end = struct.unpack('<QQ', bytes(sh_b[ table_offset + key*8 : table_offset + (key+2)*8]) )
|
||||
|
||||
offset_start, offset_end = struct.unpack('<QQ', sh_b[ table_offset + key*8 : table_offset + (key+2)*8].tobytes() )
|
||||
|
||||
return pickle.loads( sh_b[ data_offset + offset_start : data_offset + offset_end ].tobytes() )
|
||||
return pickle.loads( bytes(sh_b[ data_offset + offset_start : data_offset + offset_end ]) )
|
||||
|
||||
def __iter__(self):
|
||||
for i in range(self.__len__()):
|
||||
|
@ -86,8 +84,7 @@ class MPSharedList():
|
|||
data_size = sum([len(x) for x in obj_pickled_ar])
|
||||
|
||||
sh_b = multiprocessing.RawArray('B', table_size + data_size)
|
||||
#sh_b[0:8] = struct.pack('<Q', obj_count)
|
||||
sh_b_view = memoryview(sh_b).cast('B')
|
||||
sh_b[0:8] = struct.pack('<Q', obj_count)
|
||||
|
||||
offset = 0
|
||||
|
||||
|
@ -100,12 +97,51 @@ class MPSharedList():
|
|||
offset += len(obj_pickled_ar[i])
|
||||
offsets.append(offset)
|
||||
|
||||
sh_b_view[table_offset:table_offset+table_size] = struct.pack( '<'+'Q'*len(offsets), *offsets )
|
||||
sh_b[table_offset:table_offset+table_size] = struct.pack( '<'+'Q'*len(offsets), *offsets )
|
||||
|
||||
for i, obj_pickled in enumerate(obj_pickled_ar):
|
||||
offset = data_offset+offsets[i]
|
||||
sh_b_view[offset:offset+len(obj_pickled)] = obj_pickled_ar[i]
|
||||
ArrayFillerSubprocessor(sh_b, [ (data_offset+offsets[i], obj_pickled_ar[i] ) for i in range(obj_count) ] ).run()
|
||||
|
||||
return obj_count, table_offset, data_offset, sh_b
|
||||
return 0, 0, 0, None
|
||||
|
||||
|
||||
|
||||
class ArrayFillerSubprocessor(Subprocessor):
|
||||
"""
|
||||
Much faster to fill shared memory via subprocesses rather than direct whole bytes fill.
|
||||
"""
|
||||
#override
|
||||
def __init__(self, sh_b, data_list ):
|
||||
self.sh_b = sh_b
|
||||
self.data_list = data_list
|
||||
super().__init__('ArrayFillerSubprocessor', ArrayFillerSubprocessor.Cli, 60)
|
||||
|
||||
#override
|
||||
def process_info_generator(self):
|
||||
for i in range(min(multiprocessing.cpu_count(), 8)):
|
||||
yield 'CPU%d' % (i), {}, {'sh_b':self.sh_b}
|
||||
|
||||
#override
|
||||
def get_data(self, host_dict):
|
||||
if len(self.data_list) > 0:
|
||||
return self.data_list.pop(0)
|
||||
|
||||
return None
|
||||
|
||||
#override
|
||||
def on_data_return (self, host_dict, data):
|
||||
self.data_list.insert(0, data)
|
||||
|
||||
#override
|
||||
def on_result (self, host_dict, data, result):
|
||||
pass
|
||||
|
||||
class Cli(Subprocessor.Cli):
|
||||
#overridable optional
|
||||
def on_initialize(self, client_dict):
|
||||
self.sh_b = client_dict['sh_b']
|
||||
|
||||
def process_data(self, data):
|
||||
offset, b = data
|
||||
self.sh_b[offset:offset+len(b)]=b
|
||||
return 0
|
||||
|
|
|
@ -5,6 +5,96 @@ import time
|
|||
|
||||
import numpy as np
|
||||
|
||||
class Index2DHost():
|
||||
"""
|
||||
Provides random shuffled 2D indexes for multiprocesses
|
||||
"""
|
||||
def __init__(self, indexes2D):
|
||||
self.sq = multiprocessing.Queue()
|
||||
self.cqs = []
|
||||
self.clis = []
|
||||
self.thread = threading.Thread(target=self.host_thread, args=(indexes2D,) )
|
||||
self.thread.daemon = True
|
||||
self.thread.start()
|
||||
|
||||
def host_thread(self, indexes2D):
|
||||
indexes_counts_len = len(indexes2D)
|
||||
|
||||
idxs = [*range(indexes_counts_len)]
|
||||
idxs_2D = [None]*indexes_counts_len
|
||||
shuffle_idxs = []
|
||||
shuffle_idxs_2D = [None]*indexes_counts_len
|
||||
for i in range(indexes_counts_len):
|
||||
idxs_2D[i] = indexes2D[i]
|
||||
shuffle_idxs_2D[i] = []
|
||||
|
||||
sq = self.sq
|
||||
|
||||
while True:
|
||||
while not sq.empty():
|
||||
obj = sq.get()
|
||||
cq_id, cmd = obj[0], obj[1]
|
||||
|
||||
if cmd == 0: #get_1D
|
||||
count = obj[2]
|
||||
|
||||
result = []
|
||||
for i in range(count):
|
||||
if len(shuffle_idxs) == 0:
|
||||
shuffle_idxs = idxs.copy()
|
||||
np.random.shuffle(shuffle_idxs)
|
||||
result.append(shuffle_idxs.pop())
|
||||
self.cqs[cq_id].put (result)
|
||||
elif cmd == 1: #get_2D
|
||||
targ_idxs,count = obj[2], obj[3]
|
||||
result = []
|
||||
|
||||
for targ_idx in targ_idxs:
|
||||
sub_idxs = []
|
||||
for i in range(count):
|
||||
ar = shuffle_idxs_2D[targ_idx]
|
||||
if len(ar) == 0:
|
||||
ar = shuffle_idxs_2D[targ_idx] = idxs_2D[targ_idx].copy()
|
||||
np.random.shuffle(ar)
|
||||
sub_idxs.append(ar.pop())
|
||||
result.append (sub_idxs)
|
||||
self.cqs[cq_id].put (result)
|
||||
|
||||
time.sleep(0.005)
|
||||
|
||||
def create_cli(self):
|
||||
cq = multiprocessing.Queue()
|
||||
self.cqs.append ( cq )
|
||||
cq_id = len(self.cqs)-1
|
||||
return Index2DHost.Cli(self.sq, cq, cq_id)
|
||||
|
||||
# disable pickling
|
||||
def __getstate__(self):
|
||||
return dict()
|
||||
def __setstate__(self, d):
|
||||
self.__dict__.update(d)
|
||||
|
||||
class Cli():
|
||||
def __init__(self, sq, cq, cq_id):
|
||||
self.sq = sq
|
||||
self.cq = cq
|
||||
self.cq_id = cq_id
|
||||
|
||||
def get_1D(self, count):
|
||||
self.sq.put ( (self.cq_id,0, count) )
|
||||
|
||||
while True:
|
||||
if not self.cq.empty():
|
||||
return self.cq.get()
|
||||
time.sleep(0.001)
|
||||
|
||||
def get_2D(self, idxs, count):
|
||||
self.sq.put ( (self.cq_id,1,idxs,count) )
|
||||
|
||||
while True:
|
||||
if not self.cq.empty():
|
||||
return self.cq.get()
|
||||
time.sleep(0.001)
|
||||
|
||||
class IndexHost():
|
||||
"""
|
||||
|
@ -66,95 +156,6 @@ class IndexHost():
|
|||
return self.cq.get()
|
||||
time.sleep(0.001)
|
||||
|
||||
class Index2DHost():
|
||||
"""
|
||||
Provides random shuffled indexes for multiprocesses
|
||||
"""
|
||||
def __init__(self, indexes2D):
|
||||
self.sq = multiprocessing.Queue()
|
||||
self.cqs = []
|
||||
self.clis = []
|
||||
self.thread = threading.Thread(target=self.host_thread, args=(indexes2D,) )
|
||||
self.thread.daemon = True
|
||||
self.thread.start()
|
||||
|
||||
def host_thread(self, indexes2D):
|
||||
indexes2D_len = len(indexes2D)
|
||||
|
||||
idxs = [*range(indexes2D_len)]
|
||||
idxs_2D = [None]*indexes2D_len
|
||||
shuffle_idxs = []
|
||||
shuffle_idxs_2D = [None]*indexes2D_len
|
||||
for i in range(indexes2D_len):
|
||||
idxs_2D[i] = [*range(len(indexes2D[i]))]
|
||||
shuffle_idxs_2D[i] = []
|
||||
|
||||
#print(idxs)
|
||||
#print(idxs_2D)
|
||||
sq = self.sq
|
||||
|
||||
while True:
|
||||
while not sq.empty():
|
||||
obj = sq.get()
|
||||
cq_id, count = obj[0], obj[1]
|
||||
|
||||
result = []
|
||||
for i in range(count):
|
||||
if len(shuffle_idxs) == 0:
|
||||
shuffle_idxs = idxs.copy()
|
||||
np.random.shuffle(shuffle_idxs)
|
||||
|
||||
idx_1D = shuffle_idxs.pop()
|
||||
|
||||
#print(f'idx_1D = {idx_1D}, len(shuffle_idxs_2D[idx_1D])= {len(shuffle_idxs_2D[idx_1D])}')
|
||||
|
||||
if len(shuffle_idxs_2D[idx_1D]) == 0:
|
||||
shuffle_idxs_2D[idx_1D] = idxs_2D[idx_1D].copy()
|
||||
#print(f'new shuffle_idxs_2d for {idx_1D} = { shuffle_idxs_2D[idx_1D] }')
|
||||
|
||||
#print(f'len(shuffle_idxs_2D[idx_1D])= {len(shuffle_idxs_2D[idx_1D])}')
|
||||
|
||||
np.random.shuffle( shuffle_idxs_2D[idx_1D] )
|
||||
|
||||
idx_2D = shuffle_idxs_2D[idx_1D].pop()
|
||||
|
||||
#print(f'len(shuffle_idxs_2D[idx_1D])= {len(shuffle_idxs_2D[idx_1D])}')
|
||||
|
||||
#print(f'idx_2D = {idx_2D}')
|
||||
|
||||
|
||||
result.append( indexes2D[idx_1D][idx_2D])
|
||||
|
||||
self.cqs[cq_id].put (result)
|
||||
|
||||
time.sleep(0.001)
|
||||
|
||||
def create_cli(self):
|
||||
cq = multiprocessing.Queue()
|
||||
self.cqs.append ( cq )
|
||||
cq_id = len(self.cqs)-1
|
||||
return Index2DHost.Cli(self.sq, cq, cq_id)
|
||||
|
||||
# disable pickling
|
||||
def __getstate__(self):
|
||||
return dict()
|
||||
def __setstate__(self, d):
|
||||
self.__dict__.update(d)
|
||||
|
||||
class Cli():
|
||||
def __init__(self, sq, cq, cq_id):
|
||||
self.sq = sq
|
||||
self.cq = cq
|
||||
self.cq_id = cq_id
|
||||
|
||||
def multi_get(self, count):
|
||||
self.sq.put ( (self.cq_id,count) )
|
||||
|
||||
while True:
|
||||
if not self.cq.empty():
|
||||
return self.cq.get()
|
||||
time.sleep(0.001)
|
||||
|
||||
class ListHost():
|
||||
def __init__(self, list_):
|
||||
self.sq = multiprocessing.Queue()
|
||||
|
|
|
@ -38,8 +38,8 @@ def QImage_from_np(img):
|
|||
|
||||
return QImage(img.data, w, h, c*w, fmt )
|
||||
|
||||
def QImage_to_np(q_img, fmt=QImage.Format_BGR888):
|
||||
q_img = q_img.convertToFormat(fmt)
|
||||
def QImage_to_np(q_img):
|
||||
q_img = q_img.convertToFormat(QImage.Format_BGR888)
|
||||
|
||||
width = q_img.width()
|
||||
height = q_img.height()
|
||||
|
|
|
@ -1,14 +1,12 @@
|
|||
import numpy as np
|
||||
|
||||
def random_normal( size=(1,), trunc_val = 2.5, rnd_state=None ):
|
||||
if rnd_state is None:
|
||||
rnd_state = np.random
|
||||
def random_normal( size=(1,), trunc_val = 2.5 ):
|
||||
len = np.array(size).prod()
|
||||
result = np.empty ( (len,) , dtype=np.float32)
|
||||
|
||||
for i in range (len):
|
||||
while True:
|
||||
x = rnd_state.normal()
|
||||
x = np.random.normal()
|
||||
if x >= -trunc_val and x <= trunc_val:
|
||||
break
|
||||
result[i] = (x / trunc_val)
|
||||
|
|
BIN
doc/Alipay_donation.jpg
Normal file
After Width: | Height: | Size: 63 KiB |
Before Width: | Height: | Size: 482 KiB After Width: | Height: | Size: 544 KiB |
Before Width: | Height: | Size: 74 KiB |
Before Width: | Height: | Size: 68 KiB |
Before Width: | Height: | Size: 1 MiB After Width: | Height: | Size: 313 KiB |
Before Width: | Height: | Size: 71 KiB After Width: | Height: | Size: 71 KiB |
Before Width: | Height: | Size: 122 KiB |
Before Width: | Height: | Size: 123 KiB |
Before Width: | Height: | Size: 67 KiB After Width: | Height: | Size: 67 KiB |
Before Width: | Height: | Size: 98 KiB |
Before Width: | Height: | Size: 97 KiB |
Before Width: | Height: | Size: 25 KiB |
BIN
doc/meme2.jpg
Before Width: | Height: | Size: 208 KiB After Width: | Height: | Size: 178 KiB |
BIN
doc/meme3.jpg
Before Width: | Height: | Size: 310 KiB |
Before Width: | Height: | Size: 273 KiB |
BIN
doc/political_speech.jpg
Normal file
After Width: | Height: | Size: 548 KiB |
Before Width: | Height: | Size: 247 KiB |
Before Width: | Height: | Size: 349 KiB |
Before Width: | Height: | Size: 662 KiB |
Before Width: | Height: | Size: 378 KiB |
BIN
doc/replace_the_face.png
Normal file
After Width: | Height: | Size: 1,004 KiB |
Before Width: | Height: | Size: 268 B |
|
@ -28,13 +28,13 @@ class FANExtractor(object):
|
|||
self.out_planes = out_planes
|
||||
|
||||
self.bn1 = nn.BatchNorm2D(in_planes)
|
||||
self.conv1 = nn.Conv2D (in_planes, out_planes//2, kernel_size=3, strides=1, padding='SAME', use_bias=False )
|
||||
self.conv1 = nn.Conv2D (in_planes, out_planes/2, kernel_size=3, strides=1, padding='SAME', use_bias=False )
|
||||
|
||||
self.bn2 = nn.BatchNorm2D(out_planes//2)
|
||||
self.conv2 = nn.Conv2D (out_planes//2, out_planes//4, kernel_size=3, strides=1, padding='SAME', use_bias=False )
|
||||
self.conv2 = nn.Conv2D (out_planes/2, out_planes/4, kernel_size=3, strides=1, padding='SAME', use_bias=False )
|
||||
|
||||
self.bn3 = nn.BatchNorm2D(out_planes//4)
|
||||
self.conv3 = nn.Conv2D (out_planes//4, out_planes//4, kernel_size=3, strides=1, padding='SAME', use_bias=False )
|
||||
self.conv3 = nn.Conv2D (out_planes/4, out_planes/4, kernel_size=3, strides=1, padding='SAME', use_bias=False )
|
||||
|
||||
if self.in_planes != self.out_planes:
|
||||
self.down_bn1 = nn.BatchNorm2D(in_planes)
|
||||
|
|
|
@ -161,11 +161,11 @@ class FaceEnhancer(object):
|
|||
if not model_path.exists():
|
||||
raise Exception("Unable to load FaceEnhancer.npy")
|
||||
|
||||
with tf.device ('/CPU:0' if place_model_on_cpu else nn.tf_default_device_name):
|
||||
with tf.device ('/CPU:0' if place_model_on_cpu else '/GPU:0'):
|
||||
self.model = FaceEnhancer()
|
||||
self.model.load_weights (model_path)
|
||||
|
||||
with tf.device ('/CPU:0' if run_on_cpu else nn.tf_default_device_name):
|
||||
with tf.device ('/CPU:0' if run_on_cpu else '/GPU:0'):
|
||||
self.model.build_for_run ([ (tf.float32, nn.get4Dshape (192,192,3) ),
|
||||
(tf.float32, (None,1,) ),
|
||||
(tf.float32, (None,1,) ),
|
||||
|
@ -248,7 +248,7 @@ class FaceEnhancer(object):
|
|||
final_img = final_img [t_padding*up_res:(h-b_padding)*up_res, l_padding*up_res:(w-r_padding)*up_res,:]
|
||||
|
||||
if preserve_size:
|
||||
final_img = cv2.resize (final_img, (iw,ih), interpolation=cv2.INTER_LANCZOS4)
|
||||
final_img = cv2.resize (final_img, (iw,ih), cv2.INTER_LANCZOS4)
|
||||
|
||||
if not is_tanh:
|
||||
final_img = np.clip( final_img/2+0.5, 0, 1 )
|
||||
|
@ -278,7 +278,7 @@ class FaceEnhancer(object):
|
|||
preupscale_rate = 1.0 / ( max(h,w) / patch_size )
|
||||
|
||||
if preupscale_rate != 1.0:
|
||||
inp_img = cv2.resize (inp_img, ( int(w*preupscale_rate), int(h*preupscale_rate) ), interpolation=cv2.INTER_LANCZOS4)
|
||||
inp_img = cv2.resize (inp_img, ( int(w*preupscale_rate), int(h*preupscale_rate) ), cv2.INTER_LANCZOS4)
|
||||
h,w,c = inp_img.shape
|
||||
|
||||
i_max = w-patch_size+1
|
||||
|
@ -310,10 +310,10 @@ class FaceEnhancer(object):
|
|||
final_img /= final_img_div
|
||||
|
||||
if preserve_size:
|
||||
final_img = cv2.resize (final_img, (w,h), interpolation=cv2.INTER_LANCZOS4)
|
||||
final_img = cv2.resize (final_img, (w,h), cv2.INTER_LANCZOS4)
|
||||
else:
|
||||
if preupscale_rate != 1.0:
|
||||
final_img = cv2.resize (final_img, (tw,th), interpolation=cv2.INTER_LANCZOS4)
|
||||
final_img = cv2.resize (final_img, (tw,th), cv2.INTER_LANCZOS4)
|
||||
|
||||
if not is_tanh:
|
||||
final_img = np.clip( final_img/2+0.5, 0, 1 )
|
||||
|
|
|
@ -302,6 +302,8 @@ def get_transform_mat (image_landmarks, output_size, face_type, scale=1.0):
|
|||
g_c += vec*vec_len*0.07
|
||||
|
||||
elif face_type == FaceType.HEAD:
|
||||
mat = umeyama( np.concatenate ( [ image_landmarks[17:49] , image_landmarks[54:55] ] ) , landmarks_2D_new, True)[0:2]
|
||||
|
||||
# assuming image_landmarks are 3D_Landmarks extracted for HEAD,
|
||||
# adjust horizontal offset according to estimated yaw
|
||||
yaw = estimate_averaged_yaw(transform_points (image_landmarks, mat, False))
|
||||
|
@ -431,27 +433,6 @@ def get_image_eye_mask (image_shape, image_landmarks):
|
|||
|
||||
return hull_mask
|
||||
|
||||
def get_image_mouth_mask (image_shape, image_landmarks):
|
||||
if len(image_landmarks) != 68:
|
||||
raise Exception('get_image_eye_mask works only with 68 landmarks')
|
||||
|
||||
h,w,c = image_shape
|
||||
|
||||
hull_mask = np.zeros( (h,w,1),dtype=np.float32)
|
||||
|
||||
image_landmarks = image_landmarks.astype(np.int)
|
||||
|
||||
cv2.fillConvexPoly( hull_mask, cv2.convexHull( image_landmarks[60:]), (1,) )
|
||||
|
||||
dilate = h // 32
|
||||
hull_mask = cv2.dilate(hull_mask, cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(dilate,dilate)), iterations = 1 )
|
||||
|
||||
blur = h // 16
|
||||
blur = blur + (1-blur % 2)
|
||||
hull_mask = cv2.GaussianBlur(hull_mask, (blur, blur) , 0)
|
||||
hull_mask = hull_mask[...,None]
|
||||
|
||||
return hull_mask
|
||||
|
||||
def alpha_to_color (img_alpha, color):
|
||||
if len(img_alpha.shape) == 2:
|
||||
|
|
|
@ -30,37 +30,35 @@ class XSegNet(object):
|
|||
nn.initialize(data_format=data_format)
|
||||
tf = nn.tf
|
||||
|
||||
model_name = f'{name}_{resolution}'
|
||||
self.model_filename_list = []
|
||||
|
||||
with tf.device ('/CPU:0'):
|
||||
#Place holders on CPU
|
||||
self.input_t = tf.placeholder (nn.floatx, nn.get4Dshape(resolution,resolution,3) )
|
||||
self.target_t = tf.placeholder (nn.floatx, nn.get4Dshape(resolution,resolution,1) )
|
||||
|
||||
# Initializing model classes
|
||||
with tf.device ('/CPU:0' if place_model_on_cpu else nn.tf_default_device_name):
|
||||
with tf.device ('/CPU:0' if place_model_on_cpu else '/GPU:0'):
|
||||
self.model = nn.XSeg(3, 32, 1, name=name)
|
||||
self.model_weights = self.model.get_weights()
|
||||
if training:
|
||||
if optimizer is None:
|
||||
raise ValueError("Optimizer should be provided for training mode.")
|
||||
self.opt = optimizer
|
||||
self.opt.initialize_variables (self.model_weights, vars_on_cpu=place_model_on_cpu)
|
||||
self.model_filename_list += [ [self.opt, f'{model_name}_opt.npy' ] ]
|
||||
|
||||
model_name = f'{name}_{resolution}'
|
||||
|
||||
self.model_filename_list += [ [self.model, f'{model_name}.npy'] ]
|
||||
self.model_filename_list = [ [self.model, f'{model_name}.npy'] ]
|
||||
|
||||
if not training:
|
||||
with tf.device ('/CPU:0' if run_on_cpu else nn.tf_default_device_name):
|
||||
if training:
|
||||
if optimizer is None:
|
||||
raise ValueError("Optimizer should be provided for training mode.")
|
||||
|
||||
self.opt = optimizer
|
||||
self.opt.initialize_variables (self.model_weights, vars_on_cpu=place_model_on_cpu)
|
||||
self.model_filename_list += [ [self.opt, f'{model_name}_opt.npy' ] ]
|
||||
else:
|
||||
with tf.device ('/CPU:0' if run_on_cpu else '/GPU:0'):
|
||||
_, pred = self.model(self.input_t)
|
||||
|
||||
def net_run(input_np):
|
||||
return nn.tf_sess.run ( [pred], feed_dict={self.input_t :input_np})[0]
|
||||
self.net_run = net_run
|
||||
|
||||
self.initialized = True
|
||||
# Loading/initializing all models/optimizers weights
|
||||
for model, filename in self.model_filename_list:
|
||||
do_init = not load_weights
|
||||
|
@ -68,12 +66,8 @@ class XSegNet(object):
|
|||
if not do_init:
|
||||
model_file_path = self.weights_file_root / filename
|
||||
do_init = not model.load_weights( model_file_path )
|
||||
if do_init:
|
||||
if raise_on_no_model_files:
|
||||
raise Exception(f'{model_file_path} does not exists.')
|
||||
if not training:
|
||||
self.initialized = False
|
||||
break
|
||||
if do_init and raise_on_no_model_files:
|
||||
raise Exception(f'{model_file_path} does not exists.')
|
||||
|
||||
if do_init:
|
||||
model.init_weights()
|
||||
|
@ -81,8 +75,8 @@ class XSegNet(object):
|
|||
def get_resolution(self):
|
||||
return self.resolution
|
||||
|
||||
def flow(self, x, pretrain=False):
|
||||
return self.model(x, pretrain=pretrain)
|
||||
def flow(self, x):
|
||||
return self.model(x)
|
||||
|
||||
def get_weights(self):
|
||||
return self.model_weights
|
||||
|
@ -92,9 +86,6 @@ class XSegNet(object):
|
|||
model.save_weights( self.weights_file_root / filename )
|
||||
|
||||
def extract (self, input_image):
|
||||
if not self.initialized:
|
||||
return 0.5*np.ones ( (self.resolution, self.resolution, 1), nn.floatx.as_numpy_dtype )
|
||||
|
||||
input_shape_len = len(input_image.shape)
|
||||
if input_shape_len == 3:
|
||||
input_image = input_image[None,...]
|
||||
|
|
47
main.py
|
@ -22,8 +22,6 @@ if __name__ == "__main__":
|
|||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
setattr(namespace, self.dest, os.path.abspath(os.path.expanduser(values)))
|
||||
|
||||
exit_code = 0
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
subparsers = parser.add_subparsers()
|
||||
|
||||
|
@ -38,9 +36,6 @@ if __name__ == "__main__":
|
|||
manual_output_debug_fix = arguments.manual_output_debug_fix,
|
||||
manual_window_size = arguments.manual_window_size,
|
||||
face_type = arguments.face_type,
|
||||
max_faces_from_image = arguments.max_faces_from_image,
|
||||
image_size = arguments.image_size,
|
||||
jpeg_quality = arguments.jpeg_quality,
|
||||
cpu_only = arguments.cpu_only,
|
||||
force_gpu_idxs = [ int(x) for x in arguments.force_gpu_idxs.split(',') ] if arguments.force_gpu_idxs is not None else None,
|
||||
)
|
||||
|
@ -52,9 +47,6 @@ if __name__ == "__main__":
|
|||
p.add_argument('--output-debug', action="store_true", dest="output_debug", default=None, help="Writes debug images to <output-dir>_debug\ directory.")
|
||||
p.add_argument('--no-output-debug', action="store_false", dest="output_debug", default=None, help="Don't writes debug images to <output-dir>_debug\ directory.")
|
||||
p.add_argument('--face-type', dest="face_type", choices=['half_face', 'full_face', 'whole_face', 'head', 'mark_only'], default=None)
|
||||
p.add_argument('--max-faces-from-image', type=int, dest="max_faces_from_image", default=None, help="Max faces from image.")
|
||||
p.add_argument('--image-size', type=int, dest="image_size", default=None, help="Output image size.")
|
||||
p.add_argument('--jpeg-quality', type=int, dest="jpeg_quality", default=None, help="Jpeg quality.")
|
||||
p.add_argument('--manual-fix', action="store_true", dest="manual_fix", default=False, help="Enables manual extract only frames where faces were not recognized.")
|
||||
p.add_argument('--manual-output-debug-fix', action="store_true", dest="manual_output_debug_fix", default=False, help="Performs manual reextract input-dir frames which were deleted from [output_dir]_debug\ dir.")
|
||||
p.add_argument('--manual-window-size', type=int, dest="manual_window_size", default=1368, help="Manual fix window size. Default: 1368.")
|
||||
|
@ -70,7 +62,7 @@ if __name__ == "__main__":
|
|||
|
||||
p = subparsers.add_parser( "sort", help="Sort faces in a directory.")
|
||||
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory. A directory containing the files you wish to process.")
|
||||
p.add_argument('--by', dest="sort_by_method", default=None, choices=("blur", "motion-blur", "face-yaw", "face-pitch", "face-source-rect-size", "hist", "hist-dissim", "brightness", "hue", "black", "origname", "oneface", "final-by-blur", "final-by-size", "absdiff"), help="Method of sorting. 'origname' sort by original filename to recover original sequence." )
|
||||
p.add_argument('--by', dest="sort_by_method", default=None, choices=("blur", "face-yaw", "face-pitch", "face-source-rect-size", "hist", "hist-dissim", "brightness", "hue", "black", "origname", "oneface", "final", "final-faster", "absdiff"), help="Method of sorting. 'origname' sort by original filename to recover original sequence." )
|
||||
p.set_defaults (func=process_sort)
|
||||
|
||||
def process_util(arguments):
|
||||
|
@ -99,10 +91,6 @@ if __name__ == "__main__":
|
|||
from samplelib import PackedFaceset
|
||||
PackedFaceset.unpack( Path(arguments.input_dir) )
|
||||
|
||||
if arguments.export_faceset_mask:
|
||||
io.log_info ("Exporting faceset mask..\r\n")
|
||||
Util.export_faceset_mask( Path(arguments.input_dir) )
|
||||
|
||||
p = subparsers.add_parser( "util", help="Utilities.")
|
||||
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory. A directory containing the files you wish to process.")
|
||||
p.add_argument('--add-landmarks-debug-images', action="store_true", dest="add_landmarks_debug_images", default=False, help="Add landmarks debug image for aligned faces.")
|
||||
|
@ -111,7 +99,6 @@ if __name__ == "__main__":
|
|||
p.add_argument('--restore-faceset-metadata', action="store_true", dest="restore_faceset_metadata", default=False, help="Restore faceset metadata to file. Image filenames must be the same as used with save.")
|
||||
p.add_argument('--pack-faceset', action="store_true", dest="pack_faceset", default=False, help="")
|
||||
p.add_argument('--unpack-faceset', action="store_true", dest="unpack_faceset", default=False, help="")
|
||||
p.add_argument('--export-faceset-mask', action="store_true", dest="export_faceset_mask", default=False, help="")
|
||||
|
||||
p.set_defaults (func=process_util)
|
||||
|
||||
|
@ -150,19 +137,10 @@ if __name__ == "__main__":
|
|||
p.add_argument('--force-gpu-idxs', dest="force_gpu_idxs", default=None, help="Force to choose GPU indexes separated by comma.")
|
||||
p.add_argument('--silent-start', action="store_true", dest="silent_start", default=False, help="Silent start. Automatically chooses Best GPU and last used model.")
|
||||
|
||||
|
||||
p.add_argument('--execute-program', dest="execute_program", default=[], action='append', nargs='+')
|
||||
p.set_defaults (func=process_train)
|
||||
|
||||
def process_exportdfm(arguments):
|
||||
osex.set_process_lowest_prio()
|
||||
from mainscripts import ExportDFM
|
||||
ExportDFM.main(model_class_name = arguments.model_name, saved_models_path = Path(arguments.model_dir))
|
||||
|
||||
p = subparsers.add_parser( "exportdfm", help="Export model to use in DeepFaceLive.")
|
||||
p.add_argument('--model-dir', required=True, action=fixPathAction, dest="model_dir", help="Saved models dir.")
|
||||
p.add_argument('--model', required=True, dest="model_name", choices=pathex.get_all_dir_names_startswith ( Path(__file__).parent / 'models' , 'Model_'), help="Model class name.")
|
||||
p.set_defaults (func=process_exportdfm)
|
||||
|
||||
def process_merge(arguments):
|
||||
osex.set_process_lowest_prio()
|
||||
from mainscripts import Merger
|
||||
|
@ -267,20 +245,10 @@ if __name__ == "__main__":
|
|||
|
||||
p.set_defaults(func=process_faceset_enhancer)
|
||||
|
||||
|
||||
p = facesettool_parser.add_parser ("resize", help="Resize DFL faceset.")
|
||||
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory of aligned faces.")
|
||||
|
||||
def process_faceset_resizer(arguments):
|
||||
osex.set_process_lowest_prio()
|
||||
from mainscripts import FacesetResizer
|
||||
FacesetResizer.process_folder ( Path(arguments.input_dir) )
|
||||
p.set_defaults(func=process_faceset_resizer)
|
||||
|
||||
def process_dev_test(arguments):
|
||||
osex.set_process_lowest_prio()
|
||||
from mainscripts import dev_misc
|
||||
dev_misc.dev_gen_mask_files( arguments.input_dir )
|
||||
dev_misc.dev_test( arguments.input_dir )
|
||||
|
||||
p = subparsers.add_parser( "dev_test", help="")
|
||||
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir")
|
||||
|
@ -294,9 +262,7 @@ if __name__ == "__main__":
|
|||
def process_xsegeditor(arguments):
|
||||
osex.set_process_lowest_prio()
|
||||
from XSegEditor import XSegEditor
|
||||
global exit_code
|
||||
exit_code = XSegEditor.start (Path(arguments.input_dir))
|
||||
|
||||
XSegEditor.start (Path(arguments.input_dir))
|
||||
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir")
|
||||
|
||||
p.set_defaults (func=process_xsegeditor)
|
||||
|
@ -347,10 +313,7 @@ if __name__ == "__main__":
|
|||
arguments = parser.parse_args()
|
||||
arguments.func(arguments)
|
||||
|
||||
if exit_code == 0:
|
||||
print ("Done.")
|
||||
|
||||
exit(exit_code)
|
||||
print ("Done.")
|
||||
|
||||
'''
|
||||
import code
|
||||
|
|
|
@ -1,22 +0,0 @@
|
|||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
import numpy as np
|
||||
import itertools
|
||||
from pathlib import Path
|
||||
from core import pathex
|
||||
from core import imagelib
|
||||
import cv2
|
||||
import models
|
||||
from core.interact import interact as io
|
||||
|
||||
|
||||
def main(model_class_name, saved_models_path):
|
||||
model = models.import_model(model_class_name)(
|
||||
is_exporting=True,
|
||||
saved_models_path=saved_models_path,
|
||||
cpu_only=True)
|
||||
model.export_dfm ()
|
|
@ -10,7 +10,6 @@ from pathlib import Path
|
|||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from numpy import linalg as npla
|
||||
|
||||
import facelib
|
||||
from core import imagelib
|
||||
|
@ -44,7 +43,6 @@ class ExtractSubprocessor(Subprocessor):
|
|||
def on_initialize(self, client_dict):
|
||||
self.type = client_dict['type']
|
||||
self.image_size = client_dict['image_size']
|
||||
self.jpeg_quality = client_dict['jpeg_quality']
|
||||
self.face_type = client_dict['face_type']
|
||||
self.max_faces_from_image = client_dict['max_faces_from_image']
|
||||
self.device_idx = client_dict['device_idx']
|
||||
|
@ -97,6 +95,9 @@ class ExtractSubprocessor(Subprocessor):
|
|||
|
||||
h, w, c = image.shape
|
||||
|
||||
dflimg = DFLIMG.load (filepath)
|
||||
extract_from_dflimg = (h == w and (dflimg is not None and dflimg.has_data()) )
|
||||
|
||||
if 'rects' in self.type or self.type == 'all':
|
||||
data = ExtractSubprocessor.Cli.rects_stage (data=data,
|
||||
image=image,
|
||||
|
@ -107,6 +108,7 @@ class ExtractSubprocessor(Subprocessor):
|
|||
if 'landmarks' in self.type or self.type == 'all':
|
||||
data = ExtractSubprocessor.Cli.landmarks_stage (data=data,
|
||||
image=image,
|
||||
extract_from_dflimg=extract_from_dflimg,
|
||||
landmarks_extractor=self.landmarks_extractor,
|
||||
rects_extractor=self.rects_extractor,
|
||||
)
|
||||
|
@ -116,7 +118,7 @@ class ExtractSubprocessor(Subprocessor):
|
|||
image=image,
|
||||
face_type=self.face_type,
|
||||
image_size=self.image_size,
|
||||
jpeg_quality=self.jpeg_quality,
|
||||
extract_from_dflimg=extract_from_dflimg,
|
||||
output_debug_path=self.output_debug_path,
|
||||
final_output_path=self.final_output_path,
|
||||
)
|
||||
|
@ -146,9 +148,7 @@ class ExtractSubprocessor(Subprocessor):
|
|||
if len(rects) != 0:
|
||||
data.rects_rotation = rot
|
||||
break
|
||||
if max_faces_from_image is not None and \
|
||||
max_faces_from_image > 0 and \
|
||||
len(data.rects) > 0:
|
||||
if max_faces_from_image != 0 and len(data.rects) > 1:
|
||||
data.rects = data.rects[0:max_faces_from_image]
|
||||
return data
|
||||
|
||||
|
@ -156,6 +156,7 @@ class ExtractSubprocessor(Subprocessor):
|
|||
@staticmethod
|
||||
def landmarks_stage(data,
|
||||
image,
|
||||
extract_from_dflimg,
|
||||
landmarks_extractor,
|
||||
rects_extractor,
|
||||
):
|
||||
|
@ -170,7 +171,7 @@ class ExtractSubprocessor(Subprocessor):
|
|||
elif data.rects_rotation == 270:
|
||||
rotated_image = image.swapaxes( 0,1 )[::-1,:,:]
|
||||
|
||||
data.landmarks = landmarks_extractor.extract (rotated_image, data.rects, rects_extractor if (data.landmarks_accurate) else None, is_bgr=True)
|
||||
data.landmarks = landmarks_extractor.extract (rotated_image, data.rects, rects_extractor if (not extract_from_dflimg and data.landmarks_accurate) else None, is_bgr=True)
|
||||
if data.rects_rotation != 0:
|
||||
for i, (rect, lmrks) in enumerate(zip(data.rects, data.landmarks)):
|
||||
new_rect, new_lmrks = rect, lmrks
|
||||
|
@ -200,7 +201,7 @@ class ExtractSubprocessor(Subprocessor):
|
|||
image,
|
||||
face_type,
|
||||
image_size,
|
||||
jpeg_quality,
|
||||
extract_from_dflimg = False,
|
||||
output_debug_path=None,
|
||||
final_output_path=None,
|
||||
):
|
||||
|
@ -212,53 +213,72 @@ class ExtractSubprocessor(Subprocessor):
|
|||
if output_debug_path is not None:
|
||||
debug_image = image.copy()
|
||||
|
||||
face_idx = 0
|
||||
for rect, image_landmarks in zip( rects, landmarks ):
|
||||
if image_landmarks is None:
|
||||
continue
|
||||
if extract_from_dflimg and len(rects) != 1:
|
||||
#if re-extracting from dflimg and more than 1 or zero faces detected - dont process and just copy it
|
||||
print("extract_from_dflimg and len(rects) != 1", filepath )
|
||||
output_filepath = final_output_path / filepath.name
|
||||
if filepath != str(output_file):
|
||||
shutil.copy ( str(filepath), str(output_filepath) )
|
||||
data.final_output_files.append (output_filepath)
|
||||
else:
|
||||
face_idx = 0
|
||||
for rect, image_landmarks in zip( rects, landmarks ):
|
||||
|
||||
rect = np.array(rect)
|
||||
if extract_from_dflimg and face_idx > 1:
|
||||
#cannot extract more than 1 face from dflimg
|
||||
break
|
||||
|
||||
if face_type == FaceType.MARK_ONLY:
|
||||
image_to_face_mat = None
|
||||
face_image = image
|
||||
face_image_landmarks = image_landmarks
|
||||
else:
|
||||
image_to_face_mat = LandmarksProcessor.get_transform_mat (image_landmarks, image_size, face_type)
|
||||
|
||||
face_image = cv2.warpAffine(image, image_to_face_mat, (image_size, image_size), cv2.INTER_LANCZOS4)
|
||||
face_image_landmarks = LandmarksProcessor.transform_points (image_landmarks, image_to_face_mat)
|
||||
|
||||
landmarks_bbox = LandmarksProcessor.transform_points ( [ (0,0), (0,image_size-1), (image_size-1, image_size-1), (image_size-1,0) ], image_to_face_mat, True)
|
||||
|
||||
rect_area = mathlib.polygon_area(np.array(rect[[0,2,2,0]]).astype(np.float32), np.array(rect[[1,1,3,3]]).astype(np.float32))
|
||||
landmarks_area = mathlib.polygon_area(landmarks_bbox[:,0].astype(np.float32), landmarks_bbox[:,1].astype(np.float32) )
|
||||
|
||||
if not data.manual and face_type <= FaceType.FULL_NO_ALIGN and landmarks_area > 4*rect_area: #get rid of faces which umeyama-landmark-area > 4*detector-rect-area
|
||||
if image_landmarks is None:
|
||||
continue
|
||||
|
||||
if output_debug_path is not None:
|
||||
LandmarksProcessor.draw_rect_landmarks (debug_image, rect, image_landmarks, face_type, image_size, transparent_mask=True)
|
||||
rect = np.array(rect)
|
||||
|
||||
output_path = final_output_path
|
||||
if data.force_output_path is not None:
|
||||
output_path = data.force_output_path
|
||||
if face_type == FaceType.MARK_ONLY:
|
||||
image_to_face_mat = None
|
||||
face_image = image
|
||||
face_image_landmarks = image_landmarks
|
||||
else:
|
||||
image_to_face_mat = LandmarksProcessor.get_transform_mat (image_landmarks, image_size, face_type)
|
||||
|
||||
output_filepath = output_path / f"{filepath.stem}_{face_idx}.jpg"
|
||||
cv2_imwrite(output_filepath, face_image, [int(cv2.IMWRITE_JPEG_QUALITY), jpeg_quality ] )
|
||||
face_image = cv2.warpAffine(image, image_to_face_mat, (image_size, image_size), cv2.INTER_LANCZOS4)
|
||||
face_image_landmarks = LandmarksProcessor.transform_points (image_landmarks, image_to_face_mat)
|
||||
|
||||
dflimg = DFLJPG.load(output_filepath)
|
||||
dflimg.set_face_type(FaceType.toString(face_type))
|
||||
dflimg.set_landmarks(face_image_landmarks.tolist())
|
||||
dflimg.set_source_filename(filepath.name)
|
||||
dflimg.set_source_rect(rect)
|
||||
dflimg.set_source_landmarks(image_landmarks.tolist())
|
||||
dflimg.set_image_to_face_mat(image_to_face_mat)
|
||||
dflimg.save()
|
||||
landmarks_bbox = LandmarksProcessor.transform_points ( [ (0,0), (0,image_size-1), (image_size-1, image_size-1), (image_size-1,0) ], image_to_face_mat, True)
|
||||
|
||||
data.final_output_files.append (output_filepath)
|
||||
face_idx += 1
|
||||
data.faces_detected = face_idx
|
||||
rect_area = mathlib.polygon_area(np.array(rect[[0,2,2,0]]).astype(np.float32), np.array(rect[[1,1,3,3]]).astype(np.float32))
|
||||
landmarks_area = mathlib.polygon_area(landmarks_bbox[:,0].astype(np.float32), landmarks_bbox[:,1].astype(np.float32) )
|
||||
|
||||
if not data.manual and face_type <= FaceType.FULL_NO_ALIGN and landmarks_area > 4*rect_area: #get rid of faces which umeyama-landmark-area > 4*detector-rect-area
|
||||
continue
|
||||
|
||||
if output_debug_path is not None:
|
||||
LandmarksProcessor.draw_rect_landmarks (debug_image, rect, image_landmarks, face_type, image_size, transparent_mask=True)
|
||||
|
||||
output_path = final_output_path
|
||||
if data.force_output_path is not None:
|
||||
output_path = data.force_output_path
|
||||
|
||||
if extract_from_dflimg and filepath.suffix == '.jpg':
|
||||
#if extracting from dflimg and jpg copy it in order not to lose quality
|
||||
output_filepath = output_path / filepath.name
|
||||
if filepath != output_filepath:
|
||||
shutil.copy ( str(filepath), str(output_filepath) )
|
||||
else:
|
||||
output_filepath = output_path / f"{filepath.stem}_{face_idx}.jpg"
|
||||
cv2_imwrite(output_filepath, face_image, [int(cv2.IMWRITE_JPEG_QUALITY), 90] )
|
||||
|
||||
dflimg = DFLJPG.load(output_filepath)
|
||||
dflimg.set_face_type(FaceType.toString(face_type))
|
||||
dflimg.set_landmarks(face_image_landmarks.tolist())
|
||||
dflimg.set_source_filename(filepath.name)
|
||||
dflimg.set_source_rect(rect)
|
||||
dflimg.set_source_landmarks(image_landmarks.tolist())
|
||||
dflimg.set_image_to_face_mat(image_to_face_mat)
|
||||
dflimg.save()
|
||||
|
||||
data.final_output_files.append (output_filepath)
|
||||
face_idx += 1
|
||||
data.faces_detected = face_idx
|
||||
|
||||
if output_debug_path is not None:
|
||||
cv2_imwrite( output_debug_path / (filepath.stem+'.jpg'), debug_image, [int(cv2.IMWRITE_JPEG_QUALITY), 50] )
|
||||
|
@ -304,7 +324,7 @@ class ExtractSubprocessor(Subprocessor):
|
|||
elif type == 'final':
|
||||
return [ (i, 'CPU', 'CPU%d' % (i), 0 ) for i in (range(min(8, multiprocessing.cpu_count())) if not DEBUG else [0]) ]
|
||||
|
||||
def __init__(self, input_data, type, image_size=None, jpeg_quality=None, face_type=None, output_debug_path=None, manual_window_size=0, max_faces_from_image=0, final_output_path=None, device_config=None):
|
||||
def __init__(self, input_data, type, image_size=None, face_type=None, output_debug_path=None, manual_window_size=0, max_faces_from_image=0, final_output_path=None, device_config=None):
|
||||
if type == 'landmarks-manual':
|
||||
for x in input_data:
|
||||
x.manual = True
|
||||
|
@ -313,7 +333,6 @@ class ExtractSubprocessor(Subprocessor):
|
|||
|
||||
self.type = type
|
||||
self.image_size = image_size
|
||||
self.jpeg_quality = jpeg_quality
|
||||
self.face_type = face_type
|
||||
self.output_debug_path = output_debug_path
|
||||
self.final_output_path = final_output_path
|
||||
|
@ -339,7 +358,6 @@ class ExtractSubprocessor(Subprocessor):
|
|||
self.cache_text_lines_img = (None, None)
|
||||
self.hide_help = False
|
||||
self.landmarks_accurate = True
|
||||
self.force_landmarks = False
|
||||
|
||||
self.landmarks = None
|
||||
self.x = 0
|
||||
|
@ -348,9 +366,6 @@ class ExtractSubprocessor(Subprocessor):
|
|||
self.rect_locked = False
|
||||
self.extract_needed = True
|
||||
|
||||
self.image = None
|
||||
self.image_filepath = None
|
||||
|
||||
io.progress_bar (None, len (self.input_data))
|
||||
|
||||
#override
|
||||
|
@ -364,7 +379,6 @@ class ExtractSubprocessor(Subprocessor):
|
|||
def process_info_generator(self):
|
||||
base_dict = {'type' : self.type,
|
||||
'image_size': self.image_size,
|
||||
'jpeg_quality' : self.jpeg_quality,
|
||||
'face_type': self.face_type,
|
||||
'max_faces_from_image':self.max_faces_from_image,
|
||||
'output_debug_path': self.output_debug_path,
|
||||
|
@ -383,13 +397,26 @@ class ExtractSubprocessor(Subprocessor):
|
|||
def get_data(self, host_dict):
|
||||
if self.type == 'landmarks-manual':
|
||||
need_remark_face = False
|
||||
redraw_needed = False
|
||||
while len (self.input_data) > 0:
|
||||
data = self.input_data[0]
|
||||
filepath, data_rects, data_landmarks = data.filepath, data.rects, data.landmarks
|
||||
is_frame_done = False
|
||||
|
||||
if self.image_filepath != filepath:
|
||||
self.image_filepath = filepath
|
||||
if need_remark_face: # need remark image from input data that already has a marked face?
|
||||
need_remark_face = False
|
||||
if len(data_rects) != 0: # If there was already a face then lock the rectangle to it until the mouse is clicked
|
||||
self.rect = data_rects.pop()
|
||||
self.landmarks = data_landmarks.pop()
|
||||
data_rects.clear()
|
||||
data_landmarks.clear()
|
||||
redraw_needed = True
|
||||
self.rect_locked = True
|
||||
self.rect_size = ( self.rect[2] - self.rect[0] ) / 2
|
||||
self.x = ( self.rect[0] + self.rect[2] ) / 2
|
||||
self.y = ( self.rect[1] + self.rect[3] ) / 2
|
||||
|
||||
if len(data_rects) == 0:
|
||||
if self.cache_original_image[0] == filepath:
|
||||
self.original_image = self.cache_original_image[1]
|
||||
else:
|
||||
|
@ -413,8 +440,8 @@ class ExtractSubprocessor(Subprocessor):
|
|||
self.text_lines_img = self.cache_text_lines_img[1]
|
||||
else:
|
||||
self.text_lines_img = (imagelib.get_draw_text_lines ( self.image, sh,
|
||||
[ '[L Mouse click] - lock/unlock selection. [Mouse wheel] - change rect',
|
||||
'[R Mouse Click] - manual face rectangle',
|
||||
[ '[Mouse click] - lock/unlock selection',
|
||||
'[Mouse wheel] - change rect',
|
||||
'[Enter] / [Space] - confirm / skip frame',
|
||||
'[,] [.]- prev frame, next frame. [Q] - skip remaining frames',
|
||||
'[a] - accuracy on/off (more fps)',
|
||||
|
@ -423,29 +450,11 @@ class ExtractSubprocessor(Subprocessor):
|
|||
|
||||
self.cache_text_lines_img = (sh, self.text_lines_img)
|
||||
|
||||
if need_remark_face: # need remark image from input data that already has a marked face?
|
||||
need_remark_face = False
|
||||
if len(data_rects) != 0: # If there was already a face then lock the rectangle to it until the mouse is clicked
|
||||
self.rect = data_rects.pop()
|
||||
self.landmarks = data_landmarks.pop()
|
||||
data_rects.clear()
|
||||
data_landmarks.clear()
|
||||
|
||||
self.rect_locked = True
|
||||
self.rect_size = ( self.rect[2] - self.rect[0] ) / 2
|
||||
self.x = ( self.rect[0] + self.rect[2] ) / 2
|
||||
self.y = ( self.rect[1] + self.rect[3] ) / 2
|
||||
self.redraw()
|
||||
|
||||
if len(data_rects) == 0:
|
||||
(h,w,c) = self.image.shape
|
||||
while True:
|
||||
io.process_messages(0.0001)
|
||||
|
||||
if not self.force_landmarks:
|
||||
new_x = self.x
|
||||
new_y = self.y
|
||||
|
||||
new_x = self.x
|
||||
new_y = self.y
|
||||
new_rect_size = self.rect_size
|
||||
|
||||
mouse_events = io.get_mouse_events(self.wnd_name)
|
||||
|
@ -456,19 +465,8 @@ class ExtractSubprocessor(Subprocessor):
|
|||
diff = 1 if new_rect_size <= 40 else np.clip(new_rect_size / 10, 1, 10)
|
||||
new_rect_size = max (5, new_rect_size + diff*mod)
|
||||
elif ev == io.EVENT_LBUTTONDOWN:
|
||||
if self.force_landmarks:
|
||||
self.x = new_x
|
||||
self.y = new_y
|
||||
self.force_landmarks = False
|
||||
self.rect_locked = True
|
||||
self.redraw()
|
||||
else:
|
||||
self.rect_locked = not self.rect_locked
|
||||
self.extract_needed = True
|
||||
elif ev == io.EVENT_RBUTTONDOWN:
|
||||
self.force_landmarks = not self.force_landmarks
|
||||
if self.force_landmarks:
|
||||
self.rect_locked = False
|
||||
self.rect_locked = not self.rect_locked
|
||||
self.extract_needed = True
|
||||
elif not self.rect_locked:
|
||||
new_x = np.clip (x, 0, w-1) / self.view_scale
|
||||
new_y = np.clip (y, 0, h-1) / self.view_scale
|
||||
|
@ -534,35 +532,11 @@ class ExtractSubprocessor(Subprocessor):
|
|||
self.landmarks_accurate = not self.landmarks_accurate
|
||||
break
|
||||
|
||||
if self.force_landmarks:
|
||||
pt2 = np.float32([new_x, new_y])
|
||||
pt1 = np.float32([self.x, self.y])
|
||||
|
||||
pt_vec_len = npla.norm(pt2-pt1)
|
||||
pt_vec = pt2-pt1
|
||||
if pt_vec_len != 0:
|
||||
pt_vec /= pt_vec_len
|
||||
|
||||
self.rect_size = pt_vec_len
|
||||
self.rect = ( int(self.x-self.rect_size),
|
||||
int(self.y-self.rect_size),
|
||||
int(self.x+self.rect_size),
|
||||
int(self.y+self.rect_size) )
|
||||
|
||||
if pt_vec_len > 0:
|
||||
lmrks = np.concatenate ( (np.zeros ((17,2), np.float32), LandmarksProcessor.landmarks_2D), axis=0 )
|
||||
lmrks -= lmrks[30:31,:]
|
||||
mat = cv2.getRotationMatrix2D( (0, 0), -np.arctan2( pt_vec[1], pt_vec[0] )*180/math.pi , pt_vec_len)
|
||||
mat[:, 2] += (self.x, self.y)
|
||||
self.landmarks = LandmarksProcessor.transform_points(lmrks, mat )
|
||||
|
||||
|
||||
self.redraw()
|
||||
|
||||
elif self.x != new_x or \
|
||||
if self.x != new_x or \
|
||||
self.y != new_y or \
|
||||
self.rect_size != new_rect_size or \
|
||||
self.extract_needed:
|
||||
self.extract_needed or \
|
||||
redraw_needed:
|
||||
self.x = new_x
|
||||
self.y = new_y
|
||||
self.rect_size = new_rect_size
|
||||
|
@ -571,7 +545,11 @@ class ExtractSubprocessor(Subprocessor):
|
|||
int(self.x+self.rect_size),
|
||||
int(self.y+self.rect_size) )
|
||||
|
||||
return ExtractSubprocessor.Data (filepath, rects=[self.rect], landmarks_accurate=self.landmarks_accurate)
|
||||
if redraw_needed:
|
||||
redraw_needed = False
|
||||
return ExtractSubprocessor.Data (filepath, landmarks_accurate=self.landmarks_accurate)
|
||||
else:
|
||||
return ExtractSubprocessor.Data (filepath, rects=[self.rect], landmarks_accurate=self.landmarks_accurate)
|
||||
|
||||
else:
|
||||
is_frame_done = True
|
||||
|
@ -593,40 +571,6 @@ class ExtractSubprocessor(Subprocessor):
|
|||
if not self.type != 'landmarks-manual':
|
||||
self.input_data.insert(0, data)
|
||||
|
||||
def redraw(self):
|
||||
(h,w,c) = self.image.shape
|
||||
|
||||
if not self.hide_help:
|
||||
image = cv2.addWeighted (self.image,1.0,self.text_lines_img,1.0,0)
|
||||
else:
|
||||
image = self.image.copy()
|
||||
|
||||
view_rect = (np.array(self.rect) * self.view_scale).astype(np.int).tolist()
|
||||
view_landmarks = (np.array(self.landmarks) * self.view_scale).astype(np.int).tolist()
|
||||
|
||||
if self.rect_size <= 40:
|
||||
scaled_rect_size = h // 3 if w > h else w // 3
|
||||
|
||||
p1 = (self.x - self.rect_size, self.y - self.rect_size)
|
||||
p2 = (self.x + self.rect_size, self.y - self.rect_size)
|
||||
p3 = (self.x - self.rect_size, self.y + self.rect_size)
|
||||
|
||||
wh = h if h < w else w
|
||||
np1 = (w / 2 - wh / 4, h / 2 - wh / 4)
|
||||
np2 = (w / 2 + wh / 4, h / 2 - wh / 4)
|
||||
np3 = (w / 2 - wh / 4, h / 2 + wh / 4)
|
||||
|
||||
mat = cv2.getAffineTransform( np.float32([p1,p2,p3])*self.view_scale, np.float32([np1,np2,np3]) )
|
||||
image = cv2.warpAffine(image, mat,(w,h) )
|
||||
view_landmarks = LandmarksProcessor.transform_points (view_landmarks, mat)
|
||||
|
||||
landmarks_color = (255,255,0) if self.rect_locked else (0,255,0)
|
||||
LandmarksProcessor.draw_rect_landmarks (image, view_rect, view_landmarks, self.face_type, self.image_size, landmarks_color=landmarks_color)
|
||||
self.extract_needed = False
|
||||
|
||||
io.show_image (self.wnd_name, image)
|
||||
|
||||
|
||||
#override
|
||||
def on_result (self, host_dict, data, result):
|
||||
if self.type == 'landmarks-manual':
|
||||
|
@ -635,7 +579,37 @@ class ExtractSubprocessor(Subprocessor):
|
|||
if len(landmarks) != 0 and landmarks[0] is not None:
|
||||
self.landmarks = landmarks[0]
|
||||
|
||||
self.redraw()
|
||||
(h,w,c) = self.image.shape
|
||||
|
||||
if not self.hide_help:
|
||||
image = cv2.addWeighted (self.image,1.0,self.text_lines_img,1.0,0)
|
||||
else:
|
||||
image = self.image.copy()
|
||||
|
||||
view_rect = (np.array(self.rect) * self.view_scale).astype(np.int).tolist()
|
||||
view_landmarks = (np.array(self.landmarks) * self.view_scale).astype(np.int).tolist()
|
||||
|
||||
if self.rect_size <= 40:
|
||||
scaled_rect_size = h // 3 if w > h else w // 3
|
||||
|
||||
p1 = (self.x - self.rect_size, self.y - self.rect_size)
|
||||
p2 = (self.x + self.rect_size, self.y - self.rect_size)
|
||||
p3 = (self.x - self.rect_size, self.y + self.rect_size)
|
||||
|
||||
wh = h if h < w else w
|
||||
np1 = (w / 2 - wh / 4, h / 2 - wh / 4)
|
||||
np2 = (w / 2 + wh / 4, h / 2 - wh / 4)
|
||||
np3 = (w / 2 - wh / 4, h / 2 + wh / 4)
|
||||
|
||||
mat = cv2.getAffineTransform( np.float32([p1,p2,p3])*self.view_scale, np.float32([np1,np2,np3]) )
|
||||
image = cv2.warpAffine(image, mat,(w,h) )
|
||||
view_landmarks = LandmarksProcessor.transform_points (view_landmarks, mat)
|
||||
|
||||
landmarks_color = (255,255,0) if self.rect_locked else (0,255,0)
|
||||
LandmarksProcessor.draw_rect_landmarks (image, view_rect, view_landmarks, self.face_type, self.image_size, landmarks_color=landmarks_color)
|
||||
self.extract_needed = False
|
||||
|
||||
io.show_image (self.wnd_name, image)
|
||||
else:
|
||||
self.result.append ( result )
|
||||
io.progress_bar_inc(1)
|
||||
|
@ -712,9 +686,7 @@ def main(detector=None,
|
|||
manual_output_debug_fix=False,
|
||||
manual_window_size=1368,
|
||||
face_type='full_face',
|
||||
max_faces_from_image=None,
|
||||
image_size=None,
|
||||
jpeg_quality=None,
|
||||
max_faces_from_image=0,
|
||||
cpu_only = False,
|
||||
force_gpu_idxs = None,
|
||||
):
|
||||
|
@ -723,57 +695,24 @@ def main(detector=None,
|
|||
io.log_err ('Input directory not found. Please ensure it exists.')
|
||||
return
|
||||
|
||||
if not output_path.exists():
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if face_type is not None:
|
||||
face_type = FaceType.fromString(face_type)
|
||||
|
||||
if face_type is None:
|
||||
if manual_output_debug_fix:
|
||||
if manual_output_debug_fix and output_path.exists():
|
||||
files = pathex.get_image_paths(output_path)
|
||||
if len(files) != 0:
|
||||
dflimg = DFLIMG.load(Path(files[0]))
|
||||
if dflimg is not None and dflimg.has_data():
|
||||
face_type = FaceType.fromString ( dflimg.get_face_type() )
|
||||
|
||||
input_image_paths = pathex.get_image_unique_filestem_paths(input_path, verbose_print_func=io.log_info)
|
||||
output_images_paths = pathex.get_image_paths(output_path)
|
||||
output_debug_path = output_path.parent / (output_path.name + '_debug')
|
||||
|
||||
continue_extraction = False
|
||||
if not manual_output_debug_fix and len(output_images_paths) > 0:
|
||||
if len(output_images_paths) > 128:
|
||||
continue_extraction = io.input_bool ("Continue extraction?", True, help_message="Extraction can be continued, but you must specify the same options again.")
|
||||
|
||||
if len(output_images_paths) > 128 and continue_extraction:
|
||||
try:
|
||||
input_image_paths = input_image_paths[ [ Path(x).stem for x in input_image_paths ].index ( Path(output_images_paths[-128]).stem.split('_')[0] ) : ]
|
||||
except:
|
||||
io.log_err("Error in fetching the last index. Extraction cannot be continued.")
|
||||
return
|
||||
elif input_path != output_path:
|
||||
io.input(f"\n WARNING !!! \n {output_path} contains files! \n They will be deleted. \n Press enter to continue.\n")
|
||||
for filename in output_images_paths:
|
||||
Path(filename).unlink()
|
||||
|
||||
device_config = nn.DeviceConfig.GPUIndexes( force_gpu_idxs or nn.ask_choose_device_idxs(choose_only_one=detector=='manual', suggest_all_gpu=True) ) \
|
||||
if not cpu_only else nn.DeviceConfig.CPU()
|
||||
|
||||
if face_type is None:
|
||||
face_type = io.input_str ("Face type", 'wf', ['f','wf','head'], help_message="Full face / whole face / head. 'Whole face' covers full area of face include forehead. 'head' covers full head, but requires XSeg for src and dst faceset.").lower()
|
||||
face_type = {'f' : FaceType.FULL,
|
||||
'wf' : FaceType.WHOLE_FACE,
|
||||
'head' : FaceType.HEAD}[face_type]
|
||||
|
||||
if max_faces_from_image is None:
|
||||
max_faces_from_image = io.input_int(f"Max number of faces from image", 0, help_message="If you extract a src faceset that has frames with a large number of faces, it is advisable to set max faces to 3 to speed up extraction. 0 - unlimited")
|
||||
|
||||
if image_size is None:
|
||||
image_size = io.input_int(f"Image size", 512 if face_type < FaceType.HEAD else 768, valid_range=[256,2048], help_message="Output image size. The higher image size, the worse face-enhancer works. Use higher than 512 value only if the source image is sharp enough and the face does not need to be enhanced.")
|
||||
|
||||
if jpeg_quality is None:
|
||||
jpeg_quality = io.input_int(f"Jpeg quality", 90, valid_range=[1,100], help_message="Jpeg quality. The higher jpeg quality the larger the output file size.")
|
||||
image_size = 512 if face_type < FaceType.HEAD else 768
|
||||
|
||||
if detector is None:
|
||||
io.log_info ("Choose detector type.")
|
||||
|
@ -781,12 +720,25 @@ def main(detector=None,
|
|||
io.log_info ("[1] manual")
|
||||
detector = {0:'s3fd', 1:'manual'}[ io.input_int("", 0, [0,1]) ]
|
||||
|
||||
device_config = nn.DeviceConfig.GPUIndexes( force_gpu_idxs or nn.ask_choose_device_idxs(choose_only_one=detector=='manual', suggest_all_gpu=True) ) \
|
||||
if not cpu_only else nn.DeviceConfig.CPU()
|
||||
|
||||
output_debug_path = output_path.parent / (output_path.name + '_debug')
|
||||
|
||||
if output_debug is None:
|
||||
output_debug = io.input_bool (f"Write debug images to {output_debug_path.name}?", False)
|
||||
|
||||
if output_debug:
|
||||
output_debug_path.mkdir(parents=True, exist_ok=True)
|
||||
if output_path.exists():
|
||||
if not manual_output_debug_fix and input_path != output_path:
|
||||
output_images_paths = pathex.get_image_paths(output_path)
|
||||
if len(output_images_paths) > 0:
|
||||
io.input(f"\n WARNING !!! \n {output_path} contains files! \n They will be deleted. \n Press enter to continue.\n")
|
||||
for filename in output_images_paths:
|
||||
Path(filename).unlink()
|
||||
else:
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
input_path_image_paths = pathex.get_image_unique_filestem_paths(input_path, verbose_print_func=io.log_info)
|
||||
|
||||
if manual_output_debug_fix:
|
||||
if not output_debug_path.exists():
|
||||
|
@ -796,30 +748,31 @@ def main(detector=None,
|
|||
detector = 'manual'
|
||||
io.log_info('Performing re-extract frames which were deleted from _debug directory.')
|
||||
|
||||
input_image_paths = DeletedFilesSearcherSubprocessor (input_image_paths, pathex.get_image_paths(output_debug_path) ).run()
|
||||
input_image_paths = sorted (input_image_paths)
|
||||
io.log_info('Found %d images.' % (len(input_image_paths)))
|
||||
input_path_image_paths = DeletedFilesSearcherSubprocessor (input_path_image_paths, pathex.get_image_paths(output_debug_path) ).run()
|
||||
input_path_image_paths = sorted (input_path_image_paths)
|
||||
io.log_info('Found %d images.' % (len(input_path_image_paths)))
|
||||
else:
|
||||
if not continue_extraction and output_debug_path.exists():
|
||||
if output_debug_path.exists():
|
||||
for filename in pathex.get_image_paths(output_debug_path):
|
||||
Path(filename).unlink()
|
||||
else:
|
||||
output_debug_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
images_found = len(input_image_paths)
|
||||
images_found = len(input_path_image_paths)
|
||||
faces_detected = 0
|
||||
if images_found != 0:
|
||||
if detector == 'manual':
|
||||
io.log_info ('Performing manual extract...')
|
||||
data = ExtractSubprocessor ([ ExtractSubprocessor.Data(Path(filename)) for filename in input_image_paths ], 'landmarks-manual', image_size, jpeg_quality, face_type, output_debug_path if output_debug else None, manual_window_size=manual_window_size, device_config=device_config).run()
|
||||
data = ExtractSubprocessor ([ ExtractSubprocessor.Data(Path(filename)) for filename in input_path_image_paths ], 'landmarks-manual', image_size, face_type, output_debug_path if output_debug else None, manual_window_size=manual_window_size, device_config=device_config).run()
|
||||
|
||||
io.log_info ('Performing 3rd pass...')
|
||||
data = ExtractSubprocessor (data, 'final', image_size, jpeg_quality, face_type, output_debug_path if output_debug else None, final_output_path=output_path, device_config=device_config).run()
|
||||
data = ExtractSubprocessor (data, 'final', image_size, face_type, output_debug_path if output_debug else None, final_output_path=output_path, device_config=device_config).run()
|
||||
|
||||
else:
|
||||
io.log_info ('Extracting faces...')
|
||||
data = ExtractSubprocessor ([ ExtractSubprocessor.Data(Path(filename)) for filename in input_image_paths ],
|
||||
data = ExtractSubprocessor ([ ExtractSubprocessor.Data(Path(filename)) for filename in input_path_image_paths ],
|
||||
'all',
|
||||
image_size,
|
||||
jpeg_quality,
|
||||
face_type,
|
||||
output_debug_path if output_debug else None,
|
||||
max_faces_from_image=max_faces_from_image,
|
||||
|
@ -834,8 +787,8 @@ def main(detector=None,
|
|||
else:
|
||||
fix_data = [ ExtractSubprocessor.Data(d.filepath) for d in data if d.faces_detected == 0 ]
|
||||
io.log_info ('Performing manual fix for %d images...' % (len(fix_data)) )
|
||||
fix_data = ExtractSubprocessor (fix_data, 'landmarks-manual', image_size, jpeg_quality, face_type, output_debug_path if output_debug else None, manual_window_size=manual_window_size, device_config=device_config).run()
|
||||
fix_data = ExtractSubprocessor (fix_data, 'final', image_size, jpeg_quality, face_type, output_debug_path if output_debug else None, final_output_path=output_path, device_config=device_config).run()
|
||||
fix_data = ExtractSubprocessor (fix_data, 'landmarks-manual', image_size, face_type, output_debug_path if output_debug else None, manual_window_size=manual_window_size, device_config=device_config).run()
|
||||
fix_data = ExtractSubprocessor (fix_data, 'final', image_size, face_type, output_debug_path if output_debug else None, final_output_path=output_path, device_config=device_config).run()
|
||||
faces_detected += sum([d.faces_detected for d in fix_data])
|
||||
|
||||
|
||||
|
|
|
@ -1,209 +0,0 @@
|
|||
import multiprocessing
|
||||
import shutil
|
||||
|
||||
import cv2
|
||||
from core import pathex
|
||||
from core.cv2ex import *
|
||||
from core.interact import interact as io
|
||||
from core.joblib import Subprocessor
|
||||
from DFLIMG import *
|
||||
from facelib import FaceType, LandmarksProcessor
|
||||
|
||||
|
||||
class FacesetResizerSubprocessor(Subprocessor):
|
||||
|
||||
#override
|
||||
def __init__(self, image_paths, output_dirpath, image_size, face_type=None):
|
||||
self.image_paths = image_paths
|
||||
self.output_dirpath = output_dirpath
|
||||
self.image_size = image_size
|
||||
self.face_type = face_type
|
||||
self.result = []
|
||||
|
||||
super().__init__('FacesetResizer', FacesetResizerSubprocessor.Cli, 600)
|
||||
|
||||
#override
|
||||
def on_clients_initialized(self):
|
||||
io.progress_bar (None, len (self.image_paths))
|
||||
|
||||
#override
|
||||
def on_clients_finalized(self):
|
||||
io.progress_bar_close()
|
||||
|
||||
#override
|
||||
def process_info_generator(self):
|
||||
base_dict = {'output_dirpath':self.output_dirpath, 'image_size':self.image_size, 'face_type':self.face_type}
|
||||
|
||||
for device_idx in range( min(8, multiprocessing.cpu_count()) ):
|
||||
client_dict = base_dict.copy()
|
||||
device_name = f'CPU #{device_idx}'
|
||||
client_dict['device_name'] = device_name
|
||||
yield device_name, {}, client_dict
|
||||
|
||||
#override
|
||||
def get_data(self, host_dict):
|
||||
if len (self.image_paths) > 0:
|
||||
return self.image_paths.pop(0)
|
||||
|
||||
#override
|
||||
def on_data_return (self, host_dict, data):
|
||||
self.image_paths.insert(0, data)
|
||||
|
||||
#override
|
||||
def on_result (self, host_dict, data, result):
|
||||
io.progress_bar_inc(1)
|
||||
if result[0] == 1:
|
||||
self.result +=[ (result[1], result[2]) ]
|
||||
|
||||
#override
|
||||
def get_result(self):
|
||||
return self.result
|
||||
|
||||
class Cli(Subprocessor.Cli):
|
||||
|
||||
#override
|
||||
def on_initialize(self, client_dict):
|
||||
self.output_dirpath = client_dict['output_dirpath']
|
||||
self.image_size = client_dict['image_size']
|
||||
self.face_type = client_dict['face_type']
|
||||
self.log_info (f"Running on { client_dict['device_name'] }")
|
||||
|
||||
#override
|
||||
def process_data(self, filepath):
|
||||
try:
|
||||
dflimg = DFLIMG.load (filepath)
|
||||
if dflimg is None or not dflimg.has_data():
|
||||
self.log_err (f"{filepath.name} is not a dfl image file")
|
||||
else:
|
||||
img = cv2_imread(filepath)
|
||||
h,w = img.shape[:2]
|
||||
if h != w:
|
||||
raise Exception(f'w != h in {filepath}')
|
||||
|
||||
image_size = self.image_size
|
||||
face_type = self.face_type
|
||||
output_filepath = self.output_dirpath / filepath.name
|
||||
|
||||
if face_type is not None:
|
||||
lmrks = dflimg.get_landmarks()
|
||||
mat = LandmarksProcessor.get_transform_mat(lmrks, image_size, face_type)
|
||||
|
||||
img = cv2.warpAffine(img, mat, (image_size, image_size), flags=cv2.INTER_LANCZOS4 )
|
||||
img = np.clip(img, 0, 255).astype(np.uint8)
|
||||
|
||||
cv2_imwrite ( str(output_filepath), img, [int(cv2.IMWRITE_JPEG_QUALITY), 100] )
|
||||
|
||||
dfl_dict = dflimg.get_dict()
|
||||
dflimg = DFLIMG.load (output_filepath)
|
||||
dflimg.set_dict(dfl_dict)
|
||||
|
||||
xseg_mask = dflimg.get_xseg_mask()
|
||||
if xseg_mask is not None:
|
||||
xseg_res = 256
|
||||
|
||||
xseg_lmrks = lmrks.copy()
|
||||
xseg_lmrks *= (xseg_res / w)
|
||||
xseg_mat = LandmarksProcessor.get_transform_mat(xseg_lmrks, xseg_res, face_type)
|
||||
|
||||
xseg_mask = cv2.warpAffine(xseg_mask, xseg_mat, (xseg_res, xseg_res), flags=cv2.INTER_LANCZOS4 )
|
||||
xseg_mask[xseg_mask < 0.5] = 0
|
||||
xseg_mask[xseg_mask >= 0.5] = 1
|
||||
|
||||
dflimg.set_xseg_mask(xseg_mask)
|
||||
|
||||
seg_ie_polys = dflimg.get_seg_ie_polys()
|
||||
|
||||
for poly in seg_ie_polys.get_polys():
|
||||
poly_pts = poly.get_pts()
|
||||
poly_pts = LandmarksProcessor.transform_points(poly_pts, mat)
|
||||
poly.set_points(poly_pts)
|
||||
|
||||
dflimg.set_seg_ie_polys(seg_ie_polys)
|
||||
|
||||
lmrks = LandmarksProcessor.transform_points(lmrks, mat)
|
||||
dflimg.set_landmarks(lmrks)
|
||||
|
||||
image_to_face_mat = dflimg.get_image_to_face_mat()
|
||||
if image_to_face_mat is not None:
|
||||
image_to_face_mat = LandmarksProcessor.get_transform_mat ( dflimg.get_source_landmarks(), image_size, face_type )
|
||||
dflimg.set_image_to_face_mat(image_to_face_mat)
|
||||
dflimg.set_face_type( FaceType.toString(face_type) )
|
||||
dflimg.save()
|
||||
|
||||
else:
|
||||
dfl_dict = dflimg.get_dict()
|
||||
|
||||
scale = w / image_size
|
||||
|
||||
img = cv2.resize(img, (image_size, image_size), interpolation=cv2.INTER_LANCZOS4)
|
||||
|
||||
cv2_imwrite ( str(output_filepath), img, [int(cv2.IMWRITE_JPEG_QUALITY), 100] )
|
||||
|
||||
dflimg = DFLIMG.load (output_filepath)
|
||||
dflimg.set_dict(dfl_dict)
|
||||
|
||||
lmrks = dflimg.get_landmarks()
|
||||
lmrks /= scale
|
||||
dflimg.set_landmarks(lmrks)
|
||||
|
||||
seg_ie_polys = dflimg.get_seg_ie_polys()
|
||||
seg_ie_polys.mult_points( 1.0 / scale)
|
||||
dflimg.set_seg_ie_polys(seg_ie_polys)
|
||||
|
||||
image_to_face_mat = dflimg.get_image_to_face_mat()
|
||||
|
||||
if image_to_face_mat is not None:
|
||||
face_type = FaceType.fromString ( dflimg.get_face_type() )
|
||||
image_to_face_mat = LandmarksProcessor.get_transform_mat ( dflimg.get_source_landmarks(), image_size, face_type )
|
||||
dflimg.set_image_to_face_mat(image_to_face_mat)
|
||||
dflimg.save()
|
||||
|
||||
return (1, filepath, output_filepath)
|
||||
except:
|
||||
self.log_err (f"Exception occured while processing file {filepath}. Error: {traceback.format_exc()}")
|
||||
|
||||
return (0, filepath, None)
|
||||
|
||||
def process_folder ( dirpath):
|
||||
|
||||
image_size = io.input_int(f"New image size", 512, valid_range=[128,2048])
|
||||
|
||||
face_type = io.input_str ("Change face type", 'same', ['h','mf','f','wf','head','same']).lower()
|
||||
if face_type == 'same':
|
||||
face_type = None
|
||||
else:
|
||||
face_type = {'h' : FaceType.HALF,
|
||||
'mf' : FaceType.MID_FULL,
|
||||
'f' : FaceType.FULL,
|
||||
'wf' : FaceType.WHOLE_FACE,
|
||||
'head' : FaceType.HEAD}[face_type]
|
||||
|
||||
|
||||
output_dirpath = dirpath.parent / (dirpath.name + '_resized')
|
||||
output_dirpath.mkdir (exist_ok=True, parents=True)
|
||||
|
||||
dirpath_parts = '/'.join( dirpath.parts[-2:])
|
||||
output_dirpath_parts = '/'.join( output_dirpath.parts[-2:] )
|
||||
io.log_info (f"Resizing faceset in {dirpath_parts}")
|
||||
io.log_info ( f"Processing to {output_dirpath_parts}")
|
||||
|
||||
output_images_paths = pathex.get_image_paths(output_dirpath)
|
||||
if len(output_images_paths) > 0:
|
||||
for filename in output_images_paths:
|
||||
Path(filename).unlink()
|
||||
|
||||
image_paths = [Path(x) for x in pathex.get_image_paths( dirpath )]
|
||||
result = FacesetResizerSubprocessor ( image_paths, output_dirpath, image_size, face_type).run()
|
||||
|
||||
is_merge = io.input_bool (f"\r\nMerge {output_dirpath_parts} to {dirpath_parts} ?", True)
|
||||
if is_merge:
|
||||
io.log_info (f"Copying processed files to {dirpath_parts}")
|
||||
|
||||
for (filepath, output_filepath) in result:
|
||||
try:
|
||||
shutil.copy (output_filepath, filepath)
|
||||
except:
|
||||
pass
|
||||
|
||||
io.log_info (f"Removing {output_dirpath_parts}")
|
||||
shutil.rmtree(output_dirpath)
|
|
@ -1,5 +1,4 @@
|
|||
import math
|
||||
import multiprocessing
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
|
||||
|
@ -14,8 +13,7 @@ from core.joblib import MPClassFuncOnDemand, MPFunc
|
|||
from core.leras import nn
|
||||
from DFLIMG import DFLIMG
|
||||
from facelib import FaceEnhancer, FaceType, LandmarksProcessor, XSegNet
|
||||
from merger import FrameInfo, InteractiveMergerSubprocessor, MergerConfig
|
||||
|
||||
from merger import FrameInfo, MergerConfig, InteractiveMergerSubprocessor
|
||||
|
||||
def main (model_class_name=None,
|
||||
saved_models_path=None,
|
||||
|
@ -49,7 +47,6 @@ def main (model_class_name=None,
|
|||
model = models.import_model(model_class_name)(is_training=False,
|
||||
saved_models_path=saved_models_path,
|
||||
force_gpu_idxs=force_gpu_idxs,
|
||||
force_model_name=force_model_name,
|
||||
cpu_only=cpu_only)
|
||||
|
||||
predictor_func, predictor_input_shape, cfg = model.get_MergerConfig()
|
||||
|
@ -74,9 +71,6 @@ def main (model_class_name=None,
|
|||
if not is_interactive:
|
||||
cfg.ask_settings()
|
||||
|
||||
subprocess_count = io.input_int("Number of workers?", max(8, multiprocessing.cpu_count()),
|
||||
valid_range=[1, multiprocessing.cpu_count()], help_message="Specify the number of threads to process. A low value may affect performance. A high value may result in memory error. The value may not be greater than CPU cores." )
|
||||
|
||||
input_path_image_paths = pathex.get_image_paths(input_path)
|
||||
|
||||
if cfg.type == MergerConfig.TYPE_MASKED:
|
||||
|
@ -205,8 +199,7 @@ def main (model_class_name=None,
|
|||
frames_root_path = input_path,
|
||||
output_path = output_path,
|
||||
output_mask_path = output_mask_path,
|
||||
model_iter = model.get_iter(),
|
||||
subprocess_count = subprocess_count,
|
||||
model_iter = model.get_iter()
|
||||
).run()
|
||||
|
||||
model.finalize()
|
||||
|
|
|
@ -23,9 +23,6 @@ from facelib import LandmarksProcessor
|
|||
|
||||
class BlurEstimatorSubprocessor(Subprocessor):
|
||||
class Cli(Subprocessor.Cli):
|
||||
def on_initialize(self, client_dict):
|
||||
self.estimate_motion_blur = client_dict['estimate_motion_blur']
|
||||
|
||||
#override
|
||||
def process_data(self, data):
|
||||
filepath = Path( data[0] )
|
||||
|
@ -36,17 +33,7 @@ class BlurEstimatorSubprocessor(Subprocessor):
|
|||
return [ str(filepath), 0 ]
|
||||
else:
|
||||
image = cv2_imread( str(filepath) )
|
||||
|
||||
face_mask = LandmarksProcessor.get_image_hull_mask (image.shape, dflimg.get_landmarks())
|
||||
image = (image*face_mask).astype(np.uint8)
|
||||
|
||||
|
||||
if self.estimate_motion_blur:
|
||||
value = cv2.Laplacian(image, cv2.CV_64F, ksize=11).var()
|
||||
else:
|
||||
value = estimate_sharpness(image)
|
||||
|
||||
return [ str(filepath), value ]
|
||||
return [ str(filepath), estimate_sharpness(image) ]
|
||||
|
||||
|
||||
#override
|
||||
|
@ -55,9 +42,8 @@ class BlurEstimatorSubprocessor(Subprocessor):
|
|||
return data[0]
|
||||
|
||||
#override
|
||||
def __init__(self, input_data, estimate_motion_blur=False ):
|
||||
def __init__(self, input_data ):
|
||||
self.input_data = input_data
|
||||
self.estimate_motion_blur = estimate_motion_blur
|
||||
self.img_list = []
|
||||
self.trash_img_list = []
|
||||
super().__init__('BlurEstimator', BlurEstimatorSubprocessor.Cli, 60)
|
||||
|
@ -76,7 +62,7 @@ class BlurEstimatorSubprocessor(Subprocessor):
|
|||
io.log_info(f'Running on {cpu_count} CPUs')
|
||||
|
||||
for i in range(cpu_count):
|
||||
yield 'CPU%d' % (i), {}, {'estimate_motion_blur':self.estimate_motion_blur}
|
||||
yield 'CPU%d' % (i), {}, {}
|
||||
|
||||
#override
|
||||
def get_data(self, host_dict):
|
||||
|
@ -114,17 +100,6 @@ def sort_by_blur(input_path):
|
|||
|
||||
return img_list, trash_img_list
|
||||
|
||||
def sort_by_motion_blur(input_path):
|
||||
io.log_info ("Sorting by motion blur...")
|
||||
|
||||
img_list = [ (filename,[]) for filename in pathex.get_image_paths(input_path) ]
|
||||
img_list, trash_img_list = BlurEstimatorSubprocessor (img_list, estimate_motion_blur=True).run()
|
||||
|
||||
io.log_info ("Sorting...")
|
||||
img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True)
|
||||
|
||||
return img_list, trash_img_list
|
||||
|
||||
def sort_by_face_yaw(input_path):
|
||||
io.log_info ("Sorting by face yaw...")
|
||||
img_list = []
|
||||
|
@ -468,12 +443,12 @@ class FinalLoaderSubprocessor(Subprocessor):
|
|||
raise Exception ("Unable to load %s" % (filepath.name) )
|
||||
|
||||
gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
if self.faster:
|
||||
source_rect = dflimg.get_source_rect()
|
||||
sharpness = mathlib.polygon_area(np.array(source_rect[[0,2,2,0]]).astype(np.float32), np.array(source_rect[[1,1,3,3]]).astype(np.float32))
|
||||
else:
|
||||
face_mask = LandmarksProcessor.get_image_hull_mask (gray.shape, dflimg.get_landmarks())
|
||||
sharpness = estimate_sharpness( (gray[...,None]*face_mask).astype(np.uint8) )
|
||||
sharpness = estimate_sharpness(gray)
|
||||
|
||||
pitch, yaw, roll = LandmarksProcessor.estimate_pitch_yaw_roll ( dflimg.get_landmarks(), size=dflimg.get_shape()[1] )
|
||||
|
||||
|
@ -897,7 +872,6 @@ def final_process(input_path, img_list, trash_img_list):
|
|||
|
||||
sort_func_methods = {
|
||||
'blur': ("blur", sort_by_blur),
|
||||
'motion-blur': ("motion_blur", sort_by_motion_blur),
|
||||
'face-yaw': ("face yaw direction", sort_by_face_yaw),
|
||||
'face-pitch': ("face pitch direction", sort_by_face_pitch),
|
||||
'face-source-rect-size' : ("face rect size in source image", sort_by_face_source_rect_size),
|
||||
|
@ -925,7 +899,7 @@ def main (input_path, sort_by_method=None):
|
|||
io.log_info(f"[{i}] {desc}")
|
||||
|
||||
io.log_info("")
|
||||
id = io.input_int("", 5, valid_list=[*range(len(key_list))] )
|
||||
id = io.input_int("", 4, valid_list=[*range(len(key_list))] )
|
||||
|
||||
sort_by_method = key_list[id]
|
||||
else:
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import os
|
||||
import sys
|
||||
import sys
|
||||
import traceback
|
||||
import queue
|
||||
import threading
|
||||
|
@ -32,7 +31,7 @@ def trainerThread (s2c, c2s, e,
|
|||
try:
|
||||
start_time = time.time()
|
||||
|
||||
save_interval_min = 25
|
||||
save_interval_min = 15
|
||||
|
||||
if not training_data_src_path.exists():
|
||||
training_data_src_path.mkdir(exist_ok=True, parents=True)
|
||||
|
@ -55,7 +54,8 @@ def trainerThread (s2c, c2s, e,
|
|||
force_gpu_idxs=force_gpu_idxs,
|
||||
cpu_only=cpu_only,
|
||||
silent_start=silent_start,
|
||||
debug=debug)
|
||||
debug=debug,
|
||||
)
|
||||
|
||||
is_reached_goal = model.is_reached_iter_goal()
|
||||
|
||||
|
@ -120,12 +120,6 @@ def trainerThread (s2c, c2s, e,
|
|||
io.log_info("Trying to do the first iteration. If an error occurs, reduce the model parameters.")
|
||||
io.log_info("")
|
||||
|
||||
if sys.platform[0:3] == 'win':
|
||||
io.log_info("!!!")
|
||||
io.log_info("Windows 10 users IMPORTANT notice. You should set this setting in order to work correctly.")
|
||||
io.log_info("https://i.imgur.com/B7cmDCB.jpg")
|
||||
io.log_info("!!!")
|
||||
|
||||
iter, iter_time = model.train_one_iter()
|
||||
|
||||
loss_history = model.get_loss_history()
|
||||
|
@ -164,12 +158,8 @@ def trainerThread (s2c, c2s, e,
|
|||
is_reached_goal = True
|
||||
io.log_info ('You can use preview now.')
|
||||
|
||||
need_save = False
|
||||
while time.time() - last_save_time >= save_interval_min*60:
|
||||
if not is_reached_goal and (time.time() - last_save_time) >= save_interval_min*60:
|
||||
last_save_time += save_interval_min*60
|
||||
need_save = True
|
||||
|
||||
if not is_reached_goal and need_save:
|
||||
model_save()
|
||||
send_preview()
|
||||
|
||||
|
|
|
@ -64,7 +64,7 @@ def restore_faceset_metadata_folder(input_path):
|
|||
|
||||
img = cv2_imread (filepath)
|
||||
if img.shape != shape:
|
||||
img = cv2.resize (img, (shape[1], shape[0]), interpolation=cv2.INTER_LANCZOS4 )
|
||||
img = cv2.resize (img, (shape[1], shape[0]), cv2.INTER_LANCZOS4 )
|
||||
|
||||
cv2_imwrite (str(filepath), img, [int(cv2.IMWRITE_JPEG_QUALITY), 100] )
|
||||
|
||||
|
@ -159,32 +159,3 @@ def recover_original_aligned_filename(input_path):
|
|||
fs.rename (fd)
|
||||
except:
|
||||
io.log_err ('fail to rename %s' % (fs.name) )
|
||||
|
||||
def export_faceset_mask(input_dir):
|
||||
for filename in io.progress_bar_generator(pathex.get_image_paths (input_dir), "Processing"):
|
||||
filepath = Path(filename)
|
||||
|
||||
if '_mask' in filepath.stem:
|
||||
continue
|
||||
|
||||
mask_filepath = filepath.parent / (filepath.stem+'_mask'+filepath.suffix)
|
||||
|
||||
dflimg = DFLJPG.load(filepath)
|
||||
|
||||
H,W,C = dflimg.shape
|
||||
|
||||
seg_ie_polys = dflimg.get_seg_ie_polys()
|
||||
|
||||
if dflimg.has_xseg_mask():
|
||||
mask = dflimg.get_xseg_mask()
|
||||
mask[mask < 0.5] = 0.0
|
||||
mask[mask >= 0.5] = 1.0
|
||||
elif seg_ie_polys.has_polys():
|
||||
mask = np.zeros ((H,W,1), dtype=np.float32)
|
||||
seg_ie_polys.overlay_mask(mask)
|
||||
else:
|
||||
raise Exception(f'no mask in file {filepath}')
|
||||
|
||||
|
||||
cv2_imwrite(mask_filepath, (mask*255).astype(np.uint8), [int(cv2.IMWRITE_JPEG_QUALITY), 100] )
|
||||
|
|
@ -10,8 +10,8 @@ from core.cv2ex import *
|
|||
from core.interact import interact as io
|
||||
from core.leras import nn
|
||||
from DFLIMG import *
|
||||
from facelib import XSegNet, LandmarksProcessor, FaceType
|
||||
import pickle
|
||||
from facelib import XSegNet
|
||||
|
||||
|
||||
def apply_xseg(input_path, model_path):
|
||||
if not input_path.exists():
|
||||
|
@ -20,42 +20,17 @@ def apply_xseg(input_path, model_path):
|
|||
if not model_path.exists():
|
||||
raise ValueError(f'{model_path} not found. Please ensure it exists.')
|
||||
|
||||
face_type = None
|
||||
|
||||
model_dat = model_path / 'XSeg_data.dat'
|
||||
if model_dat.exists():
|
||||
dat = pickle.loads( model_dat.read_bytes() )
|
||||
dat_options = dat.get('options', None)
|
||||
if dat_options is not None:
|
||||
face_type = dat_options.get('face_type', None)
|
||||
|
||||
|
||||
|
||||
if face_type is None:
|
||||
face_type = io.input_str ("XSeg model face type", 'same', ['h','mf','f','wf','head','same'], help_message="Specify face type of trained XSeg model. For example if XSeg model trained as WF, but faceset is HEAD, specify WF to apply xseg only on WF part of HEAD. Default is 'same'").lower()
|
||||
if face_type == 'same':
|
||||
face_type = None
|
||||
|
||||
if face_type is not None:
|
||||
face_type = {'h' : FaceType.HALF,
|
||||
'mf' : FaceType.MID_FULL,
|
||||
'f' : FaceType.FULL,
|
||||
'wf' : FaceType.WHOLE_FACE,
|
||||
'head' : FaceType.HEAD}[face_type]
|
||||
|
||||
io.log_info(f'Applying trained XSeg model to {input_path.name}/ folder.')
|
||||
|
||||
device_config = nn.DeviceConfig.ask_choose_device(choose_only_one=True)
|
||||
nn.initialize(device_config)
|
||||
|
||||
|
||||
|
||||
xseg = XSegNet(name='XSeg',
|
||||
load_weights=True,
|
||||
weights_file_root=model_path,
|
||||
data_format=nn.data_format,
|
||||
raise_on_no_model_files=True)
|
||||
xseg_res = xseg.get_resolution()
|
||||
res = xseg.get_resolution()
|
||||
|
||||
images_paths = pathex.get_image_paths(input_path, return_Path_class=True)
|
||||
|
||||
|
@ -67,36 +42,15 @@ def apply_xseg(input_path, model_path):
|
|||
|
||||
img = cv2_imread(filepath).astype(np.float32) / 255.0
|
||||
h,w,c = img.shape
|
||||
|
||||
img_face_type = FaceType.fromString( dflimg.get_face_type() )
|
||||
if face_type is not None and img_face_type != face_type:
|
||||
lmrks = dflimg.get_source_landmarks()
|
||||
|
||||
fmat = LandmarksProcessor.get_transform_mat(lmrks, w, face_type)
|
||||
imat = LandmarksProcessor.get_transform_mat(lmrks, w, img_face_type)
|
||||
|
||||
g_p = LandmarksProcessor.transform_points (np.float32([(0,0),(w,0),(0,w) ]), fmat, True)
|
||||
g_p2 = LandmarksProcessor.transform_points (g_p, imat)
|
||||
|
||||
mat = cv2.getAffineTransform( g_p2, np.float32([(0,0),(w,0),(0,w) ]) )
|
||||
|
||||
img = cv2.warpAffine(img, mat, (w, w), cv2.INTER_LANCZOS4)
|
||||
img = cv2.resize(img, (xseg_res, xseg_res), interpolation=cv2.INTER_LANCZOS4)
|
||||
else:
|
||||
if w != xseg_res:
|
||||
img = cv2.resize( img, (xseg_res,xseg_res), interpolation=cv2.INTER_LANCZOS4 )
|
||||
|
||||
if len(img.shape) == 2:
|
||||
img = img[...,None]
|
||||
if w != res:
|
||||
img = cv2.resize( img, (res,res), interpolation=cv2.INTER_CUBIC )
|
||||
if len(img.shape) == 2:
|
||||
img = img[...,None]
|
||||
|
||||
mask = xseg.extract(img)
|
||||
|
||||
if face_type is not None and img_face_type != face_type:
|
||||
mask = cv2.resize(mask, (w, w), interpolation=cv2.INTER_LANCZOS4)
|
||||
mask = cv2.warpAffine( mask, mat, (w,w), np.zeros( (h,w,c), dtype=np.float), cv2.WARP_INVERSE_MAP | cv2.INTER_LANCZOS4)
|
||||
mask = cv2.resize(mask, (xseg_res, xseg_res), interpolation=cv2.INTER_LANCZOS4)
|
||||
mask[mask < 0.5]=0
|
||||
mask[mask >= 0.5]=1
|
||||
|
||||
dflimg.set_xseg_mask(mask)
|
||||
dflimg.save()
|
||||
|
||||
|
@ -113,8 +67,7 @@ def fetch_xseg(input_path):
|
|||
|
||||
images_paths = pathex.get_image_paths(input_path, return_Path_class=True)
|
||||
|
||||
|
||||
files_copied = []
|
||||
files_copied = 0
|
||||
for filepath in io.progress_bar_generator(images_paths, "Processing"):
|
||||
dflimg = DFLIMG.load(filepath)
|
||||
if dflimg is None or not dflimg.has_data():
|
||||
|
@ -124,16 +77,10 @@ def fetch_xseg(input_path):
|
|||
ie_polys = dflimg.get_seg_ie_polys()
|
||||
|
||||
if ie_polys.has_polys():
|
||||
files_copied.append(filepath)
|
||||
files_copied += 1
|
||||
shutil.copy ( str(filepath), str(output_path / filepath.name) )
|
||||
|
||||
io.log_info(f'Files copied: {len(files_copied)}')
|
||||
|
||||
is_delete = io.input_bool (f"\r\nDelete original files?", True)
|
||||
if is_delete:
|
||||
for filepath in files_copied:
|
||||
Path(filepath).unlink()
|
||||
|
||||
io.log_info(f'Files copied: {files_copied}')
|
||||
|
||||
def remove_xseg(input_path):
|
||||
if not input_path.exists():
|
||||
|
|
|
@ -13,6 +13,7 @@ from core.joblib import Subprocessor
|
|||
from core.leras import nn
|
||||
from DFLIMG import *
|
||||
from facelib import FaceType, LandmarksProcessor
|
||||
|
||||
from . import Extractor, Sorter
|
||||
from .Extractor import ExtractSubprocessor
|
||||
|
||||
|
@ -356,114 +357,25 @@ def extract_umd_csv(input_file_csv,
|
|||
io.log_info ('Faces detected: %d' % (faces_detected) )
|
||||
io.log_info ('-------------------------')
|
||||
|
||||
|
||||
|
||||
def dev_test1(input_dir):
|
||||
# LaPa dataset
|
||||
|
||||
image_size = 1024
|
||||
face_type = FaceType.HEAD
|
||||
|
||||
input_path = Path(input_dir)
|
||||
images_path = input_path / 'images'
|
||||
if not images_path.exists:
|
||||
raise ValueError('LaPa dataset: images folder not found.')
|
||||
labels_path = input_path / 'labels'
|
||||
if not labels_path.exists:
|
||||
raise ValueError('LaPa dataset: labels folder not found.')
|
||||
landmarks_path = input_path / 'landmarks'
|
||||
if not landmarks_path.exists:
|
||||
raise ValueError('LaPa dataset: landmarks folder not found.')
|
||||
|
||||
output_path = input_path / 'out'
|
||||
if output_path.exists():
|
||||
output_images_paths = pathex.get_image_paths(output_path)
|
||||
if len(output_images_paths) != 0:
|
||||
io.input(f"\n WARNING !!! \n {output_path} contains files! \n They will be deleted. \n Press enter to continue.\n")
|
||||
for filename in output_images_paths:
|
||||
Path(filename).unlink()
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
dir_names = pathex.get_all_dir_names(input_path)
|
||||
|
||||
data = []
|
||||
for dir_name in io.progress_bar_generator(dir_names, desc="Processing"):
|
||||
|
||||
img_paths = pathex.get_image_paths (images_path)
|
||||
for filename in img_paths:
|
||||
filepath = Path(filename)
|
||||
|
||||
landmark_filepath = landmarks_path / (filepath.stem + '.txt')
|
||||
if not landmark_filepath.exists():
|
||||
raise ValueError(f'no landmarks for {filepath}')
|
||||
|
||||
#img = cv2_imread(filepath)
|
||||
|
||||
lm = landmark_filepath.read_text()
|
||||
lm = lm.split('\n')
|
||||
if int(lm[0]) != 106:
|
||||
raise ValueError(f'wrong landmarks format in {landmark_filepath}')
|
||||
|
||||
lmrks = []
|
||||
for i in range(106):
|
||||
x,y = lm[i+1].split(' ')
|
||||
x,y = float(x), float(y)
|
||||
lmrks.append ( (x,y) )
|
||||
|
||||
lmrks = np.array(lmrks)
|
||||
|
||||
l,t = np.min(lmrks, 0)
|
||||
r,b = np.max(lmrks, 0)
|
||||
|
||||
l,t,r,b = ( int(x) for x in (l,t,r,b) )
|
||||
|
||||
#for x, y in lmrks:
|
||||
# x,y = int(x), int(y)
|
||||
# cv2.circle(img, (x, y), 1, (0,255,0) , 1, lineType=cv2.LINE_AA)
|
||||
|
||||
#imagelib.draw_rect(img, (l,t,r,b), (0,255,0) )
|
||||
|
||||
|
||||
data += [ ExtractSubprocessor.Data(filepath=filepath, rects=[ (l,t,r,b) ]) ]
|
||||
|
||||
#cv2.imshow("", img)
|
||||
#cv2.waitKey(0)
|
||||
|
||||
if len(data) > 0:
|
||||
device_config = nn.DeviceConfig.BestGPU()
|
||||
|
||||
io.log_info ("Performing 2nd pass...")
|
||||
data = ExtractSubprocessor (data, 'landmarks', image_size, 95, face_type, device_config=device_config).run()
|
||||
io.log_info ("Performing 3rd pass...")
|
||||
data = ExtractSubprocessor (data, 'final', image_size, 95, face_type, final_output_path=output_path, device_config=device_config).run()
|
||||
|
||||
|
||||
for filename in pathex.get_image_paths (output_path):
|
||||
img_paths = pathex.get_image_paths (input_path / dir_name)
|
||||
for filename in img_paths:
|
||||
filepath = Path(filename)
|
||||
|
||||
dflimg = DFLIMG.x (filepath)
|
||||
if dflimg is None:
|
||||
raise ValueError
|
||||
|
||||
dflimg = DFLJPG.load(filepath)
|
||||
|
||||
src_filename = dflimg.get_source_filename()
|
||||
image_to_face_mat = dflimg.get_image_to_face_mat()
|
||||
|
||||
label_filepath = labels_path / ( Path(src_filename).stem + '.png')
|
||||
if not label_filepath.exists():
|
||||
raise ValueError(f'{label_filepath} does not exist')
|
||||
|
||||
mask = cv2_imread(label_filepath)
|
||||
#mask[mask == 10] = 0 # remove hair
|
||||
mask[mask > 0] = 1
|
||||
mask = cv2.warpAffine(mask, image_to_face_mat, (image_size, image_size), cv2.INTER_LINEAR)
|
||||
mask = cv2.blur(mask, (3,3) )
|
||||
|
||||
#cv2.imshow("", (mask*255).astype(np.uint8) )
|
||||
#cv2.waitKey(0)
|
||||
|
||||
dflimg.set_xseg_mask(mask)
|
||||
dflimg.save()
|
||||
|
||||
|
||||
import code
|
||||
code.interact(local=dict(globals(), **locals()))
|
||||
#dflimg.x(filename, person_name=dir_name)
|
||||
|
||||
#import code
|
||||
#code.interact(local=dict(globals(), **locals()))
|
||||
|
||||
def dev_resave_pngs(input_dir):
|
||||
input_path = Path(input_dir)
|
||||
|
@ -499,96 +411,3 @@ def dev_segmented_trash(input_dir):
|
|||
except:
|
||||
io.log_info ('fail to trashing %s' % (src.name) )
|
||||
|
||||
|
||||
|
||||
def dev_test(input_dir):
|
||||
"""
|
||||
extract FaceSynthetics dataset https://github.com/microsoft/FaceSynthetics
|
||||
|
||||
BACKGROUND = 0
|
||||
SKIN = 1
|
||||
NOSE = 2
|
||||
RIGHT_EYE = 3
|
||||
LEFT_EYE = 4
|
||||
RIGHT_BROW = 5
|
||||
LEFT_BROW = 6
|
||||
RIGHT_EAR = 7
|
||||
LEFT_EAR = 8
|
||||
MOUTH_INTERIOR = 9
|
||||
TOP_LIP = 10
|
||||
BOTTOM_LIP = 11
|
||||
NECK = 12
|
||||
HAIR = 13
|
||||
BEARD = 14
|
||||
CLOTHING = 15
|
||||
GLASSES = 16
|
||||
HEADWEAR = 17
|
||||
FACEWEAR = 18
|
||||
IGNORE = 255
|
||||
"""
|
||||
|
||||
|
||||
image_size = 1024
|
||||
face_type = FaceType.WHOLE_FACE
|
||||
|
||||
input_path = Path(input_dir)
|
||||
|
||||
|
||||
|
||||
output_path = input_path.parent / f'{input_path.name}_out'
|
||||
if output_path.exists():
|
||||
output_images_paths = pathex.get_image_paths(output_path)
|
||||
if len(output_images_paths) != 0:
|
||||
io.input(f"\n WARNING !!! \n {output_path} contains files! \n They will be deleted. \n Press enter to continue.\n")
|
||||
for filename in output_images_paths:
|
||||
Path(filename).unlink()
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
data = []
|
||||
|
||||
for filepath in io.progress_bar_generator(pathex.get_paths(input_path), "Processing"):
|
||||
if filepath.suffix == '.txt':
|
||||
|
||||
image_filepath = filepath.parent / f'{filepath.name.split("_")[0]}.png'
|
||||
if not image_filepath.exists():
|
||||
print(f'{image_filepath} does not exist, skipping')
|
||||
|
||||
lmrks = []
|
||||
for lmrk_line in filepath.read_text().split('\n'):
|
||||
if len(lmrk_line) == 0:
|
||||
continue
|
||||
|
||||
x, y = lmrk_line.split(' ')
|
||||
x, y = float(x), float(y)
|
||||
|
||||
lmrks.append( (x,y) )
|
||||
|
||||
lmrks = np.array(lmrks[:68], np.float32)
|
||||
rect = LandmarksProcessor.get_rect_from_landmarks(lmrks)
|
||||
data += [ ExtractSubprocessor.Data(filepath=image_filepath, rects=[rect], landmarks=[ lmrks ] ) ]
|
||||
|
||||
if len(data) > 0:
|
||||
io.log_info ("Performing 3rd pass...")
|
||||
data = ExtractSubprocessor (data, 'final', image_size, 95, face_type, final_output_path=output_path, device_config=nn.DeviceConfig.CPU()).run()
|
||||
|
||||
for filename in io.progress_bar_generator(pathex.get_image_paths (output_path), "Processing"):
|
||||
filepath = Path(filename)
|
||||
|
||||
dflimg = DFLJPG.load(filepath)
|
||||
|
||||
src_filename = dflimg.get_source_filename()
|
||||
image_to_face_mat = dflimg.get_image_to_face_mat()
|
||||
|
||||
seg_filepath = input_path / ( Path(src_filename).stem + '_seg.png')
|
||||
if not seg_filepath.exists():
|
||||
raise ValueError(f'{seg_filepath} does not exist')
|
||||
|
||||
seg = cv2_imread(seg_filepath)
|
||||
seg_inds = np.isin(seg, [1,2,3,4,5,6,9,10,11])
|
||||
seg[~seg_inds] = 0
|
||||
seg[seg_inds] = 1
|
||||
seg = seg.astype(np.float32)
|
||||
seg = cv2.warpAffine(seg, image_to_face_mat, (image_size, image_size), cv2.INTER_LANCZOS4)
|
||||
dflimg.set_xseg_mask(seg)
|
||||
dflimg.save()
|
||||
|
|
@ -140,7 +140,7 @@ class InteractiveMergerSubprocessor(Subprocessor):
|
|||
|
||||
|
||||
#override
|
||||
def __init__(self, is_interactive, merger_session_filepath, predictor_func, predictor_input_shape, face_enhancer_func, xseg_256_extract_func, merger_config, frames, frames_root_path, output_path, output_mask_path, model_iter, subprocess_count=4):
|
||||
def __init__(self, is_interactive, merger_session_filepath, predictor_func, predictor_input_shape, face_enhancer_func, xseg_256_extract_func, merger_config, frames, frames_root_path, output_path, output_mask_path, model_iter):
|
||||
if len (frames) == 0:
|
||||
raise ValueError ("len (frames) == 0")
|
||||
|
||||
|
@ -161,7 +161,7 @@ class InteractiveMergerSubprocessor(Subprocessor):
|
|||
self.output_mask_path = output_mask_path
|
||||
self.model_iter = model_iter
|
||||
|
||||
self.prefetch_frame_count = self.process_count = subprocess_count
|
||||
self.prefetch_frame_count = self.process_count = multiprocessing.cpu_count()
|
||||
|
||||
session_data = None
|
||||
if self.is_interactive and self.merger_session_filepath.exists():
|
||||
|
@ -393,7 +393,6 @@ class InteractiveMergerSubprocessor(Subprocessor):
|
|||
# unable to read? recompute then
|
||||
cur_frame.is_done = False
|
||||
else:
|
||||
image = imagelib.normalize_channels(image, 3)
|
||||
image_mask = imagelib.normalize_channels(image_mask, 1)
|
||||
cur_frame.image = np.concatenate([image, image_mask], -1)
|
||||
|
||||
|
|
|
@ -1,25 +1,26 @@
|
|||
import sys
|
||||
import traceback
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from core import imagelib
|
||||
from core.cv2ex import *
|
||||
from core.interact import interact as io
|
||||
from facelib import FaceType, LandmarksProcessor
|
||||
from core.interact import interact as io
|
||||
from core.cv2ex import *
|
||||
|
||||
is_windows = sys.platform[0:3] == 'win'
|
||||
xseg_input_size = 256
|
||||
|
||||
def MergeMaskedFace (predictor_func, predictor_input_shape,
|
||||
face_enhancer_func,
|
||||
xseg_256_extract_func,
|
||||
cfg, frame_info, img_bgr_uint8, img_bgr, img_face_landmarks):
|
||||
|
||||
img_size = img_bgr.shape[1], img_bgr.shape[0]
|
||||
img_face_mask_a = LandmarksProcessor.get_image_hull_mask (img_bgr.shape, img_face_landmarks)
|
||||
|
||||
|
||||
out_img = img_bgr.copy()
|
||||
out_merging_mask_a = None
|
||||
|
||||
input_size = predictor_input_shape[0]
|
||||
mask_subres_size = input_size*4
|
||||
output_size = input_size
|
||||
|
@ -54,49 +55,45 @@ def MergeMaskedFace (predictor_func, predictor_input_shape,
|
|||
prd_face_bgr = np.clip(prd_face_bgr, 0, 1)
|
||||
|
||||
if cfg.super_resolution_power != 0:
|
||||
prd_face_mask_a_0 = cv2.resize (prd_face_mask_a_0, (output_size, output_size), interpolation=cv2.INTER_CUBIC)
|
||||
prd_face_dst_mask_a_0 = cv2.resize (prd_face_dst_mask_a_0, (output_size, output_size), interpolation=cv2.INTER_CUBIC)
|
||||
prd_face_mask_a_0 = cv2.resize (prd_face_mask_a_0, (output_size, output_size), cv2.INTER_CUBIC)
|
||||
prd_face_dst_mask_a_0 = cv2.resize (prd_face_dst_mask_a_0, (output_size, output_size), cv2.INTER_CUBIC)
|
||||
|
||||
if cfg.mask_mode == 0: #full
|
||||
wrk_face_mask_a_0 = np.ones_like(dst_face_mask_a_0)
|
||||
elif cfg.mask_mode == 1: #dst
|
||||
wrk_face_mask_a_0 = cv2.resize (dst_face_mask_a_0, (output_size,output_size), interpolation=cv2.INTER_CUBIC)
|
||||
if cfg.mask_mode == 1: #dst
|
||||
wrk_face_mask_a_0 = cv2.resize (dst_face_mask_a_0, (output_size,output_size), cv2.INTER_CUBIC)
|
||||
elif cfg.mask_mode == 2: #learned-prd
|
||||
wrk_face_mask_a_0 = prd_face_mask_a_0
|
||||
elif cfg.mask_mode == 3: #learned-dst
|
||||
wrk_face_mask_a_0 = prd_face_dst_mask_a_0
|
||||
elif cfg.mask_mode == 4: #learned-prd*learned-dst
|
||||
wrk_face_mask_a_0 = prd_face_mask_a_0*prd_face_dst_mask_a_0
|
||||
elif cfg.mask_mode == 5: #learned-prd+learned-dst
|
||||
wrk_face_mask_a_0 = np.clip( prd_face_mask_a_0+prd_face_dst_mask_a_0, 0, 1)
|
||||
elif cfg.mask_mode >= 6 and cfg.mask_mode <= 9: #XSeg modes
|
||||
if cfg.mask_mode == 6 or cfg.mask_mode == 8 or cfg.mask_mode == 9:
|
||||
elif cfg.mask_mode >= 5 and cfg.mask_mode <= 8: #XSeg modes
|
||||
if cfg.mask_mode == 5 or cfg.mask_mode == 7 or cfg.mask_mode == 8:
|
||||
# obtain XSeg-prd
|
||||
prd_face_xseg_bgr = cv2.resize (prd_face_bgr, (xseg_input_size,)*2, interpolation=cv2.INTER_CUBIC)
|
||||
prd_face_xseg_bgr = cv2.resize (prd_face_bgr, (xseg_input_size,)*2, cv2.INTER_CUBIC)
|
||||
prd_face_xseg_mask = xseg_256_extract_func(prd_face_xseg_bgr)
|
||||
X_prd_face_mask_a_0 = cv2.resize ( prd_face_xseg_mask, (output_size, output_size), interpolation=cv2.INTER_CUBIC)
|
||||
X_prd_face_mask_a_0 = cv2.resize ( prd_face_xseg_mask, (output_size, output_size), cv2.INTER_CUBIC)
|
||||
|
||||
if cfg.mask_mode >= 7 and cfg.mask_mode <= 9:
|
||||
if cfg.mask_mode >= 6 and cfg.mask_mode <= 8:
|
||||
# obtain XSeg-dst
|
||||
xseg_mat = LandmarksProcessor.get_transform_mat (img_face_landmarks, xseg_input_size, face_type=cfg.face_type)
|
||||
dst_face_xseg_bgr = cv2.warpAffine(img_bgr, xseg_mat, (xseg_input_size,)*2, flags=cv2.INTER_CUBIC )
|
||||
dst_face_xseg_mask = xseg_256_extract_func(dst_face_xseg_bgr)
|
||||
X_dst_face_mask_a_0 = cv2.resize (dst_face_xseg_mask, (output_size,output_size), interpolation=cv2.INTER_CUBIC)
|
||||
X_dst_face_mask_a_0 = cv2.resize (dst_face_xseg_mask, (output_size,output_size), cv2.INTER_CUBIC)
|
||||
|
||||
if cfg.mask_mode == 6: #'XSeg-prd'
|
||||
if cfg.mask_mode == 5: #'XSeg-prd'
|
||||
wrk_face_mask_a_0 = X_prd_face_mask_a_0
|
||||
elif cfg.mask_mode == 7: #'XSeg-dst'
|
||||
elif cfg.mask_mode == 6: #'XSeg-dst'
|
||||
wrk_face_mask_a_0 = X_dst_face_mask_a_0
|
||||
elif cfg.mask_mode == 8: #'XSeg-prd*XSeg-dst'
|
||||
elif cfg.mask_mode == 7: #'XSeg-prd*XSeg-dst'
|
||||
wrk_face_mask_a_0 = X_prd_face_mask_a_0 * X_dst_face_mask_a_0
|
||||
elif cfg.mask_mode == 9: #learned-prd*learned-dst*XSeg-prd*XSeg-dst
|
||||
elif cfg.mask_mode == 8: #learned-prd*learned-dst*XSeg-prd*XSeg-dst
|
||||
wrk_face_mask_a_0 = prd_face_mask_a_0 * prd_face_dst_mask_a_0 * X_prd_face_mask_a_0 * X_dst_face_mask_a_0
|
||||
|
||||
wrk_face_mask_a_0[ wrk_face_mask_a_0 < (1.0/255.0) ] = 0.0 # get rid of noise
|
||||
|
||||
# resize to mask_subres_size
|
||||
if wrk_face_mask_a_0.shape[0] != mask_subres_size:
|
||||
wrk_face_mask_a_0 = cv2.resize (wrk_face_mask_a_0, (mask_subres_size, mask_subres_size), interpolation=cv2.INTER_CUBIC)
|
||||
wrk_face_mask_a_0 = cv2.resize (wrk_face_mask_a_0, (mask_subres_size, mask_subres_size), cv2.INTER_CUBIC)
|
||||
|
||||
# process mask in local predicted space
|
||||
if 'raw' not in cfg.mode:
|
||||
|
@ -130,187 +127,184 @@ def MergeMaskedFace (predictor_func, predictor_input_shape,
|
|||
|
||||
img_face_mask_a = cv2.warpAffine( wrk_face_mask_a_0, face_mask_output_mat, img_size, np.zeros(img_bgr.shape[0:2], dtype=np.float32), flags=cv2.WARP_INVERSE_MAP | cv2.INTER_CUBIC )[...,None]
|
||||
img_face_mask_a = np.clip (img_face_mask_a, 0.0, 1.0)
|
||||
|
||||
img_face_mask_a [ img_face_mask_a < (1.0/255.0) ] = 0.0 # get rid of noise
|
||||
|
||||
if wrk_face_mask_a_0.shape[0] != output_size:
|
||||
wrk_face_mask_a_0 = cv2.resize (wrk_face_mask_a_0, (output_size,output_size), interpolation=cv2.INTER_CUBIC)
|
||||
wrk_face_mask_a_0 = cv2.resize (wrk_face_mask_a_0, (output_size,output_size), cv2.INTER_CUBIC)
|
||||
|
||||
wrk_face_mask_a = wrk_face_mask_a_0[...,None]
|
||||
wrk_face_mask_area_a = wrk_face_mask_a.copy()
|
||||
wrk_face_mask_area_a[wrk_face_mask_area_a>0] = 1.0
|
||||
|
||||
out_img = None
|
||||
out_merging_mask_a = None
|
||||
if cfg.mode == 'original':
|
||||
return img_bgr, img_face_mask_a
|
||||
|
||||
elif 'raw' in cfg.mode:
|
||||
if cfg.mode == 'raw-rgb':
|
||||
out_img_face = cv2.warpAffine( prd_face_bgr, face_output_mat, img_size, np.empty_like(img_bgr), cv2.WARP_INVERSE_MAP | cv2.INTER_CUBIC)
|
||||
out_img_face_mask = cv2.warpAffine( np.ones_like(prd_face_bgr), face_output_mat, img_size, np.empty_like(img_bgr), cv2.WARP_INVERSE_MAP | cv2.INTER_CUBIC)
|
||||
out_img = img_bgr*(1-out_img_face_mask) + out_img_face*out_img_face_mask
|
||||
out_img = cv2.warpAffine( prd_face_bgr, face_output_mat, img_size, out_img, cv2.WARP_INVERSE_MAP | cv2.INTER_CUBIC, cv2.BORDER_TRANSPARENT )
|
||||
out_merging_mask_a = img_face_mask_a
|
||||
|
||||
elif cfg.mode == 'raw-predict':
|
||||
out_img = prd_face_bgr
|
||||
out_merging_mask_a = wrk_face_mask_a
|
||||
else:
|
||||
raise ValueError(f"undefined raw type {cfg.mode}")
|
||||
|
||||
out_img = np.clip (out_img, 0.0, 1.0 )
|
||||
else:
|
||||
#averaging [lenx, leny, maskx, masky] by grayscale gradients of upscaled mask
|
||||
ar = []
|
||||
for i in range(1, 10):
|
||||
maxregion = np.argwhere( img_face_mask_a > i / 10.0 )
|
||||
if maxregion.size != 0:
|
||||
miny,minx = maxregion.min(axis=0)[:2]
|
||||
maxy,maxx = maxregion.max(axis=0)[:2]
|
||||
lenx = maxx - minx
|
||||
leny = maxy - miny
|
||||
if min(lenx,leny) >= 4:
|
||||
ar += [ [ lenx, leny] ]
|
||||
|
||||
# Process if the mask meets minimum size
|
||||
maxregion = np.argwhere( img_face_mask_a >= 0.1 )
|
||||
if maxregion.size != 0:
|
||||
miny,minx = maxregion.min(axis=0)[:2]
|
||||
maxy,maxx = maxregion.max(axis=0)[:2]
|
||||
lenx = maxx - minx
|
||||
leny = maxy - miny
|
||||
if min(lenx,leny) >= 4:
|
||||
wrk_face_mask_area_a = wrk_face_mask_a.copy()
|
||||
wrk_face_mask_area_a[wrk_face_mask_area_a>0] = 1.0
|
||||
if len(ar) > 0:
|
||||
|
||||
if 'seamless' not in cfg.mode and cfg.color_transfer_mode != 0:
|
||||
if cfg.color_transfer_mode == 1: #rct
|
||||
prd_face_bgr = imagelib.reinhard_color_transfer (prd_face_bgr, dst_face_bgr, target_mask=wrk_face_mask_area_a, source_mask=wrk_face_mask_area_a)
|
||||
elif cfg.color_transfer_mode == 2: #lct
|
||||
prd_face_bgr = imagelib.linear_color_transfer (prd_face_bgr, dst_face_bgr)
|
||||
elif cfg.color_transfer_mode == 3: #mkl
|
||||
prd_face_bgr = imagelib.color_transfer_mkl (prd_face_bgr, dst_face_bgr)
|
||||
elif cfg.color_transfer_mode == 4: #mkl-m
|
||||
prd_face_bgr = imagelib.color_transfer_mkl (prd_face_bgr*wrk_face_mask_area_a, dst_face_bgr*wrk_face_mask_area_a)
|
||||
elif cfg.color_transfer_mode == 5: #idt
|
||||
prd_face_bgr = imagelib.color_transfer_idt (prd_face_bgr, dst_face_bgr)
|
||||
elif cfg.color_transfer_mode == 6: #idt-m
|
||||
prd_face_bgr = imagelib.color_transfer_idt (prd_face_bgr*wrk_face_mask_area_a, dst_face_bgr*wrk_face_mask_area_a)
|
||||
elif cfg.color_transfer_mode == 7: #sot-m
|
||||
prd_face_bgr = imagelib.color_transfer_sot (prd_face_bgr*wrk_face_mask_area_a, dst_face_bgr*wrk_face_mask_area_a, steps=10, batch_size=30)
|
||||
prd_face_bgr = np.clip (prd_face_bgr, 0.0, 1.0)
|
||||
elif cfg.color_transfer_mode == 8: #mix-m
|
||||
prd_face_bgr = imagelib.color_transfer_mix (prd_face_bgr*wrk_face_mask_area_a, dst_face_bgr*wrk_face_mask_area_a)
|
||||
if 'seamless' not in cfg.mode and cfg.color_transfer_mode != 0:
|
||||
if cfg.color_transfer_mode == 1: #rct
|
||||
prd_face_bgr = imagelib.reinhard_color_transfer ( np.clip( prd_face_bgr*wrk_face_mask_area_a*255, 0, 255).astype(np.uint8),
|
||||
np.clip( dst_face_bgr*wrk_face_mask_area_a*255, 0, 255).astype(np.uint8), )
|
||||
|
||||
if cfg.mode == 'hist-match':
|
||||
hist_mask_a = np.ones ( prd_face_bgr.shape[:2] + (1,) , dtype=np.float32)
|
||||
prd_face_bgr = np.clip( prd_face_bgr.astype(np.float32) / 255.0, 0.0, 1.0)
|
||||
elif cfg.color_transfer_mode == 2: #lct
|
||||
prd_face_bgr = imagelib.linear_color_transfer (prd_face_bgr, dst_face_bgr)
|
||||
elif cfg.color_transfer_mode == 3: #mkl
|
||||
prd_face_bgr = imagelib.color_transfer_mkl (prd_face_bgr, dst_face_bgr)
|
||||
elif cfg.color_transfer_mode == 4: #mkl-m
|
||||
prd_face_bgr = imagelib.color_transfer_mkl (prd_face_bgr*wrk_face_mask_area_a, dst_face_bgr*wrk_face_mask_area_a)
|
||||
elif cfg.color_transfer_mode == 5: #idt
|
||||
prd_face_bgr = imagelib.color_transfer_idt (prd_face_bgr, dst_face_bgr)
|
||||
elif cfg.color_transfer_mode == 6: #idt-m
|
||||
prd_face_bgr = imagelib.color_transfer_idt (prd_face_bgr*wrk_face_mask_area_a, dst_face_bgr*wrk_face_mask_area_a)
|
||||
elif cfg.color_transfer_mode == 7: #sot-m
|
||||
prd_face_bgr = imagelib.color_transfer_sot (prd_face_bgr*wrk_face_mask_area_a, dst_face_bgr*wrk_face_mask_area_a)
|
||||
prd_face_bgr = np.clip (prd_face_bgr, 0.0, 1.0)
|
||||
elif cfg.color_transfer_mode == 8: #mix-m
|
||||
prd_face_bgr = imagelib.color_transfer_mix (prd_face_bgr*wrk_face_mask_area_a, dst_face_bgr*wrk_face_mask_area_a)
|
||||
|
||||
if cfg.masked_hist_match:
|
||||
hist_mask_a *= wrk_face_mask_area_a
|
||||
if cfg.mode == 'hist-match':
|
||||
hist_mask_a = np.ones ( prd_face_bgr.shape[:2] + (1,) , dtype=np.float32)
|
||||
|
||||
white = (1.0-hist_mask_a)* np.ones ( prd_face_bgr.shape[:2] + (1,) , dtype=np.float32)
|
||||
if cfg.masked_hist_match:
|
||||
hist_mask_a *= wrk_face_mask_area_a
|
||||
|
||||
hist_match_1 = prd_face_bgr*hist_mask_a + white
|
||||
hist_match_1[ hist_match_1 > 1.0 ] = 1.0
|
||||
white = (1.0-hist_mask_a)* np.ones ( prd_face_bgr.shape[:2] + (1,) , dtype=np.float32)
|
||||
|
||||
hist_match_2 = dst_face_bgr*hist_mask_a + white
|
||||
hist_match_2[ hist_match_1 > 1.0 ] = 1.0
|
||||
hist_match_1 = prd_face_bgr*hist_mask_a + white
|
||||
hist_match_1[ hist_match_1 > 1.0 ] = 1.0
|
||||
|
||||
prd_face_bgr = imagelib.color_hist_match(hist_match_1, hist_match_2, cfg.hist_match_threshold ).astype(dtype=np.float32)
|
||||
hist_match_2 = dst_face_bgr*hist_mask_a + white
|
||||
hist_match_2[ hist_match_1 > 1.0 ] = 1.0
|
||||
|
||||
if 'seamless' in cfg.mode:
|
||||
#mask used for cv2.seamlessClone
|
||||
img_face_seamless_mask_a = None
|
||||
for i in range(1,10):
|
||||
a = img_face_mask_a > i / 10.0
|
||||
if len(np.argwhere(a)) == 0:
|
||||
continue
|
||||
img_face_seamless_mask_a = img_face_mask_a.copy()
|
||||
img_face_seamless_mask_a[a] = 1.0
|
||||
img_face_seamless_mask_a[img_face_seamless_mask_a <= i / 10.0] = 0.0
|
||||
break
|
||||
prd_face_bgr = imagelib.color_hist_match(hist_match_1, hist_match_2, cfg.hist_match_threshold ).astype(dtype=np.float32)
|
||||
|
||||
out_img = cv2.warpAffine( prd_face_bgr, face_output_mat, img_size, np.empty_like(img_bgr), cv2.WARP_INVERSE_MAP | cv2.INTER_CUBIC )
|
||||
out_img = np.clip(out_img, 0.0, 1.0)
|
||||
if 'seamless' in cfg.mode:
|
||||
#mask used for cv2.seamlessClone
|
||||
img_face_seamless_mask_a = None
|
||||
for i in range(1,10):
|
||||
a = img_face_mask_a > i / 10.0
|
||||
if len(np.argwhere(a)) == 0:
|
||||
continue
|
||||
img_face_seamless_mask_a = img_face_mask_a.copy()
|
||||
img_face_seamless_mask_a[a] = 1.0
|
||||
img_face_seamless_mask_a[img_face_seamless_mask_a <= i / 10.0] = 0.0
|
||||
break
|
||||
|
||||
if 'seamless' in cfg.mode:
|
||||
try:
|
||||
#calc same bounding rect and center point as in cv2.seamlessClone to prevent jittering (not flickering)
|
||||
l,t,w,h = cv2.boundingRect( (img_face_seamless_mask_a*255).astype(np.uint8) )
|
||||
s_maskx, s_masky = int(l+w/2), int(t+h/2)
|
||||
out_img = cv2.seamlessClone( (out_img*255).astype(np.uint8), img_bgr_uint8, (img_face_seamless_mask_a*255).astype(np.uint8), (s_maskx,s_masky) , cv2.NORMAL_CLONE )
|
||||
out_img = out_img.astype(dtype=np.float32) / 255.0
|
||||
except Exception as e:
|
||||
#seamlessClone may fail in some cases
|
||||
e_str = traceback.format_exc()
|
||||
out_img = cv2.warpAffine( prd_face_bgr, face_output_mat, img_size, out_img, cv2.WARP_INVERSE_MAP | cv2.INTER_CUBIC, cv2.BORDER_TRANSPARENT )
|
||||
|
||||
if 'MemoryError' in e_str:
|
||||
raise Exception("Seamless fail: " + e_str) #reraise MemoryError in order to reprocess this data by other processes
|
||||
else:
|
||||
print ("Seamless fail: " + e_str)
|
||||
out_img = np.clip(out_img, 0.0, 1.0)
|
||||
|
||||
cfg_mp = cfg.motion_blur_power / 100.0
|
||||
if 'seamless' in cfg.mode:
|
||||
try:
|
||||
#calc same bounding rect and center point as in cv2.seamlessClone to prevent jittering (not flickering)
|
||||
l,t,w,h = cv2.boundingRect( (img_face_seamless_mask_a*255).astype(np.uint8) )
|
||||
s_maskx, s_masky = int(l+w/2), int(t+h/2)
|
||||
out_img = cv2.seamlessClone( (out_img*255).astype(np.uint8), img_bgr_uint8, (img_face_seamless_mask_a*255).astype(np.uint8), (s_maskx,s_masky) , cv2.NORMAL_CLONE )
|
||||
out_img = out_img.astype(dtype=np.float32) / 255.0
|
||||
except Exception as e:
|
||||
#seamlessClone may fail in some cases
|
||||
e_str = traceback.format_exc()
|
||||
|
||||
out_img = img_bgr*(1-img_face_mask_a) + (out_img*img_face_mask_a)
|
||||
|
||||
if ('seamless' in cfg.mode and cfg.color_transfer_mode != 0) or \
|
||||
cfg.mode == 'seamless-hist-match' or \
|
||||
cfg_mp != 0 or \
|
||||
cfg.blursharpen_amount != 0 or \
|
||||
cfg.image_denoise_power != 0 or \
|
||||
cfg.bicubic_degrade_power != 0:
|
||||
|
||||
out_face_bgr = cv2.warpAffine( out_img, face_mat, (output_size, output_size), flags=cv2.INTER_CUBIC )
|
||||
|
||||
if 'seamless' in cfg.mode and cfg.color_transfer_mode != 0:
|
||||
if cfg.color_transfer_mode == 1:
|
||||
out_face_bgr = imagelib.reinhard_color_transfer (out_face_bgr, dst_face_bgr, target_mask=wrk_face_mask_area_a, source_mask=wrk_face_mask_area_a)
|
||||
elif cfg.color_transfer_mode == 2: #lct
|
||||
out_face_bgr = imagelib.linear_color_transfer (out_face_bgr, dst_face_bgr)
|
||||
elif cfg.color_transfer_mode == 3: #mkl
|
||||
out_face_bgr = imagelib.color_transfer_mkl (out_face_bgr, dst_face_bgr)
|
||||
elif cfg.color_transfer_mode == 4: #mkl-m
|
||||
out_face_bgr = imagelib.color_transfer_mkl (out_face_bgr*wrk_face_mask_area_a, dst_face_bgr*wrk_face_mask_area_a)
|
||||
elif cfg.color_transfer_mode == 5: #idt
|
||||
out_face_bgr = imagelib.color_transfer_idt (out_face_bgr, dst_face_bgr)
|
||||
elif cfg.color_transfer_mode == 6: #idt-m
|
||||
out_face_bgr = imagelib.color_transfer_idt (out_face_bgr*wrk_face_mask_area_a, dst_face_bgr*wrk_face_mask_area_a)
|
||||
elif cfg.color_transfer_mode == 7: #sot-m
|
||||
out_face_bgr = imagelib.color_transfer_sot (out_face_bgr*wrk_face_mask_area_a, dst_face_bgr*wrk_face_mask_area_a, steps=10, batch_size=30)
|
||||
out_face_bgr = np.clip (out_face_bgr, 0.0, 1.0)
|
||||
elif cfg.color_transfer_mode == 8: #mix-m
|
||||
out_face_bgr = imagelib.color_transfer_mix (out_face_bgr*wrk_face_mask_area_a, dst_face_bgr*wrk_face_mask_area_a)
|
||||
|
||||
if cfg.mode == 'seamless-hist-match':
|
||||
out_face_bgr = imagelib.color_hist_match(out_face_bgr, dst_face_bgr, cfg.hist_match_threshold)
|
||||
|
||||
if cfg_mp != 0:
|
||||
k_size = int(frame_info.motion_power*cfg_mp)
|
||||
if k_size >= 1:
|
||||
k_size = np.clip (k_size+1, 2, 50)
|
||||
if cfg.super_resolution_power != 0:
|
||||
k_size *= 2
|
||||
out_face_bgr = imagelib.LinearMotionBlur (out_face_bgr, k_size , frame_info.motion_deg)
|
||||
|
||||
if cfg.blursharpen_amount != 0:
|
||||
out_face_bgr = imagelib.blursharpen ( out_face_bgr, cfg.sharpen_mode, 3, cfg.blursharpen_amount)
|
||||
|
||||
if cfg.image_denoise_power != 0:
|
||||
n = cfg.image_denoise_power
|
||||
while n > 0:
|
||||
img_bgr_denoised = cv2.medianBlur(img_bgr, 5)
|
||||
if int(n / 100) != 0:
|
||||
img_bgr = img_bgr_denoised
|
||||
else:
|
||||
pass_power = (n % 100) / 100.0
|
||||
img_bgr = img_bgr*(1.0-pass_power)+img_bgr_denoised*pass_power
|
||||
n = max(n-10,0)
|
||||
|
||||
if cfg.bicubic_degrade_power != 0:
|
||||
p = 1.0 - cfg.bicubic_degrade_power / 101.0
|
||||
img_bgr_downscaled = cv2.resize (img_bgr, ( int(img_size[0]*p), int(img_size[1]*p ) ), interpolation=cv2.INTER_CUBIC)
|
||||
img_bgr = cv2.resize (img_bgr_downscaled, img_size, interpolation=cv2.INTER_CUBIC)
|
||||
|
||||
new_out = cv2.warpAffine( out_face_bgr, face_mat, img_size, np.empty_like(img_bgr), cv2.WARP_INVERSE_MAP | cv2.INTER_CUBIC )
|
||||
|
||||
out_img = np.clip( img_bgr*(1-img_face_mask_a) + (new_out*img_face_mask_a) , 0, 1.0 )
|
||||
|
||||
if cfg.color_degrade_power != 0:
|
||||
out_img_reduced = imagelib.reduce_colors(out_img, 256)
|
||||
if cfg.color_degrade_power == 100:
|
||||
out_img = out_img_reduced
|
||||
if 'MemoryError' in e_str:
|
||||
raise Exception("Seamless fail: " + e_str) #reraise MemoryError in order to reprocess this data by other processes
|
||||
else:
|
||||
alpha = cfg.color_degrade_power / 100.0
|
||||
out_img = (out_img*(1.0-alpha) + out_img_reduced*alpha)
|
||||
out_merging_mask_a = img_face_mask_a
|
||||
print ("Seamless fail: " + e_str)
|
||||
|
||||
if out_img is None:
|
||||
out_img = img_bgr.copy()
|
||||
|
||||
out_img = img_bgr*(1-img_face_mask_a) + (out_img*img_face_mask_a)
|
||||
|
||||
out_face_bgr = cv2.warpAffine( out_img, face_mat, (output_size, output_size), flags=cv2.INTER_CUBIC )
|
||||
|
||||
if 'seamless' in cfg.mode and cfg.color_transfer_mode != 0:
|
||||
if cfg.color_transfer_mode == 1:
|
||||
out_face_bgr = imagelib.reinhard_color_transfer ( np.clip(out_face_bgr*wrk_face_mask_area_a*255, 0, 255).astype(np.uint8),
|
||||
np.clip(dst_face_bgr*wrk_face_mask_area_a*255, 0, 255).astype(np.uint8) )
|
||||
out_face_bgr = np.clip( out_face_bgr.astype(np.float32) / 255.0, 0.0, 1.0)
|
||||
elif cfg.color_transfer_mode == 2: #lct
|
||||
out_face_bgr = imagelib.linear_color_transfer (out_face_bgr, dst_face_bgr)
|
||||
elif cfg.color_transfer_mode == 3: #mkl
|
||||
out_face_bgr = imagelib.color_transfer_mkl (out_face_bgr, dst_face_bgr)
|
||||
elif cfg.color_transfer_mode == 4: #mkl-m
|
||||
out_face_bgr = imagelib.color_transfer_mkl (out_face_bgr*wrk_face_mask_area_a, dst_face_bgr*wrk_face_mask_area_a)
|
||||
elif cfg.color_transfer_mode == 5: #idt
|
||||
out_face_bgr = imagelib.color_transfer_idt (out_face_bgr, dst_face_bgr)
|
||||
elif cfg.color_transfer_mode == 6: #idt-m
|
||||
out_face_bgr = imagelib.color_transfer_idt (out_face_bgr*wrk_face_mask_area_a, dst_face_bgr*wrk_face_mask_area_a)
|
||||
elif cfg.color_transfer_mode == 7: #sot-m
|
||||
out_face_bgr = imagelib.color_transfer_sot (out_face_bgr*wrk_face_mask_area_a, dst_face_bgr*wrk_face_mask_area_a)
|
||||
out_face_bgr = np.clip (out_face_bgr, 0.0, 1.0)
|
||||
elif cfg.color_transfer_mode == 8: #mix-m
|
||||
out_face_bgr = imagelib.color_transfer_mix (out_face_bgr*wrk_face_mask_area_a, dst_face_bgr*wrk_face_mask_area_a)
|
||||
|
||||
if cfg.mode == 'seamless-hist-match':
|
||||
out_face_bgr = imagelib.color_hist_match(out_face_bgr, dst_face_bgr, cfg.hist_match_threshold)
|
||||
|
||||
cfg_mp = cfg.motion_blur_power / 100.0
|
||||
if cfg_mp != 0:
|
||||
k_size = int(frame_info.motion_power*cfg_mp)
|
||||
if k_size >= 1:
|
||||
k_size = np.clip (k_size+1, 2, 50)
|
||||
if cfg.super_resolution_power != 0:
|
||||
k_size *= 2
|
||||
out_face_bgr = imagelib.LinearMotionBlur (out_face_bgr, k_size , frame_info.motion_deg)
|
||||
|
||||
if cfg.blursharpen_amount != 0:
|
||||
out_face_bgr = imagelib.blursharpen ( out_face_bgr, cfg.sharpen_mode, 3, cfg.blursharpen_amount)
|
||||
|
||||
|
||||
if cfg.image_denoise_power != 0:
|
||||
n = cfg.image_denoise_power
|
||||
while n > 0:
|
||||
img_bgr_denoised = cv2.medianBlur(img_bgr, 5)
|
||||
if int(n / 100) != 0:
|
||||
img_bgr = img_bgr_denoised
|
||||
else:
|
||||
pass_power = (n % 100) / 100.0
|
||||
img_bgr = img_bgr*(1.0-pass_power)+img_bgr_denoised*pass_power
|
||||
n = max(n-10,0)
|
||||
|
||||
if cfg.bicubic_degrade_power != 0:
|
||||
p = 1.0 - cfg.bicubic_degrade_power / 101.0
|
||||
img_bgr_downscaled = cv2.resize (img_bgr, ( int(img_size[0]*p), int(img_size[1]*p ) ), cv2.INTER_CUBIC)
|
||||
img_bgr = cv2.resize (img_bgr_downscaled, img_size, cv2.INTER_CUBIC)
|
||||
|
||||
new_out = cv2.warpAffine( out_face_bgr, face_mat, img_size, img_bgr.copy(), cv2.WARP_INVERSE_MAP | cv2.INTER_CUBIC, cv2.BORDER_TRANSPARENT )
|
||||
out_img = np.clip( img_bgr*(1-img_face_mask_a) + (new_out*img_face_mask_a) , 0, 1.0 )
|
||||
|
||||
if cfg.color_degrade_power != 0:
|
||||
out_img_reduced = imagelib.reduce_colors(out_img, 256)
|
||||
if cfg.color_degrade_power == 100:
|
||||
out_img = out_img_reduced
|
||||
else:
|
||||
alpha = cfg.color_degrade_power / 100.0
|
||||
out_img = (out_img*(1.0-alpha) + out_img_reduced*alpha)
|
||||
|
||||
out_merging_mask_a = img_face_mask_a
|
||||
|
||||
return out_img, out_merging_mask_a
|
||||
|
||||
|
|
|
@ -81,16 +81,14 @@ mode_dict = {0:'original',
|
|||
|
||||
mode_str_dict = { mode_dict[key] : key for key in mode_dict.keys() }
|
||||
|
||||
mask_mode_dict = {0:'full',
|
||||
1:'dst',
|
||||
mask_mode_dict = {1:'dst',
|
||||
2:'learned-prd',
|
||||
3:'learned-dst',
|
||||
4:'learned-prd*learned-dst',
|
||||
5:'learned-prd+learned-dst',
|
||||
6:'XSeg-prd',
|
||||
7:'XSeg-dst',
|
||||
8:'XSeg-prd*XSeg-dst',
|
||||
9:'learned-prd*learned-dst*XSeg-prd*XSeg-dst'
|
||||
5:'XSeg-prd',
|
||||
6:'XSeg-dst',
|
||||
7:'XSeg-prd*XSeg-dst',
|
||||
8:'learned-prd*learned-dst*XSeg-prd*XSeg-dst'
|
||||
}
|
||||
|
||||
|
||||
|
|
Before Width: | Height: | Size: 260 KiB After Width: | Height: | Size: 310 KiB |
|
@ -1,7 +1,6 @@
|
|||
import colorsys
|
||||
import inspect
|
||||
import json
|
||||
import multiprocessing
|
||||
import operator
|
||||
import os
|
||||
import pickle
|
||||
|
@ -13,16 +12,16 @@ from pathlib import Path
|
|||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from core import imagelib, pathex
|
||||
from core.cv2ex import *
|
||||
from core import imagelib
|
||||
from core.interact import interact as io
|
||||
from core.leras import nn
|
||||
from samplelib import SampleGeneratorBase
|
||||
from core import pathex
|
||||
from core.cv2ex import *
|
||||
|
||||
|
||||
class ModelBase(object):
|
||||
def __init__(self, is_training=False,
|
||||
is_exporting=False,
|
||||
saved_models_path=None,
|
||||
training_data_src_path=None,
|
||||
training_data_dst_path=None,
|
||||
|
@ -37,7 +36,6 @@ class ModelBase(object):
|
|||
silent_start=False,
|
||||
**kwargs):
|
||||
self.is_training = is_training
|
||||
self.is_exporting = is_exporting
|
||||
self.saved_models_path = saved_models_path
|
||||
self.training_data_src_path = training_data_src_path
|
||||
self.training_data_dst_path = training_data_dst_path
|
||||
|
@ -134,7 +132,6 @@ class ModelBase(object):
|
|||
|
||||
self.iter = 0
|
||||
self.options = {}
|
||||
self.options_show_override = {}
|
||||
self.loss_history = []
|
||||
self.sample_for_preview = None
|
||||
self.choosed_gpu_indexes = None
|
||||
|
@ -187,13 +184,10 @@ class ModelBase(object):
|
|||
self.write_preview_history = self.options.get('write_preview_history', False)
|
||||
self.target_iter = self.options.get('target_iter',0)
|
||||
self.random_flip = self.options.get('random_flip',True)
|
||||
self.random_src_flip = self.options.get('random_src_flip', False)
|
||||
self.random_dst_flip = self.options.get('random_dst_flip', True)
|
||||
|
||||
self.on_initialize()
|
||||
self.options['batch_size'] = self.batch_size
|
||||
|
||||
self.preview_history_writer = None
|
||||
if self.is_training:
|
||||
self.preview_history_path = self.saved_models_path / ( f'{self.get_model_name()}_history' )
|
||||
self.autobackups_path = self.saved_models_path / ( f'{self.get_model_name()}_autobackups' )
|
||||
|
@ -226,17 +220,15 @@ class ModelBase(object):
|
|||
def update_sample_for_preview(self, choose_preview_history=False, force_new=False):
|
||||
if self.sample_for_preview is None or choose_preview_history or force_new:
|
||||
if choose_preview_history and io.is_support_windows():
|
||||
wnd_name = "[p] - next. [space] - switch preview type. [enter] - confirm."
|
||||
io.log_info (f"Choose image for the preview history. {wnd_name}")
|
||||
io.log_info ("Choose image for the preview history. [p] - next. [enter] - confirm.")
|
||||
wnd_name = "[p] - next. [enter] - confirm."
|
||||
io.named_window(wnd_name)
|
||||
io.capture_keys(wnd_name)
|
||||
choosed = False
|
||||
preview_id_counter = 0
|
||||
while not choosed:
|
||||
self.sample_for_preview = self.generate_next_samples()
|
||||
previews = self.get_history_previews()
|
||||
|
||||
io.show_image( wnd_name, ( previews[preview_id_counter % len(previews) ][1] *255).astype(np.uint8) )
|
||||
preview = self.get_static_preview()
|
||||
io.show_image( wnd_name, (preview*255).astype(np.uint8) )
|
||||
|
||||
while True:
|
||||
key_events = io.get_key_events(wnd_name)
|
||||
|
@ -244,9 +236,6 @@ class ModelBase(object):
|
|||
if key == ord('\n') or key == ord('\r'):
|
||||
choosed = True
|
||||
break
|
||||
elif key == ord(' '):
|
||||
preview_id_counter += 1
|
||||
break
|
||||
elif key == ord('p'):
|
||||
break
|
||||
|
||||
|
@ -260,7 +249,7 @@ class ModelBase(object):
|
|||
self.sample_for_preview = self.generate_next_samples()
|
||||
|
||||
try:
|
||||
self.get_history_previews()
|
||||
self.get_static_preview()
|
||||
except:
|
||||
self.sample_for_preview = self.generate_next_samples()
|
||||
|
||||
|
@ -302,23 +291,9 @@ class ModelBase(object):
|
|||
default_random_flip = self.load_or_def_option('random_flip', True)
|
||||
self.options['random_flip'] = io.input_bool("Flip faces randomly", default_random_flip, help_message="Predicted face will look more naturally without this option, but src faceset should cover all face directions as dst faceset.")
|
||||
|
||||
def ask_random_src_flip(self):
|
||||
default_random_src_flip = self.load_or_def_option('random_src_flip', False)
|
||||
self.options['random_src_flip'] = io.input_bool("Flip SRC faces randomly", default_random_src_flip, help_message="Random horizontal flip SRC faceset. Covers more angles, but the face may look less naturally.")
|
||||
|
||||
def ask_random_dst_flip(self):
|
||||
default_random_dst_flip = self.load_or_def_option('random_dst_flip', True)
|
||||
self.options['random_dst_flip'] = io.input_bool("Flip DST faces randomly", default_random_dst_flip, help_message="Random horizontal flip DST faceset. Makes generalization of src->dst better, if src random flip is not enabled.")
|
||||
|
||||
def ask_batch_size(self, suggest_batch_size=None, range=None):
|
||||
def ask_batch_size(self, suggest_batch_size=None):
|
||||
default_batch_size = self.load_or_def_option('batch_size', suggest_batch_size or self.batch_size)
|
||||
|
||||
batch_size = max(0, io.input_int("Batch_size", default_batch_size, valid_range=range, help_message="Larger batch size is better for NN's generalization, but it can cause Out of Memory error. Tune this value for your videocard manually."))
|
||||
|
||||
if range is not None:
|
||||
batch_size = np.clip(batch_size, range[0], range[1])
|
||||
|
||||
self.options['batch_size'] = self.batch_size = batch_size
|
||||
self.options['batch_size'] = self.batch_size = max(0, io.input_int("Batch_size", default_batch_size, help_message="Larger batch size is better for NN's generalization, but it can cause Out of Memory error. Tune this value for your videocard manually."))
|
||||
|
||||
|
||||
#overridable
|
||||
|
@ -349,7 +324,7 @@ class ModelBase(object):
|
|||
return ( ('loss_src', 0), ('loss_dst', 0) )
|
||||
|
||||
#overridable
|
||||
def onGetPreview(self, sample, for_history=False):
|
||||
def onGetPreview(self, sample):
|
||||
#you can return multiple previews
|
||||
#return [ ('preview_name',preview_rgb), ... ]
|
||||
return []
|
||||
|
@ -379,13 +354,8 @@ class ModelBase(object):
|
|||
def get_previews(self):
|
||||
return self.onGetPreview ( self.last_sample )
|
||||
|
||||
def get_history_previews(self):
|
||||
return self.onGetPreview (self.sample_for_preview, for_history=True)
|
||||
|
||||
def get_preview_history_writer(self):
|
||||
if self.preview_history_writer is None:
|
||||
self.preview_history_writer = PreviewHistoryWriter()
|
||||
return self.preview_history_writer
|
||||
def get_static_preview(self):
|
||||
return self.onGetPreview (self.sample_for_preview)[0][1] #first preview, and bgr
|
||||
|
||||
def save(self):
|
||||
Path( self.get_summary_path() ).write_text( self.get_summary_text() )
|
||||
|
@ -442,8 +412,10 @@ class ModelBase(object):
|
|||
name, bgr = previews[i]
|
||||
plist += [ (bgr, idx_backup_path / ( ('preview_%s.jpg') % (name)) ) ]
|
||||
|
||||
if len(plist) != 0:
|
||||
self.get_preview_history_writer().post(plist, self.loss_history, self.iter)
|
||||
for preview, filepath in plist:
|
||||
preview_lh = ModelBase.get_loss_history_preview(self.loss_history, self.iter, preview.shape[1], preview.shape[2])
|
||||
img = (np.concatenate ( [preview_lh, preview], axis=0 ) * 255).astype(np.uint8)
|
||||
cv2_imwrite (filepath, img )
|
||||
|
||||
def debug_one_iter(self):
|
||||
images = []
|
||||
|
@ -464,10 +436,6 @@ class ModelBase(object):
|
|||
self.last_sample = sample
|
||||
return sample
|
||||
|
||||
#overridable
|
||||
def should_save_preview_history(self):
|
||||
return (not io.is_colab() and self.iter % 10 == 0) or (io.is_colab() and self.iter % 100 == 0)
|
||||
|
||||
def train_one_iter(self):
|
||||
|
||||
iter_time = time.time()
|
||||
|
@ -476,7 +444,8 @@ class ModelBase(object):
|
|||
|
||||
self.loss_history.append ( [float(loss[1]) for loss in losses] )
|
||||
|
||||
if self.should_save_preview_history():
|
||||
if (not io.is_colab() and self.iter % 10 == 0) or \
|
||||
(io.is_colab() and self.iter % 100 == 0):
|
||||
plist = []
|
||||
|
||||
if io.is_colab():
|
||||
|
@ -486,16 +455,12 @@ class ModelBase(object):
|
|||
plist += [ (bgr, self.get_strpath_storage_for_file('preview_%s.jpg' % (name) ) ) ]
|
||||
|
||||
if self.write_preview_history:
|
||||
previews = self.get_history_previews()
|
||||
for i in range(len(previews)):
|
||||
name, bgr = previews[i]
|
||||
path = self.preview_history_path / name
|
||||
plist += [ ( bgr, str ( path / ( f'{self.iter:07d}.jpg') ) ) ]
|
||||
if not io.is_colab():
|
||||
plist += [ ( bgr, str ( path / ( '_last.jpg' ) )) ]
|
||||
plist += [ (self.get_static_preview(), str (self.preview_history_path / ('%.6d.jpg' % (self.iter))) ) ]
|
||||
|
||||
if len(plist) != 0:
|
||||
self.get_preview_history_writer().post(plist, self.loss_history, self.iter)
|
||||
for preview, filepath in plist:
|
||||
preview_lh = ModelBase.get_loss_history_preview(self.loss_history, self.iter, preview.shape[1], preview.shape[2])
|
||||
img = (np.concatenate ( [preview_lh, preview], axis=0 ) * 255).astype(np.uint8)
|
||||
cv2_imwrite (filepath, img )
|
||||
|
||||
self.iter += 1
|
||||
|
||||
|
@ -545,13 +510,10 @@ class ModelBase(object):
|
|||
return self.get_strpath_storage_for_file('summary.txt')
|
||||
|
||||
def get_summary_text(self):
|
||||
visible_options = self.options.copy()
|
||||
visible_options.update(self.options_show_override)
|
||||
|
||||
###Generate text summary of model hyperparameters
|
||||
#Find the longest key name and value string. Used as column widths.
|
||||
width_name = max([len(k) for k in visible_options.keys()] + [17]) + 1 # Single space buffer to left edge. Minimum of 17, the length of the longest static string used "Current iteration"
|
||||
width_value = max([len(str(x)) for x in visible_options.values()] + [len(str(self.get_iter())), len(self.get_model_name())]) + 1 # Single space buffer to right edge
|
||||
width_name = max([len(k) for k in self.options.keys()] + [17]) + 1 # Single space buffer to left edge. Minimum of 17, the length of the longest static string used "Current iteration"
|
||||
width_value = max([len(str(x)) for x in self.options.values()] + [len(str(self.get_iter())), len(self.get_model_name())]) + 1 # Single space buffer to right edge
|
||||
if len(self.device_config.devices) != 0: #Check length of GPU names
|
||||
width_value = max([len(device.name)+1 for device in self.device_config.devices] + [width_value])
|
||||
width_total = width_name + width_value + 2 #Plus 2 for ": "
|
||||
|
@ -566,8 +528,8 @@ class ModelBase(object):
|
|||
|
||||
summary_text += [f'=={" Model Options ":-^{width_total}}=='] # Model options
|
||||
summary_text += [f'=={" "*width_total}==']
|
||||
for key in visible_options.keys():
|
||||
summary_text += [f'=={key: >{width_name}}: {str(visible_options[key]): <{width_value}}=='] # visible_options key/value pairs
|
||||
for key in self.options.keys():
|
||||
summary_text += [f'=={key: >{width_name}}: {str(self.options[key]): <{width_value}}=='] # self.options key/value pairs
|
||||
summary_text += [f'=={" "*width_total}==']
|
||||
|
||||
summary_text += [f'=={" Running On ":-^{width_total}}=='] # Training hardware info
|
||||
|
@ -645,41 +607,3 @@ class ModelBase(object):
|
|||
|
||||
lh_img[last_line_t:last_line_b, 0:w] += imagelib.get_text_image ( (last_line_b-last_line_t,w,c), lh_text, color=[0.8]*c )
|
||||
return lh_img
|
||||
|
||||
class PreviewHistoryWriter():
|
||||
def __init__(self):
|
||||
self.sq = multiprocessing.Queue()
|
||||
self.p = multiprocessing.Process(target=self.process, args=( self.sq, ))
|
||||
self.p.daemon = True
|
||||
self.p.start()
|
||||
|
||||
def process(self, sq):
|
||||
while True:
|
||||
while not sq.empty():
|
||||
plist, loss_history, iter = sq.get()
|
||||
|
||||
preview_lh_cache = {}
|
||||
for preview, filepath in plist:
|
||||
filepath = Path(filepath)
|
||||
i = (preview.shape[1], preview.shape[2])
|
||||
|
||||
preview_lh = preview_lh_cache.get(i, None)
|
||||
if preview_lh is None:
|
||||
preview_lh = ModelBase.get_loss_history_preview(loss_history, iter, preview.shape[1], preview.shape[2])
|
||||
preview_lh_cache[i] = preview_lh
|
||||
|
||||
img = (np.concatenate ( [preview_lh, preview], axis=0 ) * 255).astype(np.uint8)
|
||||
|
||||
filepath.parent.mkdir(parents=True, exist_ok=True)
|
||||
cv2_imwrite (filepath, img )
|
||||
|
||||
time.sleep(0.01)
|
||||
|
||||
def post(self, plist, loss_history, iter):
|
||||
self.sq.put ( (plist, loss_history, iter) )
|
||||
|
||||
# disable pickling
|
||||
def __getstate__(self):
|
||||
return dict()
|
||||
def __setstate__(self, d):
|
||||
self.__dict__.update(d)
|
||||
|
|
|
@ -1,725 +0,0 @@
|
|||
import multiprocessing
|
||||
import operator
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
|
||||
from core import mathlib
|
||||
from core.interact import interact as io
|
||||
from core.leras import nn
|
||||
from facelib import FaceType
|
||||
from models import ModelBase
|
||||
from samplelib import *
|
||||
from core.cv2ex import *
|
||||
|
||||
class AMPModel(ModelBase):
|
||||
|
||||
#override
|
||||
def on_initialize_options(self):
|
||||
default_resolution = self.options['resolution'] = self.load_or_def_option('resolution', 224)
|
||||
default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'wf')
|
||||
default_models_opt_on_gpu = self.options['models_opt_on_gpu'] = self.load_or_def_option('models_opt_on_gpu', True)
|
||||
|
||||
default_ae_dims = self.options['ae_dims'] = self.load_or_def_option('ae_dims', 256)
|
||||
default_inter_dims = self.options['inter_dims'] = self.load_or_def_option('inter_dims', 1024)
|
||||
|
||||
default_e_dims = self.options['e_dims'] = self.load_or_def_option('e_dims', 64)
|
||||
default_d_dims = self.options['d_dims'] = self.options.get('d_dims', None)
|
||||
default_d_mask_dims = self.options['d_mask_dims'] = self.options.get('d_mask_dims', None)
|
||||
default_morph_factor = self.options['morph_factor'] = self.options.get('morph_factor', 0.5)
|
||||
default_uniform_yaw = self.options['uniform_yaw'] = self.load_or_def_option('uniform_yaw', False)
|
||||
default_blur_out_mask = self.options['blur_out_mask'] = self.load_or_def_option('blur_out_mask', False)
|
||||
default_lr_dropout = self.options['lr_dropout'] = self.load_or_def_option('lr_dropout', 'n')
|
||||
default_random_warp = self.options['random_warp'] = self.load_or_def_option('random_warp', True)
|
||||
default_ct_mode = self.options['ct_mode'] = self.load_or_def_option('ct_mode', 'none')
|
||||
default_clipgrad = self.options['clipgrad'] = self.load_or_def_option('clipgrad', False)
|
||||
|
||||
ask_override = self.ask_override()
|
||||
if self.is_first_run() or ask_override:
|
||||
self.ask_autobackup_hour()
|
||||
self.ask_write_preview_history()
|
||||
self.ask_target_iter()
|
||||
self.ask_random_src_flip()
|
||||
self.ask_random_dst_flip()
|
||||
self.ask_batch_size(8)
|
||||
|
||||
if self.is_first_run():
|
||||
resolution = io.input_int("Resolution", default_resolution, add_info="64-640", help_message="More resolution requires more VRAM and time to train. Value will be adjusted to multiple of 32 .")
|
||||
resolution = np.clip ( (resolution // 32) * 32, 64, 640)
|
||||
self.options['resolution'] = resolution
|
||||
self.options['face_type'] = io.input_str ("Face type", default_face_type, ['f','wf','head'], help_message="whole face / head").lower()
|
||||
|
||||
|
||||
default_d_dims = self.options['d_dims'] = self.load_or_def_option('d_dims', 64)
|
||||
|
||||
default_d_mask_dims = default_d_dims // 3
|
||||
default_d_mask_dims += default_d_mask_dims % 2
|
||||
default_d_mask_dims = self.options['d_mask_dims'] = self.load_or_def_option('d_mask_dims', default_d_mask_dims)
|
||||
|
||||
if self.is_first_run():
|
||||
self.options['ae_dims'] = np.clip ( io.input_int("AutoEncoder dimensions", default_ae_dims, add_info="32-1024", help_message="All face information will packed to AE dims. If amount of AE dims are not enough, then for example closed eyes will not be recognized. More dims are better, but require more VRAM. You can fine-tune model size to fit your GPU." ), 32, 1024 )
|
||||
self.options['inter_dims'] = np.clip ( io.input_int("Inter dimensions", default_inter_dims, add_info="32-2048", help_message="Should be equal or more than AutoEncoder dimensions. More dims are better, but require more VRAM. You can fine-tune model size to fit your GPU." ), 32, 2048 )
|
||||
|
||||
e_dims = np.clip ( io.input_int("Encoder dimensions", default_e_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 )
|
||||
self.options['e_dims'] = e_dims + e_dims % 2
|
||||
|
||||
d_dims = np.clip ( io.input_int("Decoder dimensions", default_d_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 )
|
||||
self.options['d_dims'] = d_dims + d_dims % 2
|
||||
|
||||
d_mask_dims = np.clip ( io.input_int("Decoder mask dimensions", default_d_mask_dims, add_info="16-256", help_message="Typical mask dimensions = decoder dimensions / 3. If you manually cut out obstacles from the dst mask, you can increase this parameter to achieve better quality." ), 16, 256 )
|
||||
self.options['d_mask_dims'] = d_mask_dims + d_mask_dims % 2
|
||||
|
||||
morph_factor = np.clip ( io.input_number ("Morph factor.", default_morph_factor, add_info="0.1 .. 0.5", help_message="Typical fine value is 0.5"), 0.1, 0.5 )
|
||||
self.options['morph_factor'] = morph_factor
|
||||
|
||||
if self.is_first_run() or ask_override:
|
||||
self.options['uniform_yaw'] = io.input_bool ("Uniform yaw distribution of samples", default_uniform_yaw, help_message='Helps to fix blurry side faces due to small amount of them in the faceset.')
|
||||
self.options['blur_out_mask'] = io.input_bool ("Blur out mask", default_blur_out_mask, help_message='Blurs nearby area outside of applied face mask of training samples. The result is the background near the face is smoothed and less noticeable on swapped face. The exact xseg mask in src and dst faceset is required.')
|
||||
self.options['lr_dropout'] = io.input_str (f"Use learning rate dropout", default_lr_dropout, ['n','y','cpu'], help_message="When the face is trained enough, you can enable this option to get extra sharpness and reduce subpixel shake for less amount of iterations. Enabled it before `disable random warp` and before GAN. \nn - disabled.\ny - enabled\ncpu - enabled on CPU. This allows not to use extra VRAM, sacrificing 20% time of iteration.")
|
||||
|
||||
default_gan_power = self.options['gan_power'] = self.load_or_def_option('gan_power', 0.0)
|
||||
default_gan_patch_size = self.options['gan_patch_size'] = self.load_or_def_option('gan_patch_size', self.options['resolution'] // 8)
|
||||
default_gan_dims = self.options['gan_dims'] = self.load_or_def_option('gan_dims', 16)
|
||||
|
||||
if self.is_first_run() or ask_override:
|
||||
self.options['models_opt_on_gpu'] = io.input_bool ("Place models and optimizer on GPU", default_models_opt_on_gpu, help_message="When you train on one GPU, by default model and optimizer weights are placed on GPU to accelerate the process. You can place they on CPU to free up extra VRAM, thus set bigger dimensions.")
|
||||
|
||||
self.options['random_warp'] = io.input_bool ("Enable random warp of samples", default_random_warp, help_message="Random warp is required to generalize facial expressions of both faces. When the face is trained enough, you can disable it to get extra sharpness and reduce subpixel shake for less amount of iterations.")
|
||||
|
||||
self.options['gan_power'] = np.clip ( io.input_number ("GAN power", default_gan_power, add_info="0.0 .. 5.0", help_message="Forces the neural network to learn small details of the face. Enable it only when the face is trained enough with random_warp(off), and don't disable. The higher the value, the higher the chances of artifacts. Typical fine value is 0.1"), 0.0, 5.0 )
|
||||
|
||||
if self.options['gan_power'] != 0.0:
|
||||
gan_patch_size = np.clip ( io.input_int("GAN patch size", default_gan_patch_size, add_info="3-640", help_message="The higher patch size, the higher the quality, the more VRAM is required. You can get sharper edges even at the lowest setting. Typical fine value is resolution / 8." ), 3, 640 )
|
||||
self.options['gan_patch_size'] = gan_patch_size
|
||||
|
||||
gan_dims = np.clip ( io.input_int("GAN dimensions", default_gan_dims, add_info="4-512", help_message="The dimensions of the GAN network. The higher dimensions, the more VRAM is required. You can get sharper edges even at the lowest setting. Typical fine value is 16." ), 4, 512 )
|
||||
self.options['gan_dims'] = gan_dims
|
||||
|
||||
self.options['ct_mode'] = io.input_str (f"Color transfer for src faceset", default_ct_mode, ['none','rct','lct','mkl','idt','sot'], help_message="Change color distribution of src samples close to dst samples. If src faceset is deverse enough, then lct mode is fine in most cases.")
|
||||
self.options['clipgrad'] = io.input_bool ("Enable gradient clipping", default_clipgrad, help_message="Gradient clipping reduces chance of model collapse, sacrificing speed of training.")
|
||||
|
||||
self.gan_model_changed = (default_gan_patch_size != self.options['gan_patch_size']) or (default_gan_dims != self.options['gan_dims'])
|
||||
|
||||
#override
|
||||
def on_initialize(self):
|
||||
device_config = nn.getCurrentDeviceConfig()
|
||||
devices = device_config.devices
|
||||
self.model_data_format = "NCHW"
|
||||
nn.initialize(data_format=self.model_data_format)
|
||||
tf = nn.tf
|
||||
|
||||
input_ch=3
|
||||
resolution = self.resolution = self.options['resolution']
|
||||
e_dims = self.options['e_dims']
|
||||
ae_dims = self.options['ae_dims']
|
||||
inter_dims = self.inter_dims = self.options['inter_dims']
|
||||
inter_res = self.inter_res = resolution // 32
|
||||
d_dims = self.options['d_dims']
|
||||
d_mask_dims = self.options['d_mask_dims']
|
||||
face_type = self.face_type = {'f' : FaceType.FULL,
|
||||
'wf' : FaceType.WHOLE_FACE,
|
||||
'head' : FaceType.HEAD}[ self.options['face_type'] ]
|
||||
morph_factor = self.options['morph_factor']
|
||||
gan_power = self.gan_power = self.options['gan_power']
|
||||
random_warp = self.options['random_warp']
|
||||
|
||||
blur_out_mask = self.options['blur_out_mask']
|
||||
|
||||
ct_mode = self.options['ct_mode']
|
||||
if ct_mode == 'none':
|
||||
ct_mode = None
|
||||
|
||||
use_fp16 = False
|
||||
if self.is_exporting:
|
||||
use_fp16 = io.input_bool ("Export quantized?", False, help_message='Makes the exported model faster. If you have problems, disable this option.')
|
||||
|
||||
conv_dtype = tf.float16 if use_fp16 else tf.float32
|
||||
|
||||
class Downscale(nn.ModelBase):
|
||||
def on_build(self, in_ch, out_ch, kernel_size=5 ):
|
||||
self.conv1 = nn.Conv2D( in_ch, out_ch, kernel_size=kernel_size, strides=2, padding='SAME', dtype=conv_dtype)
|
||||
|
||||
def forward(self, x):
|
||||
return tf.nn.leaky_relu(self.conv1(x), 0.1)
|
||||
|
||||
class Upscale(nn.ModelBase):
|
||||
def on_build(self, in_ch, out_ch, kernel_size=3 ):
|
||||
self.conv1 = nn.Conv2D(in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x = nn.depth_to_space(tf.nn.leaky_relu(self.conv1(x), 0.1), 2)
|
||||
return x
|
||||
|
||||
class ResidualBlock(nn.ModelBase):
|
||||
def on_build(self, ch, kernel_size=3 ):
|
||||
self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype)
|
||||
self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype)
|
||||
|
||||
def forward(self, inp):
|
||||
x = self.conv1(inp)
|
||||
x = tf.nn.leaky_relu(x, 0.2)
|
||||
x = self.conv2(x)
|
||||
x = tf.nn.leaky_relu(inp+x, 0.2)
|
||||
return x
|
||||
|
||||
class Encoder(nn.ModelBase):
|
||||
def on_build(self):
|
||||
self.down1 = Downscale(input_ch, e_dims, kernel_size=5)
|
||||
self.res1 = ResidualBlock(e_dims)
|
||||
self.down2 = Downscale(e_dims, e_dims*2, kernel_size=5)
|
||||
self.down3 = Downscale(e_dims*2, e_dims*4, kernel_size=5)
|
||||
self.down4 = Downscale(e_dims*4, e_dims*8, kernel_size=5)
|
||||
self.down5 = Downscale(e_dims*8, e_dims*8, kernel_size=5)
|
||||
self.res5 = ResidualBlock(e_dims*8)
|
||||
self.dense1 = nn.Dense( (( resolution//(2**5) )**2) * e_dims*8, ae_dims )
|
||||
|
||||
def forward(self, x):
|
||||
if use_fp16:
|
||||
x = tf.cast(x, tf.float16)
|
||||
x = self.down1(x)
|
||||
x = self.res1(x)
|
||||
x = self.down2(x)
|
||||
x = self.down3(x)
|
||||
x = self.down4(x)
|
||||
x = self.down5(x)
|
||||
x = self.res5(x)
|
||||
if use_fp16:
|
||||
x = tf.cast(x, tf.float32)
|
||||
x = nn.pixel_norm(nn.flatten(x), axes=-1)
|
||||
x = self.dense1(x)
|
||||
return x
|
||||
|
||||
|
||||
class Inter(nn.ModelBase):
|
||||
def on_build(self):
|
||||
self.dense2 = nn.Dense(ae_dims, inter_res * inter_res * inter_dims)
|
||||
|
||||
def forward(self, inp):
|
||||
x = inp
|
||||
x = self.dense2(x)
|
||||
x = nn.reshape_4D (x, inter_res, inter_res, inter_dims)
|
||||
return x
|
||||
|
||||
|
||||
class Decoder(nn.ModelBase):
|
||||
def on_build(self ):
|
||||
self.upscale0 = Upscale(inter_dims, d_dims*8, kernel_size=3)
|
||||
self.upscale1 = Upscale(d_dims*8, d_dims*8, kernel_size=3)
|
||||
self.upscale2 = Upscale(d_dims*8, d_dims*4, kernel_size=3)
|
||||
self.upscale3 = Upscale(d_dims*4, d_dims*2, kernel_size=3)
|
||||
|
||||
self.res0 = ResidualBlock(d_dims*8, kernel_size=3)
|
||||
self.res1 = ResidualBlock(d_dims*8, kernel_size=3)
|
||||
self.res2 = ResidualBlock(d_dims*4, kernel_size=3)
|
||||
self.res3 = ResidualBlock(d_dims*2, kernel_size=3)
|
||||
|
||||
self.upscalem0 = Upscale(inter_dims, d_mask_dims*8, kernel_size=3)
|
||||
self.upscalem1 = Upscale(d_mask_dims*8, d_mask_dims*8, kernel_size=3)
|
||||
self.upscalem2 = Upscale(d_mask_dims*8, d_mask_dims*4, kernel_size=3)
|
||||
self.upscalem3 = Upscale(d_mask_dims*4, d_mask_dims*2, kernel_size=3)
|
||||
self.upscalem4 = Upscale(d_mask_dims*2, d_mask_dims*1, kernel_size=3)
|
||||
self.out_convm = nn.Conv2D( d_mask_dims*1, 1, kernel_size=1, padding='SAME', dtype=conv_dtype)
|
||||
|
||||
self.out_conv = nn.Conv2D( d_dims*2, 3, kernel_size=1, padding='SAME', dtype=conv_dtype)
|
||||
self.out_conv1 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
|
||||
self.out_conv2 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
|
||||
self.out_conv3 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
|
||||
|
||||
def forward(self, z):
|
||||
if use_fp16:
|
||||
z = tf.cast(z, tf.float16)
|
||||
|
||||
x = self.upscale0(z)
|
||||
x = self.res0(x)
|
||||
x = self.upscale1(x)
|
||||
x = self.res1(x)
|
||||
x = self.upscale2(x)
|
||||
x = self.res2(x)
|
||||
x = self.upscale3(x)
|
||||
x = self.res3(x)
|
||||
|
||||
x = tf.nn.sigmoid( nn.depth_to_space(tf.concat( (self.out_conv(x),
|
||||
self.out_conv1(x),
|
||||
self.out_conv2(x),
|
||||
self.out_conv3(x)), nn.conv2d_ch_axis), 2) )
|
||||
m = self.upscalem0(z)
|
||||
m = self.upscalem1(m)
|
||||
m = self.upscalem2(m)
|
||||
m = self.upscalem3(m)
|
||||
m = self.upscalem4(m)
|
||||
m = tf.nn.sigmoid(self.out_convm(m))
|
||||
|
||||
if use_fp16:
|
||||
x = tf.cast(x, tf.float32)
|
||||
m = tf.cast(m, tf.float32)
|
||||
return x, m
|
||||
|
||||
models_opt_on_gpu = False if len(devices) == 0 else self.options['models_opt_on_gpu']
|
||||
models_opt_device = nn.tf_default_device_name if models_opt_on_gpu and self.is_training else '/CPU:0'
|
||||
optimizer_vars_on_cpu = models_opt_device=='/CPU:0'
|
||||
|
||||
bgr_shape = self.bgr_shape = nn.get4Dshape(resolution,resolution,input_ch)
|
||||
mask_shape = nn.get4Dshape(resolution,resolution,1)
|
||||
self.model_filename_list = []
|
||||
|
||||
with tf.device ('/CPU:0'):
|
||||
#Place holders on CPU
|
||||
self.warped_src = tf.placeholder (nn.floatx, bgr_shape, name='warped_src')
|
||||
self.warped_dst = tf.placeholder (nn.floatx, bgr_shape, name='warped_dst')
|
||||
|
||||
self.target_src = tf.placeholder (nn.floatx, bgr_shape, name='target_src')
|
||||
self.target_dst = tf.placeholder (nn.floatx, bgr_shape, name='target_dst')
|
||||
|
||||
self.target_srcm = tf.placeholder (nn.floatx, mask_shape, name='target_srcm')
|
||||
self.target_srcm_em = tf.placeholder (nn.floatx, mask_shape, name='target_srcm_em')
|
||||
self.target_dstm = tf.placeholder (nn.floatx, mask_shape, name='target_dstm')
|
||||
self.target_dstm_em = tf.placeholder (nn.floatx, mask_shape, name='target_dstm_em')
|
||||
|
||||
self.morph_value_t = tf.placeholder (nn.floatx, (1,), name='morph_value_t')
|
||||
|
||||
# Initializing model classes
|
||||
with tf.device (models_opt_device):
|
||||
self.encoder = Encoder(name='encoder')
|
||||
self.inter_src = Inter(name='inter_src')
|
||||
self.inter_dst = Inter(name='inter_dst')
|
||||
self.decoder = Decoder(name='decoder')
|
||||
|
||||
self.model_filename_list += [ [self.encoder, 'encoder.npy'],
|
||||
[self.inter_src, 'inter_src.npy'],
|
||||
[self.inter_dst , 'inter_dst.npy'],
|
||||
[self.decoder , 'decoder.npy'] ]
|
||||
|
||||
if self.is_training:
|
||||
# Initialize optimizers
|
||||
clipnorm = 1.0 if self.options['clipgrad'] else 0.0
|
||||
if self.options['lr_dropout'] in ['y','cpu']:
|
||||
lr_cos = 500
|
||||
lr_dropout = 0.3
|
||||
else:
|
||||
lr_cos = 0
|
||||
lr_dropout = 1.0
|
||||
self.G_weights = self.encoder.get_weights() + self.decoder.get_weights()
|
||||
|
||||
self.src_dst_opt = nn.AdaBelief(lr=5e-5, lr_dropout=lr_dropout, lr_cos=lr_cos, clipnorm=clipnorm, name='src_dst_opt')
|
||||
self.src_dst_opt.initialize_variables (self.G_weights, vars_on_cpu=optimizer_vars_on_cpu)
|
||||
self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ]
|
||||
|
||||
if gan_power != 0:
|
||||
self.GAN = nn.UNetPatchDiscriminator(patch_size=self.options['gan_patch_size'], in_ch=input_ch, base_ch=self.options['gan_dims'], name="GAN")
|
||||
self.GAN_opt = nn.AdaBelief(lr=5e-5, lr_dropout=lr_dropout, lr_cos=lr_cos, clipnorm=clipnorm, name='GAN_opt')
|
||||
self.GAN_opt.initialize_variables ( self.GAN.get_weights(), vars_on_cpu=optimizer_vars_on_cpu)
|
||||
self.model_filename_list += [ [self.GAN, 'GAN.npy'],
|
||||
[self.GAN_opt, 'GAN_opt.npy'] ]
|
||||
|
||||
if self.is_training:
|
||||
# Adjust batch size for multiple GPU
|
||||
gpu_count = max(1, len(devices) )
|
||||
bs_per_gpu = max(1, self.get_batch_size() // gpu_count)
|
||||
self.set_batch_size( gpu_count*bs_per_gpu)
|
||||
|
||||
# Compute losses per GPU
|
||||
gpu_pred_src_src_list = []
|
||||
gpu_pred_dst_dst_list = []
|
||||
gpu_pred_src_dst_list = []
|
||||
gpu_pred_src_srcm_list = []
|
||||
gpu_pred_dst_dstm_list = []
|
||||
gpu_pred_src_dstm_list = []
|
||||
|
||||
gpu_src_losses = []
|
||||
gpu_dst_losses = []
|
||||
gpu_G_loss_gradients = []
|
||||
gpu_GAN_loss_gradients = []
|
||||
|
||||
def DLossOnes(logits):
|
||||
return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(logits), logits=logits), axis=[1,2,3])
|
||||
|
||||
def DLossZeros(logits):
|
||||
return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(logits), logits=logits), axis=[1,2,3])
|
||||
|
||||
for gpu_id in range(gpu_count):
|
||||
with tf.device( f'/{devices[gpu_id].tf_dev_type}:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ):
|
||||
with tf.device(f'/CPU:0'):
|
||||
# slice on CPU, otherwise all batch data will be transfered to GPU first
|
||||
batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu )
|
||||
gpu_warped_src = self.warped_src [batch_slice,:,:,:]
|
||||
gpu_warped_dst = self.warped_dst [batch_slice,:,:,:]
|
||||
gpu_target_src = self.target_src [batch_slice,:,:,:]
|
||||
gpu_target_dst = self.target_dst [batch_slice,:,:,:]
|
||||
gpu_target_srcm = self.target_srcm[batch_slice,:,:,:]
|
||||
gpu_target_srcm_em = self.target_srcm_em[batch_slice,:,:,:]
|
||||
gpu_target_dstm = self.target_dstm[batch_slice,:,:,:]
|
||||
gpu_target_dstm_em = self.target_dstm_em[batch_slice,:,:,:]
|
||||
|
||||
# process model tensors
|
||||
gpu_src_code = self.encoder (gpu_warped_src)
|
||||
gpu_dst_code = self.encoder (gpu_warped_dst)
|
||||
|
||||
gpu_src_inter_src_code, gpu_src_inter_dst_code = self.inter_src (gpu_src_code), self.inter_dst (gpu_src_code)
|
||||
gpu_dst_inter_src_code, gpu_dst_inter_dst_code = self.inter_src (gpu_dst_code), self.inter_dst (gpu_dst_code)
|
||||
|
||||
inter_dims_bin = int(inter_dims*morph_factor)
|
||||
with tf.device(f'/CPU:0'):
|
||||
inter_rnd_binomial = tf.stack([tf.random.shuffle(tf.concat([tf.tile(tf.constant([1], tf.float32), ( inter_dims_bin, )),
|
||||
tf.tile(tf.constant([0], tf.float32), ( inter_dims-inter_dims_bin, ))], 0 )) for _ in range(bs_per_gpu)], 0)
|
||||
|
||||
inter_rnd_binomial = tf.stop_gradient(inter_rnd_binomial[...,None,None])
|
||||
|
||||
gpu_src_code = gpu_src_inter_src_code * inter_rnd_binomial + gpu_src_inter_dst_code * (1-inter_rnd_binomial)
|
||||
gpu_dst_code = gpu_dst_inter_dst_code
|
||||
|
||||
inter_dims_slice = tf.cast(inter_dims*self.morph_value_t[0], tf.int32)
|
||||
gpu_src_dst_code = tf.concat( (tf.slice(gpu_dst_inter_src_code, [0,0,0,0], [-1, inter_dims_slice , inter_res, inter_res]),
|
||||
tf.slice(gpu_dst_inter_dst_code, [0,inter_dims_slice,0,0], [-1,inter_dims-inter_dims_slice, inter_res,inter_res]) ), 1 )
|
||||
|
||||
gpu_pred_src_src, gpu_pred_src_srcm = self.decoder(gpu_src_code)
|
||||
gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder(gpu_dst_code)
|
||||
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code)
|
||||
|
||||
gpu_pred_src_src_list.append(gpu_pred_src_src), gpu_pred_src_srcm_list.append(gpu_pred_src_srcm)
|
||||
gpu_pred_dst_dst_list.append(gpu_pred_dst_dst), gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm)
|
||||
gpu_pred_src_dst_list.append(gpu_pred_src_dst), gpu_pred_src_dstm_list.append(gpu_pred_src_dstm)
|
||||
|
||||
gpu_target_srcm_anti = 1-gpu_target_srcm
|
||||
gpu_target_dstm_anti = 1-gpu_target_dstm
|
||||
|
||||
gpu_target_srcm_gblur = nn.gaussian_blur(gpu_target_srcm, resolution // 32)
|
||||
gpu_target_dstm_gblur = nn.gaussian_blur(gpu_target_dstm, resolution // 32)
|
||||
|
||||
gpu_target_srcm_blur = tf.clip_by_value(gpu_target_srcm_gblur, 0, 0.5) * 2
|
||||
gpu_target_dstm_blur = tf.clip_by_value(gpu_target_dstm_gblur, 0, 0.5) * 2
|
||||
gpu_target_srcm_anti_blur = 1.0-gpu_target_srcm_blur
|
||||
gpu_target_dstm_anti_blur = 1.0-gpu_target_dstm_blur
|
||||
|
||||
if blur_out_mask:
|
||||
sigma = resolution / 128
|
||||
|
||||
x = nn.gaussian_blur(gpu_target_src*gpu_target_srcm_anti, sigma)
|
||||
y = 1-nn.gaussian_blur(gpu_target_srcm, sigma)
|
||||
y = tf.where(tf.equal(y, 0), tf.ones_like(y), y)
|
||||
gpu_target_src = gpu_target_src*gpu_target_srcm + (x/y)*gpu_target_srcm_anti
|
||||
|
||||
x = nn.gaussian_blur(gpu_target_dst*gpu_target_dstm_anti, sigma)
|
||||
y = 1-nn.gaussian_blur(gpu_target_dstm, sigma)
|
||||
y = tf.where(tf.equal(y, 0), tf.ones_like(y), y)
|
||||
gpu_target_dst = gpu_target_dst*gpu_target_dstm + (x/y)*gpu_target_dstm_anti
|
||||
|
||||
gpu_target_src_masked = gpu_target_src*gpu_target_srcm_blur
|
||||
gpu_target_dst_masked = gpu_target_dst*gpu_target_dstm_blur
|
||||
gpu_target_src_anti_masked = gpu_target_src*gpu_target_srcm_anti_blur
|
||||
gpu_target_dst_anti_masked = gpu_target_dst*gpu_target_dstm_anti_blur
|
||||
|
||||
gpu_pred_src_src_masked = gpu_pred_src_src*gpu_target_srcm_blur
|
||||
gpu_pred_dst_dst_masked = gpu_pred_dst_dst*gpu_target_dstm_blur
|
||||
gpu_pred_src_src_anti_masked = gpu_pred_src_src*gpu_target_srcm_anti_blur
|
||||
gpu_pred_dst_dst_anti_masked = gpu_pred_dst_dst*gpu_target_dstm_anti_blur
|
||||
|
||||
# Structural loss
|
||||
gpu_src_loss = tf.reduce_mean (5*nn.dssim(gpu_target_src_masked, gpu_pred_src_src_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
|
||||
gpu_src_loss += tf.reduce_mean (5*nn.dssim(gpu_target_src_masked, gpu_pred_src_src_masked, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1])
|
||||
gpu_dst_loss = tf.reduce_mean (5*nn.dssim(gpu_target_dst_masked, gpu_pred_dst_dst_masked, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1])
|
||||
gpu_dst_loss += tf.reduce_mean (5*nn.dssim(gpu_target_dst_masked, gpu_pred_dst_dst_masked, max_val=1.0, filter_size=int(resolution/23.2) ), axis=[1])
|
||||
|
||||
# Pixel loss
|
||||
gpu_src_loss += tf.reduce_mean (10*tf.square(gpu_target_src_masked-gpu_pred_src_src_masked), axis=[1,2,3])
|
||||
gpu_dst_loss += tf.reduce_mean (10*tf.square(gpu_target_dst_masked-gpu_pred_dst_dst_masked), axis=[1,2,3])
|
||||
|
||||
# Eyes+mouth prio loss
|
||||
gpu_src_loss += tf.reduce_mean (300*tf.abs (gpu_target_src*gpu_target_srcm_em-gpu_pred_src_src*gpu_target_srcm_em), axis=[1,2,3])
|
||||
gpu_dst_loss += tf.reduce_mean (300*tf.abs (gpu_target_dst*gpu_target_dstm_em-gpu_pred_dst_dst*gpu_target_dstm_em), axis=[1,2,3])
|
||||
|
||||
# Mask loss
|
||||
gpu_src_loss += tf.reduce_mean ( 10*tf.square( gpu_target_srcm - gpu_pred_src_srcm ),axis=[1,2,3] )
|
||||
gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] )
|
||||
|
||||
gpu_src_losses += [gpu_src_loss]
|
||||
gpu_dst_losses += [gpu_dst_loss]
|
||||
gpu_G_loss = gpu_src_loss + gpu_dst_loss
|
||||
# dst-dst background weak loss
|
||||
gpu_G_loss += tf.reduce_mean(0.1*tf.square(gpu_pred_dst_dst_anti_masked-gpu_target_dst_anti_masked),axis=[1,2,3] )
|
||||
gpu_G_loss += 0.000001*nn.total_variation_mse(gpu_pred_dst_dst_anti_masked)
|
||||
|
||||
|
||||
if gan_power != 0:
|
||||
gpu_pred_src_src_d, gpu_pred_src_src_d2 = self.GAN(gpu_pred_src_src_masked)
|
||||
gpu_pred_dst_dst_d, gpu_pred_dst_dst_d2 = self.GAN(gpu_pred_dst_dst_masked)
|
||||
gpu_target_src_d, gpu_target_src_d2 = self.GAN(gpu_target_src_masked)
|
||||
gpu_target_dst_d, gpu_target_dst_d2 = self.GAN(gpu_target_dst_masked)
|
||||
|
||||
gpu_GAN_loss = (DLossOnes (gpu_target_src_d) + DLossOnes (gpu_target_src_d2) + \
|
||||
DLossZeros(gpu_pred_src_src_d) + DLossZeros(gpu_pred_src_src_d2) + \
|
||||
DLossOnes (gpu_target_dst_d) + DLossOnes (gpu_target_dst_d2) + \
|
||||
DLossZeros(gpu_pred_dst_dst_d) + DLossZeros(gpu_pred_dst_dst_d2)
|
||||
) * (1.0 / 8)
|
||||
|
||||
gpu_GAN_loss_gradients += [ nn.gradients (gpu_GAN_loss, self.GAN.get_weights() ) ]
|
||||
|
||||
gpu_G_loss += (DLossOnes(gpu_pred_src_src_d) + DLossOnes(gpu_pred_src_src_d2) + \
|
||||
DLossOnes(gpu_pred_dst_dst_d) + DLossOnes(gpu_pred_dst_dst_d2)
|
||||
) * gan_power
|
||||
|
||||
# Minimal src-src-bg rec with total_variation_mse to suppress random bright dots from gan
|
||||
gpu_G_loss += 0.000001*nn.total_variation_mse(gpu_pred_src_src)
|
||||
gpu_G_loss += 0.02*tf.reduce_mean(tf.square(gpu_pred_src_src_anti_masked-gpu_target_src_anti_masked),axis=[1,2,3] )
|
||||
|
||||
gpu_G_loss_gradients += [ nn.gradients ( gpu_G_loss, self.G_weights ) ]
|
||||
|
||||
# Average losses and gradients, and create optimizer update ops
|
||||
with tf.device(f'/CPU:0'):
|
||||
pred_src_src = nn.concat(gpu_pred_src_src_list, 0)
|
||||
pred_dst_dst = nn.concat(gpu_pred_dst_dst_list, 0)
|
||||
pred_src_dst = nn.concat(gpu_pred_src_dst_list, 0)
|
||||
pred_src_srcm = nn.concat(gpu_pred_src_srcm_list, 0)
|
||||
pred_dst_dstm = nn.concat(gpu_pred_dst_dstm_list, 0)
|
||||
pred_src_dstm = nn.concat(gpu_pred_src_dstm_list, 0)
|
||||
|
||||
with tf.device (models_opt_device):
|
||||
src_loss = tf.concat(gpu_src_losses, 0)
|
||||
dst_loss = tf.concat(gpu_dst_losses, 0)
|
||||
train_op = self.src_dst_opt.get_update_op (nn.average_gv_list (gpu_G_loss_gradients))
|
||||
|
||||
if gan_power != 0:
|
||||
GAN_train_op = self.GAN_opt.get_update_op (nn.average_gv_list(gpu_GAN_loss_gradients) )
|
||||
|
||||
# Initializing training and view functions
|
||||
def train(warped_src, target_src, target_srcm, target_srcm_em, \
|
||||
warped_dst, target_dst, target_dstm, target_dstm_em, ):
|
||||
s, d, _ = nn.tf_sess.run ([src_loss, dst_loss, train_op],
|
||||
feed_dict={self.warped_src :warped_src,
|
||||
self.target_src :target_src,
|
||||
self.target_srcm:target_srcm,
|
||||
self.target_srcm_em:target_srcm_em,
|
||||
self.warped_dst :warped_dst,
|
||||
self.target_dst :target_dst,
|
||||
self.target_dstm:target_dstm,
|
||||
self.target_dstm_em:target_dstm_em,
|
||||
})
|
||||
return s, d
|
||||
self.train = train
|
||||
|
||||
if gan_power != 0:
|
||||
def GAN_train(warped_src, target_src, target_srcm, target_srcm_em, \
|
||||
warped_dst, target_dst, target_dstm, target_dstm_em, ):
|
||||
nn.tf_sess.run ([GAN_train_op], feed_dict={self.warped_src :warped_src,
|
||||
self.target_src :target_src,
|
||||
self.target_srcm:target_srcm,
|
||||
self.target_srcm_em:target_srcm_em,
|
||||
self.warped_dst :warped_dst,
|
||||
self.target_dst :target_dst,
|
||||
self.target_dstm:target_dstm,
|
||||
self.target_dstm_em:target_dstm_em})
|
||||
self.GAN_train = GAN_train
|
||||
|
||||
def AE_view(warped_src, warped_dst, morph_value):
|
||||
return nn.tf_sess.run ( [pred_src_src, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm],
|
||||
feed_dict={self.warped_src:warped_src, self.warped_dst:warped_dst, self.morph_value_t:[morph_value] })
|
||||
|
||||
self.AE_view = AE_view
|
||||
else:
|
||||
#Initializing merge function
|
||||
with tf.device( nn.tf_default_device_name if len(devices) != 0 else f'/CPU:0'):
|
||||
gpu_dst_code = self.encoder (self.warped_dst)
|
||||
gpu_dst_inter_src_code = self.inter_src (gpu_dst_code)
|
||||
gpu_dst_inter_dst_code = self.inter_dst (gpu_dst_code)
|
||||
|
||||
inter_dims_slice = tf.cast(inter_dims*self.morph_value_t[0], tf.int32)
|
||||
gpu_src_dst_code = tf.concat( ( tf.slice(gpu_dst_inter_src_code, [0,0,0,0], [-1, inter_dims_slice , inter_res, inter_res]),
|
||||
tf.slice(gpu_dst_inter_dst_code, [0,inter_dims_slice,0,0], [-1,inter_dims-inter_dims_slice, inter_res,inter_res]) ), 1 )
|
||||
|
||||
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code)
|
||||
_, gpu_pred_dst_dstm = self.decoder(gpu_dst_inter_dst_code)
|
||||
|
||||
def AE_merge(warped_dst, morph_value):
|
||||
return nn.tf_sess.run ( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm], feed_dict={self.warped_dst:warped_dst, self.morph_value_t:[morph_value] })
|
||||
|
||||
self.AE_merge = AE_merge
|
||||
|
||||
# Loading/initializing all models/optimizers weights
|
||||
for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"):
|
||||
do_init = self.is_first_run()
|
||||
if self.is_training and gan_power != 0 and model == self.GAN:
|
||||
if self.gan_model_changed:
|
||||
do_init = True
|
||||
if not do_init:
|
||||
do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) )
|
||||
if do_init:
|
||||
model.init_weights()
|
||||
###############
|
||||
|
||||
# initializing sample generators
|
||||
if self.is_training:
|
||||
training_data_src_path = self.training_data_src_path #if not self.pretrain else self.get_pretraining_data_path()
|
||||
training_data_dst_path = self.training_data_dst_path #if not self.pretrain else self.get_pretraining_data_path()
|
||||
|
||||
random_ct_samples_path=training_data_dst_path if ct_mode is not None else None #and not self.pretrain
|
||||
|
||||
cpu_count = multiprocessing.cpu_count()
|
||||
src_generators_count = cpu_count // 2
|
||||
dst_generators_count = cpu_count // 2
|
||||
if ct_mode is not None:
|
||||
src_generators_count = int(src_generators_count * 1.5)
|
||||
|
||||
|
||||
|
||||
self.set_training_data_generators ([
|
||||
SampleGeneratorFace(training_data_src_path, random_ct_samples_path=random_ct_samples_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
|
||||
sample_process_options=SampleProcessor.Options(scale_range=[-0.15, 0.15], random_flip=self.random_src_flip),
|
||||
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.EYES_MOUTH, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
],
|
||||
uniform_yaw_distribution=self.options['uniform_yaw'],# or self.pretrain,
|
||||
generators_count=src_generators_count ),
|
||||
|
||||
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
|
||||
sample_process_options=SampleProcessor.Options(scale_range=[-0.15, 0.15], random_flip=self.random_dst_flip),
|
||||
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.EYES_MOUTH, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
],
|
||||
uniform_yaw_distribution=self.options['uniform_yaw'],# or self.pretrain,
|
||||
generators_count=dst_generators_count )
|
||||
])
|
||||
|
||||
def export_dfm (self):
|
||||
output_path=self.get_strpath_storage_for_file('model.dfm')
|
||||
|
||||
io.log_info(f'Dumping .dfm to {output_path}')
|
||||
|
||||
tf = nn.tf
|
||||
with tf.device (nn.tf_default_device_name):
|
||||
warped_dst = tf.placeholder (nn.floatx, (None, self.resolution, self.resolution, 3), name='in_face')
|
||||
warped_dst = tf.transpose(warped_dst, (0,3,1,2))
|
||||
morph_value = tf.placeholder (nn.floatx, (1,), name='morph_value')
|
||||
|
||||
gpu_dst_code = self.encoder (warped_dst)
|
||||
gpu_dst_inter_src_code = self.inter_src ( gpu_dst_code)
|
||||
gpu_dst_inter_dst_code = self.inter_dst ( gpu_dst_code)
|
||||
|
||||
inter_dims_slice = tf.cast(self.inter_dims*morph_value[0], tf.int32)
|
||||
gpu_src_dst_code = tf.concat( (tf.slice(gpu_dst_inter_src_code, [0,0,0,0], [-1, inter_dims_slice , self.inter_res, self.inter_res]),
|
||||
tf.slice(gpu_dst_inter_dst_code, [0,inter_dims_slice,0,0], [-1,self.inter_dims-inter_dims_slice, self.inter_res,self.inter_res]) ), 1 )
|
||||
|
||||
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code)
|
||||
_, gpu_pred_dst_dstm = self.decoder(gpu_dst_inter_dst_code)
|
||||
|
||||
gpu_pred_src_dst = tf.transpose(gpu_pred_src_dst, (0,2,3,1))
|
||||
gpu_pred_dst_dstm = tf.transpose(gpu_pred_dst_dstm, (0,2,3,1))
|
||||
gpu_pred_src_dstm = tf.transpose(gpu_pred_src_dstm, (0,2,3,1))
|
||||
|
||||
tf.identity(gpu_pred_dst_dstm, name='out_face_mask')
|
||||
tf.identity(gpu_pred_src_dst, name='out_celeb_face')
|
||||
tf.identity(gpu_pred_src_dstm, name='out_celeb_face_mask')
|
||||
|
||||
output_graph_def = tf.graph_util.convert_variables_to_constants(
|
||||
nn.tf_sess,
|
||||
tf.get_default_graph().as_graph_def(),
|
||||
['out_face_mask','out_celeb_face','out_celeb_face_mask']
|
||||
)
|
||||
|
||||
import tf2onnx
|
||||
with tf.device("/CPU:0"):
|
||||
model_proto, _ = tf2onnx.convert._convert_common(
|
||||
output_graph_def,
|
||||
name='AMP',
|
||||
input_names=['in_face:0','morph_value:0'],
|
||||
output_names=['out_face_mask:0','out_celeb_face:0','out_celeb_face_mask:0'],
|
||||
opset=12,
|
||||
output_path=output_path)
|
||||
|
||||
#override
|
||||
def get_model_filename_list(self):
|
||||
return self.model_filename_list
|
||||
|
||||
#override
|
||||
def onSave(self):
|
||||
for model, filename in io.progress_bar_generator(self.get_model_filename_list(), "Saving", leave=False):
|
||||
model.save_weights ( self.get_strpath_storage_for_file(filename) )
|
||||
|
||||
#override
|
||||
def should_save_preview_history(self):
|
||||
return (not io.is_colab() and self.iter % ( 10*(max(1,self.resolution // 64)) ) == 0) or \
|
||||
(io.is_colab() and self.iter % 100 == 0)
|
||||
|
||||
#override
|
||||
def onTrainOneIter(self):
|
||||
bs = self.get_batch_size()
|
||||
|
||||
( (warped_src, target_src, target_srcm, target_srcm_em), \
|
||||
(warped_dst, target_dst, target_dstm, target_dstm_em) ) = self.generate_next_samples()
|
||||
|
||||
src_loss, dst_loss = self.train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em)
|
||||
|
||||
if self.gan_power != 0:
|
||||
self.GAN_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em)
|
||||
|
||||
return ( ('src_loss', np.mean(src_loss) ), ('dst_loss', np.mean(dst_loss) ), )
|
||||
|
||||
#override
|
||||
def onGetPreview(self, samples, for_history=False):
|
||||
( (warped_src, target_src, target_srcm, target_srcm_em),
|
||||
(warped_dst, target_dst, target_dstm, target_dstm_em) ) = samples
|
||||
|
||||
S, D, SS, DD, DDM_000, _, _ = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([target_src,target_dst] + self.AE_view (target_src, target_dst, 0.0) ) ]
|
||||
|
||||
_, _, DDM_025, SD_025, SDM_025 = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in self.AE_view (target_src, target_dst, 0.25) ]
|
||||
_, _, DDM_050, SD_050, SDM_050 = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in self.AE_view (target_src, target_dst, 0.50) ]
|
||||
_, _, DDM_065, SD_065, SDM_065 = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in self.AE_view (target_src, target_dst, 0.65) ]
|
||||
_, _, DDM_075, SD_075, SDM_075 = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in self.AE_view (target_src, target_dst, 0.75) ]
|
||||
_, _, DDM_100, SD_100, SDM_100 = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in self.AE_view (target_src, target_dst, 1.00) ]
|
||||
|
||||
(DDM_000,
|
||||
DDM_025, SDM_025,
|
||||
DDM_050, SDM_050,
|
||||
DDM_065, SDM_065,
|
||||
DDM_075, SDM_075,
|
||||
DDM_100, SDM_100) = [ np.repeat (x, (3,), -1) for x in (DDM_000,
|
||||
DDM_025, SDM_025,
|
||||
DDM_050, SDM_050,
|
||||
DDM_065, SDM_065,
|
||||
DDM_075, SDM_075,
|
||||
DDM_100, SDM_100) ]
|
||||
|
||||
target_srcm, target_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format) for x in ([target_srcm, target_dstm] )]
|
||||
|
||||
n_samples = min(4, self.get_batch_size(), 800 // self.resolution )
|
||||
|
||||
result = []
|
||||
|
||||
i = np.random.randint(n_samples) if not for_history else 0
|
||||
|
||||
st = [ np.concatenate ((S[i], D[i], DD[i]*DDM_000[i]), axis=1) ]
|
||||
st += [ np.concatenate ((SS[i], DD[i], SD_100[i] ), axis=1) ]
|
||||
|
||||
result += [ ('AMP morph 1.0', np.concatenate (st, axis=0 )), ]
|
||||
|
||||
st = [ np.concatenate ((DD[i], SD_025[i], SD_050[i]), axis=1) ]
|
||||
st += [ np.concatenate ((SD_065[i], SD_075[i], SD_100[i]), axis=1) ]
|
||||
result += [ ('AMP morph list', np.concatenate (st, axis=0 )), ]
|
||||
|
||||
st = [ np.concatenate ((DD[i], SD_025[i]*DDM_025[i]*SDM_025[i], SD_050[i]*DDM_050[i]*SDM_050[i]), axis=1) ]
|
||||
st += [ np.concatenate ((SD_065[i]*DDM_065[i]*SDM_065[i], SD_075[i]*DDM_075[i]*SDM_075[i], SD_100[i]*DDM_100[i]*SDM_100[i]), axis=1) ]
|
||||
result += [ ('AMP morph list masked', np.concatenate (st, axis=0 )), ]
|
||||
|
||||
return result
|
||||
|
||||
def predictor_func (self, face, morph_value):
|
||||
face = nn.to_data_format(face[None,...], self.model_data_format, "NHWC")
|
||||
|
||||
bgr, mask_dst_dstm, mask_src_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format).astype(np.float32) for x in self.AE_merge (face, morph_value) ]
|
||||
|
||||
return bgr[0], mask_src_dstm[0][...,0], mask_dst_dstm[0][...,0]
|
||||
|
||||
#override
|
||||
def get_MergerConfig(self):
|
||||
morph_factor = np.clip ( io.input_number ("Morph factor", 1.0, add_info="0.0 .. 1.0"), 0.0, 1.0 )
|
||||
|
||||
def predictor_morph(face):
|
||||
return self.predictor_func(face, morph_factor)
|
||||
|
||||
|
||||
import merger
|
||||
return predictor_morph, (self.options['resolution'], self.options['resolution'], 3), merger.MergerConfigMasked(face_type=self.face_type, default_mode = 'overlay')
|
||||
|
||||
Model = AMPModel
|
|
@ -1 +0,0 @@
|
|||
from .Model import Model
|
|
@ -22,16 +22,15 @@ class QModel(ModelBase):
|
|||
resolution = self.resolution = 96
|
||||
self.face_type = FaceType.FULL
|
||||
ae_dims = 128
|
||||
e_dims = 64
|
||||
e_dims = 128
|
||||
d_dims = 64
|
||||
d_mask_dims = 16
|
||||
self.pretrain = False
|
||||
self.pretrain_just_disabled = False
|
||||
|
||||
masked_training = True
|
||||
|
||||
models_opt_on_gpu = len(devices) >= 1 and all([dev.total_mem_gb >= 4 for dev in devices])
|
||||
models_opt_device = nn.tf_default_device_name if models_opt_on_gpu and self.is_training else '/CPU:0'
|
||||
models_opt_device = '/GPU:0' if models_opt_on_gpu and self.is_training else '/CPU:0'
|
||||
optimizer_vars_on_cpu = models_opt_device=='/CPU:0'
|
||||
|
||||
input_ch = 3
|
||||
|
@ -40,7 +39,7 @@ class QModel(ModelBase):
|
|||
|
||||
self.model_filename_list = []
|
||||
|
||||
model_archi = nn.DeepFakeArchi(resolution, opts='ud')
|
||||
model_archi = nn.DeepFakeArchi(resolution, mod='quick')
|
||||
|
||||
with tf.device ('/CPU:0'):
|
||||
#Place holders on CPU
|
||||
|
@ -56,13 +55,13 @@ class QModel(ModelBase):
|
|||
# Initializing model classes
|
||||
with tf.device (models_opt_device):
|
||||
self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, name='encoder')
|
||||
encoder_out_ch = self.encoder.get_out_ch()*self.encoder.get_out_res(resolution)**2
|
||||
encoder_out_ch = self.encoder.compute_output_channels ( (nn.floatx, bgr_shape))
|
||||
|
||||
self.inter = model_archi.Inter (in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims, name='inter')
|
||||
inter_out_ch = self.inter.get_out_ch()
|
||||
self.inter = model_archi.Inter (in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims, d_ch=d_dims, name='inter')
|
||||
inter_out_ch = self.inter.compute_output_channels ( (nn.floatx, (None,encoder_out_ch)))
|
||||
|
||||
self.decoder_src = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder_src')
|
||||
self.decoder_dst = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder_dst')
|
||||
self.decoder_src = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, name='decoder_src')
|
||||
self.decoder_dst = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, name='decoder_dst')
|
||||
|
||||
self.model_filename_list += [ [self.encoder, 'encoder.npy' ],
|
||||
[self.inter, 'inter.npy' ],
|
||||
|
@ -96,7 +95,7 @@ class QModel(ModelBase):
|
|||
gpu_src_dst_loss_gvs = []
|
||||
|
||||
for gpu_id in range(gpu_count):
|
||||
with tf.device( f'/{devices[gpu_id].tf_dev_type}:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ):
|
||||
with tf.device( f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ):
|
||||
batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu )
|
||||
with tf.device(f'/CPU:0'):
|
||||
# slice on CPU, otherwise all batch data will be transfered to GPU first
|
||||
|
@ -190,7 +189,7 @@ class QModel(ModelBase):
|
|||
self.AE_view = AE_view
|
||||
else:
|
||||
# Initializing merge function
|
||||
with tf.device( nn.tf_default_device_name if len(devices) != 0 else f'/CPU:0'):
|
||||
with tf.device( f'/GPU:0' if len(devices) != 0 else f'/CPU:0'):
|
||||
gpu_dst_code = self.inter(self.encoder(self.warped_dst))
|
||||
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code)
|
||||
_, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code)
|
||||
|
@ -278,7 +277,7 @@ class QModel(ModelBase):
|
|||
return ( ('src_loss', src_loss), ('dst_loss', dst_loss), )
|
||||
|
||||
#override
|
||||
def onGetPreview(self, samples, for_history=False):
|
||||
def onGetPreview(self, samples):
|
||||
( (warped_src, target_src, target_srcm),
|
||||
(warped_dst, target_dst, target_dstm) ) = samples
|
||||
|
||||
|
|
|
@ -27,33 +27,20 @@ class SAEHDModel(ModelBase):
|
|||
suggest_batch_size = 4
|
||||
|
||||
yn_str = {True:'y',False:'n'}
|
||||
min_res = 64
|
||||
max_res = 640
|
||||
|
||||
#default_usefp16 = self.options['use_fp16'] = self.load_or_def_option('use_fp16', False)
|
||||
default_resolution = self.options['resolution'] = self.load_or_def_option('resolution', 128)
|
||||
default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'f')
|
||||
default_models_opt_on_gpu = self.options['models_opt_on_gpu'] = self.load_or_def_option('models_opt_on_gpu', True)
|
||||
|
||||
default_archi = self.options['archi'] = self.load_or_def_option('archi', 'liae-ud')
|
||||
|
||||
default_archi = self.options['archi'] = self.load_or_def_option('archi', 'df')
|
||||
default_ae_dims = self.options['ae_dims'] = self.load_or_def_option('ae_dims', 256)
|
||||
default_e_dims = self.options['e_dims'] = self.load_or_def_option('e_dims', 64)
|
||||
default_d_dims = self.options['d_dims'] = self.options.get('d_dims', None)
|
||||
default_d_mask_dims = self.options['d_mask_dims'] = self.options.get('d_mask_dims', None)
|
||||
default_masked_training = self.options['masked_training'] = self.load_or_def_option('masked_training', True)
|
||||
default_eyes_mouth_prio = self.options['eyes_mouth_prio'] = self.load_or_def_option('eyes_mouth_prio', False)
|
||||
default_uniform_yaw = self.options['uniform_yaw'] = self.load_or_def_option('uniform_yaw', False)
|
||||
default_blur_out_mask = self.options['blur_out_mask'] = self.load_or_def_option('blur_out_mask', False)
|
||||
|
||||
default_adabelief = self.options['adabelief'] = self.load_or_def_option('adabelief', True)
|
||||
|
||||
lr_dropout = self.load_or_def_option('lr_dropout', 'n')
|
||||
lr_dropout = {True:'y', False:'n'}.get(lr_dropout, lr_dropout) #backward comp
|
||||
default_lr_dropout = self.options['lr_dropout'] = lr_dropout
|
||||
|
||||
default_eyes_prio = self.options['eyes_prio'] = self.load_or_def_option('eyes_prio', False)
|
||||
default_lr_dropout = self.options['lr_dropout'] = self.load_or_def_option('lr_dropout', False)
|
||||
default_random_warp = self.options['random_warp'] = self.load_or_def_option('random_warp', True)
|
||||
default_random_hsv_power = self.options['random_hsv_power'] = self.load_or_def_option('random_hsv_power', 0.0)
|
||||
default_gan_power = self.options['gan_power'] = self.load_or_def_option('gan_power', 0.0)
|
||||
default_true_face_power = self.options['true_face_power'] = self.load_or_def_option('true_face_power', 0.0)
|
||||
default_face_style_power = self.options['face_style_power'] = self.load_or_def_option('face_style_power', 0.0)
|
||||
default_bg_style_power = self.options['bg_style_power'] = self.load_or_def_option('bg_style_power', 0.0)
|
||||
|
@ -66,55 +53,18 @@ class SAEHDModel(ModelBase):
|
|||
self.ask_autobackup_hour()
|
||||
self.ask_write_preview_history()
|
||||
self.ask_target_iter()
|
||||
self.ask_random_src_flip()
|
||||
self.ask_random_dst_flip()
|
||||
self.ask_random_flip()
|
||||
self.ask_batch_size(suggest_batch_size)
|
||||
#self.options['use_fp16'] = io.input_bool ("Use fp16", default_usefp16, help_message='Increases training/inference speed, reduces model size. Model may crash. Enable it after 1-5k iters.')
|
||||
|
||||
if self.is_first_run():
|
||||
resolution = io.input_int("Resolution", default_resolution, add_info="64-640", help_message="More resolution requires more VRAM and time to train. Value will be adjusted to multiple of 16 and 32 for -d archi.")
|
||||
resolution = np.clip ( (resolution // 16) * 16, min_res, max_res)
|
||||
resolution = io.input_int("Resolution", default_resolution, add_info="64-512", help_message="More resolution requires more VRAM and time to train. Value will be adjusted to multiple of 16.")
|
||||
resolution = np.clip ( (resolution // 16) * 16, 64, 512)
|
||||
self.options['resolution'] = resolution
|
||||
|
||||
|
||||
|
||||
self.options['face_type'] = io.input_str ("Face type", default_face_type, ['h','mf','f','wf','head'], help_message="Half / mid face / full face / whole face / head. Half face has better resolution, but covers less area of cheeks. Mid face is 30% wider than half face. 'Whole face' covers full area of face include forehead. 'head' covers full head, but requires XSeg for src and dst faceset.").lower()
|
||||
self.options['archi'] = io.input_str ("AE architecture", default_archi, ['df','liae','dfhd','liaehd','dfuhd','liaeuhd'], help_message="'df' keeps faces more natural.\n'liae' can fix overly different face shapes.\n'hd' are experimental versions.").lower()
|
||||
|
||||
while True:
|
||||
archi = io.input_str ("AE architecture", default_archi, help_message=\
|
||||
"""
|
||||
'df' keeps more identity-preserved face.
|
||||
'liae' can fix overly different face shapes.
|
||||
'-u' increased likeness of the face.
|
||||
'-d' (experimental) doubling the resolution using the same computation cost.
|
||||
Examples: df, liae, df-d, df-ud, liae-ud, ...
|
||||
""").lower()
|
||||
|
||||
archi_split = archi.split('-')
|
||||
|
||||
if len(archi_split) == 2:
|
||||
archi_type, archi_opts = archi_split
|
||||
elif len(archi_split) == 1:
|
||||
archi_type, archi_opts = archi_split[0], None
|
||||
else:
|
||||
continue
|
||||
|
||||
if archi_type not in ['df', 'liae']:
|
||||
continue
|
||||
|
||||
if archi_opts is not None:
|
||||
if len(archi_opts) == 0:
|
||||
continue
|
||||
if len([ 1 for opt in archi_opts if opt not in ['u','d','t','c'] ]) != 0:
|
||||
continue
|
||||
|
||||
if 'd' in archi_opts:
|
||||
self.options['resolution'] = np.clip ( (self.options['resolution'] // 32) * 32, min_res, max_res)
|
||||
|
||||
break
|
||||
self.options['archi'] = archi
|
||||
|
||||
default_d_dims = self.options['d_dims'] = self.load_or_def_option('d_dims', 64)
|
||||
default_d_dims = 48 if self.options['archi'] == 'dfhd' else 64
|
||||
default_d_dims = self.options['d_dims'] = self.load_or_def_option('d_dims', default_d_dims)
|
||||
|
||||
default_d_mask_dims = default_d_dims // 3
|
||||
default_d_mask_dims += default_d_mask_dims % 2
|
||||
|
@ -126,6 +76,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
e_dims = np.clip ( io.input_int("Encoder dimensions", default_e_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 )
|
||||
self.options['e_dims'] = e_dims + e_dims % 2
|
||||
|
||||
|
||||
d_dims = np.clip ( io.input_int("Decoder dimensions", default_d_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 )
|
||||
self.options['d_dims'] = d_dims + d_dims % 2
|
||||
|
||||
|
@ -134,35 +85,17 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
|
||||
if self.is_first_run() or ask_override:
|
||||
if self.options['face_type'] == 'wf' or self.options['face_type'] == 'head':
|
||||
self.options['masked_training'] = io.input_bool ("Masked training", default_masked_training, help_message="This option is available only for 'whole_face' or 'head' type. Masked training clips training area to full_face mask or XSeg mask, thus network will train the faces properly.")
|
||||
self.options['masked_training'] = io.input_bool ("Masked training", default_masked_training, help_message="This option is available only for 'whole_face' type. Masked training clips training area to full_face mask, thus network will train the faces properly. When the face is trained enough, disable this option to train all area of the frame. Merge with 'raw-rgb' mode, then use Adobe After Effects to manually mask and compose whole face include forehead.")
|
||||
|
||||
self.options['eyes_mouth_prio'] = io.input_bool ("Eyes and mouth priority", default_eyes_mouth_prio, help_message='Helps to fix eye problems during training like "alien eyes" and wrong eyes direction. Also makes the detail of the teeth higher.')
|
||||
self.options['uniform_yaw'] = io.input_bool ("Uniform yaw distribution of samples", default_uniform_yaw, help_message='Helps to fix blurry side faces due to small amount of them in the faceset.')
|
||||
self.options['blur_out_mask'] = io.input_bool ("Blur out mask", default_blur_out_mask, help_message='Blurs nearby area outside of applied face mask of training samples. The result is the background near the face is smoothed and less noticeable on swapped face. The exact xseg mask in src and dst faceset is required.')
|
||||
|
||||
default_gan_power = self.options['gan_power'] = self.load_or_def_option('gan_power', 0.0)
|
||||
default_gan_patch_size = self.options['gan_patch_size'] = self.load_or_def_option('gan_patch_size', self.options['resolution'] // 8)
|
||||
default_gan_dims = self.options['gan_dims'] = self.load_or_def_option('gan_dims', 16)
|
||||
self.options['eyes_prio'] = io.input_bool ("Eyes priority", default_eyes_prio, help_message='Helps to fix eye problems during training like "alien eyes" and wrong eyes direction ( especially on HD architectures ) by forcing the neural network to train eyes with higher priority. before/after https://i.imgur.com/YQHOuSR.jpg ')
|
||||
|
||||
if self.is_first_run() or ask_override:
|
||||
self.options['models_opt_on_gpu'] = io.input_bool ("Place models and optimizer on GPU", default_models_opt_on_gpu, help_message="When you train on one GPU, by default model and optimizer weights are placed on GPU to accelerate the process. You can place they on CPU to free up extra VRAM, thus set bigger dimensions.")
|
||||
|
||||
self.options['adabelief'] = io.input_bool ("Use AdaBelief optimizer?", default_adabelief, help_message="Use AdaBelief optimizer. It requires more VRAM, but the accuracy and the generalization of the model is higher.")
|
||||
|
||||
self.options['lr_dropout'] = io.input_str (f"Use learning rate dropout", default_lr_dropout, ['n','y','cpu'], help_message="When the face is trained enough, you can enable this option to get extra sharpness and reduce subpixel shake for less amount of iterations. Enabled it before `disable random warp` and before GAN. \nn - disabled.\ny - enabled\ncpu - enabled on CPU. This allows not to use extra VRAM, sacrificing 20% time of iteration.")
|
||||
|
||||
self.options['lr_dropout'] = io.input_bool ("Use learning rate dropout", default_lr_dropout, help_message="When the face is trained enough, you can enable this option to get extra sharpness and reduce subpixel shake for less amount of iterations.")
|
||||
self.options['random_warp'] = io.input_bool ("Enable random warp of samples", default_random_warp, help_message="Random warp is required to generalize facial expressions of both faces. When the face is trained enough, you can disable it to get extra sharpness and reduce subpixel shake for less amount of iterations.")
|
||||
|
||||
self.options['random_hsv_power'] = np.clip ( io.input_number ("Random hue/saturation/light intensity", default_random_hsv_power, add_info="0.0 .. 0.3", help_message="Random hue/saturation/light intensity applied to the src face set only at the input of the neural network. Stabilizes color perturbations during face swapping. Reduces the quality of the color transfer by selecting the closest one in the src faceset. Thus the src faceset must be diverse enough. Typical fine value is 0.05"), 0.0, 0.3 )
|
||||
|
||||
self.options['gan_power'] = np.clip ( io.input_number ("GAN power", default_gan_power, add_info="0.0 .. 5.0", help_message="Forces the neural network to learn small details of the face. Enable it only when the face is trained enough with lr_dropout(on) and random_warp(off), and don't disable. The higher the value, the higher the chances of artifacts. Typical fine value is 0.1"), 0.0, 5.0 )
|
||||
|
||||
if self.options['gan_power'] != 0.0:
|
||||
gan_patch_size = np.clip ( io.input_int("GAN patch size", default_gan_patch_size, add_info="3-640", help_message="The higher patch size, the higher the quality, the more VRAM is required. You can get sharper edges even at the lowest setting. Typical fine value is resolution / 8." ), 3, 640 )
|
||||
self.options['gan_patch_size'] = gan_patch_size
|
||||
|
||||
gan_dims = np.clip ( io.input_int("GAN dimensions", default_gan_dims, add_info="4-512", help_message="The dimensions of the GAN network. The higher dimensions, the more VRAM is required. You can get sharper edges even at the lowest setting. Typical fine value is 16." ), 4, 512 )
|
||||
self.options['gan_dims'] = gan_dims
|
||||
self.options['gan_power'] = np.clip ( io.input_number ("GAN power", default_gan_power, add_info="0.0 .. 10.0", help_message="Train the network in Generative Adversarial manner. Accelerates the speed of training. Forces the neural network to learn small details of the face. You can enable/disable this option at any time. Typical value is 1.0"), 0.0, 10.0 )
|
||||
|
||||
if 'df' in self.options['archi']:
|
||||
self.options['true_face_power'] = np.clip ( io.input_number ("'True face' power.", default_true_face_power, add_info="0.0000 .. 1.0", help_message="Experimental option. Discriminates result face to be more like src face. Higher value - stronger discrimination. Typical value is 0.01 . Comparison - https://i.imgur.com/czScS9q.png"), 0.0, 1.0 )
|
||||
|
@ -175,13 +108,11 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
self.options['ct_mode'] = io.input_str (f"Color transfer for src faceset", default_ct_mode, ['none','rct','lct','mkl','idt','sot'], help_message="Change color distribution of src samples close to dst samples. Try all modes to find the best.")
|
||||
self.options['clipgrad'] = io.input_bool ("Enable gradient clipping", default_clipgrad, help_message="Gradient clipping reduces chance of model collapse, sacrificing speed of training.")
|
||||
|
||||
self.options['pretrain'] = io.input_bool ("Enable pretraining mode", default_pretrain, help_message="Pretrain the model with large amount of various faces. After that, model can be used to train the fakes more quickly. Forces random_warp=N, random_flips=Y, gan_power=0.0, lr_dropout=N, styles=0.0, uniform_yaw=Y")
|
||||
self.options['pretrain'] = io.input_bool ("Enable pretraining mode", default_pretrain, help_message="Pretrain the model with large amount of various faces. After that, model can be used to train the fakes more quickly.")
|
||||
|
||||
if self.options['pretrain'] and self.get_pretraining_data_path() is None:
|
||||
raise Exception("pretraining_data_path is not defined")
|
||||
|
||||
self.gan_model_changed = (default_gan_patch_size != self.options['gan_patch_size']) or (default_gan_dims != self.options['gan_dims'])
|
||||
|
||||
self.pretrain_just_disabled = (default_pretrain == True and self.options['pretrain'] == False)
|
||||
|
||||
#override
|
||||
|
@ -199,20 +130,9 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
'wf' : FaceType.WHOLE_FACE,
|
||||
'head' : FaceType.HEAD}[ self.options['face_type'] ]
|
||||
|
||||
if 'eyes_prio' in self.options:
|
||||
self.options.pop('eyes_prio')
|
||||
|
||||
eyes_mouth_prio = self.options['eyes_mouth_prio']
|
||||
|
||||
archi_split = self.options['archi'].split('-')
|
||||
|
||||
if len(archi_split) == 2:
|
||||
archi_type, archi_opts = archi_split
|
||||
elif len(archi_split) == 1:
|
||||
archi_type, archi_opts = archi_split[0], None
|
||||
|
||||
self.archi_type = archi_type
|
||||
|
||||
eyes_prio = self.options['eyes_prio']
|
||||
archi = self.options['archi']
|
||||
is_hd = 'hd' in archi
|
||||
ae_dims = self.options['ae_dims']
|
||||
e_dims = self.options['e_dims']
|
||||
d_dims = self.options['d_dims']
|
||||
|
@ -221,69 +141,46 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
if self.pretrain_just_disabled:
|
||||
self.set_iter(0)
|
||||
|
||||
adabelief = self.options['adabelief']
|
||||
|
||||
use_fp16 = False
|
||||
if self.is_exporting:
|
||||
use_fp16 = io.input_bool ("Export quantized?", False, help_message='Makes the exported model faster. If you have problems, disable this option.')
|
||||
|
||||
self.gan_power = gan_power = 0.0 if self.pretrain else self.options['gan_power']
|
||||
random_warp = False if self.pretrain else self.options['random_warp']
|
||||
random_src_flip = self.random_src_flip if not self.pretrain else True
|
||||
random_dst_flip = self.random_dst_flip if not self.pretrain else True
|
||||
random_hsv_power = self.options['random_hsv_power'] if not self.pretrain else 0.0
|
||||
blur_out_mask = self.options['blur_out_mask']
|
||||
|
||||
if self.pretrain:
|
||||
self.options_show_override['lr_dropout'] = 'n'
|
||||
self.options_show_override['random_warp'] = False
|
||||
self.options_show_override['gan_power'] = 0.0
|
||||
self.options_show_override['random_hsv_power'] = 0.0
|
||||
self.options_show_override['face_style_power'] = 0.0
|
||||
self.options_show_override['bg_style_power'] = 0.0
|
||||
self.options_show_override['uniform_yaw'] = True
|
||||
self.gan_power = gan_power = self.options['gan_power'] if not self.pretrain else 0.0
|
||||
|
||||
masked_training = self.options['masked_training']
|
||||
ct_mode = self.options['ct_mode']
|
||||
if ct_mode == 'none':
|
||||
ct_mode = None
|
||||
|
||||
|
||||
models_opt_on_gpu = False if len(devices) == 0 else self.options['models_opt_on_gpu']
|
||||
models_opt_device = nn.tf_default_device_name if models_opt_on_gpu and self.is_training else '/CPU:0'
|
||||
models_opt_device = '/GPU:0' if models_opt_on_gpu and self.is_training else '/CPU:0'
|
||||
optimizer_vars_on_cpu = models_opt_device=='/CPU:0'
|
||||
|
||||
input_ch=3
|
||||
bgr_shape = self.bgr_shape = nn.get4Dshape(resolution,resolution,input_ch)
|
||||
bgr_shape = nn.get4Dshape(resolution,resolution,input_ch)
|
||||
mask_shape = nn.get4Dshape(resolution,resolution,1)
|
||||
self.model_filename_list = []
|
||||
|
||||
with tf.device ('/CPU:0'):
|
||||
#Place holders on CPU
|
||||
self.warped_src = tf.placeholder (nn.floatx, bgr_shape, name='warped_src')
|
||||
self.warped_dst = tf.placeholder (nn.floatx, bgr_shape, name='warped_dst')
|
||||
self.warped_src = tf.placeholder (nn.floatx, bgr_shape)
|
||||
self.warped_dst = tf.placeholder (nn.floatx, bgr_shape)
|
||||
|
||||
self.target_src = tf.placeholder (nn.floatx, bgr_shape, name='target_src')
|
||||
self.target_dst = tf.placeholder (nn.floatx, bgr_shape, name='target_dst')
|
||||
self.target_src = tf.placeholder (nn.floatx, bgr_shape)
|
||||
self.target_dst = tf.placeholder (nn.floatx, bgr_shape)
|
||||
|
||||
self.target_srcm = tf.placeholder (nn.floatx, mask_shape, name='target_srcm')
|
||||
self.target_srcm_em = tf.placeholder (nn.floatx, mask_shape, name='target_srcm_em')
|
||||
self.target_dstm = tf.placeholder (nn.floatx, mask_shape, name='target_dstm')
|
||||
self.target_dstm_em = tf.placeholder (nn.floatx, mask_shape, name='target_dstm_em')
|
||||
self.target_srcm_all = tf.placeholder (nn.floatx, mask_shape)
|
||||
self.target_dstm_all = tf.placeholder (nn.floatx, mask_shape)
|
||||
|
||||
# Initializing model classes
|
||||
model_archi = nn.DeepFakeArchi(resolution, use_fp16=use_fp16, opts=archi_opts)
|
||||
model_archi = nn.DeepFakeArchi(resolution, mod='uhd' if 'uhd' in archi else None)
|
||||
|
||||
with tf.device (models_opt_device):
|
||||
if 'df' in archi_type:
|
||||
self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, name='encoder')
|
||||
encoder_out_ch = self.encoder.get_out_ch()*self.encoder.get_out_res(resolution)**2
|
||||
if 'df' in archi:
|
||||
self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, is_hd=is_hd, name='encoder')
|
||||
encoder_out_ch = self.encoder.compute_output_channels ( (nn.floatx, bgr_shape))
|
||||
|
||||
self.inter = model_archi.Inter (in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims, name='inter')
|
||||
inter_out_ch = self.inter.get_out_ch()
|
||||
self.inter = model_archi.Inter (in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims, is_hd=is_hd, name='inter')
|
||||
inter_out_ch = self.inter.compute_output_channels ( (nn.floatx, (None,encoder_out_ch)))
|
||||
|
||||
self.decoder_src = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder_src')
|
||||
self.decoder_dst = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder_dst')
|
||||
self.decoder_src = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, is_hd=is_hd, name='decoder_src')
|
||||
self.decoder_dst = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, is_hd=is_hd, name='decoder_dst')
|
||||
|
||||
self.model_filename_list += [ [self.encoder, 'encoder.npy' ],
|
||||
[self.inter, 'inter.npy' ],
|
||||
|
@ -292,19 +189,20 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
|
||||
if self.is_training:
|
||||
if self.options['true_face_power'] != 0:
|
||||
self.code_discriminator = nn.CodeDiscriminator(ae_dims, code_res=self.inter.get_out_res(), name='dis' )
|
||||
self.code_discriminator = nn.CodeDiscriminator(ae_dims, code_res=model_archi.Inter.get_code_res()*2, name='dis' )
|
||||
self.model_filename_list += [ [self.code_discriminator, 'code_discriminator.npy'] ]
|
||||
|
||||
elif 'liae' in archi_type:
|
||||
self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, name='encoder')
|
||||
encoder_out_ch = self.encoder.get_out_ch()*self.encoder.get_out_res(resolution)**2
|
||||
elif 'liae' in archi:
|
||||
self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, is_hd=is_hd, name='encoder')
|
||||
encoder_out_ch = self.encoder.compute_output_channels ( (nn.floatx, bgr_shape))
|
||||
|
||||
self.inter_AB = model_archi.Inter(in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims*2, name='inter_AB')
|
||||
self.inter_B = model_archi.Inter(in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims*2, name='inter_B')
|
||||
self.inter_AB = model_archi.Inter(in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims*2, is_hd=is_hd, name='inter_AB')
|
||||
self.inter_B = model_archi.Inter(in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims*2, is_hd=is_hd, name='inter_B')
|
||||
|
||||
inter_out_ch = self.inter_AB.get_out_ch()
|
||||
inters_out_ch = inter_out_ch*2
|
||||
self.decoder = model_archi.Decoder(in_ch=inters_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder')
|
||||
inter_AB_out_ch = self.inter_AB.compute_output_channels ( (nn.floatx, (None,encoder_out_ch)))
|
||||
inter_B_out_ch = self.inter_B.compute_output_channels ( (nn.floatx, (None,encoder_out_ch)))
|
||||
inters_out_ch = inter_AB_out_ch+inter_B_out_ch
|
||||
self.decoder = model_archi.Decoder(in_ch=inters_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, is_hd=is_hd, name='decoder')
|
||||
|
||||
self.model_filename_list += [ [self.encoder, 'encoder.npy'],
|
||||
[self.inter_AB, 'inter_AB.npy'],
|
||||
|
@ -313,43 +211,33 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
|
||||
if self.is_training:
|
||||
if gan_power != 0:
|
||||
self.D_src = nn.UNetPatchDiscriminator(patch_size=self.options['gan_patch_size'], in_ch=input_ch, base_ch=self.options['gan_dims'], name="D_src")
|
||||
self.model_filename_list += [ [self.D_src, 'GAN.npy'] ]
|
||||
self.D_src = nn.PatchDiscriminator(patch_size=resolution//16, in_ch=input_ch, name="D_src")
|
||||
self.D_dst = nn.PatchDiscriminator(patch_size=resolution//16, in_ch=input_ch, name="D_dst")
|
||||
self.model_filename_list += [ [self.D_src, 'D_src.npy'] ]
|
||||
self.model_filename_list += [ [self.D_dst, 'D_dst.npy'] ]
|
||||
|
||||
# Initialize optimizers
|
||||
lr=5e-5
|
||||
if self.options['lr_dropout'] in ['y','cpu'] and not self.pretrain:
|
||||
lr_cos = 500
|
||||
lr_dropout = 0.3
|
||||
else:
|
||||
lr_cos = 0
|
||||
lr_dropout = 1.0
|
||||
OptimizerClass = nn.AdaBelief if adabelief else nn.RMSprop
|
||||
lr_dropout = 0.3 if self.options['lr_dropout'] and not self.pretrain else 1.0
|
||||
clipnorm = 1.0 if self.options['clipgrad'] else 0.0
|
||||
|
||||
if 'df' in archi_type:
|
||||
self.src_dst_saveable_weights = self.encoder.get_weights() + self.inter.get_weights() + self.decoder_src.get_weights() + self.decoder_dst.get_weights()
|
||||
self.src_dst_trainable_weights = self.src_dst_saveable_weights
|
||||
elif 'liae' in archi_type:
|
||||
self.src_dst_saveable_weights = self.encoder.get_weights() + self.inter_AB.get_weights() + self.inter_B.get_weights() + self.decoder.get_weights()
|
||||
if random_warp:
|
||||
self.src_dst_trainable_weights = self.src_dst_saveable_weights
|
||||
else:
|
||||
self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter_B.get_weights() + self.decoder.get_weights()
|
||||
|
||||
self.src_dst_opt = OptimizerClass(lr=lr, lr_dropout=lr_dropout, lr_cos=lr_cos, clipnorm=clipnorm, name='src_dst_opt')
|
||||
self.src_dst_opt.initialize_variables (self.src_dst_saveable_weights, vars_on_cpu=optimizer_vars_on_cpu, lr_dropout_on_cpu=self.options['lr_dropout']=='cpu')
|
||||
self.src_dst_opt = nn.RMSprop(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='src_dst_opt')
|
||||
self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ]
|
||||
if 'df' in archi:
|
||||
self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter.get_weights() + self.decoder_src.get_weights() + self.decoder_dst.get_weights()
|
||||
elif 'liae' in archi:
|
||||
self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter_AB.get_weights() + self.inter_B.get_weights() + self.decoder.get_weights()
|
||||
|
||||
self.src_dst_opt.initialize_variables (self.src_dst_trainable_weights, vars_on_cpu=optimizer_vars_on_cpu)
|
||||
|
||||
if self.options['true_face_power'] != 0:
|
||||
self.D_code_opt = OptimizerClass(lr=lr, lr_dropout=lr_dropout, lr_cos=lr_cos, clipnorm=clipnorm, name='D_code_opt')
|
||||
self.D_code_opt.initialize_variables ( self.code_discriminator.get_weights(), vars_on_cpu=optimizer_vars_on_cpu, lr_dropout_on_cpu=self.options['lr_dropout']=='cpu')
|
||||
self.D_code_opt = nn.RMSprop(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='D_code_opt')
|
||||
self.D_code_opt.initialize_variables ( self.code_discriminator.get_weights(), vars_on_cpu=optimizer_vars_on_cpu)
|
||||
self.model_filename_list += [ (self.D_code_opt, 'D_code_opt.npy') ]
|
||||
|
||||
if gan_power != 0:
|
||||
self.D_src_dst_opt = OptimizerClass(lr=lr, lr_dropout=lr_dropout, lr_cos=lr_cos, clipnorm=clipnorm, name='GAN_opt')
|
||||
self.D_src_dst_opt.initialize_variables ( self.D_src.get_weights(), vars_on_cpu=optimizer_vars_on_cpu, lr_dropout_on_cpu=self.options['lr_dropout']=='cpu')#+self.D_src_x2.get_weights()
|
||||
self.model_filename_list += [ (self.D_src_dst_opt, 'GAN_opt.npy') ]
|
||||
self.D_src_dst_opt = nn.RMSprop(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='D_src_dst_opt')
|
||||
self.D_src_dst_opt.initialize_variables ( self.D_src.get_weights()+self.D_dst.get_weights(), vars_on_cpu=optimizer_vars_on_cpu)
|
||||
self.model_filename_list += [ (self.D_src_dst_opt, 'D_src_dst_opt.npy') ]
|
||||
|
||||
if self.is_training:
|
||||
# Adjust batch size for multiple GPU
|
||||
|
@ -357,6 +245,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
bs_per_gpu = max(1, self.get_batch_size() // gpu_count)
|
||||
self.set_batch_size( gpu_count*bs_per_gpu)
|
||||
|
||||
|
||||
# Compute losses per GPU
|
||||
gpu_pred_src_src_list = []
|
||||
gpu_pred_dst_dst_list = []
|
||||
|
@ -370,9 +259,9 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
gpu_G_loss_gvs = []
|
||||
gpu_D_code_loss_gvs = []
|
||||
gpu_D_src_dst_loss_gvs = []
|
||||
|
||||
for gpu_id in range(gpu_count):
|
||||
with tf.device( f'/{devices[gpu_id].tf_dev_type}:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ):
|
||||
with tf.device( f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ):
|
||||
|
||||
with tf.device(f'/CPU:0'):
|
||||
# slice on CPU, otherwise all batch data will be transfered to GPU first
|
||||
batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu )
|
||||
|
@ -380,38 +269,18 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
gpu_warped_dst = self.warped_dst [batch_slice,:,:,:]
|
||||
gpu_target_src = self.target_src [batch_slice,:,:,:]
|
||||
gpu_target_dst = self.target_dst [batch_slice,:,:,:]
|
||||
gpu_target_srcm = self.target_srcm[batch_slice,:,:,:]
|
||||
gpu_target_srcm_em = self.target_srcm_em[batch_slice,:,:,:]
|
||||
gpu_target_dstm = self.target_dstm[batch_slice,:,:,:]
|
||||
gpu_target_dstm_em = self.target_dstm_em[batch_slice,:,:,:]
|
||||
|
||||
gpu_target_srcm_anti = 1-gpu_target_srcm
|
||||
gpu_target_dstm_anti = 1-gpu_target_dstm
|
||||
|
||||
if blur_out_mask:
|
||||
sigma = resolution / 128
|
||||
|
||||
x = nn.gaussian_blur(gpu_target_src*gpu_target_srcm_anti, sigma)
|
||||
y = 1-nn.gaussian_blur(gpu_target_srcm, sigma)
|
||||
y = tf.where(tf.equal(y, 0), tf.ones_like(y), y)
|
||||
gpu_target_src = gpu_target_src*gpu_target_srcm + (x/y)*gpu_target_srcm_anti
|
||||
|
||||
x = nn.gaussian_blur(gpu_target_dst*gpu_target_dstm_anti, sigma)
|
||||
y = 1-nn.gaussian_blur(gpu_target_dstm, sigma)
|
||||
y = tf.where(tf.equal(y, 0), tf.ones_like(y), y)
|
||||
gpu_target_dst = gpu_target_dst*gpu_target_dstm + (x/y)*gpu_target_dstm_anti
|
||||
|
||||
gpu_target_srcm_all = self.target_srcm_all[batch_slice,:,:,:]
|
||||
gpu_target_dstm_all = self.target_dstm_all[batch_slice,:,:,:]
|
||||
|
||||
# process model tensors
|
||||
if 'df' in archi_type:
|
||||
if 'df' in archi:
|
||||
gpu_src_code = self.inter(self.encoder(gpu_warped_src))
|
||||
gpu_dst_code = self.inter(self.encoder(gpu_warped_dst))
|
||||
gpu_pred_src_src, gpu_pred_src_srcm = self.decoder_src(gpu_src_code)
|
||||
gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code)
|
||||
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code)
|
||||
gpu_pred_src_dst_no_code_grad, _ = self.decoder_src(tf.stop_gradient(gpu_dst_code))
|
||||
|
||||
elif 'liae' in archi_type:
|
||||
elif 'liae' in archi:
|
||||
gpu_src_code = self.encoder (gpu_warped_src)
|
||||
gpu_src_inter_AB_code = self.inter_AB (gpu_src_code)
|
||||
gpu_src_code = tf.concat([gpu_src_inter_AB_code,gpu_src_inter_AB_code], nn.conv2d_ch_axis )
|
||||
|
@ -424,7 +293,6 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
gpu_pred_src_src, gpu_pred_src_srcm = self.decoder(gpu_src_code)
|
||||
gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder(gpu_dst_code)
|
||||
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code)
|
||||
gpu_pred_src_dst_no_code_grad, _ = self.decoder(tf.stop_gradient(gpu_src_dst_code))
|
||||
|
||||
gpu_pred_src_src_list.append(gpu_pred_src_src)
|
||||
gpu_pred_dst_dst_list.append(gpu_pred_dst_dst)
|
||||
|
@ -434,60 +302,49 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm)
|
||||
gpu_pred_src_dstm_list.append(gpu_pred_src_dstm)
|
||||
|
||||
# unpack masks from one combined mask
|
||||
gpu_target_srcm = tf.clip_by_value (gpu_target_srcm_all, 0, 1)
|
||||
gpu_target_dstm = tf.clip_by_value (gpu_target_dstm_all, 0, 1)
|
||||
gpu_target_srcm_eyes = tf.clip_by_value (gpu_target_srcm_all-1, 0, 1)
|
||||
gpu_target_dstm_eyes = tf.clip_by_value (gpu_target_dstm_all-1, 0, 1)
|
||||
|
||||
gpu_target_srcm_blur = nn.gaussian_blur(gpu_target_srcm, max(1, resolution // 32) )
|
||||
gpu_target_srcm_blur = tf.clip_by_value(gpu_target_srcm_blur, 0, 0.5) * 2
|
||||
gpu_target_srcm_anti_blur = 1.0-gpu_target_srcm_blur
|
||||
|
||||
gpu_target_dstm_blur = nn.gaussian_blur(gpu_target_dstm, max(1, resolution // 32) )
|
||||
gpu_target_dstm_blur = tf.clip_by_value(gpu_target_dstm_blur, 0, 0.5) * 2
|
||||
|
||||
gpu_style_mask_blur = nn.gaussian_blur(gpu_pred_src_dstm*gpu_pred_dst_dstm, max(1, resolution // 32) )
|
||||
gpu_style_mask_blur = tf.stop_gradient(tf.clip_by_value(gpu_target_srcm_blur, 0, 1.0))
|
||||
gpu_style_mask_anti_blur = 1.0 - gpu_style_mask_blur
|
||||
|
||||
gpu_target_dst_masked = gpu_target_dst*gpu_target_dstm_blur
|
||||
|
||||
gpu_target_src_anti_masked = gpu_target_src*gpu_target_srcm_anti_blur
|
||||
gpu_pred_src_src_anti_masked = gpu_pred_src_src*gpu_target_srcm_anti_blur
|
||||
gpu_target_dst_masked = gpu_target_dst*gpu_target_dstm_blur
|
||||
gpu_target_dst_anti_masked = gpu_target_dst*(1.0 - gpu_target_dstm_blur)
|
||||
|
||||
gpu_target_src_masked_opt = gpu_target_src*gpu_target_srcm_blur if masked_training else gpu_target_src
|
||||
gpu_target_dst_masked_opt = gpu_target_dst_masked if masked_training else gpu_target_dst
|
||||
|
||||
gpu_pred_src_src_masked_opt = gpu_pred_src_src*gpu_target_srcm_blur if masked_training else gpu_pred_src_src
|
||||
gpu_pred_dst_dst_masked_opt = gpu_pred_dst_dst*gpu_target_dstm_blur if masked_training else gpu_pred_dst_dst
|
||||
|
||||
if resolution < 256:
|
||||
gpu_src_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
|
||||
else:
|
||||
gpu_src_loss = tf.reduce_mean ( 5*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
|
||||
gpu_src_loss += tf.reduce_mean ( 5*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1])
|
||||
gpu_psd_target_dst_masked = gpu_pred_src_dst*gpu_target_dstm_blur
|
||||
gpu_psd_target_dst_anti_masked = gpu_pred_src_dst*(1.0 - gpu_target_dstm_blur)
|
||||
|
||||
gpu_src_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
|
||||
gpu_src_loss += tf.reduce_mean ( 10*tf.square ( gpu_target_src_masked_opt - gpu_pred_src_src_masked_opt ), axis=[1,2,3])
|
||||
|
||||
if eyes_mouth_prio:
|
||||
gpu_src_loss += tf.reduce_mean ( 300*tf.abs ( gpu_target_src*gpu_target_srcm_em - gpu_pred_src_src*gpu_target_srcm_em ), axis=[1,2,3])
|
||||
if eyes_prio:
|
||||
gpu_src_loss += tf.reduce_mean ( 300*tf.abs ( gpu_target_src*gpu_target_srcm_eyes - gpu_pred_src_src*gpu_target_srcm_eyes ), axis=[1,2,3])
|
||||
|
||||
gpu_src_loss += tf.reduce_mean ( 10*tf.square( gpu_target_srcm - gpu_pred_src_srcm ),axis=[1,2,3] )
|
||||
|
||||
face_style_power = self.options['face_style_power'] / 100.0
|
||||
if face_style_power != 0 and not self.pretrain:
|
||||
gpu_src_loss += nn.style_loss(gpu_pred_src_dst_no_code_grad*tf.stop_gradient(gpu_pred_src_dstm), tf.stop_gradient(gpu_pred_dst_dst*gpu_pred_dst_dstm), gaussian_blur_radius=resolution//8, loss_weight=10000*face_style_power)
|
||||
gpu_src_loss += nn.style_loss(gpu_psd_target_dst_masked, gpu_target_dst_masked, gaussian_blur_radius=resolution//16, loss_weight=10000*face_style_power)
|
||||
|
||||
bg_style_power = self.options['bg_style_power'] / 100.0
|
||||
if bg_style_power != 0 and not self.pretrain:
|
||||
gpu_target_dst_style_anti_masked = gpu_target_dst*gpu_style_mask_anti_blur
|
||||
gpu_psd_style_anti_masked = gpu_pred_src_dst*gpu_style_mask_anti_blur
|
||||
gpu_src_loss += tf.reduce_mean( (10*bg_style_power)*nn.dssim(gpu_psd_target_dst_anti_masked, gpu_target_dst_anti_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
|
||||
gpu_src_loss += tf.reduce_mean( (10*bg_style_power)*tf.square( gpu_psd_target_dst_anti_masked - gpu_target_dst_anti_masked), axis=[1,2,3] )
|
||||
|
||||
gpu_src_loss += tf.reduce_mean( (10*bg_style_power)*nn.dssim( gpu_psd_style_anti_masked, gpu_target_dst_style_anti_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
|
||||
gpu_src_loss += tf.reduce_mean( (10*bg_style_power)*tf.square(gpu_psd_style_anti_masked - gpu_target_dst_style_anti_masked), axis=[1,2,3] )
|
||||
|
||||
if resolution < 256:
|
||||
gpu_dst_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1])
|
||||
else:
|
||||
gpu_dst_loss = tf.reduce_mean ( 5*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1])
|
||||
gpu_dst_loss += tf.reduce_mean ( 5*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/23.2) ), axis=[1])
|
||||
gpu_dst_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1])
|
||||
gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dst_masked_opt- gpu_pred_dst_dst_masked_opt ), axis=[1,2,3])
|
||||
|
||||
if eyes_mouth_prio:
|
||||
gpu_dst_loss += tf.reduce_mean ( 300*tf.abs ( gpu_target_dst*gpu_target_dstm_em - gpu_pred_dst_dst*gpu_target_dstm_em ), axis=[1,2,3])
|
||||
if eyes_prio:
|
||||
gpu_dst_loss += tf.reduce_mean ( 300*tf.abs ( gpu_target_dst*gpu_target_dstm_eyes - gpu_pred_dst_dst*gpu_target_dstm_eyes ), axis=[1,2,3])
|
||||
|
||||
gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] )
|
||||
|
||||
|
@ -508,49 +365,38 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
|
||||
gpu_G_loss += self.options['true_face_power']*DLoss(gpu_src_code_d_ones, gpu_src_code_d)
|
||||
|
||||
gpu_D_code_loss = (DLoss(gpu_dst_code_d_ones , gpu_dst_code_d) + \
|
||||
gpu_D_code_loss = (DLoss(gpu_src_code_d_ones , gpu_dst_code_d) + \
|
||||
DLoss(gpu_src_code_d_zeros, gpu_src_code_d) ) * 0.5
|
||||
|
||||
gpu_D_code_loss_gvs += [ nn.gradients (gpu_D_code_loss, self.code_discriminator.get_weights() ) ]
|
||||
|
||||
if gan_power != 0:
|
||||
gpu_pred_src_src_d, \
|
||||
gpu_pred_src_src_d2 = self.D_src(gpu_pred_src_src_masked_opt)
|
||||
|
||||
gpu_pred_src_src_d = self.D_src(gpu_pred_src_src_masked_opt)
|
||||
gpu_pred_src_src_d_ones = tf.ones_like (gpu_pred_src_src_d)
|
||||
gpu_pred_src_src_d_zeros = tf.zeros_like(gpu_pred_src_src_d)
|
||||
|
||||
gpu_pred_src_src_d2_ones = tf.ones_like (gpu_pred_src_src_d2)
|
||||
gpu_pred_src_src_d2_zeros = tf.zeros_like(gpu_pred_src_src_d2)
|
||||
|
||||
gpu_target_src_d, \
|
||||
gpu_target_src_d2 = self.D_src(gpu_target_src_masked_opt)
|
||||
|
||||
gpu_target_src_d = self.D_src(gpu_target_src_masked_opt)
|
||||
gpu_target_src_d_ones = tf.ones_like(gpu_target_src_d)
|
||||
gpu_target_src_d2_ones = tf.ones_like(gpu_target_src_d2)
|
||||
gpu_pred_dst_dst_d = self.D_dst(gpu_pred_dst_dst_masked_opt)
|
||||
gpu_pred_dst_dst_d_ones = tf.ones_like (gpu_pred_dst_dst_d)
|
||||
gpu_pred_dst_dst_d_zeros = tf.zeros_like(gpu_pred_dst_dst_d)
|
||||
gpu_target_dst_d = self.D_dst(gpu_target_dst_masked_opt)
|
||||
gpu_target_dst_d_ones = tf.ones_like(gpu_target_dst_d)
|
||||
|
||||
gpu_D_src_dst_loss = (DLoss(gpu_target_src_d_ones , gpu_target_src_d) + \
|
||||
DLoss(gpu_pred_src_src_d_zeros , gpu_pred_src_src_d) ) * 0.5 + \
|
||||
(DLoss(gpu_target_src_d2_ones , gpu_target_src_d2) + \
|
||||
DLoss(gpu_pred_src_src_d2_zeros , gpu_pred_src_src_d2) ) * 0.5
|
||||
gpu_D_src_dst_loss = (DLoss(gpu_target_src_d_ones , gpu_target_src_d) + \
|
||||
DLoss(gpu_pred_src_src_d_zeros, gpu_pred_src_src_d) ) * 0.5 + \
|
||||
(DLoss(gpu_target_dst_d_ones , gpu_target_dst_d) + \
|
||||
DLoss(gpu_pred_dst_dst_d_zeros, gpu_pred_dst_dst_d) ) * 0.5
|
||||
|
||||
gpu_D_src_dst_loss_gvs += [ nn.gradients (gpu_D_src_dst_loss, self.D_src.get_weights() ) ]#+self.D_src_x2.get_weights()
|
||||
gpu_D_src_dst_loss_gvs += [ nn.gradients (gpu_D_src_dst_loss, self.D_src.get_weights()+self.D_dst.get_weights() ) ]
|
||||
|
||||
gpu_G_loss += gan_power*(DLoss(gpu_pred_src_src_d_ones, gpu_pred_src_src_d) + \
|
||||
DLoss(gpu_pred_src_src_d2_ones, gpu_pred_src_src_d2))
|
||||
|
||||
if masked_training:
|
||||
# Minimal src-src-bg rec with total_variation_mse to suppress random bright dots from gan
|
||||
gpu_G_loss += 0.000001*nn.total_variation_mse(gpu_pred_src_src)
|
||||
gpu_G_loss += 0.02*tf.reduce_mean(tf.square(gpu_pred_src_src_anti_masked-gpu_target_src_anti_masked),axis=[1,2,3] )
|
||||
|
||||
gpu_G_loss_gvs += [ nn.gradients ( gpu_G_loss, self.src_dst_trainable_weights )]
|
||||
gpu_G_loss += gan_power*(DLoss(gpu_pred_src_src_d_ones, gpu_pred_src_src_d) + DLoss(gpu_pred_dst_dst_d_ones, gpu_pred_dst_dst_d))
|
||||
|
||||
|
||||
gpu_G_loss_gvs += [ nn.gradients ( gpu_G_loss, self.src_dst_trainable_weights ) ]
|
||||
|
||||
|
||||
# Average losses and gradients, and create optimizer update ops
|
||||
with tf.device(f'/CPU:0'):
|
||||
with tf.device (models_opt_device):
|
||||
pred_src_src = nn.concat(gpu_pred_src_src_list, 0)
|
||||
pred_dst_dst = nn.concat(gpu_pred_dst_dst_list, 0)
|
||||
pred_src_dst = nn.concat(gpu_pred_src_dst_list, 0)
|
||||
|
@ -558,7 +404,6 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
pred_dst_dstm = nn.concat(gpu_pred_dst_dstm_list, 0)
|
||||
pred_src_dstm = nn.concat(gpu_pred_src_dstm_list, 0)
|
||||
|
||||
with tf.device (models_opt_device):
|
||||
src_loss = tf.concat(gpu_src_losses, 0)
|
||||
dst_loss = tf.concat(gpu_dst_losses, 0)
|
||||
src_dst_loss_gv_op = self.src_dst_opt.get_update_op (nn.average_gv_list (gpu_G_loss_gvs))
|
||||
|
@ -571,18 +416,16 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
|
||||
|
||||
# Initializing training and view functions
|
||||
def src_dst_train(warped_src, target_src, target_srcm, target_srcm_em, \
|
||||
warped_dst, target_dst, target_dstm, target_dstm_em, ):
|
||||
s, d = nn.tf_sess.run ( [ src_loss, dst_loss, src_dst_loss_gv_op],
|
||||
def src_dst_train(warped_src, target_src, target_srcm_all, \
|
||||
warped_dst, target_dst, target_dstm_all):
|
||||
s, d, _ = nn.tf_sess.run ( [ src_loss, dst_loss, src_dst_loss_gv_op],
|
||||
feed_dict={self.warped_src :warped_src,
|
||||
self.target_src :target_src,
|
||||
self.target_srcm:target_srcm,
|
||||
self.target_srcm_em:target_srcm_em,
|
||||
self.target_srcm_all:target_srcm_all,
|
||||
self.warped_dst :warped_dst,
|
||||
self.target_dst :target_dst,
|
||||
self.target_dstm:target_dstm,
|
||||
self.target_dstm_em:target_dstm_em,
|
||||
})[:2]
|
||||
self.target_dstm_all:target_dstm_all,
|
||||
})
|
||||
return s, d
|
||||
self.src_dst_train = src_dst_train
|
||||
|
||||
|
@ -592,16 +435,14 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
self.D_train = D_train
|
||||
|
||||
if gan_power != 0:
|
||||
def D_src_dst_train(warped_src, target_src, target_srcm, target_srcm_em, \
|
||||
warped_dst, target_dst, target_dstm, target_dstm_em, ):
|
||||
def D_src_dst_train(warped_src, target_src, target_srcm_all, \
|
||||
warped_dst, target_dst, target_dstm_all):
|
||||
nn.tf_sess.run ([src_D_src_dst_loss_gv_op], feed_dict={self.warped_src :warped_src,
|
||||
self.target_src :target_src,
|
||||
self.target_srcm:target_srcm,
|
||||
self.target_srcm_em:target_srcm_em,
|
||||
self.target_srcm_all:target_srcm_all,
|
||||
self.warped_dst :warped_dst,
|
||||
self.target_dst :target_dst,
|
||||
self.target_dstm:target_dstm,
|
||||
self.target_dstm_em:target_dstm_em})
|
||||
self.target_dstm_all:target_dstm_all})
|
||||
self.D_src_dst_train = D_src_dst_train
|
||||
|
||||
|
||||
|
@ -612,13 +453,13 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
self.AE_view = AE_view
|
||||
else:
|
||||
# Initializing merge function
|
||||
with tf.device( nn.tf_default_device_name if len(devices) != 0 else f'/CPU:0'):
|
||||
if 'df' in archi_type:
|
||||
with tf.device( f'/GPU:0' if len(devices) != 0 else f'/CPU:0'):
|
||||
if 'df' in archi:
|
||||
gpu_dst_code = self.inter(self.encoder(self.warped_dst))
|
||||
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code)
|
||||
_, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code)
|
||||
|
||||
elif 'liae' in archi_type:
|
||||
elif 'liae' in archi:
|
||||
gpu_dst_code = self.encoder (self.warped_dst)
|
||||
gpu_dst_inter_B_code = self.inter_B (gpu_dst_code)
|
||||
gpu_dst_inter_AB_code = self.inter_AB (gpu_dst_code)
|
||||
|
@ -638,17 +479,14 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"):
|
||||
if self.pretrain_just_disabled:
|
||||
do_init = False
|
||||
if 'df' in archi_type:
|
||||
if 'df' in archi:
|
||||
if model == self.inter:
|
||||
do_init = True
|
||||
elif 'liae' in archi_type:
|
||||
elif 'liae' in archi:
|
||||
if model == self.inter_AB or model == self.inter_B:
|
||||
do_init = True
|
||||
else:
|
||||
do_init = self.is_first_run()
|
||||
if self.is_training and gan_power != 0 and model == self.D_src:
|
||||
if self.gan_model_changed:
|
||||
do_init = True
|
||||
|
||||
if not do_init:
|
||||
do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) )
|
||||
|
@ -656,9 +494,6 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
if do_init:
|
||||
model.init_weights()
|
||||
|
||||
|
||||
###############
|
||||
|
||||
# initializing sample generators
|
||||
if self.is_training:
|
||||
training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path()
|
||||
|
@ -666,7 +501,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
|
||||
random_ct_samples_path=training_data_dst_path if ct_mode is not None and not self.pretrain else None
|
||||
|
||||
cpu_count = multiprocessing.cpu_count()
|
||||
cpu_count = min(multiprocessing.cpu_count(), 8)
|
||||
src_generators_count = cpu_count // 2
|
||||
dst_generators_count = cpu_count // 2
|
||||
if ct_mode is not None:
|
||||
|
@ -674,81 +509,28 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
|
||||
self.set_training_data_generators ([
|
||||
SampleGeneratorFace(training_data_src_path, random_ct_samples_path=random_ct_samples_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
|
||||
sample_process_options=SampleProcessor.Options(scale_range=[-0.15, 0.15], random_flip=random_src_flip),
|
||||
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'random_hsv_shift_amount' : random_hsv_power, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.EYES_MOUTH, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip),
|
||||
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':self.options['random_warp'], 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE_EYES, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
],
|
||||
uniform_yaw_distribution=self.options['uniform_yaw'] or self.pretrain,
|
||||
generators_count=src_generators_count ),
|
||||
|
||||
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
|
||||
sample_process_options=SampleProcessor.Options(scale_range=[-0.15, 0.15], random_flip=random_dst_flip),
|
||||
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.EYES_MOUTH, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip),
|
||||
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':self.options['random_warp'], 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE_EYES, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
],
|
||||
uniform_yaw_distribution=self.options['uniform_yaw'] or self.pretrain,
|
||||
generators_count=dst_generators_count )
|
||||
])
|
||||
|
||||
self.last_src_samples_loss = []
|
||||
self.last_dst_samples_loss = []
|
||||
|
||||
if self.pretrain_just_disabled:
|
||||
self.update_sample_for_preview(force_new=True)
|
||||
|
||||
def export_dfm (self):
|
||||
output_path=self.get_strpath_storage_for_file('model.dfm')
|
||||
|
||||
io.log_info(f'Dumping .dfm to {output_path}')
|
||||
|
||||
tf = nn.tf
|
||||
nn.set_data_format('NCHW')
|
||||
|
||||
with tf.device (nn.tf_default_device_name):
|
||||
warped_dst = tf.placeholder (nn.floatx, (None, self.resolution, self.resolution, 3), name='in_face')
|
||||
warped_dst = tf.transpose(warped_dst, (0,3,1,2))
|
||||
|
||||
|
||||
if 'df' in self.archi_type:
|
||||
gpu_dst_code = self.inter(self.encoder(warped_dst))
|
||||
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code)
|
||||
_, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code)
|
||||
|
||||
elif 'liae' in self.archi_type:
|
||||
gpu_dst_code = self.encoder (warped_dst)
|
||||
gpu_dst_inter_B_code = self.inter_B (gpu_dst_code)
|
||||
gpu_dst_inter_AB_code = self.inter_AB (gpu_dst_code)
|
||||
gpu_dst_code = tf.concat([gpu_dst_inter_B_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis)
|
||||
gpu_src_dst_code = tf.concat([gpu_dst_inter_AB_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis)
|
||||
|
||||
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code)
|
||||
_, gpu_pred_dst_dstm = self.decoder(gpu_dst_code)
|
||||
|
||||
gpu_pred_src_dst = tf.transpose(gpu_pred_src_dst, (0,2,3,1))
|
||||
gpu_pred_dst_dstm = tf.transpose(gpu_pred_dst_dstm, (0,2,3,1))
|
||||
gpu_pred_src_dstm = tf.transpose(gpu_pred_src_dstm, (0,2,3,1))
|
||||
|
||||
tf.identity(gpu_pred_dst_dstm, name='out_face_mask')
|
||||
tf.identity(gpu_pred_src_dst, name='out_celeb_face')
|
||||
tf.identity(gpu_pred_src_dstm, name='out_celeb_face_mask')
|
||||
|
||||
output_graph_def = tf.graph_util.convert_variables_to_constants(
|
||||
nn.tf_sess,
|
||||
tf.get_default_graph().as_graph_def(),
|
||||
['out_face_mask','out_celeb_face','out_celeb_face_mask']
|
||||
)
|
||||
|
||||
import tf2onnx
|
||||
with tf.device("/CPU:0"):
|
||||
model_proto, _ = tf2onnx.convert._convert_common(
|
||||
output_graph_def,
|
||||
name='SAEHD',
|
||||
input_names=['in_face:0'],
|
||||
output_names=['out_face_mask:0','out_celeb_face:0','out_celeb_face_mask:0'],
|
||||
opset=12,
|
||||
output_path=output_path)
|
||||
|
||||
#override
|
||||
def get_model_filename_list(self):
|
||||
return self.model_filename_list
|
||||
|
@ -758,38 +540,54 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
for model, filename in io.progress_bar_generator(self.get_model_filename_list(), "Saving", leave=False):
|
||||
model.save_weights ( self.get_strpath_storage_for_file(filename) )
|
||||
|
||||
#override
|
||||
def should_save_preview_history(self):
|
||||
return (not io.is_colab() and self.iter % ( 10*(max(1,self.resolution // 64)) ) == 0) or \
|
||||
(io.is_colab() and self.iter % 100 == 0)
|
||||
|
||||
#override
|
||||
def onTrainOneIter(self):
|
||||
if self.get_iter() == 0 and not self.pretrain and not self.pretrain_just_disabled:
|
||||
io.log_info('You are training the model from scratch. It is strongly recommended to use a pretrained model to speed up the training and improve the quality.\n')
|
||||
bs = self.get_batch_size()
|
||||
|
||||
( (warped_src, target_src, target_srcm, target_srcm_em), \
|
||||
(warped_dst, target_dst, target_dstm, target_dstm_em) ) = self.generate_next_samples()
|
||||
( (warped_src, target_src, target_srcm_all), \
|
||||
(warped_dst, target_dst, target_dstm_all) ) = self.generate_next_samples()
|
||||
|
||||
src_loss, dst_loss = self.src_dst_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em)
|
||||
src_loss, dst_loss = self.src_dst_train (warped_src, target_src, target_srcm_all, warped_dst, target_dst, target_dstm_all)
|
||||
|
||||
for i in range(bs):
|
||||
self.last_src_samples_loss.append ( (target_src[i], target_srcm_all[i], src_loss[i] ) )
|
||||
self.last_dst_samples_loss.append ( (target_dst[i], target_dstm_all[i], dst_loss[i] ) )
|
||||
|
||||
if len(self.last_src_samples_loss) >= bs*16:
|
||||
src_samples_loss = sorted(self.last_src_samples_loss, key=operator.itemgetter(2), reverse=True)
|
||||
dst_samples_loss = sorted(self.last_dst_samples_loss, key=operator.itemgetter(2), reverse=True)
|
||||
|
||||
target_src = np.stack( [ x[0] for x in src_samples_loss[:bs] ] )
|
||||
target_srcm_all = np.stack( [ x[1] for x in src_samples_loss[:bs] ] )
|
||||
|
||||
target_dst = np.stack( [ x[0] for x in dst_samples_loss[:bs] ] )
|
||||
target_dstm_all = np.stack( [ x[1] for x in dst_samples_loss[:bs] ] )
|
||||
|
||||
src_loss, dst_loss = self.src_dst_train (target_src, target_src, target_srcm_all, target_dst, target_dst, target_dstm_all)
|
||||
self.last_src_samples_loss = []
|
||||
self.last_dst_samples_loss = []
|
||||
|
||||
if self.options['true_face_power'] != 0 and not self.pretrain:
|
||||
self.D_train (warped_src, warped_dst)
|
||||
|
||||
if self.gan_power != 0:
|
||||
self.D_src_dst_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em)
|
||||
self.D_src_dst_train (warped_src, target_src, target_srcm_all, warped_dst, target_dst, target_dstm_all)
|
||||
|
||||
return ( ('src_loss', np.mean(src_loss) ), ('dst_loss', np.mean(dst_loss) ), )
|
||||
|
||||
#override
|
||||
def onGetPreview(self, samples, for_history=False):
|
||||
( (warped_src, target_src, target_srcm, target_srcm_em),
|
||||
(warped_dst, target_dst, target_dstm, target_dstm_em) ) = samples
|
||||
def onGetPreview(self, samples):
|
||||
( (warped_src, target_src, target_srcm_all,),
|
||||
(warped_dst, target_dst, target_dstm_all,) ) = samples
|
||||
|
||||
S, D, SS, DD, DDM, SD, SDM = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([target_src,target_dst] + self.AE_view (target_src, target_dst) ) ]
|
||||
DDM, SDM, = [ np.repeat (x, (3,), -1) for x in [DDM, SDM] ]
|
||||
|
||||
target_srcm, target_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format) for x in ([target_srcm, target_dstm] )]
|
||||
target_srcm_all, target_dstm_all = [ nn.to_data_format(x,"NHWC", self.model_data_format) for x in ([target_srcm_all, target_dstm_all] )]
|
||||
|
||||
target_srcm = np.clip(target_srcm_all, 0, 1)
|
||||
target_dstm = np.clip(target_dstm_all, 0, 1)
|
||||
|
||||
n_samples = min(4, self.get_batch_size(), 800 // self.resolution )
|
||||
|
||||
|
|
|
@ -18,6 +18,8 @@ class XSegModel(ModelBase):
|
|||
|
||||
#override
|
||||
def on_initialize_options(self):
|
||||
self.set_batch_size(4)
|
||||
|
||||
ask_override = self.ask_override()
|
||||
|
||||
if not self.is_first_run() and ask_override:
|
||||
|
@ -25,24 +27,15 @@ class XSegModel(ModelBase):
|
|||
self.set_iter(0)
|
||||
|
||||
default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'wf')
|
||||
default_pretrain = self.options['pretrain'] = self.load_or_def_option('pretrain', False)
|
||||
|
||||
if self.is_first_run():
|
||||
self.options['face_type'] = io.input_str ("Face type", default_face_type, ['h','mf','f','wf','head'], help_message="Half / mid face / full face / whole face / head. Choose the same as your deepfake model.").lower()
|
||||
|
||||
if self.is_first_run() or ask_override:
|
||||
self.ask_batch_size(4, range=[2,16])
|
||||
self.options['pretrain'] = io.input_bool ("Enable pretraining mode", default_pretrain)
|
||||
|
||||
if not self.is_exporting and (self.options['pretrain'] and self.get_pretraining_data_path() is None):
|
||||
raise Exception("pretraining_data_path is not defined")
|
||||
|
||||
self.pretrain_just_disabled = (default_pretrain == True and self.options['pretrain'] == False)
|
||||
|
||||
#override
|
||||
def on_initialize(self):
|
||||
device_config = nn.getCurrentDeviceConfig()
|
||||
self.model_data_format = "NCHW" if self.is_exporting or (len(device_config.devices) != 0 and not self.is_debug()) else "NHWC"
|
||||
self.model_data_format = "NCHW" if len(device_config.devices) != 0 and not self.is_debug() else "NHWC"
|
||||
nn.initialize(data_format=self.model_data_format)
|
||||
tf = nn.tf
|
||||
|
||||
|
@ -58,9 +51,8 @@ class XSegModel(ModelBase):
|
|||
'wf' : FaceType.WHOLE_FACE,
|
||||
'head' : FaceType.HEAD}[ self.options['face_type'] ]
|
||||
|
||||
|
||||
place_model_on_cpu = len(devices) == 0
|
||||
models_opt_device = '/CPU:0' if place_model_on_cpu else nn.tf_default_device_name
|
||||
models_opt_device = '/CPU:0' if place_model_on_cpu else '/GPU:0'
|
||||
|
||||
bgr_shape = nn.get4Dshape(resolution,resolution,3)
|
||||
mask_shape = nn.get4Dshape(resolution,resolution,1)
|
||||
|
@ -75,16 +67,13 @@ class XSegModel(ModelBase):
|
|||
optimizer=nn.RMSprop(lr=0.0001, lr_dropout=0.3, name='opt'),
|
||||
data_format=nn.data_format)
|
||||
|
||||
self.pretrain = self.options['pretrain']
|
||||
if self.pretrain_just_disabled:
|
||||
self.set_iter(0)
|
||||
|
||||
if self.is_training:
|
||||
# Adjust batch size for multiple GPU
|
||||
gpu_count = max(1, len(devices) )
|
||||
bs_per_gpu = max(1, self.get_batch_size() // gpu_count)
|
||||
self.set_batch_size( gpu_count*bs_per_gpu)
|
||||
|
||||
|
||||
# Compute losses per GPU
|
||||
gpu_pred_list = []
|
||||
|
||||
|
@ -92,7 +81,8 @@ class XSegModel(ModelBase):
|
|||
gpu_loss_gvs = []
|
||||
|
||||
for gpu_id in range(gpu_count):
|
||||
with tf.device(f'/{devices[gpu_id].tf_dev_type}:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ):
|
||||
with tf.device( f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ):
|
||||
|
||||
with tf.device(f'/CPU:0'):
|
||||
# slice on CPU, otherwise all batch data will be transfered to GPU first
|
||||
batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu )
|
||||
|
@ -100,41 +90,27 @@ class XSegModel(ModelBase):
|
|||
gpu_target_t = self.model.target_t [batch_slice,:,:,:]
|
||||
|
||||
# process model tensors
|
||||
gpu_pred_logits_t, gpu_pred_t = self.model.flow(gpu_input_t, pretrain=self.pretrain)
|
||||
gpu_pred_logits_t, gpu_pred_t = self.model.flow(gpu_input_t)
|
||||
gpu_pred_list.append(gpu_pred_t)
|
||||
|
||||
|
||||
if self.pretrain:
|
||||
# Structural loss
|
||||
gpu_loss = tf.reduce_mean (5*nn.dssim(gpu_target_t, gpu_pred_t, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
|
||||
gpu_loss += tf.reduce_mean (5*nn.dssim(gpu_target_t, gpu_pred_t, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1])
|
||||
# Pixel loss
|
||||
gpu_loss += tf.reduce_mean (10*tf.square(gpu_target_t-gpu_pred_t), axis=[1,2,3])
|
||||
else:
|
||||
gpu_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=gpu_target_t, logits=gpu_pred_logits_t), axis=[1,2,3])
|
||||
|
||||
gpu_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=gpu_target_t, logits=gpu_pred_logits_t), axis=[1,2,3])
|
||||
gpu_losses += [gpu_loss]
|
||||
|
||||
gpu_loss_gvs += [ nn.gradients ( gpu_loss, self.model.get_weights() ) ]
|
||||
|
||||
|
||||
# Average losses and gradients, and create optimizer update ops
|
||||
#with tf.device(f'/CPU:0'): # Temporary fix. Unknown bug with training freeze starts from 2.4.0, but 2.3.1 was ok
|
||||
with tf.device (models_opt_device):
|
||||
pred = tf.concat(gpu_pred_list, 0)
|
||||
loss = tf.concat(gpu_losses, 0)
|
||||
pred = nn.concat(gpu_pred_list, 0)
|
||||
loss = tf.reduce_mean(gpu_losses)
|
||||
|
||||
loss_gv_op = self.model.opt.get_update_op (nn.average_gv_list (gpu_loss_gvs))
|
||||
|
||||
|
||||
# Initializing training and view functions
|
||||
if self.pretrain:
|
||||
def train(input_np, target_np):
|
||||
l, _ = nn.tf_sess.run ( [loss, loss_gv_op], feed_dict={self.model.input_t :input_np, self.model.target_t :target_np})
|
||||
return l
|
||||
else:
|
||||
def train(input_np, target_np):
|
||||
l, _ = nn.tf_sess.run ( [loss, loss_gv_op], feed_dict={self.model.input_t :input_np, self.model.target_t :target_np })
|
||||
return l
|
||||
def train(input_np, target_np):
|
||||
l, _ = nn.tf_sess.run ( [loss, loss_gv_op], feed_dict={self.model.input_t :input_np, self.model.target_t :target_np })
|
||||
return l
|
||||
self.train = train
|
||||
|
||||
def view(input_np):
|
||||
|
@ -147,38 +123,29 @@ class XSegModel(ModelBase):
|
|||
src_generators_count = cpu_count // 2
|
||||
dst_generators_count = cpu_count // 2
|
||||
|
||||
if self.pretrain:
|
||||
pretrain_gen = SampleGeneratorFace(self.get_pretraining_data_path(), debug=self.is_debug(), batch_size=self.get_batch_size(),
|
||||
sample_process_options=SampleProcessor.Options(random_flip=True),
|
||||
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
],
|
||||
uniform_yaw_distribution=False,
|
||||
generators_count=cpu_count )
|
||||
self.set_training_data_generators ([pretrain_gen])
|
||||
else:
|
||||
srcdst_generator = SampleGeneratorFaceXSeg([self.training_data_src_path, self.training_data_dst_path],
|
||||
debug=self.is_debug(),
|
||||
batch_size=self.get_batch_size(),
|
||||
resolution=resolution,
|
||||
face_type=self.face_type,
|
||||
generators_count=src_dst_generators_count,
|
||||
data_format=nn.data_format)
|
||||
|
||||
src_generator = SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
|
||||
sample_process_options=SampleProcessor.Options(random_flip=False),
|
||||
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp':False, 'transform':False, 'channel_type' : SampleProcessor.ChannelType.BGR, 'border_replicate':False, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
],
|
||||
generators_count=src_generators_count,
|
||||
raise_on_no_data=False )
|
||||
dst_generator = SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
|
||||
sample_process_options=SampleProcessor.Options(random_flip=False),
|
||||
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp':False, 'transform':False, 'channel_type' : SampleProcessor.ChannelType.BGR, 'border_replicate':False, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
],
|
||||
generators_count=dst_generators_count,
|
||||
raise_on_no_data=False )
|
||||
srcdst_generator = SampleGeneratorFaceXSeg([self.training_data_src_path, self.training_data_dst_path],
|
||||
debug=self.is_debug(),
|
||||
batch_size=self.get_batch_size(),
|
||||
resolution=resolution,
|
||||
face_type=self.face_type,
|
||||
generators_count=src_dst_generators_count,
|
||||
data_format=nn.data_format)
|
||||
|
||||
self.set_training_data_generators ([srcdst_generator, src_generator, dst_generator])
|
||||
src_generator = SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
|
||||
sample_process_options=SampleProcessor.Options(random_flip=False),
|
||||
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp':False, 'transform':False, 'channel_type' : SampleProcessor.ChannelType.BGR, 'border_replicate':False, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
],
|
||||
generators_count=src_generators_count,
|
||||
raise_on_no_data=False )
|
||||
dst_generator = SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
|
||||
sample_process_options=SampleProcessor.Options(random_flip=False),
|
||||
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp':False, 'transform':False, 'channel_type' : SampleProcessor.ChannelType.BGR, 'border_replicate':False, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
],
|
||||
generators_count=dst_generators_count,
|
||||
raise_on_no_data=False )
|
||||
|
||||
self.set_training_data_generators ([srcdst_generator, src_generator, dst_generator])
|
||||
|
||||
#override
|
||||
def get_model_filename_list(self):
|
||||
|
@ -190,21 +157,19 @@ class XSegModel(ModelBase):
|
|||
|
||||
#override
|
||||
def onTrainOneIter(self):
|
||||
image_np, target_np = self.generate_next_samples()[0]
|
||||
loss = self.train (image_np, target_np)
|
||||
|
||||
return ( ('loss', np.mean(loss) ), )
|
||||
|
||||
image_np, mask_np = self.generate_next_samples()[0]
|
||||
loss = self.train (image_np, mask_np)
|
||||
|
||||
return ( ('loss', loss ), )
|
||||
|
||||
#override
|
||||
def onGetPreview(self, samples, for_history=False):
|
||||
def onGetPreview(self, samples):
|
||||
n_samples = min(4, self.get_batch_size(), 800 // self.resolution )
|
||||
|
||||
if self.pretrain:
|
||||
srcdst_samples, = samples
|
||||
image_np, mask_np = srcdst_samples
|
||||
else:
|
||||
srcdst_samples, src_samples, dst_samples = samples
|
||||
image_np, mask_np = srcdst_samples
|
||||
srcdst_samples, src_samples, dst_samples = samples
|
||||
image_np, mask_np = srcdst_samples
|
||||
|
||||
I, M, IM, = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([image_np,mask_np] + self.view (image_np) ) ]
|
||||
M, IM, = [ np.repeat (x, (3,), -1) for x in [M, IM] ]
|
||||
|
@ -214,14 +179,11 @@ class XSegModel(ModelBase):
|
|||
result = []
|
||||
st = []
|
||||
for i in range(n_samples):
|
||||
if self.pretrain:
|
||||
ar = I[i], IM[i]
|
||||
else:
|
||||
ar = I[i]*M[i]+0.5*I[i]*(1-M[i])+0.5*green_bg*(1-M[i]), IM[i], I[i]*IM[i]+0.5*I[i]*(1-IM[i]) + 0.5*green_bg*(1-IM[i])
|
||||
ar = I[i]*M[i]+0.5*I[i]*(1-M[i])+0.5*green_bg*(1-M[i]), IM[i], I[i]*IM[i]+0.5*I[i]*(1-IM[i]) + 0.5*green_bg*(1-IM[i])
|
||||
st.append ( np.concatenate ( ar, axis=1) )
|
||||
result += [ ('XSeg training faces', np.concatenate (st, axis=0 )), ]
|
||||
|
||||
if not self.pretrain and len(src_samples) != 0:
|
||||
if len(src_samples) != 0:
|
||||
src_np, = src_samples
|
||||
|
||||
|
||||
|
@ -235,7 +197,7 @@ class XSegModel(ModelBase):
|
|||
|
||||
result += [ ('XSeg src faces', np.concatenate (st, axis=0 )), ]
|
||||
|
||||
if not self.pretrain and len(dst_samples) != 0:
|
||||
if len(dst_samples) != 0:
|
||||
dst_np, = dst_samples
|
||||
|
||||
|
||||
|
@ -251,33 +213,4 @@ class XSegModel(ModelBase):
|
|||
|
||||
return result
|
||||
|
||||
def export_dfm (self):
|
||||
output_path = self.get_strpath_storage_for_file(f'model.onnx')
|
||||
io.log_info(f'Dumping .onnx to {output_path}')
|
||||
tf = nn.tf
|
||||
|
||||
with tf.device (nn.tf_default_device_name):
|
||||
input_t = tf.placeholder (nn.floatx, (None, self.resolution, self.resolution, 3), name='in_face')
|
||||
input_t = tf.transpose(input_t, (0,3,1,2))
|
||||
_, pred_t = self.model.flow(input_t)
|
||||
pred_t = tf.transpose(pred_t, (0,2,3,1))
|
||||
|
||||
tf.identity(pred_t, name='out_mask')
|
||||
|
||||
output_graph_def = tf.graph_util.convert_variables_to_constants(
|
||||
nn.tf_sess,
|
||||
tf.get_default_graph().as_graph_def(),
|
||||
['out_mask']
|
||||
)
|
||||
|
||||
import tf2onnx
|
||||
with tf.device("/CPU:0"):
|
||||
model_proto, _ = tf2onnx.convert._convert_common(
|
||||
output_graph_def,
|
||||
name='XSeg',
|
||||
input_names=['in_face:0'],
|
||||
output_names=['out_mask:0'],
|
||||
opset=13,
|
||||
output_path=output_path)
|
||||
|
||||
Model = XSegModel
|
50
project.code-workspace
Normal file
|
@ -0,0 +1,50 @@
|
|||
{
|
||||
"folders": [
|
||||
{
|
||||
"path": "."
|
||||
}
|
||||
],
|
||||
"settings": {
|
||||
"workbench.colorTheme": "Visual Studio Light",
|
||||
"diffEditor.ignoreTrimWhitespace": true,
|
||||
"workbench.sideBar.location": "right",
|
||||
"breadcrumbs.enabled": false,
|
||||
"editor.renderWhitespace": "none",
|
||||
"editor.minimap.enabled": false,
|
||||
"workbench.activityBar.visible": true,
|
||||
"window.menuBarVisibility": "default",
|
||||
"editor.fastScrollSensitivity": 10,
|
||||
"editor.mouseWheelScrollSensitivity": 2,
|
||||
"window.zoomLevel": 0,
|
||||
"extensions.ignoreRecommendations": true,
|
||||
|
||||
"python.linting.pylintEnabled": false,
|
||||
"python.linting.enabled": false,
|
||||
"python.linting.pylamaEnabled": false,
|
||||
"python.linting.pydocstyleEnabled": false,
|
||||
"python.pythonPath": "${env:PYTHON_EXECUTABLE}",
|
||||
"workbench.editor.tabCloseButton": "off",
|
||||
"workbench.editor.tabSizing": "shrink",
|
||||
"workbench.editor.highlightModifiedTabs": true,
|
||||
"editor.mouseWheelScrollSensitivity": 3,
|
||||
"editor.folding": false,
|
||||
"editor.glyphMargin": false,
|
||||
"files.exclude": {
|
||||
"**/__pycache__": true,
|
||||
"**/.github": true,
|
||||
"**/.vscode": true,
|
||||
"**/*.dat": true,
|
||||
"**/*.h5": true,
|
||||
"**/*.npy": true
|
||||
},
|
||||
"editor.quickSuggestions": {
|
||||
"other": false,
|
||||
"comments": false,
|
||||
"strings": false
|
||||
},
|
||||
"editor.trimAutoWhitespace": false,
|
||||
"python.linting.pylintArgs": [
|
||||
"--disable=import-error"
|
||||
]
|
||||
}
|
||||
}
|
|
@ -1,11 +1,9 @@
|
|||
tqdm
|
||||
numpy==1.19.3
|
||||
numexpr
|
||||
h5py==2.10.0
|
||||
numpy==1.17.0
|
||||
h5py==2.9.0
|
||||
opencv-python==4.1.0.25
|
||||
ffmpeg-python==0.1.17
|
||||
scikit-image==0.14.2
|
||||
scipy==1.4.1
|
||||
colorama
|
||||
tensorflow-gpu==2.4.0
|
||||
tf2onnx==1.9.3
|
||||
tensorflow-gpu==1.13.2
|
|
@ -1,12 +1,11 @@
|
|||
tqdm
|
||||
numpy==1.19.3
|
||||
numexpr
|
||||
h5py==2.10.0
|
||||
numpy==1.17.0
|
||||
h5py==2.9.0
|
||||
opencv-python==4.1.0.25
|
||||
ffmpeg-python==0.1.17
|
||||
scikit-image==0.14.2
|
||||
scipy==1.4.1
|
||||
colorama
|
||||
tensorflow-gpu==2.4.0
|
||||
labelme==4.2.9
|
||||
tensorflow-gpu==1.13.2
|
||||
pyqt5
|
||||
tf2onnx==1.9.3
|
|
@ -85,17 +85,16 @@ class PackedFaceset():
|
|||
of.seek(0,2)
|
||||
of.close()
|
||||
|
||||
if io.input_bool(f"Delete original files?", True):
|
||||
for filename in io.progress_bar_generator(image_paths, "Deleting files"):
|
||||
Path(filename).unlink()
|
||||
for filename in io.progress_bar_generator(image_paths, "Deleting files"):
|
||||
Path(filename).unlink()
|
||||
|
||||
if as_person_faceset:
|
||||
for dir_name in io.progress_bar_generator(dir_names, "Deleting dirs"):
|
||||
dir_path = samples_path / dir_name
|
||||
try:
|
||||
shutil.rmtree(dir_path)
|
||||
except:
|
||||
io.log_info (f"unable to remove: {dir_path} ")
|
||||
if as_person_faceset:
|
||||
for dir_name in io.progress_bar_generator(dir_names, "Deleting dirs"):
|
||||
dir_path = samples_path / dir_name
|
||||
try:
|
||||
shutil.rmtree(dir_path)
|
||||
except:
|
||||
io.log_info (f"unable to remove: {dir_path} ")
|
||||
|
||||
@staticmethod
|
||||
def unpack(samples_path):
|
||||
|
@ -121,11 +120,6 @@ class PackedFaceset():
|
|||
|
||||
samples_dat_path.unlink()
|
||||
|
||||
@staticmethod
|
||||
def path_contains(samples_path):
|
||||
samples_dat_path = samples_path / packed_faceset_filename
|
||||
return samples_dat_path.exists()
|
||||
|
||||
@staticmethod
|
||||
def load(samples_path):
|
||||
samples_dat_path = samples_path / packed_faceset_filename
|
||||
|
|
|
@ -5,8 +5,8 @@ import cv2
|
|||
import numpy as np
|
||||
|
||||
from core.cv2ex import *
|
||||
from DFLIMG import *
|
||||
from facelib import LandmarksProcessor
|
||||
from core import imagelib
|
||||
from core.imagelib import SegIEPolys
|
||||
|
||||
class SampleType(IntEnum):
|
||||
|
@ -28,7 +28,6 @@ class Sample(object):
|
|||
'landmarks',
|
||||
'seg_ie_polys',
|
||||
'xseg_mask',
|
||||
'xseg_mask_compressed',
|
||||
'eyebrows_expand_mod',
|
||||
'source_filename',
|
||||
'person_name',
|
||||
|
@ -43,7 +42,6 @@ class Sample(object):
|
|||
landmarks=None,
|
||||
seg_ie_polys=None,
|
||||
xseg_mask=None,
|
||||
xseg_mask_compressed=None,
|
||||
eyebrows_expand_mod=None,
|
||||
source_filename=None,
|
||||
person_name=None,
|
||||
|
@ -62,16 +60,6 @@ class Sample(object):
|
|||
self.seg_ie_polys = SegIEPolys.load(seg_ie_polys)
|
||||
|
||||
self.xseg_mask = xseg_mask
|
||||
self.xseg_mask_compressed = xseg_mask_compressed
|
||||
|
||||
if self.xseg_mask_compressed is None and self.xseg_mask is not None:
|
||||
xseg_mask = np.clip( imagelib.normalize_channels(xseg_mask, 1)*255, 0, 255 ).astype(np.uint8)
|
||||
ret, xseg_mask_compressed = cv2.imencode('.png', xseg_mask)
|
||||
if not ret:
|
||||
raise Exception("Sample(): unable to generate xseg_mask_compressed")
|
||||
self.xseg_mask_compressed = xseg_mask_compressed
|
||||
self.xseg_mask = None
|
||||
|
||||
self.eyebrows_expand_mod = eyebrows_expand_mod if eyebrows_expand_mod is not None else 1.0
|
||||
self.source_filename = source_filename
|
||||
self.person_name = person_name
|
||||
|
@ -79,17 +67,6 @@ class Sample(object):
|
|||
|
||||
self._filename_offset_size = None
|
||||
|
||||
def has_xseg_mask(self):
|
||||
return self.xseg_mask is not None or self.xseg_mask_compressed is not None
|
||||
|
||||
def get_xseg_mask(self):
|
||||
if self.xseg_mask_compressed is not None:
|
||||
xseg_mask = cv2.imdecode(self.xseg_mask_compressed, cv2.IMREAD_UNCHANGED)
|
||||
if len(xseg_mask.shape) == 2:
|
||||
xseg_mask = xseg_mask[...,None]
|
||||
return xseg_mask.astype(np.float32) / 255.0
|
||||
return self.xseg_mask
|
||||
|
||||
def get_pitch_yaw_roll(self):
|
||||
if self.pitch_yaw_roll is None:
|
||||
self.pitch_yaw_roll = LandmarksProcessor.estimate_pitch_yaw_roll(self.landmarks, size=self.shape[1])
|
||||
|
@ -120,7 +97,6 @@ class Sample(object):
|
|||
'landmarks': self.landmarks.tolist(),
|
||||
'seg_ie_polys': self.seg_ie_polys.dump(),
|
||||
'xseg_mask' : self.xseg_mask,
|
||||
'xseg_mask_compressed' : self.xseg_mask_compressed,
|
||||
'eyebrows_expand_mod': self.eyebrows_expand_mod,
|
||||
'source_filename': self.source_filename,
|
||||
'person_name': self.person_name
|
||||
|
|
|
@ -6,13 +6,11 @@ import cv2
|
|||
import numpy as np
|
||||
|
||||
from core import mplib
|
||||
from core.interact import interact as io
|
||||
from core.joblib import SubprocessGenerator, ThisThreadGenerator
|
||||
from facelib import LandmarksProcessor
|
||||
from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor,
|
||||
SampleType)
|
||||
|
||||
|
||||
'''
|
||||
arg
|
||||
output_sample_types = [
|
||||
|
@ -25,15 +23,15 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
|||
random_ct_samples_path=None,
|
||||
sample_process_options=SampleProcessor.Options(),
|
||||
output_sample_types=[],
|
||||
uniform_yaw_distribution=False,
|
||||
add_sample_idx=False,
|
||||
generators_count=4,
|
||||
raise_on_no_data=True,
|
||||
**kwargs):
|
||||
|
||||
super().__init__(debug, batch_size)
|
||||
self.initialized = False
|
||||
self.sample_process_options = sample_process_options
|
||||
self.output_sample_types = output_sample_types
|
||||
self.add_sample_idx = add_sample_idx
|
||||
|
||||
if self.debug:
|
||||
self.generators_count = 1
|
||||
|
@ -43,39 +41,14 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
|||
samples = SampleLoader.load (SampleType.FACE, samples_path)
|
||||
self.samples_len = len(samples)
|
||||
|
||||
self.initialized = False
|
||||
if self.samples_len == 0:
|
||||
if raise_on_no_data:
|
||||
raise ValueError('No training data provided.')
|
||||
else:
|
||||
return
|
||||
|
||||
if uniform_yaw_distribution:
|
||||
samples_pyr = [ ( idx, sample.get_pitch_yaw_roll() ) for idx, sample in enumerate(samples) ]
|
||||
|
||||
grads = 128
|
||||
#instead of math.pi / 2, using -1.2,+1.2 because actually maximum yaw for 2DFAN landmarks are -1.2+1.2
|
||||
grads_space = np.linspace (-1.2, 1.2,grads)
|
||||
|
||||
yaws_sample_list = [None]*grads
|
||||
for g in io.progress_bar_generator ( range(grads), "Sort by yaw"):
|
||||
yaw = grads_space[g]
|
||||
next_yaw = grads_space[g+1] if g < grads-1 else yaw
|
||||
|
||||
yaw_samples = []
|
||||
for idx, pyr in samples_pyr:
|
||||
s_yaw = -pyr[1]
|
||||
if (g == 0 and s_yaw < next_yaw) or \
|
||||
(g < grads-1 and s_yaw >= yaw and s_yaw < next_yaw) or \
|
||||
(g == grads-1 and s_yaw >= yaw):
|
||||
yaw_samples += [ idx ]
|
||||
if len(yaw_samples) > 0:
|
||||
yaws_sample_list[g] = yaw_samples
|
||||
|
||||
yaws_sample_list = [ y for y in yaws_sample_list if y is not None ]
|
||||
|
||||
index_host = mplib.Index2DHost( yaws_sample_list )
|
||||
else:
|
||||
index_host = mplib.IndexHost(self.samples_len)
|
||||
index_host = mplib.IndexHost(self.samples_len)
|
||||
|
||||
if random_ct_samples_path is not None:
|
||||
ct_samples = SampleLoader.load (SampleType.FACE, random_ct_samples_path)
|
||||
|
@ -137,8 +110,14 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
|||
|
||||
if batches is None:
|
||||
batches = [ [] for _ in range(len(x)) ]
|
||||
if self.add_sample_idx:
|
||||
batches += [ [] ]
|
||||
i_sample_idx = len(batches)-1
|
||||
|
||||
for i in range(len(x)):
|
||||
batches[i].append ( x[i] )
|
||||
|
||||
if self.add_sample_idx:
|
||||
batches[i_sample_idx].append (sample_idx)
|
||||
|
||||
yield [ np.array(batch) for batch in batches]
|
||||
|
|
|
@ -12,98 +12,6 @@ from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor,
|
|||
SampleType)
|
||||
|
||||
|
||||
|
||||
class Index2DHost():
|
||||
"""
|
||||
Provides random shuffled 2D indexes for multiprocesses
|
||||
"""
|
||||
def __init__(self, indexes2D):
|
||||
self.sq = multiprocessing.Queue()
|
||||
self.cqs = []
|
||||
self.clis = []
|
||||
self.thread = threading.Thread(target=self.host_thread, args=(indexes2D,) )
|
||||
self.thread.daemon = True
|
||||
self.thread.start()
|
||||
|
||||
def host_thread(self, indexes2D):
|
||||
indexes_counts_len = len(indexes2D)
|
||||
|
||||
idxs = [*range(indexes_counts_len)]
|
||||
idxs_2D = [None]*indexes_counts_len
|
||||
shuffle_idxs = []
|
||||
shuffle_idxs_2D = [None]*indexes_counts_len
|
||||
for i in range(indexes_counts_len):
|
||||
idxs_2D[i] = indexes2D[i]
|
||||
shuffle_idxs_2D[i] = []
|
||||
|
||||
sq = self.sq
|
||||
|
||||
while True:
|
||||
while not sq.empty():
|
||||
obj = sq.get()
|
||||
cq_id, cmd = obj[0], obj[1]
|
||||
|
||||
if cmd == 0: #get_1D
|
||||
count = obj[2]
|
||||
|
||||
result = []
|
||||
for i in range(count):
|
||||
if len(shuffle_idxs) == 0:
|
||||
shuffle_idxs = idxs.copy()
|
||||
np.random.shuffle(shuffle_idxs)
|
||||
result.append(shuffle_idxs.pop())
|
||||
self.cqs[cq_id].put (result)
|
||||
elif cmd == 1: #get_2D
|
||||
targ_idxs,count = obj[2], obj[3]
|
||||
result = []
|
||||
|
||||
for targ_idx in targ_idxs:
|
||||
sub_idxs = []
|
||||
for i in range(count):
|
||||
ar = shuffle_idxs_2D[targ_idx]
|
||||
if len(ar) == 0:
|
||||
ar = shuffle_idxs_2D[targ_idx] = idxs_2D[targ_idx].copy()
|
||||
np.random.shuffle(ar)
|
||||
sub_idxs.append(ar.pop())
|
||||
result.append (sub_idxs)
|
||||
self.cqs[cq_id].put (result)
|
||||
|
||||
time.sleep(0.001)
|
||||
|
||||
def create_cli(self):
|
||||
cq = multiprocessing.Queue()
|
||||
self.cqs.append ( cq )
|
||||
cq_id = len(self.cqs)-1
|
||||
return Index2DHost.Cli(self.sq, cq, cq_id)
|
||||
|
||||
# disable pickling
|
||||
def __getstate__(self):
|
||||
return dict()
|
||||
def __setstate__(self, d):
|
||||
self.__dict__.update(d)
|
||||
|
||||
class Cli():
|
||||
def __init__(self, sq, cq, cq_id):
|
||||
self.sq = sq
|
||||
self.cq = cq
|
||||
self.cq_id = cq_id
|
||||
|
||||
def get_1D(self, count):
|
||||
self.sq.put ( (self.cq_id,0, count) )
|
||||
|
||||
while True:
|
||||
if not self.cq.empty():
|
||||
return self.cq.get()
|
||||
time.sleep(0.001)
|
||||
|
||||
def get_2D(self, idxs, count):
|
||||
self.sq.put ( (self.cq_id,1,idxs,count) )
|
||||
|
||||
while True:
|
||||
if not self.cq.empty():
|
||||
return self.cq.get()
|
||||
time.sleep(0.001)
|
||||
|
||||
'''
|
||||
arg
|
||||
output_sample_types = [
|
||||
|
@ -137,7 +45,7 @@ class SampleGeneratorFacePerson(SampleGeneratorBase):
|
|||
for i,sample in enumerate(samples):
|
||||
persons_name_idxs[sample.person_name].append (i)
|
||||
indexes2D = [ persons_name_idxs[person_name] for person_name in unique_person_names ]
|
||||
index2d_host = Index2DHost(indexes2D)
|
||||
index2d_host = mplib.Index2DHost(indexes2D)
|
||||
|
||||
if self.debug:
|
||||
self.generators_count = 1
|
||||
|
|