mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 13:02:15 -07:00
19 lines
No EOL
593 B
Python
19 lines
No EOL
593 B
Python
def normalize_channels(img, target_channels):
|
|
img_shape_len = len(img.shape)
|
|
if img_shape_len == 2:
|
|
h, w = img.shape
|
|
c = 0
|
|
elif img_shape_len == 3:
|
|
h, w, c = img.shape
|
|
else:
|
|
raise ValueError("normalize: incorrect image dimensions.")
|
|
|
|
if c == 0 and target_channels > 0:
|
|
img = img[...,np.newaxis]
|
|
if c == 1 and target_channels > 1:
|
|
img = np.repeat (img, target_channels, -1)
|
|
if c > target_channels:
|
|
img = img[...,0:target_channels]
|
|
c = target_channels
|
|
|
|
return img |