refactoring

This commit is contained in:
iperov 2021-10-26 20:06:50 +04:00
parent d90ec2d024
commit 8b385f6d80
11 changed files with 263 additions and 137 deletions

View file

@ -39,29 +39,31 @@ class FaceWarper:
self._img_to_face_uni_mat = img_to_face_uni_mat
self._face_to_img_uni_mat = img_to_face_uni_mat.invert()
if rnd_state is None:
rnd_state = np.random
self._rnd_state_state = rnd_state.get_state()
rnd_state = np.random.RandomState()
rnd_state.set_state(rnd_state.get_state() if rnd_state is not None else np.random.RandomState().get_state())
self._align_rot_deg = rnd_state.uniform(*align_rot_deg) if isinstance(align_rot_deg, Iterable) else align_rot_deg
self._align_scale = rnd_state.uniform(*align_scale) if isinstance(align_scale, Iterable) else align_scale
self._align_tx = rnd_state.uniform(*align_tx) if isinstance(align_tx, Iterable) else align_tx
self._align_ty = rnd_state.uniform(*align_ty) if isinstance(align_ty, Iterable) else align_ty
self._align_rot_deg = rnd_state.uniform(*align_rot_deg) if isinstance(align_rot_deg, Iterable) else align_rot_deg
self._align_scale = rnd_state.uniform(*align_scale) if isinstance(align_scale, Iterable) else align_scale
self._align_tx = rnd_state.uniform(*align_tx) if isinstance(align_tx, Iterable) else align_tx
self._align_ty = rnd_state.uniform(*align_ty) if isinstance(align_ty, Iterable) else align_ty
self._rw_grid_cell_count = rnd_state.randint(*rw_grid_cell_count) if isinstance(rw_grid_cell_count, Iterable) else rw_grid_cell_count
self._rw_grid_rot_deg = rnd_state.uniform(*rw_grid_rot_deg) if isinstance(rw_grid_rot_deg, Iterable) else rw_grid_rot_deg
self._rw_grid_scale = rnd_state.uniform(*rw_grid_scale) if isinstance(rw_grid_scale, Iterable) else rw_grid_scale
self._rw_grid_tx = rnd_state.uniform(*rw_grid_tx) if isinstance(rw_grid_tx, Iterable) else rw_grid_tx
self._rw_grid_ty = rnd_state.uniform(*rw_grid_ty) if isinstance(rw_grid_ty, Iterable) else rw_grid_ty
self._rw_grid_rot_deg = rnd_state.uniform(*rw_grid_rot_deg) if isinstance(rw_grid_rot_deg, Iterable) else rw_grid_rot_deg
self._rw_grid_scale = rnd_state.uniform(*rw_grid_scale) if isinstance(rw_grid_scale, Iterable) else rw_grid_scale
self._rw_grid_tx = rnd_state.uniform(*rw_grid_tx) if isinstance(rw_grid_tx, Iterable) else rw_grid_tx
self._rw_grid_ty = rnd_state.uniform(*rw_grid_ty) if isinstance(rw_grid_ty, Iterable) else rw_grid_ty
self._rnd_state_state = rnd_state.get_state()
self._cached = {}
def transform(self, img : np.ndarray, out_res : int, random_warp : bool = True) -> np.ndarray:
"""
transform an image. Subsequent calls will output the same result for any img shape and out_res.
transform an image.
img np.ndarray (HWC)
Subsequent calls will output the same result for any img shape and out_res.
img np.ndarray (HWC)
out_res int
out_res int
random_warp(True) bool
"""