diff --git a/models/Model_SAEHD/Model.py b/models/Model_SAEHD/Model.py index 7784ec6..a6edd55 100644 --- a/models/Model_SAEHD/Model.py +++ b/models/Model_SAEHD/Model.py @@ -38,8 +38,11 @@ class SAEHDModel(ModelBase): default_d_mask_dims = self.options['d_mask_dims'] = self.options.get('d_mask_dims', None) default_masked_training = self.options['masked_training'] = self.load_or_def_option('masked_training', True) default_eyes_prio = self.options['eyes_prio'] = self.load_or_def_option('eyes_prio', False) - default_lr_dropout = self.options['lr_dropout'] = self.load_or_def_option('lr_dropout', 'n') - default_lr_dropout = {True:'y', False:'n'}.get(default_lr_dropout, default_lr_dropout) #backward comp + + lr_dropout = self.load_or_def_option('lr_dropout', 'n') + lr_dropout = {True:'y', False:'n'}.get(lr_dropout, lr_dropout) #backward comp + default_lr_dropout = self.options['lr_dropout'] = lr_dropout + default_random_warp = self.options['random_warp'] = self.load_or_def_option('random_warp', True) default_gan_power = self.options['gan_power'] = self.load_or_def_option('gan_power', 0.0) default_true_face_power = self.options['true_face_power'] = self.load_or_def_option('true_face_power', 0.0)