mirror of
https://github.com/iperov/DeepFaceLive
synced 2025-07-16 10:03:42 -07:00
111 lines
4.2 KiB
Python
111 lines
4.2 KiB
Python
import numpy as np
|
|
|
|
from ..AShape import AShape
|
|
from ..backend import Kernel
|
|
from ..HArgs import HArgs
|
|
from ..HKernel import HKernel
|
|
from ..HType import HType
|
|
from ..info import BroadcastInfo
|
|
from ..SCacheton import SCacheton
|
|
from ..Tensor import Tensor
|
|
|
|
|
|
def any_wise(op_text : str,
|
|
*args,
|
|
dtype : np.dtype = None,
|
|
output_t:Tensor=None) -> Tensor:
|
|
"""
|
|
operator for N-wise ops with N inputs
|
|
|
|
arguments
|
|
op_text example: O=(2*I0*I1)+I2
|
|
|
|
*args List[ Tensor | number ]
|
|
|
|
dtype
|
|
|
|
output_t compute result to this Tensor.
|
|
Tensor may be with different shape, but should match total size.
|
|
"""
|
|
HArgs.check_zero_get_length(args)
|
|
tensor_args = HArgs.filter_tensor(args, raise_on_empty=True)
|
|
device = HArgs.check_get_same_device(tensor_args)
|
|
|
|
shape_list, dtype_list, krn_args = HArgs.decompose(args)
|
|
|
|
op = SCacheton.get(_AnyWiseOp, shape_list, dtype_list, dtype, op_text)
|
|
|
|
if output_t is None:
|
|
output_t = Tensor ( op.o_shape, op.o_dtype, device=device )
|
|
elif output_t.shape.size != op.o_shape.size:
|
|
raise ValueError(f'output_t must have size {op.o_shape.size}')
|
|
|
|
device.run_kernel(op.forward_krn, output_t.get_buffer(), *krn_args)
|
|
|
|
return output_t
|
|
|
|
class _AnyWiseOp:
|
|
def __init__(self, shape_list, dtype_list, o_dtype, op_text : str):
|
|
if len(shape_list) != len(dtype_list):
|
|
raise ValueError('len(shape_list) != len(dtype_list)')
|
|
|
|
self.o_dtype = o_dtype = o_dtype if o_dtype is not None else HType.get_most_weighted_dtype (dtype_list)
|
|
|
|
if len(shape_list) == 1:
|
|
# element-wise.
|
|
i_shape, i_dtype = shape_list[0], dtype_list[0]
|
|
self.o_shape = o_shape = i_shape
|
|
|
|
self.forward_krn = Kernel(global_shape=(o_shape.size,), kernel_text=f"""
|
|
{HKernel.define_tensor('O', o_shape, o_dtype)}
|
|
{HKernel.define_tensor('IN', i_shape, i_dtype)}
|
|
__kernel void impl(__global O_PTR_TYPE* O_PTR_NAME, __global const IN_PTR_TYPE* IN_PTR_NAME)
|
|
{{
|
|
size_t gid = get_global_id(0);
|
|
|
|
O_TYPE O = O_GLOBAL_LOAD(gid);
|
|
IN_TYPE I0 = IN_GLOBAL_LOAD(gid);
|
|
{op_text};
|
|
O_GLOBAL_STORE(gid, O);
|
|
}}
|
|
""")
|
|
else:
|
|
# Multi arg.
|
|
self.info = info = BroadcastInfo( [ shape if shape is not None else AShape((1,)) for shape in shape_list ])
|
|
|
|
self.o_shape = o_shape = info.o_shape
|
|
|
|
defs, arg_defs, impls = [], [], []
|
|
for i, (t_shape, t_dtype) in enumerate(zip(shape_list, dtype_list)):
|
|
t_name = f'I{i}'
|
|
if t_shape is not None:
|
|
defs.append( HKernel.define_tensor(t_name, info.br_shapes[i], t_dtype) )
|
|
arg_defs.append( f", __global const {t_name}_PTR_TYPE* {t_name}_PTR_NAME" )
|
|
impls.append( f"{t_name}_TYPE {t_name} = {t_name}_GLOBAL_LOAD({t_name}_IDX_MOD({HKernel.axes_seq_enum('O', info.o_shape.ndim)}));")
|
|
else:
|
|
arg_defs.append( f", {HKernel.define_scalar_func_arg(t_name, t_dtype)}" )
|
|
|
|
defs, arg_defs, impls = '\n'.join(defs), '\n'.join(arg_defs), '\n'.join(impls)
|
|
|
|
self.forward_krn = Kernel(global_shape=(o_shape.size,), kernel_text=f"""
|
|
{defs}
|
|
{HKernel.define_tensor('O', o_shape, o_dtype)}
|
|
__kernel void impl(__global O_PTR_TYPE* O_PTR_NAME{arg_defs})
|
|
{{
|
|
size_t gid = get_global_id(0);
|
|
{HKernel.decompose_idx_to_axes_idxs('gid', 'o', o_shape.ndim)}
|
|
{impls}
|
|
O_TYPE O;
|
|
{op_text};
|
|
O_GLOBAL_STORE(gid, O);
|
|
}}
|
|
""")
|
|
|
|
def add(a_t : Tensor, b_t : Tensor) -> Tensor: return any_wise('O=I0+I1', a_t, b_t)
|
|
def sub(a_t : Tensor, b_t : Tensor) -> Tensor: return any_wise('O=I0-I1', a_t, b_t)
|
|
def mul(a_t : Tensor, b_t : Tensor) -> Tensor: return any_wise('O=I0*I1', a_t, b_t)
|
|
def div(a_t : Tensor, b_t : Tensor) -> Tensor: return any_wise('O=I0/I1', a_t, b_t)
|
|
def min_(a_t : Tensor, b_t : Tensor) -> Tensor: return any_wise('O=fmin( I0_TO_FLOATX(I0), I1_TO_FLOATX(I1) )', a_t, b_t)
|
|
def max_(a_t : Tensor, b_t : Tensor) -> Tensor: return any_wise('O=fmax( I0_TO_FLOATX(I0), I1_TO_FLOATX(I1) )', a_t, b_t)
|
|
def square(a_t : Tensor) -> Tensor: return any_wise('O=I0*I0', a_t)
|
|
def sqrt(a_t : Tensor) -> Tensor: return any_wise('O=sqrt(I0_TO_FLOATX(I0))', a_t)
|