Skip to content

Commit

Permalink
Merge pull request #103 from alexander-g/numpy
Browse files Browse the repository at this point in the history
Added numpy() method
  • Loading branch information
axsaucedo authored Dec 27, 2020
2 parents e8b536c + 695fb08 commit 2e717ad
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 0 deletions.
15 changes: 15 additions & 0 deletions python/src/main.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/numpy.h>

#include <kompute/Kompute.hpp>

Expand Down Expand Up @@ -39,6 +40,20 @@ PYBIND11_MODULE(kp, m) {
return std::unique_ptr<kp::Tensor>(new kp::Tensor(data, tensorTypes));
}), "Initialiser with list of data components and tensor GPU memory type.")
.def("data", &kp::Tensor::data, DOC(kp, Tensor, data))
.def("numpy", [](kp::Tensor& self){
ssize_t ndim = 1;
std::vector<ssize_t> shape = { self.size() };
std::vector<ssize_t> strides = { sizeof(float) };

return py::array(py::buffer_info(
self.data().data(),
sizeof(float),
py::format_descriptor<float>::format(),
ndim,
shape,
strides
));
}, "Returns stored data as a new numpy array.")
.def("__getitem__", [](kp::Tensor &self, size_t index) -> float { return self.data()[index]; },
"When only an index is necessary")
.def("__setitem__", [](kp::Tensor &self, size_t index, float value) {
Expand Down
1 change: 1 addition & 0 deletions python/test/requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
pyshader==0.7.0
numpy
2 changes: 2 additions & 0 deletions python/test/test_array_multiplication.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pyshader as ps
import kp
import numpy as np


def test_array_multiplication():
Expand Down Expand Up @@ -33,3 +34,4 @@ def compute_shader_multiply(index=("input", "GlobalInvocationId", ps.ivec3),
mgr.eval_tensor_sync_local_def([tensor_out])

assert tensor_out.data() == [2.0, 4.0, 6.0]
assert np.all(tensor_out.numpy() == [2.0, 4.0, 6.0])
4 changes: 4 additions & 0 deletions python/test/test_kompute.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os

import kp
import numpy as np

DIRNAME = os.path.dirname(os.path.abspath(__file__))

Expand All @@ -22,6 +23,7 @@ def test_opmult():
mgr.eval_tensor_sync_local_def([tensor_out])

assert tensor_out.data() == [2.0, 4.0, 6.0]
assert np.all(tensor_out.numpy() == [2.0, 4.0, 6.0])

def test_opalgobase_data():
"""
Expand Down Expand Up @@ -57,6 +59,7 @@ def test_opalgobase_data():
mgr.eval_tensor_sync_local_def([tensor_out])

assert tensor_out.data() == [2.0, 4.0, 6.0]
assert np.all(tensor_out.numpy() == [2.0, 4.0, 6.0])


def test_opalgobase_file():
Expand Down Expand Up @@ -106,3 +109,4 @@ def test_sequence():
seq.eval()

assert tensor_out.data() == [2.0, 4.0, 6.0]
assert np.all(tensor_out.numpy() == [2.0, 4.0, 6.0])

0 comments on commit 2e717ad

Please sign in to comment.