mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 04:52:13 -07:00
added XSeg model
This commit is contained in:
parent
a030ff6951
commit
b0b9072981
5 changed files with 513 additions and 0 deletions
66
samplelib/SampleGeneratorImage.py
Normal file
66
samplelib/SampleGeneratorImage.py
Normal file
|
@ -0,0 +1,66 @@
|
|||
import traceback
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from core.joblib import SubprocessGenerator, ThisThreadGenerator
|
||||
from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor,
|
||||
SampleType)
|
||||
|
||||
|
||||
class SampleGeneratorImage(SampleGeneratorBase):
|
||||
def __init__ (self, samples_path, debug, batch_size, sample_process_options=SampleProcessor.Options(), output_sample_types=[], raise_on_no_data=True, **kwargs):
|
||||
super().__init__(debug, batch_size)
|
||||
self.initialized = False
|
||||
self.sample_process_options = sample_process_options
|
||||
self.output_sample_types = output_sample_types
|
||||
|
||||
samples = SampleLoader.load (SampleType.IMAGE, samples_path)
|
||||
|
||||
if len(samples) == 0:
|
||||
if raise_on_no_data:
|
||||
raise ValueError('No training data provided.')
|
||||
return
|
||||
|
||||
self.generators = [ThisThreadGenerator ( self.batch_func, samples )] if self.debug else \
|
||||
[SubprocessGenerator ( self.batch_func, samples )]
|
||||
|
||||
self.generator_counter = -1
|
||||
self.initialized = True
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
self.generator_counter += 1
|
||||
generator = self.generators[self.generator_counter % len(self.generators) ]
|
||||
return next(generator)
|
||||
|
||||
def batch_func(self, samples):
|
||||
samples_len = len(samples)
|
||||
|
||||
|
||||
idxs = [ *range(samples_len) ]
|
||||
shuffle_idxs = []
|
||||
|
||||
while True:
|
||||
|
||||
batches = None
|
||||
for n_batch in range(self.batch_size):
|
||||
|
||||
if len(shuffle_idxs) == 0:
|
||||
shuffle_idxs = idxs.copy()
|
||||
np.random.shuffle (shuffle_idxs)
|
||||
|
||||
idx = shuffle_idxs.pop()
|
||||
sample = samples[idx]
|
||||
|
||||
x, = SampleProcessor.process ([sample], self.sample_process_options, self.output_sample_types, self.debug)
|
||||
|
||||
if batches is None:
|
||||
batches = [ [] for _ in range(len(x)) ]
|
||||
|
||||
for i in range(len(x)):
|
||||
batches[i].append ( x[i] )
|
||||
|
||||
yield [ np.array(batch) for batch in batches]
|
Loading…
Add table
Add a link
Reference in a new issue