Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[stdlib] Implement difference between pointers #3508

Open
wants to merge 8 commits into
base: nightly
Choose a base branch
from
29 changes: 29 additions & 0 deletions docs/manual/pointers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,35 @@
" print(float_ptr[offset], end=\", \")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"When subtracting two pointers of the same type, it gives the number of elements between them:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Second int: 200\nOffset between pointers: 1"
]
}
],
"source": [
"fn main():\n",
" l = List[Int](100, 200) \n",
" p1 = UnsafePointer.address_of(l[0])\n",
" p2 = p1 + 1\n",
" print(\"Second int:\", p2[])\n",
" print(\"Offset between pointers:\", p2 - p1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
12 changes: 12 additions & 0 deletions stdlib/src/memory/unsafe_pointer.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,18 @@ struct UnsafePointer[
"""
self = self - offset

@always_inline
fn __sub__(self, other: Self) -> Int:
"""Return the difference between two pointers.

Args:
other: The other pointer.

Returns:
The offset between the two pointers.
"""
return (int(self) - int(other)) // sizeof[type]()

@always_inline("nodebug")
fn __eq__(self, rhs: Self) -> Bool:
"""Returns True if the two pointers are equal.
Expand Down
9 changes: 9 additions & 0 deletions stdlib/test/memory/test_unsafepointer.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,14 @@ def test_load_and_store_simd():
assert_equal(ptr2[i], i // 4 * 4)


def test_difference():
var ptr = UnsafePointer[Int].alloc(5)
var ptr2 = ptr + 2
assert_equal(ptr2 - ptr, 2)
assert_equal(ptr - ptr2, -2)
ptr.free()


def main():
test_address_of()

Expand All @@ -297,6 +305,7 @@ def main():
test_unsafepointer_string()
test_eq()
test_comparisons()
test_difference()

test_unsafepointer_address_space()
test_indexing()
Expand Down