mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 13:02:15 -07:00
fixes and optimizations
This commit is contained in:
parent
842a48964f
commit
d3e6b435aa
6 changed files with 72 additions and 58 deletions
|
@ -58,13 +58,12 @@ class AVATARModel(ModelBase):
|
|||
self.D = modelify(AVATARModel.Discriminator() ) (Input(df_bgr_shape))
|
||||
self.C = modelify(AVATARModel.ResNet (9, n_blocks=6, ngf=128, use_dropout=False))( Input(res_bgr_t_shape))
|
||||
|
||||
self.CA_conv_weights_list = []
|
||||
if self.is_first_run():
|
||||
conv_weights_list = []
|
||||
for model, _ in self.get_model_filename_list():
|
||||
for layer in model.layers:
|
||||
if type(layer) == keras.layers.Conv2D:
|
||||
conv_weights_list += [layer.weights[0]] #Conv2D kernel_weights
|
||||
CAInitializerMP ( conv_weights_list )
|
||||
self.CA_conv_weights_list += [layer.weights[0]] #Conv2D kernel_weights
|
||||
|
||||
if not self.is_first_run():
|
||||
self.load_weights_safe( self.get_model_filename_list() )
|
||||
|
@ -248,6 +247,13 @@ class AVATARModel(ModelBase):
|
|||
def onSave(self):
|
||||
self.save_weights_safe( self.get_model_filename_list() )
|
||||
|
||||
#override
|
||||
def on_success_train_one_iter(self):
|
||||
if len(self.CA_conv_weights_list) != 0:
|
||||
exec(nnlib.import_all(), locals(), globals())
|
||||
CAInitializerMP ( self.CA_conv_weights_list )
|
||||
self.CA_conv_weights_list = []
|
||||
|
||||
#override
|
||||
def onTrainOneIter(self, generators_samples, generators_list):
|
||||
warped_src64, src64, src64m = generators_samples[0]
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
import multiprocessing
|
||||
import traceback
|
||||
|
||||
import pickle
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
import time
|
||||
from facelib import LandmarksProcessor
|
||||
from samplelib import (SampleGeneratorBase, SampleHost, SampleProcessor,
|
||||
SampleType)
|
||||
|
@ -23,6 +23,7 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
|||
sample_process_options=SampleProcessor.Options(),
|
||||
output_sample_types=[],
|
||||
add_sample_idx=False,
|
||||
generators_count=4,
|
||||
**kwargs):
|
||||
|
||||
super().__init__(samples_path, debug, batch_size)
|
||||
|
@ -33,10 +34,10 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
|||
if self.debug:
|
||||
self.generators_count = 1
|
||||
else:
|
||||
self.generators_count = np.clip(multiprocessing.cpu_count(), 2, 6)
|
||||
self.generators_count = np.clip(multiprocessing.cpu_count(), 2, generators_count)
|
||||
|
||||
samples_clis = SampleHost.host (SampleType.FACE, self.samples_path, number_of_clis=self.generators_count)
|
||||
self.samples_len = len(samples_clis[0])
|
||||
samples = SampleHost.load (SampleType.FACE, self.samples_path)
|
||||
self.samples_len = len(samples)
|
||||
|
||||
if self.samples_len == 0:
|
||||
raise ValueError('No training data provided.')
|
||||
|
@ -44,16 +45,19 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
|||
index_host = mp_utils.IndexHost(self.samples_len)
|
||||
|
||||
if random_ct_samples_path is not None:
|
||||
ct_samples_clis = SampleHost.host (SampleType.FACE, random_ct_samples_path, number_of_clis=self.generators_count)
|
||||
ct_index_host = mp_utils.IndexHost( len(ct_samples_clis[0]) )
|
||||
ct_samples = SampleHost.load (SampleType.FACE, random_ct_samples_path)
|
||||
ct_index_host = mp_utils.IndexHost( len(ct_samples) )
|
||||
else:
|
||||
ct_samples_clis = None
|
||||
ct_samples = None
|
||||
ct_index_host = None
|
||||
|
||||
pickled_samples = pickle.dumps(samples, 4)
|
||||
ct_pickled_samples = pickle.dumps(ct_samples, 4) if ct_samples is not None else None
|
||||
|
||||
if self.debug:
|
||||
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (samples_clis[0], index_host.create_cli(), ct_samples_clis[0] if ct_index_host is not None else None, ct_index_host.create_cli() if ct_index_host is not None else None) )]
|
||||
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(), ct_pickled_samples, ct_index_host.create_cli() if ct_index_host is not None else None) )]
|
||||
else:
|
||||
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (samples_clis[i], index_host.create_cli(), ct_samples_clis[i] if ct_index_host is not None else None, ct_index_host.create_cli() if ct_index_host is not None else None), start_now=True ) for i in range(self.generators_count) ]
|
||||
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(), ct_pickled_samples, ct_index_host.create_cli() if ct_index_host is not None else None), start_now=True ) for i in range(self.generators_count) ]
|
||||
|
||||
self.generator_counter = -1
|
||||
|
||||
|
@ -70,7 +74,11 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
|||
return next(generator)
|
||||
|
||||
def batch_func(self, param ):
|
||||
samples, index_host, ct_samples, ct_index_host = param
|
||||
pickled_samples, index_host, ct_pickled_samples, ct_index_host = param
|
||||
|
||||
samples = pickle.loads(pickled_samples)
|
||||
ct_samples = pickle.loads(ct_pickled_samples) if ct_pickled_samples is not None else None
|
||||
|
||||
bs = self.batch_size
|
||||
while True:
|
||||
batches = None
|
||||
|
@ -78,13 +86,14 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
|||
indexes = index_host.multi_get(bs)
|
||||
ct_indexes = ct_index_host.multi_get(bs) if ct_samples is not None else None
|
||||
|
||||
batch_samples = samples.multi_get (indexes)
|
||||
batch_ct_samples = ct_samples.multi_get (ct_indexes) if ct_samples is not None else None
|
||||
|
||||
t = time.time()
|
||||
for n_batch in range(bs):
|
||||
sample_idx = indexes[n_batch]
|
||||
sample = batch_samples[n_batch]
|
||||
ct_sample = batch_ct_samples[n_batch] if ct_samples is not None else None
|
||||
sample = samples[sample_idx]
|
||||
|
||||
ct_sample = None
|
||||
if ct_samples is not None:
|
||||
ct_sample = ct_samples[ct_indexes[n_batch]]
|
||||
|
||||
try:
|
||||
x, = SampleProcessor.process ([sample], self.sample_process_options, self.output_sample_types, self.debug, ct_sample=ct_sample)
|
||||
|
@ -102,4 +111,5 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
|||
|
||||
if self.add_sample_idx:
|
||||
batches[i_sample_idx].append (sample_idx)
|
||||
|
||||
yield [ np.array(batch) for batch in batches]
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
import pickle
|
||||
import traceback
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from samplelib import (SampleGeneratorBase, SampleHost, SampleProcessor,
|
||||
SampleType)
|
||||
from utils import iter_utils
|
||||
|
||||
from samplelib import SampleType, SampleProcessor, SampleHost, SampleGeneratorBase
|
||||
|
||||
'''
|
||||
output_sample_types = [
|
||||
|
@ -25,12 +28,16 @@ class SampleGeneratorFaceTemporal(SampleGeneratorBase):
|
|||
else:
|
||||
self.generators_count = generators_count
|
||||
|
||||
samples_clis = SampleHost.host (SampleType.FACE_TEMPORAL_SORTED, self.samples_path, number_of_clis=self.generators_count)
|
||||
samples = SampleHost.load (SampleType.FACE_TEMPORAL_SORTED, self.samples_path)
|
||||
samples_len = len(samples)
|
||||
if samples_len == 0:
|
||||
raise ValueError('No training data provided.')
|
||||
|
||||
pickled_samples = pickle.dumps(samples, 4)
|
||||
if self.debug:
|
||||
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (0, samples_clis[0]) )]
|
||||
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (0, pickled_samples) )]
|
||||
else:
|
||||
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (i, samples_clis[i]) ) for i in range(self.generators_count) ]
|
||||
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (i, pickled_samples) ) for i in range(self.generators_count) ]
|
||||
|
||||
self.generator_counter = -1
|
||||
|
||||
|
@ -43,16 +50,14 @@ class SampleGeneratorFaceTemporal(SampleGeneratorBase):
|
|||
return next(generator)
|
||||
|
||||
def batch_func(self, param):
|
||||
generator_id, samples = param
|
||||
|
||||
generator_id, pickled_samples = param
|
||||
samples = pickle.loads(pickled_samples)
|
||||
samples_len = len(samples)
|
||||
if samples_len == 0:
|
||||
raise ValueError('No training data provided.')
|
||||
|
||||
mult_max = 1
|
||||
l = samples_len - ( (self.temporal_image_count)*mult_max - (mult_max-1) )
|
||||
|
||||
samples_idxs = [ *range(l+1) ] [generator_id::self.generators_count]
|
||||
samples_idxs = [ *range(l+1) ]
|
||||
|
||||
if len(samples_idxs) - self.temporal_image_count < 0:
|
||||
raise ValueError('Not enough samples to fit temporal line.')
|
||||
|
|
|
@ -2,7 +2,7 @@ import multiprocessing
|
|||
import operator
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
|
||||
import pickle
|
||||
import samplelib.PackedFaceset
|
||||
from DFLIMG import *
|
||||
from facelib import FaceType, LandmarksProcessor
|
||||
|
@ -35,7 +35,7 @@ class SampleHost:
|
|||
return len(list(persons_name_idxs.keys()))
|
||||
|
||||
@staticmethod
|
||||
def host(sample_type, samples_path, number_of_clis):
|
||||
def load(sample_type, samples_path):
|
||||
samples_cache = SampleHost.samples_cache
|
||||
|
||||
if str(samples_path) not in samples_cache.keys():
|
||||
|
@ -46,10 +46,8 @@ class SampleHost:
|
|||
if sample_type == SampleType.IMAGE:
|
||||
if samples[sample_type] is None:
|
||||
samples[sample_type] = [ Sample(filename=filename) for filename in io.progress_bar_generator( Path_utils.get_image_paths(samples_path), "Loading") ]
|
||||
elif sample_type == SampleType.FACE or \
|
||||
sample_type == SampleType.FACE_TEMPORAL_SORTED:
|
||||
result = None
|
||||
|
||||
elif sample_type == SampleType.FACE:
|
||||
if samples[sample_type] is None:
|
||||
try:
|
||||
result = samplelib.PackedFaceset.load(samples_path)
|
||||
|
@ -61,17 +59,12 @@ class SampleHost:
|
|||
|
||||
if result is None:
|
||||
result = SampleHost.load_face_samples( Path_utils.get_image_paths(samples_path) )
|
||||
samples[sample_type] = result
|
||||
|
||||
if sample_type == SampleType.FACE_TEMPORAL_SORTED:
|
||||
elif sample_type == SampleType.FACE_TEMPORAL_SORTED:
|
||||
result = SampleHost.load (SampleType.FACE, samples_path)
|
||||
result = SampleHost.upgradeToFaceTemporalSortedSamples(result)
|
||||
|
||||
samples[sample_type] = mp_utils.ListHost(result)
|
||||
|
||||
list_host = samples[sample_type]
|
||||
|
||||
clis = [ list_host.create_cli() for _ in range(number_of_clis) ]
|
||||
|
||||
return clis
|
||||
samples[sample_type] = result
|
||||
|
||||
return samples[sample_type]
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ class ThisThreadGenerator(object):
|
|||
return next(self.generator_func)
|
||||
|
||||
class SubprocessGenerator(object):
|
||||
def __init__(self, generator_func, user_param=None, prefetch=3, start_now=False):
|
||||
def __init__(self, generator_func, user_param=None, prefetch=2, start_now=False):
|
||||
super().__init__()
|
||||
self.prefetch = prefetch
|
||||
self.generator_func = generator_func
|
||||
|
|
|
@ -125,7 +125,7 @@ class IndexHost():
|
|||
result.append(shuffle_idxs.pop())
|
||||
self.cqs[cq_id].put (result)
|
||||
|
||||
time.sleep(0.005)
|
||||
time.sleep(0.001)
|
||||
|
||||
def create_cli(self):
|
||||
cq = multiprocessing.Queue()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue