diff --git a/main.py b/main.py index 55a0107..c3772ab 100644 --- a/main.py +++ b/main.py @@ -79,12 +79,12 @@ def main(): p.set_defaults(func=train_FaceAligner) def train_CTSOT(args): - from apps.trainers.CTSOT.CTSOTTrainerApp import run_app - run_app(userdata_path=Path(args.userdata_dir), faceset_path=Path(args.faceset_path)) + from apps.trainers.CTSOT.CTSOTTrainerApp import CTSOTTrainerApp + CTSOTTrainerApp(workspace_path=Path(args.workspace_dir), faceset_path=Path(args.faceset_path)) p = train_parsers.add_parser('CTSOT') - p.add_argument('--userdata-dir', default=None, action=fixPathAction, help="Directory to save app data.") - p.add_argument('--faceset-path', default=None, action=fixPathAction, help=".dfs path") + p.add_argument('--workspace-dir', default=None, action=fixPathAction, help="Workspace directory.") + p.add_argument('--faceset-path', default=None, action=fixPathAction, help=".dfs faceset path") p.set_defaults(func=train_CTSOT) def bad_args(arguments): diff --git a/modelhub/DFLive/DFMModel.py b/modelhub/DFLive/DFMModel.py index 0668b41..781325b 100644 --- a/modelhub/DFLive/DFMModel.py +++ b/modelhub/DFLive/DFMModel.py @@ -58,7 +58,7 @@ def get_available_devices() -> List[ORTDeviceInfo]: class DFMModel: def __init__(self, model_path : Path, device : ORTDeviceInfo = None): if device is None: - device = lib_ort.get_cpu_device() + device = lib_ort.get_cpu_device_info() self._model_path = model_path sess = self._sess = lib_ort.InferenceSession_with_device(str(model_path), device) diff --git a/modelhub/torch/CTSOT/CTSOT.py b/modelhub/torch/CTSOT/CTSOT.py new file mode 100644 index 0000000..9d50ba7 --- /dev/null +++ b/modelhub/torch/CTSOT/CTSOT.py @@ -0,0 +1,121 @@ +from pathlib import Path +from typing import Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from xlib.file import SplittedFile +from xlib.torch import TorchDeviceInfo, get_cpu_device_info + + +class ResBlock(nn.Module): + def __init__(self, ch): + super().__init__() + self.conv1 = nn.Conv2d(ch, ch, kernel_size=3, stride=1, padding=1) + self.conv2 = nn.Conv2d(ch, ch, kernel_size=3, stride=1, padding=1) + + def forward(self, inp): + x = inp + x = F.leaky_relu(self.conv1(x), 0.2) + x = F.leaky_relu(self.conv2(x) + inp, 0.2) + return x + +class AutoEncoder(nn.Module): + def __init__(self, resolution, in_ch, ae_ch): + super().__init__() + self._resolution = resolution + self._in_ch = in_ch + self._ae_ch = ae_ch + + self.conv1 = nn.Conv2d(in_ch*1, in_ch*2, kernel_size=5, stride=2, padding=2) + self.conv2 = nn.Conv2d(in_ch*2, in_ch*4, kernel_size=5, stride=2, padding=2) + self.conv3 = nn.Conv2d(in_ch*4, in_ch*8, kernel_size=5, stride=2, padding=2) + self.conv4 = nn.Conv2d(in_ch*8, in_ch*8, kernel_size=5, stride=2, padding=2) + + self.dense_in = nn.Linear( in_ch*8 * ( resolution // (2**4) )**2, ae_ch) + self.dense_out = nn.Linear( ae_ch, in_ch*8 * ( resolution // (2**4) )**2 ) + + self.up_conv4 = nn.ConvTranspose2d(in_ch*8, in_ch*8, kernel_size=3, stride=2, padding=1, output_padding=(1,1), ) + self.up_conv3 = nn.ConvTranspose2d(in_ch*8, in_ch*4, kernel_size=3, stride=2, padding=1, output_padding=(1,1), ) + self.up_conv2 = nn.ConvTranspose2d(in_ch*4, in_ch*2, kernel_size=3, stride=2, padding=1, output_padding=(1,1), ) + self.up_conv1 = nn.ConvTranspose2d(in_ch*2, in_ch*1, kernel_size=3, stride=2, padding=1, output_padding=(1,1), ) + + + def forward(self, inp): + x = inp + x = F.leaky_relu(self.conv1(x), 0.1) + x = F.leaky_relu(self.conv2(x), 0.1) + x = F.leaky_relu(self.conv3(x), 0.1) + x = F.leaky_relu(self.conv4(x), 0.1) + + x = x.view( (x.shape[0], np.prod(x.shape[1:])) ) + + x = self.dense_in(x) + x = self.dense_out(x) + + x = x.view( (x.shape[0], self._in_ch*8, self._resolution // (2**4), self._resolution // (2**4) )) + x = F.leaky_relu(self.up_conv4(x), 0.1) + x = F.leaky_relu(self.up_conv3(x), 0.1) + x = F.leaky_relu(self.up_conv2(x), 0.1) + x = F.leaky_relu(self.up_conv1(x), 0.1) + + # from xlib.console import diacon + # diacon.Diacon.stop() + # import code + # code.interact(local=dict(globals(), **locals())) + + + return x + + +class CTSOTNet(nn.Module): + def __init__(self, resolution, in_ch=6, inner_ch=64, ae_ch=256, out_ch=3, res_block_count = 12): + super().__init__() + self.in_conv = nn.Conv2d(in_ch, inner_ch, kernel_size=1, stride=1, padding=0) + self.ae = AutoEncoder(resolution, inner_ch, ae_ch) + self.out_conv = nn.Conv2d(inner_ch, out_ch, kernel_size=1, stride=1, padding=0) + + def forward(self, img1_t, img2_t): + x = torch.cat([img1_t, img2_t], dim=1) + + x = self.in_conv(x) + + x = self.ae(x) + x = self.out_conv(x) + x = torch.tanh(x) + return x + + + +class CTSOT: + def __init__(self, device_info : TorchDeviceInfo = None, + state_dict : Union[dict, None] = None, + training : bool = False): + if device_info is None: + device_info = get_cpu_device_info() + self.device_info = device_info + + self._net = net = CTSOTNet() + + if state_dict is not None: + net.load_state_dict(state_dict) + + if training: + net.train() + else: + net.eval() + + self.set_device(device_info) + + def set_device(self, device_info : TorchDeviceInfo = None): + if device_info is None or device_info.is_cpu(): + self._net.cpu() + else: + self._net.cuda(device_info.get_index()) + + def get_state_dict(self): + return self.net.state_dict() + + def get_net(self) -> CTSOTNet: + return self._net \ No newline at end of file diff --git a/modelhub/torch/S3FD/S3FD.py b/modelhub/torch/S3FD/S3FD.py index 141842e..3fd319b 100644 --- a/modelhub/torch/S3FD/S3FD.py +++ b/modelhub/torch/S3FD/S3FD.py @@ -8,13 +8,13 @@ import torch.nn.functional as F from xlib import math as lib_math from xlib.file import SplittedFile from xlib.image import ImageProcessor -from xlib.torch import TorchDeviceInfo, get_cpu_device +from xlib.torch import TorchDeviceInfo, get_cpu_device_info class S3FD: def __init__(self, device_info : TorchDeviceInfo = None ): if device_info is None: - device_info = get_cpu_device() + device_info = get_cpu_device_info() self.device_info = device_info path = Path(__file__).parent / 'S3FD.pth' diff --git a/modelhub/torch/__init__.py b/modelhub/torch/__init__.py index 8ecdd82..fca0256 100644 --- a/modelhub/torch/__init__.py +++ b/modelhub/torch/__init__.py @@ -1,2 +1,3 @@ -from .CenterFace.CenterFace import CenterFace, CenterFace_to_onnx +from .CenterFace.CenterFace import CenterFace_to_onnx from .S3FD.S3FD import S3FD +from .CTSOT.CTSOT import CTSOT, CTSOTNet \ No newline at end of file diff --git a/scripts/dev.py b/scripts/dev.py index 64d2216..311c506 100644 --- a/scripts/dev.py +++ b/scripts/dev.py @@ -53,14 +53,10 @@ def extract_FaceSynthetics(inputdir_path : Path, faceset_path : Path): """ if faceset_path.suffix != '.dfs': raise ValueError('faceset_path must have .dfs extension.') - + filepaths = lib_path.get_files_paths(inputdir_path) - - fs = lib_face.Faceset(faceset_path) - fs.recreate() - - - for filepath in lib_con.progress_bar_iterator(filepaths, 'Processing'): + fs = lib_face.Faceset(faceset_path, write_access=True, recreate=True) + for filepath in lib_con.progress_bar_iterator(filepaths, desc='Processing'): if filepath.suffix == '.txt': image_filepath = filepath.parent / f'{filepath.name.split("_")[0]}.png' @@ -93,17 +89,13 @@ def extract_FaceSynthetics(inputdir_path : Path, faceset_path : Path): ufm.set_UImage_uuid(uimg.get_uuid()) ufm.set_FRect(flmrks.get_FRect()) ufm.add_FLandmarks2D(flmrks) - + + fs.add_UFaceMark(ufm) fs.add_UImage(uimg, format='png') - fs.add_UFaceMark(ufm) - - - fs.shrink() + + fs.optimize() fs.close() - import code - code.interact(local=dict(globals(), **locals())) - # seg_filepath = input_path / ( Path(image_filepath).stem + '_seg.png') # if not seg_filepath.exists(): # raise ValueError(f'{seg_filepath} does not exist') diff --git a/xlib/avecl/_internal/op/gaussian_blur.py b/xlib/avecl/_internal/op/gaussian_blur.py index 69c9fd0..a9c1097 100644 --- a/xlib/avecl/_internal/op/gaussian_blur.py +++ b/xlib/avecl/_internal/op/gaussian_blur.py @@ -15,7 +15,7 @@ def gaussian_blur (input_t : Tensor, sigma, dtype=None) -> Tensor: """ if sigma <= 0.0: - return input_t.copy() #TODO + return input_t.copy() device = input_t.get_device() diff --git a/xlib/console/diacon/Diacon.py b/xlib/console/diacon/Diacon.py index 568ebfc..61497c5 100644 --- a/xlib/console/diacon/Diacon.py +++ b/xlib/console/diacon/Diacon.py @@ -3,113 +3,220 @@ import threading import time from enum import IntEnum from typing import Any, Callable, List, Tuple, Union - +from numbers import Number from ... import text as lib_text class EDlgMode(IntEnum): - UNDEFINED = 0 + UNHANDLED = 0 BACK = 1 RELOAD = 2 WRONG_INPUT = 3 - SUCCESS = 4 + HANDLED = 4 + class DlgChoice: - def __init__(self, name : str = None, row_desc : str = None): - super().__init__() + def __init__(self, name : str = None, + row_def : str = None, + on_choose : Callable = None): if len(name) == 0: raise ValueError('Zero len name is not valid.') self._name = name - self._row_desc = row_desc - + self._row_def = row_def + self._on_choose = on_choose + def get_name(self) -> Union[str, None]: return self._name - def get_row_desc(self) -> Union[str, None]: return self._row_desc - + def get_row_def(self) -> Union[str, None]: return self._row_def + def get_on_choose(self) -> Callable: return self._on_choose + class Dlg: - def __init__(self, title : str = None, has_go_back=True): + def __init__(self, on_recreate : Callable[ [], 'Dlg'] = None, + on_back : Callable = None, + top_rows_def : Union[str, List[str]] = None, + bottom_rows_def : Union[str, List[str]] = None, + ): """ - + base class for Diacon dialogs. """ - self._title = title - self._has_go_back = has_go_back - + self._on_recreate = on_recreate + self._on_back = on_back + + self._top_rows_def = top_rows_def + self._bottom_rows_def = bottom_rows_def + + def recreate(self): + """ + """ + if self._on_recreate is not None: + return self._on_recreate(self) + else: + raise Exception('on_recreate() is not defined.') + def get_name(self) -> str: return self._name - - def handle_user_input(self, s : str) -> EDlgMode: + + def set_current(self, print=True): + Diacon.update_dlg(self, print=print) + + def handle_user_input(self, s : str): """ - """ - s = s.strip() - - # ? and < available in any dialog, handle them first - s_len = len(s) - if s_len == 0: + mode = self.on_user_input(s.strip()) + + if mode == EDlgMode.UNHANDLED: + mode = EDlgMode.RELOAD + if mode == EDlgMode.WRONG_INPUT: + print('\nWrong input') + mode = EDlgMode.RELOAD + + if mode == EDlgMode.RELOAD: + self.recreate().set_current() + if mode == EDlgMode.BACK: + if self._on_back is not None: + self._on_back(self) + + #overridable + def on_user_input(self, s : str) -> EDlgMode: + if len(s) == 0: return EDlgMode.RELOAD - if s_len == 1: - #if s == '?': - # return EDlgMode.RELOAD + if self._on_back is not None and len(s) == 1: if s == '<': return EDlgMode.BACK - - return self.on_user_input(s) - + return EDlgMode.UNHANDLED + def print(self, table_width_max=80, col_spacing = 3): """ print dialog """ - - # Gather table lines table_def : List[str]= [] - if self._has_go_back: + trd = self._top_rows_def + brd = self._bottom_rows_def + if trd is not None: + if not isinstance(trd, (list,tuple)): + trd = [trd] + table_def += trd + + if self._on_back is not None: table_def.append('| < | Go back.') - + table_def.append('|99') table_def = self.on_print(table_def) - table = lib_text.ascii_table(table_def, max_table_width=80, - left_border = None, - right_border = None, + if brd is not None: + if not isinstance(brd, (list,tuple)): + brd = [brd] + table_def += brd + + table = lib_text.ascii_table(table_def, max_table_width=80, + left_border = '| ', + right_border = ' |', border = ' | ', - row_symbol = None) + row_symbol = None, + ) print() print(table) - + #overridable def on_print(self, table_lines : List[Tuple[str,str]]): return table_lines + + +class DlgNumber(Dlg): + def __init__(self, is_float : bool, + current_value = None, + min_value = None, + max_value = None, + clip_min_value = None, + clip_max_value = None, + on_value : Callable[ [Dlg, Number], None] = None, + on_recreate : Callable[ [], 'Dlg'] = None, + on_back : Callable = None, + top_rows_def : Union[str, List[str]] = None, + bottom_rows_def : Union[str, List[str]] = None, ): + super().__init__(on_recreate=on_recreate, on_back=on_back, top_rows_def=top_rows_def, bottom_rows_def=bottom_rows_def) + + if min_value is not None and max_value is not None and min_value > max_value: + raise ValueError('min_value > max_value') + if clip_min_value is not None and clip_max_value is not None and clip_min_value > clip_max_value: + raise ValueError('clip_min_value > clip_max_value') + + self._is_float = is_float + self._current_value = current_value + self._min_value = min_value + self._max_value = max_value + self._clip_min_value = clip_min_value + self._clip_max_value = clip_max_value + self._on_value = on_value + #overridable - def on_user_input(self, s : str) -> EDlgMode: - """ - handle user input - return False if input is invalid - """ - return EDlgMode.UNDEFINED + def on_print(self, table_def : List[str]): + + minv, maxv = self._min_value, self._max_value + + if self._is_float: + line = '| * | Enter float number' + else: + line = '| * | Enter integer number' + + if minv is not None and maxv is None: + line += f' in range: [{minv} ... )' + elif minv is None and maxv is not None: + line += f' in range: ( ... {maxv} ]' + elif minv is not None and maxv is not None: + line += f' in range: [{minv} ... {maxv} ]' + + table_def.append(line) + + return table_def + + #overridable + def on_user_input(self, s : str) -> bool: + result = super().on_user_input(s) + if result == EDlgMode.UNHANDLED: + try: + print(s) + v = float(s) if self._is_float else int(s) + + if self._min_value is not None: + if v < self._min_value: + return EDlgMode.WRONG_INPUT + + if self._max_value is not None: + if v > self._max_value: + return EDlgMode.WRONG_INPUT + + if self._clip_min_value is not None: + if v < self._clip_min_value: + v = self._clip_min_value + + if self._clip_max_value is not None: + if v > self._clip_max_value: + v = self._clip_max_value + + if self._on_value is not None: + self._on_value(self, v) + return EDlgMode.HANDLED + except: + return EDlgMode.WRONG_INPUT + + return result class DlgChoices(Dlg): - def __init__(self, choices : List[DlgChoice], multiple_choices=False, title : str = None, has_go_back = True): - """ - - """ - super().__init__(title=title, has_go_back=has_go_back) + def __init__(self, choices : List[DlgChoice], + on_multi_choice : Callable[ [ List[DlgChoice] ], None] = None, + on_recreate : Callable[ [Dlg], Dlg] = None, + on_back : Callable = None, + top_rows_def : Union[str, List[str]] = None, + bottom_rows_def : Union[str, List[str]] = None, + ): + super().__init__(on_recreate=on_recreate, on_back=on_back, top_rows_def=top_rows_def, bottom_rows_def=bottom_rows_def) self._choices = choices - self._multiple_choices = multiple_choices - - self._results = None - self._results_id = None + self._on_multi_choice = on_multi_choice self._short_names = [choice.get_name() for choice in choices] - # if any([x is not None for x in self._short_names]): - # # Using short names from choices - # if any([x is None for x in self._short_names]): - # raise Exception('No short name for one of choices.') - # if len(set(self._short_names)) != len(self._short_names): - # raise ValueError(f'Contains duplicate short names: {self._short_names}') - # else: - # Make short names for all choices names = [ choice.get_name() for choice in choices ] names_len = len(names) @@ -139,27 +246,14 @@ class DlgChoices(Dlg): break self._short_names = short_names - - - def get_selected_choices(self) -> List[DlgChoice]: - """ - returns selected choices - """ - return self._results - - def get_selected_choices_id(self) -> List[int]: - """ - returns selected choice - """ - return self._results_id - #overridable def on_print(self, table_def : List[str]): + for short_name, choice in zip(self._short_names, self._choices): row_def = f'| {short_name}' - row_desc = choice.get_row_desc() - if row_desc is not None: - row_def += row_desc + x = choice.get_row_def() + if x is not None: + row_def += x table_def.append(row_def) return table_def @@ -167,35 +261,36 @@ class DlgChoices(Dlg): #overridable def on_user_input(self, s : str) -> bool: result = super().on_user_input(s) - if result == EDlgMode.UNDEFINED: + if result == EDlgMode.UNHANDLED: - if self._multiple_choices: + if self._on_multi_choice is not None: multi_s = s.split(',') else: multi_s = [s] - results = [] - results_id = [] + choices_id = [] for s in multi_s: - s = s.strip() - - x = [ i for i,short_name in enumerate(self._short_names) if s == short_name ] + x = [ i for i, short_name in enumerate(self._short_names) if s.strip() == short_name ] if len(x) == 0: - # no short name match + # No short name match return EDlgMode.WRONG_INPUT else: id = x[0] - results_id.append(id) - results.append(self._choices[id]) - - if len(set(results_id)) != len(results_id): + choices_id.append(id) + + if len(set(choices_id)) != len(choices_id): # Duplicate input return EDlgMode.WRONG_INPUT - self._results = results - self._results_id = results_id + for id in choices_id: + on_choose = self._choices[id].get_on_choose() + if on_choose is not None: + on_choose(self) + + if self._on_multi_choice is not None: + self._on_multi_choice(choices_id) - return EDlgMode.SUCCESS + return EDlgMode.HANDLED return result @@ -204,42 +299,9 @@ class DlgChoices(Dlg): class _Diacon: """ User dialog with via console. - - Internal architecture: - - [ - Main-Thread - - current thread from which __init__() called - ] - - [ - Dialog-Thread - - separate thread where dialogs are handled and dynamically created - - we need this thread, because main thread can be busy, - for example training neural network - - calls on_dlg() provided with __init__ - - thus keep in mind on_dlg() works in separate thread - - This thread must not be blocked inside on_dlg(), - because Diacon.stop() can be called that stops all threads. - ] - - [ - Input-Thread - - separate thread where user input is accepted in non-blocking mode, - and transfered to processing thread - ] """ def __init__(self): - self._on_dlg : Callable = None - self._lock = threading.RLock() self._current_dlg : Dlg = None self._new_dlg : Dlg = None @@ -250,11 +312,10 @@ class _Diacon: self._input_request = False self._input_result : str = None - def start(self, on_dlg : Callable): + def start(self): if self._started: raise Exception('Diacon already started.') self._started = True - self._on_dlg = on_dlg self._input_t = threading.Thread(target=self._input_thread, daemon=True) self._input_t.start() @@ -285,8 +346,6 @@ class _Diacon: def _dialog_thread(self, ): - self._on_dlg(None, EDlgMode.RELOAD) - while self._started: with self._lock: @@ -303,13 +362,7 @@ class _Diacon: if input_result is not None: if self._current_dlg is not None: - mode = self._current_dlg.handle_user_input(input_result) - if mode == EDlgMode.WRONG_INPUT: - print('\nWrong input') - mode = EDlgMode.RELOAD - if mode == EDlgMode.UNDEFINED: - mode = EDlgMode.RELOAD - self._on_dlg(self._current_dlg, mode) + self._current_dlg.handle_user_input(input_result) continue time.sleep(0.005) @@ -332,6 +385,9 @@ class _Diacon: show current or set new Dialog Can be called from any thread. """ + if not self._started: + self.start() + self._new_dlg = (new_dlg, print) Diacon = _Diacon() diff --git a/xlib/console/diacon/__init__.py b/xlib/console/diacon/__init__.py index 34d2886..e546071 100644 --- a/xlib/console/diacon/__init__.py +++ b/xlib/console/diacon/__init__.py @@ -1 +1 @@ -from .Diacon import Diacon, Dlg, DlgChoice, DlgChoices, EDlgMode +from .Diacon import Diacon, Dlg, DlgChoice, DlgChoices, EDlgMode, DlgNumber diff --git a/xlib/face/FMask.py b/xlib/face/FMask.py index 5d673f9..ade8bbd 100644 --- a/xlib/face/FMask.py +++ b/xlib/face/FMask.py @@ -9,7 +9,7 @@ class FMask: def __init__(self, _from_pickled=False): """ """ - self._uuid : Union[bytes, None] = uuid.uuid4().bytes_le if not _from_pickled else None + self._uuid : Union[bytes, None] = uuid.uuid4().bytes if not _from_pickled else None self._mask_type : Union[FMask.Type, None] = None self._FImage_uuid : Union[bytes, None] = None diff --git a/xlib/face/FaceWarper.py b/xlib/face/FaceWarper.py index 439316f..e981e76 100644 --- a/xlib/face/FaceWarper.py +++ b/xlib/face/FaceWarper.py @@ -52,9 +52,25 @@ class FaceWarper: self._rw_grid_tx = rnd_state.uniform(*rw_grid_tx) if isinstance(rw_grid_tx, Iterable) else rw_grid_tx self._rw_grid_ty = rnd_state.uniform(*rw_grid_ty) if isinstance(rw_grid_ty, Iterable) else rw_grid_ty + self._warp_rnd_mat = Affine2DUniMat.from_transformation(0.5, 0.5, self._rw_grid_rot_deg, 1.0+self._rw_grid_scale, self._rw_grid_tx, self._rw_grid_ty) + self._align_rnd_mat = Affine2DUniMat.from_transformation(0.5, 0.5, self._align_rot_deg, 1.0+self._align_scale, self._align_tx, self._align_ty) + self._rnd_state_state = rnd_state.get_state() self._cached = {} + def get_aligned_random_transform_mat(self) -> Affine2DUniMat: + """ + returns Affine2DUniMat that represents transformation from aligned face to randomly transformed aligned face + """ + mat1 = self._img_to_face_uni_mat + mat2 = (self._face_to_img_uni_mat * self._align_rnd_mat).invert() + + pts = [ [0,0], [1,0], [1,1]] + src_pts = mat1.transform_points(pts) + dst_pts = mat2.transform_points(pts) + + return Affine2DUniMat.from_3_pairs(src_pts, dst_pts) + def transform(self, img : np.ndarray, out_res : int, random_warp : bool = True) -> np.ndarray: """ transform an image. @@ -91,9 +107,7 @@ class FaceWarper: face_warp_grid = FaceWarper._gen_random_warp_uni_grid_diff(out_res, self._rw_grid_cell_count, 0.12, rnd_state) # make a randomly transformable mat to transform face_warp_grid from face to image - face_warp_grid_mat = (self._face_to_img_uni_mat * - Affine2DUniMat.from_transformation(0.5, 0.5, self._rw_grid_rot_deg, 1.0+self._rw_grid_scale, self._rw_grid_tx, self._rw_grid_ty) - ) + face_warp_grid_mat = self._face_to_img_uni_mat * self._warp_rnd_mat # warp face_warp_grid to the space of image and merge with image_grid image_grid += cv2.warpAffine(face_warp_grid, face_warp_grid_mat.to_exact_mat(out_res,out_res, W, H), (W,H), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT) @@ -101,9 +115,10 @@ class FaceWarper: # scale uniform grid from to image size image_grid *= (H-1, W-1) - # apply random transormations for align mat - img_to_face_rnd_mat = (self._face_to_img_uni_mat * Affine2DMat.from_transformation(0.5, 0.5, self._align_rot_deg, 1.0+self._align_scale, self._align_tx, self._align_ty) - ).invert().to_exact_mat(W,H,out_res,out_res) + # apply random transformations for align mat + #img_to_face_rnd_uni_mat = (self._face_to_img_uni_mat * self._align_rnd_mat).invert() + + img_to_face_rnd_mat = (self._face_to_img_uni_mat * self._align_rnd_mat).invert().to_exact_mat(W,H,out_res,out_res) # warp image_grid to face space image_grid = cv2.warpAffine(image_grid, img_to_face_rnd_mat, (out_res,out_res), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_REPLICATE ) diff --git a/xlib/face/Faceset.py b/xlib/face/Faceset.py index 93f19da..5fd3d44 100644 --- a/xlib/face/Faceset.py +++ b/xlib/face/Faceset.py @@ -1,19 +1,22 @@ import pickle -import sqlite3 +import uuid from pathlib import Path -from typing import Generator, List, Union, Iterable +from typing import Generator, Iterable, List, Union import cv2 +import h5py import numpy as np +from .. import console as lib_con from .FMask import FMask from .UFaceMark import UFaceMark from .UImage import UImage from .UPerson import UPerson + class Faceset: - def __init__(self, path = None): + def __init__(self, path = None, write_access=False, recreate=False): """ Faceset is a class to store and manage face related data. @@ -21,205 +24,155 @@ class Faceset: path path to faceset .dfs file + write_access + + recreate + Can be pickled. """ + self._f = None + self._path = path = Path(path) if path.suffix != '.dfs': raise ValueError('Path must be a .dfs file') - self._conn = conn = sqlite3.connect(path, isolation_level=None) - self._cur = cur = conn.cursor() + if path.exists(): + if write_access and recreate: + path.unlink() + elif not write_access: + raise FileNotFoundError(f'File {path} not found.') - cur = self._get_cursor() - cur.execute('BEGIN IMMEDIATE') - if not self._is_table_exists('FacesetInfo'): - self.recreate(shrink=False, _transaction=False) - cur.execute('COMMIT') - self.shrink() - else: - cur.execute('END') + self._mode = 'a' if write_access else 'r' + self._open() def __del__(self): self.close() def __getstate__(self): - return {'_path' : self._path} + return {'_path' : self._path, '_mode' : self._mode} def __setstate__(self, d): - self.__init__( d['_path'] ) + self._f = None + self._path = d['_path'] + self._mode = d['_mode'] + self._open() def __repr__(self): return self.__str__() def __str__(self): return f"Faceset. UImage:{self.get_UImage_count()} UFaceMark:{self.get_UFaceMark_count()} UPerson:{self.get_UPerson_count()}" - def _is_table_exists(self, name): - return self._cur.execute(f"SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", [name]).fetchone()[0] != 0 + def _open(self): + if self._f is None: + self._f = f = h5py.File(self._path, mode=self._mode) + self._UFaceMark_grp = f.require_group('UFaceMark') + self._UImage_grp = f.require_group('UImage') + self._UImage_image_data_grp = f.require_group('UImage_image_data') + self._UPerson_grp = f.require_group('UPerson') - def _get_cursor(self) -> sqlite3.Cursor: return self._cur def close(self): - if self._cur is not None: - self._cur.close() - self._cur = None + if self._f is not None: + self._f.close() + self._f = None - if self._conn is not None: - self._conn.close() - self._conn = None - - def shrink(self): - self._cur.execute('VACUUM') - - def recreate(self, shrink=True, _transaction=True): + def optimize(self, verbose=True): """ - delete all data and recreate Faceset structure. + recreate Faceset with optimized structure. """ - cur = self._get_cursor() + if verbose: + print(f'Optimizing {self._path.name}...') - if _transaction: - cur.execute('BEGIN IMMEDIATE') + tmp_path = self._path.parent / (self._path.stem + '_optimizing' + self._path.suffix) - for table_name, in cur.execute("SELECT name from sqlite_master where type = 'table';").fetchall(): - cur.execute(f'DROP TABLE {table_name}') + tmp_fs = Faceset(tmp_path, write_access=True, recreate=True) + self._group_copy(tmp_fs._UFaceMark_grp, self._UFaceMark_grp, verbose=verbose) + self._group_copy(tmp_fs._UPerson_grp, self._UPerson_grp, verbose=verbose) + self._group_copy(tmp_fs._UImage_grp, self._UImage_grp, verbose=verbose) + self._group_copy(tmp_fs._UImage_image_data_grp, self._UImage_image_data_grp, verbose=verbose) + tmp_fs.close() - (cur.execute('CREATE TABLE FacesetInfo (version INT)') - .execute('INSERT INTO FacesetInfo VALUES (1)') + self.close() + self._path.unlink() + tmp_path.rename(self._path) + self._open() - .execute('CREATE TABLE UImage (uuid BLOB, name TEXT, format TEXT, data BLOB)') - .execute('CREATE TABLE UPerson (uuid BLOB, data BLOB)') - .execute('CREATE TABLE UFaceMark (uuid BLOB, UImage_uuid BLOB, UPerson_uuid BLOB, data BLOB)') - ) + def _group_copy(self, group_dst : h5py.Group, group_src : h5py.Group, verbose=True): + for key, value in lib_con.progress_bar_iterator(group_src.items(), desc=f'Copying {group_src.name} -> {group_dst.name}', suppress_print=not verbose): + d = group_dst.create_dataset(key, shape=value.shape, dtype=value.dtype ) + d[:] = value[:] + for a_key, a_value in value.attrs.items(): + d.attrs[a_key] = a_value - if _transaction: - cur.execute('COMMIT') + def _group_read_bytes(self, group : h5py.Group, key : str, check_key=True) -> Union[bytes, None]: + if check_key and key not in group: + return None + dataset = group[key] + data_bytes = bytearray(len(dataset)) + dataset.read_direct(np.frombuffer(data_bytes, dtype=np.uint8)) + return data_bytes - if shrink: - self.shrink() + def _group_write_bytes(self, group : h5py.Group, key : str, data : bytes, update_existing=True) -> Union[h5py.Dataset, None]: + if key in group: + if not update_existing: + return None + del group[key] + + return group.create_dataset(key, data=np.frombuffer(data, dtype=np.uint8) ) ################### ### UFaceMark ################### - def _UFaceMark_from_db_row(self, db_row) -> UFaceMark: - uuid, UImage_uuid, UPerson_uuid, data = db_row - - ufm = UFaceMark() - ufm.restore_state(pickle.loads(data)) - return ufm - - def add_UFaceMark(self, ufacemark_or_list : UFaceMark): + def add_UFaceMark(self, ufacemark_or_list : UFaceMark, update_existing=True): """ add or update UFaceMark in DB """ if not isinstance(ufacemark_or_list, Iterable): ufacemark_or_list : List[UFaceMark] = [ufacemark_or_list] - cur = self._cur - cur.execute('BEGIN IMMEDIATE') for ufm in ufacemark_or_list: - uuid = ufm.get_uuid() - UImage_uuid = ufm.get_UImage_uuid() - UPerson_uuid = ufm.get_UPerson_uuid() - - data = pickle.dumps(ufm.dump_state()) - - if cur.execute('SELECT COUNT(*) from UFaceMark where uuid=?', [uuid] ).fetchone()[0] != 0: - cur.execute('UPDATE UFaceMark SET UImage_uuid=?, UPerson_uuid=?, data=? WHERE uuid=?', - [UImage_uuid, UPerson_uuid, data, uuid]) - else: - cur.execute('INSERT INTO UFaceMark VALUES (?, ?, ?, ?)', [uuid, UImage_uuid, UPerson_uuid, data]) - cur.execute('COMMIT') + self._group_write_bytes(self._UFaceMark_grp, ufm.get_uuid().hex(), pickle.dumps(ufm.dump_state()), update_existing=update_existing ) def get_UFaceMark_count(self) -> int: - return self._cur.execute('SELECT COUNT(*) FROM UFaceMark').fetchone()[0] + return len(self._UFaceMark_grp.keys()) def get_all_UFaceMark(self) -> List[UFaceMark]: - return [ self._UFaceMark_from_db_row(db_row) for db_row in self._cur.execute('SELECT * FROM UFaceMark').fetchall() ] - + return [ UFaceMark.from_state(pickle.loads(self._group_read_bytes(self._UFaceMark_grp, key, check_key=False))) for key in self._UFaceMark_grp.keys() ] + + def get_all_UFaceMark_uuids(self) -> List[bytes]: + return [ uuid.UUID(key).bytes for key in self._UFaceMark_grp.keys() ] + def get_UFaceMark_by_uuid(self, uuid : bytes) -> Union[UFaceMark, None]: - c = self._cur.execute('SELECT * FROM UFaceMark WHERE uuid=?', [uuid]) - db_row = c.fetchone() - if db_row is None: + data = self._group_read_bytes(self._UFaceMark_grp, uuid.hex()) + if data is None: return None + return UFaceMark.from_state(pickle.loads(data)) - return self._UFaceMark_from_db_row(db_row) - + def delete_UFaceMark_by_uuid(self, uuid : bytes) -> bool: + key = uuid.hex() + if key in self._UFaceMark_grp: + del self._UFaceMark_grp[key] + return True + return False + + def iter_UFaceMark(self) -> Generator[UFaceMark, None, None]: """ returns Generator of UFaceMark """ - for db_row in self._cur.execute('SELECT * FROM UFaceMark').fetchall(): - yield self._UFaceMark_from_db_row(db_row) + for key in self._UFaceMark_grp.keys(): + yield UFaceMark.from_state(pickle.loads(self._group_read_bytes(self._UFaceMark_grp, key, check_key=False))) def delete_all_UFaceMark(self): """ deletes all UFaceMark from DB """ - (self._cur.execute('BEGIN IMMEDIATE') - .execute('DELETE FROM UFaceMark') - .execute('COMMIT') ) - ################### - ### UPerson - ################### - def _UPerson_from_db_row(self, db_row) -> UPerson: - uuid, data = db_row - up = UPerson() - up.restore_state(pickle.loads(data)) - return up - - def add_UPerson(self, uperson_or_list : UPerson): - """ - add or update UPerson in DB - """ - if not isinstance(uperson_or_list, Iterable): - uperson_or_list : List[UPerson] = [uperson_or_list] - - cur = self._cur - cur.execute('BEGIN IMMEDIATE') - for uperson in uperson_or_list: - uuid = uperson.get_uuid() - - data = pickle.dumps(uperson.dump_state()) - - if cur.execute('SELECT COUNT(*) from UPerson where uuid=?', [uuid]).fetchone()[0] != 0: - cur.execute('UPDATE UPerson SET data=? WHERE uuid=?', [data]) - else: - cur.execute('INSERT INTO UPerson VALUES (?, ?)', [uuid, data]) - cur.execute('COMMIT') - - def get_UPerson_count(self) -> int: - return self._cur.execute('SELECT COUNT(*) FROM UPerson').fetchone()[0] - - def get_all_UPerson(self) -> List[UPerson]: - return [ self._UPerson_from_db_row(db_row) for db_row in self._cur.execute('SELECT * FROM UPerson').fetchall() ] - - def iter_UPerson(self) -> Generator[UPerson, None, None]: - """ - iterator of all UPerson's - """ - for db_row in self._cur.execute('SELECT * FROM UPerson').fetchall(): - yield self._UPerson_from_db_row(db_row) - - def delete_all_UPerson(self): - """ - deletes all UPerson from DB - """ - (self._cur.execute('BEGIN IMMEDIATE') - .execute('DELETE FROM UPerson') - .execute('COMMIT') ) + for key in self._UFaceMark_grp.keys(): + del self._UFaceMark_grp[key] ################### ### UImage ################### - def _UImage_from_db_row(self, db_row) -> UImage: - uuid, name, format, data_bytes = db_row - img = cv2.imdecode(np.frombuffer(data_bytes, dtype=np.uint8), flags=cv2.IMREAD_UNCHANGED) - - uimg = UImage() - uimg.set_uuid(uuid) - uimg.set_name(name) - uimg.assign_image(img) - return uimg - - def add_UImage(self, uimage_or_list : UImage, format : str = 'webp', quality : int = 100): + def add_UImage(self, uimage_or_list : UImage, format : str = 'png', quality : int = 100, update_existing=True): """ add or update UImage in DB @@ -239,9 +192,8 @@ class Faceset: raise ValueError('quality must be in range [0..100]') if not isinstance(uimage_or_list, Iterable): - uimage_or_list = [uimage_or_list] + uimage_or_list : List[UImage] = [uimage_or_list] - uimage_datas = [] for uimage in uimage_or_list: if format == 'webp': imencode_args = [int(cv2.IMWRITE_WEBP_QUALITY), quality] @@ -251,44 +203,112 @@ class Faceset: imencode_args = [int(cv2.IMWRITE_JPEG2000_COMPRESSION_X1000), quality*10] else: imencode_args = [] + ret, data_bytes = cv2.imencode( f'.{format}', uimage.get_image(), imencode_args) if not ret: raise Exception(f'Unable to encode image format {format}') - uimage_datas.append(data_bytes.data) - cur = self._cur - cur.execute('BEGIN IMMEDIATE') - for uimage, data in zip(uimage_or_list, uimage_datas): - uuid = uimage.get_uuid() - if cur.execute('SELECT COUNT(*) from UImage where uuid=?', [uuid] ).fetchone()[0] != 0: - cur.execute('UPDATE UImage SET name=?, format=?, data=? WHERE uuid=?', [uimage.get_name(), format, data, uuid]) - else: - cur.execute('INSERT INTO UImage VALUES (?, ?, ?, ?)', [uuid, uimage.get_name(), format, data]) - cur.execute('COMMIT') + key = uimage.get_uuid().hex() - def get_UImage_count(self) -> int: return self._cur.execute('SELECT COUNT(*) FROM UImage').fetchone()[0] - def get_UImage_by_uuid(self, uuid : Union[bytes, None]) -> Union[UImage, None]: - """ - """ - if uuid is None: + self._group_write_bytes(self._UImage_grp, key, pickle.dumps(uimage.dump_state(exclude_image=True)), update_existing=update_existing ) + d = self._group_write_bytes(self._UImage_image_data_grp, key, data_bytes.data, update_existing=update_existing ) + d.attrs['format'] = format + d.attrs['quality'] = quality + + def get_UImage_count(self) -> int: + return len(self._UImage_grp.keys()) + + def get_all_UImage(self) -> List[UImage]: + return [ self._get_UImage_by_key(key) for key in self._UImage_grp.keys() ] + + def get_all_UImage_uuids(self) -> List[bytes]: + return [ uuid.UUID(key).bytes for key in self._UImage_grp.keys() ] + + def _get_UImage_by_key(self, key, check_key=True) -> Union[UImage, None]: + data = self._group_read_bytes(self._UImage_grp, key, check_key=check_key) + if data is None: return None + uimg = UImage.from_state(pickle.loads(data)) - db_row = self._cur.execute('SELECT * FROM UImage where uuid=?', [uuid]).fetchone() - if db_row is None: - return None - return self._UImage_from_db_row(db_row) + image_data = self._group_read_bytes(self._UImage_image_data_grp, key, check_key=check_key) + if image_data is not None: + uimg.assign_image (cv2.imdecode(np.frombuffer(image_data, dtype=np.uint8), flags=cv2.IMREAD_UNCHANGED)) - def iter_UImage(self) -> Generator[UImage, None, None]: + return uimg + + def get_UImage_by_uuid(self, uuid : bytes) -> Union[UImage, None]: + return self._get_UImage_by_key(uuid.hex()) + + def delete_UImage_by_uuid(self, uuid : bytes): + key = uuid.hex() + if key in self._UImage_grp: + del self._UImage_grp[key] + if key in self._UImage_image_data_grp: + del self._UImage_image_data_grp[key] + + def iter_UImage(self, include_key=False) -> Generator[UImage, None, None]: """ - iterator of all UImage's + returns Generator of UImage """ - for db_row in self._cur.execute('SELECT * FROM UImage').fetchall(): - yield self._UImage_from_db_row(db_row) + for key in self._UImage_grp.keys(): + uimg = self._get_UImage_by_key(key, check_key=False) + yield (uimg, key) if include_key else uimg def delete_all_UImage(self): """ deletes all UImage from DB """ - (self._cur.execute('BEGIN IMMEDIATE') - .execute('DELETE FROM UImage') - .execute('COMMIT') ) + for key in self._UImage_grp.keys(): + del self._UImage_grp[key] + for key in self._UImage_image_data_grp.keys(): + del self._UImage_image_data_grp[key] + + ################### + ### UPerson + ################### + + def add_UPerson(self, uperson_or_list : UPerson, update_existing=True): + """ + add or update UPerson in DB + """ + if not isinstance(uperson_or_list, Iterable): + uperson_or_list : List[UPerson] = [uperson_or_list] + + for uperson in uperson_or_list: + self._group_write_bytes(self._UPerson_grp, uperson.get_uuid().hex(), pickle.dumps(uperson.dump_state()), update_existing=update_existing ) + + def get_UPerson_count(self) -> int: + return len(self._UPerson_grp.keys()) + + def get_all_UPerson(self) -> List[UPerson]: + return [ UPerson.from_state(pickle.loads(self._group_read_bytes(self._UPerson_grp, key, check_key=False))) for key in self._UPerson_grp.keys() ] + + def get_all_UPerson_uuids(self) -> List[bytes]: + return [ uuid.UUID(key).bytes for key in self._UPerson_grp.keys() ] + + def get_UPerson_by_uuid(self, uuid : bytes) -> Union[UPerson, None]: + data = self._group_read_bytes(self._UPerson_grp, uuid.hex()) + if data is None: + return None + return UPerson.from_state(pickle.loads(data)) + + def delete_UPerson_by_uuid(self, uuid : bytes) -> bool: + key = uuid.hex() + if key in self._UPerson_grp: + del self._UPerson_grp[key] + return True + return False + + def iter_UPerson(self) -> Generator[UPerson, None, None]: + """ + returns Generator of UPerson + """ + for key in self._UPerson_grp.keys(): + yield UPerson.from_state(pickle.loads(self._group_read_bytes(self._UPerson_grp, key, check_key=False))) + + def delete_all_UPerson(self): + """ + deletes all UPerson from DB + """ + for key in self._UPerson_grp.keys(): + del self._UPerson_grp[key] diff --git a/xlib/face/UFaceMark.py b/xlib/face/UFaceMark.py index 3c0c3ff..dc23856 100644 --- a/xlib/face/UFaceMark.py +++ b/xlib/face/UFaceMark.py @@ -25,7 +25,13 @@ class UFaceMark(IState): def __repr__(self): return self.__str__() def __str__(self): return f"UFaceMark UUID:[...{self.get_uuid()[-4:].hex()}]" - + + @staticmethod + def from_state(state : dict) -> 'UFaceMark': + ufm = UFaceMark() + ufm.restore_state(state) + return ufm + def restore_state(self, state : dict): self._uuid = state.get('_uuid', None) self._UImage_uuid = state.get('_UImage_uuid', None) @@ -45,7 +51,7 @@ class UFaceMark(IState): def get_uuid(self) -> Union[bytes, None]: if self._uuid is None: - self._uuid = uuid.uuid4().bytes_le + self._uuid = uuid.uuid4().bytes return self._uuid def set_uuid(self, uuid : Union[bytes, None]): @@ -72,6 +78,16 @@ class UFaceMark(IState): self._FRect = face_urect def get_all_FLandmarks2D(self) -> List[FLandmarks2D]: return self._FLandmarks2D_list + + def get_FLandmarks2D_best(self) -> Union[FLandmarks2D, None]: + """get best available FLandmarks2D """ + lmrks = self.get_FLandmarks2D_by_type(ELandmarks2D.L468) + if lmrks is None: + lmrks = self.get_FLandmarks2D_by_type(ELandmarks2D.L68) + if lmrks is None: + lmrks = self.get_FLandmarks2D_by_type(ELandmarks2D.L5) + return lmrks + def get_FLandmarks2D_by_type(self, type : ELandmarks2D) -> Union[FLandmarks2D, None]: """get FLandmarks2D from list by type""" if not isinstance(type, ELandmarks2D): diff --git a/xlib/face/UImage.py b/xlib/face/UImage.py index c22ba8c..c6fd519 100644 --- a/xlib/face/UImage.py +++ b/xlib/face/UImage.py @@ -18,10 +18,31 @@ class UImage(IState): def __str__(self): return f"UImage UUID:[...{self.get_uuid()[-4:].hex()}] name:[{self._name}] image:[{ (self._image.shape, self._image.dtype) if self._image is not None else None}]" def __repr__(self): return self.__str__() + + @staticmethod + def from_state(state : dict) -> 'UImage': + ufm = UImage() + ufm.restore_state(state) + return ufm + + def restore_state(self, state : dict): + self._uuid = state.get('_uuid', None) + self._name = state.get('_name', None) + self._image = state.get('_image', None) + def dump_state(self, exclude_image=False) -> dict: + d = {'_uuid' : self._uuid, + '_name' : self._name, + } + + if not exclude_image: + d['_image'] = self._image + + return d + def get_uuid(self) -> Union[bytes, None]: if self._uuid is None: - self._uuid = uuid.uuid4().bytes_le + self._uuid = uuid.uuid4().bytes return self._uuid def set_uuid(self, uuid : Union[bytes, None]): diff --git a/xlib/face/UPerson.py b/xlib/face/UPerson.py index ac6280d..e0b4181 100644 --- a/xlib/face/UPerson.py +++ b/xlib/face/UPerson.py @@ -15,6 +15,12 @@ class UPerson(IState): def __str__(self): return f"UPerson UUID:[...{self._uuid[-4:].hex()}] name:[{self._name}] age:[{self._age}]" def __repr__(self): return self.__str__() + @staticmethod + def from_state(state : dict) -> 'UPerson': + ufm = UPerson() + ufm.restore_state(state) + return ufm + def restore_state(self, state : dict): self._uuid = state.get('_uuid', None) self._name = state.get('_name', None) @@ -28,7 +34,7 @@ class UPerson(IState): def get_uuid(self) -> Union[bytes, None]: if self._uuid is None: - self._uuid = uuid.uuid4().bytes_le + self._uuid = uuid.uuid4().bytes return self._uuid def set_uuid(self, uuid : Union[bytes, None]): diff --git a/xlib/image/color_transfer/sot.py b/xlib/image/color_transfer/sot.py index 4e3d0e1..29fb011 100644 --- a/xlib/image/color_transfer/sot.py +++ b/xlib/image/color_transfer/sot.py @@ -2,12 +2,14 @@ import cv2 import numpy as np import numpy.linalg as npla -def sot(src,trg, steps=10, batch_size=30, reg_sigmaXY=16.0, reg_sigmaV=5.0): +def sot(src,trg, mask=None, steps=10, batch_size=30, reg_sigmaXY=16.0, reg_sigmaV=5.0, return_diff=False): """ Color Transform via Sliced Optimal Transfer, ported from https://github.com/dcoeurjo/OTColorTransfer src - any float range any channel image dst - any float range any channel image, same shape as src + mask - + steps - number of solver steps batch_size - solver batch size reg_sigmaXY - apply regularization and sigmaXY of filter, otherwise set to 0.0 @@ -39,8 +41,8 @@ def sot(src,trg, steps=10, batch_size=30, reg_sigmaXY=16.0, reg_sigmaV=5.0): dir = np.random.normal(size=c).astype(src_dtype) dir /= npla.norm(dir) - projsource = np.sum( new_src*dir, axis=-1).reshape ((h*w)) - projtarget = np.sum( trg*dir, axis=-1).reshape ((h*w)) + projsource = np.sum( new_src*dir*mask, axis=-1).reshape ((h*w)) + projtarget = np.sum( trg*dir*mask, axis=-1).reshape ((h*w)) idSource = np.argsort (projsource) idTarget = np.argsort (projtarget) @@ -48,12 +50,18 @@ def sot(src,trg, steps=10, batch_size=30, reg_sigmaXY=16.0, reg_sigmaV=5.0): a = projtarget[idTarget]-projsource[idSource] for i_c in range(c): advect[idSource,i_c] += a * dir[i_c] - new_src += advect.reshape( (h,w,c) ) / batch_size + + new_src += (advect.reshape( (h,w,c) ) * mask) / batch_size if reg_sigmaXY != 0.0: src_diff = new_src-src src_diff_filt = cv2.bilateralFilter (src_diff, 0, reg_sigmaV, reg_sigmaXY ) if len(src_diff_filt.shape) == 2: src_diff_filt = src_diff_filt[...,None] - new_src = src + src_diff_filt - return new_src + if return_diff: + return src_diff_filt + return src + src_diff_filt + else: + if return_diff: + return new_src-src + return new_src diff --git a/xlib/mp/MPSPSCMRRingData.py b/xlib/mp/MPSPSCMRRingData.py index 3c95024..f0e66f6 100644 --- a/xlib/mp/MPSPSCMRRingData.py +++ b/xlib/mp/MPSPSCMRRingData.py @@ -37,7 +37,7 @@ class MPSPSCMRRingData: # Initialize first block at 0 index wid = 0 - wid_uuid = uuid.uuid4().bytes_le + wid_uuid = uuid.uuid4().bytes wid_heap_offset = 0 wid_data_size = 0 @@ -82,7 +82,7 @@ class MPSPSCMRRingData: raise Exception('data_size more than heap_size') fmv = FormattedMemoryViewIO(self._shared_mem.get_mv()) - wid_uuid = uuid.uuid4().bytes_le + wid_uuid = uuid.uuid4().bytes if self._write_lock is not None: self._write_lock.acquire() diff --git a/xlib/mp/MPWeakHeap.py b/xlib/mp/MPWeakHeap.py index 77da3b1..378851a 100644 --- a/xlib/mp/MPWeakHeap.py +++ b/xlib/mp/MPWeakHeap.py @@ -49,7 +49,7 @@ class MPWeakHeap: # Entire block fmv.seek(self._first_block_offset) - fmv.write_fmt('qq', self._heap_size-self._first_block_offset, 0), fmv.write(uuid.uuid4().bytes_le) + fmv.write_fmt('qq', self._heap_size-self._first_block_offset, 0), fmv.write(uuid.uuid4().bytes) def add_data(self, data : Union[bytes, bytearray, memoryview] ) -> 'MPWeakHeap.DataRef': @@ -92,7 +92,7 @@ class MPWeakHeap: if block_remain_size >= block_header_size: # the remain space of the block is enough for next block, split the block next_block_offset = cur_block_offset + block_new_size - fmv.seek(next_block_offset), fmv.write_fmt('qq', block_remain_size, 0), fmv.write(uuid.uuid4().bytes_le) + fmv.seek(next_block_offset), fmv.write_fmt('qq', block_remain_size, 0), fmv.write(uuid.uuid4().bytes) else: # otherwise do not split next_block_offset = cur_block_offset + block_size @@ -101,7 +101,7 @@ class MPWeakHeap: block_new_size = block_size # update current block structure - uid = uuid.uuid4().bytes_le + uid = uuid.uuid4().bytes fmv.seek(cur_block_offset), fmv.write_fmt('qq', block_new_size, data_size ), fmv.write(uid) # update ring_head_block_offset @@ -135,7 +135,7 @@ class MPWeakHeap: next_block_size, = fmv.get_fmt('q') # erase data of next block - fmv.write_fmt('qq', 0, 0), fmv.write(uuid.uuid4().bytes_le) + fmv.write_fmt('qq', 0, 0), fmv.write(uuid.uuid4().bytes) # overwrite current block size with expanded block size fmv.seek(cur_block_offset) diff --git a/xlib/mp/SPMTWorker.py b/xlib/mp/SPMTWorker.py index a59f055..cbb96ce 100644 --- a/xlib/mp/SPMTWorker.py +++ b/xlib/mp/SPMTWorker.py @@ -11,7 +11,6 @@ def _host_thread_proc(wref): break ref._host_process_messages(0.005) del ref - print('_host_thread_proc exit') class SPMTWorker: def __init__(self, *sub_args, **sub_kwargs): @@ -152,7 +151,7 @@ class SPMTWorker: break self._threads_running = False - + self._threads_exit_barrier.wait() self._on_sub_finalize() \ No newline at end of file diff --git a/xlib/onnxruntime/__init__.py b/xlib/onnxruntime/__init__.py index 927968c..bdf59e6 100644 --- a/xlib/onnxruntime/__init__.py +++ b/xlib/onnxruntime/__init__.py @@ -1,3 +1,3 @@ from .device import (ORTDeviceInfo, get_available_devices_info, - get_cpu_device) + get_cpu_device_info) from .InferenceSession import InferenceSession_with_device diff --git a/xlib/onnxruntime/device.py b/xlib/onnxruntime/device.py index ad8ddc4..d6f8843 100644 --- a/xlib/onnxruntime/device.py +++ b/xlib/onnxruntime/device.py @@ -67,33 +67,34 @@ class ORTDeviceInfo: _ort_devices_info = None -def get_cpu_device() -> ORTDeviceInfo: +def get_cpu_device_info() -> ORTDeviceInfo: return ORTDeviceInfo(index=-1, execution_provider='CPUExecutionProvider', name='CPU', total_memory=0, free_memory=0) def get_available_devices_info(include_cpu=True, cpu_only=False) -> List[ORTDeviceInfo]: """ returns a list of available ORTDeviceInfo """ - global _ort_devices_info - if _ort_devices_info is None: - _initialize_ort_devices() - devices = [] - if not cpu_only: + devices = [] + if not cpu_only: + global _ort_devices_info + if _ort_devices_info is None: + _initialize_ort_devices_info() + _ort_devices_info = [] for i in range ( int(os.environ.get('ORT_DEVICES_COUNT',0)) ): - devices.append ( ORTDeviceInfo(index=int(os.environ[f'ORT_DEVICE_{i}_INDEX']), - execution_provider=os.environ[f'ORT_DEVICE_{i}_EP'], - name=os.environ[f'ORT_DEVICE_{i}_NAME'], - total_memory=int(os.environ[f'ORT_DEVICE_{i}_TOTAL_MEM']), - free_memory=int(os.environ[f'ORT_DEVICE_{i}_FREE_MEM']), - ) ) - if include_cpu or cpu_only: - devices.append(get_cpu_device()) - _ort_devices_info = devices + _ort_devices_info.append ( ORTDeviceInfo(index=int(os.environ[f'ORT_DEVICE_{i}_INDEX']), + execution_provider=os.environ[f'ORT_DEVICE_{i}_EP'], + name=os.environ[f'ORT_DEVICE_{i}_NAME'], + total_memory=int(os.environ[f'ORT_DEVICE_{i}_TOTAL_MEM']), + free_memory=int(os.environ[f'ORT_DEVICE_{i}_FREE_MEM']), + ) ) + devices += _ort_devices_info + if include_cpu: + devices.append(get_cpu_device_info()) - return _ort_devices_info + return devices -def _initialize_ort_devices(): +def _initialize_ort_devices_info(): """ Determine available ORT devices, and place info about them to os.environ, they will be available in spawned subprocesses. @@ -184,4 +185,4 @@ def _initialize_ort_devices(): os.environ[f'ORT_DEVICE_{i}_TOTAL_MEM'] = str(device['total_mem']) os.environ[f'ORT_DEVICE_{i}_FREE_MEM'] = str(device['free_mem']) -_initialize_ort_devices() +_initialize_ort_devices_info() diff --git a/xlib/text/ascii_table.py b/xlib/text/ascii_table.py index af823cc..2d3e94e 100644 --- a/xlib/text/ascii_table.py +++ b/xlib/text/ascii_table.py @@ -247,7 +247,8 @@ def ascii_table(table_def : List[str], if row_line is not None: lines.append(row_line) for sub_rows in rows: - + + for row in sub_rows: line = '' @@ -288,7 +289,8 @@ def ascii_table(table_def : List[str], line += right_border lines.append(line) - if row_line is not None: + + if len(sub_rows) != 0 and row_line is not None: lines.append(row_line) return '\n'.join(lines) \ No newline at end of file diff --git a/xlib/torch/__init__.py b/xlib/torch/__init__.py index 45b6a5d..8545f72 100644 --- a/xlib/torch/__init__.py +++ b/xlib/torch/__init__.py @@ -1,2 +1,2 @@ -from .device import (TorchDeviceInfo, TorchDevicesInfo, get_available_devices, - get_cpu_device) +from .device import (TorchDeviceInfo, get_available_devices_info, + get_cpu_device_info, get_device, get_device_info_by_index) diff --git a/xlib/torch/device.py b/xlib/torch/device.py index 83de3e9..c94cfac 100644 --- a/xlib/torch/device.py +++ b/xlib/torch/device.py @@ -1,4 +1,6 @@ -from typing import List +from typing import List, Union + +import torch class TorchDeviceInfo: @@ -45,96 +47,40 @@ class TorchDeviceInfo: def __repr__(self): return f'{self.__class__.__name__} object: ' + self.__str__() -# class TorchDevicesInfo: -# """ -# picklable list of TorchDeviceInfo -# """ -# def __init__(self, devices : List[TorchDeviceInfo] = None): -# if devices is None: -# devices = [] -# self._devices = devices - -# def __getstate__(self): -# return self.__dict__.copy() - -# def __setstate__(self, d): -# self.__init__() -# self.__dict__.update(d) - -# def add(self, device_or_devices : TorchDeviceInfo): -# if isinstance(device_or_devices, TorchDeviceInfo): -# if device_or_devices not in self._devices: -# self._devices.append(device_or_devices) -# elif isinstance(device_or_devices, TorchDevicesInfo): -# for device in device_or_devices: -# self.add(device) - -# def copy(self): -# return copy.deepcopy(self) - -# def get_count(self): return len(self._devices) - -# def get_largest_total_memory_device(self) -> TorchDeviceInfo: -# raise NotImplementedError() -# result = None -# idx_mem = 0 -# for device in self._devices: -# mem = device.get_total_memory() -# if result is None or (mem is not None and mem > idx_mem): -# result = device -# idx_mem = mem -# return result - -# def get_smallest_total_memory_device(self) -> TorchDeviceInfo: -# raise NotImplementedError() -# result = None -# idx_mem = sys.maxsize -# for device in self._devices: -# mem = device.get_total_memory() -# if result is None or (mem is not None and mem < idx_mem): -# result = device -# idx_mem = mem -# return result - -# def __len__(self): -# return len(self._devices) - -# def __getitem__(self, key): -# result = self._devices[key] -# if isinstance(key, slice): -# return self.__class__(result) -# return result - -# def __iter__(self): -# for device in self._devices: -# yield device - -# def __str__(self): return f'{self.__class__.__name__}:[' + ', '.join([ device.__str__() for device in self._devices ]) + ']' -# def __repr__(self): return f'{self.__class__.__name__}:[' + ', '.join([ device.__repr__() for device in self._devices ]) + ']' - - _torch_devices = None -def get_cpu_device() -> TorchDeviceInfo: +def get_cpu_device_info() -> TorchDeviceInfo: return TorchDeviceInfo(index=-1, name='CPU', total_memory=0) -def get_available_devices(include_cpu=True, cpu_only=False) -> List[TorchDeviceInfo]: +def get_device_info_by_index(index) -> Union[TorchDeviceInfo, None]: + for device in get_available_devices_info(include_cpu=False): + if device.get_index() == index: + return device + + return None + +def get_device(device_info : TorchDeviceInfo) -> torch.device: + if device_info.is_cpu(): + return torch.device('cpu') + return torch.device(f'cuda:{device_info.get_index()}') + +def get_available_devices_info(include_cpu=True, cpu_only=False) -> List[TorchDeviceInfo]: """ returns a list of available TorchDeviceInfo """ - global _torch_devices - if _torch_devices is None: - import torch - devices = [] - - if not cpu_only: + devices = [] + if not cpu_only: + global _torch_devices + if _torch_devices is None: + + _torch_devices = [] for i in range (torch.cuda.device_count()): device_props = torch.cuda.get_device_properties(i) - devices.append ( TorchDeviceInfo(index=i, name=device_props.name, total_memory=device_props.total_memory)) + _torch_devices.append ( TorchDeviceInfo(index=i, name=device_props.name, total_memory=device_props.total_memory)) + devices += _torch_devices - if include_cpu or cpu_only: - devices.append ( get_cpu_device() ) + if include_cpu: + devices.append ( get_cpu_device_info() ) - _torch_devices = devices - return _torch_devices + return devices