diff --git a/core/imagelib/warp.py b/core/imagelib/warp.py index 37abc36..385d8f7 100644 --- a/core/imagelib/warp.py +++ b/core/imagelib/warp.py @@ -1,7 +1,140 @@ import numpy as np +import numpy.linalg as npla import cv2 from core import randomex + +def mls_rigid_deformation(vy, vx, p, q, alpha=1.0, eps=1e-8): + """ Rigid deformation + + Parameters + ---------- + vx, vy: ndarray + coordinate grid, generated by np.meshgrid(gridX, gridY) + p: ndarray + an array with size [n, 2], original control points + q: ndarray + an array with size [n, 2], final control points + alpha: float + parameter used by weights + eps: float + epsilon + + Return + ------ + A deformed image. + """ + # Change (x, y) to (row, col) + q = np.ascontiguousarray(q[:, [1, 0]].astype(np.int16)) + p = np.ascontiguousarray(p[:, [1, 0]].astype(np.int16)) + + # Exchange p and q and hence we transform destination pixels to the corresponding source pixels. + p, q = q, p + + grow = vx.shape[0] # grid rows + gcol = vx.shape[1] # grid cols + ctrls = p.shape[0] # control points + + # Compute + reshaped_p = p.reshape(ctrls, 2, 1, 1) # [ctrls, 2, 1, 1] + reshaped_v = np.vstack((vx.reshape(1, grow, gcol), vy.reshape(1, grow, gcol))) # [2, grow, gcol] + + w = 1.0 / (np.sum((reshaped_p - reshaped_v).astype(np.float32) ** 2, axis=1) + eps) ** alpha # [ctrls, grow, gcol] + w /= np.sum(w, axis=0, keepdims=True) # [ctrls, grow, gcol] + + pstar = np.zeros((2, grow, gcol), np.float32) + for i in range(ctrls): + pstar += w[i] * reshaped_p[i] # [2, grow, gcol] + + vpstar = reshaped_v - pstar # [2, grow, gcol] + reshaped_vpstar = vpstar.reshape(2, 1, grow, gcol) # [2, 1, grow, gcol] + neg_vpstar_verti = vpstar[[1, 0],...] # [2, grow, gcol] + neg_vpstar_verti[1,...] = -neg_vpstar_verti[1,...] + reshaped_neg_vpstar_verti = neg_vpstar_verti.reshape(2, 1, grow, gcol) # [2, 1, grow, gcol] + mul_right = np.concatenate((reshaped_vpstar, reshaped_neg_vpstar_verti), axis=1) # [2, 2, grow, gcol] + reshaped_mul_right = mul_right.reshape(2, 2, grow, gcol) # [2, 2, grow, gcol] + + # Calculate q + reshaped_q = q.reshape((ctrls, 2, 1, 1)) # [ctrls, 2, 1, 1] + qstar = np.zeros((2, grow, gcol), np.float32) + for i in range(ctrls): + qstar += w[i] * reshaped_q[i] # [2, grow, gcol] + + temp = np.zeros((grow, gcol, 2), np.float32) + for i in range(ctrls): + phat = reshaped_p[i] - pstar # [2, grow, gcol] + reshaped_phat = phat.reshape(1, 2, grow, gcol) # [1, 2, grow, gcol] + reshaped_w = w[i].reshape(1, 1, grow, gcol) # [1, 1, grow, gcol] + neg_phat_verti = phat[[1, 0]] # [2, grow, gcol] + neg_phat_verti[1] = -neg_phat_verti[1] + reshaped_neg_phat_verti = neg_phat_verti.reshape(1, 2, grow, gcol) # [1, 2, grow, gcol] + mul_left = np.concatenate((reshaped_phat, reshaped_neg_phat_verti), axis=0) # [2, 2, grow, gcol] + + A = np.matmul((reshaped_w * mul_left).transpose(2, 3, 0, 1), + reshaped_mul_right.transpose(2, 3, 0, 1)) # [grow, gcol, 2, 2] + + qhat = reshaped_q[i] - qstar # [2, grow, gcol] + reshaped_qhat = qhat.reshape(1, 2, grow, gcol).transpose(2, 3, 0, 1) # [grow, gcol, 1, 2] + + # Get final image transfomer -- 3-D array + temp += np.matmul(reshaped_qhat, A).reshape(grow, gcol, 2) # [grow, gcol, 2] + + temp = temp.transpose(2, 0, 1) # [2, grow, gcol] + + normed_temp = np.linalg.norm(temp, axis=0, keepdims=True) # [1, grow, gcol] + normed_vpstar = np.linalg.norm(vpstar, axis=0, keepdims=True) # [1, grow, gcol] + nan_mask = normed_temp[0]==0 + + transformers = np.true_divide(temp, normed_temp, out=np.zeros_like(temp), where= ~nan_mask) * normed_vpstar + qstar + # fix nan values + nan_mask_flat = np.flatnonzero(nan_mask) + nan_mask_anti_flat = np.flatnonzero(~nan_mask) + transformers[0][nan_mask] = np.interp(nan_mask_flat, nan_mask_anti_flat, transformers[0][~nan_mask]) + transformers[1][nan_mask] = np.interp(nan_mask_flat, nan_mask_anti_flat, transformers[1][~nan_mask]) + + return transformers + + +def gen_pts(W, H, rnd_state=None): + + if rnd_state is None: + rnd_state = np.random + + min_pts, max_pts = 4, 16 + n_pts = rnd_state.randint(min_pts, max_pts) + + min_radius_per = 0.00 + max_radius_per = 0.10 + pts = [] + + for i in range(max_pts): + while True: + x, y = rnd_state.randint(W), rnd_state.randint(H) + rad = min_radius_per + rnd_state.rand()*(max_radius_per-min_radius_per) + + intersect = False + for px,py,prad,_,_ in pts: + + dist = npla.norm([x-px, y-py]) + if dist <= (rad+prad)*2: + intersect = True + break + if intersect: + continue + + angle = rnd_state.rand()*(2*np.pi) + x2 = int(x+np.cos(angle)*W*rad) + y2 = int(y+np.sin(angle)*H*rad) + + break + pts.append( (x,y,rad, x2,y2) ) + + pts1 = np.array( [ [pt[0],pt[1]] for pt in pts ] ) + pts2 = np.array( [ [pt[-2],pt[-1]] for pt in pts ] ) + + return pts1, pts2 + + def gen_warp_params (w, flip=False, rotation_range=[-10,10], scale_range=[-0.5, 0.5], tx_range=[-0.05, 0.05], ty_range=[-0.05, 0.05], rnd_state=None ): if rnd_state is None: rnd_state = np.random @@ -17,22 +150,28 @@ def gen_warp_params (w, flip=False, rotation_range=[-10,10], scale_range=[-0.5, ty = rnd_state.uniform( ty_range[0], ty_range[1] ) p_flip = flip and rnd_state.randint(10) < 4 - #random warp by grid + #random warp V1 cell_size = [ w // (2**i) for i in range(1,4) ] [ rnd_state.randint(3) ] cell_count = w // cell_size + 1 - grid_points = np.linspace( 0, w, cell_count) mapx = np.broadcast_to(grid_points, (cell_count, cell_count)).copy() mapy = mapx.T - mapx[1:-1,1:-1] = mapx[1:-1,1:-1] + randomex.random_normal( size=(cell_count-2, cell_count-2) )*(cell_size*0.24) mapy[1:-1,1:-1] = mapy[1:-1,1:-1] + randomex.random_normal( size=(cell_count-2, cell_count-2) )*(cell_size*0.24) - half_cell_size = cell_size // 2 - mapx = cv2.resize(mapx, (w+cell_size,)*2 )[half_cell_size:-half_cell_size,half_cell_size:-half_cell_size].astype(np.float32) mapy = cv2.resize(mapy, (w+cell_size,)*2 )[half_cell_size:-half_cell_size,half_cell_size:-half_cell_size].astype(np.float32) - + ############## + + # random warp V2 + # pts1, pts2 = gen_pts(w, w, rnd_state) + # gridX = np.arange(w, dtype=np.int16) + # gridY = np.arange(w, dtype=np.int16) + # vy, vx = np.meshgrid(gridX, gridY) + # drigid = mls_rigid_deformation(vy, vx, pts1, pts2) + # mapy, mapx = drigid.astype(np.float32) + ################ + #random transform random_transform_mat = cv2.getRotationMatrix2D((w // 2, w // 2), rotation, scale) random_transform_mat[:, 2] += (tx*w, ty*w)