mirror of
https://github.com/iperov/DeepFaceLive
synced 2025-08-21 05:53:25 -07:00
update xlib.avecl
This commit is contained in:
parent
2d401f47f8
commit
6da916cc66
14 changed files with 246 additions and 184 deletions
|
@ -1,27 +1,31 @@
|
|||
import numpy as np
|
||||
|
||||
from ..AAxes import AAxes
|
||||
from ..AShape import AShape
|
||||
from ..backend import Kernel
|
||||
from ..HArgs import HArgs
|
||||
from ..HKernel import HKernel
|
||||
from ..HType import HType
|
||||
from ..info import BroadcastInfo
|
||||
from ..info import BroadcastInfo, ReductionInfo
|
||||
from ..SCacheton import SCacheton
|
||||
from ..Tensor import Tensor
|
||||
|
||||
|
||||
def any_wise(op_text : str,
|
||||
*args,
|
||||
dim_wise_axis : int = None,
|
||||
dtype : np.dtype = None,
|
||||
output_t:Tensor=None) -> Tensor:
|
||||
"""
|
||||
operator for N-wise ops with N inputs
|
||||
elements-wise operator with N inputs
|
||||
|
||||
arguments
|
||||
op_text example: O=(2*I0*I1)+I2
|
||||
|
||||
*args List[ Tensor | number ]
|
||||
|
||||
dim_wise_axis(None)
|
||||
|
||||
dtype
|
||||
|
||||
output_t compute result to this Tensor.
|
||||
|
@ -33,7 +37,7 @@ def any_wise(op_text : str,
|
|||
|
||||
shape_list, dtype_list, krn_args = HArgs.decompose(args)
|
||||
|
||||
op = SCacheton.get(_AnyWiseOp, shape_list, dtype_list, dtype, op_text)
|
||||
op = SCacheton.get(_AnyWiseOp, shape_list, dtype_list, dim_wise_axis, dtype, op_text)
|
||||
|
||||
if output_t is None:
|
||||
output_t = Tensor ( op.o_shape, op.o_dtype, device=device )
|
||||
|
@ -45,59 +49,60 @@ def any_wise(op_text : str,
|
|||
return output_t
|
||||
|
||||
class _AnyWiseOp:
|
||||
def __init__(self, shape_list, dtype_list, o_dtype, op_text : str):
|
||||
def __init__(self, shape_list, dtype_list, dim_wise_axis, o_dtype, op_text : str):
|
||||
if len(shape_list) != len(dtype_list):
|
||||
raise ValueError('len(shape_list) != len(dtype_list)')
|
||||
|
||||
self.o_dtype = o_dtype = o_dtype if o_dtype is not None else HType.get_most_weighted_dtype (dtype_list)
|
||||
self.info = info = BroadcastInfo( [ shape if shape is not None else AShape((1,)) for shape in shape_list ])
|
||||
self.o_shape = o_shape = info.o_shape
|
||||
|
||||
if len(shape_list) == 1:
|
||||
# element-wise.
|
||||
i_shape, i_dtype = shape_list[0], dtype_list[0]
|
||||
self.o_shape = o_shape = i_shape
|
||||
g_shape = o_shape
|
||||
if dim_wise_axis is not None:
|
||||
dim_wise_axis = o_shape.check_axis(dim_wise_axis)
|
||||
|
||||
self.forward_krn = Kernel(global_shape=(o_shape.size,), kernel_text=f"""
|
||||
{HKernel.define_tensor('O', o_shape, o_dtype)}
|
||||
{HKernel.define_tensor('IN', i_shape, i_dtype)}
|
||||
__kernel void impl(__global O_PTR_TYPE* O_PTR_NAME, __global const IN_PTR_TYPE* IN_PTR_NAME)
|
||||
{{
|
||||
size_t gid = get_global_id(0);
|
||||
dim_wise_axis_size = o_shape[dim_wise_axis]
|
||||
if dim_wise_axis_size > 16:
|
||||
raise ValueError(f'dim_wise_axis size > 16: {dim_wise_axis_size}')
|
||||
|
||||
O_TYPE O = O_GLOBAL_LOAD(gid);
|
||||
IN_TYPE I0 = IN_GLOBAL_LOAD(gid);
|
||||
{op_text};
|
||||
O_GLOBAL_STORE(gid, O);
|
||||
}}
|
||||
""")
|
||||
else:
|
||||
# Multi arg.
|
||||
self.info = info = BroadcastInfo( [ shape if shape is not None else AShape((1,)) for shape in shape_list ])
|
||||
g_shape = ReductionInfo( o_shape, AAxes(dim_wise_axis), False ).o_shape
|
||||
|
||||
self.o_shape = o_shape = info.o_shape
|
||||
defs, arg_defs, impls = [], [], []
|
||||
for i, (t_shape, t_dtype) in enumerate(zip(shape_list, dtype_list)):
|
||||
t_name = f'I{i}'
|
||||
if t_shape is not None:
|
||||
defs.append( HKernel.define_tensor(t_name, info.br_shapes[i], t_dtype) )
|
||||
arg_defs.append( f", __global const {t_name}_PTR_TYPE* {t_name}_PTR_NAME" )
|
||||
|
||||
defs, arg_defs, impls = [], [], []
|
||||
for i, (t_shape, t_dtype) in enumerate(zip(shape_list, dtype_list)):
|
||||
t_name = f'I{i}'
|
||||
if t_shape is not None:
|
||||
defs.append( HKernel.define_tensor(t_name, info.br_shapes[i], t_dtype) )
|
||||
arg_defs.append( f", __global const {t_name}_PTR_TYPE* {t_name}_PTR_NAME" )
|
||||
impls.append( f"{t_name}_TYPE {t_name} = {t_name}_GLOBAL_LOAD({t_name}_IDX_MOD({HKernel.axes_seq_enum('O', info.o_shape.ndim)}));")
|
||||
if dim_wise_axis is not None:
|
||||
for i_elem in range(dim_wise_axis_size):
|
||||
impls.append( f"{t_name}_TYPE {t_name}_{i_elem} = {t_name}_GLOBAL_LOAD({t_name}_IDX_MOD({HKernel.axes_seq_enum('G', g_shape.ndim, new_axis=(f'{i_elem}', dim_wise_axis) )}));")
|
||||
else:
|
||||
arg_defs.append( f", {HKernel.define_scalar_func_arg(t_name, t_dtype)}" )
|
||||
impls.append( f"{t_name}_TYPE {t_name} = {t_name}_GLOBAL_LOAD({t_name}_IDX_MOD({HKernel.axes_seq_enum('G', g_shape.ndim)}));")
|
||||
else:
|
||||
arg_defs.append( f", {HKernel.define_scalar_func_arg(t_name, t_dtype)}" )
|
||||
|
||||
defs, arg_defs, impls = '\n'.join(defs), '\n'.join(arg_defs), '\n'.join(impls)
|
||||
defs, arg_defs, impls = '\n'.join(defs), '\n'.join(arg_defs), '\n'.join(impls)
|
||||
|
||||
self.forward_krn = Kernel(global_shape=(o_shape.size,), kernel_text=f"""
|
||||
if dim_wise_axis is not None:
|
||||
o_def = '\n'.join( f"O_TYPE O_{i_elem};" for i_elem in range(dim_wise_axis_size) )
|
||||
o_store = '\n'.join( f"O_GLOBAL_STORE(O_IDX({HKernel.axes_seq_enum('G', g_shape.ndim, new_axis=(f'{i_elem}', dim_wise_axis) )}), O_{i_elem});" for i_elem in range(dim_wise_axis_size) )
|
||||
else:
|
||||
o_def = 'O_TYPE O;'
|
||||
o_store = 'O_GLOBAL_STORE(gid, O);'
|
||||
|
||||
self.forward_krn = Kernel(global_shape=(g_shape.size,), kernel_text=f"""
|
||||
{defs}
|
||||
{HKernel.define_tensor('O', o_shape, o_dtype)}
|
||||
{HKernel.define_tensor_shape('G', g_shape)}
|
||||
__kernel void impl(__global O_PTR_TYPE* O_PTR_NAME{arg_defs})
|
||||
{{
|
||||
size_t gid = get_global_id(0);
|
||||
{HKernel.decompose_idx_to_axes_idxs('gid', 'o', o_shape.ndim)}
|
||||
{HKernel.decompose_idx_to_axes_idxs('gid', 'G', g_shape.ndim)}
|
||||
{impls}
|
||||
O_TYPE O;
|
||||
{o_def}
|
||||
{op_text};
|
||||
O_GLOBAL_STORE(gid, O);
|
||||
{o_store}
|
||||
}}
|
||||
""")
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue