mirror of
https://github.com/iperov/DeepFaceLive
synced 2025-07-15 01:23:45 -07:00
refactoring
This commit is contained in:
parent
8489949f2c
commit
30ba51edf7
24 changed files with 663 additions and 459 deletions
8
main.py
8
main.py
|
@ -79,12 +79,12 @@ def main():
|
|||
p.set_defaults(func=train_FaceAligner)
|
||||
|
||||
def train_CTSOT(args):
|
||||
from apps.trainers.CTSOT.CTSOTTrainerApp import run_app
|
||||
run_app(userdata_path=Path(args.userdata_dir), faceset_path=Path(args.faceset_path))
|
||||
from apps.trainers.CTSOT.CTSOTTrainerApp import CTSOTTrainerApp
|
||||
CTSOTTrainerApp(workspace_path=Path(args.workspace_dir), faceset_path=Path(args.faceset_path))
|
||||
|
||||
p = train_parsers.add_parser('CTSOT')
|
||||
p.add_argument('--userdata-dir', default=None, action=fixPathAction, help="Directory to save app data.")
|
||||
p.add_argument('--faceset-path', default=None, action=fixPathAction, help=".dfs path")
|
||||
p.add_argument('--workspace-dir', default=None, action=fixPathAction, help="Workspace directory.")
|
||||
p.add_argument('--faceset-path', default=None, action=fixPathAction, help=".dfs faceset path")
|
||||
p.set_defaults(func=train_CTSOT)
|
||||
|
||||
def bad_args(arguments):
|
||||
|
|
|
@ -58,7 +58,7 @@ def get_available_devices() -> List[ORTDeviceInfo]:
|
|||
class DFMModel:
|
||||
def __init__(self, model_path : Path, device : ORTDeviceInfo = None):
|
||||
if device is None:
|
||||
device = lib_ort.get_cpu_device()
|
||||
device = lib_ort.get_cpu_device_info()
|
||||
self._model_path = model_path
|
||||
|
||||
sess = self._sess = lib_ort.InferenceSession_with_device(str(model_path), device)
|
||||
|
|
121
modelhub/torch/CTSOT/CTSOT.py
Normal file
121
modelhub/torch/CTSOT/CTSOT.py
Normal file
|
@ -0,0 +1,121 @@
|
|||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from xlib.file import SplittedFile
|
||||
from xlib.torch import TorchDeviceInfo, get_cpu_device_info
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, ch):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(ch, ch, kernel_size=3, stride=1, padding=1)
|
||||
self.conv2 = nn.Conv2d(ch, ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, inp):
|
||||
x = inp
|
||||
x = F.leaky_relu(self.conv1(x), 0.2)
|
||||
x = F.leaky_relu(self.conv2(x) + inp, 0.2)
|
||||
return x
|
||||
|
||||
class AutoEncoder(nn.Module):
|
||||
def __init__(self, resolution, in_ch, ae_ch):
|
||||
super().__init__()
|
||||
self._resolution = resolution
|
||||
self._in_ch = in_ch
|
||||
self._ae_ch = ae_ch
|
||||
|
||||
self.conv1 = nn.Conv2d(in_ch*1, in_ch*2, kernel_size=5, stride=2, padding=2)
|
||||
self.conv2 = nn.Conv2d(in_ch*2, in_ch*4, kernel_size=5, stride=2, padding=2)
|
||||
self.conv3 = nn.Conv2d(in_ch*4, in_ch*8, kernel_size=5, stride=2, padding=2)
|
||||
self.conv4 = nn.Conv2d(in_ch*8, in_ch*8, kernel_size=5, stride=2, padding=2)
|
||||
|
||||
self.dense_in = nn.Linear( in_ch*8 * ( resolution // (2**4) )**2, ae_ch)
|
||||
self.dense_out = nn.Linear( ae_ch, in_ch*8 * ( resolution // (2**4) )**2 )
|
||||
|
||||
self.up_conv4 = nn.ConvTranspose2d(in_ch*8, in_ch*8, kernel_size=3, stride=2, padding=1, output_padding=(1,1), )
|
||||
self.up_conv3 = nn.ConvTranspose2d(in_ch*8, in_ch*4, kernel_size=3, stride=2, padding=1, output_padding=(1,1), )
|
||||
self.up_conv2 = nn.ConvTranspose2d(in_ch*4, in_ch*2, kernel_size=3, stride=2, padding=1, output_padding=(1,1), )
|
||||
self.up_conv1 = nn.ConvTranspose2d(in_ch*2, in_ch*1, kernel_size=3, stride=2, padding=1, output_padding=(1,1), )
|
||||
|
||||
|
||||
def forward(self, inp):
|
||||
x = inp
|
||||
x = F.leaky_relu(self.conv1(x), 0.1)
|
||||
x = F.leaky_relu(self.conv2(x), 0.1)
|
||||
x = F.leaky_relu(self.conv3(x), 0.1)
|
||||
x = F.leaky_relu(self.conv4(x), 0.1)
|
||||
|
||||
x = x.view( (x.shape[0], np.prod(x.shape[1:])) )
|
||||
|
||||
x = self.dense_in(x)
|
||||
x = self.dense_out(x)
|
||||
|
||||
x = x.view( (x.shape[0], self._in_ch*8, self._resolution // (2**4), self._resolution // (2**4) ))
|
||||
x = F.leaky_relu(self.up_conv4(x), 0.1)
|
||||
x = F.leaky_relu(self.up_conv3(x), 0.1)
|
||||
x = F.leaky_relu(self.up_conv2(x), 0.1)
|
||||
x = F.leaky_relu(self.up_conv1(x), 0.1)
|
||||
|
||||
# from xlib.console import diacon
|
||||
# diacon.Diacon.stop()
|
||||
# import code
|
||||
# code.interact(local=dict(globals(), **locals()))
|
||||
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class CTSOTNet(nn.Module):
|
||||
def __init__(self, resolution, in_ch=6, inner_ch=64, ae_ch=256, out_ch=3, res_block_count = 12):
|
||||
super().__init__()
|
||||
self.in_conv = nn.Conv2d(in_ch, inner_ch, kernel_size=1, stride=1, padding=0)
|
||||
self.ae = AutoEncoder(resolution, inner_ch, ae_ch)
|
||||
self.out_conv = nn.Conv2d(inner_ch, out_ch, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, img1_t, img2_t):
|
||||
x = torch.cat([img1_t, img2_t], dim=1)
|
||||
|
||||
x = self.in_conv(x)
|
||||
|
||||
x = self.ae(x)
|
||||
x = self.out_conv(x)
|
||||
x = torch.tanh(x)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class CTSOT:
|
||||
def __init__(self, device_info : TorchDeviceInfo = None,
|
||||
state_dict : Union[dict, None] = None,
|
||||
training : bool = False):
|
||||
if device_info is None:
|
||||
device_info = get_cpu_device_info()
|
||||
self.device_info = device_info
|
||||
|
||||
self._net = net = CTSOTNet()
|
||||
|
||||
if state_dict is not None:
|
||||
net.load_state_dict(state_dict)
|
||||
|
||||
if training:
|
||||
net.train()
|
||||
else:
|
||||
net.eval()
|
||||
|
||||
self.set_device(device_info)
|
||||
|
||||
def set_device(self, device_info : TorchDeviceInfo = None):
|
||||
if device_info is None or device_info.is_cpu():
|
||||
self._net.cpu()
|
||||
else:
|
||||
self._net.cuda(device_info.get_index())
|
||||
|
||||
def get_state_dict(self):
|
||||
return self.net.state_dict()
|
||||
|
||||
def get_net(self) -> CTSOTNet:
|
||||
return self._net
|
|
@ -8,13 +8,13 @@ import torch.nn.functional as F
|
|||
from xlib import math as lib_math
|
||||
from xlib.file import SplittedFile
|
||||
from xlib.image import ImageProcessor
|
||||
from xlib.torch import TorchDeviceInfo, get_cpu_device
|
||||
from xlib.torch import TorchDeviceInfo, get_cpu_device_info
|
||||
|
||||
|
||||
class S3FD:
|
||||
def __init__(self, device_info : TorchDeviceInfo = None ):
|
||||
if device_info is None:
|
||||
device_info = get_cpu_device()
|
||||
device_info = get_cpu_device_info()
|
||||
self.device_info = device_info
|
||||
|
||||
path = Path(__file__).parent / 'S3FD.pth'
|
||||
|
|
|
@ -1,2 +1,3 @@
|
|||
from .CenterFace.CenterFace import CenterFace, CenterFace_to_onnx
|
||||
from .CenterFace.CenterFace import CenterFace_to_onnx
|
||||
from .S3FD.S3FD import S3FD
|
||||
from .CTSOT.CTSOT import CTSOT, CTSOTNet
|
|
@ -55,12 +55,8 @@ def extract_FaceSynthetics(inputdir_path : Path, faceset_path : Path):
|
|||
raise ValueError('faceset_path must have .dfs extension.')
|
||||
|
||||
filepaths = lib_path.get_files_paths(inputdir_path)
|
||||
|
||||
fs = lib_face.Faceset(faceset_path)
|
||||
fs.recreate()
|
||||
|
||||
|
||||
for filepath in lib_con.progress_bar_iterator(filepaths, 'Processing'):
|
||||
fs = lib_face.Faceset(faceset_path, write_access=True, recreate=True)
|
||||
for filepath in lib_con.progress_bar_iterator(filepaths, desc='Processing'):
|
||||
if filepath.suffix == '.txt':
|
||||
|
||||
image_filepath = filepath.parent / f'{filepath.name.split("_")[0]}.png'
|
||||
|
@ -94,16 +90,12 @@ def extract_FaceSynthetics(inputdir_path : Path, faceset_path : Path):
|
|||
ufm.set_FRect(flmrks.get_FRect())
|
||||
ufm.add_FLandmarks2D(flmrks)
|
||||
|
||||
fs.add_UImage(uimg, format='png')
|
||||
fs.add_UFaceMark(ufm)
|
||||
fs.add_UImage(uimg, format='png')
|
||||
|
||||
|
||||
fs.shrink()
|
||||
fs.optimize()
|
||||
fs.close()
|
||||
|
||||
import code
|
||||
code.interact(local=dict(globals(), **locals()))
|
||||
|
||||
# seg_filepath = input_path / ( Path(image_filepath).stem + '_seg.png')
|
||||
# if not seg_filepath.exists():
|
||||
# raise ValueError(f'{seg_filepath} does not exist')
|
||||
|
|
|
@ -15,7 +15,7 @@ def gaussian_blur (input_t : Tensor, sigma, dtype=None) -> Tensor:
|
|||
|
||||
"""
|
||||
if sigma <= 0.0:
|
||||
return input_t.copy() #TODO
|
||||
return input_t.copy()
|
||||
|
||||
device = input_t.get_device()
|
||||
|
||||
|
|
|
@ -3,75 +3,116 @@ import threading
|
|||
import time
|
||||
from enum import IntEnum
|
||||
from typing import Any, Callable, List, Tuple, Union
|
||||
|
||||
from numbers import Number
|
||||
from ... import text as lib_text
|
||||
|
||||
class EDlgMode(IntEnum):
|
||||
UNDEFINED = 0
|
||||
UNHANDLED = 0
|
||||
BACK = 1
|
||||
RELOAD = 2
|
||||
WRONG_INPUT = 3
|
||||
SUCCESS = 4
|
||||
HANDLED = 4
|
||||
|
||||
|
||||
|
||||
class DlgChoice:
|
||||
def __init__(self, name : str = None, row_desc : str = None):
|
||||
super().__init__()
|
||||
def __init__(self, name : str = None,
|
||||
row_def : str = None,
|
||||
on_choose : Callable = None):
|
||||
if len(name) == 0:
|
||||
raise ValueError('Zero len name is not valid.')
|
||||
self._name = name
|
||||
self._row_desc = row_desc
|
||||
self._row_def = row_def
|
||||
self._on_choose = on_choose
|
||||
|
||||
def get_name(self) -> Union[str, None]: return self._name
|
||||
def get_row_desc(self) -> Union[str, None]: return self._row_desc
|
||||
def get_row_def(self) -> Union[str, None]: return self._row_def
|
||||
def get_on_choose(self) -> Callable: return self._on_choose
|
||||
|
||||
class Dlg:
|
||||
def __init__(self, title : str = None, has_go_back=True):
|
||||
def __init__(self, on_recreate : Callable[ [], 'Dlg'] = None,
|
||||
on_back : Callable = None,
|
||||
top_rows_def : Union[str, List[str]] = None,
|
||||
bottom_rows_def : Union[str, List[str]] = None,
|
||||
):
|
||||
"""
|
||||
base class for Diacon dialogs.
|
||||
"""
|
||||
self._on_recreate = on_recreate
|
||||
self._on_back = on_back
|
||||
|
||||
self._top_rows_def = top_rows_def
|
||||
self._bottom_rows_def = bottom_rows_def
|
||||
|
||||
def recreate(self):
|
||||
"""
|
||||
self._title = title
|
||||
self._has_go_back = has_go_back
|
||||
"""
|
||||
if self._on_recreate is not None:
|
||||
return self._on_recreate(self)
|
||||
else:
|
||||
raise Exception('on_recreate() is not defined.')
|
||||
|
||||
def get_name(self) -> str: return self._name
|
||||
|
||||
def handle_user_input(self, s : str) -> EDlgMode:
|
||||
"""
|
||||
def set_current(self, print=True):
|
||||
Diacon.update_dlg(self, print=print)
|
||||
|
||||
def handle_user_input(self, s : str):
|
||||
"""
|
||||
s = s.strip()
|
||||
"""
|
||||
mode = self.on_user_input(s.strip())
|
||||
|
||||
# ? and < available in any dialog, handle them first
|
||||
s_len = len(s)
|
||||
if s_len == 0:
|
||||
if mode == EDlgMode.UNHANDLED:
|
||||
mode = EDlgMode.RELOAD
|
||||
if mode == EDlgMode.WRONG_INPUT:
|
||||
print('\nWrong input')
|
||||
mode = EDlgMode.RELOAD
|
||||
|
||||
if mode == EDlgMode.RELOAD:
|
||||
self.recreate().set_current()
|
||||
if mode == EDlgMode.BACK:
|
||||
if self._on_back is not None:
|
||||
self._on_back(self)
|
||||
|
||||
#overridable
|
||||
def on_user_input(self, s : str) -> EDlgMode:
|
||||
if len(s) == 0:
|
||||
return EDlgMode.RELOAD
|
||||
if s_len == 1:
|
||||
#if s == '?':
|
||||
# return EDlgMode.RELOAD
|
||||
if self._on_back is not None and len(s) == 1:
|
||||
if s == '<':
|
||||
return EDlgMode.BACK
|
||||
|
||||
return self.on_user_input(s)
|
||||
return EDlgMode.UNHANDLED
|
||||
|
||||
def print(self, table_width_max=80, col_spacing = 3):
|
||||
"""
|
||||
print dialog
|
||||
"""
|
||||
|
||||
# Gather table lines
|
||||
table_def : List[str]= []
|
||||
|
||||
if self._has_go_back:
|
||||
trd = self._top_rows_def
|
||||
brd = self._bottom_rows_def
|
||||
if trd is not None:
|
||||
if not isinstance(trd, (list,tuple)):
|
||||
trd = [trd]
|
||||
table_def += trd
|
||||
|
||||
if self._on_back is not None:
|
||||
table_def.append('| < | Go back.')
|
||||
|
||||
table_def.append('|99')
|
||||
table_def = self.on_print(table_def)
|
||||
|
||||
if brd is not None:
|
||||
if not isinstance(brd, (list,tuple)):
|
||||
brd = [brd]
|
||||
table_def += brd
|
||||
|
||||
table = lib_text.ascii_table(table_def, max_table_width=80,
|
||||
left_border = None,
|
||||
right_border = None,
|
||||
left_border = '| ',
|
||||
right_border = ' |',
|
||||
border = ' | ',
|
||||
row_symbol = None)
|
||||
row_symbol = None,
|
||||
)
|
||||
print()
|
||||
print(table)
|
||||
|
||||
|
@ -80,36 +121,102 @@ class Dlg:
|
|||
def on_print(self, table_lines : List[Tuple[str,str]]):
|
||||
return table_lines
|
||||
|
||||
|
||||
|
||||
class DlgNumber(Dlg):
|
||||
def __init__(self, is_float : bool,
|
||||
current_value = None,
|
||||
min_value = None,
|
||||
max_value = None,
|
||||
clip_min_value = None,
|
||||
clip_max_value = None,
|
||||
on_value : Callable[ [Dlg, Number], None] = None,
|
||||
on_recreate : Callable[ [], 'Dlg'] = None,
|
||||
on_back : Callable = None,
|
||||
top_rows_def : Union[str, List[str]] = None,
|
||||
bottom_rows_def : Union[str, List[str]] = None, ):
|
||||
super().__init__(on_recreate=on_recreate, on_back=on_back, top_rows_def=top_rows_def, bottom_rows_def=bottom_rows_def)
|
||||
|
||||
if min_value is not None and max_value is not None and min_value > max_value:
|
||||
raise ValueError('min_value > max_value')
|
||||
if clip_min_value is not None and clip_max_value is not None and clip_min_value > clip_max_value:
|
||||
raise ValueError('clip_min_value > clip_max_value')
|
||||
|
||||
self._is_float = is_float
|
||||
self._current_value = current_value
|
||||
self._min_value = min_value
|
||||
self._max_value = max_value
|
||||
self._clip_min_value = clip_min_value
|
||||
self._clip_max_value = clip_max_value
|
||||
self._on_value = on_value
|
||||
|
||||
#overridable
|
||||
def on_user_input(self, s : str) -> EDlgMode:
|
||||
"""
|
||||
handle user input
|
||||
return False if input is invalid
|
||||
"""
|
||||
return EDlgMode.UNDEFINED
|
||||
def on_print(self, table_def : List[str]):
|
||||
|
||||
minv, maxv = self._min_value, self._max_value
|
||||
|
||||
if self._is_float:
|
||||
line = '| * | Enter float number'
|
||||
else:
|
||||
line = '| * | Enter integer number'
|
||||
|
||||
if minv is not None and maxv is None:
|
||||
line += f' in range: [{minv} ... )'
|
||||
elif minv is None and maxv is not None:
|
||||
line += f' in range: ( ... {maxv} ]'
|
||||
elif minv is not None and maxv is not None:
|
||||
line += f' in range: [{minv} ... {maxv} ]'
|
||||
|
||||
table_def.append(line)
|
||||
|
||||
return table_def
|
||||
|
||||
#overridable
|
||||
def on_user_input(self, s : str) -> bool:
|
||||
result = super().on_user_input(s)
|
||||
if result == EDlgMode.UNHANDLED:
|
||||
try:
|
||||
print(s)
|
||||
v = float(s) if self._is_float else int(s)
|
||||
|
||||
if self._min_value is not None:
|
||||
if v < self._min_value:
|
||||
return EDlgMode.WRONG_INPUT
|
||||
|
||||
if self._max_value is not None:
|
||||
if v > self._max_value:
|
||||
return EDlgMode.WRONG_INPUT
|
||||
|
||||
if self._clip_min_value is not None:
|
||||
if v < self._clip_min_value:
|
||||
v = self._clip_min_value
|
||||
|
||||
if self._clip_max_value is not None:
|
||||
if v > self._clip_max_value:
|
||||
v = self._clip_max_value
|
||||
|
||||
if self._on_value is not None:
|
||||
self._on_value(self, v)
|
||||
return EDlgMode.HANDLED
|
||||
except:
|
||||
return EDlgMode.WRONG_INPUT
|
||||
|
||||
return result
|
||||
|
||||
class DlgChoices(Dlg):
|
||||
def __init__(self, choices : List[DlgChoice], multiple_choices=False, title : str = None, has_go_back = True):
|
||||
"""
|
||||
|
||||
"""
|
||||
super().__init__(title=title, has_go_back=has_go_back)
|
||||
def __init__(self, choices : List[DlgChoice],
|
||||
on_multi_choice : Callable[ [ List[DlgChoice] ], None] = None,
|
||||
on_recreate : Callable[ [Dlg], Dlg] = None,
|
||||
on_back : Callable = None,
|
||||
top_rows_def : Union[str, List[str]] = None,
|
||||
bottom_rows_def : Union[str, List[str]] = None,
|
||||
):
|
||||
super().__init__(on_recreate=on_recreate, on_back=on_back, top_rows_def=top_rows_def, bottom_rows_def=bottom_rows_def)
|
||||
self._choices = choices
|
||||
self._multiple_choices = multiple_choices
|
||||
|
||||
self._results = None
|
||||
self._results_id = None
|
||||
self._on_multi_choice = on_multi_choice
|
||||
|
||||
self._short_names = [choice.get_name() for choice in choices]
|
||||
|
||||
# if any([x is not None for x in self._short_names]):
|
||||
# # Using short names from choices
|
||||
# if any([x is None for x in self._short_names]):
|
||||
# raise Exception('No short name for one of choices.')
|
||||
# if len(set(self._short_names)) != len(self._short_names):
|
||||
# raise ValueError(f'Contains duplicate short names: {self._short_names}')
|
||||
# else:
|
||||
|
||||
# Make short names for all choices
|
||||
names = [ choice.get_name() for choice in choices ]
|
||||
names_len = len(names)
|
||||
|
@ -139,27 +246,14 @@ class DlgChoices(Dlg):
|
|||
break
|
||||
self._short_names = short_names
|
||||
|
||||
|
||||
|
||||
def get_selected_choices(self) -> List[DlgChoice]:
|
||||
"""
|
||||
returns selected choices
|
||||
"""
|
||||
return self._results
|
||||
|
||||
def get_selected_choices_id(self) -> List[int]:
|
||||
"""
|
||||
returns selected choice
|
||||
"""
|
||||
return self._results_id
|
||||
|
||||
#overridable
|
||||
def on_print(self, table_def : List[str]):
|
||||
|
||||
for short_name, choice in zip(self._short_names, self._choices):
|
||||
row_def = f'| {short_name}'
|
||||
row_desc = choice.get_row_desc()
|
||||
if row_desc is not None:
|
||||
row_def += row_desc
|
||||
x = choice.get_row_def()
|
||||
if x is not None:
|
||||
row_def += x
|
||||
table_def.append(row_def)
|
||||
|
||||
return table_def
|
||||
|
@ -167,35 +261,36 @@ class DlgChoices(Dlg):
|
|||
#overridable
|
||||
def on_user_input(self, s : str) -> bool:
|
||||
result = super().on_user_input(s)
|
||||
if result == EDlgMode.UNDEFINED:
|
||||
if result == EDlgMode.UNHANDLED:
|
||||
|
||||
if self._multiple_choices:
|
||||
if self._on_multi_choice is not None:
|
||||
multi_s = s.split(',')
|
||||
else:
|
||||
multi_s = [s]
|
||||
|
||||
results = []
|
||||
results_id = []
|
||||
choices_id = []
|
||||
for s in multi_s:
|
||||
s = s.strip()
|
||||
|
||||
x = [ i for i,short_name in enumerate(self._short_names) if s == short_name ]
|
||||
x = [ i for i, short_name in enumerate(self._short_names) if s.strip() == short_name ]
|
||||
if len(x) == 0:
|
||||
# no short name match
|
||||
# No short name match
|
||||
return EDlgMode.WRONG_INPUT
|
||||
else:
|
||||
id = x[0]
|
||||
results_id.append(id)
|
||||
results.append(self._choices[id])
|
||||
choices_id.append(id)
|
||||
|
||||
if len(set(results_id)) != len(results_id):
|
||||
if len(set(choices_id)) != len(choices_id):
|
||||
# Duplicate input
|
||||
return EDlgMode.WRONG_INPUT
|
||||
|
||||
self._results = results
|
||||
self._results_id = results_id
|
||||
for id in choices_id:
|
||||
on_choose = self._choices[id].get_on_choose()
|
||||
if on_choose is not None:
|
||||
on_choose(self)
|
||||
|
||||
return EDlgMode.SUCCESS
|
||||
if self._on_multi_choice is not None:
|
||||
self._on_multi_choice(choices_id)
|
||||
|
||||
return EDlgMode.HANDLED
|
||||
|
||||
return result
|
||||
|
||||
|
@ -204,42 +299,9 @@ class DlgChoices(Dlg):
|
|||
class _Diacon:
|
||||
"""
|
||||
User dialog with via console.
|
||||
|
||||
Internal architecture:
|
||||
|
||||
[
|
||||
Main-Thread
|
||||
|
||||
current thread from which __init__() called
|
||||
]
|
||||
|
||||
[
|
||||
Dialog-Thread
|
||||
|
||||
separate thread where dialogs are handled and dynamically created
|
||||
|
||||
we need this thread, because main thread can be busy,
|
||||
for example training neural network
|
||||
|
||||
calls on_dlg() provided with __init__
|
||||
|
||||
thus keep in mind on_dlg() works in separate thread
|
||||
|
||||
This thread must not be blocked inside on_dlg(),
|
||||
because Diacon.stop() can be called that stops all threads.
|
||||
]
|
||||
|
||||
[
|
||||
Input-Thread
|
||||
|
||||
separate thread where user input is accepted in non-blocking mode,
|
||||
and transfered to processing thread
|
||||
]
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._on_dlg : Callable = None
|
||||
|
||||
self._lock = threading.RLock()
|
||||
self._current_dlg : Dlg = None
|
||||
self._new_dlg : Dlg = None
|
||||
|
@ -250,11 +312,10 @@ class _Diacon:
|
|||
self._input_request = False
|
||||
self._input_result : str = None
|
||||
|
||||
def start(self, on_dlg : Callable):
|
||||
def start(self):
|
||||
if self._started:
|
||||
raise Exception('Diacon already started.')
|
||||
self._started = True
|
||||
self._on_dlg = on_dlg
|
||||
|
||||
self._input_t = threading.Thread(target=self._input_thread, daemon=True)
|
||||
self._input_t.start()
|
||||
|
@ -285,8 +346,6 @@ class _Diacon:
|
|||
|
||||
|
||||
def _dialog_thread(self, ):
|
||||
self._on_dlg(None, EDlgMode.RELOAD)
|
||||
|
||||
while self._started:
|
||||
|
||||
with self._lock:
|
||||
|
@ -303,13 +362,7 @@ class _Diacon:
|
|||
if input_result is not None:
|
||||
|
||||
if self._current_dlg is not None:
|
||||
mode = self._current_dlg.handle_user_input(input_result)
|
||||
if mode == EDlgMode.WRONG_INPUT:
|
||||
print('\nWrong input')
|
||||
mode = EDlgMode.RELOAD
|
||||
if mode == EDlgMode.UNDEFINED:
|
||||
mode = EDlgMode.RELOAD
|
||||
self._on_dlg(self._current_dlg, mode)
|
||||
self._current_dlg.handle_user_input(input_result)
|
||||
continue
|
||||
|
||||
time.sleep(0.005)
|
||||
|
@ -332,6 +385,9 @@ class _Diacon:
|
|||
show current or set new Dialog
|
||||
Can be called from any thread.
|
||||
"""
|
||||
if not self._started:
|
||||
self.start()
|
||||
|
||||
self._new_dlg = (new_dlg, print)
|
||||
|
||||
Diacon = _Diacon()
|
||||
|
|
|
@ -1 +1 @@
|
|||
from .Diacon import Diacon, Dlg, DlgChoice, DlgChoices, EDlgMode
|
||||
from .Diacon import Diacon, Dlg, DlgChoice, DlgChoices, EDlgMode, DlgNumber
|
||||
|
|
|
@ -9,7 +9,7 @@ class FMask:
|
|||
def __init__(self, _from_pickled=False):
|
||||
"""
|
||||
"""
|
||||
self._uuid : Union[bytes, None] = uuid.uuid4().bytes_le if not _from_pickled else None
|
||||
self._uuid : Union[bytes, None] = uuid.uuid4().bytes if not _from_pickled else None
|
||||
self._mask_type : Union[FMask.Type, None] = None
|
||||
self._FImage_uuid : Union[bytes, None] = None
|
||||
|
||||
|
|
|
@ -52,9 +52,25 @@ class FaceWarper:
|
|||
self._rw_grid_tx = rnd_state.uniform(*rw_grid_tx) if isinstance(rw_grid_tx, Iterable) else rw_grid_tx
|
||||
self._rw_grid_ty = rnd_state.uniform(*rw_grid_ty) if isinstance(rw_grid_ty, Iterable) else rw_grid_ty
|
||||
|
||||
self._warp_rnd_mat = Affine2DUniMat.from_transformation(0.5, 0.5, self._rw_grid_rot_deg, 1.0+self._rw_grid_scale, self._rw_grid_tx, self._rw_grid_ty)
|
||||
self._align_rnd_mat = Affine2DUniMat.from_transformation(0.5, 0.5, self._align_rot_deg, 1.0+self._align_scale, self._align_tx, self._align_ty)
|
||||
|
||||
self._rnd_state_state = rnd_state.get_state()
|
||||
self._cached = {}
|
||||
|
||||
def get_aligned_random_transform_mat(self) -> Affine2DUniMat:
|
||||
"""
|
||||
returns Affine2DUniMat that represents transformation from aligned face to randomly transformed aligned face
|
||||
"""
|
||||
mat1 = self._img_to_face_uni_mat
|
||||
mat2 = (self._face_to_img_uni_mat * self._align_rnd_mat).invert()
|
||||
|
||||
pts = [ [0,0], [1,0], [1,1]]
|
||||
src_pts = mat1.transform_points(pts)
|
||||
dst_pts = mat2.transform_points(pts)
|
||||
|
||||
return Affine2DUniMat.from_3_pairs(src_pts, dst_pts)
|
||||
|
||||
def transform(self, img : np.ndarray, out_res : int, random_warp : bool = True) -> np.ndarray:
|
||||
"""
|
||||
transform an image.
|
||||
|
@ -91,9 +107,7 @@ class FaceWarper:
|
|||
face_warp_grid = FaceWarper._gen_random_warp_uni_grid_diff(out_res, self._rw_grid_cell_count, 0.12, rnd_state)
|
||||
|
||||
# make a randomly transformable mat to transform face_warp_grid from face to image
|
||||
face_warp_grid_mat = (self._face_to_img_uni_mat *
|
||||
Affine2DUniMat.from_transformation(0.5, 0.5, self._rw_grid_rot_deg, 1.0+self._rw_grid_scale, self._rw_grid_tx, self._rw_grid_ty)
|
||||
)
|
||||
face_warp_grid_mat = self._face_to_img_uni_mat * self._warp_rnd_mat
|
||||
|
||||
# warp face_warp_grid to the space of image and merge with image_grid
|
||||
image_grid += cv2.warpAffine(face_warp_grid, face_warp_grid_mat.to_exact_mat(out_res,out_res, W, H), (W,H), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT)
|
||||
|
@ -101,9 +115,10 @@ class FaceWarper:
|
|||
# scale uniform grid from to image size
|
||||
image_grid *= (H-1, W-1)
|
||||
|
||||
# apply random transormations for align mat
|
||||
img_to_face_rnd_mat = (self._face_to_img_uni_mat * Affine2DMat.from_transformation(0.5, 0.5, self._align_rot_deg, 1.0+self._align_scale, self._align_tx, self._align_ty)
|
||||
).invert().to_exact_mat(W,H,out_res,out_res)
|
||||
# apply random transformations for align mat
|
||||
#img_to_face_rnd_uni_mat = (self._face_to_img_uni_mat * self._align_rnd_mat).invert()
|
||||
|
||||
img_to_face_rnd_mat = (self._face_to_img_uni_mat * self._align_rnd_mat).invert().to_exact_mat(W,H,out_res,out_res)
|
||||
|
||||
# warp image_grid to face space
|
||||
image_grid = cv2.warpAffine(image_grid, img_to_face_rnd_mat, (out_res,out_res), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_REPLICATE )
|
||||
|
|
|
@ -1,19 +1,22 @@
|
|||
import pickle
|
||||
import sqlite3
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Generator, List, Union, Iterable
|
||||
from typing import Generator, Iterable, List, Union
|
||||
|
||||
import cv2
|
||||
import h5py
|
||||
import numpy as np
|
||||
|
||||
from .. import console as lib_con
|
||||
from .FMask import FMask
|
||||
from .UFaceMark import UFaceMark
|
||||
from .UImage import UImage
|
||||
from .UPerson import UPerson
|
||||
|
||||
|
||||
class Faceset:
|
||||
|
||||
def __init__(self, path = None):
|
||||
def __init__(self, path = None, write_access=False, recreate=False):
|
||||
"""
|
||||
Faceset is a class to store and manage face related data.
|
||||
|
||||
|
@ -21,205 +24,155 @@ class Faceset:
|
|||
|
||||
path path to faceset .dfs file
|
||||
|
||||
write_access
|
||||
|
||||
recreate
|
||||
|
||||
Can be pickled.
|
||||
"""
|
||||
self._f = None
|
||||
|
||||
self._path = path = Path(path)
|
||||
if path.suffix != '.dfs':
|
||||
raise ValueError('Path must be a .dfs file')
|
||||
|
||||
self._conn = conn = sqlite3.connect(path, isolation_level=None)
|
||||
self._cur = cur = conn.cursor()
|
||||
if path.exists():
|
||||
if write_access and recreate:
|
||||
path.unlink()
|
||||
elif not write_access:
|
||||
raise FileNotFoundError(f'File {path} not found.')
|
||||
|
||||
cur = self._get_cursor()
|
||||
cur.execute('BEGIN IMMEDIATE')
|
||||
if not self._is_table_exists('FacesetInfo'):
|
||||
self.recreate(shrink=False, _transaction=False)
|
||||
cur.execute('COMMIT')
|
||||
self.shrink()
|
||||
else:
|
||||
cur.execute('END')
|
||||
self._mode = 'a' if write_access else 'r'
|
||||
self._open()
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
|
||||
def __getstate__(self):
|
||||
return {'_path' : self._path}
|
||||
return {'_path' : self._path, '_mode' : self._mode}
|
||||
|
||||
def __setstate__(self, d):
|
||||
self.__init__( d['_path'] )
|
||||
self._f = None
|
||||
self._path = d['_path']
|
||||
self._mode = d['_mode']
|
||||
self._open()
|
||||
|
||||
def __repr__(self): return self.__str__()
|
||||
def __str__(self):
|
||||
return f"Faceset. UImage:{self.get_UImage_count()} UFaceMark:{self.get_UFaceMark_count()} UPerson:{self.get_UPerson_count()}"
|
||||
|
||||
def _is_table_exists(self, name):
|
||||
return self._cur.execute(f"SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", [name]).fetchone()[0] != 0
|
||||
def _open(self):
|
||||
if self._f is None:
|
||||
self._f = f = h5py.File(self._path, mode=self._mode)
|
||||
self._UFaceMark_grp = f.require_group('UFaceMark')
|
||||
self._UImage_grp = f.require_group('UImage')
|
||||
self._UImage_image_data_grp = f.require_group('UImage_image_data')
|
||||
self._UPerson_grp = f.require_group('UPerson')
|
||||
|
||||
def _get_cursor(self) -> sqlite3.Cursor: return self._cur
|
||||
|
||||
def close(self):
|
||||
if self._cur is not None:
|
||||
self._cur.close()
|
||||
self._cur = None
|
||||
if self._f is not None:
|
||||
self._f.close()
|
||||
self._f = None
|
||||
|
||||
if self._conn is not None:
|
||||
self._conn.close()
|
||||
self._conn = None
|
||||
|
||||
def shrink(self):
|
||||
self._cur.execute('VACUUM')
|
||||
|
||||
def recreate(self, shrink=True, _transaction=True):
|
||||
def optimize(self, verbose=True):
|
||||
"""
|
||||
delete all data and recreate Faceset structure.
|
||||
recreate Faceset with optimized structure.
|
||||
"""
|
||||
cur = self._get_cursor()
|
||||
if verbose:
|
||||
print(f'Optimizing {self._path.name}...')
|
||||
|
||||
if _transaction:
|
||||
cur.execute('BEGIN IMMEDIATE')
|
||||
tmp_path = self._path.parent / (self._path.stem + '_optimizing' + self._path.suffix)
|
||||
|
||||
for table_name, in cur.execute("SELECT name from sqlite_master where type = 'table';").fetchall():
|
||||
cur.execute(f'DROP TABLE {table_name}')
|
||||
tmp_fs = Faceset(tmp_path, write_access=True, recreate=True)
|
||||
self._group_copy(tmp_fs._UFaceMark_grp, self._UFaceMark_grp, verbose=verbose)
|
||||
self._group_copy(tmp_fs._UPerson_grp, self._UPerson_grp, verbose=verbose)
|
||||
self._group_copy(tmp_fs._UImage_grp, self._UImage_grp, verbose=verbose)
|
||||
self._group_copy(tmp_fs._UImage_image_data_grp, self._UImage_image_data_grp, verbose=verbose)
|
||||
tmp_fs.close()
|
||||
|
||||
(cur.execute('CREATE TABLE FacesetInfo (version INT)')
|
||||
.execute('INSERT INTO FacesetInfo VALUES (1)')
|
||||
self.close()
|
||||
self._path.unlink()
|
||||
tmp_path.rename(self._path)
|
||||
self._open()
|
||||
|
||||
.execute('CREATE TABLE UImage (uuid BLOB, name TEXT, format TEXT, data BLOB)')
|
||||
.execute('CREATE TABLE UPerson (uuid BLOB, data BLOB)')
|
||||
.execute('CREATE TABLE UFaceMark (uuid BLOB, UImage_uuid BLOB, UPerson_uuid BLOB, data BLOB)')
|
||||
)
|
||||
def _group_copy(self, group_dst : h5py.Group, group_src : h5py.Group, verbose=True):
|
||||
for key, value in lib_con.progress_bar_iterator(group_src.items(), desc=f'Copying {group_src.name} -> {group_dst.name}', suppress_print=not verbose):
|
||||
d = group_dst.create_dataset(key, shape=value.shape, dtype=value.dtype )
|
||||
d[:] = value[:]
|
||||
for a_key, a_value in value.attrs.items():
|
||||
d.attrs[a_key] = a_value
|
||||
|
||||
if _transaction:
|
||||
cur.execute('COMMIT')
|
||||
def _group_read_bytes(self, group : h5py.Group, key : str, check_key=True) -> Union[bytes, None]:
|
||||
if check_key and key not in group:
|
||||
return None
|
||||
dataset = group[key]
|
||||
data_bytes = bytearray(len(dataset))
|
||||
dataset.read_direct(np.frombuffer(data_bytes, dtype=np.uint8))
|
||||
return data_bytes
|
||||
|
||||
if shrink:
|
||||
self.shrink()
|
||||
def _group_write_bytes(self, group : h5py.Group, key : str, data : bytes, update_existing=True) -> Union[h5py.Dataset, None]:
|
||||
if key in group:
|
||||
if not update_existing:
|
||||
return None
|
||||
del group[key]
|
||||
|
||||
return group.create_dataset(key, data=np.frombuffer(data, dtype=np.uint8) )
|
||||
|
||||
###################
|
||||
### UFaceMark
|
||||
###################
|
||||
def _UFaceMark_from_db_row(self, db_row) -> UFaceMark:
|
||||
uuid, UImage_uuid, UPerson_uuid, data = db_row
|
||||
|
||||
ufm = UFaceMark()
|
||||
ufm.restore_state(pickle.loads(data))
|
||||
return ufm
|
||||
|
||||
def add_UFaceMark(self, ufacemark_or_list : UFaceMark):
|
||||
def add_UFaceMark(self, ufacemark_or_list : UFaceMark, update_existing=True):
|
||||
"""
|
||||
add or update UFaceMark in DB
|
||||
"""
|
||||
if not isinstance(ufacemark_or_list, Iterable):
|
||||
ufacemark_or_list : List[UFaceMark] = [ufacemark_or_list]
|
||||
|
||||
cur = self._cur
|
||||
cur.execute('BEGIN IMMEDIATE')
|
||||
for ufm in ufacemark_or_list:
|
||||
uuid = ufm.get_uuid()
|
||||
UImage_uuid = ufm.get_UImage_uuid()
|
||||
UPerson_uuid = ufm.get_UPerson_uuid()
|
||||
|
||||
data = pickle.dumps(ufm.dump_state())
|
||||
|
||||
if cur.execute('SELECT COUNT(*) from UFaceMark where uuid=?', [uuid] ).fetchone()[0] != 0:
|
||||
cur.execute('UPDATE UFaceMark SET UImage_uuid=?, UPerson_uuid=?, data=? WHERE uuid=?',
|
||||
[UImage_uuid, UPerson_uuid, data, uuid])
|
||||
else:
|
||||
cur.execute('INSERT INTO UFaceMark VALUES (?, ?, ?, ?)', [uuid, UImage_uuid, UPerson_uuid, data])
|
||||
cur.execute('COMMIT')
|
||||
self._group_write_bytes(self._UFaceMark_grp, ufm.get_uuid().hex(), pickle.dumps(ufm.dump_state()), update_existing=update_existing )
|
||||
|
||||
def get_UFaceMark_count(self) -> int:
|
||||
return self._cur.execute('SELECT COUNT(*) FROM UFaceMark').fetchone()[0]
|
||||
return len(self._UFaceMark_grp.keys())
|
||||
|
||||
def get_all_UFaceMark(self) -> List[UFaceMark]:
|
||||
return [ self._UFaceMark_from_db_row(db_row) for db_row in self._cur.execute('SELECT * FROM UFaceMark').fetchall() ]
|
||||
return [ UFaceMark.from_state(pickle.loads(self._group_read_bytes(self._UFaceMark_grp, key, check_key=False))) for key in self._UFaceMark_grp.keys() ]
|
||||
|
||||
def get_all_UFaceMark_uuids(self) -> List[bytes]:
|
||||
return [ uuid.UUID(key).bytes for key in self._UFaceMark_grp.keys() ]
|
||||
|
||||
def get_UFaceMark_by_uuid(self, uuid : bytes) -> Union[UFaceMark, None]:
|
||||
c = self._cur.execute('SELECT * FROM UFaceMark WHERE uuid=?', [uuid])
|
||||
db_row = c.fetchone()
|
||||
if db_row is None:
|
||||
data = self._group_read_bytes(self._UFaceMark_grp, uuid.hex())
|
||||
if data is None:
|
||||
return None
|
||||
return UFaceMark.from_state(pickle.loads(data))
|
||||
|
||||
def delete_UFaceMark_by_uuid(self, uuid : bytes) -> bool:
|
||||
key = uuid.hex()
|
||||
if key in self._UFaceMark_grp:
|
||||
del self._UFaceMark_grp[key]
|
||||
return True
|
||||
return False
|
||||
|
||||
return self._UFaceMark_from_db_row(db_row)
|
||||
|
||||
def iter_UFaceMark(self) -> Generator[UFaceMark, None, None]:
|
||||
"""
|
||||
returns Generator of UFaceMark
|
||||
"""
|
||||
for db_row in self._cur.execute('SELECT * FROM UFaceMark').fetchall():
|
||||
yield self._UFaceMark_from_db_row(db_row)
|
||||
for key in self._UFaceMark_grp.keys():
|
||||
yield UFaceMark.from_state(pickle.loads(self._group_read_bytes(self._UFaceMark_grp, key, check_key=False)))
|
||||
|
||||
def delete_all_UFaceMark(self):
|
||||
"""
|
||||
deletes all UFaceMark from DB
|
||||
"""
|
||||
(self._cur.execute('BEGIN IMMEDIATE')
|
||||
.execute('DELETE FROM UFaceMark')
|
||||
.execute('COMMIT') )
|
||||
###################
|
||||
### UPerson
|
||||
###################
|
||||
def _UPerson_from_db_row(self, db_row) -> UPerson:
|
||||
uuid, data = db_row
|
||||
up = UPerson()
|
||||
up.restore_state(pickle.loads(data))
|
||||
return up
|
||||
|
||||
def add_UPerson(self, uperson_or_list : UPerson):
|
||||
"""
|
||||
add or update UPerson in DB
|
||||
"""
|
||||
if not isinstance(uperson_or_list, Iterable):
|
||||
uperson_or_list : List[UPerson] = [uperson_or_list]
|
||||
|
||||
cur = self._cur
|
||||
cur.execute('BEGIN IMMEDIATE')
|
||||
for uperson in uperson_or_list:
|
||||
uuid = uperson.get_uuid()
|
||||
|
||||
data = pickle.dumps(uperson.dump_state())
|
||||
|
||||
if cur.execute('SELECT COUNT(*) from UPerson where uuid=?', [uuid]).fetchone()[0] != 0:
|
||||
cur.execute('UPDATE UPerson SET data=? WHERE uuid=?', [data])
|
||||
else:
|
||||
cur.execute('INSERT INTO UPerson VALUES (?, ?)', [uuid, data])
|
||||
cur.execute('COMMIT')
|
||||
|
||||
def get_UPerson_count(self) -> int:
|
||||
return self._cur.execute('SELECT COUNT(*) FROM UPerson').fetchone()[0]
|
||||
|
||||
def get_all_UPerson(self) -> List[UPerson]:
|
||||
return [ self._UPerson_from_db_row(db_row) for db_row in self._cur.execute('SELECT * FROM UPerson').fetchall() ]
|
||||
|
||||
def iter_UPerson(self) -> Generator[UPerson, None, None]:
|
||||
"""
|
||||
iterator of all UPerson's
|
||||
"""
|
||||
for db_row in self._cur.execute('SELECT * FROM UPerson').fetchall():
|
||||
yield self._UPerson_from_db_row(db_row)
|
||||
|
||||
def delete_all_UPerson(self):
|
||||
"""
|
||||
deletes all UPerson from DB
|
||||
"""
|
||||
(self._cur.execute('BEGIN IMMEDIATE')
|
||||
.execute('DELETE FROM UPerson')
|
||||
.execute('COMMIT') )
|
||||
for key in self._UFaceMark_grp.keys():
|
||||
del self._UFaceMark_grp[key]
|
||||
|
||||
###################
|
||||
### UImage
|
||||
###################
|
||||
def _UImage_from_db_row(self, db_row) -> UImage:
|
||||
uuid, name, format, data_bytes = db_row
|
||||
img = cv2.imdecode(np.frombuffer(data_bytes, dtype=np.uint8), flags=cv2.IMREAD_UNCHANGED)
|
||||
|
||||
uimg = UImage()
|
||||
uimg.set_uuid(uuid)
|
||||
uimg.set_name(name)
|
||||
uimg.assign_image(img)
|
||||
return uimg
|
||||
|
||||
def add_UImage(self, uimage_or_list : UImage, format : str = 'webp', quality : int = 100):
|
||||
def add_UImage(self, uimage_or_list : UImage, format : str = 'png', quality : int = 100, update_existing=True):
|
||||
"""
|
||||
add or update UImage in DB
|
||||
|
||||
|
@ -239,9 +192,8 @@ class Faceset:
|
|||
raise ValueError('quality must be in range [0..100]')
|
||||
|
||||
if not isinstance(uimage_or_list, Iterable):
|
||||
uimage_or_list = [uimage_or_list]
|
||||
uimage_or_list : List[UImage] = [uimage_or_list]
|
||||
|
||||
uimage_datas = []
|
||||
for uimage in uimage_or_list:
|
||||
if format == 'webp':
|
||||
imencode_args = [int(cv2.IMWRITE_WEBP_QUALITY), quality]
|
||||
|
@ -251,44 +203,112 @@ class Faceset:
|
|||
imencode_args = [int(cv2.IMWRITE_JPEG2000_COMPRESSION_X1000), quality*10]
|
||||
else:
|
||||
imencode_args = []
|
||||
|
||||
ret, data_bytes = cv2.imencode( f'.{format}', uimage.get_image(), imencode_args)
|
||||
if not ret:
|
||||
raise Exception(f'Unable to encode image format {format}')
|
||||
uimage_datas.append(data_bytes.data)
|
||||
|
||||
cur = self._cur
|
||||
cur.execute('BEGIN IMMEDIATE')
|
||||
for uimage, data in zip(uimage_or_list, uimage_datas):
|
||||
uuid = uimage.get_uuid()
|
||||
if cur.execute('SELECT COUNT(*) from UImage where uuid=?', [uuid] ).fetchone()[0] != 0:
|
||||
cur.execute('UPDATE UImage SET name=?, format=?, data=? WHERE uuid=?', [uimage.get_name(), format, data, uuid])
|
||||
else:
|
||||
cur.execute('INSERT INTO UImage VALUES (?, ?, ?, ?)', [uuid, uimage.get_name(), format, data])
|
||||
cur.execute('COMMIT')
|
||||
key = uimage.get_uuid().hex()
|
||||
|
||||
def get_UImage_count(self) -> int: return self._cur.execute('SELECT COUNT(*) FROM UImage').fetchone()[0]
|
||||
def get_UImage_by_uuid(self, uuid : Union[bytes, None]) -> Union[UImage, None]:
|
||||
"""
|
||||
"""
|
||||
if uuid is None:
|
||||
self._group_write_bytes(self._UImage_grp, key, pickle.dumps(uimage.dump_state(exclude_image=True)), update_existing=update_existing )
|
||||
d = self._group_write_bytes(self._UImage_image_data_grp, key, data_bytes.data, update_existing=update_existing )
|
||||
d.attrs['format'] = format
|
||||
d.attrs['quality'] = quality
|
||||
|
||||
def get_UImage_count(self) -> int:
|
||||
return len(self._UImage_grp.keys())
|
||||
|
||||
def get_all_UImage(self) -> List[UImage]:
|
||||
return [ self._get_UImage_by_key(key) for key in self._UImage_grp.keys() ]
|
||||
|
||||
def get_all_UImage_uuids(self) -> List[bytes]:
|
||||
return [ uuid.UUID(key).bytes for key in self._UImage_grp.keys() ]
|
||||
|
||||
def _get_UImage_by_key(self, key, check_key=True) -> Union[UImage, None]:
|
||||
data = self._group_read_bytes(self._UImage_grp, key, check_key=check_key)
|
||||
if data is None:
|
||||
return None
|
||||
uimg = UImage.from_state(pickle.loads(data))
|
||||
|
||||
db_row = self._cur.execute('SELECT * FROM UImage where uuid=?', [uuid]).fetchone()
|
||||
if db_row is None:
|
||||
return None
|
||||
return self._UImage_from_db_row(db_row)
|
||||
image_data = self._group_read_bytes(self._UImage_image_data_grp, key, check_key=check_key)
|
||||
if image_data is not None:
|
||||
uimg.assign_image (cv2.imdecode(np.frombuffer(image_data, dtype=np.uint8), flags=cv2.IMREAD_UNCHANGED))
|
||||
|
||||
def iter_UImage(self) -> Generator[UImage, None, None]:
|
||||
return uimg
|
||||
|
||||
def get_UImage_by_uuid(self, uuid : bytes) -> Union[UImage, None]:
|
||||
return self._get_UImage_by_key(uuid.hex())
|
||||
|
||||
def delete_UImage_by_uuid(self, uuid : bytes):
|
||||
key = uuid.hex()
|
||||
if key in self._UImage_grp:
|
||||
del self._UImage_grp[key]
|
||||
if key in self._UImage_image_data_grp:
|
||||
del self._UImage_image_data_grp[key]
|
||||
|
||||
def iter_UImage(self, include_key=False) -> Generator[UImage, None, None]:
|
||||
"""
|
||||
iterator of all UImage's
|
||||
returns Generator of UImage
|
||||
"""
|
||||
for db_row in self._cur.execute('SELECT * FROM UImage').fetchall():
|
||||
yield self._UImage_from_db_row(db_row)
|
||||
for key in self._UImage_grp.keys():
|
||||
uimg = self._get_UImage_by_key(key, check_key=False)
|
||||
yield (uimg, key) if include_key else uimg
|
||||
|
||||
def delete_all_UImage(self):
|
||||
"""
|
||||
deletes all UImage from DB
|
||||
"""
|
||||
(self._cur.execute('BEGIN IMMEDIATE')
|
||||
.execute('DELETE FROM UImage')
|
||||
.execute('COMMIT') )
|
||||
for key in self._UImage_grp.keys():
|
||||
del self._UImage_grp[key]
|
||||
for key in self._UImage_image_data_grp.keys():
|
||||
del self._UImage_image_data_grp[key]
|
||||
|
||||
###################
|
||||
### UPerson
|
||||
###################
|
||||
|
||||
def add_UPerson(self, uperson_or_list : UPerson, update_existing=True):
|
||||
"""
|
||||
add or update UPerson in DB
|
||||
"""
|
||||
if not isinstance(uperson_or_list, Iterable):
|
||||
uperson_or_list : List[UPerson] = [uperson_or_list]
|
||||
|
||||
for uperson in uperson_or_list:
|
||||
self._group_write_bytes(self._UPerson_grp, uperson.get_uuid().hex(), pickle.dumps(uperson.dump_state()), update_existing=update_existing )
|
||||
|
||||
def get_UPerson_count(self) -> int:
|
||||
return len(self._UPerson_grp.keys())
|
||||
|
||||
def get_all_UPerson(self) -> List[UPerson]:
|
||||
return [ UPerson.from_state(pickle.loads(self._group_read_bytes(self._UPerson_grp, key, check_key=False))) for key in self._UPerson_grp.keys() ]
|
||||
|
||||
def get_all_UPerson_uuids(self) -> List[bytes]:
|
||||
return [ uuid.UUID(key).bytes for key in self._UPerson_grp.keys() ]
|
||||
|
||||
def get_UPerson_by_uuid(self, uuid : bytes) -> Union[UPerson, None]:
|
||||
data = self._group_read_bytes(self._UPerson_grp, uuid.hex())
|
||||
if data is None:
|
||||
return None
|
||||
return UPerson.from_state(pickle.loads(data))
|
||||
|
||||
def delete_UPerson_by_uuid(self, uuid : bytes) -> bool:
|
||||
key = uuid.hex()
|
||||
if key in self._UPerson_grp:
|
||||
del self._UPerson_grp[key]
|
||||
return True
|
||||
return False
|
||||
|
||||
def iter_UPerson(self) -> Generator[UPerson, None, None]:
|
||||
"""
|
||||
returns Generator of UPerson
|
||||
"""
|
||||
for key in self._UPerson_grp.keys():
|
||||
yield UPerson.from_state(pickle.loads(self._group_read_bytes(self._UPerson_grp, key, check_key=False)))
|
||||
|
||||
def delete_all_UPerson(self):
|
||||
"""
|
||||
deletes all UPerson from DB
|
||||
"""
|
||||
for key in self._UPerson_grp.keys():
|
||||
del self._UPerson_grp[key]
|
||||
|
|
|
@ -26,6 +26,12 @@ class UFaceMark(IState):
|
|||
def __str__(self):
|
||||
return f"UFaceMark UUID:[...{self.get_uuid()[-4:].hex()}]"
|
||||
|
||||
@staticmethod
|
||||
def from_state(state : dict) -> 'UFaceMark':
|
||||
ufm = UFaceMark()
|
||||
ufm.restore_state(state)
|
||||
return ufm
|
||||
|
||||
def restore_state(self, state : dict):
|
||||
self._uuid = state.get('_uuid', None)
|
||||
self._UImage_uuid = state.get('_UImage_uuid', None)
|
||||
|
@ -45,7 +51,7 @@ class UFaceMark(IState):
|
|||
|
||||
def get_uuid(self) -> Union[bytes, None]:
|
||||
if self._uuid is None:
|
||||
self._uuid = uuid.uuid4().bytes_le
|
||||
self._uuid = uuid.uuid4().bytes
|
||||
return self._uuid
|
||||
|
||||
def set_uuid(self, uuid : Union[bytes, None]):
|
||||
|
@ -72,6 +78,16 @@ class UFaceMark(IState):
|
|||
self._FRect = face_urect
|
||||
|
||||
def get_all_FLandmarks2D(self) -> List[FLandmarks2D]: return self._FLandmarks2D_list
|
||||
|
||||
def get_FLandmarks2D_best(self) -> Union[FLandmarks2D, None]:
|
||||
"""get best available FLandmarks2D """
|
||||
lmrks = self.get_FLandmarks2D_by_type(ELandmarks2D.L468)
|
||||
if lmrks is None:
|
||||
lmrks = self.get_FLandmarks2D_by_type(ELandmarks2D.L68)
|
||||
if lmrks is None:
|
||||
lmrks = self.get_FLandmarks2D_by_type(ELandmarks2D.L5)
|
||||
return lmrks
|
||||
|
||||
def get_FLandmarks2D_by_type(self, type : ELandmarks2D) -> Union[FLandmarks2D, None]:
|
||||
"""get FLandmarks2D from list by type"""
|
||||
if not isinstance(type, ELandmarks2D):
|
||||
|
|
|
@ -19,9 +19,30 @@ class UImage(IState):
|
|||
def __str__(self): return f"UImage UUID:[...{self.get_uuid()[-4:].hex()}] name:[{self._name}] image:[{ (self._image.shape, self._image.dtype) if self._image is not None else None}]"
|
||||
def __repr__(self): return self.__str__()
|
||||
|
||||
@staticmethod
|
||||
def from_state(state : dict) -> 'UImage':
|
||||
ufm = UImage()
|
||||
ufm.restore_state(state)
|
||||
return ufm
|
||||
|
||||
def restore_state(self, state : dict):
|
||||
self._uuid = state.get('_uuid', None)
|
||||
self._name = state.get('_name', None)
|
||||
self._image = state.get('_image', None)
|
||||
|
||||
def dump_state(self, exclude_image=False) -> dict:
|
||||
d = {'_uuid' : self._uuid,
|
||||
'_name' : self._name,
|
||||
}
|
||||
|
||||
if not exclude_image:
|
||||
d['_image'] = self._image
|
||||
|
||||
return d
|
||||
|
||||
def get_uuid(self) -> Union[bytes, None]:
|
||||
if self._uuid is None:
|
||||
self._uuid = uuid.uuid4().bytes_le
|
||||
self._uuid = uuid.uuid4().bytes
|
||||
return self._uuid
|
||||
|
||||
def set_uuid(self, uuid : Union[bytes, None]):
|
||||
|
|
|
@ -15,6 +15,12 @@ class UPerson(IState):
|
|||
def __str__(self): return f"UPerson UUID:[...{self._uuid[-4:].hex()}] name:[{self._name}] age:[{self._age}]"
|
||||
def __repr__(self): return self.__str__()
|
||||
|
||||
@staticmethod
|
||||
def from_state(state : dict) -> 'UPerson':
|
||||
ufm = UPerson()
|
||||
ufm.restore_state(state)
|
||||
return ufm
|
||||
|
||||
def restore_state(self, state : dict):
|
||||
self._uuid = state.get('_uuid', None)
|
||||
self._name = state.get('_name', None)
|
||||
|
@ -28,7 +34,7 @@ class UPerson(IState):
|
|||
|
||||
def get_uuid(self) -> Union[bytes, None]:
|
||||
if self._uuid is None:
|
||||
self._uuid = uuid.uuid4().bytes_le
|
||||
self._uuid = uuid.uuid4().bytes
|
||||
return self._uuid
|
||||
|
||||
def set_uuid(self, uuid : Union[bytes, None]):
|
||||
|
|
|
@ -2,12 +2,14 @@ import cv2
|
|||
import numpy as np
|
||||
import numpy.linalg as npla
|
||||
|
||||
def sot(src,trg, steps=10, batch_size=30, reg_sigmaXY=16.0, reg_sigmaV=5.0):
|
||||
def sot(src,trg, mask=None, steps=10, batch_size=30, reg_sigmaXY=16.0, reg_sigmaV=5.0, return_diff=False):
|
||||
"""
|
||||
Color Transform via Sliced Optimal Transfer, ported from https://github.com/dcoeurjo/OTColorTransfer
|
||||
|
||||
src - any float range any channel image
|
||||
dst - any float range any channel image, same shape as src
|
||||
mask -
|
||||
|
||||
steps - number of solver steps
|
||||
batch_size - solver batch size
|
||||
reg_sigmaXY - apply regularization and sigmaXY of filter, otherwise set to 0.0
|
||||
|
@ -39,8 +41,8 @@ def sot(src,trg, steps=10, batch_size=30, reg_sigmaXY=16.0, reg_sigmaV=5.0):
|
|||
dir = np.random.normal(size=c).astype(src_dtype)
|
||||
dir /= npla.norm(dir)
|
||||
|
||||
projsource = np.sum( new_src*dir, axis=-1).reshape ((h*w))
|
||||
projtarget = np.sum( trg*dir, axis=-1).reshape ((h*w))
|
||||
projsource = np.sum( new_src*dir*mask, axis=-1).reshape ((h*w))
|
||||
projtarget = np.sum( trg*dir*mask, axis=-1).reshape ((h*w))
|
||||
|
||||
idSource = np.argsort (projsource)
|
||||
idTarget = np.argsort (projtarget)
|
||||
|
@ -48,12 +50,18 @@ def sot(src,trg, steps=10, batch_size=30, reg_sigmaXY=16.0, reg_sigmaV=5.0):
|
|||
a = projtarget[idTarget]-projsource[idSource]
|
||||
for i_c in range(c):
|
||||
advect[idSource,i_c] += a * dir[i_c]
|
||||
new_src += advect.reshape( (h,w,c) ) / batch_size
|
||||
|
||||
new_src += (advect.reshape( (h,w,c) ) * mask) / batch_size
|
||||
|
||||
if reg_sigmaXY != 0.0:
|
||||
src_diff = new_src-src
|
||||
src_diff_filt = cv2.bilateralFilter (src_diff, 0, reg_sigmaV, reg_sigmaXY )
|
||||
if len(src_diff_filt.shape) == 2:
|
||||
src_diff_filt = src_diff_filt[...,None]
|
||||
new_src = src + src_diff_filt
|
||||
if return_diff:
|
||||
return src_diff_filt
|
||||
return src + src_diff_filt
|
||||
else:
|
||||
if return_diff:
|
||||
return new_src-src
|
||||
return new_src
|
||||
|
|
|
@ -37,7 +37,7 @@ class MPSPSCMRRingData:
|
|||
|
||||
# Initialize first block at 0 index
|
||||
wid = 0
|
||||
wid_uuid = uuid.uuid4().bytes_le
|
||||
wid_uuid = uuid.uuid4().bytes
|
||||
wid_heap_offset = 0
|
||||
wid_data_size = 0
|
||||
|
||||
|
@ -82,7 +82,7 @@ class MPSPSCMRRingData:
|
|||
raise Exception('data_size more than heap_size')
|
||||
|
||||
fmv = FormattedMemoryViewIO(self._shared_mem.get_mv())
|
||||
wid_uuid = uuid.uuid4().bytes_le
|
||||
wid_uuid = uuid.uuid4().bytes
|
||||
|
||||
if self._write_lock is not None:
|
||||
self._write_lock.acquire()
|
||||
|
|
|
@ -49,7 +49,7 @@ class MPWeakHeap:
|
|||
|
||||
# Entire block
|
||||
fmv.seek(self._first_block_offset)
|
||||
fmv.write_fmt('qq', self._heap_size-self._first_block_offset, 0), fmv.write(uuid.uuid4().bytes_le)
|
||||
fmv.write_fmt('qq', self._heap_size-self._first_block_offset, 0), fmv.write(uuid.uuid4().bytes)
|
||||
|
||||
|
||||
def add_data(self, data : Union[bytes, bytearray, memoryview] ) -> 'MPWeakHeap.DataRef':
|
||||
|
@ -92,7 +92,7 @@ class MPWeakHeap:
|
|||
if block_remain_size >= block_header_size:
|
||||
# the remain space of the block is enough for next block, split the block
|
||||
next_block_offset = cur_block_offset + block_new_size
|
||||
fmv.seek(next_block_offset), fmv.write_fmt('qq', block_remain_size, 0), fmv.write(uuid.uuid4().bytes_le)
|
||||
fmv.seek(next_block_offset), fmv.write_fmt('qq', block_remain_size, 0), fmv.write(uuid.uuid4().bytes)
|
||||
else:
|
||||
# otherwise do not split
|
||||
next_block_offset = cur_block_offset + block_size
|
||||
|
@ -101,7 +101,7 @@ class MPWeakHeap:
|
|||
block_new_size = block_size
|
||||
|
||||
# update current block structure
|
||||
uid = uuid.uuid4().bytes_le
|
||||
uid = uuid.uuid4().bytes
|
||||
fmv.seek(cur_block_offset), fmv.write_fmt('qq', block_new_size, data_size ), fmv.write(uid)
|
||||
|
||||
# update ring_head_block_offset
|
||||
|
@ -135,7 +135,7 @@ class MPWeakHeap:
|
|||
next_block_size, = fmv.get_fmt('q')
|
||||
|
||||
# erase data of next block
|
||||
fmv.write_fmt('qq', 0, 0), fmv.write(uuid.uuid4().bytes_le)
|
||||
fmv.write_fmt('qq', 0, 0), fmv.write(uuid.uuid4().bytes)
|
||||
|
||||
# overwrite current block size with expanded block size
|
||||
fmv.seek(cur_block_offset)
|
||||
|
|
|
@ -11,7 +11,6 @@ def _host_thread_proc(wref):
|
|||
break
|
||||
ref._host_process_messages(0.005)
|
||||
del ref
|
||||
print('_host_thread_proc exit')
|
||||
|
||||
class SPMTWorker:
|
||||
def __init__(self, *sub_args, **sub_kwargs):
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
from .device import (ORTDeviceInfo, get_available_devices_info,
|
||||
get_cpu_device)
|
||||
get_cpu_device_info)
|
||||
from .InferenceSession import InferenceSession_with_device
|
||||
|
|
|
@ -67,33 +67,34 @@ class ORTDeviceInfo:
|
|||
|
||||
_ort_devices_info = None
|
||||
|
||||
def get_cpu_device() -> ORTDeviceInfo:
|
||||
def get_cpu_device_info() -> ORTDeviceInfo:
|
||||
return ORTDeviceInfo(index=-1, execution_provider='CPUExecutionProvider', name='CPU', total_memory=0, free_memory=0)
|
||||
|
||||
def get_available_devices_info(include_cpu=True, cpu_only=False) -> List[ORTDeviceInfo]:
|
||||
"""
|
||||
returns a list of available ORTDeviceInfo
|
||||
"""
|
||||
global _ort_devices_info
|
||||
if _ort_devices_info is None:
|
||||
_initialize_ort_devices()
|
||||
devices = []
|
||||
if not cpu_only:
|
||||
global _ort_devices_info
|
||||
if _ort_devices_info is None:
|
||||
_initialize_ort_devices_info()
|
||||
_ort_devices_info = []
|
||||
for i in range ( int(os.environ.get('ORT_DEVICES_COUNT',0)) ):
|
||||
devices.append ( ORTDeviceInfo(index=int(os.environ[f'ORT_DEVICE_{i}_INDEX']),
|
||||
_ort_devices_info.append ( ORTDeviceInfo(index=int(os.environ[f'ORT_DEVICE_{i}_INDEX']),
|
||||
execution_provider=os.environ[f'ORT_DEVICE_{i}_EP'],
|
||||
name=os.environ[f'ORT_DEVICE_{i}_NAME'],
|
||||
total_memory=int(os.environ[f'ORT_DEVICE_{i}_TOTAL_MEM']),
|
||||
free_memory=int(os.environ[f'ORT_DEVICE_{i}_FREE_MEM']),
|
||||
) )
|
||||
if include_cpu or cpu_only:
|
||||
devices.append(get_cpu_device())
|
||||
_ort_devices_info = devices
|
||||
devices += _ort_devices_info
|
||||
if include_cpu:
|
||||
devices.append(get_cpu_device_info())
|
||||
|
||||
return _ort_devices_info
|
||||
return devices
|
||||
|
||||
|
||||
def _initialize_ort_devices():
|
||||
def _initialize_ort_devices_info():
|
||||
"""
|
||||
Determine available ORT devices, and place info about them to os.environ,
|
||||
they will be available in spawned subprocesses.
|
||||
|
@ -184,4 +185,4 @@ def _initialize_ort_devices():
|
|||
os.environ[f'ORT_DEVICE_{i}_TOTAL_MEM'] = str(device['total_mem'])
|
||||
os.environ[f'ORT_DEVICE_{i}_FREE_MEM'] = str(device['free_mem'])
|
||||
|
||||
_initialize_ort_devices()
|
||||
_initialize_ort_devices_info()
|
||||
|
|
|
@ -248,6 +248,7 @@ def ascii_table(table_def : List[str],
|
|||
lines.append(row_line)
|
||||
for sub_rows in rows:
|
||||
|
||||
|
||||
for row in sub_rows:
|
||||
line = ''
|
||||
|
||||
|
@ -288,7 +289,8 @@ def ascii_table(table_def : List[str],
|
|||
line += right_border
|
||||
|
||||
lines.append(line)
|
||||
if row_line is not None:
|
||||
|
||||
if len(sub_rows) != 0 and row_line is not None:
|
||||
lines.append(row_line)
|
||||
|
||||
return '\n'.join(lines)
|
|
@ -1,2 +1,2 @@
|
|||
from .device import (TorchDeviceInfo, TorchDevicesInfo, get_available_devices,
|
||||
get_cpu_device)
|
||||
from .device import (TorchDeviceInfo, get_available_devices_info,
|
||||
get_cpu_device_info, get_device, get_device_info_by_index)
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
from typing import List
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class TorchDeviceInfo:
|
||||
|
@ -45,96 +47,40 @@ class TorchDeviceInfo:
|
|||
def __repr__(self):
|
||||
return f'{self.__class__.__name__} object: ' + self.__str__()
|
||||
|
||||
# class TorchDevicesInfo:
|
||||
# """
|
||||
# picklable list of TorchDeviceInfo
|
||||
# """
|
||||
# def __init__(self, devices : List[TorchDeviceInfo] = None):
|
||||
# if devices is None:
|
||||
# devices = []
|
||||
# self._devices = devices
|
||||
|
||||
# def __getstate__(self):
|
||||
# return self.__dict__.copy()
|
||||
|
||||
# def __setstate__(self, d):
|
||||
# self.__init__()
|
||||
# self.__dict__.update(d)
|
||||
|
||||
# def add(self, device_or_devices : TorchDeviceInfo):
|
||||
# if isinstance(device_or_devices, TorchDeviceInfo):
|
||||
# if device_or_devices not in self._devices:
|
||||
# self._devices.append(device_or_devices)
|
||||
# elif isinstance(device_or_devices, TorchDevicesInfo):
|
||||
# for device in device_or_devices:
|
||||
# self.add(device)
|
||||
|
||||
# def copy(self):
|
||||
# return copy.deepcopy(self)
|
||||
|
||||
# def get_count(self): return len(self._devices)
|
||||
|
||||
# def get_largest_total_memory_device(self) -> TorchDeviceInfo:
|
||||
# raise NotImplementedError()
|
||||
# result = None
|
||||
# idx_mem = 0
|
||||
# for device in self._devices:
|
||||
# mem = device.get_total_memory()
|
||||
# if result is None or (mem is not None and mem > idx_mem):
|
||||
# result = device
|
||||
# idx_mem = mem
|
||||
# return result
|
||||
|
||||
# def get_smallest_total_memory_device(self) -> TorchDeviceInfo:
|
||||
# raise NotImplementedError()
|
||||
# result = None
|
||||
# idx_mem = sys.maxsize
|
||||
# for device in self._devices:
|
||||
# mem = device.get_total_memory()
|
||||
# if result is None or (mem is not None and mem < idx_mem):
|
||||
# result = device
|
||||
# idx_mem = mem
|
||||
# return result
|
||||
|
||||
# def __len__(self):
|
||||
# return len(self._devices)
|
||||
|
||||
# def __getitem__(self, key):
|
||||
# result = self._devices[key]
|
||||
# if isinstance(key, slice):
|
||||
# return self.__class__(result)
|
||||
# return result
|
||||
|
||||
# def __iter__(self):
|
||||
# for device in self._devices:
|
||||
# yield device
|
||||
|
||||
# def __str__(self): return f'{self.__class__.__name__}:[' + ', '.join([ device.__str__() for device in self._devices ]) + ']'
|
||||
# def __repr__(self): return f'{self.__class__.__name__}:[' + ', '.join([ device.__repr__() for device in self._devices ]) + ']'
|
||||
|
||||
|
||||
_torch_devices = None
|
||||
|
||||
def get_cpu_device() -> TorchDeviceInfo:
|
||||
def get_cpu_device_info() -> TorchDeviceInfo:
|
||||
return TorchDeviceInfo(index=-1, name='CPU', total_memory=0)
|
||||
|
||||
def get_available_devices(include_cpu=True, cpu_only=False) -> List[TorchDeviceInfo]:
|
||||
def get_device_info_by_index(index) -> Union[TorchDeviceInfo, None]:
|
||||
for device in get_available_devices_info(include_cpu=False):
|
||||
if device.get_index() == index:
|
||||
return device
|
||||
|
||||
return None
|
||||
|
||||
def get_device(device_info : TorchDeviceInfo) -> torch.device:
|
||||
if device_info.is_cpu():
|
||||
return torch.device('cpu')
|
||||
return torch.device(f'cuda:{device_info.get_index()}')
|
||||
|
||||
def get_available_devices_info(include_cpu=True, cpu_only=False) -> List[TorchDeviceInfo]:
|
||||
"""
|
||||
returns a list of available TorchDeviceInfo
|
||||
"""
|
||||
devices = []
|
||||
if not cpu_only:
|
||||
global _torch_devices
|
||||
if _torch_devices is None:
|
||||
import torch
|
||||
devices = []
|
||||
|
||||
if not cpu_only:
|
||||
_torch_devices = []
|
||||
for i in range (torch.cuda.device_count()):
|
||||
device_props = torch.cuda.get_device_properties(i)
|
||||
devices.append ( TorchDeviceInfo(index=i, name=device_props.name, total_memory=device_props.total_memory))
|
||||
_torch_devices.append ( TorchDeviceInfo(index=i, name=device_props.name, total_memory=device_props.total_memory))
|
||||
devices += _torch_devices
|
||||
|
||||
if include_cpu or cpu_only:
|
||||
devices.append ( get_cpu_device() )
|
||||
if include_cpu:
|
||||
devices.append ( get_cpu_device_info() )
|
||||
|
||||
_torch_devices = devices
|
||||
return _torch_devices
|
||||
return devices
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue