diff --git a/.vscode/launch.json b/.vscode/launch.json index f8857c1..6dc974d 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -12,7 +12,7 @@ "type": "python", "request": "launch", "program": "${env:DFL_ROOT}\\main.py", - "pythonPath": "${env:PYTHONEXECUTABLE}", + "python": "${env:PYTHONEXECUTABLE}", "cwd": "${env:WORKSPACE}", "console": "integratedTerminal", "args": ["train", diff --git a/DFLIMG/DFLJPG.py b/DFLIMG/DFLJPG.py index babdeee..614cebc 100644 --- a/DFLIMG/DFLJPG.py +++ b/DFLIMG/DFLJPG.py @@ -1,21 +1,27 @@ import pickle import struct +import traceback import cv2 import numpy as np +from core import imagelib +from core.cv2ex import * +from core.imagelib import SegIEPolys from core.interact import interact as io from core.structex import * from facelib import FaceType class DFLJPG(object): - def __init__(self): + def __init__(self, filename): + self.filename = filename self.data = b"" self.length = 0 self.chunks = [] self.dfl_dict = None - self.shape = (0,0,0) + self.shape = None + self.img = None @staticmethod def load_raw(filename, loader_func=None): @@ -29,7 +35,7 @@ class DFLJPG(object): raise FileNotFoundError(filename) try: - inst = DFLJPG() + inst = DFLJPG(filename) inst.data = data inst.length = len(data) inst_length = inst.length @@ -123,7 +129,7 @@ class DFLJPG(object): def load(filename, loader_func=None): try: inst = DFLJPG.load_raw (filename, loader_func=loader_func) - inst.dfl_dict = None + inst.dfl_dict = {} for chunk in inst.chunks: if chunk['name'] == 'APP0': @@ -132,8 +138,6 @@ class DFLJPG(object): if id == b"JFIF": c, ver_major, ver_minor, units, Xdensity, Ydensity, Xthumbnail, Ythumbnail = struct_unpack (d, c, "=BBBHHBB") - #if units == 0: - # inst.shape = (Ydensity, Xdensity, 3) else: raise Exception("Unknown jpeg ID: %s" % (id) ) elif chunk['name'] == 'SOF0' or chunk['name'] == 'SOF2': @@ -145,160 +149,30 @@ class DFLJPG(object): if type(chunk['data']) == bytes: inst.dfl_dict = pickle.loads(chunk['data']) - if (inst.dfl_dict is not None): - if 'face_type' not in inst.dfl_dict: - inst.dfl_dict['face_type'] = FaceType.toString (FaceType.FULL) - - if 'fanseg_mask' in inst.dfl_dict: - fanseg_mask = inst.dfl_dict['fanseg_mask'] - if fanseg_mask is not None: - numpyarray = np.asarray( inst.dfl_dict['fanseg_mask'], dtype=np.uint8) - inst.dfl_dict['fanseg_mask'] = cv2.imdecode(numpyarray, cv2.IMREAD_UNCHANGED) - - if inst.dfl_dict == None: - return None - return inst except Exception as e: - print (e) + io.log_err (f'Exception occured while DFLJPG.load : {traceback.format_exc()}') return None - @staticmethod - def embed_dfldict(filename, dfl_dict): - inst = DFLJPG.load_raw (filename) - inst.setDFLDictData (dfl_dict) + def has_data(self): + return len(self.dfl_dict.keys()) != 0 + def save(self): try: - with open(filename, "wb") as f: - f.write ( inst.dump() ) + with open(self.filename, "wb") as f: + f.write ( self.dump() ) except: - raise Exception( 'cannot save %s' % (filename) ) - - @staticmethod - def embed_data(filename, face_type=None, - landmarks=None, - ie_polys=None, - seg_ie_polys=None, - source_filename=None, - source_rect=None, - source_landmarks=None, - image_to_face_mat=None, - fanseg_mask=None, - eyebrows_expand_mod=None, - relighted=None, - **kwargs - ): - - if fanseg_mask is not None: - fanseg_mask = np.clip ( (fanseg_mask*255).astype(np.uint8), 0, 255 ) - - ret, buf = cv2.imencode( '.jpg', fanseg_mask, [int(cv2.IMWRITE_JPEG_QUALITY), 85] ) - - if ret and len(buf) < 60000: - fanseg_mask = buf - else: - io.log_err("Unable to encode fanseg_mask for %s" % (filename) ) - fanseg_mask = None - - if ie_polys is not None: - if not isinstance(ie_polys, list): - ie_polys = ie_polys.dump() - - if seg_ie_polys is not None: - if not isinstance(seg_ie_polys, list): - seg_ie_polys = seg_ie_polys.dump() - - DFLJPG.embed_dfldict (filename, {'face_type': face_type, - 'landmarks': landmarks, - 'ie_polys' : ie_polys, - 'seg_ie_polys' : seg_ie_polys, - 'source_filename': source_filename, - 'source_rect': source_rect, - 'source_landmarks': source_landmarks, - 'image_to_face_mat': image_to_face_mat, - 'fanseg_mask' : fanseg_mask, - 'eyebrows_expand_mod' : eyebrows_expand_mod, - 'relighted' : relighted - }) - - def embed_and_set(self, filename, face_type=None, - landmarks=None, - ie_polys=None, - seg_ie_polys=None, - source_filename=None, - source_rect=None, - source_landmarks=None, - image_to_face_mat=None, - fanseg_mask=None, - eyebrows_expand_mod=None, - relighted=None, - **kwargs - ): - if face_type is None: face_type = self.get_face_type() - if landmarks is None: landmarks = self.get_landmarks() - if ie_polys is None: ie_polys = self.get_ie_polys() - if seg_ie_polys is None: seg_ie_polys = self.get_seg_ie_polys() - if source_filename is None: source_filename = self.get_source_filename() - if source_rect is None: source_rect = self.get_source_rect() - if source_landmarks is None: source_landmarks = self.get_source_landmarks() - if image_to_face_mat is None: image_to_face_mat = self.get_image_to_face_mat() - if fanseg_mask is None: fanseg_mask = self.get_fanseg_mask() - if eyebrows_expand_mod is None: eyebrows_expand_mod = self.get_eyebrows_expand_mod() - if relighted is None: relighted = self.get_relighted() - DFLJPG.embed_data (filename, face_type=face_type, - landmarks=landmarks, - ie_polys=ie_polys, - seg_ie_polys=seg_ie_polys, - source_filename=source_filename, - source_rect=source_rect, - source_landmarks=source_landmarks, - image_to_face_mat=image_to_face_mat, - fanseg_mask=fanseg_mask, - eyebrows_expand_mod=eyebrows_expand_mod, - relighted=relighted) - - def remove_ie_polys(self): - self.dfl_dict['ie_polys'] = None - - def remove_seg_ie_polys(self): - self.dfl_dict['seg_ie_polys'] = None - - def remove_fanseg_mask(self): - self.dfl_dict['fanseg_mask'] = None - - def remove_source_filename(self): - self.dfl_dict['source_filename'] = None + raise Exception( f'cannot save {self.filename}' ) def dump(self): data = b"" - for chunk in self.chunks: - data += struct.pack ("BB", 0xFF, chunk['m_h'] ) - chunk_data = chunk['data'] - if chunk_data is not None: - data += struct.pack (">H", len(chunk_data)+2 ) - data += chunk_data + dict_data = self.dfl_dict - chunk_ex_data = chunk['ex_data'] - if chunk_ex_data is not None: - data += chunk_ex_data - - return data - - def get_shape(self): - return self.shape - - def get_height(self): - for chunk in self.chunks: - if type(chunk) == IHDR: - return chunk.height - return 0 - - def getDFLDictData(self): - return self.dfl_dict - - def setDFLDictData (self, dict_data=None): - self.dfl_dict = dict_data + # Remove None keys + for key in list(dict_data.keys()): + if dict_data[key] is None: + dict_data.pop(key) for chunk in self.chunks: if chunk['name'] == 'APP15': @@ -317,24 +191,134 @@ class DFLJPG(object): } self.chunks.insert (last_app_chunk+1, dflchunk) - def get_face_type(self): return self.dfl_dict['face_type'] - def get_landmarks(self): return np.array ( self.dfl_dict['landmarks'] ) - def get_ie_polys(self): return self.dfl_dict.get('ie_polys',None) - def get_seg_ie_polys(self): return self.dfl_dict.get('seg_ie_polys',None) - def get_source_filename(self): return self.dfl_dict['source_filename'] - def get_source_rect(self): return self.dfl_dict['source_rect'] - def get_source_landmarks(self): return np.array ( self.dfl_dict['source_landmarks'] ) + + for chunk in self.chunks: + data += struct.pack ("BB", 0xFF, chunk['m_h'] ) + chunk_data = chunk['data'] + if chunk_data is not None: + data += struct.pack (">H", len(chunk_data)+2 ) + data += chunk_data + + chunk_ex_data = chunk['ex_data'] + if chunk_ex_data is not None: + data += chunk_ex_data + + return data + + def get_img(self): + if self.img is None: + self.img = cv2_imread(self.filename) + return self.img + + def get_shape(self): + if self.shape is None: + img = self.get_img() + if img is not None: + self.shape = img.shape + return self.shape + + def get_height(self): + for chunk in self.chunks: + if type(chunk) == IHDR: + return chunk.height + return 0 + + def get_dict(self): + return self.dfl_dict + + def set_dict (self, dict_data=None): + self.dfl_dict = dict_data + + def get_face_type(self): return self.dfl_dict.get('face_type', FaceType.toString (FaceType.FULL) ) + def set_face_type(self, face_type): self.dfl_dict['face_type'] = face_type + + def get_landmarks(self): return np.array ( self.dfl_dict['landmarks'] ) + def set_landmarks(self, landmarks): self.dfl_dict['landmarks'] = landmarks + + def get_eyebrows_expand_mod(self): return self.dfl_dict.get ('eyebrows_expand_mod', 1.0) + def set_eyebrows_expand_mod(self, eyebrows_expand_mod): self.dfl_dict['eyebrows_expand_mod'] = eyebrows_expand_mod + + def get_source_filename(self): return self.dfl_dict.get ('source_filename', None) + def set_source_filename(self, source_filename): self.dfl_dict['source_filename'] = source_filename + + def get_source_rect(self): return self.dfl_dict.get ('source_rect', None) + def set_source_rect(self, source_rect): self.dfl_dict['source_rect'] = source_rect + + def get_source_landmarks(self): return np.array ( self.dfl_dict.get('source_landmarks', None) ) + def set_source_landmarks(self, source_landmarks): self.dfl_dict['source_landmarks'] = source_landmarks + def get_image_to_face_mat(self): mat = self.dfl_dict.get ('image_to_face_mat', None) if mat is not None: return np.array (mat) return None - def get_fanseg_mask(self): - fanseg_mask = self.dfl_dict.get ('fanseg_mask', None) - if fanseg_mask is not None: - return np.clip ( np.array (fanseg_mask) / 255.0, 0.0, 1.0 )[...,np.newaxis] - return None - def get_eyebrows_expand_mod(self): - return self.dfl_dict.get ('eyebrows_expand_mod', None) - def get_relighted(self): - return self.dfl_dict.get ('relighted', False) + def set_image_to_face_mat(self, image_to_face_mat): self.dfl_dict['image_to_face_mat'] = image_to_face_mat + + def has_seg_ie_polys(self): + return self.dfl_dict.get('seg_ie_polys',None) is not None + + def get_seg_ie_polys(self): + d = self.dfl_dict.get('seg_ie_polys',None) + if d is not None: + d = SegIEPolys.load(d) + else: + d = SegIEPolys() + + return d + + def set_seg_ie_polys(self, seg_ie_polys): + if seg_ie_polys is not None: + if not isinstance(seg_ie_polys, SegIEPolys): + raise ValueError('seg_ie_polys should be instance of SegIEPolys') + + if seg_ie_polys.has_polys(): + seg_ie_polys = seg_ie_polys.dump() + else: + seg_ie_polys = None + + self.dfl_dict['seg_ie_polys'] = seg_ie_polys + + def has_xseg_mask(self): + return self.dfl_dict.get('xseg_mask',None) is not None + + def get_xseg_mask_compressed(self): + mask_buf = self.dfl_dict.get('xseg_mask',None) + if mask_buf is None: + return None + + return mask_buf + + def get_xseg_mask(self): + mask_buf = self.dfl_dict.get('xseg_mask',None) + if mask_buf is None: + return None + + img = cv2.imdecode(mask_buf, cv2.IMREAD_UNCHANGED) + if len(img.shape) == 2: + img = img[...,None] + + return img.astype(np.float32) / 255.0 + + + def set_xseg_mask(self, mask_a): + if mask_a is None: + self.dfl_dict['xseg_mask'] = None + return + + mask_a = imagelib.normalize_channels(mask_a, 1) + img_data = np.clip( mask_a*255, 0, 255 ).astype(np.uint8) + + data_max_len = 50000 + + ret, buf = cv2.imencode('.png', img_data) + + if not ret or len(buf) > data_max_len: + for jpeg_quality in range(100,-1,-1): + ret, buf = cv2.imencode( '.jpg', img_data, [int(cv2.IMWRITE_JPEG_QUALITY), jpeg_quality] ) + if ret and len(buf) <= data_max_len: + break + + if not ret: + raise Exception("set_xseg_mask: unable to generate image data for set_xseg_mask") + + self.dfl_dict['xseg_mask'] = buf diff --git a/README.md b/README.md index 62f23df..27b590a 100644 --- a/README.md +++ b/README.md @@ -1,122 +1,237 @@ -
- + + + + - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ +# DeepFaceLab + + + + +https://arxiv.org/abs/2005.05535 + +

-![](doc/logo_cuda.png) ![](doc/logo_tensorflow.png) -![](doc/logo_python.png) +![](doc/logo_cuda.png) +![](doc/logo_directx.png) +

- -# DeepFaceLab -### the leading software for creating deepfakes - -
- -More than 95% of deepfake videos are created with DeepFaceLab. DeepFaceLab is used by such popular youtube channels as -|![](doc/youtube_icon.png) [Ctrl Shift Face](https://www.youtube.com/channel/UCKpH0CKltc73e4wh0_pgL3g)|![](doc/youtube_icon.png) [VFXChris Ume](https://www.youtube.com/channel/UCGf4OlX_aTt8DlrgiH3jN3g/videos)| -|---|---| - -|![](doc/youtube_icon.png) [Sham00k](https://www.youtube.com/channel/UCZXbWcv7fSZFTAZV4beckyw/videos)|![](doc/youtube_icon.png) [Collider videos](https://www.youtube.com/watch?v=A91P2qtPT54&list=PLayt6616lBclvOprvrC8qKGCO-mAhPRux)|![](doc/youtube_icon.png) [iFake](https://www.youtube.com/channel/UCC0lK2Zo2BMXX-k1Ks0r7dg/videos)| +|![](doc/tiktok_icon.png) [deeptomcruise](https://www.tiktok.com/@deeptomcruise)|![](doc/tiktok_icon.png) [1facerussia](https://www.tiktok.com/@1facerussia)|![](doc/tiktok_icon.png) [arnoldschwarzneggar](https://www.tiktok.com/@arnoldschwarzneggar) |---|---|---| -
-# Quality progress +|![](doc/tiktok_icon.png) [mariahcareyathome?](https://www.tiktok.com/@mariahcareyathome?)|![](doc/tiktok_icon.png) [diepnep](https://www.tiktok.com/@diepnep)|![](doc/tiktok_icon.png) [mr__heisenberg](https://www.tiktok.com/@mr__heisenberg)|![](doc/tiktok_icon.png) [deepcaprio](https://www.tiktok.com/@deepcaprio) +|---|---|---|---| -
- -|2018|![](doc/progress_2018.png) | +|![](doc/youtube_icon.png) [VFXChris Ume](https://www.youtube.com/channel/UCGf4OlX_aTt8DlrgiH3jN3g/videos)|![](doc/youtube_icon.png) [Sham00k](https://www.youtube.com/channel/UCZXbWcv7fSZFTAZV4beckyw/videos)| |---|---| -|2020|![](doc/progress_2020.png)| +|![](doc/youtube_icon.png) [Collider videos](https://www.youtube.com/watch?v=A91P2qtPT54&list=PLayt6616lBclvOprvrC8qKGCO-mAhPRux)|![](doc/youtube_icon.png) [iFake](https://www.youtube.com/channel/UCC0lK2Zo2BMXX-k1Ks0r7dg/videos)|![](doc/youtube_icon.png) [NextFace](https://www.youtube.com/channel/UCFh3gL0a8BS21g-DHvXZEeQ/videos)| +|---|---|---| + +|![](doc/youtube_icon.png) [Futuring Machine](https://www.youtube.com/channel/UCC5BbFxqLQgfnWPhprmQLVg)|![](doc/youtube_icon.png) [RepresentUS](https://www.youtube.com/channel/UCRzgK52MmetD9aG8pDOID3g)|![](doc/youtube_icon.png) [Corridor Crew](https://www.youtube.com/c/corridorcrew/videos)| +|---|---|---| + +|![](doc/youtube_icon.png) [DeepFaker](https://www.youtube.com/channel/UCkHecfDTcSazNZSKPEhtPVQ)|![](doc/youtube_icon.png) [DeepFakes in movie](https://www.youtube.com/c/DeepFakesinmovie/videos)| +|---|---| + +|![](doc/youtube_icon.png) [DeepFakeCreator](https://www.youtube.com/channel/UCkNFhcYNLQ5hr6A6lZ56mKA)|![](doc/youtube_icon.png) [Jarkan](https://www.youtube.com/user/Jarkancio/videos)| |---|---|
+ +
+ +# What can I do using DeepFaceLab? + +
+ +## Replace the face + + + +
+ +## De-age the face + +
+ + + + + + + +
+ +![](doc/youtube_icon.png) https://www.youtube.com/watch?v=Ddx5B-84ebo + +
+ +## Replace the head + +
+ + + + + + + +
+ +![](doc/youtube_icon.png) https://www.youtube.com/watch?v=RTjgkhMugVw + +
+ +# Native resolution progress + +
+ + + +
+ + + +Unfortunately, there is no "make everything ok" button in DeepFaceLab. You should spend time studying the workflow and growing your skills. A skill in programs such as *AfterEffects* or *Davinci Resolve* is also desirable. + +
+ +## Mini tutorial + + + + + + + +
## Releases -|||| -|---|---|---| -|Windows|[github releases](https://github.com/iperov/DeepFaceLab/releases)|Direct download| -||[Google drive](https://drive.google.com/open?id=1BCFK_L7lPNwMbEQ_kFPqPpDdFEOd_Dci)|if the download quota is exceeded, add the file to your own google drive and download from it| -|Google Colab|[github](https://github.com/chervonij/DFL-Colab)|by @chervonij . You can train fakes for free using Google Colab.| -|CentOS Linux|[github](https://github.com/elemantalcode/dfl)|by @elemantalcode| -|Linux|[github](https://github.com/lbfs/DeepFaceLab_Linux)|by @lbfs | -|||| +
+Windows (magnet link) +Last release. Use torrent client to download.
+Windows (Mega.nz) +Contains new and prev releases.
+Windows (yandex.ru) +Contains new and prev releases.
+Linux (github) +by @nagadit
+CentOS Linux (github) +May be outdated. By @elemantalcode
+ + + + - -## Links + - - - - +
+ +### Communication groups
+
+Discord +Official discord channel. English / Russian.
- -|||| -|---|---|---| -|Guides and tutorials||| -||[DeepFaceLab guide](https://mrdeepfakes.com/forums/thread-guide-deepfacelab-2-0-explained-and-tutorials-recommended)|Main guide| -||[Faceset creation guide](https://mrdeepfakes.com/forums/thread-guide-celebrity-faceset-dataset-creation-how-to-create-celebrity-facesets)|How to create the right faceset | -||[Google Colab guide](https://mrdeepfakes.com/forums/thread-guide-deepfacelab-google-colab-tutorial)|Guide how to train the fake on Google Colab| -||[Compositing](https://mrdeepfakes.com/forums/thread-deepfacelab-2-0-compositing-in-davinci-resolve-vegas-pro-and-after-effects)|To achieve the highest quality, compose deepfake manually in video editors such as Davince Resolve or Adobe AfterEffects| -||[Discussion and suggestions](https://mrdeepfakes.com/forums/thread-deepfacelab-2-0-discussion-tips-suggestions)|| -|||| -|Supplementary material||| -||[Ready to work facesets](https://mrdeepfakes.com/forums/forum-celebrity-facesets)|Celebrity facesets made by community| -||[Pretrained models](https://mrdeepfakes.com/forums/forum-celebrity-facesets)|Use pretrained models made by community to speed up training| -|||| -|Communication groups||| -||[telegram (English / Russian)](https://t.me/DeepFaceLab_official)|Don't forget to hide your phone number| -||[telegram (English only)](https://t.me/DeepFaceLab_official_en)|Don't forget to hide your phone number| -||[mrdeepfakes](https://mrdeepfakes.com/forums/)|the biggest NSFW English community| -||QQ 951138799| 中文 Chinese QQ group for ML/AI experts|| -||[deepfaker.xyz](https://www.deepfaker.xyz)|中文 Chinese guys are localizing DeepFaceLab| -||[reddit r/GifFakes/](https://www.reddit.com/r/GifFakes/new/)|Post your deepfakes there !| -||[reddit r/SFWdeepfakes/](https://www.reddit.com/r/SFWdeepfakes/new/)|Post your deepfakes there !| +## Related works
- -## How I can help the project? - -|||| -|---|---|---| -|Donate|Sponsor deepfake research and DeepFaceLab development.|| -||[Donate via Paypal](https://www.paypal.com/cgi-bin/webscr?cmd=_donations&business=lepersorium@gmail.com&lc=US&no_note=0&item_name=Support+DeepFaceLab&cn=&curency_code=USD&bn=PP-DonationsBF:btn_donateCC_LG.gif:NonHosted) -||[Donate via Yandex.Money](https://money.yandex.ru/to/41001142318065)|| -||bitcoin:31mPd6DxPCzbpCMZk4k1koWAbErSyqkAXr|| -|||| -|Last donations|10$ ( 14 march 2020 Amien Phillips ) -||20$ ( 12 march 2020 Maria D. ) -||200$ ( 12 march 2020 VFXChris Ume ) -||300$ ( 12 march 2020 Thiago O. ) -||50$ ( 8 march 2020 blanuk ) -|||| -|Collect facesets|You can collect faceset of any celebrity that can be used in DeepFaceLab and share it [in the community](https://mrdeepfakes.com/forums/forum-celebrity-facesets)| -|||| -|Star this repo|Register github account and push "Star" button. -
- -## Meme zone -

- -![](doc/DeepFaceLab_is_working.png) - -

- -
- -#deepfacelab #deepfakes #faceswap #face-swap #deep-learning #deeplearning #deep-neural-networks #deepface #deep-face-swap #fakeapp #fake-app #neural-networks #neural-nets #tensorflow #cuda #nvidia +
+DeepFaceLive +Real-time face swap for PC streaming or video calls
+ + + + + + + + + +
+ +## How I can help the project? + +
+ +### Star this repo + +
+ +Register github account and push "Star" button. + +
+ + + + + + + + + + + +
+ +## Meme zone + +
+ + + + + + + +
+ +#deepfacelab #faceswap #face-swap #deep-learning #deeplearning #deep-neural-networks #deepface #deep-face-swap #neural-networks #neural-nets #tensorflow #cuda #nvidia + +
diff --git a/XSegEditor/QCursorDB.py b/XSegEditor/QCursorDB.py new file mode 100644 index 0000000..0909cba --- /dev/null +++ b/XSegEditor/QCursorDB.py @@ -0,0 +1,10 @@ +from PyQt5.QtCore import * +from PyQt5.QtGui import * +from PyQt5.QtWidgets import * + +class QCursorDB(): + @staticmethod + def initialize(cursor_path): + QCursorDB.cross_red = QCursor ( QPixmap ( str(cursor_path / 'cross_red.png') ) ) + QCursorDB.cross_green = QCursor ( QPixmap ( str(cursor_path / 'cross_green.png') ) ) + QCursorDB.cross_blue = QCursor ( QPixmap ( str(cursor_path / 'cross_blue.png') ) ) diff --git a/XSegEditor/QIconDB.py b/XSegEditor/QIconDB.py new file mode 100644 index 0000000..1fd9e3e --- /dev/null +++ b/XSegEditor/QIconDB.py @@ -0,0 +1,26 @@ +from PyQt5.QtCore import * +from PyQt5.QtGui import * +from PyQt5.QtWidgets import * + + +class QIconDB(): + @staticmethod + def initialize(icon_path): + QIconDB.app_icon = QIcon ( str(icon_path / 'app_icon.png') ) + QIconDB.delete_poly = QIcon ( str(icon_path / 'delete_poly.png') ) + QIconDB.undo_pt = QIcon ( str(icon_path / 'undo_pt.png') ) + QIconDB.redo_pt = QIcon ( str(icon_path / 'redo_pt.png') ) + QIconDB.poly_color_red = QIcon ( str(icon_path / 'poly_color_red.png') ) + QIconDB.poly_color_green = QIcon ( str(icon_path / 'poly_color_green.png') ) + QIconDB.poly_color_blue = QIcon ( str(icon_path / 'poly_color_blue.png') ) + QIconDB.poly_type_include = QIcon ( str(icon_path / 'poly_type_include.png') ) + QIconDB.poly_type_exclude = QIcon ( str(icon_path / 'poly_type_exclude.png') ) + QIconDB.left = QIcon ( str(icon_path / 'left.png') ) + QIconDB.right = QIcon ( str(icon_path / 'right.png') ) + QIconDB.trashcan = QIcon ( str(icon_path / 'trashcan.png') ) + QIconDB.pt_edit_mode = QIcon ( str(icon_path / 'pt_edit_mode.png') ) + QIconDB.view_lock_center = QIcon ( str(icon_path / 'view_lock_center.png') ) + QIconDB.view_baked = QIcon ( str(icon_path / 'view_baked.png') ) + QIconDB.view_xseg = QIcon ( str(icon_path / 'view_xseg.png') ) + QIconDB.view_xseg_overlay = QIcon ( str(icon_path / 'view_xseg_overlay.png') ) + \ No newline at end of file diff --git a/XSegEditor/QImageDB.py b/XSegEditor/QImageDB.py new file mode 100644 index 0000000..45cad78 --- /dev/null +++ b/XSegEditor/QImageDB.py @@ -0,0 +1,8 @@ +from PyQt5.QtCore import * +from PyQt5.QtGui import * +from PyQt5.QtWidgets import * + +class QImageDB(): + @staticmethod + def initialize(image_path): + QImageDB.intro = QImage ( str(image_path / 'intro.png') ) diff --git a/XSegEditor/QStringDB.py b/XSegEditor/QStringDB.py new file mode 100644 index 0000000..b9100d2 --- /dev/null +++ b/XSegEditor/QStringDB.py @@ -0,0 +1,102 @@ +from localization import system_language + + +class QStringDB(): + + @staticmethod + def initialize(): + lang = system_language + + if lang not in ['en','ru','zh']: + lang = 'en' + + QStringDB.btn_poly_color_red_tip = { 'en' : 'Poly color scheme red', + 'ru' : 'Красная цветовая схема полигонов', + 'zh' : '选区配色方案红色', + }[lang] + + QStringDB.btn_poly_color_green_tip = { 'en' : 'Poly color scheme green', + 'ru' : 'Зелёная цветовая схема полигонов', + 'zh' : '选区配色方案绿色', + }[lang] + + QStringDB.btn_poly_color_blue_tip = { 'en' : 'Poly color scheme blue', + 'ru' : 'Синяя цветовая схема полигонов', + 'zh' : '选区配色方案蓝色', + }[lang] + + QStringDB.btn_view_baked_mask_tip = { 'en' : 'View baked mask', + 'ru' : 'Посмотреть запечёную маску', + 'zh' : '查看遮罩通道', + }[lang] + + QStringDB.btn_view_xseg_mask_tip = { 'en' : 'View trained XSeg mask', + 'ru' : 'Посмотреть тренированную XSeg маску', + 'zh' : '查看导入后的XSeg遮罩', + }[lang] + + QStringDB.btn_view_xseg_overlay_mask_tip = { 'en' : 'View trained XSeg mask overlay face', + 'ru' : 'Посмотреть тренированную XSeg маску поверх лица', + 'zh' : '查看导入后的XSeg遮罩于脸上方', + }[lang] + + QStringDB.btn_poly_type_include_tip = { 'en' : 'Poly include mode', + 'ru' : 'Режим полигонов - включение', + 'zh' : '包含选区模式', + }[lang] + + QStringDB.btn_poly_type_exclude_tip = { 'en' : 'Poly exclude mode', + 'ru' : 'Режим полигонов - исключение', + 'zh' : '排除选区模式', + }[lang] + + QStringDB.btn_undo_pt_tip = { 'en' : 'Undo point', + 'ru' : 'Отменить точку', + 'zh' : '撤消点', + }[lang] + + QStringDB.btn_redo_pt_tip = { 'en' : 'Redo point', + 'ru' : 'Повторить точку', + 'zh' : '重做点', + }[lang] + + QStringDB.btn_delete_poly_tip = { 'en' : 'Delete poly', + 'ru' : 'Удалить полигон', + 'zh' : '删除选区', + }[lang] + + QStringDB.btn_pt_edit_mode_tip = { 'en' : 'Add/delete point mode ( HOLD CTRL )', + 'ru' : 'Режим добавления/удаления точек ( удерживайте CTRL )', + 'zh' : '点加/删除模式 ( 按住CTRL )', + }[lang] + + QStringDB.btn_view_lock_center_tip = { 'en' : 'Lock cursor at the center ( HOLD SHIFT )', + 'ru' : 'Заблокировать курсор в центре ( удерживайте SHIFT )', + 'zh' : '将光标锁定在中心 ( 按住SHIFT )', + }[lang] + + + QStringDB.btn_prev_image_tip = { 'en' : 'Save and Prev image\nHold SHIFT : accelerate\nHold CTRL : skip non masked\n', + 'ru' : 'Сохранить и предыдущее изображение\nУдерживать SHIFT : ускорить\nУдерживать CTRL : пропустить неразмеченные\n', + 'zh' : '保存并转到上一张图片\n按住SHIFT : 加快\n按住CTRL : 跳过未标记的\n', + }[lang] + QStringDB.btn_next_image_tip = { 'en' : 'Save and Next image\nHold SHIFT : accelerate\nHold CTRL : skip non masked\n', + 'ru' : 'Сохранить и следующее изображение\nУдерживать SHIFT : ускорить\nУдерживать CTRL : пропустить неразмеченные\n', + 'zh' : '保存并转到下一张图片\n按住SHIFT : 加快\n按住CTRL : 跳过未标记的\n', + }[lang] + + QStringDB.btn_delete_image_tip = { 'en' : 'Move to _trash and Next image\n', + 'ru' : 'Переместить в _trash и следующее изображение\n', + 'zh' : '移至_trash,转到下一张图片 ', + }[lang] + + QStringDB.loading_tip = {'en' : 'Loading', + 'ru' : 'Загрузка', + 'zh' : '正在载入', + }[lang] + + QStringDB.labeled_tip = {'en' : 'labeled', + 'ru' : 'размечено', + 'zh' : '标记的', + }[lang] + diff --git a/XSegEditor/XSegEditor.py b/XSegEditor/XSegEditor.py new file mode 100644 index 0000000..affc9f6 --- /dev/null +++ b/XSegEditor/XSegEditor.py @@ -0,0 +1,1494 @@ +import json +import multiprocessing +import os +import pickle +import sys +import tempfile +import time +import traceback +from enum import IntEnum +from types import SimpleNamespace as sn + +import cv2 +import numpy as np +import numpy.linalg as npla +from PyQt5.QtCore import * +from PyQt5.QtGui import * +from PyQt5.QtWidgets import * + +from core import imagelib, pathex +from core.cv2ex import * +from core.imagelib import SegIEPoly, SegIEPolys, SegIEPolyType, sd +from core.qtex import * +from DFLIMG import * +from localization import StringsDB, system_language +from samplelib import PackedFaceset + +from .QCursorDB import QCursorDB +from .QIconDB import QIconDB +from .QStringDB import QStringDB +from .QImageDB import QImageDB + +class OpMode(IntEnum): + NONE = 0 + DRAW_PTS = 1 + EDIT_PTS = 2 + VIEW_BAKED = 3 + VIEW_XSEG_MASK = 4 + +class PTEditMode(IntEnum): + MOVE = 0 + ADD_DEL = 1 + +class DragType(IntEnum): + NONE = 0 + IMAGE_LOOK = 1 + POLY_PT = 2 + +class ViewLock(IntEnum): + NONE = 0 + CENTER = 1 + +class QUIConfig(): + @staticmethod + def initialize(icon_size = 48, icon_spacer_size=16, preview_bar_icon_size=64): + QUIConfig.icon_q_size = QSize(icon_size, icon_size) + QUIConfig.icon_spacer_q_size = QSize(icon_spacer_size, icon_spacer_size) + QUIConfig.preview_bar_icon_q_size = QSize(preview_bar_icon_size, preview_bar_icon_size) + +class ImagePreviewSequenceBar(QFrame): + def __init__(self, preview_images_count, icon_size): + super().__init__() + self.preview_images_count = preview_images_count = max(1, preview_images_count + (preview_images_count % 2 -1) ) + + self.icon_size = icon_size + + black_q_img = QImage(np.zeros( (icon_size,icon_size,3) ).data, icon_size, icon_size, 3*icon_size, QImage.Format_RGB888) + self.black_q_pixmap = QPixmap.fromImage(black_q_img) + + self.image_containers = [ QLabel() for i in range(preview_images_count)] + + main_frame_l_cont_hl = QGridLayout() + main_frame_l_cont_hl.setContentsMargins(0,0,0,0) + #main_frame_l_cont_hl.setSpacing(0) + + + + for i in range(len(self.image_containers)): + q_label = self.image_containers[i] + q_label.setScaledContents(True) + if i == preview_images_count//2: + q_label.setMinimumSize(icon_size+16, icon_size+16 ) + q_label.setMaximumSize(icon_size+16, icon_size+16 ) + else: + q_label.setMinimumSize(icon_size, icon_size ) + q_label.setMaximumSize(icon_size, icon_size ) + opacity_effect = QGraphicsOpacityEffect() + opacity_effect.setOpacity(0.5) + q_label.setGraphicsEffect(opacity_effect) + + q_label.setSizePolicy (QSizePolicy.Fixed, QSizePolicy.Fixed) + + main_frame_l_cont_hl.addWidget (q_label, 0, i) + + self.setLayout(main_frame_l_cont_hl) + + self.prev_img_conts = self.image_containers[(preview_images_count//2) -1::-1] + self.next_img_conts = self.image_containers[preview_images_count//2:] + + self.update_images() + + def get_preview_images_count(self): + return self.preview_images_count + + def update_images(self, prev_imgs=None, next_imgs=None): + # Fix arrays + if prev_imgs is None: + prev_imgs = [] + prev_img_conts_len = len(self.prev_img_conts) + prev_q_imgs_len = len(prev_imgs) + if prev_q_imgs_len < prev_img_conts_len: + for i in range ( prev_img_conts_len - prev_q_imgs_len ): + prev_imgs.append(None) + elif prev_q_imgs_len > prev_img_conts_len: + prev_imgs = prev_imgs[:prev_img_conts_len] + + if next_imgs is None: + next_imgs = [] + next_img_conts_len = len(self.next_img_conts) + next_q_imgs_len = len(next_imgs) + if next_q_imgs_len < next_img_conts_len: + for i in range ( next_img_conts_len - next_q_imgs_len ): + next_imgs.append(None) + elif next_q_imgs_len > next_img_conts_len: + next_imgs = next_imgs[:next_img_conts_len] + + for i,img in enumerate(prev_imgs): + self.prev_img_conts[i].setPixmap( QPixmap.fromImage( QImage_from_np(img) ) if img is not None else self.black_q_pixmap ) + + for i,img in enumerate(next_imgs): + self.next_img_conts[i].setPixmap( QPixmap.fromImage( QImage_from_np(img) ) if img is not None else self.black_q_pixmap ) + +class ColorScheme(): + def __init__(self, unselected_color, selected_color, outline_color, outline_width, pt_outline_color, cross_cursor): + self.poly_unselected_brush = QBrush(unselected_color) + self.poly_selected_brush = QBrush(selected_color) + + self.poly_outline_solid_pen = QPen(outline_color, outline_width, Qt.SolidLine, Qt.RoundCap, Qt.RoundJoin) + self.poly_outline_dot_pen = QPen(outline_color, outline_width, Qt.DotLine, Qt.RoundCap, Qt.RoundJoin) + + self.pt_outline_pen = QPen(pt_outline_color) + self.cross_cursor = cross_cursor + +class CanvasConfig(): + + def __init__(self, + pt_radius=4, + pt_select_radius=8, + color_schemes=None, + **kwargs): + self.pt_radius = pt_radius + self.pt_select_radius = pt_select_radius + + if color_schemes is None: + color_schemes = [ + ColorScheme( QColor(192,0,0,alpha=0), QColor(192,0,0,alpha=72), QColor(192,0,0), 2, QColor(255,255,255), QCursorDB.cross_red ), + ColorScheme( QColor(0,192,0,alpha=0), QColor(0,192,0,alpha=72), QColor(0,192,0), 2, QColor(255,255,255), QCursorDB.cross_green ), + ColorScheme( QColor(0,0,192,alpha=0), QColor(0,0,192,alpha=72), QColor(0,0,192), 2, QColor(255,255,255), QCursorDB.cross_blue ), + ] + self.color_schemes = color_schemes + +class QCanvasControlsLeftBar(QFrame): + + def __init__(self): + super().__init__() + #============================================== + btn_poly_type_include = QToolButton() + self.btn_poly_type_include_act = QActionEx( QIconDB.poly_type_include, QStringDB.btn_poly_type_include_tip, shortcut='Q', shortcut_in_tooltip=True, is_checkable=True) + btn_poly_type_include.setDefaultAction(self.btn_poly_type_include_act) + btn_poly_type_include.setIconSize(QUIConfig.icon_q_size) + + btn_poly_type_exclude = QToolButton() + self.btn_poly_type_exclude_act = QActionEx( QIconDB.poly_type_exclude, QStringDB.btn_poly_type_exclude_tip, shortcut='W', shortcut_in_tooltip=True, is_checkable=True) + btn_poly_type_exclude.setDefaultAction(self.btn_poly_type_exclude_act) + btn_poly_type_exclude.setIconSize(QUIConfig.icon_q_size) + + self.btn_poly_type_act_grp = QActionGroup (self) + self.btn_poly_type_act_grp.addAction(self.btn_poly_type_include_act) + self.btn_poly_type_act_grp.addAction(self.btn_poly_type_exclude_act) + self.btn_poly_type_act_grp.setExclusive(True) + #============================================== + btn_undo_pt = QToolButton() + self.btn_undo_pt_act = QActionEx( QIconDB.undo_pt, QStringDB.btn_undo_pt_tip, shortcut='Ctrl+Z', shortcut_in_tooltip=True, is_auto_repeat=True) + btn_undo_pt.setDefaultAction(self.btn_undo_pt_act) + btn_undo_pt.setIconSize(QUIConfig.icon_q_size) + + btn_redo_pt = QToolButton() + self.btn_redo_pt_act = QActionEx( QIconDB.redo_pt, QStringDB.btn_redo_pt_tip, shortcut='Ctrl+Shift+Z', shortcut_in_tooltip=True, is_auto_repeat=True) + btn_redo_pt.setDefaultAction(self.btn_redo_pt_act) + btn_redo_pt.setIconSize(QUIConfig.icon_q_size) + + btn_delete_poly = QToolButton() + self.btn_delete_poly_act = QActionEx( QIconDB.delete_poly, QStringDB.btn_delete_poly_tip, shortcut='Delete', shortcut_in_tooltip=True) + btn_delete_poly.setDefaultAction(self.btn_delete_poly_act) + btn_delete_poly.setIconSize(QUIConfig.icon_q_size) + #============================================== + btn_pt_edit_mode = QToolButton() + self.btn_pt_edit_mode_act = QActionEx( QIconDB.pt_edit_mode, QStringDB.btn_pt_edit_mode_tip, shortcut_in_tooltip=True, is_checkable=True) + btn_pt_edit_mode.setDefaultAction(self.btn_pt_edit_mode_act) + btn_pt_edit_mode.setIconSize(QUIConfig.icon_q_size) + #============================================== + + controls_bar_frame2_l = QVBoxLayout() + controls_bar_frame2_l.addWidget ( btn_poly_type_include ) + controls_bar_frame2_l.addWidget ( btn_poly_type_exclude ) + controls_bar_frame2 = QFrame() + controls_bar_frame2.setFrameShape(QFrame.StyledPanel) + controls_bar_frame2.setSizePolicy (QSizePolicy.Fixed, QSizePolicy.Fixed) + controls_bar_frame2.setLayout(controls_bar_frame2_l) + + controls_bar_frame3_l = QVBoxLayout() + controls_bar_frame3_l.addWidget ( btn_undo_pt ) + controls_bar_frame3_l.addWidget ( btn_redo_pt ) + controls_bar_frame3_l.addWidget ( btn_delete_poly ) + controls_bar_frame3 = QFrame() + controls_bar_frame3.setFrameShape(QFrame.StyledPanel) + controls_bar_frame3.setSizePolicy (QSizePolicy.Fixed, QSizePolicy.Fixed) + controls_bar_frame3.setLayout(controls_bar_frame3_l) + + controls_bar_frame4_l = QVBoxLayout() + controls_bar_frame4_l.addWidget ( btn_pt_edit_mode ) + controls_bar_frame4 = QFrame() + controls_bar_frame4.setFrameShape(QFrame.StyledPanel) + controls_bar_frame4.setSizePolicy (QSizePolicy.Fixed, QSizePolicy.Fixed) + controls_bar_frame4.setLayout(controls_bar_frame4_l) + + controls_bar_l = QVBoxLayout() + controls_bar_l.setContentsMargins(0,0,0,0) + controls_bar_l.addWidget(controls_bar_frame2) + controls_bar_l.addWidget(controls_bar_frame3) + controls_bar_l.addWidget(controls_bar_frame4) + + self.setSizePolicy ( QSizePolicy.Fixed, QSizePolicy.Expanding ) + self.setLayout(controls_bar_l) + +class QCanvasControlsRightBar(QFrame): + + def __init__(self): + super().__init__() + #============================================== + btn_poly_color_red = QToolButton() + self.btn_poly_color_red_act = QActionEx( QIconDB.poly_color_red, QStringDB.btn_poly_color_red_tip, shortcut='1', shortcut_in_tooltip=True, is_checkable=True) + btn_poly_color_red.setDefaultAction(self.btn_poly_color_red_act) + btn_poly_color_red.setIconSize(QUIConfig.icon_q_size) + + btn_poly_color_green = QToolButton() + self.btn_poly_color_green_act = QActionEx( QIconDB.poly_color_green, QStringDB.btn_poly_color_green_tip, shortcut='2', shortcut_in_tooltip=True, is_checkable=True) + btn_poly_color_green.setDefaultAction(self.btn_poly_color_green_act) + btn_poly_color_green.setIconSize(QUIConfig.icon_q_size) + + btn_poly_color_blue = QToolButton() + self.btn_poly_color_blue_act = QActionEx( QIconDB.poly_color_blue, QStringDB.btn_poly_color_blue_tip, shortcut='3', shortcut_in_tooltip=True, is_checkable=True) + btn_poly_color_blue.setDefaultAction(self.btn_poly_color_blue_act) + btn_poly_color_blue.setIconSize(QUIConfig.icon_q_size) + + btn_view_baked_mask = QToolButton() + self.btn_view_baked_mask_act = QActionEx( QIconDB.view_baked, QStringDB.btn_view_baked_mask_tip, shortcut='4', shortcut_in_tooltip=True, is_checkable=True) + btn_view_baked_mask.setDefaultAction(self.btn_view_baked_mask_act) + btn_view_baked_mask.setIconSize(QUIConfig.icon_q_size) + + btn_view_xseg_mask = QToolButton() + self.btn_view_xseg_mask_act = QActionEx( QIconDB.view_xseg, QStringDB.btn_view_xseg_mask_tip, shortcut='5', shortcut_in_tooltip=True, is_checkable=True) + btn_view_xseg_mask.setDefaultAction(self.btn_view_xseg_mask_act) + btn_view_xseg_mask.setIconSize(QUIConfig.icon_q_size) + + btn_view_xseg_overlay_mask = QToolButton() + self.btn_view_xseg_overlay_mask_act = QActionEx( QIconDB.view_xseg_overlay, QStringDB.btn_view_xseg_overlay_mask_tip, shortcut='`', shortcut_in_tooltip=True, is_checkable=True) + btn_view_xseg_overlay_mask.setDefaultAction(self.btn_view_xseg_overlay_mask_act) + btn_view_xseg_overlay_mask.setIconSize(QUIConfig.icon_q_size) + + self.btn_poly_color_act_grp = QActionGroup (self) + self.btn_poly_color_act_grp.addAction(self.btn_poly_color_red_act) + self.btn_poly_color_act_grp.addAction(self.btn_poly_color_green_act) + self.btn_poly_color_act_grp.addAction(self.btn_poly_color_blue_act) + self.btn_poly_color_act_grp.addAction(self.btn_view_baked_mask_act) + self.btn_poly_color_act_grp.addAction(self.btn_view_xseg_mask_act) + self.btn_poly_color_act_grp.setExclusive(True) + #============================================== + btn_view_lock_center = QToolButton() + self.btn_view_lock_center_act = QActionEx( QIconDB.view_lock_center, QStringDB.btn_view_lock_center_tip, shortcut_in_tooltip=True, is_checkable=True) + btn_view_lock_center.setDefaultAction(self.btn_view_lock_center_act) + btn_view_lock_center.setIconSize(QUIConfig.icon_q_size) + + controls_bar_frame2_l = QVBoxLayout() + controls_bar_frame2_l.addWidget ( btn_view_xseg_overlay_mask ) + controls_bar_frame2 = QFrame() + controls_bar_frame2.setFrameShape(QFrame.StyledPanel) + controls_bar_frame2.setSizePolicy (QSizePolicy.Fixed, QSizePolicy.Fixed) + controls_bar_frame2.setLayout(controls_bar_frame2_l) + + controls_bar_frame1_l = QVBoxLayout() + controls_bar_frame1_l.addWidget ( btn_poly_color_red ) + controls_bar_frame1_l.addWidget ( btn_poly_color_green ) + controls_bar_frame1_l.addWidget ( btn_poly_color_blue ) + controls_bar_frame1_l.addWidget ( btn_view_baked_mask ) + controls_bar_frame1_l.addWidget ( btn_view_xseg_mask ) + controls_bar_frame1 = QFrame() + controls_bar_frame1.setFrameShape(QFrame.StyledPanel) + controls_bar_frame1.setSizePolicy (QSizePolicy.Fixed, QSizePolicy.Fixed) + controls_bar_frame1.setLayout(controls_bar_frame1_l) + + controls_bar_frame3_l = QVBoxLayout() + controls_bar_frame3_l.addWidget ( btn_view_lock_center ) + controls_bar_frame3 = QFrame() + controls_bar_frame3.setFrameShape(QFrame.StyledPanel) + controls_bar_frame3.setSizePolicy (QSizePolicy.Fixed, QSizePolicy.Fixed) + controls_bar_frame3.setLayout(controls_bar_frame3_l) + + controls_bar_l = QVBoxLayout() + controls_bar_l.setContentsMargins(0,0,0,0) + controls_bar_l.addWidget(controls_bar_frame2) + controls_bar_l.addWidget(controls_bar_frame1) + controls_bar_l.addWidget(controls_bar_frame3) + + self.setSizePolicy ( QSizePolicy.Fixed, QSizePolicy.Expanding ) + self.setLayout(controls_bar_l) + +class QCanvasOperator(QWidget): + def __init__(self, cbar): + super().__init__() + self.cbar = cbar + + self.set_cbar_disabled() + + self.cbar.btn_poly_color_red_act.triggered.connect ( lambda : self.set_color_scheme_id(0) ) + self.cbar.btn_poly_color_green_act.triggered.connect ( lambda : self.set_color_scheme_id(1) ) + self.cbar.btn_poly_color_blue_act.triggered.connect ( lambda : self.set_color_scheme_id(2) ) + self.cbar.btn_view_baked_mask_act.triggered.connect ( lambda : self.set_op_mode(OpMode.VIEW_BAKED) ) + self.cbar.btn_view_xseg_mask_act.triggered.connect ( lambda : self.set_op_mode(OpMode.VIEW_XSEG_MASK) ) + + self.cbar.btn_view_xseg_overlay_mask_act.toggled.connect ( lambda is_checked: self.update() ) + + self.cbar.btn_poly_type_include_act.triggered.connect ( lambda : self.set_poly_include_type(SegIEPolyType.INCLUDE) ) + self.cbar.btn_poly_type_exclude_act.triggered.connect ( lambda : self.set_poly_include_type(SegIEPolyType.EXCLUDE) ) + + self.cbar.btn_undo_pt_act.triggered.connect ( lambda : self.action_undo_pt() ) + self.cbar.btn_redo_pt_act.triggered.connect ( lambda : self.action_redo_pt() ) + + self.cbar.btn_delete_poly_act.triggered.connect ( lambda : self.action_delete_poly() ) + + self.cbar.btn_pt_edit_mode_act.toggled.connect ( lambda is_checked: self.set_pt_edit_mode( PTEditMode.ADD_DEL if is_checked else PTEditMode.MOVE ) ) + self.cbar.btn_view_lock_center_act.toggled.connect ( lambda is_checked: self.set_view_lock( ViewLock.CENTER if is_checked else ViewLock.NONE ) ) + + self.mouse_in_widget = False + + QXMainWindow.inst.add_keyPressEvent_listener ( self.on_keyPressEvent ) + QXMainWindow.inst.add_keyReleaseEvent_listener ( self.on_keyReleaseEvent ) + + self.qp = QPainter() + self.initialized = False + self.last_state = None + + def initialize(self, img, img_look_pt=None, view_scale=None, ie_polys=None, xseg_mask=None, canvas_config=None ): + q_img = self.q_img = QImage_from_np(img) + self.img_pixmap = QPixmap.fromImage(q_img) + + self.xseg_mask_pixmap = None + self.xseg_overlay_mask_pixmap = None + if xseg_mask is not None: + h,w,c = img.shape + xseg_mask = cv2.resize(xseg_mask, (w,h), interpolation=cv2.INTER_CUBIC) + xseg_mask = imagelib.normalize_channels(xseg_mask, 1) + xseg_img = img.astype(np.float32)/255.0 + xseg_overlay_mask = xseg_img*(1-xseg_mask)*0.5 + xseg_img*xseg_mask + xseg_overlay_mask = np.clip(xseg_overlay_mask*255, 0, 255).astype(np.uint8) + xseg_mask = np.clip(xseg_mask*255, 0, 255).astype(np.uint8) + self.xseg_mask_pixmap = QPixmap.fromImage(QImage_from_np(xseg_mask)) + self.xseg_overlay_mask_pixmap = QPixmap.fromImage(QImage_from_np(xseg_overlay_mask)) + + self.img_size = QSize_to_np (self.img_pixmap.size()) + + self.img_look_pt = img_look_pt + self.view_scale = view_scale + + if ie_polys is None: + ie_polys = SegIEPolys() + self.ie_polys = ie_polys + + if canvas_config is None: + canvas_config = CanvasConfig() + self.canvas_config = canvas_config + + # UI init + self.set_cbar_disabled() + self.cbar.btn_poly_color_act_grp.setDisabled(False) + self.cbar.btn_view_xseg_overlay_mask_act.setDisabled(False) + self.cbar.btn_poly_type_act_grp.setDisabled(False) + + # Initial vars + self.current_cursor = None + self.mouse_hull_poly = None + self.mouse_wire_poly = None + self.drag_type = DragType.NONE + self.mouse_cli_pt = np.zeros((2,), np.float32 ) + + # Initial state + self.set_op_mode(OpMode.NONE) + self.set_color_scheme_id(1) + self.set_poly_include_type(SegIEPolyType.INCLUDE) + self.set_pt_edit_mode(PTEditMode.MOVE) + self.set_view_lock(ViewLock.NONE) + + # Apply last state + if self.last_state is not None: + self.set_color_scheme_id(self.last_state.color_scheme_id) + if self.last_state.op_mode is not None: + self.set_op_mode(self.last_state.op_mode) + + self.initialized = True + + self.setMouseTracking(True) + self.update_cursor() + self.update() + + + def finalize(self): + if self.initialized: + if self.op_mode == OpMode.DRAW_PTS: + self.set_op_mode(OpMode.EDIT_PTS) + + self.last_state = sn(op_mode = self.op_mode if self.op_mode in [OpMode.VIEW_BAKED, OpMode.VIEW_XSEG_MASK] else None, + color_scheme_id = self.color_scheme_id) + + self.img_pixmap = None + self.update_cursor(is_finalize=True) + self.setMouseTracking(False) + self.setFocusPolicy(Qt.NoFocus) + self.set_cbar_disabled() + self.initialized = False + self.update() + + # ==================================================================================== + # ==================================================================================== + # ====================================== GETTERS ===================================== + # ==================================================================================== + # ==================================================================================== + def is_initialized(self): + return self.initialized + + def get_ie_polys(self): + return self.ie_polys + + def get_cli_center_pt(self): + return np.round(QSize_to_np(self.size())/2.0) + + def get_img_look_pt(self): + img_look_pt = self.img_look_pt + if img_look_pt is None: + img_look_pt = self.img_size / 2 + return img_look_pt + + def get_view_scale(self): + view_scale = self.view_scale + if view_scale is None: + # Calc as scale to fit + min_cli_size = np.min(QSize_to_np(self.size())) + max_img_size = np.max(self.img_size) + view_scale = min_cli_size / max_img_size + + return view_scale + + def get_current_color_scheme(self): + return self.canvas_config.color_schemes[self.color_scheme_id] + + def get_poly_pt_id_under_pt(self, poly, cli_pt): + w = np.argwhere ( npla.norm ( cli_pt - self.img_to_cli_pt( poly.get_pts() ), axis=1 ) <= self.canvas_config.pt_select_radius ) + return None if len(w) == 0 else w[-1][0] + + def get_poly_edge_id_pt_under_pt(self, poly, cli_pt): + cli_pts = self.img_to_cli_pt(poly.get_pts()) + if len(cli_pts) >= 3: + edge_dists, projs = sd.dist_to_edges(cli_pts, cli_pt, is_closed=True) + edge_id = np.argmin(edge_dists) + dist = edge_dists[edge_id] + pt = projs[edge_id] + if dist <= self.canvas_config.pt_select_radius: + return edge_id, pt + return None, None + + def get_poly_by_pt_near_wire(self, cli_pt): + pt_select_radius = self.canvas_config.pt_select_radius + + for poly in reversed(self.ie_polys.get_polys()): + pts = poly.get_pts() + if len(pts) >= 3: + cli_pts = self.img_to_cli_pt(pts) + + edge_dists, _ = sd.dist_to_edges(cli_pts, cli_pt, is_closed=True) + + if np.min(edge_dists) <= pt_select_radius or \ + any( npla.norm ( cli_pt - cli_pts, axis=1 ) <= pt_select_radius ): + return poly + return None + + def get_poly_by_pt_in_hull(self, cli_pos): + img_pos = self.cli_to_img_pt(cli_pos) + + for poly in reversed(self.ie_polys.get_polys()): + pts = poly.get_pts() + if len(pts) >= 3: + if cv2.pointPolygonTest( pts, tuple(img_pos), False) >= 0: + return poly + + return None + + def img_to_cli_pt(self, p): + return (p - self.get_img_look_pt()) * self.get_view_scale() + self.get_cli_center_pt()# QSize_to_np(self.size())/2.0 + + def cli_to_img_pt(self, p): + return (p - self.get_cli_center_pt() ) / self.get_view_scale() + self.get_img_look_pt() + + def img_to_cli_rect(self, rect): + tl = QPoint_to_np(rect.topLeft()) + xy = self.img_to_cli_pt(tl) + xy2 = self.img_to_cli_pt(tl + QSize_to_np(rect.size()) ) - xy + return QRect ( *xy.astype(np.int), *xy2.astype(np.int) ) + + # ==================================================================================== + # ==================================================================================== + # ====================================== SETTERS ===================================== + # ==================================================================================== + # ==================================================================================== + def set_op_mode(self, op_mode, op_poly=None): + if not hasattr(self,'op_mode'): + self.op_mode = None + self.op_poly = None + + if self.op_mode != op_mode: + # Finalize prev mode + if self.op_mode == OpMode.NONE: + self.cbar.btn_poly_type_act_grp.setDisabled(True) + elif self.op_mode == OpMode.DRAW_PTS: + self.cbar.btn_undo_pt_act.setDisabled(True) + self.cbar.btn_redo_pt_act.setDisabled(True) + self.cbar.btn_view_lock_center_act.setDisabled(True) + # Reset view_lock when exit from DRAW_PTS + self.set_view_lock(ViewLock.NONE) + # Remove unfinished poly + if self.op_poly.get_pts_count() < 3: + self.ie_polys.remove_poly(self.op_poly) + + elif self.op_mode == OpMode.EDIT_PTS: + self.cbar.btn_pt_edit_mode_act.setDisabled(True) + self.cbar.btn_delete_poly_act.setDisabled(True) + # Reset pt_edit_move when exit from EDIT_PTS + self.set_pt_edit_mode(PTEditMode.MOVE) + elif self.op_mode == OpMode.VIEW_BAKED: + self.cbar.btn_view_baked_mask_act.setChecked(False) + elif self.op_mode == OpMode.VIEW_XSEG_MASK: + self.cbar.btn_view_xseg_mask_act.setChecked(False) + + self.op_mode = op_mode + + # Initialize new mode + if op_mode == OpMode.NONE: + self.cbar.btn_poly_type_act_grp.setDisabled(False) + elif op_mode == OpMode.DRAW_PTS: + self.cbar.btn_undo_pt_act.setDisabled(False) + self.cbar.btn_redo_pt_act.setDisabled(False) + self.cbar.btn_view_lock_center_act.setDisabled(False) + elif op_mode == OpMode.EDIT_PTS: + self.cbar.btn_pt_edit_mode_act.setDisabled(False) + self.cbar.btn_delete_poly_act.setDisabled(False) + elif op_mode == OpMode.VIEW_BAKED: + self.cbar.btn_view_baked_mask_act.setChecked(True ) + n = QImage_to_np ( self.q_img ).astype(np.float32) / 255.0 + h,w,c = n.shape + mask = np.zeros( (h,w,1), dtype=np.float32 ) + self.ie_polys.overlay_mask(mask) + n = (mask*255).astype(np.uint8) + self.img_baked_pixmap = QPixmap.fromImage(QImage_from_np(n)) + elif op_mode == OpMode.VIEW_XSEG_MASK: + self.cbar.btn_view_xseg_mask_act.setChecked(True) + + if op_mode in [OpMode.DRAW_PTS, OpMode.EDIT_PTS]: + self.mouse_op_poly_pt_id = None + self.mouse_op_poly_edge_id = None + self.mouse_op_poly_edge_id_pt = None + + self.op_poly = op_poly + if op_poly is not None: + self.update_mouse_info() + + self.update_cursor() + self.update() + + def set_pt_edit_mode(self, pt_edit_mode): + if not hasattr(self, 'pt_edit_mode') or self.pt_edit_mode != pt_edit_mode: + self.pt_edit_mode = pt_edit_mode + self.update_cursor() + self.update() + self.cbar.btn_pt_edit_mode_act.setChecked( self.pt_edit_mode == PTEditMode.ADD_DEL ) + + def set_view_lock(self, view_lock): + if not hasattr(self, 'view_lock') or self.view_lock != view_lock: + if hasattr(self, 'view_lock') and self.view_lock != view_lock: + if view_lock == ViewLock.CENTER: + self.img_look_pt = self.mouse_img_pt + QCursor.setPos ( self.mapToGlobal( QPoint_from_np(self.img_to_cli_pt(self.img_look_pt)) )) + + self.view_lock = view_lock + self.update() + self.cbar.btn_view_lock_center_act.setChecked( self.view_lock == ViewLock.CENTER ) + + def set_cbar_disabled(self): + self.cbar.btn_delete_poly_act.setDisabled(True) + self.cbar.btn_undo_pt_act.setDisabled(True) + self.cbar.btn_redo_pt_act.setDisabled(True) + self.cbar.btn_pt_edit_mode_act.setDisabled(True) + self.cbar.btn_view_lock_center_act.setDisabled(True) + self.cbar.btn_poly_color_act_grp.setDisabled(True) + self.cbar.btn_view_xseg_overlay_mask_act.setDisabled(True) + self.cbar.btn_poly_type_act_grp.setDisabled(True) + + + def set_color_scheme_id(self, id): + if self.op_mode == OpMode.VIEW_BAKED or self.op_mode == OpMode.VIEW_XSEG_MASK: + self.set_op_mode(OpMode.NONE) + + if not hasattr(self, 'color_scheme_id') or self.color_scheme_id != id: + self.color_scheme_id = id + self.update_cursor() + self.update() + + if self.color_scheme_id == 0: + self.cbar.btn_poly_color_red_act.setChecked( True ) + elif self.color_scheme_id == 1: + self.cbar.btn_poly_color_green_act.setChecked( True ) + elif self.color_scheme_id == 2: + self.cbar.btn_poly_color_blue_act.setChecked( True ) + + def set_poly_include_type(self, poly_include_type): + if not hasattr(self, 'poly_include_type' ) or \ + ( self.poly_include_type != poly_include_type and \ + self.op_mode in [OpMode.NONE, OpMode.EDIT_PTS] ): + self.poly_include_type = poly_include_type + self.update() + self.cbar.btn_poly_type_include_act.setChecked(self.poly_include_type == SegIEPolyType.INCLUDE) + self.cbar.btn_poly_type_exclude_act.setChecked(self.poly_include_type == SegIEPolyType.EXCLUDE) + + # ==================================================================================== + # ==================================================================================== + # ====================================== METHODS ===================================== + # ==================================================================================== + # ==================================================================================== + + def update_cursor(self, is_finalize=False): + if not self.initialized: + return + + if not self.mouse_in_widget or is_finalize: + if self.current_cursor is not None: + QApplication.restoreOverrideCursor() + self.current_cursor = None + else: + color_cc = self.get_current_color_scheme().cross_cursor + nc = Qt.ArrowCursor + + if self.drag_type == DragType.IMAGE_LOOK: + nc = Qt.ClosedHandCursor + else: + + if self.op_mode == OpMode.NONE: + nc = color_cc + if self.mouse_wire_poly is not None: + nc = Qt.PointingHandCursor + + elif self.op_mode == OpMode.DRAW_PTS: + nc = color_cc + elif self.op_mode == OpMode.EDIT_PTS: + nc = Qt.ArrowCursor + + if self.mouse_op_poly_pt_id is not None: + nc = Qt.PointingHandCursor + + if self.pt_edit_mode == PTEditMode.ADD_DEL: + + if self.mouse_op_poly_edge_id is not None and \ + self.mouse_op_poly_pt_id is None: + nc = color_cc + if self.current_cursor != nc: + if self.current_cursor is None: + QApplication.setOverrideCursor(nc) + else: + QApplication.changeOverrideCursor(nc) + self.current_cursor = nc + + def update_mouse_info(self, mouse_cli_pt=None): + """ + Update selected polys/edges/points by given mouse position + """ + if mouse_cli_pt is not None: + self.mouse_cli_pt = mouse_cli_pt.astype(np.float32) + + self.mouse_img_pt = self.cli_to_img_pt(self.mouse_cli_pt) + + new_mouse_hull_poly = self.get_poly_by_pt_in_hull(self.mouse_cli_pt) + + if self.mouse_hull_poly != new_mouse_hull_poly: + self.mouse_hull_poly = new_mouse_hull_poly + self.update_cursor() + self.update() + + new_mouse_wire_poly = self.get_poly_by_pt_near_wire(self.mouse_cli_pt) + + if self.mouse_wire_poly != new_mouse_wire_poly: + self.mouse_wire_poly = new_mouse_wire_poly + self.update_cursor() + self.update() + + if self.op_mode in [OpMode.DRAW_PTS, OpMode.EDIT_PTS]: + new_mouse_op_poly_pt_id = self.get_poly_pt_id_under_pt (self.op_poly, self.mouse_cli_pt) + if self.mouse_op_poly_pt_id != new_mouse_op_poly_pt_id: + self.mouse_op_poly_pt_id = new_mouse_op_poly_pt_id + self.update_cursor() + self.update() + + new_mouse_op_poly_edge_id,\ + new_mouse_op_poly_edge_id_pt = self.get_poly_edge_id_pt_under_pt (self.op_poly, self.mouse_cli_pt) + if self.mouse_op_poly_edge_id != new_mouse_op_poly_edge_id: + self.mouse_op_poly_edge_id = new_mouse_op_poly_edge_id + self.update_cursor() + self.update() + + if (self.mouse_op_poly_edge_id_pt.__class__ != new_mouse_op_poly_edge_id_pt.__class__) or \ + (isinstance(self.mouse_op_poly_edge_id_pt, np.ndarray) and \ + all(self.mouse_op_poly_edge_id_pt != new_mouse_op_poly_edge_id_pt)): + + self.mouse_op_poly_edge_id_pt = new_mouse_op_poly_edge_id_pt + self.update_cursor() + self.update() + + + def action_undo_pt(self): + if self.drag_type == DragType.NONE: + if self.op_mode == OpMode.DRAW_PTS: + if self.op_poly.undo() == 0: + self.ie_polys.remove_poly (self.op_poly) + self.set_op_mode(OpMode.NONE) + self.update() + + def action_redo_pt(self): + if self.drag_type == DragType.NONE: + if self.op_mode == OpMode.DRAW_PTS: + self.op_poly.redo() + self.update() + + def action_delete_poly(self): + if self.op_mode == OpMode.EDIT_PTS and \ + self.drag_type == DragType.NONE and \ + self.pt_edit_mode == PTEditMode.MOVE: + # Delete current poly + self.ie_polys.remove_poly (self.op_poly) + self.set_op_mode(OpMode.NONE) + + # ==================================================================================== + # ==================================================================================== + # ================================== OVERRIDE QT METHODS ============================= + # ==================================================================================== + # ==================================================================================== + def on_keyPressEvent(self, ev): + if not self.initialized: + return + key = ev.key() + key_mods = int(ev.modifiers()) + if self.op_mode == OpMode.DRAW_PTS: + self.set_view_lock(ViewLock.CENTER if key_mods == Qt.ShiftModifier else ViewLock.NONE ) + elif self.op_mode == OpMode.EDIT_PTS: + self.set_pt_edit_mode(PTEditMode.ADD_DEL if key_mods == Qt.ControlModifier else PTEditMode.MOVE ) + + def on_keyReleaseEvent(self, ev): + if not self.initialized: + return + key = ev.key() + key_mods = int(ev.modifiers()) + if self.op_mode == OpMode.DRAW_PTS: + self.set_view_lock(ViewLock.CENTER if key_mods == Qt.ShiftModifier else ViewLock.NONE ) + elif self.op_mode == OpMode.EDIT_PTS: + self.set_pt_edit_mode(PTEditMode.ADD_DEL if key_mods == Qt.ControlModifier else PTEditMode.MOVE ) + + def enterEvent(self, ev): + super().enterEvent(ev) + self.mouse_in_widget = True + self.update_cursor() + + def leaveEvent(self, ev): + super().leaveEvent(ev) + self.mouse_in_widget = False + self.update_cursor() + + def mousePressEvent(self, ev): + super().mousePressEvent(ev) + if not self.initialized: + return + + self.update_mouse_info(QPoint_to_np(ev.pos())) + + btn = ev.button() + + if btn == Qt.LeftButton: + if self.op_mode == OpMode.NONE: + # Clicking in NO OPERATION mode + if self.mouse_wire_poly is not None: + # Click on wire on any poly -> switch to EDIT_MODE + self.set_op_mode(OpMode.EDIT_PTS, op_poly=self.mouse_wire_poly) + else: + # Click on empty space -> create new poly with one point + new_poly = self.ie_polys.add_poly(self.poly_include_type) + self.ie_polys.sort() + new_poly.add_pt(*self.mouse_img_pt) + self.set_op_mode(OpMode.DRAW_PTS, op_poly=new_poly ) + + elif self.op_mode == OpMode.DRAW_PTS: + # Clicking in DRAW_PTS mode + if len(self.op_poly.get_pts()) >= 3 and self.mouse_op_poly_pt_id == 0: + # Click on first point -> close poly and switch to edit mode + self.set_op_mode(OpMode.EDIT_PTS, op_poly=self.op_poly) + else: + # Click on empty space -> add point to current poly + self.op_poly.add_pt(*self.mouse_img_pt) + self.update() + + elif self.op_mode == OpMode.EDIT_PTS: + # Clicking in EDIT_PTS mode + + if self.mouse_op_poly_pt_id is not None: + # Click on point of op_poly + if self.pt_edit_mode == PTEditMode.ADD_DEL: + # in mode 'delete point' + self.op_poly.remove_pt(self.mouse_op_poly_pt_id) + if self.op_poly.get_pts_count() < 3: + # not enough points after delete -> remove poly + self.ie_polys.remove_poly (self.op_poly) + self.set_op_mode(OpMode.NONE) + self.update() + + elif self.drag_type == DragType.NONE: + # otherwise -> start drag + self.drag_type = DragType.POLY_PT + self.drag_cli_pt = self.mouse_cli_pt + self.drag_poly_pt_id = self.mouse_op_poly_pt_id + self.drag_poly_pt = self.op_poly.get_pts()[ self.drag_poly_pt_id ] + elif self.mouse_op_poly_edge_id is not None: + # Click on edge of op_poly + if self.pt_edit_mode == PTEditMode.ADD_DEL: + # in mode 'insert new point' + edge_img_pt = self.cli_to_img_pt(self.mouse_op_poly_edge_id_pt) + self.op_poly.insert_pt (self.mouse_op_poly_edge_id+1, edge_img_pt) + self.update() + else: + # Otherwise do nothing + pass + else: + # other cases -> unselect poly + self.set_op_mode(OpMode.NONE) + + elif btn == Qt.MiddleButton: + if self.drag_type == DragType.NONE: + # Start image drag + self.drag_type = DragType.IMAGE_LOOK + self.drag_cli_pt = self.mouse_cli_pt + self.drag_img_look_pt = self.get_img_look_pt() + self.update_cursor() + + + def mouseReleaseEvent(self, ev): + super().mouseReleaseEvent(ev) + if not self.initialized: + return + + self.update_mouse_info(QPoint_to_np(ev.pos())) + + btn = ev.button() + + if btn == Qt.LeftButton: + if self.op_mode == OpMode.EDIT_PTS: + if self.drag_type == DragType.POLY_PT: + self.drag_type = DragType.NONE + self.update() + + elif btn == Qt.MiddleButton: + if self.drag_type == DragType.IMAGE_LOOK: + self.drag_type = DragType.NONE + self.update_cursor() + self.update() + + def mouseMoveEvent(self, ev): + super().mouseMoveEvent(ev) + if not self.initialized: + return + + prev_mouse_cli_pt = self.mouse_cli_pt + self.update_mouse_info(QPoint_to_np(ev.pos())) + + if self.view_lock == ViewLock.CENTER: + if npla.norm(self.mouse_cli_pt - prev_mouse_cli_pt) >= 1: + self.img_look_pt = self.mouse_img_pt + QCursor.setPos ( self.mapToGlobal( QPoint_from_np(self.img_to_cli_pt(self.img_look_pt)) )) + self.update() + + if self.drag_type == DragType.IMAGE_LOOK: + delta_pt = self.cli_to_img_pt(self.mouse_cli_pt) - self.cli_to_img_pt(self.drag_cli_pt) + self.img_look_pt = self.drag_img_look_pt - delta_pt + self.update() + + if self.op_mode == OpMode.DRAW_PTS: + self.update() + elif self.op_mode == OpMode.EDIT_PTS: + if self.drag_type == DragType.POLY_PT: + delta_pt = self.cli_to_img_pt(self.mouse_cli_pt) - self.cli_to_img_pt(self.drag_cli_pt) + self.op_poly.set_point(self.drag_poly_pt_id, self.drag_poly_pt + delta_pt) + self.update() + + def wheelEvent(self, ev): + super().wheelEvent(ev) + + if not self.initialized: + return + + mods = int(ev.modifiers()) + delta = ev.angleDelta() + + cli_pt = QPoint_to_np(ev.pos()) + + if self.drag_type == DragType.NONE: + sign = np.sign( delta.y() ) + prev_img_pos = self.cli_to_img_pt (cli_pt) + delta_scale = sign*0.2 + sign * self.get_view_scale() / 10.0 + self.view_scale = np.clip(self.get_view_scale() + delta_scale, 1.0, 20.0) + new_img_pos = self.cli_to_img_pt (cli_pt) + if sign > 0: + self.img_look_pt = self.get_img_look_pt() + (prev_img_pos-new_img_pos)#*1.5 + else: + QCursor.setPos ( self.mapToGlobal(QPoint_from_np(self.img_to_cli_pt(prev_img_pos))) ) + self.update() + + def paintEvent(self, event): + super().paintEvent(event) + if not self.initialized: + return + + qp = self.qp + qp.begin(self) + qp.setRenderHint(QPainter.Antialiasing) + qp.setRenderHint(QPainter.HighQualityAntialiasing) + qp.setRenderHint(QPainter.SmoothPixmapTransform) + + src_rect = QRect(0, 0, *self.img_size) + dst_rect = self.img_to_cli_rect( src_rect ) + + if self.op_mode == OpMode.VIEW_BAKED: + qp.drawPixmap(dst_rect, self.img_baked_pixmap, src_rect) + elif self.op_mode == OpMode.VIEW_XSEG_MASK: + if self.xseg_mask_pixmap is not None: + qp.drawPixmap(dst_rect, self.xseg_mask_pixmap, src_rect) + else: + if self.cbar.btn_view_xseg_overlay_mask_act.isChecked() and \ + self.xseg_overlay_mask_pixmap is not None: + qp.drawPixmap(dst_rect, self.xseg_overlay_mask_pixmap, src_rect) + elif self.img_pixmap is not None: + qp.drawPixmap(dst_rect, self.img_pixmap, src_rect) + + polys = self.ie_polys.get_polys() + polys_len = len(polys) + + color_scheme = self.get_current_color_scheme() + + pt_rad = self.canvas_config.pt_radius + pt_rad_x2 = pt_rad*2 + + pt_select_radius = self.canvas_config.pt_select_radius + + op_mode = self.op_mode + op_poly = self.op_poly + + for i,poly in enumerate(polys): + + selected_pt_path = QPainterPath() + poly_line_path = QPainterPath() + pts_line_path = QPainterPath() + + pt_remove_cli_pt = None + poly_pts = poly.get_pts() + for pt_id, img_pt in enumerate(poly_pts): + cli_pt = self.img_to_cli_pt(img_pt) + q_cli_pt = QPoint_from_np(cli_pt) + + if pt_id == 0: + poly_line_path.moveTo(q_cli_pt) + else: + poly_line_path.lineTo(q_cli_pt) + + + if poly == op_poly: + if self.op_mode == OpMode.DRAW_PTS or \ + (self.op_mode == OpMode.EDIT_PTS and \ + (self.pt_edit_mode == PTEditMode.MOVE) or \ + (self.pt_edit_mode == PTEditMode.ADD_DEL and self.mouse_op_poly_pt_id == pt_id) \ + ): + pts_line_path.moveTo( QPoint_from_np(cli_pt + np.float32([0,-pt_rad])) ) + pts_line_path.lineTo( QPoint_from_np(cli_pt + np.float32([0,pt_rad])) ) + pts_line_path.moveTo( QPoint_from_np(cli_pt + np.float32([-pt_rad,0])) ) + pts_line_path.lineTo( QPoint_from_np(cli_pt + np.float32([pt_rad,0])) ) + + if (self.op_mode == OpMode.EDIT_PTS and \ + self.pt_edit_mode == PTEditMode.ADD_DEL and \ + self.mouse_op_poly_pt_id == pt_id): + pt_remove_cli_pt = cli_pt + + if self.op_mode == OpMode.DRAW_PTS and \ + len(op_poly.get_pts()) >= 3 and pt_id == 0 and self.mouse_op_poly_pt_id == pt_id: + # Circle around poly point + selected_pt_path.addEllipse(q_cli_pt, pt_rad_x2, pt_rad_x2) + + + if poly == op_poly: + if op_mode == OpMode.DRAW_PTS: + # Line from last point to mouse + poly_line_path.lineTo( QPoint_from_np(self.mouse_cli_pt) ) + + if self.mouse_op_poly_pt_id is not None: + pass + + if self.mouse_op_poly_edge_id_pt is not None: + if self.pt_edit_mode == PTEditMode.ADD_DEL and self.mouse_op_poly_pt_id is None: + # Ready to insert point on edge + m_cli_pt = self.mouse_op_poly_edge_id_pt + pts_line_path.moveTo( QPoint_from_np(m_cli_pt + np.float32([0,-pt_rad])) ) + pts_line_path.lineTo( QPoint_from_np(m_cli_pt + np.float32([0,pt_rad])) ) + pts_line_path.moveTo( QPoint_from_np(m_cli_pt + np.float32([-pt_rad,0])) ) + pts_line_path.lineTo( QPoint_from_np(m_cli_pt + np.float32([pt_rad,0])) ) + + if len(poly_pts) >= 2: + # Closing poly line + poly_line_path.lineTo( QPoint_from_np(self.img_to_cli_pt(poly_pts[0])) ) + + # Draw calls + qp.setPen(color_scheme.pt_outline_pen) + qp.setBrush(QBrush()) + qp.drawPath(selected_pt_path) + + qp.setPen(color_scheme.poly_outline_solid_pen) + qp.setBrush(QBrush()) + qp.drawPath(pts_line_path) + + if poly.get_type() == SegIEPolyType.INCLUDE: + qp.setPen(color_scheme.poly_outline_solid_pen) + else: + qp.setPen(color_scheme.poly_outline_dot_pen) + + qp.setBrush(color_scheme.poly_unselected_brush) + if op_mode == OpMode.NONE: + if poly == self.mouse_wire_poly: + qp.setBrush(color_scheme.poly_selected_brush) + #else: + # if poly == op_poly: + # qp.setBrush(color_scheme.poly_selected_brush) + + qp.drawPath(poly_line_path) + + if pt_remove_cli_pt is not None: + qp.setPen(color_scheme.poly_outline_solid_pen) + qp.setBrush(QBrush()) + + qp.drawLine( *(pt_remove_cli_pt + np.float32([-pt_rad_x2,-pt_rad_x2])), *(pt_remove_cli_pt + np.float32([pt_rad_x2,pt_rad_x2])) ) + qp.drawLine( *(pt_remove_cli_pt + np.float32([-pt_rad_x2,pt_rad_x2])), *(pt_remove_cli_pt + np.float32([pt_rad_x2,-pt_rad_x2])) ) + + qp.end() + +class QCanvas(QFrame): + def __init__(self): + super().__init__() + + self.canvas_control_left_bar = QCanvasControlsLeftBar() + self.canvas_control_right_bar = QCanvasControlsRightBar() + + cbar = sn( btn_poly_color_red_act = self.canvas_control_right_bar.btn_poly_color_red_act, + btn_poly_color_green_act = self.canvas_control_right_bar.btn_poly_color_green_act, + btn_poly_color_blue_act = self.canvas_control_right_bar.btn_poly_color_blue_act, + btn_view_baked_mask_act = self.canvas_control_right_bar.btn_view_baked_mask_act, + btn_view_xseg_mask_act = self.canvas_control_right_bar.btn_view_xseg_mask_act, + btn_view_xseg_overlay_mask_act = self.canvas_control_right_bar.btn_view_xseg_overlay_mask_act, + btn_poly_color_act_grp = self.canvas_control_right_bar.btn_poly_color_act_grp, + btn_view_lock_center_act = self.canvas_control_right_bar.btn_view_lock_center_act, + + btn_poly_type_include_act = self.canvas_control_left_bar.btn_poly_type_include_act, + btn_poly_type_exclude_act = self.canvas_control_left_bar.btn_poly_type_exclude_act, + btn_poly_type_act_grp = self.canvas_control_left_bar.btn_poly_type_act_grp, + btn_undo_pt_act = self.canvas_control_left_bar.btn_undo_pt_act, + btn_redo_pt_act = self.canvas_control_left_bar.btn_redo_pt_act, + btn_delete_poly_act = self.canvas_control_left_bar.btn_delete_poly_act, + btn_pt_edit_mode_act = self.canvas_control_left_bar.btn_pt_edit_mode_act ) + + self.op = QCanvasOperator(cbar) + self.l = QHBoxLayout() + self.l.setContentsMargins(0,0,0,0) + self.l.addWidget(self.canvas_control_left_bar) + self.l.addWidget(self.op) + self.l.addWidget(self.canvas_control_right_bar) + self.setLayout(self.l) + +class LoaderQSubprocessor(QSubprocessor): + def __init__(self, image_paths, q_label, q_progressbar, on_finish_func ): + + self.image_paths = image_paths + self.image_paths_len = len(image_paths) + self.idxs = [*range(self.image_paths_len)] + + self.filtered_image_paths = self.image_paths.copy() + + self.image_paths_has_ie_polys = { image_path : False for image_path in self.image_paths } + + self.q_label = q_label + self.q_progressbar = q_progressbar + self.q_progressbar.setRange(0, self.image_paths_len) + self.q_progressbar.setValue(0) + self.q_progressbar.update() + self.on_finish_func = on_finish_func + self.done_count = 0 + super().__init__('LoaderQSubprocessor', LoaderQSubprocessor.Cli, 60) + + def get_data(self, host_dict): + if len (self.idxs) > 0: + idx = self.idxs.pop(0) + image_path = self.image_paths[idx] + self.q_label.setText(f'{QStringDB.loading_tip}... {image_path.name}') + + return idx, image_path + + return None + + def on_clients_finalized(self): + self.on_finish_func([x for x in self.filtered_image_paths if x is not None], self.image_paths_has_ie_polys) + + def on_data_return (self, host_dict, data): + self.idxs.insert(0, data[0]) + + def on_result (self, host_dict, data, result): + idx, has_dflimg, has_ie_polys = result + + if not has_dflimg: + self.filtered_image_paths[idx] = None + self.image_paths_has_ie_polys[self.image_paths[idx]] = has_ie_polys + + self.done_count += 1 + if self.q_progressbar is not None: + self.q_progressbar.setValue(self.done_count) + + class Cli(QSubprocessor.Cli): + def process_data(self, data): + idx, filename = data + dflimg = DFLIMG.load(filename) + if dflimg is not None and dflimg.has_data(): + ie_polys = dflimg.get_seg_ie_polys() + + return idx, True, ie_polys.has_polys() + return idx, False, False + +class MainWindow(QXMainWindow): + + def __init__(self, input_dirpath, cfg_root_path): + self.loading_frame = None + self.help_frame = None + + super().__init__() + + self.input_dirpath = input_dirpath + self.trash_dirpath = input_dirpath.parent / (input_dirpath.name + '_trash') + self.cfg_root_path = cfg_root_path + + self.cfg_path = cfg_root_path / 'MainWindow_cfg.dat' + self.cfg_dict = pickle.loads(self.cfg_path.read_bytes()) if self.cfg_path.exists() else {} + + self.cached_images = {} + self.cached_has_ie_polys = {} + + self.initialize_ui() + + # Loader + self.loading_frame = QFrame(self.main_canvas_frame) + self.loading_frame.setAutoFillBackground(True) + self.loading_frame.setFrameShape(QFrame.StyledPanel) + self.loader_label = QLabel() + self.loader_progress_bar = QProgressBar() + + intro_image = QLabel() + intro_image.setPixmap( QPixmap.fromImage(QImageDB.intro) ) + + intro_image_frame_l = QVBoxLayout() + intro_image_frame_l.addWidget(intro_image, alignment=Qt.AlignCenter) + intro_image_frame = QFrame() + intro_image_frame.setSizePolicy (QSizePolicy.Expanding, QSizePolicy.Expanding) + intro_image_frame.setLayout(intro_image_frame_l) + + loading_frame_l = QVBoxLayout() + loading_frame_l.addWidget (intro_image_frame) + loading_frame_l.addWidget (self.loader_label) + loading_frame_l.addWidget (self.loader_progress_bar) + self.loading_frame.setLayout(loading_frame_l) + + self.loader_subprocessor = LoaderQSubprocessor( image_paths=pathex.get_image_paths(input_dirpath, return_Path_class=True), + q_label=self.loader_label, + q_progressbar=self.loader_progress_bar, + on_finish_func=self.on_loader_finish ) + + + def on_loader_finish(self, image_paths, image_paths_has_ie_polys): + self.image_paths_done = [] + self.image_paths = image_paths + self.image_paths_has_ie_polys = image_paths_has_ie_polys + self.set_has_ie_polys_count ( len([ 1 for x in self.image_paths_has_ie_polys if self.image_paths_has_ie_polys[x] == True]) ) + self.loading_frame.hide() + self.loading_frame = None + + self.process_next_image(first_initialization=True) + + def closeEvent(self, ev): + self.cfg_dict['geometry'] = self.saveGeometry().data() + self.cfg_path.write_bytes( pickle.dumps(self.cfg_dict) ) + + + def update_cached_images (self, count=5): + d = self.cached_images + + for image_path in self.image_paths_done[:-count]+self.image_paths[count:]: + if image_path in d: + del d[image_path] + + for image_path in self.image_paths[:count]+self.image_paths_done[-count:]: + if image_path not in d: + img = cv2_imread(image_path) + if img is not None: + d[image_path] = img + + def load_image(self, image_path): + try: + img = self.cached_images.get(image_path, None) + if img is None: + img = cv2_imread(image_path) + self.cached_images[image_path] = img + if img is None: + io.log_err(f'Unable to load {image_path}') + except: + img = None + + return img + + def update_preview_bar(self): + count = self.image_bar.get_preview_images_count() + d = self.cached_images + prev_imgs = [ d.get(image_path, None) for image_path in self.image_paths_done[-1:-count:-1] ] + next_imgs = [ d.get(image_path, None) for image_path in self.image_paths[:count] ] + self.image_bar.update_images(prev_imgs, next_imgs) + + + def canvas_initialize(self, image_path, only_has_polys=False): + if only_has_polys and not self.image_paths_has_ie_polys[image_path]: + return False + + dflimg = DFLIMG.load(image_path) + if not dflimg or not dflimg.has_data(): + return False + + ie_polys = dflimg.get_seg_ie_polys() + xseg_mask = dflimg.get_xseg_mask() + img = self.load_image(image_path) + if img is None: + return False + + self.canvas.op.initialize ( img, ie_polys=ie_polys, xseg_mask=xseg_mask ) + + self.filename_label.setText(f"{image_path.name}") + + return True + + def canvas_finalize(self, image_path): + self.canvas.op.finalize() + + if image_path.exists(): + dflimg = DFLIMG.load(image_path) + ie_polys = dflimg.get_seg_ie_polys() + new_ie_polys = self.canvas.op.get_ie_polys() + + if not new_ie_polys.identical(ie_polys): + prev_has_polys = self.image_paths_has_ie_polys[image_path] + self.image_paths_has_ie_polys[image_path] = new_ie_polys.has_polys() + new_has_polys = self.image_paths_has_ie_polys[image_path] + + if not prev_has_polys and new_has_polys: + self.set_has_ie_polys_count ( self.get_has_ie_polys_count() +1) + elif prev_has_polys and not new_has_polys: + self.set_has_ie_polys_count ( self.get_has_ie_polys_count() -1) + + dflimg.set_seg_ie_polys( new_ie_polys ) + dflimg.save() + + self.filename_label.setText(f"") + + def process_prev_image(self): + key_mods = QApplication.keyboardModifiers() + step = 5 if key_mods == Qt.ShiftModifier else 1 + only_has_polys = key_mods == Qt.ControlModifier + + if self.canvas.op.is_initialized(): + self.canvas_finalize(self.image_paths[0]) + + while True: + for _ in range(step): + if len(self.image_paths_done) != 0: + self.image_paths.insert (0, self.image_paths_done.pop(-1)) + else: + break + if len(self.image_paths) == 0: + break + + ret = self.canvas_initialize(self.image_paths[0], len(self.image_paths_done) != 0 and only_has_polys) + + if ret or len(self.image_paths_done) == 0: + break + + self.update_cached_images() + self.update_preview_bar() + + def process_next_image(self, first_initialization=False): + key_mods = QApplication.keyboardModifiers() + + step = 0 if first_initialization else 5 if key_mods == Qt.ShiftModifier else 1 + only_has_polys = False if first_initialization else key_mods == Qt.ControlModifier + + if self.canvas.op.is_initialized(): + self.canvas_finalize(self.image_paths[0]) + + while True: + for _ in range(step): + if len(self.image_paths) != 0: + self.image_paths_done.append(self.image_paths.pop(0)) + else: + break + if len(self.image_paths) == 0: + break + if self.canvas_initialize(self.image_paths[0], only_has_polys): + break + + self.update_cached_images() + self.update_preview_bar() + + def trash_current_image(self): + self.process_next_image() + + img_path = self.image_paths_done.pop(-1) + img_path = Path(img_path) + self.trash_dirpath.mkdir(parents=True, exist_ok=True) + img_path.rename( self.trash_dirpath / img_path.name ) + + self.update_cached_images() + self.update_preview_bar() + + def initialize_ui(self): + + self.canvas = QCanvas() + + image_bar = self.image_bar = ImagePreviewSequenceBar(preview_images_count=9, icon_size=QUIConfig.preview_bar_icon_q_size.width()) + image_bar.setSizePolicy ( QSizePolicy.Fixed, QSizePolicy.Fixed ) + + + btn_prev_image = QXIconButton(QIconDB.left, QStringDB.btn_prev_image_tip, shortcut='A', click_func=self.process_prev_image) + btn_prev_image.setIconSize(QUIConfig.preview_bar_icon_q_size) + + btn_next_image = QXIconButton(QIconDB.right, QStringDB.btn_next_image_tip, shortcut='D', click_func=self.process_next_image) + btn_next_image.setIconSize(QUIConfig.preview_bar_icon_q_size) + + btn_delete_image = QXIconButton(QIconDB.trashcan, QStringDB.btn_delete_image_tip, shortcut='X', click_func=self.trash_current_image) + btn_delete_image.setIconSize(QUIConfig.preview_bar_icon_q_size) + + pad_image = QWidget() + pad_image.setFixedSize(QUIConfig.preview_bar_icon_q_size) + + preview_image_bar_frame_l = QHBoxLayout() + preview_image_bar_frame_l.setContentsMargins(0,0,0,0) + preview_image_bar_frame_l.addWidget ( pad_image, alignment=Qt.AlignCenter) + preview_image_bar_frame_l.addWidget ( btn_prev_image, alignment=Qt.AlignCenter) + preview_image_bar_frame_l.addWidget ( image_bar) + preview_image_bar_frame_l.addWidget ( btn_next_image, alignment=Qt.AlignCenter) + #preview_image_bar_frame_l.addWidget ( btn_delete_image, alignment=Qt.AlignCenter) + + preview_image_bar_frame = QFrame() + preview_image_bar_frame.setSizePolicy ( QSizePolicy.Fixed, QSizePolicy.Fixed ) + preview_image_bar_frame.setLayout(preview_image_bar_frame_l) + + preview_image_bar_frame2_l = QHBoxLayout() + preview_image_bar_frame2_l.setContentsMargins(0,0,0,0) + preview_image_bar_frame2_l.addWidget ( btn_delete_image, alignment=Qt.AlignCenter) + + preview_image_bar_frame2 = QFrame() + preview_image_bar_frame2.setSizePolicy ( QSizePolicy.Fixed, QSizePolicy.Fixed ) + preview_image_bar_frame2.setLayout(preview_image_bar_frame2_l) + + preview_image_bar_l = QHBoxLayout() + preview_image_bar_l.addWidget (preview_image_bar_frame, alignment=Qt.AlignCenter) + preview_image_bar_l.addWidget (preview_image_bar_frame2) + + preview_image_bar = QFrame() + preview_image_bar.setFrameShape(QFrame.StyledPanel) + preview_image_bar.setSizePolicy ( QSizePolicy.Expanding, QSizePolicy.Fixed ) + preview_image_bar.setLayout(preview_image_bar_l) + + label_font = QFont('Courier New') + self.filename_label = QLabel() + self.filename_label.setFont(label_font) + + self.has_ie_polys_count_label = QLabel() + + status_frame_l = QHBoxLayout() + status_frame_l.setContentsMargins(0,0,0,0) + status_frame_l.addWidget ( QLabel(), alignment=Qt.AlignCenter) + status_frame_l.addWidget (self.filename_label, alignment=Qt.AlignCenter) + status_frame_l.addWidget (self.has_ie_polys_count_label, alignment=Qt.AlignCenter) + status_frame = QFrame() + status_frame.setLayout(status_frame_l) + + main_canvas_l = QVBoxLayout() + main_canvas_l.setContentsMargins(0,0,0,0) + main_canvas_l.addWidget (self.canvas) + main_canvas_l.addWidget (status_frame) + main_canvas_l.addWidget (preview_image_bar) + + self.main_canvas_frame = QFrame() + self.main_canvas_frame.setLayout(main_canvas_l) + + self.main_l = QHBoxLayout() + self.main_l.setContentsMargins(0,0,0,0) + self.main_l.addWidget (self.main_canvas_frame) + + self.setLayout(self.main_l) + + geometry = self.cfg_dict.get('geometry', None) + if geometry is not None: + self.restoreGeometry(geometry) + else: + self.move( QPoint(0,0)) + + def get_has_ie_polys_count(self): + return self.has_ie_polys_count + + def set_has_ie_polys_count(self, c): + self.has_ie_polys_count = c + self.has_ie_polys_count_label.setText(f"{c} {QStringDB.labeled_tip}") + + def resizeEvent(self, ev): + if self.loading_frame is not None: + self.loading_frame.resize( ev.size() ) + if self.help_frame is not None: + self.help_frame.resize( ev.size() ) + +def start(input_dirpath): + """ + returns exit_code + """ + io.log_info("Running XSeg editor.") + + if PackedFaceset.path_contains(input_dirpath): + io.log_info (f'\n{input_dirpath} contains packed faceset! Unpack it first.\n') + return 1 + + root_path = Path(__file__).parent + cfg_root_path = Path(tempfile.gettempdir()) + + QApplication.setAttribute(Qt.AA_EnableHighDpiScaling, True) + QApplication.setAttribute(Qt.AA_UseHighDpiPixmaps, True) + + app = QApplication([]) + app.setApplicationName("XSegEditor") + app.setStyle('Fusion') + + QFontDatabase.addApplicationFont( str(root_path / 'gfx' / 'fonts' / 'NotoSans-Medium.ttf') ) + + app.setFont( QFont('NotoSans')) + + QUIConfig.initialize() + QStringDB.initialize() + + QIconDB.initialize( root_path / 'gfx' / 'icons' ) + QCursorDB.initialize( root_path / 'gfx' / 'cursors' ) + QImageDB.initialize( root_path / 'gfx' / 'images' ) + + app.setWindowIcon(QIconDB.app_icon) + app.setPalette( QDarkPalette() ) + + win = MainWindow( input_dirpath=input_dirpath, cfg_root_path=cfg_root_path) + + win.show() + win.raise_() + + app.exec_() + return 0 diff --git a/XSegEditor/gfx/cursors/cross_blue.png b/XSegEditor/gfx/cursors/cross_blue.png new file mode 100644 index 0000000..8915219 Binary files /dev/null and b/XSegEditor/gfx/cursors/cross_blue.png differ diff --git a/XSegEditor/gfx/cursors/cross_green.png b/XSegEditor/gfx/cursors/cross_green.png new file mode 100644 index 0000000..3ce16f0 Binary files /dev/null and b/XSegEditor/gfx/cursors/cross_green.png differ diff --git a/XSegEditor/gfx/cursors/cross_red.png b/XSegEditor/gfx/cursors/cross_red.png new file mode 100644 index 0000000..bb851ac Binary files /dev/null and b/XSegEditor/gfx/cursors/cross_red.png differ diff --git a/XSegEditor/gfx/fonts/NotoSans-Medium.ttf b/XSegEditor/gfx/fonts/NotoSans-Medium.ttf new file mode 100644 index 0000000..25050f7 Binary files /dev/null and b/XSegEditor/gfx/fonts/NotoSans-Medium.ttf differ diff --git a/XSegEditor/gfx/icons/app_icon.png b/XSegEditor/gfx/icons/app_icon.png new file mode 100644 index 0000000..16bc03e Binary files /dev/null and b/XSegEditor/gfx/icons/app_icon.png differ diff --git a/XSegEditor/gfx/icons/delete_poly.png b/XSegEditor/gfx/icons/delete_poly.png new file mode 100644 index 0000000..afd57d1 Binary files /dev/null and b/XSegEditor/gfx/icons/delete_poly.png differ diff --git a/XSegEditor/gfx/icons/down.png b/XSegEditor/gfx/icons/down.png new file mode 100644 index 0000000..873b719 Binary files /dev/null and b/XSegEditor/gfx/icons/down.png differ diff --git a/XSegEditor/gfx/icons/left.png b/XSegEditor/gfx/icons/left.png new file mode 100644 index 0000000..2118be6 Binary files /dev/null and b/XSegEditor/gfx/icons/left.png differ diff --git a/XSegEditor/gfx/icons/poly_color.psd b/XSegEditor/gfx/icons/poly_color.psd new file mode 100644 index 0000000..9a94957 Binary files /dev/null and b/XSegEditor/gfx/icons/poly_color.psd differ diff --git a/XSegEditor/gfx/icons/poly_color_blue.png b/XSegEditor/gfx/icons/poly_color_blue.png new file mode 100644 index 0000000..80b5222 Binary files /dev/null and b/XSegEditor/gfx/icons/poly_color_blue.png differ diff --git a/XSegEditor/gfx/icons/poly_color_green.png b/XSegEditor/gfx/icons/poly_color_green.png new file mode 100644 index 0000000..2db1fbb Binary files /dev/null and b/XSegEditor/gfx/icons/poly_color_green.png differ diff --git a/XSegEditor/gfx/icons/poly_color_red.png b/XSegEditor/gfx/icons/poly_color_red.png new file mode 100644 index 0000000..d04efff Binary files /dev/null and b/XSegEditor/gfx/icons/poly_color_red.png differ diff --git a/XSegEditor/gfx/icons/poly_type_exclude.png b/XSegEditor/gfx/icons/poly_type_exclude.png new file mode 100644 index 0000000..8e36bc3 Binary files /dev/null and b/XSegEditor/gfx/icons/poly_type_exclude.png differ diff --git a/XSegEditor/gfx/icons/poly_type_include.png b/XSegEditor/gfx/icons/poly_type_include.png new file mode 100644 index 0000000..5f16c15 Binary files /dev/null and b/XSegEditor/gfx/icons/poly_type_include.png differ diff --git a/XSegEditor/gfx/icons/poly_type_source.psd b/XSegEditor/gfx/icons/poly_type_source.psd new file mode 100644 index 0000000..50943d0 Binary files /dev/null and b/XSegEditor/gfx/icons/poly_type_source.psd differ diff --git a/XSegEditor/gfx/icons/pt_edit_mode.png b/XSegEditor/gfx/icons/pt_edit_mode.png new file mode 100644 index 0000000..d385fc2 Binary files /dev/null and b/XSegEditor/gfx/icons/pt_edit_mode.png differ diff --git a/XSegEditor/gfx/icons/pt_edit_mode_source.psd b/XSegEditor/gfx/icons/pt_edit_mode_source.psd new file mode 100644 index 0000000..f73e310 Binary files /dev/null and b/XSegEditor/gfx/icons/pt_edit_mode_source.psd differ diff --git a/XSegEditor/gfx/icons/redo_pt.png b/XSegEditor/gfx/icons/redo_pt.png new file mode 100644 index 0000000..aa73329 Binary files /dev/null and b/XSegEditor/gfx/icons/redo_pt.png differ diff --git a/XSegEditor/gfx/icons/redo_pt_source.psd b/XSegEditor/gfx/icons/redo_pt_source.psd new file mode 100644 index 0000000..2771f77 Binary files /dev/null and b/XSegEditor/gfx/icons/redo_pt_source.psd differ diff --git a/XSegEditor/gfx/icons/right.png b/XSegEditor/gfx/icons/right.png new file mode 100644 index 0000000..b4ef220 Binary files /dev/null and b/XSegEditor/gfx/icons/right.png differ diff --git a/XSegEditor/gfx/icons/trashcan.png b/XSegEditor/gfx/icons/trashcan.png new file mode 100644 index 0000000..a31285b Binary files /dev/null and b/XSegEditor/gfx/icons/trashcan.png differ diff --git a/XSegEditor/gfx/icons/undo_pt.png b/XSegEditor/gfx/icons/undo_pt.png new file mode 100644 index 0000000..7a4464c Binary files /dev/null and b/XSegEditor/gfx/icons/undo_pt.png differ diff --git a/XSegEditor/gfx/icons/undo_pt_source.psd b/XSegEditor/gfx/icons/undo_pt_source.psd new file mode 100644 index 0000000..98b9d1a Binary files /dev/null and b/XSegEditor/gfx/icons/undo_pt_source.psd differ diff --git a/XSegEditor/gfx/icons/up.png b/XSegEditor/gfx/icons/up.png new file mode 100644 index 0000000..f3368b6 Binary files /dev/null and b/XSegEditor/gfx/icons/up.png differ diff --git a/XSegEditor/gfx/icons/view_baked.png b/XSegEditor/gfx/icons/view_baked.png new file mode 100644 index 0000000..3e32142 Binary files /dev/null and b/XSegEditor/gfx/icons/view_baked.png differ diff --git a/XSegEditor/gfx/icons/view_lock_center.png b/XSegEditor/gfx/icons/view_lock_center.png new file mode 100644 index 0000000..2a10408 Binary files /dev/null and b/XSegEditor/gfx/icons/view_lock_center.png differ diff --git a/XSegEditor/gfx/icons/view_xseg.png b/XSegEditor/gfx/icons/view_xseg.png new file mode 100644 index 0000000..7328d2c Binary files /dev/null and b/XSegEditor/gfx/icons/view_xseg.png differ diff --git a/XSegEditor/gfx/icons/view_xseg_overlay.png b/XSegEditor/gfx/icons/view_xseg_overlay.png new file mode 100644 index 0000000..d188285 Binary files /dev/null and b/XSegEditor/gfx/icons/view_xseg_overlay.png differ diff --git a/XSegEditor/gfx/images/intro.png b/XSegEditor/gfx/images/intro.png new file mode 100644 index 0000000..7f4d43f Binary files /dev/null and b/XSegEditor/gfx/images/intro.png differ diff --git a/XSegEditor/gfx/images/intro_source.psd b/XSegEditor/gfx/images/intro_source.psd new file mode 100644 index 0000000..bb1cc90 Binary files /dev/null and b/XSegEditor/gfx/images/intro_source.psd differ diff --git a/_config.yml b/_config.yml index c419263..9751715 100644 --- a/_config.yml +++ b/_config.yml @@ -1 +1,9 @@ -theme: jekyll-theme-cayman \ No newline at end of file +theme: jekyll-theme-cayman +plugins: + - jekyll-relative-links +relative_links: + enabled: true + collections: true + +include: + - README.md \ No newline at end of file diff --git a/core/cv2ex.py b/core/cv2ex.py index 17c8095..aa5d73c 100644 --- a/core/cv2ex.py +++ b/core/cv2ex.py @@ -2,6 +2,7 @@ import cv2 import numpy as np from pathlib import Path from core.interact import interact as io +from core import imagelib import traceback def cv2_imread(filename, flags=cv2.IMREAD_UNCHANGED, loader_func=None, verbose=True): @@ -29,3 +30,11 @@ def cv2_imwrite(filename, img, *args): stream.write( buf ) except: pass + +def cv2_resize(x, *args, **kwargs): + h,w,c = x.shape + x = cv2.resize(x, *args, **kwargs) + + x = imagelib.normalize_channels(x, c) + return x + \ No newline at end of file diff --git a/core/imagelib/IEPolys.py b/core/imagelib/IEPolys.py deleted file mode 100644 index aee7333..0000000 --- a/core/imagelib/IEPolys.py +++ /dev/null @@ -1,109 +0,0 @@ -import numpy as np -import cv2 - -class IEPolysPoints: - def __init__(self, IEPolys_parent, type): - self.parent = IEPolys_parent - self.type = type - self.points = np.empty( (0,2), dtype=np.int32 ) - self.n_max = self.n = 0 - - def add(self,x,y): - self.points = np.append(self.points[0:self.n], [ (x,y) ], axis=0) - self.n_max = self.n = self.n + 1 - self.parent.dirty = True - - def n_dec(self): - self.n = max(0, self.n-1) - self.parent.dirty = True - return self.n - - def n_inc(self): - self.n = min(len(self.points), self.n+1) - self.parent.dirty = True - return self.n - - def n_clip(self): - self.points = self.points[0:self.n] - self.n_max = self.n - - def cur_point(self): - return self.points[self.n-1] - - def points_to_n(self): - return self.points[0:self.n] - - def set_points(self, points): - self.points = np.array(points) - self.n_max = self.n = len(points) - self.parent.dirty = True - -class IEPolys: - def __init__(self): - self.list = [] - self.n_max = self.n = 0 - self.dirty = True - - def add(self, type): - self.list = self.list[0:self.n] - l = IEPolysPoints(self, type) - self.list.append ( l ) - self.n_max = self.n = self.n + 1 - self.dirty = True - return l - - def n_dec(self): - self.n = max(0, self.n-1) - self.dirty = True - return self.n - - def n_inc(self): - self.n = min(len(self.list), self.n+1) - self.dirty = True - return self.n - - def n_list(self): - return self.list[self.n-1] - - def n_clip(self): - self.list = self.list[0:self.n] - self.n_max = self.n - if self.n > 0: - self.list[-1].n_clip() - - def __iter__(self): - for n in range(self.n): - yield self.list[n] - - def switch_dirty(self): - d = self.dirty - self.dirty = False - return d - - def overlay_mask(self, mask): - h,w,c = mask.shape - white = (1,)*c - black = (0,)*c - for n in range(self.n): - poly = self.list[n] - if poly.n > 0: - cv2.fillPoly(mask, [poly.points_to_n()], white if poly.type == 1 else black ) - - def get_total_points(self): - return sum([self.list[n].n for n in range(self.n)]) - - def dump(self): - result = [] - for n in range(self.n): - l = self.list[n] - result += [ (l.type, l.points_to_n().tolist() ) ] - return result - - @staticmethod - def load(ie_polys=None): - obj = IEPolys() - if ie_polys is not None and isinstance(ie_polys, list): - for (type, points) in ie_polys: - obj.add(type) - obj.n_list().set_points(points) - return obj \ No newline at end of file diff --git a/core/imagelib/SegIEPolys.py b/core/imagelib/SegIEPolys.py new file mode 100644 index 0000000..1a4c3d2 --- /dev/null +++ b/core/imagelib/SegIEPolys.py @@ -0,0 +1,158 @@ +import numpy as np +import cv2 +from enum import IntEnum + + +class SegIEPolyType(IntEnum): + EXCLUDE = 0 + INCLUDE = 1 + + + +class SegIEPoly(): + def __init__(self, type=None, pts=None, **kwargs): + self.type = type + + if pts is None: + pts = np.empty( (0,2), dtype=np.float32 ) + else: + pts = np.float32(pts) + self.pts = pts + self.n_max = self.n = len(pts) + + def dump(self): + return {'type': int(self.type), + 'pts' : self.get_pts(), + } + + def identical(self, b): + if self.n != b.n: + return False + return (self.pts[0:self.n] == b.pts[0:b.n]).all() + + def get_type(self): + return self.type + + def add_pt(self, x, y): + self.pts = np.append(self.pts[0:self.n], [ ( float(x), float(y) ) ], axis=0).astype(np.float32) + self.n_max = self.n = self.n + 1 + + def undo(self): + self.n = max(0, self.n-1) + return self.n + + def redo(self): + self.n = min(len(self.pts), self.n+1) + return self.n + + def redo_clip(self): + self.pts = self.pts[0:self.n] + self.n_max = self.n + + def insert_pt(self, n, pt): + if n < 0 or n > self.n: + raise ValueError("insert_pt out of range") + self.pts = np.concatenate( (self.pts[0:n], pt[None,...].astype(np.float32), self.pts[n:]), axis=0) + self.n_max = self.n = self.n+1 + + def remove_pt(self, n): + if n < 0 or n >= self.n: + raise ValueError("remove_pt out of range") + self.pts = np.concatenate( (self.pts[0:n], self.pts[n+1:]), axis=0) + self.n_max = self.n = self.n-1 + + def get_last_point(self): + return self.pts[self.n-1].copy() + + def get_pts(self): + return self.pts[0:self.n].copy() + + def get_pts_count(self): + return self.n + + def set_point(self, id, pt): + self.pts[id] = pt + + def set_points(self, pts): + self.pts = np.array(pts) + self.n_max = self.n = len(pts) + + def mult_points(self, val): + self.pts *= val + + + +class SegIEPolys(): + def __init__(self): + self.polys = [] + + def identical(self, b): + polys_len = len(self.polys) + o_polys_len = len(b.polys) + if polys_len != o_polys_len: + return False + + return all ([ a_poly.identical(b_poly) for a_poly, b_poly in zip(self.polys, b.polys) ]) + + def add_poly(self, ie_poly_type): + poly = SegIEPoly(ie_poly_type) + self.polys.append (poly) + return poly + + def remove_poly(self, poly): + if poly in self.polys: + self.polys.remove(poly) + + def has_polys(self): + return len(self.polys) != 0 + + def get_poly(self, id): + return self.polys[id] + + def get_polys(self): + return self.polys + + def get_pts_count(self): + return sum([poly.get_pts_count() for poly in self.polys]) + + def sort(self): + poly_by_type = { SegIEPolyType.EXCLUDE : [], SegIEPolyType.INCLUDE : [] } + + for poly in self.polys: + poly_by_type[poly.type].append(poly) + + self.polys = poly_by_type[SegIEPolyType.INCLUDE] + poly_by_type[SegIEPolyType.EXCLUDE] + + def __iter__(self): + for poly in self.polys: + yield poly + + def overlay_mask(self, mask): + h,w,c = mask.shape + white = (1,)*c + black = (0,)*c + for poly in self.polys: + pts = poly.get_pts().astype(np.int32) + if len(pts) != 0: + cv2.fillPoly(mask, [pts], white if poly.type == SegIEPolyType.INCLUDE else black ) + + def dump(self): + return {'polys' : [ poly.dump() for poly in self.polys ] } + + def mult_points(self, val): + for poly in self.polys: + poly.mult_points(val) + + @staticmethod + def load(data=None): + ie_polys = SegIEPolys() + if data is not None: + if isinstance(data, list): + # Backward comp + ie_polys.polys = [ SegIEPoly(type=type, pts=pts) for (type, pts) in data ] + elif isinstance(data, dict): + ie_polys.polys = [ SegIEPoly(**poly_cfg) for poly_cfg in data['polys'] ] + + ie_polys.sort() + + return ie_polys \ No newline at end of file diff --git a/core/imagelib/__init__.py b/core/imagelib/__init__.py index fec43f2..11234a5 100644 --- a/core/imagelib/__init__.py +++ b/core/imagelib/__init__.py @@ -1,4 +1,5 @@ from .estimate_sharpness import estimate_sharpness + from .equalize_and_stack_square import equalize_and_stack_square from .text import get_text_image, get_draw_text_lines @@ -11,16 +12,21 @@ from .warp import gen_warp_params, warp_by_params from .reduce_colors import reduce_colors -from .color_transfer import color_transfer, color_transfer_mix, color_transfer_sot, color_transfer_mkl, color_transfer_idt, color_hist_match, reinhard_color_transfer, linear_color_transfer, seamless_clone +from .color_transfer import color_transfer, color_transfer_mix, color_transfer_sot, color_transfer_mkl, color_transfer_idt, color_hist_match, reinhard_color_transfer, linear_color_transfer -from .common import normalize_channels, cut_odd_image, overlay_alpha_image +from .common import random_crop, normalize_channels, cut_odd_image, overlay_alpha_image -from .IEPolys import IEPolys +from .SegIEPolys import * from .blursharpen import LinearMotionBlur, blursharpen from .filters import apply_random_rgb_levels, \ + apply_random_overlay_triangle, \ apply_random_hsv_shift, \ + apply_random_sharpen, \ apply_random_motion_blur, \ apply_random_gaussian_blur, \ - apply_random_bilinear_resize + apply_random_nearest_resize, \ + apply_random_bilinear_resize, \ + apply_random_jpeg_compress, \ + apply_random_relight diff --git a/core/imagelib/color_transfer.py b/core/imagelib/color_transfer.py index 22f6876..d269de2 100644 --- a/core/imagelib/color_transfer.py +++ b/core/imagelib/color_transfer.py @@ -1,10 +1,9 @@ import cv2 +import numexpr as ne import numpy as np +import scipy as sp from numpy import linalg as npla -import scipy as sp -import scipy.sparse -from scipy.sparse.linalg import spsolve def color_transfer_sot(src,trg, steps=10, batch_size=5, reg_sigmaXY=16.0, reg_sigmaV=5.0): """ @@ -35,8 +34,9 @@ def color_transfer_sot(src,trg, steps=10, batch_size=5, reg_sigmaXY=16.0, reg_si h,w,c = src.shape new_src = src.copy() + advect = np.empty ( (h*w,c), dtype=src_dtype ) for step in range (steps): - advect = np.zeros ( (h*w,c), dtype=src_dtype ) + advect.fill(0) for batch in range (batch_size): dir = np.random.normal(size=c).astype(src_dtype) dir /= npla.norm(dir) @@ -91,6 +91,8 @@ def color_transfer_mkl(x0, x1): return np.clip ( result.reshape ( (h,w,c) ).astype(x0.dtype), 0, 1) def color_transfer_idt(i0, i1, bins=256, n_rot=20): + import scipy.stats + relaxation = 1 / n_rot h,w,c = i0.shape h1,w1,c1 = i1.shape @@ -133,135 +135,57 @@ def color_transfer_idt(i0, i1, bins=256, n_rot=20): return np.clip ( d0.T.reshape ( (h,w,c) ).astype(i0.dtype) , 0, 1) -def laplacian_matrix(n, m): - mat_D = scipy.sparse.lil_matrix((m, m)) - mat_D.setdiag(-1, -1) - mat_D.setdiag(4) - mat_D.setdiag(-1, 1) - mat_A = scipy.sparse.block_diag([mat_D] * n).tolil() - mat_A.setdiag(-1, 1*m) - mat_A.setdiag(-1, -1*m) - return mat_A +def reinhard_color_transfer(target : np.ndarray, source : np.ndarray, target_mask : np.ndarray = None, source_mask : np.ndarray = None, mask_cutoff=0.5) -> np.ndarray: + """ + Transfer color using rct method. -def seamless_clone(source, target, mask): - h, w,c = target.shape - result = [] + target np.ndarray H W 3C (BGR) np.float32 + source np.ndarray H W 3C (BGR) np.float32 - mat_A = laplacian_matrix(h, w) - laplacian = mat_A.tocsc() + target_mask(None) np.ndarray H W 1C np.float32 + source_mask(None) np.ndarray H W 1C np.float32 + + mask_cutoff(0.5) float - mask[0,:] = 1 - mask[-1,:] = 1 - mask[:,0] = 1 - mask[:,-1] = 1 - q = np.argwhere(mask==0) + masks are used to limit the space where color statistics will be computed to adjust the target - k = q[:,1]+q[:,0]*w - mat_A[k, k] = 1 - mat_A[k, k + 1] = 0 - mat_A[k, k - 1] = 0 - mat_A[k, k + w] = 0 - mat_A[k, k - w] = 0 + reference: Color Transfer between Images https://www.cs.tau.ac.il/~turkel/imagepapers/ColorTransfer.pdf + """ + source = cv2.cvtColor(source, cv2.COLOR_BGR2LAB) + target = cv2.cvtColor(target, cv2.COLOR_BGR2LAB) - mat_A = mat_A.tocsc() - mask_flat = mask.flatten() - for channel in range(c): + source_input = source + if source_mask is not None: + source_input = source_input.copy() + source_input[source_mask[...,0] < mask_cutoff] = [0,0,0] + + target_input = target + if target_mask is not None: + target_input = target_input.copy() + target_input[target_mask[...,0] < mask_cutoff] = [0,0,0] - source_flat = source[:, :, channel].flatten() - target_flat = target[:, :, channel].flatten() + target_l_mean, target_l_std, target_a_mean, target_a_std, target_b_mean, target_b_std, \ + = target_input[...,0].mean(), target_input[...,0].std(), target_input[...,1].mean(), target_input[...,1].std(), target_input[...,2].mean(), target_input[...,2].std() + + source_l_mean, source_l_std, source_a_mean, source_a_std, source_b_mean, source_b_std, \ + = source_input[...,0].mean(), source_input[...,0].std(), source_input[...,1].mean(), source_input[...,1].std(), source_input[...,2].mean(), source_input[...,2].std() + + # not as in the paper: scale by the standard deviations using reciprocal of paper proposed factor + target_l = target[...,0] + target_l = ne.evaluate('(target_l - target_l_mean) * source_l_std / target_l_std + source_l_mean') - mat_b = laplacian.dot(source_flat)*0.75 - mat_b[mask_flat==0] = target_flat[mask_flat==0] + target_a = target[...,1] + target_a = ne.evaluate('(target_a - target_a_mean) * source_a_std / target_a_std + source_a_mean') + + target_b = target[...,2] + target_b = ne.evaluate('(target_b - target_b_mean) * source_b_std / target_b_std + source_b_mean') - x = spsolve(mat_A, mat_b).reshape((h, w)) - result.append (x) + np.clip(target_l, 0, 100, out=target_l) + np.clip(target_a, -127, 127, out=target_a) + np.clip(target_b, -127, 127, out=target_b) + return cv2.cvtColor(np.stack([target_l,target_a,target_b], -1), cv2.COLOR_LAB2BGR) - return np.clip( np.dstack(result), 0, 1 ) - -def reinhard_color_transfer(target, source, clip=False, preserve_paper=False, source_mask=None, target_mask=None): - """ - Transfers the color distribution from the source to the target - image using the mean and standard deviations of the L*a*b* - color space. - - This implementation is (loosely) based on to the "Color Transfer - between Images" paper by Reinhard et al., 2001. - - Parameters: - ------- - source: NumPy array - OpenCV image in BGR color space (the source image) - target: NumPy array - OpenCV image in BGR color space (the target image) - clip: Should components of L*a*b* image be scaled by np.clip before - converting back to BGR color space? - If False then components will be min-max scaled appropriately. - Clipping will keep target image brightness truer to the input. - Scaling will adjust image brightness to avoid washed out portions - in the resulting color transfer that can be caused by clipping. - preserve_paper: Should color transfer strictly follow methodology - layed out in original paper? The method does not always produce - aesthetically pleasing results. - If False then L*a*b* components will scaled using the reciprocal of - the scaling factor proposed in the paper. This method seems to produce - more consistently aesthetically pleasing results - - Returns: - ------- - transfer: NumPy array - OpenCV image (w, h, 3) NumPy array (uint8) - """ - - - # convert the images from the RGB to L*ab* color space, being - # sure to utilizing the floating point data type (note: OpenCV - # expects floats to be 32-bit, so use that instead of 64-bit) - source = cv2.cvtColor(source, cv2.COLOR_BGR2LAB).astype(np.float32) - target = cv2.cvtColor(target, cv2.COLOR_BGR2LAB).astype(np.float32) - - # compute color statistics for the source and target images - src_input = source if source_mask is None else source*source_mask - tgt_input = target if target_mask is None else target*target_mask - (lMeanSrc, lStdSrc, aMeanSrc, aStdSrc, bMeanSrc, bStdSrc) = lab_image_stats(src_input) - (lMeanTar, lStdTar, aMeanTar, aStdTar, bMeanTar, bStdTar) = lab_image_stats(tgt_input) - - # subtract the means from the target image - (l, a, b) = cv2.split(target) - l -= lMeanTar - a -= aMeanTar - b -= bMeanTar - - if preserve_paper: - # scale by the standard deviations using paper proposed factor - l = (lStdTar / lStdSrc) * l - a = (aStdTar / aStdSrc) * a - b = (bStdTar / bStdSrc) * b - else: - # scale by the standard deviations using reciprocal of paper proposed factor - l = (lStdSrc / lStdTar) * l - a = (aStdSrc / aStdTar) * a - b = (bStdSrc / bStdTar) * b - - # add in the source mean - l += lMeanSrc - a += aMeanSrc - b += bMeanSrc - - # clip/scale the pixel intensities to [0, 255] if they fall - # outside this range - l = _scale_array(l, clip=clip) - a = _scale_array(a, clip=clip) - b = _scale_array(b, clip=clip) - - # merge the channels together and convert back to the RGB color - # space, being sure to utilize the 8-bit unsigned integer data - # type - transfer = cv2.merge([l, a, b]) - transfer = cv2.cvtColor(transfer.astype(np.uint8), cv2.COLOR_LAB2BGR) - - # return the color transferred image - return transfer def linear_color_transfer(target_img, source_img, mode='pca', eps=1e-5): ''' @@ -399,9 +323,7 @@ def color_transfer(ct_mode, img_src, img_trg): if ct_mode == 'lct': out = linear_color_transfer (img_src, img_trg) elif ct_mode == 'rct': - out = reinhard_color_transfer ( np.clip( img_src*255, 0, 255 ).astype(np.uint8), - np.clip( img_trg*255, 0, 255 ).astype(np.uint8) ) - out = np.clip( out.astype(np.float32) / 255.0, 0.0, 1.0) + out = reinhard_color_transfer(img_src, img_trg) elif ct_mode == 'mkl': out = color_transfer_mkl (img_src, img_trg) elif ct_mode == 'idt': @@ -411,4 +333,4 @@ def color_transfer(ct_mode, img_src, img_trg): out = np.clip( out, 0.0, 1.0) else: raise ValueError(f"unknown ct_mode {ct_mode}") - return out \ No newline at end of file + return out diff --git a/core/imagelib/common.py b/core/imagelib/common.py index 6566819..4219d7d 100644 --- a/core/imagelib/common.py +++ b/core/imagelib/common.py @@ -1,5 +1,16 @@ import numpy as np +def random_crop(img, w, h): + height, width = img.shape[:2] + + h_rnd = height - h + w_rnd = width - w + + y = np.random.randint(0, h_rnd) if h_rnd > 0 else 0 + x = np.random.randint(0, w_rnd) if w_rnd > 0 else 0 + + return img[y:y+height, x:x+width] + def normalize_channels(img, target_channels): img_shape_len = len(img.shape) if img_shape_len == 2: diff --git a/core/imagelib/estimate_sharpness.py b/core/imagelib/estimate_sharpness.py index 01ef0b7..e4b3e2d 100644 --- a/core/imagelib/estimate_sharpness.py +++ b/core/imagelib/estimate_sharpness.py @@ -31,9 +31,7 @@ goods or services; loss of use, data, or profits; or business interruption) howe import numpy as np import cv2 from math import atan2, pi -from scipy.ndimage import convolve -from skimage.filters.edges import HSOBEL_WEIGHTS -from skimage.feature import canny + def sobel(image): # type: (numpy.ndarray) -> numpy.ndarray @@ -42,10 +40,11 @@ def sobel(image): Inspired by the [Octave implementation](https://sourceforge.net/p/octave/image/ci/default/tree/inst/edge.m#l196). """ - + from skimage.filters.edges import HSOBEL_WEIGHTS h1 = np.array(HSOBEL_WEIGHTS) h1 /= np.sum(abs(h1)) # normalize h1 - + + from scipy.ndimage import convolve strength2 = np.square(convolve(image, h1.T)) # Note: https://sourceforge.net/p/octave/image/ci/default/tree/inst/edge.m#l59 @@ -103,6 +102,7 @@ def compute(image): # edge detection using canny and sobel canny edge detection is done to # classify the blocks as edge or non-edge blocks and sobel edge # detection is done for the purpose of edge width measurement. + from skimage.feature import canny canny_edges = canny(image) sobel_edges = sobel(image) @@ -269,9 +269,10 @@ def get_block_contrast(block): def estimate_sharpness(image): - height, width = image.shape[:2] - if image.ndim == 3: - image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) - + if image.shape[2] > 1: + image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + else: + image = image[...,0] + return compute(image) diff --git a/core/imagelib/filters.py b/core/imagelib/filters.py index 2a59069..6b69576 100644 --- a/core/imagelib/filters.py +++ b/core/imagelib/filters.py @@ -1,47 +1,65 @@ import numpy as np -from .blursharpen import LinearMotionBlur +from .blursharpen import LinearMotionBlur, blursharpen import cv2 def apply_random_rgb_levels(img, mask=None, rnd_state=None): if rnd_state is None: rnd_state = np.random np_rnd = rnd_state.rand - + inBlack = np.array([np_rnd()*0.25 , np_rnd()*0.25 , np_rnd()*0.25], dtype=np.float32) inWhite = np.array([1.0-np_rnd()*0.25, 1.0-np_rnd()*0.25, 1.0-np_rnd()*0.25], dtype=np.float32) inGamma = np.array([0.5+np_rnd(), 0.5+np_rnd(), 0.5+np_rnd()], dtype=np.float32) - + outBlack = np.array([np_rnd()*0.25 , np_rnd()*0.25 , np_rnd()*0.25], dtype=np.float32) outWhite = np.array([1.0-np_rnd()*0.25, 1.0-np_rnd()*0.25, 1.0-np_rnd()*0.25], dtype=np.float32) result = np.clip( (img - inBlack) / (inWhite - inBlack), 0, 1 ) result = ( result ** (1/inGamma) ) * (outWhite - outBlack) + outBlack result = np.clip(result, 0, 1) - + if mask is not None: result = img*(1-mask) + result*mask - + return result - + def apply_random_hsv_shift(img, mask=None, rnd_state=None): if rnd_state is None: rnd_state = np.random - + h, s, v = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV)) h = ( h + rnd_state.randint(360) ) % 360 s = np.clip ( s + rnd_state.random()-0.5, 0, 1 ) - v = np.clip ( v + rnd_state.random()/2-0.25, 0, 1 ) - + v = np.clip ( v + rnd_state.random()-0.5, 0, 1 ) + result = np.clip( cv2.cvtColor(cv2.merge([h, s, v]), cv2.COLOR_HSV2BGR) , 0, 1 ) if mask is not None: result = img*(1-mask) + result*mask - + return result - + +def apply_random_sharpen( img, chance, kernel_max_size, mask=None, rnd_state=None ): + if rnd_state is None: + rnd_state = np.random + + sharp_rnd_kernel = rnd_state.randint(kernel_max_size)+1 + + result = img + if rnd_state.randint(100) < np.clip(chance, 0, 100): + if rnd_state.randint(2) == 0: + result = blursharpen(result, 1, sharp_rnd_kernel, rnd_state.randint(10) ) + else: + result = blursharpen(result, 2, sharp_rnd_kernel, rnd_state.randint(50) ) + + if mask is not None: + result = img*(1-mask) + result*mask + + return result + def apply_random_motion_blur( img, chance, mb_max_size, mask=None, rnd_state=None ): if rnd_state is None: rnd_state = np.random - + mblur_rnd_kernel = rnd_state.randint(mb_max_size)+1 mblur_rnd_deg = rnd_state.randint(360) @@ -50,38 +68,178 @@ def apply_random_motion_blur( img, chance, mb_max_size, mask=None, rnd_state=Non result = LinearMotionBlur (result, mblur_rnd_kernel, mblur_rnd_deg ) if mask is not None: result = img*(1-mask) + result*mask - + return result - + def apply_random_gaussian_blur( img, chance, kernel_max_size, mask=None, rnd_state=None ): if rnd_state is None: rnd_state = np.random - + result = img if rnd_state.randint(100) < np.clip(chance, 0, 100): gblur_rnd_kernel = rnd_state.randint(kernel_max_size)*2+1 result = cv2.GaussianBlur(result, (gblur_rnd_kernel,)*2 , 0) if mask is not None: result = img*(1-mask) + result*mask - + return result - - -def apply_random_bilinear_resize( img, chance, max_size_per, mask=None, rnd_state=None ): + +def apply_random_resize( img, chance, max_size_per, interpolation=cv2.INTER_LINEAR, mask=None, rnd_state=None ): if rnd_state is None: rnd_state = np.random result = img if rnd_state.randint(100) < np.clip(chance, 0, 100): h,w,c = result.shape - + trg = rnd_state.rand() - rw = w - int( trg * int(w*(max_size_per/100.0)) ) - rh = h - int( trg * int(h*(max_size_per/100.0)) ) - - result = cv2.resize (result, (rw,rh), cv2.INTER_LINEAR ) - result = cv2.resize (result, (w,h), cv2.INTER_LINEAR ) + rw = w - int( trg * int(w*(max_size_per/100.0)) ) + rh = h - int( trg * int(h*(max_size_per/100.0)) ) + + result = cv2.resize (result, (rw,rh), interpolation=interpolation ) + result = cv2.resize (result, (w,h), interpolation=interpolation ) if mask is not None: result = img*(1-mask) + result*mask - + + return result + +def apply_random_nearest_resize( img, chance, max_size_per, mask=None, rnd_state=None ): + return apply_random_resize( img, chance, max_size_per, interpolation=cv2.INTER_NEAREST, mask=mask, rnd_state=rnd_state ) + +def apply_random_bilinear_resize( img, chance, max_size_per, mask=None, rnd_state=None ): + return apply_random_resize( img, chance, max_size_per, interpolation=cv2.INTER_LINEAR, mask=mask, rnd_state=rnd_state ) + +def apply_random_jpeg_compress( img, chance, mask=None, rnd_state=None ): + if rnd_state is None: + rnd_state = np.random + + result = img + if rnd_state.randint(100) < np.clip(chance, 0, 100): + h,w,c = result.shape + + quality = rnd_state.randint(10,101) + + ret, result = cv2.imencode('.jpg', np.clip(img*255, 0,255).astype(np.uint8), [int(cv2.IMWRITE_JPEG_QUALITY), quality] ) + if ret == True: + result = cv2.imdecode(result, flags=cv2.IMREAD_UNCHANGED) + result = result.astype(np.float32) / 255.0 + if mask is not None: + result = img*(1-mask) + result*mask + + return result + +def apply_random_overlay_triangle( img, max_alpha, mask=None, rnd_state=None ): + if rnd_state is None: + rnd_state = np.random + + h,w,c = img.shape + pt1 = [rnd_state.randint(w), rnd_state.randint(h) ] + pt2 = [rnd_state.randint(w), rnd_state.randint(h) ] + pt3 = [rnd_state.randint(w), rnd_state.randint(h) ] + + alpha = rnd_state.uniform()*max_alpha + + tri_mask = cv2.fillPoly( np.zeros_like(img), [ np.array([pt1,pt2,pt3], np.int32) ], (alpha,)*c ) + + if rnd_state.randint(2) == 0: + result = np.clip(img+tri_mask, 0, 1) + else: + result = np.clip(img-tri_mask, 0, 1) + + if mask is not None: + result = img*(1-mask) + result*mask + + return result + +def _min_resize(x, m): + if x.shape[0] < x.shape[1]: + s0 = m + s1 = int(float(m) / float(x.shape[0]) * float(x.shape[1])) + else: + s0 = int(float(m) / float(x.shape[1]) * float(x.shape[0])) + s1 = m + new_max = min(s1, s0) + raw_max = min(x.shape[0], x.shape[1]) + return cv2.resize(x, (s1, s0), interpolation=cv2.INTER_LANCZOS4) + +def _d_resize(x, d, fac=1.0): + new_min = min(int(d[1] * fac), int(d[0] * fac)) + raw_min = min(x.shape[0], x.shape[1]) + if new_min < raw_min: + interpolation = cv2.INTER_AREA + else: + interpolation = cv2.INTER_LANCZOS4 + y = cv2.resize(x, (int(d[1] * fac), int(d[0] * fac)), interpolation=interpolation) + return y + +def _get_image_gradient(dist): + cols = cv2.filter2D(dist, cv2.CV_32F, np.array([[-1, 0, +1], [-2, 0, +2], [-1, 0, +1]])) + rows = cv2.filter2D(dist, cv2.CV_32F, np.array([[-1, -2, -1], [0, 0, 0], [+1, +2, +1]])) + return cols, rows + +def _generate_lighting_effects(content): + h512 = content + h256 = cv2.pyrDown(h512) + h128 = cv2.pyrDown(h256) + h64 = cv2.pyrDown(h128) + h32 = cv2.pyrDown(h64) + h16 = cv2.pyrDown(h32) + c512, r512 = _get_image_gradient(h512) + c256, r256 = _get_image_gradient(h256) + c128, r128 = _get_image_gradient(h128) + c64, r64 = _get_image_gradient(h64) + c32, r32 = _get_image_gradient(h32) + c16, r16 = _get_image_gradient(h16) + c = c16 + c = _d_resize(cv2.pyrUp(c), c32.shape) * 4.0 + c32 + c = _d_resize(cv2.pyrUp(c), c64.shape) * 4.0 + c64 + c = _d_resize(cv2.pyrUp(c), c128.shape) * 4.0 + c128 + c = _d_resize(cv2.pyrUp(c), c256.shape) * 4.0 + c256 + c = _d_resize(cv2.pyrUp(c), c512.shape) * 4.0 + c512 + r = r16 + r = _d_resize(cv2.pyrUp(r), r32.shape) * 4.0 + r32 + r = _d_resize(cv2.pyrUp(r), r64.shape) * 4.0 + r64 + r = _d_resize(cv2.pyrUp(r), r128.shape) * 4.0 + r128 + r = _d_resize(cv2.pyrUp(r), r256.shape) * 4.0 + r256 + r = _d_resize(cv2.pyrUp(r), r512.shape) * 4.0 + r512 + coarse_effect_cols = c + coarse_effect_rows = r + EPS = 1e-10 + + max_effect = np.max((coarse_effect_cols**2 + coarse_effect_rows**2)**0.5, axis=0, keepdims=True, ).max(1, keepdims=True) + coarse_effect_cols = (coarse_effect_cols + EPS) / (max_effect + EPS) + coarse_effect_rows = (coarse_effect_rows + EPS) / (max_effect + EPS) + + return np.stack([ np.zeros_like(coarse_effect_rows), coarse_effect_rows, coarse_effect_cols], axis=-1) + +def apply_random_relight(img, mask=None, rnd_state=None): + if rnd_state is None: + rnd_state = np.random + + def_img = img + + if rnd_state.randint(2) == 0: + light_pos_y = 1.0 if rnd_state.randint(2) == 0 else -1.0 + light_pos_x = rnd_state.uniform()*2-1.0 + else: + light_pos_y = rnd_state.uniform()*2-1.0 + light_pos_x = 1.0 if rnd_state.randint(2) == 0 else -1.0 + + light_source_height = 0.3*rnd_state.uniform()*0.7 + light_intensity = 1.0+rnd_state.uniform() + ambient_intensity = 0.5 + + light_source_location = np.array([[[light_source_height, light_pos_y, light_pos_x ]]], dtype=np.float32) + light_source_direction = light_source_location / np.sqrt(np.sum(np.square(light_source_location))) + + lighting_effect = _generate_lighting_effects(img) + lighting_effect = np.sum(lighting_effect * light_source_direction, axis=-1).clip(0, 1) + lighting_effect = np.mean(lighting_effect, axis=-1, keepdims=True) + + result = def_img * (ambient_intensity + lighting_effect * light_intensity) #light_source_color + result = np.clip(result, 0, 1) + + if mask is not None: + result = def_img*(1-mask) + result*mask + return result \ No newline at end of file diff --git a/core/imagelib/sd/__init__.py b/core/imagelib/sd/__init__.py index 7c50477..1cddc19 100644 --- a/core/imagelib/sd/__init__.py +++ b/core/imagelib/sd/__init__.py @@ -1 +1,2 @@ -from .draw import * \ No newline at end of file +from .draw import circle_faded, random_circle_faded, bezier, random_bezier_split_faded, random_faded +from .calc import * \ No newline at end of file diff --git a/core/imagelib/sd/calc.py b/core/imagelib/sd/calc.py new file mode 100644 index 0000000..2304e66 --- /dev/null +++ b/core/imagelib/sd/calc.py @@ -0,0 +1,25 @@ +import numpy as np +import numpy.linalg as npla + +def dist_to_edges(pts, pt, is_closed=False): + """ + returns array of dist from pt to edge and projection pt to edges + """ + if is_closed: + a = pts + b = np.concatenate( (pts[1:,:], pts[0:1,:]), axis=0 ) + else: + a = pts[:-1,:] + b = pts[1:,:] + + pa = pt-a + ba = b-a + + div = np.einsum('ij,ij->i', ba, ba) + div[div==0]=1 + h = np.clip( np.einsum('ij,ij->i', pa, ba) / div, 0, 1 ) + + x = npla.norm ( pa - ba*h[...,None], axis=1 ) + + return x, a+ba*h[...,None] + diff --git a/core/imagelib/sd/draw.py b/core/imagelib/sd/draw.py index 77e9a46..711ad33 100644 --- a/core/imagelib/sd/draw.py +++ b/core/imagelib/sd/draw.py @@ -1,23 +1,36 @@ """ Signed distance drawing functions using numpy. """ +import math import numpy as np from numpy import linalg as npla -def circle_faded( hw, center, fade_dists ): + +def vector2_dot(a,b): + return a[...,0]*b[...,0]+a[...,1]*b[...,1] + +def vector2_dot2(a): + return a[...,0]*a[...,0]+a[...,1]*a[...,1] + +def vector2_cross(a,b): + return a[...,0]*b[...,1]-a[...,1]*b[...,0] + + +def circle_faded( wh, center, fade_dists ): """ returns drawn circle in [h,w,1] output range [0..1.0] float32 - hw = [h,w] resolution - center = [y,x] center of circle + wh = [w,h] resolution + center = [x,y] center of circle fade_dists = [fade_start, fade_end] fade values """ - h,w = hw + w,h = wh pts = np.empty( (h,w,2), dtype=np.float32 ) - pts[...,1] = np.arange(h)[None,:] pts[...,0] = np.arange(w)[:,None] + pts[...,1] = np.arange(h)[None,:] + pts = pts.reshape ( (h*w, -1) ) pts_dists = np.abs ( npla.norm(pts-center, axis=-1) ) @@ -30,15 +43,158 @@ def circle_faded( hw, center, fade_dists ): pts_dists = np.clip( 1-pts_dists, 0, 1) return pts_dists.reshape ( (h,w,1) ).astype(np.float32) + + +def bezier( wh, A, B, C ): + """ + returns drawn bezier in [h,w,1] output range float32, + every pixel contains signed distance to bezier line + + wh [w,h] resolution + A,B,C points [x,y] + """ -def random_circle_faded ( hw, rnd_state=None ): + width,height = wh + + A = np.float32(A) + B = np.float32(B) + C = np.float32(C) + + + pos = np.empty( (height,width,2), dtype=np.float32 ) + pos[...,0] = np.arange(width)[:,None] + pos[...,1] = np.arange(height)[None,:] + + + a = B-A + b = A - 2.0*B + C + c = a * 2.0 + d = A - pos + + b_dot = vector2_dot(b,b) + if b_dot == 0.0: + return np.zeros( (height,width), dtype=np.float32 ) + + kk = 1.0 / b_dot + + kx = kk * vector2_dot(a,b) + ky = kk * (2.0*vector2_dot(a,a)+vector2_dot(d,b))/3.0; + kz = kk * vector2_dot(d,a); + + res = 0.0; + sgn = 0.0; + + p = ky - kx*kx; + + p3 = p*p*p; + q = kx*(2.0*kx*kx - 3.0*ky) + kz; + h = q*q + 4.0*p3; + + hp_sel = h >= 0.0 + + hp_p = h[hp_sel] + hp_p = np.sqrt(hp_p) + + hp_x = ( np.stack( (hp_p,-hp_p), -1) -q[hp_sel,None] ) / 2.0 + hp_uv = np.sign(hp_x) * np.power( np.abs(hp_x), [1.0/3.0, 1.0/3.0] ) + hp_t = np.clip( hp_uv[...,0] + hp_uv[...,1] - kx, 0.0, 1.0 ) + + hp_t = hp_t[...,None] + hp_q = d[hp_sel]+(c+b*hp_t)*hp_t + hp_res = vector2_dot2(hp_q) + hp_sgn = vector2_cross(c+2.0*b*hp_t,hp_q) + + hl_sel = h < 0.0 + + hl_q = q[hl_sel] + hl_p = p[hl_sel] + hl_z = np.sqrt(-hl_p) + hl_v = np.arccos( hl_q / (hl_p*hl_z*2.0)) / 3.0 + + hl_m = np.cos(hl_v) + hl_n = np.sin(hl_v)*1.732050808; + + hl_t = np.clip( np.stack( (hl_m+hl_m,-hl_n-hl_m,hl_n-hl_m), -1)*hl_z[...,None]-kx, 0.0, 1.0 ); + + hl_d = d[hl_sel] + + hl_qx = hl_d+(c+b*hl_t[...,0:1])*hl_t[...,0:1] + + hl_dx = vector2_dot2(hl_qx) + hl_sx = vector2_cross(c+2.0*b*hl_t[...,0:1], hl_qx) + + hl_qy = hl_d+(c+b*hl_t[...,1:2])*hl_t[...,1:2] + hl_dy = vector2_dot2(hl_qy) + hl_sy = vector2_cross(c+2.0*b*hl_t[...,1:2],hl_qy); + + hl_dx_l_dy = hl_dx=hl_dy + + hl_res = np.empty_like(hl_dx) + hl_res[hl_dx_l_dy] = hl_dx[hl_dx_l_dy] + hl_res[hl_dx_ge_dy] = hl_dy[hl_dx_ge_dy] + + hl_sgn = np.empty_like(hl_sx) + hl_sgn[hl_dx_l_dy] = hl_sx[hl_dx_l_dy] + hl_sgn[hl_dx_ge_dy] = hl_sy[hl_dx_ge_dy] + + res = np.empty( (height, width), np.float32 ) + res[hp_sel] = hp_res + res[hl_sel] = hl_res + + sgn = np.empty( (height, width), np.float32 ) + sgn[hp_sel] = hp_sgn + sgn[hl_sel] = hl_sgn + + sgn = np.sign(sgn) + res = np.sqrt(res)*sgn + + return res[...,None] + +def random_faded(wh): + """ + apply one of them: + random_circle_faded + random_bezier_split_faded + """ + rnd = np.random.randint(2) + if rnd == 0: + return random_circle_faded(wh) + elif rnd == 1: + return random_bezier_split_faded(wh) + +def random_circle_faded ( wh, rnd_state=None ): if rnd_state is None: rnd_state = np.random - h,w = hw - hw_max = max(h,w) - fade_start = rnd_state.randint(hw_max) - fade_end = fade_start + rnd_state.randint(hw_max- fade_start) + w,h = wh + wh_max = max(w,h) + fade_start = rnd_state.randint(wh_max) + fade_end = fade_start + rnd_state.randint(wh_max- fade_start) - return circle_faded (hw, [ rnd_state.randint(h), rnd_state.randint(w) ], - [fade_start, fade_end] ) \ No newline at end of file + return circle_faded (wh, [ rnd_state.randint(h), rnd_state.randint(w) ], + [fade_start, fade_end] ) + +def random_bezier_split_faded( wh ): + width, height = wh + + degA = np.random.randint(360) + degB = np.random.randint(360) + degC = np.random.randint(360) + + deg_2_rad = math.pi / 180.0 + + center = np.float32([width / 2.0, height / 2.0]) + + radius = max(width, height) + + A = center + radius*np.float32([ math.sin( degA * deg_2_rad), math.cos( degA * deg_2_rad) ] ) + B = center + np.random.randint(radius)*np.float32([ math.sin( degB * deg_2_rad), math.cos( degB * deg_2_rad) ] ) + C = center + radius*np.float32([ math.sin( degC * deg_2_rad), math.cos( degC * deg_2_rad) ] ) + + x = bezier( (width,height), A, B, C ) + + x = x / (1+np.random.randint(radius)) + 0.5 + + x = np.clip(x, 0, 1) + return x diff --git a/core/imagelib/warp.py b/core/imagelib/warp.py index 50c6376..2c429d0 100644 --- a/core/imagelib/warp.py +++ b/core/imagelib/warp.py @@ -1,33 +1,147 @@ import numpy as np +import numpy.linalg as npla import cv2 from core import randomex -def gen_warp_params (w, flip, rotation_range=[-10,10], scale_range=[-0.5, 0.5], tx_range=[-0.05, 0.05], ty_range=[-0.05, 0.05], rnd_state=None ): +def mls_rigid_deformation(vy, vx, src_pts, dst_pts, alpha=1.0, eps=1e-8): + dst_pts = dst_pts[..., ::-1].astype(np.int16) + src_pts = src_pts[..., ::-1].astype(np.int16) + + src_pts, dst_pts = dst_pts, src_pts + + grow = vx.shape[0] + gcol = vx.shape[1] + ctrls = src_pts.shape[0] + + reshaped_p = src_pts.reshape(ctrls, 2, 1, 1) + reshaped_v = np.vstack((vx.reshape(1, grow, gcol), vy.reshape(1, grow, gcol))) + + w = 1.0 / (np.sum((reshaped_p - reshaped_v).astype(np.float32) ** 2, axis=1) + eps) ** alpha + w /= np.sum(w, axis=0, keepdims=True) + + pstar = np.zeros((2, grow, gcol), np.float32) + for i in range(ctrls): + pstar += w[i] * reshaped_p[i] + + vpstar = reshaped_v - pstar + + reshaped_mul_right = np.concatenate((vpstar[:,None,...], + np.concatenate((vpstar[1:2,None,...],-vpstar[0:1,None,...]), 0) + ), axis=1).transpose(2, 3, 0, 1) + + reshaped_q = dst_pts.reshape((ctrls, 2, 1, 1)) + + qstar = np.zeros((2, grow, gcol), np.float32) + for i in range(ctrls): + qstar += w[i] * reshaped_q[i] + + temp = np.zeros((grow, gcol, 2), np.float32) + for i in range(ctrls): + phat = reshaped_p[i] - pstar + qhat = reshaped_q[i] - qstar + + temp += np.matmul(qhat.reshape(1, 2, grow, gcol).transpose(2, 3, 0, 1), + + np.matmul( ( w[None, i:i+1,...] * + np.concatenate((phat.reshape(1, 2, grow, gcol), + np.concatenate( (phat[None,1:2], -phat[None,0:1]), 1 )), 0) + ).transpose(2, 3, 0, 1), reshaped_mul_right + ) + ).reshape(grow, gcol, 2) + + temp = temp.transpose(2, 0, 1) + + normed_temp = np.linalg.norm(temp, axis=0, keepdims=True) + normed_vpstar = np.linalg.norm(vpstar, axis=0, keepdims=True) + nan_mask = normed_temp[0]==0 + + transformers = np.true_divide(temp, normed_temp, out=np.zeros_like(temp), where= ~nan_mask) * normed_vpstar + qstar + nan_mask_flat = np.flatnonzero(nan_mask) + nan_mask_anti_flat = np.flatnonzero(~nan_mask) + + transformers[0][nan_mask] = np.interp(nan_mask_flat, nan_mask_anti_flat, transformers[0][~nan_mask]) + transformers[1][nan_mask] = np.interp(nan_mask_flat, nan_mask_anti_flat, transformers[1][~nan_mask]) + + return transformers + +def gen_pts(W, H, rnd_state=None): + if rnd_state is None: rnd_state = np.random - + + min_pts, max_pts = 4, 8 + n_pts = rnd_state.randint(min_pts, max_pts) + + min_radius_per = 0.00 + max_radius_per = 0.10 + pts = [] + + for i in range(n_pts): + while True: + x, y = rnd_state.randint(W), rnd_state.randint(H) + rad = min_radius_per + rnd_state.rand()*(max_radius_per-min_radius_per) + + intersect = False + for px,py,prad,_,_ in pts: + + dist = npla.norm([x-px, y-py]) + if dist <= (rad+prad)*2: + intersect = True + break + if intersect: + continue + + angle = rnd_state.rand()*(2*np.pi) + x2 = int(x+np.cos(angle)*W*rad) + y2 = int(y+np.sin(angle)*H*rad) + + break + pts.append( (x,y,rad, x2,y2) ) + + pts1 = np.array( [ [pt[0],pt[1]] for pt in pts ] ) + pts2 = np.array( [ [pt[-2],pt[-1]] for pt in pts ] ) + + return pts1, pts2 + + +def gen_warp_params (w, flip=False, rotation_range=[-10,10], scale_range=[-0.5, 0.5], tx_range=[-0.05, 0.05], ty_range=[-0.05, 0.05], rnd_state=None, warp_rnd_state=None ): + if rnd_state is None: + rnd_state = np.random + if warp_rnd_state is None: + warp_rnd_state = np.random + rw = None + if w < 64: + rw = w + w = 64 + rotation = rnd_state.uniform( rotation_range[0], rotation_range[1] ) - scale = rnd_state.uniform(1 +scale_range[0], 1 +scale_range[1]) + scale = rnd_state.uniform( 1/(1-scale_range[0]) , 1+scale_range[1] ) tx = rnd_state.uniform( tx_range[0], tx_range[1] ) ty = rnd_state.uniform( ty_range[0], ty_range[1] ) p_flip = flip and rnd_state.randint(10) < 4 - #random warp by grid - cell_size = [ w // (2**i) for i in range(1,4) ] [ rnd_state.randint(3) ] + #random warp V1 + cell_size = [ w // (2**i) for i in range(1,4) ] [ warp_rnd_state.randint(3) ] cell_count = w // cell_size + 1 - grid_points = np.linspace( 0, w, cell_count) mapx = np.broadcast_to(grid_points, (cell_count, cell_count)).copy() mapy = mapx.T - - mapx[1:-1,1:-1] = mapx[1:-1,1:-1] + randomex.random_normal( size=(cell_count-2, cell_count-2) )*(cell_size*0.24) - mapy[1:-1,1:-1] = mapy[1:-1,1:-1] + randomex.random_normal( size=(cell_count-2, cell_count-2) )*(cell_size*0.24) - + mapx[1:-1,1:-1] = mapx[1:-1,1:-1] + randomex.random_normal( size=(cell_count-2, cell_count-2), rnd_state=warp_rnd_state )*(cell_size*0.24) + mapy[1:-1,1:-1] = mapy[1:-1,1:-1] + randomex.random_normal( size=(cell_count-2, cell_count-2), rnd_state=warp_rnd_state )*(cell_size*0.24) half_cell_size = cell_size // 2 - mapx = cv2.resize(mapx, (w+cell_size,)*2 )[half_cell_size:-half_cell_size,half_cell_size:-half_cell_size].astype(np.float32) mapy = cv2.resize(mapy, (w+cell_size,)*2 )[half_cell_size:-half_cell_size,half_cell_size:-half_cell_size].astype(np.float32) - + ############## + + # random warp V2 + # pts1, pts2 = gen_pts(w, w, rnd_state) + # gridX = np.arange(w, dtype=np.int16) + # gridY = np.arange(w, dtype=np.int16) + # vy, vx = np.meshgrid(gridX, gridY) + # drigid = mls_rigid_deformation(vy, vx, pts1, pts2) + # mapy, mapx = drigid.astype(np.float32) + ################ + #random transform random_transform_mat = cv2.getRotationMatrix2D((w // 2, w // 2), rotation, scale) random_transform_mat[:, 2] += (tx*w, ty*w) @@ -36,16 +150,30 @@ def gen_warp_params (w, flip, rotation_range=[-10,10], scale_range=[-0.5, 0.5], params['mapx'] = mapx params['mapy'] = mapy params['rmat'] = random_transform_mat + u_mat = random_transform_mat.copy() + u_mat[:,2] /= w + params['umat'] = u_mat params['w'] = w + params['rw'] = rw params['flip'] = p_flip return params def warp_by_params (params, img, can_warp, can_transform, can_flip, border_replicate, cv2_inter=cv2.INTER_CUBIC): + rw = params['rw'] + + if (can_warp or can_transform) and rw is not None: + img = cv2.resize(img, (64,64), interpolation=cv2_inter) + if can_warp: img = cv2.remap(img, params['mapx'], params['mapy'], cv2_inter ) if can_transform: img = cv2.warpAffine( img, params['rmat'], (params['w'], params['w']), borderMode=(cv2.BORDER_REPLICATE if border_replicate else cv2.BORDER_CONSTANT), flags=cv2_inter ) + + + if (can_warp or can_transform) and rw is not None: + img = cv2.resize(img, (rw,rw), interpolation=cv2_inter) + if len(img.shape) == 2: img = img[...,None] if can_flip and params['flip']: diff --git a/core/interact/interact.py b/core/interact/interact.py index 5a1577e..1a8214a 100644 --- a/core/interact/interact.py +++ b/core/interact/interact.py @@ -7,6 +7,7 @@ import types import colorama import cv2 +import numpy as np from tqdm import tqdm from core import stdex @@ -197,7 +198,7 @@ class InteractBase(object): def add_key_event(self, wnd_name, ord_key, ctrl_pressed, alt_pressed, shift_pressed): if wnd_name not in self.key_events: self.key_events[wnd_name] = [] - self.key_events[wnd_name] += [ (ord_key, chr(ord_key), ctrl_pressed, alt_pressed, shift_pressed) ] + self.key_events[wnd_name] += [ (ord_key, chr(ord_key) if ord_key <= 255 else chr(0), ctrl_pressed, alt_pressed, shift_pressed) ] def get_mouse_events(self, wnd_name): ar = self.mouse_events.get(wnd_name, []) @@ -255,7 +256,7 @@ class InteractBase(object): print(result) return result - def input_int(self, s, default_value, valid_list=None, add_info=None, show_default_value=True, help_message=None): + def input_int(self, s, default_value, valid_range=None, valid_list=None, add_info=None, show_default_value=True, help_message=None): if show_default_value: if len(s) != 0: s = f"[{default_value}] {s}" @@ -263,15 +264,21 @@ class InteractBase(object): s = f"[{default_value}]" if add_info is not None or \ + valid_range is not None or \ help_message is not None: s += " (" + if valid_range is not None: + s += f" {valid_range[0]}-{valid_range[1]}" + if add_info is not None: s += f" {add_info}" + if help_message is not None: s += " ?:help" if add_info is not None or \ + valid_range is not None or \ help_message is not None: s += " )" @@ -288,9 +295,12 @@ class InteractBase(object): continue i = int(inp) + if valid_range is not None: + i = int(np.clip(i, valid_range[0], valid_range[1])) + if (valid_list is not None) and (i not in valid_list): - result = default_value - break + i = default_value + result = i break except: @@ -427,6 +437,7 @@ class InteractBase(object): p.start() time.sleep(0.5) p.terminate() + p.join() sys.stdin = os.fdopen( sys.stdin.fileno() ) @@ -490,10 +501,11 @@ class InteractDesktop(InteractBase): if has_windows or has_capture_keys: wait_key_time = max(1, int(sleep_time*1000) ) - ord_key = cv2.waitKey(wait_key_time) + ord_key = cv2.waitKeyEx(wait_key_time) + shift_pressed = False if ord_key != -1: - chr_key = chr(ord_key) + chr_key = chr(ord_key) if ord_key <= 255 else chr(0) if chr_key >= 'A' and chr_key <= 'Z': shift_pressed = True diff --git a/core/joblib/SubprocessorBase.py b/core/joblib/SubprocessorBase.py index 181c8cf..17e7056 100644 --- a/core/joblib/SubprocessorBase.py +++ b/core/joblib/SubprocessorBase.py @@ -81,11 +81,8 @@ class Subprocessor(object): except Subprocessor.SilenceException as e: c2s.put ( {'op': 'error', 'data' : data} ) except Exception as e: - c2s.put ( {'op': 'error', 'data' : data} ) - if data is not None: - print ('Exception while process data [%s]: %s' % (self.get_data_name(data), traceback.format_exc()) ) - else: - print ('Exception: %s' % (traceback.format_exc()) ) + err_msg = traceback.format_exc() + c2s.put ( {'op': 'error', 'data' : data, 'err_msg' : err_msg} ) c2s.close() s2c.close() @@ -159,6 +156,24 @@ class Subprocessor(object): self.clis = [] + def cli_init_dispatcher(cli): + while not cli.c2s.empty(): + obj = cli.c2s.get() + op = obj.get('op','') + if op == 'init_ok': + cli.state = 0 + elif op == 'log_info': + io.log_info(obj['msg']) + elif op == 'log_err': + io.log_err(obj['msg']) + elif op == 'error': + err_msg = obj.get('err_msg', None) + if err_msg is not None: + io.log_info(f'Error while subprocess initialization: {err_msg}') + cli.kill() + self.clis.remove(cli) + break + #getting info about name of subprocesses, host and client dicts, and spawning them for name, host_dict, client_dict in self.process_info_generator(): try: @@ -173,19 +188,7 @@ class Subprocessor(object): if self.initialize_subprocesses_in_serial: while True: - while not cli.c2s.empty(): - obj = cli.c2s.get() - op = obj.get('op','') - if op == 'init_ok': - cli.state = 0 - elif op == 'log_info': - io.log_info(obj['msg']) - elif op == 'log_err': - io.log_err(obj['msg']) - elif op == 'error': - cli.kill() - self.clis.remove(cli) - break + cli_init_dispatcher(cli) if cli.state == 0: break io.process_messages(0.005) @@ -198,19 +201,7 @@ class Subprocessor(object): #waiting subprocesses their success(or not) initialization while True: for cli in self.clis[:]: - while not cli.c2s.empty(): - obj = cli.c2s.get() - op = obj.get('op','') - if op == 'init_ok': - cli.state = 0 - elif op == 'log_info': - io.log_info(obj['msg']) - elif op == 'log_err': - io.log_err(obj['msg']) - elif op == 'error': - cli.kill() - self.clis.remove(cli) - break + cli_init_dispatcher(cli) if all ([cli.state == 0 for cli in self.clis]): break io.process_messages(0.005) @@ -235,8 +226,12 @@ class Subprocessor(object): cli.state = 0 elif op == 'error': #some error occured while process data, returning chunk to on_data_return + err_msg = obj.get('err_msg', None) + if err_msg is not None: + io.log_info(f'Error while processing data: {err_msg}') + if 'data' in obj.keys(): - self.on_data_return (cli.host_dict, obj['data'] ) + self.on_data_return (cli.host_dict, obj['data'] ) #and killing process cli.kill() self.clis.remove(cli) diff --git a/core/leras/archis/DeepFakeArchi.py b/core/leras/archis/DeepFakeArchi.py index b26495f..93ff13c 100644 --- a/core/leras/archis/DeepFakeArchi.py +++ b/core/leras/archis/DeepFakeArchi.py @@ -1,54 +1,60 @@ from core.leras import nn tf = nn.tf -class DeepFakeArchi(nn.ArchiBase): +class DeepFakeArchi(nn.ArchiBase): """ resolution - + mod None - default - 'chervonij' 'quick' + + opts '' + '' + 't' """ - def __init__(self, resolution, mod=None): + def __init__(self, resolution, use_fp16=False, mod=None, opts=None): super().__init__() + + if opts is None: + opts = '' + + + conv_dtype = tf.float16 if use_fp16 else tf.float32 + if 'c' in opts: + def act(x, alpha=0.1): + return x*tf.cos(x) + else: + def act(x, alpha=0.1): + return tf.nn.leaky_relu(x, alpha) + if mod is None: class Downscale(nn.ModelBase): - def __init__(self, in_ch, out_ch, kernel_size=5, dilations=1, subpixel=True, use_activator=True, *kwargs ): + def __init__(self, in_ch, out_ch, kernel_size=5, *kwargs ): self.in_ch = in_ch self.out_ch = out_ch self.kernel_size = kernel_size - self.dilations = dilations - self.subpixel = subpixel - self.use_activator = use_activator super().__init__(*kwargs) def on_build(self, *args, **kwargs ): - self.conv1 = nn.Conv2D( self.in_ch, - self.out_ch // (4 if self.subpixel else 1), - kernel_size=self.kernel_size, - strides=1 if self.subpixel else 2, - padding='SAME', dilations=self.dilations) + self.conv1 = nn.Conv2D( self.in_ch, self.out_ch, kernel_size=self.kernel_size, strides=2, padding='SAME', dtype=conv_dtype) def forward(self, x): x = self.conv1(x) - if self.subpixel: - x = nn.space_to_depth(x, 2) - if self.use_activator: - x = tf.nn.leaky_relu(x, 0.1) + x = act(x, 0.1) return x def get_out_ch(self): - return (self.out_ch // 4) * 4 + return self.out_ch class DownscaleBlock(nn.ModelBase): - def on_build(self, in_ch, ch, n_downscales, kernel_size, dilations=1, subpixel=True): + def on_build(self, in_ch, ch, n_downscales, kernel_size): self.downs = [] last_ch = in_ch for i in range(n_downscales): cur_ch = ch*( min(2**i, 8) ) - self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size, dilations=dilations, subpixel=subpixel) ) + self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size)) last_ch = self.downs[-1].get_out_ch() def forward(self, inp): @@ -58,66 +64,77 @@ class DeepFakeArchi(nn.ArchiBase): return x class Upscale(nn.ModelBase): - def on_build(self, in_ch, out_ch, kernel_size=3 ): - self.conv1 = nn.Conv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME') + def on_build(self, in_ch, out_ch, kernel_size=3): + self.conv1 = nn.Conv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype) def forward(self, x): x = self.conv1(x) - x = tf.nn.leaky_relu(x, 0.1) + x = act(x, 0.1) x = nn.depth_to_space(x, 2) return x class ResidualBlock(nn.ModelBase): - def on_build(self, ch, kernel_size=3 ): - self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME') - self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME') + def on_build(self, ch, kernel_size=3): + self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype) + self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype) def forward(self, inp): x = self.conv1(inp) - x = tf.nn.leaky_relu(x, 0.2) + x = act(x, 0.2) x = self.conv2(x) - x = tf.nn.leaky_relu(inp + x, 0.2) + x = act(inp + x, 0.2) return x - class UpdownResidualBlock(nn.ModelBase): - def on_build(self, ch, inner_ch, kernel_size=3 ): - self.up = Upscale (ch, inner_ch, kernel_size=kernel_size) - self.res = ResidualBlock (inner_ch, kernel_size=kernel_size) - self.down = Downscale (inner_ch, ch, kernel_size=kernel_size, use_activator=False) - - def forward(self, inp): - x = self.up(inp) - x = upx = self.res(x) - x = self.down(x) - x = x + inp - x = tf.nn.leaky_relu(x, 0.2) - return x, upx - class Encoder(nn.ModelBase): - def on_build(self, in_ch, e_ch, is_hd): - self.is_hd=is_hd - if self.is_hd: - self.down1 = DownscaleBlock(in_ch, e_ch*2, n_downscales=4, kernel_size=3, dilations=1) - self.down2 = DownscaleBlock(in_ch, e_ch*2, n_downscales=4, kernel_size=5, dilations=1) - self.down3 = DownscaleBlock(in_ch, e_ch//2, n_downscales=4, kernel_size=5, dilations=2) - self.down4 = DownscaleBlock(in_ch, e_ch//2, n_downscales=4, kernel_size=7, dilations=2) - else: - self.down1 = DownscaleBlock(in_ch, e_ch, n_downscales=4, kernel_size=5, dilations=1, subpixel=False) + def __init__(self, in_ch, e_ch, **kwargs ): + self.in_ch = in_ch + self.e_ch = e_ch + super().__init__(**kwargs) - def forward(self, inp): - if self.is_hd: - x = tf.concat([ nn.flatten(self.down1(inp)), - nn.flatten(self.down2(inp)), - nn.flatten(self.down3(inp)), - nn.flatten(self.down4(inp)) ], -1 ) + def on_build(self): + if 't' in opts: + self.down1 = Downscale(self.in_ch, self.e_ch, kernel_size=5) + self.res1 = ResidualBlock(self.e_ch) + self.down2 = Downscale(self.e_ch, self.e_ch*2, kernel_size=5) + self.down3 = Downscale(self.e_ch*2, self.e_ch*4, kernel_size=5) + self.down4 = Downscale(self.e_ch*4, self.e_ch*8, kernel_size=5) + self.down5 = Downscale(self.e_ch*8, self.e_ch*8, kernel_size=5) + self.res5 = ResidualBlock(self.e_ch*8) else: - x = nn.flatten(self.down1(inp)) + self.down1 = DownscaleBlock(self.in_ch, self.e_ch, n_downscales=4 if 't' not in opts else 5, kernel_size=5) + + def forward(self, x): + if use_fp16: + x = tf.cast(x, tf.float16) + + if 't' in opts: + x = self.down1(x) + x = self.res1(x) + x = self.down2(x) + x = self.down3(x) + x = self.down4(x) + x = self.down5(x) + x = self.res5(x) + else: + x = self.down1(x) + x = nn.flatten(x) + if 'u' in opts: + x = nn.pixel_norm(x, axes=-1) + + if use_fp16: + x = tf.cast(x, tf.float32) return x - - lowest_dense_res = resolution // 16 + + def get_out_res(self, res): + return res // ( (2**4) if 't' not in opts else (2**5) ) + + def get_out_ch(self): + return self.e_ch * 8 + + lowest_dense_res = resolution // (32 if 'd' in opts else 16) class Inter(nn.ModelBase): - def __init__(self, in_ch, ae_ch, ae_out_ch, is_hd=False, **kwargs): + def __init__(self, in_ch, ae_ch, ae_out_ch, **kwargs): self.in_ch, self.ae_ch, self.ae_out_ch = in_ch, ae_ch, ae_out_ch super().__init__(**kwargs) @@ -126,362 +143,120 @@ class DeepFakeArchi(nn.ArchiBase): self.dense1 = nn.Dense( in_ch, ae_ch ) self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch ) - self.upscale1 = Upscale(ae_out_ch, ae_out_ch) + if 't' not in opts: + self.upscale1 = Upscale(ae_out_ch, ae_out_ch) def forward(self, inp): - x = self.dense1(inp) + x = inp + x = self.dense1(x) x = self.dense2(x) x = nn.reshape_4D (x, lowest_dense_res, lowest_dense_res, self.ae_out_ch) - x = self.upscale1(x) + + if use_fp16: + x = tf.cast(x, tf.float16) + + if 't' not in opts: + x = self.upscale1(x) + return x - - @staticmethod - def get_code_res(): - return lowest_dense_res - + + def get_out_res(self): + return lowest_dense_res * 2 if 't' not in opts else lowest_dense_res + def get_out_ch(self): return self.ae_out_ch class Decoder(nn.ModelBase): - def on_build(self, in_ch, d_ch, d_mask_ch, is_hd ): - self.is_hd = is_hd - - self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3) - self.upscale1 = Upscale(d_ch*8, d_ch*4, kernel_size=3) - self.upscale2 = Upscale(d_ch*4, d_ch*2, kernel_size=3) - - if is_hd: - self.res0 = UpdownResidualBlock(in_ch, d_ch*8, kernel_size=3) - self.res1 = UpdownResidualBlock(d_ch*8, d_ch*4, kernel_size=3) - self.res2 = UpdownResidualBlock(d_ch*4, d_ch*2, kernel_size=3) - self.res3 = UpdownResidualBlock(d_ch*2, d_ch, kernel_size=3) - else: + def on_build(self, in_ch, d_ch, d_mask_ch): + if 't' not in opts: + self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3) + self.upscale1 = Upscale(d_ch*8, d_ch*4, kernel_size=3) + self.upscale2 = Upscale(d_ch*4, d_ch*2, kernel_size=3) self.res0 = ResidualBlock(d_ch*8, kernel_size=3) self.res1 = ResidualBlock(d_ch*4, kernel_size=3) self.res2 = ResidualBlock(d_ch*2, kernel_size=3) - self.out_conv = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='SAME') + self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3) + self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3) + self.upscalem2 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3) - self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3) - self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3) - self.upscalem2 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3) - self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME') + self.out_conv = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='SAME', dtype=conv_dtype) - def forward(self, inp): - z = inp - - if self.is_hd: - x, upx = self.res0(z) - x = self.upscale0(x) - x = tf.nn.leaky_relu(x + upx, 0.2) - x, upx = self.res1(x) - - x = self.upscale1(x) - x = tf.nn.leaky_relu(x + upx, 0.2) - x, upx = self.res2(x) - - x = self.upscale2(x) - x = tf.nn.leaky_relu(x + upx, 0.2) - x, upx = self.res3(x) + if 'd' in opts: + self.out_conv1 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype) + self.out_conv2 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype) + self.out_conv3 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype) + self.upscalem3 = Upscale(d_mask_ch*2, d_mask_ch*1, kernel_size=3) + self.out_convm = nn.Conv2D( d_mask_ch*1, 1, kernel_size=1, padding='SAME', dtype=conv_dtype) + else: + self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME', dtype=conv_dtype) else: - x = self.upscale0(z) - x = self.res0(x) - x = self.upscale1(x) - x = self.res1(x) - x = self.upscale2(x) - x = self.res2(x) + self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3) + self.upscale1 = Upscale(d_ch*8, d_ch*8, kernel_size=3) + self.upscale2 = Upscale(d_ch*8, d_ch*4, kernel_size=3) + self.upscale3 = Upscale(d_ch*4, d_ch*2, kernel_size=3) + self.res0 = ResidualBlock(d_ch*8, kernel_size=3) + self.res1 = ResidualBlock(d_ch*8, kernel_size=3) + self.res2 = ResidualBlock(d_ch*4, kernel_size=3) + self.res3 = ResidualBlock(d_ch*2, kernel_size=3) - m = self.upscalem0(z) - m = self.upscalem1(m) - m = self.upscalem2(m) + self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3) + self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*8, kernel_size=3) + self.upscalem2 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3) + self.upscalem3 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3) + self.out_conv = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='SAME', dtype=conv_dtype) - return tf.nn.sigmoid(self.out_conv(x)), \ - tf.nn.sigmoid(self.out_convm(m)) - - elif mod == 'chervonij': - class Downscale(nn.ModelBase): - def __init__(self, in_ch, kernel_size=3, dilations=1, *kwargs ): - self.in_ch = in_ch - self.kernel_size = kernel_size - self.dilations = dilations - super().__init__(*kwargs) - - def on_build(self, *args, **kwargs ): - self.conv_base1 = nn.Conv2D( self.in_ch, self.in_ch//2, kernel_size=1, strides=1, padding='SAME', dilations=self.dilations) - self.conv_l1 = nn.Conv2D( self.in_ch//2, self.in_ch//2, kernel_size=self.kernel_size, strides=1, padding='SAME', dilations=self.dilations) - self.conv_l2 = nn.Conv2D( self.in_ch//2, self.in_ch//2, kernel_size=self.kernel_size, strides=2, padding='SAME', dilations=self.dilations) - - self.conv_base2 = nn.Conv2D( self.in_ch, self.in_ch//2, kernel_size=1, strides=1, padding='SAME', dilations=self.dilations) - self.conv_r1 = nn.Conv2D( self.in_ch//2, self.in_ch//2, kernel_size=self.kernel_size, strides=2, padding='SAME', dilations=self.dilations) - - self.pool_size = [1,1,2,2] if nn.data_format == 'NCHW' else [1,2,2,1] - def forward(self, x): - - x_l = self.conv_base1(x) - x_l = self.conv_l1(x_l) - x_l = self.conv_l2(x_l) - - x_r = self.conv_base2(x) - x_r = self.conv_r1(x_r) - - x_pool = tf.nn.max_pool(x, ksize=self.pool_size, strides=self.pool_size, padding='SAME', data_format=nn.data_format) - - x = tf.concat([x_l, x_r, x_pool], axis=nn.conv2d_ch_axis) - x = tf.nn.leaky_relu(x, 0.1) - return x - - class Upscale(nn.ModelBase): - def on_build(self, in_ch, out_ch, kernel_size=3 ): - self.conv1 = nn.Conv2D( in_ch, out_ch, kernel_size=kernel_size, padding='SAME') - self.conv2 = nn.Conv2D( out_ch, out_ch, kernel_size=kernel_size, padding='SAME') - self.conv3 = nn.Conv2D( out_ch, out_ch, kernel_size=kernel_size, padding='SAME') - self.conv4 = nn.Conv2D( out_ch, out_ch, kernel_size=kernel_size, padding='SAME') - - def forward(self, x): - x0 = self.conv1(x) - x1 = self.conv2(x0) - x2 = self.conv3(x1) - x3 = self.conv4(x2) - x = tf.concat([x0, x1, x2, x3], axis=nn.conv2d_ch_axis) - x = tf.nn.leaky_relu(x, 0.1) - x = nn.depth_to_space(x, 2) - return x - - class ResidualBlock(nn.ModelBase): - def on_build(self, ch, kernel_size=3 ): - self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME') - self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME') - self.norm = nn.FRNorm2D(ch) - - def forward(self, inp): - x = self.conv1(inp) - x = tf.nn.leaky_relu(x, 0.2) - x = self.conv2(x) - x = self.norm(inp + x) - x = tf.nn.leaky_relu(x, 0.2) - return x - - class Encoder(nn.ModelBase): - def on_build(self, in_ch, e_ch, **kwargs): - self.conv0 = nn.Conv2D(in_ch, e_ch, kernel_size=3, padding='SAME') - - self.down0 = Downscale(e_ch) - self.down1 = Downscale(e_ch*2) - self.down2 = Downscale(e_ch*4) - self.down3 = Downscale(e_ch*8) - self.down4 = Downscale(e_ch*16) - - def forward(self, inp): - x = self.conv0(inp) - x = self.down0(x) - x = self.down1(x) - x = self.down2(x) - x = self.down3(x) - x = self.down4(x) - x = nn.flatten(x) - return x - - lowest_dense_res = resolution // 32 - - class Inter(nn.ModelBase): - def __init__(self, in_ch, ae_ch, ae_out_ch, **kwargs): - self.in_ch, self.ae_ch, self.ae_out_ch = in_ch, ae_ch, ae_out_ch - super().__init__(**kwargs) - - def on_build(self, **kwargs): - in_ch, ae_ch, ae_out_ch = self.in_ch, self.ae_ch, self.ae_out_ch - - self.dense_l = nn.Dense( in_ch, ae_ch//2, kernel_initializer=tf.initializers.orthogonal) - self.dense_r = nn.Dense( in_ch, ae_ch//2, kernel_initializer=tf.initializers.orthogonal)#maxout_ch=4, - self.dense = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * (ae_out_ch//2), kernel_initializer=tf.initializers.orthogonal) - self.upscale1 = Upscale(ae_out_ch//2, ae_out_ch//2) - - def forward(self, inp): - x0 = self.dense_l(inp) - x1 = self.dense_r(inp) - x = tf.concat([x0, x1], axis=-1) - x = self.dense(x) - x = nn.reshape_4D (x, lowest_dense_res, lowest_dense_res, self.ae_out_ch//2) - x = self.upscale1(x) - - return x - - def get_out_ch(self): - return self.ae_out_ch//2 - - class Decoder(nn.ModelBase): - def on_build(self, in_ch, d_ch, d_mask_ch, **kwargs): - - self.upscale0 = Upscale(in_ch, d_ch*8) - self.upscale1 = Upscale(d_ch*8, d_ch*4) - self.upscale2 = Upscale(d_ch*4, d_ch*2) - self.upscale3 = Upscale(d_ch*2, d_ch) - - self.res0 = ResidualBlock(d_ch*8) - self.res1 = ResidualBlock(d_ch*4) - self.res2 = ResidualBlock(d_ch*2) - self.res3 = ResidualBlock(d_ch) - - self.out_conv = nn.Conv2D( d_ch, 3, kernel_size=1, padding='SAME') - - self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3) - self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3) - self.upscalem2 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3) - self.upscalem3 = Upscale(d_mask_ch*2, d_mask_ch, kernel_size=3) - self.out_convm = nn.Conv2D( d_mask_ch, 1, kernel_size=1, padding='SAME') - - def forward(self, inp): - z = inp + if 'd' in opts: + self.out_conv1 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype) + self.out_conv2 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype) + self.out_conv3 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype) + self.upscalem4 = Upscale(d_mask_ch*2, d_mask_ch*1, kernel_size=3) + self.out_convm = nn.Conv2D( d_mask_ch*1, 1, kernel_size=1, padding='SAME', dtype=conv_dtype) + else: + self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME', dtype=conv_dtype) + + + def forward(self, z): x = self.upscale0(z) x = self.res0(x) x = self.upscale1(x) x = self.res1(x) x = self.upscale2(x) x = self.res2(x) - x = self.upscale3(x) - x = self.res3(x) + + if 't' in opts: + x = self.upscale3(x) + x = self.res3(x) + + if 'd' in opts: + x = tf.nn.sigmoid( nn.depth_to_space(tf.concat( (self.out_conv(x), + self.out_conv1(x), + self.out_conv2(x), + self.out_conv3(x)), nn.conv2d_ch_axis), 2) ) + else: + x = tf.nn.sigmoid(self.out_conv(x)) + m = self.upscalem0(z) m = self.upscalem1(m) m = self.upscalem2(m) - m = self.upscalem3(m) - return tf.nn.sigmoid(self.out_conv(x)), \ - tf.nn.sigmoid(self.out_convm(m)) - elif mod == 'quick': - class Downscale(nn.ModelBase): - def __init__(self, in_ch, out_ch, kernel_size=5, dilations=1, subpixel=True, use_activator=True, *kwargs ): - self.in_ch = in_ch - self.out_ch = out_ch - self.kernel_size = kernel_size - self.dilations = dilations - self.subpixel = subpixel - self.use_activator = use_activator - super().__init__(*kwargs) + if 't' in opts: + m = self.upscalem3(m) + if 'd' in opts: + m = self.upscalem4(m) + else: + if 'd' in opts: + m = self.upscalem3(m) - def on_build(self, *args, **kwargs ): - self.conv1 = nn.Conv2D( self.in_ch, - self.out_ch // (4 if self.subpixel else 1), - kernel_size=self.kernel_size, - strides=1 if self.subpixel else 2, - padding='SAME', dilations=self.dilations ) + m = tf.nn.sigmoid(self.out_convm(m)) - def forward(self, x): - x = self.conv1(x) + if use_fp16: + x = tf.cast(x, tf.float32) + m = tf.cast(m, tf.float32) - if self.subpixel: - x = nn.space_to_depth(x, 2) - - if self.use_activator: - x = nn.gelu(x) - return x - - def get_out_ch(self): - return (self.out_ch // 4) * 4 - - class DownscaleBlock(nn.ModelBase): - def on_build(self, in_ch, ch, n_downscales, kernel_size, dilations=1, subpixel=True): - self.downs = [] - - last_ch = in_ch - for i in range(n_downscales): - cur_ch = ch*( min(2**i, 8) ) - self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size, dilations=dilations, subpixel=subpixel) ) - last_ch = self.downs[-1].get_out_ch() - - def forward(self, inp): - x = inp - for down in self.downs: - x = down(x) - return x - - class Upscale(nn.ModelBase): - def on_build(self, in_ch, out_ch, kernel_size=3 ): - self.conv1 = nn.Conv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME') - - def forward(self, x): - x = self.conv1(x) - x = nn.gelu(x) - x = nn.depth_to_space(x, 2) - return x - - class ResidualBlock(nn.ModelBase): - def on_build(self, ch, kernel_size=3 ): - self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME') - self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME') - - def forward(self, inp): - x = self.conv1(inp) - x = nn.gelu(x) - x = self.conv2(x) - x = inp + x - x = nn.gelu(x) - return x - - class Encoder(nn.ModelBase): - def on_build(self, in_ch, e_ch): - self.down1 = DownscaleBlock(in_ch, e_ch, n_downscales=4, kernel_size=5) - def forward(self, inp): - return nn.flatten(self.down1(inp)) - - lowest_dense_res = resolution // 16 - - class Inter(nn.ModelBase): - def __init__(self, in_ch, ae_ch, ae_out_ch, d_ch, **kwargs): - self.in_ch, self.ae_ch, self.ae_out_ch, self.d_ch = in_ch, ae_ch, ae_out_ch, d_ch - super().__init__(**kwargs) - - def on_build(self): - in_ch, ae_ch, ae_out_ch, d_ch = self.in_ch, self.ae_ch, self.ae_out_ch, self.d_ch - - self.dense1 = nn.Dense( in_ch, ae_ch, kernel_initializer=tf.initializers.orthogonal ) - self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch, kernel_initializer=tf.initializers.orthogonal ) - self.upscale1 = Upscale(ae_out_ch, d_ch*8) - self.res1 = ResidualBlock(d_ch*8) - - def forward(self, inp): - x = self.dense1(inp) - x = self.dense2(x) - x = nn.reshape_4D (x, lowest_dense_res, lowest_dense_res, self.ae_out_ch) - x = self.upscale1(x) - x = self.res1(x) - return x - - def get_out_ch(self): - return self.ae_out_ch - - class Decoder(nn.ModelBase): - def on_build(self, in_ch, d_ch): - self.upscale1 = Upscale(in_ch, d_ch*4) - self.res1 = ResidualBlock(d_ch*4) - self.upscale2 = Upscale(d_ch*4, d_ch*2) - self.res2 = ResidualBlock(d_ch*2) - self.upscale3 = Upscale(d_ch*2, d_ch*1) - self.res3 = ResidualBlock(d_ch*1) - - self.upscalem1 = Upscale(in_ch, d_ch) - self.upscalem2 = Upscale(d_ch, d_ch//2) - self.upscalem3 = Upscale(d_ch//2, d_ch//2) - - self.out_conv = nn.Conv2D( d_ch*1, 3, kernel_size=1, padding='SAME') - self.out_convm = nn.Conv2D( d_ch//2, 1, kernel_size=1, padding='SAME') - - def forward(self, inp): - z = inp - x = self.upscale1 (z) - x = self.res1 (x) - x = self.upscale2 (x) - x = self.res2 (x) - x = self.upscale3 (x) - x = self.res3 (x) - - y = self.upscalem1 (z) - y = self.upscalem2 (y) - y = self.upscalem3 (y) - - return tf.nn.sigmoid(self.out_conv(x)), \ - tf.nn.sigmoid(self.out_convm(y)) + return x, m self.Encoder = Encoder self.Inter = Inter diff --git a/core/leras/device.py b/core/leras/device.py index 46fbd12..a2ba371 100644 --- a/core/leras/device.py +++ b/core/leras/device.py @@ -1,12 +1,19 @@ import sys import ctypes import os +import multiprocessing +import json +import time +from pathlib import Path +from core.interact import interact as io + class Device(object): - def __init__(self, index, name, total_mem, free_mem, cc=0): + def __init__(self, index, tf_dev_type, name, total_mem, free_mem): self.index = index + self.tf_dev_type = tf_dev_type self.name = name - self.cc = cc + self.total_mem = total_mem self.total_mem_gb = total_mem / 1024**3 self.free_mem = free_mem @@ -82,8 +89,136 @@ class Devices(object): result.append (device) return Devices(result) + @staticmethod + def _get_tf_devices_proc(q : multiprocessing.Queue): + + if sys.platform[0:3] == 'win': + compute_cache_path = Path(os.environ['APPDATA']) / 'NVIDIA' / ('ComputeCache_ALL') + os.environ['CUDA_CACHE_PATH'] = str(compute_cache_path) + if not compute_cache_path.exists(): + io.log_info("Caching GPU kernels...") + compute_cache_path.mkdir(parents=True, exist_ok=True) + + import tensorflow + + tf_version = tensorflow.version.VERSION + #if tf_version is None: + # tf_version = tensorflow.version.GIT_VERSION + if tf_version[0] == 'v': + tf_version = tf_version[1:] + if tf_version[0] == '2': + tf = tensorflow.compat.v1 + else: + tf = tensorflow + + import logging + # Disable tensorflow warnings + tf_logger = logging.getLogger('tensorflow') + tf_logger.setLevel(logging.ERROR) + + from tensorflow.python.client import device_lib + + devices = [] + + physical_devices = device_lib.list_local_devices() + physical_devices_f = {} + for dev in physical_devices: + dev_type = dev.device_type + dev_tf_name = dev.name + dev_tf_name = dev_tf_name[ dev_tf_name.index(dev_type) : ] + + dev_idx = int(dev_tf_name.split(':')[-1]) + + if dev_type in ['GPU','DML']: + dev_name = dev_tf_name + + dev_desc = dev.physical_device_desc + if len(dev_desc) != 0: + if dev_desc[0] == '{': + dev_desc_json = json.loads(dev_desc) + dev_desc_json_name = dev_desc_json.get('name',None) + if dev_desc_json_name is not None: + dev_name = dev_desc_json_name + else: + for param, value in ( v.split(':') for v in dev_desc.split(',') ): + param = param.strip() + value = value.strip() + if param == 'name': + dev_name = value + break + + physical_devices_f[dev_idx] = (dev_type, dev_name, dev.memory_limit) + + q.put(physical_devices_f) + time.sleep(0.1) + + @staticmethod def initialize_main_env(): + if int(os.environ.get("NN_DEVICES_INITIALIZED", 0)) != 0: + return + + if 'CUDA_VISIBLE_DEVICES' in os.environ.keys(): + os.environ.pop('CUDA_VISIBLE_DEVICES') + + os.environ['TF_DIRECTML_KERNEL_CACHE_SIZE'] = '2500' + os.environ['CUDA_​CACHE_​MAXSIZE'] = '2147483647' + os.environ['TF_MIN_GPU_MULTIPROCESSOR_COUNT'] = '2' + os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # tf log errors only + + q = multiprocessing.Queue() + p = multiprocessing.Process(target=Devices._get_tf_devices_proc, args=(q,), daemon=True) + p.start() + p.join() + + visible_devices = q.get() + + os.environ['NN_DEVICES_INITIALIZED'] = '1' + os.environ['NN_DEVICES_COUNT'] = str(len(visible_devices)) + + for i in visible_devices: + dev_type, name, total_mem = visible_devices[i] + + os.environ[f'NN_DEVICE_{i}_TF_DEV_TYPE'] = dev_type + os.environ[f'NN_DEVICE_{i}_NAME'] = name + os.environ[f'NN_DEVICE_{i}_TOTAL_MEM'] = str(total_mem) + os.environ[f'NN_DEVICE_{i}_FREE_MEM'] = str(total_mem) + + + + @staticmethod + def getDevices(): + if Devices.all_devices is None: + if int(os.environ.get("NN_DEVICES_INITIALIZED", 0)) != 1: + raise Exception("nn devices are not initialized. Run initialize_main_env() in main process.") + devices = [] + for i in range ( int(os.environ['NN_DEVICES_COUNT']) ): + devices.append ( Device(index=i, + tf_dev_type=os.environ[f'NN_DEVICE_{i}_TF_DEV_TYPE'], + name=os.environ[f'NN_DEVICE_{i}_NAME'], + total_mem=int(os.environ[f'NN_DEVICE_{i}_TOTAL_MEM']), + free_mem=int(os.environ[f'NN_DEVICE_{i}_FREE_MEM']), ) + ) + Devices.all_devices = Devices(devices) + + return Devices.all_devices + +""" + + + # {'name' : name.split(b'\0', 1)[0].decode(), + # 'total_mem' : totalMem.value + # } + + + + + + return + + + + min_cc = int(os.environ.get("TF_MIN_REQ_CAP", 35)) libnames = ('libcuda.so', 'libcuda.dylib', 'nvcuda.dll') for libname in libnames: @@ -129,77 +264,10 @@ class Devices(object): }) cuda.cuCtxDetach(context) - os.environ['NN_DEVICES_INITIALIZED'] = '1' os.environ['NN_DEVICES_COUNT'] = str(len(devices)) for i, device in enumerate(devices): os.environ[f'NN_DEVICE_{i}_NAME'] = device['name'] os.environ[f'NN_DEVICE_{i}_TOTAL_MEM'] = str(device['total_mem']) os.environ[f'NN_DEVICE_{i}_FREE_MEM'] = str(device['free_mem']) os.environ[f'NN_DEVICE_{i}_CC'] = str(device['cc']) - - @staticmethod - def getDevices(): - if Devices.all_devices is None: - if int(os.environ.get("NN_DEVICES_INITIALIZED", 0)) != 1: - raise Exception("nn devices are not initialized. Run initialize_main_env() in main process.") - devices = [] - for i in range ( int(os.environ['NN_DEVICES_COUNT']) ): - devices.append ( Device(index=i, - name=os.environ[f'NN_DEVICE_{i}_NAME'], - total_mem=int(os.environ[f'NN_DEVICE_{i}_TOTAL_MEM']), - free_mem=int(os.environ[f'NN_DEVICE_{i}_FREE_MEM']), - cc=int(os.environ[f'NN_DEVICE_{i}_CC']) )) - Devices.all_devices = Devices(devices) - - return Devices.all_devices - -""" -if Devices.all_devices is None: - min_cc = int(os.environ.get("TF_MIN_REQ_CAP", 35)) - - libnames = ('libcuda.so', 'libcuda.dylib', 'nvcuda.dll') - for libname in libnames: - try: - cuda = ctypes.CDLL(libname) - except: - continue - else: - break - else: - return Devices([]) - - nGpus = ctypes.c_int() - name = b' ' * 200 - cc_major = ctypes.c_int() - cc_minor = ctypes.c_int() - freeMem = ctypes.c_size_t() - totalMem = ctypes.c_size_t() - - result = ctypes.c_int() - device = ctypes.c_int() - context = ctypes.c_void_p() - error_str = ctypes.c_char_p() - - devices = [] - - if cuda.cuInit(0) == 0 and \ - cuda.cuDeviceGetCount(ctypes.byref(nGpus)) == 0: - for i in range(nGpus.value): - if cuda.cuDeviceGet(ctypes.byref(device), i) != 0 or \ - cuda.cuDeviceGetName(ctypes.c_char_p(name), len(name), device) != 0 or \ - cuda.cuDeviceComputeCapability(ctypes.byref(cc_major), ctypes.byref(cc_minor), device) != 0: - continue - - if cuda.cuCtxCreate_v2(ctypes.byref(context), 0, device) == 0: - if cuda.cuMemGetInfo_v2(ctypes.byref(freeMem), ctypes.byref(totalMem)) == 0: - cc = cc_major.value * 10 + cc_minor.value - if cc >= min_cc: - devices.append ( Device(index=i, - name=name.split(b'\0', 1)[0].decode(), - total_mem=totalMem.value, - free_mem=freeMem.value, - cc=cc) ) - cuda.cuCtxDetach(context) - Devices.all_devices = Devices(devices) - return Devices.all_devices """ \ No newline at end of file diff --git a/core/leras/layers/Conv2D.py b/core/leras/layers/Conv2D.py index ae37c50..a5febf0 100644 --- a/core/leras/layers/Conv2D.py +++ b/core/leras/layers/Conv2D.py @@ -23,28 +23,13 @@ class Conv2D(nn.LayerBase): if padding == "SAME": padding = ( (kernel_size - 1) * dilations + 1 ) // 2 elif padding == "VALID": - padding = 0 + padding = None else: raise ValueError ("Wrong padding type. Should be VALID SAME or INT or 4x INTs") - - if isinstance(padding, int): - if padding != 0: - if nn.data_format == "NHWC": - padding = [ [0,0], [padding,padding], [padding,padding], [0,0] ] - else: - padding = [ [0,0], [0,0], [padding,padding], [padding,padding] ] - else: - padding = None - - if nn.data_format == "NHWC": - strides = [1,strides,strides,1] else: - strides = [1,1,strides,strides] - - if nn.data_format == "NHWC": - dilations = [1,dilations,dilations,1] - else: - dilations = [1,1,dilations,dilations] + padding = int(padding) + + self.in_ch = in_ch self.out_ch = out_ch @@ -70,8 +55,8 @@ class Conv2D(nn.LayerBase): if kernel_initializer is None: kernel_initializer = tf.initializers.random_normal(0, 1.0, dtype=self.dtype) - if kernel_initializer is None: - kernel_initializer = nn.initializers.ca() + #if kernel_initializer is None: + # kernel_initializer = nn.initializers.ca() self.weight = tf.get_variable("weight", (self.kernel_size,self.kernel_size,self.in_ch,self.out_ch), dtype=self.dtype, initializer=kernel_initializer, trainable=self.trainable ) @@ -93,10 +78,27 @@ class Conv2D(nn.LayerBase): if self.use_wscale: weight = weight * self.wscale - if self.padding is not None: - x = tf.pad (x, self.padding, mode='CONSTANT') + padding = self.padding + if padding is not None: + if nn.data_format == "NHWC": + padding = [ [0,0], [padding,padding], [padding,padding], [0,0] ] + else: + padding = [ [0,0], [0,0], [padding,padding], [padding,padding] ] + x = tf.pad (x, padding, mode='CONSTANT') + + strides = self.strides + if nn.data_format == "NHWC": + strides = [1,strides,strides,1] + else: + strides = [1,1,strides,strides] - x = tf.nn.conv2d(x, weight, self.strides, 'VALID', dilations=self.dilations, data_format=nn.data_format) + dilations = self.dilations + if nn.data_format == "NHWC": + dilations = [1,dilations,dilations,1] + else: + dilations = [1,1,dilations,dilations] + + x = tf.nn.conv2d(x, weight, strides, 'VALID', dilations=dilations, data_format=nn.data_format) if self.use_bias: if nn.data_format == "NHWC": bias = tf.reshape (self.bias, (1,1,1,self.out_ch) ) diff --git a/core/leras/layers/Conv2DTranspose.py b/core/leras/layers/Conv2DTranspose.py index 937d624..a2e97dc 100644 --- a/core/leras/layers/Conv2DTranspose.py +++ b/core/leras/layers/Conv2DTranspose.py @@ -38,8 +38,8 @@ class Conv2DTranspose(nn.LayerBase): if kernel_initializer is None: kernel_initializer = tf.initializers.random_normal(0, 1.0, dtype=self.dtype) - if kernel_initializer is None: - kernel_initializer = nn.initializers.ca() + #if kernel_initializer is None: + # kernel_initializer = nn.initializers.ca() self.weight = tf.get_variable("weight", (self.kernel_size,self.kernel_size,self.out_ch,self.in_ch), dtype=self.dtype, initializer=kernel_initializer, trainable=self.trainable ) if self.use_bias: diff --git a/core/leras/layers/DenseNorm.py b/core/leras/layers/DenseNorm.py new file mode 100644 index 0000000..594bf57 --- /dev/null +++ b/core/leras/layers/DenseNorm.py @@ -0,0 +1,16 @@ +from core.leras import nn +tf = nn.tf + +class DenseNorm(nn.LayerBase): + def __init__(self, dense=False, eps=1e-06, dtype=None, **kwargs): + self.dense = dense + if dtype is None: + dtype = nn.floatx + self.eps = tf.constant(eps, dtype=dtype, name="epsilon") + + super().__init__(**kwargs) + + def __call__(self, x): + return x * tf.rsqrt(tf.reduce_mean(tf.square(x), axis=-1, keepdims=True) + self.eps) + +nn.DenseNorm = DenseNorm \ No newline at end of file diff --git a/core/leras/layers/DepthwiseConv2D.py b/core/leras/layers/DepthwiseConv2D.py new file mode 100644 index 0000000..2916f01 --- /dev/null +++ b/core/leras/layers/DepthwiseConv2D.py @@ -0,0 +1,110 @@ +import numpy as np +from core.leras import nn +tf = nn.tf + +class DepthwiseConv2D(nn.LayerBase): + """ + default kernel_initializer - CA + use_wscale bool enables equalized learning rate, if kernel_initializer is None, it will be forced to random_normal + """ + def __init__(self, in_ch, kernel_size, strides=1, padding='SAME', depth_multiplier=1, dilations=1, use_bias=True, use_wscale=False, kernel_initializer=None, bias_initializer=None, trainable=True, dtype=None, **kwargs ): + if not isinstance(strides, int): + raise ValueError ("strides must be an int type") + if not isinstance(dilations, int): + raise ValueError ("dilations must be an int type") + kernel_size = int(kernel_size) + + if dtype is None: + dtype = nn.floatx + + if isinstance(padding, str): + if padding == "SAME": + padding = ( (kernel_size - 1) * dilations + 1 ) // 2 + elif padding == "VALID": + padding = 0 + else: + raise ValueError ("Wrong padding type. Should be VALID SAME or INT or 4x INTs") + + if isinstance(padding, int): + if padding != 0: + if nn.data_format == "NHWC": + padding = [ [0,0], [padding,padding], [padding,padding], [0,0] ] + else: + padding = [ [0,0], [0,0], [padding,padding], [padding,padding] ] + else: + padding = None + + if nn.data_format == "NHWC": + strides = [1,strides,strides,1] + else: + strides = [1,1,strides,strides] + + if nn.data_format == "NHWC": + dilations = [1,dilations,dilations,1] + else: + dilations = [1,1,dilations,dilations] + + self.in_ch = in_ch + self.depth_multiplier = depth_multiplier + self.kernel_size = kernel_size + self.strides = strides + self.padding = padding + self.dilations = dilations + self.use_bias = use_bias + self.use_wscale = use_wscale + self.kernel_initializer = kernel_initializer + self.bias_initializer = bias_initializer + self.trainable = trainable + self.dtype = dtype + super().__init__(**kwargs) + + def build_weights(self): + kernel_initializer = self.kernel_initializer + if self.use_wscale: + gain = 1.0 if self.kernel_size == 1 else np.sqrt(2) + fan_in = self.kernel_size*self.kernel_size*self.in_ch + he_std = gain / np.sqrt(fan_in) + self.wscale = tf.constant(he_std, dtype=self.dtype ) + if kernel_initializer is None: + kernel_initializer = tf.initializers.random_normal(0, 1.0, dtype=self.dtype) + + #if kernel_initializer is None: + # kernel_initializer = nn.initializers.ca() + + self.weight = tf.get_variable("weight", (self.kernel_size,self.kernel_size,self.in_ch,self.depth_multiplier), dtype=self.dtype, initializer=kernel_initializer, trainable=self.trainable ) + + if self.use_bias: + bias_initializer = self.bias_initializer + if bias_initializer is None: + bias_initializer = tf.initializers.zeros(dtype=self.dtype) + + self.bias = tf.get_variable("bias", (self.in_ch*self.depth_multiplier,), dtype=self.dtype, initializer=bias_initializer, trainable=self.trainable ) + + def get_weights(self): + weights = [self.weight] + if self.use_bias: + weights += [self.bias] + return weights + + def forward(self, x): + weight = self.weight + if self.use_wscale: + weight = weight * self.wscale + + if self.padding is not None: + x = tf.pad (x, self.padding, mode='CONSTANT') + + x = tf.nn.depthwise_conv2d(x, weight, self.strides, 'VALID', data_format=nn.data_format) + if self.use_bias: + if nn.data_format == "NHWC": + bias = tf.reshape (self.bias, (1,1,1,self.in_ch*self.depth_multiplier) ) + else: + bias = tf.reshape (self.bias, (1,self.in_ch*self.depth_multiplier,1,1) ) + x = tf.add(x, bias) + return x + + def __str__(self): + r = f"{self.__class__.__name__} : in_ch:{self.in_ch} depth_multiplier:{self.depth_multiplier} " + return r + +nn.DepthwiseConv2D = DepthwiseConv2D \ No newline at end of file diff --git a/core/leras/layers/Saveable.py b/core/leras/layers/Saveable.py index 13eff3b..e72f594 100644 --- a/core/leras/layers/Saveable.py +++ b/core/leras/layers/Saveable.py @@ -46,7 +46,9 @@ class Saveable(): raise Exception("name must be defined.") name = self.name - for w, w_val in zip(weights, nn.tf_sess.run (weights)): + + for w in weights: + w_val = nn.tf_sess.run (w).copy() w_name_split = w.name.split('/', 1) if name != w_name_split[0]: raise Exception("weight first name != Saveable.name") @@ -76,28 +78,31 @@ class Saveable(): if self.name is None: raise Exception("name must be defined.") - tuples = [] - for w in weights: - w_name_split = w.name.split('/') - if self.name != w_name_split[0]: - raise Exception("weight first name != Saveable.name") + try: + tuples = [] + for w in weights: + w_name_split = w.name.split('/') + if self.name != w_name_split[0]: + raise Exception("weight first name != Saveable.name") - sub_w_name = "/".join(w_name_split[1:]) + sub_w_name = "/".join(w_name_split[1:]) - w_val = d.get(sub_w_name, None) + w_val = d.get(sub_w_name, None) - if w_val is None: - #io.log_err(f"Weight {w.name} was not loaded from file {filename}") - tuples.append ( (w, w.initializer) ) - else: - w_val = np.reshape( w_val, w.shape.as_list() ) - tuples.append ( (w, w_val) ) + if w_val is None: + #io.log_err(f"Weight {w.name} was not loaded from file {filename}") + tuples.append ( (w, w.initializer) ) + else: + w_val = np.reshape( w_val, w.shape.as_list() ) + tuples.append ( (w, w_val) ) - nn.batch_set_value(tuples) + nn.batch_set_value(tuples) + except: + return False return True def init_weights(self): nn.init_weights(self.get_weights()) - + nn.Saveable = Saveable diff --git a/core/leras/layers/ScaleAdd.py b/core/leras/layers/ScaleAdd.py new file mode 100644 index 0000000..06188b8 --- /dev/null +++ b/core/leras/layers/ScaleAdd.py @@ -0,0 +1,31 @@ +from core.leras import nn +tf = nn.tf + +class ScaleAdd(nn.LayerBase): + def __init__(self, ch, dtype=None, **kwargs): + if dtype is None: + dtype = nn.floatx + self.dtype = dtype + self.ch = ch + + super().__init__(**kwargs) + + def build_weights(self): + self.weight = tf.get_variable("weight",(self.ch,), dtype=self.dtype, initializer=tf.initializers.zeros() ) + + def get_weights(self): + return [self.weight] + + def forward(self, inputs): + if nn.data_format == "NHWC": + shape = (1,1,1,self.ch) + else: + shape = (1,self.ch,1,1) + + weight = tf.reshape ( self.weight, shape ) + + x0, x1 = inputs + x = x0 + x1*weight + + return x +nn.ScaleAdd = ScaleAdd \ No newline at end of file diff --git a/core/leras/layers/TanhPolar.py b/core/leras/layers/TanhPolar.py new file mode 100644 index 0000000..8955f32 --- /dev/null +++ b/core/leras/layers/TanhPolar.py @@ -0,0 +1,104 @@ +import numpy as np +from core.leras import nn +tf = nn.tf + +class TanhPolar(nn.LayerBase): + """ + RoI Tanh-polar Transformer Network for Face Parsing in the Wild + https://github.com/hhj1897/roi_tanh_warping + """ + + def __init__(self, width, height, angular_offset_deg=270, **kwargs): + self.width = width + self.height = height + + warp_gridx, warp_gridy = TanhPolar._get_tanh_polar_warp_grids(width,height,angular_offset_deg=angular_offset_deg) + restore_gridx, restore_gridy = TanhPolar._get_tanh_polar_restore_grids(width,height,angular_offset_deg=angular_offset_deg) + + self.warp_gridx_t = tf.constant(warp_gridx[None, ...]) + self.warp_gridy_t = tf.constant(warp_gridy[None, ...]) + self.restore_gridx_t = tf.constant(restore_gridx[None, ...]) + self.restore_gridy_t = tf.constant(restore_gridy[None, ...]) + + super().__init__(**kwargs) + + def warp(self, inp_t): + batch_t = tf.shape(inp_t)[0] + warp_gridx_t = tf.tile(self.warp_gridx_t, (batch_t,1,1) ) + warp_gridy_t = tf.tile(self.warp_gridy_t, (batch_t,1,1) ) + + if nn.data_format == "NCHW": + inp_t = tf.transpose(inp_t,(0,2,3,1)) + + out_t = nn.bilinear_sampler(inp_t, warp_gridx_t, warp_gridy_t) + + if nn.data_format == "NCHW": + out_t = tf.transpose(out_t,(0,3,1,2)) + + return out_t + + def restore(self, inp_t): + batch_t = tf.shape(inp_t)[0] + restore_gridx_t = tf.tile(self.restore_gridx_t, (batch_t,1,1) ) + restore_gridy_t = tf.tile(self.restore_gridy_t, (batch_t,1,1) ) + + if nn.data_format == "NCHW": + inp_t = tf.transpose(inp_t,(0,2,3,1)) + + inp_t = tf.pad(inp_t, [(0,0), (1, 1), (1, 0), (0, 0)], "SYMMETRIC") + + out_t = nn.bilinear_sampler(inp_t, restore_gridx_t, restore_gridy_t) + + if nn.data_format == "NCHW": + out_t = tf.transpose(out_t,(0,3,1,2)) + + return out_t + + @staticmethod + def _get_tanh_polar_warp_grids(W,H,angular_offset_deg): + angular_offset_pi = angular_offset_deg * np.pi / 180.0 + + roi_center = np.array([ W//2, H//2], np.float32 ) + roi_radii = np.array([W, H], np.float32 ) / np.pi ** 0.5 + cos_offset, sin_offset = np.cos(angular_offset_pi), np.sin(angular_offset_pi) + normalised_dest_indices = np.stack(np.meshgrid(np.arange(0.0, 1.0, 1.0 / W),np.arange(0.0, 2.0 * np.pi, 2.0 * np.pi / H)), axis=-1) + radii = normalised_dest_indices[..., 0] + orientation_x = np.cos(normalised_dest_indices[..., 1]) + orientation_y = np.sin(normalised_dest_indices[..., 1]) + + src_radii = np.arctanh(radii) * (roi_radii[0] * roi_radii[1] / np.sqrt(roi_radii[1] ** 2 * orientation_x ** 2 + roi_radii[0] ** 2 * orientation_y ** 2)) + src_x_indices = src_radii * orientation_x + src_y_indices = src_radii * orientation_y + src_x_indices, src_y_indices = (roi_center[0] + cos_offset * src_x_indices - sin_offset * src_y_indices, + roi_center[1] + cos_offset * src_y_indices + sin_offset * src_x_indices) + + return src_x_indices.astype(np.float32), src_y_indices.astype(np.float32) + + @staticmethod + def _get_tanh_polar_restore_grids(W,H,angular_offset_deg): + angular_offset_pi = angular_offset_deg * np.pi / 180.0 + + roi_center = np.array([ W//2, H//2], np.float32 ) + roi_radii = np.array([W, H], np.float32 ) / np.pi ** 0.5 + cos_offset, sin_offset = np.cos(angular_offset_pi), np.sin(angular_offset_pi) + + dest_indices = np.stack(np.meshgrid(np.arange(W), np.arange(H)), axis=-1).astype(float) + normalised_dest_indices = np.matmul(dest_indices - roi_center, np.array([[cos_offset, -sin_offset], + [sin_offset, cos_offset]])) + radii = np.linalg.norm(normalised_dest_indices, axis=-1) + normalised_dest_indices[..., 0] /= np.clip(radii, 1e-9, None) + normalised_dest_indices[..., 1] /= np.clip(radii, 1e-9, None) + radii *= np.sqrt(roi_radii[1] ** 2 * normalised_dest_indices[..., 0] ** 2 + + roi_radii[0] ** 2 * normalised_dest_indices[..., 1] ** 2) / roi_radii[0] / roi_radii[1] + + src_radii = np.tanh(radii) + + + src_x_indices = src_radii * W + 1.0 + src_y_indices = np.mod((np.arctan2(normalised_dest_indices[..., 1], normalised_dest_indices[..., 0]) / + 2.0 / np.pi) * H, H) + 1.0 + + return src_x_indices.astype(np.float32), src_y_indices.astype(np.float32) + + +nn.TanhPolar = TanhPolar \ No newline at end of file diff --git a/core/leras/layers/__init__.py b/core/leras/layers/__init__.py index 8b35ffe..4accaf6 100644 --- a/core/leras/layers/__init__.py +++ b/core/leras/layers/__init__.py @@ -3,10 +3,16 @@ from .LayerBase import * from .Conv2D import * from .Conv2DTranspose import * +from .DepthwiseConv2D import * from .Dense import * from .BlurPool import * from .BatchNorm2D import * +from .InstanceNorm2D import * from .FRNorm2D import * -from .TLU import * \ No newline at end of file +from .TLU import * +from .ScaleAdd import * +from .DenseNorm import * +from .AdaIN import * +from .TanhPolar import * \ No newline at end of file diff --git a/core/leras/models/ModelBase.py b/core/leras/models/ModelBase.py index d96e03f..cc558a4 100644 --- a/core/leras/models/ModelBase.py +++ b/core/leras/models/ModelBase.py @@ -18,6 +18,10 @@ class ModelBase(nn.Saveable): if isinstance (layer, list): for i,sublayer in enumerate(layer): self._build_sub(sublayer, f"{name}_{i}") + elif isinstance (layer, dict): + for subname in layer.keys(): + sublayer = layer[subname] + self._build_sub(sublayer, f"{name}_{subname}") elif isinstance (layer, nn.LayerBase) or \ isinstance (layer, ModelBase): @@ -32,7 +36,7 @@ class ModelBase(nn.Saveable): self.layers.append (layer) self.layers_by_name[layer.name] = layer - + def xor_list(self, lst1, lst2): return [value for value in lst1+lst2 if (value not in lst1) or (value not in lst2) ] @@ -79,7 +83,7 @@ class ModelBase(nn.Saveable): def get_layer_by_name(self, name): return self.layers_by_name.get(name, None) - + def get_layers(self): if not self.built: self.build() @@ -112,41 +116,32 @@ class ModelBase(nn.Saveable): return self.forward(*args, **kwargs) - def compute_output_shape(self, shapes): - if not self.built: - self.build() + # def compute_output_shape(self, shapes): + # if not self.built: + # self.build() - not_list = False - if not isinstance(shapes, list): - not_list = True - shapes = [shapes] + # not_list = False + # if not isinstance(shapes, list): + # not_list = True + # shapes = [shapes] - with tf.device('/CPU:0'): - # CPU tensors will not impact any performance, only slightly RAM "leakage" - phs = [] - for dtype,sh in shapes: - phs += [ tf.placeholder(dtype, sh) ] + # with tf.device('/CPU:0'): + # # CPU tensors will not impact any performance, only slightly RAM "leakage" + # phs = [] + # for dtype,sh in shapes: + # phs += [ tf.placeholder(dtype, sh) ] - result = self.__call__(phs[0] if not_list else phs) + # result = self.__call__(phs[0] if not_list else phs) - if not isinstance(result, list): - result = [result] + # if not isinstance(result, list): + # result = [result] - result_shapes = [] + # result_shapes = [] - for t in result: - result_shapes += [ t.shape.as_list() ] + # for t in result: + # result_shapes += [ t.shape.as_list() ] - return result_shapes[0] if not_list else result_shapes - - def compute_output_channels(self, shapes): - shape = self.compute_output_shape(shapes) - shape_len = len(shape) - - if shape_len == 4: - if nn.data_format == "NCHW": - return shape[1] - return shape[-1] + # return result_shapes[0] if not_list else result_shapes def build_for_run(self, shapes_list): if not isinstance(shapes_list, list): diff --git a/core/leras/models/PatchDiscriminator.py b/core/leras/models/PatchDiscriminator.py index d2bd44a..9b94e9f 100644 --- a/core/leras/models/PatchDiscriminator.py +++ b/core/leras/models/PatchDiscriminator.py @@ -1,7 +1,7 @@ +import numpy as np from core.leras import nn tf = nn.tf - patch_discriminator_kernels = \ { 1 : (512, [ [1,1] ]), 2 : (512, [ [2,1] ]), @@ -12,7 +12,7 @@ patch_discriminator_kernels = \ 7 : (512, [ [3,2], [3,2] ]), 8 : (512, [ [4,2], [3,2] ]), 9 : (512, [ [3,2], [4,2] ]), - 10 : (512, [ [4,2], [4,2] ]), + 10 : (512, [ [4,2], [4,2] ]), 11 : (512, [ [3,2], [3,2], [2,1] ]), 12 : (512, [ [4,2], [3,2], [2,1] ]), 13 : (512, [ [3,2], [4,2], [2,1] ]), @@ -20,42 +20,50 @@ patch_discriminator_kernels = \ 15 : (512, [ [3,2], [3,2], [3,1] ]), 16 : (512, [ [4,2], [3,2], [3,1] ]), 17 : (512, [ [3,2], [4,2], [3,1] ]), - 18 : (512, [ [4,2], [4,2], [3,1] ]), + 18 : (512, [ [4,2], [4,2], [3,1] ]), 19 : (512, [ [3,2], [3,2], [4,1] ]), 20 : (512, [ [4,2], [3,2], [4,1] ]), 21 : (512, [ [3,2], [4,2], [4,1] ]), - 22 : (512, [ [4,2], [4,2], [4,1] ]), + 22 : (512, [ [4,2], [4,2], [4,1] ]), 23 : (256, [ [3,2], [3,2], [3,2], [2,1] ]), 24 : (256, [ [4,2], [3,2], [3,2], [2,1] ]), 25 : (256, [ [3,2], [4,2], [3,2], [2,1] ]), - 26 : (256, [ [4,2], [4,2], [3,2], [2,1] ]), - 27 : (256, [ [3,2], [4,2], [4,2], [2,1] ]), + 26 : (256, [ [4,2], [4,2], [3,2], [2,1] ]), + 27 : (256, [ [3,2], [4,2], [4,2], [2,1] ]), 28 : (256, [ [4,2], [3,2], [4,2], [2,1] ]), 29 : (256, [ [3,2], [4,2], [4,2], [2,1] ]), - 30 : (256, [ [4,2], [4,2], [4,2], [2,1] ]), + 30 : (256, [ [4,2], [4,2], [4,2], [2,1] ]), 31 : (256, [ [3,2], [3,2], [3,2], [3,1] ]), 32 : (256, [ [4,2], [3,2], [3,2], [3,1] ]), 33 : (256, [ [3,2], [4,2], [3,2], [3,1] ]), - 34 : (256, [ [4,2], [4,2], [3,2], [3,1] ]), - 35 : (256, [ [3,2], [4,2], [4,2], [3,1] ]), + 34 : (256, [ [4,2], [4,2], [3,2], [3,1] ]), + 35 : (256, [ [3,2], [4,2], [4,2], [3,1] ]), 36 : (256, [ [4,2], [3,2], [4,2], [3,1] ]), 37 : (256, [ [3,2], [4,2], [4,2], [3,1] ]), 38 : (256, [ [4,2], [4,2], [4,2], [3,1] ]), + 39 : (256, [ [3,2], [3,2], [3,2], [4,1] ]), + 40 : (256, [ [4,2], [3,2], [3,2], [4,1] ]), + 41 : (256, [ [3,2], [4,2], [3,2], [4,1] ]), + 42 : (256, [ [4,2], [4,2], [3,2], [4,1] ]), + 43 : (256, [ [3,2], [4,2], [4,2], [4,1] ]), + 44 : (256, [ [4,2], [3,2], [4,2], [4,1] ]), + 45 : (256, [ [3,2], [4,2], [4,2], [4,1] ]), + 46 : (256, [ [4,2], [4,2], [4,2], [4,1] ]), } class PatchDiscriminator(nn.ModelBase): - def on_build(self, patch_size, in_ch, base_ch=None, conv_kernel_initializer=None): + def on_build(self, patch_size, in_ch, base_ch=None, conv_kernel_initializer=None): suggested_base_ch, kernels_strides = patch_discriminator_kernels[patch_size] - + if base_ch is None: base_ch = suggested_base_ch - + prev_ch = in_ch self.convs = [] - for i, (kernel_size, strides) in enumerate(kernels_strides): + for i, (kernel_size, strides) in enumerate(kernels_strides): cur_ch = base_ch * min( (2**i), 8 ) - + self.convs.append ( nn.Conv2D( prev_ch, cur_ch, kernel_size=kernel_size, strides=strides, padding='SAME', kernel_initializer=conv_kernel_initializer) ) prev_ch = cur_ch @@ -66,4 +74,121 @@ class PatchDiscriminator(nn.ModelBase): x = tf.nn.leaky_relu( conv(x), 0.1 ) return self.out_conv(x) -nn.PatchDiscriminator = PatchDiscriminator \ No newline at end of file +nn.PatchDiscriminator = PatchDiscriminator + +class UNetPatchDiscriminator(nn.ModelBase): + """ + Inspired by https://arxiv.org/abs/2002.12655 "A U-Net Based Discriminator for Generative Adversarial Networks" + """ + def calc_receptive_field_size(self, layers): + """ + result the same as https://fomoro.com/research/article/receptive-field-calculatorindex.html + """ + rf = 0 + ts = 1 + for i, (k, s) in enumerate(layers): + if i == 0: + rf = k + else: + rf += (k-1)*ts + ts *= s + return rf + + def find_archi(self, target_patch_size, max_layers=9): + """ + Find the best configuration of layers using only 3x3 convs for target patch size + """ + s = {} + for layers_count in range(1,max_layers+1): + val = 1 << (layers_count-1) + while True: + val -= 1 + + layers = [] + sum_st = 0 + layers.append ( [3, 2]) + sum_st += 2 + for i in range(layers_count-1): + st = 1 + (1 if val & (1 << i) !=0 else 0 ) + layers.append ( [3, st ]) + sum_st += st + + rf = self.calc_receptive_field_size(layers) + + s_rf = s.get(rf, None) + if s_rf is None: + s[rf] = (layers_count, sum_st, layers) + else: + if layers_count < s_rf[0] or \ + ( layers_count == s_rf[0] and sum_st > s_rf[1] ): + s[rf] = (layers_count, sum_st, layers) + + if val == 0: + break + + x = sorted(list(s.keys())) + q=x[np.abs(np.array(x)-target_patch_size).argmin()] + return s[q][2] + + def on_build(self, patch_size, in_ch, base_ch = 16, use_fp16 = False): + self.use_fp16 = use_fp16 + conv_dtype = tf.float16 if use_fp16 else tf.float32 + + class ResidualBlock(nn.ModelBase): + def on_build(self, ch, kernel_size=3 ): + self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype) + self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype) + + def forward(self, inp): + x = self.conv1(inp) + x = tf.nn.leaky_relu(x, 0.2) + x = self.conv2(x) + x = tf.nn.leaky_relu(inp + x, 0.2) + return x + + prev_ch = in_ch + self.convs = [] + self.upconvs = [] + layers = self.find_archi(patch_size) + + level_chs = { i-1:v for i,v in enumerate([ min( base_ch * (2**i), 512 ) for i in range(len(layers)+1)]) } + + self.in_conv = nn.Conv2D( in_ch, level_chs[-1], kernel_size=1, padding='VALID', dtype=conv_dtype) + + for i, (kernel_size, strides) in enumerate(layers): + self.convs.append ( nn.Conv2D( level_chs[i-1], level_chs[i], kernel_size=kernel_size, strides=strides, padding='SAME', dtype=conv_dtype) ) + + self.upconvs.insert (0, nn.Conv2DTranspose( level_chs[i]*(2 if i != len(layers)-1 else 1), level_chs[i-1], kernel_size=kernel_size, strides=strides, padding='SAME', dtype=conv_dtype) ) + + self.out_conv = nn.Conv2D( level_chs[-1]*2, 1, kernel_size=1, padding='VALID', dtype=conv_dtype) + + self.center_out = nn.Conv2D( level_chs[len(layers)-1], 1, kernel_size=1, padding='VALID', dtype=conv_dtype) + self.center_conv = nn.Conv2D( level_chs[len(layers)-1], level_chs[len(layers)-1], kernel_size=1, padding='VALID', dtype=conv_dtype) + + + def forward(self, x): + if self.use_fp16: + x = tf.cast(x, tf.float16) + + x = tf.nn.leaky_relu( self.in_conv(x), 0.2 ) + + encs = [] + for conv in self.convs: + encs.insert(0, x) + x = tf.nn.leaky_relu( conv(x), 0.2 ) + + center_out, x = self.center_out(x), tf.nn.leaky_relu( self.center_conv(x), 0.2 ) + + for i, (upconv, enc) in enumerate(zip(self.upconvs, encs)): + x = tf.nn.leaky_relu( upconv(x), 0.2 ) + x = tf.concat( [enc, x], axis=nn.conv2d_ch_axis) + + x = self.out_conv(x) + + if self.use_fp16: + center_out = tf.cast(center_out, tf.float32) + x = tf.cast(x, tf.float32) + + return center_out, x + +nn.UNetPatchDiscriminator = UNetPatchDiscriminator diff --git a/core/leras/models/Ternaus.py b/core/leras/models/Ternaus.py deleted file mode 100644 index ad5ffc3..0000000 --- a/core/leras/models/Ternaus.py +++ /dev/null @@ -1,92 +0,0 @@ -""" -using https://github.com/ternaus/TernausNet -TernausNet: U-Net with VGG11 Encoder Pre-Trained on ImageNet for Image Segmentation -""" - -from core.leras import nn -tf = nn.tf - -class Ternaus(nn.ModelBase): - def on_build(self, in_ch, base_ch): - - self.features_0 = nn.Conv2D (in_ch, base_ch, kernel_size=3, padding='SAME') - self.features_3 = nn.Conv2D (base_ch, base_ch*2, kernel_size=3, padding='SAME') - self.features_6 = nn.Conv2D (base_ch*2, base_ch*4, kernel_size=3, padding='SAME') - self.features_8 = nn.Conv2D (base_ch*4, base_ch*4, kernel_size=3, padding='SAME') - self.features_11 = nn.Conv2D (base_ch*4, base_ch*8, kernel_size=3, padding='SAME') - self.features_13 = nn.Conv2D (base_ch*8, base_ch*8, kernel_size=3, padding='SAME') - self.features_16 = nn.Conv2D (base_ch*8, base_ch*8, kernel_size=3, padding='SAME') - self.features_18 = nn.Conv2D (base_ch*8, base_ch*8, kernel_size=3, padding='SAME') - - self.blurpool_0 = nn.BlurPool (filt_size=3) - self.blurpool_3 = nn.BlurPool (filt_size=3) - self.blurpool_8 = nn.BlurPool (filt_size=3) - self.blurpool_13 = nn.BlurPool (filt_size=3) - self.blurpool_18 = nn.BlurPool (filt_size=3) - - self.conv_center = nn.Conv2D (base_ch*8, base_ch*8, kernel_size=3, padding='SAME') - - self.conv1_up = nn.Conv2DTranspose (base_ch*8, base_ch*4, kernel_size=3, padding='SAME') - self.conv1 = nn.Conv2D (base_ch*12, base_ch*8, kernel_size=3, padding='SAME') - - self.conv2_up = nn.Conv2DTranspose (base_ch*8, base_ch*4, kernel_size=3, padding='SAME') - self.conv2 = nn.Conv2D (base_ch*12, base_ch*8, kernel_size=3, padding='SAME') - - self.conv3_up = nn.Conv2DTranspose (base_ch*8, base_ch*2, kernel_size=3, padding='SAME') - self.conv3 = nn.Conv2D (base_ch*6, base_ch*4, kernel_size=3, padding='SAME') - - self.conv4_up = nn.Conv2DTranspose (base_ch*4, base_ch, kernel_size=3, padding='SAME') - self.conv4 = nn.Conv2D (base_ch*3, base_ch*2, kernel_size=3, padding='SAME') - - self.conv5_up = nn.Conv2DTranspose (base_ch*2, base_ch//2, kernel_size=3, padding='SAME') - self.conv5 = nn.Conv2D (base_ch//2+base_ch, base_ch, kernel_size=3, padding='SAME') - - self.out_conv = nn.Conv2D (base_ch, 1, kernel_size=3, padding='SAME') - - def forward(self, inp): - x, = inp - - x = x0 = tf.nn.relu(self.features_0(x)) - x = self.blurpool_0(x) - - x = x1 = tf.nn.relu(self.features_3(x)) - x = self.blurpool_3(x) - - x = tf.nn.relu(self.features_6(x)) - x = x2 = tf.nn.relu(self.features_8(x)) - x = self.blurpool_8(x) - - x = tf.nn.relu(self.features_11(x)) - x = x3 = tf.nn.relu(self.features_13(x)) - x = self.blurpool_13(x) - - x = tf.nn.relu(self.features_16(x)) - x = x4 = tf.nn.relu(self.features_18(x)) - x = self.blurpool_18(x) - - x = self.conv_center(x) - - x = tf.nn.relu(self.conv1_up(x)) - x = tf.concat( [x,x4], nn.conv2d_ch_axis) - x = tf.nn.relu(self.conv1(x)) - - x = tf.nn.relu(self.conv2_up(x)) - x = tf.concat( [x,x3], nn.conv2d_ch_axis) - x = tf.nn.relu(self.conv2(x)) - - x = tf.nn.relu(self.conv3_up(x)) - x = tf.concat( [x,x2], nn.conv2d_ch_axis) - x = tf.nn.relu(self.conv3(x)) - - x = tf.nn.relu(self.conv4_up(x)) - x = tf.concat( [x,x1], nn.conv2d_ch_axis) - x = tf.nn.relu(self.conv4(x)) - - x = tf.nn.relu(self.conv5_up(x)) - x = tf.concat( [x,x0], nn.conv2d_ch_axis) - x = tf.nn.relu(self.conv5(x)) - - logits = self.out_conv(x) - return logits, tf.nn.sigmoid(logits) - -nn.Ternaus = Ternaus \ No newline at end of file diff --git a/core/leras/models/XSeg.py b/core/leras/models/XSeg.py index 0ba19a6..f59eb8c 100644 --- a/core/leras/models/XSeg.py +++ b/core/leras/models/XSeg.py @@ -28,11 +28,12 @@ class XSeg(nn.ModelBase): x = self.frn(x) x = self.tlu(x) return x + + self.base_ch = base_ch self.conv01 = ConvBlock(in_ch, base_ch) self.conv02 = ConvBlock(base_ch, base_ch) - self.bp0 = nn.BlurPool (filt_size=3) - + self.bp0 = nn.BlurPool (filt_size=4) self.conv11 = ConvBlock(base_ch, base_ch*2) self.conv12 = ConvBlock(base_ch*2, base_ch*2) @@ -40,19 +41,30 @@ class XSeg(nn.ModelBase): self.conv21 = ConvBlock(base_ch*2, base_ch*4) self.conv22 = ConvBlock(base_ch*4, base_ch*4) - self.conv23 = ConvBlock(base_ch*4, base_ch*4) - self.bp2 = nn.BlurPool (filt_size=3) - + self.bp2 = nn.BlurPool (filt_size=2) self.conv31 = ConvBlock(base_ch*4, base_ch*8) self.conv32 = ConvBlock(base_ch*8, base_ch*8) self.conv33 = ConvBlock(base_ch*8, base_ch*8) - self.bp3 = nn.BlurPool (filt_size=3) + self.bp3 = nn.BlurPool (filt_size=2) self.conv41 = ConvBlock(base_ch*8, base_ch*8) self.conv42 = ConvBlock(base_ch*8, base_ch*8) self.conv43 = ConvBlock(base_ch*8, base_ch*8) - self.bp4 = nn.BlurPool (filt_size=3) + self.bp4 = nn.BlurPool (filt_size=2) + + self.conv51 = ConvBlock(base_ch*8, base_ch*8) + self.conv52 = ConvBlock(base_ch*8, base_ch*8) + self.conv53 = ConvBlock(base_ch*8, base_ch*8) + self.bp5 = nn.BlurPool (filt_size=2) + + self.dense1 = nn.Dense ( 4*4* base_ch*8, 512) + self.dense2 = nn.Dense ( 512, 4*4* base_ch*8) + + self.up5 = UpConvBlock (base_ch*8, base_ch*4) + self.uconv53 = ConvBlock(base_ch*12, base_ch*8) + self.uconv52 = ConvBlock(base_ch*8, base_ch*8) + self.uconv51 = ConvBlock(base_ch*8, base_ch*8) self.up4 = UpConvBlock (base_ch*8, base_ch*4) self.uconv43 = ConvBlock(base_ch*12, base_ch*8) @@ -65,8 +77,7 @@ class XSeg(nn.ModelBase): self.uconv31 = ConvBlock(base_ch*8, base_ch*8) self.up2 = UpConvBlock (base_ch*8, base_ch*4) - self.uconv23 = ConvBlock(base_ch*8, base_ch*4) - self.uconv22 = ConvBlock(base_ch*4, base_ch*4) + self.uconv22 = ConvBlock(base_ch*8, base_ch*4) self.uconv21 = ConvBlock(base_ch*4, base_ch*4) self.up1 = UpConvBlock (base_ch*4, base_ch*2) @@ -77,10 +88,9 @@ class XSeg(nn.ModelBase): self.uconv02 = ConvBlock(base_ch*2, base_ch) self.uconv01 = ConvBlock(base_ch, base_ch) self.out_conv = nn.Conv2D (base_ch, out_ch, kernel_size=3, padding='SAME') + - self.conv_center = ConvBlock(base_ch*8, base_ch*8) - - def forward(self, inp): + def forward(self, inp, pretrain=False): x = inp x = self.conv01(x) @@ -92,8 +102,7 @@ class XSeg(nn.ModelBase): x = self.bp1(x) x = self.conv21(x) - x = self.conv22(x) - x = x2 = self.conv23(x) + x = x2 = self.conv22(x) x = self.bp2(x) x = self.conv31(x) @@ -106,28 +115,52 @@ class XSeg(nn.ModelBase): x = x4 = self.conv43(x) x = self.bp4(x) - x = self.conv_center(x) - + x = self.conv51(x) + x = self.conv52(x) + x = x5 = self.conv53(x) + x = self.bp5(x) + + x = nn.flatten(x) + x = self.dense1(x) + x = self.dense2(x) + x = nn.reshape_4D (x, 4, 4, self.base_ch*8 ) + + x = self.up5(x) + if pretrain: + x5 = tf.zeros_like(x5) + x = self.uconv53(tf.concat([x,x5],axis=nn.conv2d_ch_axis)) + x = self.uconv52(x) + x = self.uconv51(x) + x = self.up4(x) + if pretrain: + x4 = tf.zeros_like(x4) x = self.uconv43(tf.concat([x,x4],axis=nn.conv2d_ch_axis)) x = self.uconv42(x) x = self.uconv41(x) x = self.up3(x) + if pretrain: + x3 = tf.zeros_like(x3) x = self.uconv33(tf.concat([x,x3],axis=nn.conv2d_ch_axis)) x = self.uconv32(x) x = self.uconv31(x) x = self.up2(x) - x = self.uconv23(tf.concat([x,x2],axis=nn.conv2d_ch_axis)) - x = self.uconv22(x) + if pretrain: + x2 = tf.zeros_like(x2) + x = self.uconv22(tf.concat([x,x2],axis=nn.conv2d_ch_axis)) x = self.uconv21(x) x = self.up1(x) + if pretrain: + x1 = tf.zeros_like(x1) x = self.uconv12(tf.concat([x,x1],axis=nn.conv2d_ch_axis)) x = self.uconv11(x) x = self.up0(x) + if pretrain: + x0 = tf.zeros_like(x0) x = self.uconv02(tf.concat([x,x0],axis=nn.conv2d_ch_axis)) x = self.uconv01(x) diff --git a/core/leras/models/__init__.py b/core/leras/models/__init__.py index 9db94fa..2f7e545 100644 --- a/core/leras/models/__init__.py +++ b/core/leras/models/__init__.py @@ -1,5 +1,4 @@ from .ModelBase import * from .PatchDiscriminator import * from .CodeDiscriminator import * -from .Ternaus import * from .XSeg import * \ No newline at end of file diff --git a/core/leras/nn.py b/core/leras/nn.py index 6504698..f392aaf 100644 --- a/core/leras/nn.py +++ b/core/leras/nn.py @@ -33,8 +33,8 @@ class nn(): tf = None tf_sess = None tf_sess_config = None - tf_default_device = None - + tf_default_device_name = None + data_format = None conv2d_ch_axis = None conv2d_spatial_axes = None @@ -50,9 +50,6 @@ class nn(): nn.setCurrentDeviceConfig(device_config) # Manipulate environment variables before import tensorflow - - if 'CUDA_VISIBLE_DEVICES' in os.environ.keys(): - os.environ.pop('CUDA_VISIBLE_DEVICES') first_run = False if len(device_config.devices) != 0: @@ -68,21 +65,32 @@ class nn(): compute_cache_path = Path(os.environ['APPDATA']) / 'NVIDIA' / ('ComputeCache' + devices_str) if not compute_cache_path.exists(): first_run = True + compute_cache_path.mkdir(parents=True, exist_ok=True) os.environ['CUDA_CACHE_PATH'] = str(compute_cache_path) - - os.environ['CUDA_​CACHE_​MAXSIZE'] = '536870912' #512Mb (32mb default) - os.environ['TF_MIN_GPU_MULTIPROCESSOR_COUNT'] = '2' - os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # tf log errors only - + if first_run: io.log_info("Caching GPU kernels...") - import tensorflow as tf - nn.tf = tf - + import tensorflow + + tf_version = tensorflow.version.VERSION + #if tf_version is None: + # tf_version = tensorflow.version.GIT_VERSION + if tf_version[0] == 'v': + tf_version = tf_version[1:] + if tf_version[0] == '2': + tf = tensorflow.compat.v1 + else: + tf = tensorflow + import logging # Disable tensorflow warnings - logging.getLogger('tensorflow').setLevel(logging.ERROR) + tf_logger = logging.getLogger('tensorflow') + tf_logger.setLevel(logging.ERROR) + + if tf_version[0] == '2': + tf.disable_v2_behavior() + nn.tf = tf # Initialize framework import core.leras.ops @@ -94,13 +102,14 @@ class nn(): # Configure tensorflow session-config if len(device_config.devices) == 0: - nn.tf_default_device = "/CPU:0" config = tf.ConfigProto(device_count={'GPU': 0}) + nn.tf_default_device_name = '/CPU:0' else: - nn.tf_default_device = "/GPU:0" + nn.tf_default_device_name = f'/{device_config.devices[0].tf_dev_type}:0' + config = tf.ConfigProto() config.gpu_options.visible_device_list = ','.join([str(device.index) for device in device_config.devices]) - + config.gpu_options.force_gpu_compatible = True config.gpu_options.allow_growth = True nn.tf_sess_config = config @@ -188,14 +197,6 @@ class nn(): nn.tf_sess.close() nn.tf_sess = None - @staticmethod - def get_current_device(): - # Undocumented access to last tf.device(...) - objs = nn.tf.get_default_graph()._device_function_stack.peek_objs() - if len(objs) != 0: - return objs[0].display_name - return nn.tf_default_device - @staticmethod def ask_choose_device_idxs(choose_only_one=False, allow_cpu=True, suggest_best_multi_gpu=False, suggest_all_gpu=False): devices = Devices.getDevices() diff --git a/core/leras/ops/__init__.py b/core/leras/ops/__init__.py index 4931488..bd690da 100644 --- a/core/leras/ops/__init__.py +++ b/core/leras/ops/__init__.py @@ -56,7 +56,7 @@ def tf_gradients ( loss, vars ): gv = [*zip(grads,vars)] for g,v in gv: if g is None: - raise Exception(f"No gradient for variable {v.name}") + raise Exception(f"Variable {v.name} is declared as trainable, but no tensors flow through it.") return gv nn.gradients = tf_gradients @@ -108,10 +108,15 @@ nn.gelu = gelu def upsample2d(x, size=2): if nn.data_format == "NCHW": - b,c,h,w = x.shape.as_list() - x = tf.reshape (x, (-1,c,h,1,w,1) ) - x = tf.tile(x, (1,1,1,size,1,size) ) - x = tf.reshape (x, (-1,c,h*size,w*size) ) + x = tf.transpose(x, (0,2,3,1)) + x = tf.image.resize_nearest_neighbor(x, (x.shape[1]*size, x.shape[2]*size) ) + x = tf.transpose(x, (0,3,1,2)) + + + # b,c,h,w = x.shape.as_list() + # x = tf.reshape (x, (-1,c,h,1,w,1) ) + # x = tf.tile(x, (1,1,1,size,1,size) ) + # x = tf.reshape (x, (-1,c,h*size,w*size) ) return x else: return tf.image.resize_nearest_neighbor(x, (x.shape[1]*size, x.shape[2]*size) ) @@ -120,25 +125,56 @@ nn.upsample2d = upsample2d def resize2d_bilinear(x, size=2): h = x.shape[nn.conv2d_spatial_axes[0]].value w = x.shape[nn.conv2d_spatial_axes[1]].value - + if nn.data_format == "NCHW": x = tf.transpose(x, (0,2,3,1)) - + if size > 0: new_size = (h*size,w*size) else: new_size = (h//-size,w//-size) x = tf.image.resize(x, new_size, method=tf.image.ResizeMethod.BILINEAR) - + if nn.data_format == "NCHW": - x = tf.transpose(x, (0,3,1,2)) - - return x + x = tf.transpose(x, (0,3,1,2)) + + return x nn.resize2d_bilinear = resize2d_bilinear +def resize2d_nearest(x, size=2): + if size in [-1,0,1]: + return x + if size > 0: + raise Exception("") + else: + if nn.data_format == "NCHW": + x = x[:,:,::-size,::-size] + else: + x = x[:,::-size,::-size,:] + return x + + h = x.shape[nn.conv2d_spatial_axes[0]].value + w = x.shape[nn.conv2d_spatial_axes[1]].value + + if nn.data_format == "NCHW": + x = tf.transpose(x, (0,2,3,1)) + + if size > 0: + new_size = (h*size,w*size) + else: + new_size = (h//-size,w//-size) + + x = tf.image.resize(x, new_size, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) + + if nn.data_format == "NCHW": + x = tf.transpose(x, (0,3,1,2)) + + return x +nn.resize2d_nearest = resize2d_nearest + def flatten(x): if nn.data_format == "NHWC": # match NCHW version in order to switch data_format without problems @@ -173,7 +209,7 @@ def random_binomial(shape, p=0.0, dtype=None, seed=None): seed = np.random.randint(10e6) return array_ops.where( random_ops.random_uniform(shape, dtype=tf.float16, seed=seed) < p, - array_ops.ones(shape, dtype=dtype), array_ops.zeros(shape, dtype=dtype)) + array_ops.ones(shape, dtype=dtype), array_ops.zeros(shape, dtype=dtype)) nn.random_binomial = random_binomial def gaussian_blur(input, radius=2.0): @@ -181,7 +217,9 @@ def gaussian_blur(input, radius=2.0): return np.exp(-(float(x) - float(mu)) ** 2 / (2 * sigma ** 2)) def make_kernel(sigma): - kernel_size = max(3, int(2 * 2 * sigma + 1)) + kernel_size = max(3, int(2 * 2 * sigma)) + if kernel_size % 2 == 0: + kernel_size += 1 mean = np.floor(0.5 * kernel_size) kernel_1d = np.array([gaussian(x, mean, sigma) for x in range(kernel_size)]) np_kernel = np.outer(kernel_1d, kernel_1d).astype(np.float32) @@ -238,6 +276,8 @@ def dssim(img1,img2, max_val, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03 img1 = tf.cast(img1, tf.float32) img2 = tf.cast(img2, tf.float32) + filter_size = max(1, filter_size) + kernel = np.arange(0, filter_size, dtype=np.float32) kernel -= (filter_size - 1 ) / 2.0 kernel = kernel**2 @@ -300,7 +340,17 @@ def depth_to_space(x, size): x = tf.reshape(x, (-1, oh, ow, oc, )) return x else: - return tf.depth_to_space(x, size, data_format=nn.data_format) + cfg = nn.getCurrentDeviceConfig() + if not cfg.cpu_only: + return tf.depth_to_space(x, size, data_format=nn.data_format) + b,c,h,w = x.shape.as_list() + oh, ow = h * size, w * size + oc = c // (size * size) + + x = tf.reshape(x, (-1, size, size, oc, h, w, ) ) + x = tf.transpose(x, (0, 3, 4, 1, 5, 2)) + x = tf.reshape(x, (-1, oc, oh, ow)) + return x nn.depth_to_space = depth_to_space def rgb_to_lab(srgb): @@ -333,6 +383,23 @@ def rgb_to_lab(srgb): return tf.reshape(lab_pixels, tf.shape(srgb)) nn.rgb_to_lab = rgb_to_lab +def total_variation_mse(images): + """ + Same as generic total_variation, but MSE diff instead of MAE + """ + pixel_dif1 = images[:, 1:, :, :] - images[:, :-1, :, :] + pixel_dif2 = images[:, :, 1:, :] - images[:, :, :-1, :] + + tot_var = ( tf.reduce_sum(tf.square(pixel_dif1), axis=[1,2,3]) + + tf.reduce_sum(tf.square(pixel_dif2), axis=[1,2,3]) ) + return tot_var +nn.total_variation_mse = total_variation_mse + + +def pixel_norm(x, axes): + return x * tf.rsqrt(tf.reduce_mean(tf.square(x), axis=axes, keepdims=True) + 1e-06) +nn.pixel_norm = pixel_norm + """ def tf_suppress_lower_mean(t, eps=0.00001): if t.shape.ndims != 1: @@ -342,4 +409,70 @@ def tf_suppress_lower_mean(t, eps=0.00001): q = tf.clip_by_value(q-t_mean_eps, 0, eps) q = q * (t/eps) return q -""" \ No newline at end of file +""" + + + +def _get_pixel_value(img, x, y): + shape = tf.shape(x) + batch_size = shape[0] + height = shape[1] + width = shape[2] + + batch_idx = tf.range(0, batch_size) + batch_idx = tf.reshape(batch_idx, (batch_size, 1, 1)) + b = tf.tile(batch_idx, (1, height, width)) + + indices = tf.stack([b, y, x], 3) + + return tf.gather_nd(img, indices) + +def bilinear_sampler(img, x, y): + H = tf.shape(img)[1] + W = tf.shape(img)[2] + H_MAX = tf.cast(H - 1, tf.int32) + W_MAX = tf.cast(W - 1, tf.int32) + + # grab 4 nearest corner points for each (x_i, y_i) + x0 = tf.cast(tf.floor(x), tf.int32) + x1 = x0 + 1 + y0 = tf.cast(tf.floor(y), tf.int32) + y1 = y0 + 1 + + # clip to range [0, H-1/W-1] to not violate img boundaries + x0 = tf.clip_by_value(x0, 0, W_MAX) + x1 = tf.clip_by_value(x1, 0, W_MAX) + y0 = tf.clip_by_value(y0, 0, H_MAX) + y1 = tf.clip_by_value(y1, 0, H_MAX) + + # get pixel value at corner coords + Ia = _get_pixel_value(img, x0, y0) + Ib = _get_pixel_value(img, x0, y1) + Ic = _get_pixel_value(img, x1, y0) + Id = _get_pixel_value(img, x1, y1) + + # recast as float for delta calculation + x0 = tf.cast(x0, tf.float32) + x1 = tf.cast(x1, tf.float32) + y0 = tf.cast(y0, tf.float32) + y1 = tf.cast(y1, tf.float32) + + # calculate deltas + wa = (x1-x) * (y1-y) + wb = (x1-x) * (y-y0) + wc = (x-x0) * (y1-y) + wd = (x-x0) * (y-y0) + + # add dimension for addition + wa = tf.expand_dims(wa, axis=3) + wb = tf.expand_dims(wb, axis=3) + wc = tf.expand_dims(wc, axis=3) + wd = tf.expand_dims(wd, axis=3) + + # compute output + out = tf.add_n([wa*Ia, wb*Ib, wc*Ic, wd*Id]) + + return out + +nn.bilinear_sampler = bilinear_sampler + diff --git a/core/leras/optimizers/AdaBelief.py b/core/leras/optimizers/AdaBelief.py new file mode 100644 index 0000000..da6e1a2 --- /dev/null +++ b/core/leras/optimizers/AdaBelief.py @@ -0,0 +1,81 @@ +import numpy as np +from core.leras import nn +from tensorflow.python.ops import control_flow_ops, state_ops + +tf = nn.tf + +class AdaBelief(nn.OptimizerBase): + def __init__(self, lr=0.001, beta_1=0.9, beta_2=0.999, lr_dropout=1.0, lr_cos=0, clipnorm=0.0, name=None, **kwargs): + super().__init__(name=name) + + if name is None: + raise ValueError('name must be defined.') + + self.lr = lr + self.beta_1 = beta_1 + self.beta_2 = beta_2 + self.lr_dropout = lr_dropout + self.lr_cos = lr_cos + self.clipnorm = clipnorm + + with tf.device('/CPU:0') : + with tf.variable_scope(self.name): + self.iterations = tf.Variable(0, dtype=tf.int64, name='iters') + + self.ms_dict = {} + self.vs_dict = {} + self.lr_rnds_dict = {} + + def get_weights(self): + return [self.iterations] + list(self.ms_dict.values()) + list(self.vs_dict.values()) + + def initialize_variables(self, trainable_weights, vars_on_cpu=True, lr_dropout_on_cpu=False): + # Initialize here all trainable variables used in training + e = tf.device('/CPU:0') if vars_on_cpu else None + if e: e.__enter__() + with tf.variable_scope(self.name): + ms = { v.name : tf.get_variable ( f'ms_{v.name}'.replace(':','_'), v.shape, dtype=v.dtype, initializer=tf.initializers.constant(0.0), trainable=False) for v in trainable_weights } + vs = { v.name : tf.get_variable ( f'vs_{v.name}'.replace(':','_'), v.shape, dtype=v.dtype, initializer=tf.initializers.constant(0.0), trainable=False) for v in trainable_weights } + self.ms_dict.update (ms) + self.vs_dict.update (vs) + + if self.lr_dropout != 1.0: + e = tf.device('/CPU:0') if lr_dropout_on_cpu else None + if e: e.__enter__() + lr_rnds = [ nn.random_binomial( v.shape, p=self.lr_dropout, dtype=v.dtype) for v in trainable_weights ] + if e: e.__exit__(None, None, None) + self.lr_rnds_dict.update ( { v.name : rnd for v,rnd in zip(trainable_weights,lr_rnds) } ) + if e: e.__exit__(None, None, None) + + def get_update_op(self, grads_vars): + updates = [] + + if self.clipnorm > 0.0: + norm = tf.sqrt( sum([tf.reduce_sum(tf.square(tf.cast(g, tf.float32))) for g,v in grads_vars])) + updates += [ state_ops.assign_add( self.iterations, 1) ] + for i, (g,v) in enumerate(grads_vars): + if self.clipnorm > 0.0: + g = self.tf_clip_norm(g, self.clipnorm, tf.cast(norm, g.dtype) ) + + ms = self.ms_dict[ v.name ] + vs = self.vs_dict[ v.name ] + + m_t = self.beta_1*ms + (1.0-self.beta_1) * g + v_t = self.beta_2*vs + (1.0-self.beta_2) * tf.square(g-m_t) + + lr = tf.constant(self.lr, g.dtype) + if self.lr_cos != 0: + lr *= (tf.cos( tf.cast(self.iterations, g.dtype) * (2*3.1415926535/ float(self.lr_cos) ) ) + 1.0) / 2.0 + + v_diff = - lr * m_t / (tf.sqrt(v_t) + np.finfo( g.dtype.as_numpy_dtype ).resolution ) + if self.lr_dropout != 1.0: + lr_rnd = self.lr_rnds_dict[v.name] + v_diff *= lr_rnd + new_v = v + v_diff + + updates.append (state_ops.assign(ms, m_t)) + updates.append (state_ops.assign(vs, v_t)) + updates.append (state_ops.assign(v, new_v)) + + return control_flow_ops.group ( *updates, name=self.name+'_updates') +nn.AdaBelief = AdaBelief diff --git a/core/leras/optimizers/RMSprop.py b/core/leras/optimizers/RMSprop.py index edd4c38..0b20fbf 100644 --- a/core/leras/optimizers/RMSprop.py +++ b/core/leras/optimizers/RMSprop.py @@ -1,31 +1,33 @@ +import numpy as np from tensorflow.python.ops import control_flow_ops, state_ops from core.leras import nn tf = nn.tf class RMSprop(nn.OptimizerBase): - def __init__(self, lr=0.001, rho=0.9, lr_dropout=1.0, epsilon=1e-7, clipnorm=0.0, name=None): + def __init__(self, lr=0.001, rho=0.9, lr_dropout=1.0, lr_cos=0, clipnorm=0.0, name=None, **kwargs): super().__init__(name=name) if name is None: raise ValueError('name must be defined.') self.lr_dropout = lr_dropout + self.lr_cos = lr_cos + self.lr = lr + self.rho = rho self.clipnorm = clipnorm with tf.device('/CPU:0') : with tf.variable_scope(self.name): - self.lr = tf.Variable (lr, name="lr") - self.rho = tf.Variable (rho, name="rho") - self.epsilon = tf.Variable (epsilon, name="epsilon") + self.iterations = tf.Variable(0, dtype=tf.int64, name='iters') self.accumulators_dict = {} self.lr_rnds_dict = {} def get_weights(self): - return [self.lr, self.rho, self.epsilon, self.iterations] + list(self.accumulators_dict.values()) + return [self.iterations] + list(self.accumulators_dict.values()) - def initialize_variables(self, trainable_weights, vars_on_cpu=True): + def initialize_variables(self, trainable_weights, vars_on_cpu=True, lr_dropout_on_cpu=False): # Initialize here all trainable variables used in training e = tf.device('/CPU:0') if vars_on_cpu else None if e: e.__enter__() @@ -34,7 +36,10 @@ class RMSprop(nn.OptimizerBase): self.accumulators_dict.update ( accumulators) if self.lr_dropout != 1.0: + e = tf.device('/CPU:0') if lr_dropout_on_cpu else None + if e: e.__enter__() lr_rnds = [ nn.random_binomial( v.shape, p=self.lr_dropout, dtype=v.dtype) for v in trainable_weights ] + if e: e.__exit__(None, None, None) self.lr_rnds_dict.update ( { v.name : rnd for v,rnd in zip(trainable_weights,lr_rnds) } ) if e: e.__exit__(None, None, None) @@ -42,21 +47,21 @@ class RMSprop(nn.OptimizerBase): updates = [] if self.clipnorm > 0.0: - norm = tf.sqrt( sum([tf.reduce_sum(tf.square(g)) for g,v in grads_vars])) + norm = tf.sqrt( sum([tf.reduce_sum(tf.square(tf.cast(g, tf.float32))) for g,v in grads_vars])) updates += [ state_ops.assign_add( self.iterations, 1) ] for i, (g,v) in enumerate(grads_vars): if self.clipnorm > 0.0: - g = self.tf_clip_norm(g, self.clipnorm, norm) + g = self.tf_clip_norm(g, self.clipnorm, tf.cast(norm, g.dtype) ) a = self.accumulators_dict[ v.name ] - rho = tf.cast(self.rho, a.dtype) - new_a = rho * a + (1. - rho) * tf.square(g) + new_a = self.rho * a + (1. - self.rho) * tf.square(g) - lr = tf.cast(self.lr, a.dtype) - epsilon = tf.cast(self.epsilon, a.dtype) + lr = tf.constant(self.lr, g.dtype) + if self.lr_cos != 0: + lr *= (tf.cos( tf.cast(self.iterations, g.dtype) * (2*3.1415926535/ float(self.lr_cos) ) ) + 1.0) / 2.0 - v_diff = - lr * g / (tf.sqrt(new_a) + epsilon) + v_diff = - lr * g / (tf.sqrt(new_a) + np.finfo( g.dtype.as_numpy_dtype ).resolution ) if self.lr_dropout != 1.0: lr_rnd = self.lr_rnds_dict[v.name] v_diff *= lr_rnd diff --git a/core/leras/optimizers/__init__.py b/core/leras/optimizers/__init__.py index aec36af..4f8a7e4 100644 --- a/core/leras/optimizers/__init__.py +++ b/core/leras/optimizers/__init__.py @@ -1,2 +1,3 @@ from .OptimizerBase import * -from .RMSprop import * \ No newline at end of file +from .RMSprop import * +from .AdaBelief import * \ No newline at end of file diff --git a/core/mathlib/__init__.py b/core/mathlib/__init__.py index a11e725..7e5fa13 100644 --- a/core/mathlib/__init__.py +++ b/core/mathlib/__init__.py @@ -1,7 +1,12 @@ -import numpy as np import math + +import cv2 +import numpy as np +import numpy.linalg as npla + from .umeyama import umeyama + def get_power_of_two(x): i = 0 while (1 << i) < x: @@ -23,3 +28,70 @@ def rotationMatrixToEulerAngles(R) : def polygon_area(x,y): return 0.5*np.abs(np.dot(x,np.roll(y,1))-np.dot(y,np.roll(x,1))) + +def rotate_point(origin, point, deg): + """ + Rotate a point counterclockwise by a given angle around a given origin. + + The angle should be given in radians. + """ + ox, oy = origin + px, py = point + + rad = deg * math.pi / 180.0 + qx = ox + math.cos(rad) * (px - ox) - math.sin(rad) * (py - oy) + qy = oy + math.sin(rad) * (px - ox) + math.cos(rad) * (py - oy) + return np.float32([qx, qy]) + +def transform_points(points, mat, invert=False): + if invert: + mat = cv2.invertAffineTransform (mat) + points = np.expand_dims(points, axis=1) + points = cv2.transform(points, mat, points.shape) + points = np.squeeze(points) + return points + + +def transform_mat(mat, res, tx, ty, rotation, scale): + """ + transform mat in local space of res + scale -> translate -> rotate + + tx,ty float + rotation int degrees + scale float + """ + + + lt, rt, lb, ct = transform_points ( np.float32([(0,0),(res,0),(0,res),(res / 2, res/2) ]),mat, True) + + hor_v = (rt-lt).astype(np.float32) + hor_size = npla.norm(hor_v) + hor_v /= hor_size + + ver_v = (lb-lt).astype(np.float32) + ver_size = npla.norm(ver_v) + ver_v /= ver_size + + bt_diag_vec = (rt-ct).astype(np.float32) + half_diag_len = npla.norm(bt_diag_vec) + bt_diag_vec /= half_diag_len + + tb_diag_vec = np.float32( [ -bt_diag_vec[1], bt_diag_vec[0] ] ) + + rt = ct + bt_diag_vec*half_diag_len*scale + lb = ct - bt_diag_vec*half_diag_len*scale + lt = ct - tb_diag_vec*half_diag_len*scale + + rt[0] += tx*hor_size + lb[0] += tx*hor_size + lt[0] += tx*hor_size + rt[1] += ty*ver_size + lb[1] += ty*ver_size + lt[1] += ty*ver_size + + rt = rotate_point(ct, rt, rotation) + lb = rotate_point(ct, lb, rotation) + lt = rotate_point(ct, lt, rotation) + + return cv2.getAffineTransform( np.float32([lt, rt, lb]), np.float32([ [0,0], [res,0], [0,res] ]) ) diff --git a/core/mplib/MPSharedList.py b/core/mplib/MPSharedList.py new file mode 100644 index 0000000..874c56a --- /dev/null +++ b/core/mplib/MPSharedList.py @@ -0,0 +1,111 @@ +import multiprocessing +import pickle +import struct +from core.joblib import Subprocessor + +class MPSharedList(): + """ + Provides read-only pickled list of constant objects via shared memory aka 'multiprocessing.Array' + Thus no 4GB limit for subprocesses. + + supports list concat via + or sum() + """ + + def __init__(self, obj_list): + if obj_list is None: + self.obj_counts = None + self.table_offsets = None + self.data_offsets = None + self.sh_bs = None + else: + obj_count, table_offset, data_offset, sh_b = MPSharedList.bake_data(obj_list) + + self.obj_counts = [obj_count] + self.table_offsets = [table_offset] + self.data_offsets = [data_offset] + self.sh_bs = [sh_b] + + def __add__(self, o): + if isinstance(o, MPSharedList): + m = MPSharedList(None) + m.obj_counts = self.obj_counts + o.obj_counts + m.table_offsets = self.table_offsets + o.table_offsets + m.data_offsets = self.data_offsets + o.data_offsets + m.sh_bs = self.sh_bs + o.sh_bs + return m + elif isinstance(o, int): + return self + else: + raise ValueError(f"MPSharedList object of class {o.__class__} is not supported for __add__ operator.") + + def __radd__(self, o): + return self+o + + def __len__(self): + return sum(self.obj_counts) + + def __getitem__(self, key): + obj_count = sum(self.obj_counts) + if key < 0: + key = obj_count+key + if key < 0 or key >= obj_count: + raise ValueError("out of range") + + for i in range(len(self.obj_counts)): + + if key < self.obj_counts[i]: + table_offset = self.table_offsets[i] + data_offset = self.data_offsets[i] + sh_b = self.sh_bs[i] + break + key -= self.obj_counts[i] + + sh_b = memoryview(sh_b).cast('B') + + offset_start, offset_end = struct.unpack(' self.no_response_time_sec: + #subprocess busy too long + io.log_info ( '%s doesnt response, terminating it.' % (cli.name) ) + self.on_data_return (cli.host_dict, cli.sent_data ) + cli.kill() + self.clis.remove(cli) + + for cli in self.clis[:]: + if cli.state == 0: + #free state of subprocess, get some data from get_data + data = self.get_data(cli.host_dict) + if data is not None: + #and send it to subprocess + cli.s2c.put ( {'op': 'data', 'data' : data} ) + cli.sent_time = time.time() + cli.sent_data = data + cli.state = 1 + + if all ([cli.state == 0 for cli in self.clis]): + #gracefully terminating subprocesses + for cli in self.clis[:]: + cli.s2c.put ( {'op': 'close'} ) + cli.sent_time = time.time() + + while True: + for cli in self.clis[:]: + terminate_it = False + while not cli.c2s.empty(): + obj = cli.c2s.get() + obj_op = obj['op'] + if obj_op == 'finalized': + terminate_it = True + break + + if (time.time() - cli.sent_time) > 30: + terminate_it = True + + if terminate_it: + cli.state = 2 + cli.kill() + + if all ([cli.state == 2 for cli in self.clis]): + break + + #finalizing host logic + self.q_timer.stop() + self.q_timer = None + self.on_clients_finalized() + diff --git a/core/qtex/QXIconButton.py b/core/qtex/QXIconButton.py new file mode 100644 index 0000000..235d149 --- /dev/null +++ b/core/qtex/QXIconButton.py @@ -0,0 +1,83 @@ +from PyQt5.QtCore import * +from PyQt5.QtGui import * +from PyQt5.QtWidgets import * + +from localization import StringsDB +from .QXMainWindow import * + +class QXIconButton(QPushButton): + """ + Custom Icon button that works through keyEvent system, without shortcut of QAction + works only with QXMainWindow as global window class + currently works only with one-key shortcut + """ + + def __init__(self, icon, + tooltip=None, + shortcut=None, + click_func=None, + first_repeat_delay=300, + repeat_delay=20, + ): + + super().__init__(icon, "") + + self.setIcon(icon) + + if shortcut is not None: + tooltip = f"{tooltip} ( {StringsDB['S_HOT_KEY'] }: {shortcut} )" + + self.setToolTip(tooltip) + + + self.seq = QKeySequence(shortcut) if shortcut is not None else None + + QXMainWindow.inst.add_keyPressEvent_listener ( self.on_keyPressEvent ) + QXMainWindow.inst.add_keyReleaseEvent_listener ( self.on_keyReleaseEvent ) + + self.click_func = click_func + self.first_repeat_delay = first_repeat_delay + self.repeat_delay = repeat_delay + self.repeat_timer = None + + self.op_device = None + + self.pressed.connect( lambda : self.action(is_pressed=True) ) + self.released.connect( lambda : self.action(is_pressed=False) ) + + def action(self, is_pressed=None, op_device=None): + if self.click_func is None: + return + + if is_pressed is not None: + if is_pressed: + if self.repeat_timer is None: + self.click_func() + self.repeat_timer = QTimer() + self.repeat_timer.timeout.connect(self.action) + self.repeat_timer.start(self.first_repeat_delay) + else: + if self.repeat_timer is not None: + self.repeat_timer.stop() + self.repeat_timer = None + else: + self.click_func() + if self.repeat_timer is not None: + self.repeat_timer.setInterval(self.repeat_delay) + + def on_keyPressEvent(self, ev): + key = ev.nativeVirtualKey() + if ev.isAutoRepeat(): + return + + if self.seq is not None: + if key == self.seq[0]: + self.action(is_pressed=True) + + def on_keyReleaseEvent(self, ev): + key = ev.nativeVirtualKey() + if ev.isAutoRepeat(): + return + if self.seq is not None: + if key == self.seq[0]: + self.action(is_pressed=False) diff --git a/core/qtex/QXMainWindow.py b/core/qtex/QXMainWindow.py new file mode 100644 index 0000000..a50e597 --- /dev/null +++ b/core/qtex/QXMainWindow.py @@ -0,0 +1,34 @@ +from PyQt5.QtCore import * +from PyQt5.QtGui import * +from PyQt5.QtWidgets import * + +class QXMainWindow(QWidget): + """ + Custom mainwindow class that provides global single instance and event listeners + """ + inst = None + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if QXMainWindow.inst is not None: + raise Exception("QXMainWindow can only be one.") + QXMainWindow.inst = self + + self.keyPressEvent_listeners = [] + self.keyReleaseEvent_listeners = [] + self.setFocusPolicy(Qt.WheelFocus) + + def add_keyPressEvent_listener(self, func): + self.keyPressEvent_listeners.append (func) + + def add_keyReleaseEvent_listener(self, func): + self.keyReleaseEvent_listeners.append (func) + + def keyPressEvent(self, ev): + super().keyPressEvent(ev) + for func in self.keyPressEvent_listeners: + func(ev) + + def keyReleaseEvent(self, ev): + super().keyReleaseEvent(ev) + for func in self.keyReleaseEvent_listeners: + func(ev) \ No newline at end of file diff --git a/core/qtex/__init__.py b/core/qtex/__init__.py new file mode 100644 index 0000000..2cb44b5 --- /dev/null +++ b/core/qtex/__init__.py @@ -0,0 +1,3 @@ +from .qtex import * +from .QSubprocessor import * +from .QXIconButton import * \ No newline at end of file diff --git a/core/qtex/qtex.py b/core/qtex/qtex.py new file mode 100644 index 0000000..d15e41d --- /dev/null +++ b/core/qtex/qtex.py @@ -0,0 +1,80 @@ +import numpy as np +from PyQt5.QtCore import * +from PyQt5.QtGui import * +from PyQt5.QtWidgets import * +from localization import StringsDB + +from .QXMainWindow import * + + +class QActionEx(QAction): + def __init__(self, icon, text, shortcut=None, trigger_func=None, shortcut_in_tooltip=False, is_checkable=False, is_auto_repeat=False ): + super().__init__(icon, text) + if shortcut is not None: + self.setShortcut(shortcut) + if shortcut_in_tooltip: + + self.setToolTip( f"{text} ( {StringsDB['S_HOT_KEY'] }: {shortcut} )") + + if trigger_func is not None: + self.triggered.connect(trigger_func) + if is_checkable: + self.setCheckable(True) + self.setAutoRepeat(is_auto_repeat) + +def QImage_from_np(img): + if img.dtype != np.uint8: + raise ValueError("img should be in np.uint8 format") + + h,w,c = img.shape + if c == 1: + fmt = QImage.Format_Grayscale8 + elif c == 3: + fmt = QImage.Format_BGR888 + elif c == 4: + fmt = QImage.Format_ARGB32 + else: + raise ValueError("unsupported channel count") + + return QImage(img.data, w, h, c*w, fmt ) + +def QImage_to_np(q_img, fmt=QImage.Format_BGR888): + q_img = q_img.convertToFormat(fmt) + + width = q_img.width() + height = q_img.height() + + b = q_img.constBits() + b.setsize(height * width * 3) + arr = np.frombuffer(b, np.uint8).reshape((height, width, 3)) + return arr#[::-1] + +def QPixmap_from_np(img): + return QPixmap.fromImage(QImage_from_np(img)) + +def QPoint_from_np(n): + return QPoint(*n.astype(np.int)) + +def QPoint_to_np(q): + return np.int32( [q.x(), q.y()] ) + +def QSize_to_np(q): + return np.int32( [q.width(), q.height()] ) + +class QDarkPalette(QPalette): + def __init__(self): + super().__init__() + text_color = QColor(200,200,200) + self.setColor(QPalette.Window, QColor(53, 53, 53)) + self.setColor(QPalette.WindowText, text_color ) + self.setColor(QPalette.Base, QColor(25, 25, 25)) + self.setColor(QPalette.AlternateBase, QColor(53, 53, 53)) + self.setColor(QPalette.ToolTipBase, text_color ) + self.setColor(QPalette.ToolTipText, text_color ) + self.setColor(QPalette.Text, text_color ) + self.setColor(QPalette.Button, QColor(53, 53, 53)) + self.setColor(QPalette.ButtonText, Qt.white) + self.setColor(QPalette.BrightText, Qt.red) + self.setColor(QPalette.Link, QColor(42, 130, 218)) + self.setColor(QPalette.Highlight, QColor(42, 130, 218)) + self.setColor(QPalette.HighlightedText, Qt.black) \ No newline at end of file diff --git a/core/randomex.py b/core/randomex.py index 9c8dc63..ecd18e2 100644 --- a/core/randomex.py +++ b/core/randomex.py @@ -1,12 +1,14 @@ import numpy as np -def random_normal( size=(1,), trunc_val = 2.5 ): +def random_normal( size=(1,), trunc_val = 2.5, rnd_state=None ): + if rnd_state is None: + rnd_state = np.random len = np.array(size).prod() result = np.empty ( (len,) , dtype=np.float32) for i in range (len): while True: - x = np.random.normal() + x = rnd_state.normal() if x >= -trunc_val and x <= trunc_val: break result[i] = (x / trunc_val) diff --git a/doc/DFL_welcome.jpg b/doc/DFL_welcome.jpg deleted file mode 100644 index 4f362b5..0000000 Binary files a/doc/DFL_welcome.jpg and /dev/null differ diff --git a/doc/DFL_welcome.png b/doc/DFL_welcome.png new file mode 100644 index 0000000..2e4e138 Binary files /dev/null and b/doc/DFL_welcome.png differ diff --git a/doc/DeepFaceLab_is_working.png b/doc/DeepFaceLab_is_working.png deleted file mode 100644 index 4d86d36..0000000 Binary files a/doc/DeepFaceLab_is_working.png and /dev/null differ diff --git a/doc/deage_0_1.jpg b/doc/deage_0_1.jpg new file mode 100644 index 0000000..51e057e Binary files /dev/null and b/doc/deage_0_1.jpg differ diff --git a/doc/deage_0_2.jpg b/doc/deage_0_2.jpg new file mode 100644 index 0000000..996cacd Binary files /dev/null and b/doc/deage_0_2.jpg differ diff --git a/doc/deepfake_progress.png b/doc/deepfake_progress.png new file mode 100644 index 0000000..32d8c7e Binary files /dev/null and b/doc/deepfake_progress.png differ diff --git a/doc/deepfake_progress_source.psd b/doc/deepfake_progress_source.psd new file mode 100644 index 0000000..8954923 Binary files /dev/null and b/doc/deepfake_progress_source.psd differ diff --git a/doc/head_replace_0_1.jpg b/doc/head_replace_0_1.jpg new file mode 100644 index 0000000..9125d91 Binary files /dev/null and b/doc/head_replace_0_1.jpg differ diff --git a/doc/head_replace_0_2.jpg b/doc/head_replace_0_2.jpg new file mode 100644 index 0000000..14c23ec Binary files /dev/null and b/doc/head_replace_0_2.jpg differ diff --git a/doc/head_replace_1_1.jpg b/doc/head_replace_1_1.jpg new file mode 100644 index 0000000..464bf50 Binary files /dev/null and b/doc/head_replace_1_1.jpg differ diff --git a/doc/head_replace_1_2.jpg b/doc/head_replace_1_2.jpg new file mode 100644 index 0000000..14c845a Binary files /dev/null and b/doc/head_replace_1_2.jpg differ diff --git a/doc/head_replace_2_1.jpg b/doc/head_replace_2_1.jpg new file mode 100644 index 0000000..a447893 Binary files /dev/null and b/doc/head_replace_2_1.jpg differ diff --git a/doc/head_replace_2_2.jpg b/doc/head_replace_2_2.jpg new file mode 100644 index 0000000..dc4eaf3 Binary files /dev/null and b/doc/head_replace_2_2.jpg differ diff --git a/doc/logo_directx.png b/doc/logo_directx.png new file mode 100644 index 0000000..f9fb10a Binary files /dev/null and b/doc/logo_directx.png differ diff --git a/doc/make_everything_ok.png b/doc/make_everything_ok.png new file mode 100644 index 0000000..9a90c0d Binary files /dev/null and b/doc/make_everything_ok.png differ diff --git a/doc/meme1.jpg b/doc/meme1.jpg new file mode 100644 index 0000000..819d36d Binary files /dev/null and b/doc/meme1.jpg differ diff --git a/doc/meme2.jpg b/doc/meme2.jpg new file mode 100644 index 0000000..9899c85 Binary files /dev/null and b/doc/meme2.jpg differ diff --git a/doc/meme3.jpg b/doc/meme3.jpg new file mode 100644 index 0000000..3ee794a Binary files /dev/null and b/doc/meme3.jpg differ diff --git a/doc/mini_tutorial.jpg b/doc/mini_tutorial.jpg new file mode 100644 index 0000000..2243fd9 Binary files /dev/null and b/doc/mini_tutorial.jpg differ diff --git a/doc/political_speech1.jpg b/doc/political_speech1.jpg new file mode 100644 index 0000000..33ae2ab Binary files /dev/null and b/doc/political_speech1.jpg differ diff --git a/doc/political_speech2.jpg b/doc/political_speech2.jpg new file mode 100644 index 0000000..f170ee2 Binary files /dev/null and b/doc/political_speech2.jpg differ diff --git a/doc/political_speech3.jpg b/doc/political_speech3.jpg new file mode 100644 index 0000000..7da3a64 Binary files /dev/null and b/doc/political_speech3.jpg differ diff --git a/doc/progress_2018.png b/doc/progress_2018.png deleted file mode 100644 index 89be8e2..0000000 Binary files a/doc/progress_2018.png and /dev/null differ diff --git a/doc/progress_2020.png b/doc/progress_2020.png deleted file mode 100644 index 43218d0..0000000 Binary files a/doc/progress_2020.png and /dev/null differ diff --git a/doc/replace_the_face.jpg b/doc/replace_the_face.jpg new file mode 100644 index 0000000..55501d0 Binary files /dev/null and b/doc/replace_the_face.jpg differ diff --git a/doc/tiktok_icon.png b/doc/tiktok_icon.png new file mode 100644 index 0000000..63d3e7e Binary files /dev/null and b/doc/tiktok_icon.png differ diff --git a/facelib/FAN.npy b/facelib/2DFAN.npy similarity index 100% rename from facelib/FAN.npy rename to facelib/2DFAN.npy diff --git a/facelib/FANSeg_full_face_256.npy b/facelib/3DFAN.npy similarity index 83% rename from facelib/FANSeg_full_face_256.npy rename to facelib/3DFAN.npy index 53a6664..b96bcd2 100644 Binary files a/facelib/FANSeg_full_face_256.npy and b/facelib/3DFAN.npy differ diff --git a/facelib/FANExtractor.py b/facelib/FANExtractor.py index 3e6c9ad..e71f393 100644 --- a/facelib/FANExtractor.py +++ b/facelib/FANExtractor.py @@ -13,8 +13,9 @@ from core.leras import nn ported from https://github.com/1adrianb/face-alignment """ class FANExtractor(object): - def __init__ (self, place_model_on_cpu=False): - model_path = Path(__file__).parent / "FAN.npy" + def __init__ (self, landmarks_3D=False, place_model_on_cpu=False): + + model_path = Path(__file__).parent / ( "2DFAN.npy" if not landmarks_3D else "3DFAN.npy") if not model_path.exists(): raise Exception("Unable to load FANExtractor model") @@ -27,13 +28,13 @@ class FANExtractor(object): self.out_planes = out_planes self.bn1 = nn.BatchNorm2D(in_planes) - self.conv1 = nn.Conv2D (in_planes, out_planes/2, kernel_size=3, strides=1, padding='SAME', use_bias=False ) + self.conv1 = nn.Conv2D (in_planes, out_planes//2, kernel_size=3, strides=1, padding='SAME', use_bias=False ) self.bn2 = nn.BatchNorm2D(out_planes//2) - self.conv2 = nn.Conv2D (out_planes/2, out_planes/4, kernel_size=3, strides=1, padding='SAME', use_bias=False ) + self.conv2 = nn.Conv2D (out_planes//2, out_planes//4, kernel_size=3, strides=1, padding='SAME', use_bias=False ) self.bn3 = nn.BatchNorm2D(out_planes//4) - self.conv3 = nn.Conv2D (out_planes/4, out_planes/4, kernel_size=3, strides=1, padding='SAME', use_bias=False ) + self.conv3 = nn.Conv2D (out_planes//4, out_planes//4, kernel_size=3, strides=1, padding='SAME', use_bias=False ) if self.in_planes != self.out_planes: self.down_bn1 = nn.BatchNorm2D(in_planes) diff --git a/facelib/FaceEnhancer.py b/facelib/FaceEnhancer.py index 48e21f6..1dc0dd9 100644 --- a/facelib/FaceEnhancer.py +++ b/facelib/FaceEnhancer.py @@ -161,11 +161,11 @@ class FaceEnhancer(object): if not model_path.exists(): raise Exception("Unable to load FaceEnhancer.npy") - with tf.device ('/CPU:0' if place_model_on_cpu else '/GPU:0'): + with tf.device ('/CPU:0' if place_model_on_cpu else nn.tf_default_device_name): self.model = FaceEnhancer() self.model.load_weights (model_path) - with tf.device ('/CPU:0' if run_on_cpu else '/GPU:0'): + with tf.device ('/CPU:0' if run_on_cpu else nn.tf_default_device_name): self.model.build_for_run ([ (tf.float32, nn.get4Dshape (192,192,3) ), (tf.float32, (None,1,) ), (tf.float32, (None,1,) ), @@ -248,7 +248,7 @@ class FaceEnhancer(object): final_img = final_img [t_padding*up_res:(h-b_padding)*up_res, l_padding*up_res:(w-r_padding)*up_res,:] if preserve_size: - final_img = cv2.resize (final_img, (iw,ih), cv2.INTER_LANCZOS4) + final_img = cv2.resize (final_img, (iw,ih), interpolation=cv2.INTER_LANCZOS4) if not is_tanh: final_img = np.clip( final_img/2+0.5, 0, 1 ) @@ -278,7 +278,7 @@ class FaceEnhancer(object): preupscale_rate = 1.0 / ( max(h,w) / patch_size ) if preupscale_rate != 1.0: - inp_img = cv2.resize (inp_img, ( int(w*preupscale_rate), int(h*preupscale_rate) ), cv2.INTER_LANCZOS4) + inp_img = cv2.resize (inp_img, ( int(w*preupscale_rate), int(h*preupscale_rate) ), interpolation=cv2.INTER_LANCZOS4) h,w,c = inp_img.shape i_max = w-patch_size+1 @@ -310,10 +310,10 @@ class FaceEnhancer(object): final_img /= final_img_div if preserve_size: - final_img = cv2.resize (final_img, (w,h), cv2.INTER_LANCZOS4) + final_img = cv2.resize (final_img, (w,h), interpolation=cv2.INTER_LANCZOS4) else: if preupscale_rate != 1.0: - final_img = cv2.resize (final_img, (tw,th), cv2.INTER_LANCZOS4) + final_img = cv2.resize (final_img, (tw,th), interpolation=cv2.INTER_LANCZOS4) if not is_tanh: final_img = np.clip( final_img/2+0.5, 0, 1 ) diff --git a/facelib/FaceType.py b/facelib/FaceType.py index 816375d..745cff3 100644 --- a/facelib/FaceType.py +++ b/facelib/FaceType.py @@ -7,10 +7,10 @@ class FaceType(IntEnum): FULL = 2 FULL_NO_ALIGN = 3 WHOLE_FACE = 4 - HEAD = 5 - HEAD_NO_ALIGN = 6 + HEAD = 10 + HEAD_NO_ALIGN = 20 - MARK_ONLY = 10, #no align at all, just embedded faceinfo + MARK_ONLY = 100, #no align at all, just embedded faceinfo @staticmethod def fromString (s): @@ -23,21 +23,15 @@ class FaceType(IntEnum): def toString (face_type): return to_string_dict[face_type] -from_string_dict = {'half_face': FaceType.HALF, - 'midfull_face': FaceType.MID_FULL, - 'full_face': FaceType.FULL, - 'whole_face': FaceType.WHOLE_FACE, - 'head' : FaceType.HEAD, - 'mark_only' : FaceType.MARK_ONLY, - 'full_face_no_align' : FaceType.FULL_NO_ALIGN, - 'head_no_align' : FaceType.HEAD_NO_ALIGN, - } to_string_dict = { FaceType.HALF : 'half_face', FaceType.MID_FULL : 'midfull_face', FaceType.FULL : 'full_face', + FaceType.FULL_NO_ALIGN : 'full_face_no_align', FaceType.WHOLE_FACE : 'whole_face', FaceType.HEAD : 'head', - FaceType.MARK_ONLY :'mark_only', - FaceType.FULL_NO_ALIGN : 'full_face_no_align', - FaceType.HEAD_NO_ALIGN : 'head_no_align' + FaceType.HEAD_NO_ALIGN : 'head_no_align', + + FaceType.MARK_ONLY :'mark_only', } + +from_string_dict = { to_string_dict[x] : x for x in to_string_dict.keys() } \ No newline at end of file diff --git a/facelib/LandmarksProcessor.py b/facelib/LandmarksProcessor.py index f4c1e82..8e5d51b 100644 --- a/facelib/LandmarksProcessor.py +++ b/facelib/LandmarksProcessor.py @@ -9,7 +9,6 @@ import numpy.linalg as npla from core import imagelib from core import mathlib from facelib import FaceType -from core.imagelib import IEPolys from core.mathlib.umeyama import umeyama landmarks_2D = np.array([ @@ -103,6 +102,29 @@ landmarks_2D_new = np.array([ [ 0.726104, 0.780233 ], #54 ], dtype=np.float32) +mouth_center_landmarks_2D = np.array([ + [-4.4202591e-07, 4.4916576e-01], #48 + [ 1.8399176e-01, 3.7537053e-01], #49 + [ 3.7018123e-01, 3.3719531e-01], #50 + [ 5.0000089e-01, 3.6938059e-01], #51 + [ 6.2981832e-01, 3.3719531e-01], #52 + [ 8.1600773e-01, 3.7537053e-01], #53 + [ 1.0000000e+00, 4.4916576e-01], #54 + [ 8.2213330e-01, 6.2836081e-01], #55 + [ 6.4110327e-01, 7.0757812e-01], #56 + [ 5.0000089e-01, 7.2259867e-01], #57 + [ 3.5889623e-01, 7.0757812e-01], #58 + [ 1.7786618e-01, 6.2836081e-01], #59 + [ 7.6765373e-02, 4.5882553e-01], #60 + [ 3.6856663e-01, 4.4601500e-01], #61 + [ 5.0000089e-01, 4.5999300e-01], #62 + [ 6.3143289e-01, 4.4601500e-01], #63 + [ 9.2323411e-01, 4.5882553e-01], #64 + [ 6.3399029e-01, 5.4228687e-01], #65 + [ 5.0000089e-01, 5.5843467e-01], #66 + [ 3.6601129e-01, 5.4228687e-01] #67 +], dtype=np.float32) + # 68 point landmark definitions landmarks_68_pt = { "mouth": (48,68), "right_eyebrow": (17, 22), @@ -112,76 +134,76 @@ landmarks_68_pt = { "mouth": (48,68), "nose": (27, 36), # missed one point "jaw": (0, 17) } - landmarks_68_3D = np.array( [ -[-73.393523 , -29.801432 , 47.667532 ], -[-72.775014 , -10.949766 , 45.909403 ], -[-70.533638 , 7.929818 , 44.842580 ], -[-66.850058 , 26.074280 , 43.141114 ], -[-59.790187 , 42.564390 , 38.635298 ], -[-48.368973 , 56.481080 , 30.750622 ], -[-34.121101 , 67.246992 , 18.456453 ], -[-17.875411 , 75.056892 , 3.609035 ], -[0.098749 , 77.061286 , -0.881698 ], -[17.477031 , 74.758448 , 5.181201 ], -[32.648966 , 66.929021 , 19.176563 ], -[46.372358 , 56.311389 , 30.770570 ], -[57.343480 , 42.419126 , 37.628629 ], -[64.388482 , 25.455880 , 40.886309 ], -[68.212038 , 6.990805 , 42.281449 ], -[70.486405 , -11.666193 , 44.142567 ], -[71.375822 , -30.365191 , 47.140426 ], -[-61.119406 , -49.361602 , 14.254422 ], -[-51.287588 , -58.769795 , 7.268147 ], -[-37.804800 , -61.996155 , 0.442051 ], -[-24.022754 , -61.033399 , -6.606501 ], -[-11.635713 , -56.686759 , -11.967398 ], -[12.056636 , -57.391033 , -12.051204 ], -[25.106256 , -61.902186 , -7.315098 ], -[38.338588 , -62.777713 , -1.022953 ], -[51.191007 , -59.302347 , 5.349435 ], -[60.053851 , -50.190255 , 11.615746 ], -[0.653940 , -42.193790 , -13.380835 ], -[0.804809 , -30.993721 , -21.150853 ], -[0.992204 , -19.944596 , -29.284036 ], -[1.226783 , -8.414541 , -36.948060 ], -[-14.772472 , 2.598255 , -20.132003 ], -[-7.180239 , 4.751589 , -23.536684 ], -[0.555920 , 6.562900 , -25.944448 ], -[8.272499 , 4.661005 , -23.695741 ], -[15.214351 , 2.643046 , -20.858157 ], -[-46.047290 , -37.471411 , 7.037989 ], -[-37.674688 , -42.730510 , 3.021217 ], -[-27.883856 , -42.711517 , 1.353629 ], -[-19.648268 , -36.754742 , -0.111088 ], -[-28.272965 , -35.134493 , -0.147273 ], -[-38.082418 , -34.919043 , 1.476612 ], -[19.265868 , -37.032306 , -0.665746 ], -[27.894191 , -43.342445 , 0.247660 ], -[37.437529 , -43.110822 , 1.696435 ], -[45.170805 , -38.086515 , 4.894163 ], -[38.196454 , -35.532024 , 0.282961 ], -[28.764989 , -35.484289 , -1.172675 ], -[-28.916267 , 28.612716 , -2.240310 ], -[-17.533194 , 22.172187 , -15.934335 ], -[-6.684590 , 19.029051 , -22.611355 ], -[0.381001 , 20.721118 , -23.748437 ], -[8.375443 , 19.035460 , -22.721995 ], -[18.876618 , 22.394109 , -15.610679 ], -[28.794412 , 28.079924 , -3.217393 ], -[19.057574 , 36.298248 , -14.987997 ], -[8.956375 , 39.634575 , -22.554245 ], -[0.381549 , 40.395647 , -23.591626 ], -[-7.428895 , 39.836405 , -22.406106 ], -[-18.160634 , 36.677899 , -15.121907 ], -[-24.377490 , 28.677771 , -4.785684 ], -[-6.897633 , 25.475976 , -20.893742 ], -[0.340663 , 26.014269 , -22.220479 ], -[8.444722 , 25.326198 , -21.025520 ], -[24.474473 , 28.323008 , -5.712776 ], -[8.449166 , 30.596216 , -20.671489 ], -[0.205322 , 31.408738 , -21.903670 ], -[-7.198266 , 30.844876 , -20.328022 ] ], dtype=np.float32) +[-73.393523 , -29.801432 , 47.667532 ], #00 +[-72.775014 , -10.949766 , 45.909403 ], #01 +[-70.533638 , 7.929818 , 44.842580 ], #02 +[-66.850058 , 26.074280 , 43.141114 ], #03 +[-59.790187 , 42.564390 , 38.635298 ], #04 +[-48.368973 , 56.481080 , 30.750622 ], #05 +[-34.121101 , 67.246992 , 18.456453 ], #06 +[-17.875411 , 75.056892 , 3.609035 ], #07 +[0.098749 , 77.061286 , -0.881698 ], #08 +[17.477031 , 74.758448 , 5.181201 ], #09 +[32.648966 , 66.929021 , 19.176563 ], #10 +[46.372358 , 56.311389 , 30.770570 ], #11 +[57.343480 , 42.419126 , 37.628629 ], #12 +[64.388482 , 25.455880 , 40.886309 ], #13 +[68.212038 , 6.990805 , 42.281449 ], #14 +[70.486405 , -11.666193 , 44.142567 ], #15 +[71.375822 , -30.365191 , 47.140426 ], #16 +[-61.119406 , -49.361602 , 14.254422 ], #17 +[-51.287588 , -58.769795 , 7.268147 ], #18 +[-37.804800 , -61.996155 , 0.442051 ], #19 +[-24.022754 , -61.033399 , -6.606501 ], #20 +[-11.635713 , -56.686759 , -11.967398 ], #21 +[12.056636 , -57.391033 , -12.051204 ], #22 +[25.106256 , -61.902186 , -7.315098 ], #23 +[38.338588 , -62.777713 , -1.022953 ], #24 +[51.191007 , -59.302347 , 5.349435 ], #25 +[60.053851 , -50.190255 , 11.615746 ], #26 +[0.653940 , -42.193790 , -13.380835 ], #27 +[0.804809 , -30.993721 , -21.150853 ], #28 +[0.992204 , -19.944596 , -29.284036 ], #29 +[1.226783 , -8.414541 , -36.948060 ], #00 +[-14.772472 , 2.598255 , -20.132003 ], #01 +[-7.180239 , 4.751589 , -23.536684 ], #02 +[0.555920 , 6.562900 , -25.944448 ], #03 +[8.272499 , 4.661005 , -23.695741 ], #04 +[15.214351 , 2.643046 , -20.858157 ], #05 +[-46.047290 , -37.471411 , 7.037989 ], #06 +[-37.674688 , -42.730510 , 3.021217 ], #07 +[-27.883856 , -42.711517 , 1.353629 ], #08 +[-19.648268 , -36.754742 , -0.111088 ], #09 +[-28.272965 , -35.134493 , -0.147273 ], #10 +[-38.082418 , -34.919043 , 1.476612 ], #11 +[19.265868 , -37.032306 , -0.665746 ], #12 +[27.894191 , -43.342445 , 0.247660 ], #13 +[37.437529 , -43.110822 , 1.696435 ], #14 +[45.170805 , -38.086515 , 4.894163 ], #15 +[38.196454 , -35.532024 , 0.282961 ], #16 +[28.764989 , -35.484289 , -1.172675 ], #17 +[-28.916267 , 28.612716 , -2.240310 ], #18 +[-17.533194 , 22.172187 , -15.934335 ], #19 +[-6.684590 , 19.029051 , -22.611355 ], #20 +[0.381001 , 20.721118 , -23.748437 ], #21 +[8.375443 , 19.035460 , -22.721995 ], #22 +[18.876618 , 22.394109 , -15.610679 ], #23 +[28.794412 , 28.079924 , -3.217393 ], #24 +[19.057574 , 36.298248 , -14.987997 ], #25 +[8.956375 , 39.634575 , -22.554245 ], #26 +[0.381549 , 40.395647 , -23.591626 ], #27 +[-7.428895 , 39.836405 , -22.406106 ], #28 +[-18.160634 , 36.677899 , -15.121907 ], #29 +[-24.377490 , 28.677771 , -4.785684 ], #30 +[-6.897633 , 25.475976 , -20.893742 ], #31 +[0.340663 , 26.014269 , -22.220479 ], #32 +[8.444722 , 25.326198 , -21.025520 ], #33 +[24.474473 , 28.323008 , -5.712776 ], #34 +[8.449166 , 30.596216 , -20.671489 ], #35 +[0.205322 , 31.408738 , -21.903670 ], #36 +[-7.198266 , 30.844876 , -20.328022 ] #37 +], dtype=np.float32) FaceType_to_padding_remove_align = { FaceType.HALF: (0.0, False), @@ -189,8 +211,8 @@ FaceType_to_padding_remove_align = { FaceType.FULL: (0.2109375, False), FaceType.FULL_NO_ALIGN: (0.2109375, True), FaceType.WHOLE_FACE: (0.40, False), - FaceType.HEAD: (1.0, False), - FaceType.HEAD_NO_ALIGN: (1.0, True), + FaceType.HEAD: (0.70, False), + FaceType.HEAD_NO_ALIGN: (0.70, True), } def convert_98_to_68(lmrks): @@ -254,9 +276,10 @@ def get_transform_mat (image_landmarks, output_size, face_type, scale=1.0): if not isinstance(image_landmarks, np.ndarray): image_landmarks = np.array (image_landmarks) + # estimate landmarks transform from global space to local aligned space with bounds [0..1] mat = umeyama( np.concatenate ( [ image_landmarks[17:49] , image_landmarks[54:55] ] ) , landmarks_2D_new, True)[0:2] - + # get corner points in global space g_p = transform_points ( np.float32([(0,0),(1,0),(1,1),(0,1),(0.5,0.5) ]) , mat, True) g_c = g_p[4] @@ -270,16 +293,34 @@ def get_transform_mat (image_landmarks, output_size, face_type, scale=1.0): # calc modifier of diagonal vectors for scale and padding value padding, remove_align = FaceType_to_padding_remove_align.get(face_type, 0.0) mod = (1.0 / scale)* ( npla.norm(g_p[0]-g_p[2])*(padding*np.sqrt(2.0) + 0.5) ) - + if face_type == FaceType.WHOLE_FACE: - # adjust center for WHOLE_FACE, 7% below in order to cover more forehead + # adjust vertical offset for WHOLE_FACE, 7% below in order to cover more forehead vec = (g_p[0]-g_p[3]).astype(np.float32) vec_len = npla.norm(vec) vec /= vec_len - g_c += vec*vec_len*0.07 - - # calc 3 points in global space to estimate 2d affine transform + + elif face_type == FaceType.HEAD: + # assuming image_landmarks are 3D_Landmarks extracted for HEAD, + # adjust horizontal offset according to estimated yaw + yaw = estimate_averaged_yaw(transform_points (image_landmarks, mat, False)) + + hvec = (g_p[0]-g_p[1]).astype(np.float32) + hvec_len = npla.norm(hvec) + hvec /= hvec_len + + yaw *= np.abs(math.tanh(yaw*2)) # Damp near zero + + g_c -= hvec * (yaw * hvec_len / 2.0) + + # adjust vertical offset for HEAD, 50% below + vvec = (g_p[0]-g_p[3]).astype(np.float32) + vvec_len = npla.norm(vvec) + vvec /= vvec_len + g_c += vvec*vvec_len*0.50 + + # calc 3 points in global space to estimate 2d affine transform if not remove_align: l_t = np.array( [ g_c - tb_diag_vec*mod, g_c + bt_diag_vec*mod, @@ -294,10 +335,10 @@ def get_transform_mat (image_landmarks, output_size, face_type, scale=1.0): # get area of face square in global space area = mathlib.polygon_area(l_t[:,0], l_t[:,1] ) - + # calc side of square side = np.float32(math.sqrt(area) / 2) - + # calc 3 points with unrotated square l_t = np.array( [ g_c + [-side,-side], g_c + [ side,-side], @@ -307,14 +348,14 @@ def get_transform_mat (image_landmarks, output_size, face_type, scale=1.0): pts2 = np.float32(( (0,0),(output_size,0),(output_size,output_size) )) mat = cv2.getAffineTransform(l_t,pts2) return mat - + def get_rect_from_landmarks(image_landmarks): mat = get_transform_mat(image_landmarks, 256, FaceType.FULL_NO_ALIGN) - + g_p = transform_points ( np.float32([(0,0),(255,255) ]) , mat, True) - + (l,t,r,b) = g_p[0][0], g_p[0][1], g_p[1][0], g_p[1][1] - + return (l,t,r,b) def expand_eyebrows(lmrks, eyebrows_expand_mod=1.0): @@ -346,7 +387,7 @@ def expand_eyebrows(lmrks, eyebrows_expand_mod=1.0): -def get_image_hull_mask (image_shape, image_landmarks, eyebrows_expand_mod=1.0, ie_polys=None ): +def get_image_hull_mask (image_shape, image_landmarks, eyebrows_expand_mod=1.0 ): hull_mask = np.zeros(image_shape[0:2]+(1,),dtype=np.float32) lmrks = expand_eyebrows(image_landmarks, eyebrows_expand_mod) @@ -365,19 +406,16 @@ def get_image_hull_mask (image_shape, image_landmarks, eyebrows_expand_mod=1.0, merged = np.concatenate(item) cv2.fillConvexPoly(hull_mask, cv2.convexHull(merged), (1,) ) - if ie_polys is not None: - ie_polys.overlay_mask(hull_mask) - return hull_mask - + def get_image_eye_mask (image_shape, image_landmarks): if len(image_landmarks) != 68: raise Exception('get_image_eye_mask works only with 68 landmarks') - + h,w,c = image_shape hull_mask = np.zeros( (h,w,1),dtype=np.float32) - + image_landmarks = image_landmarks.astype(np.int) cv2.fillConvexPoly( hull_mask, cv2.convexHull( image_landmarks[36:42]), (1,) ) @@ -385,7 +423,7 @@ def get_image_eye_mask (image_shape, image_landmarks): dilate = h // 32 hull_mask = cv2.dilate(hull_mask, cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(dilate,dilate)), iterations = 1 ) - + blur = h // 16 blur = blur + (1-blur % 2) hull_mask = cv2.GaussianBlur(hull_mask, (blur, blur) , 0) @@ -393,7 +431,28 @@ def get_image_eye_mask (image_shape, image_landmarks): return hull_mask +def get_image_mouth_mask (image_shape, image_landmarks): + if len(image_landmarks) != 68: + raise Exception('get_image_eye_mask works only with 68 landmarks') + h,w,c = image_shape + + hull_mask = np.zeros( (h,w,1),dtype=np.float32) + + image_landmarks = image_landmarks.astype(np.int) + + cv2.fillConvexPoly( hull_mask, cv2.convexHull( image_landmarks[60:]), (1,) ) + + dilate = h // 32 + hull_mask = cv2.dilate(hull_mask, cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(dilate,dilate)), iterations = 1 ) + + blur = h // 16 + blur = blur + (1-blur % 2) + hull_mask = cv2.GaussianBlur(hull_mask, (blur, blur) , 0) + hull_mask = hull_mask[...,None] + + return hull_mask + def alpha_to_color (img_alpha, color): if len(img_alpha.shape) == 2: img_alpha = img_alpha[...,None] @@ -619,13 +678,13 @@ def mirror_landmarks (landmarks, val): result[:,0] = val - result[:,0] - 1 return result -def get_face_struct_mask (image_shape, image_landmarks, eyebrows_expand_mod=1.0, ie_polys=None, color=(1,) ): +def get_face_struct_mask (image_shape, image_landmarks, eyebrows_expand_mod=1.0, color=(1,) ): mask = np.zeros(image_shape[0:2]+( len(color),),dtype=np.float32) lmrks = expand_eyebrows(image_landmarks, eyebrows_expand_mod) - draw_landmarks (mask, image_landmarks, color=color, draw_circles=False, thickness=2, ie_polys=ie_polys) + draw_landmarks (mask, image_landmarks, color=color, draw_circles=False, thickness=2) return mask - -def draw_landmarks (image, image_landmarks, color=(0,255,0), draw_circles=True, thickness=1, transparent_mask=False, ie_polys=None): + +def draw_landmarks (image, image_landmarks, color=(0,255,0), draw_circles=True, thickness=1, transparent_mask=False): if len(image_landmarks) != 68: raise Exception('get_image_eye_mask works only with 68 landmarks') @@ -645,7 +704,7 @@ def draw_landmarks (image, image_landmarks, color=(0,255,0), draw_circles=True, # closed shapes cv2.polylines(image, tuple(np.array([v]) for v in (right_eye, left_eye, mouth)), True, color, thickness=thickness, lineType=cv2.LINE_AA) - + if draw_circles: # the rest of the cicles for x, y in np.concatenate((right_eyebrow, left_eyebrow, mouth, right_eye, left_eye, nose), axis=0): @@ -655,11 +714,11 @@ def draw_landmarks (image, image_landmarks, color=(0,255,0), draw_circles=True, cv2.circle(image, (x, y), 2, color, lineType=cv2.LINE_AA) if transparent_mask: - mask = get_image_hull_mask (image.shape, image_landmarks, ie_polys=ie_polys) + mask = get_image_hull_mask (image.shape, image_landmarks) image[...] = ( image * (1-mask) + image * mask / 2 )[...] -def draw_rect_landmarks (image, rect, image_landmarks, face_type, face_size=256, transparent_mask=False, ie_polys=None, landmarks_color=(0,255,0)): - draw_landmarks(image, image_landmarks, color=landmarks_color, transparent_mask=transparent_mask, ie_polys=ie_polys) +def draw_rect_landmarks (image, rect, image_landmarks, face_type, face_size=256, transparent_mask=False, landmarks_color=(0,255,0)): + draw_landmarks(image, image_landmarks, color=landmarks_color, transparent_mask=transparent_mask) imagelib.draw_rect (image, rect, (255,0,0), 2 ) image_to_face_mat = get_transform_mat (image_landmarks, face_size, face_type) @@ -668,17 +727,25 @@ def draw_rect_landmarks (image, rect, image_landmarks, face_type, face_size=256, points = transform_points ( [ ( int(face_size*0.05), 0), ( int(face_size*0.1), int(face_size*0.1) ), ( 0, int(face_size*0.1) ) ], image_to_face_mat, True) imagelib.draw_polygon (image, points, (0,0,255), 2) - + def calc_face_pitch(landmarks): if not isinstance(landmarks, np.ndarray): landmarks = np.array (landmarks) t = ( (landmarks[6][1]-landmarks[8][1]) + (landmarks[10][1]-landmarks[8][1]) ) / 2.0 b = landmarks[8][1] return float(b-t) + +def estimate_averaged_yaw(landmarks): + # Works much better than solvePnP if landmarks from "3DFAN" + if not isinstance(landmarks, np.ndarray): + landmarks = np.array (landmarks) + l = ( (landmarks[27][0]-landmarks[0][0]) + (landmarks[28][0]-landmarks[1][0]) + (landmarks[29][0]-landmarks[2][0]) ) / 3.0 + r = ( (landmarks[16][0]-landmarks[27][0]) + (landmarks[15][0]-landmarks[28][0]) + (landmarks[14][0]-landmarks[29][0]) ) / 3.0 + return float(r-l) def estimate_pitch_yaw_roll(aligned_landmarks, size=256): """ - returns pitch,yaw,roll [-pi...+pi] + returns pitch,yaw,roll [-pi/2...+pi/2] """ shape = (size,size) focal_length = shape[1] @@ -688,19 +755,21 @@ def estimate_pitch_yaw_roll(aligned_landmarks, size=256): [0, focal_length, camera_center[1]], [0, 0, 1]], dtype=np.float32) - (_, rotation_vector, translation_vector) = cv2.solvePnP( - landmarks_68_3D, - aligned_landmarks.astype(np.float32), + (_, rotation_vector, _) = cv2.solvePnP( + np.concatenate( (landmarks_68_3D[:27], landmarks_68_3D[30:36]) , axis=0) , + np.concatenate( (aligned_landmarks[:27], aligned_landmarks[30:36]) , axis=0).astype(np.float32), camera_matrix, np.zeros((4, 1)) ) pitch, yaw, roll = mathlib.rotationMatrixToEulerAngles( cv2.Rodrigues(rotation_vector)[0] ) - pitch = np.clip ( pitch, -math.pi, math.pi ) - yaw = np.clip ( yaw , -math.pi, math.pi ) - roll = np.clip ( roll, -math.pi, math.pi ) + + half_pi = math.pi / 2.0 + pitch = np.clip ( pitch, -half_pi, half_pi ) + yaw = np.clip ( yaw , -half_pi, half_pi ) + roll = np.clip ( roll, -half_pi, half_pi ) return -pitch, yaw, roll - + #if remove_align: # bbox = transform_points ( [ (0,0), (0,output_size), (output_size, output_size), (output_size,0) ], mat, True) # #import code @@ -734,48 +803,48 @@ def estimate_pitch_yaw_roll(aligned_landmarks, size=256): """ -def get_averaged_transform_mat (img_landmarks, - img_landmarks_prev, - img_landmarks_next, - average_frame_count, +def get_averaged_transform_mat (img_landmarks, + img_landmarks_prev, + img_landmarks_next, + average_frame_count, average_center_frame_count, output_size, face_type, scale=1.0): - + l_c_list = [] tb_diag_vec_list = [] bt_diag_vec_list = [] mod_list = [] - + count = max(average_frame_count,average_center_frame_count) - for i in range ( -count, count+1, 1 ): + for i in range ( -count, count+1, 1 ): if i < 0: lmrks = img_landmarks_prev[i] if -i < len(img_landmarks_prev) else None elif i > 0: lmrks = img_landmarks_next[i] if i < len(img_landmarks_next) else None else: lmrks = img_landmarks - + if lmrks is None: continue - + l_c, tb_diag_vec, bt_diag_vec, mod = get_transform_mat_data (lmrks, face_type, scale=scale) - + if i >= -average_frame_count and i <= average_frame_count: tb_diag_vec_list.append(tb_diag_vec) bt_diag_vec_list.append(bt_diag_vec) mod_list.append(mod) - + if i >= -average_center_frame_count and i <= average_center_frame_count: l_c_list.append(l_c) - + tb_diag_vec = np.mean( np.array(tb_diag_vec_list), axis=0 ) bt_diag_vec = np.mean( np.array(bt_diag_vec_list), axis=0 ) - mod = np.mean( np.array(mod_list), axis=0 ) + mod = np.mean( np.array(mod_list), axis=0 ) l_c = np.mean( np.array(l_c_list), axis=0 ) return get_transform_mat_by_data (l_c, tb_diag_vec, bt_diag_vec, mod, output_size, face_type) - - + + def get_transform_mat (image_landmarks, output_size, face_type, scale=1.0): if not isinstance(image_landmarks, np.ndarray): image_landmarks = np.array (image_landmarks) @@ -785,7 +854,7 @@ def get_transform_mat (image_landmarks, output_size, face_type, scale=1.0): # estimate landmarks transform from global space to local aligned space with bounds [0..1] mat = umeyama( np.concatenate ( [ image_landmarks[17:49] , image_landmarks[54:55] ] ) , landmarks_2D_new, True)[0:2] - + # get corner points in global space l_p = transform_points ( np.float32([(0,0),(1,0),(1,1),(0,1),(0.5,0.5)]) , mat, True) l_c = l_p[4] @@ -799,7 +868,7 @@ def get_transform_mat (image_landmarks, output_size, face_type, scale=1.0): # calc modifier of diagonal vectors for scale and padding value mod = (1.0 / scale)* ( npla.norm(l_p[0]-l_p[2])*(padding*np.sqrt(2.0) + 0.5) ) - # calc 3 points in global space to estimate 2d affine transform + # calc 3 points in global space to estimate 2d affine transform if not remove_align: l_t = np.array( [ np.round( l_c - tb_diag_vec*mod ), np.round( l_c + bt_diag_vec*mod ), @@ -814,10 +883,10 @@ def get_transform_mat (image_landmarks, output_size, face_type, scale=1.0): # get area of face square in global space area = mathlib.polygon_area(l_t[:,0], l_t[:,1] ) - + # calc side of square side = np.float32(math.sqrt(area) / 2) - + # calc 3 points with unrotated square l_t = np.array( [ np.round( l_c + [-side,-side] ), np.round( l_c + [ side,-side] ), @@ -826,6 +895,6 @@ def get_transform_mat (image_landmarks, output_size, face_type, scale=1.0): # calc affine transform from 3 global space points to 3 local space points size of 'output_size' pts2 = np.float32(( (0,0),(output_size,0),(output_size,output_size) )) mat = cv2.getAffineTransform(l_t,pts2) - + return mat """ \ No newline at end of file diff --git a/facelib/TernausNet.py b/facelib/TernausNet.py deleted file mode 100644 index d955c95..0000000 --- a/facelib/TernausNet.py +++ /dev/null @@ -1,139 +0,0 @@ -import os -import pickle -from functools import partial -from pathlib import Path - -import cv2 -import numpy as np - -from core.interact import interact as io -from core.leras import nn - -class TernausNet(object): - VERSION = 1 - - def __init__ (self, name, resolution, load_weights=True, weights_file_root=None, training=False, place_model_on_cpu=False, run_on_cpu=False, optimizer=None, data_format="NHWC"): - nn.initialize(data_format=data_format) - tf = nn.tf - - if weights_file_root is not None: - weights_file_root = Path(weights_file_root) - else: - weights_file_root = Path(__file__).parent - self.weights_file_root = weights_file_root - - with tf.device ('/CPU:0'): - #Place holders on CPU - self.input_t = tf.placeholder (nn.floatx, nn.get4Dshape(resolution,resolution,3) ) - self.target_t = tf.placeholder (nn.floatx, nn.get4Dshape(resolution,resolution,1) ) - - # Initializing model classes - with tf.device ('/CPU:0' if place_model_on_cpu else '/GPU:0'): - self.net = nn.Ternaus(3, 64, name='Ternaus') - self.net_weights = self.net.get_weights() - - model_name = f'{name}_{resolution}' - - self.model_filename_list = [ [self.net, f'{model_name}.npy'] ] - - if training: - if optimizer is None: - raise ValueError("Optimizer should be provided for traning mode.") - - self.opt = optimizer - self.opt.initialize_variables (self.net_weights, vars_on_cpu=place_model_on_cpu) - self.model_filename_list += [ [self.opt, f'{model_name}_opt.npy' ] ] - else: - with tf.device ('/CPU:0' if run_on_cpu else '/GPU:0'): - _, pred = self.net([self.input_t]) - - def net_run(input_np): - return nn.tf_sess.run ( [pred], feed_dict={self.input_t :input_np})[0] - self.net_run = net_run - - # Loading/initializing all models/optimizers weights - for model, filename in self.model_filename_list: - do_init = not load_weights - - if not do_init: - do_init = not model.load_weights( self.weights_file_root / filename ) - - if do_init: - model.init_weights() - if model == self.net: - try: - with open( Path(__file__).parent / 'vgg11_enc_weights.npy', 'rb' ) as f: - d = pickle.loads (f.read()) - - for i in [0,3,6,8,11,13,16,18]: - model.get_layer_by_name ('features_%d' % i).set_weights ( d['features.%d' % i] ) - except: - io.log_err("Unable to load VGG11 pretrained weights from vgg11_enc_weights.npy") - - def save_weights(self): - for model, filename in io.progress_bar_generator(self.model_filename_list, "Saving", leave=False): - model.save_weights( self.weights_file_root / filename ) - - def extract (self, input_image): - input_shape_len = len(input_image.shape) - if input_shape_len == 3: - input_image = input_image[None,...] - - result = np.clip ( self.net_run(input_image), 0, 1.0 ) - result[result < 0.1] = 0 #get rid of noise - - if input_shape_len == 3: - result = result[0] - - return result - -""" -if load_weights: - self.net.load_weights (self.weights_path) -else: - self.net.init_weights() - -if load_weights: - self.opt.load_weights (self.opt_path) -else: - self.opt.init_weights() -""" -""" -if training: - try: - with open( Path(__file__).parent / 'vgg11_enc_weights.npy', 'rb' ) as f: - d = pickle.loads (f.read()) - - for i in [0,3,6,8,11,13,16,18]: - s = 'features.%d' % i - - self.model.get_layer (s).set_weights ( d[s] ) - except: - io.log_err("Unable to load VGG11 pretrained weights from vgg11_enc_weights.npy") - - conv_weights_list = [] - for layer in self.model.layers: - if 'CA.' in layer.name: - conv_weights_list += [layer.weights[0]] #Conv2D kernel_weights - CAInitializerMP ( conv_weights_list ) -""" - - - -""" -if training: - inp_t = Input ( (resolution, resolution, 3) ) - real_t = Input ( (resolution, resolution, 1) ) - out_t = self.model(inp_t) - - loss = K.mean(10*K.binary_crossentropy(real_t,out_t) ) - - out_t_diff1 = out_t[:, 1:, :, :] - out_t[:, :-1, :, :] - out_t_diff2 = out_t[:, :, 1:, :] - out_t[:, :, :-1, :] - - total_var_loss = K.mean( 0.1*K.abs(out_t_diff1), axis=[1, 2, 3] ) + K.mean( 0.1*K.abs(out_t_diff2), axis=[1, 2, 3] ) - - opt = Adam(lr=0.0001, beta_1=0.5, beta_2=0.999, tf_cpu_mode=2) - - self.train_func = K.function ( [inp_t, real_t], [K.mean(loss)], opt.get_updates( [loss], self.model.trainable_weights) ) -""" diff --git a/facelib/XSegNet.py b/facelib/XSegNet.py index 35f2cef..ff2bd08 100644 --- a/facelib/XSegNet.py +++ b/facelib/XSegNet.py @@ -14,61 +14,75 @@ class XSegNet(object): VERSION = 1 def __init__ (self, name, - resolution, + resolution=256, load_weights=True, weights_file_root=None, training=False, place_model_on_cpu=False, run_on_cpu=False, optimizer=None, - data_format="NHWC"): - + data_format="NHWC", + raise_on_no_model_files=False): + + self.resolution = resolution + self.weights_file_root = Path(weights_file_root) if weights_file_root is not None else Path(__file__).parent + nn.initialize(data_format=data_format) tf = nn.tf - - self.weights_file_root = Path(weights_file_root) if weights_file_root is not None else Path(__file__).parent - + + model_name = f'{name}_{resolution}' + self.model_filename_list = [] + with tf.device ('/CPU:0'): #Place holders on CPU self.input_t = tf.placeholder (nn.floatx, nn.get4Dshape(resolution,resolution,3) ) self.target_t = tf.placeholder (nn.floatx, nn.get4Dshape(resolution,resolution,1) ) # Initializing model classes - with tf.device ('/CPU:0' if place_model_on_cpu else '/GPU:0'): + with tf.device ('/CPU:0' if place_model_on_cpu else nn.tf_default_device_name): self.model = nn.XSeg(3, 32, 1, name=name) self.model_weights = self.model.get_weights() + if training: + if optimizer is None: + raise ValueError("Optimizer should be provided for training mode.") + self.opt = optimizer + self.opt.initialize_variables (self.model_weights, vars_on_cpu=place_model_on_cpu) + self.model_filename_list += [ [self.opt, f'{model_name}_opt.npy' ] ] + + + self.model_filename_list += [ [self.model, f'{model_name}.npy'] ] - model_name = f'{name}_{resolution}' - - self.model_filename_list = [ [self.model, f'{model_name}.npy'] ] - - if training: - if optimizer is None: - raise ValueError("Optimizer should be provided for training mode.") - - self.opt = optimizer - self.opt.initialize_variables (self.model_weights, vars_on_cpu=place_model_on_cpu) - self.model_filename_list += [ [self.opt, f'{model_name}_opt.npy' ] ] - else: - with tf.device ('/CPU:0' if run_on_cpu else '/GPU:0'): + if not training: + with tf.device ('/CPU:0' if run_on_cpu else nn.tf_default_device_name): _, pred = self.model(self.input_t) def net_run(input_np): return nn.tf_sess.run ( [pred], feed_dict={self.input_t :input_np})[0] self.net_run = net_run + self.initialized = True # Loading/initializing all models/optimizers weights for model, filename in self.model_filename_list: do_init = not load_weights if not do_init: - do_init = not model.load_weights( self.weights_file_root / filename ) + model_file_path = self.weights_file_root / filename + do_init = not model.load_weights( model_file_path ) + if do_init: + if raise_on_no_model_files: + raise Exception(f'{model_file_path} does not exists.') + if not training: + self.initialized = False + break if do_init: model.init_weights() - - def flow(self, x): - return self.model(x) + + def get_resolution(self): + return self.resolution + + def flow(self, x, pretrain=False): + return self.model(x, pretrain=pretrain) def get_weights(self): return self.model_weights @@ -78,7 +92,10 @@ class XSegNet(object): model.save_weights( self.weights_file_root / filename ) def extract (self, input_image): - input_shape_len = len(input_image.shape) + if not self.initialized: + return 0.5*np.ones ( (self.resolution, self.resolution, 1), nn.floatx.as_numpy_dtype ) + + input_shape_len = len(input_image.shape) if input_shape_len == 3: input_image = input_image[None,...] diff --git a/facelib/__init__.py b/facelib/__init__.py index e900019..e46ca51 100644 --- a/facelib/__init__.py +++ b/facelib/__init__.py @@ -2,5 +2,4 @@ from .FaceType import FaceType from .S3FDExtractor import S3FDExtractor from .FANExtractor import FANExtractor from .FaceEnhancer import FaceEnhancer -from .TernausNet import TernausNet from .XSegNet import XSegNet \ No newline at end of file diff --git a/facelib/vgg11_enc_weights.npy b/facelib/vgg11_enc_weights.npy deleted file mode 100644 index ea9df4e..0000000 Binary files a/facelib/vgg11_enc_weights.npy and /dev/null differ diff --git a/localization/__init__.py b/localization/__init__.py index f3bcf09..ccd8c6e 100644 --- a/localization/__init__.py +++ b/localization/__init__.py @@ -1,2 +1,2 @@ -from .localization import get_default_ttf_font_name +from .localization import StringsDB, system_language, get_default_ttf_font_name diff --git a/localization/localization.py b/localization/localization.py index a603285..3df7bbd 100644 --- a/localization/localization.py +++ b/localization/localization.py @@ -4,23 +4,25 @@ import locale system_locale = locale.getdefaultlocale()[0] # system_locale may be nil system_language = system_locale[0:2] if system_locale is not None else "en" +if system_language not in ['en','ru','zh']: + system_language = 'en' windows_font_name_map = { 'en' : 'cour', 'ru' : 'cour', - 'zn' : 'simsun_01' + 'zh' : 'simsun_01' } darwin_font_name_map = { 'en' : 'cour', 'ru' : 'cour', - 'zn' : 'Apple LiSung Light' + 'zh' : 'Apple LiSung Light' } linux_font_name_map = { 'en' : 'cour', 'ru' : 'cour', - 'zn' : 'cour' + 'zh' : 'cour' } def get_default_ttf_font_name(): @@ -28,3 +30,13 @@ def get_default_ttf_font_name(): if platform[0:3] == 'win': return windows_font_name_map.get(system_language, 'cour') elif platform == 'darwin': return darwin_font_name_map.get(system_language, 'cour') else: return linux_font_name_map.get(system_language, 'cour') + +SID_HOT_KEY = 1 + +if system_language == 'en': + StringsDB = {'S_HOT_KEY' : 'hot key'} +elif system_language == 'ru': + StringsDB = {'S_HOT_KEY' : 'горячая клавиша'} +elif system_language == 'zh': + StringsDB = {'S_HOT_KEY' : '热键'} + \ No newline at end of file diff --git a/main.py b/main.py index 1143b16..b821910 100644 --- a/main.py +++ b/main.py @@ -22,6 +22,8 @@ if __name__ == "__main__": def __call__(self, parser, namespace, values, option_string=None): setattr(namespace, self.dest, os.path.abspath(os.path.expanduser(values))) + exit_code = 0 + parser = argparse.ArgumentParser() subparsers = parser.add_subparsers() @@ -36,6 +38,9 @@ if __name__ == "__main__": manual_output_debug_fix = arguments.manual_output_debug_fix, manual_window_size = arguments.manual_window_size, face_type = arguments.face_type, + max_faces_from_image = arguments.max_faces_from_image, + image_size = arguments.image_size, + jpeg_quality = arguments.jpeg_quality, cpu_only = arguments.cpu_only, force_gpu_idxs = [ int(x) for x in arguments.force_gpu_idxs.split(',') ] if arguments.force_gpu_idxs is not None else None, ) @@ -46,7 +51,10 @@ if __name__ == "__main__": p.add_argument('--output-dir', required=True, action=fixPathAction, dest="output_dir", help="Output directory. This is where the extracted files will be stored.") p.add_argument('--output-debug', action="store_true", dest="output_debug", default=None, help="Writes debug images to _debug\ directory.") p.add_argument('--no-output-debug', action="store_false", dest="output_debug", default=None, help="Don't writes debug images to _debug\ directory.") - p.add_argument('--face-type', dest="face_type", choices=['half_face', 'full_face', 'whole_face', 'head', 'full_face_no_align', 'mark_only'], default='full_face', help="Default 'full_face'. Don't change this option, currently all models uses 'full_face'") + p.add_argument('--face-type', dest="face_type", choices=['half_face', 'full_face', 'whole_face', 'head', 'mark_only'], default=None) + p.add_argument('--max-faces-from-image', type=int, dest="max_faces_from_image", default=None, help="Max faces from image.") + p.add_argument('--image-size', type=int, dest="image_size", default=None, help="Output image size.") + p.add_argument('--jpeg-quality', type=int, dest="jpeg_quality", default=None, help="Jpeg quality.") p.add_argument('--manual-fix', action="store_true", dest="manual_fix", default=False, help="Enables manual extract only frames where faces were not recognized.") p.add_argument('--manual-output-debug-fix', action="store_true", dest="manual_output_debug_fix", default=False, help="Performs manual reextract input-dir frames which were deleted from [output_dir]_debug\ dir.") p.add_argument('--manual-window-size', type=int, dest="manual_window_size", default=1368, help="Manual fix window size. Default: 1368.") @@ -62,7 +70,7 @@ if __name__ == "__main__": p = subparsers.add_parser( "sort", help="Sort faces in a directory.") p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory. A directory containing the files you wish to process.") - p.add_argument('--by', dest="sort_by_method", default=None, choices=("blur", "face-yaw", "face-pitch", "face-source-rect-size", "hist", "hist-dissim", "brightness", "hue", "black", "origname", "oneface", "final", "final-faster", "absdiff"), help="Method of sorting. 'origname' sort by original filename to recover original sequence." ) + p.add_argument('--by', dest="sort_by_method", default=None, choices=("blur", "motion-blur", "face-yaw", "face-pitch", "face-source-rect-size", "hist", "hist-dissim", "brightness", "hue", "black", "origname", "oneface", "final-by-blur", "final-by-size", "absdiff"), help="Method of sorting. 'origname' sort by original filename to recover original sequence." ) p.set_defaults (func=process_sort) def process_util(arguments): @@ -75,12 +83,6 @@ if __name__ == "__main__": if arguments.recover_original_aligned_filename: Util.recover_original_aligned_filename (input_path=arguments.input_dir) - #if arguments.remove_fanseg: - # Util.remove_fanseg_folder (input_path=arguments.input_dir) - - if arguments.remove_ie_polys: - Util.remove_ie_polys_folder (input_path=arguments.input_dir) - if arguments.save_faceset_metadata: Util.save_faceset_metadata_folder (input_path=arguments.input_dir) @@ -96,17 +98,20 @@ if __name__ == "__main__": io.log_info ("Performing faceset unpacking...\r\n") from samplelib import PackedFaceset PackedFaceset.unpack( Path(arguments.input_dir) ) + + if arguments.export_faceset_mask: + io.log_info ("Exporting faceset mask..\r\n") + Util.export_faceset_mask( Path(arguments.input_dir) ) p = subparsers.add_parser( "util", help="Utilities.") p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory. A directory containing the files you wish to process.") p.add_argument('--add-landmarks-debug-images', action="store_true", dest="add_landmarks_debug_images", default=False, help="Add landmarks debug image for aligned faces.") p.add_argument('--recover-original-aligned-filename', action="store_true", dest="recover_original_aligned_filename", default=False, help="Recover original aligned filename.") - #p.add_argument('--remove-fanseg', action="store_true", dest="remove_fanseg", default=False, help="Remove fanseg mask from aligned faces.") - p.add_argument('--remove-ie-polys', action="store_true", dest="remove_ie_polys", default=False, help="Remove ie_polys from aligned faces.") p.add_argument('--save-faceset-metadata', action="store_true", dest="save_faceset_metadata", default=False, help="Save faceset metadata to file.") p.add_argument('--restore-faceset-metadata', action="store_true", dest="restore_faceset_metadata", default=False, help="Restore faceset metadata to file. Image filenames must be the same as used with save.") p.add_argument('--pack-faceset', action="store_true", dest="pack_faceset", default=False, help="") p.add_argument('--unpack-faceset', action="store_true", dest="unpack_faceset", default=False, help="") + p.add_argument('--export-faceset-mask', action="store_true", dest="export_faceset_mask", default=False, help="") p.set_defaults (func=process_util) @@ -124,6 +129,7 @@ if __name__ == "__main__": 'force_model_name' : arguments.force_model_name, 'force_gpu_idxs' : [ int(x) for x in arguments.force_gpu_idxs.split(',') ] if arguments.force_gpu_idxs is not None else None, 'cpu_only' : arguments.cpu_only, + 'silent_start' : arguments.silent_start, 'execute_programs' : [ [int(x[0]), x[1] ] for x in arguments.execute_program ], 'debug' : arguments.debug, } @@ -142,8 +148,20 @@ if __name__ == "__main__": p.add_argument('--force-model-name', dest="force_model_name", default=None, help="Forcing to choose model name from model/ folder.") p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Train on CPU.") p.add_argument('--force-gpu-idxs', dest="force_gpu_idxs", default=None, help="Force to choose GPU indexes separated by comma.") + p.add_argument('--silent-start', action="store_true", dest="silent_start", default=False, help="Silent start. Automatically chooses Best GPU and last used model.") + p.add_argument('--execute-program', dest="execute_program", default=[], action='append', nargs='+') p.set_defaults (func=process_train) + + def process_exportdfm(arguments): + osex.set_process_lowest_prio() + from mainscripts import ExportDFM + ExportDFM.main(model_class_name = arguments.model_name, saved_models_path = Path(arguments.model_dir)) + + p = subparsers.add_parser( "exportdfm", help="Export model to use in DeepFaceLive.") + p.add_argument('--model-dir', required=True, action=fixPathAction, dest="model_dir", help="Saved models dir.") + p.add_argument('--model', required=True, dest="model_name", choices=pathex.get_all_dir_names_startswith ( Path(__file__).parent / 'models' , 'Model_'), help="Model class name.") + p.set_defaults (func=process_exportdfm) def process_merge(arguments): osex.set_process_lowest_prio() @@ -232,19 +250,6 @@ if __name__ == "__main__": p.set_defaults(func=process_videoed_video_from_sequence) - def process_labelingtool_edit_mask(arguments): - from mainscripts import MaskEditorTool - MaskEditorTool.mask_editor_main (arguments.input_dir, arguments.confirmed_dir, arguments.skipped_dir, no_default_mask=arguments.no_default_mask) - - labeling_parser = subparsers.add_parser( "labelingtool", help="Labeling tool.").add_subparsers() - p = labeling_parser.add_parser ( "edit_mask", help="") - p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory of aligned faces.") - p.add_argument('--confirmed-dir', required=True, action=fixPathAction, dest="confirmed_dir", help="This is where the labeled faces will be stored.") - p.add_argument('--skipped-dir', required=True, action=fixPathAction, dest="skipped_dir", help="This is where the labeled faces will be stored.") - p.add_argument('--no-default-mask', action="store_true", dest="no_default_mask", default=False, help="Don't use default mask.") - - p.set_defaults(func=process_labelingtool_edit_mask) - facesettool_parser = subparsers.add_parser( "facesettool", help="Faceset tools.").add_subparsers() def process_faceset_enhancer(arguments): @@ -262,37 +267,78 @@ if __name__ == "__main__": p.set_defaults(func=process_faceset_enhancer) + + p = facesettool_parser.add_parser ("resize", help="Resize DFL faceset.") + p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory of aligned faces.") + + def process_faceset_resizer(arguments): + osex.set_process_lowest_prio() + from mainscripts import FacesetResizer + FacesetResizer.process_folder ( Path(arguments.input_dir) ) + p.set_defaults(func=process_faceset_resizer) + def process_dev_test(arguments): osex.set_process_lowest_prio() from mainscripts import dev_misc - dev_misc.dev_test( arguments.input_dir ) + dev_misc.dev_gen_mask_files( arguments.input_dir ) p = subparsers.add_parser( "dev_test", help="") p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir") p.set_defaults (func=process_dev_test) - # ========== XSeg util - xseg_parser = subparsers.add_parser( "xseg", help="XSeg utils.").add_subparsers() - - def process_xseg_merge(arguments): - osex.set_process_lowest_prio() - from mainscripts import XSegUtil - XSegUtil.merge(arguments.input_dir) - p = xseg_parser.add_parser( "merge", help="") - p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir") - - p.set_defaults (func=process_xseg_merge) + # ========== XSeg + xseg_parser = subparsers.add_parser( "xseg", help="XSeg tools.").add_subparsers() - def process_xseg_split(arguments): - osex.set_process_lowest_prio() - from mainscripts import XSegUtil - XSegUtil.split(arguments.input_dir) + p = xseg_parser.add_parser( "editor", help="XSeg editor.") - p = xseg_parser.add_parser( "split", help="") + def process_xsegeditor(arguments): + osex.set_process_lowest_prio() + from XSegEditor import XSegEditor + global exit_code + exit_code = XSegEditor.start (Path(arguments.input_dir)) + p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir") - p.set_defaults (func=process_xseg_split) + p.set_defaults (func=process_xsegeditor) + + p = xseg_parser.add_parser( "apply", help="Apply trained XSeg model to the extracted faces.") + def process_xsegapply(arguments): + osex.set_process_lowest_prio() + from mainscripts import XSegUtil + XSegUtil.apply_xseg (Path(arguments.input_dir), Path(arguments.model_dir)) + p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir") + p.add_argument('--model-dir', required=True, action=fixPathAction, dest="model_dir") + p.set_defaults (func=process_xsegapply) + + + p = xseg_parser.add_parser( "remove", help="Remove applied XSeg masks from the extracted faces.") + def process_xsegremove(arguments): + osex.set_process_lowest_prio() + from mainscripts import XSegUtil + XSegUtil.remove_xseg (Path(arguments.input_dir) ) + p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir") + p.set_defaults (func=process_xsegremove) + + + p = xseg_parser.add_parser( "remove_labels", help="Remove XSeg labels from the extracted faces.") + def process_xsegremovelabels(arguments): + osex.set_process_lowest_prio() + from mainscripts import XSegUtil + XSegUtil.remove_xseg_labels (Path(arguments.input_dir) ) + p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir") + p.set_defaults (func=process_xsegremovelabels) + + + p = xseg_parser.add_parser( "fetch", help="Copies faces containing XSeg polygons in _xseg dir.") + + def process_xsegfetch(arguments): + osex.set_process_lowest_prio() + from mainscripts import XSegUtil + XSegUtil.fetch_xseg (Path(arguments.input_dir) ) + p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir") + p.set_defaults (func=process_xsegfetch) + def bad_args(arguments): parser.print_help() exit(0) @@ -301,7 +347,10 @@ if __name__ == "__main__": arguments = parser.parse_args() arguments.func(arguments) - print ("Done.") + if exit_code == 0: + print ("Done.") + + exit(exit_code) ''' import code diff --git a/mainscripts/ExportDFM.py b/mainscripts/ExportDFM.py new file mode 100644 index 0000000..cf7d64e --- /dev/null +++ b/mainscripts/ExportDFM.py @@ -0,0 +1,22 @@ +import os +import sys +import traceback +import queue +import threading +import time +import numpy as np +import itertools +from pathlib import Path +from core import pathex +from core import imagelib +import cv2 +import models +from core.interact import interact as io + + +def main(model_class_name, saved_models_path): + model = models.import_model(model_class_name)( + is_exporting=True, + saved_models_path=saved_models_path, + cpu_only=True) + model.export_dfm () diff --git a/mainscripts/Extractor.py b/mainscripts/Extractor.py index c2855ac..365804f 100644 --- a/mainscripts/Extractor.py +++ b/mainscripts/Extractor.py @@ -10,11 +10,12 @@ from pathlib import Path import cv2 import numpy as np +from numpy import linalg as npla import facelib from core import imagelib from core import mathlib -from facelib import FaceType, LandmarksProcessor, TernausNet +from facelib import FaceType, LandmarksProcessor from core.interact import interact as io from core.joblib import Subprocessor from core.leras import nn @@ -43,6 +44,7 @@ class ExtractSubprocessor(Subprocessor): def on_initialize(self, client_dict): self.type = client_dict['type'] self.image_size = client_dict['image_size'] + self.jpeg_quality = client_dict['jpeg_quality'] self.face_type = client_dict['face_type'] self.max_faces_from_image = client_dict['max_faces_from_image'] self.device_idx = client_dict['device_idx'] @@ -71,7 +73,9 @@ class ExtractSubprocessor(Subprocessor): self.rects_extractor = facelib.S3FDExtractor(place_model_on_cpu=place_model_on_cpu) if self.type == 'all' or 'landmarks' in self.type: - self.landmarks_extractor = facelib.FANExtractor(place_model_on_cpu=place_model_on_cpu) + # for head type, extract "3D landmarks" + self.landmarks_extractor = facelib.FANExtractor(landmarks_3D=self.face_type >= FaceType.HEAD, + place_model_on_cpu=place_model_on_cpu) self.cached_image = (None, None) @@ -92,7 +96,6 @@ class ExtractSubprocessor(Subprocessor): self.cached_image = ( filepath, image ) h, w, c = image.shape - extract_from_dflimg = (h == w and DFLIMG.load (filepath) is not None) if 'rects' in self.type or self.type == 'all': data = ExtractSubprocessor.Cli.rects_stage (data=data, @@ -104,7 +107,6 @@ class ExtractSubprocessor(Subprocessor): if 'landmarks' in self.type or self.type == 'all': data = ExtractSubprocessor.Cli.landmarks_stage (data=data, image=image, - extract_from_dflimg=extract_from_dflimg, landmarks_extractor=self.landmarks_extractor, rects_extractor=self.rects_extractor, ) @@ -114,7 +116,7 @@ class ExtractSubprocessor(Subprocessor): image=image, face_type=self.face_type, image_size=self.image_size, - extract_from_dflimg=extract_from_dflimg, + jpeg_quality=self.jpeg_quality, output_debug_path=self.output_debug_path, final_output_path=self.final_output_path, ) @@ -144,7 +146,9 @@ class ExtractSubprocessor(Subprocessor): if len(rects) != 0: data.rects_rotation = rot break - if max_faces_from_image != 0 and len(data.rects) > 1: + if max_faces_from_image is not None and \ + max_faces_from_image > 0 and \ + len(data.rects) > 0: data.rects = data.rects[0:max_faces_from_image] return data @@ -152,7 +156,6 @@ class ExtractSubprocessor(Subprocessor): @staticmethod def landmarks_stage(data, image, - extract_from_dflimg, landmarks_extractor, rects_extractor, ): @@ -167,7 +170,7 @@ class ExtractSubprocessor(Subprocessor): elif data.rects_rotation == 270: rotated_image = image.swapaxes( 0,1 )[::-1,:,:] - data.landmarks = landmarks_extractor.extract (rotated_image, data.rects, rects_extractor if (not extract_from_dflimg and data.landmarks_accurate) else None, is_bgr=True) + data.landmarks = landmarks_extractor.extract (rotated_image, data.rects, rects_extractor if (data.landmarks_accurate) else None, is_bgr=True) if data.rects_rotation != 0: for i, (rect, lmrks) in enumerate(zip(data.rects, data.landmarks)): new_rect, new_lmrks = rect, lmrks @@ -197,7 +200,7 @@ class ExtractSubprocessor(Subprocessor): image, face_type, image_size, - extract_from_dflimg = False, + jpeg_quality, output_debug_path=None, final_output_path=None, ): @@ -209,71 +212,53 @@ class ExtractSubprocessor(Subprocessor): if output_debug_path is not None: debug_image = image.copy() - if extract_from_dflimg and len(rects) != 1: - #if re-extracting from dflimg and more than 1 or zero faces detected - dont process and just copy it - print("extract_from_dflimg and len(rects) != 1", filepath ) - output_filepath = final_output_path / filepath.name - if filepath != str(output_file): - shutil.copy ( str(filepath), str(output_filepath) ) - data.final_output_files.append (output_filepath) - else: - face_idx = 0 - for rect, image_landmarks in zip( rects, landmarks ): + face_idx = 0 + for rect, image_landmarks in zip( rects, landmarks ): + if image_landmarks is None: + continue - if extract_from_dflimg and face_idx > 1: - #cannot extract more than 1 face from dflimg - break + rect = np.array(rect) - if image_landmarks is None: + if face_type == FaceType.MARK_ONLY: + image_to_face_mat = None + face_image = image + face_image_landmarks = image_landmarks + else: + image_to_face_mat = LandmarksProcessor.get_transform_mat (image_landmarks, image_size, face_type) + + face_image = cv2.warpAffine(image, image_to_face_mat, (image_size, image_size), cv2.INTER_LANCZOS4) + face_image_landmarks = LandmarksProcessor.transform_points (image_landmarks, image_to_face_mat) + + landmarks_bbox = LandmarksProcessor.transform_points ( [ (0,0), (0,image_size-1), (image_size-1, image_size-1), (image_size-1,0) ], image_to_face_mat, True) + + rect_area = mathlib.polygon_area(np.array(rect[[0,2,2,0]]).astype(np.float32), np.array(rect[[1,1,3,3]]).astype(np.float32)) + landmarks_area = mathlib.polygon_area(landmarks_bbox[:,0].astype(np.float32), landmarks_bbox[:,1].astype(np.float32) ) + + if not data.manual and face_type <= FaceType.FULL_NO_ALIGN and landmarks_area > 4*rect_area: #get rid of faces which umeyama-landmark-area > 4*detector-rect-area continue - rect = np.array(rect) + if output_debug_path is not None: + LandmarksProcessor.draw_rect_landmarks (debug_image, rect, image_landmarks, face_type, image_size, transparent_mask=True) - if face_type == FaceType.MARK_ONLY: - image_to_face_mat = None - face_image = image - face_image_landmarks = image_landmarks - else: - image_to_face_mat = LandmarksProcessor.get_transform_mat (image_landmarks, image_size, face_type) + output_path = final_output_path + if data.force_output_path is not None: + output_path = data.force_output_path - face_image = cv2.warpAffine(image, image_to_face_mat, (image_size, image_size), cv2.INTER_LANCZOS4) - face_image_landmarks = LandmarksProcessor.transform_points (image_landmarks, image_to_face_mat) + output_filepath = output_path / f"{filepath.stem}_{face_idx}.jpg" + cv2_imwrite(output_filepath, face_image, [int(cv2.IMWRITE_JPEG_QUALITY), jpeg_quality ] ) - landmarks_bbox = LandmarksProcessor.transform_points ( [ (0,0), (0,image_size-1), (image_size-1, image_size-1), (image_size-1,0) ], image_to_face_mat, True) + dflimg = DFLJPG.load(output_filepath) + dflimg.set_face_type(FaceType.toString(face_type)) + dflimg.set_landmarks(face_image_landmarks.tolist()) + dflimg.set_source_filename(filepath.name) + dflimg.set_source_rect(rect) + dflimg.set_source_landmarks(image_landmarks.tolist()) + dflimg.set_image_to_face_mat(image_to_face_mat) + dflimg.save() - rect_area = mathlib.polygon_area(np.array(rect[[0,2,2,0]]).astype(np.float32), np.array(rect[[1,1,3,3]]).astype(np.float32)) - landmarks_area = mathlib.polygon_area(landmarks_bbox[:,0].astype(np.float32), landmarks_bbox[:,1].astype(np.float32) ) - - if not data.manual and face_type <= FaceType.FULL_NO_ALIGN and landmarks_area > 4*rect_area: #get rid of faces which umeyama-landmark-area > 4*detector-rect-area - continue - - if output_debug_path is not None: - LandmarksProcessor.draw_rect_landmarks (debug_image, rect, image_landmarks, face_type, image_size, transparent_mask=True) - - output_path = final_output_path - if data.force_output_path is not None: - output_path = data.force_output_path - - if extract_from_dflimg and filepath.suffix == '.jpg': - #if extracting from dflimg and jpg copy it in order not to lose quality - output_filepath = output_path / filepath.name - if filepath != output_filepath: - shutil.copy ( str(filepath), str(output_filepath) ) - else: - output_filepath = output_path / f"{filepath.stem}_{face_idx}.jpg" - cv2_imwrite(output_filepath, face_image, [int(cv2.IMWRITE_JPEG_QUALITY), 90] ) - - DFLJPG.embed_data(output_filepath, face_type=FaceType.toString(face_type), - landmarks=face_image_landmarks.tolist(), - source_filename=filepath.name, - source_rect=rect, - source_landmarks=image_landmarks.tolist(), - image_to_face_mat=image_to_face_mat - ) - - data.final_output_files.append (output_filepath) - face_idx += 1 - data.faces_detected = face_idx + data.final_output_files.append (output_filepath) + face_idx += 1 + data.faces_detected = face_idx if output_debug_path is not None: cv2_imwrite( output_debug_path / (filepath.stem+'.jpg'), debug_image, [int(cv2.IMWRITE_JPEG_QUALITY), 50] ) @@ -297,16 +282,16 @@ class ExtractSubprocessor(Subprocessor): if not cpu_only: if type == 'landmarks-manual': devices = [devices.get_best_device()] - + result = [] - + for device in devices: count = 1 - + if count == 1: result += [ (device.index, 'GPU', device.name, device.total_mem_gb) ] else: - for i in range(count): + for i in range(count): result += [ (device.index, 'GPU', f"{device.name} #{i}", device.total_mem_gb) ] return result @@ -319,7 +304,7 @@ class ExtractSubprocessor(Subprocessor): elif type == 'final': return [ (i, 'CPU', 'CPU%d' % (i), 0 ) for i in (range(min(8, multiprocessing.cpu_count())) if not DEBUG else [0]) ] - def __init__(self, input_data, type, image_size=None, face_type=None, output_debug_path=None, manual_window_size=0, max_faces_from_image=0, final_output_path=None, device_config=None): + def __init__(self, input_data, type, image_size=None, jpeg_quality=None, face_type=None, output_debug_path=None, manual_window_size=0, max_faces_from_image=0, final_output_path=None, device_config=None): if type == 'landmarks-manual': for x in input_data: x.manual = True @@ -328,6 +313,7 @@ class ExtractSubprocessor(Subprocessor): self.type = type self.image_size = image_size + self.jpeg_quality = jpeg_quality self.face_type = face_type self.output_debug_path = output_debug_path self.final_output_path = final_output_path @@ -353,6 +339,7 @@ class ExtractSubprocessor(Subprocessor): self.cache_text_lines_img = (None, None) self.hide_help = False self.landmarks_accurate = True + self.force_landmarks = False self.landmarks = None self.x = 0 @@ -361,6 +348,9 @@ class ExtractSubprocessor(Subprocessor): self.rect_locked = False self.extract_needed = True + self.image = None + self.image_filepath = None + io.progress_bar (None, len (self.input_data)) #override @@ -374,6 +364,7 @@ class ExtractSubprocessor(Subprocessor): def process_info_generator(self): base_dict = {'type' : self.type, 'image_size': self.image_size, + 'jpeg_quality' : self.jpeg_quality, 'face_type': self.face_type, 'max_faces_from_image':self.max_faces_from_image, 'output_debug_path': self.output_debug_path, @@ -392,26 +383,13 @@ class ExtractSubprocessor(Subprocessor): def get_data(self, host_dict): if self.type == 'landmarks-manual': need_remark_face = False - redraw_needed = False while len (self.input_data) > 0: data = self.input_data[0] filepath, data_rects, data_landmarks = data.filepath, data.rects, data.landmarks is_frame_done = False - if need_remark_face: # need remark image from input data that already has a marked face? - need_remark_face = False - if len(data_rects) != 0: # If there was already a face then lock the rectangle to it until the mouse is clicked - self.rect = data_rects.pop() - self.landmarks = data_landmarks.pop() - data_rects.clear() - data_landmarks.clear() - redraw_needed = True - self.rect_locked = True - self.rect_size = ( self.rect[2] - self.rect[0] ) / 2 - self.x = ( self.rect[0] + self.rect[2] ) / 2 - self.y = ( self.rect[1] + self.rect[3] ) / 2 - - if len(data_rects) == 0: + if self.image_filepath != filepath: + self.image_filepath = filepath if self.cache_original_image[0] == filepath: self.original_image = self.cache_original_image[1] else: @@ -435,8 +413,8 @@ class ExtractSubprocessor(Subprocessor): self.text_lines_img = self.cache_text_lines_img[1] else: self.text_lines_img = (imagelib.get_draw_text_lines ( self.image, sh, - [ '[Mouse click] - lock/unlock selection', - '[Mouse wheel] - change rect', + [ '[L Mouse click] - lock/unlock selection. [Mouse wheel] - change rect', + '[R Mouse Click] - manual face rectangle', '[Enter] / [Space] - confirm / skip frame', '[,] [.]- prev frame, next frame. [Q] - skip remaining frames', '[a] - accuracy on/off (more fps)', @@ -445,11 +423,29 @@ class ExtractSubprocessor(Subprocessor): self.cache_text_lines_img = (sh, self.text_lines_img) + if need_remark_face: # need remark image from input data that already has a marked face? + need_remark_face = False + if len(data_rects) != 0: # If there was already a face then lock the rectangle to it until the mouse is clicked + self.rect = data_rects.pop() + self.landmarks = data_landmarks.pop() + data_rects.clear() + data_landmarks.clear() + + self.rect_locked = True + self.rect_size = ( self.rect[2] - self.rect[0] ) / 2 + self.x = ( self.rect[0] + self.rect[2] ) / 2 + self.y = ( self.rect[1] + self.rect[3] ) / 2 + self.redraw() + + if len(data_rects) == 0: + (h,w,c) = self.image.shape while True: io.process_messages(0.0001) - new_x = self.x - new_y = self.y + if not self.force_landmarks: + new_x = self.x + new_y = self.y + new_rect_size = self.rect_size mouse_events = io.get_mouse_events(self.wnd_name) @@ -460,8 +456,19 @@ class ExtractSubprocessor(Subprocessor): diff = 1 if new_rect_size <= 40 else np.clip(new_rect_size / 10, 1, 10) new_rect_size = max (5, new_rect_size + diff*mod) elif ev == io.EVENT_LBUTTONDOWN: - self.rect_locked = not self.rect_locked - self.extract_needed = True + if self.force_landmarks: + self.x = new_x + self.y = new_y + self.force_landmarks = False + self.rect_locked = True + self.redraw() + else: + self.rect_locked = not self.rect_locked + self.extract_needed = True + elif ev == io.EVENT_RBUTTONDOWN: + self.force_landmarks = not self.force_landmarks + if self.force_landmarks: + self.rect_locked = False elif not self.rect_locked: new_x = np.clip (x, 0, w-1) / self.view_scale new_y = np.clip (y, 0, h-1) / self.view_scale @@ -527,11 +534,35 @@ class ExtractSubprocessor(Subprocessor): self.landmarks_accurate = not self.landmarks_accurate break - if self.x != new_x or \ + if self.force_landmarks: + pt2 = np.float32([new_x, new_y]) + pt1 = np.float32([self.x, self.y]) + + pt_vec_len = npla.norm(pt2-pt1) + pt_vec = pt2-pt1 + if pt_vec_len != 0: + pt_vec /= pt_vec_len + + self.rect_size = pt_vec_len + self.rect = ( int(self.x-self.rect_size), + int(self.y-self.rect_size), + int(self.x+self.rect_size), + int(self.y+self.rect_size) ) + + if pt_vec_len > 0: + lmrks = np.concatenate ( (np.zeros ((17,2), np.float32), LandmarksProcessor.landmarks_2D), axis=0 ) + lmrks -= lmrks[30:31,:] + mat = cv2.getRotationMatrix2D( (0, 0), -np.arctan2( pt_vec[1], pt_vec[0] )*180/math.pi , pt_vec_len) + mat[:, 2] += (self.x, self.y) + self.landmarks = LandmarksProcessor.transform_points(lmrks, mat ) + + + self.redraw() + + elif self.x != new_x or \ self.y != new_y or \ self.rect_size != new_rect_size or \ - self.extract_needed or \ - redraw_needed: + self.extract_needed: self.x = new_x self.y = new_y self.rect_size = new_rect_size @@ -540,11 +571,7 @@ class ExtractSubprocessor(Subprocessor): int(self.x+self.rect_size), int(self.y+self.rect_size) ) - if redraw_needed: - redraw_needed = False - return ExtractSubprocessor.Data (filepath, landmarks_accurate=self.landmarks_accurate) - else: - return ExtractSubprocessor.Data (filepath, rects=[self.rect], landmarks_accurate=self.landmarks_accurate) + return ExtractSubprocessor.Data (filepath, rects=[self.rect], landmarks_accurate=self.landmarks_accurate) else: is_frame_done = True @@ -566,6 +593,40 @@ class ExtractSubprocessor(Subprocessor): if not self.type != 'landmarks-manual': self.input_data.insert(0, data) + def redraw(self): + (h,w,c) = self.image.shape + + if not self.hide_help: + image = cv2.addWeighted (self.image,1.0,self.text_lines_img,1.0,0) + else: + image = self.image.copy() + + view_rect = (np.array(self.rect) * self.view_scale).astype(np.int).tolist() + view_landmarks = (np.array(self.landmarks) * self.view_scale).astype(np.int).tolist() + + if self.rect_size <= 40: + scaled_rect_size = h // 3 if w > h else w // 3 + + p1 = (self.x - self.rect_size, self.y - self.rect_size) + p2 = (self.x + self.rect_size, self.y - self.rect_size) + p3 = (self.x - self.rect_size, self.y + self.rect_size) + + wh = h if h < w else w + np1 = (w / 2 - wh / 4, h / 2 - wh / 4) + np2 = (w / 2 + wh / 4, h / 2 - wh / 4) + np3 = (w / 2 - wh / 4, h / 2 + wh / 4) + + mat = cv2.getAffineTransform( np.float32([p1,p2,p3])*self.view_scale, np.float32([np1,np2,np3]) ) + image = cv2.warpAffine(image, mat,(w,h) ) + view_landmarks = LandmarksProcessor.transform_points (view_landmarks, mat) + + landmarks_color = (255,255,0) if self.rect_locked else (0,255,0) + LandmarksProcessor.draw_rect_landmarks (image, view_rect, view_landmarks, self.face_type, self.image_size, landmarks_color=landmarks_color) + self.extract_needed = False + + io.show_image (self.wnd_name, image) + + #override def on_result (self, host_dict, data, result): if self.type == 'landmarks-manual': @@ -574,37 +635,7 @@ class ExtractSubprocessor(Subprocessor): if len(landmarks) != 0 and landmarks[0] is not None: self.landmarks = landmarks[0] - (h,w,c) = self.image.shape - - if not self.hide_help: - image = cv2.addWeighted (self.image,1.0,self.text_lines_img,1.0,0) - else: - image = self.image.copy() - - view_rect = (np.array(self.rect) * self.view_scale).astype(np.int).tolist() - view_landmarks = (np.array(self.landmarks) * self.view_scale).astype(np.int).tolist() - - if self.rect_size <= 40: - scaled_rect_size = h // 3 if w > h else w // 3 - - p1 = (self.x - self.rect_size, self.y - self.rect_size) - p2 = (self.x + self.rect_size, self.y - self.rect_size) - p3 = (self.x - self.rect_size, self.y + self.rect_size) - - wh = h if h < w else w - np1 = (w / 2 - wh / 4, h / 2 - wh / 4) - np2 = (w / 2 + wh / 4, h / 2 - wh / 4) - np3 = (w / 2 - wh / 4, h / 2 + wh / 4) - - mat = cv2.getAffineTransform( np.float32([p1,p2,p3])*self.view_scale, np.float32([np1,np2,np3]) ) - image = cv2.warpAffine(image, mat,(w,h) ) - view_landmarks = LandmarksProcessor.transform_points (view_landmarks, mat) - - landmarks_color = (255,255,0) if self.rect_locked else (0,255,0) - LandmarksProcessor.draw_rect_landmarks (image, view_rect, view_landmarks, self.face_type, self.image_size, landmarks_color=landmarks_color) - self.extract_needed = False - - io.show_image (self.wnd_name, image) + self.redraw() else: self.result.append ( result ) io.progress_bar_inc(1) @@ -681,43 +712,81 @@ def main(detector=None, manual_output_debug_fix=False, manual_window_size=1368, face_type='full_face', - max_faces_from_image=0, + max_faces_from_image=None, + image_size=None, + jpeg_quality=None, cpu_only = False, force_gpu_idxs = None, ): - face_type = FaceType.fromString(face_type) - image_size = 512 - if not input_path.exists(): io.log_err ('Input directory not found. Please ensure it exists.') return + if not output_path.exists(): + output_path.mkdir(parents=True, exist_ok=True) + + if face_type is not None: + face_type = FaceType.fromString(face_type) + + if face_type is None: + if manual_output_debug_fix: + files = pathex.get_image_paths(output_path) + if len(files) != 0: + dflimg = DFLIMG.load(Path(files[0])) + if dflimg is not None and dflimg.has_data(): + face_type = FaceType.fromString ( dflimg.get_face_type() ) + + input_image_paths = pathex.get_image_unique_filestem_paths(input_path, verbose_print_func=io.log_info) + output_images_paths = pathex.get_image_paths(output_path) + output_debug_path = output_path.parent / (output_path.name + '_debug') + + continue_extraction = False + if not manual_output_debug_fix and len(output_images_paths) > 0: + if len(output_images_paths) > 128: + continue_extraction = io.input_bool ("Continue extraction?", True, help_message="Extraction can be continued, but you must specify the same options again.") + + if len(output_images_paths) > 128 and continue_extraction: + try: + input_image_paths = input_image_paths[ [ Path(x).stem for x in input_image_paths ].index ( Path(output_images_paths[-128]).stem.split('_')[0] ) : ] + except: + io.log_err("Error in fetching the last index. Extraction cannot be continued.") + return + elif input_path != output_path: + io.input(f"\n WARNING !!! \n {output_path} contains files! \n They will be deleted. \n Press enter to continue.\n") + for filename in output_images_paths: + Path(filename).unlink() + + device_config = nn.DeviceConfig.GPUIndexes( force_gpu_idxs or nn.ask_choose_device_idxs(choose_only_one=detector=='manual', suggest_all_gpu=True) ) \ + if not cpu_only else nn.DeviceConfig.CPU() + + if face_type is None: + face_type = io.input_str ("Face type", 'wf', ['f','wf','head'], help_message="Full face / whole face / head. 'Whole face' covers full area of face include forehead. 'head' covers full head, but requires XSeg for src and dst faceset.").lower() + face_type = {'f' : FaceType.FULL, + 'wf' : FaceType.WHOLE_FACE, + 'head' : FaceType.HEAD}[face_type] + + if max_faces_from_image is None: + max_faces_from_image = io.input_int(f"Max number of faces from image", 0, help_message="If you extract a src faceset that has frames with a large number of faces, it is advisable to set max faces to 3 to speed up extraction. 0 - unlimited") + + if image_size is None: + image_size = io.input_int(f"Image size", 512 if face_type < FaceType.HEAD else 768, valid_range=[256,2048], help_message="Output image size. The higher image size, the worse face-enhancer works. Use higher than 512 value only if the source image is sharp enough and the face does not need to be enhanced.") + + if jpeg_quality is None: + jpeg_quality = io.input_int(f"Jpeg quality", 90, valid_range=[1,100], help_message="Jpeg quality. The higher jpeg quality the larger the output file size.") + if detector is None: io.log_info ("Choose detector type.") io.log_info ("[0] S3FD") io.log_info ("[1] manual") detector = {0:'s3fd', 1:'manual'}[ io.input_int("", 0, [0,1]) ] - device_config = nn.DeviceConfig.GPUIndexes( force_gpu_idxs or nn.ask_choose_device_idxs(choose_only_one=detector=='manual', suggest_all_gpu=True) ) \ - if not cpu_only else nn.DeviceConfig.CPU() - - output_debug_path = output_path.parent / (output_path.name + '_debug') if output_debug is None: output_debug = io.input_bool (f"Write debug images to {output_debug_path.name}?", False) - if output_path.exists(): - if not manual_output_debug_fix and input_path != output_path: - output_images_paths = pathex.get_image_paths(output_path) - if len(output_images_paths) > 0: - io.input(f"\n WARNING !!! \n {output_path} contains files! \n They will be deleted. \n Press enter to continue.\n") - for filename in output_images_paths: - Path(filename).unlink() - else: - output_path.mkdir(parents=True, exist_ok=True) - - input_path_image_paths = pathex.get_image_unique_filestem_paths(input_path, verbose_print_func=io.log_info) + if output_debug: + output_debug_path.mkdir(parents=True, exist_ok=True) if manual_output_debug_fix: if not output_debug_path.exists(): @@ -727,31 +796,30 @@ def main(detector=None, detector = 'manual' io.log_info('Performing re-extract frames which were deleted from _debug directory.') - input_path_image_paths = DeletedFilesSearcherSubprocessor (input_path_image_paths, pathex.get_image_paths(output_debug_path) ).run() - input_path_image_paths = sorted (input_path_image_paths) - io.log_info('Found %d images.' % (len(input_path_image_paths))) + input_image_paths = DeletedFilesSearcherSubprocessor (input_image_paths, pathex.get_image_paths(output_debug_path) ).run() + input_image_paths = sorted (input_image_paths) + io.log_info('Found %d images.' % (len(input_image_paths))) else: - if output_debug_path.exists(): + if not continue_extraction and output_debug_path.exists(): for filename in pathex.get_image_paths(output_debug_path): Path(filename).unlink() - else: - output_debug_path.mkdir(parents=True, exist_ok=True) - images_found = len(input_path_image_paths) + images_found = len(input_image_paths) faces_detected = 0 if images_found != 0: if detector == 'manual': io.log_info ('Performing manual extract...') - data = ExtractSubprocessor ([ ExtractSubprocessor.Data(Path(filename)) for filename in input_path_image_paths ], 'landmarks-manual', image_size, face_type, output_debug_path if output_debug else None, manual_window_size=manual_window_size, device_config=device_config).run() + data = ExtractSubprocessor ([ ExtractSubprocessor.Data(Path(filename)) for filename in input_image_paths ], 'landmarks-manual', image_size, jpeg_quality, face_type, output_debug_path if output_debug else None, manual_window_size=manual_window_size, device_config=device_config).run() io.log_info ('Performing 3rd pass...') - data = ExtractSubprocessor (data, 'final', image_size, face_type, output_debug_path if output_debug else None, final_output_path=output_path, device_config=device_config).run() + data = ExtractSubprocessor (data, 'final', image_size, jpeg_quality, face_type, output_debug_path if output_debug else None, final_output_path=output_path, device_config=device_config).run() else: io.log_info ('Extracting faces...') - data = ExtractSubprocessor ([ ExtractSubprocessor.Data(Path(filename)) for filename in input_path_image_paths ], + data = ExtractSubprocessor ([ ExtractSubprocessor.Data(Path(filename)) for filename in input_image_paths ], 'all', image_size, + jpeg_quality, face_type, output_debug_path if output_debug else None, max_faces_from_image=max_faces_from_image, @@ -766,8 +834,8 @@ def main(detector=None, else: fix_data = [ ExtractSubprocessor.Data(d.filepath) for d in data if d.faces_detected == 0 ] io.log_info ('Performing manual fix for %d images...' % (len(fix_data)) ) - fix_data = ExtractSubprocessor (fix_data, 'landmarks-manual', image_size, face_type, output_debug_path if output_debug else None, manual_window_size=manual_window_size, device_config=device_config).run() - fix_data = ExtractSubprocessor (fix_data, 'final', image_size, face_type, output_debug_path if output_debug else None, final_output_path=output_path, device_config=device_config).run() + fix_data = ExtractSubprocessor (fix_data, 'landmarks-manual', image_size, jpeg_quality, face_type, output_debug_path if output_debug else None, manual_window_size=manual_window_size, device_config=device_config).run() + fix_data = ExtractSubprocessor (fix_data, 'final', image_size, jpeg_quality, face_type, output_debug_path if output_debug else None, final_output_path=output_path, device_config=device_config).run() faces_detected += sum([d.faces_detected for d in fix_data]) diff --git a/mainscripts/FacesetEnhancer.py b/mainscripts/FacesetEnhancer.py index 5b5433f..3de9cea 100644 --- a/mainscripts/FacesetEnhancer.py +++ b/mainscripts/FacesetEnhancer.py @@ -99,19 +99,23 @@ class FacesetEnhancerSubprocessor(Subprocessor): def process_data(self, filepath): try: dflimg = DFLIMG.load (filepath) - if dflimg is None: - self.log_err ("%s is not a dfl image file" % (filepath.name) ) + if dflimg is None or not dflimg.has_data(): + self.log_err (f"{filepath.name} is not a dfl image file") else: + dfl_dict = dflimg.get_dict() + img = cv2_imread(filepath).astype(np.float32) / 255.0 - img = self.fe.enhance(img) - img = np.clip (img*255, 0, 255).astype(np.uint8) output_filepath = self.output_dirpath / filepath.name cv2_imwrite ( str(output_filepath), img, [int(cv2.IMWRITE_JPEG_QUALITY), 100] ) - dflimg.embed_and_set ( str(output_filepath) ) + + dflimg = DFLIMG.load (output_filepath) + dflimg.set_dict(dfl_dict) + dflimg.save() + return (1, filepath, output_filepath) except: self.log_err (f"Exception occured while processing file {filepath}. Error: {traceback.format_exc()}") diff --git a/mainscripts/FacesetResizer.py b/mainscripts/FacesetResizer.py new file mode 100644 index 0000000..4bcd1b8 --- /dev/null +++ b/mainscripts/FacesetResizer.py @@ -0,0 +1,209 @@ +import multiprocessing +import shutil + +import cv2 +from core import pathex +from core.cv2ex import * +from core.interact import interact as io +from core.joblib import Subprocessor +from DFLIMG import * +from facelib import FaceType, LandmarksProcessor + + +class FacesetResizerSubprocessor(Subprocessor): + + #override + def __init__(self, image_paths, output_dirpath, image_size, face_type=None): + self.image_paths = image_paths + self.output_dirpath = output_dirpath + self.image_size = image_size + self.face_type = face_type + self.result = [] + + super().__init__('FacesetResizer', FacesetResizerSubprocessor.Cli, 600) + + #override + def on_clients_initialized(self): + io.progress_bar (None, len (self.image_paths)) + + #override + def on_clients_finalized(self): + io.progress_bar_close() + + #override + def process_info_generator(self): + base_dict = {'output_dirpath':self.output_dirpath, 'image_size':self.image_size, 'face_type':self.face_type} + + for device_idx in range( min(8, multiprocessing.cpu_count()) ): + client_dict = base_dict.copy() + device_name = f'CPU #{device_idx}' + client_dict['device_name'] = device_name + yield device_name, {}, client_dict + + #override + def get_data(self, host_dict): + if len (self.image_paths) > 0: + return self.image_paths.pop(0) + + #override + def on_data_return (self, host_dict, data): + self.image_paths.insert(0, data) + + #override + def on_result (self, host_dict, data, result): + io.progress_bar_inc(1) + if result[0] == 1: + self.result +=[ (result[1], result[2]) ] + + #override + def get_result(self): + return self.result + + class Cli(Subprocessor.Cli): + + #override + def on_initialize(self, client_dict): + self.output_dirpath = client_dict['output_dirpath'] + self.image_size = client_dict['image_size'] + self.face_type = client_dict['face_type'] + self.log_info (f"Running on { client_dict['device_name'] }") + + #override + def process_data(self, filepath): + try: + dflimg = DFLIMG.load (filepath) + if dflimg is None or not dflimg.has_data(): + self.log_err (f"{filepath.name} is not a dfl image file") + else: + img = cv2_imread(filepath) + h,w = img.shape[:2] + if h != w: + raise Exception(f'w != h in {filepath}') + + image_size = self.image_size + face_type = self.face_type + output_filepath = self.output_dirpath / filepath.name + + if face_type is not None: + lmrks = dflimg.get_landmarks() + mat = LandmarksProcessor.get_transform_mat(lmrks, image_size, face_type) + + img = cv2.warpAffine(img, mat, (image_size, image_size), flags=cv2.INTER_LANCZOS4 ) + img = np.clip(img, 0, 255).astype(np.uint8) + + cv2_imwrite ( str(output_filepath), img, [int(cv2.IMWRITE_JPEG_QUALITY), 100] ) + + dfl_dict = dflimg.get_dict() + dflimg = DFLIMG.load (output_filepath) + dflimg.set_dict(dfl_dict) + + xseg_mask = dflimg.get_xseg_mask() + if xseg_mask is not None: + xseg_res = 256 + + xseg_lmrks = lmrks.copy() + xseg_lmrks *= (xseg_res / w) + xseg_mat = LandmarksProcessor.get_transform_mat(xseg_lmrks, xseg_res, face_type) + + xseg_mask = cv2.warpAffine(xseg_mask, xseg_mat, (xseg_res, xseg_res), flags=cv2.INTER_LANCZOS4 ) + xseg_mask[xseg_mask < 0.5] = 0 + xseg_mask[xseg_mask >= 0.5] = 1 + + dflimg.set_xseg_mask(xseg_mask) + + seg_ie_polys = dflimg.get_seg_ie_polys() + + for poly in seg_ie_polys.get_polys(): + poly_pts = poly.get_pts() + poly_pts = LandmarksProcessor.transform_points(poly_pts, mat) + poly.set_points(poly_pts) + + dflimg.set_seg_ie_polys(seg_ie_polys) + + lmrks = LandmarksProcessor.transform_points(lmrks, mat) + dflimg.set_landmarks(lmrks) + + image_to_face_mat = dflimg.get_image_to_face_mat() + if image_to_face_mat is not None: + image_to_face_mat = LandmarksProcessor.get_transform_mat ( dflimg.get_source_landmarks(), image_size, face_type ) + dflimg.set_image_to_face_mat(image_to_face_mat) + dflimg.set_face_type( FaceType.toString(face_type) ) + dflimg.save() + + else: + dfl_dict = dflimg.get_dict() + + scale = w / image_size + + img = cv2.resize(img, (image_size, image_size), interpolation=cv2.INTER_LANCZOS4) + + cv2_imwrite ( str(output_filepath), img, [int(cv2.IMWRITE_JPEG_QUALITY), 100] ) + + dflimg = DFLIMG.load (output_filepath) + dflimg.set_dict(dfl_dict) + + lmrks = dflimg.get_landmarks() + lmrks /= scale + dflimg.set_landmarks(lmrks) + + seg_ie_polys = dflimg.get_seg_ie_polys() + seg_ie_polys.mult_points( 1.0 / scale) + dflimg.set_seg_ie_polys(seg_ie_polys) + + image_to_face_mat = dflimg.get_image_to_face_mat() + + if image_to_face_mat is not None: + face_type = FaceType.fromString ( dflimg.get_face_type() ) + image_to_face_mat = LandmarksProcessor.get_transform_mat ( dflimg.get_source_landmarks(), image_size, face_type ) + dflimg.set_image_to_face_mat(image_to_face_mat) + dflimg.save() + + return (1, filepath, output_filepath) + except: + self.log_err (f"Exception occured while processing file {filepath}. Error: {traceback.format_exc()}") + + return (0, filepath, None) + +def process_folder ( dirpath): + + image_size = io.input_int(f"New image size", 512, valid_range=[128,2048]) + + face_type = io.input_str ("Change face type", 'same', ['h','mf','f','wf','head','same']).lower() + if face_type == 'same': + face_type = None + else: + face_type = {'h' : FaceType.HALF, + 'mf' : FaceType.MID_FULL, + 'f' : FaceType.FULL, + 'wf' : FaceType.WHOLE_FACE, + 'head' : FaceType.HEAD}[face_type] + + + output_dirpath = dirpath.parent / (dirpath.name + '_resized') + output_dirpath.mkdir (exist_ok=True, parents=True) + + dirpath_parts = '/'.join( dirpath.parts[-2:]) + output_dirpath_parts = '/'.join( output_dirpath.parts[-2:] ) + io.log_info (f"Resizing faceset in {dirpath_parts}") + io.log_info ( f"Processing to {output_dirpath_parts}") + + output_images_paths = pathex.get_image_paths(output_dirpath) + if len(output_images_paths) > 0: + for filename in output_images_paths: + Path(filename).unlink() + + image_paths = [Path(x) for x in pathex.get_image_paths( dirpath )] + result = FacesetResizerSubprocessor ( image_paths, output_dirpath, image_size, face_type).run() + + is_merge = io.input_bool (f"\r\nMerge {output_dirpath_parts} to {dirpath_parts} ?", True) + if is_merge: + io.log_info (f"Copying processed files to {dirpath_parts}") + + for (filepath, output_filepath) in result: + try: + shutil.copy (output_filepath, filepath) + except: + pass + + io.log_info (f"Removing {output_dirpath_parts}") + shutil.rmtree(output_dirpath) diff --git a/mainscripts/MaskEditorTool.py b/mainscripts/MaskEditorTool.py deleted file mode 100644 index 69aaacb..0000000 --- a/mainscripts/MaskEditorTool.py +++ /dev/null @@ -1,570 +0,0 @@ -import os -import sys -import time -import traceback -from pathlib import Path - -import cv2 -import numpy as np -import numpy.linalg as npl - -from core import imagelib -from DFLIMG import * -from facelib import LandmarksProcessor -from core.imagelib import IEPolys -from core.interact import interact as io -from core import pathex -from core.cv2ex import * - - -class MaskEditor: - STATE_NONE=0 - STATE_MASKING=1 - - def __init__(self, img, prev_images, next_images, mask=None, ie_polys=None, get_status_lines_func=None): - self.img = imagelib.normalize_channels (img,3) - h, w, c = img.shape - - if h != w and w != 256: - #to support any square res, scale img,mask and ie_polys to 256, then scale ie_polys back on .get_ie_polys() - raise Exception ("MaskEditor does not support image size != 256x256") - - ph, pw = h // 4, w // 4 #pad wh - - self.prev_images = prev_images - self.next_images = next_images - - if mask is not None: - self.mask = imagelib.normalize_channels (mask,3) - else: - self.mask = np.zeros ( (h,w,3) ) - self.get_status_lines_func = get_status_lines_func - - self.state_prop = self.STATE_NONE - - self.w, self.h = w, h - self.pw, self.ph = pw, ph - self.pwh = np.array([self.pw, self.ph]) - self.pwh2 = np.array([self.pw*2, self.ph*2]) - self.sw, self.sh = w+pw*2, h+ph*2 - self.prwh = 64 #preview wh - - if ie_polys is None: - ie_polys = IEPolys() - self.ie_polys = ie_polys - - self.polys_mask = None - self.preview_images = None - - self.mouse_x = self.mouse_y = 9999 - self.screen_status_block = None - self.screen_status_block_dirty = True - self.screen_changed = True - - def set_state(self, state): - self.state = state - - @property - def state(self): - return self.state_prop - - @state.setter - def state(self, value): - self.state_prop = value - if value == self.STATE_MASKING: - self.ie_polys.dirty = True - - def get_mask(self): - if self.ie_polys.switch_dirty(): - self.screen_status_block_dirty = True - self.ie_mask = img = self.mask.copy() - - self.ie_polys.overlay_mask(img) - - return img - return self.ie_mask - - def get_screen_overlay(self): - img = np.zeros ( (self.sh, self.sw, 3) ) - - if self.state == self.STATE_MASKING: - mouse_xy = self.mouse_xy.copy() + self.pwh - l = self.ie_polys.n_list() - if l.n > 0: - p = l.cur_point().copy() + self.pwh - color = (0,1,0) if l.type == 1 else (0,0,1) - cv2.line(img, tuple(p), tuple(mouse_xy), color ) - - return img - - def undo_to_begin_point(self): - while not self.undo_point(): - pass - - def undo_point(self): - self.screen_changed = True - if self.state == self.STATE_NONE: - if self.ie_polys.n > 0: - self.state = self.STATE_MASKING - - if self.state == self.STATE_MASKING: - if self.ie_polys.n_list().n_dec() == 0 and \ - self.ie_polys.n_dec() == 0: - self.state = self.STATE_NONE - else: - return False - - return True - - def redo_to_end_point(self): - while not self.redo_point(): - pass - - def redo_point(self): - self.screen_changed = True - if self.state == self.STATE_NONE: - if self.ie_polys.n_max > 0: - self.state = self.STATE_MASKING - if self.ie_polys.n == 0: - self.ie_polys.n_inc() - - if self.state == self.STATE_MASKING: - while True: - l = self.ie_polys.n_list() - if l.n_inc() == l.n_max: - if self.ie_polys.n == self.ie_polys.n_max: - break - self.ie_polys.n_inc() - else: - return False - - return True - - def combine_screens(self, screens): - - screens_len = len(screens) - - new_screens = [] - for screen, padded_overlay in screens: - screen_img = np.zeros( (self.sh, self.sw, 3), dtype=np.float32 ) - - screen = imagelib.normalize_channels (screen, 3) - h,w,c = screen.shape - - screen_img[self.ph:-self.ph, self.pw:-self.pw, :] = screen - - if padded_overlay is not None: - screen_img = screen_img + padded_overlay - - screen_img = np.clip(screen_img*255, 0, 255).astype(np.uint8) - new_screens.append(screen_img) - - return np.concatenate (new_screens, axis=1) - - def get_screen_status_block(self, w, c): - if self.screen_status_block_dirty: - self.screen_status_block_dirty = False - lines = [ - 'Polys current/max = %d/%d' % (self.ie_polys.n, self.ie_polys.n_max), - ] - if self.get_status_lines_func is not None: - lines += self.get_status_lines_func() - - lines_count = len(lines) - - - h_line = 21 - h = lines_count * h_line - img = np.ones ( (h,w,c) ) * 0.1 - - for i in range(lines_count): - img[ i*h_line:(i+1)*h_line, 0:w] += \ - imagelib.get_text_image ( (h_line,w,c), lines[i], color=[0.8]*c ) - - self.screen_status_block = np.clip(img*255, 0, 255).astype(np.uint8) - - return self.screen_status_block - - def set_screen_status_block_dirty(self): - self.screen_status_block_dirty = True - - def set_screen_changed(self): - self.screen_changed = True - - def switch_screen_changed(self): - result = self.screen_changed - self.screen_changed = False - return result - - def make_screen(self): - screen_overlay = self.get_screen_overlay() - final_mask = self.get_mask() - - masked_img = self.img*final_mask*0.5 + self.img*(1-final_mask) - - pink = np.full ( (self.h, self.w, 3), (1,0,1) ) - pink_masked_img = self.img*final_mask + pink*(1-final_mask) - - - - - screens = [ (self.img, screen_overlay), - (masked_img, screen_overlay), - (pink_masked_img, screen_overlay), - ] - screens = self.combine_screens(screens) - - if self.preview_images is None: - sh,sw,sc = screens.shape - - prh, prw = self.prwh, self.prwh - - total_w = sum ([ img.shape[1] for (t,img) in self.prev_images ]) + \ - sum ([ img.shape[1] for (t,img) in self.next_images ]) - - total_images_len = len(self.prev_images) + len(self.next_images) - - max_hor_images_count = sw // prw - max_side_images_count = (max_hor_images_count - 1) // 2 - - prev_images = self.prev_images[-max_side_images_count:] - next_images = self.next_images[:max_side_images_count] - - border = 2 - - max_wh_bordered = (prw-border*2, prh-border*2) - - prev_images = [ (t, cv2.resize( imagelib.normalize_channels(img, 3), max_wh_bordered )) for t,img in prev_images ] - next_images = [ (t, cv2.resize( imagelib.normalize_channels(img, 3), max_wh_bordered )) for t,img in next_images ] - - for images in [prev_images, next_images]: - for i, (t, img) in enumerate(images): - new_img = np.zeros ( (prh,prw, sc) ) - new_img[border:-border,border:-border] = img - - if t == 2: - cv2.line (new_img, ( prw//2, int(prh//1.5) ), (int(prw/1.5), prh ) , (0,1,0), thickness=2 ) - cv2.line (new_img, ( int(prw/1.5), prh ), ( prw, prh // 2 ) , (0,1,0), thickness=2 ) - elif t == 1: - cv2.line (new_img, ( prw//2, prh//2 ), ( prw, prh ) , (0,0,1), thickness=2 ) - cv2.line (new_img, ( prw//2, prh ), ( prw, prh // 2 ) , (0,0,1), thickness=2 ) - - images[i] = new_img - - - preview_images = [] - if len(prev_images) > 0: - preview_images += [ np.concatenate (prev_images, axis=1) ] - - img = np.full ( (prh,prw, sc), (0,0,1), dtype=np.float ) - img[border:-border,border:-border] = cv2.resize( self.img, max_wh_bordered ) - - preview_images += [ img ] - - if len(next_images) > 0: - preview_images += [ np.concatenate (next_images, axis=1) ] - - preview_images = np.concatenate ( preview_images, axis=1 ) - - left_pad = sw // 2 - len(prev_images) * prw - prw // 2 - right_pad = sw // 2 - len(next_images) * prw - prw // 2 - - preview_images = np.concatenate ([np.zeros ( (preview_images.shape[0], left_pad, preview_images.shape[2]) ), - preview_images, - np.zeros ( (preview_images.shape[0], right_pad, preview_images.shape[2]) ) - ], axis=1) - self.preview_images = np.clip(preview_images * 255, 0, 255 ).astype(np.uint8) - - status_img = self.get_screen_status_block( screens.shape[1], screens.shape[2] ) - - result = np.concatenate ( [self.preview_images, screens, status_img], axis=0 ) - - return result - - def mask_finish(self, n_clip=True): - if self.state == self.STATE_MASKING: - self.screen_changed = True - if self.ie_polys.n_list().n <= 2: - self.ie_polys.n_dec() - self.state = self.STATE_NONE - if n_clip: - self.ie_polys.n_clip() - - def set_mouse_pos(self,x,y): - if self.preview_images is not None: - y -= self.preview_images.shape[0] - - mouse_x = x % (self.sw) - self.pw - mouse_y = y % (self.sh) - self.ph - - - - if mouse_x != self.mouse_x or mouse_y != self.mouse_y: - self.mouse_xy = np.array( [mouse_x, mouse_y] ) - self.mouse_x, self.mouse_y = self.mouse_xy - self.screen_changed = True - - def mask_point(self, type): - self.screen_changed = True - if self.state == self.STATE_MASKING and \ - self.ie_polys.n_list().type != type: - self.mask_finish() - - elif self.state == self.STATE_NONE: - self.state = self.STATE_MASKING - self.ie_polys.add(type) - - if self.state == self.STATE_MASKING: - self.ie_polys.n_list().add (self.mouse_x, self.mouse_y) - - def get_ie_polys(self): - return self.ie_polys - - def set_ie_polys(self, saved_ie_polys): - self.state = self.STATE_NONE - self.ie_polys = saved_ie_polys - self.redo_to_end_point() - self.mask_finish() - - -def mask_editor_main(input_dir, confirmed_dir=None, skipped_dir=None, no_default_mask=False): - input_path = Path(input_dir) - - confirmed_path = Path(confirmed_dir) - skipped_path = Path(skipped_dir) - - if not input_path.exists(): - raise ValueError('Input directory not found. Please ensure it exists.') - - if not confirmed_path.exists(): - confirmed_path.mkdir(parents=True) - - if not skipped_path.exists(): - skipped_path.mkdir(parents=True) - - if not no_default_mask: - eyebrows_expand_mod = np.clip ( io.input_int ("Default eyebrows expand modifier?", 100, add_info="0..400"), 0, 400 ) / 100.0 - else: - eyebrows_expand_mod = None - - wnd_name = "MaskEditor tool" - io.named_window (wnd_name) - io.capture_mouse(wnd_name) - io.capture_keys(wnd_name) - - cached_images = {} - - image_paths = [ Path(x) for x in pathex.get_image_paths(input_path)] - done_paths = [] - done_images_types = {} - image_paths_total = len(image_paths) - saved_ie_polys = IEPolys() - zoom_factor = 1.0 - preview_images_count = 9 - target_wh = 256 - - do_prev_count = 0 - do_save_move_count = 0 - do_save_count = 0 - do_skip_move_count = 0 - do_skip_count = 0 - - def jobs_count(): - return do_prev_count + do_save_move_count + do_save_count + do_skip_move_count + do_skip_count - - is_exit = False - while not is_exit: - - if len(image_paths) > 0: - filepath = image_paths.pop(0) - else: - filepath = None - - next_image_paths = image_paths[0:preview_images_count] - next_image_paths_names = [ path.name for path in next_image_paths ] - prev_image_paths = done_paths[-preview_images_count:] - prev_image_paths_names = [ path.name for path in prev_image_paths ] - - for key in list( cached_images.keys() ): - if key not in prev_image_paths_names and \ - key not in next_image_paths_names: - cached_images.pop(key) - - for paths in [prev_image_paths, next_image_paths]: - for path in paths: - if path.name not in cached_images: - cached_images[path.name] = cv2_imread(str(path)) / 255.0 - - if filepath is not None: - dflimg = DFLIMG.load (filepath) - - if dflimg is None: - io.log_err ("%s is not a dfl image file" % (filepath.name) ) - continue - else: - lmrks = dflimg.get_landmarks() - ie_polys = IEPolys.load(dflimg.get_ie_polys()) - fanseg_mask = dflimg.get_fanseg_mask() - - if filepath.name in cached_images: - img = cached_images[filepath.name] - else: - img = cached_images[filepath.name] = cv2_imread(str(filepath)) / 255.0 - - if fanseg_mask is not None: - mask = fanseg_mask - else: - if no_default_mask: - mask = np.zeros ( (target_wh,target_wh,3) ) - else: - mask = LandmarksProcessor.get_image_hull_mask( img.shape, lmrks, eyebrows_expand_mod=eyebrows_expand_mod) - else: - img = np.zeros ( (target_wh,target_wh,3) ) - mask = np.ones ( (target_wh,target_wh,3) ) - ie_polys = None - - def get_status_lines_func(): - return ['Progress: %d / %d . Current file: %s' % (len(done_paths), image_paths_total, str(filepath.name) if filepath is not None else "end" ), - '[Left mouse button] - mark include mask.', - '[Right mouse button] - mark exclude mask.', - '[Middle mouse button] - finish current poly.', - '[Mouse wheel] - undo/redo poly or point. [+ctrl] - undo to begin/redo to end', - '[r] - applies edits made to last saved image.', - '[q] - prev image. [w] - skip and move to %s. [e] - save and move to %s. ' % (skipped_path.name, confirmed_path.name), - '[z] - prev image. [x] - skip. [c] - save. ', - 'hold [shift] - speed up the frame counter by 10.', - '[-/+] - window zoom [esc] - quit', - ] - - try: - ed = MaskEditor(img, - [ (done_images_types[name], cached_images[name]) for name in prev_image_paths_names ], - [ (0, cached_images[name]) for name in next_image_paths_names ], - mask, ie_polys, get_status_lines_func) - except Exception as e: - print(e) - continue - - next = False - while not next: - io.process_messages(0.005) - - if jobs_count() == 0: - for (x,y,ev,flags) in io.get_mouse_events(wnd_name): - x, y = int (x / zoom_factor), int(y / zoom_factor) - ed.set_mouse_pos(x, y) - if filepath is not None: - if ev == io.EVENT_LBUTTONDOWN: - ed.mask_point(1) - elif ev == io.EVENT_RBUTTONDOWN: - ed.mask_point(0) - elif ev == io.EVENT_MBUTTONDOWN: - ed.mask_finish() - elif ev == io.EVENT_MOUSEWHEEL: - if flags & 0x80000000 != 0: - if flags & 0x8 != 0: - ed.undo_to_begin_point() - else: - ed.undo_point() - else: - if flags & 0x8 != 0: - ed.redo_to_end_point() - else: - ed.redo_point() - - for key, chr_key, ctrl_pressed, alt_pressed, shift_pressed in io.get_key_events(wnd_name): - if chr_key == 'q' or chr_key == 'z': - do_prev_count = 1 if not shift_pressed else 10 - elif chr_key == '-': - zoom_factor = np.clip (zoom_factor-0.1, 0.1, 4.0) - ed.set_screen_changed() - elif chr_key == '+': - zoom_factor = np.clip (zoom_factor+0.1, 0.1, 4.0) - ed.set_screen_changed() - elif key == 27: #esc - is_exit = True - next = True - break - elif filepath is not None: - if chr_key == 'e': - saved_ie_polys = ed.ie_polys - do_save_move_count = 1 if not shift_pressed else 10 - elif chr_key == 'c': - saved_ie_polys = ed.ie_polys - do_save_count = 1 if not shift_pressed else 10 - elif chr_key == 'w': - do_skip_move_count = 1 if not shift_pressed else 10 - elif chr_key == 'x': - do_skip_count = 1 if not shift_pressed else 10 - elif chr_key == 'r' and saved_ie_polys != None: - ed.set_ie_polys(saved_ie_polys) - - if do_prev_count > 0: - do_prev_count -= 1 - if len(done_paths) > 0: - if filepath is not None: - image_paths.insert(0, filepath) - - filepath = done_paths.pop(-1) - done_images_types[filepath.name] = 0 - - if filepath.parent != input_path: - new_filename_path = input_path / filepath.name - filepath.rename ( new_filename_path ) - image_paths.insert(0, new_filename_path) - else: - image_paths.insert(0, filepath) - - next = True - elif filepath is not None: - if do_save_move_count > 0: - do_save_move_count -= 1 - - ed.mask_finish() - dflimg.embed_and_set (str(filepath), ie_polys=ed.get_ie_polys(), eyebrows_expand_mod=eyebrows_expand_mod ) - - done_paths += [ confirmed_path / filepath.name ] - done_images_types[filepath.name] = 2 - filepath.rename(done_paths[-1]) - - next = True - elif do_save_count > 0: - do_save_count -= 1 - - ed.mask_finish() - dflimg.embed_and_set (str(filepath), ie_polys=ed.get_ie_polys(), eyebrows_expand_mod=eyebrows_expand_mod ) - - done_paths += [ filepath ] - done_images_types[filepath.name] = 2 - - next = True - elif do_skip_move_count > 0: - do_skip_move_count -= 1 - - done_paths += [ skipped_path / filepath.name ] - done_images_types[filepath.name] = 1 - filepath.rename(done_paths[-1]) - - next = True - elif do_skip_count > 0: - do_skip_count -= 1 - - done_paths += [ filepath ] - done_images_types[filepath.name] = 1 - - next = True - else: - do_save_move_count = do_save_count = do_skip_move_count = do_skip_count = 0 - - if jobs_count() == 0: - if ed.switch_screen_changed(): - screen = ed.make_screen() - if zoom_factor != 1.0: - h,w,c = screen.shape - screen = cv2.resize ( screen, ( int(w*zoom_factor), int(h*zoom_factor) ) ) - io.show_image (wnd_name, screen ) - - - io.process_messages(0.005) - - io.destroy_all_windows() diff --git a/mainscripts/Merger.py b/mainscripts/Merger.py index c7d8e0b..0703dc1 100644 --- a/mainscripts/Merger.py +++ b/mainscripts/Merger.py @@ -1,4 +1,5 @@ import math +import multiprocessing import traceback from pathlib import Path @@ -12,8 +13,9 @@ from core.interact import interact as io from core.joblib import MPClassFuncOnDemand, MPFunc from core.leras import nn from DFLIMG import DFLIMG -from facelib import FaceEnhancer, FaceType, LandmarksProcessor, TernausNet, XSegNet -from merger import FrameInfo, MergerConfig, InteractiveMergerSubprocessor +from facelib import FaceEnhancer, FaceType, LandmarksProcessor, XSegNet +from merger import FrameInfo, InteractiveMergerSubprocessor, MergerConfig + def main (model_class_name=None, saved_models_path=None, @@ -47,25 +49,20 @@ def main (model_class_name=None, model = models.import_model(model_class_name)(is_training=False, saved_models_path=saved_models_path, force_gpu_idxs=force_gpu_idxs, + force_model_name=force_model_name, cpu_only=cpu_only) predictor_func, predictor_input_shape, cfg = model.get_MergerConfig() # Preparing MP functions - predictor_func = MPFunc(predictor_func) - - run_on_cpu = len(nn.getCurrentDeviceConfig().devices) == 0 - fanseg_full_face_256_extract_func = MPClassFuncOnDemand(TernausNet, 'extract', - name=f'FANSeg_{FaceType.toString(FaceType.FULL)}', - resolution=256, - place_model_on_cpu=True, - run_on_cpu=run_on_cpu) + predictor_func = MPFunc(predictor_func) + run_on_cpu = len(nn.getCurrentDeviceConfig().devices) == 0 xseg_256_extract_func = MPClassFuncOnDemand(XSegNet, 'extract', name='XSeg', resolution=256, weights_file_root=saved_models_path, - place_model_on_cpu=True, + place_model_on_cpu=True, run_on_cpu=run_on_cpu) face_enhancer_func = MPClassFuncOnDemand(FaceEnhancer, 'enhance', @@ -76,6 +73,9 @@ def main (model_class_name=None, if not is_interactive: cfg.ask_settings() + + subprocess_count = io.input_int("Number of workers?", max(8, multiprocessing.cpu_count()), + valid_range=[1, multiprocessing.cpu_count()], help_message="Specify the number of threads to process. A low value may affect performance. A high value may result in memory error. The value may not be greater than CPU cores." ) input_path_image_paths = pathex.get_image_paths(input_path) @@ -101,14 +101,14 @@ def main (model_class_name=None, def generator(): for filepath in io.progress_bar_generator( pathex.get_image_paths(aligned_path), "Collecting alignments"): filepath = Path(filepath) - yield filepath, DFLIMG.load(filepath) + yield filepath, DFLIMG.load(filepath) alignments = {} multiple_faces_detected = False for filepath, dflimg in generator(): - if dflimg is None: - io.log_err ("%s is not a dfl image file" % (filepath.name) ) + if dflimg is None or not dflimg.has_data(): + io.log_err (f"{filepath.name} is not a dfl image file") continue source_filename = dflimg.get_source_filename() @@ -192,21 +192,21 @@ def main (model_class_name=None, else: if False: pass - else: + else: InteractiveMergerSubprocessor ( is_interactive = is_interactive, merger_session_filepath = model.get_strpath_storage_for_file('merger_session.dat'), predictor_func = predictor_func, predictor_input_shape = predictor_input_shape, face_enhancer_func = face_enhancer_func, - fanseg_full_face_256_extract_func = fanseg_full_face_256_extract_func, xseg_256_extract_func = xseg_256_extract_func, merger_config = cfg, frames = frames, frames_root_path = input_path, output_path = output_path, output_mask_path = output_mask_path, - model_iter = model.get_iter() + model_iter = model.get_iter(), + subprocess_count = subprocess_count, ).run() model.finalize() @@ -221,7 +221,7 @@ filesdata = [] for filepath in io.progress_bar_generator(input_path_image_paths, "Collecting info"): filepath = Path(filepath) - dflimg = DFLIMG.load(filepath) + dflimg = DFLIMG.x(filepath) if dflimg is None: io.log_err ("%s is not a dfl image file" % (filepath.name) ) continue diff --git a/mainscripts/Sorter.py b/mainscripts/Sorter.py index aef1b69..39eec5e 100644 --- a/mainscripts/Sorter.py +++ b/mainscripts/Sorter.py @@ -6,7 +6,6 @@ import sys import tempfile from functools import cmp_to_key from pathlib import Path -from shutil import copyfile import cv2 import numpy as np @@ -24,17 +23,31 @@ from facelib import LandmarksProcessor class BlurEstimatorSubprocessor(Subprocessor): class Cli(Subprocessor.Cli): + def on_initialize(self, client_dict): + self.estimate_motion_blur = client_dict['estimate_motion_blur'] + #override def process_data(self, data): filepath = Path( data[0] ) dflimg = DFLIMG.load (filepath) - if dflimg is not None: - image = cv2_imread( str(filepath) ) - return [ str(filepath), estimate_sharpness(image) ] - else: - self.log_err ("%s is not a dfl image file" % (filepath.name) ) + if dflimg is None or not dflimg.has_data(): + self.log_err (f"{filepath.name} is not a dfl image file") return [ str(filepath), 0 ] + else: + image = cv2_imread( str(filepath) ) + + face_mask = LandmarksProcessor.get_image_hull_mask (image.shape, dflimg.get_landmarks()) + image = (image*face_mask).astype(np.uint8) + + + if self.estimate_motion_blur: + value = cv2.Laplacian(image, cv2.CV_64F, ksize=11).var() + else: + value = estimate_sharpness(image) + + return [ str(filepath), value ] + #override def get_data_name (self, data): @@ -42,8 +55,9 @@ class BlurEstimatorSubprocessor(Subprocessor): return data[0] #override - def __init__(self, input_data ): + def __init__(self, input_data, estimate_motion_blur=False ): self.input_data = input_data + self.estimate_motion_blur = estimate_motion_blur self.img_list = [] self.trash_img_list = [] super().__init__('BlurEstimator', BlurEstimatorSubprocessor.Cli, 60) @@ -62,7 +76,7 @@ class BlurEstimatorSubprocessor(Subprocessor): io.log_info(f'Running on {cpu_count} CPUs') for i in range(cpu_count): - yield 'CPU%d' % (i), {}, {} + yield 'CPU%d' % (i), {}, {'estimate_motion_blur':self.estimate_motion_blur} #override def get_data(self, host_dict): @@ -99,7 +113,18 @@ def sort_by_blur(input_path): img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True) return img_list, trash_img_list + +def sort_by_motion_blur(input_path): + io.log_info ("Sorting by motion blur...") + img_list = [ (filename,[]) for filename in pathex.get_image_paths(input_path) ] + img_list, trash_img_list = BlurEstimatorSubprocessor (img_list, estimate_motion_blur=True).run() + + io.log_info ("Sorting...") + img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True) + + return img_list, trash_img_list + def sort_by_face_yaw(input_path): io.log_info ("Sorting by face yaw...") img_list = [] @@ -109,8 +134,8 @@ def sort_by_face_yaw(input_path): dflimg = DFLIMG.load (filepath) - if dflimg is None: - io.log_err ("%s is not a dfl image file" % (filepath.name) ) + if dflimg is None or not dflimg.has_data(): + io.log_err (f"{filepath.name} is not a dfl image file") trash_img_list.append ( [str(filepath)] ) continue @@ -132,8 +157,8 @@ def sort_by_face_pitch(input_path): dflimg = DFLIMG.load (filepath) - if dflimg is None: - io.log_err ("%s is not a dfl image file" % (filepath.name) ) + if dflimg is None or not dflimg.has_data(): + io.log_err (f"{filepath.name} is not a dfl image file") trash_img_list.append ( [str(filepath)] ) continue @@ -145,7 +170,7 @@ def sort_by_face_pitch(input_path): img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True) return img_list, trash_img_list - + def sort_by_face_source_rect_size(input_path): io.log_info ("Sorting by face rect size...") img_list = [] @@ -155,22 +180,22 @@ def sort_by_face_source_rect_size(input_path): dflimg = DFLIMG.load (filepath) - if dflimg is None: - io.log_err ("%s is not a dfl image file" % (filepath.name) ) + if dflimg is None or not dflimg.has_data(): + io.log_err (f"{filepath.name} is not a dfl image file") trash_img_list.append ( [str(filepath)] ) continue source_rect = dflimg.get_source_rect() rect_area = mathlib.polygon_area(np.array(source_rect[[0,2,2,0]]).astype(np.float32), np.array(source_rect[[1,1,3,3]]).astype(np.float32)) - + img_list.append( [str(filepath), rect_area ] ) io.log_info ("Sorting...") img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True) - return img_list, trash_img_list - - + return img_list, trash_img_list + + class HistSsimSubprocessor(Subprocessor): class Cli(Subprocessor.Cli): @@ -341,7 +366,7 @@ def sort_by_hist_dissim(input_path): image = cv2_imread(str(filepath)) - if dflimg is not None: + if dflimg is not None and dflimg.has_data(): face_mask = LandmarksProcessor.get_image_hull_mask (image.shape, dflimg.get_landmarks()) image = (image*face_mask).astype(np.uint8) @@ -391,8 +416,8 @@ def sort_by_origname(input_path): dflimg = DFLIMG.load (filepath) - if dflimg is None: - io.log_err ("%s is not a dfl image file" % (filepath.name) ) + if dflimg is None or not dflimg.has_data(): + io.log_err (f"{filepath.name} is not a dfl image file") trash_img_list.append( [str(filepath)] ) continue @@ -434,8 +459,8 @@ class FinalLoaderSubprocessor(Subprocessor): try: dflimg = DFLIMG.load (filepath) - if dflimg is None: - self.log_err("%s is not a dfl image file" % (filepath.name)) + if dflimg is None or not dflimg.has_data(): + self.log_err (f"{filepath.name} is not a dfl image file") return [ 1, [str(filepath)] ] bgr = cv2_imread(str(filepath)) @@ -443,13 +468,13 @@ class FinalLoaderSubprocessor(Subprocessor): raise Exception ("Unable to load %s" % (filepath.name) ) gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY) - if self.faster: source_rect = dflimg.get_source_rect() sharpness = mathlib.polygon_area(np.array(source_rect[[0,2,2,0]]).astype(np.float32), np.array(source_rect[[1,1,3,3]]).astype(np.float32)) else: - sharpness = estimate_sharpness(gray) - + face_mask = LandmarksProcessor.get_image_hull_mask (gray.shape, dflimg.get_landmarks()) + sharpness = estimate_sharpness( (gray[...,None]*face_mask).astype(np.uint8) ) + pitch, yaw, roll = LandmarksProcessor.estimate_pitch_yaw_roll ( dflimg.get_landmarks(), size=dflimg.get_shape()[1] ) hist = cv2.calcHist([gray], [0], None, [256], [0, 256]) @@ -585,12 +610,12 @@ class FinalHistDissimSubprocessor(Subprocessor): def get_result(self): return self.result -def sort_best_faster(input_path): +def sort_best_faster(input_path): return sort_best(input_path, faster=True) - + def sort_best(input_path, faster=False): target_count = io.input_int ("Target number of faces?", 2000) - + io.log_info ("Performing sort by best faces.") if faster: io.log_info("Using faster algorithm. Faces will be sorted by source-rect-area instead of blur.") @@ -629,7 +654,7 @@ def sort_best(input_path, faster=False): imgs_per_grad += total_lack // grads - + sharpned_imgs_per_grad = imgs_per_grad*10 for g in io.progress_bar_generator ( range (grads), "Sort by blur"): img_list = yaws_sample_list[g] @@ -769,7 +794,7 @@ def sort_by_absdiff(input_path): outputs_full = [] outputs_remain = [] - + for i in range(batch_size): diff_t = tf.reduce_sum( tf.abs(i_t-j_t[i]), axis=[1,2,3] ) outputs_full.append(diff_t) @@ -872,6 +897,7 @@ def final_process(input_path, img_list, trash_img_list): sort_func_methods = { 'blur': ("blur", sort_by_blur), + 'motion-blur': ("motion_blur", sort_by_motion_blur), 'face-yaw': ("face yaw direction", sort_by_face_yaw), 'face-pitch': ("face pitch direction", sort_by_face_pitch), 'face-source-rect-size' : ("face rect size in source image", sort_by_face_source_rect_size), @@ -899,7 +925,7 @@ def main (input_path, sort_by_method=None): io.log_info(f"[{i}] {desc}") io.log_info("") - id = io.input_int("", 4, valid_list=[*range(len(key_list))] ) + id = io.input_int("", 5, valid_list=[*range(len(key_list))] ) sort_by_method = key_list[id] else: diff --git a/mainscripts/Trainer.py b/mainscripts/Trainer.py index d0d3eb9..df74ca3 100644 --- a/mainscripts/Trainer.py +++ b/mainscripts/Trainer.py @@ -1,4 +1,5 @@ -import sys +import os +import sys import traceback import queue import threading @@ -23,6 +24,7 @@ def trainerThread (s2c, c2s, e, force_model_name=None, force_gpu_idxs=None, cpu_only=None, + silent_start=False, execute_programs = None, debug=False, **kwargs): @@ -30,7 +32,7 @@ def trainerThread (s2c, c2s, e, try: start_time = time.time() - save_interval_min = 15 + save_interval_min = 25 if not training_data_src_path.exists(): training_data_src_path.mkdir(exist_ok=True, parents=True) @@ -40,7 +42,7 @@ def trainerThread (s2c, c2s, e, if not saved_models_path.exists(): saved_models_path.mkdir(exist_ok=True, parents=True) - + model = models.import_model(model_class_name)( is_training=True, saved_models_path=saved_models_path, @@ -52,8 +54,8 @@ def trainerThread (s2c, c2s, e, force_model_name=force_model_name, force_gpu_idxs=force_gpu_idxs, cpu_only=cpu_only, - debug=debug, - ) + silent_start=silent_start, + debug=debug) is_reached_goal = model.is_reached_iter_goal() @@ -117,6 +119,12 @@ def trainerThread (s2c, c2s, e, io.log_info("") io.log_info("Trying to do the first iteration. If an error occurs, reduce the model parameters.") io.log_info("") + + if sys.platform[0:3] == 'win': + io.log_info("!!!") + io.log_info("Windows 10 users IMPORTANT notice. You should set this setting in order to work correctly.") + io.log_info("https://i.imgur.com/B7cmDCB.jpg") + io.log_info("!!!") iter, iter_time = model.train_one_iter() @@ -155,9 +163,13 @@ def trainerThread (s2c, c2s, e, model_save() is_reached_goal = True io.log_info ('You can use preview now.') - - if not is_reached_goal and (time.time() - last_save_time) >= save_interval_min*60: + + need_save = False + while time.time() - last_save_time >= save_interval_min*60: last_save_time += save_interval_min*60 + need_save = True + + if not is_reached_goal and need_save: model_save() send_preview() diff --git a/mainscripts/Util.py b/mainscripts/Util.py index 46e154e..66e751d 100644 --- a/mainscripts/Util.py +++ b/mainscripts/Util.py @@ -5,7 +5,6 @@ import cv2 from DFLIMG import * from facelib import LandmarksProcessor, FaceType -from core.imagelib import IEPolys from core.interact import interact as io from core import pathex from core.cv2ex import * @@ -22,8 +21,11 @@ def save_faceset_metadata_folder(input_path): for filepath in io.progress_bar_generator( pathex.get_image_paths(input_path), "Processing"): filepath = Path(filepath) dflimg = DFLIMG.load (filepath) - - dfl_dict = dflimg.getDFLDictData() + if dflimg is None or not dflimg.has_data(): + io.log_info(f"{filepath} is not a dfl image file") + continue + + dfl_dict = dflimg.get_dict() d[filepath.name] = ( dflimg.get_shape(), dfl_dict ) try: @@ -52,70 +54,29 @@ def restore_faceset_metadata_folder(input_path): except: raise FileNotFoundError(filename) - for filepath in io.progress_bar_generator( pathex.get_image_paths(input_path), "Processing"): - filepath = Path(filepath) + for filepath in io.progress_bar_generator( pathex.get_image_paths(input_path, image_extensions=['.jpg'], return_Path_class=True), "Processing"): + saved_data = d.get(filepath.name, None) + if saved_data is None: + io.log_info(f"No saved metadata for {filepath}") + continue + + shape, dfl_dict = saved_data - shape, dfl_dict = d.get(filepath.name, None) - - img = cv2_imread (str(filepath)) + img = cv2_imread (filepath) if img.shape != shape: - img = cv2.resize (img, (shape[1], shape[0]), cv2.INTER_LANCZOS4 ) + img = cv2.resize (img, (shape[1], shape[0]), interpolation=cv2.INTER_LANCZOS4 ) - if filepath.suffix == '.png': - cv2_imwrite (str(filepath), img) - elif filepath.suffix == '.jpg': - cv2_imwrite (str(filepath), img, [int(cv2.IMWRITE_JPEG_QUALITY), 100] ) + cv2_imwrite (str(filepath), img, [int(cv2.IMWRITE_JPEG_QUALITY), 100] ) if filepath.suffix == '.jpg': - DFLJPG.embed_dfldict( str(filepath), dfl_dict ) + dflimg = DFLJPG.load(filepath) + dflimg.set_dict(dfl_dict) + dflimg.save() else: continue metadata_filepath.unlink() -def remove_ie_polys_file (filepath): - filepath = Path(filepath) - - dflimg = DFLIMG.load (filepath) - if dflimg is None: - io.log_err ("%s is not a dfl image file" % (filepath.name) ) - return - - dflimg.remove_ie_polys() - dflimg.embed_and_set( str(filepath) ) - - -def remove_ie_polys_folder(input_path): - input_path = Path(input_path) - - io.log_info ("Removing ie_polys...\r\n") - - for filepath in io.progress_bar_generator( pathex.get_image_paths(input_path), "Removing"): - filepath = Path(filepath) - remove_ie_polys_file(filepath) - -def remove_fanseg_file (filepath): - filepath = Path(filepath) - - dflimg = DFLIMG.load (filepath) - - if dflimg is None: - io.log_err ("%s is not a dfl image file" % (filepath.name) ) - return - - dflimg.remove_fanseg_mask() - dflimg.embed_and_set( str(filepath) ) - - -def remove_fanseg_folder(input_path): - input_path = Path(input_path) - - io.log_info ("Removing fanseg mask...\r\n") - - for filepath in io.progress_bar_generator( pathex.get_image_paths(input_path), "Removing"): - filepath = Path(filepath) - remove_fanseg_file(filepath) - def add_landmarks_debug_images(input_path): io.log_info ("Adding landmarks debug images...") @@ -126,8 +87,8 @@ def add_landmarks_debug_images(input_path): dflimg = DFLIMG.load (filepath) - if dflimg is None: - io.log_err ("%s is not a dfl image file" % (filepath.name) ) + if dflimg is None or not dflimg.has_data(): + io.log_err (f"{filepath.name} is not a dfl image file") continue if img is not None: @@ -138,7 +99,7 @@ def add_landmarks_debug_images(input_path): rect = dflimg.get_source_rect() LandmarksProcessor.draw_rect_landmarks(img, rect, face_landmarks, FaceType.FULL ) else: - LandmarksProcessor.draw_landmarks(img, face_landmarks, transparent_mask=True, ie_polys=IEPolys.load(dflimg.get_ie_polys()) ) + LandmarksProcessor.draw_landmarks(img, face_landmarks, transparent_mask=True ) @@ -154,8 +115,8 @@ def recover_original_aligned_filename(input_path): dflimg = DFLIMG.load (filepath) - if dflimg is None: - io.log_err ("%s is not a dfl image file" % (filepath.name) ) + if dflimg is None or not dflimg.has_data(): + io.log_err (f"{filepath.name} is not a dfl image file") continue files += [ [filepath, None, dflimg.get_source_filename(), False] ] @@ -199,41 +160,31 @@ def recover_original_aligned_filename(input_path): except: io.log_err ('fail to rename %s' % (fs.name) ) - -""" -def convert_png_to_jpg_file (filepath): - filepath = Path(filepath) - - if filepath.suffix != '.png': - return - - dflpng = DFLPNG.load (str(filepath) ) - if dflpng is None: - io.log_err ("%s is not a dfl png image file" % (filepath.name) ) - return - - dfl_dict = dflpng.getDFLDictData() - - img = cv2_imread (str(filepath)) - new_filepath = str(filepath.parent / (filepath.stem + '.jpg')) - cv2_imwrite ( new_filepath, img, [int(cv2.IMWRITE_JPEG_QUALITY), 100]) - - DFLJPG.embed_data( new_filepath, - face_type=dfl_dict.get('face_type', None), - landmarks=dfl_dict.get('landmarks', None), - ie_polys=dfl_dict.get('ie_polys', None), - source_filename=dfl_dict.get('source_filename', None), - source_rect=dfl_dict.get('source_rect', None), - source_landmarks=dfl_dict.get('source_landmarks', None) ) - - filepath.unlink() - -def convert_png_to_jpg_folder (input_path): - input_path = Path(input_path) - - io.log_info ("Converting PNG to JPG...\r\n") - - for filepath in io.progress_bar_generator( pathex.get_image_paths(input_path), "Converting"): - filepath = Path(filepath) - convert_png_to_jpg_file(filepath) -""" \ No newline at end of file +def export_faceset_mask(input_dir): + for filename in io.progress_bar_generator(pathex.get_image_paths (input_dir), "Processing"): + filepath = Path(filename) + + if '_mask' in filepath.stem: + continue + + mask_filepath = filepath.parent / (filepath.stem+'_mask'+filepath.suffix) + + dflimg = DFLJPG.load(filepath) + + H,W,C = dflimg.shape + + seg_ie_polys = dflimg.get_seg_ie_polys() + + if dflimg.has_xseg_mask(): + mask = dflimg.get_xseg_mask() + mask[mask < 0.5] = 0.0 + mask[mask >= 0.5] = 1.0 + elif seg_ie_polys.has_polys(): + mask = np.zeros ((H,W,1), dtype=np.float32) + seg_ie_polys.overlay_mask(mask) + else: + raise Exception(f'no mask in file {filepath}') + + + cv2_imwrite(mask_filepath, (mask*255).astype(np.uint8), [int(cv2.IMWRITE_JPEG_QUALITY), 100] ) + \ No newline at end of file diff --git a/mainscripts/XSegUtil.py b/mainscripts/XSegUtil.py index 45c64b1..c75a14a 100644 --- a/mainscripts/XSegUtil.py +++ b/mainscripts/XSegUtil.py @@ -1,108 +1,187 @@ -import traceback import json +import shutil +import traceback from pathlib import Path + import numpy as np from core import pathex -from core.imagelib import IEPolys +from core.cv2ex import * from core.interact import interact as io +from core.leras import nn from DFLIMG import * +from facelib import XSegNet, LandmarksProcessor, FaceType +import pickle - -def merge(input_dir): - input_path = Path(input_dir) +def apply_xseg(input_path, model_path): if not input_path.exists(): - raise ValueError('input_dir not found. Please ensure it exists.') + raise ValueError(f'{input_path} not found. Please ensure it exists.') + if not model_path.exists(): + raise ValueError(f'{model_path} not found. Please ensure it exists.') + + face_type = None + + model_dat = model_path / 'XSeg_data.dat' + if model_dat.exists(): + dat = pickle.loads( model_dat.read_bytes() ) + dat_options = dat.get('options', None) + if dat_options is not None: + face_type = dat_options.get('face_type', None) + + + + if face_type is None: + face_type = io.input_str ("XSeg model face type", 'same', ['h','mf','f','wf','head','same'], help_message="Specify face type of trained XSeg model. For example if XSeg model trained as WF, but faceset is HEAD, specify WF to apply xseg only on WF part of HEAD. Default is 'same'").lower() + if face_type == 'same': + face_type = None + + if face_type is not None: + face_type = {'h' : FaceType.HALF, + 'mf' : FaceType.MID_FULL, + 'f' : FaceType.FULL, + 'wf' : FaceType.WHOLE_FACE, + 'head' : FaceType.HEAD}[face_type] + + io.log_info(f'Applying trained XSeg model to {input_path.name}/ folder.') + + device_config = nn.DeviceConfig.ask_choose_device(choose_only_one=True) + nn.initialize(device_config) + + + + xseg = XSegNet(name='XSeg', + load_weights=True, + weights_file_root=model_path, + data_format=nn.data_format, + raise_on_no_model_files=True) + xseg_res = xseg.get_resolution() + images_paths = pathex.get_image_paths(input_path, return_Path_class=True) - - images_processed = 0 + for filepath in io.progress_bar_generator(images_paths, "Processing"): - json_filepath = filepath.parent / (filepath.stem+'.json') - if json_filepath.exists(): - dflimg = DFLIMG.load(filepath) - if dflimg is not None: - try: - json_dict = json.loads(json_filepath.read_text()) - - seg_ie_polys = IEPolys() - total_points = 0 - - #include polys first - for shape in json_dict['shapes']: - if shape['shape_type'] == 'polygon' and \ - shape['label'] != '0': - seg_ie_poly = seg_ie_polys.add(1) - - for x,y in shape['points']: - seg_ie_poly.add( int(x), int(y) ) - total_points += 1 - - #exclude polys - for shape in json_dict['shapes']: - if shape['shape_type'] == 'polygon' and \ - shape['label'] == '0': - seg_ie_poly = seg_ie_polys.add(0) - - for x,y in shape['points']: - seg_ie_poly.add( int(x), int(y) ) - total_points += 1 - - if total_points == 0: - io.log_info(f"No points found in {json_filepath}, skipping.") - continue - - dflimg.embed_and_set (filepath, seg_ie_polys=seg_ie_polys) - - json_filepath.unlink() - - images_processed += 1 - except: - io.log_err(f"err {filepath}, {traceback.format_exc()}") - return - - io.log_info(f"Images processed: {images_processed}") - -def split(input_dir ): - input_path = Path(input_dir) - if not input_path.exists(): - raise ValueError('input_dir not found. Please ensure it exists.') - - images_paths = pathex.get_image_paths(input_path, return_Path_class=True) - - images_processed = 0 - for filepath in io.progress_bar_generator(images_paths, "Processing"): - json_filepath = filepath.parent / (filepath.stem+'.json') - - dflimg = DFLIMG.load(filepath) - if dflimg is not None: - try: - seg_ie_polys = dflimg.get_seg_ie_polys() - if seg_ie_polys is not None: - json_dict = {} - json_dict['version'] = "4.2.9" - json_dict['flags'] = {} - json_dict['shapes'] = [] - json_dict['imagePath'] = filepath.name - json_dict['imageData'] = None + if dflimg is None or not dflimg.has_data(): + io.log_info(f'{filepath} is not a DFLIMG') + continue + + img = cv2_imread(filepath).astype(np.float32) / 255.0 + h,w,c = img.shape + + img_face_type = FaceType.fromString( dflimg.get_face_type() ) + if face_type is not None and img_face_type != face_type: + lmrks = dflimg.get_source_landmarks() + + fmat = LandmarksProcessor.get_transform_mat(lmrks, w, face_type) + imat = LandmarksProcessor.get_transform_mat(lmrks, w, img_face_type) + + g_p = LandmarksProcessor.transform_points (np.float32([(0,0),(w,0),(0,w) ]), fmat, True) + g_p2 = LandmarksProcessor.transform_points (g_p, imat) + + mat = cv2.getAffineTransform( g_p2, np.float32([(0,0),(w,0),(0,w) ]) ) + + img = cv2.warpAffine(img, mat, (w, w), cv2.INTER_LANCZOS4) + img = cv2.resize(img, (xseg_res, xseg_res), interpolation=cv2.INTER_LANCZOS4) + else: + if w != xseg_res: + img = cv2.resize( img, (xseg_res,xseg_res), interpolation=cv2.INTER_LANCZOS4 ) - for poly_type, points_list in seg_ie_polys: - shape_dict = {} - shape_dict['label'] = str(poly_type) - shape_dict['points'] = points_list - shape_dict['group_id'] = None - shape_dict['shape_type'] = 'polygon' - shape_dict['flags'] = {} - json_dict['shapes'].append( shape_dict ) + if len(img.shape) == 2: + img = img[...,None] + + mask = xseg.extract(img) + + if face_type is not None and img_face_type != face_type: + mask = cv2.resize(mask, (w, w), interpolation=cv2.INTER_LANCZOS4) + mask = cv2.warpAffine( mask, mat, (w,w), np.zeros( (h,w,c), dtype=np.float), cv2.WARP_INVERSE_MAP | cv2.INTER_LANCZOS4) + mask = cv2.resize(mask, (xseg_res, xseg_res), interpolation=cv2.INTER_LANCZOS4) + mask[mask < 0.5]=0 + mask[mask >= 0.5]=1 + dflimg.set_xseg_mask(mask) + dflimg.save() - json_filepath.write_text( json.dumps (json_dict,indent=4) ) - dflimg.remove_seg_ie_polys() - dflimg.embed_and_set (filepath) - images_processed += 1 - except: - io.log_err(f"err {filepath}, {traceback.format_exc()}") - return + +def fetch_xseg(input_path): + if not input_path.exists(): + raise ValueError(f'{input_path} not found. Please ensure it exists.') + + output_path = input_path.parent / (input_path.name + '_xseg') + output_path.mkdir(exist_ok=True, parents=True) + + io.log_info(f'Copying faces containing XSeg polygons to {output_path.name}/ folder.') + + images_paths = pathex.get_image_paths(input_path, return_Path_class=True) + + + files_copied = [] + for filepath in io.progress_bar_generator(images_paths, "Processing"): + dflimg = DFLIMG.load(filepath) + if dflimg is None or not dflimg.has_data(): + io.log_info(f'{filepath} is not a DFLIMG') + continue + + ie_polys = dflimg.get_seg_ie_polys() - io.log_info(f"Images processed: {images_processed}") \ No newline at end of file + if ie_polys.has_polys(): + files_copied.append(filepath) + shutil.copy ( str(filepath), str(output_path / filepath.name) ) + + io.log_info(f'Files copied: {len(files_copied)}') + + is_delete = io.input_bool (f"\r\nDelete original files?", True) + if is_delete: + for filepath in files_copied: + Path(filepath).unlink() + + +def remove_xseg(input_path): + if not input_path.exists(): + raise ValueError(f'{input_path} not found. Please ensure it exists.') + + io.log_info(f'Processing folder {input_path}') + io.log_info('!!! WARNING : APPLIED XSEG MASKS WILL BE REMOVED FROM THE FRAMES !!!') + io.log_info('!!! WARNING : APPLIED XSEG MASKS WILL BE REMOVED FROM THE FRAMES !!!') + io.log_info('!!! WARNING : APPLIED XSEG MASKS WILL BE REMOVED FROM THE FRAMES !!!') + io.input_str('Press enter to continue.') + + images_paths = pathex.get_image_paths(input_path, return_Path_class=True) + + files_processed = 0 + for filepath in io.progress_bar_generator(images_paths, "Processing"): + dflimg = DFLIMG.load(filepath) + if dflimg is None or not dflimg.has_data(): + io.log_info(f'{filepath} is not a DFLIMG') + continue + + if dflimg.has_xseg_mask(): + dflimg.set_xseg_mask(None) + dflimg.save() + files_processed += 1 + io.log_info(f'Files processed: {files_processed}') + +def remove_xseg_labels(input_path): + if not input_path.exists(): + raise ValueError(f'{input_path} not found. Please ensure it exists.') + + io.log_info(f'Processing folder {input_path}') + io.log_info('!!! WARNING : LABELED XSEG POLYGONS WILL BE REMOVED FROM THE FRAMES !!!') + io.log_info('!!! WARNING : LABELED XSEG POLYGONS WILL BE REMOVED FROM THE FRAMES !!!') + io.log_info('!!! WARNING : LABELED XSEG POLYGONS WILL BE REMOVED FROM THE FRAMES !!!') + io.input_str('Press enter to continue.') + + images_paths = pathex.get_image_paths(input_path, return_Path_class=True) + + files_processed = 0 + for filepath in io.progress_bar_generator(images_paths, "Processing"): + dflimg = DFLIMG.load(filepath) + if dflimg is None or not dflimg.has_data(): + io.log_info(f'{filepath} is not a DFLIMG') + continue + + if dflimg.has_seg_ie_polys(): + dflimg.set_seg_ie_polys(None) + dflimg.save() + files_processed += 1 + + io.log_info(f'Files processed: {files_processed}') \ No newline at end of file diff --git a/mainscripts/dev_misc.py b/mainscripts/dev_misc.py index 2883a79..93a4359 100644 --- a/mainscripts/dev_misc.py +++ b/mainscripts/dev_misc.py @@ -8,13 +8,11 @@ import numpy as np from core import imagelib, pathex from core.cv2ex import * -from core.imagelib import IEPolys from core.interact import interact as io from core.joblib import Subprocessor from core.leras import nn from DFLIMG import * from facelib import FaceType, LandmarksProcessor - from . import Extractor, Sorter from .Extractor import ExtractSubprocessor @@ -219,182 +217,6 @@ def extract_vggface2_dataset(input_dir, device_args={} ): """ -class CelebAMASKHQSubprocessor(Subprocessor): - class Cli(Subprocessor.Cli): - #override - def on_initialize(self, client_dict): - self.masks_files_paths = client_dict['masks_files_paths'] - return None - - #override - def process_data(self, data): - filename = data[0] - - dflimg = DFLIMG.load(Path(filename)) - - image_to_face_mat = dflimg.get_image_to_face_mat() - src_filename = dflimg.get_source_filename() - - img = cv2_imread(filename) - h,w,c = img.shape - - fanseg_mask = LandmarksProcessor.get_image_hull_mask(img.shape, dflimg.get_landmarks() ) - - idx_name = '%.5d' % int(src_filename.split('.')[0]) - idx_files = [ x for x in self.masks_files_paths if idx_name in x ] - - skin_files = [ x for x in idx_files if 'skin' in x ] - eye_glass_files = [ x for x in idx_files if 'eye_g' in x ] - - for files, is_invert in [ (skin_files,False), - (eye_glass_files,True) ]: - if len(files) > 0: - mask = cv2_imread(files[0]) - mask = mask[...,0] - mask[mask == 255] = 1 - mask = mask.astype(np.float32) - mask = cv2.resize(mask, (1024,1024) ) - mask = cv2.warpAffine(mask, image_to_face_mat, (w, h), cv2.INTER_LANCZOS4) - - if not is_invert: - fanseg_mask *= mask[...,None] - else: - fanseg_mask *= (1-mask[...,None]) - - dflimg.embed_and_set (filename, fanseg_mask=fanseg_mask) - return 1 - - #override - def get_data_name (self, data): - #return string identificator of your data - return data[0] - - #override - def __init__(self, image_paths, masks_files_paths ): - self.image_paths = image_paths - self.masks_files_paths = masks_files_paths - - self.result = [] - super().__init__('CelebAMASKHQSubprocessor', CelebAMASKHQSubprocessor.Cli, 60) - - #override - def process_info_generator(self): - for i in range(min(multiprocessing.cpu_count(), 8)): - yield 'CPU%d' % (i), {}, {'masks_files_paths' : self.masks_files_paths } - - #override - def on_clients_initialized(self): - io.progress_bar ("Processing", len (self.image_paths)) - - #override - def on_clients_finalized(self): - io.progress_bar_close() - - #override - def get_data(self, host_dict): - if len (self.image_paths) > 0: - return [self.image_paths.pop(0)] - return None - - #override - def on_data_return (self, host_dict, data): - self.image_paths.insert(0, data[0]) - - #override - def on_result (self, host_dict, data, result): - io.progress_bar_inc(1) - - #override - def get_result(self): - return self.result - -#unused in end user workflow -def apply_celebamaskhq(input_dir ): - - input_path = Path(input_dir) - - img_path = input_path / 'aligned' - mask_path = input_path / 'mask' - - if not img_path.exists(): - raise ValueError(f'{str(img_path)} directory not found. Please ensure it exists.') - - CelebAMASKHQSubprocessor(pathex.get_image_paths(img_path), - pathex.get_image_paths(mask_path, subdirs=True) ).run() - - return - - paths_to_extract = [] - for filename in io.progress_bar_generator(pathex.get_image_paths(img_path), desc="Processing"): - filepath = Path(filename) - dflimg = DFLIMG.load(filepath) - - if dflimg is not None: - paths_to_extract.append (filepath) - - image_to_face_mat = dflimg.get_image_to_face_mat() - src_filename = dflimg.get_source_filename() - - #img = cv2_imread(filename) - h,w,c = dflimg.get_shape() - - fanseg_mask = LandmarksProcessor.get_image_hull_mask( (h,w,c), dflimg.get_landmarks() ) - - idx_name = '%.5d' % int(src_filename.split('.')[0]) - idx_files = [ x for x in masks_files if idx_name in x ] - - skin_files = [ x for x in idx_files if 'skin' in x ] - eye_glass_files = [ x for x in idx_files if 'eye_g' in x ] - - for files, is_invert in [ (skin_files,False), - (eye_glass_files,True) ]: - - if len(files) > 0: - mask = cv2_imread(files[0]) - mask = mask[...,0] - mask[mask == 255] = 1 - mask = mask.astype(np.float32) - mask = cv2.resize(mask, (1024,1024) ) - mask = cv2.warpAffine(mask, image_to_face_mat, (w, h), cv2.INTER_LANCZOS4) - - if not is_invert: - fanseg_mask *= mask[...,None] - else: - fanseg_mask *= (1-mask[...,None]) - - #cv2.imshow("", (fanseg_mask*255).astype(np.uint8) ) - #cv2.waitKey(0) - - - dflimg.embed_and_set (filename, fanseg_mask=fanseg_mask) - - - #import code - #code.interact(local=dict(globals(), **locals())) - - - -#unused in end user workflow -def extract_fanseg(input_dir, device_args={} ): - multi_gpu = device_args.get('multi_gpu', False) - cpu_only = device_args.get('cpu_only', False) - - input_path = Path(input_dir) - if not input_path.exists(): - raise ValueError('Input directory not found. Please ensure it exists.') - - paths_to_extract = [] - for filename in pathex.get_image_paths(input_path) : - filepath = Path(filename) - dflimg = DFLIMG.load ( filepath ) - if dflimg is not None: - paths_to_extract.append (filepath) - - paths_to_extract_len = len(paths_to_extract) - if paths_to_extract_len > 0: - io.log_info ("Performing extract fanseg for %d files..." % (paths_to_extract_len) ) - data = ExtractSubprocessor ([ ExtractSubprocessor.Data(filename) for filename in paths_to_extract ], 'fanseg', multi_gpu=multi_gpu, cpu_only=cpu_only).run() - #unused in end user workflow def dev_test_68(input_dir ): # process 68 landmarks dataset with .pts files @@ -451,13 +273,14 @@ def dev_test_68(input_dir ): img = cv2_imread(filepath) img = imagelib.normalize_channels(img, 3) cv2_imwrite(output_filepath, img, [int(cv2.IMWRITE_JPEG_QUALITY), 95] ) - - DFLJPG.embed_data(output_filepath, face_type=FaceType.toString(FaceType.MARK_ONLY), - landmarks=lmrks, - source_filename=filepath.name, - source_rect=rect, - source_landmarks=lmrks - ) + + raise Exception("unimplemented") + #DFLJPG.x(output_filepath, face_type=FaceType.toString(FaceType.MARK_ONLY), + # landmarks=lmrks, + # source_filename=filepath.name, + # source_rect=rect, + # source_landmarks=lmrks + # ) io.log_info("Done.") @@ -533,25 +356,114 @@ def extract_umd_csv(input_file_csv, io.log_info ('Faces detected: %d' % (faces_detected) ) io.log_info ('-------------------------') + + def dev_test1(input_dir): + # LaPa dataset + + image_size = 1024 + face_type = FaceType.HEAD + input_path = Path(input_dir) + images_path = input_path / 'images' + if not images_path.exists: + raise ValueError('LaPa dataset: images folder not found.') + labels_path = input_path / 'labels' + if not labels_path.exists: + raise ValueError('LaPa dataset: labels folder not found.') + landmarks_path = input_path / 'landmarks' + if not landmarks_path.exists: + raise ValueError('LaPa dataset: landmarks folder not found.') + + output_path = input_path / 'out' + if output_path.exists(): + output_images_paths = pathex.get_image_paths(output_path) + if len(output_images_paths) != 0: + io.input(f"\n WARNING !!! \n {output_path} contains files! \n They will be deleted. \n Press enter to continue.\n") + for filename in output_images_paths: + Path(filename).unlink() + output_path.mkdir(parents=True, exist_ok=True) + + data = [] + + img_paths = pathex.get_image_paths (images_path) + for filename in img_paths: + filepath = Path(filename) - dir_names = pathex.get_all_dir_names(input_path) + landmark_filepath = landmarks_path / (filepath.stem + '.txt') + if not landmark_filepath.exists(): + raise ValueError(f'no landmarks for {filepath}') + + #img = cv2_imread(filepath) + + lm = landmark_filepath.read_text() + lm = lm.split('\n') + if int(lm[0]) != 106: + raise ValueError(f'wrong landmarks format in {landmark_filepath}') + + lmrks = [] + for i in range(106): + x,y = lm[i+1].split(' ') + x,y = float(x), float(y) + lmrks.append ( (x,y) ) + + lmrks = np.array(lmrks) + + l,t = np.min(lmrks, 0) + r,b = np.max(lmrks, 0) + + l,t,r,b = ( int(x) for x in (l,t,r,b) ) + + #for x, y in lmrks: + # x,y = int(x), int(y) + # cv2.circle(img, (x, y), 1, (0,255,0) , 1, lineType=cv2.LINE_AA) + + #imagelib.draw_rect(img, (l,t,r,b), (0,255,0) ) + + + data += [ ExtractSubprocessor.Data(filepath=filepath, rects=[ (l,t,r,b) ]) ] - for dir_name in io.progress_bar_generator(dir_names, desc="Processing"): + #cv2.imshow("", img) + #cv2.waitKey(0) + + if len(data) > 0: + device_config = nn.DeviceConfig.BestGPU() + + io.log_info ("Performing 2nd pass...") + data = ExtractSubprocessor (data, 'landmarks', image_size, 95, face_type, device_config=device_config).run() + io.log_info ("Performing 3rd pass...") + data = ExtractSubprocessor (data, 'final', image_size, 95, face_type, final_output_path=output_path, device_config=device_config).run() - img_paths = pathex.get_image_paths (input_path / dir_name) - for filename in img_paths: + + for filename in pathex.get_image_paths (output_path): filepath = Path(filename) + + + dflimg = DFLJPG.load(filepath) + + src_filename = dflimg.get_source_filename() + image_to_face_mat = dflimg.get_image_to_face_mat() - dflimg = DFLIMG.load (filepath) - if dflimg is None: - raise ValueError - - dflimg.embed_and_set(filename, person_name=dir_name) - - #import code - #code.interact(local=dict(globals(), **locals())) + label_filepath = labels_path / ( Path(src_filename).stem + '.png') + if not label_filepath.exists(): + raise ValueError(f'{label_filepath} does not exist') + + mask = cv2_imread(label_filepath) + #mask[mask == 10] = 0 # remove hair + mask[mask > 0] = 1 + mask = cv2.warpAffine(mask, image_to_face_mat, (image_size, image_size), cv2.INTER_LINEAR) + mask = cv2.blur(mask, (3,3) ) + + #cv2.imshow("", (mask*255).astype(np.uint8) ) + #cv2.waitKey(0) + + dflimg.set_xseg_mask(mask) + dflimg.save() + + + import code + code.interact(local=dict(globals(), **locals())) + def dev_resave_pngs(input_dir): input_path = Path(input_dir) @@ -587,123 +499,96 @@ def dev_segmented_trash(input_dir): except: io.log_info ('fail to trashing %s' % (src.name) ) + + +def dev_test(input_dir): + """ + extract FaceSynthetics dataset https://github.com/microsoft/FaceSynthetics + + BACKGROUND = 0 + SKIN = 1 + NOSE = 2 + RIGHT_EYE = 3 + LEFT_EYE = 4 + RIGHT_BROW = 5 + LEFT_BROW = 6 + RIGHT_EAR = 7 + LEFT_EAR = 8 + MOUTH_INTERIOR = 9 + TOP_LIP = 10 + BOTTOM_LIP = 11 + NECK = 12 + HAIR = 13 + BEARD = 14 + CLOTHING = 15 + GLASSES = 16 + HEADWEAR = 17 + FACEWEAR = 18 + IGNORE = 255 + """ + + + image_size = 1024 + face_type = FaceType.WHOLE_FACE -def dev_segmented_extract(input_dir, output_dir ): - # extract and merge .json labelme files within the faces - - device_config = nn.DeviceConfig.GPUIndexes( nn.ask_choose_device_idxs(suggest_all_gpu=True) ) - input_path = Path(input_dir) - if not input_path.exists(): - raise ValueError('input_dir not found. Please ensure it exists.') - - output_path = Path(output_dir) - io.log_info("Performing extract segmented faces.") - io.log_info(f'Output dir is {output_path}') - + + + + output_path = input_path.parent / f'{input_path.name}_out' if output_path.exists(): - output_images_paths = pathex.get_image_paths(output_path, subdirs=True) - if len(output_images_paths) > 0: - io.input_bool("WARNING !!! \n %s contains files! \n They will be deleted. \n Press enter to continue." % (str(output_path)), False ) + output_images_paths = pathex.get_image_paths(output_path) + if len(output_images_paths) != 0: + io.input(f"\n WARNING !!! \n {output_path} contains files! \n They will be deleted. \n Press enter to continue.\n") for filename in output_images_paths: Path(filename).unlink() - shutil.rmtree(str(output_path)) - else: - output_path.mkdir(parents=True, exist_ok=True) - - images_paths = pathex.get_image_paths(input_path, subdirs=True, return_Path_class=True) - - extract_data = [] - images_jsons = {} - images_processed = 0 - - - for filepath in io.progress_bar_generator(images_paths, "Processing"): - json_filepath = filepath.parent / (filepath.stem+'.json') - - if json_filepath.exists(): - try: - json_dict = json.loads(json_filepath.read_text()) - images_jsons[filepath] = json_dict - - total_points = [ [x,y] for shape in json_dict['shapes'] for x,y in shape['points'] ] - total_points = np.array(total_points) - - if len(total_points) == 0: - io.log_info(f"No points found in {json_filepath}, skipping.") + output_path.mkdir(parents=True, exist_ok=True) + + data = [] + + for filepath in io.progress_bar_generator(pathex.get_paths(input_path), "Processing"): + if filepath.suffix == '.txt': + + image_filepath = filepath.parent / f'{filepath.name.split("_")[0]}.png' + if not image_filepath.exists(): + print(f'{image_filepath} does not exist, skipping') + + lmrks = [] + for lmrk_line in filepath.read_text().split('\n'): + if len(lmrk_line) == 0: continue + + x, y = lmrk_line.split(' ') + x, y = float(x), float(y) + + lmrks.append( (x,y) ) + + lmrks = np.array(lmrks[:68], np.float32) + rect = LandmarksProcessor.get_rect_from_landmarks(lmrks) + data += [ ExtractSubprocessor.Data(filepath=image_filepath, rects=[rect], landmarks=[ lmrks ] ) ] - l,r = int(total_points[:,0].min()), int(total_points[:,0].max()) - t,b = int(total_points[:,1].min()), int(total_points[:,1].max()) + if len(data) > 0: + io.log_info ("Performing 3rd pass...") + data = ExtractSubprocessor (data, 'final', image_size, 95, face_type, final_output_path=output_path, device_config=nn.DeviceConfig.CPU()).run() - force_output_path=output_path / filepath.relative_to(input_path).parent - force_output_path.mkdir(exist_ok=True, parents=True) - - extract_data.append ( ExtractSubprocessor.Data(filepath, - rects=[ [l,t,r,b] ], - force_output_path=force_output_path ) ) - images_processed += 1 - except: - io.log_err(f"err {filepath}, {traceback.format_exc()}") - return - else: - io.log_info(f"No .json file for {filepath.relative_to(input_path)}, skipping.") - continue - - image_size = 1024 - face_type = FaceType.HEAD - extract_data = ExtractSubprocessor (extract_data, 'landmarks', image_size, face_type, device_config=device_config).run() - extract_data = ExtractSubprocessor (extract_data, 'final', image_size, face_type, device_config=device_config).run() - - for data in extract_data: - filepath = data.force_output_path / (data.filepath.stem+'_0.jpg') - - dflimg = DFLIMG.load(filepath) - image_to_face_mat = dflimg.get_image_to_face_mat() - - json_dict = images_jsons[data.filepath] - - ie_polys = IEPolys() - for shape in json_dict['shapes']: - ie_poly = ie_polys.add(1) - - points = np.array( [ [x,y] for x,y in shape['points'] ] ) - points = LandmarksProcessor.transform_points(points, image_to_face_mat) - - for x,y in points: - ie_poly.add( int(x), int(y) ) - - dflimg.embed_and_set (filepath, ie_polys=ie_polys) - - io.log_info(f"Images found: {len(images_paths)}") - io.log_info(f"Images processed: {images_processed}") - - - -""" -#mark only -for data in extract_data: - filepath = data.filepath - output_filepath = output_path / (filepath.stem+'.jpg') - - img = cv2_imread(filepath) - img = imagelib.normalize_channels(img, 3) - cv2_imwrite(output_filepath, img, [int(cv2.IMWRITE_JPEG_QUALITY), 100] ) - - json_dict = images_jsons[filepath] - - ie_polys = IEPolys() - for shape in json_dict['shapes']: - ie_poly = ie_polys.add(1) - for x,y in shape['points']: - ie_poly.add( int(x), int(y) ) - - - DFLJPG.embed_data(output_filepath, face_type=FaceType.toString(FaceType.MARK_ONLY), - landmarks=data.landmarks[0], - ie_polys=ie_polys, - source_filename=filepath.name, - source_rect=data.rects[0], - source_landmarks=data.landmarks[0] - ) -""" \ No newline at end of file + for filename in io.progress_bar_generator(pathex.get_image_paths (output_path), "Processing"): + filepath = Path(filename) + + dflimg = DFLJPG.load(filepath) + + src_filename = dflimg.get_source_filename() + image_to_face_mat = dflimg.get_image_to_face_mat() + + seg_filepath = input_path / ( Path(src_filename).stem + '_seg.png') + if not seg_filepath.exists(): + raise ValueError(f'{seg_filepath} does not exist') + + seg = cv2_imread(seg_filepath) + seg_inds = np.isin(seg, [1,2,3,4,5,6,9,10,11]) + seg[~seg_inds] = 0 + seg[seg_inds] = 1 + seg = seg.astype(np.float32) + seg = cv2.warpAffine(seg, image_to_face_mat, (image_size, image_size), cv2.INTER_LANCZOS4) + dflimg.set_xseg_mask(seg) + dflimg.save() + \ No newline at end of file diff --git a/merger/InteractiveMergerSubprocessor.py b/merger/InteractiveMergerSubprocessor.py index bf92045..58db0c1 100644 --- a/merger/InteractiveMergerSubprocessor.py +++ b/merger/InteractiveMergerSubprocessor.py @@ -66,7 +66,6 @@ class InteractiveMergerSubprocessor(Subprocessor): self.predictor_func = client_dict['predictor_func'] self.predictor_input_shape = client_dict['predictor_input_shape'] self.face_enhancer_func = client_dict['face_enhancer_func'] - self.fanseg_full_face_256_extract_func = client_dict['fanseg_full_face_256_extract_func'] self.xseg_256_extract_func = client_dict['xseg_256_extract_func'] @@ -85,14 +84,19 @@ class InteractiveMergerSubprocessor(Subprocessor): filepath = frame_info.filepath if len(frame_info.landmarks_list) == 0: - self.log_info (f'no faces found for {filepath.name}, copying without faces') - - img_bgr = cv2_imread(filepath) - imagelib.normalize_channels(img_bgr, 3) + + if cfg.mode == 'raw-predict': + h,w,c = self.predictor_input_shape + img_bgr = np.zeros( (h,w,3), dtype=np.uint8) + img_mask = np.zeros( (h,w,1), dtype=np.uint8) + else: + self.log_info (f'no faces found for {filepath.name}, copying without faces') + img_bgr = cv2_imread(filepath) + imagelib.normalize_channels(img_bgr, 3) + h,w,c = img_bgr.shape + img_mask = np.zeros( (h,w,1), dtype=img_bgr.dtype) + cv2_imwrite (pf.output_filepath, img_bgr) - h,w,c = img_bgr.shape - - img_mask = np.zeros( (h,w,1), dtype=img_bgr.dtype) cv2_imwrite (pf.output_mask_filepath, img_mask) if pf.need_return_image: @@ -103,7 +107,6 @@ class InteractiveMergerSubprocessor(Subprocessor): try: final_img = MergeMasked (self.predictor_func, self.predictor_input_shape, face_enhancer_func=self.face_enhancer_func, - fanseg_full_face_256_extract_func=self.fanseg_full_face_256_extract_func, xseg_256_extract_func=self.xseg_256_extract_func, cfg=cfg, frame_info=frame_info) @@ -137,7 +140,7 @@ class InteractiveMergerSubprocessor(Subprocessor): #override - def __init__(self, is_interactive, merger_session_filepath, predictor_func, predictor_input_shape, face_enhancer_func, fanseg_full_face_256_extract_func, xseg_256_extract_func, merger_config, frames, frames_root_path, output_path, output_mask_path, model_iter): + def __init__(self, is_interactive, merger_session_filepath, predictor_func, predictor_input_shape, face_enhancer_func, xseg_256_extract_func, merger_config, frames, frames_root_path, output_path, output_mask_path, model_iter, subprocess_count=4): if len (frames) == 0: raise ValueError ("len (frames) == 0") @@ -151,7 +154,6 @@ class InteractiveMergerSubprocessor(Subprocessor): self.predictor_input_shape = predictor_input_shape self.face_enhancer_func = face_enhancer_func - self.fanseg_full_face_256_extract_func = fanseg_full_face_256_extract_func self.xseg_256_extract_func = xseg_256_extract_func self.frames_root_path = frames_root_path @@ -159,7 +161,7 @@ class InteractiveMergerSubprocessor(Subprocessor): self.output_mask_path = output_mask_path self.model_iter = model_iter - self.prefetch_frame_count = self.process_count = multiprocessing.cpu_count() + self.prefetch_frame_count = self.process_count = subprocess_count session_data = None if self.is_interactive and self.merger_session_filepath.exists(): @@ -273,7 +275,6 @@ class InteractiveMergerSubprocessor(Subprocessor): 'predictor_func': self.predictor_func, 'predictor_input_shape' : self.predictor_input_shape, 'face_enhancer_func': self.face_enhancer_func, - 'fanseg_full_face_256_extract_func' : self.fanseg_full_face_256_extract_func, 'xseg_256_extract_func' : self.xseg_256_extract_func, 'stdin_fd': sys.stdin.fileno() if MERGER_DEBUG else None } @@ -304,6 +305,7 @@ class InteractiveMergerSubprocessor(Subprocessor): '3' : lambda cfg,shift_pressed: cfg.set_mode(3), '4' : lambda cfg,shift_pressed: cfg.set_mode(4), '5' : lambda cfg,shift_pressed: cfg.set_mode(5), + '6' : lambda cfg,shift_pressed: cfg.set_mode(6), 'q' : lambda cfg,shift_pressed: cfg.add_hist_match_threshold(1 if not shift_pressed else 5), 'a' : lambda cfg,shift_pressed: cfg.add_hist_match_threshold(-1 if not shift_pressed else -5), 'w' : lambda cfg,shift_pressed: cfg.add_erode_mask_modifier(1 if not shift_pressed else 5), @@ -391,6 +393,7 @@ class InteractiveMergerSubprocessor(Subprocessor): # unable to read? recompute then cur_frame.is_done = False else: + image = imagelib.normalize_channels(image, 3) image_mask = imagelib.normalize_channels(image_mask, 1) cur_frame.image = np.concatenate([image, image_mask], -1) diff --git a/merger/MergeMasked.py b/merger/MergeMasked.py index c3ee34c..0a5c633 100644 --- a/merger/MergeMasked.py +++ b/merger/MergeMasked.py @@ -1,30 +1,25 @@ +import sys import traceback import cv2 import numpy as np from core import imagelib -from facelib import FaceType, LandmarksProcessor -from core.interact import interact as io from core.cv2ex import * +from core.interact import interact as io +from facelib import FaceType, LandmarksProcessor -fanseg_input_size = 256 +is_windows = sys.platform[0:3] == 'win' xseg_input_size = 256 -def MergeMaskedFace (predictor_func, predictor_input_shape, +def MergeMaskedFace (predictor_func, predictor_input_shape, face_enhancer_func, - fanseg_full_face_256_extract_func, xseg_256_extract_func, cfg, frame_info, img_bgr_uint8, img_bgr, img_face_landmarks): + img_size = img_bgr.shape[1], img_bgr.shape[0] img_face_mask_a = LandmarksProcessor.get_image_hull_mask (img_bgr.shape, img_face_landmarks) - if cfg.mode == 'original': - return img_bgr, img_face_mask_a - - out_img = img_bgr.copy() - out_merging_mask_a = None - input_size = predictor_input_shape[0] mask_subres_size = input_size*4 output_size = input_size @@ -48,16 +43,9 @@ def MergeMaskedFace (predictor_func, predictor_input_shape, predictor_input_bgr = cv2.resize (dst_face_bgr, (input_size,input_size) ) predicted = predictor_func (predictor_input_bgr) - if isinstance(predicted, tuple): - #merger return bgr,mask - prd_face_bgr = np.clip (predicted[0], 0, 1.0) - prd_face_mask_a_0 = np.clip (predicted[1], 0, 1.0) - predictor_masked = True - else: - #merger return bgr only, using dst mask - prd_face_bgr = np.clip (predicted, 0, 1.0 ) - prd_face_mask_a_0 = cv2.resize (dst_face_mask_a_0, (input_size,input_size) ) - predictor_masked = False + prd_face_bgr = np.clip (predicted[0], 0, 1.0) + prd_face_mask_a_0 = np.clip (predicted[1], 0, 1.0) + prd_face_dst_mask_a_0 = np.clip (predicted[2], 0, 1.0) if cfg.super_resolution_power != 0: prd_face_bgr_enhanced = face_enhancer_func(prd_face_bgr, is_tanh=True, preserve_size=False) @@ -66,289 +54,272 @@ def MergeMaskedFace (predictor_func, predictor_input_shape, prd_face_bgr = np.clip(prd_face_bgr, 0, 1) if cfg.super_resolution_power != 0: - if predictor_masked: - prd_face_mask_a_0 = cv2.resize (prd_face_mask_a_0, (output_size, output_size), cv2.INTER_CUBIC) - else: - prd_face_mask_a_0 = cv2.resize (dst_face_mask_a_0, (output_size, output_size), cv2.INTER_CUBIC) + prd_face_mask_a_0 = cv2.resize (prd_face_mask_a_0, (output_size, output_size), interpolation=cv2.INTER_CUBIC) + prd_face_dst_mask_a_0 = cv2.resize (prd_face_dst_mask_a_0, (output_size, output_size), interpolation=cv2.INTER_CUBIC) - if cfg.mask_mode == 2: #dst - prd_face_mask_a_0 = cv2.resize (dst_face_mask_a_0, (output_size,output_size), cv2.INTER_CUBIC) - elif cfg.mask_mode >= 3 and cfg.mask_mode <= 7: - - if cfg.mask_mode == 3 or cfg.mask_mode == 5 or cfg.mask_mode == 6: - prd_face_fanseg_bgr = cv2.resize (prd_face_bgr, (fanseg_input_size,)*2 ) - prd_face_fanseg_mask = fanseg_full_face_256_extract_func(prd_face_fanseg_bgr) - FAN_prd_face_mask_a_0 = cv2.resize ( prd_face_fanseg_mask, (output_size, output_size), cv2.INTER_CUBIC) - - if cfg.mask_mode >= 4 and cfg.mask_mode <= 7: - - full_face_fanseg_mat = LandmarksProcessor.get_transform_mat (img_face_landmarks, fanseg_input_size, face_type=FaceType.FULL) - dst_face_fanseg_bgr = cv2.warpAffine(img_bgr, full_face_fanseg_mat, (fanseg_input_size,)*2, flags=cv2.INTER_CUBIC ) - dst_face_fanseg_mask = fanseg_full_face_256_extract_func(dst_face_fanseg_bgr ) - - if cfg.face_type == FaceType.FULL: - FAN_dst_face_mask_a_0 = cv2.resize (dst_face_fanseg_mask, (output_size,output_size), cv2.INTER_CUBIC) - else: - face_fanseg_mat = LandmarksProcessor.get_transform_mat (img_face_landmarks, fanseg_input_size, face_type=cfg.face_type) - - fanseg_rect_corner_pts = np.array ( [ [0,0], [fanseg_input_size-1,0], [0,fanseg_input_size-1] ], dtype=np.float32 ) - a = LandmarksProcessor.transform_points (fanseg_rect_corner_pts, face_fanseg_mat, invert=True ) - b = LandmarksProcessor.transform_points (a, full_face_fanseg_mat ) - m = cv2.getAffineTransform(b, fanseg_rect_corner_pts) - FAN_dst_face_mask_a_0 = cv2.warpAffine(dst_face_fanseg_mask, m, (fanseg_input_size,)*2, flags=cv2.INTER_CUBIC ) - FAN_dst_face_mask_a_0 = cv2.resize (FAN_dst_face_mask_a_0, (output_size,output_size), cv2.INTER_CUBIC) - - if cfg.mask_mode == 3: #FAN-prd - prd_face_mask_a_0 = FAN_prd_face_mask_a_0 - elif cfg.mask_mode == 4: #FAN-dst - prd_face_mask_a_0 = FAN_dst_face_mask_a_0 - elif cfg.mask_mode == 5: - prd_face_mask_a_0 = FAN_prd_face_mask_a_0 * FAN_dst_face_mask_a_0 - elif cfg.mask_mode == 6: - prd_face_mask_a_0 = prd_face_mask_a_0 * FAN_prd_face_mask_a_0 * FAN_dst_face_mask_a_0 - elif cfg.mask_mode == 7: - prd_face_mask_a_0 = prd_face_mask_a_0 * FAN_dst_face_mask_a_0 - - elif cfg.mask_mode >= 8 and cfg.mask_mode <= 11: - if cfg.mask_mode == 8 or cfg.mask_mode == 10 or cfg.mask_mode == 11: - prd_face_xseg_bgr = cv2.resize (prd_face_bgr, (xseg_input_size,)*2, cv2.INTER_CUBIC) + if cfg.mask_mode == 0: #full + wrk_face_mask_a_0 = np.ones_like(dst_face_mask_a_0) + elif cfg.mask_mode == 1: #dst + wrk_face_mask_a_0 = cv2.resize (dst_face_mask_a_0, (output_size,output_size), interpolation=cv2.INTER_CUBIC) + elif cfg.mask_mode == 2: #learned-prd + wrk_face_mask_a_0 = prd_face_mask_a_0 + elif cfg.mask_mode == 3: #learned-dst + wrk_face_mask_a_0 = prd_face_dst_mask_a_0 + elif cfg.mask_mode == 4: #learned-prd*learned-dst + wrk_face_mask_a_0 = prd_face_mask_a_0*prd_face_dst_mask_a_0 + elif cfg.mask_mode == 5: #learned-prd+learned-dst + wrk_face_mask_a_0 = np.clip( prd_face_mask_a_0+prd_face_dst_mask_a_0, 0, 1) + elif cfg.mask_mode >= 6 and cfg.mask_mode <= 9: #XSeg modes + if cfg.mask_mode == 6 or cfg.mask_mode == 8 or cfg.mask_mode == 9: + # obtain XSeg-prd + prd_face_xseg_bgr = cv2.resize (prd_face_bgr, (xseg_input_size,)*2, interpolation=cv2.INTER_CUBIC) prd_face_xseg_mask = xseg_256_extract_func(prd_face_xseg_bgr) - X_prd_face_mask_a_0 = cv2.resize ( prd_face_xseg_mask, (output_size, output_size), cv2.INTER_CUBIC) + X_prd_face_mask_a_0 = cv2.resize ( prd_face_xseg_mask, (output_size, output_size), interpolation=cv2.INTER_CUBIC) - if cfg.mask_mode >= 9 and cfg.mask_mode <= 11: - whole_face_mat = LandmarksProcessor.get_transform_mat (img_face_landmarks, xseg_input_size, face_type=FaceType.WHOLE_FACE) - dst_face_xseg_bgr = cv2.warpAffine(img_bgr, whole_face_mat, (xseg_input_size,)*2, flags=cv2.INTER_CUBIC ) + if cfg.mask_mode >= 7 and cfg.mask_mode <= 9: + # obtain XSeg-dst + xseg_mat = LandmarksProcessor.get_transform_mat (img_face_landmarks, xseg_input_size, face_type=cfg.face_type) + dst_face_xseg_bgr = cv2.warpAffine(img_bgr, xseg_mat, (xseg_input_size,)*2, flags=cv2.INTER_CUBIC ) dst_face_xseg_mask = xseg_256_extract_func(dst_face_xseg_bgr) - X_dst_face_mask_a_0 = cv2.resize (dst_face_xseg_mask, (output_size,output_size), cv2.INTER_CUBIC) + X_dst_face_mask_a_0 = cv2.resize (dst_face_xseg_mask, (output_size,output_size), interpolation=cv2.INTER_CUBIC) - if cfg.mask_mode == 8: #'XSeg-prd', - prd_face_mask_a_0 = X_prd_face_mask_a_0 - elif cfg.mask_mode == 9: #'XSeg-dst', - prd_face_mask_a_0 = X_dst_face_mask_a_0 - elif cfg.mask_mode == 10: #'XSeg-prd*XSeg-dst', - prd_face_mask_a_0 = X_prd_face_mask_a_0 * X_dst_face_mask_a_0 - elif cfg.mask_mode == 11: #learned*XSeg-prd*XSeg-dst' - prd_face_mask_a_0 = prd_face_mask_a_0 * X_prd_face_mask_a_0 * X_dst_face_mask_a_0 - - prd_face_mask_a_0[ prd_face_mask_a_0 < (1.0/255.0) ] = 0.0 # get rid of noise + if cfg.mask_mode == 6: #'XSeg-prd' + wrk_face_mask_a_0 = X_prd_face_mask_a_0 + elif cfg.mask_mode == 7: #'XSeg-dst' + wrk_face_mask_a_0 = X_dst_face_mask_a_0 + elif cfg.mask_mode == 8: #'XSeg-prd*XSeg-dst' + wrk_face_mask_a_0 = X_prd_face_mask_a_0 * X_dst_face_mask_a_0 + elif cfg.mask_mode == 9: #learned-prd*learned-dst*XSeg-prd*XSeg-dst + wrk_face_mask_a_0 = prd_face_mask_a_0 * prd_face_dst_mask_a_0 * X_prd_face_mask_a_0 * X_dst_face_mask_a_0 + + wrk_face_mask_a_0[ wrk_face_mask_a_0 < (1.0/255.0) ] = 0.0 # get rid of noise # resize to mask_subres_size - if prd_face_mask_a_0.shape[0] != mask_subres_size: - prd_face_mask_a_0 = cv2.resize (prd_face_mask_a_0, (mask_subres_size, mask_subres_size), cv2.INTER_CUBIC) + if wrk_face_mask_a_0.shape[0] != mask_subres_size: + wrk_face_mask_a_0 = cv2.resize (wrk_face_mask_a_0, (mask_subres_size, mask_subres_size), interpolation=cv2.INTER_CUBIC) # process mask in local predicted space if 'raw' not in cfg.mode: # add zero pad - prd_face_mask_a_0 = np.pad (prd_face_mask_a_0, input_size) + wrk_face_mask_a_0 = np.pad (wrk_face_mask_a_0, input_size) ero = cfg.erode_mask_modifier blur = cfg.blur_mask_modifier if ero > 0: - prd_face_mask_a_0 = cv2.erode(prd_face_mask_a_0, cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(ero,ero)), iterations = 1 ) + wrk_face_mask_a_0 = cv2.erode(wrk_face_mask_a_0, cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(ero,ero)), iterations = 1 ) elif ero < 0: - prd_face_mask_a_0 = cv2.dilate(prd_face_mask_a_0, cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(-ero,-ero)), iterations = 1 ) + wrk_face_mask_a_0 = cv2.dilate(wrk_face_mask_a_0, cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(-ero,-ero)), iterations = 1 ) # clip eroded/dilated mask in actual predict area # pad with half blur size in order to accuratelly fade to zero at the boundary clip_size = input_size + blur // 2 - prd_face_mask_a_0[:clip_size,:] = 0 - prd_face_mask_a_0[-clip_size:,:] = 0 - prd_face_mask_a_0[:,:clip_size] = 0 - prd_face_mask_a_0[:,-clip_size:] = 0 + wrk_face_mask_a_0[:clip_size,:] = 0 + wrk_face_mask_a_0[-clip_size:,:] = 0 + wrk_face_mask_a_0[:,:clip_size] = 0 + wrk_face_mask_a_0[:,-clip_size:] = 0 if blur > 0: blur = blur + (1-blur % 2) - prd_face_mask_a_0 = cv2.GaussianBlur(prd_face_mask_a_0, (blur, blur) , 0) + wrk_face_mask_a_0 = cv2.GaussianBlur(wrk_face_mask_a_0, (blur, blur) , 0) - prd_face_mask_a_0 = prd_face_mask_a_0[input_size:-input_size,input_size:-input_size] + wrk_face_mask_a_0 = wrk_face_mask_a_0[input_size:-input_size,input_size:-input_size] - prd_face_mask_a_0 = np.clip(prd_face_mask_a_0, 0, 1) + wrk_face_mask_a_0 = np.clip(wrk_face_mask_a_0, 0, 1) - img_face_mask_a = cv2.warpAffine( prd_face_mask_a_0, face_mask_output_mat, img_size, np.zeros(img_bgr.shape[0:2], dtype=np.float32), flags=cv2.WARP_INVERSE_MAP | cv2.INTER_CUBIC )[...,None] + img_face_mask_a = cv2.warpAffine( wrk_face_mask_a_0, face_mask_output_mat, img_size, np.zeros(img_bgr.shape[0:2], dtype=np.float32), flags=cv2.WARP_INVERSE_MAP | cv2.INTER_CUBIC )[...,None] img_face_mask_a = np.clip (img_face_mask_a, 0.0, 1.0) - img_face_mask_a [ img_face_mask_a < (1.0/255.0) ] = 0.0 # get rid of noise - if prd_face_mask_a_0.shape[0] != output_size: - prd_face_mask_a_0 = cv2.resize (prd_face_mask_a_0, (output_size,output_size), cv2.INTER_CUBIC) + if wrk_face_mask_a_0.shape[0] != output_size: + wrk_face_mask_a_0 = cv2.resize (wrk_face_mask_a_0, (output_size,output_size), interpolation=cv2.INTER_CUBIC) - prd_face_mask_a = prd_face_mask_a_0[...,None] - prd_face_mask_area_a = prd_face_mask_a.copy() - prd_face_mask_area_a[prd_face_mask_area_a>0] = 1.0 + wrk_face_mask_a = wrk_face_mask_a_0[...,None] - if 'raw' in cfg.mode: + out_img = None + out_merging_mask_a = None + if cfg.mode == 'original': + return img_bgr, img_face_mask_a + + elif 'raw' in cfg.mode: if cfg.mode == 'raw-rgb': - out_img = cv2.warpAffine( prd_face_bgr, face_output_mat, img_size, out_img, cv2.WARP_INVERSE_MAP | cv2.INTER_CUBIC, cv2.BORDER_TRANSPARENT ) + out_img_face = cv2.warpAffine( prd_face_bgr, face_output_mat, img_size, np.empty_like(img_bgr), cv2.WARP_INVERSE_MAP | cv2.INTER_CUBIC) + out_img_face_mask = cv2.warpAffine( np.ones_like(prd_face_bgr), face_output_mat, img_size, np.empty_like(img_bgr), cv2.WARP_INVERSE_MAP | cv2.INTER_CUBIC) + out_img = img_bgr*(1-out_img_face_mask) + out_img_face*out_img_face_mask out_merging_mask_a = img_face_mask_a + elif cfg.mode == 'raw-predict': + out_img = prd_face_bgr + out_merging_mask_a = wrk_face_mask_a + else: + raise ValueError(f"undefined raw type {cfg.mode}") out_img = np.clip (out_img, 0.0, 1.0 ) else: - #averaging [lenx, leny, maskx, masky] by grayscale gradients of upscaled mask - ar = [] - for i in range(1, 10): - maxregion = np.argwhere( img_face_mask_a > i / 10.0 ) - if maxregion.size != 0: - miny,minx = maxregion.min(axis=0)[:2] - maxy,maxx = maxregion.max(axis=0)[:2] - lenx = maxx - minx - leny = maxy - miny - if min(lenx,leny) >= 4: - ar += [ [ lenx, leny] ] - if len(ar) > 0: + # Process if the mask meets minimum size + maxregion = np.argwhere( img_face_mask_a >= 0.1 ) + if maxregion.size != 0: + miny,minx = maxregion.min(axis=0)[:2] + maxy,maxx = maxregion.max(axis=0)[:2] + lenx = maxx - minx + leny = maxy - miny + if min(lenx,leny) >= 4: + wrk_face_mask_area_a = wrk_face_mask_a.copy() + wrk_face_mask_area_a[wrk_face_mask_area_a>0] = 1.0 - if 'seamless' not in cfg.mode and cfg.color_transfer_mode != 0: - if cfg.color_transfer_mode == 1: #rct - prd_face_bgr = imagelib.reinhard_color_transfer ( np.clip( prd_face_bgr*prd_face_mask_area_a*255, 0, 255).astype(np.uint8), - np.clip( dst_face_bgr*prd_face_mask_area_a*255, 0, 255).astype(np.uint8), ) - - prd_face_bgr = np.clip( prd_face_bgr.astype(np.float32) / 255.0, 0.0, 1.0) - elif cfg.color_transfer_mode == 2: #lct - prd_face_bgr = imagelib.linear_color_transfer (prd_face_bgr, dst_face_bgr) - elif cfg.color_transfer_mode == 3: #mkl - prd_face_bgr = imagelib.color_transfer_mkl (prd_face_bgr, dst_face_bgr) - elif cfg.color_transfer_mode == 4: #mkl-m - prd_face_bgr = imagelib.color_transfer_mkl (prd_face_bgr*prd_face_mask_area_a, dst_face_bgr*prd_face_mask_area_a) - elif cfg.color_transfer_mode == 5: #idt - prd_face_bgr = imagelib.color_transfer_idt (prd_face_bgr, dst_face_bgr) - elif cfg.color_transfer_mode == 6: #idt-m - prd_face_bgr = imagelib.color_transfer_idt (prd_face_bgr*prd_face_mask_area_a, dst_face_bgr*prd_face_mask_area_a) - elif cfg.color_transfer_mode == 7: #sot-m - prd_face_bgr = imagelib.color_transfer_sot (prd_face_bgr*prd_face_mask_area_a, dst_face_bgr*prd_face_mask_area_a) - prd_face_bgr = np.clip (prd_face_bgr, 0.0, 1.0) - elif cfg.color_transfer_mode == 8: #mix-m - prd_face_bgr = imagelib.color_transfer_mix (prd_face_bgr*prd_face_mask_area_a, dst_face_bgr*prd_face_mask_area_a) + if 'seamless' not in cfg.mode and cfg.color_transfer_mode != 0: + if cfg.color_transfer_mode == 1: #rct + prd_face_bgr = imagelib.reinhard_color_transfer (prd_face_bgr, dst_face_bgr, target_mask=wrk_face_mask_area_a, source_mask=wrk_face_mask_area_a) + elif cfg.color_transfer_mode == 2: #lct + prd_face_bgr = imagelib.linear_color_transfer (prd_face_bgr, dst_face_bgr) + elif cfg.color_transfer_mode == 3: #mkl + prd_face_bgr = imagelib.color_transfer_mkl (prd_face_bgr, dst_face_bgr) + elif cfg.color_transfer_mode == 4: #mkl-m + prd_face_bgr = imagelib.color_transfer_mkl (prd_face_bgr*wrk_face_mask_area_a, dst_face_bgr*wrk_face_mask_area_a) + elif cfg.color_transfer_mode == 5: #idt + prd_face_bgr = imagelib.color_transfer_idt (prd_face_bgr, dst_face_bgr) + elif cfg.color_transfer_mode == 6: #idt-m + prd_face_bgr = imagelib.color_transfer_idt (prd_face_bgr*wrk_face_mask_area_a, dst_face_bgr*wrk_face_mask_area_a) + elif cfg.color_transfer_mode == 7: #sot-m + prd_face_bgr = imagelib.color_transfer_sot (prd_face_bgr*wrk_face_mask_area_a, dst_face_bgr*wrk_face_mask_area_a, steps=10, batch_size=30) + prd_face_bgr = np.clip (prd_face_bgr, 0.0, 1.0) + elif cfg.color_transfer_mode == 8: #mix-m + prd_face_bgr = imagelib.color_transfer_mix (prd_face_bgr*wrk_face_mask_area_a, dst_face_bgr*wrk_face_mask_area_a) - if cfg.mode == 'hist-match': - hist_mask_a = np.ones ( prd_face_bgr.shape[:2] + (1,) , dtype=np.float32) + if cfg.mode == 'hist-match': + hist_mask_a = np.ones ( prd_face_bgr.shape[:2] + (1,) , dtype=np.float32) - if cfg.masked_hist_match: - hist_mask_a *= prd_face_mask_area_a + if cfg.masked_hist_match: + hist_mask_a *= wrk_face_mask_area_a - white = (1.0-hist_mask_a)* np.ones ( prd_face_bgr.shape[:2] + (1,) , dtype=np.float32) + white = (1.0-hist_mask_a)* np.ones ( prd_face_bgr.shape[:2] + (1,) , dtype=np.float32) - hist_match_1 = prd_face_bgr*hist_mask_a + white - hist_match_1[ hist_match_1 > 1.0 ] = 1.0 + hist_match_1 = prd_face_bgr*hist_mask_a + white + hist_match_1[ hist_match_1 > 1.0 ] = 1.0 - hist_match_2 = dst_face_bgr*hist_mask_a + white - hist_match_2[ hist_match_1 > 1.0 ] = 1.0 + hist_match_2 = dst_face_bgr*hist_mask_a + white + hist_match_2[ hist_match_1 > 1.0 ] = 1.0 - prd_face_bgr = imagelib.color_hist_match(hist_match_1, hist_match_2, cfg.hist_match_threshold ).astype(dtype=np.float32) + prd_face_bgr = imagelib.color_hist_match(hist_match_1, hist_match_2, cfg.hist_match_threshold ).astype(dtype=np.float32) - if 'seamless' in cfg.mode: - #mask used for cv2.seamlessClone - img_face_seamless_mask_a = None - for i in range(1,10): - a = img_face_mask_a > i / 10.0 - if len(np.argwhere(a)) == 0: - continue - img_face_seamless_mask_a = img_face_mask_a.copy() - img_face_seamless_mask_a[a] = 1.0 - img_face_seamless_mask_a[img_face_seamless_mask_a <= i / 10.0] = 0.0 - break + if 'seamless' in cfg.mode: + #mask used for cv2.seamlessClone + img_face_seamless_mask_a = None + for i in range(1,10): + a = img_face_mask_a > i / 10.0 + if len(np.argwhere(a)) == 0: + continue + img_face_seamless_mask_a = img_face_mask_a.copy() + img_face_seamless_mask_a[a] = 1.0 + img_face_seamless_mask_a[img_face_seamless_mask_a <= i / 10.0] = 0.0 + break - out_img = cv2.warpAffine( prd_face_bgr, face_output_mat, img_size, out_img, cv2.WARP_INVERSE_MAP | cv2.INTER_CUBIC, cv2.BORDER_TRANSPARENT ) + out_img = cv2.warpAffine( prd_face_bgr, face_output_mat, img_size, np.empty_like(img_bgr), cv2.WARP_INVERSE_MAP | cv2.INTER_CUBIC ) + out_img = np.clip(out_img, 0.0, 1.0) - out_img = np.clip(out_img, 0.0, 1.0) + if 'seamless' in cfg.mode: + try: + #calc same bounding rect and center point as in cv2.seamlessClone to prevent jittering (not flickering) + l,t,w,h = cv2.boundingRect( (img_face_seamless_mask_a*255).astype(np.uint8) ) + s_maskx, s_masky = int(l+w/2), int(t+h/2) + out_img = cv2.seamlessClone( (out_img*255).astype(np.uint8), img_bgr_uint8, (img_face_seamless_mask_a*255).astype(np.uint8), (s_maskx,s_masky) , cv2.NORMAL_CLONE ) + out_img = out_img.astype(dtype=np.float32) / 255.0 + except Exception as e: + #seamlessClone may fail in some cases + e_str = traceback.format_exc() - if 'seamless' in cfg.mode: - try: - #calc same bounding rect and center point as in cv2.seamlessClone to prevent jittering (not flickering) - l,t,w,h = cv2.boundingRect( (img_face_seamless_mask_a*255).astype(np.uint8) ) - s_maskx, s_masky = int(l+w/2), int(t+h/2) - out_img = cv2.seamlessClone( (out_img*255).astype(np.uint8), img_bgr_uint8, (img_face_seamless_mask_a*255).astype(np.uint8), (s_maskx,s_masky) , cv2.NORMAL_CLONE ) - out_img = out_img.astype(dtype=np.float32) / 255.0 - except Exception as e: - #seamlessClone may fail in some cases - e_str = traceback.format_exc() + if 'MemoryError' in e_str: + raise Exception("Seamless fail: " + e_str) #reraise MemoryError in order to reprocess this data by other processes + else: + print ("Seamless fail: " + e_str) - if 'MemoryError' in e_str: - raise Exception("Seamless fail: " + e_str) #reraise MemoryError in order to reprocess this data by other processes + cfg_mp = cfg.motion_blur_power / 100.0 + + out_img = img_bgr*(1-img_face_mask_a) + (out_img*img_face_mask_a) + + if ('seamless' in cfg.mode and cfg.color_transfer_mode != 0) or \ + cfg.mode == 'seamless-hist-match' or \ + cfg_mp != 0 or \ + cfg.blursharpen_amount != 0 or \ + cfg.image_denoise_power != 0 or \ + cfg.bicubic_degrade_power != 0: + + out_face_bgr = cv2.warpAffine( out_img, face_mat, (output_size, output_size), flags=cv2.INTER_CUBIC ) + + if 'seamless' in cfg.mode and cfg.color_transfer_mode != 0: + if cfg.color_transfer_mode == 1: + out_face_bgr = imagelib.reinhard_color_transfer (out_face_bgr, dst_face_bgr, target_mask=wrk_face_mask_area_a, source_mask=wrk_face_mask_area_a) + elif cfg.color_transfer_mode == 2: #lct + out_face_bgr = imagelib.linear_color_transfer (out_face_bgr, dst_face_bgr) + elif cfg.color_transfer_mode == 3: #mkl + out_face_bgr = imagelib.color_transfer_mkl (out_face_bgr, dst_face_bgr) + elif cfg.color_transfer_mode == 4: #mkl-m + out_face_bgr = imagelib.color_transfer_mkl (out_face_bgr*wrk_face_mask_area_a, dst_face_bgr*wrk_face_mask_area_a) + elif cfg.color_transfer_mode == 5: #idt + out_face_bgr = imagelib.color_transfer_idt (out_face_bgr, dst_face_bgr) + elif cfg.color_transfer_mode == 6: #idt-m + out_face_bgr = imagelib.color_transfer_idt (out_face_bgr*wrk_face_mask_area_a, dst_face_bgr*wrk_face_mask_area_a) + elif cfg.color_transfer_mode == 7: #sot-m + out_face_bgr = imagelib.color_transfer_sot (out_face_bgr*wrk_face_mask_area_a, dst_face_bgr*wrk_face_mask_area_a, steps=10, batch_size=30) + out_face_bgr = np.clip (out_face_bgr, 0.0, 1.0) + elif cfg.color_transfer_mode == 8: #mix-m + out_face_bgr = imagelib.color_transfer_mix (out_face_bgr*wrk_face_mask_area_a, dst_face_bgr*wrk_face_mask_area_a) + + if cfg.mode == 'seamless-hist-match': + out_face_bgr = imagelib.color_hist_match(out_face_bgr, dst_face_bgr, cfg.hist_match_threshold) + + if cfg_mp != 0: + k_size = int(frame_info.motion_power*cfg_mp) + if k_size >= 1: + k_size = np.clip (k_size+1, 2, 50) + if cfg.super_resolution_power != 0: + k_size *= 2 + out_face_bgr = imagelib.LinearMotionBlur (out_face_bgr, k_size , frame_info.motion_deg) + + if cfg.blursharpen_amount != 0: + out_face_bgr = imagelib.blursharpen ( out_face_bgr, cfg.sharpen_mode, 3, cfg.blursharpen_amount) + + if cfg.image_denoise_power != 0: + n = cfg.image_denoise_power + while n > 0: + img_bgr_denoised = cv2.medianBlur(img_bgr, 5) + if int(n / 100) != 0: + img_bgr = img_bgr_denoised + else: + pass_power = (n % 100) / 100.0 + img_bgr = img_bgr*(1.0-pass_power)+img_bgr_denoised*pass_power + n = max(n-10,0) + + if cfg.bicubic_degrade_power != 0: + p = 1.0 - cfg.bicubic_degrade_power / 101.0 + img_bgr_downscaled = cv2.resize (img_bgr, ( int(img_size[0]*p), int(img_size[1]*p ) ), interpolation=cv2.INTER_CUBIC) + img_bgr = cv2.resize (img_bgr_downscaled, img_size, interpolation=cv2.INTER_CUBIC) + + new_out = cv2.warpAffine( out_face_bgr, face_mat, img_size, np.empty_like(img_bgr), cv2.WARP_INVERSE_MAP | cv2.INTER_CUBIC ) + + out_img = np.clip( img_bgr*(1-img_face_mask_a) + (new_out*img_face_mask_a) , 0, 1.0 ) + + if cfg.color_degrade_power != 0: + out_img_reduced = imagelib.reduce_colors(out_img, 256) + if cfg.color_degrade_power == 100: + out_img = out_img_reduced else: - print ("Seamless fail: " + e_str) - - - out_img = img_bgr*(1-img_face_mask_a) + (out_img*img_face_mask_a) - - out_face_bgr = cv2.warpAffine( out_img, face_mat, (output_size, output_size), flags=cv2.INTER_CUBIC ) - - if 'seamless' in cfg.mode and cfg.color_transfer_mode != 0: - if cfg.color_transfer_mode == 1: - out_face_bgr = imagelib.reinhard_color_transfer ( np.clip(out_face_bgr*prd_face_mask_area_a*255, 0, 255).astype(np.uint8), - np.clip(dst_face_bgr*prd_face_mask_area_a*255, 0, 255).astype(np.uint8) ) - out_face_bgr = np.clip( out_face_bgr.astype(np.float32) / 255.0, 0.0, 1.0) - elif cfg.color_transfer_mode == 2: #lct - out_face_bgr = imagelib.linear_color_transfer (out_face_bgr, dst_face_bgr) - elif cfg.color_transfer_mode == 3: #mkl - out_face_bgr = imagelib.color_transfer_mkl (out_face_bgr, dst_face_bgr) - elif cfg.color_transfer_mode == 4: #mkl-m - out_face_bgr = imagelib.color_transfer_mkl (out_face_bgr*prd_face_mask_area_a, dst_face_bgr*prd_face_mask_area_a) - elif cfg.color_transfer_mode == 5: #idt - out_face_bgr = imagelib.color_transfer_idt (out_face_bgr, dst_face_bgr) - elif cfg.color_transfer_mode == 6: #idt-m - out_face_bgr = imagelib.color_transfer_idt (out_face_bgr*prd_face_mask_area_a, dst_face_bgr*prd_face_mask_area_a) - elif cfg.color_transfer_mode == 7: #sot-m - out_face_bgr = imagelib.color_transfer_sot (out_face_bgr*prd_face_mask_area_a, dst_face_bgr*prd_face_mask_area_a) - out_face_bgr = np.clip (out_face_bgr, 0.0, 1.0) - elif cfg.color_transfer_mode == 8: #mix-m - out_face_bgr = imagelib.color_transfer_mix (out_face_bgr*prd_face_mask_area_a, dst_face_bgr*prd_face_mask_area_a) - - if cfg.mode == 'seamless-hist-match': - out_face_bgr = imagelib.color_hist_match(out_face_bgr, dst_face_bgr, cfg.hist_match_threshold) - - cfg_mp = cfg.motion_blur_power / 100.0 - if cfg_mp != 0: - k_size = int(frame_info.motion_power*cfg_mp) - if k_size >= 1: - k_size = np.clip (k_size+1, 2, 50) - if cfg.super_resolution_power != 0: - k_size *= 2 - out_face_bgr = imagelib.LinearMotionBlur (out_face_bgr, k_size , frame_info.motion_deg) - - if cfg.blursharpen_amount != 0: - out_face_bgr = imagelib.blursharpen ( out_face_bgr, cfg.sharpen_mode, 3, cfg.blursharpen_amount) - - - if cfg.image_denoise_power != 0: - n = cfg.image_denoise_power - while n > 0: - img_bgr_denoised = cv2.medianBlur(img_bgr, 5) - if int(n / 100) != 0: - img_bgr = img_bgr_denoised - else: - pass_power = (n % 100) / 100.0 - img_bgr = img_bgr*(1.0-pass_power)+img_bgr_denoised*pass_power - n = max(n-10,0) - - if cfg.bicubic_degrade_power != 0: - p = 1.0 - cfg.bicubic_degrade_power / 101.0 - img_bgr_downscaled = cv2.resize (img_bgr, ( int(img_size[0]*p), int(img_size[1]*p ) ), cv2.INTER_CUBIC) - img_bgr = cv2.resize (img_bgr_downscaled, img_size, cv2.INTER_CUBIC) - - new_out = cv2.warpAffine( out_face_bgr, face_mat, img_size, img_bgr.copy(), cv2.WARP_INVERSE_MAP | cv2.INTER_CUBIC, cv2.BORDER_TRANSPARENT ) - out_img = np.clip( img_bgr*(1-img_face_mask_a) + (new_out*img_face_mask_a) , 0, 1.0 ) - - if cfg.color_degrade_power != 0: - out_img_reduced = imagelib.reduce_colors(out_img, 256) - if cfg.color_degrade_power == 100: - out_img = out_img_reduced - else: - alpha = cfg.color_degrade_power / 100.0 - out_img = (out_img*(1.0-alpha) + out_img_reduced*alpha) - + alpha = cfg.color_degrade_power / 100.0 + out_img = (out_img*(1.0-alpha) + out_img_reduced*alpha) out_merging_mask_a = img_face_mask_a + if out_img is None: + out_img = img_bgr.copy() + return out_img, out_merging_mask_a -def MergeMasked (predictor_func, +def MergeMasked (predictor_func, predictor_input_shape, face_enhancer_func, - fanseg_full_face_256_extract_func, - xseg_256_extract_func, - cfg, + xseg_256_extract_func, + cfg, frame_info): img_bgr_uint8 = cv2_imread(frame_info.filepath) img_bgr_uint8 = imagelib.normalize_channels (img_bgr_uint8, 3) @@ -356,7 +327,7 @@ def MergeMasked (predictor_func, outs = [] for face_num, img_landmarks in enumerate( frame_info.landmarks_list ): - out_img, out_img_merging_mask = MergeMaskedFace (predictor_func, predictor_input_shape, face_enhancer_func, fanseg_full_face_256_extract_func, xseg_256_extract_func, cfg, frame_info, img_bgr_uint8, img_bgr, img_landmarks) + out_img, out_img_merging_mask = MergeMaskedFace (predictor_func, predictor_input_shape, face_enhancer_func, xseg_256_extract_func, cfg, frame_info, img_bgr_uint8, img_bgr, img_landmarks) outs += [ (out_img, out_img_merging_mask) ] #Combining multiple face outputs @@ -374,4 +345,4 @@ def MergeMasked (predictor_func, final_img = np.concatenate ( [final_img, final_mask], -1) - return (final_img*255).astype(np.uint8) \ No newline at end of file + return (final_img*255).astype(np.uint8) diff --git a/merger/MergerConfig.py b/merger/MergerConfig.py index c385769..eba1493 100644 --- a/merger/MergerConfig.py +++ b/merger/MergerConfig.py @@ -76,41 +76,23 @@ mode_dict = {0:'original', 2:'hist-match', 3:'seamless', 4:'seamless-hist-match', - 5:'raw-rgb',} + 5:'raw-rgb', + 6:'raw-predict'} -mode_str_dict = {} +mode_str_dict = { mode_dict[key] : key for key in mode_dict.keys() } -for key in mode_dict.keys(): - mode_str_dict[ mode_dict[key] ] = key +mask_mode_dict = {0:'full', + 1:'dst', + 2:'learned-prd', + 3:'learned-dst', + 4:'learned-prd*learned-dst', + 5:'learned-prd+learned-dst', + 6:'XSeg-prd', + 7:'XSeg-dst', + 8:'XSeg-prd*XSeg-dst', + 9:'learned-prd*learned-dst*XSeg-prd*XSeg-dst' + } -""" -whole_face_mask_mode_dict = {1:'learned', - 2:'dst', - 3:'FAN-prd', - 4:'FAN-dst', - 5:'FAN-prd*FAN-dst', - 6:'learned*FAN-prd*FAN-dst' - } -""" -whole_face_mask_mode_dict = {1:'learned', - 2:'dst', - 8:'XSeg-prd', - 9:'XSeg-dst', - 10:'XSeg-prd*XSeg-dst', - 11:'learned*XSeg-prd*XSeg-dst' - } - -full_face_mask_mode_dict = {1:'learned', - 2:'dst', - 3:'FAN-prd', - 4:'FAN-dst', - 5:'FAN-prd*FAN-dst', - 6:'learned*FAN-prd*FAN-dst'} - -half_face_mask_mode_dict = {1:'learned', - 2:'dst', - 4:'FAN-dst', - 7:'learned*FAN-dst'} ctm_dict = { 0: "None", 1:"rct", 2:"lct", 3:"mkl", 4:"mkl-m", 5:"idt", 6:"idt-m", 7:"sot-m", 8:"mix-m" } ctm_str_dict = {None:0, "rct":1, "lct":2, "mkl":3, "mkl-m":4, "idt":5, "idt-m":6, "sot-m":7, "mix-m":8 } @@ -122,7 +104,7 @@ class MergerConfigMasked(MergerConfig): mode='overlay', masked_hist_match=True, hist_match_threshold = 238, - mask_mode = 1, + mask_mode = 4, erode_mask_modifier = 0, blur_mask_modifier = 0, motion_blur_power = 0, @@ -138,7 +120,7 @@ class MergerConfigMasked(MergerConfig): super().__init__(type=MergerConfig.TYPE_MASKED, **kwargs) self.face_type = face_type - if self.face_type not in [FaceType.HALF, FaceType.MID_FULL, FaceType.FULL, FaceType.WHOLE_FACE ]: + if self.face_type not in [FaceType.HALF, FaceType.MID_FULL, FaceType.FULL, FaceType.WHOLE_FACE, FaceType.HEAD ]: raise ValueError("MergerConfigMasked does not support this type of face.") self.default_mode = default_mode @@ -176,12 +158,7 @@ class MergerConfigMasked(MergerConfig): self.hist_match_threshold = np.clip ( self.hist_match_threshold+diff , 0, 255) def toggle_mask_mode(self): - if self.face_type == FaceType.WHOLE_FACE: - a = list( whole_face_mask_mode_dict.keys() ) - elif self.face_type == FaceType.FULL: - a = list( full_face_mask_mode_dict.keys() ) - else: - a = list( half_face_mask_mode_dict.keys() ) + a = list( mask_mode_dict.keys() ) self.mask_mode = a[ (a.index(self.mask_mode)+1) % len(a) ] def add_erode_mask_modifier(self, diff): @@ -227,26 +204,11 @@ class MergerConfigMasked(MergerConfig): if self.mode == 'hist-match' or self.mode == 'seamless-hist-match': self.hist_match_threshold = np.clip ( io.input_int("Hist match threshold", 255, add_info="0..255"), 0, 255) - if self.face_type == FaceType.WHOLE_FACE: - s = """Choose mask mode: \n""" - for key in whole_face_mask_mode_dict.keys(): - s += f"""({key}) {whole_face_mask_mode_dict[key]}\n""" - io.log_info(s) - - self.mask_mode = io.input_int ("", 1, valid_list=whole_face_mask_mode_dict.keys() ) - elif self.face_type == FaceType.FULL: - s = """Choose mask mode: \n""" - for key in full_face_mask_mode_dict.keys(): - s += f"""({key}) {full_face_mask_mode_dict[key]}\n""" - io.log_info(s) - - self.mask_mode = io.input_int ("", 1, valid_list=full_face_mask_mode_dict.keys(), help_message="If you learned the mask, then option 1 should be choosed. 'dst' mask is raw shaky mask from dst aligned images. 'FAN-prd' - using super smooth mask by pretrained FAN-model from predicted face. 'FAN-dst' - using super smooth mask by pretrained FAN-model from dst face. 'FAN-prd*FAN-dst' or 'learned*FAN-prd*FAN-dst' - using multiplied masks.") - else: - s = """Choose mask mode: \n""" - for key in half_face_mask_mode_dict.keys(): - s += f"""({key}) {half_face_mask_mode_dict[key]}\n""" - io.log_info(s) - self.mask_mode = io.input_int ("", 1, valid_list=half_face_mask_mode_dict.keys(), help_message="If you learned the mask, then option 1 should be choosed. 'dst' mask is raw shaky mask from dst aligned images.") + s = """Choose mask mode: \n""" + for key in mask_mode_dict.keys(): + s += f"""({key}) {mask_mode_dict[key]}\n""" + io.log_info(s) + self.mask_mode = io.input_int ("", 1, valid_list=mask_mode_dict.keys() ) if 'raw' not in self.mode: self.erode_mask_modifier = np.clip ( io.input_int ("Choose erode mask modifier", 0, add_info="-400..400"), -400, 400) @@ -303,12 +265,7 @@ class MergerConfigMasked(MergerConfig): if self.mode == 'hist-match' or self.mode == 'seamless-hist-match': r += f"""hist_match_threshold: {self.hist_match_threshold}\n""" - if self.face_type == FaceType.WHOLE_FACE: - r += f"""mask_mode: { whole_face_mask_mode_dict[self.mask_mode] }\n""" - elif self.face_type == FaceType.FULL: - r += f"""mask_mode: { full_face_mask_mode_dict[self.mask_mode] }\n""" - else: - r += f"""mask_mode: { half_face_mask_mode_dict[self.mask_mode] }\n""" + r += f"""mask_mode: { mask_mode_dict[self.mask_mode] }\n""" if 'raw' not in self.mode: r += (f"""erode_mask_modifier: {self.erode_mask_modifier}\n""" @@ -319,8 +276,8 @@ class MergerConfigMasked(MergerConfig): if 'raw' not in self.mode: r += f"""color_transfer_mode: {ctm_dict[self.color_transfer_mode]}\n""" + r += super().to_string(filename) - r += super().to_string(filename) r += f"""super_resolution_power: {self.super_resolution_power}\n""" if 'raw' not in self.mode: diff --git a/merger/gfx/help_merger_masked.jpg b/merger/gfx/help_merger_masked.jpg index f3c31e0..f1822f3 100644 Binary files a/merger/gfx/help_merger_masked.jpg and b/merger/gfx/help_merger_masked.jpg differ diff --git a/merger/gfx/help_merger_masked_source.psd b/merger/gfx/help_merger_masked_source.psd index 25a440f..b88c8ff 100644 Binary files a/merger/gfx/help_merger_masked_source.psd and b/merger/gfx/help_merger_masked_source.psd differ diff --git a/models/ModelBase.py b/models/ModelBase.py index 7ec1085..f446efa 100644 --- a/models/ModelBase.py +++ b/models/ModelBase.py @@ -1,6 +1,7 @@ import colorsys import inspect import json +import multiprocessing import operator import os import pickle @@ -12,16 +13,16 @@ from pathlib import Path import cv2 import numpy as np -from core import imagelib +from core import imagelib, pathex +from core.cv2ex import * from core.interact import interact as io from core.leras import nn from samplelib import SampleGeneratorBase -from core import pathex -from core.cv2ex import * class ModelBase(object): def __init__(self, is_training=False, + is_exporting=False, saved_models_path=None, training_data_src_path=None, training_data_dst_path=None, @@ -33,8 +34,10 @@ class ModelBase(object): cpu_only=False, debug=False, force_model_class_name=None, + silent_start=False, **kwargs): self.is_training = is_training + self.is_exporting = is_exporting self.saved_models_path = saved_models_path self.training_data_src_path = training_data_src_path self.training_data_dst_path = training_data_dst_path @@ -61,71 +64,77 @@ class ModelBase(object): saved_models_names = sorted(saved_models_names, key=operator.itemgetter(1), reverse=True ) saved_models_names = [ x[0] for x in saved_models_names ] + if len(saved_models_names) != 0: - io.log_info ("Choose one of saved models, or enter a name to create a new model.") - io.log_info ("[r] : rename") - io.log_info ("[d] : delete") - io.log_info ("") - for i, model_name in enumerate(saved_models_names): - s = f"[{i}] : {model_name} " - if i == 0: - s += "- latest" - io.log_info (s) + if silent_start: + self.model_name = saved_models_names[0] + io.log_info(f'Silent start: choosed model "{self.model_name}"') + else: + io.log_info ("Choose one of saved models, or enter a name to create a new model.") + io.log_info ("[r] : rename") + io.log_info ("[d] : delete") + io.log_info ("") + for i, model_name in enumerate(saved_models_names): + s = f"[{i}] : {model_name} " + if i == 0: + s += "- latest" + io.log_info (s) - inp = io.input_str(f"", "0", show_default_value=False ) - model_idx = -1 - try: - model_idx = np.clip ( int(inp), 0, len(saved_models_names)-1 ) - except: - pass + inp = io.input_str(f"", "0", show_default_value=False ) + model_idx = -1 + try: + model_idx = np.clip ( int(inp), 0, len(saved_models_names)-1 ) + except: + pass - if model_idx == -1: - if len(inp) == 1: - is_rename = inp[0] == 'r' - is_delete = inp[0] == 'd' + if model_idx == -1: + if len(inp) == 1: + is_rename = inp[0] == 'r' + is_delete = inp[0] == 'd' - if is_rename or is_delete: - if len(saved_models_names) != 0: - - if is_rename: - name = io.input_str(f"Enter the name of the model you want to rename") - elif is_delete: - name = io.input_str(f"Enter the name of the model you want to delete") - - if name in saved_models_names: + if is_rename or is_delete: + if len(saved_models_names) != 0: if is_rename: - new_model_name = io.input_str(f"Enter new name of the model") + name = io.input_str(f"Enter the name of the model you want to rename") + elif is_delete: + name = io.input_str(f"Enter the name of the model you want to delete") - for filepath in pathex.get_paths(saved_models_path): - filepath_name = filepath.name + if name in saved_models_names: - model_filename, remain_filename = filepath_name.split('_', 1) - if model_filename == name: + if is_rename: + new_model_name = io.input_str(f"Enter new name of the model") - if is_rename: - new_filepath = filepath.parent / ( new_model_name + '_' + remain_filename ) - filepath.rename (new_filepath) - elif is_delete: - filepath.unlink() - continue + for filepath in pathex.get_paths(saved_models_path): + filepath_name = filepath.name - self.model_name = inp - else: - self.model_name = saved_models_names[model_idx] + model_filename, remain_filename = filepath_name.split('_', 1) + if model_filename == name: + + if is_rename: + new_filepath = filepath.parent / ( new_model_name + '_' + remain_filename ) + filepath.rename (new_filepath) + elif is_delete: + filepath.unlink() + continue + + self.model_name = inp + else: + self.model_name = saved_models_names[model_idx] else: self.model_name = io.input_str(f"No saved models found. Enter a name of a new model", "new") self.model_name = self.model_name.replace('_', ' ') break - + self.model_name = self.model_name + '_' + self.model_class_name else: self.model_name = force_model_class_name self.iter = 0 self.options = {} + self.options_show_override = {} self.loss_history = [] self.sample_for_preview = None self.choosed_gpu_indexes = None @@ -145,8 +154,12 @@ class ModelBase(object): if self.is_first_run(): io.log_info ("\nModel first run.") - self.device_config = nn.DeviceConfig.GPUIndexes( force_gpu_idxs or nn.ask_choose_device_idxs(suggest_best_multi_gpu=True)) \ - if not cpu_only else nn.DeviceConfig.CPU() + if silent_start: + self.device_config = nn.DeviceConfig.BestGPU() + io.log_info (f"Silent start: choosed device {'CPU' if self.device_config.cpu_only else self.device_config.devices[0].name}") + else: + self.device_config = nn.DeviceConfig.GPUIndexes( force_gpu_idxs or nn.ask_choose_device_idxs(suggest_best_multi_gpu=True)) \ + if not cpu_only else nn.DeviceConfig.CPU() nn.initialize(self.device_config) @@ -174,10 +187,13 @@ class ModelBase(object): self.write_preview_history = self.options.get('write_preview_history', False) self.target_iter = self.options.get('target_iter',0) self.random_flip = self.options.get('random_flip',True) - + self.random_src_flip = self.options.get('random_src_flip', False) + self.random_dst_flip = self.options.get('random_dst_flip', True) + self.on_initialize() self.options['batch_size'] = self.batch_size + self.preview_history_writer = None if self.is_training: self.preview_history_path = self.saved_models_path / ( f'{self.get_model_name()}_history' ) self.autobackups_path = self.saved_models_path / ( f'{self.get_model_name()}_autobackups' ) @@ -198,7 +214,7 @@ class ModelBase(object): raise ValueError('training data generator is not subclass of SampleGeneratorBase') self.update_sample_for_preview(choose_preview_history=self.choose_preview_history) - + if self.autobackup_hour != 0: self.autobackup_start_time = time.time() @@ -206,19 +222,21 @@ class ModelBase(object): self.autobackups_path.mkdir(exist_ok=True) io.log_info( self.get_summary_text() ) - + def update_sample_for_preview(self, choose_preview_history=False, force_new=False): if self.sample_for_preview is None or choose_preview_history or force_new: if choose_preview_history and io.is_support_windows(): - io.log_info ("Choose image for the preview history. [p] - next. [enter] - confirm.") - wnd_name = "[p] - next. [enter] - confirm." + wnd_name = "[p] - next. [space] - switch preview type. [enter] - confirm." + io.log_info (f"Choose image for the preview history. {wnd_name}") io.named_window(wnd_name) io.capture_keys(wnd_name) choosed = False + preview_id_counter = 0 while not choosed: self.sample_for_preview = self.generate_next_samples() - preview = self.get_static_preview() - io.show_image( wnd_name, (preview*255).astype(np.uint8) ) + previews = self.get_history_previews() + + io.show_image( wnd_name, ( previews[preview_id_counter % len(previews) ][1] *255).astype(np.uint8) ) while True: key_events = io.get_key_events(wnd_name) @@ -226,6 +244,9 @@ class ModelBase(object): if key == ord('\n') or key == ord('\r'): choosed = True break + elif key == ord(' '): + preview_id_counter += 1 + break elif key == ord('p'): break @@ -239,12 +260,12 @@ class ModelBase(object): self.sample_for_preview = self.generate_next_samples() try: - self.get_static_preview() + self.get_history_previews() except: self.sample_for_preview = self.generate_next_samples() self.last_sample = self.sample_for_preview - + def load_or_def_option(self, name, def_value): options_val = self.options.get(name, None) if options_val is not None: @@ -280,10 +301,24 @@ class ModelBase(object): def ask_random_flip(self): default_random_flip = self.load_or_def_option('random_flip', True) self.options['random_flip'] = io.input_bool("Flip faces randomly", default_random_flip, help_message="Predicted face will look more naturally without this option, but src faceset should cover all face directions as dst faceset.") + + def ask_random_src_flip(self): + default_random_src_flip = self.load_or_def_option('random_src_flip', False) + self.options['random_src_flip'] = io.input_bool("Flip SRC faces randomly", default_random_src_flip, help_message="Random horizontal flip SRC faceset. Covers more angles, but the face may look less naturally.") - def ask_batch_size(self, suggest_batch_size=None): + def ask_random_dst_flip(self): + default_random_dst_flip = self.load_or_def_option('random_dst_flip', True) + self.options['random_dst_flip'] = io.input_bool("Flip DST faces randomly", default_random_dst_flip, help_message="Random horizontal flip DST faceset. Makes generalization of src->dst better, if src random flip is not enabled.") + + def ask_batch_size(self, suggest_batch_size=None, range=None): default_batch_size = self.load_or_def_option('batch_size', suggest_batch_size or self.batch_size) - self.options['batch_size'] = self.batch_size = max(0, io.input_int("Batch_size", default_batch_size, help_message="Larger batch size is better for NN's generalization, but it can cause Out of Memory error. Tune this value for your videocard manually.")) + + batch_size = max(0, io.input_int("Batch_size", default_batch_size, valid_range=range, help_message="Larger batch size is better for NN's generalization, but it can cause Out of Memory error. Tune this value for your videocard manually.")) + + if range is not None: + batch_size = np.clip(batch_size, range[0], range[1]) + + self.options['batch_size'] = self.batch_size = batch_size #overridable @@ -314,7 +349,7 @@ class ModelBase(object): return ( ('loss_src', 0), ('loss_dst', 0) ) #overridable - def onGetPreview(self, sample): + def onGetPreview(self, sample, for_history=False): #you can return multiple previews #return [ ('preview_name',preview_rgb), ... ] return [] @@ -344,8 +379,13 @@ class ModelBase(object): def get_previews(self): return self.onGetPreview ( self.last_sample ) - def get_static_preview(self): - return self.onGetPreview (self.sample_for_preview)[0][1] #first preview, and bgr + def get_history_previews(self): + return self.onGetPreview (self.sample_for_preview, for_history=True) + + def get_preview_history_writer(self): + if self.preview_history_writer is None: + self.preview_history_writer = PreviewHistoryWriter() + return self.preview_history_writer def save(self): Path( self.get_summary_path() ).write_text( self.get_summary_text() ) @@ -367,16 +407,16 @@ class ModelBase(object): if diff_hour > 0 and diff_hour % self.autobackup_hour == 0: self.autobackup_start_time += self.autobackup_hour*3600 self.create_backup() - + def create_backup(self): io.log_info ("Creating backup...", end='\r') - + if not self.autobackups_path.exists(): self.autobackups_path.mkdir(exist_ok=True) - + bckp_filename_list = [ self.get_strpath_storage_for_file(filename) for _, filename in self.get_model_filename_list() ] bckp_filename_list += [ str(self.get_summary_path()), str(self.model_data_path) ] - + for i in range(24,0,-1): idx_str = '%.2d' % i next_idx_str = '%.2d' % (i+1) @@ -402,10 +442,8 @@ class ModelBase(object): name, bgr = previews[i] plist += [ (bgr, idx_backup_path / ( ('preview_%s.jpg') % (name)) ) ] - for preview, filepath in plist: - preview_lh = ModelBase.get_loss_history_preview(self.loss_history, self.iter, preview.shape[1], preview.shape[2]) - img = (np.concatenate ( [preview_lh, preview], axis=0 ) * 255).astype(np.uint8) - cv2_imwrite (filepath, img ) + if len(plist) != 0: + self.get_preview_history_writer().post(plist, self.loss_history, self.iter) def debug_one_iter(self): images = [] @@ -417,7 +455,7 @@ class ModelBase(object): return imagelib.equalize_and_stack_square (images) def generate_next_samples(self): - sample = [] + sample = [] for generator in self.generator_list: if generator.is_initialized(): sample.append ( generator.generate_next() ) @@ -426,6 +464,10 @@ class ModelBase(object): self.last_sample = sample return sample + #overridable + def should_save_preview_history(self): + return (not io.is_colab() and self.iter % 10 == 0) or (io.is_colab() and self.iter % 100 == 0) + def train_one_iter(self): iter_time = time.time() @@ -434,7 +476,7 @@ class ModelBase(object): self.loss_history.append ( [float(loss[1]) for loss in losses] ) - if self.iter % 10 == 0: + if self.should_save_preview_history(): plist = [] if io.is_colab(): @@ -444,12 +486,16 @@ class ModelBase(object): plist += [ (bgr, self.get_strpath_storage_for_file('preview_%s.jpg' % (name) ) ) ] if self.write_preview_history: - plist += [ (self.get_static_preview(), str (self.preview_history_path / ('%.6d.jpg' % (self.iter))) ) ] + previews = self.get_history_previews() + for i in range(len(previews)): + name, bgr = previews[i] + path = self.preview_history_path / name + plist += [ ( bgr, str ( path / ( f'{self.iter:07d}.jpg') ) ) ] + if not io.is_colab(): + plist += [ ( bgr, str ( path / ( '_last.jpg' ) )) ] - for preview, filepath in plist: - preview_lh = ModelBase.get_loss_history_preview(self.loss_history, self.iter, preview.shape[1], preview.shape[2]) - img = (np.concatenate ( [preview_lh, preview], axis=0 ) * 255).astype(np.uint8) - cv2_imwrite (filepath, img ) + if len(plist) != 0: + self.get_preview_history_writer().post(plist, self.loss_history, self.iter) self.iter += 1 @@ -497,12 +543,15 @@ class ModelBase(object): def get_summary_path(self): return self.get_strpath_storage_for_file('summary.txt') - + def get_summary_text(self): + visible_options = self.options.copy() + visible_options.update(self.options_show_override) + ###Generate text summary of model hyperparameters #Find the longest key name and value string. Used as column widths. - width_name = max([len(k) for k in self.options.keys()] + [17]) + 1 # Single space buffer to left edge. Minimum of 17, the length of the longest static string used "Current iteration" - width_value = max([len(str(x)) for x in self.options.values()] + [len(str(self.get_iter())), len(self.get_model_name())]) + 1 # Single space buffer to right edge + width_name = max([len(k) for k in visible_options.keys()] + [17]) + 1 # Single space buffer to left edge. Minimum of 17, the length of the longest static string used "Current iteration" + width_value = max([len(str(x)) for x in visible_options.values()] + [len(str(self.get_iter())), len(self.get_model_name())]) + 1 # Single space buffer to right edge if len(self.device_config.devices) != 0: #Check length of GPU names width_value = max([len(device.name)+1 for device in self.device_config.devices] + [width_value]) width_total = width_name + width_value + 2 #Plus 2 for ": " @@ -517,8 +566,8 @@ class ModelBase(object): summary_text += [f'=={" Model Options ":-^{width_total}}=='] # Model options summary_text += [f'=={" "*width_total}=='] - for key in self.options.keys(): - summary_text += [f'=={key: >{width_name}}: {str(self.options[key]): <{width_value}}=='] # self.options key/value pairs + for key in visible_options.keys(): + summary_text += [f'=={key: >{width_name}}: {str(visible_options[key]): <{width_value}}=='] # visible_options key/value pairs summary_text += [f'=={" "*width_total}=='] summary_text += [f'=={" Running On ":-^{width_total}}=='] # Training hardware info @@ -596,3 +645,41 @@ class ModelBase(object): lh_img[last_line_t:last_line_b, 0:w] += imagelib.get_text_image ( (last_line_b-last_line_t,w,c), lh_text, color=[0.8]*c ) return lh_img + +class PreviewHistoryWriter(): + def __init__(self): + self.sq = multiprocessing.Queue() + self.p = multiprocessing.Process(target=self.process, args=( self.sq, )) + self.p.daemon = True + self.p.start() + + def process(self, sq): + while True: + while not sq.empty(): + plist, loss_history, iter = sq.get() + + preview_lh_cache = {} + for preview, filepath in plist: + filepath = Path(filepath) + i = (preview.shape[1], preview.shape[2]) + + preview_lh = preview_lh_cache.get(i, None) + if preview_lh is None: + preview_lh = ModelBase.get_loss_history_preview(loss_history, iter, preview.shape[1], preview.shape[2]) + preview_lh_cache[i] = preview_lh + + img = (np.concatenate ( [preview_lh, preview], axis=0 ) * 255).astype(np.uint8) + + filepath.parent.mkdir(parents=True, exist_ok=True) + cv2_imwrite (filepath, img ) + + time.sleep(0.01) + + def post(self, plist, loss_history, iter): + self.sq.put ( (plist, loss_history, iter) ) + + # disable pickling + def __getstate__(self): + return dict() + def __setstate__(self, d): + self.__dict__.update(d) diff --git a/models/Model_AMP/Model.py b/models/Model_AMP/Model.py new file mode 100644 index 0000000..82b0dc5 --- /dev/null +++ b/models/Model_AMP/Model.py @@ -0,0 +1,725 @@ +import multiprocessing +import operator +from functools import partial + +import numpy as np + +from core import mathlib +from core.interact import interact as io +from core.leras import nn +from facelib import FaceType +from models import ModelBase +from samplelib import * +from core.cv2ex import * + +class AMPModel(ModelBase): + + #override + def on_initialize_options(self): + default_resolution = self.options['resolution'] = self.load_or_def_option('resolution', 224) + default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'wf') + default_models_opt_on_gpu = self.options['models_opt_on_gpu'] = self.load_or_def_option('models_opt_on_gpu', True) + + default_ae_dims = self.options['ae_dims'] = self.load_or_def_option('ae_dims', 256) + default_inter_dims = self.options['inter_dims'] = self.load_or_def_option('inter_dims', 1024) + + default_e_dims = self.options['e_dims'] = self.load_or_def_option('e_dims', 64) + default_d_dims = self.options['d_dims'] = self.options.get('d_dims', None) + default_d_mask_dims = self.options['d_mask_dims'] = self.options.get('d_mask_dims', None) + default_morph_factor = self.options['morph_factor'] = self.options.get('morph_factor', 0.5) + default_uniform_yaw = self.options['uniform_yaw'] = self.load_or_def_option('uniform_yaw', False) + default_blur_out_mask = self.options['blur_out_mask'] = self.load_or_def_option('blur_out_mask', False) + default_lr_dropout = self.options['lr_dropout'] = self.load_or_def_option('lr_dropout', 'n') + default_random_warp = self.options['random_warp'] = self.load_or_def_option('random_warp', True) + default_ct_mode = self.options['ct_mode'] = self.load_or_def_option('ct_mode', 'none') + default_clipgrad = self.options['clipgrad'] = self.load_or_def_option('clipgrad', False) + + ask_override = self.ask_override() + if self.is_first_run() or ask_override: + self.ask_autobackup_hour() + self.ask_write_preview_history() + self.ask_target_iter() + self.ask_random_src_flip() + self.ask_random_dst_flip() + self.ask_batch_size(8) + + if self.is_first_run(): + resolution = io.input_int("Resolution", default_resolution, add_info="64-640", help_message="More resolution requires more VRAM and time to train. Value will be adjusted to multiple of 32 .") + resolution = np.clip ( (resolution // 32) * 32, 64, 640) + self.options['resolution'] = resolution + self.options['face_type'] = io.input_str ("Face type", default_face_type, ['f','wf','head'], help_message="whole face / head").lower() + + + default_d_dims = self.options['d_dims'] = self.load_or_def_option('d_dims', 64) + + default_d_mask_dims = default_d_dims // 3 + default_d_mask_dims += default_d_mask_dims % 2 + default_d_mask_dims = self.options['d_mask_dims'] = self.load_or_def_option('d_mask_dims', default_d_mask_dims) + + if self.is_first_run(): + self.options['ae_dims'] = np.clip ( io.input_int("AutoEncoder dimensions", default_ae_dims, add_info="32-1024", help_message="All face information will packed to AE dims. If amount of AE dims are not enough, then for example closed eyes will not be recognized. More dims are better, but require more VRAM. You can fine-tune model size to fit your GPU." ), 32, 1024 ) + self.options['inter_dims'] = np.clip ( io.input_int("Inter dimensions", default_inter_dims, add_info="32-2048", help_message="Should be equal or more than AutoEncoder dimensions. More dims are better, but require more VRAM. You can fine-tune model size to fit your GPU." ), 32, 2048 ) + + e_dims = np.clip ( io.input_int("Encoder dimensions", default_e_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 ) + self.options['e_dims'] = e_dims + e_dims % 2 + + d_dims = np.clip ( io.input_int("Decoder dimensions", default_d_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 ) + self.options['d_dims'] = d_dims + d_dims % 2 + + d_mask_dims = np.clip ( io.input_int("Decoder mask dimensions", default_d_mask_dims, add_info="16-256", help_message="Typical mask dimensions = decoder dimensions / 3. If you manually cut out obstacles from the dst mask, you can increase this parameter to achieve better quality." ), 16, 256 ) + self.options['d_mask_dims'] = d_mask_dims + d_mask_dims % 2 + + morph_factor = np.clip ( io.input_number ("Morph factor.", default_morph_factor, add_info="0.1 .. 0.5", help_message="Typical fine value is 0.5"), 0.1, 0.5 ) + self.options['morph_factor'] = morph_factor + + if self.is_first_run() or ask_override: + self.options['uniform_yaw'] = io.input_bool ("Uniform yaw distribution of samples", default_uniform_yaw, help_message='Helps to fix blurry side faces due to small amount of them in the faceset.') + self.options['blur_out_mask'] = io.input_bool ("Blur out mask", default_blur_out_mask, help_message='Blurs nearby area outside of applied face mask of training samples. The result is the background near the face is smoothed and less noticeable on swapped face. The exact xseg mask in src and dst faceset is required.') + self.options['lr_dropout'] = io.input_str (f"Use learning rate dropout", default_lr_dropout, ['n','y','cpu'], help_message="When the face is trained enough, you can enable this option to get extra sharpness and reduce subpixel shake for less amount of iterations. Enabled it before `disable random warp` and before GAN. \nn - disabled.\ny - enabled\ncpu - enabled on CPU. This allows not to use extra VRAM, sacrificing 20% time of iteration.") + + default_gan_power = self.options['gan_power'] = self.load_or_def_option('gan_power', 0.0) + default_gan_patch_size = self.options['gan_patch_size'] = self.load_or_def_option('gan_patch_size', self.options['resolution'] // 8) + default_gan_dims = self.options['gan_dims'] = self.load_or_def_option('gan_dims', 16) + + if self.is_first_run() or ask_override: + self.options['models_opt_on_gpu'] = io.input_bool ("Place models and optimizer on GPU", default_models_opt_on_gpu, help_message="When you train on one GPU, by default model and optimizer weights are placed on GPU to accelerate the process. You can place they on CPU to free up extra VRAM, thus set bigger dimensions.") + + self.options['random_warp'] = io.input_bool ("Enable random warp of samples", default_random_warp, help_message="Random warp is required to generalize facial expressions of both faces. When the face is trained enough, you can disable it to get extra sharpness and reduce subpixel shake for less amount of iterations.") + + self.options['gan_power'] = np.clip ( io.input_number ("GAN power", default_gan_power, add_info="0.0 .. 5.0", help_message="Forces the neural network to learn small details of the face. Enable it only when the face is trained enough with random_warp(off), and don't disable. The higher the value, the higher the chances of artifacts. Typical fine value is 0.1"), 0.0, 5.0 ) + + if self.options['gan_power'] != 0.0: + gan_patch_size = np.clip ( io.input_int("GAN patch size", default_gan_patch_size, add_info="3-640", help_message="The higher patch size, the higher the quality, the more VRAM is required. You can get sharper edges even at the lowest setting. Typical fine value is resolution / 8." ), 3, 640 ) + self.options['gan_patch_size'] = gan_patch_size + + gan_dims = np.clip ( io.input_int("GAN dimensions", default_gan_dims, add_info="4-512", help_message="The dimensions of the GAN network. The higher dimensions, the more VRAM is required. You can get sharper edges even at the lowest setting. Typical fine value is 16." ), 4, 512 ) + self.options['gan_dims'] = gan_dims + + self.options['ct_mode'] = io.input_str (f"Color transfer for src faceset", default_ct_mode, ['none','rct','lct','mkl','idt','sot'], help_message="Change color distribution of src samples close to dst samples. If src faceset is deverse enough, then lct mode is fine in most cases.") + self.options['clipgrad'] = io.input_bool ("Enable gradient clipping", default_clipgrad, help_message="Gradient clipping reduces chance of model collapse, sacrificing speed of training.") + + self.gan_model_changed = (default_gan_patch_size != self.options['gan_patch_size']) or (default_gan_dims != self.options['gan_dims']) + + #override + def on_initialize(self): + device_config = nn.getCurrentDeviceConfig() + devices = device_config.devices + self.model_data_format = "NCHW" + nn.initialize(data_format=self.model_data_format) + tf = nn.tf + + input_ch=3 + resolution = self.resolution = self.options['resolution'] + e_dims = self.options['e_dims'] + ae_dims = self.options['ae_dims'] + inter_dims = self.inter_dims = self.options['inter_dims'] + inter_res = self.inter_res = resolution // 32 + d_dims = self.options['d_dims'] + d_mask_dims = self.options['d_mask_dims'] + face_type = self.face_type = {'f' : FaceType.FULL, + 'wf' : FaceType.WHOLE_FACE, + 'head' : FaceType.HEAD}[ self.options['face_type'] ] + morph_factor = self.options['morph_factor'] + gan_power = self.gan_power = self.options['gan_power'] + random_warp = self.options['random_warp'] + + blur_out_mask = self.options['blur_out_mask'] + + ct_mode = self.options['ct_mode'] + if ct_mode == 'none': + ct_mode = None + + use_fp16 = False + if self.is_exporting: + use_fp16 = io.input_bool ("Export quantized?", False, help_message='Makes the exported model faster. If you have problems, disable this option.') + + conv_dtype = tf.float16 if use_fp16 else tf.float32 + + class Downscale(nn.ModelBase): + def on_build(self, in_ch, out_ch, kernel_size=5 ): + self.conv1 = nn.Conv2D( in_ch, out_ch, kernel_size=kernel_size, strides=2, padding='SAME', dtype=conv_dtype) + + def forward(self, x): + return tf.nn.leaky_relu(self.conv1(x), 0.1) + + class Upscale(nn.ModelBase): + def on_build(self, in_ch, out_ch, kernel_size=3 ): + self.conv1 = nn.Conv2D(in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype) + + def forward(self, x): + x = nn.depth_to_space(tf.nn.leaky_relu(self.conv1(x), 0.1), 2) + return x + + class ResidualBlock(nn.ModelBase): + def on_build(self, ch, kernel_size=3 ): + self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype) + self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype) + + def forward(self, inp): + x = self.conv1(inp) + x = tf.nn.leaky_relu(x, 0.2) + x = self.conv2(x) + x = tf.nn.leaky_relu(inp+x, 0.2) + return x + + class Encoder(nn.ModelBase): + def on_build(self): + self.down1 = Downscale(input_ch, e_dims, kernel_size=5) + self.res1 = ResidualBlock(e_dims) + self.down2 = Downscale(e_dims, e_dims*2, kernel_size=5) + self.down3 = Downscale(e_dims*2, e_dims*4, kernel_size=5) + self.down4 = Downscale(e_dims*4, e_dims*8, kernel_size=5) + self.down5 = Downscale(e_dims*8, e_dims*8, kernel_size=5) + self.res5 = ResidualBlock(e_dims*8) + self.dense1 = nn.Dense( (( resolution//(2**5) )**2) * e_dims*8, ae_dims ) + + def forward(self, x): + if use_fp16: + x = tf.cast(x, tf.float16) + x = self.down1(x) + x = self.res1(x) + x = self.down2(x) + x = self.down3(x) + x = self.down4(x) + x = self.down5(x) + x = self.res5(x) + if use_fp16: + x = tf.cast(x, tf.float32) + x = nn.pixel_norm(nn.flatten(x), axes=-1) + x = self.dense1(x) + return x + + + class Inter(nn.ModelBase): + def on_build(self): + self.dense2 = nn.Dense(ae_dims, inter_res * inter_res * inter_dims) + + def forward(self, inp): + x = inp + x = self.dense2(x) + x = nn.reshape_4D (x, inter_res, inter_res, inter_dims) + return x + + + class Decoder(nn.ModelBase): + def on_build(self ): + self.upscale0 = Upscale(inter_dims, d_dims*8, kernel_size=3) + self.upscale1 = Upscale(d_dims*8, d_dims*8, kernel_size=3) + self.upscale2 = Upscale(d_dims*8, d_dims*4, kernel_size=3) + self.upscale3 = Upscale(d_dims*4, d_dims*2, kernel_size=3) + + self.res0 = ResidualBlock(d_dims*8, kernel_size=3) + self.res1 = ResidualBlock(d_dims*8, kernel_size=3) + self.res2 = ResidualBlock(d_dims*4, kernel_size=3) + self.res3 = ResidualBlock(d_dims*2, kernel_size=3) + + self.upscalem0 = Upscale(inter_dims, d_mask_dims*8, kernel_size=3) + self.upscalem1 = Upscale(d_mask_dims*8, d_mask_dims*8, kernel_size=3) + self.upscalem2 = Upscale(d_mask_dims*8, d_mask_dims*4, kernel_size=3) + self.upscalem3 = Upscale(d_mask_dims*4, d_mask_dims*2, kernel_size=3) + self.upscalem4 = Upscale(d_mask_dims*2, d_mask_dims*1, kernel_size=3) + self.out_convm = nn.Conv2D( d_mask_dims*1, 1, kernel_size=1, padding='SAME', dtype=conv_dtype) + + self.out_conv = nn.Conv2D( d_dims*2, 3, kernel_size=1, padding='SAME', dtype=conv_dtype) + self.out_conv1 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype) + self.out_conv2 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype) + self.out_conv3 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype) + + def forward(self, z): + if use_fp16: + z = tf.cast(z, tf.float16) + + x = self.upscale0(z) + x = self.res0(x) + x = self.upscale1(x) + x = self.res1(x) + x = self.upscale2(x) + x = self.res2(x) + x = self.upscale3(x) + x = self.res3(x) + + x = tf.nn.sigmoid( nn.depth_to_space(tf.concat( (self.out_conv(x), + self.out_conv1(x), + self.out_conv2(x), + self.out_conv3(x)), nn.conv2d_ch_axis), 2) ) + m = self.upscalem0(z) + m = self.upscalem1(m) + m = self.upscalem2(m) + m = self.upscalem3(m) + m = self.upscalem4(m) + m = tf.nn.sigmoid(self.out_convm(m)) + + if use_fp16: + x = tf.cast(x, tf.float32) + m = tf.cast(m, tf.float32) + return x, m + + models_opt_on_gpu = False if len(devices) == 0 else self.options['models_opt_on_gpu'] + models_opt_device = nn.tf_default_device_name if models_opt_on_gpu and self.is_training else '/CPU:0' + optimizer_vars_on_cpu = models_opt_device=='/CPU:0' + + bgr_shape = self.bgr_shape = nn.get4Dshape(resolution,resolution,input_ch) + mask_shape = nn.get4Dshape(resolution,resolution,1) + self.model_filename_list = [] + + with tf.device ('/CPU:0'): + #Place holders on CPU + self.warped_src = tf.placeholder (nn.floatx, bgr_shape, name='warped_src') + self.warped_dst = tf.placeholder (nn.floatx, bgr_shape, name='warped_dst') + + self.target_src = tf.placeholder (nn.floatx, bgr_shape, name='target_src') + self.target_dst = tf.placeholder (nn.floatx, bgr_shape, name='target_dst') + + self.target_srcm = tf.placeholder (nn.floatx, mask_shape, name='target_srcm') + self.target_srcm_em = tf.placeholder (nn.floatx, mask_shape, name='target_srcm_em') + self.target_dstm = tf.placeholder (nn.floatx, mask_shape, name='target_dstm') + self.target_dstm_em = tf.placeholder (nn.floatx, mask_shape, name='target_dstm_em') + + self.morph_value_t = tf.placeholder (nn.floatx, (1,), name='morph_value_t') + + # Initializing model classes + with tf.device (models_opt_device): + self.encoder = Encoder(name='encoder') + self.inter_src = Inter(name='inter_src') + self.inter_dst = Inter(name='inter_dst') + self.decoder = Decoder(name='decoder') + + self.model_filename_list += [ [self.encoder, 'encoder.npy'], + [self.inter_src, 'inter_src.npy'], + [self.inter_dst , 'inter_dst.npy'], + [self.decoder , 'decoder.npy'] ] + + if self.is_training: + # Initialize optimizers + clipnorm = 1.0 if self.options['clipgrad'] else 0.0 + if self.options['lr_dropout'] in ['y','cpu']: + lr_cos = 500 + lr_dropout = 0.3 + else: + lr_cos = 0 + lr_dropout = 1.0 + self.G_weights = self.encoder.get_weights() + self.decoder.get_weights() + + self.src_dst_opt = nn.AdaBelief(lr=5e-5, lr_dropout=lr_dropout, lr_cos=lr_cos, clipnorm=clipnorm, name='src_dst_opt') + self.src_dst_opt.initialize_variables (self.G_weights, vars_on_cpu=optimizer_vars_on_cpu) + self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ] + + if gan_power != 0: + self.GAN = nn.UNetPatchDiscriminator(patch_size=self.options['gan_patch_size'], in_ch=input_ch, base_ch=self.options['gan_dims'], name="GAN") + self.GAN_opt = nn.AdaBelief(lr=5e-5, lr_dropout=lr_dropout, lr_cos=lr_cos, clipnorm=clipnorm, name='GAN_opt') + self.GAN_opt.initialize_variables ( self.GAN.get_weights(), vars_on_cpu=optimizer_vars_on_cpu) + self.model_filename_list += [ [self.GAN, 'GAN.npy'], + [self.GAN_opt, 'GAN_opt.npy'] ] + + if self.is_training: + # Adjust batch size for multiple GPU + gpu_count = max(1, len(devices) ) + bs_per_gpu = max(1, self.get_batch_size() // gpu_count) + self.set_batch_size( gpu_count*bs_per_gpu) + + # Compute losses per GPU + gpu_pred_src_src_list = [] + gpu_pred_dst_dst_list = [] + gpu_pred_src_dst_list = [] + gpu_pred_src_srcm_list = [] + gpu_pred_dst_dstm_list = [] + gpu_pred_src_dstm_list = [] + + gpu_src_losses = [] + gpu_dst_losses = [] + gpu_G_loss_gradients = [] + gpu_GAN_loss_gradients = [] + + def DLossOnes(logits): + return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(logits), logits=logits), axis=[1,2,3]) + + def DLossZeros(logits): + return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(logits), logits=logits), axis=[1,2,3]) + + for gpu_id in range(gpu_count): + with tf.device( f'/{devices[gpu_id].tf_dev_type}:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ): + with tf.device(f'/CPU:0'): + # slice on CPU, otherwise all batch data will be transfered to GPU first + batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu ) + gpu_warped_src = self.warped_src [batch_slice,:,:,:] + gpu_warped_dst = self.warped_dst [batch_slice,:,:,:] + gpu_target_src = self.target_src [batch_slice,:,:,:] + gpu_target_dst = self.target_dst [batch_slice,:,:,:] + gpu_target_srcm = self.target_srcm[batch_slice,:,:,:] + gpu_target_srcm_em = self.target_srcm_em[batch_slice,:,:,:] + gpu_target_dstm = self.target_dstm[batch_slice,:,:,:] + gpu_target_dstm_em = self.target_dstm_em[batch_slice,:,:,:] + + # process model tensors + gpu_src_code = self.encoder (gpu_warped_src) + gpu_dst_code = self.encoder (gpu_warped_dst) + + gpu_src_inter_src_code, gpu_src_inter_dst_code = self.inter_src (gpu_src_code), self.inter_dst (gpu_src_code) + gpu_dst_inter_src_code, gpu_dst_inter_dst_code = self.inter_src (gpu_dst_code), self.inter_dst (gpu_dst_code) + + inter_dims_bin = int(inter_dims*morph_factor) + with tf.device(f'/CPU:0'): + inter_rnd_binomial = tf.stack([tf.random.shuffle(tf.concat([tf.tile(tf.constant([1], tf.float32), ( inter_dims_bin, )), + tf.tile(tf.constant([0], tf.float32), ( inter_dims-inter_dims_bin, ))], 0 )) for _ in range(bs_per_gpu)], 0) + + inter_rnd_binomial = tf.stop_gradient(inter_rnd_binomial[...,None,None]) + + gpu_src_code = gpu_src_inter_src_code * inter_rnd_binomial + gpu_src_inter_dst_code * (1-inter_rnd_binomial) + gpu_dst_code = gpu_dst_inter_dst_code + + inter_dims_slice = tf.cast(inter_dims*self.morph_value_t[0], tf.int32) + gpu_src_dst_code = tf.concat( (tf.slice(gpu_dst_inter_src_code, [0,0,0,0], [-1, inter_dims_slice , inter_res, inter_res]), + tf.slice(gpu_dst_inter_dst_code, [0,inter_dims_slice,0,0], [-1,inter_dims-inter_dims_slice, inter_res,inter_res]) ), 1 ) + + gpu_pred_src_src, gpu_pred_src_srcm = self.decoder(gpu_src_code) + gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder(gpu_dst_code) + gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code) + + gpu_pred_src_src_list.append(gpu_pred_src_src), gpu_pred_src_srcm_list.append(gpu_pred_src_srcm) + gpu_pred_dst_dst_list.append(gpu_pred_dst_dst), gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm) + gpu_pred_src_dst_list.append(gpu_pred_src_dst), gpu_pred_src_dstm_list.append(gpu_pred_src_dstm) + + gpu_target_srcm_anti = 1-gpu_target_srcm + gpu_target_dstm_anti = 1-gpu_target_dstm + + gpu_target_srcm_gblur = nn.gaussian_blur(gpu_target_srcm, resolution // 32) + gpu_target_dstm_gblur = nn.gaussian_blur(gpu_target_dstm, resolution // 32) + + gpu_target_srcm_blur = tf.clip_by_value(gpu_target_srcm_gblur, 0, 0.5) * 2 + gpu_target_dstm_blur = tf.clip_by_value(gpu_target_dstm_gblur, 0, 0.5) * 2 + gpu_target_srcm_anti_blur = 1.0-gpu_target_srcm_blur + gpu_target_dstm_anti_blur = 1.0-gpu_target_dstm_blur + + if blur_out_mask: + sigma = resolution / 128 + + x = nn.gaussian_blur(gpu_target_src*gpu_target_srcm_anti, sigma) + y = 1-nn.gaussian_blur(gpu_target_srcm, sigma) + y = tf.where(tf.equal(y, 0), tf.ones_like(y), y) + gpu_target_src = gpu_target_src*gpu_target_srcm + (x/y)*gpu_target_srcm_anti + + x = nn.gaussian_blur(gpu_target_dst*gpu_target_dstm_anti, sigma) + y = 1-nn.gaussian_blur(gpu_target_dstm, sigma) + y = tf.where(tf.equal(y, 0), tf.ones_like(y), y) + gpu_target_dst = gpu_target_dst*gpu_target_dstm + (x/y)*gpu_target_dstm_anti + + gpu_target_src_masked = gpu_target_src*gpu_target_srcm_blur + gpu_target_dst_masked = gpu_target_dst*gpu_target_dstm_blur + gpu_target_src_anti_masked = gpu_target_src*gpu_target_srcm_anti_blur + gpu_target_dst_anti_masked = gpu_target_dst*gpu_target_dstm_anti_blur + + gpu_pred_src_src_masked = gpu_pred_src_src*gpu_target_srcm_blur + gpu_pred_dst_dst_masked = gpu_pred_dst_dst*gpu_target_dstm_blur + gpu_pred_src_src_anti_masked = gpu_pred_src_src*gpu_target_srcm_anti_blur + gpu_pred_dst_dst_anti_masked = gpu_pred_dst_dst*gpu_target_dstm_anti_blur + + # Structural loss + gpu_src_loss = tf.reduce_mean (5*nn.dssim(gpu_target_src_masked, gpu_pred_src_src_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) + gpu_src_loss += tf.reduce_mean (5*nn.dssim(gpu_target_src_masked, gpu_pred_src_src_masked, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1]) + gpu_dst_loss = tf.reduce_mean (5*nn.dssim(gpu_target_dst_masked, gpu_pred_dst_dst_masked, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1]) + gpu_dst_loss += tf.reduce_mean (5*nn.dssim(gpu_target_dst_masked, gpu_pred_dst_dst_masked, max_val=1.0, filter_size=int(resolution/23.2) ), axis=[1]) + + # Pixel loss + gpu_src_loss += tf.reduce_mean (10*tf.square(gpu_target_src_masked-gpu_pred_src_src_masked), axis=[1,2,3]) + gpu_dst_loss += tf.reduce_mean (10*tf.square(gpu_target_dst_masked-gpu_pred_dst_dst_masked), axis=[1,2,3]) + + # Eyes+mouth prio loss + gpu_src_loss += tf.reduce_mean (300*tf.abs (gpu_target_src*gpu_target_srcm_em-gpu_pred_src_src*gpu_target_srcm_em), axis=[1,2,3]) + gpu_dst_loss += tf.reduce_mean (300*tf.abs (gpu_target_dst*gpu_target_dstm_em-gpu_pred_dst_dst*gpu_target_dstm_em), axis=[1,2,3]) + + # Mask loss + gpu_src_loss += tf.reduce_mean ( 10*tf.square( gpu_target_srcm - gpu_pred_src_srcm ),axis=[1,2,3] ) + gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] ) + + gpu_src_losses += [gpu_src_loss] + gpu_dst_losses += [gpu_dst_loss] + gpu_G_loss = gpu_src_loss + gpu_dst_loss + # dst-dst background weak loss + gpu_G_loss += tf.reduce_mean(0.1*tf.square(gpu_pred_dst_dst_anti_masked-gpu_target_dst_anti_masked),axis=[1,2,3] ) + gpu_G_loss += 0.000001*nn.total_variation_mse(gpu_pred_dst_dst_anti_masked) + + + if gan_power != 0: + gpu_pred_src_src_d, gpu_pred_src_src_d2 = self.GAN(gpu_pred_src_src_masked) + gpu_pred_dst_dst_d, gpu_pred_dst_dst_d2 = self.GAN(gpu_pred_dst_dst_masked) + gpu_target_src_d, gpu_target_src_d2 = self.GAN(gpu_target_src_masked) + gpu_target_dst_d, gpu_target_dst_d2 = self.GAN(gpu_target_dst_masked) + + gpu_GAN_loss = (DLossOnes (gpu_target_src_d) + DLossOnes (gpu_target_src_d2) + \ + DLossZeros(gpu_pred_src_src_d) + DLossZeros(gpu_pred_src_src_d2) + \ + DLossOnes (gpu_target_dst_d) + DLossOnes (gpu_target_dst_d2) + \ + DLossZeros(gpu_pred_dst_dst_d) + DLossZeros(gpu_pred_dst_dst_d2) + ) * (1.0 / 8) + + gpu_GAN_loss_gradients += [ nn.gradients (gpu_GAN_loss, self.GAN.get_weights() ) ] + + gpu_G_loss += (DLossOnes(gpu_pred_src_src_d) + DLossOnes(gpu_pred_src_src_d2) + \ + DLossOnes(gpu_pred_dst_dst_d) + DLossOnes(gpu_pred_dst_dst_d2) + ) * gan_power + + # Minimal src-src-bg rec with total_variation_mse to suppress random bright dots from gan + gpu_G_loss += 0.000001*nn.total_variation_mse(gpu_pred_src_src) + gpu_G_loss += 0.02*tf.reduce_mean(tf.square(gpu_pred_src_src_anti_masked-gpu_target_src_anti_masked),axis=[1,2,3] ) + + gpu_G_loss_gradients += [ nn.gradients ( gpu_G_loss, self.G_weights ) ] + + # Average losses and gradients, and create optimizer update ops + with tf.device(f'/CPU:0'): + pred_src_src = nn.concat(gpu_pred_src_src_list, 0) + pred_dst_dst = nn.concat(gpu_pred_dst_dst_list, 0) + pred_src_dst = nn.concat(gpu_pred_src_dst_list, 0) + pred_src_srcm = nn.concat(gpu_pred_src_srcm_list, 0) + pred_dst_dstm = nn.concat(gpu_pred_dst_dstm_list, 0) + pred_src_dstm = nn.concat(gpu_pred_src_dstm_list, 0) + + with tf.device (models_opt_device): + src_loss = tf.concat(gpu_src_losses, 0) + dst_loss = tf.concat(gpu_dst_losses, 0) + train_op = self.src_dst_opt.get_update_op (nn.average_gv_list (gpu_G_loss_gradients)) + + if gan_power != 0: + GAN_train_op = self.GAN_opt.get_update_op (nn.average_gv_list(gpu_GAN_loss_gradients) ) + + # Initializing training and view functions + def train(warped_src, target_src, target_srcm, target_srcm_em, \ + warped_dst, target_dst, target_dstm, target_dstm_em, ): + s, d, _ = nn.tf_sess.run ([src_loss, dst_loss, train_op], + feed_dict={self.warped_src :warped_src, + self.target_src :target_src, + self.target_srcm:target_srcm, + self.target_srcm_em:target_srcm_em, + self.warped_dst :warped_dst, + self.target_dst :target_dst, + self.target_dstm:target_dstm, + self.target_dstm_em:target_dstm_em, + }) + return s, d + self.train = train + + if gan_power != 0: + def GAN_train(warped_src, target_src, target_srcm, target_srcm_em, \ + warped_dst, target_dst, target_dstm, target_dstm_em, ): + nn.tf_sess.run ([GAN_train_op], feed_dict={self.warped_src :warped_src, + self.target_src :target_src, + self.target_srcm:target_srcm, + self.target_srcm_em:target_srcm_em, + self.warped_dst :warped_dst, + self.target_dst :target_dst, + self.target_dstm:target_dstm, + self.target_dstm_em:target_dstm_em}) + self.GAN_train = GAN_train + + def AE_view(warped_src, warped_dst, morph_value): + return nn.tf_sess.run ( [pred_src_src, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm], + feed_dict={self.warped_src:warped_src, self.warped_dst:warped_dst, self.morph_value_t:[morph_value] }) + + self.AE_view = AE_view + else: + #Initializing merge function + with tf.device( nn.tf_default_device_name if len(devices) != 0 else f'/CPU:0'): + gpu_dst_code = self.encoder (self.warped_dst) + gpu_dst_inter_src_code = self.inter_src (gpu_dst_code) + gpu_dst_inter_dst_code = self.inter_dst (gpu_dst_code) + + inter_dims_slice = tf.cast(inter_dims*self.morph_value_t[0], tf.int32) + gpu_src_dst_code = tf.concat( ( tf.slice(gpu_dst_inter_src_code, [0,0,0,0], [-1, inter_dims_slice , inter_res, inter_res]), + tf.slice(gpu_dst_inter_dst_code, [0,inter_dims_slice,0,0], [-1,inter_dims-inter_dims_slice, inter_res,inter_res]) ), 1 ) + + gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code) + _, gpu_pred_dst_dstm = self.decoder(gpu_dst_inter_dst_code) + + def AE_merge(warped_dst, morph_value): + return nn.tf_sess.run ( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm], feed_dict={self.warped_dst:warped_dst, self.morph_value_t:[morph_value] }) + + self.AE_merge = AE_merge + + # Loading/initializing all models/optimizers weights + for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"): + do_init = self.is_first_run() + if self.is_training and gan_power != 0 and model == self.GAN: + if self.gan_model_changed: + do_init = True + if not do_init: + do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) ) + if do_init: + model.init_weights() + ############### + + # initializing sample generators + if self.is_training: + training_data_src_path = self.training_data_src_path #if not self.pretrain else self.get_pretraining_data_path() + training_data_dst_path = self.training_data_dst_path #if not self.pretrain else self.get_pretraining_data_path() + + random_ct_samples_path=training_data_dst_path if ct_mode is not None else None #and not self.pretrain + + cpu_count = multiprocessing.cpu_count() + src_generators_count = cpu_count // 2 + dst_generators_count = cpu_count // 2 + if ct_mode is not None: + src_generators_count = int(src_generators_count * 1.5) + + + + self.set_training_data_generators ([ + SampleGeneratorFace(training_data_src_path, random_ct_samples_path=random_ct_samples_path, debug=self.is_debug(), batch_size=self.get_batch_size(), + sample_process_options=SampleProcessor.Options(scale_range=[-0.15, 0.15], random_flip=self.random_src_flip), + output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.EYES_MOUTH, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, + ], + uniform_yaw_distribution=self.options['uniform_yaw'],# or self.pretrain, + generators_count=src_generators_count ), + + SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), + sample_process_options=SampleProcessor.Options(scale_range=[-0.15, 0.15], random_flip=self.random_dst_flip), + output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.EYES_MOUTH, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, + ], + uniform_yaw_distribution=self.options['uniform_yaw'],# or self.pretrain, + generators_count=dst_generators_count ) + ]) + + def export_dfm (self): + output_path=self.get_strpath_storage_for_file('model.dfm') + + io.log_info(f'Dumping .dfm to {output_path}') + + tf = nn.tf + with tf.device (nn.tf_default_device_name): + warped_dst = tf.placeholder (nn.floatx, (None, self.resolution, self.resolution, 3), name='in_face') + warped_dst = tf.transpose(warped_dst, (0,3,1,2)) + morph_value = tf.placeholder (nn.floatx, (1,), name='morph_value') + + gpu_dst_code = self.encoder (warped_dst) + gpu_dst_inter_src_code = self.inter_src ( gpu_dst_code) + gpu_dst_inter_dst_code = self.inter_dst ( gpu_dst_code) + + inter_dims_slice = tf.cast(self.inter_dims*morph_value[0], tf.int32) + gpu_src_dst_code = tf.concat( (tf.slice(gpu_dst_inter_src_code, [0,0,0,0], [-1, inter_dims_slice , self.inter_res, self.inter_res]), + tf.slice(gpu_dst_inter_dst_code, [0,inter_dims_slice,0,0], [-1,self.inter_dims-inter_dims_slice, self.inter_res,self.inter_res]) ), 1 ) + + gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code) + _, gpu_pred_dst_dstm = self.decoder(gpu_dst_inter_dst_code) + + gpu_pred_src_dst = tf.transpose(gpu_pred_src_dst, (0,2,3,1)) + gpu_pred_dst_dstm = tf.transpose(gpu_pred_dst_dstm, (0,2,3,1)) + gpu_pred_src_dstm = tf.transpose(gpu_pred_src_dstm, (0,2,3,1)) + + tf.identity(gpu_pred_dst_dstm, name='out_face_mask') + tf.identity(gpu_pred_src_dst, name='out_celeb_face') + tf.identity(gpu_pred_src_dstm, name='out_celeb_face_mask') + + output_graph_def = tf.graph_util.convert_variables_to_constants( + nn.tf_sess, + tf.get_default_graph().as_graph_def(), + ['out_face_mask','out_celeb_face','out_celeb_face_mask'] + ) + + import tf2onnx + with tf.device("/CPU:0"): + model_proto, _ = tf2onnx.convert._convert_common( + output_graph_def, + name='AMP', + input_names=['in_face:0','morph_value:0'], + output_names=['out_face_mask:0','out_celeb_face:0','out_celeb_face_mask:0'], + opset=12, + output_path=output_path) + + #override + def get_model_filename_list(self): + return self.model_filename_list + + #override + def onSave(self): + for model, filename in io.progress_bar_generator(self.get_model_filename_list(), "Saving", leave=False): + model.save_weights ( self.get_strpath_storage_for_file(filename) ) + + #override + def should_save_preview_history(self): + return (not io.is_colab() and self.iter % ( 10*(max(1,self.resolution // 64)) ) == 0) or \ + (io.is_colab() and self.iter % 100 == 0) + + #override + def onTrainOneIter(self): + bs = self.get_batch_size() + + ( (warped_src, target_src, target_srcm, target_srcm_em), \ + (warped_dst, target_dst, target_dstm, target_dstm_em) ) = self.generate_next_samples() + + src_loss, dst_loss = self.train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em) + + if self.gan_power != 0: + self.GAN_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em) + + return ( ('src_loss', np.mean(src_loss) ), ('dst_loss', np.mean(dst_loss) ), ) + + #override + def onGetPreview(self, samples, for_history=False): + ( (warped_src, target_src, target_srcm, target_srcm_em), + (warped_dst, target_dst, target_dstm, target_dstm_em) ) = samples + + S, D, SS, DD, DDM_000, _, _ = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([target_src,target_dst] + self.AE_view (target_src, target_dst, 0.0) ) ] + + _, _, DDM_025, SD_025, SDM_025 = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in self.AE_view (target_src, target_dst, 0.25) ] + _, _, DDM_050, SD_050, SDM_050 = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in self.AE_view (target_src, target_dst, 0.50) ] + _, _, DDM_065, SD_065, SDM_065 = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in self.AE_view (target_src, target_dst, 0.65) ] + _, _, DDM_075, SD_075, SDM_075 = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in self.AE_view (target_src, target_dst, 0.75) ] + _, _, DDM_100, SD_100, SDM_100 = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in self.AE_view (target_src, target_dst, 1.00) ] + + (DDM_000, + DDM_025, SDM_025, + DDM_050, SDM_050, + DDM_065, SDM_065, + DDM_075, SDM_075, + DDM_100, SDM_100) = [ np.repeat (x, (3,), -1) for x in (DDM_000, + DDM_025, SDM_025, + DDM_050, SDM_050, + DDM_065, SDM_065, + DDM_075, SDM_075, + DDM_100, SDM_100) ] + + target_srcm, target_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format) for x in ([target_srcm, target_dstm] )] + + n_samples = min(4, self.get_batch_size(), 800 // self.resolution ) + + result = [] + + i = np.random.randint(n_samples) if not for_history else 0 + + st = [ np.concatenate ((S[i], D[i], DD[i]*DDM_000[i]), axis=1) ] + st += [ np.concatenate ((SS[i], DD[i], SD_100[i] ), axis=1) ] + + result += [ ('AMP morph 1.0', np.concatenate (st, axis=0 )), ] + + st = [ np.concatenate ((DD[i], SD_025[i], SD_050[i]), axis=1) ] + st += [ np.concatenate ((SD_065[i], SD_075[i], SD_100[i]), axis=1) ] + result += [ ('AMP morph list', np.concatenate (st, axis=0 )), ] + + st = [ np.concatenate ((DD[i], SD_025[i]*DDM_025[i]*SDM_025[i], SD_050[i]*DDM_050[i]*SDM_050[i]), axis=1) ] + st += [ np.concatenate ((SD_065[i]*DDM_065[i]*SDM_065[i], SD_075[i]*DDM_075[i]*SDM_075[i], SD_100[i]*DDM_100[i]*SDM_100[i]), axis=1) ] + result += [ ('AMP morph list masked', np.concatenate (st, axis=0 )), ] + + return result + + def predictor_func (self, face, morph_value): + face = nn.to_data_format(face[None,...], self.model_data_format, "NHWC") + + bgr, mask_dst_dstm, mask_src_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format).astype(np.float32) for x in self.AE_merge (face, morph_value) ] + + return bgr[0], mask_src_dstm[0][...,0], mask_dst_dstm[0][...,0] + + #override + def get_MergerConfig(self): + morph_factor = np.clip ( io.input_number ("Morph factor", 1.0, add_info="0.0 .. 1.0"), 0.0, 1.0 ) + + def predictor_morph(face): + return self.predictor_func(face, morph_factor) + + + import merger + return predictor_morph, (self.options['resolution'], self.options['resolution'], 3), merger.MergerConfigMasked(face_type=self.face_type, default_mode = 'overlay') + +Model = AMPModel diff --git a/models/Model_FANSeg/__init__.py b/models/Model_AMP/__init__.py similarity index 100% rename from models/Model_FANSeg/__init__.py rename to models/Model_AMP/__init__.py diff --git a/models/Model_FANSeg/Model.py b/models/Model_FANSeg/Model.py deleted file mode 100644 index d41adbd..0000000 --- a/models/Model_FANSeg/Model.py +++ /dev/null @@ -1,188 +0,0 @@ -import multiprocessing -import operator -from functools import partial - -import numpy as np - -from core import mathlib -from core.interact import interact as io -from core.leras import nn -from facelib import FaceType, TernausNet -from models import ModelBase -from samplelib import * - -class FANSegModel(ModelBase): - - def __init__(self, *args, **kwargs): - super().__init__(*args, force_model_class_name='FANSeg', **kwargs) - - #override - def on_initialize_options(self): - device_config = nn.getCurrentDeviceConfig() - yn_str = {True:'y',False:'n'} - - ask_override = self.ask_override() - if self.is_first_run() or ask_override: - self.ask_autobackup_hour() - self.ask_target_iter() - self.ask_batch_size(24) - - default_lr_dropout = self.options['lr_dropout'] = self.load_or_def_option('lr_dropout', False) - - if self.is_first_run() or ask_override: - self.options['lr_dropout'] = io.input_bool ("Use learning rate dropout", default_lr_dropout, help_message="When the face is trained enough, you can enable this option to get extra sharpness and reduce subpixel shake for less amount of iterations.") - - #override - def on_initialize(self): - device_config = nn.getCurrentDeviceConfig() - nn.initialize(data_format="NHWC") - tf = nn.tf - - device_config = nn.getCurrentDeviceConfig() - devices = device_config.devices - - self.resolution = resolution = 256 - self.face_type = FaceType.FULL - - place_model_on_cpu = len(devices) == 0 - models_opt_device = '/CPU:0' if place_model_on_cpu else '/GPU:0' - - bgr_shape = nn.get4Dshape(resolution,resolution,3) - mask_shape = nn.get4Dshape(resolution,resolution,1) - - # Initializing model classes - self.model = TernausNet(f'FANSeg_{FaceType.toString(self.face_type)}', - resolution, - load_weights=not self.is_first_run(), - weights_file_root=self.get_model_root_path(), - training=True, - place_model_on_cpu=place_model_on_cpu, - optimizer=nn.RMSprop(lr=0.0001, lr_dropout=0.3 if self.options['lr_dropout'] else 1.0,name='opt') ) - - if self.is_training: - # Adjust batch size for multiple GPU - gpu_count = max(1, len(devices) ) - bs_per_gpu = max(1, self.get_batch_size() // gpu_count) - self.set_batch_size( gpu_count*bs_per_gpu) - - - # Compute losses per GPU - gpu_pred_list = [] - - gpu_losses = [] - gpu_loss_gvs = [] - - for gpu_id in range(gpu_count): - with tf.device( f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ): - - with tf.device(f'/CPU:0'): - # slice on CPU, otherwise all batch data will be transfered to GPU first - batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu ) - gpu_input_t = self.model.input_t [batch_slice,:,:,:] - gpu_target_t = self.model.target_t [batch_slice,:,:,:] - - # process model tensors - gpu_pred_logits_t, gpu_pred_t = self.model.net([gpu_input_t]) - gpu_pred_list.append(gpu_pred_t) - - gpu_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=gpu_target_t, logits=gpu_pred_logits_t), axis=[1,2,3]) - gpu_losses += [gpu_loss] - - gpu_loss_gvs += [ nn.gradients ( gpu_loss, self.model.net_weights ) ] - - - # Average losses and gradients, and create optimizer update ops - with tf.device (models_opt_device): - pred = nn.concat(gpu_pred_list, 0) - loss = tf.reduce_mean(gpu_losses) - - loss_gv_op = self.model.opt.get_update_op (nn.average_gv_list (gpu_loss_gvs)) - - - # Initializing training and view functions - def train(input_np, target_np): - l, _ = nn.tf_sess.run ( [loss, loss_gv_op], feed_dict={self.model.input_t :input_np, self.model.target_t :target_np }) - return l - self.train = train - - def view(input_np): - return nn.tf_sess.run ( [pred], feed_dict={self.model.input_t :input_np}) - self.view = view - - # initializing sample generators - training_data_src_path = self.training_data_src_path - training_data_dst_path = self.training_data_dst_path - - cpu_count = min(multiprocessing.cpu_count(), 8) - src_generators_count = cpu_count // 2 - dst_generators_count = cpu_count // 2 - src_generators_count = int(src_generators_count * 1.5) - - src_generator = SampleGeneratorFace(training_data_src_path, random_ct_samples_path=training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(), - sample_process_options=SampleProcessor.Options(random_flip=True), - output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'ct_mode':'lct', 'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'random_motion_blur':(25, 5), 'random_gaussian_blur':(25,5), 'data_format':nn.data_format, 'resolution': resolution}, - {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, - ], - generators_count=src_generators_count ) - - dst_generator = SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), - sample_process_options=SampleProcessor.Options(random_flip=True), - output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp':False, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, - ], - generators_count=dst_generators_count, - raise_on_no_data=False ) - if not dst_generator.is_initialized(): - io.log_info(f"\nTo view the model on unseen faces, place any aligned faces in {training_data_dst_path}.\n") - - self.set_training_data_generators ([src_generator, dst_generator]) - - #override - def get_model_filename_list(self): - return self.model.model_filename_list - - #override - def onSave(self): - self.model.save_weights() - - #override - def onTrainOneIter(self): - source_np, target_np = self.generate_next_samples()[0] - loss = self.train (source_np, target_np) - - return ( ('loss', loss ), ) - - #override - def onGetPreview(self, samples): - n_samples = min(4, self.get_batch_size(), 800 // self.resolution ) - - src_samples, dst_samples = samples - source_np, target_np = src_samples - - S, TM, SM, = [ np.clip(x, 0.0, 1.0) for x in ([source_np,target_np] + self.view (source_np) ) ] - TM, SM, = [ np.repeat (x, (3,), -1) for x in [TM, SM] ] - - green_bg = np.tile( np.array([0,1,0], dtype=np.float32)[None,None,...], (self.resolution,self.resolution,1) ) - - result = [] - st = [] - for i in range(n_samples): - ar = S[i]*TM[i] + 0.5*S[i]*(1-TM[i]) + 0.5*green_bg*(1-TM[i]), SM[i], S[i]*SM[i] + green_bg*(1-SM[i]) - st.append ( np.concatenate ( ar, axis=1) ) - result += [ ('FANSeg training faces', np.concatenate (st, axis=0 )), ] - - if len(dst_samples) != 0: - dst_np, = dst_samples - - D, DM, = [ np.clip(x, 0.0, 1.0) for x in ([dst_np] + self.view (dst_np) ) ] - DM, = [ np.repeat (x, (3,), -1) for x in [DM] ] - - st = [] - for i in range(n_samples): - ar = D[i], DM[i], D[i]*DM[i]+ green_bg*(1-DM[i]) - st.append ( np.concatenate ( ar, axis=1) ) - - result += [ ('FANSeg unseen faces', np.concatenate (st, axis=0 )), ] - - return result - -Model = FANSegModel diff --git a/models/Model_Quick96/Model.py b/models/Model_Quick96/Model.py index eccdbc7..fa9e215 100644 --- a/models/Model_Quick96/Model.py +++ b/models/Model_Quick96/Model.py @@ -22,15 +22,16 @@ class QModel(ModelBase): resolution = self.resolution = 96 self.face_type = FaceType.FULL ae_dims = 128 - e_dims = 128 + e_dims = 64 d_dims = 64 + d_mask_dims = 16 self.pretrain = False self.pretrain_just_disabled = False masked_training = True models_opt_on_gpu = len(devices) >= 1 and all([dev.total_mem_gb >= 4 for dev in devices]) - models_opt_device = '/GPU:0' if models_opt_on_gpu and self.is_training else '/CPU:0' + models_opt_device = nn.tf_default_device_name if models_opt_on_gpu and self.is_training else '/CPU:0' optimizer_vars_on_cpu = models_opt_device=='/CPU:0' input_ch = 3 @@ -39,7 +40,7 @@ class QModel(ModelBase): self.model_filename_list = [] - model_archi = nn.DeepFakeArchi(resolution, mod='quick') + model_archi = nn.DeepFakeArchi(resolution, opts='ud') with tf.device ('/CPU:0'): #Place holders on CPU @@ -55,13 +56,13 @@ class QModel(ModelBase): # Initializing model classes with tf.device (models_opt_device): self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, name='encoder') - encoder_out_ch = self.encoder.compute_output_channels ( (nn.floatx, bgr_shape)) + encoder_out_ch = self.encoder.get_out_ch()*self.encoder.get_out_res(resolution)**2 - self.inter = model_archi.Inter (in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims, d_ch=d_dims, name='inter') - inter_out_ch = self.inter.compute_output_channels ( (nn.floatx, (None,encoder_out_ch))) + self.inter = model_archi.Inter (in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims, name='inter') + inter_out_ch = self.inter.get_out_ch() - self.decoder_src = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, name='decoder_src') - self.decoder_dst = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, name='decoder_dst') + self.decoder_src = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder_src') + self.decoder_dst = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder_dst') self.model_filename_list += [ [self.encoder, 'encoder.npy' ], [self.inter, 'inter.npy' ], @@ -95,7 +96,7 @@ class QModel(ModelBase): gpu_src_dst_loss_gvs = [] for gpu_id in range(gpu_count): - with tf.device( f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ): + with tf.device( f'/{devices[gpu_id].tf_dev_type}:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ): batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu ) with tf.device(f'/CPU:0'): # slice on CPU, otherwise all batch data will be transfered to GPU first @@ -189,7 +190,7 @@ class QModel(ModelBase): self.AE_view = AE_view else: # Initializing merge function - with tf.device( f'/GPU:0' if len(devices) != 0 else f'/CPU:0'): + with tf.device( nn.tf_default_device_name if len(devices) != 0 else f'/CPU:0'): gpu_dst_code = self.inter(self.encoder(self.warped_dst)) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code) _, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code) @@ -277,7 +278,7 @@ class QModel(ModelBase): return ( ('src_loss', src_loss), ('dst_loss', dst_loss), ) #override - def onGetPreview(self, samples): + def onGetPreview(self, samples, for_history=False): ( (warped_src, target_src, target_srcm), (warped_dst, target_dst, target_dstm) ) = samples @@ -308,8 +309,7 @@ class QModel(ModelBase): face = nn.to_data_format(face[None,...], self.model_data_format, "NHWC") bgr, mask_dst_dstm, mask_src_dstm = [ nn.to_data_format(x, "NHWC", self.model_data_format).astype(np.float32) for x in self.AE_merge (face) ] - mask = mask_dst_dstm[0] * mask_src_dstm[0] - return bgr[0], mask[...,0] + return bgr[0], mask_src_dstm[0][...,0], mask_dst_dstm[0][...,0] #override def get_MergerConfig(self): diff --git a/models/Model_SAEHD/Model.py b/models/Model_SAEHD/Model.py index b3dde30..ecfaa73 100644 --- a/models/Model_SAEHD/Model.py +++ b/models/Model_SAEHD/Model.py @@ -27,20 +27,33 @@ class SAEHDModel(ModelBase): suggest_batch_size = 4 yn_str = {True:'y',False:'n'} + min_res = 64 + max_res = 640 + #default_usefp16 = self.options['use_fp16'] = self.load_or_def_option('use_fp16', False) default_resolution = self.options['resolution'] = self.load_or_def_option('resolution', 128) default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'f') default_models_opt_on_gpu = self.options['models_opt_on_gpu'] = self.load_or_def_option('models_opt_on_gpu', True) - default_archi = self.options['archi'] = self.load_or_def_option('archi', 'df') + + default_archi = self.options['archi'] = self.load_or_def_option('archi', 'liae-ud') + default_ae_dims = self.options['ae_dims'] = self.load_or_def_option('ae_dims', 256) default_e_dims = self.options['e_dims'] = self.load_or_def_option('e_dims', 64) default_d_dims = self.options['d_dims'] = self.options.get('d_dims', None) default_d_mask_dims = self.options['d_mask_dims'] = self.options.get('d_mask_dims', None) default_masked_training = self.options['masked_training'] = self.load_or_def_option('masked_training', True) - default_eyes_prio = self.options['eyes_prio'] = self.load_or_def_option('eyes_prio', False) - default_lr_dropout = self.options['lr_dropout'] = self.load_or_def_option('lr_dropout', False) + default_eyes_mouth_prio = self.options['eyes_mouth_prio'] = self.load_or_def_option('eyes_mouth_prio', False) + default_uniform_yaw = self.options['uniform_yaw'] = self.load_or_def_option('uniform_yaw', False) + default_blur_out_mask = self.options['blur_out_mask'] = self.load_or_def_option('blur_out_mask', False) + + default_adabelief = self.options['adabelief'] = self.load_or_def_option('adabelief', True) + + lr_dropout = self.load_or_def_option('lr_dropout', 'n') + lr_dropout = {True:'y', False:'n'}.get(lr_dropout, lr_dropout) #backward comp + default_lr_dropout = self.options['lr_dropout'] = lr_dropout + default_random_warp = self.options['random_warp'] = self.load_or_def_option('random_warp', True) - default_gan_power = self.options['gan_power'] = self.load_or_def_option('gan_power', 0.0) + default_random_hsv_power = self.options['random_hsv_power'] = self.load_or_def_option('random_hsv_power', 0.0) default_true_face_power = self.options['true_face_power'] = self.load_or_def_option('true_face_power', 0.0) default_face_style_power = self.options['face_style_power'] = self.load_or_def_option('face_style_power', 0.0) default_bg_style_power = self.options['bg_style_power'] = self.load_or_def_option('bg_style_power', 0.0) @@ -53,18 +66,55 @@ class SAEHDModel(ModelBase): self.ask_autobackup_hour() self.ask_write_preview_history() self.ask_target_iter() - self.ask_random_flip() + self.ask_random_src_flip() + self.ask_random_dst_flip() self.ask_batch_size(suggest_batch_size) + #self.options['use_fp16'] = io.input_bool ("Use fp16", default_usefp16, help_message='Increases training/inference speed, reduces model size. Model may crash. Enable it after 1-5k iters.') if self.is_first_run(): - resolution = io.input_int("Resolution", default_resolution, add_info="64-512", help_message="More resolution requires more VRAM and time to train. Value will be adjusted to multiple of 16.") - resolution = np.clip ( (resolution // 16) * 16, 64, 512) + resolution = io.input_int("Resolution", default_resolution, add_info="64-640", help_message="More resolution requires more VRAM and time to train. Value will be adjusted to multiple of 16 and 32 for -d archi.") + resolution = np.clip ( (resolution // 16) * 16, min_res, max_res) self.options['resolution'] = resolution - self.options['face_type'] = io.input_str ("Face type", default_face_type, ['h','mf','f','wf'], help_message="Half / mid face / full face / whole face. Half face has better resolution, but covers less area of cheeks. Mid face is 30% wider than half face. 'Whole face' covers full area of face include forehead, but requires manual merge in Adobe After Effects.").lower() - self.options['archi'] = io.input_str ("AE architecture", default_archi, ['df','liae','dfhd','liaehd'], help_message="'df' keeps faces more natural.\n'liae' can fix overly different face shapes.\n'hd' are experimental versions.").lower() - default_d_dims = 48 if self.options['archi'] == 'dfhd' else 64 - default_d_dims = self.options['d_dims'] = self.load_or_def_option('d_dims', default_d_dims) + + + self.options['face_type'] = io.input_str ("Face type", default_face_type, ['h','mf','f','wf','head'], help_message="Half / mid face / full face / whole face / head. Half face has better resolution, but covers less area of cheeks. Mid face is 30% wider than half face. 'Whole face' covers full area of face include forehead. 'head' covers full head, but requires XSeg for src and dst faceset.").lower() + + while True: + archi = io.input_str ("AE architecture", default_archi, help_message=\ +""" +'df' keeps more identity-preserved face. +'liae' can fix overly different face shapes. +'-u' increased likeness of the face. +'-d' (experimental) doubling the resolution using the same computation cost. +Examples: df, liae, df-d, df-ud, liae-ud, ... +""").lower() + + archi_split = archi.split('-') + + if len(archi_split) == 2: + archi_type, archi_opts = archi_split + elif len(archi_split) == 1: + archi_type, archi_opts = archi_split[0], None + else: + continue + + if archi_type not in ['df', 'liae']: + continue + + if archi_opts is not None: + if len(archi_opts) == 0: + continue + if len([ 1 for opt in archi_opts if opt not in ['u','d','t','c'] ]) != 0: + continue + + if 'd' in archi_opts: + self.options['resolution'] = np.clip ( (self.options['resolution'] // 32) * 32, min_res, max_res) + + break + self.options['archi'] = archi + + default_d_dims = self.options['d_dims'] = self.load_or_def_option('d_dims', 64) default_d_mask_dims = default_d_dims // 3 default_d_mask_dims += default_d_mask_dims % 2 @@ -76,7 +126,6 @@ class SAEHDModel(ModelBase): e_dims = np.clip ( io.input_int("Encoder dimensions", default_e_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 ) self.options['e_dims'] = e_dims + e_dims % 2 - d_dims = np.clip ( io.input_int("Decoder dimensions", default_d_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 ) self.options['d_dims'] = d_dims + d_dims % 2 @@ -84,36 +133,55 @@ class SAEHDModel(ModelBase): self.options['d_mask_dims'] = d_mask_dims + d_mask_dims % 2 if self.is_first_run() or ask_override: - if self.options['face_type'] == 'wf': - self.options['masked_training'] = io.input_bool ("Masked training", default_masked_training, help_message="This option is available only for 'whole_face' type. Masked training clips training area to full_face mask, thus network will train the faces properly. When the face is trained enough, disable this option to train all area of the frame. Merge with 'raw-rgb' mode, then use Adobe After Effects to manually mask and compose whole face include forehead.") - - self.options['eyes_prio'] = io.input_bool ("Eyes priority", default_eyes_prio, help_message='Helps to fix eye problems during training like "alien eyes" and wrong eyes direction ( especially on HD architectures ) by forcing the neural network to train eyes with higher priority. before/after https://i.imgur.com/YQHOuSR.jpg ') - + if self.options['face_type'] == 'wf' or self.options['face_type'] == 'head': + self.options['masked_training'] = io.input_bool ("Masked training", default_masked_training, help_message="This option is available only for 'whole_face' or 'head' type. Masked training clips training area to full_face mask or XSeg mask, thus network will train the faces properly.") + + self.options['eyes_mouth_prio'] = io.input_bool ("Eyes and mouth priority", default_eyes_mouth_prio, help_message='Helps to fix eye problems during training like "alien eyes" and wrong eyes direction. Also makes the detail of the teeth higher.') + self.options['uniform_yaw'] = io.input_bool ("Uniform yaw distribution of samples", default_uniform_yaw, help_message='Helps to fix blurry side faces due to small amount of them in the faceset.') + self.options['blur_out_mask'] = io.input_bool ("Blur out mask", default_blur_out_mask, help_message='Blurs nearby area outside of applied face mask of training samples. The result is the background near the face is smoothed and less noticeable on swapped face. The exact xseg mask in src and dst faceset is required.') + + default_gan_power = self.options['gan_power'] = self.load_or_def_option('gan_power', 0.0) + default_gan_patch_size = self.options['gan_patch_size'] = self.load_or_def_option('gan_patch_size', self.options['resolution'] // 8) + default_gan_dims = self.options['gan_dims'] = self.load_or_def_option('gan_dims', 16) + if self.is_first_run() or ask_override: self.options['models_opt_on_gpu'] = io.input_bool ("Place models and optimizer on GPU", default_models_opt_on_gpu, help_message="When you train on one GPU, by default model and optimizer weights are placed on GPU to accelerate the process. You can place they on CPU to free up extra VRAM, thus set bigger dimensions.") - self.options['lr_dropout'] = io.input_bool ("Use learning rate dropout", default_lr_dropout, help_message="When the face is trained enough, you can enable this option to get extra sharpness and reduce subpixel shake for less amount of iterations.") + self.options['adabelief'] = io.input_bool ("Use AdaBelief optimizer?", default_adabelief, help_message="Use AdaBelief optimizer. It requires more VRAM, but the accuracy and the generalization of the model is higher.") + + self.options['lr_dropout'] = io.input_str (f"Use learning rate dropout", default_lr_dropout, ['n','y','cpu'], help_message="When the face is trained enough, you can enable this option to get extra sharpness and reduce subpixel shake for less amount of iterations. Enabled it before `disable random warp` and before GAN. \nn - disabled.\ny - enabled\ncpu - enabled on CPU. This allows not to use extra VRAM, sacrificing 20% time of iteration.") + self.options['random_warp'] = io.input_bool ("Enable random warp of samples", default_random_warp, help_message="Random warp is required to generalize facial expressions of both faces. When the face is trained enough, you can disable it to get extra sharpness and reduce subpixel shake for less amount of iterations.") - self.options['gan_power'] = np.clip ( io.input_number ("GAN power", default_gan_power, add_info="0.0 .. 10.0", help_message="Train the network in Generative Adversarial manner. Accelerates the speed of training. Forces the neural network to learn small details of the face. You can enable/disable this option at any time. Typical value is 1.0"), 0.0, 10.0 ) + self.options['random_hsv_power'] = np.clip ( io.input_number ("Random hue/saturation/light intensity", default_random_hsv_power, add_info="0.0 .. 0.3", help_message="Random hue/saturation/light intensity applied to the src face set only at the input of the neural network. Stabilizes color perturbations during face swapping. Reduces the quality of the color transfer by selecting the closest one in the src faceset. Thus the src faceset must be diverse enough. Typical fine value is 0.05"), 0.0, 0.3 ) + + self.options['gan_power'] = np.clip ( io.input_number ("GAN power", default_gan_power, add_info="0.0 .. 5.0", help_message="Forces the neural network to learn small details of the face. Enable it only when the face is trained enough with lr_dropout(on) and random_warp(off), and don't disable. The higher the value, the higher the chances of artifacts. Typical fine value is 0.1"), 0.0, 5.0 ) + + if self.options['gan_power'] != 0.0: + gan_patch_size = np.clip ( io.input_int("GAN patch size", default_gan_patch_size, add_info="3-640", help_message="The higher patch size, the higher the quality, the more VRAM is required. You can get sharper edges even at the lowest setting. Typical fine value is resolution / 8." ), 3, 640 ) + self.options['gan_patch_size'] = gan_patch_size + + gan_dims = np.clip ( io.input_int("GAN dimensions", default_gan_dims, add_info="4-512", help_message="The dimensions of the GAN network. The higher dimensions, the more VRAM is required. You can get sharper edges even at the lowest setting. Typical fine value is 16." ), 4, 512 ) + self.options['gan_dims'] = gan_dims if 'df' in self.options['archi']: self.options['true_face_power'] = np.clip ( io.input_number ("'True face' power.", default_true_face_power, add_info="0.0000 .. 1.0", help_message="Experimental option. Discriminates result face to be more like src face. Higher value - stronger discrimination. Typical value is 0.01 . Comparison - https://i.imgur.com/czScS9q.png"), 0.0, 1.0 ) else: self.options['true_face_power'] = 0.0 - if self.options['face_type'] != 'wf': - self.options['face_style_power'] = np.clip ( io.input_number("Face style power", default_face_style_power, add_info="0.0..100.0", help_message="Learn to transfer face style details such as light and color conditions. Warning: Enable it only after 10k iters, when predicted face is clear enough to start learn style. Start from 0.001 value and check history changes. Enabling this option increases the chance of model collapse."), 0.0, 100.0 ) - self.options['bg_style_power'] = np.clip ( io.input_number("Background style power", default_bg_style_power, add_info="0.0..100.0", help_message="Learn to transfer background around face. This can make face more like dst. Enabling this option increases the chance of model collapse. Typical value is 2.0"), 0.0, 100.0 ) - + self.options['face_style_power'] = np.clip ( io.input_number("Face style power", default_face_style_power, add_info="0.0..100.0", help_message="Learn the color of the predicted face to be the same as dst inside mask. If you want to use this option with 'whole_face' you have to use XSeg trained mask. Warning: Enable it only after 10k iters, when predicted face is clear enough to start learn style. Start from 0.001 value and check history changes. Enabling this option increases the chance of model collapse."), 0.0, 100.0 ) + self.options['bg_style_power'] = np.clip ( io.input_number("Background style power", default_bg_style_power, add_info="0.0..100.0", help_message="Learn the area outside mask of the predicted face to be the same as dst. If you want to use this option with 'whole_face' you have to use XSeg trained mask. For whole_face you have to use XSeg trained mask. This can make face more like dst. Enabling this option increases the chance of model collapse. Typical value is 2.0"), 0.0, 100.0 ) + self.options['ct_mode'] = io.input_str (f"Color transfer for src faceset", default_ct_mode, ['none','rct','lct','mkl','idt','sot'], help_message="Change color distribution of src samples close to dst samples. Try all modes to find the best.") self.options['clipgrad'] = io.input_bool ("Enable gradient clipping", default_clipgrad, help_message="Gradient clipping reduces chance of model collapse, sacrificing speed of training.") - - self.options['pretrain'] = io.input_bool ("Enable pretraining mode", default_pretrain, help_message="Pretrain the model with large amount of various faces. After that, model can be used to train the fakes more quickly.") + + self.options['pretrain'] = io.input_bool ("Enable pretraining mode", default_pretrain, help_message="Pretrain the model with large amount of various faces. After that, model can be used to train the fakes more quickly. Forces random_warp=N, random_flips=Y, gan_power=0.0, lr_dropout=N, styles=0.0, uniform_yaw=Y") if self.options['pretrain'] and self.get_pretraining_data_path() is None: raise Exception("pretraining_data_path is not defined") + self.gan_model_changed = (default_gan_patch_size != self.options['gan_patch_size']) or (default_gan_dims != self.options['gan_dims']) + self.pretrain_just_disabled = (default_pretrain == True and self.options['pretrain'] == False) #override @@ -128,11 +196,23 @@ class SAEHDModel(ModelBase): self.face_type = {'h' : FaceType.HALF, 'mf' : FaceType.MID_FULL, 'f' : FaceType.FULL, - 'wf' : FaceType.WHOLE_FACE}[ self.options['face_type'] ] + 'wf' : FaceType.WHOLE_FACE, + 'head' : FaceType.HEAD}[ self.options['face_type'] ] + + if 'eyes_prio' in self.options: + self.options.pop('eyes_prio') + + eyes_mouth_prio = self.options['eyes_mouth_prio'] + + archi_split = self.options['archi'].split('-') + + if len(archi_split) == 2: + archi_type, archi_opts = archi_split + elif len(archi_split) == 1: + archi_type, archi_opts = archi_split[0], None + + self.archi_type = archi_type - eyes_prio = self.options['eyes_prio'] - archi = self.options['archi'] - is_hd = 'hd' in archi ae_dims = self.options['ae_dims'] e_dims = self.options['e_dims'] d_dims = self.options['d_dims'] @@ -141,46 +221,69 @@ class SAEHDModel(ModelBase): if self.pretrain_just_disabled: self.set_iter(0) - self.gan_power = gan_power = self.options['gan_power'] if not self.pretrain else 0.0 + adabelief = self.options['adabelief'] + + use_fp16 = False + if self.is_exporting: + use_fp16 = io.input_bool ("Export quantized?", False, help_message='Makes the exported model faster. If you have problems, disable this option.') + + self.gan_power = gan_power = 0.0 if self.pretrain else self.options['gan_power'] + random_warp = False if self.pretrain else self.options['random_warp'] + random_src_flip = self.random_src_flip if not self.pretrain else True + random_dst_flip = self.random_dst_flip if not self.pretrain else True + random_hsv_power = self.options['random_hsv_power'] if not self.pretrain else 0.0 + blur_out_mask = self.options['blur_out_mask'] + + if self.pretrain: + self.options_show_override['lr_dropout'] = 'n' + self.options_show_override['random_warp'] = False + self.options_show_override['gan_power'] = 0.0 + self.options_show_override['random_hsv_power'] = 0.0 + self.options_show_override['face_style_power'] = 0.0 + self.options_show_override['bg_style_power'] = 0.0 + self.options_show_override['uniform_yaw'] = True masked_training = self.options['masked_training'] ct_mode = self.options['ct_mode'] if ct_mode == 'none': ct_mode = None + models_opt_on_gpu = False if len(devices) == 0 else self.options['models_opt_on_gpu'] - models_opt_device = '/GPU:0' if models_opt_on_gpu and self.is_training else '/CPU:0' + models_opt_device = nn.tf_default_device_name if models_opt_on_gpu and self.is_training else '/CPU:0' optimizer_vars_on_cpu = models_opt_device=='/CPU:0' input_ch=3 - bgr_shape = nn.get4Dshape(resolution,resolution,input_ch) + bgr_shape = self.bgr_shape = nn.get4Dshape(resolution,resolution,input_ch) mask_shape = nn.get4Dshape(resolution,resolution,1) self.model_filename_list = [] with tf.device ('/CPU:0'): #Place holders on CPU - self.warped_src = tf.placeholder (nn.floatx, bgr_shape) - self.warped_dst = tf.placeholder (nn.floatx, bgr_shape) + self.warped_src = tf.placeholder (nn.floatx, bgr_shape, name='warped_src') + self.warped_dst = tf.placeholder (nn.floatx, bgr_shape, name='warped_dst') - self.target_src = tf.placeholder (nn.floatx, bgr_shape) - self.target_dst = tf.placeholder (nn.floatx, bgr_shape) + self.target_src = tf.placeholder (nn.floatx, bgr_shape, name='target_src') + self.target_dst = tf.placeholder (nn.floatx, bgr_shape, name='target_dst') + + self.target_srcm = tf.placeholder (nn.floatx, mask_shape, name='target_srcm') + self.target_srcm_em = tf.placeholder (nn.floatx, mask_shape, name='target_srcm_em') + self.target_dstm = tf.placeholder (nn.floatx, mask_shape, name='target_dstm') + self.target_dstm_em = tf.placeholder (nn.floatx, mask_shape, name='target_dstm_em') - self.target_srcm_all = tf.placeholder (nn.floatx, mask_shape) - self.target_dstm_all = tf.placeholder (nn.floatx, mask_shape) - # Initializing model classes - model_archi = nn.DeepFakeArchi(resolution) - + model_archi = nn.DeepFakeArchi(resolution, use_fp16=use_fp16, opts=archi_opts) + with tf.device (models_opt_device): - if 'df' in archi: - self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, is_hd=is_hd, name='encoder') - encoder_out_ch = self.encoder.compute_output_channels ( (nn.floatx, bgr_shape)) + if 'df' in archi_type: + self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, name='encoder') + encoder_out_ch = self.encoder.get_out_ch()*self.encoder.get_out_res(resolution)**2 - self.inter = model_archi.Inter (in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims, is_hd=is_hd, name='inter') - inter_out_ch = self.inter.compute_output_channels ( (nn.floatx, (None,encoder_out_ch))) + self.inter = model_archi.Inter (in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims, name='inter') + inter_out_ch = self.inter.get_out_ch() - self.decoder_src = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, is_hd=is_hd, name='decoder_src') - self.decoder_dst = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, is_hd=is_hd, name='decoder_dst') + self.decoder_src = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder_src') + self.decoder_dst = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder_dst') self.model_filename_list += [ [self.encoder, 'encoder.npy' ], [self.inter, 'inter.npy' ], @@ -189,20 +292,19 @@ class SAEHDModel(ModelBase): if self.is_training: if self.options['true_face_power'] != 0: - self.code_discriminator = nn.CodeDiscriminator(ae_dims, code_res=model_archi.Inter.get_code_res()*2, name='dis' ) + self.code_discriminator = nn.CodeDiscriminator(ae_dims, code_res=self.inter.get_out_res(), name='dis' ) self.model_filename_list += [ [self.code_discriminator, 'code_discriminator.npy'] ] - elif 'liae' in archi: - self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, is_hd=is_hd, name='encoder') - encoder_out_ch = self.encoder.compute_output_channels ( (nn.floatx, bgr_shape)) + elif 'liae' in archi_type: + self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, name='encoder') + encoder_out_ch = self.encoder.get_out_ch()*self.encoder.get_out_res(resolution)**2 - self.inter_AB = model_archi.Inter(in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims*2, is_hd=is_hd, name='inter_AB') - self.inter_B = model_archi.Inter(in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims*2, is_hd=is_hd, name='inter_B') + self.inter_AB = model_archi.Inter(in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims*2, name='inter_AB') + self.inter_B = model_archi.Inter(in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims*2, name='inter_B') - inter_AB_out_ch = self.inter_AB.compute_output_channels ( (nn.floatx, (None,encoder_out_ch))) - inter_B_out_ch = self.inter_B.compute_output_channels ( (nn.floatx, (None,encoder_out_ch))) - inters_out_ch = inter_AB_out_ch+inter_B_out_ch - self.decoder = model_archi.Decoder(in_ch=inters_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, is_hd=is_hd, name='decoder') + inter_out_ch = self.inter_AB.get_out_ch() + inters_out_ch = inter_out_ch*2 + self.decoder = model_archi.Decoder(in_ch=inters_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder') self.model_filename_list += [ [self.encoder, 'encoder.npy'], [self.inter_AB, 'inter_AB.npy'], @@ -211,33 +313,43 @@ class SAEHDModel(ModelBase): if self.is_training: if gan_power != 0: - self.D_src = nn.PatchDiscriminator(patch_size=resolution//16, in_ch=input_ch, name="D_src") - self.D_dst = nn.PatchDiscriminator(patch_size=resolution//16, in_ch=input_ch, name="D_dst") - self.model_filename_list += [ [self.D_src, 'D_src.npy'] ] - self.model_filename_list += [ [self.D_dst, 'D_dst.npy'] ] + self.D_src = nn.UNetPatchDiscriminator(patch_size=self.options['gan_patch_size'], in_ch=input_ch, base_ch=self.options['gan_dims'], name="D_src") + self.model_filename_list += [ [self.D_src, 'GAN.npy'] ] # Initialize optimizers lr=5e-5 - lr_dropout = 0.3 if self.options['lr_dropout'] and not self.pretrain else 1.0 + if self.options['lr_dropout'] in ['y','cpu'] and not self.pretrain: + lr_cos = 500 + lr_dropout = 0.3 + else: + lr_cos = 0 + lr_dropout = 1.0 + OptimizerClass = nn.AdaBelief if adabelief else nn.RMSprop clipnorm = 1.0 if self.options['clipgrad'] else 0.0 - self.src_dst_opt = nn.RMSprop(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='src_dst_opt') - self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ] - if 'df' in archi: - self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter.get_weights() + self.decoder_src.get_weights() + self.decoder_dst.get_weights() - elif 'liae' in archi: - self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter_AB.get_weights() + self.inter_B.get_weights() + self.decoder.get_weights() - self.src_dst_opt.initialize_variables (self.src_dst_trainable_weights, vars_on_cpu=optimizer_vars_on_cpu) + if 'df' in archi_type: + self.src_dst_saveable_weights = self.encoder.get_weights() + self.inter.get_weights() + self.decoder_src.get_weights() + self.decoder_dst.get_weights() + self.src_dst_trainable_weights = self.src_dst_saveable_weights + elif 'liae' in archi_type: + self.src_dst_saveable_weights = self.encoder.get_weights() + self.inter_AB.get_weights() + self.inter_B.get_weights() + self.decoder.get_weights() + if random_warp: + self.src_dst_trainable_weights = self.src_dst_saveable_weights + else: + self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter_B.get_weights() + self.decoder.get_weights() + + self.src_dst_opt = OptimizerClass(lr=lr, lr_dropout=lr_dropout, lr_cos=lr_cos, clipnorm=clipnorm, name='src_dst_opt') + self.src_dst_opt.initialize_variables (self.src_dst_saveable_weights, vars_on_cpu=optimizer_vars_on_cpu, lr_dropout_on_cpu=self.options['lr_dropout']=='cpu') + self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ] if self.options['true_face_power'] != 0: - self.D_code_opt = nn.RMSprop(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='D_code_opt') - self.D_code_opt.initialize_variables ( self.code_discriminator.get_weights(), vars_on_cpu=optimizer_vars_on_cpu) + self.D_code_opt = OptimizerClass(lr=lr, lr_dropout=lr_dropout, lr_cos=lr_cos, clipnorm=clipnorm, name='D_code_opt') + self.D_code_opt.initialize_variables ( self.code_discriminator.get_weights(), vars_on_cpu=optimizer_vars_on_cpu, lr_dropout_on_cpu=self.options['lr_dropout']=='cpu') self.model_filename_list += [ (self.D_code_opt, 'D_code_opt.npy') ] if gan_power != 0: - self.D_src_dst_opt = nn.RMSprop(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='D_src_dst_opt') - self.D_src_dst_opt.initialize_variables ( self.D_src.get_weights()+self.D_dst.get_weights(), vars_on_cpu=optimizer_vars_on_cpu) - self.model_filename_list += [ (self.D_src_dst_opt, 'D_src_dst_opt.npy') ] + self.D_src_dst_opt = OptimizerClass(lr=lr, lr_dropout=lr_dropout, lr_cos=lr_cos, clipnorm=clipnorm, name='GAN_opt') + self.D_src_dst_opt.initialize_variables ( self.D_src.get_weights(), vars_on_cpu=optimizer_vars_on_cpu, lr_dropout_on_cpu=self.options['lr_dropout']=='cpu')#+self.D_src_x2.get_weights() + self.model_filename_list += [ (self.D_src_dst_opt, 'GAN_opt.npy') ] if self.is_training: # Adjust batch size for multiple GPU @@ -245,7 +357,6 @@ class SAEHDModel(ModelBase): bs_per_gpu = max(1, self.get_batch_size() // gpu_count) self.set_batch_size( gpu_count*bs_per_gpu) - # Compute losses per GPU gpu_pred_src_src_list = [] gpu_pred_dst_dst_list = [] @@ -259,9 +370,9 @@ class SAEHDModel(ModelBase): gpu_G_loss_gvs = [] gpu_D_code_loss_gvs = [] gpu_D_src_dst_loss_gvs = [] - for gpu_id in range(gpu_count): - with tf.device( f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ): + for gpu_id in range(gpu_count): + with tf.device( f'/{devices[gpu_id].tf_dev_type}:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ): with tf.device(f'/CPU:0'): # slice on CPU, otherwise all batch data will be transfered to GPU first batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu ) @@ -269,18 +380,38 @@ class SAEHDModel(ModelBase): gpu_warped_dst = self.warped_dst [batch_slice,:,:,:] gpu_target_src = self.target_src [batch_slice,:,:,:] gpu_target_dst = self.target_dst [batch_slice,:,:,:] - gpu_target_srcm_all = self.target_srcm_all[batch_slice,:,:,:] - gpu_target_dstm_all = self.target_dstm_all[batch_slice,:,:,:] - + gpu_target_srcm = self.target_srcm[batch_slice,:,:,:] + gpu_target_srcm_em = self.target_srcm_em[batch_slice,:,:,:] + gpu_target_dstm = self.target_dstm[batch_slice,:,:,:] + gpu_target_dstm_em = self.target_dstm_em[batch_slice,:,:,:] + + gpu_target_srcm_anti = 1-gpu_target_srcm + gpu_target_dstm_anti = 1-gpu_target_dstm + + if blur_out_mask: + sigma = resolution / 128 + + x = nn.gaussian_blur(gpu_target_src*gpu_target_srcm_anti, sigma) + y = 1-nn.gaussian_blur(gpu_target_srcm, sigma) + y = tf.where(tf.equal(y, 0), tf.ones_like(y), y) + gpu_target_src = gpu_target_src*gpu_target_srcm + (x/y)*gpu_target_srcm_anti + + x = nn.gaussian_blur(gpu_target_dst*gpu_target_dstm_anti, sigma) + y = 1-nn.gaussian_blur(gpu_target_dstm, sigma) + y = tf.where(tf.equal(y, 0), tf.ones_like(y), y) + gpu_target_dst = gpu_target_dst*gpu_target_dstm + (x/y)*gpu_target_dstm_anti + + # process model tensors - if 'df' in archi: + if 'df' in archi_type: gpu_src_code = self.inter(self.encoder(gpu_warped_src)) gpu_dst_code = self.inter(self.encoder(gpu_warped_dst)) gpu_pred_src_src, gpu_pred_src_srcm = self.decoder_src(gpu_src_code) gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code) + gpu_pred_src_dst_no_code_grad, _ = self.decoder_src(tf.stop_gradient(gpu_dst_code)) - elif 'liae' in archi: + elif 'liae' in archi_type: gpu_src_code = self.encoder (gpu_warped_src) gpu_src_inter_AB_code = self.inter_AB (gpu_src_code) gpu_src_code = tf.concat([gpu_src_inter_AB_code,gpu_src_inter_AB_code], nn.conv2d_ch_axis ) @@ -293,6 +424,7 @@ class SAEHDModel(ModelBase): gpu_pred_src_src, gpu_pred_src_srcm = self.decoder(gpu_src_code) gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder(gpu_dst_code) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code) + gpu_pred_src_dst_no_code_grad, _ = self.decoder(tf.stop_gradient(gpu_src_dst_code)) gpu_pred_src_src_list.append(gpu_pred_src_src) gpu_pred_dst_dst_list.append(gpu_pred_dst_dst) @@ -301,51 +433,62 @@ class SAEHDModel(ModelBase): gpu_pred_src_srcm_list.append(gpu_pred_src_srcm) gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm) gpu_pred_src_dstm_list.append(gpu_pred_src_dstm) - - # unpack masks from one combined mask - gpu_target_srcm = tf.clip_by_value (gpu_target_srcm_all, 0, 1) - gpu_target_dstm = tf.clip_by_value (gpu_target_dstm_all, 0, 1) - gpu_target_srcm_eyes = tf.clip_by_value (gpu_target_srcm_all-1, 0, 1) - gpu_target_dstm_eyes = tf.clip_by_value (gpu_target_dstm_all-1, 0, 1) gpu_target_srcm_blur = nn.gaussian_blur(gpu_target_srcm, max(1, resolution // 32) ) - gpu_target_dstm_blur = nn.gaussian_blur(gpu_target_dstm, max(1, resolution // 32) ) + gpu_target_srcm_blur = tf.clip_by_value(gpu_target_srcm_blur, 0, 0.5) * 2 + gpu_target_srcm_anti_blur = 1.0-gpu_target_srcm_blur - gpu_target_dst_masked = gpu_target_dst*gpu_target_dstm_blur - gpu_target_dst_anti_masked = gpu_target_dst*(1.0 - gpu_target_dstm_blur) + gpu_target_dstm_blur = nn.gaussian_blur(gpu_target_dstm, max(1, resolution // 32) ) + gpu_target_dstm_blur = tf.clip_by_value(gpu_target_dstm_blur, 0, 0.5) * 2 + + gpu_style_mask_blur = nn.gaussian_blur(gpu_pred_src_dstm*gpu_pred_dst_dstm, max(1, resolution // 32) ) + gpu_style_mask_blur = tf.stop_gradient(tf.clip_by_value(gpu_target_srcm_blur, 0, 1.0)) + gpu_style_mask_anti_blur = 1.0 - gpu_style_mask_blur + + gpu_target_dst_masked = gpu_target_dst*gpu_target_dstm_blur + + gpu_target_src_anti_masked = gpu_target_src*gpu_target_srcm_anti_blur + gpu_pred_src_src_anti_masked = gpu_pred_src_src*gpu_target_srcm_anti_blur gpu_target_src_masked_opt = gpu_target_src*gpu_target_srcm_blur if masked_training else gpu_target_src gpu_target_dst_masked_opt = gpu_target_dst_masked if masked_training else gpu_target_dst - gpu_pred_src_src_masked_opt = gpu_pred_src_src*gpu_target_srcm_blur if masked_training else gpu_pred_src_src gpu_pred_dst_dst_masked_opt = gpu_pred_dst_dst*gpu_target_dstm_blur if masked_training else gpu_pred_dst_dst - gpu_psd_target_dst_masked = gpu_pred_src_dst*gpu_target_dstm_blur - gpu_psd_target_dst_anti_masked = gpu_pred_src_dst*(1.0 - gpu_target_dstm_blur) - - gpu_src_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) + if resolution < 256: + gpu_src_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) + else: + gpu_src_loss = tf.reduce_mean ( 5*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) + gpu_src_loss += tf.reduce_mean ( 5*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1]) gpu_src_loss += tf.reduce_mean ( 10*tf.square ( gpu_target_src_masked_opt - gpu_pred_src_src_masked_opt ), axis=[1,2,3]) - - if eyes_prio: - gpu_src_loss += tf.reduce_mean ( 300*tf.abs ( gpu_target_src*gpu_target_srcm_eyes - gpu_pred_src_src*gpu_target_srcm_eyes ), axis=[1,2,3]) - + + if eyes_mouth_prio: + gpu_src_loss += tf.reduce_mean ( 300*tf.abs ( gpu_target_src*gpu_target_srcm_em - gpu_pred_src_src*gpu_target_srcm_em ), axis=[1,2,3]) + gpu_src_loss += tf.reduce_mean ( 10*tf.square( gpu_target_srcm - gpu_pred_src_srcm ),axis=[1,2,3] ) face_style_power = self.options['face_style_power'] / 100.0 if face_style_power != 0 and not self.pretrain: - gpu_src_loss += nn.style_loss(gpu_psd_target_dst_masked, gpu_target_dst_masked, gaussian_blur_radius=resolution//16, loss_weight=10000*face_style_power) + gpu_src_loss += nn.style_loss(gpu_pred_src_dst_no_code_grad*tf.stop_gradient(gpu_pred_src_dstm), tf.stop_gradient(gpu_pred_dst_dst*gpu_pred_dst_dstm), gaussian_blur_radius=resolution//8, loss_weight=10000*face_style_power) bg_style_power = self.options['bg_style_power'] / 100.0 if bg_style_power != 0 and not self.pretrain: - gpu_src_loss += tf.reduce_mean( (10*bg_style_power)*nn.dssim(gpu_psd_target_dst_anti_masked, gpu_target_dst_anti_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) - gpu_src_loss += tf.reduce_mean( (10*bg_style_power)*tf.square( gpu_psd_target_dst_anti_masked - gpu_target_dst_anti_masked), axis=[1,2,3] ) + gpu_target_dst_style_anti_masked = gpu_target_dst*gpu_style_mask_anti_blur + gpu_psd_style_anti_masked = gpu_pred_src_dst*gpu_style_mask_anti_blur - gpu_dst_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1]) + gpu_src_loss += tf.reduce_mean( (10*bg_style_power)*nn.dssim( gpu_psd_style_anti_masked, gpu_target_dst_style_anti_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) + gpu_src_loss += tf.reduce_mean( (10*bg_style_power)*tf.square(gpu_psd_style_anti_masked - gpu_target_dst_style_anti_masked), axis=[1,2,3] ) + + if resolution < 256: + gpu_dst_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1]) + else: + gpu_dst_loss = tf.reduce_mean ( 5*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1]) + gpu_dst_loss += tf.reduce_mean ( 5*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/23.2) ), axis=[1]) gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dst_masked_opt- gpu_pred_dst_dst_masked_opt ), axis=[1,2,3]) - - if eyes_prio: - gpu_dst_loss += tf.reduce_mean ( 300*tf.abs ( gpu_target_dst*gpu_target_dstm_eyes - gpu_pred_dst_dst*gpu_target_dstm_eyes ), axis=[1,2,3]) - + + if eyes_mouth_prio: + gpu_dst_loss += tf.reduce_mean ( 300*tf.abs ( gpu_target_dst*gpu_target_dstm_em - gpu_pred_dst_dst*gpu_target_dstm_em ), axis=[1,2,3]) + gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] ) gpu_src_losses += [gpu_src_loss] @@ -365,38 +508,49 @@ class SAEHDModel(ModelBase): gpu_G_loss += self.options['true_face_power']*DLoss(gpu_src_code_d_ones, gpu_src_code_d) - gpu_D_code_loss = (DLoss(gpu_src_code_d_ones , gpu_dst_code_d) + \ + gpu_D_code_loss = (DLoss(gpu_dst_code_d_ones , gpu_dst_code_d) + \ DLoss(gpu_src_code_d_zeros, gpu_src_code_d) ) * 0.5 gpu_D_code_loss_gvs += [ nn.gradients (gpu_D_code_loss, self.code_discriminator.get_weights() ) ] if gan_power != 0: - gpu_pred_src_src_d = self.D_src(gpu_pred_src_src_masked_opt) + gpu_pred_src_src_d, \ + gpu_pred_src_src_d2 = self.D_src(gpu_pred_src_src_masked_opt) + gpu_pred_src_src_d_ones = tf.ones_like (gpu_pred_src_src_d) gpu_pred_src_src_d_zeros = tf.zeros_like(gpu_pred_src_src_d) - gpu_target_src_d = self.D_src(gpu_target_src_masked_opt) + + gpu_pred_src_src_d2_ones = tf.ones_like (gpu_pred_src_src_d2) + gpu_pred_src_src_d2_zeros = tf.zeros_like(gpu_pred_src_src_d2) + + gpu_target_src_d, \ + gpu_target_src_d2 = self.D_src(gpu_target_src_masked_opt) + gpu_target_src_d_ones = tf.ones_like(gpu_target_src_d) - gpu_pred_dst_dst_d = self.D_dst(gpu_pred_dst_dst_masked_opt) - gpu_pred_dst_dst_d_ones = tf.ones_like (gpu_pred_dst_dst_d) - gpu_pred_dst_dst_d_zeros = tf.zeros_like(gpu_pred_dst_dst_d) - gpu_target_dst_d = self.D_dst(gpu_target_dst_masked_opt) - gpu_target_dst_d_ones = tf.ones_like(gpu_target_dst_d) + gpu_target_src_d2_ones = tf.ones_like(gpu_target_src_d2) - gpu_D_src_dst_loss = (DLoss(gpu_target_src_d_ones , gpu_target_src_d) + \ - DLoss(gpu_pred_src_src_d_zeros, gpu_pred_src_src_d) ) * 0.5 + \ - (DLoss(gpu_target_dst_d_ones , gpu_target_dst_d) + \ - DLoss(gpu_pred_dst_dst_d_zeros, gpu_pred_dst_dst_d) ) * 0.5 + gpu_D_src_dst_loss = (DLoss(gpu_target_src_d_ones , gpu_target_src_d) + \ + DLoss(gpu_pred_src_src_d_zeros , gpu_pred_src_src_d) ) * 0.5 + \ + (DLoss(gpu_target_src_d2_ones , gpu_target_src_d2) + \ + DLoss(gpu_pred_src_src_d2_zeros , gpu_pred_src_src_d2) ) * 0.5 - gpu_D_src_dst_loss_gvs += [ nn.gradients (gpu_D_src_dst_loss, self.D_src.get_weights()+self.D_dst.get_weights() ) ] + gpu_D_src_dst_loss_gvs += [ nn.gradients (gpu_D_src_dst_loss, self.D_src.get_weights() ) ]#+self.D_src_x2.get_weights() - gpu_G_loss += gan_power*(DLoss(gpu_pred_src_src_d_ones, gpu_pred_src_src_d) + DLoss(gpu_pred_dst_dst_d_ones, gpu_pred_dst_dst_d)) + gpu_G_loss += gan_power*(DLoss(gpu_pred_src_src_d_ones, gpu_pred_src_src_d) + \ + DLoss(gpu_pred_src_src_d2_ones, gpu_pred_src_src_d2)) + + if masked_training: + # Minimal src-src-bg rec with total_variation_mse to suppress random bright dots from gan + gpu_G_loss += 0.000001*nn.total_variation_mse(gpu_pred_src_src) + gpu_G_loss += 0.02*tf.reduce_mean(tf.square(gpu_pred_src_src_anti_masked-gpu_target_src_anti_masked),axis=[1,2,3] ) + + gpu_G_loss_gvs += [ nn.gradients ( gpu_G_loss, self.src_dst_trainable_weights )] - gpu_G_loss_gvs += [ nn.gradients ( gpu_G_loss, self.src_dst_trainable_weights ) ] # Average losses and gradients, and create optimizer update ops - with tf.device (models_opt_device): + with tf.device(f'/CPU:0'): pred_src_src = nn.concat(gpu_pred_src_src_list, 0) pred_dst_dst = nn.concat(gpu_pred_dst_dst_list, 0) pred_src_dst = nn.concat(gpu_pred_src_dst_list, 0) @@ -404,6 +558,7 @@ class SAEHDModel(ModelBase): pred_dst_dstm = nn.concat(gpu_pred_dst_dstm_list, 0) pred_src_dstm = nn.concat(gpu_pred_src_dstm_list, 0) + with tf.device (models_opt_device): src_loss = tf.concat(gpu_src_losses, 0) dst_loss = tf.concat(gpu_dst_losses, 0) src_dst_loss_gv_op = self.src_dst_opt.get_update_op (nn.average_gv_list (gpu_G_loss_gvs)) @@ -416,16 +571,18 @@ class SAEHDModel(ModelBase): # Initializing training and view functions - def src_dst_train(warped_src, target_src, target_srcm_all, \ - warped_dst, target_dst, target_dstm_all): - s, d, _ = nn.tf_sess.run ( [ src_loss, dst_loss, src_dst_loss_gv_op], + def src_dst_train(warped_src, target_src, target_srcm, target_srcm_em, \ + warped_dst, target_dst, target_dstm, target_dstm_em, ): + s, d = nn.tf_sess.run ( [ src_loss, dst_loss, src_dst_loss_gv_op], feed_dict={self.warped_src :warped_src, self.target_src :target_src, - self.target_srcm_all:target_srcm_all, + self.target_srcm:target_srcm, + self.target_srcm_em:target_srcm_em, self.warped_dst :warped_dst, self.target_dst :target_dst, - self.target_dstm_all:target_dstm_all, - }) + self.target_dstm:target_dstm, + self.target_dstm_em:target_dstm_em, + })[:2] return s, d self.src_dst_train = src_dst_train @@ -435,17 +592,19 @@ class SAEHDModel(ModelBase): self.D_train = D_train if gan_power != 0: - def D_src_dst_train(warped_src, target_src, target_srcm_all, \ - warped_dst, target_dst, target_dstm_all): + def D_src_dst_train(warped_src, target_src, target_srcm, target_srcm_em, \ + warped_dst, target_dst, target_dstm, target_dstm_em, ): nn.tf_sess.run ([src_D_src_dst_loss_gv_op], feed_dict={self.warped_src :warped_src, self.target_src :target_src, - self.target_srcm_all:target_srcm_all, + self.target_srcm:target_srcm, + self.target_srcm_em:target_srcm_em, self.warped_dst :warped_dst, self.target_dst :target_dst, - self.target_dstm_all:target_dstm_all}) + self.target_dstm:target_dstm, + self.target_dstm_em:target_dstm_em}) self.D_src_dst_train = D_src_dst_train - + def AE_view(warped_src, warped_dst): return nn.tf_sess.run ( [pred_src_src, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm], feed_dict={self.warped_src:warped_src, @@ -453,13 +612,13 @@ class SAEHDModel(ModelBase): self.AE_view = AE_view else: # Initializing merge function - with tf.device( f'/GPU:0' if len(devices) != 0 else f'/CPU:0'): - if 'df' in archi: + with tf.device( nn.tf_default_device_name if len(devices) != 0 else f'/CPU:0'): + if 'df' in archi_type: gpu_dst_code = self.inter(self.encoder(self.warped_dst)) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code) _, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code) - elif 'liae' in archi: + elif 'liae' in archi_type: gpu_dst_code = self.encoder (self.warped_dst) gpu_dst_inter_B_code = self.inter_B (gpu_dst_code) gpu_dst_inter_AB_code = self.inter_AB (gpu_dst_code) @@ -469,7 +628,7 @@ class SAEHDModel(ModelBase): gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code) _, gpu_pred_dst_dstm = self.decoder(gpu_dst_code) - + def AE_merge( warped_dst): return nn.tf_sess.run ( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm], feed_dict={self.warped_dst:warped_dst}) @@ -479,14 +638,17 @@ class SAEHDModel(ModelBase): for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"): if self.pretrain_just_disabled: do_init = False - if 'df' in archi: + if 'df' in archi_type: if model == self.inter: do_init = True - elif 'liae' in archi: - if model == self.inter_AB: + elif 'liae' in archi_type: + if model == self.inter_AB or model == self.inter_B: do_init = True else: do_init = self.is_first_run() + if self.is_training and gan_power != 0 and model == self.D_src: + if self.gan_model_changed: + do_init = True if not do_init: do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) ) @@ -494,6 +656,9 @@ class SAEHDModel(ModelBase): if do_init: model.init_weights() + + ############### + # initializing sample generators if self.is_training: training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path() @@ -501,7 +666,7 @@ class SAEHDModel(ModelBase): random_ct_samples_path=training_data_dst_path if ct_mode is not None and not self.pretrain else None - cpu_count = min(multiprocessing.cpu_count(), 8) + cpu_count = multiprocessing.cpu_count() src_generators_count = cpu_count // 2 dst_generators_count = cpu_count // 2 if ct_mode is not None: @@ -509,28 +674,81 @@ class SAEHDModel(ModelBase): self.set_training_data_generators ([ SampleGeneratorFace(training_data_src_path, random_ct_samples_path=random_ct_samples_path, debug=self.is_debug(), batch_size=self.get_batch_size(), - sample_process_options=SampleProcessor.Options(random_flip=self.random_flip), - output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':self.options['random_warp'], 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, - {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, - {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE_EYES, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + sample_process_options=SampleProcessor.Options(scale_range=[-0.15, 0.15], random_flip=random_src_flip), + output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'random_hsv_shift_amount' : random_hsv_power, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.EYES_MOUTH, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, ], + uniform_yaw_distribution=self.options['uniform_yaw'] or self.pretrain, generators_count=src_generators_count ), SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), - sample_process_options=SampleProcessor.Options(random_flip=self.random_flip), - output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':self.options['random_warp'], 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, - {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, - {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE_EYES, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + sample_process_options=SampleProcessor.Options(scale_range=[-0.15, 0.15], random_flip=random_dst_flip), + output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.EYES_MOUTH, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, ], + uniform_yaw_distribution=self.options['uniform_yaw'] or self.pretrain, generators_count=dst_generators_count ) ]) - - self.last_src_samples_loss = [] - self.last_dst_samples_loss = [] - + if self.pretrain_just_disabled: self.update_sample_for_preview(force_new=True) + def export_dfm (self): + output_path=self.get_strpath_storage_for_file('model.dfm') + + io.log_info(f'Dumping .dfm to {output_path}') + + tf = nn.tf + nn.set_data_format('NCHW') + + with tf.device (nn.tf_default_device_name): + warped_dst = tf.placeholder (nn.floatx, (None, self.resolution, self.resolution, 3), name='in_face') + warped_dst = tf.transpose(warped_dst, (0,3,1,2)) + + + if 'df' in self.archi_type: + gpu_dst_code = self.inter(self.encoder(warped_dst)) + gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code) + _, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code) + + elif 'liae' in self.archi_type: + gpu_dst_code = self.encoder (warped_dst) + gpu_dst_inter_B_code = self.inter_B (gpu_dst_code) + gpu_dst_inter_AB_code = self.inter_AB (gpu_dst_code) + gpu_dst_code = tf.concat([gpu_dst_inter_B_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis) + gpu_src_dst_code = tf.concat([gpu_dst_inter_AB_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis) + + gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code) + _, gpu_pred_dst_dstm = self.decoder(gpu_dst_code) + + gpu_pred_src_dst = tf.transpose(gpu_pred_src_dst, (0,2,3,1)) + gpu_pred_dst_dstm = tf.transpose(gpu_pred_dst_dstm, (0,2,3,1)) + gpu_pred_src_dstm = tf.transpose(gpu_pred_src_dstm, (0,2,3,1)) + + tf.identity(gpu_pred_dst_dstm, name='out_face_mask') + tf.identity(gpu_pred_src_dst, name='out_celeb_face') + tf.identity(gpu_pred_src_dstm, name='out_celeb_face_mask') + + output_graph_def = tf.graph_util.convert_variables_to_constants( + nn.tf_sess, + tf.get_default_graph().as_graph_def(), + ['out_face_mask','out_celeb_face','out_celeb_face_mask'] + ) + + import tf2onnx + with tf.device("/CPU:0"): + model_proto, _ = tf2onnx.convert._convert_common( + output_graph_def, + name='SAEHD', + input_names=['in_face:0'], + output_names=['out_face_mask:0','out_celeb_face:0','out_celeb_face_mask:0'], + opset=12, + output_path=output_path) + #override def get_model_filename_list(self): return self.model_filename_list @@ -540,113 +758,100 @@ class SAEHDModel(ModelBase): for model, filename in io.progress_bar_generator(self.get_model_filename_list(), "Saving", leave=False): model.save_weights ( self.get_strpath_storage_for_file(filename) ) + #override + def should_save_preview_history(self): + return (not io.is_colab() and self.iter % ( 10*(max(1,self.resolution // 64)) ) == 0) or \ + (io.is_colab() and self.iter % 100 == 0) #override def onTrainOneIter(self): - bs = self.get_batch_size() - - ( (warped_src, target_src, target_srcm_all), \ - (warped_dst, target_dst, target_dstm_all) ) = self.generate_next_samples() + if self.get_iter() == 0 and not self.pretrain and not self.pretrain_just_disabled: + io.log_info('You are training the model from scratch. It is strongly recommended to use a pretrained model to speed up the training and improve the quality.\n') - src_loss, dst_loss = self.src_dst_train (warped_src, target_src, target_srcm_all, warped_dst, target_dst, target_dstm_all) - - for i in range(bs): - self.last_src_samples_loss.append ( (target_src[i], target_srcm_all[i], src_loss[i] ) ) - self.last_dst_samples_loss.append ( (target_dst[i], target_dstm_all[i], dst_loss[i] ) ) - - if len(self.last_src_samples_loss) >= bs*16: - src_samples_loss = sorted(self.last_src_samples_loss, key=operator.itemgetter(2), reverse=True) - dst_samples_loss = sorted(self.last_dst_samples_loss, key=operator.itemgetter(2), reverse=True) - - target_src = np.stack( [ x[0] for x in src_samples_loss[:bs] ] ) - target_srcm_all = np.stack( [ x[1] for x in src_samples_loss[:bs] ] ) - - target_dst = np.stack( [ x[0] for x in dst_samples_loss[:bs] ] ) - target_dstm_all = np.stack( [ x[1] for x in dst_samples_loss[:bs] ] ) + ( (warped_src, target_src, target_srcm, target_srcm_em), \ + (warped_dst, target_dst, target_dstm, target_dstm_em) ) = self.generate_next_samples() - src_loss, dst_loss = self.src_dst_train (target_src, target_src, target_srcm_all, target_dst, target_dst, target_dstm_all) - self.last_src_samples_loss = [] - self.last_dst_samples_loss = [] + src_loss, dst_loss = self.src_dst_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em) if self.options['true_face_power'] != 0 and not self.pretrain: self.D_train (warped_src, warped_dst) if self.gan_power != 0: - self.D_src_dst_train (warped_src, target_src, target_srcm_all, warped_dst, target_dst, target_dstm_all) + self.D_src_dst_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em) return ( ('src_loss', np.mean(src_loss) ), ('dst_loss', np.mean(dst_loss) ), ) #override - def onGetPreview(self, samples): - ( (warped_src, target_src, target_srcm_all,), - (warped_dst, target_dst, target_dstm_all,) ) = samples + def onGetPreview(self, samples, for_history=False): + ( (warped_src, target_src, target_srcm, target_srcm_em), + (warped_dst, target_dst, target_dstm, target_dstm_em) ) = samples S, D, SS, DD, DDM, SD, SDM = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([target_src,target_dst] + self.AE_view (target_src, target_dst) ) ] DDM, SDM, = [ np.repeat (x, (3,), -1) for x in [DDM, SDM] ] - target_srcm_all, target_dstm_all = [ nn.to_data_format(x,"NHWC", self.model_data_format) for x in ([target_srcm_all, target_dstm_all] )] - - target_srcm = np.clip(target_srcm_all, 0, 1) - target_dstm = np.clip(target_dstm_all, 0, 1) - + target_srcm, target_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format) for x in ([target_srcm, target_dstm] )] + n_samples = min(4, self.get_batch_size(), 800 // self.resolution ) if self.resolution <= 256: result = [] - + st = [] for i in range(n_samples): ar = S[i], SS[i], D[i], DD[i], SD[i] st.append ( np.concatenate ( ar, axis=1) ) result += [ ('SAEHD', np.concatenate (st, axis=0 )), ] - + st_m = [] for i in range(n_samples): - ar = S[i]*target_srcm[i], SS[i], D[i]*target_dstm[i], DD[i]*DDM[i], SD[i]*(DDM[i]*SDM[i]) + SD_mask = DDM[i]*SDM[i] if self.face_type < FaceType.HEAD else SDM[i] + + ar = S[i]*target_srcm[i], SS[i], D[i]*target_dstm[i], DD[i]*DDM[i], SD[i]*SD_mask st_m.append ( np.concatenate ( ar, axis=1) ) result += [ ('SAEHD masked', np.concatenate (st_m, axis=0 )), ] else: result = [] - + st = [] for i in range(n_samples): ar = S[i], SS[i] st.append ( np.concatenate ( ar, axis=1) ) result += [ ('SAEHD src-src', np.concatenate (st, axis=0 )), ] - + st = [] for i in range(n_samples): ar = D[i], DD[i] st.append ( np.concatenate ( ar, axis=1) ) result += [ ('SAEHD dst-dst', np.concatenate (st, axis=0 )), ] - + st = [] for i in range(n_samples): ar = D[i], SD[i] st.append ( np.concatenate ( ar, axis=1) ) result += [ ('SAEHD pred', np.concatenate (st, axis=0 )), ] - + st_m = [] for i in range(n_samples): ar = S[i]*target_srcm[i], SS[i] - st_m.append ( np.concatenate ( ar, axis=1) ) + st_m.append ( np.concatenate ( ar, axis=1) ) result += [ ('SAEHD masked src-src', np.concatenate (st_m, axis=0 )), ] - + st_m = [] for i in range(n_samples): ar = D[i]*target_dstm[i], DD[i]*DDM[i] - st_m.append ( np.concatenate ( ar, axis=1) ) + st_m.append ( np.concatenate ( ar, axis=1) ) result += [ ('SAEHD masked dst-dst', np.concatenate (st_m, axis=0 )), ] - + st_m = [] for i in range(n_samples): - ar = D[i]*target_dstm[i], SD[i]*(DDM[i]*SDM[i]) - st_m.append ( np.concatenate ( ar, axis=1) ) + SD_mask = DDM[i]*SDM[i] if self.face_type < FaceType.HEAD else SDM[i] + ar = D[i]*target_dstm[i], SD[i]*SD_mask + st_m.append ( np.concatenate ( ar, axis=1) ) result += [ ('SAEHD masked pred', np.concatenate (st_m, axis=0 )), ] - + return result def predictor_func (self, face=None): @@ -654,8 +859,7 @@ class SAEHDModel(ModelBase): bgr, mask_dst_dstm, mask_src_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format).astype(np.float32) for x in self.AE_merge (face) ] - mask = mask_dst_dstm[0] * mask_src_dstm[0] - return bgr[0], mask[...,0] + return bgr[0], mask_src_dstm[0][...,0], mask_dst_dstm[0][...,0] #override def get_MergerConfig(self): diff --git a/models/Model_XSeg/Model.py b/models/Model_XSeg/Model.py index b3578fb..b0addfd 100644 --- a/models/Model_XSeg/Model.py +++ b/models/Model_XSeg/Model.py @@ -7,7 +7,7 @@ import numpy as np from core import mathlib from core.interact import interact as io from core.leras import nn -from facelib import FaceType, TernausNet, XSegNet +from facelib import FaceType, XSegNet from models import ModelBase from samplelib import * @@ -15,15 +15,34 @@ class XSegModel(ModelBase): def __init__(self, *args, **kwargs): super().__init__(*args, force_model_class_name='XSeg', **kwargs) - + #override def on_initialize_options(self): - self.set_batch_size(4) + ask_override = self.ask_override() + + if not self.is_first_run() and ask_override: + if io.input_bool(f"Restart training?", False, help_message="Reset model weights and start training from scratch."): + self.set_iter(0) + + default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'wf') + default_pretrain = self.options['pretrain'] = self.load_or_def_option('pretrain', False) + + if self.is_first_run(): + self.options['face_type'] = io.input_str ("Face type", default_face_type, ['h','mf','f','wf','head'], help_message="Half / mid face / full face / whole face / head. Choose the same as your deepfake model.").lower() + + if self.is_first_run() or ask_override: + self.ask_batch_size(4, range=[2,16]) + self.options['pretrain'] = io.input_bool ("Enable pretraining mode", default_pretrain) + + if not self.is_exporting and (self.options['pretrain'] and self.get_pretraining_data_path() is None): + raise Exception("pretraining_data_path is not defined") + + self.pretrain_just_disabled = (default_pretrain == True and self.options['pretrain'] == False) #override def on_initialize(self): device_config = nn.getCurrentDeviceConfig() - self.model_data_format = "NCHW" if len(device_config.devices) != 0 and not self.is_debug() else "NHWC" + self.model_data_format = "NCHW" if self.is_exporting or (len(device_config.devices) != 0 and not self.is_debug()) else "NHWC" nn.initialize(data_format=self.model_data_format) tf = nn.tf @@ -31,68 +50,91 @@ class XSegModel(ModelBase): devices = device_config.devices self.resolution = resolution = 256 - self.face_type = FaceType.WHOLE_FACE + + + self.face_type = {'h' : FaceType.HALF, + 'mf' : FaceType.MID_FULL, + 'f' : FaceType.FULL, + 'wf' : FaceType.WHOLE_FACE, + 'head' : FaceType.HEAD}[ self.options['face_type'] ] + place_model_on_cpu = len(devices) == 0 - models_opt_device = '/CPU:0' if place_model_on_cpu else '/GPU:0' + models_opt_device = '/CPU:0' if place_model_on_cpu else nn.tf_default_device_name bgr_shape = nn.get4Dshape(resolution,resolution,3) mask_shape = nn.get4Dshape(resolution,resolution,1) - + # Initializing model classes - self.model = XSegNet(name=f'XSeg', - resolution=resolution, + self.model = XSegNet(name='XSeg', + resolution=resolution, load_weights=not self.is_first_run(), weights_file_root=self.get_model_root_path(), training=True, place_model_on_cpu=place_model_on_cpu, optimizer=nn.RMSprop(lr=0.0001, lr_dropout=0.3, name='opt'), data_format=nn.data_format) - + + self.pretrain = self.options['pretrain'] + if self.pretrain_just_disabled: + self.set_iter(0) + if self.is_training: # Adjust batch size for multiple GPU gpu_count = max(1, len(devices) ) bs_per_gpu = max(1, self.get_batch_size() // gpu_count) self.set_batch_size( gpu_count*bs_per_gpu) - # Compute losses per GPU gpu_pred_list = [] gpu_losses = [] gpu_loss_gvs = [] - - for gpu_id in range(gpu_count): - with tf.device( f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ): + for gpu_id in range(gpu_count): + with tf.device(f'/{devices[gpu_id].tf_dev_type}:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ): with tf.device(f'/CPU:0'): # slice on CPU, otherwise all batch data will be transfered to GPU first batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu ) gpu_input_t = self.model.input_t [batch_slice,:,:,:] - gpu_target_t = self.model.target_t [batch_slice,:,:,:] - + gpu_target_t = self.model.target_t [batch_slice,:,:,:] + # process model tensors - gpu_pred_logits_t, gpu_pred_t = self.model.flow(gpu_input_t) + gpu_pred_logits_t, gpu_pred_t = self.model.flow(gpu_input_t, pretrain=self.pretrain) gpu_pred_list.append(gpu_pred_t) - - gpu_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=gpu_target_t, logits=gpu_pred_logits_t), axis=[1,2,3]) + + + if self.pretrain: + # Structural loss + gpu_loss = tf.reduce_mean (5*nn.dssim(gpu_target_t, gpu_pred_t, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) + gpu_loss += tf.reduce_mean (5*nn.dssim(gpu_target_t, gpu_pred_t, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1]) + # Pixel loss + gpu_loss += tf.reduce_mean (10*tf.square(gpu_target_t-gpu_pred_t), axis=[1,2,3]) + else: + gpu_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=gpu_target_t, logits=gpu_pred_logits_t), axis=[1,2,3]) + gpu_losses += [gpu_loss] gpu_loss_gvs += [ nn.gradients ( gpu_loss, self.model.get_weights() ) ] # Average losses and gradients, and create optimizer update ops + #with tf.device(f'/CPU:0'): # Temporary fix. Unknown bug with training freeze starts from 2.4.0, but 2.3.1 was ok with tf.device (models_opt_device): - pred = nn.concat(gpu_pred_list, 0) - loss = tf.reduce_mean(gpu_losses) - + pred = tf.concat(gpu_pred_list, 0) + loss = tf.concat(gpu_losses, 0) loss_gv_op = self.model.opt.get_update_op (nn.average_gv_list (gpu_loss_gvs)) - - + + # Initializing training and view functions - def train(input_np, target_np): - l, _ = nn.tf_sess.run ( [loss, loss_gv_op], feed_dict={self.model.input_t :input_np, self.model.target_t :target_np }) - return l + if self.pretrain: + def train(input_np, target_np): + l, _ = nn.tf_sess.run ( [loss, loss_gv_op], feed_dict={self.model.input_t :input_np, self.model.target_t :target_np}) + return l + else: + def train(input_np, target_np): + l, _ = nn.tf_sess.run ( [loss, loss_gv_op], feed_dict={self.model.input_t :input_np, self.model.target_t :target_np }) + return l self.train = train def view(input_np): @@ -105,29 +147,38 @@ class XSegModel(ModelBase): src_generators_count = cpu_count // 2 dst_generators_count = cpu_count // 2 - - srcdst_generator = SampleGeneratorFaceXSeg([self.training_data_src_path, self.training_data_dst_path], - debug=self.is_debug(), - batch_size=self.get_batch_size(), - resolution=resolution, - face_type=self.face_type, - generators_count=src_dst_generators_count, - data_format=nn.data_format) - - src_generator = SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(), - sample_process_options=SampleProcessor.Options(random_flip=False), - output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp':False, 'transform':False, 'channel_type' : SampleProcessor.ChannelType.BGR, 'border_replicate':False, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, - ], - generators_count=src_generators_count, - raise_on_no_data=False ) - dst_generator = SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), - sample_process_options=SampleProcessor.Options(random_flip=False), - output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp':False, 'transform':False, 'channel_type' : SampleProcessor.ChannelType.BGR, 'border_replicate':False, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, - ], - generators_count=dst_generators_count, - raise_on_no_data=False ) - - self.set_training_data_generators ([srcdst_generator, src_generator, dst_generator]) + if self.pretrain: + pretrain_gen = SampleGeneratorFace(self.get_pretraining_data_path(), debug=self.is_debug(), batch_size=self.get_batch_size(), + sample_process_options=SampleProcessor.Options(random_flip=True), + output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + ], + uniform_yaw_distribution=False, + generators_count=cpu_count ) + self.set_training_data_generators ([pretrain_gen]) + else: + srcdst_generator = SampleGeneratorFaceXSeg([self.training_data_src_path, self.training_data_dst_path], + debug=self.is_debug(), + batch_size=self.get_batch_size(), + resolution=resolution, + face_type=self.face_type, + generators_count=src_dst_generators_count, + data_format=nn.data_format) + + src_generator = SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(), + sample_process_options=SampleProcessor.Options(random_flip=False), + output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp':False, 'transform':False, 'channel_type' : SampleProcessor.ChannelType.BGR, 'border_replicate':False, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + ], + generators_count=src_generators_count, + raise_on_no_data=False ) + dst_generator = SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), + sample_process_options=SampleProcessor.Options(random_flip=False), + output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp':False, 'transform':False, 'channel_type' : SampleProcessor.ChannelType.BGR, 'border_replicate':False, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + ], + generators_count=dst_generators_count, + raise_on_no_data=False ) + + self.set_training_data_generators ([srcdst_generator, src_generator, dst_generator]) #override def get_model_filename_list(self): @@ -136,63 +187,97 @@ class XSegModel(ModelBase): #override def onSave(self): self.model.save_weights() - - #override - def onTrainOneIter(self): - - - image_np, mask_np = self.generate_next_samples()[0] - loss = self.train (image_np, mask_np) - - return ( ('loss', loss ), ) #override - def onGetPreview(self, samples): + def onTrainOneIter(self): + image_np, target_np = self.generate_next_samples()[0] + loss = self.train (image_np, target_np) + + return ( ('loss', np.mean(loss) ), ) + + #override + def onGetPreview(self, samples, for_history=False): n_samples = min(4, self.get_batch_size(), 800 // self.resolution ) - - srcdst_samples, src_samples, dst_samples = samples - image_np, mask_np = srcdst_samples + + if self.pretrain: + srcdst_samples, = samples + image_np, mask_np = srcdst_samples + else: + srcdst_samples, src_samples, dst_samples = samples + image_np, mask_np = srcdst_samples I, M, IM, = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([image_np,mask_np] + self.view (image_np) ) ] M, IM, = [ np.repeat (x, (3,), -1) for x in [M, IM] ] green_bg = np.tile( np.array([0,1,0], dtype=np.float32)[None,None,...], (self.resolution,self.resolution,1) ) - result = [] + result = [] st = [] for i in range(n_samples): - ar = I[i]*M[i]+0.5*I[i]*(1-M[i])+0.5*green_bg*(1-M[i]), IM[i], I[i]*IM[i]+0.5*I[i]*(1-IM[i]) + 0.5*green_bg*(1-IM[i]) + if self.pretrain: + ar = I[i], IM[i] + else: + ar = I[i]*M[i]+0.5*I[i]*(1-M[i])+0.5*green_bg*(1-M[i]), IM[i], I[i]*IM[i]+0.5*I[i]*(1-IM[i]) + 0.5*green_bg*(1-IM[i]) st.append ( np.concatenate ( ar, axis=1) ) result += [ ('XSeg training faces', np.concatenate (st, axis=0 )), ] - - if len(src_samples) != 0: + + if not self.pretrain and len(src_samples) != 0: src_np, = src_samples - + D, DM, = [ np.clip(nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([src_np] + self.view (src_np) ) ] DM, = [ np.repeat (x, (3,), -1) for x in [DM] ] - + st = [] for i in range(n_samples): ar = D[i], DM[i], D[i]*DM[i] + 0.5*D[i]*(1-DM[i]) + 0.5*green_bg*(1-DM[i]) st.append ( np.concatenate ( ar, axis=1) ) - + result += [ ('XSeg src faces', np.concatenate (st, axis=0 )), ] - - if len(dst_samples) != 0: + + if not self.pretrain and len(dst_samples) != 0: dst_np, = dst_samples - + D, DM, = [ np.clip(nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([dst_np] + self.view (dst_np) ) ] DM, = [ np.repeat (x, (3,), -1) for x in [DM] ] - + st = [] for i in range(n_samples): ar = D[i], DM[i], D[i]*DM[i] + 0.5*D[i]*(1-DM[i]) + 0.5*green_bg*(1-DM[i]) st.append ( np.concatenate ( ar, axis=1) ) - - result += [ ('XSeg dst faces', np.concatenate (st, axis=0 )), ] - - return result + result += [ ('XSeg dst faces', np.concatenate (st, axis=0 )), ] + + return result + + def export_dfm (self): + output_path = self.get_strpath_storage_for_file(f'model.onnx') + io.log_info(f'Dumping .onnx to {output_path}') + tf = nn.tf + + with tf.device (nn.tf_default_device_name): + input_t = tf.placeholder (nn.floatx, (None, self.resolution, self.resolution, 3), name='in_face') + input_t = tf.transpose(input_t, (0,3,1,2)) + _, pred_t = self.model.flow(input_t) + pred_t = tf.transpose(pred_t, (0,2,3,1)) + + tf.identity(pred_t, name='out_mask') + + output_graph_def = tf.graph_util.convert_variables_to_constants( + nn.tf_sess, + tf.get_default_graph().as_graph_def(), + ['out_mask'] + ) + + import tf2onnx + with tf.device("/CPU:0"): + model_proto, _ = tf2onnx.convert._convert_common( + output_graph_def, + name='XSeg', + input_names=['in_face:0'], + output_names=['out_mask:0'], + opset=13, + output_path=output_path) + Model = XSegModel \ No newline at end of file diff --git a/project.code-workspace b/project.code-workspace deleted file mode 100644 index 07fae2f..0000000 --- a/project.code-workspace +++ /dev/null @@ -1,50 +0,0 @@ -{ - "folders": [ - { - "path": "." - } - ], - "settings": { - "workbench.colorTheme": "Visual Studio Light", - "diffEditor.ignoreTrimWhitespace": true, - "workbench.sideBar.location": "right", - "breadcrumbs.enabled": false, - "editor.renderWhitespace": "none", - "editor.minimap.enabled": false, - "workbench.activityBar.visible": true, - "window.menuBarVisibility": "default", - "editor.fastScrollSensitivity": 10, - "editor.mouseWheelScrollSensitivity": 2, - "window.zoomLevel": 0, - "extensions.ignoreRecommendations": true, - - "python.linting.pylintEnabled": false, - "python.linting.enabled": false, - "python.linting.pylamaEnabled": false, - "python.linting.pydocstyleEnabled": false, - "python.pythonPath": "${env:PYTHON_EXECUTABLE}", - "workbench.editor.tabCloseButton": "off", - "workbench.editor.tabSizing": "shrink", - "workbench.editor.highlightModifiedTabs": true, - "editor.mouseWheelScrollSensitivity": 3, - "editor.folding": false, - "editor.glyphMargin": false, - "files.exclude": { - "**/__pycache__": true, - "**/.github": true, - "**/.vscode": true, - "**/*.dat": true, - "**/*.h5": true, - "**/*.npy": true - }, - "editor.quickSuggestions": { - "other": false, - "comments": false, - "strings": false - }, - "editor.trimAutoWhitespace": false, - "python.linting.pylintArgs": [ - "--disable=import-error" - ] - } -} \ No newline at end of file diff --git a/requirements-colab.txt b/requirements-colab.txt index 128a518..33546b2 100644 --- a/requirements-colab.txt +++ b/requirements-colab.txt @@ -1,9 +1,11 @@ tqdm -numpy==1.17.0 -h5py==2.9.0 +numpy==1.19.3 +numexpr +h5py==2.10.0 opencv-python==4.1.0.25 ffmpeg-python==0.1.17 scikit-image==0.14.2 scipy==1.4.1 colorama -tensorflow-gpu==1.13.2 \ No newline at end of file +tensorflow-gpu==2.4.0 +tf2onnx==1.9.3 \ No newline at end of file diff --git a/requirements-cuda.txt b/requirements-cuda.txt index 905aaef..b70520d 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -1,10 +1,12 @@ tqdm -numpy==1.17.0 -h5py==2.9.0 +numpy==1.19.3 +numexpr +h5py==2.10.0 opencv-python==4.1.0.25 ffmpeg-python==0.1.17 scikit-image==0.14.2 scipy==1.4.1 colorama -labelme==4.2.9 -tensorflow-gpu==1.13.2 \ No newline at end of file +tensorflow-gpu==2.4.0 +pyqt5 +tf2onnx==1.9.3 \ No newline at end of file diff --git a/samplelib/PackedFaceset.py b/samplelib/PackedFaceset.py index 867fcd9..e7ae1d4 100644 --- a/samplelib/PackedFaceset.py +++ b/samplelib/PackedFaceset.py @@ -84,17 +84,18 @@ class PackedFaceset(): of.write ( struct.pack("Q", offset) ) of.seek(0,2) of.close() + + if io.input_bool(f"Delete original files?", True): + for filename in io.progress_bar_generator(image_paths, "Deleting files"): + Path(filename).unlink() - for filename in io.progress_bar_generator(image_paths, "Deleting files"): - Path(filename).unlink() - - if as_person_faceset: - for dir_name in io.progress_bar_generator(dir_names, "Deleting dirs"): - dir_path = samples_path / dir_name - try: - shutil.rmtree(dir_path) - except: - io.log_info (f"unable to remove: {dir_path} ") + if as_person_faceset: + for dir_name in io.progress_bar_generator(dir_names, "Deleting dirs"): + dir_path = samples_path / dir_name + try: + shutil.rmtree(dir_path) + except: + io.log_info (f"unable to remove: {dir_path} ") @staticmethod def unpack(samples_path): @@ -120,6 +121,11 @@ class PackedFaceset(): samples_dat_path.unlink() + @staticmethod + def path_contains(samples_path): + samples_dat_path = samples_path / packed_faceset_filename + return samples_dat_path.exists() + @staticmethod def load(samples_path): samples_dat_path = samples_path / packed_faceset_filename diff --git a/samplelib/Sample.py b/samplelib/Sample.py index 604c02f..a379275 100644 --- a/samplelib/Sample.py +++ b/samplelib/Sample.py @@ -5,9 +5,9 @@ import cv2 import numpy as np from core.cv2ex import * -from DFLIMG import * from facelib import LandmarksProcessor -from core.imagelib import IEPolys +from core import imagelib +from core.imagelib import SegIEPolys class SampleType(IntEnum): IMAGE = 0 #raw image @@ -26,8 +26,9 @@ class Sample(object): 'face_type', 'shape', 'landmarks', - 'ie_polys', 'seg_ie_polys', + 'xseg_mask', + 'xseg_mask_compressed', 'eyebrows_expand_mod', 'source_filename', 'person_name', @@ -40,8 +41,9 @@ class Sample(object): face_type=None, shape=None, landmarks=None, - ie_polys=None, seg_ie_polys=None, + xseg_mask=None, + xseg_mask_compressed=None, eyebrows_expand_mod=None, source_filename=None, person_name=None, @@ -53,15 +55,41 @@ class Sample(object): self.face_type = face_type self.shape = shape self.landmarks = np.array(landmarks) if landmarks is not None else None - self.ie_polys = IEPolys.load(ie_polys) - self.seg_ie_polys = IEPolys.load(seg_ie_polys) - self.eyebrows_expand_mod = eyebrows_expand_mod + + if isinstance(seg_ie_polys, SegIEPolys): + self.seg_ie_polys = seg_ie_polys + else: + self.seg_ie_polys = SegIEPolys.load(seg_ie_polys) + + self.xseg_mask = xseg_mask + self.xseg_mask_compressed = xseg_mask_compressed + + if self.xseg_mask_compressed is None and self.xseg_mask is not None: + xseg_mask = np.clip( imagelib.normalize_channels(xseg_mask, 1)*255, 0, 255 ).astype(np.uint8) + ret, xseg_mask_compressed = cv2.imencode('.png', xseg_mask) + if not ret: + raise Exception("Sample(): unable to generate xseg_mask_compressed") + self.xseg_mask_compressed = xseg_mask_compressed + self.xseg_mask = None + + self.eyebrows_expand_mod = eyebrows_expand_mod if eyebrows_expand_mod is not None else 1.0 self.source_filename = source_filename self.person_name = person_name self.pitch_yaw_roll = pitch_yaw_roll self._filename_offset_size = None + def has_xseg_mask(self): + return self.xseg_mask is not None or self.xseg_mask_compressed is not None + + def get_xseg_mask(self): + if self.xseg_mask_compressed is not None: + xseg_mask = cv2.imdecode(self.xseg_mask_compressed, cv2.IMREAD_UNCHANGED) + if len(xseg_mask.shape) == 2: + xseg_mask = xseg_mask[...,None] + return xseg_mask.astype(np.float32) / 255.0 + return self.xseg_mask + def get_pitch_yaw_roll(self): if self.pitch_yaw_roll is None: self.pitch_yaw_roll = LandmarksProcessor.estimate_pitch_yaw_roll(self.landmarks, size=self.shape[1]) @@ -90,25 +118,10 @@ class Sample(object): 'face_type': self.face_type, 'shape': self.shape, 'landmarks': self.landmarks.tolist(), - 'ie_polys': self.ie_polys.dump(), 'seg_ie_polys': self.seg_ie_polys.dump(), + 'xseg_mask' : self.xseg_mask, + 'xseg_mask_compressed' : self.xseg_mask_compressed, 'eyebrows_expand_mod': self.eyebrows_expand_mod, 'source_filename': self.source_filename, 'person_name': self.person_name } - -""" -def copy_and_set(self, sample_type=None, filename=None, face_type=None, shape=None, landmarks=None, ie_polys=None, pitch_yaw_roll=None, eyebrows_expand_mod=None, source_filename=None, fanseg_mask=None, person_name=None): - return Sample( - sample_type=sample_type if sample_type is not None else self.sample_type, - filename=filename if filename is not None else self.filename, - face_type=face_type if face_type is not None else self.face_type, - shape=shape if shape is not None else self.shape, - landmarks=landmarks if landmarks is not None else self.landmarks.copy(), - ie_polys=ie_polys if ie_polys is not None else self.ie_polys, - pitch_yaw_roll=pitch_yaw_roll if pitch_yaw_roll is not None else self.pitch_yaw_roll, - eyebrows_expand_mod=eyebrows_expand_mod if eyebrows_expand_mod is not None else self.eyebrows_expand_mod, - source_filename=source_filename if source_filename is not None else self.source_filename, - person_name=person_name if person_name is not None else self.person_name) - -""" \ No newline at end of file diff --git a/samplelib/SampleGeneratorFace.py b/samplelib/SampleGeneratorFace.py index 515c6fe..605d327 100644 --- a/samplelib/SampleGeneratorFace.py +++ b/samplelib/SampleGeneratorFace.py @@ -1,5 +1,4 @@ import multiprocessing -import pickle import time import traceback @@ -7,6 +6,7 @@ import cv2 import numpy as np from core import mplib +from core.interact import interact as io from core.joblib import SubprocessGenerator, ThisThreadGenerator from facelib import LandmarksProcessor from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor, @@ -25,15 +25,15 @@ class SampleGeneratorFace(SampleGeneratorBase): random_ct_samples_path=None, sample_process_options=SampleProcessor.Options(), output_sample_types=[], - add_sample_idx=False, + uniform_yaw_distribution=False, generators_count=4, - raise_on_no_data=True, + raise_on_no_data=True, **kwargs): super().__init__(debug, batch_size) + self.initialized = False self.sample_process_options = sample_process_options self.output_sample_types = output_sample_types - self.add_sample_idx = add_sample_idx if self.debug: self.generators_count = 1 @@ -42,15 +42,40 @@ class SampleGeneratorFace(SampleGeneratorBase): samples = SampleLoader.load (SampleType.FACE, samples_path) self.samples_len = len(samples) - - self.initialized = False + if self.samples_len == 0: if raise_on_no_data: raise ValueError('No training data provided.') else: return + + if uniform_yaw_distribution: + samples_pyr = [ ( idx, sample.get_pitch_yaw_roll() ) for idx, sample in enumerate(samples) ] + + grads = 128 + #instead of math.pi / 2, using -1.2,+1.2 because actually maximum yaw for 2DFAN landmarks are -1.2+1.2 + grads_space = np.linspace (-1.2, 1.2,grads) - index_host = mplib.IndexHost(self.samples_len) + yaws_sample_list = [None]*grads + for g in io.progress_bar_generator ( range(grads), "Sort by yaw"): + yaw = grads_space[g] + next_yaw = grads_space[g+1] if g < grads-1 else yaw + + yaw_samples = [] + for idx, pyr in samples_pyr: + s_yaw = -pyr[1] + if (g == 0 and s_yaw < next_yaw) or \ + (g < grads-1 and s_yaw >= yaw and s_yaw < next_yaw) or \ + (g == grads-1 and s_yaw >= yaw): + yaw_samples += [ idx ] + if len(yaw_samples) > 0: + yaws_sample_list[g] = yaw_samples + + yaws_sample_list = [ y for y in yaws_sample_list if y is not None ] + + index_host = mplib.Index2DHost( yaws_sample_list ) + else: + index_host = mplib.IndexHost(self.samples_len) if random_ct_samples_path is not None: ct_samples = SampleLoader.load (SampleType.FACE, random_ct_samples_path) @@ -59,13 +84,10 @@ class SampleGeneratorFace(SampleGeneratorBase): ct_samples = None ct_index_host = None - pickled_samples = pickle.dumps(samples, 4) - ct_pickled_samples = pickle.dumps(ct_samples, 4) if ct_samples is not None else None - if self.debug: - self.generators = [ThisThreadGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(), ct_pickled_samples, ct_index_host.create_cli() if ct_index_host is not None else None) )] + self.generators = [ThisThreadGenerator ( self.batch_func, (samples, index_host.create_cli(), ct_samples, ct_index_host.create_cli() if ct_index_host is not None else None) )] else: - self.generators = [SubprocessGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(), ct_pickled_samples, ct_index_host.create_cli() if ct_index_host is not None else None), start_now=False ) \ + self.generators = [SubprocessGenerator ( self.batch_func, (samples, index_host.create_cli(), ct_samples, ct_index_host.create_cli() if ct_index_host is not None else None), start_now=False ) \ for i in range(self.generators_count) ] SubprocessGenerator.start_in_parallel( self.generators ) @@ -90,11 +112,8 @@ class SampleGeneratorFace(SampleGeneratorBase): return next(generator) def batch_func(self, param ): - pickled_samples, index_host, ct_pickled_samples, ct_index_host = param - - samples = pickle.loads(pickled_samples) - ct_samples = pickle.loads(ct_pickled_samples) if ct_pickled_samples is not None else None - + samples, index_host, ct_samples, ct_index_host = param + bs = self.batch_size while True: batches = None @@ -118,14 +137,8 @@ class SampleGeneratorFace(SampleGeneratorBase): if batches is None: batches = [ [] for _ in range(len(x)) ] - if self.add_sample_idx: - batches += [ [] ] - i_sample_idx = len(batches)-1 for i in range(len(x)): batches[i].append ( x[i] ) - if self.add_sample_idx: - batches[i_sample_idx].append (sample_idx) - yield [ np.array(batch) for batch in batches] diff --git a/samplelib/SampleGeneratorFacePerson.py b/samplelib/SampleGeneratorFacePerson.py index a72cf59..0fbd2c3 100644 --- a/samplelib/SampleGeneratorFacePerson.py +++ b/samplelib/SampleGeneratorFacePerson.py @@ -12,6 +12,98 @@ from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor, SampleType) + +class Index2DHost(): + """ + Provides random shuffled 2D indexes for multiprocesses + """ + def __init__(self, indexes2D): + self.sq = multiprocessing.Queue() + self.cqs = [] + self.clis = [] + self.thread = threading.Thread(target=self.host_thread, args=(indexes2D,) ) + self.thread.daemon = True + self.thread.start() + + def host_thread(self, indexes2D): + indexes_counts_len = len(indexes2D) + + idxs = [*range(indexes_counts_len)] + idxs_2D = [None]*indexes_counts_len + shuffle_idxs = [] + shuffle_idxs_2D = [None]*indexes_counts_len + for i in range(indexes_counts_len): + idxs_2D[i] = indexes2D[i] + shuffle_idxs_2D[i] = [] + + sq = self.sq + + while True: + while not sq.empty(): + obj = sq.get() + cq_id, cmd = obj[0], obj[1] + + if cmd == 0: #get_1D + count = obj[2] + + result = [] + for i in range(count): + if len(shuffle_idxs) == 0: + shuffle_idxs = idxs.copy() + np.random.shuffle(shuffle_idxs) + result.append(shuffle_idxs.pop()) + self.cqs[cq_id].put (result) + elif cmd == 1: #get_2D + targ_idxs,count = obj[2], obj[3] + result = [] + + for targ_idx in targ_idxs: + sub_idxs = [] + for i in range(count): + ar = shuffle_idxs_2D[targ_idx] + if len(ar) == 0: + ar = shuffle_idxs_2D[targ_idx] = idxs_2D[targ_idx].copy() + np.random.shuffle(ar) + sub_idxs.append(ar.pop()) + result.append (sub_idxs) + self.cqs[cq_id].put (result) + + time.sleep(0.001) + + def create_cli(self): + cq = multiprocessing.Queue() + self.cqs.append ( cq ) + cq_id = len(self.cqs)-1 + return Index2DHost.Cli(self.sq, cq, cq_id) + + # disable pickling + def __getstate__(self): + return dict() + def __setstate__(self, d): + self.__dict__.update(d) + + class Cli(): + def __init__(self, sq, cq, cq_id): + self.sq = sq + self.cq = cq + self.cq_id = cq_id + + def get_1D(self, count): + self.sq.put ( (self.cq_id,0, count) ) + + while True: + if not self.cq.empty(): + return self.cq.get() + time.sleep(0.001) + + def get_2D(self, idxs, count): + self.sq.put ( (self.cq_id,1,idxs,count) ) + + while True: + if not self.cq.empty(): + return self.cq.get() + time.sleep(0.001) + ''' arg output_sample_types = [ @@ -45,7 +137,7 @@ class SampleGeneratorFacePerson(SampleGeneratorBase): for i,sample in enumerate(samples): persons_name_idxs[sample.person_name].append (i) indexes2D = [ persons_name_idxs[person_name] for person_name in unique_person_names ] - index2d_host = mplib.Index2DHost(indexes2D) + index2d_host = Index2DHost(indexes2D) if self.debug: self.generators_count = 1 diff --git a/samplelib/SampleGeneratorFaceXSeg.py b/samplelib/SampleGeneratorFaceXSeg.py index c744d40..7e38e64 100644 --- a/samplelib/SampleGeneratorFaceXSeg.py +++ b/samplelib/SampleGeneratorFaceXSeg.py @@ -6,12 +6,12 @@ from enum import IntEnum import cv2 import numpy as np - +from pathlib import Path from core import imagelib, mplib, pathex from core.imagelib import sd from core.cv2ex import * from core.interact import interact as io -from core.joblib import SubprocessGenerator, ThisThreadGenerator +from core.joblib import Subprocessor, SubprocessGenerator, ThisThreadGenerator from facelib import LandmarksProcessor from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor, SampleType) @@ -23,29 +23,28 @@ class SampleGeneratorFaceXSeg(SampleGeneratorBase): super().__init__(debug, batch_size) self.initialized = False - samples = [] - for path in paths: - samples += SampleLoader.load (SampleType.FACE, path) - - seg_samples = [ sample for sample in samples if sample.seg_ie_polys.get_total_points() != 0] - seg_samples_len = len(seg_samples) - if seg_samples_len == 0: - raise Exception(f"No segmented faces found.") + samples = sum([ SampleLoader.load (SampleType.FACE, path) for path in paths ] ) + seg_sample_idxs = SegmentedSampleFilterSubprocessor(samples).run() + + if len(seg_sample_idxs) == 0: + seg_sample_idxs = SegmentedSampleFilterSubprocessor(samples, count_xseg_mask=True).run() + if len(seg_sample_idxs) == 0: + raise Exception(f"No segmented faces found.") + else: + io.log_info(f"Using {len(seg_sample_idxs)} xseg labeled samples.") else: - io.log_info(f"Using {seg_samples_len} segmented samples.") - - pickled_samples = pickle.dumps(seg_samples, 4) - + io.log_info(f"Using {len(seg_sample_idxs)} segmented samples.") + if self.debug: self.generators_count = 1 else: self.generators_count = max(1, generators_count) + args = (samples, seg_sample_idxs, resolution, face_type, data_format) if self.debug: - self.generators = [ThisThreadGenerator ( self.batch_func, (pickled_samples, resolution, face_type, data_format) )] + self.generators = [ThisThreadGenerator ( self.batch_func, args )] else: - self.generators = [SubprocessGenerator ( self.batch_func, (pickled_samples, resolution, face_type, data_format), start_now=False ) \ - for i in range(self.generators_count) ] + self.generators = [SubprocessGenerator ( self.batch_func, args, start_now=False ) for i in range(self.generators_count) ] SubprocessGenerator.start_in_parallel( self.generators ) @@ -66,13 +65,11 @@ class SampleGeneratorFaceXSeg(SampleGeneratorBase): return next(generator) def batch_func(self, param ): - pickled_samples, resolution, face_type, data_format = param + samples, seg_sample_idxs, resolution, face_type, data_format = param - samples = pickle.loads(pickled_samples) - shuffle_idxs = [] - idxs = [*range(len(samples))] - + bg_shuffle_idxs = [] + random_flip = True rotation_range=[-10,10] scale_range=[-0.05, 0.05] @@ -80,8 +77,37 @@ class SampleGeneratorFaceXSeg(SampleGeneratorBase): ty_range=[-0.05, 0.05] random_bilinear_resize_chance, random_bilinear_resize_max_size_per = 25,75 + sharpen_chance, sharpen_kernel_max_size = 25, 5 motion_blur_chance, motion_blur_mb_max_size = 25, 5 gaussian_blur_chance, gaussian_blur_kernel_max_size = 25, 5 + random_jpeg_compress_chance = 25 + + def gen_img_mask(sample): + img = sample.load_bgr() + h,w,c = img.shape + + if sample.seg_ie_polys.has_polys(): + mask = np.zeros ((h,w,1), dtype=np.float32) + sample.seg_ie_polys.overlay_mask(mask) + elif sample.has_xseg_mask(): + mask = sample.get_xseg_mask() + mask[mask < 0.5] = 0.0 + mask[mask >= 0.5] = 1.0 + else: + raise Exception(f'no mask in sample {sample.filename}') + + if face_type == sample.face_type: + if w != resolution: + img = cv2.resize( img, (resolution, resolution), interpolation=cv2.INTER_LANCZOS4 ) + mask = cv2.resize( mask, (resolution, resolution), interpolation=cv2.INTER_LANCZOS4 ) + else: + mat = LandmarksProcessor.get_transform_mat (sample.landmarks, resolution, face_type) + img = cv2.warpAffine( img, mat, (resolution,resolution), borderMode=cv2.BORDER_CONSTANT, flags=cv2.INTER_LANCZOS4 ) + mask = cv2.warpAffine( mask, mat, (resolution,resolution), borderMode=cv2.BORDER_CONSTANT, flags=cv2.INTER_LANCZOS4 ) + + if len(mask.shape) == 2: + mask = mask[...,None] + return img, mask bs = self.batch_size while True: @@ -91,49 +117,72 @@ class SampleGeneratorFaceXSeg(SampleGeneratorBase): while n_batch < bs: try: if len(shuffle_idxs) == 0: - shuffle_idxs = idxs.copy() + shuffle_idxs = seg_sample_idxs.copy() np.random.shuffle(shuffle_idxs) - idx = shuffle_idxs.pop() - - sample = samples[idx] + sample = samples[shuffle_idxs.pop()] + img, mask = gen_img_mask(sample) - img = sample.load_bgr() - h,w,c = img.shape + if np.random.randint(2) == 0: + if len(bg_shuffle_idxs) == 0: + bg_shuffle_idxs = seg_sample_idxs.copy() + np.random.shuffle(bg_shuffle_idxs) + bg_sample = samples[bg_shuffle_idxs.pop()] - mask = np.zeros ((h,w,1), dtype=np.float32) - sample.seg_ie_polys.overlay_mask(mask) + bg_img, bg_mask = gen_img_mask(bg_sample) + + bg_wp = imagelib.gen_warp_params(resolution, True, rotation_range=[-180,180], scale_range=[-0.10, 0.10], tx_range=[-0.10, 0.10], ty_range=[-0.10, 0.10] ) + bg_img = imagelib.warp_by_params (bg_wp, bg_img, can_warp=False, can_transform=True, can_flip=True, border_replicate=True) + bg_mask = imagelib.warp_by_params (bg_wp, bg_mask, can_warp=False, can_transform=True, can_flip=True, border_replicate=False) + bg_img = bg_img*(1-bg_mask) + if np.random.randint(2) == 0: + bg_img = imagelib.apply_random_hsv_shift(bg_img) + else: + bg_img = imagelib.apply_random_rgb_levels(bg_img) + + c_mask = 1.0 - (1-bg_mask) * (1-mask) + rnd = 0.15 + np.random.uniform()*0.85 + img = img*(c_mask) + img*(1-c_mask)*rnd + bg_img*(1-c_mask)*(1-rnd) warp_params = imagelib.gen_warp_params(resolution, random_flip, rotation_range=rotation_range, scale_range=scale_range, tx_range=tx_range, ty_range=ty_range ) - - if face_type == sample.face_type: - if w != resolution: - img = cv2.resize( img, (resolution, resolution), cv2.INTER_LANCZOS4 ) - mask = cv2.resize( mask, (resolution, resolution), cv2.INTER_LANCZOS4 ) - else: - mat = LandmarksProcessor.get_transform_mat (sample.landmarks, resolution, face_type) - img = cv2.warpAffine( img, mat, (resolution,resolution), borderMode=cv2.BORDER_CONSTANT, flags=cv2.INTER_LANCZOS4 ) - mask = cv2.warpAffine( mask, mat, (resolution,resolution), borderMode=cv2.BORDER_CONSTANT, flags=cv2.INTER_LANCZOS4 ) - - if len(mask.shape) == 2: - mask = mask[...,None] - - img = imagelib.warp_by_params (warp_params, img, can_warp=True, can_transform=True, can_flip=True, border_replicate=False) + img = imagelib.warp_by_params (warp_params, img, can_warp=True, can_transform=True, can_flip=True, border_replicate=True) mask = imagelib.warp_by_params (warp_params, mask, can_warp=True, can_transform=True, can_flip=True, border_replicate=False) img = np.clip(img.astype(np.float32), 0, 1) mask[mask < 0.5] = 0.0 mask[mask >= 0.5] = 1.0 mask = np.clip(mask, 0, 1) + + if np.random.randint(2) == 0: + # random face flare + krn = np.random.randint( resolution//4, resolution ) + krn = krn - krn % 2 + 1 + img = img + cv2.GaussianBlur(img*mask, (krn,krn), 0) + + if np.random.randint(2) == 0: + # random bg flare + krn = np.random.randint( resolution//4, resolution ) + krn = krn - krn % 2 + 1 + img = img + cv2.GaussianBlur(img*(1-mask), (krn,krn), 0) if np.random.randint(2) == 0: img = imagelib.apply_random_hsv_shift(img, mask=sd.random_circle_faded ([resolution,resolution])) - else: + else: img = imagelib.apply_random_rgb_levels(img, mask=sd.random_circle_faded ([resolution,resolution])) - - img = imagelib.apply_random_motion_blur( img, motion_blur_chance, motion_blur_mb_max_size, mask=sd.random_circle_faded ([resolution,resolution])) - img = imagelib.apply_random_gaussian_blur( img, gaussian_blur_chance, gaussian_blur_kernel_max_size, mask=sd.random_circle_faded ([resolution,resolution])) - img = imagelib.apply_random_bilinear_resize( img, random_bilinear_resize_chance, random_bilinear_resize_max_size_per, mask=sd.random_circle_faded ([resolution,resolution])) - + + if np.random.randint(2) == 0: + img = imagelib.apply_random_sharpen( img, sharpen_chance, sharpen_kernel_max_size, mask=sd.random_circle_faded ([resolution,resolution])) + else: + img = imagelib.apply_random_motion_blur( img, motion_blur_chance, motion_blur_mb_max_size, mask=sd.random_circle_faded ([resolution,resolution])) + img = imagelib.apply_random_gaussian_blur( img, gaussian_blur_chance, gaussian_blur_kernel_max_size, mask=sd.random_circle_faded ([resolution,resolution])) + + if np.random.randint(2) == 0: + img = imagelib.apply_random_nearest_resize( img, random_bilinear_resize_chance, random_bilinear_resize_max_size_per, mask=sd.random_circle_faded ([resolution,resolution])) + else: + img = imagelib.apply_random_bilinear_resize( img, random_bilinear_resize_chance, random_bilinear_resize_max_size_per, mask=sd.random_circle_faded ([resolution,resolution])) + img = np.clip(img, 0, 1) + + img = imagelib.apply_random_jpeg_compress( img, random_jpeg_compress_chance, mask=sd.random_circle_faded ([resolution,resolution])) + if data_format == "NCHW": img = np.transpose(img, (2,0,1) ) mask = np.transpose(mask, (2,0,1) ) @@ -146,3 +195,103 @@ class SampleGeneratorFaceXSeg(SampleGeneratorBase): io.log_err ( traceback.format_exc() ) yield [ np.array(batch) for batch in batches] + +class SegmentedSampleFilterSubprocessor(Subprocessor): + #override + def __init__(self, samples, count_xseg_mask=False ): + self.samples = samples + self.samples_len = len(self.samples) + self.count_xseg_mask = count_xseg_mask + + self.idxs = [*range(self.samples_len)] + self.result = [] + super().__init__('SegmentedSampleFilterSubprocessor', SegmentedSampleFilterSubprocessor.Cli, 60) + + #override + def process_info_generator(self): + for i in range(multiprocessing.cpu_count()): + yield 'CPU%d' % (i), {}, {'samples':self.samples, 'count_xseg_mask':self.count_xseg_mask} + + #override + def on_clients_initialized(self): + io.progress_bar ("Filtering", self.samples_len) + + #override + def on_clients_finalized(self): + io.progress_bar_close() + + #override + def get_data(self, host_dict): + if len (self.idxs) > 0: + return self.idxs.pop(0) + + return None + + #override + def on_data_return (self, host_dict, data): + self.idxs.insert(0, data) + + #override + def on_result (self, host_dict, data, result): + idx, is_ok = result + if is_ok: + self.result.append(idx) + io.progress_bar_inc(1) + def get_result(self): + return self.result + + class Cli(Subprocessor.Cli): + #overridable optional + def on_initialize(self, client_dict): + self.samples = client_dict['samples'] + self.count_xseg_mask = client_dict['count_xseg_mask'] + + def process_data(self, idx): + if self.count_xseg_mask: + return idx, self.samples[idx].has_xseg_mask() + else: + return idx, self.samples[idx].seg_ie_polys.get_pts_count() != 0 + +""" + bg_path = None + for path in paths: + bg_path = Path(path) / 'backgrounds' + if bg_path.exists(): + + break + if bg_path is None: + io.log_info(f'Random backgrounds will not be used. Place no face jpg images to aligned\backgrounds folder. ') + bg_pathes = None + else: + bg_pathes = pathex.get_image_paths(bg_path, image_extensions=['.jpg'], return_Path_class=True) + io.log_info(f'Using {len(bg_pathes)} random backgrounds from {bg_path}') + +if bg_pathes is not None: + bg_path = bg_pathes[ np.random.randint(len(bg_pathes)) ] + + bg_img = cv2_imread(bg_path) + if bg_img is not None: + bg_img = bg_img.astype(np.float32) / 255.0 + bg_img = imagelib.normalize_channels(bg_img, 3) + + bg_img = imagelib.random_crop(bg_img, resolution, resolution) + bg_img = cv2.resize(bg_img, (resolution, resolution), interpolation=cv2.INTER_LINEAR) + + if np.random.randint(2) == 0: + bg_img = imagelib.apply_random_hsv_shift(bg_img) + else: + bg_img = imagelib.apply_random_rgb_levels(bg_img) + + bg_wp = imagelib.gen_warp_params(resolution, True, rotation_range=[-180,180], scale_range=[0,0], tx_range=[0,0], ty_range=[0,0]) + bg_img = imagelib.warp_by_params (bg_wp, bg_img, can_warp=False, can_transform=True, can_flip=True, border_replicate=True) + + bg = img*(1-mask) + fg = img*mask + + c_mask = sd.random_circle_faded ([resolution,resolution]) + bg = ( bg_img*c_mask + bg*(1-c_mask) )*(1-mask) + + img = fg+bg + + else: +""" \ No newline at end of file diff --git a/samplelib/SampleLoader.py b/samplelib/SampleLoader.py index 32a8ba7..2989354 100644 --- a/samplelib/SampleLoader.py +++ b/samplelib/SampleLoader.py @@ -6,6 +6,7 @@ from pathlib import Path import samplelib.PackedFaceset from core import pathex +from core.mplib import MPSharedList from core.interact import interact as io from core.joblib import Subprocessor from DFLIMG import * @@ -22,7 +23,7 @@ class SampleLoader: try: samples = samplelib.PackedFaceset.load(samples_path) except: - io.log_err(f"Error occured while loading samplelib.PackedFaceset.load {str(samples_dat_path)}, {traceback.format_exc()}") + io.log_err(f"Error occured while loading samplelib.PackedFaceset.load {str(samples_path)}, {traceback.format_exc()}") if samples is None: raise ValueError("packed faceset not found.") @@ -33,6 +34,9 @@ class SampleLoader: @staticmethod def load(sample_type, samples_path, subdirs=False): + """ + Return MPSharedList of samples + """ samples_cache = SampleLoader.samples_cache if str(samples_path) not in samples_cache.keys(): @@ -56,12 +60,12 @@ class SampleLoader: if result is None: result = SampleLoader.load_face_samples( pathex.get_image_paths(samples_path, subdirs=subdirs) ) - samples[sample_type] = result + samples[sample_type] = MPSharedList(result) elif sample_type == SampleType.FACE_TEMPORAL_SORTED: result = SampleLoader.load (SampleType.FACE, samples_path) result = SampleLoader.upgradeToFaceTemporalSortedSamples(result) - samples[sample_type] = result + samples[sample_type] = MPSharedList(result) return samples[sample_type] @@ -70,49 +74,29 @@ class SampleLoader: result = FaceSamplesLoaderSubprocessor(image_paths).run() sample_list = [] - for filename, \ - ( face_type, - shape, - landmarks, - ie_polys, - seg_ie_polys, - eyebrows_expand_mod, - source_filename, - ) in result: + for filename, data in result: + if data is None: + continue + ( face_type, + shape, + landmarks, + seg_ie_polys, + xseg_mask_compressed, + eyebrows_expand_mod, + source_filename ) = data + sample_list.append( Sample(filename=filename, sample_type=SampleType.FACE, face_type=FaceType.fromString (face_type), shape=shape, landmarks=landmarks, - ie_polys=ie_polys, seg_ie_polys=seg_ie_polys, + xseg_mask_compressed=xseg_mask_compressed, eyebrows_expand_mod=eyebrows_expand_mod, source_filename=source_filename, )) return sample_list - """ - @staticmethod - def load_face_samples ( image_paths): - sample_list = [] - - for filename in io.progress_bar_generator (image_paths, desc="Loading"): - dflimg = DFLIMG.load (Path(filename)) - if dflimg is None: - io.log_err (f"{filename} is not a dfl image file.") - else: - sample_list.append( Sample(filename=filename, - sample_type=SampleType.FACE, - face_type=FaceType.fromString ( dflimg.get_face_type() ), - shape=dflimg.get_shape(), - landmarks=dflimg.get_landmarks(), - ie_polys=dflimg.get_ie_polys(), - eyebrows_expand_mod=dflimg.get_eyebrows_expand_mod(), - source_filename=dflimg.get_source_filename(), - )) - return sample_list - """ - @staticmethod def upgradeToFaceTemporalSortedSamples( samples ): new_s = [ (s, s.source_filename) for s in samples] @@ -171,15 +155,15 @@ class FaceSamplesLoaderSubprocessor(Subprocessor): idx, filename = data dflimg = DFLIMG.load (Path(filename)) - if dflimg is None: + if dflimg is None or not dflimg.has_data(): self.log_err (f"FaceSamplesLoader: {filename} is not a dfl image file.") data = None else: data = (dflimg.get_face_type(), dflimg.get_shape(), dflimg.get_landmarks(), - dflimg.get_ie_polys(), dflimg.get_seg_ie_polys(), + dflimg.get_xseg_mask_compressed(), dflimg.get_eyebrows_expand_mod(), dflimg.get_source_filename() ) diff --git a/samplelib/SampleProcessor.py b/samplelib/SampleProcessor.py index b15232d..7432e75 100644 --- a/samplelib/SampleProcessor.py +++ b/samplelib/SampleProcessor.py @@ -6,6 +6,8 @@ import cv2 import numpy as np from core import imagelib +from core.cv2ex import * +from core.imagelib import sd from facelib import FaceType, LandmarksProcessor @@ -24,16 +26,12 @@ class SampleProcessor(object): BGR = 1 #BGR G = 2 #Grayscale GGG = 3 #3xGrayscale - BGR_SHUFFLE = 4 #BGR shuffle - BGR_RANDOM_HSV_SHIFT = 5 - BGR_RANDOM_RGB_LEVELS = 6 - G_MASK = 7 class FaceMaskType(IntEnum): NONE = 0 - FULL_FACE = 1 #mask all hull as grayscale - EYES = 2 #mask eyes hull as grayscale - FULL_FACE_EYES = 3 #combo all + eyes as grayscale + FULL_FACE = 1 # mask all hull as grayscale + EYES = 2 # mask eyes hull as grayscale + EYES_MOUTH = 3 # eyes and mouse class Options(object): def __init__(self, random_flip = True, rotation_range=[-10,10], scale_range=[-0.05, 0.05], tx_range=[-0.05, 0.05], ty_range=[-0.05, 0.05] ): @@ -49,60 +47,72 @@ class SampleProcessor(object): SPCT = SampleProcessor.ChannelType SPFMT = SampleProcessor.FaceMaskType - sample_rnd_seed = np.random.randint(0x80000000) - + outputs = [] for sample in samples: + sample_rnd_seed = np.random.randint(0x80000000) + sample_face_type = sample.face_type sample_bgr = sample.load_bgr() sample_landmarks = sample.landmarks ct_sample_bgr = None h,w,c = sample_bgr.shape - def get_full_face_mask(): - if sample.eyebrows_expand_mod is not None: - full_face_mask = LandmarksProcessor.get_image_hull_mask (sample_bgr.shape, sample_landmarks, eyebrows_expand_mod=sample.eyebrows_expand_mod ) + def get_full_face_mask(): + xseg_mask = sample.get_xseg_mask() + if xseg_mask is not None: + if xseg_mask.shape[0] != h or xseg_mask.shape[1] != w: + xseg_mask = cv2.resize(xseg_mask, (w,h), interpolation=cv2.INTER_CUBIC) + xseg_mask = imagelib.normalize_channels(xseg_mask, 1) + return np.clip(xseg_mask, 0, 1) else: - full_face_mask = LandmarksProcessor.get_image_hull_mask (sample_bgr.shape, sample_landmarks) - return np.clip(full_face_mask, 0, 1) + full_face_mask = LandmarksProcessor.get_image_hull_mask (sample_bgr.shape, sample_landmarks, eyebrows_expand_mod=sample.eyebrows_expand_mod ) + return np.clip(full_face_mask, 0, 1) def get_eyes_mask(): eyes_mask = LandmarksProcessor.get_image_eye_mask (sample_bgr.shape, sample_landmarks) return np.clip(eyes_mask, 0, 1) - + + def get_eyes_mouth_mask(): + eyes_mask = LandmarksProcessor.get_image_eye_mask (sample_bgr.shape, sample_landmarks) + mouth_mask = LandmarksProcessor.get_image_mouth_mask (sample_bgr.shape, sample_landmarks) + mask = eyes_mask + mouth_mask + return np.clip(mask, 0, 1) + is_face_sample = sample_landmarks is not None if debug and is_face_sample: LandmarksProcessor.draw_landmarks (sample_bgr, sample_landmarks, (0, 1, 0)) - - params_per_resolution = {} - warp_rnd_state = np.random.RandomState (sample_rnd_seed-1) - for opts in output_sample_types: - resolution = opts.get('resolution', None) - if resolution is None: - continue - params_per_resolution[resolution] = imagelib.gen_warp_params(resolution, - sample_process_options.random_flip, - rotation_range=sample_process_options.rotation_range, - scale_range=sample_process_options.scale_range, - tx_range=sample_process_options.tx_range, - ty_range=sample_process_options.ty_range, - rnd_state=warp_rnd_state) outputs_sample = [] for opts in output_sample_types: + resolution = opts.get('resolution', 0) sample_type = opts.get('sample_type', SPST.NONE) channel_type = opts.get('channel_type', SPCT.NONE) - resolution = opts.get('resolution', 0) + nearest_resize_to = opts.get('nearest_resize_to', None) warp = opts.get('warp', False) transform = opts.get('transform', False) - motion_blur = opts.get('motion_blur', None) - gaussian_blur = opts.get('gaussian_blur', None) - random_bilinear_resize = opts.get('random_bilinear_resize', None) + random_hsv_shift_amount = opts.get('random_hsv_shift_amount', 0) normalize_tanh = opts.get('normalize_tanh', False) ct_mode = opts.get('ct_mode', None) data_format = opts.get('data_format', 'NHWC') + rnd_seed_shift = opts.get('rnd_seed_shift', 0) + warp_rnd_seed_shift = opts.get('warp_rnd_seed_shift', rnd_seed_shift) + + rnd_state = np.random.RandomState (sample_rnd_seed+rnd_seed_shift) + warp_rnd_state = np.random.RandomState (sample_rnd_seed+warp_rnd_seed_shift) + + warp_params = imagelib.gen_warp_params(resolution, + sample_process_options.random_flip, + rotation_range=sample_process_options.rotation_range, + scale_range=sample_process_options.scale_range, + tx_range=sample_process_options.tx_range, + ty_range=sample_process_options.ty_range, + rnd_state=rnd_state, + warp_rnd_state=warp_rnd_state, + ) + if sample_type == SPST.FACE_MASK or sample_type == SPST.IMAGE: border_replicate = False elif sample_type == SPST.FACE_IMAGE: @@ -124,40 +134,40 @@ class SampleProcessor(object): if face_type is None: raise ValueError("face_type must be defined for face samples") - if face_type > sample.face_type: - raise Exception ('sample %s type %s does not match model requirement %s. Consider extract necessary type of faces.' % (sample.filename, sample.face_type, face_type) ) - - - if sample_type == SPST.FACE_MASK: - + if sample_type == SPST.FACE_MASK: if face_mask_type == SPFMT.FULL_FACE: img = get_full_face_mask() elif face_mask_type == SPFMT.EYES: img = get_eyes_mask() - elif face_mask_type == SPFMT.FULL_FACE_EYES: - img = get_full_face_mask() + get_eyes_mask() + elif face_mask_type == SPFMT.EYES_MOUTH: + mask = get_full_face_mask().copy() + mask[mask != 0.0] = 1.0 + img = get_eyes_mouth_mask()*mask else: img = np.zeros ( sample_bgr.shape[0:2]+(1,), dtype=np.float32) - - if sample.ie_polys is not None: - sample.ie_polys.overlay_mask(img) if sample_face_type == FaceType.MARK_ONLY: + raise NotImplementedError() mat = LandmarksProcessor.get_transform_mat (sample_landmarks, warp_resolution, face_type) img = cv2.warpAffine( img, mat, (warp_resolution, warp_resolution), flags=cv2.INTER_LINEAR ) - img = imagelib.warp_by_params (params_per_resolution[resolution], img, warp, transform, can_flip=True, border_replicate=border_replicate, cv2_inter=cv2.INTER_LINEAR) - img = cv2.resize( img, (resolution,resolution), cv2.INTER_LINEAR ) + img = imagelib.warp_by_params (warp_params, img, warp, transform, can_flip=True, border_replicate=border_replicate, cv2_inter=cv2.INTER_LINEAR) + img = cv2.resize( img, (resolution,resolution), interpolation=cv2.INTER_LINEAR ) else: if face_type != sample_face_type: mat = LandmarksProcessor.get_transform_mat (sample_landmarks, resolution, face_type) img = cv2.warpAffine( img, mat, (resolution,resolution), borderMode=borderMode, flags=cv2.INTER_LINEAR ) else: if w != resolution: - img = cv2.resize( img, (resolution, resolution), cv2.INTER_CUBIC ) + img = cv2.resize( img, (resolution, resolution), interpolation=cv2.INTER_LINEAR ) - img = imagelib.warp_by_params (params_per_resolution[resolution], img, warp, transform, can_flip=True, border_replicate=border_replicate, cv2_inter=cv2.INTER_LINEAR) + img = imagelib.warp_by_params (warp_params, img, warp, transform, can_flip=True, border_replicate=border_replicate, cv2_inter=cv2.INTER_LINEAR) + if face_mask_type == SPFMT.EYES_MOUTH: + div = img.max() + if div != 0.0: + img = img / div # normalize to 1.0 after warp + if len(img.shape) == 2: img = img[...,None] @@ -168,91 +178,45 @@ class SampleProcessor(object): elif sample_type == SPST.FACE_IMAGE: img = sample_bgr - - + if face_type != sample_face_type: mat = LandmarksProcessor.get_transform_mat (sample_landmarks, resolution, face_type) img = cv2.warpAffine( img, mat, (resolution,resolution), borderMode=borderMode, flags=cv2.INTER_CUBIC ) else: if w != resolution: - img = cv2.resize( img, (resolution, resolution), cv2.INTER_CUBIC ) - - img = imagelib.warp_by_params (params_per_resolution[resolution], img, warp, transform, can_flip=True, border_replicate=border_replicate) - - img = np.clip(img.astype(np.float32), 0, 1) - - + img = cv2.resize( img, (resolution, resolution), interpolation=cv2.INTER_CUBIC ) + # Apply random color transfer if ct_mode is not None and ct_sample is not None: if ct_sample_bgr is None: ct_sample_bgr = ct_sample.load_bgr() - img = imagelib.color_transfer (ct_mode, img, cv2.resize( ct_sample_bgr, (resolution,resolution), cv2.INTER_LINEAR ) ) + img = imagelib.color_transfer (ct_mode, img, cv2.resize( ct_sample_bgr, (resolution,resolution), interpolation=cv2.INTER_LINEAR ) ) + + if random_hsv_shift_amount != 0: + a = random_hsv_shift_amount + h_amount = max(1, int(360*a*0.5)) + img_h, img_s, img_v = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV)) + img_h = (img_h + rnd_state.randint(-h_amount, h_amount+1) ) % 360 + img_s = np.clip (img_s + (rnd_state.random()-0.5)*a, 0, 1 ) + img_v = np.clip (img_v + (rnd_state.random()-0.5)*a, 0, 1 ) + img = np.clip( cv2.cvtColor(cv2.merge([img_h, img_s, img_v]), cv2.COLOR_HSV2BGR) , 0, 1 ) - if motion_blur is not None: - chance, mb_max_size = motion_blur - chance = np.clip(chance, 0, 100) + img = imagelib.warp_by_params (warp_params, img, warp, transform, can_flip=True, border_replicate=border_replicate) + + img = np.clip(img.astype(np.float32), 0, 1) - l_rnd_state = np.random.RandomState (sample_rnd_seed) - mblur_rnd_chance = l_rnd_state.randint(100) - mblur_rnd_kernel = l_rnd_state.randint(mb_max_size)+1 - mblur_rnd_deg = l_rnd_state.randint(360) - - if mblur_rnd_chance < chance: - img = imagelib.LinearMotionBlur (img, mblur_rnd_kernel, mblur_rnd_deg ) - - if gaussian_blur is not None: - chance, kernel_max_size = gaussian_blur - chance = np.clip(chance, 0, 100) - - l_rnd_state = np.random.RandomState (sample_rnd_seed+1) - gblur_rnd_chance = l_rnd_state.randint(100) - gblur_rnd_kernel = l_rnd_state.randint(kernel_max_size)*2+1 - - if gblur_rnd_chance < chance: - img = cv2.GaussianBlur(img, (gblur_rnd_kernel,) *2 , 0) - - if random_bilinear_resize is not None: - l_rnd_state = np.random.RandomState (sample_rnd_seed+2) - - chance, max_size_per = random_bilinear_resize - chance = np.clip(chance, 0, 100) - pick_chance = l_rnd_state.randint(100) - resize_to = resolution - int( l_rnd_state.rand()* int(resolution*(max_size_per/100.0)) ) - img = cv2.resize (img, (resize_to,resize_to), cv2.INTER_LINEAR ) - img = cv2.resize (img, (resolution,resolution), cv2.INTER_LINEAR ) - # Transform from BGR to desired channel_type if channel_type == SPCT.BGR: out_sample = img - elif channel_type == SPCT.BGR_SHUFFLE: - l_rnd_state = np.random.RandomState (sample_rnd_seed) - out_sample = np.take (img, l_rnd_state.permutation(img.shape[-1]), axis=-1) - elif channel_type == SPCT.BGR_RANDOM_HSV_SHIFT: - l_rnd_state = np.random.RandomState (sample_rnd_seed) - hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) - h, s, v = cv2.split(hsv) - h = (h + l_rnd_state.randint(360) ) % 360 - s = np.clip ( s + l_rnd_state.random()-0.5, 0, 1 ) - v = np.clip ( v + l_rnd_state.random()/2-0.25, 0, 1 ) - hsv = cv2.merge([h, s, v]) - out_sample = np.clip( cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) , 0, 1 ) - elif channel_type == SPCT.BGR_RANDOM_RGB_LEVELS: - l_rnd_state = np.random.RandomState (sample_rnd_seed) - np_rnd = l_rnd_state.rand - inBlack = np.array([np_rnd()*0.25 , np_rnd()*0.25 , np_rnd()*0.25], dtype=np.float32) - inWhite = np.array([1.0-np_rnd()*0.25, 1.0-np_rnd()*0.25, 1.0-np_rnd()*0.25], dtype=np.float32) - inGamma = np.array([0.5+np_rnd(), 0.5+np_rnd(), 0.5+np_rnd()], dtype=np.float32) - outBlack = np.array([0.0, 0.0, 0.0], dtype=np.float32) - outWhite = np.array([1.0, 1.0, 1.0], dtype=np.float32) - out_sample = np.clip( (img - inBlack) / (inWhite - inBlack), 0, 1 ) - out_sample = ( out_sample ** (1/inGamma) ) * (outWhite - outBlack) + outBlack - out_sample = np.clip(out_sample, 0, 1) elif channel_type == SPCT.G: out_sample = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[...,None] elif channel_type == SPCT.GGG: out_sample = np.repeat ( np.expand_dims(cv2.cvtColor(img, cv2.COLOR_BGR2GRAY),-1), (3,), -1) # Final transformations + if nearest_resize_to is not None: + out_sample = cv2_resize(out_sample, (nearest_resize_to,nearest_resize_to), interpolation=cv2.INTER_NEAREST) + if not debug: if normalize_tanh: out_sample = np.clip (out_sample * 2.0 - 1.0, -1.0, 1.0) @@ -260,8 +224,8 @@ class SampleProcessor(object): out_sample = np.transpose(out_sample, (2,0,1) ) elif sample_type == SPST.IMAGE: img = sample_bgr - img = imagelib.warp_by_params (params_per_resolution[resolution], img, warp, transform, can_flip=True, border_replicate=True) - img = cv2.resize( img, (resolution, resolution), cv2.INTER_CUBIC ) + img = imagelib.warp_by_params (warp_params, img, warp, transform, can_flip=True, border_replicate=True) + img = cv2.resize( img, (resolution, resolution), interpolation=cv2.INTER_CUBIC ) out_sample = img if data_format == "NCHW": @@ -275,7 +239,7 @@ class SampleProcessor(object): out_sample = l elif sample_type == SPST.PITCH_YAW_ROLL or sample_type == SPST.PITCH_YAW_ROLL_SIGMOID: pitch,yaw,roll = sample.get_pitch_yaw_roll() - if params_per_resolution[resolution]['flip']: + if warp_params['flip']: yaw = -yaw if sample_type == SPST.PITCH_YAW_ROLL_SIGMOID: @@ -292,65 +256,3 @@ class SampleProcessor(object): return outputs -""" - - STRUCT = 4 #mask structure as grayscale - elif face_mask_type == SPFMT.STRUCT: - if sample.eyebrows_expand_mod is not None: - img = LandmarksProcessor.get_face_struct_mask (sample_bgr.shape, sample_landmarks, eyebrows_expand_mod=sample.eyebrows_expand_mod ) - else: - img = LandmarksProcessor.get_face_struct_mask (sample_bgr.shape, sample_landmarks) - - - - close_sample = sample.close_target_list[ np.random.randint(0, len(sample.close_target_list)) ] if sample.close_target_list is not None else None - close_sample_bgr = close_sample.load_bgr() if close_sample is not None else None - - if debug and close_sample_bgr is not None: - LandmarksProcessor.draw_landmarks (close_sample_bgr, close_sample.landmarks, (0, 1, 0)) - RANDOM_CLOSE = 0x00000040, #currently unused - MORPH_TO_RANDOM_CLOSE = 0x00000080, #currently unused - -if f & SPTF.RANDOM_CLOSE != 0: - img_type += 10 - elif f & SPTF.MORPH_TO_RANDOM_CLOSE != 0: - img_type += 20 -if img_type >= 10 and img_type <= 19: #RANDOM_CLOSE - img_type -= 10 - img = close_sample_bgr - cur_sample = close_sample - -elif img_type >= 20 and img_type <= 29: #MORPH_TO_RANDOM_CLOSE - img_type -= 20 - res = sample.shape[0] - - s_landmarks = sample.landmarks.copy() - d_landmarks = close_sample.landmarks.copy() - idxs = list(range(len(s_landmarks))) - #remove landmarks near boundaries - for i in idxs[:]: - s_l = s_landmarks[i] - d_l = d_landmarks[i] - if s_l[0] < 5 or s_l[1] < 5 or s_l[0] >= res-5 or s_l[1] >= res-5 or \ - d_l[0] < 5 or d_l[1] < 5 or d_l[0] >= res-5 or d_l[1] >= res-5: - idxs.remove(i) - #remove landmarks that close to each other in 5 dist - for landmarks in [s_landmarks, d_landmarks]: - for i in idxs[:]: - s_l = landmarks[i] - for j in idxs[:]: - if i == j: - continue - s_l_2 = landmarks[j] - diff_l = np.abs(s_l - s_l_2) - if np.sqrt(diff_l.dot(diff_l)) < 5: - idxs.remove(i) - break - s_landmarks = s_landmarks[idxs] - d_landmarks = d_landmarks[idxs] - s_landmarks = np.concatenate ( [s_landmarks, [ [0,0], [ res // 2, 0], [ res-1, 0], [0, res//2], [res-1, res//2] ,[0,res-1] ,[res//2, res-1] ,[res-1,res-1] ] ] ) - d_landmarks = np.concatenate ( [d_landmarks, [ [0,0], [ res // 2, 0], [ res-1, 0], [0, res//2], [res-1, res//2] ,[0,res-1] ,[res//2, res-1] ,[res-1,res-1] ] ] ) - img = imagelib.morph_by_points (sample_bgr, s_landmarks, d_landmarks) - cur_sample = close_sample -else: - """