improved sort by blur. Now it's better than laplacian, and multiprocessed.

This commit is contained in:
iperov 2018-12-18 16:51:36 +04:00
parent 1111aa519e
commit 43da6f84d6

View file

@ -10,15 +10,134 @@ from pathlib import Path
from utils import Path_utils
from utils.AlignedPNG import AlignedPNG
from facelib import LandmarksProcessor
from utils.SubprocessorBase import SubprocessorBase
import multiprocessing
def estimate_blur(image):
def estimate_sharpness(image):
height, width = image.shape[:2]
if image.ndim == 3:
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
blur_map = cv2.Laplacian(image, cv2.CV_64F)
score = np.var(blur_map)
return score
sharpness = 0
for y in range(height):
for x in range(width-1):
sharpness += abs( int(image[y, x]) - int(image[y, x+1]) )
for x in range(width):
for y in range(height-1):
sharpness += abs( int(image[y, x]) - int(image[y+1, x]) )
return sharpness
class BlurEstimatorSubprocessor(SubprocessorBase):
#override
def __init__(self, input_data ):
self.input_data = input_data
self.result = []
super().__init__('BlurEstimator', 60)
#override
def onHostClientsInitialized(self):
pass
#override
def process_info_generator(self):
for i in range(0, multiprocessing.cpu_count() ):
yield 'CPU%d' % (i), {}, {'device_idx': i,
'device_name': 'CPU%d' % (i),
}
#override
def get_no_process_started_message(self):
print ( 'Unable to start CPU processes.')
#override
def onHostGetProgressBarDesc(self):
return None
#override
def onHostGetProgressBarLen(self):
return len (self.input_data)
#override
def onHostGetData(self):
if len (self.input_data) > 0:
return self.input_data.pop(0)
return None
#override
def onHostDataReturn (self, data):
self.input_data.insert(0, data)
#override
def onClientInitialize(self, client_dict):
self.safe_print ('Running on %s.' % (client_dict['device_name']) )
return None
#override
def onClientFinalize(self):
pass
#override
def onClientProcessData(self, data):
filename_path = Path( data[0] )
image = cv2.imread( str(filename_path) )
face_mask = None
a_png = AlignedPNG.load( str(filename_path) )
if a_png is not None:
d = a_png.getFaceswapDictData()
if (d is not None) and (d['landmarks'] is not None):
face_mask = LandmarksProcessor.get_image_hull_mask (image, np.array(d['landmarks']))
if face_mask is not None:
image = (image*face_mask).astype(np.uint8)
else:
print ( "%s - no embedded data found." % (str(filename_path)) )
return [ str(filename_path), 0 ]
return [ str(filename_path), estimate_sharpness( image ) ]
#override
def onClientGetDataName (self, data):
#return string identificator of your data
return data[0]
#override
def onHostResult (self, data, result):
if result[1] == 0:
filename_path = Path( data[0] )
print ( "{0} - invalid image, renaming to {0}_invalid.".format(str(filename_path)) )
filename_path.rename ( str(filename_path) + '_invalid' )
else:
self.result.append ( result )
return 1
#override
def onHostProcessEnd(self):
pass
#override
def get_start_return(self):
return self.result
def sort_by_blur(input_path):
print ("Sorting by blur...")
img_list = [ (filename,[]) for filename in Path_utils.get_image_paths(input_path) ]
img_list = BlurEstimatorSubprocessor (img_list).process()
print ("Sorting...")
img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True)
return img_list
def sort_by_brightness(input_path):
print ("Sorting by brightness...")
img_list = [ [x, np.mean ( cv2.cvtColor(cv2.imread(x), cv2.COLOR_BGR2HSV)[...,2].flatten() )] for x in tqdm( Path_utils.get_image_paths(input_path), desc="Loading") ]
@ -33,18 +152,6 @@ def sort_by_hue(input_path):
img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True)
return img_list
def sort_by_blur(input_path):
img_list = []
print ("Sorting by blur...")
for filepath in tqdm( Path_utils.get_image_paths(input_path), desc="Loading"):
#never mask it by face hull, it worse than whole image blur estimate
img_list.append ( [filepath, estimate_blur (cv2.imread( filepath ))] )
print ("Sorting...")
img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True)
return img_list
def sort_by_face(input_path):
print ("Sorting by face similarity...")
@ -168,7 +275,7 @@ def sort_by_hist_blur(input_path):
img_list.append ([x, cv2.calcHist([img], [0], None, [256], [0, 256]),
cv2.calcHist([img], [1], None, [256], [0, 256]),
cv2.calcHist([img], [2], None, [256], [0, 256]),
estimate_blur(img)
estimate_sharpness(img)
])
img_list_len = len(img_list)