diff --git a/models/Model_SAE/Model.py b/models/Model_SAE/Model.py index 8ba5ab5..e7b704a 100644 --- a/models/Model_SAE/Model.py +++ b/models/Model_SAE/Model.py @@ -43,7 +43,7 @@ class SAEModel(ModelBase): self.options['learn_mask'] = self.options.get('learn_mask', True) - if is_first_run and 'tensorflow' in self.device_config.backend: + if (is_first_run or ask_override) and 'tensorflow' in self.device_config.backend: def_optimizer_mode = self.options.get('optimizer_mode', 1) self.options['optimizer_mode'] = io.input_int ("Optimizer mode? ( 1,2,3 ?:help skip:%d) : " % (def_optimizer_mode), def_optimizer_mode, help_message="1 - no changes. 2 - allows you to train x2 bigger network consuming RAM. 3 - allows you to train x3 bigger network consuming huge amount of RAM and slower, depends on CPU power.") else: