Skip to content

Commit

Permalink
fix build
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu committed Feb 3, 2023
1 parent 99ac780 commit 4b0afd8
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 9 deletions.
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ Status GatedRelativePositionBias<T>::ComputeInternal(OpKernelContext* context) c
reinterpret_cast<const CudaT*>(query_tensor.template Data<T>()),
reinterpret_cast<const CudaT*>(query_bias_tensor.template Data<T>()),
reinterpret_cast<CudaT*>(workspace.get()),
false, head_size, reinterpret_cast<CudaT*>(nullptr), total_maxtrix);
false, head_size, reinterpret_cast<CudaT*>(static_cast<CudaT*>(nullptr)), total_maxtrix);

// reuse output if possible
CudaT* gemm_output = (seq_len < D) ? (reinterpret_cast<CudaT*>(workspace.get()) + elements_in_query)
Expand Down
16 changes: 9 additions & 7 deletions onnxruntime/core/graph/contrib_ops/bert_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -688,13 +688,15 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 0, 0);
int64_t num_heads = getAttribute(ctx, "num_heads", -1L);
auto& query_layer_shape = getInputShape(ctx, 0);
TensorShapeProto output_shape;
*output_shape.add_dim() = query_layer_shape.dim(0);
output_shape.add_dim()->set_dim_value(num_heads);
*output_shape.add_dim() = query_layer_shape.dim(1);
*output_shape.add_dim() = query_layer_shape.dim(1);
updateOutputShape(ctx, 0, output_shape);
if (hasInputShape(ctx, 0)) {
auto& query_layer_shape = getInputShape(ctx, 0);
TensorShapeProto output_shape;
*output_shape.add_dim() = query_layer_shape.dim(0);
output_shape.add_dim()->set_dim_value(num_heads);
*output_shape.add_dim() = query_layer_shape.dim(1);
*output_shape.add_dim() = query_layer_shape.dim(1);
updateOutputShape(ctx, 0, output_shape);
}
}));

} // namespace contrib
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,6 @@ static void RunGatedRelativePositionBiasTest(
int min_cuda_architecture = use_float16 ? 530 : 0;

bool enable_cuda = HasCudaEnvironment(min_cuda_architecture);
bool enable_cpu = false;
if (enable_cuda) {
OpTester tester("GatedRelativePositionBias", 1, onnxruntime::kMSDomain);
tester.AddAttribute<int64_t>("num_heads", static_cast<int64_t>(num_heads));
Expand Down

0 comments on commit 4b0afd8

Please sign in to comment.