diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 147813d8e0..c314eded3b 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -25,6 +25,7 @@ get_positive_dim, is_only_operator_on_placeholder, ) +from torch_tensorrt.dynamo.utils import DYNAMIC_DIM _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -2721,9 +2722,18 @@ def sort_validator(node: Node, settings: Optional[CompilationSettings] = None) - def topk_sort_validator(k: int) -> bool: + + # topk layer supports dynamic k value but we cannot determine supported dynamic topk value at + # compile time. + if k == DYNAMIC_DIM or not isinstance(k, int): + _LOGGER.warning( + "[top_k validator] It's not expected for k to be a dynamic or data-dependent value. aten::topk will run in PyTorch" + ) + return False + if k > 3840: - _LOGGER.debug( - f"Currently only topk values up to 3840 are supported, got k={k}." + _LOGGER.warning( + f"[top_k validator] Currently only topk values up to 3840 are supported, got k={k}. Therefore, aten::topk will run in PyTorch" ) return False return True diff --git a/py/torch_tensorrt/dynamo/conversion/impl/topk.py b/py/torch_tensorrt/dynamo/conversion/impl/topk.py index 053a46ce2b..638cbf599e 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/topk.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/topk.py @@ -209,10 +209,6 @@ def topk( get_axes_for_reduce_op(get_positive_dim(dim, len(input.shape))), ) - # topk layer supports dynamic k value but we cannot dertermin supported dynamic topk value at - # compile time. - assert k != DYNAMIC_DIM, "k value cannot be dynamic!" - # TensorRT ITopKLayer does not have a sorted flag, it is always returning the sorted topk elements # so here no matter sorted is True or False the returned the topk Tensor object is always sorted set_layer_name(topk_layer, target, f"{name}_topk", source_ir)