mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-08-22 06:23:20 -07:00
fixed bug fix
This commit is contained in:
parent
4adeb0aa79
commit
82ce28c784
2 changed files with 5 additions and 4 deletions
|
@ -145,8 +145,8 @@ class ModelBase(object):
|
||||||
" Memory error. Tune this value for your"
|
" Memory error. Tune this value for your"
|
||||||
" videocard manually."))
|
" videocard manually."))
|
||||||
self.options['ping_pong'] = io.input_bool(
|
self.options['ping_pong'] = io.input_bool(
|
||||||
"Enable ping-pong? (y/n ?:help skip:%s) : " % self.options.get('batch_cap', False),
|
"Enable ping-pong? (y/n ?:help skip:%s) : " % self.options.get('ping_pong', False),
|
||||||
self.options.get('batch_cap', False),
|
self.options.get('ping_pong', False),
|
||||||
help_message="Cycles batch size between 1 and chosen batch size, simulating super convergence")
|
help_message="Cycles batch size between 1 and chosen batch size, simulating super convergence")
|
||||||
self.options['paddle'] = self.options.get('paddle','ping')
|
self.options['paddle'] = self.options.get('paddle','ping')
|
||||||
if self.options.get('ping_pong',False):
|
if self.options.get('ping_pong',False):
|
||||||
|
|
|
@ -539,12 +539,13 @@ class SAEModel(ModelBase):
|
||||||
|
|
||||||
# override
|
# override
|
||||||
def set_batch_size(self, batch_size):
|
def set_batch_size(self, batch_size):
|
||||||
|
self.batch_size = batch_size
|
||||||
self.set_training_data_generators(None)
|
self.set_training_data_generators(None)
|
||||||
self.set_training_data_generators([
|
self.set_training_data_generators([
|
||||||
SampleGeneratorFace(training_data_src_path,
|
SampleGeneratorFace(training_data_src_path,
|
||||||
sort_by_yaw_target_samples_path=training_data_dst_path if sort_by_yaw else None,
|
sort_by_yaw_target_samples_path=training_data_dst_path if sort_by_yaw else None,
|
||||||
random_ct_samples_path=training_data_dst_path if apply_random_ct != ColorTransferMode.NONE else None,
|
random_ct_samples_path=training_data_dst_path if apply_random_ct != ColorTransferMode.NONE else None,
|
||||||
debug=self.is_debug(), batch_size=batch_size,
|
debug=self.is_debug(), batch_size=self.batch_size,
|
||||||
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip,
|
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip,
|
||||||
scale_range=np.array([-0.05,
|
scale_range=np.array([-0.05,
|
||||||
0.05]) + self.src_scale_mod / 100.0),
|
0.05]) + self.src_scale_mod / 100.0),
|
||||||
|
@ -559,7 +560,7 @@ class SAEModel(ModelBase):
|
||||||
range(ms_count)]
|
range(ms_count)]
|
||||||
),
|
),
|
||||||
|
|
||||||
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=batch_size,
|
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
|
||||||
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, ),
|
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, ),
|
||||||
output_sample_types=[{'types': (
|
output_sample_types=[{'types': (
|
||||||
t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr),
|
t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr),
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue