Skip to content

Commit

Permalink
Set unbiased to False as Onnx need
Browse files Browse the repository at this point in the history
  • Loading branch information
vivekkhandelwal1 committed Aug 9, 2024
1 parent c7ef086 commit 6ed8a18
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 4 deletions.
3 changes: 1 addition & 2 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -412,14 +412,13 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
dimsToReduce);
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
Value cstTrue = rewriter.create<Torch::ConstantBoolOp>(loc, true);
Value currentMean = rewriter.create<Torch::AtenMeanDimOp>(
loc, meanResultType, input, reduceDimsList,
/*keepdim=*/cstFalse,
/*dtype=*/noneVal);
Value currentVar = rewriter.create<Torch::AtenVarDimOp>(
loc, varResultType, input, reduceDimsList,
/*unbiased=*/cstTrue,
/*unbiased=*/cstFalse,
/*keepdim=*/cstFalse);

// Computing running_mean.
Expand Down
3 changes: 1 addition & 2 deletions test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1275,9 +1275,8 @@ func.func @test_batchnorm_training(%arg0: !torch.vtensor<[1,16,27],f32>, %arg1:
// CHECK: %[[CST2:.*]] = torch.constant.int 2
// CHECK: %[[REDUCE_DIMS:.*]] = torch.prim.ListConstruct %[[CST0]], %[[CST2]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: %[[CURRENT_MEAN:.*]] = torch.aten.mean.dim %arg0, %[[REDUCE_DIMS]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[1,16,27],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[16],f32>
// CHECK: %[[CURRENT_VAR:.*]] = torch.aten.var.dim %arg0, %[[REDUCE_DIMS]], %[[TRUE]], %[[FALSE]] : !torch.vtensor<[1,16,27],f32>, !torch.list<int>, !torch.bool, !torch.bool -> !torch.vtensor<[16],f32>
// CHECK: %[[CURRENT_VAR:.*]] = torch.aten.var.dim %arg0, %[[REDUCE_DIMS]], %[[FALSE]], %[[FALSE]] : !torch.vtensor<[1,16,27],f32>, !torch.list<int>, !torch.bool, !torch.bool -> !torch.vtensor<[16],f32>
// CHECK: %[[MEAN_MUL_MOMENTUM:.*]] = torch.aten.mul.Scalar %arg3, %[[MOMENTUM]] : !torch.vtensor<[16],f32>, !torch.float -> !torch.vtensor<[16],f32>
// CHECK: %[[CURR_MEAN_MUL_MOMENTUM:.*]] = torch.aten.mul.Scalar %[[CURRENT_MEAN]], %[[MOMENTUM]] : !torch.vtensor<[16],f32>, !torch.float -> !torch.vtensor<[16],f32>
// CHECK: %[[CST1:.*]] = torch.constant.int 1
Expand Down

0 comments on commit 6ed8a18

Please sign in to comment.