refactoring

This commit is contained in:
iperov 2021-10-20 18:05:56 +04:00
parent 7463515bfc
commit 78d80f9c5c
5 changed files with 135 additions and 67 deletions

View file

@ -0,0 +1,22 @@
from xlib import facemeta as lib_fm
from xlib import time as lib_time
class FaceAlignerTrainer:
def __init__(self, faceset_path):
#fs = self._fs = lib_fm.Faceset(faceset_path)
fs = lib_fm.Faceset(faceset_path)
#fs.close()
with lib_time.timeit():
for x in fs.iter_UImage():
x.get_image()
#fs = lib_fm.Faceset(faceset_path)
#fs.add_UFaceMark( [ lib_fm.UFaceMark() for _ in range(1000)] )
import code
code.interact(local=dict(globals(), **locals()))
def run(self):
...

13
main.py
View file

@ -52,6 +52,19 @@ def main():
p.add_argument('--delete-parts', action="store_true", default=False) p.add_argument('--delete-parts', action="store_true", default=False)
p.set_defaults(func=run_merge_large_files) p.set_defaults(func=run_merge_large_files)
train_parser = subparsers.add_parser( "train", help="Train neural network.")
train_parsers = train_parser.add_subparsers()
def train_FaceAligner(args):
faceset_path = Path(args.faceset_path)
from apps.trainers.FaceAligner.FaceAlignerTrainer import FaceAlignerTrainer
FaceAlignerTrainer(faceset_path=faceset_path).run()
p = train_parsers.add_parser('FaceAligner')
p.add_argument('--faceset-path', default=None, action=fixPathAction, help=".dfs path")
p.set_defaults(func=train_FaceAligner)
def bad_args(arguments): def bad_args(arguments):
parser.print_help() parser.print_help()
exit(0) exit(0)

View file

@ -1,7 +1,7 @@
import pickle import pickle
import sqlite3 import sqlite3
from pathlib import Path from pathlib import Path
from typing import Generator, List, Union from typing import Generator, List, Union, Iterable
import cv2 import cv2
import numpy as np import numpy as np
@ -11,50 +11,69 @@ from .UFaceMark import UFaceMark
from .UImage import UImage from .UImage import UImage
from .UPerson import UPerson from .UPerson import UPerson
class Faceset: class Faceset:
def __init__(self, path): def __init__(self, path = None):
""" """
Faceset is a class to store and manage face related data. Faceset is a class to store and manage face related data.
arguments: arguments:
path path to faceset .dfs file path path to faceset .dfs file
Can be pickled.
""" """
self._path = path = Path(path) self._path = path = Path(path)
if path.suffix != '.dfs': if path.suffix != '.dfs':
raise ValueError('Path must be a .dfs file') raise ValueError('Path must be a .dfs file')
self._conn = conn = sqlite3.connect(path, isolation_level=None) self._conn = conn = sqlite3.connect(path, isolation_level=None)
self._cur = cur = conn.cursor() self._cur = cur = conn.cursor()
cur = self._get_cursor()
cur.execute('BEGIN IMMEDIATE') cur.execute('BEGIN IMMEDIATE')
if not self._is_table_exists('FacesetInfo'): if not self._is_table_exists('FacesetInfo'):
self.clear_db(transaction=False) self.recreate(shrink=False, _transaction=False)
cur.execute('COMMIT') cur.execute('COMMIT')
self.shrink()
def close(self): def __del__(self):
self._cur.close() self.close()
self._cur = None
self._conn.close() def __getstate__(self):
self._conn = None return {'_path' : self._path}
def __setstate__(self, d):
self.__init__( d['_path'] )
def __repr__(self): return self.__str__()
def __str__(self):
return f"Faceset. UImage:{self.get_UImage_count()} UFaceMark:{self.get_UFaceMark_count()}"
def _is_table_exists(self, name): def _is_table_exists(self, name):
return self._cur.execute(f"SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", [name]).fetchone()[0] != 0 return self._cur.execute(f"SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", [name]).fetchone()[0] != 0
def _get_cursor(self) -> sqlite3.Cursor: return self._cur
def close(self):
if self._cur is not None:
self._cur.close()
self._cur = None
if self._conn is not None:
self._conn.close()
self._conn = None
def shrink(self): def shrink(self):
self._cur.execute('VACUUM') self._cur.execute('VACUUM')
def clear_db(self, transaction=True): def recreate(self, shrink=True, _transaction=True):
""" """
delete all data and recreate DB delete all data and recreate Faceset structure.
""" """
cur = self._cur cur = self._get_cursor()
if transaction: if _transaction:
cur.execute('BEGIN IMMEDIATE') cur.execute('BEGIN IMMEDIATE')
for table_name, in cur.execute("SELECT name from sqlite_master where type = 'table';").fetchall(): for table_name, in cur.execute("SELECT name from sqlite_master where type = 'table';").fetchall():
@ -68,9 +87,12 @@ class Faceset:
.execute('CREATE TABLE UFaceMark (uuid BLOB, UImage_uuid BLOB, UPerson_uuid BLOB, pickled_bytes BLOB)') .execute('CREATE TABLE UFaceMark (uuid BLOB, UImage_uuid BLOB, UPerson_uuid BLOB, pickled_bytes BLOB)')
) )
if transaction: if _transaction:
cur.execute('COMMIT') cur.execute('COMMIT')
if shrink:
self.shrink()
################### ###################
### UFaceMark ### UFaceMark
################### ###################
@ -78,17 +100,21 @@ class Faceset:
uuid, UImage_uuid, UPerson_uuid, pickled_bytes = db_row uuid, UImage_uuid, UPerson_uuid, pickled_bytes = db_row
return pickle.loads(pickled_bytes) return pickle.loads(pickled_bytes)
def add_UFaceMark(self, fm : UFaceMark): def add_UFaceMark(self, ufacemark_or_list : UFaceMark):
""" """
add or update UFaceMark in DB add or update UFaceMark in DB
""" """
pickled_bytes = pickle.dumps(fm) if not isinstance(ufacemark_or_list, Iterable):
uuid = fm.get_uuid() ufacemark_or_list = [ufacemark_or_list]
UImage_uuid = fm.get_UImage_uuid()
UPerson_uuid = fm.get_UPerson_uuid()
cur = self._cur cur = self._cur
cur.execute('BEGIN IMMEDIATE') cur.execute('BEGIN IMMEDIATE')
for ufacemark in ufacemark_or_list:
pickled_bytes = pickle.dumps(ufacemark)
uuid = ufacemark.get_uuid()
UImage_uuid = ufacemark.get_UImage_uuid()
UPerson_uuid = ufacemark.get_UPerson_uuid()
if cur.execute('SELECT COUNT(*) from UFaceMark where uuid=?', [uuid] ).fetchone()[0] != 0: if cur.execute('SELECT COUNT(*) from UFaceMark where uuid=?', [uuid] ).fetchone()[0] != 0:
cur.execute('UPDATE UFaceMark SET UImage_uuid=?, UPerson_uuid=?, pickled_bytes=? WHERE uuid=?', cur.execute('UPDATE UFaceMark SET UImage_uuid=?, UPerson_uuid=?, pickled_bytes=? WHERE uuid=?',
[UImage_uuid, UPerson_uuid, pickled_bytes, uuid]) [UImage_uuid, UPerson_uuid, pickled_bytes, uuid])
@ -96,6 +122,7 @@ class Faceset:
cur.execute('INSERT INTO UFaceMark VALUES (?, ?, ?, ?)', [uuid, UImage_uuid, UPerson_uuid, pickled_bytes]) cur.execute('INSERT INTO UFaceMark VALUES (?, ?, ?, ?)', [uuid, UImage_uuid, UPerson_uuid, pickled_bytes])
cur.execute('COMMIT') cur.execute('COMMIT')
def get_UFaceMark_count(self) -> int: def get_UFaceMark_count(self) -> int:
return self._cur.execute('SELECT COUNT(*) FROM UFaceMark').fetchone()[0] return self._cur.execute('SELECT COUNT(*) FROM UFaceMark').fetchone()[0]
@ -116,26 +143,27 @@ class Faceset:
(self._cur.execute('BEGIN IMMEDIATE') (self._cur.execute('BEGIN IMMEDIATE')
.execute('DELETE FROM UFaceMark') .execute('DELETE FROM UFaceMark')
.execute('COMMIT') ) .execute('COMMIT') )
################### ###################
### UPerson ### UPerson
################### ###################
def add_UPerson(self, uperson : UPerson): def add_UPerson(self, uperson_or_list : UPerson):
""" """
add or update UPerson in DB add or update UPerson in DB
""" """
if not isinstance(uperson_or_list, Iterable):
uperson_or_list = [uperson_or_list]
cur = self._cur
cur.execute('BEGIN IMMEDIATE')
for uperson in uperson_or_list:
uuid = uperson.get_uuid() uuid = uperson.get_uuid()
name = uperson.get_name() name = uperson.get_name()
age = uperson.get_age() age = uperson.get_age()
cur = self._conn.cursor()
cur.execute('BEGIN IMMEDIATE')
if cur.execute('SELECT COUNT(*) from UPerson where uuid=?', [uuid]).fetchone()[0] != 0: if cur.execute('SELECT COUNT(*) from UPerson where uuid=?', [uuid]).fetchone()[0] != 0:
cur.execute('UPDATE UPerson SET name=?, age=? WHERE uuid=?', [name, age, uuid]) cur.execute('UPDATE UPerson SET name=?, age=? WHERE uuid=?', [name, age, uuid])
else: else:
cur.execute('INSERT INTO UPerson VALUES (?, ?, ?)', [uuid, name, age]) cur.execute('INSERT INTO UPerson VALUES (?, ?, ?)', [uuid, name, age])
cur.execute('COMMIT') cur.execute('COMMIT')
cur.close()
def iter_UPerson(self) -> Generator[UPerson, None, None]: def iter_UPerson(self) -> Generator[UPerson, None, None]:
""" """
@ -169,7 +197,7 @@ class Faceset:
uimg.assign_image(img) uimg.assign_image(img)
return uimg return uimg
def add_UImage(self, uimage : UImage, format : str = 'webp', quality : int = 100): def add_UImage(self, uimage_or_list : UImage, format : str = 'webp', quality : int = 100):
""" """
add or update UImage in DB add or update UImage in DB
@ -188,6 +216,13 @@ class Faceset:
if format in ['jpg','jp2'] and quality < 0 or quality > 100: if format in ['jpg','jp2'] and quality < 0 or quality > 100:
raise ValueError('quality must be in range [0..100]') raise ValueError('quality must be in range [0..100]')
if not isinstance(uimage_or_list, Iterable):
uimage_or_list = [uimage_or_list]
cur = self._cur
cur.execute('BEGIN IMMEDIATE')
for uimage in uimage_or_list:
# TODO optimize move encoding to out of transaction
img = uimage.get_image() img = uimage.get_image()
uuid = uimage.get_uuid() uuid = uimage.get_uuid()
@ -204,8 +239,6 @@ class Faceset:
if not ret: if not ret:
raise Exception(f'Unable to encode image format {format}') raise Exception(f'Unable to encode image format {format}')
cur = self._cur
cur.execute('BEGIN IMMEDIATE')
if cur.execute('SELECT COUNT(*) from UImage where uuid=?', [uuid] ).fetchone()[0] != 0: if cur.execute('SELECT COUNT(*) from UImage where uuid=?', [uuid] ).fetchone()[0] != 0:
cur.execute('UPDATE UImage SET name=?, format=?, data=? WHERE uuid=?', [uimage.get_name(), format, data_bytes.data, uuid]) cur.execute('UPDATE UImage SET name=?, format=?, data=? WHERE uuid=?', [uimage.get_name(), format, data_bytes.data, uuid])
else: else:

View file

@ -1,2 +1,2 @@
from .ImageProcessor import ImageProcessor from .ImageProcessor import ImageProcessor
from .misc import get_NHWC_shape from ._misc import get_NHWC_shape