mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 13:02:15 -07:00
transfercolor via lab converter now implemented by tensorflow-cpu, which is x2 faster than skimage.
We cannot use GPU for lab converter in converter multiprocesses, because almost all VRAM ate by model process, so even 300Mb free VRAM not enough for tensorflow lab converter. Removed skimage dependency. Refactorings.
This commit is contained in:
parent
6f0d38d171
commit
5c43f4245e
7 changed files with 229 additions and 67 deletions
|
@ -5,6 +5,8 @@ import contextlib
|
||||||
from utils import std_utils
|
from utils import std_utils
|
||||||
from .pynvml import *
|
from .pynvml import *
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
dlib_module = None
|
dlib_module = None
|
||||||
def import_dlib(device_idx):
|
def import_dlib(device_idx):
|
||||||
global dlib_module
|
global dlib_module
|
||||||
|
@ -22,18 +24,26 @@ keras_module = None
|
||||||
keras_contrib_module = None
|
keras_contrib_module = None
|
||||||
keras_vggface_module = None
|
keras_vggface_module = None
|
||||||
|
|
||||||
|
def set_prefer_GPUConfig(gpu_config):
|
||||||
|
global prefer_GPUConfig
|
||||||
|
prefer_GPUConfig = gpu_config
|
||||||
|
|
||||||
def get_tf_session():
|
def get_tf_session():
|
||||||
global tf_session
|
global tf_session
|
||||||
return tf_session
|
return tf_session
|
||||||
|
|
||||||
#allow_growth=False for keras model
|
def import_tf( gpu_config = None ):
|
||||||
#allow_growth=True for tf only model
|
global prefer_GPUConfig
|
||||||
def import_tf( device_idxs_list, allow_growth ):
|
|
||||||
global tf_module
|
global tf_module
|
||||||
global tf_session
|
global tf_session
|
||||||
|
|
||||||
|
if gpu_config is None:
|
||||||
|
gpu_config = prefer_GPUConfig
|
||||||
|
else:
|
||||||
|
prefer_GPUConfig = gpu_config
|
||||||
|
|
||||||
if tf_module is not None:
|
if tf_module is not None:
|
||||||
raise Exception ('Multiple import of tf is not allowed, reorganize your program.')
|
return tf_module
|
||||||
|
|
||||||
if 'TF_SUPPRESS_STD' in os.environ.keys() and os.environ['TF_SUPPRESS_STD'] == '1':
|
if 'TF_SUPPRESS_STD' in os.environ.keys() and os.environ['TF_SUPPRESS_STD'] == '1':
|
||||||
suppressor = std_utils.suppress_stdout_stderr().__enter__()
|
suppressor = std_utils.suppress_stdout_stderr().__enter__()
|
||||||
|
@ -48,14 +58,18 @@ def import_tf( device_idxs_list, allow_growth ):
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
tf_module = tf
|
tf_module = tf
|
||||||
|
|
||||||
visible_device_list = ''
|
if gpu_config.cpu_only:
|
||||||
for idx in device_idxs_list: visible_device_list += str(idx) + ','
|
config = tf_module.ConfigProto( device_count = {'GPU': 0} )
|
||||||
visible_device_list = visible_device_list[:-1]
|
else:
|
||||||
|
config = tf_module.ConfigProto()
|
||||||
|
visible_device_list = ''
|
||||||
|
for idx in gpu_config.gpu_idxs: visible_device_list += str(idx) + ','
|
||||||
|
visible_device_list = visible_device_list[:-1]
|
||||||
|
config.gpu_options.visible_device_list=visible_device_list
|
||||||
|
config.gpu_options.force_gpu_compatible = True
|
||||||
|
|
||||||
config = tf_module.ConfigProto()
|
config.gpu_options.allow_growth = gpu_config.allow_growth
|
||||||
config.gpu_options.allow_growth = allow_growth
|
|
||||||
config.gpu_options.visible_device_list=visible_device_list
|
|
||||||
config.gpu_options.force_gpu_compatible = True
|
|
||||||
tf_session = tf_module.Session(config=config)
|
tf_session = tf_module.Session(config=config)
|
||||||
|
|
||||||
if suppressor is not None:
|
if suppressor is not None:
|
||||||
|
@ -71,11 +85,15 @@ def finalize_tf():
|
||||||
tf_session = None
|
tf_session = None
|
||||||
tf_module = None
|
tf_module = None
|
||||||
|
|
||||||
|
def get_keras():
|
||||||
|
global keras_module
|
||||||
|
return keras_module
|
||||||
|
|
||||||
def import_keras():
|
def import_keras():
|
||||||
global keras_module
|
global keras_module
|
||||||
|
|
||||||
if keras_module is not None:
|
if keras_module is not None:
|
||||||
raise Exception ('Multiple import of keras is not allowed, reorganize your program.')
|
return keras_module
|
||||||
|
|
||||||
sess = get_tf_session()
|
sess = get_tf_session()
|
||||||
if sess is None:
|
if sess is None:
|
||||||
|
@ -241,4 +259,48 @@ def getDeviceName (idx):
|
||||||
if idx < nvmlDeviceGetCount():
|
if idx < nvmlDeviceGetCount():
|
||||||
result = nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(idx)).decode()
|
result = nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(idx)).decode()
|
||||||
nvmlShutdown()
|
nvmlShutdown()
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class GPUConfig():
|
||||||
|
force_best_gpu_idx = -1
|
||||||
|
multi_gpu = False
|
||||||
|
force_gpu_idxs = None
|
||||||
|
choose_worst_gpu = False
|
||||||
|
gpu_idxs = []
|
||||||
|
gpu_total_vram_gb = 0
|
||||||
|
allow_growth = True
|
||||||
|
cpu_only = False
|
||||||
|
|
||||||
|
def __init__ (self, force_best_gpu_idx = -1,
|
||||||
|
multi_gpu = False,
|
||||||
|
force_gpu_idxs = None,
|
||||||
|
choose_worst_gpu = False,
|
||||||
|
allow_growth = True,
|
||||||
|
cpu_only = False,
|
||||||
|
**in_options):
|
||||||
|
|
||||||
|
if cpu_only:
|
||||||
|
self.cpu_only = cpu_only
|
||||||
|
else:
|
||||||
|
self.force_best_gpu_idx = force_best_gpu_idx
|
||||||
|
self.multi_gpu = multi_gpu
|
||||||
|
self.force_gpu_idxs = force_gpu_idxs
|
||||||
|
self.choose_worst_gpu = choose_worst_gpu
|
||||||
|
self.allow_growth = allow_growth
|
||||||
|
|
||||||
|
gpu_idx = force_best_gpu_idx if (force_best_gpu_idx >= 0 and isValidDeviceIdx(force_best_gpu_idx)) else getBestDeviceIdx() if not choose_worst_gpu else getWorstDeviceIdx()
|
||||||
|
|
||||||
|
if force_gpu_idxs is not None:
|
||||||
|
self.gpu_idxs = [ int(x) for x in force_gpu_idxs.split(',') ]
|
||||||
|
else:
|
||||||
|
if self.multi_gpu:
|
||||||
|
self.gpu_idxs = getDeviceIdxsEqualModel( gpu_idx )
|
||||||
|
if len(self.gpu_idxs) <= 1:
|
||||||
|
self.multi_gpu = False
|
||||||
|
else:
|
||||||
|
self.gpu_idxs = [gpu_idx]
|
||||||
|
|
||||||
|
self.gpu_total_vram_gb = getDeviceVRAMTotalGb ( self.gpu_idxs[0] )
|
||||||
|
|
||||||
|
prefer_GPUConfig = GPUConfig()
|
|
@ -64,14 +64,14 @@ from utils.SubprocessorBase import SubprocessorBase
|
||||||
class ConvertSubprocessor(SubprocessorBase):
|
class ConvertSubprocessor(SubprocessorBase):
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def __init__(self, converter, input_path_image_paths, output_path, alignments, debug):
|
def __init__(self, converter, input_path_image_paths, output_path, alignments, debug = False, **in_options):
|
||||||
super().__init__('Converter')
|
super().__init__('Converter')
|
||||||
self.converter = converter
|
self.converter = converter
|
||||||
self.input_path_image_paths = input_path_image_paths
|
self.input_path_image_paths = input_path_image_paths
|
||||||
self.output_path = output_path
|
self.output_path = output_path
|
||||||
self.alignments = alignments
|
self.alignments = alignments
|
||||||
self.debug = debug
|
self.debug = debug
|
||||||
|
self.in_options = in_options
|
||||||
self.input_data = self.input_path_image_paths
|
self.input_data = self.input_path_image_paths
|
||||||
self.files_processed = 0
|
self.files_processed = 0
|
||||||
self.faces_processed = 0
|
self.faces_processed = 0
|
||||||
|
@ -79,13 +79,18 @@ class ConvertSubprocessor(SubprocessorBase):
|
||||||
#override
|
#override
|
||||||
def process_info_generator(self):
|
def process_info_generator(self):
|
||||||
r = [0] if self.debug else range(multiprocessing.cpu_count())
|
r = [0] if self.debug else range(multiprocessing.cpu_count())
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
for i in r:
|
for i in r:
|
||||||
yield 'CPU%d' % (i), {}, {'device_idx': i,
|
yield 'CPU%d' % (i), {}, {'device_idx': i,
|
||||||
'device_name': 'CPU%d' % (i),
|
'device_name': 'CPU%d' % (i),
|
||||||
'converter' : self.converter,
|
'converter' : self.converter,
|
||||||
'output_dir' : str(self.output_path),
|
'output_dir' : str(self.output_path),
|
||||||
'alignments' : self.alignments,
|
'alignments' : self.alignments,
|
||||||
'debug': self.debug }
|
'debug': self.debug,
|
||||||
|
'in_options': self.in_options
|
||||||
|
}
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def get_no_process_started_message(self):
|
def get_no_process_started_message(self):
|
||||||
|
@ -118,6 +123,12 @@ class ConvertSubprocessor(SubprocessorBase):
|
||||||
self.output_path = Path(client_dict['output_dir']) if 'output_dir' in client_dict.keys() else None
|
self.output_path = Path(client_dict['output_dir']) if 'output_dir' in client_dict.keys() else None
|
||||||
self.alignments = client_dict['alignments']
|
self.alignments = client_dict['alignments']
|
||||||
self.debug = client_dict['debug']
|
self.debug = client_dict['debug']
|
||||||
|
|
||||||
|
import gpufmkmgr
|
||||||
|
#model process ate all GPU mem,
|
||||||
|
#so we cannot use GPU for any TF operations in converter processes (for example image_utils.TFLabConverter)
|
||||||
|
gpufmkmgr.set_prefer_GPUConfig ( gpufmkmgr.GPUConfig (cpu_only=True) )
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
#override
|
#override
|
||||||
|
@ -211,7 +222,6 @@ def main (input_dir, output_dir, aligned_dir, model_dir, model_name, **in_option
|
||||||
model_sq = multiprocessing.Queue()
|
model_sq = multiprocessing.Queue()
|
||||||
model_cq = multiprocessing.Queue()
|
model_cq = multiprocessing.Queue()
|
||||||
model_lock = multiprocessing.Lock()
|
model_lock = multiprocessing.Lock()
|
||||||
|
|
||||||
model_p = multiprocessing.Process(target=model_process, args=(model_name, model_dir, in_options, model_sq, model_cq))
|
model_p = multiprocessing.Process(target=model_process, args=(model_name, model_dir, in_options, model_sq, model_cq))
|
||||||
model_p.start()
|
model_p.start()
|
||||||
|
|
||||||
|
@ -241,13 +251,14 @@ def main (input_dir, output_dir, aligned_dir, model_dir, model_name, **in_option
|
||||||
alignments[ source_filename_stem ] = []
|
alignments[ source_filename_stem ] = []
|
||||||
|
|
||||||
alignments[ source_filename_stem ].append ( np.array(d['source_landmarks']) )
|
alignments[ source_filename_stem ].append ( np.array(d['source_landmarks']) )
|
||||||
|
|
||||||
|
|
||||||
files_processed, faces_processed = ConvertSubprocessor (
|
files_processed, faces_processed = ConvertSubprocessor (
|
||||||
converter = converter.copy_and_set_predictor( model_process_predictor(model_sq,model_cq,model_lock) ),
|
converter = converter.copy_and_set_predictor( model_process_predictor(model_sq,model_cq,model_lock) ),
|
||||||
input_path_image_paths = Path_utils.get_image_paths(input_path),
|
input_path_image_paths = Path_utils.get_image_paths(input_path),
|
||||||
output_path = output_path,
|
output_path = output_path,
|
||||||
alignments = alignments,
|
alignments = alignments,
|
||||||
debug = debug ).process()
|
**in_options ).process()
|
||||||
|
|
||||||
model_sq.put ( {'op':'close'} )
|
model_sq.put ( {'op':'close'} )
|
||||||
model_p.join()
|
model_p.join()
|
||||||
|
|
|
@ -241,7 +241,9 @@ class ExtractSubprocessor(SubprocessorBase):
|
||||||
if self.type == 'rects':
|
if self.type == 'rects':
|
||||||
if self.detector is not None:
|
if self.detector is not None:
|
||||||
if self.detector == 'mt':
|
if self.detector == 'mt':
|
||||||
self.tf = gpufmkmgr.import_tf ([self.device_idx], allow_growth=True)
|
|
||||||
|
self.gpu_config = gpufmkmgr.GPUConfig ( force_best_gpu_idx=self.device_idx, allow_growth=True)
|
||||||
|
self.tf = gpufmkmgr.import_tf ( self.gpu_config )
|
||||||
self.tf_session = gpufmkmgr.get_tf_session()
|
self.tf_session = gpufmkmgr.get_tf_session()
|
||||||
self.keras = gpufmkmgr.import_keras()
|
self.keras = gpufmkmgr.import_keras()
|
||||||
self.e = facelib.MTCExtractor(self.keras, self.tf, self.tf_session)
|
self.e = facelib.MTCExtractor(self.keras, self.tf, self.tf_session)
|
||||||
|
@ -251,7 +253,8 @@ class ExtractSubprocessor(SubprocessorBase):
|
||||||
self.e.__enter__()
|
self.e.__enter__()
|
||||||
|
|
||||||
elif self.type == 'landmarks':
|
elif self.type == 'landmarks':
|
||||||
self.tf = gpufmkmgr.import_tf([self.device_idx], allow_growth=True)
|
self.gpu_config = gpufmkmgr.GPUConfig ( force_best_gpu_idx=self.device_idx, allow_growth=True)
|
||||||
|
self.tf = gpufmkmgr.import_tf ( self.gpu_config )
|
||||||
self.tf_session = gpufmkmgr.get_tf_session()
|
self.tf_session = gpufmkmgr.get_tf_session()
|
||||||
self.keras = gpufmkmgr.import_keras()
|
self.keras = gpufmkmgr.import_keras()
|
||||||
self.e = facelib.LandmarksExtractor(self.keras)
|
self.e = facelib.LandmarksExtractor(self.keras)
|
||||||
|
|
|
@ -38,7 +38,8 @@ class ConverterMasked(ConverterBase):
|
||||||
self.erode_mask_modifier = erode_mask_modifier
|
self.erode_mask_modifier = erode_mask_modifier
|
||||||
self.blur_mask_modifier = blur_mask_modifier
|
self.blur_mask_modifier = blur_mask_modifier
|
||||||
self.output_face_scale = np.clip(1.0 + output_face_scale_modifier*0.01, 0.5, 1.0)
|
self.output_face_scale = np.clip(1.0 + output_face_scale_modifier*0.01, 0.5, 1.0)
|
||||||
self.transfercolor = transfercolor
|
self.transfercolor = transfercolor
|
||||||
|
self.TFLabConverter = None
|
||||||
self.final_image_color_degrade_power = np.clip (final_image_color_degrade_power, 0, 100)
|
self.final_image_color_degrade_power = np.clip (final_image_color_degrade_power, 0, 100)
|
||||||
self.alpha = alpha
|
self.alpha = alpha
|
||||||
|
|
||||||
|
@ -199,14 +200,14 @@ class ConverterMasked(ConverterBase):
|
||||||
new_out = cv2.warpAffine( new_out_face_bgr, face_mat, img_size, img_bgr.copy(), cv2.WARP_INVERSE_MAP | cv2.INTER_LANCZOS4, cv2.BORDER_TRANSPARENT )
|
new_out = cv2.warpAffine( new_out_face_bgr, face_mat, img_size, img_bgr.copy(), cv2.WARP_INVERSE_MAP | cv2.INTER_LANCZOS4, cv2.BORDER_TRANSPARENT )
|
||||||
out_img = np.clip( img_bgr*(1-img_mask_blurry_aaa) + (new_out*img_mask_blurry_aaa) , 0, 1.0 )
|
out_img = np.clip( img_bgr*(1-img_mask_blurry_aaa) + (new_out*img_mask_blurry_aaa) , 0, 1.0 )
|
||||||
|
|
||||||
if self.transfercolor: #making transfer color from original DST image to fake
|
if self.transfercolor:
|
||||||
from skimage import io, color
|
if self.TFLabConverter is None:
|
||||||
lab_clr = color.rgb2lab(img_bgr) #original DST, converting RGB to LAB color space
|
self.TFLabConverter = image_utils.TFLabConverter()
|
||||||
lab_bw = color.rgb2lab(out_img) #fake, converting RGB to LAB color space
|
|
||||||
tmp_channel, a_channel, b_channel = cv2.split(lab_clr) #taking color channel A and B from original dst image
|
img_lab_l, img_lab_a, img_lab_b = np.split ( self.TFLabConverter.bgr2lab (img_bgr), 3, axis=-1 )
|
||||||
l_channel, tmp2_channel, tmp3_channel = cv2.split(lab_bw) #taking lightness channel L from merged fake
|
out_img_lab_l, out_img_lab_a, out_img_lab_b = np.split ( self.TFLabConverter.bgr2lab (out_img), 3, axis=-1 )
|
||||||
img_LAB = cv2.merge((l_channel,a_channel, b_channel)) #merging light and color
|
|
||||||
out_img = color.lab2rgb(img_LAB) #converting LAB to RGB
|
out_img = self.TFLabConverter.lab2bgr ( np.concatenate([out_img_lab_l, img_lab_a, img_lab_b], axis=-1) )
|
||||||
|
|
||||||
if self.final_image_color_degrade_power != 0:
|
if self.final_image_color_degrade_power != 0:
|
||||||
if debug:
|
if debug:
|
||||||
|
@ -234,4 +235,5 @@ class ConverterMasked(ConverterBase):
|
||||||
debugs += [out_img.copy()]
|
debugs += [out_img.copy()]
|
||||||
|
|
||||||
return debugs if debug else out_img
|
return debugs if debug else out_img
|
||||||
|
|
||||||
|
|
|
@ -20,10 +20,6 @@ class ModelBase(object):
|
||||||
#DONT OVERRIDE
|
#DONT OVERRIDE
|
||||||
def __init__(self, model_path, training_data_src_path=None, training_data_dst_path=None,
|
def __init__(self, model_path, training_data_src_path=None, training_data_dst_path=None,
|
||||||
batch_size=0,
|
batch_size=0,
|
||||||
multi_gpu = False,
|
|
||||||
choose_worst_gpu = False,
|
|
||||||
force_best_gpu_idx = -1,
|
|
||||||
force_gpu_idxs = None,
|
|
||||||
write_preview_history = False,
|
write_preview_history = False,
|
||||||
debug = False, **in_options
|
debug = False, **in_options
|
||||||
):
|
):
|
||||||
|
@ -70,37 +66,23 @@ class ModelBase(object):
|
||||||
for filename in Path_utils.get_image_paths(self.preview_history_path):
|
for filename in Path_utils.get_image_paths(self.preview_history_path):
|
||||||
Path(filename).unlink()
|
Path(filename).unlink()
|
||||||
|
|
||||||
self.multi_gpu = multi_gpu
|
|
||||||
|
self.gpu_config = gpufmkmgr.GPUConfig(allow_growth=False, **in_options)
|
||||||
gpu_idx = force_best_gpu_idx if (force_best_gpu_idx >= 0 and gpufmkmgr.isValidDeviceIdx(force_best_gpu_idx)) else gpufmkmgr.getBestDeviceIdx() if not choose_worst_gpu else gpufmkmgr.getWorstDeviceIdx()
|
self.gpu_total_vram_gb = self.gpu_config.gpu_total_vram_gb
|
||||||
gpu_total_vram_gb = gpufmkmgr.getDeviceVRAMTotalGb (gpu_idx)
|
|
||||||
is_gpu_low_mem = (gpu_total_vram_gb < 4)
|
|
||||||
|
|
||||||
self.gpu_total_vram_gb = gpu_total_vram_gb
|
|
||||||
|
|
||||||
if self.epoch == 0:
|
if self.epoch == 0:
|
||||||
#first run
|
#first run
|
||||||
self.options['created_vram_gb'] = gpu_total_vram_gb
|
self.options['created_vram_gb'] = self.gpu_total_vram_gb
|
||||||
self.created_vram_gb = gpu_total_vram_gb
|
self.created_vram_gb = self.gpu_total_vram_gb
|
||||||
else:
|
else:
|
||||||
#not first run
|
#not first run
|
||||||
if 'created_vram_gb' in self.options.keys():
|
if 'created_vram_gb' in self.options.keys():
|
||||||
self.created_vram_gb = self.options['created_vram_gb']
|
self.created_vram_gb = self.options['created_vram_gb']
|
||||||
else:
|
else:
|
||||||
self.options['created_vram_gb'] = gpu_total_vram_gb
|
self.options['created_vram_gb'] = self.gpu_total_vram_gb
|
||||||
self.created_vram_gb = gpu_total_vram_gb
|
self.created_vram_gb = self.gpu_total_vram_gb
|
||||||
|
|
||||||
if force_gpu_idxs is not None:
|
self.tf = gpufmkmgr.import_tf( self.gpu_config )
|
||||||
self.gpu_idxs = [ int(x) for x in force_gpu_idxs.split(',') ]
|
|
||||||
else:
|
|
||||||
if self.multi_gpu:
|
|
||||||
self.gpu_idxs = gpufmkmgr.getDeviceIdxsEqualModel( gpu_idx )
|
|
||||||
if len(self.gpu_idxs) <= 1:
|
|
||||||
self.multi_gpu = False
|
|
||||||
else:
|
|
||||||
self.gpu_idxs = [gpu_idx]
|
|
||||||
|
|
||||||
self.tf = gpufmkmgr.import_tf(self.gpu_idxs,allow_growth=False)
|
|
||||||
self.tf_sess = gpufmkmgr.get_tf_session()
|
self.tf_sess = gpufmkmgr.get_tf_session()
|
||||||
self.keras = gpufmkmgr.import_keras()
|
self.keras = gpufmkmgr.import_keras()
|
||||||
self.keras_contrib = gpufmkmgr.import_keras_contrib()
|
self.keras_contrib = gpufmkmgr.import_keras_contrib()
|
||||||
|
@ -131,12 +113,12 @@ class ModelBase(object):
|
||||||
print ("==")
|
print ("==")
|
||||||
print ("== Options:")
|
print ("== Options:")
|
||||||
print ("== |== batch_size : %s " % (self.batch_size) )
|
print ("== |== batch_size : %s " % (self.batch_size) )
|
||||||
print ("== |== multi_gpu : %s " % (self.multi_gpu) )
|
print ("== |== multi_gpu : %s " % (self.gpu_config.multi_gpu) )
|
||||||
for key in self.options.keys():
|
for key in self.options.keys():
|
||||||
print ("== |== %s : %s" % (key, self.options[key]) )
|
print ("== |== %s : %s" % (key, self.options[key]) )
|
||||||
|
|
||||||
print ("== Running on:")
|
print ("== Running on:")
|
||||||
for idx in self.gpu_idxs:
|
for idx in self.gpu_config.gpu_idxs:
|
||||||
print ("== |== [%d : %s]" % (idx, gpufmkmgr.getDeviceName(idx)) )
|
print ("== |== [%d : %s]" % (idx, gpufmkmgr.getDeviceName(idx)) )
|
||||||
|
|
||||||
if self.gpu_total_vram_gb == 2:
|
if self.gpu_total_vram_gb == 2:
|
||||||
|
@ -188,18 +170,18 @@ class ModelBase(object):
|
||||||
return ConverterBase(self, **in_options)
|
return ConverterBase(self, **in_options)
|
||||||
|
|
||||||
def to_multi_gpu_model_if_possible (self, models_list):
|
def to_multi_gpu_model_if_possible (self, models_list):
|
||||||
if len(self.gpu_idxs) > 1:
|
if len(self.gpu_config.gpu_idxs) > 1:
|
||||||
#make batch_size to divide on GPU count without remainder
|
#make batch_size to divide on GPU count without remainder
|
||||||
self.batch_size = int( self.batch_size / len(self.gpu_idxs) )
|
self.batch_size = int( self.batch_size / len(self.gpu_config.gpu_idxs) )
|
||||||
if self.batch_size == 0:
|
if self.batch_size == 0:
|
||||||
self.batch_size = 1
|
self.batch_size = 1
|
||||||
self.batch_size *= len(self.gpu_idxs)
|
self.batch_size *= len(self.gpu_config.gpu_idxs)
|
||||||
|
|
||||||
result = []
|
result = []
|
||||||
for model in models_list:
|
for model in models_list:
|
||||||
for i in range( len(model.output_names) ):
|
for i in range( len(model.output_names) ):
|
||||||
model.output_names = 'output_%d' % (i)
|
model.output_names = 'output_%d' % (i)
|
||||||
result += [ self.keras.utils.multi_gpu_model( model, self.gpu_idxs ) ]
|
result += [ self.keras.utils.multi_gpu_model( model, self.gpu_config.gpu_idxs ) ]
|
||||||
|
|
||||||
return result
|
return result
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -8,4 +8,3 @@ scikit-image
|
||||||
dlib==19.10.0
|
dlib==19.10.0
|
||||||
tqdm
|
tqdm
|
||||||
git+https://www.github.com/keras-team/keras-contrib.git
|
git+https://www.github.com/keras-team/keras-contrib.git
|
||||||
skimage
|
|
||||||
|
|
|
@ -272,4 +272,107 @@ def reduce_colors (img_bgr, n_colors):
|
||||||
img_rgb_p = img_rgb_pil_p.convert('RGB')
|
img_rgb_p = img_rgb_pil_p.convert('RGB')
|
||||||
img_bgr = cv2.cvtColor( np.array(img_rgb_p, dtype=np.float32) / 255.0, cv2.COLOR_RGB2BGR )
|
img_bgr = cv2.cvtColor( np.array(img_rgb_p, dtype=np.float32) / 255.0, cv2.COLOR_RGB2BGR )
|
||||||
|
|
||||||
return img_bgr
|
return img_bgr
|
||||||
|
|
||||||
|
|
||||||
|
class TFLabConverter():
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def __init__(self,):
|
||||||
|
import gpufmkmgr
|
||||||
|
|
||||||
|
self.tf_module = gpufmkmgr.import_tf()
|
||||||
|
self.tf_session = gpufmkmgr.get_tf_session()
|
||||||
|
|
||||||
|
self.bgr_input_tensor = self.tf_module.placeholder("float", [None, None, 3])
|
||||||
|
self.lab_input_tensor = self.tf_module.placeholder("float", [None, None, 3])
|
||||||
|
|
||||||
|
self.lab_output_tensor = self.rgb_to_lab(self.tf_module, self.bgr_input_tensor)
|
||||||
|
self.bgr_output_tensor = self.lab_to_rgb(self.tf_module, self.lab_input_tensor)
|
||||||
|
|
||||||
|
|
||||||
|
def bgr2lab(self, bgr):
|
||||||
|
return self.tf_session.run(self.lab_output_tensor, feed_dict={self.bgr_input_tensor: bgr})
|
||||||
|
|
||||||
|
def lab2bgr(self, lab):
|
||||||
|
return self.tf_session.run(self.bgr_output_tensor, feed_dict={self.lab_input_tensor: lab})
|
||||||
|
|
||||||
|
def rgb_to_lab(self, tf, rgb_input):
|
||||||
|
with tf.name_scope("rgb_to_lab"):
|
||||||
|
srgb_pixels = tf.reshape(rgb_input, [-1, 3])
|
||||||
|
|
||||||
|
with tf.name_scope("srgb_to_xyz"):
|
||||||
|
linear_mask = tf.cast(srgb_pixels <= 0.04045, dtype=tf.float32)
|
||||||
|
exponential_mask = tf.cast(srgb_pixels > 0.04045, dtype=tf.float32)
|
||||||
|
rgb_pixels = (srgb_pixels / 12.92 * linear_mask) + (((srgb_pixels + 0.055) / 1.055) ** 2.4) * exponential_mask
|
||||||
|
rgb_to_xyz = tf.constant([
|
||||||
|
# X Y Z
|
||||||
|
[0.412453, 0.212671, 0.019334], # R
|
||||||
|
[0.357580, 0.715160, 0.119193], # G
|
||||||
|
[0.180423, 0.072169, 0.950227], # B
|
||||||
|
])
|
||||||
|
xyz_pixels = tf.matmul(rgb_pixels, rgb_to_xyz)
|
||||||
|
|
||||||
|
# https://en.wikipedia.org/wiki/Lab_color_space#CIELAB-CIEXYZ_conversions
|
||||||
|
with tf.name_scope("xyz_to_cielab"):
|
||||||
|
# convert to fx = f(X/Xn), fy = f(Y/Yn), fz = f(Z/Zn)
|
||||||
|
|
||||||
|
# normalize for D65 white point
|
||||||
|
xyz_normalized_pixels = tf.multiply(xyz_pixels, [1/0.950456, 1.0, 1/1.088754])
|
||||||
|
|
||||||
|
epsilon = 6/29
|
||||||
|
linear_mask = tf.cast(xyz_normalized_pixels <= (epsilon**3), dtype=tf.float32)
|
||||||
|
exponential_mask = tf.cast(xyz_normalized_pixels > (epsilon**3), dtype=tf.float32)
|
||||||
|
fxfyfz_pixels = (xyz_normalized_pixels / (3 * epsilon**2) + 4/29) * linear_mask + (xyz_normalized_pixels ** (1/3)) * exponential_mask
|
||||||
|
|
||||||
|
# convert to lab
|
||||||
|
fxfyfz_to_lab = tf.constant([
|
||||||
|
# l a b
|
||||||
|
[ 0.0, 500.0, 0.0], # fx
|
||||||
|
[116.0, -500.0, 200.0], # fy
|
||||||
|
[ 0.0, 0.0, -200.0], # fz
|
||||||
|
])
|
||||||
|
lab_pixels = tf.matmul(fxfyfz_pixels, fxfyfz_to_lab) + tf.constant([-16.0, 0.0, 0.0])
|
||||||
|
|
||||||
|
return tf.reshape(lab_pixels, tf.shape(rgb_input))
|
||||||
|
|
||||||
|
def lab_to_rgb(self, tf, lab):
|
||||||
|
with tf.name_scope("lab_to_rgb"):
|
||||||
|
lab_pixels = tf.reshape(lab, [-1, 3])
|
||||||
|
|
||||||
|
# https://en.wikipedia.org/wiki/Lab_color_space#CIELAB-CIEXYZ_conversions
|
||||||
|
with tf.name_scope("cielab_to_xyz"):
|
||||||
|
# convert to fxfyfz
|
||||||
|
lab_to_fxfyfz = tf.constant([
|
||||||
|
# fx fy fz
|
||||||
|
[1/116.0, 1/116.0, 1/116.0], # l
|
||||||
|
[1/500.0, 0.0, 0.0], # a
|
||||||
|
[ 0.0, 0.0, -1/200.0], # b
|
||||||
|
])
|
||||||
|
fxfyfz_pixels = tf.matmul(lab_pixels + tf.constant([16.0, 0.0, 0.0]), lab_to_fxfyfz)
|
||||||
|
|
||||||
|
# convert to xyz
|
||||||
|
epsilon = 6/29
|
||||||
|
linear_mask = tf.cast(fxfyfz_pixels <= epsilon, dtype=tf.float32)
|
||||||
|
exponential_mask = tf.cast(fxfyfz_pixels > epsilon, dtype=tf.float32)
|
||||||
|
xyz_pixels = (3 * epsilon**2 * (fxfyfz_pixels - 4/29)) * linear_mask + (fxfyfz_pixels ** 3) * exponential_mask
|
||||||
|
|
||||||
|
# denormalize for D65 white point
|
||||||
|
xyz_pixels = tf.multiply(xyz_pixels, [0.950456, 1.0, 1.088754])
|
||||||
|
|
||||||
|
with tf.name_scope("xyz_to_srgb"):
|
||||||
|
xyz_to_rgb = tf.constant([
|
||||||
|
# r g b
|
||||||
|
[ 3.2404542, -0.9692660, 0.0556434], # x
|
||||||
|
[-1.5371385, 1.8760108, -0.2040259], # y
|
||||||
|
[-0.4985314, 0.0415560, 1.0572252], # z
|
||||||
|
])
|
||||||
|
rgb_pixels = tf.matmul(xyz_pixels, xyz_to_rgb)
|
||||||
|
# avoid a slightly negative number messing up the conversion
|
||||||
|
rgb_pixels = tf.clip_by_value(rgb_pixels, 0.0, 1.0)
|
||||||
|
linear_mask = tf.cast(rgb_pixels <= 0.0031308, dtype=tf.float32)
|
||||||
|
exponential_mask = tf.cast(rgb_pixels > 0.0031308, dtype=tf.float32)
|
||||||
|
srgb_pixels = (rgb_pixels * 12.92 * linear_mask) + ((rgb_pixels ** (1/2.4) * 1.055) - 0.055) * exponential_mask
|
||||||
|
|
||||||
|
return tf.reshape(srgb_pixels, tf.shape(lab))
|
Loading…
Add table
Add a link
Reference in a new issue