mirror of
https://github.com/iperov/DeepFaceLive
synced 2025-08-14 02:37:01 -07:00
upd xlib.avecl
This commit is contained in:
parent
5b3398053e
commit
a22f8b571a
1 changed files with 22 additions and 8 deletions
|
@ -8,7 +8,7 @@ from ..SCacheton import SCacheton
|
|||
from ..Tensor import Tensor
|
||||
|
||||
|
||||
def remap_np_affine (input_t : Tensor, affine_n : np.ndarray, interpolation : EInterpolation = None, inverse=False, output_size=None, dtype=None) -> Tensor:
|
||||
def remap_np_affine (input_t : Tensor, affine_n : np.ndarray, interpolation : EInterpolation = None, inverse=False, output_size=None, post_op_text=None, dtype=None) -> Tensor:
|
||||
"""
|
||||
remap affine operator for all channels using single numpy affine mat
|
||||
|
||||
|
@ -20,12 +20,18 @@ def remap_np_affine (input_t : Tensor, affine_n : np.ndarray, interpolation : EI
|
|||
|
||||
interpolation EInterpolation
|
||||
|
||||
post_op_text cl kernel
|
||||
post operation with output float value named 'O'
|
||||
example 'O = 2*O;'
|
||||
|
||||
output_size (w,h)
|
||||
|
||||
dtype
|
||||
"""
|
||||
if affine_n.shape != (2,3):
|
||||
raise ValueError('affine_n.shape must be (2,3)')
|
||||
|
||||
op = SCacheton.get(_RemapAffineOp, input_t.shape, input_t.dtype, interpolation, output_size, dtype)
|
||||
op = SCacheton.get(_RemapAffineOp, input_t.shape, input_t.dtype, interpolation, output_size, post_op_text, dtype)
|
||||
|
||||
output_t = Tensor( op.o_shape, op.o_dtype, device=input_t.get_device() )
|
||||
|
||||
|
@ -45,7 +51,7 @@ def remap_np_affine (input_t : Tensor, affine_n : np.ndarray, interpolation : EI
|
|||
|
||||
|
||||
class _RemapAffineOp():
|
||||
def __init__(self, i_shape : AShape, i_dtype, interpolation, o_size, o_dtype):
|
||||
def __init__(self, i_shape : AShape, i_dtype, interpolation, o_size, post_op_text, o_dtype):
|
||||
if np.dtype(i_dtype).type == np.bool_:
|
||||
raise ValueError('np.bool_ dtype of i_dtype is not supported.')
|
||||
if i_shape.ndim < 2:
|
||||
|
@ -65,7 +71,11 @@ class _RemapAffineOp():
|
|||
|
||||
self.o_shape = o_shape
|
||||
self.o_dtype = o_dtype = o_dtype if o_dtype is not None else i_dtype
|
||||
|
||||
|
||||
if post_op_text is None:
|
||||
post_op_text = ''
|
||||
|
||||
|
||||
if interpolation == EInterpolation.LINEAR:
|
||||
self.forward_krn = Kernel(global_shape=(o_shape.size,), kernel_text=f"""
|
||||
|
||||
|
@ -97,8 +107,12 @@ __kernel void impl(__global O_PTR_TYPE* O_PTR_NAME, __global const I_PTR_TYPE* I
|
|||
p01 *= (cx01 - cx0f)*(cy1f - cy01)*(cy0 >= 0 & cy0 < Im2 & cx1 >= 0 & cx1 < Im1);
|
||||
p10 *= (cx1f - cx01)*(cy01 - cy0f)*(cy1 >= 0 & cy1 < Im2 & cx0 >= 0 & cx0 < Im1);
|
||||
p11 *= (cx01 - cx0f)*(cy01 - cy0f)*(cy1 >= 0 & cy1 < Im2 & cx1 >= 0 & cx1 < Im1);
|
||||
|
||||
O_GLOBAL_STORE(gid, p00 + p01 + p10 + p11);
|
||||
|
||||
float O = p00 + p01 + p10 + p11;
|
||||
|
||||
{post_op_text}
|
||||
|
||||
O_GLOBAL_STORE(gid, O);
|
||||
}}
|
||||
""")
|
||||
elif interpolation == EInterpolation.CUBIC:
|
||||
|
@ -150,7 +164,7 @@ __kernel void impl(__global O_PTR_TYPE* O_PTR_NAME, __global const I_PTR_TYPE* I
|
|||
}}
|
||||
|
||||
float O = cubic(row[0], row[1], row[2], row[3], dy);
|
||||
|
||||
{post_op_text}
|
||||
O_GLOBAL_STORE(gid, O);
|
||||
}}
|
||||
""")
|
||||
|
@ -213,7 +227,7 @@ __kernel void impl(__global O_PTR_TYPE* O_PTR_NAME, __global const I_PTR_TYPE* I
|
|||
O += sxy*Fxyv*(y >= 0 & y < Im2 & x >= 0 & x < Im1);
|
||||
}}
|
||||
O = O / FxFysum;
|
||||
|
||||
{post_op_text}
|
||||
O_GLOBAL_STORE(gid, O);
|
||||
}}
|
||||
""")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue