From 083ccf9893a98bffb5cbcdf4c06bbeed11c56fef Mon Sep 17 00:00:00 2001 From: Nathan Simpson Date: Tue, 19 Oct 2021 23:34:15 +0200 Subject: [PATCH] Cast the input of jax.concatenate to a jax array This should make the behaviour consistent with numpy, e.g. when using a list as input. --- src/pyhf/tensor/jax_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pyhf/tensor/jax_backend.py b/src/pyhf/tensor/jax_backend.py index f5867ded10..fb0b810d75 100644 --- a/src/pyhf/tensor/jax_backend.py +++ b/src/pyhf/tensor/jax_backend.py @@ -295,7 +295,7 @@ def concatenate(self, sequence, axis=0): output: the concatenated tensor """ - return jnp.concatenate(sequence, axis=axis) + return jnp.concatenate(jnp.asarray(sequence), axis=axis) def simple_broadcast(self, *args): """