From afa85150997b5fc09ca5709ac191f79566ee41a7 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Thu, 14 Mar 2024 14:38:24 -0400 Subject: [PATCH 1/4] add a warning message when using prng_key outside of seed --- numpyro/primitives.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/numpyro/primitives.py b/numpyro/primitives.py index 99cf35902..cacfdbc3e 100644 --- a/numpyro/primitives.py +++ b/numpyro/primitives.py @@ -622,6 +622,8 @@ def prng_key(): :return: a PRNG key of shape (2,) and dtype unit32. """ if not _PYRO_STACK: + warnings.warn("Cannot generate JAX PRNG key outside of `seed handler.`", + stacklevel=find_stack_level()) return initial_msg = { From 568f86ef1e8b443e5804b5d8cd06d7a723e3a3eb Mon Sep 17 00:00:00 2001 From: Du Phan Date: Thu, 14 Mar 2024 14:52:47 -0400 Subject: [PATCH 2/4] format properly --- numpyro/primitives.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/numpyro/primitives.py b/numpyro/primitives.py index cacfdbc3e..e5b21478a 100644 --- a/numpyro/primitives.py +++ b/numpyro/primitives.py @@ -622,8 +622,10 @@ def prng_key(): :return: a PRNG key of shape (2,) and dtype unit32. """ if not _PYRO_STACK: - warnings.warn("Cannot generate JAX PRNG key outside of `seed handler.`", - stacklevel=find_stack_level()) + warnings.warn( + "Cannot generate JAX PRNG key outside of `seed handler.`", + stacklevel=find_stack_level(), + ) return initial_msg = { From 58db860e5a6f536a8453a84302ed2276c6a78bd6 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Thu, 14 Mar 2024 14:53:38 -0400 Subject: [PATCH 3/4] Fix typos --- numpyro/primitives.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/numpyro/primitives.py b/numpyro/primitives.py index e5b21478a..ac02a8856 100644 --- a/numpyro/primitives.py +++ b/numpyro/primitives.py @@ -623,7 +623,7 @@ def prng_key(): """ if not _PYRO_STACK: warnings.warn( - "Cannot generate JAX PRNG key outside of `seed handler.`", + "Cannot generate JAX PRNG key outside of `seed` handler.", stacklevel=find_stack_level(), ) return From 63ea001e7ddc874e28aea3944aca1ee315406921 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Thu, 14 Mar 2024 16:53:56 -0400 Subject: [PATCH 4/4] make sure that the test catch warning for prng_key --- test/test_handlers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_handlers.py b/test/test_handlers.py index e24e22890..518f856dc 100644 --- a/test/test_handlers.py +++ b/test/test_handlers.py @@ -778,7 +778,8 @@ def guide(): def test_prng_key(): - assert numpyro.prng_key() is None + with pytest.warns(Warning, match="outside of `seed`"): + assert numpyro.prng_key() is None with handlers.seed(rng_seed=0): rng_key = numpyro.prng_key()