@tf_export
(
'keras.optimizers.adamsss'
)
class
Adamsss(Optimizer):
def
__init__(
self
,
lr
=
0.002
,
beta_1
=
0.9
,
beta_2
=
0.999
,
epsilon
=
None
,
schedule_decay
=
0.004
,
*
*
kwargs):
super
(Adamsss,
self
).__init__(
*
*
kwargs)
with K.name_scope(
self
.__class__.__name__):
self
.iterations
=
K.variable(
0
, dtype
=
'int64'
, name
=
'iterations'
)
self
.m_schedule
=
K.variable(
1.
, name
=
'm_schedule'
)
self
.lr
=
K.variable(lr, name
=
'lr'
)
self
.beta_1
=
K.variable(beta_1, name
=
'beta_1'
)
self
.beta_2
=
K.variable(beta_2, name
=
'beta_2'
)
if
epsilon
is
None
:
epsilon
=
K.epsilon()
self
.epsilon
=
epsilon
self
.schedule_decay
=
schedule_decay
def
get_updates(
self
, loss, params):
grads
=
self
.get_gradients(loss, params)
self
.updates
=
[state_ops.assign_add(
self
.iterations,
1
)]
t
=
math_ops.cast(
self
.iterations, K.floatx())
+
1
momentum_cache_t
=
self
.beta_1
*
(
1.
-
0.5
*
(math_ops.
pow
(K.cast_to_floatx(
0.96
), t
*
self
.schedule_decay)))
momentum_cache_t_1
=
self
.beta_1
*
(
1.
-
0.5
*
(math_ops.
pow
(K.cast_to_floatx(
0.96
), (t
+
1
)
*
self
.schedule_decay)))
m_schedule_new
=
self
.m_schedule
*
momentum_cache_t
m_schedule_next
=
self
.m_schedule
*
momentum_cache_t
*
momentum_cache_t_1
self
.updates.append((
self
.m_schedule, m_schedule_new))
shapes
=
[K.int_shape(p)
for
p
in
params]
ms
=
[K.zeros(shape)
for
shape
in
shapes]
vs
=
[K.zeros(shape)
for
shape
in
shapes]
self
.weights
=
[
self
.iterations]
+
ms
+
vs
for
p, g, m, v
in
zip
(params, grads, ms, vs):
g_prime
=
g
/
(
1.
-
m_schedule_new)
m_t
=
self
.beta_1
*
m
+
(
1.
-
self
.beta_1)
*
g
m_t_prime
=
m_t
/
(
1.
-
m_schedule_next)
v_t
=
self
.beta_2
*
v
+
(
1.
-
self
.beta_2)
*
math_ops.square(g)
v_t_prime
=
v_t
/
(
1.
-
math_ops.
pow
(
self
.beta_2, t))
m_t_bar
=
(
1.
-
momentum_cache_t)
*
g_prime
+
momentum_cache_t_1
*
m_t_prime
self
.updates.append(state_ops.assign(m, m_t))
self
.updates.append(state_ops.assign(v, v_t))
p_t
=
p
-
self
.lr
*
m_t_bar
/
(K.sqrt(v_t_prime)
+
self
.epsilon)
new_p
=
p_t
if
getattr
(p,
'constraint'
,
None
)
is
not
None
:
new_p
=
p.constraint(new_p)
self
.updates.append(state_ops.assign(p, new_p))
return
self
.updates
def
get_config(
self
):
config
=
{
'lr'
:
float
(K.get_value(
self
.lr)),
'beta_1'
:
float
(K.get_value(
self
.beta_1)),
'beta_2'
:
float
(K.get_value(
self
.beta_2)),
'epsilon'
:
self
.epsilon,
'schedule_decay'
:
self
.schedule_decay
}
base_config
=
super
(Adamsss,
self
).get_config()
return
dict
(
list
(base_config.items())
+
list
(config.items()))