diff --git a/apps/trainers/FaceAligner/FaceAlignerTrainer.py b/apps/trainers/FaceAligner/FaceAlignerTrainer.py deleted file mode 100644 index 4c739d0..0000000 --- a/apps/trainers/FaceAligner/FaceAlignerTrainer.py +++ /dev/null @@ -1,22 +0,0 @@ -from xlib import face as lib_face -from xlib import time as lib_time - - -class FaceAlignerTrainer: - def __init__(self, faceset_path): - #fs = self._fs = lib_face.Faceset(faceset_path) - fs = lib_face.Faceset(faceset_path) - #fs.close() - - with lib_time.timeit(): - for x in fs.iter_UImage(): - x.get_image() - #fs = lib_face.Faceset(faceset_path) - #fs.add_UFaceMark( [ lib_face.UFaceMark() for _ in range(1000)] ) - - import code - code.interact(local=dict(globals(), **locals())) - - - def run(self): - ... \ No newline at end of file diff --git a/apps/trainers/FaceAligner/FaceAlignerTrainerApp.py b/apps/trainers/FaceAligner/FaceAlignerTrainerApp.py new file mode 100644 index 0000000..da70167 --- /dev/null +++ b/apps/trainers/FaceAligner/FaceAlignerTrainerApp.py @@ -0,0 +1,449 @@ +import threading +import time +from pathlib import Path +from typing import Any, Callable, List, Tuple, Union + +import cv2 +import numpy as np +import torch +import torch.autograd +from localization import L, Localization +from modelhub import torch as torch_models +from xlib import torch as lib_torch +from xlib.console import diacon as dc +from xlib.torch.optim import AdaBelief +from xlib.torch.device import TorchDeviceInfo +import torchvision.models as tv +from xlib import math as lib_math +from .TrainingDataGenerator import Data, TrainingDataGenerator + +class FaceAlignerTrainerApp: + def __init__(self, workspace_path : Path, faceset_path : Path): + print('Initializing trainer.\n') + print(f'Workspace path: {workspace_path}') + print(f'Faceset path: {faceset_path}\n') + + workspace_path.mkdir(parents=True, exist_ok=True) + self._workspace_path = workspace_path + self._faceset_path = faceset_path + self._model_data_path = workspace_path / 'model.dat' + + # system vars + self._model_lock = threading.Lock() + self._ev_request_reset_model = threading.Event() + self._ev_request_preview = threading.Event() + self._ev_request_export_model = threading.Event() + self._ev_request_save = threading.Event() + self._ev_request_quit = threading.Event() + + # state vars + self._is_quit = False + self._is_device_changed = False + self._new_viewing_data = None + self._is_previewing_samples = False + self._new_preview_data : 'PreviewData' = None + self._last_save_time = None + + # settings / params + self._model_data = None + self._device = None + self._device_info = None + self._batch_size = None + self._resolution = None + self._iteration = None + self._autosave_period = None + self._is_training = None + self._loss_history = {} + + # Generators + self._training_generator = TrainingDataGenerator(faceset_path) + + # Load and start + self.load() + threading.Thread(target=self.preview_thread_proc, daemon=True).start() + self._training_generator.set_running(True) + + self.get_main_dlg().set_current() + self.main_loop() + + # Finalizing + self._training_generator.stop() + dc.Diacon.stop() + + def get_device_info(self): return self._device_info + def set_device_info(self, device_info : TorchDeviceInfo): + self._device = lib_torch.get_device(device_info) + self._device_info = device_info + self._is_device_changed = True + + def get_batch_size(self) -> int: return self._batch_size + def set_batch_size(self, batch_size : int): + self._batch_size = batch_size + self._training_generator.set_batch_size(batch_size) + + def get_resolution(self) -> int: return self._resolution + def set_resolution(self, resolution : int): + self._resolution = resolution + self._training_generator.set_resolution(resolution) + + def get_iteration(self) -> int: return self._iteration + def set_iteration(self, iteration : int): + self._iteration = iteration + for key in self._loss_history.keys(): + self._loss_history[key] = self._loss_history[key][:iteration] + + def get_autosave_period(self): return self._autosave_period + def set_autosave_period(self, mins : int): + self._autosave_period = mins + + def get_is_training(self) -> bool: return self._is_training + def set_is_training(self, training : bool): + if self._is_training != training: + if training: + self._last_save_time = time.time() + else: + self._last_save_time = None + self._is_training = training + + def get_loss_history(self): return self._loss_history + def set_loss_history(self, lh): self._loss_history = lh + + def load(self): + self._model_data = model_data = torch.load(self._model_data_path, map_location='cpu') if self._model_data_path.exists() else {} + + self.set_device_info( lib_torch.get_device_info_by_index(model_data.get('device_index', -1)) ) + self.set_batch_size( model_data.get('batch_size', 64) ) + self.set_resolution( model_data.get('resolution', 224) ) + self.set_iteration( model_data.get('iteration', 0) ) + self.set_autosave_period( model_data.get('autosave_period', 25) ) + self.set_is_training( model_data.get('training', False) ) + self.set_loss_history ( model_data.get('loss_history', {}) ) + + self.reset_model(load=True) + + + + def reset_model(self, load : bool = True): + while True: + model = tv.mobilenet.mobilenet_v3_large(num_classes=6) + + model.train() + model.to(self._device) + + model_optimizer = AdaBelief(model.parameters(), lr=5e-5, lr_dropout=0.3) + + if load: + model_state_dict = self._model_data.get('model_state_dict', None) + if model_state_dict is not None: + try: + model.load_state_dict(model_state_dict) + model.to(self._device) + + model_optimizer_state_dict = self._model_data.get('model_optimizer_state_dict', None) + if model_optimizer_state_dict is not None: + model_optimizer.load_state_dict(model_optimizer_state_dict) + + except: + print('Network weights have been reseted.') + self._model_data['model_state_dict'] = None + self._model_data['model_optimizer_state_dict'] = None + continue + else: + print('Network weights have been reseted.') + break + + self._model = model + self._model_optimizer = model_optimizer + + def save(self): + if self._model_data is not None: + d = {'device_index' : self._device_info.get_index(), + 'batch_size' : self.get_batch_size(), + 'resolution' : self.get_resolution(), + 'iteration' : self.get_iteration(), + 'autosave_period' : self.get_autosave_period(), + 'training' : self.get_is_training(), + 'loss_history' : self.get_loss_history(), + 'model_state_dict' : self._model.state_dict(), + 'model_optimizer_state_dict': self._model_optimizer.state_dict(), + } + + torch.save(d, self._model_data_path) + + def export(self): + self._model.to('cpu') + self._model.eval() + + torch.onnx.export( self._model, + (torch.from_numpy( np.zeros( (1,3,self._resolution,self._resolution), dtype=np.float32)),) , + str(self._workspace_path / 'FaceAligner.onnx'), + verbose=True, + training=torch.onnx.TrainingMode.EVAL, + opset_version=12, + do_constant_folding=True, + input_names=['in'], + output_names=['mat'], + dynamic_axes={'in' : {0:'batch_size'}, 'mat' : {0:'batch_size'}}, ) + + self._model.to(self._device) + self._model.train() + + + def preview_thread_proc(self): + + + while not self._is_quit: + preview_data, self._new_preview_data = self._new_preview_data, None + if preview_data is not None: + # new preview data to show + data = preview_data.training_data + n = np.random.randint(data.batch_size) + img_aligned = data.img_aligned[n].transpose((1,2,0)) + img_aligned_shifted = data.img_aligned_shifted[n].transpose((1,2,0)) + + H,W = img_aligned_shifted.shape[:2] + + + shift_mat = lib_math.Affine2DUniMat(data.shift_uni_mats[n]).invert().to_exact_mat(W,H, W, H) + shift_mat_pred = lib_math.Affine2DUniMat(preview_data.shift_uni_mats_pred[n]).invert().to_exact_mat(W,H, W, H) + + img_aligned_unshifted = cv2.warpAffine(img_aligned_shifted, shift_mat, (W,H)) + img_aligned_unshifted_pred = cv2.warpAffine(img_aligned_shifted, shift_mat_pred, (W,H)) + + screen = np.concatenate([img_aligned, img_aligned_shifted, img_aligned_unshifted, img_aligned_unshifted_pred], 1) + cv2.imshow('Preview', screen) + + + viewing_data, self._new_viewing_data = self._new_viewing_data, None + if viewing_data is not None: + n = np.random.randint(viewing_data.batch_size) + img_aligned_shifted = viewing_data.img_aligned_shifted[n].transpose((1,2,0)) + screen = np.concatenate([img_aligned_shifted], 1) + cv2.imshow('Viewing samples', screen) + + + cv2.waitKey(5) + time.sleep(0.005) + + def main_loop(self): + + while not self._is_quit: + + if self._ev_request_reset_model.is_set(): + self._ev_request_reset_model.clear() + self.reset_model(load=False) + + if self._is_device_changed: + self._is_device_changed = False + self._model.to(self._device) + self._model_optimizer.load_state_dict(self._model_optimizer.state_dict()) + + if self._is_training or \ + self._is_previewing_samples or \ + self._ev_request_preview.is_set(): + training_data = self._training_generator.get_next_data(wait=False) + if training_data is not None and \ + training_data.resolution == self.get_resolution(): # Skip if resolution is different, due to delay + + if self._is_training: + self._model_optimizer.zero_grad() + + if self._ev_request_preview.is_set() or \ + self._is_training: + # Inference for both preview and training + img_aligned_shifted_t = torch.tensor(training_data.img_aligned_shifted).to(self._device) + shift_uni_mats_pred_t = self._model(img_aligned_shifted_t).view( (-1,2,3) ) + + if self._is_training: + # Training optimization step + shift_uni_mats_t = torch.tensor(training_data.shift_uni_mats).to(self._device) + + loss_t = (shift_uni_mats_pred_t-shift_uni_mats_t).square().mean()*10.0 + loss_t.backward() + self._model_optimizer.step() + + loss = loss_t.detach().cpu().numpy() + + rec_loss_history = self._loss_history.get('reconstruct', None) + if rec_loss_history is None: + rec_loss_history = self._loss_history['reconstruct'] = [] + rec_loss_history.append(float(loss)) + + self.set_iteration( self.get_iteration() + 1 ) + + if self._ev_request_preview.is_set(): + self._ev_request_preview.clear() + # Preview request + pd = PreviewData() + pd.training_data = training_data + pd.shift_uni_mats_pred = shift_uni_mats_pred_t.detach().cpu().numpy() + self._new_preview_data = pd + + if self._is_previewing_samples: + self._new_viewing_data = training_data + + if self._is_training: + if self._last_save_time is not None: + while (time.time()-self._last_save_time)/60 >= self._autosave_period: + self._last_save_time += self._autosave_period*60 + self._ev_request_save.set() + + + if self._ev_request_export_model.is_set(): + self._ev_request_export_model.clear() + print('Exporting...') + self.export() + print('Exporting done.') + dc.Diacon.update_dlg() + + if self._ev_request_save.is_set(): + self._ev_request_save.clear() + print('Saving...') + self.save() + print('Saving done.') + + dc.Diacon.update_dlg() + + if self._ev_request_quit.is_set(): + self._ev_request_quit.clear() + self._is_quit = True + + time.sleep(0.005) + + + def get_main_dlg(self): + last_loss = 0 + rec_loss_history = self._loss_history.get('reconstruct', None) + if rec_loss_history is not None: + if len(rec_loss_history) != 0: + last_loss = rec_loss_history[-1] + + return dc.DlgChoices([ + dc.DlgChoice(short_name='sgm', row_def='| Sample generator menu.', + on_choose=lambda dlg: self.get_sample_generator_dlg(dlg).set_current()), + + dc.DlgChoice(short_name='d', row_def=f'| Device | {self._device_info}', + on_choose=lambda dlg: self.get_training_device_dlg(dlg).set_current()), + + dc.DlgChoice(short_name='i', row_def=f'| Iteration | {self.get_iteration()}', + on_choose=lambda dlg: self.get_iteration_dlg(parent_dlg=dlg).set_current() ), + + dc.DlgChoice(short_name='l', row_def=f'| Print loss history | Last loss = {last_loss:.5f} ', + on_choose=self.on_main_dlg_print_loss_history ), + + dc.DlgChoice(short_name='p', row_def='| Show current preview.', + on_choose=lambda dlg: (self._ev_request_preview.set(), dlg.recreate().set_current())), + + dc.DlgChoice(short_name='t', row_def=f'| Training | {self._is_training}', + on_choose=lambda dlg: (self.set_is_training(not self.get_is_training()), dlg.recreate().set_current()) ), + + dc.DlgChoice(short_name='reset', row_def='| Reset model.', + on_choose=lambda dlg: (self._ev_request_reset_model.set(), dlg.recreate().set_current()) ), + + dc.DlgChoice(short_name='export', row_def='| Export model.', + on_choose=lambda dlg: self._ev_request_export_model.set() ), + + dc.DlgChoice(short_name='se', row_def=f'| Autosave period | {self.get_autosave_period()} minutes', + on_choose=lambda dlg: self.get_autosave_period_dlg(dlg).set_current()), + + dc.DlgChoice(short_name='s', row_def='| Save all.', + on_choose=lambda dlg: self._ev_request_save.set() ), + + dc.DlgChoice(short_name='q', row_def='| Quit now.', + on_choose=self.on_main_dlg_quit ) + ], on_recreate=lambda dlg: self.get_main_dlg(), + top_rows_def='|c9 Main menu' ) + + def on_main_dlg_quit(self, dlg): + self._ev_request_quit.set() + + + + def on_main_dlg_print_loss_history(self, dlg): + max_lines = 20 + for key in self._loss_history.keys(): + lh = self._loss_history[key] + + print(f'Loss history for: {key}') + + d = len(lh) // max_lines + + lh_ar = np.array(lh[-d*max_lines:], np.float32) + lh_ar = lh_ar.reshape( (max_lines, d)).mean(-1) + + print( '\n'.join( f'{value:.5f}' for value in lh_ar ) ) + + dlg.recreate().set_current() + + + def get_sample_generator_dlg(self, parent_dlg): + return dc.DlgChoices([ + dc.DlgChoice(short_name='v', row_def=f'| Previewing samples | {self._is_previewing_samples}', + on_choose=self.on_sample_generator_dlg_previewing_last_samples, + ), + + dc.DlgChoice(short_name='r', row_def=f'| Running | {self._training_generator.is_running()}', + on_choose=lambda dlg: (self._training_generator.set_running(not self._training_generator.is_running()), dlg.recreate().set_current()) ), + + ], + on_recreate=lambda dlg: self.get_sample_generator_dlg(parent_dlg), + on_back =lambda dlg: parent_dlg.recreate().set_current(), + top_rows_def='|c9 Sample generator menu' ) + + def on_sample_generator_dlg_previewing_last_samples(self, dlg): + self._is_previewing_samples = not self._is_previewing_samples + dlg.recreate().set_current() + + def get_training_dlg(self, parent_dlg): + return dc.DlgChoices([ + + ], + on_recreate=lambda dlg: self.get_training_dlg(parent_dlg), + on_back =lambda dlg: parent_dlg.recreate().set_current(), + top_rows_def='|c9 Training menu' ) + + def get_autosave_period_dlg(self, parent_dlg): + return dc.DlgNumber(is_float=False, min_value=1, + on_value = lambda dlg, value: (self.set_autosave_period(value), parent_dlg.recreate().set_current()), + on_recreate = lambda dlg: self.get_autosave_period_dlg(parent_dlg), + on_back = lambda dlg: parent_dlg.recreate().set_current(), + top_rows_def='|c9 Set save every min', ) + + def get_iteration_dlg(self, parent_dlg): + return dc.DlgNumber(is_float=False, min_value=0, + on_value = lambda dlg, value: (self.set_iteration(value), parent_dlg.recreate().set_current()), + on_recreate = lambda dlg: self.get_iteration_dlg(parent_dlg), + on_back = lambda dlg: parent_dlg.recreate().set_current(), + top_rows_def='|c9 Set iteration', ) + + def get_training_device_dlg(self, parent_dlg): + return DlgTorchDevicesInfo(on_device_choice = lambda dlg, device_info: (self.set_device_info(device_info), parent_dlg.recreate().set_current()), + on_recreate = lambda dlg: self.get_training_device_dlg(parent_dlg), + on_back = lambda dlg: parent_dlg.recreate().set_current(), + top_rows_def='|c9 Choose device' + ) + + +class DlgTorchDevicesInfo(dc.DlgChoices): + def __init__(self, on_device_choice : Callable = None, + on_device_multi_choice : Callable = None, + on_recreate = None, + on_back : Callable = None, + top_rows_def : Union[str, List[str]] = None, + bottom_rows_def : Union[str, List[str]] = None,): + devices = lib_torch.get_available_devices_info() + super().__init__(choices=[ + dc.DlgChoice(short_name=f'{device.get_index()}' if not device.is_cpu() else 'c', + row_def= f"| {str(device.get_name())} " + + (f"| {(device.get_total_memory() / 1024**3) :.3}Gb" if not device.is_cpu() else ""), + on_choose= ( lambda dlg, i=i: on_device_choice(dlg, devices[i]) ) \ + if on_device_choice is not None else None) + for i, device in enumerate(devices) ], + on_multi_choice=(lambda idxs: on_device_multi_choice([ devices[idx] for idx in idxs ])) \ + if on_device_multi_choice is not None else None, + on_recreate=on_recreate, on_back=on_back, + top_rows_def=top_rows_def, bottom_rows_def=bottom_rows_def) + +class PreviewData: + training_data : Data = None + shift_uni_mats_pred = None diff --git a/apps/trainers/FaceAligner/TrainingDataGenerator.py b/apps/trainers/FaceAligner/TrainingDataGenerator.py new file mode 100644 index 0000000..119c9e5 --- /dev/null +++ b/apps/trainers/FaceAligner/TrainingDataGenerator.py @@ -0,0 +1,248 @@ +import threading +import time +from collections import deque +from pathlib import Path +from typing import Any, List, Tuple, Union + +import cv2 +import numpy as np +from xlib import face as lib_face +from xlib import image as lib_img +from xlib import mp as lib_mp +from xlib import mt as lib_mt +from xlib.image import sd as lib_sd + + +class Data: + def __init__(self): + self.batch_size : int = None + self.resolution : int = None + self.img_aligned : np.ndarray = None + self.img_aligned_shifted : np.ndarray = None + self.shift_uni_mats : np.ndarray = None + +class TrainingDataGenerator(lib_mp.MPWorker): + def __init__(self, faceset_path : Path): + faceset_path = Path(faceset_path) + if not faceset_path.exists(): + raise Exception (f'{faceset_path} does not exist.') + + super().__init__(sub_args=[faceset_path]) + + self._datas = [ deque() for _ in range(self.get_process_count())] + self._datas_counter = 0 + self._running = False + + def get_next_data(self, wait : bool) -> Union[Data, None]: + """ + wait and returns new generated data + """ + while True: + for _ in range(self.get_process_count()): + process_id, self._datas_counter = self._datas_counter % len(self._datas), self._datas_counter + 1 + data = self._datas[process_id] + + if len(data) != 0: + self._send_msg('data_received', process_id=process_id) + return data.popleft() + + if not wait: + return None + time.sleep(0.005) + + def is_running(self) -> bool: return self._running + def set_running(self, running : bool): + self._running = running + self._send_msg('running', running) + + def set_batch_size(self, batch_size): + self._send_msg('batch_size', batch_size) + + def set_resolution(self, resolution): + self._send_msg('resolution', resolution) + + ###### IMPL HOST + def _on_host_sub_message(self, process_id, name, *args, **kwargs): + """ + message from sub + """ + if name == 'data': + self._datas[process_id].append(args[0]) + + ###### IMPL SUB + def _on_sub_initialize(self, faceset_path : Path): + self._fs = fs = lib_face.Faceset(faceset_path) + self._ufm_uuids = fs.get_all_UFaceMark_uuids() + self._ufm_uuid_indexes = [] + self._sent_buffers_count = 0 + self._running = False + self._batch_size = None + self._resolution = None + + self._n_batch = 0 + self._img_aligned_list = [] + self._img_aligned_shifted_list = [] + self._shift_mat_list = [] + + def _on_sub_finalize(self): + self._fs.close() + + def _on_sub_host_message(self, name, *args, **kwargs): + """ + a message from host + """ + if name == 'data_received': + self._sent_buffers_count -= 1 + elif name == 'batch_size': + self._batch_size, = args + elif name == 'resolution': + self._resolution, = args + elif name == 'running': + self._running , = args + + ####### IMPL SUB THREAD + def _on_sub_tick(self, process_id): + running = self._running + if running: + if self._batch_size is None: + print('Unable to start TrainingGenerator: batch_size must be set') + running = False + if self._resolution is None: + print('Unable to start TrainingGenerator: resolution must be set') + running = False + + if running: + if self._sent_buffers_count < 2: + batch_size = self._batch_size + resolution = self._resolution + face_coverage = 1.0 + + rw_grid_cell_range = [3,7] + rw_grid_rot_deg_range = [-180,180] + rw_grid_scale_range = [-0.25, 2.5] + rw_grid_tx_range = [-0.50, 0.50] + rw_grid_ty_range = [-0.50, 0.50] + + align_rot_deg_range = [-180,180] + align_scale_range = [0.0,2.5] + align_tx_range = [-0.50, 0.50] + align_ty_range = [-0.50, 0.50] + + + + random_mask_complexity = 3 + sharpen_chance = 25 + motion_blur_chance = 25 + gaussian_blur_chance = 25 + reresize_chance = 25 + recompress_chance = 25 + + img_aligned_list = [] + img_aligned_shifted_list = [] + shift_mat_list = [] + + if self._n_batch < batch_size: + # Make only 1 sample per tick + while True: + uuid1 = self._get_next_UFaceMark_uuid() + + ufm1 = self._fs.get_UFaceMark_by_uuid(uuid1) + + flmrks1 = ufm1.get_FLandmarks2D_best() + if flmrks1 is None: + print(f'Corrupted faceset, no FLandmarks2D for UFaceMark {ufm1.get_uuid()}') + continue + + uimg1 = self._fs.get_UImage_by_uuid(ufm1.get_UImage_uuid()) + if uimg1 is None: + print(f'Corrupted faceset, no UImage for UFaceMark {ufm1.get_uuid()}') + continue + + img1 = uimg1.get_image() + + if img1 is None: + print(f'Corrupted faceset, no image in UImage {uimg1.get_uuid()}') + continue + + img_aligned, _ = flmrks1.cut(img1, face_coverage, resolution) + img_aligned = img_aligned.astype(np.float32) / 255.0 + + _, img_to_face_uni_mat1 = flmrks1.calc_cut( img1.shape[0:2], face_coverage, resolution) + + + fw1 = lib_face.FaceWarper(img_to_face_uni_mat1, + align_rot_deg=align_rot_deg_range, + align_scale=align_scale_range, + align_tx=align_tx_range, + align_ty=align_ty_range, + rw_grid_cell_count=rw_grid_cell_range, + rw_grid_rot_deg=rw_grid_rot_deg_range, + rw_grid_scale=rw_grid_scale_range, + rw_grid_tx=rw_grid_tx_range, + rw_grid_ty=rw_grid_ty_range, + ) + + img_aligned_shifted = fw1.transform(img1, resolution, random_warp=True).astype(np.float32) / 255.0 + + ip = lib_img.ImageProcessor(img_aligned_shifted) + rnd = np.random + if rnd.randint(2) == 0: + ip.hsv( rnd.randint(0, 360), rnd.uniform(-0.5,0.5), rnd.uniform(-0.5,0.5), mask=lib_sd.random_circle_faded_multi((resolution,resolution), complexity=random_mask_complexity)) + else: + ip.levels( [ [rnd.uniform(0,0.25),rnd.uniform(0.75,1.0),rnd.uniform(0.5,1.5), rnd.uniform(0,0.25),rnd.uniform(0.75,1.0), ], + [rnd.uniform(0,0.25),rnd.uniform(0.75,1.0),rnd.uniform(0.5,1.5), rnd.uniform(0,0.25),rnd.uniform(0.75,1.0),], + [rnd.uniform(0,0.25),rnd.uniform(0.75,1.0),rnd.uniform(0.5,1.5), rnd.uniform(0,0.25),rnd.uniform(0.75,1.0),], ], mask=lib_sd.random_circle_faded_multi((resolution,resolution), complexity=random_mask_complexity)) + + if rnd.randint(2) == 0: + if rnd.randint(100) < sharpen_chance: + if rnd.randint(2) == 0: + ip.box_sharpen(size=rnd.randint(1,11), power=rnd.uniform(0.5,5.0), mask=lib_sd.random_circle_faded_multi((resolution,resolution), complexity=random_mask_complexity) ) + else: + ip.gaussian_sharpen(sigma=1.0, power=rnd.uniform(0.5,5.0), mask=lib_sd.random_circle_faded_multi((resolution,resolution), complexity=random_mask_complexity) ) + + else: + if rnd.randint(100) < motion_blur_chance: + ip.motion_blur (size=rnd.randint(1,11), angle=rnd.randint(360), mask=lib_sd.random_circle_faded_multi((resolution,resolution), complexity=random_mask_complexity)) + if rnd.randint(100) < gaussian_blur_chance: + ip.gaussian_blur (sigma=rnd.uniform(0.5,3.0), mask=lib_sd.random_circle_faded_multi((resolution,resolution), complexity=random_mask_complexity)) + + if np.random.randint(2) == 0: + if rnd.randint(100) < reresize_chance: + ip.reresize( rnd.uniform(0.0,0.75), interpolation=ip.Interpolation.NEAREST, mask=lib_sd.random_circle_faded_multi((resolution,resolution), complexity=random_mask_complexity) ) + if np.random.randint(2) == 0: + if rnd.randint(100) < reresize_chance: + ip.reresize( rnd.uniform(0.0,0.75), interpolation=ip.Interpolation.LINEAR, mask=lib_sd.random_circle_faded_multi((resolution,resolution), complexity=random_mask_complexity) ) + if rnd.randint(100) < recompress_chance: + ip.jpeg_recompress(quality=rnd.randint(10,75), mask=lib_sd.random_circle_faded_multi((resolution,resolution), complexity=random_mask_complexity) ) + + img_aligned_shifted = ip.get_image('HWC') + + self._img_aligned_list.append(img_aligned) + self._img_aligned_shifted_list.append(img_aligned_shifted) + self._shift_mat_list.append( fw1.get_aligned_random_transform_mat() ) + self._n_batch += 1 + break + + if self._n_batch == batch_size: + data = Data() + data.batch_size = batch_size + data.resolution = resolution + data.img_aligned = np.array(self._img_aligned_list).transpose( (0,3,1,2)) + data.img_aligned_shifted = np.array(self._img_aligned_shifted_list).transpose( (0,3,1,2)) + data.shift_uni_mats = np.array(self._shift_mat_list) + + self._send_msg('data', data) + self._sent_buffers_count +=1 + + self._n_batch = 0 + self._img_aligned_list = [] + self._img_aligned_shifted_list = [] + self._shift_mat_list = [] + + + def _get_next_UFaceMark_uuid(self) -> bytes: + if len(self._ufm_uuid_indexes) == 0: + self._ufm_uuid_indexes = [*range(len(self._ufm_uuids))] + np.random.shuffle(self._ufm_uuid_indexes) + idx = self._ufm_uuid_indexes.pop() + return self._ufm_uuids[idx]