Skip to content

Commit

Permalink
Revert embedding dataformat cast (#59)
Browse files Browse the repository at this point in the history
We don't need explicit embedding dataformat cast in tvm (from float32 to
bf16) as dataformat workaround for this case is implemented in mlir.

PRs for reference:

- [TVM cast workaround](#55)
- [Embedding Op
workaround](tenstorrent/tt-mlir#1583)
- [EmbeddingBackward Op
workaround](tenstorrent/tt-mlir#1756)

Related to this issue
tenstorrent/tt-forge-fe#1112
  • Loading branch information
pmarkovicTT authored Jan 27, 2025
2 parents 4304c2f + 0484763 commit 705bca2
Showing 1 changed file with 1 addition and 4 deletions.
5 changes: 1 addition & 4 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2700,10 +2700,7 @@ def embedding(self, inputs, input_types):
# exposes a few bugs in tt-mlir https://github.com/tenstorrent/tt-mlir/issues/1215
logger.warning("Casting input indices of embedding op from {} to int32", indicies_dtype)
indices = tvm.relay.cast(indices, "int32")
# cast the weight to bfloat16 if it is float32
if weight.type_annotation.dtype == "float32":
weight = tvm.relay.cast(weight, "bfloat16")
return tvm.relay.cast(_op.embedding(weight, indices, axis=0), "float32")
return _op.embedding(weight, indices, axis=0)

def embedding_bag(self, inputs, input_types):

Expand Down

0 comments on commit 705bca2

Please sign in to comment.