From 5c43f4245ebe08ca44884199ade58a47e2c6e263 Mon Sep 17 00:00:00 2001 From: iperov Date: Sat, 1 Dec 2018 12:11:54 +0400 Subject: [PATCH] transfercolor via lab converter now implemented by tensorflow-cpu, which is x2 faster than skimage. We cannot use GPU for lab converter in converter multiprocesses, because almost all VRAM ate by model process, so even 300Mb free VRAM not enough for tensorflow lab converter. Removed skimage dependency. Refactorings. --- gpufmkmgr/gpufmkmgr.py | 90 +++++++++++++++++++++---- mainscripts/Converter.py | 25 +++++-- mainscripts/Extractor.py | 7 +- models/ConverterMasked.py | 22 ++++--- models/ModelBase.py | 46 ++++--------- requirements-gpu-cuda9-cudnn7.txt | 1 - utils/image_utils.py | 105 +++++++++++++++++++++++++++++- 7 files changed, 229 insertions(+), 67 deletions(-) diff --git a/gpufmkmgr/gpufmkmgr.py b/gpufmkmgr/gpufmkmgr.py index 7fe6ead..a0dbb87 100644 --- a/gpufmkmgr/gpufmkmgr.py +++ b/gpufmkmgr/gpufmkmgr.py @@ -5,6 +5,8 @@ import contextlib from utils import std_utils from .pynvml import * + + dlib_module = None def import_dlib(device_idx): global dlib_module @@ -22,18 +24,26 @@ keras_module = None keras_contrib_module = None keras_vggface_module = None +def set_prefer_GPUConfig(gpu_config): + global prefer_GPUConfig + prefer_GPUConfig = gpu_config + def get_tf_session(): global tf_session return tf_session - -#allow_growth=False for keras model -#allow_growth=True for tf only model -def import_tf( device_idxs_list, allow_growth ): + +def import_tf( gpu_config = None ): + global prefer_GPUConfig global tf_module global tf_session + if gpu_config is None: + gpu_config = prefer_GPUConfig + else: + prefer_GPUConfig = gpu_config + if tf_module is not None: - raise Exception ('Multiple import of tf is not allowed, reorganize your program.') + return tf_module if 'TF_SUPPRESS_STD' in os.environ.keys() and os.environ['TF_SUPPRESS_STD'] == '1': suppressor = std_utils.suppress_stdout_stderr().__enter__() @@ -48,14 +58,18 @@ def import_tf( device_idxs_list, allow_growth ): import tensorflow as tf tf_module = tf - visible_device_list = '' - for idx in device_idxs_list: visible_device_list += str(idx) + ',' - visible_device_list = visible_device_list[:-1] + if gpu_config.cpu_only: + config = tf_module.ConfigProto( device_count = {'GPU': 0} ) + else: + config = tf_module.ConfigProto() + visible_device_list = '' + for idx in gpu_config.gpu_idxs: visible_device_list += str(idx) + ',' + visible_device_list = visible_device_list[:-1] + config.gpu_options.visible_device_list=visible_device_list + config.gpu_options.force_gpu_compatible = True - config = tf_module.ConfigProto() - config.gpu_options.allow_growth = allow_growth - config.gpu_options.visible_device_list=visible_device_list - config.gpu_options.force_gpu_compatible = True + config.gpu_options.allow_growth = gpu_config.allow_growth + tf_session = tf_module.Session(config=config) if suppressor is not None: @@ -71,11 +85,15 @@ def finalize_tf(): tf_session = None tf_module = None +def get_keras(): + global keras_module + return keras_module + def import_keras(): global keras_module if keras_module is not None: - raise Exception ('Multiple import of keras is not allowed, reorganize your program.') + return keras_module sess = get_tf_session() if sess is None: @@ -241,4 +259,48 @@ def getDeviceName (idx): if idx < nvmlDeviceGetCount(): result = nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(idx)).decode() nvmlShutdown() - return result \ No newline at end of file + return result + + +class GPUConfig(): + force_best_gpu_idx = -1 + multi_gpu = False + force_gpu_idxs = None + choose_worst_gpu = False + gpu_idxs = [] + gpu_total_vram_gb = 0 + allow_growth = True + cpu_only = False + + def __init__ (self, force_best_gpu_idx = -1, + multi_gpu = False, + force_gpu_idxs = None, + choose_worst_gpu = False, + allow_growth = True, + cpu_only = False, + **in_options): + + if cpu_only: + self.cpu_only = cpu_only + else: + self.force_best_gpu_idx = force_best_gpu_idx + self.multi_gpu = multi_gpu + self.force_gpu_idxs = force_gpu_idxs + self.choose_worst_gpu = choose_worst_gpu + self.allow_growth = allow_growth + + gpu_idx = force_best_gpu_idx if (force_best_gpu_idx >= 0 and isValidDeviceIdx(force_best_gpu_idx)) else getBestDeviceIdx() if not choose_worst_gpu else getWorstDeviceIdx() + + if force_gpu_idxs is not None: + self.gpu_idxs = [ int(x) for x in force_gpu_idxs.split(',') ] + else: + if self.multi_gpu: + self.gpu_idxs = getDeviceIdxsEqualModel( gpu_idx ) + if len(self.gpu_idxs) <= 1: + self.multi_gpu = False + else: + self.gpu_idxs = [gpu_idx] + + self.gpu_total_vram_gb = getDeviceVRAMTotalGb ( self.gpu_idxs[0] ) + +prefer_GPUConfig = GPUConfig() \ No newline at end of file diff --git a/mainscripts/Converter.py b/mainscripts/Converter.py index 1fc27a4..36a7803 100644 --- a/mainscripts/Converter.py +++ b/mainscripts/Converter.py @@ -64,14 +64,14 @@ from utils.SubprocessorBase import SubprocessorBase class ConvertSubprocessor(SubprocessorBase): #override - def __init__(self, converter, input_path_image_paths, output_path, alignments, debug): + def __init__(self, converter, input_path_image_paths, output_path, alignments, debug = False, **in_options): super().__init__('Converter') self.converter = converter self.input_path_image_paths = input_path_image_paths self.output_path = output_path self.alignments = alignments self.debug = debug - + self.in_options = in_options self.input_data = self.input_path_image_paths self.files_processed = 0 self.faces_processed = 0 @@ -79,13 +79,18 @@ class ConvertSubprocessor(SubprocessorBase): #override def process_info_generator(self): r = [0] if self.debug else range(multiprocessing.cpu_count()) + + + for i in r: yield 'CPU%d' % (i), {}, {'device_idx': i, 'device_name': 'CPU%d' % (i), 'converter' : self.converter, 'output_dir' : str(self.output_path), 'alignments' : self.alignments, - 'debug': self.debug } + 'debug': self.debug, + 'in_options': self.in_options + } #override def get_no_process_started_message(self): @@ -118,6 +123,12 @@ class ConvertSubprocessor(SubprocessorBase): self.output_path = Path(client_dict['output_dir']) if 'output_dir' in client_dict.keys() else None self.alignments = client_dict['alignments'] self.debug = client_dict['debug'] + + import gpufmkmgr + #model process ate all GPU mem, + #so we cannot use GPU for any TF operations in converter processes (for example image_utils.TFLabConverter) + gpufmkmgr.set_prefer_GPUConfig ( gpufmkmgr.GPUConfig (cpu_only=True) ) + return None #override @@ -211,7 +222,6 @@ def main (input_dir, output_dir, aligned_dir, model_dir, model_name, **in_option model_sq = multiprocessing.Queue() model_cq = multiprocessing.Queue() model_lock = multiprocessing.Lock() - model_p = multiprocessing.Process(target=model_process, args=(model_name, model_dir, in_options, model_sq, model_cq)) model_p.start() @@ -241,13 +251,14 @@ def main (input_dir, output_dir, aligned_dir, model_dir, model_name, **in_option alignments[ source_filename_stem ] = [] alignments[ source_filename_stem ].append ( np.array(d['source_landmarks']) ) - + + files_processed, faces_processed = ConvertSubprocessor ( converter = converter.copy_and_set_predictor( model_process_predictor(model_sq,model_cq,model_lock) ), input_path_image_paths = Path_utils.get_image_paths(input_path), output_path = output_path, - alignments = alignments, - debug = debug ).process() + alignments = alignments, + **in_options ).process() model_sq.put ( {'op':'close'} ) model_p.join() diff --git a/mainscripts/Extractor.py b/mainscripts/Extractor.py index bf16a3a..e2d0f07 100644 --- a/mainscripts/Extractor.py +++ b/mainscripts/Extractor.py @@ -241,7 +241,9 @@ class ExtractSubprocessor(SubprocessorBase): if self.type == 'rects': if self.detector is not None: if self.detector == 'mt': - self.tf = gpufmkmgr.import_tf ([self.device_idx], allow_growth=True) + + self.gpu_config = gpufmkmgr.GPUConfig ( force_best_gpu_idx=self.device_idx, allow_growth=True) + self.tf = gpufmkmgr.import_tf ( self.gpu_config ) self.tf_session = gpufmkmgr.get_tf_session() self.keras = gpufmkmgr.import_keras() self.e = facelib.MTCExtractor(self.keras, self.tf, self.tf_session) @@ -251,7 +253,8 @@ class ExtractSubprocessor(SubprocessorBase): self.e.__enter__() elif self.type == 'landmarks': - self.tf = gpufmkmgr.import_tf([self.device_idx], allow_growth=True) + self.gpu_config = gpufmkmgr.GPUConfig ( force_best_gpu_idx=self.device_idx, allow_growth=True) + self.tf = gpufmkmgr.import_tf ( self.gpu_config ) self.tf_session = gpufmkmgr.get_tf_session() self.keras = gpufmkmgr.import_keras() self.e = facelib.LandmarksExtractor(self.keras) diff --git a/models/ConverterMasked.py b/models/ConverterMasked.py index b72fc0d..6c36e39 100644 --- a/models/ConverterMasked.py +++ b/models/ConverterMasked.py @@ -38,7 +38,8 @@ class ConverterMasked(ConverterBase): self.erode_mask_modifier = erode_mask_modifier self.blur_mask_modifier = blur_mask_modifier self.output_face_scale = np.clip(1.0 + output_face_scale_modifier*0.01, 0.5, 1.0) - self.transfercolor = transfercolor + self.transfercolor = transfercolor + self.TFLabConverter = None self.final_image_color_degrade_power = np.clip (final_image_color_degrade_power, 0, 100) self.alpha = alpha @@ -199,14 +200,14 @@ class ConverterMasked(ConverterBase): new_out = cv2.warpAffine( new_out_face_bgr, face_mat, img_size, img_bgr.copy(), cv2.WARP_INVERSE_MAP | cv2.INTER_LANCZOS4, cv2.BORDER_TRANSPARENT ) out_img = np.clip( img_bgr*(1-img_mask_blurry_aaa) + (new_out*img_mask_blurry_aaa) , 0, 1.0 ) - if self.transfercolor: #making transfer color from original DST image to fake - from skimage import io, color - lab_clr = color.rgb2lab(img_bgr) #original DST, converting RGB to LAB color space - lab_bw = color.rgb2lab(out_img) #fake, converting RGB to LAB color space - tmp_channel, a_channel, b_channel = cv2.split(lab_clr) #taking color channel A and B from original dst image - l_channel, tmp2_channel, tmp3_channel = cv2.split(lab_bw) #taking lightness channel L from merged fake - img_LAB = cv2.merge((l_channel,a_channel, b_channel)) #merging light and color - out_img = color.lab2rgb(img_LAB) #converting LAB to RGB + if self.transfercolor: + if self.TFLabConverter is None: + self.TFLabConverter = image_utils.TFLabConverter() + + img_lab_l, img_lab_a, img_lab_b = np.split ( self.TFLabConverter.bgr2lab (img_bgr), 3, axis=-1 ) + out_img_lab_l, out_img_lab_a, out_img_lab_b = np.split ( self.TFLabConverter.bgr2lab (out_img), 3, axis=-1 ) + + out_img = self.TFLabConverter.lab2bgr ( np.concatenate([out_img_lab_l, img_lab_a, img_lab_b], axis=-1) ) if self.final_image_color_degrade_power != 0: if debug: @@ -234,4 +235,5 @@ class ConverterMasked(ConverterBase): debugs += [out_img.copy()] return debugs if debug else out_img - + + \ No newline at end of file diff --git a/models/ModelBase.py b/models/ModelBase.py index 1ee368c..f660fc8 100644 --- a/models/ModelBase.py +++ b/models/ModelBase.py @@ -20,10 +20,6 @@ class ModelBase(object): #DONT OVERRIDE def __init__(self, model_path, training_data_src_path=None, training_data_dst_path=None, batch_size=0, - multi_gpu = False, - choose_worst_gpu = False, - force_best_gpu_idx = -1, - force_gpu_idxs = None, write_preview_history = False, debug = False, **in_options ): @@ -70,37 +66,23 @@ class ModelBase(object): for filename in Path_utils.get_image_paths(self.preview_history_path): Path(filename).unlink() - self.multi_gpu = multi_gpu - - gpu_idx = force_best_gpu_idx if (force_best_gpu_idx >= 0 and gpufmkmgr.isValidDeviceIdx(force_best_gpu_idx)) else gpufmkmgr.getBestDeviceIdx() if not choose_worst_gpu else gpufmkmgr.getWorstDeviceIdx() - gpu_total_vram_gb = gpufmkmgr.getDeviceVRAMTotalGb (gpu_idx) - is_gpu_low_mem = (gpu_total_vram_gb < 4) - - self.gpu_total_vram_gb = gpu_total_vram_gb + + self.gpu_config = gpufmkmgr.GPUConfig(allow_growth=False, **in_options) + self.gpu_total_vram_gb = self.gpu_config.gpu_total_vram_gb if self.epoch == 0: #first run - self.options['created_vram_gb'] = gpu_total_vram_gb - self.created_vram_gb = gpu_total_vram_gb + self.options['created_vram_gb'] = self.gpu_total_vram_gb + self.created_vram_gb = self.gpu_total_vram_gb else: #not first run if 'created_vram_gb' in self.options.keys(): self.created_vram_gb = self.options['created_vram_gb'] else: - self.options['created_vram_gb'] = gpu_total_vram_gb - self.created_vram_gb = gpu_total_vram_gb + self.options['created_vram_gb'] = self.gpu_total_vram_gb + self.created_vram_gb = self.gpu_total_vram_gb - if force_gpu_idxs is not None: - self.gpu_idxs = [ int(x) for x in force_gpu_idxs.split(',') ] - else: - if self.multi_gpu: - self.gpu_idxs = gpufmkmgr.getDeviceIdxsEqualModel( gpu_idx ) - if len(self.gpu_idxs) <= 1: - self.multi_gpu = False - else: - self.gpu_idxs = [gpu_idx] - - self.tf = gpufmkmgr.import_tf(self.gpu_idxs,allow_growth=False) + self.tf = gpufmkmgr.import_tf( self.gpu_config ) self.tf_sess = gpufmkmgr.get_tf_session() self.keras = gpufmkmgr.import_keras() self.keras_contrib = gpufmkmgr.import_keras_contrib() @@ -131,12 +113,12 @@ class ModelBase(object): print ("==") print ("== Options:") print ("== |== batch_size : %s " % (self.batch_size) ) - print ("== |== multi_gpu : %s " % (self.multi_gpu) ) + print ("== |== multi_gpu : %s " % (self.gpu_config.multi_gpu) ) for key in self.options.keys(): print ("== |== %s : %s" % (key, self.options[key]) ) print ("== Running on:") - for idx in self.gpu_idxs: + for idx in self.gpu_config.gpu_idxs: print ("== |== [%d : %s]" % (idx, gpufmkmgr.getDeviceName(idx)) ) if self.gpu_total_vram_gb == 2: @@ -188,18 +170,18 @@ class ModelBase(object): return ConverterBase(self, **in_options) def to_multi_gpu_model_if_possible (self, models_list): - if len(self.gpu_idxs) > 1: + if len(self.gpu_config.gpu_idxs) > 1: #make batch_size to divide on GPU count without remainder - self.batch_size = int( self.batch_size / len(self.gpu_idxs) ) + self.batch_size = int( self.batch_size / len(self.gpu_config.gpu_idxs) ) if self.batch_size == 0: self.batch_size = 1 - self.batch_size *= len(self.gpu_idxs) + self.batch_size *= len(self.gpu_config.gpu_idxs) result = [] for model in models_list: for i in range( len(model.output_names) ): model.output_names = 'output_%d' % (i) - result += [ self.keras.utils.multi_gpu_model( model, self.gpu_idxs ) ] + result += [ self.keras.utils.multi_gpu_model( model, self.gpu_config.gpu_idxs ) ] return result else: diff --git a/requirements-gpu-cuda9-cudnn7.txt b/requirements-gpu-cuda9-cudnn7.txt index ce57662..aa13b2f 100644 --- a/requirements-gpu-cuda9-cudnn7.txt +++ b/requirements-gpu-cuda9-cudnn7.txt @@ -8,4 +8,3 @@ scikit-image dlib==19.10.0 tqdm git+https://www.github.com/keras-team/keras-contrib.git -skimage diff --git a/utils/image_utils.py b/utils/image_utils.py index fd58309..b2d1d05 100644 --- a/utils/image_utils.py +++ b/utils/image_utils.py @@ -272,4 +272,107 @@ def reduce_colors (img_bgr, n_colors): img_rgb_p = img_rgb_pil_p.convert('RGB') img_bgr = cv2.cvtColor( np.array(img_rgb_p, dtype=np.float32) / 255.0, cv2.COLOR_RGB2BGR ) - return img_bgr \ No newline at end of file + return img_bgr + + +class TFLabConverter(): + + + + def __init__(self,): + import gpufmkmgr + + self.tf_module = gpufmkmgr.import_tf() + self.tf_session = gpufmkmgr.get_tf_session() + + self.bgr_input_tensor = self.tf_module.placeholder("float", [None, None, 3]) + self.lab_input_tensor = self.tf_module.placeholder("float", [None, None, 3]) + + self.lab_output_tensor = self.rgb_to_lab(self.tf_module, self.bgr_input_tensor) + self.bgr_output_tensor = self.lab_to_rgb(self.tf_module, self.lab_input_tensor) + + + def bgr2lab(self, bgr): + return self.tf_session.run(self.lab_output_tensor, feed_dict={self.bgr_input_tensor: bgr}) + + def lab2bgr(self, lab): + return self.tf_session.run(self.bgr_output_tensor, feed_dict={self.lab_input_tensor: lab}) + + def rgb_to_lab(self, tf, rgb_input): + with tf.name_scope("rgb_to_lab"): + srgb_pixels = tf.reshape(rgb_input, [-1, 3]) + + with tf.name_scope("srgb_to_xyz"): + linear_mask = tf.cast(srgb_pixels <= 0.04045, dtype=tf.float32) + exponential_mask = tf.cast(srgb_pixels > 0.04045, dtype=tf.float32) + rgb_pixels = (srgb_pixels / 12.92 * linear_mask) + (((srgb_pixels + 0.055) / 1.055) ** 2.4) * exponential_mask + rgb_to_xyz = tf.constant([ + # X Y Z + [0.412453, 0.212671, 0.019334], # R + [0.357580, 0.715160, 0.119193], # G + [0.180423, 0.072169, 0.950227], # B + ]) + xyz_pixels = tf.matmul(rgb_pixels, rgb_to_xyz) + + # https://en.wikipedia.org/wiki/Lab_color_space#CIELAB-CIEXYZ_conversions + with tf.name_scope("xyz_to_cielab"): + # convert to fx = f(X/Xn), fy = f(Y/Yn), fz = f(Z/Zn) + + # normalize for D65 white point + xyz_normalized_pixels = tf.multiply(xyz_pixels, [1/0.950456, 1.0, 1/1.088754]) + + epsilon = 6/29 + linear_mask = tf.cast(xyz_normalized_pixels <= (epsilon**3), dtype=tf.float32) + exponential_mask = tf.cast(xyz_normalized_pixels > (epsilon**3), dtype=tf.float32) + fxfyfz_pixels = (xyz_normalized_pixels / (3 * epsilon**2) + 4/29) * linear_mask + (xyz_normalized_pixels ** (1/3)) * exponential_mask + + # convert to lab + fxfyfz_to_lab = tf.constant([ + # l a b + [ 0.0, 500.0, 0.0], # fx + [116.0, -500.0, 200.0], # fy + [ 0.0, 0.0, -200.0], # fz + ]) + lab_pixels = tf.matmul(fxfyfz_pixels, fxfyfz_to_lab) + tf.constant([-16.0, 0.0, 0.0]) + + return tf.reshape(lab_pixels, tf.shape(rgb_input)) + + def lab_to_rgb(self, tf, lab): + with tf.name_scope("lab_to_rgb"): + lab_pixels = tf.reshape(lab, [-1, 3]) + + # https://en.wikipedia.org/wiki/Lab_color_space#CIELAB-CIEXYZ_conversions + with tf.name_scope("cielab_to_xyz"): + # convert to fxfyfz + lab_to_fxfyfz = tf.constant([ + # fx fy fz + [1/116.0, 1/116.0, 1/116.0], # l + [1/500.0, 0.0, 0.0], # a + [ 0.0, 0.0, -1/200.0], # b + ]) + fxfyfz_pixels = tf.matmul(lab_pixels + tf.constant([16.0, 0.0, 0.0]), lab_to_fxfyfz) + + # convert to xyz + epsilon = 6/29 + linear_mask = tf.cast(fxfyfz_pixels <= epsilon, dtype=tf.float32) + exponential_mask = tf.cast(fxfyfz_pixels > epsilon, dtype=tf.float32) + xyz_pixels = (3 * epsilon**2 * (fxfyfz_pixels - 4/29)) * linear_mask + (fxfyfz_pixels ** 3) * exponential_mask + + # denormalize for D65 white point + xyz_pixels = tf.multiply(xyz_pixels, [0.950456, 1.0, 1.088754]) + + with tf.name_scope("xyz_to_srgb"): + xyz_to_rgb = tf.constant([ + # r g b + [ 3.2404542, -0.9692660, 0.0556434], # x + [-1.5371385, 1.8760108, -0.2040259], # y + [-0.4985314, 0.0415560, 1.0572252], # z + ]) + rgb_pixels = tf.matmul(xyz_pixels, xyz_to_rgb) + # avoid a slightly negative number messing up the conversion + rgb_pixels = tf.clip_by_value(rgb_pixels, 0.0, 1.0) + linear_mask = tf.cast(rgb_pixels <= 0.0031308, dtype=tf.float32) + exponential_mask = tf.cast(rgb_pixels > 0.0031308, dtype=tf.float32) + srgb_pixels = (rgb_pixels * 12.92 * linear_mask) + ((rgb_pixels ** (1/2.4) * 1.055) - 0.055) * exponential_mask + + return tf.reshape(srgb_pixels, tf.shape(lab)) \ No newline at end of file