diff --git a/.gitignore b/.gitignore index 2748d37e0b..46268102ed 100644 --- a/.gitignore +++ b/.gitignore @@ -34,3 +34,6 @@ candle-wasm-examples/*/package-lock.json .DS_Store .idea/* + +# Python virtual environment +.venv/ diff --git a/candle-pyo3/.gitignore b/candle-pyo3/.gitignore new file mode 100644 index 0000000000..49cac2118a --- /dev/null +++ b/candle-pyo3/.gitignore @@ -0,0 +1,161 @@ +# Default python .gitignore file: https://github.com/github/gitignore/blob/main/Python.gitignore +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ diff --git a/candle-pyo3/candle/__init__.py b/candle-pyo3/candle/__init__.py new file mode 100644 index 0000000000..b67e83b91c --- /dev/null +++ b/candle-pyo3/candle/__init__.py @@ -0,0 +1,5 @@ +from .candle import * + +__doc__ = candle.__doc__ +if hasattr(candle, "__all__"): + __all__ = candle.__all__ \ No newline at end of file diff --git a/candle-pyo3/candle/candle.pyi b/candle-pyo3/candle/candle.pyi new file mode 100644 index 0000000000..a325edd774 --- /dev/null +++ b/candle-pyo3/candle/candle.pyi @@ -0,0 +1,111 @@ +from typing import Any, List, Tuple, Union +from enum import Enum, auto + +class Device(Enum): + """ + Backend device for a tensor. + """ + Cpu = auto(), + Cuda = auto(), + +class DType: + """ + The DType of a tensor. + """ + ... + +class QTensor: + """ + Represents a quantized `candle` Tensor. + """ + def dequantize(self) -> Tensor: + ... + +class Tensor: + """ + Represents a internal `candle` Tensor. + """ + def __init__(self, data: Any) -> None: + """ + Create a Tensor from given data. + """ + + @property + def shape(self) -> List[int]: + """ + Returns the shape of the Tensor. + """ + + @property + def rank(self) -> int: + """ + Returns the rank of the Tensor. + """ + + @property + def device(self) -> Device: + """ + Returns the device of the Tensor. + """ + + @property + def dtype(self) -> DType: + """ + Returns the dtype of the Tensor. + """ + + def values(self) -> Any: + """ + Return the values of the Tensor as a python object. + """ + + def reshape(self, shape: List[int])-> Tensor: + """ + Reshape the Tensor. + """ + + def t(self)-> Tensor: + """ + Transpose the Tensor. + """ + + def matmul(self, other: Tensor) -> Tensor: + """ + Matrix multiplication. + """ + + def to_dtype(self, dtype: Union[DType, str]) -> Tensor: + """ + Convert the Tensor to a different dtype. + """ + + def quantize(self, qtype: str) -> QTensor: + """ + Quantize the Tensor. + """ + + def __add__(self, other: Tensor) -> Tensor: + """ + Add two Tensors. + """ + + def __sub__(self, other: Tensor) -> Tensor: + """ + Subtract two Tensors. + """ + + def sqr(self) -> Tensor: + """ + Square the Tensor. + """ + + def mean_all(self) -> Tensor: + """ + Mean value of the Tensor. + """ + + +def randn(shape:Tuple[int])->Tensor: + """ + Create a Tensor with random values. + """ \ No newline at end of file diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index eddc0fdacf..5a70fe25a3 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -72,6 +72,7 @@ impl PyDType { static CUDA_DEVICE: std::sync::Mutex> = std::sync::Mutex::new(None); #[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[pyclass(name = "Device")] enum PyDevice { Cpu, Cuda, @@ -101,28 +102,6 @@ impl PyDevice { } } -impl<'source> FromPyObject<'source> for PyDevice { - fn extract(ob: &'source PyAny) -> PyResult { - let device: &str = ob.extract()?; - let device = match device { - "cpu" => PyDevice::Cpu, - "cuda" => PyDevice::Cuda, - _ => Err(PyTypeError::new_err(format!("invalid device '{device}'")))?, - }; - Ok(device) - } -} - -impl ToPyObject for PyDevice { - fn to_object(&self, py: Python<'_>) -> PyObject { - let str = match self { - PyDevice::Cpu => "cpu", - PyDevice::Cuda => "cuda", - }; - str.to_object(py) - } -} - trait PyWithDType: WithDType { fn to_py(&self, py: Python<'_>) -> PyObject; } @@ -295,8 +274,8 @@ impl PyTensor { } #[getter] - fn device(&self, py: Python<'_>) -> PyObject { - PyDevice::from_device(self.0.device()).to_object(py) + fn device(&self, _py: Python<'_>) -> PyDevice { + PyDevice::from_device(self.0.device()) } #[getter] @@ -863,6 +842,7 @@ fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add("u8", PyDType(DType::U8))?; m.add("u32", PyDType(DType::U32))?; m.add("i16", PyDType(DType::I64))?; diff --git a/candle-pyo3/test.py b/candle-pyo3/test.py index 7f24b49d7e..f7d26eb697 100644 --- a/candle-pyo3/test.py +++ b/candle-pyo3/test.py @@ -1,17 +1,22 @@ import candle +from candle import Tensor, Device -t = candle.Tensor(42.0) +t = Tensor(42.0) print(t) print(t.shape, t.rank, t.device) print(t + t) +print(t.device) -t = candle.Tensor([3.0, 1, 4, 1, 5, 9, 2, 6]) +t = Tensor([3.0, 1, 4, 1, 5, 9, 2, 6]) print(t) print(t+t) t = t.reshape([2, 4]) print(t.matmul(t.t())) +device = Device.Cuda +print(device) + print(t.to_dtype(candle.u8)) print(t.to_dtype("u8"))