ImageProcessor.py refactoring

This commit is contained in:
iperov 2022-05-18 14:24:39 +04:00
parent 2d3d9874bf
commit b3bc4e7345

View file

@ -57,14 +57,14 @@ class ImageProcessor:
dtype = self.get_dtype() dtype = self.get_dtype()
self.to_ufloat32() self.to_ufloat32()
img = orig_img = self._img img = orig_img = self._img
img = np.power(img, np.array([1.0 / blue, 1.0 / green, 1.0 / red], np.float32) ) img = np.power(img, np.array([1.0 / blue, 1.0 / green, 1.0 / red], np.float32) )
np.clip(img, 0, 1.0, out=img) np.clip(img, 0, 1.0, out=img)
if mask is not None: if mask is not None:
mask = self._check_normalize_mask(mask) mask = self._check_normalize_mask(mask)
img = ne.evaluate('orig_img*(1-mask) + img*mask') img = ne.evaluate('orig_img*(1-mask) + img*mask')
self._img = img self._img = img
self.to_dtype(dtype) self.to_dtype(dtype)
return self return self
@ -86,11 +86,11 @@ class ImageProcessor:
img = func(img).astype(orig_img.dtype) img = func(img).astype(orig_img.dtype)
if img.ndim != 4: if img.ndim != 4:
raise Exception('func used in ImageProcessor.apply changed format of image') raise Exception('func used in ImageProcessor.apply changed format of image')
if mask is not None: if mask is not None:
mask = self._check_normalize_mask(mask) mask = self._check_normalize_mask(mask)
img = ne.evaluate('orig_img*(1-mask) + img*mask').astype(orig_img.dtype) img = ne.evaluate('orig_img*(1-mask) + img*mask').astype(orig_img.dtype)
self._img = img self._img = img
return self return self
@ -178,7 +178,7 @@ class ImageProcessor:
img = cv2.resize (img, (W_lr,H_lr), interpolation=_cv_inter[interpolation]) img = cv2.resize (img, (W_lr,H_lr), interpolation=_cv_inter[interpolation])
img = cv2.resize (img, (W,H) , interpolation=_cv_inter[interpolation]) img = cv2.resize (img, (W,H) , interpolation=_cv_inter[interpolation])
img = img.reshape( (H,W,N,C) ).transpose( (2,0,1,3) ) img = img.reshape( (H,W,N,C) ).transpose( (2,0,1,3) )
if mask is not None: if mask is not None:
mask = self._check_normalize_mask(mask) mask = self._check_normalize_mask(mask)
img = ne.evaluate('orig_img*(1-mask) + img*mask').astype(orig_img.dtype) img = ne.evaluate('orig_img*(1-mask) + img*mask').astype(orig_img.dtype)
@ -195,7 +195,7 @@ class ImageProcessor:
power = max(0, power) power = max(0, power)
if power == 0: if power == 0:
return self return self
if size % 2 == 0: if size % 2 == 0:
size += 1 size += 1
@ -206,7 +206,7 @@ class ImageProcessor:
N,H,W,C = img.shape N,H,W,C = img.shape
img = img.transpose( (1,2,0,3) ).reshape( (H,W,N*C) ) img = img.transpose( (1,2,0,3) ).reshape( (H,W,N*C) )
kernel = np.zeros( (size, size), dtype=np.float32) kernel = np.zeros( (size, size), dtype=np.float32)
kernel[ size//2, size//2] = 1.0 kernel[ size//2, size//2] = 1.0
box_filter = np.ones( (size, size), dtype=np.float32) / (size**2) box_filter = np.ones( (size, size), dtype=np.float32) / (size**2)
@ -215,15 +215,15 @@ class ImageProcessor:
img = np.clip(img, 0, 1, out=img) img = np.clip(img, 0, 1, out=img)
img = img.reshape( (H,W,N,C) ).transpose( (2,0,1,3) ) img = img.reshape( (H,W,N,C) ).transpose( (2,0,1,3) )
if mask is not None: if mask is not None:
mask = self._check_normalize_mask(mask) mask = self._check_normalize_mask(mask)
img = ne.evaluate('orig_img*(1-mask) + img*mask') img = ne.evaluate('orig_img*(1-mask) + img*mask')
self._img = img self._img = img
self.to_dtype(dtype) self.to_dtype(dtype)
return self return self
def gaussian_sharpen(self, sigma : float, power : float, mask = None) -> 'ImageProcessor': def gaussian_sharpen(self, sigma : float, power : float, mask = None) -> 'ImageProcessor':
""" """
sigma float sigma float
@ -241,16 +241,16 @@ class ImageProcessor:
N,H,W,C = img.shape N,H,W,C = img.shape
img = img.transpose( (1,2,0,3) ).reshape( (H,W,N*C) ) img = img.transpose( (1,2,0,3) ).reshape( (H,W,N*C) )
img = cv2.addWeighted(img, 1.0 + power, img = cv2.addWeighted(img, 1.0 + power,
cv2.GaussianBlur(img, (0, 0), sigma), -power, 0) cv2.GaussianBlur(img, (0, 0), sigma), -power, 0)
img = np.clip(img, 0, 1, out=img) img = np.clip(img, 0, 1, out=img)
img = img.reshape( (H,W,N,C) ).transpose( (2,0,1,3) ) img = img.reshape( (H,W,N,C) ).transpose( (2,0,1,3) )
if mask is not None: if mask is not None:
mask = self._check_normalize_mask(mask) mask = self._check_normalize_mask(mask)
img = ne.evaluate('orig_img*(1-mask) + img*mask') img = ne.evaluate('orig_img*(1-mask) + img*mask')
self._img = img self._img = img
self.to_dtype(dtype) self.to_dtype(dtype)
@ -268,7 +268,7 @@ class ImageProcessor:
opacity = np.float32( min(1, max(0, opacity)) ) opacity = np.float32( min(1, max(0, opacity)) )
if opacity == 0: if opacity == 0:
return self return self
dtype = self.get_dtype() dtype = self.get_dtype()
self.to_ufloat32() self.to_ufloat32()
@ -280,19 +280,19 @@ class ImageProcessor:
img_blur = cv2.GaussianBlur(img, (0,0), sigma) img_blur = cv2.GaussianBlur(img, (0,0), sigma)
f32_1 = np.float32(1.0) f32_1 = np.float32(1.0)
img = ne.evaluate('img*(f32_1-opacity) + img_blur*opacity') img = ne.evaluate('img*(f32_1-opacity) + img_blur*opacity')
img = np.clip(img, 0, 1, out=img) img = np.clip(img, 0, 1, out=img)
img = img.reshape( (H,W,N,C) ).transpose( (2,0,1,3) ) img = img.reshape( (H,W,N,C) ).transpose( (2,0,1,3) )
if mask is not None: if mask is not None:
mask = self._check_normalize_mask(mask) mask = self._check_normalize_mask(mask)
img = ne.evaluate('orig_img*(1-mask) + img*mask') img = ne.evaluate('orig_img*(1-mask) + img*mask')
self._img = img self._img = img
self.to_dtype(dtype) self.to_dtype(dtype)
return self return self
def median_blur(self, size : int, opacity : float = 1.0, mask = None) -> 'ImageProcessor': def median_blur(self, size : int, opacity : float = 1.0, mask = None) -> 'ImageProcessor':
""" """
size int median kernel size size int median kernel size
@ -302,7 +302,7 @@ class ImageProcessor:
if size % 2 == 0: if size % 2 == 0:
size += 1 size += 1
size = max(1, size) size = max(1, size)
opacity = min(1, max(0, opacity)) opacity = min(1, max(0, opacity))
if opacity == 0: if opacity == 0:
return self return self
@ -320,29 +320,29 @@ class ImageProcessor:
img = ne.evaluate('img*(f32_1-opacity) + img_blur*opacity') img = ne.evaluate('img*(f32_1-opacity) + img_blur*opacity')
img = np.clip(img, 0, 1, out=img) img = np.clip(img, 0, 1, out=img)
img = img.reshape( (H,W,N,C) ).transpose( (2,0,1,3) ) img = img.reshape( (H,W,N,C) ).transpose( (2,0,1,3) )
if mask is not None: if mask is not None:
mask = self._check_normalize_mask(mask) mask = self._check_normalize_mask(mask)
img = ne.evaluate('orig_img*(1-mask) + img*mask') img = ne.evaluate('orig_img*(1-mask) + img*mask')
self._img = img self._img = img
self.to_dtype(dtype) self.to_dtype(dtype)
return self return self
def motion_blur( self, size, angle, mask=None ): def motion_blur( self, size, angle, mask=None ):
""" """
size [1..] size [1..]
angle degrees angle degrees
mask H,W mask H,W
H,W,C H,W,C
N,H,W,C int/float 0-1 will be applied N,H,W,C int/float 0-1 will be applied
""" """
if size % 2 == 0: if size % 2 == 0:
size += 1 size += 1
dtype = self.get_dtype() dtype = self.get_dtype()
self.to_ufloat32() self.to_ufloat32()
@ -350,25 +350,25 @@ class ImageProcessor:
N,H,W,C = img.shape N,H,W,C = img.shape
img = img.transpose( (1,2,0,3) ).reshape( (H,W,N*C) ) img = img.transpose( (1,2,0,3) ).reshape( (H,W,N*C) )
k = np.zeros((size, size), dtype=np.float32) k = np.zeros((size, size), dtype=np.float32)
k[ (size-1)// 2 , :] = np.ones(size, dtype=np.float32) k[ (size-1)// 2 , :] = np.ones(size, dtype=np.float32)
k = cv2.warpAffine(k, cv2.getRotationMatrix2D( (size / 2 -0.5 , size / 2 -0.5 ) , angle, 1.0), (size, size) ) k = cv2.warpAffine(k, cv2.getRotationMatrix2D( (size / 2 -0.5 , size / 2 -0.5 ) , angle, 1.0), (size, size) )
k = k * ( 1.0 / np.sum(k) ) k = k * ( 1.0 / np.sum(k) )
img = cv2.filter2D(img, -1, k) img = cv2.filter2D(img, -1, k)
img = np.clip(img, 0, 1, out=img) img = np.clip(img, 0, 1, out=img)
img = img.reshape( (H,W,N,C) ).transpose( (2,0,1,3) ) img = img.reshape( (H,W,N,C) ).transpose( (2,0,1,3) )
if mask is not None: if mask is not None:
mask = self._check_normalize_mask(mask) mask = self._check_normalize_mask(mask)
img = ne.evaluate('orig_img*(1-mask) + img*mask') img = ne.evaluate('orig_img*(1-mask) + img*mask')
self._img = img self._img = img
self.to_dtype(dtype) self.to_dtype(dtype)
return self return self
def erode_blur(self, erode : int, blur : int, fade_to_border : bool = False) -> 'ImageProcessor': def erode_blur(self, erode : int, blur : int, fade_to_border : bool = False) -> 'ImageProcessor':
""" """
apply erode and blur to the mask image apply erode and blur to the mask image
@ -413,25 +413,25 @@ class ImageProcessor:
self._img = img self._img = img
return self return self
def levels(self, in_bwg_out_bw, mask = None) -> 'ImageProcessor': def levels(self, in_bwg_out_bw, mask = None) -> 'ImageProcessor':
""" """
in_bwg_out_bw ( [N],[C], 5) in_bwg_out_bw ( [N],[C], 5)
optional per channel/batch input black,white,gamma and out black,white floats optional per channel/batch input black,white,gamma and out black,white floats
in black = [0.0 .. 1.0] default:0.0 in black = [0.0 .. 1.0] default:0.0
in white = [0.0 .. 1.0] default:1.0 in white = [0.0 .. 1.0] default:1.0
in gamma = [0.0 .. 2.0++] default:1.0 in gamma = [0.0 .. 2.0++] default:1.0
out black = [0.0 .. 1.0] default:0.0 out black = [0.0 .. 1.0] default:0.0
out white = [0.0 .. 1.0] default:1.0 out white = [0.0 .. 1.0] default:1.0
""" """
dtype = self.get_dtype() dtype = self.get_dtype()
self.to_ufloat32() self.to_ufloat32()
img = orig_img = self._img img = orig_img = self._img
N,H,W,C = img.shape N,H,W,C = img.shape
v = np.array(in_bwg_out_bw, np.float32) v = np.array(in_bwg_out_bw, np.float32)
if v.ndim == 1: if v.ndim == 1:
@ -442,151 +442,151 @@ class ImageProcessor:
v = np.tile(v, (N,1,1)) v = np.tile(v, (N,1,1))
elif v.ndim > 3: elif v.ndim > 3:
raise ValueError('in_bwg_out_bw.ndim > 3') raise ValueError('in_bwg_out_bw.ndim > 3')
VN, VC, VD = v.shape VN, VC, VD = v.shape
if N != VN or C != VC or VD != 5: if N != VN or C != VC or VD != 5:
raise ValueError('wrong in_bwg_out_bw size. Must have 5 floats at last dim.') raise ValueError('wrong in_bwg_out_bw size. Must have 5 floats at last dim.')
v = v[:,None,None,:,:] v = v[:,None,None,:,:]
img = np.clip( (img - v[...,0]) / (v[...,1] - v[...,0]), 0, 1 ) img = np.clip( (img - v[...,0]) / (v[...,1] - v[...,0]), 0, 1 )
img = ( img ** (1/v[...,2]) ) * (v[...,4] - v[...,3]) + v[...,3] img = ( img ** (1/v[...,2]) ) * (v[...,4] - v[...,3]) + v[...,3]
img = np.clip(img, 0, 1, out=img) img = np.clip(img, 0, 1, out=img)
if mask is not None: if mask is not None:
mask = self._check_normalize_mask(mask) mask = self._check_normalize_mask(mask)
img = ne.evaluate('orig_img*(1-mask) + img*mask') img = ne.evaluate('orig_img*(1-mask) + img*mask')
self._img = img self._img = img
self.to_dtype(dtype) self.to_dtype(dtype)
return self return self
def hsv(self, h_diff : float, s_diff : float, v_diff : float, mask = None) -> 'ImageProcessor': def hsv(self, h_diff : float, s_diff : float, v_diff : float, mask = None) -> 'ImageProcessor':
""" """
apply HSV modification for BGR image apply HSV modification for BGR image
h_diff = [-1.0 .. 1.0] h_diff = [-1.0 .. 1.0]
s_diff = [-1.0 .. 1.0] s_diff = [-1.0 .. 1.0]
v_diff = [-1.0 .. 1.0] v_diff = [-1.0 .. 1.0]
""" """
dtype = self.get_dtype() dtype = self.get_dtype()
self.to_ufloat32() self.to_ufloat32()
img = orig_img = self._img img = orig_img = self._img
N,H,W,C = img.shape N,H,W,C = img.shape
if C != 3: if C != 3:
raise Exception('Image channels must be == 3') raise Exception('Image channels must be == 3')
img = img.reshape( (N*H,W,C) ) img = img.reshape( (N*H,W,C) )
h, s, v = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV)) h, s, v = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
h = ( h + h_diff*360.0 ) % 360 h = ( h + h_diff*360.0 ) % 360
s += s_diff s += s_diff
np.clip (s, 0, 1, out=s ) np.clip (s, 0, 1, out=s )
v += v_diff v += v_diff
np.clip (v, 0, 1, out=v ) np.clip (v, 0, 1, out=v )
img = np.clip( cv2.cvtColor(cv2.merge([h, s, v]), cv2.COLOR_HSV2BGR) , 0, 1 ) img = np.clip( cv2.cvtColor(cv2.merge([h, s, v]), cv2.COLOR_HSV2BGR) , 0, 1 )
img = img.reshape( (N,H,W,C) ) img = img.reshape( (N,H,W,C) )
if mask is not None: if mask is not None:
mask = self._check_normalize_mask(mask) mask = self._check_normalize_mask(mask)
img = ne.evaluate('orig_img*(1-mask) + img*mask') img = ne.evaluate('orig_img*(1-mask) + img*mask')
self._img = img self._img = img
self.to_dtype(dtype) self.to_dtype(dtype)
return self return self
def to_lab(self) -> 'ImageProcessor': def to_lab(self) -> 'ImageProcessor':
""" """
""" """
img = self._img img = self._img
N,H,W,C = img.shape N,H,W,C = img.shape
if C != 3: if C != 3:
raise Exception('Image channels must be == 3') raise Exception('Image channels must be == 3')
img = img.reshape( (N*H,W,C) ) img = img.reshape( (N*H,W,C) )
img = cv2.cvtColor(img, cv2.COLOR_BGR2LAB) img = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
img = img.reshape( (N,H,W,C) ) img = img.reshape( (N,H,W,C) )
self._img = img self._img = img
return self return self
def from_lab(self) -> 'ImageProcessor': def from_lab(self) -> 'ImageProcessor':
""" """
""" """
img = self._img img = self._img
N,H,W,C = img.shape N,H,W,C = img.shape
if C != 3: if C != 3:
raise Exception('Image channels must be == 3') raise Exception('Image channels must be == 3')
img = img.reshape( (N*H,W,C) ) img = img.reshape( (N*H,W,C) )
img = cv2.cvtColor(img, cv2.COLOR_LAB2BGR) img = cv2.cvtColor(img, cv2.COLOR_LAB2BGR)
img = img.reshape( (N,H,W,C) ) img = img.reshape( (N,H,W,C) )
self._img = img self._img = img
return self return self
def jpeg_recompress(self, quality : int, mask = None ) -> 'ImageProcessor': def jpeg_recompress(self, quality : int, mask = None ) -> 'ImageProcessor':
""" """
quality 0-100 quality 0-100
""" """
dtype = self.get_dtype() dtype = self.get_dtype()
self.to_uint8() self.to_uint8()
img = orig_img = self._img img = orig_img = self._img
_,_,_,C = img.shape _,_,_,C = img.shape
if C != 3: if C != 3:
raise Exception('Image channels must be == 3') raise Exception('Image channels must be == 3')
new_imgs = [] new_imgs = []
for x in img: for x in img:
ret, result = cv2.imencode('.jpg', x, [int(cv2.IMWRITE_JPEG_QUALITY), quality] ) ret, result = cv2.imencode('.jpg', x, [int(cv2.IMWRITE_JPEG_QUALITY), quality] )
if not ret: if not ret:
raise Exception('unable to compress jpeg') raise Exception('unable to compress jpeg')
x = cv2.imdecode(result, flags=cv2.IMREAD_UNCHANGED) x = cv2.imdecode(result, flags=cv2.IMREAD_UNCHANGED)
new_imgs.append(x) new_imgs.append(x)
img = np.array(new_imgs) img = np.array(new_imgs)
if mask is not None: if mask is not None:
mask = self._check_normalize_mask(mask) mask = self._check_normalize_mask(mask)
img = ne.evaluate('orig_img*(1-mask) + img*mask').astype(np.uint8) img = ne.evaluate('orig_img*(1-mask) + img*mask').astype(np.uint8)
self._img = img self._img = img
self.to_dtype(dtype) self.to_dtype(dtype)
return self return self
def patch_to_batch(self, patch_size : int) -> 'ImageProcessor': def patch_to_batch(self, patch_size : int) -> 'ImageProcessor':
img = self._img img = self._img
N,H,W,C = img.shape N,H,W,C = img.shape
OH, OW = H // patch_size, W // patch_size OH, OW = H // patch_size, W // patch_size
img = img.reshape(N,OH,patch_size,OW,patch_size,C) img = img.reshape(N,OH,patch_size,OW,patch_size,C)
img = img.transpose(0,2,4,1,3,5) img = img.transpose(0,2,4,1,3,5)
img = img.reshape(N*patch_size*patch_size,OH,OW,C) img = img.reshape(N*patch_size*patch_size,OH,OW,C)
self._img = img self._img = img
return self return self
def patch_from_batch(self, patch_size : int) -> 'ImageProcessor': def patch_from_batch(self, patch_size : int) -> 'ImageProcessor':
img = self._img img = self._img
N,H,W,C = img.shape N,H,W,C = img.shape
ON = N//(patch_size*patch_size) ON = N//(patch_size*patch_size)
img = img.reshape(ON,patch_size,patch_size,H,W,C ) img = img.reshape(ON,patch_size,patch_size,H,W,C )
img = img.transpose(0,3,1,4,2,5) img = img.transpose(0,3,1,4,2,5)
img = img.reshape(ON,H*patch_size,W*patch_size,C ) img = img.reshape(ON,H*patch_size,W*patch_size,C )
self._img = img self._img = img
return self return self
def rct(self, like : np.ndarray, mask : np.ndarray = None, like_mask : np.ndarray = None, mask_cutoff=0.5) -> 'ImageProcessor': def rct(self, like : np.ndarray, mask : np.ndarray = None, like_mask : np.ndarray = None, mask_cutoff=0.5) -> 'ImageProcessor':
""" """
@ -596,7 +596,7 @@ class ImageProcessor:
mask(None) np.ndarray [N][HW][1C] np.uint8/np.float32 mask(None) np.ndarray [N][HW][1C] np.uint8/np.float32
like_mask(None) np.ndarray [N][HW][1C] np.uint8/np.float32 like_mask(None) np.ndarray [N][HW][1C] np.uint8/np.float32
mask_cutoff(0.5) float mask_cutoff(0.5) float
masks are used to limit the space where color statistics will be computed to adjust the image masks are used to limit the space where color statistics will be computed to adjust the image
@ -610,41 +610,41 @@ class ImageProcessor:
like_for_stat = ImageProcessor(like).to_ufloat32().to_lab().get_image('NHWC') like_for_stat = ImageProcessor(like).to_ufloat32().to_lab().get_image('NHWC')
if like_mask is not None: if like_mask is not None:
like_mask = ImageProcessor(like_mask).to_ufloat32().ch(1).get_image('NHW') like_mask = ImageProcessor(like_mask).to_ufloat32().ch(1).get_image('NHW')
like_for_stat = like_for_stat.copy() like_for_stat = like_for_stat.copy()
like_for_stat[like_mask < mask_cutoff] = [0,0,0] like_for_stat[like_mask < mask_cutoff] = [0,0,0]
img_for_stat = img = self._img img_for_stat = img = self._img
if mask is not None: if mask is not None:
mask = ImageProcessor(mask).to_ufloat32().ch(1).get_image('NHW') mask = ImageProcessor(mask).to_ufloat32().ch(1).get_image('NHW')
img_for_stat = img_for_stat.copy() img_for_stat = img_for_stat.copy()
img_for_stat[mask < mask_cutoff] = [0,0,0] img_for_stat[mask < mask_cutoff] = [0,0,0]
source_l_mean, source_l_std, source_a_mean, source_a_std, source_b_mean, source_b_std, \ source_l_mean, source_l_std, source_a_mean, source_a_std, source_b_mean, source_b_std, \
= img_for_stat[...,0].mean((1,2), keepdims=True), img_for_stat[...,0].std((1,2), keepdims=True), img_for_stat[...,1].mean((1,2), keepdims=True), img_for_stat[...,1].std((1,2), keepdims=True), img_for_stat[...,2].mean((1,2), keepdims=True), img_for_stat[...,2].std((1,2), keepdims=True) = img_for_stat[...,0].mean((1,2), keepdims=True), img_for_stat[...,0].std((1,2), keepdims=True), img_for_stat[...,1].mean((1,2), keepdims=True), img_for_stat[...,1].std((1,2), keepdims=True), img_for_stat[...,2].mean((1,2), keepdims=True), img_for_stat[...,2].std((1,2), keepdims=True)
like_l_mean, like_l_std, like_a_mean, like_a_std, like_b_mean, like_b_std, \ like_l_mean, like_l_std, like_a_mean, like_a_std, like_b_mean, like_b_std, \
= like_for_stat[...,0].mean((1,2), keepdims=True), like_for_stat[...,0].std((1,2), keepdims=True), like_for_stat[...,1].mean((1,2), keepdims=True), like_for_stat[...,1].std((1,2), keepdims=True), like_for_stat[...,2].mean((1,2), keepdims=True), like_for_stat[...,2].std((1,2), keepdims=True) = like_for_stat[...,0].mean((1,2), keepdims=True), like_for_stat[...,0].std((1,2), keepdims=True), like_for_stat[...,1].mean((1,2), keepdims=True), like_for_stat[...,1].std((1,2), keepdims=True), like_for_stat[...,2].mean((1,2), keepdims=True), like_for_stat[...,2].std((1,2), keepdims=True)
# not as in the paper: scale by the standard deviations using reciprocal of paper proposed factor # not as in the paper: scale by the standard deviations using reciprocal of paper proposed factor
source_l = img[...,0] source_l = img[...,0]
source_l = ne.evaluate('(source_l - source_l_mean) * like_l_std / source_l_std + like_l_mean') source_l = ne.evaluate('(source_l - source_l_mean) * like_l_std / source_l_std + like_l_mean')
source_a = img[...,1] source_a = img[...,1]
source_a = ne.evaluate('(source_a - source_a_mean) * like_a_std / source_a_std + like_a_mean') source_a = ne.evaluate('(source_a - source_a_mean) * like_a_std / source_a_std + like_a_mean')
source_b = img[...,2] source_b = img[...,2]
source_b = ne.evaluate('(source_b - source_b_mean) * like_b_std / source_b_std + like_b_mean') source_b = ne.evaluate('(source_b - source_b_mean) * like_b_std / source_b_std + like_b_mean')
np.clip(source_l, 0, 100, out=source_l) np.clip(source_l, 0, 100, out=source_l)
np.clip(source_a, -127, 127, out=source_a) np.clip(source_a, -127, 127, out=source_a)
np.clip(source_b, -127, 127, out=source_b) np.clip(source_b, -127, 127, out=source_b)
self._img = np.stack([source_l,source_a,source_b], -1) self._img = np.stack([source_l,source_a,source_b], -1)
self.from_lab() self.from_lab()
self.to_dtype(dtype) self.to_dtype(dtype)
return self return self
def rotate90(self) -> 'ImageProcessor': def rotate90(self) -> 'ImageProcessor':
self._img = np.rot90(self._img, k=1, axes=(1,2) ) self._img = np.rot90(self._img, k=1, axes=(1,2) )
return self return self
@ -847,7 +847,7 @@ class ImageProcessor:
else: else:
raise ValueError('unsupported dtype') raise ValueError('unsupported dtype')
def to_ufloat32(self) -> 'ImageProcessor': def to_ufloat32(self, as_tanh=False) -> 'ImageProcessor':
""" """
Convert to uniform float32 Convert to uniform float32
if current image dtype uint8, then image will be divided by / 255.0 if current image dtype uint8, then image will be divided by / 255.0
@ -855,7 +855,11 @@ class ImageProcessor:
""" """
if self._img.dtype == np.uint8: if self._img.dtype == np.uint8:
self._img = self._img.astype(np.float32) self._img = self._img.astype(np.float32)
self._img /= 255.0 if as_tanh:
self._img /= 127.5
self._img -= 1.0
else:
self._img /= 255.0
return self return self
@ -876,26 +880,26 @@ class ImageProcessor:
def _check_normalize_mask(self, mask : np.ndarray): def _check_normalize_mask(self, mask : np.ndarray):
N,H,W,C = self._img.shape N,H,W,C = self._img.shape
if mask.ndim == 2: if mask.ndim == 2:
mask = mask[None,...,None] mask = mask[None,...,None]
elif mask.ndim == 3: elif mask.ndim == 3:
mask = mask[None,...] mask = mask[None,...]
if mask.ndim != 4: if mask.ndim != 4:
raise ValueError('mask must have ndim == 4') raise ValueError('mask must have ndim == 4')
MN, MH, MW, MC = mask.shape MN, MH, MW, MC = mask.shape
if H != MH or W != MW: if H != MH or W != MW:
raise ValueError('mask H,W, mismatch') raise ValueError('mask H,W, mismatch')
if MN != 1 and N != MN: if MN != 1 and N != MN:
raise ValueError(f'mask N dim must be 1 or == {N}') raise ValueError(f'mask N dim must be 1 or == {N}')
if MC != 1 and C != MC: if MC != 1 and C != MC:
raise ValueError(f'mask C dim must be 1 or == {C}') raise ValueError(f'mask C dim must be 1 or == {C}')
return mask return mask
class Interpolation(IntEnum): class Interpolation(IntEnum):
NEAREST = 0, NEAREST = 0,
LINEAR = 1 LINEAR = 1