diff --git a/models/Model_AVATAR/Model.py b/models/Model_AVATAR/Model.py index 2486b07..3aef73d 100644 --- a/models/Model_AVATAR/Model.py +++ b/models/Model_AVATAR/Model.py @@ -58,13 +58,12 @@ class AVATARModel(ModelBase): self.D = modelify(AVATARModel.Discriminator() ) (Input(df_bgr_shape)) self.C = modelify(AVATARModel.ResNet (9, n_blocks=6, ngf=128, use_dropout=False))( Input(res_bgr_t_shape)) - if self.is_first_run(): - conv_weights_list = [] + self.CA_conv_weights_list = [] + if self.is_first_run(): for model, _ in self.get_model_filename_list(): for layer in model.layers: if type(layer) == keras.layers.Conv2D: - conv_weights_list += [layer.weights[0]] #Conv2D kernel_weights - CAInitializerMP ( conv_weights_list ) + self.CA_conv_weights_list += [layer.weights[0]] #Conv2D kernel_weights if not self.is_first_run(): self.load_weights_safe( self.get_model_filename_list() ) @@ -247,7 +246,14 @@ class AVATARModel(ModelBase): #override def onSave(self): self.save_weights_safe( self.get_model_filename_list() ) - + + #override + def on_success_train_one_iter(self): + if len(self.CA_conv_weights_list) != 0: + exec(nnlib.import_all(), locals(), globals()) + CAInitializerMP ( self.CA_conv_weights_list ) + self.CA_conv_weights_list = [] + #override def onTrainOneIter(self, generators_samples, generators_list): warped_src64, src64, src64m = generators_samples[0] diff --git a/samplelib/SampleGeneratorFace.py b/samplelib/SampleGeneratorFace.py index edee7ea..a76e006 100644 --- a/samplelib/SampleGeneratorFace.py +++ b/samplelib/SampleGeneratorFace.py @@ -1,9 +1,9 @@ import multiprocessing import traceback - +import pickle import cv2 import numpy as np - +import time from facelib import LandmarksProcessor from samplelib import (SampleGeneratorBase, SampleHost, SampleProcessor, SampleType) @@ -23,6 +23,7 @@ class SampleGeneratorFace(SampleGeneratorBase): sample_process_options=SampleProcessor.Options(), output_sample_types=[], add_sample_idx=False, + generators_count=4, **kwargs): super().__init__(samples_path, debug, batch_size) @@ -33,10 +34,10 @@ class SampleGeneratorFace(SampleGeneratorBase): if self.debug: self.generators_count = 1 else: - self.generators_count = np.clip(multiprocessing.cpu_count(), 2, 6) + self.generators_count = np.clip(multiprocessing.cpu_count(), 2, generators_count) - samples_clis = SampleHost.host (SampleType.FACE, self.samples_path, number_of_clis=self.generators_count) - self.samples_len = len(samples_clis[0]) + samples = SampleHost.load (SampleType.FACE, self.samples_path) + self.samples_len = len(samples) if self.samples_len == 0: raise ValueError('No training data provided.') @@ -44,16 +45,19 @@ class SampleGeneratorFace(SampleGeneratorBase): index_host = mp_utils.IndexHost(self.samples_len) if random_ct_samples_path is not None: - ct_samples_clis = SampleHost.host (SampleType.FACE, random_ct_samples_path, number_of_clis=self.generators_count) - ct_index_host = mp_utils.IndexHost( len(ct_samples_clis[0]) ) + ct_samples = SampleHost.load (SampleType.FACE, random_ct_samples_path) + ct_index_host = mp_utils.IndexHost( len(ct_samples) ) else: - ct_samples_clis = None + ct_samples = None ct_index_host = None + pickled_samples = pickle.dumps(samples, 4) + ct_pickled_samples = pickle.dumps(ct_samples, 4) if ct_samples is not None else None + if self.debug: - self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (samples_clis[0], index_host.create_cli(), ct_samples_clis[0] if ct_index_host is not None else None, ct_index_host.create_cli() if ct_index_host is not None else None) )] + self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(), ct_pickled_samples, ct_index_host.create_cli() if ct_index_host is not None else None) )] else: - self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (samples_clis[i], index_host.create_cli(), ct_samples_clis[i] if ct_index_host is not None else None, ct_index_host.create_cli() if ct_index_host is not None else None), start_now=True ) for i in range(self.generators_count) ] + self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(), ct_pickled_samples, ct_index_host.create_cli() if ct_index_host is not None else None), start_now=True ) for i in range(self.generators_count) ] self.generator_counter = -1 @@ -70,7 +74,11 @@ class SampleGeneratorFace(SampleGeneratorBase): return next(generator) def batch_func(self, param ): - samples, index_host, ct_samples, ct_index_host = param + pickled_samples, index_host, ct_pickled_samples, ct_index_host = param + + samples = pickle.loads(pickled_samples) + ct_samples = pickle.loads(ct_pickled_samples) if ct_pickled_samples is not None else None + bs = self.batch_size while True: batches = None @@ -78,13 +86,14 @@ class SampleGeneratorFace(SampleGeneratorBase): indexes = index_host.multi_get(bs) ct_indexes = ct_index_host.multi_get(bs) if ct_samples is not None else None - batch_samples = samples.multi_get (indexes) - batch_ct_samples = ct_samples.multi_get (ct_indexes) if ct_samples is not None else None - + t = time.time() for n_batch in range(bs): sample_idx = indexes[n_batch] - sample = batch_samples[n_batch] - ct_sample = batch_ct_samples[n_batch] if ct_samples is not None else None + sample = samples[sample_idx] + + ct_sample = None + if ct_samples is not None: + ct_sample = ct_samples[ct_indexes[n_batch]] try: x, = SampleProcessor.process ([sample], self.sample_process_options, self.output_sample_types, self.debug, ct_sample=ct_sample) @@ -102,4 +111,5 @@ class SampleGeneratorFace(SampleGeneratorBase): if self.add_sample_idx: batches[i_sample_idx].append (sample_idx) + yield [ np.array(batch) for batch in batches] diff --git a/samplelib/SampleGeneratorFaceTemporal.py b/samplelib/SampleGeneratorFaceTemporal.py index 1ba06c5..d1a6500 100644 --- a/samplelib/SampleGeneratorFaceTemporal.py +++ b/samplelib/SampleGeneratorFaceTemporal.py @@ -1,10 +1,13 @@ +import pickle import traceback -import numpy as np -import cv2 +import cv2 +import numpy as np + +from samplelib import (SampleGeneratorBase, SampleHost, SampleProcessor, + SampleType) from utils import iter_utils -from samplelib import SampleType, SampleProcessor, SampleHost, SampleGeneratorBase ''' output_sample_types = [ @@ -24,14 +27,18 @@ class SampleGeneratorFaceTemporal(SampleGeneratorBase): self.generators_count = 1 else: self.generators_count = generators_count - - samples_clis = SampleHost.host (SampleType.FACE_TEMPORAL_SORTED, self.samples_path, number_of_clis=self.generators_count) + samples = SampleHost.load (SampleType.FACE_TEMPORAL_SORTED, self.samples_path) + samples_len = len(samples) + if samples_len == 0: + raise ValueError('No training data provided.') + + pickled_samples = pickle.dumps(samples, 4) if self.debug: - self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (0, samples_clis[0]) )] + self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (0, pickled_samples) )] else: - self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (i, samples_clis[i]) ) for i in range(self.generators_count) ] - + self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (i, pickled_samples) ) for i in range(self.generators_count) ] + self.generator_counter = -1 def __iter__(self): @@ -43,22 +50,20 @@ class SampleGeneratorFaceTemporal(SampleGeneratorBase): return next(generator) def batch_func(self, param): - generator_id, samples = param - + generator_id, pickled_samples = param + samples = pickle.loads(pickled_samples) samples_len = len(samples) - if samples_len == 0: - raise ValueError('No training data provided.') - + mult_max = 1 l = samples_len - ( (self.temporal_image_count)*mult_max - (mult_max-1) ) - samples_idxs = [ *range(l+1) ] [generator_id::self.generators_count] - + samples_idxs = [ *range(l+1) ] + if len(samples_idxs) - self.temporal_image_count < 0: raise ValueError('Not enough samples to fit temporal line.') - + shuffle_idxs = [] - + while True: batches = None for n_batch in range(self.batch_size): diff --git a/samplelib/SampleHost.py b/samplelib/SampleHost.py index 25f35d6..19ffd13 100644 --- a/samplelib/SampleHost.py +++ b/samplelib/SampleHost.py @@ -2,7 +2,7 @@ import multiprocessing import operator import traceback from pathlib import Path - +import pickle import samplelib.PackedFaceset from DFLIMG import * from facelib import FaceType, LandmarksProcessor @@ -35,7 +35,7 @@ class SampleHost: return len(list(persons_name_idxs.keys())) @staticmethod - def host(sample_type, samples_path, number_of_clis): + def load(sample_type, samples_path): samples_cache = SampleHost.samples_cache if str(samples_path) not in samples_cache.keys(): @@ -46,10 +46,8 @@ class SampleHost: if sample_type == SampleType.IMAGE: if samples[sample_type] is None: samples[sample_type] = [ Sample(filename=filename) for filename in io.progress_bar_generator( Path_utils.get_image_paths(samples_path), "Loading") ] - elif sample_type == SampleType.FACE or \ - sample_type == SampleType.FACE_TEMPORAL_SORTED: - result = None - + + elif sample_type == SampleType.FACE: if samples[sample_type] is None: try: result = samplelib.PackedFaceset.load(samples_path) @@ -61,18 +59,13 @@ class SampleHost: if result is None: result = SampleHost.load_face_samples( Path_utils.get_image_paths(samples_path) ) - - if sample_type == SampleType.FACE_TEMPORAL_SORTED: - result = SampleHost.upgradeToFaceTemporalSortedSamples(result) - - samples[sample_type] = mp_utils.ListHost(result) - - list_host = samples[sample_type] - - clis = [ list_host.create_cli() for _ in range(number_of_clis) ] - - return clis - + samples[sample_type] = result + + elif sample_type == SampleType.FACE_TEMPORAL_SORTED: + result = SampleHost.load (SampleType.FACE, samples_path) + result = SampleHost.upgradeToFaceTemporalSortedSamples(result) + samples[sample_type] = result + return samples[sample_type] @staticmethod diff --git a/utils/iter_utils.py b/utils/iter_utils.py index 5cb1b2b..e690e3b 100644 --- a/utils/iter_utils.py +++ b/utils/iter_utils.py @@ -22,7 +22,7 @@ class ThisThreadGenerator(object): return next(self.generator_func) class SubprocessGenerator(object): - def __init__(self, generator_func, user_param=None, prefetch=3, start_now=False): + def __init__(self, generator_func, user_param=None, prefetch=2, start_now=False): super().__init__() self.prefetch = prefetch self.generator_func = generator_func diff --git a/utils/mp_utils.py b/utils/mp_utils.py index cde9ad0..63a14ea 100644 --- a/utils/mp_utils.py +++ b/utils/mp_utils.py @@ -125,7 +125,7 @@ class IndexHost(): result.append(shuffle_idxs.pop()) self.cqs[cq_id].put (result) - time.sleep(0.005) + time.sleep(0.001) def create_cli(self): cq = multiprocessing.Queue()