mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-07 05:22:06 -07:00
increased speed of sort by hist sim for ten thousands of faces
This commit is contained in:
parent
08550ac856
commit
b5ba7d52cb
6 changed files with 226 additions and 140 deletions
|
@ -119,7 +119,7 @@ download from here: https://mega.nz/#F!y1ERHDaL!PPwg01PQZk0FhWLVo5_MaQ
|
||||||
|
|
||||||
### **Build info**
|
### **Build info**
|
||||||
|
|
||||||
dlib==19.10.0 from pip compiled without CUDA. Therefore you have to compile DLIB manually.
|
dlib==19.10.0 from pip compiled without CUDA. Therefore you have to compile DLIB manually, orelse use MT extractor only.
|
||||||
|
|
||||||
Command line example for windows: `python setup.py install -G "Visual Studio 14 2015" --yes DLIB_USE_CUDA`
|
Command line example for windows: `python setup.py install -G "Visual Studio 14 2015" --yes DLIB_USE_CUDA`
|
||||||
|
|
||||||
|
|
|
@ -104,13 +104,13 @@ class ConvertSubprocessor(SubprocessorBase):
|
||||||
return len (self.input_data)
|
return len (self.input_data)
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def onHostGetData(self):
|
def onHostGetData(self, host_dict):
|
||||||
if len (self.input_data) > 0:
|
if len (self.input_data) > 0:
|
||||||
return self.input_data.pop(0)
|
return self.input_data.pop(0)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def onHostDataReturn (self, data):
|
def onHostDataReturn (self, host_dict, data):
|
||||||
self.input_data.insert(0, data)
|
self.input_data.insert(0, data)
|
||||||
|
|
||||||
#overridable
|
#overridable
|
||||||
|
@ -195,13 +195,13 @@ class ConvertSubprocessor(SubprocessorBase):
|
||||||
return (files_processed, faces_processed)
|
return (files_processed, faces_processed)
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def onHostResult (self, data, result):
|
def onHostResult (self, host_dict, data, result):
|
||||||
self.files_processed += result[0]
|
self.files_processed += result[0]
|
||||||
self.faces_processed += result[1]
|
self.faces_processed += result[1]
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def get_start_return(self):
|
def onFinalizeAndGetResult(self):
|
||||||
return self.files_processed, self.faces_processed
|
return self.files_processed, self.faces_processed
|
||||||
|
|
||||||
def main (input_dir, output_dir, model_dir, model_name, aligned_dir=None, **in_options):
|
def main (input_dir, output_dir, model_dir, model_name, aligned_dir=None, **in_options):
|
||||||
|
|
|
@ -125,7 +125,7 @@ class ExtractSubprocessor(SubprocessorBase):
|
||||||
return len (self.input_data)
|
return len (self.input_data)
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def onHostGetData(self):
|
def onHostGetData(self, host_dict):
|
||||||
if not self.manual:
|
if not self.manual:
|
||||||
if len (self.input_data) > 0:
|
if len (self.input_data) > 0:
|
||||||
return self.input_data.pop(0)
|
return self.input_data.pop(0)
|
||||||
|
@ -237,7 +237,7 @@ class ExtractSubprocessor(SubprocessorBase):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def onHostDataReturn (self, data):
|
def onHostDataReturn (self, host_dict, data):
|
||||||
if not self.manual:
|
if not self.manual:
|
||||||
self.input_data.insert(0, data)
|
self.input_data.insert(0, data)
|
||||||
|
|
||||||
|
@ -348,7 +348,7 @@ class ExtractSubprocessor(SubprocessorBase):
|
||||||
return data[0]
|
return data[0]
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def onHostResult (self, data, result):
|
def onHostResult (self, host_dict, data, result):
|
||||||
if self.manual == True:
|
if self.manual == True:
|
||||||
self.landmarks = result[1][0][1]
|
self.landmarks = result[1][0][1]
|
||||||
|
|
||||||
|
@ -374,12 +374,9 @@ class ExtractSubprocessor(SubprocessorBase):
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def onHostProcessEnd(self):
|
def onFinalizeAndGetResult(self):
|
||||||
if self.manual == True:
|
if self.manual == True:
|
||||||
cv2.destroyAllWindows()
|
cv2.destroyAllWindows()
|
||||||
|
|
||||||
#override
|
|
||||||
def get_start_return(self):
|
|
||||||
return self.result
|
return self.result
|
||||||
|
|
||||||
'''
|
'''
|
||||||
|
|
|
@ -65,14 +65,14 @@ class BlurEstimatorSubprocessor(SubprocessorBase):
|
||||||
return len (self.input_data)
|
return len (self.input_data)
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def onHostGetData(self):
|
def onHostGetData(self, host_dict):
|
||||||
if len (self.input_data) > 0:
|
if len (self.input_data) > 0:
|
||||||
return self.input_data.pop(0)
|
return self.input_data.pop(0)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def onHostDataReturn (self, data):
|
def onHostDataReturn (self, host_dict, data):
|
||||||
self.input_data.insert(0, data)
|
self.input_data.insert(0, data)
|
||||||
|
|
||||||
#override
|
#override
|
||||||
|
@ -104,7 +104,7 @@ class BlurEstimatorSubprocessor(SubprocessorBase):
|
||||||
return data[0]
|
return data[0]
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def onHostResult (self, data, result):
|
def onHostResult (self, host_dict, data, result):
|
||||||
if result[1] == 0:
|
if result[1] == 0:
|
||||||
filename_path = Path( data[0] )
|
filename_path = Path( data[0] )
|
||||||
print ( "{0} - invalid image, renaming to {0}_invalid.".format(str(filename_path)) )
|
print ( "{0} - invalid image, renaming to {0}_invalid.".format(str(filename_path)) )
|
||||||
|
@ -114,11 +114,7 @@ class BlurEstimatorSubprocessor(SubprocessorBase):
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def onHostProcessEnd(self):
|
def onFinalizeAndGetResult(self):
|
||||||
pass
|
|
||||||
|
|
||||||
#override
|
|
||||||
def get_start_return(self):
|
|
||||||
return self.result
|
return self.result
|
||||||
|
|
||||||
|
|
||||||
|
@ -239,30 +235,119 @@ def sort_by_face_yaw(input_path):
|
||||||
|
|
||||||
return img_list
|
return img_list
|
||||||
|
|
||||||
|
class HistSsimSubprocessor(SubprocessorBase):
|
||||||
|
#override
|
||||||
|
def __init__(self, img_list ):
|
||||||
|
self.img_list = img_list
|
||||||
|
self.img_list_len = len(img_list)
|
||||||
|
|
||||||
|
slice_count = 20000
|
||||||
|
sliced_count = self.img_list_len // slice_count
|
||||||
|
|
||||||
|
if sliced_count > 12:
|
||||||
|
sliced_count = 11.9
|
||||||
|
slice_count = int(self.img_list_len / sliced_count)
|
||||||
|
sliced_count = self.img_list_len // slice_count
|
||||||
|
|
||||||
|
self.img_chunks_list = [ self.img_list[i*slice_count : (i+1)*slice_count] for i in range(sliced_count) ] + \
|
||||||
|
[ self.img_list[sliced_count*slice_count:] ]
|
||||||
|
|
||||||
|
self.result = []
|
||||||
|
|
||||||
|
super().__init__('HistSsim', 0)
|
||||||
|
|
||||||
|
#override
|
||||||
|
def onHostClientsInitialized(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
#override
|
||||||
|
def process_info_generator(self):
|
||||||
|
for i in range( len(self.img_chunks_list) ):
|
||||||
|
yield 'CPU%d' % (i), {'i':i}, {'device_idx': i,
|
||||||
|
'device_name': 'CPU%d' % (i)
|
||||||
|
}
|
||||||
|
|
||||||
|
#override
|
||||||
|
def get_no_process_started_message(self):
|
||||||
|
print ( 'Unable to start CPU processes.')
|
||||||
|
|
||||||
|
#override
|
||||||
|
def onHostGetProgressBarDesc(self):
|
||||||
|
return "Sorting"
|
||||||
|
|
||||||
|
#override
|
||||||
|
def onHostGetProgressBarLen(self):
|
||||||
|
return len(self.img_list)
|
||||||
|
|
||||||
|
#override
|
||||||
|
def onHostClientsInitialized(self):
|
||||||
|
self.inc_progress_bar(len(self.img_chunks_list))
|
||||||
|
|
||||||
|
#override
|
||||||
|
def onHostGetData(self, host_dict):
|
||||||
|
if len (self.img_chunks_list) > 0:
|
||||||
|
return self.img_chunks_list.pop(0)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
#override
|
||||||
|
def onHostDataReturn (self, host_dict, data):
|
||||||
|
raise Exception("Fail to process data. Decrease number of images and try again.")
|
||||||
|
|
||||||
|
#override
|
||||||
|
def onClientInitialize(self, client_dict):
|
||||||
|
self.safe_print ('Running on %s.' % (client_dict['device_name']) )
|
||||||
|
return None
|
||||||
|
|
||||||
|
#override
|
||||||
|
def onClientFinalize(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
#override
|
||||||
|
def onClientProcessData(self, data):
|
||||||
|
|
||||||
|
img_list = []
|
||||||
|
for x in data:
|
||||||
|
img = cv2.imread(x)
|
||||||
|
img_list.append ([x, cv2.calcHist([img], [0], None, [256], [0, 256]),
|
||||||
|
cv2.calcHist([img], [1], None, [256], [0, 256]),
|
||||||
|
cv2.calcHist([img], [2], None, [256], [0, 256])
|
||||||
|
])
|
||||||
|
|
||||||
|
img_list_len = len(img_list)
|
||||||
|
for i in range(img_list_len-1):
|
||||||
|
min_score = float("inf")
|
||||||
|
j_min_score = i+1
|
||||||
|
for j in range(i+1,len(img_list)):
|
||||||
|
score = cv2.compareHist(img_list[i][1], img_list[j][1], cv2.HISTCMP_BHATTACHARYYA) + \
|
||||||
|
cv2.compareHist(img_list[i][2], img_list[j][2], cv2.HISTCMP_BHATTACHARYYA) + \
|
||||||
|
cv2.compareHist(img_list[i][3], img_list[j][3], cv2.HISTCMP_BHATTACHARYYA)
|
||||||
|
if score < min_score:
|
||||||
|
min_score = score
|
||||||
|
j_min_score = j
|
||||||
|
img_list[i+1], img_list[j_min_score] = img_list[j_min_score], img_list[i+1]
|
||||||
|
|
||||||
|
self.inc_progress_bar(1)
|
||||||
|
|
||||||
|
return img_list
|
||||||
|
|
||||||
|
#override
|
||||||
|
def onClientGetDataName (self, data):
|
||||||
|
#return string identificator of your data
|
||||||
|
return "Bunch of images"
|
||||||
|
|
||||||
|
#override
|
||||||
|
def onHostResult (self, host_dict, data, result):
|
||||||
|
self.result += result
|
||||||
|
return 0
|
||||||
|
|
||||||
|
#override
|
||||||
|
def onFinalizeAndGetResult(self):
|
||||||
|
return self.result
|
||||||
|
|
||||||
def sort_by_hist(input_path):
|
def sort_by_hist(input_path):
|
||||||
print ("Sorting by histogram similarity...")
|
print ("Sorting by histogram similarity...")
|
||||||
|
img_list = HistSsimSubprocessor(Path_utils.get_image_paths(input_path)).process()
|
||||||
img_list = []
|
|
||||||
for x in tqdm( Path_utils.get_image_paths(input_path), desc="Loading"):
|
|
||||||
img = cv2.imread(x)
|
|
||||||
img_list.append ([x, cv2.calcHist([img], [0], None, [256], [0, 256]),
|
|
||||||
cv2.calcHist([img], [1], None, [256], [0, 256]),
|
|
||||||
cv2.calcHist([img], [2], None, [256], [0, 256])
|
|
||||||
])
|
|
||||||
|
|
||||||
img_list_len = len(img_list)
|
|
||||||
for i in tqdm( range(0, img_list_len-1), desc="Sorting"):
|
|
||||||
min_score = float("inf")
|
|
||||||
j_min_score = i+1
|
|
||||||
for j in range(i+1,len(img_list)):
|
|
||||||
score = cv2.compareHist(img_list[i][1], img_list[j][1], cv2.HISTCMP_BHATTACHARYYA) + \
|
|
||||||
cv2.compareHist(img_list[i][2], img_list[j][2], cv2.HISTCMP_BHATTACHARYYA) + \
|
|
||||||
cv2.compareHist(img_list[i][3], img_list[j][3], cv2.HISTCMP_BHATTACHARYYA)
|
|
||||||
if score < min_score:
|
|
||||||
min_score = score
|
|
||||||
j_min_score = j
|
|
||||||
img_list[i+1], img_list[j_min_score] = img_list[j_min_score], img_list[i+1]
|
|
||||||
|
|
||||||
return img_list
|
return img_list
|
||||||
|
|
||||||
class HistDissimSubprocessor(SubprocessorBase):
|
class HistDissimSubprocessor(SubprocessorBase):
|
||||||
|
@ -300,14 +385,14 @@ class HistDissimSubprocessor(SubprocessorBase):
|
||||||
return len (self.img_list)
|
return len (self.img_list)
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def onHostGetData(self):
|
def onHostGetData(self, host_dict):
|
||||||
if len (self.img_list_range) > 0:
|
if len (self.img_list_range) > 0:
|
||||||
return [self.img_list_range.pop(0)]
|
return [self.img_list_range.pop(0)]
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def onHostDataReturn (self, data):
|
def onHostDataReturn (self, host_dict, data):
|
||||||
self.img_list_range.insert(0, data[0])
|
self.img_list_range.insert(0, data[0])
|
||||||
|
|
||||||
#override
|
#override
|
||||||
|
@ -339,16 +424,12 @@ class HistDissimSubprocessor(SubprocessorBase):
|
||||||
return self.img_list[data[0]][0]
|
return self.img_list[data[0]][0]
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def onHostResult (self, data, result):
|
def onHostResult (self, host_dict, data, result):
|
||||||
self.img_list[data[0]][2] = result
|
self.img_list[data[0]][2] = result
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def onHostProcessEnd(self):
|
def onFinalizeAndGetResult(self):
|
||||||
pass
|
|
||||||
|
|
||||||
#override
|
|
||||||
def get_start_return(self):
|
|
||||||
return self.img_list
|
return self.img_list
|
||||||
|
|
||||||
def sort_by_hist_dissim(input_path):
|
def sort_by_hist_dissim(input_path):
|
||||||
|
|
|
@ -12,13 +12,14 @@ class SampleProcessor(object):
|
||||||
WARPED = 0x00000002,
|
WARPED = 0x00000002,
|
||||||
WARPED_TRANSFORMED = 0x00000004,
|
WARPED_TRANSFORMED = 0x00000004,
|
||||||
TRANSFORMED = 0x00000008,
|
TRANSFORMED = 0x00000008,
|
||||||
|
LANDMARKS_ARRAY = 0x00000010, #currently unused
|
||||||
|
|
||||||
FACE_ALIGN_HALF = 0x00000010,
|
FACE_ALIGN_HALF = 0x00000100,
|
||||||
FACE_ALIGN_FULL = 0x00000020,
|
FACE_ALIGN_FULL = 0x00000200,
|
||||||
FACE_ALIGN_HEAD = 0x00000040,
|
FACE_ALIGN_HEAD = 0x00000400,
|
||||||
FACE_ALIGN_AVATAR = 0x00000080,
|
FACE_ALIGN_AVATAR = 0x00000800,
|
||||||
FACE_MASK_FULL = 0x00000100,
|
FACE_MASK_FULL = 0x00001000,
|
||||||
FACE_MASK_EYES = 0x00000200,
|
FACE_MASK_EYES = 0x00002000,
|
||||||
|
|
||||||
MODE_BGR = 0x01000000, #BGR
|
MODE_BGR = 0x01000000, #BGR
|
||||||
MODE_G = 0x02000000, #Grayscale
|
MODE_G = 0x02000000, #Grayscale
|
||||||
|
@ -47,7 +48,7 @@ class SampleProcessor(object):
|
||||||
|
|
||||||
params = image_utils.gen_warp_params(source, sample_process_options.random_flip, rotation_range=sample_process_options.rotation_range, scale_range=sample_process_options.scale_range, tx_range=sample_process_options.tx_range, ty_range=sample_process_options.ty_range )
|
params = image_utils.gen_warp_params(source, sample_process_options.random_flip, rotation_range=sample_process_options.rotation_range, scale_range=sample_process_options.scale_range, tx_range=sample_process_options.tx_range, ty_range=sample_process_options.ty_range )
|
||||||
|
|
||||||
images = [[None]*3 for _ in range(4)]
|
images = [[None]*3 for _ in range(5)]
|
||||||
|
|
||||||
sample_rnd_seed = np.random.randint(0x80000000)
|
sample_rnd_seed = np.random.randint(0x80000000)
|
||||||
|
|
||||||
|
@ -65,6 +66,8 @@ class SampleProcessor(object):
|
||||||
img_type = 2
|
img_type = 2
|
||||||
elif f & SampleProcessor.TypeFlags.TRANSFORMED != 0:
|
elif f & SampleProcessor.TypeFlags.TRANSFORMED != 0:
|
||||||
img_type = 3
|
img_type = 3
|
||||||
|
elif f & SampleProcessor.TypeFlags.LANDMARKS_ARRAY != 0:
|
||||||
|
img_type = 4
|
||||||
else:
|
else:
|
||||||
raise ValueError ('expected SampleTypeFlags type')
|
raise ValueError ('expected SampleTypeFlags type')
|
||||||
|
|
||||||
|
@ -84,57 +87,63 @@ class SampleProcessor(object):
|
||||||
elif f & SampleProcessor.TypeFlags.FACE_ALIGN_AVATAR != 0:
|
elif f & SampleProcessor.TypeFlags.FACE_ALIGN_AVATAR != 0:
|
||||||
target_face_type = FaceType.AVATAR
|
target_face_type = FaceType.AVATAR
|
||||||
|
|
||||||
if images[img_type][face_mask_type] is None:
|
if img_type == 4:
|
||||||
img = source
|
l = sample.landmarks
|
||||||
if is_face_sample:
|
l = np.concatenate ( [ np.expand_dims(l[:,0] / w,-1), np.expand_dims(l[:,1] / h,-1) ], -1 )
|
||||||
if face_mask_type == 1:
|
l = np.clip(l, 0.0, 1.0)
|
||||||
img = np.concatenate( (img, LandmarksProcessor.get_image_hull_mask (source, sample.landmarks) ), -1 )
|
img = l
|
||||||
elif face_mask_type == 2:
|
|
||||||
mask = LandmarksProcessor.get_image_eye_mask (source, sample.landmarks)
|
|
||||||
mask = np.expand_dims (cv2.blur (mask, ( w // 32, w // 32 ) ), -1)
|
|
||||||
mask[mask > 0.0] = 1.0
|
|
||||||
img = np.concatenate( (img, mask ), -1 )
|
|
||||||
|
|
||||||
images[img_type][face_mask_type] = image_utils.warp_by_params (params, img, (img_type==1 or img_type==2), (img_type==2 or img_type==3), img_type != 0, face_mask_type == 0)
|
|
||||||
|
|
||||||
img = images[img_type][face_mask_type]
|
|
||||||
|
|
||||||
if is_face_sample and target_face_type != -1:
|
|
||||||
if target_face_type > sample.face_type:
|
|
||||||
raise Exception ('sample %s type %s does not match model requirement %s. Consider extract necessary type of faces.' % (sample.filename, sample.face_type, target_face_type) )
|
|
||||||
|
|
||||||
img = cv2.warpAffine( img, LandmarksProcessor.get_transform_mat (sample.landmarks, size, target_face_type), (size,size), flags=cv2.INTER_LANCZOS4 )
|
|
||||||
else:
|
else:
|
||||||
img = cv2.resize( img, (size,size), cv2.INTER_LANCZOS4 )
|
if images[img_type][face_mask_type] is None:
|
||||||
|
img = source
|
||||||
|
if is_face_sample:
|
||||||
|
if face_mask_type == 1:
|
||||||
|
img = np.concatenate( (img, LandmarksProcessor.get_image_hull_mask (source, sample.landmarks) ), -1 )
|
||||||
|
elif face_mask_type == 2:
|
||||||
|
mask = LandmarksProcessor.get_image_eye_mask (source, sample.landmarks)
|
||||||
|
mask = np.expand_dims (cv2.blur (mask, ( w // 32, w // 32 ) ), -1)
|
||||||
|
mask[mask > 0.0] = 1.0
|
||||||
|
img = np.concatenate( (img, mask ), -1 )
|
||||||
|
|
||||||
if random_sub_size != 0:
|
images[img_type][face_mask_type] = image_utils.warp_by_params (params, img, (img_type==1 or img_type==2), (img_type==2 or img_type==3), img_type != 0, face_mask_type == 0)
|
||||||
sub_size = size - random_sub_size
|
|
||||||
rnd_state = np.random.RandomState (sample_rnd_seed+random_sub_size)
|
|
||||||
start_x = rnd_state.randint(sub_size+1)
|
|
||||||
start_y = rnd_state.randint(sub_size+1)
|
|
||||||
img = img[start_y:start_y+sub_size,start_x:start_x+sub_size,:]
|
|
||||||
|
|
||||||
img_bgr = img[...,0:3]
|
img = images[img_type][face_mask_type]
|
||||||
img_mask = img[...,3:4]
|
|
||||||
|
|
||||||
if f & SampleProcessor.TypeFlags.MODE_BGR != 0:
|
if is_face_sample and target_face_type != -1:
|
||||||
img = img
|
if target_face_type > sample.face_type:
|
||||||
elif f & SampleProcessor.TypeFlags.MODE_BGR_SHUFFLE != 0:
|
raise Exception ('sample %s type %s does not match model requirement %s. Consider extract necessary type of faces.' % (sample.filename, sample.face_type, target_face_type) )
|
||||||
img_bgr = np.take (img_bgr, np.random.permutation(img_bgr.shape[-1]), axis=-1)
|
|
||||||
img = np.concatenate ( (img_bgr,img_mask) , -1 )
|
|
||||||
elif f & SampleProcessor.TypeFlags.MODE_G != 0:
|
|
||||||
img = np.concatenate ( (np.expand_dims(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY),-1),img_mask) , -1 )
|
|
||||||
elif f & SampleProcessor.TypeFlags.MODE_GGG != 0:
|
|
||||||
img = np.concatenate ( ( np.repeat ( np.expand_dims(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY),-1), (3,), -1), img_mask), -1)
|
|
||||||
elif is_face_sample and f & SampleProcessor.TypeFlags.MODE_M != 0:
|
|
||||||
if face_mask_type== 0:
|
|
||||||
raise ValueError ('no face_mask_type defined')
|
|
||||||
img = img_mask
|
|
||||||
else:
|
|
||||||
raise ValueError ('expected SampleTypeFlags mode')
|
|
||||||
|
|
||||||
if not debug and sample_process_options.normalize_tanh:
|
img = cv2.warpAffine( img, LandmarksProcessor.get_transform_mat (sample.landmarks, size, target_face_type), (size,size), flags=cv2.INTER_LANCZOS4 )
|
||||||
img = img * 2.0 - 1.0
|
else:
|
||||||
|
img = cv2.resize( img, (size,size), cv2.INTER_LANCZOS4 )
|
||||||
|
|
||||||
|
if random_sub_size != 0:
|
||||||
|
sub_size = size - random_sub_size
|
||||||
|
rnd_state = np.random.RandomState (sample_rnd_seed+random_sub_size)
|
||||||
|
start_x = rnd_state.randint(sub_size+1)
|
||||||
|
start_y = rnd_state.randint(sub_size+1)
|
||||||
|
img = img[start_y:start_y+sub_size,start_x:start_x+sub_size,:]
|
||||||
|
|
||||||
|
img_bgr = img[...,0:3]
|
||||||
|
img_mask = img[...,3:4]
|
||||||
|
|
||||||
|
if f & SampleProcessor.TypeFlags.MODE_BGR != 0:
|
||||||
|
img = img
|
||||||
|
elif f & SampleProcessor.TypeFlags.MODE_BGR_SHUFFLE != 0:
|
||||||
|
img_bgr = np.take (img_bgr, np.random.permutation(img_bgr.shape[-1]), axis=-1)
|
||||||
|
img = np.concatenate ( (img_bgr,img_mask) , -1 )
|
||||||
|
elif f & SampleProcessor.TypeFlags.MODE_G != 0:
|
||||||
|
img = np.concatenate ( (np.expand_dims(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY),-1),img_mask) , -1 )
|
||||||
|
elif f & SampleProcessor.TypeFlags.MODE_GGG != 0:
|
||||||
|
img = np.concatenate ( ( np.repeat ( np.expand_dims(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY),-1), (3,), -1), img_mask), -1)
|
||||||
|
elif is_face_sample and f & SampleProcessor.TypeFlags.MODE_M != 0:
|
||||||
|
if face_mask_type== 0:
|
||||||
|
raise ValueError ('no face_mask_type defined')
|
||||||
|
img = img_mask
|
||||||
|
else:
|
||||||
|
raise ValueError ('expected SampleTypeFlags mode')
|
||||||
|
|
||||||
|
if not debug and sample_process_options.normalize_tanh:
|
||||||
|
img = img * 2.0 - 1.0
|
||||||
|
|
||||||
outputs.append ( img )
|
outputs.append ( img )
|
||||||
|
|
||||||
|
|
|
@ -10,6 +10,7 @@ class SubprocessorBase(object):
|
||||||
def __init__(self, name, no_response_time_sec = 60):
|
def __init__(self, name, no_response_time_sec = 60):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.no_response_time_sec = no_response_time_sec
|
self.no_response_time_sec = no_response_time_sec
|
||||||
|
self.is_host = True
|
||||||
|
|
||||||
#overridable
|
#overridable
|
||||||
def process_info_generator(self):
|
def process_info_generator(self):
|
||||||
|
@ -29,12 +30,12 @@ class SubprocessorBase(object):
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
#overridable
|
#overridable
|
||||||
def onHostGetData(self):
|
def onHostGetData(self, host_dict):
|
||||||
#return data here
|
#return data here
|
||||||
return None
|
return None
|
||||||
|
|
||||||
#overridable
|
#overridable
|
||||||
def onHostDataReturn (self, data):
|
def onHostDataReturn (self, host_dict, data):
|
||||||
#input_data.insert(0, obj['data'])
|
#input_data.insert(0, obj['data'])
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -62,21 +63,20 @@ class SubprocessorBase(object):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
#overridable
|
#overridable
|
||||||
def onHostResult (self, data, result):
|
def onHostResult (self, host_dict, data, result):
|
||||||
#return count of progress bar update
|
#return count of progress bar update
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
#overridable
|
#overridable
|
||||||
def onHostProcessEnd(self):
|
def onFinalizeAndGetResult(self):
|
||||||
pass
|
|
||||||
|
|
||||||
#overridable
|
|
||||||
def get_start_return(self):
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def inc_progress_bar(self, c):
|
def inc_progress_bar(self, c):
|
||||||
self.progress_bar.n += c
|
if self.is_host:
|
||||||
self.progress_bar.refresh()
|
self.progress_bar.n += c
|
||||||
|
self.progress_bar.refresh()
|
||||||
|
else:
|
||||||
|
self.cq.put ( {'op': 'inc_bar', 'c':c} )
|
||||||
|
|
||||||
def safe_print(self, msg):
|
def safe_print(self, msg):
|
||||||
self.print_lock.acquire()
|
self.print_lock.acquire()
|
||||||
|
@ -132,11 +132,10 @@ class SubprocessorBase(object):
|
||||||
|
|
||||||
if len(self.processes) == 0:
|
if len(self.processes) == 0:
|
||||||
print ( self.get_no_process_started_message() )
|
print ( self.get_no_process_started_message() )
|
||||||
return self.get_start_return()
|
return None
|
||||||
|
|
||||||
self.onHostClientsInitialized()
|
|
||||||
|
|
||||||
self.progress_bar = tqdm( total=self.onHostGetProgressBarLen(), desc=self.onHostGetProgressBarDesc() )
|
self.progress_bar = tqdm( total=self.onHostGetProgressBarLen(), desc=self.onHostGetProgressBarDesc() )
|
||||||
|
self.onHostClientsInitialized()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
|
@ -149,15 +148,16 @@ class SubprocessorBase(object):
|
||||||
data = obj['data']
|
data = obj['data']
|
||||||
result = obj['result']
|
result = obj['result']
|
||||||
|
|
||||||
c = self.onHostResult (data, result)
|
c = self.onHostResult (p['host_dict'], data, result)
|
||||||
if c > 0:
|
if c > 0:
|
||||||
self.progress_bar.update(c)
|
self.progress_bar.update(c)
|
||||||
|
|
||||||
p['state'] = 'free'
|
p['state'] = 'free'
|
||||||
|
elif obj_op == 'inc_bar':
|
||||||
|
self.inc_progress_bar(obj['c'])
|
||||||
elif obj_op == 'error':
|
elif obj_op == 'error':
|
||||||
if 'data' in obj.keys():
|
if 'data' in obj.keys():
|
||||||
self.onHostDataReturn ( obj['data'] )
|
self.onHostDataReturn (p['host_dict'], obj['data'] )
|
||||||
|
|
||||||
if obj['close'] == True:
|
if obj['close'] == True:
|
||||||
p['sq'].put ( {'op': 'close'} )
|
p['sq'].put ( {'op': 'close'} )
|
||||||
|
@ -168,7 +168,7 @@ class SubprocessorBase(object):
|
||||||
|
|
||||||
for p in self.processes[:]:
|
for p in self.processes[:]:
|
||||||
if p['state'] == 'free':
|
if p['state'] == 'free':
|
||||||
data = self.onHostGetData()
|
data = self.onHostGetData(p['host_dict'])
|
||||||
if data is not None:
|
if data is not None:
|
||||||
p['sq'].put ( {'op': 'data', 'data' : data} )
|
p['sq'].put ( {'op': 'data', 'data' : data} )
|
||||||
p['sent_time'] = time.time()
|
p['sent_time'] = time.time()
|
||||||
|
@ -176,9 +176,9 @@ class SubprocessorBase(object):
|
||||||
p['state'] = 'busy'
|
p['state'] = 'busy'
|
||||||
|
|
||||||
elif p['state'] == 'busy':
|
elif p['state'] == 'busy':
|
||||||
if (time.time() - p['sent_time']) > self.no_response_time_sec:
|
if self.no_response_time_sec != 0 and (time.time() - p['sent_time']) > self.no_response_time_sec:
|
||||||
print ( '%s doesnt response, terminating it.' % (p['name']) )
|
print ( '%s doesnt response, terminating it.' % (p['name']) )
|
||||||
self.onHostDataReturn ( p['sent_data'] )
|
self.onHostDataReturn (p['host_dict'], p['sent_data'] )
|
||||||
p['process'].terminate()
|
p['process'].terminate()
|
||||||
self.processes.remove(p)
|
self.processes.remove(p)
|
||||||
|
|
||||||
|
@ -205,7 +205,7 @@ class SubprocessorBase(object):
|
||||||
terminate_it = True
|
terminate_it = True
|
||||||
break
|
break
|
||||||
|
|
||||||
if (time.time() - p['sent_time']) > self.no_response_time_sec:
|
if self.no_response_time_sec != 0 and (time.time() - p['sent_time']) > self.no_response_time_sec:
|
||||||
terminate_it = True
|
terminate_it = True
|
||||||
|
|
||||||
if terminate_it:
|
if terminate_it:
|
||||||
|
@ -215,13 +215,12 @@ class SubprocessorBase(object):
|
||||||
if all ([p['state'] == 'finalized' for p in self.processes]):
|
if all ([p['state'] == 'finalized' for p in self.processes]):
|
||||||
break
|
break
|
||||||
|
|
||||||
self.onHostProcessEnd()
|
return self.onFinalizeAndGetResult()
|
||||||
|
|
||||||
return self.get_start_return()
|
|
||||||
|
|
||||||
def subprocess(self, sq, cq, client_dict):
|
def subprocess(self, sq, cq, client_dict):
|
||||||
|
self.is_host = False
|
||||||
self.print_lock = client_dict['print_lock']
|
self.print_lock = client_dict['print_lock']
|
||||||
|
self.cq = cq
|
||||||
try:
|
try:
|
||||||
fail_message = self.onClientInitialize(client_dict)
|
fail_message = self.onClientInitialize(client_dict)
|
||||||
except:
|
except:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue