mirror of
https://github.com/iperov/DeepFaceLive
synced 2025-08-21 05:53:25 -07:00
add xlib.avecl
This commit is contained in:
parent
932edfe875
commit
0058474da7
56 changed files with 5569 additions and 0 deletions
103
xlib/avecl/_internal/AShape.py
Normal file
103
xlib/avecl/_internal/AShape.py
Normal file
|
@ -0,0 +1,103 @@
|
|||
from collections import Iterable
|
||||
from .AAxes import AAxes
|
||||
|
||||
class AShape(Iterable):
|
||||
__slots__ = ['shape','size','ndim']
|
||||
|
||||
def __init__(self, shape):
|
||||
"""
|
||||
Constructs valid shape from user argument
|
||||
|
||||
arguments
|
||||
|
||||
shape AShape
|
||||
Iterable
|
||||
|
||||
can raise ValueError during the construction
|
||||
"""
|
||||
|
||||
if isinstance(shape, AShape):
|
||||
self.shape = shape.shape
|
||||
self.size = shape.size
|
||||
self.ndim = shape.ndim
|
||||
else:
|
||||
if isinstance(shape, (int,float) ):
|
||||
shape = (int(shape),)
|
||||
|
||||
if isinstance(shape, Iterable):
|
||||
size = 1
|
||||
valid_shape = []
|
||||
for x in shape:
|
||||
if x is None:
|
||||
raise ValueError(f'Incorrent value {x} in shape {shape}')
|
||||
x = int(x)
|
||||
if x < 1:
|
||||
raise ValueError(f'Incorrent value {x} in shape {shape}')
|
||||
valid_shape.append(x)
|
||||
size *= x # Faster than np.prod()
|
||||
|
||||
self.shape = tuple(valid_shape)
|
||||
self.ndim = len(self.shape)
|
||||
if self.ndim == 0:
|
||||
# Force (1,) shape for scalar shape
|
||||
self.ndim = 1
|
||||
self.shape = (1,)
|
||||
self.size = size
|
||||
else:
|
||||
raise ValueError('Invalid type to create AShape')
|
||||
|
||||
def axes_arange(self) -> AAxes:
|
||||
"""
|
||||
Returns tuple of axes arange.
|
||||
|
||||
Example (0,1,2) for ndim 3
|
||||
"""
|
||||
return AAxes(range(self.ndim))
|
||||
|
||||
def transpose_by_axes(self, axes) -> 'AShape':
|
||||
"""
|
||||
Same as AShape[axes]
|
||||
|
||||
Returns AShape transposed by axes.
|
||||
|
||||
axes AAxes
|
||||
Iterable(list,tuple,set,generator)
|
||||
"""
|
||||
return AShape(self.shape[axis] for axis in AAxes(axes) )
|
||||
|
||||
def __hash__(self): return self.shape.__hash__()
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, AShape):
|
||||
return self.shape == other.shape
|
||||
elif isinstance(other, Iterable):
|
||||
return self.shape == tuple(other)
|
||||
return False
|
||||
def __iter__(self): return self.shape.__iter__()
|
||||
def __len__(self): return len(self.shape)
|
||||
def __getitem__(self,key):
|
||||
if isinstance(key, Iterable):
|
||||
if isinstance(key, AAxes):
|
||||
if key.is_none_axes():
|
||||
return self
|
||||
return self.transpose_by_axes(key)
|
||||
elif isinstance(key, slice):
|
||||
return AShape(self.shape[key])
|
||||
|
||||
return self.shape[key]
|
||||
|
||||
def __radd__(self, o):
|
||||
if isinstance(o, Iterable):
|
||||
return AShape( tuple(o) + self.shape)
|
||||
else:
|
||||
raise ValueError(f'unable to use type {o.__class__} in AShape append')
|
||||
def __add__(self, o):
|
||||
if isinstance(o, Iterable):
|
||||
return AShape( self.shape + tuple(o) )
|
||||
else:
|
||||
raise ValueError(f'unable to use type {o.__class__} in AShape append')
|
||||
|
||||
|
||||
def __str__(self): return str(self.shape)
|
||||
def __repr__(self): return 'AShape' + self.__str__()
|
||||
|
||||
__all__ = ['AShape']
|
Loading…
Add table
Add a link
Reference in a new issue