This commit is contained in:
Colombo 2019-11-24 19:51:07 +04:00
parent 1bfd65abe5
commit 77b390c04b
4 changed files with 150 additions and 25 deletions

View file

@ -92,6 +92,7 @@ Model = keras.models.Model
Adam = nnlib.Adam
RMSprop = nnlib.RMSprop
LookaheadOptimizer = nnlib.LookaheadOptimizer
modelify = nnlib.modelify
gaussian_blur = nnlib.gaussian_blur
@ -936,7 +937,85 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
base_config = super(Adam, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
nnlib.Adam = Adam
class LookaheadOptimizer(keras.optimizers.Optimizer):
def __init__(self, optimizer, sync_period=5, slow_step=0.5, tf_cpu_mode=0, **kwargs):
super(LookaheadOptimizer, self).__init__(**kwargs)
self.optimizer = optimizer
self.tf_cpu_mode = tf_cpu_mode
with K.name_scope(self.__class__.__name__):
self.sync_period = K.variable(sync_period, dtype='int64', name='sync_period')
self.slow_step = K.variable(slow_step, name='slow_step')
@property
def lr(self):
return self.optimizer.lr
@lr.setter
def lr(self, lr):
self.optimizer.lr = lr
@property
def learning_rate(self):
return self.optimizer.learning_rate
@learning_rate.setter
def learning_rate(self, learning_rate):
self.optimizer.learning_rate = learning_rate
@property
def iterations(self):
return self.optimizer.iterations
def get_updates(self, loss, params):
sync_cond = K.equal((self.iterations + 1) // self.sync_period * self.sync_period, (self.iterations + 1))
e = K.tf.device("/cpu:0") if self.tf_cpu_mode > 0 else None
if e: e.__enter__()
slow_params = [K.variable(K.get_value(p), name='sp_{}'.format(i)) for i, p in enumerate(params)]
if e: e.__exit__(None, None, None)
self.updates = self.optimizer.get_updates(loss, params)
slow_updates = []
for p, sp in zip(params, slow_params):
e = K.tf.device("/cpu:0") if self.tf_cpu_mode == 2 else None
if e: e.__enter__()
sp_t = sp + self.slow_step * (p - sp)
if e: e.__exit__(None, None, None)
slow_updates.append(K.update(sp, K.switch(
sync_cond,
sp_t,
sp,
)))
slow_updates.append(K.update_add(p, K.switch(
sync_cond,
sp_t - p,
K.zeros_like(p),
)))
self.updates += slow_updates
self.weights = self.optimizer.weights + slow_params
return self.updates
def get_config(self):
config = {
'optimizer': keras.optimizers.serialize(self.optimizer),
'sync_period': int(K.get_value(self.sync_period)),
'slow_step': float(K.get_value(self.slow_step)),
}
base_config = super(LookaheadOptimizer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@classmethod
def from_config(cls, config):
optimizer = keras.optimizers.deserialize(config.pop('optimizer'))
return cls(optimizer, **config)
nnlib.LookaheadOptimizer = LookaheadOptimizer
class DenseMaxout(keras.layers.Layer):
"""A dense maxout layer.
A `MaxoutDense` layer takes the element-wise maximum of