mirror of
https://github.com/iperov/DeepFaceLive
synced 2025-07-13 00:23:52 -07:00
update xlib.avecl
This commit is contained in:
parent
63adc2995e
commit
2d401f47f8
7 changed files with 337 additions and 27 deletions
|
@ -1,6 +1,9 @@
|
||||||
from collections import Iterable
|
from collections import Iterable
|
||||||
|
from typing import Tuple, List
|
||||||
|
|
||||||
from .AAxes import AAxes
|
from .AAxes import AAxes
|
||||||
|
|
||||||
|
|
||||||
class AShape(Iterable):
|
class AShape(Iterable):
|
||||||
__slots__ = ['shape','size','ndim']
|
__slots__ = ['shape','size','ndim']
|
||||||
|
|
||||||
|
@ -13,6 +16,8 @@ class AShape(Iterable):
|
||||||
shape AShape
|
shape AShape
|
||||||
Iterable
|
Iterable
|
||||||
|
|
||||||
|
AShape cannot be scalar shape, thus minimal AShape is (1,)
|
||||||
|
|
||||||
can raise ValueError during the construction
|
can raise ValueError during the construction
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -46,6 +51,12 @@ class AShape(Iterable):
|
||||||
else:
|
else:
|
||||||
raise ValueError('Invalid type to create AShape')
|
raise ValueError('Invalid type to create AShape')
|
||||||
|
|
||||||
|
def copy(self) -> 'AShape':
|
||||||
|
return AShape(self)
|
||||||
|
|
||||||
|
def as_list(self) -> List[int]:
|
||||||
|
return list(self.shape)
|
||||||
|
|
||||||
def axes_arange(self) -> AAxes:
|
def axes_arange(self) -> AAxes:
|
||||||
"""
|
"""
|
||||||
Returns tuple of axes arange.
|
Returns tuple of axes arange.
|
||||||
|
@ -54,6 +65,35 @@ class AShape(Iterable):
|
||||||
"""
|
"""
|
||||||
return AAxes(range(self.ndim))
|
return AAxes(range(self.ndim))
|
||||||
|
|
||||||
|
def replaced_axes(self, axes, dims) -> 'AShape':
|
||||||
|
"""
|
||||||
|
returns new AShape where axes replaced with new dims
|
||||||
|
"""
|
||||||
|
new_shape = list(self.shape)
|
||||||
|
ndim = self.ndim
|
||||||
|
for axis, dim in zip(axes, dims):
|
||||||
|
if axis < 0:
|
||||||
|
axis = ndim + axis
|
||||||
|
if axis < 0 or axis >= ndim:
|
||||||
|
raise ValueError(f'invalid axis value {axis}')
|
||||||
|
|
||||||
|
new_shape[axis] = dim
|
||||||
|
return AShape(new_shape)
|
||||||
|
|
||||||
|
|
||||||
|
def split(self, axis) -> Tuple['AShape', 'AShape']:
|
||||||
|
"""
|
||||||
|
split AShape at specified axis
|
||||||
|
|
||||||
|
returns two AShape before+exclusive and inclusive+after
|
||||||
|
"""
|
||||||
|
if axis < 0:
|
||||||
|
axis = self.ndim + axis
|
||||||
|
if axis < 0 or axis >= self.ndim:
|
||||||
|
raise ValueError(f'invalid axis value {axis}')
|
||||||
|
|
||||||
|
return self[:axis], self[axis:]
|
||||||
|
|
||||||
def transpose_by_axes(self, axes) -> 'AShape':
|
def transpose_by_axes(self, axes) -> 'AShape':
|
||||||
"""
|
"""
|
||||||
Same as AShape[axes]
|
Same as AShape[axes]
|
||||||
|
|
|
@ -142,10 +142,28 @@ class HKernel:
|
||||||
out += [f'#define {name_upper}_TO_FLOATX(x) ((double)x)']
|
out += [f'#define {name_upper}_TO_FLOATX(x) ((double)x)']
|
||||||
return '\n'.join(out)
|
return '\n'.join(out)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def define_ndim_idx(ndim):
|
||||||
|
"""
|
||||||
|
example for ndim=3
|
||||||
|
#define NDIM3_IDX(t0,t1,t2,T0,T1,T2) (((size_t)(t0))*T1*T2+((size_t)(t1))*T2+((size_t)(t2)))
|
||||||
|
#define NDIM3_IDX_MOD(t0,t1,t2,T0,T1,T2) (((size_t)(t0) % T0)*T1*T2+((size_t)(t1) % T1)*T2+((size_t)(t2) % T2))
|
||||||
|
"""
|
||||||
|
|
||||||
|
out = [f'#define NDIM{ndim}_IDX(' + \
|
||||||
|
','.join([f't{i}' for i in range(ndim)] + [f'T{i}' for i in range(ndim)]) + \
|
||||||
|
') (' + '+'.join([f'((size_t)(t{i}))' + ''.join(f'*T{j}' for j in range(i+1,ndim)) for i in range(ndim) ]) + ')']
|
||||||
|
|
||||||
|
out +=[f'#define NDIM{ndim}_IDX_MOD(' + \
|
||||||
|
','.join([f't{i}' for i in range(ndim)] + [f'T{i}' for i in range(ndim)]) + \
|
||||||
|
') (' + '+'.join([f'((size_t)(t{i}) % T{i})' + ''.join(f'*T{j}' for j in range(i+1,ndim)) for i in range(ndim) ]) + ')']
|
||||||
|
|
||||||
|
return '\n'.join(out)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def define_tensor_shape(name, shape, axes_symbols=None):
|
def define_tensor_shape(name, shape, axes_symbols=None):
|
||||||
"""
|
"""
|
||||||
Returns a definitions for operations with tensor
|
Returns a definitions for operations with tensor shape
|
||||||
|
|
||||||
example for 'O', (7,3),
|
example for 'O', (7,3),
|
||||||
|
|
||||||
|
|
|
@ -44,6 +44,7 @@ class NTest():
|
||||||
binary_erode_circle_test,
|
binary_erode_circle_test,
|
||||||
binary_dilate_circle_test,
|
binary_dilate_circle_test,
|
||||||
binary_morph_test,
|
binary_morph_test,
|
||||||
|
cvt_color_test,
|
||||||
]
|
]
|
||||||
|
|
||||||
for test_func in test_funcs:
|
for test_func in test_funcs:
|
||||||
|
@ -61,6 +62,30 @@ class NTest():
|
||||||
def _all_close(x,y, atol=1, btol=1):
|
def _all_close(x,y, atol=1, btol=1):
|
||||||
return np.allclose( np.ndarray.flatten(x[None,...]), np.ndarray.flatten(y[None,...]), atol, btol )
|
return np.allclose( np.ndarray.flatten(x[None,...]), np.ndarray.flatten(y[None,...]), atol, btol )
|
||||||
|
|
||||||
|
def cvt_color_test():
|
||||||
|
for _ in range(10):
|
||||||
|
for shape_len in range(2,6):
|
||||||
|
for in_mode in ['RGB','BGR','XYZ','LAB']:
|
||||||
|
for out_mode in ['RGB','BGR','XYZ','LAB']:
|
||||||
|
for dtype in [np.float16, np.float32, np.float64]:
|
||||||
|
shape = list(np.random.randint(1, 8, size=shape_len) )
|
||||||
|
|
||||||
|
ch_axis = np.random.randint(len(shape))
|
||||||
|
shape[ch_axis] = 3
|
||||||
|
|
||||||
|
print(f'cvt_color {shape} {str(np.dtype(dtype).name)} {in_mode}->{out_mode} ... ', end='')
|
||||||
|
|
||||||
|
inp_n = np.random.uniform(size=shape ).astype(dtype)
|
||||||
|
inp_t = Tensor.from_value(inp_n)
|
||||||
|
|
||||||
|
out_t = op.cvt_color(inp_t, in_mode=in_mode, out_mode=out_mode, ch_axis=ch_axis)
|
||||||
|
inp_t2 = op.cvt_color(out_t, in_mode=out_mode, out_mode=in_mode, ch_axis=ch_axis)
|
||||||
|
|
||||||
|
if not _all_close(inp_t.np(), inp_t2.np(), atol=0.1, btol=0.1):
|
||||||
|
raise Exception(f'data is not equal')
|
||||||
|
|
||||||
|
print('pass')
|
||||||
|
|
||||||
def cast_test():
|
def cast_test():
|
||||||
for in_dtype in HType.get_np_scalar_types():
|
for in_dtype in HType.get_np_scalar_types():
|
||||||
for out_dtype in HType.get_np_scalar_types():
|
for out_dtype in HType.get_np_scalar_types():
|
||||||
|
|
|
@ -27,24 +27,24 @@ class SCacheton:
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def set_var(var_name, value):
|
def set_var(key, value):
|
||||||
"""
|
"""
|
||||||
Set data cached by var_name
|
Set data cached by key
|
||||||
All cached data will be freed with cleanup()
|
All cached data will be freed with cleanup()
|
||||||
|
|
||||||
You must not to store Tensor in SCacheton, use per-device cache vars
|
You must not to store Tensor in SCacheton, use per-device cache vars
|
||||||
"""
|
"""
|
||||||
SCacheton.cachevars[var_name] = value
|
SCacheton.cachevars[key] = value
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_var(var_name):
|
def get_var(key):
|
||||||
"""
|
"""
|
||||||
Get data cached by var_name
|
Get data cached by key
|
||||||
All cached data will be freed with cleanup()
|
All cached data will be freed with cleanup()
|
||||||
|
|
||||||
You must not to store Tensor in SCacheton, use per-device cache vars
|
You must not to store Tensor in SCacheton, use per-device cache vars
|
||||||
"""
|
"""
|
||||||
return SCacheton.cachevars.get(var_name, None)
|
return SCacheton.cachevars.get(key, None)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def cleanup():
|
def cleanup():
|
||||||
|
|
|
@ -36,7 +36,7 @@ class Device:
|
||||||
|
|
||||||
self._cached_data = {} # cached data (per device) by key
|
self._cached_data = {} # cached data (per device) by key
|
||||||
self._pooled_buffers = {} # Pool of cached device buffers.
|
self._pooled_buffers = {} # Pool of cached device buffers.
|
||||||
self._compiled_kernels = {} # compiled kernels by key
|
self._cached_kernels = {} # compiled kernels by key
|
||||||
self._ctx_q = None # CL command queue
|
self._ctx_q = None # CL command queue
|
||||||
self._ctx = None # CL context
|
self._ctx = None # CL context
|
||||||
|
|
||||||
|
@ -86,7 +86,7 @@ class Device:
|
||||||
compile or get cached kernel
|
compile or get cached kernel
|
||||||
"""
|
"""
|
||||||
|
|
||||||
compiled_krn, prog = self._compiled_kernels.get(key, (None, None) )
|
compiled_krn, prog = self._cached_kernels.get(key, (None, None) )
|
||||||
|
|
||||||
if compiled_krn is None:
|
if compiled_krn is None:
|
||||||
clr = CL.CLRESULT()
|
clr = CL.CLRESULT()
|
||||||
|
@ -123,7 +123,7 @@ class Device:
|
||||||
raise Exception(f'clCreateKernelsInProgram error: {clr}')
|
raise Exception(f'clCreateKernelsInProgram error: {clr}')
|
||||||
|
|
||||||
compiled_krn = kernels[0]
|
compiled_krn = kernels[0]
|
||||||
self._compiled_kernels[key] = (compiled_krn, prog)
|
self._cached_kernels[key] = (compiled_krn, prog)
|
||||||
|
|
||||||
return compiled_krn
|
return compiled_krn
|
||||||
|
|
||||||
|
@ -176,7 +176,7 @@ class Device:
|
||||||
mem = self._cl_mem_alloc(size)
|
mem = self._cl_mem_alloc(size)
|
||||||
if mem is None:
|
if mem is None:
|
||||||
# MemoryError.
|
# MemoryError.
|
||||||
if not self._free_random_pooled_buffers():
|
if not self._release_random_pooled_buffers():
|
||||||
raise Exception(f'Unable to allocate {size // 1024**2}Mb on {self.get_description()}')
|
raise Exception(f'Unable to allocate {size // 1024**2}Mb on {self.get_description()}')
|
||||||
continue
|
continue
|
||||||
break
|
break
|
||||||
|
@ -202,7 +202,7 @@ class Device:
|
||||||
self._total_memory_pooled += size
|
self._total_memory_pooled += size
|
||||||
self._total_buffers_pooled += 1
|
self._total_buffers_pooled += 1
|
||||||
|
|
||||||
def _free_random_pooled_buffers(self) -> bool:
|
def _release_random_pooled_buffers(self) -> bool:
|
||||||
"""
|
"""
|
||||||
remove random 25% of pooled boofers
|
remove random 25% of pooled boofers
|
||||||
|
|
||||||
|
@ -221,7 +221,8 @@ class Device:
|
||||||
def _keep_target_memory_usage(self):
|
def _keep_target_memory_usage(self):
|
||||||
targ = self._target_memory_usage
|
targ = self._target_memory_usage
|
||||||
if targ != 0 and self.get_total_allocated_memory() >= targ:
|
if targ != 0 and self.get_total_allocated_memory() >= targ:
|
||||||
self._free_random_pooled_buffers()
|
self.cleanup_cached_kernels()
|
||||||
|
self._release_random_pooled_buffers()
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return self.get_description()
|
return self.get_description()
|
||||||
|
@ -242,6 +243,17 @@ class Device:
|
||||||
self._total_memory_pooled = 0
|
self._total_memory_pooled = 0
|
||||||
self._total_buffers_pooled = 0
|
self._total_buffers_pooled = 0
|
||||||
|
|
||||||
|
def cleanup_cached_kernels(self):
|
||||||
|
for kernel, prog in self._cached_kernels.values():
|
||||||
|
clr = CL.clReleaseKernel(kernel)
|
||||||
|
if clr != CL.CLERROR.SUCCESS:
|
||||||
|
raise Exception(f'clReleaseKernel error: {clr}')
|
||||||
|
|
||||||
|
clr = CL.clReleaseProgram(prog)
|
||||||
|
if clr != CL.CLERROR.SUCCESS:
|
||||||
|
raise Exception(f'clReleaseProgram error: {clr}')
|
||||||
|
self._cached_kernels = {}
|
||||||
|
|
||||||
def cleanup(self):
|
def cleanup(self):
|
||||||
"""
|
"""
|
||||||
Frees all resources from this Device.
|
Frees all resources from this Device.
|
||||||
|
@ -253,15 +265,7 @@ class Device:
|
||||||
if self._total_memory_allocated != 0:
|
if self._total_memory_allocated != 0:
|
||||||
raise Exception('Unable to cleanup CLDevice, while not all Buffers are deallocated.')
|
raise Exception('Unable to cleanup CLDevice, while not all Buffers are deallocated.')
|
||||||
|
|
||||||
for kernel, prog in self._compiled_kernels.values():
|
self.cleanup_cached_kernels()
|
||||||
clr = CL.clReleaseKernel(kernel)
|
|
||||||
if clr != CL.CLERROR.SUCCESS:
|
|
||||||
raise Exception(f'clReleaseKernel error: {clr}')
|
|
||||||
|
|
||||||
clr = CL.clReleaseProgram(prog)
|
|
||||||
if clr != CL.CLERROR.SUCCESS:
|
|
||||||
raise Exception(f'clReleaseProgram error: {clr}')
|
|
||||||
self._compiled_kernels = {}
|
|
||||||
|
|
||||||
if self._ctx_q is not None:
|
if self._ctx_q is not None:
|
||||||
clr = CL.clReleaseCommandQueue(self._ctx_q)
|
clr = CL.clReleaseCommandQueue(self._ctx_q)
|
||||||
|
@ -308,7 +312,7 @@ Total memory allocated: {self._total_memory_allocated}
|
||||||
Total buffers allocated: {self._total_buffers_allocated}
|
Total buffers allocated: {self._total_buffers_allocated}
|
||||||
Total memory pooled: {self._total_memory_pooled}
|
Total memory pooled: {self._total_memory_pooled}
|
||||||
Total buffers pooled: {self._total_buffers_pooled}
|
Total buffers pooled: {self._total_buffers_pooled}
|
||||||
N of compiled kernels: {len(self._compiled_kernels)}
|
N of compiled kernels: {len(self._cached_kernels)}
|
||||||
N of cacheddata: {len(self._cached_data)}
|
N of cacheddata: {len(self._cached_data)}
|
||||||
'''
|
'''
|
||||||
print(s)
|
print(s)
|
||||||
|
|
|
@ -4,6 +4,7 @@ from .binary_erode_circle import binary_erode_circle
|
||||||
from .binary_morph import binary_morph
|
from .binary_morph import binary_morph
|
||||||
from .cast import cast
|
from .cast import cast
|
||||||
from .concat import concat
|
from .concat import concat
|
||||||
|
from .cvt_color import cvt_color
|
||||||
from .depthwise_conv2D import depthwise_conv2D
|
from .depthwise_conv2D import depthwise_conv2D
|
||||||
from .gaussian_blur import gaussian_blur
|
from .gaussian_blur import gaussian_blur
|
||||||
from .matmul import matmul, matmulc
|
from .matmul import matmul, matmulc
|
||||||
|
|
222
xlib/avecl/_internal/op/cvt_color.py
Normal file
222
xlib/avecl/_internal/op/cvt_color.py
Normal file
|
@ -0,0 +1,222 @@
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from ..AShape import AShape
|
||||||
|
from ..backend import Kernel
|
||||||
|
from ..HKernel import HKernel
|
||||||
|
from ..SCacheton import SCacheton
|
||||||
|
from ..Tensor import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
def cvt_color (input_t : Tensor, in_mode : str, out_mode : str, ch_axis=1, dtype=None):
|
||||||
|
"""
|
||||||
|
converts color
|
||||||
|
|
||||||
|
input_t Tensor (...,C,...) float16/32/64
|
||||||
|
|
||||||
|
in_mode str 'RGB', 'BGR', 'XYZ', 'LAB'
|
||||||
|
|
||||||
|
out_mode str 'RGB', 'BGR', 'XYZ', 'LAB'
|
||||||
|
|
||||||
|
ch_axis(1) int num of axis contains channels
|
||||||
|
default 1 (assuming NCHW input)
|
||||||
|
|
||||||
|
dtype output_dtype float16/32/64
|
||||||
|
"""
|
||||||
|
op = SCacheton.get(_CvtColor32Op, input_t.shape, input_t.dtype, in_mode, dtype, out_mode, ch_axis)
|
||||||
|
|
||||||
|
device = input_t.get_device()
|
||||||
|
|
||||||
|
if op.output_same_as_input:
|
||||||
|
output_t = input_t.copy()
|
||||||
|
if dtype is not None:
|
||||||
|
output_t = output_t.cast(dtype)
|
||||||
|
else:
|
||||||
|
output_t = Tensor(op.o_shape, op.o_dtype, device=device)
|
||||||
|
|
||||||
|
device.run_kernel(op.forward_krn, output_t.get_buffer(), input_t.get_buffer(), op.krn_S0, op.krn_S1,
|
||||||
|
global_shape=op.global_shape )
|
||||||
|
|
||||||
|
return output_t
|
||||||
|
|
||||||
|
_allowed_modes = ['RGB', 'BGR', 'XYZ', 'LAB']
|
||||||
|
_allowed_dtypes = [np.float16, np.float32, np.float64]
|
||||||
|
|
||||||
|
class _CvtColor32Op():
|
||||||
|
def __init__(self, i_shape : AShape, i_dtype, in_mode, o_dtype, out_mode, ch_axis):
|
||||||
|
self.o_dtype = o_dtype = o_dtype if o_dtype is not None else i_dtype
|
||||||
|
|
||||||
|
if in_mode not in _allowed_modes:
|
||||||
|
raise ValueError(f'in_mode {in_mode} not in allowed modes: {_allowed_modes}')
|
||||||
|
if out_mode not in _allowed_modes:
|
||||||
|
raise ValueError(f'out_mode {out_mode} not in allowed modes: {_allowed_modes}')
|
||||||
|
if i_dtype not in _allowed_dtypes:
|
||||||
|
raise Exception(f'input dtype not in {_allowed_dtypes}')
|
||||||
|
if o_dtype not in _allowed_dtypes:
|
||||||
|
raise Exception(f'output dtype not in {_allowed_dtypes}')
|
||||||
|
|
||||||
|
in_ch = 3 if in_mode in ['RGB', 'BGR', 'XYZ', 'LAB'] else None
|
||||||
|
out_ch = 3 if in_mode in ['RGB', 'BGR', 'XYZ', 'LAB'] else None
|
||||||
|
if i_shape[ch_axis] != in_ch:
|
||||||
|
raise ValueError(f'input ch_axis must have size {in_ch} for {in_mode} mode')
|
||||||
|
|
||||||
|
self.o_shape = i_shape.replaced_axes([ch_axis], [out_ch])
|
||||||
|
|
||||||
|
s0_shape, s1_shape = i_shape.split(ch_axis)
|
||||||
|
s1_shape = s1_shape[1:]
|
||||||
|
|
||||||
|
self.krn_S0 = np.int64(s0_shape.size)
|
||||||
|
self.krn_S1 = np.int64(s1_shape.size)
|
||||||
|
|
||||||
|
self.global_shape = (s0_shape.size*s1_shape.size,)
|
||||||
|
|
||||||
|
self.output_same_as_input = in_mode == out_mode
|
||||||
|
|
||||||
|
if not self.output_same_as_input:
|
||||||
|
|
||||||
|
key = (_CvtColor32Op, in_mode, out_mode, i_dtype, o_dtype)
|
||||||
|
krn = SCacheton.get_var(key)
|
||||||
|
if krn is None:
|
||||||
|
body = None
|
||||||
|
|
||||||
|
if in_mode in ['RGB','XYZ','LAB']:
|
||||||
|
in_args = ['I0','I1','I2']
|
||||||
|
elif in_mode == 'BGR':
|
||||||
|
in_args = ['I2','I1','I0']
|
||||||
|
|
||||||
|
if out_mode in ['RGB','XYZ','LAB']:
|
||||||
|
out_args = ['O0','O1','O2']
|
||||||
|
elif out_mode == 'BGR':
|
||||||
|
out_args = ['O2','O1','O0']
|
||||||
|
|
||||||
|
get_body_func = _modes_to_body_func.get( (in_mode, out_mode), None )
|
||||||
|
if get_body_func is None:
|
||||||
|
raise ValueError(f'{in_mode} -> {out_mode} is not supported.')
|
||||||
|
|
||||||
|
body = get_body_func( *(in_args+out_args) )
|
||||||
|
|
||||||
|
krn = Kernel(kernel_text=_CvtColor32Op.fused_kernel(in_ch, i_dtype, out_ch, o_dtype, body=body))
|
||||||
|
SCacheton.set_var(key, krn)
|
||||||
|
|
||||||
|
self.forward_krn = krn
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_RGB_to_LAB_body(R,G,B,L,a,b,lab_type='') -> str:
|
||||||
|
return f"""
|
||||||
|
{_CvtColor32Op.get_RGB_to_XYZ_body(R,G,B,'X','Y','Z', xyz_type='float')}
|
||||||
|
{_CvtColor32Op.get_XYZ_to_LAB_body('X','Y','Z',L,a,b, lab_type=lab_type)}
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_LAB_to_RGB_body(L,a,b,R,G,B,rgb_type='') -> str:
|
||||||
|
return f"""
|
||||||
|
{_CvtColor32Op.get_LAB_to_XYZ_body(L,a,b,'X','Y','Z', xyz_type='float')}
|
||||||
|
{_CvtColor32Op.get_XYZ_to_RGB_body('X','Y','Z',R,G,B,rgb_type=rgb_type)}
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_RGB_to_XYZ_body(R,G,B,X,Y,Z,xyz_type='') -> str:
|
||||||
|
return f"""
|
||||||
|
{xyz_type} {X} = fma(0.4124564, {R}, fma(0.3575761, {G}, 0.1804375*{B}));
|
||||||
|
{xyz_type} {Y} = fma(0.2126729, {R}, fma(0.7151522, {G}, 0.0721750*{B}));
|
||||||
|
{xyz_type} {Z} = fma(0.0193339, {R}, fma(0.1191920, {G}, 0.9503041*{B}));
|
||||||
|
"""
|
||||||
|
@staticmethod
|
||||||
|
def get_XYZ_to_RGB_body(X,Y,Z,R,G,B,rgb_type='') -> str:
|
||||||
|
return f"""
|
||||||
|
{rgb_type} {R} = fma( 3.2404542, {X}, fma(-1.5371385, {Y}, -0.4985314*{Z}));
|
||||||
|
{rgb_type} {G} = fma(-0.9692660, {X}, fma( 1.8760108, {Y}, 0.0415560*{Z}));
|
||||||
|
{rgb_type} {B} = fma( 0.0556434, {X}, fma(-0.2040259, {Y}, 1.0572252*{Z}));
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_RGB_to_BGR_body(R,G,B,b,g,r,bgr_type='') -> str:
|
||||||
|
return f"""
|
||||||
|
{bgr_type} {b} = {R};
|
||||||
|
{bgr_type} {g} = {G};
|
||||||
|
{bgr_type} {r} = {B};
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_BGR_to_RGB_body(B,G,R,r,g,b,rgb_type='') -> str:
|
||||||
|
return f"""
|
||||||
|
{rgb_type} {r} = {B};
|
||||||
|
{rgb_type} {g} = {G};
|
||||||
|
{rgb_type} {b} = {R};
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_XYZ_to_LAB_body(X,Y,Z,L,A,B,lab_type='') -> str:
|
||||||
|
beta3 = '((6.0/29.0)*(6.0/29.0)*(6.0/29.0))'
|
||||||
|
xyz_xn = '(0.9556)'
|
||||||
|
xyz_zn = '(1.088754)'
|
||||||
|
return f"""
|
||||||
|
{X} /= {xyz_xn};
|
||||||
|
{Z} /= {xyz_zn};
|
||||||
|
|
||||||
|
{X} = ({X} > {beta3})*rootn({X}, 3) + ({X} <= {beta3})*(7.787*{X}+4.0/29.0);
|
||||||
|
{Y} = ({Y} > {beta3})*rootn({Y}, 3) + ({Y} <= {beta3})*(7.787*{Y}+4.0/29.0);
|
||||||
|
{Z} = ({Z} > {beta3})*rootn({Z}, 3) + ({Z} <= {beta3})*(7.787*{Z}+4.0/29.0);
|
||||||
|
|
||||||
|
{lab_type} {L} = 116.0*{Y}-16.0;
|
||||||
|
{lab_type} {A} = 500.0*({X}-{Y});
|
||||||
|
{lab_type} {B} = 200.0*({Y}-{Z});
|
||||||
|
"""
|
||||||
|
@staticmethod
|
||||||
|
def get_LAB_to_XYZ_body(L,A,B,X,Y,Z,xyz_type='') -> str:
|
||||||
|
beta = '(6.0/29.0)'
|
||||||
|
beta2 = '((6.0/29.0)*(6.0/29.0))'
|
||||||
|
xyz_xn = '(0.9556)'
|
||||||
|
xyz_zn = '(1.088754)'
|
||||||
|
return f"""
|
||||||
|
{xyz_type} {Y} = ({L} + 16.0) / 116.0;
|
||||||
|
{xyz_type} {X} = {Y} + {A} / 500.0;
|
||||||
|
{xyz_type} {Z} = {Y} - {B} / 200.0;
|
||||||
|
|
||||||
|
{Y} = ({Y} > {beta})*({Y}*{Y}*{Y}) + ({Y} <= {beta})*({Y}-16.0/116.0)*3*{beta2};
|
||||||
|
{X} = ({X} > {beta})*({X}*{X}*{X}*{xyz_xn}) + ({X} <= {beta})*({X}-16.0/116.0)*3*{beta2}*{xyz_xn};
|
||||||
|
{Z} = ({Z} > {beta})*({Z}*{Z}*{Z}*{xyz_zn}) + ({Z} <= {beta})*({Z}-16.0/116.0)*3*{beta2}*{xyz_zn};
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def fused_kernel(i_ch : int, i_dtype, o_ch : int, o_dtype, body : str) -> str:
|
||||||
|
line_sep = '\n'
|
||||||
|
return f"""
|
||||||
|
{HKernel.define_ndim_idx(o_ch)}
|
||||||
|
{HKernel.define_ndim_idx(i_ch)}
|
||||||
|
{HKernel.define_tensor_type('O', o_dtype)}
|
||||||
|
{HKernel.define_tensor_type('I', i_dtype)}
|
||||||
|
|
||||||
|
__kernel void impl(__global O_PTR_TYPE* O_PTR_NAME, __global const I_PTR_TYPE* I_PTR_NAME, long S0, long S1)
|
||||||
|
{{
|
||||||
|
size_t gid = get_global_id(0);
|
||||||
|
{HKernel.decompose_idx_to_axes_idxs('gid', 'S', 2)}
|
||||||
|
|
||||||
|
{line_sep.join([f'size_t i_idx{i} = NDIM{i_ch}_IDX(s0, {i}, s1, S0, {i_ch}, S1);' for i in range(i_ch)])}
|
||||||
|
{line_sep.join([f'size_t o_idx{o} = NDIM{o_ch}_IDX(s0, {o}, s1, S0, {o_ch}, S1);' for o in range(o_ch)])}
|
||||||
|
|
||||||
|
{line_sep.join([f'float I{i} = I_GLOBAL_LOAD(i_idx{i});' for i in range(i_ch)])}
|
||||||
|
{line_sep.join([f'float O{o};' for o in range(o_ch)])}
|
||||||
|
|
||||||
|
{body}
|
||||||
|
|
||||||
|
{line_sep.join([f'O_GLOBAL_STORE(o_idx{o}, O{o});' for o in range(o_ch)])}
|
||||||
|
}}
|
||||||
|
"""
|
||||||
|
|
||||||
|
_modes_to_body_func = {
|
||||||
|
('RGB','BGR') : _CvtColor32Op.get_RGB_to_BGR_body,
|
||||||
|
('BGR','RGB') : _CvtColor32Op.get_BGR_to_RGB_body,
|
||||||
|
|
||||||
|
('RGB','XYZ') : _CvtColor32Op.get_RGB_to_XYZ_body,
|
||||||
|
('RGB','LAB') : _CvtColor32Op.get_RGB_to_LAB_body,
|
||||||
|
('BGR','XYZ') : _CvtColor32Op.get_RGB_to_XYZ_body,
|
||||||
|
('BGR','LAB') : _CvtColor32Op.get_RGB_to_LAB_body,
|
||||||
|
|
||||||
|
('XYZ','RGB') : _CvtColor32Op.get_XYZ_to_RGB_body,
|
||||||
|
('LAB','RGB') : _CvtColor32Op.get_LAB_to_RGB_body,
|
||||||
|
('XYZ','BGR') : _CvtColor32Op.get_XYZ_to_RGB_body,
|
||||||
|
('LAB','BGR') : _CvtColor32Op.get_LAB_to_RGB_body,
|
||||||
|
|
||||||
|
('XYZ','LAB') : _CvtColor32Op.get_XYZ_to_LAB_body,
|
||||||
|
('LAB','XYZ') : _CvtColor32Op.get_LAB_to_XYZ_body,
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue