Merge pull request #1 from faceshiftlabs/pull-request/fix/fs-aug-random-seeding

Fixes bug in fs-aug implementation
This commit is contained in:
Ognjen 2021-03-24 22:19:43 +01:00 committed by GitHub
commit 239e0fdd5d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 14 deletions

View file

@ -1,7 +1,7 @@
import cv2 import cv2
import numpy as np import numpy as np
from numpy import linalg as npla from numpy import linalg as npla
from random import random, shuffle, choice import random
from scipy.stats import special_ortho_group from scipy.stats import special_ortho_group
import scipy as sp import scipy as sp
@ -92,7 +92,7 @@ def color_transfer_mkl(x0, x1):
def color_transfer_idt(i0, i1, bins=256, n_rot=20): def color_transfer_idt(i0, i1, bins=256, n_rot=20):
import scipy.stats import scipy.stats
relaxation = 1 / n_rot relaxation = 1 / n_rot
h,w,c = i0.shape h,w,c = i0.shape
h1,w1,c1 = i1.shape h1,w1,c1 = i1.shape
@ -371,23 +371,24 @@ def color_transfer(ct_mode, img_src, img_trg):
# imported from faceswap # imported from faceswap
def color_augmentation(img): def color_augmentation(img, seed=None):
""" Color adjust RGB image """ """ Color adjust RGB image """
face = img face = img
face = np.clip(face*255.0, 0, 255).astype(np.uint8) face = np.clip(face*255.0, 0, 255).astype(np.uint8)
face = random_clahe(face) face = random_clahe(face, seed)
face = random_lab(face) face = random_lab(face, seed)
img[:, :, :3] = face img[:, :, :3] = face
return (face / 255.0).astype(np.float32) return (face / 255.0).astype(np.float32)
def random_lab(image): def random_lab(image, seed=None):
""" Perform random color/lightness adjustment in L*a*b* colorspace """ """ Perform random color/lightness adjustment in L*a*b* colorspace """
random.seed(seed)
amount_l = 30 / 100 amount_l = 30 / 100
amount_ab = 8 / 100 amount_ab = 8 / 100
randoms = [(random() * amount_l * 2) - amount_l, # L adjust randoms = [(random.random() * amount_l * 2) - amount_l, # L adjust
(random() * amount_ab * 2) - amount_ab, # A adjust (random.random() * amount_ab * 2) - amount_ab, # A adjust
(random() * amount_ab * 2) - amount_ab] # B adjust (random.random() * amount_ab * 2) - amount_ab] # B adjust
image = cv2.cvtColor( # pylint:disable=no-member image = cv2.cvtColor( # pylint:disable=no-member
image, cv2.COLOR_BGR2LAB).astype("float32") / 255.0 # pylint:disable=no-member image, cv2.COLOR_BGR2LAB).astype("float32") / 255.0 # pylint:disable=no-member
@ -400,15 +401,16 @@ def random_lab(image):
cv2.COLOR_LAB2BGR) # pylint:disable=no-member cv2.COLOR_LAB2BGR) # pylint:disable=no-member
return image return image
def random_clahe(image): def random_clahe(image, seed=None):
""" Randomly perform Contrast Limited Adaptive Histogram Equalization """ """ Randomly perform Contrast Limited Adaptive Histogram Equalization """
contrast_random = random() random.seed(seed)
contrast_random = random.random()
if contrast_random > 50 / 100: if contrast_random > 50 / 100:
return image return image
# base_contrast = image.shape[0] // 128 # base_contrast = image.shape[0] // 128
base_contrast = 1 # testing because it breaks on small sizes base_contrast = 1 # testing because it breaks on small sizes
grid_base = random() * 4 grid_base = random.random() * 4
contrast_adjustment = int(grid_base * (base_contrast / 2)) contrast_adjustment = int(grid_base * (base_contrast / 2))
grid_size = base_contrast + contrast_adjustment grid_size = base_contrast + contrast_adjustment
@ -416,4 +418,4 @@ def random_clahe(image):
tileGridSize=(grid_size, grid_size)) tileGridSize=(grid_size, grid_size))
for chan in range(3): for chan in range(3):
image[:, :, chan] = clahe.apply(image[:, :, chan]) image[:, :, chan] = clahe.apply(image[:, :, chan])
return image return image

View file

@ -204,7 +204,7 @@ class SampleProcessor(object):
# Apply random color transfer # Apply random color transfer
if ct_mode is not None and ct_sample is not None or ct_mode == 'fs-aug': if ct_mode is not None and ct_sample is not None or ct_mode == 'fs-aug':
if ct_mode == 'fs-aug': if ct_mode == 'fs-aug':
img = imagelib.color_augmentation(img) img = imagelib.color_augmentation(img, sample_rnd_seed)
else: else:
if ct_sample_bgr is None: if ct_sample_bgr is None:
ct_sample_bgr = ct_sample.load_bgr() ct_sample_bgr = ct_sample.load_bgr()