Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions python/pyabacus/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,14 @@ set(BASE_PATH "${PROJECT_SOURCE_DIR}/../../source/module_base")
set(ABACUS_SOURCE_DIR "${PROJECT_SOURCE_DIR}/../../source")
include_directories(${BASE_PATH} ${ABACUS_SOURCE_DIR})
list(APPEND _sources
${ABACUS_SOURCE_DIR}/module_basis/module_nao/numerical_radial.h
${ABACUS_SOURCE_DIR}/module_basis/module_nao/numerical_radial.cpp
${PROJECT_SOURCE_DIR}/src/py_numerical_radial.cpp)
#${ABACUS_SOURCE_DIR}/module_basis/module_nao/numerical_radial.h
#${ABACUS_SOURCE_DIR}/module_basis/module_nao/numerical_radial.cpp
${ABACUS_SOURCE_DIR}/module_base/constants.h
${ABACUS_SOURCE_DIR}/module_base/math_sphbes.h
${ABACUS_SOURCE_DIR}/module_base/math_sphbes.cpp
${PROJECT_SOURCE_DIR}/src/py_abacus.cpp
#${PROJECT_SOURCE_DIR}/src/py_numerical_radial.cpp
${PROJECT_SOURCE_DIR}/src/py_math_base.cpp)
python_add_library(_core MODULE ${_sources} WITH_SOABI)
target_link_libraries(_core PRIVATE pybind11::headers)
target_compile_definitions(_core PRIVATE VERSION_INFO=${PROJECT_VERSION})
Expand Down
13 changes: 13 additions & 0 deletions python/pyabacus/src/py_abacus.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>

namespace py = pybind11;

void bind_numerical_radial(py::module& m);
void bind_math_base(py::module& m);

PYBIND11_MODULE(_core, m)
{
// bind_numerical_radial(m);
bind_math_base(m);
}
63 changes: 63 additions & 0 deletions python/pyabacus/src/py_math_base.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>

#include "module_base/math_sphbes.h"

namespace py = pybind11;
using namespace pybind11::literals;
template <typename... Args>
using overload_cast_ = pybind11::detail::overload_cast_impl<Args...>;

void bind_math_base(py::module& m)
{
py::module module_base = m.def_submodule("ModuleBase");

py::class_<ModuleBase::Sphbes>(module_base, "Sphbes")
.def(py::init<>())
.def_static("sphbesj", overload_cast_<const int, const double>()(&ModuleBase::Sphbes::sphbesj), "l"_a, "x"_a)
.def_static("dsphbesj", overload_cast_<const int, const double>()(&ModuleBase::Sphbes::dsphbesj), "l"_a, "x"_a)
.def_static("sphbesj",
[](const int n, py::array_t<double> r, const double q, const int l, py::array_t<double> jl) {
py::buffer_info r_info = r.request();
if (r_info.ndim != 1)
{
throw std::runtime_error("r array must be 1-dimensional");
}
py::buffer_info jl_info = jl.request();
if (jl_info.ndim != 1)
{
throw std::runtime_error("jl array must be 1-dimensional");
}
ModuleBase::Sphbes::sphbesj(n,
static_cast<const double* const>(r_info.ptr),
q,
l,
static_cast<double* const>(jl_info.ptr));
})
.def_static("dsphbesj",
[](const int n, py::array_t<double> r, const double q, const int l, py::array_t<double> djl) {
py::buffer_info r_info = r.request();
if (r_info.ndim != 1)
{
throw std::runtime_error("r array must be 1-dimensional");
}
py::buffer_info djl_info = djl.request();
if (djl_info.ndim != 1)
{
throw std::runtime_error("djl array must be 1-dimensional");
}
ModuleBase::Sphbes::dsphbesj(n,
static_cast<const double* const>(r_info.ptr),
q,
l,
static_cast<double* const>(djl_info.ptr));
})
.def_static("sphbes_zeros", [](const int l, const int n, py::array_t<double> zeros) {
py::buffer_info zeros_info = zeros.request();
if (zeros_info.ndim != 1)
{
throw std::runtime_error("zeros array must be 1-dimensional");
}
ModuleBase::Sphbes::sphbes_zeros(l, n, static_cast<double* const>(zeros_info.ptr));
});
}
4 changes: 2 additions & 2 deletions python/pyabacus/src/py_numerical_radial.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using namespace pybind11::literals;
template <typename... Args>
using overload_cast_ = pybind11::detail::overload_cast_impl<Args...>;

PYBIND11_MODULE(_core, m)
void bind_numerical_radial(py::module& m)
{
// Create the submodule for NumericalRadial
py::module m_numerical_radial = m.def_submodule("NumericalRadial");
Expand Down Expand Up @@ -165,4 +165,4 @@ PYBIND11_MODULE(_core, m)
.def_property_readonly("kgrid", overload_cast_<int>()(&NumericalRadial::kgrid, py::const_))
.def_property_readonly("rvalue", overload_cast_<int>()(&NumericalRadial::rvalue, py::const_))
.def_property_readonly("kvalue", overload_cast_<int>()(&NumericalRadial::kvalue, py::const_));
}
}
5 changes: 3 additions & 2 deletions python/pyabacus/src/pyabacus/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import annotations
from ._core import __doc__, __version__, NumericalRadial
__all__ = ["__doc__", "__version__", "NumericalRadial"]
# from ._core import __doc__, __version__, NumericalRadial, ModuleBase
from ._core import ModuleBase
__all__ = ["ModuleBase"]
15 changes: 15 additions & 0 deletions python/pyabacus/tests/test_base_math.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from __future__ import annotations

import pyabacus as m
import numpy as np


def test_version():
assert m.__version__ == "0.0.1"

def test_sphbes():
s = m.ModuleBase.Sphbes()
# test for sphbesj
assert s.sphbesj(1, 0.0) == 0.0
assert s.sphbesj(0, 0.0) == 1.0

25 changes: 0 additions & 25 deletions python/pyabacus/tests/test_nr.py

This file was deleted.

7 changes: 1 addition & 6 deletions source/module_base/math_sphbes.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include "math_sphbes.h"
#include "timer.h"
#include "constants.h"
#include <algorithm>
#include <iostream>

#include <cassert>

Expand Down Expand Up @@ -425,7 +425,6 @@ void Sphbes::Spherical_Bessel
double *jl // jl(1:msh) = j_l(q*r(i)),spherical bessel function
)
{
ModuleBase::timer::tick("Sphbes","Spherical_Bessel");
double x1=0.0;

int i=0;
Expand Down Expand Up @@ -598,7 +597,6 @@ void Sphbes::Spherical_Bessel
}
}

ModuleBase::timer::tick("Sphbes","Spherical_Bessel");
return;
}

Expand All @@ -613,7 +611,6 @@ void Sphbes::Spherical_Bessel
double *sjp
)
{
ModuleBase::timer::tick("Sphbes","Spherical_Bessel");

//calculate jlx first
Spherical_Bessel (msh, r, q, l, sj);
Expand All @@ -634,7 +631,6 @@ void Sphbes::dSpherical_Bessel_dx
double *djl // jl(1:msh) = j_l(q*r(i)),spherical bessel function
)
{
ModuleBase::timer::tick("Sphbes","dSpherical_Bessel_dq");
if (l < 0 )
{
std::cout << "We temporarily only calculate derivative of l >= 0." << std::endl;
Expand Down Expand Up @@ -682,7 +678,6 @@ void Sphbes::dSpherical_Bessel_dx
}
delete[] jl;
}
ModuleBase::timer::tick("Sphbes","dSpherical_Bessel_dq");
return;
}

Expand Down