mirror of
https://github.com/iperov/DeepFaceLive
synced 2025-07-10 23:33:32 -07:00
refactoring
This commit is contained in:
parent
7463515bfc
commit
78d80f9c5c
5 changed files with 135 additions and 67 deletions
22
apps/trainers/FaceAligner/FaceAlignerTrainer.py
Normal file
22
apps/trainers/FaceAligner/FaceAlignerTrainer.py
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
from xlib import facemeta as lib_fm
|
||||||
|
from xlib import time as lib_time
|
||||||
|
|
||||||
|
|
||||||
|
class FaceAlignerTrainer:
|
||||||
|
def __init__(self, faceset_path):
|
||||||
|
#fs = self._fs = lib_fm.Faceset(faceset_path)
|
||||||
|
fs = lib_fm.Faceset(faceset_path)
|
||||||
|
#fs.close()
|
||||||
|
|
||||||
|
with lib_time.timeit():
|
||||||
|
for x in fs.iter_UImage():
|
||||||
|
x.get_image()
|
||||||
|
#fs = lib_fm.Faceset(faceset_path)
|
||||||
|
#fs.add_UFaceMark( [ lib_fm.UFaceMark() for _ in range(1000)] )
|
||||||
|
|
||||||
|
import code
|
||||||
|
code.interact(local=dict(globals(), **locals()))
|
||||||
|
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
...
|
13
main.py
13
main.py
|
@ -52,6 +52,19 @@ def main():
|
||||||
p.add_argument('--delete-parts', action="store_true", default=False)
|
p.add_argument('--delete-parts', action="store_true", default=False)
|
||||||
p.set_defaults(func=run_merge_large_files)
|
p.set_defaults(func=run_merge_large_files)
|
||||||
|
|
||||||
|
train_parser = subparsers.add_parser( "train", help="Train neural network.")
|
||||||
|
train_parsers = train_parser.add_subparsers()
|
||||||
|
|
||||||
|
def train_FaceAligner(args):
|
||||||
|
faceset_path = Path(args.faceset_path)
|
||||||
|
|
||||||
|
from apps.trainers.FaceAligner.FaceAlignerTrainer import FaceAlignerTrainer
|
||||||
|
FaceAlignerTrainer(faceset_path=faceset_path).run()
|
||||||
|
|
||||||
|
p = train_parsers.add_parser('FaceAligner')
|
||||||
|
p.add_argument('--faceset-path', default=None, action=fixPathAction, help=".dfs path")
|
||||||
|
p.set_defaults(func=train_FaceAligner)
|
||||||
|
|
||||||
def bad_args(arguments):
|
def bad_args(arguments):
|
||||||
parser.print_help()
|
parser.print_help()
|
||||||
exit(0)
|
exit(0)
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import pickle
|
import pickle
|
||||||
import sqlite3
|
import sqlite3
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Generator, List, Union
|
from typing import Generator, List, Union, Iterable
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -11,50 +11,69 @@ from .UFaceMark import UFaceMark
|
||||||
from .UImage import UImage
|
from .UImage import UImage
|
||||||
from .UPerson import UPerson
|
from .UPerson import UPerson
|
||||||
|
|
||||||
|
|
||||||
class Faceset:
|
class Faceset:
|
||||||
|
|
||||||
def __init__(self, path):
|
def __init__(self, path = None):
|
||||||
"""
|
"""
|
||||||
Faceset is a class to store and manage face related data.
|
Faceset is a class to store and manage face related data.
|
||||||
|
|
||||||
arguments:
|
arguments:
|
||||||
|
|
||||||
path path to faceset .dfs file
|
path path to faceset .dfs file
|
||||||
|
|
||||||
|
Can be pickled.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self._path = path = Path(path)
|
self._path = path = Path(path)
|
||||||
|
|
||||||
if path.suffix != '.dfs':
|
if path.suffix != '.dfs':
|
||||||
raise ValueError('Path must be a .dfs file')
|
raise ValueError('Path must be a .dfs file')
|
||||||
|
|
||||||
self._conn = conn = sqlite3.connect(path, isolation_level=None)
|
self._conn = conn = sqlite3.connect(path, isolation_level=None)
|
||||||
self._cur = cur = conn.cursor()
|
self._cur = cur = conn.cursor()
|
||||||
|
|
||||||
|
cur = self._get_cursor()
|
||||||
cur.execute('BEGIN IMMEDIATE')
|
cur.execute('BEGIN IMMEDIATE')
|
||||||
if not self._is_table_exists('FacesetInfo'):
|
if not self._is_table_exists('FacesetInfo'):
|
||||||
self.clear_db(transaction=False)
|
self.recreate(shrink=False, _transaction=False)
|
||||||
cur.execute('COMMIT')
|
cur.execute('COMMIT')
|
||||||
|
self.shrink()
|
||||||
|
|
||||||
def close(self):
|
def __del__(self):
|
||||||
self._cur.close()
|
self.close()
|
||||||
self._cur = None
|
|
||||||
self._conn.close()
|
def __getstate__(self):
|
||||||
self._conn = None
|
return {'_path' : self._path}
|
||||||
|
|
||||||
|
def __setstate__(self, d):
|
||||||
|
self.__init__( d['_path'] )
|
||||||
|
|
||||||
|
def __repr__(self): return self.__str__()
|
||||||
|
def __str__(self):
|
||||||
|
return f"Faceset. UImage:{self.get_UImage_count()} UFaceMark:{self.get_UFaceMark_count()}"
|
||||||
|
|
||||||
def _is_table_exists(self, name):
|
def _is_table_exists(self, name):
|
||||||
return self._cur.execute(f"SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", [name]).fetchone()[0] != 0
|
return self._cur.execute(f"SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", [name]).fetchone()[0] != 0
|
||||||
|
|
||||||
|
def _get_cursor(self) -> sqlite3.Cursor: return self._cur
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
if self._cur is not None:
|
||||||
|
self._cur.close()
|
||||||
|
self._cur = None
|
||||||
|
|
||||||
|
if self._conn is not None:
|
||||||
|
self._conn.close()
|
||||||
|
self._conn = None
|
||||||
|
|
||||||
def shrink(self):
|
def shrink(self):
|
||||||
self._cur.execute('VACUUM')
|
self._cur.execute('VACUUM')
|
||||||
|
|
||||||
def clear_db(self, transaction=True):
|
def recreate(self, shrink=True, _transaction=True):
|
||||||
"""
|
"""
|
||||||
delete all data and recreate DB
|
delete all data and recreate Faceset structure.
|
||||||
"""
|
"""
|
||||||
cur = self._cur
|
cur = self._get_cursor()
|
||||||
|
|
||||||
if transaction:
|
if _transaction:
|
||||||
cur.execute('BEGIN IMMEDIATE')
|
cur.execute('BEGIN IMMEDIATE')
|
||||||
|
|
||||||
for table_name, in cur.execute("SELECT name from sqlite_master where type = 'table';").fetchall():
|
for table_name, in cur.execute("SELECT name from sqlite_master where type = 'table';").fetchall():
|
||||||
|
@ -68,9 +87,12 @@ class Faceset:
|
||||||
.execute('CREATE TABLE UFaceMark (uuid BLOB, UImage_uuid BLOB, UPerson_uuid BLOB, pickled_bytes BLOB)')
|
.execute('CREATE TABLE UFaceMark (uuid BLOB, UImage_uuid BLOB, UPerson_uuid BLOB, pickled_bytes BLOB)')
|
||||||
)
|
)
|
||||||
|
|
||||||
if transaction:
|
if _transaction:
|
||||||
cur.execute('COMMIT')
|
cur.execute('COMMIT')
|
||||||
|
|
||||||
|
if shrink:
|
||||||
|
self.shrink()
|
||||||
|
|
||||||
###################
|
###################
|
||||||
### UFaceMark
|
### UFaceMark
|
||||||
###################
|
###################
|
||||||
|
@ -78,17 +100,21 @@ class Faceset:
|
||||||
uuid, UImage_uuid, UPerson_uuid, pickled_bytes = db_row
|
uuid, UImage_uuid, UPerson_uuid, pickled_bytes = db_row
|
||||||
return pickle.loads(pickled_bytes)
|
return pickle.loads(pickled_bytes)
|
||||||
|
|
||||||
def add_UFaceMark(self, fm : UFaceMark):
|
def add_UFaceMark(self, ufacemark_or_list : UFaceMark):
|
||||||
"""
|
"""
|
||||||
add or update UFaceMark in DB
|
add or update UFaceMark in DB
|
||||||
"""
|
"""
|
||||||
pickled_bytes = pickle.dumps(fm)
|
if not isinstance(ufacemark_or_list, Iterable):
|
||||||
uuid = fm.get_uuid()
|
ufacemark_or_list = [ufacemark_or_list]
|
||||||
UImage_uuid = fm.get_UImage_uuid()
|
|
||||||
UPerson_uuid = fm.get_UPerson_uuid()
|
|
||||||
|
|
||||||
cur = self._cur
|
cur = self._cur
|
||||||
cur.execute('BEGIN IMMEDIATE')
|
cur.execute('BEGIN IMMEDIATE')
|
||||||
|
for ufacemark in ufacemark_or_list:
|
||||||
|
pickled_bytes = pickle.dumps(ufacemark)
|
||||||
|
uuid = ufacemark.get_uuid()
|
||||||
|
UImage_uuid = ufacemark.get_UImage_uuid()
|
||||||
|
UPerson_uuid = ufacemark.get_UPerson_uuid()
|
||||||
|
|
||||||
if cur.execute('SELECT COUNT(*) from UFaceMark where uuid=?', [uuid] ).fetchone()[0] != 0:
|
if cur.execute('SELECT COUNT(*) from UFaceMark where uuid=?', [uuid] ).fetchone()[0] != 0:
|
||||||
cur.execute('UPDATE UFaceMark SET UImage_uuid=?, UPerson_uuid=?, pickled_bytes=? WHERE uuid=?',
|
cur.execute('UPDATE UFaceMark SET UImage_uuid=?, UPerson_uuid=?, pickled_bytes=? WHERE uuid=?',
|
||||||
[UImage_uuid, UPerson_uuid, pickled_bytes, uuid])
|
[UImage_uuid, UPerson_uuid, pickled_bytes, uuid])
|
||||||
|
@ -96,6 +122,7 @@ class Faceset:
|
||||||
cur.execute('INSERT INTO UFaceMark VALUES (?, ?, ?, ?)', [uuid, UImage_uuid, UPerson_uuid, pickled_bytes])
|
cur.execute('INSERT INTO UFaceMark VALUES (?, ?, ?, ?)', [uuid, UImage_uuid, UPerson_uuid, pickled_bytes])
|
||||||
cur.execute('COMMIT')
|
cur.execute('COMMIT')
|
||||||
|
|
||||||
|
|
||||||
def get_UFaceMark_count(self) -> int:
|
def get_UFaceMark_count(self) -> int:
|
||||||
return self._cur.execute('SELECT COUNT(*) FROM UFaceMark').fetchone()[0]
|
return self._cur.execute('SELECT COUNT(*) FROM UFaceMark').fetchone()[0]
|
||||||
|
|
||||||
|
@ -116,26 +143,27 @@ class Faceset:
|
||||||
(self._cur.execute('BEGIN IMMEDIATE')
|
(self._cur.execute('BEGIN IMMEDIATE')
|
||||||
.execute('DELETE FROM UFaceMark')
|
.execute('DELETE FROM UFaceMark')
|
||||||
.execute('COMMIT') )
|
.execute('COMMIT') )
|
||||||
|
|
||||||
###################
|
###################
|
||||||
### UPerson
|
### UPerson
|
||||||
###################
|
###################
|
||||||
def add_UPerson(self, uperson : UPerson):
|
def add_UPerson(self, uperson_or_list : UPerson):
|
||||||
"""
|
"""
|
||||||
add or update UPerson in DB
|
add or update UPerson in DB
|
||||||
"""
|
"""
|
||||||
|
if not isinstance(uperson_or_list, Iterable):
|
||||||
|
uperson_or_list = [uperson_or_list]
|
||||||
|
|
||||||
|
cur = self._cur
|
||||||
|
cur.execute('BEGIN IMMEDIATE')
|
||||||
|
for uperson in uperson_or_list:
|
||||||
uuid = uperson.get_uuid()
|
uuid = uperson.get_uuid()
|
||||||
name = uperson.get_name()
|
name = uperson.get_name()
|
||||||
age = uperson.get_age()
|
age = uperson.get_age()
|
||||||
|
|
||||||
cur = self._conn.cursor()
|
|
||||||
cur.execute('BEGIN IMMEDIATE')
|
|
||||||
if cur.execute('SELECT COUNT(*) from UPerson where uuid=?', [uuid]).fetchone()[0] != 0:
|
if cur.execute('SELECT COUNT(*) from UPerson where uuid=?', [uuid]).fetchone()[0] != 0:
|
||||||
cur.execute('UPDATE UPerson SET name=?, age=? WHERE uuid=?', [name, age, uuid])
|
cur.execute('UPDATE UPerson SET name=?, age=? WHERE uuid=?', [name, age, uuid])
|
||||||
else:
|
else:
|
||||||
cur.execute('INSERT INTO UPerson VALUES (?, ?, ?)', [uuid, name, age])
|
cur.execute('INSERT INTO UPerson VALUES (?, ?, ?)', [uuid, name, age])
|
||||||
cur.execute('COMMIT')
|
cur.execute('COMMIT')
|
||||||
cur.close()
|
|
||||||
|
|
||||||
def iter_UPerson(self) -> Generator[UPerson, None, None]:
|
def iter_UPerson(self) -> Generator[UPerson, None, None]:
|
||||||
"""
|
"""
|
||||||
|
@ -169,7 +197,7 @@ class Faceset:
|
||||||
uimg.assign_image(img)
|
uimg.assign_image(img)
|
||||||
return uimg
|
return uimg
|
||||||
|
|
||||||
def add_UImage(self, uimage : UImage, format : str = 'webp', quality : int = 100):
|
def add_UImage(self, uimage_or_list : UImage, format : str = 'webp', quality : int = 100):
|
||||||
"""
|
"""
|
||||||
add or update UImage in DB
|
add or update UImage in DB
|
||||||
|
|
||||||
|
@ -188,6 +216,13 @@ class Faceset:
|
||||||
if format in ['jpg','jp2'] and quality < 0 or quality > 100:
|
if format in ['jpg','jp2'] and quality < 0 or quality > 100:
|
||||||
raise ValueError('quality must be in range [0..100]')
|
raise ValueError('quality must be in range [0..100]')
|
||||||
|
|
||||||
|
if not isinstance(uimage_or_list, Iterable):
|
||||||
|
uimage_or_list = [uimage_or_list]
|
||||||
|
|
||||||
|
cur = self._cur
|
||||||
|
cur.execute('BEGIN IMMEDIATE')
|
||||||
|
for uimage in uimage_or_list:
|
||||||
|
# TODO optimize move encoding to out of transaction
|
||||||
img = uimage.get_image()
|
img = uimage.get_image()
|
||||||
uuid = uimage.get_uuid()
|
uuid = uimage.get_uuid()
|
||||||
|
|
||||||
|
@ -204,8 +239,6 @@ class Faceset:
|
||||||
if not ret:
|
if not ret:
|
||||||
raise Exception(f'Unable to encode image format {format}')
|
raise Exception(f'Unable to encode image format {format}')
|
||||||
|
|
||||||
cur = self._cur
|
|
||||||
cur.execute('BEGIN IMMEDIATE')
|
|
||||||
if cur.execute('SELECT COUNT(*) from UImage where uuid=?', [uuid] ).fetchone()[0] != 0:
|
if cur.execute('SELECT COUNT(*) from UImage where uuid=?', [uuid] ).fetchone()[0] != 0:
|
||||||
cur.execute('UPDATE UImage SET name=?, format=?, data=? WHERE uuid=?', [uimage.get_name(), format, data_bytes.data, uuid])
|
cur.execute('UPDATE UImage SET name=?, format=?, data=? WHERE uuid=?', [uimage.get_name(), format, data_bytes.data, uuid])
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -1,2 +1,2 @@
|
||||||
from .ImageProcessor import ImageProcessor
|
from .ImageProcessor import ImageProcessor
|
||||||
from .misc import get_NHWC_shape
|
from ._misc import get_NHWC_shape
|
Loading…
Add table
Add a link
Reference in a new issue