mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-08-22 06:23:20 -07:00
If using mask, scale the masked portion of image
This commit is contained in:
parent
aaff2f811e
commit
390a9638d5
2 changed files with 40 additions and 37 deletions
|
@ -77,11 +77,11 @@ def reinhard_color_transfer(source, target, clip=False, preserve_paper=False, so
|
||||||
a += aMeanTar
|
a += aMeanTar
|
||||||
b += bMeanTar
|
b += bMeanTar
|
||||||
|
|
||||||
# clip/scale the pixel intensities to [0, 255] if they fall
|
# clip/scale the pixel intensities if they fall
|
||||||
# outside this range
|
# outside the ranges for LAB
|
||||||
l = _scale_array(l, 0, 100, clip=clip)
|
l = _scale_array(l, 0, 100, clip=clip, mask=source_mask)
|
||||||
a = _scale_array(a, -127, 127, clip=clip)
|
a = _scale_array(a, -127, 127, clip=clip, mask=source_mask)
|
||||||
b = _scale_array(b, -127, 127, clip=clip)
|
b = _scale_array(b, -127, 127, clip=clip, mask=source_mask)
|
||||||
|
|
||||||
# merge the channels together and convert back to the RGB color
|
# merge the channels together and convert back to the RGB color
|
||||||
transfer = cv2.merge([l, a, b])
|
transfer = cv2.merge([l, a, b])
|
||||||
|
@ -180,7 +180,7 @@ def _min_max_scale(arr, new_range=(0, 255)):
|
||||||
return scaled
|
return scaled
|
||||||
|
|
||||||
|
|
||||||
def _scale_array(arr, mn, mx, clip=True):
|
def _scale_array(arr, mn, mx, clip=True, mask=None):
|
||||||
"""
|
"""
|
||||||
Trim NumPy array values to be in [0, 255] range with option of
|
Trim NumPy array values to be in [0, 255] range with option of
|
||||||
clipping or scaling.
|
clipping or scaling.
|
||||||
|
@ -197,7 +197,10 @@ def _scale_array(arr, mn, mx, clip=True):
|
||||||
if clip:
|
if clip:
|
||||||
scaled = np.clip(arr, mn, mx)
|
scaled = np.clip(arr, mn, mx)
|
||||||
else:
|
else:
|
||||||
scale_range = (max([arr.min(), mn]), min([arr.max(), mx]))
|
if mask is not None:
|
||||||
|
scale_range = (max([np.min(mask * arr), mn]), min([np.max(mask * arr), mx]))
|
||||||
|
else:
|
||||||
|
scale_range = (max([np.min(arr), mn]), min([np.max(arr), mx]))
|
||||||
scaled = _min_max_scale(arr, new_range=scale_range)
|
scaled = _min_max_scale(arr, new_range=scale_range)
|
||||||
|
|
||||||
return scaled
|
return scaled
|
||||||
|
|
|
@ -15,44 +15,44 @@ class ColorTranfer(unittest.TestCase):
|
||||||
src_samples = SampleLoader.load(SampleType.FACE, './test_src', None)
|
src_samples = SampleLoader.load(SampleType.FACE, './test_src', None)
|
||||||
dst_samples = SampleLoader.load(SampleType.FACE, './test_dst', None)
|
dst_samples = SampleLoader.load(SampleType.FACE, './test_dst', None)
|
||||||
|
|
||||||
src_sample = src_samples[2]
|
for src_sample in src_samples:
|
||||||
src_img = src_sample.load_bgr()
|
src_img = src_sample.load_bgr()
|
||||||
src_mask = src_sample.load_mask()
|
src_mask = src_sample.load_mask()
|
||||||
|
|
||||||
# Toggle to see masks
|
# Toggle to see masks
|
||||||
show_masks = False
|
show_masks = False
|
||||||
|
|
||||||
grid = []
|
grid = []
|
||||||
for ct_sample in dst_samples:
|
for ct_sample in dst_samples:
|
||||||
print(src_sample.filename, ct_sample.filename)
|
print(src_sample.filename, ct_sample.filename)
|
||||||
ct_img = ct_sample.load_bgr()
|
ct_img = ct_sample.load_bgr()
|
||||||
ct_mask = ct_sample.load_mask()
|
ct_mask = ct_sample.load_mask()
|
||||||
|
|
||||||
lct_img = linear_color_transfer(src_img, ct_img)
|
lct_img = linear_color_transfer(src_img, ct_img)
|
||||||
rct_img = reinhard_color_transfer(src_img, ct_img)
|
rct_img = reinhard_color_transfer(src_img, ct_img)
|
||||||
rct_img_clip = reinhard_color_transfer(src_img, ct_img, clip=True)
|
rct_img_clip = reinhard_color_transfer(src_img, ct_img, clip=True)
|
||||||
rct_img_paper = reinhard_color_transfer(src_img, ct_img, preserve_paper=True)
|
rct_img_paper = reinhard_color_transfer(src_img, ct_img, preserve_paper=True)
|
||||||
rct_img_paper_clip = reinhard_color_transfer(src_img, ct_img, clip=True, preserve_paper=True)
|
rct_img_paper_clip = reinhard_color_transfer(src_img, ct_img, clip=True, preserve_paper=True)
|
||||||
|
|
||||||
masked_rct_img = reinhard_color_transfer(src_img, ct_img, source_mask=src_mask, target_mask=ct_mask)
|
masked_rct_img = reinhard_color_transfer(src_img, ct_img, source_mask=src_mask, target_mask=ct_mask)
|
||||||
masked_rct_img_clip = reinhard_color_transfer(src_img, ct_img, clip=True, source_mask=src_mask, target_mask=ct_mask)
|
masked_rct_img_clip = reinhard_color_transfer(src_img, ct_img, clip=True, source_mask=src_mask, target_mask=ct_mask)
|
||||||
masked_rct_img_paper = reinhard_color_transfer(src_img, ct_img, preserve_paper=True, source_mask=src_mask, target_mask=ct_mask)
|
masked_rct_img_paper = reinhard_color_transfer(src_img, ct_img, preserve_paper=True, source_mask=src_mask, target_mask=ct_mask)
|
||||||
masked_rct_img_paper_clip = reinhard_color_transfer(src_img, ct_img, clip=True, preserve_paper=True, source_mask=src_mask, target_mask=ct_mask)
|
masked_rct_img_paper_clip = reinhard_color_transfer(src_img, ct_img, clip=True, preserve_paper=True, source_mask=src_mask, target_mask=ct_mask)
|
||||||
|
|
||||||
results = [lct_img, rct_img, rct_img_clip, rct_img_paper, rct_img_paper_clip,
|
results = [lct_img, rct_img, rct_img_clip, rct_img_paper, rct_img_paper_clip,
|
||||||
masked_rct_img, masked_rct_img_clip, masked_rct_img_paper, masked_rct_img_paper_clip]
|
masked_rct_img, masked_rct_img_clip, masked_rct_img_paper, masked_rct_img_paper_clip]
|
||||||
|
|
||||||
if show_masks:
|
if show_masks:
|
||||||
results = [src_mask * im for im in results]
|
results = [src_mask * im for im in results]
|
||||||
src_img *= src_mask
|
src_img *= src_mask
|
||||||
ct_img *= ct_mask
|
ct_img *= ct_mask
|
||||||
|
|
||||||
results = np.concatenate((src_img, ct_img, *results), axis=1)
|
results = np.concatenate((src_img, ct_img, *results), axis=1)
|
||||||
grid.append(results)
|
grid.append(results)
|
||||||
|
|
||||||
cv2.namedWindow('test output', cv2.WINDOW_NORMAL)
|
cv2.namedWindow('test output', cv2.WINDOW_NORMAL)
|
||||||
cv2.imshow('test output', np.concatenate(grid, axis=0))
|
cv2.imshow('test output', np.concatenate(grid, axis=0))
|
||||||
cv2.waitKey(0)
|
cv2.waitKey(0)
|
||||||
cv2.destroyAllWindows()
|
cv2.destroyAllWindows()
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue