mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 04:52:13 -07:00
XSeg trainer: added random relighting sample augmentation to improve generalization
This commit is contained in:
parent
23130cd56a
commit
e53d1b1820
3 changed files with 103 additions and 2 deletions
|
@ -27,4 +27,5 @@ from .filters import apply_random_rgb_levels, \
|
|||
apply_random_gaussian_blur, \
|
||||
apply_random_nearest_resize, \
|
||||
apply_random_bilinear_resize, \
|
||||
apply_random_jpeg_compress
|
||||
apply_random_jpeg_compress, \
|
||||
apply_random_relight
|
||||
|
|
|
@ -126,4 +126,98 @@ def apply_random_jpeg_compress( img, chance, mask=None, rnd_state=None ):
|
|||
if mask is not None:
|
||||
result = img*(1-mask) + result*mask
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _min_resize(x, m):
|
||||
if x.shape[0] < x.shape[1]:
|
||||
s0 = m
|
||||
s1 = int(float(m) / float(x.shape[0]) * float(x.shape[1]))
|
||||
else:
|
||||
s0 = int(float(m) / float(x.shape[1]) * float(x.shape[0]))
|
||||
s1 = m
|
||||
new_max = min(s1, s0)
|
||||
raw_max = min(x.shape[0], x.shape[1])
|
||||
return cv2.resize(x, (s1, s0), interpolation=cv2.INTER_LANCZOS4)
|
||||
|
||||
def _d_resize(x, d, fac=1.0):
|
||||
new_min = min(int(d[1] * fac), int(d[0] * fac))
|
||||
raw_min = min(x.shape[0], x.shape[1])
|
||||
if new_min < raw_min:
|
||||
interpolation = cv2.INTER_AREA
|
||||
else:
|
||||
interpolation = cv2.INTER_LANCZOS4
|
||||
y = cv2.resize(x, (int(d[1] * fac), int(d[0] * fac)), interpolation=interpolation)
|
||||
return y
|
||||
|
||||
def _get_image_gradient(dist):
|
||||
cols = cv2.filter2D(dist, cv2.CV_32F, np.array([[-1, 0, +1], [-2, 0, +2], [-1, 0, +1]]))
|
||||
rows = cv2.filter2D(dist, cv2.CV_32F, np.array([[-1, -2, -1], [0, 0, 0], [+1, +2, +1]]))
|
||||
return cols, rows
|
||||
|
||||
def _generate_lighting_effects(content):
|
||||
h512 = content
|
||||
h256 = cv2.pyrDown(h512)
|
||||
h128 = cv2.pyrDown(h256)
|
||||
h64 = cv2.pyrDown(h128)
|
||||
h32 = cv2.pyrDown(h64)
|
||||
h16 = cv2.pyrDown(h32)
|
||||
c512, r512 = _get_image_gradient(h512)
|
||||
c256, r256 = _get_image_gradient(h256)
|
||||
c128, r128 = _get_image_gradient(h128)
|
||||
c64, r64 = _get_image_gradient(h64)
|
||||
c32, r32 = _get_image_gradient(h32)
|
||||
c16, r16 = _get_image_gradient(h16)
|
||||
c = c16
|
||||
c = _d_resize(cv2.pyrUp(c), c32.shape) * 4.0 + c32
|
||||
c = _d_resize(cv2.pyrUp(c), c64.shape) * 4.0 + c64
|
||||
c = _d_resize(cv2.pyrUp(c), c128.shape) * 4.0 + c128
|
||||
c = _d_resize(cv2.pyrUp(c), c256.shape) * 4.0 + c256
|
||||
c = _d_resize(cv2.pyrUp(c), c512.shape) * 4.0 + c512
|
||||
r = r16
|
||||
r = _d_resize(cv2.pyrUp(r), r32.shape) * 4.0 + r32
|
||||
r = _d_resize(cv2.pyrUp(r), r64.shape) * 4.0 + r64
|
||||
r = _d_resize(cv2.pyrUp(r), r128.shape) * 4.0 + r128
|
||||
r = _d_resize(cv2.pyrUp(r), r256.shape) * 4.0 + r256
|
||||
r = _d_resize(cv2.pyrUp(r), r512.shape) * 4.0 + r512
|
||||
coarse_effect_cols = c
|
||||
coarse_effect_rows = r
|
||||
EPS = 1e-10
|
||||
|
||||
max_effect = np.max((coarse_effect_cols**2 + coarse_effect_rows**2)**0.5, axis=0, keepdims=True, ).max(1, keepdims=True)
|
||||
coarse_effect_cols = (coarse_effect_cols + EPS) / (max_effect + EPS)
|
||||
coarse_effect_rows = (coarse_effect_rows + EPS) / (max_effect + EPS)
|
||||
|
||||
return np.stack([ np.zeros_like(coarse_effect_rows), coarse_effect_rows, coarse_effect_cols], axis=-1)
|
||||
|
||||
def apply_random_relight(img, mask=None, rnd_state=None):
|
||||
if rnd_state is None:
|
||||
rnd_state = np.random
|
||||
|
||||
def_img = img
|
||||
|
||||
if rnd_state.randint(2) == 0:
|
||||
light_pos_y = 1.0 if rnd_state.randint(2) == 0 else -1.0
|
||||
light_pos_x = rnd_state.uniform()*2-1.0
|
||||
else:
|
||||
light_pos_y = rnd_state.uniform()*2-1.0
|
||||
light_pos_x = 1.0 if rnd_state.randint(2) == 0 else -1.0
|
||||
|
||||
light_source_height = 0.3*rnd_state.uniform()*0.7
|
||||
light_intensity = 1.0+rnd_state.uniform()
|
||||
ambient_intensity = 0.5
|
||||
|
||||
light_source_location = np.array([[[light_source_height, light_pos_y, light_pos_x ]]], dtype=np.float32)
|
||||
light_source_direction = light_source_location / np.sqrt(np.sum(np.square(light_source_location)))
|
||||
|
||||
lighting_effect = _generate_lighting_effects(img)
|
||||
lighting_effect = np.sum(lighting_effect * light_source_direction, axis=-1).clip(0, 1)
|
||||
lighting_effect = np.mean(lighting_effect, axis=-1, keepdims=True)
|
||||
|
||||
result = def_img * (ambient_intensity + lighting_effect * light_intensity) #light_source_color
|
||||
result = np.clip(result, 0, 1)
|
||||
|
||||
if mask is not None:
|
||||
result = def_img*(1-mask) + result*mask
|
||||
|
||||
return result
|
|
@ -138,6 +138,8 @@ class SampleGeneratorFaceXSeg(SampleGeneratorBase):
|
|||
bg_img = imagelib.apply_random_hsv_shift(bg_img)
|
||||
else:
|
||||
bg_img = imagelib.apply_random_rgb_levels(bg_img)
|
||||
|
||||
|
||||
|
||||
c_mask = 1.0 - (1-bg_mask) * (1-mask)
|
||||
rnd = np.random.uniform()
|
||||
|
@ -151,12 +153,16 @@ class SampleGeneratorFaceXSeg(SampleGeneratorBase):
|
|||
mask[mask < 0.5] = 0.0
|
||||
mask[mask >= 0.5] = 1.0
|
||||
mask = np.clip(mask, 0, 1)
|
||||
|
||||
if np.random.randint(4) < 3:
|
||||
img = imagelib.apply_random_relight(img)
|
||||
|
||||
if np.random.randint(2) == 0:
|
||||
img = imagelib.apply_random_hsv_shift(img, mask=sd.random_circle_faded ([resolution,resolution]))
|
||||
else:
|
||||
img = imagelib.apply_random_rgb_levels(img, mask=sd.random_circle_faded ([resolution,resolution]))
|
||||
|
||||
|
||||
|
||||
if np.random.randint(2) == 0:
|
||||
# random face flare
|
||||
krn = np.random.randint( resolution//4, resolution )
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue