Skip to content

Commit

Permalink
Use std::string_view
Browse files Browse the repository at this point in the history
  • Loading branch information
cyyever committed Jun 16, 2024
1 parent 3f22daa commit eee3b2f
Show file tree
Hide file tree
Showing 10 changed files with 28 additions and 28 deletions.
4 changes: 2 additions & 2 deletions test/cpp/test_aten_xla_tensor_4.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ TEST_F(AtenXlaTensorTest, TestDiv) {
}

TEST_F(AtenXlaTensorTest, TestDivWithRoundingMode) {
std::optional<c10::string_view> rounding_modes[] = {"trunc", "floor",
std::optional<std::string_view> rounding_modes[] = {"trunc", "floor",
std::nullopt};
for (const auto& rounding_mode : rounding_modes) {
for (torch::ScalarType scalar_type1 :
Expand Down Expand Up @@ -453,7 +453,7 @@ TEST_F(AtenXlaTensorTest, TestDivInPlace) {
}

TEST_F(AtenXlaTensorTest, TestDivInPlaceWithRoundingMode) {
std::optional<c10::string_view> rounding_modes[] = {"trunc", "floor",
std::optional<std::string_view> rounding_modes[] = {"trunc", "floor",
std::nullopt};
for (const auto& rounding_mode : rounding_modes) {
for (torch::ScalarType scalar_type1 : {torch::kFloat}) {
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/aten_autograd_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ bool IsNonTrivialDilation(at::IntArrayRef dilation) {
namespace aten_autograd_ops {

torch::Tensor EinsumAutogradFunction::forward(
torch::autograd::AutogradContext* ctx, const c10::string_view equation,
torch::autograd::AutogradContext* ctx, const std::string_view equation,
at::TensorList tensors) {
std::string eq_str = std::string(equation);
ctx->saved_data["equation"] = eq_str;
Expand Down
4 changes: 2 additions & 2 deletions torch_xla/csrc/aten_autograd_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace aten_autograd_ops {
struct EinsumAutogradFunction
: public torch::autograd::Function<EinsumAutogradFunction> {
static torch::Tensor forward(torch::autograd::AutogradContext* ctx,
c10::string_view equation,
std::string_view equation,
at::TensorList tensors);
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
Expand Down Expand Up @@ -60,4 +60,4 @@ torch::Tensor max_pool2d_backward(torch::Tensor grad_output, torch::Tensor self,
} // namespace aten_autograd_ops
} // namespace torch_xla

#endif // XLA_TORCH_XLA_CSRC_ATEN_AUTOGRAD_OPS_H_
#endif // XLA_TORCH_XLA_CSRC_ATEN_AUTOGRAD_OPS_H_
20 changes: 10 additions & 10 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1363,7 +1363,7 @@ at::Tensor XLANativeFunctions::div(const at::Tensor& self,

at::Tensor XLANativeFunctions::div(
const at::Tensor& self, const at::Tensor& other,
std::optional<c10::string_view> rounding_mode) {
std::optional<std::string_view> rounding_mode) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
at::ScalarType dtype = at::result_type(self, other);
auto operands = GetBinaryOperands(self, UnwrapNumber(other, dtype));
Expand Down Expand Up @@ -1401,7 +1401,7 @@ at::Tensor XLANativeFunctions::dot(const at::Tensor& self,
bridge::GetXlaTensor(self), bridge::GetXlaTensor(tensor)));
}

at::Tensor XLANativeFunctions::einsum(c10::string_view equation,
at::Tensor XLANativeFunctions::einsum(std::string_view equation,
at::TensorList tensors,
at::OptionalIntArrayRef path) {
std::string cleansed_equation = std::string(equation);
Expand Down Expand Up @@ -1660,15 +1660,15 @@ at::Tensor XLANativeFunctions::gather(const at::Tensor& self, int64_t dim,
}

at::Tensor XLANativeFunctions::gelu(const at::Tensor& self,
c10::string_view approximate) {
std::string_view approximate) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
return bridge::AtenFromXlaTensor(
tensor_methods::gelu(bridge::GetXlaTensor(self), approximate));
}

at::Tensor XLANativeFunctions::gelu_backward(const at::Tensor& grad,
const at::Tensor& self,
c10::string_view approximate) {
std::string_view approximate) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
at::ScalarType result_type = at::result_type(grad, self);
return bridge::AtenFromXlaTensor(tensor_methods::gelu_backward(
Expand Down Expand Up @@ -3031,7 +3031,7 @@ at::Tensor XLANativeFunctions::rsub(const at::Tensor& self,

at::Tensor scatter_reduce_helper(const at::Tensor& self, int64_t dim,
const at::Tensor& index, const at::Tensor& src,
std::optional<c10::string_view> reduce) {
std::optional<std::string_view> reduce) {
XLATensorPtr self_tensor = bridge::GetXlaTensor(self);
if (!reduce.has_value()) {
return bridge::AtenFromXlaTensor(
Expand All @@ -3052,7 +3052,7 @@ at::Tensor scatter_reduce_helper(const at::Tensor& self, int64_t dim,
at::Tensor scatter_reduce_helper(const at::Tensor& self, int64_t dim,
const at::Tensor& index,
const at::Scalar& value,
std::optional<c10::string_view> reduce) {
std::optional<std::string_view> reduce) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
XLATensorPtr self_tensor = bridge::GetXlaTensor(self);
if (!reduce.has_value()) {
Expand Down Expand Up @@ -3087,15 +3087,15 @@ at::Tensor XLANativeFunctions::scatter(const at::Tensor& self, int64_t dim,
at::Tensor XLANativeFunctions::scatter(const at::Tensor& self, int64_t dim,
const at::Tensor& index,
const at::Tensor& src,
c10::string_view reduce) {
std::string_view reduce) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
return scatter_reduce_helper(self, dim, index, src, reduce);
}

at::Tensor XLANativeFunctions::scatter(const at::Tensor& self, int64_t dim,
const at::Tensor& index,
const at::Scalar& value,
c10::string_view reduce) {
std::string_view reduce) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
return scatter_reduce_helper(self, dim, index, value, reduce);
}
Expand All @@ -3111,7 +3111,7 @@ at::Tensor XLANativeFunctions::scatter_add(const at::Tensor& self, int64_t dim,
// supported
at::Tensor XLANativeFunctions::scatter_reduce(
const at::Tensor& self, int64_t dim, const at::Tensor& index,
const at::Tensor& src, c10::string_view reduce, bool include_self) {
const at::Tensor& src, std::string_view reduce, bool include_self) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if ((reduce == "sum" || reduce == "prod" || reduce == "amin" ||
reduce == "amax") &&
Expand Down Expand Up @@ -3741,7 +3741,7 @@ at::Tensor& XLANativeFunctions::zero_(at::Tensor& self) {

std::tuple<at::Tensor, at::Tensor, at::Tensor> XLANativeFunctions::_linalg_svd(
const at::Tensor& self, bool full_matrices, bool compute_uv,
std::optional<c10::string_view> /* driver */) {
std::optional<std::string_view> /* driver */) {
// The optional driver string is only for CUDA with a cuSOLVER backend.
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
// As per https://pytorch.org/docs/stable/generated/torch.svd.html,
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/cross_replica_reduces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ void SetAllReduceToken(const torch::lazy::BackendDevice& device,
g_all_reduce_tokens[device.ordinal()] = token;
}

AllReduceType GetReduceType(c10::string_view reduce_type) {
AllReduceType GetReduceType(std::string_view reduce_type) {
if (reduce_type == "sum") {
return AllReduceType::kSum;
} else if (reduce_type == "mul") {
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/cross_replica_reduces.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ const torch::lazy::Value& GetAllReduceToken(
void SetAllReduceToken(const torch::lazy::BackendDevice& device,
const std::shared_ptr<torch::lazy::Value>& token);

AllReduceType GetReduceType(c10::string_view reduce_type);
AllReduceType GetReduceType(std::string_view reduce_type);

} // namespace torch_xla

Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/ops/scatter_reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace torch_xla {
ScatterReduce::ScatterReduce(const torch::lazy::Value& input,
const torch::lazy::Value& index,
const torch::lazy::Value& src,
c10::string_view reduce, bool include_self,
std::string_view reduce, bool include_self,
int64_t dim)
: XlaNode(torch::lazy::OpKind(at::aten::scatter_reduce),
{input, index, src}, GetXlaShape(input),
Expand Down
4 changes: 2 additions & 2 deletions torch_xla/csrc/ops/scatter_reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class ScatterReduce : public XlaNode {
public:
ScatterReduce(const torch::lazy::Value& input,
const torch::lazy::Value& index, const torch::lazy::Value& src,
c10::string_view reduce, bool include_self, int64_t dim);
std::string_view reduce, bool include_self, int64_t dim);

std::string ToString() const override;

Expand All @@ -27,4 +27,4 @@ class ScatterReduce : public XlaNode {

} // namespace torch_xla

#endif // XLA_TORCH_XLA_CSRC_OPS_SCATTER_REDUCE_H_
#endif // XLA_TORCH_XLA_CSRC_OPS_SCATTER_REDUCE_H_
8 changes: 4 additions & 4 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1250,7 +1250,7 @@ XLATensorPtr diagonal(const XLATensorPtr& input, int64_t offset, int64_t dim1,
}

XLATensorPtr div(const XLATensorPtr& input, const XLATensorPtr& other,
const std::optional<c10::string_view>& rounding_mode,
const std::optional<std::string_view>& rounding_mode,
std::optional<at::ScalarType> logical_element_type) {
at::ScalarType scalar_type =
at::typeMetaToScalarType(c10::get_default_dtype());
Expand Down Expand Up @@ -1548,7 +1548,7 @@ XLATensorPtr ge(const XLATensorPtr& input, const XLATensorPtr& other) {
}

XLATensorPtr gelu(const XLATensorPtr& input,
const c10::string_view approximate) {
const std::string_view approximate) {
if (approximate == "none") {
return input->CreateFrom(Gelu(input->GetIrValue()));
} else if (approximate == "tanh") {
Expand All @@ -1559,7 +1559,7 @@ XLATensorPtr gelu(const XLATensorPtr& input,
}

XLATensorPtr gelu_backward(const XLATensorPtr& grad, const XLATensorPtr& input,
const c10::string_view approximate) {
const std::string_view approximate) {
if (approximate == "none") {
return input->CreateFrom(
GeluBackward(grad->GetIrValue(), input->GetIrValue()));
Expand Down Expand Up @@ -2718,7 +2718,7 @@ XLATensorPtr scatter_add(const XLATensorPtr& input, int64_t dim,

XLATensorPtr scatter_reduce(const XLATensorPtr& input, int64_t dim,
const XLATensorPtr& index, const XLATensorPtr& src,
c10::string_view reduce, bool include_self) {
std::string_view reduce, bool include_self) {
return input->CreateFrom(torch::lazy::MakeNode<ScatterReduce>(
input->GetIrValue(), index->GetIrValue(), src->GetIrValue(), reduce,
include_self,
Expand Down
8 changes: 4 additions & 4 deletions torch_xla/csrc/tensor_methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ XLATensorPtr diagonal(const XLATensorPtr& input, int64_t offset, int64_t dim1,

XLATensorPtr div(
const XLATensorPtr& input, const XLATensorPtr& other,
const std::optional<c10::string_view>& rounding_mode = std::nullopt,
const std::optional<std::string_view>& rounding_mode = std::nullopt,
std::optional<at::ScalarType> logical_element_type = std::nullopt);
XLATensorPtr div(const XLATensorPtr& input, const at::Scalar& other);

Expand Down Expand Up @@ -459,10 +459,10 @@ XLATensorPtr ge(const XLATensorPtr& input, const at::Scalar& other);
XLATensorPtr ge(const XLATensorPtr& input, const XLATensorPtr& other);

XLATensorPtr gelu(const XLATensorPtr& input,
const c10::string_view approximate);
const std::string_view approximate);

XLATensorPtr gelu_backward(const XLATensorPtr& grad, const XLATensorPtr& input,
const c10::string_view approximate);
const std::string_view approximate);

XLATensorPtr gt(const XLATensorPtr& input, const at::Scalar& other);

Expand Down Expand Up @@ -842,7 +842,7 @@ XLATensorPtr scatter_add(const XLATensorPtr& input, int64_t dim,

XLATensorPtr scatter_reduce(const XLATensorPtr& input, int64_t dim,
const XLATensorPtr& index, const XLATensorPtr& src,
c10::string_view reduce, bool include_self);
std::string_view reduce, bool include_self);

XLATensorPtr select(const XLATensorPtr& input, int64_t dim, int64_t index);

Expand Down

0 comments on commit eee3b2f

Please sign in to comment.