This commit is contained in:
iperov 2021-11-11 23:17:43 +04:00
commit 6978619d7f
3 changed files with 697 additions and 22 deletions

View file

@ -1,22 +0,0 @@
from xlib import face as lib_face
from xlib import time as lib_time
class FaceAlignerTrainer:
def __init__(self, faceset_path):
#fs = self._fs = lib_face.Faceset(faceset_path)
fs = lib_face.Faceset(faceset_path)
#fs.close()
with lib_time.timeit():
for x in fs.iter_UImage():
x.get_image()
#fs = lib_face.Faceset(faceset_path)
#fs.add_UFaceMark( [ lib_face.UFaceMark() for _ in range(1000)] )
import code
code.interact(local=dict(globals(), **locals()))
def run(self):
...

View file

@ -0,0 +1,449 @@
import threading
import time
from pathlib import Path
from typing import Any, Callable, List, Tuple, Union
import cv2
import numpy as np
import torch
import torch.autograd
from localization import L, Localization
from modelhub import torch as torch_models
from xlib import torch as lib_torch
from xlib.console import diacon as dc
from xlib.torch.optim import AdaBelief
from xlib.torch.device import TorchDeviceInfo
import torchvision.models as tv
from xlib import math as lib_math
from .TrainingDataGenerator import Data, TrainingDataGenerator
class FaceAlignerTrainerApp:
def __init__(self, workspace_path : Path, faceset_path : Path):
print('Initializing trainer.\n')
print(f'Workspace path: {workspace_path}')
print(f'Faceset path: {faceset_path}\n')
workspace_path.mkdir(parents=True, exist_ok=True)
self._workspace_path = workspace_path
self._faceset_path = faceset_path
self._model_data_path = workspace_path / 'model.dat'
# system vars
self._model_lock = threading.Lock()
self._ev_request_reset_model = threading.Event()
self._ev_request_preview = threading.Event()
self._ev_request_export_model = threading.Event()
self._ev_request_save = threading.Event()
self._ev_request_quit = threading.Event()
# state vars
self._is_quit = False
self._is_device_changed = False
self._new_viewing_data = None
self._is_previewing_samples = False
self._new_preview_data : 'PreviewData' = None
self._last_save_time = None
# settings / params
self._model_data = None
self._device = None
self._device_info = None
self._batch_size = None
self._resolution = None
self._iteration = None
self._autosave_period = None
self._is_training = None
self._loss_history = {}
# Generators
self._training_generator = TrainingDataGenerator(faceset_path)
# Load and start
self.load()
threading.Thread(target=self.preview_thread_proc, daemon=True).start()
self._training_generator.set_running(True)
self.get_main_dlg().set_current()
self.main_loop()
# Finalizing
self._training_generator.stop()
dc.Diacon.stop()
def get_device_info(self): return self._device_info
def set_device_info(self, device_info : TorchDeviceInfo):
self._device = lib_torch.get_device(device_info)
self._device_info = device_info
self._is_device_changed = True
def get_batch_size(self) -> int: return self._batch_size
def set_batch_size(self, batch_size : int):
self._batch_size = batch_size
self._training_generator.set_batch_size(batch_size)
def get_resolution(self) -> int: return self._resolution
def set_resolution(self, resolution : int):
self._resolution = resolution
self._training_generator.set_resolution(resolution)
def get_iteration(self) -> int: return self._iteration
def set_iteration(self, iteration : int):
self._iteration = iteration
for key in self._loss_history.keys():
self._loss_history[key] = self._loss_history[key][:iteration]
def get_autosave_period(self): return self._autosave_period
def set_autosave_period(self, mins : int):
self._autosave_period = mins
def get_is_training(self) -> bool: return self._is_training
def set_is_training(self, training : bool):
if self._is_training != training:
if training:
self._last_save_time = time.time()
else:
self._last_save_time = None
self._is_training = training
def get_loss_history(self): return self._loss_history
def set_loss_history(self, lh): self._loss_history = lh
def load(self):
self._model_data = model_data = torch.load(self._model_data_path, map_location='cpu') if self._model_data_path.exists() else {}
self.set_device_info( lib_torch.get_device_info_by_index(model_data.get('device_index', -1)) )
self.set_batch_size( model_data.get('batch_size', 64) )
self.set_resolution( model_data.get('resolution', 224) )
self.set_iteration( model_data.get('iteration', 0) )
self.set_autosave_period( model_data.get('autosave_period', 25) )
self.set_is_training( model_data.get('training', False) )
self.set_loss_history ( model_data.get('loss_history', {}) )
self.reset_model(load=True)
def reset_model(self, load : bool = True):
while True:
model = tv.mobilenet.mobilenet_v3_large(num_classes=6)
model.train()
model.to(self._device)
model_optimizer = AdaBelief(model.parameters(), lr=5e-5, lr_dropout=0.3)
if load:
model_state_dict = self._model_data.get('model_state_dict', None)
if model_state_dict is not None:
try:
model.load_state_dict(model_state_dict)
model.to(self._device)
model_optimizer_state_dict = self._model_data.get('model_optimizer_state_dict', None)
if model_optimizer_state_dict is not None:
model_optimizer.load_state_dict(model_optimizer_state_dict)
except:
print('Network weights have been reseted.')
self._model_data['model_state_dict'] = None
self._model_data['model_optimizer_state_dict'] = None
continue
else:
print('Network weights have been reseted.')
break
self._model = model
self._model_optimizer = model_optimizer
def save(self):
if self._model_data is not None:
d = {'device_index' : self._device_info.get_index(),
'batch_size' : self.get_batch_size(),
'resolution' : self.get_resolution(),
'iteration' : self.get_iteration(),
'autosave_period' : self.get_autosave_period(),
'training' : self.get_is_training(),
'loss_history' : self.get_loss_history(),
'model_state_dict' : self._model.state_dict(),
'model_optimizer_state_dict': self._model_optimizer.state_dict(),
}
torch.save(d, self._model_data_path)
def export(self):
self._model.to('cpu')
self._model.eval()
torch.onnx.export( self._model,
(torch.from_numpy( np.zeros( (1,3,self._resolution,self._resolution), dtype=np.float32)),) ,
str(self._workspace_path / 'FaceAligner.onnx'),
verbose=True,
training=torch.onnx.TrainingMode.EVAL,
opset_version=12,
do_constant_folding=True,
input_names=['in'],
output_names=['mat'],
dynamic_axes={'in' : {0:'batch_size'}, 'mat' : {0:'batch_size'}}, )
self._model.to(self._device)
self._model.train()
def preview_thread_proc(self):
while not self._is_quit:
preview_data, self._new_preview_data = self._new_preview_data, None
if preview_data is not None:
# new preview data to show
data = preview_data.training_data
n = np.random.randint(data.batch_size)
img_aligned = data.img_aligned[n].transpose((1,2,0))
img_aligned_shifted = data.img_aligned_shifted[n].transpose((1,2,0))
H,W = img_aligned_shifted.shape[:2]
shift_mat = lib_math.Affine2DUniMat(data.shift_uni_mats[n]).invert().to_exact_mat(W,H, W, H)
shift_mat_pred = lib_math.Affine2DUniMat(preview_data.shift_uni_mats_pred[n]).invert().to_exact_mat(W,H, W, H)
img_aligned_unshifted = cv2.warpAffine(img_aligned_shifted, shift_mat, (W,H))
img_aligned_unshifted_pred = cv2.warpAffine(img_aligned_shifted, shift_mat_pred, (W,H))
screen = np.concatenate([img_aligned, img_aligned_shifted, img_aligned_unshifted, img_aligned_unshifted_pred], 1)
cv2.imshow('Preview', screen)
viewing_data, self._new_viewing_data = self._new_viewing_data, None
if viewing_data is not None:
n = np.random.randint(viewing_data.batch_size)
img_aligned_shifted = viewing_data.img_aligned_shifted[n].transpose((1,2,0))
screen = np.concatenate([img_aligned_shifted], 1)
cv2.imshow('Viewing samples', screen)
cv2.waitKey(5)
time.sleep(0.005)
def main_loop(self):
while not self._is_quit:
if self._ev_request_reset_model.is_set():
self._ev_request_reset_model.clear()
self.reset_model(load=False)
if self._is_device_changed:
self._is_device_changed = False
self._model.to(self._device)
self._model_optimizer.load_state_dict(self._model_optimizer.state_dict())
if self._is_training or \
self._is_previewing_samples or \
self._ev_request_preview.is_set():
training_data = self._training_generator.get_next_data(wait=False)
if training_data is not None and \
training_data.resolution == self.get_resolution(): # Skip if resolution is different, due to delay
if self._is_training:
self._model_optimizer.zero_grad()
if self._ev_request_preview.is_set() or \
self._is_training:
# Inference for both preview and training
img_aligned_shifted_t = torch.tensor(training_data.img_aligned_shifted).to(self._device)
shift_uni_mats_pred_t = self._model(img_aligned_shifted_t).view( (-1,2,3) )
if self._is_training:
# Training optimization step
shift_uni_mats_t = torch.tensor(training_data.shift_uni_mats).to(self._device)
loss_t = (shift_uni_mats_pred_t-shift_uni_mats_t).square().mean()*10.0
loss_t.backward()
self._model_optimizer.step()
loss = loss_t.detach().cpu().numpy()
rec_loss_history = self._loss_history.get('reconstruct', None)
if rec_loss_history is None:
rec_loss_history = self._loss_history['reconstruct'] = []
rec_loss_history.append(float(loss))
self.set_iteration( self.get_iteration() + 1 )
if self._ev_request_preview.is_set():
self._ev_request_preview.clear()
# Preview request
pd = PreviewData()
pd.training_data = training_data
pd.shift_uni_mats_pred = shift_uni_mats_pred_t.detach().cpu().numpy()
self._new_preview_data = pd
if self._is_previewing_samples:
self._new_viewing_data = training_data
if self._is_training:
if self._last_save_time is not None:
while (time.time()-self._last_save_time)/60 >= self._autosave_period:
self._last_save_time += self._autosave_period*60
self._ev_request_save.set()
if self._ev_request_export_model.is_set():
self._ev_request_export_model.clear()
print('Exporting...')
self.export()
print('Exporting done.')
dc.Diacon.update_dlg()
if self._ev_request_save.is_set():
self._ev_request_save.clear()
print('Saving...')
self.save()
print('Saving done.')
dc.Diacon.update_dlg()
if self._ev_request_quit.is_set():
self._ev_request_quit.clear()
self._is_quit = True
time.sleep(0.005)
def get_main_dlg(self):
last_loss = 0
rec_loss_history = self._loss_history.get('reconstruct', None)
if rec_loss_history is not None:
if len(rec_loss_history) != 0:
last_loss = rec_loss_history[-1]
return dc.DlgChoices([
dc.DlgChoice(short_name='sgm', row_def='| Sample generator menu.',
on_choose=lambda dlg: self.get_sample_generator_dlg(dlg).set_current()),
dc.DlgChoice(short_name='d', row_def=f'| Device | {self._device_info}',
on_choose=lambda dlg: self.get_training_device_dlg(dlg).set_current()),
dc.DlgChoice(short_name='i', row_def=f'| Iteration | {self.get_iteration()}',
on_choose=lambda dlg: self.get_iteration_dlg(parent_dlg=dlg).set_current() ),
dc.DlgChoice(short_name='l', row_def=f'| Print loss history | Last loss = {last_loss:.5f} ',
on_choose=self.on_main_dlg_print_loss_history ),
dc.DlgChoice(short_name='p', row_def='| Show current preview.',
on_choose=lambda dlg: (self._ev_request_preview.set(), dlg.recreate().set_current())),
dc.DlgChoice(short_name='t', row_def=f'| Training | {self._is_training}',
on_choose=lambda dlg: (self.set_is_training(not self.get_is_training()), dlg.recreate().set_current()) ),
dc.DlgChoice(short_name='reset', row_def='| Reset model.',
on_choose=lambda dlg: (self._ev_request_reset_model.set(), dlg.recreate().set_current()) ),
dc.DlgChoice(short_name='export', row_def='| Export model.',
on_choose=lambda dlg: self._ev_request_export_model.set() ),
dc.DlgChoice(short_name='se', row_def=f'| Autosave period | {self.get_autosave_period()} minutes',
on_choose=lambda dlg: self.get_autosave_period_dlg(dlg).set_current()),
dc.DlgChoice(short_name='s', row_def='| Save all.',
on_choose=lambda dlg: self._ev_request_save.set() ),
dc.DlgChoice(short_name='q', row_def='| Quit now.',
on_choose=self.on_main_dlg_quit )
], on_recreate=lambda dlg: self.get_main_dlg(),
top_rows_def='|c9 Main menu' )
def on_main_dlg_quit(self, dlg):
self._ev_request_quit.set()
def on_main_dlg_print_loss_history(self, dlg):
max_lines = 20
for key in self._loss_history.keys():
lh = self._loss_history[key]
print(f'Loss history for: {key}')
d = len(lh) // max_lines
lh_ar = np.array(lh[-d*max_lines:], np.float32)
lh_ar = lh_ar.reshape( (max_lines, d)).mean(-1)
print( '\n'.join( f'{value:.5f}' for value in lh_ar ) )
dlg.recreate().set_current()
def get_sample_generator_dlg(self, parent_dlg):
return dc.DlgChoices([
dc.DlgChoice(short_name='v', row_def=f'| Previewing samples | {self._is_previewing_samples}',
on_choose=self.on_sample_generator_dlg_previewing_last_samples,
),
dc.DlgChoice(short_name='r', row_def=f'| Running | {self._training_generator.is_running()}',
on_choose=lambda dlg: (self._training_generator.set_running(not self._training_generator.is_running()), dlg.recreate().set_current()) ),
],
on_recreate=lambda dlg: self.get_sample_generator_dlg(parent_dlg),
on_back =lambda dlg: parent_dlg.recreate().set_current(),
top_rows_def='|c9 Sample generator menu' )
def on_sample_generator_dlg_previewing_last_samples(self, dlg):
self._is_previewing_samples = not self._is_previewing_samples
dlg.recreate().set_current()
def get_training_dlg(self, parent_dlg):
return dc.DlgChoices([
],
on_recreate=lambda dlg: self.get_training_dlg(parent_dlg),
on_back =lambda dlg: parent_dlg.recreate().set_current(),
top_rows_def='|c9 Training menu' )
def get_autosave_period_dlg(self, parent_dlg):
return dc.DlgNumber(is_float=False, min_value=1,
on_value = lambda dlg, value: (self.set_autosave_period(value), parent_dlg.recreate().set_current()),
on_recreate = lambda dlg: self.get_autosave_period_dlg(parent_dlg),
on_back = lambda dlg: parent_dlg.recreate().set_current(),
top_rows_def='|c9 Set save every min', )
def get_iteration_dlg(self, parent_dlg):
return dc.DlgNumber(is_float=False, min_value=0,
on_value = lambda dlg, value: (self.set_iteration(value), parent_dlg.recreate().set_current()),
on_recreate = lambda dlg: self.get_iteration_dlg(parent_dlg),
on_back = lambda dlg: parent_dlg.recreate().set_current(),
top_rows_def='|c9 Set iteration', )
def get_training_device_dlg(self, parent_dlg):
return DlgTorchDevicesInfo(on_device_choice = lambda dlg, device_info: (self.set_device_info(device_info), parent_dlg.recreate().set_current()),
on_recreate = lambda dlg: self.get_training_device_dlg(parent_dlg),
on_back = lambda dlg: parent_dlg.recreate().set_current(),
top_rows_def='|c9 Choose device'
)
class DlgTorchDevicesInfo(dc.DlgChoices):
def __init__(self, on_device_choice : Callable = None,
on_device_multi_choice : Callable = None,
on_recreate = None,
on_back : Callable = None,
top_rows_def : Union[str, List[str]] = None,
bottom_rows_def : Union[str, List[str]] = None,):
devices = lib_torch.get_available_devices_info()
super().__init__(choices=[
dc.DlgChoice(short_name=f'{device.get_index()}' if not device.is_cpu() else 'c',
row_def= f"| {str(device.get_name())} " +
(f"| {(device.get_total_memory() / 1024**3) :.3}Gb" if not device.is_cpu() else ""),
on_choose= ( lambda dlg, i=i: on_device_choice(dlg, devices[i]) ) \
if on_device_choice is not None else None)
for i, device in enumerate(devices) ],
on_multi_choice=(lambda idxs: on_device_multi_choice([ devices[idx] for idx in idxs ])) \
if on_device_multi_choice is not None else None,
on_recreate=on_recreate, on_back=on_back,
top_rows_def=top_rows_def, bottom_rows_def=bottom_rows_def)
class PreviewData:
training_data : Data = None
shift_uni_mats_pred = None

View file

@ -0,0 +1,248 @@
import threading
import time
from collections import deque
from pathlib import Path
from typing import Any, List, Tuple, Union
import cv2
import numpy as np
from xlib import face as lib_face
from xlib import image as lib_img
from xlib import mp as lib_mp
from xlib import mt as lib_mt
from xlib.image import sd as lib_sd
class Data:
def __init__(self):
self.batch_size : int = None
self.resolution : int = None
self.img_aligned : np.ndarray = None
self.img_aligned_shifted : np.ndarray = None
self.shift_uni_mats : np.ndarray = None
class TrainingDataGenerator(lib_mp.MPWorker):
def __init__(self, faceset_path : Path):
faceset_path = Path(faceset_path)
if not faceset_path.exists():
raise Exception (f'{faceset_path} does not exist.')
super().__init__(sub_args=[faceset_path])
self._datas = [ deque() for _ in range(self.get_process_count())]
self._datas_counter = 0
self._running = False
def get_next_data(self, wait : bool) -> Union[Data, None]:
"""
wait and returns new generated data
"""
while True:
for _ in range(self.get_process_count()):
process_id, self._datas_counter = self._datas_counter % len(self._datas), self._datas_counter + 1
data = self._datas[process_id]
if len(data) != 0:
self._send_msg('data_received', process_id=process_id)
return data.popleft()
if not wait:
return None
time.sleep(0.005)
def is_running(self) -> bool: return self._running
def set_running(self, running : bool):
self._running = running
self._send_msg('running', running)
def set_batch_size(self, batch_size):
self._send_msg('batch_size', batch_size)
def set_resolution(self, resolution):
self._send_msg('resolution', resolution)
###### IMPL HOST
def _on_host_sub_message(self, process_id, name, *args, **kwargs):
"""
message from sub
"""
if name == 'data':
self._datas[process_id].append(args[0])
###### IMPL SUB
def _on_sub_initialize(self, faceset_path : Path):
self._fs = fs = lib_face.Faceset(faceset_path)
self._ufm_uuids = fs.get_all_UFaceMark_uuids()
self._ufm_uuid_indexes = []
self._sent_buffers_count = 0
self._running = False
self._batch_size = None
self._resolution = None
self._n_batch = 0
self._img_aligned_list = []
self._img_aligned_shifted_list = []
self._shift_mat_list = []
def _on_sub_finalize(self):
self._fs.close()
def _on_sub_host_message(self, name, *args, **kwargs):
"""
a message from host
"""
if name == 'data_received':
self._sent_buffers_count -= 1
elif name == 'batch_size':
self._batch_size, = args
elif name == 'resolution':
self._resolution, = args
elif name == 'running':
self._running , = args
####### IMPL SUB THREAD
def _on_sub_tick(self, process_id):
running = self._running
if running:
if self._batch_size is None:
print('Unable to start TrainingGenerator: batch_size must be set')
running = False
if self._resolution is None:
print('Unable to start TrainingGenerator: resolution must be set')
running = False
if running:
if self._sent_buffers_count < 2:
batch_size = self._batch_size
resolution = self._resolution
face_coverage = 1.0
rw_grid_cell_range = [3,7]
rw_grid_rot_deg_range = [-180,180]
rw_grid_scale_range = [-0.25, 2.5]
rw_grid_tx_range = [-0.50, 0.50]
rw_grid_ty_range = [-0.50, 0.50]
align_rot_deg_range = [-180,180]
align_scale_range = [0.0,2.5]
align_tx_range = [-0.50, 0.50]
align_ty_range = [-0.50, 0.50]
random_mask_complexity = 3
sharpen_chance = 25
motion_blur_chance = 25
gaussian_blur_chance = 25
reresize_chance = 25
recompress_chance = 25
img_aligned_list = []
img_aligned_shifted_list = []
shift_mat_list = []
if self._n_batch < batch_size:
# Make only 1 sample per tick
while True:
uuid1 = self._get_next_UFaceMark_uuid()
ufm1 = self._fs.get_UFaceMark_by_uuid(uuid1)
flmrks1 = ufm1.get_FLandmarks2D_best()
if flmrks1 is None:
print(f'Corrupted faceset, no FLandmarks2D for UFaceMark {ufm1.get_uuid()}')
continue
uimg1 = self._fs.get_UImage_by_uuid(ufm1.get_UImage_uuid())
if uimg1 is None:
print(f'Corrupted faceset, no UImage for UFaceMark {ufm1.get_uuid()}')
continue
img1 = uimg1.get_image()
if img1 is None:
print(f'Corrupted faceset, no image in UImage {uimg1.get_uuid()}')
continue
img_aligned, _ = flmrks1.cut(img1, face_coverage, resolution)
img_aligned = img_aligned.astype(np.float32) / 255.0
_, img_to_face_uni_mat1 = flmrks1.calc_cut( img1.shape[0:2], face_coverage, resolution)
fw1 = lib_face.FaceWarper(img_to_face_uni_mat1,
align_rot_deg=align_rot_deg_range,
align_scale=align_scale_range,
align_tx=align_tx_range,
align_ty=align_ty_range,
rw_grid_cell_count=rw_grid_cell_range,
rw_grid_rot_deg=rw_grid_rot_deg_range,
rw_grid_scale=rw_grid_scale_range,
rw_grid_tx=rw_grid_tx_range,
rw_grid_ty=rw_grid_ty_range,
)
img_aligned_shifted = fw1.transform(img1, resolution, random_warp=True).astype(np.float32) / 255.0
ip = lib_img.ImageProcessor(img_aligned_shifted)
rnd = np.random
if rnd.randint(2) == 0:
ip.hsv( rnd.randint(0, 360), rnd.uniform(-0.5,0.5), rnd.uniform(-0.5,0.5), mask=lib_sd.random_circle_faded_multi((resolution,resolution), complexity=random_mask_complexity))
else:
ip.levels( [ [rnd.uniform(0,0.25),rnd.uniform(0.75,1.0),rnd.uniform(0.5,1.5), rnd.uniform(0,0.25),rnd.uniform(0.75,1.0), ],
[rnd.uniform(0,0.25),rnd.uniform(0.75,1.0),rnd.uniform(0.5,1.5), rnd.uniform(0,0.25),rnd.uniform(0.75,1.0),],
[rnd.uniform(0,0.25),rnd.uniform(0.75,1.0),rnd.uniform(0.5,1.5), rnd.uniform(0,0.25),rnd.uniform(0.75,1.0),], ], mask=lib_sd.random_circle_faded_multi((resolution,resolution), complexity=random_mask_complexity))
if rnd.randint(2) == 0:
if rnd.randint(100) < sharpen_chance:
if rnd.randint(2) == 0:
ip.box_sharpen(size=rnd.randint(1,11), power=rnd.uniform(0.5,5.0), mask=lib_sd.random_circle_faded_multi((resolution,resolution), complexity=random_mask_complexity) )
else:
ip.gaussian_sharpen(sigma=1.0, power=rnd.uniform(0.5,5.0), mask=lib_sd.random_circle_faded_multi((resolution,resolution), complexity=random_mask_complexity) )
else:
if rnd.randint(100) < motion_blur_chance:
ip.motion_blur (size=rnd.randint(1,11), angle=rnd.randint(360), mask=lib_sd.random_circle_faded_multi((resolution,resolution), complexity=random_mask_complexity))
if rnd.randint(100) < gaussian_blur_chance:
ip.gaussian_blur (sigma=rnd.uniform(0.5,3.0), mask=lib_sd.random_circle_faded_multi((resolution,resolution), complexity=random_mask_complexity))
if np.random.randint(2) == 0:
if rnd.randint(100) < reresize_chance:
ip.reresize( rnd.uniform(0.0,0.75), interpolation=ip.Interpolation.NEAREST, mask=lib_sd.random_circle_faded_multi((resolution,resolution), complexity=random_mask_complexity) )
if np.random.randint(2) == 0:
if rnd.randint(100) < reresize_chance:
ip.reresize( rnd.uniform(0.0,0.75), interpolation=ip.Interpolation.LINEAR, mask=lib_sd.random_circle_faded_multi((resolution,resolution), complexity=random_mask_complexity) )
if rnd.randint(100) < recompress_chance:
ip.jpeg_recompress(quality=rnd.randint(10,75), mask=lib_sd.random_circle_faded_multi((resolution,resolution), complexity=random_mask_complexity) )
img_aligned_shifted = ip.get_image('HWC')
self._img_aligned_list.append(img_aligned)
self._img_aligned_shifted_list.append(img_aligned_shifted)
self._shift_mat_list.append( fw1.get_aligned_random_transform_mat() )
self._n_batch += 1
break
if self._n_batch == batch_size:
data = Data()
data.batch_size = batch_size
data.resolution = resolution
data.img_aligned = np.array(self._img_aligned_list).transpose( (0,3,1,2))
data.img_aligned_shifted = np.array(self._img_aligned_shifted_list).transpose( (0,3,1,2))
data.shift_uni_mats = np.array(self._shift_mat_list)
self._send_msg('data', data)
self._sent_buffers_count +=1
self._n_batch = 0
self._img_aligned_list = []
self._img_aligned_shifted_list = []
self._shift_mat_list = []
def _get_next_UFaceMark_uuid(self) -> bytes:
if len(self._ufm_uuid_indexes) == 0:
self._ufm_uuid_indexes = [*range(len(self._ufm_uuids))]
np.random.shuffle(self._ufm_uuid_indexes)
idx = self._ufm_uuid_indexes.pop()
return self._ufm_uuids[idx]