transfercolor via lab converter now implemented by tensorflow-cpu, which is x2 faster than skimage.

We cannot use GPU for lab converter in converter multiprocesses, because almost all VRAM ate by model process, so even 300Mb free VRAM not enough for tensorflow lab converter.
Removed skimage dependency.
Refactorings.
This commit is contained in:
iperov 2018-12-01 12:11:54 +04:00
parent 6f0d38d171
commit 5c43f4245e
7 changed files with 229 additions and 67 deletions

View file

@ -5,6 +5,8 @@ import contextlib
from utils import std_utils
from .pynvml import *
dlib_module = None
def import_dlib(device_idx):
global dlib_module
@ -22,18 +24,26 @@ keras_module = None
keras_contrib_module = None
keras_vggface_module = None
def set_prefer_GPUConfig(gpu_config):
global prefer_GPUConfig
prefer_GPUConfig = gpu_config
def get_tf_session():
global tf_session
return tf_session
#allow_growth=False for keras model
#allow_growth=True for tf only model
def import_tf( device_idxs_list, allow_growth ):
def import_tf( gpu_config = None ):
global prefer_GPUConfig
global tf_module
global tf_session
if gpu_config is None:
gpu_config = prefer_GPUConfig
else:
prefer_GPUConfig = gpu_config
if tf_module is not None:
raise Exception ('Multiple import of tf is not allowed, reorganize your program.')
return tf_module
if 'TF_SUPPRESS_STD' in os.environ.keys() and os.environ['TF_SUPPRESS_STD'] == '1':
suppressor = std_utils.suppress_stdout_stderr().__enter__()
@ -48,14 +58,18 @@ def import_tf( device_idxs_list, allow_growth ):
import tensorflow as tf
tf_module = tf
visible_device_list = ''
for idx in device_idxs_list: visible_device_list += str(idx) + ','
visible_device_list = visible_device_list[:-1]
if gpu_config.cpu_only:
config = tf_module.ConfigProto( device_count = {'GPU': 0} )
else:
config = tf_module.ConfigProto()
config.gpu_options.allow_growth = allow_growth
visible_device_list = ''
for idx in gpu_config.gpu_idxs: visible_device_list += str(idx) + ','
visible_device_list = visible_device_list[:-1]
config.gpu_options.visible_device_list=visible_device_list
config.gpu_options.force_gpu_compatible = True
config.gpu_options.allow_growth = gpu_config.allow_growth
tf_session = tf_module.Session(config=config)
if suppressor is not None:
@ -71,11 +85,15 @@ def finalize_tf():
tf_session = None
tf_module = None
def get_keras():
global keras_module
return keras_module
def import_keras():
global keras_module
if keras_module is not None:
raise Exception ('Multiple import of keras is not allowed, reorganize your program.')
return keras_module
sess = get_tf_session()
if sess is None:
@ -242,3 +260,47 @@ def getDeviceName (idx):
result = nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(idx)).decode()
nvmlShutdown()
return result
class GPUConfig():
force_best_gpu_idx = -1
multi_gpu = False
force_gpu_idxs = None
choose_worst_gpu = False
gpu_idxs = []
gpu_total_vram_gb = 0
allow_growth = True
cpu_only = False
def __init__ (self, force_best_gpu_idx = -1,
multi_gpu = False,
force_gpu_idxs = None,
choose_worst_gpu = False,
allow_growth = True,
cpu_only = False,
**in_options):
if cpu_only:
self.cpu_only = cpu_only
else:
self.force_best_gpu_idx = force_best_gpu_idx
self.multi_gpu = multi_gpu
self.force_gpu_idxs = force_gpu_idxs
self.choose_worst_gpu = choose_worst_gpu
self.allow_growth = allow_growth
gpu_idx = force_best_gpu_idx if (force_best_gpu_idx >= 0 and isValidDeviceIdx(force_best_gpu_idx)) else getBestDeviceIdx() if not choose_worst_gpu else getWorstDeviceIdx()
if force_gpu_idxs is not None:
self.gpu_idxs = [ int(x) for x in force_gpu_idxs.split(',') ]
else:
if self.multi_gpu:
self.gpu_idxs = getDeviceIdxsEqualModel( gpu_idx )
if len(self.gpu_idxs) <= 1:
self.multi_gpu = False
else:
self.gpu_idxs = [gpu_idx]
self.gpu_total_vram_gb = getDeviceVRAMTotalGb ( self.gpu_idxs[0] )
prefer_GPUConfig = GPUConfig()

View file

@ -64,14 +64,14 @@ from utils.SubprocessorBase import SubprocessorBase
class ConvertSubprocessor(SubprocessorBase):
#override
def __init__(self, converter, input_path_image_paths, output_path, alignments, debug):
def __init__(self, converter, input_path_image_paths, output_path, alignments, debug = False, **in_options):
super().__init__('Converter')
self.converter = converter
self.input_path_image_paths = input_path_image_paths
self.output_path = output_path
self.alignments = alignments
self.debug = debug
self.in_options = in_options
self.input_data = self.input_path_image_paths
self.files_processed = 0
self.faces_processed = 0
@ -79,13 +79,18 @@ class ConvertSubprocessor(SubprocessorBase):
#override
def process_info_generator(self):
r = [0] if self.debug else range(multiprocessing.cpu_count())
for i in r:
yield 'CPU%d' % (i), {}, {'device_idx': i,
'device_name': 'CPU%d' % (i),
'converter' : self.converter,
'output_dir' : str(self.output_path),
'alignments' : self.alignments,
'debug': self.debug }
'debug': self.debug,
'in_options': self.in_options
}
#override
def get_no_process_started_message(self):
@ -118,6 +123,12 @@ class ConvertSubprocessor(SubprocessorBase):
self.output_path = Path(client_dict['output_dir']) if 'output_dir' in client_dict.keys() else None
self.alignments = client_dict['alignments']
self.debug = client_dict['debug']
import gpufmkmgr
#model process ate all GPU mem,
#so we cannot use GPU for any TF operations in converter processes (for example image_utils.TFLabConverter)
gpufmkmgr.set_prefer_GPUConfig ( gpufmkmgr.GPUConfig (cpu_only=True) )
return None
#override
@ -211,7 +222,6 @@ def main (input_dir, output_dir, aligned_dir, model_dir, model_name, **in_option
model_sq = multiprocessing.Queue()
model_cq = multiprocessing.Queue()
model_lock = multiprocessing.Lock()
model_p = multiprocessing.Process(target=model_process, args=(model_name, model_dir, in_options, model_sq, model_cq))
model_p.start()
@ -242,12 +252,13 @@ def main (input_dir, output_dir, aligned_dir, model_dir, model_name, **in_option
alignments[ source_filename_stem ].append ( np.array(d['source_landmarks']) )
files_processed, faces_processed = ConvertSubprocessor (
converter = converter.copy_and_set_predictor( model_process_predictor(model_sq,model_cq,model_lock) ),
input_path_image_paths = Path_utils.get_image_paths(input_path),
output_path = output_path,
alignments = alignments,
debug = debug ).process()
**in_options ).process()
model_sq.put ( {'op':'close'} )
model_p.join()

View file

@ -241,7 +241,9 @@ class ExtractSubprocessor(SubprocessorBase):
if self.type == 'rects':
if self.detector is not None:
if self.detector == 'mt':
self.tf = gpufmkmgr.import_tf ([self.device_idx], allow_growth=True)
self.gpu_config = gpufmkmgr.GPUConfig ( force_best_gpu_idx=self.device_idx, allow_growth=True)
self.tf = gpufmkmgr.import_tf ( self.gpu_config )
self.tf_session = gpufmkmgr.get_tf_session()
self.keras = gpufmkmgr.import_keras()
self.e = facelib.MTCExtractor(self.keras, self.tf, self.tf_session)
@ -251,7 +253,8 @@ class ExtractSubprocessor(SubprocessorBase):
self.e.__enter__()
elif self.type == 'landmarks':
self.tf = gpufmkmgr.import_tf([self.device_idx], allow_growth=True)
self.gpu_config = gpufmkmgr.GPUConfig ( force_best_gpu_idx=self.device_idx, allow_growth=True)
self.tf = gpufmkmgr.import_tf ( self.gpu_config )
self.tf_session = gpufmkmgr.get_tf_session()
self.keras = gpufmkmgr.import_keras()
self.e = facelib.LandmarksExtractor(self.keras)

View file

@ -39,6 +39,7 @@ class ConverterMasked(ConverterBase):
self.blur_mask_modifier = blur_mask_modifier
self.output_face_scale = np.clip(1.0 + output_face_scale_modifier*0.01, 0.5, 1.0)
self.transfercolor = transfercolor
self.TFLabConverter = None
self.final_image_color_degrade_power = np.clip (final_image_color_degrade_power, 0, 100)
self.alpha = alpha
@ -199,14 +200,14 @@ class ConverterMasked(ConverterBase):
new_out = cv2.warpAffine( new_out_face_bgr, face_mat, img_size, img_bgr.copy(), cv2.WARP_INVERSE_MAP | cv2.INTER_LANCZOS4, cv2.BORDER_TRANSPARENT )
out_img = np.clip( img_bgr*(1-img_mask_blurry_aaa) + (new_out*img_mask_blurry_aaa) , 0, 1.0 )
if self.transfercolor: #making transfer color from original DST image to fake
from skimage import io, color
lab_clr = color.rgb2lab(img_bgr) #original DST, converting RGB to LAB color space
lab_bw = color.rgb2lab(out_img) #fake, converting RGB to LAB color space
tmp_channel, a_channel, b_channel = cv2.split(lab_clr) #taking color channel A and B from original dst image
l_channel, tmp2_channel, tmp3_channel = cv2.split(lab_bw) #taking lightness channel L from merged fake
img_LAB = cv2.merge((l_channel,a_channel, b_channel)) #merging light and color
out_img = color.lab2rgb(img_LAB) #converting LAB to RGB
if self.transfercolor:
if self.TFLabConverter is None:
self.TFLabConverter = image_utils.TFLabConverter()
img_lab_l, img_lab_a, img_lab_b = np.split ( self.TFLabConverter.bgr2lab (img_bgr), 3, axis=-1 )
out_img_lab_l, out_img_lab_a, out_img_lab_b = np.split ( self.TFLabConverter.bgr2lab (out_img), 3, axis=-1 )
out_img = self.TFLabConverter.lab2bgr ( np.concatenate([out_img_lab_l, img_lab_a, img_lab_b], axis=-1) )
if self.final_image_color_degrade_power != 0:
if debug:
@ -235,3 +236,4 @@ class ConverterMasked(ConverterBase):
return debugs if debug else out_img

View file

@ -20,10 +20,6 @@ class ModelBase(object):
#DONT OVERRIDE
def __init__(self, model_path, training_data_src_path=None, training_data_dst_path=None,
batch_size=0,
multi_gpu = False,
choose_worst_gpu = False,
force_best_gpu_idx = -1,
force_gpu_idxs = None,
write_preview_history = False,
debug = False, **in_options
):
@ -70,37 +66,23 @@ class ModelBase(object):
for filename in Path_utils.get_image_paths(self.preview_history_path):
Path(filename).unlink()
self.multi_gpu = multi_gpu
gpu_idx = force_best_gpu_idx if (force_best_gpu_idx >= 0 and gpufmkmgr.isValidDeviceIdx(force_best_gpu_idx)) else gpufmkmgr.getBestDeviceIdx() if not choose_worst_gpu else gpufmkmgr.getWorstDeviceIdx()
gpu_total_vram_gb = gpufmkmgr.getDeviceVRAMTotalGb (gpu_idx)
is_gpu_low_mem = (gpu_total_vram_gb < 4)
self.gpu_total_vram_gb = gpu_total_vram_gb
self.gpu_config = gpufmkmgr.GPUConfig(allow_growth=False, **in_options)
self.gpu_total_vram_gb = self.gpu_config.gpu_total_vram_gb
if self.epoch == 0:
#first run
self.options['created_vram_gb'] = gpu_total_vram_gb
self.created_vram_gb = gpu_total_vram_gb
self.options['created_vram_gb'] = self.gpu_total_vram_gb
self.created_vram_gb = self.gpu_total_vram_gb
else:
#not first run
if 'created_vram_gb' in self.options.keys():
self.created_vram_gb = self.options['created_vram_gb']
else:
self.options['created_vram_gb'] = gpu_total_vram_gb
self.created_vram_gb = gpu_total_vram_gb
self.options['created_vram_gb'] = self.gpu_total_vram_gb
self.created_vram_gb = self.gpu_total_vram_gb
if force_gpu_idxs is not None:
self.gpu_idxs = [ int(x) for x in force_gpu_idxs.split(',') ]
else:
if self.multi_gpu:
self.gpu_idxs = gpufmkmgr.getDeviceIdxsEqualModel( gpu_idx )
if len(self.gpu_idxs) <= 1:
self.multi_gpu = False
else:
self.gpu_idxs = [gpu_idx]
self.tf = gpufmkmgr.import_tf(self.gpu_idxs,allow_growth=False)
self.tf = gpufmkmgr.import_tf( self.gpu_config )
self.tf_sess = gpufmkmgr.get_tf_session()
self.keras = gpufmkmgr.import_keras()
self.keras_contrib = gpufmkmgr.import_keras_contrib()
@ -131,12 +113,12 @@ class ModelBase(object):
print ("==")
print ("== Options:")
print ("== |== batch_size : %s " % (self.batch_size) )
print ("== |== multi_gpu : %s " % (self.multi_gpu) )
print ("== |== multi_gpu : %s " % (self.gpu_config.multi_gpu) )
for key in self.options.keys():
print ("== |== %s : %s" % (key, self.options[key]) )
print ("== Running on:")
for idx in self.gpu_idxs:
for idx in self.gpu_config.gpu_idxs:
print ("== |== [%d : %s]" % (idx, gpufmkmgr.getDeviceName(idx)) )
if self.gpu_total_vram_gb == 2:
@ -188,18 +170,18 @@ class ModelBase(object):
return ConverterBase(self, **in_options)
def to_multi_gpu_model_if_possible (self, models_list):
if len(self.gpu_idxs) > 1:
if len(self.gpu_config.gpu_idxs) > 1:
#make batch_size to divide on GPU count without remainder
self.batch_size = int( self.batch_size / len(self.gpu_idxs) )
self.batch_size = int( self.batch_size / len(self.gpu_config.gpu_idxs) )
if self.batch_size == 0:
self.batch_size = 1
self.batch_size *= len(self.gpu_idxs)
self.batch_size *= len(self.gpu_config.gpu_idxs)
result = []
for model in models_list:
for i in range( len(model.output_names) ):
model.output_names = 'output_%d' % (i)
result += [ self.keras.utils.multi_gpu_model( model, self.gpu_idxs ) ]
result += [ self.keras.utils.multi_gpu_model( model, self.gpu_config.gpu_idxs ) ]
return result
else:

View file

@ -8,4 +8,3 @@ scikit-image
dlib==19.10.0
tqdm
git+https://www.github.com/keras-team/keras-contrib.git
skimage

View file

@ -273,3 +273,106 @@ def reduce_colors (img_bgr, n_colors):
img_bgr = cv2.cvtColor( np.array(img_rgb_p, dtype=np.float32) / 255.0, cv2.COLOR_RGB2BGR )
return img_bgr
class TFLabConverter():
def __init__(self,):
import gpufmkmgr
self.tf_module = gpufmkmgr.import_tf()
self.tf_session = gpufmkmgr.get_tf_session()
self.bgr_input_tensor = self.tf_module.placeholder("float", [None, None, 3])
self.lab_input_tensor = self.tf_module.placeholder("float", [None, None, 3])
self.lab_output_tensor = self.rgb_to_lab(self.tf_module, self.bgr_input_tensor)
self.bgr_output_tensor = self.lab_to_rgb(self.tf_module, self.lab_input_tensor)
def bgr2lab(self, bgr):
return self.tf_session.run(self.lab_output_tensor, feed_dict={self.bgr_input_tensor: bgr})
def lab2bgr(self, lab):
return self.tf_session.run(self.bgr_output_tensor, feed_dict={self.lab_input_tensor: lab})
def rgb_to_lab(self, tf, rgb_input):
with tf.name_scope("rgb_to_lab"):
srgb_pixels = tf.reshape(rgb_input, [-1, 3])
with tf.name_scope("srgb_to_xyz"):
linear_mask = tf.cast(srgb_pixels <= 0.04045, dtype=tf.float32)
exponential_mask = tf.cast(srgb_pixels > 0.04045, dtype=tf.float32)
rgb_pixels = (srgb_pixels / 12.92 * linear_mask) + (((srgb_pixels + 0.055) / 1.055) ** 2.4) * exponential_mask
rgb_to_xyz = tf.constant([
# X Y Z
[0.412453, 0.212671, 0.019334], # R
[0.357580, 0.715160, 0.119193], # G
[0.180423, 0.072169, 0.950227], # B
])
xyz_pixels = tf.matmul(rgb_pixels, rgb_to_xyz)
# https://en.wikipedia.org/wiki/Lab_color_space#CIELAB-CIEXYZ_conversions
with tf.name_scope("xyz_to_cielab"):
# convert to fx = f(X/Xn), fy = f(Y/Yn), fz = f(Z/Zn)
# normalize for D65 white point
xyz_normalized_pixels = tf.multiply(xyz_pixels, [1/0.950456, 1.0, 1/1.088754])
epsilon = 6/29
linear_mask = tf.cast(xyz_normalized_pixels <= (epsilon**3), dtype=tf.float32)
exponential_mask = tf.cast(xyz_normalized_pixels > (epsilon**3), dtype=tf.float32)
fxfyfz_pixels = (xyz_normalized_pixels / (3 * epsilon**2) + 4/29) * linear_mask + (xyz_normalized_pixels ** (1/3)) * exponential_mask
# convert to lab
fxfyfz_to_lab = tf.constant([
# l a b
[ 0.0, 500.0, 0.0], # fx
[116.0, -500.0, 200.0], # fy
[ 0.0, 0.0, -200.0], # fz
])
lab_pixels = tf.matmul(fxfyfz_pixels, fxfyfz_to_lab) + tf.constant([-16.0, 0.0, 0.0])
return tf.reshape(lab_pixels, tf.shape(rgb_input))
def lab_to_rgb(self, tf, lab):
with tf.name_scope("lab_to_rgb"):
lab_pixels = tf.reshape(lab, [-1, 3])
# https://en.wikipedia.org/wiki/Lab_color_space#CIELAB-CIEXYZ_conversions
with tf.name_scope("cielab_to_xyz"):
# convert to fxfyfz
lab_to_fxfyfz = tf.constant([
# fx fy fz
[1/116.0, 1/116.0, 1/116.0], # l
[1/500.0, 0.0, 0.0], # a
[ 0.0, 0.0, -1/200.0], # b
])
fxfyfz_pixels = tf.matmul(lab_pixels + tf.constant([16.0, 0.0, 0.0]), lab_to_fxfyfz)
# convert to xyz
epsilon = 6/29
linear_mask = tf.cast(fxfyfz_pixels <= epsilon, dtype=tf.float32)
exponential_mask = tf.cast(fxfyfz_pixels > epsilon, dtype=tf.float32)
xyz_pixels = (3 * epsilon**2 * (fxfyfz_pixels - 4/29)) * linear_mask + (fxfyfz_pixels ** 3) * exponential_mask
# denormalize for D65 white point
xyz_pixels = tf.multiply(xyz_pixels, [0.950456, 1.0, 1.088754])
with tf.name_scope("xyz_to_srgb"):
xyz_to_rgb = tf.constant([
# r g b
[ 3.2404542, -0.9692660, 0.0556434], # x
[-1.5371385, 1.8760108, -0.2040259], # y
[-0.4985314, 0.0415560, 1.0572252], # z
])
rgb_pixels = tf.matmul(xyz_pixels, xyz_to_rgb)
# avoid a slightly negative number messing up the conversion
rgb_pixels = tf.clip_by_value(rgb_pixels, 0.0, 1.0)
linear_mask = tf.cast(rgb_pixels <= 0.0031308, dtype=tf.float32)
exponential_mask = tf.cast(rgb_pixels > 0.0031308, dtype=tf.float32)
srgb_pixels = (rgb_pixels * 12.92 * linear_mask) + ((rgb_pixels ** (1/2.4) * 1.055) - 0.055) * exponential_mask
return tf.reshape(srgb_pixels, tf.shape(lab))