mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 13:02:15 -07:00
fix rmsprop
This commit is contained in:
parent
874a7eba18
commit
c516454566
1 changed files with 9 additions and 9 deletions
|
@ -10,20 +10,22 @@ class RMSprop(nn.OptimizerBase):
|
|||
raise ValueError('name must be defined.')
|
||||
|
||||
self.lr_dropout = lr_dropout
|
||||
self.lr = lr
|
||||
self.rho = rho
|
||||
self.epsilon = epsilon
|
||||
|
||||
self.clipnorm = clipnorm
|
||||
|
||||
with tf.device('/CPU:0') :
|
||||
with tf.variable_scope(self.name):
|
||||
self.lr = tf.Variable (lr, name="lr")
|
||||
self.rho = tf.Variable (rho, name="rho")
|
||||
self.epsilon = tf.Variable (epsilon, name="epsilon")
|
||||
|
||||
self.iterations = tf.Variable(0, dtype=tf.int64, name='iters')
|
||||
|
||||
self.accumulators_dict = {}
|
||||
self.lr_rnds_dict = {}
|
||||
|
||||
def get_weights(self):
|
||||
return [self.lr, self.rho, self.epsilon, self.iterations] + list(self.accumulators_dict.values())
|
||||
return [self.iterations] + list(self.accumulators_dict.values())
|
||||
|
||||
def initialize_variables(self, trainable_weights, vars_on_cpu=True, lr_dropout_on_cpu=False):
|
||||
# Initialize here all trainable variables used in training
|
||||
|
@ -53,13 +55,11 @@ class RMSprop(nn.OptimizerBase):
|
|||
|
||||
a = self.accumulators_dict[ v.name ]
|
||||
|
||||
rho = tf.cast(self.rho, a.dtype)
|
||||
new_a = rho * a + (1. - rho) * tf.square(g)
|
||||
new_a = self.rho * a + (1. - self.rho) * tf.square(g)
|
||||
|
||||
lr = tf.cast(self.lr, a.dtype)
|
||||
epsilon = tf.cast(self.epsilon, a.dtype)
|
||||
lr = tf.constant(self.lr, g.dtype)
|
||||
|
||||
v_diff = - lr * g / (tf.sqrt(new_a) + epsilon)
|
||||
v_diff = - lr * g / (tf.sqrt(new_a) + self.epsilon)
|
||||
if self.lr_dropout != 1.0:
|
||||
lr_rnd = self.lr_rnds_dict[v.name]
|
||||
v_diff *= lr_rnd
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue