diff --git a/models/Model_SAE/Model.py b/models/Model_SAE/Model.py index a077877..97bab15 100644 --- a/models/Model_SAE/Model.py +++ b/models/Model_SAE/Model.py @@ -389,18 +389,17 @@ class SAEModel(ModelBase): ar += [ [self.decoderm, 'decoderm.h5'] ] elif 'df' in self.options['archi']: - ar += [ [self.encoder, 'encoder.h5'], - ] - if not self.pretrain or self.iter == 0: - ar += [ [self.decoder_src, 'decoder_src.h5'], - [self.decoder_dst, 'decoder_dst.h5'] + ar += [ [self.encoder, 'encoder.h5'], ] + + ar += [ [self.decoder_src, 'decoder_src.h5'], + [self.decoder_dst, 'decoder_dst.h5'] + ] if self.options['learn_mask']: - if not self.pretrain or self.iter == 0: - ar += [ [self.decoder_srcm, 'decoder_srcm.h5'], - [self.decoder_dstm, 'decoder_dstm.h5'] ] + ar += [ [self.decoder_srcm, 'decoder_srcm.h5'], + [self.decoder_dstm, 'decoder_dstm.h5'] ] self.save_weights_safe(ar)