refactoring

This commit is contained in:
iperov 2021-11-07 10:03:15 +04:00
parent 8489949f2c
commit 30ba51edf7
24 changed files with 663 additions and 459 deletions

View file

@ -79,12 +79,12 @@ def main():
p.set_defaults(func=train_FaceAligner) p.set_defaults(func=train_FaceAligner)
def train_CTSOT(args): def train_CTSOT(args):
from apps.trainers.CTSOT.CTSOTTrainerApp import run_app from apps.trainers.CTSOT.CTSOTTrainerApp import CTSOTTrainerApp
run_app(userdata_path=Path(args.userdata_dir), faceset_path=Path(args.faceset_path)) CTSOTTrainerApp(workspace_path=Path(args.workspace_dir), faceset_path=Path(args.faceset_path))
p = train_parsers.add_parser('CTSOT') p = train_parsers.add_parser('CTSOT')
p.add_argument('--userdata-dir', default=None, action=fixPathAction, help="Directory to save app data.") p.add_argument('--workspace-dir', default=None, action=fixPathAction, help="Workspace directory.")
p.add_argument('--faceset-path', default=None, action=fixPathAction, help=".dfs path") p.add_argument('--faceset-path', default=None, action=fixPathAction, help=".dfs faceset path")
p.set_defaults(func=train_CTSOT) p.set_defaults(func=train_CTSOT)
def bad_args(arguments): def bad_args(arguments):

View file

@ -58,7 +58,7 @@ def get_available_devices() -> List[ORTDeviceInfo]:
class DFMModel: class DFMModel:
def __init__(self, model_path : Path, device : ORTDeviceInfo = None): def __init__(self, model_path : Path, device : ORTDeviceInfo = None):
if device is None: if device is None:
device = lib_ort.get_cpu_device() device = lib_ort.get_cpu_device_info()
self._model_path = model_path self._model_path = model_path
sess = self._sess = lib_ort.InferenceSession_with_device(str(model_path), device) sess = self._sess = lib_ort.InferenceSession_with_device(str(model_path), device)

View file

@ -0,0 +1,121 @@
from pathlib import Path
from typing import Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from xlib.file import SplittedFile
from xlib.torch import TorchDeviceInfo, get_cpu_device_info
class ResBlock(nn.Module):
def __init__(self, ch):
super().__init__()
self.conv1 = nn.Conv2d(ch, ch, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(ch, ch, kernel_size=3, stride=1, padding=1)
def forward(self, inp):
x = inp
x = F.leaky_relu(self.conv1(x), 0.2)
x = F.leaky_relu(self.conv2(x) + inp, 0.2)
return x
class AutoEncoder(nn.Module):
def __init__(self, resolution, in_ch, ae_ch):
super().__init__()
self._resolution = resolution
self._in_ch = in_ch
self._ae_ch = ae_ch
self.conv1 = nn.Conv2d(in_ch*1, in_ch*2, kernel_size=5, stride=2, padding=2)
self.conv2 = nn.Conv2d(in_ch*2, in_ch*4, kernel_size=5, stride=2, padding=2)
self.conv3 = nn.Conv2d(in_ch*4, in_ch*8, kernel_size=5, stride=2, padding=2)
self.conv4 = nn.Conv2d(in_ch*8, in_ch*8, kernel_size=5, stride=2, padding=2)
self.dense_in = nn.Linear( in_ch*8 * ( resolution // (2**4) )**2, ae_ch)
self.dense_out = nn.Linear( ae_ch, in_ch*8 * ( resolution // (2**4) )**2 )
self.up_conv4 = nn.ConvTranspose2d(in_ch*8, in_ch*8, kernel_size=3, stride=2, padding=1, output_padding=(1,1), )
self.up_conv3 = nn.ConvTranspose2d(in_ch*8, in_ch*4, kernel_size=3, stride=2, padding=1, output_padding=(1,1), )
self.up_conv2 = nn.ConvTranspose2d(in_ch*4, in_ch*2, kernel_size=3, stride=2, padding=1, output_padding=(1,1), )
self.up_conv1 = nn.ConvTranspose2d(in_ch*2, in_ch*1, kernel_size=3, stride=2, padding=1, output_padding=(1,1), )
def forward(self, inp):
x = inp
x = F.leaky_relu(self.conv1(x), 0.1)
x = F.leaky_relu(self.conv2(x), 0.1)
x = F.leaky_relu(self.conv3(x), 0.1)
x = F.leaky_relu(self.conv4(x), 0.1)
x = x.view( (x.shape[0], np.prod(x.shape[1:])) )
x = self.dense_in(x)
x = self.dense_out(x)
x = x.view( (x.shape[0], self._in_ch*8, self._resolution // (2**4), self._resolution // (2**4) ))
x = F.leaky_relu(self.up_conv4(x), 0.1)
x = F.leaky_relu(self.up_conv3(x), 0.1)
x = F.leaky_relu(self.up_conv2(x), 0.1)
x = F.leaky_relu(self.up_conv1(x), 0.1)
# from xlib.console import diacon
# diacon.Diacon.stop()
# import code
# code.interact(local=dict(globals(), **locals()))
return x
class CTSOTNet(nn.Module):
def __init__(self, resolution, in_ch=6, inner_ch=64, ae_ch=256, out_ch=3, res_block_count = 12):
super().__init__()
self.in_conv = nn.Conv2d(in_ch, inner_ch, kernel_size=1, stride=1, padding=0)
self.ae = AutoEncoder(resolution, inner_ch, ae_ch)
self.out_conv = nn.Conv2d(inner_ch, out_ch, kernel_size=1, stride=1, padding=0)
def forward(self, img1_t, img2_t):
x = torch.cat([img1_t, img2_t], dim=1)
x = self.in_conv(x)
x = self.ae(x)
x = self.out_conv(x)
x = torch.tanh(x)
return x
class CTSOT:
def __init__(self, device_info : TorchDeviceInfo = None,
state_dict : Union[dict, None] = None,
training : bool = False):
if device_info is None:
device_info = get_cpu_device_info()
self.device_info = device_info
self._net = net = CTSOTNet()
if state_dict is not None:
net.load_state_dict(state_dict)
if training:
net.train()
else:
net.eval()
self.set_device(device_info)
def set_device(self, device_info : TorchDeviceInfo = None):
if device_info is None or device_info.is_cpu():
self._net.cpu()
else:
self._net.cuda(device_info.get_index())
def get_state_dict(self):
return self.net.state_dict()
def get_net(self) -> CTSOTNet:
return self._net

View file

@ -8,13 +8,13 @@ import torch.nn.functional as F
from xlib import math as lib_math from xlib import math as lib_math
from xlib.file import SplittedFile from xlib.file import SplittedFile
from xlib.image import ImageProcessor from xlib.image import ImageProcessor
from xlib.torch import TorchDeviceInfo, get_cpu_device from xlib.torch import TorchDeviceInfo, get_cpu_device_info
class S3FD: class S3FD:
def __init__(self, device_info : TorchDeviceInfo = None ): def __init__(self, device_info : TorchDeviceInfo = None ):
if device_info is None: if device_info is None:
device_info = get_cpu_device() device_info = get_cpu_device_info()
self.device_info = device_info self.device_info = device_info
path = Path(__file__).parent / 'S3FD.pth' path = Path(__file__).parent / 'S3FD.pth'

View file

@ -1,2 +1,3 @@
from .CenterFace.CenterFace import CenterFace, CenterFace_to_onnx from .CenterFace.CenterFace import CenterFace_to_onnx
from .S3FD.S3FD import S3FD from .S3FD.S3FD import S3FD
from .CTSOT.CTSOT import CTSOT, CTSOTNet

View file

@ -53,14 +53,10 @@ def extract_FaceSynthetics(inputdir_path : Path, faceset_path : Path):
""" """
if faceset_path.suffix != '.dfs': if faceset_path.suffix != '.dfs':
raise ValueError('faceset_path must have .dfs extension.') raise ValueError('faceset_path must have .dfs extension.')
filepaths = lib_path.get_files_paths(inputdir_path) filepaths = lib_path.get_files_paths(inputdir_path)
fs = lib_face.Faceset(faceset_path, write_access=True, recreate=True)
fs = lib_face.Faceset(faceset_path) for filepath in lib_con.progress_bar_iterator(filepaths, desc='Processing'):
fs.recreate()
for filepath in lib_con.progress_bar_iterator(filepaths, 'Processing'):
if filepath.suffix == '.txt': if filepath.suffix == '.txt':
image_filepath = filepath.parent / f'{filepath.name.split("_")[0]}.png' image_filepath = filepath.parent / f'{filepath.name.split("_")[0]}.png'
@ -93,17 +89,13 @@ def extract_FaceSynthetics(inputdir_path : Path, faceset_path : Path):
ufm.set_UImage_uuid(uimg.get_uuid()) ufm.set_UImage_uuid(uimg.get_uuid())
ufm.set_FRect(flmrks.get_FRect()) ufm.set_FRect(flmrks.get_FRect())
ufm.add_FLandmarks2D(flmrks) ufm.add_FLandmarks2D(flmrks)
fs.add_UFaceMark(ufm)
fs.add_UImage(uimg, format='png') fs.add_UImage(uimg, format='png')
fs.add_UFaceMark(ufm)
fs.optimize()
fs.shrink()
fs.close() fs.close()
import code
code.interact(local=dict(globals(), **locals()))
# seg_filepath = input_path / ( Path(image_filepath).stem + '_seg.png') # seg_filepath = input_path / ( Path(image_filepath).stem + '_seg.png')
# if not seg_filepath.exists(): # if not seg_filepath.exists():
# raise ValueError(f'{seg_filepath} does not exist') # raise ValueError(f'{seg_filepath} does not exist')

View file

@ -15,7 +15,7 @@ def gaussian_blur (input_t : Tensor, sigma, dtype=None) -> Tensor:
""" """
if sigma <= 0.0: if sigma <= 0.0:
return input_t.copy() #TODO return input_t.copy()
device = input_t.get_device() device = input_t.get_device()

View file

@ -3,113 +3,220 @@ import threading
import time import time
from enum import IntEnum from enum import IntEnum
from typing import Any, Callable, List, Tuple, Union from typing import Any, Callable, List, Tuple, Union
from numbers import Number
from ... import text as lib_text from ... import text as lib_text
class EDlgMode(IntEnum): class EDlgMode(IntEnum):
UNDEFINED = 0 UNHANDLED = 0
BACK = 1 BACK = 1
RELOAD = 2 RELOAD = 2
WRONG_INPUT = 3 WRONG_INPUT = 3
SUCCESS = 4 HANDLED = 4
class DlgChoice: class DlgChoice:
def __init__(self, name : str = None, row_desc : str = None): def __init__(self, name : str = None,
super().__init__() row_def : str = None,
on_choose : Callable = None):
if len(name) == 0: if len(name) == 0:
raise ValueError('Zero len name is not valid.') raise ValueError('Zero len name is not valid.')
self._name = name self._name = name
self._row_desc = row_desc self._row_def = row_def
self._on_choose = on_choose
def get_name(self) -> Union[str, None]: return self._name def get_name(self) -> Union[str, None]: return self._name
def get_row_desc(self) -> Union[str, None]: return self._row_desc def get_row_def(self) -> Union[str, None]: return self._row_def
def get_on_choose(self) -> Callable: return self._on_choose
class Dlg: class Dlg:
def __init__(self, title : str = None, has_go_back=True): def __init__(self, on_recreate : Callable[ [], 'Dlg'] = None,
on_back : Callable = None,
top_rows_def : Union[str, List[str]] = None,
bottom_rows_def : Union[str, List[str]] = None,
):
""" """
base class for Diacon dialogs.
""" """
self._title = title self._on_recreate = on_recreate
self._has_go_back = has_go_back self._on_back = on_back
self._top_rows_def = top_rows_def
self._bottom_rows_def = bottom_rows_def
def recreate(self):
"""
"""
if self._on_recreate is not None:
return self._on_recreate(self)
else:
raise Exception('on_recreate() is not defined.')
def get_name(self) -> str: return self._name def get_name(self) -> str: return self._name
def handle_user_input(self, s : str) -> EDlgMode: def set_current(self, print=True):
Diacon.update_dlg(self, print=print)
def handle_user_input(self, s : str):
""" """
""" """
s = s.strip() mode = self.on_user_input(s.strip())
# ? and < available in any dialog, handle them first if mode == EDlgMode.UNHANDLED:
s_len = len(s) mode = EDlgMode.RELOAD
if s_len == 0: if mode == EDlgMode.WRONG_INPUT:
print('\nWrong input')
mode = EDlgMode.RELOAD
if mode == EDlgMode.RELOAD:
self.recreate().set_current()
if mode == EDlgMode.BACK:
if self._on_back is not None:
self._on_back(self)
#overridable
def on_user_input(self, s : str) -> EDlgMode:
if len(s) == 0:
return EDlgMode.RELOAD return EDlgMode.RELOAD
if s_len == 1: if self._on_back is not None and len(s) == 1:
#if s == '?':
# return EDlgMode.RELOAD
if s == '<': if s == '<':
return EDlgMode.BACK return EDlgMode.BACK
return EDlgMode.UNHANDLED
return self.on_user_input(s)
def print(self, table_width_max=80, col_spacing = 3): def print(self, table_width_max=80, col_spacing = 3):
""" """
print dialog print dialog
""" """
# Gather table lines
table_def : List[str]= [] table_def : List[str]= []
if self._has_go_back: trd = self._top_rows_def
brd = self._bottom_rows_def
if trd is not None:
if not isinstance(trd, (list,tuple)):
trd = [trd]
table_def += trd
if self._on_back is not None:
table_def.append('| < | Go back.') table_def.append('| < | Go back.')
table_def.append('|99') table_def.append('|99')
table_def = self.on_print(table_def) table_def = self.on_print(table_def)
table = lib_text.ascii_table(table_def, max_table_width=80, if brd is not None:
left_border = None, if not isinstance(brd, (list,tuple)):
right_border = None, brd = [brd]
table_def += brd
table = lib_text.ascii_table(table_def, max_table_width=80,
left_border = '| ',
right_border = ' |',
border = ' | ', border = ' | ',
row_symbol = None) row_symbol = None,
)
print() print()
print(table) print(table)
#overridable #overridable
def on_print(self, table_lines : List[Tuple[str,str]]): def on_print(self, table_lines : List[Tuple[str,str]]):
return table_lines return table_lines
class DlgNumber(Dlg):
def __init__(self, is_float : bool,
current_value = None,
min_value = None,
max_value = None,
clip_min_value = None,
clip_max_value = None,
on_value : Callable[ [Dlg, Number], None] = None,
on_recreate : Callable[ [], 'Dlg'] = None,
on_back : Callable = None,
top_rows_def : Union[str, List[str]] = None,
bottom_rows_def : Union[str, List[str]] = None, ):
super().__init__(on_recreate=on_recreate, on_back=on_back, top_rows_def=top_rows_def, bottom_rows_def=bottom_rows_def)
if min_value is not None and max_value is not None and min_value > max_value:
raise ValueError('min_value > max_value')
if clip_min_value is not None and clip_max_value is not None and clip_min_value > clip_max_value:
raise ValueError('clip_min_value > clip_max_value')
self._is_float = is_float
self._current_value = current_value
self._min_value = min_value
self._max_value = max_value
self._clip_min_value = clip_min_value
self._clip_max_value = clip_max_value
self._on_value = on_value
#overridable #overridable
def on_user_input(self, s : str) -> EDlgMode: def on_print(self, table_def : List[str]):
"""
handle user input minv, maxv = self._min_value, self._max_value
return False if input is invalid
""" if self._is_float:
return EDlgMode.UNDEFINED line = '| * | Enter float number'
else:
line = '| * | Enter integer number'
if minv is not None and maxv is None:
line += f' in range: [{minv} ... )'
elif minv is None and maxv is not None:
line += f' in range: ( ... {maxv} ]'
elif minv is not None and maxv is not None:
line += f' in range: [{minv} ... {maxv} ]'
table_def.append(line)
return table_def
#overridable
def on_user_input(self, s : str) -> bool:
result = super().on_user_input(s)
if result == EDlgMode.UNHANDLED:
try:
print(s)
v = float(s) if self._is_float else int(s)
if self._min_value is not None:
if v < self._min_value:
return EDlgMode.WRONG_INPUT
if self._max_value is not None:
if v > self._max_value:
return EDlgMode.WRONG_INPUT
if self._clip_min_value is not None:
if v < self._clip_min_value:
v = self._clip_min_value
if self._clip_max_value is not None:
if v > self._clip_max_value:
v = self._clip_max_value
if self._on_value is not None:
self._on_value(self, v)
return EDlgMode.HANDLED
except:
return EDlgMode.WRONG_INPUT
return result
class DlgChoices(Dlg): class DlgChoices(Dlg):
def __init__(self, choices : List[DlgChoice], multiple_choices=False, title : str = None, has_go_back = True): def __init__(self, choices : List[DlgChoice],
""" on_multi_choice : Callable[ [ List[DlgChoice] ], None] = None,
on_recreate : Callable[ [Dlg], Dlg] = None,
""" on_back : Callable = None,
super().__init__(title=title, has_go_back=has_go_back) top_rows_def : Union[str, List[str]] = None,
bottom_rows_def : Union[str, List[str]] = None,
):
super().__init__(on_recreate=on_recreate, on_back=on_back, top_rows_def=top_rows_def, bottom_rows_def=bottom_rows_def)
self._choices = choices self._choices = choices
self._multiple_choices = multiple_choices self._on_multi_choice = on_multi_choice
self._results = None
self._results_id = None
self._short_names = [choice.get_name() for choice in choices] self._short_names = [choice.get_name() for choice in choices]
# if any([x is not None for x in self._short_names]):
# # Using short names from choices
# if any([x is None for x in self._short_names]):
# raise Exception('No short name for one of choices.')
# if len(set(self._short_names)) != len(self._short_names):
# raise ValueError(f'Contains duplicate short names: {self._short_names}')
# else:
# Make short names for all choices # Make short names for all choices
names = [ choice.get_name() for choice in choices ] names = [ choice.get_name() for choice in choices ]
names_len = len(names) names_len = len(names)
@ -139,27 +246,14 @@ class DlgChoices(Dlg):
break break
self._short_names = short_names self._short_names = short_names
def get_selected_choices(self) -> List[DlgChoice]:
"""
returns selected choices
"""
return self._results
def get_selected_choices_id(self) -> List[int]:
"""
returns selected choice
"""
return self._results_id
#overridable #overridable
def on_print(self, table_def : List[str]): def on_print(self, table_def : List[str]):
for short_name, choice in zip(self._short_names, self._choices): for short_name, choice in zip(self._short_names, self._choices):
row_def = f'| {short_name}' row_def = f'| {short_name}'
row_desc = choice.get_row_desc() x = choice.get_row_def()
if row_desc is not None: if x is not None:
row_def += row_desc row_def += x
table_def.append(row_def) table_def.append(row_def)
return table_def return table_def
@ -167,35 +261,36 @@ class DlgChoices(Dlg):
#overridable #overridable
def on_user_input(self, s : str) -> bool: def on_user_input(self, s : str) -> bool:
result = super().on_user_input(s) result = super().on_user_input(s)
if result == EDlgMode.UNDEFINED: if result == EDlgMode.UNHANDLED:
if self._multiple_choices: if self._on_multi_choice is not None:
multi_s = s.split(',') multi_s = s.split(',')
else: else:
multi_s = [s] multi_s = [s]
results = [] choices_id = []
results_id = []
for s in multi_s: for s in multi_s:
s = s.strip() x = [ i for i, short_name in enumerate(self._short_names) if s.strip() == short_name ]
x = [ i for i,short_name in enumerate(self._short_names) if s == short_name ]
if len(x) == 0: if len(x) == 0:
# no short name match # No short name match
return EDlgMode.WRONG_INPUT return EDlgMode.WRONG_INPUT
else: else:
id = x[0] id = x[0]
results_id.append(id) choices_id.append(id)
results.append(self._choices[id])
if len(set(choices_id)) != len(choices_id):
if len(set(results_id)) != len(results_id):
# Duplicate input # Duplicate input
return EDlgMode.WRONG_INPUT return EDlgMode.WRONG_INPUT
self._results = results for id in choices_id:
self._results_id = results_id on_choose = self._choices[id].get_on_choose()
if on_choose is not None:
on_choose(self)
if self._on_multi_choice is not None:
self._on_multi_choice(choices_id)
return EDlgMode.SUCCESS return EDlgMode.HANDLED
return result return result
@ -204,42 +299,9 @@ class DlgChoices(Dlg):
class _Diacon: class _Diacon:
""" """
User dialog with via console. User dialog with via console.
Internal architecture:
[
Main-Thread
current thread from which __init__() called
]
[
Dialog-Thread
separate thread where dialogs are handled and dynamically created
we need this thread, because main thread can be busy,
for example training neural network
calls on_dlg() provided with __init__
thus keep in mind on_dlg() works in separate thread
This thread must not be blocked inside on_dlg(),
because Diacon.stop() can be called that stops all threads.
]
[
Input-Thread
separate thread where user input is accepted in non-blocking mode,
and transfered to processing thread
]
""" """
def __init__(self): def __init__(self):
self._on_dlg : Callable = None
self._lock = threading.RLock() self._lock = threading.RLock()
self._current_dlg : Dlg = None self._current_dlg : Dlg = None
self._new_dlg : Dlg = None self._new_dlg : Dlg = None
@ -250,11 +312,10 @@ class _Diacon:
self._input_request = False self._input_request = False
self._input_result : str = None self._input_result : str = None
def start(self, on_dlg : Callable): def start(self):
if self._started: if self._started:
raise Exception('Diacon already started.') raise Exception('Diacon already started.')
self._started = True self._started = True
self._on_dlg = on_dlg
self._input_t = threading.Thread(target=self._input_thread, daemon=True) self._input_t = threading.Thread(target=self._input_thread, daemon=True)
self._input_t.start() self._input_t.start()
@ -285,8 +346,6 @@ class _Diacon:
def _dialog_thread(self, ): def _dialog_thread(self, ):
self._on_dlg(None, EDlgMode.RELOAD)
while self._started: while self._started:
with self._lock: with self._lock:
@ -303,13 +362,7 @@ class _Diacon:
if input_result is not None: if input_result is not None:
if self._current_dlg is not None: if self._current_dlg is not None:
mode = self._current_dlg.handle_user_input(input_result) self._current_dlg.handle_user_input(input_result)
if mode == EDlgMode.WRONG_INPUT:
print('\nWrong input')
mode = EDlgMode.RELOAD
if mode == EDlgMode.UNDEFINED:
mode = EDlgMode.RELOAD
self._on_dlg(self._current_dlg, mode)
continue continue
time.sleep(0.005) time.sleep(0.005)
@ -332,6 +385,9 @@ class _Diacon:
show current or set new Dialog show current or set new Dialog
Can be called from any thread. Can be called from any thread.
""" """
if not self._started:
self.start()
self._new_dlg = (new_dlg, print) self._new_dlg = (new_dlg, print)
Diacon = _Diacon() Diacon = _Diacon()

View file

@ -1 +1 @@
from .Diacon import Diacon, Dlg, DlgChoice, DlgChoices, EDlgMode from .Diacon import Diacon, Dlg, DlgChoice, DlgChoices, EDlgMode, DlgNumber

View file

@ -9,7 +9,7 @@ class FMask:
def __init__(self, _from_pickled=False): def __init__(self, _from_pickled=False):
""" """
""" """
self._uuid : Union[bytes, None] = uuid.uuid4().bytes_le if not _from_pickled else None self._uuid : Union[bytes, None] = uuid.uuid4().bytes if not _from_pickled else None
self._mask_type : Union[FMask.Type, None] = None self._mask_type : Union[FMask.Type, None] = None
self._FImage_uuid : Union[bytes, None] = None self._FImage_uuid : Union[bytes, None] = None

View file

@ -52,9 +52,25 @@ class FaceWarper:
self._rw_grid_tx = rnd_state.uniform(*rw_grid_tx) if isinstance(rw_grid_tx, Iterable) else rw_grid_tx self._rw_grid_tx = rnd_state.uniform(*rw_grid_tx) if isinstance(rw_grid_tx, Iterable) else rw_grid_tx
self._rw_grid_ty = rnd_state.uniform(*rw_grid_ty) if isinstance(rw_grid_ty, Iterable) else rw_grid_ty self._rw_grid_ty = rnd_state.uniform(*rw_grid_ty) if isinstance(rw_grid_ty, Iterable) else rw_grid_ty
self._warp_rnd_mat = Affine2DUniMat.from_transformation(0.5, 0.5, self._rw_grid_rot_deg, 1.0+self._rw_grid_scale, self._rw_grid_tx, self._rw_grid_ty)
self._align_rnd_mat = Affine2DUniMat.from_transformation(0.5, 0.5, self._align_rot_deg, 1.0+self._align_scale, self._align_tx, self._align_ty)
self._rnd_state_state = rnd_state.get_state() self._rnd_state_state = rnd_state.get_state()
self._cached = {} self._cached = {}
def get_aligned_random_transform_mat(self) -> Affine2DUniMat:
"""
returns Affine2DUniMat that represents transformation from aligned face to randomly transformed aligned face
"""
mat1 = self._img_to_face_uni_mat
mat2 = (self._face_to_img_uni_mat * self._align_rnd_mat).invert()
pts = [ [0,0], [1,0], [1,1]]
src_pts = mat1.transform_points(pts)
dst_pts = mat2.transform_points(pts)
return Affine2DUniMat.from_3_pairs(src_pts, dst_pts)
def transform(self, img : np.ndarray, out_res : int, random_warp : bool = True) -> np.ndarray: def transform(self, img : np.ndarray, out_res : int, random_warp : bool = True) -> np.ndarray:
""" """
transform an image. transform an image.
@ -91,9 +107,7 @@ class FaceWarper:
face_warp_grid = FaceWarper._gen_random_warp_uni_grid_diff(out_res, self._rw_grid_cell_count, 0.12, rnd_state) face_warp_grid = FaceWarper._gen_random_warp_uni_grid_diff(out_res, self._rw_grid_cell_count, 0.12, rnd_state)
# make a randomly transformable mat to transform face_warp_grid from face to image # make a randomly transformable mat to transform face_warp_grid from face to image
face_warp_grid_mat = (self._face_to_img_uni_mat * face_warp_grid_mat = self._face_to_img_uni_mat * self._warp_rnd_mat
Affine2DUniMat.from_transformation(0.5, 0.5, self._rw_grid_rot_deg, 1.0+self._rw_grid_scale, self._rw_grid_tx, self._rw_grid_ty)
)
# warp face_warp_grid to the space of image and merge with image_grid # warp face_warp_grid to the space of image and merge with image_grid
image_grid += cv2.warpAffine(face_warp_grid, face_warp_grid_mat.to_exact_mat(out_res,out_res, W, H), (W,H), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT) image_grid += cv2.warpAffine(face_warp_grid, face_warp_grid_mat.to_exact_mat(out_res,out_res, W, H), (W,H), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT)
@ -101,9 +115,10 @@ class FaceWarper:
# scale uniform grid from to image size # scale uniform grid from to image size
image_grid *= (H-1, W-1) image_grid *= (H-1, W-1)
# apply random transormations for align mat # apply random transformations for align mat
img_to_face_rnd_mat = (self._face_to_img_uni_mat * Affine2DMat.from_transformation(0.5, 0.5, self._align_rot_deg, 1.0+self._align_scale, self._align_tx, self._align_ty) #img_to_face_rnd_uni_mat = (self._face_to_img_uni_mat * self._align_rnd_mat).invert()
).invert().to_exact_mat(W,H,out_res,out_res)
img_to_face_rnd_mat = (self._face_to_img_uni_mat * self._align_rnd_mat).invert().to_exact_mat(W,H,out_res,out_res)
# warp image_grid to face space # warp image_grid to face space
image_grid = cv2.warpAffine(image_grid, img_to_face_rnd_mat, (out_res,out_res), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_REPLICATE ) image_grid = cv2.warpAffine(image_grid, img_to_face_rnd_mat, (out_res,out_res), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_REPLICATE )

View file

@ -1,19 +1,22 @@
import pickle import pickle
import sqlite3 import uuid
from pathlib import Path from pathlib import Path
from typing import Generator, List, Union, Iterable from typing import Generator, Iterable, List, Union
import cv2 import cv2
import h5py
import numpy as np import numpy as np
from .. import console as lib_con
from .FMask import FMask from .FMask import FMask
from .UFaceMark import UFaceMark from .UFaceMark import UFaceMark
from .UImage import UImage from .UImage import UImage
from .UPerson import UPerson from .UPerson import UPerson
class Faceset: class Faceset:
def __init__(self, path = None): def __init__(self, path = None, write_access=False, recreate=False):
""" """
Faceset is a class to store and manage face related data. Faceset is a class to store and manage face related data.
@ -21,205 +24,155 @@ class Faceset:
path path to faceset .dfs file path path to faceset .dfs file
write_access
recreate
Can be pickled. Can be pickled.
""" """
self._f = None
self._path = path = Path(path) self._path = path = Path(path)
if path.suffix != '.dfs': if path.suffix != '.dfs':
raise ValueError('Path must be a .dfs file') raise ValueError('Path must be a .dfs file')
self._conn = conn = sqlite3.connect(path, isolation_level=None) if path.exists():
self._cur = cur = conn.cursor() if write_access and recreate:
path.unlink()
elif not write_access:
raise FileNotFoundError(f'File {path} not found.')
cur = self._get_cursor() self._mode = 'a' if write_access else 'r'
cur.execute('BEGIN IMMEDIATE') self._open()
if not self._is_table_exists('FacesetInfo'):
self.recreate(shrink=False, _transaction=False)
cur.execute('COMMIT')
self.shrink()
else:
cur.execute('END')
def __del__(self): def __del__(self):
self.close() self.close()
def __getstate__(self): def __getstate__(self):
return {'_path' : self._path} return {'_path' : self._path, '_mode' : self._mode}
def __setstate__(self, d): def __setstate__(self, d):
self.__init__( d['_path'] ) self._f = None
self._path = d['_path']
self._mode = d['_mode']
self._open()
def __repr__(self): return self.__str__() def __repr__(self): return self.__str__()
def __str__(self): def __str__(self):
return f"Faceset. UImage:{self.get_UImage_count()} UFaceMark:{self.get_UFaceMark_count()} UPerson:{self.get_UPerson_count()}" return f"Faceset. UImage:{self.get_UImage_count()} UFaceMark:{self.get_UFaceMark_count()} UPerson:{self.get_UPerson_count()}"
def _is_table_exists(self, name): def _open(self):
return self._cur.execute(f"SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", [name]).fetchone()[0] != 0 if self._f is None:
self._f = f = h5py.File(self._path, mode=self._mode)
self._UFaceMark_grp = f.require_group('UFaceMark')
self._UImage_grp = f.require_group('UImage')
self._UImage_image_data_grp = f.require_group('UImage_image_data')
self._UPerson_grp = f.require_group('UPerson')
def _get_cursor(self) -> sqlite3.Cursor: return self._cur
def close(self): def close(self):
if self._cur is not None: if self._f is not None:
self._cur.close() self._f.close()
self._cur = None self._f = None
if self._conn is not None: def optimize(self, verbose=True):
self._conn.close()
self._conn = None
def shrink(self):
self._cur.execute('VACUUM')
def recreate(self, shrink=True, _transaction=True):
""" """
delete all data and recreate Faceset structure. recreate Faceset with optimized structure.
""" """
cur = self._get_cursor() if verbose:
print(f'Optimizing {self._path.name}...')
if _transaction: tmp_path = self._path.parent / (self._path.stem + '_optimizing' + self._path.suffix)
cur.execute('BEGIN IMMEDIATE')
for table_name, in cur.execute("SELECT name from sqlite_master where type = 'table';").fetchall(): tmp_fs = Faceset(tmp_path, write_access=True, recreate=True)
cur.execute(f'DROP TABLE {table_name}') self._group_copy(tmp_fs._UFaceMark_grp, self._UFaceMark_grp, verbose=verbose)
self._group_copy(tmp_fs._UPerson_grp, self._UPerson_grp, verbose=verbose)
self._group_copy(tmp_fs._UImage_grp, self._UImage_grp, verbose=verbose)
self._group_copy(tmp_fs._UImage_image_data_grp, self._UImage_image_data_grp, verbose=verbose)
tmp_fs.close()
(cur.execute('CREATE TABLE FacesetInfo (version INT)') self.close()
.execute('INSERT INTO FacesetInfo VALUES (1)') self._path.unlink()
tmp_path.rename(self._path)
self._open()
.execute('CREATE TABLE UImage (uuid BLOB, name TEXT, format TEXT, data BLOB)') def _group_copy(self, group_dst : h5py.Group, group_src : h5py.Group, verbose=True):
.execute('CREATE TABLE UPerson (uuid BLOB, data BLOB)') for key, value in lib_con.progress_bar_iterator(group_src.items(), desc=f'Copying {group_src.name} -> {group_dst.name}', suppress_print=not verbose):
.execute('CREATE TABLE UFaceMark (uuid BLOB, UImage_uuid BLOB, UPerson_uuid BLOB, data BLOB)') d = group_dst.create_dataset(key, shape=value.shape, dtype=value.dtype )
) d[:] = value[:]
for a_key, a_value in value.attrs.items():
d.attrs[a_key] = a_value
if _transaction: def _group_read_bytes(self, group : h5py.Group, key : str, check_key=True) -> Union[bytes, None]:
cur.execute('COMMIT') if check_key and key not in group:
return None
dataset = group[key]
data_bytes = bytearray(len(dataset))
dataset.read_direct(np.frombuffer(data_bytes, dtype=np.uint8))
return data_bytes
if shrink: def _group_write_bytes(self, group : h5py.Group, key : str, data : bytes, update_existing=True) -> Union[h5py.Dataset, None]:
self.shrink() if key in group:
if not update_existing:
return None
del group[key]
return group.create_dataset(key, data=np.frombuffer(data, dtype=np.uint8) )
################### ###################
### UFaceMark ### UFaceMark
################### ###################
def _UFaceMark_from_db_row(self, db_row) -> UFaceMark: def add_UFaceMark(self, ufacemark_or_list : UFaceMark, update_existing=True):
uuid, UImage_uuid, UPerson_uuid, data = db_row
ufm = UFaceMark()
ufm.restore_state(pickle.loads(data))
return ufm
def add_UFaceMark(self, ufacemark_or_list : UFaceMark):
""" """
add or update UFaceMark in DB add or update UFaceMark in DB
""" """
if not isinstance(ufacemark_or_list, Iterable): if not isinstance(ufacemark_or_list, Iterable):
ufacemark_or_list : List[UFaceMark] = [ufacemark_or_list] ufacemark_or_list : List[UFaceMark] = [ufacemark_or_list]
cur = self._cur
cur.execute('BEGIN IMMEDIATE')
for ufm in ufacemark_or_list: for ufm in ufacemark_or_list:
uuid = ufm.get_uuid() self._group_write_bytes(self._UFaceMark_grp, ufm.get_uuid().hex(), pickle.dumps(ufm.dump_state()), update_existing=update_existing )
UImage_uuid = ufm.get_UImage_uuid()
UPerson_uuid = ufm.get_UPerson_uuid()
data = pickle.dumps(ufm.dump_state())
if cur.execute('SELECT COUNT(*) from UFaceMark where uuid=?', [uuid] ).fetchone()[0] != 0:
cur.execute('UPDATE UFaceMark SET UImage_uuid=?, UPerson_uuid=?, data=? WHERE uuid=?',
[UImage_uuid, UPerson_uuid, data, uuid])
else:
cur.execute('INSERT INTO UFaceMark VALUES (?, ?, ?, ?)', [uuid, UImage_uuid, UPerson_uuid, data])
cur.execute('COMMIT')
def get_UFaceMark_count(self) -> int: def get_UFaceMark_count(self) -> int:
return self._cur.execute('SELECT COUNT(*) FROM UFaceMark').fetchone()[0] return len(self._UFaceMark_grp.keys())
def get_all_UFaceMark(self) -> List[UFaceMark]: def get_all_UFaceMark(self) -> List[UFaceMark]:
return [ self._UFaceMark_from_db_row(db_row) for db_row in self._cur.execute('SELECT * FROM UFaceMark').fetchall() ] return [ UFaceMark.from_state(pickle.loads(self._group_read_bytes(self._UFaceMark_grp, key, check_key=False))) for key in self._UFaceMark_grp.keys() ]
def get_all_UFaceMark_uuids(self) -> List[bytes]:
return [ uuid.UUID(key).bytes for key in self._UFaceMark_grp.keys() ]
def get_UFaceMark_by_uuid(self, uuid : bytes) -> Union[UFaceMark, None]: def get_UFaceMark_by_uuid(self, uuid : bytes) -> Union[UFaceMark, None]:
c = self._cur.execute('SELECT * FROM UFaceMark WHERE uuid=?', [uuid]) data = self._group_read_bytes(self._UFaceMark_grp, uuid.hex())
db_row = c.fetchone() if data is None:
if db_row is None:
return None return None
return UFaceMark.from_state(pickle.loads(data))
return self._UFaceMark_from_db_row(db_row) def delete_UFaceMark_by_uuid(self, uuid : bytes) -> bool:
key = uuid.hex()
if key in self._UFaceMark_grp:
del self._UFaceMark_grp[key]
return True
return False
def iter_UFaceMark(self) -> Generator[UFaceMark, None, None]: def iter_UFaceMark(self) -> Generator[UFaceMark, None, None]:
""" """
returns Generator of UFaceMark returns Generator of UFaceMark
""" """
for db_row in self._cur.execute('SELECT * FROM UFaceMark').fetchall(): for key in self._UFaceMark_grp.keys():
yield self._UFaceMark_from_db_row(db_row) yield UFaceMark.from_state(pickle.loads(self._group_read_bytes(self._UFaceMark_grp, key, check_key=False)))
def delete_all_UFaceMark(self): def delete_all_UFaceMark(self):
""" """
deletes all UFaceMark from DB deletes all UFaceMark from DB
""" """
(self._cur.execute('BEGIN IMMEDIATE') for key in self._UFaceMark_grp.keys():
.execute('DELETE FROM UFaceMark') del self._UFaceMark_grp[key]
.execute('COMMIT') )
###################
### UPerson
###################
def _UPerson_from_db_row(self, db_row) -> UPerson:
uuid, data = db_row
up = UPerson()
up.restore_state(pickle.loads(data))
return up
def add_UPerson(self, uperson_or_list : UPerson):
"""
add or update UPerson in DB
"""
if not isinstance(uperson_or_list, Iterable):
uperson_or_list : List[UPerson] = [uperson_or_list]
cur = self._cur
cur.execute('BEGIN IMMEDIATE')
for uperson in uperson_or_list:
uuid = uperson.get_uuid()
data = pickle.dumps(uperson.dump_state())
if cur.execute('SELECT COUNT(*) from UPerson where uuid=?', [uuid]).fetchone()[0] != 0:
cur.execute('UPDATE UPerson SET data=? WHERE uuid=?', [data])
else:
cur.execute('INSERT INTO UPerson VALUES (?, ?)', [uuid, data])
cur.execute('COMMIT')
def get_UPerson_count(self) -> int:
return self._cur.execute('SELECT COUNT(*) FROM UPerson').fetchone()[0]
def get_all_UPerson(self) -> List[UPerson]:
return [ self._UPerson_from_db_row(db_row) for db_row in self._cur.execute('SELECT * FROM UPerson').fetchall() ]
def iter_UPerson(self) -> Generator[UPerson, None, None]:
"""
iterator of all UPerson's
"""
for db_row in self._cur.execute('SELECT * FROM UPerson').fetchall():
yield self._UPerson_from_db_row(db_row)
def delete_all_UPerson(self):
"""
deletes all UPerson from DB
"""
(self._cur.execute('BEGIN IMMEDIATE')
.execute('DELETE FROM UPerson')
.execute('COMMIT') )
################### ###################
### UImage ### UImage
################### ###################
def _UImage_from_db_row(self, db_row) -> UImage: def add_UImage(self, uimage_or_list : UImage, format : str = 'png', quality : int = 100, update_existing=True):
uuid, name, format, data_bytes = db_row
img = cv2.imdecode(np.frombuffer(data_bytes, dtype=np.uint8), flags=cv2.IMREAD_UNCHANGED)
uimg = UImage()
uimg.set_uuid(uuid)
uimg.set_name(name)
uimg.assign_image(img)
return uimg
def add_UImage(self, uimage_or_list : UImage, format : str = 'webp', quality : int = 100):
""" """
add or update UImage in DB add or update UImage in DB
@ -239,9 +192,8 @@ class Faceset:
raise ValueError('quality must be in range [0..100]') raise ValueError('quality must be in range [0..100]')
if not isinstance(uimage_or_list, Iterable): if not isinstance(uimage_or_list, Iterable):
uimage_or_list = [uimage_or_list] uimage_or_list : List[UImage] = [uimage_or_list]
uimage_datas = []
for uimage in uimage_or_list: for uimage in uimage_or_list:
if format == 'webp': if format == 'webp':
imencode_args = [int(cv2.IMWRITE_WEBP_QUALITY), quality] imencode_args = [int(cv2.IMWRITE_WEBP_QUALITY), quality]
@ -251,44 +203,112 @@ class Faceset:
imencode_args = [int(cv2.IMWRITE_JPEG2000_COMPRESSION_X1000), quality*10] imencode_args = [int(cv2.IMWRITE_JPEG2000_COMPRESSION_X1000), quality*10]
else: else:
imencode_args = [] imencode_args = []
ret, data_bytes = cv2.imencode( f'.{format}', uimage.get_image(), imencode_args) ret, data_bytes = cv2.imencode( f'.{format}', uimage.get_image(), imencode_args)
if not ret: if not ret:
raise Exception(f'Unable to encode image format {format}') raise Exception(f'Unable to encode image format {format}')
uimage_datas.append(data_bytes.data)
cur = self._cur key = uimage.get_uuid().hex()
cur.execute('BEGIN IMMEDIATE')
for uimage, data in zip(uimage_or_list, uimage_datas):
uuid = uimage.get_uuid()
if cur.execute('SELECT COUNT(*) from UImage where uuid=?', [uuid] ).fetchone()[0] != 0:
cur.execute('UPDATE UImage SET name=?, format=?, data=? WHERE uuid=?', [uimage.get_name(), format, data, uuid])
else:
cur.execute('INSERT INTO UImage VALUES (?, ?, ?, ?)', [uuid, uimage.get_name(), format, data])
cur.execute('COMMIT')
def get_UImage_count(self) -> int: return self._cur.execute('SELECT COUNT(*) FROM UImage').fetchone()[0] self._group_write_bytes(self._UImage_grp, key, pickle.dumps(uimage.dump_state(exclude_image=True)), update_existing=update_existing )
def get_UImage_by_uuid(self, uuid : Union[bytes, None]) -> Union[UImage, None]: d = self._group_write_bytes(self._UImage_image_data_grp, key, data_bytes.data, update_existing=update_existing )
""" d.attrs['format'] = format
""" d.attrs['quality'] = quality
if uuid is None:
def get_UImage_count(self) -> int:
return len(self._UImage_grp.keys())
def get_all_UImage(self) -> List[UImage]:
return [ self._get_UImage_by_key(key) for key in self._UImage_grp.keys() ]
def get_all_UImage_uuids(self) -> List[bytes]:
return [ uuid.UUID(key).bytes for key in self._UImage_grp.keys() ]
def _get_UImage_by_key(self, key, check_key=True) -> Union[UImage, None]:
data = self._group_read_bytes(self._UImage_grp, key, check_key=check_key)
if data is None:
return None return None
uimg = UImage.from_state(pickle.loads(data))
db_row = self._cur.execute('SELECT * FROM UImage where uuid=?', [uuid]).fetchone() image_data = self._group_read_bytes(self._UImage_image_data_grp, key, check_key=check_key)
if db_row is None: if image_data is not None:
return None uimg.assign_image (cv2.imdecode(np.frombuffer(image_data, dtype=np.uint8), flags=cv2.IMREAD_UNCHANGED))
return self._UImage_from_db_row(db_row)
def iter_UImage(self) -> Generator[UImage, None, None]: return uimg
def get_UImage_by_uuid(self, uuid : bytes) -> Union[UImage, None]:
return self._get_UImage_by_key(uuid.hex())
def delete_UImage_by_uuid(self, uuid : bytes):
key = uuid.hex()
if key in self._UImage_grp:
del self._UImage_grp[key]
if key in self._UImage_image_data_grp:
del self._UImage_image_data_grp[key]
def iter_UImage(self, include_key=False) -> Generator[UImage, None, None]:
""" """
iterator of all UImage's returns Generator of UImage
""" """
for db_row in self._cur.execute('SELECT * FROM UImage').fetchall(): for key in self._UImage_grp.keys():
yield self._UImage_from_db_row(db_row) uimg = self._get_UImage_by_key(key, check_key=False)
yield (uimg, key) if include_key else uimg
def delete_all_UImage(self): def delete_all_UImage(self):
""" """
deletes all UImage from DB deletes all UImage from DB
""" """
(self._cur.execute('BEGIN IMMEDIATE') for key in self._UImage_grp.keys():
.execute('DELETE FROM UImage') del self._UImage_grp[key]
.execute('COMMIT') ) for key in self._UImage_image_data_grp.keys():
del self._UImage_image_data_grp[key]
###################
### UPerson
###################
def add_UPerson(self, uperson_or_list : UPerson, update_existing=True):
"""
add or update UPerson in DB
"""
if not isinstance(uperson_or_list, Iterable):
uperson_or_list : List[UPerson] = [uperson_or_list]
for uperson in uperson_or_list:
self._group_write_bytes(self._UPerson_grp, uperson.get_uuid().hex(), pickle.dumps(uperson.dump_state()), update_existing=update_existing )
def get_UPerson_count(self) -> int:
return len(self._UPerson_grp.keys())
def get_all_UPerson(self) -> List[UPerson]:
return [ UPerson.from_state(pickle.loads(self._group_read_bytes(self._UPerson_grp, key, check_key=False))) for key in self._UPerson_grp.keys() ]
def get_all_UPerson_uuids(self) -> List[bytes]:
return [ uuid.UUID(key).bytes for key in self._UPerson_grp.keys() ]
def get_UPerson_by_uuid(self, uuid : bytes) -> Union[UPerson, None]:
data = self._group_read_bytes(self._UPerson_grp, uuid.hex())
if data is None:
return None
return UPerson.from_state(pickle.loads(data))
def delete_UPerson_by_uuid(self, uuid : bytes) -> bool:
key = uuid.hex()
if key in self._UPerson_grp:
del self._UPerson_grp[key]
return True
return False
def iter_UPerson(self) -> Generator[UPerson, None, None]:
"""
returns Generator of UPerson
"""
for key in self._UPerson_grp.keys():
yield UPerson.from_state(pickle.loads(self._group_read_bytes(self._UPerson_grp, key, check_key=False)))
def delete_all_UPerson(self):
"""
deletes all UPerson from DB
"""
for key in self._UPerson_grp.keys():
del self._UPerson_grp[key]

View file

@ -25,7 +25,13 @@ class UFaceMark(IState):
def __repr__(self): return self.__str__() def __repr__(self): return self.__str__()
def __str__(self): def __str__(self):
return f"UFaceMark UUID:[...{self.get_uuid()[-4:].hex()}]" return f"UFaceMark UUID:[...{self.get_uuid()[-4:].hex()}]"
@staticmethod
def from_state(state : dict) -> 'UFaceMark':
ufm = UFaceMark()
ufm.restore_state(state)
return ufm
def restore_state(self, state : dict): def restore_state(self, state : dict):
self._uuid = state.get('_uuid', None) self._uuid = state.get('_uuid', None)
self._UImage_uuid = state.get('_UImage_uuid', None) self._UImage_uuid = state.get('_UImage_uuid', None)
@ -45,7 +51,7 @@ class UFaceMark(IState):
def get_uuid(self) -> Union[bytes, None]: def get_uuid(self) -> Union[bytes, None]:
if self._uuid is None: if self._uuid is None:
self._uuid = uuid.uuid4().bytes_le self._uuid = uuid.uuid4().bytes
return self._uuid return self._uuid
def set_uuid(self, uuid : Union[bytes, None]): def set_uuid(self, uuid : Union[bytes, None]):
@ -72,6 +78,16 @@ class UFaceMark(IState):
self._FRect = face_urect self._FRect = face_urect
def get_all_FLandmarks2D(self) -> List[FLandmarks2D]: return self._FLandmarks2D_list def get_all_FLandmarks2D(self) -> List[FLandmarks2D]: return self._FLandmarks2D_list
def get_FLandmarks2D_best(self) -> Union[FLandmarks2D, None]:
"""get best available FLandmarks2D """
lmrks = self.get_FLandmarks2D_by_type(ELandmarks2D.L468)
if lmrks is None:
lmrks = self.get_FLandmarks2D_by_type(ELandmarks2D.L68)
if lmrks is None:
lmrks = self.get_FLandmarks2D_by_type(ELandmarks2D.L5)
return lmrks
def get_FLandmarks2D_by_type(self, type : ELandmarks2D) -> Union[FLandmarks2D, None]: def get_FLandmarks2D_by_type(self, type : ELandmarks2D) -> Union[FLandmarks2D, None]:
"""get FLandmarks2D from list by type""" """get FLandmarks2D from list by type"""
if not isinstance(type, ELandmarks2D): if not isinstance(type, ELandmarks2D):

View file

@ -18,10 +18,31 @@ class UImage(IState):
def __str__(self): return f"UImage UUID:[...{self.get_uuid()[-4:].hex()}] name:[{self._name}] image:[{ (self._image.shape, self._image.dtype) if self._image is not None else None}]" def __str__(self): return f"UImage UUID:[...{self.get_uuid()[-4:].hex()}] name:[{self._name}] image:[{ (self._image.shape, self._image.dtype) if self._image is not None else None}]"
def __repr__(self): return self.__str__() def __repr__(self): return self.__str__()
@staticmethod
def from_state(state : dict) -> 'UImage':
ufm = UImage()
ufm.restore_state(state)
return ufm
def restore_state(self, state : dict):
self._uuid = state.get('_uuid', None)
self._name = state.get('_name', None)
self._image = state.get('_image', None)
def dump_state(self, exclude_image=False) -> dict:
d = {'_uuid' : self._uuid,
'_name' : self._name,
}
if not exclude_image:
d['_image'] = self._image
return d
def get_uuid(self) -> Union[bytes, None]: def get_uuid(self) -> Union[bytes, None]:
if self._uuid is None: if self._uuid is None:
self._uuid = uuid.uuid4().bytes_le self._uuid = uuid.uuid4().bytes
return self._uuid return self._uuid
def set_uuid(self, uuid : Union[bytes, None]): def set_uuid(self, uuid : Union[bytes, None]):

View file

@ -15,6 +15,12 @@ class UPerson(IState):
def __str__(self): return f"UPerson UUID:[...{self._uuid[-4:].hex()}] name:[{self._name}] age:[{self._age}]" def __str__(self): return f"UPerson UUID:[...{self._uuid[-4:].hex()}] name:[{self._name}] age:[{self._age}]"
def __repr__(self): return self.__str__() def __repr__(self): return self.__str__()
@staticmethod
def from_state(state : dict) -> 'UPerson':
ufm = UPerson()
ufm.restore_state(state)
return ufm
def restore_state(self, state : dict): def restore_state(self, state : dict):
self._uuid = state.get('_uuid', None) self._uuid = state.get('_uuid', None)
self._name = state.get('_name', None) self._name = state.get('_name', None)
@ -28,7 +34,7 @@ class UPerson(IState):
def get_uuid(self) -> Union[bytes, None]: def get_uuid(self) -> Union[bytes, None]:
if self._uuid is None: if self._uuid is None:
self._uuid = uuid.uuid4().bytes_le self._uuid = uuid.uuid4().bytes
return self._uuid return self._uuid
def set_uuid(self, uuid : Union[bytes, None]): def set_uuid(self, uuid : Union[bytes, None]):

View file

@ -2,12 +2,14 @@ import cv2
import numpy as np import numpy as np
import numpy.linalg as npla import numpy.linalg as npla
def sot(src,trg, steps=10, batch_size=30, reg_sigmaXY=16.0, reg_sigmaV=5.0): def sot(src,trg, mask=None, steps=10, batch_size=30, reg_sigmaXY=16.0, reg_sigmaV=5.0, return_diff=False):
""" """
Color Transform via Sliced Optimal Transfer, ported from https://github.com/dcoeurjo/OTColorTransfer Color Transform via Sliced Optimal Transfer, ported from https://github.com/dcoeurjo/OTColorTransfer
src - any float range any channel image src - any float range any channel image
dst - any float range any channel image, same shape as src dst - any float range any channel image, same shape as src
mask -
steps - number of solver steps steps - number of solver steps
batch_size - solver batch size batch_size - solver batch size
reg_sigmaXY - apply regularization and sigmaXY of filter, otherwise set to 0.0 reg_sigmaXY - apply regularization and sigmaXY of filter, otherwise set to 0.0
@ -39,8 +41,8 @@ def sot(src,trg, steps=10, batch_size=30, reg_sigmaXY=16.0, reg_sigmaV=5.0):
dir = np.random.normal(size=c).astype(src_dtype) dir = np.random.normal(size=c).astype(src_dtype)
dir /= npla.norm(dir) dir /= npla.norm(dir)
projsource = np.sum( new_src*dir, axis=-1).reshape ((h*w)) projsource = np.sum( new_src*dir*mask, axis=-1).reshape ((h*w))
projtarget = np.sum( trg*dir, axis=-1).reshape ((h*w)) projtarget = np.sum( trg*dir*mask, axis=-1).reshape ((h*w))
idSource = np.argsort (projsource) idSource = np.argsort (projsource)
idTarget = np.argsort (projtarget) idTarget = np.argsort (projtarget)
@ -48,12 +50,18 @@ def sot(src,trg, steps=10, batch_size=30, reg_sigmaXY=16.0, reg_sigmaV=5.0):
a = projtarget[idTarget]-projsource[idSource] a = projtarget[idTarget]-projsource[idSource]
for i_c in range(c): for i_c in range(c):
advect[idSource,i_c] += a * dir[i_c] advect[idSource,i_c] += a * dir[i_c]
new_src += advect.reshape( (h,w,c) ) / batch_size
new_src += (advect.reshape( (h,w,c) ) * mask) / batch_size
if reg_sigmaXY != 0.0: if reg_sigmaXY != 0.0:
src_diff = new_src-src src_diff = new_src-src
src_diff_filt = cv2.bilateralFilter (src_diff, 0, reg_sigmaV, reg_sigmaXY ) src_diff_filt = cv2.bilateralFilter (src_diff, 0, reg_sigmaV, reg_sigmaXY )
if len(src_diff_filt.shape) == 2: if len(src_diff_filt.shape) == 2:
src_diff_filt = src_diff_filt[...,None] src_diff_filt = src_diff_filt[...,None]
new_src = src + src_diff_filt if return_diff:
return new_src return src_diff_filt
return src + src_diff_filt
else:
if return_diff:
return new_src-src
return new_src

View file

@ -37,7 +37,7 @@ class MPSPSCMRRingData:
# Initialize first block at 0 index # Initialize first block at 0 index
wid = 0 wid = 0
wid_uuid = uuid.uuid4().bytes_le wid_uuid = uuid.uuid4().bytes
wid_heap_offset = 0 wid_heap_offset = 0
wid_data_size = 0 wid_data_size = 0
@ -82,7 +82,7 @@ class MPSPSCMRRingData:
raise Exception('data_size more than heap_size') raise Exception('data_size more than heap_size')
fmv = FormattedMemoryViewIO(self._shared_mem.get_mv()) fmv = FormattedMemoryViewIO(self._shared_mem.get_mv())
wid_uuid = uuid.uuid4().bytes_le wid_uuid = uuid.uuid4().bytes
if self._write_lock is not None: if self._write_lock is not None:
self._write_lock.acquire() self._write_lock.acquire()

View file

@ -49,7 +49,7 @@ class MPWeakHeap:
# Entire block # Entire block
fmv.seek(self._first_block_offset) fmv.seek(self._first_block_offset)
fmv.write_fmt('qq', self._heap_size-self._first_block_offset, 0), fmv.write(uuid.uuid4().bytes_le) fmv.write_fmt('qq', self._heap_size-self._first_block_offset, 0), fmv.write(uuid.uuid4().bytes)
def add_data(self, data : Union[bytes, bytearray, memoryview] ) -> 'MPWeakHeap.DataRef': def add_data(self, data : Union[bytes, bytearray, memoryview] ) -> 'MPWeakHeap.DataRef':
@ -92,7 +92,7 @@ class MPWeakHeap:
if block_remain_size >= block_header_size: if block_remain_size >= block_header_size:
# the remain space of the block is enough for next block, split the block # the remain space of the block is enough for next block, split the block
next_block_offset = cur_block_offset + block_new_size next_block_offset = cur_block_offset + block_new_size
fmv.seek(next_block_offset), fmv.write_fmt('qq', block_remain_size, 0), fmv.write(uuid.uuid4().bytes_le) fmv.seek(next_block_offset), fmv.write_fmt('qq', block_remain_size, 0), fmv.write(uuid.uuid4().bytes)
else: else:
# otherwise do not split # otherwise do not split
next_block_offset = cur_block_offset + block_size next_block_offset = cur_block_offset + block_size
@ -101,7 +101,7 @@ class MPWeakHeap:
block_new_size = block_size block_new_size = block_size
# update current block structure # update current block structure
uid = uuid.uuid4().bytes_le uid = uuid.uuid4().bytes
fmv.seek(cur_block_offset), fmv.write_fmt('qq', block_new_size, data_size ), fmv.write(uid) fmv.seek(cur_block_offset), fmv.write_fmt('qq', block_new_size, data_size ), fmv.write(uid)
# update ring_head_block_offset # update ring_head_block_offset
@ -135,7 +135,7 @@ class MPWeakHeap:
next_block_size, = fmv.get_fmt('q') next_block_size, = fmv.get_fmt('q')
# erase data of next block # erase data of next block
fmv.write_fmt('qq', 0, 0), fmv.write(uuid.uuid4().bytes_le) fmv.write_fmt('qq', 0, 0), fmv.write(uuid.uuid4().bytes)
# overwrite current block size with expanded block size # overwrite current block size with expanded block size
fmv.seek(cur_block_offset) fmv.seek(cur_block_offset)

View file

@ -11,7 +11,6 @@ def _host_thread_proc(wref):
break break
ref._host_process_messages(0.005) ref._host_process_messages(0.005)
del ref del ref
print('_host_thread_proc exit')
class SPMTWorker: class SPMTWorker:
def __init__(self, *sub_args, **sub_kwargs): def __init__(self, *sub_args, **sub_kwargs):
@ -152,7 +151,7 @@ class SPMTWorker:
break break
self._threads_running = False self._threads_running = False
self._threads_exit_barrier.wait() self._threads_exit_barrier.wait()
self._on_sub_finalize() self._on_sub_finalize()

View file

@ -1,3 +1,3 @@
from .device import (ORTDeviceInfo, get_available_devices_info, from .device import (ORTDeviceInfo, get_available_devices_info,
get_cpu_device) get_cpu_device_info)
from .InferenceSession import InferenceSession_with_device from .InferenceSession import InferenceSession_with_device

View file

@ -67,33 +67,34 @@ class ORTDeviceInfo:
_ort_devices_info = None _ort_devices_info = None
def get_cpu_device() -> ORTDeviceInfo: def get_cpu_device_info() -> ORTDeviceInfo:
return ORTDeviceInfo(index=-1, execution_provider='CPUExecutionProvider', name='CPU', total_memory=0, free_memory=0) return ORTDeviceInfo(index=-1, execution_provider='CPUExecutionProvider', name='CPU', total_memory=0, free_memory=0)
def get_available_devices_info(include_cpu=True, cpu_only=False) -> List[ORTDeviceInfo]: def get_available_devices_info(include_cpu=True, cpu_only=False) -> List[ORTDeviceInfo]:
""" """
returns a list of available ORTDeviceInfo returns a list of available ORTDeviceInfo
""" """
global _ort_devices_info devices = []
if _ort_devices_info is None: if not cpu_only:
_initialize_ort_devices() global _ort_devices_info
devices = [] if _ort_devices_info is None:
if not cpu_only: _initialize_ort_devices_info()
_ort_devices_info = []
for i in range ( int(os.environ.get('ORT_DEVICES_COUNT',0)) ): for i in range ( int(os.environ.get('ORT_DEVICES_COUNT',0)) ):
devices.append ( ORTDeviceInfo(index=int(os.environ[f'ORT_DEVICE_{i}_INDEX']), _ort_devices_info.append ( ORTDeviceInfo(index=int(os.environ[f'ORT_DEVICE_{i}_INDEX']),
execution_provider=os.environ[f'ORT_DEVICE_{i}_EP'], execution_provider=os.environ[f'ORT_DEVICE_{i}_EP'],
name=os.environ[f'ORT_DEVICE_{i}_NAME'], name=os.environ[f'ORT_DEVICE_{i}_NAME'],
total_memory=int(os.environ[f'ORT_DEVICE_{i}_TOTAL_MEM']), total_memory=int(os.environ[f'ORT_DEVICE_{i}_TOTAL_MEM']),
free_memory=int(os.environ[f'ORT_DEVICE_{i}_FREE_MEM']), free_memory=int(os.environ[f'ORT_DEVICE_{i}_FREE_MEM']),
) ) ) )
if include_cpu or cpu_only: devices += _ort_devices_info
devices.append(get_cpu_device()) if include_cpu:
_ort_devices_info = devices devices.append(get_cpu_device_info())
return _ort_devices_info return devices
def _initialize_ort_devices(): def _initialize_ort_devices_info():
""" """
Determine available ORT devices, and place info about them to os.environ, Determine available ORT devices, and place info about them to os.environ,
they will be available in spawned subprocesses. they will be available in spawned subprocesses.
@ -184,4 +185,4 @@ def _initialize_ort_devices():
os.environ[f'ORT_DEVICE_{i}_TOTAL_MEM'] = str(device['total_mem']) os.environ[f'ORT_DEVICE_{i}_TOTAL_MEM'] = str(device['total_mem'])
os.environ[f'ORT_DEVICE_{i}_FREE_MEM'] = str(device['free_mem']) os.environ[f'ORT_DEVICE_{i}_FREE_MEM'] = str(device['free_mem'])
_initialize_ort_devices() _initialize_ort_devices_info()

View file

@ -247,7 +247,8 @@ def ascii_table(table_def : List[str],
if row_line is not None: if row_line is not None:
lines.append(row_line) lines.append(row_line)
for sub_rows in rows: for sub_rows in rows:
for row in sub_rows: for row in sub_rows:
line = '' line = ''
@ -288,7 +289,8 @@ def ascii_table(table_def : List[str],
line += right_border line += right_border
lines.append(line) lines.append(line)
if row_line is not None:
if len(sub_rows) != 0 and row_line is not None:
lines.append(row_line) lines.append(row_line)
return '\n'.join(lines) return '\n'.join(lines)

View file

@ -1,2 +1,2 @@
from .device import (TorchDeviceInfo, TorchDevicesInfo, get_available_devices, from .device import (TorchDeviceInfo, get_available_devices_info,
get_cpu_device) get_cpu_device_info, get_device, get_device_info_by_index)

View file

@ -1,4 +1,6 @@
from typing import List from typing import List, Union
import torch
class TorchDeviceInfo: class TorchDeviceInfo:
@ -45,96 +47,40 @@ class TorchDeviceInfo:
def __repr__(self): def __repr__(self):
return f'{self.__class__.__name__} object: ' + self.__str__() return f'{self.__class__.__name__} object: ' + self.__str__()
# class TorchDevicesInfo:
# """
# picklable list of TorchDeviceInfo
# """
# def __init__(self, devices : List[TorchDeviceInfo] = None):
# if devices is None:
# devices = []
# self._devices = devices
# def __getstate__(self):
# return self.__dict__.copy()
# def __setstate__(self, d):
# self.__init__()
# self.__dict__.update(d)
# def add(self, device_or_devices : TorchDeviceInfo):
# if isinstance(device_or_devices, TorchDeviceInfo):
# if device_or_devices not in self._devices:
# self._devices.append(device_or_devices)
# elif isinstance(device_or_devices, TorchDevicesInfo):
# for device in device_or_devices:
# self.add(device)
# def copy(self):
# return copy.deepcopy(self)
# def get_count(self): return len(self._devices)
# def get_largest_total_memory_device(self) -> TorchDeviceInfo:
# raise NotImplementedError()
# result = None
# idx_mem = 0
# for device in self._devices:
# mem = device.get_total_memory()
# if result is None or (mem is not None and mem > idx_mem):
# result = device
# idx_mem = mem
# return result
# def get_smallest_total_memory_device(self) -> TorchDeviceInfo:
# raise NotImplementedError()
# result = None
# idx_mem = sys.maxsize
# for device in self._devices:
# mem = device.get_total_memory()
# if result is None or (mem is not None and mem < idx_mem):
# result = device
# idx_mem = mem
# return result
# def __len__(self):
# return len(self._devices)
# def __getitem__(self, key):
# result = self._devices[key]
# if isinstance(key, slice):
# return self.__class__(result)
# return result
# def __iter__(self):
# for device in self._devices:
# yield device
# def __str__(self): return f'{self.__class__.__name__}:[' + ', '.join([ device.__str__() for device in self._devices ]) + ']'
# def __repr__(self): return f'{self.__class__.__name__}:[' + ', '.join([ device.__repr__() for device in self._devices ]) + ']'
_torch_devices = None _torch_devices = None
def get_cpu_device() -> TorchDeviceInfo: def get_cpu_device_info() -> TorchDeviceInfo:
return TorchDeviceInfo(index=-1, name='CPU', total_memory=0) return TorchDeviceInfo(index=-1, name='CPU', total_memory=0)
def get_available_devices(include_cpu=True, cpu_only=False) -> List[TorchDeviceInfo]: def get_device_info_by_index(index) -> Union[TorchDeviceInfo, None]:
for device in get_available_devices_info(include_cpu=False):
if device.get_index() == index:
return device
return None
def get_device(device_info : TorchDeviceInfo) -> torch.device:
if device_info.is_cpu():
return torch.device('cpu')
return torch.device(f'cuda:{device_info.get_index()}')
def get_available_devices_info(include_cpu=True, cpu_only=False) -> List[TorchDeviceInfo]:
""" """
returns a list of available TorchDeviceInfo returns a list of available TorchDeviceInfo
""" """
global _torch_devices devices = []
if _torch_devices is None: if not cpu_only:
import torch global _torch_devices
devices = [] if _torch_devices is None:
if not cpu_only: _torch_devices = []
for i in range (torch.cuda.device_count()): for i in range (torch.cuda.device_count()):
device_props = torch.cuda.get_device_properties(i) device_props = torch.cuda.get_device_properties(i)
devices.append ( TorchDeviceInfo(index=i, name=device_props.name, total_memory=device_props.total_memory)) _torch_devices.append ( TorchDeviceInfo(index=i, name=device_props.name, total_memory=device_props.total_memory))
devices += _torch_devices
if include_cpu or cpu_only: if include_cpu:
devices.append ( get_cpu_device() ) devices.append ( get_cpu_device_info() )
_torch_devices = devices return devices
return _torch_devices