DeepFaceLive/xlib/avecl/_internal/op/any_wise.py
2021-09-30 18:21:30 +04:00

111 lines
4.2 KiB
Python

import numpy as np
from ..AShape import AShape
from ..backend import Kernel
from ..HArgs import HArgs
from ..HKernel import HKernel
from ..HType import HType
from ..info import BroadcastInfo
from ..SCacheton import SCacheton
from ..Tensor import Tensor
def any_wise(op_text : str,
*args,
dtype : np.dtype = None,
output_t:Tensor=None) -> Tensor:
"""
operator for N-wise ops with N inputs
arguments
op_text example: O=(2*I0*I1)+I2
*args List[ Tensor | number ]
dtype
output_t compute result to this Tensor.
Tensor may be with different shape, but should match total size.
"""
HArgs.check_zero_get_length(args)
tensor_args = HArgs.filter_tensor(args, raise_on_empty=True)
device = HArgs.check_get_same_device(tensor_args)
shape_list, dtype_list, krn_args = HArgs.decompose(args)
op = SCacheton.get(_AnyWiseOp, shape_list, dtype_list, dtype, op_text)
if output_t is None:
output_t = Tensor ( op.o_shape, op.o_dtype, device=device )
elif output_t.shape.size != op.o_shape.size:
raise ValueError(f'output_t must have size {op.o_shape.size}')
device.run_kernel(op.forward_krn, output_t.get_buffer(), *krn_args)
return output_t
class _AnyWiseOp:
def __init__(self, shape_list, dtype_list, o_dtype, op_text : str):
if len(shape_list) != len(dtype_list):
raise ValueError('len(shape_list) != len(dtype_list)')
self.o_dtype = o_dtype = o_dtype if o_dtype is not None else HType.get_most_weighted_dtype (dtype_list)
if len(shape_list) == 1:
# element-wise.
i_shape, i_dtype = shape_list[0], dtype_list[0]
self.o_shape = o_shape = i_shape
self.forward_krn = Kernel(global_shape=(o_shape.size,), kernel_text=f"""
{HKernel.define_tensor('O', o_shape, o_dtype)}
{HKernel.define_tensor('IN', i_shape, i_dtype)}
__kernel void impl(__global O_PTR_TYPE* O_PTR_NAME, __global const IN_PTR_TYPE* IN_PTR_NAME)
{{
size_t gid = get_global_id(0);
O_TYPE O = O_GLOBAL_LOAD(gid);
IN_TYPE I0 = IN_GLOBAL_LOAD(gid);
{op_text};
O_GLOBAL_STORE(gid, O);
}}
""")
else:
# Multi arg.
self.info = info = BroadcastInfo( [ shape if shape is not None else AShape((1,)) for shape in shape_list ])
self.o_shape = o_shape = info.o_shape
defs, arg_defs, impls = [], [], []
for i, (t_shape, t_dtype) in enumerate(zip(shape_list, dtype_list)):
t_name = f'I{i}'
if t_shape is not None:
defs.append( HKernel.define_tensor(t_name, info.br_shapes[i], t_dtype) )
arg_defs.append( f", __global const {t_name}_PTR_TYPE* {t_name}_PTR_NAME" )
impls.append( f"{t_name}_TYPE {t_name} = {t_name}_GLOBAL_LOAD({t_name}_IDX_MOD({HKernel.axes_seq_enum('O', info.o_shape.ndim)}));")
else:
arg_defs.append( f", {HKernel.define_scalar_func_arg(t_name, t_dtype)}" )
defs, arg_defs, impls = '\n'.join(defs), '\n'.join(arg_defs), '\n'.join(impls)
self.forward_krn = Kernel(global_shape=(o_shape.size,), kernel_text=f"""
{defs}
{HKernel.define_tensor('O', o_shape, o_dtype)}
__kernel void impl(__global O_PTR_TYPE* O_PTR_NAME{arg_defs})
{{
size_t gid = get_global_id(0);
{HKernel.decompose_idx_to_axes_idxs('gid', 'o', o_shape.ndim)}
{impls}
O_TYPE O;
{op_text};
O_GLOBAL_STORE(gid, O);
}}
""")
def add(a_t : Tensor, b_t : Tensor) -> Tensor: return any_wise('O=I0+I1', a_t, b_t)
def sub(a_t : Tensor, b_t : Tensor) -> Tensor: return any_wise('O=I0-I1', a_t, b_t)
def mul(a_t : Tensor, b_t : Tensor) -> Tensor: return any_wise('O=I0*I1', a_t, b_t)
def div(a_t : Tensor, b_t : Tensor) -> Tensor: return any_wise('O=I0/I1', a_t, b_t)
def min_(a_t : Tensor, b_t : Tensor) -> Tensor: return any_wise('O=fmin( I0_TO_FLOATX(I0), I1_TO_FLOATX(I1) )', a_t, b_t)
def max_(a_t : Tensor, b_t : Tensor) -> Tensor: return any_wise('O=fmax( I0_TO_FLOATX(I0), I1_TO_FLOATX(I1) )', a_t, b_t)
def square(a_t : Tensor) -> Tensor: return any_wise('O=I0*I0', a_t)
def sqrt(a_t : Tensor) -> Tensor: return any_wise('O=sqrt(I0_TO_FLOATX(I0))', a_t)