mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-07 05:22:06 -07:00
fix
This commit is contained in:
parent
21b25038ac
commit
ea33541177
2 changed files with 92 additions and 100 deletions
|
@ -1,5 +1,6 @@
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import operator
|
import operator
|
||||||
|
import pickle
|
||||||
import traceback
|
import traceback
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
@ -14,23 +15,23 @@ from .Sample import Sample, SampleType
|
||||||
|
|
||||||
|
|
||||||
class SampleHost:
|
class SampleHost:
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
samples_cache = dict()
|
samples_cache = dict()
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_person_id_max_count(samples_path):
|
def get_person_id_max_count(samples_path):
|
||||||
samples = None
|
samples = None
|
||||||
try:
|
try:
|
||||||
samples = samplelib.PackedFaceset.load(samples_path)
|
samples = samplelib.PackedFaceset.load(samples_path)
|
||||||
except:
|
except:
|
||||||
io.log_err(f"Error occured while loading samplelib.PackedFaceset.load {str(samples_dat_path)}, {traceback.format_exc()}")
|
io.log_err(f"Error occured while loading samplelib.PackedFaceset.load {str(samples_dat_path)}, {traceback.format_exc()}")
|
||||||
|
|
||||||
if samples is None:
|
if samples is None:
|
||||||
raise ValueError("packed faceset not found.")
|
raise ValueError("packed faceset not found.")
|
||||||
persons_name_idxs = {}
|
persons_name_idxs = {}
|
||||||
for sample in samples:
|
for sample in samples:
|
||||||
persons_name_idxs[sample.person_name] = 0
|
persons_name_idxs[sample.person_name] = 0
|
||||||
return len(list(persons_name_idxs.keys()))
|
return len(list(persons_name_idxs.keys()))
|
||||||
|
|
||||||
|
@ -49,41 +50,37 @@ class SampleHost:
|
||||||
elif sample_type == SampleType.FACE or \
|
elif sample_type == SampleType.FACE or \
|
||||||
sample_type == SampleType.FACE_TEMPORAL_SORTED:
|
sample_type == SampleType.FACE_TEMPORAL_SORTED:
|
||||||
result = None
|
result = None
|
||||||
|
|
||||||
if samples[sample_type] is None:
|
if samples[sample_type] is None:
|
||||||
try:
|
try:
|
||||||
result = samplelib.PackedFaceset.load(samples_path)
|
result = samplelib.PackedFaceset.load(samples_path)
|
||||||
except:
|
except:
|
||||||
io.log_err(f"Error occured while loading samplelib.PackedFaceset.load {str(samples_dat_path)}, {traceback.format_exc()}")
|
io.log_err(f"Error occured while loading samplelib.PackedFaceset.load {str(samples_dat_path)}, {traceback.format_exc()}")
|
||||||
|
|
||||||
if result is not None:
|
if result is not None:
|
||||||
io.log_info (f"Loaded {len(result)} packed faces from {samples_path}")
|
io.log_info (f"Loaded {len(result)} packed faces from {samples_path}")
|
||||||
|
|
||||||
if result is None:
|
if result is None:
|
||||||
result = SampleHost.load_face_samples( Path_utils.get_image_paths(samples_path) )
|
result = SampleHost.load_face_samples( Path_utils.get_image_paths(samples_path) )
|
||||||
|
|
||||||
samples[sample_type] = mp_utils.ListHost()
|
|
||||||
|
|
||||||
if sample_type == SampleType.FACE_TEMPORAL_SORTED:
|
if sample_type == SampleType.FACE_TEMPORAL_SORTED:
|
||||||
result = SampleHost.upgradeToFaceTemporalSortedSamples(result)
|
result = SampleHost.upgradeToFaceTemporalSortedSamples(result)
|
||||||
|
|
||||||
|
result_dumped = pickle.dumps(result)
|
||||||
|
del result
|
||||||
|
result = pickle.loads(result_dumped)
|
||||||
|
samples[sample_type] = mp_utils.ListHost(result)
|
||||||
|
|
||||||
list_host = samples[sample_type]
|
list_host = samples[sample_type]
|
||||||
|
|
||||||
clis = [ list_host.create_cli() for _ in range(number_of_clis) ]
|
clis = [ list_host.create_cli() for _ in range(number_of_clis) ]
|
||||||
|
|
||||||
if result is not None:
|
|
||||||
while True:
|
|
||||||
if len(result) == 0:
|
|
||||||
break
|
|
||||||
items = result[0:10000]
|
|
||||||
del result[0:10000]
|
|
||||||
clis[0].extend(items)
|
|
||||||
return clis
|
return clis
|
||||||
|
|
||||||
return samples[sample_type]
|
return samples[sample_type]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load_face_samples ( image_paths):
|
def load_face_samples ( image_paths):
|
||||||
result = FaceSamplesLoaderSubprocessor(image_paths).run()
|
result = FaceSamplesLoaderSubprocessor(image_paths).run()
|
||||||
sample_list = []
|
sample_list = []
|
||||||
|
|
||||||
|
@ -93,7 +90,7 @@ class SampleHost:
|
||||||
landmarks,
|
landmarks,
|
||||||
ie_polys,
|
ie_polys,
|
||||||
eyebrows_expand_mod,
|
eyebrows_expand_mod,
|
||||||
source_filename,
|
source_filename,
|
||||||
) in result:
|
) in result:
|
||||||
sample_list.append( Sample(filename=filename,
|
sample_list.append( Sample(filename=filename,
|
||||||
sample_type=SampleType.FACE,
|
sample_type=SampleType.FACE,
|
||||||
|
@ -102,7 +99,7 @@ class SampleHost:
|
||||||
landmarks=landmarks,
|
landmarks=landmarks,
|
||||||
ie_polys=ie_polys,
|
ie_polys=ie_polys,
|
||||||
eyebrows_expand_mod=eyebrows_expand_mod,
|
eyebrows_expand_mod=eyebrows_expand_mod,
|
||||||
source_filename=source_filename,
|
source_filename=source_filename,
|
||||||
))
|
))
|
||||||
return sample_list
|
return sample_list
|
||||||
|
|
||||||
|
@ -112,14 +109,14 @@ class SampleHost:
|
||||||
new_s = sorted(new_s, key=operator.itemgetter(1))
|
new_s = sorted(new_s, key=operator.itemgetter(1))
|
||||||
|
|
||||||
return [ s[0] for s in new_s]
|
return [ s[0] for s in new_s]
|
||||||
|
|
||||||
|
|
||||||
class FaceSamplesLoaderSubprocessor(Subprocessor):
|
class FaceSamplesLoaderSubprocessor(Subprocessor):
|
||||||
#override
|
#override
|
||||||
def __init__(self, image_paths ):
|
def __init__(self, image_paths ):
|
||||||
self.image_paths = image_paths
|
self.image_paths = image_paths
|
||||||
self.image_paths_len = len(image_paths)
|
self.image_paths_len = len(image_paths)
|
||||||
self.idxs = [*range(self.image_paths_len)]
|
self.idxs = [*range(self.image_paths_len)]
|
||||||
self.result = [None]*self.image_paths_len
|
self.result = [None]*self.image_paths_len
|
||||||
super().__init__('FaceSamplesLoader', FaceSamplesLoaderSubprocessor.Cli, 60, initialize_subprocesses_in_serial=False)
|
super().__init__('FaceSamplesLoader', FaceSamplesLoaderSubprocessor.Cli, 60, initialize_subprocesses_in_serial=False)
|
||||||
|
|
||||||
|
@ -169,7 +166,7 @@ class FaceSamplesLoaderSubprocessor(Subprocessor):
|
||||||
def process_data(self, data):
|
def process_data(self, data):
|
||||||
idx, filename = data
|
idx, filename = data
|
||||||
dflimg = DFLIMG.load (Path(filename))
|
dflimg = DFLIMG.load (Path(filename))
|
||||||
|
|
||||||
if dflimg is None:
|
if dflimg is None:
|
||||||
self.log_err (f"FaceSamplesLoader: {filename} is not a dfl image file.")
|
self.log_err (f"FaceSamplesLoader: {filename} is not a dfl image file.")
|
||||||
data = None
|
data = None
|
||||||
|
@ -179,8 +176,8 @@ class FaceSamplesLoaderSubprocessor(Subprocessor):
|
||||||
dflimg.get_landmarks(),
|
dflimg.get_landmarks(),
|
||||||
dflimg.get_ie_polys(),
|
dflimg.get_ie_polys(),
|
||||||
dflimg.get_eyebrows_expand_mod(),
|
dflimg.get_eyebrows_expand_mod(),
|
||||||
dflimg.get_source_filename() )
|
dflimg.get_source_filename() )
|
||||||
|
|
||||||
return idx, data
|
return idx, data
|
||||||
|
|
||||||
#override
|
#override
|
||||||
|
|
|
@ -1,25 +1,22 @@
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import traceback
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
class Index2DHost():
|
class Index2DHost():
|
||||||
"""
|
"""
|
||||||
Provides random shuffled 2D indexes for multiprocesses
|
Provides random shuffled 2D indexes for multiprocesses
|
||||||
"""
|
"""
|
||||||
def __init__(self, indexes2D, max_number_of_clis=128):
|
def __init__(self, indexes2D):
|
||||||
self.sq = multiprocessing.Queue()
|
self.sq = multiprocessing.Queue()
|
||||||
self.cqs = [ multiprocessing.Queue() for _ in range(max_number_of_clis) ]
|
self.cqs = []
|
||||||
self.n_clis = 0
|
self.clis = []
|
||||||
self.max_number_of_clis = max_number_of_clis
|
self.thread = threading.Thread(target=self.host_thread, args=(indexes2D,) )
|
||||||
|
self.thread.daemon = True
|
||||||
|
self.thread.start()
|
||||||
|
|
||||||
self.p = multiprocessing.Process(target=self.host_proc, args=(indexes2D, self.sq, self.cqs) )
|
def host_thread(self, indexes2D):
|
||||||
self.p.daemon = True
|
|
||||||
self.p.start()
|
|
||||||
|
|
||||||
def host_proc(self, indexes2D, sq, cqs):
|
|
||||||
indexes_counts_len = len(indexes2D)
|
indexes_counts_len = len(indexes2D)
|
||||||
|
|
||||||
idxs = [*range(indexes_counts_len)]
|
idxs = [*range(indexes_counts_len)]
|
||||||
|
@ -30,6 +27,8 @@ class Index2DHost():
|
||||||
idxs_2D[i] = indexes2D[i]
|
idxs_2D[i] = indexes2D[i]
|
||||||
shuffle_idxs_2D[i] = []
|
shuffle_idxs_2D[i] = []
|
||||||
|
|
||||||
|
sq = self.sq
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
while not sq.empty():
|
while not sq.empty():
|
||||||
obj = sq.get()
|
obj = sq.get()
|
||||||
|
@ -44,7 +43,7 @@ class Index2DHost():
|
||||||
shuffle_idxs = idxs.copy()
|
shuffle_idxs = idxs.copy()
|
||||||
np.random.shuffle(shuffle_idxs)
|
np.random.shuffle(shuffle_idxs)
|
||||||
result.append(shuffle_idxs.pop())
|
result.append(shuffle_idxs.pop())
|
||||||
cqs[cq_id].put (result)
|
self.cqs[cq_id].put (result)
|
||||||
elif cmd == 1: #get_2D
|
elif cmd == 1: #get_2D
|
||||||
targ_idxs,count = obj[2], obj[3]
|
targ_idxs,count = obj[2], obj[3]
|
||||||
result = []
|
result = []
|
||||||
|
@ -58,7 +57,7 @@ class Index2DHost():
|
||||||
np.random.shuffle(ar)
|
np.random.shuffle(ar)
|
||||||
sub_idxs.append(ar.pop())
|
sub_idxs.append(ar.pop())
|
||||||
result.append (sub_idxs)
|
result.append (sub_idxs)
|
||||||
cqs[cq_id].put (result)
|
self.cqs[cq_id].put (result)
|
||||||
|
|
||||||
time.sleep(0.005)
|
time.sleep(0.005)
|
||||||
|
|
||||||
|
@ -100,19 +99,18 @@ class IndexHost():
|
||||||
"""
|
"""
|
||||||
Provides random shuffled indexes for multiprocesses
|
Provides random shuffled indexes for multiprocesses
|
||||||
"""
|
"""
|
||||||
def __init__(self, indexes_count, max_number_of_clis=128):
|
def __init__(self, indexes_count):
|
||||||
self.sq = multiprocessing.Queue()
|
self.sq = multiprocessing.Queue()
|
||||||
self.cqs = [ multiprocessing.Queue() for _ in range(max_number_of_clis) ]
|
self.cqs = []
|
||||||
self.n_clis = 0
|
self.clis = []
|
||||||
self.max_number_of_clis = max_number_of_clis
|
self.thread = threading.Thread(target=self.host_thread, args=(indexes_count,) )
|
||||||
|
self.thread.daemon = True
|
||||||
|
self.thread.start()
|
||||||
|
|
||||||
self.p = multiprocessing.Process(target=self.host_proc, args=(indexes_count, self.sq, self.cqs) )
|
def host_thread(self, indexes_count):
|
||||||
self.p.daemon = True
|
|
||||||
self.p.start()
|
|
||||||
|
|
||||||
def host_proc(self, indexes_count, sq, cqs):
|
|
||||||
idxs = [*range(indexes_count)]
|
idxs = [*range(indexes_count)]
|
||||||
shuffle_idxs = []
|
shuffle_idxs = []
|
||||||
|
sq = self.sq
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
while not sq.empty():
|
while not sq.empty():
|
||||||
|
@ -125,18 +123,15 @@ class IndexHost():
|
||||||
shuffle_idxs = idxs.copy()
|
shuffle_idxs = idxs.copy()
|
||||||
np.random.shuffle(shuffle_idxs)
|
np.random.shuffle(shuffle_idxs)
|
||||||
result.append(shuffle_idxs.pop())
|
result.append(shuffle_idxs.pop())
|
||||||
cqs[cq_id].put (result)
|
self.cqs[cq_id].put (result)
|
||||||
|
|
||||||
time.sleep(0.005)
|
time.sleep(0.005)
|
||||||
|
|
||||||
def create_cli(self):
|
def create_cli(self):
|
||||||
if self.n_clis == self.max_number_of_clis:
|
cq = multiprocessing.Queue()
|
||||||
raise Exception("")
|
self.cqs.append ( cq )
|
||||||
|
cq_id = len(self.cqs)-1
|
||||||
cq_id = self.n_clis
|
return IndexHost.Cli(self.sq, cq, cq_id)
|
||||||
self.n_clis += 1
|
|
||||||
|
|
||||||
return IndexHost.Cli(self.sq, self.cqs[cq_id], cq_id)
|
|
||||||
|
|
||||||
# disable pickling
|
# disable pickling
|
||||||
def __getstate__(self):
|
def __getstate__(self):
|
||||||
|
@ -159,50 +154,50 @@ class IndexHost():
|
||||||
time.sleep(0.001)
|
time.sleep(0.001)
|
||||||
|
|
||||||
class ListHost():
|
class ListHost():
|
||||||
def __init__(self, list_=None, max_number_of_clis=128):
|
def __init__(self, list_):
|
||||||
self.sq = multiprocessing.Queue()
|
self.sq = multiprocessing.Queue()
|
||||||
self.cqs = [ multiprocessing.Queue() for _ in range(max_number_of_clis) ]
|
self.cqs = []
|
||||||
self.n_clis = 0
|
self.clis = []
|
||||||
self.max_number_of_clis = max_number_of_clis
|
self.m_list = list_
|
||||||
|
self.thread = threading.Thread(target=self.host_thread)
|
||||||
self.p = multiprocessing.Process(target=self.host_proc, args=(self.sq, self.cqs) )
|
self.thread.daemon = True
|
||||||
self.p.daemon = True
|
self.thread.start()
|
||||||
self.p.start()
|
|
||||||
|
|
||||||
def host_proc(self, sq, cqs):
|
def host_thread(self):
|
||||||
m_list = list()
|
sq = self.sq
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
while not sq.empty():
|
while not sq.empty():
|
||||||
obj = sq.get()
|
obj = sq.get()
|
||||||
cq_id, cmd = obj[0], obj[1]
|
cq_id, cmd = obj[0], obj[1]
|
||||||
|
|
||||||
if cmd == 0:
|
if cmd == 0:
|
||||||
cqs[cq_id].put ( len(m_list) )
|
self.cqs[cq_id].put ( len(self.m_list) )
|
||||||
elif cmd == 1:
|
elif cmd == 1:
|
||||||
idx = obj[2]
|
idx = obj[2]
|
||||||
item = m_list[idx ]
|
item = self.m_list[idx ]
|
||||||
cqs[cq_id].put ( item )
|
self.cqs[cq_id].put ( item )
|
||||||
elif cmd == 2:
|
elif cmd == 2:
|
||||||
result = []
|
result = []
|
||||||
for item in obj[2]:
|
for item in obj[2]:
|
||||||
result.append ( m_list[item] )
|
result.append ( self.m_list[item] )
|
||||||
cqs[cq_id].put ( result )
|
self.cqs[cq_id].put ( result )
|
||||||
elif cmd == 3:
|
elif cmd == 3:
|
||||||
m_list.insert(obj[2], obj[3])
|
self.m_list.insert(obj[2], obj[3])
|
||||||
elif cmd == 4:
|
elif cmd == 4:
|
||||||
m_list.append(obj[2])
|
self.m_list.append(obj[2])
|
||||||
elif cmd == 5:
|
elif cmd == 5:
|
||||||
m_list.extend(obj[2])
|
self.m_list.extend(obj[2])
|
||||||
|
|
||||||
time.sleep(0.005)
|
time.sleep(0.005)
|
||||||
|
|
||||||
def create_cli(self):
|
def create_cli(self):
|
||||||
if self.n_clis == self.max_number_of_clis:
|
cq = multiprocessing.Queue()
|
||||||
raise Exception("")
|
self.cqs.append ( cq )
|
||||||
|
cq_id = len(self.cqs)-1
|
||||||
cq_id = self.n_clis
|
return ListHost.Cli(self.sq, cq, cq_id)
|
||||||
self.n_clis += 1
|
|
||||||
|
def get_list(self):
|
||||||
return ListHost.Cli(self.sq, self.cqs[cq_id], cq_id)
|
return self.list_
|
||||||
|
|
||||||
# disable pickling
|
# disable pickling
|
||||||
def __getstate__(self):
|
def __getstate__(self):
|
||||||
|
@ -223,7 +218,7 @@ class ListHost():
|
||||||
if not self.cq.empty():
|
if not self.cq.empty():
|
||||||
return self.cq.get()
|
return self.cq.get()
|
||||||
time.sleep(0.001)
|
time.sleep(0.001)
|
||||||
|
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
self.sq.put ( (self.cq_id,1,key) )
|
self.sq.put ( (self.cq_id,1,key) )
|
||||||
|
|
||||||
|
@ -231,7 +226,7 @@ class ListHost():
|
||||||
if not self.cq.empty():
|
if not self.cq.empty():
|
||||||
return self.cq.get()
|
return self.cq.get()
|
||||||
time.sleep(0.001)
|
time.sleep(0.001)
|
||||||
|
|
||||||
def multi_get(self, keys):
|
def multi_get(self, keys):
|
||||||
self.sq.put ( (self.cq_id,2,keys) )
|
self.sq.put ( (self.cq_id,2,keys) )
|
||||||
|
|
||||||
|
@ -239,17 +234,17 @@ class ListHost():
|
||||||
if not self.cq.empty():
|
if not self.cq.empty():
|
||||||
return self.cq.get()
|
return self.cq.get()
|
||||||
time.sleep(0.001)
|
time.sleep(0.001)
|
||||||
|
|
||||||
def insert(self, index, item):
|
def insert(self, index, item):
|
||||||
self.sq.put ( (self.cq_id,3,index,item) )
|
self.sq.put ( (self.cq_id,3,index,item) )
|
||||||
|
|
||||||
def append(self, item):
|
def append(self, item):
|
||||||
self.sq.put ( (self.cq_id,4,item) )
|
self.sq.put ( (self.cq_id,4,item) )
|
||||||
|
|
||||||
def extend(self, items):
|
def extend(self, items):
|
||||||
self.sq.put ( (self.cq_id,5,items) )
|
self.sq.put ( (self.cq_id,5,items) )
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class DictHost():
|
class DictHost():
|
||||||
def __init__(self, d, num_users):
|
def __init__(self, d, num_users):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue