mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-05 12:36:42 -07:00
global refactoring and fixes,
removed support of extracted(aligned) PNG faces. Use old builds to convert from PNG to JPG. fanseg model file in facelib/ is renamed
This commit is contained in:
parent
921b464d5b
commit
61472cdaf7
82 changed files with 3838 additions and 3812 deletions
|
@ -1,15 +1,12 @@
|
|||
from pathlib import Path
|
||||
|
||||
from .DFLJPG import DFLJPG
|
||||
from .DFLPNG import DFLPNG
|
||||
|
||||
class DFLIMG():
|
||||
|
||||
@staticmethod
|
||||
def load(filepath, loader_func=None):
|
||||
if filepath.suffix == '.png':
|
||||
return DFLPNG.load( str(filepath), loader_func=loader_func )
|
||||
elif filepath.suffix == '.jpg':
|
||||
if filepath.suffix == '.jpg':
|
||||
return DFLJPG.load ( str(filepath), loader_func=loader_func )
|
||||
else:
|
||||
return None
|
||||
|
|
435
DFLIMG/DFLPNG.py
435
DFLIMG/DFLPNG.py
|
@ -1,435 +0,0 @@
|
|||
import pickle
|
||||
import string
|
||||
import struct
|
||||
import zlib
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from facelib import FaceType
|
||||
|
||||
PNG_HEADER = b"\x89PNG\r\n\x1a\n"
|
||||
|
||||
class Chunk(object):
|
||||
def __init__(self, name=None, data=None):
|
||||
self.length = 0
|
||||
self.crc = 0
|
||||
self.name = name if name else "noNe"
|
||||
self.data = data if data else b""
|
||||
|
||||
@classmethod
|
||||
def load(cls, data):
|
||||
"""Load a chunk including header and footer"""
|
||||
inst = cls()
|
||||
if len(data) < 12:
|
||||
msg = "Chunk-data too small"
|
||||
raise ValueError(msg)
|
||||
|
||||
# chunk header & data
|
||||
(inst.length, raw_name) = struct.unpack("!I4s", data[0:8])
|
||||
inst.data = data[8:-4]
|
||||
inst.verify_length()
|
||||
inst.name = raw_name.decode("ascii")
|
||||
inst.verify_name()
|
||||
|
||||
# chunk crc
|
||||
inst.crc = struct.unpack("!I", data[8+inst.length:8+inst.length+4])[0]
|
||||
inst.verify_crc()
|
||||
|
||||
return inst
|
||||
|
||||
def dump(self, auto_crc=True, auto_length=True):
|
||||
"""Return the chunk including header and footer"""
|
||||
if auto_length: self.update_length()
|
||||
if auto_crc: self.update_crc()
|
||||
self.verify_name()
|
||||
return struct.pack("!I", self.length) + self.get_raw_name() + self.data + struct.pack("!I", self.crc)
|
||||
|
||||
def verify_length(self):
|
||||
if len(self.data) != self.length:
|
||||
msg = "Data length ({}) does not match length in chunk header ({})".format(len(self.data), self.length)
|
||||
raise ValueError(msg)
|
||||
return True
|
||||
|
||||
def verify_name(self):
|
||||
for c in self.name:
|
||||
if c not in string.ascii_letters:
|
||||
msg = "Invalid character in chunk name: {}".format(repr(self.name))
|
||||
raise ValueError(msg)
|
||||
return True
|
||||
|
||||
def verify_crc(self):
|
||||
calculated_crc = self.get_crc()
|
||||
if self.crc != calculated_crc:
|
||||
msg = "CRC mismatch: {:08X} (header), {:08X} (calculated)".format(self.crc, calculated_crc)
|
||||
raise ValueError(msg)
|
||||
return True
|
||||
|
||||
def update_length(self):
|
||||
self.length = len(self.data)
|
||||
|
||||
def update_crc(self):
|
||||
self.crc = self.get_crc()
|
||||
|
||||
def get_crc(self):
|
||||
return zlib.crc32(self.get_raw_name() + self.data)
|
||||
|
||||
def get_raw_name(self):
|
||||
return self.name if isinstance(self.name, bytes) else self.name.encode("ascii")
|
||||
|
||||
# name helper methods
|
||||
|
||||
def ancillary(self, set=None):
|
||||
"""Set and get ancillary=True/critical=False bit"""
|
||||
if set is True:
|
||||
self.name[0] = self.name[0].lower()
|
||||
elif set is False:
|
||||
self.name[0] = self.name[0].upper()
|
||||
return self.name[0].islower()
|
||||
|
||||
def private(self, set=None):
|
||||
"""Set and get private=True/public=False bit"""
|
||||
if set is True:
|
||||
self.name[1] = self.name[1].lower()
|
||||
elif set is False:
|
||||
self.name[1] = self.name[1].upper()
|
||||
return self.name[1].islower()
|
||||
|
||||
def reserved(self, set=None):
|
||||
"""Set and get reserved_valid=True/invalid=False bit"""
|
||||
if set is True:
|
||||
self.name[2] = self.name[2].upper()
|
||||
elif set is False:
|
||||
self.name[2] = self.name[2].lower()
|
||||
return self.name[2].isupper()
|
||||
|
||||
def safe_to_copy(self, set=None):
|
||||
"""Set and get save_to_copy=True/unsafe=False bit"""
|
||||
if set is True:
|
||||
self.name[3] = self.name[3].lower()
|
||||
elif set is False:
|
||||
self.name[3] = self.name[3].upper()
|
||||
return self.name[3].islower()
|
||||
|
||||
def __str__(self):
|
||||
return "<Chunk '{name}' length={length} crc={crc:08X}>".format(**self.__dict__)
|
||||
|
||||
class IHDR(Chunk):
|
||||
"""IHDR Chunk
|
||||
width, height, bit_depth, color_type, compression_method,
|
||||
filter_method, interlace_method contain the data extracted
|
||||
from the chunk. Modify those and use and build() to recreate
|
||||
the chunk. Valid values for bit_depth depend on the color_type
|
||||
and can be looked up in color_types or in the PNG specification
|
||||
|
||||
See:
|
||||
http://www.libpng.org/pub/png/spec/1.2/PNG-Chunks.html#C.IHDR
|
||||
"""
|
||||
# color types with name & allowed bit depths
|
||||
COLOR_TYPE_GRAY = 0
|
||||
COLOR_TYPE_RGB = 2
|
||||
COLOR_TYPE_PLTE = 3
|
||||
COLOR_TYPE_GRAYA = 4
|
||||
COLOR_TYPE_RGBA = 6
|
||||
color_types = {
|
||||
COLOR_TYPE_GRAY: ("Grayscale", (1,2,4,8,16)),
|
||||
COLOR_TYPE_RGB: ("RGB", (8,16)),
|
||||
COLOR_TYPE_PLTE: ("Palette", (1,2,4,8)),
|
||||
COLOR_TYPE_GRAYA: ("Greyscale+Alpha", (8,16)),
|
||||
COLOR_TYPE_RGBA: ("RGBA", (8,16)),
|
||||
}
|
||||
|
||||
def __init__(self, width=0, height=0, bit_depth=8, color_type=2, \
|
||||
compression_method=0, filter_method=0, interlace_method=0):
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.bit_depth = bit_depth
|
||||
self.color_type = color_type
|
||||
self.compression_method = compression_method
|
||||
self.filter_method = filter_method
|
||||
self.interlace_method = interlace_method
|
||||
super().__init__("IHDR")
|
||||
|
||||
@classmethod
|
||||
def load(cls, data):
|
||||
inst = super().load(data)
|
||||
fields = struct.unpack("!IIBBBBB", inst.data)
|
||||
inst.width = fields[0]
|
||||
inst.height = fields[1]
|
||||
inst.bit_depth = fields[2] # per channel
|
||||
inst.color_type = fields[3] # see specs
|
||||
inst.compression_method = fields[4] # always 0(=deflate/inflate)
|
||||
inst.filter_method = fields[5] # always 0(=adaptive filtering with 5 methods)
|
||||
inst.interlace_method = fields[6] # 0(=no interlace) or 1(=Adam7 interlace)
|
||||
return inst
|
||||
|
||||
def dump(self):
|
||||
self.data = struct.pack("!IIBBBBB", \
|
||||
self.width, self.height, self.bit_depth, self.color_type, \
|
||||
self.compression_method, self.filter_method, self.interlace_method)
|
||||
return super().dump()
|
||||
|
||||
def __str__(self):
|
||||
return "<Chunk:IHDR geometry={width}x{height} bit_depth={bit_depth} color_type={}>" \
|
||||
.format(self.color_types[self.color_type][0], **self.__dict__)
|
||||
|
||||
class IEND(Chunk):
|
||||
def __init__(self):
|
||||
super().__init__("IEND")
|
||||
|
||||
def dump(self):
|
||||
if len(self.data) != 0:
|
||||
msg = "IEND has data which is not allowed"
|
||||
raise ValueError(msg)
|
||||
if self.length != 0:
|
||||
msg = "IEND data lenght is not 0 which is not allowed"
|
||||
raise ValueError(msg)
|
||||
return super().dump()
|
||||
|
||||
def __str__(self):
|
||||
return "<Chunk:IEND>".format(**self.__dict__)
|
||||
|
||||
class DFLChunk(Chunk):
|
||||
def __init__(self, dict_data=None):
|
||||
super().__init__("fcWp")
|
||||
self.dict_data = dict_data
|
||||
|
||||
def setDictData(self, dict_data):
|
||||
self.dict_data = dict_data
|
||||
|
||||
def getDictData(self):
|
||||
return self.dict_data
|
||||
|
||||
@classmethod
|
||||
def load(cls, data):
|
||||
inst = super().load(data)
|
||||
inst.dict_data = pickle.loads( inst.data )
|
||||
return inst
|
||||
|
||||
def dump(self):
|
||||
self.data = pickle.dumps (self.dict_data)
|
||||
return super().dump()
|
||||
|
||||
chunk_map = {
|
||||
b"IHDR": IHDR,
|
||||
b"fcWp": DFLChunk,
|
||||
b"IEND": IEND
|
||||
}
|
||||
|
||||
class DFLPNG(object):
|
||||
def __init__(self):
|
||||
self.data = b""
|
||||
self.length = 0
|
||||
self.chunks = []
|
||||
self.dfl_dict = None
|
||||
|
||||
@staticmethod
|
||||
def load_raw(filename, loader_func=None):
|
||||
try:
|
||||
if loader_func is not None:
|
||||
data = loader_func(filename)
|
||||
else:
|
||||
with open(filename, "rb") as f:
|
||||
data = f.read()
|
||||
except:
|
||||
raise FileNotFoundError(filename)
|
||||
|
||||
inst = DFLPNG()
|
||||
inst.data = data
|
||||
inst.length = len(data)
|
||||
|
||||
if data[0:8] != PNG_HEADER:
|
||||
msg = "No Valid PNG header"
|
||||
raise ValueError(msg)
|
||||
|
||||
chunk_start = 8
|
||||
while chunk_start < inst.length:
|
||||
(chunk_length, chunk_name) = struct.unpack("!I4s", data[chunk_start:chunk_start+8])
|
||||
chunk_end = chunk_start + chunk_length + 12
|
||||
|
||||
chunk = chunk_map.get(chunk_name, Chunk).load(data[chunk_start:chunk_end])
|
||||
inst.chunks.append(chunk)
|
||||
chunk_start = chunk_end
|
||||
|
||||
return inst
|
||||
|
||||
@staticmethod
|
||||
def load(filename, loader_func=None):
|
||||
try:
|
||||
inst = DFLPNG.load_raw (filename, loader_func=loader_func)
|
||||
inst.dfl_dict = inst.getDFLDictData()
|
||||
|
||||
if inst.dfl_dict is not None:
|
||||
if 'face_type' not in inst.dfl_dict:
|
||||
inst.dfl_dict['face_type'] = FaceType.toString (FaceType.FULL)
|
||||
|
||||
if 'fanseg_mask' in inst.dfl_dict:
|
||||
fanseg_mask = inst.dfl_dict['fanseg_mask']
|
||||
if fanseg_mask is not None:
|
||||
numpyarray = np.asarray( inst.dfl_dict['fanseg_mask'], dtype=np.uint8)
|
||||
inst.dfl_dict['fanseg_mask'] = cv2.imdecode(numpyarray, cv2.IMREAD_UNCHANGED)
|
||||
|
||||
if inst.dfl_dict == None:
|
||||
return None
|
||||
|
||||
return inst
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def embed_dfldict(filename, dfl_dict):
|
||||
inst = DFLPNG.load_raw (filename)
|
||||
inst.setDFLDictData (dfl_dict)
|
||||
|
||||
try:
|
||||
with open(filename, "wb") as f:
|
||||
f.write ( inst.dump() )
|
||||
except:
|
||||
raise Exception( 'cannot save %s' % (filename) )
|
||||
|
||||
@staticmethod
|
||||
def embed_data(filename, face_type=None,
|
||||
landmarks=None,
|
||||
ie_polys=None,
|
||||
source_filename=None,
|
||||
source_rect=None,
|
||||
source_landmarks=None,
|
||||
image_to_face_mat=None,
|
||||
fanseg_mask=None,
|
||||
eyebrows_expand_mod=None,
|
||||
relighted=None,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
if fanseg_mask is not None:
|
||||
fanseg_mask = np.clip ( (fanseg_mask*255).astype(np.uint8), 0, 255 )
|
||||
|
||||
ret, buf = cv2.imencode( '.jpg', fanseg_mask, [int(cv2.IMWRITE_JPEG_QUALITY), 85] )
|
||||
|
||||
if ret and len(buf) < 60000:
|
||||
fanseg_mask = buf
|
||||
else:
|
||||
io.log_err("Unable to encode fanseg_mask for %s" % (filename) )
|
||||
fanseg_mask = None
|
||||
|
||||
if ie_polys is not None:
|
||||
if not isinstance(ie_polys, list):
|
||||
ie_polys = ie_polys.dump()
|
||||
|
||||
DFLPNG.embed_dfldict (filename, {'face_type': face_type,
|
||||
'landmarks': landmarks,
|
||||
'ie_polys' : ie_polys,
|
||||
'source_filename': source_filename,
|
||||
'source_rect': source_rect,
|
||||
'source_landmarks': source_landmarks,
|
||||
'image_to_face_mat':image_to_face_mat,
|
||||
'fanseg_mask' : fanseg_mask,
|
||||
'eyebrows_expand_mod' : eyebrows_expand_mod,
|
||||
'relighted' : relighted
|
||||
})
|
||||
|
||||
def embed_and_set(self, filename, face_type=None,
|
||||
landmarks=None,
|
||||
ie_polys=None,
|
||||
source_filename=None,
|
||||
source_rect=None,
|
||||
source_landmarks=None,
|
||||
image_to_face_mat=None,
|
||||
fanseg_mask=None,
|
||||
eyebrows_expand_mod=None,
|
||||
relighted=None,
|
||||
**kwargs
|
||||
):
|
||||
if face_type is None: face_type = self.get_face_type()
|
||||
if landmarks is None: landmarks = self.get_landmarks()
|
||||
if ie_polys is None: ie_polys = self.get_ie_polys()
|
||||
if source_filename is None: source_filename = self.get_source_filename()
|
||||
if source_rect is None: source_rect = self.get_source_rect()
|
||||
if source_landmarks is None: source_landmarks = self.get_source_landmarks()
|
||||
if image_to_face_mat is None: image_to_face_mat = self.get_image_to_face_mat()
|
||||
if fanseg_mask is None: fanseg_mask = self.get_fanseg_mask()
|
||||
if eyebrows_expand_mod is None: eyebrows_expand_mod = self.get_eyebrows_expand_mod()
|
||||
if relighted is None: relighted = self.get_relighted()
|
||||
|
||||
DFLPNG.embed_data (filename, face_type=face_type,
|
||||
landmarks=landmarks,
|
||||
ie_polys=ie_polys,
|
||||
source_filename=source_filename,
|
||||
source_rect=source_rect,
|
||||
source_landmarks=source_landmarks,
|
||||
image_to_face_mat=image_to_face_mat,
|
||||
fanseg_mask=fanseg_mask,
|
||||
eyebrows_expand_mod=eyebrows_expand_mod,
|
||||
relighted=relighted)
|
||||
|
||||
def remove_ie_polys(self):
|
||||
self.dfl_dict['ie_polys'] = None
|
||||
|
||||
def remove_fanseg_mask(self):
|
||||
self.dfl_dict['fanseg_mask'] = None
|
||||
|
||||
def remove_source_filename(self):
|
||||
self.dfl_dict['source_filename'] = None
|
||||
|
||||
def dump(self):
|
||||
data = PNG_HEADER
|
||||
for chunk in self.chunks:
|
||||
data += chunk.dump()
|
||||
return data
|
||||
|
||||
def get_shape(self):
|
||||
for chunk in self.chunks:
|
||||
if type(chunk) == IHDR:
|
||||
c = 3 if chunk.color_type == IHDR.COLOR_TYPE_RGB else 4
|
||||
w = chunk.width
|
||||
h = chunk.height
|
||||
return (h,w,c)
|
||||
return (0,0,0)
|
||||
|
||||
def get_height(self):
|
||||
for chunk in self.chunks:
|
||||
if type(chunk) == IHDR:
|
||||
return chunk.height
|
||||
return 0
|
||||
|
||||
def getDFLDictData(self):
|
||||
for chunk in self.chunks:
|
||||
if type(chunk) == DFLChunk:
|
||||
return chunk.getDictData()
|
||||
return None
|
||||
|
||||
def setDFLDictData (self, dict_data=None):
|
||||
for chunk in self.chunks:
|
||||
if type(chunk) == DFLChunk:
|
||||
self.chunks.remove(chunk)
|
||||
break
|
||||
|
||||
if not dict_data is None:
|
||||
chunk = DFLChunk(dict_data)
|
||||
self.chunks.insert(-1, chunk)
|
||||
|
||||
def get_face_type(self): return self.dfl_dict['face_type']
|
||||
def get_landmarks(self): return np.array ( self.dfl_dict['landmarks'] )
|
||||
def get_ie_polys(self): return self.dfl_dict.get('ie_polys',None)
|
||||
def get_source_filename(self): return self.dfl_dict['source_filename']
|
||||
def get_source_rect(self): return self.dfl_dict['source_rect']
|
||||
def get_source_landmarks(self): return np.array ( self.dfl_dict['source_landmarks'] )
|
||||
def get_image_to_face_mat(self):
|
||||
mat = self.dfl_dict.get ('image_to_face_mat', None)
|
||||
if mat is not None:
|
||||
return np.array (mat)
|
||||
return None
|
||||
def get_fanseg_mask(self):
|
||||
fanseg_mask = self.dfl_dict.get ('fanseg_mask', None)
|
||||
if fanseg_mask is not None:
|
||||
return np.clip ( np.array (fanseg_mask) / 255.0, 0.0, 1.0 )[...,np.newaxis]
|
||||
return None
|
||||
def get_eyebrows_expand_mod(self):
|
||||
return self.dfl_dict.get ('eyebrows_expand_mod', None)
|
||||
def get_relighted(self):
|
||||
return self.dfl_dict.get ('relighted', False)
|
||||
|
||||
|
||||
def __str__(self):
|
||||
return "<PNG length={length} chunks={}>".format(len(self.chunks), **self.__dict__)
|
|
@ -1,3 +1,2 @@
|
|||
from .DFLIMG import DFLIMG
|
||||
from .DFLJPG import DFLJPG
|
||||
from .DFLPNG import DFLPNG
|
||||
from .DFLJPG import DFLJPG
|
|
@ -93,7 +93,8 @@ deepfake quality progress
|
|||
||[Donate via Paypal](https://www.paypal.com/cgi-bin/webscr?cmd=_donations&business=lepersorium@gmail.com&lc=US&no_note=0&item_name=Support+DeepFaceLab&cn=&curency_code=USD&bn=PP-DonationsBF:btn_donateCC_LG.gif:NonHosted)
|
||||
||[Donate via Yandex.Money](https://money.yandex.ru/to/41001142318065)||
|
||||
||bitcoin:31mPd6DxPCzbpCMZk4k1koWAbErSyqkAXr||
|
||||
|Last donations|200$ ( 12 march 2020 VFXChris Ume )
|
||||
|Last donations|20$ ( 12 march 2020 maria.d )
|
||||
||200$ ( 12 march 2020 VFXChris Ume )
|
||||
||300$ ( 12 march 2020 Thiago O. )
|
||||
||50$ ( 8 march 2020 blanuk )
|
||||
||||
|
||||
|
|
|
@ -17,4 +17,6 @@ from .common import normalize_channels, cut_odd_image, overlay_alpha_image
|
|||
|
||||
from .IEPolys import IEPolys
|
||||
|
||||
from .blur import LinearMotionBlur
|
||||
from .blursharpen import LinearMotionBlur, blursharpen
|
||||
|
||||
from .filters import apply_random_hsv_shift, apply_random_motion_blur, apply_random_gaussian_blur, apply_random_bilinear_resize
|
||||
|
|
|
@ -1,9 +0,0 @@
|
|||
import cv2
|
||||
import numpy as np
|
||||
|
||||
def LinearMotionBlur(image, size, angle):
|
||||
k = np.zeros((size, size), dtype=np.float32)
|
||||
k[ (size-1)// 2 , :] = np.ones(size, dtype=np.float32)
|
||||
k = cv2.warpAffine(k, cv2.getRotationMatrix2D( (size / 2 -0.5 , size / 2 -0.5 ) , angle, 1.0), (size, size) )
|
||||
k = k * ( 1.0 / np.sum(k) )
|
||||
return cv2.filter2D(image, -1, k)
|
38
core/imagelib/blursharpen.py
Normal file
38
core/imagelib/blursharpen.py
Normal file
|
@ -0,0 +1,38 @@
|
|||
import cv2
|
||||
import numpy as np
|
||||
|
||||
def LinearMotionBlur(image, size, angle):
|
||||
k = np.zeros((size, size), dtype=np.float32)
|
||||
k[ (size-1)// 2 , :] = np.ones(size, dtype=np.float32)
|
||||
k = cv2.warpAffine(k, cv2.getRotationMatrix2D( (size / 2 -0.5 , size / 2 -0.5 ) , angle, 1.0), (size, size) )
|
||||
k = k * ( 1.0 / np.sum(k) )
|
||||
return cv2.filter2D(image, -1, k)
|
||||
|
||||
def blursharpen (img, sharpen_mode=0, kernel_size=3, amount=100):
|
||||
if kernel_size % 2 == 0:
|
||||
kernel_size += 1
|
||||
if amount > 0:
|
||||
if sharpen_mode == 1: #box
|
||||
kernel = np.zeros( (kernel_size, kernel_size), dtype=np.float32)
|
||||
kernel[ kernel_size//2, kernel_size//2] = 1.0
|
||||
box_filter = np.ones( (kernel_size, kernel_size), dtype=np.float32) / (kernel_size**2)
|
||||
kernel = kernel + (kernel - box_filter) * amount
|
||||
return cv2.filter2D(img, -1, kernel)
|
||||
elif sharpen_mode == 2: #gaussian
|
||||
blur = cv2.GaussianBlur(img, (kernel_size, kernel_size) , 0)
|
||||
img = cv2.addWeighted(img, 1.0 + (0.5 * amount), blur, -(0.5 * amount), 0)
|
||||
return img
|
||||
elif amount < 0:
|
||||
n = -amount
|
||||
while n > 0:
|
||||
|
||||
img_blur = cv2.medianBlur(img, 5)
|
||||
if int(n / 10) != 0:
|
||||
img = img_blur
|
||||
else:
|
||||
pass_power = (n % 10) / 10.0
|
||||
img = img*(1.0-pass_power)+img_blur*pass_power
|
||||
n = max(n-10,0)
|
||||
|
||||
return img
|
||||
return img
|
53
core/imagelib/filters.py
Normal file
53
core/imagelib/filters.py
Normal file
|
@ -0,0 +1,53 @@
|
|||
import numpy as np
|
||||
from .blursharpen import LinearMotionBlur
|
||||
import cv2
|
||||
|
||||
def apply_random_hsv_shift(img, rnd_state=None):
|
||||
if rnd_state is None:
|
||||
rnd_state = np.random
|
||||
|
||||
h, s, v = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
|
||||
h = ( h + rnd_state.randint(360) ) % 360
|
||||
s = np.clip ( s + rnd_state.random()-0.5, 0, 1 )
|
||||
v = np.clip ( v + rnd_state.random()/2-0.25, 0, 1 )
|
||||
img = np.clip( cv2.cvtColor(cv2.merge([h, s, v]), cv2.COLOR_HSV2BGR) , 0, 1 )
|
||||
return img
|
||||
|
||||
def apply_random_motion_blur( img, chance, mb_max_size, rnd_state=None ):
|
||||
if rnd_state is None:
|
||||
rnd_state = np.random
|
||||
|
||||
mblur_rnd_kernel = rnd_state.randint(mb_max_size)+1
|
||||
mblur_rnd_deg = rnd_state.randint(360)
|
||||
|
||||
if rnd_state.randint(100) < np.clip(chance, 0, 100):
|
||||
img = LinearMotionBlur (img, mblur_rnd_kernel, mblur_rnd_deg )
|
||||
|
||||
return img
|
||||
|
||||
def apply_random_gaussian_blur( img, chance, kernel_max_size, rnd_state=None ):
|
||||
if rnd_state is None:
|
||||
rnd_state = np.random
|
||||
|
||||
if rnd_state.randint(100) < np.clip(chance, 0, 100):
|
||||
gblur_rnd_kernel = rnd_state.randint(kernel_max_size)*2+1
|
||||
img = cv2.GaussianBlur(img, (gblur_rnd_kernel,)*2 , 0)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def apply_random_bilinear_resize( img, chance, max_size_per, rnd_state=None ):
|
||||
if rnd_state is None:
|
||||
rnd_state = np.random
|
||||
|
||||
if rnd_state.randint(100) < np.clip(chance, 0, 100):
|
||||
h,w,c = img.shape
|
||||
|
||||
trg = rnd_state.rand()
|
||||
rw = w - int( trg * int(w*(max_size_per/100.0)) )
|
||||
rh = h - int( trg * int(h*(max_size_per/100.0)) )
|
||||
|
||||
img = cv2.resize (img, (rw,rh), cv2.INTER_LINEAR )
|
||||
img = cv2.resize (img, (w,h), cv2.INTER_LINEAR )
|
||||
|
||||
return img
|
|
@ -1,13 +1,16 @@
|
|||
import multiprocessing
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import types
|
||||
|
||||
import colorama
|
||||
import cv2
|
||||
from tqdm import tqdm
|
||||
|
||||
from core import stdex
|
||||
|
||||
try:
|
||||
import IPython #if success we are in colab
|
||||
from IPython.display import display, clear_output
|
||||
|
@ -38,6 +41,8 @@ class InteractBase(object):
|
|||
self.focus_wnd_name = None
|
||||
self.error_log_line_prefix = '/!\\ '
|
||||
|
||||
self.process_messages_callbacks = {}
|
||||
|
||||
def is_support_windows(self):
|
||||
return False
|
||||
|
||||
|
@ -164,7 +169,21 @@ class InteractBase(object):
|
|||
self.pg_bar.close()
|
||||
self.pg_bar = None
|
||||
|
||||
def add_process_messages_callback(self, func ):
|
||||
tid = threading.get_ident()
|
||||
callbacks = self.process_messages_callbacks.get(tid, None)
|
||||
if callbacks is None:
|
||||
callbacks = []
|
||||
self.process_messages_callbacks[tid] = callbacks
|
||||
|
||||
callbacks.append ( func )
|
||||
|
||||
def process_messages(self, sleep_time=0):
|
||||
callbacks = self.process_messages_callbacks.get(threading.get_ident(), None)
|
||||
if callbacks is not None:
|
||||
for func in callbacks:
|
||||
func()
|
||||
|
||||
self.on_process_messages(sleep_time)
|
||||
|
||||
def wait_any_key(self):
|
||||
|
@ -359,11 +378,11 @@ class InteractBase(object):
|
|||
def input_process(self, stdin_fd, sq, str):
|
||||
sys.stdin = os.fdopen(stdin_fd)
|
||||
try:
|
||||
inp = input (str)
|
||||
inp = input (str)
|
||||
sq.put (True)
|
||||
except:
|
||||
sq.put (False)
|
||||
|
||||
sq.put (False)
|
||||
|
||||
def input_in_time (self, str, max_time_sec):
|
||||
sq = multiprocessing.Queue()
|
||||
p = multiprocessing.Process(target=self.input_process, args=( sys.stdin.fileno(), sq, str))
|
||||
|
@ -377,14 +396,14 @@ class InteractBase(object):
|
|||
break
|
||||
if time.time() - t > max_time_sec:
|
||||
break
|
||||
|
||||
|
||||
p.terminate()
|
||||
|
||||
|
||||
p.terminate()
|
||||
p.join()
|
||||
|
||||
|
||||
old_stdin = sys.stdin
|
||||
sys.stdin = os.fdopen( os.dup(sys.stdin.fileno()) )
|
||||
old_stdin.close()
|
||||
old_stdin.close()
|
||||
return inp
|
||||
|
||||
def input_process_skip_pending(self, stdin_fd):
|
||||
|
|
32
core/joblib/MPClassFuncOnDemand.py
Normal file
32
core/joblib/MPClassFuncOnDemand.py
Normal file
|
@ -0,0 +1,32 @@
|
|||
import multiprocessing
|
||||
from core.interact import interact as io
|
||||
|
||||
class MPClassFuncOnDemand():
|
||||
def __init__(self, class_handle, class_func_name, **class_kwargs):
|
||||
self.class_handle = class_handle
|
||||
self.class_func_name = class_func_name
|
||||
self.class_kwargs = class_kwargs
|
||||
|
||||
self.class_func = None
|
||||
|
||||
self.s2c = multiprocessing.Queue()
|
||||
self.c2s = multiprocessing.Queue()
|
||||
self.lock = multiprocessing.Lock()
|
||||
|
||||
io.add_process_messages_callback(self.io_callback)
|
||||
|
||||
def io_callback(self):
|
||||
while not self.c2s.empty():
|
||||
func_args, func_kwargs = self.c2s.get()
|
||||
if self.class_func is None:
|
||||
self.class_func = getattr( self.class_handle(**self.class_kwargs), self.class_func_name)
|
||||
self.s2c.put ( self.class_func (*func_args, **func_kwargs) )
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
with self.lock:
|
||||
self.c2s.put ( (args, kwargs) )
|
||||
return self.s2c.get()
|
||||
|
||||
def __getstate__(self):
|
||||
return {'s2c':self.s2c, 'c2s':self.c2s, 'lock':self.lock}
|
||||
|
25
core/joblib/MPFunc.py
Normal file
25
core/joblib/MPFunc.py
Normal file
|
@ -0,0 +1,25 @@
|
|||
import multiprocessing
|
||||
from core.interact import interact as io
|
||||
|
||||
class MPFunc():
|
||||
def __init__(self, func):
|
||||
self.func = func
|
||||
|
||||
self.s2c = multiprocessing.Queue()
|
||||
self.c2s = multiprocessing.Queue()
|
||||
self.lock = multiprocessing.Lock()
|
||||
|
||||
io.add_process_messages_callback(self.io_callback)
|
||||
|
||||
def io_callback(self):
|
||||
while not self.c2s.empty():
|
||||
func_args, func_kwargs = self.c2s.get()
|
||||
self.s2c.put ( self.func (*func_args, **func_kwargs) )
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
with self.lock:
|
||||
self.c2s.put ( (args, kwargs) )
|
||||
return self.s2c.get()
|
||||
|
||||
def __getstate__(self):
|
||||
return {'s2c':self.s2c, 'c2s':self.c2s, 'lock':self.lock}
|
|
@ -1,49 +0,0 @@
|
|||
import time
|
||||
import multiprocessing
|
||||
|
||||
class SubprocessFunctionCaller(object):
|
||||
class CliFunction(object):
|
||||
def __init__(self, s2c, c2s, lock):
|
||||
self.s2c = s2c
|
||||
self.c2s = c2s
|
||||
self.lock = lock
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
self.lock.acquire()
|
||||
self.c2s.put ( {'args':args, 'kwargs':kwargs} )
|
||||
while True:
|
||||
if not self.s2c.empty():
|
||||
obj = self.s2c.get()
|
||||
self.lock.release()
|
||||
return obj
|
||||
time.sleep(0.005)
|
||||
|
||||
class HostProcessor(object):
|
||||
def __init__(self, s2c, c2s, func):
|
||||
self.s2c = s2c
|
||||
self.c2s = c2s
|
||||
self.func = func
|
||||
|
||||
def process_messages(self):
|
||||
while not self.c2s.empty():
|
||||
obj = self.c2s.get()
|
||||
result = self.func ( *obj['args'], **obj['kwargs'] )
|
||||
self.s2c.put (result)
|
||||
|
||||
def __getstate__(self):
|
||||
#disable pickling this class
|
||||
return dict()
|
||||
|
||||
def __setstate__(self, d):
|
||||
self.__dict__.update(d)
|
||||
|
||||
@staticmethod
|
||||
def make_pair(func):
|
||||
s2c = multiprocessing.Queue()
|
||||
c2s = multiprocessing.Queue()
|
||||
lock = multiprocessing.Lock()
|
||||
|
||||
host_processor = SubprocessFunctionCaller.HostProcessor (s2c, c2s, func)
|
||||
cli_func = SubprocessFunctionCaller.CliFunction (s2c, c2s, lock)
|
||||
|
||||
return host_processor, cli_func
|
|
@ -1,4 +1,5 @@
|
|||
from .SubprocessorBase import Subprocessor
|
||||
from .SubprocessFunctionCaller import SubprocessFunctionCaller
|
||||
from .ThisThreadGenerator import ThisThreadGenerator
|
||||
from .SubprocessGenerator import SubprocessGenerator
|
||||
from .SubprocessGenerator import SubprocessGenerator
|
||||
from .MPFunc import MPFunc
|
||||
from .MPClassFuncOnDemand import MPClassFuncOnDemand
|
|
@ -1,569 +0,0 @@
|
|||
|
||||
def initialize_archis(nn):
|
||||
tf = nn.tf
|
||||
|
||||
def get_ae_models(resolution):
|
||||
lowest_dense_res = resolution // 16
|
||||
conv_kernel_initializer = nn.initializers.ca()
|
||||
|
||||
class Downscale(nn.ModelBase):
|
||||
def __init__(self, in_ch, out_ch, kernel_size=5, dilations=1, subpixel=True, use_activator=True, *kwargs ):
|
||||
self.in_ch = in_ch
|
||||
self.out_ch = out_ch
|
||||
self.kernel_size = kernel_size
|
||||
self.dilations = dilations
|
||||
self.subpixel = subpixel
|
||||
self.use_activator = use_activator
|
||||
super().__init__(*kwargs)
|
||||
|
||||
def on_build(self, *args, **kwargs ):
|
||||
self.conv1 = nn.Conv2D( self.in_ch,
|
||||
self.out_ch // (4 if self.subpixel else 1),
|
||||
kernel_size=self.kernel_size,
|
||||
strides=1 if self.subpixel else 2,
|
||||
padding='SAME', dilations=self.dilations, kernel_initializer=conv_kernel_initializer)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
if self.subpixel:
|
||||
x = nn.tf_space_to_depth(x, 2)
|
||||
if self.use_activator:
|
||||
x = tf.nn.leaky_relu(x, 0.1)
|
||||
return x
|
||||
|
||||
def get_out_ch(self):
|
||||
return (self.out_ch // 4) * 4
|
||||
|
||||
class DownscaleBlock(nn.ModelBase):
|
||||
def on_build(self, in_ch, ch, n_downscales, kernel_size, dilations=1, subpixel=True):
|
||||
self.downs = []
|
||||
|
||||
last_ch = in_ch
|
||||
for i in range(n_downscales):
|
||||
cur_ch = ch*( min(2**i, 8) )
|
||||
self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size, dilations=dilations, subpixel=subpixel) )
|
||||
last_ch = self.downs[-1].get_out_ch()
|
||||
|
||||
def forward(self, inp):
|
||||
x = inp
|
||||
for down in self.downs:
|
||||
x = down(x)
|
||||
return x
|
||||
|
||||
class Upscale(nn.ModelBase):
|
||||
def on_build(self, in_ch, out_ch, kernel_size=3 ):
|
||||
self.conv1 = nn.Conv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = tf.nn.leaky_relu(x, 0.1)
|
||||
x = nn.tf_depth_to_space(x, 2)
|
||||
return x
|
||||
|
||||
class ResidualBlock(nn.ModelBase):
|
||||
def on_build(self, ch, kernel_size=3 ):
|
||||
self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer)
|
||||
self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer)
|
||||
|
||||
def forward(self, inp):
|
||||
x = self.conv1(inp)
|
||||
x = tf.nn.leaky_relu(x, 0.2)
|
||||
x = self.conv2(x)
|
||||
x = tf.nn.leaky_relu(inp + x, 0.2)
|
||||
return x
|
||||
|
||||
class UpdownResidualBlock(nn.ModelBase):
|
||||
def on_build(self, ch, inner_ch, kernel_size=3 ):
|
||||
self.up = Upscale (ch, inner_ch, kernel_size=kernel_size)
|
||||
self.res = ResidualBlock (inner_ch, kernel_size=kernel_size)
|
||||
self.down = Downscale (inner_ch, ch, kernel_size=kernel_size, use_activator=False)
|
||||
|
||||
def forward(self, inp):
|
||||
x = self.up(inp)
|
||||
x = upx = self.res(x)
|
||||
x = self.down(x)
|
||||
x = x + inp
|
||||
x = tf.nn.leaky_relu(x, 0.2)
|
||||
return x, upx
|
||||
|
||||
class Encoder(nn.ModelBase):
|
||||
def on_build(self, in_ch, e_ch, is_hd):
|
||||
self.is_hd=is_hd
|
||||
if self.is_hd:
|
||||
self.down1 = DownscaleBlock(in_ch, e_ch*2, n_downscales=4, kernel_size=3, dilations=1)
|
||||
self.down2 = DownscaleBlock(in_ch, e_ch*2, n_downscales=4, kernel_size=5, dilations=1)
|
||||
self.down3 = DownscaleBlock(in_ch, e_ch//2, n_downscales=4, kernel_size=5, dilations=2)
|
||||
self.down4 = DownscaleBlock(in_ch, e_ch//2, n_downscales=4, kernel_size=7, dilations=2)
|
||||
else:
|
||||
self.down1 = DownscaleBlock(in_ch, e_ch, n_downscales=4, kernel_size=5, dilations=1, subpixel=False)
|
||||
|
||||
def forward(self, inp):
|
||||
if self.is_hd:
|
||||
x = tf.concat([ nn.tf_flatten(self.down1(inp)),
|
||||
nn.tf_flatten(self.down2(inp)),
|
||||
nn.tf_flatten(self.down3(inp)),
|
||||
nn.tf_flatten(self.down4(inp)) ], -1 )
|
||||
else:
|
||||
x = nn.tf_flatten(self.down1(inp))
|
||||
return x
|
||||
|
||||
class Inter(nn.ModelBase):
|
||||
def __init__(self, in_ch, ae_ch, ae_out_ch, is_hd=False, **kwargs):
|
||||
self.in_ch, self.ae_ch, self.ae_out_ch = in_ch, ae_ch, ae_out_ch
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def on_build(self):
|
||||
in_ch, ae_ch, ae_out_ch = self.in_ch, self.ae_ch, self.ae_out_ch
|
||||
|
||||
self.dense1 = nn.Dense( in_ch, ae_ch )
|
||||
self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch )
|
||||
self.upscale1 = Upscale(ae_out_ch, ae_out_ch)
|
||||
|
||||
def forward(self, inp):
|
||||
x = self.dense1(inp)
|
||||
x = self.dense2(x)
|
||||
x = nn.tf_reshape_4D (x, lowest_dense_res, lowest_dense_res, self.ae_out_ch)
|
||||
x = self.upscale1(x)
|
||||
return x
|
||||
|
||||
def get_out_ch(self):
|
||||
return self.ae_out_ch
|
||||
|
||||
class Decoder(nn.ModelBase):
|
||||
def on_build(self, in_ch, d_ch, d_mask_ch, is_hd ):
|
||||
self.is_hd = is_hd
|
||||
|
||||
self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3)
|
||||
self.upscale1 = Upscale(d_ch*8, d_ch*4, kernel_size=3)
|
||||
self.upscale2 = Upscale(d_ch*4, d_ch*2, kernel_size=3)
|
||||
|
||||
if is_hd:
|
||||
self.res0 = UpdownResidualBlock(in_ch, d_ch*8, kernel_size=3)
|
||||
self.res1 = UpdownResidualBlock(d_ch*8, d_ch*4, kernel_size=3)
|
||||
self.res2 = UpdownResidualBlock(d_ch*4, d_ch*2, kernel_size=3)
|
||||
self.res3 = UpdownResidualBlock(d_ch*2, d_ch, kernel_size=3)
|
||||
else:
|
||||
self.res0 = ResidualBlock(d_ch*8, kernel_size=3)
|
||||
self.res1 = ResidualBlock(d_ch*4, kernel_size=3)
|
||||
self.res2 = ResidualBlock(d_ch*2, kernel_size=3)
|
||||
|
||||
self.out_conv = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='SAME', kernel_initializer=conv_kernel_initializer)
|
||||
|
||||
self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3)
|
||||
self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3)
|
||||
self.upscalem2 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3)
|
||||
self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME', kernel_initializer=conv_kernel_initializer)
|
||||
|
||||
def get_weights_ex(self, include_mask):
|
||||
# Call internal get_weights in order to initialize inner logic
|
||||
self.get_weights()
|
||||
|
||||
weights = self.upscale0.get_weights() + self.upscale1.get_weights() + self.upscale2.get_weights() \
|
||||
+ self.res0.get_weights() + self.res1.get_weights() + self.res2.get_weights() + self.out_conv.get_weights()
|
||||
|
||||
if self.is_hd:
|
||||
weights += self.res3.get_weights()
|
||||
|
||||
if include_mask:
|
||||
weights += self.upscalem0.get_weights() + self.upscalem1.get_weights() + self.upscalem2.get_weights() \
|
||||
+ self.out_convm.get_weights()
|
||||
return weights
|
||||
|
||||
|
||||
def forward(self, inp):
|
||||
z = inp
|
||||
|
||||
if self.is_hd:
|
||||
x, upx = self.res0(z)
|
||||
x = self.upscale0(x)
|
||||
x = tf.nn.leaky_relu(x + upx, 0.2)
|
||||
x, upx = self.res1(x)
|
||||
|
||||
x = self.upscale1(x)
|
||||
x = tf.nn.leaky_relu(x + upx, 0.2)
|
||||
x, upx = self.res2(x)
|
||||
|
||||
x = self.upscale2(x)
|
||||
x = tf.nn.leaky_relu(x + upx, 0.2)
|
||||
x, upx = self.res3(x)
|
||||
else:
|
||||
x = self.upscale0(z)
|
||||
x = self.res0(x)
|
||||
x = self.upscale1(x)
|
||||
x = self.res1(x)
|
||||
x = self.upscale2(x)
|
||||
x = self.res2(x)
|
||||
|
||||
m = self.upscalem0(z)
|
||||
m = self.upscalem1(m)
|
||||
m = self.upscalem2(m)
|
||||
|
||||
return tf.nn.sigmoid(self.out_conv(x)), \
|
||||
tf.nn.sigmoid(self.out_convm(m))
|
||||
|
||||
|
||||
|
||||
return lowest_dense_res, Encoder, Inter, Decoder
|
||||
|
||||
nn.get_ae_models = get_ae_models
|
||||
|
||||
def get_ae_models_chervonij(resolution):
|
||||
lowest_dense_res = resolution // 32
|
||||
"""
|
||||
by @chervonij
|
||||
"""
|
||||
conv_kernel_initializer = nn.initializers.ca()
|
||||
|
||||
class Downscale(nn.ModelBase):
|
||||
def __init__(self, in_ch, kernel_size=3, dilations=1, *kwargs ):
|
||||
self.in_ch = in_ch
|
||||
self.kernel_size = kernel_size
|
||||
self.dilations = dilations
|
||||
super().__init__(*kwargs)
|
||||
|
||||
def on_build(self, *args, **kwargs ):
|
||||
self.conv_base1 = nn.Conv2D( self.in_ch, self.in_ch//2, kernel_size=1, strides=1, padding='SAME', dilations=self.dilations, kernel_initializer=conv_kernel_initializer)
|
||||
self.conv_l1 = nn.Conv2D( self.in_ch//2, self.in_ch//2, kernel_size=self.kernel_size, strides=1, padding='SAME', dilations=self.dilations, kernel_initializer=conv_kernel_initializer)
|
||||
self.conv_l2 = nn.Conv2D( self.in_ch//2, self.in_ch//2, kernel_size=self.kernel_size, strides=2, padding='SAME', dilations=self.dilations, kernel_initializer=conv_kernel_initializer)
|
||||
|
||||
self.conv_base2 = nn.Conv2D( self.in_ch, self.in_ch//2, kernel_size=1, strides=1, padding='SAME', dilations=self.dilations, kernel_initializer=conv_kernel_initializer)
|
||||
self.conv_r1 = nn.Conv2D( self.in_ch//2, self.in_ch//2, kernel_size=self.kernel_size, strides=2, padding='SAME', dilations=self.dilations, kernel_initializer=conv_kernel_initializer)
|
||||
|
||||
self.pool_size = [1,1,2,2] if nn.data_format == 'NCHW' else [1,2,2,1]
|
||||
def forward(self, x):
|
||||
|
||||
x_l = self.conv_base1(x)
|
||||
x_l = self.conv_l1(x_l)
|
||||
x_l = self.conv_l2(x_l)
|
||||
|
||||
x_r = self.conv_base2(x)
|
||||
x_r = self.conv_r1(x_r)
|
||||
|
||||
x_pool = tf.nn.max_pool(x, ksize=self.pool_size, strides=self.pool_size, padding='SAME', data_format=nn.data_format)
|
||||
|
||||
x = tf.concat([x_l, x_r, x_pool], axis=nn.conv2d_ch_axis)
|
||||
x = nn.tf_gelu(x)
|
||||
return x
|
||||
|
||||
class Upscale(nn.ModelBase):
|
||||
def on_build(self, in_ch, out_ch, kernel_size=3 ):
|
||||
self.conv1 = nn.Conv2D( in_ch, out_ch, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer)
|
||||
self.conv2 = nn.Conv2D( out_ch, out_ch, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer)
|
||||
self.conv3 = nn.Conv2D( out_ch, out_ch, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer)
|
||||
self.conv4 = nn.Conv2D( out_ch, out_ch, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer)
|
||||
|
||||
def forward(self, x):
|
||||
x0 = self.conv1(x)
|
||||
x1 = self.conv2(x0)
|
||||
x2 = self.conv3(x1)
|
||||
x3 = self.conv4(x2)
|
||||
x = tf.concat([x0, x1, x2, x3], axis=nn.conv2d_ch_axis)
|
||||
x = nn.tf_gelu(x)
|
||||
x = nn.tf_depth_to_space(x, 2)
|
||||
return x
|
||||
|
||||
class ResidualBlock(nn.ModelBase):
|
||||
def on_build(self, ch, kernel_size=3 ):
|
||||
self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer)
|
||||
self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer)
|
||||
|
||||
def forward(self, inp):
|
||||
x = self.conv1(inp)
|
||||
x = tf.nn.leaky_relu(x, 0.2)
|
||||
x = self.conv2(x)
|
||||
x = tf.nn.leaky_relu(inp + x, 0.2)
|
||||
return x
|
||||
|
||||
class Encoder(nn.ModelBase):
|
||||
def on_build(self, in_ch, e_ch, **kwargs):
|
||||
self.conv0 = nn.Conv2D(in_ch, e_ch, kernel_size=3, padding='SAME', kernel_initializer=conv_kernel_initializer)
|
||||
|
||||
self.down0 = Downscale(e_ch)
|
||||
self.down1 = Downscale(e_ch*2)
|
||||
self.down2 = Downscale(e_ch*4)
|
||||
self.down3 = Downscale(e_ch*8)
|
||||
self.down4 = Downscale(e_ch*16)
|
||||
|
||||
def forward(self, inp):
|
||||
x = self.conv0(inp)
|
||||
x = self.down0(x)
|
||||
x = self.down1(x)
|
||||
x = self.down2(x)
|
||||
x = self.down3(x)
|
||||
x = self.down4(x)
|
||||
x = nn.tf_flatten(x)
|
||||
return x
|
||||
|
||||
class Inter(nn.ModelBase):
|
||||
def __init__(self, in_ch, ae_ch, ae_out_ch, **kwargs):
|
||||
self.in_ch, self.ae_ch, self.ae_out_ch = in_ch, ae_ch, ae_out_ch
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def on_build(self, **kwargs):
|
||||
in_ch, ae_ch, ae_out_ch = self.in_ch, self.ae_ch, self.ae_out_ch
|
||||
|
||||
self.dense_l = nn.Dense( in_ch, ae_ch//2, kernel_initializer=tf.initializers.orthogonal)
|
||||
self.dense_r = nn.Dense( in_ch, ae_ch//2, maxout_features=4, kernel_initializer=tf.initializers.orthogonal)
|
||||
self.dense = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * (ae_out_ch//2), kernel_initializer=tf.initializers.orthogonal)
|
||||
self.upscale1 = Upscale(ae_out_ch//2, ae_out_ch//2)
|
||||
|
||||
def forward(self, inp):
|
||||
x0 = self.dense_l(inp)
|
||||
x1 = self.dense_r(inp)
|
||||
x = tf.concat([x0, x1], axis=-1)
|
||||
x = self.dense(x)
|
||||
x = nn.tf_reshape_4D (x, lowest_dense_res, lowest_dense_res, self.ae_out_ch//2)
|
||||
x = self.upscale1(x)
|
||||
return x
|
||||
|
||||
def get_out_ch(self):
|
||||
return self.ae_out_ch//2
|
||||
|
||||
class Decoder(nn.ModelBase):
|
||||
def on_build(self, in_ch, d_ch, d_mask_ch, **kwargs):
|
||||
|
||||
self.upscale0 = Upscale(in_ch, d_ch*8)
|
||||
self.upscale1 = Upscale(d_ch*8, d_ch*4)
|
||||
self.upscale2 = Upscale(d_ch*4, d_ch*2)
|
||||
self.upscale3 = Upscale(d_ch*2, d_ch)
|
||||
|
||||
self.out_conv = nn.Conv2D( d_ch, 3, kernel_size=1, padding='SAME', kernel_initializer=conv_kernel_initializer)
|
||||
|
||||
self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3)
|
||||
self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3)
|
||||
self.upscalem2 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3)
|
||||
self.upscalem3 = Upscale(d_mask_ch*2, d_mask_ch, kernel_size=3)
|
||||
self.out_convm = nn.Conv2D( d_mask_ch, 1, kernel_size=1, padding='SAME', kernel_initializer=conv_kernel_initializer)
|
||||
|
||||
def get_weights_ex(self, include_mask):
|
||||
# Call internal get_weights in order to initialize inner logic
|
||||
self.get_weights()
|
||||
|
||||
weights = self.upscale0.get_weights() + self.upscale1.get_weights() + self.upscale2.get_weights() + self.upscale3.get_weights() + self.out_conv.get_weights()
|
||||
|
||||
if include_mask:
|
||||
weights += self.upscalem0.get_weights() + self.upscalem1.get_weights() + self.upscalem2.get_weights() + self.upscale3.get_weights() + self.out_convm.get_weights()
|
||||
return weights
|
||||
|
||||
|
||||
def forward(self, inp):
|
||||
z = inp
|
||||
|
||||
x = self.upscale0(inp)
|
||||
x = self.upscale1(x)
|
||||
x = self.upscale2(x)
|
||||
x = self.upscale3(x)
|
||||
|
||||
m = self.upscalem0(z)
|
||||
m = self.upscalem1(m)
|
||||
m = self.upscalem2(m)
|
||||
m = self.upscalem3(m)
|
||||
|
||||
return tf.nn.sigmoid(self.out_conv(x)), \
|
||||
tf.nn.sigmoid(self.out_convm(m))
|
||||
|
||||
return lowest_dense_res, Encoder, Inter, Decoder
|
||||
|
||||
nn.get_ae_models_chervonij = get_ae_models_chervonij
|
||||
|
||||
"""
|
||||
def get_ae_models2():
|
||||
conv_kernel_initializer = nn.initializers.ca()
|
||||
class Downscale(nn.ModelBase):
|
||||
def __init__(self, in_ch, out_ch, kernel_size=5, dilations=1, is_hd=False, *kwargs ):
|
||||
self.in_ch = in_ch
|
||||
self.out_ch = out_ch
|
||||
self.kernel_size = kernel_size
|
||||
self.dilations = dilations
|
||||
self.is_hd = is_hd
|
||||
super().__init__(*kwargs)
|
||||
|
||||
def on_build(self, *args, **kwargs ):
|
||||
if not self.is_hd:
|
||||
self.conv1 = nn.Conv2D( self.in_ch, self.out_ch, kernel_size=self.kernel_size, strides=2, padding='SAME', dilations=self.dilations, kernel_initializer=conv_kernel_initializer)
|
||||
else:
|
||||
self.conv0 = nn.Conv2D( self.in_ch, self.out_ch//4, kernel_size=self.kernel_size, strides=2, padding='SAME', dilations=self.dilations, kernel_initializer=conv_kernel_initializer)
|
||||
self.conv1 = nn.Conv2D( self.out_ch//4, self.out_ch//4, kernel_size=self.kernel_size, strides=1, padding='SAME', dilations=self.dilations, kernel_initializer=conv_kernel_initializer)
|
||||
self.conv2 = nn.Conv2D( self.out_ch//4, self.out_ch//4, kernel_size=self.kernel_size, strides=1, padding='SAME', dilations=self.dilations, kernel_initializer=conv_kernel_initializer)
|
||||
self.conv3 = nn.Conv2D( self.out_ch//4, self.out_ch//4, kernel_size=self.kernel_size, strides=1, padding='SAME', dilations=self.dilations, kernel_initializer=conv_kernel_initializer)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
if not self.is_hd:
|
||||
x = self.conv1(x)
|
||||
else:
|
||||
x = x0 = self.conv0(x)
|
||||
x = x1 = self.conv1(x)
|
||||
x = x2 = self.conv2(x)
|
||||
x = x3 = self.conv3(x)
|
||||
x = tf.concat([x0,x1,x2,x3], axis=nn.conv2d_ch_axis)
|
||||
x = tf.nn.leaky_relu(x, 0.1)
|
||||
return x
|
||||
|
||||
def get_out_ch(self):
|
||||
return (self.out_ch // 4) * 4
|
||||
|
||||
class DownscaleBlock(nn.ModelBase):
|
||||
def on_build(self, in_ch, ch, n_downscales, kernel_size, dilations=1, is_hd=False):
|
||||
self.downs = []
|
||||
|
||||
last_ch = in_ch
|
||||
for i in range(n_downscales):
|
||||
cur_ch = ch*( min(2**i, 8) )
|
||||
self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size, dilations=dilations, is_hd=is_hd) )
|
||||
last_ch = self.downs[-1].get_out_ch()
|
||||
|
||||
def forward(self, inp):
|
||||
x = inp
|
||||
for down in self.downs:
|
||||
x = down(x)
|
||||
return x
|
||||
|
||||
class Upscale(nn.ModelBase):
|
||||
def on_build(self, in_ch, out_ch, kernel_size=3, is_hd=False ):
|
||||
self.is_hd = is_hd
|
||||
|
||||
if not self.is_hd:
|
||||
self.conv1 = nn.Conv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer)
|
||||
else:
|
||||
self.conv0 = nn.Conv2D( in_ch, out_ch, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer)
|
||||
self.conv1 = nn.Conv2D( out_ch, out_ch, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer)
|
||||
self.conv2 = nn.Conv2D( out_ch, out_ch, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer)
|
||||
self.conv3 = nn.Conv2D( out_ch, out_ch, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer)
|
||||
|
||||
def forward(self, x):
|
||||
if not self.is_hd:
|
||||
x = self.conv1(x)
|
||||
else:
|
||||
x = x0 = self.conv0(x)
|
||||
x = x1 = self.conv1(x)
|
||||
x = x2 = self.conv2(x)
|
||||
x = x3 = self.conv3(x)
|
||||
x = tf.concat([x0,x1,x2,x3], axis=nn.conv2d_ch_axis)
|
||||
|
||||
x = tf.nn.leaky_relu(x, 0.1)
|
||||
x = nn.tf_depth_to_space(x, 2)
|
||||
return x
|
||||
|
||||
class ResidualBlock(nn.ModelBase):
|
||||
def on_build(self, ch, kernel_size=3, is_hd=False ):
|
||||
self.is_hd = is_hd
|
||||
if not is_hd:
|
||||
self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer)
|
||||
self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer)
|
||||
else:
|
||||
self.conv10 = nn.Conv2D( ch, ch//4, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer)
|
||||
self.conv11 = nn.Conv2D( ch//4, ch//4, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer)
|
||||
self.conv12 = nn.Conv2D( ch//4, ch//4, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer)
|
||||
self.conv13 = nn.Conv2D( ch//4, ch//4, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer)
|
||||
|
||||
self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer)
|
||||
|
||||
def forward(self, inp):
|
||||
if not self.is_hd:
|
||||
x = self.conv1(inp)
|
||||
else:
|
||||
x = x0 = self.conv10(inp)
|
||||
x = x1 = self.conv11(x)
|
||||
x = x2 = self.conv12(x)
|
||||
x = x3 = self.conv13(x)
|
||||
x = tf.concat([x0,x1,x2,x3], axis=nn.conv2d_ch_axis)
|
||||
|
||||
x = tf.nn.leaky_relu(x, 0.2)
|
||||
x = self.conv2(x)
|
||||
x = tf.nn.leaky_relu(inp + x, 0.2)
|
||||
return x
|
||||
|
||||
class Encoder(nn.ModelBase):
|
||||
def on_build(self, in_ch, e_ch, is_hd):
|
||||
self.is_hd=is_hd
|
||||
if self.is_hd:
|
||||
self.down1 = DownscaleBlock(in_ch, e_ch*2, n_downscales=4, kernel_size=3, is_hd=is_hd)
|
||||
self.down2 = DownscaleBlock(in_ch, e_ch*2, n_downscales=4, kernel_size=5, is_hd=is_hd)
|
||||
else:
|
||||
self.down1 = DownscaleBlock(in_ch, e_ch, n_downscales=4, kernel_size=5, is_hd=is_hd)
|
||||
|
||||
def forward(self, inp):
|
||||
if self.is_hd:
|
||||
x = tf.concat([ nn.tf_flatten(self.down1(inp)),
|
||||
nn.tf_flatten(self.down2(inp)) ], -1 )
|
||||
else:
|
||||
x = nn.tf_flatten(self.down1(inp))
|
||||
return x
|
||||
|
||||
class Inter(nn.ModelBase):
|
||||
def __init__(self, in_ch, lowest_dense_res, ae_ch, ae_out_ch, is_hd=False, **kwargs):
|
||||
self.in_ch, self.lowest_dense_res, self.ae_ch, self.ae_out_ch, self.is_hd = in_ch, lowest_dense_res, ae_ch, ae_out_ch, is_hd
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def on_build(self):
|
||||
in_ch, lowest_dense_res, ae_ch, ae_out_ch = self.in_ch, self.lowest_dense_res, self.ae_ch, self.ae_out_ch
|
||||
|
||||
self.dense1 = nn.Dense( in_ch, ae_ch )
|
||||
self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch )
|
||||
self.upscale1 = Upscale(ae_out_ch, ae_out_ch, is_hd=self.is_hd)
|
||||
|
||||
def forward(self, inp):
|
||||
x = self.dense1(inp)
|
||||
x = self.dense2(x)
|
||||
x = nn.tf_reshape_4D (x, self.lowest_dense_res, self.lowest_dense_res, self.ae_out_ch)
|
||||
x = self.upscale1(x)
|
||||
return x
|
||||
|
||||
def get_out_ch(self):
|
||||
return self.ae_out_ch
|
||||
|
||||
class Decoder(nn.ModelBase):
|
||||
def on_build(self, in_ch, d_ch, d_mask_ch, is_hd ):
|
||||
self.is_hd = is_hd
|
||||
|
||||
self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3)
|
||||
self.upscale1 = Upscale(d_ch*8, d_ch*4, kernel_size=3)
|
||||
self.upscale2 = Upscale(d_ch*4, d_ch*2, kernel_size=3)
|
||||
|
||||
self.res0 = ResidualBlock(d_ch*8, kernel_size=3, is_hd=is_hd)
|
||||
self.res1 = ResidualBlock(d_ch*4, kernel_size=3, is_hd=is_hd)
|
||||
self.res2 = ResidualBlock(d_ch*2, kernel_size=3, is_hd=is_hd)
|
||||
|
||||
self.out_conv = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='SAME', kernel_initializer=conv_kernel_initializer)
|
||||
|
||||
self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3)
|
||||
self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3)
|
||||
self.upscalem2 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3)
|
||||
self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME', kernel_initializer=conv_kernel_initializer)
|
||||
|
||||
def get_weights_ex(self, include_mask):
|
||||
# Call internal get_weights in order to initialize inner logic
|
||||
self.get_weights()
|
||||
|
||||
weights = self.upscale0.get_weights() + self.upscale1.get_weights() + self.upscale2.get_weights() \
|
||||
+ self.res0.get_weights() + self.res1.get_weights() + self.res2.get_weights() + self.out_conv.get_weights()
|
||||
|
||||
if include_mask:
|
||||
weights += self.upscalem0.get_weights() + self.upscalem1.get_weights() + self.upscalem2.get_weights() \
|
||||
+ self.out_convm.get_weights()
|
||||
return weights
|
||||
|
||||
|
||||
def forward(self, inp):
|
||||
z = inp
|
||||
|
||||
x = self.upscale0(z)
|
||||
x = self.res0(x)
|
||||
x = self.upscale1(x)
|
||||
x = self.res1(x)
|
||||
x = self.upscale2(x)
|
||||
x = self.res2(x)
|
||||
|
||||
m = self.upscalem0(z)
|
||||
m = self.upscalem1(m)
|
||||
m = self.upscalem2(m)
|
||||
|
||||
return tf.nn.sigmoid(self.out_conv(x)), \
|
||||
tf.nn.sigmoid(self.out_convm(m))
|
||||
|
||||
|
||||
return Encoder, Inter, Decoder
|
||||
|
||||
nn.get_ae_models2 = get_ae_models2
|
||||
"""
|
17
core/leras/archis/ArchiBase.py
Normal file
17
core/leras/archis/ArchiBase.py
Normal file
|
@ -0,0 +1,17 @@
|
|||
from core.leras import nn
|
||||
|
||||
class ArchiBase():
|
||||
|
||||
def __init__(self, *args, name=None, **kwargs):
|
||||
self.name=name
|
||||
|
||||
|
||||
#overridable
|
||||
def flow(self, *args, **kwargs):
|
||||
raise Exception("this archi does not support flow. Use model classes directly.")
|
||||
|
||||
#overridable
|
||||
def get_weights(self):
|
||||
pass
|
||||
|
||||
nn.ArchiBase = ArchiBase
|
151
core/leras/archis/DFLSegnet.py
Normal file
151
core/leras/archis/DFLSegnet.py
Normal file
|
@ -0,0 +1,151 @@
|
|||
from core.leras import nn
|
||||
tf = nn.tf
|
||||
|
||||
class DFLSegnetArchi(nn.ArchiBase):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
class ConvBlock(nn.ModelBase):
|
||||
def on_build(self, in_ch, out_ch):
|
||||
self.conv = nn.Conv2D (in_ch, out_ch, kernel_size=3, padding='SAME')
|
||||
self.frn = nn.FRNorm2D(out_ch)
|
||||
self.tlu = nn.TLU(out_ch)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.frn(x)
|
||||
x = self.tlu(x)
|
||||
return x
|
||||
|
||||
class UpConvBlock(nn.ModelBase):
|
||||
def on_build(self, in_ch, out_ch):
|
||||
self.conv = nn.Conv2DTranspose (in_ch, out_ch, kernel_size=3, padding='SAME')
|
||||
self.frn = nn.FRNorm2D(out_ch)
|
||||
self.tlu = nn.TLU(out_ch)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.frn(x)
|
||||
x = self.tlu(x)
|
||||
return x
|
||||
|
||||
class Encoder(nn.ModelBase):
|
||||
def on_build(self, in_ch, base_ch):
|
||||
self.conv01 = ConvBlock(in_ch, base_ch)
|
||||
self.conv02 = ConvBlock(base_ch, base_ch)
|
||||
self.bp0 = nn.BlurPool (filt_size=3)
|
||||
|
||||
|
||||
self.conv11 = ConvBlock(base_ch, base_ch*2)
|
||||
self.conv12 = ConvBlock(base_ch*2, base_ch*2)
|
||||
self.bp1 = nn.BlurPool (filt_size=3)
|
||||
|
||||
self.conv21 = ConvBlock(base_ch*2, base_ch*4)
|
||||
self.conv22 = ConvBlock(base_ch*4, base_ch*4)
|
||||
self.conv23 = ConvBlock(base_ch*4, base_ch*4)
|
||||
self.bp2 = nn.BlurPool (filt_size=3)
|
||||
|
||||
|
||||
self.conv31 = ConvBlock(base_ch*4, base_ch*8)
|
||||
self.conv32 = ConvBlock(base_ch*8, base_ch*8)
|
||||
self.conv33 = ConvBlock(base_ch*8, base_ch*8)
|
||||
self.bp3 = nn.BlurPool (filt_size=3)
|
||||
|
||||
self.conv41 = ConvBlock(base_ch*8, base_ch*8)
|
||||
self.conv42 = ConvBlock(base_ch*8, base_ch*8)
|
||||
self.conv43 = ConvBlock(base_ch*8, base_ch*8)
|
||||
self.bp4 = nn.BlurPool (filt_size=3)
|
||||
|
||||
self.conv_center = ConvBlock(base_ch*8, base_ch*8)
|
||||
|
||||
def forward(self, inp):
|
||||
x = inp
|
||||
|
||||
x = self.conv01(x)
|
||||
x = x0 = self.conv02(x)
|
||||
x = self.bp0(x)
|
||||
|
||||
x = self.conv11(x)
|
||||
x = x1 = self.conv12(x)
|
||||
x = self.bp1(x)
|
||||
|
||||
x = self.conv21(x)
|
||||
x = self.conv22(x)
|
||||
x = x2 = self.conv23(x)
|
||||
x = self.bp2(x)
|
||||
|
||||
x = self.conv31(x)
|
||||
x = self.conv32(x)
|
||||
x = x3 = self.conv33(x)
|
||||
x = self.bp3(x)
|
||||
|
||||
x = self.conv41(x)
|
||||
x = self.conv42(x)
|
||||
x = x4 = self.conv43(x)
|
||||
x = self.bp4(x)
|
||||
|
||||
x = self.conv_center(x)
|
||||
|
||||
return x0,x1,x2,x3,x4, x
|
||||
|
||||
|
||||
|
||||
class Decoder(nn.ModelBase):
|
||||
def on_build(self, base_ch, out_ch):
|
||||
|
||||
self.up4 = UpConvBlock (base_ch*8, base_ch*4)
|
||||
self.conv43 = ConvBlock(base_ch*12, base_ch*8)
|
||||
self.conv42 = ConvBlock(base_ch*8, base_ch*8)
|
||||
self.conv41 = ConvBlock(base_ch*8, base_ch*8)
|
||||
|
||||
self.up3 = UpConvBlock (base_ch*8, base_ch*4)
|
||||
self.conv33 = ConvBlock(base_ch*12, base_ch*8)
|
||||
self.conv32 = ConvBlock(base_ch*8, base_ch*8)
|
||||
self.conv31 = ConvBlock(base_ch*8, base_ch*8)
|
||||
|
||||
self.up2 = UpConvBlock (base_ch*8, base_ch*4)
|
||||
self.conv23 = ConvBlock(base_ch*8, base_ch*4)
|
||||
self.conv22 = ConvBlock(base_ch*4, base_ch*4)
|
||||
self.conv21 = ConvBlock(base_ch*4, base_ch*4)
|
||||
|
||||
self.up1 = UpConvBlock (base_ch*4, base_ch*2)
|
||||
self.conv12 = ConvBlock(base_ch*4, base_ch*2)
|
||||
self.conv11 = ConvBlock(base_ch*2, base_ch*2)
|
||||
|
||||
self.up0 = UpConvBlock (base_ch*2, base_ch)
|
||||
self.conv02 = ConvBlock(base_ch*2, base_ch)
|
||||
self.conv01 = ConvBlock(base_ch, base_ch)
|
||||
self.out_conv = nn.Conv2D (base_ch, out_ch, kernel_size=3, padding='SAME')
|
||||
|
||||
def forward(self, inp):
|
||||
x0,x1,x2,x3,x4,x = inp
|
||||
|
||||
x = self.up4(x)
|
||||
x = self.conv43(tf.concat([x,x4],axis=nn.conv2d_ch_axis))
|
||||
x = self.conv42(x)
|
||||
x = self.conv41(x)
|
||||
|
||||
x = self.up3(x)
|
||||
x = self.conv33(tf.concat([x,x3],axis=nn.conv2d_ch_axis))
|
||||
x = self.conv32(x)
|
||||
x = self.conv31(x)
|
||||
|
||||
x = self.up2(x)
|
||||
x = self.conv23(tf.concat([x,x2],axis=nn.conv2d_ch_axis))
|
||||
x = self.conv22(x)
|
||||
x = self.conv21(x)
|
||||
|
||||
x = self.up1(x)
|
||||
x = self.conv12(tf.concat([x,x1],axis=nn.conv2d_ch_axis))
|
||||
x = self.conv11(x)
|
||||
|
||||
x = self.up0(x)
|
||||
x = self.conv02(tf.concat([x,x0],axis=nn.conv2d_ch_axis))
|
||||
x = self.conv01(x)
|
||||
|
||||
logits = self.out_conv(x)
|
||||
return logits, tf.nn.sigmoid(logits)
|
||||
self.Encoder = Encoder
|
||||
self.Decoder = Decoder
|
||||
|
||||
nn.DFLSegnetArchi = DFLSegnetArchi
|
490
core/leras/archis/DeepFakeArchi.py
Normal file
490
core/leras/archis/DeepFakeArchi.py
Normal file
|
@ -0,0 +1,490 @@
|
|||
from core.leras import nn
|
||||
tf = nn.tf
|
||||
|
||||
class DeepFakeArchi(nn.ArchiBase):
|
||||
"""
|
||||
resolution
|
||||
|
||||
mod None - default
|
||||
'chervonij'
|
||||
'quick'
|
||||
"""
|
||||
def __init__(self, resolution, mod=None):
|
||||
super().__init__()
|
||||
|
||||
if mod is None:
|
||||
class Downscale(nn.ModelBase):
|
||||
def __init__(self, in_ch, out_ch, kernel_size=5, dilations=1, subpixel=True, use_activator=True, *kwargs ):
|
||||
self.in_ch = in_ch
|
||||
self.out_ch = out_ch
|
||||
self.kernel_size = kernel_size
|
||||
self.dilations = dilations
|
||||
self.subpixel = subpixel
|
||||
self.use_activator = use_activator
|
||||
super().__init__(*kwargs)
|
||||
|
||||
def on_build(self, *args, **kwargs ):
|
||||
self.conv1 = nn.Conv2D( self.in_ch,
|
||||
self.out_ch // (4 if self.subpixel else 1),
|
||||
kernel_size=self.kernel_size,
|
||||
strides=1 if self.subpixel else 2,
|
||||
padding='SAME', dilations=self.dilations)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
if self.subpixel:
|
||||
x = nn.space_to_depth(x, 2)
|
||||
if self.use_activator:
|
||||
x = tf.nn.leaky_relu(x, 0.1)
|
||||
return x
|
||||
|
||||
def get_out_ch(self):
|
||||
return (self.out_ch // 4) * 4
|
||||
|
||||
class DownscaleBlock(nn.ModelBase):
|
||||
def on_build(self, in_ch, ch, n_downscales, kernel_size, dilations=1, subpixel=True):
|
||||
self.downs = []
|
||||
|
||||
last_ch = in_ch
|
||||
for i in range(n_downscales):
|
||||
cur_ch = ch*( min(2**i, 8) )
|
||||
self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size, dilations=dilations, subpixel=subpixel) )
|
||||
last_ch = self.downs[-1].get_out_ch()
|
||||
|
||||
def forward(self, inp):
|
||||
x = inp
|
||||
for down in self.downs:
|
||||
x = down(x)
|
||||
return x
|
||||
|
||||
class Upscale(nn.ModelBase):
|
||||
def on_build(self, in_ch, out_ch, kernel_size=3 ):
|
||||
self.conv1 = nn.Conv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME')
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = tf.nn.leaky_relu(x, 0.1)
|
||||
x = nn.depth_to_space(x, 2)
|
||||
return x
|
||||
|
||||
class ResidualBlock(nn.ModelBase):
|
||||
def on_build(self, ch, kernel_size=3 ):
|
||||
self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
|
||||
self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
|
||||
|
||||
def forward(self, inp):
|
||||
x = self.conv1(inp)
|
||||
x = tf.nn.leaky_relu(x, 0.2)
|
||||
x = self.conv2(x)
|
||||
x = tf.nn.leaky_relu(inp + x, 0.2)
|
||||
return x
|
||||
|
||||
class UpdownResidualBlock(nn.ModelBase):
|
||||
def on_build(self, ch, inner_ch, kernel_size=3 ):
|
||||
self.up = Upscale (ch, inner_ch, kernel_size=kernel_size)
|
||||
self.res = ResidualBlock (inner_ch, kernel_size=kernel_size)
|
||||
self.down = Downscale (inner_ch, ch, kernel_size=kernel_size, use_activator=False)
|
||||
|
||||
def forward(self, inp):
|
||||
x = self.up(inp)
|
||||
x = upx = self.res(x)
|
||||
x = self.down(x)
|
||||
x = x + inp
|
||||
x = tf.nn.leaky_relu(x, 0.2)
|
||||
return x, upx
|
||||
|
||||
class Encoder(nn.ModelBase):
|
||||
def on_build(self, in_ch, e_ch, is_hd):
|
||||
self.is_hd=is_hd
|
||||
if self.is_hd:
|
||||
self.down1 = DownscaleBlock(in_ch, e_ch*2, n_downscales=4, kernel_size=3, dilations=1)
|
||||
self.down2 = DownscaleBlock(in_ch, e_ch*2, n_downscales=4, kernel_size=5, dilations=1)
|
||||
self.down3 = DownscaleBlock(in_ch, e_ch//2, n_downscales=4, kernel_size=5, dilations=2)
|
||||
self.down4 = DownscaleBlock(in_ch, e_ch//2, n_downscales=4, kernel_size=7, dilations=2)
|
||||
else:
|
||||
self.down1 = DownscaleBlock(in_ch, e_ch, n_downscales=4, kernel_size=5, dilations=1, subpixel=False)
|
||||
|
||||
def forward(self, inp):
|
||||
if self.is_hd:
|
||||
x = tf.concat([ nn.flatten(self.down1(inp)),
|
||||
nn.flatten(self.down2(inp)),
|
||||
nn.flatten(self.down3(inp)),
|
||||
nn.flatten(self.down4(inp)) ], -1 )
|
||||
else:
|
||||
x = nn.flatten(self.down1(inp))
|
||||
return x
|
||||
|
||||
lowest_dense_res = resolution // 16
|
||||
|
||||
class Inter(nn.ModelBase):
|
||||
def __init__(self, in_ch, ae_ch, ae_out_ch, is_hd=False, **kwargs):
|
||||
self.in_ch, self.ae_ch, self.ae_out_ch = in_ch, ae_ch, ae_out_ch
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def on_build(self):
|
||||
in_ch, ae_ch, ae_out_ch = self.in_ch, self.ae_ch, self.ae_out_ch
|
||||
|
||||
self.dense1 = nn.Dense( in_ch, ae_ch )
|
||||
self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch )
|
||||
self.upscale1 = Upscale(ae_out_ch, ae_out_ch)
|
||||
|
||||
def forward(self, inp):
|
||||
x = self.dense1(inp)
|
||||
x = self.dense2(x)
|
||||
x = nn.reshape_4D (x, lowest_dense_res, lowest_dense_res, self.ae_out_ch)
|
||||
x = self.upscale1(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def get_code_res():
|
||||
return lowest_dense_res
|
||||
|
||||
def get_out_ch(self):
|
||||
return self.ae_out_ch
|
||||
|
||||
class Decoder(nn.ModelBase):
|
||||
def on_build(self, in_ch, d_ch, d_mask_ch, is_hd ):
|
||||
self.is_hd = is_hd
|
||||
|
||||
self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3)
|
||||
self.upscale1 = Upscale(d_ch*8, d_ch*4, kernel_size=3)
|
||||
self.upscale2 = Upscale(d_ch*4, d_ch*2, kernel_size=3)
|
||||
|
||||
if is_hd:
|
||||
self.res0 = UpdownResidualBlock(in_ch, d_ch*8, kernel_size=3)
|
||||
self.res1 = UpdownResidualBlock(d_ch*8, d_ch*4, kernel_size=3)
|
||||
self.res2 = UpdownResidualBlock(d_ch*4, d_ch*2, kernel_size=3)
|
||||
self.res3 = UpdownResidualBlock(d_ch*2, d_ch, kernel_size=3)
|
||||
else:
|
||||
self.res0 = ResidualBlock(d_ch*8, kernel_size=3)
|
||||
self.res1 = ResidualBlock(d_ch*4, kernel_size=3)
|
||||
self.res2 = ResidualBlock(d_ch*2, kernel_size=3)
|
||||
|
||||
self.out_conv = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='SAME')
|
||||
|
||||
self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3)
|
||||
self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3)
|
||||
self.upscalem2 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3)
|
||||
self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME')
|
||||
|
||||
def forward(self, inp):
|
||||
z = inp
|
||||
|
||||
if self.is_hd:
|
||||
x, upx = self.res0(z)
|
||||
x = self.upscale0(x)
|
||||
x = tf.nn.leaky_relu(x + upx, 0.2)
|
||||
x, upx = self.res1(x)
|
||||
|
||||
x = self.upscale1(x)
|
||||
x = tf.nn.leaky_relu(x + upx, 0.2)
|
||||
x, upx = self.res2(x)
|
||||
|
||||
x = self.upscale2(x)
|
||||
x = tf.nn.leaky_relu(x + upx, 0.2)
|
||||
x, upx = self.res3(x)
|
||||
else:
|
||||
x = self.upscale0(z)
|
||||
x = self.res0(x)
|
||||
x = self.upscale1(x)
|
||||
x = self.res1(x)
|
||||
x = self.upscale2(x)
|
||||
x = self.res2(x)
|
||||
|
||||
m = self.upscalem0(z)
|
||||
m = self.upscalem1(m)
|
||||
m = self.upscalem2(m)
|
||||
|
||||
return tf.nn.sigmoid(self.out_conv(x)), \
|
||||
tf.nn.sigmoid(self.out_convm(m))
|
||||
|
||||
elif mod == 'chervonij':
|
||||
class Downscale(nn.ModelBase):
|
||||
def __init__(self, in_ch, kernel_size=3, dilations=1, *kwargs ):
|
||||
self.in_ch = in_ch
|
||||
self.kernel_size = kernel_size
|
||||
self.dilations = dilations
|
||||
super().__init__(*kwargs)
|
||||
|
||||
def on_build(self, *args, **kwargs ):
|
||||
self.conv_base1 = nn.Conv2D( self.in_ch, self.in_ch//2, kernel_size=1, strides=1, padding='SAME', dilations=self.dilations)
|
||||
self.conv_l1 = nn.Conv2D( self.in_ch//2, self.in_ch//2, kernel_size=self.kernel_size, strides=1, padding='SAME', dilations=self.dilations)
|
||||
self.conv_l2 = nn.Conv2D( self.in_ch//2, self.in_ch//2, kernel_size=self.kernel_size, strides=2, padding='SAME', dilations=self.dilations)
|
||||
|
||||
self.conv_base2 = nn.Conv2D( self.in_ch, self.in_ch//2, kernel_size=1, strides=1, padding='SAME', dilations=self.dilations)
|
||||
self.conv_r1 = nn.Conv2D( self.in_ch//2, self.in_ch//2, kernel_size=self.kernel_size, strides=2, padding='SAME', dilations=self.dilations)
|
||||
|
||||
self.pool_size = [1,1,2,2] if nn.data_format == 'NCHW' else [1,2,2,1]
|
||||
def forward(self, x):
|
||||
|
||||
x_l = self.conv_base1(x)
|
||||
x_l = self.conv_l1(x_l)
|
||||
x_l = self.conv_l2(x_l)
|
||||
|
||||
x_r = self.conv_base2(x)
|
||||
x_r = self.conv_r1(x_r)
|
||||
|
||||
x_pool = tf.nn.max_pool(x, ksize=self.pool_size, strides=self.pool_size, padding='SAME', data_format=nn.data_format)
|
||||
|
||||
x = tf.concat([x_l, x_r, x_pool], axis=nn.conv2d_ch_axis)
|
||||
x = tf.nn.leaky_relu(x, 0.1)
|
||||
return x
|
||||
|
||||
class Upscale(nn.ModelBase):
|
||||
def on_build(self, in_ch, out_ch, kernel_size=3 ):
|
||||
self.conv1 = nn.Conv2D( in_ch, out_ch, kernel_size=kernel_size, padding='SAME')
|
||||
self.conv2 = nn.Conv2D( out_ch, out_ch, kernel_size=kernel_size, padding='SAME')
|
||||
self.conv3 = nn.Conv2D( out_ch, out_ch, kernel_size=kernel_size, padding='SAME')
|
||||
self.conv4 = nn.Conv2D( out_ch, out_ch, kernel_size=kernel_size, padding='SAME')
|
||||
|
||||
def forward(self, x):
|
||||
x0 = self.conv1(x)
|
||||
x1 = self.conv2(x0)
|
||||
x2 = self.conv3(x1)
|
||||
x3 = self.conv4(x2)
|
||||
x = tf.concat([x0, x1, x2, x3], axis=nn.conv2d_ch_axis)
|
||||
x = tf.nn.leaky_relu(x, 0.1)
|
||||
x = nn.depth_to_space(x, 2)
|
||||
return x
|
||||
|
||||
class ResidualBlock(nn.ModelBase):
|
||||
def on_build(self, ch, kernel_size=3 ):
|
||||
self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
|
||||
self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
|
||||
self.norm = nn.FRNorm2D(ch)
|
||||
|
||||
def forward(self, inp):
|
||||
x = self.conv1(inp)
|
||||
x = tf.nn.leaky_relu(x, 0.2)
|
||||
x = self.conv2(x)
|
||||
x = self.norm(inp + x)
|
||||
x = tf.nn.leaky_relu(x, 0.2)
|
||||
return x
|
||||
|
||||
class Encoder(nn.ModelBase):
|
||||
def on_build(self, in_ch, e_ch, **kwargs):
|
||||
self.conv0 = nn.Conv2D(in_ch, e_ch, kernel_size=3, padding='SAME')
|
||||
|
||||
self.down0 = Downscale(e_ch)
|
||||
self.down1 = Downscale(e_ch*2)
|
||||
self.down2 = Downscale(e_ch*4)
|
||||
self.down3 = Downscale(e_ch*8)
|
||||
self.down4 = Downscale(e_ch*16)
|
||||
|
||||
def forward(self, inp):
|
||||
x = self.conv0(inp)
|
||||
x = self.down0(x)
|
||||
x = self.down1(x)
|
||||
x = self.down2(x)
|
||||
x = self.down3(x)
|
||||
x = self.down4(x)
|
||||
x = nn.flatten(x)
|
||||
return x
|
||||
|
||||
lowest_dense_res = resolution // 32
|
||||
|
||||
class Inter(nn.ModelBase):
|
||||
def __init__(self, in_ch, ae_ch, ae_out_ch, **kwargs):
|
||||
self.in_ch, self.ae_ch, self.ae_out_ch = in_ch, ae_ch, ae_out_ch
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def on_build(self, **kwargs):
|
||||
in_ch, ae_ch, ae_out_ch = self.in_ch, self.ae_ch, self.ae_out_ch
|
||||
|
||||
self.dense_l = nn.Dense( in_ch, ae_ch//2, kernel_initializer=tf.initializers.orthogonal)
|
||||
self.dense_r = nn.Dense( in_ch, ae_ch//2, kernel_initializer=tf.initializers.orthogonal)#maxout_ch=4,
|
||||
self.dense = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * (ae_out_ch//2), kernel_initializer=tf.initializers.orthogonal)
|
||||
self.upscale1 = Upscale(ae_out_ch//2, ae_out_ch//2)
|
||||
|
||||
def forward(self, inp):
|
||||
x0 = self.dense_l(inp)
|
||||
x1 = self.dense_r(inp)
|
||||
x = tf.concat([x0, x1], axis=-1)
|
||||
x = self.dense(x)
|
||||
x = nn.reshape_4D (x, lowest_dense_res, lowest_dense_res, self.ae_out_ch//2)
|
||||
x = self.upscale1(x)
|
||||
|
||||
return x
|
||||
|
||||
def get_out_ch(self):
|
||||
return self.ae_out_ch//2
|
||||
|
||||
class Decoder(nn.ModelBase):
|
||||
def on_build(self, in_ch, d_ch, d_mask_ch, **kwargs):
|
||||
|
||||
self.upscale0 = Upscale(in_ch, d_ch*8)
|
||||
self.upscale1 = Upscale(d_ch*8, d_ch*4)
|
||||
self.upscale2 = Upscale(d_ch*4, d_ch*2)
|
||||
self.upscale3 = Upscale(d_ch*2, d_ch)
|
||||
|
||||
self.res0 = ResidualBlock(d_ch*8)
|
||||
self.res1 = ResidualBlock(d_ch*4)
|
||||
self.res2 = ResidualBlock(d_ch*2)
|
||||
self.res3 = ResidualBlock(d_ch)
|
||||
|
||||
self.out_conv = nn.Conv2D( d_ch, 3, kernel_size=1, padding='SAME')
|
||||
|
||||
self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3)
|
||||
self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3)
|
||||
self.upscalem2 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3)
|
||||
self.upscalem3 = Upscale(d_mask_ch*2, d_mask_ch, kernel_size=3)
|
||||
self.out_convm = nn.Conv2D( d_mask_ch, 1, kernel_size=1, padding='SAME')
|
||||
|
||||
def forward(self, inp):
|
||||
z = inp
|
||||
|
||||
x = self.upscale0(z)
|
||||
x = self.res0(x)
|
||||
x = self.upscale1(x)
|
||||
x = self.res1(x)
|
||||
x = self.upscale2(x)
|
||||
x = self.res2(x)
|
||||
x = self.upscale3(x)
|
||||
x = self.res3(x)
|
||||
|
||||
m = self.upscalem0(z)
|
||||
m = self.upscalem1(m)
|
||||
m = self.upscalem2(m)
|
||||
m = self.upscalem3(m)
|
||||
|
||||
return tf.nn.sigmoid(self.out_conv(x)), \
|
||||
tf.nn.sigmoid(self.out_convm(m))
|
||||
elif mod == 'quick':
|
||||
class Downscale(nn.ModelBase):
|
||||
def __init__(self, in_ch, out_ch, kernel_size=5, dilations=1, subpixel=True, use_activator=True, *kwargs ):
|
||||
self.in_ch = in_ch
|
||||
self.out_ch = out_ch
|
||||
self.kernel_size = kernel_size
|
||||
self.dilations = dilations
|
||||
self.subpixel = subpixel
|
||||
self.use_activator = use_activator
|
||||
super().__init__(*kwargs)
|
||||
|
||||
def on_build(self, *args, **kwargs ):
|
||||
self.conv1 = nn.Conv2D( self.in_ch,
|
||||
self.out_ch // (4 if self.subpixel else 1),
|
||||
kernel_size=self.kernel_size,
|
||||
strides=1 if self.subpixel else 2,
|
||||
padding='SAME', dilations=self.dilations )
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
|
||||
if self.subpixel:
|
||||
x = nn.space_to_depth(x, 2)
|
||||
|
||||
if self.use_activator:
|
||||
x = nn.gelu(x)
|
||||
return x
|
||||
|
||||
def get_out_ch(self):
|
||||
return (self.out_ch // 4) * 4
|
||||
|
||||
class DownscaleBlock(nn.ModelBase):
|
||||
def on_build(self, in_ch, ch, n_downscales, kernel_size, dilations=1, subpixel=True):
|
||||
self.downs = []
|
||||
|
||||
last_ch = in_ch
|
||||
for i in range(n_downscales):
|
||||
cur_ch = ch*( min(2**i, 8) )
|
||||
self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size, dilations=dilations, subpixel=subpixel) )
|
||||
last_ch = self.downs[-1].get_out_ch()
|
||||
|
||||
def forward(self, inp):
|
||||
x = inp
|
||||
for down in self.downs:
|
||||
x = down(x)
|
||||
return x
|
||||
|
||||
class Upscale(nn.ModelBase):
|
||||
def on_build(self, in_ch, out_ch, kernel_size=3 ):
|
||||
self.conv1 = nn.Conv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME')
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = nn.gelu(x)
|
||||
x = nn.depth_to_space(x, 2)
|
||||
return x
|
||||
|
||||
class ResidualBlock(nn.ModelBase):
|
||||
def on_build(self, ch, kernel_size=3 ):
|
||||
self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
|
||||
self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
|
||||
|
||||
def forward(self, inp):
|
||||
x = self.conv1(inp)
|
||||
x = nn.gelu(x)
|
||||
x = self.conv2(x)
|
||||
x = inp + x
|
||||
x = nn.gelu(x)
|
||||
return x
|
||||
|
||||
class Encoder(nn.ModelBase):
|
||||
def on_build(self, in_ch, e_ch):
|
||||
self.down1 = DownscaleBlock(in_ch, e_ch, n_downscales=4, kernel_size=5)
|
||||
def forward(self, inp):
|
||||
return nn.flatten(self.down1(inp))
|
||||
|
||||
lowest_dense_res = resolution // 16
|
||||
|
||||
class Inter(nn.ModelBase):
|
||||
def __init__(self, in_ch, ae_ch, ae_out_ch, d_ch, **kwargs):
|
||||
self.in_ch, self.ae_ch, self.ae_out_ch, self.d_ch = in_ch, ae_ch, ae_out_ch, d_ch
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def on_build(self):
|
||||
in_ch, ae_ch, ae_out_ch, d_ch = self.in_ch, self.ae_ch, self.ae_out_ch, self.d_ch
|
||||
|
||||
self.dense1 = nn.Dense( in_ch, ae_ch, kernel_initializer=tf.initializers.orthogonal )
|
||||
self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch, kernel_initializer=tf.initializers.orthogonal )
|
||||
self.upscale1 = Upscale(ae_out_ch, d_ch*8)
|
||||
self.res1 = ResidualBlock(d_ch*8)
|
||||
|
||||
def forward(self, inp):
|
||||
x = self.dense1(inp)
|
||||
x = self.dense2(x)
|
||||
x = nn.reshape_4D (x, lowest_dense_res, lowest_dense_res, self.ae_out_ch)
|
||||
x = self.upscale1(x)
|
||||
x = self.res1(x)
|
||||
return x
|
||||
|
||||
def get_out_ch(self):
|
||||
return self.ae_out_ch
|
||||
|
||||
class Decoder(nn.ModelBase):
|
||||
def on_build(self, in_ch, d_ch):
|
||||
self.upscale1 = Upscale(in_ch, d_ch*4)
|
||||
self.res1 = ResidualBlock(d_ch*4)
|
||||
self.upscale2 = Upscale(d_ch*4, d_ch*2)
|
||||
self.res2 = ResidualBlock(d_ch*2)
|
||||
self.upscale3 = Upscale(d_ch*2, d_ch*1)
|
||||
self.res3 = ResidualBlock(d_ch*1)
|
||||
|
||||
self.upscalem1 = Upscale(in_ch, d_ch)
|
||||
self.upscalem2 = Upscale(d_ch, d_ch//2)
|
||||
self.upscalem3 = Upscale(d_ch//2, d_ch//2)
|
||||
|
||||
self.out_conv = nn.Conv2D( d_ch*1, 3, kernel_size=1, padding='SAME')
|
||||
self.out_convm = nn.Conv2D( d_ch//2, 1, kernel_size=1, padding='SAME')
|
||||
|
||||
def forward(self, inp):
|
||||
z = inp
|
||||
x = self.upscale1 (z)
|
||||
x = self.res1 (x)
|
||||
x = self.upscale2 (x)
|
||||
x = self.res2 (x)
|
||||
x = self.upscale3 (x)
|
||||
x = self.res3 (x)
|
||||
|
||||
y = self.upscalem1 (z)
|
||||
y = self.upscalem2 (y)
|
||||
y = self.upscalem3 (y)
|
||||
|
||||
return tf.nn.sigmoid(self.out_conv(x)), \
|
||||
tf.nn.sigmoid(self.out_convm(y))
|
||||
|
||||
self.Encoder = Encoder
|
||||
self.Inter = Inter
|
||||
self.Decoder = Decoder
|
||||
|
||||
nn.DeepFakeArchi = DeepFakeArchi
|
3
core/leras/archis/__init__.py
Normal file
3
core/leras/archis/__init__.py
Normal file
|
@ -0,0 +1,3 @@
|
|||
from .ArchiBase import *
|
||||
from .DeepFakeArchi import *
|
||||
from .DFLSegnet import *
|
|
@ -1,25 +1,6 @@
|
|||
import multiprocessing
|
||||
|
||||
import numpy as np
|
||||
|
||||
from core.joblib import Subprocessor
|
||||
|
||||
|
||||
def initialize_initializers(nn):
|
||||
tf = nn.tf
|
||||
from tensorflow.python.ops import init_ops
|
||||
|
||||
class initializers():
|
||||
class ca (init_ops.Initializer):
|
||||
def __call__(self, shape, dtype=None, partition_info=None):
|
||||
return tf.zeros( shape, dtype=dtype, name="_cai_")
|
||||
|
||||
@staticmethod
|
||||
def generate_batch( data_list, eps_std=0.05 ):
|
||||
# list of (shape, np.dtype)
|
||||
return CAInitializerSubprocessor (data_list).run()
|
||||
|
||||
nn.initializers = initializers
|
||||
import numpy as np
|
||||
|
||||
class CAInitializerSubprocessor(Subprocessor):
|
||||
@staticmethod
|
||||
|
@ -98,4 +79,4 @@ class CAInitializerSubprocessor(Subprocessor):
|
|||
|
||||
#override
|
||||
def get_result(self):
|
||||
return self.result
|
||||
return self.result
|
20
core/leras/initializers/__init__.py
Normal file
20
core/leras/initializers/__init__.py
Normal file
|
@ -0,0 +1,20 @@
|
|||
import numpy as np
|
||||
from tensorflow.python.ops import init_ops
|
||||
|
||||
from core.leras import nn
|
||||
|
||||
tf = nn.tf
|
||||
|
||||
from .CA import CAInitializerSubprocessor
|
||||
|
||||
class initializers():
|
||||
class ca (init_ops.Initializer):
|
||||
def __call__(self, shape, dtype=None, partition_info=None):
|
||||
return tf.zeros( shape, dtype=dtype, name="_cai_")
|
||||
|
||||
@staticmethod
|
||||
def generate_batch( data_list, eps_std=0.05 ):
|
||||
# list of (shape, np.dtype)
|
||||
return CAInitializerSubprocessor (data_list).run()
|
||||
|
||||
nn.initializers = initializers
|
|
@ -1,573 +0,0 @@
|
|||
import pickle
|
||||
from pathlib import Path
|
||||
from core import pathex
|
||||
from core.interact import interact as io
|
||||
import numpy as np
|
||||
|
||||
|
||||
def initialize_layers(nn):
|
||||
tf = nn.tf
|
||||
|
||||
class Saveable():
|
||||
def __init__(self, name=None):
|
||||
self.name = name
|
||||
|
||||
#override
|
||||
def get_weights(self):
|
||||
#return tf tensors that should be initialized/loaded/saved
|
||||
pass
|
||||
|
||||
def save_weights(self, filename, force_dtype=None):
|
||||
d = {}
|
||||
weights = self.get_weights()
|
||||
|
||||
if self.name is None:
|
||||
raise Exception("name must be defined.")
|
||||
|
||||
name = self.name
|
||||
for w, w_val in zip(weights, nn.tf_sess.run (weights)):
|
||||
w_name_split = w.name.split('/', 1)
|
||||
if name != w_name_split[0]:
|
||||
raise Exception("weight first name != Saveable.name")
|
||||
|
||||
if force_dtype is not None:
|
||||
w_val = w_val.astype(force_dtype)
|
||||
|
||||
d[ w_name_split[1] ] = w_val
|
||||
|
||||
d_dumped = pickle.dumps (d, 4)
|
||||
pathex.write_bytes_safe ( Path(filename), d_dumped )
|
||||
|
||||
def load_weights(self, filename):
|
||||
"""
|
||||
returns True if file exists
|
||||
"""
|
||||
filepath = Path(filename)
|
||||
if filepath.exists():
|
||||
result = True
|
||||
d_dumped = filepath.read_bytes()
|
||||
d = pickle.loads(d_dumped)
|
||||
else:
|
||||
return False
|
||||
|
||||
weights = self.get_weights()
|
||||
|
||||
if self.name is None:
|
||||
raise Exception("name must be defined.")
|
||||
|
||||
tuples = []
|
||||
for w in weights:
|
||||
w_name_split = w.name.split('/')
|
||||
if self.name != w_name_split[0]:
|
||||
raise Exception("weight first name != Saveable.name")
|
||||
|
||||
sub_w_name = "/".join(w_name_split[1:])
|
||||
|
||||
w_val = d.get(sub_w_name, None)
|
||||
|
||||
if w_val is None:
|
||||
#io.log_err(f"Weight {w.name} was not loaded from file {filename}")
|
||||
tuples.append ( (w, w.initializer) )
|
||||
else:
|
||||
w_val = np.reshape( w_val, w.shape.as_list() )
|
||||
tuples.append ( (w, w_val) )
|
||||
|
||||
nn.tf_batch_set_value(tuples)
|
||||
|
||||
return True
|
||||
|
||||
def init_weights(self):
|
||||
nn.tf_init_weights(self.get_weights())
|
||||
|
||||
nn.Saveable = Saveable
|
||||
|
||||
class LayerBase():
|
||||
def __init__(self, name=None, **kwargs):
|
||||
self.name = name
|
||||
|
||||
#override
|
||||
def build_weights(self):
|
||||
pass
|
||||
|
||||
#override
|
||||
def get_weights(self):
|
||||
return []
|
||||
|
||||
def set_weights(self, new_weights):
|
||||
weights = self.get_weights()
|
||||
if len(weights) != len(new_weights):
|
||||
raise ValueError ('len of lists mismatch')
|
||||
|
||||
tuples = []
|
||||
for w, new_w in zip(weights, new_weights):
|
||||
if len(w.shape) != new_w.shape:
|
||||
new_w = new_w.reshape(w.shape)
|
||||
|
||||
tuples.append ( (w, new_w) )
|
||||
|
||||
nn.tf_batch_set_value (tuples)
|
||||
nn.LayerBase = LayerBase
|
||||
|
||||
class Conv2D(LayerBase):
|
||||
"""
|
||||
use_wscale bool enables equalized learning rate, if kernel_initializer is None, it will be forced to random_normal
|
||||
|
||||
|
||||
"""
|
||||
def __init__(self, in_ch, out_ch, kernel_size, strides=1, padding='SAME', dilations=1, use_bias=True, use_wscale=False, kernel_initializer=None, bias_initializer=None, trainable=True, dtype=None, **kwargs ):
|
||||
if not isinstance(strides, int):
|
||||
raise ValueError ("strides must be an int type")
|
||||
if not isinstance(dilations, int):
|
||||
raise ValueError ("dilations must be an int type")
|
||||
kernel_size = int(kernel_size)
|
||||
|
||||
if dtype is None:
|
||||
dtype = nn.tf_floatx
|
||||
|
||||
if isinstance(padding, str):
|
||||
if padding == "SAME":
|
||||
padding = ( (kernel_size - 1) * dilations + 1 ) // 2
|
||||
elif padding == "VALID":
|
||||
padding = 0
|
||||
else:
|
||||
raise ValueError ("Wrong padding type. Should be VALID SAME or INT or 4x INTs")
|
||||
|
||||
if isinstance(padding, int):
|
||||
if padding != 0:
|
||||
if nn.data_format == "NHWC":
|
||||
padding = [ [0,0], [padding,padding], [padding,padding], [0,0] ]
|
||||
else:
|
||||
padding = [ [0,0], [0,0], [padding,padding], [padding,padding] ]
|
||||
else:
|
||||
padding = None
|
||||
|
||||
if nn.data_format == "NHWC":
|
||||
strides = [1,strides,strides,1]
|
||||
else:
|
||||
strides = [1,1,strides,strides]
|
||||
|
||||
if nn.data_format == "NHWC":
|
||||
dilations = [1,dilations,dilations,1]
|
||||
else:
|
||||
dilations = [1,1,dilations,dilations]
|
||||
|
||||
self.in_ch = in_ch
|
||||
self.out_ch = out_ch
|
||||
self.kernel_size = kernel_size
|
||||
self.strides = strides
|
||||
self.padding = padding
|
||||
self.dilations = dilations
|
||||
self.use_bias = use_bias
|
||||
self.use_wscale = use_wscale
|
||||
self.kernel_initializer = kernel_initializer
|
||||
self.bias_initializer = bias_initializer
|
||||
self.trainable = trainable
|
||||
self.dtype = dtype
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def build_weights(self):
|
||||
kernel_initializer = self.kernel_initializer
|
||||
if self.use_wscale:
|
||||
gain = 1.0 if self.kernel_size == 1 else np.sqrt(2)
|
||||
fan_in = self.kernel_size*self.kernel_size*self.in_ch
|
||||
he_std = gain / np.sqrt(fan_in) # He init
|
||||
self.wscale = tf.constant(he_std, dtype=self.dtype )
|
||||
if kernel_initializer is None:
|
||||
kernel_initializer = tf.initializers.random_normal(0, 1.0, dtype=self.dtype)
|
||||
|
||||
if kernel_initializer is None:
|
||||
kernel_initializer = tf.initializers.glorot_uniform(dtype=self.dtype)
|
||||
|
||||
self.weight = tf.get_variable("weight", (self.kernel_size,self.kernel_size,self.in_ch,self.out_ch), dtype=self.dtype, initializer=kernel_initializer, trainable=self.trainable )
|
||||
|
||||
if self.use_bias:
|
||||
bias_initializer = self.bias_initializer
|
||||
if bias_initializer is None:
|
||||
bias_initializer = tf.initializers.zeros(dtype=self.dtype)
|
||||
|
||||
self.bias = tf.get_variable("bias", (self.out_ch,), dtype=self.dtype, initializer=bias_initializer, trainable=self.trainable )
|
||||
|
||||
def get_weights(self):
|
||||
weights = [self.weight]
|
||||
if self.use_bias:
|
||||
weights += [self.bias]
|
||||
return weights
|
||||
|
||||
def __call__(self, x):
|
||||
weight = self.weight
|
||||
if self.use_wscale:
|
||||
weight = weight * self.wscale
|
||||
|
||||
if self.padding is not None:
|
||||
x = tf.pad (x, self.padding, mode='CONSTANT')
|
||||
|
||||
x = tf.nn.conv2d(x, weight, self.strides, 'VALID', dilations=self.dilations, data_format=nn.data_format)
|
||||
if self.use_bias:
|
||||
if nn.data_format == "NHWC":
|
||||
bias = tf.reshape (self.bias, (1,1,1,self.out_ch) )
|
||||
else:
|
||||
bias = tf.reshape (self.bias, (1,self.out_ch,1,1) )
|
||||
x = tf.add(x, bias)
|
||||
return x
|
||||
|
||||
def __str__(self):
|
||||
r = f"{self.__class__.__name__} : in_ch:{self.in_ch} out_ch:{self.out_ch} "
|
||||
|
||||
return r
|
||||
nn.Conv2D = Conv2D
|
||||
|
||||
class Conv2DTranspose(LayerBase):
|
||||
"""
|
||||
use_wscale enables weight scale (equalized learning rate)
|
||||
if kernel_initializer is None, it will be forced to random_normal
|
||||
"""
|
||||
def __init__(self, in_ch, out_ch, kernel_size, strides=2, padding='SAME', use_bias=True, use_wscale=False, kernel_initializer=None, bias_initializer=None, trainable=True, dtype=None, **kwargs ):
|
||||
if not isinstance(strides, int):
|
||||
raise ValueError ("strides must be an int type")
|
||||
kernel_size = int(kernel_size)
|
||||
|
||||
if dtype is None:
|
||||
dtype = nn.tf_floatx
|
||||
|
||||
self.in_ch = in_ch
|
||||
self.out_ch = out_ch
|
||||
self.kernel_size = kernel_size
|
||||
self.strides = strides
|
||||
self.padding = padding
|
||||
self.use_bias = use_bias
|
||||
self.use_wscale = use_wscale
|
||||
self.kernel_initializer = kernel_initializer
|
||||
self.bias_initializer = bias_initializer
|
||||
self.trainable = trainable
|
||||
self.dtype = dtype
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def build_weights(self):
|
||||
kernel_initializer = self.kernel_initializer
|
||||
if self.use_wscale:
|
||||
gain = 1.0 if self.kernel_size == 1 else np.sqrt(2)
|
||||
fan_in = self.kernel_size*self.kernel_size*self.in_ch
|
||||
he_std = gain / np.sqrt(fan_in) # He init
|
||||
self.wscale = tf.constant(he_std, dtype=self.dtype )
|
||||
if kernel_initializer is None:
|
||||
kernel_initializer = tf.initializers.random_normal(0, 1.0, dtype=self.dtype)
|
||||
|
||||
if kernel_initializer is None:
|
||||
kernel_initializer = tf.initializers.glorot_uniform(dtype=self.dtype)
|
||||
self.weight = tf.get_variable("weight", (self.kernel_size,self.kernel_size,self.out_ch,self.in_ch), dtype=self.dtype, initializer=kernel_initializer, trainable=self.trainable )
|
||||
|
||||
if self.use_bias:
|
||||
bias_initializer = self.bias_initializer
|
||||
if bias_initializer is None:
|
||||
bias_initializer = tf.initializers.zeros(dtype=self.dtype)
|
||||
|
||||
self.bias = tf.get_variable("bias", (self.out_ch,), dtype=self.dtype, initializer=bias_initializer, trainable=self.trainable )
|
||||
|
||||
def get_weights(self):
|
||||
weights = [self.weight]
|
||||
if self.use_bias:
|
||||
weights += [self.bias]
|
||||
return weights
|
||||
|
||||
def __call__(self, x):
|
||||
shape = x.shape
|
||||
|
||||
if nn.data_format == "NHWC":
|
||||
h,w,c = shape[1], shape[2], shape[3]
|
||||
output_shape = tf.stack ( (tf.shape(x)[0],
|
||||
self.deconv_length(w, self.strides, self.kernel_size, self.padding),
|
||||
self.deconv_length(h, self.strides, self.kernel_size, self.padding),
|
||||
self.out_ch) )
|
||||
|
||||
strides = [1,self.strides,self.strides,1]
|
||||
else:
|
||||
c,h,w = shape[1], shape[2], shape[3]
|
||||
output_shape = tf.stack ( (tf.shape(x)[0],
|
||||
self.out_ch,
|
||||
self.deconv_length(w, self.strides, self.kernel_size, self.padding),
|
||||
self.deconv_length(h, self.strides, self.kernel_size, self.padding),
|
||||
) )
|
||||
strides = [1,1,self.strides,self.strides]
|
||||
weight = self.weight
|
||||
if self.use_wscale:
|
||||
weight = weight * self.wscale
|
||||
|
||||
x = tf.nn.conv2d_transpose(x, weight, output_shape, strides, padding=self.padding, data_format=nn.data_format)
|
||||
|
||||
if self.use_bias:
|
||||
if nn.data_format == "NHWC":
|
||||
bias = tf.reshape (self.bias, (1,1,1,self.out_ch) )
|
||||
else:
|
||||
bias = tf.reshape (self.bias, (1,self.out_ch,1,1) )
|
||||
x = tf.add(x, bias)
|
||||
return x
|
||||
|
||||
def __str__(self):
|
||||
r = f"{self.__class__.__name__} : in_ch:{self.in_ch} out_ch:{self.out_ch} "
|
||||
|
||||
return r
|
||||
|
||||
def deconv_length(self, dim_size, stride_size, kernel_size, padding):
|
||||
assert padding in {'SAME', 'VALID', 'FULL'}
|
||||
if dim_size is None:
|
||||
return None
|
||||
if padding == 'VALID':
|
||||
dim_size = dim_size * stride_size + max(kernel_size - stride_size, 0)
|
||||
elif padding == 'FULL':
|
||||
dim_size = dim_size * stride_size - (stride_size + kernel_size - 2)
|
||||
elif padding == 'SAME':
|
||||
dim_size = dim_size * stride_size
|
||||
return dim_size
|
||||
nn.Conv2DTranspose = Conv2DTranspose
|
||||
|
||||
class BlurPool(LayerBase):
|
||||
def __init__(self, filt_size=3, stride=2, **kwargs ):
|
||||
|
||||
if nn.data_format == "NHWC":
|
||||
self.strides = [1,stride,stride,1]
|
||||
else:
|
||||
self.strides = [1,1,stride,stride]
|
||||
|
||||
self.filt_size = filt_size
|
||||
pad = [ int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2)) ]
|
||||
|
||||
if nn.data_format == "NHWC":
|
||||
self.padding = [ [0,0], pad, pad, [0,0] ]
|
||||
else:
|
||||
self.padding = [ [0,0], [0,0], pad, pad ]
|
||||
|
||||
if(self.filt_size==1):
|
||||
a = np.array([1.,])
|
||||
elif(self.filt_size==2):
|
||||
a = np.array([1., 1.])
|
||||
elif(self.filt_size==3):
|
||||
a = np.array([1., 2., 1.])
|
||||
elif(self.filt_size==4):
|
||||
a = np.array([1., 3., 3., 1.])
|
||||
elif(self.filt_size==5):
|
||||
a = np.array([1., 4., 6., 4., 1.])
|
||||
elif(self.filt_size==6):
|
||||
a = np.array([1., 5., 10., 10., 5., 1.])
|
||||
elif(self.filt_size==7):
|
||||
a = np.array([1., 6., 15., 20., 15., 6., 1.])
|
||||
|
||||
a = a[:,None]*a[None,:]
|
||||
a = a / np.sum(a)
|
||||
a = a[:,:,None,None]
|
||||
self.a = a
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def build_weights(self):
|
||||
self.k = tf.constant (self.a, dtype=nn.tf_floatx )
|
||||
|
||||
def __call__(self, x):
|
||||
k = tf.tile (self.k, (1,1,x.shape[nn.conv2d_ch_axis],1) )
|
||||
x = tf.pad(x, self.padding )
|
||||
x = tf.nn.depthwise_conv2d(x, k, self.strides, 'VALID', data_format=nn.data_format)
|
||||
return x
|
||||
nn.BlurPool = BlurPool
|
||||
|
||||
class Dense(LayerBase):
|
||||
def __init__(self, in_ch, out_ch, use_bias=True, use_wscale=False, maxout_ch=0, kernel_initializer=None, bias_initializer=None, trainable=True, dtype=None, **kwargs ):
|
||||
"""
|
||||
use_wscale enables weight scale (equalized learning rate)
|
||||
if kernel_initializer is None, it will be forced to random_normal
|
||||
|
||||
maxout_ch https://link.springer.com/article/10.1186/s40537-019-0233-0
|
||||
typical 2-4 if you want to enable DenseMaxout behaviour
|
||||
"""
|
||||
self.in_ch = in_ch
|
||||
self.out_ch = out_ch
|
||||
self.use_bias = use_bias
|
||||
self.use_wscale = use_wscale
|
||||
self.maxout_ch = maxout_ch
|
||||
self.kernel_initializer = kernel_initializer
|
||||
self.bias_initializer = bias_initializer
|
||||
self.trainable = trainable
|
||||
if dtype is None:
|
||||
dtype = nn.tf_floatx
|
||||
|
||||
self.dtype = dtype
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def build_weights(self):
|
||||
if self.maxout_ch > 1:
|
||||
weight_shape = (self.in_ch,self.out_ch*self.maxout_ch)
|
||||
else:
|
||||
weight_shape = (self.in_ch,self.out_ch)
|
||||
|
||||
kernel_initializer = self.kernel_initializer
|
||||
|
||||
if self.use_wscale:
|
||||
gain = 1.0
|
||||
fan_in = np.prod( weight_shape[:-1] )
|
||||
he_std = gain / np.sqrt(fan_in) # He init
|
||||
self.wscale = tf.constant(he_std, dtype=self.dtype )
|
||||
if kernel_initializer is None:
|
||||
kernel_initializer = tf.initializers.random_normal(0, 1.0, dtype=self.dtype)
|
||||
|
||||
if kernel_initializer is None:
|
||||
kernel_initializer = tf.initializers.glorot_uniform(dtype=self.dtype)
|
||||
|
||||
self.weight = tf.get_variable("weight", weight_shape, dtype=self.dtype, initializer=kernel_initializer, trainable=self.trainable )
|
||||
|
||||
if self.use_bias:
|
||||
bias_initializer = self.bias_initializer
|
||||
if bias_initializer is None:
|
||||
bias_initializer = tf.initializers.zeros(dtype=self.dtype)
|
||||
self.bias = tf.get_variable("bias", (self.out_ch,), dtype=self.dtype, initializer=bias_initializer, trainable=self.trainable )
|
||||
|
||||
def get_weights(self):
|
||||
weights = [self.weight]
|
||||
if self.use_bias:
|
||||
weights += [self.bias]
|
||||
return weights
|
||||
|
||||
def __call__(self, x):
|
||||
weight = self.weight
|
||||
if self.use_wscale:
|
||||
weight = weight * self.wscale
|
||||
|
||||
x = tf.matmul(x, weight)
|
||||
|
||||
if self.maxout_ch > 1:
|
||||
x = tf.reshape (x, (-1, self.out_ch, self.maxout_ch) )
|
||||
x = tf.reduce_max(x, axis=-1)
|
||||
|
||||
if self.use_bias:
|
||||
x = tf.add(x, tf.reshape(self.bias, (1,self.out_ch) ) )
|
||||
|
||||
return x
|
||||
nn.Dense = Dense
|
||||
|
||||
class InstanceNorm2D(LayerBase):
|
||||
def __init__(self, in_ch, dtype=None, **kwargs):
|
||||
self.in_ch = in_ch
|
||||
|
||||
if dtype is None:
|
||||
dtype = nn.tf_floatx
|
||||
self.dtype = dtype
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def build_weights(self):
|
||||
kernel_initializer = tf.initializers.glorot_uniform(dtype=self.dtype)
|
||||
self.weight = tf.get_variable("weight", (self.in_ch,), dtype=self.dtype, initializer=kernel_initializer )
|
||||
self.bias = tf.get_variable("bias", (self.in_ch,), dtype=self.dtype, initializer=tf.initializers.zeros() )
|
||||
|
||||
def get_weights(self):
|
||||
return [self.weight, self.bias]
|
||||
|
||||
def __call__(self, x):
|
||||
if nn.data_format == "NHWC":
|
||||
shape = (1,1,1,self.in_ch)
|
||||
else:
|
||||
shape = (1,self.in_ch,1,1)
|
||||
|
||||
weight = tf.reshape ( self.weight , shape )
|
||||
bias = tf.reshape ( self.bias , shape )
|
||||
|
||||
x_mean = tf.reduce_mean(x, axis=nn.conv2d_spatial_axes, keepdims=True )
|
||||
x_std = tf.math.reduce_std(x, axis=nn.conv2d_spatial_axes, keepdims=True ) + 1e-5
|
||||
|
||||
x = (x - x_mean) / x_std
|
||||
x *= weight
|
||||
x += bias
|
||||
|
||||
return x
|
||||
|
||||
nn.InstanceNorm2D = InstanceNorm2D
|
||||
|
||||
class BatchNorm2D(LayerBase):
|
||||
"""
|
||||
currently not for training
|
||||
"""
|
||||
def __init__(self, dim, eps=1e-05, momentum=0.1, dtype=None, **kwargs):
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
self.momentum = momentum
|
||||
if dtype is None:
|
||||
dtype = nn.tf_floatx
|
||||
self.dtype = dtype
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def build_weights(self):
|
||||
self.weight = tf.get_variable("weight", (self.dim,), dtype=self.dtype, initializer=tf.initializers.ones() )
|
||||
self.bias = tf.get_variable("bias", (self.dim,), dtype=self.dtype, initializer=tf.initializers.zeros() )
|
||||
self.running_mean = tf.get_variable("running_mean", (self.dim,), dtype=self.dtype, initializer=tf.initializers.zeros(), trainable=False )
|
||||
self.running_var = tf.get_variable("running_var", (self.dim,), dtype=self.dtype, initializer=tf.initializers.zeros(), trainable=False )
|
||||
|
||||
def get_weights(self):
|
||||
return [self.weight, self.bias, self.running_mean, self.running_var]
|
||||
|
||||
def __call__(self, x):
|
||||
if nn.data_format == "NHWC":
|
||||
shape = (1,1,1,self.dim)
|
||||
else:
|
||||
shape = (1,self.dim,1,1)
|
||||
|
||||
weight = tf.reshape ( self.weight , shape )
|
||||
bias = tf.reshape ( self.bias , shape )
|
||||
running_mean = tf.reshape ( self.running_mean, shape )
|
||||
running_var = tf.reshape ( self.running_var , shape )
|
||||
|
||||
x = (x - running_mean) / tf.sqrt( running_var + self.eps )
|
||||
x *= weight
|
||||
x += bias
|
||||
return x
|
||||
|
||||
nn.BatchNorm2D = BatchNorm2D
|
||||
|
||||
class AdaIN(LayerBase):
|
||||
"""
|
||||
"""
|
||||
def __init__(self, in_ch, mlp_ch, kernel_initializer=None, dtype=None, **kwargs):
|
||||
self.in_ch = in_ch
|
||||
self.mlp_ch = mlp_ch
|
||||
self.kernel_initializer = kernel_initializer
|
||||
|
||||
if dtype is None:
|
||||
dtype = nn.tf_floatx
|
||||
self.dtype = dtype
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def build_weights(self):
|
||||
kernel_initializer = self.kernel_initializer
|
||||
if kernel_initializer is None:
|
||||
kernel_initializer = tf.initializers.he_normal()#(dtype=self.dtype)
|
||||
|
||||
self.weight1 = tf.get_variable("weight1", (self.mlp_ch, self.in_ch), dtype=self.dtype, initializer=kernel_initializer)
|
||||
self.bias1 = tf.get_variable("bias1", (self.in_ch,), dtype=self.dtype, initializer=tf.initializers.zeros())
|
||||
self.weight2 = tf.get_variable("weight2", (self.mlp_ch, self.in_ch), dtype=self.dtype, initializer=kernel_initializer)
|
||||
self.bias2 = tf.get_variable("bias2", (self.in_ch,), dtype=self.dtype, initializer=tf.initializers.zeros())
|
||||
|
||||
def get_weights(self):
|
||||
return [self.weight1, self.bias1, self.weight2, self.bias2]
|
||||
|
||||
def __call__(self, inputs):
|
||||
x, mlp = inputs
|
||||
|
||||
gamma = tf.matmul(mlp, self.weight1)
|
||||
gamma = tf.add(gamma, tf.reshape(self.bias1, (1,self.in_ch) ) )
|
||||
|
||||
beta = tf.matmul(mlp, self.weight2)
|
||||
beta = tf.add(beta, tf.reshape(self.bias2, (1,self.in_ch) ) )
|
||||
|
||||
|
||||
if nn.data_format == "NHWC":
|
||||
shape = (-1,1,1,self.in_ch)
|
||||
else:
|
||||
shape = (-1,self.in_ch,1,1)
|
||||
|
||||
x_mean = tf.reduce_mean(x, axis=nn.conv2d_spatial_axes, keepdims=True )
|
||||
x_std = tf.math.reduce_std(x, axis=nn.conv2d_spatial_axes, keepdims=True ) + 1e-5
|
||||
|
||||
x = (x - x_mean) / x_std
|
||||
x *= tf.reshape(gamma, shape)
|
||||
|
||||
x += tf.reshape(beta, shape)
|
||||
|
||||
return x
|
||||
|
||||
nn.AdaIN = AdaIN
|
56
core/leras/layers/AdaIN.py
Normal file
56
core/leras/layers/AdaIN.py
Normal file
|
@ -0,0 +1,56 @@
|
|||
from core.leras import nn
|
||||
tf = nn.tf
|
||||
|
||||
class AdaIN(nn.LayerBase):
|
||||
"""
|
||||
"""
|
||||
def __init__(self, in_ch, mlp_ch, kernel_initializer=None, dtype=None, **kwargs):
|
||||
self.in_ch = in_ch
|
||||
self.mlp_ch = mlp_ch
|
||||
self.kernel_initializer = kernel_initializer
|
||||
|
||||
if dtype is None:
|
||||
dtype = nn.floatx
|
||||
self.dtype = dtype
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def build_weights(self):
|
||||
kernel_initializer = self.kernel_initializer
|
||||
if kernel_initializer is None:
|
||||
kernel_initializer = tf.initializers.he_normal()
|
||||
|
||||
self.weight1 = tf.get_variable("weight1", (self.mlp_ch, self.in_ch), dtype=self.dtype, initializer=kernel_initializer)
|
||||
self.bias1 = tf.get_variable("bias1", (self.in_ch,), dtype=self.dtype, initializer=tf.initializers.zeros())
|
||||
self.weight2 = tf.get_variable("weight2", (self.mlp_ch, self.in_ch), dtype=self.dtype, initializer=kernel_initializer)
|
||||
self.bias2 = tf.get_variable("bias2", (self.in_ch,), dtype=self.dtype, initializer=tf.initializers.zeros())
|
||||
|
||||
def get_weights(self):
|
||||
return [self.weight1, self.bias1, self.weight2, self.bias2]
|
||||
|
||||
def forward(self, inputs):
|
||||
x, mlp = inputs
|
||||
|
||||
gamma = tf.matmul(mlp, self.weight1)
|
||||
gamma = tf.add(gamma, tf.reshape(self.bias1, (1,self.in_ch) ) )
|
||||
|
||||
beta = tf.matmul(mlp, self.weight2)
|
||||
beta = tf.add(beta, tf.reshape(self.bias2, (1,self.in_ch) ) )
|
||||
|
||||
|
||||
if nn.data_format == "NHWC":
|
||||
shape = (-1,1,1,self.in_ch)
|
||||
else:
|
||||
shape = (-1,self.in_ch,1,1)
|
||||
|
||||
x_mean = tf.reduce_mean(x, axis=nn.conv2d_spatial_axes, keepdims=True )
|
||||
x_std = tf.math.reduce_std(x, axis=nn.conv2d_spatial_axes, keepdims=True ) + 1e-5
|
||||
|
||||
x = (x - x_mean) / x_std
|
||||
x *= tf.reshape(gamma, shape)
|
||||
|
||||
x += tf.reshape(beta, shape)
|
||||
|
||||
return x
|
||||
|
||||
nn.AdaIN = AdaIN
|
42
core/leras/layers/BatchNorm2D.py
Normal file
42
core/leras/layers/BatchNorm2D.py
Normal file
|
@ -0,0 +1,42 @@
|
|||
from core.leras import nn
|
||||
tf = nn.tf
|
||||
|
||||
class BatchNorm2D(nn.LayerBase):
|
||||
"""
|
||||
currently not for training
|
||||
"""
|
||||
def __init__(self, dim, eps=1e-05, momentum=0.1, dtype=None, **kwargs):
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
self.momentum = momentum
|
||||
if dtype is None:
|
||||
dtype = nn.floatx
|
||||
self.dtype = dtype
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def build_weights(self):
|
||||
self.weight = tf.get_variable("weight", (self.dim,), dtype=self.dtype, initializer=tf.initializers.ones() )
|
||||
self.bias = tf.get_variable("bias", (self.dim,), dtype=self.dtype, initializer=tf.initializers.zeros() )
|
||||
self.running_mean = tf.get_variable("running_mean", (self.dim,), dtype=self.dtype, initializer=tf.initializers.zeros(), trainable=False )
|
||||
self.running_var = tf.get_variable("running_var", (self.dim,), dtype=self.dtype, initializer=tf.initializers.zeros(), trainable=False )
|
||||
|
||||
def get_weights(self):
|
||||
return [self.weight, self.bias, self.running_mean, self.running_var]
|
||||
|
||||
def forward(self, x):
|
||||
if nn.data_format == "NHWC":
|
||||
shape = (1,1,1,self.dim)
|
||||
else:
|
||||
shape = (1,self.dim,1,1)
|
||||
|
||||
weight = tf.reshape ( self.weight , shape )
|
||||
bias = tf.reshape ( self.bias , shape )
|
||||
running_mean = tf.reshape ( self.running_mean, shape )
|
||||
running_var = tf.reshape ( self.running_var , shape )
|
||||
|
||||
x = (x - running_mean) / tf.sqrt( running_var + self.eps )
|
||||
x *= weight
|
||||
x += bias
|
||||
return x
|
||||
|
||||
nn.BatchNorm2D = BatchNorm2D
|
50
core/leras/layers/BlurPool.py
Normal file
50
core/leras/layers/BlurPool.py
Normal file
|
@ -0,0 +1,50 @@
|
|||
import numpy as np
|
||||
from core.leras import nn
|
||||
tf = nn.tf
|
||||
|
||||
class BlurPool(nn.LayerBase):
|
||||
def __init__(self, filt_size=3, stride=2, **kwargs ):
|
||||
|
||||
if nn.data_format == "NHWC":
|
||||
self.strides = [1,stride,stride,1]
|
||||
else:
|
||||
self.strides = [1,1,stride,stride]
|
||||
|
||||
self.filt_size = filt_size
|
||||
pad = [ int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2)) ]
|
||||
|
||||
if nn.data_format == "NHWC":
|
||||
self.padding = [ [0,0], pad, pad, [0,0] ]
|
||||
else:
|
||||
self.padding = [ [0,0], [0,0], pad, pad ]
|
||||
|
||||
if(self.filt_size==1):
|
||||
a = np.array([1.,])
|
||||
elif(self.filt_size==2):
|
||||
a = np.array([1., 1.])
|
||||
elif(self.filt_size==3):
|
||||
a = np.array([1., 2., 1.])
|
||||
elif(self.filt_size==4):
|
||||
a = np.array([1., 3., 3., 1.])
|
||||
elif(self.filt_size==5):
|
||||
a = np.array([1., 4., 6., 4., 1.])
|
||||
elif(self.filt_size==6):
|
||||
a = np.array([1., 5., 10., 10., 5., 1.])
|
||||
elif(self.filt_size==7):
|
||||
a = np.array([1., 6., 15., 20., 15., 6., 1.])
|
||||
|
||||
a = a[:,None]*a[None,:]
|
||||
a = a / np.sum(a)
|
||||
a = a[:,:,None,None]
|
||||
self.a = a
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def build_weights(self):
|
||||
self.k = tf.constant (self.a, dtype=nn.floatx )
|
||||
|
||||
def forward(self, x):
|
||||
k = tf.tile (self.k, (1,1,x.shape[nn.conv2d_ch_axis],1) )
|
||||
x = tf.pad(x, self.padding )
|
||||
x = tf.nn.depthwise_conv2d(x, k, self.strides, 'VALID', data_format=nn.data_format)
|
||||
return x
|
||||
nn.BlurPool = BlurPool
|
112
core/leras/layers/Conv2D.py
Normal file
112
core/leras/layers/Conv2D.py
Normal file
|
@ -0,0 +1,112 @@
|
|||
import numpy as np
|
||||
from core.leras import nn
|
||||
tf = nn.tf
|
||||
|
||||
class Conv2D(nn.LayerBase):
|
||||
"""
|
||||
default kernel_initializer - CA
|
||||
use_wscale bool enables equalized learning rate, if kernel_initializer is None, it will be forced to random_normal
|
||||
|
||||
|
||||
"""
|
||||
def __init__(self, in_ch, out_ch, kernel_size, strides=1, padding='SAME', dilations=1, use_bias=True, use_wscale=False, kernel_initializer=None, bias_initializer=None, trainable=True, dtype=None, **kwargs ):
|
||||
if not isinstance(strides, int):
|
||||
raise ValueError ("strides must be an int type")
|
||||
if not isinstance(dilations, int):
|
||||
raise ValueError ("dilations must be an int type")
|
||||
kernel_size = int(kernel_size)
|
||||
|
||||
if dtype is None:
|
||||
dtype = nn.floatx
|
||||
|
||||
if isinstance(padding, str):
|
||||
if padding == "SAME":
|
||||
padding = ( (kernel_size - 1) * dilations + 1 ) // 2
|
||||
elif padding == "VALID":
|
||||
padding = 0
|
||||
else:
|
||||
raise ValueError ("Wrong padding type. Should be VALID SAME or INT or 4x INTs")
|
||||
|
||||
if isinstance(padding, int):
|
||||
if padding != 0:
|
||||
if nn.data_format == "NHWC":
|
||||
padding = [ [0,0], [padding,padding], [padding,padding], [0,0] ]
|
||||
else:
|
||||
padding = [ [0,0], [0,0], [padding,padding], [padding,padding] ]
|
||||
else:
|
||||
padding = None
|
||||
|
||||
if nn.data_format == "NHWC":
|
||||
strides = [1,strides,strides,1]
|
||||
else:
|
||||
strides = [1,1,strides,strides]
|
||||
|
||||
if nn.data_format == "NHWC":
|
||||
dilations = [1,dilations,dilations,1]
|
||||
else:
|
||||
dilations = [1,1,dilations,dilations]
|
||||
|
||||
self.in_ch = in_ch
|
||||
self.out_ch = out_ch
|
||||
self.kernel_size = kernel_size
|
||||
self.strides = strides
|
||||
self.padding = padding
|
||||
self.dilations = dilations
|
||||
self.use_bias = use_bias
|
||||
self.use_wscale = use_wscale
|
||||
self.kernel_initializer = kernel_initializer
|
||||
self.bias_initializer = bias_initializer
|
||||
self.trainable = trainable
|
||||
self.dtype = dtype
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def build_weights(self):
|
||||
kernel_initializer = self.kernel_initializer
|
||||
if self.use_wscale:
|
||||
gain = 1.0 if self.kernel_size == 1 else np.sqrt(2)
|
||||
fan_in = self.kernel_size*self.kernel_size*self.in_ch
|
||||
he_std = gain / np.sqrt(fan_in)
|
||||
self.wscale = tf.constant(he_std, dtype=self.dtype )
|
||||
if kernel_initializer is None:
|
||||
kernel_initializer = tf.initializers.random_normal(0, 1.0, dtype=self.dtype)
|
||||
|
||||
if kernel_initializer is None:
|
||||
kernel_initializer = nn.initializers.ca()
|
||||
|
||||
self.weight = tf.get_variable("weight", (self.kernel_size,self.kernel_size,self.in_ch,self.out_ch), dtype=self.dtype, initializer=kernel_initializer, trainable=self.trainable )
|
||||
|
||||
if self.use_bias:
|
||||
bias_initializer = self.bias_initializer
|
||||
if bias_initializer is None:
|
||||
bias_initializer = tf.initializers.zeros(dtype=self.dtype)
|
||||
|
||||
self.bias = tf.get_variable("bias", (self.out_ch,), dtype=self.dtype, initializer=bias_initializer, trainable=self.trainable )
|
||||
|
||||
def get_weights(self):
|
||||
weights = [self.weight]
|
||||
if self.use_bias:
|
||||
weights += [self.bias]
|
||||
return weights
|
||||
|
||||
def forward(self, x):
|
||||
weight = self.weight
|
||||
if self.use_wscale:
|
||||
weight = weight * self.wscale
|
||||
|
||||
if self.padding is not None:
|
||||
x = tf.pad (x, self.padding, mode='CONSTANT')
|
||||
|
||||
x = tf.nn.conv2d(x, weight, self.strides, 'VALID', dilations=self.dilations, data_format=nn.data_format)
|
||||
if self.use_bias:
|
||||
if nn.data_format == "NHWC":
|
||||
bias = tf.reshape (self.bias, (1,1,1,self.out_ch) )
|
||||
else:
|
||||
bias = tf.reshape (self.bias, (1,self.out_ch,1,1) )
|
||||
x = tf.add(x, bias)
|
||||
return x
|
||||
|
||||
def __str__(self):
|
||||
r = f"{self.__class__.__name__} : in_ch:{self.in_ch} out_ch:{self.out_ch} "
|
||||
|
||||
return r
|
||||
nn.Conv2D = Conv2D
|
107
core/leras/layers/Conv2DTranspose.py
Normal file
107
core/leras/layers/Conv2DTranspose.py
Normal file
|
@ -0,0 +1,107 @@
|
|||
import numpy as np
|
||||
from core.leras import nn
|
||||
tf = nn.tf
|
||||
|
||||
class Conv2DTranspose(nn.LayerBase):
|
||||
"""
|
||||
use_wscale enables weight scale (equalized learning rate)
|
||||
if kernel_initializer is None, it will be forced to random_normal
|
||||
"""
|
||||
def __init__(self, in_ch, out_ch, kernel_size, strides=2, padding='SAME', use_bias=True, use_wscale=False, kernel_initializer=None, bias_initializer=None, trainable=True, dtype=None, **kwargs ):
|
||||
if not isinstance(strides, int):
|
||||
raise ValueError ("strides must be an int type")
|
||||
kernel_size = int(kernel_size)
|
||||
|
||||
if dtype is None:
|
||||
dtype = nn.floatx
|
||||
|
||||
self.in_ch = in_ch
|
||||
self.out_ch = out_ch
|
||||
self.kernel_size = kernel_size
|
||||
self.strides = strides
|
||||
self.padding = padding
|
||||
self.use_bias = use_bias
|
||||
self.use_wscale = use_wscale
|
||||
self.kernel_initializer = kernel_initializer
|
||||
self.bias_initializer = bias_initializer
|
||||
self.trainable = trainable
|
||||
self.dtype = dtype
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def build_weights(self):
|
||||
kernel_initializer = self.kernel_initializer
|
||||
if self.use_wscale:
|
||||
gain = 1.0 if self.kernel_size == 1 else np.sqrt(2)
|
||||
fan_in = self.kernel_size*self.kernel_size*self.in_ch
|
||||
he_std = gain / np.sqrt(fan_in) # He init
|
||||
self.wscale = tf.constant(he_std, dtype=self.dtype )
|
||||
if kernel_initializer is None:
|
||||
kernel_initializer = tf.initializers.random_normal(0, 1.0, dtype=self.dtype)
|
||||
|
||||
if kernel_initializer is None:
|
||||
kernel_initializer = nn.initializers.ca()
|
||||
self.weight = tf.get_variable("weight", (self.kernel_size,self.kernel_size,self.out_ch,self.in_ch), dtype=self.dtype, initializer=kernel_initializer, trainable=self.trainable )
|
||||
|
||||
if self.use_bias:
|
||||
bias_initializer = self.bias_initializer
|
||||
if bias_initializer is None:
|
||||
bias_initializer = tf.initializers.zeros(dtype=self.dtype)
|
||||
|
||||
self.bias = tf.get_variable("bias", (self.out_ch,), dtype=self.dtype, initializer=bias_initializer, trainable=self.trainable )
|
||||
|
||||
def get_weights(self):
|
||||
weights = [self.weight]
|
||||
if self.use_bias:
|
||||
weights += [self.bias]
|
||||
return weights
|
||||
|
||||
def forward(self, x):
|
||||
shape = x.shape
|
||||
|
||||
if nn.data_format == "NHWC":
|
||||
h,w,c = shape[1], shape[2], shape[3]
|
||||
output_shape = tf.stack ( (tf.shape(x)[0],
|
||||
self.deconv_length(w, self.strides, self.kernel_size, self.padding),
|
||||
self.deconv_length(h, self.strides, self.kernel_size, self.padding),
|
||||
self.out_ch) )
|
||||
|
||||
strides = [1,self.strides,self.strides,1]
|
||||
else:
|
||||
c,h,w = shape[1], shape[2], shape[3]
|
||||
output_shape = tf.stack ( (tf.shape(x)[0],
|
||||
self.out_ch,
|
||||
self.deconv_length(w, self.strides, self.kernel_size, self.padding),
|
||||
self.deconv_length(h, self.strides, self.kernel_size, self.padding),
|
||||
) )
|
||||
strides = [1,1,self.strides,self.strides]
|
||||
weight = self.weight
|
||||
if self.use_wscale:
|
||||
weight = weight * self.wscale
|
||||
|
||||
x = tf.nn.conv2d_transpose(x, weight, output_shape, strides, padding=self.padding, data_format=nn.data_format)
|
||||
|
||||
if self.use_bias:
|
||||
if nn.data_format == "NHWC":
|
||||
bias = tf.reshape (self.bias, (1,1,1,self.out_ch) )
|
||||
else:
|
||||
bias = tf.reshape (self.bias, (1,self.out_ch,1,1) )
|
||||
x = tf.add(x, bias)
|
||||
return x
|
||||
|
||||
def __str__(self):
|
||||
r = f"{self.__class__.__name__} : in_ch:{self.in_ch} out_ch:{self.out_ch} "
|
||||
|
||||
return r
|
||||
|
||||
def deconv_length(self, dim_size, stride_size, kernel_size, padding):
|
||||
assert padding in {'SAME', 'VALID', 'FULL'}
|
||||
if dim_size is None:
|
||||
return None
|
||||
if padding == 'VALID':
|
||||
dim_size = dim_size * stride_size + max(kernel_size - stride_size, 0)
|
||||
elif padding == 'FULL':
|
||||
dim_size = dim_size * stride_size - (stride_size + kernel_size - 2)
|
||||
elif padding == 'SAME':
|
||||
dim_size = dim_size * stride_size
|
||||
return dim_size
|
||||
nn.Conv2DTranspose = Conv2DTranspose
|
76
core/leras/layers/Dense.py
Normal file
76
core/leras/layers/Dense.py
Normal file
|
@ -0,0 +1,76 @@
|
|||
import numpy as np
|
||||
from core.leras import nn
|
||||
tf = nn.tf
|
||||
|
||||
class Dense(nn.LayerBase):
|
||||
def __init__(self, in_ch, out_ch, use_bias=True, use_wscale=False, maxout_ch=0, kernel_initializer=None, bias_initializer=None, trainable=True, dtype=None, **kwargs ):
|
||||
"""
|
||||
use_wscale enables weight scale (equalized learning rate)
|
||||
if kernel_initializer is None, it will be forced to random_normal
|
||||
|
||||
maxout_ch https://link.springer.com/article/10.1186/s40537-019-0233-0
|
||||
typical 2-4 if you want to enable DenseMaxout behaviour
|
||||
"""
|
||||
self.in_ch = in_ch
|
||||
self.out_ch = out_ch
|
||||
self.use_bias = use_bias
|
||||
self.use_wscale = use_wscale
|
||||
self.maxout_ch = maxout_ch
|
||||
self.kernel_initializer = kernel_initializer
|
||||
self.bias_initializer = bias_initializer
|
||||
self.trainable = trainable
|
||||
if dtype is None:
|
||||
dtype = nn.floatx
|
||||
|
||||
self.dtype = dtype
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def build_weights(self):
|
||||
if self.maxout_ch > 1:
|
||||
weight_shape = (self.in_ch,self.out_ch*self.maxout_ch)
|
||||
else:
|
||||
weight_shape = (self.in_ch,self.out_ch)
|
||||
|
||||
kernel_initializer = self.kernel_initializer
|
||||
|
||||
if self.use_wscale:
|
||||
gain = 1.0
|
||||
fan_in = np.prod( weight_shape[:-1] )
|
||||
he_std = gain / np.sqrt(fan_in) # He init
|
||||
self.wscale = tf.constant(he_std, dtype=self.dtype )
|
||||
if kernel_initializer is None:
|
||||
kernel_initializer = tf.initializers.random_normal(0, 1.0, dtype=self.dtype)
|
||||
|
||||
if kernel_initializer is None:
|
||||
kernel_initializer = tf.initializers.glorot_uniform(dtype=self.dtype)
|
||||
|
||||
self.weight = tf.get_variable("weight", weight_shape, dtype=self.dtype, initializer=kernel_initializer, trainable=self.trainable )
|
||||
|
||||
if self.use_bias:
|
||||
bias_initializer = self.bias_initializer
|
||||
if bias_initializer is None:
|
||||
bias_initializer = tf.initializers.zeros(dtype=self.dtype)
|
||||
self.bias = tf.get_variable("bias", (self.out_ch,), dtype=self.dtype, initializer=bias_initializer, trainable=self.trainable )
|
||||
|
||||
def get_weights(self):
|
||||
weights = [self.weight]
|
||||
if self.use_bias:
|
||||
weights += [self.bias]
|
||||
return weights
|
||||
|
||||
def forward(self, x):
|
||||
weight = self.weight
|
||||
if self.use_wscale:
|
||||
weight = weight * self.wscale
|
||||
|
||||
x = tf.matmul(x, weight)
|
||||
|
||||
if self.maxout_ch > 1:
|
||||
x = tf.reshape (x, (-1, self.out_ch, self.maxout_ch) )
|
||||
x = tf.reduce_max(x, axis=-1)
|
||||
|
||||
if self.use_bias:
|
||||
x = tf.add(x, tf.reshape(self.bias, (1,self.out_ch) ) )
|
||||
|
||||
return x
|
||||
nn.Dense = Dense
|
38
core/leras/layers/FRNorm2D.py
Normal file
38
core/leras/layers/FRNorm2D.py
Normal file
|
@ -0,0 +1,38 @@
|
|||
from core.leras import nn
|
||||
tf = nn.tf
|
||||
|
||||
class FRNorm2D(nn.LayerBase):
|
||||
"""
|
||||
Tensorflow implementation of
|
||||
Filter Response Normalization Layer: Eliminating Batch Dependence in theTraining of Deep Neural Networks
|
||||
https://arxiv.org/pdf/1911.09737.pdf
|
||||
"""
|
||||
def __init__(self, in_ch, dtype=None, **kwargs):
|
||||
self.in_ch = in_ch
|
||||
|
||||
if dtype is None:
|
||||
dtype = nn.floatx
|
||||
self.dtype = dtype
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def build_weights(self):
|
||||
self.weight = tf.get_variable("weight", (self.in_ch,), dtype=self.dtype, initializer=tf.initializers.ones() )
|
||||
self.bias = tf.get_variable("bias", (self.in_ch,), dtype=self.dtype, initializer=tf.initializers.zeros() )
|
||||
self.eps = tf.get_variable("eps", (1,), dtype=self.dtype, initializer=tf.initializers.constant(1e-6) )
|
||||
|
||||
def get_weights(self):
|
||||
return [self.weight, self.bias, self.eps]
|
||||
|
||||
def forward(self, x):
|
||||
if nn.data_format == "NHWC":
|
||||
shape = (1,1,1,self.in_ch)
|
||||
else:
|
||||
shape = (1,self.in_ch,1,1)
|
||||
weight = tf.reshape ( self.weight, shape )
|
||||
bias = tf.reshape ( self.bias , shape )
|
||||
nu2 = tf.reduce_mean(tf.square(x), axis=nn.conv2d_spatial_axes, keepdims=True)
|
||||
x = x * ( 1.0/tf.sqrt(nu2 + tf.abs(self.eps) ) )
|
||||
|
||||
return x*weight + bias
|
||||
nn.FRNorm2D = FRNorm2D
|
40
core/leras/layers/InstanceNorm2D.py
Normal file
40
core/leras/layers/InstanceNorm2D.py
Normal file
|
@ -0,0 +1,40 @@
|
|||
from core.leras import nn
|
||||
tf = nn.tf
|
||||
|
||||
class InstanceNorm2D(nn.LayerBase):
|
||||
def __init__(self, in_ch, dtype=None, **kwargs):
|
||||
self.in_ch = in_ch
|
||||
|
||||
if dtype is None:
|
||||
dtype = nn.floatx
|
||||
self.dtype = dtype
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def build_weights(self):
|
||||
kernel_initializer = tf.initializers.glorot_uniform(dtype=self.dtype)
|
||||
self.weight = tf.get_variable("weight", (self.in_ch,), dtype=self.dtype, initializer=kernel_initializer )
|
||||
self.bias = tf.get_variable("bias", (self.in_ch,), dtype=self.dtype, initializer=tf.initializers.zeros() )
|
||||
|
||||
def get_weights(self):
|
||||
return [self.weight, self.bias]
|
||||
|
||||
def forward(self, x):
|
||||
if nn.data_format == "NHWC":
|
||||
shape = (1,1,1,self.in_ch)
|
||||
else:
|
||||
shape = (1,self.in_ch,1,1)
|
||||
|
||||
weight = tf.reshape ( self.weight , shape )
|
||||
bias = tf.reshape ( self.bias , shape )
|
||||
|
||||
x_mean = tf.reduce_mean(x, axis=nn.conv2d_spatial_axes, keepdims=True )
|
||||
x_std = tf.math.reduce_std(x, axis=nn.conv2d_spatial_axes, keepdims=True ) + 1e-5
|
||||
|
||||
x = (x - x_mean) / x_std
|
||||
x *= weight
|
||||
x += bias
|
||||
|
||||
return x
|
||||
|
||||
nn.InstanceNorm2D = InstanceNorm2D
|
16
core/leras/layers/LayerBase.py
Normal file
16
core/leras/layers/LayerBase.py
Normal file
|
@ -0,0 +1,16 @@
|
|||
from core.leras import nn
|
||||
tf = nn.tf
|
||||
|
||||
class LayerBase(nn.Saveable):
|
||||
#override
|
||||
def build_weights(self):
|
||||
pass
|
||||
|
||||
#override
|
||||
def forward(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
nn.LayerBase = LayerBase
|
103
core/leras/layers/Saveable.py
Normal file
103
core/leras/layers/Saveable.py
Normal file
|
@ -0,0 +1,103 @@
|
|||
import pickle
|
||||
from pathlib import Path
|
||||
from core import pathex
|
||||
import numpy as np
|
||||
|
||||
from core.leras import nn
|
||||
|
||||
tf = nn.tf
|
||||
|
||||
class Saveable():
|
||||
def __init__(self, name=None):
|
||||
self.name = name
|
||||
|
||||
#override
|
||||
def get_weights(self):
|
||||
#return tf tensors that should be initialized/loaded/saved
|
||||
return []
|
||||
|
||||
#override
|
||||
def get_weights_np(self):
|
||||
weights = self.get_weights()
|
||||
if len(weights) == 0:
|
||||
return []
|
||||
return nn.tf_sess.run (weights)
|
||||
|
||||
def set_weights(self, new_weights):
|
||||
weights = self.get_weights()
|
||||
if len(weights) != len(new_weights):
|
||||
raise ValueError ('len of lists mismatch')
|
||||
|
||||
tuples = []
|
||||
for w, new_w in zip(weights, new_weights):
|
||||
|
||||
if len(w.shape) != new_w.shape:
|
||||
new_w = new_w.reshape(w.shape)
|
||||
|
||||
tuples.append ( (w, new_w) )
|
||||
|
||||
nn.batch_set_value (tuples)
|
||||
|
||||
def save_weights(self, filename, force_dtype=None):
|
||||
d = {}
|
||||
weights = self.get_weights()
|
||||
|
||||
if self.name is None:
|
||||
raise Exception("name must be defined.")
|
||||
|
||||
name = self.name
|
||||
for w, w_val in zip(weights, nn.tf_sess.run (weights)):
|
||||
w_name_split = w.name.split('/', 1)
|
||||
if name != w_name_split[0]:
|
||||
raise Exception("weight first name != Saveable.name")
|
||||
|
||||
if force_dtype is not None:
|
||||
w_val = w_val.astype(force_dtype)
|
||||
|
||||
d[ w_name_split[1] ] = w_val
|
||||
|
||||
d_dumped = pickle.dumps (d, 4)
|
||||
pathex.write_bytes_safe ( Path(filename), d_dumped )
|
||||
|
||||
def load_weights(self, filename):
|
||||
"""
|
||||
returns True if file exists
|
||||
"""
|
||||
filepath = Path(filename)
|
||||
if filepath.exists():
|
||||
result = True
|
||||
d_dumped = filepath.read_bytes()
|
||||
d = pickle.loads(d_dumped)
|
||||
else:
|
||||
return False
|
||||
|
||||
weights = self.get_weights()
|
||||
|
||||
if self.name is None:
|
||||
raise Exception("name must be defined.")
|
||||
|
||||
tuples = []
|
||||
for w in weights:
|
||||
w_name_split = w.name.split('/')
|
||||
if self.name != w_name_split[0]:
|
||||
raise Exception("weight first name != Saveable.name")
|
||||
|
||||
sub_w_name = "/".join(w_name_split[1:])
|
||||
|
||||
w_val = d.get(sub_w_name, None)
|
||||
|
||||
if w_val is None:
|
||||
#io.log_err(f"Weight {w.name} was not loaded from file {filename}")
|
||||
tuples.append ( (w, w.initializer) )
|
||||
else:
|
||||
w_val = np.reshape( w_val, w.shape.as_list() )
|
||||
tuples.append ( (w, w_val) )
|
||||
|
||||
nn.batch_set_value(tuples)
|
||||
|
||||
return True
|
||||
|
||||
def init_weights(self):
|
||||
nn.init_weights(self.get_weights())
|
||||
|
||||
nn.Saveable = Saveable
|
33
core/leras/layers/TLU.py
Normal file
33
core/leras/layers/TLU.py
Normal file
|
@ -0,0 +1,33 @@
|
|||
from core.leras import nn
|
||||
tf = nn.tf
|
||||
|
||||
class TLU(nn.LayerBase):
|
||||
"""
|
||||
Tensorflow implementation of
|
||||
Filter Response Normalization Layer: Eliminating Batch Dependence in theTraining of Deep Neural Networks
|
||||
https://arxiv.org/pdf/1911.09737.pdf
|
||||
"""
|
||||
def __init__(self, in_ch, dtype=None, **kwargs):
|
||||
self.in_ch = in_ch
|
||||
|
||||
if dtype is None:
|
||||
dtype = nn.floatx
|
||||
self.dtype = dtype
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def build_weights(self):
|
||||
self.tau = tf.get_variable("tau", (self.in_ch,), dtype=self.dtype, initializer=tf.initializers.zeros() )
|
||||
|
||||
def get_weights(self):
|
||||
return [self.tau]
|
||||
|
||||
def forward(self, x):
|
||||
if nn.data_format == "NHWC":
|
||||
shape = (1,1,1,self.in_ch)
|
||||
else:
|
||||
shape = (1,self.in_ch,1,1)
|
||||
|
||||
tau = tf.reshape ( self.tau, shape )
|
||||
return tf.math.maximum(x, tau)
|
||||
nn.TLU = TLU
|
12
core/leras/layers/__init__.py
Normal file
12
core/leras/layers/__init__.py
Normal file
|
@ -0,0 +1,12 @@
|
|||
from .Saveable import *
|
||||
from .LayerBase import *
|
||||
|
||||
from .Conv2D import *
|
||||
from .Conv2DTranspose import *
|
||||
from .Dense import *
|
||||
from .BlurPool import *
|
||||
|
||||
from .BatchNorm2D import *
|
||||
from .FRNorm2D import *
|
||||
|
||||
from .TLU import *
|
|
@ -1,367 +0,0 @@
|
|||
import types
|
||||
import numpy as np
|
||||
from core.interact import interact as io
|
||||
|
||||
def initialize_models(nn):
|
||||
tf = nn.tf
|
||||
|
||||
class ModelBase(nn.Saveable):
|
||||
def __init__(self, *args, name=None, **kwargs):
|
||||
super().__init__(name=name)
|
||||
self.layers = []
|
||||
self.layers_by_name = {}
|
||||
self.built = False
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
self.run_placeholders = None
|
||||
|
||||
def _build_sub(self, layer, name):
|
||||
if isinstance (layer, list):
|
||||
for i,sublayer in enumerate(layer):
|
||||
self._build_sub(sublayer, f"{name}_{i}")
|
||||
elif isinstance (layer, nn.LayerBase) or \
|
||||
isinstance (layer, ModelBase):
|
||||
|
||||
if layer.name is None:
|
||||
layer.name = name
|
||||
|
||||
if isinstance (layer, nn.LayerBase):
|
||||
with tf.variable_scope(layer.name):
|
||||
layer.build_weights()
|
||||
elif isinstance (layer, ModelBase):
|
||||
layer.build()
|
||||
|
||||
self.layers.append (layer)
|
||||
self.layers_by_name[layer.name] = layer
|
||||
|
||||
def xor_list(self, lst1, lst2):
|
||||
return [value for value in lst1+lst2 if (value not in lst1) or (value not in lst2) ]
|
||||
|
||||
def build(self):
|
||||
with tf.variable_scope(self.name):
|
||||
|
||||
current_vars = []
|
||||
generator = None
|
||||
while True:
|
||||
|
||||
if generator is None:
|
||||
generator = self.on_build(*self.args, **self.kwargs)
|
||||
if not isinstance(generator, types.GeneratorType):
|
||||
generator = None
|
||||
|
||||
if generator is not None:
|
||||
try:
|
||||
next(generator)
|
||||
except StopIteration:
|
||||
generator = None
|
||||
|
||||
v = vars(self)
|
||||
new_vars = self.xor_list (current_vars, list(v.keys()) )
|
||||
|
||||
for name in new_vars:
|
||||
self._build_sub(v[name],name)
|
||||
|
||||
current_vars += new_vars
|
||||
|
||||
if generator is None:
|
||||
break
|
||||
|
||||
self.built = True
|
||||
|
||||
#override
|
||||
def get_weights(self):
|
||||
if not self.built:
|
||||
self.build()
|
||||
|
||||
weights = []
|
||||
for layer in self.layers:
|
||||
weights += layer.get_weights()
|
||||
return weights
|
||||
|
||||
def get_layer_by_name(self, name):
|
||||
return self.layers_by_name.get(name, None)
|
||||
|
||||
def get_layers(self):
|
||||
if not self.built:
|
||||
self.build()
|
||||
layers = []
|
||||
for layer in self.layers:
|
||||
if isinstance (layer, nn.LayerBase):
|
||||
layers.append(layer)
|
||||
else:
|
||||
layers += layer.get_layers()
|
||||
return layers
|
||||
|
||||
#override
|
||||
def on_build(self, *args, **kwargs):
|
||||
"""
|
||||
init model layers here
|
||||
|
||||
return 'yield' if build is not finished
|
||||
therefore dependency models will be initialized
|
||||
"""
|
||||
pass
|
||||
|
||||
#override
|
||||
def forward(self, *args, **kwargs):
|
||||
#flow layers/models/tensors here
|
||||
pass
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
if not self.built:
|
||||
self.build()
|
||||
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
def compute_output_shape(self, shapes):
|
||||
if not self.built:
|
||||
self.build()
|
||||
|
||||
not_list = False
|
||||
if not isinstance(shapes, list):
|
||||
not_list = True
|
||||
shapes = [shapes]
|
||||
|
||||
with tf.device('/CPU:0'):
|
||||
# CPU tensors will not impact any performance, only slightly RAM "leakage"
|
||||
phs = []
|
||||
for dtype,sh in shapes:
|
||||
phs += [ tf.placeholder(dtype, sh) ]
|
||||
|
||||
result = self.__call__(phs[0] if not_list else phs)
|
||||
|
||||
if not isinstance(result, list):
|
||||
result = [result]
|
||||
|
||||
result_shapes = []
|
||||
|
||||
for t in result:
|
||||
result_shapes += [ t.shape.as_list() ]
|
||||
|
||||
return result_shapes[0] if not_list else result_shapes
|
||||
|
||||
def compute_output_channels(self, shapes):
|
||||
shape = self.compute_output_shape(shapes)
|
||||
shape_len = len(shape)
|
||||
|
||||
if shape_len == 4:
|
||||
if nn.data_format == "NCHW":
|
||||
return shape[1]
|
||||
return shape[-1]
|
||||
|
||||
def build_for_run(self, shapes_list):
|
||||
if not isinstance(shapes_list, list):
|
||||
raise ValueError("shapes_list must be a list.")
|
||||
|
||||
self.run_placeholders = []
|
||||
for dtype,sh in shapes_list:
|
||||
self.run_placeholders.append ( tf.placeholder(dtype, sh) )
|
||||
|
||||
self.run_output = self.__call__(self.run_placeholders)
|
||||
|
||||
def run (self, inputs):
|
||||
if self.run_placeholders is None:
|
||||
raise Exception ("Model didn't build for run.")
|
||||
|
||||
if len(inputs) != len(self.run_placeholders):
|
||||
raise ValueError("len(inputs) != self.run_placeholders")
|
||||
|
||||
feed_dict = {}
|
||||
for ph, inp in zip(self.run_placeholders, inputs):
|
||||
feed_dict[ph] = inp
|
||||
|
||||
return nn.tf_sess.run ( self.run_output, feed_dict=feed_dict)
|
||||
|
||||
def summary(self):
|
||||
layers = self.get_layers()
|
||||
layers_names = []
|
||||
layers_params = []
|
||||
|
||||
max_len_str = 0
|
||||
max_len_param_str = 0
|
||||
delim_str = "-"
|
||||
|
||||
total_params = 0
|
||||
|
||||
#Get layers names and str lenght for delim
|
||||
for l in layers:
|
||||
if len(str(l))>max_len_str:
|
||||
max_len_str = len(str(l))
|
||||
layers_names+=[str(l).capitalize()]
|
||||
|
||||
#Get params for each layer
|
||||
layers_params = [ int(np.sum(np.prod(w.shape) for w in l.get_weights())) for l in layers ]
|
||||
total_params = np.sum(layers_params)
|
||||
|
||||
#Get str lenght for delim
|
||||
for p in layers_params:
|
||||
if len(str(p))>max_len_param_str:
|
||||
max_len_param_str=len(str(p))
|
||||
|
||||
#Set delim
|
||||
for i in range(max_len_str+max_len_param_str+3):
|
||||
delim_str += "-"
|
||||
|
||||
output = "\n"+delim_str+"\n"
|
||||
|
||||
#Format model name str
|
||||
model_name_str = "| "+self.name.capitalize()
|
||||
len_model_name_str = len(model_name_str)
|
||||
for i in range(len(delim_str)-len_model_name_str):
|
||||
model_name_str+= " " if i!=(len(delim_str)-len_model_name_str-2) else " |"
|
||||
|
||||
output += model_name_str +"\n"
|
||||
output += delim_str +"\n"
|
||||
|
||||
|
||||
#Format layers table
|
||||
for i in range(len(layers_names)):
|
||||
output += delim_str +"\n"
|
||||
|
||||
l_name = layers_names[i]
|
||||
l_param = str(layers_params[i])
|
||||
l_param_str = ""
|
||||
if len(l_name)<=max_len_str:
|
||||
for i in range(max_len_str - len(l_name)):
|
||||
l_name+= " "
|
||||
|
||||
if len(l_param)<=max_len_param_str:
|
||||
for i in range(max_len_param_str - len(l_param)):
|
||||
l_param_str+= " "
|
||||
|
||||
l_param_str += l_param
|
||||
|
||||
|
||||
output +="| "+l_name+"|"+l_param_str+"| \n"
|
||||
|
||||
output += delim_str +"\n"
|
||||
|
||||
#Format sum of params
|
||||
total_params_str = "| Total params count: "+str(total_params)
|
||||
len_total_params_str = len(total_params_str)
|
||||
for i in range(len(delim_str)-len_total_params_str):
|
||||
total_params_str+= " " if i!=(len(delim_str)-len_total_params_str-2) else " |"
|
||||
|
||||
output += total_params_str +"\n"
|
||||
output += delim_str +"\n"
|
||||
|
||||
io.log_info(output)
|
||||
|
||||
nn.ModelBase = ModelBase
|
||||
|
||||
class PatchDiscriminator(nn.ModelBase):
|
||||
def on_build(self, patch_size, in_ch, base_ch=None, conv_kernel_initializer=None):
|
||||
suggested_base_ch, kernels_strides = patch_discriminator_kernels[patch_size]
|
||||
|
||||
if base_ch is None:
|
||||
base_ch = suggested_base_ch
|
||||
|
||||
prev_ch = in_ch
|
||||
self.convs = []
|
||||
for i, (kernel_size, strides) in enumerate(kernels_strides):
|
||||
cur_ch = base_ch * min( (2**i), 8 )
|
||||
|
||||
self.convs.append ( nn.Conv2D( prev_ch, cur_ch, kernel_size=kernel_size, strides=strides, padding='SAME', kernel_initializer=conv_kernel_initializer) )
|
||||
prev_ch = cur_ch
|
||||
|
||||
self.out_conv = nn.Conv2D( prev_ch, 1, kernel_size=1, padding='VALID', kernel_initializer=conv_kernel_initializer)
|
||||
|
||||
def forward(self, x):
|
||||
for conv in self.convs:
|
||||
x = tf.nn.leaky_relu( conv(x), 0.1 )
|
||||
return self.out_conv(x)
|
||||
|
||||
nn.PatchDiscriminator = PatchDiscriminator
|
||||
|
||||
class IllumDiscriminator(nn.ModelBase):
|
||||
def on_build(self, patch_size, in_ch, base_ch=None, conv_kernel_initializer=None):
|
||||
suggested_base_ch, kernels_strides = patch_discriminator_kernels[patch_size]
|
||||
if base_ch is None:
|
||||
base_ch = suggested_base_ch
|
||||
|
||||
prev_ch = in_ch
|
||||
self.convs = []
|
||||
for i, (kernel_size, strides) in enumerate(kernels_strides):
|
||||
cur_ch = base_ch * min( (2**i), 8 )
|
||||
self.convs.append ( nn.Conv2D( prev_ch, cur_ch, kernel_size=kernel_size, strides=strides, padding='SAME', kernel_initializer=conv_kernel_initializer) )
|
||||
prev_ch = cur_ch
|
||||
|
||||
self.out1 = nn.Conv2D( 1, 1024, kernel_size=1, strides=1, padding='SAME', kernel_initializer=conv_kernel_initializer)
|
||||
self.out2 = nn.Conv2D( 1024, 1, kernel_size=1, strides=1, padding='SAME', kernel_initializer=conv_kernel_initializer)
|
||||
|
||||
def forward(self, x):
|
||||
for conv in self.convs:
|
||||
x = tf.nn.leaky_relu( conv(x), 0.1 )
|
||||
|
||||
x = tf.reduce_mean(x, axis=nn.conv2d_ch_axis, keep_dims=True)
|
||||
|
||||
x = self.out1(x)
|
||||
x = tf.nn.leaky_relu(x, 0.1 )
|
||||
x = self.out2(x)
|
||||
|
||||
return x
|
||||
|
||||
nn.IllumDiscriminator = IllumDiscriminator
|
||||
|
||||
class CodeDiscriminator(nn.ModelBase):
|
||||
def on_build(self, in_ch, code_res, ch=256, conv_kernel_initializer=None):
|
||||
if conv_kernel_initializer is None:
|
||||
conv_kernel_initializer = nn.initializers.ca()
|
||||
|
||||
n_downscales = 1 + code_res // 8
|
||||
|
||||
self.convs = []
|
||||
prev_ch = in_ch
|
||||
for i in range(n_downscales):
|
||||
cur_ch = ch * min( (2**i), 8 )
|
||||
self.convs.append ( nn.Conv2D( prev_ch, cur_ch, kernel_size=4 if i == 0 else 3, strides=2, padding='SAME', kernel_initializer=conv_kernel_initializer) )
|
||||
prev_ch = cur_ch
|
||||
|
||||
self.out_conv = nn.Conv2D( prev_ch, 1, kernel_size=1, padding='VALID', kernel_initializer=conv_kernel_initializer)
|
||||
|
||||
def forward(self, x):
|
||||
for conv in self.convs:
|
||||
x = tf.nn.leaky_relu( conv(x), 0.1 )
|
||||
return self.out_conv(x)
|
||||
nn.CodeDiscriminator = CodeDiscriminator
|
||||
|
||||
patch_discriminator_kernels = \
|
||||
{ 1 : (512, [ [1,1] ]),
|
||||
2 : (512, [ [2,1] ]),
|
||||
3 : (512, [ [2,1], [2,1] ]),
|
||||
4 : (512, [ [2,2], [2,2] ]),
|
||||
5 : (512, [ [3,2], [2,2] ]),
|
||||
6 : (512, [ [4,2], [2,2] ]),
|
||||
7 : (512, [ [3,2], [3,2] ]),
|
||||
8 : (512, [ [4,2], [3,2] ]),
|
||||
9 : (512, [ [3,2], [4,2] ]),
|
||||
10 : (512, [ [4,2], [4,2] ]),
|
||||
11 : (512, [ [3,2], [3,2], [2,1] ]),
|
||||
12 : (512, [ [4,2], [3,2], [2,1] ]),
|
||||
13 : (512, [ [3,2], [4,2], [2,1] ]),
|
||||
14 : (512, [ [4,2], [4,2], [2,1] ]),
|
||||
15 : (512, [ [3,2], [3,2], [3,1] ]),
|
||||
16 : (512, [ [4,2], [3,2], [3,1] ]),
|
||||
17 : (512, [ [3,2], [4,2], [3,1] ]),
|
||||
18 : (512, [ [4,2], [4,2], [3,1] ]),
|
||||
19 : (512, [ [3,2], [3,2], [4,1] ]),
|
||||
20 : (512, [ [4,2], [3,2], [4,1] ]),
|
||||
21 : (512, [ [3,2], [4,2], [4,1] ]),
|
||||
22 : (512, [ [4,2], [4,2], [4,1] ]),
|
||||
23 : (256, [ [3,2], [3,2], [3,2], [2,1] ]),
|
||||
24 : (256, [ [4,2], [3,2], [3,2], [2,1] ]),
|
||||
25 : (256, [ [3,2], [4,2], [3,2], [2,1] ]),
|
||||
26 : (256, [ [4,2], [4,2], [3,2], [2,1] ]),
|
||||
27 : (256, [ [3,2], [4,2], [4,2], [2,1] ]),
|
||||
28 : (256, [ [4,2], [3,2], [4,2], [2,1] ]),
|
||||
29 : (256, [ [3,2], [4,2], [4,2], [2,1] ]),
|
||||
30 : (256, [ [4,2], [4,2], [4,2], [2,1] ]),
|
||||
31 : (256, [ [3,2], [3,2], [3,2], [3,1] ]),
|
||||
32 : (256, [ [4,2], [3,2], [3,2], [3,1] ]),
|
||||
33 : (256, [ [3,2], [4,2], [3,2], [3,1] ]),
|
||||
34 : (256, [ [4,2], [4,2], [3,2], [3,1] ]),
|
||||
35 : (256, [ [3,2], [4,2], [4,2], [3,1] ]),
|
||||
36 : (256, [ [4,2], [3,2], [4,2], [3,1] ]),
|
||||
37 : (256, [ [3,2], [4,2], [4,2], [3,1] ]),
|
||||
38 : (256, [ [4,2], [4,2], [4,2], [3,1] ]),
|
||||
}
|
22
core/leras/models/CodeDiscriminator.py
Normal file
22
core/leras/models/CodeDiscriminator.py
Normal file
|
@ -0,0 +1,22 @@
|
|||
from core.leras import nn
|
||||
tf = nn.tf
|
||||
|
||||
class CodeDiscriminator(nn.ModelBase):
|
||||
def on_build(self, in_ch, code_res, ch=256, conv_kernel_initializer=None):
|
||||
n_downscales = 1 + code_res // 8
|
||||
|
||||
self.convs = []
|
||||
prev_ch = in_ch
|
||||
for i in range(n_downscales):
|
||||
cur_ch = ch * min( (2**i), 8 )
|
||||
self.convs.append ( nn.Conv2D( prev_ch, cur_ch, kernel_size=4 if i == 0 else 3, strides=2, padding='SAME', kernel_initializer=conv_kernel_initializer) )
|
||||
prev_ch = cur_ch
|
||||
|
||||
self.out_conv = nn.Conv2D( prev_ch, 1, kernel_size=1, padding='VALID', kernel_initializer=conv_kernel_initializer)
|
||||
|
||||
def forward(self, x):
|
||||
for conv in self.convs:
|
||||
x = tf.nn.leaky_relu( conv(x), 0.1 )
|
||||
return self.out_conv(x)
|
||||
|
||||
nn.CodeDiscriminator = CodeDiscriminator
|
249
core/leras/models/ModelBase.py
Normal file
249
core/leras/models/ModelBase.py
Normal file
|
@ -0,0 +1,249 @@
|
|||
import types
|
||||
import numpy as np
|
||||
from core.interact import interact as io
|
||||
from core.leras import nn
|
||||
tf = nn.tf
|
||||
|
||||
class ModelBase(nn.Saveable):
|
||||
def __init__(self, *args, name=None, **kwargs):
|
||||
super().__init__(name=name)
|
||||
self.layers = []
|
||||
self.layers_by_name = {}
|
||||
self.built = False
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
self.run_placeholders = None
|
||||
|
||||
def _build_sub(self, layer, name):
|
||||
if isinstance (layer, list):
|
||||
for i,sublayer in enumerate(layer):
|
||||
self._build_sub(sublayer, f"{name}_{i}")
|
||||
elif isinstance (layer, nn.LayerBase) or \
|
||||
isinstance (layer, ModelBase):
|
||||
|
||||
if layer.name is None:
|
||||
layer.name = name
|
||||
|
||||
if isinstance (layer, nn.LayerBase):
|
||||
with tf.variable_scope(layer.name):
|
||||
layer.build_weights()
|
||||
elif isinstance (layer, ModelBase):
|
||||
layer.build()
|
||||
|
||||
self.layers.append (layer)
|
||||
self.layers_by_name[layer.name] = layer
|
||||
|
||||
def xor_list(self, lst1, lst2):
|
||||
return [value for value in lst1+lst2 if (value not in lst1) or (value not in lst2) ]
|
||||
|
||||
def build(self):
|
||||
with tf.variable_scope(self.name):
|
||||
|
||||
current_vars = []
|
||||
generator = None
|
||||
while True:
|
||||
|
||||
if generator is None:
|
||||
generator = self.on_build(*self.args, **self.kwargs)
|
||||
if not isinstance(generator, types.GeneratorType):
|
||||
generator = None
|
||||
|
||||
if generator is not None:
|
||||
try:
|
||||
next(generator)
|
||||
except StopIteration:
|
||||
generator = None
|
||||
|
||||
v = vars(self)
|
||||
new_vars = self.xor_list (current_vars, list(v.keys()) )
|
||||
|
||||
for name in new_vars:
|
||||
self._build_sub(v[name],name)
|
||||
|
||||
current_vars += new_vars
|
||||
|
||||
if generator is None:
|
||||
break
|
||||
|
||||
self.built = True
|
||||
|
||||
#override
|
||||
def get_weights(self):
|
||||
if not self.built:
|
||||
self.build()
|
||||
|
||||
weights = []
|
||||
for layer in self.layers:
|
||||
weights += layer.get_weights()
|
||||
return weights
|
||||
|
||||
def get_layer_by_name(self, name):
|
||||
return self.layers_by_name.get(name, None)
|
||||
|
||||
def get_layers(self):
|
||||
if not self.built:
|
||||
self.build()
|
||||
layers = []
|
||||
for layer in self.layers:
|
||||
if isinstance (layer, nn.LayerBase):
|
||||
layers.append(layer)
|
||||
else:
|
||||
layers += layer.get_layers()
|
||||
return layers
|
||||
|
||||
#override
|
||||
def on_build(self, *args, **kwargs):
|
||||
"""
|
||||
init model layers here
|
||||
|
||||
return 'yield' if build is not finished
|
||||
therefore dependency models will be initialized
|
||||
"""
|
||||
pass
|
||||
|
||||
#override
|
||||
def forward(self, *args, **kwargs):
|
||||
#flow layers/models/tensors here
|
||||
pass
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
if not self.built:
|
||||
self.build()
|
||||
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
def compute_output_shape(self, shapes):
|
||||
if not self.built:
|
||||
self.build()
|
||||
|
||||
not_list = False
|
||||
if not isinstance(shapes, list):
|
||||
not_list = True
|
||||
shapes = [shapes]
|
||||
|
||||
with tf.device('/CPU:0'):
|
||||
# CPU tensors will not impact any performance, only slightly RAM "leakage"
|
||||
phs = []
|
||||
for dtype,sh in shapes:
|
||||
phs += [ tf.placeholder(dtype, sh) ]
|
||||
|
||||
result = self.__call__(phs[0] if not_list else phs)
|
||||
|
||||
if not isinstance(result, list):
|
||||
result = [result]
|
||||
|
||||
result_shapes = []
|
||||
|
||||
for t in result:
|
||||
result_shapes += [ t.shape.as_list() ]
|
||||
|
||||
return result_shapes[0] if not_list else result_shapes
|
||||
|
||||
def compute_output_channels(self, shapes):
|
||||
shape = self.compute_output_shape(shapes)
|
||||
shape_len = len(shape)
|
||||
|
||||
if shape_len == 4:
|
||||
if nn.data_format == "NCHW":
|
||||
return shape[1]
|
||||
return shape[-1]
|
||||
|
||||
def build_for_run(self, shapes_list):
|
||||
if not isinstance(shapes_list, list):
|
||||
raise ValueError("shapes_list must be a list.")
|
||||
|
||||
self.run_placeholders = []
|
||||
for dtype,sh in shapes_list:
|
||||
self.run_placeholders.append ( tf.placeholder(dtype, sh) )
|
||||
|
||||
self.run_output = self.__call__(self.run_placeholders)
|
||||
|
||||
def run (self, inputs):
|
||||
if self.run_placeholders is None:
|
||||
raise Exception ("Model didn't build for run.")
|
||||
|
||||
if len(inputs) != len(self.run_placeholders):
|
||||
raise ValueError("len(inputs) != self.run_placeholders")
|
||||
|
||||
feed_dict = {}
|
||||
for ph, inp in zip(self.run_placeholders, inputs):
|
||||
feed_dict[ph] = inp
|
||||
|
||||
return nn.tf_sess.run ( self.run_output, feed_dict=feed_dict)
|
||||
|
||||
def summary(self):
|
||||
layers = self.get_layers()
|
||||
layers_names = []
|
||||
layers_params = []
|
||||
|
||||
max_len_str = 0
|
||||
max_len_param_str = 0
|
||||
delim_str = "-"
|
||||
|
||||
total_params = 0
|
||||
|
||||
#Get layers names and str lenght for delim
|
||||
for l in layers:
|
||||
if len(str(l))>max_len_str:
|
||||
max_len_str = len(str(l))
|
||||
layers_names+=[str(l).capitalize()]
|
||||
|
||||
#Get params for each layer
|
||||
layers_params = [ int(np.sum(np.prod(w.shape) for w in l.get_weights())) for l in layers ]
|
||||
total_params = np.sum(layers_params)
|
||||
|
||||
#Get str lenght for delim
|
||||
for p in layers_params:
|
||||
if len(str(p))>max_len_param_str:
|
||||
max_len_param_str=len(str(p))
|
||||
|
||||
#Set delim
|
||||
for i in range(max_len_str+max_len_param_str+3):
|
||||
delim_str += "-"
|
||||
|
||||
output = "\n"+delim_str+"\n"
|
||||
|
||||
#Format model name str
|
||||
model_name_str = "| "+self.name.capitalize()
|
||||
len_model_name_str = len(model_name_str)
|
||||
for i in range(len(delim_str)-len_model_name_str):
|
||||
model_name_str+= " " if i!=(len(delim_str)-len_model_name_str-2) else " |"
|
||||
|
||||
output += model_name_str +"\n"
|
||||
output += delim_str +"\n"
|
||||
|
||||
|
||||
#Format layers table
|
||||
for i in range(len(layers_names)):
|
||||
output += delim_str +"\n"
|
||||
|
||||
l_name = layers_names[i]
|
||||
l_param = str(layers_params[i])
|
||||
l_param_str = ""
|
||||
if len(l_name)<=max_len_str:
|
||||
for i in range(max_len_str - len(l_name)):
|
||||
l_name+= " "
|
||||
|
||||
if len(l_param)<=max_len_param_str:
|
||||
for i in range(max_len_param_str - len(l_param)):
|
||||
l_param_str+= " "
|
||||
|
||||
l_param_str += l_param
|
||||
|
||||
|
||||
output +="| "+l_name+"|"+l_param_str+"| \n"
|
||||
|
||||
output += delim_str +"\n"
|
||||
|
||||
#Format sum of params
|
||||
total_params_str = "| Total params count: "+str(total_params)
|
||||
len_total_params_str = len(total_params_str)
|
||||
for i in range(len(delim_str)-len_total_params_str):
|
||||
total_params_str+= " " if i!=(len(delim_str)-len_total_params_str-2) else " |"
|
||||
|
||||
output += total_params_str +"\n"
|
||||
output += delim_str +"\n"
|
||||
|
||||
io.log_info(output)
|
||||
|
||||
nn.ModelBase = ModelBase
|
69
core/leras/models/PatchDiscriminator.py
Normal file
69
core/leras/models/PatchDiscriminator.py
Normal file
|
@ -0,0 +1,69 @@
|
|||
from core.leras import nn
|
||||
tf = nn.tf
|
||||
|
||||
|
||||
patch_discriminator_kernels = \
|
||||
{ 1 : (512, [ [1,1] ]),
|
||||
2 : (512, [ [2,1] ]),
|
||||
3 : (512, [ [2,1], [2,1] ]),
|
||||
4 : (512, [ [2,2], [2,2] ]),
|
||||
5 : (512, [ [3,2], [2,2] ]),
|
||||
6 : (512, [ [4,2], [2,2] ]),
|
||||
7 : (512, [ [3,2], [3,2] ]),
|
||||
8 : (512, [ [4,2], [3,2] ]),
|
||||
9 : (512, [ [3,2], [4,2] ]),
|
||||
10 : (512, [ [4,2], [4,2] ]),
|
||||
11 : (512, [ [3,2], [3,2], [2,1] ]),
|
||||
12 : (512, [ [4,2], [3,2], [2,1] ]),
|
||||
13 : (512, [ [3,2], [4,2], [2,1] ]),
|
||||
14 : (512, [ [4,2], [4,2], [2,1] ]),
|
||||
15 : (512, [ [3,2], [3,2], [3,1] ]),
|
||||
16 : (512, [ [4,2], [3,2], [3,1] ]),
|
||||
17 : (512, [ [3,2], [4,2], [3,1] ]),
|
||||
18 : (512, [ [4,2], [4,2], [3,1] ]),
|
||||
19 : (512, [ [3,2], [3,2], [4,1] ]),
|
||||
20 : (512, [ [4,2], [3,2], [4,1] ]),
|
||||
21 : (512, [ [3,2], [4,2], [4,1] ]),
|
||||
22 : (512, [ [4,2], [4,2], [4,1] ]),
|
||||
23 : (256, [ [3,2], [3,2], [3,2], [2,1] ]),
|
||||
24 : (256, [ [4,2], [3,2], [3,2], [2,1] ]),
|
||||
25 : (256, [ [3,2], [4,2], [3,2], [2,1] ]),
|
||||
26 : (256, [ [4,2], [4,2], [3,2], [2,1] ]),
|
||||
27 : (256, [ [3,2], [4,2], [4,2], [2,1] ]),
|
||||
28 : (256, [ [4,2], [3,2], [4,2], [2,1] ]),
|
||||
29 : (256, [ [3,2], [4,2], [4,2], [2,1] ]),
|
||||
30 : (256, [ [4,2], [4,2], [4,2], [2,1] ]),
|
||||
31 : (256, [ [3,2], [3,2], [3,2], [3,1] ]),
|
||||
32 : (256, [ [4,2], [3,2], [3,2], [3,1] ]),
|
||||
33 : (256, [ [3,2], [4,2], [3,2], [3,1] ]),
|
||||
34 : (256, [ [4,2], [4,2], [3,2], [3,1] ]),
|
||||
35 : (256, [ [3,2], [4,2], [4,2], [3,1] ]),
|
||||
36 : (256, [ [4,2], [3,2], [4,2], [3,1] ]),
|
||||
37 : (256, [ [3,2], [4,2], [4,2], [3,1] ]),
|
||||
38 : (256, [ [4,2], [4,2], [4,2], [3,1] ]),
|
||||
}
|
||||
|
||||
|
||||
class PatchDiscriminator(nn.ModelBase):
|
||||
def on_build(self, patch_size, in_ch, base_ch=None, conv_kernel_initializer=None):
|
||||
suggested_base_ch, kernels_strides = patch_discriminator_kernels[patch_size]
|
||||
|
||||
if base_ch is None:
|
||||
base_ch = suggested_base_ch
|
||||
|
||||
prev_ch = in_ch
|
||||
self.convs = []
|
||||
for i, (kernel_size, strides) in enumerate(kernels_strides):
|
||||
cur_ch = base_ch * min( (2**i), 8 )
|
||||
|
||||
self.convs.append ( nn.Conv2D( prev_ch, cur_ch, kernel_size=kernel_size, strides=strides, padding='SAME', kernel_initializer=conv_kernel_initializer) )
|
||||
prev_ch = cur_ch
|
||||
|
||||
self.out_conv = nn.Conv2D( prev_ch, 1, kernel_size=1, padding='VALID', kernel_initializer=conv_kernel_initializer)
|
||||
|
||||
def forward(self, x):
|
||||
for conv in self.convs:
|
||||
x = tf.nn.leaky_relu( conv(x), 0.1 )
|
||||
return self.out_conv(x)
|
||||
|
||||
nn.PatchDiscriminator = PatchDiscriminator
|
4
core/leras/models/__init__.py
Normal file
4
core/leras/models/__init__.py
Normal file
|
@ -0,0 +1,4 @@
|
|||
from .ModelBase import *
|
||||
from .PatchDiscriminator import *
|
||||
from .CodeDiscriminator import *
|
||||
from .Ternaus import *
|
116
core/leras/nn.py
116
core/leras/nn.py
|
@ -11,7 +11,7 @@ Provides:
|
|||
+ convenient and understandable logic
|
||||
|
||||
Reasons why we cannot import tensorflow or any tensorflow.sub modules right here:
|
||||
1) change env variables based on DeviceConfig before import tensorflow
|
||||
1) program is changing env variables based on DeviceConfig before import tensorflow
|
||||
2) multiprocesses will import tensorflow every spawn
|
||||
|
||||
NCHW speed up training for 10-20%.
|
||||
|
@ -19,12 +19,11 @@ NCHW speed up training for 10-20%.
|
|||
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
warnings.simplefilter(action='ignore', category=FutureWarning)
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
from core.interact import interact as io
|
||||
|
||||
from .device import Devices
|
||||
|
||||
|
||||
|
@ -40,57 +39,7 @@ class nn():
|
|||
conv2d_ch_axis = None
|
||||
conv2d_spatial_axes = None
|
||||
|
||||
tf_floatx = None
|
||||
np_floatx = None
|
||||
|
||||
# Tensor ops
|
||||
tf_get_value = None
|
||||
tf_batch_set_value = None
|
||||
tf_init_weights = None
|
||||
tf_gradients = None
|
||||
tf_average_gv_list = None
|
||||
tf_average_tensor_list = None
|
||||
tf_concat = None
|
||||
tf_gelu = None
|
||||
tf_upsample2d = None
|
||||
tf_resize2d_bilinear = None
|
||||
tf_flatten = None
|
||||
tf_max_pool = None
|
||||
tf_reshape_4D = None
|
||||
tf_random_binomial = None
|
||||
tf_gaussian_blur = None
|
||||
tf_style_loss = None
|
||||
tf_dssim = None
|
||||
tf_space_to_depth = None
|
||||
tf_depth_to_space = None
|
||||
|
||||
# Layers
|
||||
Saveable = None
|
||||
LayerBase = None
|
||||
ModelBase = None
|
||||
Conv2D = None
|
||||
Conv2DTranspose = None
|
||||
BlurPool = None
|
||||
Dense = None
|
||||
InstanceNorm2D = None
|
||||
BatchNorm2D = None
|
||||
AdaIN = None
|
||||
|
||||
# Initializers
|
||||
initializers = None
|
||||
|
||||
# Optimizers
|
||||
TFBaseOptimizer = None
|
||||
TFRMSpropOptimizer = None
|
||||
|
||||
# Models
|
||||
PatchDiscriminator = None
|
||||
IllumDiscriminator = None
|
||||
CodeDiscriminator = None
|
||||
|
||||
# Arhis
|
||||
get_ae_models = None
|
||||
get_ae_models_chervonij = None
|
||||
floatx = None
|
||||
|
||||
@staticmethod
|
||||
def initialize(device_config=None, floatx="float32", data_format="NHWC"):
|
||||
|
@ -98,15 +47,17 @@ class nn():
|
|||
if nn.tf is None:
|
||||
if device_config is None:
|
||||
device_config = nn.getCurrentDeviceConfig()
|
||||
else:
|
||||
nn.setCurrentDeviceConfig(device_config)
|
||||
nn.setCurrentDeviceConfig(device_config)
|
||||
|
||||
# Manipulate environment variables before import tensorflow
|
||||
|
||||
if 'CUDA_VISIBLE_DEVICES' in os.environ.keys():
|
||||
os.environ.pop('CUDA_VISIBLE_DEVICES')
|
||||
|
||||
first_run = False
|
||||
if len(device_config.devices) != 0:
|
||||
if sys.platform[0:3] == 'win':
|
||||
# Windows specific env vars
|
||||
if all( [ x.name == device_config.devices[0].name for x in device_config.devices ] ):
|
||||
devices_str = "_" + device_config.devices[0].name.replace(' ','_')
|
||||
else:
|
||||
|
@ -123,18 +74,25 @@ class nn():
|
|||
os.environ['TF_MIN_GPU_MULTIPROCESSOR_COUNT'] = '2'
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # tf log errors only
|
||||
|
||||
import warnings
|
||||
warnings.simplefilter(action='ignore', category=FutureWarning)
|
||||
|
||||
if first_run:
|
||||
io.log_info("Caching GPU kernels...")
|
||||
|
||||
import tensorflow as tf
|
||||
nn.tf = tf
|
||||
|
||||
import logging
|
||||
# Disable tensorflow warnings
|
||||
logging.getLogger('tensorflow').setLevel(logging.ERROR)
|
||||
|
||||
nn.tf = tf
|
||||
|
||||
# Initialize framework
|
||||
import core.leras.ops
|
||||
import core.leras.layers
|
||||
import core.leras.initializers
|
||||
import core.leras.optimizers
|
||||
import core.leras.models
|
||||
import core.leras.archis
|
||||
|
||||
# Configure tensorflow session-config
|
||||
if len(device_config.devices) == 0:
|
||||
nn.tf_default_device = "/CPU:0"
|
||||
config = tf.ConfigProto(device_count={'GPU': 0})
|
||||
|
@ -146,20 +104,6 @@ class nn():
|
|||
config.gpu_options.force_gpu_compatible = True
|
||||
config.gpu_options.allow_growth = True
|
||||
nn.tf_sess_config = config
|
||||
|
||||
from .tensor_ops import initialize_tensor_ops
|
||||
from .layers import initialize_layers
|
||||
from .initializers import initialize_initializers
|
||||
from .optimizers import initialize_optimizers
|
||||
from .models import initialize_models
|
||||
from .archis import initialize_archis
|
||||
|
||||
initialize_tensor_ops(nn)
|
||||
initialize_layers(nn)
|
||||
initialize_initializers(nn)
|
||||
initialize_optimizers(nn)
|
||||
initialize_models(nn)
|
||||
initialize_archis(nn)
|
||||
|
||||
if nn.tf_sess is None:
|
||||
nn.tf_sess = tf.Session(config=nn.tf_sess_config)
|
||||
|
@ -182,8 +126,7 @@ class nn():
|
|||
"""
|
||||
set default float type for all layers when dtype is None for them
|
||||
"""
|
||||
nn.tf_floatx = tf_dtype
|
||||
nn.np_floatx = tf_dtype.as_numpy_dtype
|
||||
nn.floatx = tf_dtype
|
||||
|
||||
@staticmethod
|
||||
def set_data_format(data_format):
|
||||
|
@ -231,7 +174,7 @@ class nn():
|
|||
nn.current_DeviceConfig = device_config
|
||||
|
||||
@staticmethod
|
||||
def tf_reset_session():
|
||||
def reset_session():
|
||||
if nn.tf is not None:
|
||||
if nn.tf_sess is not None:
|
||||
nn.tf.reset_default_graph()
|
||||
|
@ -239,14 +182,14 @@ class nn():
|
|||
nn.tf_sess = nn.tf.Session(config=nn.tf_sess_config)
|
||||
|
||||
@staticmethod
|
||||
def tf_close_session():
|
||||
def close_session():
|
||||
if nn.tf_sess is not None:
|
||||
nn.tf.reset_default_graph()
|
||||
nn.tf_sess.close()
|
||||
nn.tf_sess = None
|
||||
|
||||
@staticmethod
|
||||
def tf_get_current_device():
|
||||
def get_current_device():
|
||||
# Undocumented access to last tf.device(...)
|
||||
objs = nn.tf.get_default_graph()._device_function_stack.peek_objs()
|
||||
if len(objs) != 0:
|
||||
|
@ -254,7 +197,7 @@ class nn():
|
|||
return nn.tf_default_device
|
||||
|
||||
@staticmethod
|
||||
def ask_choose_device_idxs(choose_only_one=False, allow_cpu=True, suggest_best_multi_gpu=False, suggest_all_gpu=False, return_device_config=False):
|
||||
def ask_choose_device_idxs(choose_only_one=False, allow_cpu=True, suggest_best_multi_gpu=False, suggest_all_gpu=False):
|
||||
devices = Devices.getDevices()
|
||||
if len(devices) == 0:
|
||||
return []
|
||||
|
@ -310,12 +253,13 @@ class nn():
|
|||
pass
|
||||
io.log_info ("")
|
||||
|
||||
if return_device_config:
|
||||
return nn.DeviceConfig.GPUIndexes(choosed_idxs)
|
||||
else:
|
||||
return choosed_idxs
|
||||
return choosed_idxs
|
||||
|
||||
class DeviceConfig():
|
||||
@staticmethod
|
||||
def ask_choose_device(*args, **kwargs):
|
||||
return nn.DeviceConfig.GPUIndexes( nn.ask_choose_device_idxs(*args,**kwargs) )
|
||||
|
||||
def __init__ (self, devices=None):
|
||||
devices = devices or []
|
||||
|
||||
|
|
345
core/leras/ops/__init__.py
Normal file
345
core/leras/ops/__init__.py
Normal file
|
@ -0,0 +1,345 @@
|
|||
import numpy as np
|
||||
from core.leras import nn
|
||||
tf = nn.tf
|
||||
from tensorflow.python.ops import array_ops, random_ops, math_ops, sparse_ops, gradients
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
|
||||
def tf_get_value(tensor):
|
||||
return nn.tf_sess.run (tensor)
|
||||
nn.tf_get_value = tf_get_value
|
||||
|
||||
|
||||
def batch_set_value(tuples):
|
||||
if len(tuples) != 0:
|
||||
with nn.tf.device('/CPU:0'):
|
||||
assign_ops = []
|
||||
feed_dict = {}
|
||||
|
||||
for x, value in tuples:
|
||||
if isinstance(value, nn.tf.Operation) or \
|
||||
isinstance(value, nn.tf.Variable):
|
||||
assign_ops.append(value)
|
||||
else:
|
||||
value = np.asarray(value, dtype=x.dtype.as_numpy_dtype)
|
||||
assign_placeholder = nn.tf.placeholder( x.dtype.base_dtype, shape=[None]*value.ndim )
|
||||
assign_op = nn.tf.assign (x, assign_placeholder )
|
||||
assign_ops.append(assign_op)
|
||||
feed_dict[assign_placeholder] = value
|
||||
|
||||
nn.tf_sess.run(assign_ops, feed_dict=feed_dict)
|
||||
nn.batch_set_value = batch_set_value
|
||||
|
||||
def init_weights(weights):
|
||||
ops = []
|
||||
|
||||
ca_tuples_w = []
|
||||
ca_tuples = []
|
||||
for w in weights:
|
||||
initializer = w.initializer
|
||||
for input in initializer.inputs:
|
||||
if "_cai_" in input.name:
|
||||
ca_tuples_w.append (w)
|
||||
ca_tuples.append ( (w.shape.as_list(), w.dtype.as_numpy_dtype) )
|
||||
break
|
||||
else:
|
||||
ops.append (initializer)
|
||||
|
||||
if len(ops) != 0:
|
||||
nn.tf_sess.run (ops)
|
||||
|
||||
if len(ca_tuples) != 0:
|
||||
nn.batch_set_value( [*zip(ca_tuples_w, nn.initializers.ca.generate_batch (ca_tuples))] )
|
||||
nn.init_weights = init_weights
|
||||
|
||||
def tf_gradients ( loss, vars ):
|
||||
grads = gradients.gradients(loss, vars, colocate_gradients_with_ops=True )
|
||||
gv = [*zip(grads,vars)]
|
||||
for g,v in gv:
|
||||
if g is None:
|
||||
raise Exception(f"No gradient for variable {v.name}")
|
||||
return gv
|
||||
nn.gradients = tf_gradients
|
||||
|
||||
def average_gv_list(grad_var_list, tf_device_string=None):
|
||||
if len(grad_var_list) == 1:
|
||||
return grad_var_list[0]
|
||||
|
||||
e = tf.device(tf_device_string) if tf_device_string is not None else None
|
||||
if e is not None: e.__enter__()
|
||||
result = []
|
||||
for i, (gv) in enumerate(grad_var_list):
|
||||
for j,(g,v) in enumerate(gv):
|
||||
g = tf.expand_dims(g, 0)
|
||||
if i == 0:
|
||||
result += [ [[g], v] ]
|
||||
else:
|
||||
result[j][0] += [g]
|
||||
|
||||
for i,(gs,v) in enumerate(result):
|
||||
result[i] = ( tf.reduce_mean( tf.concat (gs, 0), 0 ), v )
|
||||
if e is not None: e.__exit__(None,None,None)
|
||||
return result
|
||||
nn.average_gv_list = average_gv_list
|
||||
|
||||
def average_tensor_list(tensors_list, tf_device_string=None):
|
||||
if len(tensors_list) == 1:
|
||||
return tensors_list[0]
|
||||
|
||||
e = tf.device(tf_device_string) if tf_device_string is not None else None
|
||||
if e is not None: e.__enter__()
|
||||
result = tf.reduce_mean(tf.concat ([tf.expand_dims(t, 0) for t in tensors_list], 0), 0)
|
||||
if e is not None: e.__exit__(None,None,None)
|
||||
return result
|
||||
nn.average_tensor_list = average_tensor_list
|
||||
|
||||
def concat (tensors_list, axis):
|
||||
"""
|
||||
Better version.
|
||||
"""
|
||||
if len(tensors_list) == 1:
|
||||
return tensors_list[0]
|
||||
return tf.concat(tensors_list, axis)
|
||||
nn.concat = concat
|
||||
|
||||
def gelu(x):
|
||||
cdf = 0.5 * (1.0 + tf.nn.tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
|
||||
return x * cdf
|
||||
nn.gelu = gelu
|
||||
|
||||
def upsample2d(x, size=2):
|
||||
if nn.data_format == "NCHW":
|
||||
b,c,h,w = x.shape.as_list()
|
||||
x = tf.reshape (x, (-1,c,h,1,w,1) )
|
||||
x = tf.tile(x, (1,1,1,size,1,size) )
|
||||
x = tf.reshape (x, (-1,c,h*size,w*size) )
|
||||
return x
|
||||
else:
|
||||
return tf.image.resize_nearest_neighbor(x, (x.shape[1]*size, x.shape[2]*size) )
|
||||
nn.upsample2d = upsample2d
|
||||
|
||||
def resize2d_bilinear(x, size=2):
|
||||
h = x.shape[nn.conv2d_spatial_axes[0]].value
|
||||
w = x.shape[nn.conv2d_spatial_axes[1]].value
|
||||
|
||||
if nn.data_format == "NCHW":
|
||||
x = tf.transpose(x, (0,2,3,1))
|
||||
|
||||
if size > 0:
|
||||
new_size = (h*size,w*size)
|
||||
else:
|
||||
new_size = (h//-size,w//-size)
|
||||
|
||||
x = tf.image.resize(x, new_size, method=tf.image.ResizeMethod.BILINEAR)
|
||||
|
||||
if nn.data_format == "NCHW":
|
||||
x = tf.transpose(x, (0,3,1,2))
|
||||
|
||||
return x
|
||||
nn.resize2d_bilinear = resize2d_bilinear
|
||||
|
||||
|
||||
|
||||
def flatten(x):
|
||||
if nn.data_format == "NHWC":
|
||||
# match NCHW version in order to switch data_format without problems
|
||||
x = tf.transpose(x, (0,3,1,2) )
|
||||
return tf.reshape (x, (-1, np.prod(x.shape[1:])) )
|
||||
|
||||
nn.flatten = flatten
|
||||
|
||||
def max_pool(x, kernel_size=2, strides=2):
|
||||
if nn.data_format == "NHWC":
|
||||
return tf.nn.max_pool(x, [1,kernel_size,kernel_size,1], [1,strides,strides,1], 'SAME', data_format=nn.data_format)
|
||||
else:
|
||||
return tf.nn.max_pool(x, [1,1,kernel_size,kernel_size], [1,1,strides,strides], 'SAME', data_format=nn.data_format)
|
||||
|
||||
nn.max_pool = max_pool
|
||||
|
||||
def reshape_4D(x, w,h,c):
|
||||
if nn.data_format == "NHWC":
|
||||
# match NCHW version in order to switch data_format without problems
|
||||
x = tf.reshape (x, (-1,c,h,w))
|
||||
x = tf.transpose(x, (0,2,3,1) )
|
||||
return x
|
||||
else:
|
||||
return tf.reshape (x, (-1,c,h,w))
|
||||
nn.reshape_4D = reshape_4D
|
||||
|
||||
def random_binomial(shape, p=0.0, dtype=None, seed=None):
|
||||
if dtype is None:
|
||||
dtype=tf.float32
|
||||
|
||||
if seed is None:
|
||||
seed = np.random.randint(10e6)
|
||||
return array_ops.where(
|
||||
random_ops.random_uniform(shape, dtype=tf.float16, seed=seed) < p,
|
||||
array_ops.ones(shape, dtype=dtype), array_ops.zeros(shape, dtype=dtype))
|
||||
nn.random_binomial = random_binomial
|
||||
|
||||
def gaussian_blur(input, radius=2.0):
|
||||
def gaussian(x, mu, sigma):
|
||||
return np.exp(-(float(x) - float(mu)) ** 2 / (2 * sigma ** 2))
|
||||
|
||||
def make_kernel(sigma):
|
||||
kernel_size = max(3, int(2 * 2 * sigma + 1))
|
||||
mean = np.floor(0.5 * kernel_size)
|
||||
kernel_1d = np.array([gaussian(x, mean, sigma) for x in range(kernel_size)])
|
||||
np_kernel = np.outer(kernel_1d, kernel_1d).astype(np.float32)
|
||||
kernel = np_kernel / np.sum(np_kernel)
|
||||
return kernel, kernel_size
|
||||
|
||||
gauss_kernel, kernel_size = make_kernel(radius)
|
||||
padding = kernel_size//2
|
||||
if padding != 0:
|
||||
if nn.data_format == "NHWC":
|
||||
padding = [ [0,0], [padding,padding], [padding,padding], [0,0] ]
|
||||
else:
|
||||
padding = [ [0,0], [0,0], [padding,padding], [padding,padding] ]
|
||||
else:
|
||||
padding = None
|
||||
gauss_kernel = gauss_kernel[:,:,None,None]
|
||||
|
||||
x = input
|
||||
k = tf.tile (gauss_kernel, (1,1,x.shape[nn.conv2d_ch_axis],1) )
|
||||
x = tf.pad(x, padding )
|
||||
x = tf.nn.depthwise_conv2d(x, k, strides=[1,1,1,1], padding='VALID', data_format=nn.data_format)
|
||||
return x
|
||||
nn.gaussian_blur = gaussian_blur
|
||||
|
||||
def style_loss(target, style, gaussian_blur_radius=0.0, loss_weight=1.0, step_size=1):
|
||||
def sd(content, style, loss_weight):
|
||||
content_nc = content.shape[ nn.conv2d_ch_axis ]
|
||||
style_nc = style.shape[nn.conv2d_ch_axis]
|
||||
if content_nc != style_nc:
|
||||
raise Exception("style_loss() content_nc != style_nc")
|
||||
c_mean, c_var = tf.nn.moments(content, axes=nn.conv2d_spatial_axes, keep_dims=True)
|
||||
s_mean, s_var = tf.nn.moments(style, axes=nn.conv2d_spatial_axes, keep_dims=True)
|
||||
c_std, s_std = tf.sqrt(c_var + 1e-5), tf.sqrt(s_var + 1e-5)
|
||||
mean_loss = tf.reduce_sum(tf.square(c_mean-s_mean), axis=[1,2,3])
|
||||
std_loss = tf.reduce_sum(tf.square(c_std-s_std), axis=[1,2,3])
|
||||
return (mean_loss + std_loss) * ( loss_weight / content_nc.value )
|
||||
|
||||
if gaussian_blur_radius > 0.0:
|
||||
target = gaussian_blur(target, gaussian_blur_radius)
|
||||
style = gaussian_blur(style, gaussian_blur_radius)
|
||||
|
||||
return sd( target, style, loss_weight=loss_weight )
|
||||
|
||||
nn.style_loss = style_loss
|
||||
|
||||
def dssim(img1,img2, max_val, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03):
|
||||
if img1.dtype != img2.dtype:
|
||||
raise ValueError("img1.dtype != img2.dtype")
|
||||
|
||||
not_float32 = img1.dtype != tf.float32
|
||||
|
||||
if not_float32:
|
||||
img_dtype = img1.dtype
|
||||
img1 = tf.cast(img1, tf.float32)
|
||||
img2 = tf.cast(img2, tf.float32)
|
||||
|
||||
kernel = np.arange(0, filter_size, dtype=np.float32)
|
||||
kernel -= (filter_size - 1 ) / 2.0
|
||||
kernel = kernel**2
|
||||
kernel *= ( -0.5 / (filter_sigma**2) )
|
||||
kernel = np.reshape (kernel, (1,-1)) + np.reshape(kernel, (-1,1) )
|
||||
kernel = tf.constant ( np.reshape (kernel, (1,-1)), dtype=tf.float32 )
|
||||
kernel = tf.nn.softmax(kernel)
|
||||
kernel = tf.reshape (kernel, (filter_size, filter_size, 1, 1))
|
||||
kernel = tf.tile (kernel, (1,1, img1.shape[ nn.conv2d_ch_axis ] ,1))
|
||||
|
||||
def reducer(x):
|
||||
return tf.nn.depthwise_conv2d(x, kernel, strides=[1,1,1,1], padding='VALID', data_format=nn.data_format)
|
||||
|
||||
c1 = (k1 * max_val) ** 2
|
||||
c2 = (k2 * max_val) ** 2
|
||||
|
||||
mean0 = reducer(img1)
|
||||
mean1 = reducer(img2)
|
||||
num0 = mean0 * mean1 * 2.0
|
||||
den0 = tf.square(mean0) + tf.square(mean1)
|
||||
luminance = (num0 + c1) / (den0 + c1)
|
||||
|
||||
num1 = reducer(img1 * img2) * 2.0
|
||||
den1 = reducer(tf.square(img1) + tf.square(img2))
|
||||
c2 *= 1.0 #compensation factor
|
||||
cs = (num1 - num0 + c2) / (den1 - den0 + c2)
|
||||
|
||||
ssim_val = tf.reduce_mean(luminance * cs, axis=nn.conv2d_spatial_axes )
|
||||
dssim = (1.0 - ssim_val ) / 2.0
|
||||
|
||||
if not_float32:
|
||||
dssim = tf.cast(dssim, img_dtype)
|
||||
return dssim
|
||||
|
||||
nn.dssim = dssim
|
||||
|
||||
def space_to_depth(x, size):
|
||||
if nn.data_format == "NHWC":
|
||||
# match NCHW version in order to switch data_format without problems
|
||||
b,h,w,c = x.shape.as_list()
|
||||
oh, ow = h // size, w // size
|
||||
x = tf.reshape(x, (-1, size, oh, size, ow, c))
|
||||
x = tf.transpose(x, (0, 2, 4, 1, 3, 5))
|
||||
x = tf.reshape(x, (-1, oh, ow, size* size* c ))
|
||||
return x
|
||||
else:
|
||||
return tf.space_to_depth(x, size, data_format=nn.data_format)
|
||||
nn.space_to_depth = space_to_depth
|
||||
|
||||
def depth_to_space(x, size):
|
||||
if nn.data_format == "NHWC":
|
||||
# match NCHW version in order to switch data_format without problems
|
||||
|
||||
b,h,w,c = x.shape.as_list()
|
||||
oh, ow = h * size, w * size
|
||||
oc = c // (size * size)
|
||||
|
||||
x = tf.reshape(x, (-1, h, w, size, size, oc, ) )
|
||||
x = tf.transpose(x, (0, 1, 3, 2, 4, 5))
|
||||
x = tf.reshape(x, (-1, oh, ow, oc, ))
|
||||
return x
|
||||
else:
|
||||
return tf.depth_to_space(x, size, data_format=nn.data_format)
|
||||
nn.depth_to_space = depth_to_space
|
||||
|
||||
def rgb_to_lab(srgb):
|
||||
srgb_pixels = tf.reshape(srgb, [-1, 3])
|
||||
linear_mask = tf.cast(srgb_pixels <= 0.04045, dtype=tf.float32)
|
||||
exponential_mask = tf.cast(srgb_pixels > 0.04045, dtype=tf.float32)
|
||||
rgb_pixels = (srgb_pixels / 12.92 * linear_mask) + (((srgb_pixels + 0.055) / 1.055) ** 2.4) * exponential_mask
|
||||
rgb_to_xyz = tf.constant([
|
||||
# X Y Z
|
||||
[0.412453, 0.212671, 0.019334], # R
|
||||
[0.357580, 0.715160, 0.119193], # G
|
||||
[0.180423, 0.072169, 0.950227], # B
|
||||
])
|
||||
xyz_pixels = tf.matmul(rgb_pixels, rgb_to_xyz)
|
||||
|
||||
xyz_normalized_pixels = tf.multiply(xyz_pixels, [1/0.950456, 1.0, 1/1.088754])
|
||||
|
||||
epsilon = 6/29
|
||||
linear_mask = tf.cast(xyz_normalized_pixels <= (epsilon**3), dtype=tf.float32)
|
||||
exponential_mask = tf.cast(xyz_normalized_pixels > (epsilon**3), dtype=tf.float32)
|
||||
fxfyfz_pixels = (xyz_normalized_pixels / (3 * epsilon**2) + 4/29) * linear_mask + (xyz_normalized_pixels ** (1/3)) * exponential_mask
|
||||
|
||||
fxfyfz_to_lab = tf.constant([
|
||||
# l a b
|
||||
[ 0.0, 500.0, 0.0], # fx
|
||||
[116.0, -500.0, 200.0], # fy
|
||||
[ 0.0, 0.0, -200.0], # fz
|
||||
])
|
||||
lab_pixels = tf.matmul(fxfyfz_pixels, fxfyfz_to_lab) + tf.constant([-16.0, 0.0, 0.0])
|
||||
return tf.reshape(lab_pixels, tf.shape(srgb))
|
||||
nn.rgb_to_lab = rgb_to_lab
|
||||
|
||||
"""
|
||||
def tf_suppress_lower_mean(t, eps=0.00001):
|
||||
if t.shape.ndims != 1:
|
||||
raise ValueError("tf_suppress_lower_mean: t rank must be 1")
|
||||
t_mean_eps = tf.reduce_mean(t) - eps
|
||||
q = tf.clip_by_value(t, t_mean_eps, tf.reduce_max(t) )
|
||||
q = tf.clip_by_value(q-t_mean_eps, 0, eps)
|
||||
q = q * (t/eps)
|
||||
return q
|
||||
"""
|
|
@ -1,110 +0,0 @@
|
|||
import copy
|
||||
|
||||
def initialize_optimizers(nn):
|
||||
tf = nn.tf
|
||||
from tensorflow.python.ops import state_ops, control_flow_ops
|
||||
|
||||
class TFBaseOptimizer(nn.Saveable):
|
||||
def __init__(self, name=None):
|
||||
super().__init__(name=name)
|
||||
|
||||
def tf_clip_norm(self, g, c, n):
|
||||
"""Clip the gradient `g` if the L2 norm `n` exceeds `c`.
|
||||
# Arguments
|
||||
g: Tensor, the gradient tensor
|
||||
c: float >= 0. Gradients will be clipped
|
||||
when their L2 norm exceeds this value.
|
||||
n: Tensor, actual norm of `g`.
|
||||
# Returns
|
||||
Tensor, the gradient clipped if required.
|
||||
"""
|
||||
if c <= 0: # if clipnorm == 0 no need to add ops to the graph
|
||||
return g
|
||||
|
||||
condition = n >= c
|
||||
then_expression = tf.scalar_mul(c / n, g)
|
||||
else_expression = g
|
||||
|
||||
# saving the shape to avoid converting sparse tensor to dense
|
||||
if isinstance(then_expression, tf.Tensor):
|
||||
g_shape = copy.copy(then_expression.get_shape())
|
||||
elif isinstance(then_expression, tf.IndexedSlices):
|
||||
g_shape = copy.copy(then_expression.dense_shape)
|
||||
if condition.dtype != tf.bool:
|
||||
condition = tf.cast(condition, 'bool')
|
||||
g = tf.cond(condition,
|
||||
lambda: then_expression,
|
||||
lambda: else_expression)
|
||||
if isinstance(then_expression, tf.Tensor):
|
||||
g.set_shape(g_shape)
|
||||
elif isinstance(then_expression, tf.IndexedSlices):
|
||||
g._dense_shape = g_shape
|
||||
|
||||
return g
|
||||
nn.TFBaseOptimizer = TFBaseOptimizer
|
||||
|
||||
class TFRMSpropOptimizer(TFBaseOptimizer):
|
||||
def __init__(self, lr=0.001, rho=0.9, lr_dropout=1.0, epsilon=1e-7, clipnorm=0.0, name=None):
|
||||
super().__init__(name=name)
|
||||
|
||||
if name is None:
|
||||
raise ValueError('name must be defined.')
|
||||
|
||||
self.lr_dropout = lr_dropout
|
||||
self.clipnorm = clipnorm
|
||||
|
||||
with tf.device('/CPU:0') :
|
||||
with tf.variable_scope(self.name):
|
||||
self.lr = tf.Variable (lr, name="lr")
|
||||
self.rho = tf.Variable (rho, name="rho")
|
||||
self.epsilon = tf.Variable (epsilon, name="epsilon")
|
||||
self.iterations = tf.Variable(0, dtype=tf.int64, name='iters')
|
||||
|
||||
self.accumulators_dict = {}
|
||||
self.lr_rnds_dict = {}
|
||||
|
||||
def get_weights(self):
|
||||
return [self.lr, self.rho, self.epsilon, self.iterations] + list(self.accumulators_dict.values())
|
||||
|
||||
def initialize_variables(self, trainable_weights, vars_on_cpu=True):
|
||||
# Initialize here all trainable variables used in training
|
||||
e = tf.device('/CPU:0') if vars_on_cpu else None
|
||||
if e: e.__enter__()
|
||||
with tf.variable_scope(self.name):
|
||||
accumulators = { v.name : tf.get_variable ( f'acc_{v.name}'.replace(':','_'), v.shape, dtype=v.dtype, initializer=tf.initializers.constant(0.0), trainable=False) for v in trainable_weights }
|
||||
self.accumulators_dict.update ( accumulators)
|
||||
|
||||
if self.lr_dropout != 1.0:
|
||||
lr_rnds = [ nn.tf_random_binomial( v.shape, p=self.lr_dropout, dtype=v.dtype) for v in trainable_weights ]
|
||||
self.lr_rnds_dict.update ( { v.name : rnd for v,rnd in zip(trainable_weights,lr_rnds) } )
|
||||
if e: e.__exit__(None, None, None)
|
||||
|
||||
def get_update_op(self, grads_vars):
|
||||
updates = []
|
||||
|
||||
if self.clipnorm > 0.0:
|
||||
norm = tf.sqrt( sum([tf.reduce_sum(tf.square(g)) for g,v in grads_vars]))
|
||||
updates += [ state_ops.assign_add( self.iterations, 1) ]
|
||||
for i, (g,v) in enumerate(grads_vars):
|
||||
if self.clipnorm > 0.0:
|
||||
g = self.tf_clip_norm(g, self.clipnorm, norm)
|
||||
|
||||
a = self.accumulators_dict[ v.name ]
|
||||
|
||||
rho = tf.cast(self.rho, a.dtype)
|
||||
new_a = rho * a + (1. - rho) * tf.square(g)
|
||||
|
||||
lr = tf.cast(self.lr, a.dtype)
|
||||
epsilon = tf.cast(self.epsilon, a.dtype)
|
||||
|
||||
v_diff = - lr * g / (tf.sqrt(new_a) + epsilon)
|
||||
if self.lr_dropout != 1.0:
|
||||
lr_rnd = self.lr_rnds_dict[v.name]
|
||||
v_diff *= lr_rnd
|
||||
new_v = v + v_diff
|
||||
|
||||
updates.append (state_ops.assign(a, new_a))
|
||||
updates.append (state_ops.assign(v, new_v))
|
||||
|
||||
return control_flow_ops.group ( *updates, name=self.name+'_updates')
|
||||
nn.TFRMSpropOptimizer = TFRMSpropOptimizer
|
42
core/leras/optimizers/OptimizerBase.py
Normal file
42
core/leras/optimizers/OptimizerBase.py
Normal file
|
@ -0,0 +1,42 @@
|
|||
import copy
|
||||
from core.leras import nn
|
||||
tf = nn.tf
|
||||
|
||||
class OptimizerBase(nn.Saveable):
|
||||
def __init__(self, name=None):
|
||||
super().__init__(name=name)
|
||||
|
||||
def tf_clip_norm(self, g, c, n):
|
||||
"""Clip the gradient `g` if the L2 norm `n` exceeds `c`.
|
||||
# Arguments
|
||||
g: Tensor, the gradient tensor
|
||||
c: float >= 0. Gradients will be clipped
|
||||
when their L2 norm exceeds this value.
|
||||
n: Tensor, actual norm of `g`.
|
||||
# Returns
|
||||
Tensor, the gradient clipped if required.
|
||||
"""
|
||||
if c <= 0: # if clipnorm == 0 no need to add ops to the graph
|
||||
return g
|
||||
|
||||
condition = n >= c
|
||||
then_expression = tf.scalar_mul(c / n, g)
|
||||
else_expression = g
|
||||
|
||||
# saving the shape to avoid converting sparse tensor to dense
|
||||
if isinstance(then_expression, tf.Tensor):
|
||||
g_shape = copy.copy(then_expression.get_shape())
|
||||
elif isinstance(then_expression, tf.IndexedSlices):
|
||||
g_shape = copy.copy(then_expression.dense_shape)
|
||||
if condition.dtype != tf.bool:
|
||||
condition = tf.cast(condition, 'bool')
|
||||
g = tf.cond(condition,
|
||||
lambda: then_expression,
|
||||
lambda: else_expression)
|
||||
if isinstance(then_expression, tf.Tensor):
|
||||
g.set_shape(g_shape)
|
||||
elif isinstance(then_expression, tf.IndexedSlices):
|
||||
g._dense_shape = g_shape
|
||||
|
||||
return g
|
||||
nn.OptimizerBase = OptimizerBase
|
69
core/leras/optimizers/RMSprop.py
Normal file
69
core/leras/optimizers/RMSprop.py
Normal file
|
@ -0,0 +1,69 @@
|
|||
from tensorflow.python.ops import control_flow_ops, state_ops
|
||||
from core.leras import nn
|
||||
tf = nn.tf
|
||||
|
||||
class RMSprop(nn.OptimizerBase):
|
||||
def __init__(self, lr=0.001, rho=0.9, lr_dropout=1.0, epsilon=1e-7, clipnorm=0.0, name=None):
|
||||
super().__init__(name=name)
|
||||
|
||||
if name is None:
|
||||
raise ValueError('name must be defined.')
|
||||
|
||||
self.lr_dropout = lr_dropout
|
||||
self.clipnorm = clipnorm
|
||||
|
||||
with tf.device('/CPU:0') :
|
||||
with tf.variable_scope(self.name):
|
||||
self.lr = tf.Variable (lr, name="lr")
|
||||
self.rho = tf.Variable (rho, name="rho")
|
||||
self.epsilon = tf.Variable (epsilon, name="epsilon")
|
||||
self.iterations = tf.Variable(0, dtype=tf.int64, name='iters')
|
||||
|
||||
self.accumulators_dict = {}
|
||||
self.lr_rnds_dict = {}
|
||||
|
||||
def get_weights(self):
|
||||
return [self.lr, self.rho, self.epsilon, self.iterations] + list(self.accumulators_dict.values())
|
||||
|
||||
def initialize_variables(self, trainable_weights, vars_on_cpu=True):
|
||||
# Initialize here all trainable variables used in training
|
||||
e = tf.device('/CPU:0') if vars_on_cpu else None
|
||||
if e: e.__enter__()
|
||||
with tf.variable_scope(self.name):
|
||||
accumulators = { v.name : tf.get_variable ( f'acc_{v.name}'.replace(':','_'), v.shape, dtype=v.dtype, initializer=tf.initializers.constant(0.0), trainable=False) for v in trainable_weights }
|
||||
self.accumulators_dict.update ( accumulators)
|
||||
|
||||
if self.lr_dropout != 1.0:
|
||||
lr_rnds = [ nn.random_binomial( v.shape, p=self.lr_dropout, dtype=v.dtype) for v in trainable_weights ]
|
||||
self.lr_rnds_dict.update ( { v.name : rnd for v,rnd in zip(trainable_weights,lr_rnds) } )
|
||||
if e: e.__exit__(None, None, None)
|
||||
|
||||
def get_update_op(self, grads_vars):
|
||||
updates = []
|
||||
|
||||
if self.clipnorm > 0.0:
|
||||
norm = tf.sqrt( sum([tf.reduce_sum(tf.square(g)) for g,v in grads_vars]))
|
||||
updates += [ state_ops.assign_add( self.iterations, 1) ]
|
||||
for i, (g,v) in enumerate(grads_vars):
|
||||
if self.clipnorm > 0.0:
|
||||
g = self.tf_clip_norm(g, self.clipnorm, norm)
|
||||
|
||||
a = self.accumulators_dict[ v.name ]
|
||||
|
||||
rho = tf.cast(self.rho, a.dtype)
|
||||
new_a = rho * a + (1. - rho) * tf.square(g)
|
||||
|
||||
lr = tf.cast(self.lr, a.dtype)
|
||||
epsilon = tf.cast(self.epsilon, a.dtype)
|
||||
|
||||
v_diff = - lr * g / (tf.sqrt(new_a) + epsilon)
|
||||
if self.lr_dropout != 1.0:
|
||||
lr_rnd = self.lr_rnds_dict[v.name]
|
||||
v_diff *= lr_rnd
|
||||
new_v = v + v_diff
|
||||
|
||||
updates.append (state_ops.assign(a, new_a))
|
||||
updates.append (state_ops.assign(v, new_v))
|
||||
|
||||
return control_flow_ops.group ( *updates, name=self.name+'_updates')
|
||||
nn.RMSprop = RMSprop
|
2
core/leras/optimizers/__init__.py
Normal file
2
core/leras/optimizers/__init__.py
Normal file
|
@ -0,0 +1,2 @@
|
|||
from .OptimizerBase import *
|
||||
from .RMSprop import *
|
|
@ -1,374 +0,0 @@
|
|||
import numpy as np
|
||||
|
||||
def initialize_tensor_ops(nn):
|
||||
tf = nn.tf
|
||||
from tensorflow.python.ops import array_ops, random_ops, math_ops, sparse_ops, gradients
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
|
||||
def tf_get_value(tensor):
|
||||
return nn.tf_sess.run (tensor)
|
||||
nn.tf_get_value = tf_get_value
|
||||
|
||||
|
||||
def tf_batch_set_value(tuples):
|
||||
if len(tuples) != 0:
|
||||
with nn.tf.device('/CPU:0'):
|
||||
assign_ops = []
|
||||
feed_dict = {}
|
||||
|
||||
for x, value in tuples:
|
||||
if isinstance(value, nn.tf.Operation):
|
||||
assign_ops.append(value)
|
||||
else:
|
||||
value = np.asarray(value, dtype=x.dtype.as_numpy_dtype)
|
||||
assign_placeholder = nn.tf.placeholder( x.dtype.base_dtype, shape=[None]*value.ndim )
|
||||
assign_op = nn.tf.assign (x, assign_placeholder )
|
||||
assign_ops.append(assign_op)
|
||||
feed_dict[assign_placeholder] = value
|
||||
|
||||
nn.tf_sess.run(assign_ops, feed_dict=feed_dict)
|
||||
nn.tf_batch_set_value = tf_batch_set_value
|
||||
|
||||
def tf_init_weights(weights):
|
||||
ops = []
|
||||
|
||||
ca_tuples_w = []
|
||||
ca_tuples = []
|
||||
for w in weights:
|
||||
initializer = w.initializer
|
||||
for input in initializer.inputs:
|
||||
if "_cai_" in input.name:
|
||||
ca_tuples_w.append (w)
|
||||
ca_tuples.append ( (w.shape.as_list(), w.dtype.as_numpy_dtype) )
|
||||
break
|
||||
else:
|
||||
ops.append (initializer)
|
||||
|
||||
if len(ops) != 0:
|
||||
nn.tf_sess.run (ops)
|
||||
|
||||
if len(ca_tuples) != 0:
|
||||
nn.tf_batch_set_value( [*zip(ca_tuples_w, nn.initializers.ca.generate_batch (ca_tuples))] )
|
||||
nn.tf_init_weights = tf_init_weights
|
||||
|
||||
def tf_gradients ( loss, vars ):
|
||||
grads = gradients.gradients(loss, vars, colocate_gradients_with_ops=True )
|
||||
gv = [*zip(grads,vars)]
|
||||
for g,v in gv:
|
||||
if g is None:
|
||||
raise Exception(f"No gradient for variable {v.name}")
|
||||
return gv
|
||||
nn.tf_gradients = tf_gradients
|
||||
|
||||
def tf_average_gv_list(grad_var_list, tf_device_string=None):
|
||||
if len(grad_var_list) == 1:
|
||||
return grad_var_list[0]
|
||||
|
||||
e = tf.device(tf_device_string) if tf_device_string is not None else None
|
||||
if e is not None: e.__enter__()
|
||||
result = []
|
||||
for i, (gv) in enumerate(grad_var_list):
|
||||
for j,(g,v) in enumerate(gv):
|
||||
g = tf.expand_dims(g, 0)
|
||||
if i == 0:
|
||||
result += [ [[g], v] ]
|
||||
else:
|
||||
result[j][0] += [g]
|
||||
|
||||
for i,(gs,v) in enumerate(result):
|
||||
result[i] = ( tf.reduce_mean( tf.concat (gs, 0), 0 ), v )
|
||||
if e is not None: e.__exit__(None,None,None)
|
||||
return result
|
||||
nn.tf_average_gv_list = tf_average_gv_list
|
||||
|
||||
def tf_average_tensor_list(tensors_list, tf_device_string=None):
|
||||
if len(tensors_list) == 1:
|
||||
return tensors_list[0]
|
||||
|
||||
e = tf.device(tf_device_string) if tf_device_string is not None else None
|
||||
if e is not None: e.__enter__()
|
||||
result = tf.reduce_mean(tf.concat ([tf.expand_dims(t, 0) for t in tensors_list], 0), 0)
|
||||
if e is not None: e.__exit__(None,None,None)
|
||||
return result
|
||||
nn.tf_average_tensor_list = tf_average_tensor_list
|
||||
|
||||
def tf_concat (tensors_list, axis):
|
||||
"""
|
||||
Better version.
|
||||
"""
|
||||
if len(tensors_list) == 1:
|
||||
return tensors_list[0]
|
||||
return tf.concat(tensors_list, axis)
|
||||
nn.tf_concat = tf_concat
|
||||
|
||||
def tf_gelu(x):
|
||||
cdf = 0.5 * (1.0 + tf.nn.tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
|
||||
return x * cdf
|
||||
nn.tf_gelu = tf_gelu
|
||||
|
||||
def tf_upsample2d(x, size=2):
|
||||
if nn.data_format == "NCHW":
|
||||
b,c,h,w = x.shape.as_list()
|
||||
x = tf.reshape (x, (-1,c,h,1,w,1) )
|
||||
x = tf.tile(x, (1,1,1,size,1,size) )
|
||||
x = tf.reshape (x, (-1,c,h*size,w*size) )
|
||||
return x
|
||||
else:
|
||||
return tf.image.resize_nearest_neighbor(x, (x.shape[1]*size, x.shape[2]*size) )
|
||||
nn.tf_upsample2d = tf_upsample2d
|
||||
|
||||
def tf_resize2d_bilinear(x, size=2):
|
||||
h = x.shape[nn.conv2d_spatial_axes[0]].value
|
||||
w = x.shape[nn.conv2d_spatial_axes[1]].value
|
||||
|
||||
if nn.data_format == "NCHW":
|
||||
x = tf.transpose(x, (0,2,3,1))
|
||||
|
||||
if size > 0:
|
||||
new_size = (h*size,w*size)
|
||||
else:
|
||||
new_size = (h//-size,w//-size)
|
||||
|
||||
x = tf.image.resize(x, new_size, method=tf.image.ResizeMethod.BILINEAR)
|
||||
|
||||
if nn.data_format == "NCHW":
|
||||
x = tf.transpose(x, (0,3,1,2))
|
||||
|
||||
return x
|
||||
nn.tf_resize2d_bilinear = tf_resize2d_bilinear
|
||||
|
||||
|
||||
|
||||
def tf_flatten(x):
|
||||
if nn.data_format == "NHWC":
|
||||
# match NCHW version in order to switch data_format without problems
|
||||
x = tf.transpose(x, (0,3,1,2) )
|
||||
return tf.reshape (x, (-1, np.prod(x.shape[1:])) )
|
||||
|
||||
nn.tf_flatten = tf_flatten
|
||||
|
||||
def tf_max_pool(x, kernel_size, strides):
|
||||
if nn.data_format == "NHWC":
|
||||
return tf.nn.max_pool(x, [1,kernel_size,kernel_size,1], [1,strides,strides,1], "VALID", data_format=nn.data_format)
|
||||
else:
|
||||
return tf.nn.max_pool(x, [1,1,kernel_size,kernel_size], [1,1,strides,strides], "VALID", data_format=nn.data_format)
|
||||
|
||||
nn.tf_max_pool = tf_max_pool
|
||||
|
||||
def tf_reshape_4D(x, w,h,c):
|
||||
if nn.data_format == "NHWC":
|
||||
# match NCHW version in order to switch data_format without problems
|
||||
x = tf.reshape (x, (-1,c,h,w))
|
||||
x = tf.transpose(x, (0,2,3,1) )
|
||||
return x
|
||||
else:
|
||||
return tf.reshape (x, (-1,c,h,w))
|
||||
nn.tf_reshape_4D = tf_reshape_4D
|
||||
|
||||
def tf_random_binomial(shape, p=0.0, dtype=None, seed=None):
|
||||
if dtype is None:
|
||||
dtype=tf.float32
|
||||
|
||||
if seed is None:
|
||||
seed = np.random.randint(10e6)
|
||||
return array_ops.where(
|
||||
random_ops.random_uniform(shape, dtype=tf.float16, seed=seed) < p,
|
||||
array_ops.ones(shape, dtype=dtype), array_ops.zeros(shape, dtype=dtype))
|
||||
nn.tf_random_binomial = tf_random_binomial
|
||||
|
||||
def tf_gaussian_blur(input, radius=2.0):
|
||||
def gaussian(x, mu, sigma):
|
||||
return np.exp(-(float(x) - float(mu)) ** 2 / (2 * sigma ** 2))
|
||||
|
||||
def make_kernel(sigma):
|
||||
kernel_size = max(3, int(2 * 2 * sigma + 1))
|
||||
mean = np.floor(0.5 * kernel_size)
|
||||
kernel_1d = np.array([gaussian(x, mean, sigma) for x in range(kernel_size)])
|
||||
np_kernel = np.outer(kernel_1d, kernel_1d).astype(np.float32)
|
||||
kernel = np_kernel / np.sum(np_kernel)
|
||||
return kernel, kernel_size
|
||||
|
||||
gauss_kernel, kernel_size = make_kernel(radius)
|
||||
padding = kernel_size//2
|
||||
if padding != 0:
|
||||
if nn.data_format == "NHWC":
|
||||
padding = [ [0,0], [padding,padding], [padding,padding], [0,0] ]
|
||||
else:
|
||||
padding = [ [0,0], [0,0], [padding,padding], [padding,padding] ]
|
||||
else:
|
||||
padding = None
|
||||
gauss_kernel = gauss_kernel[:,:,None,None]
|
||||
|
||||
x = input
|
||||
k = tf.tile (gauss_kernel, (1,1,x.shape[nn.conv2d_ch_axis],1) )
|
||||
x = tf.pad(x, padding )
|
||||
x = tf.nn.depthwise_conv2d(x, k, strides=[1,1,1,1], padding='VALID', data_format=nn.data_format)
|
||||
return x
|
||||
nn.tf_gaussian_blur = tf_gaussian_blur
|
||||
|
||||
def tf_style_loss(target, style, gaussian_blur_radius=0.0, loss_weight=1.0, step_size=1):
|
||||
def sd(content, style, loss_weight):
|
||||
content_nc = content.shape[ nn.conv2d_ch_axis ]
|
||||
style_nc = style.shape[nn.conv2d_ch_axis]
|
||||
if content_nc != style_nc:
|
||||
raise Exception("style_loss() content_nc != style_nc")
|
||||
c_mean, c_var = tf.nn.moments(content, axes=nn.conv2d_spatial_axes, keep_dims=True)
|
||||
s_mean, s_var = tf.nn.moments(style, axes=nn.conv2d_spatial_axes, keep_dims=True)
|
||||
c_std, s_std = tf.sqrt(c_var + 1e-5), tf.sqrt(s_var + 1e-5)
|
||||
mean_loss = tf.reduce_sum(tf.square(c_mean-s_mean), axis=[1,2,3])
|
||||
std_loss = tf.reduce_sum(tf.square(c_std-s_std), axis=[1,2,3])
|
||||
return (mean_loss + std_loss) * ( loss_weight / content_nc.value )
|
||||
|
||||
if gaussian_blur_radius > 0.0:
|
||||
target = tf_gaussian_blur(target, gaussian_blur_radius)
|
||||
style = tf_gaussian_blur(style, gaussian_blur_radius)
|
||||
|
||||
return sd( target, style, loss_weight=loss_weight )
|
||||
|
||||
nn.tf_style_loss = tf_style_loss
|
||||
|
||||
def tf_dssim(img1,img2, max_val, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03):
|
||||
if img1.dtype != img2.dtype:
|
||||
raise ValueError("img1.dtype != img2.dtype")
|
||||
|
||||
not_float32 = img1.dtype != tf.float32
|
||||
|
||||
if not_float32:
|
||||
img_dtype = img1.dtype
|
||||
img1 = tf.cast(img1, tf.float32)
|
||||
img2 = tf.cast(img2, tf.float32)
|
||||
|
||||
kernel = np.arange(0, filter_size, dtype=np.float32)
|
||||
kernel -= (filter_size - 1 ) / 2.0
|
||||
kernel = kernel**2
|
||||
kernel *= ( -0.5 / (filter_sigma**2) )
|
||||
kernel = np.reshape (kernel, (1,-1)) + np.reshape(kernel, (-1,1) )
|
||||
kernel = tf.constant ( np.reshape (kernel, (1,-1)), dtype=tf.float32 )
|
||||
kernel = tf.nn.softmax(kernel)
|
||||
kernel = tf.reshape (kernel, (filter_size, filter_size, 1, 1))
|
||||
kernel = tf.tile (kernel, (1,1, img1.shape[ nn.conv2d_ch_axis ] ,1))
|
||||
|
||||
def reducer(x):
|
||||
return tf.nn.depthwise_conv2d(x, kernel, strides=[1,1,1,1], padding='VALID', data_format=nn.data_format)
|
||||
|
||||
c1 = (k1 * max_val) ** 2
|
||||
c2 = (k2 * max_val) ** 2
|
||||
|
||||
mean0 = reducer(img1)
|
||||
mean1 = reducer(img2)
|
||||
num0 = mean0 * mean1 * 2.0
|
||||
den0 = tf.square(mean0) + tf.square(mean1)
|
||||
luminance = (num0 + c1) / (den0 + c1)
|
||||
|
||||
num1 = reducer(img1 * img2) * 2.0
|
||||
den1 = reducer(tf.square(img1) + tf.square(img2))
|
||||
c2 *= 1.0 #compensation factor
|
||||
cs = (num1 - num0 + c2) / (den1 - den0 + c2)
|
||||
|
||||
ssim_val = tf.reduce_mean(luminance * cs, axis=nn.conv2d_spatial_axes )
|
||||
dssim = (1.0 - ssim_val ) / 2.0
|
||||
|
||||
if not_float32:
|
||||
dssim = tf.cast(dssim, img_dtype)
|
||||
return dssim
|
||||
|
||||
nn.tf_dssim = tf_dssim
|
||||
|
||||
def tf_space_to_depth(x, size):
|
||||
if nn.data_format == "NHWC":
|
||||
# match NCHW version in order to switch data_format without problems
|
||||
b,h,w,c = x.shape.as_list()
|
||||
oh, ow = h // size, w // size
|
||||
x = tf.reshape(x, (-1, size, oh, size, ow, c))
|
||||
x = tf.transpose(x, (0, 2, 4, 1, 3, 5))
|
||||
x = tf.reshape(x, (-1, oh, ow, size* size* c ))
|
||||
return x
|
||||
else:
|
||||
return tf.space_to_depth(x, size, data_format=nn.data_format)
|
||||
nn.tf_space_to_depth = tf_space_to_depth
|
||||
|
||||
def tf_depth_to_space(x, size):
|
||||
if nn.data_format == "NHWC":
|
||||
# match NCHW version in order to switch data_format without problems
|
||||
|
||||
b,h,w,c = x.shape.as_list()
|
||||
oh, ow = h * size, w * size
|
||||
oc = c // (size * size)
|
||||
|
||||
x = tf.reshape(x, (-1, h, w, size, size, oc, ) )
|
||||
x = tf.transpose(x, (0, 1, 3, 2, 4, 5))
|
||||
x = tf.reshape(x, (-1, oh, ow, oc, ))
|
||||
return x
|
||||
else:
|
||||
return tf.depth_to_space(x, size, data_format=nn.data_format)
|
||||
nn.tf_depth_to_space = tf_depth_to_space
|
||||
|
||||
def tf_rgb_to_lab(srgb):
|
||||
srgb_pixels = tf.reshape(srgb, [-1, 3])
|
||||
linear_mask = tf.cast(srgb_pixels <= 0.04045, dtype=tf.float32)
|
||||
exponential_mask = tf.cast(srgb_pixels > 0.04045, dtype=tf.float32)
|
||||
rgb_pixels = (srgb_pixels / 12.92 * linear_mask) + (((srgb_pixels + 0.055) / 1.055) ** 2.4) * exponential_mask
|
||||
rgb_to_xyz = tf.constant([
|
||||
# X Y Z
|
||||
[0.412453, 0.212671, 0.019334], # R
|
||||
[0.357580, 0.715160, 0.119193], # G
|
||||
[0.180423, 0.072169, 0.950227], # B
|
||||
])
|
||||
xyz_pixels = tf.matmul(rgb_pixels, rgb_to_xyz)
|
||||
|
||||
xyz_normalized_pixels = tf.multiply(xyz_pixels, [1/0.950456, 1.0, 1/1.088754])
|
||||
|
||||
epsilon = 6/29
|
||||
linear_mask = tf.cast(xyz_normalized_pixels <= (epsilon**3), dtype=tf.float32)
|
||||
exponential_mask = tf.cast(xyz_normalized_pixels > (epsilon**3), dtype=tf.float32)
|
||||
fxfyfz_pixels = (xyz_normalized_pixels / (3 * epsilon**2) + 4/29) * linear_mask + (xyz_normalized_pixels ** (1/3)) * exponential_mask
|
||||
|
||||
fxfyfz_to_lab = tf.constant([
|
||||
# l a b
|
||||
[ 0.0, 500.0, 0.0], # fx
|
||||
[116.0, -500.0, 200.0], # fy
|
||||
[ 0.0, 0.0, -200.0], # fz
|
||||
])
|
||||
lab_pixels = tf.matmul(fxfyfz_pixels, fxfyfz_to_lab) + tf.constant([-16.0, 0.0, 0.0])
|
||||
return tf.reshape(lab_pixels, tf.shape(srgb))
|
||||
nn.tf_rgb_to_lab = tf_rgb_to_lab
|
||||
|
||||
def tf_suppress_lower_mean(t, eps=0.00001):
|
||||
if t.shape.ndims != 1:
|
||||
raise ValueError("tf_suppress_lower_mean: t rank must be 1")
|
||||
t_mean_eps = tf.reduce_mean(t) - eps
|
||||
q = tf.clip_by_value(t, t_mean_eps, tf.reduce_max(t) )
|
||||
q = tf.clip_by_value(q-t_mean_eps, 0, eps)
|
||||
q = q * (t/eps)
|
||||
return q
|
||||
"""
|
||||
class GeLU(KL.Layer):
|
||||
Gaussian Error Linear Unit.
|
||||
A smoother version of ReLU generally used
|
||||
in the BERT or BERT architecture based models.
|
||||
Original paper: https://arxiv.org/abs/1606.08415
|
||||
Input shape:
|
||||
Arbitrary. Use the keyword argument `input_shape`
|
||||
(tuple of integers, does not include the samples axis)
|
||||
when using this layer as the first layer in a model.
|
||||
Output shape:
|
||||
Same shape as the input.
|
||||
|
||||
def __init__(self, approximate=True, **kwargs):
|
||||
super(GeLU, self).__init__(**kwargs)
|
||||
self.approximate = approximate
|
||||
self.supports_masking = True
|
||||
|
||||
def call(self, inputs):
|
||||
cdf = 0.5 * (1.0 + K.tanh((np.sqrt(2 / np.pi) * (inputs + 0.044715 * K.pow(inputs, 3)))))
|
||||
return inputs * cdf
|
||||
|
||||
def get_config(self):
|
||||
config = {'approximate': self.approximate}
|
||||
base_config = super(GeLU, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
return input_shape
|
||||
nn.GeLU = GeLU
|
||||
"""
|
|
@ -33,4 +33,5 @@ def get_screen_size():
|
|||
elif 'linux' in sys.platform:
|
||||
pass
|
||||
|
||||
return (1366, 768)
|
||||
return (1366, 768)
|
||||
|
95
facelib/DFLSegNet.py
Normal file
95
facelib/DFLSegNet.py
Normal file
|
@ -0,0 +1,95 @@
|
|||
import os
|
||||
import pickle
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from core.interact import interact as io
|
||||
from core.leras import nn
|
||||
|
||||
|
||||
class DFLSegNet(object):
|
||||
VERSION = 1
|
||||
|
||||
def __init__ (self, name,
|
||||
resolution,
|
||||
load_weights=True,
|
||||
weights_file_root=None,
|
||||
training=False,
|
||||
place_model_on_cpu=False,
|
||||
run_on_cpu=False,
|
||||
optimizer=None,
|
||||
data_format="NHWC"):
|
||||
|
||||
nn.initialize(data_format=data_format)
|
||||
tf = nn.tf
|
||||
|
||||
self.weights_file_root = Path(weights_file_root) if weights_file_root is not None else Path(__file__).parent
|
||||
|
||||
with tf.device ('/CPU:0'):
|
||||
#Place holders on CPU
|
||||
self.input_t = tf.placeholder (nn.floatx, nn.get4Dshape(resolution,resolution,3) )
|
||||
self.target_t = tf.placeholder (nn.floatx, nn.get4Dshape(resolution,resolution,1) )
|
||||
|
||||
# Initializing model classes
|
||||
archi = nn.DFLSegnetArchi()
|
||||
with tf.device ('/CPU:0' if place_model_on_cpu else '/GPU:0'):
|
||||
self.enc = archi.Encoder(3, 64, name='Encoder')
|
||||
self.dec = archi.Decoder(64, 1, name='Decoder')
|
||||
self.enc_dec_weights = self.enc.get_weights()+self.dec.get_weights()
|
||||
|
||||
model_name = f'{name}_{resolution}'
|
||||
|
||||
self.model_filename_list = [ [self.enc, f'{model_name}_enc.npy'],
|
||||
[self.dec, f'{model_name}_dec.npy'],
|
||||
]
|
||||
|
||||
if training:
|
||||
if optimizer is None:
|
||||
raise ValueError("Optimizer should be provided for training mode.")
|
||||
|
||||
self.opt = optimizer
|
||||
self.opt.initialize_variables (self.enc_dec_weights, vars_on_cpu=place_model_on_cpu)
|
||||
self.model_filename_list += [ [self.opt, f'{model_name}_opt.npy' ] ]
|
||||
else:
|
||||
with tf.device ('/CPU:0' if run_on_cpu else '/GPU:0'):
|
||||
_, pred = self.dec(self.enc(self.input_t))
|
||||
|
||||
def net_run(input_np):
|
||||
return nn.tf_sess.run ( [pred], feed_dict={self.input_t :input_np})[0]
|
||||
self.net_run = net_run
|
||||
|
||||
# Loading/initializing all models/optimizers weights
|
||||
for model, filename in self.model_filename_list:
|
||||
do_init = not load_weights
|
||||
|
||||
if not do_init:
|
||||
do_init = not model.load_weights( self.weights_file_root / filename )
|
||||
|
||||
if do_init:
|
||||
model.init_weights()
|
||||
|
||||
def flow(self, x):
|
||||
return self.dec(self.enc(x))
|
||||
|
||||
def get_weights(self):
|
||||
return self.enc_dec_weights
|
||||
|
||||
def save_weights(self):
|
||||
for model, filename in io.progress_bar_generator(self.model_filename_list, "Saving", leave=False):
|
||||
model.save_weights( self.weights_file_root / filename )
|
||||
|
||||
def extract (self, input_image):
|
||||
input_shape_len = len(input_image.shape)
|
||||
if input_shape_len == 3:
|
||||
input_image = input_image[None,...]
|
||||
|
||||
result = np.clip ( self.net_run(input_image), 0, 1.0 )
|
||||
result[result < 0.1] = 0 #get rid of noise
|
||||
|
||||
if input_shape_len == 3:
|
||||
result = result[0]
|
||||
|
||||
return result
|
|
@ -89,7 +89,7 @@ class FANExtractor(object):
|
|||
low2 = self.b2_plus(low1)
|
||||
low3 = self.b3(low2)
|
||||
|
||||
up2 = nn.tf_upsample2d(low3)
|
||||
up2 = nn.upsample2d(low3)
|
||||
|
||||
return up1+up2
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ class FaceEnhancer(object):
|
|||
"""
|
||||
x4 face enhancer
|
||||
"""
|
||||
def __init__(self, place_model_on_cpu=False):
|
||||
def __init__(self, place_model_on_cpu=False, run_on_cpu=False):
|
||||
nn.initialize(data_format="NHWC")
|
||||
tf = nn.tf
|
||||
|
||||
|
@ -111,23 +111,23 @@ class FaceEnhancer(object):
|
|||
x = tf.nn.leaky_relu(self.center_conv2(x), 0.1)
|
||||
x = tf.nn.leaky_relu(self.center_conv3(x), 0.1)
|
||||
|
||||
x = tf.concat( [nn.tf_resize2d_bilinear(x), e4], -1 )
|
||||
x = tf.concat( [nn.resize2d_bilinear(x), e4], -1 )
|
||||
x = tf.nn.leaky_relu(self.d4_conv0(x), 0.1)
|
||||
x = tf.nn.leaky_relu(self.d4_conv1(x), 0.1)
|
||||
|
||||
x = tf.concat( [nn.tf_resize2d_bilinear(x), e3], -1 )
|
||||
x = tf.concat( [nn.resize2d_bilinear(x), e3], -1 )
|
||||
x = tf.nn.leaky_relu(self.d3_conv0(x), 0.1)
|
||||
x = tf.nn.leaky_relu(self.d3_conv1(x), 0.1)
|
||||
|
||||
x = tf.concat( [nn.tf_resize2d_bilinear(x), e2], -1 )
|
||||
x = tf.concat( [nn.resize2d_bilinear(x), e2], -1 )
|
||||
x = tf.nn.leaky_relu(self.d2_conv0(x), 0.1)
|
||||
x = tf.nn.leaky_relu(self.d2_conv1(x), 0.1)
|
||||
|
||||
x = tf.concat( [nn.tf_resize2d_bilinear(x), e1], -1 )
|
||||
x = tf.concat( [nn.resize2d_bilinear(x), e1], -1 )
|
||||
x = tf.nn.leaky_relu(self.d1_conv0(x), 0.1)
|
||||
x = tf.nn.leaky_relu(self.d1_conv1(x), 0.1)
|
||||
|
||||
x = tf.concat( [nn.tf_resize2d_bilinear(x), e0], -1 )
|
||||
x = tf.concat( [nn.resize2d_bilinear(x), e0], -1 )
|
||||
x = tf.nn.leaky_relu(self.d0_conv0(x), 0.1)
|
||||
x = d0 = tf.nn.leaky_relu(self.d0_conv1(x), 0.1)
|
||||
|
||||
|
@ -138,22 +138,22 @@ class FaceEnhancer(object):
|
|||
x = d0
|
||||
x = tf.nn.leaky_relu(self.dec2x_conv0(x), 0.1)
|
||||
x = tf.nn.leaky_relu(self.dec2x_conv1(x), 0.1)
|
||||
x = d2x = nn.tf_resize2d_bilinear(x)
|
||||
x = d2x = nn.resize2d_bilinear(x)
|
||||
|
||||
x = tf.nn.leaky_relu(self.out2x_conv0(x), 0.1)
|
||||
x = self.out2x_conv1(x)
|
||||
|
||||
out2x = nn.tf_resize2d_bilinear(out1x) + tf.nn.tanh(x)
|
||||
out2x = nn.resize2d_bilinear(out1x) + tf.nn.tanh(x)
|
||||
|
||||
x = d2x
|
||||
x = tf.nn.leaky_relu(self.dec4x_conv0(x), 0.1)
|
||||
x = tf.nn.leaky_relu(self.dec4x_conv1(x), 0.1)
|
||||
x = d4x = nn.tf_resize2d_bilinear(x)
|
||||
x = d4x = nn.resize2d_bilinear(x)
|
||||
|
||||
x = tf.nn.leaky_relu(self.out4x_conv0(x), 0.1)
|
||||
x = self.out4x_conv1(x)
|
||||
|
||||
out4x = nn.tf_resize2d_bilinear(out2x) + tf.nn.tanh(x)
|
||||
out4x = nn.resize2d_bilinear(out2x) + tf.nn.tanh(x)
|
||||
|
||||
return out4x
|
||||
|
||||
|
@ -161,17 +161,15 @@ class FaceEnhancer(object):
|
|||
if not model_path.exists():
|
||||
raise Exception("Unable to load FaceEnhancer.npy")
|
||||
|
||||
e = tf.device("/CPU:0") if place_model_on_cpu else None
|
||||
if e is not None: e.__enter__()
|
||||
self.model = FaceEnhancer()
|
||||
self.model.load_weights (model_path)
|
||||
if e is not None: e.__exit__(None,None,None)
|
||||
|
||||
self.model.build_for_run ([ (tf.float32, nn.get4Dshape (192,192,3) ),
|
||||
(tf.float32, (None,1,) ),
|
||||
(tf.float32, (None,1,) ),
|
||||
])
|
||||
|
||||
with tf.device ('/CPU:0' if place_model_on_cpu else '/GPU:0'):
|
||||
self.model = FaceEnhancer()
|
||||
self.model.load_weights (model_path)
|
||||
|
||||
with tf.device ('/CPU:0' if run_on_cpu else '/GPU:0'):
|
||||
self.model.build_for_run ([ (tf.float32, nn.get4Dshape (192,192,3) ),
|
||||
(tf.float32, (None,1,) ),
|
||||
(tf.float32, (None,1,) ),
|
||||
])
|
||||
|
||||
def enhance (self, inp_img, is_tanh=False, preserve_size=True):
|
||||
if not is_tanh:
|
||||
|
|
|
@ -21,7 +21,7 @@ class S3FDExtractor(object):
|
|||
super().__init__(**kwargs)
|
||||
|
||||
def build_weights(self):
|
||||
self.weight = tf.get_variable ("weight", (1, 1, 1, self.n_channels), dtype=nn.tf_floatx, initializer=tf.initializers.ones )
|
||||
self.weight = tf.get_variable ("weight", (1, 1, 1, self.n_channels), dtype=nn.floatx, initializer=tf.initializers.ones )
|
||||
|
||||
def get_weights(self):
|
||||
return [self.weight]
|
||||
|
@ -36,7 +36,7 @@ class S3FDExtractor(object):
|
|||
super().__init__(name='S3FD')
|
||||
|
||||
def on_build(self):
|
||||
self.minus = tf.constant([104,117,123], dtype=nn.tf_floatx )
|
||||
self.minus = tf.constant([104,117,123], dtype=nn.floatx )
|
||||
self.conv1_1 = nn.Conv2D(3, 64, kernel_size=3, strides=1, padding='SAME')
|
||||
self.conv1_2 = nn.Conv2D(64, 64, kernel_size=3, strides=1, padding='SAME')
|
||||
|
||||
|
|
|
@ -9,106 +9,13 @@ import numpy as np
|
|||
from core.interact import interact as io
|
||||
from core.leras import nn
|
||||
|
||||
"""
|
||||
Dataset used to train located in official DFL mega.nz folder
|
||||
https://mega.nz/#F!b9MzCK4B!zEAG9txu7uaRUjXz9PtBqg
|
||||
|
||||
using https://github.com/ternaus/TernausNet
|
||||
TernausNet: U-Net with VGG11 Encoder Pre-Trained on ImageNet for Image Segmentation
|
||||
"""
|
||||
|
||||
class TernausNet(object):
|
||||
VERSION = 1
|
||||
def __init__ (self, name, resolution, face_type_str=None, load_weights=True, weights_file_root=None, training=False, place_model_on_cpu=False, data_format="NHWC"):
|
||||
|
||||
def __init__ (self, name, resolution, load_weights=True, weights_file_root=None, training=False, place_model_on_cpu=False, run_on_cpu=False, optimizer=None, data_format="NHWC"):
|
||||
nn.initialize(data_format=data_format)
|
||||
tf = nn.tf
|
||||
|
||||
class Ternaus(nn.ModelBase):
|
||||
def on_build(self, in_ch, base_ch):
|
||||
|
||||
self.features_0 = nn.Conv2D (in_ch, base_ch, kernel_size=3, padding='SAME')
|
||||
self.blurpool_0 = nn.BlurPool (filt_size=3)
|
||||
|
||||
self.features_3 = nn.Conv2D (base_ch, base_ch*2, kernel_size=3, padding='SAME')
|
||||
self.blurpool_3 = nn.BlurPool (filt_size=3)
|
||||
|
||||
self.features_6 = nn.Conv2D (base_ch*2, base_ch*4, kernel_size=3, padding='SAME')
|
||||
self.features_8 = nn.Conv2D (base_ch*4, base_ch*4, kernel_size=3, padding='SAME')
|
||||
self.blurpool_8 = nn.BlurPool (filt_size=3)
|
||||
|
||||
self.features_11 = nn.Conv2D (base_ch*4, base_ch*8, kernel_size=3, padding='SAME')
|
||||
self.features_13 = nn.Conv2D (base_ch*8, base_ch*8, kernel_size=3, padding='SAME')
|
||||
self.blurpool_13 = nn.BlurPool (filt_size=3)
|
||||
|
||||
self.features_16 = nn.Conv2D (base_ch*8, base_ch*8, kernel_size=3, padding='SAME')
|
||||
self.features_18 = nn.Conv2D (base_ch*8, base_ch*8, kernel_size=3, padding='SAME')
|
||||
self.blurpool_18 = nn.BlurPool (filt_size=3)
|
||||
|
||||
self.conv_center = nn.Conv2D (base_ch*8, base_ch*8, kernel_size=3, padding='SAME')
|
||||
|
||||
self.conv1_up = nn.Conv2DTranspose (base_ch*8, base_ch*4, kernel_size=3, padding='SAME')
|
||||
self.conv1 = nn.Conv2D (base_ch*12, base_ch*8, kernel_size=3, padding='SAME')
|
||||
|
||||
self.conv2_up = nn.Conv2DTranspose (base_ch*8, base_ch*4, kernel_size=3, padding='SAME')
|
||||
self.conv2 = nn.Conv2D (base_ch*12, base_ch*8, kernel_size=3, padding='SAME')
|
||||
|
||||
self.conv3_up = nn.Conv2DTranspose (base_ch*8, base_ch*2, kernel_size=3, padding='SAME')
|
||||
self.conv3 = nn.Conv2D (base_ch*6, base_ch*4, kernel_size=3, padding='SAME')
|
||||
|
||||
self.conv4_up = nn.Conv2DTranspose (base_ch*4, base_ch, kernel_size=3, padding='SAME')
|
||||
self.conv4 = nn.Conv2D (base_ch*3, base_ch*2, kernel_size=3, padding='SAME')
|
||||
|
||||
self.conv5_up = nn.Conv2DTranspose (base_ch*2, base_ch//2, kernel_size=3, padding='SAME')
|
||||
self.conv5 = nn.Conv2D (base_ch//2+base_ch, base_ch, kernel_size=3, padding='SAME')
|
||||
|
||||
self.out_conv = nn.Conv2D (base_ch, 1, kernel_size=3, padding='SAME')
|
||||
|
||||
def forward(self, inp):
|
||||
x, = inp
|
||||
|
||||
x = x0 = tf.nn.relu(self.features_0(x))
|
||||
x = self.blurpool_0(x)
|
||||
|
||||
x = x1 = tf.nn.relu(self.features_3(x))
|
||||
x = self.blurpool_3(x)
|
||||
|
||||
x = tf.nn.relu(self.features_6(x))
|
||||
x = x2 = tf.nn.relu(self.features_8(x))
|
||||
x = self.blurpool_8(x)
|
||||
|
||||
x = tf.nn.relu(self.features_11(x))
|
||||
x = x3 = tf.nn.relu(self.features_13(x))
|
||||
x = self.blurpool_13(x)
|
||||
|
||||
x = tf.nn.relu(self.features_16(x))
|
||||
x = x4 = tf.nn.relu(self.features_18(x))
|
||||
x = self.blurpool_18(x)
|
||||
|
||||
x = self.conv_center(x)
|
||||
|
||||
x = tf.nn.relu(self.conv1_up(x))
|
||||
x = tf.concat( [x,x4], nn.conv2d_ch_axis)
|
||||
x = tf.nn.relu(self.conv1(x))
|
||||
|
||||
x = tf.nn.relu(self.conv2_up(x))
|
||||
x = tf.concat( [x,x3], nn.conv2d_ch_axis)
|
||||
x = tf.nn.relu(self.conv2(x))
|
||||
|
||||
x = tf.nn.relu(self.conv3_up(x))
|
||||
x = tf.concat( [x,x2], nn.conv2d_ch_axis)
|
||||
x = tf.nn.relu(self.conv3(x))
|
||||
|
||||
x = tf.nn.relu(self.conv4_up(x))
|
||||
x = tf.concat( [x,x1], nn.conv2d_ch_axis)
|
||||
x = tf.nn.relu(self.conv4(x))
|
||||
|
||||
x = tf.nn.relu(self.conv5_up(x))
|
||||
x = tf.concat( [x,x0], nn.conv2d_ch_axis)
|
||||
x = tf.nn.relu(self.conv5(x))
|
||||
|
||||
logits = self.out_conv(x)
|
||||
return logits, tf.nn.sigmoid(logits)
|
||||
|
||||
if weights_file_root is not None:
|
||||
weights_file_root = Path(weights_file_root)
|
||||
else:
|
||||
|
@ -117,39 +24,42 @@ class TernausNet(object):
|
|||
|
||||
with tf.device ('/CPU:0'):
|
||||
#Place holders on CPU
|
||||
self.input_t = tf.placeholder (nn.tf_floatx, nn.get4Dshape(resolution,resolution,3) )
|
||||
self.target_t = tf.placeholder (nn.tf_floatx, nn.get4Dshape(resolution,resolution,1) )
|
||||
self.input_t = tf.placeholder (nn.floatx, nn.get4Dshape(resolution,resolution,3) )
|
||||
self.target_t = tf.placeholder (nn.floatx, nn.get4Dshape(resolution,resolution,1) )
|
||||
|
||||
# Initializing model classes
|
||||
with tf.device ('/CPU:0' if place_model_on_cpu else '/GPU:0'):
|
||||
self.net = Ternaus(3, 64, name='Ternaus')
|
||||
self.net = nn.Ternaus(3, 64, name='Ternaus')
|
||||
self.net_weights = self.net.get_weights()
|
||||
|
||||
model_name = f'{name}_{resolution}'
|
||||
if face_type_str is not None:
|
||||
model_name += f'_{face_type_str}'
|
||||
model_name = f'{name}_{resolution}'
|
||||
|
||||
self.model_filename_list = [ [self.net, f'{model_name}.npy'] ]
|
||||
|
||||
if training:
|
||||
self.opt = nn.TFRMSpropOptimizer(lr=0.0001, name='opt')
|
||||
if optimizer is None:
|
||||
raise ValueError("Optimizer should be provided for traning mode.")
|
||||
|
||||
self.opt = optimizer
|
||||
self.opt.initialize_variables (self.net_weights, vars_on_cpu=place_model_on_cpu)
|
||||
self.model_filename_list += [ [self.opt, f'{model_name}_opt.npy' ] ]
|
||||
else:
|
||||
_, pred = self.net([self.input_t])
|
||||
with tf.device ('/CPU:0' if run_on_cpu else '/GPU:0'):
|
||||
_, pred = self.net([self.input_t])
|
||||
|
||||
def net_run(input_np):
|
||||
return nn.tf_sess.run ( [pred], feed_dict={self.input_t :input_np})[0]
|
||||
self.net_run = net_run
|
||||
|
||||
# Loading/initializing all models/optimizers weights
|
||||
for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"):
|
||||
for model, filename in self.model_filename_list:
|
||||
do_init = not load_weights
|
||||
|
||||
if not do_init:
|
||||
do_init = not model.load_weights( self.weights_file_root / filename )
|
||||
|
||||
if do_init:
|
||||
model.init_weights()
|
||||
model.init_weights()
|
||||
if model == self.net:
|
||||
try:
|
||||
with open( Path(__file__).parent / 'vgg11_enc_weights.npy', 'rb' ) as f:
|
||||
|
@ -177,7 +87,7 @@ class TernausNet(object):
|
|||
|
||||
return result
|
||||
|
||||
"""
|
||||
"""
|
||||
if load_weights:
|
||||
self.net.load_weights (self.weights_path)
|
||||
else:
|
||||
|
|
|
@ -2,4 +2,5 @@ from .FaceType import FaceType
|
|||
from .S3FDExtractor import S3FDExtractor
|
||||
from .FANExtractor import FANExtractor
|
||||
from .FaceEnhancer import FaceEnhancer
|
||||
from .TernausNet import TernausNet
|
||||
from .TernausNet import TernausNet
|
||||
from .DFLSegNet import DFLSegNet
|
29
main.py
29
main.py
|
@ -5,7 +5,6 @@ if __name__ == "__main__":
|
|||
|
||||
from core.leras import nn
|
||||
nn.initialize_main_env()
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
@ -116,6 +115,26 @@ if __name__ == "__main__":
|
|||
|
||||
p.set_defaults (func=process_dev_segmented_extract)
|
||||
|
||||
def process_dev_segmented_trash(arguments):
|
||||
osex.set_process_lowest_prio()
|
||||
from mainscripts import dev_misc
|
||||
dev_misc.dev_segmented_trash(arguments.input_dir)
|
||||
|
||||
p = subparsers.add_parser( "dev_segmented_trash", help="")
|
||||
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir")
|
||||
|
||||
p.set_defaults (func=process_dev_segmented_trash)
|
||||
|
||||
def process_dev_resave_pngs(arguments):
|
||||
osex.set_process_lowest_prio()
|
||||
from mainscripts import dev_misc
|
||||
dev_misc.dev_resave_pngs(arguments.input_dir)
|
||||
|
||||
p = subparsers.add_parser( "dev_resave_pngs", help="")
|
||||
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir")
|
||||
|
||||
p.set_defaults (func=process_dev_resave_pngs)
|
||||
|
||||
def process_sort(arguments):
|
||||
osex.set_process_lowest_prio()
|
||||
from mainscripts import Sorter
|
||||
|
@ -130,9 +149,6 @@ if __name__ == "__main__":
|
|||
osex.set_process_lowest_prio()
|
||||
from mainscripts import Util
|
||||
|
||||
if arguments.convert_png_to_jpg:
|
||||
Util.convert_png_to_jpg_folder (input_path=arguments.input_dir)
|
||||
|
||||
if arguments.add_landmarks_debug_images:
|
||||
Util.add_landmarks_debug_images (input_path=arguments.input_dir)
|
||||
|
||||
|
@ -163,7 +179,6 @@ if __name__ == "__main__":
|
|||
|
||||
p = subparsers.add_parser( "util", help="Utilities.")
|
||||
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory. A directory containing the files you wish to process.")
|
||||
p.add_argument('--convert-png-to-jpg', action="store_true", dest="convert_png_to_jpg", default=False, help="Convert DeepFaceLAB PNG files to JPEG.")
|
||||
p.add_argument('--add-landmarks-debug-images', action="store_true", dest="add_landmarks_debug_images", default=False, help="Add landmarks debug image for aligned faces.")
|
||||
p.add_argument('--recover-original-aligned-filename', action="store_true", dest="recover_original_aligned_filename", default=False, help="Recover original aligned filename.")
|
||||
#p.add_argument('--remove-fanseg', action="store_true", dest="remove_fanseg", default=False, help="Remove fanseg mask from aligned faces.")
|
||||
|
@ -215,7 +230,6 @@ if __name__ == "__main__":
|
|||
from mainscripts import Merger
|
||||
Merger.main ( model_class_name = arguments.model_name,
|
||||
saved_models_path = Path(arguments.model_dir),
|
||||
training_data_src_path = Path(arguments.training_data_src_dir) if arguments.training_data_src_dir is not None else None,
|
||||
force_model_name = arguments.force_model_name,
|
||||
input_path = Path(arguments.input_dir),
|
||||
output_path = Path(arguments.output_dir),
|
||||
|
@ -225,7 +239,6 @@ if __name__ == "__main__":
|
|||
cpu_only = arguments.cpu_only)
|
||||
|
||||
p = subparsers.add_parser( "merge", help="Merger")
|
||||
p.add_argument('--training-data-src-dir', action=fixPathAction, dest="training_data_src_dir", default=None, help="(optional, may be required by some models) Dir of extracted SRC faceset.")
|
||||
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory. A directory containing the files you wish to process.")
|
||||
p.add_argument('--output-dir', required=True, action=fixPathAction, dest="output_dir", help="Output directory. This is where the merged files will be stored.")
|
||||
p.add_argument('--output-mask-dir', required=True, action=fixPathAction, dest="output_mask_dir", help="Output mask directory. This is where the mask files will be stored.")
|
||||
|
@ -360,7 +373,7 @@ if __name__ == "__main__":
|
|||
arguments.func(arguments)
|
||||
|
||||
print ("Done.")
|
||||
|
||||
|
||||
'''
|
||||
import code
|
||||
code.interact(local=dict(globals(), **locals()))
|
||||
|
|
|
@ -93,7 +93,7 @@ class FacesetEnhancerSubprocessor(Subprocessor):
|
|||
self.log_info (intro_str)
|
||||
|
||||
from facelib import FaceEnhancer
|
||||
self.fe = FaceEnhancer( place_model_on_cpu=(device_vram<=2) )
|
||||
self.fe = FaceEnhancer( place_model_on_cpu=(device_vram<=2 or cpu_only), run_on_cpu=cpu_only )
|
||||
|
||||
#override
|
||||
def process_data(self, filepath):
|
||||
|
|
|
@ -1,635 +1,19 @@
|
|||
import math
|
||||
import multiprocessing
|
||||
import operator
|
||||
import os
|
||||
import pickle
|
||||
import shutil
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import numpy.linalg as npla
|
||||
|
||||
from core import imagelib
|
||||
import samplelib
|
||||
from merger import (MergerConfig, MergeFaceAvatar, MergeMasked,
|
||||
FrameInfo)
|
||||
from DFLIMG import DFLIMG
|
||||
from facelib import FaceEnhancer, FaceType, LandmarksProcessor, TernausNet
|
||||
from core.interact import interact as io
|
||||
from core.joblib import SubprocessFunctionCaller, Subprocessor
|
||||
from core.leras import nn
|
||||
from core import pathex
|
||||
from core.cv2ex import *
|
||||
|
||||
from .MergerScreen import Screen, ScreenManager
|
||||
|
||||
MERGER_DEBUG = False
|
||||
|
||||
class MergeSubprocessor(Subprocessor):
|
||||
|
||||
class Frame(object):
|
||||
def __init__(self, prev_temporal_frame_infos=None,
|
||||
frame_info=None,
|
||||
next_temporal_frame_infos=None):
|
||||
self.prev_temporal_frame_infos = prev_temporal_frame_infos
|
||||
self.frame_info = frame_info
|
||||
self.next_temporal_frame_infos = next_temporal_frame_infos
|
||||
self.output_filepath = None
|
||||
self.output_mask_filepath = None
|
||||
|
||||
self.idx = None
|
||||
self.cfg = None
|
||||
self.is_done = False
|
||||
self.is_processing = False
|
||||
self.is_shown = False
|
||||
self.image = None
|
||||
|
||||
class ProcessingFrame(object):
|
||||
def __init__(self, idx=None,
|
||||
cfg=None,
|
||||
prev_temporal_frame_infos=None,
|
||||
frame_info=None,
|
||||
next_temporal_frame_infos=None,
|
||||
output_filepath=None,
|
||||
output_mask_filepath=None,
|
||||
need_return_image = False):
|
||||
self.idx = idx
|
||||
self.cfg = cfg
|
||||
self.prev_temporal_frame_infos = prev_temporal_frame_infos
|
||||
self.frame_info = frame_info
|
||||
self.next_temporal_frame_infos = next_temporal_frame_infos
|
||||
self.output_filepath = output_filepath
|
||||
self.output_mask_filepath = output_mask_filepath
|
||||
|
||||
self.need_return_image = need_return_image
|
||||
if self.need_return_image:
|
||||
self.image = None
|
||||
|
||||
class Cli(Subprocessor.Cli):
|
||||
|
||||
#override
|
||||
def on_initialize(self, client_dict):
|
||||
self.log_info ('Running on %s.' % (client_dict['device_name']) )
|
||||
self.device_idx = client_dict['device_idx']
|
||||
self.device_name = client_dict['device_name']
|
||||
self.predictor_func = client_dict['predictor_func']
|
||||
self.predictor_input_shape = client_dict['predictor_input_shape']
|
||||
self.superres_func = client_dict['superres_func']
|
||||
self.fanseg_input_size = client_dict['fanseg_input_size']
|
||||
self.fanseg_extract_func = client_dict['fanseg_extract_func']
|
||||
|
||||
#transfer and set stdin in order to work code.interact in debug subprocess
|
||||
stdin_fd = client_dict['stdin_fd']
|
||||
if stdin_fd is not None:
|
||||
sys.stdin = os.fdopen(stdin_fd)
|
||||
|
||||
def blursharpen_func (img, sharpen_mode=0, kernel_size=3, amount=100):
|
||||
if kernel_size % 2 == 0:
|
||||
kernel_size += 1
|
||||
if amount > 0:
|
||||
if sharpen_mode == 1: #box
|
||||
kernel = np.zeros( (kernel_size, kernel_size), dtype=np.float32)
|
||||
kernel[ kernel_size//2, kernel_size//2] = 1.0
|
||||
box_filter = np.ones( (kernel_size, kernel_size), dtype=np.float32) / (kernel_size**2)
|
||||
kernel = kernel + (kernel - box_filter) * amount
|
||||
return cv2.filter2D(img, -1, kernel)
|
||||
elif sharpen_mode == 2: #gaussian
|
||||
blur = cv2.GaussianBlur(img, (kernel_size, kernel_size) , 0)
|
||||
img = cv2.addWeighted(img, 1.0 + (0.5 * amount), blur, -(0.5 * amount), 0)
|
||||
return img
|
||||
elif amount < 0:
|
||||
n = -amount
|
||||
while n > 0:
|
||||
|
||||
img_blur = cv2.medianBlur(img, 5)
|
||||
if int(n / 10) != 0:
|
||||
img = img_blur
|
||||
else:
|
||||
pass_power = (n % 10) / 10.0
|
||||
img = img*(1.0-pass_power)+img_blur*pass_power
|
||||
n = max(n-10,0)
|
||||
|
||||
return img
|
||||
return img
|
||||
self.blursharpen_func = blursharpen_func
|
||||
|
||||
return None
|
||||
|
||||
#override
|
||||
def process_data(self, pf): #pf=ProcessingFrame
|
||||
cfg = pf.cfg.copy()
|
||||
cfg.blursharpen_func = self.blursharpen_func
|
||||
cfg.superres_func = self.superres_func
|
||||
|
||||
frame_info = pf.frame_info
|
||||
filepath = frame_info.filepath
|
||||
|
||||
if len(frame_info.landmarks_list) == 0:
|
||||
self.log_info (f'no faces found for {filepath.name}, copying without faces')
|
||||
|
||||
img_bgr = cv2_imread(filepath)
|
||||
imagelib.normalize_channels(img_bgr, 3)
|
||||
cv2_imwrite (pf.output_filepath, img_bgr)
|
||||
h,w,c = img_bgr.shape
|
||||
|
||||
img_mask = np.zeros( (h,w,1), dtype=img_bgr.dtype)
|
||||
cv2_imwrite (pf.output_mask_filepath, img_mask)
|
||||
|
||||
if pf.need_return_image:
|
||||
pf.image = np.concatenate ([img_bgr, img_mask], axis=-1)
|
||||
|
||||
else:
|
||||
if cfg.type == MergerConfig.TYPE_MASKED:
|
||||
cfg.fanseg_input_size = self.fanseg_input_size
|
||||
cfg.fanseg_extract_func = self.fanseg_extract_func
|
||||
|
||||
try:
|
||||
final_img = MergeMasked (self.predictor_func, self.predictor_input_shape, cfg, frame_info)
|
||||
except Exception as e:
|
||||
e_str = traceback.format_exc()
|
||||
if 'MemoryError' in e_str:
|
||||
raise Subprocessor.SilenceException
|
||||
else:
|
||||
raise Exception( f'Error while merging file [{filepath}]: {e_str}' )
|
||||
|
||||
elif cfg.type == MergerConfig.TYPE_FACE_AVATAR:
|
||||
final_img = MergeFaceAvatar (self.predictor_func, self.predictor_input_shape,
|
||||
cfg, pf.prev_temporal_frame_infos,
|
||||
pf.frame_info,
|
||||
pf.next_temporal_frame_infos )
|
||||
|
||||
cv2_imwrite (pf.output_filepath, final_img[...,0:3] )
|
||||
cv2_imwrite (pf.output_mask_filepath, final_img[...,3:4] )
|
||||
|
||||
if pf.need_return_image:
|
||||
pf.image = final_img
|
||||
|
||||
return pf
|
||||
|
||||
#overridable
|
||||
def get_data_name (self, pf):
|
||||
#return string identificator of your data
|
||||
return pf.frame_info.filepath
|
||||
|
||||
#override
|
||||
def __init__(self, is_interactive, merger_session_filepath, predictor_func, predictor_input_shape, merger_config, frames, frames_root_path, output_path, output_mask_path, model_iter):
|
||||
if len (frames) == 0:
|
||||
raise ValueError ("len (frames) == 0")
|
||||
|
||||
super().__init__('Merger', MergeSubprocessor.Cli, io_loop_sleep_time=0.001)
|
||||
|
||||
self.is_interactive = is_interactive
|
||||
self.merger_session_filepath = Path(merger_session_filepath)
|
||||
self.merger_config = merger_config
|
||||
|
||||
self.predictor_func_host, self.predictor_func = SubprocessFunctionCaller.make_pair(predictor_func)
|
||||
self.predictor_input_shape = predictor_input_shape
|
||||
|
||||
self.face_enhancer = None
|
||||
def superres_func(face_bgr):
|
||||
if self.face_enhancer is None:
|
||||
self.face_enhancer = FaceEnhancer(place_model_on_cpu=True)
|
||||
|
||||
return self.face_enhancer.enhance (face_bgr, is_tanh=True, preserve_size=False)
|
||||
|
||||
self.superres_host, self.superres_func = SubprocessFunctionCaller.make_pair(superres_func)
|
||||
|
||||
self.fanseg_by_face_type = {}
|
||||
self.fanseg_input_size = 256
|
||||
def fanseg_extract_func(face_type, *args, **kwargs):
|
||||
fanseg = self.fanseg_by_face_type.get(face_type, None)
|
||||
if self.fanseg_by_face_type.get(face_type, None) is None:
|
||||
cpu_only = len(nn.getCurrentDeviceConfig().devices) == 0
|
||||
|
||||
with nn.tf.device('/CPU:0' if cpu_only else '/GPU:0'):
|
||||
fanseg = TernausNet("FANSeg", self.fanseg_input_size , FaceType.toString( face_type ), place_model_on_cpu=True )
|
||||
|
||||
self.fanseg_by_face_type[face_type] = fanseg
|
||||
return fanseg.extract(*args, **kwargs)
|
||||
|
||||
self.fanseg_host, self.fanseg_extract_func = SubprocessFunctionCaller.make_pair(fanseg_extract_func)
|
||||
|
||||
self.frames_root_path = frames_root_path
|
||||
self.output_path = output_path
|
||||
self.output_mask_path = output_mask_path
|
||||
self.model_iter = model_iter
|
||||
|
||||
self.prefetch_frame_count = self.process_count = multiprocessing.cpu_count()
|
||||
|
||||
session_data = None
|
||||
if self.is_interactive and self.merger_session_filepath.exists():
|
||||
io.input_skip_pending()
|
||||
if io.input_bool ("Use saved session?", True):
|
||||
try:
|
||||
with open( str(self.merger_session_filepath), "rb") as f:
|
||||
session_data = pickle.loads(f.read())
|
||||
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
rewind_to_frame_idx = None
|
||||
self.frames = frames
|
||||
self.frames_idxs = [ *range(len(self.frames)) ]
|
||||
self.frames_done_idxs = []
|
||||
|
||||
if self.is_interactive and session_data is not None:
|
||||
# Loaded session data, check it
|
||||
s_frames = session_data.get('frames', None)
|
||||
s_frames_idxs = session_data.get('frames_idxs', None)
|
||||
s_frames_done_idxs = session_data.get('frames_done_idxs', None)
|
||||
s_model_iter = session_data.get('model_iter', None)
|
||||
|
||||
frames_equal = (s_frames is not None) and \
|
||||
(s_frames_idxs is not None) and \
|
||||
(s_frames_done_idxs is not None) and \
|
||||
(s_model_iter is not None) and \
|
||||
(len(frames) == len(s_frames)) # frames count must match
|
||||
|
||||
if frames_equal:
|
||||
for i in range(len(frames)):
|
||||
frame = frames[i]
|
||||
s_frame = s_frames[i]
|
||||
# frames filenames must match
|
||||
if frame.frame_info.filepath.name != s_frame.frame_info.filepath.name:
|
||||
frames_equal = False
|
||||
if not frames_equal:
|
||||
break
|
||||
|
||||
if frames_equal:
|
||||
io.log_info ('Using saved session from ' + '/'.join (self.merger_session_filepath.parts[-2:]) )
|
||||
|
||||
for frame in s_frames:
|
||||
if frame.cfg is not None:
|
||||
# recreate MergerConfig class using constructor with get_config() as dict params
|
||||
# so if any new param will be added, old merger session will work properly
|
||||
frame.cfg = frame.cfg.__class__( **frame.cfg.get_config() )
|
||||
|
||||
self.frames = s_frames
|
||||
self.frames_idxs = s_frames_idxs
|
||||
self.frames_done_idxs = s_frames_done_idxs
|
||||
|
||||
if self.model_iter != s_model_iter:
|
||||
# model was more trained, recompute all frames
|
||||
rewind_to_frame_idx = -1
|
||||
for frame in self.frames:
|
||||
frame.is_done = False
|
||||
elif len(self.frames_idxs) == 0:
|
||||
# all frames are done?
|
||||
rewind_to_frame_idx = -1
|
||||
|
||||
if len(self.frames_idxs) != 0:
|
||||
cur_frame = self.frames[self.frames_idxs[0]]
|
||||
cur_frame.is_shown = False
|
||||
|
||||
if not frames_equal:
|
||||
session_data = None
|
||||
|
||||
if session_data is None:
|
||||
for filename in pathex.get_image_paths(self.output_path): #remove all images in output_path
|
||||
Path(filename).unlink()
|
||||
|
||||
for filename in pathex.get_image_paths(self.output_mask_path): #remove all images in output_mask_path
|
||||
Path(filename).unlink()
|
||||
|
||||
|
||||
frames[0].cfg = self.merger_config.copy()
|
||||
|
||||
for i in range( len(self.frames) ):
|
||||
frame = self.frames[i]
|
||||
frame.idx = i
|
||||
frame.output_filepath = self.output_path / ( frame.frame_info.filepath.stem + '.png' )
|
||||
frame.output_mask_filepath = self.output_mask_path / ( frame.frame_info.filepath.stem + '.png' )
|
||||
|
||||
if not frame.output_filepath.exists() or \
|
||||
not frame.output_mask_filepath.exists():
|
||||
# if some frame does not exist, recompute and rewind
|
||||
frame.is_done = False
|
||||
frame.is_shown = False
|
||||
|
||||
if rewind_to_frame_idx is None:
|
||||
rewind_to_frame_idx = i-1
|
||||
else:
|
||||
rewind_to_frame_idx = min(rewind_to_frame_idx, i-1)
|
||||
|
||||
if rewind_to_frame_idx is not None:
|
||||
while len(self.frames_done_idxs) > 0:
|
||||
if self.frames_done_idxs[-1] > rewind_to_frame_idx:
|
||||
prev_frame = self.frames[self.frames_done_idxs.pop()]
|
||||
self.frames_idxs.insert(0, prev_frame.idx)
|
||||
else:
|
||||
break
|
||||
#override
|
||||
def process_info_generator(self):
|
||||
r = [0] if MERGER_DEBUG else range(self.process_count)
|
||||
|
||||
for i in r:
|
||||
yield 'CPU%d' % (i), {}, {'device_idx': i,
|
||||
'device_name': 'CPU%d' % (i),
|
||||
'predictor_func': self.predictor_func,
|
||||
'predictor_input_shape' : self.predictor_input_shape,
|
||||
'superres_func': self.superres_func,
|
||||
'fanseg_input_size' : self.fanseg_input_size,
|
||||
'fanseg_extract_func' : self.fanseg_extract_func,
|
||||
'stdin_fd': sys.stdin.fileno() if MERGER_DEBUG else None
|
||||
}
|
||||
|
||||
#overridable optional
|
||||
def on_clients_initialized(self):
|
||||
io.progress_bar ("Merging", len(self.frames_idxs)+len(self.frames_done_idxs), initial=len(self.frames_done_idxs) )
|
||||
|
||||
self.process_remain_frames = not self.is_interactive
|
||||
self.is_interactive_quitting = not self.is_interactive
|
||||
|
||||
if self.is_interactive:
|
||||
help_images = {
|
||||
MergerConfig.TYPE_MASKED : cv2_imread ( str(Path(__file__).parent / 'gfx' / 'help_merger_masked.jpg') ),
|
||||
MergerConfig.TYPE_FACE_AVATAR : cv2_imread ( str(Path(__file__).parent / 'gfx' / 'help_merger_face_avatar.jpg') ),
|
||||
}
|
||||
|
||||
self.main_screen = Screen(initial_scale_to_width=1368, image=None, waiting_icon=True)
|
||||
self.help_screen = Screen(initial_scale_to_height=768, image=help_images[self.merger_config.type], waiting_icon=False)
|
||||
self.screen_manager = ScreenManager( "Merger", [self.main_screen, self.help_screen], capture_keys=True )
|
||||
self.screen_manager.set_current (self.help_screen)
|
||||
self.screen_manager.show_current()
|
||||
|
||||
self.masked_keys_funcs = {
|
||||
'`' : lambda cfg,shift_pressed: cfg.set_mode(0),
|
||||
'1' : lambda cfg,shift_pressed: cfg.set_mode(1),
|
||||
'2' : lambda cfg,shift_pressed: cfg.set_mode(2),
|
||||
'3' : lambda cfg,shift_pressed: cfg.set_mode(3),
|
||||
'4' : lambda cfg,shift_pressed: cfg.set_mode(4),
|
||||
'5' : lambda cfg,shift_pressed: cfg.set_mode(5),
|
||||
'q' : lambda cfg,shift_pressed: cfg.add_hist_match_threshold(1 if not shift_pressed else 5),
|
||||
'a' : lambda cfg,shift_pressed: cfg.add_hist_match_threshold(-1 if not shift_pressed else -5),
|
||||
'w' : lambda cfg,shift_pressed: cfg.add_erode_mask_modifier(1 if not shift_pressed else 5),
|
||||
's' : lambda cfg,shift_pressed: cfg.add_erode_mask_modifier(-1 if not shift_pressed else -5),
|
||||
'e' : lambda cfg,shift_pressed: cfg.add_blur_mask_modifier(1 if not shift_pressed else 5),
|
||||
'd' : lambda cfg,shift_pressed: cfg.add_blur_mask_modifier(-1 if not shift_pressed else -5),
|
||||
'r' : lambda cfg,shift_pressed: cfg.add_motion_blur_power(1 if not shift_pressed else 5),
|
||||
'f' : lambda cfg,shift_pressed: cfg.add_motion_blur_power(-1 if not shift_pressed else -5),
|
||||
't' : lambda cfg,shift_pressed: cfg.add_super_resolution_power(1 if not shift_pressed else 5),
|
||||
'g' : lambda cfg,shift_pressed: cfg.add_super_resolution_power(-1 if not shift_pressed else -5),
|
||||
'y' : lambda cfg,shift_pressed: cfg.add_blursharpen_amount(1 if not shift_pressed else 5),
|
||||
'h' : lambda cfg,shift_pressed: cfg.add_blursharpen_amount(-1 if not shift_pressed else -5),
|
||||
'u' : lambda cfg,shift_pressed: cfg.add_output_face_scale(1 if not shift_pressed else 5),
|
||||
'j' : lambda cfg,shift_pressed: cfg.add_output_face_scale(-1 if not shift_pressed else -5),
|
||||
'i' : lambda cfg,shift_pressed: cfg.add_image_denoise_power(1 if not shift_pressed else 5),
|
||||
'k' : lambda cfg,shift_pressed: cfg.add_image_denoise_power(-1 if not shift_pressed else -5),
|
||||
'o' : lambda cfg,shift_pressed: cfg.add_bicubic_degrade_power(1 if not shift_pressed else 5),
|
||||
'l' : lambda cfg,shift_pressed: cfg.add_bicubic_degrade_power(-1 if not shift_pressed else -5),
|
||||
'p' : lambda cfg,shift_pressed: cfg.add_color_degrade_power(1 if not shift_pressed else 5),
|
||||
';' : lambda cfg,shift_pressed: cfg.add_color_degrade_power(-1),
|
||||
':' : lambda cfg,shift_pressed: cfg.add_color_degrade_power(-5),
|
||||
'z' : lambda cfg,shift_pressed: cfg.toggle_masked_hist_match(),
|
||||
'x' : lambda cfg,shift_pressed: cfg.toggle_mask_mode(),
|
||||
'c' : lambda cfg,shift_pressed: cfg.toggle_color_transfer_mode(),
|
||||
'n' : lambda cfg,shift_pressed: cfg.toggle_sharpen_mode(),
|
||||
}
|
||||
self.masked_keys = list(self.masked_keys_funcs.keys())
|
||||
|
||||
#overridable optional
|
||||
def on_clients_finalized(self):
|
||||
io.progress_bar_close()
|
||||
|
||||
if self.is_interactive:
|
||||
self.screen_manager.finalize()
|
||||
|
||||
for frame in self.frames:
|
||||
frame.output_filepath = None
|
||||
frame.output_mask_filepath = None
|
||||
frame.image = None
|
||||
|
||||
session_data = {
|
||||
'frames': self.frames,
|
||||
'frames_idxs': self.frames_idxs,
|
||||
'frames_done_idxs': self.frames_done_idxs,
|
||||
'model_iter' : self.model_iter,
|
||||
}
|
||||
self.merger_session_filepath.write_bytes( pickle.dumps(session_data) )
|
||||
|
||||
io.log_info ("Session is saved to " + '/'.join (self.merger_session_filepath.parts[-2:]) )
|
||||
|
||||
#override
|
||||
def on_tick(self):
|
||||
self.predictor_func_host.process_messages()
|
||||
self.superres_host.process_messages()
|
||||
self.fanseg_host.process_messages()
|
||||
|
||||
go_prev_frame = False
|
||||
go_first_frame = False
|
||||
go_prev_frame_overriding_cfg = False
|
||||
go_first_frame_overriding_cfg = False
|
||||
|
||||
go_next_frame = self.process_remain_frames
|
||||
go_next_frame_overriding_cfg = False
|
||||
go_last_frame_overriding_cfg = False
|
||||
|
||||
cur_frame = None
|
||||
if len(self.frames_idxs) != 0:
|
||||
cur_frame = self.frames[self.frames_idxs[0]]
|
||||
|
||||
if self.is_interactive:
|
||||
|
||||
screen_image = None if self.process_remain_frames else \
|
||||
self.main_screen.get_image()
|
||||
|
||||
self.main_screen.set_waiting_icon( self.process_remain_frames or \
|
||||
self.is_interactive_quitting )
|
||||
|
||||
if cur_frame is not None and not self.is_interactive_quitting:
|
||||
|
||||
if not self.process_remain_frames:
|
||||
if cur_frame.is_done:
|
||||
if not cur_frame.is_shown:
|
||||
if cur_frame.image is None:
|
||||
image = cv2_imread (cur_frame.output_filepath, verbose=False)
|
||||
image_mask = cv2_imread (cur_frame.output_mask_filepath, verbose=False)
|
||||
if image is None or image_mask is None:
|
||||
# unable to read? recompute then
|
||||
cur_frame.is_done = False
|
||||
else:
|
||||
image_mask = imagelib.normalize_channels(image_mask, 1)
|
||||
cur_frame.image = np.concatenate([image, image_mask], -1)
|
||||
|
||||
if cur_frame.is_done:
|
||||
io.log_info (cur_frame.cfg.to_string( cur_frame.frame_info.filepath.name) )
|
||||
cur_frame.is_shown = True
|
||||
screen_image = cur_frame.image
|
||||
else:
|
||||
self.main_screen.set_waiting_icon(True)
|
||||
|
||||
self.main_screen.set_image(screen_image)
|
||||
self.screen_manager.show_current()
|
||||
|
||||
key_events = self.screen_manager.get_key_events()
|
||||
key, chr_key, ctrl_pressed, alt_pressed, shift_pressed = key_events[-1] if len(key_events) > 0 else (0,0,False,False,False)
|
||||
|
||||
if key == 9: #tab
|
||||
self.screen_manager.switch_screens()
|
||||
else:
|
||||
if key == 27: #esc
|
||||
self.is_interactive_quitting = True
|
||||
elif self.screen_manager.get_current() is self.main_screen:
|
||||
|
||||
if self.merger_config.type == MergerConfig.TYPE_MASKED and chr_key in self.masked_keys:
|
||||
self.process_remain_frames = False
|
||||
|
||||
if cur_frame is not None:
|
||||
cfg = cur_frame.cfg
|
||||
prev_cfg = cfg.copy()
|
||||
|
||||
if cfg.type == MergerConfig.TYPE_MASKED:
|
||||
self.masked_keys_funcs[chr_key](cfg, shift_pressed)
|
||||
|
||||
if prev_cfg != cfg:
|
||||
io.log_info ( cfg.to_string(cur_frame.frame_info.filepath.name) )
|
||||
cur_frame.is_done = False
|
||||
cur_frame.is_shown = False
|
||||
else:
|
||||
|
||||
if chr_key == ',' or chr_key == 'm':
|
||||
self.process_remain_frames = False
|
||||
go_prev_frame = True
|
||||
|
||||
if chr_key == ',':
|
||||
if shift_pressed:
|
||||
go_first_frame = True
|
||||
|
||||
elif chr_key == 'm':
|
||||
if not shift_pressed:
|
||||
go_prev_frame_overriding_cfg = True
|
||||
else:
|
||||
go_first_frame_overriding_cfg = True
|
||||
|
||||
elif chr_key == '.' or chr_key == '/':
|
||||
self.process_remain_frames = False
|
||||
go_next_frame = True
|
||||
|
||||
if chr_key == '.':
|
||||
if shift_pressed:
|
||||
self.process_remain_frames = not self.process_remain_frames
|
||||
|
||||
elif chr_key == '/':
|
||||
if not shift_pressed:
|
||||
go_next_frame_overriding_cfg = True
|
||||
else:
|
||||
go_last_frame_overriding_cfg = True
|
||||
|
||||
elif chr_key == '-':
|
||||
self.screen_manager.get_current().diff_scale(-0.1)
|
||||
elif chr_key == '=':
|
||||
self.screen_manager.get_current().diff_scale(0.1)
|
||||
elif chr_key == 'v':
|
||||
self.screen_manager.get_current().toggle_show_checker_board()
|
||||
|
||||
if go_prev_frame:
|
||||
if cur_frame is None or cur_frame.is_done:
|
||||
if cur_frame is not None:
|
||||
cur_frame.image = None
|
||||
|
||||
while True:
|
||||
if len(self.frames_done_idxs) > 0:
|
||||
prev_frame = self.frames[self.frames_done_idxs.pop()]
|
||||
self.frames_idxs.insert(0, prev_frame.idx)
|
||||
prev_frame.is_shown = False
|
||||
io.progress_bar_inc(-1)
|
||||
|
||||
if cur_frame is not None and (go_prev_frame_overriding_cfg or go_first_frame_overriding_cfg):
|
||||
if prev_frame.cfg != cur_frame.cfg:
|
||||
prev_frame.cfg = cur_frame.cfg.copy()
|
||||
prev_frame.is_done = False
|
||||
|
||||
cur_frame = prev_frame
|
||||
|
||||
if go_first_frame_overriding_cfg or go_first_frame:
|
||||
if len(self.frames_done_idxs) > 0:
|
||||
continue
|
||||
break
|
||||
|
||||
elif go_next_frame:
|
||||
if cur_frame is not None and cur_frame.is_done:
|
||||
cur_frame.image = None
|
||||
cur_frame.is_shown = True
|
||||
self.frames_done_idxs.append(cur_frame.idx)
|
||||
self.frames_idxs.pop(0)
|
||||
io.progress_bar_inc(1)
|
||||
|
||||
f = self.frames
|
||||
|
||||
if len(self.frames_idxs) != 0:
|
||||
next_frame = f[ self.frames_idxs[0] ]
|
||||
next_frame.is_shown = False
|
||||
|
||||
if go_next_frame_overriding_cfg or go_last_frame_overriding_cfg:
|
||||
|
||||
if go_next_frame_overriding_cfg:
|
||||
to_frames = next_frame.idx+1
|
||||
else:
|
||||
to_frames = len(f)
|
||||
|
||||
for i in range( next_frame.idx, to_frames ):
|
||||
f[i].cfg = None
|
||||
|
||||
for i in range( min(len(self.frames_idxs), self.prefetch_frame_count) ):
|
||||
frame = f[ self.frames_idxs[i] ]
|
||||
if frame.cfg is None:
|
||||
if i == 0:
|
||||
frame.cfg = cur_frame.cfg.copy()
|
||||
else:
|
||||
frame.cfg = f[ self.frames_idxs[i-1] ].cfg.copy()
|
||||
|
||||
frame.is_done = False #initiate solve again
|
||||
frame.is_shown = False
|
||||
|
||||
if len(self.frames_idxs) == 0:
|
||||
self.process_remain_frames = False
|
||||
|
||||
return (self.is_interactive and self.is_interactive_quitting) or \
|
||||
(not self.is_interactive and self.process_remain_frames == False)
|
||||
|
||||
|
||||
#override
|
||||
def on_data_return (self, host_dict, pf):
|
||||
frame = self.frames[pf.idx]
|
||||
frame.is_done = False
|
||||
frame.is_processing = False
|
||||
|
||||
#override
|
||||
def on_result (self, host_dict, pf_sent, pf_result):
|
||||
frame = self.frames[pf_result.idx]
|
||||
frame.is_processing = False
|
||||
if frame.cfg == pf_result.cfg:
|
||||
frame.is_done = True
|
||||
frame.image = pf_result.image
|
||||
|
||||
#override
|
||||
def get_data(self, host_dict):
|
||||
if self.is_interactive and self.is_interactive_quitting:
|
||||
return None
|
||||
|
||||
for i in range ( min(len(self.frames_idxs), self.prefetch_frame_count) ):
|
||||
frame = self.frames[ self.frames_idxs[i] ]
|
||||
|
||||
if not frame.is_done and not frame.is_processing and frame.cfg is not None:
|
||||
frame.is_processing = True
|
||||
return MergeSubprocessor.ProcessingFrame(idx=frame.idx,
|
||||
cfg=frame.cfg.copy(),
|
||||
prev_temporal_frame_infos=frame.prev_temporal_frame_infos,
|
||||
frame_info=frame.frame_info,
|
||||
next_temporal_frame_infos=frame.next_temporal_frame_infos,
|
||||
output_filepath=frame.output_filepath,
|
||||
output_mask_filepath=frame.output_mask_filepath,
|
||||
need_return_image=True )
|
||||
|
||||
return None
|
||||
|
||||
#override
|
||||
def get_result(self):
|
||||
return 0
|
||||
from core.interact import interact as io
|
||||
from core.joblib import MPClassFuncOnDemand, MPFunc
|
||||
from core.leras import nn
|
||||
from DFLIMG import DFLIMG
|
||||
from facelib import FaceEnhancer, FaceType, LandmarksProcessor, TernausNet, DFLSegNet
|
||||
from merger import FrameInfo, MergerConfig, InteractiveMergerSubprocessor
|
||||
|
||||
def main (model_class_name=None,
|
||||
saved_models_path=None,
|
||||
|
@ -658,23 +42,42 @@ def main (model_class_name=None,
|
|||
io.log_err('Model directory not found. Please ensure it exists.')
|
||||
return
|
||||
|
||||
is_interactive = io.input_bool ("Use interactive merger?", True) if not io.is_colab() else False
|
||||
|
||||
# Initialize model
|
||||
import models
|
||||
model = models.import_model(model_class_name)(is_training=False,
|
||||
saved_models_path=saved_models_path,
|
||||
training_data_src_path=training_data_src_path,
|
||||
force_gpu_idxs=force_gpu_idxs,
|
||||
cpu_only=cpu_only)
|
||||
merger_session_filepath = model.get_strpath_storage_for_file('merger_session.dat')
|
||||
|
||||
predictor_func, predictor_input_shape, cfg = model.get_MergerConfig()
|
||||
|
||||
# Preparing MP functions
|
||||
predictor_func = MPFunc(predictor_func)
|
||||
|
||||
run_on_cpu = len(nn.getCurrentDeviceConfig().devices) == 0
|
||||
fanseg_full_face_256_extract_func = MPClassFuncOnDemand(TernausNet, 'extract',
|
||||
name=f'FANSeg_{FaceType.toString(FaceType.FULL)}',
|
||||
resolution=256,
|
||||
place_model_on_cpu=True,
|
||||
run_on_cpu=run_on_cpu)
|
||||
|
||||
skinseg_256_extract_func = MPClassFuncOnDemand(DFLSegNet, 'extract',
|
||||
name='SkinSeg',
|
||||
resolution=256,
|
||||
place_model_on_cpu=True,
|
||||
run_on_cpu=run_on_cpu)
|
||||
|
||||
face_enhancer_func = MPClassFuncOnDemand(FaceEnhancer, 'enhance',
|
||||
place_model_on_cpu=True,
|
||||
run_on_cpu=run_on_cpu)
|
||||
|
||||
is_interactive = io.input_bool ("Use interactive merger?", True) if not io.is_colab() else False
|
||||
|
||||
if not is_interactive:
|
||||
cfg.ask_settings()
|
||||
|
||||
input_path_image_paths = pathex.get_image_paths(input_path)
|
||||
|
||||
|
||||
|
||||
if cfg.type == MergerConfig.TYPE_MASKED:
|
||||
if not aligned_path.exists():
|
||||
io.log_err('Aligned directory not found. Please ensure it exists.')
|
||||
|
@ -719,7 +122,7 @@ def main (model_class_name=None,
|
|||
|
||||
alignments_ar = alignments[ source_filename_stem ]
|
||||
alignments_ar.append ( (dflimg.get_source_landmarks(), filepath, source_filepath ) )
|
||||
|
||||
|
||||
if len(alignments_ar) > 1:
|
||||
multiple_faces_detected = True
|
||||
|
||||
|
@ -727,22 +130,22 @@ def main (model_class_name=None,
|
|||
io.log_info ("")
|
||||
io.log_info ("Warning: multiple faces detected. Only one alignment file should refer one source file.")
|
||||
io.log_info ("")
|
||||
|
||||
|
||||
for a_key in list(alignments.keys()):
|
||||
a_ar = alignments[a_key]
|
||||
if len(a_ar) > 1:
|
||||
for _, filepath, source_filepath in a_ar:
|
||||
for _, filepath, source_filepath in a_ar:
|
||||
io.log_info (f"alignment {filepath.name} refers to {source_filepath.name} ")
|
||||
io.log_info ("")
|
||||
|
||||
|
||||
alignments[a_key] = [ a[0] for a in a_ar]
|
||||
|
||||
|
||||
if multiple_faces_detected:
|
||||
io.log_info ("It is strongly recommended to process the faces separatelly.")
|
||||
io.log_info ("Use 'recover original filename' to determine the exact duplicates.")
|
||||
io.log_info ("")
|
||||
|
||||
frames = [ MergeSubprocessor.Frame( frame_info=FrameInfo(filepath=Path(p),
|
||||
frames = [ InteractiveMergerSubprocessor.Frame( frame_info=FrameInfo(filepath=Path(p),
|
||||
landmarks_list=alignments.get(Path(p).stem, None)
|
||||
)
|
||||
)
|
||||
|
@ -783,60 +186,66 @@ def main (model_class_name=None,
|
|||
fi.motion_deg = -math.atan2(motion_vector[1],motion_vector[0])*180 / math.pi
|
||||
|
||||
|
||||
elif cfg.type == MergerConfig.TYPE_FACE_AVATAR:
|
||||
pass
|
||||
"""
|
||||
filesdata = []
|
||||
for filepath in io.progress_bar_generator(input_path_image_paths, "Collecting info"):
|
||||
filepath = Path(filepath)
|
||||
|
||||
dflimg = DFLIMG.load(filepath)
|
||||
if dflimg is None:
|
||||
io.log_err ("%s is not a dfl image file" % (filepath.name) )
|
||||
continue
|
||||
filesdata += [ ( FrameInfo(filepath=filepath, landmarks_list=[dflimg.get_landmarks()] ), dflimg.get_source_filename() ) ]
|
||||
|
||||
filesdata = sorted(filesdata, key=operator.itemgetter(1)) #sort by source_filename
|
||||
frames = []
|
||||
filesdata_len = len(filesdata)
|
||||
for i in range(len(filesdata)):
|
||||
frame_info = filesdata[i][0]
|
||||
|
||||
prev_temporal_frame_infos = []
|
||||
next_temporal_frame_infos = []
|
||||
|
||||
for t in range (cfg.temporal_face_count):
|
||||
prev_frame_info = filesdata[ max(i -t, 0) ][0]
|
||||
next_frame_info = filesdata[ min(i +t, filesdata_len-1 )][0]
|
||||
|
||||
prev_temporal_frame_infos.insert (0, prev_frame_info )
|
||||
next_temporal_frame_infos.append ( next_frame_info )
|
||||
|
||||
frames.append ( MergeSubprocessor.Frame(prev_temporal_frame_infos=prev_temporal_frame_infos,
|
||||
frame_info=frame_info,
|
||||
next_temporal_frame_infos=next_temporal_frame_infos) )
|
||||
"""
|
||||
if len(frames) == 0:
|
||||
io.log_info ("No frames to merge in input_dir.")
|
||||
else:
|
||||
MergeSubprocessor (
|
||||
is_interactive = is_interactive,
|
||||
merger_session_filepath = merger_session_filepath,
|
||||
predictor_func = predictor_func,
|
||||
predictor_input_shape = predictor_input_shape,
|
||||
merger_config = cfg,
|
||||
frames = frames,
|
||||
frames_root_path = input_path,
|
||||
output_path = output_path,
|
||||
output_mask_path = output_mask_path,
|
||||
model_iter = model.get_iter()
|
||||
).run()
|
||||
if False:
|
||||
pass
|
||||
else:
|
||||
InteractiveMergerSubprocessor (
|
||||
is_interactive = is_interactive,
|
||||
merger_session_filepath = model.get_strpath_storage_for_file('merger_session.dat'),
|
||||
predictor_func = predictor_func,
|
||||
predictor_input_shape = predictor_input_shape,
|
||||
face_enhancer_func = face_enhancer_func,
|
||||
fanseg_full_face_256_extract_func = fanseg_full_face_256_extract_func,
|
||||
skinseg_256_extract_func = skinseg_256_extract_func,
|
||||
merger_config = cfg,
|
||||
frames = frames,
|
||||
frames_root_path = input_path,
|
||||
output_path = output_path,
|
||||
output_mask_path = output_mask_path,
|
||||
model_iter = model.get_iter()
|
||||
).run()
|
||||
|
||||
model.finalize()
|
||||
|
||||
except Exception as e:
|
||||
print ( 'Error: %s' % (str(e)))
|
||||
traceback.print_exc()
|
||||
print ( traceback.format_exc() )
|
||||
|
||||
|
||||
"""
|
||||
elif cfg.type == MergerConfig.TYPE_FACE_AVATAR:
|
||||
filesdata = []
|
||||
for filepath in io.progress_bar_generator(input_path_image_paths, "Collecting info"):
|
||||
filepath = Path(filepath)
|
||||
|
||||
dflimg = DFLIMG.load(filepath)
|
||||
if dflimg is None:
|
||||
io.log_err ("%s is not a dfl image file" % (filepath.name) )
|
||||
continue
|
||||
filesdata += [ ( FrameInfo(filepath=filepath, landmarks_list=[dflimg.get_landmarks()] ), dflimg.get_source_filename() ) ]
|
||||
|
||||
filesdata = sorted(filesdata, key=operator.itemgetter(1)) #sort by source_filename
|
||||
frames = []
|
||||
filesdata_len = len(filesdata)
|
||||
for i in range(len(filesdata)):
|
||||
frame_info = filesdata[i][0]
|
||||
|
||||
prev_temporal_frame_infos = []
|
||||
next_temporal_frame_infos = []
|
||||
|
||||
for t in range (cfg.temporal_face_count):
|
||||
prev_frame_info = filesdata[ max(i -t, 0) ][0]
|
||||
next_frame_info = filesdata[ min(i +t, filesdata_len-1 )][0]
|
||||
|
||||
prev_temporal_frame_infos.insert (0, prev_frame_info )
|
||||
next_temporal_frame_infos.append ( next_frame_info )
|
||||
|
||||
frames.append ( InteractiveMergerSubprocessor.Frame(prev_temporal_frame_infos=prev_temporal_frame_infos,
|
||||
frame_info=frame_info,
|
||||
next_temporal_frame_infos=next_temporal_frame_infos) )
|
||||
"""
|
||||
|
||||
#interpolate landmarks
|
||||
#from facelib import LandmarksProcessor
|
||||
|
|
|
@ -754,7 +754,7 @@ def sort_by_absdiff(input_path):
|
|||
|
||||
from core.leras import nn
|
||||
|
||||
device_config = nn.ask_choose_device_idxs(choose_only_one=True, return_device_config=True)
|
||||
device_config = nn.DeviceConfig.ask_choose_device(choose_only_one=True)
|
||||
nn.initialize( device_config=device_config, data_format="NHWC" )
|
||||
tf = nn.tf
|
||||
|
||||
|
|
|
@ -66,9 +66,7 @@ def restore_faceset_metadata_folder(input_path):
|
|||
elif filepath.suffix == '.jpg':
|
||||
cv2_imwrite (str(filepath), img, [int(cv2.IMWRITE_JPEG_QUALITY), 100] )
|
||||
|
||||
if filepath.suffix == '.png':
|
||||
DFLPNG.embed_dfldict( str(filepath), dfl_dict )
|
||||
elif filepath.suffix == '.jpg':
|
||||
if filepath.suffix == '.jpg':
|
||||
DFLJPG.embed_dfldict( str(filepath), dfl_dict )
|
||||
else:
|
||||
continue
|
||||
|
@ -118,42 +116,6 @@ def remove_fanseg_folder(input_path):
|
|||
filepath = Path(filepath)
|
||||
remove_fanseg_file(filepath)
|
||||
|
||||
def convert_png_to_jpg_file (filepath):
|
||||
filepath = Path(filepath)
|
||||
|
||||
if filepath.suffix != '.png':
|
||||
return
|
||||
|
||||
dflpng = DFLPNG.load (str(filepath) )
|
||||
if dflpng is None:
|
||||
io.log_err ("%s is not a dfl png image file" % (filepath.name) )
|
||||
return
|
||||
|
||||
dfl_dict = dflpng.getDFLDictData()
|
||||
|
||||
img = cv2_imread (str(filepath))
|
||||
new_filepath = str(filepath.parent / (filepath.stem + '.jpg'))
|
||||
cv2_imwrite ( new_filepath, img, [int(cv2.IMWRITE_JPEG_QUALITY), 100])
|
||||
|
||||
DFLJPG.embed_data( new_filepath,
|
||||
face_type=dfl_dict.get('face_type', None),
|
||||
landmarks=dfl_dict.get('landmarks', None),
|
||||
ie_polys=dfl_dict.get('ie_polys', None),
|
||||
source_filename=dfl_dict.get('source_filename', None),
|
||||
source_rect=dfl_dict.get('source_rect', None),
|
||||
source_landmarks=dfl_dict.get('source_landmarks', None) )
|
||||
|
||||
filepath.unlink()
|
||||
|
||||
def convert_png_to_jpg_folder (input_path):
|
||||
input_path = Path(input_path)
|
||||
|
||||
io.log_info ("Converting PNG to JPG...\r\n")
|
||||
|
||||
for filepath in io.progress_bar_generator( pathex.get_image_paths(input_path), "Converting"):
|
||||
filepath = Path(filepath)
|
||||
convert_png_to_jpg_file(filepath)
|
||||
|
||||
def add_landmarks_debug_images(input_path):
|
||||
io.log_info ("Adding landmarks debug images...")
|
||||
|
||||
|
@ -236,3 +198,42 @@ def recover_original_aligned_filename(input_path):
|
|||
fs.rename (fd)
|
||||
except:
|
||||
io.log_err ('fail to rename %s' % (fs.name) )
|
||||
|
||||
|
||||
"""
|
||||
def convert_png_to_jpg_file (filepath):
|
||||
filepath = Path(filepath)
|
||||
|
||||
if filepath.suffix != '.png':
|
||||
return
|
||||
|
||||
dflpng = DFLPNG.load (str(filepath) )
|
||||
if dflpng is None:
|
||||
io.log_err ("%s is not a dfl png image file" % (filepath.name) )
|
||||
return
|
||||
|
||||
dfl_dict = dflpng.getDFLDictData()
|
||||
|
||||
img = cv2_imread (str(filepath))
|
||||
new_filepath = str(filepath.parent / (filepath.stem + '.jpg'))
|
||||
cv2_imwrite ( new_filepath, img, [int(cv2.IMWRITE_JPEG_QUALITY), 100])
|
||||
|
||||
DFLJPG.embed_data( new_filepath,
|
||||
face_type=dfl_dict.get('face_type', None),
|
||||
landmarks=dfl_dict.get('landmarks', None),
|
||||
ie_polys=dfl_dict.get('ie_polys', None),
|
||||
source_filename=dfl_dict.get('source_filename', None),
|
||||
source_rect=dfl_dict.get('source_rect', None),
|
||||
source_landmarks=dfl_dict.get('source_landmarks', None) )
|
||||
|
||||
filepath.unlink()
|
||||
|
||||
def convert_png_to_jpg_folder (input_path):
|
||||
input_path = Path(input_path)
|
||||
|
||||
io.log_info ("Converting PNG to JPG...\r\n")
|
||||
|
||||
for filepath in io.progress_bar_generator( pathex.get_image_paths(input_path), "Converting"):
|
||||
filepath = Path(filepath)
|
||||
convert_png_to_jpg_file(filepath)
|
||||
"""
|
|
@ -552,8 +552,42 @@ def dev_test1(input_dir):
|
|||
|
||||
#import code
|
||||
#code.interact(local=dict(globals(), **locals()))
|
||||
|
||||
def dev_resave_pngs(input_dir):
|
||||
input_path = Path(input_dir)
|
||||
if not input_path.exists():
|
||||
raise ValueError('input_dir not found. Please ensure it exists.')
|
||||
|
||||
images_paths = pathex.get_image_paths(input_path, image_extensions=['.png'], subdirs=True, return_Path_class=True)
|
||||
|
||||
for filepath in io.progress_bar_generator(images_paths,"Processing"):
|
||||
cv2_imwrite(filepath, cv2_imread(filepath))
|
||||
|
||||
|
||||
def dev_segmented_trash(input_dir):
|
||||
input_path = Path(input_dir)
|
||||
if not input_path.exists():
|
||||
raise ValueError('input_dir not found. Please ensure it exists.')
|
||||
|
||||
output_path = input_path.parent / (input_path.name+'_trash')
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
images_paths = pathex.get_image_paths(input_path, return_Path_class=True)
|
||||
|
||||
trash_paths = []
|
||||
for filepath in images_paths:
|
||||
json_file = filepath.parent / (filepath.stem +'.json')
|
||||
if not json_file.exists():
|
||||
trash_paths.append(filepath)
|
||||
|
||||
for filepath in trash_paths:
|
||||
|
||||
try:
|
||||
filepath.rename ( output_path / filepath.name )
|
||||
except:
|
||||
io.log_info ('fail to trashing %s' % (src.name) )
|
||||
|
||||
|
||||
def dev_segmented_extract(input_dir, output_dir ):
|
||||
# extract and merge .json labelme files within the faces
|
||||
|
||||
|
|
571
merger/InteractiveMergerSubprocessor.py
Normal file
571
merger/InteractiveMergerSubprocessor.py
Normal file
|
@ -0,0 +1,571 @@
|
|||
import multiprocessing
|
||||
import os
|
||||
import pickle
|
||||
import sys
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
from core import imagelib, pathex
|
||||
from core.cv2ex import *
|
||||
from core.interact import interact as io
|
||||
from core.joblib import Subprocessor
|
||||
from merger import MergeFaceAvatar, MergeMasked, MergerConfig
|
||||
|
||||
from .MergerScreen import Screen, ScreenManager
|
||||
|
||||
MERGER_DEBUG = False
|
||||
class InteractiveMergerSubprocessor(Subprocessor):
|
||||
|
||||
class Frame(object):
|
||||
def __init__(self, prev_temporal_frame_infos=None,
|
||||
frame_info=None,
|
||||
next_temporal_frame_infos=None):
|
||||
self.prev_temporal_frame_infos = prev_temporal_frame_infos
|
||||
self.frame_info = frame_info
|
||||
self.next_temporal_frame_infos = next_temporal_frame_infos
|
||||
self.output_filepath = None
|
||||
self.output_mask_filepath = None
|
||||
|
||||
self.idx = None
|
||||
self.cfg = None
|
||||
self.is_done = False
|
||||
self.is_processing = False
|
||||
self.is_shown = False
|
||||
self.image = None
|
||||
|
||||
class ProcessingFrame(object):
|
||||
def __init__(self, idx=None,
|
||||
cfg=None,
|
||||
prev_temporal_frame_infos=None,
|
||||
frame_info=None,
|
||||
next_temporal_frame_infos=None,
|
||||
output_filepath=None,
|
||||
output_mask_filepath=None,
|
||||
need_return_image = False):
|
||||
self.idx = idx
|
||||
self.cfg = cfg
|
||||
self.prev_temporal_frame_infos = prev_temporal_frame_infos
|
||||
self.frame_info = frame_info
|
||||
self.next_temporal_frame_infos = next_temporal_frame_infos
|
||||
self.output_filepath = output_filepath
|
||||
self.output_mask_filepath = output_mask_filepath
|
||||
|
||||
self.need_return_image = need_return_image
|
||||
if self.need_return_image:
|
||||
self.image = None
|
||||
|
||||
class Cli(Subprocessor.Cli):
|
||||
|
||||
#override
|
||||
def on_initialize(self, client_dict):
|
||||
self.log_info ('Running on %s.' % (client_dict['device_name']) )
|
||||
self.device_idx = client_dict['device_idx']
|
||||
self.device_name = client_dict['device_name']
|
||||
self.predictor_func = client_dict['predictor_func']
|
||||
self.predictor_input_shape = client_dict['predictor_input_shape']
|
||||
self.face_enhancer_func = client_dict['face_enhancer_func']
|
||||
self.fanseg_full_face_256_extract_func = client_dict['fanseg_full_face_256_extract_func']
|
||||
self.skinseg_256_extract_func = client_dict['skinseg_256_extract_func']
|
||||
|
||||
|
||||
#transfer and set stdin in order to work code.interact in debug subprocess
|
||||
stdin_fd = client_dict['stdin_fd']
|
||||
if stdin_fd is not None:
|
||||
sys.stdin = os.fdopen(stdin_fd)
|
||||
|
||||
return None
|
||||
|
||||
#override
|
||||
def process_data(self, pf): #pf=ProcessingFrame
|
||||
cfg = pf.cfg.copy()
|
||||
|
||||
frame_info = pf.frame_info
|
||||
filepath = frame_info.filepath
|
||||
|
||||
if len(frame_info.landmarks_list) == 0:
|
||||
self.log_info (f'no faces found for {filepath.name}, copying without faces')
|
||||
|
||||
img_bgr = cv2_imread(filepath)
|
||||
imagelib.normalize_channels(img_bgr, 3)
|
||||
cv2_imwrite (pf.output_filepath, img_bgr)
|
||||
h,w,c = img_bgr.shape
|
||||
|
||||
img_mask = np.zeros( (h,w,1), dtype=img_bgr.dtype)
|
||||
cv2_imwrite (pf.output_mask_filepath, img_mask)
|
||||
|
||||
if pf.need_return_image:
|
||||
pf.image = np.concatenate ([img_bgr, img_mask], axis=-1)
|
||||
|
||||
else:
|
||||
if cfg.type == MergerConfig.TYPE_MASKED:
|
||||
try:
|
||||
final_img = MergeMasked (self.predictor_func, self.predictor_input_shape,
|
||||
face_enhancer_func=self.face_enhancer_func,
|
||||
fanseg_full_face_256_extract_func=self.fanseg_full_face_256_extract_func,
|
||||
skinseg_256_extract_func=self.skinseg_256_extract_func,
|
||||
cfg=cfg,
|
||||
frame_info=frame_info)
|
||||
except Exception as e:
|
||||
e_str = traceback.format_exc()
|
||||
if 'MemoryError' in e_str:
|
||||
raise Subprocessor.SilenceException
|
||||
else:
|
||||
raise Exception( f'Error while merging file [{filepath}]: {e_str}' )
|
||||
|
||||
elif cfg.type == MergerConfig.TYPE_FACE_AVATAR:
|
||||
final_img = MergeFaceAvatar (self.predictor_func, self.predictor_input_shape,
|
||||
cfg, pf.prev_temporal_frame_infos,
|
||||
pf.frame_info,
|
||||
pf.next_temporal_frame_infos )
|
||||
|
||||
cv2_imwrite (pf.output_filepath, final_img[...,0:3] )
|
||||
cv2_imwrite (pf.output_mask_filepath, final_img[...,3:4] )
|
||||
|
||||
if pf.need_return_image:
|
||||
pf.image = final_img
|
||||
|
||||
return pf
|
||||
|
||||
#overridable
|
||||
def get_data_name (self, pf):
|
||||
#return string identificator of your data
|
||||
return pf.frame_info.filepath
|
||||
|
||||
|
||||
|
||||
|
||||
#override
|
||||
def __init__(self, is_interactive, merger_session_filepath, predictor_func, predictor_input_shape, face_enhancer_func, fanseg_full_face_256_extract_func, skinseg_256_extract_func, merger_config, frames, frames_root_path, output_path, output_mask_path, model_iter):
|
||||
if len (frames) == 0:
|
||||
raise ValueError ("len (frames) == 0")
|
||||
|
||||
super().__init__('Merger', InteractiveMergerSubprocessor.Cli, io_loop_sleep_time=0.001)
|
||||
|
||||
self.is_interactive = is_interactive
|
||||
self.merger_session_filepath = Path(merger_session_filepath)
|
||||
self.merger_config = merger_config
|
||||
|
||||
self.predictor_func = predictor_func
|
||||
self.predictor_input_shape = predictor_input_shape
|
||||
|
||||
self.face_enhancer_func = face_enhancer_func
|
||||
self.fanseg_full_face_256_extract_func = fanseg_full_face_256_extract_func
|
||||
self.skinseg_256_extract_func = skinseg_256_extract_func
|
||||
|
||||
self.frames_root_path = frames_root_path
|
||||
self.output_path = output_path
|
||||
self.output_mask_path = output_mask_path
|
||||
self.model_iter = model_iter
|
||||
|
||||
self.prefetch_frame_count = self.process_count = multiprocessing.cpu_count()
|
||||
|
||||
session_data = None
|
||||
if self.is_interactive and self.merger_session_filepath.exists():
|
||||
io.input_skip_pending()
|
||||
if io.input_bool ("Use saved session?", True):
|
||||
try:
|
||||
with open( str(self.merger_session_filepath), "rb") as f:
|
||||
session_data = pickle.loads(f.read())
|
||||
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
rewind_to_frame_idx = None
|
||||
self.frames = frames
|
||||
self.frames_idxs = [ *range(len(self.frames)) ]
|
||||
self.frames_done_idxs = []
|
||||
|
||||
if self.is_interactive and session_data is not None:
|
||||
# Loaded session data, check it
|
||||
s_frames = session_data.get('frames', None)
|
||||
s_frames_idxs = session_data.get('frames_idxs', None)
|
||||
s_frames_done_idxs = session_data.get('frames_done_idxs', None)
|
||||
s_model_iter = session_data.get('model_iter', None)
|
||||
|
||||
frames_equal = (s_frames is not None) and \
|
||||
(s_frames_idxs is not None) and \
|
||||
(s_frames_done_idxs is not None) and \
|
||||
(s_model_iter is not None) and \
|
||||
(len(frames) == len(s_frames)) # frames count must match
|
||||
|
||||
if frames_equal:
|
||||
for i in range(len(frames)):
|
||||
frame = frames[i]
|
||||
s_frame = s_frames[i]
|
||||
# frames filenames must match
|
||||
if frame.frame_info.filepath.name != s_frame.frame_info.filepath.name:
|
||||
frames_equal = False
|
||||
if not frames_equal:
|
||||
break
|
||||
|
||||
if frames_equal:
|
||||
io.log_info ('Using saved session from ' + '/'.join (self.merger_session_filepath.parts[-2:]) )
|
||||
|
||||
for frame in s_frames:
|
||||
if frame.cfg is not None:
|
||||
# recreate MergerConfig class using constructor with get_config() as dict params
|
||||
# so if any new param will be added, old merger session will work properly
|
||||
frame.cfg = frame.cfg.__class__( **frame.cfg.get_config() )
|
||||
|
||||
self.frames = s_frames
|
||||
self.frames_idxs = s_frames_idxs
|
||||
self.frames_done_idxs = s_frames_done_idxs
|
||||
|
||||
if self.model_iter != s_model_iter:
|
||||
# model was more trained, recompute all frames
|
||||
rewind_to_frame_idx = -1
|
||||
for frame in self.frames:
|
||||
frame.is_done = False
|
||||
elif len(self.frames_idxs) == 0:
|
||||
# all frames are done?
|
||||
rewind_to_frame_idx = -1
|
||||
|
||||
if len(self.frames_idxs) != 0:
|
||||
cur_frame = self.frames[self.frames_idxs[0]]
|
||||
cur_frame.is_shown = False
|
||||
|
||||
if not frames_equal:
|
||||
session_data = None
|
||||
|
||||
if session_data is None:
|
||||
for filename in pathex.get_image_paths(self.output_path): #remove all images in output_path
|
||||
Path(filename).unlink()
|
||||
|
||||
for filename in pathex.get_image_paths(self.output_mask_path): #remove all images in output_mask_path
|
||||
Path(filename).unlink()
|
||||
|
||||
|
||||
frames[0].cfg = self.merger_config.copy()
|
||||
|
||||
for i in range( len(self.frames) ):
|
||||
frame = self.frames[i]
|
||||
frame.idx = i
|
||||
frame.output_filepath = self.output_path / ( frame.frame_info.filepath.stem + '.png' )
|
||||
frame.output_mask_filepath = self.output_mask_path / ( frame.frame_info.filepath.stem + '.png' )
|
||||
|
||||
if not frame.output_filepath.exists() or \
|
||||
not frame.output_mask_filepath.exists():
|
||||
# if some frame does not exist, recompute and rewind
|
||||
frame.is_done = False
|
||||
frame.is_shown = False
|
||||
|
||||
if rewind_to_frame_idx is None:
|
||||
rewind_to_frame_idx = i-1
|
||||
else:
|
||||
rewind_to_frame_idx = min(rewind_to_frame_idx, i-1)
|
||||
|
||||
if rewind_to_frame_idx is not None:
|
||||
while len(self.frames_done_idxs) > 0:
|
||||
if self.frames_done_idxs[-1] > rewind_to_frame_idx:
|
||||
prev_frame = self.frames[self.frames_done_idxs.pop()]
|
||||
self.frames_idxs.insert(0, prev_frame.idx)
|
||||
else:
|
||||
break
|
||||
#override
|
||||
def process_info_generator(self):
|
||||
r = [0] if MERGER_DEBUG else range(self.process_count)
|
||||
|
||||
for i in r:
|
||||
yield 'CPU%d' % (i), {}, {'device_idx': i,
|
||||
'device_name': 'CPU%d' % (i),
|
||||
'predictor_func': self.predictor_func,
|
||||
'predictor_input_shape' : self.predictor_input_shape,
|
||||
'face_enhancer_func': self.face_enhancer_func,
|
||||
'fanseg_full_face_256_extract_func' : self.fanseg_full_face_256_extract_func,
|
||||
'skinseg_256_extract_func' : self.skinseg_256_extract_func,
|
||||
'stdin_fd': sys.stdin.fileno() if MERGER_DEBUG else None
|
||||
}
|
||||
|
||||
#overridable optional
|
||||
def on_clients_initialized(self):
|
||||
io.progress_bar ("Merging", len(self.frames_idxs)+len(self.frames_done_idxs), initial=len(self.frames_done_idxs) )
|
||||
|
||||
self.process_remain_frames = not self.is_interactive
|
||||
self.is_interactive_quitting = not self.is_interactive
|
||||
|
||||
if self.is_interactive:
|
||||
help_images = {
|
||||
MergerConfig.TYPE_MASKED : cv2_imread ( str(Path(__file__).parent / 'gfx' / 'help_merger_masked.jpg') ),
|
||||
MergerConfig.TYPE_FACE_AVATAR : cv2_imread ( str(Path(__file__).parent / 'gfx' / 'help_merger_face_avatar.jpg') ),
|
||||
}
|
||||
|
||||
self.main_screen = Screen(initial_scale_to_width=1368, image=None, waiting_icon=True)
|
||||
self.help_screen = Screen(initial_scale_to_height=768, image=help_images[self.merger_config.type], waiting_icon=False)
|
||||
self.screen_manager = ScreenManager( "Merger", [self.main_screen, self.help_screen], capture_keys=True )
|
||||
self.screen_manager.set_current (self.help_screen)
|
||||
self.screen_manager.show_current()
|
||||
|
||||
self.masked_keys_funcs = {
|
||||
'`' : lambda cfg,shift_pressed: cfg.set_mode(0),
|
||||
'1' : lambda cfg,shift_pressed: cfg.set_mode(1),
|
||||
'2' : lambda cfg,shift_pressed: cfg.set_mode(2),
|
||||
'3' : lambda cfg,shift_pressed: cfg.set_mode(3),
|
||||
'4' : lambda cfg,shift_pressed: cfg.set_mode(4),
|
||||
'5' : lambda cfg,shift_pressed: cfg.set_mode(5),
|
||||
'q' : lambda cfg,shift_pressed: cfg.add_hist_match_threshold(1 if not shift_pressed else 5),
|
||||
'a' : lambda cfg,shift_pressed: cfg.add_hist_match_threshold(-1 if not shift_pressed else -5),
|
||||
'w' : lambda cfg,shift_pressed: cfg.add_erode_mask_modifier(1 if not shift_pressed else 5),
|
||||
's' : lambda cfg,shift_pressed: cfg.add_erode_mask_modifier(-1 if not shift_pressed else -5),
|
||||
'e' : lambda cfg,shift_pressed: cfg.add_blur_mask_modifier(1 if not shift_pressed else 5),
|
||||
'd' : lambda cfg,shift_pressed: cfg.add_blur_mask_modifier(-1 if not shift_pressed else -5),
|
||||
'r' : lambda cfg,shift_pressed: cfg.add_motion_blur_power(1 if not shift_pressed else 5),
|
||||
'f' : lambda cfg,shift_pressed: cfg.add_motion_blur_power(-1 if not shift_pressed else -5),
|
||||
't' : lambda cfg,shift_pressed: cfg.add_super_resolution_power(1 if not shift_pressed else 5),
|
||||
'g' : lambda cfg,shift_pressed: cfg.add_super_resolution_power(-1 if not shift_pressed else -5),
|
||||
'y' : lambda cfg,shift_pressed: cfg.add_blursharpen_amount(1 if not shift_pressed else 5),
|
||||
'h' : lambda cfg,shift_pressed: cfg.add_blursharpen_amount(-1 if not shift_pressed else -5),
|
||||
'u' : lambda cfg,shift_pressed: cfg.add_output_face_scale(1 if not shift_pressed else 5),
|
||||
'j' : lambda cfg,shift_pressed: cfg.add_output_face_scale(-1 if not shift_pressed else -5),
|
||||
'i' : lambda cfg,shift_pressed: cfg.add_image_denoise_power(1 if not shift_pressed else 5),
|
||||
'k' : lambda cfg,shift_pressed: cfg.add_image_denoise_power(-1 if not shift_pressed else -5),
|
||||
'o' : lambda cfg,shift_pressed: cfg.add_bicubic_degrade_power(1 if not shift_pressed else 5),
|
||||
'l' : lambda cfg,shift_pressed: cfg.add_bicubic_degrade_power(-1 if not shift_pressed else -5),
|
||||
'p' : lambda cfg,shift_pressed: cfg.add_color_degrade_power(1 if not shift_pressed else 5),
|
||||
';' : lambda cfg,shift_pressed: cfg.add_color_degrade_power(-1),
|
||||
':' : lambda cfg,shift_pressed: cfg.add_color_degrade_power(-5),
|
||||
'z' : lambda cfg,shift_pressed: cfg.toggle_masked_hist_match(),
|
||||
'x' : lambda cfg,shift_pressed: cfg.toggle_mask_mode(),
|
||||
'c' : lambda cfg,shift_pressed: cfg.toggle_color_transfer_mode(),
|
||||
'n' : lambda cfg,shift_pressed: cfg.toggle_sharpen_mode(),
|
||||
}
|
||||
self.masked_keys = list(self.masked_keys_funcs.keys())
|
||||
|
||||
#overridable optional
|
||||
def on_clients_finalized(self):
|
||||
io.progress_bar_close()
|
||||
|
||||
if self.is_interactive:
|
||||
self.screen_manager.finalize()
|
||||
|
||||
for frame in self.frames:
|
||||
frame.output_filepath = None
|
||||
frame.output_mask_filepath = None
|
||||
frame.image = None
|
||||
|
||||
session_data = {
|
||||
'frames': self.frames,
|
||||
'frames_idxs': self.frames_idxs,
|
||||
'frames_done_idxs': self.frames_done_idxs,
|
||||
'model_iter' : self.model_iter,
|
||||
}
|
||||
self.merger_session_filepath.write_bytes( pickle.dumps(session_data) )
|
||||
|
||||
io.log_info ("Session is saved to " + '/'.join (self.merger_session_filepath.parts[-2:]) )
|
||||
|
||||
#override
|
||||
def on_tick(self):
|
||||
io.process_messages()
|
||||
|
||||
go_prev_frame = False
|
||||
go_first_frame = False
|
||||
go_prev_frame_overriding_cfg = False
|
||||
go_first_frame_overriding_cfg = False
|
||||
|
||||
go_next_frame = self.process_remain_frames
|
||||
go_next_frame_overriding_cfg = False
|
||||
go_last_frame_overriding_cfg = False
|
||||
|
||||
cur_frame = None
|
||||
if len(self.frames_idxs) != 0:
|
||||
cur_frame = self.frames[self.frames_idxs[0]]
|
||||
|
||||
if self.is_interactive:
|
||||
|
||||
screen_image = None if self.process_remain_frames else \
|
||||
self.main_screen.get_image()
|
||||
|
||||
self.main_screen.set_waiting_icon( self.process_remain_frames or \
|
||||
self.is_interactive_quitting )
|
||||
|
||||
if cur_frame is not None and not self.is_interactive_quitting:
|
||||
|
||||
if not self.process_remain_frames:
|
||||
if cur_frame.is_done:
|
||||
if not cur_frame.is_shown:
|
||||
if cur_frame.image is None:
|
||||
image = cv2_imread (cur_frame.output_filepath, verbose=False)
|
||||
image_mask = cv2_imread (cur_frame.output_mask_filepath, verbose=False)
|
||||
if image is None or image_mask is None:
|
||||
# unable to read? recompute then
|
||||
cur_frame.is_done = False
|
||||
else:
|
||||
image_mask = imagelib.normalize_channels(image_mask, 1)
|
||||
cur_frame.image = np.concatenate([image, image_mask], -1)
|
||||
|
||||
if cur_frame.is_done:
|
||||
io.log_info (cur_frame.cfg.to_string( cur_frame.frame_info.filepath.name) )
|
||||
cur_frame.is_shown = True
|
||||
screen_image = cur_frame.image
|
||||
else:
|
||||
self.main_screen.set_waiting_icon(True)
|
||||
|
||||
self.main_screen.set_image(screen_image)
|
||||
self.screen_manager.show_current()
|
||||
|
||||
key_events = self.screen_manager.get_key_events()
|
||||
key, chr_key, ctrl_pressed, alt_pressed, shift_pressed = key_events[-1] if len(key_events) > 0 else (0,0,False,False,False)
|
||||
|
||||
if key == 9: #tab
|
||||
self.screen_manager.switch_screens()
|
||||
else:
|
||||
if key == 27: #esc
|
||||
self.is_interactive_quitting = True
|
||||
elif self.screen_manager.get_current() is self.main_screen:
|
||||
|
||||
if self.merger_config.type == MergerConfig.TYPE_MASKED and chr_key in self.masked_keys:
|
||||
self.process_remain_frames = False
|
||||
|
||||
if cur_frame is not None:
|
||||
cfg = cur_frame.cfg
|
||||
prev_cfg = cfg.copy()
|
||||
|
||||
if cfg.type == MergerConfig.TYPE_MASKED:
|
||||
self.masked_keys_funcs[chr_key](cfg, shift_pressed)
|
||||
|
||||
if prev_cfg != cfg:
|
||||
io.log_info ( cfg.to_string(cur_frame.frame_info.filepath.name) )
|
||||
cur_frame.is_done = False
|
||||
cur_frame.is_shown = False
|
||||
else:
|
||||
|
||||
if chr_key == ',' or chr_key == 'm':
|
||||
self.process_remain_frames = False
|
||||
go_prev_frame = True
|
||||
|
||||
if chr_key == ',':
|
||||
if shift_pressed:
|
||||
go_first_frame = True
|
||||
|
||||
elif chr_key == 'm':
|
||||
if not shift_pressed:
|
||||
go_prev_frame_overriding_cfg = True
|
||||
else:
|
||||
go_first_frame_overriding_cfg = True
|
||||
|
||||
elif chr_key == '.' or chr_key == '/':
|
||||
self.process_remain_frames = False
|
||||
go_next_frame = True
|
||||
|
||||
if chr_key == '.':
|
||||
if shift_pressed:
|
||||
self.process_remain_frames = not self.process_remain_frames
|
||||
|
||||
elif chr_key == '/':
|
||||
if not shift_pressed:
|
||||
go_next_frame_overriding_cfg = True
|
||||
else:
|
||||
go_last_frame_overriding_cfg = True
|
||||
|
||||
elif chr_key == '-':
|
||||
self.screen_manager.get_current().diff_scale(-0.1)
|
||||
elif chr_key == '=':
|
||||
self.screen_manager.get_current().diff_scale(0.1)
|
||||
elif chr_key == 'v':
|
||||
self.screen_manager.get_current().toggle_show_checker_board()
|
||||
|
||||
if go_prev_frame:
|
||||
if cur_frame is None or cur_frame.is_done:
|
||||
if cur_frame is not None:
|
||||
cur_frame.image = None
|
||||
|
||||
while True:
|
||||
if len(self.frames_done_idxs) > 0:
|
||||
prev_frame = self.frames[self.frames_done_idxs.pop()]
|
||||
self.frames_idxs.insert(0, prev_frame.idx)
|
||||
prev_frame.is_shown = False
|
||||
io.progress_bar_inc(-1)
|
||||
|
||||
if cur_frame is not None and (go_prev_frame_overriding_cfg or go_first_frame_overriding_cfg):
|
||||
if prev_frame.cfg != cur_frame.cfg:
|
||||
prev_frame.cfg = cur_frame.cfg.copy()
|
||||
prev_frame.is_done = False
|
||||
|
||||
cur_frame = prev_frame
|
||||
|
||||
if go_first_frame_overriding_cfg or go_first_frame:
|
||||
if len(self.frames_done_idxs) > 0:
|
||||
continue
|
||||
break
|
||||
|
||||
elif go_next_frame:
|
||||
if cur_frame is not None and cur_frame.is_done:
|
||||
cur_frame.image = None
|
||||
cur_frame.is_shown = True
|
||||
self.frames_done_idxs.append(cur_frame.idx)
|
||||
self.frames_idxs.pop(0)
|
||||
io.progress_bar_inc(1)
|
||||
|
||||
f = self.frames
|
||||
|
||||
if len(self.frames_idxs) != 0:
|
||||
next_frame = f[ self.frames_idxs[0] ]
|
||||
next_frame.is_shown = False
|
||||
|
||||
if go_next_frame_overriding_cfg or go_last_frame_overriding_cfg:
|
||||
|
||||
if go_next_frame_overriding_cfg:
|
||||
to_frames = next_frame.idx+1
|
||||
else:
|
||||
to_frames = len(f)
|
||||
|
||||
for i in range( next_frame.idx, to_frames ):
|
||||
f[i].cfg = None
|
||||
|
||||
for i in range( min(len(self.frames_idxs), self.prefetch_frame_count) ):
|
||||
frame = f[ self.frames_idxs[i] ]
|
||||
if frame.cfg is None:
|
||||
if i == 0:
|
||||
frame.cfg = cur_frame.cfg.copy()
|
||||
else:
|
||||
frame.cfg = f[ self.frames_idxs[i-1] ].cfg.copy()
|
||||
|
||||
frame.is_done = False #initiate solve again
|
||||
frame.is_shown = False
|
||||
|
||||
if len(self.frames_idxs) == 0:
|
||||
self.process_remain_frames = False
|
||||
|
||||
return (self.is_interactive and self.is_interactive_quitting) or \
|
||||
(not self.is_interactive and self.process_remain_frames == False)
|
||||
|
||||
|
||||
#override
|
||||
def on_data_return (self, host_dict, pf):
|
||||
frame = self.frames[pf.idx]
|
||||
frame.is_done = False
|
||||
frame.is_processing = False
|
||||
|
||||
#override
|
||||
def on_result (self, host_dict, pf_sent, pf_result):
|
||||
frame = self.frames[pf_result.idx]
|
||||
frame.is_processing = False
|
||||
if frame.cfg == pf_result.cfg:
|
||||
frame.is_done = True
|
||||
frame.image = pf_result.image
|
||||
|
||||
#override
|
||||
def get_data(self, host_dict):
|
||||
if self.is_interactive and self.is_interactive_quitting:
|
||||
return None
|
||||
|
||||
for i in range ( min(len(self.frames_idxs), self.prefetch_frame_count) ):
|
||||
frame = self.frames[ self.frames_idxs[i] ]
|
||||
|
||||
if not frame.is_done and not frame.is_processing and frame.cfg is not None:
|
||||
frame.is_processing = True
|
||||
return InteractiveMergerSubprocessor.ProcessingFrame(idx=frame.idx,
|
||||
cfg=frame.cfg.copy(),
|
||||
prev_temporal_frame_infos=frame.prev_temporal_frame_infos,
|
||||
frame_info=frame.frame_info,
|
||||
next_temporal_frame_infos=frame.next_temporal_frame_infos,
|
||||
output_filepath=frame.output_filepath,
|
||||
output_mask_filepath=frame.output_mask_filepath,
|
||||
need_return_image=True )
|
||||
|
||||
return None
|
||||
|
||||
#override
|
||||
def get_result(self):
|
||||
return 0
|
|
@ -8,7 +8,14 @@ from facelib import FaceType, LandmarksProcessor
|
|||
from core.interact import interact as io
|
||||
from core.cv2ex import *
|
||||
|
||||
def MergeMaskedFace (predictor_func, predictor_input_shape, cfg, frame_info, img_bgr_uint8, img_bgr, img_face_landmarks):
|
||||
fanseg_input_size = 256
|
||||
skinseg_input_size = 256
|
||||
|
||||
def MergeMaskedFace (predictor_func, predictor_input_shape,
|
||||
face_enhancer_func,
|
||||
fanseg_full_face_256_extract_func,
|
||||
skinseg_256_extract_func,
|
||||
cfg, frame_info, img_bgr_uint8, img_bgr, img_face_landmarks):
|
||||
img_size = img_bgr.shape[1], img_bgr.shape[0]
|
||||
img_face_mask_a = LandmarksProcessor.get_image_hull_mask (img_bgr.shape, img_face_landmarks)
|
||||
|
||||
|
@ -53,7 +60,7 @@ def MergeMaskedFace (predictor_func, predictor_input_shape, cfg, frame_info, img
|
|||
predictor_masked = False
|
||||
|
||||
if cfg.super_resolution_power != 0:
|
||||
prd_face_bgr_enhanced = cfg.superres_func(prd_face_bgr)
|
||||
prd_face_bgr_enhanced = face_enhancer_func(prd_face_bgr, is_tanh=True, preserve_size=False)
|
||||
mod = cfg.super_resolution_power / 100.0
|
||||
prd_face_bgr = cv2.resize(prd_face_bgr, (output_size,output_size))*(1.0-mod) + prd_face_bgr_enhanced*mod
|
||||
prd_face_bgr = np.clip(prd_face_bgr, 0, 1)
|
||||
|
@ -66,29 +73,29 @@ def MergeMaskedFace (predictor_func, predictor_input_shape, cfg, frame_info, img
|
|||
|
||||
if cfg.mask_mode == 2: #dst
|
||||
prd_face_mask_a_0 = cv2.resize (dst_face_mask_a_0, (output_size,output_size), cv2.INTER_CUBIC)
|
||||
elif cfg.mask_mode >= 3 and cfg.mask_mode <= 8:
|
||||
elif cfg.mask_mode >= 3 and cfg.mask_mode <= 7:
|
||||
|
||||
if cfg.mask_mode == 3 or cfg.mask_mode == 5 or cfg.mask_mode == 6:
|
||||
prd_face_fanseg_bgr = cv2.resize (prd_face_bgr, (cfg.fanseg_input_size,)*2 )
|
||||
prd_face_fanseg_mask = cfg.fanseg_extract_func(FaceType.FULL, prd_face_fanseg_bgr)
|
||||
prd_face_fanseg_bgr = cv2.resize (prd_face_bgr, (fanseg_input_size,)*2 )
|
||||
prd_face_fanseg_mask = fanseg_full_face_256_extract_func(prd_face_fanseg_bgr)
|
||||
FAN_prd_face_mask_a_0 = cv2.resize ( prd_face_fanseg_mask, (output_size, output_size), cv2.INTER_CUBIC)
|
||||
|
||||
if cfg.mask_mode >= 4 and cfg.mask_mode <= 7:
|
||||
|
||||
full_face_fanseg_mat = LandmarksProcessor.get_transform_mat (img_face_landmarks, cfg.fanseg_input_size, face_type=FaceType.FULL)
|
||||
dst_face_fanseg_bgr = cv2.warpAffine(img_bgr, full_face_fanseg_mat, (cfg.fanseg_input_size,)*2, flags=cv2.INTER_CUBIC )
|
||||
dst_face_fanseg_mask = cfg.fanseg_extract_func( FaceType.FULL, dst_face_fanseg_bgr )
|
||||
full_face_fanseg_mat = LandmarksProcessor.get_transform_mat (img_face_landmarks, fanseg_input_size, face_type=FaceType.FULL)
|
||||
dst_face_fanseg_bgr = cv2.warpAffine(img_bgr, full_face_fanseg_mat, (fanseg_input_size,)*2, flags=cv2.INTER_CUBIC )
|
||||
dst_face_fanseg_mask = fanseg_full_face_256_extract_func(dst_face_fanseg_bgr )
|
||||
|
||||
if cfg.face_type == FaceType.FULL:
|
||||
FAN_dst_face_mask_a_0 = cv2.resize (dst_face_fanseg_mask, (output_size,output_size), cv2.INTER_CUBIC)
|
||||
else:
|
||||
face_fanseg_mat = LandmarksProcessor.get_transform_mat (img_face_landmarks, cfg.fanseg_input_size, face_type=cfg.face_type)
|
||||
face_fanseg_mat = LandmarksProcessor.get_transform_mat (img_face_landmarks, fanseg_input_size, face_type=cfg.face_type)
|
||||
|
||||
fanseg_rect_corner_pts = np.array ( [ [0,0], [cfg.fanseg_input_size-1,0], [0,cfg.fanseg_input_size-1] ], dtype=np.float32 )
|
||||
fanseg_rect_corner_pts = np.array ( [ [0,0], [fanseg_input_size-1,0], [0,fanseg_input_size-1] ], dtype=np.float32 )
|
||||
a = LandmarksProcessor.transform_points (fanseg_rect_corner_pts, face_fanseg_mat, invert=True )
|
||||
b = LandmarksProcessor.transform_points (a, full_face_fanseg_mat )
|
||||
m = cv2.getAffineTransform(b, fanseg_rect_corner_pts)
|
||||
FAN_dst_face_mask_a_0 = cv2.warpAffine(dst_face_fanseg_mask, m, (cfg.fanseg_input_size,)*2, flags=cv2.INTER_CUBIC )
|
||||
FAN_dst_face_mask_a_0 = cv2.warpAffine(dst_face_fanseg_mask, m, (fanseg_input_size,)*2, flags=cv2.INTER_CUBIC )
|
||||
FAN_dst_face_mask_a_0 = cv2.resize (FAN_dst_face_mask_a_0, (output_size,output_size), cv2.INTER_CUBIC)
|
||||
|
||||
if cfg.mask_mode == 3: #FAN-prd
|
||||
|
@ -101,7 +108,28 @@ def MergeMaskedFace (predictor_func, predictor_input_shape, cfg, frame_info, img
|
|||
prd_face_mask_a_0 = prd_face_mask_a_0 * FAN_prd_face_mask_a_0 * FAN_dst_face_mask_a_0
|
||||
elif cfg.mask_mode == 7:
|
||||
prd_face_mask_a_0 = prd_face_mask_a_0 * FAN_dst_face_mask_a_0
|
||||
|
||||
elif cfg.mask_mode >= 8 and cfg.mask_mode <= 11:
|
||||
if cfg.mask_mode == 8 or cfg.mask_mode == 10 or cfg.mask_mode == 11:
|
||||
prd_face_skinseg_bgr = cv2.resize (prd_face_bgr, (skinseg_input_size,)*2 )
|
||||
prd_face_skinseg_mask = skinseg_256_extract_func(prd_face_skinseg_bgr)
|
||||
X_prd_face_mask_a_0 = cv2.resize ( prd_face_skinseg_mask, (output_size, output_size), cv2.INTER_CUBIC)
|
||||
|
||||
if cfg.mask_mode >= 9 and cfg.mask_mode <= 11:
|
||||
whole_face_mat = LandmarksProcessor.get_transform_mat (img_face_landmarks, skinseg_input_size, face_type=FaceType.WHOLE_FACE)
|
||||
dst_face_skinseg_bgr = cv2.warpAffine(img_bgr, whole_face_mat, (skinseg_input_size,)*2, flags=cv2.INTER_CUBIC )
|
||||
dst_face_skinseg_mask = skinseg_256_extract_func(dst_face_skinseg_bgr)
|
||||
X_dst_face_mask_a_0 = cv2.resize (dst_face_skinseg_mask, (output_size,output_size), cv2.INTER_CUBIC)
|
||||
|
||||
if cfg.mask_mode == 8: #'SkinSeg-prd',
|
||||
prd_face_mask_a_0 = X_prd_face_mask_a_0
|
||||
elif cfg.mask_mode == 9: #'SkinSeg-dst',
|
||||
prd_face_mask_a_0 = X_dst_face_mask_a_0
|
||||
elif cfg.mask_mode == 10: #'SkinSeg-prd*SkinSeg-dst',
|
||||
prd_face_mask_a_0 = X_prd_face_mask_a_0 * X_dst_face_mask_a_0
|
||||
elif cfg.mask_mode == 11: #learned*SkinSeg-prd*SkinSeg-dst'
|
||||
prd_face_mask_a_0 = prd_face_mask_a_0 * X_prd_face_mask_a_0 * X_dst_face_mask_a_0
|
||||
|
||||
prd_face_mask_a_0[ prd_face_mask_a_0 < (1.0/255.0) ] = 0.0 # get rid of noise
|
||||
|
||||
# resize to mask_subres_size
|
||||
|
@ -280,7 +308,7 @@ def MergeMaskedFace (predictor_func, predictor_input_shape, cfg, frame_info, img
|
|||
out_face_bgr = imagelib.LinearMotionBlur (out_face_bgr, k_size , frame_info.motion_deg)
|
||||
|
||||
if cfg.blursharpen_amount != 0:
|
||||
out_face_bgr = cfg.blursharpen_func ( out_face_bgr, cfg.sharpen_mode, 3, cfg.blursharpen_amount)
|
||||
out_face_bgr = imagelib.blursharpen ( out_face_bgr, cfg.sharpen_mode, 3, cfg.blursharpen_amount)
|
||||
|
||||
|
||||
if cfg.image_denoise_power != 0:
|
||||
|
@ -315,14 +343,20 @@ def MergeMaskedFace (predictor_func, predictor_input_shape, cfg, frame_info, img
|
|||
return out_img, out_merging_mask_a
|
||||
|
||||
|
||||
def MergeMasked (predictor_func, predictor_input_shape, cfg, frame_info):
|
||||
def MergeMasked (predictor_func,
|
||||
predictor_input_shape,
|
||||
face_enhancer_func,
|
||||
fanseg_full_face_256_extract_func,
|
||||
skinseg_256_extract_func,
|
||||
cfg,
|
||||
frame_info):
|
||||
img_bgr_uint8 = cv2_imread(frame_info.filepath)
|
||||
img_bgr_uint8 = imagelib.normalize_channels (img_bgr_uint8, 3)
|
||||
img_bgr = img_bgr_uint8.astype(np.float32) / 255.0
|
||||
|
||||
outs = []
|
||||
for face_num, img_landmarks in enumerate( frame_info.landmarks_list ):
|
||||
out_img, out_img_merging_mask = MergeMaskedFace (predictor_func, predictor_input_shape, cfg, frame_info, img_bgr_uint8, img_bgr, img_landmarks)
|
||||
out_img, out_img_merging_mask = MergeMaskedFace (predictor_func, predictor_input_shape, face_enhancer_func, fanseg_full_face_256_extract_func, skinseg_256_extract_func, cfg, frame_info, img_bgr_uint8, img_bgr, img_landmarks)
|
||||
outs += [ (out_img, out_img_merging_mask) ]
|
||||
|
||||
#Combining multiple face outputs
|
||||
|
|
|
@ -21,11 +21,6 @@ class MergerConfig(object):
|
|||
):
|
||||
self.type = type
|
||||
|
||||
self.superres_func = None
|
||||
self.blursharpen_func = None
|
||||
self.fanseg_input_size = None
|
||||
self.fanseg_extract_func = None
|
||||
|
||||
self.sharpen_dict = {0:"None", 1:'box', 2:'gaussian'}
|
||||
|
||||
#default changeable params
|
||||
|
@ -88,6 +83,19 @@ mode_str_dict = {}
|
|||
for key in mode_dict.keys():
|
||||
mode_str_dict[ mode_dict[key] ] = key
|
||||
|
||||
whole_face_mask_mode_dict = {1:'learned',
|
||||
2:'dst',
|
||||
3:'FAN-prd',
|
||||
4:'FAN-dst',
|
||||
5:'FAN-prd*FAN-dst',
|
||||
6:'learned*FAN-prd*FAN-dst'
|
||||
}
|
||||
"""
|
||||
8:'SkinSeg-prd',
|
||||
9:'SkinSeg-dst',
|
||||
10:'SkinSeg-prd*SkinSeg-dst',
|
||||
11:'learned*SkinSeg-prd*SkinSeg-dst'
|
||||
"""
|
||||
full_face_mask_mode_dict = {1:'learned',
|
||||
2:'dst',
|
||||
3:'FAN-prd',
|
||||
|
@ -164,7 +172,9 @@ class MergerConfigMasked(MergerConfig):
|
|||
self.hist_match_threshold = np.clip ( self.hist_match_threshold+diff , 0, 255)
|
||||
|
||||
def toggle_mask_mode(self):
|
||||
if self.face_type == FaceType.FULL:
|
||||
if self.face_type == FaceType.WHOLE_FACE:
|
||||
a = list( whole_face_mask_mode_dict.keys() )
|
||||
elif self.face_type == FaceType.FULL:
|
||||
a = list( full_face_mask_mode_dict.keys() )
|
||||
else:
|
||||
a = list( half_face_mask_mode_dict.keys() )
|
||||
|
@ -213,7 +223,14 @@ class MergerConfigMasked(MergerConfig):
|
|||
if self.mode == 'hist-match' or self.mode == 'seamless-hist-match':
|
||||
self.hist_match_threshold = np.clip ( io.input_int("Hist match threshold", 255, add_info="0..255"), 0, 255)
|
||||
|
||||
if self.face_type == FaceType.FULL:
|
||||
if self.face_type == FaceType.WHOLE_FACE:
|
||||
s = """Choose mask mode: \n"""
|
||||
for key in whole_face_mask_mode_dict.keys():
|
||||
s += f"""({key}) {whole_face_mask_mode_dict[key]}\n"""
|
||||
io.log_info(s)
|
||||
|
||||
self.mask_mode = io.input_int ("", 1, valid_list=whole_face_mask_mode_dict.keys() )
|
||||
elif self.face_type == FaceType.FULL:
|
||||
s = """Choose mask mode: \n"""
|
||||
for key in full_face_mask_mode_dict.keys():
|
||||
s += f"""({key}) {full_face_mask_mode_dict[key]}\n"""
|
||||
|
@ -282,7 +299,9 @@ class MergerConfigMasked(MergerConfig):
|
|||
if self.mode == 'hist-match' or self.mode == 'seamless-hist-match':
|
||||
r += f"""hist_match_threshold: {self.hist_match_threshold}\n"""
|
||||
|
||||
if self.face_type == FaceType.FULL:
|
||||
if self.face_type == FaceType.WHOLE_FACE:
|
||||
r += f"""mask_mode: { whole_face_mask_mode_dict[self.mask_mode] }\n"""
|
||||
elif self.face_type == FaceType.FULL:
|
||||
r += f"""mask_mode: { full_face_mask_mode_dict[self.mask_mode] }\n"""
|
||||
else:
|
||||
r += f"""mask_mode: { half_face_mask_mode_dict[self.mask_mode] }\n"""
|
||||
|
|
Before Width: | Height: | Size: 3.2 KiB After Width: | Height: | Size: 3.2 KiB |
|
@ -2,3 +2,4 @@ from .FrameInfo import FrameInfo
|
|||
from .MergerConfig import MergerConfig, MergerConfigMasked, MergerConfigFaceAvatar
|
||||
from .MergeMasked import MergeMasked
|
||||
from .MergeAvatar import MergeFaceAvatar
|
||||
from .InteractiveMergerSubprocessor import InteractiveMergerSubprocessor
|
Before Width: | Height: | Size: 222 KiB After Width: | Height: | Size: 222 KiB |
Before Width: | Height: | Size: 307 KiB After Width: | Height: | Size: 307 KiB |
|
@ -454,7 +454,7 @@ class ModelBase(object):
|
|||
self.generate_next_samples()
|
||||
|
||||
def finalize(self):
|
||||
nn.tf_close_session()
|
||||
nn.close_session()
|
||||
|
||||
def is_first_run(self):
|
||||
return self.iter == 0
|
||||
|
|
|
@ -18,21 +18,17 @@ class FANSegModel(ModelBase):
|
|||
device_config = nn.getCurrentDeviceConfig()
|
||||
yn_str = {True:'y',False:'n'}
|
||||
|
||||
#default_resolution = 256
|
||||
#default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'f')
|
||||
|
||||
ask_override = self.ask_override()
|
||||
if self.is_first_run() or ask_override:
|
||||
self.ask_autobackup_hour()
|
||||
self.ask_target_iter()
|
||||
self.ask_batch_size(24)
|
||||
|
||||
#if self.is_first_run():
|
||||
#resolution = io.input_int("Resolution", default_resolution, add_info="64-512")
|
||||
#resolution = np.clip ( (resolution // 16) * 16, 64, 512)
|
||||
#self.options['resolution'] = resolution
|
||||
#self.options['face_type'] = io.input_str ("Face type", default_face_type, ['f']).lower()
|
||||
|
||||
default_lr_dropout = self.options['lr_dropout'] = self.load_or_def_option('lr_dropout', False)
|
||||
|
||||
if self.is_first_run() or ask_override:
|
||||
self.options['lr_dropout'] = io.input_bool ("Use learning rate dropout", default_lr_dropout, help_message="When the face is trained enough, you can enable this option to get extra sharpness and reduce subpixel shake for less amount of iterations.")
|
||||
|
||||
#override
|
||||
def on_initialize(self):
|
||||
device_config = nn.getCurrentDeviceConfig()
|
||||
|
@ -42,11 +38,7 @@ class FANSegModel(ModelBase):
|
|||
device_config = nn.getCurrentDeviceConfig()
|
||||
devices = device_config.devices
|
||||
|
||||
self.resolution = resolution = 256#self.options['resolution']
|
||||
#self.face_type = {'h' : FaceType.HALF,
|
||||
# 'mf' : FaceType.MID_FULL,
|
||||
# 'f' : FaceType.FULL,
|
||||
# 'wf' : FaceType.WHOLE_FACE}[ self.options['face_type'] ]
|
||||
self.resolution = resolution = 256
|
||||
self.face_type = FaceType.FULL
|
||||
|
||||
place_model_on_cpu = len(devices) == 0
|
||||
|
@ -56,13 +48,13 @@ class FANSegModel(ModelBase):
|
|||
mask_shape = nn.get4Dshape(resolution,resolution,1)
|
||||
|
||||
# Initializing model classes
|
||||
self.model = TernausNet(f'{self.model_name}_FANSeg',
|
||||
self.model = TernausNet(f'{self.model_name}_FANSeg_{FaceType.toString(self.face_type)}',
|
||||
resolution,
|
||||
FaceType.toString(self.face_type),
|
||||
load_weights=not self.is_first_run(),
|
||||
weights_file_root=self.get_model_root_path(),
|
||||
training=True,
|
||||
place_model_on_cpu=place_model_on_cpu)
|
||||
place_model_on_cpu=place_model_on_cpu,
|
||||
optimizer=nn.RMSprop(lr=0.0001, lr_dropout=0.3 if self.options['lr_dropout'] else 1.0,name='opt') )
|
||||
|
||||
if self.is_training:
|
||||
# Adjust batch size for multiple GPU
|
||||
|
@ -93,15 +85,15 @@ class FANSegModel(ModelBase):
|
|||
gpu_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=gpu_target_t, logits=gpu_pred_logits_t), axis=[1,2,3])
|
||||
gpu_losses += [gpu_loss]
|
||||
|
||||
gpu_loss_gvs += [ nn.tf_gradients ( gpu_loss, self.model.net_weights ) ]
|
||||
gpu_loss_gvs += [ nn.gradients ( gpu_loss, self.model.net_weights ) ]
|
||||
|
||||
|
||||
# Average losses and gradients, and create optimizer update ops
|
||||
with tf.device (models_opt_device):
|
||||
pred = nn.tf_concat(gpu_pred_list, 0)
|
||||
pred = nn.concat(gpu_pred_list, 0)
|
||||
loss = tf.reduce_mean(gpu_losses)
|
||||
|
||||
loss_gv_op = self.model.opt.get_update_op (nn.tf_average_gv_list (gpu_loss_gvs))
|
||||
loss_gv_op = self.model.opt.get_update_op (nn.average_gv_list (gpu_loss_gvs))
|
||||
|
||||
|
||||
# Initializing training and view functions
|
||||
|
@ -171,7 +163,7 @@ class FANSegModel(ModelBase):
|
|||
result = []
|
||||
st = []
|
||||
for i in range(n_samples):
|
||||
ar = S[i]*TM[i]+ green_bg*(1-TM[i]), SM[i], S[i]*SM[i] + green_bg*(1-SM[i])
|
||||
ar = S[i]*TM[i] + 0.5*S[i]*(1-TM[i]) + 0.5*green_bg*(1-TM[i]), SM[i], S[i]*SM[i] + green_bg*(1-SM[i])
|
||||
st.append ( np.concatenate ( ar, axis=1) )
|
||||
result += [ ('FANSeg training faces', np.concatenate (st, axis=0 )), ]
|
||||
|
||||
|
|
|
@ -14,146 +14,11 @@ class QModel(ModelBase):
|
|||
#override
|
||||
def on_initialize(self):
|
||||
device_config = nn.getCurrentDeviceConfig()
|
||||
self.model_data_format = "NCHW" if len(device_config.devices) != 0 and not self.is_debug() else "NHWC"
|
||||
devices = device_config.devices
|
||||
self.model_data_format = "NCHW" if len(devices) != 0 and not self.is_debug() else "NHWC"
|
||||
nn.initialize(data_format=self.model_data_format)
|
||||
tf = nn.tf
|
||||
|
||||
conv_kernel_initializer = nn.initializers.ca()
|
||||
|
||||
class Downscale(nn.ModelBase):
|
||||
def __init__(self, in_ch, out_ch, kernel_size=5, dilations=1, subpixel=True, use_activator=True, *kwargs ):
|
||||
self.in_ch = in_ch
|
||||
self.out_ch = out_ch
|
||||
self.kernel_size = kernel_size
|
||||
self.dilations = dilations
|
||||
self.subpixel = subpixel
|
||||
self.use_activator = use_activator
|
||||
super().__init__(*kwargs)
|
||||
|
||||
def on_build(self, *args, **kwargs ):
|
||||
self.conv1 = nn.Conv2D( self.in_ch,
|
||||
self.out_ch // (4 if self.subpixel else 1),
|
||||
kernel_size=self.kernel_size,
|
||||
strides=1 if self.subpixel else 2,
|
||||
padding='SAME', dilations=self.dilations, kernel_initializer=conv_kernel_initializer )
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
|
||||
if self.subpixel:
|
||||
x = nn.tf_space_to_depth(x, 2)
|
||||
|
||||
if self.use_activator:
|
||||
x = nn.tf_gelu(x)
|
||||
return x
|
||||
|
||||
def get_out_ch(self):
|
||||
return (self.out_ch // 4) * 4
|
||||
|
||||
class DownscaleBlock(nn.ModelBase):
|
||||
def on_build(self, in_ch, ch, n_downscales, kernel_size, dilations=1, subpixel=True):
|
||||
self.downs = []
|
||||
|
||||
last_ch = in_ch
|
||||
for i in range(n_downscales):
|
||||
cur_ch = ch*( min(2**i, 8) )
|
||||
self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size, dilations=dilations, subpixel=subpixel) )
|
||||
last_ch = self.downs[-1].get_out_ch()
|
||||
|
||||
def forward(self, inp):
|
||||
x = inp
|
||||
for down in self.downs:
|
||||
x = down(x)
|
||||
return x
|
||||
|
||||
class Upscale(nn.ModelBase):
|
||||
def on_build(self, in_ch, out_ch, kernel_size=3 ):
|
||||
self.conv1 = nn.Conv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = nn.tf_gelu(x)
|
||||
x = nn.tf_depth_to_space(x, 2)
|
||||
return x
|
||||
|
||||
class ResidualBlock(nn.ModelBase):
|
||||
def on_build(self, ch, kernel_size=3 ):
|
||||
self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer)
|
||||
self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer)
|
||||
|
||||
def forward(self, inp):
|
||||
x = self.conv1(inp)
|
||||
x = nn.tf_gelu(x)
|
||||
x = self.conv2(x)
|
||||
x = inp + x
|
||||
x = nn.tf_gelu(x)
|
||||
return x
|
||||
|
||||
class Encoder(nn.ModelBase):
|
||||
def on_build(self, in_ch, e_ch):
|
||||
self.down1 = DownscaleBlock(in_ch, e_ch, n_downscales=4, kernel_size=5)
|
||||
def forward(self, inp):
|
||||
return nn.tf_flatten(self.down1(inp))
|
||||
|
||||
class Inter(nn.ModelBase):
|
||||
def __init__(self, in_ch, lowest_dense_res, ae_ch, ae_out_ch, d_ch, **kwargs):
|
||||
self.in_ch, self.lowest_dense_res, self.ae_ch, self.ae_out_ch, self.d_ch = in_ch, lowest_dense_res, ae_ch, ae_out_ch, d_ch
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def on_build(self):
|
||||
in_ch, lowest_dense_res, ae_ch, ae_out_ch, d_ch = self.in_ch, self.lowest_dense_res, self.ae_ch, self.ae_out_ch, self.d_ch
|
||||
|
||||
self.dense1 = nn.Dense( in_ch, ae_ch, kernel_initializer=tf.initializers.orthogonal )
|
||||
self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch, maxout_features=4, kernel_initializer=tf.initializers.orthogonal )
|
||||
self.upscale1 = Upscale(ae_out_ch, d_ch*8)
|
||||
self.res1 = ResidualBlock(d_ch*8)
|
||||
|
||||
def forward(self, inp):
|
||||
x = self.dense1(inp)
|
||||
x = self.dense2(x)
|
||||
x = nn.tf_reshape_4D (x, lowest_dense_res, lowest_dense_res, self.ae_out_ch)
|
||||
x = self.upscale1(x)
|
||||
x = self.res1(x)
|
||||
return x
|
||||
|
||||
def get_out_ch(self):
|
||||
return self.ae_out_ch
|
||||
|
||||
class Decoder(nn.ModelBase):
|
||||
def on_build(self, in_ch, d_ch):
|
||||
self.upscale1 = Upscale(in_ch, d_ch*4)
|
||||
self.res1 = ResidualBlock(d_ch*4)
|
||||
self.upscale2 = Upscale(d_ch*4, d_ch*2)
|
||||
self.res2 = ResidualBlock(d_ch*2)
|
||||
self.upscale3 = Upscale(d_ch*2, d_ch*1)
|
||||
self.res3 = ResidualBlock(d_ch*1)
|
||||
|
||||
self.upscalem1 = Upscale(in_ch, d_ch)
|
||||
self.upscalem2 = Upscale(d_ch, d_ch//2)
|
||||
self.upscalem3 = Upscale(d_ch//2, d_ch//2)
|
||||
|
||||
self.out_conv = nn.Conv2D( d_ch*1, 3, kernel_size=1, padding='SAME', kernel_initializer=conv_kernel_initializer)
|
||||
self.out_convm = nn.Conv2D( d_ch//2, 1, kernel_size=1, padding='SAME', kernel_initializer=conv_kernel_initializer)
|
||||
|
||||
def forward(self, inp):
|
||||
z = inp
|
||||
x = self.upscale1 (z)
|
||||
x = self.res1 (x)
|
||||
x = self.upscale2 (x)
|
||||
x = self.res2 (x)
|
||||
x = self.upscale3 (x)
|
||||
x = self.res3 (x)
|
||||
|
||||
y = self.upscalem1 (z)
|
||||
y = self.upscalem2 (y)
|
||||
y = self.upscalem3 (y)
|
||||
|
||||
return tf.nn.sigmoid(self.out_conv(x)), \
|
||||
tf.nn.sigmoid(self.out_convm(y))
|
||||
|
||||
device_config = nn.getCurrentDeviceConfig()
|
||||
devices = device_config.devices
|
||||
|
||||
resolution = self.resolution = 96
|
||||
self.face_type = FaceType.FULL
|
||||
ae_dims = 128
|
||||
|
@ -169,35 +34,34 @@ class QModel(ModelBase):
|
|||
optimizer_vars_on_cpu = models_opt_device=='/CPU:0'
|
||||
|
||||
input_ch = 3
|
||||
output_ch = 3
|
||||
bgr_shape = nn.get4Dshape(resolution,resolution,input_ch)
|
||||
mask_shape = nn.get4Dshape(resolution,resolution,1)
|
||||
lowest_dense_res = resolution // 16
|
||||
|
||||
self.model_filename_list = []
|
||||
|
||||
|
||||
model_archi = nn.DeepFakeArchi(resolution, mod='quick')
|
||||
|
||||
with tf.device ('/CPU:0'):
|
||||
#Place holders on CPU
|
||||
self.warped_src = tf.placeholder (nn.tf_floatx, bgr_shape)
|
||||
self.warped_dst = tf.placeholder (nn.tf_floatx, bgr_shape)
|
||||
self.warped_src = tf.placeholder (nn.floatx, bgr_shape)
|
||||
self.warped_dst = tf.placeholder (nn.floatx, bgr_shape)
|
||||
|
||||
self.target_src = tf.placeholder (nn.tf_floatx, bgr_shape)
|
||||
self.target_dst = tf.placeholder (nn.tf_floatx, bgr_shape)
|
||||
self.target_src = tf.placeholder (nn.floatx, bgr_shape)
|
||||
self.target_dst = tf.placeholder (nn.floatx, bgr_shape)
|
||||
|
||||
self.target_srcm = tf.placeholder (nn.tf_floatx, mask_shape)
|
||||
self.target_dstm = tf.placeholder (nn.tf_floatx, mask_shape)
|
||||
self.target_srcm = tf.placeholder (nn.floatx, mask_shape)
|
||||
self.target_dstm = tf.placeholder (nn.floatx, mask_shape)
|
||||
|
||||
# Initializing model classes
|
||||
with tf.device (models_opt_device):
|
||||
self.encoder = Encoder(in_ch=input_ch, e_ch=e_dims, name='encoder')
|
||||
encoder_out_ch = self.encoder.compute_output_channels ( (nn.tf_floatx, bgr_shape))
|
||||
self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, name='encoder')
|
||||
encoder_out_ch = self.encoder.compute_output_channels ( (nn.floatx, bgr_shape))
|
||||
|
||||
self.inter = Inter (in_ch=encoder_out_ch, lowest_dense_res=lowest_dense_res, ae_ch=ae_dims, ae_out_ch=ae_dims, d_ch=d_dims, name='inter')
|
||||
inter_out_ch = self.inter.compute_output_channels ( (nn.tf_floatx, (None,encoder_out_ch)))
|
||||
self.inter = model_archi.Inter (in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims, d_ch=d_dims, name='inter')
|
||||
inter_out_ch = self.inter.compute_output_channels ( (nn.floatx, (None,encoder_out_ch)))
|
||||
|
||||
self.decoder_src = Decoder(in_ch=inter_out_ch, d_ch=d_dims, name='decoder_src')
|
||||
self.decoder_dst = Decoder(in_ch=inter_out_ch, d_ch=d_dims, name='decoder_dst')
|
||||
self.decoder_src = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, name='decoder_src')
|
||||
self.decoder_dst = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, name='decoder_dst')
|
||||
|
||||
self.model_filename_list += [ [self.encoder, 'encoder.npy' ],
|
||||
[self.inter, 'inter.npy' ],
|
||||
|
@ -208,7 +72,7 @@ class QModel(ModelBase):
|
|||
self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter.get_weights() + self.decoder_src.get_weights() + self.decoder_dst.get_weights()
|
||||
|
||||
# Initialize optimizers
|
||||
self.src_dst_opt = nn.TFRMSpropOptimizer(lr=2e-4, lr_dropout=0.3, name='src_dst_opt')
|
||||
self.src_dst_opt = nn.RMSprop(lr=2e-4, lr_dropout=0.3, name='src_dst_opt')
|
||||
self.src_dst_opt.initialize_variables(self.src_dst_trainable_weights, vars_on_cpu=optimizer_vars_on_cpu )
|
||||
self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ]
|
||||
|
||||
|
@ -257,8 +121,8 @@ class QModel(ModelBase):
|
|||
gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm)
|
||||
gpu_pred_src_dstm_list.append(gpu_pred_src_dstm)
|
||||
|
||||
gpu_target_srcm_blur = nn.tf_gaussian_blur(gpu_target_srcm, max(1, resolution // 32) )
|
||||
gpu_target_dstm_blur = nn.tf_gaussian_blur(gpu_target_dstm, max(1, resolution // 32) )
|
||||
gpu_target_srcm_blur = nn.gaussian_blur(gpu_target_srcm, max(1, resolution // 32) )
|
||||
gpu_target_dstm_blur = nn.gaussian_blur(gpu_target_dstm, max(1, resolution // 32) )
|
||||
|
||||
gpu_target_dst_masked = gpu_target_dst*gpu_target_dstm_blur
|
||||
gpu_target_dst_anti_masked = gpu_target_dst*(1.0 - gpu_target_dstm_blur)
|
||||
|
@ -272,11 +136,11 @@ class QModel(ModelBase):
|
|||
gpu_psd_target_dst_masked = gpu_pred_src_dst*gpu_target_dstm_blur
|
||||
gpu_psd_target_dst_anti_masked = gpu_pred_src_dst*(1.0 - gpu_target_dstm_blur)
|
||||
|
||||
gpu_src_loss = tf.reduce_mean ( 10*nn.tf_dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
|
||||
gpu_src_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
|
||||
gpu_src_loss += tf.reduce_mean ( 10*tf.square ( gpu_target_src_masked_opt - gpu_pred_src_src_masked_opt ), axis=[1,2,3])
|
||||
gpu_src_loss += tf.reduce_mean ( 10*tf.square( gpu_target_srcm - gpu_pred_src_srcm ),axis=[1,2,3] )
|
||||
|
||||
gpu_dst_loss = tf.reduce_mean ( 10*nn.tf_dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1])
|
||||
gpu_dst_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1])
|
||||
gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dst_masked_opt- gpu_pred_dst_dst_masked_opt ), axis=[1,2,3])
|
||||
gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] )
|
||||
|
||||
|
@ -284,21 +148,21 @@ class QModel(ModelBase):
|
|||
gpu_dst_losses += [gpu_dst_loss]
|
||||
|
||||
gpu_G_loss = gpu_src_loss + gpu_dst_loss
|
||||
gpu_src_dst_loss_gvs += [ nn.tf_gradients ( gpu_G_loss, self.src_dst_trainable_weights ) ]
|
||||
gpu_src_dst_loss_gvs += [ nn.gradients ( gpu_G_loss, self.src_dst_trainable_weights ) ]
|
||||
|
||||
|
||||
# Average losses and gradients, and create optimizer update ops
|
||||
with tf.device (models_opt_device):
|
||||
pred_src_src = nn.tf_concat(gpu_pred_src_src_list, 0)
|
||||
pred_dst_dst = nn.tf_concat(gpu_pred_dst_dst_list, 0)
|
||||
pred_src_dst = nn.tf_concat(gpu_pred_src_dst_list, 0)
|
||||
pred_src_srcm = nn.tf_concat(gpu_pred_src_srcm_list, 0)
|
||||
pred_dst_dstm = nn.tf_concat(gpu_pred_dst_dstm_list, 0)
|
||||
pred_src_dstm = nn.tf_concat(gpu_pred_src_dstm_list, 0)
|
||||
pred_src_src = nn.concat(gpu_pred_src_src_list, 0)
|
||||
pred_dst_dst = nn.concat(gpu_pred_dst_dst_list, 0)
|
||||
pred_src_dst = nn.concat(gpu_pred_src_dst_list, 0)
|
||||
pred_src_srcm = nn.concat(gpu_pred_src_srcm_list, 0)
|
||||
pred_dst_dstm = nn.concat(gpu_pred_dst_dstm_list, 0)
|
||||
pred_src_dstm = nn.concat(gpu_pred_src_dstm_list, 0)
|
||||
|
||||
src_loss = nn.tf_average_tensor_list(gpu_src_losses)
|
||||
dst_loss = nn.tf_average_tensor_list(gpu_dst_losses)
|
||||
src_dst_loss_gv = nn.tf_average_gv_list (gpu_src_dst_loss_gvs)
|
||||
src_loss = nn.average_tensor_list(gpu_src_losses)
|
||||
dst_loss = nn.average_tensor_list(gpu_dst_losses)
|
||||
src_dst_loss_gv = nn.average_gv_list (gpu_src_dst_loss_gvs)
|
||||
src_dst_loss_gv_op = self.src_dst_opt.get_update_op (src_dst_loss_gv)
|
||||
|
||||
# Initializing training and view functions
|
||||
|
|
|
@ -119,13 +119,11 @@ class SAEHDModel(ModelBase):
|
|||
#override
|
||||
def on_initialize(self):
|
||||
device_config = nn.getCurrentDeviceConfig()
|
||||
self.model_data_format = "NCHW" if len(device_config.devices) != 0 and not self.is_debug() else "NHWC"
|
||||
devices = device_config.devices
|
||||
self.model_data_format = "NCHW" if len(devices) != 0 and not self.is_debug() else "NHWC"
|
||||
nn.initialize(data_format=self.model_data_format)
|
||||
tf = nn.tf
|
||||
|
||||
device_config = nn.getCurrentDeviceConfig()
|
||||
devices = device_config.devices
|
||||
|
||||
self.resolution = resolution = self.options['resolution']
|
||||
self.face_type = {'h' : FaceType.HALF,
|
||||
'mf' : FaceType.MID_FULL,
|
||||
|
@ -154,41 +152,38 @@ class SAEHDModel(ModelBase):
|
|||
models_opt_device = '/GPU:0' if models_opt_on_gpu and self.is_training else '/CPU:0'
|
||||
optimizer_vars_on_cpu = models_opt_device=='/CPU:0'
|
||||
|
||||
input_ch = 3
|
||||
output_ch = 3
|
||||
input_ch=3
|
||||
bgr_shape = nn.get4Dshape(resolution,resolution,input_ch)
|
||||
mask_shape = nn.get4Dshape(resolution,resolution,1)
|
||||
self.model_filename_list = []
|
||||
|
||||
|
||||
with tf.device ('/CPU:0'):
|
||||
#Place holders on CPU
|
||||
self.warped_src = tf.placeholder (nn.tf_floatx, bgr_shape)
|
||||
self.warped_dst = tf.placeholder (nn.tf_floatx, bgr_shape)
|
||||
self.warped_src = tf.placeholder (nn.floatx, bgr_shape)
|
||||
self.warped_dst = tf.placeholder (nn.floatx, bgr_shape)
|
||||
|
||||
self.target_src = tf.placeholder (nn.tf_floatx, bgr_shape)
|
||||
self.target_dst = tf.placeholder (nn.tf_floatx, bgr_shape)
|
||||
self.target_src = tf.placeholder (nn.floatx, bgr_shape)
|
||||
self.target_dst = tf.placeholder (nn.floatx, bgr_shape)
|
||||
|
||||
self.target_srcm_all = tf.placeholder (nn.tf_floatx, mask_shape)
|
||||
self.target_dstm_all = tf.placeholder (nn.tf_floatx, mask_shape)
|
||||
self.target_srcm_all = tf.placeholder (nn.floatx, mask_shape)
|
||||
self.target_dstm_all = tf.placeholder (nn.floatx, mask_shape)
|
||||
|
||||
# Initializing model classes
|
||||
if archi == 'liaech':
|
||||
lowest_dense_res, Encoder, Inter, Decoder = nn.get_ae_models_chervonij(resolution)
|
||||
model_archi = nn.DeepFakeArchi(resolution, mod='chervonij')
|
||||
else:
|
||||
lowest_dense_res, Encoder, Inter, Decoder = nn.get_ae_models(resolution)
|
||||
|
||||
model_archi = nn.DeepFakeArchi(resolution)
|
||||
|
||||
with tf.device (models_opt_device):
|
||||
if 'df' in archi:
|
||||
self.encoder = Encoder(in_ch=input_ch, e_ch=e_dims, is_hd=is_hd, name='encoder')
|
||||
encoder_out_ch = self.encoder.compute_output_channels ( (nn.tf_floatx, bgr_shape))
|
||||
self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, is_hd=is_hd, name='encoder')
|
||||
encoder_out_ch = self.encoder.compute_output_channels ( (nn.floatx, bgr_shape))
|
||||
|
||||
self.inter = Inter (in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims, is_hd=is_hd, name='inter')
|
||||
inter_out_ch = self.inter.compute_output_channels ( (nn.tf_floatx, (None,encoder_out_ch)))
|
||||
self.inter = model_archi.Inter (in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims, is_hd=is_hd, name='inter')
|
||||
inter_out_ch = self.inter.compute_output_channels ( (nn.floatx, (None,encoder_out_ch)))
|
||||
|
||||
self.decoder_src = Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, is_hd=is_hd, name='decoder_src')
|
||||
self.decoder_dst = Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, is_hd=is_hd, name='decoder_dst')
|
||||
self.decoder_src = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, is_hd=is_hd, name='decoder_src')
|
||||
self.decoder_dst = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, is_hd=is_hd, name='decoder_dst')
|
||||
|
||||
self.model_filename_list += [ [self.encoder, 'encoder.npy' ],
|
||||
[self.inter, 'inter.npy' ],
|
||||
|
@ -197,20 +192,20 @@ class SAEHDModel(ModelBase):
|
|||
|
||||
if self.is_training:
|
||||
if self.options['true_face_power'] != 0:
|
||||
self.code_discriminator = nn.CodeDiscriminator(ae_dims, code_res=lowest_dense_res*2, name='dis' )
|
||||
self.code_discriminator = nn.CodeDiscriminator(ae_dims, code_res=model_archi.Inter.get_code_res()*2, name='dis' )
|
||||
self.model_filename_list += [ [self.code_discriminator, 'code_discriminator.npy'] ]
|
||||
|
||||
elif 'liae' in archi:
|
||||
self.encoder = Encoder(in_ch=input_ch, e_ch=e_dims, is_hd=is_hd, name='encoder')
|
||||
encoder_out_ch = self.encoder.compute_output_channels ( (nn.tf_floatx, bgr_shape))
|
||||
self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, is_hd=is_hd, name='encoder')
|
||||
encoder_out_ch = self.encoder.compute_output_channels ( (nn.floatx, bgr_shape))
|
||||
|
||||
self.inter_AB = Inter(in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims*2, is_hd=is_hd, name='inter_AB')
|
||||
self.inter_B = Inter(in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims*2, is_hd=is_hd, name='inter_B')
|
||||
self.inter_AB = model_archi.Inter(in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims*2, is_hd=is_hd, name='inter_AB')
|
||||
self.inter_B = model_archi.Inter(in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims*2, is_hd=is_hd, name='inter_B')
|
||||
|
||||
inter_AB_out_ch = self.inter_AB.compute_output_channels ( (nn.tf_floatx, (None,encoder_out_ch)))
|
||||
inter_B_out_ch = self.inter_B.compute_output_channels ( (nn.tf_floatx, (None,encoder_out_ch)))
|
||||
inter_AB_out_ch = self.inter_AB.compute_output_channels ( (nn.floatx, (None,encoder_out_ch)))
|
||||
inter_B_out_ch = self.inter_B.compute_output_channels ( (nn.floatx, (None,encoder_out_ch)))
|
||||
inters_out_ch = inter_AB_out_ch+inter_B_out_ch
|
||||
self.decoder = Decoder(in_ch=inters_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, is_hd=is_hd, name='decoder')
|
||||
self.decoder = model_archi.Decoder(in_ch=inters_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, is_hd=is_hd, name='decoder')
|
||||
|
||||
self.model_filename_list += [ [self.encoder, 'encoder.npy'],
|
||||
[self.inter_AB, 'inter_AB.npy'],
|
||||
|
@ -219,8 +214,8 @@ class SAEHDModel(ModelBase):
|
|||
|
||||
if self.is_training:
|
||||
if gan_power != 0:
|
||||
self.D_src = nn.PatchDiscriminator(patch_size=resolution//16, in_ch=output_ch, name="D_src")
|
||||
self.D_dst = nn.PatchDiscriminator(patch_size=resolution//16, in_ch=output_ch, name="D_dst")
|
||||
self.D_src = nn.PatchDiscriminator(patch_size=resolution//16, in_ch=input_ch, name="D_src")
|
||||
self.D_dst = nn.PatchDiscriminator(patch_size=resolution//16, in_ch=input_ch, name="D_dst")
|
||||
self.model_filename_list += [ [self.D_src, 'D_src.npy'] ]
|
||||
self.model_filename_list += [ [self.D_dst, 'D_dst.npy'] ]
|
||||
|
||||
|
@ -228,7 +223,7 @@ class SAEHDModel(ModelBase):
|
|||
lr=5e-5
|
||||
lr_dropout = 0.3 if self.options['lr_dropout'] and not self.pretrain else 1.0
|
||||
clipnorm = 1.0 if self.options['clipgrad'] else 0.0
|
||||
self.src_dst_opt = nn.TFRMSpropOptimizer(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='src_dst_opt')
|
||||
self.src_dst_opt = nn.RMSprop(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='src_dst_opt')
|
||||
self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ]
|
||||
if 'df' in archi:
|
||||
self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter.get_weights() + self.decoder_src.get_weights() + self.decoder_dst.get_weights()
|
||||
|
@ -238,12 +233,12 @@ class SAEHDModel(ModelBase):
|
|||
self.src_dst_opt.initialize_variables (self.src_dst_trainable_weights, vars_on_cpu=optimizer_vars_on_cpu)
|
||||
|
||||
if self.options['true_face_power'] != 0:
|
||||
self.D_code_opt = nn.TFRMSpropOptimizer(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='D_code_opt')
|
||||
self.D_code_opt = nn.RMSprop(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='D_code_opt')
|
||||
self.D_code_opt.initialize_variables ( self.code_discriminator.get_weights(), vars_on_cpu=optimizer_vars_on_cpu)
|
||||
self.model_filename_list += [ (self.D_code_opt, 'D_code_opt.npy') ]
|
||||
|
||||
if gan_power != 0:
|
||||
self.D_src_dst_opt = nn.TFRMSpropOptimizer(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='D_src_dst_opt')
|
||||
self.D_src_dst_opt = nn.RMSprop(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='D_src_dst_opt')
|
||||
self.D_src_dst_opt.initialize_variables ( self.D_src.get_weights()+self.D_dst.get_weights(), vars_on_cpu=optimizer_vars_on_cpu)
|
||||
self.model_filename_list += [ (self.D_src_dst_opt, 'D_src_dst_opt.npy') ]
|
||||
|
||||
|
@ -316,8 +311,8 @@ class SAEHDModel(ModelBase):
|
|||
gpu_target_srcm_eyes = tf.clip_by_value (gpu_target_srcm_all-1, 0, 1)
|
||||
gpu_target_dstm_eyes = tf.clip_by_value (gpu_target_dstm_all-1, 0, 1)
|
||||
|
||||
gpu_target_srcm_blur = nn.tf_gaussian_blur(gpu_target_srcm, max(1, resolution // 32) )
|
||||
gpu_target_dstm_blur = nn.tf_gaussian_blur(gpu_target_dstm, max(1, resolution // 32) )
|
||||
gpu_target_srcm_blur = nn.gaussian_blur(gpu_target_srcm, max(1, resolution // 32) )
|
||||
gpu_target_dstm_blur = nn.gaussian_blur(gpu_target_dstm, max(1, resolution // 32) )
|
||||
|
||||
gpu_target_dst_masked = gpu_target_dst*gpu_target_dstm_blur
|
||||
gpu_target_dst_anti_masked = gpu_target_dst*(1.0 - gpu_target_dstm_blur)
|
||||
|
@ -331,7 +326,7 @@ class SAEHDModel(ModelBase):
|
|||
gpu_psd_target_dst_masked = gpu_pred_src_dst*gpu_target_dstm_blur
|
||||
gpu_psd_target_dst_anti_masked = gpu_pred_src_dst*(1.0 - gpu_target_dstm_blur)
|
||||
|
||||
gpu_src_loss = tf.reduce_mean ( 10*nn.tf_dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
|
||||
gpu_src_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
|
||||
gpu_src_loss += tf.reduce_mean ( 10*tf.square ( gpu_target_src_masked_opt - gpu_pred_src_src_masked_opt ), axis=[1,2,3])
|
||||
|
||||
if eyes_prio:
|
||||
|
@ -341,14 +336,14 @@ class SAEHDModel(ModelBase):
|
|||
|
||||
face_style_power = self.options['face_style_power'] / 100.0
|
||||
if face_style_power != 0 and not self.pretrain:
|
||||
gpu_src_loss += nn.tf_style_loss(gpu_psd_target_dst_masked, gpu_target_dst_masked, gaussian_blur_radius=resolution//16, loss_weight=10000*face_style_power)
|
||||
gpu_src_loss += nn.style_loss(gpu_psd_target_dst_masked, gpu_target_dst_masked, gaussian_blur_radius=resolution//16, loss_weight=10000*face_style_power)
|
||||
|
||||
bg_style_power = self.options['bg_style_power'] / 100.0
|
||||
if bg_style_power != 0 and not self.pretrain:
|
||||
gpu_src_loss += tf.reduce_mean( (10*bg_style_power)*nn.tf_dssim(gpu_psd_target_dst_anti_masked, gpu_target_dst_anti_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
|
||||
gpu_src_loss += tf.reduce_mean( (10*bg_style_power)*nn.dssim(gpu_psd_target_dst_anti_masked, gpu_target_dst_anti_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
|
||||
gpu_src_loss += tf.reduce_mean( (10*bg_style_power)*tf.square( gpu_psd_target_dst_anti_masked - gpu_target_dst_anti_masked), axis=[1,2,3] )
|
||||
|
||||
gpu_dst_loss = tf.reduce_mean ( 10*nn.tf_dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1])
|
||||
gpu_dst_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1])
|
||||
gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dst_masked_opt- gpu_pred_dst_dst_masked_opt ), axis=[1,2,3])
|
||||
|
||||
if eyes_prio:
|
||||
|
@ -376,7 +371,7 @@ class SAEHDModel(ModelBase):
|
|||
gpu_D_code_loss = (DLoss(gpu_src_code_d_ones , gpu_dst_code_d) + \
|
||||
DLoss(gpu_src_code_d_zeros, gpu_src_code_d) ) * 0.5
|
||||
|
||||
gpu_D_code_loss_gvs += [ nn.tf_gradients (gpu_D_code_loss, self.code_discriminator.get_weights() ) ]
|
||||
gpu_D_code_loss_gvs += [ nn.gradients (gpu_D_code_loss, self.code_discriminator.get_weights() ) ]
|
||||
|
||||
if gan_power != 0:
|
||||
gpu_pred_src_src_d = self.D_src(gpu_pred_src_src_masked_opt)
|
||||
|
@ -395,32 +390,32 @@ class SAEHDModel(ModelBase):
|
|||
(DLoss(gpu_target_dst_d_ones , gpu_target_dst_d) + \
|
||||
DLoss(gpu_pred_dst_dst_d_zeros, gpu_pred_dst_dst_d) ) * 0.5
|
||||
|
||||
gpu_D_src_dst_loss_gvs += [ nn.tf_gradients (gpu_D_src_dst_loss, self.D_src.get_weights()+self.D_dst.get_weights() ) ]
|
||||
gpu_D_src_dst_loss_gvs += [ nn.gradients (gpu_D_src_dst_loss, self.D_src.get_weights()+self.D_dst.get_weights() ) ]
|
||||
|
||||
gpu_G_loss += gan_power*(DLoss(gpu_pred_src_src_d_ones, gpu_pred_src_src_d) + DLoss(gpu_pred_dst_dst_d_ones, gpu_pred_dst_dst_d))
|
||||
|
||||
|
||||
gpu_G_loss_gvs += [ nn.tf_gradients ( gpu_G_loss, self.src_dst_trainable_weights ) ]
|
||||
gpu_G_loss_gvs += [ nn.gradients ( gpu_G_loss, self.src_dst_trainable_weights ) ]
|
||||
|
||||
|
||||
# Average losses and gradients, and create optimizer update ops
|
||||
with tf.device (models_opt_device):
|
||||
pred_src_src = nn.tf_concat(gpu_pred_src_src_list, 0)
|
||||
pred_dst_dst = nn.tf_concat(gpu_pred_dst_dst_list, 0)
|
||||
pred_src_dst = nn.tf_concat(gpu_pred_src_dst_list, 0)
|
||||
pred_src_srcm = nn.tf_concat(gpu_pred_src_srcm_list, 0)
|
||||
pred_dst_dstm = nn.tf_concat(gpu_pred_dst_dstm_list, 0)
|
||||
pred_src_dstm = nn.tf_concat(gpu_pred_src_dstm_list, 0)
|
||||
pred_src_src = nn.concat(gpu_pred_src_src_list, 0)
|
||||
pred_dst_dst = nn.concat(gpu_pred_dst_dst_list, 0)
|
||||
pred_src_dst = nn.concat(gpu_pred_src_dst_list, 0)
|
||||
pred_src_srcm = nn.concat(gpu_pred_src_srcm_list, 0)
|
||||
pred_dst_dstm = nn.concat(gpu_pred_dst_dstm_list, 0)
|
||||
pred_src_dstm = nn.concat(gpu_pred_src_dstm_list, 0)
|
||||
|
||||
src_loss = tf.concat(gpu_src_losses, 0)
|
||||
dst_loss = tf.concat(gpu_dst_losses, 0)
|
||||
src_dst_loss_gv_op = self.src_dst_opt.get_update_op (nn.tf_average_gv_list (gpu_G_loss_gvs))
|
||||
src_dst_loss_gv_op = self.src_dst_opt.get_update_op (nn.average_gv_list (gpu_G_loss_gvs))
|
||||
|
||||
if self.options['true_face_power'] != 0:
|
||||
D_loss_gv_op = self.D_code_opt.get_update_op (nn.tf_average_gv_list(gpu_D_code_loss_gvs))
|
||||
D_loss_gv_op = self.D_code_opt.get_update_op (nn.average_gv_list(gpu_D_code_loss_gvs))
|
||||
|
||||
if gan_power != 0:
|
||||
src_D_src_dst_loss_gv_op = self.D_src_dst_opt.get_update_op (nn.tf_average_gv_list(gpu_D_src_dst_loss_gvs) )
|
||||
src_D_src_dst_loss_gv_op = self.D_src_dst_opt.get_update_op (nn.average_gv_list(gpu_D_src_dst_loss_gvs) )
|
||||
|
||||
|
||||
# Initializing training and view functions
|
||||
|
|
|
@ -7,30 +7,29 @@ import numpy as np
|
|||
from core import mathlib
|
||||
from core.interact import interact as io
|
||||
from core.leras import nn
|
||||
from facelib import FaceType, TernausNet
|
||||
from facelib import FaceType, TernausNet, DFLSegNet
|
||||
from models import ModelBase
|
||||
from samplelib import *
|
||||
|
||||
class XSegModel(ModelBase):
|
||||
class SkinSegModel(ModelBase):
|
||||
|
||||
#override
|
||||
def on_initialize_options(self):
|
||||
device_config = nn.getCurrentDeviceConfig()
|
||||
yn_str = {True:'y',False:'n'}
|
||||
|
||||
#default_resolution = 256
|
||||
|
||||
ask_override = self.ask_override()
|
||||
if self.is_first_run() or ask_override:
|
||||
self.ask_autobackup_hour()
|
||||
self.ask_write_preview_history()
|
||||
self.ask_target_iter()
|
||||
self.ask_batch_size(24)
|
||||
|
||||
#if self.is_first_run():
|
||||
#resolution = io.input_int("Resolution", default_resolution, add_info="64-512")
|
||||
#resolution = np.clip ( (resolution // 16) * 16, 64, 512)
|
||||
#self.options['resolution'] = resolution
|
||||
|
||||
self.ask_batch_size(8)
|
||||
|
||||
default_lr_dropout = self.options['lr_dropout'] = self.load_or_def_option('lr_dropout', False)
|
||||
|
||||
if self.is_first_run() or ask_override:
|
||||
self.options['lr_dropout'] = io.input_bool ("Use learning rate dropout", default_lr_dropout, help_message="When the face is trained enough, you can enable this option to get extra sharpness and reduce subpixel shake for less amount of iterations.")
|
||||
|
||||
#override
|
||||
def on_initialize(self):
|
||||
device_config = nn.getCurrentDeviceConfig()
|
||||
|
@ -41,22 +40,24 @@ class XSegModel(ModelBase):
|
|||
device_config = nn.getCurrentDeviceConfig()
|
||||
devices = device_config.devices
|
||||
|
||||
self.resolution = resolution = 256#self.options['resolution']
|
||||
|
||||
place_model_on_cpu = True#len(devices) == 0
|
||||
self.resolution = resolution = 256
|
||||
self.face_type = FaceType.WHOLE_FACE
|
||||
|
||||
place_model_on_cpu = True #len(devices) == 0
|
||||
models_opt_device = '/CPU:0' if place_model_on_cpu else '/GPU:0'
|
||||
|
||||
bgr_shape = nn.get4Dshape(resolution,resolution,3)
|
||||
mask_shape = nn.get4Dshape(resolution,resolution,1)
|
||||
|
||||
# Initializing model classes
|
||||
self.model = TernausNet(f'{self.model_name}_SkinSeg',
|
||||
resolution,
|
||||
load_weights=not self.is_first_run(),
|
||||
weights_file_root=self.get_model_root_path(),
|
||||
training=True,
|
||||
place_model_on_cpu=place_model_on_cpu,
|
||||
data_format=nn.data_format)
|
||||
self.model = DFLSegNet(name=f'{self.model_name}_SkinSeg',
|
||||
resolution=resolution,
|
||||
load_weights=not self.is_first_run(),
|
||||
weights_file_root=self.get_model_root_path(),
|
||||
training=True,
|
||||
place_model_on_cpu=place_model_on_cpu,
|
||||
optimizer=nn.RMSprop(lr=0.0001, lr_dropout=0.3 if self.options['lr_dropout'] else 1.0, name='opt'),
|
||||
data_format=nn.data_format)
|
||||
|
||||
if self.is_training:
|
||||
# Adjust batch size for multiple GPU
|
||||
|
@ -81,21 +82,21 @@ class XSegModel(ModelBase):
|
|||
gpu_target_t = self.model.target_t [batch_slice,:,:,:]
|
||||
|
||||
# process model tensors
|
||||
gpu_pred_logits_t, gpu_pred_t = self.model.net([gpu_input_t])
|
||||
gpu_pred_logits_t, gpu_pred_t = self.model.flow(gpu_input_t)
|
||||
gpu_pred_list.append(gpu_pred_t)
|
||||
|
||||
gpu_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=gpu_target_t, logits=gpu_pred_logits_t), axis=[1,2,3])
|
||||
gpu_losses += [gpu_loss]
|
||||
|
||||
gpu_loss_gvs += [ nn.tf_gradients ( gpu_loss, self.model.net_weights ) ]
|
||||
gpu_loss_gvs += [ nn.gradients ( gpu_loss, self.model.get_weights() ) ]
|
||||
|
||||
|
||||
# Average losses and gradients, and create optimizer update ops
|
||||
with tf.device (models_opt_device):
|
||||
pred = nn.tf_concat(gpu_pred_list, 0)
|
||||
pred = nn.concat(gpu_pred_list, 0)
|
||||
loss = tf.reduce_mean(gpu_losses)
|
||||
|
||||
loss_gv_op = self.model.opt.get_update_op (nn.tf_average_gv_list (gpu_loss_gvs))
|
||||
loss_gv_op = self.model.opt.get_update_op (nn.average_gv_list (gpu_loss_gvs))
|
||||
|
||||
|
||||
# Initializing training and view functions
|
||||
|
@ -114,14 +115,30 @@ class XSegModel(ModelBase):
|
|||
dst_generators_count = cpu_count // 2
|
||||
src_generators_count = int(src_generators_count * 1.5)
|
||||
|
||||
|
||||
src_generator = SampleGeneratorFaceCelebAMaskHQ ( root_path=self.training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(), resolution=256, generators_count=src_generators_count, data_format = nn.data_format)
|
||||
|
||||
dst_generator = SampleGeneratorImage(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
|
||||
"""
|
||||
src_generator = SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
|
||||
sample_process_options=SampleProcessor.Options(random_flip=True),
|
||||
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.IMAGE, 'warp':False, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'data_format':nn.data_format, 'resolution': resolution} ],
|
||||
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR_RANDOM_HSV_SHIFT, 'border_replicate':False, 'face_type':self.face_type, 'motion_blur':(25, 5), 'gaussian_blur':(25,5), 'random_bilinear_resize':(25,75), 'data_format':nn.data_format, 'resolution': resolution},
|
||||
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.NONE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
],
|
||||
generators_count=src_generators_count )
|
||||
"""
|
||||
src_generator = SampleGeneratorFaceSkinSegDataset(self.training_data_src_path,
|
||||
debug=self.is_debug(),
|
||||
batch_size=self.get_batch_size(),
|
||||
resolution=resolution,
|
||||
face_type=self.face_type,
|
||||
generators_count=src_generators_count,
|
||||
data_format=nn.data_format)
|
||||
|
||||
dst_generator = SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
|
||||
sample_process_options=SampleProcessor.Options(random_flip=True),
|
||||
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp':False, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'border_replicate':False, 'face_type':self.face_type, 'motion_blur':(25, 5), 'gaussian_blur':(25,5), 'random_bilinear_resize':(25,75), 'data_format':nn.data_format, 'resolution': resolution},
|
||||
],
|
||||
generators_count=dst_generators_count,
|
||||
raise_on_no_data=False )
|
||||
|
||||
|
||||
if not dst_generator.is_initialized():
|
||||
io.log_info(f"\nTo view the model on unseen faces, place any image faces in {self.training_data_dst_path}.\n")
|
||||
|
||||
|
@ -157,9 +174,9 @@ class XSegModel(ModelBase):
|
|||
result = []
|
||||
st = []
|
||||
for i in range(n_samples):
|
||||
ar = I[i]*M[i]+ green_bg*(1-M[i]), IM[i], I[i]*IM[i] + green_bg*(1-IM[i])
|
||||
ar = I[i]*M[i]+0.5*I[i]*(1-M[i])+0.5*green_bg*(1-M[i]), IM[i], I[i]*IM[i] + green_bg*(1-IM[i])
|
||||
st.append ( np.concatenate ( ar, axis=1) )
|
||||
result += [ ('XSeg training faces', np.concatenate (st, axis=0 )), ]
|
||||
result += [ ('SkinSeg training faces', np.concatenate (st, axis=0 )), ]
|
||||
|
||||
if len(dst_samples) != 0:
|
||||
dst_np, = dst_samples
|
||||
|
@ -173,8 +190,8 @@ class XSegModel(ModelBase):
|
|||
ar = D[i], DM[i], D[i]*DM[i]+ green_bg*(1-DM[i])
|
||||
st.append ( np.concatenate ( ar, axis=1) )
|
||||
|
||||
result += [ ('XSeg unseen faces', np.concatenate (st, axis=0 )), ]
|
||||
result += [ ('SkinSeg unseen faces', np.concatenate (st, axis=0 )), ]
|
||||
|
||||
return result
|
||||
|
||||
Model = XSegModel
|
||||
Model = SkinSegModel
|
|
@ -82,6 +82,9 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
|||
return self
|
||||
|
||||
def __next__(self):
|
||||
if not self.initialized:
|
||||
return []
|
||||
|
||||
self.generator_counter += 1
|
||||
generator = self.generators[self.generator_counter % len(self.generators) ]
|
||||
return next(generator)
|
||||
|
|
|
@ -152,15 +152,15 @@ class SampleGeneratorFaceCelebAMaskHQ(SampleGeneratorBase):
|
|||
file_ids = list(mask_file_id_hash.keys())
|
||||
|
||||
shuffle_file_ids = []
|
||||
shuffle_file_ids_random_ct = []
|
||||
|
||||
resolution = 256
|
||||
random_flip = True
|
||||
rotation_range=[-15,15]
|
||||
scale_range=[-0.25, 0.75]
|
||||
scale_range=[-0.10, 0.95]
|
||||
tx_range=[-0.3, 0.3]
|
||||
ty_range=[-0.3, 0.3]
|
||||
|
||||
random_bilinear_resize = (25,75)
|
||||
motion_blur = (25, 5)
|
||||
gaussian_blur = (25, 5)
|
||||
|
||||
|
@ -174,22 +174,15 @@ class SampleGeneratorFaceCelebAMaskHQ(SampleGeneratorBase):
|
|||
if len(shuffle_file_ids) == 0:
|
||||
shuffle_file_ids = file_ids.copy()
|
||||
np.random.shuffle(shuffle_file_ids)
|
||||
if len(shuffle_file_ids_random_ct) == 0:
|
||||
shuffle_file_ids_random_ct = file_ids.copy()
|
||||
np.random.shuffle(shuffle_file_ids_random_ct)
|
||||
|
||||
file_id = shuffle_file_ids.pop()
|
||||
#file_id_random_ct = shuffle_file_ids_random_ct.pop()
|
||||
|
||||
masks = mask_file_id_hash[file_id]
|
||||
|
||||
image_path = images_path / f'{file_id}.jpg'
|
||||
#image_random_ct_path = images_path / f'{file_id_random_ct}.jpg'
|
||||
|
||||
skin_path = masks.get(MaskType.skin, None)
|
||||
hair_path = masks.get(MaskType.hair, None)
|
||||
hat_path = masks.get(MaskType.hat, None)
|
||||
neck_path = masks.get(MaskType.neck, None)
|
||||
#neck_path = masks.get(MaskType.neck, None)
|
||||
|
||||
img = cv2_imread(image_path).astype(np.float32) / 255.0
|
||||
mask = cv2_imread(masks_path / skin_path)[...,0:1].astype(np.float32) / 255.0
|
||||
|
@ -205,7 +198,13 @@ class SampleGeneratorFaceCelebAMaskHQ(SampleGeneratorBase):
|
|||
if hat_path.exists():
|
||||
hat = cv2_imread(hat_path)[...,0:1].astype(np.float32) / 255.0
|
||||
mask *= (1-hat)
|
||||
|
||||
|
||||
#if neck_path is not None:
|
||||
# neck_path = masks_path / neck_path
|
||||
# if neck_path.exists():
|
||||
# neck = cv2_imread(neck_path)[...,0:1].astype(np.float32) / 255.0
|
||||
# mask = np.clip(mask+neck, 0, 1)
|
||||
|
||||
warp_params = imagelib.gen_warp_params(resolution, random_flip, rotation_range=rotation_range, scale_range=scale_range, tx_range=tx_range, ty_range=ty_range )
|
||||
|
||||
img = cv2.resize( img, (resolution,resolution), cv2.INTER_LANCZOS4 )
|
||||
|
@ -215,9 +214,6 @@ class SampleGeneratorFaceCelebAMaskHQ(SampleGeneratorBase):
|
|||
v = np.clip ( v + np.random.random()/2-0.25, 0, 1 )
|
||||
img = np.clip( cv2.cvtColor(cv2.merge([h, s, v]), cv2.COLOR_HSV2BGR) , 0, 1 )
|
||||
|
||||
#img_random_ct = cv2.resize( cv2_imread(image_random_ct_path).astype(np.float32) / 255.0, (resolution,resolution), cv2.INTER_LANCZOS4 )
|
||||
#img = imagelib.color_transfer ('idt', img, img_random_ct )
|
||||
|
||||
if motion_blur is not None:
|
||||
chance, mb_max_size = motion_blur
|
||||
chance = np.clip(chance, 0, 100)
|
||||
|
@ -241,6 +237,15 @@ class SampleGeneratorFaceCelebAMaskHQ(SampleGeneratorBase):
|
|||
if gblur_rnd_chance < chance:
|
||||
img = cv2.GaussianBlur(img, (gblur_rnd_kernel,) *2 , 0)
|
||||
|
||||
if random_bilinear_resize is not None:
|
||||
chance, max_size_per = random_bilinear_resize
|
||||
chance = np.clip(chance, 0, 100)
|
||||
pick_chance = np.random.randint(100)
|
||||
resize_to = resolution - int( np.random.rand()* int(resolution*(max_size_per/100.0)) )
|
||||
img = cv2.resize (img, (resize_to,resize_to), cv2.INTER_LINEAR )
|
||||
img = cv2.resize (img, (resolution,resolution), cv2.INTER_LINEAR )
|
||||
|
||||
|
||||
mask = cv2.resize( mask, (resolution,resolution), cv2.INTER_LANCZOS4 )[...,None]
|
||||
mask = imagelib.warp_by_params (warp_params, mask, can_warp=True, can_transform=True, can_flip=True, border_replicate=False, cv2_inter=cv2.INTER_LANCZOS4)
|
||||
mask[mask < 0.5] = 0.0
|
||||
|
|
263
samplelib/SampleGeneratorFaceSkinSegDataset.py
Normal file
263
samplelib/SampleGeneratorFaceSkinSegDataset.py
Normal file
|
@ -0,0 +1,263 @@
|
|||
import multiprocessing
|
||||
import pickle
|
||||
import time
|
||||
import traceback
|
||||
from enum import IntEnum
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from core import imagelib, mplib, pathex
|
||||
from core.cv2ex import *
|
||||
from core.interact import interact as io
|
||||
from core.joblib import SubprocessGenerator, ThisThreadGenerator
|
||||
from facelib import LandmarksProcessor
|
||||
from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor, SampleType)
|
||||
|
||||
class MaskType(IntEnum):
|
||||
none = 0,
|
||||
cloth = 1,
|
||||
ear_r = 2,
|
||||
eye_g = 3,
|
||||
hair = 4,
|
||||
hat = 5,
|
||||
l_brow = 6,
|
||||
l_ear = 7,
|
||||
l_eye = 8,
|
||||
l_lip = 9,
|
||||
mouth = 10,
|
||||
neck = 11,
|
||||
neck_l = 12,
|
||||
nose = 13,
|
||||
r_brow = 14,
|
||||
r_ear = 15,
|
||||
r_eye = 16,
|
||||
skin = 17,
|
||||
u_lip = 18
|
||||
|
||||
|
||||
|
||||
MaskType_to_name = {
|
||||
int(MaskType.none ) : 'none',
|
||||
int(MaskType.cloth ) : 'cloth',
|
||||
int(MaskType.ear_r ) : 'ear_r',
|
||||
int(MaskType.eye_g ) : 'eye_g',
|
||||
int(MaskType.hair ) : 'hair',
|
||||
int(MaskType.hat ) : 'hat',
|
||||
int(MaskType.l_brow) : 'l_brow',
|
||||
int(MaskType.l_ear ) : 'l_ear',
|
||||
int(MaskType.l_eye ) : 'l_eye',
|
||||
int(MaskType.l_lip ) : 'l_lip',
|
||||
int(MaskType.mouth ) : 'mouth',
|
||||
int(MaskType.neck ) : 'neck',
|
||||
int(MaskType.neck_l) : 'neck_l',
|
||||
int(MaskType.nose ) : 'nose',
|
||||
int(MaskType.r_brow) : 'r_brow',
|
||||
int(MaskType.r_ear ) : 'r_ear',
|
||||
int(MaskType.r_eye ) : 'r_eye',
|
||||
int(MaskType.skin ) : 'skin',
|
||||
int(MaskType.u_lip ) : 'u_lip',
|
||||
}
|
||||
|
||||
MaskType_from_name = { MaskType_to_name[k] : k for k in MaskType_to_name.keys() }
|
||||
|
||||
class SampleGeneratorFaceSkinSegDataset(SampleGeneratorBase):
|
||||
def __init__ (self, root_path, debug=False, batch_size=1, resolution=256, face_type=None,
|
||||
generators_count=4, data_format="NHWC",
|
||||
**kwargs):
|
||||
|
||||
super().__init__(debug, batch_size)
|
||||
self.initialized = False
|
||||
|
||||
|
||||
dataset_path = root_path / 'XSegDataset'
|
||||
if not dataset_path.exists():
|
||||
raise ValueError(f'Unable to find {dataset_path}')
|
||||
|
||||
aligned_path = dataset_path /'aligned'
|
||||
if not aligned_path.exists():
|
||||
raise ValueError(f'Unable to find {aligned_path}')
|
||||
|
||||
obstructions_path = dataset_path / 'obstructions'
|
||||
|
||||
obstructions_images_paths = pathex.get_image_paths(obstructions_path, image_extensions=['.png'], subdirs=True)
|
||||
|
||||
samples = SampleLoader.load (SampleType.FACE, aligned_path, subdirs=True)
|
||||
self.samples_len = len(samples)
|
||||
|
||||
pickled_samples = pickle.dumps(samples, 4)
|
||||
|
||||
if self.debug:
|
||||
self.generators_count = 1
|
||||
else:
|
||||
self.generators_count = max(1, generators_count)
|
||||
|
||||
if self.debug:
|
||||
self.generators = [ThisThreadGenerator ( self.batch_func, (pickled_samples, obstructions_images_paths, resolution, face_type, data_format) )]
|
||||
else:
|
||||
self.generators = [SubprocessGenerator ( self.batch_func, (pickled_samples, obstructions_images_paths, resolution, face_type, data_format), start_now=False ) \
|
||||
for i in range(self.generators_count) ]
|
||||
|
||||
SubprocessGenerator.start_in_parallel( self.generators )
|
||||
|
||||
self.generator_counter = -1
|
||||
|
||||
self.initialized = True
|
||||
|
||||
#overridable
|
||||
def is_initialized(self):
|
||||
return self.initialized
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
self.generator_counter += 1
|
||||
generator = self.generators[self.generator_counter % len(self.generators) ]
|
||||
return next(generator)
|
||||
|
||||
def batch_func(self, param ):
|
||||
pickled_samples, obstructions_images_paths, resolution, face_type, data_format = param
|
||||
|
||||
samples = pickle.loads(pickled_samples)
|
||||
|
||||
shuffle_o_idxs = []
|
||||
o_idxs = [*range(len(obstructions_images_paths))]
|
||||
|
||||
shuffle_idxs = []
|
||||
idxs = [*range(len(samples))]
|
||||
|
||||
random_flip = True
|
||||
rotation_range=[-10,10]
|
||||
scale_range=[-0.05, 0.05]
|
||||
tx_range=[-0.05, 0.05]
|
||||
ty_range=[-0.05, 0.05]
|
||||
|
||||
o_random_flip = True
|
||||
o_rotation_range=[-180,180]
|
||||
o_scale_range=[-0.5, 0.05]
|
||||
o_tx_range=[-0.5, 0.5]
|
||||
o_ty_range=[-0.5, 0.5]
|
||||
|
||||
random_bilinear_resize_chance, random_bilinear_resize_max_size_per = 25,75
|
||||
motion_blur_chance, motion_blur_mb_max_size = 25, 5
|
||||
gaussian_blur_chance, gaussian_blur_kernel_max_size = 25, 5
|
||||
|
||||
bs = self.batch_size
|
||||
while True:
|
||||
batches = [ [], [] ]
|
||||
|
||||
n_batch = 0
|
||||
while n_batch < bs:
|
||||
try:
|
||||
if len(shuffle_idxs) == 0:
|
||||
shuffle_idxs = idxs.copy()
|
||||
np.random.shuffle(shuffle_idxs)
|
||||
|
||||
idx = shuffle_idxs.pop()
|
||||
|
||||
sample = samples[idx]
|
||||
|
||||
img = sample.load_bgr()
|
||||
h,w,c = img.shape
|
||||
|
||||
mask = np.zeros ((h,w,1), dtype=np.float32)
|
||||
sample.ie_polys.overlay_mask(mask)
|
||||
|
||||
warp_params = imagelib.gen_warp_params(resolution, random_flip, rotation_range=rotation_range, scale_range=scale_range, tx_range=tx_range, ty_range=ty_range )
|
||||
|
||||
if face_type == sample.face_type:
|
||||
if w != resolution:
|
||||
img = cv2.resize( img, (resolution, resolution), cv2.INTER_LANCZOS4 )
|
||||
mask = cv2.resize( mask, (resolution, resolution), cv2.INTER_LANCZOS4 )
|
||||
else:
|
||||
mat = LandmarksProcessor.get_transform_mat (sample.landmarks, resolution, face_type)
|
||||
img = cv2.warpAffine( img, mat, (resolution,resolution), borderMode=cv2.BORDER_CONSTANT, flags=cv2.INTER_LANCZOS4 )
|
||||
mask = cv2.warpAffine( mask, mat, (resolution,resolution), borderMode=cv2.BORDER_CONSTANT, flags=cv2.INTER_LANCZOS4 )
|
||||
|
||||
if len(mask.shape) == 2:
|
||||
mask = mask[...,None]
|
||||
|
||||
# apply obstruction
|
||||
if len(shuffle_o_idxs) == 0:
|
||||
shuffle_o_idxs = o_idxs.copy()
|
||||
np.random.shuffle(shuffle_o_idxs)
|
||||
o_idx = shuffle_o_idxs.pop()
|
||||
o_img = cv2_imread (obstructions_images_paths[o_idx]).astype(np.float32) / 255.0
|
||||
oh,ow,oc = o_img.shape
|
||||
if oc == 4:
|
||||
ohw = max(oh,ow)
|
||||
scale = resolution / ohw
|
||||
|
||||
#o_img = cv2.resize (o_img, ( int(ow*rate), int(oh*rate), ), cv2.INTER_CUBIC)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
mat = cv2.getRotationMatrix2D( (ow/2,oh/2),
|
||||
np.random.uniform( o_rotation_range[0], o_rotation_range[1] ),
|
||||
1.0 )
|
||||
|
||||
mat += np.float32( [[0,0, -ow/2 ],
|
||||
[0,0, -oh/2 ]])
|
||||
mat *= scale * np.random.uniform(1 +o_scale_range[0], 1 +o_scale_range[1])
|
||||
mat += np.float32( [[0, 0, resolution/2 + resolution*np.random.uniform( o_tx_range[0], o_tx_range[1] ) ],
|
||||
[0, 0, resolution/2 + resolution*np.random.uniform( o_ty_range[0], o_ty_range[1] ) ] ])
|
||||
|
||||
|
||||
o_img = cv2.warpAffine( o_img, mat, (resolution,resolution), borderMode=cv2.BORDER_CONSTANT, flags=cv2.INTER_LANCZOS4 )
|
||||
|
||||
if o_random_flip and np.random.randint(10) < 4:
|
||||
o_img = o_img[:,::-1,...]
|
||||
|
||||
o_mask = o_img[...,3:4]
|
||||
o_mask[o_mask>0] = 1.0
|
||||
|
||||
|
||||
o_mask = cv2.erode (o_mask, cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(5,5)), iterations = 1 )
|
||||
o_mask = cv2.GaussianBlur(o_mask, (5, 5) , 0)[...,None]
|
||||
|
||||
img = img*(1-o_mask) + o_img[...,0:3]*o_mask
|
||||
|
||||
o_mask[o_mask<0.5] = 0.0
|
||||
|
||||
|
||||
#import code
|
||||
#code.interact(local=dict(globals(), **locals()))
|
||||
mask *= (1-o_mask)
|
||||
|
||||
|
||||
#cv2.imshow ("", np.clip(o_img*255, 0,255).astype(np.uint8) )
|
||||
#cv2.waitKey(0)
|
||||
|
||||
|
||||
img = imagelib.warp_by_params (warp_params, img, can_warp=True, can_transform=True, can_flip=True, border_replicate=False)
|
||||
mask = imagelib.warp_by_params (warp_params, mask, can_warp=True, can_transform=True, can_flip=True, border_replicate=False)
|
||||
|
||||
|
||||
img = np.clip(img.astype(np.float32), 0, 1)
|
||||
mask[mask < 0.5] = 0.0
|
||||
mask[mask >= 0.5] = 1.0
|
||||
mask = np.clip(mask, 0, 1)
|
||||
|
||||
img = imagelib.apply_random_hsv_shift(img)
|
||||
|
||||
#todo random mask for blur
|
||||
|
||||
img = imagelib.apply_random_motion_blur( img, motion_blur_chance, motion_blur_mb_max_size )
|
||||
img = imagelib.apply_random_gaussian_blur( img, gaussian_blur_chance, gaussian_blur_kernel_max_size )
|
||||
img = imagelib.apply_random_bilinear_resize( img, random_bilinear_resize_chance, random_bilinear_resize_max_size_per )
|
||||
|
||||
if data_format == "NCHW":
|
||||
img = np.transpose(img, (2,0,1) )
|
||||
mask = np.transpose(mask, (2,0,1) )
|
||||
|
||||
batches[0].append ( img )
|
||||
batches[1].append ( mask )
|
||||
|
||||
n_batch += 1
|
||||
except:
|
||||
io.log_err ( traceback.format_exc() )
|
||||
|
||||
yield [ np.array(batch) for batch in batches]
|
|
@ -32,7 +32,7 @@ class SampleLoader:
|
|||
return len(list(persons_name_idxs.keys()))
|
||||
|
||||
@staticmethod
|
||||
def load(sample_type, samples_path):
|
||||
def load(sample_type, samples_path, subdirs=False):
|
||||
samples_cache = SampleLoader.samples_cache
|
||||
|
||||
if str(samples_path) not in samples_cache.keys():
|
||||
|
@ -42,7 +42,7 @@ class SampleLoader:
|
|||
|
||||
if sample_type == SampleType.IMAGE:
|
||||
if samples[sample_type] is None:
|
||||
samples[sample_type] = [ Sample(filename=filename) for filename in io.progress_bar_generator( pathex.get_image_paths(samples_path), "Loading") ]
|
||||
samples[sample_type] = [ Sample(filename=filename) for filename in io.progress_bar_generator( pathex.get_image_paths(samples_path, subdirs=subdirs), "Loading") ]
|
||||
|
||||
elif sample_type == SampleType.FACE:
|
||||
if samples[sample_type] is None:
|
||||
|
@ -55,7 +55,7 @@ class SampleLoader:
|
|||
io.log_info (f"Loaded {len(result)} packed faces from {samples_path}")
|
||||
|
||||
if result is None:
|
||||
result = SampleLoader.load_face_samples( pathex.get_image_paths(samples_path) )
|
||||
result = SampleLoader.load_face_samples( pathex.get_image_paths(samples_path, subdirs=subdirs) )
|
||||
samples[sample_type] = result
|
||||
|
||||
elif sample_type == SampleType.FACE_TEMPORAL_SORTED:
|
||||
|
|
|
@ -73,12 +73,19 @@ class SampleProcessor(object):
|
|||
if debug and is_face_sample:
|
||||
LandmarksProcessor.draw_landmarks (sample_bgr, sample_landmarks, (0, 1, 0))
|
||||
|
||||
if sample_face_type == FaceType.MARK_ONLY:
|
||||
warp_resolution = np.max( [ opts.get('resolution', 0) for opts in output_sample_types ] )
|
||||
else:
|
||||
warp_resolution = w
|
||||
|
||||
params = imagelib.gen_warp_params(warp_resolution, sample_process_options.random_flip, rotation_range=sample_process_options.rotation_range, scale_range=sample_process_options.scale_range, tx_range=sample_process_options.tx_range, ty_range=sample_process_options.ty_range )
|
||||
params_per_resolution = {}
|
||||
warp_rnd_state = np.random.RandomState (sample_rnd_seed-1)
|
||||
for opts in output_sample_types:
|
||||
resolution = opts.get('resolution', None)
|
||||
if resolution is None:
|
||||
continue
|
||||
params_per_resolution[resolution] = imagelib.gen_warp_params(resolution,
|
||||
sample_process_options.random_flip,
|
||||
rotation_range=sample_process_options.rotation_range,
|
||||
scale_range=sample_process_options.scale_range,
|
||||
tx_range=sample_process_options.tx_range,
|
||||
ty_range=sample_process_options.ty_range,
|
||||
rnd_state=warp_rnd_state)
|
||||
|
||||
outputs_sample = []
|
||||
for opts in output_sample_types:
|
||||
|
@ -89,6 +96,7 @@ class SampleProcessor(object):
|
|||
transform = opts.get('transform', False)
|
||||
motion_blur = opts.get('motion_blur', None)
|
||||
gaussian_blur = opts.get('gaussian_blur', None)
|
||||
random_bilinear_resize = opts.get('random_bilinear_resize', None)
|
||||
normalize_tanh = opts.get('normalize_tanh', False)
|
||||
ct_mode = opts.get('ct_mode', None)
|
||||
data_format = opts.get('data_format', 'NHWC')
|
||||
|
@ -137,31 +145,38 @@ class SampleProcessor(object):
|
|||
mat = LandmarksProcessor.get_transform_mat (sample_landmarks, warp_resolution, face_type)
|
||||
img = cv2.warpAffine( img, mat, (warp_resolution, warp_resolution), flags=cv2.INTER_LINEAR )
|
||||
|
||||
img = imagelib.warp_by_params (params, img, warp, transform, can_flip=True, border_replicate=border_replicate, cv2_inter=cv2.INTER_LINEAR)
|
||||
img = cv2.resize( img, (resolution,resolution), cv2.INTER_LINEAR )[...,None]
|
||||
img = imagelib.warp_by_params (params_per_resolution[resolution], img, warp, transform, can_flip=True, border_replicate=border_replicate, cv2_inter=cv2.INTER_LINEAR)
|
||||
img = cv2.resize( img, (resolution,resolution), cv2.INTER_LINEAR )
|
||||
else:
|
||||
img = imagelib.warp_by_params (params, img, warp, transform, can_flip=True, border_replicate=border_replicate, cv2_inter=cv2.INTER_LINEAR)
|
||||
mat = LandmarksProcessor.get_transform_mat (sample_landmarks, resolution, face_type)
|
||||
img = cv2.warpAffine( img, mat, (resolution,resolution), borderMode=borderMode, flags=cv2.INTER_LINEAR )[...,None]
|
||||
|
||||
if face_type != sample_face_type:
|
||||
mat = LandmarksProcessor.get_transform_mat (sample_landmarks, resolution, face_type)
|
||||
img = cv2.warpAffine( img, mat, (resolution,resolution), borderMode=borderMode, flags=cv2.INTER_LINEAR )
|
||||
else:
|
||||
if w != resolution:
|
||||
img = cv2.resize( img, (resolution, resolution), cv2.INTER_CUBIC )
|
||||
|
||||
img = imagelib.warp_by_params (params_per_resolution[resolution], img, warp, transform, can_flip=True, border_replicate=border_replicate, cv2_inter=cv2.INTER_LINEAR)
|
||||
|
||||
if len(img.shape) == 2:
|
||||
img = img[...,None]
|
||||
|
||||
if channel_type == SPCT.G:
|
||||
out_sample = img.astype(np.float32)
|
||||
else:
|
||||
raise ValueError("only channel_type.G supported for the mask")
|
||||
|
||||
elif sample_type == SPST.FACE_IMAGE:
|
||||
img = sample_bgr
|
||||
img = sample_bgr
|
||||
|
||||
if sample_face_type == FaceType.MARK_ONLY:
|
||||
mat = LandmarksProcessor.get_transform_mat (sample_landmarks, warp_resolution, face_type)
|
||||
img = cv2.warpAffine( img, mat, (warp_resolution,warp_resolution), flags=cv2.INTER_CUBIC )
|
||||
img = imagelib.warp_by_params (params, img, warp, transform, can_flip=True, border_replicate=border_replicate)
|
||||
img = cv2.resize( img, (resolution,resolution), cv2.INTER_CUBIC )
|
||||
else:
|
||||
img = imagelib.warp_by_params (params, img, warp, transform, can_flip=True, border_replicate=border_replicate)
|
||||
|
||||
if face_type != sample_face_type:
|
||||
mat = LandmarksProcessor.get_transform_mat (sample_landmarks, resolution, face_type)
|
||||
img = cv2.warpAffine( img, mat, (resolution,resolution), borderMode=borderMode, flags=cv2.INTER_CUBIC )
|
||||
|
||||
img = cv2.warpAffine( img, mat, (resolution,resolution), borderMode=borderMode, flags=cv2.INTER_CUBIC )
|
||||
else:
|
||||
if w != resolution:
|
||||
img = cv2.resize( img, (resolution, resolution), cv2.INTER_CUBIC )
|
||||
img = imagelib.warp_by_params (params_per_resolution[resolution], img, warp, transform, can_flip=True, border_replicate=border_replicate)
|
||||
|
||||
img = np.clip(img.astype(np.float32), 0, 1)
|
||||
|
||||
|
||||
|
@ -195,6 +210,16 @@ class SampleProcessor(object):
|
|||
if gblur_rnd_chance < chance:
|
||||
img = cv2.GaussianBlur(img, (gblur_rnd_kernel,) *2 , 0)
|
||||
|
||||
if random_bilinear_resize is not None:
|
||||
l_rnd_state = np.random.RandomState (sample_rnd_seed+2)
|
||||
|
||||
chance, max_size_per = random_bilinear_resize
|
||||
chance = np.clip(chance, 0, 100)
|
||||
pick_chance = l_rnd_state.randint(100)
|
||||
resize_to = resolution - int( l_rnd_state.rand()* int(resolution*(max_size_per/100.0)) )
|
||||
img = cv2.resize (img, (resize_to,resize_to), cv2.INTER_LINEAR )
|
||||
img = cv2.resize (img, (resolution,resolution), cv2.INTER_LINEAR )
|
||||
|
||||
# Transform from BGR to desired channel_type
|
||||
if channel_type == SPCT.BGR:
|
||||
out_sample = img
|
||||
|
@ -207,7 +232,7 @@ class SampleProcessor(object):
|
|||
h, s, v = cv2.split(hsv)
|
||||
h = (h + l_rnd_state.randint(360) ) % 360
|
||||
s = np.clip ( s + l_rnd_state.random()-0.5, 0, 1 )
|
||||
v = np.clip ( v + l_rnd_state.random()-0.5, 0, 1 )
|
||||
v = np.clip ( v + l_rnd_state.random()/2-0.25, 0, 1 )
|
||||
hsv = cv2.merge([h, s, v])
|
||||
out_sample = np.clip( cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) , 0, 1 )
|
||||
elif channel_type == SPCT.BGR_RANDOM_RGB_LEVELS:
|
||||
|
@ -234,7 +259,7 @@ class SampleProcessor(object):
|
|||
out_sample = np.transpose(out_sample, (2,0,1) )
|
||||
elif sample_type == SPST.IMAGE:
|
||||
img = sample_bgr
|
||||
img = imagelib.warp_by_params (params, img, warp, transform, can_flip=True, border_replicate=True)
|
||||
img = imagelib.warp_by_params (params_per_resolution[resolution], img, warp, transform, can_flip=True, border_replicate=True)
|
||||
img = cv2.resize( img, (resolution, resolution), cv2.INTER_CUBIC )
|
||||
out_sample = img
|
||||
|
||||
|
|
|
@ -9,4 +9,5 @@ from .SampleGeneratorFaceTemporal import SampleGeneratorFaceTemporal
|
|||
from .SampleGeneratorImage import SampleGeneratorImage
|
||||
from .SampleGeneratorImageTemporal import SampleGeneratorImageTemporal
|
||||
from .SampleGeneratorFaceCelebAMaskHQ import SampleGeneratorFaceCelebAMaskHQ
|
||||
from .SampleGeneratorFaceSkinSegDataset import SampleGeneratorFaceSkinSegDataset
|
||||
from .PackedFaceset import PackedFaceset
|
Loading…
Add table
Add a link
Reference in a new issue