Skip to content

[Jax][Pallas][Mosaic] Implement platform dependent diag, with branch selection driven by constant prop in mosaic lowering. #679

[Jax][Pallas][Mosaic] Implement platform dependent diag, with branch selection driven by constant prop in mosaic lowering.

[Jax][Pallas][Mosaic] Implement platform dependent diag, with branch selection driven by constant prop in mosaic lowering. #679

TPU test (jaxlib=head, v5e-8)

succeeded Dec 22, 2024 in 7m 5s