mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-05 20:42:11 -07:00
XSeg sample generator: additional sample augmentation for training
This commit is contained in:
parent
87030bdcdf
commit
081d8faa45
3 changed files with 26 additions and 7 deletions
|
@ -21,6 +21,7 @@ from .SegIEPolys import *
|
|||
from .blursharpen import LinearMotionBlur, blursharpen
|
||||
|
||||
from .filters import apply_random_rgb_levels, \
|
||||
apply_random_overlay_triangle, \
|
||||
apply_random_hsv_shift, \
|
||||
apply_random_sharpen, \
|
||||
apply_random_motion_blur, \
|
||||
|
|
|
@ -128,7 +128,29 @@ def apply_random_jpeg_compress( img, chance, mask=None, rnd_state=None ):
|
|||
|
||||
return result
|
||||
|
||||
def apply_random_overlay_triangle( img, max_alpha, mask=None, rnd_state=None ):
|
||||
if rnd_state is None:
|
||||
rnd_state = np.random
|
||||
|
||||
h,w,c = img.shape
|
||||
pt1 = [rnd_state.randint(w), rnd_state.randint(h) ]
|
||||
pt2 = [rnd_state.randint(w), rnd_state.randint(h) ]
|
||||
pt3 = [rnd_state.randint(w), rnd_state.randint(h) ]
|
||||
|
||||
alpha = rnd_state.uniform()*max_alpha
|
||||
|
||||
tri_mask = cv2.fillPoly( np.zeros_like(img), [ np.array([pt1,pt2,pt3], np.int32) ], (alpha,)*c )
|
||||
|
||||
if rnd_state.randint(2) == 0:
|
||||
result = np.clip(img+tri_mask, 0, 1)
|
||||
else:
|
||||
result = np.clip(img-tri_mask, 0, 1)
|
||||
|
||||
if mask is not None:
|
||||
result = img*(1-mask) + result*mask
|
||||
|
||||
return result
|
||||
|
||||
def _min_resize(x, m):
|
||||
if x.shape[0] < x.shape[1]:
|
||||
s0 = m
|
||||
|
|
|
@ -138,11 +138,9 @@ class SampleGeneratorFaceXSeg(SampleGeneratorBase):
|
|||
bg_img = imagelib.apply_random_hsv_shift(bg_img)
|
||||
else:
|
||||
bg_img = imagelib.apply_random_rgb_levels(bg_img)
|
||||
|
||||
|
||||
|
||||
c_mask = 1.0 - (1-bg_mask) * (1-mask)
|
||||
rnd = np.random.uniform()
|
||||
rnd = 0.15 + np.random.uniform()*0.85
|
||||
img = img*(c_mask) + img*(1-c_mask)*rnd + bg_img*(1-c_mask)*(1-rnd)
|
||||
|
||||
warp_params = imagelib.gen_warp_params(resolution, random_flip, rotation_range=rotation_range, scale_range=scale_range, tx_range=tx_range, ty_range=ty_range )
|
||||
|
@ -153,15 +151,13 @@ class SampleGeneratorFaceXSeg(SampleGeneratorBase):
|
|||
mask[mask < 0.5] = 0.0
|
||||
mask[mask >= 0.5] = 1.0
|
||||
mask = np.clip(mask, 0, 1)
|
||||
|
||||
#if np.random.randint(4) < 3:
|
||||
# img = imagelib.apply_random_relight(img)
|
||||
|
||||
img = imagelib.apply_random_overlay_triangle(img, max_alpha=0.25, mask=sd.random_circle_faded ([resolution,resolution]))
|
||||
|
||||
if np.random.randint(2) == 0:
|
||||
img = imagelib.apply_random_hsv_shift(img, mask=sd.random_circle_faded ([resolution,resolution]))
|
||||
else:
|
||||
img = imagelib.apply_random_rgb_levels(img, mask=sd.random_circle_faded ([resolution,resolution]))
|
||||
|
||||
|
||||
if np.random.randint(2) == 0:
|
||||
# random face flare
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue