Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate NVText Byte Pair Encoding APIs to pylibcudf #17101

Merged
merged 6 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
================
byte_pair_encode
================

.. automodule:: pylibcudf.nvtext.byte_pair_encode
:members:
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ nvtext
generate_ngrams
jaccard
minhash
byte_pair_encode
ngrams_tokenize
normalize
replace
Expand Down
45 changes: 9 additions & 36 deletions python/cudf/cudf/_lib/nvtext/byte_pair_encode.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -3,49 +3,22 @@

from cudf.core.buffer import acquire_spill_lock

from libcpp.memory cimport unique_ptr
from libcpp.utility cimport move

from pylibcudf.libcudf.column.column cimport column
from pylibcudf.libcudf.column.column_view cimport column_view
from pylibcudf.libcudf.nvtext.byte_pair_encode cimport (
bpe_merge_pairs as cpp_bpe_merge_pairs,
byte_pair_encoding as cpp_byte_pair_encoding,
load_merge_pairs as cpp_load_merge_pairs,
)
from pylibcudf.libcudf.scalar.scalar cimport string_scalar

from cudf._lib.column cimport Column
from cudf._lib.scalar cimport DeviceScalar


cdef class BPEMergePairs:
cdef unique_ptr[cpp_bpe_merge_pairs] c_obj

def __cinit__(self, Column merge_pairs):
cdef column_view c_pairs = merge_pairs.view()
with nogil:
self.c_obj = move(cpp_load_merge_pairs(c_pairs))
from pylibcudf import nvtext
from pylibcudf.nvtext.byte_pair_encode import BPEMergePairs # no-cython-lint


@acquire_spill_lock()
def byte_pair_encoding(
Column strings,
BPEMergePairs merge_pairs,
object merge_pairs,
object separator
):
cdef column_view c_strings = strings.view()
cdef DeviceScalar d_separator = separator.device_value
cdef const string_scalar* c_separator = <const string_scalar*>d_separator\
.get_raw_ptr()
cdef unique_ptr[column] c_result
with nogil:
c_result = move(
cpp_byte_pair_encoding(
c_strings,
merge_pairs.c_obj.get()[0],
c_separator[0]
)
return Column.from_pylibcudf(
nvtext.byte_pair_encode.byte_pair_encoding(
strings.to_pylibcudf(mode="read"),
merge_pairs,
separator.device_value.c_value
)

return Column.from_unique_ptr(move(c_result))
)
7 changes: 5 additions & 2 deletions python/cudf/cudf/core/byte_pair_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

from __future__ import annotations

import pylibcudf as plc

import cudf
from cudf._lib.nvtext.byte_pair_encode import (
BPEMergePairs as cpp_merge_pairs,
byte_pair_encoding as cpp_byte_pair_encoding,
)

Expand All @@ -25,7 +26,9 @@ class BytePairEncoder:
"""

def __init__(self, merges_pair: "cudf.Series"):
self.merge_pairs = cpp_merge_pairs(merges_pair._column)
self.merge_pairs = plc.nvtext.byte_pair_encode.BPEMergePairs(
merges_pair._column.to_pylibcudf(mode="read")
)

def __call__(self, text, separator: str = " ") -> cudf.Series:
"""
Expand Down
5 changes: 3 additions & 2 deletions python/pylibcudf/pylibcudf/nvtext/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
# the License.
# =============================================================================

set(cython_sources edit_distance.pyx generate_ngrams.pyx jaccard.pyx minhash.pyx
ngrams_tokenize.pyx normalize.pyx replace.pyx stemmer.pyx tokenize.pyx
set(cython_sources
edit_distance.pyx generate_ngrams.pyx jaccard.pyx minhash.pyx ngrams_tokenize.pyx normalize.pyx
replace.pyx stemmer.pyx tokenize.pyx byte_pair_encode.pyx
)

set(linked_libraries cudf::cudf)
Expand Down
2 changes: 2 additions & 0 deletions python/pylibcudf/pylibcudf/nvtext/__init__.pxd
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

from . cimport (
byte_pair_encode,
edit_distance,
generate_ngrams,
jaccard,
Expand All @@ -17,6 +18,7 @@ __all__ = [
"generate_ngrams",
"jaccard",
"minhash",
"byte_pair_encode"
"ngrams_tokenize",
"normalize",
"replace",
Expand Down
2 changes: 2 additions & 0 deletions python/pylibcudf/pylibcudf/nvtext/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

from . import (
byte_pair_encode,
edit_distance,
generate_ngrams,
jaccard,
Expand All @@ -17,6 +18,7 @@
"generate_ngrams",
"jaccard",
"minhash",
"byte_pair_encode",
"ngrams_tokenize",
"normalize",
"replace",
Expand Down
16 changes: 16 additions & 0 deletions python/pylibcudf/pylibcudf/nvtext/byte_pair_encode.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

from libcpp.memory cimport unique_ptr
from pylibcudf.column cimport Column
from pylibcudf.libcudf.nvtext.byte_pair_encode cimport bpe_merge_pairs
from pylibcudf.scalar cimport Scalar


cdef class BPEMergePairs:
cdef unique_ptr[bpe_merge_pairs] c_obj

cpdef Column byte_pair_encoding(
Column input,
BPEMergePairs merge_pairs,
Scalar separator=*
)
70 changes: 70 additions & 0 deletions python/pylibcudf/pylibcudf/nvtext/byte_pair_encode.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

from cython.operator cimport dereference
from libcpp.memory cimport unique_ptr
from libcpp.utility cimport move
from pylibcudf.column cimport Column
from pylibcudf.libcudf.column.column cimport column
from pylibcudf.libcudf.column.column_view cimport column_view
from pylibcudf.libcudf.nvtext.byte_pair_encode cimport (
byte_pair_encoding as cpp_byte_pair_encoding,
load_merge_pairs as cpp_load_merge_pairs,
)
from pylibcudf.libcudf.scalar.scalar cimport string_scalar
from pylibcudf.libcudf.scalar.scalar_factories cimport (
make_string_scalar as cpp_make_string_scalar,
)
from pylibcudf.scalar cimport Scalar


cdef class BPEMergePairs:
"""The table of merge pairs for the BPE encoder.

For details, see :cpp:class:`cudf::nvtext::bpe_merge_pairs`.
"""
def __cinit__(self, Column merge_pairs):
cdef column_view c_pairs = merge_pairs.view()
with nogil:
self.c_obj = move(cpp_load_merge_pairs(c_pairs))

cpdef Column byte_pair_encoding(
Column input,
BPEMergePairs merge_pairs,
Scalar separator=None
):
"""
Byte pair encode the input strings.

For details, see cpp:func:`cudf::nvtext::byte_pair_encoding`

Parameters
----------
input : Column
Strings to encode.
merge_pairs : BPEMergePairs
Substrings to rebuild each string on.
separator : Scalar
String used to build the output after encoding. Default is a space.

Returns
-------
Column
An encoded column of strings.
"""
cdef unique_ptr[column] c_result

if separator is None:
separator = Scalar.from_libcudf(
cpp_make_string_scalar(" ".encode())
)

with nogil:
c_result = move(
cpp_byte_pair_encoding(
input.view(),
dereference(merge_pairs.c_obj.get()),
dereference(<const string_scalar*>separator.c_obj.get()),
)
)

return Column.from_libcudf(move(c_result))
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

import pyarrow as pa
import pytest
from utils import assert_column_eq

import pylibcudf as plc


@pytest.fixture(scope="module")
def input_col():
return pa.array(
[
"e n",
"i t",
"i s",
"e s",
"en t",
"c e",
"es t",
"en ce",
"t est",
"s ent",
]
)


@pytest.mark.parametrize(
"separator", [None, plc.interop.from_arrow(pa.scalar("e"))]
)
def test_byte_pair_encoding(input_col, separator):
plc_col = plc.interop.from_arrow(
pa.array(["test sentence", "thisis test"])
)
result = plc.nvtext.byte_pair_encode.byte_pair_encoding(
plc_col,
plc.nvtext.byte_pair_encode.BPEMergePairs(
plc.interop.from_arrow(input_col)
),
separator,
)
if separator is None:
expected = pa.array(["test sent ence", "t h is is test"])
else:
expected = pa.array(["teste esenteence", "teheiseise etest"])
assert_column_eq(result, expected)
Loading