mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-08-19 21:13:20 -07:00
initial
This commit is contained in:
parent
73de93b4f1
commit
6bd5a44264
71 changed files with 8448 additions and 0 deletions
283
mainscripts/Converter.py
Normal file
283
mainscripts/Converter.py
Normal file
|
@ -0,0 +1,283 @@
|
|||
import traceback
|
||||
from pathlib import Path
|
||||
from utils import Path_utils
|
||||
import cv2
|
||||
from tqdm import tqdm
|
||||
from utils.AlignedPNG import AlignedPNG
|
||||
from utils import image_utils
|
||||
import shutil
|
||||
import numpy as np
|
||||
import time
|
||||
import multiprocessing
|
||||
from models import ConverterBase
|
||||
|
||||
class model_process_predictor(object):
|
||||
def __init__(self, sq, cq, lock):
|
||||
self.sq = sq
|
||||
self.cq = cq
|
||||
self.lock = lock
|
||||
|
||||
def __call__(self, face):
|
||||
self.lock.acquire()
|
||||
|
||||
self.sq.put ( {'op': 'predict', 'face' : face} )
|
||||
while True:
|
||||
if not self.cq.empty():
|
||||
obj = self.cq.get()
|
||||
obj_op = obj['op']
|
||||
if obj_op == 'predict_result':
|
||||
self.lock.release()
|
||||
return obj['result']
|
||||
time.sleep(0.005)
|
||||
|
||||
def model_process(model_name, model_dir, in_options, sq, cq):
|
||||
try:
|
||||
model_path = Path(model_dir)
|
||||
|
||||
import models
|
||||
model = models.import_model(model_name)(model_path, **in_options)
|
||||
converter = model.get_converter(**in_options)
|
||||
converter.dummy_predict()
|
||||
|
||||
cq.put ( {'op':'init', 'converter' : converter.copy_and_set_predictor( None ) } )
|
||||
|
||||
closing = False
|
||||
while not closing:
|
||||
while not sq.empty():
|
||||
obj = sq.get()
|
||||
obj_op = obj['op']
|
||||
if obj_op == 'predict':
|
||||
result = converter.predictor ( obj['face'] )
|
||||
cq.put ( {'op':'predict_result', 'result':result} )
|
||||
elif obj_op == 'close':
|
||||
closing = True
|
||||
break
|
||||
time.sleep(0.005)
|
||||
|
||||
model.finalize()
|
||||
|
||||
except Exception as e:
|
||||
print ( 'Error: %s' % (str(e)))
|
||||
traceback.print_exc()
|
||||
|
||||
from utils.SubprocessorBase import SubprocessorBase
|
||||
class ConvertSubprocessor(SubprocessorBase):
|
||||
|
||||
#override
|
||||
def __init__(self, converter, input_path_image_paths, output_path, alignments, debug):
|
||||
super().__init__('Converter')
|
||||
self.converter = converter
|
||||
self.input_path_image_paths = input_path_image_paths
|
||||
self.output_path = output_path
|
||||
self.alignments = alignments
|
||||
self.debug = debug
|
||||
|
||||
self.input_data = self.input_path_image_paths
|
||||
self.files_processed = 0
|
||||
self.faces_processed = 0
|
||||
|
||||
#override
|
||||
def process_info_generator(self):
|
||||
r = [0] if self.debug else range(multiprocessing.cpu_count())
|
||||
for i in r:
|
||||
yield 'CPU%d' % (i), {}, {'device_idx': i,
|
||||
'device_name': 'CPU%d' % (i),
|
||||
'converter' : self.converter,
|
||||
'output_dir' : str(self.output_path),
|
||||
'alignments' : self.alignments,
|
||||
'debug': self.debug }
|
||||
|
||||
#override
|
||||
def get_no_process_started_message(self):
|
||||
return 'Unable to start CPU processes.'
|
||||
|
||||
#override
|
||||
def onHostGetProgressBarDesc(self):
|
||||
return "Converting"
|
||||
|
||||
#override
|
||||
def onHostGetProgressBarLen(self):
|
||||
return len (self.input_data)
|
||||
|
||||
#override
|
||||
def onHostGetData(self):
|
||||
if len (self.input_data) > 0:
|
||||
return self.input_data.pop(0)
|
||||
return None
|
||||
|
||||
#override
|
||||
def onHostDataReturn (self, data):
|
||||
self.input_data.insert(0, data)
|
||||
|
||||
#override
|
||||
def onClientInitialize(self, client_dict):
|
||||
print ('Running on %s.' % (client_dict['device_name']) )
|
||||
self.device_idx = client_dict['device_idx']
|
||||
self.device_name = client_dict['device_name']
|
||||
self.converter = client_dict['converter']
|
||||
self.output_path = Path(client_dict['output_dir']) if 'output_dir' in client_dict.keys() else None
|
||||
self.alignments = client_dict['alignments']
|
||||
self.debug = client_dict['debug']
|
||||
return None
|
||||
|
||||
#override
|
||||
def onClientFinalize(self):
|
||||
pass
|
||||
|
||||
#override
|
||||
def onClientProcessData(self, data):
|
||||
filename_path = Path(data)
|
||||
|
||||
files_processed = 1
|
||||
faces_processed = 0
|
||||
|
||||
output_filename_path = self.output_path / filename_path.name
|
||||
if self.converter.get_mode() == ConverterBase.MODE_FACE and filename_path.stem not in self.alignments.keys():
|
||||
if not self.debug:
|
||||
print ( 'no faces found for %s, copying without faces' % (filename_path.name) )
|
||||
shutil.copy ( str(filename_path), str(output_filename_path) )
|
||||
else:
|
||||
image = (cv2.imread(str(filename_path)) / 255.0).astype(np.float32)
|
||||
|
||||
if self.converter.get_mode() == ConverterBase.MODE_IMAGE:
|
||||
image_landmarks = None
|
||||
a_png = AlignedPNG.load( str(filename_path) )
|
||||
if a_png is not None:
|
||||
d = a_png.getFaceswapDictData()
|
||||
if d is not None and 'landmarks' in d.keys():
|
||||
image_landmarks = np.array(d['landmarks'])
|
||||
|
||||
image = self.converter.convert_image(image, image_landmarks, self.debug)
|
||||
if self.debug:
|
||||
for img in image:
|
||||
cv2.imshow ('Debug convert', img )
|
||||
cv2.waitKey(0)
|
||||
faces_processed = 1
|
||||
elif self.converter.get_mode() == ConverterBase.MODE_FACE:
|
||||
faces = self.alignments[filename_path.stem]
|
||||
for image_landmarks in faces:
|
||||
image = self.converter.convert_face(image, image_landmarks, self.debug)
|
||||
if self.debug:
|
||||
for img in image:
|
||||
cv2.imshow ('Debug convert', img )
|
||||
cv2.waitKey(0)
|
||||
faces_processed = len(faces)
|
||||
|
||||
if not self.debug:
|
||||
cv2.imwrite (str(output_filename_path), (image*255).astype(np.uint8) )
|
||||
|
||||
|
||||
return (files_processed, faces_processed)
|
||||
|
||||
#override
|
||||
def onHostResult (self, data, result):
|
||||
self.files_processed += result[0]
|
||||
self.faces_processed += result[1]
|
||||
return 1
|
||||
|
||||
#override
|
||||
def get_start_return(self):
|
||||
return self.files_processed, self.faces_processed
|
||||
|
||||
def main (input_dir, output_dir, aligned_dir, model_dir, model_name, **in_options):
|
||||
print ("Running converter.\r\n")
|
||||
|
||||
debug = in_options['debug']
|
||||
|
||||
try:
|
||||
input_path = Path(input_dir)
|
||||
output_path = Path(output_dir)
|
||||
aligned_path = Path(aligned_dir)
|
||||
model_path = Path(model_dir)
|
||||
|
||||
if not input_path.exists():
|
||||
print('Input directory not found. Please ensure it exists.')
|
||||
return
|
||||
|
||||
if output_path.exists():
|
||||
for filename in Path_utils.get_image_paths(output_path):
|
||||
Path(filename).unlink()
|
||||
else:
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if not aligned_path.exists():
|
||||
print('Aligned directory not found. Please ensure it exists.')
|
||||
return
|
||||
|
||||
if not model_path.exists():
|
||||
print('Model directory not found. Please ensure it exists.')
|
||||
return
|
||||
|
||||
model_sq = multiprocessing.Queue()
|
||||
model_cq = multiprocessing.Queue()
|
||||
model_lock = multiprocessing.Lock()
|
||||
|
||||
model_p = multiprocessing.Process(target=model_process, args=(model_name, model_dir, in_options, model_sq, model_cq))
|
||||
model_p.start()
|
||||
|
||||
while True:
|
||||
if not model_cq.empty():
|
||||
obj = model_cq.get()
|
||||
obj_op = obj['op']
|
||||
if obj_op == 'init':
|
||||
converter = obj['converter']
|
||||
break
|
||||
|
||||
alignments = {}
|
||||
if converter.get_mode() == ConverterBase.MODE_FACE:
|
||||
aligned_path_image_paths = Path_utils.get_image_paths(aligned_path)
|
||||
for filename in tqdm(aligned_path_image_paths, desc= "Collecting alignments" ):
|
||||
a_png = AlignedPNG.load( str(filename) )
|
||||
if a_png is None:
|
||||
print ( "%s - no embedded data found." % (filename) )
|
||||
continue
|
||||
d = a_png.getFaceswapDictData()
|
||||
if d is None or d['source_filename'] is None or d['source_rect'] is None or d['source_landmarks'] is None:
|
||||
print ( "%s - no embedded data found." % (filename) )
|
||||
continue
|
||||
|
||||
source_filename_stem = Path(d['source_filename']).stem
|
||||
if source_filename_stem not in alignments.keys():
|
||||
alignments[ source_filename_stem ] = []
|
||||
|
||||
alignments[ source_filename_stem ].append ( np.array(d['source_landmarks']) )
|
||||
|
||||
files_processed, faces_processed = ConvertSubprocessor (
|
||||
converter = converter.copy_and_set_predictor( model_process_predictor(model_sq,model_cq,model_lock) ),
|
||||
input_path_image_paths = Path_utils.get_image_paths(input_path),
|
||||
output_path = output_path,
|
||||
alignments = alignments,
|
||||
debug = debug ).process()
|
||||
|
||||
model_sq.put ( {'op':'close'} )
|
||||
model_p.join()
|
||||
|
||||
'''
|
||||
if model_name == 'AVATAR':
|
||||
output_path_image_paths = Path_utils.get_image_paths(output_path)
|
||||
|
||||
last_ok_frame = -1
|
||||
for filename in output_path_image_paths:
|
||||
filename_path = Path(filename)
|
||||
stem = Path(filename).stem
|
||||
try:
|
||||
frame = int(stem)
|
||||
except:
|
||||
raise Exception ('Aligned avatars must be created from indexed sequence files.')
|
||||
|
||||
if frame-last_ok_frame > 1:
|
||||
start = last_ok_frame + 1
|
||||
end = frame - 1
|
||||
|
||||
print ("Filling gaps: [%d...%d]" % (start, end) )
|
||||
for i in range (start, end+1):
|
||||
shutil.copy ( str(filename), str( output_path / ('%.5d%s' % (i, filename_path.suffix )) ) )
|
||||
|
||||
last_ok_frame = frame
|
||||
'''
|
||||
|
||||
except Exception as e:
|
||||
print ( 'Error: %s' % (str(e)))
|
||||
traceback.print_exc()
|
||||
|
||||
|
378
mainscripts/Extractor.py
Normal file
378
mainscripts/Extractor.py
Normal file
|
@ -0,0 +1,378 @@
|
|||
import traceback
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import multiprocessing
|
||||
from tqdm import tqdm
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import cv2
|
||||
from utils import Path_utils
|
||||
from utils.AlignedPNG import AlignedPNG
|
||||
from utils import image_utils
|
||||
from facelib import FaceType
|
||||
import facelib
|
||||
import gpufmkmgr
|
||||
|
||||
from utils.SubprocessorBase import SubprocessorBase
|
||||
class ExtractSubprocessor(SubprocessorBase):
|
||||
|
||||
#override
|
||||
def __init__(self, input_data, type, image_size, face_type, debug, multi_gpu=False, manual=False, manual_window_size=0, detector=None, output_path=None ):
|
||||
self.input_data = input_data
|
||||
self.type = type
|
||||
self.image_size = image_size
|
||||
self.face_type = face_type
|
||||
self.debug = debug
|
||||
self.multi_gpu = multi_gpu
|
||||
self.detector = detector
|
||||
self.output_path = output_path
|
||||
self.manual = manual
|
||||
self.manual_window_size = manual_window_size
|
||||
self.result = []
|
||||
|
||||
no_response_time_sec = 60 if not self.manual else 999999
|
||||
super().__init__('Extractor', no_response_time_sec)
|
||||
|
||||
#override
|
||||
def onHostClientsInitialized(self):
|
||||
if self.manual == True:
|
||||
self.wnd_name = 'Manual pass'
|
||||
cv2.namedWindow(self.wnd_name)
|
||||
|
||||
self.landmarks = None
|
||||
self.param_x = -1
|
||||
self.param_y = -1
|
||||
self.param_rect_size = -1
|
||||
self.param = {'x': 0, 'y': 0, 'rect_size' : 5}
|
||||
|
||||
def onMouse(event, x, y, flags, param):
|
||||
if event == cv2.EVENT_MOUSEWHEEL:
|
||||
mod = 1 if flags > 0 else -1
|
||||
param['rect_size'] = max (5, param['rect_size'] + 10*mod)
|
||||
else:
|
||||
param['x'] = x
|
||||
param['y'] = y
|
||||
|
||||
cv2.setMouseCallback(self.wnd_name, onMouse, self.param)
|
||||
|
||||
def get_devices_for_type (self, type, multi_gpu):
|
||||
if (type == 'rects' or type == 'landmarks'):
|
||||
if not multi_gpu:
|
||||
devices = [gpufmkmgr.getBestDeviceIdx()]
|
||||
else:
|
||||
devices = gpufmkmgr.getDevicesWithAtLeastTotalMemoryGB(2)
|
||||
devices = [ (idx, gpufmkmgr.getDeviceName(idx), gpufmkmgr.getDeviceVRAMTotalGb(idx) ) for idx in devices]
|
||||
|
||||
elif type == 'final':
|
||||
devices = [ (i, 'CPU%d' % (i), 0 ) for i in range(0, multiprocessing.cpu_count()) ]
|
||||
|
||||
return devices
|
||||
|
||||
#override
|
||||
def process_info_generator(self):
|
||||
for (device_idx, device_name, device_total_vram_gb) in self.get_devices_for_type(self.type, self.multi_gpu):
|
||||
num_processes = 1
|
||||
if not self.manual and self.type == 'rects' and self.detector == 'mt':
|
||||
num_processes = int ( max (1, device_total_vram_gb / 2) )
|
||||
|
||||
for i in range(0, num_processes ):
|
||||
device_name_for_process = device_name if num_processes == 1 else '%s #%d' % (device_name,i)
|
||||
yield device_name_for_process, {}, {'type' : self.type,
|
||||
'device_idx' : device_idx,
|
||||
'device_name' : device_name_for_process,
|
||||
'image_size': self.image_size,
|
||||
'face_type': self.face_type,
|
||||
'debug': self.debug,
|
||||
'output_dir': str(self.output_path),
|
||||
'detector': self.detector}
|
||||
|
||||
#override
|
||||
def get_no_process_started_message(self):
|
||||
if (self.type == 'rects' or self.type == 'landmarks'):
|
||||
print ( 'You have no capable GPUs. Try to close programs which can consume VRAM, and run again.')
|
||||
elif self.type == 'final':
|
||||
print ( 'Unable to start CPU processes.')
|
||||
|
||||
#override
|
||||
def onHostGetProgressBarDesc(self):
|
||||
return None
|
||||
|
||||
#override
|
||||
def onHostGetProgressBarLen(self):
|
||||
return len (self.input_data)
|
||||
|
||||
#override
|
||||
def onHostGetData(self):
|
||||
if not self.manual:
|
||||
if len (self.input_data) > 0:
|
||||
return self.input_data.pop(0)
|
||||
else:
|
||||
while len (self.input_data) > 0:
|
||||
data = self.input_data[0]
|
||||
filename, faces = data
|
||||
is_frame_done = False
|
||||
if len(faces) == 0:
|
||||
self.original_image = cv2.imread(filename)
|
||||
|
||||
(h,w,c) = self.original_image.shape
|
||||
self.view_scale = 1.0 if self.manual_window_size == 0 else self.manual_window_size / (w if w > h else h)
|
||||
self.original_image = cv2.resize (self.original_image, ( int(w*self.view_scale), int(h*self.view_scale) ), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
self.text_lines_img = (image_utils.get_draw_text_lines ( self.original_image, (0,0, self.original_image.shape[1], min(100, self.original_image.shape[0]) ),
|
||||
[ 'Match landmarks with face exactly.',
|
||||
'[Enter] - confirm frame',
|
||||
'[Space] - skip frame',
|
||||
'[Mouse wheel] - change rect'
|
||||
], (1, 1, 1) )*255).astype(np.uint8)
|
||||
|
||||
while True:
|
||||
key = cv2.waitKey(1) & 0xFF
|
||||
|
||||
if key == ord('\r') or key == ord('\n'):
|
||||
faces.append ( [(self.rect), self.landmarks] )
|
||||
is_frame_done = True
|
||||
break
|
||||
elif key == ord(' '):
|
||||
is_frame_done = True
|
||||
break
|
||||
|
||||
if self.param_x != self.param['x'] / self.view_scale or \
|
||||
self.param_y != self.param['y'] / self.view_scale or \
|
||||
self.param_rect_size != self.param['rect_size']:
|
||||
|
||||
self.param_x = self.param['x'] / self.view_scale
|
||||
self.param_y = self.param['y'] / self.view_scale
|
||||
self.param_rect_size = self.param['rect_size']
|
||||
|
||||
self.rect = (self.param_x-self.param_rect_size, self.param_y-self.param_rect_size, self.param_x+self.param_rect_size, self.param_y+self.param_rect_size)
|
||||
return [filename, [self.rect]]
|
||||
|
||||
else:
|
||||
is_frame_done = True
|
||||
|
||||
if is_frame_done:
|
||||
self.result.append ( data )
|
||||
self.input_data.pop(0)
|
||||
self.inc_progress_bar(1)
|
||||
|
||||
return None
|
||||
|
||||
#override
|
||||
def onHostDataReturn (self, data):
|
||||
if not self.manual:
|
||||
self.input_data.insert(0, data)
|
||||
|
||||
#override
|
||||
def onClientInitialize(self, client_dict):
|
||||
self.safe_print ('Running on %s.' % (client_dict['device_name']) )
|
||||
self.type = client_dict['type']
|
||||
self.image_size = client_dict['image_size']
|
||||
self.face_type = client_dict['face_type']
|
||||
self.device_idx = client_dict['device_idx']
|
||||
self.output_path = Path(client_dict['output_dir']) if 'output_dir' in client_dict.keys() else None
|
||||
self.debug = client_dict['debug']
|
||||
self.detector = client_dict['detector']
|
||||
|
||||
self.keras = None
|
||||
self.tf = None
|
||||
self.tf_session = None
|
||||
|
||||
self.e = None
|
||||
if self.type == 'rects':
|
||||
if self.detector is not None:
|
||||
if self.detector == 'mt':
|
||||
self.tf = gpufmkmgr.import_tf ([self.device_idx], allow_growth=True)
|
||||
self.tf_session = gpufmkmgr.get_tf_session()
|
||||
self.keras = gpufmkmgr.import_keras()
|
||||
self.e = facelib.MTCExtractor(self.keras, self.tf, self.tf_session)
|
||||
elif self.detector == 'dlib':
|
||||
self.dlib = gpufmkmgr.import_dlib( self.device_idx )
|
||||
self.e = facelib.DLIBExtractor(self.dlib)
|
||||
self.e.__enter__()
|
||||
|
||||
elif self.type == 'landmarks':
|
||||
self.tf = gpufmkmgr.import_tf([self.device_idx], allow_growth=True)
|
||||
self.tf_session = gpufmkmgr.get_tf_session()
|
||||
self.keras = gpufmkmgr.import_keras()
|
||||
self.e = facelib.LandmarksExtractor(self.keras)
|
||||
self.e.__enter__()
|
||||
|
||||
elif self.type == 'final':
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
#override
|
||||
def onClientFinalize(self):
|
||||
if self.e is not None:
|
||||
self.e.__exit__()
|
||||
|
||||
#override
|
||||
def onClientProcessData(self, data):
|
||||
filename_path = Path( data[0] )
|
||||
|
||||
image = cv2.imread( str(filename_path) )
|
||||
if image is None:
|
||||
print ( 'Failed to extract %s, reason: cv2.imread() fail.' % ( str(filename_path) ) )
|
||||
else:
|
||||
if self.type == 'rects':
|
||||
rects = self.e.extract_from_bgr (image)
|
||||
return [str(filename_path), rects]
|
||||
|
||||
elif self.type == 'landmarks':
|
||||
rects = data[1]
|
||||
landmarks = self.e.extract_from_bgr (image, rects)
|
||||
return [str(filename_path), landmarks]
|
||||
|
||||
elif self.type == 'final':
|
||||
result = []
|
||||
faces = data[1]
|
||||
|
||||
if self.debug:
|
||||
debug_output_file = '{}_{}'.format( str(Path(str(self.output_path) + '_debug') / filename_path.stem), 'debug.png')
|
||||
debug_image = image.copy()
|
||||
|
||||
for (face_idx, face) in enumerate(faces):
|
||||
output_file = '{}_{}{}'.format(str(self.output_path / filename_path.stem), str(face_idx), '.png')
|
||||
|
||||
rect = face[0]
|
||||
image_landmarks = np.array(face[1])
|
||||
|
||||
if self.debug:
|
||||
facelib.LandmarksProcessor.draw_rect_landmarks (debug_image, rect, image_landmarks, self.image_size, self.face_type)
|
||||
|
||||
if self.face_type == FaceType.MARK_ONLY:
|
||||
face_image = image
|
||||
face_image_landmarks = image_landmarks
|
||||
else:
|
||||
image_to_face_mat = facelib.LandmarksProcessor.get_transform_mat (image_landmarks, self.image_size, self.face_type)
|
||||
face_image = cv2.warpAffine(image, image_to_face_mat, (self.image_size, self.image_size), cv2.INTER_LANCZOS4)
|
||||
face_image_landmarks = facelib.LandmarksProcessor.transform_points (image_landmarks, image_to_face_mat)
|
||||
|
||||
cv2.imwrite(output_file, face_image)
|
||||
|
||||
a_png = AlignedPNG.load (output_file)
|
||||
|
||||
d = {
|
||||
'face_type': FaceType.toString(self.face_type),
|
||||
'landmarks': face_image_landmarks.tolist(),
|
||||
'yaw_value': facelib.LandmarksProcessor.calc_face_yaw (face_image_landmarks),
|
||||
'pitch_value': facelib.LandmarksProcessor.calc_face_pitch (face_image_landmarks),
|
||||
'source_filename': filename_path.name,
|
||||
'source_rect': rect,
|
||||
'source_landmarks': image_landmarks.tolist()
|
||||
}
|
||||
a_png.setFaceswapDictData (d)
|
||||
a_png.save(output_file)
|
||||
|
||||
result.append (output_file)
|
||||
|
||||
if self.debug:
|
||||
cv2.imwrite(debug_output_file, debug_image )
|
||||
|
||||
return result
|
||||
return None
|
||||
|
||||
#overridable
|
||||
def onClientGetDataName (self, data):
|
||||
#return string identificator of your data
|
||||
return data[0]
|
||||
|
||||
#override
|
||||
def onHostResult (self, data, result):
|
||||
if self.manual == True:
|
||||
self.landmarks = result[1][0][1]
|
||||
|
||||
image = cv2.addWeighted (self.original_image,1.0,self.text_lines_img,1.0,0)
|
||||
view_rect = (np.array(self.rect) * self.view_scale).astype(np.int).tolist()
|
||||
view_landmarks = (np.array(self.landmarks) * self.view_scale).astype(np.int).tolist()
|
||||
facelib.LandmarksProcessor.draw_rect_landmarks (image, view_rect, view_landmarks, self.image_size, self.face_type)
|
||||
|
||||
cv2.imshow (self.wnd_name, image)
|
||||
return 0
|
||||
else:
|
||||
if self.type == 'rects':
|
||||
self.result.append ( result )
|
||||
elif self.type == 'landmarks':
|
||||
self.result.append ( result )
|
||||
elif self.type == 'final':
|
||||
self.result += result
|
||||
|
||||
return 1
|
||||
|
||||
#override
|
||||
def onHostProcessEnd(self):
|
||||
if self.manual == True:
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
#override
|
||||
def get_start_return(self):
|
||||
return self.result
|
||||
|
||||
'''
|
||||
detector
|
||||
'dlib'
|
||||
'mt'
|
||||
'manual'
|
||||
|
||||
face_type
|
||||
'full_face'
|
||||
'avatar'
|
||||
'''
|
||||
def main (input_dir, output_dir, debug, detector='mt', multi_gpu=True, manual_fix=False, manual_window_size=0, image_size=256, face_type='full_face'):
|
||||
print ("Running extractor.\r\n")
|
||||
|
||||
input_path = Path(input_dir)
|
||||
output_path = Path(output_dir)
|
||||
face_type = FaceType.fromString(face_type)
|
||||
|
||||
if not input_path.exists():
|
||||
print('Input directory not found. Please ensure it exists.')
|
||||
return
|
||||
|
||||
if output_path.exists():
|
||||
for filename in Path_utils.get_image_paths(output_path):
|
||||
Path(filename).unlink()
|
||||
else:
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if debug:
|
||||
debug_output_path = Path(str(output_path) + '_debug')
|
||||
if debug_output_path.exists():
|
||||
for filename in Path_utils.get_image_paths(debug_output_path):
|
||||
Path(filename).unlink()
|
||||
else:
|
||||
debug_output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
input_path_image_paths = Path_utils.get_image_unique_filestem_paths(input_path, verbose=True)
|
||||
images_found = len(input_path_image_paths)
|
||||
faces_detected = 0
|
||||
if images_found != 0:
|
||||
if detector == 'manual':
|
||||
print ('Performing manual extract...')
|
||||
extracted_faces = ExtractSubprocessor ([ (filename,[]) for filename in input_path_image_paths ], 'landmarks', image_size, face_type, debug, manual=True, manual_window_size=manual_window_size).process()
|
||||
else:
|
||||
print ('Performing 1st pass...')
|
||||
extracted_rects = ExtractSubprocessor ([ (x,) for x in input_path_image_paths ], 'rects', image_size, face_type, debug, multi_gpu=multi_gpu, manual=False, detector=detector).process()
|
||||
|
||||
print ('Performing 2nd pass...')
|
||||
extracted_faces = ExtractSubprocessor (extracted_rects, 'landmarks', image_size, face_type, debug, multi_gpu=multi_gpu, manual=False).process()
|
||||
|
||||
if manual_fix:
|
||||
print ('Performing manual fix...')
|
||||
|
||||
if all ( np.array ( [ len(data[1]) > 0 for data in extracted_faces] ) == True ):
|
||||
print ('All faces are detected, manual fix not needed.')
|
||||
else:
|
||||
extracted_faces = ExtractSubprocessor (extracted_faces, 'landmarks', image_size, face_type, debug, manual=True, manual_window_size=manual_window_size).process()
|
||||
|
||||
if len(extracted_faces) > 0:
|
||||
print ('Performing 3rd pass...')
|
||||
final_imgs_paths = ExtractSubprocessor (extracted_faces, 'final', image_size, face_type, debug, multi_gpu=multi_gpu, manual=False, output_path=output_path).process()
|
||||
faces_detected = len(final_imgs_paths)
|
||||
|
||||
print('-------------------------')
|
||||
print('Images found: %d' % (images_found) )
|
||||
print('Faces detected: %d' % (faces_detected) )
|
||||
print('-------------------------')
|
351
mainscripts/Sorter.py
Normal file
351
mainscripts/Sorter.py
Normal file
|
@ -0,0 +1,351 @@
|
|||
import os
|
||||
import sys
|
||||
import operator
|
||||
import numpy as np
|
||||
import cv2
|
||||
from tqdm import tqdm
|
||||
from shutil import copyfile
|
||||
|
||||
from pathlib import Path
|
||||
from utils import Path_utils
|
||||
from utils.AlignedPNG import AlignedPNG
|
||||
from facelib import LandmarksProcessor
|
||||
|
||||
def estimate_blur(image):
|
||||
if image.ndim == 3:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
blur_map = cv2.Laplacian(image, cv2.CV_64F)
|
||||
score = np.var(blur_map)
|
||||
return score
|
||||
|
||||
def sort_by_brightness(input_path):
|
||||
print ("Sorting by brightness...")
|
||||
img_list = [ [x, np.mean ( cv2.cvtColor(cv2.imread(x), cv2.COLOR_BGR2HSV)[...,2].flatten() )] for x in tqdm( Path_utils.get_image_paths(input_path), desc="Loading") ]
|
||||
print ("Sorting...")
|
||||
img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True)
|
||||
return img_list
|
||||
|
||||
def sort_by_hue(input_path):
|
||||
print ("Sorting by hue...")
|
||||
img_list = [ [x, np.mean ( cv2.cvtColor(cv2.imread(x), cv2.COLOR_BGR2HSV)[...,0].flatten() )] for x in tqdm( Path_utils.get_image_paths(input_path), desc="Loading") ]
|
||||
print ("Sorting...")
|
||||
img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True)
|
||||
return img_list
|
||||
|
||||
def sort_by_blur(input_path):
|
||||
img_list = []
|
||||
print ("Sorting by blur...")
|
||||
for filepath in tqdm( Path_utils.get_image_paths(input_path), desc="Loading"):
|
||||
#never mask it by face hull, it worse than whole image blur estimate
|
||||
img_list.append ( [filepath, estimate_blur (cv2.imread( filepath ))] )
|
||||
|
||||
print ("Sorting...")
|
||||
img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True)
|
||||
|
||||
return img_list
|
||||
|
||||
def sort_by_face(input_path):
|
||||
|
||||
print ("Sorting by face similarity...")
|
||||
|
||||
img_list = []
|
||||
for filepath in tqdm( Path_utils.get_image_paths(input_path), desc="Loading"):
|
||||
filepath = Path(filepath)
|
||||
|
||||
if filepath.suffix != '.png':
|
||||
print ("%s is not a png file required for sort_by_face" % (filepath.name) )
|
||||
continue
|
||||
|
||||
a_png = AlignedPNG.load (str(filepath))
|
||||
if a_png is None:
|
||||
print ("%s failed to load" % (filepath.name) )
|
||||
continue
|
||||
|
||||
d = a_png.getFaceswapDictData()
|
||||
|
||||
if d is None or d['landmarks'] is None:
|
||||
print ("%s - no embedded data found required for sort_by_face" % (filepath.name) )
|
||||
continue
|
||||
|
||||
img_list.append( [str(filepath), np.array(d['landmarks']) ] )
|
||||
|
||||
|
||||
img_list_len = len(img_list)
|
||||
for i in tqdm ( range(0, img_list_len-1), desc="Sorting"):
|
||||
min_score = float("inf")
|
||||
j_min_score = i+1
|
||||
for j in range(i+1,len(img_list)):
|
||||
|
||||
fl1 = img_list[i][1]
|
||||
fl2 = img_list[j][1]
|
||||
score = np.sum ( np.absolute ( (fl2 - fl1).flatten() ) )
|
||||
|
||||
if score < min_score:
|
||||
min_score = score
|
||||
j_min_score = j
|
||||
img_list[i+1], img_list[j_min_score] = img_list[j_min_score], img_list[i+1]
|
||||
|
||||
return img_list
|
||||
|
||||
def sort_by_face_dissim(input_path):
|
||||
|
||||
print ("Sorting by face dissimilarity...")
|
||||
|
||||
img_list = []
|
||||
for filepath in tqdm( Path_utils.get_image_paths(input_path), desc="Loading"):
|
||||
filepath = Path(filepath)
|
||||
|
||||
if filepath.suffix != '.png':
|
||||
print ("%s is not a png file required for sort_by_face_dissim" % (filepath.name) )
|
||||
continue
|
||||
|
||||
a_png = AlignedPNG.load (str(filepath))
|
||||
if a_png is None:
|
||||
print ("%s failed to load" % (filepath.name) )
|
||||
continue
|
||||
|
||||
d = a_png.getFaceswapDictData()
|
||||
|
||||
if d is None or d['landmarks'] is None:
|
||||
print ("%s - no embedded data found required for sort_by_face_dissim" % (filepath.name) )
|
||||
continue
|
||||
|
||||
img_list.append( [str(filepath), np.array(d['landmarks']), 0 ] )
|
||||
|
||||
img_list_len = len(img_list)
|
||||
for i in tqdm( range(0, img_list_len-1), desc="Sorting"):
|
||||
score_total = 0
|
||||
for j in range(i+1,len(img_list)):
|
||||
if i == j:
|
||||
continue
|
||||
fl1 = img_list[i][1]
|
||||
fl2 = img_list[j][1]
|
||||
score_total += np.sum ( np.absolute ( (fl2 - fl1).flatten() ) )
|
||||
|
||||
img_list[i][2] = score_total
|
||||
|
||||
print ("Sorting...")
|
||||
img_list = sorted(img_list, key=operator.itemgetter(2), reverse=True)
|
||||
|
||||
return img_list
|
||||
|
||||
def sort_by_face_yaw(input_path):
|
||||
print ("Sorting by face yaw...")
|
||||
img_list = []
|
||||
for filepath in tqdm( Path_utils.get_image_paths(input_path), desc="Loading"):
|
||||
filepath = Path(filepath)
|
||||
|
||||
if filepath.suffix != '.png':
|
||||
print ("%s is not a png file required for sort_by_face_dissim" % (filepath.name) )
|
||||
continue
|
||||
|
||||
a_png = AlignedPNG.load (str(filepath))
|
||||
if a_png is None:
|
||||
print ("%s failed to load" % (filepath.name) )
|
||||
continue
|
||||
|
||||
d = a_png.getFaceswapDictData()
|
||||
|
||||
if d is None or d['yaw_value'] is None:
|
||||
print ("%s - no embedded data found required for sort_by_face_dissim" % (filepath.name) )
|
||||
continue
|
||||
|
||||
img_list.append( [str(filepath), np.array(d['yaw_value']) ] )
|
||||
|
||||
print ("Sorting...")
|
||||
img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True)
|
||||
|
||||
return img_list
|
||||
|
||||
def sort_by_hist_blur(input_path):
|
||||
|
||||
print ("Sorting by histogram similarity and blur...")
|
||||
|
||||
img_list = []
|
||||
for x in tqdm( Path_utils.get_image_paths(input_path), desc="Loading"):
|
||||
img = cv2.imread(x)
|
||||
img_list.append ([x, cv2.calcHist([img], [0], None, [256], [0, 256]),
|
||||
cv2.calcHist([img], [1], None, [256], [0, 256]),
|
||||
cv2.calcHist([img], [2], None, [256], [0, 256]),
|
||||
estimate_blur(img)
|
||||
])
|
||||
|
||||
img_list_len = len(img_list)
|
||||
for i in tqdm( range(0, img_list_len-1), desc="Sorting"):
|
||||
min_score = float("inf")
|
||||
j_min_score = i+1
|
||||
for j in range(i+1,len(img_list)):
|
||||
score = cv2.compareHist(img_list[i][1], img_list[j][1], cv2.HISTCMP_BHATTACHARYYA) + \
|
||||
cv2.compareHist(img_list[i][2], img_list[j][2], cv2.HISTCMP_BHATTACHARYYA) + \
|
||||
cv2.compareHist(img_list[i][3], img_list[j][3], cv2.HISTCMP_BHATTACHARYYA)
|
||||
if score < min_score:
|
||||
min_score = score
|
||||
j_min_score = j
|
||||
img_list[i+1], img_list[j_min_score] = img_list[j_min_score], img_list[i+1]
|
||||
|
||||
l = []
|
||||
for i in range(0, img_list_len-1):
|
||||
score = cv2.compareHist(img_list[i][1], img_list[i+1][1], cv2.HISTCMP_BHATTACHARYYA) + \
|
||||
cv2.compareHist(img_list[i][2], img_list[i+1][2], cv2.HISTCMP_BHATTACHARYYA) + \
|
||||
cv2.compareHist(img_list[i][3], img_list[i+1][3], cv2.HISTCMP_BHATTACHARYYA)
|
||||
l += [score]
|
||||
l = np.array(l)
|
||||
v = np.mean(l)
|
||||
if v*2 < np.max(l):
|
||||
v *= 2
|
||||
|
||||
new_img_list = []
|
||||
|
||||
start_group_i = 0
|
||||
odd_counter = 0
|
||||
for i in tqdm( range(0, img_list_len), desc="Sorting"):
|
||||
end_group_i = -1
|
||||
if i < img_list_len-1:
|
||||
score = cv2.compareHist(img_list[i][1], img_list[i+1][1], cv2.HISTCMP_BHATTACHARYYA) + \
|
||||
cv2.compareHist(img_list[i][2], img_list[i+1][2], cv2.HISTCMP_BHATTACHARYYA) + \
|
||||
cv2.compareHist(img_list[i][3], img_list[i+1][3], cv2.HISTCMP_BHATTACHARYYA)
|
||||
|
||||
if score >= v:
|
||||
end_group_i = i
|
||||
|
||||
elif i == img_list_len-1:
|
||||
end_group_i = i
|
||||
|
||||
if end_group_i >= start_group_i:
|
||||
odd_counter += 1
|
||||
|
||||
s = sorted(img_list[start_group_i:end_group_i+1] , key=operator.itemgetter(4), reverse=True)
|
||||
if odd_counter % 2 == 0:
|
||||
new_img_list = new_img_list + s
|
||||
else:
|
||||
new_img_list = s + new_img_list
|
||||
|
||||
start_group_i = i + 1
|
||||
|
||||
return new_img_list
|
||||
|
||||
def sort_by_hist(input_path):
|
||||
|
||||
print ("Sorting by histogram similarity...")
|
||||
|
||||
img_list = []
|
||||
for x in tqdm( Path_utils.get_image_paths(input_path), desc="Loading"):
|
||||
img = cv2.imread(x)
|
||||
img_list.append ([x, cv2.calcHist([img], [0], None, [256], [0, 256]),
|
||||
cv2.calcHist([img], [1], None, [256], [0, 256]),
|
||||
cv2.calcHist([img], [2], None, [256], [0, 256])
|
||||
])
|
||||
|
||||
img_list_len = len(img_list)
|
||||
for i in tqdm( range(0, img_list_len-1), desc="Sorting"):
|
||||
min_score = float("inf")
|
||||
j_min_score = i+1
|
||||
for j in range(i+1,len(img_list)):
|
||||
score = cv2.compareHist(img_list[i][1], img_list[j][1], cv2.HISTCMP_BHATTACHARYYA) + \
|
||||
cv2.compareHist(img_list[i][2], img_list[j][2], cv2.HISTCMP_BHATTACHARYYA) + \
|
||||
cv2.compareHist(img_list[i][3], img_list[j][3], cv2.HISTCMP_BHATTACHARYYA)
|
||||
if score < min_score:
|
||||
min_score = score
|
||||
j_min_score = j
|
||||
img_list[i+1], img_list[j_min_score] = img_list[j_min_score], img_list[i+1]
|
||||
|
||||
return img_list
|
||||
|
||||
def sort_by_hist_dissim(input_path):
|
||||
|
||||
print ("Sorting by histogram dissimilarity...")
|
||||
|
||||
img_list = []
|
||||
for x in tqdm( Path_utils.get_image_paths(input_path), desc="Loading"):
|
||||
img = cv2.imread(x)
|
||||
img_list.append ([x, cv2.calcHist([img], [0], None, [256], [0, 256]),
|
||||
cv2.calcHist([img], [1], None, [256], [0, 256]),
|
||||
cv2.calcHist([img], [2], None, [256], [0, 256]), 0
|
||||
])
|
||||
|
||||
img_list_len = len(img_list)
|
||||
for i in tqdm ( range(0, img_list_len), desc="Sorting"):
|
||||
score_total = 0
|
||||
for j in range( 0, img_list_len):
|
||||
if i == j:
|
||||
continue
|
||||
score_total += cv2.compareHist(img_list[i][1], img_list[j][1], cv2.HISTCMP_BHATTACHARYYA) + \
|
||||
cv2.compareHist(img_list[i][2], img_list[j][2], cv2.HISTCMP_BHATTACHARYYA) + \
|
||||
cv2.compareHist(img_list[i][3], img_list[j][3], cv2.HISTCMP_BHATTACHARYYA)
|
||||
|
||||
img_list[i][4] = score_total
|
||||
|
||||
|
||||
print ("Sorting...")
|
||||
img_list = sorted(img_list, key=operator.itemgetter(4), reverse=True)
|
||||
|
||||
return img_list
|
||||
|
||||
def final_rename(input_path, img_list):
|
||||
for i in tqdm( range(0,len(img_list)), desc="Renaming" , leave=False):
|
||||
src = Path (img_list[i][0])
|
||||
dst = input_path / ('%.5d_%s' % (i, src.name ))
|
||||
try:
|
||||
src.rename (dst)
|
||||
except:
|
||||
print ('fail to rename %s' % (src.name) )
|
||||
|
||||
for i in tqdm( range(0,len(img_list)) , desc="Renaming" ):
|
||||
src = Path (img_list[i][0])
|
||||
|
||||
src = input_path / ('%.5d_%s' % (i, src.name))
|
||||
dst = input_path / ('%.5d%s' % (i, src.suffix))
|
||||
try:
|
||||
src.rename (dst)
|
||||
except:
|
||||
print ('fail to rename %s' % (src.name) )
|
||||
|
||||
def sort_by_origname(input_path):
|
||||
print ("Sort by original filename...")
|
||||
|
||||
img_list = []
|
||||
for filepath in tqdm( Path_utils.get_image_paths(input_path), desc="Loading"):
|
||||
filepath = Path(filepath)
|
||||
|
||||
if filepath.suffix != '.png':
|
||||
print ("%s is not a png file required for sort_by_origname" % (filepath.name) )
|
||||
continue
|
||||
|
||||
a_png = AlignedPNG.load (str(filepath))
|
||||
if a_png is None:
|
||||
print ("%s failed to load" % (filepath.name) )
|
||||
continue
|
||||
|
||||
d = a_png.getFaceswapDictData()
|
||||
|
||||
if d is None or d['source_filename'] is None:
|
||||
print ("%s - no embedded data found required for sort_by_origname" % (filepath.name) )
|
||||
continue
|
||||
|
||||
img_list.append( [str(filepath), d['source_filename']] )
|
||||
|
||||
print ("Sorting...")
|
||||
img_list = sorted(img_list, key=operator.itemgetter(1))
|
||||
return img_list
|
||||
|
||||
def main (input_path, sort_by_method):
|
||||
input_path = Path(input_path)
|
||||
sort_by_method = sort_by_method.lower()
|
||||
|
||||
print ("Running sort tool.\r\n")
|
||||
|
||||
img_list = []
|
||||
|
||||
if sort_by_method == 'blur': img_list = sort_by_blur (input_path)
|
||||
elif sort_by_method == 'face': img_list = sort_by_face (input_path)
|
||||
elif sort_by_method == 'face-dissim': img_list = sort_by_face_dissim (input_path)
|
||||
elif sort_by_method == 'face-yaw': img_list = sort_by_face_yaw (input_path)
|
||||
elif sort_by_method == 'hist': img_list = sort_by_hist (input_path)
|
||||
elif sort_by_method == 'hist-dissim': img_list = sort_by_hist_dissim (input_path)
|
||||
elif sort_by_method == 'hist-blur': img_list = sort_by_hist_blur (input_path)
|
||||
elif sort_by_method == 'brightness': img_list = sort_by_brightness (input_path)
|
||||
elif sort_by_method == 'hue': img_list = sort_by_hue (input_path)
|
||||
elif sort_by_method == 'origname': img_list = sort_by_origname (input_path)
|
||||
|
||||
final_rename (input_path, img_list)
|
289
mainscripts/Trainer.py
Normal file
289
mainscripts/Trainer.py
Normal file
|
@ -0,0 +1,289 @@
|
|||
import sys
|
||||
import traceback
|
||||
import queue
|
||||
import colorsys
|
||||
import time
|
||||
import numpy as np
|
||||
import itertools
|
||||
|
||||
from pathlib import Path
|
||||
from utils import Path_utils
|
||||
from utils import image_utils
|
||||
import cv2
|
||||
|
||||
def trainerThread (input_queue, output_queue, training_data_src_dir, training_data_dst_dir, model_path, model_name, save_interval_min=10, debug=False, target_epoch=0, **in_options):
|
||||
|
||||
while True:
|
||||
try:
|
||||
training_data_src_path = Path(training_data_src_dir)
|
||||
training_data_dst_path = Path(training_data_dst_dir)
|
||||
model_path = Path(model_path)
|
||||
|
||||
if not training_data_src_path.exists():
|
||||
print( 'Training data src directory is not exists.')
|
||||
return
|
||||
|
||||
if not training_data_dst_path.exists():
|
||||
print( 'Training data dst directory is not exists.')
|
||||
return
|
||||
|
||||
if not model_path.exists():
|
||||
model_path.mkdir(exist_ok=True)
|
||||
|
||||
|
||||
|
||||
import models
|
||||
model = models.import_model(model_name)(
|
||||
model_path,
|
||||
training_data_src_path=training_data_src_path,
|
||||
training_data_dst_path=training_data_dst_path,
|
||||
debug=debug,
|
||||
**in_options)
|
||||
|
||||
is_reached_goal = (target_epoch > 0 and model.get_epoch() >= target_epoch)
|
||||
|
||||
def model_save():
|
||||
if not debug and not is_reached_goal:
|
||||
model.save()
|
||||
|
||||
def send_preview():
|
||||
if not debug:
|
||||
previews = model.get_previews()
|
||||
output_queue.put ( {'op':'show', 'previews': previews, 'epoch':model.get_epoch(), 'loss_history': model.get_loss_history().copy() } )
|
||||
else:
|
||||
previews = [( 'debug, press update for new', model.debug_one_epoch())]
|
||||
output_queue.put ( {'op':'show', 'previews': previews} )
|
||||
|
||||
|
||||
if model.is_first_run():
|
||||
model_save()
|
||||
|
||||
if target_epoch != 0:
|
||||
if is_reached_goal:
|
||||
print ('Model already trained to target epoch. You can use preview.')
|
||||
else:
|
||||
print('Starting. Target epoch: %d. Press "Enter" to stop training and save model.' % (target_epoch) )
|
||||
else:
|
||||
print('Starting. Press "Enter" to stop training and save model.')
|
||||
|
||||
last_save_time = time.time()
|
||||
for i in itertools.count(0,1):
|
||||
if not debug:
|
||||
if not is_reached_goal:
|
||||
loss_string = model.train_one_epoch()
|
||||
|
||||
print (loss_string, end='\r')
|
||||
if target_epoch != 0 and model.get_epoch() >= target_epoch:
|
||||
print ('Reached target epoch.')
|
||||
model_save()
|
||||
is_reached_goal = True
|
||||
print ('You can use preview now.')
|
||||
|
||||
if not is_reached_goal and (time.time() - last_save_time) >= save_interval_min*60:
|
||||
last_save_time = time.time()
|
||||
model_save()
|
||||
send_preview()
|
||||
|
||||
if i==0:
|
||||
if is_reached_goal:
|
||||
model.pass_one_epoch()
|
||||
send_preview()
|
||||
|
||||
if debug:
|
||||
time.sleep(0.005)
|
||||
|
||||
while not input_queue.empty():
|
||||
input = input_queue.get()
|
||||
op = input['op']
|
||||
if op == 'save':
|
||||
model_save()
|
||||
elif op == 'preview':
|
||||
if is_reached_goal:
|
||||
model.pass_one_epoch()
|
||||
send_preview()
|
||||
elif op == 'close':
|
||||
model_save()
|
||||
i = -1
|
||||
break
|
||||
|
||||
if i == -1:
|
||||
break
|
||||
|
||||
|
||||
|
||||
model.finalize()
|
||||
|
||||
except Exception as e:
|
||||
print ('Error: %s' % (str(e)))
|
||||
traceback.print_exc()
|
||||
break
|
||||
output_queue.put ( {'op':'close'} )
|
||||
|
||||
def previewThread (input_queue, output_queue):
|
||||
|
||||
|
||||
previews = None
|
||||
loss_history = None
|
||||
selected_preview = 0
|
||||
update_preview = False
|
||||
is_showing = False
|
||||
is_waiting_preview = False
|
||||
epoch = 0
|
||||
while True:
|
||||
if not input_queue.empty():
|
||||
input = input_queue.get()
|
||||
op = input['op']
|
||||
if op == 'show':
|
||||
is_waiting_preview = False
|
||||
loss_history = input['loss_history'] if 'loss_history' in input.keys() else None
|
||||
previews = input['previews'] if 'previews' in input.keys() else None
|
||||
epoch = input['epoch'] if 'epoch' in input.keys() else 0
|
||||
if previews is not None:
|
||||
max_w = 0
|
||||
max_h = 0
|
||||
for (preview_name, preview_rgb) in previews:
|
||||
(h, w, c) = preview_rgb.shape
|
||||
max_h = max (max_h, h)
|
||||
max_w = max (max_w, w)
|
||||
|
||||
max_size = 800
|
||||
if max_h > max_size:
|
||||
max_w = int( max_w / (max_h / max_size) )
|
||||
max_h = max_size
|
||||
|
||||
#make all previews size equal
|
||||
for preview in previews[:]:
|
||||
(preview_name, preview_rgb) = preview
|
||||
(h, w, c) = preview_rgb.shape
|
||||
if h != max_h or w != max_w:
|
||||
previews.remove(preview)
|
||||
previews.append ( (preview_name, cv2.resize(preview_rgb, (max_w, max_h))) )
|
||||
selected_preview = selected_preview % len(previews)
|
||||
update_preview = True
|
||||
elif op == 'close':
|
||||
break
|
||||
|
||||
if update_preview:
|
||||
update_preview = False
|
||||
(h,w,c) = previews[0][1].shape
|
||||
|
||||
selected_preview_name = previews[selected_preview][0]
|
||||
selected_preview_rgb = previews[selected_preview][1]
|
||||
|
||||
# HEAD
|
||||
head_text_color = [0.8]*c
|
||||
head_lines = [
|
||||
'[s]:save [enter]:exit',
|
||||
'[p]:update [space]:next preview',
|
||||
'Preview: "%s" [%d/%d]' % (selected_preview_name,selected_preview+1, len(previews) )
|
||||
]
|
||||
head_line_height = 15
|
||||
head_height = len(head_lines) * head_line_height
|
||||
head = np.ones ( (head_height,w,c) ) * 0.1
|
||||
|
||||
for i in range(0, len(head_lines)):
|
||||
t = i*head_line_height
|
||||
b = (i+1)*head_line_height
|
||||
head[t:b, 0:w] += image_utils.get_text_image ( (w,head_line_height,c) , head_lines[i], color=head_text_color )
|
||||
|
||||
final = head
|
||||
|
||||
if loss_history is not None:
|
||||
# LOSS HISTORY
|
||||
loss_history = np.array (loss_history)
|
||||
|
||||
lh_height = 100
|
||||
lh_img = np.ones ( (lh_height,w,c) ) * 0.1
|
||||
loss_count = len(loss_history[0])
|
||||
lh_len = len(loss_history)
|
||||
|
||||
l_per_col = lh_len / w
|
||||
plist_max = [ [ max (0.0, 0.0, *[ loss_history[i_ab][p]
|
||||
for i_ab in range( int(col*l_per_col), int((col+1)*l_per_col) )
|
||||
]
|
||||
)
|
||||
for p in range(0,loss_count)
|
||||
]
|
||||
for col in range(0, w)
|
||||
]
|
||||
|
||||
|
||||
plist_min = [ [ min (plist_max[col][p],
|
||||
plist_max[col][p],
|
||||
*[ loss_history[i_ab][p]
|
||||
for i_ab in range( int(col*l_per_col), int((col+1)*l_per_col) )
|
||||
]
|
||||
)
|
||||
for p in range(0,loss_count)
|
||||
]
|
||||
for col in range(0, w)
|
||||
]
|
||||
plist_abs_max = np.mean(loss_history[ len(loss_history) // 5 : ]) * 2
|
||||
|
||||
if l_per_col >= 1.0:
|
||||
for col in range(0, w):
|
||||
for p in range(0,loss_count):
|
||||
point_color = [1.0]*c
|
||||
point_color[0:3] = colorsys.hsv_to_rgb ( p * (1.0/loss_count), 1.0, 1.0 )
|
||||
|
||||
ph_max = int ( (plist_max[col][p] / plist_abs_max) * (lh_height-1) )
|
||||
ph_max = np.clip( ph_max, 0, lh_height-1 )
|
||||
|
||||
ph_min = int ( (plist_min[col][p] / plist_abs_max) * (lh_height-1) )
|
||||
ph_min = np.clip( ph_min, 0, lh_height-1 )
|
||||
|
||||
for ph in range(ph_min, ph_max+1):
|
||||
lh_img[ (lh_height-ph-1), col ] = point_color
|
||||
|
||||
lh_lines = 5
|
||||
lh_line_height = (lh_height-1)/lh_lines
|
||||
for i in range(0,lh_lines+1):
|
||||
lh_img[ int(i*lh_line_height), : ] = (0.8,)*c
|
||||
|
||||
last_line_t = int((lh_lines-1)*lh_line_height)
|
||||
last_line_b = int(lh_lines*lh_line_height)
|
||||
|
||||
if epoch != 0:
|
||||
lh_text = 'Loss history. Epoch: %d' % (epoch)
|
||||
else:
|
||||
lh_text = 'Loss history.'
|
||||
|
||||
lh_img[last_line_t:last_line_b, 0:w] += image_utils.get_text_image ( (w,last_line_b-last_line_t,c), lh_text, color=head_text_color )
|
||||
|
||||
final = np.concatenate ( [final, lh_img], axis=0 )
|
||||
|
||||
final = np.concatenate ( [final, selected_preview_rgb], axis=0 )
|
||||
|
||||
cv2.imshow ( 'Training preview', final)
|
||||
is_showing = True
|
||||
|
||||
if is_showing:
|
||||
key = cv2.waitKey(100)
|
||||
else:
|
||||
time.sleep(0.1)
|
||||
key = 0
|
||||
|
||||
if key == ord('\n') or key == ord('\r'):
|
||||
output_queue.put ( {'op': 'close'} )
|
||||
elif key == ord('s'):
|
||||
output_queue.put ( {'op': 'save'} )
|
||||
elif key == ord('p'):
|
||||
if not is_waiting_preview:
|
||||
is_waiting_preview = True
|
||||
output_queue.put ( {'op': 'preview'} )
|
||||
elif key == ord(' '):
|
||||
selected_preview = (selected_preview + 1) % len(previews)
|
||||
update_preview = True
|
||||
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
def main (training_data_src_dir, training_data_dst_dir, model_path, model_name, **in_options):
|
||||
print ("Running trainer.\r\n")
|
||||
|
||||
output_queue = queue.Queue()
|
||||
input_queue = queue.Queue()
|
||||
import threading
|
||||
thread = threading.Thread(target=trainerThread, args=(output_queue, input_queue, training_data_src_dir, training_data_dst_dir, model_path, model_name), kwargs=in_options )
|
||||
thread.start()
|
||||
|
||||
previewThread (input_queue, output_queue)
|
Loading…
Add table
Add a link
Reference in a new issue