diff --git a/xlib/avecl/__init__.py b/xlib/avecl/__init__.py index be2dc5d..bde8071 100644 --- a/xlib/avecl/__init__.py +++ b/xlib/avecl/__init__.py @@ -13,6 +13,8 @@ The lib is not thread-safe. made by @iperov from scratch. """ +from xlib.avecl._internal.initializer.InitConst import InitConst + from ._internal.AAxes import AAxes from ._internal.AShape import AShape from ._internal.backend import (Device, DeviceInfo, Kernel, @@ -24,11 +26,11 @@ from ._internal.HArgs import HArgs from ._internal.HKernel import HKernel from ._internal.HTensor import HTensor from ._internal.HType import HType -from ._internal.initializer import (InitCoords2DArange, Initializer, +from ._internal.initializer import (InitConst, InitCoords2DArange, Initializer, InitRandomUniform) from ._internal.NCore import NCore from ._internal.NTest import NTest from ._internal.op import * from ._internal.SCacheton import SCacheton from ._internal.Tensor import Tensor -from ._internal.TensorImpl import * \ No newline at end of file +from ._internal.TensorImpl import * diff --git a/xlib/avecl/_internal/initializer/InitConst.py b/xlib/avecl/_internal/initializer/InitConst.py new file mode 100644 index 0000000..b299829 --- /dev/null +++ b/xlib/avecl/_internal/initializer/InitConst.py @@ -0,0 +1,37 @@ +from ..backend import Kernel +from ..HKernel import HKernel +from ..SCacheton import SCacheton +from ..Tensor import Tensor +from .Initializer import Initializer + + +class InitConst(Initializer): + + def __init__(self, value=0): + """ + arguments + + value(0) + """ + super().__init__() + self._value = value + + def initialize_tensor(self, tensor : Tensor): + + key = (InitConst, self._value, tensor.dtype) + kernel = SCacheton.get_var(key) + if kernel is None: + kernel = Kernel(kernel_text=f""" +{HKernel.define_tensor('O', (tensor.shape.size,), tensor.dtype )} +__kernel void impl(__global O_PTR_TYPE* O_PTR_NAME) +{{ + O_GLOBAL_STORE(get_global_id(0), (O_TYPE){self._value} ); +}} +""") + SCacheton.set_var(key, kernel) + + tensor.get_device().run_kernel( kernel, tensor.get_buffer(), + global_shape=(tensor.shape.size,) ) + + def __str__(self): return f'InitConst low={self._low}, high={self._high}' + diff --git a/xlib/avecl/_internal/initializer/__init__.py b/xlib/avecl/_internal/initializer/__init__.py index 932ea21..84190a8 100644 --- a/xlib/avecl/_internal/initializer/__init__.py +++ b/xlib/avecl/_internal/initializer/__init__.py @@ -1,3 +1,4 @@ +from .InitConst import InitConst from .InitCoords2DArange import InitCoords2DArange from .Initializer import Initializer from .InitRandomUniform import InitRandomUniform