diff --git a/models/Model_SAE/Model.py b/models/Model_SAE/Model.py index 0513d74..8643ce5 100644 --- a/models/Model_SAE/Model.py +++ b/models/Model_SAE/Model.py @@ -249,7 +249,7 @@ class SAEModel(ModelBase): psd_target_dst_masked_ar = [ pred_src_dst_sigm_ar[i]*target_dstm_sigm_ar[i] for i in range(len(pred_src_dst_sigm_ar))] psd_target_dst_anti_masked_ar = [ pred_src_dst_sigm_ar[i]*target_dstm_anti_sigm_ar[i] for i in range(len(pred_src_dst_sigm_ar))] - alpha_rec = 50 + alpha_rec = 10 if self.is_training_mode: self.src_dst_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, tf_cpu_mode=self.options['optimizer_mode']-1)