Skip to content

Commit

Permalink
Include final ngram in NumpyOps.ngrams (#514)
Browse files Browse the repository at this point in the history
* Include final ngram in NumpyOps.ngrams

* Additionally return empty lists for invalid n-gram lengths

* Temporarily restrict mypy version

* Revert "Temporarily restrict mypy version"

This reverts commit 71bf6ec.
  • Loading branch information
adrianeboyd authored Jul 1, 2021
1 parent 8ac87d9 commit 367c895
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
6 changes: 4 additions & 2 deletions thinc/backends/numpy_ops.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -345,11 +345,13 @@ class NumpyOps(Ops):
return weights, gradient, mom1, mom2

def ngrams(self, int n, const uint64_t[::1] keys):
if n < 1:
return self.alloc((0,), dtype="uint64")
keys_ = <uint64_t*>&keys[0]
length = max(0, keys.shape[0]-n)
length = max(0, keys.shape[0]-(n-1))
cdef np.ndarray output_ = self.alloc((length,), dtype="uint64")
output = <uint64_t*>output_.data
for i in range(keys.shape[0]-n):
for i in range(keys.shape[0]-(n-1)):
output[i] = hash64(&keys_[i], n*sizeof(keys_[0]), 0)
return output_

Expand Down
9 changes: 9 additions & 0 deletions thinc/tests/backends/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,3 +485,12 @@ def test_multibatch():
ops.multibatch(10, (i for i in range(100)), (i for i in range(100)))
with pytest.raises(ValueError):
ops.multibatch(10, arr1, (i for i in range(100)), arr2)


def test_ngrams():
ops = get_current_ops()
arr1 = numpy.asarray([1, 2, 3, 4, 5], dtype=numpy.uint64)
for n in range(1, 10):
assert len(ops.ngrams(n, arr1)) == max(0, arr1.shape[0] - (n - 1))
assert len(ops.ngrams(-1, arr1)) == 0
assert len(ops.ngrams(arr1.shape[0] + 1, arr1)) == 0

0 comments on commit 367c895

Please sign in to comment.