Skip to content

Commit

Permalink
Added logsumexp to backend.
Browse files Browse the repository at this point in the history
  • Loading branch information
phipleg committed Apr 21, 2017
1 parent 9217eff commit 58775b8
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 0 deletions.
7 changes: 7 additions & 0 deletions keras/backend/tensorflow_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1274,6 +1274,13 @@ def log(x):
return tf.log(x)


def logsumexp(x, axis=None, keepdims=False):
'''Returns `log(sum(exp(x), axis=axis, keepdims=keepdims))` with improved numerical stability.
'''
axis = _normalize_axis(axis, ndim(x))
return tf.reduce_logsumexp(x, reduction_indices=axis, keep_dims=keepdims)


def round(x):
"""Element-wise rounding to the closest integer.
Expand Down
8 changes: 8 additions & 0 deletions keras/backend/theano_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,14 @@ def log(x):
return T.log(x)


def logsumexp(x, axis=None, keepdims=False):
'''Returns `log(sum(exp(x), axis=axis, keepdims=keepdims))` with improved numerical stability.
'''
# Theano has a built-in optimization for logsumexp (see https://github.com/Theano/Theano/pull/4736)
# so we can just write the expression directly:
return T.log(T.sum(T.exp(x), axis=axis, keepdims=keepdims))


def round(x):
return T.round(x, mode='half_to_even')

Expand Down
34 changes: 34 additions & 0 deletions tests/keras/backend/backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,40 @@ def step_function(x, states):
assert_allclose(tf_last_output, th_last_output, atol=1e-04)
assert_allclose(tf_outputs, th_outputs, atol=1e-04)

@pytest.mark.parametrize('x_np,axis,keepdims', [
(np.array([1.1, 0.8, 0.9]), 0, False),
(np.array([[1.1, 0.8, 0.9]]), 0, False),
(np.array([[1.1, 0.8, 0.9]]), 1, False),
(np.array([[1.1, 0.8, 0.9]]), -1, False),
(np.array([[1.1, 0.8, 0.9]]), 1, True),
(np.array([[1.1], [1.2]]), 0, False),
(np.array([[1.1], [1.2]]), 1, False),
(np.array([[1.1], [1.2]]), -1, False),
(np.array([[1.1], [1.2]]), -1, True),
(np.array([[1.1, 1.2, 1.3], [0.9, 0.7, 1.4]]), 0, False),
(np.array([[1.1, 1.2, 1.3], [0.9, 0.7, 1.4]]), 1, False),
(np.array([[1.1, 1.2, 1.3], [0.9, 0.7, 1.4]]), -1, False),
])
@pytest.mark.parametrize('K', [KTH, KTF], ids=["KTH", "KTF"])
def test_logsumexp(self, x_np, axis, keepdims, K):
'''
Check if K.logsumexp works properly for values close to one.
'''
x = K.variable(x_np)
assert_allclose(K.eval(K.logsumexp(x, axis=axis, keepdims=keepdims)),
np.log(np.sum(np.exp(x_np), axis=axis, keepdims=keepdims)),
rtol=1e-5)

@pytest.mark.parametrize('K', [KTH, KTF], ids=["KTH", "KTF"])
def test_logsumexp_optim(self, K):
'''
Check if optimization works.
'''
x_np = np.array([1e+4, 1e-4])
assert_allclose(K.eval(K.logsumexp(K.variable(x_np), axis=0)),
1e4,
rtol=1e-5)

def test_switch(self):
val = np.random.random()
xth = KTH.variable(val)
Expand Down

0 comments on commit 58775b8

Please sign in to comment.