diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 6e76c3538e71d..66b401c731e34 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -705,12 +705,7 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) { this->PrintIndent(); int32_t constant_size = op->constant_allocation_size(); ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now"; - const VarNode* buffer = op->buffer_var.as(); - auto it = alloc_storage_scope_.find(buffer); - ICHECK(it != alloc_storage_scope_.end()) - << "Buffer " << op->buffer_var << " is missing an AttrStmt with a \"storage_scope\" key"; - - std::string scope = it->second; + std::string scope = GetStorageScope(op->buffer_var); if (scope.find("wmma.") == 0) { if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") { ICHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Int(8) || @@ -724,6 +719,7 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) { op->dtype == DataType::Int(32)) << "Accumulator only support half, float and int type for now"; } + const VarNode* buffer = op->buffer_var.as(); constant_size = GetWmmaFragmentSize(scope, buffer, constant_size); PrintWmmaScope(scope, op->dtype, buffer, stream); } else {