mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 21:12:07 -07:00
upd SampleGenerator
This commit is contained in:
parent
1898bd6881
commit
9860a38907
4 changed files with 42 additions and 33 deletions
|
@ -2,14 +2,12 @@ import numpy as np
|
||||||
import cv2
|
import cv2
|
||||||
from core import randomex
|
from core import randomex
|
||||||
|
|
||||||
def gen_warp_params (source, flip, rotation_range=[-10,10], scale_range=[-0.5, 0.5], tx_range=[-0.05, 0.05], ty_range=[-0.05, 0.05], rnd_seed=None ):
|
def gen_warp_params (source, flip, rotation_range=[-10,10], scale_range=[-0.5, 0.5], tx_range=[-0.05, 0.05], ty_range=[-0.05, 0.05], rnd_state=None ):
|
||||||
h,w,c = source.shape
|
h,w,c = source.shape
|
||||||
if (h != w):
|
if (h != w):
|
||||||
raise ValueError ('gen_warp_params accepts only square images.')
|
raise ValueError ('gen_warp_params accepts only square images.')
|
||||||
|
|
||||||
if rnd_seed != None:
|
if rnd_state is None:
|
||||||
rnd_state = np.random.RandomState (rnd_seed)
|
|
||||||
else:
|
|
||||||
rnd_state = np.random
|
rnd_state = np.random
|
||||||
|
|
||||||
rotation = rnd_state.uniform( rotation_range[0], rotation_range[1] )
|
rotation = rnd_state.uniform( rotation_range[0], rotation_range[1] )
|
||||||
|
|
|
@ -99,15 +99,17 @@ class IndexHost():
|
||||||
"""
|
"""
|
||||||
Provides random shuffled indexes for multiprocesses
|
Provides random shuffled indexes for multiprocesses
|
||||||
"""
|
"""
|
||||||
def __init__(self, indexes_count):
|
def __init__(self, indexes_count, rnd_seed=None):
|
||||||
self.sq = multiprocessing.Queue()
|
self.sq = multiprocessing.Queue()
|
||||||
self.cqs = []
|
self.cqs = []
|
||||||
self.clis = []
|
self.clis = []
|
||||||
self.thread = threading.Thread(target=self.host_thread, args=(indexes_count,) )
|
self.thread = threading.Thread(target=self.host_thread, args=(indexes_count,rnd_seed) )
|
||||||
self.thread.daemon = True
|
self.thread.daemon = True
|
||||||
self.thread.start()
|
self.thread.start()
|
||||||
|
|
||||||
def host_thread(self, indexes_count):
|
def host_thread(self, indexes_count, rnd_seed):
|
||||||
|
rnd_state = np.random.RandomState(rnd_seed) if rnd_seed is not None else np.random
|
||||||
|
|
||||||
idxs = [*range(indexes_count)]
|
idxs = [*range(indexes_count)]
|
||||||
shuffle_idxs = []
|
shuffle_idxs = []
|
||||||
sq = self.sq
|
sq = self.sq
|
||||||
|
@ -121,7 +123,7 @@ class IndexHost():
|
||||||
for i in range(count):
|
for i in range(count):
|
||||||
if len(shuffle_idxs) == 0:
|
if len(shuffle_idxs) == 0:
|
||||||
shuffle_idxs = idxs.copy()
|
shuffle_idxs = idxs.copy()
|
||||||
np.random.shuffle(shuffle_idxs)
|
rnd_state.shuffle(shuffle_idxs)
|
||||||
result.append(shuffle_idxs.pop())
|
result.append(shuffle_idxs.pop())
|
||||||
self.cqs[cq_id].put (result)
|
self.cqs[cq_id].put (result)
|
||||||
|
|
||||||
|
|
|
@ -27,6 +27,7 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
||||||
output_sample_types=[],
|
output_sample_types=[],
|
||||||
add_sample_idx=False,
|
add_sample_idx=False,
|
||||||
generators_count=4,
|
generators_count=4,
|
||||||
|
rnd_seed=None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
|
||||||
super().__init__(samples_path, debug, batch_size)
|
super().__init__(samples_path, debug, batch_size)
|
||||||
|
@ -34,6 +35,9 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
||||||
self.output_sample_types = output_sample_types
|
self.output_sample_types = output_sample_types
|
||||||
self.add_sample_idx = add_sample_idx
|
self.add_sample_idx = add_sample_idx
|
||||||
|
|
||||||
|
if rnd_seed is None:
|
||||||
|
rnd_seed = np.random.randint(0x80000000)
|
||||||
|
|
||||||
if self.debug:
|
if self.debug:
|
||||||
self.generators_count = 1
|
self.generators_count = 1
|
||||||
else:
|
else:
|
||||||
|
@ -45,11 +49,11 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
||||||
if self.samples_len == 0:
|
if self.samples_len == 0:
|
||||||
raise ValueError('No training data provided.')
|
raise ValueError('No training data provided.')
|
||||||
|
|
||||||
index_host = mplib.IndexHost(self.samples_len)
|
index_host = mplib.IndexHost(self.samples_len, rnd_seed=rnd_seed)
|
||||||
|
|
||||||
if random_ct_samples_path is not None:
|
if random_ct_samples_path is not None:
|
||||||
ct_samples = SampleLoader.load (SampleType.FACE, random_ct_samples_path)
|
ct_samples = SampleLoader.load (SampleType.FACE, random_ct_samples_path)
|
||||||
ct_index_host = mplib.IndexHost( len(ct_samples) )
|
ct_index_host = mplib.IndexHost( len(ct_samples), rnd_seed=rnd_seed )
|
||||||
else:
|
else:
|
||||||
ct_samples = None
|
ct_samples = None
|
||||||
ct_index_host = None
|
ct_index_host = None
|
||||||
|
@ -58,9 +62,9 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
||||||
ct_pickled_samples = pickle.dumps(ct_samples, 4) if ct_samples is not None else None
|
ct_pickled_samples = pickle.dumps(ct_samples, 4) if ct_samples is not None else None
|
||||||
|
|
||||||
if self.debug:
|
if self.debug:
|
||||||
self.generators = [ThisThreadGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(), ct_pickled_samples, ct_index_host.create_cli() if ct_index_host is not None else None) )]
|
self.generators = [ThisThreadGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(), ct_pickled_samples, ct_index_host.create_cli() if ct_index_host is not None else None, rnd_seed) )]
|
||||||
else:
|
else:
|
||||||
self.generators = [SubprocessGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(), ct_pickled_samples, ct_index_host.create_cli() if ct_index_host is not None else None), start_now=False ) \
|
self.generators = [SubprocessGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(), ct_pickled_samples, ct_index_host.create_cli() if ct_index_host is not None else None, rnd_seed), start_now=False ) \
|
||||||
for i in range(self.generators_count) ]
|
for i in range(self.generators_count) ]
|
||||||
|
|
||||||
SubprocessGenerator.start_in_parallel( self.generators )
|
SubprocessGenerator.start_in_parallel( self.generators )
|
||||||
|
@ -76,7 +80,9 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
||||||
return next(generator)
|
return next(generator)
|
||||||
|
|
||||||
def batch_func(self, param ):
|
def batch_func(self, param ):
|
||||||
pickled_samples, index_host, ct_pickled_samples, ct_index_host = param
|
pickled_samples, index_host, ct_pickled_samples, ct_index_host, rnd_seed = param
|
||||||
|
|
||||||
|
rnd_state = np.random.RandomState(rnd_seed)
|
||||||
|
|
||||||
samples = pickle.loads(pickled_samples)
|
samples = pickle.loads(pickled_samples)
|
||||||
ct_samples = pickle.loads(ct_pickled_samples) if ct_pickled_samples is not None else None
|
ct_samples = pickle.loads(ct_pickled_samples) if ct_pickled_samples is not None else None
|
||||||
|
@ -98,7 +104,7 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
||||||
ct_sample = ct_samples[ct_indexes[n_batch]]
|
ct_sample = ct_samples[ct_indexes[n_batch]]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
x, = SampleProcessor.process ([sample], self.sample_process_options, self.output_sample_types, self.debug, ct_sample=ct_sample)
|
x, = SampleProcessor.process ([sample], self.sample_process_options, self.output_sample_types, self.debug, ct_sample=ct_sample, rnd_state=rnd_state)
|
||||||
except:
|
except:
|
||||||
raise Exception ("Exception occured in sample %s. Error: %s" % (sample.filename, traceback.format_exc() ) )
|
raise Exception ("Exception occured in sample %s. Error: %s" % (sample.filename, traceback.format_exc() ) )
|
||||||
|
|
||||||
|
|
|
@ -63,10 +63,13 @@ class SampleProcessor(object):
|
||||||
}
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def process (samples, sample_process_options, output_sample_types, debug, ct_sample=None):
|
def process (samples, sample_process_options, output_sample_types, debug, ct_sample=None, rnd_state=None):
|
||||||
SPTF = SampleProcessor.Types
|
SPTF = SampleProcessor.Types
|
||||||
|
|
||||||
sample_rnd_seed = np.random.randint(0x80000000)
|
if rnd_state is None:
|
||||||
|
rnd_state = np.random.RandomState( np.random.randint(0x80000000) )
|
||||||
|
|
||||||
|
sample_rnd_seed = rnd_state.randint(0x80000000)
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
for sample in samples:
|
for sample in samples:
|
||||||
|
@ -79,7 +82,7 @@ class SampleProcessor(object):
|
||||||
if debug and is_face_sample:
|
if debug and is_face_sample:
|
||||||
LandmarksProcessor.draw_landmarks (sample_bgr, sample.landmarks, (0, 1, 0))
|
LandmarksProcessor.draw_landmarks (sample_bgr, sample.landmarks, (0, 1, 0))
|
||||||
|
|
||||||
params = imagelib.gen_warp_params(sample_bgr, sample_process_options.random_flip, rotation_range=sample_process_options.rotation_range, scale_range=sample_process_options.scale_range, tx_range=sample_process_options.tx_range, ty_range=sample_process_options.ty_range, rnd_seed=sample_rnd_seed )
|
params = imagelib.gen_warp_params(sample_bgr, sample_process_options.random_flip, rotation_range=sample_process_options.rotation_range, scale_range=sample_process_options.scale_range, tx_range=sample_process_options.tx_range, ty_range=sample_process_options.ty_range, rnd_state=rnd_state )
|
||||||
|
|
||||||
outputs_sample = []
|
outputs_sample = []
|
||||||
for opts in output_sample_types:
|
for opts in output_sample_types:
|
||||||
|
@ -186,10 +189,10 @@ class SampleProcessor(object):
|
||||||
chance, mb_max_size = motion_blur
|
chance, mb_max_size = motion_blur
|
||||||
chance = np.clip(chance, 0, 100)
|
chance = np.clip(chance, 0, 100)
|
||||||
|
|
||||||
rnd_state = np.random.RandomState (sample_rnd_seed)
|
l_rnd_state = np.random.RandomState (sample_rnd_seed)
|
||||||
mblur_rnd_chance = rnd_state.randint(100)
|
mblur_rnd_chance = l_rnd_state.randint(100)
|
||||||
mblur_rnd_kernel = rnd_state.randint(mb_max_size)+1
|
mblur_rnd_kernel = l_rnd_state.randint(mb_max_size)+1
|
||||||
mblur_rnd_deg = rnd_state.randint(360)
|
mblur_rnd_deg = l_rnd_state.randint(360)
|
||||||
|
|
||||||
if mblur_rnd_chance < chance:
|
if mblur_rnd_chance < chance:
|
||||||
img = imagelib.LinearMotionBlur (img, mblur_rnd_kernel, mblur_rnd_deg )
|
img = imagelib.LinearMotionBlur (img, mblur_rnd_kernel, mblur_rnd_deg )
|
||||||
|
@ -198,9 +201,9 @@ class SampleProcessor(object):
|
||||||
chance, kernel_max_size = gaussian_blur
|
chance, kernel_max_size = gaussian_blur
|
||||||
chance = np.clip(chance, 0, 100)
|
chance = np.clip(chance, 0, 100)
|
||||||
|
|
||||||
rnd_state = np.random.RandomState (sample_rnd_seed+1)
|
l_rnd_state = np.random.RandomState (sample_rnd_seed+1)
|
||||||
gblur_rnd_chance = rnd_state.randint(100)
|
gblur_rnd_chance = l_rnd_state.randint(100)
|
||||||
gblur_rnd_kernel = rnd_state.randint(kernel_max_size)*2+1
|
gblur_rnd_kernel = l_rnd_state.randint(kernel_max_size)*2+1
|
||||||
|
|
||||||
if gblur_rnd_chance < chance:
|
if gblur_rnd_chance < chance:
|
||||||
img = cv2.GaussianBlur(img, (gblur_rnd_kernel,) *2 , 0)
|
img = cv2.GaussianBlur(img, (gblur_rnd_kernel,) *2 , 0)
|
||||||
|
@ -260,22 +263,22 @@ class SampleProcessor(object):
|
||||||
if mode_type == SPTF.MODE_BGR:
|
if mode_type == SPTF.MODE_BGR:
|
||||||
out_sample = img
|
out_sample = img
|
||||||
elif mode_type == SPTF.MODE_BGR_SHUFFLE:
|
elif mode_type == SPTF.MODE_BGR_SHUFFLE:
|
||||||
rnd_state = np.random.RandomState (sample_rnd_seed)
|
l_rnd_state = np.random.RandomState (sample_rnd_seed)
|
||||||
out_sample = np.take (img, rnd_state.permutation(img.shape[-1]), axis=-1)
|
out_sample = np.take (img, l_rnd_state.permutation(img.shape[-1]), axis=-1)
|
||||||
|
|
||||||
elif mode_type == SPTF.MODE_BGR_RANDOM_HSV_SHIFT:
|
elif mode_type == SPTF.MODE_BGR_RANDOM_HSV_SHIFT:
|
||||||
rnd_state = np.random.RandomState (sample_rnd_seed)
|
l_rnd_state = np.random.RandomState (sample_rnd_seed)
|
||||||
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
|
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
|
||||||
h, s, v = cv2.split(hsv)
|
h, s, v = cv2.split(hsv)
|
||||||
h = (h + rnd_state.randint(360) ) % 360
|
h = (h + l_rnd_state.randint(360) ) % 360
|
||||||
s = np.clip ( s + rnd_state.random()-0.5, 0, 1 )
|
s = np.clip ( s + l_rnd_state.random()-0.5, 0, 1 )
|
||||||
v = np.clip ( v + rnd_state.random()-0.5, 0, 1 )
|
v = np.clip ( v + l_rnd_state.random()-0.5, 0, 1 )
|
||||||
hsv = cv2.merge([h, s, v])
|
hsv = cv2.merge([h, s, v])
|
||||||
out_sample = np.clip( cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) , 0, 1 )
|
out_sample = np.clip( cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) , 0, 1 )
|
||||||
|
|
||||||
elif mode_type == SPTF.MODE_BGR_RANDOM_RGB_LEVELS:
|
elif mode_type == SPTF.MODE_BGR_RANDOM_RGB_LEVELS:
|
||||||
rnd_state = np.random.RandomState (sample_rnd_seed)
|
l_rnd_state = np.random.RandomState (sample_rnd_seed)
|
||||||
np_rnd = rnd_state.rand
|
np_rnd = l_rnd_state.rand
|
||||||
|
|
||||||
inBlack = np.array([np_rnd()*0.25 , np_rnd()*0.25 , np_rnd()*0.25], dtype=np.float32)
|
inBlack = np.array([np_rnd()*0.25 , np_rnd()*0.25 , np_rnd()*0.25], dtype=np.float32)
|
||||||
inWhite = np.array([1.0-np_rnd()*0.25, 1.0-np_rnd()*0.25, 1.0-np_rnd()*0.25], dtype=np.float32)
|
inWhite = np.array([1.0-np_rnd()*0.25, 1.0-np_rnd()*0.25, 1.0-np_rnd()*0.25], dtype=np.float32)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue