mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 04:52:13 -07:00
fixes and optimizations
This commit is contained in:
parent
842a48964f
commit
d3e6b435aa
6 changed files with 72 additions and 58 deletions
|
@ -1,10 +1,13 @@
|
|||
import pickle
|
||||
import traceback
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from samplelib import (SampleGeneratorBase, SampleHost, SampleProcessor,
|
||||
SampleType)
|
||||
from utils import iter_utils
|
||||
|
||||
from samplelib import SampleType, SampleProcessor, SampleHost, SampleGeneratorBase
|
||||
|
||||
'''
|
||||
output_sample_types = [
|
||||
|
@ -24,14 +27,18 @@ class SampleGeneratorFaceTemporal(SampleGeneratorBase):
|
|||
self.generators_count = 1
|
||||
else:
|
||||
self.generators_count = generators_count
|
||||
|
||||
samples_clis = SampleHost.host (SampleType.FACE_TEMPORAL_SORTED, self.samples_path, number_of_clis=self.generators_count)
|
||||
|
||||
samples = SampleHost.load (SampleType.FACE_TEMPORAL_SORTED, self.samples_path)
|
||||
samples_len = len(samples)
|
||||
if samples_len == 0:
|
||||
raise ValueError('No training data provided.')
|
||||
|
||||
pickled_samples = pickle.dumps(samples, 4)
|
||||
if self.debug:
|
||||
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (0, samples_clis[0]) )]
|
||||
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (0, pickled_samples) )]
|
||||
else:
|
||||
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (i, samples_clis[i]) ) for i in range(self.generators_count) ]
|
||||
|
||||
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (i, pickled_samples) ) for i in range(self.generators_count) ]
|
||||
|
||||
self.generator_counter = -1
|
||||
|
||||
def __iter__(self):
|
||||
|
@ -43,22 +50,20 @@ class SampleGeneratorFaceTemporal(SampleGeneratorBase):
|
|||
return next(generator)
|
||||
|
||||
def batch_func(self, param):
|
||||
generator_id, samples = param
|
||||
|
||||
generator_id, pickled_samples = param
|
||||
samples = pickle.loads(pickled_samples)
|
||||
samples_len = len(samples)
|
||||
if samples_len == 0:
|
||||
raise ValueError('No training data provided.')
|
||||
|
||||
|
||||
mult_max = 1
|
||||
l = samples_len - ( (self.temporal_image_count)*mult_max - (mult_max-1) )
|
||||
|
||||
samples_idxs = [ *range(l+1) ] [generator_id::self.generators_count]
|
||||
|
||||
samples_idxs = [ *range(l+1) ]
|
||||
|
||||
if len(samples_idxs) - self.temporal_image_count < 0:
|
||||
raise ValueError('Not enough samples to fit temporal line.')
|
||||
|
||||
|
||||
shuffle_idxs = []
|
||||
|
||||
|
||||
while True:
|
||||
batches = None
|
||||
for n_batch in range(self.batch_size):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue