mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-11 15:47:01 -07:00
increased speed of sort by hist
This commit is contained in:
parent
5835a8832a
commit
d04e8b1d91
1 changed files with 109 additions and 14 deletions
|
@ -238,9 +238,96 @@ def sort_by_face_yaw(input_path):
|
||||||
img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True)
|
img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True)
|
||||||
|
|
||||||
return img_list
|
return img_list
|
||||||
|
|
||||||
|
class HistSimSubprocessor(SubprocessorBase):
|
||||||
|
#override
|
||||||
|
def __init__(self, img_list):
|
||||||
|
self.img_list = img_list
|
||||||
|
self.img_list_range = [i for i in range(0, len(img_list) )]
|
||||||
|
self.result = [0]*len(img_list)
|
||||||
|
super().__init__('HistSim', 60)
|
||||||
|
|
||||||
|
#override
|
||||||
|
def onHostClientsInitialized(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
#override
|
||||||
|
def process_info_generator(self):
|
||||||
|
for i in range(0, multiprocessing.cpu_count()):
|
||||||
|
yield 'CPU%d' % (i), {}, {'device_idx': i,
|
||||||
|
'device_name': 'CPU%d' % (i),
|
||||||
|
'img_list' : self.img_list
|
||||||
|
}
|
||||||
|
|
||||||
|
#override
|
||||||
|
def get_no_process_started_message(self):
|
||||||
|
print ( 'Unable to start CPU processes.')
|
||||||
|
|
||||||
|
#override
|
||||||
|
def onHostGetProgressBarDesc(self):
|
||||||
|
return "Computing"
|
||||||
|
|
||||||
|
#override
|
||||||
|
def onHostGetProgressBarLen(self):
|
||||||
|
return len (self.img_list)
|
||||||
|
|
||||||
|
#override
|
||||||
|
def onHostGetData(self):
|
||||||
|
if len (self.img_list_range) > 0:
|
||||||
|
return [self.img_list_range.pop(0)]
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
#override
|
||||||
|
def onHostDataReturn (self, data):
|
||||||
|
self.img_list_range.insert(0, data[0])
|
||||||
|
|
||||||
|
#override
|
||||||
|
def onClientInitialize(self, client_dict):
|
||||||
|
self.img_list = client_dict['img_list']
|
||||||
|
self.img_list_len = len(self.img_list)
|
||||||
|
|
||||||
|
self.safe_print ('Running on %s.' % (client_dict['device_name']) )
|
||||||
|
return None
|
||||||
|
|
||||||
|
#override
|
||||||
|
def onClientFinalize(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
#override
|
||||||
|
def onClientProcessData(self, data):
|
||||||
|
i = data[0]
|
||||||
|
|
||||||
|
result = [0]*self.img_list_len
|
||||||
|
for j in range( 0, self.img_list_len):
|
||||||
|
if i == j:
|
||||||
|
continue
|
||||||
|
|
||||||
|
result[j] = cv2.compareHist(self.img_list[i][1], self.img_list[j][1], cv2.HISTCMP_BHATTACHARYYA) + \
|
||||||
|
cv2.compareHist(self.img_list[i][2], self.img_list[j][2], cv2.HISTCMP_BHATTACHARYYA) + \
|
||||||
|
cv2.compareHist(self.img_list[i][3], self.img_list[j][3], cv2.HISTCMP_BHATTACHARYYA)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
#override
|
||||||
|
def onClientGetDataName (self, data):
|
||||||
|
#return string identificator of your data
|
||||||
|
return self.img_list[data[0]][0]
|
||||||
|
|
||||||
|
#override
|
||||||
|
def onHostResult (self, data, result):
|
||||||
|
self.result[data[0]] = result
|
||||||
|
return 1
|
||||||
|
|
||||||
|
#override
|
||||||
|
def onHostProcessEnd(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
#override
|
||||||
|
def get_start_return(self):
|
||||||
|
return self.result
|
||||||
|
|
||||||
def sort_by_hist(input_path):
|
def sort_by_hist(input_path):
|
||||||
|
|
||||||
print ("Sorting by histogram similarity...")
|
print ("Sorting by histogram similarity...")
|
||||||
|
|
||||||
img_list = []
|
img_list = []
|
||||||
|
@ -250,21 +337,29 @@ def sort_by_hist(input_path):
|
||||||
cv2.calcHist([img], [1], None, [256], [0, 256]),
|
cv2.calcHist([img], [1], None, [256], [0, 256]),
|
||||||
cv2.calcHist([img], [2], None, [256], [0, 256])
|
cv2.calcHist([img], [2], None, [256], [0, 256])
|
||||||
])
|
])
|
||||||
|
|
||||||
img_list_len = len(img_list)
|
img_list_2D = HistSimSubprocessor(img_list).process()
|
||||||
for i in tqdm( range(0, img_list_len-1), desc="Sorting"):
|
|
||||||
min_score = float("inf")
|
rem_list = [i for i in range(len(img_list))]
|
||||||
j_min_score = i+1
|
|
||||||
for j in range(i+1,len(img_list)):
|
s_list = []
|
||||||
score = cv2.compareHist(img_list[i][1], img_list[j][1], cv2.HISTCMP_BHATTACHARYYA) + \
|
s_list.append ( rem_list.pop(0) )
|
||||||
cv2.compareHist(img_list[i][2], img_list[j][2], cv2.HISTCMP_BHATTACHARYYA) + \
|
|
||||||
cv2.compareHist(img_list[i][3], img_list[j][3], cv2.HISTCMP_BHATTACHARYYA)
|
for x in tqdm( range(len(rem_list)), desc="Sorting"):
|
||||||
|
i = s_list[ len(s_list) -1 ]
|
||||||
|
|
||||||
|
min_score_id = rem_list[0]
|
||||||
|
min_score = img_list_2D[i][min_score_id]
|
||||||
|
|
||||||
|
for score_id in rem_list:
|
||||||
|
score = img_list_2D[i][score_id]
|
||||||
if score < min_score:
|
if score < min_score:
|
||||||
min_score = score
|
min_score = score
|
||||||
j_min_score = j
|
min_score_id = score_id
|
||||||
img_list[i+1], img_list[j_min_score] = img_list[j_min_score], img_list[i+1]
|
|
||||||
|
s_list.append ( rem_list.pop(rem_list.index(min_score_id)) )
|
||||||
|
|
||||||
return img_list
|
return [ (img_list[i][0],) for i in s_list ]
|
||||||
|
|
||||||
class HistDissimSubprocessor(SubprocessorBase):
|
class HistDissimSubprocessor(SubprocessorBase):
|
||||||
#override
|
#override
|
||||||
|
@ -337,7 +432,7 @@ class HistDissimSubprocessor(SubprocessorBase):
|
||||||
#override
|
#override
|
||||||
def onClientGetDataName (self, data):
|
def onClientGetDataName (self, data):
|
||||||
#return string identificator of your data
|
#return string identificator of your data
|
||||||
return data[1]
|
return self.img_list[data[0]][0]
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def onHostResult (self, data, result):
|
def onHostResult (self, data, result):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue