Fixed exploding bug

This commit is contained in:
Jose 2023-02-07 16:07:17 +01:00 committed by GitHub
commit 9179464561
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 1 deletions

View file

@ -43,6 +43,10 @@ class BatchNorm2D(nn.LayerBase):
x = (x - running_mean) / tf.sqrt( running_var + self.eps )
x *= weight
x += bias
tf.assign(self.running_mean, tf.reshape(running_mean, (self.dim,)))
tf.assign(self.running_var, tf.reshape(running_var, (self.dim,)))
return x
nn.BatchNorm2D = BatchNorm2D

View file

@ -335,13 +335,14 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
if 'df' in archi_type:
self.src_dst_saveable_weights = self.encoder.get_weights() + self.inter.get_weights() + self.decoder_src.get_weights() + self.decoder_dst.get_weights()
self.src_dst_trainable_weights = self.src_dst_saveable_weights
self.src_dst_trainable_weights = [x for x in self.src_dst_saveable_weights if x.trainable]
elif 'liae' in archi_type:
self.src_dst_saveable_weights = self.encoder.get_weights() + self.inter_AB.get_weights() + self.inter_B.get_weights() + self.decoder.get_weights()
if random_warp:
self.src_dst_trainable_weights = self.src_dst_saveable_weights
else:
self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter_B.get_weights() + self.decoder.get_weights()
self.src_dst_trainable_weights = [x for x in self.src_dst_trainable_weights if x.trainable]
self.src_dst_opt = OptimizerClass(lr=lr, lr_dropout=lr_dropout, lr_cos=lr_cos, clipnorm=clipnorm, name='src_dst_opt')
self.src_dst_opt.initialize_variables (self.src_dst_saveable_weights, vars_on_cpu=optimizer_vars_on_cpu, lr_dropout_on_cpu=self.options['lr_dropout']=='cpu')