mirror of
https://github.com/iperov/DeepFaceLive
synced 2025-08-14 02:37:01 -07:00
update xlib.avecl
This commit is contained in:
parent
63adc2995e
commit
2d401f47f8
7 changed files with 337 additions and 27 deletions
|
@ -1,6 +1,9 @@
|
|||
from collections import Iterable
|
||||
from typing import Tuple, List
|
||||
|
||||
from .AAxes import AAxes
|
||||
|
||||
|
||||
class AShape(Iterable):
|
||||
__slots__ = ['shape','size','ndim']
|
||||
|
||||
|
@ -10,8 +13,10 @@ class AShape(Iterable):
|
|||
|
||||
arguments
|
||||
|
||||
shape AShape
|
||||
Iterable
|
||||
shape AShape
|
||||
Iterable
|
||||
|
||||
AShape cannot be scalar shape, thus minimal AShape is (1,)
|
||||
|
||||
can raise ValueError during the construction
|
||||
"""
|
||||
|
@ -45,7 +50,13 @@ class AShape(Iterable):
|
|||
self.size = size
|
||||
else:
|
||||
raise ValueError('Invalid type to create AShape')
|
||||
|
||||
|
||||
def copy(self) -> 'AShape':
|
||||
return AShape(self)
|
||||
|
||||
def as_list(self) -> List[int]:
|
||||
return list(self.shape)
|
||||
|
||||
def axes_arange(self) -> AAxes:
|
||||
"""
|
||||
Returns tuple of axes arange.
|
||||
|
@ -53,6 +64,35 @@ class AShape(Iterable):
|
|||
Example (0,1,2) for ndim 3
|
||||
"""
|
||||
return AAxes(range(self.ndim))
|
||||
|
||||
def replaced_axes(self, axes, dims) -> 'AShape':
|
||||
"""
|
||||
returns new AShape where axes replaced with new dims
|
||||
"""
|
||||
new_shape = list(self.shape)
|
||||
ndim = self.ndim
|
||||
for axis, dim in zip(axes, dims):
|
||||
if axis < 0:
|
||||
axis = ndim + axis
|
||||
if axis < 0 or axis >= ndim:
|
||||
raise ValueError(f'invalid axis value {axis}')
|
||||
|
||||
new_shape[axis] = dim
|
||||
return AShape(new_shape)
|
||||
|
||||
|
||||
def split(self, axis) -> Tuple['AShape', 'AShape']:
|
||||
"""
|
||||
split AShape at specified axis
|
||||
|
||||
returns two AShape before+exclusive and inclusive+after
|
||||
"""
|
||||
if axis < 0:
|
||||
axis = self.ndim + axis
|
||||
if axis < 0 or axis >= self.ndim:
|
||||
raise ValueError(f'invalid axis value {axis}')
|
||||
|
||||
return self[:axis], self[axis:]
|
||||
|
||||
def transpose_by_axes(self, axes) -> 'AShape':
|
||||
"""
|
||||
|
@ -100,4 +140,4 @@ class AShape(Iterable):
|
|||
def __str__(self): return str(self.shape)
|
||||
def __repr__(self): return 'AShape' + self.__str__()
|
||||
|
||||
__all__ = ['AShape']
|
||||
__all__ = ['AShape']
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue