mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-08-20 13:33:24 -07:00
Fixed exploding bug
This commit is contained in:
parent
347e5bcdf3
commit
9179464561
2 changed files with 6 additions and 1 deletions
|
@ -43,6 +43,10 @@ class BatchNorm2D(nn.LayerBase):
|
|||
x = (x - running_mean) / tf.sqrt( running_var + self.eps )
|
||||
x *= weight
|
||||
x += bias
|
||||
|
||||
tf.assign(self.running_mean, tf.reshape(running_mean, (self.dim,)))
|
||||
tf.assign(self.running_var, tf.reshape(running_var, (self.dim,)))
|
||||
|
||||
return x
|
||||
|
||||
nn.BatchNorm2D = BatchNorm2D
|
|
@ -335,13 +335,14 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
|
||||
if 'df' in archi_type:
|
||||
self.src_dst_saveable_weights = self.encoder.get_weights() + self.inter.get_weights() + self.decoder_src.get_weights() + self.decoder_dst.get_weights()
|
||||
self.src_dst_trainable_weights = self.src_dst_saveable_weights
|
||||
self.src_dst_trainable_weights = [x for x in self.src_dst_saveable_weights if x.trainable]
|
||||
elif 'liae' in archi_type:
|
||||
self.src_dst_saveable_weights = self.encoder.get_weights() + self.inter_AB.get_weights() + self.inter_B.get_weights() + self.decoder.get_weights()
|
||||
if random_warp:
|
||||
self.src_dst_trainable_weights = self.src_dst_saveable_weights
|
||||
else:
|
||||
self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter_B.get_weights() + self.decoder.get_weights()
|
||||
self.src_dst_trainable_weights = [x for x in self.src_dst_trainable_weights if x.trainable]
|
||||
|
||||
self.src_dst_opt = OptimizerClass(lr=lr, lr_dropout=lr_dropout, lr_cos=lr_cos, clipnorm=clipnorm, name='src_dst_opt')
|
||||
self.src_dst_opt.initialize_variables (self.src_dst_saveable_weights, vars_on_cpu=optimizer_vars_on_cpu, lr_dropout_on_cpu=self.options['lr_dropout']=='cpu')
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue