DeepFaceLab/mainscripts/Sorter.py
Colombo 6d3607a13d New script:
5.XSeg) data_dst/src mask for XSeg trainer - fetch.bat
Copies faces containing XSeg polygons to aligned_xseg\ dir.
Useful only if you want to collect labeled faces and reuse them in other fakes.

Now you can use trained XSeg mask in the SAEHD training process.
It’s mean default ‘full_face’ mask obtained from landmarks will be replaced with the mask obtained from the trained XSeg model.
use
5.XSeg.optional) trained mask for data_dst/data_src - apply.bat
5.XSeg.optional) trained mask for data_dst/data_src - remove.bat

Normally you don’t need it. You can use it, if you want to use ‘face_style’ and ‘bg_style’ with obstructions.

XSeg trainer : now you can choose type of face
XSeg trainer : now you can restart training in “override settings”
Merger: XSeg-* modes now can be used with all types of faces.

Therefore old MaskEditor, FANSEG models, and FAN-x modes have been removed,
because the new XSeg solution is better, simpler and more convenient, which costs only 1 hour of manual masking for regular deepfake.
2020-03-30 14:00:40 +04:00

911 lines
30 KiB
Python

import math
import multiprocessing
import operator
import os
import sys
import tempfile
from functools import cmp_to_key
from pathlib import Path
import cv2
import numpy as np
from numpy import linalg as npla
from core import imagelib, mathlib, pathex
from core.cv2ex import *
from core.imagelib import estimate_sharpness
from core.interact import interact as io
from core.joblib import Subprocessor
from core.leras import nn
from DFLIMG import *
from facelib import LandmarksProcessor
class BlurEstimatorSubprocessor(Subprocessor):
class Cli(Subprocessor.Cli):
#override
def process_data(self, data):
filepath = Path( data[0] )
dflimg = DFLIMG.load (filepath)
if dflimg is None or not dflimg.has_data():
self.log_err (f"{filepath.name} is not a dfl image file")
return [ str(filepath), 0 ]
else:
image = cv2_imread( str(filepath) )
return [ str(filepath), estimate_sharpness(image) ]
#override
def get_data_name (self, data):
#return string identificator of your data
return data[0]
#override
def __init__(self, input_data ):
self.input_data = input_data
self.img_list = []
self.trash_img_list = []
super().__init__('BlurEstimator', BlurEstimatorSubprocessor.Cli, 60)
#override
def on_clients_initialized(self):
io.progress_bar ("", len (self.input_data))
#override
def on_clients_finalized(self):
io.progress_bar_close ()
#override
def process_info_generator(self):
cpu_count = multiprocessing.cpu_count()
io.log_info(f'Running on {cpu_count} CPUs')
for i in range(cpu_count):
yield 'CPU%d' % (i), {}, {}
#override
def get_data(self, host_dict):
if len (self.input_data) > 0:
return self.input_data.pop(0)
return None
#override
def on_data_return (self, host_dict, data):
self.input_data.insert(0, data)
#override
def on_result (self, host_dict, data, result):
if result[1] == 0:
self.trash_img_list.append ( result )
else:
self.img_list.append ( result )
io.progress_bar_inc(1)
#override
def get_result(self):
return self.img_list, self.trash_img_list
def sort_by_blur(input_path):
io.log_info ("Sorting by blur...")
img_list = [ (filename,[]) for filename in pathex.get_image_paths(input_path) ]
img_list, trash_img_list = BlurEstimatorSubprocessor (img_list).run()
io.log_info ("Sorting...")
img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True)
return img_list, trash_img_list
def sort_by_face_yaw(input_path):
io.log_info ("Sorting by face yaw...")
img_list = []
trash_img_list = []
for filepath in io.progress_bar_generator( pathex.get_image_paths(input_path), "Loading"):
filepath = Path(filepath)
dflimg = DFLIMG.load (filepath)
if dflimg is None or not dflimg.has_data():
io.log_err (f"{filepath.name} is not a dfl image file")
trash_img_list.append ( [str(filepath)] )
continue
pitch, yaw, roll = LandmarksProcessor.estimate_pitch_yaw_roll ( dflimg.get_landmarks(), size=dflimg.get_shape()[1] )
img_list.append( [str(filepath), yaw ] )
io.log_info ("Sorting...")
img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True)
return img_list, trash_img_list
def sort_by_face_pitch(input_path):
io.log_info ("Sorting by face pitch...")
img_list = []
trash_img_list = []
for filepath in io.progress_bar_generator( pathex.get_image_paths(input_path), "Loading"):
filepath = Path(filepath)
dflimg = DFLIMG.load (filepath)
if dflimg is None or not dflimg.has_data():
io.log_err (f"{filepath.name} is not a dfl image file")
trash_img_list.append ( [str(filepath)] )
continue
pitch, yaw, roll = LandmarksProcessor.estimate_pitch_yaw_roll ( dflimg.get_landmarks(), size=dflimg.get_shape()[1] )
img_list.append( [str(filepath), pitch ] )
io.log_info ("Sorting...")
img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True)
return img_list, trash_img_list
def sort_by_face_source_rect_size(input_path):
io.log_info ("Sorting by face rect size...")
img_list = []
trash_img_list = []
for filepath in io.progress_bar_generator( pathex.get_image_paths(input_path), "Loading"):
filepath = Path(filepath)
dflimg = DFLIMG.load (filepath)
if dflimg is None or not dflimg.has_data():
io.log_err (f"{filepath.name} is not a dfl image file")
trash_img_list.append ( [str(filepath)] )
continue
source_rect = dflimg.get_source_rect()
rect_area = mathlib.polygon_area(np.array(source_rect[[0,2,2,0]]).astype(np.float32), np.array(source_rect[[1,1,3,3]]).astype(np.float32))
img_list.append( [str(filepath), rect_area ] )
io.log_info ("Sorting...")
img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True)
return img_list, trash_img_list
class HistSsimSubprocessor(Subprocessor):
class Cli(Subprocessor.Cli):
#override
def process_data(self, data):
img_list = []
for x in data:
img = cv2_imread(x)
img_list.append ([x, cv2.calcHist([img], [0], None, [256], [0, 256]),
cv2.calcHist([img], [1], None, [256], [0, 256]),
cv2.calcHist([img], [2], None, [256], [0, 256])
])
img_list_len = len(img_list)
for i in range(img_list_len-1):
min_score = float("inf")
j_min_score = i+1
for j in range(i+1,len(img_list)):
score = cv2.compareHist(img_list[i][1], img_list[j][1], cv2.HISTCMP_BHATTACHARYYA) + \
cv2.compareHist(img_list[i][2], img_list[j][2], cv2.HISTCMP_BHATTACHARYYA) + \
cv2.compareHist(img_list[i][3], img_list[j][3], cv2.HISTCMP_BHATTACHARYYA)
if score < min_score:
min_score = score
j_min_score = j
img_list[i+1], img_list[j_min_score] = img_list[j_min_score], img_list[i+1]
self.progress_bar_inc(1)
return img_list
#override
def get_data_name (self, data):
return "Bunch of images"
#override
def __init__(self, img_list ):
self.img_list = img_list
self.img_list_len = len(img_list)
slice_count = 20000
sliced_count = self.img_list_len // slice_count
if sliced_count > 12:
sliced_count = 11.9
slice_count = int(self.img_list_len / sliced_count)
sliced_count = self.img_list_len // slice_count
self.img_chunks_list = [ self.img_list[i*slice_count : (i+1)*slice_count] for i in range(sliced_count) ] + \
[ self.img_list[sliced_count*slice_count:] ]
self.result = []
super().__init__('HistSsim', HistSsimSubprocessor.Cli, 0)
#override
def process_info_generator(self):
cpu_count = len(self.img_chunks_list)
io.log_info(f'Running on {cpu_count} threads')
for i in range(cpu_count):
yield 'CPU%d' % (i), {'i':i}, {}
#override
def on_clients_initialized(self):
io.progress_bar ("Sorting", len(self.img_list))
io.progress_bar_inc(len(self.img_chunks_list))
#override
def on_clients_finalized(self):
io.progress_bar_close()
#override
def get_data(self, host_dict):
if len (self.img_chunks_list) > 0:
return self.img_chunks_list.pop(0)
return None
#override
def on_data_return (self, host_dict, data):
raise Exception("Fail to process data. Decrease number of images and try again.")
#override
def on_result (self, host_dict, data, result):
self.result += result
return 0
#override
def get_result(self):
return self.result
def sort_by_hist(input_path):
io.log_info ("Sorting by histogram similarity...")
img_list = HistSsimSubprocessor(pathex.get_image_paths(input_path)).run()
return img_list, []
class HistDissimSubprocessor(Subprocessor):
class Cli(Subprocessor.Cli):
#override
def on_initialize(self, client_dict):
self.img_list = client_dict['img_list']
self.img_list_len = len(self.img_list)
#override
def process_data(self, data):
i = data[0]
score_total = 0
for j in range( 0, self.img_list_len):
if i == j:
continue
score_total += cv2.compareHist(self.img_list[i][1], self.img_list[j][1], cv2.HISTCMP_BHATTACHARYYA)
return score_total
#override
def get_data_name (self, data):
#return string identificator of your data
return self.img_list[data[0]][0]
#override
def __init__(self, img_list ):
self.img_list = img_list
self.img_list_range = [i for i in range(0, len(img_list) )]
self.result = []
super().__init__('HistDissim', HistDissimSubprocessor.Cli, 60)
#override
def on_clients_initialized(self):
io.progress_bar ("Sorting", len (self.img_list) )
#override
def on_clients_finalized(self):
io.progress_bar_close()
#override
def process_info_generator(self):
cpu_count = min(multiprocessing.cpu_count(), 8)
io.log_info(f'Running on {cpu_count} CPUs')
for i in range(cpu_count):
yield 'CPU%d' % (i), {}, {'img_list' : self.img_list}
#override
def get_data(self, host_dict):
if len (self.img_list_range) > 0:
return [self.img_list_range.pop(0)]
return None
#override
def on_data_return (self, host_dict, data):
self.img_list_range.insert(0, data[0])
#override
def on_result (self, host_dict, data, result):
self.img_list[data[0]][2] = result
io.progress_bar_inc(1)
#override
def get_result(self):
return self.img_list
def sort_by_hist_dissim(input_path):
io.log_info ("Sorting by histogram dissimilarity...")
img_list = []
trash_img_list = []
for filepath in io.progress_bar_generator( pathex.get_image_paths(input_path), "Loading"):
filepath = Path(filepath)
dflimg = DFLIMG.load (filepath)
image = cv2_imread(str(filepath))
if dflimg is not None and dflimg.has_data():
face_mask = LandmarksProcessor.get_image_hull_mask (image.shape, dflimg.get_landmarks())
image = (image*face_mask).astype(np.uint8)
img_list.append ([str(filepath), cv2.calcHist([cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)], [0], None, [256], [0, 256]), 0 ])
img_list = HistDissimSubprocessor(img_list).run()
io.log_info ("Sorting...")
img_list = sorted(img_list, key=operator.itemgetter(2), reverse=True)
return img_list, trash_img_list
def sort_by_brightness(input_path):
io.log_info ("Sorting by brightness...")
img_list = [ [x, np.mean ( cv2.cvtColor(cv2_imread(x), cv2.COLOR_BGR2HSV)[...,2].flatten() )] for x in io.progress_bar_generator( pathex.get_image_paths(input_path), "Loading") ]
io.log_info ("Sorting...")
img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True)
return img_list, []
def sort_by_hue(input_path):
io.log_info ("Sorting by hue...")
img_list = [ [x, np.mean ( cv2.cvtColor(cv2_imread(x), cv2.COLOR_BGR2HSV)[...,0].flatten() )] for x in io.progress_bar_generator( pathex.get_image_paths(input_path), "Loading") ]
io.log_info ("Sorting...")
img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True)
return img_list, []
def sort_by_black(input_path):
io.log_info ("Sorting by amount of black pixels...")
img_list = []
for x in io.progress_bar_generator( pathex.get_image_paths(input_path), "Loading"):
img = cv2_imread(x)
img_list.append ([x, img[(img == 0)].size ])
io.log_info ("Sorting...")
img_list = sorted(img_list, key=operator.itemgetter(1), reverse=False)
return img_list, []
def sort_by_origname(input_path):
io.log_info ("Sort by original filename...")
img_list = []
trash_img_list = []
for filepath in io.progress_bar_generator( pathex.get_image_paths(input_path), "Loading"):
filepath = Path(filepath)
dflimg = DFLIMG.load (filepath)
if dflimg is None or not dflimg.has_data():
io.log_err (f"{filepath.name} is not a dfl image file")
trash_img_list.append( [str(filepath)] )
continue
img_list.append( [str(filepath), dflimg.get_source_filename()] )
io.log_info ("Sorting...")
img_list = sorted(img_list, key=operator.itemgetter(1))
return img_list, trash_img_list
def sort_by_oneface_in_image(input_path):
io.log_info ("Sort by one face in images...")
image_paths = pathex.get_image_paths(input_path)
a = np.array ([ ( int(x[0]), int(x[1]) ) \
for x in [ Path(filepath).stem.split('_') for filepath in image_paths ] if len(x) == 2
])
if len(a) > 0:
idxs = np.ndarray.flatten ( np.argwhere ( a[:,1] != 0 ) )
idxs = np.unique ( a[idxs][:,0] )
idxs = np.ndarray.flatten ( np.argwhere ( np.array([ x[0] in idxs for x in a ]) == True ) )
if len(idxs) > 0:
io.log_info ("Found %d images." % (len(idxs)) )
img_list = [ (path,) for i,path in enumerate(image_paths) if i not in idxs ]
trash_img_list = [ (image_paths[x],) for x in idxs ]
return img_list, trash_img_list
io.log_info ("Nothing found. Possible recover original filenames first.")
return [], []
class FinalLoaderSubprocessor(Subprocessor):
class Cli(Subprocessor.Cli):
#override
def on_initialize(self, client_dict):
self.faster = client_dict['faster']
#override
def process_data(self, data):
filepath = Path(data[0])
try:
dflimg = DFLIMG.load (filepath)
if dflimg is None or not dflimg.has_data():
self.log_err (f"{filepath.name} is not a dfl image file")
return [ 1, [str(filepath)] ]
bgr = cv2_imread(str(filepath))
if bgr is None:
raise Exception ("Unable to load %s" % (filepath.name) )
gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY)
if self.faster:
source_rect = dflimg.get_source_rect()
sharpness = mathlib.polygon_area(np.array(source_rect[[0,2,2,0]]).astype(np.float32), np.array(source_rect[[1,1,3,3]]).astype(np.float32))
else:
sharpness = estimate_sharpness(gray)
pitch, yaw, roll = LandmarksProcessor.estimate_pitch_yaw_roll ( dflimg.get_landmarks(), size=dflimg.get_shape()[1] )
hist = cv2.calcHist([gray], [0], None, [256], [0, 256])
except Exception as e:
self.log_err (e)
return [ 1, [str(filepath)] ]
return [ 0, [str(filepath), sharpness, hist, yaw, pitch ] ]
#override
def get_data_name (self, data):
#return string identificator of your data
return data[0]
#override
def __init__(self, img_list, faster ):
self.img_list = img_list
self.faster = faster
self.result = []
self.result_trash = []
super().__init__('FinalLoader', FinalLoaderSubprocessor.Cli, 60)
#override
def on_clients_initialized(self):
io.progress_bar ("Loading", len (self.img_list))
#override
def on_clients_finalized(self):
io.progress_bar_close()
#override
def process_info_generator(self):
cpu_count = min(multiprocessing.cpu_count(), 8)
io.log_info(f'Running on {cpu_count} CPUs')
for i in range(cpu_count):
yield 'CPU%d' % (i), {}, {'faster': self.faster}
#override
def get_data(self, host_dict):
if len (self.img_list) > 0:
return [self.img_list.pop(0)]
return None
#override
def on_data_return (self, host_dict, data):
self.img_list.insert(0, data[0])
#override
def on_result (self, host_dict, data, result):
if result[0] == 0:
self.result.append (result[1])
else:
self.result_trash.append (result[1])
io.progress_bar_inc(1)
#override
def get_result(self):
return self.result, self.result_trash
class FinalHistDissimSubprocessor(Subprocessor):
class Cli(Subprocessor.Cli):
#override
def process_data(self, data):
idx, pitch_yaw_img_list = data
for p in range ( len(pitch_yaw_img_list) ):
img_list = pitch_yaw_img_list[p]
if img_list is not None:
for i in range( len(img_list) ):
score_total = 0
for j in range( len(img_list) ):
if i == j:
continue
score_total += cv2.compareHist(img_list[i][2], img_list[j][2], cv2.HISTCMP_BHATTACHARYYA)
img_list[i][3] = score_total
pitch_yaw_img_list[p] = sorted(img_list, key=operator.itemgetter(3), reverse=True)
return idx, pitch_yaw_img_list
#override
def get_data_name (self, data):
return "Bunch of images"
#override
def __init__(self, pitch_yaw_sample_list ):
self.pitch_yaw_sample_list = pitch_yaw_sample_list
self.pitch_yaw_sample_list_len = len(pitch_yaw_sample_list)
self.pitch_yaw_sample_list_idxs = [ i for i in range(self.pitch_yaw_sample_list_len) if self.pitch_yaw_sample_list[i] is not None ]
self.result = [ None for _ in range(self.pitch_yaw_sample_list_len) ]
super().__init__('FinalHistDissimSubprocessor', FinalHistDissimSubprocessor.Cli)
#override
def process_info_generator(self):
cpu_count = min(multiprocessing.cpu_count(), 8)
io.log_info(f'Running on {cpu_count} CPUs')
for i in range(cpu_count):
yield 'CPU%d' % (i), {}, {}
#override
def on_clients_initialized(self):
io.progress_bar ("Sort by hist-dissim", len(self.pitch_yaw_sample_list_idxs) )
#override
def on_clients_finalized(self):
io.progress_bar_close()
#override
def get_data(self, host_dict):
if len (self.pitch_yaw_sample_list_idxs) > 0:
idx = self.pitch_yaw_sample_list_idxs.pop(0)
return idx, self.pitch_yaw_sample_list[idx]
return None
#override
def on_data_return (self, host_dict, data):
self.pitch_yaw_sample_list_idxs.insert(0, data[0])
#override
def on_result (self, host_dict, data, result):
idx, yaws_sample_list = data
self.result[idx] = yaws_sample_list
io.progress_bar_inc(1)
#override
def get_result(self):
return self.result
def sort_best_faster(input_path):
return sort_best(input_path, faster=True)
def sort_best(input_path, faster=False):
target_count = io.input_int ("Target number of faces?", 2000)
io.log_info ("Performing sort by best faces.")
if faster:
io.log_info("Using faster algorithm. Faces will be sorted by source-rect-area instead of blur.")
img_list, trash_img_list = FinalLoaderSubprocessor( pathex.get_image_paths(input_path), faster ).run()
final_img_list = []
grads = 128
imgs_per_grad = round (target_count / grads)
#instead of math.pi / 2, using -1.2,+1.2 because actually maximum yaw for 2DFAN landmarks are -1.2+1.2
grads_space = np.linspace (-1.2, 1.2,grads)
yaws_sample_list = [None]*grads
for g in io.progress_bar_generator ( range(grads), "Sort by yaw"):
yaw = grads_space[g]
next_yaw = grads_space[g+1] if g < grads-1 else yaw
yaw_samples = []
for img in img_list:
s_yaw = -img[3]
if (g == 0 and s_yaw < next_yaw) or \
(g < grads-1 and s_yaw >= yaw and s_yaw < next_yaw) or \
(g == grads-1 and s_yaw >= yaw):
yaw_samples += [ img ]
if len(yaw_samples) > 0:
yaws_sample_list[g] = yaw_samples
total_lack = 0
for g in io.progress_bar_generator ( range(grads), ""):
img_list = yaws_sample_list[g]
img_list_len = len(img_list) if img_list is not None else 0
lack = imgs_per_grad - img_list_len
total_lack += max(lack, 0)
imgs_per_grad += total_lack // grads
sharpned_imgs_per_grad = imgs_per_grad*10
for g in io.progress_bar_generator ( range (grads), "Sort by blur"):
img_list = yaws_sample_list[g]
if img_list is None:
continue
img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True)
if len(img_list) > sharpned_imgs_per_grad:
trash_img_list += img_list[sharpned_imgs_per_grad:]
img_list = img_list[0:sharpned_imgs_per_grad]
yaws_sample_list[g] = img_list
yaw_pitch_sample_list = [None]*grads
pitch_grads = imgs_per_grad
for g in io.progress_bar_generator ( range (grads), "Sort by pitch"):
img_list = yaws_sample_list[g]
if img_list is None:
continue
pitch_sample_list = [None]*pitch_grads
grads_space = np.linspace (-math.pi / 2,math.pi / 2, pitch_grads )
for pg in range (pitch_grads):
pitch = grads_space[pg]
next_pitch = grads_space[pg+1] if pg < pitch_grads-1 else pitch
pitch_samples = []
for img in img_list:
s_pitch = img[4]
if (pg == 0 and s_pitch < next_pitch) or \
(pg < pitch_grads-1 and s_pitch >= pitch and s_pitch < next_pitch) or \
(pg == pitch_grads-1 and s_pitch >= pitch):
pitch_samples += [ img ]
if len(pitch_samples) > 0:
pitch_sample_list[pg] = pitch_samples
yaw_pitch_sample_list[g] = pitch_sample_list
yaw_pitch_sample_list = FinalHistDissimSubprocessor(yaw_pitch_sample_list).run()
for g in io.progress_bar_generator (range (grads), "Fetching the best"):
pitch_sample_list = yaw_pitch_sample_list[g]
if pitch_sample_list is None:
continue
n = imgs_per_grad
while n > 0:
n_prev = n
for pg in range(pitch_grads):
img_list = pitch_sample_list[pg]
if img_list is None:
continue
final_img_list += [ img_list.pop(0) ]
if len(img_list) == 0:
pitch_sample_list[pg] = None
n -= 1
if n == 0:
break
if n_prev == n:
break
for pg in range(pitch_grads):
img_list = pitch_sample_list[pg]
if img_list is None:
continue
trash_img_list += img_list
return final_img_list, trash_img_list
"""
def sort_by_vggface(input_path):
io.log_info ("Sorting by face similarity using VGGFace model...")
model = VGGFace()
final_img_list = []
trash_img_list = []
image_paths = pathex.get_image_paths(input_path)
img_list = [ (x,) for x in image_paths ]
img_list_len = len(img_list)
img_list_range = [*range(img_list_len)]
feats = [None]*img_list_len
for i in io.progress_bar_generator(img_list_range, "Loading"):
img = cv2_imread( img_list[i][0] ).astype(np.float32)
img = imagelib.normalize_channels (img, 3)
img = cv2.resize (img, (224,224) )
img = img[..., ::-1]
img[..., 0] -= 93.5940
img[..., 1] -= 104.7624
img[..., 2] -= 129.1863
feats[i] = model.predict( img[None,...] )[0]
tmp = np.zeros( (img_list_len,) )
float_inf = float("inf")
for i in io.progress_bar_generator ( range(img_list_len-1), "Sorting" ):
i_feat = feats[i]
for j in img_list_range:
tmp[j] = npla.norm(i_feat-feats[j]) if j >= i+1 else float_inf
idx = np.argmin(tmp)
img_list[i+1], img_list[idx] = img_list[idx], img_list[i+1]
feats[i+1], feats[idx] = feats[idx], feats[i+1]
return img_list, trash_img_list
"""
def sort_by_absdiff(input_path):
io.log_info ("Sorting by absolute difference...")
is_sim = io.input_bool ("Sort by similar?", True, help_message="Otherwise sort by dissimilar.")
from core.leras import nn
device_config = nn.DeviceConfig.ask_choose_device(choose_only_one=True)
nn.initialize( device_config=device_config, data_format="NHWC" )
tf = nn.tf
image_paths = pathex.get_image_paths(input_path)
image_paths_len = len(image_paths)
batch_size = 512
batch_size_remain = image_paths_len % batch_size
i_t = tf.placeholder (tf.float32, (None,None,None,None) )
j_t = tf.placeholder (tf.float32, (None,None,None,None) )
outputs_full = []
outputs_remain = []
for i in range(batch_size):
diff_t = tf.reduce_sum( tf.abs(i_t-j_t[i]), axis=[1,2,3] )
outputs_full.append(diff_t)
if i < batch_size_remain:
outputs_remain.append(diff_t)
def func_bs_full(i,j):
return nn.tf_sess.run (outputs_full, feed_dict={i_t:i,j_t:j})
def func_bs_remain(i,j):
return nn.tf_sess.run (outputs_remain, feed_dict={i_t:i,j_t:j})
import h5py
db_file_path = Path(tempfile.gettempdir()) / 'sort_cache.hdf5'
db_file = h5py.File( str(db_file_path), "w")
db = db_file.create_dataset("results", (image_paths_len,image_paths_len), compression="gzip")
pg_len = image_paths_len // batch_size
if batch_size_remain != 0:
pg_len += 1
pg_len = int( ( pg_len*pg_len - pg_len ) / 2 + pg_len )
io.progress_bar ("Computing", pg_len)
j=0
while j < image_paths_len:
j_images = [ cv2_imread(x) for x in image_paths[j:j+batch_size] ]
j_images_len = len(j_images)
func = func_bs_remain if image_paths_len-j < batch_size else func_bs_full
i=0
while i < image_paths_len:
if i >= j:
i_images = [ cv2_imread(x) for x in image_paths[i:i+batch_size] ]
i_images_len = len(i_images)
result = func (i_images,j_images)
db[j:j+j_images_len,i:i+i_images_len] = np.array(result)
io.progress_bar_inc(1)
i += batch_size
db_file.flush()
j += batch_size
io.progress_bar_close()
next_id = 0
sorted = [next_id]
for i in io.progress_bar_generator ( range(image_paths_len-1), "Sorting" ):
id_ar = np.concatenate ( [ db[:next_id,next_id], db[next_id,next_id:] ] )
id_ar = np.argsort(id_ar)
next_id = np.setdiff1d(id_ar, sorted, True)[ 0 if is_sim else -1]
sorted += [next_id]
db_file.close()
db_file_path.unlink()
img_list = [ (image_paths[x],) for x in sorted]
return img_list, []
def final_process(input_path, img_list, trash_img_list):
if len(trash_img_list) != 0:
parent_input_path = input_path.parent
trash_path = parent_input_path / (input_path.stem + '_trash')
trash_path.mkdir (exist_ok=True)
io.log_info ("Trashing %d items to %s" % ( len(trash_img_list), str(trash_path) ) )
for filename in pathex.get_image_paths(trash_path):
Path(filename).unlink()
for i in io.progress_bar_generator( range(len(trash_img_list)), "Moving trash", leave=False):
src = Path (trash_img_list[i][0])
dst = trash_path / src.name
try:
src.rename (dst)
except:
io.log_info ('fail to trashing %s' % (src.name) )
io.log_info ("")
if len(img_list) != 0:
for i in io.progress_bar_generator( [*range(len(img_list))], "Renaming", leave=False):
src = Path (img_list[i][0])
dst = input_path / ('%.5d_%s' % (i, src.name ))
try:
src.rename (dst)
except:
io.log_info ('fail to rename %s' % (src.name) )
for i in io.progress_bar_generator( [*range(len(img_list))], "Renaming"):
src = Path (img_list[i][0])
src = input_path / ('%.5d_%s' % (i, src.name))
dst = input_path / ('%.5d%s' % (i, src.suffix))
try:
src.rename (dst)
except:
io.log_info ('fail to rename %s' % (src.name) )
sort_func_methods = {
'blur': ("blur", sort_by_blur),
'face-yaw': ("face yaw direction", sort_by_face_yaw),
'face-pitch': ("face pitch direction", sort_by_face_pitch),
'face-source-rect-size' : ("face rect size in source image", sort_by_face_source_rect_size),
'hist': ("histogram similarity", sort_by_hist),
'hist-dissim': ("histogram dissimilarity", sort_by_hist_dissim),
'brightness': ("brightness", sort_by_brightness),
'hue': ("hue", sort_by_hue),
'black': ("amount of black pixels", sort_by_black),
'origname': ("original filename", sort_by_origname),
'oneface': ("one face in image", sort_by_oneface_in_image),
'absdiff': ("absolute pixel difference", sort_by_absdiff),
'final': ("best faces", sort_best),
'final-fast': ("best faces faster", sort_best_faster),
}
def main (input_path, sort_by_method=None):
io.log_info ("Running sort tool.\r\n")
if sort_by_method is None:
io.log_info(f"Choose sorting method:")
key_list = list(sort_func_methods.keys())
for i, key in enumerate(key_list):
desc, func = sort_func_methods[key]
io.log_info(f"[{i}] {desc}")
io.log_info("")
id = io.input_int("", 4, valid_list=[*range(len(key_list))] )
sort_by_method = key_list[id]
else:
sort_by_method = sort_by_method.lower()
desc, func = sort_func_methods[sort_by_method]
img_list, trash_img_list = func(input_path)
final_process (input_path, img_list, trash_img_list)