diff --git a/core/leras/layers/BatchNorm2D.py b/core/leras/layers/BatchNorm2D.py index 8ba1030..80310ab 100644 --- a/core/leras/layers/BatchNorm2D.py +++ b/core/leras/layers/BatchNorm2D.py @@ -43,6 +43,10 @@ class BatchNorm2D(nn.LayerBase): x = (x - running_mean) / tf.sqrt( running_var + self.eps ) x *= weight x += bias + + tf.assign(self.running_mean, tf.reshape(running_mean, (self.dim,))) + tf.assign(self.running_var, tf.reshape(running_var, (self.dim,))) + return x nn.BatchNorm2D = BatchNorm2D \ No newline at end of file diff --git a/models/Model_SAEHD/Model.py b/models/Model_SAEHD/Model.py index fdd42af..6aea84b 100644 --- a/models/Model_SAEHD/Model.py +++ b/models/Model_SAEHD/Model.py @@ -335,13 +335,14 @@ Examples: df, liae, df-d, df-ud, liae-ud, ... if 'df' in archi_type: self.src_dst_saveable_weights = self.encoder.get_weights() + self.inter.get_weights() + self.decoder_src.get_weights() + self.decoder_dst.get_weights() - self.src_dst_trainable_weights = self.src_dst_saveable_weights + self.src_dst_trainable_weights = [x for x in self.src_dst_saveable_weights if x.trainable] elif 'liae' in archi_type: self.src_dst_saveable_weights = self.encoder.get_weights() + self.inter_AB.get_weights() + self.inter_B.get_weights() + self.decoder.get_weights() if random_warp: self.src_dst_trainable_weights = self.src_dst_saveable_weights else: self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter_B.get_weights() + self.decoder.get_weights() + self.src_dst_trainable_weights = [x for x in self.src_dst_trainable_weights if x.trainable] self.src_dst_opt = OptimizerClass(lr=lr, lr_dropout=lr_dropout, lr_cos=lr_cos, clipnorm=clipnorm, name='src_dst_opt') self.src_dst_opt.initialize_variables (self.src_dst_saveable_weights, vars_on_cpu=optimizer_vars_on_cpu, lr_dropout_on_cpu=self.options['lr_dropout']=='cpu')