@@ -60,6 +60,7 @@ def rand_validator(rand_node: Node) -> bool:
6060 if layout is not None :
6161 _LOGGER .debug (f"Currently we don't support specifying layout, got { layout } ." )
6262 return False
63+ return True
6364
6465
6566@dynamo_tensorrt_converter (
@@ -76,21 +77,8 @@ def aten_ops_rand(
7677 return np .random .rand (* args )
7778
7879
79- def randn_validator (randn_node : Node ) -> bool :
80- dtype = randn_node .kwargs .get ("dtype" , None )
81- layout = randn_node .kwargs .get ("layout" , None )
82- if dtype is not None :
83- _LOGGER .debug (
84- f"Currently we don't support specifying output dtype, got { dtype } ."
85- )
86- return False
87- if layout is not None :
88- _LOGGER .debug (f"Currently we don't support specifying layout, got { layout } ." )
89- return False
90-
91-
9280@dynamo_tensorrt_converter (
93- torch .ops .aten .randn .default , capability_validator = randn_validator
81+ torch .ops .aten .randn .default , capability_validator = rand_validator
9482)
9583def aten_ops_randn (
9684 ctx : ConversionContext ,
@@ -118,6 +106,7 @@ def randperm_validator(randperm_node: Node) -> bool:
118106 if layout is not None :
119107 _LOGGER .debug (f"Currently we don't support specifying layout, got { layout } ." )
120108 return False
109+ return True
121110
122111
123112@dynamo_tensorrt_converter (
@@ -131,7 +120,4 @@ def aten_ops_randperm(
131120 name : str ,
132121) -> Union [TRTTensor , Sequence [TRTTensor ]]:
133122 device = kwargs .get ("device" , None )
134- input = args [0 ]
135- if not isinstance (input , int ):
136- raise RuntimeError (f"The input must be an integer" )
137123 return np .random .permutation (* args )
0 commit comments