fix AveCL

This commit is contained in:
iperov 2021-10-01 18:10:06 +04:00
commit 8f530cc0a7
4 changed files with 55 additions and 25 deletions

View file

@ -7,6 +7,7 @@ Applicable for high-performance general purpose n-dim array computations for eve
Works in python 3.5+. Dependencies: numpy. Works in python 3.5+. Dependencies: numpy.
This lib uses relative import, thus you can place it in any subfolder. This lib uses relative import, thus you can place it in any subfolder.
The lib is not thread-safe.
made by @iperov from scratch. made by @iperov from scratch.
""" """

View file

@ -175,10 +175,7 @@ class HKernel:
line = f'#define {name_upper}_IDX({HKernel.axes_seq_enum(name, ndim)}) ' line = f'#define {name_upper}_IDX({HKernel.axes_seq_enum(name, ndim)}) '
for i in range(ndim): for i in range(ndim):
if i == 0: line += f'( (size_t)({name_lower}{i}) )'
line += f'( (size_t)({name_lower}{i}) )'
else:
line += f'( {name_lower}{i} )'
for j in range(i+1,ndim): for j in range(i+1,ndim):
line += f'*{shape[j]} ' line += f'*{shape[j]} '
@ -190,10 +187,7 @@ class HKernel:
line = f'#define {name_upper}_IDX_MOD({HKernel.axes_seq_enum(name, ndim)}) ' line = f'#define {name_upper}_IDX_MOD({HKernel.axes_seq_enum(name, ndim)}) '
for i in range(ndim): for i in range(ndim):
if i == 0: line += f'( (size_t)({name_lower}{i}) % {shape[i]} )'
line += f'( (size_t)({name_lower}{i}) % {shape[i]} )'
else:
line += f'( ({name_lower}{i}) % {shape[i]} )'
for j in range(i+1,ndim): for j in range(i+1,ndim):
line += f'*{shape[j]} ' line += f'*{shape[j]} '

View file

@ -15,10 +15,6 @@ class NCore:
can raise Exception can raise Exception
""" """
SCacheton.cleanup() SCacheton.cleanup()
if Tensor._object_count != 0:
raise Exception(f'Unable to cleanup while {Tensor._object_count} Tensor objects exist.')
cleanup_devices() cleanup_devices()
__all__ = ['NCore'] __all__ = ['NCore']

View file

@ -1,3 +1,5 @@
import random
from collections import deque
from typing import List, Union from typing import List, Union
import numpy as np import numpy as np
@ -38,6 +40,8 @@ class Device:
self._ctx_q = None # CL command queue self._ctx_q = None # CL command queue
self._ctx = None # CL context self._ctx = None # CL context
self._target_memory_usage = 0
self._total_memory_allocated = 0 self._total_memory_allocated = 0
self._total_buffers_allocated = 0 self._total_buffers_allocated = 0
self._total_memory_pooled = 0 self._total_memory_pooled = 0
@ -97,6 +101,9 @@ class Device:
return self._cached_data.get(key, None) return self._cached_data.get(key, None)
def get_total_allocated_memory(self): def get_total_allocated_memory(self):
"""
get total bytes of used and pooled memory
"""
return self._total_memory_allocated return self._total_memory_allocated
def get_max_malloc_size(self) -> int: def get_max_malloc_size(self) -> int:
@ -153,6 +160,8 @@ class Device:
return compiled_krn return compiled_krn
def _cl_mem_alloc(self, size) -> CL.cl_mem: def _cl_mem_alloc(self, size) -> CL.cl_mem:
self._keep_target_memory_usage()
clr = CL.CLRESULT() clr = CL.CLRESULT()
mem = CL.clCreateBuffer(self._get_ctx(), CL.CL_MEM_READ_WRITE, size, None, clr) mem = CL.clCreateBuffer(self._get_ctx(), CL.CL_MEM_READ_WRITE, size, None, clr)
if clr == CL.CLERROR.SUCCESS: if clr == CL.CLERROR.SUCCESS:
@ -172,6 +181,7 @@ class Device:
if clr != CL.CLERROR.SUCCESS: if clr != CL.CLERROR.SUCCESS:
raise Exception(f'clGetMemObjectInfo error: {clr}') raise Exception(f'clGetMemObjectInfo error: {clr}')
size = size.value size = size.value
self._total_memory_allocated -= size self._total_memory_allocated -= size
self._total_buffers_allocated -= 1 self._total_buffers_allocated -= 1
clr = CL.clReleaseMemObject(mem) clr = CL.clReleaseMemObject(mem)
@ -182,12 +192,14 @@ class Device:
""" """
allocate or get cl_mem from pool allocate or get cl_mem from pool
""" """
self._keep_target_memory_usage()
pool = self._pooled_buffers pool = self._pooled_buffers
# First try to get pooled buffer # First try to get pooled buffer
ar = pool.get(size, None) ar = pool.get(size, None)
if ar is not None and len(ar) != 0: if ar is not None and len(ar) != 0:
mem = ar.pop(-1) mem = ar.pop()
self._total_memory_pooled -= size self._total_memory_pooled -= size
self._total_buffers_pooled -= 1 self._total_buffers_pooled -= 1
else: else:
@ -200,7 +212,7 @@ class Device:
for size_key in sorted(list(pool.keys()), reverse=True): for size_key in sorted(list(pool.keys()), reverse=True):
ar = pool[size_key] ar = pool[size_key]
if len(ar) != 0: if len(ar) != 0:
buf_to_release = ar.pop(-1) buf_to_release = ar.pop()
break break
if buf_to_release is not None: if buf_to_release is not None:
@ -208,8 +220,7 @@ class Device:
self._cl_mem_free(buf_to_release) self._cl_mem_free(buf_to_release)
continue continue
raise Exception(f'Unable to allocate {size // 1024**2}Mb on {str(self)}') raise Exception(f'Unable to allocate {size // 1024**2}Mb on {self.get_description()}')
break break
@ -228,7 +239,7 @@ class Device:
pool = self._pooled_buffers pool = self._pooled_buffers
ar = pool.get(size, None) ar = pool.get(size, None)
if ar is None: if ar is None:
ar = pool[size] = [] ar = pool[size] = deque()
ar.append(mem) ar.append(mem)
self._total_memory_pooled += size self._total_memory_pooled += size
@ -336,13 +347,33 @@ N of cacheddata: {len(self._cached_data)}
clr = CL.clFinish(self._get_ctx_q()) clr = CL.clFinish(self._get_ctx_q())
if clr != CL.CLERROR.SUCCESS: if clr != CL.CLERROR.SUCCESS:
raise Exception(f'clFinish error: {clr}') raise Exception(f'clFinish error: {clr}')
def cleanup(self): def set_target_memory_usage(self, mb : int):
""" """
Frees all resources from this Device. keep memory usage at specified position
when total allocated memory reached the target and new allocation is performing,
random pooled memory will be freed
""" """
self._cached_data = {} self._target_memory_usage = mb*1024*1024
def _keep_target_memory_usage(self):
targ = self._target_memory_usage
if targ != 0 and self.get_total_allocated_memory() >= targ:
# remove random 25% of pooled boofers
print('remove random 25% of pooled boofers')
pool = self._pooled_buffers
mems = [ (k,x) for k in pool.keys() for x in pool[k] ]
for k, mem in random.sample(mems, max(1,int(len(mems)*0.25)) ):
self._cl_mem_free(mem)
pool[k].remove(mem)
def clear_pooled_memory(self):
"""
frees pooled memory
"""
pool = self._pooled_buffers pool = self._pooled_buffers
for size_key in pool.keys(): for size_key in pool.keys():
for mem in pool[size_key]: for mem in pool[size_key]:
@ -350,6 +381,14 @@ N of cacheddata: {len(self._cached_data)}
self._pooled_buffers = {} self._pooled_buffers = {}
self._total_memory_pooled = 0 self._total_memory_pooled = 0
self._total_buffers_pooled = 0 self._total_buffers_pooled = 0
def cleanup(self):
"""
Frees all resources from this Device.
"""
self._cached_data = {}
self.clear_pooled_memory()
if self._total_memory_allocated != 0: if self._total_memory_allocated != 0:
raise Exception('Unable to cleanup CLDevice, while not all Buffers are deallocated.') raise Exception('Unable to cleanup CLDevice, while not all Buffers are deallocated.')
@ -459,7 +498,7 @@ def get_available_devices_info() -> List[DeviceInfo]:
def get_default_device() -> Union[Device, None]: def get_default_device() -> Union[Device, None]:
global _default_device global _default_device
if _default_device is None: if _default_device is None:
_default_device = get_device(0) _default_device = get_best_device()
return _default_device return _default_device
def set_default_device(device : Device): def set_default_device(device : Device):
@ -473,7 +512,7 @@ def get_device(arg : Union[None, int, Device, DeviceInfo]) -> Union[Device, None
""" """
get physical TensorCL device. get physical TensorCL device.
arg None - get best device arg None - get default device
int - by index int - by index
DeviceInfo - by device info DeviceInfo - by device info
Device - returns the same Device - returns the same
@ -481,7 +520,7 @@ def get_device(arg : Union[None, int, Device, DeviceInfo]) -> Union[Device, None
global _devices global _devices
if arg is None: if arg is None:
return get_best_device() return get_default_device()
elif isinstance(arg, int): elif isinstance(arg, int):
devices_info = get_available_devices_info() devices_info = get_available_devices_info()
if arg < len(devices_info): if arg < len(devices_info):