From 384ebadf290f175a70e2331dabba7b322932a27c Mon Sep 17 00:00:00 2001 From: Aakash Thatte Date: Mon, 14 Oct 2024 19:11:08 +0530 Subject: [PATCH] add jax to benchmark processor --- benchmarks/bench_processors.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/benchmarks/bench_processors.py b/benchmarks/bench_processors.py index 5b4901540..db1e4a8f1 100644 --- a/benchmarks/bench_processors.py +++ b/benchmarks/bench_processors.py @@ -9,6 +9,12 @@ except ImportError: pass +try: + import jax + import jax.numpy as jnp +except ImportError: + pass + def is_mlx_lm_allowed(): try: @@ -18,6 +24,14 @@ def is_mlx_lm_allowed(): return mx.metal.is_available() +def is_jax_allowed(): + try: + import jax # noqa: F401 + except ImportError: + return False + return True + + def get_mock_processor_inputs(array_library, num_tokens=30000): """ logits: (4, 30,000 ) dtype=float @@ -43,6 +57,13 @@ def get_mock_processor_inputs(array_library, num_tokens=30000): input_ids = mx.random.randint( low=0, high=num_tokens, shape=(4, 2048), dtype=mx.int32 ) + elif array_library == "jax": + logits = jnp.random.uniform( + key=jax.random.PRNGKey(0), shape=(4, num_tokens), dtype=jnp.float32 + ) + input_ids = jnp.random.randint( + key=jax.random.PRNGKey(0), low=0, high=num_tokens, shape=(4, 2048) + ) else: raise ValueError @@ -67,6 +88,8 @@ class LogitsProcessorPassthroughBenchmark: params += ["mlx"] if torch.cuda.is_available(): params += ["torch_cuda"] + if is_jax_allowed(): + params += ["jax"] def setup(self, array_library): self.logits_processor = HalvingLogitsProcessor()