Skip to content

Commit

Permalink
Full codegen erf, erfc, erfinv, and exp (#3659)
Browse files Browse the repository at this point in the history
  • Loading branch information
wonjoolee95 authored Jun 23, 2022
1 parent 105f077 commit 05fa2aa
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 43 deletions.
21 changes: 0 additions & 21 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1322,27 +1322,6 @@ at::Tensor XLANativeFunctions::eq(const at::Tensor& self,
XLATensor::eq(bridge::GetXlaTensor(self), bridge::GetXlaTensor(other)));
}

at::Tensor XLANativeFunctions::erf(const at::Tensor& self) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(XLATensor::erf(bridge::GetXlaTensor(self)));
}

at::Tensor XLANativeFunctions::erfc(const at::Tensor& self) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(XLATensor::erfc(bridge::GetXlaTensor(self)));
}

at::Tensor XLANativeFunctions::erfinv(const at::Tensor& self) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(
XLATensor::erfinv(bridge::GetXlaTensor(self)));
}

at::Tensor XLANativeFunctions::exp(const at::Tensor& self) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(XLATensor::exp(bridge::GetXlaTensor(self)));
}

at::Tensor XLANativeFunctions::expand(const at::Tensor& self,
at::IntArrayRef size, bool implicit) {
XLA_FN_COUNTER("xla::");
Expand Down
20 changes: 20 additions & 0 deletions torch_xla/csrc/ops/ops_lower_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,26 @@ torch_xla::XlaOpVector Cosh::Lower(LoweringContext* loctx) const {
return ReturnOp(xla::Cosh(xla_input), loctx);
}

torch_xla::XlaOpVector Erf::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
return ReturnOp(xla::Erf(xla_input), loctx);
}

torch_xla::XlaOpVector Erfc::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
return ReturnOp(xla::Erfc(xla_input), loctx);
}

torch_xla::XlaOpVector Erfinv::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
return ReturnOp(xla::ErfInv(xla_input), loctx);
}

torch_xla::XlaOpVector Exp::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
return ReturnOp(xla::Exp(xla_input), loctx);
}

torch_xla::XlaOpVector Floor::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
return ReturnOp(xla::Floor(xla_input), loctx);
Expand Down
16 changes: 16 additions & 0 deletions torch_xla/csrc/ops/ops_xla_shape_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,22 @@ xla::Shape CoshOutputShape(const torch::lazy::Value& input) {
return GetXlaShape(input);
}

xla::Shape ErfOutputShape(const torch::lazy::Value& input) {
return GetXlaShape(input);
}

xla::Shape ErfcOutputShape(const torch::lazy::Value& input) {
return GetXlaShape(input);
}

xla::Shape ErfinvOutputShape(const torch::lazy::Value& input) {
return GetXlaShape(input);
}

xla::Shape ExpOutputShape(const torch::lazy::Value& input) {
return GetXlaShape(input);
}

xla::Shape FloorOutputShape(const torch::lazy::Value& input) {
return GetXlaShape(input);
}
Expand Down
8 changes: 8 additions & 0 deletions torch_xla/csrc/ops/ops_xla_shape_fn.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@ xla::Shape CosOutputShape(const torch::lazy::Value& input);

xla::Shape CoshOutputShape(const torch::lazy::Value& input);

xla::Shape ErfOutputShape(const torch::lazy::Value& input);

xla::Shape ErfcOutputShape(const torch::lazy::Value& input);

xla::Shape ErfinvOutputShape(const torch::lazy::Value& input);

xla::Shape ExpOutputShape(const torch::lazy::Value& input);

xla::Shape FloorOutputShape(const torch::lazy::Value& input);

xla::Shape InverseOutputShape(const torch::lazy::Value& input);
Expand Down
6 changes: 0 additions & 6 deletions torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -555,12 +555,6 @@ class XLATensor : public c10::intrusive_ptr_target {

static XLATensor eq(const XLATensor& input, const XLATensor& other);

static XLATensor erf(const XLATensor& input);

static XLATensor erfc(const XLATensor& input);

static XLATensor erfinv(const XLATensor& input);

static XLATensor exp(const XLATensor& input);

static XLATensor expand(const XLATensor& input, std::vector<int64_t> size);
Expand Down
12 changes: 0 additions & 12 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1271,18 +1271,6 @@ XLATensor XLATensor::embedding_dense_backward(const XLATensor& grad_output,
padding_idx, scale_grad_by_freq);
}

XLATensor XLATensor::erf(const XLATensor& input) {
return input.CreateFrom(Erf(input.GetIrValue()));
}

XLATensor XLATensor::erfc(const XLATensor& input) {
return input.CreateFrom(Erfc(input.GetIrValue()));
}

XLATensor XLATensor::erfinv(const XLATensor& input) {
return input.CreateFrom(Erfinv(input.GetIrValue()));
}

XLATensor XLATensor::exp(const XLATensor& input) {
return input.CreateFrom(Exp(input.GetIrValue()));
}
Expand Down
8 changes: 4 additions & 4 deletions xla_native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ full_codegen:
- atanh
- cos
- cosh
- erf
- erfc
- erfinv
- exp
- floor
- inverse
- logdet
Expand Down Expand Up @@ -121,10 +125,6 @@ supported:
- empty_strided
- eq.Scalar
- eq.Tensor
- erf
- erfc
- erfinv
- exp
- expand
- expm1
- exponential_
Expand Down

0 comments on commit 05fa2aa

Please sign in to comment.