Skip to content

Commit

Permalink
Merge branch 'fix/fem_array_axpy_warning' into 'main'
Browse files Browse the repository at this point in the history
Fix warning being printed when capturing `array_axpy` on Tape

See merge request omniverse/warp!868
  • Loading branch information
gdaviet committed Nov 18, 2024
2 parents 4f43dc2 + f1c61e3 commit a7af93a
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 5 deletions.
11 changes: 9 additions & 2 deletions warp/fem/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,9 +372,16 @@ def array_axpy(x: wp.array, y: wp.array, alpha: float = 1.0, beta: float = 1.0):
if x.shape != y.shape or x.device != y.device:
raise ValueError("x and y arrays must have the same shape and device")

# array_axpy requires a custom adjoint; unfortunately we cannot use `wp.func_grad`
# as generic functions are not supported yet. Instead we use a non-differentiable kernel
# and record a custom adjoint function on the tape.

# temporarilly disable tape to avoid printing warning that kernel is not differentiable
(tape, wp.context.runtime.tape) = (wp.context.runtime.tape, None)
wp.launch(kernel=_array_axpy_kernel, dim=x.shape, device=x.device, inputs=[x, y, alpha, beta])
wp.context.runtime.tape = tape

if (x.requires_grad or y.requires_grad) and wp.context.runtime.tape is not None:
if tape is not None and (x.requires_grad or y.requires_grad):

def backward_axpy():
# adj_x += adj_y * alpha
Expand All @@ -383,7 +390,7 @@ def backward_axpy():
if beta != 1.0:
array_axpy(x=y.grad, y=y.grad, alpha=0.0, beta=beta)

wp.context.runtime.tape.record_func(backward_axpy, arrays=[x, y])
tape.record_func(backward_axpy, arrays=[x, y])


@wp.kernel(enable_backward=False)
Expand Down
33 changes: 31 additions & 2 deletions warp/tests/test_fem.py
Original file line number Diff line number Diff line change
Expand Up @@ -1898,6 +1898,28 @@ def test_qr_inverse():
wp.expect_near(wp.ddot(Err, Err), 0.0, tol)


def test_array_axpy(test, device):
N = 10
alpha = 0.5
beta = 4.0

x = wp.full(N, 2.0, device=device, dtype=float, requires_grad=True)
y = wp.array(np.arange(N), device=device, dtype=wp.float64, requires_grad=True)

tape = wp.Tape()
with tape:
fem.utils.array_axpy(x=x, y=y, alpha=alpha, beta=beta)

assert_np_equal(x.numpy(), np.full(N, 2.0))
assert_np_equal(y.numpy(), alpha * x.numpy() + beta * np.arange(N))

y.grad.fill_(1.0)
tape.backward()

assert_np_equal(x.grad.numpy(), alpha * np.ones(N))
assert_np_equal(y.grad.numpy(), beta * np.ones(N))


devices = get_test_devices()
cuda_devices = get_selected_cuda_test_devices()

Expand Down Expand Up @@ -1928,8 +1950,15 @@ class TestFem(unittest.TestCase):
add_function_test(TestFem, "test_particle_quadratures", test_particle_quadratures)
add_function_test(TestFem, "test_nodal_quadrature", test_nodal_quadrature)
add_function_test(TestFem, "test_implicit_fields", test_implicit_fields)
add_kernel_test(TestFem, test_qr_eigenvalues, dim=1, devices=devices)
add_kernel_test(TestFem, test_qr_inverse, dim=100, devices=devices)


class TestFemUtilities(unittest.TestCase):
pass


add_kernel_test(TestFemUtilities, test_qr_eigenvalues, dim=1, devices=devices)
add_kernel_test(TestFemUtilities, test_qr_inverse, dim=100, devices=devices)
add_function_test(TestFemUtilities, "test_array_axpy", test_array_axpy)


class TestFemShapeFunctions(unittest.TestCase):
Expand Down
3 changes: 2 additions & 1 deletion warp/tests/unittest_suites.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def default_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader)
)
from warp.tests.test_fabricarray import TestFabricArray
from warp.tests.test_fast_math import TestFastMath
from warp.tests.test_fem import TestFem, TestFemShapeFunctions
from warp.tests.test_fem import TestFem, TestFemShapeFunctions, TestFemUtilities
from warp.tests.test_fp16 import TestFp16
from warp.tests.test_func import TestFunc
from warp.tests.test_future_annotations import TestFutureAnnotations
Expand Down Expand Up @@ -216,6 +216,7 @@ def default_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader)
TestFastMath,
TestFem,
TestFemShapeFunctions,
TestFemUtilities,
TestFp16,
TestFunc,
TestFutureAnnotations,
Expand Down

0 comments on commit a7af93a

Please sign in to comment.