From 09301265359dce7e7148ede38faca59e1ea5d036 Mon Sep 17 00:00:00 2001 From: Daniel Garvey <34486624+dan-garvey@users.noreply.github.com> Date: Fri, 1 Nov 2024 12:52:00 -0500 Subject: [PATCH] Revert "temporary decompose for decode (#353)" This reverts commit 0e93b6483d56927d7bf33ccc45244395235aae58. --- sharktank/sharktank/models/llama/llama.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sharktank/sharktank/models/llama/llama.py b/sharktank/sharktank/models/llama/llama.py index 656b4432b..ef3c4800d 100644 --- a/sharktank/sharktank/models/llama/llama.py +++ b/sharktank/sharktank/models/llama/llama.py @@ -269,7 +269,6 @@ def decode( for block_idx, block in enumerate(self.attn_blocks): if block_idx == 0: self.trace_tensor(f"llama.attn_block.{block_idx}.input", h) - block.attn.attention_kernel = "decomposed" h = block( h, start_positions=start_positions,