Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TypeWithAttrTypeConverter doesn't handle blocks args from loops #1344

Open
asraa opened this issue Jan 30, 2025 · 0 comments
Open

TypeWithAttrTypeConverter doesn't handle blocks args from loops #1344

asraa opened this issue Jan 30, 2025 · 0 comments
Assignees

Comments

@asraa
Copy link
Collaborator

asraa commented Jan 30, 2025

If a loop has loop carrying variables that are secret, then the inner type conversion to a ciphertext needs to happen. The secret-to-scheme passes rely on TypeWithAttrTypeConverter - but this code doesn't handle that case:

Attribute TypeWithAttrTypeConverter::getValueAttr(Value value) const {

In addition, affine for loops are converted in secret-to-scheme with a generic ConvertAny pass that intends to convert operands and regions. But that code relies on the base TypeConverter, so it doesn't actually convert secret types (see SecretToCKKSTypeConverter's conversion callbacks):

LogicalResult convertAnyOperand(const TypeConverter *typeConverter,
Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) {
if (typeConverter->isLegal(op)) {
return failure();
}
SmallVector<Type> newOperandTypes;
if (failed(
typeConverter->convertTypes(op->getOperandTypes(), newOperandTypes)))
return failure();
SmallVector<Type> newResultTypes;
if (failed(typeConverter->convertTypes(op->getResultTypes(), newResultTypes)))
return failure();
SmallVector<std::unique_ptr<Region>, 1> regions;
IRMapping mapping;
for (auto &r : op->getRegions()) {
Region *newRegion = new Region(op);
rewriter.cloneRegionBefore(r, *newRegion, newRegion->end(), mapping);
if (failed(rewriter.convertRegionTypes(newRegion, *typeConverter)))
return failure();
regions.emplace_back(newRegion);
}
Operation *newOp = rewriter.create(OperationState(
op->getLoc(), op->getName().getStringRef(), operands, newResultTypes,
op->getAttrs(), op->getSuccessors(), regions));
rewriter.replaceOp(op, newOp);
return success();
}

Here's a reproducer:

module {
  func.func @test(%arg0: !secret.secret<tensor<1x1024xf32>> {mgmt.mgmt = #mgmt.mgmt<level = 1>}) -> !secret.secret<tensor<1x1024xf32>> attributes {llvm.emit_c_interface} {
    %c1 = arith.constant 1 : index
    %2 = affine.for %arg1 = 1 to 1024 iter_args(%arg2 = %arg0) -> (!secret.secret<tensor<1x1024xf32>>) {
      %4 = secret.generic ins(%arg2 : !secret.secret<tensor<1x1024xf32>>) attrs = {mgmt.mgmt = #mgmt.mgmt<level = 1>} {
      ^body(%input0: tensor<1x1024xf32>):
        %7 = tensor_ext.rotate %input0, %c1 : tensor<1x1024xf32>, index
        secret.yield %7 : tensor<1x1024xf32>
      } -> !secret.secret<tensor<1x1024xf32>>
      affine.yield %4 : !secret.secret<tensor<1x1024xf32>>
    } {mgmt.mgmt = #mgmt.mgmt<level = 1>}
    %3 = secret.generic ins(%2 : !secret.secret<tensor<1x1024xf32>>) attrs = {mgmt.mgmt = #mgmt.mgmt<level = 0>} {
    ^body(%input0: tensor<1x1024xf32>):
      %4 = mgmt.modreduce %input0 : tensor<1x1024xf32>
      secret.yield %4 : tensor<1x1024xf32>
    } -> !secret.secret<tensor<1x1024xf32>>
    return %3 : !secret.secret<tensor<1x1024xf32>>
  }
}

Fix incoming.

@asraa asraa self-assigned this Jan 30, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant