DeepFaceLive/xlib/avecl/_internal/op/any_wise.py
2021-10-20 18:02:50 +04:00

116 lines
4.8 KiB
Python

import numpy as np
from ..AAxes import AAxes
from ..AShape import AShape
from ..backend import Kernel
from ..HArgs import HArgs
from ..HKernel import HKernel
from ..HType import HType
from ..info import BroadcastInfo, ReductionInfo
from ..SCacheton import SCacheton
from ..Tensor import Tensor
def any_wise(op_text : str,
*args,
dim_wise_axis : int = None,
dtype : np.dtype = None,
output_t:Tensor=None) -> Tensor:
"""
elements-wise operator with N inputs
arguments
op_text example: O=(2*I0*I1)+I2
*args List[ Tensor | number ]
dim_wise_axis(None)
dtype
output_t compute result to this Tensor.
Tensor may be with different shape, but should match total size.
"""
HArgs.check_zero_get_length(args)
tensor_args = HArgs.filter_tensor(args, raise_on_empty=True)
device = HArgs.check_get_same_device(tensor_args)
shape_list, dtype_list, krn_args = HArgs.decompose(args)
op = SCacheton.get(_AnyWiseOp, shape_list, dtype_list, dim_wise_axis, dtype, op_text)
if output_t is None:
output_t = Tensor ( op.o_shape, op.o_dtype, device=device )
elif output_t.shape.size != op.o_shape.size:
raise ValueError(f'output_t must have size {op.o_shape.size}')
device.run_kernel(op.forward_krn, output_t.get_buffer(), *krn_args)
return output_t
class _AnyWiseOp:
def __init__(self, shape_list, dtype_list, dim_wise_axis, o_dtype, op_text : str):
if len(shape_list) != len(dtype_list):
raise ValueError('len(shape_list) != len(dtype_list)')
self.o_dtype = o_dtype = o_dtype if o_dtype is not None else HType.get_most_weighted_dtype (dtype_list)
self.info = info = BroadcastInfo( [ shape if shape is not None else AShape((1,)) for shape in shape_list ])
self.o_shape = o_shape = info.o_shape
g_shape = o_shape
if dim_wise_axis is not None:
dim_wise_axis = o_shape.check_axis(dim_wise_axis)
dim_wise_axis_size = o_shape[dim_wise_axis]
if dim_wise_axis_size > 16:
raise ValueError(f'dim_wise_axis size > 16: {dim_wise_axis_size}')
g_shape = ReductionInfo( o_shape, AAxes(dim_wise_axis), False ).o_shape
defs, arg_defs, impls = [], [], []
for i, (t_shape, t_dtype) in enumerate(zip(shape_list, dtype_list)):
t_name = f'I{i}'
if t_shape is not None:
defs.append( HKernel.define_tensor(t_name, info.br_shapes[i], t_dtype) )
arg_defs.append( f", __global const {t_name}_PTR_TYPE* {t_name}_PTR_NAME" )
if dim_wise_axis is not None:
for i_elem in range(dim_wise_axis_size):
impls.append( f"{t_name}_TYPE {t_name}_{i_elem} = {t_name}_GLOBAL_LOAD({t_name}_IDX_MOD({HKernel.axes_seq_enum('G', g_shape.ndim, new_axis=(f'{i_elem}', dim_wise_axis) )}));")
else:
impls.append( f"{t_name}_TYPE {t_name} = {t_name}_GLOBAL_LOAD({t_name}_IDX_MOD({HKernel.axes_seq_enum('G', g_shape.ndim)}));")
else:
arg_defs.append( f", {HKernel.define_scalar_func_arg(t_name, t_dtype)}" )
defs, arg_defs, impls = '\n'.join(defs), '\n'.join(arg_defs), '\n'.join(impls)
if dim_wise_axis is not None:
o_def = '\n'.join( f"O_TYPE O_{i_elem};" for i_elem in range(dim_wise_axis_size) )
o_store = '\n'.join( f"O_GLOBAL_STORE(O_IDX({HKernel.axes_seq_enum('G', g_shape.ndim, new_axis=(f'{i_elem}', dim_wise_axis) )}), O_{i_elem});" for i_elem in range(dim_wise_axis_size) )
else:
o_def = 'O_TYPE O;'
o_store = 'O_GLOBAL_STORE(gid, O);'
self.forward_krn = Kernel(global_shape=(g_shape.size,), kernel_text=f"""
{defs}
{HKernel.define_tensor('O', o_shape, o_dtype)}
{HKernel.define_tensor_shape('G', g_shape)}
__kernel void impl(__global O_PTR_TYPE* O_PTR_NAME{arg_defs})
{{
size_t gid = get_global_id(0);
{HKernel.decompose_idx_to_axes_idxs('gid', 'G', g_shape.ndim)}
{impls}
{o_def}
{op_text};
{o_store}
}}
""")
def add(a_t : Tensor, b_t : Tensor) -> Tensor: return any_wise('O=I0+I1', a_t, b_t)
def sub(a_t : Tensor, b_t : Tensor) -> Tensor: return any_wise('O=I0-I1', a_t, b_t)
def mul(a_t : Tensor, b_t : Tensor) -> Tensor: return any_wise('O=I0*I1', a_t, b_t)
def div(a_t : Tensor, b_t : Tensor) -> Tensor: return any_wise('O=I0/I1', a_t, b_t)
def min_(a_t : Tensor, b_t : Tensor) -> Tensor: return any_wise('O=fmin( I0_TO_FLOATX(I0), I1_TO_FLOATX(I1) )', a_t, b_t)
def max_(a_t : Tensor, b_t : Tensor) -> Tensor: return any_wise('O=fmax( I0_TO_FLOATX(I0), I1_TO_FLOATX(I1) )', a_t, b_t)
def square(a_t : Tensor) -> Tensor: return any_wise('O=I0*I0', a_t)
def sqrt(a_t : Tensor) -> Tensor: return any_wise('O=sqrt(I0_TO_FLOATX(I0))', a_t)