diff --git a/mainscripts/Sorter.py b/mainscripts/Sorter.py index 2b1e63d..a802f95 100644 --- a/mainscripts/Sorter.py +++ b/mainscripts/Sorter.py @@ -10,15 +10,134 @@ from pathlib import Path from utils import Path_utils from utils.AlignedPNG import AlignedPNG from facelib import LandmarksProcessor +from utils.SubprocessorBase import SubprocessorBase +import multiprocessing -def estimate_blur(image): + +def estimate_sharpness(image): + height, width = image.shape[:2] + if image.ndim == 3: image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) - blur_map = cv2.Laplacian(image, cv2.CV_64F) - score = np.var(blur_map) - return score + sharpness = 0 + for y in range(height): + for x in range(width-1): + sharpness += abs( int(image[y, x]) - int(image[y, x+1]) ) + + for x in range(width): + for y in range(height-1): + sharpness += abs( int(image[y, x]) - int(image[y+1, x]) ) + + return sharpness + + +class BlurEstimatorSubprocessor(SubprocessorBase): + #override + def __init__(self, input_data ): + self.input_data = input_data + self.result = [] + + super().__init__('BlurEstimator', 60) + + #override + def onHostClientsInitialized(self): + pass + #override + def process_info_generator(self): + for i in range(0, multiprocessing.cpu_count() ): + yield 'CPU%d' % (i), {}, {'device_idx': i, + 'device_name': 'CPU%d' % (i), + } + + #override + def get_no_process_started_message(self): + print ( 'Unable to start CPU processes.') + + #override + def onHostGetProgressBarDesc(self): + return None + + #override + def onHostGetProgressBarLen(self): + return len (self.input_data) + + #override + def onHostGetData(self): + if len (self.input_data) > 0: + return self.input_data.pop(0) + + return None + + #override + def onHostDataReturn (self, data): + self.input_data.insert(0, data) + + #override + def onClientInitialize(self, client_dict): + self.safe_print ('Running on %s.' % (client_dict['device_name']) ) + return None + + #override + def onClientFinalize(self): + pass + + #override + def onClientProcessData(self, data): + filename_path = Path( data[0] ) + image = cv2.imread( str(filename_path) ) + face_mask = None + + a_png = AlignedPNG.load( str(filename_path) ) + if a_png is not None: + d = a_png.getFaceswapDictData() + if (d is not None) and (d['landmarks'] is not None): + face_mask = LandmarksProcessor.get_image_hull_mask (image, np.array(d['landmarks'])) + + if face_mask is not None: + image = (image*face_mask).astype(np.uint8) + else: + print ( "%s - no embedded data found." % (str(filename_path)) ) + return [ str(filename_path), 0 ] + + return [ str(filename_path), estimate_sharpness( image ) ] + + #override + def onClientGetDataName (self, data): + #return string identificator of your data + return data[0] + + #override + def onHostResult (self, data, result): + if result[1] == 0: + filename_path = Path( data[0] ) + print ( "{0} - invalid image, renaming to {0}_invalid.".format(str(filename_path)) ) + filename_path.rename ( str(filename_path) + '_invalid' ) + else: + self.result.append ( result ) + return 1 + + #override + def onHostProcessEnd(self): + pass + + #override + def get_start_return(self): + return self.result + + +def sort_by_blur(input_path): + print ("Sorting by blur...") + + img_list = [ (filename,[]) for filename in Path_utils.get_image_paths(input_path) ] + img_list = BlurEstimatorSubprocessor (img_list).process() + + print ("Sorting...") + img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True) + + return img_list + def sort_by_brightness(input_path): print ("Sorting by brightness...") img_list = [ [x, np.mean ( cv2.cvtColor(cv2.imread(x), cv2.COLOR_BGR2HSV)[...,2].flatten() )] for x in tqdm( Path_utils.get_image_paths(input_path), desc="Loading") ] @@ -33,18 +152,6 @@ def sort_by_hue(input_path): img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True) return img_list -def sort_by_blur(input_path): - img_list = [] - print ("Sorting by blur...") - for filepath in tqdm( Path_utils.get_image_paths(input_path), desc="Loading"): - #never mask it by face hull, it worse than whole image blur estimate - img_list.append ( [filepath, estimate_blur (cv2.imread( filepath ))] ) - - print ("Sorting...") - img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True) - - return img_list - def sort_by_face(input_path): print ("Sorting by face similarity...") @@ -168,7 +275,7 @@ def sort_by_hist_blur(input_path): img_list.append ([x, cv2.calcHist([img], [0], None, [256], [0, 256]), cv2.calcHist([img], [1], None, [256], [0, 256]), cv2.calcHist([img], [2], None, [256], [0, 256]), - estimate_blur(img) + estimate_sharpness(img) ]) img_list_len = len(img_list)