Skip to content

Commit 834d29e

Browse files
authored
uniform_: improve error handling and error messages. (#9635)
This PR refactors the `uniform_` operation implementation by improving its error message, and returning a status type value. **Key Changes:** - Make `tensor_methods::uniform_` return `Status` - Improve error messages and error handling - Add `CheckUniformRangeIsValid` function
1 parent 647f5a7 commit 834d29e

File tree

4 files changed

+32
-8
lines changed

4 files changed

+32
-8
lines changed

test/test_ops_error_message.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,20 @@ def test():
317317
expect="""baddbmm(): cannot apply batch matrix-multiplication to `batch1` f32[2,3,8], the 2nd input tensor, and to `batch2` f32[2,4,3], the 3rd input tensor. Expected the size of dimension 2 of `batch1` (8) to be equal the size of dimension 1 of `batch2` (4)."""
318318
)
319319

320+
def test_uniform__raises_error_on_invalid_range(self):
321+
device = torch_xla.device()
322+
a = torch.empty(5, 5, device=device)
323+
from_ = 5.
324+
to_ = 2.
325+
326+
def test():
327+
return a.uniform_(from_, to_)
328+
329+
self.assertExpectedRaisesInline(
330+
exc_type=RuntimeError,
331+
callable=test,
332+
expect="""uniform_(): expected `from` (5) <= `to` (2).""")
333+
320334

321335
if __name__ == "__main__":
322336
unittest.main()

torch_xla/csrc/aten_xla_type.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3936,8 +3936,9 @@ at::Tensor& XLANativeFunctions::uniform_(
39363936
return at::native::call_fallback_fn<&xla_fallback, ATEN_OP(uniform_)>::call(
39373937
self, from, to, generator);
39383938
}
3939-
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self));
3940-
tensor_methods::uniform_(xla_self, from, to);
3939+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self,
3940+
bridge::GetXlaTensor(self));
3941+
XLA_THROW_IF_ERROR(tensor_methods::uniform_(xla_self, from, to));
39413942
return self;
39423943
}
39433944

torch_xla/csrc/tensor_methods.cpp

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,14 @@ absl::Status CheckBmmInputsAreValid(const std::string_view op,
594594
return absl::OkStatus();
595595
}
596596

597+
absl::Status CheckUniformRangeIsValid(double from, double to) {
598+
if (from > to) {
599+
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat(
600+
"uniform_(): expected `from` (", from, ") <= `to` (", to, ").")));
601+
}
602+
return absl::OkStatus();
603+
}
604+
597605
} // namespace
598606

599607
//////////////////////////////////////////////////////////////////////////////
@@ -3759,15 +3767,16 @@ std::vector<XLATensorPtr> unbind(const XLATensorPtr& input, int64_t dim) {
37593767
return slices;
37603768
}
37613769

3762-
void uniform_(XLATensorPtr& input, double from, double to) {
3763-
XLA_CHECK_LE(from, to);
3764-
auto input_shape = input->shape();
3770+
absl::Status uniform_(XLATensorPtr& input, double from, double to) {
3771+
XLA_RETURN_IF_ERROR(CheckUniformRangeIsValid(from, to));
3772+
xla::Shape input_shape = input->shape();
37653773
input->SetInPlaceIrValue(torch_xla::MakeNode<Uniform>(
37663774
XLAGraphExecutor::Get()->GetIrValueForScalar(
3767-
from, input_shape.get().element_type(), input->GetDevice()),
3775+
from, input_shape.element_type(), input->GetDevice()),
37683776
XLAGraphExecutor::Get()->GetIrValueForScalar(
3769-
to, input_shape.get().element_type(), input->GetDevice()),
3777+
to, input_shape.element_type(), input->GetDevice()),
37703778
XLAGraphExecutor::Get()->GetRngSeed(input->GetDevice()), input_shape));
3779+
return absl::OkStatus();
37713780
}
37723781

37733782
XLATensorPtr unsqueeze(const XLATensorPtr& input, int64_t dim) {

torch_xla/csrc/tensor_methods.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -989,7 +989,7 @@ std::tuple<XLATensorPtr, XLATensorPtr> triangular_solve(
989989
// removed.
990990
std::vector<XLATensorPtr> unbind(const XLATensorPtr& input, int64_t dim);
991991

992-
void uniform_(XLATensorPtr& input, double from, double to);
992+
absl::Status uniform_(XLATensorPtr& input, double from, double to);
993993

994994
// Insert a dimension of size one at the specified position.
995995
XLATensorPtr unsqueeze(const XLATensorPtr& input, int64_t dim);

0 commit comments

Comments
 (0)