22#
33# This source code is licensed under the BSD-style license found in the
44# LICENSE file in the root directory of this source tree.
5+ """Declare operator support for ``aten.embedding`` in TOSA.
56
7+ Permit embeddings with int32 indices (TOSA lacks int64 support); other dtypes
8+ are rejected by this check.
9+
10+ """
611
712import torch
813
1722
1823@register_tosa_support_check
1924class EmbeddingSupported (SupportedTOSAOperatorCheck ):
25+ """Provide TOSA support check for ``aten.embedding``."""
26+
2027 targets = [exir_ops .edge .aten .embedding .default ]
2128
2229 tosa_specs = [
@@ -27,16 +34,20 @@ class EmbeddingSupported(SupportedTOSAOperatorCheck):
2734 def is_node_tosa_supported (
2835 self , node : fx .Node , tosa_spec : TosaSpecification
2936 ) -> bool : # type: ignore[override, misc]
30- # Note aten.embedding.default requires int64 indices and TOSA does not
31- # support it. Int32 indices here for aten.embedding.default is ok since
32- # it will be decomposed into ops that can handle it.
37+ """Return True if the node is supported by TOSA.
3338
39+ PyTorch's ``aten.embedding`` typically takes int64 indices, but for
40+ TOSA we only allow int32 indices. The export path decomposes the op so
41+ that int32 indices are ok.
42+
43+ """
3444 if len (node .all_input_nodes ) != 2 :
3545 self .reporter .report_reject (
3646 node ,
3747 (f"Expected exactly two input nodes, got { len (node .all_input_nodes )} " ),
3848 )
3949 return False
50+
4051 indices_val = node .all_input_nodes [1 ].meta ["val" ]
4152 indices_dtype = indices_val .dtype
4253
0 commit comments