@tf_export('keras.optimizers.adamsss')
class Adamsss(Optimizer):
def __init__(self,
lr=0.002,
beta_1=0.9,
beta_2=0.999,
epsilon=None,
schedule_decay=0.004,
**kwargs):
super(Adamsss, self).__init__(**kwargs)
with K.name_scope(self.__class__.__name__):
self.iterations = K.variable(0, dtype='int64', name='iterations')
self.m_schedule = K.variable(1., name='m_schedule')
self.lr = K.variable(lr, name='lr')
self.beta_1 = K.variable(beta_1, name='beta_1')
self.beta_2 = K.variable(beta_2, name='beta_2')
if epsilon is None:
epsilon = K.epsilon()
self.epsilon = epsilon
self.schedule_decay = schedule_decay
def get_updates(self, loss, params):
grads = self.get_gradients(loss, params)
self.updates = [state_ops.assign_add(self.iterations, 1)]
t = math_ops.cast(self.iterations, K.floatx()) + 1
momentum_cache_t = self.beta_1 * (
1. - 0.5 *
(math_ops.pow(K.cast_to_floatx(0.96), t * self.schedule_decay)))
momentum_cache_t_1 = self.beta_1 * (
1. - 0.5 *
(math_ops.pow(K.cast_to_floatx(0.96), (t + 1) * self.schedule_decay)))
m_schedule_new = self.m_schedule * momentum_cache_t
m_schedule_next = self.m_schedule * momentum_cache_t * momentum_cache_t_1
self.updates.append((self.m_schedule, m_schedule_new))
shapes = [K.int_shape(p) for p in params]
ms = [K.zeros(shape) for shape in shapes]
vs = [K.zeros(shape) for shape in shapes]
self.weights = [self.iterations] + ms + vs
for p, g, m, v in zip(params, grads, ms, vs):
g_prime = g / (1. - m_schedule_new)
m_t = self.beta_1 * m + (1. - self.beta_1) * g
m_t_prime = m_t / (1. - m_schedule_next)
v_t = self.beta_2 * v + (1. - self.beta_2) * math_ops.square(g)
v_t_prime = v_t / (1. - math_ops.pow(self.beta_2, t))
m_t_bar = (
1. - momentum_cache_t) * g_prime + momentum_cache_t_1 * m_t_prime
self.updates.append(state_ops.assign(m, m_t))
self.updates.append(state_ops.assign(v, v_t))
p_t = p - self.lr * m_t_bar / (K.sqrt(v_t_prime) + self.epsilon)
new_p = p_t
if getattr(p, 'constraint', None) is not None:
new_p = p.constraint(new_p)
self.updates.append(state_ops.assign(p, new_p))
return self.updates
def get_config(self):
config = {
'lr': float(K.get_value(self.lr)),
'beta_1': float(K.get_value(self.beta_1)),
'beta_2': float(K.get_value(self.beta_2)),
'epsilon': self.epsilon,
'schedule_decay': self.schedule_decay
}
base_config = super(Adamsss, self).get_config()
return dict(list(base_config.items()) + list(config.items()))