SAE: removed simple_optimizer . Added optimizer mode for tensorflow only (NVIDIA cards), allows to train x2-x3 bigger networks with normal Adam optimizer, consuming VRAM and CPU power.

This commit is contained in:
iperov 2019-03-13 11:54:17 +04:00
parent 7d6ca32250
commit 58763756f5
3 changed files with 100 additions and 37 deletions

View file

@ -72,7 +72,7 @@ RandomNormal = keras.initializers.RandomNormal
Model = keras.models.Model
Adam = keras.optimizers.Adam
DFLOptimizer = nnlib.DFLOptimizer
AdamCPU = nnlib.AdamCPU
modelify = nnlib.modelify
gaussian_blur = nnlib.gaussian_blur
@ -434,28 +434,93 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
return dict(list(base_config.items()) + list(config.items()))
nnlib.Scale = Scale
class DFLOptimizer(keras.optimizers.Optimizer):
def __init__(self, lr=0.001, **kwargs):
super(DFLOptimizer, self).__init__(**kwargs)
class AdamCPU(keras.optimizers.Optimizer):
"""Adam optimizer.
Default parameters follow those provided in the original paper.
# Arguments
lr: float >= 0. Learning rate.
beta_1: float, 0 < beta < 1. Generally close to 1.
beta_2: float, 0 < beta < 1. Generally close to 1.
epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`.
decay: float >= 0. Learning rate decay over each update.
amsgrad: boolean. Whether to apply the AMSGrad variant of this
algorithm from the paper "On the Convergence of Adam and
Beyond".
# References
- [Adam - A Method for Stochastic Optimization](
https://arxiv.org/abs/1412.6980v8)
- [On the Convergence of Adam and Beyond](
https://openreview.net/forum?id=ryQu7f-RZ)
"""
def __init__(self, lr=0.001, beta_1=0.9, beta_2=0.999,
epsilon=None, decay=0., amsgrad=False, tf_cpu_mode=0, **kwargs):
super(AdamCPU, self).__init__(**kwargs)
with K.name_scope(self.__class__.__name__):
self.iterations = K.variable(0, dtype='int64', name='iterations')
self.lr = K.variable(lr, name='lr')
self.beta_1 = K.variable(0.9, name='beta_1')
self.beta_2 = K.variable(0.998, name='beta_2')
self.beta_1 = K.variable(beta_1, name='beta_1')
self.beta_2 = K.variable(beta_2, name='beta_2')
self.decay = K.variable(decay, name='decay')
if epsilon is None:
epsilon = K.epsilon()
self.epsilon = epsilon
self.initial_decay = decay
self.amsgrad = amsgrad
self.tf_cpu_mode = tf_cpu_mode
@keras.legacy.interfaces.legacy_get_updates_support
def get_updates(self, loss, params):
grads = self.get_gradients(loss, params)
self.updates = [K.update_add(self.iterations, 1)]
lr_t = self.lr * ( ( K.cast(self.iterations, K.floatx()) ) % 100 + 1 ) / 100.0
lr = self.lr
if self.initial_decay > 0:
lr = lr * (1. / (1. + self.decay * K.cast(self.iterations,
K.dtype(self.decay))))
t = K.cast(self.iterations, K.floatx()) + 1
lr_t = lr * (K.sqrt(1. - K.pow(self.beta_2, t)) /
(1. - K.pow(self.beta_1, t)))
self.weights = []
for p, g in zip(params, grads):
if self.tf_cpu_mode > 0:
with K.tf.device("/cpu:0"):
ms = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
vs = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
if self.amsgrad:
vhats = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
else:
vhats = [K.zeros(1) for _ in params]
else:
ms = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
vs = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
if self.amsgrad:
vhats = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
else:
vhats = [K.zeros(1) for _ in params]
m_t = (1. - self.beta_1) * g
v_t = (1. - self.beta_2) * K.square(g)
new_p = p - lr_t * m_t / (K.sqrt(v_t) + K.epsilon() )
self.weights = [self.iterations] + ms + vs + vhats
for p, g, m, v, vhat in zip(params, grads, ms, vs, vhats):
if self.tf_cpu_mode == 2:
with K.tf.device("/cpu:0"):
m_t = (self.beta_1 * m) + (1. - self.beta_1) * g
v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(g)
else:
m_t = (self.beta_1 * m) + (1. - self.beta_1) * g
v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(g)
if self.amsgrad:
vhat_t = K.maximum(vhat, v_t)
p_t = p - lr_t * m_t / (K.sqrt(vhat_t) + self.epsilon)
self.updates.append(K.update(vhat, vhat_t))
else:
p_t = p - lr_t * m_t / (K.sqrt(v_t) + self.epsilon)
self.updates.append(K.update(m, m_t))
self.updates.append(K.update(v, v_t))
new_p = p_t
# Apply constraints.
if getattr(p, 'constraint', None) is not None:
@ -467,13 +532,14 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
def get_config(self):
config = {'lr': float(K.get_value(self.lr)),
'beta_1': float(K.get_value(self.beta_1)),
'beta_2': float(K.get_value(self.beta_2))
}
base_config = super(DFLOptimizer, self).get_config()
'beta_2': float(K.get_value(self.beta_2)),
'decay': float(K.get_value(self.decay)),
'epsilon': self.epsilon,
'amsgrad': self.amsgrad}
base_config = super(AdamCPU, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
nnlib.DFLOptimizer = DFLOptimizer
nnlib.AdamCPU = AdamCPU
'''
not implemented in plaidML
class ReflectionPadding2D(keras.layers.Layer):