fixes and optimizations

This commit is contained in:
Colombo 2020-01-07 13:45:54 +04:00
parent 842a48964f
commit d3e6b435aa
6 changed files with 72 additions and 58 deletions

View file

@ -58,13 +58,12 @@ class AVATARModel(ModelBase):
self.D = modelify(AVATARModel.Discriminator() ) (Input(df_bgr_shape)) self.D = modelify(AVATARModel.Discriminator() ) (Input(df_bgr_shape))
self.C = modelify(AVATARModel.ResNet (9, n_blocks=6, ngf=128, use_dropout=False))( Input(res_bgr_t_shape)) self.C = modelify(AVATARModel.ResNet (9, n_blocks=6, ngf=128, use_dropout=False))( Input(res_bgr_t_shape))
if self.is_first_run(): self.CA_conv_weights_list = []
conv_weights_list = [] if self.is_first_run():
for model, _ in self.get_model_filename_list(): for model, _ in self.get_model_filename_list():
for layer in model.layers: for layer in model.layers:
if type(layer) == keras.layers.Conv2D: if type(layer) == keras.layers.Conv2D:
conv_weights_list += [layer.weights[0]] #Conv2D kernel_weights self.CA_conv_weights_list += [layer.weights[0]] #Conv2D kernel_weights
CAInitializerMP ( conv_weights_list )
if not self.is_first_run(): if not self.is_first_run():
self.load_weights_safe( self.get_model_filename_list() ) self.load_weights_safe( self.get_model_filename_list() )
@ -247,7 +246,14 @@ class AVATARModel(ModelBase):
#override #override
def onSave(self): def onSave(self):
self.save_weights_safe( self.get_model_filename_list() ) self.save_weights_safe( self.get_model_filename_list() )
#override
def on_success_train_one_iter(self):
if len(self.CA_conv_weights_list) != 0:
exec(nnlib.import_all(), locals(), globals())
CAInitializerMP ( self.CA_conv_weights_list )
self.CA_conv_weights_list = []
#override #override
def onTrainOneIter(self, generators_samples, generators_list): def onTrainOneIter(self, generators_samples, generators_list):
warped_src64, src64, src64m = generators_samples[0] warped_src64, src64, src64m = generators_samples[0]

View file

@ -1,9 +1,9 @@
import multiprocessing import multiprocessing
import traceback import traceback
import pickle
import cv2 import cv2
import numpy as np import numpy as np
import time
from facelib import LandmarksProcessor from facelib import LandmarksProcessor
from samplelib import (SampleGeneratorBase, SampleHost, SampleProcessor, from samplelib import (SampleGeneratorBase, SampleHost, SampleProcessor,
SampleType) SampleType)
@ -23,6 +23,7 @@ class SampleGeneratorFace(SampleGeneratorBase):
sample_process_options=SampleProcessor.Options(), sample_process_options=SampleProcessor.Options(),
output_sample_types=[], output_sample_types=[],
add_sample_idx=False, add_sample_idx=False,
generators_count=4,
**kwargs): **kwargs):
super().__init__(samples_path, debug, batch_size) super().__init__(samples_path, debug, batch_size)
@ -33,10 +34,10 @@ class SampleGeneratorFace(SampleGeneratorBase):
if self.debug: if self.debug:
self.generators_count = 1 self.generators_count = 1
else: else:
self.generators_count = np.clip(multiprocessing.cpu_count(), 2, 6) self.generators_count = np.clip(multiprocessing.cpu_count(), 2, generators_count)
samples_clis = SampleHost.host (SampleType.FACE, self.samples_path, number_of_clis=self.generators_count) samples = SampleHost.load (SampleType.FACE, self.samples_path)
self.samples_len = len(samples_clis[0]) self.samples_len = len(samples)
if self.samples_len == 0: if self.samples_len == 0:
raise ValueError('No training data provided.') raise ValueError('No training data provided.')
@ -44,16 +45,19 @@ class SampleGeneratorFace(SampleGeneratorBase):
index_host = mp_utils.IndexHost(self.samples_len) index_host = mp_utils.IndexHost(self.samples_len)
if random_ct_samples_path is not None: if random_ct_samples_path is not None:
ct_samples_clis = SampleHost.host (SampleType.FACE, random_ct_samples_path, number_of_clis=self.generators_count) ct_samples = SampleHost.load (SampleType.FACE, random_ct_samples_path)
ct_index_host = mp_utils.IndexHost( len(ct_samples_clis[0]) ) ct_index_host = mp_utils.IndexHost( len(ct_samples) )
else: else:
ct_samples_clis = None ct_samples = None
ct_index_host = None ct_index_host = None
pickled_samples = pickle.dumps(samples, 4)
ct_pickled_samples = pickle.dumps(ct_samples, 4) if ct_samples is not None else None
if self.debug: if self.debug:
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (samples_clis[0], index_host.create_cli(), ct_samples_clis[0] if ct_index_host is not None else None, ct_index_host.create_cli() if ct_index_host is not None else None) )] self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(), ct_pickled_samples, ct_index_host.create_cli() if ct_index_host is not None else None) )]
else: else:
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (samples_clis[i], index_host.create_cli(), ct_samples_clis[i] if ct_index_host is not None else None, ct_index_host.create_cli() if ct_index_host is not None else None), start_now=True ) for i in range(self.generators_count) ] self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(), ct_pickled_samples, ct_index_host.create_cli() if ct_index_host is not None else None), start_now=True ) for i in range(self.generators_count) ]
self.generator_counter = -1 self.generator_counter = -1
@ -70,7 +74,11 @@ class SampleGeneratorFace(SampleGeneratorBase):
return next(generator) return next(generator)
def batch_func(self, param ): def batch_func(self, param ):
samples, index_host, ct_samples, ct_index_host = param pickled_samples, index_host, ct_pickled_samples, ct_index_host = param
samples = pickle.loads(pickled_samples)
ct_samples = pickle.loads(ct_pickled_samples) if ct_pickled_samples is not None else None
bs = self.batch_size bs = self.batch_size
while True: while True:
batches = None batches = None
@ -78,13 +86,14 @@ class SampleGeneratorFace(SampleGeneratorBase):
indexes = index_host.multi_get(bs) indexes = index_host.multi_get(bs)
ct_indexes = ct_index_host.multi_get(bs) if ct_samples is not None else None ct_indexes = ct_index_host.multi_get(bs) if ct_samples is not None else None
batch_samples = samples.multi_get (indexes) t = time.time()
batch_ct_samples = ct_samples.multi_get (ct_indexes) if ct_samples is not None else None
for n_batch in range(bs): for n_batch in range(bs):
sample_idx = indexes[n_batch] sample_idx = indexes[n_batch]
sample = batch_samples[n_batch] sample = samples[sample_idx]
ct_sample = batch_ct_samples[n_batch] if ct_samples is not None else None
ct_sample = None
if ct_samples is not None:
ct_sample = ct_samples[ct_indexes[n_batch]]
try: try:
x, = SampleProcessor.process ([sample], self.sample_process_options, self.output_sample_types, self.debug, ct_sample=ct_sample) x, = SampleProcessor.process ([sample], self.sample_process_options, self.output_sample_types, self.debug, ct_sample=ct_sample)
@ -102,4 +111,5 @@ class SampleGeneratorFace(SampleGeneratorBase):
if self.add_sample_idx: if self.add_sample_idx:
batches[i_sample_idx].append (sample_idx) batches[i_sample_idx].append (sample_idx)
yield [ np.array(batch) for batch in batches] yield [ np.array(batch) for batch in batches]

View file

@ -1,10 +1,13 @@
import pickle
import traceback import traceback
import numpy as np
import cv2
import cv2
import numpy as np
from samplelib import (SampleGeneratorBase, SampleHost, SampleProcessor,
SampleType)
from utils import iter_utils from utils import iter_utils
from samplelib import SampleType, SampleProcessor, SampleHost, SampleGeneratorBase
''' '''
output_sample_types = [ output_sample_types = [
@ -24,14 +27,18 @@ class SampleGeneratorFaceTemporal(SampleGeneratorBase):
self.generators_count = 1 self.generators_count = 1
else: else:
self.generators_count = generators_count self.generators_count = generators_count
samples_clis = SampleHost.host (SampleType.FACE_TEMPORAL_SORTED, self.samples_path, number_of_clis=self.generators_count)
samples = SampleHost.load (SampleType.FACE_TEMPORAL_SORTED, self.samples_path)
samples_len = len(samples)
if samples_len == 0:
raise ValueError('No training data provided.')
pickled_samples = pickle.dumps(samples, 4)
if self.debug: if self.debug:
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (0, samples_clis[0]) )] self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (0, pickled_samples) )]
else: else:
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (i, samples_clis[i]) ) for i in range(self.generators_count) ] self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (i, pickled_samples) ) for i in range(self.generators_count) ]
self.generator_counter = -1 self.generator_counter = -1
def __iter__(self): def __iter__(self):
@ -43,22 +50,20 @@ class SampleGeneratorFaceTemporal(SampleGeneratorBase):
return next(generator) return next(generator)
def batch_func(self, param): def batch_func(self, param):
generator_id, samples = param generator_id, pickled_samples = param
samples = pickle.loads(pickled_samples)
samples_len = len(samples) samples_len = len(samples)
if samples_len == 0:
raise ValueError('No training data provided.')
mult_max = 1 mult_max = 1
l = samples_len - ( (self.temporal_image_count)*mult_max - (mult_max-1) ) l = samples_len - ( (self.temporal_image_count)*mult_max - (mult_max-1) )
samples_idxs = [ *range(l+1) ] [generator_id::self.generators_count] samples_idxs = [ *range(l+1) ]
if len(samples_idxs) - self.temporal_image_count < 0: if len(samples_idxs) - self.temporal_image_count < 0:
raise ValueError('Not enough samples to fit temporal line.') raise ValueError('Not enough samples to fit temporal line.')
shuffle_idxs = [] shuffle_idxs = []
while True: while True:
batches = None batches = None
for n_batch in range(self.batch_size): for n_batch in range(self.batch_size):

View file

@ -2,7 +2,7 @@ import multiprocessing
import operator import operator
import traceback import traceback
from pathlib import Path from pathlib import Path
import pickle
import samplelib.PackedFaceset import samplelib.PackedFaceset
from DFLIMG import * from DFLIMG import *
from facelib import FaceType, LandmarksProcessor from facelib import FaceType, LandmarksProcessor
@ -35,7 +35,7 @@ class SampleHost:
return len(list(persons_name_idxs.keys())) return len(list(persons_name_idxs.keys()))
@staticmethod @staticmethod
def host(sample_type, samples_path, number_of_clis): def load(sample_type, samples_path):
samples_cache = SampleHost.samples_cache samples_cache = SampleHost.samples_cache
if str(samples_path) not in samples_cache.keys(): if str(samples_path) not in samples_cache.keys():
@ -46,10 +46,8 @@ class SampleHost:
if sample_type == SampleType.IMAGE: if sample_type == SampleType.IMAGE:
if samples[sample_type] is None: if samples[sample_type] is None:
samples[sample_type] = [ Sample(filename=filename) for filename in io.progress_bar_generator( Path_utils.get_image_paths(samples_path), "Loading") ] samples[sample_type] = [ Sample(filename=filename) for filename in io.progress_bar_generator( Path_utils.get_image_paths(samples_path), "Loading") ]
elif sample_type == SampleType.FACE or \
sample_type == SampleType.FACE_TEMPORAL_SORTED: elif sample_type == SampleType.FACE:
result = None
if samples[sample_type] is None: if samples[sample_type] is None:
try: try:
result = samplelib.PackedFaceset.load(samples_path) result = samplelib.PackedFaceset.load(samples_path)
@ -61,18 +59,13 @@ class SampleHost:
if result is None: if result is None:
result = SampleHost.load_face_samples( Path_utils.get_image_paths(samples_path) ) result = SampleHost.load_face_samples( Path_utils.get_image_paths(samples_path) )
samples[sample_type] = result
if sample_type == SampleType.FACE_TEMPORAL_SORTED:
result = SampleHost.upgradeToFaceTemporalSortedSamples(result) elif sample_type == SampleType.FACE_TEMPORAL_SORTED:
result = SampleHost.load (SampleType.FACE, samples_path)
samples[sample_type] = mp_utils.ListHost(result) result = SampleHost.upgradeToFaceTemporalSortedSamples(result)
samples[sample_type] = result
list_host = samples[sample_type]
clis = [ list_host.create_cli() for _ in range(number_of_clis) ]
return clis
return samples[sample_type] return samples[sample_type]
@staticmethod @staticmethod

View file

@ -22,7 +22,7 @@ class ThisThreadGenerator(object):
return next(self.generator_func) return next(self.generator_func)
class SubprocessGenerator(object): class SubprocessGenerator(object):
def __init__(self, generator_func, user_param=None, prefetch=3, start_now=False): def __init__(self, generator_func, user_param=None, prefetch=2, start_now=False):
super().__init__() super().__init__()
self.prefetch = prefetch self.prefetch = prefetch
self.generator_func = generator_func self.generator_func = generator_func

View file

@ -125,7 +125,7 @@ class IndexHost():
result.append(shuffle_idxs.pop()) result.append(shuffle_idxs.pop())
self.cqs[cq_id].put (result) self.cqs[cq_id].put (result)
time.sleep(0.005) time.sleep(0.001)
def create_cli(self): def create_cli(self):
cq = multiprocessing.Queue() cq = multiprocessing.Queue()