Compare commits
428 commits
pretrained
...
master
Author | SHA1 | Date | |
---|---|---|---|
|
e4b7543ffa | ||
|
235caf2b79 | ||
|
8489e58e04 | ||
|
a27a30b682 | ||
|
46fc2397c5 | ||
|
9ef04b2207 | ||
|
19ab857cfb | ||
|
501f1e2ea3 | ||
|
cd83f6fedf | ||
|
5c942764df | ||
|
2ac860b074 | ||
|
521a23f557 | ||
|
0d19d8ec8e | ||
|
b9e571d7e1 | ||
|
9704c5d8f8 | ||
|
8c79e70191 | ||
|
292b562eb7 | ||
|
9b331e59dd | ||
|
1badaa4b4c | ||
|
388964e8d0 | ||
|
4b6bf003b8 | ||
|
ce624735c0 | ||
|
7326771c02 | ||
|
1aa9463edf | ||
|
853312e84c | ||
|
71f22957a6 | ||
|
f71a9e6bbd | ||
|
a1ba64be10 | ||
|
14cc9d4e5f | ||
|
d8c7cc3d93 | ||
|
f64b2495d9 | ||
|
f48e852de3 | ||
|
f700bd5954 | ||
|
80cae7d9f5 | ||
|
91f0f30a78 | ||
|
0a203bb25e | ||
|
b1b5d6f482 | ||
|
f8469fe4d7 | ||
|
3aa2b56eda | ||
|
c6d0c547b7 | ||
|
9e0079c6a0 | ||
|
33ff0be722 | ||
|
8a897f236f | ||
|
c1cee5d3ca | ||
|
7b96538cfb | ||
|
9f73ed0e7e | ||
|
cc133c0702 | ||
|
b9c9552db6 | ||
|
bd0b408a9b | ||
|
335da71db6 | ||
|
70c9ee9463 | ||
|
d301a4799f | ||
|
b1990d421a | ||
|
3fe8ce86b1 | ||
|
f346e35e7f | ||
|
6e094d873d | ||
|
01f1a084b4 | ||
|
e53be5e22d | ||
|
91187ecb95 | ||
|
8e63666390 | ||
|
e41f87e682 | ||
|
b256b07e03 | ||
|
57f3393ec2 | ||
|
56e70edc46 | ||
|
26c83f6e35 | ||
|
664c8fd105 | ||
|
8a455fc238 | ||
|
fce90a1b44 | ||
|
f95b3cdd33 | ||
|
60ade74327 | ||
|
f6bd45cce0 | ||
|
c5584fbda0 | ||
|
623eb3856d | ||
|
4aa33d3ba8 | ||
|
a299324166 | ||
|
c8e6f23a31 | ||
|
4581d48ff0 | ||
|
f99b8a0842 | ||
|
931ec36b41 | ||
|
55b947eab5 | ||
|
83b1412da7 | ||
|
f5cc54177f | ||
|
d26d579836 | ||
|
9e1bc5a153 | ||
|
4be135af60 | ||
|
bfa88c5fd9 | ||
|
5c7e2f310a | ||
|
c5044580c7 | ||
|
850b1c5cc9 | ||
|
493f23c618 | ||
|
9ad9728b40 | ||
|
336e0ca944 | ||
|
8b71b83c94 | ||
|
0748b8d043 | ||
|
959a3530f8 | ||
|
da8f33ee85 | ||
|
c9e0ba1779 | ||
|
7b8991a3cc | ||
|
f044c99ddc | ||
|
2edac3df8c | ||
|
35877dbfd7 | ||
|
ee1bc83a14 | ||
|
b4b72d056f | ||
|
9e0bd34253 | ||
|
446b95942f | ||
|
0444a9a565 | ||
|
075422e0c5 | ||
|
63c794b3d0 | ||
|
9d6b6feb1f | ||
|
c43b3b161b | ||
|
90a74efd89 | ||
|
714234bce4 | ||
|
3c6bbe22b9 | ||
|
5783191849 | ||
|
6d89d7fa4c | ||
|
9092e73c41 | ||
|
427e51b413 | ||
|
62f1d57871 | ||
|
62c6fffdde | ||
|
9f0b4bf3cf | ||
|
6e0c8222f7 | ||
|
5dc027a8b0 | ||
|
24ba84d4a5 | ||
|
2a95699b1c | ||
|
8ea7820b9c | ||
|
a6438ca494 | ||
|
5fac5ee4f3 | ||
|
34b62862e0 | ||
|
1981ed0ca8 | ||
|
315f241c51 | ||
|
e52b53f87c | ||
|
11a7993238 | ||
|
6f86c68e65 | ||
|
e6e2ee7466 | ||
|
757283d10e | ||
|
766750941a | ||
|
41b517517e | ||
|
deedd3dd12 | ||
|
66bb72f164 | ||
|
65a703c024 | ||
|
8b90ca0dac | ||
|
b15cdd96a1 | ||
|
aa26089032 | ||
|
d204e049d1 | ||
|
6f5bccaa15 | ||
|
081d8faa45 | ||
|
87030bdcdf | ||
|
e53d1b1820 | ||
|
23130cd56a | ||
|
78f12de819 | ||
|
26b4b6adef | ||
|
fdb143ff47 | ||
|
fc4a49c3e7 | ||
|
dcf146cc16 | ||
|
7a08c0c1d3 | ||
|
93fe480eca | ||
|
af0b3904fc | ||
|
65432d0c3d | ||
|
bee8628d77 | ||
|
d676a365f7 | ||
|
457a39c093 | ||
|
243f73fafc | ||
|
bcfc794a1b | ||
|
d63294a548 | ||
|
30ef9c0808 | ||
|
1652bffeb0 | ||
|
b7245d888b | ||
|
0ad9421101 | ||
|
8d46cd94fd | ||
|
f387179cba | ||
|
e47b602ec8 | ||
|
b333fcea4b | ||
|
3d0e18b0ad | ||
|
48181832a7 | ||
|
a089c1f108 | ||
|
1aef229c72 | ||
|
31a8d05c53 | ||
|
61d345a26e | ||
|
3a207e0ee1 | ||
|
13055dbe60 | ||
|
78fc843c15 | ||
|
eb71b9f256 | ||
|
21a24ffc8d | ||
|
28d7f22802 | ||
|
0b8d1b2672 | ||
|
68e55baa15 | ||
|
1bd59a6c14 | ||
|
1fbfe6ebeb | ||
|
653475fe5e | ||
|
a50041ee9e | ||
|
81b67bcd11 | ||
|
ec5ae22d13 | ||
|
11add4cd4f | ||
|
54fc3162ed | ||
|
140f16f772 | ||
|
4f2efd7985 | ||
|
8ff34be5e4 | ||
|
ad5733c5bb | ||
|
2339a9ab09 | ||
|
ae9e16b4a5 | ||
|
241d1a9c35 | ||
|
299d91c81e | ||
|
cc950f12fb | ||
|
637c2b2a9b | ||
|
fba48fbe9a | ||
|
fc7ed4d0ee | ||
|
e391b5aa4d | ||
|
b7bed0ef5e | ||
|
1bd7c42b7e | ||
|
b12746fb8a | ||
|
977d8a2d77 | ||
|
554217d026 | ||
|
35cb35eb9f | ||
|
db83a21244 | ||
|
dd037d2dea | ||
|
a06d95e45a | ||
|
20d3270a86 | ||
|
40d18896b7 | ||
|
98e31c8171 | ||
|
b21bce458f | ||
|
27f41b3bcc | ||
|
aa79a7deea | ||
|
da058cf17f | ||
|
26ed582aff | ||
|
bcfefb492b | ||
|
1ee7798fb9 | ||
|
704b5dc072 | ||
|
e7d36b4287 | ||
|
254a7cf5cf | ||
|
35945b257c | ||
|
ca03bbca04 | ||
|
764f3069a0 | ||
|
b9c9e7cffd | ||
|
bbf3a71a96 | ||
|
a03493c4f7 | ||
|
a35076a171 | ||
|
573ab6c22c | ||
|
bb432b21f9 | ||
|
e9e7344424 | ||
|
c516454566 | ||
|
874a7eba18 | ||
|
1adad3ece6 | ||
|
9ed662c522 | ||
|
5446e17ea4 | ||
|
53a6df65af | ||
|
0eb7e06ac1 | ||
|
8b9b907682 | ||
|
33f7ca5b62 | ||
|
676310edc4 | ||
|
b53a8e4372 | ||
|
732c1e4ab2 | ||
|
6903841161 | ||
|
b2f9ea8637 | ||
|
ce3f0676f0 | ||
|
d91e79fa64 | ||
|
4c47f0990d | ||
|
06e120a3df | ||
|
00b4f9090f | ||
|
768e56d862 | ||
|
f4a661b742 | ||
|
a61b7ee94d | ||
|
b5bfec84c4 | ||
|
48168eb575 | ||
|
3e7ee22ae3 | ||
|
6134e57762 | ||
|
1435dd3dd1 | ||
|
ce95af9068 | ||
|
58670722dc | ||
|
b9b97861e1 | ||
|
0efebe6bc5 | ||
|
3930073e32 | ||
|
dd21880ecd | ||
|
6c7bb74ad4 | ||
|
e8b04053e4 | ||
|
aacd29269a | ||
|
c963703395 | ||
|
5a40f537cc | ||
|
9a540e644c | ||
|
11dc7d41a0 | ||
|
b0ad36de94 | ||
|
3283dce96a | ||
|
c58d2e8fb3 | ||
|
770da74a9b | ||
|
55c51578e5 | ||
|
57ea3e61b7 | ||
|
4ce4997d1a | ||
|
81ed7c0a2a | ||
|
f56df85f78 | ||
|
e5ed29bd76 | ||
|
e5e9f1689c | ||
|
e6aa996814 | ||
|
e0a1d52d78 | ||
|
bef8403e39 | ||
|
44e6970dc5 | ||
|
3b71ecafa1 | ||
|
b6dd482e05 | ||
|
bb8d2b5b2c | ||
|
7da28b283d | ||
|
d78ec338c6 | ||
|
ddb6bcf416 | ||
|
f0508287f0 | ||
|
4e998aff93 | ||
|
29b1050637 | ||
|
669935c9e6 | ||
|
d0ded46b51 | ||
|
f90723c9f9 | ||
|
27f219fec7 | ||
|
2b8e8f0554 | ||
|
53222839c6 | ||
|
05606b5e15 | ||
|
0c2e1c3944 | ||
|
9fd3a9ff8d | ||
|
38573473da | ||
|
9994f1512e | ||
|
11327cb0c5 | ||
|
70182f8c65 | ||
|
8c467bad90 | ||
|
2e650a7122 | ||
|
7f11713730 | ||
|
d99afebbdc | ||
|
82f405ed49 | ||
|
5c315cab68 | ||
|
dc43f5a891 | ||
|
c2b5f80e3f | ||
|
1d3d417aee | ||
|
af98407f06 | ||
|
4acaf08ebf | ||
|
cfd7803e0d | ||
|
85ede1a7e7 | ||
|
ddf7363eda | ||
|
235315dd70 | ||
|
1e003170bd | ||
|
61f93063b2 | ||
|
addc96fe3e | ||
|
4ed320a86b | ||
|
65bc9273e6 | ||
|
c94c9106be | ||
|
903657ff91 | ||
|
f91a604c6d | ||
|
e77865ce18 | ||
|
6f25ebbc9e | ||
|
c6652ac006 | ||
|
130d72cb8b | ||
|
72e2f771a2 | ||
|
0c18b1f011 | ||
|
dcf380345b | ||
|
f935bf0465 | ||
|
7b49220ee9 | ||
|
4fe45adc7c | ||
|
3a9e851339 | ||
|
a3c91271d0 | ||
|
baff59e421 | ||
|
ea607edfc9 | ||
|
6bbc607312 | ||
|
f8580928ed | ||
|
c42b1b2124 | ||
|
cd8919d95b | ||
|
e549624eeb | ||
|
2cc0e64572 | ||
|
d1af3b51cd | ||
|
835f3adb8e | ||
|
4a3b94832f | ||
|
95847637e5 | ||
|
8ddc7954da | ||
|
80d4eef1fe | ||
|
23961879b4 | ||
|
25324267d6 | ||
|
f72df02950 | ||
|
935d940ace | ||
|
d8a792232c | ||
|
5b4d023712 | ||
|
b6b3936bcd | ||
|
f821ab350f | ||
|
afdb1ef85d | ||
|
453237bfd7 | ||
|
215277b376 | ||
|
33b0aadb4e | ||
|
8e9e346c9d | ||
|
8aa87c3080 | ||
|
4db11aa133 | ||
|
2b7364005d | ||
|
383d4d3736 | ||
|
0fb912e91f | ||
|
2fe86faf01 | ||
|
dbb0988927 | ||
|
652ef75099 | ||
|
3702531898 | ||
|
9a9b7e4f81 | ||
|
6d3607a13d | ||
|
e5bad483ca | ||
|
3d4b9c858e | ||
|
497a7eec94 | ||
|
4e744cf184 | ||
|
c3ce06a588 | ||
|
d1a5639e90 | ||
|
6687213fa5 | ||
|
ca9138d6b7 | ||
|
54fdeb0666 | ||
|
8ad2a5373e | ||
|
2d5e949100 | ||
|
3a9f22f68e | ||
|
01d81674fd | ||
|
eddebedcf6 | ||
|
e5f736680d | ||
|
14640db720 | ||
|
883f72867f | ||
|
7015f6fe09 | ||
|
ed8af226a3 | ||
|
8bffc2b190 | ||
|
bf68cef869 | ||
|
9a015c9bb5 | ||
|
3a1758d719 | ||
|
c6e4ae8613 | ||
|
4a75e91daf | ||
|
51ec884d96 | ||
|
acba34237c | ||
|
f9684cb155 | ||
|
4b2bef8bd8 | ||
|
deb705ad53 | ||
|
46eacd6fd2 | ||
|
9b11a9c9a4 | ||
|
7ab01ade43 | ||
|
57ceba3225 | ||
|
efe3b56683 | ||
|
a9b23e9851 | ||
|
79b8b8a7a7 | ||
|
3fa93da5e7 | ||
|
4c24f9d41c |
2
.vscode/launch.json
vendored
|
@ -12,7 +12,7 @@
|
|||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "${env:DFL_ROOT}\\main.py",
|
||||
"pythonPath": "${env:PYTHONEXECUTABLE}",
|
||||
"python": "${env:PYTHONEXECUTABLE}",
|
||||
"cwd": "${env:WORKSPACE}",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["train",
|
||||
|
|
312
DFLIMG/DFLJPG.py
|
@ -1,21 +1,27 @@
|
|||
import pickle
|
||||
import struct
|
||||
import traceback
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from core import imagelib
|
||||
from core.cv2ex import *
|
||||
from core.imagelib import SegIEPolys
|
||||
from core.interact import interact as io
|
||||
from core.structex import *
|
||||
from facelib import FaceType
|
||||
|
||||
|
||||
class DFLJPG(object):
|
||||
def __init__(self):
|
||||
def __init__(self, filename):
|
||||
self.filename = filename
|
||||
self.data = b""
|
||||
self.length = 0
|
||||
self.chunks = []
|
||||
self.dfl_dict = None
|
||||
self.shape = (0,0,0)
|
||||
self.shape = None
|
||||
self.img = None
|
||||
|
||||
@staticmethod
|
||||
def load_raw(filename, loader_func=None):
|
||||
|
@ -29,7 +35,7 @@ class DFLJPG(object):
|
|||
raise FileNotFoundError(filename)
|
||||
|
||||
try:
|
||||
inst = DFLJPG()
|
||||
inst = DFLJPG(filename)
|
||||
inst.data = data
|
||||
inst.length = len(data)
|
||||
inst_length = inst.length
|
||||
|
@ -123,7 +129,7 @@ class DFLJPG(object):
|
|||
def load(filename, loader_func=None):
|
||||
try:
|
||||
inst = DFLJPG.load_raw (filename, loader_func=loader_func)
|
||||
inst.dfl_dict = None
|
||||
inst.dfl_dict = {}
|
||||
|
||||
for chunk in inst.chunks:
|
||||
if chunk['name'] == 'APP0':
|
||||
|
@ -132,8 +138,6 @@ class DFLJPG(object):
|
|||
|
||||
if id == b"JFIF":
|
||||
c, ver_major, ver_minor, units, Xdensity, Ydensity, Xthumbnail, Ythumbnail = struct_unpack (d, c, "=BBBHHBB")
|
||||
#if units == 0:
|
||||
# inst.shape = (Ydensity, Xdensity, 3)
|
||||
else:
|
||||
raise Exception("Unknown jpeg ID: %s" % (id) )
|
||||
elif chunk['name'] == 'SOF0' or chunk['name'] == 'SOF2':
|
||||
|
@ -145,160 +149,30 @@ class DFLJPG(object):
|
|||
if type(chunk['data']) == bytes:
|
||||
inst.dfl_dict = pickle.loads(chunk['data'])
|
||||
|
||||
if (inst.dfl_dict is not None):
|
||||
if 'face_type' not in inst.dfl_dict:
|
||||
inst.dfl_dict['face_type'] = FaceType.toString (FaceType.FULL)
|
||||
|
||||
if 'fanseg_mask' in inst.dfl_dict:
|
||||
fanseg_mask = inst.dfl_dict['fanseg_mask']
|
||||
if fanseg_mask is not None:
|
||||
numpyarray = np.asarray( inst.dfl_dict['fanseg_mask'], dtype=np.uint8)
|
||||
inst.dfl_dict['fanseg_mask'] = cv2.imdecode(numpyarray, cv2.IMREAD_UNCHANGED)
|
||||
|
||||
if inst.dfl_dict == None:
|
||||
return None
|
||||
|
||||
return inst
|
||||
except Exception as e:
|
||||
print (e)
|
||||
io.log_err (f'Exception occured while DFLJPG.load : {traceback.format_exc()}')
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def embed_dfldict(filename, dfl_dict):
|
||||
inst = DFLJPG.load_raw (filename)
|
||||
inst.setDFLDictData (dfl_dict)
|
||||
def has_data(self):
|
||||
return len(self.dfl_dict.keys()) != 0
|
||||
|
||||
def save(self):
|
||||
try:
|
||||
with open(filename, "wb") as f:
|
||||
f.write ( inst.dump() )
|
||||
with open(self.filename, "wb") as f:
|
||||
f.write ( self.dump() )
|
||||
except:
|
||||
raise Exception( 'cannot save %s' % (filename) )
|
||||
|
||||
@staticmethod
|
||||
def embed_data(filename, face_type=None,
|
||||
landmarks=None,
|
||||
ie_polys=None,
|
||||
seg_ie_polys=None,
|
||||
source_filename=None,
|
||||
source_rect=None,
|
||||
source_landmarks=None,
|
||||
image_to_face_mat=None,
|
||||
fanseg_mask=None,
|
||||
eyebrows_expand_mod=None,
|
||||
relighted=None,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
if fanseg_mask is not None:
|
||||
fanseg_mask = np.clip ( (fanseg_mask*255).astype(np.uint8), 0, 255 )
|
||||
|
||||
ret, buf = cv2.imencode( '.jpg', fanseg_mask, [int(cv2.IMWRITE_JPEG_QUALITY), 85] )
|
||||
|
||||
if ret and len(buf) < 60000:
|
||||
fanseg_mask = buf
|
||||
else:
|
||||
io.log_err("Unable to encode fanseg_mask for %s" % (filename) )
|
||||
fanseg_mask = None
|
||||
|
||||
if ie_polys is not None:
|
||||
if not isinstance(ie_polys, list):
|
||||
ie_polys = ie_polys.dump()
|
||||
|
||||
if seg_ie_polys is not None:
|
||||
if not isinstance(seg_ie_polys, list):
|
||||
seg_ie_polys = seg_ie_polys.dump()
|
||||
|
||||
DFLJPG.embed_dfldict (filename, {'face_type': face_type,
|
||||
'landmarks': landmarks,
|
||||
'ie_polys' : ie_polys,
|
||||
'seg_ie_polys' : seg_ie_polys,
|
||||
'source_filename': source_filename,
|
||||
'source_rect': source_rect,
|
||||
'source_landmarks': source_landmarks,
|
||||
'image_to_face_mat': image_to_face_mat,
|
||||
'fanseg_mask' : fanseg_mask,
|
||||
'eyebrows_expand_mod' : eyebrows_expand_mod,
|
||||
'relighted' : relighted
|
||||
})
|
||||
|
||||
def embed_and_set(self, filename, face_type=None,
|
||||
landmarks=None,
|
||||
ie_polys=None,
|
||||
seg_ie_polys=None,
|
||||
source_filename=None,
|
||||
source_rect=None,
|
||||
source_landmarks=None,
|
||||
image_to_face_mat=None,
|
||||
fanseg_mask=None,
|
||||
eyebrows_expand_mod=None,
|
||||
relighted=None,
|
||||
**kwargs
|
||||
):
|
||||
if face_type is None: face_type = self.get_face_type()
|
||||
if landmarks is None: landmarks = self.get_landmarks()
|
||||
if ie_polys is None: ie_polys = self.get_ie_polys()
|
||||
if seg_ie_polys is None: seg_ie_polys = self.get_seg_ie_polys()
|
||||
if source_filename is None: source_filename = self.get_source_filename()
|
||||
if source_rect is None: source_rect = self.get_source_rect()
|
||||
if source_landmarks is None: source_landmarks = self.get_source_landmarks()
|
||||
if image_to_face_mat is None: image_to_face_mat = self.get_image_to_face_mat()
|
||||
if fanseg_mask is None: fanseg_mask = self.get_fanseg_mask()
|
||||
if eyebrows_expand_mod is None: eyebrows_expand_mod = self.get_eyebrows_expand_mod()
|
||||
if relighted is None: relighted = self.get_relighted()
|
||||
DFLJPG.embed_data (filename, face_type=face_type,
|
||||
landmarks=landmarks,
|
||||
ie_polys=ie_polys,
|
||||
seg_ie_polys=seg_ie_polys,
|
||||
source_filename=source_filename,
|
||||
source_rect=source_rect,
|
||||
source_landmarks=source_landmarks,
|
||||
image_to_face_mat=image_to_face_mat,
|
||||
fanseg_mask=fanseg_mask,
|
||||
eyebrows_expand_mod=eyebrows_expand_mod,
|
||||
relighted=relighted)
|
||||
|
||||
def remove_ie_polys(self):
|
||||
self.dfl_dict['ie_polys'] = None
|
||||
|
||||
def remove_seg_ie_polys(self):
|
||||
self.dfl_dict['seg_ie_polys'] = None
|
||||
|
||||
def remove_fanseg_mask(self):
|
||||
self.dfl_dict['fanseg_mask'] = None
|
||||
|
||||
def remove_source_filename(self):
|
||||
self.dfl_dict['source_filename'] = None
|
||||
raise Exception( f'cannot save {self.filename}' )
|
||||
|
||||
def dump(self):
|
||||
data = b""
|
||||
|
||||
for chunk in self.chunks:
|
||||
data += struct.pack ("BB", 0xFF, chunk['m_h'] )
|
||||
chunk_data = chunk['data']
|
||||
if chunk_data is not None:
|
||||
data += struct.pack (">H", len(chunk_data)+2 )
|
||||
data += chunk_data
|
||||
dict_data = self.dfl_dict
|
||||
|
||||
chunk_ex_data = chunk['ex_data']
|
||||
if chunk_ex_data is not None:
|
||||
data += chunk_ex_data
|
||||
|
||||
return data
|
||||
|
||||
def get_shape(self):
|
||||
return self.shape
|
||||
|
||||
def get_height(self):
|
||||
for chunk in self.chunks:
|
||||
if type(chunk) == IHDR:
|
||||
return chunk.height
|
||||
return 0
|
||||
|
||||
def getDFLDictData(self):
|
||||
return self.dfl_dict
|
||||
|
||||
def setDFLDictData (self, dict_data=None):
|
||||
self.dfl_dict = dict_data
|
||||
# Remove None keys
|
||||
for key in list(dict_data.keys()):
|
||||
if dict_data[key] is None:
|
||||
dict_data.pop(key)
|
||||
|
||||
for chunk in self.chunks:
|
||||
if chunk['name'] == 'APP15':
|
||||
|
@ -317,24 +191,134 @@ class DFLJPG(object):
|
|||
}
|
||||
self.chunks.insert (last_app_chunk+1, dflchunk)
|
||||
|
||||
def get_face_type(self): return self.dfl_dict['face_type']
|
||||
def get_landmarks(self): return np.array ( self.dfl_dict['landmarks'] )
|
||||
def get_ie_polys(self): return self.dfl_dict.get('ie_polys',None)
|
||||
def get_seg_ie_polys(self): return self.dfl_dict.get('seg_ie_polys',None)
|
||||
def get_source_filename(self): return self.dfl_dict['source_filename']
|
||||
def get_source_rect(self): return self.dfl_dict['source_rect']
|
||||
def get_source_landmarks(self): return np.array ( self.dfl_dict['source_landmarks'] )
|
||||
|
||||
for chunk in self.chunks:
|
||||
data += struct.pack ("BB", 0xFF, chunk['m_h'] )
|
||||
chunk_data = chunk['data']
|
||||
if chunk_data is not None:
|
||||
data += struct.pack (">H", len(chunk_data)+2 )
|
||||
data += chunk_data
|
||||
|
||||
chunk_ex_data = chunk['ex_data']
|
||||
if chunk_ex_data is not None:
|
||||
data += chunk_ex_data
|
||||
|
||||
return data
|
||||
|
||||
def get_img(self):
|
||||
if self.img is None:
|
||||
self.img = cv2_imread(self.filename)
|
||||
return self.img
|
||||
|
||||
def get_shape(self):
|
||||
if self.shape is None:
|
||||
img = self.get_img()
|
||||
if img is not None:
|
||||
self.shape = img.shape
|
||||
return self.shape
|
||||
|
||||
def get_height(self):
|
||||
for chunk in self.chunks:
|
||||
if type(chunk) == IHDR:
|
||||
return chunk.height
|
||||
return 0
|
||||
|
||||
def get_dict(self):
|
||||
return self.dfl_dict
|
||||
|
||||
def set_dict (self, dict_data=None):
|
||||
self.dfl_dict = dict_data
|
||||
|
||||
def get_face_type(self): return self.dfl_dict.get('face_type', FaceType.toString (FaceType.FULL) )
|
||||
def set_face_type(self, face_type): self.dfl_dict['face_type'] = face_type
|
||||
|
||||
def get_landmarks(self): return np.array ( self.dfl_dict['landmarks'] )
|
||||
def set_landmarks(self, landmarks): self.dfl_dict['landmarks'] = landmarks
|
||||
|
||||
def get_eyebrows_expand_mod(self): return self.dfl_dict.get ('eyebrows_expand_mod', 1.0)
|
||||
def set_eyebrows_expand_mod(self, eyebrows_expand_mod): self.dfl_dict['eyebrows_expand_mod'] = eyebrows_expand_mod
|
||||
|
||||
def get_source_filename(self): return self.dfl_dict.get ('source_filename', None)
|
||||
def set_source_filename(self, source_filename): self.dfl_dict['source_filename'] = source_filename
|
||||
|
||||
def get_source_rect(self): return self.dfl_dict.get ('source_rect', None)
|
||||
def set_source_rect(self, source_rect): self.dfl_dict['source_rect'] = source_rect
|
||||
|
||||
def get_source_landmarks(self): return np.array ( self.dfl_dict.get('source_landmarks', None) )
|
||||
def set_source_landmarks(self, source_landmarks): self.dfl_dict['source_landmarks'] = source_landmarks
|
||||
|
||||
def get_image_to_face_mat(self):
|
||||
mat = self.dfl_dict.get ('image_to_face_mat', None)
|
||||
if mat is not None:
|
||||
return np.array (mat)
|
||||
return None
|
||||
def get_fanseg_mask(self):
|
||||
fanseg_mask = self.dfl_dict.get ('fanseg_mask', None)
|
||||
if fanseg_mask is not None:
|
||||
return np.clip ( np.array (fanseg_mask) / 255.0, 0.0, 1.0 )[...,np.newaxis]
|
||||
return None
|
||||
def get_eyebrows_expand_mod(self):
|
||||
return self.dfl_dict.get ('eyebrows_expand_mod', None)
|
||||
def get_relighted(self):
|
||||
return self.dfl_dict.get ('relighted', False)
|
||||
def set_image_to_face_mat(self, image_to_face_mat): self.dfl_dict['image_to_face_mat'] = image_to_face_mat
|
||||
|
||||
def has_seg_ie_polys(self):
|
||||
return self.dfl_dict.get('seg_ie_polys',None) is not None
|
||||
|
||||
def get_seg_ie_polys(self):
|
||||
d = self.dfl_dict.get('seg_ie_polys',None)
|
||||
if d is not None:
|
||||
d = SegIEPolys.load(d)
|
||||
else:
|
||||
d = SegIEPolys()
|
||||
|
||||
return d
|
||||
|
||||
def set_seg_ie_polys(self, seg_ie_polys):
|
||||
if seg_ie_polys is not None:
|
||||
if not isinstance(seg_ie_polys, SegIEPolys):
|
||||
raise ValueError('seg_ie_polys should be instance of SegIEPolys')
|
||||
|
||||
if seg_ie_polys.has_polys():
|
||||
seg_ie_polys = seg_ie_polys.dump()
|
||||
else:
|
||||
seg_ie_polys = None
|
||||
|
||||
self.dfl_dict['seg_ie_polys'] = seg_ie_polys
|
||||
|
||||
def has_xseg_mask(self):
|
||||
return self.dfl_dict.get('xseg_mask',None) is not None
|
||||
|
||||
def get_xseg_mask_compressed(self):
|
||||
mask_buf = self.dfl_dict.get('xseg_mask',None)
|
||||
if mask_buf is None:
|
||||
return None
|
||||
|
||||
return mask_buf
|
||||
|
||||
def get_xseg_mask(self):
|
||||
mask_buf = self.dfl_dict.get('xseg_mask',None)
|
||||
if mask_buf is None:
|
||||
return None
|
||||
|
||||
img = cv2.imdecode(mask_buf, cv2.IMREAD_UNCHANGED)
|
||||
if len(img.shape) == 2:
|
||||
img = img[...,None]
|
||||
|
||||
return img.astype(np.float32) / 255.0
|
||||
|
||||
|
||||
def set_xseg_mask(self, mask_a):
|
||||
if mask_a is None:
|
||||
self.dfl_dict['xseg_mask'] = None
|
||||
return
|
||||
|
||||
mask_a = imagelib.normalize_channels(mask_a, 1)
|
||||
img_data = np.clip( mask_a*255, 0, 255 ).astype(np.uint8)
|
||||
|
||||
data_max_len = 50000
|
||||
|
||||
ret, buf = cv2.imencode('.png', img_data)
|
||||
|
||||
if not ret or len(buf) > data_max_len:
|
||||
for jpeg_quality in range(100,-1,-1):
|
||||
ret, buf = cv2.imencode( '.jpg', img_data, [int(cv2.IMWRITE_JPEG_QUALITY), jpeg_quality] )
|
||||
if ret and len(buf) <= data_max_len:
|
||||
break
|
||||
|
||||
if not ret:
|
||||
raise Exception("set_xseg_mask: unable to generate image data for set_xseg_mask")
|
||||
|
||||
self.dfl_dict['xseg_mask'] = buf
|
||||
|
|
299
README.md
|
@ -1,122 +1,237 @@
|
|||
<table align="center"><tr><td align="center" width="9999">
|
||||
<img src="doc/DFL_welcome.jpg" align="center">
|
||||
<table align="center" border="0">
|
||||
|
||||
<tr><td colspan=2 align="center">
|
||||
|
||||
# DeepFaceLab
|
||||
|
||||
<a href="https://arxiv.org/abs/2005.05535">
|
||||
|
||||
<img src="https://static.arxiv.org/static/browse/0.3.0/images/icons/favicon.ico" width=14></img>
|
||||
https://arxiv.org/abs/2005.05535</a>
|
||||
|
||||
</td></tr>
|
||||
<tr><td colspan=2 align="center">
|
||||
|
||||
<p align="center">
|
||||
|
||||

|
||||

|
||||

|
||||

|
||||

|
||||
|
||||
</p>
|
||||
|
||||
# DeepFaceLab
|
||||
### the leading software for creating deepfakes
|
||||
|
||||
</td></tr>
|
||||
<tr><td align="center" width="9999">
|
||||
|
||||
More than 95% of deepfake videos are created with DeepFaceLab.
|
||||
|
||||
DeepFaceLab is used by such popular youtube channels as
|
||||
|
||||
| [Ctrl Shift Face](https://www.youtube.com/channel/UCKpH0CKltc73e4wh0_pgL3g)| [VFXChris Ume](https://www.youtube.com/channel/UCGf4OlX_aTt8DlrgiH3jN3g/videos)|
|
||||
|---|---|
|
||||
|
||||
| [Sham00k](https://www.youtube.com/channel/UCZXbWcv7fSZFTAZV4beckyw/videos)| [Collider videos](https://www.youtube.com/watch?v=A91P2qtPT54&list=PLayt6616lBclvOprvrC8qKGCO-mAhPRux)| [iFake](https://www.youtube.com/channel/UCC0lK2Zo2BMXX-k1Ks0r7dg/videos)|
|
||||
| [deeptomcruise](https://www.tiktok.com/@deeptomcruise)| [1facerussia](https://www.tiktok.com/@1facerussia)| [arnoldschwarzneggar](https://www.tiktok.com/@arnoldschwarzneggar)
|
||||
|---|---|---|
|
||||
</td></tr>
|
||||
<tr><td align="center" width="9999">
|
||||
|
||||
# Quality progress
|
||||
| [mariahcareyathome?](https://www.tiktok.com/@mariahcareyathome?)| [diepnep](https://www.tiktok.com/@diepnep)| [mr__heisenberg](https://www.tiktok.com/@mr__heisenberg)| [deepcaprio](https://www.tiktok.com/@deepcaprio)
|
||||
|---|---|---|---|
|
||||
|
||||
</td></tr>
|
||||
<tr><td align="center" width="9999">
|
||||
|
||||
|2018| |
|
||||
| [VFXChris Ume](https://www.youtube.com/channel/UCGf4OlX_aTt8DlrgiH3jN3g/videos)| [Sham00k](https://www.youtube.com/channel/UCZXbWcv7fSZFTAZV4beckyw/videos)|
|
||||
|---|---|
|
||||
|
||||
|2020||
|
||||
| [Collider videos](https://www.youtube.com/watch?v=A91P2qtPT54&list=PLayt6616lBclvOprvrC8qKGCO-mAhPRux)| [iFake](https://www.youtube.com/channel/UCC0lK2Zo2BMXX-k1Ks0r7dg/videos)| [NextFace](https://www.youtube.com/channel/UCFh3gL0a8BS21g-DHvXZEeQ/videos)|
|
||||
|---|---|---|
|
||||
|
||||
| [Futuring Machine](https://www.youtube.com/channel/UCC5BbFxqLQgfnWPhprmQLVg)| [RepresentUS](https://www.youtube.com/channel/UCRzgK52MmetD9aG8pDOID3g)| [Corridor Crew](https://www.youtube.com/c/corridorcrew/videos)|
|
||||
|---|---|---|
|
||||
|
||||
| [DeepFaker](https://www.youtube.com/channel/UCkHecfDTcSazNZSKPEhtPVQ)| [DeepFakes in movie](https://www.youtube.com/c/DeepFakesinmovie/videos)|
|
||||
|---|---|
|
||||
|
||||
| [DeepFakeCreator](https://www.youtube.com/channel/UCkNFhcYNLQ5hr6A6lZ56mKA)| [Jarkan](https://www.youtube.com/user/Jarkancio/videos)|
|
||||
|---|---|
|
||||
|
||||
</td></tr>
|
||||
<tr><td align="center" width="9999">
|
||||
|
||||
<tr><td colspan=2 align="center">
|
||||
|
||||
# What can I do using DeepFaceLab?
|
||||
|
||||
</td></tr>
|
||||
<tr><td colspan=2 align="center">
|
||||
|
||||
## Replace the face
|
||||
|
||||
<img src="doc/replace_the_face.jpg" align="center">
|
||||
|
||||
</td></tr>
|
||||
|
||||
<tr><td colspan=2 align="center">
|
||||
|
||||
## De-age the face
|
||||
|
||||
</td></tr>
|
||||
|
||||
<tr><td align="center" width="50%">
|
||||
|
||||
<img src="doc/deage_0_1.jpg" align="center">
|
||||
|
||||
</td>
|
||||
<td align="center" width="50%">
|
||||
|
||||
<img src="doc/deage_0_2.jpg" align="center">
|
||||
|
||||
</td></tr>
|
||||
|
||||
<tr><td colspan=2 align="center">
|
||||
|
||||
 https://www.youtube.com/watch?v=Ddx5B-84ebo
|
||||
|
||||
</td></tr>
|
||||
|
||||
<tr><td colspan=2 align="center">
|
||||
|
||||
## Replace the head
|
||||
|
||||
</td></tr>
|
||||
|
||||
<tr><td align="center" width="50%">
|
||||
|
||||
<img src="doc/head_replace_1_1.jpg" align="center">
|
||||
|
||||
</td>
|
||||
<td align="center" width="50%">
|
||||
|
||||
<img src="doc/head_replace_1_2.jpg" align="center">
|
||||
|
||||
</td></tr>
|
||||
|
||||
<tr><td colspan=2 align="center">
|
||||
|
||||
 https://www.youtube.com/watch?v=RTjgkhMugVw
|
||||
|
||||
</td></tr>
|
||||
|
||||
<tr><td colspan=2 align="center">
|
||||
|
||||
# Native resolution progress
|
||||
|
||||
</td></tr>
|
||||
<tr><td colspan=2 align="center">
|
||||
|
||||
<img src="doc/deepfake_progress.png" align="center">
|
||||
|
||||
</td></tr>
|
||||
<tr><td colspan=2 align="center">
|
||||
|
||||
<img src="doc/make_everything_ok.png" align="center">
|
||||
|
||||
Unfortunately, there is no "make everything ok" button in DeepFaceLab. You should spend time studying the workflow and growing your skills. A skill in programs such as *AfterEffects* or *Davinci Resolve* is also desirable.
|
||||
|
||||
</td></tr>
|
||||
<tr><td colspan=2 align="center">
|
||||
|
||||
## Mini tutorial
|
||||
|
||||
<a href="https://www.youtube.com/watch?v=kOIMXt8KK8M">
|
||||
|
||||
<img src="doc/mini_tutorial.jpg" align="center">
|
||||
|
||||
</a>
|
||||
|
||||
</td></tr>
|
||||
<tr><td colspan=2 align="center">
|
||||
|
||||
## Releases
|
||||
|
||||
||||
|
||||
|---|---|---|
|
||||
|Windows|[github releases](https://github.com/iperov/DeepFaceLab/releases)|Direct download|
|
||||
||[Google drive](https://drive.google.com/open?id=1BCFK_L7lPNwMbEQ_kFPqPpDdFEOd_Dci)|if the download quota is exceeded, add the file to your own google drive and download from it|
|
||||
|Google Colab|[github](https://github.com/chervonij/DFL-Colab)|by @chervonij . You can train fakes for free using Google Colab.|
|
||||
|CentOS Linux|[github](https://github.com/elemantalcode/dfl)|by @elemantalcode|
|
||||
|Linux|[github](https://github.com/lbfs/DeepFaceLab_Linux)|by @lbfs |
|
||||
||||
|
||||
</td></tr>
|
||||
|
||||
<tr><td align="right">
|
||||
<a href="https://tinyurl.com/2p9cvt25">Windows (magnet link)</a>
|
||||
</td><td align="center">Last release. Use torrent client to download.</td></tr>
|
||||
|
||||
<tr><td align="right">
|
||||
<a href="https://mega.nz/folder/Po0nGQrA#dbbttiNWojCt8jzD4xYaPw">Windows (Mega.nz)</a>
|
||||
</td><td align="center">Contains new and prev releases.</td></tr>
|
||||
|
||||
<tr><td align="right">
|
||||
<a href="https://disk.yandex.ru/d/7i5XTKIKVg5UUg">Windows (yandex.ru)</a>
|
||||
</td><td align="center">Contains new and prev releases.</td></tr>
|
||||
|
||||
<tr><td align="right">
|
||||
<a href="https://github.com/nagadit/DeepFaceLab_Linux">Linux (github)</a>
|
||||
</td><td align="center">by @nagadit</td></tr>
|
||||
|
||||
<tr><td align="right">
|
||||
<a href="https://github.com/elemantalcode/dfl">CentOS Linux (github)</a>
|
||||
</td><td align="center">May be outdated. By @elemantalcode</td></tr>
|
||||
|
||||
</table>
|
||||
|
||||
<table align="center" border="0">
|
||||
|
||||
<tr><td colspan=2 align="center">
|
||||
|
||||
### Communication groups
|
||||
|
||||
</td></tr>
|
||||
|
||||
<tr><td align="center" width="9999">
|
||||
<tr><td align="right">
|
||||
<a href="https://discord.gg/rxa7h9M6rH">Discord</a>
|
||||
</td><td align="center">Official discord channel. English / Russian.</td></tr>
|
||||
|
||||
## Links
|
||||
<tr><td colspan=2 align="center">
|
||||
|
||||
|
||||
||||
|
||||
|---|---|---|
|
||||
|Guides and tutorials|||
|
||||
||[DeepFaceLab guide](https://mrdeepfakes.com/forums/thread-guide-deepfacelab-2-0-explained-and-tutorials-recommended)|Main guide|
|
||||
||[Faceset creation guide](https://mrdeepfakes.com/forums/thread-guide-celebrity-faceset-dataset-creation-how-to-create-celebrity-facesets)|How to create the right faceset |
|
||||
||[Google Colab guide](https://mrdeepfakes.com/forums/thread-guide-deepfacelab-google-colab-tutorial)|Guide how to train the fake on Google Colab|
|
||||
||[Compositing](https://mrdeepfakes.com/forums/thread-deepfacelab-2-0-compositing-in-davinci-resolve-vegas-pro-and-after-effects)|To achieve the highest quality, compose deepfake manually in video editors such as Davince Resolve or Adobe AfterEffects|
|
||||
||[Discussion and suggestions](https://mrdeepfakes.com/forums/thread-deepfacelab-2-0-discussion-tips-suggestions)||
|
||||
||||
|
||||
|Supplementary material|||
|
||||
||[Ready to work facesets](https://mrdeepfakes.com/forums/forum-celebrity-facesets)|Celebrity facesets made by community|
|
||||
||[Pretrained models](https://mrdeepfakes.com/forums/forum-celebrity-facesets)|Use pretrained models made by community to speed up training|
|
||||
||||
|
||||
|Communication groups|||
|
||||
||[telegram (English / Russian)](https://t.me/DeepFaceLab_official)|Don't forget to hide your phone number|
|
||||
||[telegram (English only)](https://t.me/DeepFaceLab_official_en)|Don't forget to hide your phone number|
|
||||
||[mrdeepfakes](https://mrdeepfakes.com/forums/)|the biggest NSFW English community|
|
||||
||QQ 951138799| 中文 Chinese QQ group for ML/AI experts||
|
||||
||[deepfaker.xyz](https://www.deepfaker.xyz)|中文 Chinese guys are localizing DeepFaceLab|
|
||||
||[reddit r/GifFakes/](https://www.reddit.com/r/GifFakes/new/)|Post your deepfakes there !|
|
||||
||[reddit r/SFWdeepfakes/](https://www.reddit.com/r/SFWdeepfakes/new/)|Post your deepfakes there !|
|
||||
## Related works
|
||||
|
||||
</td></tr>
|
||||
|
||||
<tr><td align="center" width="9999">
|
||||
|
||||
## How I can help the project?
|
||||
|
||||
||||
|
||||
|---|---|---|
|
||||
|Donate|Sponsor deepfake research and DeepFaceLab development.||
|
||||
||[Donate via Paypal](https://www.paypal.com/cgi-bin/webscr?cmd=_donations&business=lepersorium@gmail.com&lc=US&no_note=0&item_name=Support+DeepFaceLab&cn=&curency_code=USD&bn=PP-DonationsBF:btn_donateCC_LG.gif:NonHosted)
|
||||
||[Donate via Yandex.Money](https://money.yandex.ru/to/41001142318065)||
|
||||
||bitcoin:31mPd6DxPCzbpCMZk4k1koWAbErSyqkAXr||
|
||||
||||
|
||||
|Last donations|10$ ( 14 march 2020 Amien Phillips )
|
||||
||20$ ( 12 march 2020 Maria D. )
|
||||
||200$ ( 12 march 2020 VFXChris Ume )
|
||||
||300$ ( 12 march 2020 Thiago O. )
|
||||
||50$ ( 8 march 2020 blanuk )
|
||||
||||
|
||||
|Collect facesets|You can collect faceset of any celebrity that can be used in DeepFaceLab and share it [in the community](https://mrdeepfakes.com/forums/forum-celebrity-facesets)|
|
||||
||||
|
||||
|Star this repo|Register github account and push "Star" button.
|
||||
</td></tr>
|
||||
|
||||
<tr><td align="center" width="9999">
|
||||
|
||||
## Meme zone
|
||||
<p align="center">
|
||||
|
||||

|
||||
|
||||
</p>
|
||||
|
||||
</td></tr>
|
||||
<tr><td align="center" width="9999">
|
||||
|
||||
<sub>#deepfacelab #deepfakes #faceswap #face-swap #deep-learning #deeplearning #deep-neural-networks #deepface #deep-face-swap #fakeapp #fake-app #neural-networks #neural-nets #tensorflow #cuda #nvidia</sub>
|
||||
<tr><td align="right">
|
||||
<a href="https://github.com/iperov/DeepFaceLive">DeepFaceLive</a>
|
||||
</td><td align="center">Real-time face swap for PC streaming or video calls</td></tr>
|
||||
|
||||
</td></tr>
|
||||
</table>
|
||||
|
||||
<table align="center" border="0">
|
||||
|
||||
<tr><td colspan=2 align="center">
|
||||
|
||||
## How I can help the project?
|
||||
|
||||
</td></tr>
|
||||
|
||||
<tr><td colspan=2 align="center">
|
||||
|
||||
### Star this repo
|
||||
|
||||
</td></tr>
|
||||
|
||||
<tr><td colspan=2 align="center">
|
||||
|
||||
Register github account and push "Star" button.
|
||||
|
||||
</td></tr>
|
||||
|
||||
</table>
|
||||
|
||||
<table align="center" border="0">
|
||||
<tr><td colspan=2 align="center">
|
||||
|
||||
## Meme zone
|
||||
|
||||
</td></tr>
|
||||
|
||||
<tr><td align="center" width="50%">
|
||||
|
||||
<img src="doc/meme1.jpg" align="center">
|
||||
|
||||
</td>
|
||||
|
||||
<td align="center" width="50%">
|
||||
|
||||
<img src="doc/meme2.jpg" align="center">
|
||||
|
||||
</td></tr>
|
||||
|
||||
<tr><td colspan=2 align="center">
|
||||
|
||||
<sub>#deepfacelab #faceswap #face-swap #deep-learning #deeplearning #deep-neural-networks #deepface #deep-face-swap #neural-networks #neural-nets #tensorflow #cuda #nvidia</sub>
|
||||
|
||||
</td></tr>
|
||||
|
||||
|
||||
|
||||
</table>
|
||||
|
|
10
XSegEditor/QCursorDB.py
Normal file
|
@ -0,0 +1,10 @@
|
|||
from PyQt5.QtCore import *
|
||||
from PyQt5.QtGui import *
|
||||
from PyQt5.QtWidgets import *
|
||||
|
||||
class QCursorDB():
|
||||
@staticmethod
|
||||
def initialize(cursor_path):
|
||||
QCursorDB.cross_red = QCursor ( QPixmap ( str(cursor_path / 'cross_red.png') ) )
|
||||
QCursorDB.cross_green = QCursor ( QPixmap ( str(cursor_path / 'cross_green.png') ) )
|
||||
QCursorDB.cross_blue = QCursor ( QPixmap ( str(cursor_path / 'cross_blue.png') ) )
|
26
XSegEditor/QIconDB.py
Normal file
|
@ -0,0 +1,26 @@
|
|||
from PyQt5.QtCore import *
|
||||
from PyQt5.QtGui import *
|
||||
from PyQt5.QtWidgets import *
|
||||
|
||||
|
||||
class QIconDB():
|
||||
@staticmethod
|
||||
def initialize(icon_path):
|
||||
QIconDB.app_icon = QIcon ( str(icon_path / 'app_icon.png') )
|
||||
QIconDB.delete_poly = QIcon ( str(icon_path / 'delete_poly.png') )
|
||||
QIconDB.undo_pt = QIcon ( str(icon_path / 'undo_pt.png') )
|
||||
QIconDB.redo_pt = QIcon ( str(icon_path / 'redo_pt.png') )
|
||||
QIconDB.poly_color_red = QIcon ( str(icon_path / 'poly_color_red.png') )
|
||||
QIconDB.poly_color_green = QIcon ( str(icon_path / 'poly_color_green.png') )
|
||||
QIconDB.poly_color_blue = QIcon ( str(icon_path / 'poly_color_blue.png') )
|
||||
QIconDB.poly_type_include = QIcon ( str(icon_path / 'poly_type_include.png') )
|
||||
QIconDB.poly_type_exclude = QIcon ( str(icon_path / 'poly_type_exclude.png') )
|
||||
QIconDB.left = QIcon ( str(icon_path / 'left.png') )
|
||||
QIconDB.right = QIcon ( str(icon_path / 'right.png') )
|
||||
QIconDB.trashcan = QIcon ( str(icon_path / 'trashcan.png') )
|
||||
QIconDB.pt_edit_mode = QIcon ( str(icon_path / 'pt_edit_mode.png') )
|
||||
QIconDB.view_lock_center = QIcon ( str(icon_path / 'view_lock_center.png') )
|
||||
QIconDB.view_baked = QIcon ( str(icon_path / 'view_baked.png') )
|
||||
QIconDB.view_xseg = QIcon ( str(icon_path / 'view_xseg.png') )
|
||||
QIconDB.view_xseg_overlay = QIcon ( str(icon_path / 'view_xseg_overlay.png') )
|
||||
|
8
XSegEditor/QImageDB.py
Normal file
|
@ -0,0 +1,8 @@
|
|||
from PyQt5.QtCore import *
|
||||
from PyQt5.QtGui import *
|
||||
from PyQt5.QtWidgets import *
|
||||
|
||||
class QImageDB():
|
||||
@staticmethod
|
||||
def initialize(image_path):
|
||||
QImageDB.intro = QImage ( str(image_path / 'intro.png') )
|
102
XSegEditor/QStringDB.py
Normal file
|
@ -0,0 +1,102 @@
|
|||
from localization import system_language
|
||||
|
||||
|
||||
class QStringDB():
|
||||
|
||||
@staticmethod
|
||||
def initialize():
|
||||
lang = system_language
|
||||
|
||||
if lang not in ['en','ru','zh']:
|
||||
lang = 'en'
|
||||
|
||||
QStringDB.btn_poly_color_red_tip = { 'en' : 'Poly color scheme red',
|
||||
'ru' : 'Красная цветовая схема полигонов',
|
||||
'zh' : '选区配色方案红色',
|
||||
}[lang]
|
||||
|
||||
QStringDB.btn_poly_color_green_tip = { 'en' : 'Poly color scheme green',
|
||||
'ru' : 'Зелёная цветовая схема полигонов',
|
||||
'zh' : '选区配色方案绿色',
|
||||
}[lang]
|
||||
|
||||
QStringDB.btn_poly_color_blue_tip = { 'en' : 'Poly color scheme blue',
|
||||
'ru' : 'Синяя цветовая схема полигонов',
|
||||
'zh' : '选区配色方案蓝色',
|
||||
}[lang]
|
||||
|
||||
QStringDB.btn_view_baked_mask_tip = { 'en' : 'View baked mask',
|
||||
'ru' : 'Посмотреть запечёную маску',
|
||||
'zh' : '查看遮罩通道',
|
||||
}[lang]
|
||||
|
||||
QStringDB.btn_view_xseg_mask_tip = { 'en' : 'View trained XSeg mask',
|
||||
'ru' : 'Посмотреть тренированную XSeg маску',
|
||||
'zh' : '查看导入后的XSeg遮罩',
|
||||
}[lang]
|
||||
|
||||
QStringDB.btn_view_xseg_overlay_mask_tip = { 'en' : 'View trained XSeg mask overlay face',
|
||||
'ru' : 'Посмотреть тренированную XSeg маску поверх лица',
|
||||
'zh' : '查看导入后的XSeg遮罩于脸上方',
|
||||
}[lang]
|
||||
|
||||
QStringDB.btn_poly_type_include_tip = { 'en' : 'Poly include mode',
|
||||
'ru' : 'Режим полигонов - включение',
|
||||
'zh' : '包含选区模式',
|
||||
}[lang]
|
||||
|
||||
QStringDB.btn_poly_type_exclude_tip = { 'en' : 'Poly exclude mode',
|
||||
'ru' : 'Режим полигонов - исключение',
|
||||
'zh' : '排除选区模式',
|
||||
}[lang]
|
||||
|
||||
QStringDB.btn_undo_pt_tip = { 'en' : 'Undo point',
|
||||
'ru' : 'Отменить точку',
|
||||
'zh' : '撤消点',
|
||||
}[lang]
|
||||
|
||||
QStringDB.btn_redo_pt_tip = { 'en' : 'Redo point',
|
||||
'ru' : 'Повторить точку',
|
||||
'zh' : '重做点',
|
||||
}[lang]
|
||||
|
||||
QStringDB.btn_delete_poly_tip = { 'en' : 'Delete poly',
|
||||
'ru' : 'Удалить полигон',
|
||||
'zh' : '删除选区',
|
||||
}[lang]
|
||||
|
||||
QStringDB.btn_pt_edit_mode_tip = { 'en' : 'Add/delete point mode ( HOLD CTRL )',
|
||||
'ru' : 'Режим добавления/удаления точек ( удерживайте CTRL )',
|
||||
'zh' : '点加/删除模式 ( 按住CTRL )',
|
||||
}[lang]
|
||||
|
||||
QStringDB.btn_view_lock_center_tip = { 'en' : 'Lock cursor at the center ( HOLD SHIFT )',
|
||||
'ru' : 'Заблокировать курсор в центре ( удерживайте SHIFT )',
|
||||
'zh' : '将光标锁定在中心 ( 按住SHIFT )',
|
||||
}[lang]
|
||||
|
||||
|
||||
QStringDB.btn_prev_image_tip = { 'en' : 'Save and Prev image\nHold SHIFT : accelerate\nHold CTRL : skip non masked\n',
|
||||
'ru' : 'Сохранить и предыдущее изображение\nУдерживать SHIFT : ускорить\nУдерживать CTRL : пропустить неразмеченные\n',
|
||||
'zh' : '保存并转到上一张图片\n按住SHIFT : 加快\n按住CTRL : 跳过未标记的\n',
|
||||
}[lang]
|
||||
QStringDB.btn_next_image_tip = { 'en' : 'Save and Next image\nHold SHIFT : accelerate\nHold CTRL : skip non masked\n',
|
||||
'ru' : 'Сохранить и следующее изображение\nУдерживать SHIFT : ускорить\nУдерживать CTRL : пропустить неразмеченные\n',
|
||||
'zh' : '保存并转到下一张图片\n按住SHIFT : 加快\n按住CTRL : 跳过未标记的\n',
|
||||
}[lang]
|
||||
|
||||
QStringDB.btn_delete_image_tip = { 'en' : 'Move to _trash and Next image\n',
|
||||
'ru' : 'Переместить в _trash и следующее изображение\n',
|
||||
'zh' : '移至_trash,转到下一张图片 ',
|
||||
}[lang]
|
||||
|
||||
QStringDB.loading_tip = {'en' : 'Loading',
|
||||
'ru' : 'Загрузка',
|
||||
'zh' : '正在载入',
|
||||
}[lang]
|
||||
|
||||
QStringDB.labeled_tip = {'en' : 'labeled',
|
||||
'ru' : 'размечено',
|
||||
'zh' : '标记的',
|
||||
}[lang]
|
||||
|
1494
XSegEditor/XSegEditor.py
Normal file
BIN
XSegEditor/gfx/cursors/cross_blue.png
Normal file
After Width: | Height: | Size: 1.6 KiB |
BIN
XSegEditor/gfx/cursors/cross_green.png
Normal file
After Width: | Height: | Size: 1.6 KiB |
BIN
XSegEditor/gfx/cursors/cross_red.png
Normal file
After Width: | Height: | Size: 1.6 KiB |
BIN
XSegEditor/gfx/fonts/NotoSans-Medium.ttf
Normal file
BIN
XSegEditor/gfx/icons/app_icon.png
Normal file
After Width: | Height: | Size: 5.5 KiB |
BIN
XSegEditor/gfx/icons/delete_poly.png
Normal file
After Width: | Height: | Size: 4.7 KiB |
BIN
XSegEditor/gfx/icons/down.png
Normal file
After Width: | Height: | Size: 2.6 KiB |
BIN
XSegEditor/gfx/icons/left.png
Normal file
After Width: | Height: | Size: 2.7 KiB |
BIN
XSegEditor/gfx/icons/poly_color.psd
Normal file
BIN
XSegEditor/gfx/icons/poly_color_blue.png
Normal file
After Width: | Height: | Size: 8.4 KiB |
BIN
XSegEditor/gfx/icons/poly_color_green.png
Normal file
After Width: | Height: | Size: 9 KiB |
BIN
XSegEditor/gfx/icons/poly_color_red.png
Normal file
After Width: | Height: | Size: 8.9 KiB |
BIN
XSegEditor/gfx/icons/poly_type_exclude.png
Normal file
After Width: | Height: | Size: 6.3 KiB |
BIN
XSegEditor/gfx/icons/poly_type_include.png
Normal file
After Width: | Height: | Size: 5.6 KiB |
BIN
XSegEditor/gfx/icons/poly_type_source.psd
Normal file
BIN
XSegEditor/gfx/icons/pt_edit_mode.png
Normal file
After Width: | Height: | Size: 4.2 KiB |
BIN
XSegEditor/gfx/icons/pt_edit_mode_source.psd
Normal file
BIN
XSegEditor/gfx/icons/redo_pt.png
Normal file
After Width: | Height: | Size: 5.4 KiB |
BIN
XSegEditor/gfx/icons/redo_pt_source.psd
Normal file
BIN
XSegEditor/gfx/icons/right.png
Normal file
After Width: | Height: | Size: 2.7 KiB |
BIN
XSegEditor/gfx/icons/trashcan.png
Normal file
After Width: | Height: | Size: 3.2 KiB |
BIN
XSegEditor/gfx/icons/undo_pt.png
Normal file
After Width: | Height: | Size: 5.4 KiB |
BIN
XSegEditor/gfx/icons/undo_pt_source.psd
Normal file
BIN
XSegEditor/gfx/icons/up.png
Normal file
After Width: | Height: | Size: 2.6 KiB |
BIN
XSegEditor/gfx/icons/view_baked.png
Normal file
After Width: | Height: | Size: 8.1 KiB |
BIN
XSegEditor/gfx/icons/view_lock_center.png
Normal file
After Width: | Height: | Size: 4 KiB |
BIN
XSegEditor/gfx/icons/view_xseg.png
Normal file
After Width: | Height: | Size: 9.5 KiB |
BIN
XSegEditor/gfx/icons/view_xseg_overlay.png
Normal file
After Width: | Height: | Size: 12 KiB |
BIN
XSegEditor/gfx/images/intro.png
Normal file
After Width: | Height: | Size: 30 KiB |
BIN
XSegEditor/gfx/images/intro_source.psd
Normal file
10
_config.yml
|
@ -1 +1,9 @@
|
|||
theme: jekyll-theme-cayman
|
||||
theme: jekyll-theme-cayman
|
||||
plugins:
|
||||
- jekyll-relative-links
|
||||
relative_links:
|
||||
enabled: true
|
||||
collections: true
|
||||
|
||||
include:
|
||||
- README.md
|
|
@ -2,6 +2,7 @@ import cv2
|
|||
import numpy as np
|
||||
from pathlib import Path
|
||||
from core.interact import interact as io
|
||||
from core import imagelib
|
||||
import traceback
|
||||
|
||||
def cv2_imread(filename, flags=cv2.IMREAD_UNCHANGED, loader_func=None, verbose=True):
|
||||
|
@ -29,3 +30,11 @@ def cv2_imwrite(filename, img, *args):
|
|||
stream.write( buf )
|
||||
except:
|
||||
pass
|
||||
|
||||
def cv2_resize(x, *args, **kwargs):
|
||||
h,w,c = x.shape
|
||||
x = cv2.resize(x, *args, **kwargs)
|
||||
|
||||
x = imagelib.normalize_channels(x, c)
|
||||
return x
|
||||
|
|
@ -1,109 +0,0 @@
|
|||
import numpy as np
|
||||
import cv2
|
||||
|
||||
class IEPolysPoints:
|
||||
def __init__(self, IEPolys_parent, type):
|
||||
self.parent = IEPolys_parent
|
||||
self.type = type
|
||||
self.points = np.empty( (0,2), dtype=np.int32 )
|
||||
self.n_max = self.n = 0
|
||||
|
||||
def add(self,x,y):
|
||||
self.points = np.append(self.points[0:self.n], [ (x,y) ], axis=0)
|
||||
self.n_max = self.n = self.n + 1
|
||||
self.parent.dirty = True
|
||||
|
||||
def n_dec(self):
|
||||
self.n = max(0, self.n-1)
|
||||
self.parent.dirty = True
|
||||
return self.n
|
||||
|
||||
def n_inc(self):
|
||||
self.n = min(len(self.points), self.n+1)
|
||||
self.parent.dirty = True
|
||||
return self.n
|
||||
|
||||
def n_clip(self):
|
||||
self.points = self.points[0:self.n]
|
||||
self.n_max = self.n
|
||||
|
||||
def cur_point(self):
|
||||
return self.points[self.n-1]
|
||||
|
||||
def points_to_n(self):
|
||||
return self.points[0:self.n]
|
||||
|
||||
def set_points(self, points):
|
||||
self.points = np.array(points)
|
||||
self.n_max = self.n = len(points)
|
||||
self.parent.dirty = True
|
||||
|
||||
class IEPolys:
|
||||
def __init__(self):
|
||||
self.list = []
|
||||
self.n_max = self.n = 0
|
||||
self.dirty = True
|
||||
|
||||
def add(self, type):
|
||||
self.list = self.list[0:self.n]
|
||||
l = IEPolysPoints(self, type)
|
||||
self.list.append ( l )
|
||||
self.n_max = self.n = self.n + 1
|
||||
self.dirty = True
|
||||
return l
|
||||
|
||||
def n_dec(self):
|
||||
self.n = max(0, self.n-1)
|
||||
self.dirty = True
|
||||
return self.n
|
||||
|
||||
def n_inc(self):
|
||||
self.n = min(len(self.list), self.n+1)
|
||||
self.dirty = True
|
||||
return self.n
|
||||
|
||||
def n_list(self):
|
||||
return self.list[self.n-1]
|
||||
|
||||
def n_clip(self):
|
||||
self.list = self.list[0:self.n]
|
||||
self.n_max = self.n
|
||||
if self.n > 0:
|
||||
self.list[-1].n_clip()
|
||||
|
||||
def __iter__(self):
|
||||
for n in range(self.n):
|
||||
yield self.list[n]
|
||||
|
||||
def switch_dirty(self):
|
||||
d = self.dirty
|
||||
self.dirty = False
|
||||
return d
|
||||
|
||||
def overlay_mask(self, mask):
|
||||
h,w,c = mask.shape
|
||||
white = (1,)*c
|
||||
black = (0,)*c
|
||||
for n in range(self.n):
|
||||
poly = self.list[n]
|
||||
if poly.n > 0:
|
||||
cv2.fillPoly(mask, [poly.points_to_n()], white if poly.type == 1 else black )
|
||||
|
||||
def get_total_points(self):
|
||||
return sum([self.list[n].n for n in range(self.n)])
|
||||
|
||||
def dump(self):
|
||||
result = []
|
||||
for n in range(self.n):
|
||||
l = self.list[n]
|
||||
result += [ (l.type, l.points_to_n().tolist() ) ]
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def load(ie_polys=None):
|
||||
obj = IEPolys()
|
||||
if ie_polys is not None and isinstance(ie_polys, list):
|
||||
for (type, points) in ie_polys:
|
||||
obj.add(type)
|
||||
obj.n_list().set_points(points)
|
||||
return obj
|
158
core/imagelib/SegIEPolys.py
Normal file
|
@ -0,0 +1,158 @@
|
|||
import numpy as np
|
||||
import cv2
|
||||
from enum import IntEnum
|
||||
|
||||
|
||||
class SegIEPolyType(IntEnum):
|
||||
EXCLUDE = 0
|
||||
INCLUDE = 1
|
||||
|
||||
|
||||
|
||||
class SegIEPoly():
|
||||
def __init__(self, type=None, pts=None, **kwargs):
|
||||
self.type = type
|
||||
|
||||
if pts is None:
|
||||
pts = np.empty( (0,2), dtype=np.float32 )
|
||||
else:
|
||||
pts = np.float32(pts)
|
||||
self.pts = pts
|
||||
self.n_max = self.n = len(pts)
|
||||
|
||||
def dump(self):
|
||||
return {'type': int(self.type),
|
||||
'pts' : self.get_pts(),
|
||||
}
|
||||
|
||||
def identical(self, b):
|
||||
if self.n != b.n:
|
||||
return False
|
||||
return (self.pts[0:self.n] == b.pts[0:b.n]).all()
|
||||
|
||||
def get_type(self):
|
||||
return self.type
|
||||
|
||||
def add_pt(self, x, y):
|
||||
self.pts = np.append(self.pts[0:self.n], [ ( float(x), float(y) ) ], axis=0).astype(np.float32)
|
||||
self.n_max = self.n = self.n + 1
|
||||
|
||||
def undo(self):
|
||||
self.n = max(0, self.n-1)
|
||||
return self.n
|
||||
|
||||
def redo(self):
|
||||
self.n = min(len(self.pts), self.n+1)
|
||||
return self.n
|
||||
|
||||
def redo_clip(self):
|
||||
self.pts = self.pts[0:self.n]
|
||||
self.n_max = self.n
|
||||
|
||||
def insert_pt(self, n, pt):
|
||||
if n < 0 or n > self.n:
|
||||
raise ValueError("insert_pt out of range")
|
||||
self.pts = np.concatenate( (self.pts[0:n], pt[None,...].astype(np.float32), self.pts[n:]), axis=0)
|
||||
self.n_max = self.n = self.n+1
|
||||
|
||||
def remove_pt(self, n):
|
||||
if n < 0 or n >= self.n:
|
||||
raise ValueError("remove_pt out of range")
|
||||
self.pts = np.concatenate( (self.pts[0:n], self.pts[n+1:]), axis=0)
|
||||
self.n_max = self.n = self.n-1
|
||||
|
||||
def get_last_point(self):
|
||||
return self.pts[self.n-1].copy()
|
||||
|
||||
def get_pts(self):
|
||||
return self.pts[0:self.n].copy()
|
||||
|
||||
def get_pts_count(self):
|
||||
return self.n
|
||||
|
||||
def set_point(self, id, pt):
|
||||
self.pts[id] = pt
|
||||
|
||||
def set_points(self, pts):
|
||||
self.pts = np.array(pts)
|
||||
self.n_max = self.n = len(pts)
|
||||
|
||||
def mult_points(self, val):
|
||||
self.pts *= val
|
||||
|
||||
|
||||
|
||||
class SegIEPolys():
|
||||
def __init__(self):
|
||||
self.polys = []
|
||||
|
||||
def identical(self, b):
|
||||
polys_len = len(self.polys)
|
||||
o_polys_len = len(b.polys)
|
||||
if polys_len != o_polys_len:
|
||||
return False
|
||||
|
||||
return all ([ a_poly.identical(b_poly) for a_poly, b_poly in zip(self.polys, b.polys) ])
|
||||
|
||||
def add_poly(self, ie_poly_type):
|
||||
poly = SegIEPoly(ie_poly_type)
|
||||
self.polys.append (poly)
|
||||
return poly
|
||||
|
||||
def remove_poly(self, poly):
|
||||
if poly in self.polys:
|
||||
self.polys.remove(poly)
|
||||
|
||||
def has_polys(self):
|
||||
return len(self.polys) != 0
|
||||
|
||||
def get_poly(self, id):
|
||||
return self.polys[id]
|
||||
|
||||
def get_polys(self):
|
||||
return self.polys
|
||||
|
||||
def get_pts_count(self):
|
||||
return sum([poly.get_pts_count() for poly in self.polys])
|
||||
|
||||
def sort(self):
|
||||
poly_by_type = { SegIEPolyType.EXCLUDE : [], SegIEPolyType.INCLUDE : [] }
|
||||
|
||||
for poly in self.polys:
|
||||
poly_by_type[poly.type].append(poly)
|
||||
|
||||
self.polys = poly_by_type[SegIEPolyType.INCLUDE] + poly_by_type[SegIEPolyType.EXCLUDE]
|
||||
|
||||
def __iter__(self):
|
||||
for poly in self.polys:
|
||||
yield poly
|
||||
|
||||
def overlay_mask(self, mask):
|
||||
h,w,c = mask.shape
|
||||
white = (1,)*c
|
||||
black = (0,)*c
|
||||
for poly in self.polys:
|
||||
pts = poly.get_pts().astype(np.int32)
|
||||
if len(pts) != 0:
|
||||
cv2.fillPoly(mask, [pts], white if poly.type == SegIEPolyType.INCLUDE else black )
|
||||
|
||||
def dump(self):
|
||||
return {'polys' : [ poly.dump() for poly in self.polys ] }
|
||||
|
||||
def mult_points(self, val):
|
||||
for poly in self.polys:
|
||||
poly.mult_points(val)
|
||||
|
||||
@staticmethod
|
||||
def load(data=None):
|
||||
ie_polys = SegIEPolys()
|
||||
if data is not None:
|
||||
if isinstance(data, list):
|
||||
# Backward comp
|
||||
ie_polys.polys = [ SegIEPoly(type=type, pts=pts) for (type, pts) in data ]
|
||||
elif isinstance(data, dict):
|
||||
ie_polys.polys = [ SegIEPoly(**poly_cfg) for poly_cfg in data['polys'] ]
|
||||
|
||||
ie_polys.sort()
|
||||
|
||||
return ie_polys
|
|
@ -1,4 +1,5 @@
|
|||
from .estimate_sharpness import estimate_sharpness
|
||||
|
||||
from .equalize_and_stack_square import equalize_and_stack_square
|
||||
|
||||
from .text import get_text_image, get_draw_text_lines
|
||||
|
@ -11,16 +12,21 @@ from .warp import gen_warp_params, warp_by_params
|
|||
|
||||
from .reduce_colors import reduce_colors
|
||||
|
||||
from .color_transfer import color_transfer, color_transfer_mix, color_transfer_sot, color_transfer_mkl, color_transfer_idt, color_hist_match, reinhard_color_transfer, linear_color_transfer, seamless_clone
|
||||
from .color_transfer import color_transfer, color_transfer_mix, color_transfer_sot, color_transfer_mkl, color_transfer_idt, color_hist_match, reinhard_color_transfer, linear_color_transfer
|
||||
|
||||
from .common import normalize_channels, cut_odd_image, overlay_alpha_image
|
||||
from .common import random_crop, normalize_channels, cut_odd_image, overlay_alpha_image
|
||||
|
||||
from .IEPolys import IEPolys
|
||||
from .SegIEPolys import *
|
||||
|
||||
from .blursharpen import LinearMotionBlur, blursharpen
|
||||
|
||||
from .filters import apply_random_rgb_levels, \
|
||||
apply_random_overlay_triangle, \
|
||||
apply_random_hsv_shift, \
|
||||
apply_random_sharpen, \
|
||||
apply_random_motion_blur, \
|
||||
apply_random_gaussian_blur, \
|
||||
apply_random_bilinear_resize
|
||||
apply_random_nearest_resize, \
|
||||
apply_random_bilinear_resize, \
|
||||
apply_random_jpeg_compress, \
|
||||
apply_random_relight
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
import cv2
|
||||
import numexpr as ne
|
||||
import numpy as np
|
||||
import scipy as sp
|
||||
from numpy import linalg as npla
|
||||
|
||||
import scipy as sp
|
||||
import scipy.sparse
|
||||
from scipy.sparse.linalg import spsolve
|
||||
|
||||
def color_transfer_sot(src,trg, steps=10, batch_size=5, reg_sigmaXY=16.0, reg_sigmaV=5.0):
|
||||
"""
|
||||
|
@ -35,8 +34,9 @@ def color_transfer_sot(src,trg, steps=10, batch_size=5, reg_sigmaXY=16.0, reg_si
|
|||
h,w,c = src.shape
|
||||
new_src = src.copy()
|
||||
|
||||
advect = np.empty ( (h*w,c), dtype=src_dtype )
|
||||
for step in range (steps):
|
||||
advect = np.zeros ( (h*w,c), dtype=src_dtype )
|
||||
advect.fill(0)
|
||||
for batch in range (batch_size):
|
||||
dir = np.random.normal(size=c).astype(src_dtype)
|
||||
dir /= npla.norm(dir)
|
||||
|
@ -91,6 +91,8 @@ def color_transfer_mkl(x0, x1):
|
|||
return np.clip ( result.reshape ( (h,w,c) ).astype(x0.dtype), 0, 1)
|
||||
|
||||
def color_transfer_idt(i0, i1, bins=256, n_rot=20):
|
||||
import scipy.stats
|
||||
|
||||
relaxation = 1 / n_rot
|
||||
h,w,c = i0.shape
|
||||
h1,w1,c1 = i1.shape
|
||||
|
@ -133,135 +135,57 @@ def color_transfer_idt(i0, i1, bins=256, n_rot=20):
|
|||
|
||||
return np.clip ( d0.T.reshape ( (h,w,c) ).astype(i0.dtype) , 0, 1)
|
||||
|
||||
def laplacian_matrix(n, m):
|
||||
mat_D = scipy.sparse.lil_matrix((m, m))
|
||||
mat_D.setdiag(-1, -1)
|
||||
mat_D.setdiag(4)
|
||||
mat_D.setdiag(-1, 1)
|
||||
mat_A = scipy.sparse.block_diag([mat_D] * n).tolil()
|
||||
mat_A.setdiag(-1, 1*m)
|
||||
mat_A.setdiag(-1, -1*m)
|
||||
return mat_A
|
||||
def reinhard_color_transfer(target : np.ndarray, source : np.ndarray, target_mask : np.ndarray = None, source_mask : np.ndarray = None, mask_cutoff=0.5) -> np.ndarray:
|
||||
"""
|
||||
Transfer color using rct method.
|
||||
|
||||
def seamless_clone(source, target, mask):
|
||||
h, w,c = target.shape
|
||||
result = []
|
||||
target np.ndarray H W 3C (BGR) np.float32
|
||||
source np.ndarray H W 3C (BGR) np.float32
|
||||
|
||||
mat_A = laplacian_matrix(h, w)
|
||||
laplacian = mat_A.tocsc()
|
||||
target_mask(None) np.ndarray H W 1C np.float32
|
||||
source_mask(None) np.ndarray H W 1C np.float32
|
||||
|
||||
mask_cutoff(0.5) float
|
||||
|
||||
mask[0,:] = 1
|
||||
mask[-1,:] = 1
|
||||
mask[:,0] = 1
|
||||
mask[:,-1] = 1
|
||||
q = np.argwhere(mask==0)
|
||||
masks are used to limit the space where color statistics will be computed to adjust the target
|
||||
|
||||
k = q[:,1]+q[:,0]*w
|
||||
mat_A[k, k] = 1
|
||||
mat_A[k, k + 1] = 0
|
||||
mat_A[k, k - 1] = 0
|
||||
mat_A[k, k + w] = 0
|
||||
mat_A[k, k - w] = 0
|
||||
reference: Color Transfer between Images https://www.cs.tau.ac.il/~turkel/imagepapers/ColorTransfer.pdf
|
||||
"""
|
||||
source = cv2.cvtColor(source, cv2.COLOR_BGR2LAB)
|
||||
target = cv2.cvtColor(target, cv2.COLOR_BGR2LAB)
|
||||
|
||||
mat_A = mat_A.tocsc()
|
||||
mask_flat = mask.flatten()
|
||||
for channel in range(c):
|
||||
source_input = source
|
||||
if source_mask is not None:
|
||||
source_input = source_input.copy()
|
||||
source_input[source_mask[...,0] < mask_cutoff] = [0,0,0]
|
||||
|
||||
target_input = target
|
||||
if target_mask is not None:
|
||||
target_input = target_input.copy()
|
||||
target_input[target_mask[...,0] < mask_cutoff] = [0,0,0]
|
||||
|
||||
source_flat = source[:, :, channel].flatten()
|
||||
target_flat = target[:, :, channel].flatten()
|
||||
target_l_mean, target_l_std, target_a_mean, target_a_std, target_b_mean, target_b_std, \
|
||||
= target_input[...,0].mean(), target_input[...,0].std(), target_input[...,1].mean(), target_input[...,1].std(), target_input[...,2].mean(), target_input[...,2].std()
|
||||
|
||||
source_l_mean, source_l_std, source_a_mean, source_a_std, source_b_mean, source_b_std, \
|
||||
= source_input[...,0].mean(), source_input[...,0].std(), source_input[...,1].mean(), source_input[...,1].std(), source_input[...,2].mean(), source_input[...,2].std()
|
||||
|
||||
# not as in the paper: scale by the standard deviations using reciprocal of paper proposed factor
|
||||
target_l = target[...,0]
|
||||
target_l = ne.evaluate('(target_l - target_l_mean) * source_l_std / target_l_std + source_l_mean')
|
||||
|
||||
mat_b = laplacian.dot(source_flat)*0.75
|
||||
mat_b[mask_flat==0] = target_flat[mask_flat==0]
|
||||
target_a = target[...,1]
|
||||
target_a = ne.evaluate('(target_a - target_a_mean) * source_a_std / target_a_std + source_a_mean')
|
||||
|
||||
target_b = target[...,2]
|
||||
target_b = ne.evaluate('(target_b - target_b_mean) * source_b_std / target_b_std + source_b_mean')
|
||||
|
||||
x = spsolve(mat_A, mat_b).reshape((h, w))
|
||||
result.append (x)
|
||||
np.clip(target_l, 0, 100, out=target_l)
|
||||
np.clip(target_a, -127, 127, out=target_a)
|
||||
np.clip(target_b, -127, 127, out=target_b)
|
||||
|
||||
return cv2.cvtColor(np.stack([target_l,target_a,target_b], -1), cv2.COLOR_LAB2BGR)
|
||||
|
||||
return np.clip( np.dstack(result), 0, 1 )
|
||||
|
||||
def reinhard_color_transfer(target, source, clip=False, preserve_paper=False, source_mask=None, target_mask=None):
|
||||
"""
|
||||
Transfers the color distribution from the source to the target
|
||||
image using the mean and standard deviations of the L*a*b*
|
||||
color space.
|
||||
|
||||
This implementation is (loosely) based on to the "Color Transfer
|
||||
between Images" paper by Reinhard et al., 2001.
|
||||
|
||||
Parameters:
|
||||
-------
|
||||
source: NumPy array
|
||||
OpenCV image in BGR color space (the source image)
|
||||
target: NumPy array
|
||||
OpenCV image in BGR color space (the target image)
|
||||
clip: Should components of L*a*b* image be scaled by np.clip before
|
||||
converting back to BGR color space?
|
||||
If False then components will be min-max scaled appropriately.
|
||||
Clipping will keep target image brightness truer to the input.
|
||||
Scaling will adjust image brightness to avoid washed out portions
|
||||
in the resulting color transfer that can be caused by clipping.
|
||||
preserve_paper: Should color transfer strictly follow methodology
|
||||
layed out in original paper? The method does not always produce
|
||||
aesthetically pleasing results.
|
||||
If False then L*a*b* components will scaled using the reciprocal of
|
||||
the scaling factor proposed in the paper. This method seems to produce
|
||||
more consistently aesthetically pleasing results
|
||||
|
||||
Returns:
|
||||
-------
|
||||
transfer: NumPy array
|
||||
OpenCV image (w, h, 3) NumPy array (uint8)
|
||||
"""
|
||||
|
||||
|
||||
# convert the images from the RGB to L*ab* color space, being
|
||||
# sure to utilizing the floating point data type (note: OpenCV
|
||||
# expects floats to be 32-bit, so use that instead of 64-bit)
|
||||
source = cv2.cvtColor(source, cv2.COLOR_BGR2LAB).astype(np.float32)
|
||||
target = cv2.cvtColor(target, cv2.COLOR_BGR2LAB).astype(np.float32)
|
||||
|
||||
# compute color statistics for the source and target images
|
||||
src_input = source if source_mask is None else source*source_mask
|
||||
tgt_input = target if target_mask is None else target*target_mask
|
||||
(lMeanSrc, lStdSrc, aMeanSrc, aStdSrc, bMeanSrc, bStdSrc) = lab_image_stats(src_input)
|
||||
(lMeanTar, lStdTar, aMeanTar, aStdTar, bMeanTar, bStdTar) = lab_image_stats(tgt_input)
|
||||
|
||||
# subtract the means from the target image
|
||||
(l, a, b) = cv2.split(target)
|
||||
l -= lMeanTar
|
||||
a -= aMeanTar
|
||||
b -= bMeanTar
|
||||
|
||||
if preserve_paper:
|
||||
# scale by the standard deviations using paper proposed factor
|
||||
l = (lStdTar / lStdSrc) * l
|
||||
a = (aStdTar / aStdSrc) * a
|
||||
b = (bStdTar / bStdSrc) * b
|
||||
else:
|
||||
# scale by the standard deviations using reciprocal of paper proposed factor
|
||||
l = (lStdSrc / lStdTar) * l
|
||||
a = (aStdSrc / aStdTar) * a
|
||||
b = (bStdSrc / bStdTar) * b
|
||||
|
||||
# add in the source mean
|
||||
l += lMeanSrc
|
||||
a += aMeanSrc
|
||||
b += bMeanSrc
|
||||
|
||||
# clip/scale the pixel intensities to [0, 255] if they fall
|
||||
# outside this range
|
||||
l = _scale_array(l, clip=clip)
|
||||
a = _scale_array(a, clip=clip)
|
||||
b = _scale_array(b, clip=clip)
|
||||
|
||||
# merge the channels together and convert back to the RGB color
|
||||
# space, being sure to utilize the 8-bit unsigned integer data
|
||||
# type
|
||||
transfer = cv2.merge([l, a, b])
|
||||
transfer = cv2.cvtColor(transfer.astype(np.uint8), cv2.COLOR_LAB2BGR)
|
||||
|
||||
# return the color transferred image
|
||||
return transfer
|
||||
|
||||
def linear_color_transfer(target_img, source_img, mode='pca', eps=1e-5):
|
||||
'''
|
||||
|
@ -399,9 +323,7 @@ def color_transfer(ct_mode, img_src, img_trg):
|
|||
if ct_mode == 'lct':
|
||||
out = linear_color_transfer (img_src, img_trg)
|
||||
elif ct_mode == 'rct':
|
||||
out = reinhard_color_transfer ( np.clip( img_src*255, 0, 255 ).astype(np.uint8),
|
||||
np.clip( img_trg*255, 0, 255 ).astype(np.uint8) )
|
||||
out = np.clip( out.astype(np.float32) / 255.0, 0.0, 1.0)
|
||||
out = reinhard_color_transfer(img_src, img_trg)
|
||||
elif ct_mode == 'mkl':
|
||||
out = color_transfer_mkl (img_src, img_trg)
|
||||
elif ct_mode == 'idt':
|
||||
|
@ -411,4 +333,4 @@ def color_transfer(ct_mode, img_src, img_trg):
|
|||
out = np.clip( out, 0.0, 1.0)
|
||||
else:
|
||||
raise ValueError(f"unknown ct_mode {ct_mode}")
|
||||
return out
|
||||
return out
|
||||
|
|
|
@ -1,5 +1,16 @@
|
|||
import numpy as np
|
||||
|
||||
def random_crop(img, w, h):
|
||||
height, width = img.shape[:2]
|
||||
|
||||
h_rnd = height - h
|
||||
w_rnd = width - w
|
||||
|
||||
y = np.random.randint(0, h_rnd) if h_rnd > 0 else 0
|
||||
x = np.random.randint(0, w_rnd) if w_rnd > 0 else 0
|
||||
|
||||
return img[y:y+height, x:x+width]
|
||||
|
||||
def normalize_channels(img, target_channels):
|
||||
img_shape_len = len(img.shape)
|
||||
if img_shape_len == 2:
|
||||
|
|
|
@ -31,9 +31,7 @@ goods or services; loss of use, data, or profits; or business interruption) howe
|
|||
import numpy as np
|
||||
import cv2
|
||||
from math import atan2, pi
|
||||
from scipy.ndimage import convolve
|
||||
from skimage.filters.edges import HSOBEL_WEIGHTS
|
||||
from skimage.feature import canny
|
||||
|
||||
|
||||
def sobel(image):
|
||||
# type: (numpy.ndarray) -> numpy.ndarray
|
||||
|
@ -42,10 +40,11 @@ def sobel(image):
|
|||
|
||||
Inspired by the [Octave implementation](https://sourceforge.net/p/octave/image/ci/default/tree/inst/edge.m#l196).
|
||||
"""
|
||||
|
||||
from skimage.filters.edges import HSOBEL_WEIGHTS
|
||||
h1 = np.array(HSOBEL_WEIGHTS)
|
||||
h1 /= np.sum(abs(h1)) # normalize h1
|
||||
|
||||
|
||||
from scipy.ndimage import convolve
|
||||
strength2 = np.square(convolve(image, h1.T))
|
||||
|
||||
# Note: https://sourceforge.net/p/octave/image/ci/default/tree/inst/edge.m#l59
|
||||
|
@ -103,6 +102,7 @@ def compute(image):
|
|||
# edge detection using canny and sobel canny edge detection is done to
|
||||
# classify the blocks as edge or non-edge blocks and sobel edge
|
||||
# detection is done for the purpose of edge width measurement.
|
||||
from skimage.feature import canny
|
||||
canny_edges = canny(image)
|
||||
sobel_edges = sobel(image)
|
||||
|
||||
|
@ -269,9 +269,10 @@ def get_block_contrast(block):
|
|||
|
||||
|
||||
def estimate_sharpness(image):
|
||||
height, width = image.shape[:2]
|
||||
|
||||
if image.ndim == 3:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
if image.shape[2] > 1:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
else:
|
||||
image = image[...,0]
|
||||
|
||||
return compute(image)
|
||||
|
|
|
@ -1,47 +1,65 @@
|
|||
import numpy as np
|
||||
from .blursharpen import LinearMotionBlur
|
||||
from .blursharpen import LinearMotionBlur, blursharpen
|
||||
import cv2
|
||||
|
||||
def apply_random_rgb_levels(img, mask=None, rnd_state=None):
|
||||
if rnd_state is None:
|
||||
rnd_state = np.random
|
||||
np_rnd = rnd_state.rand
|
||||
|
||||
|
||||
inBlack = np.array([np_rnd()*0.25 , np_rnd()*0.25 , np_rnd()*0.25], dtype=np.float32)
|
||||
inWhite = np.array([1.0-np_rnd()*0.25, 1.0-np_rnd()*0.25, 1.0-np_rnd()*0.25], dtype=np.float32)
|
||||
inGamma = np.array([0.5+np_rnd(), 0.5+np_rnd(), 0.5+np_rnd()], dtype=np.float32)
|
||||
|
||||
|
||||
outBlack = np.array([np_rnd()*0.25 , np_rnd()*0.25 , np_rnd()*0.25], dtype=np.float32)
|
||||
outWhite = np.array([1.0-np_rnd()*0.25, 1.0-np_rnd()*0.25, 1.0-np_rnd()*0.25], dtype=np.float32)
|
||||
|
||||
result = np.clip( (img - inBlack) / (inWhite - inBlack), 0, 1 )
|
||||
result = ( result ** (1/inGamma) ) * (outWhite - outBlack) + outBlack
|
||||
result = np.clip(result, 0, 1)
|
||||
|
||||
|
||||
if mask is not None:
|
||||
result = img*(1-mask) + result*mask
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def apply_random_hsv_shift(img, mask=None, rnd_state=None):
|
||||
if rnd_state is None:
|
||||
rnd_state = np.random
|
||||
|
||||
|
||||
h, s, v = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
|
||||
h = ( h + rnd_state.randint(360) ) % 360
|
||||
s = np.clip ( s + rnd_state.random()-0.5, 0, 1 )
|
||||
v = np.clip ( v + rnd_state.random()/2-0.25, 0, 1 )
|
||||
|
||||
v = np.clip ( v + rnd_state.random()-0.5, 0, 1 )
|
||||
|
||||
result = np.clip( cv2.cvtColor(cv2.merge([h, s, v]), cv2.COLOR_HSV2BGR) , 0, 1 )
|
||||
if mask is not None:
|
||||
result = img*(1-mask) + result*mask
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def apply_random_sharpen( img, chance, kernel_max_size, mask=None, rnd_state=None ):
|
||||
if rnd_state is None:
|
||||
rnd_state = np.random
|
||||
|
||||
sharp_rnd_kernel = rnd_state.randint(kernel_max_size)+1
|
||||
|
||||
result = img
|
||||
if rnd_state.randint(100) < np.clip(chance, 0, 100):
|
||||
if rnd_state.randint(2) == 0:
|
||||
result = blursharpen(result, 1, sharp_rnd_kernel, rnd_state.randint(10) )
|
||||
else:
|
||||
result = blursharpen(result, 2, sharp_rnd_kernel, rnd_state.randint(50) )
|
||||
|
||||
if mask is not None:
|
||||
result = img*(1-mask) + result*mask
|
||||
|
||||
return result
|
||||
|
||||
def apply_random_motion_blur( img, chance, mb_max_size, mask=None, rnd_state=None ):
|
||||
if rnd_state is None:
|
||||
rnd_state = np.random
|
||||
|
||||
|
||||
mblur_rnd_kernel = rnd_state.randint(mb_max_size)+1
|
||||
mblur_rnd_deg = rnd_state.randint(360)
|
||||
|
||||
|
@ -50,38 +68,178 @@ def apply_random_motion_blur( img, chance, mb_max_size, mask=None, rnd_state=Non
|
|||
result = LinearMotionBlur (result, mblur_rnd_kernel, mblur_rnd_deg )
|
||||
if mask is not None:
|
||||
result = img*(1-mask) + result*mask
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def apply_random_gaussian_blur( img, chance, kernel_max_size, mask=None, rnd_state=None ):
|
||||
if rnd_state is None:
|
||||
rnd_state = np.random
|
||||
|
||||
|
||||
result = img
|
||||
if rnd_state.randint(100) < np.clip(chance, 0, 100):
|
||||
gblur_rnd_kernel = rnd_state.randint(kernel_max_size)*2+1
|
||||
result = cv2.GaussianBlur(result, (gblur_rnd_kernel,)*2 , 0)
|
||||
if mask is not None:
|
||||
result = img*(1-mask) + result*mask
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def apply_random_bilinear_resize( img, chance, max_size_per, mask=None, rnd_state=None ):
|
||||
|
||||
def apply_random_resize( img, chance, max_size_per, interpolation=cv2.INTER_LINEAR, mask=None, rnd_state=None ):
|
||||
if rnd_state is None:
|
||||
rnd_state = np.random
|
||||
|
||||
result = img
|
||||
if rnd_state.randint(100) < np.clip(chance, 0, 100):
|
||||
h,w,c = result.shape
|
||||
|
||||
|
||||
trg = rnd_state.rand()
|
||||
rw = w - int( trg * int(w*(max_size_per/100.0)) )
|
||||
rh = h - int( trg * int(h*(max_size_per/100.0)) )
|
||||
|
||||
result = cv2.resize (result, (rw,rh), cv2.INTER_LINEAR )
|
||||
result = cv2.resize (result, (w,h), cv2.INTER_LINEAR )
|
||||
rw = w - int( trg * int(w*(max_size_per/100.0)) )
|
||||
rh = h - int( trg * int(h*(max_size_per/100.0)) )
|
||||
|
||||
result = cv2.resize (result, (rw,rh), interpolation=interpolation )
|
||||
result = cv2.resize (result, (w,h), interpolation=interpolation )
|
||||
if mask is not None:
|
||||
result = img*(1-mask) + result*mask
|
||||
|
||||
|
||||
return result
|
||||
|
||||
def apply_random_nearest_resize( img, chance, max_size_per, mask=None, rnd_state=None ):
|
||||
return apply_random_resize( img, chance, max_size_per, interpolation=cv2.INTER_NEAREST, mask=mask, rnd_state=rnd_state )
|
||||
|
||||
def apply_random_bilinear_resize( img, chance, max_size_per, mask=None, rnd_state=None ):
|
||||
return apply_random_resize( img, chance, max_size_per, interpolation=cv2.INTER_LINEAR, mask=mask, rnd_state=rnd_state )
|
||||
|
||||
def apply_random_jpeg_compress( img, chance, mask=None, rnd_state=None ):
|
||||
if rnd_state is None:
|
||||
rnd_state = np.random
|
||||
|
||||
result = img
|
||||
if rnd_state.randint(100) < np.clip(chance, 0, 100):
|
||||
h,w,c = result.shape
|
||||
|
||||
quality = rnd_state.randint(10,101)
|
||||
|
||||
ret, result = cv2.imencode('.jpg', np.clip(img*255, 0,255).astype(np.uint8), [int(cv2.IMWRITE_JPEG_QUALITY), quality] )
|
||||
if ret == True:
|
||||
result = cv2.imdecode(result, flags=cv2.IMREAD_UNCHANGED)
|
||||
result = result.astype(np.float32) / 255.0
|
||||
if mask is not None:
|
||||
result = img*(1-mask) + result*mask
|
||||
|
||||
return result
|
||||
|
||||
def apply_random_overlay_triangle( img, max_alpha, mask=None, rnd_state=None ):
|
||||
if rnd_state is None:
|
||||
rnd_state = np.random
|
||||
|
||||
h,w,c = img.shape
|
||||
pt1 = [rnd_state.randint(w), rnd_state.randint(h) ]
|
||||
pt2 = [rnd_state.randint(w), rnd_state.randint(h) ]
|
||||
pt3 = [rnd_state.randint(w), rnd_state.randint(h) ]
|
||||
|
||||
alpha = rnd_state.uniform()*max_alpha
|
||||
|
||||
tri_mask = cv2.fillPoly( np.zeros_like(img), [ np.array([pt1,pt2,pt3], np.int32) ], (alpha,)*c )
|
||||
|
||||
if rnd_state.randint(2) == 0:
|
||||
result = np.clip(img+tri_mask, 0, 1)
|
||||
else:
|
||||
result = np.clip(img-tri_mask, 0, 1)
|
||||
|
||||
if mask is not None:
|
||||
result = img*(1-mask) + result*mask
|
||||
|
||||
return result
|
||||
|
||||
def _min_resize(x, m):
|
||||
if x.shape[0] < x.shape[1]:
|
||||
s0 = m
|
||||
s1 = int(float(m) / float(x.shape[0]) * float(x.shape[1]))
|
||||
else:
|
||||
s0 = int(float(m) / float(x.shape[1]) * float(x.shape[0]))
|
||||
s1 = m
|
||||
new_max = min(s1, s0)
|
||||
raw_max = min(x.shape[0], x.shape[1])
|
||||
return cv2.resize(x, (s1, s0), interpolation=cv2.INTER_LANCZOS4)
|
||||
|
||||
def _d_resize(x, d, fac=1.0):
|
||||
new_min = min(int(d[1] * fac), int(d[0] * fac))
|
||||
raw_min = min(x.shape[0], x.shape[1])
|
||||
if new_min < raw_min:
|
||||
interpolation = cv2.INTER_AREA
|
||||
else:
|
||||
interpolation = cv2.INTER_LANCZOS4
|
||||
y = cv2.resize(x, (int(d[1] * fac), int(d[0] * fac)), interpolation=interpolation)
|
||||
return y
|
||||
|
||||
def _get_image_gradient(dist):
|
||||
cols = cv2.filter2D(dist, cv2.CV_32F, np.array([[-1, 0, +1], [-2, 0, +2], [-1, 0, +1]]))
|
||||
rows = cv2.filter2D(dist, cv2.CV_32F, np.array([[-1, -2, -1], [0, 0, 0], [+1, +2, +1]]))
|
||||
return cols, rows
|
||||
|
||||
def _generate_lighting_effects(content):
|
||||
h512 = content
|
||||
h256 = cv2.pyrDown(h512)
|
||||
h128 = cv2.pyrDown(h256)
|
||||
h64 = cv2.pyrDown(h128)
|
||||
h32 = cv2.pyrDown(h64)
|
||||
h16 = cv2.pyrDown(h32)
|
||||
c512, r512 = _get_image_gradient(h512)
|
||||
c256, r256 = _get_image_gradient(h256)
|
||||
c128, r128 = _get_image_gradient(h128)
|
||||
c64, r64 = _get_image_gradient(h64)
|
||||
c32, r32 = _get_image_gradient(h32)
|
||||
c16, r16 = _get_image_gradient(h16)
|
||||
c = c16
|
||||
c = _d_resize(cv2.pyrUp(c), c32.shape) * 4.0 + c32
|
||||
c = _d_resize(cv2.pyrUp(c), c64.shape) * 4.0 + c64
|
||||
c = _d_resize(cv2.pyrUp(c), c128.shape) * 4.0 + c128
|
||||
c = _d_resize(cv2.pyrUp(c), c256.shape) * 4.0 + c256
|
||||
c = _d_resize(cv2.pyrUp(c), c512.shape) * 4.0 + c512
|
||||
r = r16
|
||||
r = _d_resize(cv2.pyrUp(r), r32.shape) * 4.0 + r32
|
||||
r = _d_resize(cv2.pyrUp(r), r64.shape) * 4.0 + r64
|
||||
r = _d_resize(cv2.pyrUp(r), r128.shape) * 4.0 + r128
|
||||
r = _d_resize(cv2.pyrUp(r), r256.shape) * 4.0 + r256
|
||||
r = _d_resize(cv2.pyrUp(r), r512.shape) * 4.0 + r512
|
||||
coarse_effect_cols = c
|
||||
coarse_effect_rows = r
|
||||
EPS = 1e-10
|
||||
|
||||
max_effect = np.max((coarse_effect_cols**2 + coarse_effect_rows**2)**0.5, axis=0, keepdims=True, ).max(1, keepdims=True)
|
||||
coarse_effect_cols = (coarse_effect_cols + EPS) / (max_effect + EPS)
|
||||
coarse_effect_rows = (coarse_effect_rows + EPS) / (max_effect + EPS)
|
||||
|
||||
return np.stack([ np.zeros_like(coarse_effect_rows), coarse_effect_rows, coarse_effect_cols], axis=-1)
|
||||
|
||||
def apply_random_relight(img, mask=None, rnd_state=None):
|
||||
if rnd_state is None:
|
||||
rnd_state = np.random
|
||||
|
||||
def_img = img
|
||||
|
||||
if rnd_state.randint(2) == 0:
|
||||
light_pos_y = 1.0 if rnd_state.randint(2) == 0 else -1.0
|
||||
light_pos_x = rnd_state.uniform()*2-1.0
|
||||
else:
|
||||
light_pos_y = rnd_state.uniform()*2-1.0
|
||||
light_pos_x = 1.0 if rnd_state.randint(2) == 0 else -1.0
|
||||
|
||||
light_source_height = 0.3*rnd_state.uniform()*0.7
|
||||
light_intensity = 1.0+rnd_state.uniform()
|
||||
ambient_intensity = 0.5
|
||||
|
||||
light_source_location = np.array([[[light_source_height, light_pos_y, light_pos_x ]]], dtype=np.float32)
|
||||
light_source_direction = light_source_location / np.sqrt(np.sum(np.square(light_source_location)))
|
||||
|
||||
lighting_effect = _generate_lighting_effects(img)
|
||||
lighting_effect = np.sum(lighting_effect * light_source_direction, axis=-1).clip(0, 1)
|
||||
lighting_effect = np.mean(lighting_effect, axis=-1, keepdims=True)
|
||||
|
||||
result = def_img * (ambient_intensity + lighting_effect * light_intensity) #light_source_color
|
||||
result = np.clip(result, 0, 1)
|
||||
|
||||
if mask is not None:
|
||||
result = def_img*(1-mask) + result*mask
|
||||
|
||||
return result
|
|
@ -1 +1,2 @@
|
|||
from .draw import *
|
||||
from .draw import circle_faded, random_circle_faded, bezier, random_bezier_split_faded, random_faded
|
||||
from .calc import *
|
25
core/imagelib/sd/calc.py
Normal file
|
@ -0,0 +1,25 @@
|
|||
import numpy as np
|
||||
import numpy.linalg as npla
|
||||
|
||||
def dist_to_edges(pts, pt, is_closed=False):
|
||||
"""
|
||||
returns array of dist from pt to edge and projection pt to edges
|
||||
"""
|
||||
if is_closed:
|
||||
a = pts
|
||||
b = np.concatenate( (pts[1:,:], pts[0:1,:]), axis=0 )
|
||||
else:
|
||||
a = pts[:-1,:]
|
||||
b = pts[1:,:]
|
||||
|
||||
pa = pt-a
|
||||
ba = b-a
|
||||
|
||||
div = np.einsum('ij,ij->i', ba, ba)
|
||||
div[div==0]=1
|
||||
h = np.clip( np.einsum('ij,ij->i', pa, ba) / div, 0, 1 )
|
||||
|
||||
x = npla.norm ( pa - ba*h[...,None], axis=1 )
|
||||
|
||||
return x, a+ba*h[...,None]
|
||||
|
|
@ -1,23 +1,36 @@
|
|||
"""
|
||||
Signed distance drawing functions using numpy.
|
||||
"""
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
from numpy import linalg as npla
|
||||
|
||||
def circle_faded( hw, center, fade_dists ):
|
||||
|
||||
def vector2_dot(a,b):
|
||||
return a[...,0]*b[...,0]+a[...,1]*b[...,1]
|
||||
|
||||
def vector2_dot2(a):
|
||||
return a[...,0]*a[...,0]+a[...,1]*a[...,1]
|
||||
|
||||
def vector2_cross(a,b):
|
||||
return a[...,0]*b[...,1]-a[...,1]*b[...,0]
|
||||
|
||||
|
||||
def circle_faded( wh, center, fade_dists ):
|
||||
"""
|
||||
returns drawn circle in [h,w,1] output range [0..1.0] float32
|
||||
|
||||
hw = [h,w] resolution
|
||||
center = [y,x] center of circle
|
||||
wh = [w,h] resolution
|
||||
center = [x,y] center of circle
|
||||
fade_dists = [fade_start, fade_end] fade values
|
||||
"""
|
||||
h,w = hw
|
||||
w,h = wh
|
||||
|
||||
pts = np.empty( (h,w,2), dtype=np.float32 )
|
||||
pts[...,1] = np.arange(h)[None,:]
|
||||
pts[...,0] = np.arange(w)[:,None]
|
||||
pts[...,1] = np.arange(h)[None,:]
|
||||
|
||||
pts = pts.reshape ( (h*w, -1) )
|
||||
|
||||
pts_dists = np.abs ( npla.norm(pts-center, axis=-1) )
|
||||
|
@ -30,15 +43,158 @@ def circle_faded( hw, center, fade_dists ):
|
|||
pts_dists = np.clip( 1-pts_dists, 0, 1)
|
||||
|
||||
return pts_dists.reshape ( (h,w,1) ).astype(np.float32)
|
||||
|
||||
|
||||
def bezier( wh, A, B, C ):
|
||||
"""
|
||||
returns drawn bezier in [h,w,1] output range float32,
|
||||
every pixel contains signed distance to bezier line
|
||||
|
||||
wh [w,h] resolution
|
||||
A,B,C points [x,y]
|
||||
"""
|
||||
|
||||
def random_circle_faded ( hw, rnd_state=None ):
|
||||
width,height = wh
|
||||
|
||||
A = np.float32(A)
|
||||
B = np.float32(B)
|
||||
C = np.float32(C)
|
||||
|
||||
|
||||
pos = np.empty( (height,width,2), dtype=np.float32 )
|
||||
pos[...,0] = np.arange(width)[:,None]
|
||||
pos[...,1] = np.arange(height)[None,:]
|
||||
|
||||
|
||||
a = B-A
|
||||
b = A - 2.0*B + C
|
||||
c = a * 2.0
|
||||
d = A - pos
|
||||
|
||||
b_dot = vector2_dot(b,b)
|
||||
if b_dot == 0.0:
|
||||
return np.zeros( (height,width), dtype=np.float32 )
|
||||
|
||||
kk = 1.0 / b_dot
|
||||
|
||||
kx = kk * vector2_dot(a,b)
|
||||
ky = kk * (2.0*vector2_dot(a,a)+vector2_dot(d,b))/3.0;
|
||||
kz = kk * vector2_dot(d,a);
|
||||
|
||||
res = 0.0;
|
||||
sgn = 0.0;
|
||||
|
||||
p = ky - kx*kx;
|
||||
|
||||
p3 = p*p*p;
|
||||
q = kx*(2.0*kx*kx - 3.0*ky) + kz;
|
||||
h = q*q + 4.0*p3;
|
||||
|
||||
hp_sel = h >= 0.0
|
||||
|
||||
hp_p = h[hp_sel]
|
||||
hp_p = np.sqrt(hp_p)
|
||||
|
||||
hp_x = ( np.stack( (hp_p,-hp_p), -1) -q[hp_sel,None] ) / 2.0
|
||||
hp_uv = np.sign(hp_x) * np.power( np.abs(hp_x), [1.0/3.0, 1.0/3.0] )
|
||||
hp_t = np.clip( hp_uv[...,0] + hp_uv[...,1] - kx, 0.0, 1.0 )
|
||||
|
||||
hp_t = hp_t[...,None]
|
||||
hp_q = d[hp_sel]+(c+b*hp_t)*hp_t
|
||||
hp_res = vector2_dot2(hp_q)
|
||||
hp_sgn = vector2_cross(c+2.0*b*hp_t,hp_q)
|
||||
|
||||
hl_sel = h < 0.0
|
||||
|
||||
hl_q = q[hl_sel]
|
||||
hl_p = p[hl_sel]
|
||||
hl_z = np.sqrt(-hl_p)
|
||||
hl_v = np.arccos( hl_q / (hl_p*hl_z*2.0)) / 3.0
|
||||
|
||||
hl_m = np.cos(hl_v)
|
||||
hl_n = np.sin(hl_v)*1.732050808;
|
||||
|
||||
hl_t = np.clip( np.stack( (hl_m+hl_m,-hl_n-hl_m,hl_n-hl_m), -1)*hl_z[...,None]-kx, 0.0, 1.0 );
|
||||
|
||||
hl_d = d[hl_sel]
|
||||
|
||||
hl_qx = hl_d+(c+b*hl_t[...,0:1])*hl_t[...,0:1]
|
||||
|
||||
hl_dx = vector2_dot2(hl_qx)
|
||||
hl_sx = vector2_cross(c+2.0*b*hl_t[...,0:1], hl_qx)
|
||||
|
||||
hl_qy = hl_d+(c+b*hl_t[...,1:2])*hl_t[...,1:2]
|
||||
hl_dy = vector2_dot2(hl_qy)
|
||||
hl_sy = vector2_cross(c+2.0*b*hl_t[...,1:2],hl_qy);
|
||||
|
||||
hl_dx_l_dy = hl_dx<hl_dy
|
||||
hl_dx_ge_dy = hl_dx>=hl_dy
|
||||
|
||||
hl_res = np.empty_like(hl_dx)
|
||||
hl_res[hl_dx_l_dy] = hl_dx[hl_dx_l_dy]
|
||||
hl_res[hl_dx_ge_dy] = hl_dy[hl_dx_ge_dy]
|
||||
|
||||
hl_sgn = np.empty_like(hl_sx)
|
||||
hl_sgn[hl_dx_l_dy] = hl_sx[hl_dx_l_dy]
|
||||
hl_sgn[hl_dx_ge_dy] = hl_sy[hl_dx_ge_dy]
|
||||
|
||||
res = np.empty( (height, width), np.float32 )
|
||||
res[hp_sel] = hp_res
|
||||
res[hl_sel] = hl_res
|
||||
|
||||
sgn = np.empty( (height, width), np.float32 )
|
||||
sgn[hp_sel] = hp_sgn
|
||||
sgn[hl_sel] = hl_sgn
|
||||
|
||||
sgn = np.sign(sgn)
|
||||
res = np.sqrt(res)*sgn
|
||||
|
||||
return res[...,None]
|
||||
|
||||
def random_faded(wh):
|
||||
"""
|
||||
apply one of them:
|
||||
random_circle_faded
|
||||
random_bezier_split_faded
|
||||
"""
|
||||
rnd = np.random.randint(2)
|
||||
if rnd == 0:
|
||||
return random_circle_faded(wh)
|
||||
elif rnd == 1:
|
||||
return random_bezier_split_faded(wh)
|
||||
|
||||
def random_circle_faded ( wh, rnd_state=None ):
|
||||
if rnd_state is None:
|
||||
rnd_state = np.random
|
||||
|
||||
h,w = hw
|
||||
hw_max = max(h,w)
|
||||
fade_start = rnd_state.randint(hw_max)
|
||||
fade_end = fade_start + rnd_state.randint(hw_max- fade_start)
|
||||
w,h = wh
|
||||
wh_max = max(w,h)
|
||||
fade_start = rnd_state.randint(wh_max)
|
||||
fade_end = fade_start + rnd_state.randint(wh_max- fade_start)
|
||||
|
||||
return circle_faded (hw, [ rnd_state.randint(h), rnd_state.randint(w) ],
|
||||
[fade_start, fade_end] )
|
||||
return circle_faded (wh, [ rnd_state.randint(h), rnd_state.randint(w) ],
|
||||
[fade_start, fade_end] )
|
||||
|
||||
def random_bezier_split_faded( wh ):
|
||||
width, height = wh
|
||||
|
||||
degA = np.random.randint(360)
|
||||
degB = np.random.randint(360)
|
||||
degC = np.random.randint(360)
|
||||
|
||||
deg_2_rad = math.pi / 180.0
|
||||
|
||||
center = np.float32([width / 2.0, height / 2.0])
|
||||
|
||||
radius = max(width, height)
|
||||
|
||||
A = center + radius*np.float32([ math.sin( degA * deg_2_rad), math.cos( degA * deg_2_rad) ] )
|
||||
B = center + np.random.randint(radius)*np.float32([ math.sin( degB * deg_2_rad), math.cos( degB * deg_2_rad) ] )
|
||||
C = center + radius*np.float32([ math.sin( degC * deg_2_rad), math.cos( degC * deg_2_rad) ] )
|
||||
|
||||
x = bezier( (width,height), A, B, C )
|
||||
|
||||
x = x / (1+np.random.randint(radius)) + 0.5
|
||||
|
||||
x = np.clip(x, 0, 1)
|
||||
return x
|
||||
|
|
|
@ -1,33 +1,147 @@
|
|||
import numpy as np
|
||||
import numpy.linalg as npla
|
||||
import cv2
|
||||
from core import randomex
|
||||
|
||||
def gen_warp_params (w, flip, rotation_range=[-10,10], scale_range=[-0.5, 0.5], tx_range=[-0.05, 0.05], ty_range=[-0.05, 0.05], rnd_state=None ):
|
||||
def mls_rigid_deformation(vy, vx, src_pts, dst_pts, alpha=1.0, eps=1e-8):
|
||||
dst_pts = dst_pts[..., ::-1].astype(np.int16)
|
||||
src_pts = src_pts[..., ::-1].astype(np.int16)
|
||||
|
||||
src_pts, dst_pts = dst_pts, src_pts
|
||||
|
||||
grow = vx.shape[0]
|
||||
gcol = vx.shape[1]
|
||||
ctrls = src_pts.shape[0]
|
||||
|
||||
reshaped_p = src_pts.reshape(ctrls, 2, 1, 1)
|
||||
reshaped_v = np.vstack((vx.reshape(1, grow, gcol), vy.reshape(1, grow, gcol)))
|
||||
|
||||
w = 1.0 / (np.sum((reshaped_p - reshaped_v).astype(np.float32) ** 2, axis=1) + eps) ** alpha
|
||||
w /= np.sum(w, axis=0, keepdims=True)
|
||||
|
||||
pstar = np.zeros((2, grow, gcol), np.float32)
|
||||
for i in range(ctrls):
|
||||
pstar += w[i] * reshaped_p[i]
|
||||
|
||||
vpstar = reshaped_v - pstar
|
||||
|
||||
reshaped_mul_right = np.concatenate((vpstar[:,None,...],
|
||||
np.concatenate((vpstar[1:2,None,...],-vpstar[0:1,None,...]), 0)
|
||||
), axis=1).transpose(2, 3, 0, 1)
|
||||
|
||||
reshaped_q = dst_pts.reshape((ctrls, 2, 1, 1))
|
||||
|
||||
qstar = np.zeros((2, grow, gcol), np.float32)
|
||||
for i in range(ctrls):
|
||||
qstar += w[i] * reshaped_q[i]
|
||||
|
||||
temp = np.zeros((grow, gcol, 2), np.float32)
|
||||
for i in range(ctrls):
|
||||
phat = reshaped_p[i] - pstar
|
||||
qhat = reshaped_q[i] - qstar
|
||||
|
||||
temp += np.matmul(qhat.reshape(1, 2, grow, gcol).transpose(2, 3, 0, 1),
|
||||
|
||||
np.matmul( ( w[None, i:i+1,...] *
|
||||
np.concatenate((phat.reshape(1, 2, grow, gcol),
|
||||
np.concatenate( (phat[None,1:2], -phat[None,0:1]), 1 )), 0)
|
||||
).transpose(2, 3, 0, 1), reshaped_mul_right
|
||||
)
|
||||
).reshape(grow, gcol, 2)
|
||||
|
||||
temp = temp.transpose(2, 0, 1)
|
||||
|
||||
normed_temp = np.linalg.norm(temp, axis=0, keepdims=True)
|
||||
normed_vpstar = np.linalg.norm(vpstar, axis=0, keepdims=True)
|
||||
nan_mask = normed_temp[0]==0
|
||||
|
||||
transformers = np.true_divide(temp, normed_temp, out=np.zeros_like(temp), where= ~nan_mask) * normed_vpstar + qstar
|
||||
nan_mask_flat = np.flatnonzero(nan_mask)
|
||||
nan_mask_anti_flat = np.flatnonzero(~nan_mask)
|
||||
|
||||
transformers[0][nan_mask] = np.interp(nan_mask_flat, nan_mask_anti_flat, transformers[0][~nan_mask])
|
||||
transformers[1][nan_mask] = np.interp(nan_mask_flat, nan_mask_anti_flat, transformers[1][~nan_mask])
|
||||
|
||||
return transformers
|
||||
|
||||
def gen_pts(W, H, rnd_state=None):
|
||||
|
||||
if rnd_state is None:
|
||||
rnd_state = np.random
|
||||
|
||||
|
||||
min_pts, max_pts = 4, 8
|
||||
n_pts = rnd_state.randint(min_pts, max_pts)
|
||||
|
||||
min_radius_per = 0.00
|
||||
max_radius_per = 0.10
|
||||
pts = []
|
||||
|
||||
for i in range(n_pts):
|
||||
while True:
|
||||
x, y = rnd_state.randint(W), rnd_state.randint(H)
|
||||
rad = min_radius_per + rnd_state.rand()*(max_radius_per-min_radius_per)
|
||||
|
||||
intersect = False
|
||||
for px,py,prad,_,_ in pts:
|
||||
|
||||
dist = npla.norm([x-px, y-py])
|
||||
if dist <= (rad+prad)*2:
|
||||
intersect = True
|
||||
break
|
||||
if intersect:
|
||||
continue
|
||||
|
||||
angle = rnd_state.rand()*(2*np.pi)
|
||||
x2 = int(x+np.cos(angle)*W*rad)
|
||||
y2 = int(y+np.sin(angle)*H*rad)
|
||||
|
||||
break
|
||||
pts.append( (x,y,rad, x2,y2) )
|
||||
|
||||
pts1 = np.array( [ [pt[0],pt[1]] for pt in pts ] )
|
||||
pts2 = np.array( [ [pt[-2],pt[-1]] for pt in pts ] )
|
||||
|
||||
return pts1, pts2
|
||||
|
||||
|
||||
def gen_warp_params (w, flip=False, rotation_range=[-10,10], scale_range=[-0.5, 0.5], tx_range=[-0.05, 0.05], ty_range=[-0.05, 0.05], rnd_state=None, warp_rnd_state=None ):
|
||||
if rnd_state is None:
|
||||
rnd_state = np.random
|
||||
if warp_rnd_state is None:
|
||||
warp_rnd_state = np.random
|
||||
rw = None
|
||||
if w < 64:
|
||||
rw = w
|
||||
w = 64
|
||||
|
||||
rotation = rnd_state.uniform( rotation_range[0], rotation_range[1] )
|
||||
scale = rnd_state.uniform(1 +scale_range[0], 1 +scale_range[1])
|
||||
scale = rnd_state.uniform( 1/(1-scale_range[0]) , 1+scale_range[1] )
|
||||
tx = rnd_state.uniform( tx_range[0], tx_range[1] )
|
||||
ty = rnd_state.uniform( ty_range[0], ty_range[1] )
|
||||
p_flip = flip and rnd_state.randint(10) < 4
|
||||
|
||||
#random warp by grid
|
||||
cell_size = [ w // (2**i) for i in range(1,4) ] [ rnd_state.randint(3) ]
|
||||
#random warp V1
|
||||
cell_size = [ w // (2**i) for i in range(1,4) ] [ warp_rnd_state.randint(3) ]
|
||||
cell_count = w // cell_size + 1
|
||||
|
||||
grid_points = np.linspace( 0, w, cell_count)
|
||||
mapx = np.broadcast_to(grid_points, (cell_count, cell_count)).copy()
|
||||
mapy = mapx.T
|
||||
|
||||
mapx[1:-1,1:-1] = mapx[1:-1,1:-1] + randomex.random_normal( size=(cell_count-2, cell_count-2) )*(cell_size*0.24)
|
||||
mapy[1:-1,1:-1] = mapy[1:-1,1:-1] + randomex.random_normal( size=(cell_count-2, cell_count-2) )*(cell_size*0.24)
|
||||
|
||||
mapx[1:-1,1:-1] = mapx[1:-1,1:-1] + randomex.random_normal( size=(cell_count-2, cell_count-2), rnd_state=warp_rnd_state )*(cell_size*0.24)
|
||||
mapy[1:-1,1:-1] = mapy[1:-1,1:-1] + randomex.random_normal( size=(cell_count-2, cell_count-2), rnd_state=warp_rnd_state )*(cell_size*0.24)
|
||||
half_cell_size = cell_size // 2
|
||||
|
||||
mapx = cv2.resize(mapx, (w+cell_size,)*2 )[half_cell_size:-half_cell_size,half_cell_size:-half_cell_size].astype(np.float32)
|
||||
mapy = cv2.resize(mapy, (w+cell_size,)*2 )[half_cell_size:-half_cell_size,half_cell_size:-half_cell_size].astype(np.float32)
|
||||
|
||||
##############
|
||||
|
||||
# random warp V2
|
||||
# pts1, pts2 = gen_pts(w, w, rnd_state)
|
||||
# gridX = np.arange(w, dtype=np.int16)
|
||||
# gridY = np.arange(w, dtype=np.int16)
|
||||
# vy, vx = np.meshgrid(gridX, gridY)
|
||||
# drigid = mls_rigid_deformation(vy, vx, pts1, pts2)
|
||||
# mapy, mapx = drigid.astype(np.float32)
|
||||
################
|
||||
|
||||
#random transform
|
||||
random_transform_mat = cv2.getRotationMatrix2D((w // 2, w // 2), rotation, scale)
|
||||
random_transform_mat[:, 2] += (tx*w, ty*w)
|
||||
|
@ -36,16 +150,30 @@ def gen_warp_params (w, flip, rotation_range=[-10,10], scale_range=[-0.5, 0.5],
|
|||
params['mapx'] = mapx
|
||||
params['mapy'] = mapy
|
||||
params['rmat'] = random_transform_mat
|
||||
u_mat = random_transform_mat.copy()
|
||||
u_mat[:,2] /= w
|
||||
params['umat'] = u_mat
|
||||
params['w'] = w
|
||||
params['rw'] = rw
|
||||
params['flip'] = p_flip
|
||||
|
||||
return params
|
||||
|
||||
def warp_by_params (params, img, can_warp, can_transform, can_flip, border_replicate, cv2_inter=cv2.INTER_CUBIC):
|
||||
rw = params['rw']
|
||||
|
||||
if (can_warp or can_transform) and rw is not None:
|
||||
img = cv2.resize(img, (64,64), interpolation=cv2_inter)
|
||||
|
||||
if can_warp:
|
||||
img = cv2.remap(img, params['mapx'], params['mapy'], cv2_inter )
|
||||
if can_transform:
|
||||
img = cv2.warpAffine( img, params['rmat'], (params['w'], params['w']), borderMode=(cv2.BORDER_REPLICATE if border_replicate else cv2.BORDER_CONSTANT), flags=cv2_inter )
|
||||
|
||||
|
||||
if (can_warp or can_transform) and rw is not None:
|
||||
img = cv2.resize(img, (rw,rw), interpolation=cv2_inter)
|
||||
|
||||
if len(img.shape) == 2:
|
||||
img = img[...,None]
|
||||
if can_flip and params['flip']:
|
||||
|
|
|
@ -7,6 +7,7 @@ import types
|
|||
|
||||
import colorama
|
||||
import cv2
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from core import stdex
|
||||
|
@ -197,7 +198,7 @@ class InteractBase(object):
|
|||
def add_key_event(self, wnd_name, ord_key, ctrl_pressed, alt_pressed, shift_pressed):
|
||||
if wnd_name not in self.key_events:
|
||||
self.key_events[wnd_name] = []
|
||||
self.key_events[wnd_name] += [ (ord_key, chr(ord_key), ctrl_pressed, alt_pressed, shift_pressed) ]
|
||||
self.key_events[wnd_name] += [ (ord_key, chr(ord_key) if ord_key <= 255 else chr(0), ctrl_pressed, alt_pressed, shift_pressed) ]
|
||||
|
||||
def get_mouse_events(self, wnd_name):
|
||||
ar = self.mouse_events.get(wnd_name, [])
|
||||
|
@ -255,7 +256,7 @@ class InteractBase(object):
|
|||
print(result)
|
||||
return result
|
||||
|
||||
def input_int(self, s, default_value, valid_list=None, add_info=None, show_default_value=True, help_message=None):
|
||||
def input_int(self, s, default_value, valid_range=None, valid_list=None, add_info=None, show_default_value=True, help_message=None):
|
||||
if show_default_value:
|
||||
if len(s) != 0:
|
||||
s = f"[{default_value}] {s}"
|
||||
|
@ -263,15 +264,21 @@ class InteractBase(object):
|
|||
s = f"[{default_value}]"
|
||||
|
||||
if add_info is not None or \
|
||||
valid_range is not None or \
|
||||
help_message is not None:
|
||||
s += " ("
|
||||
|
||||
if valid_range is not None:
|
||||
s += f" {valid_range[0]}-{valid_range[1]}"
|
||||
|
||||
if add_info is not None:
|
||||
s += f" {add_info}"
|
||||
|
||||
if help_message is not None:
|
||||
s += " ?:help"
|
||||
|
||||
if add_info is not None or \
|
||||
valid_range is not None or \
|
||||
help_message is not None:
|
||||
s += " )"
|
||||
|
||||
|
@ -288,9 +295,12 @@ class InteractBase(object):
|
|||
continue
|
||||
|
||||
i = int(inp)
|
||||
if valid_range is not None:
|
||||
i = int(np.clip(i, valid_range[0], valid_range[1]))
|
||||
|
||||
if (valid_list is not None) and (i not in valid_list):
|
||||
result = default_value
|
||||
break
|
||||
i = default_value
|
||||
|
||||
result = i
|
||||
break
|
||||
except:
|
||||
|
@ -427,6 +437,7 @@ class InteractBase(object):
|
|||
p.start()
|
||||
time.sleep(0.5)
|
||||
p.terminate()
|
||||
p.join()
|
||||
sys.stdin = os.fdopen( sys.stdin.fileno() )
|
||||
|
||||
|
||||
|
@ -490,10 +501,11 @@ class InteractDesktop(InteractBase):
|
|||
|
||||
if has_windows or has_capture_keys:
|
||||
wait_key_time = max(1, int(sleep_time*1000) )
|
||||
ord_key = cv2.waitKey(wait_key_time)
|
||||
ord_key = cv2.waitKeyEx(wait_key_time)
|
||||
|
||||
shift_pressed = False
|
||||
if ord_key != -1:
|
||||
chr_key = chr(ord_key)
|
||||
chr_key = chr(ord_key) if ord_key <= 255 else chr(0)
|
||||
|
||||
if chr_key >= 'A' and chr_key <= 'Z':
|
||||
shift_pressed = True
|
||||
|
|
|
@ -81,11 +81,8 @@ class Subprocessor(object):
|
|||
except Subprocessor.SilenceException as e:
|
||||
c2s.put ( {'op': 'error', 'data' : data} )
|
||||
except Exception as e:
|
||||
c2s.put ( {'op': 'error', 'data' : data} )
|
||||
if data is not None:
|
||||
print ('Exception while process data [%s]: %s' % (self.get_data_name(data), traceback.format_exc()) )
|
||||
else:
|
||||
print ('Exception: %s' % (traceback.format_exc()) )
|
||||
err_msg = traceback.format_exc()
|
||||
c2s.put ( {'op': 'error', 'data' : data, 'err_msg' : err_msg} )
|
||||
|
||||
c2s.close()
|
||||
s2c.close()
|
||||
|
@ -159,6 +156,24 @@ class Subprocessor(object):
|
|||
|
||||
self.clis = []
|
||||
|
||||
def cli_init_dispatcher(cli):
|
||||
while not cli.c2s.empty():
|
||||
obj = cli.c2s.get()
|
||||
op = obj.get('op','')
|
||||
if op == 'init_ok':
|
||||
cli.state = 0
|
||||
elif op == 'log_info':
|
||||
io.log_info(obj['msg'])
|
||||
elif op == 'log_err':
|
||||
io.log_err(obj['msg'])
|
||||
elif op == 'error':
|
||||
err_msg = obj.get('err_msg', None)
|
||||
if err_msg is not None:
|
||||
io.log_info(f'Error while subprocess initialization: {err_msg}')
|
||||
cli.kill()
|
||||
self.clis.remove(cli)
|
||||
break
|
||||
|
||||
#getting info about name of subprocesses, host and client dicts, and spawning them
|
||||
for name, host_dict, client_dict in self.process_info_generator():
|
||||
try:
|
||||
|
@ -173,19 +188,7 @@ class Subprocessor(object):
|
|||
|
||||
if self.initialize_subprocesses_in_serial:
|
||||
while True:
|
||||
while not cli.c2s.empty():
|
||||
obj = cli.c2s.get()
|
||||
op = obj.get('op','')
|
||||
if op == 'init_ok':
|
||||
cli.state = 0
|
||||
elif op == 'log_info':
|
||||
io.log_info(obj['msg'])
|
||||
elif op == 'log_err':
|
||||
io.log_err(obj['msg'])
|
||||
elif op == 'error':
|
||||
cli.kill()
|
||||
self.clis.remove(cli)
|
||||
break
|
||||
cli_init_dispatcher(cli)
|
||||
if cli.state == 0:
|
||||
break
|
||||
io.process_messages(0.005)
|
||||
|
@ -198,19 +201,7 @@ class Subprocessor(object):
|
|||
#waiting subprocesses their success(or not) initialization
|
||||
while True:
|
||||
for cli in self.clis[:]:
|
||||
while not cli.c2s.empty():
|
||||
obj = cli.c2s.get()
|
||||
op = obj.get('op','')
|
||||
if op == 'init_ok':
|
||||
cli.state = 0
|
||||
elif op == 'log_info':
|
||||
io.log_info(obj['msg'])
|
||||
elif op == 'log_err':
|
||||
io.log_err(obj['msg'])
|
||||
elif op == 'error':
|
||||
cli.kill()
|
||||
self.clis.remove(cli)
|
||||
break
|
||||
cli_init_dispatcher(cli)
|
||||
if all ([cli.state == 0 for cli in self.clis]):
|
||||
break
|
||||
io.process_messages(0.005)
|
||||
|
@ -235,8 +226,12 @@ class Subprocessor(object):
|
|||
cli.state = 0
|
||||
elif op == 'error':
|
||||
#some error occured while process data, returning chunk to on_data_return
|
||||
err_msg = obj.get('err_msg', None)
|
||||
if err_msg is not None:
|
||||
io.log_info(f'Error while processing data: {err_msg}')
|
||||
|
||||
if 'data' in obj.keys():
|
||||
self.on_data_return (cli.host_dict, obj['data'] )
|
||||
self.on_data_return (cli.host_dict, obj['data'] )
|
||||
#and killing process
|
||||
cli.kill()
|
||||
self.clis.remove(cli)
|
||||
|
|
|
@ -1,54 +1,60 @@
|
|||
from core.leras import nn
|
||||
tf = nn.tf
|
||||
|
||||
class DeepFakeArchi(nn.ArchiBase):
|
||||
class DeepFakeArchi(nn.ArchiBase):
|
||||
"""
|
||||
resolution
|
||||
|
||||
|
||||
mod None - default
|
||||
'chervonij'
|
||||
'quick'
|
||||
|
||||
opts ''
|
||||
''
|
||||
't'
|
||||
"""
|
||||
def __init__(self, resolution, mod=None):
|
||||
def __init__(self, resolution, use_fp16=False, mod=None, opts=None):
|
||||
super().__init__()
|
||||
|
||||
if opts is None:
|
||||
opts = ''
|
||||
|
||||
|
||||
conv_dtype = tf.float16 if use_fp16 else tf.float32
|
||||
|
||||
if 'c' in opts:
|
||||
def act(x, alpha=0.1):
|
||||
return x*tf.cos(x)
|
||||
else:
|
||||
def act(x, alpha=0.1):
|
||||
return tf.nn.leaky_relu(x, alpha)
|
||||
|
||||
if mod is None:
|
||||
class Downscale(nn.ModelBase):
|
||||
def __init__(self, in_ch, out_ch, kernel_size=5, dilations=1, subpixel=True, use_activator=True, *kwargs ):
|
||||
def __init__(self, in_ch, out_ch, kernel_size=5, *kwargs ):
|
||||
self.in_ch = in_ch
|
||||
self.out_ch = out_ch
|
||||
self.kernel_size = kernel_size
|
||||
self.dilations = dilations
|
||||
self.subpixel = subpixel
|
||||
self.use_activator = use_activator
|
||||
super().__init__(*kwargs)
|
||||
|
||||
def on_build(self, *args, **kwargs ):
|
||||
self.conv1 = nn.Conv2D( self.in_ch,
|
||||
self.out_ch // (4 if self.subpixel else 1),
|
||||
kernel_size=self.kernel_size,
|
||||
strides=1 if self.subpixel else 2,
|
||||
padding='SAME', dilations=self.dilations)
|
||||
self.conv1 = nn.Conv2D( self.in_ch, self.out_ch, kernel_size=self.kernel_size, strides=2, padding='SAME', dtype=conv_dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
if self.subpixel:
|
||||
x = nn.space_to_depth(x, 2)
|
||||
if self.use_activator:
|
||||
x = tf.nn.leaky_relu(x, 0.1)
|
||||
x = act(x, 0.1)
|
||||
return x
|
||||
|
||||
def get_out_ch(self):
|
||||
return (self.out_ch // 4) * 4
|
||||
return self.out_ch
|
||||
|
||||
class DownscaleBlock(nn.ModelBase):
|
||||
def on_build(self, in_ch, ch, n_downscales, kernel_size, dilations=1, subpixel=True):
|
||||
def on_build(self, in_ch, ch, n_downscales, kernel_size):
|
||||
self.downs = []
|
||||
|
||||
last_ch = in_ch
|
||||
for i in range(n_downscales):
|
||||
cur_ch = ch*( min(2**i, 8) )
|
||||
self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size, dilations=dilations, subpixel=subpixel) )
|
||||
self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size))
|
||||
last_ch = self.downs[-1].get_out_ch()
|
||||
|
||||
def forward(self, inp):
|
||||
|
@ -58,66 +64,77 @@ class DeepFakeArchi(nn.ArchiBase):
|
|||
return x
|
||||
|
||||
class Upscale(nn.ModelBase):
|
||||
def on_build(self, in_ch, out_ch, kernel_size=3 ):
|
||||
self.conv1 = nn.Conv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME')
|
||||
def on_build(self, in_ch, out_ch, kernel_size=3):
|
||||
self.conv1 = nn.Conv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = tf.nn.leaky_relu(x, 0.1)
|
||||
x = act(x, 0.1)
|
||||
x = nn.depth_to_space(x, 2)
|
||||
return x
|
||||
|
||||
class ResidualBlock(nn.ModelBase):
|
||||
def on_build(self, ch, kernel_size=3 ):
|
||||
self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
|
||||
self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
|
||||
def on_build(self, ch, kernel_size=3):
|
||||
self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype)
|
||||
self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype)
|
||||
|
||||
def forward(self, inp):
|
||||
x = self.conv1(inp)
|
||||
x = tf.nn.leaky_relu(x, 0.2)
|
||||
x = act(x, 0.2)
|
||||
x = self.conv2(x)
|
||||
x = tf.nn.leaky_relu(inp + x, 0.2)
|
||||
x = act(inp + x, 0.2)
|
||||
return x
|
||||
|
||||
class UpdownResidualBlock(nn.ModelBase):
|
||||
def on_build(self, ch, inner_ch, kernel_size=3 ):
|
||||
self.up = Upscale (ch, inner_ch, kernel_size=kernel_size)
|
||||
self.res = ResidualBlock (inner_ch, kernel_size=kernel_size)
|
||||
self.down = Downscale (inner_ch, ch, kernel_size=kernel_size, use_activator=False)
|
||||
|
||||
def forward(self, inp):
|
||||
x = self.up(inp)
|
||||
x = upx = self.res(x)
|
||||
x = self.down(x)
|
||||
x = x + inp
|
||||
x = tf.nn.leaky_relu(x, 0.2)
|
||||
return x, upx
|
||||
|
||||
class Encoder(nn.ModelBase):
|
||||
def on_build(self, in_ch, e_ch, is_hd):
|
||||
self.is_hd=is_hd
|
||||
if self.is_hd:
|
||||
self.down1 = DownscaleBlock(in_ch, e_ch*2, n_downscales=4, kernel_size=3, dilations=1)
|
||||
self.down2 = DownscaleBlock(in_ch, e_ch*2, n_downscales=4, kernel_size=5, dilations=1)
|
||||
self.down3 = DownscaleBlock(in_ch, e_ch//2, n_downscales=4, kernel_size=5, dilations=2)
|
||||
self.down4 = DownscaleBlock(in_ch, e_ch//2, n_downscales=4, kernel_size=7, dilations=2)
|
||||
else:
|
||||
self.down1 = DownscaleBlock(in_ch, e_ch, n_downscales=4, kernel_size=5, dilations=1, subpixel=False)
|
||||
def __init__(self, in_ch, e_ch, **kwargs ):
|
||||
self.in_ch = in_ch
|
||||
self.e_ch = e_ch
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def forward(self, inp):
|
||||
if self.is_hd:
|
||||
x = tf.concat([ nn.flatten(self.down1(inp)),
|
||||
nn.flatten(self.down2(inp)),
|
||||
nn.flatten(self.down3(inp)),
|
||||
nn.flatten(self.down4(inp)) ], -1 )
|
||||
def on_build(self):
|
||||
if 't' in opts:
|
||||
self.down1 = Downscale(self.in_ch, self.e_ch, kernel_size=5)
|
||||
self.res1 = ResidualBlock(self.e_ch)
|
||||
self.down2 = Downscale(self.e_ch, self.e_ch*2, kernel_size=5)
|
||||
self.down3 = Downscale(self.e_ch*2, self.e_ch*4, kernel_size=5)
|
||||
self.down4 = Downscale(self.e_ch*4, self.e_ch*8, kernel_size=5)
|
||||
self.down5 = Downscale(self.e_ch*8, self.e_ch*8, kernel_size=5)
|
||||
self.res5 = ResidualBlock(self.e_ch*8)
|
||||
else:
|
||||
x = nn.flatten(self.down1(inp))
|
||||
self.down1 = DownscaleBlock(self.in_ch, self.e_ch, n_downscales=4 if 't' not in opts else 5, kernel_size=5)
|
||||
|
||||
def forward(self, x):
|
||||
if use_fp16:
|
||||
x = tf.cast(x, tf.float16)
|
||||
|
||||
if 't' in opts:
|
||||
x = self.down1(x)
|
||||
x = self.res1(x)
|
||||
x = self.down2(x)
|
||||
x = self.down3(x)
|
||||
x = self.down4(x)
|
||||
x = self.down5(x)
|
||||
x = self.res5(x)
|
||||
else:
|
||||
x = self.down1(x)
|
||||
x = nn.flatten(x)
|
||||
if 'u' in opts:
|
||||
x = nn.pixel_norm(x, axes=-1)
|
||||
|
||||
if use_fp16:
|
||||
x = tf.cast(x, tf.float32)
|
||||
return x
|
||||
|
||||
lowest_dense_res = resolution // 16
|
||||
|
||||
def get_out_res(self, res):
|
||||
return res // ( (2**4) if 't' not in opts else (2**5) )
|
||||
|
||||
def get_out_ch(self):
|
||||
return self.e_ch * 8
|
||||
|
||||
lowest_dense_res = resolution // (32 if 'd' in opts else 16)
|
||||
|
||||
class Inter(nn.ModelBase):
|
||||
def __init__(self, in_ch, ae_ch, ae_out_ch, is_hd=False, **kwargs):
|
||||
def __init__(self, in_ch, ae_ch, ae_out_ch, **kwargs):
|
||||
self.in_ch, self.ae_ch, self.ae_out_ch = in_ch, ae_ch, ae_out_ch
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
@ -126,362 +143,120 @@ class DeepFakeArchi(nn.ArchiBase):
|
|||
|
||||
self.dense1 = nn.Dense( in_ch, ae_ch )
|
||||
self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch )
|
||||
self.upscale1 = Upscale(ae_out_ch, ae_out_ch)
|
||||
if 't' not in opts:
|
||||
self.upscale1 = Upscale(ae_out_ch, ae_out_ch)
|
||||
|
||||
def forward(self, inp):
|
||||
x = self.dense1(inp)
|
||||
x = inp
|
||||
x = self.dense1(x)
|
||||
x = self.dense2(x)
|
||||
x = nn.reshape_4D (x, lowest_dense_res, lowest_dense_res, self.ae_out_ch)
|
||||
x = self.upscale1(x)
|
||||
|
||||
if use_fp16:
|
||||
x = tf.cast(x, tf.float16)
|
||||
|
||||
if 't' not in opts:
|
||||
x = self.upscale1(x)
|
||||
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def get_code_res():
|
||||
return lowest_dense_res
|
||||
|
||||
|
||||
def get_out_res(self):
|
||||
return lowest_dense_res * 2 if 't' not in opts else lowest_dense_res
|
||||
|
||||
def get_out_ch(self):
|
||||
return self.ae_out_ch
|
||||
|
||||
class Decoder(nn.ModelBase):
|
||||
def on_build(self, in_ch, d_ch, d_mask_ch, is_hd ):
|
||||
self.is_hd = is_hd
|
||||
|
||||
self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3)
|
||||
self.upscale1 = Upscale(d_ch*8, d_ch*4, kernel_size=3)
|
||||
self.upscale2 = Upscale(d_ch*4, d_ch*2, kernel_size=3)
|
||||
|
||||
if is_hd:
|
||||
self.res0 = UpdownResidualBlock(in_ch, d_ch*8, kernel_size=3)
|
||||
self.res1 = UpdownResidualBlock(d_ch*8, d_ch*4, kernel_size=3)
|
||||
self.res2 = UpdownResidualBlock(d_ch*4, d_ch*2, kernel_size=3)
|
||||
self.res3 = UpdownResidualBlock(d_ch*2, d_ch, kernel_size=3)
|
||||
else:
|
||||
def on_build(self, in_ch, d_ch, d_mask_ch):
|
||||
if 't' not in opts:
|
||||
self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3)
|
||||
self.upscale1 = Upscale(d_ch*8, d_ch*4, kernel_size=3)
|
||||
self.upscale2 = Upscale(d_ch*4, d_ch*2, kernel_size=3)
|
||||
self.res0 = ResidualBlock(d_ch*8, kernel_size=3)
|
||||
self.res1 = ResidualBlock(d_ch*4, kernel_size=3)
|
||||
self.res2 = ResidualBlock(d_ch*2, kernel_size=3)
|
||||
|
||||
self.out_conv = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='SAME')
|
||||
self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3)
|
||||
self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3)
|
||||
self.upscalem2 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3)
|
||||
|
||||
self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3)
|
||||
self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3)
|
||||
self.upscalem2 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3)
|
||||
self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME')
|
||||
self.out_conv = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='SAME', dtype=conv_dtype)
|
||||
|
||||
def forward(self, inp):
|
||||
z = inp
|
||||
|
||||
if self.is_hd:
|
||||
x, upx = self.res0(z)
|
||||
x = self.upscale0(x)
|
||||
x = tf.nn.leaky_relu(x + upx, 0.2)
|
||||
x, upx = self.res1(x)
|
||||
|
||||
x = self.upscale1(x)
|
||||
x = tf.nn.leaky_relu(x + upx, 0.2)
|
||||
x, upx = self.res2(x)
|
||||
|
||||
x = self.upscale2(x)
|
||||
x = tf.nn.leaky_relu(x + upx, 0.2)
|
||||
x, upx = self.res3(x)
|
||||
if 'd' in opts:
|
||||
self.out_conv1 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
|
||||
self.out_conv2 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
|
||||
self.out_conv3 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
|
||||
self.upscalem3 = Upscale(d_mask_ch*2, d_mask_ch*1, kernel_size=3)
|
||||
self.out_convm = nn.Conv2D( d_mask_ch*1, 1, kernel_size=1, padding='SAME', dtype=conv_dtype)
|
||||
else:
|
||||
self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME', dtype=conv_dtype)
|
||||
else:
|
||||
x = self.upscale0(z)
|
||||
x = self.res0(x)
|
||||
x = self.upscale1(x)
|
||||
x = self.res1(x)
|
||||
x = self.upscale2(x)
|
||||
x = self.res2(x)
|
||||
self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3)
|
||||
self.upscale1 = Upscale(d_ch*8, d_ch*8, kernel_size=3)
|
||||
self.upscale2 = Upscale(d_ch*8, d_ch*4, kernel_size=3)
|
||||
self.upscale3 = Upscale(d_ch*4, d_ch*2, kernel_size=3)
|
||||
self.res0 = ResidualBlock(d_ch*8, kernel_size=3)
|
||||
self.res1 = ResidualBlock(d_ch*8, kernel_size=3)
|
||||
self.res2 = ResidualBlock(d_ch*4, kernel_size=3)
|
||||
self.res3 = ResidualBlock(d_ch*2, kernel_size=3)
|
||||
|
||||
m = self.upscalem0(z)
|
||||
m = self.upscalem1(m)
|
||||
m = self.upscalem2(m)
|
||||
self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3)
|
||||
self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*8, kernel_size=3)
|
||||
self.upscalem2 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3)
|
||||
self.upscalem3 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3)
|
||||
self.out_conv = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='SAME', dtype=conv_dtype)
|
||||
|
||||
return tf.nn.sigmoid(self.out_conv(x)), \
|
||||
tf.nn.sigmoid(self.out_convm(m))
|
||||
|
||||
elif mod == 'chervonij':
|
||||
class Downscale(nn.ModelBase):
|
||||
def __init__(self, in_ch, kernel_size=3, dilations=1, *kwargs ):
|
||||
self.in_ch = in_ch
|
||||
self.kernel_size = kernel_size
|
||||
self.dilations = dilations
|
||||
super().__init__(*kwargs)
|
||||
|
||||
def on_build(self, *args, **kwargs ):
|
||||
self.conv_base1 = nn.Conv2D( self.in_ch, self.in_ch//2, kernel_size=1, strides=1, padding='SAME', dilations=self.dilations)
|
||||
self.conv_l1 = nn.Conv2D( self.in_ch//2, self.in_ch//2, kernel_size=self.kernel_size, strides=1, padding='SAME', dilations=self.dilations)
|
||||
self.conv_l2 = nn.Conv2D( self.in_ch//2, self.in_ch//2, kernel_size=self.kernel_size, strides=2, padding='SAME', dilations=self.dilations)
|
||||
|
||||
self.conv_base2 = nn.Conv2D( self.in_ch, self.in_ch//2, kernel_size=1, strides=1, padding='SAME', dilations=self.dilations)
|
||||
self.conv_r1 = nn.Conv2D( self.in_ch//2, self.in_ch//2, kernel_size=self.kernel_size, strides=2, padding='SAME', dilations=self.dilations)
|
||||
|
||||
self.pool_size = [1,1,2,2] if nn.data_format == 'NCHW' else [1,2,2,1]
|
||||
def forward(self, x):
|
||||
|
||||
x_l = self.conv_base1(x)
|
||||
x_l = self.conv_l1(x_l)
|
||||
x_l = self.conv_l2(x_l)
|
||||
|
||||
x_r = self.conv_base2(x)
|
||||
x_r = self.conv_r1(x_r)
|
||||
|
||||
x_pool = tf.nn.max_pool(x, ksize=self.pool_size, strides=self.pool_size, padding='SAME', data_format=nn.data_format)
|
||||
|
||||
x = tf.concat([x_l, x_r, x_pool], axis=nn.conv2d_ch_axis)
|
||||
x = tf.nn.leaky_relu(x, 0.1)
|
||||
return x
|
||||
|
||||
class Upscale(nn.ModelBase):
|
||||
def on_build(self, in_ch, out_ch, kernel_size=3 ):
|
||||
self.conv1 = nn.Conv2D( in_ch, out_ch, kernel_size=kernel_size, padding='SAME')
|
||||
self.conv2 = nn.Conv2D( out_ch, out_ch, kernel_size=kernel_size, padding='SAME')
|
||||
self.conv3 = nn.Conv2D( out_ch, out_ch, kernel_size=kernel_size, padding='SAME')
|
||||
self.conv4 = nn.Conv2D( out_ch, out_ch, kernel_size=kernel_size, padding='SAME')
|
||||
|
||||
def forward(self, x):
|
||||
x0 = self.conv1(x)
|
||||
x1 = self.conv2(x0)
|
||||
x2 = self.conv3(x1)
|
||||
x3 = self.conv4(x2)
|
||||
x = tf.concat([x0, x1, x2, x3], axis=nn.conv2d_ch_axis)
|
||||
x = tf.nn.leaky_relu(x, 0.1)
|
||||
x = nn.depth_to_space(x, 2)
|
||||
return x
|
||||
|
||||
class ResidualBlock(nn.ModelBase):
|
||||
def on_build(self, ch, kernel_size=3 ):
|
||||
self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
|
||||
self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
|
||||
self.norm = nn.FRNorm2D(ch)
|
||||
|
||||
def forward(self, inp):
|
||||
x = self.conv1(inp)
|
||||
x = tf.nn.leaky_relu(x, 0.2)
|
||||
x = self.conv2(x)
|
||||
x = self.norm(inp + x)
|
||||
x = tf.nn.leaky_relu(x, 0.2)
|
||||
return x
|
||||
|
||||
class Encoder(nn.ModelBase):
|
||||
def on_build(self, in_ch, e_ch, **kwargs):
|
||||
self.conv0 = nn.Conv2D(in_ch, e_ch, kernel_size=3, padding='SAME')
|
||||
|
||||
self.down0 = Downscale(e_ch)
|
||||
self.down1 = Downscale(e_ch*2)
|
||||
self.down2 = Downscale(e_ch*4)
|
||||
self.down3 = Downscale(e_ch*8)
|
||||
self.down4 = Downscale(e_ch*16)
|
||||
|
||||
def forward(self, inp):
|
||||
x = self.conv0(inp)
|
||||
x = self.down0(x)
|
||||
x = self.down1(x)
|
||||
x = self.down2(x)
|
||||
x = self.down3(x)
|
||||
x = self.down4(x)
|
||||
x = nn.flatten(x)
|
||||
return x
|
||||
|
||||
lowest_dense_res = resolution // 32
|
||||
|
||||
class Inter(nn.ModelBase):
|
||||
def __init__(self, in_ch, ae_ch, ae_out_ch, **kwargs):
|
||||
self.in_ch, self.ae_ch, self.ae_out_ch = in_ch, ae_ch, ae_out_ch
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def on_build(self, **kwargs):
|
||||
in_ch, ae_ch, ae_out_ch = self.in_ch, self.ae_ch, self.ae_out_ch
|
||||
|
||||
self.dense_l = nn.Dense( in_ch, ae_ch//2, kernel_initializer=tf.initializers.orthogonal)
|
||||
self.dense_r = nn.Dense( in_ch, ae_ch//2, kernel_initializer=tf.initializers.orthogonal)#maxout_ch=4,
|
||||
self.dense = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * (ae_out_ch//2), kernel_initializer=tf.initializers.orthogonal)
|
||||
self.upscale1 = Upscale(ae_out_ch//2, ae_out_ch//2)
|
||||
|
||||
def forward(self, inp):
|
||||
x0 = self.dense_l(inp)
|
||||
x1 = self.dense_r(inp)
|
||||
x = tf.concat([x0, x1], axis=-1)
|
||||
x = self.dense(x)
|
||||
x = nn.reshape_4D (x, lowest_dense_res, lowest_dense_res, self.ae_out_ch//2)
|
||||
x = self.upscale1(x)
|
||||
|
||||
return x
|
||||
|
||||
def get_out_ch(self):
|
||||
return self.ae_out_ch//2
|
||||
|
||||
class Decoder(nn.ModelBase):
|
||||
def on_build(self, in_ch, d_ch, d_mask_ch, **kwargs):
|
||||
|
||||
self.upscale0 = Upscale(in_ch, d_ch*8)
|
||||
self.upscale1 = Upscale(d_ch*8, d_ch*4)
|
||||
self.upscale2 = Upscale(d_ch*4, d_ch*2)
|
||||
self.upscale3 = Upscale(d_ch*2, d_ch)
|
||||
|
||||
self.res0 = ResidualBlock(d_ch*8)
|
||||
self.res1 = ResidualBlock(d_ch*4)
|
||||
self.res2 = ResidualBlock(d_ch*2)
|
||||
self.res3 = ResidualBlock(d_ch)
|
||||
|
||||
self.out_conv = nn.Conv2D( d_ch, 3, kernel_size=1, padding='SAME')
|
||||
|
||||
self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3)
|
||||
self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3)
|
||||
self.upscalem2 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3)
|
||||
self.upscalem3 = Upscale(d_mask_ch*2, d_mask_ch, kernel_size=3)
|
||||
self.out_convm = nn.Conv2D( d_mask_ch, 1, kernel_size=1, padding='SAME')
|
||||
|
||||
def forward(self, inp):
|
||||
z = inp
|
||||
if 'd' in opts:
|
||||
self.out_conv1 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
|
||||
self.out_conv2 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
|
||||
self.out_conv3 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
|
||||
self.upscalem4 = Upscale(d_mask_ch*2, d_mask_ch*1, kernel_size=3)
|
||||
self.out_convm = nn.Conv2D( d_mask_ch*1, 1, kernel_size=1, padding='SAME', dtype=conv_dtype)
|
||||
else:
|
||||
self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME', dtype=conv_dtype)
|
||||
|
||||
|
||||
|
||||
def forward(self, z):
|
||||
x = self.upscale0(z)
|
||||
x = self.res0(x)
|
||||
x = self.upscale1(x)
|
||||
x = self.res1(x)
|
||||
x = self.upscale2(x)
|
||||
x = self.res2(x)
|
||||
x = self.upscale3(x)
|
||||
x = self.res3(x)
|
||||
|
||||
if 't' in opts:
|
||||
x = self.upscale3(x)
|
||||
x = self.res3(x)
|
||||
|
||||
if 'd' in opts:
|
||||
x = tf.nn.sigmoid( nn.depth_to_space(tf.concat( (self.out_conv(x),
|
||||
self.out_conv1(x),
|
||||
self.out_conv2(x),
|
||||
self.out_conv3(x)), nn.conv2d_ch_axis), 2) )
|
||||
else:
|
||||
x = tf.nn.sigmoid(self.out_conv(x))
|
||||
|
||||
|
||||
m = self.upscalem0(z)
|
||||
m = self.upscalem1(m)
|
||||
m = self.upscalem2(m)
|
||||
m = self.upscalem3(m)
|
||||
|
||||
return tf.nn.sigmoid(self.out_conv(x)), \
|
||||
tf.nn.sigmoid(self.out_convm(m))
|
||||
elif mod == 'quick':
|
||||
class Downscale(nn.ModelBase):
|
||||
def __init__(self, in_ch, out_ch, kernel_size=5, dilations=1, subpixel=True, use_activator=True, *kwargs ):
|
||||
self.in_ch = in_ch
|
||||
self.out_ch = out_ch
|
||||
self.kernel_size = kernel_size
|
||||
self.dilations = dilations
|
||||
self.subpixel = subpixel
|
||||
self.use_activator = use_activator
|
||||
super().__init__(*kwargs)
|
||||
if 't' in opts:
|
||||
m = self.upscalem3(m)
|
||||
if 'd' in opts:
|
||||
m = self.upscalem4(m)
|
||||
else:
|
||||
if 'd' in opts:
|
||||
m = self.upscalem3(m)
|
||||
|
||||
def on_build(self, *args, **kwargs ):
|
||||
self.conv1 = nn.Conv2D( self.in_ch,
|
||||
self.out_ch // (4 if self.subpixel else 1),
|
||||
kernel_size=self.kernel_size,
|
||||
strides=1 if self.subpixel else 2,
|
||||
padding='SAME', dilations=self.dilations )
|
||||
m = tf.nn.sigmoid(self.out_convm(m))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
if use_fp16:
|
||||
x = tf.cast(x, tf.float32)
|
||||
m = tf.cast(m, tf.float32)
|
||||
|
||||
if self.subpixel:
|
||||
x = nn.space_to_depth(x, 2)
|
||||
|
||||
if self.use_activator:
|
||||
x = nn.gelu(x)
|
||||
return x
|
||||
|
||||
def get_out_ch(self):
|
||||
return (self.out_ch // 4) * 4
|
||||
|
||||
class DownscaleBlock(nn.ModelBase):
|
||||
def on_build(self, in_ch, ch, n_downscales, kernel_size, dilations=1, subpixel=True):
|
||||
self.downs = []
|
||||
|
||||
last_ch = in_ch
|
||||
for i in range(n_downscales):
|
||||
cur_ch = ch*( min(2**i, 8) )
|
||||
self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size, dilations=dilations, subpixel=subpixel) )
|
||||
last_ch = self.downs[-1].get_out_ch()
|
||||
|
||||
def forward(self, inp):
|
||||
x = inp
|
||||
for down in self.downs:
|
||||
x = down(x)
|
||||
return x
|
||||
|
||||
class Upscale(nn.ModelBase):
|
||||
def on_build(self, in_ch, out_ch, kernel_size=3 ):
|
||||
self.conv1 = nn.Conv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME')
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = nn.gelu(x)
|
||||
x = nn.depth_to_space(x, 2)
|
||||
return x
|
||||
|
||||
class ResidualBlock(nn.ModelBase):
|
||||
def on_build(self, ch, kernel_size=3 ):
|
||||
self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
|
||||
self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
|
||||
|
||||
def forward(self, inp):
|
||||
x = self.conv1(inp)
|
||||
x = nn.gelu(x)
|
||||
x = self.conv2(x)
|
||||
x = inp + x
|
||||
x = nn.gelu(x)
|
||||
return x
|
||||
|
||||
class Encoder(nn.ModelBase):
|
||||
def on_build(self, in_ch, e_ch):
|
||||
self.down1 = DownscaleBlock(in_ch, e_ch, n_downscales=4, kernel_size=5)
|
||||
def forward(self, inp):
|
||||
return nn.flatten(self.down1(inp))
|
||||
|
||||
lowest_dense_res = resolution // 16
|
||||
|
||||
class Inter(nn.ModelBase):
|
||||
def __init__(self, in_ch, ae_ch, ae_out_ch, d_ch, **kwargs):
|
||||
self.in_ch, self.ae_ch, self.ae_out_ch, self.d_ch = in_ch, ae_ch, ae_out_ch, d_ch
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def on_build(self):
|
||||
in_ch, ae_ch, ae_out_ch, d_ch = self.in_ch, self.ae_ch, self.ae_out_ch, self.d_ch
|
||||
|
||||
self.dense1 = nn.Dense( in_ch, ae_ch, kernel_initializer=tf.initializers.orthogonal )
|
||||
self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch, kernel_initializer=tf.initializers.orthogonal )
|
||||
self.upscale1 = Upscale(ae_out_ch, d_ch*8)
|
||||
self.res1 = ResidualBlock(d_ch*8)
|
||||
|
||||
def forward(self, inp):
|
||||
x = self.dense1(inp)
|
||||
x = self.dense2(x)
|
||||
x = nn.reshape_4D (x, lowest_dense_res, lowest_dense_res, self.ae_out_ch)
|
||||
x = self.upscale1(x)
|
||||
x = self.res1(x)
|
||||
return x
|
||||
|
||||
def get_out_ch(self):
|
||||
return self.ae_out_ch
|
||||
|
||||
class Decoder(nn.ModelBase):
|
||||
def on_build(self, in_ch, d_ch):
|
||||
self.upscale1 = Upscale(in_ch, d_ch*4)
|
||||
self.res1 = ResidualBlock(d_ch*4)
|
||||
self.upscale2 = Upscale(d_ch*4, d_ch*2)
|
||||
self.res2 = ResidualBlock(d_ch*2)
|
||||
self.upscale3 = Upscale(d_ch*2, d_ch*1)
|
||||
self.res3 = ResidualBlock(d_ch*1)
|
||||
|
||||
self.upscalem1 = Upscale(in_ch, d_ch)
|
||||
self.upscalem2 = Upscale(d_ch, d_ch//2)
|
||||
self.upscalem3 = Upscale(d_ch//2, d_ch//2)
|
||||
|
||||
self.out_conv = nn.Conv2D( d_ch*1, 3, kernel_size=1, padding='SAME')
|
||||
self.out_convm = nn.Conv2D( d_ch//2, 1, kernel_size=1, padding='SAME')
|
||||
|
||||
def forward(self, inp):
|
||||
z = inp
|
||||
x = self.upscale1 (z)
|
||||
x = self.res1 (x)
|
||||
x = self.upscale2 (x)
|
||||
x = self.res2 (x)
|
||||
x = self.upscale3 (x)
|
||||
x = self.res3 (x)
|
||||
|
||||
y = self.upscalem1 (z)
|
||||
y = self.upscalem2 (y)
|
||||
y = self.upscalem3 (y)
|
||||
|
||||
return tf.nn.sigmoid(self.out_conv(x)), \
|
||||
tf.nn.sigmoid(self.out_convm(y))
|
||||
return x, m
|
||||
|
||||
self.Encoder = Encoder
|
||||
self.Inter = Inter
|
||||
|
|
|
@ -1,12 +1,19 @@
|
|||
import sys
|
||||
import ctypes
|
||||
import os
|
||||
import multiprocessing
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
from core.interact import interact as io
|
||||
|
||||
|
||||
class Device(object):
|
||||
def __init__(self, index, name, total_mem, free_mem, cc=0):
|
||||
def __init__(self, index, tf_dev_type, name, total_mem, free_mem):
|
||||
self.index = index
|
||||
self.tf_dev_type = tf_dev_type
|
||||
self.name = name
|
||||
self.cc = cc
|
||||
|
||||
self.total_mem = total_mem
|
||||
self.total_mem_gb = total_mem / 1024**3
|
||||
self.free_mem = free_mem
|
||||
|
@ -82,8 +89,136 @@ class Devices(object):
|
|||
result.append (device)
|
||||
return Devices(result)
|
||||
|
||||
@staticmethod
|
||||
def _get_tf_devices_proc(q : multiprocessing.Queue):
|
||||
|
||||
if sys.platform[0:3] == 'win':
|
||||
compute_cache_path = Path(os.environ['APPDATA']) / 'NVIDIA' / ('ComputeCache_ALL')
|
||||
os.environ['CUDA_CACHE_PATH'] = str(compute_cache_path)
|
||||
if not compute_cache_path.exists():
|
||||
io.log_info("Caching GPU kernels...")
|
||||
compute_cache_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
import tensorflow
|
||||
|
||||
tf_version = tensorflow.version.VERSION
|
||||
#if tf_version is None:
|
||||
# tf_version = tensorflow.version.GIT_VERSION
|
||||
if tf_version[0] == 'v':
|
||||
tf_version = tf_version[1:]
|
||||
if tf_version[0] == '2':
|
||||
tf = tensorflow.compat.v1
|
||||
else:
|
||||
tf = tensorflow
|
||||
|
||||
import logging
|
||||
# Disable tensorflow warnings
|
||||
tf_logger = logging.getLogger('tensorflow')
|
||||
tf_logger.setLevel(logging.ERROR)
|
||||
|
||||
from tensorflow.python.client import device_lib
|
||||
|
||||
devices = []
|
||||
|
||||
physical_devices = device_lib.list_local_devices()
|
||||
physical_devices_f = {}
|
||||
for dev in physical_devices:
|
||||
dev_type = dev.device_type
|
||||
dev_tf_name = dev.name
|
||||
dev_tf_name = dev_tf_name[ dev_tf_name.index(dev_type) : ]
|
||||
|
||||
dev_idx = int(dev_tf_name.split(':')[-1])
|
||||
|
||||
if dev_type in ['GPU','DML']:
|
||||
dev_name = dev_tf_name
|
||||
|
||||
dev_desc = dev.physical_device_desc
|
||||
if len(dev_desc) != 0:
|
||||
if dev_desc[0] == '{':
|
||||
dev_desc_json = json.loads(dev_desc)
|
||||
dev_desc_json_name = dev_desc_json.get('name',None)
|
||||
if dev_desc_json_name is not None:
|
||||
dev_name = dev_desc_json_name
|
||||
else:
|
||||
for param, value in ( v.split(':') for v in dev_desc.split(',') ):
|
||||
param = param.strip()
|
||||
value = value.strip()
|
||||
if param == 'name':
|
||||
dev_name = value
|
||||
break
|
||||
|
||||
physical_devices_f[dev_idx] = (dev_type, dev_name, dev.memory_limit)
|
||||
|
||||
q.put(physical_devices_f)
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def initialize_main_env():
|
||||
if int(os.environ.get("NN_DEVICES_INITIALIZED", 0)) != 0:
|
||||
return
|
||||
|
||||
if 'CUDA_VISIBLE_DEVICES' in os.environ.keys():
|
||||
os.environ.pop('CUDA_VISIBLE_DEVICES')
|
||||
|
||||
os.environ['TF_DIRECTML_KERNEL_CACHE_SIZE'] = '2500'
|
||||
os.environ['CUDA_CACHE_MAXSIZE'] = '2147483647'
|
||||
os.environ['TF_MIN_GPU_MULTIPROCESSOR_COUNT'] = '2'
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # tf log errors only
|
||||
|
||||
q = multiprocessing.Queue()
|
||||
p = multiprocessing.Process(target=Devices._get_tf_devices_proc, args=(q,), daemon=True)
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
visible_devices = q.get()
|
||||
|
||||
os.environ['NN_DEVICES_INITIALIZED'] = '1'
|
||||
os.environ['NN_DEVICES_COUNT'] = str(len(visible_devices))
|
||||
|
||||
for i in visible_devices:
|
||||
dev_type, name, total_mem = visible_devices[i]
|
||||
|
||||
os.environ[f'NN_DEVICE_{i}_TF_DEV_TYPE'] = dev_type
|
||||
os.environ[f'NN_DEVICE_{i}_NAME'] = name
|
||||
os.environ[f'NN_DEVICE_{i}_TOTAL_MEM'] = str(total_mem)
|
||||
os.environ[f'NN_DEVICE_{i}_FREE_MEM'] = str(total_mem)
|
||||
|
||||
|
||||
|
||||
@staticmethod
|
||||
def getDevices():
|
||||
if Devices.all_devices is None:
|
||||
if int(os.environ.get("NN_DEVICES_INITIALIZED", 0)) != 1:
|
||||
raise Exception("nn devices are not initialized. Run initialize_main_env() in main process.")
|
||||
devices = []
|
||||
for i in range ( int(os.environ['NN_DEVICES_COUNT']) ):
|
||||
devices.append ( Device(index=i,
|
||||
tf_dev_type=os.environ[f'NN_DEVICE_{i}_TF_DEV_TYPE'],
|
||||
name=os.environ[f'NN_DEVICE_{i}_NAME'],
|
||||
total_mem=int(os.environ[f'NN_DEVICE_{i}_TOTAL_MEM']),
|
||||
free_mem=int(os.environ[f'NN_DEVICE_{i}_FREE_MEM']), )
|
||||
)
|
||||
Devices.all_devices = Devices(devices)
|
||||
|
||||
return Devices.all_devices
|
||||
|
||||
"""
|
||||
|
||||
|
||||
# {'name' : name.split(b'\0', 1)[0].decode(),
|
||||
# 'total_mem' : totalMem.value
|
||||
# }
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
return
|
||||
|
||||
|
||||
|
||||
|
||||
min_cc = int(os.environ.get("TF_MIN_REQ_CAP", 35))
|
||||
libnames = ('libcuda.so', 'libcuda.dylib', 'nvcuda.dll')
|
||||
for libname in libnames:
|
||||
|
@ -129,77 +264,10 @@ class Devices(object):
|
|||
})
|
||||
cuda.cuCtxDetach(context)
|
||||
|
||||
os.environ['NN_DEVICES_INITIALIZED'] = '1'
|
||||
os.environ['NN_DEVICES_COUNT'] = str(len(devices))
|
||||
for i, device in enumerate(devices):
|
||||
os.environ[f'NN_DEVICE_{i}_NAME'] = device['name']
|
||||
os.environ[f'NN_DEVICE_{i}_TOTAL_MEM'] = str(device['total_mem'])
|
||||
os.environ[f'NN_DEVICE_{i}_FREE_MEM'] = str(device['free_mem'])
|
||||
os.environ[f'NN_DEVICE_{i}_CC'] = str(device['cc'])
|
||||
|
||||
@staticmethod
|
||||
def getDevices():
|
||||
if Devices.all_devices is None:
|
||||
if int(os.environ.get("NN_DEVICES_INITIALIZED", 0)) != 1:
|
||||
raise Exception("nn devices are not initialized. Run initialize_main_env() in main process.")
|
||||
devices = []
|
||||
for i in range ( int(os.environ['NN_DEVICES_COUNT']) ):
|
||||
devices.append ( Device(index=i,
|
||||
name=os.environ[f'NN_DEVICE_{i}_NAME'],
|
||||
total_mem=int(os.environ[f'NN_DEVICE_{i}_TOTAL_MEM']),
|
||||
free_mem=int(os.environ[f'NN_DEVICE_{i}_FREE_MEM']),
|
||||
cc=int(os.environ[f'NN_DEVICE_{i}_CC']) ))
|
||||
Devices.all_devices = Devices(devices)
|
||||
|
||||
return Devices.all_devices
|
||||
|
||||
"""
|
||||
if Devices.all_devices is None:
|
||||
min_cc = int(os.environ.get("TF_MIN_REQ_CAP", 35))
|
||||
|
||||
libnames = ('libcuda.so', 'libcuda.dylib', 'nvcuda.dll')
|
||||
for libname in libnames:
|
||||
try:
|
||||
cuda = ctypes.CDLL(libname)
|
||||
except:
|
||||
continue
|
||||
else:
|
||||
break
|
||||
else:
|
||||
return Devices([])
|
||||
|
||||
nGpus = ctypes.c_int()
|
||||
name = b' ' * 200
|
||||
cc_major = ctypes.c_int()
|
||||
cc_minor = ctypes.c_int()
|
||||
freeMem = ctypes.c_size_t()
|
||||
totalMem = ctypes.c_size_t()
|
||||
|
||||
result = ctypes.c_int()
|
||||
device = ctypes.c_int()
|
||||
context = ctypes.c_void_p()
|
||||
error_str = ctypes.c_char_p()
|
||||
|
||||
devices = []
|
||||
|
||||
if cuda.cuInit(0) == 0 and \
|
||||
cuda.cuDeviceGetCount(ctypes.byref(nGpus)) == 0:
|
||||
for i in range(nGpus.value):
|
||||
if cuda.cuDeviceGet(ctypes.byref(device), i) != 0 or \
|
||||
cuda.cuDeviceGetName(ctypes.c_char_p(name), len(name), device) != 0 or \
|
||||
cuda.cuDeviceComputeCapability(ctypes.byref(cc_major), ctypes.byref(cc_minor), device) != 0:
|
||||
continue
|
||||
|
||||
if cuda.cuCtxCreate_v2(ctypes.byref(context), 0, device) == 0:
|
||||
if cuda.cuMemGetInfo_v2(ctypes.byref(freeMem), ctypes.byref(totalMem)) == 0:
|
||||
cc = cc_major.value * 10 + cc_minor.value
|
||||
if cc >= min_cc:
|
||||
devices.append ( Device(index=i,
|
||||
name=name.split(b'\0', 1)[0].decode(),
|
||||
total_mem=totalMem.value,
|
||||
free_mem=freeMem.value,
|
||||
cc=cc) )
|
||||
cuda.cuCtxDetach(context)
|
||||
Devices.all_devices = Devices(devices)
|
||||
return Devices.all_devices
|
||||
"""
|
|
@ -23,28 +23,13 @@ class Conv2D(nn.LayerBase):
|
|||
if padding == "SAME":
|
||||
padding = ( (kernel_size - 1) * dilations + 1 ) // 2
|
||||
elif padding == "VALID":
|
||||
padding = 0
|
||||
padding = None
|
||||
else:
|
||||
raise ValueError ("Wrong padding type. Should be VALID SAME or INT or 4x INTs")
|
||||
|
||||
if isinstance(padding, int):
|
||||
if padding != 0:
|
||||
if nn.data_format == "NHWC":
|
||||
padding = [ [0,0], [padding,padding], [padding,padding], [0,0] ]
|
||||
else:
|
||||
padding = [ [0,0], [0,0], [padding,padding], [padding,padding] ]
|
||||
else:
|
||||
padding = None
|
||||
|
||||
if nn.data_format == "NHWC":
|
||||
strides = [1,strides,strides,1]
|
||||
else:
|
||||
strides = [1,1,strides,strides]
|
||||
|
||||
if nn.data_format == "NHWC":
|
||||
dilations = [1,dilations,dilations,1]
|
||||
else:
|
||||
dilations = [1,1,dilations,dilations]
|
||||
padding = int(padding)
|
||||
|
||||
|
||||
|
||||
self.in_ch = in_ch
|
||||
self.out_ch = out_ch
|
||||
|
@ -70,8 +55,8 @@ class Conv2D(nn.LayerBase):
|
|||
if kernel_initializer is None:
|
||||
kernel_initializer = tf.initializers.random_normal(0, 1.0, dtype=self.dtype)
|
||||
|
||||
if kernel_initializer is None:
|
||||
kernel_initializer = nn.initializers.ca()
|
||||
#if kernel_initializer is None:
|
||||
# kernel_initializer = nn.initializers.ca()
|
||||
|
||||
self.weight = tf.get_variable("weight", (self.kernel_size,self.kernel_size,self.in_ch,self.out_ch), dtype=self.dtype, initializer=kernel_initializer, trainable=self.trainable )
|
||||
|
||||
|
@ -93,10 +78,27 @@ class Conv2D(nn.LayerBase):
|
|||
if self.use_wscale:
|
||||
weight = weight * self.wscale
|
||||
|
||||
if self.padding is not None:
|
||||
x = tf.pad (x, self.padding, mode='CONSTANT')
|
||||
padding = self.padding
|
||||
if padding is not None:
|
||||
if nn.data_format == "NHWC":
|
||||
padding = [ [0,0], [padding,padding], [padding,padding], [0,0] ]
|
||||
else:
|
||||
padding = [ [0,0], [0,0], [padding,padding], [padding,padding] ]
|
||||
x = tf.pad (x, padding, mode='CONSTANT')
|
||||
|
||||
strides = self.strides
|
||||
if nn.data_format == "NHWC":
|
||||
strides = [1,strides,strides,1]
|
||||
else:
|
||||
strides = [1,1,strides,strides]
|
||||
|
||||
x = tf.nn.conv2d(x, weight, self.strides, 'VALID', dilations=self.dilations, data_format=nn.data_format)
|
||||
dilations = self.dilations
|
||||
if nn.data_format == "NHWC":
|
||||
dilations = [1,dilations,dilations,1]
|
||||
else:
|
||||
dilations = [1,1,dilations,dilations]
|
||||
|
||||
x = tf.nn.conv2d(x, weight, strides, 'VALID', dilations=dilations, data_format=nn.data_format)
|
||||
if self.use_bias:
|
||||
if nn.data_format == "NHWC":
|
||||
bias = tf.reshape (self.bias, (1,1,1,self.out_ch) )
|
||||
|
|
|
@ -38,8 +38,8 @@ class Conv2DTranspose(nn.LayerBase):
|
|||
if kernel_initializer is None:
|
||||
kernel_initializer = tf.initializers.random_normal(0, 1.0, dtype=self.dtype)
|
||||
|
||||
if kernel_initializer is None:
|
||||
kernel_initializer = nn.initializers.ca()
|
||||
#if kernel_initializer is None:
|
||||
# kernel_initializer = nn.initializers.ca()
|
||||
self.weight = tf.get_variable("weight", (self.kernel_size,self.kernel_size,self.out_ch,self.in_ch), dtype=self.dtype, initializer=kernel_initializer, trainable=self.trainable )
|
||||
|
||||
if self.use_bias:
|
||||
|
|
16
core/leras/layers/DenseNorm.py
Normal file
|
@ -0,0 +1,16 @@
|
|||
from core.leras import nn
|
||||
tf = nn.tf
|
||||
|
||||
class DenseNorm(nn.LayerBase):
|
||||
def __init__(self, dense=False, eps=1e-06, dtype=None, **kwargs):
|
||||
self.dense = dense
|
||||
if dtype is None:
|
||||
dtype = nn.floatx
|
||||
self.eps = tf.constant(eps, dtype=dtype, name="epsilon")
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def __call__(self, x):
|
||||
return x * tf.rsqrt(tf.reduce_mean(tf.square(x), axis=-1, keepdims=True) + self.eps)
|
||||
|
||||
nn.DenseNorm = DenseNorm
|
110
core/leras/layers/DepthwiseConv2D.py
Normal file
|
@ -0,0 +1,110 @@
|
|||
import numpy as np
|
||||
from core.leras import nn
|
||||
tf = nn.tf
|
||||
|
||||
class DepthwiseConv2D(nn.LayerBase):
|
||||
"""
|
||||
default kernel_initializer - CA
|
||||
use_wscale bool enables equalized learning rate, if kernel_initializer is None, it will be forced to random_normal
|
||||
"""
|
||||
def __init__(self, in_ch, kernel_size, strides=1, padding='SAME', depth_multiplier=1, dilations=1, use_bias=True, use_wscale=False, kernel_initializer=None, bias_initializer=None, trainable=True, dtype=None, **kwargs ):
|
||||
if not isinstance(strides, int):
|
||||
raise ValueError ("strides must be an int type")
|
||||
if not isinstance(dilations, int):
|
||||
raise ValueError ("dilations must be an int type")
|
||||
kernel_size = int(kernel_size)
|
||||
|
||||
if dtype is None:
|
||||
dtype = nn.floatx
|
||||
|
||||
if isinstance(padding, str):
|
||||
if padding == "SAME":
|
||||
padding = ( (kernel_size - 1) * dilations + 1 ) // 2
|
||||
elif padding == "VALID":
|
||||
padding = 0
|
||||
else:
|
||||
raise ValueError ("Wrong padding type. Should be VALID SAME or INT or 4x INTs")
|
||||
|
||||
if isinstance(padding, int):
|
||||
if padding != 0:
|
||||
if nn.data_format == "NHWC":
|
||||
padding = [ [0,0], [padding,padding], [padding,padding], [0,0] ]
|
||||
else:
|
||||
padding = [ [0,0], [0,0], [padding,padding], [padding,padding] ]
|
||||
else:
|
||||
padding = None
|
||||
|
||||
if nn.data_format == "NHWC":
|
||||
strides = [1,strides,strides,1]
|
||||
else:
|
||||
strides = [1,1,strides,strides]
|
||||
|
||||
if nn.data_format == "NHWC":
|
||||
dilations = [1,dilations,dilations,1]
|
||||
else:
|
||||
dilations = [1,1,dilations,dilations]
|
||||
|
||||
self.in_ch = in_ch
|
||||
self.depth_multiplier = depth_multiplier
|
||||
self.kernel_size = kernel_size
|
||||
self.strides = strides
|
||||
self.padding = padding
|
||||
self.dilations = dilations
|
||||
self.use_bias = use_bias
|
||||
self.use_wscale = use_wscale
|
||||
self.kernel_initializer = kernel_initializer
|
||||
self.bias_initializer = bias_initializer
|
||||
self.trainable = trainable
|
||||
self.dtype = dtype
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def build_weights(self):
|
||||
kernel_initializer = self.kernel_initializer
|
||||
if self.use_wscale:
|
||||
gain = 1.0 if self.kernel_size == 1 else np.sqrt(2)
|
||||
fan_in = self.kernel_size*self.kernel_size*self.in_ch
|
||||
he_std = gain / np.sqrt(fan_in)
|
||||
self.wscale = tf.constant(he_std, dtype=self.dtype )
|
||||
if kernel_initializer is None:
|
||||
kernel_initializer = tf.initializers.random_normal(0, 1.0, dtype=self.dtype)
|
||||
|
||||
#if kernel_initializer is None:
|
||||
# kernel_initializer = nn.initializers.ca()
|
||||
|
||||
self.weight = tf.get_variable("weight", (self.kernel_size,self.kernel_size,self.in_ch,self.depth_multiplier), dtype=self.dtype, initializer=kernel_initializer, trainable=self.trainable )
|
||||
|
||||
if self.use_bias:
|
||||
bias_initializer = self.bias_initializer
|
||||
if bias_initializer is None:
|
||||
bias_initializer = tf.initializers.zeros(dtype=self.dtype)
|
||||
|
||||
self.bias = tf.get_variable("bias", (self.in_ch*self.depth_multiplier,), dtype=self.dtype, initializer=bias_initializer, trainable=self.trainable )
|
||||
|
||||
def get_weights(self):
|
||||
weights = [self.weight]
|
||||
if self.use_bias:
|
||||
weights += [self.bias]
|
||||
return weights
|
||||
|
||||
def forward(self, x):
|
||||
weight = self.weight
|
||||
if self.use_wscale:
|
||||
weight = weight * self.wscale
|
||||
|
||||
if self.padding is not None:
|
||||
x = tf.pad (x, self.padding, mode='CONSTANT')
|
||||
|
||||
x = tf.nn.depthwise_conv2d(x, weight, self.strides, 'VALID', data_format=nn.data_format)
|
||||
if self.use_bias:
|
||||
if nn.data_format == "NHWC":
|
||||
bias = tf.reshape (self.bias, (1,1,1,self.in_ch*self.depth_multiplier) )
|
||||
else:
|
||||
bias = tf.reshape (self.bias, (1,self.in_ch*self.depth_multiplier,1,1) )
|
||||
x = tf.add(x, bias)
|
||||
return x
|
||||
|
||||
def __str__(self):
|
||||
r = f"{self.__class__.__name__} : in_ch:{self.in_ch} depth_multiplier:{self.depth_multiplier} "
|
||||
return r
|
||||
|
||||
nn.DepthwiseConv2D = DepthwiseConv2D
|
|
@ -46,7 +46,9 @@ class Saveable():
|
|||
raise Exception("name must be defined.")
|
||||
|
||||
name = self.name
|
||||
for w, w_val in zip(weights, nn.tf_sess.run (weights)):
|
||||
|
||||
for w in weights:
|
||||
w_val = nn.tf_sess.run (w).copy()
|
||||
w_name_split = w.name.split('/', 1)
|
||||
if name != w_name_split[0]:
|
||||
raise Exception("weight first name != Saveable.name")
|
||||
|
@ -76,28 +78,31 @@ class Saveable():
|
|||
if self.name is None:
|
||||
raise Exception("name must be defined.")
|
||||
|
||||
tuples = []
|
||||
for w in weights:
|
||||
w_name_split = w.name.split('/')
|
||||
if self.name != w_name_split[0]:
|
||||
raise Exception("weight first name != Saveable.name")
|
||||
try:
|
||||
tuples = []
|
||||
for w in weights:
|
||||
w_name_split = w.name.split('/')
|
||||
if self.name != w_name_split[0]:
|
||||
raise Exception("weight first name != Saveable.name")
|
||||
|
||||
sub_w_name = "/".join(w_name_split[1:])
|
||||
sub_w_name = "/".join(w_name_split[1:])
|
||||
|
||||
w_val = d.get(sub_w_name, None)
|
||||
w_val = d.get(sub_w_name, None)
|
||||
|
||||
if w_val is None:
|
||||
#io.log_err(f"Weight {w.name} was not loaded from file {filename}")
|
||||
tuples.append ( (w, w.initializer) )
|
||||
else:
|
||||
w_val = np.reshape( w_val, w.shape.as_list() )
|
||||
tuples.append ( (w, w_val) )
|
||||
if w_val is None:
|
||||
#io.log_err(f"Weight {w.name} was not loaded from file {filename}")
|
||||
tuples.append ( (w, w.initializer) )
|
||||
else:
|
||||
w_val = np.reshape( w_val, w.shape.as_list() )
|
||||
tuples.append ( (w, w_val) )
|
||||
|
||||
nn.batch_set_value(tuples)
|
||||
nn.batch_set_value(tuples)
|
||||
except:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def init_weights(self):
|
||||
nn.init_weights(self.get_weights())
|
||||
|
||||
|
||||
nn.Saveable = Saveable
|
||||
|
|
31
core/leras/layers/ScaleAdd.py
Normal file
|
@ -0,0 +1,31 @@
|
|||
from core.leras import nn
|
||||
tf = nn.tf
|
||||
|
||||
class ScaleAdd(nn.LayerBase):
|
||||
def __init__(self, ch, dtype=None, **kwargs):
|
||||
if dtype is None:
|
||||
dtype = nn.floatx
|
||||
self.dtype = dtype
|
||||
self.ch = ch
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def build_weights(self):
|
||||
self.weight = tf.get_variable("weight",(self.ch,), dtype=self.dtype, initializer=tf.initializers.zeros() )
|
||||
|
||||
def get_weights(self):
|
||||
return [self.weight]
|
||||
|
||||
def forward(self, inputs):
|
||||
if nn.data_format == "NHWC":
|
||||
shape = (1,1,1,self.ch)
|
||||
else:
|
||||
shape = (1,self.ch,1,1)
|
||||
|
||||
weight = tf.reshape ( self.weight, shape )
|
||||
|
||||
x0, x1 = inputs
|
||||
x = x0 + x1*weight
|
||||
|
||||
return x
|
||||
nn.ScaleAdd = ScaleAdd
|
104
core/leras/layers/TanhPolar.py
Normal file
|
@ -0,0 +1,104 @@
|
|||
import numpy as np
|
||||
from core.leras import nn
|
||||
tf = nn.tf
|
||||
|
||||
class TanhPolar(nn.LayerBase):
|
||||
"""
|
||||
RoI Tanh-polar Transformer Network for Face Parsing in the Wild
|
||||
https://github.com/hhj1897/roi_tanh_warping
|
||||
"""
|
||||
|
||||
def __init__(self, width, height, angular_offset_deg=270, **kwargs):
|
||||
self.width = width
|
||||
self.height = height
|
||||
|
||||
warp_gridx, warp_gridy = TanhPolar._get_tanh_polar_warp_grids(width,height,angular_offset_deg=angular_offset_deg)
|
||||
restore_gridx, restore_gridy = TanhPolar._get_tanh_polar_restore_grids(width,height,angular_offset_deg=angular_offset_deg)
|
||||
|
||||
self.warp_gridx_t = tf.constant(warp_gridx[None, ...])
|
||||
self.warp_gridy_t = tf.constant(warp_gridy[None, ...])
|
||||
self.restore_gridx_t = tf.constant(restore_gridx[None, ...])
|
||||
self.restore_gridy_t = tf.constant(restore_gridy[None, ...])
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def warp(self, inp_t):
|
||||
batch_t = tf.shape(inp_t)[0]
|
||||
warp_gridx_t = tf.tile(self.warp_gridx_t, (batch_t,1,1) )
|
||||
warp_gridy_t = tf.tile(self.warp_gridy_t, (batch_t,1,1) )
|
||||
|
||||
if nn.data_format == "NCHW":
|
||||
inp_t = tf.transpose(inp_t,(0,2,3,1))
|
||||
|
||||
out_t = nn.bilinear_sampler(inp_t, warp_gridx_t, warp_gridy_t)
|
||||
|
||||
if nn.data_format == "NCHW":
|
||||
out_t = tf.transpose(out_t,(0,3,1,2))
|
||||
|
||||
return out_t
|
||||
|
||||
def restore(self, inp_t):
|
||||
batch_t = tf.shape(inp_t)[0]
|
||||
restore_gridx_t = tf.tile(self.restore_gridx_t, (batch_t,1,1) )
|
||||
restore_gridy_t = tf.tile(self.restore_gridy_t, (batch_t,1,1) )
|
||||
|
||||
if nn.data_format == "NCHW":
|
||||
inp_t = tf.transpose(inp_t,(0,2,3,1))
|
||||
|
||||
inp_t = tf.pad(inp_t, [(0,0), (1, 1), (1, 0), (0, 0)], "SYMMETRIC")
|
||||
|
||||
out_t = nn.bilinear_sampler(inp_t, restore_gridx_t, restore_gridy_t)
|
||||
|
||||
if nn.data_format == "NCHW":
|
||||
out_t = tf.transpose(out_t,(0,3,1,2))
|
||||
|
||||
return out_t
|
||||
|
||||
@staticmethod
|
||||
def _get_tanh_polar_warp_grids(W,H,angular_offset_deg):
|
||||
angular_offset_pi = angular_offset_deg * np.pi / 180.0
|
||||
|
||||
roi_center = np.array([ W//2, H//2], np.float32 )
|
||||
roi_radii = np.array([W, H], np.float32 ) / np.pi ** 0.5
|
||||
cos_offset, sin_offset = np.cos(angular_offset_pi), np.sin(angular_offset_pi)
|
||||
normalised_dest_indices = np.stack(np.meshgrid(np.arange(0.0, 1.0, 1.0 / W),np.arange(0.0, 2.0 * np.pi, 2.0 * np.pi / H)), axis=-1)
|
||||
radii = normalised_dest_indices[..., 0]
|
||||
orientation_x = np.cos(normalised_dest_indices[..., 1])
|
||||
orientation_y = np.sin(normalised_dest_indices[..., 1])
|
||||
|
||||
src_radii = np.arctanh(radii) * (roi_radii[0] * roi_radii[1] / np.sqrt(roi_radii[1] ** 2 * orientation_x ** 2 + roi_radii[0] ** 2 * orientation_y ** 2))
|
||||
src_x_indices = src_radii * orientation_x
|
||||
src_y_indices = src_radii * orientation_y
|
||||
src_x_indices, src_y_indices = (roi_center[0] + cos_offset * src_x_indices - sin_offset * src_y_indices,
|
||||
roi_center[1] + cos_offset * src_y_indices + sin_offset * src_x_indices)
|
||||
|
||||
return src_x_indices.astype(np.float32), src_y_indices.astype(np.float32)
|
||||
|
||||
@staticmethod
|
||||
def _get_tanh_polar_restore_grids(W,H,angular_offset_deg):
|
||||
angular_offset_pi = angular_offset_deg * np.pi / 180.0
|
||||
|
||||
roi_center = np.array([ W//2, H//2], np.float32 )
|
||||
roi_radii = np.array([W, H], np.float32 ) / np.pi ** 0.5
|
||||
cos_offset, sin_offset = np.cos(angular_offset_pi), np.sin(angular_offset_pi)
|
||||
|
||||
dest_indices = np.stack(np.meshgrid(np.arange(W), np.arange(H)), axis=-1).astype(float)
|
||||
normalised_dest_indices = np.matmul(dest_indices - roi_center, np.array([[cos_offset, -sin_offset],
|
||||
[sin_offset, cos_offset]]))
|
||||
radii = np.linalg.norm(normalised_dest_indices, axis=-1)
|
||||
normalised_dest_indices[..., 0] /= np.clip(radii, 1e-9, None)
|
||||
normalised_dest_indices[..., 1] /= np.clip(radii, 1e-9, None)
|
||||
radii *= np.sqrt(roi_radii[1] ** 2 * normalised_dest_indices[..., 0] ** 2 +
|
||||
roi_radii[0] ** 2 * normalised_dest_indices[..., 1] ** 2) / roi_radii[0] / roi_radii[1]
|
||||
|
||||
src_radii = np.tanh(radii)
|
||||
|
||||
|
||||
src_x_indices = src_radii * W + 1.0
|
||||
src_y_indices = np.mod((np.arctan2(normalised_dest_indices[..., 1], normalised_dest_indices[..., 0]) /
|
||||
2.0 / np.pi) * H, H) + 1.0
|
||||
|
||||
return src_x_indices.astype(np.float32), src_y_indices.astype(np.float32)
|
||||
|
||||
|
||||
nn.TanhPolar = TanhPolar
|
|
@ -3,10 +3,16 @@ from .LayerBase import *
|
|||
|
||||
from .Conv2D import *
|
||||
from .Conv2DTranspose import *
|
||||
from .DepthwiseConv2D import *
|
||||
from .Dense import *
|
||||
from .BlurPool import *
|
||||
|
||||
from .BatchNorm2D import *
|
||||
from .InstanceNorm2D import *
|
||||
from .FRNorm2D import *
|
||||
|
||||
from .TLU import *
|
||||
from .TLU import *
|
||||
from .ScaleAdd import *
|
||||
from .DenseNorm import *
|
||||
from .AdaIN import *
|
||||
from .TanhPolar import *
|
|
@ -18,6 +18,10 @@ class ModelBase(nn.Saveable):
|
|||
if isinstance (layer, list):
|
||||
for i,sublayer in enumerate(layer):
|
||||
self._build_sub(sublayer, f"{name}_{i}")
|
||||
elif isinstance (layer, dict):
|
||||
for subname in layer.keys():
|
||||
sublayer = layer[subname]
|
||||
self._build_sub(sublayer, f"{name}_{subname}")
|
||||
elif isinstance (layer, nn.LayerBase) or \
|
||||
isinstance (layer, ModelBase):
|
||||
|
||||
|
@ -32,7 +36,7 @@ class ModelBase(nn.Saveable):
|
|||
|
||||
self.layers.append (layer)
|
||||
self.layers_by_name[layer.name] = layer
|
||||
|
||||
|
||||
def xor_list(self, lst1, lst2):
|
||||
return [value for value in lst1+lst2 if (value not in lst1) or (value not in lst2) ]
|
||||
|
||||
|
@ -79,7 +83,7 @@ class ModelBase(nn.Saveable):
|
|||
|
||||
def get_layer_by_name(self, name):
|
||||
return self.layers_by_name.get(name, None)
|
||||
|
||||
|
||||
def get_layers(self):
|
||||
if not self.built:
|
||||
self.build()
|
||||
|
@ -112,41 +116,32 @@ class ModelBase(nn.Saveable):
|
|||
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
def compute_output_shape(self, shapes):
|
||||
if not self.built:
|
||||
self.build()
|
||||
# def compute_output_shape(self, shapes):
|
||||
# if not self.built:
|
||||
# self.build()
|
||||
|
||||
not_list = False
|
||||
if not isinstance(shapes, list):
|
||||
not_list = True
|
||||
shapes = [shapes]
|
||||
# not_list = False
|
||||
# if not isinstance(shapes, list):
|
||||
# not_list = True
|
||||
# shapes = [shapes]
|
||||
|
||||
with tf.device('/CPU:0'):
|
||||
# CPU tensors will not impact any performance, only slightly RAM "leakage"
|
||||
phs = []
|
||||
for dtype,sh in shapes:
|
||||
phs += [ tf.placeholder(dtype, sh) ]
|
||||
# with tf.device('/CPU:0'):
|
||||
# # CPU tensors will not impact any performance, only slightly RAM "leakage"
|
||||
# phs = []
|
||||
# for dtype,sh in shapes:
|
||||
# phs += [ tf.placeholder(dtype, sh) ]
|
||||
|
||||
result = self.__call__(phs[0] if not_list else phs)
|
||||
# result = self.__call__(phs[0] if not_list else phs)
|
||||
|
||||
if not isinstance(result, list):
|
||||
result = [result]
|
||||
# if not isinstance(result, list):
|
||||
# result = [result]
|
||||
|
||||
result_shapes = []
|
||||
# result_shapes = []
|
||||
|
||||
for t in result:
|
||||
result_shapes += [ t.shape.as_list() ]
|
||||
# for t in result:
|
||||
# result_shapes += [ t.shape.as_list() ]
|
||||
|
||||
return result_shapes[0] if not_list else result_shapes
|
||||
|
||||
def compute_output_channels(self, shapes):
|
||||
shape = self.compute_output_shape(shapes)
|
||||
shape_len = len(shape)
|
||||
|
||||
if shape_len == 4:
|
||||
if nn.data_format == "NCHW":
|
||||
return shape[1]
|
||||
return shape[-1]
|
||||
# return result_shapes[0] if not_list else result_shapes
|
||||
|
||||
def build_for_run(self, shapes_list):
|
||||
if not isinstance(shapes_list, list):
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import numpy as np
|
||||
from core.leras import nn
|
||||
tf = nn.tf
|
||||
|
||||
|
||||
patch_discriminator_kernels = \
|
||||
{ 1 : (512, [ [1,1] ]),
|
||||
2 : (512, [ [2,1] ]),
|
||||
|
@ -12,7 +12,7 @@ patch_discriminator_kernels = \
|
|||
7 : (512, [ [3,2], [3,2] ]),
|
||||
8 : (512, [ [4,2], [3,2] ]),
|
||||
9 : (512, [ [3,2], [4,2] ]),
|
||||
10 : (512, [ [4,2], [4,2] ]),
|
||||
10 : (512, [ [4,2], [4,2] ]),
|
||||
11 : (512, [ [3,2], [3,2], [2,1] ]),
|
||||
12 : (512, [ [4,2], [3,2], [2,1] ]),
|
||||
13 : (512, [ [3,2], [4,2], [2,1] ]),
|
||||
|
@ -20,42 +20,50 @@ patch_discriminator_kernels = \
|
|||
15 : (512, [ [3,2], [3,2], [3,1] ]),
|
||||
16 : (512, [ [4,2], [3,2], [3,1] ]),
|
||||
17 : (512, [ [3,2], [4,2], [3,1] ]),
|
||||
18 : (512, [ [4,2], [4,2], [3,1] ]),
|
||||
18 : (512, [ [4,2], [4,2], [3,1] ]),
|
||||
19 : (512, [ [3,2], [3,2], [4,1] ]),
|
||||
20 : (512, [ [4,2], [3,2], [4,1] ]),
|
||||
21 : (512, [ [3,2], [4,2], [4,1] ]),
|
||||
22 : (512, [ [4,2], [4,2], [4,1] ]),
|
||||
22 : (512, [ [4,2], [4,2], [4,1] ]),
|
||||
23 : (256, [ [3,2], [3,2], [3,2], [2,1] ]),
|
||||
24 : (256, [ [4,2], [3,2], [3,2], [2,1] ]),
|
||||
25 : (256, [ [3,2], [4,2], [3,2], [2,1] ]),
|
||||
26 : (256, [ [4,2], [4,2], [3,2], [2,1] ]),
|
||||
27 : (256, [ [3,2], [4,2], [4,2], [2,1] ]),
|
||||
26 : (256, [ [4,2], [4,2], [3,2], [2,1] ]),
|
||||
27 : (256, [ [3,2], [4,2], [4,2], [2,1] ]),
|
||||
28 : (256, [ [4,2], [3,2], [4,2], [2,1] ]),
|
||||
29 : (256, [ [3,2], [4,2], [4,2], [2,1] ]),
|
||||
30 : (256, [ [4,2], [4,2], [4,2], [2,1] ]),
|
||||
30 : (256, [ [4,2], [4,2], [4,2], [2,1] ]),
|
||||
31 : (256, [ [3,2], [3,2], [3,2], [3,1] ]),
|
||||
32 : (256, [ [4,2], [3,2], [3,2], [3,1] ]),
|
||||
33 : (256, [ [3,2], [4,2], [3,2], [3,1] ]),
|
||||
34 : (256, [ [4,2], [4,2], [3,2], [3,1] ]),
|
||||
35 : (256, [ [3,2], [4,2], [4,2], [3,1] ]),
|
||||
34 : (256, [ [4,2], [4,2], [3,2], [3,1] ]),
|
||||
35 : (256, [ [3,2], [4,2], [4,2], [3,1] ]),
|
||||
36 : (256, [ [4,2], [3,2], [4,2], [3,1] ]),
|
||||
37 : (256, [ [3,2], [4,2], [4,2], [3,1] ]),
|
||||
38 : (256, [ [4,2], [4,2], [4,2], [3,1] ]),
|
||||
39 : (256, [ [3,2], [3,2], [3,2], [4,1] ]),
|
||||
40 : (256, [ [4,2], [3,2], [3,2], [4,1] ]),
|
||||
41 : (256, [ [3,2], [4,2], [3,2], [4,1] ]),
|
||||
42 : (256, [ [4,2], [4,2], [3,2], [4,1] ]),
|
||||
43 : (256, [ [3,2], [4,2], [4,2], [4,1] ]),
|
||||
44 : (256, [ [4,2], [3,2], [4,2], [4,1] ]),
|
||||
45 : (256, [ [3,2], [4,2], [4,2], [4,1] ]),
|
||||
46 : (256, [ [4,2], [4,2], [4,2], [4,1] ]),
|
||||
}
|
||||
|
||||
|
||||
class PatchDiscriminator(nn.ModelBase):
|
||||
def on_build(self, patch_size, in_ch, base_ch=None, conv_kernel_initializer=None):
|
||||
def on_build(self, patch_size, in_ch, base_ch=None, conv_kernel_initializer=None):
|
||||
suggested_base_ch, kernels_strides = patch_discriminator_kernels[patch_size]
|
||||
|
||||
|
||||
if base_ch is None:
|
||||
base_ch = suggested_base_ch
|
||||
|
||||
|
||||
prev_ch = in_ch
|
||||
self.convs = []
|
||||
for i, (kernel_size, strides) in enumerate(kernels_strides):
|
||||
for i, (kernel_size, strides) in enumerate(kernels_strides):
|
||||
cur_ch = base_ch * min( (2**i), 8 )
|
||||
|
||||
|
||||
self.convs.append ( nn.Conv2D( prev_ch, cur_ch, kernel_size=kernel_size, strides=strides, padding='SAME', kernel_initializer=conv_kernel_initializer) )
|
||||
prev_ch = cur_ch
|
||||
|
||||
|
@ -66,4 +74,121 @@ class PatchDiscriminator(nn.ModelBase):
|
|||
x = tf.nn.leaky_relu( conv(x), 0.1 )
|
||||
return self.out_conv(x)
|
||||
|
||||
nn.PatchDiscriminator = PatchDiscriminator
|
||||
nn.PatchDiscriminator = PatchDiscriminator
|
||||
|
||||
class UNetPatchDiscriminator(nn.ModelBase):
|
||||
"""
|
||||
Inspired by https://arxiv.org/abs/2002.12655 "A U-Net Based Discriminator for Generative Adversarial Networks"
|
||||
"""
|
||||
def calc_receptive_field_size(self, layers):
|
||||
"""
|
||||
result the same as https://fomoro.com/research/article/receptive-field-calculatorindex.html
|
||||
"""
|
||||
rf = 0
|
||||
ts = 1
|
||||
for i, (k, s) in enumerate(layers):
|
||||
if i == 0:
|
||||
rf = k
|
||||
else:
|
||||
rf += (k-1)*ts
|
||||
ts *= s
|
||||
return rf
|
||||
|
||||
def find_archi(self, target_patch_size, max_layers=9):
|
||||
"""
|
||||
Find the best configuration of layers using only 3x3 convs for target patch size
|
||||
"""
|
||||
s = {}
|
||||
for layers_count in range(1,max_layers+1):
|
||||
val = 1 << (layers_count-1)
|
||||
while True:
|
||||
val -= 1
|
||||
|
||||
layers = []
|
||||
sum_st = 0
|
||||
layers.append ( [3, 2])
|
||||
sum_st += 2
|
||||
for i in range(layers_count-1):
|
||||
st = 1 + (1 if val & (1 << i) !=0 else 0 )
|
||||
layers.append ( [3, st ])
|
||||
sum_st += st
|
||||
|
||||
rf = self.calc_receptive_field_size(layers)
|
||||
|
||||
s_rf = s.get(rf, None)
|
||||
if s_rf is None:
|
||||
s[rf] = (layers_count, sum_st, layers)
|
||||
else:
|
||||
if layers_count < s_rf[0] or \
|
||||
( layers_count == s_rf[0] and sum_st > s_rf[1] ):
|
||||
s[rf] = (layers_count, sum_st, layers)
|
||||
|
||||
if val == 0:
|
||||
break
|
||||
|
||||
x = sorted(list(s.keys()))
|
||||
q=x[np.abs(np.array(x)-target_patch_size).argmin()]
|
||||
return s[q][2]
|
||||
|
||||
def on_build(self, patch_size, in_ch, base_ch = 16, use_fp16 = False):
|
||||
self.use_fp16 = use_fp16
|
||||
conv_dtype = tf.float16 if use_fp16 else tf.float32
|
||||
|
||||
class ResidualBlock(nn.ModelBase):
|
||||
def on_build(self, ch, kernel_size=3 ):
|
||||
self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype)
|
||||
self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype)
|
||||
|
||||
def forward(self, inp):
|
||||
x = self.conv1(inp)
|
||||
x = tf.nn.leaky_relu(x, 0.2)
|
||||
x = self.conv2(x)
|
||||
x = tf.nn.leaky_relu(inp + x, 0.2)
|
||||
return x
|
||||
|
||||
prev_ch = in_ch
|
||||
self.convs = []
|
||||
self.upconvs = []
|
||||
layers = self.find_archi(patch_size)
|
||||
|
||||
level_chs = { i-1:v for i,v in enumerate([ min( base_ch * (2**i), 512 ) for i in range(len(layers)+1)]) }
|
||||
|
||||
self.in_conv = nn.Conv2D( in_ch, level_chs[-1], kernel_size=1, padding='VALID', dtype=conv_dtype)
|
||||
|
||||
for i, (kernel_size, strides) in enumerate(layers):
|
||||
self.convs.append ( nn.Conv2D( level_chs[i-1], level_chs[i], kernel_size=kernel_size, strides=strides, padding='SAME', dtype=conv_dtype) )
|
||||
|
||||
self.upconvs.insert (0, nn.Conv2DTranspose( level_chs[i]*(2 if i != len(layers)-1 else 1), level_chs[i-1], kernel_size=kernel_size, strides=strides, padding='SAME', dtype=conv_dtype) )
|
||||
|
||||
self.out_conv = nn.Conv2D( level_chs[-1]*2, 1, kernel_size=1, padding='VALID', dtype=conv_dtype)
|
||||
|
||||
self.center_out = nn.Conv2D( level_chs[len(layers)-1], 1, kernel_size=1, padding='VALID', dtype=conv_dtype)
|
||||
self.center_conv = nn.Conv2D( level_chs[len(layers)-1], level_chs[len(layers)-1], kernel_size=1, padding='VALID', dtype=conv_dtype)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_fp16:
|
||||
x = tf.cast(x, tf.float16)
|
||||
|
||||
x = tf.nn.leaky_relu( self.in_conv(x), 0.2 )
|
||||
|
||||
encs = []
|
||||
for conv in self.convs:
|
||||
encs.insert(0, x)
|
||||
x = tf.nn.leaky_relu( conv(x), 0.2 )
|
||||
|
||||
center_out, x = self.center_out(x), tf.nn.leaky_relu( self.center_conv(x), 0.2 )
|
||||
|
||||
for i, (upconv, enc) in enumerate(zip(self.upconvs, encs)):
|
||||
x = tf.nn.leaky_relu( upconv(x), 0.2 )
|
||||
x = tf.concat( [enc, x], axis=nn.conv2d_ch_axis)
|
||||
|
||||
x = self.out_conv(x)
|
||||
|
||||
if self.use_fp16:
|
||||
center_out = tf.cast(center_out, tf.float32)
|
||||
x = tf.cast(x, tf.float32)
|
||||
|
||||
return center_out, x
|
||||
|
||||
nn.UNetPatchDiscriminator = UNetPatchDiscriminator
|
||||
|
|
|
@ -1,92 +0,0 @@
|
|||
"""
|
||||
using https://github.com/ternaus/TernausNet
|
||||
TernausNet: U-Net with VGG11 Encoder Pre-Trained on ImageNet for Image Segmentation
|
||||
"""
|
||||
|
||||
from core.leras import nn
|
||||
tf = nn.tf
|
||||
|
||||
class Ternaus(nn.ModelBase):
|
||||
def on_build(self, in_ch, base_ch):
|
||||
|
||||
self.features_0 = nn.Conv2D (in_ch, base_ch, kernel_size=3, padding='SAME')
|
||||
self.features_3 = nn.Conv2D (base_ch, base_ch*2, kernel_size=3, padding='SAME')
|
||||
self.features_6 = nn.Conv2D (base_ch*2, base_ch*4, kernel_size=3, padding='SAME')
|
||||
self.features_8 = nn.Conv2D (base_ch*4, base_ch*4, kernel_size=3, padding='SAME')
|
||||
self.features_11 = nn.Conv2D (base_ch*4, base_ch*8, kernel_size=3, padding='SAME')
|
||||
self.features_13 = nn.Conv2D (base_ch*8, base_ch*8, kernel_size=3, padding='SAME')
|
||||
self.features_16 = nn.Conv2D (base_ch*8, base_ch*8, kernel_size=3, padding='SAME')
|
||||
self.features_18 = nn.Conv2D (base_ch*8, base_ch*8, kernel_size=3, padding='SAME')
|
||||
|
||||
self.blurpool_0 = nn.BlurPool (filt_size=3)
|
||||
self.blurpool_3 = nn.BlurPool (filt_size=3)
|
||||
self.blurpool_8 = nn.BlurPool (filt_size=3)
|
||||
self.blurpool_13 = nn.BlurPool (filt_size=3)
|
||||
self.blurpool_18 = nn.BlurPool (filt_size=3)
|
||||
|
||||
self.conv_center = nn.Conv2D (base_ch*8, base_ch*8, kernel_size=3, padding='SAME')
|
||||
|
||||
self.conv1_up = nn.Conv2DTranspose (base_ch*8, base_ch*4, kernel_size=3, padding='SAME')
|
||||
self.conv1 = nn.Conv2D (base_ch*12, base_ch*8, kernel_size=3, padding='SAME')
|
||||
|
||||
self.conv2_up = nn.Conv2DTranspose (base_ch*8, base_ch*4, kernel_size=3, padding='SAME')
|
||||
self.conv2 = nn.Conv2D (base_ch*12, base_ch*8, kernel_size=3, padding='SAME')
|
||||
|
||||
self.conv3_up = nn.Conv2DTranspose (base_ch*8, base_ch*2, kernel_size=3, padding='SAME')
|
||||
self.conv3 = nn.Conv2D (base_ch*6, base_ch*4, kernel_size=3, padding='SAME')
|
||||
|
||||
self.conv4_up = nn.Conv2DTranspose (base_ch*4, base_ch, kernel_size=3, padding='SAME')
|
||||
self.conv4 = nn.Conv2D (base_ch*3, base_ch*2, kernel_size=3, padding='SAME')
|
||||
|
||||
self.conv5_up = nn.Conv2DTranspose (base_ch*2, base_ch//2, kernel_size=3, padding='SAME')
|
||||
self.conv5 = nn.Conv2D (base_ch//2+base_ch, base_ch, kernel_size=3, padding='SAME')
|
||||
|
||||
self.out_conv = nn.Conv2D (base_ch, 1, kernel_size=3, padding='SAME')
|
||||
|
||||
def forward(self, inp):
|
||||
x, = inp
|
||||
|
||||
x = x0 = tf.nn.relu(self.features_0(x))
|
||||
x = self.blurpool_0(x)
|
||||
|
||||
x = x1 = tf.nn.relu(self.features_3(x))
|
||||
x = self.blurpool_3(x)
|
||||
|
||||
x = tf.nn.relu(self.features_6(x))
|
||||
x = x2 = tf.nn.relu(self.features_8(x))
|
||||
x = self.blurpool_8(x)
|
||||
|
||||
x = tf.nn.relu(self.features_11(x))
|
||||
x = x3 = tf.nn.relu(self.features_13(x))
|
||||
x = self.blurpool_13(x)
|
||||
|
||||
x = tf.nn.relu(self.features_16(x))
|
||||
x = x4 = tf.nn.relu(self.features_18(x))
|
||||
x = self.blurpool_18(x)
|
||||
|
||||
x = self.conv_center(x)
|
||||
|
||||
x = tf.nn.relu(self.conv1_up(x))
|
||||
x = tf.concat( [x,x4], nn.conv2d_ch_axis)
|
||||
x = tf.nn.relu(self.conv1(x))
|
||||
|
||||
x = tf.nn.relu(self.conv2_up(x))
|
||||
x = tf.concat( [x,x3], nn.conv2d_ch_axis)
|
||||
x = tf.nn.relu(self.conv2(x))
|
||||
|
||||
x = tf.nn.relu(self.conv3_up(x))
|
||||
x = tf.concat( [x,x2], nn.conv2d_ch_axis)
|
||||
x = tf.nn.relu(self.conv3(x))
|
||||
|
||||
x = tf.nn.relu(self.conv4_up(x))
|
||||
x = tf.concat( [x,x1], nn.conv2d_ch_axis)
|
||||
x = tf.nn.relu(self.conv4(x))
|
||||
|
||||
x = tf.nn.relu(self.conv5_up(x))
|
||||
x = tf.concat( [x,x0], nn.conv2d_ch_axis)
|
||||
x = tf.nn.relu(self.conv5(x))
|
||||
|
||||
logits = self.out_conv(x)
|
||||
return logits, tf.nn.sigmoid(logits)
|
||||
|
||||
nn.Ternaus = Ternaus
|
|
@ -28,11 +28,12 @@ class XSeg(nn.ModelBase):
|
|||
x = self.frn(x)
|
||||
x = self.tlu(x)
|
||||
return x
|
||||
|
||||
self.base_ch = base_ch
|
||||
|
||||
self.conv01 = ConvBlock(in_ch, base_ch)
|
||||
self.conv02 = ConvBlock(base_ch, base_ch)
|
||||
self.bp0 = nn.BlurPool (filt_size=3)
|
||||
|
||||
self.bp0 = nn.BlurPool (filt_size=4)
|
||||
|
||||
self.conv11 = ConvBlock(base_ch, base_ch*2)
|
||||
self.conv12 = ConvBlock(base_ch*2, base_ch*2)
|
||||
|
@ -40,19 +41,30 @@ class XSeg(nn.ModelBase):
|
|||
|
||||
self.conv21 = ConvBlock(base_ch*2, base_ch*4)
|
||||
self.conv22 = ConvBlock(base_ch*4, base_ch*4)
|
||||
self.conv23 = ConvBlock(base_ch*4, base_ch*4)
|
||||
self.bp2 = nn.BlurPool (filt_size=3)
|
||||
|
||||
self.bp2 = nn.BlurPool (filt_size=2)
|
||||
|
||||
self.conv31 = ConvBlock(base_ch*4, base_ch*8)
|
||||
self.conv32 = ConvBlock(base_ch*8, base_ch*8)
|
||||
self.conv33 = ConvBlock(base_ch*8, base_ch*8)
|
||||
self.bp3 = nn.BlurPool (filt_size=3)
|
||||
self.bp3 = nn.BlurPool (filt_size=2)
|
||||
|
||||
self.conv41 = ConvBlock(base_ch*8, base_ch*8)
|
||||
self.conv42 = ConvBlock(base_ch*8, base_ch*8)
|
||||
self.conv43 = ConvBlock(base_ch*8, base_ch*8)
|
||||
self.bp4 = nn.BlurPool (filt_size=3)
|
||||
self.bp4 = nn.BlurPool (filt_size=2)
|
||||
|
||||
self.conv51 = ConvBlock(base_ch*8, base_ch*8)
|
||||
self.conv52 = ConvBlock(base_ch*8, base_ch*8)
|
||||
self.conv53 = ConvBlock(base_ch*8, base_ch*8)
|
||||
self.bp5 = nn.BlurPool (filt_size=2)
|
||||
|
||||
self.dense1 = nn.Dense ( 4*4* base_ch*8, 512)
|
||||
self.dense2 = nn.Dense ( 512, 4*4* base_ch*8)
|
||||
|
||||
self.up5 = UpConvBlock (base_ch*8, base_ch*4)
|
||||
self.uconv53 = ConvBlock(base_ch*12, base_ch*8)
|
||||
self.uconv52 = ConvBlock(base_ch*8, base_ch*8)
|
||||
self.uconv51 = ConvBlock(base_ch*8, base_ch*8)
|
||||
|
||||
self.up4 = UpConvBlock (base_ch*8, base_ch*4)
|
||||
self.uconv43 = ConvBlock(base_ch*12, base_ch*8)
|
||||
|
@ -65,8 +77,7 @@ class XSeg(nn.ModelBase):
|
|||
self.uconv31 = ConvBlock(base_ch*8, base_ch*8)
|
||||
|
||||
self.up2 = UpConvBlock (base_ch*8, base_ch*4)
|
||||
self.uconv23 = ConvBlock(base_ch*8, base_ch*4)
|
||||
self.uconv22 = ConvBlock(base_ch*4, base_ch*4)
|
||||
self.uconv22 = ConvBlock(base_ch*8, base_ch*4)
|
||||
self.uconv21 = ConvBlock(base_ch*4, base_ch*4)
|
||||
|
||||
self.up1 = UpConvBlock (base_ch*4, base_ch*2)
|
||||
|
@ -77,10 +88,9 @@ class XSeg(nn.ModelBase):
|
|||
self.uconv02 = ConvBlock(base_ch*2, base_ch)
|
||||
self.uconv01 = ConvBlock(base_ch, base_ch)
|
||||
self.out_conv = nn.Conv2D (base_ch, out_ch, kernel_size=3, padding='SAME')
|
||||
|
||||
|
||||
self.conv_center = ConvBlock(base_ch*8, base_ch*8)
|
||||
|
||||
def forward(self, inp):
|
||||
def forward(self, inp, pretrain=False):
|
||||
x = inp
|
||||
|
||||
x = self.conv01(x)
|
||||
|
@ -92,8 +102,7 @@ class XSeg(nn.ModelBase):
|
|||
x = self.bp1(x)
|
||||
|
||||
x = self.conv21(x)
|
||||
x = self.conv22(x)
|
||||
x = x2 = self.conv23(x)
|
||||
x = x2 = self.conv22(x)
|
||||
x = self.bp2(x)
|
||||
|
||||
x = self.conv31(x)
|
||||
|
@ -106,28 +115,52 @@ class XSeg(nn.ModelBase):
|
|||
x = x4 = self.conv43(x)
|
||||
x = self.bp4(x)
|
||||
|
||||
x = self.conv_center(x)
|
||||
|
||||
x = self.conv51(x)
|
||||
x = self.conv52(x)
|
||||
x = x5 = self.conv53(x)
|
||||
x = self.bp5(x)
|
||||
|
||||
x = nn.flatten(x)
|
||||
x = self.dense1(x)
|
||||
x = self.dense2(x)
|
||||
x = nn.reshape_4D (x, 4, 4, self.base_ch*8 )
|
||||
|
||||
x = self.up5(x)
|
||||
if pretrain:
|
||||
x5 = tf.zeros_like(x5)
|
||||
x = self.uconv53(tf.concat([x,x5],axis=nn.conv2d_ch_axis))
|
||||
x = self.uconv52(x)
|
||||
x = self.uconv51(x)
|
||||
|
||||
x = self.up4(x)
|
||||
if pretrain:
|
||||
x4 = tf.zeros_like(x4)
|
||||
x = self.uconv43(tf.concat([x,x4],axis=nn.conv2d_ch_axis))
|
||||
x = self.uconv42(x)
|
||||
x = self.uconv41(x)
|
||||
|
||||
x = self.up3(x)
|
||||
if pretrain:
|
||||
x3 = tf.zeros_like(x3)
|
||||
x = self.uconv33(tf.concat([x,x3],axis=nn.conv2d_ch_axis))
|
||||
x = self.uconv32(x)
|
||||
x = self.uconv31(x)
|
||||
|
||||
x = self.up2(x)
|
||||
x = self.uconv23(tf.concat([x,x2],axis=nn.conv2d_ch_axis))
|
||||
x = self.uconv22(x)
|
||||
if pretrain:
|
||||
x2 = tf.zeros_like(x2)
|
||||
x = self.uconv22(tf.concat([x,x2],axis=nn.conv2d_ch_axis))
|
||||
x = self.uconv21(x)
|
||||
|
||||
x = self.up1(x)
|
||||
if pretrain:
|
||||
x1 = tf.zeros_like(x1)
|
||||
x = self.uconv12(tf.concat([x,x1],axis=nn.conv2d_ch_axis))
|
||||
x = self.uconv11(x)
|
||||
|
||||
x = self.up0(x)
|
||||
if pretrain:
|
||||
x0 = tf.zeros_like(x0)
|
||||
x = self.uconv02(tf.concat([x,x0],axis=nn.conv2d_ch_axis))
|
||||
x = self.uconv01(x)
|
||||
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
from .ModelBase import *
|
||||
from .PatchDiscriminator import *
|
||||
from .CodeDiscriminator import *
|
||||
from .Ternaus import *
|
||||
from .XSeg import *
|
|
@ -33,8 +33,8 @@ class nn():
|
|||
tf = None
|
||||
tf_sess = None
|
||||
tf_sess_config = None
|
||||
tf_default_device = None
|
||||
|
||||
tf_default_device_name = None
|
||||
|
||||
data_format = None
|
||||
conv2d_ch_axis = None
|
||||
conv2d_spatial_axes = None
|
||||
|
@ -50,9 +50,6 @@ class nn():
|
|||
nn.setCurrentDeviceConfig(device_config)
|
||||
|
||||
# Manipulate environment variables before import tensorflow
|
||||
|
||||
if 'CUDA_VISIBLE_DEVICES' in os.environ.keys():
|
||||
os.environ.pop('CUDA_VISIBLE_DEVICES')
|
||||
|
||||
first_run = False
|
||||
if len(device_config.devices) != 0:
|
||||
|
@ -68,21 +65,32 @@ class nn():
|
|||
compute_cache_path = Path(os.environ['APPDATA']) / 'NVIDIA' / ('ComputeCache' + devices_str)
|
||||
if not compute_cache_path.exists():
|
||||
first_run = True
|
||||
compute_cache_path.mkdir(parents=True, exist_ok=True)
|
||||
os.environ['CUDA_CACHE_PATH'] = str(compute_cache_path)
|
||||
|
||||
os.environ['CUDA_CACHE_MAXSIZE'] = '536870912' #512Mb (32mb default)
|
||||
os.environ['TF_MIN_GPU_MULTIPROCESSOR_COUNT'] = '2'
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # tf log errors only
|
||||
|
||||
|
||||
if first_run:
|
||||
io.log_info("Caching GPU kernels...")
|
||||
|
||||
import tensorflow as tf
|
||||
nn.tf = tf
|
||||
|
||||
import tensorflow
|
||||
|
||||
tf_version = tensorflow.version.VERSION
|
||||
#if tf_version is None:
|
||||
# tf_version = tensorflow.version.GIT_VERSION
|
||||
if tf_version[0] == 'v':
|
||||
tf_version = tf_version[1:]
|
||||
if tf_version[0] == '2':
|
||||
tf = tensorflow.compat.v1
|
||||
else:
|
||||
tf = tensorflow
|
||||
|
||||
import logging
|
||||
# Disable tensorflow warnings
|
||||
logging.getLogger('tensorflow').setLevel(logging.ERROR)
|
||||
tf_logger = logging.getLogger('tensorflow')
|
||||
tf_logger.setLevel(logging.ERROR)
|
||||
|
||||
if tf_version[0] == '2':
|
||||
tf.disable_v2_behavior()
|
||||
nn.tf = tf
|
||||
|
||||
# Initialize framework
|
||||
import core.leras.ops
|
||||
|
@ -94,13 +102,14 @@ class nn():
|
|||
|
||||
# Configure tensorflow session-config
|
||||
if len(device_config.devices) == 0:
|
||||
nn.tf_default_device = "/CPU:0"
|
||||
config = tf.ConfigProto(device_count={'GPU': 0})
|
||||
nn.tf_default_device_name = '/CPU:0'
|
||||
else:
|
||||
nn.tf_default_device = "/GPU:0"
|
||||
nn.tf_default_device_name = f'/{device_config.devices[0].tf_dev_type}:0'
|
||||
|
||||
config = tf.ConfigProto()
|
||||
config.gpu_options.visible_device_list = ','.join([str(device.index) for device in device_config.devices])
|
||||
|
||||
|
||||
config.gpu_options.force_gpu_compatible = True
|
||||
config.gpu_options.allow_growth = True
|
||||
nn.tf_sess_config = config
|
||||
|
@ -188,14 +197,6 @@ class nn():
|
|||
nn.tf_sess.close()
|
||||
nn.tf_sess = None
|
||||
|
||||
@staticmethod
|
||||
def get_current_device():
|
||||
# Undocumented access to last tf.device(...)
|
||||
objs = nn.tf.get_default_graph()._device_function_stack.peek_objs()
|
||||
if len(objs) != 0:
|
||||
return objs[0].display_name
|
||||
return nn.tf_default_device
|
||||
|
||||
@staticmethod
|
||||
def ask_choose_device_idxs(choose_only_one=False, allow_cpu=True, suggest_best_multi_gpu=False, suggest_all_gpu=False):
|
||||
devices = Devices.getDevices()
|
||||
|
|
|
@ -56,7 +56,7 @@ def tf_gradients ( loss, vars ):
|
|||
gv = [*zip(grads,vars)]
|
||||
for g,v in gv:
|
||||
if g is None:
|
||||
raise Exception(f"No gradient for variable {v.name}")
|
||||
raise Exception(f"Variable {v.name} is declared as trainable, but no tensors flow through it.")
|
||||
return gv
|
||||
nn.gradients = tf_gradients
|
||||
|
||||
|
@ -108,10 +108,15 @@ nn.gelu = gelu
|
|||
|
||||
def upsample2d(x, size=2):
|
||||
if nn.data_format == "NCHW":
|
||||
b,c,h,w = x.shape.as_list()
|
||||
x = tf.reshape (x, (-1,c,h,1,w,1) )
|
||||
x = tf.tile(x, (1,1,1,size,1,size) )
|
||||
x = tf.reshape (x, (-1,c,h*size,w*size) )
|
||||
x = tf.transpose(x, (0,2,3,1))
|
||||
x = tf.image.resize_nearest_neighbor(x, (x.shape[1]*size, x.shape[2]*size) )
|
||||
x = tf.transpose(x, (0,3,1,2))
|
||||
|
||||
|
||||
# b,c,h,w = x.shape.as_list()
|
||||
# x = tf.reshape (x, (-1,c,h,1,w,1) )
|
||||
# x = tf.tile(x, (1,1,1,size,1,size) )
|
||||
# x = tf.reshape (x, (-1,c,h*size,w*size) )
|
||||
return x
|
||||
else:
|
||||
return tf.image.resize_nearest_neighbor(x, (x.shape[1]*size, x.shape[2]*size) )
|
||||
|
@ -120,25 +125,56 @@ nn.upsample2d = upsample2d
|
|||
def resize2d_bilinear(x, size=2):
|
||||
h = x.shape[nn.conv2d_spatial_axes[0]].value
|
||||
w = x.shape[nn.conv2d_spatial_axes[1]].value
|
||||
|
||||
|
||||
if nn.data_format == "NCHW":
|
||||
x = tf.transpose(x, (0,2,3,1))
|
||||
|
||||
|
||||
if size > 0:
|
||||
new_size = (h*size,w*size)
|
||||
else:
|
||||
new_size = (h//-size,w//-size)
|
||||
|
||||
x = tf.image.resize(x, new_size, method=tf.image.ResizeMethod.BILINEAR)
|
||||
|
||||
|
||||
if nn.data_format == "NCHW":
|
||||
x = tf.transpose(x, (0,3,1,2))
|
||||
|
||||
return x
|
||||
x = tf.transpose(x, (0,3,1,2))
|
||||
|
||||
return x
|
||||
nn.resize2d_bilinear = resize2d_bilinear
|
||||
|
||||
def resize2d_nearest(x, size=2):
|
||||
if size in [-1,0,1]:
|
||||
return x
|
||||
|
||||
|
||||
if size > 0:
|
||||
raise Exception("")
|
||||
else:
|
||||
if nn.data_format == "NCHW":
|
||||
x = x[:,:,::-size,::-size]
|
||||
else:
|
||||
x = x[:,::-size,::-size,:]
|
||||
return x
|
||||
|
||||
h = x.shape[nn.conv2d_spatial_axes[0]].value
|
||||
w = x.shape[nn.conv2d_spatial_axes[1]].value
|
||||
|
||||
if nn.data_format == "NCHW":
|
||||
x = tf.transpose(x, (0,2,3,1))
|
||||
|
||||
if size > 0:
|
||||
new_size = (h*size,w*size)
|
||||
else:
|
||||
new_size = (h//-size,w//-size)
|
||||
|
||||
x = tf.image.resize(x, new_size, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
|
||||
|
||||
if nn.data_format == "NCHW":
|
||||
x = tf.transpose(x, (0,3,1,2))
|
||||
|
||||
return x
|
||||
nn.resize2d_nearest = resize2d_nearest
|
||||
|
||||
def flatten(x):
|
||||
if nn.data_format == "NHWC":
|
||||
# match NCHW version in order to switch data_format without problems
|
||||
|
@ -173,7 +209,7 @@ def random_binomial(shape, p=0.0, dtype=None, seed=None):
|
|||
seed = np.random.randint(10e6)
|
||||
return array_ops.where(
|
||||
random_ops.random_uniform(shape, dtype=tf.float16, seed=seed) < p,
|
||||
array_ops.ones(shape, dtype=dtype), array_ops.zeros(shape, dtype=dtype))
|
||||
array_ops.ones(shape, dtype=dtype), array_ops.zeros(shape, dtype=dtype))
|
||||
nn.random_binomial = random_binomial
|
||||
|
||||
def gaussian_blur(input, radius=2.0):
|
||||
|
@ -181,7 +217,9 @@ def gaussian_blur(input, radius=2.0):
|
|||
return np.exp(-(float(x) - float(mu)) ** 2 / (2 * sigma ** 2))
|
||||
|
||||
def make_kernel(sigma):
|
||||
kernel_size = max(3, int(2 * 2 * sigma + 1))
|
||||
kernel_size = max(3, int(2 * 2 * sigma))
|
||||
if kernel_size % 2 == 0:
|
||||
kernel_size += 1
|
||||
mean = np.floor(0.5 * kernel_size)
|
||||
kernel_1d = np.array([gaussian(x, mean, sigma) for x in range(kernel_size)])
|
||||
np_kernel = np.outer(kernel_1d, kernel_1d).astype(np.float32)
|
||||
|
@ -238,6 +276,8 @@ def dssim(img1,img2, max_val, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03
|
|||
img1 = tf.cast(img1, tf.float32)
|
||||
img2 = tf.cast(img2, tf.float32)
|
||||
|
||||
filter_size = max(1, filter_size)
|
||||
|
||||
kernel = np.arange(0, filter_size, dtype=np.float32)
|
||||
kernel -= (filter_size - 1 ) / 2.0
|
||||
kernel = kernel**2
|
||||
|
@ -300,7 +340,17 @@ def depth_to_space(x, size):
|
|||
x = tf.reshape(x, (-1, oh, ow, oc, ))
|
||||
return x
|
||||
else:
|
||||
return tf.depth_to_space(x, size, data_format=nn.data_format)
|
||||
cfg = nn.getCurrentDeviceConfig()
|
||||
if not cfg.cpu_only:
|
||||
return tf.depth_to_space(x, size, data_format=nn.data_format)
|
||||
b,c,h,w = x.shape.as_list()
|
||||
oh, ow = h * size, w * size
|
||||
oc = c // (size * size)
|
||||
|
||||
x = tf.reshape(x, (-1, size, size, oc, h, w, ) )
|
||||
x = tf.transpose(x, (0, 3, 4, 1, 5, 2))
|
||||
x = tf.reshape(x, (-1, oc, oh, ow))
|
||||
return x
|
||||
nn.depth_to_space = depth_to_space
|
||||
|
||||
def rgb_to_lab(srgb):
|
||||
|
@ -333,6 +383,23 @@ def rgb_to_lab(srgb):
|
|||
return tf.reshape(lab_pixels, tf.shape(srgb))
|
||||
nn.rgb_to_lab = rgb_to_lab
|
||||
|
||||
def total_variation_mse(images):
|
||||
"""
|
||||
Same as generic total_variation, but MSE diff instead of MAE
|
||||
"""
|
||||
pixel_dif1 = images[:, 1:, :, :] - images[:, :-1, :, :]
|
||||
pixel_dif2 = images[:, :, 1:, :] - images[:, :, :-1, :]
|
||||
|
||||
tot_var = ( tf.reduce_sum(tf.square(pixel_dif1), axis=[1,2,3]) +
|
||||
tf.reduce_sum(tf.square(pixel_dif2), axis=[1,2,3]) )
|
||||
return tot_var
|
||||
nn.total_variation_mse = total_variation_mse
|
||||
|
||||
|
||||
def pixel_norm(x, axes):
|
||||
return x * tf.rsqrt(tf.reduce_mean(tf.square(x), axis=axes, keepdims=True) + 1e-06)
|
||||
nn.pixel_norm = pixel_norm
|
||||
|
||||
"""
|
||||
def tf_suppress_lower_mean(t, eps=0.00001):
|
||||
if t.shape.ndims != 1:
|
||||
|
@ -342,4 +409,70 @@ def tf_suppress_lower_mean(t, eps=0.00001):
|
|||
q = tf.clip_by_value(q-t_mean_eps, 0, eps)
|
||||
q = q * (t/eps)
|
||||
return q
|
||||
"""
|
||||
"""
|
||||
|
||||
|
||||
|
||||
def _get_pixel_value(img, x, y):
|
||||
shape = tf.shape(x)
|
||||
batch_size = shape[0]
|
||||
height = shape[1]
|
||||
width = shape[2]
|
||||
|
||||
batch_idx = tf.range(0, batch_size)
|
||||
batch_idx = tf.reshape(batch_idx, (batch_size, 1, 1))
|
||||
b = tf.tile(batch_idx, (1, height, width))
|
||||
|
||||
indices = tf.stack([b, y, x], 3)
|
||||
|
||||
return tf.gather_nd(img, indices)
|
||||
|
||||
def bilinear_sampler(img, x, y):
|
||||
H = tf.shape(img)[1]
|
||||
W = tf.shape(img)[2]
|
||||
H_MAX = tf.cast(H - 1, tf.int32)
|
||||
W_MAX = tf.cast(W - 1, tf.int32)
|
||||
|
||||
# grab 4 nearest corner points for each (x_i, y_i)
|
||||
x0 = tf.cast(tf.floor(x), tf.int32)
|
||||
x1 = x0 + 1
|
||||
y0 = tf.cast(tf.floor(y), tf.int32)
|
||||
y1 = y0 + 1
|
||||
|
||||
# clip to range [0, H-1/W-1] to not violate img boundaries
|
||||
x0 = tf.clip_by_value(x0, 0, W_MAX)
|
||||
x1 = tf.clip_by_value(x1, 0, W_MAX)
|
||||
y0 = tf.clip_by_value(y0, 0, H_MAX)
|
||||
y1 = tf.clip_by_value(y1, 0, H_MAX)
|
||||
|
||||
# get pixel value at corner coords
|
||||
Ia = _get_pixel_value(img, x0, y0)
|
||||
Ib = _get_pixel_value(img, x0, y1)
|
||||
Ic = _get_pixel_value(img, x1, y0)
|
||||
Id = _get_pixel_value(img, x1, y1)
|
||||
|
||||
# recast as float for delta calculation
|
||||
x0 = tf.cast(x0, tf.float32)
|
||||
x1 = tf.cast(x1, tf.float32)
|
||||
y0 = tf.cast(y0, tf.float32)
|
||||
y1 = tf.cast(y1, tf.float32)
|
||||
|
||||
# calculate deltas
|
||||
wa = (x1-x) * (y1-y)
|
||||
wb = (x1-x) * (y-y0)
|
||||
wc = (x-x0) * (y1-y)
|
||||
wd = (x-x0) * (y-y0)
|
||||
|
||||
# add dimension for addition
|
||||
wa = tf.expand_dims(wa, axis=3)
|
||||
wb = tf.expand_dims(wb, axis=3)
|
||||
wc = tf.expand_dims(wc, axis=3)
|
||||
wd = tf.expand_dims(wd, axis=3)
|
||||
|
||||
# compute output
|
||||
out = tf.add_n([wa*Ia, wb*Ib, wc*Ic, wd*Id])
|
||||
|
||||
return out
|
||||
|
||||
nn.bilinear_sampler = bilinear_sampler
|
||||
|
||||
|
|
81
core/leras/optimizers/AdaBelief.py
Normal file
|
@ -0,0 +1,81 @@
|
|||
import numpy as np
|
||||
from core.leras import nn
|
||||
from tensorflow.python.ops import control_flow_ops, state_ops
|
||||
|
||||
tf = nn.tf
|
||||
|
||||
class AdaBelief(nn.OptimizerBase):
|
||||
def __init__(self, lr=0.001, beta_1=0.9, beta_2=0.999, lr_dropout=1.0, lr_cos=0, clipnorm=0.0, name=None, **kwargs):
|
||||
super().__init__(name=name)
|
||||
|
||||
if name is None:
|
||||
raise ValueError('name must be defined.')
|
||||
|
||||
self.lr = lr
|
||||
self.beta_1 = beta_1
|
||||
self.beta_2 = beta_2
|
||||
self.lr_dropout = lr_dropout
|
||||
self.lr_cos = lr_cos
|
||||
self.clipnorm = clipnorm
|
||||
|
||||
with tf.device('/CPU:0') :
|
||||
with tf.variable_scope(self.name):
|
||||
self.iterations = tf.Variable(0, dtype=tf.int64, name='iters')
|
||||
|
||||
self.ms_dict = {}
|
||||
self.vs_dict = {}
|
||||
self.lr_rnds_dict = {}
|
||||
|
||||
def get_weights(self):
|
||||
return [self.iterations] + list(self.ms_dict.values()) + list(self.vs_dict.values())
|
||||
|
||||
def initialize_variables(self, trainable_weights, vars_on_cpu=True, lr_dropout_on_cpu=False):
|
||||
# Initialize here all trainable variables used in training
|
||||
e = tf.device('/CPU:0') if vars_on_cpu else None
|
||||
if e: e.__enter__()
|
||||
with tf.variable_scope(self.name):
|
||||
ms = { v.name : tf.get_variable ( f'ms_{v.name}'.replace(':','_'), v.shape, dtype=v.dtype, initializer=tf.initializers.constant(0.0), trainable=False) for v in trainable_weights }
|
||||
vs = { v.name : tf.get_variable ( f'vs_{v.name}'.replace(':','_'), v.shape, dtype=v.dtype, initializer=tf.initializers.constant(0.0), trainable=False) for v in trainable_weights }
|
||||
self.ms_dict.update (ms)
|
||||
self.vs_dict.update (vs)
|
||||
|
||||
if self.lr_dropout != 1.0:
|
||||
e = tf.device('/CPU:0') if lr_dropout_on_cpu else None
|
||||
if e: e.__enter__()
|
||||
lr_rnds = [ nn.random_binomial( v.shape, p=self.lr_dropout, dtype=v.dtype) for v in trainable_weights ]
|
||||
if e: e.__exit__(None, None, None)
|
||||
self.lr_rnds_dict.update ( { v.name : rnd for v,rnd in zip(trainable_weights,lr_rnds) } )
|
||||
if e: e.__exit__(None, None, None)
|
||||
|
||||
def get_update_op(self, grads_vars):
|
||||
updates = []
|
||||
|
||||
if self.clipnorm > 0.0:
|
||||
norm = tf.sqrt( sum([tf.reduce_sum(tf.square(tf.cast(g, tf.float32))) for g,v in grads_vars]))
|
||||
updates += [ state_ops.assign_add( self.iterations, 1) ]
|
||||
for i, (g,v) in enumerate(grads_vars):
|
||||
if self.clipnorm > 0.0:
|
||||
g = self.tf_clip_norm(g, self.clipnorm, tf.cast(norm, g.dtype) )
|
||||
|
||||
ms = self.ms_dict[ v.name ]
|
||||
vs = self.vs_dict[ v.name ]
|
||||
|
||||
m_t = self.beta_1*ms + (1.0-self.beta_1) * g
|
||||
v_t = self.beta_2*vs + (1.0-self.beta_2) * tf.square(g-m_t)
|
||||
|
||||
lr = tf.constant(self.lr, g.dtype)
|
||||
if self.lr_cos != 0:
|
||||
lr *= (tf.cos( tf.cast(self.iterations, g.dtype) * (2*3.1415926535/ float(self.lr_cos) ) ) + 1.0) / 2.0
|
||||
|
||||
v_diff = - lr * m_t / (tf.sqrt(v_t) + np.finfo( g.dtype.as_numpy_dtype ).resolution )
|
||||
if self.lr_dropout != 1.0:
|
||||
lr_rnd = self.lr_rnds_dict[v.name]
|
||||
v_diff *= lr_rnd
|
||||
new_v = v + v_diff
|
||||
|
||||
updates.append (state_ops.assign(ms, m_t))
|
||||
updates.append (state_ops.assign(vs, v_t))
|
||||
updates.append (state_ops.assign(v, new_v))
|
||||
|
||||
return control_flow_ops.group ( *updates, name=self.name+'_updates')
|
||||
nn.AdaBelief = AdaBelief
|
|
@ -1,31 +1,33 @@
|
|||
import numpy as np
|
||||
from tensorflow.python.ops import control_flow_ops, state_ops
|
||||
from core.leras import nn
|
||||
tf = nn.tf
|
||||
|
||||
class RMSprop(nn.OptimizerBase):
|
||||
def __init__(self, lr=0.001, rho=0.9, lr_dropout=1.0, epsilon=1e-7, clipnorm=0.0, name=None):
|
||||
def __init__(self, lr=0.001, rho=0.9, lr_dropout=1.0, lr_cos=0, clipnorm=0.0, name=None, **kwargs):
|
||||
super().__init__(name=name)
|
||||
|
||||
if name is None:
|
||||
raise ValueError('name must be defined.')
|
||||
|
||||
self.lr_dropout = lr_dropout
|
||||
self.lr_cos = lr_cos
|
||||
self.lr = lr
|
||||
self.rho = rho
|
||||
self.clipnorm = clipnorm
|
||||
|
||||
with tf.device('/CPU:0') :
|
||||
with tf.variable_scope(self.name):
|
||||
self.lr = tf.Variable (lr, name="lr")
|
||||
self.rho = tf.Variable (rho, name="rho")
|
||||
self.epsilon = tf.Variable (epsilon, name="epsilon")
|
||||
|
||||
self.iterations = tf.Variable(0, dtype=tf.int64, name='iters')
|
||||
|
||||
self.accumulators_dict = {}
|
||||
self.lr_rnds_dict = {}
|
||||
|
||||
def get_weights(self):
|
||||
return [self.lr, self.rho, self.epsilon, self.iterations] + list(self.accumulators_dict.values())
|
||||
return [self.iterations] + list(self.accumulators_dict.values())
|
||||
|
||||
def initialize_variables(self, trainable_weights, vars_on_cpu=True):
|
||||
def initialize_variables(self, trainable_weights, vars_on_cpu=True, lr_dropout_on_cpu=False):
|
||||
# Initialize here all trainable variables used in training
|
||||
e = tf.device('/CPU:0') if vars_on_cpu else None
|
||||
if e: e.__enter__()
|
||||
|
@ -34,7 +36,10 @@ class RMSprop(nn.OptimizerBase):
|
|||
self.accumulators_dict.update ( accumulators)
|
||||
|
||||
if self.lr_dropout != 1.0:
|
||||
e = tf.device('/CPU:0') if lr_dropout_on_cpu else None
|
||||
if e: e.__enter__()
|
||||
lr_rnds = [ nn.random_binomial( v.shape, p=self.lr_dropout, dtype=v.dtype) for v in trainable_weights ]
|
||||
if e: e.__exit__(None, None, None)
|
||||
self.lr_rnds_dict.update ( { v.name : rnd for v,rnd in zip(trainable_weights,lr_rnds) } )
|
||||
if e: e.__exit__(None, None, None)
|
||||
|
||||
|
@ -42,21 +47,21 @@ class RMSprop(nn.OptimizerBase):
|
|||
updates = []
|
||||
|
||||
if self.clipnorm > 0.0:
|
||||
norm = tf.sqrt( sum([tf.reduce_sum(tf.square(g)) for g,v in grads_vars]))
|
||||
norm = tf.sqrt( sum([tf.reduce_sum(tf.square(tf.cast(g, tf.float32))) for g,v in grads_vars]))
|
||||
updates += [ state_ops.assign_add( self.iterations, 1) ]
|
||||
for i, (g,v) in enumerate(grads_vars):
|
||||
if self.clipnorm > 0.0:
|
||||
g = self.tf_clip_norm(g, self.clipnorm, norm)
|
||||
g = self.tf_clip_norm(g, self.clipnorm, tf.cast(norm, g.dtype) )
|
||||
|
||||
a = self.accumulators_dict[ v.name ]
|
||||
|
||||
rho = tf.cast(self.rho, a.dtype)
|
||||
new_a = rho * a + (1. - rho) * tf.square(g)
|
||||
new_a = self.rho * a + (1. - self.rho) * tf.square(g)
|
||||
|
||||
lr = tf.cast(self.lr, a.dtype)
|
||||
epsilon = tf.cast(self.epsilon, a.dtype)
|
||||
lr = tf.constant(self.lr, g.dtype)
|
||||
if self.lr_cos != 0:
|
||||
lr *= (tf.cos( tf.cast(self.iterations, g.dtype) * (2*3.1415926535/ float(self.lr_cos) ) ) + 1.0) / 2.0
|
||||
|
||||
v_diff = - lr * g / (tf.sqrt(new_a) + epsilon)
|
||||
v_diff = - lr * g / (tf.sqrt(new_a) + np.finfo( g.dtype.as_numpy_dtype ).resolution )
|
||||
if self.lr_dropout != 1.0:
|
||||
lr_rnd = self.lr_rnds_dict[v.name]
|
||||
v_diff *= lr_rnd
|
||||
|
|
|
@ -1,2 +1,3 @@
|
|||
from .OptimizerBase import *
|
||||
from .RMSprop import *
|
||||
from .RMSprop import *
|
||||
from .AdaBelief import *
|
|
@ -1,7 +1,12 @@
|
|||
import numpy as np
|
||||
import math
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import numpy.linalg as npla
|
||||
|
||||
from .umeyama import umeyama
|
||||
|
||||
|
||||
def get_power_of_two(x):
|
||||
i = 0
|
||||
while (1 << i) < x:
|
||||
|
@ -23,3 +28,70 @@ def rotationMatrixToEulerAngles(R) :
|
|||
|
||||
def polygon_area(x,y):
|
||||
return 0.5*np.abs(np.dot(x,np.roll(y,1))-np.dot(y,np.roll(x,1)))
|
||||
|
||||
def rotate_point(origin, point, deg):
|
||||
"""
|
||||
Rotate a point counterclockwise by a given angle around a given origin.
|
||||
|
||||
The angle should be given in radians.
|
||||
"""
|
||||
ox, oy = origin
|
||||
px, py = point
|
||||
|
||||
rad = deg * math.pi / 180.0
|
||||
qx = ox + math.cos(rad) * (px - ox) - math.sin(rad) * (py - oy)
|
||||
qy = oy + math.sin(rad) * (px - ox) + math.cos(rad) * (py - oy)
|
||||
return np.float32([qx, qy])
|
||||
|
||||
def transform_points(points, mat, invert=False):
|
||||
if invert:
|
||||
mat = cv2.invertAffineTransform (mat)
|
||||
points = np.expand_dims(points, axis=1)
|
||||
points = cv2.transform(points, mat, points.shape)
|
||||
points = np.squeeze(points)
|
||||
return points
|
||||
|
||||
|
||||
def transform_mat(mat, res, tx, ty, rotation, scale):
|
||||
"""
|
||||
transform mat in local space of res
|
||||
scale -> translate -> rotate
|
||||
|
||||
tx,ty float
|
||||
rotation int degrees
|
||||
scale float
|
||||
"""
|
||||
|
||||
|
||||
lt, rt, lb, ct = transform_points ( np.float32([(0,0),(res,0),(0,res),(res / 2, res/2) ]),mat, True)
|
||||
|
||||
hor_v = (rt-lt).astype(np.float32)
|
||||
hor_size = npla.norm(hor_v)
|
||||
hor_v /= hor_size
|
||||
|
||||
ver_v = (lb-lt).astype(np.float32)
|
||||
ver_size = npla.norm(ver_v)
|
||||
ver_v /= ver_size
|
||||
|
||||
bt_diag_vec = (rt-ct).astype(np.float32)
|
||||
half_diag_len = npla.norm(bt_diag_vec)
|
||||
bt_diag_vec /= half_diag_len
|
||||
|
||||
tb_diag_vec = np.float32( [ -bt_diag_vec[1], bt_diag_vec[0] ] )
|
||||
|
||||
rt = ct + bt_diag_vec*half_diag_len*scale
|
||||
lb = ct - bt_diag_vec*half_diag_len*scale
|
||||
lt = ct - tb_diag_vec*half_diag_len*scale
|
||||
|
||||
rt[0] += tx*hor_size
|
||||
lb[0] += tx*hor_size
|
||||
lt[0] += tx*hor_size
|
||||
rt[1] += ty*ver_size
|
||||
lb[1] += ty*ver_size
|
||||
lt[1] += ty*ver_size
|
||||
|
||||
rt = rotate_point(ct, rt, rotation)
|
||||
lb = rotate_point(ct, lb, rotation)
|
||||
lt = rotate_point(ct, lt, rotation)
|
||||
|
||||
return cv2.getAffineTransform( np.float32([lt, rt, lb]), np.float32([ [0,0], [res,0], [0,res] ]) )
|
||||
|
|
111
core/mplib/MPSharedList.py
Normal file
|
@ -0,0 +1,111 @@
|
|||
import multiprocessing
|
||||
import pickle
|
||||
import struct
|
||||
from core.joblib import Subprocessor
|
||||
|
||||
class MPSharedList():
|
||||
"""
|
||||
Provides read-only pickled list of constant objects via shared memory aka 'multiprocessing.Array'
|
||||
Thus no 4GB limit for subprocesses.
|
||||
|
||||
supports list concat via + or sum()
|
||||
"""
|
||||
|
||||
def __init__(self, obj_list):
|
||||
if obj_list is None:
|
||||
self.obj_counts = None
|
||||
self.table_offsets = None
|
||||
self.data_offsets = None
|
||||
self.sh_bs = None
|
||||
else:
|
||||
obj_count, table_offset, data_offset, sh_b = MPSharedList.bake_data(obj_list)
|
||||
|
||||
self.obj_counts = [obj_count]
|
||||
self.table_offsets = [table_offset]
|
||||
self.data_offsets = [data_offset]
|
||||
self.sh_bs = [sh_b]
|
||||
|
||||
def __add__(self, o):
|
||||
if isinstance(o, MPSharedList):
|
||||
m = MPSharedList(None)
|
||||
m.obj_counts = self.obj_counts + o.obj_counts
|
||||
m.table_offsets = self.table_offsets + o.table_offsets
|
||||
m.data_offsets = self.data_offsets + o.data_offsets
|
||||
m.sh_bs = self.sh_bs + o.sh_bs
|
||||
return m
|
||||
elif isinstance(o, int):
|
||||
return self
|
||||
else:
|
||||
raise ValueError(f"MPSharedList object of class {o.__class__} is not supported for __add__ operator.")
|
||||
|
||||
def __radd__(self, o):
|
||||
return self+o
|
||||
|
||||
def __len__(self):
|
||||
return sum(self.obj_counts)
|
||||
|
||||
def __getitem__(self, key):
|
||||
obj_count = sum(self.obj_counts)
|
||||
if key < 0:
|
||||
key = obj_count+key
|
||||
if key < 0 or key >= obj_count:
|
||||
raise ValueError("out of range")
|
||||
|
||||
for i in range(len(self.obj_counts)):
|
||||
|
||||
if key < self.obj_counts[i]:
|
||||
table_offset = self.table_offsets[i]
|
||||
data_offset = self.data_offsets[i]
|
||||
sh_b = self.sh_bs[i]
|
||||
break
|
||||
key -= self.obj_counts[i]
|
||||
|
||||
sh_b = memoryview(sh_b).cast('B')
|
||||
|
||||
offset_start, offset_end = struct.unpack('<QQ', sh_b[ table_offset + key*8 : table_offset + (key+2)*8].tobytes() )
|
||||
|
||||
return pickle.loads( sh_b[ data_offset + offset_start : data_offset + offset_end ].tobytes() )
|
||||
|
||||
def __iter__(self):
|
||||
for i in range(self.__len__()):
|
||||
yield self.__getitem__(i)
|
||||
|
||||
@staticmethod
|
||||
def bake_data(obj_list):
|
||||
if not isinstance(obj_list, list):
|
||||
raise ValueError("MPSharedList: obj_list should be list type.")
|
||||
|
||||
obj_count = len(obj_list)
|
||||
|
||||
if obj_count != 0:
|
||||
obj_pickled_ar = [pickle.dumps(o, 4) for o in obj_list]
|
||||
|
||||
table_offset = 0
|
||||
table_size = (obj_count+1)*8
|
||||
data_offset = table_offset + table_size
|
||||
data_size = sum([len(x) for x in obj_pickled_ar])
|
||||
|
||||
sh_b = multiprocessing.RawArray('B', table_size + data_size)
|
||||
#sh_b[0:8] = struct.pack('<Q', obj_count)
|
||||
sh_b_view = memoryview(sh_b).cast('B')
|
||||
|
||||
offset = 0
|
||||
|
||||
sh_b_table = bytes()
|
||||
offsets = []
|
||||
|
||||
offset = 0
|
||||
for i in range(obj_count):
|
||||
offsets.append(offset)
|
||||
offset += len(obj_pickled_ar[i])
|
||||
offsets.append(offset)
|
||||
|
||||
sh_b_view[table_offset:table_offset+table_size] = struct.pack( '<'+'Q'*len(offsets), *offsets )
|
||||
|
||||
for i, obj_pickled in enumerate(obj_pickled_ar):
|
||||
offset = data_offset+offsets[i]
|
||||
sh_b_view[offset:offset+len(obj_pickled)] = obj_pickled_ar[i]
|
||||
|
||||
return obj_count, table_offset, data_offset, sh_b
|
||||
return 0, 0, 0, None
|
||||
|
|
@ -1,99 +1,10 @@
|
|||
from .MPSharedList import MPSharedList
|
||||
import multiprocessing
|
||||
import threading
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
class Index2DHost():
|
||||
"""
|
||||
Provides random shuffled 2D indexes for multiprocesses
|
||||
"""
|
||||
def __init__(self, indexes2D):
|
||||
self.sq = multiprocessing.Queue()
|
||||
self.cqs = []
|
||||
self.clis = []
|
||||
self.thread = threading.Thread(target=self.host_thread, args=(indexes2D,) )
|
||||
self.thread.daemon = True
|
||||
self.thread.start()
|
||||
|
||||
def host_thread(self, indexes2D):
|
||||
indexes_counts_len = len(indexes2D)
|
||||
|
||||
idxs = [*range(indexes_counts_len)]
|
||||
idxs_2D = [None]*indexes_counts_len
|
||||
shuffle_idxs = []
|
||||
shuffle_idxs_2D = [None]*indexes_counts_len
|
||||
for i in range(indexes_counts_len):
|
||||
idxs_2D[i] = indexes2D[i]
|
||||
shuffle_idxs_2D[i] = []
|
||||
|
||||
sq = self.sq
|
||||
|
||||
while True:
|
||||
while not sq.empty():
|
||||
obj = sq.get()
|
||||
cq_id, cmd = obj[0], obj[1]
|
||||
|
||||
if cmd == 0: #get_1D
|
||||
count = obj[2]
|
||||
|
||||
result = []
|
||||
for i in range(count):
|
||||
if len(shuffle_idxs) == 0:
|
||||
shuffle_idxs = idxs.copy()
|
||||
np.random.shuffle(shuffle_idxs)
|
||||
result.append(shuffle_idxs.pop())
|
||||
self.cqs[cq_id].put (result)
|
||||
elif cmd == 1: #get_2D
|
||||
targ_idxs,count = obj[2], obj[3]
|
||||
result = []
|
||||
|
||||
for targ_idx in targ_idxs:
|
||||
sub_idxs = []
|
||||
for i in range(count):
|
||||
ar = shuffle_idxs_2D[targ_idx]
|
||||
if len(ar) == 0:
|
||||
ar = shuffle_idxs_2D[targ_idx] = idxs_2D[targ_idx].copy()
|
||||
np.random.shuffle(ar)
|
||||
sub_idxs.append(ar.pop())
|
||||
result.append (sub_idxs)
|
||||
self.cqs[cq_id].put (result)
|
||||
|
||||
time.sleep(0.005)
|
||||
|
||||
def create_cli(self):
|
||||
cq = multiprocessing.Queue()
|
||||
self.cqs.append ( cq )
|
||||
cq_id = len(self.cqs)-1
|
||||
return Index2DHost.Cli(self.sq, cq, cq_id)
|
||||
|
||||
# disable pickling
|
||||
def __getstate__(self):
|
||||
return dict()
|
||||
def __setstate__(self, d):
|
||||
self.__dict__.update(d)
|
||||
|
||||
class Cli():
|
||||
def __init__(self, sq, cq, cq_id):
|
||||
self.sq = sq
|
||||
self.cq = cq
|
||||
self.cq_id = cq_id
|
||||
|
||||
def get_1D(self, count):
|
||||
self.sq.put ( (self.cq_id,0, count) )
|
||||
|
||||
while True:
|
||||
if not self.cq.empty():
|
||||
return self.cq.get()
|
||||
time.sleep(0.001)
|
||||
|
||||
def get_2D(self, idxs, count):
|
||||
self.sq.put ( (self.cq_id,1,idxs,count) )
|
||||
|
||||
while True:
|
||||
if not self.cq.empty():
|
||||
return self.cq.get()
|
||||
time.sleep(0.001)
|
||||
|
||||
class IndexHost():
|
||||
"""
|
||||
|
@ -107,9 +18,9 @@ class IndexHost():
|
|||
self.thread.daemon = True
|
||||
self.thread.start()
|
||||
|
||||
def host_thread(self, indexes_count, rnd_seed):
|
||||
def host_thread(self, indexes_count, rnd_seed):
|
||||
rnd_state = np.random.RandomState(rnd_seed) if rnd_seed is not None else np.random
|
||||
|
||||
|
||||
idxs = [*range(indexes_count)]
|
||||
shuffle_idxs = []
|
||||
sq = self.sq
|
||||
|
@ -155,6 +66,95 @@ class IndexHost():
|
|||
return self.cq.get()
|
||||
time.sleep(0.001)
|
||||
|
||||
class Index2DHost():
|
||||
"""
|
||||
Provides random shuffled indexes for multiprocesses
|
||||
"""
|
||||
def __init__(self, indexes2D):
|
||||
self.sq = multiprocessing.Queue()
|
||||
self.cqs = []
|
||||
self.clis = []
|
||||
self.thread = threading.Thread(target=self.host_thread, args=(indexes2D,) )
|
||||
self.thread.daemon = True
|
||||
self.thread.start()
|
||||
|
||||
def host_thread(self, indexes2D):
|
||||
indexes2D_len = len(indexes2D)
|
||||
|
||||
idxs = [*range(indexes2D_len)]
|
||||
idxs_2D = [None]*indexes2D_len
|
||||
shuffle_idxs = []
|
||||
shuffle_idxs_2D = [None]*indexes2D_len
|
||||
for i in range(indexes2D_len):
|
||||
idxs_2D[i] = [*range(len(indexes2D[i]))]
|
||||
shuffle_idxs_2D[i] = []
|
||||
|
||||
#print(idxs)
|
||||
#print(idxs_2D)
|
||||
sq = self.sq
|
||||
|
||||
while True:
|
||||
while not sq.empty():
|
||||
obj = sq.get()
|
||||
cq_id, count = obj[0], obj[1]
|
||||
|
||||
result = []
|
||||
for i in range(count):
|
||||
if len(shuffle_idxs) == 0:
|
||||
shuffle_idxs = idxs.copy()
|
||||
np.random.shuffle(shuffle_idxs)
|
||||
|
||||
idx_1D = shuffle_idxs.pop()
|
||||
|
||||
#print(f'idx_1D = {idx_1D}, len(shuffle_idxs_2D[idx_1D])= {len(shuffle_idxs_2D[idx_1D])}')
|
||||
|
||||
if len(shuffle_idxs_2D[idx_1D]) == 0:
|
||||
shuffle_idxs_2D[idx_1D] = idxs_2D[idx_1D].copy()
|
||||
#print(f'new shuffle_idxs_2d for {idx_1D} = { shuffle_idxs_2D[idx_1D] }')
|
||||
|
||||
#print(f'len(shuffle_idxs_2D[idx_1D])= {len(shuffle_idxs_2D[idx_1D])}')
|
||||
|
||||
np.random.shuffle( shuffle_idxs_2D[idx_1D] )
|
||||
|
||||
idx_2D = shuffle_idxs_2D[idx_1D].pop()
|
||||
|
||||
#print(f'len(shuffle_idxs_2D[idx_1D])= {len(shuffle_idxs_2D[idx_1D])}')
|
||||
|
||||
#print(f'idx_2D = {idx_2D}')
|
||||
|
||||
|
||||
result.append( indexes2D[idx_1D][idx_2D])
|
||||
|
||||
self.cqs[cq_id].put (result)
|
||||
|
||||
time.sleep(0.001)
|
||||
|
||||
def create_cli(self):
|
||||
cq = multiprocessing.Queue()
|
||||
self.cqs.append ( cq )
|
||||
cq_id = len(self.cqs)-1
|
||||
return Index2DHost.Cli(self.sq, cq, cq_id)
|
||||
|
||||
# disable pickling
|
||||
def __getstate__(self):
|
||||
return dict()
|
||||
def __setstate__(self, d):
|
||||
self.__dict__.update(d)
|
||||
|
||||
class Cli():
|
||||
def __init__(self, sq, cq, cq_id):
|
||||
self.sq = sq
|
||||
self.cq = cq
|
||||
self.cq_id = cq_id
|
||||
|
||||
def multi_get(self, count):
|
||||
self.sq.put ( (self.cq_id,count) )
|
||||
|
||||
while True:
|
||||
if not self.cq.empty():
|
||||
return self.cq.get()
|
||||
time.sleep(0.001)
|
||||
|
||||
class ListHost():
|
||||
def __init__(self, list_):
|
||||
self.sq = multiprocessing.Queue()
|
||||
|
|
262
core/qtex/QSubprocessor.py
Normal file
|
@ -0,0 +1,262 @@
|
|||
import multiprocessing
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
|
||||
from PyQt5.QtCore import *
|
||||
from PyQt5.QtGui import *
|
||||
from PyQt5.QtWidgets import *
|
||||
|
||||
from core.interact import interact as io
|
||||
|
||||
from .qtex import *
|
||||
|
||||
class QSubprocessor(object):
|
||||
"""
|
||||
|
||||
"""
|
||||
|
||||
class Cli(object):
|
||||
def __init__ ( self, client_dict ):
|
||||
s2c = multiprocessing.Queue()
|
||||
c2s = multiprocessing.Queue()
|
||||
self.p = multiprocessing.Process(target=self._subprocess_run, args=(client_dict,s2c,c2s) )
|
||||
self.s2c = s2c
|
||||
self.c2s = c2s
|
||||
self.p.daemon = True
|
||||
self.p.start()
|
||||
|
||||
self.state = None
|
||||
self.sent_time = None
|
||||
self.sent_data = None
|
||||
self.name = None
|
||||
self.host_dict = None
|
||||
|
||||
def kill(self):
|
||||
self.p.terminate()
|
||||
self.p.join()
|
||||
|
||||
#overridable optional
|
||||
def on_initialize(self, client_dict):
|
||||
#initialize your subprocess here using client_dict
|
||||
pass
|
||||
|
||||
#overridable optional
|
||||
def on_finalize(self):
|
||||
#finalize your subprocess here
|
||||
pass
|
||||
|
||||
#overridable
|
||||
def process_data(self, data):
|
||||
#process 'data' given from host and return result
|
||||
raise NotImplementedError
|
||||
|
||||
#overridable optional
|
||||
def get_data_name (self, data):
|
||||
#return string identificator of your 'data'
|
||||
return "undefined"
|
||||
|
||||
def log_info(self, msg): self.c2s.put ( {'op': 'log_info', 'msg':msg } )
|
||||
def log_err(self, msg): self.c2s.put ( {'op': 'log_err' , 'msg':msg } )
|
||||
def progress_bar_inc(self, c): self.c2s.put ( {'op': 'progress_bar_inc' , 'c':c } )
|
||||
|
||||
def _subprocess_run(self, client_dict, s2c, c2s):
|
||||
self.c2s = c2s
|
||||
data = None
|
||||
try:
|
||||
self.on_initialize(client_dict)
|
||||
c2s.put ( {'op': 'init_ok'} )
|
||||
while True:
|
||||
msg = s2c.get()
|
||||
op = msg.get('op','')
|
||||
if op == 'data':
|
||||
data = msg['data']
|
||||
result = self.process_data (data)
|
||||
c2s.put ( {'op': 'success', 'data' : data, 'result' : result} )
|
||||
data = None
|
||||
elif op == 'close':
|
||||
break
|
||||
time.sleep(0.001)
|
||||
self.on_finalize()
|
||||
c2s.put ( {'op': 'finalized'} )
|
||||
except Exception as e:
|
||||
c2s.put ( {'op': 'error', 'data' : data} )
|
||||
if data is not None:
|
||||
print ('Exception while process data [%s]: %s' % (self.get_data_name(data), traceback.format_exc()) )
|
||||
else:
|
||||
print ('Exception: %s' % (traceback.format_exc()) )
|
||||
c2s.close()
|
||||
s2c.close()
|
||||
self.c2s = None
|
||||
|
||||
# disable pickling
|
||||
def __getstate__(self):
|
||||
return dict()
|
||||
def __setstate__(self, d):
|
||||
self.__dict__.update(d)
|
||||
|
||||
#overridable
|
||||
def __init__(self, name, SubprocessorCli_class, no_response_time_sec = 0, io_loop_sleep_time=0.005):
|
||||
if not issubclass(SubprocessorCli_class, QSubprocessor.Cli):
|
||||
raise ValueError("SubprocessorCli_class must be subclass of QSubprocessor.Cli")
|
||||
|
||||
self.name = name
|
||||
self.SubprocessorCli_class = SubprocessorCli_class
|
||||
self.no_response_time_sec = no_response_time_sec
|
||||
self.io_loop_sleep_time = io_loop_sleep_time
|
||||
|
||||
self.clis = []
|
||||
|
||||
#getting info about name of subprocesses, host and client dicts, and spawning them
|
||||
for name, host_dict, client_dict in self.process_info_generator():
|
||||
try:
|
||||
cli = self.SubprocessorCli_class(client_dict)
|
||||
cli.state = 1
|
||||
cli.sent_time = 0
|
||||
cli.sent_data = None
|
||||
cli.name = name
|
||||
cli.host_dict = host_dict
|
||||
|
||||
self.clis.append (cli)
|
||||
except:
|
||||
raise Exception (f"Unable to start subprocess {name}. Error: {traceback.format_exc()}")
|
||||
|
||||
if len(self.clis) == 0:
|
||||
raise Exception ("Unable to start QSubprocessor '%s' " % (self.name))
|
||||
|
||||
#waiting subprocesses their success(or not) initialization
|
||||
while True:
|
||||
for cli in self.clis[:]:
|
||||
while not cli.c2s.empty():
|
||||
obj = cli.c2s.get()
|
||||
op = obj.get('op','')
|
||||
if op == 'init_ok':
|
||||
cli.state = 0
|
||||
elif op == 'log_info':
|
||||
io.log_info(obj['msg'])
|
||||
elif op == 'log_err':
|
||||
io.log_err(obj['msg'])
|
||||
elif op == 'error':
|
||||
cli.kill()
|
||||
self.clis.remove(cli)
|
||||
break
|
||||
if all ([cli.state == 0 for cli in self.clis]):
|
||||
break
|
||||
io.process_messages(0.005)
|
||||
|
||||
if len(self.clis) == 0:
|
||||
raise Exception ( "Unable to start subprocesses." )
|
||||
|
||||
#ok some processes survived, initialize host logic
|
||||
self.on_clients_initialized()
|
||||
|
||||
self.q_timer = QTimer()
|
||||
self.q_timer.timeout.connect(self.tick)
|
||||
self.q_timer.start(5)
|
||||
|
||||
#overridable
|
||||
def process_info_generator(self):
|
||||
#yield per process (name, host_dict, client_dict)
|
||||
for i in range(min(multiprocessing.cpu_count(), 8) ):
|
||||
yield 'CPU%d' % (i), {}, {}
|
||||
|
||||
#overridable optional
|
||||
def on_clients_initialized(self):
|
||||
#logic when all subprocesses initialized and ready
|
||||
pass
|
||||
|
||||
#overridable optional
|
||||
def on_clients_finalized(self):
|
||||
#logic when all subprocess finalized
|
||||
pass
|
||||
|
||||
#overridable
|
||||
def get_data(self, host_dict):
|
||||
#return data for processing here
|
||||
raise NotImplementedError
|
||||
|
||||
#overridable
|
||||
def on_data_return (self, host_dict, data):
|
||||
#you have to place returned 'data' back to your queue
|
||||
raise NotImplementedError
|
||||
|
||||
#overridable
|
||||
def on_result (self, host_dict, data, result):
|
||||
#your logic what to do with 'result' of 'data'
|
||||
raise NotImplementedError
|
||||
|
||||
def tick(self):
|
||||
for cli in self.clis[:]:
|
||||
while not cli.c2s.empty():
|
||||
obj = cli.c2s.get()
|
||||
op = obj.get('op','')
|
||||
if op == 'success':
|
||||
#success processed data, return data and result to on_result
|
||||
self.on_result (cli.host_dict, obj['data'], obj['result'])
|
||||
self.sent_data = None
|
||||
cli.state = 0
|
||||
elif op == 'error':
|
||||
#some error occured while process data, returning chunk to on_data_return
|
||||
if 'data' in obj.keys():
|
||||
self.on_data_return (cli.host_dict, obj['data'] )
|
||||
#and killing process
|
||||
cli.kill()
|
||||
self.clis.remove(cli)
|
||||
elif op == 'log_info':
|
||||
io.log_info(obj['msg'])
|
||||
elif op == 'log_err':
|
||||
io.log_err(obj['msg'])
|
||||
elif op == 'progress_bar_inc':
|
||||
io.progress_bar_inc(obj['c'])
|
||||
|
||||
for cli in self.clis[:]:
|
||||
if cli.state == 1:
|
||||
if cli.sent_time != 0 and self.no_response_time_sec != 0 and (time.time() - cli.sent_time) > self.no_response_time_sec:
|
||||
#subprocess busy too long
|
||||
io.log_info ( '%s doesnt response, terminating it.' % (cli.name) )
|
||||
self.on_data_return (cli.host_dict, cli.sent_data )
|
||||
cli.kill()
|
||||
self.clis.remove(cli)
|
||||
|
||||
for cli in self.clis[:]:
|
||||
if cli.state == 0:
|
||||
#free state of subprocess, get some data from get_data
|
||||
data = self.get_data(cli.host_dict)
|
||||
if data is not None:
|
||||
#and send it to subprocess
|
||||
cli.s2c.put ( {'op': 'data', 'data' : data} )
|
||||
cli.sent_time = time.time()
|
||||
cli.sent_data = data
|
||||
cli.state = 1
|
||||
|
||||
if all ([cli.state == 0 for cli in self.clis]):
|
||||
#gracefully terminating subprocesses
|
||||
for cli in self.clis[:]:
|
||||
cli.s2c.put ( {'op': 'close'} )
|
||||
cli.sent_time = time.time()
|
||||
|
||||
while True:
|
||||
for cli in self.clis[:]:
|
||||
terminate_it = False
|
||||
while not cli.c2s.empty():
|
||||
obj = cli.c2s.get()
|
||||
obj_op = obj['op']
|
||||
if obj_op == 'finalized':
|
||||
terminate_it = True
|
||||
break
|
||||
|
||||
if (time.time() - cli.sent_time) > 30:
|
||||
terminate_it = True
|
||||
|
||||
if terminate_it:
|
||||
cli.state = 2
|
||||
cli.kill()
|
||||
|
||||
if all ([cli.state == 2 for cli in self.clis]):
|
||||
break
|
||||
|
||||
#finalizing host logic
|
||||
self.q_timer.stop()
|
||||
self.q_timer = None
|
||||
self.on_clients_finalized()
|
||||
|
83
core/qtex/QXIconButton.py
Normal file
|
@ -0,0 +1,83 @@
|
|||
from PyQt5.QtCore import *
|
||||
from PyQt5.QtGui import *
|
||||
from PyQt5.QtWidgets import *
|
||||
|
||||
from localization import StringsDB
|
||||
from .QXMainWindow import *
|
||||
|
||||
class QXIconButton(QPushButton):
|
||||
"""
|
||||
Custom Icon button that works through keyEvent system, without shortcut of QAction
|
||||
works only with QXMainWindow as global window class
|
||||
currently works only with one-key shortcut
|
||||
"""
|
||||
|
||||
def __init__(self, icon,
|
||||
tooltip=None,
|
||||
shortcut=None,
|
||||
click_func=None,
|
||||
first_repeat_delay=300,
|
||||
repeat_delay=20,
|
||||
):
|
||||
|
||||
super().__init__(icon, "")
|
||||
|
||||
self.setIcon(icon)
|
||||
|
||||
if shortcut is not None:
|
||||
tooltip = f"{tooltip} ( {StringsDB['S_HOT_KEY'] }: {shortcut} )"
|
||||
|
||||
self.setToolTip(tooltip)
|
||||
|
||||
|
||||
self.seq = QKeySequence(shortcut) if shortcut is not None else None
|
||||
|
||||
QXMainWindow.inst.add_keyPressEvent_listener ( self.on_keyPressEvent )
|
||||
QXMainWindow.inst.add_keyReleaseEvent_listener ( self.on_keyReleaseEvent )
|
||||
|
||||
self.click_func = click_func
|
||||
self.first_repeat_delay = first_repeat_delay
|
||||
self.repeat_delay = repeat_delay
|
||||
self.repeat_timer = None
|
||||
|
||||
self.op_device = None
|
||||
|
||||
self.pressed.connect( lambda : self.action(is_pressed=True) )
|
||||
self.released.connect( lambda : self.action(is_pressed=False) )
|
||||
|
||||
def action(self, is_pressed=None, op_device=None):
|
||||
if self.click_func is None:
|
||||
return
|
||||
|
||||
if is_pressed is not None:
|
||||
if is_pressed:
|
||||
if self.repeat_timer is None:
|
||||
self.click_func()
|
||||
self.repeat_timer = QTimer()
|
||||
self.repeat_timer.timeout.connect(self.action)
|
||||
self.repeat_timer.start(self.first_repeat_delay)
|
||||
else:
|
||||
if self.repeat_timer is not None:
|
||||
self.repeat_timer.stop()
|
||||
self.repeat_timer = None
|
||||
else:
|
||||
self.click_func()
|
||||
if self.repeat_timer is not None:
|
||||
self.repeat_timer.setInterval(self.repeat_delay)
|
||||
|
||||
def on_keyPressEvent(self, ev):
|
||||
key = ev.nativeVirtualKey()
|
||||
if ev.isAutoRepeat():
|
||||
return
|
||||
|
||||
if self.seq is not None:
|
||||
if key == self.seq[0]:
|
||||
self.action(is_pressed=True)
|
||||
|
||||
def on_keyReleaseEvent(self, ev):
|
||||
key = ev.nativeVirtualKey()
|
||||
if ev.isAutoRepeat():
|
||||
return
|
||||
if self.seq is not None:
|
||||
if key == self.seq[0]:
|
||||
self.action(is_pressed=False)
|
34
core/qtex/QXMainWindow.py
Normal file
|
@ -0,0 +1,34 @@
|
|||
from PyQt5.QtCore import *
|
||||
from PyQt5.QtGui import *
|
||||
from PyQt5.QtWidgets import *
|
||||
|
||||
class QXMainWindow(QWidget):
|
||||
"""
|
||||
Custom mainwindow class that provides global single instance and event listeners
|
||||
"""
|
||||
inst = None
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
if QXMainWindow.inst is not None:
|
||||
raise Exception("QXMainWindow can only be one.")
|
||||
QXMainWindow.inst = self
|
||||
|
||||
self.keyPressEvent_listeners = []
|
||||
self.keyReleaseEvent_listeners = []
|
||||
self.setFocusPolicy(Qt.WheelFocus)
|
||||
|
||||
def add_keyPressEvent_listener(self, func):
|
||||
self.keyPressEvent_listeners.append (func)
|
||||
|
||||
def add_keyReleaseEvent_listener(self, func):
|
||||
self.keyReleaseEvent_listeners.append (func)
|
||||
|
||||
def keyPressEvent(self, ev):
|
||||
super().keyPressEvent(ev)
|
||||
for func in self.keyPressEvent_listeners:
|
||||
func(ev)
|
||||
|
||||
def keyReleaseEvent(self, ev):
|
||||
super().keyReleaseEvent(ev)
|
||||
for func in self.keyReleaseEvent_listeners:
|
||||
func(ev)
|
3
core/qtex/__init__.py
Normal file
|
@ -0,0 +1,3 @@
|
|||
from .qtex import *
|
||||
from .QSubprocessor import *
|
||||
from .QXIconButton import *
|
80
core/qtex/qtex.py
Normal file
|
@ -0,0 +1,80 @@
|
|||
import numpy as np
|
||||
from PyQt5.QtCore import *
|
||||
from PyQt5.QtGui import *
|
||||
from PyQt5.QtWidgets import *
|
||||
from localization import StringsDB
|
||||
|
||||
from .QXMainWindow import *
|
||||
|
||||
|
||||
class QActionEx(QAction):
|
||||
def __init__(self, icon, text, shortcut=None, trigger_func=None, shortcut_in_tooltip=False, is_checkable=False, is_auto_repeat=False ):
|
||||
super().__init__(icon, text)
|
||||
if shortcut is not None:
|
||||
self.setShortcut(shortcut)
|
||||
if shortcut_in_tooltip:
|
||||
|
||||
self.setToolTip( f"{text} ( {StringsDB['S_HOT_KEY'] }: {shortcut} )")
|
||||
|
||||
if trigger_func is not None:
|
||||
self.triggered.connect(trigger_func)
|
||||
if is_checkable:
|
||||
self.setCheckable(True)
|
||||
self.setAutoRepeat(is_auto_repeat)
|
||||
|
||||
def QImage_from_np(img):
|
||||
if img.dtype != np.uint8:
|
||||
raise ValueError("img should be in np.uint8 format")
|
||||
|
||||
h,w,c = img.shape
|
||||
if c == 1:
|
||||
fmt = QImage.Format_Grayscale8
|
||||
elif c == 3:
|
||||
fmt = QImage.Format_BGR888
|
||||
elif c == 4:
|
||||
fmt = QImage.Format_ARGB32
|
||||
else:
|
||||
raise ValueError("unsupported channel count")
|
||||
|
||||
return QImage(img.data, w, h, c*w, fmt )
|
||||
|
||||
def QImage_to_np(q_img, fmt=QImage.Format_BGR888):
|
||||
q_img = q_img.convertToFormat(fmt)
|
||||
|
||||
width = q_img.width()
|
||||
height = q_img.height()
|
||||
|
||||
b = q_img.constBits()
|
||||
b.setsize(height * width * 3)
|
||||
arr = np.frombuffer(b, np.uint8).reshape((height, width, 3))
|
||||
return arr#[::-1]
|
||||
|
||||
def QPixmap_from_np(img):
|
||||
return QPixmap.fromImage(QImage_from_np(img))
|
||||
|
||||
def QPoint_from_np(n):
|
||||
return QPoint(*n.astype(np.int))
|
||||
|
||||
def QPoint_to_np(q):
|
||||
return np.int32( [q.x(), q.y()] )
|
||||
|
||||
def QSize_to_np(q):
|
||||
return np.int32( [q.width(), q.height()] )
|
||||
|
||||
class QDarkPalette(QPalette):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
text_color = QColor(200,200,200)
|
||||
self.setColor(QPalette.Window, QColor(53, 53, 53))
|
||||
self.setColor(QPalette.WindowText, text_color )
|
||||
self.setColor(QPalette.Base, QColor(25, 25, 25))
|
||||
self.setColor(QPalette.AlternateBase, QColor(53, 53, 53))
|
||||
self.setColor(QPalette.ToolTipBase, text_color )
|
||||
self.setColor(QPalette.ToolTipText, text_color )
|
||||
self.setColor(QPalette.Text, text_color )
|
||||
self.setColor(QPalette.Button, QColor(53, 53, 53))
|
||||
self.setColor(QPalette.ButtonText, Qt.white)
|
||||
self.setColor(QPalette.BrightText, Qt.red)
|
||||
self.setColor(QPalette.Link, QColor(42, 130, 218))
|
||||
self.setColor(QPalette.Highlight, QColor(42, 130, 218))
|
||||
self.setColor(QPalette.HighlightedText, Qt.black)
|
|
@ -1,12 +1,14 @@
|
|||
import numpy as np
|
||||
|
||||
def random_normal( size=(1,), trunc_val = 2.5 ):
|
||||
def random_normal( size=(1,), trunc_val = 2.5, rnd_state=None ):
|
||||
if rnd_state is None:
|
||||
rnd_state = np.random
|
||||
len = np.array(size).prod()
|
||||
result = np.empty ( (len,) , dtype=np.float32)
|
||||
|
||||
for i in range (len):
|
||||
while True:
|
||||
x = np.random.normal()
|
||||
x = rnd_state.normal()
|
||||
if x >= -trunc_val and x <= trunc_val:
|
||||
break
|
||||
result[i] = (x / trunc_val)
|
||||
|
|
Before Width: | Height: | Size: 403 KiB |
BIN
doc/DFL_welcome.png
Normal file
After Width: | Height: | Size: 482 KiB |
Before Width: | Height: | Size: 287 KiB |
BIN
doc/deage_0_1.jpg
Normal file
After Width: | Height: | Size: 74 KiB |
BIN
doc/deage_0_2.jpg
Normal file
After Width: | Height: | Size: 68 KiB |
BIN
doc/deepfake_progress.png
Normal file
After Width: | Height: | Size: 1 MiB |
BIN
doc/deepfake_progress_source.psd
Normal file
BIN
doc/head_replace_0_1.jpg
Normal file
After Width: | Height: | Size: 71 KiB |
BIN
doc/head_replace_0_2.jpg
Normal file
After Width: | Height: | Size: 67 KiB |
BIN
doc/head_replace_1_1.jpg
Normal file
After Width: | Height: | Size: 122 KiB |
BIN
doc/head_replace_1_2.jpg
Normal file
After Width: | Height: | Size: 123 KiB |
BIN
doc/head_replace_2_1.jpg
Normal file
After Width: | Height: | Size: 98 KiB |
BIN
doc/head_replace_2_2.jpg
Normal file
After Width: | Height: | Size: 97 KiB |
BIN
doc/logo_directx.png
Normal file
After Width: | Height: | Size: 25 KiB |
BIN
doc/make_everything_ok.png
Normal file
After Width: | Height: | Size: 36 KiB |
BIN
doc/meme1.jpg
Normal file
After Width: | Height: | Size: 139 KiB |
BIN
doc/meme2.jpg
Normal file
After Width: | Height: | Size: 208 KiB |
BIN
doc/meme3.jpg
Normal file
After Width: | Height: | Size: 310 KiB |