Skip to content

Commit

Permalink
raise IndexError when transforming unknown tokens
Browse files Browse the repository at this point in the history
requires Cython from master branch since unordered_map::at is broken in release version
  • Loading branch information
Dobatymo committed Jan 4, 2022
1 parent f4184d3 commit 42a8c47
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 8 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/pythonpackage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install setuptools cython numpy
python -m pip install setuptools Cython@git+https://github.com/cython/cython.git@c25c87d71107e634162302f7f61a119eff539a48 numpy
python -m pip install -r requirements-test.txt
- name: Build
run: |
Expand All @@ -63,7 +63,7 @@ jobs:
python-version: 3.8
- name: Install dependencies
run: |
python -m pip install setuptools wheel cython cibuildwheel==2.3.1
python -m pip install setuptools wheel Cython@git+https://github.com/cython/cython.git@c25c87d71107e634162302f7f61a119eff539a48 cibuildwheel==2.3.1
- name: Build wheels
run: |
python -m cibuildwheel --output-dir wheelhouse
Expand All @@ -88,7 +88,7 @@ jobs:
python-version: 3.8
- name: Install dependencies
run: |
python -m pip install setuptools cython numpy
python -m pip install setuptools Cython@git+https://github.com/cython/cython.git@c25c87d71107e634162302f7f61a119eff539a48 numpy
- name: Build dists
run: |
python setup.py sdist
Expand Down
7 changes: 6 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ repos:
rev: 'v4.0.1'
hooks:
- id: check-json
- id: check-yaml
- id: check-toml
- id: check-yaml
- id: check-case-conflict
- id: check-added-large-files
- id: debug-statements
Expand All @@ -14,6 +14,11 @@ repos:
- id: trailing-whitespace
args: ["--markdown-linebreak-ext=md"]
- id: end-of-file-fixer
- repo: https://github.com/asottile/pyupgrade
rev: 'v2.29.1'
hooks:
- id: pyupgrade
args: ["--py36-plus"]
- repo: https://github.com/psf/black
rev: '21.10b0'
hooks:
Expand Down
4 changes: 2 additions & 2 deletions encoders/cyfuncs.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ cdef class BytesLabelEncoder:
cdef size_t[::1] out = view.array(shape=(len(seq), ), itemsize=sizeof(size_t), format="Q")

for i, item in enumerate(seq):
out[i] = self._labels[item]
out[i] = self._labels.at(item)

return np.asarray(out)

Expand Down Expand Up @@ -118,7 +118,7 @@ cdef class StringLabelEncoder:
if not isinstance(item, str):
raise TypeError(f"expected bytes, {type(item)} found")

out[i] = self._labels[PyUnicodeSmartPtr(<PyObject *>item)]
out[i] = self._labels.at(PyUnicodeSmartPtr(<PyObject *>item))

return np.asarray(out)

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[build-system]
requires = ["setuptools", "wheel", "Cython", "numpy"]
requires = ["setuptools", "wheel", "Cython@git+https://github.com/cython/cython.git@c25c87d71107e634162302f7f61a119eff539a48", "numpy"]

[tool.black]
line-length = 120
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
"warn.unused_result": True,
}

with open("readme.md", "r", encoding="utf-8") as fr:
with open("readme.md", encoding="utf-8") as fr:
long_description = fr.read()

if __name__ == "__main__":
Expand Down
16 changes: 16 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ def test_bytes_inv(self):
result = le.inverse_transform(encoded)
np.testing.assert_array_equal(result, np.array(stuff))

def test_bytes_transform_fail(self):
stuff = [b"asd", b"qwe", b"zxc"] * 2

le = BytesLabelEncoder()
le.finalize()
with self.assertRaises(IndexError):
le.transform(stuff)

def test_bytes_typeerror(self):

with self.assertRaises(TypeError):
Expand Down Expand Up @@ -71,6 +79,14 @@ def test_str_inv(self):
result = le.inverse_transform(encoded)
np.testing.assert_array_equal(result, np.array(stuff))

def test_str_transform_fail(self):
stuff = ["asü", "😀", "zxä"] * 2

le = StringLabelEncoder()
le.finalize()
with self.assertRaises(IndexError):
le.transform(stuff)

def test_str_typeerror(self):

with self.assertRaises(TypeError):
Expand Down

0 comments on commit 42a8c47

Please sign in to comment.