From 677d9dbacfe970201c75825de0990e60ea2a543c Mon Sep 17 00:00:00 2001 From: Marko Bakovic Date: Thu, 7 Sep 2023 12:02:29 +0200 Subject: [PATCH 1/6] iterator protocol --- example/iterators.zig | 46 ++++++++++++++++++++++++++++++++++++++++ pydust/src/functions.zig | 2 ++ pydust/src/pytypes.zig | 37 ++++++++++++++++++++++++++++++++ pyproject.toml | 4 ++++ test/test_iterator.py | 11 ++++++++++ 5 files changed, 100 insertions(+) create mode 100644 example/iterators.zig create mode 100644 test/test_iterator.py diff --git a/example/iterators.zig b/example/iterators.zig new file mode 100644 index 00000000..af37cfbb --- /dev/null +++ b/example/iterators.zig @@ -0,0 +1,46 @@ +const std = @import("std"); +const py = @import("pydust"); + +pub const Range = py.class("Range", struct { + pub const __doc__ = "An example of iterable class"; + + const Self = @This(); + + lower: i64, + upper: i64, + step: i64, + + pub fn __new__(args: struct { lower: i64, upper: i64, step: i64 }) !Self { + return .{ .lower = args.lower, .upper = args.upper, .step = args.step }; + } + + pub fn __iter__(self: *const Self) !*RangeIterator { + return try py.init(RangeIterator, .{ .next = self.lower, .stop = self.upper, .step = self.step }); + } +}); + +pub const RangeIterator = py.class("Iterable", struct { + pub const __doc__ = "Range iterator"; + + const Self = @This(); + + next: i64, + stop: i64, + step: i64, + + pub fn __new__(args: struct { next: i64, stop: i64, step: i64 }) !Self { + return .{ .next = args.next, .stop = args.stop, .step = args.step }; + } + + pub fn __next__(self: *Self) !?py.PyLong { + if (self.next >= self.stop) { + return null; + } + defer self.next += self.step; + return try py.PyLong.from(i64, self.next); + } +}); + +comptime { + py.module(@This()); +} diff --git a/pydust/src/functions.zig b/pydust/src/functions.zig index 812069f9..aa6bb8eb 100644 --- a/pydust/src/functions.zig +++ b/pydust/src/functions.zig @@ -23,6 +23,8 @@ const reservedNames = .{ "__del__", "__buffer__", "__release_buffer__", + "__iter__", + "__next__", }; /// Parse the arguments of a Zig function into a Pydust function siganture. diff --git a/pydust/src/pytypes.zig b/pydust/src/pytypes.zig index 5cda31cd..9bb2eb5b 100644 --- a/pydust/src/pytypes.zig +++ b/pydust/src/pytypes.zig @@ -110,6 +110,20 @@ fn Slots(comptime definition: type, comptime Instance: type) type { }}; } + if (@hasDecl(definition, "__iter__")) { + slots_ = slots_ ++ .{ffi.PyType_Slot{ + .slot = ffi.Py_tp_iter, + .pfunc = @ptrCast(@constCast(&tp_iter)), + }}; + } + + if (@hasDecl(definition, "__next__")) { + slots_ = slots_ ++ .{ffi.PyType_Slot{ + .slot = ffi.Py_tp_iternext, + .pfunc = @ptrCast(@constCast(&tp_iternext)), + }}; + } + slots_ = slots_ ++ .{ffi.PyType_Slot{ .slot = ffi.Py_tp_methods, .pfunc = @ptrCast(@constCast(&methods.pydefs)), @@ -187,6 +201,29 @@ fn Slots(comptime definition: type, comptime Instance: type) type { } return result; } + + fn tp_iter(pyself: *ffi.PyObject) callconv(.C) ?*ffi.PyObject { + const iterFunc = @field(definition, "__iter__"); + const self: *const Instance = @ptrCast(pyself); + const result = iterFunc(&self.state); + const returnType = @typeInfo(@TypeOf(iterFunc)).Fn.return_type.?; + if (@typeInfo(returnType) == .ErrorUnion) { + return (tramp.Trampoline(returnType).wrap(result catch return null) catch return null).py; + } + return (tramp.Trampoline(returnType).wrap(result) catch return null).py; + } + + fn tp_iternext(pyself: *ffi.PyObject) callconv(.C) ?*ffi.PyObject { + const iterFunc = @field(definition, "__next__"); + var self: *Instance = @constCast(@ptrCast(pyself)); + const result = iterFunc(@constCast(&self.state)); + const returnType = @typeInfo(@TypeOf(iterFunc)).Fn.return_type.?; + if (@typeInfo(returnType) == .ErrorUnion) { + const optional_result = result catch return null; + return (tramp.Trampoline(returnType).wrap(optional_result orelse return null) catch return null).py; + } + return (tramp.Trampoline(returnType).wrap(result orelse return null) catch return null).py; + } }; } diff --git a/pyproject.toml b/pyproject.toml index 0f1d2cad..e0120216 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -85,3 +85,7 @@ root = "example/classes.zig" [[tool.pydust.ext_module]] name = "example.buffers" root = "example/buffers.zig" + +[[tool.pydust.ext_module]] +name = "example.iterators" +root = "example/iterators.zig" diff --git a/test/test_iterator.py b/test/test_iterator.py new file mode 100644 index 00000000..11a42328 --- /dev/null +++ b/test/test_iterator.py @@ -0,0 +1,11 @@ +import pytest + +from example import iterators + + +def test_range_iterator(): + range_iterator = iter(iterators.Range(0, 10, 1)) + for i in range(10): + assert next(range_iterator) == i + with pytest.raises(StopIteration) as exc: + next(range_iterator) From 25d05e29ded5738d9f1de3a6655566cdbaf205c2 Mon Sep 17 00:00:00 2001 From: Marko Bakovic Date: Thu, 7 Sep 2023 12:05:51 +0200 Subject: [PATCH 2/6] improvements --- pydust/src/pytypes.zig | 13 +++++++++---- test/test_iterator.py | 2 +- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/pydust/src/pytypes.zig b/pydust/src/pytypes.zig index 9bb2eb5b..acd68746 100644 --- a/pydust/src/pytypes.zig +++ b/pydust/src/pytypes.zig @@ -207,10 +207,12 @@ fn Slots(comptime definition: type, comptime Instance: type) type { const self: *const Instance = @ptrCast(pyself); const result = iterFunc(&self.state); const returnType = @typeInfo(@TypeOf(iterFunc)).Fn.return_type.?; + const trampoline = tramp.Trampoline(returnType); if (@typeInfo(returnType) == .ErrorUnion) { - return (tramp.Trampoline(returnType).wrap(result catch return null) catch return null).py; + const non_optional_result = result catch return null; + return (trampoline.wrap(non_optional_result) catch return null).py; } - return (tramp.Trampoline(returnType).wrap(result) catch return null).py; + return (trampoline.wrap(result) catch return null).py; } fn tp_iternext(pyself: *ffi.PyObject) callconv(.C) ?*ffi.PyObject { @@ -218,11 +220,14 @@ fn Slots(comptime definition: type, comptime Instance: type) type { var self: *Instance = @constCast(@ptrCast(pyself)); const result = iterFunc(@constCast(&self.state)); const returnType = @typeInfo(@TypeOf(iterFunc)).Fn.return_type.?; + const trampoline = tramp.Trampoline(returnType); if (@typeInfo(returnType) == .ErrorUnion) { const optional_result = result catch return null; - return (tramp.Trampoline(returnType).wrap(optional_result orelse return null) catch return null).py; + const non_optional_result = optional_result orelse return null; + return (trampoline.wrap(non_optional_result) catch return null).py; } - return (tramp.Trampoline(returnType).wrap(result orelse return null) catch return null).py; + const non_optional_result = result orelse return null; + return (trampoline.wrap(non_optional_result) catch return null).py; } }; } diff --git a/test/test_iterator.py b/test/test_iterator.py index 11a42328..2e4f7fcf 100644 --- a/test/test_iterator.py +++ b/test/test_iterator.py @@ -7,5 +7,5 @@ def test_range_iterator(): range_iterator = iter(iterators.Range(0, 10, 1)) for i in range(10): assert next(range_iterator) == i - with pytest.raises(StopIteration) as exc: + with pytest.raises(StopIteration): next(range_iterator) From af8f2b6ed3a9adc2b2192a5129514fb30282c68f Mon Sep 17 00:00:00 2001 From: Marko Bakovic Date: Thu, 7 Sep 2023 12:09:06 +0200 Subject: [PATCH 3/6] This commit will be squashed. --- example/iterators.zig | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/example/iterators.zig b/example/iterators.zig index af37cfbb..8eaf04c2 100644 --- a/example/iterators.zig +++ b/example/iterators.zig @@ -32,12 +32,12 @@ pub const RangeIterator = py.class("Iterable", struct { return .{ .next = args.next, .stop = args.stop, .step = args.step }; } - pub fn __next__(self: *Self) !?py.PyLong { + pub fn __next__(self: *Self) ?i64 { if (self.next >= self.stop) { return null; } defer self.next += self.step; - return try py.PyLong.from(i64, self.next); + return self.next; } }); From a97f2e1007f150703fded538420c804d3b20652a Mon Sep 17 00:00:00 2001 From: Marko Bakovic Date: Thu, 7 Sep 2023 12:47:58 +0200 Subject: [PATCH 4/6] comments --- pydust/src/pytypes.zig | 44 ++++++++++++++++++++++++++++-------------- pyproject.toml | 1 - 2 files changed, 29 insertions(+), 16 deletions(-) diff --git a/pydust/src/pytypes.zig b/pydust/src/pytypes.zig index acd68746..3b7c8479 100644 --- a/pydust/src/pytypes.zig +++ b/pydust/src/pytypes.zig @@ -203,31 +203,45 @@ fn Slots(comptime definition: type, comptime Instance: type) type { } fn tp_iter(pyself: *ffi.PyObject) callconv(.C) ?*ffi.PyObject { - const iterFunc = @field(definition, "__iter__"); + const iterFunc = definition.__iter__; const self: *const Instance = @ptrCast(pyself); - const result = iterFunc(&self.state); - const returnType = @typeInfo(@TypeOf(iterFunc)).Fn.return_type.?; + if (tp_iter_internal( + iterFunc(&self.state), + @typeInfo(@TypeOf(iterFunc)).Fn.return_type.?, + ) catch return null) |obj| { + return obj.py; + } else { + return null; + } + } + + fn tp_iter_internal(result: anytype, comptime returnType: type) !?py.PyObject { const trampoline = tramp.Trampoline(returnType); if (@typeInfo(returnType) == .ErrorUnion) { - const non_optional_result = result catch return null; - return (trampoline.wrap(non_optional_result) catch return null).py; + return try trampoline.wrap(result catch return null); } - return (trampoline.wrap(result) catch return null).py; + return try trampoline.wrap(result); } fn tp_iternext(pyself: *ffi.PyObject) callconv(.C) ?*ffi.PyObject { - const iterFunc = @field(definition, "__next__"); - var self: *Instance = @constCast(@ptrCast(pyself)); - const result = iterFunc(@constCast(&self.state)); - const returnType = @typeInfo(@TypeOf(iterFunc)).Fn.return_type.?; + const iterFunc = definition.__next__; + const self: *Instance = @constCast(@ptrCast(pyself)); + if (tp_iternext_internal( + iterFunc(&self.state), + @typeInfo(@TypeOf(iterFunc)).Fn.return_type.?, + ) catch return null) |obj| { + return obj.py; + } else { + return null; + } + } + + fn tp_iternext_internal(result: anytype, comptime returnType: type) !?py.PyObject { const trampoline = tramp.Trampoline(returnType); if (@typeInfo(returnType) == .ErrorUnion) { - const optional_result = result catch return null; - const non_optional_result = optional_result orelse return null; - return (trampoline.wrap(non_optional_result) catch return null).py; + return try trampoline.wrap((result catch return null) orelse return null); } - const non_optional_result = result orelse return null; - return (trampoline.wrap(non_optional_result) catch return null).py; + return try trampoline.wrap(result orelse return null); } }; } diff --git a/pyproject.toml b/pyproject.toml index e0120216..ca6344ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,6 @@ line-length = 120 [tool.ruff] line-length = 120 select = ["F", "E", "W", "UP", "I001", "I002"] -target-version = "py310" [build-system] requires = ["poetry-core"] From eaea8ac8716c5a66cc4c74f96bb9a92922c21e52 Mon Sep 17 00:00:00 2001 From: Marko Bakovic Date: Thu, 7 Sep 2023 13:24:24 +0200 Subject: [PATCH 5/6] Update example/iterators.zig Co-authored-by: Nicholas Gates --- example/iterators.zig | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/iterators.zig b/example/iterators.zig index 8eaf04c2..bcfadcd7 100644 --- a/example/iterators.zig +++ b/example/iterators.zig @@ -19,7 +19,7 @@ pub const Range = py.class("Range", struct { } }); -pub const RangeIterator = py.class("Iterable", struct { +pub const RangeIterator = py.class("RangeIterator", struct { pub const __doc__ = "Range iterator"; const Self = @This(); From 45640181f9895399343f3369bef373db2ebe611c Mon Sep 17 00:00:00 2001 From: Marko Bakovic Date: Thu, 7 Sep 2023 13:56:05 +0200 Subject: [PATCH 6/6] force error --- example/iterators.zig | 2 +- pydust/src/pytypes.zig | 42 +++++++++--------------------------------- 2 files changed, 10 insertions(+), 34 deletions(-) diff --git a/example/iterators.zig b/example/iterators.zig index 8eaf04c2..ee8aaf45 100644 --- a/example/iterators.zig +++ b/example/iterators.zig @@ -32,7 +32,7 @@ pub const RangeIterator = py.class("Iterable", struct { return .{ .next = args.next, .stop = args.stop, .step = args.step }; } - pub fn __next__(self: *Self) ?i64 { + pub fn __next__(self: *Self) !?i64 { if (self.next >= self.stop) { return null; } diff --git a/pydust/src/pytypes.zig b/pydust/src/pytypes.zig index 3b7c8479..15d5534d 100644 --- a/pydust/src/pytypes.zig +++ b/pydust/src/pytypes.zig @@ -193,7 +193,7 @@ fn Slots(comptime definition: type, comptime Instance: type) type { } fn mp_length(pyself: *ffi.PyObject) callconv(.C) isize { - const lenFunc = @field(definition, "__len__"); + const lenFunc = definition.__len__; const self: *const Instance = @ptrCast(pyself); const result = @as(isize, @intCast(lenFunc(&self.state))); if (@typeInfo(@typeInfo(@TypeOf(lenFunc)).Fn.return_type.?) == .ErrorUnion) { @@ -204,44 +204,20 @@ fn Slots(comptime definition: type, comptime Instance: type) type { fn tp_iter(pyself: *ffi.PyObject) callconv(.C) ?*ffi.PyObject { const iterFunc = definition.__iter__; + const trampoline = tramp.Trampoline(@typeInfo(@TypeOf(iterFunc)).Fn.return_type.?); const self: *const Instance = @ptrCast(pyself); - if (tp_iter_internal( - iterFunc(&self.state), - @typeInfo(@TypeOf(iterFunc)).Fn.return_type.?, - ) catch return null) |obj| { - return obj.py; - } else { - return null; - } - } - - fn tp_iter_internal(result: anytype, comptime returnType: type) !?py.PyObject { - const trampoline = tramp.Trampoline(returnType); - if (@typeInfo(returnType) == .ErrorUnion) { - return try trampoline.wrap(result catch return null); - } - return try trampoline.wrap(result); + const result = iterFunc(&self.state) catch return null; + const obj = trampoline.wrap(result) catch return null; + return obj.py; } fn tp_iternext(pyself: *ffi.PyObject) callconv(.C) ?*ffi.PyObject { const iterFunc = definition.__next__; + const trampoline = tramp.Trampoline(@typeInfo(@TypeOf(iterFunc)).Fn.return_type.?); const self: *Instance = @constCast(@ptrCast(pyself)); - if (tp_iternext_internal( - iterFunc(&self.state), - @typeInfo(@TypeOf(iterFunc)).Fn.return_type.?, - ) catch return null) |obj| { - return obj.py; - } else { - return null; - } - } - - fn tp_iternext_internal(result: anytype, comptime returnType: type) !?py.PyObject { - const trampoline = tramp.Trampoline(returnType); - if (@typeInfo(returnType) == .ErrorUnion) { - return try trampoline.wrap((result catch return null) orelse return null); - } - return try trampoline.wrap(result orelse return null); + const result = iterFunc(&self.state) catch return null; + const obj = trampoline.wrap(result orelse return null) catch return null; + return obj.py; } }; }