@@ -78,8 +78,6 @@ def _unsortedSegmentMax(self, data, indices, num_segments):
7878 return self ._segmentReduction (math_ops .unsorted_segment_max , data , indices ,
7979 num_segments )
8080
81- @test .disable_with_predicate (
82- pred = test .is_built_with_rocm , skip_message = "Test fails on ROCm." ) #TODO(rocm): weekly sync 25-05-14
8381 def testSegmentSum (self ):
8482 for dtype in self .numeric_types :
8583 self .assertAllClose (
@@ -88,8 +86,6 @@ def testSegmentSum(self):
8886 np .array ([0 , 1 , 2 , 3 , 4 , 5 ], dtype = dtype ),
8987 np .array ([0 , 0 , 2 , 3 , 3 , 3 ], dtype = np .int32 ), 4 ))
9088
91- @test .disable_with_predicate (
92- pred = test .is_built_with_rocm , skip_message = "Test fails on ROCm." ) #TODO(rocm): weekly sync 24-11-05
9389 def testSegmentProd (self ):
9490 for dtype in self .numeric_types :
9591 self .assertAllClose (
@@ -98,8 +94,6 @@ def testSegmentProd(self):
9894 np .array ([0 , 1 , 2 , 3 , 4 , 5 ], dtype = dtype ),
9995 np .array ([0 , 0 , 2 , 3 , 3 , 3 ], dtype = np .int32 ), 4 ))
10096
101- @test .disable_with_predicate (
102- pred = test .is_built_with_rocm , skip_message = "Test fails on ROCm." ) #TODO(rocm): weekly sync 24-11-05
10397 def testSegmentProdNumSegmentsLess (self ):
10498 for dtype in self .numeric_types :
10599 self .assertAllClose (
@@ -108,8 +102,6 @@ def testSegmentProdNumSegmentsLess(self):
108102 np .array ([0 , 1 , 2 , 3 , 4 , 5 ], dtype = dtype ),
109103 np .array ([0 , 0 , 2 , 3 , 3 , 3 ], dtype = np .int32 ), 3 ))
110104
111- @test .disable_with_predicate (
112- pred = test .is_built_with_rocm , skip_message = "Test fails on ROCm." ) #TODO(rocm): weekly sync 24-11-05
113105 def testSegmentProdNumSegmentsMore (self ):
114106 for dtype in self .numeric_types :
115107 self .assertAllClose (
@@ -194,8 +186,6 @@ def testUnsortedSegmentSum0DIndices1DData(self):
194186 self ._unsortedSegmentSum (
195187 np .array ([0 , 1 , 2 , 3 , 4 , 5 ], dtype = dtype ), 2 , 4 ))
196188
197- @test .disable_with_predicate (
198- pred = test .is_built_with_rocm , skip_message = "Test fails on ROCm." ) #TODO(rocm): weekly sync 25-05-14
199189 def testUnsortedSegmentSum1DIndices1DData (self ):
200190 for dtype in self .numeric_types :
201191 self .assertAllClose (
@@ -204,8 +194,6 @@ def testUnsortedSegmentSum1DIndices1DData(self):
204194 np .array ([0 , 1 , 2 , 3 , 4 , 5 ], dtype = dtype ),
205195 np .array ([3 , 0 , 2 , 1 , 3 , 3 ], dtype = np .int32 ), 4 ))
206196
207- @test .disable_with_predicate (
208- pred = test .is_built_with_rocm , skip_message = "Test fails on ROCm." ) #TODO(rocm): weekly sync 25-05-14
209197 def testUnsortedSegmentSum1DIndices1DDataNegativeIndices (self ):
210198 for dtype in self .numeric_types :
211199 self .assertAllClose (
@@ -214,8 +202,6 @@ def testUnsortedSegmentSum1DIndices1DDataNegativeIndices(self):
214202 np .array ([0 , 1 , 2 , 3 , 4 , 5 , 6 ], dtype = dtype ),
215203 np .array ([3 , - 1 , 0 , 1 , 0 , - 1 , 3 ], dtype = np .int32 ), 4 ))
216204
217- @test .disable_with_predicate (
218- pred = test .is_built_with_rocm , skip_message = "Test fails on ROCm." ) #TODO(rocm): weekly sync 25-05-14
219205 def testUnsortedSegmentSum1DIndices2DDataDisjoint (self ):
220206 for dtype in self .numeric_types :
221207 data = np .array (
@@ -232,8 +218,6 @@ def testUnsortedSegmentSum1DIndices2DDataDisjoint(self):
232218 [50 , 51 , 52 , 53 ], [0 , 1 , 2 , 3 ], [0 , 0 , 0 , 0 ]],
233219 dtype = dtype ), y )
234220
235- @test .disable_with_predicate (
236- pred = test .is_built_with_rocm , skip_message = "Test fails on ROCm." ) #TODO(rocm): weekly sync 25-05-14
237221 def testUnsortedSegmentSum1DIndices2DDataNonDisjoint (self ):
238222 for dtype in self .numeric_types :
239223 data = np .array (
@@ -249,8 +233,6 @@ def testUnsortedSegmentSum1DIndices2DDataNonDisjoint(self):
249233 [0 , 0 , 0 , 0 ]],
250234 dtype = dtype ), y )
251235
252- @test .disable_with_predicate (
253- pred = test .is_built_with_rocm , skip_message = "Test fails on ROCm." ) #TODO(rocm): weekly sync 25-05-14
254236 def testUnsortedSegmentSum2DIndices3DData (self ):
255237 for dtype in self .numeric_types :
256238 data = np .array (
@@ -268,8 +250,6 @@ def testUnsortedSegmentSum2DIndices3DData(self):
268250 ], [0 , 0 , 0. ], [90 , 92 , 94 ], [103 , 104 , 105 ], [0 , 0 , 0 ]],
269251 dtype = dtype ), y )
270252
271- @test .disable_with_predicate (
272- pred = test .is_built_with_rocm , skip_message = "Test fails on ROCm." ) #TODO(rocm): weekly sync 25-05-14
273253 def testUnsortedSegmentSum1DIndices3DData (self ):
274254 for dtype in self .numeric_types :
275255 data = np .array (
@@ -298,8 +278,6 @@ def testUnsortedSegmentSumShapeError(self):
298278 math_ops .unsorted_segment_sum , data , indices ,
299279 num_segments ))
300280
301- @test .disable_with_predicate (
302- pred = test .is_built_with_rocm , skip_message = "Test fails on ROCm." ) #TODO(rocm): weekly sync 24-11-05
303281 def testUnsortedSegmentOps1DIndices1DDataNegativeIndices (self ):
304282 """Tests for min, max, and prod ops.
305283
0 commit comments