mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-08-22 06:23:20 -07:00
fixed bug fix
This commit is contained in:
parent
4adeb0aa79
commit
82ce28c784
2 changed files with 5 additions and 4 deletions
|
@ -145,8 +145,8 @@ class ModelBase(object):
|
|||
" Memory error. Tune this value for your"
|
||||
" videocard manually."))
|
||||
self.options['ping_pong'] = io.input_bool(
|
||||
"Enable ping-pong? (y/n ?:help skip:%s) : " % self.options.get('batch_cap', False),
|
||||
self.options.get('batch_cap', False),
|
||||
"Enable ping-pong? (y/n ?:help skip:%s) : " % self.options.get('ping_pong', False),
|
||||
self.options.get('ping_pong', False),
|
||||
help_message="Cycles batch size between 1 and chosen batch size, simulating super convergence")
|
||||
self.options['paddle'] = self.options.get('paddle','ping')
|
||||
if self.options.get('ping_pong',False):
|
||||
|
|
|
@ -539,12 +539,13 @@ class SAEModel(ModelBase):
|
|||
|
||||
# override
|
||||
def set_batch_size(self, batch_size):
|
||||
self.batch_size = batch_size
|
||||
self.set_training_data_generators(None)
|
||||
self.set_training_data_generators([
|
||||
SampleGeneratorFace(training_data_src_path,
|
||||
sort_by_yaw_target_samples_path=training_data_dst_path if sort_by_yaw else None,
|
||||
random_ct_samples_path=training_data_dst_path if apply_random_ct != ColorTransferMode.NONE else None,
|
||||
debug=self.is_debug(), batch_size=batch_size,
|
||||
debug=self.is_debug(), batch_size=self.batch_size,
|
||||
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip,
|
||||
scale_range=np.array([-0.05,
|
||||
0.05]) + self.src_scale_mod / 100.0),
|
||||
|
@ -559,7 +560,7 @@ class SAEModel(ModelBase):
|
|||
range(ms_count)]
|
||||
),
|
||||
|
||||
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=batch_size,
|
||||
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
|
||||
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, ),
|
||||
output_sample_types=[{'types': (
|
||||
t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue