Skip to content

Commit

Permalink
Misc updates
Browse files Browse the repository at this point in the history
  • Loading branch information
garth-wells committed Feb 9, 2025
1 parent 92c54d2 commit 0b86553
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 33 deletions.
18 changes: 9 additions & 9 deletions cpp/basix/finite-element.h
Original file line number Diff line number Diff line change
Expand Up @@ -441,9 +441,9 @@ class FiniteElement
/// value_size)`. The function tabulate_shape() can be used to get the
/// required shape.
/// - The first index is the derivative, with higher derivatives are
/// stored in triangular (2D) or tetrahedral (3D) ordering, ie for
/// the (x,y) derivatives in 2D: (0,0), (1,0), (0,1), (2,0), (1,1),
/// (0,2), (3,0)... The function indexing::idx can be used to
/// stored in triangular (2D) or tetrahedral (3D) ordering, i.e. for
/// the (x, y) derivatives in 2D: (0, 0), (1, 0), (0, 1), (2, 0), (1,
/// 1), (0, 2), (3, 0), ... The function indexing::idx can be used to
/// find the appropriate derivative.
/// - The second index is the point index
/// - The third index is the basis function index
Expand Down Expand Up @@ -472,9 +472,9 @@ class FiniteElement
/// value_size)`. The function tabulate_shape() can be used to get the
/// required shape.
/// - The first index is the derivative, with higher derivatives are
/// stored in triangular (2D) or tetrahedral (3D) ordering, ie for the
/// (x,y) derivatives in 2D: (0,0), (1,0), (0,1), (2,0), (1,1), (0,2),
/// (3,0)... The function indexing::idx can be used to find the
/// stored in triangular (2D) or tetrahedral (3D) ordering, i.e. for
/// the (x,y) derivatives in 2D: (0,0), (1,0), (0,1), (2,0), (1,1),
/// (0,2), (3,0)... The function indexing::idx can be used to find the
/// appropriate derivative.
/// - The second index is the point index
/// - The third index is the basis function index
Expand Down Expand Up @@ -615,10 +615,10 @@ class FiniteElement
/// flattened with row-major layout, shape=(num_points, ref
/// value_size)
/// - `u` [in] The data on the physical cell that should be pulled
/// back , flattened with row-major layout, shape=(num_points,
/// back, flattened with row-major layout, shape=(num_points,
/// value_size)
/// - `K` [in] The inverse of the Jacobian matrix of the map
/// ,shape=(tdim, gdim)
/// - `K` [in] The inverse of the Jacobian matrix of the map,
/// shape=(tdim, gdim)
/// - `detJ_inv` [in] 1/det(J)
/// - `J` [in] The Jacobian matrix, shape=(gdim, tdim)
template <typename O, typename P, typename Q, typename R>
Expand Down
22 changes: 22 additions & 0 deletions cpp/basix/indexing.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@

#pragma once

#include <array>
#include <concepts>
#include <initializer_list>

/// @brief Indexing.
namespace basix::indexing
{
Expand Down Expand Up @@ -32,4 +36,22 @@ constexpr int idx(int p, int q, int r)
return (p + q + r) * (p + q + r + 1) * (p + q + r + 2) / 6
+ (q + r) * (q + r + 1) / 2 + r;
}

// @brief Compute indexing in a 3D tetrahedral array compressed into a
// 1D array.
// @param p Index in x.
// @param q Index in y.
// @param r Index in z.
// @return 1D Index.
template <std::integral T, std::size_t N>
constexpr int idx(std::array<T, N> p)
{
static_assert(p.size() > 0 and p.size() <= 3);
if constexpr (p.size() == 1)
return idx(p[0]);
else if constexpr (p.size() == 2)
return idx(p[0], p[1]);
else
return idx(p[0], p[1], p[2]);
}
} // namespace basix::indexing
26 changes: 10 additions & 16 deletions cpp/basix/maps.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,18 +103,15 @@ template <typename O, typename P, typename Q, typename R>
void double_covariant_piola(O&& r, const P& U, const Q& J, double /*detJ*/,
const R& K)
{
namespace stdex
= MDSPAN_IMPL_STANDARD_NAMESPACE::MDSPAN_IMPL_PROPOSED_NAMESPACE;
namespace md = MDSPAN_IMPL_STANDARD_NAMESPACE;
using T = typename std::decay_t<O>::value_type;
using Z = typename impl::scalar_value_type_t<T>;
for (std::size_t p = 0; p < U.extent(0); ++p)
{
MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
const T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>>
_U(U.data_handle() + p * U.extent(1), J.extent(1), J.extent(1));
MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>>
_r(r.data_handle() + p * r.extent(1), K.extent(1), K.extent(1));
md::mdspan<const T, md::dextents<std::size_t, 2>> _U(
U.data_handle() + p * U.extent(1), J.extent(1), J.extent(1));
md::mdspan<T, md::dextents<std::size_t, 2>> _r(
r.data_handle() + p * r.extent(1), K.extent(1), K.extent(1));
// _r = K^T _U K
for (std::size_t i = 0; i < _r.extent(0); ++i)
{
Expand All @@ -135,18 +132,15 @@ template <typename O, typename P, typename Q, typename R>
void double_contravariant_piola(O&& r, const P& U, const Q& J, double detJ,
const R& /*K*/)
{
namespace stdex
= MDSPAN_IMPL_STANDARD_NAMESPACE::MDSPAN_IMPL_PROPOSED_NAMESPACE;
namespace md = MDSPAN_IMPL_STANDARD_NAMESPACE;
using T = typename std::decay_t<O>::value_type;
using Z = typename impl::scalar_value_type_t<T>;
for (std::size_t p = 0; p < U.extent(0); ++p)
{
MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
const T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>>
_U(U.data_handle() + p * U.extent(1), J.extent(1), J.extent(1));
MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>>
_r(r.data_handle() + p * r.extent(1), J.extent(0), J.extent(0));
md::mdspan<const T, md::dextents<std::size_t, 2>> _U(
U.data_handle() + p * U.extent(1), J.extent(1), J.extent(1));
md::mdspan<T, md::dextents<std::size_t, 2>> _r(
r.data_handle() + p * r.extent(1), J.extent(0), J.extent(0));

// _r = J U J^T
for (std::size_t i = 0; i < _r.extent(0); ++i)
Expand Down
17 changes: 9 additions & 8 deletions python/wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -512,14 +512,12 @@ NB_MODULE(_basixcpp, m)
{ return cell::volume<double>(cell_type); });
m.def("cell_facet_normals", [](cell::type cell_type)
{ return as_nbarrayp(cell::facet_normals<double>(cell_type)); });
m.def("cell_facet_reference_volumes",
[](cell::type cell_type) {
return as_nbarray(cell::facet_reference_volumes<double>(cell_type));
});
m.def("cell_facet_outward_normals",
[](cell::type cell_type) {
return as_nbarrayp(cell::facet_outward_normals<double>(cell_type));
});
m.def(
"cell_facet_reference_volumes", [](cell::type cell_type)
{ return as_nbarray(cell::facet_reference_volumes<double>(cell_type)); });
m.def(
"cell_facet_outward_normals", [](cell::type cell_type)
{ return as_nbarrayp(cell::facet_outward_normals<double>(cell_type)); });
m.def("cell_facet_orientations",
[](cell::type cell_type)
{
Expand Down Expand Up @@ -684,6 +682,9 @@ NB_MODULE(_basixcpp, m)
m.def("index", nb::overload_cast<int>(&basix::indexing::idx));
m.def("index", nb::overload_cast<int, int>(&basix::indexing::idx));
m.def("index", nb::overload_cast<int, int, int>(&basix::indexing::idx));
// m.def("index", [](nb::ndarray<const int, nb::ndim<1>, nb::c_contig>) {
// if
// });

declare_float<float>(m, "float32");
declare_float<double>(m, "float64");
Expand Down
3 changes: 3 additions & 0 deletions test/test_cmake/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// SPDX-License-Identifier: MIT

#include <basix/finite-element.h>
#include <basix/indexing.h>
#include <memory>

int main()
Expand All @@ -17,5 +18,7 @@ int main()
basix::element::lagrange_variant::equispaced,
basix::element::dpc_variant::unset, false);

basix::indexing::idxn({1, 2, 1});

return 0;
}

0 comments on commit 0b86553

Please sign in to comment.