mirror of
https://github.com/iperov/DeepFaceLive
synced 2025-07-14 17:13:43 -07:00
refactoring
This commit is contained in:
parent
8489949f2c
commit
30ba51edf7
24 changed files with 663 additions and 459 deletions
8
main.py
8
main.py
|
@ -79,12 +79,12 @@ def main():
|
||||||
p.set_defaults(func=train_FaceAligner)
|
p.set_defaults(func=train_FaceAligner)
|
||||||
|
|
||||||
def train_CTSOT(args):
|
def train_CTSOT(args):
|
||||||
from apps.trainers.CTSOT.CTSOTTrainerApp import run_app
|
from apps.trainers.CTSOT.CTSOTTrainerApp import CTSOTTrainerApp
|
||||||
run_app(userdata_path=Path(args.userdata_dir), faceset_path=Path(args.faceset_path))
|
CTSOTTrainerApp(workspace_path=Path(args.workspace_dir), faceset_path=Path(args.faceset_path))
|
||||||
|
|
||||||
p = train_parsers.add_parser('CTSOT')
|
p = train_parsers.add_parser('CTSOT')
|
||||||
p.add_argument('--userdata-dir', default=None, action=fixPathAction, help="Directory to save app data.")
|
p.add_argument('--workspace-dir', default=None, action=fixPathAction, help="Workspace directory.")
|
||||||
p.add_argument('--faceset-path', default=None, action=fixPathAction, help=".dfs path")
|
p.add_argument('--faceset-path', default=None, action=fixPathAction, help=".dfs faceset path")
|
||||||
p.set_defaults(func=train_CTSOT)
|
p.set_defaults(func=train_CTSOT)
|
||||||
|
|
||||||
def bad_args(arguments):
|
def bad_args(arguments):
|
||||||
|
|
|
@ -58,7 +58,7 @@ def get_available_devices() -> List[ORTDeviceInfo]:
|
||||||
class DFMModel:
|
class DFMModel:
|
||||||
def __init__(self, model_path : Path, device : ORTDeviceInfo = None):
|
def __init__(self, model_path : Path, device : ORTDeviceInfo = None):
|
||||||
if device is None:
|
if device is None:
|
||||||
device = lib_ort.get_cpu_device()
|
device = lib_ort.get_cpu_device_info()
|
||||||
self._model_path = model_path
|
self._model_path = model_path
|
||||||
|
|
||||||
sess = self._sess = lib_ort.InferenceSession_with_device(str(model_path), device)
|
sess = self._sess = lib_ort.InferenceSession_with_device(str(model_path), device)
|
||||||
|
|
121
modelhub/torch/CTSOT/CTSOT.py
Normal file
121
modelhub/torch/CTSOT/CTSOT.py
Normal file
|
@ -0,0 +1,121 @@
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from xlib.file import SplittedFile
|
||||||
|
from xlib.torch import TorchDeviceInfo, get_cpu_device_info
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlock(nn.Module):
|
||||||
|
def __init__(self, ch):
|
||||||
|
super().__init__()
|
||||||
|
self.conv1 = nn.Conv2d(ch, ch, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.conv2 = nn.Conv2d(ch, ch, kernel_size=3, stride=1, padding=1)
|
||||||
|
|
||||||
|
def forward(self, inp):
|
||||||
|
x = inp
|
||||||
|
x = F.leaky_relu(self.conv1(x), 0.2)
|
||||||
|
x = F.leaky_relu(self.conv2(x) + inp, 0.2)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class AutoEncoder(nn.Module):
|
||||||
|
def __init__(self, resolution, in_ch, ae_ch):
|
||||||
|
super().__init__()
|
||||||
|
self._resolution = resolution
|
||||||
|
self._in_ch = in_ch
|
||||||
|
self._ae_ch = ae_ch
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(in_ch*1, in_ch*2, kernel_size=5, stride=2, padding=2)
|
||||||
|
self.conv2 = nn.Conv2d(in_ch*2, in_ch*4, kernel_size=5, stride=2, padding=2)
|
||||||
|
self.conv3 = nn.Conv2d(in_ch*4, in_ch*8, kernel_size=5, stride=2, padding=2)
|
||||||
|
self.conv4 = nn.Conv2d(in_ch*8, in_ch*8, kernel_size=5, stride=2, padding=2)
|
||||||
|
|
||||||
|
self.dense_in = nn.Linear( in_ch*8 * ( resolution // (2**4) )**2, ae_ch)
|
||||||
|
self.dense_out = nn.Linear( ae_ch, in_ch*8 * ( resolution // (2**4) )**2 )
|
||||||
|
|
||||||
|
self.up_conv4 = nn.ConvTranspose2d(in_ch*8, in_ch*8, kernel_size=3, stride=2, padding=1, output_padding=(1,1), )
|
||||||
|
self.up_conv3 = nn.ConvTranspose2d(in_ch*8, in_ch*4, kernel_size=3, stride=2, padding=1, output_padding=(1,1), )
|
||||||
|
self.up_conv2 = nn.ConvTranspose2d(in_ch*4, in_ch*2, kernel_size=3, stride=2, padding=1, output_padding=(1,1), )
|
||||||
|
self.up_conv1 = nn.ConvTranspose2d(in_ch*2, in_ch*1, kernel_size=3, stride=2, padding=1, output_padding=(1,1), )
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, inp):
|
||||||
|
x = inp
|
||||||
|
x = F.leaky_relu(self.conv1(x), 0.1)
|
||||||
|
x = F.leaky_relu(self.conv2(x), 0.1)
|
||||||
|
x = F.leaky_relu(self.conv3(x), 0.1)
|
||||||
|
x = F.leaky_relu(self.conv4(x), 0.1)
|
||||||
|
|
||||||
|
x = x.view( (x.shape[0], np.prod(x.shape[1:])) )
|
||||||
|
|
||||||
|
x = self.dense_in(x)
|
||||||
|
x = self.dense_out(x)
|
||||||
|
|
||||||
|
x = x.view( (x.shape[0], self._in_ch*8, self._resolution // (2**4), self._resolution // (2**4) ))
|
||||||
|
x = F.leaky_relu(self.up_conv4(x), 0.1)
|
||||||
|
x = F.leaky_relu(self.up_conv3(x), 0.1)
|
||||||
|
x = F.leaky_relu(self.up_conv2(x), 0.1)
|
||||||
|
x = F.leaky_relu(self.up_conv1(x), 0.1)
|
||||||
|
|
||||||
|
# from xlib.console import diacon
|
||||||
|
# diacon.Diacon.stop()
|
||||||
|
# import code
|
||||||
|
# code.interact(local=dict(globals(), **locals()))
|
||||||
|
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class CTSOTNet(nn.Module):
|
||||||
|
def __init__(self, resolution, in_ch=6, inner_ch=64, ae_ch=256, out_ch=3, res_block_count = 12):
|
||||||
|
super().__init__()
|
||||||
|
self.in_conv = nn.Conv2d(in_ch, inner_ch, kernel_size=1, stride=1, padding=0)
|
||||||
|
self.ae = AutoEncoder(resolution, inner_ch, ae_ch)
|
||||||
|
self.out_conv = nn.Conv2d(inner_ch, out_ch, kernel_size=1, stride=1, padding=0)
|
||||||
|
|
||||||
|
def forward(self, img1_t, img2_t):
|
||||||
|
x = torch.cat([img1_t, img2_t], dim=1)
|
||||||
|
|
||||||
|
x = self.in_conv(x)
|
||||||
|
|
||||||
|
x = self.ae(x)
|
||||||
|
x = self.out_conv(x)
|
||||||
|
x = torch.tanh(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class CTSOT:
|
||||||
|
def __init__(self, device_info : TorchDeviceInfo = None,
|
||||||
|
state_dict : Union[dict, None] = None,
|
||||||
|
training : bool = False):
|
||||||
|
if device_info is None:
|
||||||
|
device_info = get_cpu_device_info()
|
||||||
|
self.device_info = device_info
|
||||||
|
|
||||||
|
self._net = net = CTSOTNet()
|
||||||
|
|
||||||
|
if state_dict is not None:
|
||||||
|
net.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
if training:
|
||||||
|
net.train()
|
||||||
|
else:
|
||||||
|
net.eval()
|
||||||
|
|
||||||
|
self.set_device(device_info)
|
||||||
|
|
||||||
|
def set_device(self, device_info : TorchDeviceInfo = None):
|
||||||
|
if device_info is None or device_info.is_cpu():
|
||||||
|
self._net.cpu()
|
||||||
|
else:
|
||||||
|
self._net.cuda(device_info.get_index())
|
||||||
|
|
||||||
|
def get_state_dict(self):
|
||||||
|
return self.net.state_dict()
|
||||||
|
|
||||||
|
def get_net(self) -> CTSOTNet:
|
||||||
|
return self._net
|
|
@ -8,13 +8,13 @@ import torch.nn.functional as F
|
||||||
from xlib import math as lib_math
|
from xlib import math as lib_math
|
||||||
from xlib.file import SplittedFile
|
from xlib.file import SplittedFile
|
||||||
from xlib.image import ImageProcessor
|
from xlib.image import ImageProcessor
|
||||||
from xlib.torch import TorchDeviceInfo, get_cpu_device
|
from xlib.torch import TorchDeviceInfo, get_cpu_device_info
|
||||||
|
|
||||||
|
|
||||||
class S3FD:
|
class S3FD:
|
||||||
def __init__(self, device_info : TorchDeviceInfo = None ):
|
def __init__(self, device_info : TorchDeviceInfo = None ):
|
||||||
if device_info is None:
|
if device_info is None:
|
||||||
device_info = get_cpu_device()
|
device_info = get_cpu_device_info()
|
||||||
self.device_info = device_info
|
self.device_info = device_info
|
||||||
|
|
||||||
path = Path(__file__).parent / 'S3FD.pth'
|
path = Path(__file__).parent / 'S3FD.pth'
|
||||||
|
|
|
@ -1,2 +1,3 @@
|
||||||
from .CenterFace.CenterFace import CenterFace, CenterFace_to_onnx
|
from .CenterFace.CenterFace import CenterFace_to_onnx
|
||||||
from .S3FD.S3FD import S3FD
|
from .S3FD.S3FD import S3FD
|
||||||
|
from .CTSOT.CTSOT import CTSOT, CTSOTNet
|
|
@ -53,14 +53,10 @@ def extract_FaceSynthetics(inputdir_path : Path, faceset_path : Path):
|
||||||
"""
|
"""
|
||||||
if faceset_path.suffix != '.dfs':
|
if faceset_path.suffix != '.dfs':
|
||||||
raise ValueError('faceset_path must have .dfs extension.')
|
raise ValueError('faceset_path must have .dfs extension.')
|
||||||
|
|
||||||
filepaths = lib_path.get_files_paths(inputdir_path)
|
filepaths = lib_path.get_files_paths(inputdir_path)
|
||||||
|
fs = lib_face.Faceset(faceset_path, write_access=True, recreate=True)
|
||||||
fs = lib_face.Faceset(faceset_path)
|
for filepath in lib_con.progress_bar_iterator(filepaths, desc='Processing'):
|
||||||
fs.recreate()
|
|
||||||
|
|
||||||
|
|
||||||
for filepath in lib_con.progress_bar_iterator(filepaths, 'Processing'):
|
|
||||||
if filepath.suffix == '.txt':
|
if filepath.suffix == '.txt':
|
||||||
|
|
||||||
image_filepath = filepath.parent / f'{filepath.name.split("_")[0]}.png'
|
image_filepath = filepath.parent / f'{filepath.name.split("_")[0]}.png'
|
||||||
|
@ -93,17 +89,13 @@ def extract_FaceSynthetics(inputdir_path : Path, faceset_path : Path):
|
||||||
ufm.set_UImage_uuid(uimg.get_uuid())
|
ufm.set_UImage_uuid(uimg.get_uuid())
|
||||||
ufm.set_FRect(flmrks.get_FRect())
|
ufm.set_FRect(flmrks.get_FRect())
|
||||||
ufm.add_FLandmarks2D(flmrks)
|
ufm.add_FLandmarks2D(flmrks)
|
||||||
|
|
||||||
|
fs.add_UFaceMark(ufm)
|
||||||
fs.add_UImage(uimg, format='png')
|
fs.add_UImage(uimg, format='png')
|
||||||
fs.add_UFaceMark(ufm)
|
|
||||||
|
fs.optimize()
|
||||||
|
|
||||||
fs.shrink()
|
|
||||||
fs.close()
|
fs.close()
|
||||||
|
|
||||||
import code
|
|
||||||
code.interact(local=dict(globals(), **locals()))
|
|
||||||
|
|
||||||
# seg_filepath = input_path / ( Path(image_filepath).stem + '_seg.png')
|
# seg_filepath = input_path / ( Path(image_filepath).stem + '_seg.png')
|
||||||
# if not seg_filepath.exists():
|
# if not seg_filepath.exists():
|
||||||
# raise ValueError(f'{seg_filepath} does not exist')
|
# raise ValueError(f'{seg_filepath} does not exist')
|
||||||
|
|
|
@ -15,7 +15,7 @@ def gaussian_blur (input_t : Tensor, sigma, dtype=None) -> Tensor:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if sigma <= 0.0:
|
if sigma <= 0.0:
|
||||||
return input_t.copy() #TODO
|
return input_t.copy()
|
||||||
|
|
||||||
device = input_t.get_device()
|
device = input_t.get_device()
|
||||||
|
|
||||||
|
|
|
@ -3,113 +3,220 @@ import threading
|
||||||
import time
|
import time
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
from typing import Any, Callable, List, Tuple, Union
|
from typing import Any, Callable, List, Tuple, Union
|
||||||
|
from numbers import Number
|
||||||
from ... import text as lib_text
|
from ... import text as lib_text
|
||||||
|
|
||||||
class EDlgMode(IntEnum):
|
class EDlgMode(IntEnum):
|
||||||
UNDEFINED = 0
|
UNHANDLED = 0
|
||||||
BACK = 1
|
BACK = 1
|
||||||
RELOAD = 2
|
RELOAD = 2
|
||||||
WRONG_INPUT = 3
|
WRONG_INPUT = 3
|
||||||
SUCCESS = 4
|
HANDLED = 4
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class DlgChoice:
|
class DlgChoice:
|
||||||
def __init__(self, name : str = None, row_desc : str = None):
|
def __init__(self, name : str = None,
|
||||||
super().__init__()
|
row_def : str = None,
|
||||||
|
on_choose : Callable = None):
|
||||||
if len(name) == 0:
|
if len(name) == 0:
|
||||||
raise ValueError('Zero len name is not valid.')
|
raise ValueError('Zero len name is not valid.')
|
||||||
self._name = name
|
self._name = name
|
||||||
self._row_desc = row_desc
|
self._row_def = row_def
|
||||||
|
self._on_choose = on_choose
|
||||||
|
|
||||||
def get_name(self) -> Union[str, None]: return self._name
|
def get_name(self) -> Union[str, None]: return self._name
|
||||||
def get_row_desc(self) -> Union[str, None]: return self._row_desc
|
def get_row_def(self) -> Union[str, None]: return self._row_def
|
||||||
|
def get_on_choose(self) -> Callable: return self._on_choose
|
||||||
|
|
||||||
class Dlg:
|
class Dlg:
|
||||||
def __init__(self, title : str = None, has_go_back=True):
|
def __init__(self, on_recreate : Callable[ [], 'Dlg'] = None,
|
||||||
|
on_back : Callable = None,
|
||||||
|
top_rows_def : Union[str, List[str]] = None,
|
||||||
|
bottom_rows_def : Union[str, List[str]] = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
|
base class for Diacon dialogs.
|
||||||
"""
|
"""
|
||||||
self._title = title
|
self._on_recreate = on_recreate
|
||||||
self._has_go_back = has_go_back
|
self._on_back = on_back
|
||||||
|
|
||||||
|
self._top_rows_def = top_rows_def
|
||||||
|
self._bottom_rows_def = bottom_rows_def
|
||||||
|
|
||||||
|
def recreate(self):
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
if self._on_recreate is not None:
|
||||||
|
return self._on_recreate(self)
|
||||||
|
else:
|
||||||
|
raise Exception('on_recreate() is not defined.')
|
||||||
|
|
||||||
def get_name(self) -> str: return self._name
|
def get_name(self) -> str: return self._name
|
||||||
|
|
||||||
def handle_user_input(self, s : str) -> EDlgMode:
|
def set_current(self, print=True):
|
||||||
|
Diacon.update_dlg(self, print=print)
|
||||||
|
|
||||||
|
def handle_user_input(self, s : str):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
"""
|
"""
|
||||||
s = s.strip()
|
mode = self.on_user_input(s.strip())
|
||||||
|
|
||||||
# ? and < available in any dialog, handle them first
|
if mode == EDlgMode.UNHANDLED:
|
||||||
s_len = len(s)
|
mode = EDlgMode.RELOAD
|
||||||
if s_len == 0:
|
if mode == EDlgMode.WRONG_INPUT:
|
||||||
|
print('\nWrong input')
|
||||||
|
mode = EDlgMode.RELOAD
|
||||||
|
|
||||||
|
if mode == EDlgMode.RELOAD:
|
||||||
|
self.recreate().set_current()
|
||||||
|
if mode == EDlgMode.BACK:
|
||||||
|
if self._on_back is not None:
|
||||||
|
self._on_back(self)
|
||||||
|
|
||||||
|
#overridable
|
||||||
|
def on_user_input(self, s : str) -> EDlgMode:
|
||||||
|
if len(s) == 0:
|
||||||
return EDlgMode.RELOAD
|
return EDlgMode.RELOAD
|
||||||
if s_len == 1:
|
if self._on_back is not None and len(s) == 1:
|
||||||
#if s == '?':
|
|
||||||
# return EDlgMode.RELOAD
|
|
||||||
if s == '<':
|
if s == '<':
|
||||||
return EDlgMode.BACK
|
return EDlgMode.BACK
|
||||||
|
return EDlgMode.UNHANDLED
|
||||||
return self.on_user_input(s)
|
|
||||||
|
|
||||||
def print(self, table_width_max=80, col_spacing = 3):
|
def print(self, table_width_max=80, col_spacing = 3):
|
||||||
"""
|
"""
|
||||||
print dialog
|
print dialog
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Gather table lines
|
|
||||||
table_def : List[str]= []
|
table_def : List[str]= []
|
||||||
|
|
||||||
if self._has_go_back:
|
trd = self._top_rows_def
|
||||||
|
brd = self._bottom_rows_def
|
||||||
|
if trd is not None:
|
||||||
|
if not isinstance(trd, (list,tuple)):
|
||||||
|
trd = [trd]
|
||||||
|
table_def += trd
|
||||||
|
|
||||||
|
if self._on_back is not None:
|
||||||
table_def.append('| < | Go back.')
|
table_def.append('| < | Go back.')
|
||||||
|
|
||||||
table_def.append('|99')
|
table_def.append('|99')
|
||||||
table_def = self.on_print(table_def)
|
table_def = self.on_print(table_def)
|
||||||
|
|
||||||
table = lib_text.ascii_table(table_def, max_table_width=80,
|
if brd is not None:
|
||||||
left_border = None,
|
if not isinstance(brd, (list,tuple)):
|
||||||
right_border = None,
|
brd = [brd]
|
||||||
|
table_def += brd
|
||||||
|
|
||||||
|
table = lib_text.ascii_table(table_def, max_table_width=80,
|
||||||
|
left_border = '| ',
|
||||||
|
right_border = ' |',
|
||||||
border = ' | ',
|
border = ' | ',
|
||||||
row_symbol = None)
|
row_symbol = None,
|
||||||
|
)
|
||||||
print()
|
print()
|
||||||
print(table)
|
print(table)
|
||||||
|
|
||||||
|
|
||||||
#overridable
|
#overridable
|
||||||
def on_print(self, table_lines : List[Tuple[str,str]]):
|
def on_print(self, table_lines : List[Tuple[str,str]]):
|
||||||
return table_lines
|
return table_lines
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class DlgNumber(Dlg):
|
||||||
|
def __init__(self, is_float : bool,
|
||||||
|
current_value = None,
|
||||||
|
min_value = None,
|
||||||
|
max_value = None,
|
||||||
|
clip_min_value = None,
|
||||||
|
clip_max_value = None,
|
||||||
|
on_value : Callable[ [Dlg, Number], None] = None,
|
||||||
|
on_recreate : Callable[ [], 'Dlg'] = None,
|
||||||
|
on_back : Callable = None,
|
||||||
|
top_rows_def : Union[str, List[str]] = None,
|
||||||
|
bottom_rows_def : Union[str, List[str]] = None, ):
|
||||||
|
super().__init__(on_recreate=on_recreate, on_back=on_back, top_rows_def=top_rows_def, bottom_rows_def=bottom_rows_def)
|
||||||
|
|
||||||
|
if min_value is not None and max_value is not None and min_value > max_value:
|
||||||
|
raise ValueError('min_value > max_value')
|
||||||
|
if clip_min_value is not None and clip_max_value is not None and clip_min_value > clip_max_value:
|
||||||
|
raise ValueError('clip_min_value > clip_max_value')
|
||||||
|
|
||||||
|
self._is_float = is_float
|
||||||
|
self._current_value = current_value
|
||||||
|
self._min_value = min_value
|
||||||
|
self._max_value = max_value
|
||||||
|
self._clip_min_value = clip_min_value
|
||||||
|
self._clip_max_value = clip_max_value
|
||||||
|
self._on_value = on_value
|
||||||
|
|
||||||
#overridable
|
#overridable
|
||||||
def on_user_input(self, s : str) -> EDlgMode:
|
def on_print(self, table_def : List[str]):
|
||||||
"""
|
|
||||||
handle user input
|
minv, maxv = self._min_value, self._max_value
|
||||||
return False if input is invalid
|
|
||||||
"""
|
if self._is_float:
|
||||||
return EDlgMode.UNDEFINED
|
line = '| * | Enter float number'
|
||||||
|
else:
|
||||||
|
line = '| * | Enter integer number'
|
||||||
|
|
||||||
|
if minv is not None and maxv is None:
|
||||||
|
line += f' in range: [{minv} ... )'
|
||||||
|
elif minv is None and maxv is not None:
|
||||||
|
line += f' in range: ( ... {maxv} ]'
|
||||||
|
elif minv is not None and maxv is not None:
|
||||||
|
line += f' in range: [{minv} ... {maxv} ]'
|
||||||
|
|
||||||
|
table_def.append(line)
|
||||||
|
|
||||||
|
return table_def
|
||||||
|
|
||||||
|
#overridable
|
||||||
|
def on_user_input(self, s : str) -> bool:
|
||||||
|
result = super().on_user_input(s)
|
||||||
|
if result == EDlgMode.UNHANDLED:
|
||||||
|
try:
|
||||||
|
print(s)
|
||||||
|
v = float(s) if self._is_float else int(s)
|
||||||
|
|
||||||
|
if self._min_value is not None:
|
||||||
|
if v < self._min_value:
|
||||||
|
return EDlgMode.WRONG_INPUT
|
||||||
|
|
||||||
|
if self._max_value is not None:
|
||||||
|
if v > self._max_value:
|
||||||
|
return EDlgMode.WRONG_INPUT
|
||||||
|
|
||||||
|
if self._clip_min_value is not None:
|
||||||
|
if v < self._clip_min_value:
|
||||||
|
v = self._clip_min_value
|
||||||
|
|
||||||
|
if self._clip_max_value is not None:
|
||||||
|
if v > self._clip_max_value:
|
||||||
|
v = self._clip_max_value
|
||||||
|
|
||||||
|
if self._on_value is not None:
|
||||||
|
self._on_value(self, v)
|
||||||
|
return EDlgMode.HANDLED
|
||||||
|
except:
|
||||||
|
return EDlgMode.WRONG_INPUT
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
class DlgChoices(Dlg):
|
class DlgChoices(Dlg):
|
||||||
def __init__(self, choices : List[DlgChoice], multiple_choices=False, title : str = None, has_go_back = True):
|
def __init__(self, choices : List[DlgChoice],
|
||||||
"""
|
on_multi_choice : Callable[ [ List[DlgChoice] ], None] = None,
|
||||||
|
on_recreate : Callable[ [Dlg], Dlg] = None,
|
||||||
"""
|
on_back : Callable = None,
|
||||||
super().__init__(title=title, has_go_back=has_go_back)
|
top_rows_def : Union[str, List[str]] = None,
|
||||||
|
bottom_rows_def : Union[str, List[str]] = None,
|
||||||
|
):
|
||||||
|
super().__init__(on_recreate=on_recreate, on_back=on_back, top_rows_def=top_rows_def, bottom_rows_def=bottom_rows_def)
|
||||||
self._choices = choices
|
self._choices = choices
|
||||||
self._multiple_choices = multiple_choices
|
self._on_multi_choice = on_multi_choice
|
||||||
|
|
||||||
self._results = None
|
|
||||||
self._results_id = None
|
|
||||||
|
|
||||||
self._short_names = [choice.get_name() for choice in choices]
|
self._short_names = [choice.get_name() for choice in choices]
|
||||||
|
|
||||||
# if any([x is not None for x in self._short_names]):
|
|
||||||
# # Using short names from choices
|
|
||||||
# if any([x is None for x in self._short_names]):
|
|
||||||
# raise Exception('No short name for one of choices.')
|
|
||||||
# if len(set(self._short_names)) != len(self._short_names):
|
|
||||||
# raise ValueError(f'Contains duplicate short names: {self._short_names}')
|
|
||||||
# else:
|
|
||||||
|
|
||||||
# Make short names for all choices
|
# Make short names for all choices
|
||||||
names = [ choice.get_name() for choice in choices ]
|
names = [ choice.get_name() for choice in choices ]
|
||||||
names_len = len(names)
|
names_len = len(names)
|
||||||
|
@ -139,27 +246,14 @@ class DlgChoices(Dlg):
|
||||||
break
|
break
|
||||||
self._short_names = short_names
|
self._short_names = short_names
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_selected_choices(self) -> List[DlgChoice]:
|
|
||||||
"""
|
|
||||||
returns selected choices
|
|
||||||
"""
|
|
||||||
return self._results
|
|
||||||
|
|
||||||
def get_selected_choices_id(self) -> List[int]:
|
|
||||||
"""
|
|
||||||
returns selected choice
|
|
||||||
"""
|
|
||||||
return self._results_id
|
|
||||||
|
|
||||||
#overridable
|
#overridable
|
||||||
def on_print(self, table_def : List[str]):
|
def on_print(self, table_def : List[str]):
|
||||||
|
|
||||||
for short_name, choice in zip(self._short_names, self._choices):
|
for short_name, choice in zip(self._short_names, self._choices):
|
||||||
row_def = f'| {short_name}'
|
row_def = f'| {short_name}'
|
||||||
row_desc = choice.get_row_desc()
|
x = choice.get_row_def()
|
||||||
if row_desc is not None:
|
if x is not None:
|
||||||
row_def += row_desc
|
row_def += x
|
||||||
table_def.append(row_def)
|
table_def.append(row_def)
|
||||||
|
|
||||||
return table_def
|
return table_def
|
||||||
|
@ -167,35 +261,36 @@ class DlgChoices(Dlg):
|
||||||
#overridable
|
#overridable
|
||||||
def on_user_input(self, s : str) -> bool:
|
def on_user_input(self, s : str) -> bool:
|
||||||
result = super().on_user_input(s)
|
result = super().on_user_input(s)
|
||||||
if result == EDlgMode.UNDEFINED:
|
if result == EDlgMode.UNHANDLED:
|
||||||
|
|
||||||
if self._multiple_choices:
|
if self._on_multi_choice is not None:
|
||||||
multi_s = s.split(',')
|
multi_s = s.split(',')
|
||||||
else:
|
else:
|
||||||
multi_s = [s]
|
multi_s = [s]
|
||||||
|
|
||||||
results = []
|
choices_id = []
|
||||||
results_id = []
|
|
||||||
for s in multi_s:
|
for s in multi_s:
|
||||||
s = s.strip()
|
x = [ i for i, short_name in enumerate(self._short_names) if s.strip() == short_name ]
|
||||||
|
|
||||||
x = [ i for i,short_name in enumerate(self._short_names) if s == short_name ]
|
|
||||||
if len(x) == 0:
|
if len(x) == 0:
|
||||||
# no short name match
|
# No short name match
|
||||||
return EDlgMode.WRONG_INPUT
|
return EDlgMode.WRONG_INPUT
|
||||||
else:
|
else:
|
||||||
id = x[0]
|
id = x[0]
|
||||||
results_id.append(id)
|
choices_id.append(id)
|
||||||
results.append(self._choices[id])
|
|
||||||
|
if len(set(choices_id)) != len(choices_id):
|
||||||
if len(set(results_id)) != len(results_id):
|
|
||||||
# Duplicate input
|
# Duplicate input
|
||||||
return EDlgMode.WRONG_INPUT
|
return EDlgMode.WRONG_INPUT
|
||||||
|
|
||||||
self._results = results
|
for id in choices_id:
|
||||||
self._results_id = results_id
|
on_choose = self._choices[id].get_on_choose()
|
||||||
|
if on_choose is not None:
|
||||||
|
on_choose(self)
|
||||||
|
|
||||||
|
if self._on_multi_choice is not None:
|
||||||
|
self._on_multi_choice(choices_id)
|
||||||
|
|
||||||
return EDlgMode.SUCCESS
|
return EDlgMode.HANDLED
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@ -204,42 +299,9 @@ class DlgChoices(Dlg):
|
||||||
class _Diacon:
|
class _Diacon:
|
||||||
"""
|
"""
|
||||||
User dialog with via console.
|
User dialog with via console.
|
||||||
|
|
||||||
Internal architecture:
|
|
||||||
|
|
||||||
[
|
|
||||||
Main-Thread
|
|
||||||
|
|
||||||
current thread from which __init__() called
|
|
||||||
]
|
|
||||||
|
|
||||||
[
|
|
||||||
Dialog-Thread
|
|
||||||
|
|
||||||
separate thread where dialogs are handled and dynamically created
|
|
||||||
|
|
||||||
we need this thread, because main thread can be busy,
|
|
||||||
for example training neural network
|
|
||||||
|
|
||||||
calls on_dlg() provided with __init__
|
|
||||||
|
|
||||||
thus keep in mind on_dlg() works in separate thread
|
|
||||||
|
|
||||||
This thread must not be blocked inside on_dlg(),
|
|
||||||
because Diacon.stop() can be called that stops all threads.
|
|
||||||
]
|
|
||||||
|
|
||||||
[
|
|
||||||
Input-Thread
|
|
||||||
|
|
||||||
separate thread where user input is accepted in non-blocking mode,
|
|
||||||
and transfered to processing thread
|
|
||||||
]
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._on_dlg : Callable = None
|
|
||||||
|
|
||||||
self._lock = threading.RLock()
|
self._lock = threading.RLock()
|
||||||
self._current_dlg : Dlg = None
|
self._current_dlg : Dlg = None
|
||||||
self._new_dlg : Dlg = None
|
self._new_dlg : Dlg = None
|
||||||
|
@ -250,11 +312,10 @@ class _Diacon:
|
||||||
self._input_request = False
|
self._input_request = False
|
||||||
self._input_result : str = None
|
self._input_result : str = None
|
||||||
|
|
||||||
def start(self, on_dlg : Callable):
|
def start(self):
|
||||||
if self._started:
|
if self._started:
|
||||||
raise Exception('Diacon already started.')
|
raise Exception('Diacon already started.')
|
||||||
self._started = True
|
self._started = True
|
||||||
self._on_dlg = on_dlg
|
|
||||||
|
|
||||||
self._input_t = threading.Thread(target=self._input_thread, daemon=True)
|
self._input_t = threading.Thread(target=self._input_thread, daemon=True)
|
||||||
self._input_t.start()
|
self._input_t.start()
|
||||||
|
@ -285,8 +346,6 @@ class _Diacon:
|
||||||
|
|
||||||
|
|
||||||
def _dialog_thread(self, ):
|
def _dialog_thread(self, ):
|
||||||
self._on_dlg(None, EDlgMode.RELOAD)
|
|
||||||
|
|
||||||
while self._started:
|
while self._started:
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
|
@ -303,13 +362,7 @@ class _Diacon:
|
||||||
if input_result is not None:
|
if input_result is not None:
|
||||||
|
|
||||||
if self._current_dlg is not None:
|
if self._current_dlg is not None:
|
||||||
mode = self._current_dlg.handle_user_input(input_result)
|
self._current_dlg.handle_user_input(input_result)
|
||||||
if mode == EDlgMode.WRONG_INPUT:
|
|
||||||
print('\nWrong input')
|
|
||||||
mode = EDlgMode.RELOAD
|
|
||||||
if mode == EDlgMode.UNDEFINED:
|
|
||||||
mode = EDlgMode.RELOAD
|
|
||||||
self._on_dlg(self._current_dlg, mode)
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
time.sleep(0.005)
|
time.sleep(0.005)
|
||||||
|
@ -332,6 +385,9 @@ class _Diacon:
|
||||||
show current or set new Dialog
|
show current or set new Dialog
|
||||||
Can be called from any thread.
|
Can be called from any thread.
|
||||||
"""
|
"""
|
||||||
|
if not self._started:
|
||||||
|
self.start()
|
||||||
|
|
||||||
self._new_dlg = (new_dlg, print)
|
self._new_dlg = (new_dlg, print)
|
||||||
|
|
||||||
Diacon = _Diacon()
|
Diacon = _Diacon()
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
from .Diacon import Diacon, Dlg, DlgChoice, DlgChoices, EDlgMode
|
from .Diacon import Diacon, Dlg, DlgChoice, DlgChoices, EDlgMode, DlgNumber
|
||||||
|
|
|
@ -9,7 +9,7 @@ class FMask:
|
||||||
def __init__(self, _from_pickled=False):
|
def __init__(self, _from_pickled=False):
|
||||||
"""
|
"""
|
||||||
"""
|
"""
|
||||||
self._uuid : Union[bytes, None] = uuid.uuid4().bytes_le if not _from_pickled else None
|
self._uuid : Union[bytes, None] = uuid.uuid4().bytes if not _from_pickled else None
|
||||||
self._mask_type : Union[FMask.Type, None] = None
|
self._mask_type : Union[FMask.Type, None] = None
|
||||||
self._FImage_uuid : Union[bytes, None] = None
|
self._FImage_uuid : Union[bytes, None] = None
|
||||||
|
|
||||||
|
|
|
@ -52,9 +52,25 @@ class FaceWarper:
|
||||||
self._rw_grid_tx = rnd_state.uniform(*rw_grid_tx) if isinstance(rw_grid_tx, Iterable) else rw_grid_tx
|
self._rw_grid_tx = rnd_state.uniform(*rw_grid_tx) if isinstance(rw_grid_tx, Iterable) else rw_grid_tx
|
||||||
self._rw_grid_ty = rnd_state.uniform(*rw_grid_ty) if isinstance(rw_grid_ty, Iterable) else rw_grid_ty
|
self._rw_grid_ty = rnd_state.uniform(*rw_grid_ty) if isinstance(rw_grid_ty, Iterable) else rw_grid_ty
|
||||||
|
|
||||||
|
self._warp_rnd_mat = Affine2DUniMat.from_transformation(0.5, 0.5, self._rw_grid_rot_deg, 1.0+self._rw_grid_scale, self._rw_grid_tx, self._rw_grid_ty)
|
||||||
|
self._align_rnd_mat = Affine2DUniMat.from_transformation(0.5, 0.5, self._align_rot_deg, 1.0+self._align_scale, self._align_tx, self._align_ty)
|
||||||
|
|
||||||
self._rnd_state_state = rnd_state.get_state()
|
self._rnd_state_state = rnd_state.get_state()
|
||||||
self._cached = {}
|
self._cached = {}
|
||||||
|
|
||||||
|
def get_aligned_random_transform_mat(self) -> Affine2DUniMat:
|
||||||
|
"""
|
||||||
|
returns Affine2DUniMat that represents transformation from aligned face to randomly transformed aligned face
|
||||||
|
"""
|
||||||
|
mat1 = self._img_to_face_uni_mat
|
||||||
|
mat2 = (self._face_to_img_uni_mat * self._align_rnd_mat).invert()
|
||||||
|
|
||||||
|
pts = [ [0,0], [1,0], [1,1]]
|
||||||
|
src_pts = mat1.transform_points(pts)
|
||||||
|
dst_pts = mat2.transform_points(pts)
|
||||||
|
|
||||||
|
return Affine2DUniMat.from_3_pairs(src_pts, dst_pts)
|
||||||
|
|
||||||
def transform(self, img : np.ndarray, out_res : int, random_warp : bool = True) -> np.ndarray:
|
def transform(self, img : np.ndarray, out_res : int, random_warp : bool = True) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
transform an image.
|
transform an image.
|
||||||
|
@ -91,9 +107,7 @@ class FaceWarper:
|
||||||
face_warp_grid = FaceWarper._gen_random_warp_uni_grid_diff(out_res, self._rw_grid_cell_count, 0.12, rnd_state)
|
face_warp_grid = FaceWarper._gen_random_warp_uni_grid_diff(out_res, self._rw_grid_cell_count, 0.12, rnd_state)
|
||||||
|
|
||||||
# make a randomly transformable mat to transform face_warp_grid from face to image
|
# make a randomly transformable mat to transform face_warp_grid from face to image
|
||||||
face_warp_grid_mat = (self._face_to_img_uni_mat *
|
face_warp_grid_mat = self._face_to_img_uni_mat * self._warp_rnd_mat
|
||||||
Affine2DUniMat.from_transformation(0.5, 0.5, self._rw_grid_rot_deg, 1.0+self._rw_grid_scale, self._rw_grid_tx, self._rw_grid_ty)
|
|
||||||
)
|
|
||||||
|
|
||||||
# warp face_warp_grid to the space of image and merge with image_grid
|
# warp face_warp_grid to the space of image and merge with image_grid
|
||||||
image_grid += cv2.warpAffine(face_warp_grid, face_warp_grid_mat.to_exact_mat(out_res,out_res, W, H), (W,H), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT)
|
image_grid += cv2.warpAffine(face_warp_grid, face_warp_grid_mat.to_exact_mat(out_res,out_res, W, H), (W,H), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT)
|
||||||
|
@ -101,9 +115,10 @@ class FaceWarper:
|
||||||
# scale uniform grid from to image size
|
# scale uniform grid from to image size
|
||||||
image_grid *= (H-1, W-1)
|
image_grid *= (H-1, W-1)
|
||||||
|
|
||||||
# apply random transormations for align mat
|
# apply random transformations for align mat
|
||||||
img_to_face_rnd_mat = (self._face_to_img_uni_mat * Affine2DMat.from_transformation(0.5, 0.5, self._align_rot_deg, 1.0+self._align_scale, self._align_tx, self._align_ty)
|
#img_to_face_rnd_uni_mat = (self._face_to_img_uni_mat * self._align_rnd_mat).invert()
|
||||||
).invert().to_exact_mat(W,H,out_res,out_res)
|
|
||||||
|
img_to_face_rnd_mat = (self._face_to_img_uni_mat * self._align_rnd_mat).invert().to_exact_mat(W,H,out_res,out_res)
|
||||||
|
|
||||||
# warp image_grid to face space
|
# warp image_grid to face space
|
||||||
image_grid = cv2.warpAffine(image_grid, img_to_face_rnd_mat, (out_res,out_res), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_REPLICATE )
|
image_grid = cv2.warpAffine(image_grid, img_to_face_rnd_mat, (out_res,out_res), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_REPLICATE )
|
||||||
|
|
|
@ -1,19 +1,22 @@
|
||||||
import pickle
|
import pickle
|
||||||
import sqlite3
|
import uuid
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Generator, List, Union, Iterable
|
from typing import Generator, Iterable, List, Union
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
|
import h5py
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from .. import console as lib_con
|
||||||
from .FMask import FMask
|
from .FMask import FMask
|
||||||
from .UFaceMark import UFaceMark
|
from .UFaceMark import UFaceMark
|
||||||
from .UImage import UImage
|
from .UImage import UImage
|
||||||
from .UPerson import UPerson
|
from .UPerson import UPerson
|
||||||
|
|
||||||
|
|
||||||
class Faceset:
|
class Faceset:
|
||||||
|
|
||||||
def __init__(self, path = None):
|
def __init__(self, path = None, write_access=False, recreate=False):
|
||||||
"""
|
"""
|
||||||
Faceset is a class to store and manage face related data.
|
Faceset is a class to store and manage face related data.
|
||||||
|
|
||||||
|
@ -21,205 +24,155 @@ class Faceset:
|
||||||
|
|
||||||
path path to faceset .dfs file
|
path path to faceset .dfs file
|
||||||
|
|
||||||
|
write_access
|
||||||
|
|
||||||
|
recreate
|
||||||
|
|
||||||
Can be pickled.
|
Can be pickled.
|
||||||
"""
|
"""
|
||||||
|
self._f = None
|
||||||
|
|
||||||
self._path = path = Path(path)
|
self._path = path = Path(path)
|
||||||
if path.suffix != '.dfs':
|
if path.suffix != '.dfs':
|
||||||
raise ValueError('Path must be a .dfs file')
|
raise ValueError('Path must be a .dfs file')
|
||||||
|
|
||||||
self._conn = conn = sqlite3.connect(path, isolation_level=None)
|
if path.exists():
|
||||||
self._cur = cur = conn.cursor()
|
if write_access and recreate:
|
||||||
|
path.unlink()
|
||||||
|
elif not write_access:
|
||||||
|
raise FileNotFoundError(f'File {path} not found.')
|
||||||
|
|
||||||
cur = self._get_cursor()
|
self._mode = 'a' if write_access else 'r'
|
||||||
cur.execute('BEGIN IMMEDIATE')
|
self._open()
|
||||||
if not self._is_table_exists('FacesetInfo'):
|
|
||||||
self.recreate(shrink=False, _transaction=False)
|
|
||||||
cur.execute('COMMIT')
|
|
||||||
self.shrink()
|
|
||||||
else:
|
|
||||||
cur.execute('END')
|
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
self.close()
|
self.close()
|
||||||
|
|
||||||
def __getstate__(self):
|
def __getstate__(self):
|
||||||
return {'_path' : self._path}
|
return {'_path' : self._path, '_mode' : self._mode}
|
||||||
|
|
||||||
def __setstate__(self, d):
|
def __setstate__(self, d):
|
||||||
self.__init__( d['_path'] )
|
self._f = None
|
||||||
|
self._path = d['_path']
|
||||||
|
self._mode = d['_mode']
|
||||||
|
self._open()
|
||||||
|
|
||||||
def __repr__(self): return self.__str__()
|
def __repr__(self): return self.__str__()
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return f"Faceset. UImage:{self.get_UImage_count()} UFaceMark:{self.get_UFaceMark_count()} UPerson:{self.get_UPerson_count()}"
|
return f"Faceset. UImage:{self.get_UImage_count()} UFaceMark:{self.get_UFaceMark_count()} UPerson:{self.get_UPerson_count()}"
|
||||||
|
|
||||||
def _is_table_exists(self, name):
|
def _open(self):
|
||||||
return self._cur.execute(f"SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", [name]).fetchone()[0] != 0
|
if self._f is None:
|
||||||
|
self._f = f = h5py.File(self._path, mode=self._mode)
|
||||||
|
self._UFaceMark_grp = f.require_group('UFaceMark')
|
||||||
|
self._UImage_grp = f.require_group('UImage')
|
||||||
|
self._UImage_image_data_grp = f.require_group('UImage_image_data')
|
||||||
|
self._UPerson_grp = f.require_group('UPerson')
|
||||||
|
|
||||||
def _get_cursor(self) -> sqlite3.Cursor: return self._cur
|
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
if self._cur is not None:
|
if self._f is not None:
|
||||||
self._cur.close()
|
self._f.close()
|
||||||
self._cur = None
|
self._f = None
|
||||||
|
|
||||||
if self._conn is not None:
|
def optimize(self, verbose=True):
|
||||||
self._conn.close()
|
|
||||||
self._conn = None
|
|
||||||
|
|
||||||
def shrink(self):
|
|
||||||
self._cur.execute('VACUUM')
|
|
||||||
|
|
||||||
def recreate(self, shrink=True, _transaction=True):
|
|
||||||
"""
|
"""
|
||||||
delete all data and recreate Faceset structure.
|
recreate Faceset with optimized structure.
|
||||||
"""
|
"""
|
||||||
cur = self._get_cursor()
|
if verbose:
|
||||||
|
print(f'Optimizing {self._path.name}...')
|
||||||
|
|
||||||
if _transaction:
|
tmp_path = self._path.parent / (self._path.stem + '_optimizing' + self._path.suffix)
|
||||||
cur.execute('BEGIN IMMEDIATE')
|
|
||||||
|
|
||||||
for table_name, in cur.execute("SELECT name from sqlite_master where type = 'table';").fetchall():
|
tmp_fs = Faceset(tmp_path, write_access=True, recreate=True)
|
||||||
cur.execute(f'DROP TABLE {table_name}')
|
self._group_copy(tmp_fs._UFaceMark_grp, self._UFaceMark_grp, verbose=verbose)
|
||||||
|
self._group_copy(tmp_fs._UPerson_grp, self._UPerson_grp, verbose=verbose)
|
||||||
|
self._group_copy(tmp_fs._UImage_grp, self._UImage_grp, verbose=verbose)
|
||||||
|
self._group_copy(tmp_fs._UImage_image_data_grp, self._UImage_image_data_grp, verbose=verbose)
|
||||||
|
tmp_fs.close()
|
||||||
|
|
||||||
(cur.execute('CREATE TABLE FacesetInfo (version INT)')
|
self.close()
|
||||||
.execute('INSERT INTO FacesetInfo VALUES (1)')
|
self._path.unlink()
|
||||||
|
tmp_path.rename(self._path)
|
||||||
|
self._open()
|
||||||
|
|
||||||
.execute('CREATE TABLE UImage (uuid BLOB, name TEXT, format TEXT, data BLOB)')
|
def _group_copy(self, group_dst : h5py.Group, group_src : h5py.Group, verbose=True):
|
||||||
.execute('CREATE TABLE UPerson (uuid BLOB, data BLOB)')
|
for key, value in lib_con.progress_bar_iterator(group_src.items(), desc=f'Copying {group_src.name} -> {group_dst.name}', suppress_print=not verbose):
|
||||||
.execute('CREATE TABLE UFaceMark (uuid BLOB, UImage_uuid BLOB, UPerson_uuid BLOB, data BLOB)')
|
d = group_dst.create_dataset(key, shape=value.shape, dtype=value.dtype )
|
||||||
)
|
d[:] = value[:]
|
||||||
|
for a_key, a_value in value.attrs.items():
|
||||||
|
d.attrs[a_key] = a_value
|
||||||
|
|
||||||
if _transaction:
|
def _group_read_bytes(self, group : h5py.Group, key : str, check_key=True) -> Union[bytes, None]:
|
||||||
cur.execute('COMMIT')
|
if check_key and key not in group:
|
||||||
|
return None
|
||||||
|
dataset = group[key]
|
||||||
|
data_bytes = bytearray(len(dataset))
|
||||||
|
dataset.read_direct(np.frombuffer(data_bytes, dtype=np.uint8))
|
||||||
|
return data_bytes
|
||||||
|
|
||||||
if shrink:
|
def _group_write_bytes(self, group : h5py.Group, key : str, data : bytes, update_existing=True) -> Union[h5py.Dataset, None]:
|
||||||
self.shrink()
|
if key in group:
|
||||||
|
if not update_existing:
|
||||||
|
return None
|
||||||
|
del group[key]
|
||||||
|
|
||||||
|
return group.create_dataset(key, data=np.frombuffer(data, dtype=np.uint8) )
|
||||||
|
|
||||||
###################
|
###################
|
||||||
### UFaceMark
|
### UFaceMark
|
||||||
###################
|
###################
|
||||||
def _UFaceMark_from_db_row(self, db_row) -> UFaceMark:
|
def add_UFaceMark(self, ufacemark_or_list : UFaceMark, update_existing=True):
|
||||||
uuid, UImage_uuid, UPerson_uuid, data = db_row
|
|
||||||
|
|
||||||
ufm = UFaceMark()
|
|
||||||
ufm.restore_state(pickle.loads(data))
|
|
||||||
return ufm
|
|
||||||
|
|
||||||
def add_UFaceMark(self, ufacemark_or_list : UFaceMark):
|
|
||||||
"""
|
"""
|
||||||
add or update UFaceMark in DB
|
add or update UFaceMark in DB
|
||||||
"""
|
"""
|
||||||
if not isinstance(ufacemark_or_list, Iterable):
|
if not isinstance(ufacemark_or_list, Iterable):
|
||||||
ufacemark_or_list : List[UFaceMark] = [ufacemark_or_list]
|
ufacemark_or_list : List[UFaceMark] = [ufacemark_or_list]
|
||||||
|
|
||||||
cur = self._cur
|
|
||||||
cur.execute('BEGIN IMMEDIATE')
|
|
||||||
for ufm in ufacemark_or_list:
|
for ufm in ufacemark_or_list:
|
||||||
uuid = ufm.get_uuid()
|
self._group_write_bytes(self._UFaceMark_grp, ufm.get_uuid().hex(), pickle.dumps(ufm.dump_state()), update_existing=update_existing )
|
||||||
UImage_uuid = ufm.get_UImage_uuid()
|
|
||||||
UPerson_uuid = ufm.get_UPerson_uuid()
|
|
||||||
|
|
||||||
data = pickle.dumps(ufm.dump_state())
|
|
||||||
|
|
||||||
if cur.execute('SELECT COUNT(*) from UFaceMark where uuid=?', [uuid] ).fetchone()[0] != 0:
|
|
||||||
cur.execute('UPDATE UFaceMark SET UImage_uuid=?, UPerson_uuid=?, data=? WHERE uuid=?',
|
|
||||||
[UImage_uuid, UPerson_uuid, data, uuid])
|
|
||||||
else:
|
|
||||||
cur.execute('INSERT INTO UFaceMark VALUES (?, ?, ?, ?)', [uuid, UImage_uuid, UPerson_uuid, data])
|
|
||||||
cur.execute('COMMIT')
|
|
||||||
|
|
||||||
def get_UFaceMark_count(self) -> int:
|
def get_UFaceMark_count(self) -> int:
|
||||||
return self._cur.execute('SELECT COUNT(*) FROM UFaceMark').fetchone()[0]
|
return len(self._UFaceMark_grp.keys())
|
||||||
|
|
||||||
def get_all_UFaceMark(self) -> List[UFaceMark]:
|
def get_all_UFaceMark(self) -> List[UFaceMark]:
|
||||||
return [ self._UFaceMark_from_db_row(db_row) for db_row in self._cur.execute('SELECT * FROM UFaceMark').fetchall() ]
|
return [ UFaceMark.from_state(pickle.loads(self._group_read_bytes(self._UFaceMark_grp, key, check_key=False))) for key in self._UFaceMark_grp.keys() ]
|
||||||
|
|
||||||
|
def get_all_UFaceMark_uuids(self) -> List[bytes]:
|
||||||
|
return [ uuid.UUID(key).bytes for key in self._UFaceMark_grp.keys() ]
|
||||||
|
|
||||||
def get_UFaceMark_by_uuid(self, uuid : bytes) -> Union[UFaceMark, None]:
|
def get_UFaceMark_by_uuid(self, uuid : bytes) -> Union[UFaceMark, None]:
|
||||||
c = self._cur.execute('SELECT * FROM UFaceMark WHERE uuid=?', [uuid])
|
data = self._group_read_bytes(self._UFaceMark_grp, uuid.hex())
|
||||||
db_row = c.fetchone()
|
if data is None:
|
||||||
if db_row is None:
|
|
||||||
return None
|
return None
|
||||||
|
return UFaceMark.from_state(pickle.loads(data))
|
||||||
|
|
||||||
return self._UFaceMark_from_db_row(db_row)
|
def delete_UFaceMark_by_uuid(self, uuid : bytes) -> bool:
|
||||||
|
key = uuid.hex()
|
||||||
|
if key in self._UFaceMark_grp:
|
||||||
|
del self._UFaceMark_grp[key]
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def iter_UFaceMark(self) -> Generator[UFaceMark, None, None]:
|
def iter_UFaceMark(self) -> Generator[UFaceMark, None, None]:
|
||||||
"""
|
"""
|
||||||
returns Generator of UFaceMark
|
returns Generator of UFaceMark
|
||||||
"""
|
"""
|
||||||
for db_row in self._cur.execute('SELECT * FROM UFaceMark').fetchall():
|
for key in self._UFaceMark_grp.keys():
|
||||||
yield self._UFaceMark_from_db_row(db_row)
|
yield UFaceMark.from_state(pickle.loads(self._group_read_bytes(self._UFaceMark_grp, key, check_key=False)))
|
||||||
|
|
||||||
def delete_all_UFaceMark(self):
|
def delete_all_UFaceMark(self):
|
||||||
"""
|
"""
|
||||||
deletes all UFaceMark from DB
|
deletes all UFaceMark from DB
|
||||||
"""
|
"""
|
||||||
(self._cur.execute('BEGIN IMMEDIATE')
|
for key in self._UFaceMark_grp.keys():
|
||||||
.execute('DELETE FROM UFaceMark')
|
del self._UFaceMark_grp[key]
|
||||||
.execute('COMMIT') )
|
|
||||||
###################
|
|
||||||
### UPerson
|
|
||||||
###################
|
|
||||||
def _UPerson_from_db_row(self, db_row) -> UPerson:
|
|
||||||
uuid, data = db_row
|
|
||||||
up = UPerson()
|
|
||||||
up.restore_state(pickle.loads(data))
|
|
||||||
return up
|
|
||||||
|
|
||||||
def add_UPerson(self, uperson_or_list : UPerson):
|
|
||||||
"""
|
|
||||||
add or update UPerson in DB
|
|
||||||
"""
|
|
||||||
if not isinstance(uperson_or_list, Iterable):
|
|
||||||
uperson_or_list : List[UPerson] = [uperson_or_list]
|
|
||||||
|
|
||||||
cur = self._cur
|
|
||||||
cur.execute('BEGIN IMMEDIATE')
|
|
||||||
for uperson in uperson_or_list:
|
|
||||||
uuid = uperson.get_uuid()
|
|
||||||
|
|
||||||
data = pickle.dumps(uperson.dump_state())
|
|
||||||
|
|
||||||
if cur.execute('SELECT COUNT(*) from UPerson where uuid=?', [uuid]).fetchone()[0] != 0:
|
|
||||||
cur.execute('UPDATE UPerson SET data=? WHERE uuid=?', [data])
|
|
||||||
else:
|
|
||||||
cur.execute('INSERT INTO UPerson VALUES (?, ?)', [uuid, data])
|
|
||||||
cur.execute('COMMIT')
|
|
||||||
|
|
||||||
def get_UPerson_count(self) -> int:
|
|
||||||
return self._cur.execute('SELECT COUNT(*) FROM UPerson').fetchone()[0]
|
|
||||||
|
|
||||||
def get_all_UPerson(self) -> List[UPerson]:
|
|
||||||
return [ self._UPerson_from_db_row(db_row) for db_row in self._cur.execute('SELECT * FROM UPerson').fetchall() ]
|
|
||||||
|
|
||||||
def iter_UPerson(self) -> Generator[UPerson, None, None]:
|
|
||||||
"""
|
|
||||||
iterator of all UPerson's
|
|
||||||
"""
|
|
||||||
for db_row in self._cur.execute('SELECT * FROM UPerson').fetchall():
|
|
||||||
yield self._UPerson_from_db_row(db_row)
|
|
||||||
|
|
||||||
def delete_all_UPerson(self):
|
|
||||||
"""
|
|
||||||
deletes all UPerson from DB
|
|
||||||
"""
|
|
||||||
(self._cur.execute('BEGIN IMMEDIATE')
|
|
||||||
.execute('DELETE FROM UPerson')
|
|
||||||
.execute('COMMIT') )
|
|
||||||
|
|
||||||
###################
|
###################
|
||||||
### UImage
|
### UImage
|
||||||
###################
|
###################
|
||||||
def _UImage_from_db_row(self, db_row) -> UImage:
|
def add_UImage(self, uimage_or_list : UImage, format : str = 'png', quality : int = 100, update_existing=True):
|
||||||
uuid, name, format, data_bytes = db_row
|
|
||||||
img = cv2.imdecode(np.frombuffer(data_bytes, dtype=np.uint8), flags=cv2.IMREAD_UNCHANGED)
|
|
||||||
|
|
||||||
uimg = UImage()
|
|
||||||
uimg.set_uuid(uuid)
|
|
||||||
uimg.set_name(name)
|
|
||||||
uimg.assign_image(img)
|
|
||||||
return uimg
|
|
||||||
|
|
||||||
def add_UImage(self, uimage_or_list : UImage, format : str = 'webp', quality : int = 100):
|
|
||||||
"""
|
"""
|
||||||
add or update UImage in DB
|
add or update UImage in DB
|
||||||
|
|
||||||
|
@ -239,9 +192,8 @@ class Faceset:
|
||||||
raise ValueError('quality must be in range [0..100]')
|
raise ValueError('quality must be in range [0..100]')
|
||||||
|
|
||||||
if not isinstance(uimage_or_list, Iterable):
|
if not isinstance(uimage_or_list, Iterable):
|
||||||
uimage_or_list = [uimage_or_list]
|
uimage_or_list : List[UImage] = [uimage_or_list]
|
||||||
|
|
||||||
uimage_datas = []
|
|
||||||
for uimage in uimage_or_list:
|
for uimage in uimage_or_list:
|
||||||
if format == 'webp':
|
if format == 'webp':
|
||||||
imencode_args = [int(cv2.IMWRITE_WEBP_QUALITY), quality]
|
imencode_args = [int(cv2.IMWRITE_WEBP_QUALITY), quality]
|
||||||
|
@ -251,44 +203,112 @@ class Faceset:
|
||||||
imencode_args = [int(cv2.IMWRITE_JPEG2000_COMPRESSION_X1000), quality*10]
|
imencode_args = [int(cv2.IMWRITE_JPEG2000_COMPRESSION_X1000), quality*10]
|
||||||
else:
|
else:
|
||||||
imencode_args = []
|
imencode_args = []
|
||||||
|
|
||||||
ret, data_bytes = cv2.imencode( f'.{format}', uimage.get_image(), imencode_args)
|
ret, data_bytes = cv2.imencode( f'.{format}', uimage.get_image(), imencode_args)
|
||||||
if not ret:
|
if not ret:
|
||||||
raise Exception(f'Unable to encode image format {format}')
|
raise Exception(f'Unable to encode image format {format}')
|
||||||
uimage_datas.append(data_bytes.data)
|
|
||||||
|
|
||||||
cur = self._cur
|
key = uimage.get_uuid().hex()
|
||||||
cur.execute('BEGIN IMMEDIATE')
|
|
||||||
for uimage, data in zip(uimage_or_list, uimage_datas):
|
|
||||||
uuid = uimage.get_uuid()
|
|
||||||
if cur.execute('SELECT COUNT(*) from UImage where uuid=?', [uuid] ).fetchone()[0] != 0:
|
|
||||||
cur.execute('UPDATE UImage SET name=?, format=?, data=? WHERE uuid=?', [uimage.get_name(), format, data, uuid])
|
|
||||||
else:
|
|
||||||
cur.execute('INSERT INTO UImage VALUES (?, ?, ?, ?)', [uuid, uimage.get_name(), format, data])
|
|
||||||
cur.execute('COMMIT')
|
|
||||||
|
|
||||||
def get_UImage_count(self) -> int: return self._cur.execute('SELECT COUNT(*) FROM UImage').fetchone()[0]
|
self._group_write_bytes(self._UImage_grp, key, pickle.dumps(uimage.dump_state(exclude_image=True)), update_existing=update_existing )
|
||||||
def get_UImage_by_uuid(self, uuid : Union[bytes, None]) -> Union[UImage, None]:
|
d = self._group_write_bytes(self._UImage_image_data_grp, key, data_bytes.data, update_existing=update_existing )
|
||||||
"""
|
d.attrs['format'] = format
|
||||||
"""
|
d.attrs['quality'] = quality
|
||||||
if uuid is None:
|
|
||||||
|
def get_UImage_count(self) -> int:
|
||||||
|
return len(self._UImage_grp.keys())
|
||||||
|
|
||||||
|
def get_all_UImage(self) -> List[UImage]:
|
||||||
|
return [ self._get_UImage_by_key(key) for key in self._UImage_grp.keys() ]
|
||||||
|
|
||||||
|
def get_all_UImage_uuids(self) -> List[bytes]:
|
||||||
|
return [ uuid.UUID(key).bytes for key in self._UImage_grp.keys() ]
|
||||||
|
|
||||||
|
def _get_UImage_by_key(self, key, check_key=True) -> Union[UImage, None]:
|
||||||
|
data = self._group_read_bytes(self._UImage_grp, key, check_key=check_key)
|
||||||
|
if data is None:
|
||||||
return None
|
return None
|
||||||
|
uimg = UImage.from_state(pickle.loads(data))
|
||||||
|
|
||||||
db_row = self._cur.execute('SELECT * FROM UImage where uuid=?', [uuid]).fetchone()
|
image_data = self._group_read_bytes(self._UImage_image_data_grp, key, check_key=check_key)
|
||||||
if db_row is None:
|
if image_data is not None:
|
||||||
return None
|
uimg.assign_image (cv2.imdecode(np.frombuffer(image_data, dtype=np.uint8), flags=cv2.IMREAD_UNCHANGED))
|
||||||
return self._UImage_from_db_row(db_row)
|
|
||||||
|
|
||||||
def iter_UImage(self) -> Generator[UImage, None, None]:
|
return uimg
|
||||||
|
|
||||||
|
def get_UImage_by_uuid(self, uuid : bytes) -> Union[UImage, None]:
|
||||||
|
return self._get_UImage_by_key(uuid.hex())
|
||||||
|
|
||||||
|
def delete_UImage_by_uuid(self, uuid : bytes):
|
||||||
|
key = uuid.hex()
|
||||||
|
if key in self._UImage_grp:
|
||||||
|
del self._UImage_grp[key]
|
||||||
|
if key in self._UImage_image_data_grp:
|
||||||
|
del self._UImage_image_data_grp[key]
|
||||||
|
|
||||||
|
def iter_UImage(self, include_key=False) -> Generator[UImage, None, None]:
|
||||||
"""
|
"""
|
||||||
iterator of all UImage's
|
returns Generator of UImage
|
||||||
"""
|
"""
|
||||||
for db_row in self._cur.execute('SELECT * FROM UImage').fetchall():
|
for key in self._UImage_grp.keys():
|
||||||
yield self._UImage_from_db_row(db_row)
|
uimg = self._get_UImage_by_key(key, check_key=False)
|
||||||
|
yield (uimg, key) if include_key else uimg
|
||||||
|
|
||||||
def delete_all_UImage(self):
|
def delete_all_UImage(self):
|
||||||
"""
|
"""
|
||||||
deletes all UImage from DB
|
deletes all UImage from DB
|
||||||
"""
|
"""
|
||||||
(self._cur.execute('BEGIN IMMEDIATE')
|
for key in self._UImage_grp.keys():
|
||||||
.execute('DELETE FROM UImage')
|
del self._UImage_grp[key]
|
||||||
.execute('COMMIT') )
|
for key in self._UImage_image_data_grp.keys():
|
||||||
|
del self._UImage_image_data_grp[key]
|
||||||
|
|
||||||
|
###################
|
||||||
|
### UPerson
|
||||||
|
###################
|
||||||
|
|
||||||
|
def add_UPerson(self, uperson_or_list : UPerson, update_existing=True):
|
||||||
|
"""
|
||||||
|
add or update UPerson in DB
|
||||||
|
"""
|
||||||
|
if not isinstance(uperson_or_list, Iterable):
|
||||||
|
uperson_or_list : List[UPerson] = [uperson_or_list]
|
||||||
|
|
||||||
|
for uperson in uperson_or_list:
|
||||||
|
self._group_write_bytes(self._UPerson_grp, uperson.get_uuid().hex(), pickle.dumps(uperson.dump_state()), update_existing=update_existing )
|
||||||
|
|
||||||
|
def get_UPerson_count(self) -> int:
|
||||||
|
return len(self._UPerson_grp.keys())
|
||||||
|
|
||||||
|
def get_all_UPerson(self) -> List[UPerson]:
|
||||||
|
return [ UPerson.from_state(pickle.loads(self._group_read_bytes(self._UPerson_grp, key, check_key=False))) for key in self._UPerson_grp.keys() ]
|
||||||
|
|
||||||
|
def get_all_UPerson_uuids(self) -> List[bytes]:
|
||||||
|
return [ uuid.UUID(key).bytes for key in self._UPerson_grp.keys() ]
|
||||||
|
|
||||||
|
def get_UPerson_by_uuid(self, uuid : bytes) -> Union[UPerson, None]:
|
||||||
|
data = self._group_read_bytes(self._UPerson_grp, uuid.hex())
|
||||||
|
if data is None:
|
||||||
|
return None
|
||||||
|
return UPerson.from_state(pickle.loads(data))
|
||||||
|
|
||||||
|
def delete_UPerson_by_uuid(self, uuid : bytes) -> bool:
|
||||||
|
key = uuid.hex()
|
||||||
|
if key in self._UPerson_grp:
|
||||||
|
del self._UPerson_grp[key]
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def iter_UPerson(self) -> Generator[UPerson, None, None]:
|
||||||
|
"""
|
||||||
|
returns Generator of UPerson
|
||||||
|
"""
|
||||||
|
for key in self._UPerson_grp.keys():
|
||||||
|
yield UPerson.from_state(pickle.loads(self._group_read_bytes(self._UPerson_grp, key, check_key=False)))
|
||||||
|
|
||||||
|
def delete_all_UPerson(self):
|
||||||
|
"""
|
||||||
|
deletes all UPerson from DB
|
||||||
|
"""
|
||||||
|
for key in self._UPerson_grp.keys():
|
||||||
|
del self._UPerson_grp[key]
|
||||||
|
|
|
@ -25,7 +25,13 @@ class UFaceMark(IState):
|
||||||
def __repr__(self): return self.__str__()
|
def __repr__(self): return self.__str__()
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return f"UFaceMark UUID:[...{self.get_uuid()[-4:].hex()}]"
|
return f"UFaceMark UUID:[...{self.get_uuid()[-4:].hex()}]"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_state(state : dict) -> 'UFaceMark':
|
||||||
|
ufm = UFaceMark()
|
||||||
|
ufm.restore_state(state)
|
||||||
|
return ufm
|
||||||
|
|
||||||
def restore_state(self, state : dict):
|
def restore_state(self, state : dict):
|
||||||
self._uuid = state.get('_uuid', None)
|
self._uuid = state.get('_uuid', None)
|
||||||
self._UImage_uuid = state.get('_UImage_uuid', None)
|
self._UImage_uuid = state.get('_UImage_uuid', None)
|
||||||
|
@ -45,7 +51,7 @@ class UFaceMark(IState):
|
||||||
|
|
||||||
def get_uuid(self) -> Union[bytes, None]:
|
def get_uuid(self) -> Union[bytes, None]:
|
||||||
if self._uuid is None:
|
if self._uuid is None:
|
||||||
self._uuid = uuid.uuid4().bytes_le
|
self._uuid = uuid.uuid4().bytes
|
||||||
return self._uuid
|
return self._uuid
|
||||||
|
|
||||||
def set_uuid(self, uuid : Union[bytes, None]):
|
def set_uuid(self, uuid : Union[bytes, None]):
|
||||||
|
@ -72,6 +78,16 @@ class UFaceMark(IState):
|
||||||
self._FRect = face_urect
|
self._FRect = face_urect
|
||||||
|
|
||||||
def get_all_FLandmarks2D(self) -> List[FLandmarks2D]: return self._FLandmarks2D_list
|
def get_all_FLandmarks2D(self) -> List[FLandmarks2D]: return self._FLandmarks2D_list
|
||||||
|
|
||||||
|
def get_FLandmarks2D_best(self) -> Union[FLandmarks2D, None]:
|
||||||
|
"""get best available FLandmarks2D """
|
||||||
|
lmrks = self.get_FLandmarks2D_by_type(ELandmarks2D.L468)
|
||||||
|
if lmrks is None:
|
||||||
|
lmrks = self.get_FLandmarks2D_by_type(ELandmarks2D.L68)
|
||||||
|
if lmrks is None:
|
||||||
|
lmrks = self.get_FLandmarks2D_by_type(ELandmarks2D.L5)
|
||||||
|
return lmrks
|
||||||
|
|
||||||
def get_FLandmarks2D_by_type(self, type : ELandmarks2D) -> Union[FLandmarks2D, None]:
|
def get_FLandmarks2D_by_type(self, type : ELandmarks2D) -> Union[FLandmarks2D, None]:
|
||||||
"""get FLandmarks2D from list by type"""
|
"""get FLandmarks2D from list by type"""
|
||||||
if not isinstance(type, ELandmarks2D):
|
if not isinstance(type, ELandmarks2D):
|
||||||
|
|
|
@ -18,10 +18,31 @@ class UImage(IState):
|
||||||
|
|
||||||
def __str__(self): return f"UImage UUID:[...{self.get_uuid()[-4:].hex()}] name:[{self._name}] image:[{ (self._image.shape, self._image.dtype) if self._image is not None else None}]"
|
def __str__(self): return f"UImage UUID:[...{self.get_uuid()[-4:].hex()}] name:[{self._name}] image:[{ (self._image.shape, self._image.dtype) if self._image is not None else None}]"
|
||||||
def __repr__(self): return self.__str__()
|
def __repr__(self): return self.__str__()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_state(state : dict) -> 'UImage':
|
||||||
|
ufm = UImage()
|
||||||
|
ufm.restore_state(state)
|
||||||
|
return ufm
|
||||||
|
|
||||||
|
def restore_state(self, state : dict):
|
||||||
|
self._uuid = state.get('_uuid', None)
|
||||||
|
self._name = state.get('_name', None)
|
||||||
|
self._image = state.get('_image', None)
|
||||||
|
|
||||||
|
def dump_state(self, exclude_image=False) -> dict:
|
||||||
|
d = {'_uuid' : self._uuid,
|
||||||
|
'_name' : self._name,
|
||||||
|
}
|
||||||
|
|
||||||
|
if not exclude_image:
|
||||||
|
d['_image'] = self._image
|
||||||
|
|
||||||
|
return d
|
||||||
|
|
||||||
def get_uuid(self) -> Union[bytes, None]:
|
def get_uuid(self) -> Union[bytes, None]:
|
||||||
if self._uuid is None:
|
if self._uuid is None:
|
||||||
self._uuid = uuid.uuid4().bytes_le
|
self._uuid = uuid.uuid4().bytes
|
||||||
return self._uuid
|
return self._uuid
|
||||||
|
|
||||||
def set_uuid(self, uuid : Union[bytes, None]):
|
def set_uuid(self, uuid : Union[bytes, None]):
|
||||||
|
|
|
@ -15,6 +15,12 @@ class UPerson(IState):
|
||||||
def __str__(self): return f"UPerson UUID:[...{self._uuid[-4:].hex()}] name:[{self._name}] age:[{self._age}]"
|
def __str__(self): return f"UPerson UUID:[...{self._uuid[-4:].hex()}] name:[{self._name}] age:[{self._age}]"
|
||||||
def __repr__(self): return self.__str__()
|
def __repr__(self): return self.__str__()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_state(state : dict) -> 'UPerson':
|
||||||
|
ufm = UPerson()
|
||||||
|
ufm.restore_state(state)
|
||||||
|
return ufm
|
||||||
|
|
||||||
def restore_state(self, state : dict):
|
def restore_state(self, state : dict):
|
||||||
self._uuid = state.get('_uuid', None)
|
self._uuid = state.get('_uuid', None)
|
||||||
self._name = state.get('_name', None)
|
self._name = state.get('_name', None)
|
||||||
|
@ -28,7 +34,7 @@ class UPerson(IState):
|
||||||
|
|
||||||
def get_uuid(self) -> Union[bytes, None]:
|
def get_uuid(self) -> Union[bytes, None]:
|
||||||
if self._uuid is None:
|
if self._uuid is None:
|
||||||
self._uuid = uuid.uuid4().bytes_le
|
self._uuid = uuid.uuid4().bytes
|
||||||
return self._uuid
|
return self._uuid
|
||||||
|
|
||||||
def set_uuid(self, uuid : Union[bytes, None]):
|
def set_uuid(self, uuid : Union[bytes, None]):
|
||||||
|
|
|
@ -2,12 +2,14 @@ import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.linalg as npla
|
import numpy.linalg as npla
|
||||||
|
|
||||||
def sot(src,trg, steps=10, batch_size=30, reg_sigmaXY=16.0, reg_sigmaV=5.0):
|
def sot(src,trg, mask=None, steps=10, batch_size=30, reg_sigmaXY=16.0, reg_sigmaV=5.0, return_diff=False):
|
||||||
"""
|
"""
|
||||||
Color Transform via Sliced Optimal Transfer, ported from https://github.com/dcoeurjo/OTColorTransfer
|
Color Transform via Sliced Optimal Transfer, ported from https://github.com/dcoeurjo/OTColorTransfer
|
||||||
|
|
||||||
src - any float range any channel image
|
src - any float range any channel image
|
||||||
dst - any float range any channel image, same shape as src
|
dst - any float range any channel image, same shape as src
|
||||||
|
mask -
|
||||||
|
|
||||||
steps - number of solver steps
|
steps - number of solver steps
|
||||||
batch_size - solver batch size
|
batch_size - solver batch size
|
||||||
reg_sigmaXY - apply regularization and sigmaXY of filter, otherwise set to 0.0
|
reg_sigmaXY - apply regularization and sigmaXY of filter, otherwise set to 0.0
|
||||||
|
@ -39,8 +41,8 @@ def sot(src,trg, steps=10, batch_size=30, reg_sigmaXY=16.0, reg_sigmaV=5.0):
|
||||||
dir = np.random.normal(size=c).astype(src_dtype)
|
dir = np.random.normal(size=c).astype(src_dtype)
|
||||||
dir /= npla.norm(dir)
|
dir /= npla.norm(dir)
|
||||||
|
|
||||||
projsource = np.sum( new_src*dir, axis=-1).reshape ((h*w))
|
projsource = np.sum( new_src*dir*mask, axis=-1).reshape ((h*w))
|
||||||
projtarget = np.sum( trg*dir, axis=-1).reshape ((h*w))
|
projtarget = np.sum( trg*dir*mask, axis=-1).reshape ((h*w))
|
||||||
|
|
||||||
idSource = np.argsort (projsource)
|
idSource = np.argsort (projsource)
|
||||||
idTarget = np.argsort (projtarget)
|
idTarget = np.argsort (projtarget)
|
||||||
|
@ -48,12 +50,18 @@ def sot(src,trg, steps=10, batch_size=30, reg_sigmaXY=16.0, reg_sigmaV=5.0):
|
||||||
a = projtarget[idTarget]-projsource[idSource]
|
a = projtarget[idTarget]-projsource[idSource]
|
||||||
for i_c in range(c):
|
for i_c in range(c):
|
||||||
advect[idSource,i_c] += a * dir[i_c]
|
advect[idSource,i_c] += a * dir[i_c]
|
||||||
new_src += advect.reshape( (h,w,c) ) / batch_size
|
|
||||||
|
new_src += (advect.reshape( (h,w,c) ) * mask) / batch_size
|
||||||
|
|
||||||
if reg_sigmaXY != 0.0:
|
if reg_sigmaXY != 0.0:
|
||||||
src_diff = new_src-src
|
src_diff = new_src-src
|
||||||
src_diff_filt = cv2.bilateralFilter (src_diff, 0, reg_sigmaV, reg_sigmaXY )
|
src_diff_filt = cv2.bilateralFilter (src_diff, 0, reg_sigmaV, reg_sigmaXY )
|
||||||
if len(src_diff_filt.shape) == 2:
|
if len(src_diff_filt.shape) == 2:
|
||||||
src_diff_filt = src_diff_filt[...,None]
|
src_diff_filt = src_diff_filt[...,None]
|
||||||
new_src = src + src_diff_filt
|
if return_diff:
|
||||||
return new_src
|
return src_diff_filt
|
||||||
|
return src + src_diff_filt
|
||||||
|
else:
|
||||||
|
if return_diff:
|
||||||
|
return new_src-src
|
||||||
|
return new_src
|
||||||
|
|
|
@ -37,7 +37,7 @@ class MPSPSCMRRingData:
|
||||||
|
|
||||||
# Initialize first block at 0 index
|
# Initialize first block at 0 index
|
||||||
wid = 0
|
wid = 0
|
||||||
wid_uuid = uuid.uuid4().bytes_le
|
wid_uuid = uuid.uuid4().bytes
|
||||||
wid_heap_offset = 0
|
wid_heap_offset = 0
|
||||||
wid_data_size = 0
|
wid_data_size = 0
|
||||||
|
|
||||||
|
@ -82,7 +82,7 @@ class MPSPSCMRRingData:
|
||||||
raise Exception('data_size more than heap_size')
|
raise Exception('data_size more than heap_size')
|
||||||
|
|
||||||
fmv = FormattedMemoryViewIO(self._shared_mem.get_mv())
|
fmv = FormattedMemoryViewIO(self._shared_mem.get_mv())
|
||||||
wid_uuid = uuid.uuid4().bytes_le
|
wid_uuid = uuid.uuid4().bytes
|
||||||
|
|
||||||
if self._write_lock is not None:
|
if self._write_lock is not None:
|
||||||
self._write_lock.acquire()
|
self._write_lock.acquire()
|
||||||
|
|
|
@ -49,7 +49,7 @@ class MPWeakHeap:
|
||||||
|
|
||||||
# Entire block
|
# Entire block
|
||||||
fmv.seek(self._first_block_offset)
|
fmv.seek(self._first_block_offset)
|
||||||
fmv.write_fmt('qq', self._heap_size-self._first_block_offset, 0), fmv.write(uuid.uuid4().bytes_le)
|
fmv.write_fmt('qq', self._heap_size-self._first_block_offset, 0), fmv.write(uuid.uuid4().bytes)
|
||||||
|
|
||||||
|
|
||||||
def add_data(self, data : Union[bytes, bytearray, memoryview] ) -> 'MPWeakHeap.DataRef':
|
def add_data(self, data : Union[bytes, bytearray, memoryview] ) -> 'MPWeakHeap.DataRef':
|
||||||
|
@ -92,7 +92,7 @@ class MPWeakHeap:
|
||||||
if block_remain_size >= block_header_size:
|
if block_remain_size >= block_header_size:
|
||||||
# the remain space of the block is enough for next block, split the block
|
# the remain space of the block is enough for next block, split the block
|
||||||
next_block_offset = cur_block_offset + block_new_size
|
next_block_offset = cur_block_offset + block_new_size
|
||||||
fmv.seek(next_block_offset), fmv.write_fmt('qq', block_remain_size, 0), fmv.write(uuid.uuid4().bytes_le)
|
fmv.seek(next_block_offset), fmv.write_fmt('qq', block_remain_size, 0), fmv.write(uuid.uuid4().bytes)
|
||||||
else:
|
else:
|
||||||
# otherwise do not split
|
# otherwise do not split
|
||||||
next_block_offset = cur_block_offset + block_size
|
next_block_offset = cur_block_offset + block_size
|
||||||
|
@ -101,7 +101,7 @@ class MPWeakHeap:
|
||||||
block_new_size = block_size
|
block_new_size = block_size
|
||||||
|
|
||||||
# update current block structure
|
# update current block structure
|
||||||
uid = uuid.uuid4().bytes_le
|
uid = uuid.uuid4().bytes
|
||||||
fmv.seek(cur_block_offset), fmv.write_fmt('qq', block_new_size, data_size ), fmv.write(uid)
|
fmv.seek(cur_block_offset), fmv.write_fmt('qq', block_new_size, data_size ), fmv.write(uid)
|
||||||
|
|
||||||
# update ring_head_block_offset
|
# update ring_head_block_offset
|
||||||
|
@ -135,7 +135,7 @@ class MPWeakHeap:
|
||||||
next_block_size, = fmv.get_fmt('q')
|
next_block_size, = fmv.get_fmt('q')
|
||||||
|
|
||||||
# erase data of next block
|
# erase data of next block
|
||||||
fmv.write_fmt('qq', 0, 0), fmv.write(uuid.uuid4().bytes_le)
|
fmv.write_fmt('qq', 0, 0), fmv.write(uuid.uuid4().bytes)
|
||||||
|
|
||||||
# overwrite current block size with expanded block size
|
# overwrite current block size with expanded block size
|
||||||
fmv.seek(cur_block_offset)
|
fmv.seek(cur_block_offset)
|
||||||
|
|
|
@ -11,7 +11,6 @@ def _host_thread_proc(wref):
|
||||||
break
|
break
|
||||||
ref._host_process_messages(0.005)
|
ref._host_process_messages(0.005)
|
||||||
del ref
|
del ref
|
||||||
print('_host_thread_proc exit')
|
|
||||||
|
|
||||||
class SPMTWorker:
|
class SPMTWorker:
|
||||||
def __init__(self, *sub_args, **sub_kwargs):
|
def __init__(self, *sub_args, **sub_kwargs):
|
||||||
|
@ -152,7 +151,7 @@ class SPMTWorker:
|
||||||
break
|
break
|
||||||
|
|
||||||
self._threads_running = False
|
self._threads_running = False
|
||||||
|
|
||||||
self._threads_exit_barrier.wait()
|
self._threads_exit_barrier.wait()
|
||||||
|
|
||||||
self._on_sub_finalize()
|
self._on_sub_finalize()
|
|
@ -1,3 +1,3 @@
|
||||||
from .device import (ORTDeviceInfo, get_available_devices_info,
|
from .device import (ORTDeviceInfo, get_available_devices_info,
|
||||||
get_cpu_device)
|
get_cpu_device_info)
|
||||||
from .InferenceSession import InferenceSession_with_device
|
from .InferenceSession import InferenceSession_with_device
|
||||||
|
|
|
@ -67,33 +67,34 @@ class ORTDeviceInfo:
|
||||||
|
|
||||||
_ort_devices_info = None
|
_ort_devices_info = None
|
||||||
|
|
||||||
def get_cpu_device() -> ORTDeviceInfo:
|
def get_cpu_device_info() -> ORTDeviceInfo:
|
||||||
return ORTDeviceInfo(index=-1, execution_provider='CPUExecutionProvider', name='CPU', total_memory=0, free_memory=0)
|
return ORTDeviceInfo(index=-1, execution_provider='CPUExecutionProvider', name='CPU', total_memory=0, free_memory=0)
|
||||||
|
|
||||||
def get_available_devices_info(include_cpu=True, cpu_only=False) -> List[ORTDeviceInfo]:
|
def get_available_devices_info(include_cpu=True, cpu_only=False) -> List[ORTDeviceInfo]:
|
||||||
"""
|
"""
|
||||||
returns a list of available ORTDeviceInfo
|
returns a list of available ORTDeviceInfo
|
||||||
"""
|
"""
|
||||||
global _ort_devices_info
|
devices = []
|
||||||
if _ort_devices_info is None:
|
if not cpu_only:
|
||||||
_initialize_ort_devices()
|
global _ort_devices_info
|
||||||
devices = []
|
if _ort_devices_info is None:
|
||||||
if not cpu_only:
|
_initialize_ort_devices_info()
|
||||||
|
_ort_devices_info = []
|
||||||
for i in range ( int(os.environ.get('ORT_DEVICES_COUNT',0)) ):
|
for i in range ( int(os.environ.get('ORT_DEVICES_COUNT',0)) ):
|
||||||
devices.append ( ORTDeviceInfo(index=int(os.environ[f'ORT_DEVICE_{i}_INDEX']),
|
_ort_devices_info.append ( ORTDeviceInfo(index=int(os.environ[f'ORT_DEVICE_{i}_INDEX']),
|
||||||
execution_provider=os.environ[f'ORT_DEVICE_{i}_EP'],
|
execution_provider=os.environ[f'ORT_DEVICE_{i}_EP'],
|
||||||
name=os.environ[f'ORT_DEVICE_{i}_NAME'],
|
name=os.environ[f'ORT_DEVICE_{i}_NAME'],
|
||||||
total_memory=int(os.environ[f'ORT_DEVICE_{i}_TOTAL_MEM']),
|
total_memory=int(os.environ[f'ORT_DEVICE_{i}_TOTAL_MEM']),
|
||||||
free_memory=int(os.environ[f'ORT_DEVICE_{i}_FREE_MEM']),
|
free_memory=int(os.environ[f'ORT_DEVICE_{i}_FREE_MEM']),
|
||||||
) )
|
) )
|
||||||
if include_cpu or cpu_only:
|
devices += _ort_devices_info
|
||||||
devices.append(get_cpu_device())
|
if include_cpu:
|
||||||
_ort_devices_info = devices
|
devices.append(get_cpu_device_info())
|
||||||
|
|
||||||
return _ort_devices_info
|
return devices
|
||||||
|
|
||||||
|
|
||||||
def _initialize_ort_devices():
|
def _initialize_ort_devices_info():
|
||||||
"""
|
"""
|
||||||
Determine available ORT devices, and place info about them to os.environ,
|
Determine available ORT devices, and place info about them to os.environ,
|
||||||
they will be available in spawned subprocesses.
|
they will be available in spawned subprocesses.
|
||||||
|
@ -184,4 +185,4 @@ def _initialize_ort_devices():
|
||||||
os.environ[f'ORT_DEVICE_{i}_TOTAL_MEM'] = str(device['total_mem'])
|
os.environ[f'ORT_DEVICE_{i}_TOTAL_MEM'] = str(device['total_mem'])
|
||||||
os.environ[f'ORT_DEVICE_{i}_FREE_MEM'] = str(device['free_mem'])
|
os.environ[f'ORT_DEVICE_{i}_FREE_MEM'] = str(device['free_mem'])
|
||||||
|
|
||||||
_initialize_ort_devices()
|
_initialize_ort_devices_info()
|
||||||
|
|
|
@ -247,7 +247,8 @@ def ascii_table(table_def : List[str],
|
||||||
if row_line is not None:
|
if row_line is not None:
|
||||||
lines.append(row_line)
|
lines.append(row_line)
|
||||||
for sub_rows in rows:
|
for sub_rows in rows:
|
||||||
|
|
||||||
|
|
||||||
for row in sub_rows:
|
for row in sub_rows:
|
||||||
line = ''
|
line = ''
|
||||||
|
|
||||||
|
@ -288,7 +289,8 @@ def ascii_table(table_def : List[str],
|
||||||
line += right_border
|
line += right_border
|
||||||
|
|
||||||
lines.append(line)
|
lines.append(line)
|
||||||
if row_line is not None:
|
|
||||||
|
if len(sub_rows) != 0 and row_line is not None:
|
||||||
lines.append(row_line)
|
lines.append(row_line)
|
||||||
|
|
||||||
return '\n'.join(lines)
|
return '\n'.join(lines)
|
|
@ -1,2 +1,2 @@
|
||||||
from .device import (TorchDeviceInfo, TorchDevicesInfo, get_available_devices,
|
from .device import (TorchDeviceInfo, get_available_devices_info,
|
||||||
get_cpu_device)
|
get_cpu_device_info, get_device, get_device_info_by_index)
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
from typing import List
|
from typing import List, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
class TorchDeviceInfo:
|
class TorchDeviceInfo:
|
||||||
|
@ -45,96 +47,40 @@ class TorchDeviceInfo:
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f'{self.__class__.__name__} object: ' + self.__str__()
|
return f'{self.__class__.__name__} object: ' + self.__str__()
|
||||||
|
|
||||||
# class TorchDevicesInfo:
|
|
||||||
# """
|
|
||||||
# picklable list of TorchDeviceInfo
|
|
||||||
# """
|
|
||||||
# def __init__(self, devices : List[TorchDeviceInfo] = None):
|
|
||||||
# if devices is None:
|
|
||||||
# devices = []
|
|
||||||
# self._devices = devices
|
|
||||||
|
|
||||||
# def __getstate__(self):
|
|
||||||
# return self.__dict__.copy()
|
|
||||||
|
|
||||||
# def __setstate__(self, d):
|
|
||||||
# self.__init__()
|
|
||||||
# self.__dict__.update(d)
|
|
||||||
|
|
||||||
# def add(self, device_or_devices : TorchDeviceInfo):
|
|
||||||
# if isinstance(device_or_devices, TorchDeviceInfo):
|
|
||||||
# if device_or_devices not in self._devices:
|
|
||||||
# self._devices.append(device_or_devices)
|
|
||||||
# elif isinstance(device_or_devices, TorchDevicesInfo):
|
|
||||||
# for device in device_or_devices:
|
|
||||||
# self.add(device)
|
|
||||||
|
|
||||||
# def copy(self):
|
|
||||||
# return copy.deepcopy(self)
|
|
||||||
|
|
||||||
# def get_count(self): return len(self._devices)
|
|
||||||
|
|
||||||
# def get_largest_total_memory_device(self) -> TorchDeviceInfo:
|
|
||||||
# raise NotImplementedError()
|
|
||||||
# result = None
|
|
||||||
# idx_mem = 0
|
|
||||||
# for device in self._devices:
|
|
||||||
# mem = device.get_total_memory()
|
|
||||||
# if result is None or (mem is not None and mem > idx_mem):
|
|
||||||
# result = device
|
|
||||||
# idx_mem = mem
|
|
||||||
# return result
|
|
||||||
|
|
||||||
# def get_smallest_total_memory_device(self) -> TorchDeviceInfo:
|
|
||||||
# raise NotImplementedError()
|
|
||||||
# result = None
|
|
||||||
# idx_mem = sys.maxsize
|
|
||||||
# for device in self._devices:
|
|
||||||
# mem = device.get_total_memory()
|
|
||||||
# if result is None or (mem is not None and mem < idx_mem):
|
|
||||||
# result = device
|
|
||||||
# idx_mem = mem
|
|
||||||
# return result
|
|
||||||
|
|
||||||
# def __len__(self):
|
|
||||||
# return len(self._devices)
|
|
||||||
|
|
||||||
# def __getitem__(self, key):
|
|
||||||
# result = self._devices[key]
|
|
||||||
# if isinstance(key, slice):
|
|
||||||
# return self.__class__(result)
|
|
||||||
# return result
|
|
||||||
|
|
||||||
# def __iter__(self):
|
|
||||||
# for device in self._devices:
|
|
||||||
# yield device
|
|
||||||
|
|
||||||
# def __str__(self): return f'{self.__class__.__name__}:[' + ', '.join([ device.__str__() for device in self._devices ]) + ']'
|
|
||||||
# def __repr__(self): return f'{self.__class__.__name__}:[' + ', '.join([ device.__repr__() for device in self._devices ]) + ']'
|
|
||||||
|
|
||||||
|
|
||||||
_torch_devices = None
|
_torch_devices = None
|
||||||
|
|
||||||
def get_cpu_device() -> TorchDeviceInfo:
|
def get_cpu_device_info() -> TorchDeviceInfo:
|
||||||
return TorchDeviceInfo(index=-1, name='CPU', total_memory=0)
|
return TorchDeviceInfo(index=-1, name='CPU', total_memory=0)
|
||||||
|
|
||||||
def get_available_devices(include_cpu=True, cpu_only=False) -> List[TorchDeviceInfo]:
|
def get_device_info_by_index(index) -> Union[TorchDeviceInfo, None]:
|
||||||
|
for device in get_available_devices_info(include_cpu=False):
|
||||||
|
if device.get_index() == index:
|
||||||
|
return device
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_device(device_info : TorchDeviceInfo) -> torch.device:
|
||||||
|
if device_info.is_cpu():
|
||||||
|
return torch.device('cpu')
|
||||||
|
return torch.device(f'cuda:{device_info.get_index()}')
|
||||||
|
|
||||||
|
def get_available_devices_info(include_cpu=True, cpu_only=False) -> List[TorchDeviceInfo]:
|
||||||
"""
|
"""
|
||||||
returns a list of available TorchDeviceInfo
|
returns a list of available TorchDeviceInfo
|
||||||
"""
|
"""
|
||||||
global _torch_devices
|
devices = []
|
||||||
if _torch_devices is None:
|
if not cpu_only:
|
||||||
import torch
|
global _torch_devices
|
||||||
devices = []
|
if _torch_devices is None:
|
||||||
|
|
||||||
if not cpu_only:
|
_torch_devices = []
|
||||||
for i in range (torch.cuda.device_count()):
|
for i in range (torch.cuda.device_count()):
|
||||||
device_props = torch.cuda.get_device_properties(i)
|
device_props = torch.cuda.get_device_properties(i)
|
||||||
devices.append ( TorchDeviceInfo(index=i, name=device_props.name, total_memory=device_props.total_memory))
|
_torch_devices.append ( TorchDeviceInfo(index=i, name=device_props.name, total_memory=device_props.total_memory))
|
||||||
|
devices += _torch_devices
|
||||||
|
|
||||||
if include_cpu or cpu_only:
|
if include_cpu:
|
||||||
devices.append ( get_cpu_device() )
|
devices.append ( get_cpu_device_info() )
|
||||||
|
|
||||||
_torch_devices = devices
|
return devices
|
||||||
return _torch_devices
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue