mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-11 15:47:01 -07:00
increased speed of sort by hist sim for ten thousands of faces
This commit is contained in:
parent
08550ac856
commit
b5ba7d52cb
6 changed files with 226 additions and 140 deletions
|
@ -65,14 +65,14 @@ class BlurEstimatorSubprocessor(SubprocessorBase):
|
|||
return len (self.input_data)
|
||||
|
||||
#override
|
||||
def onHostGetData(self):
|
||||
def onHostGetData(self, host_dict):
|
||||
if len (self.input_data) > 0:
|
||||
return self.input_data.pop(0)
|
||||
|
||||
return None
|
||||
|
||||
#override
|
||||
def onHostDataReturn (self, data):
|
||||
def onHostDataReturn (self, host_dict, data):
|
||||
self.input_data.insert(0, data)
|
||||
|
||||
#override
|
||||
|
@ -104,7 +104,7 @@ class BlurEstimatorSubprocessor(SubprocessorBase):
|
|||
return data[0]
|
||||
|
||||
#override
|
||||
def onHostResult (self, data, result):
|
||||
def onHostResult (self, host_dict, data, result):
|
||||
if result[1] == 0:
|
||||
filename_path = Path( data[0] )
|
||||
print ( "{0} - invalid image, renaming to {0}_invalid.".format(str(filename_path)) )
|
||||
|
@ -113,12 +113,8 @@ class BlurEstimatorSubprocessor(SubprocessorBase):
|
|||
self.result.append ( result )
|
||||
return 1
|
||||
|
||||
#override
|
||||
def onHostProcessEnd(self):
|
||||
pass
|
||||
|
||||
#override
|
||||
def get_start_return(self):
|
||||
def onFinalizeAndGetResult(self):
|
||||
return self.result
|
||||
|
||||
|
||||
|
@ -239,32 +235,121 @@ def sort_by_face_yaw(input_path):
|
|||
|
||||
return img_list
|
||||
|
||||
class HistSsimSubprocessor(SubprocessorBase):
|
||||
#override
|
||||
def __init__(self, img_list ):
|
||||
self.img_list = img_list
|
||||
self.img_list_len = len(img_list)
|
||||
|
||||
slice_count = 20000
|
||||
sliced_count = self.img_list_len // slice_count
|
||||
|
||||
if sliced_count > 12:
|
||||
sliced_count = 11.9
|
||||
slice_count = int(self.img_list_len / sliced_count)
|
||||
sliced_count = self.img_list_len // slice_count
|
||||
|
||||
self.img_chunks_list = [ self.img_list[i*slice_count : (i+1)*slice_count] for i in range(sliced_count) ] + \
|
||||
[ self.img_list[sliced_count*slice_count:] ]
|
||||
|
||||
self.result = []
|
||||
|
||||
super().__init__('HistSsim', 0)
|
||||
|
||||
#override
|
||||
def onHostClientsInitialized(self):
|
||||
pass
|
||||
|
||||
#override
|
||||
def process_info_generator(self):
|
||||
for i in range( len(self.img_chunks_list) ):
|
||||
yield 'CPU%d' % (i), {'i':i}, {'device_idx': i,
|
||||
'device_name': 'CPU%d' % (i)
|
||||
}
|
||||
|
||||
#override
|
||||
def get_no_process_started_message(self):
|
||||
print ( 'Unable to start CPU processes.')
|
||||
|
||||
#override
|
||||
def onHostGetProgressBarDesc(self):
|
||||
return "Sorting"
|
||||
|
||||
#override
|
||||
def onHostGetProgressBarLen(self):
|
||||
return len(self.img_list)
|
||||
|
||||
#override
|
||||
def onHostClientsInitialized(self):
|
||||
self.inc_progress_bar(len(self.img_chunks_list))
|
||||
|
||||
#override
|
||||
def onHostGetData(self, host_dict):
|
||||
if len (self.img_chunks_list) > 0:
|
||||
return self.img_chunks_list.pop(0)
|
||||
|
||||
return None
|
||||
|
||||
#override
|
||||
def onHostDataReturn (self, host_dict, data):
|
||||
raise Exception("Fail to process data. Decrease number of images and try again.")
|
||||
|
||||
#override
|
||||
def onClientInitialize(self, client_dict):
|
||||
self.safe_print ('Running on %s.' % (client_dict['device_name']) )
|
||||
return None
|
||||
|
||||
#override
|
||||
def onClientFinalize(self):
|
||||
pass
|
||||
|
||||
#override
|
||||
def onClientProcessData(self, data):
|
||||
|
||||
img_list = []
|
||||
for x in data:
|
||||
img = cv2.imread(x)
|
||||
img_list.append ([x, cv2.calcHist([img], [0], None, [256], [0, 256]),
|
||||
cv2.calcHist([img], [1], None, [256], [0, 256]),
|
||||
cv2.calcHist([img], [2], None, [256], [0, 256])
|
||||
])
|
||||
|
||||
img_list_len = len(img_list)
|
||||
for i in range(img_list_len-1):
|
||||
min_score = float("inf")
|
||||
j_min_score = i+1
|
||||
for j in range(i+1,len(img_list)):
|
||||
score = cv2.compareHist(img_list[i][1], img_list[j][1], cv2.HISTCMP_BHATTACHARYYA) + \
|
||||
cv2.compareHist(img_list[i][2], img_list[j][2], cv2.HISTCMP_BHATTACHARYYA) + \
|
||||
cv2.compareHist(img_list[i][3], img_list[j][3], cv2.HISTCMP_BHATTACHARYYA)
|
||||
if score < min_score:
|
||||
min_score = score
|
||||
j_min_score = j
|
||||
img_list[i+1], img_list[j_min_score] = img_list[j_min_score], img_list[i+1]
|
||||
|
||||
self.inc_progress_bar(1)
|
||||
|
||||
return img_list
|
||||
|
||||
#override
|
||||
def onClientGetDataName (self, data):
|
||||
#return string identificator of your data
|
||||
return "Bunch of images"
|
||||
|
||||
#override
|
||||
def onHostResult (self, host_dict, data, result):
|
||||
self.result += result
|
||||
return 0
|
||||
|
||||
#override
|
||||
def onFinalizeAndGetResult(self):
|
||||
return self.result
|
||||
|
||||
def sort_by_hist(input_path):
|
||||
print ("Sorting by histogram similarity...")
|
||||
|
||||
img_list = []
|
||||
for x in tqdm( Path_utils.get_image_paths(input_path), desc="Loading"):
|
||||
img = cv2.imread(x)
|
||||
img_list.append ([x, cv2.calcHist([img], [0], None, [256], [0, 256]),
|
||||
cv2.calcHist([img], [1], None, [256], [0, 256]),
|
||||
cv2.calcHist([img], [2], None, [256], [0, 256])
|
||||
])
|
||||
|
||||
img_list_len = len(img_list)
|
||||
for i in tqdm( range(0, img_list_len-1), desc="Sorting"):
|
||||
min_score = float("inf")
|
||||
j_min_score = i+1
|
||||
for j in range(i+1,len(img_list)):
|
||||
score = cv2.compareHist(img_list[i][1], img_list[j][1], cv2.HISTCMP_BHATTACHARYYA) + \
|
||||
cv2.compareHist(img_list[i][2], img_list[j][2], cv2.HISTCMP_BHATTACHARYYA) + \
|
||||
cv2.compareHist(img_list[i][3], img_list[j][3], cv2.HISTCMP_BHATTACHARYYA)
|
||||
if score < min_score:
|
||||
min_score = score
|
||||
j_min_score = j
|
||||
img_list[i+1], img_list[j_min_score] = img_list[j_min_score], img_list[i+1]
|
||||
|
||||
img_list = HistSsimSubprocessor(Path_utils.get_image_paths(input_path)).process()
|
||||
return img_list
|
||||
|
||||
|
||||
class HistDissimSubprocessor(SubprocessorBase):
|
||||
#override
|
||||
def __init__(self, img_list ):
|
||||
|
@ -300,14 +385,14 @@ class HistDissimSubprocessor(SubprocessorBase):
|
|||
return len (self.img_list)
|
||||
|
||||
#override
|
||||
def onHostGetData(self):
|
||||
def onHostGetData(self, host_dict):
|
||||
if len (self.img_list_range) > 0:
|
||||
return [self.img_list_range.pop(0)]
|
||||
|
||||
return None
|
||||
|
||||
#override
|
||||
def onHostDataReturn (self, data):
|
||||
def onHostDataReturn (self, host_dict, data):
|
||||
self.img_list_range.insert(0, data[0])
|
||||
|
||||
#override
|
||||
|
@ -339,16 +424,12 @@ class HistDissimSubprocessor(SubprocessorBase):
|
|||
return self.img_list[data[0]][0]
|
||||
|
||||
#override
|
||||
def onHostResult (self, data, result):
|
||||
def onHostResult (self, host_dict, data, result):
|
||||
self.img_list[data[0]][2] = result
|
||||
return 1
|
||||
|
||||
#override
|
||||
def onHostProcessEnd(self):
|
||||
pass
|
||||
|
||||
|
||||
#override
|
||||
def get_start_return(self):
|
||||
def onFinalizeAndGetResult(self):
|
||||
return self.img_list
|
||||
|
||||
def sort_by_hist_dissim(input_path):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue