mirror of
https://github.com/iperov/DeepFaceLive
synced 2025-07-08 05:51:41 -07:00
116 lines
4.8 KiB
Python
116 lines
4.8 KiB
Python
import numpy as np
|
|
|
|
from ..AAxes import AAxes
|
|
from ..AShape import AShape
|
|
from ..backend import Kernel
|
|
from ..HArgs import HArgs
|
|
from ..HKernel import HKernel
|
|
from ..HType import HType
|
|
from ..info import BroadcastInfo, ReductionInfo
|
|
from ..SCacheton import SCacheton
|
|
from ..Tensor import Tensor
|
|
|
|
|
|
def any_wise(op_text : str,
|
|
*args,
|
|
dim_wise_axis : int = None,
|
|
dtype : np.dtype = None,
|
|
output_t:Tensor=None) -> Tensor:
|
|
"""
|
|
elements-wise operator with N inputs
|
|
|
|
arguments
|
|
op_text example: O=(2*I0*I1)+I2
|
|
|
|
*args List[ Tensor | number ]
|
|
|
|
dim_wise_axis(None)
|
|
|
|
dtype
|
|
|
|
output_t compute result to this Tensor.
|
|
Tensor may be with different shape, but should match total size.
|
|
"""
|
|
HArgs.check_zero_get_length(args)
|
|
tensor_args = HArgs.filter_tensor(args, raise_on_empty=True)
|
|
device = HArgs.check_get_same_device(tensor_args)
|
|
|
|
shape_list, dtype_list, krn_args = HArgs.decompose(args)
|
|
|
|
op = SCacheton.get(_AnyWiseOp, shape_list, dtype_list, dim_wise_axis, dtype, op_text)
|
|
|
|
if output_t is None:
|
|
output_t = Tensor ( op.o_shape, op.o_dtype, device=device )
|
|
elif output_t.shape.size != op.o_shape.size:
|
|
raise ValueError(f'output_t must have size {op.o_shape.size}')
|
|
|
|
device.run_kernel(op.forward_krn, output_t.get_buffer(), *krn_args)
|
|
|
|
return output_t
|
|
|
|
class _AnyWiseOp:
|
|
def __init__(self, shape_list, dtype_list, dim_wise_axis, o_dtype, op_text : str):
|
|
if len(shape_list) != len(dtype_list):
|
|
raise ValueError('len(shape_list) != len(dtype_list)')
|
|
|
|
self.o_dtype = o_dtype = o_dtype if o_dtype is not None else HType.get_most_weighted_dtype (dtype_list)
|
|
self.info = info = BroadcastInfo( [ shape if shape is not None else AShape((1,)) for shape in shape_list ])
|
|
self.o_shape = o_shape = info.o_shape
|
|
|
|
g_shape = o_shape
|
|
if dim_wise_axis is not None:
|
|
dim_wise_axis = o_shape.check_axis(dim_wise_axis)
|
|
|
|
dim_wise_axis_size = o_shape[dim_wise_axis]
|
|
if dim_wise_axis_size > 16:
|
|
raise ValueError(f'dim_wise_axis size > 16: {dim_wise_axis_size}')
|
|
|
|
g_shape = ReductionInfo( o_shape, AAxes(dim_wise_axis), False ).o_shape
|
|
|
|
defs, arg_defs, impls = [], [], []
|
|
for i, (t_shape, t_dtype) in enumerate(zip(shape_list, dtype_list)):
|
|
t_name = f'I{i}'
|
|
if t_shape is not None:
|
|
defs.append( HKernel.define_tensor(t_name, info.br_shapes[i], t_dtype) )
|
|
arg_defs.append( f", __global const {t_name}_PTR_TYPE* {t_name}_PTR_NAME" )
|
|
|
|
if dim_wise_axis is not None:
|
|
for i_elem in range(dim_wise_axis_size):
|
|
impls.append( f"{t_name}_TYPE {t_name}_{i_elem} = {t_name}_GLOBAL_LOAD({t_name}_IDX_MOD({HKernel.axes_seq_enum('G', g_shape.ndim, new_axis=(f'{i_elem}', dim_wise_axis) )}));")
|
|
else:
|
|
impls.append( f"{t_name}_TYPE {t_name} = {t_name}_GLOBAL_LOAD({t_name}_IDX_MOD({HKernel.axes_seq_enum('G', g_shape.ndim)}));")
|
|
else:
|
|
arg_defs.append( f", {HKernel.define_scalar_func_arg(t_name, t_dtype)}" )
|
|
|
|
defs, arg_defs, impls = '\n'.join(defs), '\n'.join(arg_defs), '\n'.join(impls)
|
|
|
|
if dim_wise_axis is not None:
|
|
o_def = '\n'.join( f"O_TYPE O_{i_elem};" for i_elem in range(dim_wise_axis_size) )
|
|
o_store = '\n'.join( f"O_GLOBAL_STORE(O_IDX({HKernel.axes_seq_enum('G', g_shape.ndim, new_axis=(f'{i_elem}', dim_wise_axis) )}), O_{i_elem});" for i_elem in range(dim_wise_axis_size) )
|
|
else:
|
|
o_def = 'O_TYPE O;'
|
|
o_store = 'O_GLOBAL_STORE(gid, O);'
|
|
|
|
self.forward_krn = Kernel(global_shape=(g_shape.size,), kernel_text=f"""
|
|
{defs}
|
|
{HKernel.define_tensor('O', o_shape, o_dtype)}
|
|
{HKernel.define_tensor_shape('G', g_shape)}
|
|
__kernel void impl(__global O_PTR_TYPE* O_PTR_NAME{arg_defs})
|
|
{{
|
|
size_t gid = get_global_id(0);
|
|
{HKernel.decompose_idx_to_axes_idxs('gid', 'G', g_shape.ndim)}
|
|
{impls}
|
|
{o_def}
|
|
{op_text};
|
|
{o_store}
|
|
}}
|
|
""")
|
|
|
|
def add(a_t : Tensor, b_t : Tensor) -> Tensor: return any_wise('O=I0+I1', a_t, b_t)
|
|
def sub(a_t : Tensor, b_t : Tensor) -> Tensor: return any_wise('O=I0-I1', a_t, b_t)
|
|
def mul(a_t : Tensor, b_t : Tensor) -> Tensor: return any_wise('O=I0*I1', a_t, b_t)
|
|
def div(a_t : Tensor, b_t : Tensor) -> Tensor: return any_wise('O=I0/I1', a_t, b_t)
|
|
def min_(a_t : Tensor, b_t : Tensor) -> Tensor: return any_wise('O=fmin( I0_TO_FLOATX(I0), I1_TO_FLOATX(I1) )', a_t, b_t)
|
|
def max_(a_t : Tensor, b_t : Tensor) -> Tensor: return any_wise('O=fmax( I0_TO_FLOATX(I0), I1_TO_FLOATX(I1) )', a_t, b_t)
|
|
def square(a_t : Tensor) -> Tensor: return any_wise('O=I0*I0', a_t)
|
|
def sqrt(a_t : Tensor) -> Tensor: return any_wise('O=sqrt(I0_TO_FLOATX(I0))', a_t)
|