From ae553df8f44be162cae4c0624dc18971e3aec337 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Tue, 8 Aug 2023 10:20:40 -0700 Subject: [PATCH] fix: Update embedding to reflect ATen schema - Remove arguments not present in initial schema for embedding - Improve coverage of embedding operator by expanding set of convertible implementations - Update parameter-checking function accordingly --- .../dynamo/conversion/aten_ops_converters.py | 27 +++++-------------- .../dynamo/conversion/impl/embedding.py | 19 +------------ 2 files changed, 7 insertions(+), 39 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 240ea47308..0ef7266624 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -101,23 +101,9 @@ def aten_ops_div( ) -def embedding_param_validator(embedding_node: Node) -> bool: - max_norm = args_bounds_check(embedding_node.args, 2) - norm_type = args_bounds_check(embedding_node.args, 3) - scale_grad_by_freq = args_bounds_check(embedding_node.args, 4) - sparse = args_bounds_check(embedding_node.args, 5) - - if max_norm is not None: - _LOGGER.debug( - f"Currently we don't support specifying max_norm, got {max_norm}." - ) - return False - - if norm_type is not None and norm_type != 2.0: - _LOGGER.debug( - f"Currently we don't support specifying norm_type, got {norm_type}." - ) - return False +def embedding_param_validator(embedding_node: Node): + scale_grad_by_freq = args_bounds_check(embedding_node.args, 3) + sparse = args_bounds_check(embedding_node.args, 4) if scale_grad_by_freq is not None: _LOGGER.debug( @@ -149,10 +135,9 @@ def aten_ops_embedding( name, input=args[1], weight=args[0], - max_norm=args_bounds_check(args, 2), - norm_type=args_bounds_check(args, 3), - scale_grad_by_freq=args_bounds_check(args, 4), - sparse=args_bounds_check(args, 5), + # args[2] is the padding index, which is useful for training only + scale_grad_by_freq=args_bounds_check(args, 3), + sparse=args_bounds_check(args, 4), ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/embedding.py b/py/torch_tensorrt/dynamo/conversion/impl/embedding.py index 48d5b55d7e..26064f621c 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/embedding.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/embedding.py @@ -14,16 +14,9 @@ def embedding( name: str, input: TRTTensor, weight: TRTTensor, - max_norm: None, - norm_type: None, scale_grad_by_freq: bool, sparse: bool, ) -> TRTTensor: - if network.has_implicit_batch_dimension: - raise RuntimeError( - "The `embedding` function should be called with explicit batch dimension." - ) - indices_tensor = input embedding_tensor = weight if isinstance(indices_tensor, torch.Tensor) and indices_tensor.dtype == torch.int64: @@ -37,16 +30,6 @@ def embedding( # unsupported parameters # ignore padding_idx since it is meaningful for training only - if max_norm is not None: - raise RuntimeError( - f"Currently we don't support specifying max_norm, got {max_norm}." - ) - - if norm_type is not None and norm_type != 2.0: - raise RuntimeError( - f"Currently we don't support specifying max_norm, got {norm_type} for norm_type." - ) - if scale_grad_by_freq: raise RuntimeError( "Currently we don't support scale gradient by word frequency." @@ -57,5 +40,5 @@ def embedding( # Implement embedding lookup with gather layer gather_layer = network.add_gather(embedding_tensor, indices_tensor, axis=0) - set_layer_name(gather_layer, target, name + "_gather") + set_layer_name(gather_layer, target, name + "_gather", source_ir) return gather_layer.get_output(0)