From a22f8b571a2a1c045d27020423c6a062e32d256c Mon Sep 17 00:00:00 2001 From: iperov Date: Sat, 9 Oct 2021 13:27:21 +0400 Subject: [PATCH] upd xlib.avecl --- xlib/avecl/_internal/op/remap_np_affine.py | 30 ++++++++++++++++------ 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/xlib/avecl/_internal/op/remap_np_affine.py b/xlib/avecl/_internal/op/remap_np_affine.py index 0ec0afe..a02f7db 100644 --- a/xlib/avecl/_internal/op/remap_np_affine.py +++ b/xlib/avecl/_internal/op/remap_np_affine.py @@ -8,7 +8,7 @@ from ..SCacheton import SCacheton from ..Tensor import Tensor -def remap_np_affine (input_t : Tensor, affine_n : np.ndarray, interpolation : EInterpolation = None, inverse=False, output_size=None, dtype=None) -> Tensor: +def remap_np_affine (input_t : Tensor, affine_n : np.ndarray, interpolation : EInterpolation = None, inverse=False, output_size=None, post_op_text=None, dtype=None) -> Tensor: """ remap affine operator for all channels using single numpy affine mat @@ -20,12 +20,18 @@ def remap_np_affine (input_t : Tensor, affine_n : np.ndarray, interpolation : EI interpolation EInterpolation + post_op_text cl kernel + post operation with output float value named 'O' + example 'O = 2*O;' + + output_size (w,h) + dtype """ if affine_n.shape != (2,3): raise ValueError('affine_n.shape must be (2,3)') - op = SCacheton.get(_RemapAffineOp, input_t.shape, input_t.dtype, interpolation, output_size, dtype) + op = SCacheton.get(_RemapAffineOp, input_t.shape, input_t.dtype, interpolation, output_size, post_op_text, dtype) output_t = Tensor( op.o_shape, op.o_dtype, device=input_t.get_device() ) @@ -45,7 +51,7 @@ def remap_np_affine (input_t : Tensor, affine_n : np.ndarray, interpolation : EI class _RemapAffineOp(): - def __init__(self, i_shape : AShape, i_dtype, interpolation, o_size, o_dtype): + def __init__(self, i_shape : AShape, i_dtype, interpolation, o_size, post_op_text, o_dtype): if np.dtype(i_dtype).type == np.bool_: raise ValueError('np.bool_ dtype of i_dtype is not supported.') if i_shape.ndim < 2: @@ -65,7 +71,11 @@ class _RemapAffineOp(): self.o_shape = o_shape self.o_dtype = o_dtype = o_dtype if o_dtype is not None else i_dtype - + + if post_op_text is None: + post_op_text = '' + + if interpolation == EInterpolation.LINEAR: self.forward_krn = Kernel(global_shape=(o_shape.size,), kernel_text=f""" @@ -97,8 +107,12 @@ __kernel void impl(__global O_PTR_TYPE* O_PTR_NAME, __global const I_PTR_TYPE* I p01 *= (cx01 - cx0f)*(cy1f - cy01)*(cy0 >= 0 & cy0 < Im2 & cx1 >= 0 & cx1 < Im1); p10 *= (cx1f - cx01)*(cy01 - cy0f)*(cy1 >= 0 & cy1 < Im2 & cx0 >= 0 & cx0 < Im1); p11 *= (cx01 - cx0f)*(cy01 - cy0f)*(cy1 >= 0 & cy1 < Im2 & cx1 >= 0 & cx1 < Im1); - - O_GLOBAL_STORE(gid, p00 + p01 + p10 + p11); + + float O = p00 + p01 + p10 + p11; + + {post_op_text} + + O_GLOBAL_STORE(gid, O); }} """) elif interpolation == EInterpolation.CUBIC: @@ -150,7 +164,7 @@ __kernel void impl(__global O_PTR_TYPE* O_PTR_NAME, __global const I_PTR_TYPE* I }} float O = cubic(row[0], row[1], row[2], row[3], dy); - + {post_op_text} O_GLOBAL_STORE(gid, O); }} """) @@ -213,7 +227,7 @@ __kernel void impl(__global O_PTR_TYPE* O_PTR_NAME, __global const I_PTR_TYPE* I O += sxy*Fxyv*(y >= 0 & y < Im2 & x >= 0 & x < Im1); }} O = O / FxFysum; - + {post_op_text} O_GLOBAL_STORE(gid, O); }} """)