mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-08-22 06:23:20 -07:00
Merge branch 'master' into master
This commit is contained in:
commit
f62fff7f0f
51 changed files with 2279 additions and 1287 deletions
15
DFLIMG/DFLIMG.py
Normal file
15
DFLIMG/DFLIMG.py
Normal file
|
@ -0,0 +1,15 @@
|
|||
from pathlib import Path
|
||||
|
||||
from .DFLJPG import DFLJPG
|
||||
from .DFLPNG import DFLPNG
|
||||
|
||||
class DFLIMG():
|
||||
|
||||
@staticmethod
|
||||
def load(filepath, loader_func=None):
|
||||
if filepath.suffix == '.png':
|
||||
return DFLPNG.load( str(filepath), loader_func=loader_func )
|
||||
elif filepath.suffix == '.jpg':
|
||||
return DFLJPG.load ( str(filepath), loader_func=loader_func )
|
||||
else:
|
||||
return None
|
|
@ -5,7 +5,6 @@ import cv2
|
|||
import numpy as np
|
||||
|
||||
from facelib import FaceType
|
||||
from imagelib import IEPolys
|
||||
from utils.struct_utils import *
|
||||
from interact import interact as io
|
||||
|
||||
|
@ -18,10 +17,13 @@ class DFLJPG(object):
|
|||
self.shape = (0,0,0)
|
||||
|
||||
@staticmethod
|
||||
def load_raw(filename):
|
||||
def load_raw(filename, loader_func=None):
|
||||
try:
|
||||
with open(filename, "rb") as f:
|
||||
data = f.read()
|
||||
if loader_func is not None:
|
||||
data = loader_func(filename)
|
||||
else:
|
||||
with open(filename, "rb") as f:
|
||||
data = f.read()
|
||||
except:
|
||||
raise FileNotFoundError(filename)
|
||||
|
||||
|
@ -116,9 +118,9 @@ class DFLJPG(object):
|
|||
raise Exception ("Corrupted JPG file: %s" % (str(e)))
|
||||
|
||||
@staticmethod
|
||||
def load(filename):
|
||||
def load(filename, loader_func=None):
|
||||
try:
|
||||
inst = DFLJPG.load_raw (filename)
|
||||
inst = DFLJPG.load_raw (filename, loader_func=loader_func)
|
||||
inst.dfl_dict = None
|
||||
|
||||
for chunk in inst.chunks:
|
||||
|
@ -159,6 +161,17 @@ class DFLJPG(object):
|
|||
print (e)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def embed_dfldict(filename, dfl_dict):
|
||||
inst = DFLJPG.load_raw (filename)
|
||||
inst.setDFLDictData (dfl_dict)
|
||||
|
||||
try:
|
||||
with open(filename, "wb") as f:
|
||||
f.write ( inst.dump() )
|
||||
except:
|
||||
raise Exception( 'cannot save %s' % (filename) )
|
||||
|
||||
@staticmethod
|
||||
def embed_data(filename, face_type=None,
|
||||
landmarks=None,
|
||||
|
@ -168,7 +181,6 @@ class DFLJPG(object):
|
|||
source_landmarks=None,
|
||||
image_to_face_mat=None,
|
||||
fanseg_mask=None,
|
||||
pitch_yaw_roll=None,
|
||||
eyebrows_expand_mod=None,
|
||||
relighted=None,
|
||||
**kwargs
|
||||
|
@ -185,26 +197,17 @@ class DFLJPG(object):
|
|||
io.log_err("Unable to encode fanseg_mask for %s" % (filename) )
|
||||
fanseg_mask = None
|
||||
|
||||
inst = DFLJPG.load_raw (filename)
|
||||
inst.setDFLDictData ({
|
||||
'face_type': face_type,
|
||||
'landmarks': landmarks,
|
||||
'ie_polys' : ie_polys.dump() if ie_polys is not None else None,
|
||||
'source_filename': source_filename,
|
||||
'source_rect': source_rect,
|
||||
'source_landmarks': source_landmarks,
|
||||
'image_to_face_mat': image_to_face_mat,
|
||||
'fanseg_mask' : fanseg_mask,
|
||||
'pitch_yaw_roll' : pitch_yaw_roll,
|
||||
'eyebrows_expand_mod' : eyebrows_expand_mod,
|
||||
'relighted' : relighted
|
||||
})
|
||||
|
||||
try:
|
||||
with open(filename, "wb") as f:
|
||||
f.write ( inst.dump() )
|
||||
except:
|
||||
raise Exception( 'cannot save %s' % (filename) )
|
||||
DFLJPG.embed_dfldict (filename, {'face_type': face_type,
|
||||
'landmarks': landmarks,
|
||||
'ie_polys' : ie_polys.dump() if ie_polys is not None else None,
|
||||
'source_filename': source_filename,
|
||||
'source_rect': source_rect,
|
||||
'source_landmarks': source_landmarks,
|
||||
'image_to_face_mat': image_to_face_mat,
|
||||
'fanseg_mask' : fanseg_mask,
|
||||
'eyebrows_expand_mod' : eyebrows_expand_mod,
|
||||
'relighted' : relighted
|
||||
})
|
||||
|
||||
def embed_and_set(self, filename, face_type=None,
|
||||
landmarks=None,
|
||||
|
@ -214,7 +217,6 @@ class DFLJPG(object):
|
|||
source_landmarks=None,
|
||||
image_to_face_mat=None,
|
||||
fanseg_mask=None,
|
||||
pitch_yaw_roll=None,
|
||||
eyebrows_expand_mod=None,
|
||||
relighted=None,
|
||||
**kwargs
|
||||
|
@ -227,7 +229,6 @@ class DFLJPG(object):
|
|||
if source_landmarks is None: source_landmarks = self.get_source_landmarks()
|
||||
if image_to_face_mat is None: image_to_face_mat = self.get_image_to_face_mat()
|
||||
if fanseg_mask is None: fanseg_mask = self.get_fanseg_mask()
|
||||
if pitch_yaw_roll is None: pitch_yaw_roll = self.get_pitch_yaw_roll()
|
||||
if eyebrows_expand_mod is None: eyebrows_expand_mod = self.get_eyebrows_expand_mod()
|
||||
if relighted is None: relighted = self.get_relighted()
|
||||
DFLJPG.embed_data (filename, face_type=face_type,
|
||||
|
@ -238,7 +239,6 @@ class DFLJPG(object):
|
|||
source_landmarks=source_landmarks,
|
||||
image_to_face_mat=image_to_face_mat,
|
||||
fanseg_mask=fanseg_mask,
|
||||
pitch_yaw_roll=pitch_yaw_roll,
|
||||
relighted=relighted)
|
||||
|
||||
def remove_ie_polys(self):
|
||||
|
@ -300,7 +300,7 @@ class DFLJPG(object):
|
|||
|
||||
def get_face_type(self): return self.dfl_dict['face_type']
|
||||
def get_landmarks(self): return np.array ( self.dfl_dict['landmarks'] )
|
||||
def get_ie_polys(self): return IEPolys.load(self.dfl_dict.get('ie_polys',None))
|
||||
def get_ie_polys(self): return self.dfl_dict.get('ie_polys',None)
|
||||
def get_source_filename(self): return self.dfl_dict['source_filename']
|
||||
def get_source_rect(self): return self.dfl_dict['source_rect']
|
||||
def get_source_landmarks(self): return np.array ( self.dfl_dict['source_landmarks'] )
|
||||
|
@ -314,8 +314,6 @@ class DFLJPG(object):
|
|||
if fanseg_mask is not None:
|
||||
return np.clip ( np.array (fanseg_mask) / 255.0, 0.0, 1.0 )[...,np.newaxis]
|
||||
return None
|
||||
def get_pitch_yaw_roll(self):
|
||||
return self.dfl_dict.get ('pitch_yaw_roll', None)
|
||||
def get_eyebrows_expand_mod(self):
|
||||
return self.dfl_dict.get ('eyebrows_expand_mod', None)
|
||||
def get_relighted(self):
|
|
@ -7,7 +7,6 @@ import cv2
|
|||
import numpy as np
|
||||
|
||||
from facelib import FaceType
|
||||
from imagelib import IEPolys
|
||||
|
||||
PNG_HEADER = b"\x89PNG\r\n\x1a\n"
|
||||
|
||||
|
@ -225,10 +224,13 @@ class DFLPNG(object):
|
|||
self.dfl_dict = None
|
||||
|
||||
@staticmethod
|
||||
def load_raw(filename):
|
||||
def load_raw(filename, loader_func=None):
|
||||
try:
|
||||
with open(filename, "rb") as f:
|
||||
data = f.read()
|
||||
if loader_func is not None:
|
||||
data = loader_func(filename)
|
||||
else:
|
||||
with open(filename, "rb") as f:
|
||||
data = f.read()
|
||||
except:
|
||||
raise FileNotFoundError(filename)
|
||||
|
||||
|
@ -252,9 +254,9 @@ class DFLPNG(object):
|
|||
return inst
|
||||
|
||||
@staticmethod
|
||||
def load(filename):
|
||||
def load(filename, loader_func=None):
|
||||
try:
|
||||
inst = DFLPNG.load_raw (filename)
|
||||
inst = DFLPNG.load_raw (filename, loader_func=loader_func)
|
||||
inst.dfl_dict = inst.getDFLDictData()
|
||||
|
||||
if inst.dfl_dict is not None:
|
||||
|
@ -275,6 +277,17 @@ class DFLPNG(object):
|
|||
print(e)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def embed_dfldict(filename, dfl_dict):
|
||||
inst = DFLPNG.load_raw (filename)
|
||||
inst.setDFLDictData (dfl_dict)
|
||||
|
||||
try:
|
||||
with open(filename, "wb") as f:
|
||||
f.write ( inst.dump() )
|
||||
except:
|
||||
raise Exception( 'cannot save %s' % (filename) )
|
||||
|
||||
@staticmethod
|
||||
def embed_data(filename, face_type=None,
|
||||
landmarks=None,
|
||||
|
@ -284,7 +297,6 @@ class DFLPNG(object):
|
|||
source_landmarks=None,
|
||||
image_to_face_mat=None,
|
||||
fanseg_mask=None,
|
||||
pitch_yaw_roll=None,
|
||||
eyebrows_expand_mod=None,
|
||||
relighted=None,
|
||||
**kwargs
|
||||
|
@ -301,26 +313,17 @@ class DFLPNG(object):
|
|||
io.log_err("Unable to encode fanseg_mask for %s" % (filename) )
|
||||
fanseg_mask = None
|
||||
|
||||
inst = DFLPNG.load_raw (filename)
|
||||
inst.setDFLDictData ({
|
||||
'face_type': face_type,
|
||||
'landmarks': landmarks,
|
||||
'ie_polys' : ie_polys.dump() if ie_polys is not None else None,
|
||||
'source_filename': source_filename,
|
||||
'source_rect': source_rect,
|
||||
'source_landmarks': source_landmarks,
|
||||
'image_to_face_mat':image_to_face_mat,
|
||||
'fanseg_mask' : fanseg_mask,
|
||||
'pitch_yaw_roll' : pitch_yaw_roll,
|
||||
'eyebrows_expand_mod' : eyebrows_expand_mod,
|
||||
'relighted' : relighted
|
||||
})
|
||||
|
||||
try:
|
||||
with open(filename, "wb") as f:
|
||||
f.write ( inst.dump() )
|
||||
except:
|
||||
raise Exception( 'cannot save %s' % (filename) )
|
||||
DFLPNG.embed_dfldict (filename, {'face_type': face_type,
|
||||
'landmarks': landmarks,
|
||||
'ie_polys' : ie_polys.dump() if ie_polys is not None else None,
|
||||
'source_filename': source_filename,
|
||||
'source_rect': source_rect,
|
||||
'source_landmarks': source_landmarks,
|
||||
'image_to_face_mat':image_to_face_mat,
|
||||
'fanseg_mask' : fanseg_mask,
|
||||
'eyebrows_expand_mod' : eyebrows_expand_mod,
|
||||
'relighted' : relighted
|
||||
})
|
||||
|
||||
def embed_and_set(self, filename, face_type=None,
|
||||
landmarks=None,
|
||||
|
@ -330,7 +333,6 @@ class DFLPNG(object):
|
|||
source_landmarks=None,
|
||||
image_to_face_mat=None,
|
||||
fanseg_mask=None,
|
||||
pitch_yaw_roll=None,
|
||||
eyebrows_expand_mod=None,
|
||||
relighted=None,
|
||||
**kwargs
|
||||
|
@ -343,7 +345,6 @@ class DFLPNG(object):
|
|||
if source_landmarks is None: source_landmarks = self.get_source_landmarks()
|
||||
if image_to_face_mat is None: image_to_face_mat = self.get_image_to_face_mat()
|
||||
if fanseg_mask is None: fanseg_mask = self.get_fanseg_mask()
|
||||
if pitch_yaw_roll is None: pitch_yaw_roll = self.get_pitch_yaw_roll()
|
||||
if eyebrows_expand_mod is None: eyebrows_expand_mod = self.get_eyebrows_expand_mod()
|
||||
if relighted is None: relighted = self.get_relighted()
|
||||
|
||||
|
@ -355,7 +356,6 @@ class DFLPNG(object):
|
|||
source_landmarks=source_landmarks,
|
||||
image_to_face_mat=image_to_face_mat,
|
||||
fanseg_mask=fanseg_mask,
|
||||
pitch_yaw_roll=pitch_yaw_roll,
|
||||
eyebrows_expand_mod=eyebrows_expand_mod,
|
||||
relighted=relighted)
|
||||
|
||||
|
@ -407,7 +407,7 @@ class DFLPNG(object):
|
|||
|
||||
def get_face_type(self): return self.dfl_dict['face_type']
|
||||
def get_landmarks(self): return np.array ( self.dfl_dict['landmarks'] )
|
||||
def get_ie_polys(self): return IEPolys.load(self.dfl_dict.get('ie_polys',None))
|
||||
def get_ie_polys(self): return self.dfl_dict.get('ie_polys',None)
|
||||
def get_source_filename(self): return self.dfl_dict['source_filename']
|
||||
def get_source_rect(self): return self.dfl_dict['source_rect']
|
||||
def get_source_landmarks(self): return np.array ( self.dfl_dict['source_landmarks'] )
|
||||
|
@ -421,8 +421,6 @@ class DFLPNG(object):
|
|||
if fanseg_mask is not None:
|
||||
return np.clip ( np.array (fanseg_mask) / 255.0, 0.0, 1.0 )[...,np.newaxis]
|
||||
return None
|
||||
def get_pitch_yaw_roll(self):
|
||||
return self.dfl_dict.get ('pitch_yaw_roll', None)
|
||||
def get_eyebrows_expand_mod(self):
|
||||
return self.dfl_dict.get ('eyebrows_expand_mod', None)
|
||||
def get_relighted(self):
|
3
DFLIMG/__init__.py
Normal file
3
DFLIMG/__init__.py
Normal file
|
@ -0,0 +1,3 @@
|
|||
from .DFLIMG import DFLIMG
|
||||
from .DFLJPG import DFLJPG
|
||||
from .DFLPNG import DFLPNG
|
42
README.md
42
README.md
|
@ -6,6 +6,28 @@
|
|||
|
||||
## **DeepFaceLab** is a tool that utilizes machine learning to replace faces in videos.
|
||||
|
||||
- ### [Gallery](doc/gallery/doc_gallery.md)
|
||||
|
||||
- ### Manuals:
|
||||
|
||||
[English (google translated)](doc/manual_en_google_translated.pdf)
|
||||
|
||||
[На русском](doc/manual_ru.pdf)
|
||||
|
||||
- ### [Windows Desktop App](doc/doc_windows_desktop_app.md)
|
||||
|
||||
- ### Forks
|
||||
|
||||
[Google Colab fork](https://github.com/chervonij/DFL-Colab) by @chervonij
|
||||
|
||||
[Linux fork](https://github.com/lbfs/DeepFaceLab_Linux) by @lbfs - may be outdated
|
||||
|
||||
- ### [Ready to work facesets](doc/doc_ready_to_work_facesets.md)
|
||||
|
||||
- ### [Build and repository info](doc/doc_build_and_repository_info.md)
|
||||
|
||||
- ### How I can help the project?
|
||||
|
||||
If you like this software, please consider a donation.
|
||||
|
||||
GOAL: next DeepFacelab update.
|
||||
|
@ -16,25 +38,11 @@ GOAL: next DeepFacelab update.
|
|||
|
||||
bitcoin:31mPd6DxPCzbpCMZk4k1koWAbErSyqkAXr
|
||||
|
||||
- ### [Gallery](doc/gallery/doc_gallery.md)
|
||||

|
||||
|
||||
- ### Manuals:
|
||||
You can collect faceset of any celebrities that can be used in DeepFaceLab (described in manual)
|
||||
|
||||
[English (google translated)](doc/manual_en_google_translated.pdf)
|
||||
|
||||
[На русском](doc/manual_ru.pdf)
|
||||
|
||||
- ### [Prebuilt windows app](doc/doc_prebuilt_windows_app.md)
|
||||
|
||||
- ### Forks
|
||||
|
||||
[Google Colab fork](https://github.com/chervonij/DFL-Colab) by @chervonij
|
||||
|
||||
[Linux fork](https://github.com/lbfs/DeepFaceLab_Linux) by @lbfs - may be outdated
|
||||
|
||||
- ### [Ready to work facesets](doc/doc_ready_to_work_facesets.md)
|
||||
|
||||
- ### [Build and repository info](doc/doc_build_and_repository_info.md)
|
||||
and share it here [mrdeepfakes celebrity-facesets](https://mrdeepfakes.com/forums/forum-celebrity-facesets)
|
||||
|
||||
- ### Communication groups:
|
||||
|
||||
|
|
|
@ -28,7 +28,10 @@ def ConvertMaskedFace (predictor_func, predictor_input_shape, cfg, frame_info, i
|
|||
face_output_mat = LandmarksProcessor.get_transform_mat (img_face_landmarks, output_size, face_type=cfg.face_type, scale= 1.0 + 0.01*cfg.output_face_scale )
|
||||
|
||||
dst_face_bgr = cv2.warpAffine( img_bgr , face_mat, (output_size, output_size), flags=cv2.INTER_CUBIC )
|
||||
dst_face_bgr = np.clip(dst_face_bgr, 0, 1)
|
||||
|
||||
dst_face_mask_a_0 = cv2.warpAffine( img_face_mask_a, face_mat, (output_size, output_size), flags=cv2.INTER_CUBIC )
|
||||
dst_face_mask_a_0 = np.clip(dst_face_mask_a_0, 0, 1)
|
||||
|
||||
predictor_input_bgr = cv2.resize (dst_face_bgr, predictor_input_shape[0:2] )
|
||||
|
||||
|
@ -46,6 +49,7 @@ def ConvertMaskedFace (predictor_func, predictor_input_shape, cfg, frame_info, i
|
|||
|
||||
if cfg.super_resolution_mode:
|
||||
prd_face_bgr = cfg.superres_func(cfg.super_resolution_mode, prd_face_bgr)
|
||||
prd_face_bgr = np.clip(prd_face_bgr, 0, 1)
|
||||
|
||||
if predictor_masked:
|
||||
prd_face_mask_a_0 = cv2.resize (prd_face_mask_a_0, (output_size, output_size), cv2.INTER_CUBIC)
|
||||
|
|
|
@ -10,13 +10,9 @@ if the download qouta is exceeded, add the file to your own google drive and dow
|
|||
|
||||
Available builds:
|
||||
|
||||
* DeepFaceLabCUDA9.2SSE - for NVIDIA cards and any 64-bit CPU
|
||||
* DeepFaceLab_CUDA - for NVIDIA cards
|
||||
|
||||
* DeepFaceLabCUDA10.1AVX - for NVIDIA cards and CPU with AVX instructions support
|
||||
|
||||
* DeepFaceLabOpenCLSSE - for AMD/IntelHD cards and any 64-bit CPU
|
||||
|
||||
If your card does not work with CUDA 10.1 version, try CUDA 9.2.
|
||||
* DeepFaceLab_OpenCL - for NVIDIA/AMD/IntelHD cards
|
||||
|
||||
Important: you don't need to install CUDA !
|
||||
|
BIN
doc/example_faceset.jpg
Normal file
BIN
doc/example_faceset.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 25 KiB |
BIN
doc/landmarks_98.jpg
Normal file
BIN
doc/landmarks_98.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 215 KiB |
|
@ -83,7 +83,7 @@ class FANExtractor(object):
|
|||
for i, lmrks in enumerate(landmarks):
|
||||
try:
|
||||
if lmrks is not None:
|
||||
image_to_face_mat = LandmarksProcessor.get_transform_mat (lmrks, 256, FaceType.FULL)
|
||||
image_to_face_mat = LandmarksProcessor.get_transform_mat (lmrks, 256, FaceType.FULL, full_face_align_top=False)
|
||||
face_image = cv2.warpAffine(input_image, image_to_face_mat, (256, 256), cv2.INTER_CUBIC )
|
||||
|
||||
rects2 = second_pass_extractor.extract(face_image, is_bgr=is_bgr)
|
||||
|
|
BIN
facelib/FaceEnhancer.h5
Normal file
BIN
facelib/FaceEnhancer.h5
Normal file
Binary file not shown.
154
facelib/FaceEnhancer.py
Normal file
154
facelib/FaceEnhancer.py
Normal file
|
@ -0,0 +1,154 @@
|
|||
import operator
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
|
||||
class FaceEnhancer(object):
|
||||
"""
|
||||
x4 face enhancer
|
||||
"""
|
||||
def __init__(self):
|
||||
from nnlib import nnlib
|
||||
exec( nnlib.import_all(), locals(), globals() )
|
||||
|
||||
model_path = Path(__file__).parent / "FaceEnhancer.h5"
|
||||
if not model_path.exists():
|
||||
return
|
||||
|
||||
bgr_inp = Input ( (192,192,3) )
|
||||
t_param_inp = Input ( (1,) )
|
||||
t_param1_inp = Input ( (1,) )
|
||||
x = Conv2D (64, 3, strides=1, padding='same' )(bgr_inp)
|
||||
|
||||
a = Dense (64, use_bias=False) ( t_param_inp )
|
||||
a = Reshape( (1,1,64) )(a)
|
||||
b = Dense (64, use_bias=False ) ( t_param1_inp )
|
||||
b = Reshape( (1,1,64) )(b)
|
||||
x = Add()([x,a,b])
|
||||
|
||||
x = LeakyReLU(0.1)(x)
|
||||
|
||||
x = LeakyReLU(0.1)(Conv2D (64, 3, strides=1, padding='same' )(x))
|
||||
x = e0 = LeakyReLU(0.1)(Conv2D (64, 3, strides=1, padding='same')(x))
|
||||
|
||||
x = AveragePooling2D()(x)
|
||||
x = LeakyReLU(0.1)(Conv2D (112, 3, strides=1, padding='same')(x))
|
||||
x = e1 = LeakyReLU(0.1)(Conv2D (112, 3, strides=1, padding='same')(x))
|
||||
|
||||
x = AveragePooling2D()(x)
|
||||
x = LeakyReLU(0.1)(Conv2D (192, 3, strides=1, padding='same')(x))
|
||||
x = e2 = LeakyReLU(0.1)(Conv2D (192, 3, strides=1, padding='same')(x))
|
||||
|
||||
x = AveragePooling2D()(x)
|
||||
x = LeakyReLU(0.1)(Conv2D (336, 3, strides=1, padding='same')(x))
|
||||
x = e3 = LeakyReLU(0.1)(Conv2D (336, 3, strides=1, padding='same')(x))
|
||||
|
||||
x = AveragePooling2D()(x)
|
||||
x = LeakyReLU(0.1)(Conv2D (512, 3, strides=1, padding='same')(x))
|
||||
x = e4 = LeakyReLU(0.1)(Conv2D (512, 3, strides=1, padding='same')(x))
|
||||
|
||||
x = AveragePooling2D()(x)
|
||||
x = LeakyReLU(0.1)(Conv2D (512, 3, strides=1, padding='same')(x))
|
||||
x = LeakyReLU(0.1)(Conv2D (512, 3, strides=1, padding='same')(x))
|
||||
x = LeakyReLU(0.1)(Conv2D (512, 3, strides=1, padding='same')(x))
|
||||
x = LeakyReLU(0.1)(Conv2D (512, 3, strides=1, padding='same')(x))
|
||||
|
||||
x = Concatenate()([ BilinearInterpolation()(x), e4 ])
|
||||
|
||||
x = LeakyReLU(0.1)(Conv2D (512, 3, strides=1, padding='same')(x))
|
||||
x = LeakyReLU(0.1)(Conv2D (512, 3, strides=1, padding='same')(x))
|
||||
|
||||
x = Concatenate()([ BilinearInterpolation()(x), e3 ])
|
||||
x = LeakyReLU(0.1)(Conv2D (512, 3, strides=1, padding='same')(x))
|
||||
x = LeakyReLU(0.1)(Conv2D (512, 3, strides=1, padding='same')(x))
|
||||
|
||||
x = Concatenate()([ BilinearInterpolation()(x), e2 ])
|
||||
x = LeakyReLU(0.1)(Conv2D (288, 3, strides=1, padding='same')(x))
|
||||
x = LeakyReLU(0.1)(Conv2D (288, 3, strides=1, padding='same')(x))
|
||||
|
||||
x = Concatenate()([ BilinearInterpolation()(x), e1 ])
|
||||
x = LeakyReLU(0.1)(Conv2D (160, 3, strides=1, padding='same')(x))
|
||||
x = LeakyReLU(0.1)(Conv2D (160, 3, strides=1, padding='same')(x))
|
||||
|
||||
x = Concatenate()([ BilinearInterpolation()(x), e0 ])
|
||||
x = LeakyReLU(0.1)(Conv2D (96, 3, strides=1, padding='same')(x))
|
||||
x = d0 = LeakyReLU(0.1)(Conv2D (96, 3, strides=1, padding='same')(x))
|
||||
|
||||
x = LeakyReLU(0.1)(Conv2D (48, 3, strides=1, padding='same')(x))
|
||||
|
||||
x = Conv2D (3, 3, strides=1, padding='same', activation='tanh')(x)
|
||||
out1x = Add()([bgr_inp, x])
|
||||
|
||||
x = d0
|
||||
x = LeakyReLU(0.1)(Conv2D (96, 3, strides=1, padding='same')(x))
|
||||
x = LeakyReLU(0.1)(Conv2D (96, 3, strides=1, padding='same')(x))
|
||||
x = d2x = BilinearInterpolation()(x)
|
||||
|
||||
x = LeakyReLU(0.1)(Conv2D (48, 3, strides=1, padding='same')(x))
|
||||
x = Conv2D (3, 3, strides=1, padding='same', activation='tanh')(x)
|
||||
|
||||
out2x = Add()([BilinearInterpolation()(out1x), x])
|
||||
|
||||
x = d2x
|
||||
x = LeakyReLU(0.1)(Conv2D (72, 3, strides=1, padding='same')(x))
|
||||
x = LeakyReLU(0.1)(Conv2D (72, 3, strides=1, padding='same')(x))
|
||||
x = d4x = BilinearInterpolation()(x)
|
||||
|
||||
x = LeakyReLU(0.1)(Conv2D (36, 3, strides=1, padding='same')(x))
|
||||
x = Conv2D (3, 3, strides=1, padding='same', activation='tanh')(x)
|
||||
out4x = Add()([BilinearInterpolation()(out2x), x ])
|
||||
|
||||
self.model = keras.models.Model ( [bgr_inp,t_param_inp,t_param1_inp], [out4x] )
|
||||
self.model.load_weights (str(model_path))
|
||||
|
||||
|
||||
def enhance (self, inp_img, is_tanh=False, preserve_size=True):
|
||||
if not is_tanh:
|
||||
inp_img = np.clip( inp_img * 2 -1, -1, 1 )
|
||||
|
||||
param = np.array([0.2])
|
||||
param1 = np.array([1.0])
|
||||
up_res = 4
|
||||
patch_size = 192
|
||||
patch_size_half = patch_size // 2
|
||||
|
||||
h,w,c = inp_img.shape
|
||||
|
||||
i_max = w-patch_size+1
|
||||
j_max = h-patch_size+1
|
||||
|
||||
final_img = np.zeros ( (h*up_res,w*up_res,c), dtype=np.float32 )
|
||||
final_img_div = np.zeros ( (h*up_res,w*up_res,1), dtype=np.float32 )
|
||||
|
||||
x = np.concatenate ( [ np.linspace (0,1,patch_size_half*up_res), np.linspace (1,0,patch_size_half*up_res) ] )
|
||||
x,y = np.meshgrid(x,x)
|
||||
patch_mask = (x*y)[...,None]
|
||||
|
||||
j=0
|
||||
while j < j_max:
|
||||
i = 0
|
||||
while i < i_max:
|
||||
patch_img = inp_img[j:j+patch_size, i:i+patch_size,:]
|
||||
x = self.model.predict( [ patch_img[None,...], param, param1 ] )[0]
|
||||
final_img [j*up_res:(j+patch_size)*up_res, i*up_res:(i+patch_size)*up_res,:] += x*patch_mask
|
||||
final_img_div[j*up_res:(j+patch_size)*up_res, i*up_res:(i+patch_size)*up_res,:] += patch_mask
|
||||
if i == i_max-1:
|
||||
break
|
||||
i = min( i+patch_size_half, i_max-1)
|
||||
if j == j_max-1:
|
||||
break
|
||||
j = min( j+patch_size_half, j_max-1)
|
||||
|
||||
final_img_div[final_img_div==0] = 1.0
|
||||
final_img /= final_img_div
|
||||
|
||||
if preserve_size:
|
||||
final_img = cv2.resize (final_img, (w,h), cv2.INTER_LANCZOS4)
|
||||
|
||||
if not is_tanh:
|
||||
final_img = np.clip( final_img/2+0.5, 0, 1 )
|
||||
|
||||
return final_img
|
|
@ -183,6 +183,15 @@ landmarks_68_3D = np.array( [
|
|||
[0.205322 , 31.408738 , -21.903670 ],
|
||||
[-7.198266 , 30.844876 , -20.328022 ] ], dtype=np.float32)
|
||||
|
||||
FaceType_to_padding_remove_align = {
|
||||
FaceType.HALF: (0.0, False),
|
||||
FaceType.MID_FULL: (0.0675, False),
|
||||
FaceType.FULL: (0.2109375, False),
|
||||
FaceType.FULL_NO_ALIGN: (0.2109375, True),
|
||||
FaceType.HEAD: (0.369140625, False),
|
||||
FaceType.HEAD_NO_ALIGN: (0.369140625, True),
|
||||
}
|
||||
|
||||
def convert_98_to_68(lmrks):
|
||||
#jaw
|
||||
result = [ lmrks[0] ]
|
||||
|
@ -240,66 +249,62 @@ def transform_points(points, mat, invert=False):
|
|||
points = np.squeeze(points)
|
||||
return points
|
||||
|
||||
def get_transform_mat (image_landmarks, output_size, face_type, scale=1.0):
|
||||
def get_transform_mat (image_landmarks, output_size, face_type, scale=1.0, full_face_align_top=True):
|
||||
if not isinstance(image_landmarks, np.ndarray):
|
||||
image_landmarks = np.array (image_landmarks)
|
||||
|
||||
"""
|
||||
if face_type == FaceType.AVATAR:
|
||||
centroid = np.mean (image_landmarks, axis=0)
|
||||
padding, remove_align = FaceType_to_padding_remove_align.get(face_type, 0.0)
|
||||
|
||||
mat = umeyama(image_landmarks[17:], landmarks_2D, True)[0:2]
|
||||
a, c = mat[0,0], mat[1,0]
|
||||
scale = math.sqrt((a * a) + (c * c))
|
||||
|
||||
padding = (output_size / 64) * 32
|
||||
|
||||
mat = np.eye ( 2,3 )
|
||||
mat[0,2] = -centroid[0]
|
||||
mat[1,2] = -centroid[1]
|
||||
mat = mat * scale * (output_size / 3)
|
||||
mat[:,2] += output_size / 2
|
||||
else:
|
||||
"""
|
||||
remove_align = False
|
||||
if face_type == FaceType.FULL_NO_ALIGN:
|
||||
face_type = FaceType.FULL
|
||||
remove_align = True
|
||||
elif face_type == FaceType.HEAD_NO_ALIGN:
|
||||
face_type = FaceType.HEAD
|
||||
remove_align = True
|
||||
|
||||
if face_type == FaceType.HALF:
|
||||
padding = 0
|
||||
elif face_type == FaceType.MID_FULL:
|
||||
padding = int(output_size * 0.06)
|
||||
elif face_type == FaceType.FULL:
|
||||
padding = (output_size / 64) * 12
|
||||
elif face_type == FaceType.HEAD:
|
||||
padding = (output_size / 64) * 21
|
||||
else:
|
||||
raise ValueError ('wrong face_type: ', face_type)
|
||||
|
||||
#mat = umeyama(image_landmarks[17:], landmarks_2D, True)[0:2]
|
||||
mat = umeyama( np.concatenate ( [ image_landmarks[17:49] , image_landmarks[54:55] ] ) , landmarks_2D_new, True)[0:2]
|
||||
l_p = transform_points ( np.float32([(0,0),(1,0),(1,1),(0,1),(0.5,0.5)]) , mat, True)
|
||||
l_c = l_p[4]
|
||||
|
||||
mat = mat * (output_size - 2 * padding)
|
||||
mat[:,2] += padding
|
||||
mat *= (1 / scale)
|
||||
mat[:,2] += -output_size*( ( (1 / scale) - 1.0 ) / 2 )
|
||||
tb_diag_vec = (l_p[2]-l_p[0]).astype(np.float32)
|
||||
tb_diag_vec /= npla.norm(tb_diag_vec)
|
||||
bt_diag_vec = (l_p[1]-l_p[3]).astype(np.float32)
|
||||
bt_diag_vec /= npla.norm(bt_diag_vec)
|
||||
|
||||
mod = (1.0 / scale)* ( npla.norm(l_p[0]-l_p[2])*(padding*np.sqrt(2.0) + 0.5) )
|
||||
|
||||
l_t = np.array( [ np.round( l_c - tb_diag_vec*mod ),
|
||||
np.round( l_c + bt_diag_vec*mod ),
|
||||
np.round( l_c + tb_diag_vec*mod ) ] )
|
||||
|
||||
pts2 = np.float32(( (0,0),(output_size,0),(output_size,output_size) ))
|
||||
mat = cv2.getAffineTransform(l_t,pts2)
|
||||
|
||||
#if full_face_align_top and (face_type == FaceType.FULL or face_type == FaceType.FULL_NO_ALIGN):
|
||||
# #lmrks2 = expand_eyebrows(image_landmarks)
|
||||
# #lmrks2_ = transform_points( [ lmrks2[19], lmrks2[24] ], mat, False )
|
||||
# #y_diff = np.float32( (0,np.min(lmrks2_[:,1])) )
|
||||
# #y_diff = transform_points( [ np.float32( (0,0) ), y_diff], mat, True)
|
||||
# #y_diff = y_diff[1]-y_diff[0]
|
||||
#
|
||||
# x_diff = np.float32((0,0))
|
||||
#
|
||||
# lmrks2_ = transform_points( [ image_landmarks[0], image_landmarks[16] ], mat, False )
|
||||
# if lmrks2_[0,0] < 0:
|
||||
# x_diff = lmrks2_[0,0]
|
||||
# x_diff = transform_points( [ np.float32( (0,0) ), np.float32((x_diff,0)) ], mat, True)
|
||||
# x_diff = x_diff[1]-x_diff[0]
|
||||
# elif lmrks2_[1,0] >= output_size:
|
||||
# x_diff = lmrks2_[1,0]-(output_size-1)
|
||||
# x_diff = transform_points( [ np.float32( (0,0) ), np.float32((x_diff,0)) ], mat, True)
|
||||
# x_diff = x_diff[1]-x_diff[0]
|
||||
#
|
||||
# mat = cv2.getAffineTransform( l_t+y_diff+x_diff ,pts2)
|
||||
|
||||
if remove_align:
|
||||
bbox = transform_points ( [ (0,0), (0,output_size-1), (output_size-1, output_size-1), (output_size-1,0) ], mat, True)
|
||||
bbox = transform_points ( [ (0,0), (0,output_size), (output_size, output_size), (output_size,0) ], mat, True)
|
||||
area = mathlib.polygon_area(bbox[:,0], bbox[:,1] )
|
||||
side = math.sqrt(area) / 2
|
||||
center = transform_points ( [(output_size/2,output_size/2)], mat, True)
|
||||
|
||||
pts1 = np.float32([ center+[-side,-side], center+[side,-side], center+[-side,side] ])
|
||||
pts2 = np.float32([[0,0],[output_size-1,0],[0,output_size-1]])
|
||||
pts1 = np.float32(( center+[-side,-side], center+[side,-side], center+[-side,side] ))
|
||||
mat = cv2.getAffineTransform(pts1,pts2)
|
||||
|
||||
return mat
|
||||
|
||||
|
||||
def expand_eyebrows(lmrks, eyebrows_expand_mod=1.0):
|
||||
if len(lmrks) != 68:
|
||||
raise Exception('works only with 68 landmarks')
|
||||
|
|
|
@ -4,3 +4,4 @@ from .MTCExtractor import MTCExtractor
|
|||
from .S3FDExtractor import S3FDExtractor
|
||||
from .FANExtractor import FANExtractor
|
||||
from .PoseEstimator import PoseEstimator
|
||||
from .FaceEnhancer import FaceEnhancer
|
|
@ -97,7 +97,7 @@ class IEPolys:
|
|||
@staticmethod
|
||||
def load(ie_polys=None):
|
||||
obj = IEPolys()
|
||||
if ie_polys is not None:
|
||||
if ie_polys is not None and isinstance(ie_polys, list):
|
||||
for (type, points) in ie_polys:
|
||||
obj.add(type)
|
||||
obj.n_list().set_points(points)
|
||||
|
|
|
@ -2,18 +2,24 @@ import numpy as np
|
|||
import cv2
|
||||
from utils import random_utils
|
||||
|
||||
def gen_warp_params (source, flip, rotation_range=[-10,10], scale_range=[-0.5, 0.5], tx_range=[-0.05, 0.05], ty_range=[-0.05, 0.05] ):
|
||||
def gen_warp_params (source, flip, rotation_range=[-10,10], scale_range=[-0.5, 0.5], tx_range=[-0.05, 0.05], ty_range=[-0.05, 0.05], rnd_seed=None ):
|
||||
h,w,c = source.shape
|
||||
if (h != w):
|
||||
raise ValueError ('gen_warp_params accepts only square images.')
|
||||
|
||||
rotation = np.random.uniform( rotation_range[0], rotation_range[1] )
|
||||
scale = np.random.uniform(1 +scale_range[0], 1 +scale_range[1])
|
||||
tx = np.random.uniform( tx_range[0], tx_range[1] )
|
||||
ty = np.random.uniform( ty_range[0], ty_range[1] )
|
||||
if rnd_seed != None:
|
||||
rnd_state = np.random.RandomState (rnd_seed)
|
||||
else:
|
||||
rnd_state = np.random
|
||||
|
||||
rotation = rnd_state.uniform( rotation_range[0], rotation_range[1] )
|
||||
scale = rnd_state.uniform(1 +scale_range[0], 1 +scale_range[1])
|
||||
tx = rnd_state.uniform( tx_range[0], tx_range[1] )
|
||||
ty = rnd_state.uniform( ty_range[0], ty_range[1] )
|
||||
p_flip = flip and rnd_state.randint(10) < 4
|
||||
|
||||
#random warp by grid
|
||||
cell_size = [ w // (2**i) for i in range(1,4) ] [ np.random.randint(3) ]
|
||||
cell_size = [ w // (2**i) for i in range(1,4) ] [ rnd_state.randint(3) ]
|
||||
cell_count = w // cell_size + 1
|
||||
|
||||
grid_points = np.linspace( 0, w, cell_count)
|
||||
|
@ -37,7 +43,7 @@ def gen_warp_params (source, flip, rotation_range=[-10,10], scale_range=[-0.5, 0
|
|||
params['mapy'] = mapy
|
||||
params['rmat'] = random_transform_mat
|
||||
params['w'] = w
|
||||
params['flip'] = flip and np.random.randint(10) < 4
|
||||
params['flip'] = p_flip
|
||||
|
||||
return params
|
||||
|
||||
|
|
101
main.py
101
main.py
|
@ -1,21 +1,21 @@
|
|||
import os
|
||||
import sys
|
||||
import time
|
||||
import argparse
|
||||
import multiprocessing
|
||||
from utils import Path_utils
|
||||
from utils import os_utils
|
||||
from pathlib import Path
|
||||
|
||||
if sys.version_info[0] < 3 or (sys.version_info[0] == 3 and sys.version_info[1] < 6):
|
||||
raise Exception("This program requires at least Python 3.6")
|
||||
|
||||
class fixPathAction(argparse.Action):
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
setattr(namespace, self.dest, os.path.abspath(os.path.expanduser(values)))
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import argparse
|
||||
import multiprocessing
|
||||
multiprocessing.set_start_method("spawn")
|
||||
from utils import Path_utils
|
||||
from utils import os_utils
|
||||
from pathlib import Path
|
||||
from interact import interact as io
|
||||
|
||||
if sys.version_info[0] < 3 or (sys.version_info[0] == 3 and sys.version_info[1] < 6):
|
||||
raise Exception("This program requires at least Python 3.6")
|
||||
|
||||
class fixPathAction(argparse.Action):
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
setattr(namespace, self.dest, os.path.abspath(os.path.expanduser(values)))
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
subparsers = parser.add_subparsers()
|
||||
|
@ -66,8 +66,8 @@ if __name__ == "__main__":
|
|||
|
||||
def process_dev_extract_umd_csv(arguments):
|
||||
os_utils.set_process_lowest_prio()
|
||||
from mainscripts import Extractor
|
||||
Extractor.extract_umd_csv( arguments.input_csv_file,
|
||||
from mainscripts import dev_misc
|
||||
dev_misc.extract_umd_csv( arguments.input_csv_file,
|
||||
device_args={'cpu_only' : arguments.cpu_only,
|
||||
'multi_gpu' : arguments.multi_gpu,
|
||||
}
|
||||
|
@ -88,22 +88,15 @@ if __name__ == "__main__":
|
|||
p = subparsers.add_parser( "dev_apply_celebamaskhq", help="")
|
||||
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir")
|
||||
p.set_defaults (func=process_dev_apply_celebamaskhq)
|
||||
"""
|
||||
def process_extract_fanseg(arguments):
|
||||
os_utils.set_process_lowest_prio()
|
||||
from mainscripts import Extractor
|
||||
Extractor.extract_fanseg( arguments.input_dir,
|
||||
device_args={'cpu_only' : arguments.cpu_only,
|
||||
'multi_gpu' : arguments.multi_gpu,
|
||||
}
|
||||
)
|
||||
|
||||
p = subparsers.add_parser( "extract_fanseg", help="Extract fanseg mask from faces.")
|
||||
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory. A directory containing the files you wish to process.")
|
||||
p.add_argument('--multi-gpu', action="store_true", dest="multi_gpu", default=False, help="Enables multi GPU.")
|
||||
p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Extract on CPU.")
|
||||
p.set_defaults (func=process_extract_fanseg)
|
||||
"""
|
||||
def process_dev_test(arguments):
|
||||
os_utils.set_process_lowest_prio()
|
||||
from mainscripts import dev_misc
|
||||
dev_misc.dev_test( arguments.input_dir )
|
||||
|
||||
p = subparsers.add_parser( "dev_test", help="")
|
||||
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir")
|
||||
p.set_defaults (func=process_dev_test)
|
||||
|
||||
def process_sort(arguments):
|
||||
os_utils.set_process_lowest_prio()
|
||||
|
@ -112,7 +105,7 @@ if __name__ == "__main__":
|
|||
|
||||
p = subparsers.add_parser( "sort", help="Sort faces in a directory.")
|
||||
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory. A directory containing the files you wish to process.")
|
||||
p.add_argument('--by', required=True, dest="sort_by_method", choices=("blur", "face", "face-dissim", "face-yaw", "face-pitch", "hist", "hist-dissim", "brightness", "hue", "black", "origname", "oneface", "final", "final-no-blur", "vggface", "test"), help="Method of sorting. 'origname' sort by original filename to recover original sequence." )
|
||||
p.add_argument('--by', required=True, dest="sort_by_method", choices=("blur", "face", "face-dissim", "face-yaw", "face-pitch", "hist", "hist-dissim", "brightness", "hue", "black", "origname", "oneface", "final", "final-no-blur", "vggface", "absdiff", "test"), help="Method of sorting. 'origname' sort by original filename to recover original sequence." )
|
||||
p.set_defaults (func=process_sort)
|
||||
|
||||
def process_util(arguments):
|
||||
|
@ -134,6 +127,22 @@ if __name__ == "__main__":
|
|||
if arguments.remove_ie_polys:
|
||||
Util.remove_ie_polys_folder (input_path=arguments.input_dir)
|
||||
|
||||
if arguments.save_faceset_metadata:
|
||||
Util.save_faceset_metadata_folder (input_path=arguments.input_dir)
|
||||
|
||||
if arguments.restore_faceset_metadata:
|
||||
Util.restore_faceset_metadata_folder (input_path=arguments.input_dir)
|
||||
|
||||
if arguments.pack_faceset:
|
||||
io.log_info ("Performing faceset packing...\r\n")
|
||||
from samplelib import PackedFaceset
|
||||
PackedFaceset.pack( Path(arguments.input_dir) )
|
||||
|
||||
if arguments.unpack_faceset:
|
||||
io.log_info ("Performing faceset unpacking...\r\n")
|
||||
from samplelib import PackedFaceset
|
||||
PackedFaceset.unpack( Path(arguments.input_dir) )
|
||||
|
||||
p = subparsers.add_parser( "util", help="Utilities.")
|
||||
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory. A directory containing the files you wish to process.")
|
||||
p.add_argument('--convert-png-to-jpg', action="store_true", dest="convert_png_to_jpg", default=False, help="Convert DeepFaceLAB PNG files to JPEG.")
|
||||
|
@ -141,6 +150,10 @@ if __name__ == "__main__":
|
|||
p.add_argument('--recover-original-aligned-filename', action="store_true", dest="recover_original_aligned_filename", default=False, help="Recover original aligned filename.")
|
||||
#p.add_argument('--remove-fanseg', action="store_true", dest="remove_fanseg", default=False, help="Remove fanseg mask from aligned faces.")
|
||||
p.add_argument('--remove-ie-polys', action="store_true", dest="remove_ie_polys", default=False, help="Remove ie_polys from aligned faces.")
|
||||
p.add_argument('--save-faceset-metadata', action="store_true", dest="save_faceset_metadata", default=False, help="Save faceset metadata to file.")
|
||||
p.add_argument('--restore-faceset-metadata', action="store_true", dest="restore_faceset_metadata", default=False, help="Restore faceset metadata to file. Image filenames must be the same as used with save.")
|
||||
p.add_argument('--pack-faceset', action="store_true", dest="pack_faceset", default=False, help="")
|
||||
p.add_argument('--unpack-faceset', action="store_true", dest="unpack_faceset", default=False, help="")
|
||||
|
||||
p.set_defaults (func=process_util)
|
||||
|
||||
|
@ -273,16 +286,31 @@ if __name__ == "__main__":
|
|||
|
||||
p.set_defaults(func=process_labelingtool_edit_mask)
|
||||
|
||||
facesettool_parser = subparsers.add_parser( "facesettool", help="Faceset tools.").add_subparsers()
|
||||
|
||||
def process_faceset_enhancer(arguments):
|
||||
os_utils.set_process_lowest_prio()
|
||||
from mainscripts import FacesetEnhancer
|
||||
FacesetEnhancer.process_folder ( Path(arguments.input_dir), multi_gpu=arguments.multi_gpu, cpu_only=arguments.cpu_only )
|
||||
|
||||
p = facesettool_parser.add_parser ("enhance", help="Enhance details in DFL faceset.")
|
||||
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory of aligned faces.")
|
||||
p.add_argument('--multi-gpu', action="store_true", dest="multi_gpu", default=False, help="Enables multi GPU.")
|
||||
p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Process on CPU.")
|
||||
|
||||
p.set_defaults(func=process_faceset_enhancer)
|
||||
|
||||
"""
|
||||
def process_relight_faceset(arguments):
|
||||
os_utils.set_process_lowest_prio()
|
||||
from mainscripts import FacesetRelighter
|
||||
FacesetRelighter.relight (arguments.input_dir, arguments.lighten, arguments.random_one)
|
||||
|
||||
def process_delete_relighted(arguments):
|
||||
os_utils.set_process_lowest_prio()
|
||||
from mainscripts import FacesetRelighter
|
||||
FacesetRelighter.delete_relighted (arguments.input_dir)
|
||||
|
||||
facesettool_parser = subparsers.add_parser( "facesettool", help="Faceset tools.").add_subparsers()
|
||||
|
||||
p = facesettool_parser.add_parser ("relight", help="Synthesize new faces from existing ones by relighting them. With the relighted faces neural network will better reproduce face shadows.")
|
||||
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory of aligned faces.")
|
||||
p.add_argument('--lighten', action="store_true", dest="lighten", default=None, help="Lighten the faces.")
|
||||
|
@ -292,6 +320,7 @@ if __name__ == "__main__":
|
|||
p = facesettool_parser.add_parser ("delete_relighted", help="Delete relighted faces.")
|
||||
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory of aligned faces.")
|
||||
p.set_defaults(func=process_delete_relighted)
|
||||
"""
|
||||
|
||||
def bad_args(arguments):
|
||||
parser.print_help()
|
||||
|
|
|
@ -14,17 +14,16 @@ import numpy as np
|
|||
import numpy.linalg as npla
|
||||
|
||||
import imagelib
|
||||
import samplelib
|
||||
from converters import (ConverterConfig, ConvertFaceAvatar, ConvertMasked,
|
||||
FrameInfo)
|
||||
from facelib import FaceType, LandmarksProcessor
|
||||
from nnlib import TernausNet
|
||||
|
||||
from interact import interact as io
|
||||
from joblib import SubprocessFunctionCaller, Subprocessor
|
||||
from nnlib import TernausNet
|
||||
from utils import Path_utils
|
||||
from utils.cv2_utils import *
|
||||
from utils.DFLJPG import DFLJPG
|
||||
from utils.DFLPNG import DFLPNG
|
||||
from DFLIMG import DFLIMG
|
||||
|
||||
from .ConverterScreen import Screen, ScreenManager
|
||||
|
||||
|
@ -670,19 +669,29 @@ def main (args, device_args):
|
|||
io.log_err('Aligned directory not found. Please ensure it exists.')
|
||||
return
|
||||
|
||||
packed_samples = None
|
||||
try:
|
||||
packed_samples = samplelib.PackedFaceset.load(aligned_path)
|
||||
except:
|
||||
io.log_err(f"Error occured while loading samplelib.PackedFaceset.load {str(aligned_path)}, {traceback.format_exc()}")
|
||||
|
||||
|
||||
if packed_samples is not None:
|
||||
io.log_info ("Using packed faceset.")
|
||||
def generator():
|
||||
for sample in io.progress_bar_generator( packed_samples, "Collecting alignments"):
|
||||
filepath = Path(sample.filename)
|
||||
yield DFLIMG.load(filepath, loader_func=lambda x: sample.read_raw_file() )
|
||||
else:
|
||||
def generator():
|
||||
for filepath in io.progress_bar_generator( Path_utils.get_image_paths(aligned_path), "Collecting alignments"):
|
||||
filepath = Path(filepath)
|
||||
yield DFLIMG.load(filepath)
|
||||
|
||||
alignments = {}
|
||||
multiple_faces_detected = False
|
||||
aligned_path_image_paths = Path_utils.get_image_paths(aligned_path)
|
||||
for filepath in io.progress_bar_generator(aligned_path_image_paths, "Collecting alignments"):
|
||||
filepath = Path(filepath)
|
||||
|
||||
if filepath.suffix == '.png':
|
||||
dflimg = DFLPNG.load( str(filepath) )
|
||||
elif filepath.suffix == '.jpg':
|
||||
dflimg = DFLJPG.load ( str(filepath) )
|
||||
else:
|
||||
dflimg = None
|
||||
|
||||
for dflimg in generator():
|
||||
if dflimg is None:
|
||||
io.log_err ("%s is not a dfl image file" % (filepath.name) )
|
||||
continue
|
||||
|
@ -746,13 +755,7 @@ def main (args, device_args):
|
|||
for filepath in io.progress_bar_generator(input_path_image_paths, "Collecting info"):
|
||||
filepath = Path(filepath)
|
||||
|
||||
if filepath.suffix == '.png':
|
||||
dflimg = DFLPNG.load( str(filepath) )
|
||||
elif filepath.suffix == '.jpg':
|
||||
dflimg = DFLJPG.load ( str(filepath) )
|
||||
else:
|
||||
dflimg = None
|
||||
|
||||
dflimg = DFLIMG.load(filepath)
|
||||
if dflimg is None:
|
||||
io.log_err ("%s is not a dfl image file" % (filepath.name) )
|
||||
continue
|
||||
|
|
|
@ -20,20 +20,19 @@ from joblib import Subprocessor
|
|||
from nnlib import TernausNet, nnlib
|
||||
from utils import Path_utils
|
||||
from utils.cv2_utils import *
|
||||
from utils.DFLJPG import DFLJPG
|
||||
from utils.DFLPNG import DFLPNG
|
||||
from DFLIMG import *
|
||||
|
||||
DEBUG = False
|
||||
|
||||
class ExtractSubprocessor(Subprocessor):
|
||||
class Data(object):
|
||||
def __init__(self, filename=None, rects=None, landmarks = None, landmarks_accurate=True, pitch_yaw_roll=None, force_output_path=None, final_output_files = None):
|
||||
def __init__(self, filename=None, rects=None, landmarks = None, landmarks_accurate=True, manual=False, force_output_path=None, final_output_files = None):
|
||||
self.filename = filename
|
||||
self.rects = rects or []
|
||||
self.rects_rotation = 0
|
||||
self.landmarks_accurate = landmarks_accurate
|
||||
self.manual = manual
|
||||
self.landmarks = landmarks or []
|
||||
self.pitch_yaw_roll = pitch_yaw_roll
|
||||
self.force_output_path = force_output_path
|
||||
self.final_output_files = final_output_files or []
|
||||
self.faces_detected = 0
|
||||
|
@ -110,8 +109,11 @@ class ExtractSubprocessor(Subprocessor):
|
|||
#override
|
||||
def process_data(self, data):
|
||||
filename_path = Path( data.filename )
|
||||
|
||||
filename_path_str = str(filename_path)
|
||||
|
||||
if self.type == 'landmarks' and len(data.rects) == 0:
|
||||
return data
|
||||
|
||||
if self.cached_image[0] == filename_path_str:
|
||||
image = self.cached_image[1] #cached image for manual extractor
|
||||
else:
|
||||
|
@ -133,10 +135,7 @@ class ExtractSubprocessor(Subprocessor):
|
|||
h, w, ch = image.shape
|
||||
if h == w:
|
||||
#extracting from already extracted jpg image?
|
||||
if filename_path.suffix == '.png':
|
||||
src_dflimg = DFLPNG.load ( str(filename_path) )
|
||||
if filename_path.suffix == '.jpg':
|
||||
src_dflimg = DFLJPG.load ( str(filename_path) )
|
||||
src_dflimg = DFLIMG.load (filename_path)
|
||||
|
||||
if 'rects' in self.type:
|
||||
if min(w,h) < 128:
|
||||
|
@ -164,7 +163,6 @@ class ExtractSubprocessor(Subprocessor):
|
|||
return data
|
||||
|
||||
elif self.type == 'landmarks':
|
||||
|
||||
if data.rects_rotation == 0:
|
||||
rotated_image = image
|
||||
elif data.rects_rotation == 90:
|
||||
|
@ -243,7 +241,7 @@ class ExtractSubprocessor(Subprocessor):
|
|||
rect_area = mathlib.polygon_area(np.array(rect[[0,2,2,0]]), np.array(rect[[1,1,3,3]]))
|
||||
landmarks_area = mathlib.polygon_area(landmarks_bbox[:,0], landmarks_bbox[:,1] )
|
||||
|
||||
if self.face_type <= FaceType.FULL_NO_ALIGN and landmarks_area > 4*rect_area: #get rid of faces which umeyama-landmark-area > 4*detector-rect-area
|
||||
if not data.manual and self.face_type <= FaceType.FULL_NO_ALIGN and landmarks_area > 4*rect_area: #get rid of faces which umeyama-landmark-area > 4*detector-rect-area
|
||||
continue
|
||||
|
||||
if self.debug_dir is not None:
|
||||
|
@ -268,8 +266,7 @@ class ExtractSubprocessor(Subprocessor):
|
|||
source_filename=filename_path.name,
|
||||
source_rect=rect,
|
||||
source_landmarks=image_landmarks.tolist(),
|
||||
image_to_face_mat=image_to_face_mat,
|
||||
pitch_yaw_roll=data.pitch_yaw_roll
|
||||
image_to_face_mat=image_to_face_mat
|
||||
)
|
||||
|
||||
data.final_output_files.append (output_file)
|
||||
|
@ -315,7 +312,7 @@ class ExtractSubprocessor(Subprocessor):
|
|||
else:
|
||||
no_response_time_sec = 60
|
||||
|
||||
super().__init__('Extractor', ExtractSubprocessor.Cli, no_response_time_sec, initialize_subprocesses_in_serial=(type != 'final'))
|
||||
super().__init__('Extractor', ExtractSubprocessor.Cli, no_response_time_sec)
|
||||
|
||||
#override
|
||||
def on_check_run(self):
|
||||
|
@ -768,7 +765,7 @@ def main(input_dir,
|
|||
if images_found != 0:
|
||||
if detector == 'manual':
|
||||
io.log_info ('Performing manual extract...')
|
||||
data = ExtractSubprocessor ([ ExtractSubprocessor.Data(filename) for filename in input_path_image_paths ], 'landmarks', image_size, face_type, debug_dir, cpu_only=cpu_only, manual=True, manual_window_size=manual_window_size).run()
|
||||
data = ExtractSubprocessor ([ ExtractSubprocessor.Data(filename, manual=True) for filename in input_path_image_paths ], 'landmarks', image_size, face_type, debug_dir, cpu_only=cpu_only, manual=True, manual_window_size=manual_window_size).run()
|
||||
else:
|
||||
io.log_info ('Performing 1st pass...')
|
||||
data = ExtractSubprocessor ([ ExtractSubprocessor.Data(filename) for filename in input_path_image_paths ], 'rects-'+detector, image_size, face_type, debug_dir, multi_gpu=multi_gpu, cpu_only=cpu_only, manual=False, max_faces_from_image=max_faces_from_image).run()
|
||||
|
@ -784,118 +781,13 @@ def main(input_dir,
|
|||
if all ( np.array ( [ d.faces_detected > 0 for d in data] ) == True ):
|
||||
io.log_info ('All faces are detected, manual fix not needed.')
|
||||
else:
|
||||
fix_data = [ ExtractSubprocessor.Data(d.filename) for d in data if d.faces_detected == 0 ]
|
||||
fix_data = [ ExtractSubprocessor.Data(d.filename, manual=True) for d in data if d.faces_detected == 0 ]
|
||||
io.log_info ('Performing manual fix for %d images...' % (len(fix_data)) )
|
||||
fix_data = ExtractSubprocessor (fix_data, 'landmarks', image_size, face_type, debug_dir, manual=True, manual_window_size=manual_window_size).run()
|
||||
fix_data = ExtractSubprocessor (fix_data, 'final', image_size, face_type, debug_dir, multi_gpu=multi_gpu, cpu_only=cpu_only, manual=False, final_output_path=output_path).run()
|
||||
faces_detected += sum([d.faces_detected for d in fix_data])
|
||||
|
||||
|
||||
io.log_info ('-------------------------')
|
||||
io.log_info ('Images found: %d' % (images_found) )
|
||||
io.log_info ('Faces detected: %d' % (faces_detected) )
|
||||
io.log_info ('-------------------------')
|
||||
|
||||
#unused in end user workflow
|
||||
def extract_fanseg(input_dir, device_args={} ):
|
||||
multi_gpu = device_args.get('multi_gpu', False)
|
||||
cpu_only = device_args.get('cpu_only', False)
|
||||
|
||||
input_path = Path(input_dir)
|
||||
if not input_path.exists():
|
||||
raise ValueError('Input directory not found. Please ensure it exists.')
|
||||
|
||||
paths_to_extract = []
|
||||
for filename in Path_utils.get_image_paths(input_path) :
|
||||
filepath = Path(filename)
|
||||
if filepath.suffix == '.png':
|
||||
dflimg = DFLPNG.load( str(filepath) )
|
||||
elif filepath.suffix == '.jpg':
|
||||
dflimg = DFLJPG.load ( str(filepath) )
|
||||
else:
|
||||
dflimg = None
|
||||
|
||||
if dflimg is not None:
|
||||
paths_to_extract.append (filepath)
|
||||
|
||||
paths_to_extract_len = len(paths_to_extract)
|
||||
if paths_to_extract_len > 0:
|
||||
io.log_info ("Performing extract fanseg for %d files..." % (paths_to_extract_len) )
|
||||
data = ExtractSubprocessor ([ ExtractSubprocessor.Data(filename) for filename in paths_to_extract ], 'fanseg', multi_gpu=multi_gpu, cpu_only=cpu_only).run()
|
||||
|
||||
#unused in end user workflow
|
||||
def extract_umd_csv(input_file_csv,
|
||||
image_size=256,
|
||||
face_type='full_face',
|
||||
device_args={} ):
|
||||
|
||||
#extract faces from umdfaces.io dataset csv file with pitch,yaw,roll info.
|
||||
multi_gpu = device_args.get('multi_gpu', False)
|
||||
cpu_only = device_args.get('cpu_only', False)
|
||||
face_type = FaceType.fromString(face_type)
|
||||
|
||||
input_file_csv_path = Path(input_file_csv)
|
||||
if not input_file_csv_path.exists():
|
||||
raise ValueError('input_file_csv not found. Please ensure it exists.')
|
||||
|
||||
input_file_csv_root_path = input_file_csv_path.parent
|
||||
output_path = input_file_csv_path.parent / ('aligned_' + input_file_csv_path.name)
|
||||
|
||||
io.log_info("Output dir is %s." % (str(output_path)) )
|
||||
|
||||
if output_path.exists():
|
||||
output_images_paths = Path_utils.get_image_paths(output_path)
|
||||
if len(output_images_paths) > 0:
|
||||
io.input_bool("WARNING !!! \n %s contains files! \n They will be deleted. \n Press enter to continue." % (str(output_path)), False )
|
||||
for filename in output_images_paths:
|
||||
Path(filename).unlink()
|
||||
else:
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
with open( str(input_file_csv_path), 'r') as f:
|
||||
csv_file = f.read()
|
||||
except Exception as e:
|
||||
io.log_err("Unable to open or read file " + str(input_file_csv_path) + ": " + str(e) )
|
||||
return
|
||||
|
||||
strings = csv_file.split('\n')
|
||||
keys = strings[0].split(',')
|
||||
keys_len = len(keys)
|
||||
csv_data = []
|
||||
for i in range(1, len(strings)):
|
||||
values = strings[i].split(',')
|
||||
if keys_len != len(values):
|
||||
io.log_err("Wrong string in csv file, skipping.")
|
||||
continue
|
||||
|
||||
csv_data += [ { keys[n] : values[n] for n in range(keys_len) } ]
|
||||
|
||||
data = []
|
||||
for d in csv_data:
|
||||
filename = input_file_csv_root_path / d['FILE']
|
||||
|
||||
pitch, yaw, roll = float(d['PITCH']), float(d['YAW']), float(d['ROLL'])
|
||||
if pitch < -90 or pitch > 90 or yaw < -90 or yaw > 90 or roll < -90 or roll > 90:
|
||||
continue
|
||||
|
||||
pitch_yaw_roll = pitch/90.0, yaw/90.0, roll/90.0
|
||||
|
||||
x,y,w,h = float(d['FACE_X']), float(d['FACE_Y']), float(d['FACE_WIDTH']), float(d['FACE_HEIGHT'])
|
||||
|
||||
data += [ ExtractSubprocessor.Data(filename=filename, rects=[ [x,y,x+w,y+h] ], pitch_yaw_roll=pitch_yaw_roll) ]
|
||||
|
||||
images_found = len(data)
|
||||
faces_detected = 0
|
||||
if len(data) > 0:
|
||||
io.log_info ("Performing 2nd pass from csv file...")
|
||||
data = ExtractSubprocessor (data, 'landmarks', multi_gpu=multi_gpu, cpu_only=cpu_only).run()
|
||||
|
||||
io.log_info ('Performing 3rd pass...')
|
||||
data = ExtractSubprocessor (data, 'final', image_size, face_type, None, multi_gpu=multi_gpu, cpu_only=cpu_only, manual=False, final_output_path=output_path).run()
|
||||
faces_detected += sum([d.faces_detected for d in data])
|
||||
|
||||
|
||||
io.log_info ('-------------------------')
|
||||
io.log_info ('Images found: %d' % (images_found) )
|
||||
io.log_info ('Faces detected: %d' % (faces_detected) )
|
||||
|
|
168
mainscripts/FacesetEnhancer.py
Normal file
168
mainscripts/FacesetEnhancer.py
Normal file
|
@ -0,0 +1,168 @@
|
|||
import multiprocessing
|
||||
import shutil
|
||||
|
||||
from DFLIMG import *
|
||||
from interact import interact as io
|
||||
from joblib import Subprocessor
|
||||
from nnlib import nnlib
|
||||
from utils import Path_utils
|
||||
from utils.cv2_utils import *
|
||||
|
||||
|
||||
class FacesetEnhancerSubprocessor(Subprocessor):
|
||||
|
||||
#override
|
||||
def __init__(self, image_paths, output_dirpath, multi_gpu=False, cpu_only=False):
|
||||
self.image_paths = image_paths
|
||||
self.output_dirpath = output_dirpath
|
||||
self.result = []
|
||||
self.devices = FacesetEnhancerSubprocessor.get_devices_for_config(multi_gpu, cpu_only)
|
||||
|
||||
super().__init__('FacesetEnhancer', FacesetEnhancerSubprocessor.Cli, 600)
|
||||
|
||||
#override
|
||||
def on_clients_initialized(self):
|
||||
io.progress_bar (None, len (self.image_paths))
|
||||
|
||||
#override
|
||||
def on_clients_finalized(self):
|
||||
io.progress_bar_close()
|
||||
|
||||
#override
|
||||
def process_info_generator(self):
|
||||
base_dict = {'output_dirpath':self.output_dirpath}
|
||||
|
||||
for (device_idx, device_type, device_name, device_total_vram_gb) in self.devices:
|
||||
client_dict = base_dict.copy()
|
||||
client_dict['device_idx'] = device_idx
|
||||
client_dict['device_name'] = device_name
|
||||
client_dict['device_type'] = device_type
|
||||
yield client_dict['device_name'], {}, client_dict
|
||||
|
||||
#override
|
||||
def get_data(self, host_dict):
|
||||
if len (self.image_paths) > 0:
|
||||
return self.image_paths.pop(0)
|
||||
|
||||
#override
|
||||
def on_data_return (self, host_dict, data):
|
||||
self.image_paths.insert(0, data)
|
||||
|
||||
#override
|
||||
def on_result (self, host_dict, data, result):
|
||||
io.progress_bar_inc(1)
|
||||
if result[0] == 1:
|
||||
self.result +=[ (result[1], result[2]) ]
|
||||
|
||||
#override
|
||||
def get_result(self):
|
||||
return self.result
|
||||
|
||||
@staticmethod
|
||||
def get_devices_for_config (multi_gpu, cpu_only):
|
||||
backend = nnlib.device.backend
|
||||
if 'cpu' in backend:
|
||||
cpu_only = True
|
||||
|
||||
if not cpu_only and backend == "plaidML":
|
||||
cpu_only = True
|
||||
|
||||
if not cpu_only:
|
||||
devices = []
|
||||
if multi_gpu:
|
||||
devices = nnlib.device.getValidDevicesWithAtLeastTotalMemoryGB(2)
|
||||
|
||||
if len(devices) == 0:
|
||||
idx = nnlib.device.getBestValidDeviceIdx()
|
||||
if idx != -1:
|
||||
devices = [idx]
|
||||
|
||||
if len(devices) == 0:
|
||||
cpu_only = True
|
||||
|
||||
result = []
|
||||
for idx in devices:
|
||||
dev_name = nnlib.device.getDeviceName(idx)
|
||||
dev_vram = nnlib.device.getDeviceVRAMTotalGb(idx)
|
||||
|
||||
result += [ (idx, 'GPU', dev_name, dev_vram) ]
|
||||
|
||||
return result
|
||||
|
||||
if cpu_only:
|
||||
return [ (i, 'CPU', 'CPU%d' % (i), 0 ) for i in range( min(8, multiprocessing.cpu_count() // 2) ) ]
|
||||
|
||||
class Cli(Subprocessor.Cli):
|
||||
|
||||
#override
|
||||
def on_initialize(self, client_dict):
|
||||
device_idx = client_dict['device_idx']
|
||||
cpu_only = client_dict['device_type'] == 'CPU'
|
||||
self.output_dirpath = client_dict['output_dirpath']
|
||||
|
||||
device_config = nnlib.DeviceConfig ( cpu_only=cpu_only, force_gpu_idx=device_idx, allow_growth=True)
|
||||
nnlib.import_all (device_config)
|
||||
|
||||
device_vram = device_config.gpu_vram_gb[0]
|
||||
|
||||
intro_str = 'Running on %s.' % (client_dict['device_name'])
|
||||
if not cpu_only and device_vram <= 2:
|
||||
intro_str += " Recommended to close all programs using this device."
|
||||
|
||||
self.log_info (intro_str)
|
||||
|
||||
from facelib import FaceEnhancer
|
||||
self.fe = FaceEnhancer()
|
||||
|
||||
#override
|
||||
def process_data(self, filepath):
|
||||
try:
|
||||
dflimg = DFLIMG.load (filepath)
|
||||
if dflimg is None:
|
||||
self.log_err ("%s is not a dfl image file" % (filepath.name) )
|
||||
else:
|
||||
img = cv2_imread(filepath).astype(np.float32) / 255.0
|
||||
|
||||
img = self.fe.enhance(img)
|
||||
|
||||
img = np.clip (img*255, 0, 255).astype(np.uint8)
|
||||
|
||||
output_filepath = self.output_dirpath / filepath.name
|
||||
|
||||
cv2_imwrite ( str(output_filepath), img, [int(cv2.IMWRITE_JPEG_QUALITY), 100] )
|
||||
dflimg.embed_and_set ( str(output_filepath) )
|
||||
return (1, filepath, output_filepath)
|
||||
except:
|
||||
self.log_err (f"Exception occured while processing file {filepath}. Error: {traceback.format_exc()}")
|
||||
|
||||
return (0, filepath, None)
|
||||
|
||||
def process_folder ( dirpath, multi_gpu=False, cpu_only=False ):
|
||||
output_dirpath = dirpath.parent / (dirpath.name + '_enhanced')
|
||||
output_dirpath.mkdir (exist_ok=True, parents=True)
|
||||
|
||||
dirpath_parts = '/'.join( dirpath.parts[-2:])
|
||||
output_dirpath_parts = '/'.join( output_dirpath.parts[-2:] )
|
||||
io.log_info (f"Enhancing faceset in {dirpath_parts}")
|
||||
io.log_info ( f"Processing to {output_dirpath_parts}")
|
||||
|
||||
output_images_paths = Path_utils.get_image_paths(output_dirpath)
|
||||
if len(output_images_paths) > 0:
|
||||
for filename in output_images_paths:
|
||||
Path(filename).unlink()
|
||||
|
||||
image_paths = [Path(x) for x in Path_utils.get_image_paths( dirpath )]
|
||||
result = FacesetEnhancerSubprocessor ( image_paths, output_dirpath, multi_gpu=multi_gpu, cpu_only=cpu_only).run()
|
||||
|
||||
is_merge = io.input_bool (f"\r\nMerge {output_dirpath_parts} to {dirpath_parts} ? (y/n skip:y) : ", True)
|
||||
if is_merge:
|
||||
io.log_info (f"Copying processed files to {dirpath_parts}")
|
||||
|
||||
for (filepath, output_filepath) in result:
|
||||
try:
|
||||
shutil.copy (output_filepath, filepath)
|
||||
except:
|
||||
pass
|
||||
|
||||
io.log_info (f"Removing {output_dirpath_parts}")
|
||||
shutil.rmtree(output_dirpath)
|
|
@ -6,8 +6,7 @@ from interact import interact as io
|
|||
from nnlib import DeepPortraitRelighting
|
||||
from utils import Path_utils
|
||||
from utils.cv2_utils import *
|
||||
from utils.DFLJPG import DFLJPG
|
||||
from utils.DFLPNG import DFLPNG
|
||||
from DFLIMG import *
|
||||
|
||||
class RelightEditor:
|
||||
def __init__(self, image_paths, dpr, lighten):
|
||||
|
@ -183,12 +182,7 @@ def relight(input_dir, lighten=None, random_one=None):
|
|||
filtered_image_paths = []
|
||||
for filepath in io.progress_bar_generator(image_paths, "Collecting fileinfo"):
|
||||
try:
|
||||
if filepath.suffix == '.png':
|
||||
dflimg = DFLPNG.load( str(filepath) )
|
||||
elif filepath.suffix == '.jpg':
|
||||
dflimg = DFLJPG.load ( str(filepath) )
|
||||
else:
|
||||
dflimg = None
|
||||
dflimg = DFLIMG.load (Path(filepath))
|
||||
|
||||
if dflimg is None:
|
||||
io.log_err ("%s is not a dfl image file" % (filepath.name) )
|
||||
|
@ -210,13 +204,7 @@ def relight(input_dir, lighten=None, random_one=None):
|
|||
|
||||
for filepath in io.progress_bar_generator(image_paths, "Relighting"):
|
||||
try:
|
||||
if filepath.suffix == '.png':
|
||||
dflimg = DFLPNG.load( str(filepath) )
|
||||
elif filepath.suffix == '.jpg':
|
||||
dflimg = DFLJPG.load ( str(filepath) )
|
||||
else:
|
||||
dflimg = None
|
||||
|
||||
dflimg = DFLIMG.load ( Path(filepath) )
|
||||
if dflimg is None:
|
||||
io.log_err ("%s is not a dfl image file" % (filepath.name) )
|
||||
continue
|
||||
|
@ -262,12 +250,7 @@ def delete_relighted(input_dir):
|
|||
|
||||
files_to_delete = []
|
||||
for filepath in io.progress_bar_generator(image_paths, "Loading"):
|
||||
if filepath.suffix == '.png':
|
||||
dflimg = DFLPNG.load( str(filepath) )
|
||||
elif filepath.suffix == '.jpg':
|
||||
dflimg = DFLJPG.load ( str(filepath) )
|
||||
else:
|
||||
dflimg = None
|
||||
dflimg = DFLIMG.load ( Path(filepath) )
|
||||
|
||||
if dflimg is None:
|
||||
io.log_err ("%s is not a dfl image file" % (filepath.name) )
|
||||
|
|
|
@ -9,13 +9,13 @@ import numpy as np
|
|||
import numpy.linalg as npl
|
||||
|
||||
import imagelib
|
||||
from DFLIMG import *
|
||||
from facelib import LandmarksProcessor
|
||||
from imagelib import IEPolys
|
||||
from interact import interact as io
|
||||
from utils import Path_utils
|
||||
from utils.cv2_utils import *
|
||||
from utils.DFLJPG import DFLJPG
|
||||
from utils.DFLPNG import DFLPNG
|
||||
|
||||
|
||||
class MaskEditor:
|
||||
STATE_NONE=0
|
||||
|
@ -396,19 +396,14 @@ def mask_editor_main(input_dir, confirmed_dir=None, skipped_dir=None, no_default
|
|||
cached_images[path.name] = cv2_imread(str(path)) / 255.0
|
||||
|
||||
if filepath is not None:
|
||||
if filepath.suffix == '.png':
|
||||
dflimg = DFLPNG.load( str(filepath) )
|
||||
elif filepath.suffix == '.jpg':
|
||||
dflimg = DFLJPG.load ( str(filepath) )
|
||||
else:
|
||||
dflimg = None
|
||||
dflimg = DFLIMG.load (filepath)
|
||||
|
||||
if dflimg is None:
|
||||
io.log_err ("%s is not a dfl image file" % (filepath.name) )
|
||||
continue
|
||||
else:
|
||||
lmrks = dflimg.get_landmarks()
|
||||
ie_polys = dflimg.get_ie_polys()
|
||||
ie_polys = IEPolys.load(dflimg.get_ie_polys())
|
||||
fanseg_mask = dflimg.get_fanseg_mask()
|
||||
|
||||
if filepath.name in cached_images:
|
||||
|
@ -573,4 +568,3 @@ def mask_editor_main(input_dir, confirmed_dir=None, skipped_dir=None, no_default
|
|||
io.process_messages(0.005)
|
||||
|
||||
io.destroy_all_windows()
|
||||
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
import os
|
||||
import multiprocessing
|
||||
import multiprocessing
|
||||
import operator
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
from functools import cmp_to_key
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
|
||||
|
@ -11,16 +13,13 @@ from numpy import linalg as npla
|
|||
|
||||
import imagelib
|
||||
from facelib import LandmarksProcessor
|
||||
from functools import cmp_to_key
|
||||
from imagelib import estimate_sharpness
|
||||
from interact import interact as io
|
||||
from joblib import Subprocessor
|
||||
from nnlib import VGGFace
|
||||
from nnlib import VGGFace, nnlib
|
||||
from utils import Path_utils
|
||||
from utils.cv2_utils import *
|
||||
from utils.DFLJPG import DFLJPG
|
||||
from utils.DFLPNG import DFLPNG
|
||||
|
||||
from DFLIMG import *
|
||||
|
||||
class BlurEstimatorSubprocessor(Subprocessor):
|
||||
class Cli(Subprocessor.Cli):
|
||||
|
@ -32,13 +31,7 @@ class BlurEstimatorSubprocessor(Subprocessor):
|
|||
#override
|
||||
def process_data(self, data):
|
||||
filepath = Path( data[0] )
|
||||
|
||||
if filepath.suffix == '.png':
|
||||
dflimg = DFLPNG.load( str(filepath) )
|
||||
elif filepath.suffix == '.jpg':
|
||||
dflimg = DFLJPG.load ( str(filepath) )
|
||||
else:
|
||||
dflimg = None
|
||||
dflimg = DFLIMG.load (filepath)
|
||||
|
||||
if dflimg is not None:
|
||||
image = cv2_imread( str(filepath) )
|
||||
|
@ -118,12 +111,7 @@ def sort_by_face(input_path):
|
|||
for filepath in io.progress_bar_generator( Path_utils.get_image_paths(input_path), "Loading"):
|
||||
filepath = Path(filepath)
|
||||
|
||||
if filepath.suffix == '.png':
|
||||
dflimg = DFLPNG.load( str(filepath) )
|
||||
elif filepath.suffix == '.jpg':
|
||||
dflimg = DFLJPG.load ( str(filepath) )
|
||||
else:
|
||||
dflimg = None
|
||||
dflimg = DFLIMG.load (filepath)
|
||||
|
||||
if dflimg is None:
|
||||
io.log_err ("%s is not a dfl image file" % (filepath.name) )
|
||||
|
@ -159,12 +147,7 @@ def sort_by_face_dissim(input_path):
|
|||
for filepath in io.progress_bar_generator( Path_utils.get_image_paths(input_path), "Loading"):
|
||||
filepath = Path(filepath)
|
||||
|
||||
if filepath.suffix == '.png':
|
||||
dflimg = DFLPNG.load( str(filepath) )
|
||||
elif filepath.suffix == '.jpg':
|
||||
dflimg = DFLJPG.load ( str(filepath) )
|
||||
else:
|
||||
dflimg = None
|
||||
dflimg = DFLIMG.load (filepath)
|
||||
|
||||
if dflimg is None:
|
||||
io.log_err ("%s is not a dfl image file" % (filepath.name) )
|
||||
|
@ -197,23 +180,14 @@ def sort_by_face_yaw(input_path):
|
|||
for filepath in io.progress_bar_generator( Path_utils.get_image_paths(input_path), "Loading"):
|
||||
filepath = Path(filepath)
|
||||
|
||||
if filepath.suffix == '.png':
|
||||
dflimg = DFLPNG.load( str(filepath) )
|
||||
elif filepath.suffix == '.jpg':
|
||||
dflimg = DFLJPG.load ( str(filepath) )
|
||||
else:
|
||||
dflimg = None
|
||||
dflimg = DFLIMG.load (filepath)
|
||||
|
||||
if dflimg is None:
|
||||
io.log_err ("%s is not a dfl image file" % (filepath.name) )
|
||||
trash_img_list.append ( [str(filepath)] )
|
||||
continue
|
||||
|
||||
pitch_yaw_roll = dflimg.get_pitch_yaw_roll()
|
||||
if pitch_yaw_roll is not None:
|
||||
pitch, yaw, roll = pitch_yaw_roll
|
||||
else:
|
||||
pitch, yaw, roll = LandmarksProcessor.estimate_pitch_yaw_roll ( dflimg.get_landmarks() )
|
||||
pitch, yaw, roll = LandmarksProcessor.estimate_pitch_yaw_roll ( dflimg.get_landmarks() )
|
||||
|
||||
img_list.append( [str(filepath), yaw ] )
|
||||
|
||||
|
@ -229,23 +203,14 @@ def sort_by_face_pitch(input_path):
|
|||
for filepath in io.progress_bar_generator( Path_utils.get_image_paths(input_path), "Loading"):
|
||||
filepath = Path(filepath)
|
||||
|
||||
if filepath.suffix == '.png':
|
||||
dflimg = DFLPNG.load( str(filepath) )
|
||||
elif filepath.suffix == '.jpg':
|
||||
dflimg = DFLJPG.load ( str(filepath) )
|
||||
else:
|
||||
dflimg = None
|
||||
dflimg = DFLIMG.load (filepath)
|
||||
|
||||
if dflimg is None:
|
||||
io.log_err ("%s is not a dfl image file" % (filepath.name) )
|
||||
trash_img_list.append ( [str(filepath)] )
|
||||
continue
|
||||
|
||||
pitch_yaw_roll = dflimg.get_pitch_yaw_roll()
|
||||
if pitch_yaw_roll is not None:
|
||||
pitch, yaw, roll = pitch_yaw_roll
|
||||
else:
|
||||
pitch, yaw, roll = LandmarksProcessor.estimate_pitch_yaw_roll ( dflimg.get_landmarks() )
|
||||
pitch, yaw, roll = LandmarksProcessor.estimate_pitch_yaw_roll ( dflimg.get_landmarks() )
|
||||
|
||||
img_list.append( [str(filepath), pitch ] )
|
||||
|
||||
|
@ -423,12 +388,7 @@ def sort_by_hist_dissim(input_path):
|
|||
for filepath in io.progress_bar_generator( Path_utils.get_image_paths(input_path), "Loading"):
|
||||
filepath = Path(filepath)
|
||||
|
||||
if filepath.suffix == '.png':
|
||||
dflimg = DFLPNG.load( str(filepath) )
|
||||
elif filepath.suffix == '.jpg':
|
||||
dflimg = DFLJPG.load ( str(filepath) )
|
||||
else:
|
||||
dflimg = None
|
||||
dflimg = DFLIMG.load (filepath)
|
||||
|
||||
image = cv2_imread(str(filepath))
|
||||
|
||||
|
@ -480,12 +440,7 @@ def sort_by_origname(input_path):
|
|||
for filepath in io.progress_bar_generator( Path_utils.get_image_paths(input_path), "Loading"):
|
||||
filepath = Path(filepath)
|
||||
|
||||
if filepath.suffix == '.png':
|
||||
dflimg = DFLPNG.load( str(filepath) )
|
||||
elif filepath.suffix == '.jpg':
|
||||
dflimg = DFLJPG.load( str(filepath) )
|
||||
else:
|
||||
dflimg = None
|
||||
dflimg = DFLIMG.load (filepath)
|
||||
|
||||
if dflimg is None:
|
||||
io.log_err ("%s is not a dfl image file" % (filepath.name) )
|
||||
|
@ -527,12 +482,7 @@ class FinalLoaderSubprocessor(Subprocessor):
|
|||
filepath = Path(data[0])
|
||||
|
||||
try:
|
||||
if filepath.suffix == '.png':
|
||||
dflimg = DFLPNG.load( str(filepath) )
|
||||
elif filepath.suffix == '.jpg':
|
||||
dflimg = DFLJPG.load( str(filepath) )
|
||||
else:
|
||||
dflimg = None
|
||||
dflimg = DFLIMG.load (filepath)
|
||||
|
||||
if dflimg is None:
|
||||
self.log_err("%s is not a dfl image file" % (filepath.name))
|
||||
|
@ -837,38 +787,84 @@ def sort_by_vggface(input_path):
|
|||
|
||||
return img_list, trash_img_list
|
||||
|
||||
"""
|
||||
img_list_len = len(img_list)
|
||||
def sort_by_absdiff(input_path):
|
||||
io.log_info ("Sorting by absolute difference...")
|
||||
|
||||
for i in io.progress_bar_generator ( range(img_list_len-1), "Sorting" ):
|
||||
a = []
|
||||
i_1 = img_list[i][1]
|
||||
is_sim = io.input_bool ("Sort by similar? ( y/n ?:help skip:y ) : ", True, help_message="Otherwise sort by dissimilar.")
|
||||
|
||||
from nnlib import nnlib
|
||||
exec( nnlib.import_all( device_config=nnlib.device.Config() ), locals(), globals() )
|
||||
|
||||
image_paths = Path_utils.get_image_paths(input_path)
|
||||
image_paths_len = len(image_paths)
|
||||
|
||||
batch_size = 1024
|
||||
batch_size_remain = image_paths_len % batch_size
|
||||
|
||||
i_t = Input ( (256,256,3) )
|
||||
j_t = Input ( (256,256,3) )
|
||||
|
||||
outputs = []
|
||||
for i in range(batch_size):
|
||||
outputs += [ K.sum( K.abs(i_t-j_t[i]), axis=[1,2,3] ) ]
|
||||
|
||||
func_bs_full = K.function ( [i_t,j_t], outputs)
|
||||
|
||||
outputs = []
|
||||
for i in range(batch_size_remain):
|
||||
outputs += [ K.sum( K.abs(i_t-j_t[i]), axis=[1,2,3] ) ]
|
||||
|
||||
func_bs_remain = K.function ( [i_t,j_t], outputs)
|
||||
|
||||
import h5py
|
||||
db_file_path = Path(tempfile.gettempdir()) / 'sort_cache.hdf5'
|
||||
db_file = h5py.File( str(db_file_path), "w")
|
||||
db = db_file.create_dataset("results", (image_paths_len,image_paths_len), compression="gzip")
|
||||
|
||||
|
||||
for j in range(i+1, img_list_len):
|
||||
a.append ( [ j, np.linalg.norm(i_1-img_list[j][1]) ] )
|
||||
pg_len = image_paths_len // batch_size
|
||||
if batch_size_remain != 0:
|
||||
pg_len += 1
|
||||
|
||||
x = sorted(a, key=operator.itemgetter(1) )[0][0]
|
||||
saved = img_list[i+1]
|
||||
img_list[i+1] = img_list[x]
|
||||
img_list[x] = saved
|
||||
pg_len = int( ( pg_len*pg_len - pg_len ) / 2 + pg_len )
|
||||
|
||||
io.progress_bar ("Computing", pg_len)
|
||||
j=0
|
||||
while j < image_paths_len:
|
||||
j_images = [ cv2_imread(x) for x in image_paths[j:j+batch_size] ]
|
||||
j_images_len = len(j_images)
|
||||
|
||||
func = func_bs_remain if image_paths_len-j < batch_size else func_bs_full
|
||||
|
||||
i=0
|
||||
while i < image_paths_len:
|
||||
if i >= j:
|
||||
i_images = [ cv2_imread(x) for x in image_paths[i:i+batch_size] ]
|
||||
i_images_len = len(i_images)
|
||||
result = func ([i_images,j_images])
|
||||
db[j:j+j_images_len,i:i+i_images_len] = np.array(result)
|
||||
io.progress_bar_inc(1)
|
||||
|
||||
i += batch_size
|
||||
db_file.flush()
|
||||
j += batch_size
|
||||
|
||||
io.progress_bar_close()
|
||||
|
||||
next_id = 0
|
||||
sorted = [next_id]
|
||||
for i in io.progress_bar_generator ( range(image_paths_len-1), "Sorting" ):
|
||||
id_ar = np.concatenate ( [ db[:next_id,next_id], db[next_id,next_id:] ] )
|
||||
id_ar = np.argsort(id_ar)
|
||||
|
||||
|
||||
q = np.array ( [ x[1] for x in img_list ] )
|
||||
next_id = np.setdiff1d(id_ar, sorted, True)[ 0 if is_sim else -1]
|
||||
sorted += [next_id]
|
||||
db_file.close()
|
||||
db_file_path.unlink()
|
||||
|
||||
for i in io.progress_bar_generator ( range(img_list_len-1), "Sorting" ):
|
||||
|
||||
a = np.linalg.norm( q[i] - q[i+1:], axis=1 )
|
||||
a = i+1+np.argmin(a)
|
||||
|
||||
saved = img_list[i+1]
|
||||
img_list[i+1] = img_list[a]
|
||||
img_list[a] = saved
|
||||
|
||||
saved = q[i+1]
|
||||
q[i+1] = q[a]
|
||||
q[a] = saved
|
||||
"""
|
||||
img_list = [ (image_paths[x],) for x in sorted]
|
||||
return img_list, []
|
||||
|
||||
def final_process(input_path, img_list, trash_img_list):
|
||||
if len(trash_img_list) != 0:
|
||||
|
@ -909,8 +905,6 @@ def final_process(input_path, img_list, trash_img_list):
|
|||
except:
|
||||
io.log_info ('fail to rename %s' % (src.name) )
|
||||
|
||||
|
||||
|
||||
def main (input_path, sort_by_method):
|
||||
input_path = Path(input_path)
|
||||
sort_by_method = sort_by_method.lower()
|
||||
|
@ -932,6 +926,7 @@ def main (input_path, sort_by_method):
|
|||
elif sort_by_method == 'origname': img_list, trash_img_list = sort_by_origname (input_path)
|
||||
elif sort_by_method == 'oneface': img_list, trash_img_list = sort_by_oneface_in_image (input_path)
|
||||
elif sort_by_method == 'vggface': img_list, trash_img_list = sort_by_vggface (input_path)
|
||||
elif sort_by_method == 'absdiff': img_list, trash_img_list = sort_by_absdiff (input_path)
|
||||
elif sort_by_method == 'final': img_list, trash_img_list = sort_final (input_path)
|
||||
elif sort_by_method == 'final-no-blur': img_list, trash_img_list = sort_final (input_path, include_by_blur=False)
|
||||
|
||||
|
|
|
@ -1,22 +1,84 @@
|
|||
import cv2
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
from utils import Path_utils
|
||||
from utils.DFLPNG import DFLPNG
|
||||
from utils.DFLJPG import DFLJPG
|
||||
from utils.cv2_utils import *
|
||||
|
||||
import cv2
|
||||
|
||||
from DFLIMG import *
|
||||
from facelib import LandmarksProcessor
|
||||
from imagelib import IEPolys
|
||||
from interact import interact as io
|
||||
from utils import Path_utils
|
||||
from utils.cv2_utils import *
|
||||
|
||||
|
||||
def save_faceset_metadata_folder(input_path):
|
||||
input_path = Path(input_path)
|
||||
|
||||
metadata_filepath = input_path / 'meta.dat'
|
||||
|
||||
io.log_info (f"Saving metadata to {str(metadata_filepath)}\r\n")
|
||||
|
||||
d = {}
|
||||
for filepath in io.progress_bar_generator( Path_utils.get_image_paths(input_path), "Processing"):
|
||||
filepath = Path(filepath)
|
||||
dflimg = DFLIMG.load (filepath)
|
||||
|
||||
dfl_dict = dflimg.getDFLDictData()
|
||||
d[filepath.name] = ( dflimg.get_shape(), dfl_dict )
|
||||
|
||||
try:
|
||||
with open(metadata_filepath, "wb") as f:
|
||||
f.write ( pickle.dumps(d) )
|
||||
except:
|
||||
raise Exception( 'cannot save %s' % (filename) )
|
||||
|
||||
io.log_info("Now you can edit images.")
|
||||
io.log_info("!!! Keep same filenames in the folder.")
|
||||
io.log_info("You can change size of images, restoring process will downscale back to original size.")
|
||||
io.log_info("After that, use restore metadata.")
|
||||
|
||||
def restore_faceset_metadata_folder(input_path):
|
||||
input_path = Path(input_path)
|
||||
|
||||
metadata_filepath = input_path / 'meta.dat'
|
||||
io.log_info (f"Restoring metadata from {str(metadata_filepath)}.\r\n")
|
||||
|
||||
if not metadata_filepath.exists():
|
||||
io.log_err(f"Unable to find {str(metadata_filepath)}.")
|
||||
|
||||
try:
|
||||
with open(metadata_filepath, "rb") as f:
|
||||
d = pickle.loads(f.read())
|
||||
except:
|
||||
raise FileNotFoundError(filename)
|
||||
|
||||
for filepath in io.progress_bar_generator( Path_utils.get_image_paths(input_path), "Processing"):
|
||||
filepath = Path(filepath)
|
||||
|
||||
shape, dfl_dict = d.get(filepath.name, None)
|
||||
|
||||
img = cv2_imread (str(filepath))
|
||||
if img.shape != shape:
|
||||
img = cv2.resize (img, (shape[1], shape[0]), cv2.INTER_LANCZOS4 )
|
||||
|
||||
if filepath.suffix == '.png':
|
||||
cv2_imwrite (str(filepath), img)
|
||||
elif filepath.suffix == '.jpg':
|
||||
cv2_imwrite (str(filepath), img, [int(cv2.IMWRITE_JPEG_QUALITY), 100] )
|
||||
|
||||
if filepath.suffix == '.png':
|
||||
DFLPNG.embed_dfldict( str(filepath), dfl_dict )
|
||||
elif filepath.suffix == '.jpg':
|
||||
DFLJPG.embed_dfldict( str(filepath), dfl_dict )
|
||||
else:
|
||||
continue
|
||||
|
||||
metadata_filepath.unlink()
|
||||
|
||||
def remove_ie_polys_file (filepath):
|
||||
filepath = Path(filepath)
|
||||
|
||||
if filepath.suffix == '.png':
|
||||
dflimg = DFLPNG.load( str(filepath) )
|
||||
elif filepath.suffix == '.jpg':
|
||||
dflimg = DFLJPG.load ( str(filepath) )
|
||||
else:
|
||||
return
|
||||
|
||||
dflimg = DFLIMG.load (filepath)
|
||||
if dflimg is None:
|
||||
io.log_err ("%s is not a dfl image file" % (filepath.name) )
|
||||
return
|
||||
|
@ -37,12 +99,7 @@ def remove_ie_polys_folder(input_path):
|
|||
def remove_fanseg_file (filepath):
|
||||
filepath = Path(filepath)
|
||||
|
||||
if filepath.suffix == '.png':
|
||||
dflimg = DFLPNG.load( str(filepath) )
|
||||
elif filepath.suffix == '.jpg':
|
||||
dflimg = DFLJPG.load ( str(filepath) )
|
||||
else:
|
||||
return
|
||||
dflimg = DFLIMG.load (filepath)
|
||||
|
||||
if dflimg is None:
|
||||
io.log_err ("%s is not a dfl image file" % (filepath.name) )
|
||||
|
@ -69,14 +126,14 @@ def convert_png_to_jpg_file (filepath):
|
|||
|
||||
dflpng = DFLPNG.load (str(filepath) )
|
||||
if dflpng is None:
|
||||
io.log_err ("%s is not a dfl image file" % (filepath.name) )
|
||||
io.log_err ("%s is not a dfl png image file" % (filepath.name) )
|
||||
return
|
||||
|
||||
dfl_dict = dflpng.getDFLDictData()
|
||||
|
||||
img = cv2_imread (str(filepath))
|
||||
new_filepath = str(filepath.parent / (filepath.stem + '.jpg'))
|
||||
cv2_imwrite ( new_filepath, img, [int(cv2.IMWRITE_JPEG_QUALITY), 85])
|
||||
cv2_imwrite ( new_filepath, img, [int(cv2.IMWRITE_JPEG_QUALITY), 100])
|
||||
|
||||
DFLJPG.embed_data( new_filepath,
|
||||
face_type=dfl_dict.get('face_type', None),
|
||||
|
@ -105,12 +162,7 @@ def add_landmarks_debug_images(input_path):
|
|||
|
||||
img = cv2_imread(str(filepath))
|
||||
|
||||
if filepath.suffix == '.png':
|
||||
dflimg = DFLPNG.load( str(filepath) )
|
||||
elif filepath.suffix == '.jpg':
|
||||
dflimg = DFLJPG.load ( str(filepath) )
|
||||
else:
|
||||
dflimg = None
|
||||
dflimg = DFLIMG.load (filepath)
|
||||
|
||||
if dflimg is None:
|
||||
io.log_err ("%s is not a dfl image file" % (filepath.name) )
|
||||
|
@ -118,7 +170,7 @@ def add_landmarks_debug_images(input_path):
|
|||
|
||||
if img is not None:
|
||||
face_landmarks = dflimg.get_landmarks()
|
||||
LandmarksProcessor.draw_landmarks(img, face_landmarks, transparent_mask=True, ie_polys=dflimg.get_ie_polys() )
|
||||
LandmarksProcessor.draw_landmarks(img, face_landmarks, transparent_mask=True, ie_polys=IEPolys.load(dflimg.get_ie_polys()) )
|
||||
|
||||
output_file = '{}{}'.format( str(Path(str(input_path)) / filepath.stem), '_debug.jpg')
|
||||
cv2_imwrite(output_file, img, [int(cv2.IMWRITE_JPEG_QUALITY), 50] )
|
||||
|
@ -130,12 +182,7 @@ def recover_original_aligned_filename(input_path):
|
|||
for filepath in io.progress_bar_generator( Path_utils.get_image_paths(input_path), "Processing"):
|
||||
filepath = Path(filepath)
|
||||
|
||||
if filepath.suffix == '.png':
|
||||
dflimg = DFLPNG.load( str(filepath) )
|
||||
elif filepath.suffix == '.jpg':
|
||||
dflimg = DFLJPG.load ( str(filepath) )
|
||||
else:
|
||||
dflimg = None
|
||||
dflimg = DFLIMG.load (filepath)
|
||||
|
||||
if dflimg is None:
|
||||
io.log_err ("%s is not a dfl image file" % (filepath.name) )
|
||||
|
|
|
@ -5,13 +5,12 @@ from pathlib import Path
|
|||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from DFLIMG import *
|
||||
from facelib import FaceType, LandmarksProcessor
|
||||
from interact import interact as io
|
||||
from joblib import Subprocessor
|
||||
from utils import Path_utils
|
||||
from utils.cv2_utils import *
|
||||
from utils.DFLJPG import DFLJPG
|
||||
from utils.DFLPNG import DFLPNG
|
||||
|
||||
from . import Extractor, Sorter
|
||||
from .Extractor import ExtractSubprocessor
|
||||
|
@ -227,13 +226,8 @@ class CelebAMASKHQSubprocessor(Subprocessor):
|
|||
#override
|
||||
def process_data(self, data):
|
||||
filename = data[0]
|
||||
filepath = Path(filename)
|
||||
if filepath.suffix == '.png':
|
||||
dflimg = DFLPNG.load( str(filepath) )
|
||||
elif filepath.suffix == '.jpg':
|
||||
dflimg = DFLJPG.load ( str(filepath) )
|
||||
else:
|
||||
dflimg = None
|
||||
|
||||
dflimg = DFLIMG.load(Path(filename))
|
||||
|
||||
image_to_face_mat = dflimg.get_image_to_face_mat()
|
||||
src_filename = dflimg.get_source_filename()
|
||||
|
@ -330,12 +324,7 @@ def apply_celebamaskhq(input_dir ):
|
|||
paths_to_extract = []
|
||||
for filename in io.progress_bar_generator(Path_utils.get_image_paths(img_path), desc="Processing"):
|
||||
filepath = Path(filename)
|
||||
if filepath.suffix == '.png':
|
||||
dflimg = DFLPNG.load( str(filepath) )
|
||||
elif filepath.suffix == '.jpg':
|
||||
dflimg = DFLJPG.load ( str(filepath) )
|
||||
else:
|
||||
dflimg = None
|
||||
dflimg = DFLIMG.load(filepath)
|
||||
|
||||
if dflimg is not None:
|
||||
paths_to_extract.append (filepath)
|
||||
|
@ -379,3 +368,126 @@ def apply_celebamaskhq(input_dir ):
|
|||
|
||||
#import code
|
||||
#code.interact(local=dict(globals(), **locals()))
|
||||
|
||||
|
||||
|
||||
#unused in end user workflow
|
||||
def extract_fanseg(input_dir, device_args={} ):
|
||||
multi_gpu = device_args.get('multi_gpu', False)
|
||||
cpu_only = device_args.get('cpu_only', False)
|
||||
|
||||
input_path = Path(input_dir)
|
||||
if not input_path.exists():
|
||||
raise ValueError('Input directory not found. Please ensure it exists.')
|
||||
|
||||
paths_to_extract = []
|
||||
for filename in Path_utils.get_image_paths(input_path) :
|
||||
filepath = Path(filename)
|
||||
dflimg = DFLIMG.load ( filepath )
|
||||
if dflimg is not None:
|
||||
paths_to_extract.append (filepath)
|
||||
|
||||
paths_to_extract_len = len(paths_to_extract)
|
||||
if paths_to_extract_len > 0:
|
||||
io.log_info ("Performing extract fanseg for %d files..." % (paths_to_extract_len) )
|
||||
data = ExtractSubprocessor ([ ExtractSubprocessor.Data(filename) for filename in paths_to_extract ], 'fanseg', multi_gpu=multi_gpu, cpu_only=cpu_only).run()
|
||||
|
||||
#unused in end user workflow
|
||||
def extract_umd_csv(input_file_csv,
|
||||
image_size=256,
|
||||
face_type='full_face',
|
||||
device_args={} ):
|
||||
|
||||
#extract faces from umdfaces.io dataset csv file with pitch,yaw,roll info.
|
||||
multi_gpu = device_args.get('multi_gpu', False)
|
||||
cpu_only = device_args.get('cpu_only', False)
|
||||
face_type = FaceType.fromString(face_type)
|
||||
|
||||
input_file_csv_path = Path(input_file_csv)
|
||||
if not input_file_csv_path.exists():
|
||||
raise ValueError('input_file_csv not found. Please ensure it exists.')
|
||||
|
||||
input_file_csv_root_path = input_file_csv_path.parent
|
||||
output_path = input_file_csv_path.parent / ('aligned_' + input_file_csv_path.name)
|
||||
|
||||
io.log_info("Output dir is %s." % (str(output_path)) )
|
||||
|
||||
if output_path.exists():
|
||||
output_images_paths = Path_utils.get_image_paths(output_path)
|
||||
if len(output_images_paths) > 0:
|
||||
io.input_bool("WARNING !!! \n %s contains files! \n They will be deleted. \n Press enter to continue." % (str(output_path)), False )
|
||||
for filename in output_images_paths:
|
||||
Path(filename).unlink()
|
||||
else:
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
with open( str(input_file_csv_path), 'r') as f:
|
||||
csv_file = f.read()
|
||||
except Exception as e:
|
||||
io.log_err("Unable to open or read file " + str(input_file_csv_path) + ": " + str(e) )
|
||||
return
|
||||
|
||||
strings = csv_file.split('\n')
|
||||
keys = strings[0].split(',')
|
||||
keys_len = len(keys)
|
||||
csv_data = []
|
||||
for i in range(1, len(strings)):
|
||||
values = strings[i].split(',')
|
||||
if keys_len != len(values):
|
||||
io.log_err("Wrong string in csv file, skipping.")
|
||||
continue
|
||||
|
||||
csv_data += [ { keys[n] : values[n] for n in range(keys_len) } ]
|
||||
|
||||
data = []
|
||||
for d in csv_data:
|
||||
filename = input_file_csv_root_path / d['FILE']
|
||||
|
||||
#pitch, yaw, roll = float(d['PITCH']), float(d['YAW']), float(d['ROLL'])
|
||||
#if pitch < -90 or pitch > 90 or yaw < -90 or yaw > 90 or roll < -90 or roll > 90:
|
||||
# continue
|
||||
#
|
||||
#pitch_yaw_roll = pitch/90.0, yaw/90.0, roll/90.0
|
||||
|
||||
x,y,w,h = float(d['FACE_X']), float(d['FACE_Y']), float(d['FACE_WIDTH']), float(d['FACE_HEIGHT'])
|
||||
|
||||
data += [ ExtractSubprocessor.Data(filename=filename, rects=[ [x,y,x+w,y+h] ]) ]
|
||||
|
||||
images_found = len(data)
|
||||
faces_detected = 0
|
||||
if len(data) > 0:
|
||||
io.log_info ("Performing 2nd pass from csv file...")
|
||||
data = ExtractSubprocessor (data, 'landmarks', multi_gpu=multi_gpu, cpu_only=cpu_only).run()
|
||||
|
||||
io.log_info ('Performing 3rd pass...')
|
||||
data = ExtractSubprocessor (data, 'final', image_size, face_type, None, multi_gpu=multi_gpu, cpu_only=cpu_only, manual=False, final_output_path=output_path).run()
|
||||
faces_detected += sum([d.faces_detected for d in data])
|
||||
|
||||
|
||||
io.log_info ('-------------------------')
|
||||
io.log_info ('Images found: %d' % (images_found) )
|
||||
io.log_info ('Faces detected: %d' % (faces_detected) )
|
||||
io.log_info ('-------------------------')
|
||||
|
||||
def dev_test(input_dir):
|
||||
input_path = Path(input_dir)
|
||||
|
||||
dir_names = Path_utils.get_all_dir_names(input_path)
|
||||
|
||||
for dir_name in io.progress_bar_generator(dir_names, desc="Processing"):
|
||||
|
||||
img_paths = Path_utils.get_image_paths (input_path / dir_name)
|
||||
for filename in img_paths:
|
||||
filepath = Path(filename)
|
||||
|
||||
dflimg = DFLIMG.load (filepath)
|
||||
if dflimg is None:
|
||||
raise ValueError
|
||||
|
||||
dflimg.embed_and_set(filename, person_name=dir_name)
|
||||
|
||||
#import code
|
||||
#code.interact(local=dict(globals(), **locals()))
|
||||
|
||||
|
|
@ -28,12 +28,10 @@ class ModelBase(object):
|
|||
ask_write_preview_history=True,
|
||||
ask_target_iter=True,
|
||||
ask_batch_size=True,
|
||||
ask_sort_by_yaw=True,
|
||||
ask_random_flip=True,
|
||||
ask_src_scale_mod=True):
|
||||
ask_random_flip=True, **kwargs):
|
||||
|
||||
device_args['force_gpu_idx'] = device_args.get('force_gpu_idx',-1)
|
||||
device_args['cpu_only'] = device_args.get('cpu_only',False)
|
||||
device_args['cpu_only'] = True if debug else device_args.get('cpu_only',False)
|
||||
|
||||
if device_args['force_gpu_idx'] == -1 and not device_args['cpu_only']:
|
||||
idxs_names_list = nnlib.device.getValidDevicesIdxsWithNamesList()
|
||||
|
@ -115,13 +113,6 @@ class ModelBase(object):
|
|||
else:
|
||||
self.batch_size = self.options.get('batch_size', 0)
|
||||
|
||||
if ask_sort_by_yaw:
|
||||
if (self.iter == 0 or ask_override):
|
||||
default_sort_by_yaw = self.options.get('sort_by_yaw', False)
|
||||
self.options['sort_by_yaw'] = io.input_bool("Feed faces to network sorted by yaw? (y/n ?:help skip:%s) : " % (yn_str[default_sort_by_yaw]), default_sort_by_yaw, help_message="NN will not learn src face directions that don't match dst face directions. Do not use if the dst face has hair that covers the jaw." )
|
||||
else:
|
||||
self.options['sort_by_yaw'] = self.options.get('sort_by_yaw', False)
|
||||
|
||||
if ask_random_flip:
|
||||
default_random_flip = self.options.get('random_flip', True)
|
||||
if (self.iter == 0 or ask_override):
|
||||
|
@ -129,12 +120,6 @@ class ModelBase(object):
|
|||
else:
|
||||
self.options['random_flip'] = self.options.get('random_flip', default_random_flip)
|
||||
|
||||
if ask_src_scale_mod:
|
||||
if (self.iter == 0):
|
||||
self.options['src_scale_mod'] = np.clip( io.input_int("Src face scale modifier % ( -30...30, ?:help skip:0) : ", 0, help_message="If src face shape is wider than dst, try to decrease this value to get a better result."), -30, 30)
|
||||
else:
|
||||
self.options['src_scale_mod'] = self.options.get('src_scale_mod', 0)
|
||||
|
||||
self.autobackup = self.options.get('autobackup', False)
|
||||
if not self.autobackup and 'autobackup' in self.options:
|
||||
self.options.pop('autobackup')
|
||||
|
@ -151,10 +136,6 @@ class ModelBase(object):
|
|||
self.sort_by_yaw = self.options.get('sort_by_yaw',False)
|
||||
self.random_flip = self.options.get('random_flip',True)
|
||||
|
||||
self.src_scale_mod = self.options.get('src_scale_mod',0)
|
||||
if self.src_scale_mod == 0 and 'src_scale_mod' in self.options:
|
||||
self.options.pop('src_scale_mod')
|
||||
|
||||
self.onInitializeOptions(self.iter == 0, ask_override)
|
||||
|
||||
nnlib.import_all(self.device_config)
|
||||
|
|
|
@ -16,9 +16,7 @@ class AVATARModel(ModelBase):
|
|||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs,
|
||||
ask_sort_by_yaw=False,
|
||||
ask_random_flip=False,
|
||||
ask_src_scale_mod=False)
|
||||
ask_random_flip=False)
|
||||
|
||||
#override
|
||||
def onInitializeOptions(self, is_first_run, ask_override):
|
||||
|
|
|
@ -13,9 +13,7 @@ class Model(ModelBase):
|
|||
ask_enable_autobackup=False,
|
||||
ask_write_preview_history=False,
|
||||
ask_target_iter=False,
|
||||
ask_sort_by_yaw=False,
|
||||
ask_random_flip=False,
|
||||
ask_src_scale_mod=False)
|
||||
ask_random_flip=False)
|
||||
|
||||
#override
|
||||
def onInitializeOptions(self, is_first_run, ask_override):
|
||||
|
@ -47,13 +45,13 @@ class Model(ModelBase):
|
|||
self.set_training_data_generators ([
|
||||
SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.batch_size,
|
||||
sample_process_options=SampleProcessor.Options(random_flip=True),
|
||||
output_sample_types=[ { 'types': (t.IMG_WARPED_TRANSFORMED, face_type, t.MODE_BGR), 'resolution' : self.resolution, 'motion_blur':(25, 5), 'gaussian_blur':(25,5), 'border_replicate':False, 'random_hsv_shift' : True },
|
||||
output_sample_types=[ { 'types': (t.IMG_WARPED_TRANSFORMED, face_type, t.MODE_BGR_RANDOM_HSV_SHIFT), 'resolution' : self.resolution, 'motion_blur':(25, 5), 'gaussian_blur':(25,5), 'border_replicate':False},
|
||||
{ 'types': (t.IMG_WARPED_TRANSFORMED, face_type, t.MODE_M), 'resolution': self.resolution },
|
||||
]),
|
||||
|
||||
SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
|
||||
sample_process_options=SampleProcessor.Options(random_flip=True ),
|
||||
output_sample_types=[ { 'types': (t.IMG_TRANSFORMED , face_type, t.MODE_BGR), 'resolution' : self.resolution, 'random_hsv_shift' : True},
|
||||
output_sample_types=[ { 'types': (t.IMG_TRANSFORMED , face_type, t.MODE_BGR_RANDOM_HSV_SHIFT), 'resolution' : self.resolution},
|
||||
])
|
||||
])
|
||||
|
||||
|
|
|
@ -16,16 +16,14 @@ class FUNITModel(ModelBase):
|
|||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs,
|
||||
ask_sort_by_yaw=False,
|
||||
ask_random_flip=False,
|
||||
ask_src_scale_mod=False)
|
||||
ask_random_flip=False)
|
||||
|
||||
#override
|
||||
def onInitializeOptions(self, is_first_run, ask_override):
|
||||
|
||||
default_resolution = 96
|
||||
default_resolution = 64
|
||||
if is_first_run:
|
||||
self.options['resolution'] = io.input_int(f"Resolution ( 96,128,224 ?:help skip:{default_resolution}) : ", default_resolution, [128,224])
|
||||
self.options['resolution'] = io.input_int(f"Resolution ( 64,96,128,224 ?:help skip:{default_resolution}) : ", default_resolution, [64,96,128,224])
|
||||
else:
|
||||
self.options['resolution'] = self.options.get('resolution', default_resolution)
|
||||
|
||||
|
@ -48,7 +46,7 @@ class FUNITModel(ModelBase):
|
|||
|
||||
resolution = self.options['resolution']
|
||||
face_type = FaceType.FULL if self.options['face_type'] == 'f' else FaceType.HALF
|
||||
person_id_max_count = SampleGeneratorFace.get_person_id_max_count(self.training_data_src_path)
|
||||
person_id_max_count = SampleGeneratorFacePerson.get_person_id_max_count(self.training_data_src_path)
|
||||
|
||||
|
||||
self.model = FUNIT( face_type_str=FaceType.toString(face_type),
|
||||
|
@ -85,21 +83,21 @@ class FUNITModel(ModelBase):
|
|||
output_sample_types1=[ {'types': (t.IMG_SOURCE, face_type, t.MODE_BGR), 'resolution':resolution, 'normalize_tanh':True} ]
|
||||
|
||||
self.set_training_data_generators ([
|
||||
SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.batch_size,
|
||||
SampleGeneratorFacePerson(self.training_data_src_path, debug=self.is_debug(), batch_size=self.batch_size,
|
||||
sample_process_options=SampleProcessor.Options(random_flip=True, rotation_range=[0,0] ),
|
||||
output_sample_types=output_sample_types, person_id_mode=True ),
|
||||
output_sample_types=output_sample_types, person_id_mode=1, ),
|
||||
|
||||
SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.batch_size,
|
||||
SampleGeneratorFacePerson(self.training_data_src_path, debug=self.is_debug(), batch_size=self.batch_size,
|
||||
sample_process_options=SampleProcessor.Options(random_flip=True, rotation_range=[0,0] ),
|
||||
output_sample_types=output_sample_types, person_id_mode=True ),
|
||||
output_sample_types=output_sample_types, person_id_mode=1, ),
|
||||
|
||||
SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
|
||||
SampleGeneratorFacePerson(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
|
||||
sample_process_options=SampleProcessor.Options(random_flip=True, rotation_range=[0,0]),
|
||||
output_sample_types=output_sample_types1, person_id_mode=True ),
|
||||
output_sample_types=output_sample_types1, person_id_mode=1, ),
|
||||
|
||||
SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
|
||||
SampleGeneratorFacePerson(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
|
||||
sample_process_options=SampleProcessor.Options(random_flip=True, rotation_range=[0,0]),
|
||||
output_sample_types=output_sample_types1, person_id_mode=True ),
|
||||
output_sample_types=output_sample_types1, person_id_mode=1, ),
|
||||
])
|
||||
|
||||
#override
|
||||
|
|
|
@ -15,9 +15,7 @@ class Model(ModelBase):
|
|||
ask_enable_autobackup=False,
|
||||
ask_write_preview_history=False,
|
||||
ask_target_iter=False,
|
||||
ask_sort_by_yaw=False,
|
||||
ask_random_flip=False,
|
||||
ask_src_scale_mod=False)
|
||||
ask_random_flip=False)
|
||||
|
||||
#override
|
||||
def onInitializeOptions(self, is_first_run, ask_override):
|
||||
|
|
|
@ -50,9 +50,8 @@ class Model(ModelBase):
|
|||
{ 'types': (t.IMG_TRANSFORMED, t.FACE_TYPE_FULL, t.MODE_M), 'resolution':128} ]
|
||||
|
||||
self.set_training_data_generators ([
|
||||
SampleGeneratorFace(self.training_data_src_path, sort_by_yaw_target_samples_path=self.training_data_dst_path if self.sort_by_yaw else None,
|
||||
debug=self.is_debug(), batch_size=self.batch_size,
|
||||
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05])+self.src_scale_mod / 100.0 ),
|
||||
SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.batch_size,
|
||||
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05]) ),
|
||||
output_sample_types=output_sample_types),
|
||||
|
||||
SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
|
||||
|
|
|
@ -60,9 +60,8 @@ class Model(ModelBase):
|
|||
{ 'types': (t.IMG_TRANSFORMED, t.FACE_TYPE_HALF, t.MODE_M), 'resolution':128} ]
|
||||
|
||||
self.set_training_data_generators ([
|
||||
SampleGeneratorFace(self.training_data_src_path, sort_by_yaw_target_samples_path=self.training_data_dst_path if self.sort_by_yaw else None,
|
||||
debug=self.is_debug(), batch_size=self.batch_size,
|
||||
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05])+self.src_scale_mod / 100.0 ),
|
||||
SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.batch_size,
|
||||
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05]) ),
|
||||
output_sample_types=output_sample_types ),
|
||||
|
||||
SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
|
||||
|
|
|
@ -61,9 +61,8 @@ class Model(ModelBase):
|
|||
{ 'types': (t.IMG_TRANSFORMED, t.FACE_TYPE_HALF, t.MODE_M), 'resolution':64} ]
|
||||
|
||||
self.set_training_data_generators ([
|
||||
SampleGeneratorFace(self.training_data_src_path, sort_by_yaw_target_samples_path=self.training_data_dst_path if self.sort_by_yaw else None,
|
||||
debug=self.is_debug(), batch_size=self.batch_size,
|
||||
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05])+self.src_scale_mod / 100.0 ),
|
||||
SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.batch_size,
|
||||
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05]) ),
|
||||
output_sample_types=output_sample_types),
|
||||
|
||||
SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
|
||||
|
|
|
@ -55,9 +55,8 @@ class Model(ModelBase):
|
|||
{ 'types': (t.IMG_TRANSFORMED, t.FACE_TYPE_FULL, t.MODE_M), 'resolution':128} ]
|
||||
|
||||
self.set_training_data_generators ([
|
||||
SampleGeneratorFace(self.training_data_src_path, sort_by_yaw_target_samples_path=self.training_data_dst_path if self.sort_by_yaw else None,
|
||||
debug=self.is_debug(), batch_size=self.batch_size,
|
||||
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05])+self.src_scale_mod / 100.0 ),
|
||||
SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.batch_size,
|
||||
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05]) ),
|
||||
output_sample_types=output_sample_types),
|
||||
|
||||
SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
|
||||
|
|
|
@ -16,16 +16,14 @@ class Quick96Model(ModelBase):
|
|||
super().__init__(*args, **kwargs,
|
||||
ask_enable_autobackup=False,
|
||||
ask_write_preview_history=False,
|
||||
ask_target_iter=False,
|
||||
ask_target_iter=True,
|
||||
ask_batch_size=False,
|
||||
ask_sort_by_yaw=False,
|
||||
ask_random_flip=False,
|
||||
ask_src_scale_mod=False)
|
||||
ask_random_flip=False)
|
||||
|
||||
#override
|
||||
def onInitialize(self):
|
||||
exec(nnlib.import_all(), locals(), globals())
|
||||
self.set_vram_batch_requirements({1.5:2,2:4})#,3:4,4:8})
|
||||
self.set_vram_batch_requirements({1.5:2,2:4})
|
||||
|
||||
resolution = self.resolution = 96
|
||||
|
||||
|
@ -94,8 +92,8 @@ class Quick96Model(ModelBase):
|
|||
y = self.upscale(d_dims//2)(y)
|
||||
y = self.upscale(d_dims//4)(y)
|
||||
|
||||
return Conv2D(3, kernel_size=5, padding='same', activation='tanh')(x), \
|
||||
Conv2D(1, kernel_size=5, padding='same', activation='sigmoid')(y)
|
||||
return Conv2D(3, kernel_size=1, padding='same', activation='tanh')(x), \
|
||||
Conv2D(1, kernel_size=1, padding='same', activation='sigmoid')(y)
|
||||
|
||||
return func
|
||||
|
||||
|
@ -143,8 +141,8 @@ class Quick96Model(ModelBase):
|
|||
self.CA_conv_weights_list += [layer.weights[0]] #- is Conv2D kernel_weights
|
||||
|
||||
if self.is_training_mode:
|
||||
self.src_dst_opt = RMSprop(lr=2e-4)
|
||||
self.src_dst_mask_opt = RMSprop(lr=2e-4)
|
||||
self.src_dst_opt = RMSprop(lr=2e-4, lr_dropout=0.3)
|
||||
self.src_dst_mask_opt = RMSprop(lr=2e-4, lr_dropout=0.3)
|
||||
|
||||
target_src_masked = self.model.target_src*self.model.target_srcm
|
||||
target_dst_masked = self.model.target_dst*self.model.target_dstm
|
||||
|
@ -171,7 +169,7 @@ class Quick96Model(ModelBase):
|
|||
|
||||
self.set_training_data_generators ([
|
||||
SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.batch_size,
|
||||
sample_process_options=SampleProcessor.Options(random_flip=False, scale_range=np.array([-0.05, 0.05])+self.src_scale_mod / 100.0 ),
|
||||
sample_process_options=SampleProcessor.Options(random_flip=False, scale_range=np.array([-0.05, 0.05]) ),
|
||||
output_sample_types = [ {'types' : (t.IMG_WARPED_TRANSFORMED, t.FACE_TYPE_FULL, t.MODE_BGR), 'resolution': resolution, 'normalize_tanh':True },
|
||||
{'types' : (t.IMG_TRANSFORMED, t.FACE_TYPE_FULL, t.MODE_BGR), 'resolution': resolution, 'normalize_tanh':True },
|
||||
{'types' : (t.IMG_TRANSFORMED, t.FACE_TYPE_FULL, t.MODE_M), 'resolution': resolution } ]
|
||||
|
|
|
@ -466,18 +466,15 @@ class SAEModel(ModelBase):
|
|||
|
||||
training_data_src_path = self.training_data_src_path
|
||||
training_data_dst_path = self.training_data_dst_path
|
||||
sort_by_yaw = self.sort_by_yaw
|
||||
|
||||
if self.pretrain and self.pretraining_data_path is not None:
|
||||
training_data_src_path = self.pretraining_data_path
|
||||
training_data_dst_path = self.pretraining_data_path
|
||||
sort_by_yaw = False
|
||||
|
||||
self.set_training_data_generators ([
|
||||
SampleGeneratorFace(training_data_src_path, sort_by_yaw_target_samples_path=training_data_dst_path if sort_by_yaw else None,
|
||||
random_ct_samples_path=training_data_dst_path if self.options['ct_mode'] != 'none' else None,
|
||||
SampleGeneratorFace(training_data_src_path, random_ct_samples_path=training_data_dst_path if self.options['ct_mode'] != 'none' else None,
|
||||
debug=self.is_debug(), batch_size=self.batch_size,
|
||||
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05])+self.src_scale_mod / 100.0 ),
|
||||
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05]) ),
|
||||
output_sample_types = [ {'types' : (t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr), 'resolution':resolution, 'ct_mode': self.options['ct_mode'] },
|
||||
{'types' : (t.IMG_TRANSFORMED, face_type, t_mode_bgr), 'resolution': resolution, 'ct_mode': self.options['ct_mode'] },
|
||||
{'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_M), 'resolution': resolution } ]
|
||||
|
|
|
@ -65,6 +65,9 @@ class SAEHDModel(ModelBase):
|
|||
default_bg_style_power = self.options.get('bg_style_power', 0.0)
|
||||
|
||||
if is_first_run or ask_override:
|
||||
default_lr_dropout = self.options.get('lr_dropout', False)
|
||||
self.options['lr_dropout'] = io.input_bool ( f"Use learning rate dropout? (y/n, ?:help skip:{yn_str[default_lr_dropout]} ) : ", default_lr_dropout, help_message="When the face is trained enough, you can enable this option to get extra sharpness for less amount of iterations.")
|
||||
|
||||
default_random_warp = self.options.get('random_warp', True)
|
||||
self.options['random_warp'] = io.input_bool (f"Enable random warp of samples? ( y/n, ?:help skip:{yn_str[default_random_warp]}) : ", default_random_warp, help_message="Random warp is required to generalize facial expressions of both faces. When the face is trained enough, you can disable it to get extra sharpness for less amount of iterations.")
|
||||
|
||||
|
@ -84,8 +87,8 @@ class SAEHDModel(ModelBase):
|
|||
self.options['clipgrad'] = io.input_bool (f"Enable gradient clipping? (y/n, ?:help skip:{yn_str[default_clipgrad]}) : ", default_clipgrad, help_message="Gradient clipping reduces chance of model collapse, sacrificing speed of training.")
|
||||
else:
|
||||
self.options['clipgrad'] = False
|
||||
|
||||
else:
|
||||
self.options['lr_dropout'] = self.options.get('lr_dropout', False)
|
||||
self.options['random_warp'] = self.options.get('random_warp', True)
|
||||
self.options['true_face_training'] = self.options.get('true_face_training', default_true_face_training)
|
||||
self.options['face_style_power'] = self.options.get('face_style_power', default_face_style_power)
|
||||
|
@ -195,8 +198,6 @@ class SAEHDModel(ModelBase):
|
|||
if dims % 2 != 0:
|
||||
dims += 1
|
||||
|
||||
|
||||
|
||||
def func(x):
|
||||
|
||||
for i in [8,4,2]:
|
||||
|
@ -210,7 +211,7 @@ class SAEHDModel(ModelBase):
|
|||
x = Add()([x, x0])
|
||||
x = LeakyReLU(0.2)(x)
|
||||
|
||||
return Conv2D(output_nc, kernel_size=5, padding='same', activation='sigmoid')(x)
|
||||
return Conv2D(output_nc, kernel_size=1, padding='same', activation='sigmoid')(x)
|
||||
|
||||
return func
|
||||
|
||||
|
@ -324,7 +325,7 @@ class SAEHDModel(ModelBase):
|
|||
x = Add()([x, x0])
|
||||
x = LeakyReLU(0.2)(x)
|
||||
|
||||
return Conv2D(output_nc, kernel_size=5, padding='same', activation='sigmoid')(x)
|
||||
return Conv2D(output_nc, kernel_size=1, padding='same', activation='sigmoid')(x)
|
||||
|
||||
return func
|
||||
|
||||
|
@ -452,9 +453,11 @@ class SAEHDModel(ModelBase):
|
|||
psd_target_dst_anti_masked = self.model.pred_src_dst*(1.0 - target_dstm)
|
||||
|
||||
if self.is_training_mode:
|
||||
self.src_dst_opt = RMSprop(lr=1e-5, lr_dropout=0.3, clipnorm=1.0 if self.options['clipgrad'] else 0.0, tf_cpu_mode=self.options['optimizer_mode']-1)
|
||||
self.src_dst_mask_opt = RMSprop(lr=1e-5, lr_dropout=0.3, clipnorm=1.0 if self.options['clipgrad'] else 0.0, tf_cpu_mode=self.options['optimizer_mode']-1)
|
||||
self.D_opt = RMSprop(lr=1e-5, lr_dropout=0.3, clipnorm=1.0 if self.options['clipgrad'] else 0.0, tf_cpu_mode=self.options['optimizer_mode']-1)
|
||||
|
||||
lr_dropout = 0.3 if self.options['lr_dropout'] else 0.0
|
||||
self.src_dst_opt = RMSprop(lr=5e-5, lr_dropout=lr_dropout, clipnorm=1.0 if self.options['clipgrad'] else 0.0, tf_cpu_mode=self.options['optimizer_mode']-1)
|
||||
self.src_dst_mask_opt = RMSprop(lr=5e-5, lr_dropout=lr_dropout, clipnorm=1.0 if self.options['clipgrad'] else 0.0, tf_cpu_mode=self.options['optimizer_mode']-1)
|
||||
self.D_opt = RMSprop(lr=5e-5, lr_dropout=lr_dropout, clipnorm=1.0 if self.options['clipgrad'] else 0.0, tf_cpu_mode=self.options['optimizer_mode']-1)
|
||||
|
||||
src_loss = K.mean ( 10*dssim(kernel_size=int(resolution/11.6),max_value=1.0)( target_src_masked_opt, pred_src_src_masked_opt) )
|
||||
src_loss += K.mean ( 10*K.square( target_src_masked_opt - pred_src_src_masked_opt ) )
|
||||
|
@ -525,20 +528,17 @@ class SAEHDModel(ModelBase):
|
|||
|
||||
training_data_src_path = self.training_data_src_path
|
||||
training_data_dst_path = self.training_data_dst_path
|
||||
sort_by_yaw = self.sort_by_yaw
|
||||
|
||||
if self.pretrain and self.pretraining_data_path is not None:
|
||||
training_data_src_path = self.pretraining_data_path
|
||||
training_data_dst_path = self.pretraining_data_path
|
||||
sort_by_yaw = False
|
||||
|
||||
t_img_warped = t.IMG_WARPED_TRANSFORMED if self.options['random_warp'] else t.IMG_TRANSFORMED
|
||||
|
||||
self.set_training_data_generators ([
|
||||
SampleGeneratorFace(training_data_src_path, sort_by_yaw_target_samples_path=training_data_dst_path if sort_by_yaw else None,
|
||||
random_ct_samples_path=training_data_dst_path if self.options['ct_mode'] != 'none' else None,
|
||||
SampleGeneratorFace(training_data_src_path, random_ct_samples_path=training_data_dst_path if self.options['ct_mode'] != 'none' else None,
|
||||
debug=self.is_debug(), batch_size=self.batch_size,
|
||||
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05])+self.src_scale_mod / 100.0 ),
|
||||
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05]) ),
|
||||
output_sample_types = [ {'types' : (t_img_warped, face_type, t_mode_bgr), 'resolution':resolution, 'ct_mode': self.options['ct_mode'] },
|
||||
{'types' : (t.IMG_TRANSFORMED, face_type, t_mode_bgr), 'resolution': resolution, 'ct_mode': self.options['ct_mode'] },
|
||||
{'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_M), 'resolution': resolution } ]
|
||||
|
|
209
nnlib/nnlib.py
209
nnlib/nnlib.py
|
@ -28,6 +28,7 @@ class nnlib(object):
|
|||
|
||||
tf = None
|
||||
tf_sess = None
|
||||
tf_sess_config = None
|
||||
|
||||
PML = None
|
||||
PMLK = None
|
||||
|
@ -69,6 +70,7 @@ PixelNormalization = nnlib.PixelNormalization
|
|||
Activation = KL.Activation
|
||||
LeakyReLU = KL.LeakyReLU
|
||||
ELU = KL.ELU
|
||||
GeLU = nnlib.GeLU
|
||||
ReLU = KL.ReLU
|
||||
PReLU = KL.PReLU
|
||||
tanh = KL.Activation('tanh')
|
||||
|
@ -93,6 +95,7 @@ Model = keras.models.Model
|
|||
Adam = nnlib.Adam
|
||||
RMSprop = nnlib.RMSprop
|
||||
LookaheadOptimizer = nnlib.LookaheadOptimizer
|
||||
SGD = nnlib.keras.optimizers.SGD
|
||||
|
||||
modelify = nnlib.modelify
|
||||
gaussian_blur = nnlib.gaussian_blur
|
||||
|
@ -104,6 +107,7 @@ PixelShuffler = nnlib.PixelShuffler
|
|||
SubpixelUpscaler = nnlib.SubpixelUpscaler
|
||||
SubpixelDownscaler = nnlib.SubpixelDownscaler
|
||||
Scale = nnlib.Scale
|
||||
BilinearInterpolation = nnlib.BilinearInterpolation
|
||||
BlurPool = nnlib.BlurPool
|
||||
FUNITAdain = nnlib.FUNITAdain
|
||||
SelfAttention = nnlib.SelfAttention
|
||||
|
@ -191,6 +195,7 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
|
|||
|
||||
config.gpu_options.force_gpu_compatible = True
|
||||
config.gpu_options.allow_growth = device_config.allow_growth
|
||||
nnlib.tf_sess_config = config
|
||||
|
||||
nnlib.tf_sess = tf.Session(config=config)
|
||||
|
||||
|
@ -710,6 +715,141 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
|
|||
return dict(list(base_config.items()) + list(config.items()))
|
||||
nnlib.Scale = Scale
|
||||
|
||||
|
||||
"""
|
||||
unable to work in plaidML, due to unimplemented ops
|
||||
|
||||
class BilinearInterpolation(KL.Layer):
|
||||
def __init__(self, size=(2,2), **kwargs):
|
||||
self.size = size
|
||||
super(BilinearInterpolation, self).__init__(**kwargs)
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
return (input_shape[0], input_shape[1]*self.size[1], input_shape[2]*self.size[0], input_shape[3])
|
||||
|
||||
|
||||
def call(self, X):
|
||||
_,h,w,_ = K.int_shape(X)
|
||||
|
||||
X = K.concatenate( [ X, X[:,:,-2:-1,:] ],axis=2 )
|
||||
X = K.concatenate( [ X, X[:,:,-2:-1,:] ],axis=2 )
|
||||
X = K.concatenate( [ X, X[:,-2:-1,:,:] ],axis=1 )
|
||||
X = K.concatenate( [ X, X[:,-2:-1,:,:] ],axis=1 )
|
||||
|
||||
X_sh = K.shape(X)
|
||||
batch_size, height, width, num_channels = X_sh[0], X_sh[1], X_sh[2], X_sh[3]
|
||||
|
||||
output_h, output_w = (h*self.size[1]+4, w*self.size[0]+4)
|
||||
|
||||
x_linspace = np.linspace(-1. , 1. - 2/output_w, output_w)#
|
||||
y_linspace = np.linspace(-1. , 1. - 2/output_h, output_h)#
|
||||
|
||||
x_coordinates, y_coordinates = np.meshgrid(x_linspace, y_linspace)
|
||||
x_coordinates = K.flatten(K.constant(x_coordinates, dtype=K.floatx() ))
|
||||
y_coordinates = K.flatten(K.constant(y_coordinates, dtype=K.floatx() ))
|
||||
|
||||
grid = K.concatenate([x_coordinates, y_coordinates, K.ones_like(x_coordinates)], 0)
|
||||
grid = K.flatten(grid)
|
||||
|
||||
|
||||
grids = K.tile(grid, ( batch_size, ) )
|
||||
grids = K.reshape(grids, (batch_size, 3, output_h * output_w ))
|
||||
|
||||
|
||||
x = K.cast(K.flatten(grids[:, 0:1, :]), dtype='float32')
|
||||
y = K.cast(K.flatten(grids[:, 1:2, :]), dtype='float32')
|
||||
x = .5 * (x + 1.0) * K.cast(width, dtype='float32')
|
||||
y = .5 * (y + 1.0) * K.cast(height, dtype='float32')
|
||||
x0 = K.cast(x, 'int32')
|
||||
x1 = x0 + 1
|
||||
y0 = K.cast(y, 'int32')
|
||||
y1 = y0 + 1
|
||||
max_x = int(K.int_shape(X)[2] -1)
|
||||
max_y = int(K.int_shape(X)[1] -1)
|
||||
|
||||
x0 = K.clip(x0, 0, max_x)
|
||||
x1 = K.clip(x1, 0, max_x)
|
||||
y0 = K.clip(y0, 0, max_y)
|
||||
y1 = K.clip(y1, 0, max_y)
|
||||
|
||||
|
||||
pixels_batch = K.constant ( np.arange(0, batch_size) * (height * width), dtype=K.floatx() )
|
||||
|
||||
pixels_batch = K.expand_dims(pixels_batch, axis=-1)
|
||||
|
||||
base = K.tile(pixels_batch, (1, output_h * output_w ) )
|
||||
base = K.flatten(base)
|
||||
|
||||
# base_y0 = base + (y0 * width)
|
||||
base_y0 = y0 * width
|
||||
base_y0 = base + base_y0
|
||||
# base_y1 = base + (y1 * width)
|
||||
base_y1 = y1 * width
|
||||
base_y1 = base_y1 + base
|
||||
|
||||
indices_a = base_y0 + x0
|
||||
indices_b = base_y1 + x0
|
||||
indices_c = base_y0 + x1
|
||||
indices_d = base_y1 + x1
|
||||
|
||||
flat_image = K.reshape(X, (-1, num_channels) )
|
||||
flat_image = K.cast(flat_image, dtype='float32')
|
||||
pixel_values_a = K.gather(flat_image, indices_a)
|
||||
pixel_values_b = K.gather(flat_image, indices_b)
|
||||
pixel_values_c = K.gather(flat_image, indices_c)
|
||||
pixel_values_d = K.gather(flat_image, indices_d)
|
||||
|
||||
x0 = K.cast(x0, 'float32')
|
||||
x1 = K.cast(x1, 'float32')
|
||||
y0 = K.cast(y0, 'float32')
|
||||
y1 = K.cast(y1, 'float32')
|
||||
|
||||
area_a = K.expand_dims(((x1 - x) * (y1 - y)), 1)
|
||||
area_b = K.expand_dims(((x1 - x) * (y - y0)), 1)
|
||||
area_c = K.expand_dims(((x - x0) * (y1 - y)), 1)
|
||||
area_d = K.expand_dims(((x - x0) * (y - y0)), 1)
|
||||
|
||||
values_a = area_a * pixel_values_a
|
||||
values_b = area_b * pixel_values_b
|
||||
values_c = area_c * pixel_values_c
|
||||
values_d = area_d * pixel_values_d
|
||||
interpolated_image = values_a + values_b + values_c + values_d
|
||||
|
||||
new_shape = (batch_size, output_h, output_w, num_channels)
|
||||
interpolated_image = K.reshape(interpolated_image, new_shape)
|
||||
|
||||
interpolated_image = interpolated_image[:,:-4,:-4,:]
|
||||
return interpolated_image
|
||||
|
||||
def get_config(self):
|
||||
config = {"size": self.size}
|
||||
base_config = super(BilinearInterpolation, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
"""
|
||||
class BilinearInterpolation(KL.Layer):
|
||||
def __init__(self, size=(2,2), **kwargs):
|
||||
self.size = size
|
||||
super(BilinearInterpolation, self).__init__(**kwargs)
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
return (input_shape[0], input_shape[1]*self.size[1], input_shape[2]*self.size[0], input_shape[3])
|
||||
|
||||
def call(self, X):
|
||||
_,h,w,_ = K.int_shape(X)
|
||||
|
||||
return K.cast( K.tf.image.resize_images(X, (h*self.size[1],w*self.size[0]) ), K.floatx() )
|
||||
|
||||
def get_config(self):
|
||||
config = {"size": self.size}
|
||||
base_config = super(BilinearInterpolation, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
nnlib.BilinearInterpolation = BilinearInterpolation
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class SelfAttention(KL.Layer):
|
||||
def __init__(self, nc, squeeze_factor=8, **kwargs):
|
||||
assert nc//squeeze_factor > 0, f"Input channels must be >= {squeeze_factor}, recieved nc={nc}"
|
||||
|
@ -765,9 +905,10 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
|
|||
2 - allows to train x3 bigger network on same VRAM consuming RAM*2 and CPU power.
|
||||
"""
|
||||
|
||||
def __init__(self, learning_rate=0.001, rho=0.9, tf_cpu_mode=0, **kwargs):
|
||||
def __init__(self, learning_rate=0.001, rho=0.9, lr_dropout=0, tf_cpu_mode=0, **kwargs):
|
||||
self.initial_decay = kwargs.pop('decay', 0.0)
|
||||
self.epsilon = kwargs.pop('epsilon', K.epsilon())
|
||||
self.lr_dropout = lr_dropout
|
||||
self.tf_cpu_mode = tf_cpu_mode
|
||||
|
||||
learning_rate = kwargs.pop('lr', learning_rate)
|
||||
|
@ -788,6 +929,8 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
|
|||
dtype=K.dtype(p),
|
||||
name='accumulator_' + str(i))
|
||||
for (i, p) in enumerate(params)]
|
||||
if self.lr_dropout != 0:
|
||||
lr_rnds = [ K.random_binomial(K.int_shape(p), p=self.lr_dropout, dtype=K.dtype(p)) for p in params ]
|
||||
if e: e.__exit__(None, None, None)
|
||||
|
||||
self.weights = [self.iterations] + accumulators
|
||||
|
@ -798,12 +941,15 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
|
|||
lr = lr * (1. / (1. + self.decay * K.cast(self.iterations,
|
||||
K.dtype(self.decay))))
|
||||
|
||||
for p, g, a in zip(params, grads, accumulators):
|
||||
for i, (p, g, a) in enumerate(zip(params, grads, accumulators)):
|
||||
# update accumulator
|
||||
e = K.tf.device("/cpu:0") if self.tf_cpu_mode == 2 else None
|
||||
if e: e.__enter__()
|
||||
new_a = self.rho * a + (1. - self.rho) * K.square(g)
|
||||
new_p = p - lr * g / (K.sqrt(new_a) + self.epsilon)
|
||||
p_diff = - lr * g / (K.sqrt(new_a) + self.epsilon)
|
||||
if self.lr_dropout != 0:
|
||||
p_diff *= lr_rnds[i]
|
||||
new_p = p + p_diff
|
||||
if e: e.__exit__(None, None, None)
|
||||
|
||||
self.updates.append(K.update(a, new_a))
|
||||
|
@ -828,7 +974,8 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
|
|||
config = {'learning_rate': float(K.get_value(self.learning_rate)),
|
||||
'rho': float(K.get_value(self.rho)),
|
||||
'decay': float(K.get_value(self.decay)),
|
||||
'epsilon': self.epsilon}
|
||||
'epsilon': self.epsilon,
|
||||
'lr_dropout' : self.lr_dropout }
|
||||
base_config = super(RMSprop, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
nnlib.RMSprop = RMSprop
|
||||
|
@ -847,6 +994,7 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
|
|||
amsgrad: boolean. Whether to apply the AMSGrad variant of this
|
||||
algorithm from the paper "On the Convergence of Adam and
|
||||
Beyond".
|
||||
lr_dropout: float [0.0 .. 1.0] Learning rate dropout https://arxiv.org/pdf/1912.00144
|
||||
tf_cpu_mode: only for tensorflow backend
|
||||
0 - default, no changes.
|
||||
1 - allows to train x2 bigger network on same VRAM consuming RAM
|
||||
|
@ -860,7 +1008,7 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
|
|||
"""
|
||||
|
||||
def __init__(self, lr=0.001, beta_1=0.9, beta_2=0.999,
|
||||
epsilon=None, decay=0., amsgrad=False, tf_cpu_mode=0, **kwargs):
|
||||
epsilon=None, decay=0., amsgrad=False, lr_dropout=0, tf_cpu_mode=0, **kwargs):
|
||||
super(Adam, self).__init__(**kwargs)
|
||||
with K.name_scope(self.__class__.__name__):
|
||||
self.iterations = K.variable(0, dtype='int64', name='iterations')
|
||||
|
@ -873,6 +1021,7 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
|
|||
self.epsilon = epsilon
|
||||
self.initial_decay = decay
|
||||
self.amsgrad = amsgrad
|
||||
self.lr_dropout = lr_dropout
|
||||
self.tf_cpu_mode = tf_cpu_mode
|
||||
|
||||
def get_updates(self, loss, params):
|
||||
|
@ -896,11 +1045,16 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
|
|||
vhats = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
|
||||
else:
|
||||
vhats = [K.zeros(1) for _ in params]
|
||||
|
||||
|
||||
if self.lr_dropout != 0:
|
||||
lr_rnds = [ K.random_binomial(K.int_shape(p), p=self.lr_dropout, dtype=K.dtype(p)) for p in params ]
|
||||
|
||||
if e: e.__exit__(None, None, None)
|
||||
|
||||
self.weights = [self.iterations] + ms + vs + vhats
|
||||
|
||||
for p, g, m, v, vhat in zip(params, grads, ms, vs, vhats):
|
||||
for i, (p, g, m, v, vhat) in enumerate( zip(params, grads, ms, vs, vhats) ):
|
||||
e = K.tf.device("/cpu:0") if self.tf_cpu_mode == 2 else None
|
||||
if e: e.__enter__()
|
||||
m_t = (self.beta_1 * m) + (1. - self.beta_1) * g
|
||||
|
@ -912,13 +1066,16 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
|
|||
if e: e.__exit__(None, None, None)
|
||||
|
||||
if self.amsgrad:
|
||||
p_t = p - lr_t * m_t / (K.sqrt(vhat_t) + self.epsilon)
|
||||
p_diff = - lr_t * m_t / (K.sqrt(vhat_t) + self.epsilon)
|
||||
else:
|
||||
p_t = p - lr_t * m_t / (K.sqrt(v_t) + self.epsilon)
|
||||
p_diff = - lr_t * m_t / (K.sqrt(v_t) + self.epsilon)
|
||||
|
||||
if self.lr_dropout != 0:
|
||||
p_diff *= lr_rnds[i]
|
||||
|
||||
self.updates.append(K.update(m, m_t))
|
||||
self.updates.append(K.update(v, v_t))
|
||||
new_p = p_t
|
||||
new_p = p + p_diff
|
||||
|
||||
# Apply constraints.
|
||||
if getattr(p, 'constraint', None) is not None:
|
||||
|
@ -933,7 +1090,8 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
|
|||
'beta_2': float(K.get_value(self.beta_2)),
|
||||
'decay': float(K.get_value(self.decay)),
|
||||
'epsilon': self.epsilon,
|
||||
'amsgrad': self.amsgrad}
|
||||
'amsgrad': self.amsgrad,
|
||||
'lr_dropout' : self.lr_dropout}
|
||||
base_config = super(Adam, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
nnlib.Adam = Adam
|
||||
|
@ -1143,6 +1301,37 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
|
|||
return dict(list(base_config.items()) + list(config.items()))
|
||||
nnlib.DenseMaxout = DenseMaxout
|
||||
|
||||
class GeLU(KL.Layer):
|
||||
"""Gaussian Error Linear Unit.
|
||||
A smoother version of ReLU generally used
|
||||
in the BERT or BERT architecture based models.
|
||||
Original paper: https://arxiv.org/abs/1606.08415
|
||||
Input shape:
|
||||
Arbitrary. Use the keyword argument `input_shape`
|
||||
(tuple of integers, does not include the samples axis)
|
||||
when using this layer as the first layer in a model.
|
||||
Output shape:
|
||||
Same shape as the input.
|
||||
"""
|
||||
|
||||
def __init__(self, approximate=True, **kwargs):
|
||||
super(GeLU, self).__init__(**kwargs)
|
||||
self.approximate = approximate
|
||||
self.supports_masking = True
|
||||
|
||||
def call(self, inputs):
|
||||
cdf = 0.5 * (1.0 + K.tanh((np.sqrt(2 / np.pi) * (inputs + 0.044715 * K.pow(inputs, 3)))))
|
||||
return inputs * cdf
|
||||
|
||||
def get_config(self):
|
||||
config = {'approximate': self.approximate}
|
||||
base_config = super(GeLU, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
return input_shape
|
||||
nnlib.GeLU = GeLU
|
||||
|
||||
def CAInitializerMP( conv_weights_list ):
|
||||
#Convolution Aware Initialization https://arxiv.org/abs/1702.06295
|
||||
data = [ (i, K.int_shape(conv_weights)) for i, conv_weights in enumerate(conv_weights_list) ]
|
||||
|
|
149
samplelib/PackedFaceset.py
Normal file
149
samplelib/PackedFaceset.py
Normal file
|
@ -0,0 +1,149 @@
|
|||
import pickle
|
||||
import shutil
|
||||
import struct
|
||||
from pathlib import Path
|
||||
|
||||
import samplelib.SampleHost
|
||||
from interact import interact as io
|
||||
from samplelib import Sample
|
||||
from utils import Path_utils
|
||||
|
||||
packed_faceset_filename = 'faceset.pak'
|
||||
|
||||
class PackedFaceset():
|
||||
VERSION = 1
|
||||
|
||||
@staticmethod
|
||||
def pack(samples_path):
|
||||
samples_dat_path = samples_path / packed_faceset_filename
|
||||
|
||||
if samples_dat_path.exists():
|
||||
io.log_info(f"{samples_dat_path} : file already exists !")
|
||||
io.input_bool("Press enter to continue and overwrite.", False)
|
||||
|
||||
as_person_faceset = False
|
||||
dir_names = Path_utils.get_all_dir_names(samples_path)
|
||||
if len(dir_names) != 0:
|
||||
as_person_faceset = io.input_bool(f"{len(dir_names)} subdirectories found, process as person faceset? (y/n) skip:y : ", True)
|
||||
|
||||
if as_person_faceset:
|
||||
image_paths = []
|
||||
|
||||
for dir_name in dir_names:
|
||||
image_paths += Path_utils.get_image_paths(samples_path / dir_name)
|
||||
else:
|
||||
image_paths = Path_utils.get_image_paths(samples_path)
|
||||
|
||||
samples = samplelib.SampleHost.load_face_samples(image_paths)
|
||||
samples_len = len(samples)
|
||||
|
||||
samples_configs = []
|
||||
for sample in io.progress_bar_generator (samples, "Processing"):
|
||||
sample_filepath = Path(sample.filename)
|
||||
sample.filename = sample_filepath.name
|
||||
|
||||
if as_person_faceset:
|
||||
sample.person_name = sample_filepath.parent.name
|
||||
samples_configs.append ( sample.get_config() )
|
||||
samples_bytes = pickle.dumps(samples_configs, 4)
|
||||
|
||||
of = open(samples_dat_path, "wb")
|
||||
of.write ( struct.pack ("Q", PackedFaceset.VERSION ) )
|
||||
of.write ( struct.pack ("Q", len(samples_bytes) ) )
|
||||
of.write ( samples_bytes )
|
||||
|
||||
del samples_bytes #just free mem
|
||||
del samples_configs
|
||||
|
||||
sample_data_table_offset = of.tell()
|
||||
of.write ( bytes( 8*(samples_len+1) ) ) #sample data offset table
|
||||
|
||||
data_start_offset = of.tell()
|
||||
offsets = []
|
||||
|
||||
for sample in io.progress_bar_generator(samples, "Packing"):
|
||||
try:
|
||||
if sample.person_name is not None:
|
||||
sample_path = samples_path / sample.person_name / sample.filename
|
||||
else:
|
||||
sample_path = samples_path / sample.filename
|
||||
|
||||
|
||||
with open(sample_path, "rb") as f:
|
||||
b = f.read()
|
||||
|
||||
offsets.append ( of.tell() - data_start_offset )
|
||||
of.write(b)
|
||||
except:
|
||||
raise Exception(f"error while processing sample {sample_path}")
|
||||
|
||||
offsets.append ( of.tell() )
|
||||
|
||||
of.seek(sample_data_table_offset, 0)
|
||||
for offset in offsets:
|
||||
of.write ( struct.pack("Q", offset) )
|
||||
of.seek(0,2)
|
||||
of.close()
|
||||
|
||||
for filename in io.progress_bar_generator(image_paths, "Deleting files"):
|
||||
Path(filename).unlink()
|
||||
|
||||
if as_person_faceset:
|
||||
for dir_name in io.progress_bar_generator(dir_names, "Deleting dirs"):
|
||||
dir_path = samples_path / dir_name
|
||||
try:
|
||||
shutil.rmtree(dir_path)
|
||||
except:
|
||||
io.log_info (f"unable to remove: {dir_path} ")
|
||||
|
||||
@staticmethod
|
||||
def unpack(samples_path):
|
||||
samples_dat_path = samples_path / packed_faceset_filename
|
||||
if not samples_dat_path.exists():
|
||||
io.log_info(f"{samples_dat_path} : file not found.")
|
||||
return
|
||||
|
||||
samples = PackedFaceset.load(samples_path)
|
||||
|
||||
for sample in io.progress_bar_generator(samples, "Unpacking"):
|
||||
person_name = sample.person_name
|
||||
if person_name is not None:
|
||||
person_path = samples_path / person_name
|
||||
person_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
target_filepath = person_path / sample.filename
|
||||
else:
|
||||
target_filepath = samples_path / sample.filename
|
||||
|
||||
with open(target_filepath, "wb") as f:
|
||||
f.write( sample.read_raw_file() )
|
||||
|
||||
samples_dat_path.unlink()
|
||||
|
||||
@staticmethod
|
||||
def load(samples_path):
|
||||
samples_dat_path = samples_path / packed_faceset_filename
|
||||
if not samples_dat_path.exists():
|
||||
return None
|
||||
|
||||
f = open(samples_dat_path, "rb")
|
||||
version, = struct.unpack("Q", f.read(8) )
|
||||
if version != PackedFaceset.VERSION:
|
||||
raise NotImplementedError
|
||||
|
||||
sizeof_samples_bytes, = struct.unpack("Q", f.read(8) )
|
||||
|
||||
samples_configs = pickle.loads ( f.read(sizeof_samples_bytes) )
|
||||
samples = []
|
||||
for sample_config in samples_configs:
|
||||
samples.append ( Sample (**sample_config) )
|
||||
|
||||
offsets = [ struct.unpack("Q", f.read(8) )[0] for _ in range(len(samples)+1) ]
|
||||
data_start_offset = f.tell()
|
||||
f.close()
|
||||
|
||||
for i, sample in enumerate(samples):
|
||||
start_offset, end_offset = offsets[i], offsets[i+1]
|
||||
sample.set_filename_offset_size( str(samples_dat_path), data_start_offset+start_offset, end_offset-start_offset )
|
||||
|
||||
return samples
|
|
@ -5,43 +5,99 @@ import cv2
|
|||
import numpy as np
|
||||
|
||||
from utils.cv2_utils import *
|
||||
from utils.DFLJPG import DFLJPG
|
||||
from utils.DFLPNG import DFLPNG
|
||||
|
||||
from DFLIMG import *
|
||||
from facelib import LandmarksProcessor
|
||||
from imagelib import IEPolys
|
||||
|
||||
class SampleType(IntEnum):
|
||||
IMAGE = 0 #raw image
|
||||
|
||||
FACE_BEGIN = 1
|
||||
FACE = 1 #aligned face unsorted
|
||||
FACE_YAW_SORTED = 2 #sorted by yaw
|
||||
FACE_YAW_SORTED_AS_TARGET = 3 #sorted by yaw and included only yaws which exist in TARGET also automatic mirrored
|
||||
FACE_TEMPORAL_SORTED = 4
|
||||
FACE_END = 4
|
||||
FACE = 1 #aligned face unsorted
|
||||
FACE_PERSON = 2 #aligned face person
|
||||
FACE_TEMPORAL_SORTED = 3 #sorted by source filename
|
||||
FACE_END = 3
|
||||
|
||||
QTY = 5
|
||||
QTY = 4
|
||||
|
||||
class Sample(object):
|
||||
def __init__(self, sample_type=None, filename=None, person_id=None, face_type=None, shape=None, landmarks=None, ie_polys=None, pitch_yaw_roll=None, eyebrows_expand_mod=None, source_filename=None, mirror=None, close_target_list=None, fanseg_mask_exist=False):
|
||||
__slots__ = ['sample_type',
|
||||
'filename',
|
||||
'face_type',
|
||||
'shape',
|
||||
'landmarks',
|
||||
'ie_polys',
|
||||
'eyebrows_expand_mod',
|
||||
'source_filename',
|
||||
'person_name',
|
||||
'pitch_yaw_roll',
|
||||
'_filename_offset_size',
|
||||
]
|
||||
|
||||
def __init__(self, sample_type=None,
|
||||
filename=None,
|
||||
face_type=None,
|
||||
shape=None,
|
||||
landmarks=None,
|
||||
ie_polys=None,
|
||||
eyebrows_expand_mod=None,
|
||||
source_filename=None,
|
||||
person_name=None,
|
||||
pitch_yaw_roll=None,
|
||||
**kwargs):
|
||||
|
||||
self.sample_type = sample_type if sample_type is not None else SampleType.IMAGE
|
||||
self.filename = filename
|
||||
self.person_id = person_id
|
||||
self.face_type = face_type
|
||||
self.shape = shape
|
||||
self.landmarks = np.array(landmarks) if landmarks is not None else None
|
||||
self.ie_polys = ie_polys
|
||||
self.pitch_yaw_roll = pitch_yaw_roll
|
||||
self.ie_polys = IEPolys.load(ie_polys)
|
||||
self.eyebrows_expand_mod = eyebrows_expand_mod
|
||||
self.source_filename = source_filename
|
||||
self.mirror = mirror
|
||||
self.close_target_list = close_target_list
|
||||
self.fanseg_mask_exist = fanseg_mask_exist
|
||||
self.person_name = person_name
|
||||
self.pitch_yaw_roll = pitch_yaw_roll
|
||||
|
||||
def copy_and_set(self, sample_type=None, filename=None, person_id=None, face_type=None, shape=None, landmarks=None, ie_polys=None, pitch_yaw_roll=None, eyebrows_expand_mod=None, source_filename=None, mirror=None, close_target_list=None, fanseg_mask=None, fanseg_mask_exist=None):
|
||||
self._filename_offset_size = None
|
||||
|
||||
def get_pitch_yaw_roll(self):
|
||||
if self.pitch_yaw_roll is None:
|
||||
self.pitch_yaw_roll = LandmarksProcessor.estimate_pitch_yaw_roll(landmarks)
|
||||
return self.pitch_yaw_roll
|
||||
|
||||
def set_filename_offset_size(self, filename, offset, size):
|
||||
self._filename_offset_size = (filename, offset, size)
|
||||
|
||||
def read_raw_file(self, filename=None):
|
||||
if self._filename_offset_size is not None:
|
||||
filename, offset, size = self._filename_offset_size
|
||||
with open(filename, "rb") as f:
|
||||
f.seek( offset, 0)
|
||||
return f.read (size)
|
||||
else:
|
||||
with open(filename, "rb") as f:
|
||||
return f.read()
|
||||
|
||||
def load_bgr(self):
|
||||
img = cv2_imread (self.filename, loader_func=self.read_raw_file).astype(np.float32) / 255.0
|
||||
return img
|
||||
|
||||
def get_config(self):
|
||||
return {'sample_type': self.sample_type,
|
||||
'filename': self.filename,
|
||||
'face_type': self.face_type,
|
||||
'shape': self.shape,
|
||||
'landmarks': self.landmarks.tolist(),
|
||||
'ie_polys': self.ie_polys.dump(),
|
||||
'eyebrows_expand_mod': self.eyebrows_expand_mod,
|
||||
'source_filename': self.source_filename,
|
||||
'person_name': self.person_name
|
||||
}
|
||||
|
||||
"""
|
||||
def copy_and_set(self, sample_type=None, filename=None, face_type=None, shape=None, landmarks=None, ie_polys=None, pitch_yaw_roll=None, eyebrows_expand_mod=None, source_filename=None, fanseg_mask=None, person_name=None):
|
||||
return Sample(
|
||||
sample_type=sample_type if sample_type is not None else self.sample_type,
|
||||
filename=filename if filename is not None else self.filename,
|
||||
person_id=person_id if person_id is not None else self.person_id,
|
||||
face_type=face_type if face_type is not None else self.face_type,
|
||||
shape=shape if shape is not None else self.shape,
|
||||
landmarks=landmarks if landmarks is not None else self.landmarks.copy(),
|
||||
|
@ -49,30 +105,6 @@ class Sample(object):
|
|||
pitch_yaw_roll=pitch_yaw_roll if pitch_yaw_roll is not None else self.pitch_yaw_roll,
|
||||
eyebrows_expand_mod=eyebrows_expand_mod if eyebrows_expand_mod is not None else self.eyebrows_expand_mod,
|
||||
source_filename=source_filename if source_filename is not None else self.source_filename,
|
||||
mirror=mirror if mirror is not None else self.mirror,
|
||||
close_target_list=close_target_list if close_target_list is not None else self.close_target_list,
|
||||
fanseg_mask_exist=fanseg_mask_exist if fanseg_mask_exist is not None else self.fanseg_mask_exist)
|
||||
person_name=person_name if person_name is not None else self.person_name)
|
||||
|
||||
def load_bgr(self):
|
||||
img = cv2_imread (self.filename).astype(np.float32) / 255.0
|
||||
if self.mirror:
|
||||
img = img[:,::-1].copy()
|
||||
return img
|
||||
|
||||
def load_fanseg_mask(self):
|
||||
if self.fanseg_mask_exist:
|
||||
filepath = Path(self.filename)
|
||||
if filepath.suffix == '.png':
|
||||
dflimg = DFLPNG.load ( str(filepath) )
|
||||
elif filepath.suffix == '.jpg':
|
||||
dflimg = DFLJPG.load ( str(filepath) )
|
||||
else:
|
||||
dflimg = None
|
||||
return dflimg.get_fanseg_mask()
|
||||
|
||||
return None
|
||||
|
||||
def get_random_close_target_sample(self):
|
||||
if self.close_target_list is None:
|
||||
return None
|
||||
return self.close_target_list[randint (0, len(self.close_target_list)-1)]
|
||||
"""
|
|
@ -5,10 +5,10 @@ import cv2
|
|||
import numpy as np
|
||||
|
||||
from facelib import LandmarksProcessor
|
||||
from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor,
|
||||
from samplelib import (SampleGeneratorBase, SampleHost, SampleProcessor,
|
||||
SampleType)
|
||||
from utils import iter_utils
|
||||
|
||||
from utils import mp_utils
|
||||
|
||||
'''
|
||||
arg
|
||||
|
@ -19,15 +19,10 @@ output_sample_types = [
|
|||
'''
|
||||
class SampleGeneratorFace(SampleGeneratorBase):
|
||||
def __init__ (self, samples_path, debug=False, batch_size=1,
|
||||
sort_by_yaw=False,
|
||||
sort_by_yaw_target_samples_path=None,
|
||||
random_ct_samples_path=None,
|
||||
sample_process_options=SampleProcessor.Options(),
|
||||
output_sample_types=[],
|
||||
add_sample_idx=False,
|
||||
use_caching=False,
|
||||
generators_count=2,
|
||||
generators_random_seed=None,
|
||||
**kwargs):
|
||||
|
||||
super().__init__(samples_path, debug, batch_size)
|
||||
|
@ -35,34 +30,27 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
|||
self.output_sample_types = output_sample_types
|
||||
self.add_sample_idx = add_sample_idx
|
||||
|
||||
if sort_by_yaw_target_samples_path is not None:
|
||||
self.sample_type = SampleType.FACE_YAW_SORTED_AS_TARGET
|
||||
elif sort_by_yaw:
|
||||
self.sample_type = SampleType.FACE_YAW_SORTED
|
||||
else:
|
||||
self.sample_type = SampleType.FACE
|
||||
|
||||
if generators_random_seed is not None and len(generators_random_seed) != generators_count:
|
||||
raise ValueError("len(generators_random_seed) != generators_count")
|
||||
|
||||
self.generators_random_seed = generators_random_seed
|
||||
|
||||
samples = SampleLoader.load (self.sample_type, self.samples_path, sort_by_yaw_target_samples_path, use_caching=use_caching)
|
||||
np.random.shuffle(samples)
|
||||
self.samples_len = len(samples)
|
||||
samples_host = SampleHost.mp_host (SampleType.FACE, self.samples_path)
|
||||
self.samples_len = len(samples_host.get_list())
|
||||
|
||||
if self.samples_len == 0:
|
||||
raise ValueError('No training data provided.')
|
||||
|
||||
ct_samples = SampleLoader.load (SampleType.FACE, random_ct_samples_path) if random_ct_samples_path is not None else None
|
||||
self.random_ct_sample_chance = 100
|
||||
index_host = mp_utils.IndexHost(self.samples_len)
|
||||
|
||||
if random_ct_samples_path is not None:
|
||||
ct_samples_host = SampleHost.mp_host (SampleType.FACE, random_ct_samples_path)
|
||||
ct_index_host = mp_utils.IndexHost( len(ct_samples_host.get_list()) )
|
||||
else:
|
||||
ct_samples_host = None
|
||||
ct_index_host = None
|
||||
|
||||
if self.debug:
|
||||
self.generators_count = 1
|
||||
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (0, samples, ct_samples) )]
|
||||
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (samples_host.create_cli(), index_host.create_cli(), ct_samples_host.create_cli() if ct_index_host is not None else None, ct_index_host.create_cli() if ct_index_host is not None else None) )]
|
||||
else:
|
||||
self.generators_count = min ( generators_count, self.samples_len )
|
||||
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (i, samples[i::self.generators_count], ct_samples ) ) for i in range(self.generators_count) ]
|
||||
self.generators_count = np.clip(multiprocessing.cpu_count(), 2, 4)
|
||||
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (samples_host.create_cli(), index_host.create_cli(), ct_samples_host.create_cli() if ct_index_host is not None else None, ct_index_host.create_cli() if ct_index_host is not None else None), start_now=True ) for i in range(self.generators_count) ]
|
||||
|
||||
self.generator_counter = -1
|
||||
|
||||
|
@ -79,86 +67,33 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
|||
return next(generator)
|
||||
|
||||
def batch_func(self, param ):
|
||||
generator_id, samples, ct_samples = param
|
||||
|
||||
if self.generators_random_seed is not None:
|
||||
np.random.seed ( self.generators_random_seed[generator_id] )
|
||||
|
||||
samples_len = len(samples)
|
||||
samples_idxs = [*range(samples_len)]
|
||||
|
||||
ct_samples_len = len(ct_samples) if ct_samples is not None else 0
|
||||
|
||||
if self.sample_type == SampleType.FACE_YAW_SORTED or self.sample_type == SampleType.FACE_YAW_SORTED_AS_TARGET:
|
||||
if all ( [ samples[idx] == None for idx in samples_idxs] ):
|
||||
raise ValueError('Not enough training data. Gather more faces!')
|
||||
|
||||
if self.sample_type == SampleType.FACE:
|
||||
shuffle_idxs = []
|
||||
elif self.sample_type == SampleType.FACE_YAW_SORTED or self.sample_type == SampleType.FACE_YAW_SORTED_AS_TARGET:
|
||||
shuffle_idxs = []
|
||||
shuffle_idxs_2D = [[]]*samples_len
|
||||
|
||||
samples, index_host, ct_samples, ct_index_host = param
|
||||
bs = self.batch_size
|
||||
while True:
|
||||
batches = None
|
||||
for n_batch in range(self.batch_size):
|
||||
while True:
|
||||
sample = None
|
||||
|
||||
if self.sample_type == SampleType.FACE:
|
||||
if len(shuffle_idxs) == 0:
|
||||
shuffle_idxs = samples_idxs.copy()
|
||||
np.random.shuffle(shuffle_idxs)
|
||||
indexes = index_host.get(bs)
|
||||
ct_indexes = ct_index_host.get(bs) if ct_samples is not None else None
|
||||
|
||||
idx = shuffle_idxs.pop()
|
||||
sample = samples[ idx ]
|
||||
for n_batch in range(bs):
|
||||
sample_idx = indexes[n_batch]
|
||||
sample = samples[ sample_idx ]
|
||||
ct_sample = ct_samples[ ct_indexes[n_batch] ] if ct_samples is not None else None
|
||||
|
||||
elif self.sample_type == SampleType.FACE_YAW_SORTED or self.sample_type == SampleType.FACE_YAW_SORTED_AS_TARGET:
|
||||
if len(shuffle_idxs) == 0:
|
||||
shuffle_idxs = samples_idxs.copy()
|
||||
np.random.shuffle(shuffle_idxs)
|
||||
try:
|
||||
x, = SampleProcessor.process ([sample], self.sample_process_options, self.output_sample_types, self.debug, ct_sample=ct_sample)
|
||||
except:
|
||||
raise Exception ("Exception occured in sample %s. Error: %s" % (sample.filename, traceback.format_exc() ) )
|
||||
|
||||
idx = shuffle_idxs.pop()
|
||||
if samples[idx] != None:
|
||||
if len(shuffle_idxs_2D[idx]) == 0:
|
||||
a = shuffle_idxs_2D[idx] = [ *range(len(samples[idx])) ]
|
||||
np.random.shuffle (a)
|
||||
if batches is None:
|
||||
batches = [ [] for _ in range(len(x)) ]
|
||||
if self.add_sample_idx:
|
||||
batches += [ [] ]
|
||||
i_sample_idx = len(batches)-1
|
||||
|
||||
idx2 = shuffle_idxs_2D[idx].pop()
|
||||
sample = samples[idx][idx2]
|
||||
|
||||
idx = (idx << 16) | (idx2 & 0xFFFF)
|
||||
|
||||
if sample is not None:
|
||||
try:
|
||||
ct_sample=None
|
||||
if ct_samples is not None:
|
||||
if np.random.randint(100) < self.random_ct_sample_chance:
|
||||
ct_sample=ct_samples[np.random.randint(ct_samples_len)]
|
||||
|
||||
x = SampleProcessor.process (sample, self.sample_process_options, self.output_sample_types, self.debug, ct_sample=ct_sample)
|
||||
except:
|
||||
raise Exception ("Exception occured in sample %s. Error: %s" % (sample.filename, traceback.format_exc() ) )
|
||||
|
||||
if type(x) != tuple and type(x) != list:
|
||||
raise Exception('SampleProcessor.process returns NOT tuple/list')
|
||||
|
||||
if batches is None:
|
||||
batches = [ [] for _ in range(len(x)) ]
|
||||
if self.add_sample_idx:
|
||||
batches += [ [] ]
|
||||
i_sample_idx = len(batches)-1
|
||||
|
||||
for i in range(len(x)):
|
||||
batches[i].append ( x[i] )
|
||||
|
||||
if self.add_sample_idx:
|
||||
batches[i_sample_idx].append (idx)
|
||||
|
||||
break
|
||||
for i in range(len(x)):
|
||||
batches[i].append ( x[i] )
|
||||
|
||||
if self.add_sample_idx:
|
||||
batches[i_sample_idx].append (sample_idx)
|
||||
yield [ np.array(batch) for batch in batches]
|
||||
|
||||
@staticmethod
|
||||
def get_person_id_max_count(samples_path):
|
||||
return SampleLoader.get_person_id_max_count(samples_path)
|
|
@ -1,3 +1,4 @@
|
|||
import copy
|
||||
import multiprocessing
|
||||
import traceback
|
||||
|
||||
|
@ -5,9 +6,9 @@ import cv2
|
|||
import numpy as np
|
||||
|
||||
from facelib import LandmarksProcessor
|
||||
from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor,
|
||||
from samplelib import (SampleGeneratorBase, SampleHost, SampleProcessor,
|
||||
SampleType)
|
||||
from utils import iter_utils
|
||||
from utils import iter_utils, mp_utils
|
||||
|
||||
|
||||
'''
|
||||
|
@ -22,9 +23,6 @@ class SampleGeneratorFacePerson(SampleGeneratorBase):
|
|||
sample_process_options=SampleProcessor.Options(),
|
||||
output_sample_types=[],
|
||||
person_id_mode=1,
|
||||
use_caching=False,
|
||||
generators_count=2,
|
||||
generators_random_seed=None,
|
||||
**kwargs):
|
||||
|
||||
super().__init__(samples_path, debug, batch_size)
|
||||
|
@ -32,47 +30,27 @@ class SampleGeneratorFacePerson(SampleGeneratorBase):
|
|||
self.output_sample_types = output_sample_types
|
||||
self.person_id_mode = person_id_mode
|
||||
|
||||
if generators_random_seed is not None and len(generators_random_seed) != generators_count:
|
||||
raise ValueError("len(generators_random_seed) != generators_count")
|
||||
self.generators_random_seed = generators_random_seed
|
||||
|
||||
samples = SampleLoader.load (SampleType.FACE, self.samples_path, person_id_mode=True, use_caching=use_caching)
|
||||
|
||||
if person_id_mode==1:
|
||||
np.random.shuffle(samples)
|
||||
|
||||
new_samples = []
|
||||
while len(samples) > 0:
|
||||
for i in range( len(samples)-1, -1, -1):
|
||||
sample = samples[i]
|
||||
|
||||
if len(sample) > 0:
|
||||
new_samples.append(sample.pop(0))
|
||||
|
||||
if len(sample) == 0:
|
||||
samples.pop(i)
|
||||
samples = new_samples
|
||||
#new_samples = []
|
||||
#for s in samples:
|
||||
# new_samples += s
|
||||
#samples = new_samples
|
||||
#np.random.shuffle(samples)
|
||||
|
||||
samples_host = SampleHost.mp_host (SampleType.FACE, self.samples_path)
|
||||
samples = samples_host.get_list()
|
||||
self.samples_len = len(samples)
|
||||
|
||||
if self.samples_len == 0:
|
||||
raise ValueError('No training data provided.')
|
||||
|
||||
unique_person_names = { sample.person_name for sample in samples }
|
||||
persons_name_idxs = { person_name : [] for person_name in unique_person_names }
|
||||
for i,sample in enumerate(samples):
|
||||
persons_name_idxs[sample.person_name].append (i)
|
||||
indexes2D = [ persons_name_idxs[person_name] for person_name in unique_person_names ]
|
||||
index2d_host = mp_utils.Index2DHost(indexes2D)
|
||||
|
||||
if self.debug:
|
||||
self.generators_count = 1
|
||||
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (0, samples) )]
|
||||
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (samples_host.create_cli(), index2d_host.create_cli(),) )]
|
||||
else:
|
||||
self.generators_count = min ( generators_count, self.samples_len )
|
||||
|
||||
if person_id_mode==1:
|
||||
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (i, samples[i::self.generators_count]) ) for i in range(self.generators_count) ]
|
||||
else:
|
||||
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (i, samples) ) for i in range(self.generators_count) ]
|
||||
self.generators_count = np.clip(multiprocessing.cpu_count(), 2, 4)
|
||||
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (samples_host.create_cli(), index2d_host.create_cli(),), start_now=True ) for i in range(self.generators_count) ]
|
||||
|
||||
self.generator_counter = -1
|
||||
|
||||
|
@ -89,12 +67,43 @@ class SampleGeneratorFacePerson(SampleGeneratorBase):
|
|||
return next(generator)
|
||||
|
||||
def batch_func(self, param ):
|
||||
generator_id, samples = param
|
||||
samples, index2d_host, = param
|
||||
bs = self.batch_size
|
||||
|
||||
if self.generators_random_seed is not None:
|
||||
np.random.seed ( self.generators_random_seed[generator_id] )
|
||||
while True:
|
||||
person_idxs = index2d_host.get_1D(bs)
|
||||
samples_idxs = index2d_host.get_2D(person_idxs, 1)
|
||||
|
||||
if self.person_id_mode==1:
|
||||
batches = None
|
||||
for n_batch in range(bs):
|
||||
person_id = person_idxs[n_batch]
|
||||
sample_idx = samples_idxs[n_batch][0]
|
||||
|
||||
sample = samples[ sample_idx ]
|
||||
try:
|
||||
x, = SampleProcessor.process ([sample], self.sample_process_options, self.output_sample_types, self.debug)
|
||||
except:
|
||||
raise Exception ("Exception occured in sample %s. Error: %s" % (sample.filename, traceback.format_exc() ) )
|
||||
|
||||
if batches is None:
|
||||
batches = [ [] for _ in range(len(x)) ]
|
||||
|
||||
batches += [ [] ]
|
||||
i_person_id = len(batches)-1
|
||||
|
||||
for i in range(len(x)):
|
||||
batches[i].append ( x[i] )
|
||||
|
||||
batches[i_person_id].append ( np.array([person_id]) )
|
||||
|
||||
yield [ np.array(batch) for batch in batches]
|
||||
|
||||
@staticmethod
|
||||
def get_person_id_max_count(samples_path):
|
||||
return SampleHost.get_person_id_max_count(samples_path)
|
||||
|
||||
"""
|
||||
if self.person_id_mode==1:
|
||||
samples_len = len(samples)
|
||||
samples_idxs = [*range(samples_len)]
|
||||
shuffle_idxs = []
|
||||
|
@ -114,10 +123,20 @@ class SampleGeneratorFacePerson(SampleGeneratorBase):
|
|||
for i in range(persons_count):
|
||||
samples_idxs[i] = [*range(len(samples[i]))]
|
||||
shuffle_idxs[i] = []
|
||||
elif self.person_id_mode==3:
|
||||
persons_count = len(samples)
|
||||
|
||||
while True:
|
||||
person_idxs = [ *range(persons_count) ]
|
||||
shuffle_person_idxs = []
|
||||
|
||||
if self.person_id_mode==2:
|
||||
samples_idxs = [None]*persons_count
|
||||
shuffle_idxs = [None]*persons_count
|
||||
|
||||
for i in range(persons_count):
|
||||
samples_idxs[i] = [*range(len(samples[i]))]
|
||||
shuffle_idxs[i] = []
|
||||
|
||||
if self.person_id_mode==2:
|
||||
if len(shuffle_person_idxs) == 0:
|
||||
shuffle_person_idxs = person_idxs.copy()
|
||||
np.random.shuffle(shuffle_person_idxs)
|
||||
|
@ -130,13 +149,13 @@ class SampleGeneratorFacePerson(SampleGeneratorBase):
|
|||
if self.person_id_mode==1:
|
||||
if len(shuffle_idxs) == 0:
|
||||
shuffle_idxs = samples_idxs.copy()
|
||||
#np.random.shuffle(shuffle_idxs)
|
||||
np.random.shuffle(shuffle_idxs) ###
|
||||
|
||||
idx = shuffle_idxs.pop()
|
||||
sample = samples[ idx ]
|
||||
|
||||
try:
|
||||
x = SampleProcessor.process (sample, self.sample_process_options, self.output_sample_types, self.debug)
|
||||
x, = SampleProcessor.process ([sample], self.sample_process_options, self.output_sample_types, self.debug)
|
||||
except:
|
||||
raise Exception ("Exception occured in sample %s. Error: %s" % (sample.filename, traceback.format_exc() ) )
|
||||
|
||||
|
@ -155,7 +174,7 @@ class SampleGeneratorFacePerson(SampleGeneratorBase):
|
|||
batches[i_person_id].append ( np.array([sample.person_id]) )
|
||||
|
||||
|
||||
else:
|
||||
elif self.person_id_mode==2:
|
||||
person_id1, person_id2 = person_ids
|
||||
|
||||
if len(shuffle_idxs[person_id1]) == 0:
|
||||
|
@ -174,12 +193,12 @@ class SampleGeneratorFacePerson(SampleGeneratorBase):
|
|||
|
||||
if sample1 is not None and sample2 is not None:
|
||||
try:
|
||||
x1 = SampleProcessor.process (sample1, self.sample_process_options, self.output_sample_types, self.debug)
|
||||
x1, = SampleProcessor.process ([sample1], self.sample_process_options, self.output_sample_types, self.debug)
|
||||
except:
|
||||
raise Exception ("Exception occured in sample %s. Error: %s" % (sample1.filename, traceback.format_exc() ) )
|
||||
|
||||
try:
|
||||
x2 = SampleProcessor.process (sample2, self.sample_process_options, self.output_sample_types, self.debug)
|
||||
x2, = SampleProcessor.process ([sample2], self.sample_process_options, self.output_sample_types, self.debug)
|
||||
except:
|
||||
raise Exception ("Exception occured in sample %s. Error: %s" % (sample2.filename, traceback.format_exc() ) )
|
||||
|
||||
|
@ -203,10 +222,54 @@ class SampleGeneratorFacePerson(SampleGeneratorBase):
|
|||
|
||||
batches[i_person_id2].append ( np.array([sample2.person_id]) )
|
||||
|
||||
elif self.person_id_mode==3:
|
||||
if len(shuffle_person_idxs) == 0:
|
||||
shuffle_person_idxs = person_idxs.copy()
|
||||
np.random.shuffle(shuffle_person_idxs)
|
||||
person_id = shuffle_person_idxs.pop()
|
||||
|
||||
if len(shuffle_idxs[person_id]) == 0:
|
||||
shuffle_idxs[person_id] = samples_idxs[person_id].copy()
|
||||
np.random.shuffle(shuffle_idxs[person_id])
|
||||
|
||||
yield [ np.array(batch) for batch in batches]
|
||||
idx = shuffle_idxs[person_id].pop()
|
||||
sample1 = samples[person_id][idx]
|
||||
|
||||
@staticmethod
|
||||
def get_person_id_max_count(samples_path):
|
||||
return SampleLoader.get_person_id_max_count(samples_path)
|
||||
if len(shuffle_idxs[person_id]) == 0:
|
||||
shuffle_idxs[person_id] = samples_idxs[person_id].copy()
|
||||
np.random.shuffle(shuffle_idxs[person_id])
|
||||
|
||||
idx = shuffle_idxs[person_id].pop()
|
||||
sample2 = samples[person_id][idx]
|
||||
|
||||
if sample1 is not None and sample2 is not None:
|
||||
try:
|
||||
x1, = SampleProcessor.process ([sample1], self.sample_process_options, self.output_sample_types, self.debug)
|
||||
except:
|
||||
raise Exception ("Exception occured in sample %s. Error: %s" % (sample1.filename, traceback.format_exc() ) )
|
||||
|
||||
try:
|
||||
x2, = SampleProcessor.process ([sample2], self.sample_process_options, self.output_sample_types, self.debug)
|
||||
except:
|
||||
raise Exception ("Exception occured in sample %s. Error: %s" % (sample2.filename, traceback.format_exc() ) )
|
||||
|
||||
x1_len = len(x1)
|
||||
if batches is None:
|
||||
batches = [ [] for _ in range(x1_len) ]
|
||||
batches += [ [] ]
|
||||
i_person_id1 = len(batches)-1
|
||||
|
||||
batches += [ [] for _ in range(len(x2)) ]
|
||||
batches += [ [] ]
|
||||
i_person_id2 = len(batches)-1
|
||||
|
||||
for i in range(x1_len):
|
||||
batches[i].append ( x1[i] )
|
||||
|
||||
for i in range(len(x2)):
|
||||
batches[x1_len+1+i].append ( x2[i] )
|
||||
|
||||
batches[i_person_id1].append ( np.array([sample1.person_id]) )
|
||||
|
||||
batches[i_person_id2].append ( np.array([sample2.person_id]) )
|
||||
"""
|
|
@ -4,7 +4,7 @@ import cv2
|
|||
|
||||
from utils import iter_utils
|
||||
|
||||
from samplelib import SampleType, SampleProcessor, SampleLoader, SampleGeneratorBase
|
||||
from samplelib import SampleType, SampleProcessor, SampleHost, SampleGeneratorBase
|
||||
|
||||
'''
|
||||
output_sample_types = [
|
||||
|
@ -20,7 +20,7 @@ class SampleGeneratorFaceTemporal(SampleGeneratorBase):
|
|||
self.sample_process_options = sample_process_options
|
||||
self.output_sample_types = output_sample_types
|
||||
|
||||
self.samples = SampleLoader.load (SampleType.FACE_TEMPORAL_SORTED, self.samples_path)
|
||||
self.samples = SampleHost.load (SampleType.FACE_TEMPORAL_SORTED, self.samples_path)
|
||||
|
||||
if self.debug:
|
||||
self.generators_count = 1
|
||||
|
@ -71,7 +71,7 @@ class SampleGeneratorFaceTemporal(SampleGeneratorBase):
|
|||
for i in range( self.temporal_image_count ):
|
||||
sample = samples[ idx+i*mult ]
|
||||
try:
|
||||
temporal_samples += SampleProcessor.process (sample, self.sample_process_options, self.output_sample_types, self.debug)
|
||||
temporal_samples += SampleProcessor.process ([sample], self.sample_process_options, self.output_sample_types, self.debug)[0]
|
||||
except:
|
||||
raise Exception ("Exception occured in sample %s. Error: %s" % (sample.filename, traceback.format_exc() ) )
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ import cv2
|
|||
|
||||
from utils import iter_utils
|
||||
|
||||
from samplelib import SampleType, SampleProcessor, SampleLoader, SampleGeneratorBase
|
||||
from samplelib import SampleType, SampleProcessor, SampleHost, SampleGeneratorBase
|
||||
|
||||
'''
|
||||
output_sample_types = [
|
||||
|
@ -20,7 +20,7 @@ class SampleGeneratorImageTemporal(SampleGeneratorBase):
|
|||
self.sample_process_options = sample_process_options
|
||||
self.output_sample_types = output_sample_types
|
||||
|
||||
self.samples = SampleLoader.load (SampleType.IMAGE, self.samples_path)
|
||||
self.samples = SampleHost.load (SampleType.IMAGE, self.samples_path)
|
||||
|
||||
self.generator_samples = [ self.samples ]
|
||||
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, 0 )] if self.debug else \
|
||||
|
@ -66,7 +66,7 @@ class SampleGeneratorImageTemporal(SampleGeneratorBase):
|
|||
for i in range( self.temporal_image_count ):
|
||||
sample = samples[ idx+i*mult ]
|
||||
try:
|
||||
temporal_samples += SampleProcessor.process (sample, self.sample_process_options, self.output_sample_types, self.debug)
|
||||
temporal_samples += SampleProcessor.process ([sample], self.sample_process_options, self.output_sample_types, self.debug)[0]
|
||||
except:
|
||||
raise Exception ("Exception occured in sample %s. Error: %s" % (sample.filename, traceback.format_exc() ) )
|
||||
|
||||
|
|
187
samplelib/SampleHost.py
Normal file
187
samplelib/SampleHost.py
Normal file
|
@ -0,0 +1,187 @@
|
|||
import multiprocessing
|
||||
import operator
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
|
||||
import samplelib.PackedFaceset
|
||||
from DFLIMG import *
|
||||
from facelib import FaceType, LandmarksProcessor
|
||||
from interact import interact as io
|
||||
from joblib import Subprocessor
|
||||
from utils import Path_utils, mp_utils
|
||||
|
||||
from .Sample import Sample, SampleType
|
||||
|
||||
|
||||
class SampleHost:
|
||||
samples_cache = dict()
|
||||
host_cache = dict()
|
||||
|
||||
@staticmethod
|
||||
def get_person_id_max_count(samples_path):
|
||||
samples = None
|
||||
try:
|
||||
samples = samplelib.PackedFaceset.load(samples_path)
|
||||
except:
|
||||
io.log_err(f"Error occured while loading samplelib.PackedFaceset.load {str(samples_dat_path)}, {traceback.format_exc()}")
|
||||
|
||||
if samples is None:
|
||||
raise ValueError("packed faceset not found.")
|
||||
persons_name_idxs = {}
|
||||
for sample in samples:
|
||||
persons_name_idxs[sample.person_name] = 0
|
||||
return len(list(persons_name_idxs.keys()))
|
||||
|
||||
@staticmethod
|
||||
def load(sample_type, samples_path):
|
||||
samples_cache = SampleHost.samples_cache
|
||||
|
||||
if str(samples_path) not in samples_cache.keys():
|
||||
samples_cache[str(samples_path)] = [None]*SampleType.QTY
|
||||
|
||||
samples = samples_cache[str(samples_path)]
|
||||
|
||||
if sample_type == SampleType.IMAGE:
|
||||
if samples[sample_type] is None:
|
||||
samples[sample_type] = [ Sample(filename=filename) for filename in io.progress_bar_generator( Path_utils.get_image_paths(samples_path), "Loading") ]
|
||||
elif sample_type == SampleType.FACE:
|
||||
if samples[sample_type] is None:
|
||||
result = None
|
||||
try:
|
||||
result = samplelib.PackedFaceset.load(samples_path)
|
||||
except:
|
||||
io.log_err(f"Error occured while loading samplelib.PackedFaceset.load {str(samples_dat_path)}, {traceback.format_exc()}")
|
||||
|
||||
if result is not None:
|
||||
io.log_info (f"Loaded {len(result)} packed faces from {samples_path}")
|
||||
|
||||
if result is None:
|
||||
result = SampleHost.load_face_samples( Path_utils.get_image_paths(samples_path) )
|
||||
|
||||
samples[sample_type] = result
|
||||
|
||||
elif sample_type == SampleType.FACE_TEMPORAL_SORTED:
|
||||
if samples[sample_type] is None:
|
||||
samples[sample_type] = SampleHost.upgradeToFaceTemporalSortedSamples( SampleHost.load(SampleType.FACE, samples_path) )
|
||||
|
||||
return samples[sample_type]
|
||||
|
||||
@staticmethod
|
||||
def mp_host(sample_type, samples_path):
|
||||
result = SampleHost.load (sample_type, samples_path)
|
||||
|
||||
host_cache = SampleHost.host_cache
|
||||
if str(samples_path) not in host_cache.keys():
|
||||
host_cache[str(samples_path)] = [None]*SampleType.QTY
|
||||
hosts = host_cache[str(samples_path)]
|
||||
|
||||
if hosts[sample_type] is None:
|
||||
hosts[sample_type] = mp_utils.ListHost(result)
|
||||
|
||||
return hosts[sample_type]
|
||||
|
||||
@staticmethod
|
||||
def load_face_samples ( image_paths):
|
||||
result = FaceSamplesLoaderSubprocessor(image_paths).run()
|
||||
sample_list = []
|
||||
|
||||
for filename, \
|
||||
( face_type,
|
||||
shape,
|
||||
landmarks,
|
||||
ie_polys,
|
||||
eyebrows_expand_mod,
|
||||
source_filename,
|
||||
) in result:
|
||||
sample_list.append( Sample(filename=filename,
|
||||
sample_type=SampleType.FACE,
|
||||
face_type=FaceType.fromString (face_type),
|
||||
shape=shape,
|
||||
landmarks=landmarks,
|
||||
ie_polys=ie_polys,
|
||||
eyebrows_expand_mod=eyebrows_expand_mod,
|
||||
source_filename=source_filename,
|
||||
))
|
||||
return sample_list
|
||||
|
||||
@staticmethod
|
||||
def upgradeToFaceTemporalSortedSamples( samples ):
|
||||
new_s = [ (s, s.source_filename) for s in samples]
|
||||
new_s = sorted(new_s, key=operator.itemgetter(1))
|
||||
|
||||
return [ s[0] for s in new_s]
|
||||
|
||||
|
||||
class FaceSamplesLoaderSubprocessor(Subprocessor):
|
||||
#override
|
||||
def __init__(self, image_paths ):
|
||||
self.image_paths = image_paths
|
||||
self.image_paths_len = len(image_paths)
|
||||
self.idxs = [*range(self.image_paths_len)]
|
||||
self.result = [None]*self.image_paths_len
|
||||
super().__init__('FaceSamplesLoader', FaceSamplesLoaderSubprocessor.Cli, 60, initialize_subprocesses_in_serial=False)
|
||||
|
||||
#override
|
||||
def on_clients_initialized(self):
|
||||
io.progress_bar ("Loading", len (self.image_paths))
|
||||
|
||||
#override
|
||||
def on_clients_finalized(self):
|
||||
io.progress_bar_close()
|
||||
|
||||
#override
|
||||
def process_info_generator(self):
|
||||
for i in range(min(multiprocessing.cpu_count(), 8) ):
|
||||
yield 'CPU%d' % (i), {}, {'device_idx': i,
|
||||
'device_name': 'CPU%d' % (i),
|
||||
}
|
||||
|
||||
#override
|
||||
def get_data(self, host_dict):
|
||||
if len (self.idxs) > 0:
|
||||
idx = self.idxs.pop(0)
|
||||
return idx, self.image_paths[idx]
|
||||
|
||||
return None
|
||||
|
||||
#override
|
||||
def on_data_return (self, host_dict, data):
|
||||
self.idxs.insert(0, data[0])
|
||||
|
||||
#override
|
||||
def on_result (self, host_dict, data, result):
|
||||
idx, dflimg = result
|
||||
self.result[idx] = (self.image_paths[idx], dflimg)
|
||||
io.progress_bar_inc(1)
|
||||
|
||||
#override
|
||||
def get_result(self):
|
||||
return self.result
|
||||
|
||||
class Cli(Subprocessor.Cli):
|
||||
#override
|
||||
def on_initialize(self, client_dict):
|
||||
pass
|
||||
|
||||
#override
|
||||
def process_data(self, data):
|
||||
idx, filename = data
|
||||
dflimg = DFLIMG.load (Path(filename))
|
||||
|
||||
if dflimg is None:
|
||||
self.log_err (f"FaceSamplesLoader: {filename} is not a dfl image file.")
|
||||
data = None
|
||||
else:
|
||||
data = (dflimg.get_face_type(),
|
||||
dflimg.get_shape(),
|
||||
dflimg.get_landmarks(),
|
||||
dflimg.get_ie_polys(),
|
||||
dflimg.get_eyebrows_expand_mod(),
|
||||
dflimg.get_source_filename() )
|
||||
|
||||
return idx, data
|
||||
|
||||
#override
|
||||
def get_data_name (self, data):
|
||||
#return string identificator of your data
|
||||
return data[1]
|
|
@ -1,204 +0,0 @@
|
|||
import operator
|
||||
import pickle
|
||||
import traceback
|
||||
from enum import IntEnum
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from facelib import FaceType, LandmarksProcessor
|
||||
from interact import interact as io
|
||||
from utils import Path_utils
|
||||
from utils.DFLJPG import DFLJPG
|
||||
from utils.DFLPNG import DFLPNG
|
||||
|
||||
from .Sample import Sample, SampleType
|
||||
|
||||
|
||||
class SampleLoader:
|
||||
cache = dict()
|
||||
|
||||
@staticmethod
|
||||
def get_person_id_max_count(samples_path):
|
||||
return len ( Path_utils.get_all_dir_names(samples_path) )
|
||||
|
||||
@staticmethod
|
||||
def load(sample_type, samples_path, target_samples_path=None, person_id_mode=False, use_caching=False):
|
||||
cache = SampleLoader.cache
|
||||
|
||||
if str(samples_path) not in cache.keys():
|
||||
cache[str(samples_path)] = [None]*SampleType.QTY
|
||||
|
||||
datas = cache[str(samples_path)]
|
||||
|
||||
if sample_type == SampleType.IMAGE:
|
||||
if datas[sample_type] is None:
|
||||
datas[sample_type] = [ Sample(filename=filename) for filename in io.progress_bar_generator( Path_utils.get_image_paths(samples_path), "Loading") ]
|
||||
elif sample_type == SampleType.FACE:
|
||||
if datas[sample_type] is None:
|
||||
|
||||
if not use_caching:
|
||||
datas[sample_type] = SampleLoader.upgradeToFaceSamples( [ Sample(filename=filename) for filename in Path_utils.get_image_paths(samples_path) ] )
|
||||
else:
|
||||
samples_dat = samples_path / 'samples.dat'
|
||||
if samples_dat.exists():
|
||||
io.log_info (f"Using saved samples info from '{samples_dat}' ")
|
||||
|
||||
all_samples = pickle.loads(samples_dat.read_bytes())
|
||||
|
||||
if person_id_mode:
|
||||
for samples in all_samples:
|
||||
for sample in samples:
|
||||
sample.filename = str( samples_path / Path(sample.filename) )
|
||||
else:
|
||||
for sample in all_samples:
|
||||
sample.filename = str( samples_path / Path(sample.filename) )
|
||||
|
||||
datas[sample_type] = all_samples
|
||||
|
||||
else:
|
||||
if person_id_mode:
|
||||
dir_names = Path_utils.get_all_dir_names(samples_path)
|
||||
all_samples = []
|
||||
for i, dir_name in io.progress_bar_generator( [*enumerate(dir_names)] , "Loading"):
|
||||
all_samples += [ SampleLoader.upgradeToFaceSamples( [ Sample(filename=filename, person_id=i) for filename in Path_utils.get_image_paths( samples_path / dir_name ) ], silent=True ) ]
|
||||
datas[sample_type] = all_samples
|
||||
else:
|
||||
datas[sample_type] = all_samples = SampleLoader.upgradeToFaceSamples( [ Sample(filename=filename) for filename in Path_utils.get_image_paths(samples_path) ] )
|
||||
|
||||
if person_id_mode:
|
||||
for samples in all_samples:
|
||||
for sample in samples:
|
||||
sample.filename = str(Path(sample.filename).relative_to(samples_path))
|
||||
else:
|
||||
for sample in all_samples:
|
||||
sample.filename = str(Path(sample.filename).relative_to(samples_path))
|
||||
|
||||
samples_dat.write_bytes (pickle.dumps(all_samples))
|
||||
|
||||
if person_id_mode:
|
||||
for samples in all_samples:
|
||||
for sample in samples:
|
||||
sample.filename = str( samples_path / Path(sample.filename) )
|
||||
else:
|
||||
for sample in all_samples:
|
||||
sample.filename = str( samples_path / Path(sample.filename) )
|
||||
|
||||
elif sample_type == SampleType.FACE_TEMPORAL_SORTED:
|
||||
if datas[sample_type] is None:
|
||||
datas[sample_type] = SampleLoader.upgradeToFaceTemporalSortedSamples( SampleLoader.load(SampleType.FACE, samples_path) )
|
||||
|
||||
elif sample_type == SampleType.FACE_YAW_SORTED:
|
||||
if datas[sample_type] is None:
|
||||
datas[sample_type] = SampleLoader.upgradeToFaceYawSortedSamples( SampleLoader.load(SampleType.FACE, samples_path) )
|
||||
|
||||
elif sample_type == SampleType.FACE_YAW_SORTED_AS_TARGET:
|
||||
if datas[sample_type] is None:
|
||||
if target_samples_path is None:
|
||||
raise Exception('target_samples_path is None for FACE_YAW_SORTED_AS_TARGET')
|
||||
datas[sample_type] = SampleLoader.upgradeToFaceYawSortedAsTargetSamples( SampleLoader.load(SampleType.FACE_YAW_SORTED, samples_path), SampleLoader.load(SampleType.FACE_YAW_SORTED, target_samples_path) )
|
||||
|
||||
return datas[sample_type]
|
||||
|
||||
@staticmethod
|
||||
def upgradeToFaceSamples ( samples, silent=False ):
|
||||
sample_list = []
|
||||
|
||||
for s in (samples if silent else io.progress_bar_generator(samples, "Loading")):
|
||||
s_filename_path = Path(s.filename)
|
||||
try:
|
||||
if s_filename_path.suffix == '.png':
|
||||
dflimg = DFLPNG.load ( str(s_filename_path) )
|
||||
elif s_filename_path.suffix == '.jpg':
|
||||
dflimg = DFLJPG.load ( str(s_filename_path) )
|
||||
else:
|
||||
dflimg = None
|
||||
|
||||
if dflimg is None:
|
||||
print ("%s is not a dfl image file required for training" % (s_filename_path.name) )
|
||||
continue
|
||||
|
||||
landmarks = dflimg.get_landmarks()
|
||||
pitch_yaw_roll = dflimg.get_pitch_yaw_roll()
|
||||
eyebrows_expand_mod = dflimg.get_eyebrows_expand_mod()
|
||||
|
||||
if pitch_yaw_roll is None:
|
||||
pitch_yaw_roll = LandmarksProcessor.estimate_pitch_yaw_roll(landmarks)
|
||||
|
||||
sample_list.append( s.copy_and_set(sample_type=SampleType.FACE,
|
||||
face_type=FaceType.fromString (dflimg.get_face_type()),
|
||||
shape=dflimg.get_shape(),
|
||||
landmarks=landmarks,
|
||||
ie_polys=dflimg.get_ie_polys(),
|
||||
pitch_yaw_roll=pitch_yaw_roll,
|
||||
eyebrows_expand_mod=eyebrows_expand_mod,
|
||||
source_filename=dflimg.get_source_filename(),
|
||||
fanseg_mask_exist=dflimg.get_fanseg_mask() is not None, ) )
|
||||
except:
|
||||
print ("Unable to load %s , error: %s" % (str(s_filename_path), traceback.format_exc() ) )
|
||||
|
||||
return sample_list
|
||||
|
||||
@staticmethod
|
||||
def upgradeToFaceTemporalSortedSamples( samples ):
|
||||
new_s = [ (s, s.source_filename) for s in samples]
|
||||
new_s = sorted(new_s, key=operator.itemgetter(1))
|
||||
|
||||
return [ s[0] for s in new_s]
|
||||
|
||||
@staticmethod
|
||||
def upgradeToFaceYawSortedSamples( samples ):
|
||||
|
||||
lowest_yaw, highest_yaw = -1.0, 1.0
|
||||
gradations = 64
|
||||
diff_rot_per_grad = abs(highest_yaw-lowest_yaw) / gradations
|
||||
|
||||
yaws_sample_list = [None]*gradations
|
||||
|
||||
for i in io.progress_bar_generator(range(gradations), "Sorting"):
|
||||
yaw = lowest_yaw + i*diff_rot_per_grad
|
||||
next_yaw = lowest_yaw + (i+1)*diff_rot_per_grad
|
||||
|
||||
yaw_samples = []
|
||||
for s in samples:
|
||||
s_yaw = s.pitch_yaw_roll[1]
|
||||
if (i == 0 and s_yaw < next_yaw) or \
|
||||
(i < gradations-1 and s_yaw >= yaw and s_yaw < next_yaw) or \
|
||||
(i == gradations-1 and s_yaw >= yaw):
|
||||
yaw_samples.append ( s.copy_and_set(sample_type=SampleType.FACE_YAW_SORTED) )
|
||||
|
||||
if len(yaw_samples) > 0:
|
||||
yaws_sample_list[i] = yaw_samples
|
||||
|
||||
return yaws_sample_list
|
||||
|
||||
@staticmethod
|
||||
def upgradeToFaceYawSortedAsTargetSamples (s, t):
|
||||
l = len(s)
|
||||
if l != len(t):
|
||||
raise Exception('upgradeToFaceYawSortedAsTargetSamples() s_len != t_len')
|
||||
b = l // 2
|
||||
|
||||
s_idxs = np.argwhere ( np.array ( [ 1 if x != None else 0 for x in s] ) == 1 )[:,0]
|
||||
t_idxs = np.argwhere ( np.array ( [ 1 if x != None else 0 for x in t] ) == 1 )[:,0]
|
||||
|
||||
new_s = [None]*l
|
||||
|
||||
for t_idx in t_idxs:
|
||||
search_idxs = []
|
||||
for i in range(0,l):
|
||||
search_idxs += [t_idx - i, (l-t_idx-1) - i, t_idx + i, (l-t_idx-1) + i]
|
||||
|
||||
for search_idx in search_idxs:
|
||||
if search_idx in s_idxs:
|
||||
mirrored = ( t_idx != search_idx and ((t_idx < b and search_idx >= b) or (search_idx < b and t_idx >= b)) )
|
||||
new_s[t_idx] = [ sample.copy_and_set(sample_type=SampleType.FACE_YAW_SORTED_AS_TARGET,
|
||||
mirror=True,
|
||||
pitch_yaw_roll=(sample.pitch_yaw_roll[0],-sample.pitch_yaw_roll[1],sample.pitch_yaw_roll[2]),
|
||||
landmarks=LandmarksProcessor.mirror_landmarks (sample.landmarks, sample.shape[1] ))
|
||||
for sample in s[search_idx]
|
||||
] if mirrored else s[search_idx]
|
||||
break
|
||||
|
||||
return new_s
|
|
@ -92,231 +92,218 @@ class SampleProcessor(object):
|
|||
}
|
||||
|
||||
@staticmethod
|
||||
def process (sample, sample_process_options, output_sample_types, debug, ct_sample=None):
|
||||
def process (samples, sample_process_options, output_sample_types, debug, ct_sample=None):
|
||||
SPTF = SampleProcessor.Types
|
||||
|
||||
sample_bgr = sample.load_bgr()
|
||||
ct_sample_bgr = None
|
||||
ct_sample_mask = None
|
||||
h,w,c = sample_bgr.shape
|
||||
|
||||
is_face_sample = sample.landmarks is not None
|
||||
|
||||
if debug and is_face_sample:
|
||||
LandmarksProcessor.draw_landmarks (sample_bgr, sample.landmarks, (0, 1, 0))
|
||||
|
||||
params = imagelib.gen_warp_params(sample_bgr, sample_process_options.random_flip, rotation_range=sample_process_options.rotation_range, scale_range=sample_process_options.scale_range, tx_range=sample_process_options.tx_range, ty_range=sample_process_options.ty_range )
|
||||
|
||||
cached_images = collections.defaultdict(dict)
|
||||
|
||||
sample_rnd_seed = np.random.randint(0x80000000)
|
||||
|
||||
outputs = []
|
||||
for opts in output_sample_types:
|
||||
for sample in samples:
|
||||
sample_bgr = sample.load_bgr()
|
||||
ct_sample_bgr = None
|
||||
ct_sample_mask = None
|
||||
h,w,c = sample_bgr.shape
|
||||
|
||||
resolution = opts.get('resolution', 0)
|
||||
types = opts.get('types', [] )
|
||||
is_face_sample = sample.landmarks is not None
|
||||
|
||||
border_replicate = opts.get('border_replicate', True)
|
||||
random_sub_res = opts.get('random_sub_res', 0)
|
||||
normalize_std_dev = opts.get('normalize_std_dev', False)
|
||||
normalize_vgg = opts.get('normalize_vgg', False)
|
||||
motion_blur = opts.get('motion_blur', None)
|
||||
gaussian_blur = opts.get('gaussian_blur', None)
|
||||
if debug and is_face_sample:
|
||||
LandmarksProcessor.draw_landmarks (sample_bgr, sample.landmarks, (0, 1, 0))
|
||||
|
||||
random_hsv_shift = opts.get('random_hsv_shift', None)
|
||||
ct_mode = opts.get('ct_mode', 'None')
|
||||
normalize_tanh = opts.get('normalize_tanh', False)
|
||||
params = imagelib.gen_warp_params(sample_bgr, sample_process_options.random_flip, rotation_range=sample_process_options.rotation_range, scale_range=sample_process_options.scale_range, tx_range=sample_process_options.tx_range, ty_range=sample_process_options.ty_range, rnd_seed=sample_rnd_seed )
|
||||
|
||||
img_type = SPTF.NONE
|
||||
target_face_type = SPTF.NONE
|
||||
face_mask_type = SPTF.NONE
|
||||
mode_type = SPTF.NONE
|
||||
for t in types:
|
||||
if t >= SPTF.IMG_TYPE_BEGIN and t < SPTF.IMG_TYPE_END:
|
||||
img_type = t
|
||||
elif t >= SPTF.FACE_TYPE_BEGIN and t < SPTF.FACE_TYPE_END:
|
||||
target_face_type = t
|
||||
elif t >= SPTF.MODE_BEGIN and t < SPTF.MODE_END:
|
||||
mode_type = t
|
||||
outputs_sample = []
|
||||
for opts in output_sample_types:
|
||||
|
||||
if img_type == SPTF.NONE:
|
||||
raise ValueError ('expected IMG_ type')
|
||||
resolution = opts.get('resolution', 0)
|
||||
types = opts.get('types', [] )
|
||||
|
||||
if img_type == SPTF.IMG_LANDMARKS_ARRAY:
|
||||
l = sample.landmarks
|
||||
l = np.concatenate ( [ np.expand_dims(l[:,0] / w,-1), np.expand_dims(l[:,1] / h,-1) ], -1 )
|
||||
l = np.clip(l, 0.0, 1.0)
|
||||
img = l
|
||||
elif img_type == SPTF.IMG_PITCH_YAW_ROLL or img_type == SPTF.IMG_PITCH_YAW_ROLL_SIGMOID:
|
||||
pitch_yaw_roll = sample.pitch_yaw_roll
|
||||
if pitch_yaw_roll is not None:
|
||||
pitch, yaw, roll = pitch_yaw_roll
|
||||
border_replicate = opts.get('border_replicate', True)
|
||||
random_sub_res = opts.get('random_sub_res', 0)
|
||||
normalize_std_dev = opts.get('normalize_std_dev', False)
|
||||
normalize_vgg = opts.get('normalize_vgg', False)
|
||||
motion_blur = opts.get('motion_blur', None)
|
||||
gaussian_blur = opts.get('gaussian_blur', None)
|
||||
|
||||
ct_mode = opts.get('ct_mode', 'None')
|
||||
normalize_tanh = opts.get('normalize_tanh', False)
|
||||
|
||||
img_type = SPTF.NONE
|
||||
target_face_type = SPTF.NONE
|
||||
face_mask_type = SPTF.NONE
|
||||
mode_type = SPTF.NONE
|
||||
for t in types:
|
||||
if t >= SPTF.IMG_TYPE_BEGIN and t < SPTF.IMG_TYPE_END:
|
||||
img_type = t
|
||||
elif t >= SPTF.FACE_TYPE_BEGIN and t < SPTF.FACE_TYPE_END:
|
||||
target_face_type = t
|
||||
elif t >= SPTF.MODE_BEGIN and t < SPTF.MODE_END:
|
||||
mode_type = t
|
||||
|
||||
if img_type == SPTF.NONE:
|
||||
raise ValueError ('expected IMG_ type')
|
||||
|
||||
if img_type == SPTF.IMG_LANDMARKS_ARRAY:
|
||||
l = sample.landmarks
|
||||
l = np.concatenate ( [ np.expand_dims(l[:,0] / w,-1), np.expand_dims(l[:,1] / h,-1) ], -1 )
|
||||
l = np.clip(l, 0.0, 1.0)
|
||||
img = l
|
||||
elif img_type == SPTF.IMG_PITCH_YAW_ROLL or img_type == SPTF.IMG_PITCH_YAW_ROLL_SIGMOID:
|
||||
pitch_yaw_roll = sample.get_pitch_yaw_roll()
|
||||
|
||||
if params['flip']:
|
||||
yaw = -yaw
|
||||
|
||||
if img_type == SPTF.IMG_PITCH_YAW_ROLL_SIGMOID:
|
||||
pitch = (pitch+1.0) / 2.0
|
||||
yaw = (yaw+1.0) / 2.0
|
||||
roll = (roll+1.0) / 2.0
|
||||
|
||||
img = (pitch, yaw, roll)
|
||||
else:
|
||||
pitch, yaw, roll = LandmarksProcessor.estimate_pitch_yaw_roll (sample.landmarks)
|
||||
if params['flip']:
|
||||
yaw = -yaw
|
||||
if mode_type == SPTF.NONE:
|
||||
raise ValueError ('expected MODE_ type')
|
||||
|
||||
if img_type == SPTF.IMG_PITCH_YAW_ROLL_SIGMOID:
|
||||
pitch = (pitch+1.0) / 2.0
|
||||
yaw = (yaw+1.0) / 2.0
|
||||
roll = (roll+1.0) / 2.0
|
||||
def do_transform(img, mask):
|
||||
warp = (img_type==SPTF.IMG_WARPED or img_type==SPTF.IMG_WARPED_TRANSFORMED)
|
||||
transform = (img_type==SPTF.IMG_WARPED_TRANSFORMED or img_type==SPTF.IMG_TRANSFORMED)
|
||||
flip = img_type != SPTF.IMG_WARPED
|
||||
|
||||
img = (pitch, yaw, roll)
|
||||
else:
|
||||
if mode_type == SPTF.NONE:
|
||||
raise ValueError ('expected MODE_ type')
|
||||
img = imagelib.warp_by_params (params, img, warp, transform, flip, border_replicate)
|
||||
if mask is not None:
|
||||
mask = imagelib.warp_by_params (params, mask, warp, transform, flip, False)
|
||||
if len(mask.shape) == 2:
|
||||
mask = mask[...,np.newaxis]
|
||||
|
||||
def do_transform(img, mask):
|
||||
warp = (img_type==SPTF.IMG_WARPED or img_type==SPTF.IMG_WARPED_TRANSFORMED)
|
||||
transform = (img_type==SPTF.IMG_WARPED_TRANSFORMED or img_type==SPTF.IMG_TRANSFORMED)
|
||||
flip = img_type != SPTF.IMG_WARPED
|
||||
|
||||
img = imagelib.warp_by_params (params, img, warp, transform, flip, border_replicate)
|
||||
if mask is not None:
|
||||
mask = imagelib.warp_by_params (params, mask, warp, transform, flip, False)
|
||||
if len(mask.shape) == 2:
|
||||
mask = mask[...,np.newaxis]
|
||||
return img, mask
|
||||
|
||||
img = np.concatenate( (img, mask ), -1 )
|
||||
return img
|
||||
img = sample_bgr
|
||||
|
||||
img = sample_bgr
|
||||
|
||||
### Prepare a mask
|
||||
mask = None
|
||||
if is_face_sample:
|
||||
mask = sample.load_fanseg_mask() #using fanseg_mask if exist
|
||||
|
||||
if mask is None:
|
||||
### Prepare a mask
|
||||
mask = None
|
||||
if is_face_sample:
|
||||
if sample.eyebrows_expand_mod is not None:
|
||||
mask = LandmarksProcessor.get_image_hull_mask (img.shape, sample.landmarks, eyebrows_expand_mod=sample.eyebrows_expand_mod )
|
||||
else:
|
||||
mask = LandmarksProcessor.get_image_hull_mask (img.shape, sample.landmarks)
|
||||
|
||||
if sample.ie_polys is not None:
|
||||
sample.ie_polys.overlay_mask(mask)
|
||||
##################
|
||||
if sample.ie_polys is not None:
|
||||
sample.ie_polys.overlay_mask(mask)
|
||||
##################
|
||||
|
||||
|
||||
if motion_blur is not None:
|
||||
chance, mb_max_size = motion_blur
|
||||
chance = np.clip(chance, 0, 100)
|
||||
if motion_blur is not None:
|
||||
chance, mb_max_size = motion_blur
|
||||
chance = np.clip(chance, 0, 100)
|
||||
|
||||
if np.random.randint(100) < chance:
|
||||
img = imagelib.LinearMotionBlur (img, np.random.randint( mb_max_size )+1, np.random.randint(360) )
|
||||
if np.random.randint(100) < chance:
|
||||
img = imagelib.LinearMotionBlur (img, np.random.randint( mb_max_size )+1, np.random.randint(360) )
|
||||
|
||||
if gaussian_blur is not None:
|
||||
chance, kernel_max_size = gaussian_blur
|
||||
chance = np.clip(chance, 0, 100)
|
||||
if gaussian_blur is not None:
|
||||
chance, kernel_max_size = gaussian_blur
|
||||
chance = np.clip(chance, 0, 100)
|
||||
|
||||
if np.random.randint(100) < chance:
|
||||
img = cv2.GaussianBlur(img, ( np.random.randint( kernel_max_size )*2+1 ,) *2 , 0)
|
||||
if np.random.randint(100) < chance:
|
||||
img = cv2.GaussianBlur(img, ( np.random.randint( kernel_max_size )*2+1 ,) *2 , 0)
|
||||
|
||||
if is_face_sample and target_face_type != SPTF.NONE:
|
||||
target_ft = SampleProcessor.SPTF_FACETYPE_TO_FACETYPE[target_face_type]
|
||||
if target_ft > sample.face_type:
|
||||
raise Exception ('sample %s type %s does not match model requirement %s. Consider extract necessary type of faces.' % (sample.filename, sample.face_type, target_ft) )
|
||||
if is_face_sample and target_face_type != SPTF.NONE:
|
||||
target_ft = SampleProcessor.SPTF_FACETYPE_TO_FACETYPE[target_face_type]
|
||||
if target_ft > sample.face_type:
|
||||
raise Exception ('sample %s type %s does not match model requirement %s. Consider extract necessary type of faces.' % (sample.filename, sample.face_type, target_ft) )
|
||||
|
||||
if sample.face_type == FaceType.MARK_ONLY:
|
||||
#first warp to target facetype
|
||||
img = cv2.warpAffine( img, LandmarksProcessor.get_transform_mat (sample.landmarks, sample.shape[0], target_ft), (sample.shape[0],sample.shape[0]), flags=cv2.INTER_CUBIC )
|
||||
mask = cv2.warpAffine( mask, LandmarksProcessor.get_transform_mat (sample.landmarks, sample.shape[0], target_ft), (sample.shape[0],sample.shape[0]), flags=cv2.INTER_CUBIC )
|
||||
#then apply transforms
|
||||
img = do_transform (img, mask)
|
||||
if sample.face_type == FaceType.MARK_ONLY:
|
||||
#first warp to target facetype
|
||||
img = cv2.warpAffine( img, LandmarksProcessor.get_transform_mat (sample.landmarks, sample.shape[0], target_ft), (sample.shape[0],sample.shape[0]), flags=cv2.INTER_CUBIC )
|
||||
mask = cv2.warpAffine( mask, LandmarksProcessor.get_transform_mat (sample.landmarks, sample.shape[0], target_ft), (sample.shape[0],sample.shape[0]), flags=cv2.INTER_CUBIC )
|
||||
#then apply transforms
|
||||
img, mask = do_transform (img, mask)
|
||||
img = np.concatenate( (img, mask ), -1 )
|
||||
img = cv2.resize( img, (resolution,resolution), cv2.INTER_CUBIC )
|
||||
else:
|
||||
img, mask = do_transform (img, mask)
|
||||
|
||||
mat = LandmarksProcessor.get_transform_mat (sample.landmarks, resolution, target_ft)
|
||||
img = cv2.warpAffine( img, mat, (resolution,resolution), borderMode=(cv2.BORDER_REPLICATE if border_replicate else cv2.BORDER_CONSTANT), flags=cv2.INTER_CUBIC )
|
||||
mask = cv2.warpAffine( mask, mat, (resolution,resolution), borderMode=cv2.BORDER_CONSTANT, flags=cv2.INTER_CUBIC )
|
||||
img = np.concatenate( (img, mask[...,None] ), -1 )
|
||||
|
||||
else:
|
||||
img, mask = do_transform (img, mask)
|
||||
img = np.concatenate( (img, mask ), -1 )
|
||||
img = cv2.resize( img, (resolution,resolution), cv2.INTER_CUBIC )
|
||||
else:
|
||||
img = do_transform (img, mask)
|
||||
img = cv2.warpAffine( img, LandmarksProcessor.get_transform_mat (sample.landmarks, resolution, target_ft), (resolution,resolution), borderMode=(cv2.BORDER_REPLICATE if border_replicate else cv2.BORDER_CONSTANT), flags=cv2.INTER_CUBIC )
|
||||
|
||||
else:
|
||||
img = do_transform (img, mask)
|
||||
img = cv2.resize( img, (resolution,resolution), cv2.INTER_CUBIC )
|
||||
if random_sub_res != 0:
|
||||
sub_size = resolution - random_sub_res
|
||||
rnd_state = np.random.RandomState (sample_rnd_seed+random_sub_res)
|
||||
start_x = rnd_state.randint(sub_size+1)
|
||||
start_y = rnd_state.randint(sub_size+1)
|
||||
img = img[start_y:start_y+sub_size,start_x:start_x+sub_size,:]
|
||||
|
||||
if random_sub_res != 0:
|
||||
sub_size = resolution - random_sub_res
|
||||
rnd_state = np.random.RandomState (sample_rnd_seed+random_sub_res)
|
||||
start_x = rnd_state.randint(sub_size+1)
|
||||
start_y = rnd_state.randint(sub_size+1)
|
||||
img = img[start_y:start_y+sub_size,start_x:start_x+sub_size,:]
|
||||
img = np.clip(img, 0, 1).astype(np.float32)
|
||||
img_bgr = img[...,0:3]
|
||||
img_mask = img[...,3:4]
|
||||
|
||||
img = np.clip(img, 0, 1).astype(np.float32)
|
||||
img_bgr = img[...,0:3]
|
||||
img_mask = img[...,3:4]
|
||||
if ct_mode is not None and ct_sample is not None:
|
||||
if ct_sample_bgr is None:
|
||||
ct_sample_bgr = ct_sample.load_bgr()
|
||||
|
||||
if ct_mode is not None and ct_sample is not None:
|
||||
if ct_sample_bgr is None:
|
||||
ct_sample_bgr = ct_sample.load_bgr()
|
||||
ct_sample_bgr_resized = cv2.resize( ct_sample_bgr, (resolution,resolution), cv2.INTER_LINEAR )
|
||||
|
||||
ct_sample_bgr_resized = cv2.resize( ct_sample_bgr, (resolution,resolution), cv2.INTER_LINEAR )
|
||||
if ct_mode == 'lct':
|
||||
img_bgr = imagelib.linear_color_transfer (img_bgr, ct_sample_bgr_resized)
|
||||
img_bgr = np.clip( img_bgr, 0.0, 1.0)
|
||||
elif ct_mode == 'rct':
|
||||
img_bgr = imagelib.reinhard_color_transfer ( np.clip( (img_bgr*255).astype(np.uint8), 0, 255),
|
||||
np.clip( (ct_sample_bgr_resized*255).astype(np.uint8), 0, 255) )
|
||||
img_bgr = np.clip( img_bgr.astype(np.float32) / 255.0, 0.0, 1.0)
|
||||
elif ct_mode == 'mkl':
|
||||
img_bgr = imagelib.color_transfer_mkl (img_bgr, ct_sample_bgr_resized)
|
||||
elif ct_mode == 'idt':
|
||||
img_bgr = imagelib.color_transfer_idt (img_bgr, ct_sample_bgr_resized)
|
||||
elif ct_mode == 'sot':
|
||||
img_bgr = imagelib.color_transfer_sot (img_bgr, ct_sample_bgr_resized)
|
||||
img_bgr = np.clip( img_bgr, 0.0, 1.0)
|
||||
|
||||
if ct_mode == 'lct':
|
||||
img_bgr = imagelib.linear_color_transfer (img_bgr, ct_sample_bgr_resized)
|
||||
img_bgr = np.clip( img_bgr, 0.0, 1.0)
|
||||
elif ct_mode == 'rct':
|
||||
img_bgr = imagelib.reinhard_color_transfer ( np.clip( (img_bgr*255).astype(np.uint8), 0, 255),
|
||||
np.clip( (ct_sample_bgr_resized*255).astype(np.uint8), 0, 255) )
|
||||
img_bgr = np.clip( img_bgr.astype(np.float32) / 255.0, 0.0, 1.0)
|
||||
elif ct_mode == 'mkl':
|
||||
img_bgr = imagelib.color_transfer_mkl (img_bgr, ct_sample_bgr_resized)
|
||||
elif ct_mode == 'idt':
|
||||
img_bgr = imagelib.color_transfer_idt (img_bgr, ct_sample_bgr_resized)
|
||||
elif ct_mode == 'sot':
|
||||
img_bgr = imagelib.color_transfer_sot (img_bgr, ct_sample_bgr_resized)
|
||||
img_bgr = np.clip( img_bgr, 0.0, 1.0)
|
||||
if normalize_std_dev:
|
||||
img_bgr = (img_bgr - img_bgr.mean( (0,1)) ) / img_bgr.std( (0,1) )
|
||||
elif normalize_vgg:
|
||||
img_bgr = np.clip(img_bgr*255, 0, 255)
|
||||
img_bgr[:,:,0] -= 103.939
|
||||
img_bgr[:,:,1] -= 116.779
|
||||
img_bgr[:,:,2] -= 123.68
|
||||
|
||||
if random_hsv_shift:
|
||||
rnd_state = np.random.RandomState (sample_rnd_seed)
|
||||
hsv = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2HSV)
|
||||
h, s, v = cv2.split(hsv)
|
||||
if mode_type == SPTF.MODE_BGR:
|
||||
img = img_bgr
|
||||
elif mode_type == SPTF.MODE_BGR_SHUFFLE:
|
||||
rnd_state = np.random.RandomState (sample_rnd_seed)
|
||||
img = np.take (img_bgr, rnd_state.permutation(img_bgr.shape[-1]), axis=-1)
|
||||
|
||||
h = (h + rnd_state.randint(360) ) % 360
|
||||
s = np.clip ( s + rnd_state.random()-0.5, 0, 1 )
|
||||
v = np.clip ( v + rnd_state.random()-0.5, 0, 1 )
|
||||
hsv = cv2.merge([h, s, v])
|
||||
elif mode_type == SPTF.MODE_BGR_RANDOM_HSV_SHIFT:
|
||||
rnd_state = np.random.RandomState (sample_rnd_seed)
|
||||
hsv = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2HSV)
|
||||
h, s, v = cv2.split(hsv)
|
||||
h = (h + rnd_state.randint(360) ) % 360
|
||||
s = np.clip ( s + rnd_state.random()-0.5, 0, 1 )
|
||||
v = np.clip ( v + rnd_state.random()-0.5, 0, 1 )
|
||||
hsv = cv2.merge([h, s, v])
|
||||
img = np.clip( cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) , 0, 1 )
|
||||
elif mode_type == SPTF.MODE_G:
|
||||
img = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)[...,None]
|
||||
elif mode_type == SPTF.MODE_GGG:
|
||||
img = np.repeat ( np.expand_dims(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY),-1), (3,), -1)
|
||||
elif mode_type == SPTF.MODE_M and is_face_sample:
|
||||
img = img_mask
|
||||
|
||||
img_bgr = np.clip( cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) , 0, 1 )
|
||||
if not debug:
|
||||
if normalize_tanh:
|
||||
img = np.clip (img * 2.0 - 1.0, -1.0, 1.0)
|
||||
else:
|
||||
img = np.clip (img, 0.0, 1.0)
|
||||
|
||||
if normalize_std_dev:
|
||||
img_bgr = (img_bgr - img_bgr.mean( (0,1)) ) / img_bgr.std( (0,1) )
|
||||
elif normalize_vgg:
|
||||
img_bgr = np.clip(img_bgr*255, 0, 255)
|
||||
img_bgr[:,:,0] -= 103.939
|
||||
img_bgr[:,:,1] -= 116.779
|
||||
img_bgr[:,:,2] -= 123.68
|
||||
outputs_sample.append ( img )
|
||||
outputs += [outputs_sample]
|
||||
|
||||
if mode_type == SPTF.MODE_BGR:
|
||||
img = img_bgr
|
||||
elif mode_type == SPTF.MODE_BGR_SHUFFLE:
|
||||
rnd_state = np.random.RandomState (sample_rnd_seed)
|
||||
img = np.take (img_bgr, rnd_state.permutation(img_bgr.shape[-1]), axis=-1)
|
||||
elif mode_type == SPTF.MODE_G:
|
||||
img = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)[...,None]
|
||||
elif mode_type == SPTF.MODE_GGG:
|
||||
img = np.repeat ( np.expand_dims(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY),-1), (3,), -1)
|
||||
elif mode_type == SPTF.MODE_M and is_face_sample:
|
||||
img = img_mask
|
||||
|
||||
if not debug:
|
||||
if normalize_tanh:
|
||||
img = np.clip (img * 2.0 - 1.0, -1.0, 1.0)
|
||||
else:
|
||||
img = np.clip (img, 0.0, 1.0)
|
||||
|
||||
outputs.append ( img )
|
||||
|
||||
if debug:
|
||||
result = []
|
||||
|
||||
for output in outputs:
|
||||
if output.shape[2] < 4:
|
||||
result += [output,]
|
||||
elif output.shape[2] == 4:
|
||||
result += [output[...,0:3]*output[...,3:4],]
|
||||
|
||||
return result
|
||||
else:
|
||||
return outputs
|
||||
return outputs
|
||||
|
||||
"""
|
||||
close_sample = sample.close_target_list[ np.random.randint(0, len(sample.close_target_list)) ] if sample.close_target_list is not None else None
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
from .Sample import Sample
|
||||
from .Sample import SampleType
|
||||
from .SampleLoader import SampleLoader
|
||||
from .SampleHost import SampleHost
|
||||
from .SampleProcessor import SampleProcessor
|
||||
from .SampleGeneratorBase import SampleGeneratorBase
|
||||
from .SampleGeneratorFace import SampleGeneratorFace
|
||||
from .SampleGeneratorFacePerson import SampleGeneratorFacePerson
|
||||
from .SampleGeneratorFaceTemporal import SampleGeneratorFaceTemporal
|
||||
from .SampleGeneratorImageTemporal import SampleGeneratorImageTemporal
|
||||
from .PackedFaceset import PackedFaceset
|
|
@ -1,15 +1,20 @@
|
|||
import cv2
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import traceback
|
||||
|
||||
#allows to open non-english characters path
|
||||
def cv2_imread(filename, flags=cv2.IMREAD_UNCHANGED):
|
||||
def cv2_imread(filename, flags=cv2.IMREAD_UNCHANGED, loader_func=None):
|
||||
try:
|
||||
with open(filename, "rb") as stream:
|
||||
bytes = bytearray(stream.read())
|
||||
numpyarray = np.asarray(bytes, dtype=np.uint8)
|
||||
return cv2.imdecode(numpyarray, flags)
|
||||
if loader_func is not None:
|
||||
bytes = bytearray(loader_func(filename))
|
||||
else:
|
||||
with open(filename, "rb") as stream:
|
||||
bytes = bytearray(stream.read())
|
||||
numpyarray = np.asarray(bytes, dtype=np.uint8)
|
||||
return cv2.imdecode(numpyarray, flags)
|
||||
except:
|
||||
io.log_err(f"Exception occured in cv2_imread : {traceback.format_exc()}")
|
||||
return None
|
||||
|
||||
def cv2_imwrite(filename, img, *args):
|
||||
|
|
|
@ -22,7 +22,7 @@ class ThisThreadGenerator(object):
|
|||
return next(self.generator_func)
|
||||
|
||||
class SubprocessGenerator(object):
|
||||
def __init__(self, generator_func, user_param=None, prefetch=2):
|
||||
def __init__(self, generator_func, user_param=None, prefetch=2, start_now=False):
|
||||
super().__init__()
|
||||
self.prefetch = prefetch
|
||||
self.generator_func = generator_func
|
||||
|
@ -30,6 +30,16 @@ class SubprocessGenerator(object):
|
|||
self.sc_queue = multiprocessing.Queue()
|
||||
self.cs_queue = multiprocessing.Queue()
|
||||
self.p = None
|
||||
if start_now:
|
||||
self._start()
|
||||
|
||||
def _start(self):
|
||||
if self.p == None:
|
||||
user_param = self.user_param
|
||||
self.user_param = None
|
||||
self.p = multiprocessing.Process(target=self.process_func, args=(user_param,) )
|
||||
self.p.daemon = True
|
||||
self.p.start()
|
||||
|
||||
def process_func(self, user_param):
|
||||
self.generator_func = self.generator_func(user_param)
|
||||
|
@ -54,13 +64,7 @@ class SubprocessGenerator(object):
|
|||
return self_dict
|
||||
|
||||
def __next__(self):
|
||||
if self.p == None:
|
||||
user_param = self.user_param
|
||||
self.user_param = None
|
||||
self.p = multiprocessing.Process(target=self.process_func, args=(user_param,) )
|
||||
self.p.daemon = True
|
||||
self.p.start()
|
||||
|
||||
self._start()
|
||||
gen_data = self.cs_queue.get()
|
||||
if gen_data is None:
|
||||
self.p.terminate()
|
||||
|
|
269
utils/mp_utils.py
Normal file
269
utils/mp_utils.py
Normal file
|
@ -0,0 +1,269 @@
|
|||
import multiprocessing
|
||||
import threading
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
class Index2DHost():
|
||||
"""
|
||||
Provides random shuffled 2D indexes for multiprocesses
|
||||
"""
|
||||
def __init__(self, indexes2D):
|
||||
self.sq = multiprocessing.Queue()
|
||||
self.cqs = []
|
||||
self.clis = []
|
||||
self.thread = threading.Thread(target=self.host_thread, args=(indexes2D,) )
|
||||
self.thread.daemon = True
|
||||
self.thread.start()
|
||||
|
||||
def host_thread(self, indexes2D):
|
||||
indexes_counts_len = len(indexes2D)
|
||||
|
||||
idxs = [*range(indexes_counts_len)]
|
||||
idxs_2D = [None]*indexes_counts_len
|
||||
shuffle_idxs = []
|
||||
shuffle_idxs_2D = [None]*indexes_counts_len
|
||||
for i in range(indexes_counts_len):
|
||||
idxs_2D[i] = indexes2D[i]
|
||||
shuffle_idxs_2D[i] = []
|
||||
|
||||
sq = self.sq
|
||||
|
||||
while True:
|
||||
while not sq.empty():
|
||||
obj = sq.get()
|
||||
cq_id, cmd = obj[0], obj[1]
|
||||
|
||||
if cmd == 0: #get_1D
|
||||
count = obj[2]
|
||||
|
||||
result = []
|
||||
for i in range(count):
|
||||
if len(shuffle_idxs) == 0:
|
||||
shuffle_idxs = idxs.copy()
|
||||
np.random.shuffle(shuffle_idxs)
|
||||
result.append(shuffle_idxs.pop())
|
||||
self.cqs[cq_id].put (result)
|
||||
elif cmd == 1: #get_2D
|
||||
targ_idxs,count = obj[2], obj[3]
|
||||
result = []
|
||||
|
||||
for targ_idx in targ_idxs:
|
||||
sub_idxs = []
|
||||
for i in range(count):
|
||||
ar = shuffle_idxs_2D[targ_idx]
|
||||
if len(ar) == 0:
|
||||
ar = shuffle_idxs_2D[targ_idx] = idxs_2D[targ_idx].copy()
|
||||
np.random.shuffle(ar)
|
||||
sub_idxs.append(ar.pop())
|
||||
result.append (sub_idxs)
|
||||
self.cqs[cq_id].put (result)
|
||||
|
||||
time.sleep(0.005)
|
||||
|
||||
def create_cli(self):
|
||||
cq = multiprocessing.Queue()
|
||||
self.cqs.append ( cq )
|
||||
cq_id = len(self.cqs)-1
|
||||
return Index2DHost.Cli(self.sq, cq, cq_id)
|
||||
|
||||
# disable pickling
|
||||
def __getstate__(self):
|
||||
return dict()
|
||||
def __setstate__(self, d):
|
||||
self.__dict__.update(d)
|
||||
|
||||
class Cli():
|
||||
def __init__(self, sq, cq, cq_id):
|
||||
self.sq = sq
|
||||
self.cq = cq
|
||||
self.cq_id = cq_id
|
||||
|
||||
def get_1D(self, count):
|
||||
self.sq.put ( (self.cq_id,0, count) )
|
||||
|
||||
while True:
|
||||
if not self.cq.empty():
|
||||
return self.cq.get()
|
||||
time.sleep(0.001)
|
||||
|
||||
def get_2D(self, idxs, count):
|
||||
self.sq.put ( (self.cq_id,1,idxs,count) )
|
||||
|
||||
while True:
|
||||
if not self.cq.empty():
|
||||
return self.cq.get()
|
||||
time.sleep(0.001)
|
||||
|
||||
class IndexHost():
|
||||
"""
|
||||
Provides random shuffled indexes for multiprocesses
|
||||
"""
|
||||
def __init__(self, indexes_count):
|
||||
self.sq = multiprocessing.Queue()
|
||||
self.cqs = []
|
||||
self.clis = []
|
||||
self.thread = threading.Thread(target=self.host_thread, args=(indexes_count,) )
|
||||
self.thread.daemon = True
|
||||
self.thread.start()
|
||||
|
||||
def host_thread(self, indexes_count):
|
||||
idxs = [*range(indexes_count)]
|
||||
shuffle_idxs = []
|
||||
sq = self.sq
|
||||
|
||||
while True:
|
||||
while not sq.empty():
|
||||
obj = sq.get()
|
||||
cq_id, count = obj[0], obj[1]
|
||||
|
||||
result = []
|
||||
for i in range(count):
|
||||
if len(shuffle_idxs) == 0:
|
||||
shuffle_idxs = idxs.copy()
|
||||
np.random.shuffle(shuffle_idxs)
|
||||
result.append(shuffle_idxs.pop())
|
||||
self.cqs[cq_id].put (result)
|
||||
|
||||
time.sleep(0.005)
|
||||
|
||||
def create_cli(self):
|
||||
cq = multiprocessing.Queue()
|
||||
self.cqs.append ( cq )
|
||||
cq_id = len(self.cqs)-1
|
||||
return IndexHost.Cli(self.sq, cq, cq_id)
|
||||
|
||||
# disable pickling
|
||||
def __getstate__(self):
|
||||
return dict()
|
||||
def __setstate__(self, d):
|
||||
self.__dict__.update(d)
|
||||
|
||||
class Cli():
|
||||
def __init__(self, sq, cq, cq_id):
|
||||
self.sq = sq
|
||||
self.cq = cq
|
||||
self.cq_id = cq_id
|
||||
|
||||
def get(self, count):
|
||||
self.sq.put ( (self.cq_id,count) )
|
||||
|
||||
while True:
|
||||
if not self.cq.empty():
|
||||
return self.cq.get()
|
||||
time.sleep(0.001)
|
||||
|
||||
class ListHost():
|
||||
def __init__(self, list_):
|
||||
self.sq = multiprocessing.Queue()
|
||||
self.cqs = []
|
||||
self.clis = []
|
||||
self.list_ = list_
|
||||
self.thread = threading.Thread(target=self.host_thread)
|
||||
self.thread.daemon = True
|
||||
self.thread.start()
|
||||
|
||||
def host_thread(self):
|
||||
sq = self.sq
|
||||
while True:
|
||||
while not sq.empty():
|
||||
obj = sq.get()
|
||||
cq_id, cmd = obj[0], obj[1]
|
||||
if cmd == 0:
|
||||
item = self.list_[ obj[2] ]
|
||||
self.cqs[cq_id].put ( item )
|
||||
|
||||
elif cmd == 1:
|
||||
self.cqs[cq_id].put ( len(self.list_) )
|
||||
time.sleep(0.005)
|
||||
|
||||
def create_cli(self):
|
||||
cq = multiprocessing.Queue()
|
||||
self.cqs.append ( cq )
|
||||
cq_id = len(self.cqs)-1
|
||||
return ListHost.Cli(self.sq, cq, cq_id)
|
||||
|
||||
def get_list(self):
|
||||
return self.list_
|
||||
|
||||
# disable pickling
|
||||
def __getstate__(self):
|
||||
return dict()
|
||||
def __setstate__(self, d):
|
||||
self.__dict__.update(d)
|
||||
|
||||
class Cli():
|
||||
def __init__(self, sq, cq, cq_id):
|
||||
self.sq = sq
|
||||
self.cq = cq
|
||||
self.cq_id = cq_id
|
||||
|
||||
def __getitem__(self, key):
|
||||
self.sq.put ( (self.cq_id,0,key) )
|
||||
|
||||
while True:
|
||||
if not self.cq.empty():
|
||||
return self.cq.get()
|
||||
time.sleep(0.001)
|
||||
|
||||
def __len__(self):
|
||||
self.sq.put ( (self.cq_id,1) )
|
||||
|
||||
while True:
|
||||
if not self.cq.empty():
|
||||
return self.cq.get()
|
||||
time.sleep(0.001)
|
||||
|
||||
class DictHost():
|
||||
def __init__(self, d, num_users):
|
||||
self.sqs = [ multiprocessing.Queue() for _ in range(num_users) ]
|
||||
self.cqs = [ multiprocessing.Queue() for _ in range(num_users) ]
|
||||
|
||||
self.thread = threading.Thread(target=self.host_thread, args=(d,) )
|
||||
self.thread.daemon = True
|
||||
self.thread.start()
|
||||
|
||||
self.clis = [ DictHostCli(sq,cq) for sq, cq in zip(self.sqs, self.cqs) ]
|
||||
|
||||
def host_thread(self, d):
|
||||
while True:
|
||||
for sq, cq in zip(self.sqs, self.cqs):
|
||||
if not sq.empty():
|
||||
obj = sq.get()
|
||||
cmd = obj[0]
|
||||
if cmd == 0:
|
||||
cq.put (d[ obj[1] ])
|
||||
elif cmd == 1:
|
||||
cq.put ( list(d.keys()) )
|
||||
|
||||
time.sleep(0.005)
|
||||
|
||||
|
||||
def get_cli(self, n_user):
|
||||
return self.clis[n_user]
|
||||
|
||||
# disable pickling
|
||||
def __getstate__(self):
|
||||
return dict()
|
||||
def __setstate__(self, d):
|
||||
self.__dict__.update(d)
|
||||
|
||||
class DictHostCli():
|
||||
def __init__(self, sq, cq):
|
||||
self.sq = sq
|
||||
self.cq = cq
|
||||
|
||||
def __getitem__(self, key):
|
||||
self.sq.put ( (0,key) )
|
||||
|
||||
while True:
|
||||
if not self.cq.empty():
|
||||
return self.cq.get()
|
||||
time.sleep(0.001)
|
||||
|
||||
def keys(self):
|
||||
self.sq.put ( (1,) )
|
||||
while True:
|
||||
if not self.cq.empty():
|
||||
return self.cq.get()
|
||||
time.sleep(0.001)
|
Loading…
Add table
Add a link
Reference in a new issue