From 78d80f9c5cd15aaa928f8ec2edfbd8372515e298 Mon Sep 17 00:00:00 2001 From: iperov Date: Wed, 20 Oct 2021 18:05:56 +0400 Subject: [PATCH] refactoring --- .../FaceAligner/FaceAlignerTrainer.py | 22 +++ main.py | 13 ++ xlib/facemeta/Faceset.py | 165 +++++++++++------- xlib/image/__init__.py | 2 +- xlib/image/{misc.py => _misc.py} | 0 5 files changed, 135 insertions(+), 67 deletions(-) create mode 100644 apps/trainers/FaceAligner/FaceAlignerTrainer.py rename xlib/image/{misc.py => _misc.py} (100%) diff --git a/apps/trainers/FaceAligner/FaceAlignerTrainer.py b/apps/trainers/FaceAligner/FaceAlignerTrainer.py new file mode 100644 index 0000000..df1c70e --- /dev/null +++ b/apps/trainers/FaceAligner/FaceAlignerTrainer.py @@ -0,0 +1,22 @@ +from xlib import facemeta as lib_fm +from xlib import time as lib_time + + +class FaceAlignerTrainer: + def __init__(self, faceset_path): + #fs = self._fs = lib_fm.Faceset(faceset_path) + fs = lib_fm.Faceset(faceset_path) + #fs.close() + + with lib_time.timeit(): + for x in fs.iter_UImage(): + x.get_image() + #fs = lib_fm.Faceset(faceset_path) + #fs.add_UFaceMark( [ lib_fm.UFaceMark() for _ in range(1000)] ) + + import code + code.interact(local=dict(globals(), **locals())) + + + def run(self): + ... \ No newline at end of file diff --git a/main.py b/main.py index 9c4f407..ade4f4e 100644 --- a/main.py +++ b/main.py @@ -52,6 +52,19 @@ def main(): p.add_argument('--delete-parts', action="store_true", default=False) p.set_defaults(func=run_merge_large_files) + train_parser = subparsers.add_parser( "train", help="Train neural network.") + train_parsers = train_parser.add_subparsers() + + def train_FaceAligner(args): + faceset_path = Path(args.faceset_path) + + from apps.trainers.FaceAligner.FaceAlignerTrainer import FaceAlignerTrainer + FaceAlignerTrainer(faceset_path=faceset_path).run() + + p = train_parsers.add_parser('FaceAligner') + p.add_argument('--faceset-path', default=None, action=fixPathAction, help=".dfs path") + p.set_defaults(func=train_FaceAligner) + def bad_args(arguments): parser.print_help() exit(0) diff --git a/xlib/facemeta/Faceset.py b/xlib/facemeta/Faceset.py index 365303b..c690495 100644 --- a/xlib/facemeta/Faceset.py +++ b/xlib/facemeta/Faceset.py @@ -1,7 +1,7 @@ import pickle import sqlite3 from pathlib import Path -from typing import Generator, List, Union +from typing import Generator, List, Union, Iterable import cv2 import numpy as np @@ -11,50 +11,69 @@ from .UFaceMark import UFaceMark from .UImage import UImage from .UPerson import UPerson - class Faceset: - def __init__(self, path): + def __init__(self, path = None): """ Faceset is a class to store and manage face related data. arguments: path path to faceset .dfs file + + Can be pickled. """ - self._path = path = Path(path) - if path.suffix != '.dfs': raise ValueError('Path must be a .dfs file') self._conn = conn = sqlite3.connect(path, isolation_level=None) self._cur = cur = conn.cursor() + cur = self._get_cursor() cur.execute('BEGIN IMMEDIATE') if not self._is_table_exists('FacesetInfo'): - self.clear_db(transaction=False) + self.recreate(shrink=False, _transaction=False) cur.execute('COMMIT') + self.shrink() - def close(self): - self._cur.close() - self._cur = None - self._conn.close() - self._conn = None + def __del__(self): + self.close() + + def __getstate__(self): + return {'_path' : self._path} + + def __setstate__(self, d): + self.__init__( d['_path'] ) + + def __repr__(self): return self.__str__() + def __str__(self): + return f"Faceset. UImage:{self.get_UImage_count()} UFaceMark:{self.get_UFaceMark_count()}" def _is_table_exists(self, name): return self._cur.execute(f"SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", [name]).fetchone()[0] != 0 + def _get_cursor(self) -> sqlite3.Cursor: return self._cur + + def close(self): + if self._cur is not None: + self._cur.close() + self._cur = None + + if self._conn is not None: + self._conn.close() + self._conn = None + def shrink(self): self._cur.execute('VACUUM') - - def clear_db(self, transaction=True): - """ - delete all data and recreate DB - """ - cur = self._cur - if transaction: + def recreate(self, shrink=True, _transaction=True): + """ + delete all data and recreate Faceset structure. + """ + cur = self._get_cursor() + + if _transaction: cur.execute('BEGIN IMMEDIATE') for table_name, in cur.execute("SELECT name from sqlite_master where type = 'table';").fetchall(): @@ -68,34 +87,42 @@ class Faceset: .execute('CREATE TABLE UFaceMark (uuid BLOB, UImage_uuid BLOB, UPerson_uuid BLOB, pickled_bytes BLOB)') ) - if transaction: + if _transaction: cur.execute('COMMIT') - + + if shrink: + self.shrink() + ################### ### UFaceMark ################### def _UFaceMark_from_db_row(self, db_row) -> UFaceMark: uuid, UImage_uuid, UPerson_uuid, pickled_bytes = db_row return pickle.loads(pickled_bytes) - - def add_UFaceMark(self, fm : UFaceMark): + + def add_UFaceMark(self, ufacemark_or_list : UFaceMark): """ add or update UFaceMark in DB """ - pickled_bytes = pickle.dumps(fm) - uuid = fm.get_uuid() - UImage_uuid = fm.get_UImage_uuid() - UPerson_uuid = fm.get_UPerson_uuid() + if not isinstance(ufacemark_or_list, Iterable): + ufacemark_or_list = [ufacemark_or_list] cur = self._cur cur.execute('BEGIN IMMEDIATE') - if cur.execute('SELECT COUNT(*) from UFaceMark where uuid=?', [uuid] ).fetchone()[0] != 0: - cur.execute('UPDATE UFaceMark SET UImage_uuid=?, UPerson_uuid=?, pickled_bytes=? WHERE uuid=?', - [UImage_uuid, UPerson_uuid, pickled_bytes, uuid]) - else: - cur.execute('INSERT INTO UFaceMark VALUES (?, ?, ?, ?)', [uuid, UImage_uuid, UPerson_uuid, pickled_bytes]) + for ufacemark in ufacemark_or_list: + pickled_bytes = pickle.dumps(ufacemark) + uuid = ufacemark.get_uuid() + UImage_uuid = ufacemark.get_UImage_uuid() + UPerson_uuid = ufacemark.get_UPerson_uuid() + + if cur.execute('SELECT COUNT(*) from UFaceMark where uuid=?', [uuid] ).fetchone()[0] != 0: + cur.execute('UPDATE UFaceMark SET UImage_uuid=?, UPerson_uuid=?, pickled_bytes=? WHERE uuid=?', + [UImage_uuid, UPerson_uuid, pickled_bytes, uuid]) + else: + cur.execute('INSERT INTO UFaceMark VALUES (?, ?, ?, ?)', [uuid, UImage_uuid, UPerson_uuid, pickled_bytes]) cur.execute('COMMIT') + def get_UFaceMark_count(self) -> int: return self._cur.execute('SELECT COUNT(*) FROM UFaceMark').fetchone()[0] @@ -116,26 +143,27 @@ class Faceset: (self._cur.execute('BEGIN IMMEDIATE') .execute('DELETE FROM UFaceMark') .execute('COMMIT') ) - ################### ### UPerson ################### - def add_UPerson(self, uperson : UPerson): + def add_UPerson(self, uperson_or_list : UPerson): """ add or update UPerson in DB """ - uuid = uperson.get_uuid() - name = uperson.get_name() - age = uperson.get_age() + if not isinstance(uperson_or_list, Iterable): + uperson_or_list = [uperson_or_list] - cur = self._conn.cursor() + cur = self._cur cur.execute('BEGIN IMMEDIATE') - if cur.execute('SELECT COUNT(*) from UPerson where uuid=?', [uuid]).fetchone()[0] != 0: - cur.execute('UPDATE UPerson SET name=?, age=? WHERE uuid=?', [name, age, uuid]) - else: - cur.execute('INSERT INTO UPerson VALUES (?, ?, ?)', [uuid, name, age]) + for uperson in uperson_or_list: + uuid = uperson.get_uuid() + name = uperson.get_name() + age = uperson.get_age() + if cur.execute('SELECT COUNT(*) from UPerson where uuid=?', [uuid]).fetchone()[0] != 0: + cur.execute('UPDATE UPerson SET name=?, age=? WHERE uuid=?', [name, age, uuid]) + else: + cur.execute('INSERT INTO UPerson VALUES (?, ?, ?)', [uuid, name, age]) cur.execute('COMMIT') - cur.close() def iter_UPerson(self) -> Generator[UPerson, None, None]: """ @@ -155,7 +183,7 @@ class Faceset: (self._cur.execute('BEGIN IMMEDIATE') .execute('DELETE FROM UPerson') .execute('COMMIT') ) - + ################### ### UImage ################### @@ -168,8 +196,8 @@ class Faceset: uimg.set_name(name) uimg.assign_image(img) return uimg - - def add_UImage(self, uimage : UImage, format : str = 'webp', quality : int = 100): + + def add_UImage(self, uimage_or_list : UImage, format : str = 'webp', quality : int = 100): """ add or update UImage in DB @@ -188,30 +216,35 @@ class Faceset: if format in ['jpg','jp2'] and quality < 0 or quality > 100: raise ValueError('quality must be in range [0..100]') - img = uimage.get_image() - uuid = uimage.get_uuid() - - if format == 'webp': - imencode_args = [int(cv2.IMWRITE_WEBP_QUALITY), quality] - elif format == 'jpg': - imencode_args = [int(cv2.IMWRITE_JPEG_QUALITY), quality] - elif format == 'jp2': - imencode_args = [int(cv2.IMWRITE_JPEG2000_COMPRESSION_X1000), quality*10] - else: - imencode_args = [] - - ret, data_bytes = cv2.imencode( f'.{format}', img, imencode_args) - if not ret: - raise Exception(f'Unable to encode image format {format}') + if not isinstance(uimage_or_list, Iterable): + uimage_or_list = [uimage_or_list] cur = self._cur cur.execute('BEGIN IMMEDIATE') - if cur.execute('SELECT COUNT(*) from UImage where uuid=?', [uuid] ).fetchone()[0] != 0: - cur.execute('UPDATE UImage SET name=?, format=?, data=? WHERE uuid=?', [uimage.get_name(), format, data_bytes.data, uuid]) - else: - cur.execute('INSERT INTO UImage VALUES (?, ?, ?, ?)', [uuid, uimage.get_name(), format, data_bytes.data]) + for uimage in uimage_or_list: + # TODO optimize move encoding to out of transaction + img = uimage.get_image() + uuid = uimage.get_uuid() + + if format == 'webp': + imencode_args = [int(cv2.IMWRITE_WEBP_QUALITY), quality] + elif format == 'jpg': + imencode_args = [int(cv2.IMWRITE_JPEG_QUALITY), quality] + elif format == 'jp2': + imencode_args = [int(cv2.IMWRITE_JPEG2000_COMPRESSION_X1000), quality*10] + else: + imencode_args = [] + + ret, data_bytes = cv2.imencode( f'.{format}', img, imencode_args) + if not ret: + raise Exception(f'Unable to encode image format {format}') + + if cur.execute('SELECT COUNT(*) from UImage where uuid=?', [uuid] ).fetchone()[0] != 0: + cur.execute('UPDATE UImage SET name=?, format=?, data=? WHERE uuid=?', [uimage.get_name(), format, data_bytes.data, uuid]) + else: + cur.execute('INSERT INTO UImage VALUES (?, ?, ?, ?)', [uuid, uimage.get_name(), format, data_bytes.data]) cur.execute('COMMIT') - + def get_UImage_count(self) -> int: return self._cur.execute('SELECT COUNT(*) FROM UImage').fetchone()[0] def get_UImage_by_uuid(self, uuid : bytes) -> Union[UImage, None]: """ @@ -220,7 +253,7 @@ class Faceset: if db_row is None: return None return self._UImage_from_db_row(db_row) - + def iter_UImage(self) -> Generator[UImage, None, None]: """ iterator of all UImage's diff --git a/xlib/image/__init__.py b/xlib/image/__init__.py index ff00996..727286a 100644 --- a/xlib/image/__init__.py +++ b/xlib/image/__init__.py @@ -1,2 +1,2 @@ from .ImageProcessor import ImageProcessor -from .misc import get_NHWC_shape \ No newline at end of file +from ._misc import get_NHWC_shape \ No newline at end of file diff --git a/xlib/image/misc.py b/xlib/image/_misc.py similarity index 100% rename from xlib/image/misc.py rename to xlib/image/_misc.py