mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-08-22 06:23:20 -07:00
Adds masking to RCT
This commit is contained in:
parent
ccd51c52ac
commit
3d4886cd36
2 changed files with 23 additions and 26 deletions
|
@ -2,7 +2,7 @@ import numpy as np
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
|
|
||||||
def reinhard_color_transfer(target, source, clip=False, preserve_paper=False, source_mask=None, target_mask=None):
|
def reinhard_color_transfer(target, source, clip=False, preserve_paper=False, target_mask=None, source_mask=None):
|
||||||
"""
|
"""
|
||||||
Transfers the color distribution from the source to the target
|
Transfers the color distribution from the source to the target
|
||||||
image using the mean and standard deviations of the L*a*b*
|
image using the mean and standard deviations of the L*a*b*
|
||||||
|
@ -51,8 +51,6 @@ def reinhard_color_transfer(target, source, clip=False, preserve_paper=False, so
|
||||||
target = cv2.cvtColor(target.astype(np.float32), cv2.COLOR_BGR2LAB)
|
target = cv2.cvtColor(target.astype(np.float32), cv2.COLOR_BGR2LAB)
|
||||||
|
|
||||||
# compute color statistics for the source and target images
|
# compute color statistics for the source and target images
|
||||||
src_input = source if source_mask is None else source * source_mask
|
|
||||||
tgt_input = target if target_mask is None else target * target_mask
|
|
||||||
(lMeanSrc, lStdSrc, aMeanSrc, aStdSrc, bMeanSrc, bStdSrc) = lab_image_stats(src_input)
|
(lMeanSrc, lStdSrc, aMeanSrc, aStdSrc, bMeanSrc, bStdSrc) = lab_image_stats(src_input)
|
||||||
(lMeanTar, lStdTar, aMeanTar, aStdTar, bMeanTar, bStdTar) = lab_image_stats(tgt_input)
|
(lMeanTar, lStdTar, aMeanTar, aStdTar, bMeanTar, bStdTar) = lab_image_stats(tgt_input)
|
||||||
|
|
||||||
|
@ -138,32 +136,20 @@ def linear_color_transfer(target_img, source_img, mode='pca', eps=1e-5):
|
||||||
return matched_img
|
return matched_img
|
||||||
|
|
||||||
|
|
||||||
def linear_lab_color_transform(target_img, source_img, eps=1e-5, mode='pca'):
|
def lab_image_stats(image, mask=None):
|
||||||
"""doesn't work yet"""
|
|
||||||
np.clip(source_img, 0, 1, out=source_img)
|
|
||||||
np.clip(target_img, 0, 1, out=target_img)
|
|
||||||
|
|
||||||
# convert the images from the RGB to L*ab* color space, being
|
|
||||||
# sure to utilizing the floating point data type (note: OpenCV
|
|
||||||
# expects floats to be 32-bit, so use that instead of 64-bit)
|
|
||||||
source_img = cv2.cvtColor(source_img.astype(np.float32), cv2.COLOR_BGR2LAB)
|
|
||||||
target_img = cv2.cvtColor(target_img.astype(np.float32), cv2.COLOR_BGR2LAB)
|
|
||||||
|
|
||||||
target_img = linear_color_transfer(target_img, source_img, mode=mode, eps=eps)
|
|
||||||
target_img = cv2.cvtColor(np.clip(target_img, 0, 1).astype(np.float32), cv2.COLOR_LAB2BGR)
|
|
||||||
np.clip(target_img, 0, 1, out=target_img)
|
|
||||||
return target_img
|
|
||||||
|
|
||||||
|
|
||||||
def lab_image_stats(image):
|
|
||||||
# compute the mean and standard deviation of each channel
|
# compute the mean and standard deviation of each channel
|
||||||
|
if mask is not None:
|
||||||
|
mask = np.squeeze(mask)
|
||||||
|
image = image[mask == 1]
|
||||||
|
|
||||||
(l, a, b) = cv2.split(image)
|
(l, a, b) = cv2.split(image)
|
||||||
(lMean, lStd) = (l.mean(), l.std())
|
|
||||||
(aMean, aStd) = (a.mean(), a.std())
|
l_mean, l_std = np.mean(l), np.std(l)
|
||||||
(bMean, bStd) = (b.mean(), b.std())
|
a_mean, a_std = np.mean(l), np.std(l)
|
||||||
|
b_mean, b_std = np.mean(l), np.std(l)
|
||||||
|
|
||||||
# return the color statistics
|
# return the color statistics
|
||||||
return (lMean, lStd, aMean, aStd, bMean, bStd)
|
return l_mean, l_std, a_mean, a_std, b_mean, b_std
|
||||||
|
|
||||||
|
|
||||||
def _scale_array(arr, clip=True):
|
def _scale_array(arr, clip=True):
|
||||||
|
|
|
@ -223,7 +223,18 @@ class SampleProcessor(object):
|
||||||
if apply_ct and ct_sample is not None:
|
if apply_ct and ct_sample is not None:
|
||||||
if ct_sample_bgr is None:
|
if ct_sample_bgr is None:
|
||||||
ct_sample_bgr = ct_sample.load_bgr()
|
ct_sample_bgr = ct_sample.load_bgr()
|
||||||
img_bgr = imagelib.reinhard_color_transfer(img_bgr, ct_sample_bgr, clip=True)
|
if ct_sample_mask is None:
|
||||||
|
ct_sample_mask = ct_sample.load_fanseg_mask() or \
|
||||||
|
LandmarksProcessor.get_image_hull_mask(ct_sample_bgr.shape, ct_sample.landmarks)
|
||||||
|
|
||||||
|
img_bgr = imagelib.reinhard_color_transfer(img_bgr,
|
||||||
|
ct_sample_bgr,
|
||||||
|
clip=True,
|
||||||
|
target_mask=img_mask,
|
||||||
|
source_mask=ct_sample_mask)
|
||||||
|
# img_bgr = imagelib.reinhard_color_transfer(img_bgr,
|
||||||
|
# ct_sample_bgr,
|
||||||
|
# clip=True)
|
||||||
|
|
||||||
if normalize_std_dev:
|
if normalize_std_dev:
|
||||||
img_bgr = (img_bgr - img_bgr.mean((0, 1))) / img_bgr.std((0, 1))
|
img_bgr = (img_bgr - img_bgr.mean((0, 1))) / img_bgr.std((0, 1))
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue