diff --git a/models/Model_SAEHD/Model.py b/models/Model_SAEHD/Model.py index 9957561..922c4e4 100644 --- a/models/Model_SAEHD/Model.py +++ b/models/Model_SAEHD/Model.py @@ -198,8 +198,6 @@ class SAEHDModel(ModelBase): if dims % 2 != 0: dims += 1 - - def func(x): for i in [8,4,2]: @@ -455,6 +453,7 @@ class SAEHDModel(ModelBase): psd_target_dst_anti_masked = self.model.pred_src_dst*(1.0 - target_dstm) if self.is_training_mode: + lr_dropout = 0.3 if self.options['lr_dropout'] else 0.0 self.src_dst_opt = RMSprop(lr=5e-5, lr_dropout=lr_dropout, clipnorm=1.0 if self.options['clipgrad'] else 0.0, tf_cpu_mode=self.options['optimizer_mode']-1) self.src_dst_mask_opt = RMSprop(lr=5e-5, lr_dropout=lr_dropout, clipnorm=1.0 if self.options['clipgrad'] else 0.0, tf_cpu_mode=self.options['optimizer_mode']-1)