mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-08-14 02:37:00 -07:00
initial
This commit is contained in:
parent
73de93b4f1
commit
6bd5a44264
71 changed files with 8448 additions and 0 deletions
296
utils/AlignedPNG.py
Normal file
296
utils/AlignedPNG.py
Normal file
|
@ -0,0 +1,296 @@
|
|||
PNG_HEADER = b"\x89PNG\r\n\x1a\n"
|
||||
|
||||
import string
|
||||
import struct
|
||||
import zlib
|
||||
import pickle
|
||||
|
||||
class Chunk(object):
|
||||
def __init__(self, name=None, data=None):
|
||||
self.length = 0
|
||||
self.crc = 0
|
||||
self.name = name if name else "noNe"
|
||||
self.data = data if data else b""
|
||||
|
||||
@classmethod
|
||||
def load(cls, data):
|
||||
"""Load a chunk including header and footer"""
|
||||
inst = cls()
|
||||
if len(data) < 12:
|
||||
msg = "Chunk-data too small"
|
||||
raise ValueError(msg)
|
||||
|
||||
# chunk header & data
|
||||
(inst.length, raw_name) = struct.unpack("!I4s", data[0:8])
|
||||
inst.data = data[8:-4]
|
||||
inst.verify_length()
|
||||
inst.name = raw_name.decode("ascii")
|
||||
inst.verify_name()
|
||||
|
||||
# chunk crc
|
||||
inst.crc = struct.unpack("!I", data[8+inst.length:8+inst.length+4])[0]
|
||||
inst.verify_crc()
|
||||
|
||||
return inst
|
||||
|
||||
def dump(self, auto_crc=True, auto_length=True):
|
||||
"""Return the chunk including header and footer"""
|
||||
if auto_length: self.update_length()
|
||||
if auto_crc: self.update_crc()
|
||||
self.verify_name()
|
||||
return struct.pack("!I", self.length) + self.get_raw_name() + self.data + struct.pack("!I", self.crc)
|
||||
|
||||
def verify_length(self):
|
||||
if len(self.data) != self.length:
|
||||
msg = "Data length ({}) does not match length in chunk header ({})".format(len(self.data), self.length)
|
||||
raise ValueError(msg)
|
||||
return True
|
||||
|
||||
def verify_name(self):
|
||||
for c in self.name:
|
||||
if c not in string.ascii_letters:
|
||||
msg = "Invalid character in chunk name: {}".format(repr(self.name))
|
||||
raise ValueError(msg)
|
||||
return True
|
||||
|
||||
def verify_crc(self):
|
||||
calculated_crc = self.get_crc()
|
||||
if self.crc != calculated_crc:
|
||||
msg = "CRC mismatch: {:08X} (header), {:08X} (calculated)".format(self.crc, calculated_crc)
|
||||
raise ValueError(msg)
|
||||
return True
|
||||
|
||||
def update_length(self):
|
||||
self.length = len(self.data)
|
||||
|
||||
def update_crc(self):
|
||||
self.crc = self.get_crc()
|
||||
|
||||
def get_crc(self):
|
||||
return zlib.crc32(self.get_raw_name() + self.data)
|
||||
|
||||
def get_raw_name(self):
|
||||
return self.name if isinstance(self.name, bytes) else self.name.encode("ascii")
|
||||
|
||||
# name helper methods
|
||||
|
||||
def ancillary(self, set=None):
|
||||
"""Set and get ancillary=True/critical=False bit"""
|
||||
if set is True:
|
||||
self.name[0] = self.name[0].lower()
|
||||
elif set is False:
|
||||
self.name[0] = self.name[0].upper()
|
||||
return self.name[0].islower()
|
||||
|
||||
def private(self, set=None):
|
||||
"""Set and get private=True/public=False bit"""
|
||||
if set is True:
|
||||
self.name[1] = self.name[1].lower()
|
||||
elif set is False:
|
||||
self.name[1] = self.name[1].upper()
|
||||
return self.name[1].islower()
|
||||
|
||||
def reserved(self, set=None):
|
||||
"""Set and get reserved_valid=True/invalid=False bit"""
|
||||
if set is True:
|
||||
self.name[2] = self.name[2].upper()
|
||||
elif set is False:
|
||||
self.name[2] = self.name[2].lower()
|
||||
return self.name[2].isupper()
|
||||
|
||||
def safe_to_copy(self, set=None):
|
||||
"""Set and get save_to_copy=True/unsafe=False bit"""
|
||||
if set is True:
|
||||
self.name[3] = self.name[3].lower()
|
||||
elif set is False:
|
||||
self.name[3] = self.name[3].upper()
|
||||
return self.name[3].islower()
|
||||
|
||||
def __str__(self):
|
||||
return "<Chunk '{name}' length={length} crc={crc:08X}>".format(**self.__dict__)
|
||||
|
||||
class IHDR(Chunk):
|
||||
"""IHDR Chunk
|
||||
width, height, bit_depth, color_type, compression_method,
|
||||
filter_method, interlace_method contain the data extracted
|
||||
from the chunk. Modify those and use and build() to recreate
|
||||
the chunk. Valid values for bit_depth depend on the color_type
|
||||
and can be looked up in color_types or in the PNG specification
|
||||
|
||||
See:
|
||||
http://www.libpng.org/pub/png/spec/1.2/PNG-Chunks.html#C.IHDR
|
||||
"""
|
||||
# color types with name & allowed bit depths
|
||||
COLOR_TYPE_GRAY = 0
|
||||
COLOR_TYPE_RGB = 2
|
||||
COLOR_TYPE_PLTE = 3
|
||||
COLOR_TYPE_GRAYA = 4
|
||||
COLOR_TYPE_RGBA = 6
|
||||
color_types = {
|
||||
COLOR_TYPE_GRAY: ("Grayscale", (1,2,4,8,16)),
|
||||
COLOR_TYPE_RGB: ("RGB", (8,16)),
|
||||
COLOR_TYPE_PLTE: ("Palette", (1,2,4,8)),
|
||||
COLOR_TYPE_GRAYA: ("Greyscale+Alpha", (8,16)),
|
||||
COLOR_TYPE_RGBA: ("RGBA", (8,16)),
|
||||
}
|
||||
|
||||
def __init__(self, width=0, height=0, bit_depth=8, color_type=2, \
|
||||
compression_method=0, filter_method=0, interlace_method=0):
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.bit_depth = bit_depth
|
||||
self.color_type = color_type
|
||||
self.compression_method = compression_method
|
||||
self.filter_method = filter_method
|
||||
self.interlace_method = interlace_method
|
||||
super().__init__("IHDR")
|
||||
|
||||
@classmethod
|
||||
def load(cls, data):
|
||||
inst = super().load(data)
|
||||
fields = struct.unpack("!IIBBBBB", inst.data)
|
||||
inst.width = fields[0]
|
||||
inst.height = fields[1]
|
||||
inst.bit_depth = fields[2] # per channel
|
||||
inst.color_type = fields[3] # see specs
|
||||
inst.compression_method = fields[4] # always 0(=deflate/inflate)
|
||||
inst.filter_method = fields[5] # always 0(=adaptive filtering with 5 methods)
|
||||
inst.interlace_method = fields[6] # 0(=no interlace) or 1(=Adam7 interlace)
|
||||
return inst
|
||||
|
||||
def dump(self):
|
||||
self.data = struct.pack("!IIBBBBB", \
|
||||
self.width, self.height, self.bit_depth, self.color_type, \
|
||||
self.compression_method, self.filter_method, self.interlace_method)
|
||||
return super().dump()
|
||||
|
||||
def __str__(self):
|
||||
return "<Chunk:IHDR geometry={width}x{height} bit_depth={bit_depth} color_type={}>" \
|
||||
.format(self.color_types[self.color_type][0], **self.__dict__)
|
||||
|
||||
class IEND(Chunk):
|
||||
def __init__(self):
|
||||
super().__init__("IEND")
|
||||
|
||||
def dump(self):
|
||||
if len(self.data) != 0:
|
||||
msg = "IEND has data which is not allowed"
|
||||
raise ValueError(msg)
|
||||
if self.length != 0:
|
||||
msg = "IEND data lenght is not 0 which is not allowed"
|
||||
raise ValueError(msg)
|
||||
return super().dump()
|
||||
|
||||
def __str__(self):
|
||||
return "<Chunk:IEND>".format(**self.__dict__)
|
||||
|
||||
class FaceswapChunk(Chunk):
|
||||
def __init__(self, dict_data=None):
|
||||
super().__init__("fcWp")
|
||||
self.dict_data = dict_data
|
||||
|
||||
def setDictData(self, dict_data):
|
||||
self.dict_data = dict_data
|
||||
|
||||
def getDictData(self):
|
||||
return self.dict_data
|
||||
|
||||
@classmethod
|
||||
def load(cls, data):
|
||||
inst = super().load(data)
|
||||
inst.dict_data = pickle.loads( inst.data )
|
||||
return inst
|
||||
|
||||
def dump(self):
|
||||
self.data = pickle.dumps (self.dict_data)
|
||||
return super().dump()
|
||||
|
||||
chunk_map = {
|
||||
b"IHDR": IHDR,
|
||||
b"fcWp": FaceswapChunk,
|
||||
b"IEND": IEND
|
||||
}
|
||||
|
||||
class AlignedPNG(object):
|
||||
def __init__(self):
|
||||
self.data = b""
|
||||
self.length = 0
|
||||
self.chunks = []
|
||||
|
||||
@staticmethod
|
||||
def load(data):
|
||||
|
||||
try:
|
||||
with open(data, "rb") as f:
|
||||
data = f.read()
|
||||
except:
|
||||
raise FileNotFoundError(data)
|
||||
|
||||
inst = AlignedPNG()
|
||||
inst.data = data
|
||||
inst.length = len(data)
|
||||
|
||||
if data[0:8] != PNG_HEADER:
|
||||
msg = "No Valid PNG header"
|
||||
raise ValueError(msg)
|
||||
|
||||
chunk_start = 8
|
||||
while chunk_start < inst.length:
|
||||
(chunk_length, chunk_name) = struct.unpack("!I4s", data[chunk_start:chunk_start+8])
|
||||
chunk_end = chunk_start + chunk_length + 12
|
||||
|
||||
chunk = chunk_map.get(chunk_name, Chunk).load(data[chunk_start:chunk_end])
|
||||
inst.chunks.append(chunk)
|
||||
chunk_start = chunk_end
|
||||
|
||||
return inst
|
||||
|
||||
|
||||
def save(self, filename):
|
||||
try:
|
||||
with open(filename, "wb") as f:
|
||||
f.write ( self.dump() )
|
||||
except:
|
||||
raise Exception( 'cannot save %s' % (filename) )
|
||||
|
||||
def dump(self):
|
||||
data = PNG_HEADER
|
||||
for chunk in self.chunks:
|
||||
data += chunk.dump()
|
||||
return data
|
||||
|
||||
def get_shape(self):
|
||||
for chunk in self.chunks:
|
||||
if type(chunk) == IHDR:
|
||||
c = 3 if chunk.color_type == IHDR.COLOR_TYPE_RGB else 4
|
||||
w = chunk.width
|
||||
h = chunk.height
|
||||
return (h,w,c)
|
||||
return (0,0,0)
|
||||
|
||||
def get_height(self):
|
||||
for chunk in self.chunks:
|
||||
if type(chunk) == IHDR:
|
||||
return chunk.height
|
||||
return 0
|
||||
|
||||
def getFaceswapDictData(self):
|
||||
for chunk in self.chunks:
|
||||
if type(chunk) == FaceswapChunk:
|
||||
return chunk.getDictData()
|
||||
return None
|
||||
|
||||
def setFaceswapDictData (self, dict_data=None):
|
||||
for chunk in self.chunks:
|
||||
if type(chunk) == FaceswapChunk:
|
||||
self.chunks.remove(chunk)
|
||||
break
|
||||
|
||||
if not dict_data is None:
|
||||
chunk = FaceswapChunk(dict_data)
|
||||
self.chunks.insert(-1, chunk)
|
||||
|
||||
|
||||
|
||||
def __str__(self):
|
||||
return "<PNG length={length} chunks={}>".format(len(self.chunks), **self.__dict__)
|
40
utils/Path_utils.py
Normal file
40
utils/Path_utils.py
Normal file
|
@ -0,0 +1,40 @@
|
|||
from pathlib import Path
|
||||
from scandir import scandir
|
||||
|
||||
image_extensions = [".jpg", ".jpeg", ".png", ".tif", ".tiff"]
|
||||
|
||||
def get_image_paths(dir_path):
|
||||
dir_path = Path (dir_path)
|
||||
|
||||
result = []
|
||||
if dir_path.exists():
|
||||
for x in list(scandir(str(dir_path))):
|
||||
if any([x.name.lower().endswith(ext) for ext in image_extensions]):
|
||||
result.append(x.path)
|
||||
return result
|
||||
|
||||
def get_image_unique_filestem_paths(dir_path, verbose=False):
|
||||
result = get_image_paths(dir_path)
|
||||
result_dup = set()
|
||||
|
||||
for f in result[:]:
|
||||
f_stem = Path(f).stem
|
||||
if f_stem in result_dup:
|
||||
result.remove(f)
|
||||
if verbose:
|
||||
print ("Duplicate filenames are not allowed, skipping: %s" % Path(f).name )
|
||||
continue
|
||||
result_dup.add(f_stem)
|
||||
|
||||
return result
|
||||
|
||||
def get_all_dir_names_startswith (dir_path, startswith):
|
||||
dir_path = Path (dir_path)
|
||||
startswith = startswith.lower()
|
||||
|
||||
result = []
|
||||
if dir_path.exists():
|
||||
for x in list(scandir(str(dir_path))):
|
||||
if x.name.lower().startswith(startswith):
|
||||
result.append ( x.name[len(startswith):] )
|
||||
return result
|
246
utils/SubprocessorBase.py
Normal file
246
utils/SubprocessorBase.py
Normal file
|
@ -0,0 +1,246 @@
|
|||
import traceback
|
||||
from tqdm import tqdm
|
||||
import multiprocessing
|
||||
import time
|
||||
import sys
|
||||
|
||||
class SubprocessorBase(object):
|
||||
|
||||
#overridable
|
||||
def __init__(self, name, no_response_time_sec = 60):
|
||||
self.name = name
|
||||
self.no_response_time_sec = no_response_time_sec
|
||||
|
||||
#overridable
|
||||
def process_info_generator(self):
|
||||
#yield name, host_dict, client_dict - per process
|
||||
yield 'first process', {}, {}
|
||||
|
||||
#overridable
|
||||
def get_no_process_started_message(self):
|
||||
return "No process started."
|
||||
|
||||
#overridable
|
||||
def onHostGetProgressBarDesc(self):
|
||||
return "Processing"
|
||||
|
||||
#overridable
|
||||
def onHostGetProgressBarLen(self):
|
||||
return 0
|
||||
|
||||
#overridable
|
||||
def onHostGetData(self):
|
||||
#return data here
|
||||
return None
|
||||
|
||||
#overridable
|
||||
def onHostDataReturn (self, data):
|
||||
#input_data.insert(0, obj['data'])
|
||||
pass
|
||||
|
||||
#overridable
|
||||
def onClientInitialize(self, client_dict):
|
||||
#return fail message or None if ok
|
||||
return None
|
||||
|
||||
#overridable
|
||||
def onClientFinalize(self):
|
||||
pass
|
||||
|
||||
#overridable
|
||||
def onClientProcessData(self, data):
|
||||
#return result object
|
||||
return None
|
||||
|
||||
#overridable
|
||||
def onClientGetDataName (self, data):
|
||||
#return string identificator of your data
|
||||
return "undefined"
|
||||
|
||||
#overridable
|
||||
def onHostClientsInitialized(self):
|
||||
pass
|
||||
|
||||
#overridable
|
||||
def onHostResult (self, data, result):
|
||||
#return count of progress bar update
|
||||
return 1
|
||||
|
||||
#overridable
|
||||
def onHostProcessEnd(self):
|
||||
pass
|
||||
|
||||
#overridable
|
||||
def get_start_return(self):
|
||||
return None
|
||||
|
||||
def inc_progress_bar(self, c):
|
||||
self.progress_bar.update(c)
|
||||
|
||||
def safe_print(self, msg):
|
||||
self.print_lock.acquire()
|
||||
print (msg)
|
||||
self.print_lock.release()
|
||||
|
||||
def process(self):
|
||||
#returns start_return
|
||||
|
||||
self.processes = []
|
||||
|
||||
self.print_lock = multiprocessing.Lock()
|
||||
for name, host_dict, client_dict in self.process_info_generator():
|
||||
sq = multiprocessing.Queue()
|
||||
cq = multiprocessing.Queue()
|
||||
|
||||
client_dict.update ( {'print_lock' : self.print_lock} )
|
||||
|
||||
p = multiprocessing.Process(target=self.subprocess, args=(sq,cq,client_dict))
|
||||
p.daemon = True
|
||||
p.start()
|
||||
self.processes.append ( { 'process' : p,
|
||||
'sq' : sq,
|
||||
'cq' : cq,
|
||||
'state' : 'busy',
|
||||
'sent_time': time.time(),
|
||||
'name': name,
|
||||
'host_dict' : host_dict
|
||||
} )
|
||||
|
||||
while True:
|
||||
for p in self.processes[:]:
|
||||
while not p['cq'].empty():
|
||||
obj = p['cq'].get()
|
||||
obj_op = obj['op']
|
||||
|
||||
if obj_op == 'init_ok':
|
||||
p['state'] = 'free'
|
||||
elif obj_op == 'error':
|
||||
if obj['close'] == True:
|
||||
p['process'].terminate()
|
||||
p['process'].join()
|
||||
self.processes.remove(p)
|
||||
break
|
||||
|
||||
if all ([ p['state'] == 'free' for p in self.processes ] ):
|
||||
break
|
||||
|
||||
if len(self.processes) == 0:
|
||||
print ( self.get_no_process_started_message() )
|
||||
return self.get_start_return()
|
||||
|
||||
self.onHostClientsInitialized()
|
||||
|
||||
self.progress_bar = tqdm( total=self.onHostGetProgressBarLen(), desc=self.onHostGetProgressBarDesc() )
|
||||
|
||||
try:
|
||||
while True:
|
||||
for p in self.processes[:]:
|
||||
while not p['cq'].empty():
|
||||
obj = p['cq'].get()
|
||||
obj_op = obj['op']
|
||||
|
||||
if obj_op == 'success':
|
||||
data = obj['data']
|
||||
result = obj['result']
|
||||
|
||||
c = self.onHostResult (data, result)
|
||||
if c > 0:
|
||||
self.progress_bar.update(c)
|
||||
|
||||
p['state'] = 'free'
|
||||
|
||||
elif obj_op == 'error':
|
||||
if 'data' in obj.keys():
|
||||
self.onHostDataReturn ( obj['data'] )
|
||||
|
||||
if obj['close'] == True:
|
||||
p['sq'].put ( {'op': 'close'} )
|
||||
p['process'].join()
|
||||
self.processes.remove(p)
|
||||
break
|
||||
p['state'] = 'free'
|
||||
|
||||
for p in self.processes[:]:
|
||||
if p['state'] == 'free':
|
||||
data = self.onHostGetData()
|
||||
if data is not None:
|
||||
p['sq'].put ( {'op': 'data', 'data' : data} )
|
||||
p['sent_time'] = time.time()
|
||||
p['sent_data'] = data
|
||||
p['state'] = 'busy'
|
||||
|
||||
elif p['state'] == 'busy':
|
||||
if (time.time() - p['sent_time']) > self.no_response_time_sec:
|
||||
print ( '%s doesnt response, terminating it.' % (p['name']) )
|
||||
self.onHostDataReturn ( p['sent_data'] )
|
||||
p['sq'].put ( {'op': 'close'} )
|
||||
p['process'].join()
|
||||
self.processes.remove(p)
|
||||
|
||||
if all ([p['state'] == 'free' for p in self.processes]):
|
||||
break
|
||||
|
||||
time.sleep(0.005)
|
||||
except:
|
||||
print ("Exception occured in Subprocessor.start(): %s" % (traceback.format_exc()) )
|
||||
|
||||
self.progress_bar.close()
|
||||
|
||||
for p in self.processes[:]:
|
||||
p['sq'].put ( {'op': 'close'} )
|
||||
|
||||
while True:
|
||||
for p in self.processes[:]:
|
||||
while not p['cq'].empty():
|
||||
obj = p['cq'].get()
|
||||
obj_op = obj['op']
|
||||
if obj_op == 'finalized':
|
||||
p['state'] = 'finalized'
|
||||
|
||||
if all ([p['state'] == 'finalized' for p in self.processes]):
|
||||
break
|
||||
|
||||
for p in self.processes[:]:
|
||||
p['process'].terminate()
|
||||
|
||||
self.onHostProcessEnd()
|
||||
|
||||
return self.get_start_return()
|
||||
|
||||
def subprocess(self, sq, cq, client_dict):
|
||||
self.print_lock = client_dict['print_lock']
|
||||
|
||||
try:
|
||||
fail_message = self.onClientInitialize(client_dict)
|
||||
except:
|
||||
fail_message = 'Exception while initialization: %s' % (traceback.format_exc())
|
||||
|
||||
if fail_message is None:
|
||||
cq.put ( {'op': 'init_ok'} )
|
||||
else:
|
||||
print (fail_message)
|
||||
cq.put ( {'op': 'error', 'close': True} )
|
||||
return
|
||||
|
||||
while True:
|
||||
obj = sq.get()
|
||||
obj_op = obj['op']
|
||||
|
||||
if obj_op == 'data':
|
||||
data = obj['data']
|
||||
try:
|
||||
result = self.onClientProcessData (data)
|
||||
cq.put ( {'op': 'success', 'data' : data, 'result' : result} )
|
||||
except:
|
||||
|
||||
print ( 'Exception while process data [%s]: %s' % (self.onClientGetDataName(data), traceback.format_exc()) )
|
||||
cq.put ( {'op': 'error', 'close': True, 'data' : data } )
|
||||
elif obj_op == 'close':
|
||||
break
|
||||
|
||||
time.sleep(0.005)
|
||||
|
||||
self.onClientFinalize()
|
||||
cq.put ( {'op': 'finalized'} )
|
||||
while True:
|
||||
time.sleep(0.1)
|
264
utils/image_utils.py
Normal file
264
utils/image_utils.py
Normal file
|
@ -0,0 +1,264 @@
|
|||
import sys
|
||||
from utils import random_utils
|
||||
import numpy as np
|
||||
import cv2
|
||||
import localization
|
||||
from scipy.spatial import Delaunay
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
def channel_hist_match(source, template, mask=None):
|
||||
# Code borrowed from:
|
||||
# https://stackoverflow.com/questions/32655686/histogram-matching-of-two-images-in-python-2-x
|
||||
masked_source = source
|
||||
masked_template = template
|
||||
|
||||
if mask is not None:
|
||||
masked_source = source * mask
|
||||
masked_template = template * mask
|
||||
|
||||
oldshape = source.shape
|
||||
source = source.ravel()
|
||||
template = template.ravel()
|
||||
masked_source = masked_source.ravel()
|
||||
masked_template = masked_template.ravel()
|
||||
s_values, bin_idx, s_counts = np.unique(source, return_inverse=True,
|
||||
return_counts=True)
|
||||
t_values, t_counts = np.unique(template, return_counts=True)
|
||||
ms_values, mbin_idx, ms_counts = np.unique(source, return_inverse=True,
|
||||
return_counts=True)
|
||||
mt_values, mt_counts = np.unique(template, return_counts=True)
|
||||
|
||||
s_quantiles = np.cumsum(s_counts).astype(np.float64)
|
||||
s_quantiles /= s_quantiles[-1]
|
||||
t_quantiles = np.cumsum(t_counts).astype(np.float64)
|
||||
t_quantiles /= t_quantiles[-1]
|
||||
interp_t_values = np.interp(s_quantiles, t_quantiles, t_values)
|
||||
|
||||
return interp_t_values[bin_idx].reshape(oldshape)
|
||||
|
||||
def color_hist_match(src_im, tar_im, mask=None):
|
||||
h,w,c = src_im.shape
|
||||
matched_R = channel_hist_match(src_im[:,:,0], tar_im[:,:,0], mask)
|
||||
matched_G = channel_hist_match(src_im[:,:,1], tar_im[:,:,1], mask)
|
||||
matched_B = channel_hist_match(src_im[:,:,2], tar_im[:,:,2], mask)
|
||||
|
||||
to_stack = (matched_R, matched_G, matched_B)
|
||||
for i in range(3, c):
|
||||
to_stack += ( src_im[:,:,i],)
|
||||
|
||||
|
||||
matched = np.stack(to_stack, axis=-1).astype(src_im.dtype)
|
||||
return matched
|
||||
|
||||
|
||||
pil_fonts = {}
|
||||
def _get_pil_font (font, size):
|
||||
global pil_fonts
|
||||
try:
|
||||
font_str_id = '%s_%d' % (font, size)
|
||||
if font_str_id not in pil_fonts.keys():
|
||||
pil_fonts[font_str_id] = ImageFont.truetype(font + ".ttf", size=size, encoding="unic")
|
||||
pil_font = pil_fonts[font_str_id]
|
||||
return pil_font
|
||||
except:
|
||||
return ImageFont.load_default()
|
||||
|
||||
def get_text_image( shape, text, color=(1,1,1), border=0.2, font=None):
|
||||
try:
|
||||
size = shape[1]
|
||||
pil_font = _get_pil_font( localization.get_default_ttf_font_name() , size)
|
||||
text_width, text_height = pil_font.getsize(text)
|
||||
|
||||
canvas = Image.new('RGB', shape[0:2], (0,0,0) )
|
||||
draw = ImageDraw.Draw(canvas)
|
||||
offset = ( 0, 0)
|
||||
draw.text(offset, text, font=pil_font, fill=tuple((np.array(color)*255).astype(np.int)) )
|
||||
|
||||
result = np.asarray(canvas) / 255
|
||||
if shape[2] != 3:
|
||||
result = np.concatenate ( (result, np.ones ( (shape[1],) + (shape[0],) + (shape[2]-3,)) ), axis=2 )
|
||||
|
||||
return result
|
||||
except:
|
||||
return np.zeros ( (shape[1], shape[0], shape[2]), dtype=np.float32 )
|
||||
|
||||
def draw_text( image, rect, text, color=(1,1,1), border=0.2, font=None):
|
||||
h,w,c = image.shape
|
||||
|
||||
l,t,r,b = rect
|
||||
l = np.clip (l, 0, w-1)
|
||||
r = np.clip (r, 0, w-1)
|
||||
t = np.clip (t, 0, h-1)
|
||||
b = np.clip (b, 0, h-1)
|
||||
|
||||
image[t:b, l:r] += get_text_image ( (r-l,b-t,c) , text, color, border, font )
|
||||
|
||||
def draw_text_lines (image, rect, text_lines, color=(1,1,1), border=0.2, font=None):
|
||||
text_lines_len = len(text_lines)
|
||||
if text_lines_len == 0:
|
||||
return
|
||||
|
||||
l,t,r,b = rect
|
||||
h = b-t
|
||||
h_per_line = h // text_lines_len
|
||||
|
||||
for i in range(0, text_lines_len):
|
||||
draw_text (image, (l, i*h_per_line, r, (i+1)*h_per_line), text_lines[i], color, border, font)
|
||||
|
||||
def get_draw_text_lines ( image, rect, text_lines, color=(1,1,1), border=0.2, font=None):
|
||||
image = np.zeros ( image.shape, dtype=np.float )
|
||||
draw_text_lines ( image, rect, text_lines, color, border, font)
|
||||
return image
|
||||
|
||||
|
||||
def draw_polygon (image, points, color, thickness = 1):
|
||||
points_len = len(points)
|
||||
for i in range (0, points_len):
|
||||
p0 = tuple( points[i] )
|
||||
p1 = tuple( points[ (i+1) % points_len] )
|
||||
cv2.line (image, p0, p1, color, thickness=thickness)
|
||||
|
||||
def draw_rect(image, rect, color, thickness=1):
|
||||
l,t,r,b = rect
|
||||
draw_polygon (image, [ (l,t), (r,t), (r,b), (l,b ) ], color, thickness)
|
||||
|
||||
def rectContains(rect, point) :
|
||||
return not (point[0] < rect[0] or point[0] >= rect[2] or point[1] < rect[1] or point[1] >= rect[3])
|
||||
|
||||
def applyAffineTransform(src, srcTri, dstTri, size) :
|
||||
warpMat = cv2.getAffineTransform( np.float32(srcTri), np.float32(dstTri) )
|
||||
return cv2.warpAffine( src, warpMat, (size[0], size[1]), None, flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT_101 )
|
||||
|
||||
def morphTriangle(dst_img, src_img, st, dt) :
|
||||
(h,w,c) = dst_img.shape
|
||||
sr = np.array( cv2.boundingRect(np.float32(st)) )
|
||||
dr = np.array( cv2.boundingRect(np.float32(dt)) )
|
||||
sRect = st - sr[0:2]
|
||||
dRect = dt - dr[0:2]
|
||||
d_mask = np.zeros((dr[3], dr[2], c), dtype = np.float32)
|
||||
cv2.fillConvexPoly(d_mask, np.int32(dRect), (1.0,)*c, 8, 0);
|
||||
imgRect = src_img[sr[1]:sr[1] + sr[3], sr[0]:sr[0] + sr[2]]
|
||||
size = (dr[2], dr[3])
|
||||
warpImage1 = applyAffineTransform(imgRect, sRect, dRect, size)
|
||||
dst_img[dr[1]:dr[1]+dr[3], dr[0]:dr[0]+dr[2]] = dst_img[dr[1]:dr[1]+dr[3], dr[0]:dr[0]+dr[2]]*(1-d_mask) + warpImage1 * d_mask
|
||||
|
||||
def morph_by_points (image, sp, dp):
|
||||
if sp.shape != dp.shape:
|
||||
raise ValueError ('morph_by_points() sp.shape != dp.shape')
|
||||
(h,w,c) = image.shape
|
||||
|
||||
result_image = np.zeros(image.shape, dtype = image.dtype)
|
||||
|
||||
for tri in Delaunay(dp).simplices:
|
||||
morphTriangle(result_image, image, sp[tri], dp[tri])
|
||||
|
||||
return result_image
|
||||
|
||||
def equalize_and_stack_square (images, axis=1):
|
||||
max_c = max ([ 1 if len(image.shape) == 2 else image.shape[2] for image in images ] )
|
||||
|
||||
target_wh = 99999
|
||||
for i,image in enumerate(images):
|
||||
if len(image.shape) == 2:
|
||||
h,w = image.shape
|
||||
c = 1
|
||||
else:
|
||||
h,w,c = image.shape
|
||||
|
||||
if h < target_wh:
|
||||
target_wh = h
|
||||
|
||||
if w < target_wh:
|
||||
target_wh = w
|
||||
|
||||
for i,image in enumerate(images):
|
||||
if len(image.shape) == 2:
|
||||
h,w = image.shape
|
||||
c = 1
|
||||
else:
|
||||
h,w,c = image.shape
|
||||
|
||||
if c < max_c:
|
||||
if c == 1:
|
||||
if len(image.shape) == 2:
|
||||
image = np.expand_dims ( image, -1 )
|
||||
image = np.concatenate ( (image,)*max_c, -1 )
|
||||
elif c == 2: #GA
|
||||
image = np.expand_dims ( image[...,0], -1 )
|
||||
image = np.concatenate ( (image,)*max_c, -1 )
|
||||
else:
|
||||
image = np.concatenate ( (image, np.ones((h,w,max_c - c))), -1 )
|
||||
|
||||
if h != target_wh or w != target_wh:
|
||||
image = cv2.resize ( image, (target_wh, target_wh) )
|
||||
h,w,c = image.shape
|
||||
|
||||
images[i] = image
|
||||
|
||||
return np.concatenate ( images, axis = 1 )
|
||||
|
||||
def bgr2hsv (img):
|
||||
return cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
|
||||
|
||||
def hsv2bgr (img):
|
||||
return cv2.cvtColor(img, cv2.COLOR_HSV2BGR)
|
||||
|
||||
def bgra2hsva (img):
|
||||
return np.concatenate ( (cv2.cvtColor(img[...,0:3], cv2.COLOR_BGR2HSV ), np.expand_dims (img[...,3], -1)), -1 )
|
||||
|
||||
def bgra2hsva_list (imgs):
|
||||
return [ bgra2hsva(img) for img in imgs ]
|
||||
|
||||
def hsva2bgra (img):
|
||||
return np.concatenate ( (cv2.cvtColor(img[...,0:3], cv2.COLOR_HSV2BGR ), np.expand_dims (img[...,3], -1)), -1 )
|
||||
|
||||
def hsva2bgra_list (imgs):
|
||||
return [ hsva2bgra(img) for img in imgs ]
|
||||
|
||||
def gen_warp_params (source, flip, rotation_range=[-10,10], scale_range=[-0.5, 0.5], tx_range=[-0.05, 0.05], ty_range=[-0.05, 0.05] ):
|
||||
h,w,c = source.shape
|
||||
if (h != w) or (w != 64 and w != 128 and w != 256 and w != 512 and w != 1024):
|
||||
raise ValueError ('TrainingDataGenerator accepts only square power of 2 images.')
|
||||
|
||||
rotation = np.random.uniform( rotation_range[0], rotation_range[1] )
|
||||
scale = np.random.uniform(1 +scale_range[0], 1 +scale_range[1])
|
||||
tx = np.random.uniform( tx_range[0], tx_range[1] )
|
||||
ty = np.random.uniform( ty_range[0], ty_range[1] )
|
||||
|
||||
#random warp by grid
|
||||
cell_size = [ w // (2**i) for i in range(1,4) ] [ np.random.randint(3) ]
|
||||
cell_count = w // cell_size + 1
|
||||
|
||||
grid_points = np.linspace( 0, w, cell_count)
|
||||
mapx = np.broadcast_to(grid_points, (cell_count, cell_count)).copy()
|
||||
mapy = mapx.T
|
||||
|
||||
mapx[1:-1,1:-1] = mapx[1:-1,1:-1] + random_utils.random_normal( size=(cell_count-2, cell_count-2) )*(cell_size*0.24)
|
||||
mapy[1:-1,1:-1] = mapy[1:-1,1:-1] + random_utils.random_normal( size=(cell_count-2, cell_count-2) )*(cell_size*0.24)
|
||||
|
||||
half_cell_size = cell_size // 2
|
||||
|
||||
mapx = cv2.resize(mapx, (w+cell_size,)*2 )[half_cell_size:-half_cell_size-1,half_cell_size:-half_cell_size-1].astype(np.float32)
|
||||
mapy = cv2.resize(mapy, (w+cell_size,)*2 )[half_cell_size:-half_cell_size-1,half_cell_size:-half_cell_size-1].astype(np.float32)
|
||||
|
||||
#random transform
|
||||
random_transform_mat = cv2.getRotationMatrix2D((w // 2, w // 2), rotation, scale)
|
||||
random_transform_mat[:, 2] += (tx*w, ty*w)
|
||||
|
||||
params = dict()
|
||||
params['mapx'] = mapx
|
||||
params['mapy'] = mapy
|
||||
params['rmat'] = random_transform_mat
|
||||
params['w'] = w
|
||||
params['flip'] = flip and np.random.randint(10) < 4
|
||||
|
||||
return params
|
||||
|
||||
def warp_by_params (params, img, warp, transform, flip):
|
||||
if warp:
|
||||
img = cv2.remap(img, params['mapx'], params['mapy'], cv2.INTER_LANCZOS4 )
|
||||
if transform:
|
||||
img = cv2.warpAffine( img, params['rmat'], (params['w'], params['w']), borderMode=cv2.BORDER_CONSTANT, flags=cv2.INTER_LANCZOS4 )
|
||||
if flip and params['flip']:
|
||||
img = img[:,::-1,:]
|
||||
return img
|
63
utils/iter_utils.py
Normal file
63
utils/iter_utils.py
Normal file
|
@ -0,0 +1,63 @@
|
|||
import threading
|
||||
import queue as Queue
|
||||
import multiprocessing
|
||||
import time
|
||||
|
||||
|
||||
class ThisThreadGenerator(object):
|
||||
def __init__(self, generator_func, user_param=None):
|
||||
super().__init__()
|
||||
self.generator_func = generator_func
|
||||
self.user_param = user_param
|
||||
self.initialized = False
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if not self.initialized:
|
||||
self.initialized = True
|
||||
self.generator_func = self.generator_func(self.user_param)
|
||||
|
||||
return next(self.generator_func)
|
||||
|
||||
class SubprocessGenerator(object):
|
||||
def __init__(self, generator_func, user_param=None, prefetch=2):
|
||||
super().__init__()
|
||||
self.prefetch = prefetch
|
||||
self.generator_func = generator_func
|
||||
self.user_param = user_param
|
||||
self.sc_queue = multiprocessing.Queue()
|
||||
self.cs_queue = multiprocessing.Queue()
|
||||
self.p = None
|
||||
|
||||
def process_func(self):
|
||||
self.generator_func = self.generator_func(self.user_param)
|
||||
while True:
|
||||
while self.prefetch > -1:
|
||||
try:
|
||||
gen_data = next (self.generator_func)
|
||||
except StopIteration:
|
||||
self.cs_queue.put (None)
|
||||
return
|
||||
self.cs_queue.put (gen_data)
|
||||
self.prefetch -= 1
|
||||
self.sc_queue.get()
|
||||
self.prefetch += 1
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.p == None:
|
||||
self.p = multiprocessing.Process(target=self.process_func, args=())
|
||||
self.p.daemon = True
|
||||
self.p.start()
|
||||
|
||||
gen_data = self.cs_queue.get()
|
||||
if gen_data is None:
|
||||
self.p.terminate()
|
||||
self.p.join()
|
||||
raise StopIteration()
|
||||
self.sc_queue.put (1)
|
||||
return gen_data
|
18
utils/os_utils.py
Normal file
18
utils/os_utils.py
Normal file
|
@ -0,0 +1,18 @@
|
|||
import sys
|
||||
|
||||
if sys.platform[0:3] == 'win':
|
||||
from ctypes import windll
|
||||
from ctypes import wintypes
|
||||
|
||||
def set_process_lowest_prio():
|
||||
if sys.platform[0:3] == 'win':
|
||||
GetCurrentProcess = windll.kernel32.GetCurrentProcess
|
||||
GetCurrentProcess.restype = wintypes.HANDLE
|
||||
|
||||
SetPriorityClass = windll.kernel32.SetPriorityClass
|
||||
SetPriorityClass.argtypes = (wintypes.HANDLE, wintypes.DWORD)
|
||||
SetPriorityClass ( GetCurrentProcess(), 0x00000040 )
|
||||
|
||||
def set_process_dpi_aware():
|
||||
if sys.platform[0:3] == 'win':
|
||||
windll.user32.SetProcessDPIAware(True)
|
14
utils/random_utils.py
Normal file
14
utils/random_utils.py
Normal file
|
@ -0,0 +1,14 @@
|
|||
import numpy as np
|
||||
|
||||
def random_normal( size=(1,), trunc_val = 2.5 ):
|
||||
len = np.array(size).prod()
|
||||
result = np.empty ( (len,) , dtype=np.float32)
|
||||
|
||||
for i in range (len):
|
||||
while True:
|
||||
x = np.random.normal()
|
||||
if x >= -trunc_val and x <= trunc_val:
|
||||
break
|
||||
result[i] = (x / trunc_val)
|
||||
|
||||
return result.reshape ( size )
|
36
utils/std_utils.py
Normal file
36
utils/std_utils.py
Normal file
|
@ -0,0 +1,36 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
class suppress_stdout_stderr(object):
|
||||
def __enter__(self):
|
||||
self.outnull_file = open(os.devnull, 'w')
|
||||
self.errnull_file = open(os.devnull, 'w')
|
||||
|
||||
self.old_stdout_fileno_undup = sys.stdout.fileno()
|
||||
self.old_stderr_fileno_undup = sys.stderr.fileno()
|
||||
|
||||
self.old_stdout_fileno = os.dup ( sys.stdout.fileno() )
|
||||
self.old_stderr_fileno = os.dup ( sys.stderr.fileno() )
|
||||
|
||||
self.old_stdout = sys.stdout
|
||||
self.old_stderr = sys.stderr
|
||||
|
||||
os.dup2 ( self.outnull_file.fileno(), self.old_stdout_fileno_undup )
|
||||
os.dup2 ( self.errnull_file.fileno(), self.old_stderr_fileno_undup )
|
||||
|
||||
sys.stdout = self.outnull_file
|
||||
sys.stderr = self.errnull_file
|
||||
return self
|
||||
|
||||
def __exit__(self, *_):
|
||||
sys.stdout = self.old_stdout
|
||||
sys.stderr = self.old_stderr
|
||||
|
||||
os.dup2 ( self.old_stdout_fileno, self.old_stdout_fileno_undup )
|
||||
os.dup2 ( self.old_stderr_fileno, self.old_stderr_fileno_undup )
|
||||
|
||||
os.close ( self.old_stdout_fileno )
|
||||
os.close ( self.old_stderr_fileno )
|
||||
|
||||
self.outnull_file.close()
|
||||
self.errnull_file.close()
|
Loading…
Add table
Add a link
Reference in a new issue