transfercolor via lab converter now implemented by tensorflow-cpu, which is x2 faster than skimage.

We cannot use GPU for lab converter in converter multiprocesses, because almost all VRAM ate by model process, so even 300Mb free VRAM not enough for tensorflow lab converter.
Removed skimage dependency.
Refactorings.
This commit is contained in:
iperov 2018-12-01 12:11:54 +04:00
parent 6f0d38d171
commit 5c43f4245e
7 changed files with 229 additions and 67 deletions

View file

@ -5,6 +5,8 @@ import contextlib
from utils import std_utils from utils import std_utils
from .pynvml import * from .pynvml import *
dlib_module = None dlib_module = None
def import_dlib(device_idx): def import_dlib(device_idx):
global dlib_module global dlib_module
@ -22,18 +24,26 @@ keras_module = None
keras_contrib_module = None keras_contrib_module = None
keras_vggface_module = None keras_vggface_module = None
def set_prefer_GPUConfig(gpu_config):
global prefer_GPUConfig
prefer_GPUConfig = gpu_config
def get_tf_session(): def get_tf_session():
global tf_session global tf_session
return tf_session return tf_session
#allow_growth=False for keras model def import_tf( gpu_config = None ):
#allow_growth=True for tf only model global prefer_GPUConfig
def import_tf( device_idxs_list, allow_growth ):
global tf_module global tf_module
global tf_session global tf_session
if gpu_config is None:
gpu_config = prefer_GPUConfig
else:
prefer_GPUConfig = gpu_config
if tf_module is not None: if tf_module is not None:
raise Exception ('Multiple import of tf is not allowed, reorganize your program.') return tf_module
if 'TF_SUPPRESS_STD' in os.environ.keys() and os.environ['TF_SUPPRESS_STD'] == '1': if 'TF_SUPPRESS_STD' in os.environ.keys() and os.environ['TF_SUPPRESS_STD'] == '1':
suppressor = std_utils.suppress_stdout_stderr().__enter__() suppressor = std_utils.suppress_stdout_stderr().__enter__()
@ -48,14 +58,18 @@ def import_tf( device_idxs_list, allow_growth ):
import tensorflow as tf import tensorflow as tf
tf_module = tf tf_module = tf
visible_device_list = '' if gpu_config.cpu_only:
for idx in device_idxs_list: visible_device_list += str(idx) + ',' config = tf_module.ConfigProto( device_count = {'GPU': 0} )
visible_device_list = visible_device_list[:-1] else:
config = tf_module.ConfigProto()
visible_device_list = ''
for idx in gpu_config.gpu_idxs: visible_device_list += str(idx) + ','
visible_device_list = visible_device_list[:-1]
config.gpu_options.visible_device_list=visible_device_list
config.gpu_options.force_gpu_compatible = True
config.gpu_options.allow_growth = gpu_config.allow_growth
config = tf_module.ConfigProto()
config.gpu_options.allow_growth = allow_growth
config.gpu_options.visible_device_list=visible_device_list
config.gpu_options.force_gpu_compatible = True
tf_session = tf_module.Session(config=config) tf_session = tf_module.Session(config=config)
if suppressor is not None: if suppressor is not None:
@ -71,11 +85,15 @@ def finalize_tf():
tf_session = None tf_session = None
tf_module = None tf_module = None
def get_keras():
global keras_module
return keras_module
def import_keras(): def import_keras():
global keras_module global keras_module
if keras_module is not None: if keras_module is not None:
raise Exception ('Multiple import of keras is not allowed, reorganize your program.') return keras_module
sess = get_tf_session() sess = get_tf_session()
if sess is None: if sess is None:
@ -242,3 +260,47 @@ def getDeviceName (idx):
result = nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(idx)).decode() result = nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(idx)).decode()
nvmlShutdown() nvmlShutdown()
return result return result
class GPUConfig():
force_best_gpu_idx = -1
multi_gpu = False
force_gpu_idxs = None
choose_worst_gpu = False
gpu_idxs = []
gpu_total_vram_gb = 0
allow_growth = True
cpu_only = False
def __init__ (self, force_best_gpu_idx = -1,
multi_gpu = False,
force_gpu_idxs = None,
choose_worst_gpu = False,
allow_growth = True,
cpu_only = False,
**in_options):
if cpu_only:
self.cpu_only = cpu_only
else:
self.force_best_gpu_idx = force_best_gpu_idx
self.multi_gpu = multi_gpu
self.force_gpu_idxs = force_gpu_idxs
self.choose_worst_gpu = choose_worst_gpu
self.allow_growth = allow_growth
gpu_idx = force_best_gpu_idx if (force_best_gpu_idx >= 0 and isValidDeviceIdx(force_best_gpu_idx)) else getBestDeviceIdx() if not choose_worst_gpu else getWorstDeviceIdx()
if force_gpu_idxs is not None:
self.gpu_idxs = [ int(x) for x in force_gpu_idxs.split(',') ]
else:
if self.multi_gpu:
self.gpu_idxs = getDeviceIdxsEqualModel( gpu_idx )
if len(self.gpu_idxs) <= 1:
self.multi_gpu = False
else:
self.gpu_idxs = [gpu_idx]
self.gpu_total_vram_gb = getDeviceVRAMTotalGb ( self.gpu_idxs[0] )
prefer_GPUConfig = GPUConfig()

View file

@ -64,14 +64,14 @@ from utils.SubprocessorBase import SubprocessorBase
class ConvertSubprocessor(SubprocessorBase): class ConvertSubprocessor(SubprocessorBase):
#override #override
def __init__(self, converter, input_path_image_paths, output_path, alignments, debug): def __init__(self, converter, input_path_image_paths, output_path, alignments, debug = False, **in_options):
super().__init__('Converter') super().__init__('Converter')
self.converter = converter self.converter = converter
self.input_path_image_paths = input_path_image_paths self.input_path_image_paths = input_path_image_paths
self.output_path = output_path self.output_path = output_path
self.alignments = alignments self.alignments = alignments
self.debug = debug self.debug = debug
self.in_options = in_options
self.input_data = self.input_path_image_paths self.input_data = self.input_path_image_paths
self.files_processed = 0 self.files_processed = 0
self.faces_processed = 0 self.faces_processed = 0
@ -79,13 +79,18 @@ class ConvertSubprocessor(SubprocessorBase):
#override #override
def process_info_generator(self): def process_info_generator(self):
r = [0] if self.debug else range(multiprocessing.cpu_count()) r = [0] if self.debug else range(multiprocessing.cpu_count())
for i in r: for i in r:
yield 'CPU%d' % (i), {}, {'device_idx': i, yield 'CPU%d' % (i), {}, {'device_idx': i,
'device_name': 'CPU%d' % (i), 'device_name': 'CPU%d' % (i),
'converter' : self.converter, 'converter' : self.converter,
'output_dir' : str(self.output_path), 'output_dir' : str(self.output_path),
'alignments' : self.alignments, 'alignments' : self.alignments,
'debug': self.debug } 'debug': self.debug,
'in_options': self.in_options
}
#override #override
def get_no_process_started_message(self): def get_no_process_started_message(self):
@ -118,6 +123,12 @@ class ConvertSubprocessor(SubprocessorBase):
self.output_path = Path(client_dict['output_dir']) if 'output_dir' in client_dict.keys() else None self.output_path = Path(client_dict['output_dir']) if 'output_dir' in client_dict.keys() else None
self.alignments = client_dict['alignments'] self.alignments = client_dict['alignments']
self.debug = client_dict['debug'] self.debug = client_dict['debug']
import gpufmkmgr
#model process ate all GPU mem,
#so we cannot use GPU for any TF operations in converter processes (for example image_utils.TFLabConverter)
gpufmkmgr.set_prefer_GPUConfig ( gpufmkmgr.GPUConfig (cpu_only=True) )
return None return None
#override #override
@ -211,7 +222,6 @@ def main (input_dir, output_dir, aligned_dir, model_dir, model_name, **in_option
model_sq = multiprocessing.Queue() model_sq = multiprocessing.Queue()
model_cq = multiprocessing.Queue() model_cq = multiprocessing.Queue()
model_lock = multiprocessing.Lock() model_lock = multiprocessing.Lock()
model_p = multiprocessing.Process(target=model_process, args=(model_name, model_dir, in_options, model_sq, model_cq)) model_p = multiprocessing.Process(target=model_process, args=(model_name, model_dir, in_options, model_sq, model_cq))
model_p.start() model_p.start()
@ -242,12 +252,13 @@ def main (input_dir, output_dir, aligned_dir, model_dir, model_name, **in_option
alignments[ source_filename_stem ].append ( np.array(d['source_landmarks']) ) alignments[ source_filename_stem ].append ( np.array(d['source_landmarks']) )
files_processed, faces_processed = ConvertSubprocessor ( files_processed, faces_processed = ConvertSubprocessor (
converter = converter.copy_and_set_predictor( model_process_predictor(model_sq,model_cq,model_lock) ), converter = converter.copy_and_set_predictor( model_process_predictor(model_sq,model_cq,model_lock) ),
input_path_image_paths = Path_utils.get_image_paths(input_path), input_path_image_paths = Path_utils.get_image_paths(input_path),
output_path = output_path, output_path = output_path,
alignments = alignments, alignments = alignments,
debug = debug ).process() **in_options ).process()
model_sq.put ( {'op':'close'} ) model_sq.put ( {'op':'close'} )
model_p.join() model_p.join()

View file

@ -241,7 +241,9 @@ class ExtractSubprocessor(SubprocessorBase):
if self.type == 'rects': if self.type == 'rects':
if self.detector is not None: if self.detector is not None:
if self.detector == 'mt': if self.detector == 'mt':
self.tf = gpufmkmgr.import_tf ([self.device_idx], allow_growth=True)
self.gpu_config = gpufmkmgr.GPUConfig ( force_best_gpu_idx=self.device_idx, allow_growth=True)
self.tf = gpufmkmgr.import_tf ( self.gpu_config )
self.tf_session = gpufmkmgr.get_tf_session() self.tf_session = gpufmkmgr.get_tf_session()
self.keras = gpufmkmgr.import_keras() self.keras = gpufmkmgr.import_keras()
self.e = facelib.MTCExtractor(self.keras, self.tf, self.tf_session) self.e = facelib.MTCExtractor(self.keras, self.tf, self.tf_session)
@ -251,7 +253,8 @@ class ExtractSubprocessor(SubprocessorBase):
self.e.__enter__() self.e.__enter__()
elif self.type == 'landmarks': elif self.type == 'landmarks':
self.tf = gpufmkmgr.import_tf([self.device_idx], allow_growth=True) self.gpu_config = gpufmkmgr.GPUConfig ( force_best_gpu_idx=self.device_idx, allow_growth=True)
self.tf = gpufmkmgr.import_tf ( self.gpu_config )
self.tf_session = gpufmkmgr.get_tf_session() self.tf_session = gpufmkmgr.get_tf_session()
self.keras = gpufmkmgr.import_keras() self.keras = gpufmkmgr.import_keras()
self.e = facelib.LandmarksExtractor(self.keras) self.e = facelib.LandmarksExtractor(self.keras)

View file

@ -39,6 +39,7 @@ class ConverterMasked(ConverterBase):
self.blur_mask_modifier = blur_mask_modifier self.blur_mask_modifier = blur_mask_modifier
self.output_face_scale = np.clip(1.0 + output_face_scale_modifier*0.01, 0.5, 1.0) self.output_face_scale = np.clip(1.0 + output_face_scale_modifier*0.01, 0.5, 1.0)
self.transfercolor = transfercolor self.transfercolor = transfercolor
self.TFLabConverter = None
self.final_image_color_degrade_power = np.clip (final_image_color_degrade_power, 0, 100) self.final_image_color_degrade_power = np.clip (final_image_color_degrade_power, 0, 100)
self.alpha = alpha self.alpha = alpha
@ -199,14 +200,14 @@ class ConverterMasked(ConverterBase):
new_out = cv2.warpAffine( new_out_face_bgr, face_mat, img_size, img_bgr.copy(), cv2.WARP_INVERSE_MAP | cv2.INTER_LANCZOS4, cv2.BORDER_TRANSPARENT ) new_out = cv2.warpAffine( new_out_face_bgr, face_mat, img_size, img_bgr.copy(), cv2.WARP_INVERSE_MAP | cv2.INTER_LANCZOS4, cv2.BORDER_TRANSPARENT )
out_img = np.clip( img_bgr*(1-img_mask_blurry_aaa) + (new_out*img_mask_blurry_aaa) , 0, 1.0 ) out_img = np.clip( img_bgr*(1-img_mask_blurry_aaa) + (new_out*img_mask_blurry_aaa) , 0, 1.0 )
if self.transfercolor: #making transfer color from original DST image to fake if self.transfercolor:
from skimage import io, color if self.TFLabConverter is None:
lab_clr = color.rgb2lab(img_bgr) #original DST, converting RGB to LAB color space self.TFLabConverter = image_utils.TFLabConverter()
lab_bw = color.rgb2lab(out_img) #fake, converting RGB to LAB color space
tmp_channel, a_channel, b_channel = cv2.split(lab_clr) #taking color channel A and B from original dst image img_lab_l, img_lab_a, img_lab_b = np.split ( self.TFLabConverter.bgr2lab (img_bgr), 3, axis=-1 )
l_channel, tmp2_channel, tmp3_channel = cv2.split(lab_bw) #taking lightness channel L from merged fake out_img_lab_l, out_img_lab_a, out_img_lab_b = np.split ( self.TFLabConverter.bgr2lab (out_img), 3, axis=-1 )
img_LAB = cv2.merge((l_channel,a_channel, b_channel)) #merging light and color
out_img = color.lab2rgb(img_LAB) #converting LAB to RGB out_img = self.TFLabConverter.lab2bgr ( np.concatenate([out_img_lab_l, img_lab_a, img_lab_b], axis=-1) )
if self.final_image_color_degrade_power != 0: if self.final_image_color_degrade_power != 0:
if debug: if debug:
@ -235,3 +236,4 @@ class ConverterMasked(ConverterBase):
return debugs if debug else out_img return debugs if debug else out_img

View file

@ -20,10 +20,6 @@ class ModelBase(object):
#DONT OVERRIDE #DONT OVERRIDE
def __init__(self, model_path, training_data_src_path=None, training_data_dst_path=None, def __init__(self, model_path, training_data_src_path=None, training_data_dst_path=None,
batch_size=0, batch_size=0,
multi_gpu = False,
choose_worst_gpu = False,
force_best_gpu_idx = -1,
force_gpu_idxs = None,
write_preview_history = False, write_preview_history = False,
debug = False, **in_options debug = False, **in_options
): ):
@ -70,37 +66,23 @@ class ModelBase(object):
for filename in Path_utils.get_image_paths(self.preview_history_path): for filename in Path_utils.get_image_paths(self.preview_history_path):
Path(filename).unlink() Path(filename).unlink()
self.multi_gpu = multi_gpu
gpu_idx = force_best_gpu_idx if (force_best_gpu_idx >= 0 and gpufmkmgr.isValidDeviceIdx(force_best_gpu_idx)) else gpufmkmgr.getBestDeviceIdx() if not choose_worst_gpu else gpufmkmgr.getWorstDeviceIdx() self.gpu_config = gpufmkmgr.GPUConfig(allow_growth=False, **in_options)
gpu_total_vram_gb = gpufmkmgr.getDeviceVRAMTotalGb (gpu_idx) self.gpu_total_vram_gb = self.gpu_config.gpu_total_vram_gb
is_gpu_low_mem = (gpu_total_vram_gb < 4)
self.gpu_total_vram_gb = gpu_total_vram_gb
if self.epoch == 0: if self.epoch == 0:
#first run #first run
self.options['created_vram_gb'] = gpu_total_vram_gb self.options['created_vram_gb'] = self.gpu_total_vram_gb
self.created_vram_gb = gpu_total_vram_gb self.created_vram_gb = self.gpu_total_vram_gb
else: else:
#not first run #not first run
if 'created_vram_gb' in self.options.keys(): if 'created_vram_gb' in self.options.keys():
self.created_vram_gb = self.options['created_vram_gb'] self.created_vram_gb = self.options['created_vram_gb']
else: else:
self.options['created_vram_gb'] = gpu_total_vram_gb self.options['created_vram_gb'] = self.gpu_total_vram_gb
self.created_vram_gb = gpu_total_vram_gb self.created_vram_gb = self.gpu_total_vram_gb
if force_gpu_idxs is not None: self.tf = gpufmkmgr.import_tf( self.gpu_config )
self.gpu_idxs = [ int(x) for x in force_gpu_idxs.split(',') ]
else:
if self.multi_gpu:
self.gpu_idxs = gpufmkmgr.getDeviceIdxsEqualModel( gpu_idx )
if len(self.gpu_idxs) <= 1:
self.multi_gpu = False
else:
self.gpu_idxs = [gpu_idx]
self.tf = gpufmkmgr.import_tf(self.gpu_idxs,allow_growth=False)
self.tf_sess = gpufmkmgr.get_tf_session() self.tf_sess = gpufmkmgr.get_tf_session()
self.keras = gpufmkmgr.import_keras() self.keras = gpufmkmgr.import_keras()
self.keras_contrib = gpufmkmgr.import_keras_contrib() self.keras_contrib = gpufmkmgr.import_keras_contrib()
@ -131,12 +113,12 @@ class ModelBase(object):
print ("==") print ("==")
print ("== Options:") print ("== Options:")
print ("== |== batch_size : %s " % (self.batch_size) ) print ("== |== batch_size : %s " % (self.batch_size) )
print ("== |== multi_gpu : %s " % (self.multi_gpu) ) print ("== |== multi_gpu : %s " % (self.gpu_config.multi_gpu) )
for key in self.options.keys(): for key in self.options.keys():
print ("== |== %s : %s" % (key, self.options[key]) ) print ("== |== %s : %s" % (key, self.options[key]) )
print ("== Running on:") print ("== Running on:")
for idx in self.gpu_idxs: for idx in self.gpu_config.gpu_idxs:
print ("== |== [%d : %s]" % (idx, gpufmkmgr.getDeviceName(idx)) ) print ("== |== [%d : %s]" % (idx, gpufmkmgr.getDeviceName(idx)) )
if self.gpu_total_vram_gb == 2: if self.gpu_total_vram_gb == 2:
@ -188,18 +170,18 @@ class ModelBase(object):
return ConverterBase(self, **in_options) return ConverterBase(self, **in_options)
def to_multi_gpu_model_if_possible (self, models_list): def to_multi_gpu_model_if_possible (self, models_list):
if len(self.gpu_idxs) > 1: if len(self.gpu_config.gpu_idxs) > 1:
#make batch_size to divide on GPU count without remainder #make batch_size to divide on GPU count without remainder
self.batch_size = int( self.batch_size / len(self.gpu_idxs) ) self.batch_size = int( self.batch_size / len(self.gpu_config.gpu_idxs) )
if self.batch_size == 0: if self.batch_size == 0:
self.batch_size = 1 self.batch_size = 1
self.batch_size *= len(self.gpu_idxs) self.batch_size *= len(self.gpu_config.gpu_idxs)
result = [] result = []
for model in models_list: for model in models_list:
for i in range( len(model.output_names) ): for i in range( len(model.output_names) ):
model.output_names = 'output_%d' % (i) model.output_names = 'output_%d' % (i)
result += [ self.keras.utils.multi_gpu_model( model, self.gpu_idxs ) ] result += [ self.keras.utils.multi_gpu_model( model, self.gpu_config.gpu_idxs ) ]
return result return result
else: else:

View file

@ -8,4 +8,3 @@ scikit-image
dlib==19.10.0 dlib==19.10.0
tqdm tqdm
git+https://www.github.com/keras-team/keras-contrib.git git+https://www.github.com/keras-team/keras-contrib.git
skimage

View file

@ -273,3 +273,106 @@ def reduce_colors (img_bgr, n_colors):
img_bgr = cv2.cvtColor( np.array(img_rgb_p, dtype=np.float32) / 255.0, cv2.COLOR_RGB2BGR ) img_bgr = cv2.cvtColor( np.array(img_rgb_p, dtype=np.float32) / 255.0, cv2.COLOR_RGB2BGR )
return img_bgr return img_bgr
class TFLabConverter():
def __init__(self,):
import gpufmkmgr
self.tf_module = gpufmkmgr.import_tf()
self.tf_session = gpufmkmgr.get_tf_session()
self.bgr_input_tensor = self.tf_module.placeholder("float", [None, None, 3])
self.lab_input_tensor = self.tf_module.placeholder("float", [None, None, 3])
self.lab_output_tensor = self.rgb_to_lab(self.tf_module, self.bgr_input_tensor)
self.bgr_output_tensor = self.lab_to_rgb(self.tf_module, self.lab_input_tensor)
def bgr2lab(self, bgr):
return self.tf_session.run(self.lab_output_tensor, feed_dict={self.bgr_input_tensor: bgr})
def lab2bgr(self, lab):
return self.tf_session.run(self.bgr_output_tensor, feed_dict={self.lab_input_tensor: lab})
def rgb_to_lab(self, tf, rgb_input):
with tf.name_scope("rgb_to_lab"):
srgb_pixels = tf.reshape(rgb_input, [-1, 3])
with tf.name_scope("srgb_to_xyz"):
linear_mask = tf.cast(srgb_pixels <= 0.04045, dtype=tf.float32)
exponential_mask = tf.cast(srgb_pixels > 0.04045, dtype=tf.float32)
rgb_pixels = (srgb_pixels / 12.92 * linear_mask) + (((srgb_pixels + 0.055) / 1.055) ** 2.4) * exponential_mask
rgb_to_xyz = tf.constant([
# X Y Z
[0.412453, 0.212671, 0.019334], # R
[0.357580, 0.715160, 0.119193], # G
[0.180423, 0.072169, 0.950227], # B
])
xyz_pixels = tf.matmul(rgb_pixels, rgb_to_xyz)
# https://en.wikipedia.org/wiki/Lab_color_space#CIELAB-CIEXYZ_conversions
with tf.name_scope("xyz_to_cielab"):
# convert to fx = f(X/Xn), fy = f(Y/Yn), fz = f(Z/Zn)
# normalize for D65 white point
xyz_normalized_pixels = tf.multiply(xyz_pixels, [1/0.950456, 1.0, 1/1.088754])
epsilon = 6/29
linear_mask = tf.cast(xyz_normalized_pixels <= (epsilon**3), dtype=tf.float32)
exponential_mask = tf.cast(xyz_normalized_pixels > (epsilon**3), dtype=tf.float32)
fxfyfz_pixels = (xyz_normalized_pixels / (3 * epsilon**2) + 4/29) * linear_mask + (xyz_normalized_pixels ** (1/3)) * exponential_mask
# convert to lab
fxfyfz_to_lab = tf.constant([
# l a b
[ 0.0, 500.0, 0.0], # fx
[116.0, -500.0, 200.0], # fy
[ 0.0, 0.0, -200.0], # fz
])
lab_pixels = tf.matmul(fxfyfz_pixels, fxfyfz_to_lab) + tf.constant([-16.0, 0.0, 0.0])
return tf.reshape(lab_pixels, tf.shape(rgb_input))
def lab_to_rgb(self, tf, lab):
with tf.name_scope("lab_to_rgb"):
lab_pixels = tf.reshape(lab, [-1, 3])
# https://en.wikipedia.org/wiki/Lab_color_space#CIELAB-CIEXYZ_conversions
with tf.name_scope("cielab_to_xyz"):
# convert to fxfyfz
lab_to_fxfyfz = tf.constant([
# fx fy fz
[1/116.0, 1/116.0, 1/116.0], # l
[1/500.0, 0.0, 0.0], # a
[ 0.0, 0.0, -1/200.0], # b
])
fxfyfz_pixels = tf.matmul(lab_pixels + tf.constant([16.0, 0.0, 0.0]), lab_to_fxfyfz)
# convert to xyz
epsilon = 6/29
linear_mask = tf.cast(fxfyfz_pixels <= epsilon, dtype=tf.float32)
exponential_mask = tf.cast(fxfyfz_pixels > epsilon, dtype=tf.float32)
xyz_pixels = (3 * epsilon**2 * (fxfyfz_pixels - 4/29)) * linear_mask + (fxfyfz_pixels ** 3) * exponential_mask
# denormalize for D65 white point
xyz_pixels = tf.multiply(xyz_pixels, [0.950456, 1.0, 1.088754])
with tf.name_scope("xyz_to_srgb"):
xyz_to_rgb = tf.constant([
# r g b
[ 3.2404542, -0.9692660, 0.0556434], # x
[-1.5371385, 1.8760108, -0.2040259], # y
[-0.4985314, 0.0415560, 1.0572252], # z
])
rgb_pixels = tf.matmul(xyz_pixels, xyz_to_rgb)
# avoid a slightly negative number messing up the conversion
rgb_pixels = tf.clip_by_value(rgb_pixels, 0.0, 1.0)
linear_mask = tf.cast(rgb_pixels <= 0.0031308, dtype=tf.float32)
exponential_mask = tf.cast(rgb_pixels > 0.0031308, dtype=tf.float32)
srgb_pixels = (rgb_pixels * 12.92 * linear_mask) + ((rgb_pixels ** (1/2.4) * 1.055) - 0.055) * exponential_mask
return tf.reshape(srgb_pixels, tf.shape(lab))