add xlib.avecl

This commit is contained in:
iperov 2021-09-30 18:21:30 +04:00
commit 0058474da7
56 changed files with 5569 additions and 0 deletions

View file

@ -0,0 +1,106 @@
from typing import List
import numpy as np
from .backend import Device
from .HTensor import HTensor
from .HType import HType
from .Tensor import Tensor
from .AShape import AShape
class HArgs:
"""
Helper functions for list of arguments
"""
@staticmethod
def decompose(args):
"""
decompose list of args of Tensor and supported numeric values
returns ( shape_list, # if scalar value -> shape is None
dtype_list, #
kernel_args_list #
)
"""
shape_list = []
dtype_list = []
kernel_args_list = []
for arg in args:
if isinstance(arg, Tensor):
shape_list.append(arg.shape)
dtype_list.append(arg.dtype)
kernel_args_list.append(arg.get_buffer())
else:
if isinstance(arg, int):
dtype, arg = np.int32, np.int32(arg)
elif isinstance(arg, float):
dtype, arg = np.float32, np.float32(arg)
elif HType.is_obj_of_np_scalar_type(arg):
dtype = arg.__class__
else:
raise ValueError(f'Unsupported type of arg: {arg.__class__} Use Tensor or number type.')
shape_list.append(None)
dtype_list.append(dtype)
kernel_args_list.append(arg)
return tuple(shape_list), tuple(dtype_list), tuple(kernel_args_list)
@staticmethod
def get_shapes(args : List[Tensor]) -> List[AShape]:
"""
"""
return tuple(t.shape for t in args)
@staticmethod
def check_zero_get_length(args) -> int:
"""
raises an error if len(args) == 0, otherwise returns len
"""
args_len = len(args)
if len(args) == 0:
raise ValueError('args must be specified')
return args_len
@staticmethod
def check_get_same_device(args : List[Tensor]) -> Device:
"""
check all device of tensors are the same and return the device
"""
result = HTensor.all_same_device(args)
if not result:
raise ValueError('all Tensors must have the same device')
return args[0].get_device()
@staticmethod
def check_all_tensors(args : List[Tensor]):
"""
"""
if not all (isinstance(tensor, Tensor) for tensor in args):
raise ValueError('All values must have type of Tensor')
@staticmethod
def check_get_same_shape(args : List[Tensor]) -> AShape:
"""
check all shapes of tensors are the same and return the shape
"""
shape = args[0].shape
if not all (t.shape == shape for t in args):
raise ValueError('All tensors must have the same shape')
return shape
@staticmethod
def filter_tensor(args, raise_on_empty : bool):
"""
get only tensors from the list
"""
tensor_args = [arg for arg in args if isinstance(arg, Tensor) ]
if raise_on_empty and len(tensor_args) == 0:
raise ValueError('At least one arg must be a Tensor')
return tensor_args
__all__ = ['HArgs']