Skip to content

Commit

Permalink
Add constitution.jax.Hyperelastic (#876)
Browse files Browse the repository at this point in the history
* Add `constitution.jax.Hyperelastic`

* Update test_constitution_jax.py

* Update test_constitution_jax.py

* Update test_constitution_jax.py

* Update _hyperelastic.py
  • Loading branch information
adtzlr authored Oct 31, 2024
1 parent ffe4c44 commit ac460a9
Show file tree
Hide file tree
Showing 9 changed files with 441 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ All notable changes to this project will be documented in this file. The format
- Add optional keyword-arguments to `math.transpose(**kwargs)` to support optional `out` and `order`-keywords.
- Add the attribute `RegionBoundary.tangents`, which contains a list of tangent unit vectors. For `quad` cell-types the length of this list is one and for `hexahedron` cell-types it is of length two.
- Add `math.inplane(A, vectors)` to return the in-plane components of a symmetric tensor `A`, where the plane is defined by its standard unit vectors.
- Add `constitution.jax.Hyperelastic` as a feature-equivalent alternative to `Hyperelastic` with `jax` as backend.

### Changed
- Change default `np.einsum(..., order="K")` to `np.einsum(..., order="C")` in the methods of `Field`, `FieldAxisymmetric`, `FieldPlaneStrain` and `FieldContainer`.
Expand Down
1 change: 1 addition & 0 deletions docs/felupe/constitution.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ This module provides :class:`constitutive material <felupe.ConstitutiveMaterial>
constitution/hyperelasticity
constitution/lagrange
constitution/tools
constitution/jax

There are many different pre-defined constitutive material formulations available, including definitions for linear-elasticity, small-strain plasticity, hyperelasticity or pseudo-elasticity. The generation of user materials may be simplified when using frameworks for user-defined functions, like hyperelasticity (with automatic differentiation) or a small-strain based framework with state variables. However, the most general case is given by a framework with functions for the evaluation of stress and elasticity tensors in terms of the deformation gradient.

Expand Down
21 changes: 21 additions & 0 deletions docs/felupe/constitution/jax.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
.. _felupe-api-constitution-jax:

JAX-based Materials
~~~~~~~~~~~~~~~~~~~

This page contains hyperelastic material model formulations with automatic differentiation using :mod:`jax`. These material model formulations are defined by a strain energy density function.

**Frameworks**

.. currentmodule:: felupe

.. autosummary::

constitution.jax.Hyperelastic

**Detailed API Reference**

.. autoclass:: felupe.constitution.jax.Hyperelastic
:members:
:undoc-members:
:inherited-members:
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ docs = [
"sphinx-gallery",
"pypandoc",
]
constitution = ["jax"]
examples = [
"contique",
"imageio",
Expand All @@ -81,8 +82,8 @@ progress = ["tqdm"]
plot = ["matplotlib"]
view = ["pyvista[jupyter]"]

test = ["felupe[io,plot]"]
all = ["felupe[io,parallel,plot,progress,view]"]
test = ["felupe[io,constitution,plot]"]
all = ["felupe[io,constitution,parallel,plot,progress,view]"]

[tool.setuptools.dynamic]
version = {attr = "felupe.__about__.__version__"}
Expand Down
2 changes: 2 additions & 0 deletions src/felupe/constitution/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from . import jax
from ._base import CompositeMaterial, ConstitutiveMaterial, constitutive_material
from ._kinematics import AreaChange, LineChange, VolumeChange
from ._material import Material
Expand Down Expand Up @@ -97,4 +98,5 @@
"constitutive_material",
"CompositeMaterial",
"Volumetric",
"jax",
]
3 changes: 3 additions & 0 deletions src/felupe/constitution/jax/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from ._hyperelastic import Hyperelastic, vmap

__all__ = ["Hyperelastic", "vmap"]
282 changes: 282 additions & 0 deletions src/felupe/constitution/jax/_hyperelastic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,282 @@
# -*- coding: utf-8 -*-
"""
This file is part of FElupe.
FElupe is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
FElupe is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with FElupe. If not, see <http://www.gnu.org/licenses/>.
"""

import inspect
import warnings
from functools import wraps

import numpy as np

from .._material import Material


def vmap(fun, in_axes=0, out_axes=0, **kwargs):
"""Vectorizing map. Creates a function which maps ``fun`` over argument axes. This
decorator treats all non-specified arguments and keyword-arguments as static.
See Also
--------
jax.vmap : Vectorizing map. Creates a function which maps ``fun`` over argument
axes.
"""

import jax

@wraps(fun)
def vmap_with_static_kwargs(*args, **keywordargs):
# sorted list of all parameter keys, including kwargs with default values
sig = inspect.signature(fun)
keys = [
key
for key, value in sig.parameters.items()
if not (key in ["args", "kwargs"] and value.default == inspect._empty)
]

if not (
"kwargs" in sig.parameters.keys()
and sig.parameters["kwargs"].default == inspect._empty
):
# check if unexpected keyword-argument is given
for key in keywordargs.keys():
if key not in keys:
raise TypeError(
f"{fun.__name__}() got an unexpected keyword argument '{key}'"
)

# dict with default values for all parameters
parameters = dict(
[(key, value.default) for key, value in sig.parameters.items()]
)

# merge dict of default values with custom keyword arguments
items = {**parameters, **keywordargs}

# create sorted list of values of keyword-arguments, including default kwargs
keyword_args = [items[key] for key in keys[len(args) :]]

# don't map non-given arguments and keyword-arguments
if not hasattr(in_axes, "__len__"):
in_axes_tuple = (in_axes,)
else:
in_axes_tuple = in_axes

static_argnums = len(args) + len(keyword_args) - len(in_axes_tuple)
in_axes_new = (*in_axes_tuple, *([None] * static_argnums))

vfun = jax.vmap(fun, in_axes=in_axes_new, out_axes=out_axes, **kwargs)

return vfun(*args, *keyword_args)

return vmap_with_static_kwargs


def vmap2(fun, in_axes=0, out_axes=0, **kwargs):
"Nested vectorizing map."
return vmap(
vmap(fun, in_axes=in_axes, out_axes=out_axes, **kwargs),
in_axes=in_axes,
out_axes=out_axes,
**kwargs,
)


def total_lagrange(fun):
@wraps(fun)
def evaluate(F, *args, **kwargs):
C = F.T @ F
return fun(C, *args, **kwargs)

return evaluate


class Hyperelastic(Material):
r"""A hyperelastic material definition with a given function for the strain energy
density function per unit undeformed volume with Automatic Differentiation provided
by :mod:`jax`.
Parameters
----------
fun : callable
A strain energy density function in terms of the right Cauchy-Green deformation
tensor :math:`\boldsymbol{C}`. Function signature must be
``fun = lambda C, **kwargs: psi`` for functions without state variables and
``fun = lambda C, statevars, **kwargs: [psi, statevars_new]`` for functions
with state variables. It is important to only use differentiable math-functions
from :mod:`jax`.
nstatevars : int, optional
Number of state variables (default is 0).
jit : bool, optional
A flag to invoke just-in-time compilation (default is True).
parallel : bool, optional
A flag to invoke threaded strain energy density function evaluations (default
is False). Not implemented.
**kwargs : dict, optional
Optional keyword-arguments for the strain energy density function.
Notes
-----
The strain energy density function :math:`\psi` must be given in terms of the right
Cauchy-Green deformation tensor
:math:`\boldsymbol{C} = \boldsymbol{F}^T \boldsymbol{F}`.
.. warning::
It is important to only use differentiable math-functions from :mod:`jax`!
Take this minimal code-block as template
.. math::
\psi = \psi(\boldsymbol{C})
.. code-block::
import felupe as fem
import jax.numpy as jnp
def neo_hooke(C, mu):
"Strain energy function of the Neo-Hookean material formulation."
return mu / 2 * (jnp.linalg.det(C) ** (-1/3) * jnp.trace(C) - 3)
umat = fem.constitution.jax.Hyperelastic(neo_hooke, mu=1)
and this code-block for material formulations with state variables.
.. math::
\psi = \psi(\boldsymbol{C}, \boldsymbol{\zeta})
.. code-block::
import felupe as fem
import jax.numpy as np
def viscoelastic(C, Cin, mu, eta, dtime):
"Finite strain viscoelastic material formulation."
# unimodular part of the right Cauchy-Green deformation tensor
Cu = jnp.linalg.det(C) ** (-1 / 3) * C
# update of state variables by evolution equation
Ci = Cin.reshape(3, 3) + mu / eta * dtime * Cu
Ci = jnp.linalg.det(Ci) ** (-1 / 3) * Ci
# first invariant of elastic part of right Cauchy-Green deformation tensor
I1 = jnp.trace(Cu @ jnp.linalg.inv(Ci))
# strain energy function and state variable
return mu / 2 * (I1 - 3), Ci.ravel()
umat = fem.constitution.jax.Hyperelastic(
viscoelastic, mu=1, eta=1, dtime=1, nstatevars=9
)
.. note::
See the `documentation of JAX <https://jax.readthedocs.io>`_ for further
details. JAX uses single-precision (32bit) data types by default. This requires
to relax the tolerance of :func:`~felupe.newtonrhapson` to ``tol=1e-4``. If
required, JAX may be enforced to use double-precision at startup with
``jax.config.update("jax_enable_x64", True)``.
Examples
--------
View force-stretch curves on elementary incompressible deformations.
.. pyvista-plot::
:context:
>>> import felupe as fem
>>> import jax.numpy as jnp
>>>
>>> def neo_hooke(C, mu):
... "Strain energy function of the Neo-Hookean material formulation."
... return mu / 2 * (jnp.linalg.det(C) ** (-1/3) * jnp.trace(C) - 3)
>>>
>>> umat = fem.constitution.jax.Hyperelastic(neo_hooke, mu=1)
>>> ax = umat.plot(incompressible=True)
.. pyvista-plot::
:include-source: False
:context:
:force_static:
>>> import pyvista as pv
>>>
>>> fig = ax.get_figure()
>>> chart = pv.ChartMPL(fig)
>>> chart.show()
"""

def __init__(self, fun, nstatevars=0, jit=True, parallel=False, **kwargs):
import jax

has_aux = nstatevars > 0
self.fun = total_lagrange(fun)

if parallel:
warnings.warn("Parallel execution is not implemented.")

keyword_args = kwargs
if hasattr(fun, "kwargs"):
keyword_args = {**fun.kwargs, **keyword_args}

super().__init__(
stress=self._stress,
elasticity=self._elasticity,
nstatevars=nstatevars,
**keyword_args,
)

kwargs_jax = dict(in_axes=-1, out_axes=-1)
if nstatevars > 0:
kwargs_jax["in_axes"] = (-1, -1)

self._grad = vmap2(jax.grad(self.fun, has_aux=has_aux), **kwargs_jax)
self._hess = vmap2(
jax.jacfwd(jax.grad(self.fun, has_aux=has_aux), has_aux=has_aux),
**kwargs_jax,
)

if jit:
self._grad = jax.jit(self._grad)
self._hess = jax.jit(self._hess)

def _stress(self, x, **kwargs):
if self.nstatevars > 0:
statevars = x[1]

F = x[0]
if self.nstatevars > 0:
dWdF, statevars_new = self._grad(F, statevars, **kwargs)
statevars_new = np.array(statevars_new)
else:
dWdF = self._grad(F, **kwargs)
statevars_new = None

return [np.array(dWdF), statevars_new]

def _elasticity(self, x, **kwargs):
if self.nstatevars > 0:
statevars = x[1]

F = x[0]
if self.nstatevars > 0:
d2WdFdF, statevars_new = self._hess(F, statevars, **kwargs)
else:
d2WdFdF = self._hess(F, **kwargs)
return [np.array(d2WdFdF)]
2 changes: 1 addition & 1 deletion src/felupe/region/_boundary.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ def _init_faces(self):
dX_1[0] = -dX_1[0]

tangents.append(dX_1 / np.linalg.norm(dX_1, axis=0))

if self.ensure_3d:
tangents[0] = np.insert(tangents[0], len(tangents[0]), 0, axis=0)
other_tangent = np.zeros_like(tangents[0])
Expand Down
Loading

0 comments on commit ac460a9

Please sign in to comment.