Skip to content

Commit

Permalink
Fix erfinv and swapaxes (#7217)
Browse files Browse the repository at this point in the history
* Fix erfinv and swapaxes

* Fix

* Fix bug and add test

* Modify name

* Fix arg

* Modify pi

* Fix

Co-authored-by: ZZK <42901638+MARD1NO@users.noreply.github.com>
Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
  • Loading branch information
3 people authored Jan 12, 2022
1 parent 2c49940 commit 7065d35
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 14 deletions.
3 changes: 0 additions & 3 deletions oneflow/core/functional/impl/math_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -648,9 +648,6 @@ class SwapaxesFunctor {
<< "], but got " << dim_1 << ")";
return Transpose2dim(x, dim0, dim1);
}

private:
std::shared_ptr<OpExpr> op_;
};

class ArangeFunctor {
Expand Down
23 changes: 12 additions & 11 deletions oneflow/user/kernels/erfinv_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/framework.h"
#include <math.h>
#include <cmath>
namespace oneflow {

template<typename T>
Expand All @@ -31,6 +31,7 @@ class CpuErfinvKernel final : public user_op::OpKernel {
const T* x_ptr = x->dptr<T>();
T* y_ptr = y->mut_dptr<T>();
constexpr float central_range = 0.7;
const T temp = static_cast<T>(2.0) / static_cast<T>(std::sqrt(M_PI));
T a[4] = {T(0.886226899), T(-1.645349621), T(0.914624893), T(-0.140543331)};
T b[4] = {T(-2.118377725), T(1.442710462), T(-0.329097515), T(0.012229801)};
T c[4] = {T(-1.970840454), T(-1.624906493), T(3.429567803), T(1.641345311)};
Expand All @@ -39,8 +40,14 @@ class CpuErfinvKernel final : public user_op::OpKernel {
T z, num, dem;
T x = x_ptr[i]; // Promise the correctness of inplace version.
T x_abs = std::abs(x);
if (x_abs > 1.0) y_ptr[i] = std::numeric_limits<T>::quiet_NaN();
if (x_abs == 1.0) y_ptr[i] = std::copysign(std::numeric_limits<T>::infinity(), x);
if (x_abs > 1.0) {
y_ptr[i] = std::numeric_limits<T>::quiet_NaN();
continue;
}
if (x_abs == 1.0) {
y_ptr[i] = std::copysign(std::numeric_limits<T>::infinity(), x);
continue;
}
if (x_abs <= static_cast<T>(central_range)) {
z = x * x;
num = (((a[3] * z + a[2]) * z + a[1]) * z + a[0]);
Expand All @@ -52,14 +59,8 @@ class CpuErfinvKernel final : public user_op::OpKernel {
dem = (d[1] * z + d[0]) * z + static_cast<T>(1.0);
y_ptr[i] = std::copysign(num, x) / dem;
}
y_ptr[i] = y_ptr[i]
- (std::erf(y_ptr[i]) - x)
/ ((static_cast<T>(2.0) / static_cast<T>(std::sqrt(M_PI)))
* std::exp(-y_ptr[i] * y_ptr[i]));
y_ptr[i] = y_ptr[i]
- (std::erf(y_ptr[i]) - x)
/ ((static_cast<T>(2.0) / static_cast<T>(std::sqrt(M_PI)))
* std::exp(-y_ptr[i] * y_ptr[i]));
y_ptr[i] = y_ptr[i] - (std::erf(y_ptr[i]) - x) / (temp * std::exp(-y_ptr[i] * y_ptr[i]));
y_ptr[i] = y_ptr[i] - (std::erf(y_ptr[i]) - x) / (temp * std::exp(-y_ptr[i] * y_ptr[i]));
}
}

Expand Down
26 changes: 26 additions & 0 deletions python/oneflow/test/modules/test_erfinv.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,34 @@
from oneflow.test_utils.automated_test_util import *


def _test_flow_erfinv_with_inf_data(test_case, device):
x = flow.tensor(np.ones((5, 5)), dtype=flow.float32, device=flow.device(device))
of_out = flow.erfinv(x)
np_out = np.full((5, 5), np.inf)
test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))


def _test_flow_erfinv_with_nan_data(test_case, device):
x = flow.tensor(
np.arange(2, 22).reshape(4, 5), dtype=flow.float32, device=flow.device(device)
)
of_out = flow.erfinv(x)
np_out = np.full((4, 5), np.nan)
test_case.assertTrue(np.array_equal(of_out.numpy(), np_out, equal_nan=True))


@flow.unittest.skip_unless_1n1d()
class TestErfinvModule(flow.unittest.TestCase):
def test_flow_erfinv(test_case):
arg_dict = OrderedDict()
arg_dict["test_fun"] = [
_test_flow_erfinv_with_inf_data,
_test_flow_erfinv_with_nan_data,
]
arg_dict["device"] = ["cpu", "cuda"]
for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])

@autotest(check_graph=True, auto_backward=False)
def test_flow_erfinv_with_random_data(test_case):
device = random_device()
Expand Down

0 comments on commit 7065d35

Please sign in to comment.