@@ -1374,54 +1374,37 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
1374
1374
return mlir::success ();
1375
1375
}
1376
1376
1377
- LogicalResult tosa::TransposeOp::getConstantPerms (SmallVector<int32_t > &perms) {
1378
- // Perms must be constants.
1379
- DenseIntElementsAttr permsAttr;
1380
- if (!matchPattern (getPerms (), m_Constant (&permsAttr)))
1381
- return failure ();
1382
-
1383
- perms.clear ();
1384
- for (auto v : permsAttr.getValues <APInt>())
1385
- perms.push_back (v.getSExtValue ());
1386
-
1387
- return success ();
1388
- }
1389
-
1390
1377
LogicalResult tosa::TransposeOp::inferReturnTypeComponents (
1391
1378
MLIRContext *context, ::std::optional<Location> location,
1392
1379
TransposeOp::Adaptor adaptor,
1393
1380
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1394
1381
ShapeAdaptor inputShape (adaptor.getInput1 ().getType ());
1395
- ShapeAdaptor permsShape (adaptor.getPerms ().getType ());
1396
-
1397
- // We cannot infer anything from a rank-0 "permutation" tensor.
1398
- if (permsShape.hasRank () && permsShape.getRank () == 0 )
1399
- return failure ();
1400
1382
1401
1383
// If input rank and permutation length is unknown, the output rank is
1402
1384
// unknown.
1403
- if (!inputShape.hasRank () || !permsShape.hasRank () ||
1404
- permsShape.isDynamicDim (0 )) {
1385
+ if (!inputShape.hasRank ()) {
1405
1386
inferredReturnShapes.push_back (ShapedTypeComponents ());
1406
1387
return success ();
1407
1388
}
1408
1389
1390
+ const auto inputRank = inputShape.getRank ();
1391
+
1409
1392
// This would imply the number of permutations does not match the rank of
1410
1393
// the input which is illegal.
1411
- if (permsShape. getDimSize ( 0 ) != inputShape. getRank ( )) {
1394
+ if (adaptor. getPerms (). size () != static_cast < size_t >(inputRank )) {
1412
1395
return failure ();
1413
1396
}
1414
1397
1415
1398
SmallVector<int64_t > outputShape;
1416
1399
// Rank-0 means no permutations matter.
1417
- if (inputShape. getRank () == 0 ) {
1400
+ if (inputRank == 0 ) {
1418
1401
inferredReturnShapes.push_back (ShapedTypeComponents (outputShape));
1419
1402
return success ();
1420
1403
}
1421
1404
1422
1405
// Check whether the input dimensions are all the same.
1423
1406
bool allTheSame = true ;
1424
- for (int i = 1 , s = inputShape. getRank () ; i < s; i++) {
1407
+ for (int i = 1 , s = inputRank ; i < s; i++) {
1425
1408
if (inputShape.getDimSize (0 ) != inputShape.getDimSize (i)) {
1426
1409
allTheSame = false ;
1427
1410
break ;
@@ -1431,34 +1414,21 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
1431
1414
// If all of the input dimensions are the same we don't care about the
1432
1415
// permutation.
1433
1416
if (allTheSame) {
1434
- outputShape.resize (inputShape. getRank () , inputShape.getDimSize (0 ));
1417
+ outputShape.resize (inputRank , inputShape.getDimSize (0 ));
1435
1418
inferredReturnShapes.push_back (ShapedTypeComponents (outputShape));
1436
1419
return success ();
1437
1420
}
1438
1421
1439
- outputShape.resize (inputShape.getRank (), ShapedType::kDynamic );
1440
- // If the permuations are a constant we can directly determine the output
1441
- // shape.
1442
- DenseIntElementsAttr attr;
1443
- if (matchPattern (adaptor.getPerms (), m_Constant (&attr)) &&
1444
- attr.getType ().getRank () == 1 ) {
1445
- ShapeAdaptor permShape = attr;
1446
- // Constant permutation must be the same length as the input rank.
1447
- if (inputShape.getRank () != permShape.getRank ())
1448
- return emitOptionalError (location,
1449
- " constant permutation must be the same length"
1450
- " as the input rank" );
1451
-
1452
- // Constant permutation values must be within the input rank.
1453
- for (int i = 0 , e = inputShape.getRank (); i < e; i++) {
1454
- if (inputShape.getRank () <= permShape.getDimSize (i))
1455
- return failure ();
1456
- }
1422
+ outputShape.resize (inputRank, ShapedType::kDynamic );
1457
1423
1458
- outputShape.reserve (inputShape.getRank ());
1459
- for (int i = 0 , s = inputShape.getRank (); i < s; i++) {
1460
- outputShape[i] = inputShape.getDimSize (permShape.getDimSize (i));
1461
- }
1424
+ // Constant permutation values must be within the input rank.
1425
+ if (llvm::any_of (adaptor.getPerms (),
1426
+ [inputRank](const auto i) { return i >= inputRank; }))
1427
+ return failure ();
1428
+
1429
+ outputShape.reserve (inputRank);
1430
+ for (int i = 0 , s = inputRank; i < s; i++) {
1431
+ outputShape[i] = inputShape.getDimSize (adaptor.getPerms ()[i]);
1462
1432
}
1463
1433
1464
1434
inferredReturnShapes.push_back (ShapedTypeComponents (outputShape));
@@ -1467,75 +1437,60 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
1467
1437
1468
1438
LogicalResult tosa::TransposeOp::verify () {
1469
1439
TensorType inputType = getInput1 ().getType ();
1470
- TensorType permType = getPerms ().getType ();
1471
1440
TensorType outputType = getOutput ().getType ();
1441
+ const llvm::ArrayRef<int32_t > constantPerms = getPerms ();
1472
1442
1473
- if (permType.hasRank () && permType.getRank () != 1 )
1474
- return emitOpError ()
1475
- << " expected permutation tensor to be rank 1 but got rank "
1476
- << permType.getRank ();
1477
- if (inputType.hasRank () && permType.hasRank ())
1478
- if (!permType.isDynamicDim (0 ) &&
1479
- permType.getDimSize (0 ) != inputType.getRank ())
1480
- return emitOpError () << " expected permutation tensor dim 0 to have size "
1481
- << inputType.getRank ()
1482
- << " (input rank) but got size "
1483
- << permType.getDimSize (0 );
1443
+ if (inputType.hasRank () &&
1444
+ constantPerms.size () != static_cast <size_t >(inputType.getRank ()))
1445
+ return emitOpError () << " expected perms attribute to have size "
1446
+ << inputType.getRank () << " (input rank) but got size "
1447
+ << constantPerms.size ();
1484
1448
if (inputType.hasRank () && outputType.hasRank () &&
1485
1449
inputType.getRank () != outputType.getRank ())
1486
1450
return emitOpError ()
1487
1451
<< " expected input tensor rank to equal result tensor rank" ;
1488
- if (outputType.hasRank () && permType.hasRank ())
1489
- if (!permType.isDynamicDim (0 ) &&
1490
- permType.getDimSize (0 ) != outputType.getRank ())
1491
- return emitOpError () << " expected permutation tensor dim 0 to have size "
1492
- << outputType.getRank ()
1493
- << " (output rank) but got size "
1494
- << permType.getDimSize (0 );
1495
-
1496
- SmallVector<int32_t > constantPerms;
1497
- if (succeeded (getConstantPerms (constantPerms))) {
1498
- // Assert that the permutation tensor has a rank, which means that the
1499
- // rank has been verified above.
1500
- assert (permType.hasRank () &&
1501
- " Unexpectedly found permutation tensor without rank" );
1502
- if (!llvm::all_of (constantPerms,
1503
- [&constantPerms](int32_t s) {
1504
- return s >= 0 &&
1505
- static_cast <size_t >(s) < constantPerms.size ();
1506
- }) ||
1507
- !isPermutationVector (llvm::to_vector (llvm::map_range (
1508
- constantPerms, [](int32_t v) -> int64_t { return v; }))))
1509
- return emitOpError () << " expected valid permutation tensor" ;
1510
-
1511
- // Verify that the types of the input and output tensors are properly
1512
- // permuted.
1513
- if (inputType.hasRank () && outputType.hasRank ()) {
1514
- assert (constantPerms.size () == static_cast <size_t >(inputType.getRank ()) &&
1515
- inputType.getRank () == outputType.getRank ());
1516
-
1517
- for (auto i = 0 ; i < outputType.getRank (); i++) {
1518
- if (inputType.isDynamicDim (constantPerms[i]) ||
1519
- outputType.isDynamicDim (i))
1520
- continue ;
1521
-
1522
- if (inputType.getDimSize (constantPerms[i]) != outputType.getDimSize (i))
1523
- return emitOpError ()
1524
- << " expected output tensor dim " << i << " to match "
1525
- << " input dim " << constantPerms[i] << " with value of "
1526
- << inputType.getDimSize (constantPerms[i]);
1527
- }
1452
+ if (outputType.hasRank () &&
1453
+ constantPerms.size () != static_cast <size_t >(outputType.getRank ()))
1454
+ return emitOpError () << " expected perms attribute to have size "
1455
+ << outputType.getRank ()
1456
+ << " (output rank) but got size "
1457
+ << constantPerms.size ();
1458
+
1459
+ if (!llvm::all_of (constantPerms,
1460
+ [&constantPerms](int32_t s) {
1461
+ return s >= 0 &&
1462
+ static_cast <size_t >(s) < constantPerms.size ();
1463
+ }) ||
1464
+ !isPermutationVector (llvm::to_vector (llvm::map_range (
1465
+ constantPerms, [](int32_t v) -> int64_t { return v; }))))
1466
+ return emitOpError () << " expected valid permutation indices" ;
1467
+
1468
+ // Verify that the types of the input and output tensors are properly
1469
+ // permuted.
1470
+ if (inputType.hasRank () && outputType.hasRank ()) {
1471
+ assert (constantPerms.size () == static_cast <size_t >(inputType.getRank ()) &&
1472
+ inputType.getRank () == outputType.getRank ());
1473
+
1474
+ for (auto i = 0 ; i < outputType.getRank (); i++) {
1475
+ if (inputType.isDynamicDim (constantPerms[i]) ||
1476
+ outputType.isDynamicDim (i))
1477
+ continue ;
1478
+
1479
+ if (inputType.getDimSize (constantPerms[i]) != outputType.getDimSize (i))
1480
+ return emitOpError ()
1481
+ << " expected output tensor dim " << i << " to match "
1482
+ << " input dim " << constantPerms[i] << " with value of "
1483
+ << inputType.getDimSize (constantPerms[i]);
1528
1484
}
1529
1485
}
1486
+
1530
1487
return success ();
1531
1488
}
1532
1489
1533
1490
LogicalResult TransposeOp::reifyResultShapes (
1534
1491
OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1535
1492
1536
- SmallVector<int32_t > transposePerms;
1537
- if (getConstantPerms (transposePerms).failed ())
1538
- return failure ();
1493
+ const llvm::ArrayRef<int32_t > transposePerms = getPerms ();
1539
1494
1540
1495
Value input = getInput1 ();
1541
1496
auto inputType = cast<TensorType>(input.getType ());
0 commit comments