Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing boolean + numpy > 1.20 #326

Merged
merged 2 commits into from
Aug 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 35 additions & 15 deletions bindings/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -473,8 +473,8 @@ impl Open {
Storage::TorchStorage(storage) => {
Python::with_gil(|py| -> PyResult<PyObject> {
let torch = get_module(py, &TORCH_MODULE)?;
let dtype: PyObject = get_pydtype(torch, info.dtype)?;
let torch_uint8: PyObject = get_pydtype(torch, Dtype::U8)?;
let dtype: PyObject = get_pydtype(torch, info.dtype, false)?;
let torch_uint8: PyObject = get_pydtype(torch, Dtype::U8, false)?;
let kwargs = [(intern!(py, "dtype"), torch_uint8)].into_py_dict(py);
let view_kwargs = [(intern!(py, "dtype"), dtype)].into_py_dict(py);
let shape = info.shape.to_vec();
Expand Down Expand Up @@ -504,7 +504,7 @@ impl Open {
let inplace_kwargs =
[(intern!(py, "inplace"), false.into_py(py))].into_py_dict(py);
if info.dtype == Dtype::BF16 {
let torch_f16: PyObject = get_pydtype(torch, Dtype::F16)?;
let torch_f16: PyObject = get_pydtype(torch, Dtype::F16, false)?;
tensor = tensor.getattr(intern!(py, "to"))?.call(
(),
Some([(intern!(py, "dtype"), torch_f16)].into_py_dict(py)),
Expand All @@ -519,7 +519,7 @@ impl Open {
tensor = torch.getattr(intern!(py, "from_numpy"))?.call1((numpy,))?;

if info.dtype == Dtype::BF16 {
let torch_bf16: PyObject = get_pydtype(torch, Dtype::BF16)?;
let torch_bf16: PyObject = get_pydtype(torch, Dtype::BF16, false)?;
tensor = tensor.getattr(intern!(py, "to"))?.call(
(),
Some([(intern!(py, "dtype"), torch_bf16)].into_py_dict(py)),
Expand Down Expand Up @@ -796,8 +796,8 @@ impl PySafeSlice {
}
Storage::TorchStorage(storage) => Python::with_gil(|py| -> PyResult<PyObject> {
let torch = get_module(py, &TORCH_MODULE)?;
let dtype: PyObject = get_pydtype(torch, self.info.dtype)?;
let torch_uint8: PyObject = get_pydtype(torch, Dtype::U8)?;
let dtype: PyObject = get_pydtype(torch, self.info.dtype, false)?;
let torch_uint8: PyObject = get_pydtype(torch, Dtype::U8, false)?;
let kwargs = [(intern!(py, "dtype"), torch_uint8)].into_py_dict(py);
let view_kwargs = [(intern!(py, "dtype"), dtype)].into_py_dict(py);
let shape = self.info.shape.to_vec();
Expand Down Expand Up @@ -873,13 +873,27 @@ fn create_tensor(
device: &Device,
) -> PyResult<PyObject> {
Python::with_gil(|py| -> PyResult<PyObject> {
let module: &PyModule = match framework {
Framework::Pytorch => TORCH_MODULE.get(py),
_ => NUMPY_MODULE.get(py),
}
.ok_or_else(|| SafetensorError::new_err(format!("Could not find module {framework:?}",)))?
.as_ref(py);
let dtype: PyObject = get_pydtype(module, dtype)?;
let (module, is_numpy): (&PyModule, bool) = match framework {
Framework::Pytorch => (
TORCH_MODULE
.get(py)
.ok_or_else(|| {
SafetensorError::new_err(format!("Could not find module {framework:?}",))
})?
.as_ref(py),
false,
),
_ => (
NUMPY_MODULE
.get(py)
.ok_or_else(|| {
SafetensorError::new_err(format!("Could not find module {framework:?}",))
})?
.as_ref(py),
true,
),
};
let dtype: PyObject = get_pydtype(module, dtype, is_numpy)?;
let count: usize = shape.iter().product();
let shape = shape.to_vec();
let shape: PyObject = shape.into_py(py);
Expand Down Expand Up @@ -939,7 +953,7 @@ fn create_tensor(
})
}

fn get_pydtype(module: &PyModule, dtype: Dtype) -> PyResult<PyObject> {
fn get_pydtype(module: &PyModule, dtype: Dtype, is_numpy: bool) -> PyResult<PyObject> {
Python::with_gil(|py| {
let dtype: PyObject = match dtype {
Dtype::F64 => module.getattr(intern!(py, "float64"))?.into(),
Expand All @@ -954,7 +968,13 @@ fn get_pydtype(module: &PyModule, dtype: Dtype) -> PyResult<PyObject> {
Dtype::I16 => module.getattr(intern!(py, "int16"))?.into(),
Dtype::U8 => module.getattr(intern!(py, "uint8"))?.into(),
Dtype::I8 => module.getattr(intern!(py, "int8"))?.into(),
Dtype::BOOL => module.getattr(intern!(py, "bool"))?.into(),
Dtype::BOOL => {
if is_numpy {
py.import("builtins")?.getattr(intern!(py, "bool"))?.into()
} else {
module.getattr(intern!(py, "bool"))?.into()
}
}
dtype => {
return Err(SafetensorError::new_err(format!(
"Dtype not understood: {dtype:?}"
Expand Down
13 changes: 13 additions & 0 deletions bindings/python/tests/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,19 @@ def test_numpy_example(self):
loaded = load(out)
self.assertTensorEqual(tensors, loaded, np.allclose)

def test_numpy_bool(self):
tensors = {"a": np.asarray(False)}

save_file(tensors, "./out_bool.safetensors")
out = save(tensors)

# Now loading
loaded = load_file("./out_bool.safetensors")
self.assertTensorEqual(tensors, loaded, np.allclose)

loaded = load(out)
self.assertTensorEqual(tensors, loaded, np.allclose)

def test_torch_example(self):
tensors = {
"a": torch.zeros((2, 2)),
Expand Down
Loading