upd xlib.avecl

This commit is contained in:
iperov 2021-10-09 13:27:21 +04:00
commit a22f8b571a

View file

@ -8,7 +8,7 @@ from ..SCacheton import SCacheton
from ..Tensor import Tensor
def remap_np_affine (input_t : Tensor, affine_n : np.ndarray, interpolation : EInterpolation = None, inverse=False, output_size=None, dtype=None) -> Tensor:
def remap_np_affine (input_t : Tensor, affine_n : np.ndarray, interpolation : EInterpolation = None, inverse=False, output_size=None, post_op_text=None, dtype=None) -> Tensor:
"""
remap affine operator for all channels using single numpy affine mat
@ -20,12 +20,18 @@ def remap_np_affine (input_t : Tensor, affine_n : np.ndarray, interpolation : EI
interpolation EInterpolation
post_op_text cl kernel
post operation with output float value named 'O'
example 'O = 2*O;'
output_size (w,h)
dtype
"""
if affine_n.shape != (2,3):
raise ValueError('affine_n.shape must be (2,3)')
op = SCacheton.get(_RemapAffineOp, input_t.shape, input_t.dtype, interpolation, output_size, dtype)
op = SCacheton.get(_RemapAffineOp, input_t.shape, input_t.dtype, interpolation, output_size, post_op_text, dtype)
output_t = Tensor( op.o_shape, op.o_dtype, device=input_t.get_device() )
@ -45,7 +51,7 @@ def remap_np_affine (input_t : Tensor, affine_n : np.ndarray, interpolation : EI
class _RemapAffineOp():
def __init__(self, i_shape : AShape, i_dtype, interpolation, o_size, o_dtype):
def __init__(self, i_shape : AShape, i_dtype, interpolation, o_size, post_op_text, o_dtype):
if np.dtype(i_dtype).type == np.bool_:
raise ValueError('np.bool_ dtype of i_dtype is not supported.')
if i_shape.ndim < 2:
@ -65,7 +71,11 @@ class _RemapAffineOp():
self.o_shape = o_shape
self.o_dtype = o_dtype = o_dtype if o_dtype is not None else i_dtype
if post_op_text is None:
post_op_text = ''
if interpolation == EInterpolation.LINEAR:
self.forward_krn = Kernel(global_shape=(o_shape.size,), kernel_text=f"""
@ -97,8 +107,12 @@ __kernel void impl(__global O_PTR_TYPE* O_PTR_NAME, __global const I_PTR_TYPE* I
p01 *= (cx01 - cx0f)*(cy1f - cy01)*(cy0 >= 0 & cy0 < Im2 & cx1 >= 0 & cx1 < Im1);
p10 *= (cx1f - cx01)*(cy01 - cy0f)*(cy1 >= 0 & cy1 < Im2 & cx0 >= 0 & cx0 < Im1);
p11 *= (cx01 - cx0f)*(cy01 - cy0f)*(cy1 >= 0 & cy1 < Im2 & cx1 >= 0 & cx1 < Im1);
O_GLOBAL_STORE(gid, p00 + p01 + p10 + p11);
float O = p00 + p01 + p10 + p11;
{post_op_text}
O_GLOBAL_STORE(gid, O);
}}
""")
elif interpolation == EInterpolation.CUBIC:
@ -150,7 +164,7 @@ __kernel void impl(__global O_PTR_TYPE* O_PTR_NAME, __global const I_PTR_TYPE* I
}}
float O = cubic(row[0], row[1], row[2], row[3], dy);
{post_op_text}
O_GLOBAL_STORE(gid, O);
}}
""")
@ -213,7 +227,7 @@ __kernel void impl(__global O_PTR_TYPE* O_PTR_NAME, __global const I_PTR_TYPE* I
O += sxy*Fxyv*(y >= 0 & y < Im2 & x >= 0 & x < Im1);
}}
O = O / FxFysum;
{post_op_text}
O_GLOBAL_STORE(gid, O);
}}
""")