Skip to content

Commit

Permalink
Use std::string_view
Browse files Browse the repository at this point in the history
  • Loading branch information
cyyever committed Jul 17, 2024
1 parent 860cdc8 commit fade1af
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
2 changes: 1 addition & 1 deletion torch_xla/csrc/aten_autograd_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ bool IsNonTrivialDilation(at::IntArrayRef dilation) {
namespace aten_autograd_ops {

torch::Tensor EinsumAutogradFunction::forward(
torch::autograd::AutogradContext* ctx, const std::string_view equation,
torch::autograd::AutogradContext* ctx, const c10::string_view equation,
at::TensorList tensors) {
std::string eq_str = std::string(equation);
ctx->saved_data["equation"] = eq_str;
Expand Down
20 changes: 10 additions & 10 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1376,7 +1376,7 @@ at::Tensor XLANativeFunctions::div(const at::Tensor& self,

at::Tensor XLANativeFunctions::div(
const at::Tensor& self, const at::Tensor& other,
std::optional<std::string_view> rounding_mode) {
std::optional<c10::string_view> rounding_mode) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
at::ScalarType dtype = at::result_type(self, other);
auto operands = GetBinaryOperands(self, UnwrapNumber(other, dtype));
Expand Down Expand Up @@ -1414,7 +1414,7 @@ at::Tensor XLANativeFunctions::dot(const at::Tensor& self,
bridge::GetXlaTensor(self), bridge::GetXlaTensor(tensor)));
}

at::Tensor XLANativeFunctions::einsum(std::string_view equation,
at::Tensor XLANativeFunctions::einsum(c10::string_view equation,
at::TensorList tensors,
at::OptionalIntArrayRef path) {
std::string cleansed_equation = std::string(equation);
Expand Down Expand Up @@ -1709,15 +1709,15 @@ at::Tensor XLANativeFunctions::gather(const at::Tensor& self, int64_t dim,
}

at::Tensor XLANativeFunctions::gelu(const at::Tensor& self,
std::string_view approximate) {
c10::string_view approximate) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
return bridge::AtenFromXlaTensor(
tensor_methods::gelu(bridge::GetXlaTensor(self), approximate));
}

at::Tensor XLANativeFunctions::gelu_backward(const at::Tensor& grad,
const at::Tensor& self,
std::string_view approximate) {
c10::string_view approximate) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
at::ScalarType result_type = at::result_type(grad, self);
return bridge::AtenFromXlaTensor(tensor_methods::gelu_backward(
Expand Down Expand Up @@ -3074,7 +3074,7 @@ at::Tensor XLANativeFunctions::rsub(const at::Tensor& self,

at::Tensor scatter_reduce_helper(const at::Tensor& self, int64_t dim,
const at::Tensor& index, const at::Tensor& src,
std::optional<std::string_view> reduce) {
std::optional<c10::string_view> reduce) {
XLATensorPtr self_tensor = bridge::GetXlaTensor(self);
if (!reduce.has_value()) {
return bridge::AtenFromXlaTensor(
Expand All @@ -3095,7 +3095,7 @@ at::Tensor scatter_reduce_helper(const at::Tensor& self, int64_t dim,
at::Tensor scatter_reduce_helper(const at::Tensor& self, int64_t dim,
const at::Tensor& index,
const at::Scalar& value,
std::optional<std::string_view> reduce) {
std::optional<c10::string_view> reduce) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
XLATensorPtr self_tensor = bridge::GetXlaTensor(self);
if (!reduce.has_value()) {
Expand Down Expand Up @@ -3129,15 +3129,15 @@ at::Tensor XLANativeFunctions::scatter(const at::Tensor& self, int64_t dim,
at::Tensor XLANativeFunctions::scatter(const at::Tensor& self, int64_t dim,
const at::Tensor& index,
const at::Tensor& src,
std::string_view reduce) {
c10::string_view reduce) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
return scatter_reduce_helper(self, dim, index, src, reduce);
}

at::Tensor XLANativeFunctions::scatter(const at::Tensor& self, int64_t dim,
const at::Tensor& index,
const at::Scalar& value,
std::string_view reduce) {
c10::string_view reduce) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
return scatter_reduce_helper(self, dim, index, value, reduce);
}
Expand All @@ -3153,7 +3153,7 @@ at::Tensor XLANativeFunctions::scatter_add(const at::Tensor& self, int64_t dim,
// supported
at::Tensor XLANativeFunctions::scatter_reduce(
const at::Tensor& self, int64_t dim, const at::Tensor& index,
const at::Tensor& src, std::string_view reduce, bool include_self) {
const at::Tensor& src, c10::string_view reduce, bool include_self) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if ((reduce == "sum" || reduce == "prod" || reduce == "amin" ||
reduce == "amax") &&
Expand Down Expand Up @@ -3782,7 +3782,7 @@ at::Tensor& XLANativeFunctions::zero_(at::Tensor& self) {

std::tuple<at::Tensor, at::Tensor, at::Tensor> XLANativeFunctions::_linalg_svd(
const at::Tensor& self, bool full_matrices, bool compute_uv,
std::optional<std::string_view> /* driver */) {
std::optional<c10::string_view> /* driver */) {
// The optional driver string is only for CUDA with a cuSOLVER backend.
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
// As per https://pytorch.org/docs/stable/generated/torch.svd.html,
Expand Down

0 comments on commit fade1af

Please sign in to comment.