refactoring

This commit is contained in:
iperov 2021-10-20 18:05:56 +04:00
parent 7463515bfc
commit 78d80f9c5c
5 changed files with 135 additions and 67 deletions

View file

@ -1,17 +0,0 @@
import numpy as np
def get_NHWC_shape(img : np.ndarray):
"""
returns NHWC shape where missed dims are 1
"""
ndim = img.ndim
if ndim not in [2,3,4]:
raise ValueError(f'img.ndim must be 2,3,4, not {ndim}.')
if ndim == 2:
N, (H,W), C = 1, img.shape, 1
elif ndim == 3:
N, (H,W,C) = 1, img.shape
else:
N,H,W,C = img.shape
return N,H,W,C