Skip to content

Commit

Permalink
add jax to benchmark processor
Browse files Browse the repository at this point in the history
  • Loading branch information
sky-2002 committed Oct 14, 2024
1 parent 0dfb41f commit 384ebad
Showing 1 changed file with 23 additions and 0 deletions.
23 changes: 23 additions & 0 deletions benchmarks/bench_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@
except ImportError:
pass

try:
import jax
import jax.numpy as jnp
except ImportError:
pass


def is_mlx_lm_allowed():
try:
Expand All @@ -18,6 +24,14 @@ def is_mlx_lm_allowed():
return mx.metal.is_available()


def is_jax_allowed():
try:
import jax # noqa: F401
except ImportError:
return False
return True


def get_mock_processor_inputs(array_library, num_tokens=30000):
"""
logits: (4, 30,000 ) dtype=float
Expand All @@ -43,6 +57,13 @@ def get_mock_processor_inputs(array_library, num_tokens=30000):
input_ids = mx.random.randint(
low=0, high=num_tokens, shape=(4, 2048), dtype=mx.int32
)
elif array_library == "jax":
logits = jnp.random.uniform(
key=jax.random.PRNGKey(0), shape=(4, num_tokens), dtype=jnp.float32
)
input_ids = jnp.random.randint(
key=jax.random.PRNGKey(0), low=0, high=num_tokens, shape=(4, 2048)
)
else:
raise ValueError

Expand All @@ -67,6 +88,8 @@ class LogitsProcessorPassthroughBenchmark:
params += ["mlx"]
if torch.cuda.is_available():
params += ["torch_cuda"]
if is_jax_allowed():
params += ["jax"]

def setup(self, array_library):
self.logits_processor = HalvingLogitsProcessor()
Expand Down

0 comments on commit 384ebad

Please sign in to comment.