Skip to content

Commit

Permalink
PyO3 0.22 (#1665)
Browse files Browse the repository at this point in the history
* PyO3 0.22

* Fix python stubs

* Remove name arg from PyModel::save Python signature

---------

Co-authored-by: Dimitris Iliopoulos <diliopoulos@fb.com>
  • Loading branch information
diliop and Dimitris Iliopoulos authored Nov 1, 2024
1 parent 41e0eaa commit 6ade8c2
Show file tree
Hide file tree
Showing 11 changed files with 51 additions and 51 deletions.
8 changes: 4 additions & 4 deletions bindings/python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ crate-type = ["cdylib"]

[dependencies]
rayon = "1.10"
serde = { version = "1.0", features = [ "rc", "derive" ]}
serde = { version = "1.0", features = ["rc", "derive"] }
serde_json = "1.0"
libc = "0.2"
env_logger = "0.11"
pyo3 = { version = "0.21" }
numpy = "0.21"
pyo3 = { version = "0.22", features = ["py-clone"] }
numpy = "0.22"
ndarray = "0.15"
itertools = "0.12"

Expand All @@ -24,7 +24,7 @@ path = "../../tokenizers"

[dev-dependencies]
tempfile = "3.10"
pyo3 = { version = "0.21", features = ["auto-initialize"] }
pyo3 = { version = "0.22", features = ["auto-initialize", "py-clone"] }

[features]
defaut = ["pyo3/extension-module"]
4 changes: 2 additions & 2 deletions bindings/python/src/decoders.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ impl PyDecoder {
}

fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
match state.extract::<&PyBytes>(py) {
match state.extract::<&[u8]>(py) {
Ok(s) => {
self.decoder = serde_json::from_slice(s.as_bytes()).map_err(|e| {
self.decoder = serde_json::from_slice(s).map_err(|e| {
exceptions::PyException::new_err(format!(
"Error while attempting to unpickle Decoder: {}",
e
Expand Down
4 changes: 2 additions & 2 deletions bindings/python/src/encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ impl PyEncoding {
}

fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
match state.extract::<&PyBytes>(py) {
match state.extract::<&[u8]>(py) {
Ok(s) => {
self.encoding = serde_json::from_slice(s.as_bytes()).map_err(|e| {
self.encoding = serde_json::from_slice(s).map_err(|e| {
exceptions::PyException::new_err(format!(
"Error while attempting to unpickle Encoding: {}",
e
Expand Down
8 changes: 4 additions & 4 deletions bindings/python/src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,9 @@ impl PyModel {
}

fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
match state.extract::<&PyBytes>(py) {
match state.extract::<&[u8]>(py) {
Ok(s) => {
self.model = serde_json::from_slice(s.as_bytes()).map_err(|e| {
self.model = serde_json::from_slice(s).map_err(|e| {
exceptions::PyException::new_err(format!(
"Error while attempting to unpickle Model: {}",
e
Expand Down Expand Up @@ -181,7 +181,7 @@ impl PyModel {
///
/// Returns:
/// :obj:`List[str]`: The list of saved files
#[pyo3(text_signature = "(self, folder, prefix)")]
#[pyo3(signature = (folder, prefix=None, name=None), text_signature = "(self, folder, prefix)")]
fn save<'a>(
&self,
py: Python<'_>,
Expand Down Expand Up @@ -835,7 +835,7 @@ pub struct PyUnigram {}
#[pymethods]
impl PyUnigram {
#[new]
#[pyo3(text_signature = "(self, vocab, unk_id, byte_fallback)")]
#[pyo3(signature = (vocab=None, unk_id=None, byte_fallback=None), text_signature = "(self, vocab, unk_id, byte_fallback)")]
fn new(
vocab: Option<Vec<(String, f64)>>,
unk_id: Option<usize>,
Expand Down
4 changes: 2 additions & 2 deletions bindings/python/src/normalizers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,9 @@ impl PyNormalizer {
}

fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
match state.extract::<&PyBytes>(py) {
match state.extract::<&[u8]>(py) {
Ok(s) => {
self.normalizer = serde_json::from_slice(s.as_bytes()).map_err(|e| {
self.normalizer = serde_json::from_slice(s).map_err(|e| {
exceptions::PyException::new_err(format!(
"Error while attempting to unpickle Normalizer: {}",
e
Expand Down
4 changes: 2 additions & 2 deletions bindings/python/src/pre_tokenizers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,9 @@ impl PyPreTokenizer {
}

fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
match state.extract::<&PyBytes>(py) {
match state.extract::<&[u8]>(py) {
Ok(s) => {
let unpickled = serde_json::from_slice(s.as_bytes()).map_err(|e| {
let unpickled = serde_json::from_slice(s).map_err(|e| {
exceptions::PyException::new_err(format!(
"Error while attempting to unpickle PreTokenizer: {}",
e
Expand Down
8 changes: 4 additions & 4 deletions bindings/python/src/processors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,9 @@ impl PyPostProcessor {
}

fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
match state.extract::<&PyBytes>(py) {
match state.extract::<&[u8]>(py) {
Ok(s) => {
self.processor = serde_json::from_slice(s.as_bytes()).map_err(|e| {
self.processor = serde_json::from_slice(s).map_err(|e| {
exceptions::PyException::new_err(format!(
"Error while attempting to unpickle PostProcessor: {}",
e
Expand Down Expand Up @@ -272,7 +272,7 @@ impl From<PySpecialToken> for SpecialToken {
}

impl FromPyObject<'_> for PySpecialToken {
fn extract(ob: &PyAny) -> PyResult<Self> {
fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
if let Ok(v) = ob.extract::<(String, u32)>() {
Ok(Self(v.into()))
} else if let Ok(v) = ob.extract::<(u32, String)>() {
Expand Down Expand Up @@ -312,7 +312,7 @@ impl From<PyTemplate> for Template {
}

impl FromPyObject<'_> for PyTemplate {
fn extract(ob: &PyAny) -> PyResult<Self> {
fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
if let Ok(s) = ob.extract::<&str>() {
Ok(Self(
s.try_into().map_err(exceptions::PyValueError::new_err)?,
Expand Down
46 changes: 23 additions & 23 deletions bindings/python/src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use serde::Serialize;
use std::collections::{hash_map::DefaultHasher, HashMap};
use std::hash::{Hash, Hasher};

use numpy::{npyffi, PyArray1};
use numpy::{npyffi, PyArray1, PyArrayMethods};
use pyo3::class::basic::CompareOp;
use pyo3::exceptions;
use pyo3::intern;
Expand Down Expand Up @@ -156,7 +156,7 @@ impl PyAddedToken {
}

fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
match state.extract::<&PyDict>(py) {
match state.downcast_bound::<PyDict>(py) {
Ok(state) => {
for (key, value) in state {
let key: &str = key.extract()?;
Expand All @@ -172,7 +172,7 @@ impl PyAddedToken {
}
Ok(())
}
Err(e) => Err(e),
Err(e) => Err(e.into()),
}
}

Expand Down Expand Up @@ -263,10 +263,10 @@ impl PyAddedToken {

struct TextInputSequence<'s>(tk::InputSequence<'s>);
impl<'s> FromPyObject<'s> for TextInputSequence<'s> {
fn extract(ob: &'s PyAny) -> PyResult<Self> {
fn extract_bound(ob: &Bound<'s, PyAny>) -> PyResult<Self> {
let err = exceptions::PyTypeError::new_err("TextInputSequence must be str");
if let Ok(s) = ob.downcast::<PyString>() {
Ok(Self(s.to_string_lossy().into()))
if let Ok(s) = ob.extract::<String>() {
Ok(Self(s.into()))
} else {
Err(err)
}
Expand All @@ -280,7 +280,7 @@ impl<'s> From<TextInputSequence<'s>> for tk::InputSequence<'s> {

struct PyArrayUnicode(Vec<String>);
impl FromPyObject<'_> for PyArrayUnicode {
fn extract(ob: &PyAny) -> PyResult<Self> {
fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
// SAFETY Making sure the pointer is a valid numpy array requires calling numpy C code
if unsafe { npyffi::PyArray_Check(ob.py(), ob.as_ptr()) } == 0 {
return Err(exceptions::PyTypeError::new_err("Expected an np.array"));
Expand All @@ -291,8 +291,8 @@ impl FromPyObject<'_> for PyArrayUnicode {
let desc = (*arr).descr;
(
(*desc).type_num,
(*desc).elsize as usize,
(*desc).alignment as usize,
npyffi::PyDataType_ELSIZE(ob.py(), desc) as usize,
npyffi::PyDataType_ALIGNMENT(ob.py(), desc) as usize,
(*arr).data,
(*arr).nd,
(*arr).flags,
Expand Down Expand Up @@ -347,7 +347,7 @@ impl From<PyArrayUnicode> for tk::InputSequence<'_> {

struct PyArrayStr(Vec<String>);
impl FromPyObject<'_> for PyArrayStr {
fn extract(ob: &PyAny) -> PyResult<Self> {
fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
let array = ob.downcast::<PyArray1<PyObject>>()?;
let seq = array
.readonly()
Expand All @@ -370,7 +370,7 @@ impl From<PyArrayStr> for tk::InputSequence<'_> {

struct PreTokenizedInputSequence<'s>(tk::InputSequence<'s>);
impl<'s> FromPyObject<'s> for PreTokenizedInputSequence<'s> {
fn extract(ob: &'s PyAny) -> PyResult<Self> {
fn extract_bound(ob: &Bound<'s, PyAny>) -> PyResult<Self> {
if let Ok(seq) = ob.extract::<PyArrayUnicode>() {
return Ok(Self(seq.into()));
}
Expand Down Expand Up @@ -400,17 +400,17 @@ impl<'s> From<PreTokenizedInputSequence<'s>> for tk::InputSequence<'s> {

struct TextEncodeInput<'s>(tk::EncodeInput<'s>);
impl<'s> FromPyObject<'s> for TextEncodeInput<'s> {
fn extract(ob: &'s PyAny) -> PyResult<Self> {
fn extract_bound(ob: &Bound<'s, PyAny>) -> PyResult<Self> {
if let Ok(i) = ob.extract::<TextInputSequence>() {
return Ok(Self(i.into()));
}
if let Ok((i1, i2)) = ob.extract::<(TextInputSequence, TextInputSequence)>() {
return Ok(Self((i1, i2).into()));
}
if let Ok(arr) = ob.extract::<Vec<&PyAny>>() {
if let Ok(arr) = ob.downcast::<PyList>() {
if arr.len() == 2 {
let first = arr[0].extract::<TextInputSequence>()?;
let second = arr[1].extract::<TextInputSequence>()?;
let first = arr.get_item(0)?.extract::<TextInputSequence>()?;
let second = arr.get_item(1)?.extract::<TextInputSequence>()?;
return Ok(Self((first, second).into()));
}
}
Expand All @@ -426,18 +426,18 @@ impl<'s> From<TextEncodeInput<'s>> for tk::tokenizer::EncodeInput<'s> {
}
struct PreTokenizedEncodeInput<'s>(tk::EncodeInput<'s>);
impl<'s> FromPyObject<'s> for PreTokenizedEncodeInput<'s> {
fn extract(ob: &'s PyAny) -> PyResult<Self> {
fn extract_bound(ob: &Bound<'s, PyAny>) -> PyResult<Self> {
if let Ok(i) = ob.extract::<PreTokenizedInputSequence>() {
return Ok(Self(i.into()));
}
if let Ok((i1, i2)) = ob.extract::<(PreTokenizedInputSequence, PreTokenizedInputSequence)>()
{
return Ok(Self((i1, i2).into()));
}
if let Ok(arr) = ob.extract::<Vec<&PyAny>>() {
if let Ok(arr) = ob.downcast::<PyList>() {
if arr.len() == 2 {
let first = arr[0].extract::<PreTokenizedInputSequence>()?;
let second = arr[1].extract::<PreTokenizedInputSequence>()?;
let first = arr.get_item(0)?.extract::<PreTokenizedInputSequence>()?;
let second = arr.get_item(1)?.extract::<PreTokenizedInputSequence>()?;
return Ok(Self((first, second).into()));
}
}
Expand Down Expand Up @@ -498,9 +498,9 @@ impl PyTokenizer {
}

fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
match state.extract::<&PyBytes>(py) {
match state.extract::<&[u8]>(py) {
Ok(s) => {
self.tokenizer = serde_json::from_slice(s.as_bytes()).map_err(|e| {
self.tokenizer = serde_json::from_slice(s).map_err(|e| {
exceptions::PyException::new_err(format!(
"Error while attempting to unpickle Tokenizer: {}",
e
Expand Down Expand Up @@ -1030,7 +1030,7 @@ impl PyTokenizer {
fn encode_batch(
&self,
py: Python<'_>,
input: Vec<&PyAny>,
input: Bound<'_, PyList>,
is_pretokenized: bool,
add_special_tokens: bool,
) -> PyResult<Vec<PyEncoding>> {
Expand Down Expand Up @@ -1091,7 +1091,7 @@ impl PyTokenizer {
fn encode_batch_fast(
&self,
py: Python<'_>,
input: Vec<&PyAny>,
input: Bound<'_, PyList>,
is_pretokenized: bool,
add_special_tokens: bool,
) -> PyResult<Vec<PyEncoding>> {
Expand Down
4 changes: 2 additions & 2 deletions bindings/python/src/trainers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ impl PyTrainer {
}

fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
match state.extract::<&PyBytes>(py) {
match state.extract::<&[u8]>(py) {
Ok(s) => {
let unpickled = serde_json::from_slice(s.as_bytes()).map_err(|e| {
let unpickled = serde_json::from_slice(s).map_err(|e| {
exceptions::PyException::new_err(format!(
"Error while attempting to unpickle PyTrainer: {}",
e
Expand Down
6 changes: 3 additions & 3 deletions bindings/python/src/utils/normalization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ pub enum PyRange<'s> {
#[pyo3(annotation = "Tuple[uint, uint]")]
Range(usize, usize),
#[pyo3(annotation = "slice")]
Slice(&'s PySlice),
Slice(Bound<'s, PySlice>),
}
impl PyRange<'_> {
pub fn to_range(&self, max_len: usize) -> PyResult<std::ops::Range<usize>> {
Expand All @@ -83,7 +83,7 @@ impl PyRange<'_> {
}
PyRange::Range(s, e) => Ok(*s..*e),
PyRange::Slice(s) => {
let r = s.indices(max_len as std::os::raw::c_long)?;
let r = s.indices(max_len.try_into()?)?;
Ok(r.start as usize..r.stop as usize)
}
}
Expand All @@ -94,7 +94,7 @@ impl PyRange<'_> {
pub struct PySplitDelimiterBehavior(pub SplitDelimiterBehavior);

impl FromPyObject<'_> for PySplitDelimiterBehavior {
fn extract(obj: &PyAny) -> PyResult<Self> {
fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
let s = obj.extract::<&str>()?;

Ok(Self(match s {
Expand Down
6 changes: 3 additions & 3 deletions bindings/python/src/utils/pretokenization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ fn tokenize(pretok: &mut PreTokenizedString, func: &Bound<'_, PyAny>) -> PyResul
ToPyResult(pretok.tokenize(|normalized| {
let output = func.call((normalized.get(),), None)?;
Ok(output
.extract::<&PyList>()?
.extract::<Bound<PyList>>()?
.into_iter()
.map(|obj| Ok(Token::from(obj.extract::<PyToken>()?)))
.collect::<PyResult<Vec<_>>>()?)
Expand All @@ -69,7 +69,7 @@ fn tokenize(pretok: &mut PreTokenizedString, func: &Bound<'_, PyAny>) -> PyResul
#[derive(Clone)]
pub struct PyOffsetReferential(OffsetReferential);
impl FromPyObject<'_> for PyOffsetReferential {
fn extract(obj: &PyAny) -> PyResult<Self> {
fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
let s = obj.extract::<&str>()?;

Ok(Self(match s {
Expand All @@ -85,7 +85,7 @@ impl FromPyObject<'_> for PyOffsetReferential {
#[derive(Clone)]
pub struct PyOffsetType(OffsetType);
impl FromPyObject<'_> for PyOffsetType {
fn extract(obj: &PyAny) -> PyResult<Self> {
fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
let s = obj.extract::<&str>()?;

Ok(Self(match s {
Expand Down

0 comments on commit 6ade8c2

Please sign in to comment.