diff --git a/example/buffers.zig b/example/buffers.zig new file mode 100644 index 00000000..bc7943d3 --- /dev/null +++ b/example/buffers.zig @@ -0,0 +1,48 @@ +const std = @import("std"); +const py = @import("pydust"); + +pub const ConstantBuffer = py.class("ConstantBuffer", struct { + pub const __doc__ = "A class implementing a buffer protocol"; + const Self = @This(); + + values: []i64, + pylength: isize, // isize to be compatible with Python API + format: [:0]const u8 = "l", // i64 + + pub fn __init__(self: *Self, args: *const extern struct { elem: py.PyLong, size: py.PyLong }) !void { + self.values = try py.allocator.alloc(i64, try args.size.as(u64)); + @memset(self.values, try args.elem.as(i64)); + self.pylength = @intCast(self.values.len); + } + + pub fn __buffer__(self: *const Self, view: *py.PyBuffer, flags: c_int) !void { + // For more details on request types, see https://docs.python.org/3/c-api/buffer.html#buffer-request-types + if (flags & py.PyBuffer.Flags.WRITABLE != 0) { + return py.BufferError.raise("request for writable buffer is rejected"); + } + const pyObj = try py.self(@constCast(self)); + view.initFromSlice(i64, self.values, @ptrCast(&self.pylength), pyObj); + } + + pub fn __release_buffer__(self: *const Self, view: *py.PyBuffer) void { + py.allocator.free(self.values); + // It might be necessary to clear the view here in case the __bufferr__ method allocates view properties. + _ = view; + } +}); + +// A function that accepts an object implementing the buffer protocol. +pub fn sum(args: *const extern struct { buf: py.PyObject }) !i64 { + var view: py.PyBuffer = undefined; + // ND is required by asSlice. + try args.buf.getBuffer(&view, py.PyBuffer.Flags.ND); + defer view.release(); + + var bufferSum: i64 = 0; + for (view.asSlice(i64)) |value| bufferSum += value; + return bufferSum; +} + +comptime { + py.module(@This()); +} diff --git a/pydust/src/errors.zig b/pydust/src/errors.zig index 277414ce..bcb2fc83 100644 --- a/pydust/src/errors.zig +++ b/pydust/src/errors.zig @@ -1,8 +1,8 @@ -const std = @import("std"); +const Allocator = @import("std").mem.Allocator; pub const PyError = error{ // Propagate an error raised from another Python function call. // This is the equivalent of returning PyNULL and allowing the already set error info to remain. Propagate, Raised, -} || std.mem.Allocator.Error; +} || Allocator.Error; diff --git a/pydust/src/functions.zig b/pydust/src/functions.zig index 3616adef..8f630c55 100644 --- a/pydust/src/functions.zig +++ b/pydust/src/functions.zig @@ -20,6 +20,8 @@ const reservedNames = .{ "__init__", "__len__", "__del__", + "__buffer__", + "__release_buffer__", }; /// Parse the arguments of a Zig function into a Pydust function siganture. diff --git a/pydust/src/pytypes.zig b/pydust/src/pytypes.zig index f1373a11..94287cfb 100644 --- a/pydust/src/pytypes.zig +++ b/pydust/src/pytypes.zig @@ -98,6 +98,20 @@ fn Slots(comptime name: [:0]const u8, comptime definition: type, comptime Instan }}; } + if (@hasDecl(definition, "__buffer__")) { + slots_ = slots_ ++ .{ffi.PyType_Slot{ + .slot = ffi.Py_bf_getbuffer, + .pfunc = @ptrCast(@constCast(&bf_getbuffer)), + }}; + } + + if (@hasDecl(definition, "__release_buffer__")) { + slots_ = slots_ ++ .{ffi.PyType_Slot{ + .slot = ffi.Py_bf_releasebuffer, + .pfunc = @ptrCast(@constCast(&bf_releasebuffer)), + }}; + } + slots_ = slots_ ++ .{ffi.PyType_Slot{ .slot = ffi.Py_tp_methods, .pfunc = @ptrCast(@constCast(&methods.pydefs)), @@ -136,6 +150,19 @@ fn Slots(comptime name: [:0]const u8, comptime definition: type, comptime Instan ffi.PyErr_Restore(error_type, error_value, error_tb); } + + fn bf_getbuffer(self: *ffi.PyObject, view: *ffi.Py_buffer, flags: c_int) callconv(.C) c_int { + // In case of any error, the view.obj field must be set to NULL. + view.obj = null; + + const instance: *Instance = @ptrCast(self); + return tramp.errVoid(definition.__buffer__(&instance.state, @ptrCast(view), flags)); + } + + fn bf_releasebuffer(self: *ffi.PyObject, view: *ffi.Py_buffer) callconv(.C) void { + const instance: *Instance = @ptrCast(self); + return definition.__release_buffer__(&instance.state, @ptrCast(view)); + } }; } diff --git a/pydust/src/types.zig b/pydust/src/types.zig index 7384202f..96f349dd 100644 --- a/pydust/src/types.zig +++ b/pydust/src/types.zig @@ -1,4 +1,5 @@ pub usingnamespace @import("types/bool.zig"); +pub usingnamespace @import("types/buffer.zig"); pub usingnamespace @import("types/dict.zig"); pub usingnamespace @import("types/error.zig"); pub usingnamespace @import("types/float.zig"); diff --git a/pydust/src/types/buffer.zig b/pydust/src/types/buffer.zig new file mode 100644 index 00000000..d0673d27 --- /dev/null +++ b/pydust/src/types/buffer.zig @@ -0,0 +1,120 @@ +const std = @import("std"); +const py = @import("../pydust.zig"); +const ffi = py.ffi; +const PyError = @import("../errors.zig").PyError; + +/// Wrapper for Python Py_buffer. +/// See: https://docs.python.org/3/c-api/buffer.html +pub const PyBuffer = extern struct { + const Self = @This(); + + pub const Flags = struct { + pub const SIMPLE: c_int = 0; + pub const WRITABLE: c_int = 0x0001; + pub const FORMAT: c_int = 0x0004; + pub const ND: c_int = 0x0008; + pub const STRIDES: c_int = 0x0010 | ND; + pub const C_CONTIGUOUS: c_int = 0x0020 | STRIDES; + pub const F_CONTIGUOUS: c_int = 0x0040 | STRIDES; + pub const ANY_CONTIGUOUS: c_int = 0x0080 | STRIDES; + pub const INDIRECT: c_int = 0x0100 | STRIDES; + pub const CONTIG: c_int = STRIDES | WRITABLE; + pub const CONTIG_RO: c_int = ND; + pub const STRIDED: c_int = STRIDES | WRITABLE; + pub const STRIDED_RO: c_int = STRIDES; + pub const RECORDS: c_int = STRIDES | FORMAT | WRITABLE; + pub const RECORDS_RO: c_int = STRIDES | FORMAT; + pub const FULL: c_int = STRIDES | FORMAT | WRITABLE | ND; + pub const FULL_RO: c_int = STRIDES | FORMAT | ND; + }; + + buf: ?[*]u8, + + // Use pyObj to get the PyObject. + // This must be an optional pointer so we can set null value. + obj: ?*ffi.PyObject, + + // product(shape) * itemsize. + // For contiguous arrays, this is the length of the underlying memory block. + // For non-contiguous arrays, it is the length that the logical structure would + // have if it were copied to a contiguous representation. + len: isize, + itemsize: isize, + readonly: c_int, + + // If ndim == 0, the memory location pointed to by buf is interpreted as a scalar of size itemsize. + // In that case, both shape and strides are NULL. + ndim: c_int, + format: [*:0]const u8, + + shape: ?[*]const isize = null, + // If strides is NULL, the array is interpreted as a standard n-dimensional C-array. + // Otherwise, the consumer must access an n-dimensional array as follows: + // ptr = (char *)buf + indices[0] * strides[0] + ... + indices[n-1] * strides[n-1]; + strides: ?[*]isize = null, + // If all suboffsets are negative (i.e. no de-referencing is needed), + // then this field must be NULL (the default value). + suboffsets: ?[*]isize = null, + internal: ?*anyopaque = null, + + pub fn release(self: *Self) void { + ffi.PyBuffer_Release(@ptrCast(self)); + } + + pub fn pyObj(self: *Self) py.PyObject { + return .{ .py = self.obj orelse unreachable }; + } + + pub fn initFromSlice(self: *Self, comptime value_type: type, values: []value_type, shape: [*]const isize, obj: py.PyObject) void { + self.* = .{ + .buf = std.mem.sliceAsBytes(values).ptr, + .obj = obj.py, + .len = @intCast(values.len * @sizeOf(value_type)), + .itemsize = @sizeOf(value_type), + .readonly = 1, + .ndim = 1, + .format = getFormat(value_type).ptr, + .shape = shape, + }; + // We need to incref the self object because it's being used by the view. + obj.incref(); + } + + // asSlice returns buf property as Zig slice. The view must have been created with ND flag. + pub fn asSlice(self: *const Self, comptime value_type: type) []value_type { + return @alignCast(std.mem.bytesAsSlice(value_type, self.buf.?[0..@intCast(self.len)])); + } + + fn getFormat(comptime value_type: type) [:0]const u8 { + switch (@typeInfo(value_type)) { + .Int => |i| { + switch (i.signedness) { + .unsigned => switch (i.bits) { + 8 => return "B", + 16 => return "H", + 32 => return "I", + 64 => return "L", + else => {}, + }, + .signed => switch (i.bits) { + 8 => return "b", + 16 => return "h", + 32 => return "i", + 64 => return "l", + else => {}, + }, + } + }, + .Float => |f| { + switch (f.bits) { + 32 => return "f", + 64 => return "d", + else => {}, + } + }, + else => {}, + } + + @compileError("Unsupported buffer value type" ++ @typeName(value_type)); + } +}; diff --git a/pydust/src/types/obj.zig b/pydust/src/types/obj.zig index ee522d2b..4ece4a41 100644 --- a/pydust/src/types/obj.zig +++ b/pydust/src/types/obj.zig @@ -54,6 +54,17 @@ pub const PyObject = extern struct { return .{ .py = ffi.PyObject_GetAttrString(self.py, attr) orelse return PyError.Propagate }; } + // See: https://docs.python.org/3/c-api/buffer.html#buffer-request-types + pub fn getBuffer(self: py.PyObject, out: *py.PyBuffer, flags: c_int) !void { + if (ffi.PyObject_CheckBuffer(self.py) != 1) { + return py.BufferError.raise("object does not support buffer interface"); + } + if (ffi.PyObject_GetBuffer(self.py, @ptrCast(out), flags) != 0) { + // Error is already raised. + return PyError.Propagate; + } + } + pub fn set(self: PyObject, attr: [:0]const u8, value: PyObject) !PyObject { if (ffi.PyObject_SetAttrString(self.py, attr, value.py) < 0) { return PyError.Propagate; diff --git a/pyproject.toml b/pyproject.toml index 1606e918..0f1d2cad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,3 +81,7 @@ root = "example/functions.zig" [[tool.pydust.ext_module]] name = "example.classes" root = "example/classes.zig" + +[[tool.pydust.ext_module]] +name = "example.buffers" +root = "example/buffers.zig" diff --git a/test/test_buffers.py b/test/test_buffers.py new file mode 100644 index 00000000..ba0444fa --- /dev/null +++ b/test/test_buffers.py @@ -0,0 +1,17 @@ +from array import array + +from example import buffers + + +def test_view(): + buffer = buffers.ConstantBuffer(1, 10) + view = memoryview(buffer) + for i in range(10): + assert view[i] == 1 + view.release() + + +def test_sum(): + # array implements a buffer protocol + arr = array("l", [1, 2, 3, 4, 5]) + assert buffers.sum(arr) == 15