Skip to content

Commit

Permalink
Merge branch 'ccrouzet/gh-392-struct-op-overloading' into 'main'
Browse files Browse the repository at this point in the history
Add Operator Overloading for `wp.struct`

See merge request omniverse/warp!990
  • Loading branch information
mmacklin committed Jan 17, 2025
2 parents e4efbff + 72cf73d commit 1336132
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
([GH-379](https://github.com/NVIDIA/warp/issues/379)).
- Add per-module option to add CUDA-C line information for profiling, use `wp.set_module_options({"lineinfo": True})`.
- Add `example_tile_walker.py`, which reworks the existing `walker.py` to use Warp's tile API for matrix multiplication.
- Add operator overloads for `wp.struct` objects by defining `wp.func` functions ([GH-392](https://github.com/NVIDIA/warp/issues/392)).

### Changed

Expand Down
41 changes: 41 additions & 0 deletions docs/modules/runtime.rst
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,47 @@ Example: Using a struct in gradient computation
[[1. 2. 3.]
[4. 5. 6.]]

Example: Defining Operator Overloads
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. code:: python
@wp.struct
class Complex:
real: float
imag: float
@wp.func
def add(
a: Complex,
b: Complex,
) -> Complex:
return Complex(a.real + b.real, a.imag + b.imag)
@wp.func
def mul(
a: Complex,
b: Complex,
) -> Complex:
return Complex(
a.real * b.real - a.imag * b.imag,
a.real * b.imag + a.imag * b.real,
)
@wp.kernel
def kernel():
a = Complex(1.0, 2.0)
b = Complex(3.0, 4.0)
c = a + b
wp.printf("%.0f %+.0fi\n", c.real, c.imag)
d = a * b
wp.printf("%.0f %+.0fi\n", d.real, d.imag)
wp.launch(kernel, dim=(1,))
wp.synchronize()
Type Conversions
################

Expand Down
8 changes: 8 additions & 0 deletions warp/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1914,6 +1914,14 @@ def emit_BinOp(adj, node):

name = builtin_operators[type(node.op)]

try:
# Check if there is any user-defined overload for this operator
user_func = adj.resolve_external_reference(name)
if isinstance(user_func, warp.context.Function):
return adj.add_call(user_func, (left, right), {}, {})
except WarpCodegenError:
pass

return adj.add_builtin_call(name, [left, right])

def emit_UnaryOp(adj, node):
Expand Down
43 changes: 43 additions & 0 deletions warp/tests/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,48 @@ def test_operators_mat44():
expect_eq(r0[3], wp.vec4(39.0, 42.0, 45.0, 48.0))


@wp.struct
class Complex:
real: float
imag: float


@wp.func
def add(
a: Complex,
b: Complex,
) -> Complex:
return Complex(
a.real + b.real,
a.imag + b.imag,
)


@wp.func
def mul(
a: Complex,
b: Complex,
) -> Complex:
return Complex(
a.real * b.real - a.imag * b.imag,
a.real * b.imag + a.imag * b.real,
)


@wp.kernel
def test_operators_overload():
a = Complex(1.0, 2.0)
b = Complex(3.0, 4.0)

c = a + b
expect_eq(c.real, 4.0)
expect_eq(c.imag, 6.0)

d = a * b
expect_eq(d.real, -5.0)
expect_eq(d.imag, 10.0)


devices = get_test_devices()


Expand All @@ -241,6 +283,7 @@ class TestOperators(unittest.TestCase):
add_kernel_test(TestOperators, test_operators_mat22, dim=1, devices=devices)
add_kernel_test(TestOperators, test_operators_mat33, dim=1, devices=devices)
add_kernel_test(TestOperators, test_operators_mat44, dim=1, devices=devices)
add_kernel_test(TestOperators, test_operators_overload, dim=1, devices=devices)


if __name__ == "__main__":
Expand Down

0 comments on commit 1336132

Please sign in to comment.