diff --git a/xlib/image/ImageProcessor.py b/xlib/image/ImageProcessor.py index 6be6bae..a730625 100644 --- a/xlib/image/ImageProcessor.py +++ b/xlib/image/ImageProcessor.py @@ -839,19 +839,17 @@ class ImageProcessor: self._img = self._img.astype(np.uint8) return self - def to_dtype(self, dtype) -> 'ImageProcessor': + def to_dtype(self, dtype, from_tanh=False) -> 'ImageProcessor': if dtype == np.float32: - return self.to_ufloat32() + return self.to_ufloat32(from_tanh=from_tanh) elif dtype == np.uint8: - return self.to_uint8() + return self.to_uint8(from_tanh=from_tanh) else: raise ValueError('unsupported dtype') - def to_ufloat32(self, as_tanh=False) -> 'ImageProcessor': + def to_ufloat32(self, as_tanh=False, from_tanh=False) -> 'ImageProcessor': """ Convert to uniform float32 - if current image dtype uint8, then image will be divided by / 255.0 - otherwise no operation """ if self._img.dtype == np.uint8: self._img = self._img.astype(np.float32) @@ -860,10 +858,14 @@ class ImageProcessor: self._img -= 1.0 else: self._img /= 255.0 + elif self._img.dtype in [np.float32, np.float64]: + if from_tanh: + self._img += 1.0 + self._img /= 2.0 return self - def to_uint8(self) -> 'ImageProcessor': + def to_uint8(self, from_tanh=False) -> 'ImageProcessor': """ Convert to uint8 @@ -872,6 +874,10 @@ class ImageProcessor: img = self._img if img.dtype in [np.float32, np.float64]: + if from_tanh: + img += 1.0 + img /= 2.0 + img *= 255.0 np.clip(img, 0, 255, out=img)