mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 04:52:13 -07:00
fixes
This commit is contained in:
parent
9ed0111824
commit
7e609542db
2 changed files with 96 additions and 8 deletions
|
@ -1,15 +1,17 @@
|
|||
import multiprocessing
|
||||
import operator
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
|
||||
import samplelib.PackedFaceset
|
||||
from DFLIMG import *
|
||||
from facelib import FaceType, LandmarksProcessor
|
||||
from interact import interact as io
|
||||
from joblib import Subprocessor
|
||||
from utils import Path_utils, mp_utils
|
||||
from DFLIMG import *
|
||||
|
||||
from .Sample import Sample, SampleType
|
||||
|
||||
import samplelib.PackedFaceset
|
||||
|
||||
class SampleHost:
|
||||
samples_cache = dict()
|
||||
|
@ -79,17 +81,32 @@ class SampleHost:
|
|||
return hosts[sample_type]
|
||||
|
||||
@staticmethod
|
||||
def load_face_samples ( image_paths, silent=False):
|
||||
def load_face_samples ( image_paths):
|
||||
result = FaceSamplesLoaderSubprocessor(image_paths).run()
|
||||
sample_list = []
|
||||
|
||||
for filename in (image_paths if silent else io.progress_bar_generator( image_paths, "Loading")):
|
||||
for filename,dflimg in io.progress_bar_generator(result, "Processing"):
|
||||
sample_list.append( Sample(filename=filename,
|
||||
sample_type=SampleType.FACE,
|
||||
face_type=FaceType.fromString (dflimg.get_face_type()),
|
||||
shape=dflimg.get_shape(),
|
||||
landmarks=dflimg.get_landmarks(),
|
||||
ie_polys=dflimg.get_ie_polys(),
|
||||
pitch_yaw_roll=dflimg.get_pitch_yaw_roll(),
|
||||
eyebrows_expand_mod=dflimg.get_eyebrows_expand_mod(),
|
||||
source_filename=dflimg.get_source_filename(),
|
||||
))
|
||||
return sample_list
|
||||
|
||||
"""
|
||||
sample_list = []
|
||||
|
||||
for filename in io.progress_bar_generator(image_paths, "Loading"):
|
||||
filename_path = Path(filename)
|
||||
try:
|
||||
dflimg = DFLIMG.load (filename_path)
|
||||
|
||||
if dflimg is None:
|
||||
io.log_err ("load_face_samples: %s is not a dfl image file required for training" % (filename_path.name) )
|
||||
continue
|
||||
|
||||
|
||||
sample_list.append( Sample(filename=filename,
|
||||
sample_type=SampleType.FACE,
|
||||
|
@ -105,6 +122,7 @@ class SampleHost:
|
|||
io.log_err ("Unable to load %s , error: %s" % (filename, traceback.format_exc() ) )
|
||||
|
||||
return sample_list
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def upgradeToFaceTemporalSortedSamples( samples ):
|
||||
|
@ -112,3 +130,70 @@ class SampleHost:
|
|||
new_s = sorted(new_s, key=operator.itemgetter(1))
|
||||
|
||||
return [ s[0] for s in new_s]
|
||||
|
||||
|
||||
class FaceSamplesLoaderSubprocessor(Subprocessor):
|
||||
#override
|
||||
def __init__(self, image_paths ):
|
||||
self.image_paths = image_paths
|
||||
self.image_paths_len = len(image_paths)
|
||||
self.idxs = [*range(self.image_paths_len)]
|
||||
self.result = [None]*self.image_paths_len
|
||||
super().__init__('FaceSamplesLoader', FaceSamplesLoaderSubprocessor.Cli, 60, initialize_subprocesses_in_serial=False)
|
||||
|
||||
#override
|
||||
def on_clients_initialized(self):
|
||||
io.progress_bar ("Loading", len (self.image_paths))
|
||||
|
||||
#override
|
||||
def on_clients_finalized(self):
|
||||
io.progress_bar_close()
|
||||
|
||||
#override
|
||||
def process_info_generator(self):
|
||||
for i in range(min(multiprocessing.cpu_count(), 8) ):
|
||||
yield 'CPU%d' % (i), {}, {'device_idx': i,
|
||||
'device_name': 'CPU%d' % (i),
|
||||
}
|
||||
|
||||
#override
|
||||
def get_data(self, host_dict):
|
||||
if len (self.idxs) > 0:
|
||||
idx = self.idxs.pop(0)
|
||||
return idx, self.image_paths[idx]
|
||||
|
||||
return None
|
||||
|
||||
#override
|
||||
def on_data_return (self, host_dict, data):
|
||||
self.idxs.insert(0, data[0])
|
||||
|
||||
#override
|
||||
def on_result (self, host_dict, data, result):
|
||||
idx, dflimg = result
|
||||
self.result[idx] = (self.image_paths[idx], dflimg)
|
||||
io.progress_bar_inc(1)
|
||||
|
||||
#override
|
||||
def get_result(self):
|
||||
return self.result
|
||||
|
||||
class Cli(Subprocessor.Cli):
|
||||
#override
|
||||
def on_initialize(self, client_dict):
|
||||
self.log_info ('Running on %s.' % (client_dict['device_name']) )
|
||||
|
||||
#override
|
||||
def process_data(self, data):
|
||||
idx, filename = data
|
||||
dflimg = DFLIMG.load (Path(filename))
|
||||
|
||||
if dflimg is None:
|
||||
self.log_err (f"FaceSamplesLoader: {filename} is not a dfl image file.")
|
||||
|
||||
return idx, dflimg
|
||||
|
||||
#override
|
||||
def get_data_name (self, data):
|
||||
#return string identificator of your data
|
||||
return data[1]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue