This commit is contained in:
iperov 2018-06-04 17:12:43 +04:00
commit 6bd5a44264
71 changed files with 8448 additions and 0 deletions

296
utils/AlignedPNG.py Normal file
View file

@ -0,0 +1,296 @@
PNG_HEADER = b"\x89PNG\r\n\x1a\n"
import string
import struct
import zlib
import pickle
class Chunk(object):
def __init__(self, name=None, data=None):
self.length = 0
self.crc = 0
self.name = name if name else "noNe"
self.data = data if data else b""
@classmethod
def load(cls, data):
"""Load a chunk including header and footer"""
inst = cls()
if len(data) < 12:
msg = "Chunk-data too small"
raise ValueError(msg)
# chunk header & data
(inst.length, raw_name) = struct.unpack("!I4s", data[0:8])
inst.data = data[8:-4]
inst.verify_length()
inst.name = raw_name.decode("ascii")
inst.verify_name()
# chunk crc
inst.crc = struct.unpack("!I", data[8+inst.length:8+inst.length+4])[0]
inst.verify_crc()
return inst
def dump(self, auto_crc=True, auto_length=True):
"""Return the chunk including header and footer"""
if auto_length: self.update_length()
if auto_crc: self.update_crc()
self.verify_name()
return struct.pack("!I", self.length) + self.get_raw_name() + self.data + struct.pack("!I", self.crc)
def verify_length(self):
if len(self.data) != self.length:
msg = "Data length ({}) does not match length in chunk header ({})".format(len(self.data), self.length)
raise ValueError(msg)
return True
def verify_name(self):
for c in self.name:
if c not in string.ascii_letters:
msg = "Invalid character in chunk name: {}".format(repr(self.name))
raise ValueError(msg)
return True
def verify_crc(self):
calculated_crc = self.get_crc()
if self.crc != calculated_crc:
msg = "CRC mismatch: {:08X} (header), {:08X} (calculated)".format(self.crc, calculated_crc)
raise ValueError(msg)
return True
def update_length(self):
self.length = len(self.data)
def update_crc(self):
self.crc = self.get_crc()
def get_crc(self):
return zlib.crc32(self.get_raw_name() + self.data)
def get_raw_name(self):
return self.name if isinstance(self.name, bytes) else self.name.encode("ascii")
# name helper methods
def ancillary(self, set=None):
"""Set and get ancillary=True/critical=False bit"""
if set is True:
self.name[0] = self.name[0].lower()
elif set is False:
self.name[0] = self.name[0].upper()
return self.name[0].islower()
def private(self, set=None):
"""Set and get private=True/public=False bit"""
if set is True:
self.name[1] = self.name[1].lower()
elif set is False:
self.name[1] = self.name[1].upper()
return self.name[1].islower()
def reserved(self, set=None):
"""Set and get reserved_valid=True/invalid=False bit"""
if set is True:
self.name[2] = self.name[2].upper()
elif set is False:
self.name[2] = self.name[2].lower()
return self.name[2].isupper()
def safe_to_copy(self, set=None):
"""Set and get save_to_copy=True/unsafe=False bit"""
if set is True:
self.name[3] = self.name[3].lower()
elif set is False:
self.name[3] = self.name[3].upper()
return self.name[3].islower()
def __str__(self):
return "<Chunk '{name}' length={length} crc={crc:08X}>".format(**self.__dict__)
class IHDR(Chunk):
"""IHDR Chunk
width, height, bit_depth, color_type, compression_method,
filter_method, interlace_method contain the data extracted
from the chunk. Modify those and use and build() to recreate
the chunk. Valid values for bit_depth depend on the color_type
and can be looked up in color_types or in the PNG specification
See:
http://www.libpng.org/pub/png/spec/1.2/PNG-Chunks.html#C.IHDR
"""
# color types with name & allowed bit depths
COLOR_TYPE_GRAY = 0
COLOR_TYPE_RGB = 2
COLOR_TYPE_PLTE = 3
COLOR_TYPE_GRAYA = 4
COLOR_TYPE_RGBA = 6
color_types = {
COLOR_TYPE_GRAY: ("Grayscale", (1,2,4,8,16)),
COLOR_TYPE_RGB: ("RGB", (8,16)),
COLOR_TYPE_PLTE: ("Palette", (1,2,4,8)),
COLOR_TYPE_GRAYA: ("Greyscale+Alpha", (8,16)),
COLOR_TYPE_RGBA: ("RGBA", (8,16)),
}
def __init__(self, width=0, height=0, bit_depth=8, color_type=2, \
compression_method=0, filter_method=0, interlace_method=0):
self.width = width
self.height = height
self.bit_depth = bit_depth
self.color_type = color_type
self.compression_method = compression_method
self.filter_method = filter_method
self.interlace_method = interlace_method
super().__init__("IHDR")
@classmethod
def load(cls, data):
inst = super().load(data)
fields = struct.unpack("!IIBBBBB", inst.data)
inst.width = fields[0]
inst.height = fields[1]
inst.bit_depth = fields[2] # per channel
inst.color_type = fields[3] # see specs
inst.compression_method = fields[4] # always 0(=deflate/inflate)
inst.filter_method = fields[5] # always 0(=adaptive filtering with 5 methods)
inst.interlace_method = fields[6] # 0(=no interlace) or 1(=Adam7 interlace)
return inst
def dump(self):
self.data = struct.pack("!IIBBBBB", \
self.width, self.height, self.bit_depth, self.color_type, \
self.compression_method, self.filter_method, self.interlace_method)
return super().dump()
def __str__(self):
return "<Chunk:IHDR geometry={width}x{height} bit_depth={bit_depth} color_type={}>" \
.format(self.color_types[self.color_type][0], **self.__dict__)
class IEND(Chunk):
def __init__(self):
super().__init__("IEND")
def dump(self):
if len(self.data) != 0:
msg = "IEND has data which is not allowed"
raise ValueError(msg)
if self.length != 0:
msg = "IEND data lenght is not 0 which is not allowed"
raise ValueError(msg)
return super().dump()
def __str__(self):
return "<Chunk:IEND>".format(**self.__dict__)
class FaceswapChunk(Chunk):
def __init__(self, dict_data=None):
super().__init__("fcWp")
self.dict_data = dict_data
def setDictData(self, dict_data):
self.dict_data = dict_data
def getDictData(self):
return self.dict_data
@classmethod
def load(cls, data):
inst = super().load(data)
inst.dict_data = pickle.loads( inst.data )
return inst
def dump(self):
self.data = pickle.dumps (self.dict_data)
return super().dump()
chunk_map = {
b"IHDR": IHDR,
b"fcWp": FaceswapChunk,
b"IEND": IEND
}
class AlignedPNG(object):
def __init__(self):
self.data = b""
self.length = 0
self.chunks = []
@staticmethod
def load(data):
try:
with open(data, "rb") as f:
data = f.read()
except:
raise FileNotFoundError(data)
inst = AlignedPNG()
inst.data = data
inst.length = len(data)
if data[0:8] != PNG_HEADER:
msg = "No Valid PNG header"
raise ValueError(msg)
chunk_start = 8
while chunk_start < inst.length:
(chunk_length, chunk_name) = struct.unpack("!I4s", data[chunk_start:chunk_start+8])
chunk_end = chunk_start + chunk_length + 12
chunk = chunk_map.get(chunk_name, Chunk).load(data[chunk_start:chunk_end])
inst.chunks.append(chunk)
chunk_start = chunk_end
return inst
def save(self, filename):
try:
with open(filename, "wb") as f:
f.write ( self.dump() )
except:
raise Exception( 'cannot save %s' % (filename) )
def dump(self):
data = PNG_HEADER
for chunk in self.chunks:
data += chunk.dump()
return data
def get_shape(self):
for chunk in self.chunks:
if type(chunk) == IHDR:
c = 3 if chunk.color_type == IHDR.COLOR_TYPE_RGB else 4
w = chunk.width
h = chunk.height
return (h,w,c)
return (0,0,0)
def get_height(self):
for chunk in self.chunks:
if type(chunk) == IHDR:
return chunk.height
return 0
def getFaceswapDictData(self):
for chunk in self.chunks:
if type(chunk) == FaceswapChunk:
return chunk.getDictData()
return None
def setFaceswapDictData (self, dict_data=None):
for chunk in self.chunks:
if type(chunk) == FaceswapChunk:
self.chunks.remove(chunk)
break
if not dict_data is None:
chunk = FaceswapChunk(dict_data)
self.chunks.insert(-1, chunk)
def __str__(self):
return "<PNG length={length} chunks={}>".format(len(self.chunks), **self.__dict__)

40
utils/Path_utils.py Normal file
View file

@ -0,0 +1,40 @@
from pathlib import Path
from scandir import scandir
image_extensions = [".jpg", ".jpeg", ".png", ".tif", ".tiff"]
def get_image_paths(dir_path):
dir_path = Path (dir_path)
result = []
if dir_path.exists():
for x in list(scandir(str(dir_path))):
if any([x.name.lower().endswith(ext) for ext in image_extensions]):
result.append(x.path)
return result
def get_image_unique_filestem_paths(dir_path, verbose=False):
result = get_image_paths(dir_path)
result_dup = set()
for f in result[:]:
f_stem = Path(f).stem
if f_stem in result_dup:
result.remove(f)
if verbose:
print ("Duplicate filenames are not allowed, skipping: %s" % Path(f).name )
continue
result_dup.add(f_stem)
return result
def get_all_dir_names_startswith (dir_path, startswith):
dir_path = Path (dir_path)
startswith = startswith.lower()
result = []
if dir_path.exists():
for x in list(scandir(str(dir_path))):
if x.name.lower().startswith(startswith):
result.append ( x.name[len(startswith):] )
return result

246
utils/SubprocessorBase.py Normal file
View file

@ -0,0 +1,246 @@
import traceback
from tqdm import tqdm
import multiprocessing
import time
import sys
class SubprocessorBase(object):
#overridable
def __init__(self, name, no_response_time_sec = 60):
self.name = name
self.no_response_time_sec = no_response_time_sec
#overridable
def process_info_generator(self):
#yield name, host_dict, client_dict - per process
yield 'first process', {}, {}
#overridable
def get_no_process_started_message(self):
return "No process started."
#overridable
def onHostGetProgressBarDesc(self):
return "Processing"
#overridable
def onHostGetProgressBarLen(self):
return 0
#overridable
def onHostGetData(self):
#return data here
return None
#overridable
def onHostDataReturn (self, data):
#input_data.insert(0, obj['data'])
pass
#overridable
def onClientInitialize(self, client_dict):
#return fail message or None if ok
return None
#overridable
def onClientFinalize(self):
pass
#overridable
def onClientProcessData(self, data):
#return result object
return None
#overridable
def onClientGetDataName (self, data):
#return string identificator of your data
return "undefined"
#overridable
def onHostClientsInitialized(self):
pass
#overridable
def onHostResult (self, data, result):
#return count of progress bar update
return 1
#overridable
def onHostProcessEnd(self):
pass
#overridable
def get_start_return(self):
return None
def inc_progress_bar(self, c):
self.progress_bar.update(c)
def safe_print(self, msg):
self.print_lock.acquire()
print (msg)
self.print_lock.release()
def process(self):
#returns start_return
self.processes = []
self.print_lock = multiprocessing.Lock()
for name, host_dict, client_dict in self.process_info_generator():
sq = multiprocessing.Queue()
cq = multiprocessing.Queue()
client_dict.update ( {'print_lock' : self.print_lock} )
p = multiprocessing.Process(target=self.subprocess, args=(sq,cq,client_dict))
p.daemon = True
p.start()
self.processes.append ( { 'process' : p,
'sq' : sq,
'cq' : cq,
'state' : 'busy',
'sent_time': time.time(),
'name': name,
'host_dict' : host_dict
} )
while True:
for p in self.processes[:]:
while not p['cq'].empty():
obj = p['cq'].get()
obj_op = obj['op']
if obj_op == 'init_ok':
p['state'] = 'free'
elif obj_op == 'error':
if obj['close'] == True:
p['process'].terminate()
p['process'].join()
self.processes.remove(p)
break
if all ([ p['state'] == 'free' for p in self.processes ] ):
break
if len(self.processes) == 0:
print ( self.get_no_process_started_message() )
return self.get_start_return()
self.onHostClientsInitialized()
self.progress_bar = tqdm( total=self.onHostGetProgressBarLen(), desc=self.onHostGetProgressBarDesc() )
try:
while True:
for p in self.processes[:]:
while not p['cq'].empty():
obj = p['cq'].get()
obj_op = obj['op']
if obj_op == 'success':
data = obj['data']
result = obj['result']
c = self.onHostResult (data, result)
if c > 0:
self.progress_bar.update(c)
p['state'] = 'free'
elif obj_op == 'error':
if 'data' in obj.keys():
self.onHostDataReturn ( obj['data'] )
if obj['close'] == True:
p['sq'].put ( {'op': 'close'} )
p['process'].join()
self.processes.remove(p)
break
p['state'] = 'free'
for p in self.processes[:]:
if p['state'] == 'free':
data = self.onHostGetData()
if data is not None:
p['sq'].put ( {'op': 'data', 'data' : data} )
p['sent_time'] = time.time()
p['sent_data'] = data
p['state'] = 'busy'
elif p['state'] == 'busy':
if (time.time() - p['sent_time']) > self.no_response_time_sec:
print ( '%s doesnt response, terminating it.' % (p['name']) )
self.onHostDataReturn ( p['sent_data'] )
p['sq'].put ( {'op': 'close'} )
p['process'].join()
self.processes.remove(p)
if all ([p['state'] == 'free' for p in self.processes]):
break
time.sleep(0.005)
except:
print ("Exception occured in Subprocessor.start(): %s" % (traceback.format_exc()) )
self.progress_bar.close()
for p in self.processes[:]:
p['sq'].put ( {'op': 'close'} )
while True:
for p in self.processes[:]:
while not p['cq'].empty():
obj = p['cq'].get()
obj_op = obj['op']
if obj_op == 'finalized':
p['state'] = 'finalized'
if all ([p['state'] == 'finalized' for p in self.processes]):
break
for p in self.processes[:]:
p['process'].terminate()
self.onHostProcessEnd()
return self.get_start_return()
def subprocess(self, sq, cq, client_dict):
self.print_lock = client_dict['print_lock']
try:
fail_message = self.onClientInitialize(client_dict)
except:
fail_message = 'Exception while initialization: %s' % (traceback.format_exc())
if fail_message is None:
cq.put ( {'op': 'init_ok'} )
else:
print (fail_message)
cq.put ( {'op': 'error', 'close': True} )
return
while True:
obj = sq.get()
obj_op = obj['op']
if obj_op == 'data':
data = obj['data']
try:
result = self.onClientProcessData (data)
cq.put ( {'op': 'success', 'data' : data, 'result' : result} )
except:
print ( 'Exception while process data [%s]: %s' % (self.onClientGetDataName(data), traceback.format_exc()) )
cq.put ( {'op': 'error', 'close': True, 'data' : data } )
elif obj_op == 'close':
break
time.sleep(0.005)
self.onClientFinalize()
cq.put ( {'op': 'finalized'} )
while True:
time.sleep(0.1)

264
utils/image_utils.py Normal file
View file

@ -0,0 +1,264 @@
import sys
from utils import random_utils
import numpy as np
import cv2
import localization
from scipy.spatial import Delaunay
from PIL import Image, ImageDraw, ImageFont
def channel_hist_match(source, template, mask=None):
# Code borrowed from:
# https://stackoverflow.com/questions/32655686/histogram-matching-of-two-images-in-python-2-x
masked_source = source
masked_template = template
if mask is not None:
masked_source = source * mask
masked_template = template * mask
oldshape = source.shape
source = source.ravel()
template = template.ravel()
masked_source = masked_source.ravel()
masked_template = masked_template.ravel()
s_values, bin_idx, s_counts = np.unique(source, return_inverse=True,
return_counts=True)
t_values, t_counts = np.unique(template, return_counts=True)
ms_values, mbin_idx, ms_counts = np.unique(source, return_inverse=True,
return_counts=True)
mt_values, mt_counts = np.unique(template, return_counts=True)
s_quantiles = np.cumsum(s_counts).astype(np.float64)
s_quantiles /= s_quantiles[-1]
t_quantiles = np.cumsum(t_counts).astype(np.float64)
t_quantiles /= t_quantiles[-1]
interp_t_values = np.interp(s_quantiles, t_quantiles, t_values)
return interp_t_values[bin_idx].reshape(oldshape)
def color_hist_match(src_im, tar_im, mask=None):
h,w,c = src_im.shape
matched_R = channel_hist_match(src_im[:,:,0], tar_im[:,:,0], mask)
matched_G = channel_hist_match(src_im[:,:,1], tar_im[:,:,1], mask)
matched_B = channel_hist_match(src_im[:,:,2], tar_im[:,:,2], mask)
to_stack = (matched_R, matched_G, matched_B)
for i in range(3, c):
to_stack += ( src_im[:,:,i],)
matched = np.stack(to_stack, axis=-1).astype(src_im.dtype)
return matched
pil_fonts = {}
def _get_pil_font (font, size):
global pil_fonts
try:
font_str_id = '%s_%d' % (font, size)
if font_str_id not in pil_fonts.keys():
pil_fonts[font_str_id] = ImageFont.truetype(font + ".ttf", size=size, encoding="unic")
pil_font = pil_fonts[font_str_id]
return pil_font
except:
return ImageFont.load_default()
def get_text_image( shape, text, color=(1,1,1), border=0.2, font=None):
try:
size = shape[1]
pil_font = _get_pil_font( localization.get_default_ttf_font_name() , size)
text_width, text_height = pil_font.getsize(text)
canvas = Image.new('RGB', shape[0:2], (0,0,0) )
draw = ImageDraw.Draw(canvas)
offset = ( 0, 0)
draw.text(offset, text, font=pil_font, fill=tuple((np.array(color)*255).astype(np.int)) )
result = np.asarray(canvas) / 255
if shape[2] != 3:
result = np.concatenate ( (result, np.ones ( (shape[1],) + (shape[0],) + (shape[2]-3,)) ), axis=2 )
return result
except:
return np.zeros ( (shape[1], shape[0], shape[2]), dtype=np.float32 )
def draw_text( image, rect, text, color=(1,1,1), border=0.2, font=None):
h,w,c = image.shape
l,t,r,b = rect
l = np.clip (l, 0, w-1)
r = np.clip (r, 0, w-1)
t = np.clip (t, 0, h-1)
b = np.clip (b, 0, h-1)
image[t:b, l:r] += get_text_image ( (r-l,b-t,c) , text, color, border, font )
def draw_text_lines (image, rect, text_lines, color=(1,1,1), border=0.2, font=None):
text_lines_len = len(text_lines)
if text_lines_len == 0:
return
l,t,r,b = rect
h = b-t
h_per_line = h // text_lines_len
for i in range(0, text_lines_len):
draw_text (image, (l, i*h_per_line, r, (i+1)*h_per_line), text_lines[i], color, border, font)
def get_draw_text_lines ( image, rect, text_lines, color=(1,1,1), border=0.2, font=None):
image = np.zeros ( image.shape, dtype=np.float )
draw_text_lines ( image, rect, text_lines, color, border, font)
return image
def draw_polygon (image, points, color, thickness = 1):
points_len = len(points)
for i in range (0, points_len):
p0 = tuple( points[i] )
p1 = tuple( points[ (i+1) % points_len] )
cv2.line (image, p0, p1, color, thickness=thickness)
def draw_rect(image, rect, color, thickness=1):
l,t,r,b = rect
draw_polygon (image, [ (l,t), (r,t), (r,b), (l,b ) ], color, thickness)
def rectContains(rect, point) :
return not (point[0] < rect[0] or point[0] >= rect[2] or point[1] < rect[1] or point[1] >= rect[3])
def applyAffineTransform(src, srcTri, dstTri, size) :
warpMat = cv2.getAffineTransform( np.float32(srcTri), np.float32(dstTri) )
return cv2.warpAffine( src, warpMat, (size[0], size[1]), None, flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT_101 )
def morphTriangle(dst_img, src_img, st, dt) :
(h,w,c) = dst_img.shape
sr = np.array( cv2.boundingRect(np.float32(st)) )
dr = np.array( cv2.boundingRect(np.float32(dt)) )
sRect = st - sr[0:2]
dRect = dt - dr[0:2]
d_mask = np.zeros((dr[3], dr[2], c), dtype = np.float32)
cv2.fillConvexPoly(d_mask, np.int32(dRect), (1.0,)*c, 8, 0);
imgRect = src_img[sr[1]:sr[1] + sr[3], sr[0]:sr[0] + sr[2]]
size = (dr[2], dr[3])
warpImage1 = applyAffineTransform(imgRect, sRect, dRect, size)
dst_img[dr[1]:dr[1]+dr[3], dr[0]:dr[0]+dr[2]] = dst_img[dr[1]:dr[1]+dr[3], dr[0]:dr[0]+dr[2]]*(1-d_mask) + warpImage1 * d_mask
def morph_by_points (image, sp, dp):
if sp.shape != dp.shape:
raise ValueError ('morph_by_points() sp.shape != dp.shape')
(h,w,c) = image.shape
result_image = np.zeros(image.shape, dtype = image.dtype)
for tri in Delaunay(dp).simplices:
morphTriangle(result_image, image, sp[tri], dp[tri])
return result_image
def equalize_and_stack_square (images, axis=1):
max_c = max ([ 1 if len(image.shape) == 2 else image.shape[2] for image in images ] )
target_wh = 99999
for i,image in enumerate(images):
if len(image.shape) == 2:
h,w = image.shape
c = 1
else:
h,w,c = image.shape
if h < target_wh:
target_wh = h
if w < target_wh:
target_wh = w
for i,image in enumerate(images):
if len(image.shape) == 2:
h,w = image.shape
c = 1
else:
h,w,c = image.shape
if c < max_c:
if c == 1:
if len(image.shape) == 2:
image = np.expand_dims ( image, -1 )
image = np.concatenate ( (image,)*max_c, -1 )
elif c == 2: #GA
image = np.expand_dims ( image[...,0], -1 )
image = np.concatenate ( (image,)*max_c, -1 )
else:
image = np.concatenate ( (image, np.ones((h,w,max_c - c))), -1 )
if h != target_wh or w != target_wh:
image = cv2.resize ( image, (target_wh, target_wh) )
h,w,c = image.shape
images[i] = image
return np.concatenate ( images, axis = 1 )
def bgr2hsv (img):
return cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
def hsv2bgr (img):
return cv2.cvtColor(img, cv2.COLOR_HSV2BGR)
def bgra2hsva (img):
return np.concatenate ( (cv2.cvtColor(img[...,0:3], cv2.COLOR_BGR2HSV ), np.expand_dims (img[...,3], -1)), -1 )
def bgra2hsva_list (imgs):
return [ bgra2hsva(img) for img in imgs ]
def hsva2bgra (img):
return np.concatenate ( (cv2.cvtColor(img[...,0:3], cv2.COLOR_HSV2BGR ), np.expand_dims (img[...,3], -1)), -1 )
def hsva2bgra_list (imgs):
return [ hsva2bgra(img) for img in imgs ]
def gen_warp_params (source, flip, rotation_range=[-10,10], scale_range=[-0.5, 0.5], tx_range=[-0.05, 0.05], ty_range=[-0.05, 0.05] ):
h,w,c = source.shape
if (h != w) or (w != 64 and w != 128 and w != 256 and w != 512 and w != 1024):
raise ValueError ('TrainingDataGenerator accepts only square power of 2 images.')
rotation = np.random.uniform( rotation_range[0], rotation_range[1] )
scale = np.random.uniform(1 +scale_range[0], 1 +scale_range[1])
tx = np.random.uniform( tx_range[0], tx_range[1] )
ty = np.random.uniform( ty_range[0], ty_range[1] )
#random warp by grid
cell_size = [ w // (2**i) for i in range(1,4) ] [ np.random.randint(3) ]
cell_count = w // cell_size + 1
grid_points = np.linspace( 0, w, cell_count)
mapx = np.broadcast_to(grid_points, (cell_count, cell_count)).copy()
mapy = mapx.T
mapx[1:-1,1:-1] = mapx[1:-1,1:-1] + random_utils.random_normal( size=(cell_count-2, cell_count-2) )*(cell_size*0.24)
mapy[1:-1,1:-1] = mapy[1:-1,1:-1] + random_utils.random_normal( size=(cell_count-2, cell_count-2) )*(cell_size*0.24)
half_cell_size = cell_size // 2
mapx = cv2.resize(mapx, (w+cell_size,)*2 )[half_cell_size:-half_cell_size-1,half_cell_size:-half_cell_size-1].astype(np.float32)
mapy = cv2.resize(mapy, (w+cell_size,)*2 )[half_cell_size:-half_cell_size-1,half_cell_size:-half_cell_size-1].astype(np.float32)
#random transform
random_transform_mat = cv2.getRotationMatrix2D((w // 2, w // 2), rotation, scale)
random_transform_mat[:, 2] += (tx*w, ty*w)
params = dict()
params['mapx'] = mapx
params['mapy'] = mapy
params['rmat'] = random_transform_mat
params['w'] = w
params['flip'] = flip and np.random.randint(10) < 4
return params
def warp_by_params (params, img, warp, transform, flip):
if warp:
img = cv2.remap(img, params['mapx'], params['mapy'], cv2.INTER_LANCZOS4 )
if transform:
img = cv2.warpAffine( img, params['rmat'], (params['w'], params['w']), borderMode=cv2.BORDER_CONSTANT, flags=cv2.INTER_LANCZOS4 )
if flip and params['flip']:
img = img[:,::-1,:]
return img

63
utils/iter_utils.py Normal file
View file

@ -0,0 +1,63 @@
import threading
import queue as Queue
import multiprocessing
import time
class ThisThreadGenerator(object):
def __init__(self, generator_func, user_param=None):
super().__init__()
self.generator_func = generator_func
self.user_param = user_param
self.initialized = False
def __iter__(self):
return self
def __next__(self):
if not self.initialized:
self.initialized = True
self.generator_func = self.generator_func(self.user_param)
return next(self.generator_func)
class SubprocessGenerator(object):
def __init__(self, generator_func, user_param=None, prefetch=2):
super().__init__()
self.prefetch = prefetch
self.generator_func = generator_func
self.user_param = user_param
self.sc_queue = multiprocessing.Queue()
self.cs_queue = multiprocessing.Queue()
self.p = None
def process_func(self):
self.generator_func = self.generator_func(self.user_param)
while True:
while self.prefetch > -1:
try:
gen_data = next (self.generator_func)
except StopIteration:
self.cs_queue.put (None)
return
self.cs_queue.put (gen_data)
self.prefetch -= 1
self.sc_queue.get()
self.prefetch += 1
def __iter__(self):
return self
def __next__(self):
if self.p == None:
self.p = multiprocessing.Process(target=self.process_func, args=())
self.p.daemon = True
self.p.start()
gen_data = self.cs_queue.get()
if gen_data is None:
self.p.terminate()
self.p.join()
raise StopIteration()
self.sc_queue.put (1)
return gen_data

18
utils/os_utils.py Normal file
View file

@ -0,0 +1,18 @@
import sys
if sys.platform[0:3] == 'win':
from ctypes import windll
from ctypes import wintypes
def set_process_lowest_prio():
if sys.platform[0:3] == 'win':
GetCurrentProcess = windll.kernel32.GetCurrentProcess
GetCurrentProcess.restype = wintypes.HANDLE
SetPriorityClass = windll.kernel32.SetPriorityClass
SetPriorityClass.argtypes = (wintypes.HANDLE, wintypes.DWORD)
SetPriorityClass ( GetCurrentProcess(), 0x00000040 )
def set_process_dpi_aware():
if sys.platform[0:3] == 'win':
windll.user32.SetProcessDPIAware(True)

14
utils/random_utils.py Normal file
View file

@ -0,0 +1,14 @@
import numpy as np
def random_normal( size=(1,), trunc_val = 2.5 ):
len = np.array(size).prod()
result = np.empty ( (len,) , dtype=np.float32)
for i in range (len):
while True:
x = np.random.normal()
if x >= -trunc_val and x <= trunc_val:
break
result[i] = (x / trunc_val)
return result.reshape ( size )

36
utils/std_utils.py Normal file
View file

@ -0,0 +1,36 @@
import os
import sys
class suppress_stdout_stderr(object):
def __enter__(self):
self.outnull_file = open(os.devnull, 'w')
self.errnull_file = open(os.devnull, 'w')
self.old_stdout_fileno_undup = sys.stdout.fileno()
self.old_stderr_fileno_undup = sys.stderr.fileno()
self.old_stdout_fileno = os.dup ( sys.stdout.fileno() )
self.old_stderr_fileno = os.dup ( sys.stderr.fileno() )
self.old_stdout = sys.stdout
self.old_stderr = sys.stderr
os.dup2 ( self.outnull_file.fileno(), self.old_stdout_fileno_undup )
os.dup2 ( self.errnull_file.fileno(), self.old_stderr_fileno_undup )
sys.stdout = self.outnull_file
sys.stderr = self.errnull_file
return self
def __exit__(self, *_):
sys.stdout = self.old_stdout
sys.stderr = self.old_stderr
os.dup2 ( self.old_stdout_fileno, self.old_stdout_fileno_undup )
os.dup2 ( self.old_stderr_fileno, self.old_stderr_fileno_undup )
os.close ( self.old_stdout_fileno )
os.close ( self.old_stderr_fileno )
self.outnull_file.close()
self.errnull_file.close()