diff --git a/models/Model_SAEHD/Model.py b/models/Model_SAEHD/Model.py index 3494aa2..9703d3a 100644 --- a/models/Model_SAEHD/Model.py +++ b/models/Model_SAEHD/Model.py @@ -33,7 +33,8 @@ class SAEHDModel(ModelBase): default_archi = self.options['archi'] = self.load_or_def_option('archi', 'dfhd') default_ae_dims = self.options['ae_dims'] = self.load_or_def_option('ae_dims', 256) default_e_dims = self.options['e_dims'] = self.load_or_def_option('e_dims', 64) - + self.options['d_dims'] = None + self.options['d_mask_dims'] = None default_use_float16 = self.options['use_float16'] = self.load_or_def_option('use_float16', False) default_learn_mask = self.options['learn_mask'] = self.load_or_def_option('learn_mask', True) default_lr_dropout = self.options['lr_dropout'] = self.load_or_def_option('lr_dropout', False)