From 04847638c40f8114717a7b432975fc4ceab14ebc Mon Sep 17 00:00:00 2001 From: Pavle Markovic Date: Mon, 27 Jan 2025 11:41:59 +0000 Subject: [PATCH] Revert embedding dataformat cast --- python/tvm/relay/frontend/pytorch.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 37f17057b..c4a7b06b7 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2700,10 +2700,7 @@ def embedding(self, inputs, input_types): # exposes a few bugs in tt-mlir https://github.com/tenstorrent/tt-mlir/issues/1215 logger.warning("Casting input indices of embedding op from {} to int32", indicies_dtype) indices = tvm.relay.cast(indices, "int32") - # cast the weight to bfloat16 if it is float32 - if weight.type_annotation.dtype == "float32": - weight = tvm.relay.cast(weight, "bfloat16") - return tvm.relay.cast(_op.embedding(weight, indices, axis=0), "float32") + return _op.embedding(weight, indices, axis=0) def embedding_bag(self, inputs, input_types):