mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 04:52:13 -07:00
SAE: added test option: 'Apply random color transfer to src faceset'
This commit is contained in:
parent
bde700243c
commit
a805f81142
9 changed files with 152 additions and 129 deletions
|
@ -1,11 +1,14 @@
|
|||
import traceback
|
||||
import numpy as np
|
||||
import cv2
|
||||
import multiprocessing
|
||||
from utils import iter_utils
|
||||
from facelib import LandmarksProcessor
|
||||
import traceback
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from facelib import LandmarksProcessor
|
||||
from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor,
|
||||
SampleType)
|
||||
from utils import iter_utils
|
||||
|
||||
from samplelib import SampleType, SampleProcessor, SampleLoader, SampleGeneratorBase
|
||||
|
||||
'''
|
||||
arg
|
||||
|
@ -15,7 +18,7 @@ output_sample_types = [
|
|||
]
|
||||
'''
|
||||
class SampleGeneratorFace(SampleGeneratorBase):
|
||||
def __init__ (self, samples_path, debug, batch_size, sort_by_yaw=False, sort_by_yaw_target_samples_path=None, sample_process_options=SampleProcessor.Options(), output_sample_types=[], add_sample_idx=False, generators_count=2, generators_random_seed=None, **kwargs):
|
||||
def __init__ (self, samples_path, debug, batch_size, sort_by_yaw=False, sort_by_yaw_target_samples_path=None, random_ct_samples_path=None, sample_process_options=SampleProcessor.Options(), output_sample_types=[], add_sample_idx=False, generators_count=2, generators_random_seed=None, **kwargs):
|
||||
super().__init__(samples_path, debug, batch_size)
|
||||
self.sample_process_options = sample_process_options
|
||||
self.output_sample_types = output_sample_types
|
||||
|
@ -32,15 +35,17 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
|||
raise ValueError("len(generators_random_seed) != generators_count")
|
||||
|
||||
self.generators_random_seed = generators_random_seed
|
||||
|
||||
|
||||
samples = SampleLoader.load (self.sample_type, self.samples_path, sort_by_yaw_target_samples_path)
|
||||
|
||||
|
||||
ct_samples = SampleLoader.load (SampleType.FACE, random_ct_samples_path) if random_ct_samples_path is not None else None
|
||||
|
||||
if self.debug:
|
||||
self.generators_count = 1
|
||||
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (0, samples) )]
|
||||
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (0, samples, ct_samples) )]
|
||||
else:
|
||||
self.generators_count = min ( generators_count, len(samples) )
|
||||
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (i, samples[i::self.generators_count] ) ) for i in range(self.generators_count) ]
|
||||
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (i, samples[i::self.generators_count], ct_samples ) ) for i in range(self.generators_count) ]
|
||||
|
||||
self.generator_counter = -1
|
||||
|
||||
|
@ -53,14 +58,16 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
|||
return next(generator)
|
||||
|
||||
def batch_func(self, param ):
|
||||
generator_id, samples = param
|
||||
|
||||
generator_id, samples, ct_samples = param
|
||||
|
||||
if self.generators_random_seed is not None:
|
||||
np.random.seed ( self.generators_random_seed[generator_id] )
|
||||
|
||||
samples_len = len(samples)
|
||||
samples_idxs = [*range(samples_len)]
|
||||
|
||||
ct_samples_len = len(ct_samples) if ct_samples is not None else 0
|
||||
|
||||
if len(samples_idxs) == 0:
|
||||
raise ValueError('No training data provided.')
|
||||
|
||||
|
@ -106,7 +113,8 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
|||
|
||||
if sample is not None:
|
||||
try:
|
||||
x = SampleProcessor.process (sample, self.sample_process_options, self.output_sample_types, self.debug)
|
||||
x = SampleProcessor.process (sample, self.sample_process_options, self.output_sample_types, self.debug,
|
||||
ct_sample=ct_samples[np.random.randint(ct_samples_len)] if ct_samples is not None else None )
|
||||
except:
|
||||
raise Exception ("Exception occured in sample %s. Error: %s" % (sample.filename, traceback.format_exc() ) )
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue