mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-05 20:42:11 -07:00
Sample loader : back to serial one core loader
This commit is contained in:
parent
bbe81b20af
commit
52a67a61b3
1 changed files with 14 additions and 93 deletions
|
@ -70,25 +70,21 @@ class SampleHost:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load_face_samples ( image_paths):
|
def load_face_samples ( image_paths):
|
||||||
result = FaceSamplesLoaderSubprocessor(image_paths).run()
|
|
||||||
sample_list = []
|
sample_list = []
|
||||||
|
|
||||||
for filename, \
|
for filename in io.progress_bar_generator (image_paths, desc="Loading"):
|
||||||
( face_type,
|
dflimg = DFLIMG.load (Path(filename))
|
||||||
shape,
|
if dflimg is None:
|
||||||
landmarks,
|
io.log_err (f"{filename} is not a dfl image file.")
|
||||||
ie_polys,
|
else:
|
||||||
eyebrows_expand_mod,
|
sample_list.append( Sample(filename=filename,
|
||||||
source_filename,
|
sample_type=SampleType.FACE,
|
||||||
) in result:
|
face_type=FaceType.fromString ( dflimg.get_face_type() ),
|
||||||
sample_list.append( Sample(filename=filename,
|
shape=dflimg.get_shape(),
|
||||||
sample_type=SampleType.FACE,
|
landmarks=dflimg.get_landmarks(),
|
||||||
face_type=FaceType.fromString (face_type),
|
ie_polys=dflimg.get_ie_polys(),
|
||||||
shape=shape,
|
eyebrows_expand_mod=dflimg.get_eyebrows_expand_mod(),
|
||||||
landmarks=landmarks,
|
source_filename=dflimg.get_source_filename(),
|
||||||
ie_polys=ie_polys,
|
|
||||||
eyebrows_expand_mod=eyebrows_expand_mod,
|
|
||||||
source_filename=source_filename,
|
|
||||||
))
|
))
|
||||||
return sample_list
|
return sample_list
|
||||||
|
|
||||||
|
@ -98,78 +94,3 @@ class SampleHost:
|
||||||
new_s = sorted(new_s, key=operator.itemgetter(1))
|
new_s = sorted(new_s, key=operator.itemgetter(1))
|
||||||
|
|
||||||
return [ s[0] for s in new_s]
|
return [ s[0] for s in new_s]
|
||||||
|
|
||||||
|
|
||||||
class FaceSamplesLoaderSubprocessor(Subprocessor):
|
|
||||||
#override
|
|
||||||
def __init__(self, image_paths ):
|
|
||||||
self.image_paths = image_paths
|
|
||||||
self.image_paths_len = len(image_paths)
|
|
||||||
self.idxs = [*range(self.image_paths_len)]
|
|
||||||
self.result = [None]*self.image_paths_len
|
|
||||||
super().__init__('FaceSamplesLoader', FaceSamplesLoaderSubprocessor.Cli, 60, initialize_subprocesses_in_serial=False)
|
|
||||||
|
|
||||||
#override
|
|
||||||
def on_clients_initialized(self):
|
|
||||||
io.progress_bar ("Loading", len (self.image_paths))
|
|
||||||
|
|
||||||
#override
|
|
||||||
def on_clients_finalized(self):
|
|
||||||
io.progress_bar_close()
|
|
||||||
|
|
||||||
#override
|
|
||||||
def process_info_generator(self):
|
|
||||||
for i in range(min(multiprocessing.cpu_count(), 8) ):
|
|
||||||
yield 'CPU%d' % (i), {}, {'device_idx': i,
|
|
||||||
'device_name': 'CPU%d' % (i),
|
|
||||||
}
|
|
||||||
|
|
||||||
#override
|
|
||||||
def get_data(self, host_dict):
|
|
||||||
if len (self.idxs) > 0:
|
|
||||||
idx = self.idxs.pop(0)
|
|
||||||
return idx, self.image_paths[idx]
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
#override
|
|
||||||
def on_data_return (self, host_dict, data):
|
|
||||||
self.idxs.insert(0, data[0])
|
|
||||||
|
|
||||||
#override
|
|
||||||
def on_result (self, host_dict, data, result):
|
|
||||||
idx, dflimg = result
|
|
||||||
self.result[idx] = (self.image_paths[idx], dflimg)
|
|
||||||
io.progress_bar_inc(1)
|
|
||||||
|
|
||||||
#override
|
|
||||||
def get_result(self):
|
|
||||||
return self.result
|
|
||||||
|
|
||||||
class Cli(Subprocessor.Cli):
|
|
||||||
#override
|
|
||||||
def on_initialize(self, client_dict):
|
|
||||||
pass
|
|
||||||
|
|
||||||
#override
|
|
||||||
def process_data(self, data):
|
|
||||||
idx, filename = data
|
|
||||||
dflimg = DFLIMG.load (Path(filename))
|
|
||||||
|
|
||||||
if dflimg is None:
|
|
||||||
self.log_err (f"FaceSamplesLoader: {filename} is not a dfl image file.")
|
|
||||||
data = None
|
|
||||||
else:
|
|
||||||
data = (dflimg.get_face_type(),
|
|
||||||
dflimg.get_shape(),
|
|
||||||
dflimg.get_landmarks(),
|
|
||||||
dflimg.get_ie_polys(),
|
|
||||||
dflimg.get_eyebrows_expand_mod(),
|
|
||||||
dflimg.get_source_filename() )
|
|
||||||
|
|
||||||
return idx, data
|
|
||||||
|
|
||||||
#override
|
|
||||||
def get_data_name (self, data):
|
|
||||||
#return string identificator of your data
|
|
||||||
return data[1]
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue