fixed small things dont even worry about it

This commit is contained in:
Brigham Lysenko 2019-08-23 19:37:04 -06:00
commit 1a5484df8f
3 changed files with 5 additions and 36 deletions

View file

@ -542,40 +542,8 @@ class SAEModel(ModelBase):
# override # override
def set_batch_size(self, batch_size): def set_batch_size(self, batch_size):
self.batch_size = batch_size self.batch_size = batch_size
self.set_training_data_generators(None) for generators in self.get_training_data_generators():
self.set_training_data_generators([ generators.update_batch(batch_size)
SampleGeneratorFace(training_data_src_path,
sort_by_yaw_target_samples_path=training_data_dst_path if sort_by_yaw else None,
random_ct_samples_path=training_data_dst_path if apply_random_ct != ColorTransferMode.NONE else None,
debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip,
scale_range=np.array([-0.05,
0.05]) + self.src_scale_mod / 100.0),
output_sample_types=[{'types': (
t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr),
'resolution': resolution, 'apply_ct': apply_random_ct}] + \
[{'types': (t.IMG_TRANSFORMED, face_type, t_mode_bgr),
'resolution': resolution // (2 ** i),
'apply_ct': apply_random_ct} for i in range(ms_count)] + \
[{'types': (t.IMG_TRANSFORMED, face_type, t.MODE_M),
'resolution': resolution // (2 ** i)} for i in
range(ms_count)],
ping_pong=self.ping_pong_options,
),
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, ),
output_sample_types=[{'types': (
t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr),
'resolution': resolution}] + \
[{'types': (t.IMG_TRANSFORMED, face_type, t_mode_bgr),
'resolution': resolution // (2 ** i)} for i in
range(ms_count)] + \
[{'types': (t.IMG_TRANSFORMED, face_type, t.MODE_M),
'resolution': resolution // (2 ** i)} for i in
range(ms_count)],
ping_pong=self.ping_pong_options,)
])
# override # override
def onTrainOneIter(self, generators_samples, generators_list): def onTrainOneIter(self, generators_samples, generators_list):

View file

@ -7,7 +7,7 @@ import numpy as np
from facelib import LandmarksProcessor from facelib import LandmarksProcessor
from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor, SampleGeneratorPingPong, from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor, SampleGeneratorPingPong,
SampleType) SampleType)
from samplelib.SampleGeneratorPingPong import PingPongOptions from samplelib.SampleGeneratorPingPong import PingPongOptions, SampleGeneratorPingPong
from utils import iter_utils from utils import iter_utils
@ -25,7 +25,7 @@ class SampleGeneratorFace(SampleGeneratorPingPong):
random_ct_samples_path=None, sample_process_options=SampleProcessor.Options(), random_ct_samples_path=None, sample_process_options=SampleProcessor.Options(),
output_sample_types=[], add_sample_idx=False, generators_count=2, generators_random_seed=None, output_sample_types=[], add_sample_idx=False, generators_count=2, generators_random_seed=None,
ping_pong=PingPongOptions(), **kwargs): ping_pong=PingPongOptions(), **kwargs):
super().__init__(samples_path, debug, batch_size, ping_pong) super().__init__(samples_path, debug, batch_size=batch_size, ping_pong=ping_pong)
self.sample_process_options = sample_process_options self.sample_process_options = sample_process_options
self.output_sample_types = output_sample_types self.output_sample_types = output_sample_types
self.add_sample_idx = add_sample_idx self.add_sample_idx = add_sample_idx

View file

@ -4,5 +4,6 @@ from .SampleLoader import SampleLoader
from .SampleProcessor import SampleProcessor from .SampleProcessor import SampleProcessor
from .SampleGeneratorBase import SampleGeneratorBase from .SampleGeneratorBase import SampleGeneratorBase
from .SampleGeneratorFace import SampleGeneratorFace from .SampleGeneratorFace import SampleGeneratorFace
from .SampleGeneratorPingPong import SampleGeneratorPingPong
from .SampleGeneratorFaceTemporal import SampleGeneratorFaceTemporal from .SampleGeneratorFaceTemporal import SampleGeneratorFaceTemporal
from .SampleGeneratorImageTemporal import SampleGeneratorImageTemporal from .SampleGeneratorImageTemporal import SampleGeneratorImageTemporal