diff --git a/example/iterators.zig b/example/iterators.zig new file mode 100644 index 00000000..c0e1eb01 --- /dev/null +++ b/example/iterators.zig @@ -0,0 +1,46 @@ +const std = @import("std"); +const py = @import("pydust"); + +pub const Range = py.class("Range", struct { + pub const __doc__ = "An example of iterable class"; + + const Self = @This(); + + lower: i64, + upper: i64, + step: i64, + + pub fn __new__(args: struct { lower: i64, upper: i64, step: i64 }) !Self { + return .{ .lower = args.lower, .upper = args.upper, .step = args.step }; + } + + pub fn __iter__(self: *const Self) !*RangeIterator { + return try py.init(RangeIterator, .{ .next = self.lower, .stop = self.upper, .step = self.step }); + } +}); + +pub const RangeIterator = py.class("RangeIterator", struct { + pub const __doc__ = "Range iterator"; + + const Self = @This(); + + next: i64, + stop: i64, + step: i64, + + pub fn __new__(args: struct { next: i64, stop: i64, step: i64 }) !Self { + return .{ .next = args.next, .stop = args.stop, .step = args.step }; + } + + pub fn __next__(self: *Self) !?i64 { + if (self.next >= self.stop) { + return null; + } + defer self.next += self.step; + return self.next; + } +}); + +comptime { + py.module(@This()); +} diff --git a/pydust/src/functions.zig b/pydust/src/functions.zig index 812069f9..aa6bb8eb 100644 --- a/pydust/src/functions.zig +++ b/pydust/src/functions.zig @@ -23,6 +23,8 @@ const reservedNames = .{ "__del__", "__buffer__", "__release_buffer__", + "__iter__", + "__next__", }; /// Parse the arguments of a Zig function into a Pydust function siganture. diff --git a/pydust/src/pytypes.zig b/pydust/src/pytypes.zig index 5cda31cd..15d5534d 100644 --- a/pydust/src/pytypes.zig +++ b/pydust/src/pytypes.zig @@ -110,6 +110,20 @@ fn Slots(comptime definition: type, comptime Instance: type) type { }}; } + if (@hasDecl(definition, "__iter__")) { + slots_ = slots_ ++ .{ffi.PyType_Slot{ + .slot = ffi.Py_tp_iter, + .pfunc = @ptrCast(@constCast(&tp_iter)), + }}; + } + + if (@hasDecl(definition, "__next__")) { + slots_ = slots_ ++ .{ffi.PyType_Slot{ + .slot = ffi.Py_tp_iternext, + .pfunc = @ptrCast(@constCast(&tp_iternext)), + }}; + } + slots_ = slots_ ++ .{ffi.PyType_Slot{ .slot = ffi.Py_tp_methods, .pfunc = @ptrCast(@constCast(&methods.pydefs)), @@ -179,7 +193,7 @@ fn Slots(comptime definition: type, comptime Instance: type) type { } fn mp_length(pyself: *ffi.PyObject) callconv(.C) isize { - const lenFunc = @field(definition, "__len__"); + const lenFunc = definition.__len__; const self: *const Instance = @ptrCast(pyself); const result = @as(isize, @intCast(lenFunc(&self.state))); if (@typeInfo(@typeInfo(@TypeOf(lenFunc)).Fn.return_type.?) == .ErrorUnion) { @@ -187,6 +201,24 @@ fn Slots(comptime definition: type, comptime Instance: type) type { } return result; } + + fn tp_iter(pyself: *ffi.PyObject) callconv(.C) ?*ffi.PyObject { + const iterFunc = definition.__iter__; + const trampoline = tramp.Trampoline(@typeInfo(@TypeOf(iterFunc)).Fn.return_type.?); + const self: *const Instance = @ptrCast(pyself); + const result = iterFunc(&self.state) catch return null; + const obj = trampoline.wrap(result) catch return null; + return obj.py; + } + + fn tp_iternext(pyself: *ffi.PyObject) callconv(.C) ?*ffi.PyObject { + const iterFunc = definition.__next__; + const trampoline = tramp.Trampoline(@typeInfo(@TypeOf(iterFunc)).Fn.return_type.?); + const self: *Instance = @constCast(@ptrCast(pyself)); + const result = iterFunc(&self.state) catch return null; + const obj = trampoline.wrap(result orelse return null) catch return null; + return obj.py; + } }; } diff --git a/pyproject.toml b/pyproject.toml index 0f1d2cad..ca6344ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,6 @@ line-length = 120 [tool.ruff] line-length = 120 select = ["F", "E", "W", "UP", "I001", "I002"] -target-version = "py310" [build-system] requires = ["poetry-core"] @@ -85,3 +84,7 @@ root = "example/classes.zig" [[tool.pydust.ext_module]] name = "example.buffers" root = "example/buffers.zig" + +[[tool.pydust.ext_module]] +name = "example.iterators" +root = "example/iterators.zig" diff --git a/test/test_iterator.py b/test/test_iterator.py new file mode 100644 index 00000000..2e4f7fcf --- /dev/null +++ b/test/test_iterator.py @@ -0,0 +1,11 @@ +import pytest + +from example import iterators + + +def test_range_iterator(): + range_iterator = iter(iterators.Range(0, 10, 1)) + for i in range(10): + assert next(range_iterator) == i + with pytest.raises(StopIteration): + next(range_iterator)