upd ImageProcessor

This commit is contained in:
iperov 2022-09-15 22:15:35 +04:00
parent 5eb9ad130e
commit 042867d19d

View file

@ -839,19 +839,17 @@ class ImageProcessor:
self._img = self._img.astype(np.uint8)
return self
def to_dtype(self, dtype) -> 'ImageProcessor':
def to_dtype(self, dtype, from_tanh=False) -> 'ImageProcessor':
if dtype == np.float32:
return self.to_ufloat32()
return self.to_ufloat32(from_tanh=from_tanh)
elif dtype == np.uint8:
return self.to_uint8()
return self.to_uint8(from_tanh=from_tanh)
else:
raise ValueError('unsupported dtype')
def to_ufloat32(self, as_tanh=False) -> 'ImageProcessor':
def to_ufloat32(self, as_tanh=False, from_tanh=False) -> 'ImageProcessor':
"""
Convert to uniform float32
if current image dtype uint8, then image will be divided by / 255.0
otherwise no operation
"""
if self._img.dtype == np.uint8:
self._img = self._img.astype(np.float32)
@ -860,10 +858,14 @@ class ImageProcessor:
self._img -= 1.0
else:
self._img /= 255.0
elif self._img.dtype in [np.float32, np.float64]:
if from_tanh:
self._img += 1.0
self._img /= 2.0
return self
def to_uint8(self) -> 'ImageProcessor':
def to_uint8(self, from_tanh=False) -> 'ImageProcessor':
"""
Convert to uint8
@ -872,6 +874,10 @@ class ImageProcessor:
img = self._img
if img.dtype in [np.float32, np.float64]:
if from_tanh:
img += 1.0
img /= 2.0
img *= 255.0
np.clip(img, 0, 255, out=img)