mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-08-22 06:23:20 -07:00
fixed small things dont even worry about it
This commit is contained in:
parent
1ecc6f1b62
commit
1a5484df8f
3 changed files with 5 additions and 36 deletions
|
@ -542,40 +542,8 @@ class SAEModel(ModelBase):
|
|||
# override
|
||||
def set_batch_size(self, batch_size):
|
||||
self.batch_size = batch_size
|
||||
self.set_training_data_generators(None)
|
||||
self.set_training_data_generators([
|
||||
SampleGeneratorFace(training_data_src_path,
|
||||
sort_by_yaw_target_samples_path=training_data_dst_path if sort_by_yaw else None,
|
||||
random_ct_samples_path=training_data_dst_path if apply_random_ct != ColorTransferMode.NONE else None,
|
||||
debug=self.is_debug(), batch_size=self.batch_size,
|
||||
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip,
|
||||
scale_range=np.array([-0.05,
|
||||
0.05]) + self.src_scale_mod / 100.0),
|
||||
output_sample_types=[{'types': (
|
||||
t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr),
|
||||
'resolution': resolution, 'apply_ct': apply_random_ct}] + \
|
||||
[{'types': (t.IMG_TRANSFORMED, face_type, t_mode_bgr),
|
||||
'resolution': resolution // (2 ** i),
|
||||
'apply_ct': apply_random_ct} for i in range(ms_count)] + \
|
||||
[{'types': (t.IMG_TRANSFORMED, face_type, t.MODE_M),
|
||||
'resolution': resolution // (2 ** i)} for i in
|
||||
range(ms_count)],
|
||||
ping_pong=self.ping_pong_options,
|
||||
),
|
||||
|
||||
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
|
||||
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, ),
|
||||
output_sample_types=[{'types': (
|
||||
t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr),
|
||||
'resolution': resolution}] + \
|
||||
[{'types': (t.IMG_TRANSFORMED, face_type, t_mode_bgr),
|
||||
'resolution': resolution // (2 ** i)} for i in
|
||||
range(ms_count)] + \
|
||||
[{'types': (t.IMG_TRANSFORMED, face_type, t.MODE_M),
|
||||
'resolution': resolution // (2 ** i)} for i in
|
||||
range(ms_count)],
|
||||
ping_pong=self.ping_pong_options,)
|
||||
])
|
||||
for generators in self.get_training_data_generators():
|
||||
generators.update_batch(batch_size)
|
||||
|
||||
# override
|
||||
def onTrainOneIter(self, generators_samples, generators_list):
|
||||
|
|
|
@ -7,7 +7,7 @@ import numpy as np
|
|||
from facelib import LandmarksProcessor
|
||||
from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor, SampleGeneratorPingPong,
|
||||
SampleType)
|
||||
from samplelib.SampleGeneratorPingPong import PingPongOptions
|
||||
from samplelib.SampleGeneratorPingPong import PingPongOptions, SampleGeneratorPingPong
|
||||
from utils import iter_utils
|
||||
|
||||
|
||||
|
@ -25,7 +25,7 @@ class SampleGeneratorFace(SampleGeneratorPingPong):
|
|||
random_ct_samples_path=None, sample_process_options=SampleProcessor.Options(),
|
||||
output_sample_types=[], add_sample_idx=False, generators_count=2, generators_random_seed=None,
|
||||
ping_pong=PingPongOptions(), **kwargs):
|
||||
super().__init__(samples_path, debug, batch_size, ping_pong)
|
||||
super().__init__(samples_path, debug, batch_size=batch_size, ping_pong=ping_pong)
|
||||
self.sample_process_options = sample_process_options
|
||||
self.output_sample_types = output_sample_types
|
||||
self.add_sample_idx = add_sample_idx
|
||||
|
|
|
@ -4,5 +4,6 @@ from .SampleLoader import SampleLoader
|
|||
from .SampleProcessor import SampleProcessor
|
||||
from .SampleGeneratorBase import SampleGeneratorBase
|
||||
from .SampleGeneratorFace import SampleGeneratorFace
|
||||
from .SampleGeneratorPingPong import SampleGeneratorPingPong
|
||||
from .SampleGeneratorFaceTemporal import SampleGeneratorFaceTemporal
|
||||
from .SampleGeneratorImageTemporal import SampleGeneratorImageTemporal
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue