update xlib.avecl

This commit is contained in:
iperov 2021-10-22 16:35:05 +04:00
commit 7aef4d2b1e
3 changed files with 42 additions and 2 deletions

View file

@ -13,6 +13,8 @@ The lib is not thread-safe.
made by @iperov from scratch. made by @iperov from scratch.
""" """
from xlib.avecl._internal.initializer.InitConst import InitConst
from ._internal.AAxes import AAxes from ._internal.AAxes import AAxes
from ._internal.AShape import AShape from ._internal.AShape import AShape
from ._internal.backend import (Device, DeviceInfo, Kernel, from ._internal.backend import (Device, DeviceInfo, Kernel,
@ -24,11 +26,11 @@ from ._internal.HArgs import HArgs
from ._internal.HKernel import HKernel from ._internal.HKernel import HKernel
from ._internal.HTensor import HTensor from ._internal.HTensor import HTensor
from ._internal.HType import HType from ._internal.HType import HType
from ._internal.initializer import (InitCoords2DArange, Initializer, from ._internal.initializer import (InitConst, InitCoords2DArange, Initializer,
InitRandomUniform) InitRandomUniform)
from ._internal.NCore import NCore from ._internal.NCore import NCore
from ._internal.NTest import NTest from ._internal.NTest import NTest
from ._internal.op import * from ._internal.op import *
from ._internal.SCacheton import SCacheton from ._internal.SCacheton import SCacheton
from ._internal.Tensor import Tensor from ._internal.Tensor import Tensor
from ._internal.TensorImpl import * from ._internal.TensorImpl import *

View file

@ -0,0 +1,37 @@
from ..backend import Kernel
from ..HKernel import HKernel
from ..SCacheton import SCacheton
from ..Tensor import Tensor
from .Initializer import Initializer
class InitConst(Initializer):
def __init__(self, value=0):
"""
arguments
value(0)
"""
super().__init__()
self._value = value
def initialize_tensor(self, tensor : Tensor):
key = (InitConst, self._value, tensor.dtype)
kernel = SCacheton.get_var(key)
if kernel is None:
kernel = Kernel(kernel_text=f"""
{HKernel.define_tensor('O', (tensor.shape.size,), tensor.dtype )}
__kernel void impl(__global O_PTR_TYPE* O_PTR_NAME)
{{
O_GLOBAL_STORE(get_global_id(0), (O_TYPE){self._value} );
}}
""")
SCacheton.set_var(key, kernel)
tensor.get_device().run_kernel( kernel, tensor.get_buffer(),
global_shape=(tensor.shape.size,) )
def __str__(self): return f'InitConst low={self._low}, high={self._high}'

View file

@ -1,3 +1,4 @@
from .InitConst import InitConst
from .InitCoords2DArange import InitCoords2DArange from .InitCoords2DArange import InitCoords2DArange
from .Initializer import Initializer from .Initializer import Initializer
from .InitRandomUniform import InitRandomUniform from .InitRandomUniform import InitRandomUniform