mirror of
https://github.com/iperov/DeepFaceLive
synced 2025-08-21 05:53:25 -07:00
dev ...
This commit is contained in:
parent
64116844f2
commit
6978619d7f
3 changed files with 697 additions and 22 deletions
|
@ -1,22 +0,0 @@
|
|||
from xlib import face as lib_face
|
||||
from xlib import time as lib_time
|
||||
|
||||
|
||||
class FaceAlignerTrainer:
|
||||
def __init__(self, faceset_path):
|
||||
#fs = self._fs = lib_face.Faceset(faceset_path)
|
||||
fs = lib_face.Faceset(faceset_path)
|
||||
#fs.close()
|
||||
|
||||
with lib_time.timeit():
|
||||
for x in fs.iter_UImage():
|
||||
x.get_image()
|
||||
#fs = lib_face.Faceset(faceset_path)
|
||||
#fs.add_UFaceMark( [ lib_face.UFaceMark() for _ in range(1000)] )
|
||||
|
||||
import code
|
||||
code.interact(local=dict(globals(), **locals()))
|
||||
|
||||
|
||||
def run(self):
|
||||
...
|
449
apps/trainers/FaceAligner/FaceAlignerTrainerApp.py
Normal file
449
apps/trainers/FaceAligner/FaceAlignerTrainerApp.py
Normal file
|
@ -0,0 +1,449 @@
|
|||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, List, Tuple, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.autograd
|
||||
from localization import L, Localization
|
||||
from modelhub import torch as torch_models
|
||||
from xlib import torch as lib_torch
|
||||
from xlib.console import diacon as dc
|
||||
from xlib.torch.optim import AdaBelief
|
||||
from xlib.torch.device import TorchDeviceInfo
|
||||
import torchvision.models as tv
|
||||
from xlib import math as lib_math
|
||||
from .TrainingDataGenerator import Data, TrainingDataGenerator
|
||||
|
||||
class FaceAlignerTrainerApp:
|
||||
def __init__(self, workspace_path : Path, faceset_path : Path):
|
||||
print('Initializing trainer.\n')
|
||||
print(f'Workspace path: {workspace_path}')
|
||||
print(f'Faceset path: {faceset_path}\n')
|
||||
|
||||
workspace_path.mkdir(parents=True, exist_ok=True)
|
||||
self._workspace_path = workspace_path
|
||||
self._faceset_path = faceset_path
|
||||
self._model_data_path = workspace_path / 'model.dat'
|
||||
|
||||
# system vars
|
||||
self._model_lock = threading.Lock()
|
||||
self._ev_request_reset_model = threading.Event()
|
||||
self._ev_request_preview = threading.Event()
|
||||
self._ev_request_export_model = threading.Event()
|
||||
self._ev_request_save = threading.Event()
|
||||
self._ev_request_quit = threading.Event()
|
||||
|
||||
# state vars
|
||||
self._is_quit = False
|
||||
self._is_device_changed = False
|
||||
self._new_viewing_data = None
|
||||
self._is_previewing_samples = False
|
||||
self._new_preview_data : 'PreviewData' = None
|
||||
self._last_save_time = None
|
||||
|
||||
# settings / params
|
||||
self._model_data = None
|
||||
self._device = None
|
||||
self._device_info = None
|
||||
self._batch_size = None
|
||||
self._resolution = None
|
||||
self._iteration = None
|
||||
self._autosave_period = None
|
||||
self._is_training = None
|
||||
self._loss_history = {}
|
||||
|
||||
# Generators
|
||||
self._training_generator = TrainingDataGenerator(faceset_path)
|
||||
|
||||
# Load and start
|
||||
self.load()
|
||||
threading.Thread(target=self.preview_thread_proc, daemon=True).start()
|
||||
self._training_generator.set_running(True)
|
||||
|
||||
self.get_main_dlg().set_current()
|
||||
self.main_loop()
|
||||
|
||||
# Finalizing
|
||||
self._training_generator.stop()
|
||||
dc.Diacon.stop()
|
||||
|
||||
def get_device_info(self): return self._device_info
|
||||
def set_device_info(self, device_info : TorchDeviceInfo):
|
||||
self._device = lib_torch.get_device(device_info)
|
||||
self._device_info = device_info
|
||||
self._is_device_changed = True
|
||||
|
||||
def get_batch_size(self) -> int: return self._batch_size
|
||||
def set_batch_size(self, batch_size : int):
|
||||
self._batch_size = batch_size
|
||||
self._training_generator.set_batch_size(batch_size)
|
||||
|
||||
def get_resolution(self) -> int: return self._resolution
|
||||
def set_resolution(self, resolution : int):
|
||||
self._resolution = resolution
|
||||
self._training_generator.set_resolution(resolution)
|
||||
|
||||
def get_iteration(self) -> int: return self._iteration
|
||||
def set_iteration(self, iteration : int):
|
||||
self._iteration = iteration
|
||||
for key in self._loss_history.keys():
|
||||
self._loss_history[key] = self._loss_history[key][:iteration]
|
||||
|
||||
def get_autosave_period(self): return self._autosave_period
|
||||
def set_autosave_period(self, mins : int):
|
||||
self._autosave_period = mins
|
||||
|
||||
def get_is_training(self) -> bool: return self._is_training
|
||||
def set_is_training(self, training : bool):
|
||||
if self._is_training != training:
|
||||
if training:
|
||||
self._last_save_time = time.time()
|
||||
else:
|
||||
self._last_save_time = None
|
||||
self._is_training = training
|
||||
|
||||
def get_loss_history(self): return self._loss_history
|
||||
def set_loss_history(self, lh): self._loss_history = lh
|
||||
|
||||
def load(self):
|
||||
self._model_data = model_data = torch.load(self._model_data_path, map_location='cpu') if self._model_data_path.exists() else {}
|
||||
|
||||
self.set_device_info( lib_torch.get_device_info_by_index(model_data.get('device_index', -1)) )
|
||||
self.set_batch_size( model_data.get('batch_size', 64) )
|
||||
self.set_resolution( model_data.get('resolution', 224) )
|
||||
self.set_iteration( model_data.get('iteration', 0) )
|
||||
self.set_autosave_period( model_data.get('autosave_period', 25) )
|
||||
self.set_is_training( model_data.get('training', False) )
|
||||
self.set_loss_history ( model_data.get('loss_history', {}) )
|
||||
|
||||
self.reset_model(load=True)
|
||||
|
||||
|
||||
|
||||
def reset_model(self, load : bool = True):
|
||||
while True:
|
||||
model = tv.mobilenet.mobilenet_v3_large(num_classes=6)
|
||||
|
||||
model.train()
|
||||
model.to(self._device)
|
||||
|
||||
model_optimizer = AdaBelief(model.parameters(), lr=5e-5, lr_dropout=0.3)
|
||||
|
||||
if load:
|
||||
model_state_dict = self._model_data.get('model_state_dict', None)
|
||||
if model_state_dict is not None:
|
||||
try:
|
||||
model.load_state_dict(model_state_dict)
|
||||
model.to(self._device)
|
||||
|
||||
model_optimizer_state_dict = self._model_data.get('model_optimizer_state_dict', None)
|
||||
if model_optimizer_state_dict is not None:
|
||||
model_optimizer.load_state_dict(model_optimizer_state_dict)
|
||||
|
||||
except:
|
||||
print('Network weights have been reseted.')
|
||||
self._model_data['model_state_dict'] = None
|
||||
self._model_data['model_optimizer_state_dict'] = None
|
||||
continue
|
||||
else:
|
||||
print('Network weights have been reseted.')
|
||||
break
|
||||
|
||||
self._model = model
|
||||
self._model_optimizer = model_optimizer
|
||||
|
||||
def save(self):
|
||||
if self._model_data is not None:
|
||||
d = {'device_index' : self._device_info.get_index(),
|
||||
'batch_size' : self.get_batch_size(),
|
||||
'resolution' : self.get_resolution(),
|
||||
'iteration' : self.get_iteration(),
|
||||
'autosave_period' : self.get_autosave_period(),
|
||||
'training' : self.get_is_training(),
|
||||
'loss_history' : self.get_loss_history(),
|
||||
'model_state_dict' : self._model.state_dict(),
|
||||
'model_optimizer_state_dict': self._model_optimizer.state_dict(),
|
||||
}
|
||||
|
||||
torch.save(d, self._model_data_path)
|
||||
|
||||
def export(self):
|
||||
self._model.to('cpu')
|
||||
self._model.eval()
|
||||
|
||||
torch.onnx.export( self._model,
|
||||
(torch.from_numpy( np.zeros( (1,3,self._resolution,self._resolution), dtype=np.float32)),) ,
|
||||
str(self._workspace_path / 'FaceAligner.onnx'),
|
||||
verbose=True,
|
||||
training=torch.onnx.TrainingMode.EVAL,
|
||||
opset_version=12,
|
||||
do_constant_folding=True,
|
||||
input_names=['in'],
|
||||
output_names=['mat'],
|
||||
dynamic_axes={'in' : {0:'batch_size'}, 'mat' : {0:'batch_size'}}, )
|
||||
|
||||
self._model.to(self._device)
|
||||
self._model.train()
|
||||
|
||||
|
||||
def preview_thread_proc(self):
|
||||
|
||||
|
||||
while not self._is_quit:
|
||||
preview_data, self._new_preview_data = self._new_preview_data, None
|
||||
if preview_data is not None:
|
||||
# new preview data to show
|
||||
data = preview_data.training_data
|
||||
n = np.random.randint(data.batch_size)
|
||||
img_aligned = data.img_aligned[n].transpose((1,2,0))
|
||||
img_aligned_shifted = data.img_aligned_shifted[n].transpose((1,2,0))
|
||||
|
||||
H,W = img_aligned_shifted.shape[:2]
|
||||
|
||||
|
||||
shift_mat = lib_math.Affine2DUniMat(data.shift_uni_mats[n]).invert().to_exact_mat(W,H, W, H)
|
||||
shift_mat_pred = lib_math.Affine2DUniMat(preview_data.shift_uni_mats_pred[n]).invert().to_exact_mat(W,H, W, H)
|
||||
|
||||
img_aligned_unshifted = cv2.warpAffine(img_aligned_shifted, shift_mat, (W,H))
|
||||
img_aligned_unshifted_pred = cv2.warpAffine(img_aligned_shifted, shift_mat_pred, (W,H))
|
||||
|
||||
screen = np.concatenate([img_aligned, img_aligned_shifted, img_aligned_unshifted, img_aligned_unshifted_pred], 1)
|
||||
cv2.imshow('Preview', screen)
|
||||
|
||||
|
||||
viewing_data, self._new_viewing_data = self._new_viewing_data, None
|
||||
if viewing_data is not None:
|
||||
n = np.random.randint(viewing_data.batch_size)
|
||||
img_aligned_shifted = viewing_data.img_aligned_shifted[n].transpose((1,2,0))
|
||||
screen = np.concatenate([img_aligned_shifted], 1)
|
||||
cv2.imshow('Viewing samples', screen)
|
||||
|
||||
|
||||
cv2.waitKey(5)
|
||||
time.sleep(0.005)
|
||||
|
||||
def main_loop(self):
|
||||
|
||||
while not self._is_quit:
|
||||
|
||||
if self._ev_request_reset_model.is_set():
|
||||
self._ev_request_reset_model.clear()
|
||||
self.reset_model(load=False)
|
||||
|
||||
if self._is_device_changed:
|
||||
self._is_device_changed = False
|
||||
self._model.to(self._device)
|
||||
self._model_optimizer.load_state_dict(self._model_optimizer.state_dict())
|
||||
|
||||
if self._is_training or \
|
||||
self._is_previewing_samples or \
|
||||
self._ev_request_preview.is_set():
|
||||
training_data = self._training_generator.get_next_data(wait=False)
|
||||
if training_data is not None and \
|
||||
training_data.resolution == self.get_resolution(): # Skip if resolution is different, due to delay
|
||||
|
||||
if self._is_training:
|
||||
self._model_optimizer.zero_grad()
|
||||
|
||||
if self._ev_request_preview.is_set() or \
|
||||
self._is_training:
|
||||
# Inference for both preview and training
|
||||
img_aligned_shifted_t = torch.tensor(training_data.img_aligned_shifted).to(self._device)
|
||||
shift_uni_mats_pred_t = self._model(img_aligned_shifted_t).view( (-1,2,3) )
|
||||
|
||||
if self._is_training:
|
||||
# Training optimization step
|
||||
shift_uni_mats_t = torch.tensor(training_data.shift_uni_mats).to(self._device)
|
||||
|
||||
loss_t = (shift_uni_mats_pred_t-shift_uni_mats_t).square().mean()*10.0
|
||||
loss_t.backward()
|
||||
self._model_optimizer.step()
|
||||
|
||||
loss = loss_t.detach().cpu().numpy()
|
||||
|
||||
rec_loss_history = self._loss_history.get('reconstruct', None)
|
||||
if rec_loss_history is None:
|
||||
rec_loss_history = self._loss_history['reconstruct'] = []
|
||||
rec_loss_history.append(float(loss))
|
||||
|
||||
self.set_iteration( self.get_iteration() + 1 )
|
||||
|
||||
if self._ev_request_preview.is_set():
|
||||
self._ev_request_preview.clear()
|
||||
# Preview request
|
||||
pd = PreviewData()
|
||||
pd.training_data = training_data
|
||||
pd.shift_uni_mats_pred = shift_uni_mats_pred_t.detach().cpu().numpy()
|
||||
self._new_preview_data = pd
|
||||
|
||||
if self._is_previewing_samples:
|
||||
self._new_viewing_data = training_data
|
||||
|
||||
if self._is_training:
|
||||
if self._last_save_time is not None:
|
||||
while (time.time()-self._last_save_time)/60 >= self._autosave_period:
|
||||
self._last_save_time += self._autosave_period*60
|
||||
self._ev_request_save.set()
|
||||
|
||||
|
||||
if self._ev_request_export_model.is_set():
|
||||
self._ev_request_export_model.clear()
|
||||
print('Exporting...')
|
||||
self.export()
|
||||
print('Exporting done.')
|
||||
dc.Diacon.update_dlg()
|
||||
|
||||
if self._ev_request_save.is_set():
|
||||
self._ev_request_save.clear()
|
||||
print('Saving...')
|
||||
self.save()
|
||||
print('Saving done.')
|
||||
|
||||
dc.Diacon.update_dlg()
|
||||
|
||||
if self._ev_request_quit.is_set():
|
||||
self._ev_request_quit.clear()
|
||||
self._is_quit = True
|
||||
|
||||
time.sleep(0.005)
|
||||
|
||||
|
||||
def get_main_dlg(self):
|
||||
last_loss = 0
|
||||
rec_loss_history = self._loss_history.get('reconstruct', None)
|
||||
if rec_loss_history is not None:
|
||||
if len(rec_loss_history) != 0:
|
||||
last_loss = rec_loss_history[-1]
|
||||
|
||||
return dc.DlgChoices([
|
||||
dc.DlgChoice(short_name='sgm', row_def='| Sample generator menu.',
|
||||
on_choose=lambda dlg: self.get_sample_generator_dlg(dlg).set_current()),
|
||||
|
||||
dc.DlgChoice(short_name='d', row_def=f'| Device | {self._device_info}',
|
||||
on_choose=lambda dlg: self.get_training_device_dlg(dlg).set_current()),
|
||||
|
||||
dc.DlgChoice(short_name='i', row_def=f'| Iteration | {self.get_iteration()}',
|
||||
on_choose=lambda dlg: self.get_iteration_dlg(parent_dlg=dlg).set_current() ),
|
||||
|
||||
dc.DlgChoice(short_name='l', row_def=f'| Print loss history | Last loss = {last_loss:.5f} ',
|
||||
on_choose=self.on_main_dlg_print_loss_history ),
|
||||
|
||||
dc.DlgChoice(short_name='p', row_def='| Show current preview.',
|
||||
on_choose=lambda dlg: (self._ev_request_preview.set(), dlg.recreate().set_current())),
|
||||
|
||||
dc.DlgChoice(short_name='t', row_def=f'| Training | {self._is_training}',
|
||||
on_choose=lambda dlg: (self.set_is_training(not self.get_is_training()), dlg.recreate().set_current()) ),
|
||||
|
||||
dc.DlgChoice(short_name='reset', row_def='| Reset model.',
|
||||
on_choose=lambda dlg: (self._ev_request_reset_model.set(), dlg.recreate().set_current()) ),
|
||||
|
||||
dc.DlgChoice(short_name='export', row_def='| Export model.',
|
||||
on_choose=lambda dlg: self._ev_request_export_model.set() ),
|
||||
|
||||
dc.DlgChoice(short_name='se', row_def=f'| Autosave period | {self.get_autosave_period()} minutes',
|
||||
on_choose=lambda dlg: self.get_autosave_period_dlg(dlg).set_current()),
|
||||
|
||||
dc.DlgChoice(short_name='s', row_def='| Save all.',
|
||||
on_choose=lambda dlg: self._ev_request_save.set() ),
|
||||
|
||||
dc.DlgChoice(short_name='q', row_def='| Quit now.',
|
||||
on_choose=self.on_main_dlg_quit )
|
||||
], on_recreate=lambda dlg: self.get_main_dlg(),
|
||||
top_rows_def='|c9 Main menu' )
|
||||
|
||||
def on_main_dlg_quit(self, dlg):
|
||||
self._ev_request_quit.set()
|
||||
|
||||
|
||||
|
||||
def on_main_dlg_print_loss_history(self, dlg):
|
||||
max_lines = 20
|
||||
for key in self._loss_history.keys():
|
||||
lh = self._loss_history[key]
|
||||
|
||||
print(f'Loss history for: {key}')
|
||||
|
||||
d = len(lh) // max_lines
|
||||
|
||||
lh_ar = np.array(lh[-d*max_lines:], np.float32)
|
||||
lh_ar = lh_ar.reshape( (max_lines, d)).mean(-1)
|
||||
|
||||
print( '\n'.join( f'{value:.5f}' for value in lh_ar ) )
|
||||
|
||||
dlg.recreate().set_current()
|
||||
|
||||
|
||||
def get_sample_generator_dlg(self, parent_dlg):
|
||||
return dc.DlgChoices([
|
||||
dc.DlgChoice(short_name='v', row_def=f'| Previewing samples | {self._is_previewing_samples}',
|
||||
on_choose=self.on_sample_generator_dlg_previewing_last_samples,
|
||||
),
|
||||
|
||||
dc.DlgChoice(short_name='r', row_def=f'| Running | {self._training_generator.is_running()}',
|
||||
on_choose=lambda dlg: (self._training_generator.set_running(not self._training_generator.is_running()), dlg.recreate().set_current()) ),
|
||||
|
||||
],
|
||||
on_recreate=lambda dlg: self.get_sample_generator_dlg(parent_dlg),
|
||||
on_back =lambda dlg: parent_dlg.recreate().set_current(),
|
||||
top_rows_def='|c9 Sample generator menu' )
|
||||
|
||||
def on_sample_generator_dlg_previewing_last_samples(self, dlg):
|
||||
self._is_previewing_samples = not self._is_previewing_samples
|
||||
dlg.recreate().set_current()
|
||||
|
||||
def get_training_dlg(self, parent_dlg):
|
||||
return dc.DlgChoices([
|
||||
|
||||
],
|
||||
on_recreate=lambda dlg: self.get_training_dlg(parent_dlg),
|
||||
on_back =lambda dlg: parent_dlg.recreate().set_current(),
|
||||
top_rows_def='|c9 Training menu' )
|
||||
|
||||
def get_autosave_period_dlg(self, parent_dlg):
|
||||
return dc.DlgNumber(is_float=False, min_value=1,
|
||||
on_value = lambda dlg, value: (self.set_autosave_period(value), parent_dlg.recreate().set_current()),
|
||||
on_recreate = lambda dlg: self.get_autosave_period_dlg(parent_dlg),
|
||||
on_back = lambda dlg: parent_dlg.recreate().set_current(),
|
||||
top_rows_def='|c9 Set save every min', )
|
||||
|
||||
def get_iteration_dlg(self, parent_dlg):
|
||||
return dc.DlgNumber(is_float=False, min_value=0,
|
||||
on_value = lambda dlg, value: (self.set_iteration(value), parent_dlg.recreate().set_current()),
|
||||
on_recreate = lambda dlg: self.get_iteration_dlg(parent_dlg),
|
||||
on_back = lambda dlg: parent_dlg.recreate().set_current(),
|
||||
top_rows_def='|c9 Set iteration', )
|
||||
|
||||
def get_training_device_dlg(self, parent_dlg):
|
||||
return DlgTorchDevicesInfo(on_device_choice = lambda dlg, device_info: (self.set_device_info(device_info), parent_dlg.recreate().set_current()),
|
||||
on_recreate = lambda dlg: self.get_training_device_dlg(parent_dlg),
|
||||
on_back = lambda dlg: parent_dlg.recreate().set_current(),
|
||||
top_rows_def='|c9 Choose device'
|
||||
)
|
||||
|
||||
|
||||
class DlgTorchDevicesInfo(dc.DlgChoices):
|
||||
def __init__(self, on_device_choice : Callable = None,
|
||||
on_device_multi_choice : Callable = None,
|
||||
on_recreate = None,
|
||||
on_back : Callable = None,
|
||||
top_rows_def : Union[str, List[str]] = None,
|
||||
bottom_rows_def : Union[str, List[str]] = None,):
|
||||
devices = lib_torch.get_available_devices_info()
|
||||
super().__init__(choices=[
|
||||
dc.DlgChoice(short_name=f'{device.get_index()}' if not device.is_cpu() else 'c',
|
||||
row_def= f"| {str(device.get_name())} " +
|
||||
(f"| {(device.get_total_memory() / 1024**3) :.3}Gb" if not device.is_cpu() else ""),
|
||||
on_choose= ( lambda dlg, i=i: on_device_choice(dlg, devices[i]) ) \
|
||||
if on_device_choice is not None else None)
|
||||
for i, device in enumerate(devices) ],
|
||||
on_multi_choice=(lambda idxs: on_device_multi_choice([ devices[idx] for idx in idxs ])) \
|
||||
if on_device_multi_choice is not None else None,
|
||||
on_recreate=on_recreate, on_back=on_back,
|
||||
top_rows_def=top_rows_def, bottom_rows_def=bottom_rows_def)
|
||||
|
||||
class PreviewData:
|
||||
training_data : Data = None
|
||||
shift_uni_mats_pred = None
|
248
apps/trainers/FaceAligner/TrainingDataGenerator.py
Normal file
248
apps/trainers/FaceAligner/TrainingDataGenerator.py
Normal file
|
@ -0,0 +1,248 @@
|
|||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Tuple, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from xlib import face as lib_face
|
||||
from xlib import image as lib_img
|
||||
from xlib import mp as lib_mp
|
||||
from xlib import mt as lib_mt
|
||||
from xlib.image import sd as lib_sd
|
||||
|
||||
|
||||
class Data:
|
||||
def __init__(self):
|
||||
self.batch_size : int = None
|
||||
self.resolution : int = None
|
||||
self.img_aligned : np.ndarray = None
|
||||
self.img_aligned_shifted : np.ndarray = None
|
||||
self.shift_uni_mats : np.ndarray = None
|
||||
|
||||
class TrainingDataGenerator(lib_mp.MPWorker):
|
||||
def __init__(self, faceset_path : Path):
|
||||
faceset_path = Path(faceset_path)
|
||||
if not faceset_path.exists():
|
||||
raise Exception (f'{faceset_path} does not exist.')
|
||||
|
||||
super().__init__(sub_args=[faceset_path])
|
||||
|
||||
self._datas = [ deque() for _ in range(self.get_process_count())]
|
||||
self._datas_counter = 0
|
||||
self._running = False
|
||||
|
||||
def get_next_data(self, wait : bool) -> Union[Data, None]:
|
||||
"""
|
||||
wait and returns new generated data
|
||||
"""
|
||||
while True:
|
||||
for _ in range(self.get_process_count()):
|
||||
process_id, self._datas_counter = self._datas_counter % len(self._datas), self._datas_counter + 1
|
||||
data = self._datas[process_id]
|
||||
|
||||
if len(data) != 0:
|
||||
self._send_msg('data_received', process_id=process_id)
|
||||
return data.popleft()
|
||||
|
||||
if not wait:
|
||||
return None
|
||||
time.sleep(0.005)
|
||||
|
||||
def is_running(self) -> bool: return self._running
|
||||
def set_running(self, running : bool):
|
||||
self._running = running
|
||||
self._send_msg('running', running)
|
||||
|
||||
def set_batch_size(self, batch_size):
|
||||
self._send_msg('batch_size', batch_size)
|
||||
|
||||
def set_resolution(self, resolution):
|
||||
self._send_msg('resolution', resolution)
|
||||
|
||||
###### IMPL HOST
|
||||
def _on_host_sub_message(self, process_id, name, *args, **kwargs):
|
||||
"""
|
||||
message from sub
|
||||
"""
|
||||
if name == 'data':
|
||||
self._datas[process_id].append(args[0])
|
||||
|
||||
###### IMPL SUB
|
||||
def _on_sub_initialize(self, faceset_path : Path):
|
||||
self._fs = fs = lib_face.Faceset(faceset_path)
|
||||
self._ufm_uuids = fs.get_all_UFaceMark_uuids()
|
||||
self._ufm_uuid_indexes = []
|
||||
self._sent_buffers_count = 0
|
||||
self._running = False
|
||||
self._batch_size = None
|
||||
self._resolution = None
|
||||
|
||||
self._n_batch = 0
|
||||
self._img_aligned_list = []
|
||||
self._img_aligned_shifted_list = []
|
||||
self._shift_mat_list = []
|
||||
|
||||
def _on_sub_finalize(self):
|
||||
self._fs.close()
|
||||
|
||||
def _on_sub_host_message(self, name, *args, **kwargs):
|
||||
"""
|
||||
a message from host
|
||||
"""
|
||||
if name == 'data_received':
|
||||
self._sent_buffers_count -= 1
|
||||
elif name == 'batch_size':
|
||||
self._batch_size, = args
|
||||
elif name == 'resolution':
|
||||
self._resolution, = args
|
||||
elif name == 'running':
|
||||
self._running , = args
|
||||
|
||||
####### IMPL SUB THREAD
|
||||
def _on_sub_tick(self, process_id):
|
||||
running = self._running
|
||||
if running:
|
||||
if self._batch_size is None:
|
||||
print('Unable to start TrainingGenerator: batch_size must be set')
|
||||
running = False
|
||||
if self._resolution is None:
|
||||
print('Unable to start TrainingGenerator: resolution must be set')
|
||||
running = False
|
||||
|
||||
if running:
|
||||
if self._sent_buffers_count < 2:
|
||||
batch_size = self._batch_size
|
||||
resolution = self._resolution
|
||||
face_coverage = 1.0
|
||||
|
||||
rw_grid_cell_range = [3,7]
|
||||
rw_grid_rot_deg_range = [-180,180]
|
||||
rw_grid_scale_range = [-0.25, 2.5]
|
||||
rw_grid_tx_range = [-0.50, 0.50]
|
||||
rw_grid_ty_range = [-0.50, 0.50]
|
||||
|
||||
align_rot_deg_range = [-180,180]
|
||||
align_scale_range = [0.0,2.5]
|
||||
align_tx_range = [-0.50, 0.50]
|
||||
align_ty_range = [-0.50, 0.50]
|
||||
|
||||
|
||||
|
||||
random_mask_complexity = 3
|
||||
sharpen_chance = 25
|
||||
motion_blur_chance = 25
|
||||
gaussian_blur_chance = 25
|
||||
reresize_chance = 25
|
||||
recompress_chance = 25
|
||||
|
||||
img_aligned_list = []
|
||||
img_aligned_shifted_list = []
|
||||
shift_mat_list = []
|
||||
|
||||
if self._n_batch < batch_size:
|
||||
# Make only 1 sample per tick
|
||||
while True:
|
||||
uuid1 = self._get_next_UFaceMark_uuid()
|
||||
|
||||
ufm1 = self._fs.get_UFaceMark_by_uuid(uuid1)
|
||||
|
||||
flmrks1 = ufm1.get_FLandmarks2D_best()
|
||||
if flmrks1 is None:
|
||||
print(f'Corrupted faceset, no FLandmarks2D for UFaceMark {ufm1.get_uuid()}')
|
||||
continue
|
||||
|
||||
uimg1 = self._fs.get_UImage_by_uuid(ufm1.get_UImage_uuid())
|
||||
if uimg1 is None:
|
||||
print(f'Corrupted faceset, no UImage for UFaceMark {ufm1.get_uuid()}')
|
||||
continue
|
||||
|
||||
img1 = uimg1.get_image()
|
||||
|
||||
if img1 is None:
|
||||
print(f'Corrupted faceset, no image in UImage {uimg1.get_uuid()}')
|
||||
continue
|
||||
|
||||
img_aligned, _ = flmrks1.cut(img1, face_coverage, resolution)
|
||||
img_aligned = img_aligned.astype(np.float32) / 255.0
|
||||
|
||||
_, img_to_face_uni_mat1 = flmrks1.calc_cut( img1.shape[0:2], face_coverage, resolution)
|
||||
|
||||
|
||||
fw1 = lib_face.FaceWarper(img_to_face_uni_mat1,
|
||||
align_rot_deg=align_rot_deg_range,
|
||||
align_scale=align_scale_range,
|
||||
align_tx=align_tx_range,
|
||||
align_ty=align_ty_range,
|
||||
rw_grid_cell_count=rw_grid_cell_range,
|
||||
rw_grid_rot_deg=rw_grid_rot_deg_range,
|
||||
rw_grid_scale=rw_grid_scale_range,
|
||||
rw_grid_tx=rw_grid_tx_range,
|
||||
rw_grid_ty=rw_grid_ty_range,
|
||||
)
|
||||
|
||||
img_aligned_shifted = fw1.transform(img1, resolution, random_warp=True).astype(np.float32) / 255.0
|
||||
|
||||
ip = lib_img.ImageProcessor(img_aligned_shifted)
|
||||
rnd = np.random
|
||||
if rnd.randint(2) == 0:
|
||||
ip.hsv( rnd.randint(0, 360), rnd.uniform(-0.5,0.5), rnd.uniform(-0.5,0.5), mask=lib_sd.random_circle_faded_multi((resolution,resolution), complexity=random_mask_complexity))
|
||||
else:
|
||||
ip.levels( [ [rnd.uniform(0,0.25),rnd.uniform(0.75,1.0),rnd.uniform(0.5,1.5), rnd.uniform(0,0.25),rnd.uniform(0.75,1.0), ],
|
||||
[rnd.uniform(0,0.25),rnd.uniform(0.75,1.0),rnd.uniform(0.5,1.5), rnd.uniform(0,0.25),rnd.uniform(0.75,1.0),],
|
||||
[rnd.uniform(0,0.25),rnd.uniform(0.75,1.0),rnd.uniform(0.5,1.5), rnd.uniform(0,0.25),rnd.uniform(0.75,1.0),], ], mask=lib_sd.random_circle_faded_multi((resolution,resolution), complexity=random_mask_complexity))
|
||||
|
||||
if rnd.randint(2) == 0:
|
||||
if rnd.randint(100) < sharpen_chance:
|
||||
if rnd.randint(2) == 0:
|
||||
ip.box_sharpen(size=rnd.randint(1,11), power=rnd.uniform(0.5,5.0), mask=lib_sd.random_circle_faded_multi((resolution,resolution), complexity=random_mask_complexity) )
|
||||
else:
|
||||
ip.gaussian_sharpen(sigma=1.0, power=rnd.uniform(0.5,5.0), mask=lib_sd.random_circle_faded_multi((resolution,resolution), complexity=random_mask_complexity) )
|
||||
|
||||
else:
|
||||
if rnd.randint(100) < motion_blur_chance:
|
||||
ip.motion_blur (size=rnd.randint(1,11), angle=rnd.randint(360), mask=lib_sd.random_circle_faded_multi((resolution,resolution), complexity=random_mask_complexity))
|
||||
if rnd.randint(100) < gaussian_blur_chance:
|
||||
ip.gaussian_blur (sigma=rnd.uniform(0.5,3.0), mask=lib_sd.random_circle_faded_multi((resolution,resolution), complexity=random_mask_complexity))
|
||||
|
||||
if np.random.randint(2) == 0:
|
||||
if rnd.randint(100) < reresize_chance:
|
||||
ip.reresize( rnd.uniform(0.0,0.75), interpolation=ip.Interpolation.NEAREST, mask=lib_sd.random_circle_faded_multi((resolution,resolution), complexity=random_mask_complexity) )
|
||||
if np.random.randint(2) == 0:
|
||||
if rnd.randint(100) < reresize_chance:
|
||||
ip.reresize( rnd.uniform(0.0,0.75), interpolation=ip.Interpolation.LINEAR, mask=lib_sd.random_circle_faded_multi((resolution,resolution), complexity=random_mask_complexity) )
|
||||
if rnd.randint(100) < recompress_chance:
|
||||
ip.jpeg_recompress(quality=rnd.randint(10,75), mask=lib_sd.random_circle_faded_multi((resolution,resolution), complexity=random_mask_complexity) )
|
||||
|
||||
img_aligned_shifted = ip.get_image('HWC')
|
||||
|
||||
self._img_aligned_list.append(img_aligned)
|
||||
self._img_aligned_shifted_list.append(img_aligned_shifted)
|
||||
self._shift_mat_list.append( fw1.get_aligned_random_transform_mat() )
|
||||
self._n_batch += 1
|
||||
break
|
||||
|
||||
if self._n_batch == batch_size:
|
||||
data = Data()
|
||||
data.batch_size = batch_size
|
||||
data.resolution = resolution
|
||||
data.img_aligned = np.array(self._img_aligned_list).transpose( (0,3,1,2))
|
||||
data.img_aligned_shifted = np.array(self._img_aligned_shifted_list).transpose( (0,3,1,2))
|
||||
data.shift_uni_mats = np.array(self._shift_mat_list)
|
||||
|
||||
self._send_msg('data', data)
|
||||
self._sent_buffers_count +=1
|
||||
|
||||
self._n_batch = 0
|
||||
self._img_aligned_list = []
|
||||
self._img_aligned_shifted_list = []
|
||||
self._shift_mat_list = []
|
||||
|
||||
|
||||
def _get_next_UFaceMark_uuid(self) -> bytes:
|
||||
if len(self._ufm_uuid_indexes) == 0:
|
||||
self._ufm_uuid_indexes = [*range(len(self._ufm_uuids))]
|
||||
np.random.shuffle(self._ufm_uuid_indexes)
|
||||
idx = self._ufm_uuid_indexes.pop()
|
||||
return self._ufm_uuids[idx]
|
Loading…
Add table
Add a link
Reference in a new issue