diff --git a/.gitignore b/.gitignore index 16a6020..77e566e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,8 +1,127 @@ -* -!*.py -!*.md -!*.txt -!*.jpg -!requirements* -!Dockerfile* -!*.sh \ No newline at end of file + +### Python template +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +.idea/ diff --git a/converters/ConverterMasked.py b/converters/ConverterMasked.py index 8cbe4ca..7dee2c1 100644 --- a/converters/ConverterMasked.py +++ b/converters/ConverterMasked.py @@ -5,9 +5,11 @@ import cv2 import numpy as np import imagelib +import imagelib_legacy from facelib import FaceType, FANSegmentator, LandmarksProcessor from interact import interact as io from joblib import SubprocessFunctionCaller +from samplelib.SampleProcessor import ColorTransferMode from utils.pickle_utils import AntiPickler from .Converter import Converter @@ -34,6 +36,7 @@ class ConverterMasked(Converter): base_blur_mask_modifier=0, default_erode_mask_modifier=0, default_blur_mask_modifier=0, + default_apply_random_ct=ColorTransferMode.NONE, clip_hborder_mask_per=0, force_mask_mode=-1): @@ -116,8 +119,43 @@ class ConverterMasked(Converter): 1.0 + io.input_int("Choose output face scale modifier [-50..50] (skip:0) : ", 0) * 0.01, 0.5, 1.5) if self.mode != 'raw': - self.color_transfer_mode = io.input_str( - "Apply color transfer to predicted face? Choose mode ( rct/lct skip:None ) : ", None, ['rct', 'lct']) + + # FIXME + # self.color_transfer_mode = np.clip(io.input_int( + # "Apply color transfer to predicted face? (0) None, (1) LCT, (2) RCT, (3) RCT-c, (4) RCT-p, " + # "(5) RCT-pc, (6) mRTC, (7) mRTC-c, (8) mRTC-p, (9) mRTC-pc ?:help skip:%s) : " % default_apply_random_ct, + # default_apply_random_ct, + # help_message="Increase variativity of src samples by apply color transfer from random dst " + # "samples. It is like 'face_style' learning, but more precise color transfer and without " + # "risk of model collapse, also it does not require additional GPU resources, " + # "but the training time may be longer, due to the src faceset is becoming more " + # "diverse.\n\n" + # "===Modes===:\n" + # "(0) None - No transformation\n" + # "(1) LCT - Linear Color Transfer\n" + # "(2) RCT - Reinhard Color Transfer (Uses L*A*B* colorspace)\n" + # "(3) RCT-c - RCT, clipping LAB values outside of range instead of scaling\n" + # "(4) RCT-p - RCT, preserves paper's method of sTar/sSrc, instead of sSrc/sTar\n" + # "(5) RCT-pc - RCT with both clipping and preserve paper\n" + # "(6) mRCT - Masked RCT, computed using only masked portion of faces\n" + # "(7) mRCT-c - Masked RCT with clipping\n" + # "(8) mRCT-p - Masked RCT with preserve paper\n" + # "(9) mRCT-pc - Masked RCT with both clipping and preserve paper"), + # ColorTransferMode.NONE, ColorTransferMode.MASKED_RCT_PAPER_CLIP) + + self.color_transfer_mode = np.clip(io.input_int( + "Apply color transfer to predicted face? (0) None, (1) LCT, (2) RCT-legacy?:help skip:%s) : " % default_apply_random_ct, + default_apply_random_ct, + help_message="Increase variativity of src samples by apply color transfer from random dst " + "samples. It is like 'face_style' learning, but more precise color transfer and without " + "risk of model collapse, also it does not require additional GPU resources, " + "but the training time may be longer, due to the src faceset is becoming more " + "diverse.\n\n" + "===Modes===:\n" + "(0) None - No transformation\n" + "(1) LCT - Linear Color Transfer\n" + "(2) RCT - Reinhard Color Transfer (legacy/original version)"), + ColorTransferMode.NONE, ColorTransferMode.RCT) self.super_resolution = io.input_bool("Apply super resolution? (y/n ?:help skip:n) : ", False, help_message="Enhance details by applying DCSCN network.") @@ -322,41 +360,51 @@ class ConverterMasked(Converter): if debug: debugs += [img_mask_blurry_aaa.copy()] - if 'seamless' not in self.mode and self.color_transfer_mode is not None: - if self.color_transfer_mode == 'rct': + if 'seamless' not in self.mode and self.color_transfer_mode: + if self.color_transfer_mode: if debug: debugs += [np.clip(cv2.warpAffine(prd_face_bgr, face_output_mat, img_size, np.zeros(img_bgr.shape, dtype=np.float32), cv2.WARP_INVERSE_MAP | cv2.INTER_LANCZOS4, cv2.BORDER_TRANSPARENT), 0, 1.0)] - prd_face_bgr = imagelib.reinhard_color_transfer( - prd_face_bgr, - dst_face_bgr, - source_mask=prd_face_mask_a, target_mask=prd_face_mask_a) + if self.color_transfer_mode == ColorTransferMode.LCT: + prd_face_bgr = imagelib.linear_color_transfer(prd_face_bgr, dst_face_bgr) - if debug: - debugs += [np.clip(cv2.warpAffine(prd_face_bgr, face_output_mat, img_size, - np.zeros(img_bgr.shape, dtype=np.float32), - cv2.WARP_INVERSE_MAP | cv2.INTER_LANCZOS4, - cv2.BORDER_TRANSPARENT), 0, 1.0)] + if self.color_transfer_mode == ColorTransferMode.RCT: + prd_face_bgr = imagelib_legacy.reinhard_color_transfer ( np.clip( (prd_face_bgr*255).astype(np.uint8), 0, 255), + np.clip( (dst_face_bgr*255).astype(np.uint8), 0, 255), + source_mask=prd_face_mask_a, target_mask=prd_face_mask_a) + prd_face_bgr = np.clip( prd_face_bgr.astype(np.float32) / 255.0, 0.0, 1.0) + # FIXME + # elif ColorTransferMode.RCT <= self.color_transfer_mode <= ColorTransferMode.MASKED_RCT_PAPER_CLIP: + # ct_options = { + # ColorTransferMode.RCT: (False, False, False), + # ColorTransferMode.RCT_CLIP: (False, False, True), + # ColorTransferMode.RCT_PAPER: (False, True, False), + # ColorTransferMode.RCT_PAPER_CLIP: (False, True, True), + # ColorTransferMode.MASKED_RCT: (True, False, False), + # ColorTransferMode.MASKED_RCT_CLIP: (True, False, True), + # ColorTransferMode.MASKED_RCT_PAPER: (True, True, False), + # ColorTransferMode.MASKED_RCT_PAPER_CLIP: (True, True, True), + # } + # + # use_masks, use_paper, use_clip = ct_options[self.color_transfer_mode] + # + # if not use_masks: + # img_bgr = imagelib.reinhard_color_transfer(prd_face_bgr, dst_face_bgr, clip=use_clip, + # preserve_paper=use_paper) + # else: + # img_bgr = imagelib.reinhard_color_transfer(prd_face_bgr, dst_face_bgr, clip=use_clip, + # preserve_paper=use_paper, source_mask=prd_face_mask_a, + # target_mask=prd_face_mask_a) - elif self.color_transfer_mode == 'lct': - if debug: - debugs += [np.clip(cv2.warpAffine(prd_face_bgr, face_output_mat, img_size, - np.zeros(img_bgr.shape, dtype=np.float32), - cv2.WARP_INVERSE_MAP | cv2.INTER_LANCZOS4, - cv2.BORDER_TRANSPARENT), 0, 1.0)] - - prd_face_bgr = imagelib.linear_color_transfer(prd_face_bgr, dst_face_bgr) - prd_face_bgr = np.clip(prd_face_bgr, 0.0, 1.0) - - if debug: - debugs += [np.clip(cv2.warpAffine(prd_face_bgr, face_output_mat, img_size, - np.zeros(img_bgr.shape, dtype=np.float32), - cv2.WARP_INVERSE_MAP | cv2.INTER_LANCZOS4, - cv2.BORDER_TRANSPARENT), 0, 1.0)] + if debug: + debugs += [np.clip(cv2.warpAffine(prd_face_bgr, face_output_mat, img_size, + np.zeros(img_bgr.shape, dtype=np.float32), + cv2.WARP_INVERSE_MAP | cv2.INTER_LANCZOS4, + cv2.BORDER_TRANSPARENT), 0, 1.0)] if self.mode == 'hist-match-bw': prd_face_bgr = cv2.cvtColor(prd_face_bgr, cv2.COLOR_BGR2GRAY) @@ -436,36 +484,47 @@ class ConverterMasked(Converter): out_img = np.clip(img_bgr * (1 - img_mask_blurry_aaa) + (out_img * img_mask_blurry_aaa), 0, 1.0) - if 'seamless' in self.mode and self.color_transfer_mode is not None: + if 'seamless' in self.mode and self.color_transfer_mode: out_face_bgr = cv2.warpAffine(out_img, face_mat, (output_size, output_size)) - if self.color_transfer_mode == 'rct': + if self.color_transfer_mode: if debug: - debugs += [np.clip(cv2.warpAffine(out_face_bgr, face_output_mat, img_size, + debugs += [np.clip(cv2.warpAffine(prd_face_bgr, face_output_mat, img_size, np.zeros(img_bgr.shape, dtype=np.float32), cv2.WARP_INVERSE_MAP | cv2.INTER_LANCZOS4, cv2.BORDER_TRANSPARENT), 0, 1.0)] - new_out_face_bgr = imagelib.reinhard_color_transfer( - out_face_bgr, - dst_face_bgr, - source_mask=face_mask_blurry_aaa, target_mask=face_mask_blurry_aaa) + if self.color_transfer_mode == ColorTransferMode.LCT: + new_out_face_bgr = imagelib.linear_color_transfer(out_face_bgr, dst_face_bgr) - if debug: - debugs += [np.clip(cv2.warpAffine(new_out_face_bgr, face_output_mat, img_size, - np.zeros(img_bgr.shape, dtype=np.float32), - cv2.WARP_INVERSE_MAP | cv2.INTER_LANCZOS4, - cv2.BORDER_TRANSPARENT), 0, 1.0)] + if self.color_transfer_mode == ColorTransferMode.RCT: + new_out_face_bgr = imagelib_legacy.reinhard_color_transfer ( np.clip( (out_face_bgr*255).astype(np.uint8), 0, 255), + np.clip( (dst_face_bgr*255).astype(np.uint8), 0, 255), + source_mask=face_mask_blurry_aaa, target_mask=face_mask_blurry_aaa) + new_out_face_bgr = np.clip( new_out_face_bgr.astype(np.float32) / 255.0, 0.0, 1.0) - elif self.color_transfer_mode == 'lct': - if debug: - debugs += [np.clip(cv2.warpAffine(out_face_bgr, face_output_mat, img_size, - np.zeros(img_bgr.shape, dtype=np.float32), - cv2.WARP_INVERSE_MAP | cv2.INTER_LANCZOS4, - cv2.BORDER_TRANSPARENT), 0, 1.0)] - - new_out_face_bgr = imagelib.linear_color_transfer(out_face_bgr, dst_face_bgr) - new_out_face_bgr = np.clip(new_out_face_bgr, 0.0, 1.0) + # FIXME + # elif ColorTransferMode.RCT <= self.color_transfer_mode <= ColorTransferMode.MASKED_RCT_PAPER_CLIP: + # ct_options = { + # ColorTransferMode.RCT: (False, False, False), + # ColorTransferMode.RCT_CLIP: (False, False, True), + # ColorTransferMode.RCT_PAPER: (False, True, False), + # ColorTransferMode.RCT_PAPER_CLIP: (False, True, True), + # ColorTransferMode.MASKED_RCT: (True, False, False), + # ColorTransferMode.MASKED_RCT_CLIP: (True, False, True), + # ColorTransferMode.MASKED_RCT_PAPER: (True, True, False), + # ColorTransferMode.MASKED_RCT_PAPER_CLIP: (True, True, True), + # } + # + # use_masks, use_paper, use_clip = ct_options[self.color_transfer_mode] + # + # if not use_masks: + # new_out_face_bgr = imagelib.reinhard_color_transfer(out_face_bgr, dst_face_bgr, clip=use_clip, + # preserve_paper=use_paper) + # else: + # new_out_face_bgr = imagelib.reinhard_color_transfer(out_face_bgr, dst_face_bgr, clip=use_clip, + # preserve_paper=use_paper, source_mask=face_mask_blurry_aaa, + # target_mask=face_mask_blurry_aaa) if debug: debugs += [np.clip(cv2.warpAffine(new_out_face_bgr, face_output_mat, img_size, diff --git a/imagelib/color_transfer.py b/imagelib/color_transfer.py index be580a0..772be6a 100644 --- a/imagelib/color_transfer.py +++ b/imagelib/color_transfer.py @@ -2,9 +2,9 @@ import numpy as np import cv2 -def reinhard_color_transfer(target, source, clip=False, preserve_paper=False, source_mask=None, target_mask=None): +def reinhard_color_transfer(source, target, clip=False, preserve_paper=False, source_mask=None, target_mask=None): """ - Transfers the color distribution from the source to the target + Transfers the color distribution from the target to the source image using the mean and standard deviations of the L*a*b* color space. @@ -41,51 +41,49 @@ def reinhard_color_transfer(target, source, clip=False, preserve_paper=False, so OpenCV image (w, h, 3) NumPy array (float32) """ - np.clip(source, 0, 1, out=source) - np.clip(target, 0, 1, out=target) + # np.clip(source, 0, 1, out=source) + # np.clip(target, 0, 1, out=target) # convert the images from the RGB to L*ab* color space, being # sure to utilizing the floating point data type (note: OpenCV # expects floats to be 32-bit, so use that instead of 64-bit) - source = cv2.cvtColor(source.astype(np.float32), cv2.COLOR_BGR2LAB) - target = cv2.cvtColor(target.astype(np.float32), cv2.COLOR_BGR2LAB) + source = cv2.cvtColor(source, cv2.COLOR_BGR2LAB) + target = cv2.cvtColor(target, cv2.COLOR_BGR2LAB) # compute color statistics for the source and target images - src_input = source if source_mask is None else source * source_mask - tgt_input = target if target_mask is None else target * target_mask - (lMeanSrc, lStdSrc, aMeanSrc, aStdSrc, bMeanSrc, bStdSrc) = lab_image_stats(src_input) - (lMeanTar, lStdTar, aMeanTar, aStdTar, bMeanTar, bStdTar) = lab_image_stats(tgt_input) + (lMeanSrc, lStdSrc, aMeanSrc, aStdSrc, bMeanSrc, bStdSrc) = lab_image_stats(source, mask=source_mask) + (lMeanTar, lStdTar, aMeanTar, aStdTar, bMeanTar, bStdTar) = lab_image_stats(target, mask=target_mask) - # subtract the means from the target image - (l, a, b) = cv2.split(target) - l -= lMeanTar - a -= aMeanTar - b -= bMeanTar + + # subtract the means from the source image + (l, a, b) = cv2.split(source) + l -= lMeanSrc + a -= aMeanSrc + b -= bMeanSrc if preserve_paper: # scale by the standard deviations using paper proposed factor - l = (lStdTar / lStdSrc) * l - a = (aStdTar / aStdSrc) * a - b = (bStdTar / bStdSrc) * b + l = (lStdTar / lStdSrc) * l if lStdSrc != 0 else l + a = (aStdTar / aStdSrc) * a if aStdSrc != 0 else l + b = (bStdTar / bStdSrc) * b if bStdSrc != 0 else l else: # scale by the standard deviations using reciprocal of paper proposed factor - l = (lStdSrc / lStdTar) * l - a = (aStdSrc / aStdTar) * a - b = (bStdSrc / bStdTar) * b + l = (lStdSrc / lStdTar) * l if lStdTar != 0 else l + a = (aStdSrc / aStdTar) * a if aStdTar != 0 else l + b = (bStdSrc / bStdTar) * b if bStdTar != 0 else l # add in the source mean - l += lMeanSrc - a += aMeanSrc - b += bMeanSrc + l += lMeanTar + a += aMeanTar + b += bMeanTar - # clip/scale the pixel intensities to [0, 255] if they fall - # outside this range - l = _scale_array(l, clip=clip) - a = _scale_array(a, clip=clip) - b = _scale_array(b, clip=clip) + # clip/scale the pixel intensities if they fall + # outside the ranges for LAB + l = _scale_array(l, 0, 100, clip=clip, mask=source_mask) + a = _scale_array(a, -127, 127, clip=clip, mask=source_mask) + b = _scale_array(b, -127, 127, clip=clip, mask=source_mask) # merge the channels together and convert back to the RGB color - # space transfer = cv2.merge([l, a, b]) transfer = cv2.cvtColor(transfer, cv2.COLOR_LAB2BGR) np.clip(transfer, 0, 1, out=transfer) @@ -94,7 +92,7 @@ def reinhard_color_transfer(target, source, clip=False, preserve_paper=False, so return transfer -def linear_color_transfer(target_img, source_img, mode='pca', eps=1e-5): +def linear_color_transfer(target_img, source_img, mode='sym', eps=1e-3): """ Matches the colour distribution of the target image to that of the source image using a linear transform. @@ -130,54 +128,81 @@ def linear_color_transfer(target_img, source_img, mode='pca', eps=1e-5): Qt_Cs_Qt = Qt.dot(Cs).dot(Qt) eva_QtCsQt, eve_QtCsQt = np.linalg.eigh(Qt_Cs_Qt) QtCsQt = eve_QtCsQt.dot(np.sqrt(np.diag(eva_QtCsQt))).dot(eve_QtCsQt.T) - ts = np.linalg.inv(Qt).dot(QtCsQt).dot(np.linalg.pinv(Qt)).dot(t) + ts = np.linalg.pinv(Qt).dot(QtCsQt).dot(np.linalg.pinv(Qt)).dot(t) matched_img = ts.reshape(*target_img.transpose(2, 0, 1).shape).transpose(1, 2, 0) matched_img += mu_s - matched_img[matched_img > 1] = 1 - matched_img[matched_img < 0] = 0 + np.clip(matched_img, 0, 1, out=matched_img) return matched_img -def linear_lab_color_transform(target_img, source_img, eps=1e-5, mode='pca'): - """doesn't work yet""" - np.clip(source_img, 0, 1, out=source_img) - np.clip(target_img, 0, 1, out=target_img) - - # convert the images from the RGB to L*ab* color space, being - # sure to utilizing the floating point data type (note: OpenCV - # expects floats to be 32-bit, so use that instead of 64-bit) - source_img = cv2.cvtColor(source_img.astype(np.float32), cv2.COLOR_BGR2LAB) - target_img = cv2.cvtColor(target_img.astype(np.float32), cv2.COLOR_BGR2LAB) - - target_img = linear_color_transfer(target_img, source_img, mode=mode, eps=eps) - target_img = cv2.cvtColor(np.clip(target_img, 0, 1).astype(np.float32), cv2.COLOR_LAB2BGR) - np.clip(target_img, 0, 1, out=target_img) - return target_img - - -def lab_image_stats(image): +def lab_image_stats(image, mask=None): # compute the mean and standard deviation of each channel - (l, a, b) = cv2.split(image) - (lMean, lStd) = (l.mean(), l.std()) - (aMean, aStd) = (a.mean(), a.std()) - (bMean, bStd) = (b.mean(), b.std()) + l, a, b = cv2.split(image) + + if mask is not None: + im_mask = np.squeeze(mask) if len(np.shape(mask)) == 3 else mask + l, a, b = l[im_mask == 1], a[im_mask == 1], b[im_mask == 1] + + l_mean, l_std = np.mean(l), np.std(l) + a_mean, a_std = np.mean(a), np.std(a) + b_mean, b_std = np.mean(b), np.std(b) # return the color statistics - return (lMean, lStd, aMean, aStd, bMean, bStd) + return l_mean, l_std, a_mean, a_std, b_mean, b_std -def _scale_array(arr, clip=True): - if clip: - return np.clip(arr, 0, 255) - +def _min_max_scale(arr, new_range=(0, 255)): + """ + Perform min-max scaling to a NumPy array + Parameters: + ------- + arr: NumPy array to be scaled to [new_min, new_max] range + new_range: tuple of form (min, max) specifying range of + transformed array + Returns: + ------- + NumPy array that has been scaled to be in + [new_range[0], new_range[1]] range + """ + # get array's current min and max mn = arr.min() mx = arr.max() - scale_range = (max([mn, 0]), min([mx, 255])) - if mn < scale_range[0] or mx > scale_range[1]: - return (scale_range[1] - scale_range[0]) * (arr - mn) / (mx - mn) + scale_range[0] + # check if scaling needs to be done to be in new_range + if mn < new_range[0] or mx > new_range[1]: + # perform min-max scaling + scaled = (new_range[1] - new_range[0]) * (arr - mn) / (mx - mn) + new_range[0] + else: + # return array if already in range + scaled = arr - return arr + return scaled + + +def _scale_array(arr, mn, mx, clip=True, mask=None): + """ + Trim NumPy array values to be in [0, 255] range with option of + clipping or scaling. + Parameters: + ------- + arr: array to be trimmed to [0, 255] range + clip: should array be scaled by np.clip? if False then input + array will be min-max scaled to range + [max([arr.min(), 0]), min([arr.max(), 255])] + Returns: + ------- + NumPy array that has been scaled to be in [0, 255] range + """ + if clip: + scaled = np.clip(arr, mn, mx) + else: + if mask is not None: + scale_range = (max([np.min(mask * arr), mn]), min([np.max(mask * arr), mx])) + else: + scale_range = (max([np.min(arr), mn]), min([np.max(arr), mx])) + scaled = _min_max_scale(arr, new_range=scale_range) + + return scaled def channel_hist_match(source, template, hist_match_threshold=255, mask=None): diff --git a/imagelib/test/color_transfer.py b/imagelib/test/color_transfer.py new file mode 100644 index 0000000..0705c95 --- /dev/null +++ b/imagelib/test/color_transfer.py @@ -0,0 +1,95 @@ +import unittest + +import cv2 +import numpy as np + +from facelib import LandmarksProcessor +from imagelib import reinhard_color_transfer +from imagelib.color_transfer import _scale_array, lab_image_stats, linear_color_transfer +from interact.interact import InteractDesktop +from samplelib import SampleLoader, SampleType + + +class ColorTranfer(unittest.TestCase): + def test_algorithms(self): + src_samples = SampleLoader.load(SampleType.FACE, './test_src', None) + dst_samples = SampleLoader.load(SampleType.FACE, './test_dst', None) + + for src_sample in src_samples: + src_img = src_sample.load_bgr() + src_mask = src_sample.load_mask() + + # Toggle to see masks + show_masks = False + + grid = [] + for ct_sample in dst_samples: + print(src_sample.filename, ct_sample.filename) + ct_img = ct_sample.load_bgr() + ct_mask = ct_sample.load_mask() + + lct_img = linear_color_transfer(src_img, ct_img) + rct_img = reinhard_color_transfer(src_img, ct_img) + rct_img_clip = reinhard_color_transfer(src_img, ct_img, clip=True) + rct_img_paper = reinhard_color_transfer(src_img, ct_img, preserve_paper=True) + rct_img_paper_clip = reinhard_color_transfer(src_img, ct_img, clip=True, preserve_paper=True) + + masked_rct_img = reinhard_color_transfer(src_img, ct_img, source_mask=src_mask, target_mask=ct_mask) + masked_rct_img_clip = reinhard_color_transfer(src_img, ct_img, clip=True, source_mask=src_mask, target_mask=ct_mask) + masked_rct_img_paper = reinhard_color_transfer(src_img, ct_img, preserve_paper=True, source_mask=src_mask, target_mask=ct_mask) + masked_rct_img_paper_clip = reinhard_color_transfer(src_img, ct_img, clip=True, preserve_paper=True, source_mask=src_mask, target_mask=ct_mask) + + results = [lct_img, rct_img, rct_img_clip, rct_img_paper, rct_img_paper_clip, + masked_rct_img, masked_rct_img_clip, masked_rct_img_paper, masked_rct_img_paper_clip] + + if show_masks: + results = [src_mask * im for im in results] + src_img *= src_mask + ct_img *= ct_mask + + results = np.concatenate((src_img, ct_img, *results), axis=1) + grid.append(results) + + cv2.namedWindow('test output', cv2.WINDOW_NORMAL) + cv2.imshow('test output', np.concatenate(grid, axis=0)) + cv2.waitKey(0) + cv2.destroyAllWindows() + + def test_lct_algorithms(self): + src_samples = SampleLoader.load(SampleType.FACE, './test_src', None) + dst_samples = SampleLoader.load(SampleType.FACE, './test_dst', None) + + for src_sample in src_samples: + src_img = src_sample.load_bgr() + src_mask = src_sample.load_mask() + + # Toggle to see masks + show_masks = True + + grid = [] + for ct_sample in dst_samples: + print(src_sample.filename, ct_sample.filename) + ct_img = ct_sample.load_bgr() + ct_mask = ct_sample.load_mask() + + results = [] + for mode in ['sym']: + for eps in [10**-n for n in range(1, 10, 2)]: + results.append(linear_color_transfer(src_img, ct_img, mode=mode, eps=eps)) + + if show_masks: + results = [src_mask * im for im in results] + src_img *= src_mask + ct_img *= ct_mask + + results = np.concatenate((src_img, ct_img, *results), axis=1) + grid.append(results) + + cv2.namedWindow('test output', cv2.WINDOW_NORMAL) + cv2.imshow('test output', np.concatenate(grid, axis=0)) + cv2.waitKey(0) + cv2.destroyAllWindows() + + +if __name__ == '__main__': + unittest.main() diff --git a/imagelib/test/test output_screenshot_13.08.2019-1.png b/imagelib/test/test output_screenshot_13.08.2019-1.png new file mode 100644 index 0000000..e9bb0ae Binary files /dev/null and b/imagelib/test/test output_screenshot_13.08.2019-1.png differ diff --git a/imagelib/test/test output_screenshot_13.08.2019-2.png b/imagelib/test/test output_screenshot_13.08.2019-2.png new file mode 100644 index 0000000..be729f4 Binary files /dev/null and b/imagelib/test/test output_screenshot_13.08.2019-2.png differ diff --git a/imagelib/test/test output_screenshot_13.08.2019-3.png b/imagelib/test/test output_screenshot_13.08.2019-3.png new file mode 100644 index 0000000..ed475d5 Binary files /dev/null and b/imagelib/test/test output_screenshot_13.08.2019-3.png differ diff --git a/imagelib/test/test output_screenshot_14.08.2019-lca.png b/imagelib/test/test output_screenshot_14.08.2019-lca.png new file mode 100644 index 0000000..c02611f Binary files /dev/null and b/imagelib/test/test output_screenshot_14.08.2019-lca.png differ diff --git a/imagelib/test/test output_screenshot_14.08.2019-lct-2.png b/imagelib/test/test output_screenshot_14.08.2019-lct-2.png new file mode 100644 index 0000000..1798166 Binary files /dev/null and b/imagelib/test/test output_screenshot_14.08.2019-lct-2.png differ diff --git a/imagelib/test/test output_screenshot_14.08.2019-lct-eps-2.png b/imagelib/test/test output_screenshot_14.08.2019-lct-eps-2.png new file mode 100644 index 0000000..840b96c Binary files /dev/null and b/imagelib/test/test output_screenshot_14.08.2019-lct-eps-2.png differ diff --git a/imagelib/test/test output_screenshot_14.08.2019-lct-eps.png b/imagelib/test/test output_screenshot_14.08.2019-lct-eps.png new file mode 100644 index 0000000..25b28c2 Binary files /dev/null and b/imagelib/test/test output_screenshot_14.08.2019-lct-eps.png differ diff --git a/imagelib/test/test_dst/00109.jpg b/imagelib/test/test_dst/00109.jpg new file mode 100644 index 0000000..0a56e54 Binary files /dev/null and b/imagelib/test/test_dst/00109.jpg differ diff --git a/imagelib/test/test_dst/00283_0.jpg b/imagelib/test/test_dst/00283_0.jpg new file mode 100644 index 0000000..be361b1 Binary files /dev/null and b/imagelib/test/test_dst/00283_0.jpg differ diff --git a/imagelib/test/test_dst/00471_0.jpg b/imagelib/test/test_dst/00471_0.jpg new file mode 100644 index 0000000..6530f92 Binary files /dev/null and b/imagelib/test/test_dst/00471_0.jpg differ diff --git a/imagelib/test/test_dst/00730_0.jpg b/imagelib/test/test_dst/00730_0.jpg new file mode 100644 index 0000000..6ff81ec Binary files /dev/null and b/imagelib/test/test_dst/00730_0.jpg differ diff --git a/imagelib/test/test_dst/00752_0.jpg b/imagelib/test/test_dst/00752_0.jpg new file mode 100644 index 0000000..3d26678 Binary files /dev/null and b/imagelib/test/test_dst/00752_0.jpg differ diff --git a/imagelib/test/test_dst/01038_0.jpg b/imagelib/test/test_dst/01038_0.jpg new file mode 100644 index 0000000..4bba8bb Binary files /dev/null and b/imagelib/test/test_dst/01038_0.jpg differ diff --git a/imagelib/test/test_dst/01397_0.jpg b/imagelib/test/test_dst/01397_0.jpg new file mode 100644 index 0000000..4ac23da Binary files /dev/null and b/imagelib/test/test_dst/01397_0.jpg differ diff --git a/imagelib/test/test_dst/01730_0.jpg b/imagelib/test/test_dst/01730_0.jpg new file mode 100644 index 0000000..bad2915 Binary files /dev/null and b/imagelib/test/test_dst/01730_0.jpg differ diff --git a/imagelib/test/test_dst/01905_0.jpg b/imagelib/test/test_dst/01905_0.jpg new file mode 100644 index 0000000..a69f2c5 Binary files /dev/null and b/imagelib/test/test_dst/01905_0.jpg differ diff --git a/imagelib/test/test_src/03242.jpg b/imagelib/test/test_src/03242.jpg new file mode 100644 index 0000000..49a732c Binary files /dev/null and b/imagelib/test/test_src/03242.jpg differ diff --git a/imagelib/test/test_src/03255.jpg b/imagelib/test/test_src/03255.jpg new file mode 100644 index 0000000..6cc223d Binary files /dev/null and b/imagelib/test/test_src/03255.jpg differ diff --git a/imagelib/test/test_src/04710.jpg b/imagelib/test/test_src/04710.jpg new file mode 100644 index 0000000..aae2bd4 Binary files /dev/null and b/imagelib/test/test_src/04710.jpg differ diff --git a/imagelib_legacy/__init__.py b/imagelib_legacy/__init__.py new file mode 100644 index 0000000..6281e0e --- /dev/null +++ b/imagelib_legacy/__init__.py @@ -0,0 +1,3 @@ +from .color_transfer import color_hist_match +from .color_transfer import reinhard_color_transfer +from .color_transfer import linear_color_transfer diff --git a/imagelib_legacy/color_transfer.py b/imagelib_legacy/color_transfer.py new file mode 100644 index 0000000..545df3a --- /dev/null +++ b/imagelib_legacy/color_transfer.py @@ -0,0 +1,191 @@ +import numpy as np +import cv2 + +def reinhard_color_transfer(target, source, clip=False, preserve_paper=False, source_mask=None, target_mask=None): + """ + Transfers the color distribution from the source to the target + image using the mean and standard deviations of the L*a*b* + color space. + + This implementation is (loosely) based on to the "Color Transfer + between Images" paper by Reinhard et al., 2001. + + Parameters: + ------- + source: NumPy array + OpenCV image in BGR color space (the source image) + target: NumPy array + OpenCV image in BGR color space (the target image) + clip: Should components of L*a*b* image be scaled by np.clip before + converting back to BGR color space? + If False then components will be min-max scaled appropriately. + Clipping will keep target image brightness truer to the input. + Scaling will adjust image brightness to avoid washed out portions + in the resulting color transfer that can be caused by clipping. + preserve_paper: Should color transfer strictly follow methodology + layed out in original paper? The method does not always produce + aesthetically pleasing results. + If False then L*a*b* components will scaled using the reciprocal of + the scaling factor proposed in the paper. This method seems to produce + more consistently aesthetically pleasing results + + Returns: + ------- + transfer: NumPy array + OpenCV image (w, h, 3) NumPy array (uint8) + """ + + + # convert the images from the RGB to L*ab* color space, being + # sure to utilizing the floating point data type (note: OpenCV + # expects floats to be 32-bit, so use that instead of 64-bit) + source = cv2.cvtColor(source, cv2.COLOR_BGR2LAB).astype(np.float32) + target = cv2.cvtColor(target, cv2.COLOR_BGR2LAB).astype(np.float32) + + # compute color statistics for the source and target images + src_input = source if source_mask is None else source*source_mask + tgt_input = target if target_mask is None else target*target_mask + (lMeanSrc, lStdSrc, aMeanSrc, aStdSrc, bMeanSrc, bStdSrc) = lab_image_stats(src_input) + (lMeanTar, lStdTar, aMeanTar, aStdTar, bMeanTar, bStdTar) = lab_image_stats(tgt_input) + + # subtract the means from the target image + (l, a, b) = cv2.split(target) + l -= lMeanTar + a -= aMeanTar + b -= bMeanTar + + if preserve_paper: + # scale by the standard deviations using paper proposed factor + l = (lStdTar / lStdSrc) * l + a = (aStdTar / aStdSrc) * a + b = (bStdTar / bStdSrc) * b + else: + # scale by the standard deviations using reciprocal of paper proposed factor + l = (lStdSrc / lStdTar) * l + a = (aStdSrc / aStdTar) * a + b = (bStdSrc / bStdTar) * b + + # add in the source mean + l += lMeanSrc + a += aMeanSrc + b += bMeanSrc + + # clip/scale the pixel intensities to [0, 255] if they fall + # outside this range + l = _scale_array(l, clip=clip) + a = _scale_array(a, clip=clip) + b = _scale_array(b, clip=clip) + + # merge the channels together and convert back to the RGB color + # space, being sure to utilize the 8-bit unsigned integer data + # type + transfer = cv2.merge([l, a, b]) + transfer = cv2.cvtColor(transfer.astype(np.uint8), cv2.COLOR_LAB2BGR) + + # return the color transferred image + return transfer + +def linear_color_transfer(target_img, source_img, mode='pca', eps=1e-5): + ''' + Matches the colour distribution of the target image to that of the source image + using a linear transform. + Images are expected to be of form (w,h,c) and float in [0,1]. + Modes are chol, pca or sym for different choices of basis. + ''' + mu_t = target_img.mean(0).mean(0) + t = target_img - mu_t + t = t.transpose(2,0,1).reshape(3,-1) + Ct = t.dot(t.T) / t.shape[1] + eps * np.eye(t.shape[0]) + mu_s = source_img.mean(0).mean(0) + s = source_img - mu_s + s = s.transpose(2,0,1).reshape(3,-1) + Cs = s.dot(s.T) / s.shape[1] + eps * np.eye(s.shape[0]) + if mode == 'chol': + chol_t = np.linalg.cholesky(Ct) + chol_s = np.linalg.cholesky(Cs) + ts = chol_s.dot(np.linalg.inv(chol_t)).dot(t) + if mode == 'pca': + eva_t, eve_t = np.linalg.eigh(Ct) + Qt = eve_t.dot(np.sqrt(np.diag(eva_t))).dot(eve_t.T) + eva_s, eve_s = np.linalg.eigh(Cs) + Qs = eve_s.dot(np.sqrt(np.diag(eva_s))).dot(eve_s.T) + ts = Qs.dot(np.linalg.inv(Qt)).dot(t) + if mode == 'sym': + eva_t, eve_t = np.linalg.eigh(Ct) + Qt = eve_t.dot(np.sqrt(np.diag(eva_t))).dot(eve_t.T) + Qt_Cs_Qt = Qt.dot(Cs).dot(Qt) + eva_QtCsQt, eve_QtCsQt = np.linalg.eigh(Qt_Cs_Qt) + QtCsQt = eve_QtCsQt.dot(np.sqrt(np.diag(eva_QtCsQt))).dot(eve_QtCsQt.T) + ts = np.linalg.inv(Qt).dot(QtCsQt).dot(np.linalg.inv(Qt)).dot(t) + matched_img = ts.reshape(*target_img.transpose(2,0,1).shape).transpose(1,2,0) + matched_img += mu_s + matched_img[matched_img>1] = 1 + matched_img[matched_img<0] = 0 + return matched_img + +def lab_image_stats(image): + # compute the mean and standard deviation of each channel + (l, a, b) = cv2.split(image) + (lMean, lStd) = (l.mean(), l.std()) + (aMean, aStd) = (a.mean(), a.std()) + (bMean, bStd) = (b.mean(), b.std()) + + # return the color statistics + return (lMean, lStd, aMean, aStd, bMean, bStd) + +def _scale_array(arr, clip=True): + if clip: + return np.clip(arr, 0, 255) + + mn = arr.min() + mx = arr.max() + scale_range = (max([mn, 0]), min([mx, 255])) + + if mn < scale_range[0] or mx > scale_range[1]: + return (scale_range[1] - scale_range[0]) * (arr - mn) / (mx - mn) + scale_range[0] + + return arr + +def channel_hist_match(source, template, hist_match_threshold=255, mask=None): + # Code borrowed from: + # https://stackoverflow.com/questions/32655686/histogram-matching-of-two-images-in-python-2-x + masked_source = source + masked_template = template + + if mask is not None: + masked_source = source * mask + masked_template = template * mask + + oldshape = source.shape + source = source.ravel() + template = template.ravel() + masked_source = masked_source.ravel() + masked_template = masked_template.ravel() + s_values, bin_idx, s_counts = np.unique(source, return_inverse=True, + return_counts=True) + t_values, t_counts = np.unique(template, return_counts=True) + ms_values, mbin_idx, ms_counts = np.unique(source, return_inverse=True, + return_counts=True) + mt_values, mt_counts = np.unique(template, return_counts=True) + + s_quantiles = np.cumsum(s_counts).astype(np.float64) + s_quantiles = hist_match_threshold * s_quantiles / s_quantiles[-1] + t_quantiles = np.cumsum(t_counts).astype(np.float64) + t_quantiles = 255 * t_quantiles / t_quantiles[-1] + interp_t_values = np.interp(s_quantiles, t_quantiles, t_values) + + return interp_t_values[bin_idx].reshape(oldshape) + +def color_hist_match(src_im, tar_im, hist_match_threshold=255): + h,w,c = src_im.shape + matched_R = channel_hist_match(src_im[:,:,0], tar_im[:,:,0], hist_match_threshold, None) + matched_G = channel_hist_match(src_im[:,:,1], tar_im[:,:,1], hist_match_threshold, None) + matched_B = channel_hist_match(src_im[:,:,2], tar_im[:,:,2], hist_match_threshold, None) + + to_stack = (matched_R, matched_G, matched_B) + for i in range(3, c): + to_stack += ( src_im[:,:,i],) + + + matched = np.stack(to_stack, axis=-1).astype(src_im.dtype) + return matched diff --git a/models/Model_SAE/Model.py b/models/Model_SAE/Model.py index 8f492d2..3c0d375 100644 --- a/models/Model_SAE/Model.py +++ b/models/Model_SAE/Model.py @@ -7,10 +7,13 @@ from facelib import FaceType from samplelib import * from interact import interact as io +from samplelib.SampleProcessor import ColorTransferMode # SAE - Styled AutoEncoder + + class SAEModel(ModelBase): encoderH5 = 'encoder.h5' inter_BH5 = 'inter_B.h5' @@ -121,11 +124,16 @@ class SAEModel(ModelBase): help_message="Learn to transfer image around face. This can make face more like dst. Enabling this option increases the chance of model collapse."), 0.0, 100.0) - default_apply_random_ct = False if is_first_run else self.options.get('apply_random_ct', False) - self.options['apply_random_ct'] = io.input_bool( - "Apply random color transfer to src faceset? (y/n, ?:help skip:%s) : " % ( - yn_str[bool(default_apply_random_ct)]), bool(default_apply_random_ct), - help_message="Increase variativity of src samples by apply LCT color transfer from random dst samples. It is like 'face_style' learning, but more precise color transfer and without risk of model collapse, also it does not require additional GPU resources, but the training time may be longer, due to the src faceset is becoming more diverse.") + default_apply_random_ct = ColorTransferMode.NONE if is_first_run else self.options.get('apply_random_ct', ColorTransferMode.NONE) + self.options['apply_random_ct'] = np.clip(io.input_int( + "Apply random color transfer to src faceset? (0) None, (1) LCT, (2) RCT, (3) RCT-c, (4) RCT-p, " + "(5) RCT-pc, (6) mRTC, (7) mRTC-c, (8) mRTC-p, (9) mRTC-pc ?:help skip:%s) : " % default_apply_random_ct, + default_apply_random_ct, + help_message="Increase variativity of src samples by apply LCT color transfer from random dst " + "samples. It is like 'face_style' learning, but more precise color transfer and without " + "risk of model collapse, also it does not require additional GPU resources, " + "but the training time may be longer, due to the src faceset is becoming more diverse."), + ColorTransferMode.NONE, ColorTransferMode.MASKED_RCT_PAPER_CLIP) if nnlib.device.backend != 'plaidML': # todo https://github.com/plaidml/plaidml/issues/301 default_clipgrad = False if is_first_run else self.options.get('clipgrad', False) @@ -139,7 +147,7 @@ class SAEModel(ModelBase): self.options['pixel_loss'] = self.options.get('pixel_loss', False) self.options['face_style_power'] = self.options.get('face_style_power', default_face_style_power) self.options['bg_style_power'] = self.options.get('bg_style_power', default_bg_style_power) - self.options['apply_random_ct'] = self.options.get('apply_random_ct', False) + self.options['apply_random_ct'] = self.options.get('apply_random_ct', ColorTransferMode.NONE) self.options['clipgrad'] = self.options.get('clipgrad', False) if is_first_run: @@ -170,7 +178,8 @@ class SAEModel(ModelBase): self.ms_count = ms_count = 3 if (self.options['multiscale_decoder']) else 1 - apply_random_ct = self.options.get('apply_random_ct', False) + + apply_random_ct = self.options.get('apply_random_ct', ColorTransferMode.NONE) masked_training = True warped_src = Input(bgr_shape) @@ -463,7 +472,7 @@ class SAEModel(ModelBase): self.set_training_data_generators([ SampleGeneratorFace(training_data_src_path, sort_by_yaw_target_samples_path=training_data_dst_path if sort_by_yaw else None, - random_ct_samples_path=training_data_dst_path if apply_random_ct else None, + random_ct_samples_path=training_data_dst_path if apply_random_ct != ColorTransferMode.NONE else None, debug=self.is_debug(), batch_size=self.batch_size, sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, @@ -604,6 +613,7 @@ class SAEModel(ModelBase): face_type = FaceType.FULL if self.options['face_type'] == 'f' else FaceType.HALF from converters import ConverterMasked + return ConverterMasked(self.predictor_func, predictor_input_size=self.options['resolution'], predictor_masked=self.options['learn_mask'], diff --git a/samplelib/Sample.py b/samplelib/Sample.py index d22a869..1b85f9f 100644 --- a/samplelib/Sample.py +++ b/samplelib/Sample.py @@ -4,6 +4,7 @@ from pathlib import Path import cv2 import numpy as np +from facelib import LandmarksProcessor from utils.cv2_utils import * from utils.DFLJPG import DFLJPG from utils.DFLPNG import DFLPNG @@ -68,6 +69,12 @@ class Sample(object): return None + def load_image_hull_mask(self): + return LandmarksProcessor.get_image_hull_mask(self.load_bgr().shape, self.landmarks) + + def load_mask(self): + return self.load_fanseg_mask() or self.load_image_hull_mask() + def get_random_close_target_sample(self): if self.close_target_list is None: return None diff --git a/samplelib/SampleGeneratorFace.py b/samplelib/SampleGeneratorFace.py index 87a7e5b..16f8f47 100644 --- a/samplelib/SampleGeneratorFace.py +++ b/samplelib/SampleGeneratorFace.py @@ -17,8 +17,13 @@ output_sample_types = [ ... ] ''' + + class SampleGeneratorFace(SampleGeneratorBase): - def __init__ (self, samples_path, debug, batch_size, sort_by_yaw=False, sort_by_yaw_target_samples_path=None, random_ct_samples_path=None, sample_process_options=SampleProcessor.Options(), output_sample_types=[], add_sample_idx=False, generators_count=2, generators_random_seed=None, **kwargs): + def __init__(self, samples_path, debug, batch_size, sort_by_yaw=False, sort_by_yaw_target_samples_path=None, + random_ct_samples_path=None, sample_process_options=SampleProcessor.Options(), + output_sample_types=[], add_sample_idx=False, generators_count=2, generators_random_seed=None, + **kwargs): super().__init__(samples_path, debug, batch_size) self.sample_process_options = sample_process_options self.output_sample_types = output_sample_types @@ -36,17 +41,20 @@ class SampleGeneratorFace(SampleGeneratorBase): self.generators_random_seed = generators_random_seed - samples = SampleLoader.load (self.sample_type, self.samples_path, sort_by_yaw_target_samples_path) + samples = SampleLoader.load(self.sample_type, self.samples_path, sort_by_yaw_target_samples_path) - ct_samples = SampleLoader.load (SampleType.FACE, random_ct_samples_path) if random_ct_samples_path is not None else None + ct_samples = SampleLoader.load(SampleType.FACE, + random_ct_samples_path) if random_ct_samples_path is not None else None self.random_ct_sample_chance = 100 if self.debug: self.generators_count = 1 - self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (0, samples, ct_samples) )] + self.generators = [iter_utils.ThisThreadGenerator(self.batch_func, (0, samples, ct_samples))] else: - self.generators_count = min ( generators_count, len(samples) ) - self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (i, samples[i::self.generators_count], ct_samples ) ) for i in range(self.generators_count) ] + self.generators_count = min(generators_count, len(samples)) + self.generators = [ + iter_utils.SubprocessGenerator(self.batch_func, (i, samples[i::self.generators_count], ct_samples)) for + i in range(self.generators_count)] self.generator_counter = -1 @@ -55,14 +63,14 @@ class SampleGeneratorFace(SampleGeneratorBase): def __next__(self): self.generator_counter += 1 - generator = self.generators[self.generator_counter % len(self.generators) ] + generator = self.generators[self.generator_counter % len(self.generators)] return next(generator) - def batch_func(self, param ): + def batch_func(self, param): generator_id, samples, ct_samples = param if self.generators_random_seed is not None: - np.random.seed ( self.generators_random_seed[generator_id] ) + np.random.seed(self.generators_random_seed[generator_id]) samples_len = len(samples) samples_idxs = [*range(samples_len)] @@ -73,14 +81,14 @@ class SampleGeneratorFace(SampleGeneratorBase): raise ValueError('No training data provided.') if self.sample_type == SampleType.FACE_YAW_SORTED or self.sample_type == SampleType.FACE_YAW_SORTED_AS_TARGET: - if all ( [ samples[idx] == None for idx in samples_idxs] ): + if all([samples[idx] == None for idx in samples_idxs]): raise ValueError('Not enough training data. Gather more faces!') if self.sample_type == SampleType.FACE: shuffle_idxs = [] elif self.sample_type == SampleType.FACE_YAW_SORTED or self.sample_type == SampleType.FACE_YAW_SORTED_AS_TARGET: shuffle_idxs = [] - shuffle_idxs_2D = [[]]*samples_len + shuffle_idxs_2D = [[]] * samples_len while True: batches = None @@ -94,7 +102,7 @@ class SampleGeneratorFace(SampleGeneratorBase): np.random.shuffle(shuffle_idxs) idx = shuffle_idxs.pop() - sample = samples[ idx ] + sample = samples[idx] elif self.sample_type == SampleType.FACE_YAW_SORTED or self.sample_type == SampleType.FACE_YAW_SORTED_AS_TARGET: if len(shuffle_idxs) == 0: @@ -104,8 +112,8 @@ class SampleGeneratorFace(SampleGeneratorBase): idx = shuffle_idxs.pop() if samples[idx] != None: if len(shuffle_idxs_2D[idx]) == 0: - a = shuffle_idxs_2D[idx] = [ *range(len(samples[idx])) ] - np.random.shuffle (a) + a = shuffle_idxs_2D[idx] = [*range(len(samples[idx]))] + np.random.shuffle(a) idx2 = shuffle_idxs_2D[idx].pop() sample = samples[idx][idx2] @@ -114,32 +122,34 @@ class SampleGeneratorFace(SampleGeneratorBase): if sample is not None: try: - ct_sample=None - if ct_samples is not None: + ct_sample = None + if ct_samples is not None: if np.random.randint(100) < self.random_ct_sample_chance: - ct_sample=ct_samples[np.random.randint(ct_samples_len)] - - x = SampleProcessor.process (sample, self.sample_process_options, self.output_sample_types, self.debug, ct_sample=ct_sample) + ct_sample = ct_samples[np.random.randint(ct_samples_len)] + + x = SampleProcessor.process(sample, self.sample_process_options, self.output_sample_types, + self.debug, ct_sample=ct_sample) except: - raise Exception ("Exception occured in sample %s. Error: %s" % (sample.filename, traceback.format_exc() ) ) + raise Exception( + "Exception occured in sample %s. Error: %s" % (sample.filename, traceback.format_exc())) if type(x) != tuple and type(x) != list: raise Exception('SampleProcessor.process returns NOT tuple/list') if batches is None: - batches = [ [] for _ in range(len(x)) ] + batches = [[] for _ in range(len(x))] if self.add_sample_idx: - batches += [ [] ] - i_sample_idx = len(batches)-1 + batches += [[]] + i_sample_idx = len(batches) - 1 for i in range(len(x)): - batches[i].append ( x[i] ) + batches[i].append(x[i]) if self.add_sample_idx: - batches[i_sample_idx].append (idx) + batches[i_sample_idx].append(idx) break - yield [ np.array(batch) for batch in batches] + yield [np.array(batch) for batch in batches] def update_batch(self, batch_size): self.batch_size = batch_size diff --git a/samplelib/SampleProcessor.py b/samplelib/SampleProcessor.py index aafdfa1..28f5298 100644 --- a/samplelib/SampleProcessor.py +++ b/samplelib/SampleProcessor.py @@ -42,6 +42,19 @@ opts: """ +class ColorTransferMode(IntEnum): + NONE = 0 + LCT = 1 + RCT = 2 + RCT_CLIP = 3 + RCT_PAPER = 4 + RCT_PAPER_CLIP = 5 + MASKED_RCT = 6 + MASKED_RCT_CLIP = 7 + MASKED_RCT_PAPER = 8 + MASKED_RCT_PAPER_CLIP = 9 + + class SampleProcessor(object): class Types(IntEnum): NONE = 0 @@ -120,7 +133,7 @@ class SampleProcessor(object): normalize_std_dev = opts.get('normalize_std_dev', False) normalize_vgg = opts.get('normalize_vgg', False) motion_blur = opts.get('motion_blur', None) - apply_ct = opts.get('apply_ct', False) + apply_ct = opts.get('apply_ct', ColorTransferMode.NONE) normalize_tanh = opts.get('normalize_tanh', False) img_type = SPTF.NONE @@ -223,7 +236,32 @@ class SampleProcessor(object): if apply_ct and ct_sample is not None: if ct_sample_bgr is None: ct_sample_bgr = ct_sample.load_bgr() - img_bgr = imagelib.reinhard_color_transfer(img_bgr, ct_sample_bgr, clip=True) + + if apply_ct == ColorTransferMode.LCT: + img_bgr = imagelib.linear_color_transfer(img_bgr, ct_sample_bgr) + + elif ColorTransferMode.RCT <= apply_ct <= ColorTransferMode.MASKED_RCT_PAPER_CLIP: + ct_options = { + ColorTransferMode.RCT: (False, False, False), + ColorTransferMode.RCT_CLIP: (False, False, True), + ColorTransferMode.RCT_PAPER: (False, True, False), + ColorTransferMode.RCT_PAPER_CLIP: (False, True, True), + ColorTransferMode.MASKED_RCT: (True, False, False), + ColorTransferMode.MASKED_RCT_CLIP: (True, False, True), + ColorTransferMode.MASKED_RCT_PAPER: (True, True, False), + ColorTransferMode.MASKED_RCT_PAPER_CLIP: (True, True, True), + } + + use_masks, use_paper, use_clip = ct_options[apply_ct] + if not use_masks: + img_bgr = imagelib.reinhard_color_transfer(img_bgr, ct_sample_bgr, clip=use_clip, + preserve_paper=use_paper) + else: + if ct_sample_mask is None: + ct_sample_mask = ct_sample.load_mask() + img_bgr = imagelib.reinhard_color_transfer(img_bgr, ct_sample_bgr, clip=use_clip, + preserve_paper=use_paper, source_mask=img_mask, + target_mask=ct_sample_mask) if normalize_std_dev: img_bgr = (img_bgr - img_bgr.mean((0, 1))) / img_bgr.std((0, 1))