Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more JAX-based models #884

Merged
merged 3 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ All notable changes to this project will be documented in this file. The format
- Add the attribute `RegionBoundary.tangents`, which contains a list of tangent unit vectors. For `quad` cell-types the length of this list is one and for `hexahedron` cell-types it is of length two.
- Add `math.inplane(A, vectors)` to return the in-plane components of a symmetric tensor `A`, where the plane is defined by its standard unit vectors.
- Add `constitution.jax.Hyperelastic` as a feature-equivalent alternative to `Hyperelastic` with `jax` as backend.
- Add `constitution.jax.Material` as a feature-equivalent alternative to `MaterialAD` with `jax` as backend.
- Add the material models for JAX-based materials `felupe.constitution.jax.models.hyperelastic.mooney_rivlin()`, `felupe.constitution.jax.models.hyperelastic.yeoh()`, `felupe.constitution.jax.models.hyperelastic.third_order_deformation()` and `felupe.constitution.jax.models.lagrange.morph()`.
- Add `constitution.jax.Material(..., jacobian=None)` with JAX as backend. A custom jacobian-callable may be passed to switch between forward- and backward-mode automatic differentiation.
- Add material models for JAX-based materials: `felupe.constitution.jax.models.hyperelastic.mooney_rivlin()`, `felupe.constitution.jax.models.hyperelastic.yeoh()`, `felupe.constitution.jax.models.hyperelastic.third_order_deformation()`, `felupe.constitution.jax.models.lagrange.morph()`, `felupe.constitution.jax.models.lagrange.morph_representative_directions()`.
- Add `felupe.constitution.jax.total_lagrange()` and `felupe.constitution.jax.updated_lagrange()` function decorators for JAX materials.

### Changed
Expand Down
7 changes: 7 additions & 0 deletions docs/felupe/constitution/autodiff/jax.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,14 @@ These material model formulations are defined by a strain energy density functio

**Material Models for** :class:`felupe.constitution.jax.Material`

The material model formulations are defined by the first Piola-Kirchhoff stress tensor.
Function-decorators are available to use Total-Lagrange and Updated-Lagrange material
formulations in :class:`~felupe.constitution.jax.Material`.

.. autosummary::

felupe.constitution.jax.models.lagrange.morph
felupe.constitution.jax.models.lagrange.morph_representative_directions

**Tools**

Expand Down Expand Up @@ -62,4 +67,6 @@ These material model formulations are defined by a strain energy density functio

.. autofunction:: felupe.constitution.jax.models.lagrange.morph

.. autofunction:: felupe.constitution.jax.models.lagrange.morph_representative_directions

.. autofunction:: felupe.constitution.jax.vmap
13 changes: 11 additions & 2 deletions src/felupe/constitution/jax/_material.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ class Material(MaterialDefault):
parallel : bool, optional
A flag to invoke threaded function evaluations (defaultnis False). Not
implemented.
jacobian : callable or None, optional
A callable for the Jacobian. Default is None, where :func:`jax.jacobian` is
used. This may be used to switch to forward-mode differentian
:func:`jax.jacfwd`.
**kwargs : dict, optional
Optional keyword-arguments for the gradient of the strain energy density
function.
Expand Down Expand Up @@ -151,12 +155,17 @@ def viscoelastic(F, Cin, mu, eta, dtime):

"""

def __init__(self, fun, nstatevars=0, jit=True, parallel=False, **kwargs):
def __init__(
self, fun, nstatevars=0, jit=True, parallel=False, jacobian=None, **kwargs
):
import jax

has_aux = nstatevars > 0
self.fun = fun

if jacobian is None:
jacobian = jax.jacobian

if parallel:
warnings.warn("Parallel execution is not implemented.")

Expand All @@ -176,7 +185,7 @@ def __init__(self, fun, nstatevars=0, jit=True, parallel=False, **kwargs):
kwargs_jax["in_axes"] = (-1, -1)

self._grad = vmap2(self.fun, **kwargs_jax)
self._hess = vmap2(jax.jacfwd(self.fun, has_aux=has_aux), **kwargs_jax)
self._hess = vmap2(jacobian(self.fun, has_aux=has_aux), **kwargs_jax)

if jit:
self._grad = jax.jit(self._grad)
Expand Down
15 changes: 15 additions & 0 deletions src/felupe/constitution/jax/models/lagrange/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,23 @@
"""
Strain-energy density functions for strain energy-gradient (stress) model formulations.

This module contains material model formulations to be used as the ``fun``-argument in
:func:`~felupe.constitution.jax.Material`. The gradient as well as the hessian of
the strain energy density function is carried out by automatic differentiation using
:mod:`jax`. Hence, all math-functions must be taken from :mod:`jax.numpy`.
"""

from ._morph import morph
from ._morph_representative_directions import morph_representative_directions
from ._morph_uniaxial import morph_uniaxial

__all__ = [
"morph",
"morph_representative_directions",
"morph_uniaxial",
]

# default (stable) material parameters
morph.kwargs = dict(p=[0, 0, 0, 0, 0, 1, 0, 0])
morph_representative_directions.kwargs = dict(p=[0, 0, 0, 0, 0, 1, 0, 0])
morph_uniaxial.kwargs = dict(p=[0, 0, 0, 0, 0, 1, 0, 0])
149 changes: 4 additions & 145 deletions src/felupe/constitution/jax/models/lagrange/_morph.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,156 +15,15 @@
You should have received a copy of the GNU General Public License
along with FElupe. If not, see <http://www.gnu.org/licenses/>.
"""
from functools import wraps

from ....tensortrax.models.lagrange import morph as morph_docstring
from ..._total_lagrange import total_lagrange


@wraps(morph_docstring)
@total_lagrange
def morph(F, statevars, p):
r"""First Piola-Kirchhoff stress tensor of the
`MORPH <https://doi.org/10.1016/s0749-6419(02)00091-8>`_ model formulation [1]_.

Parameters
----------
F : jax.Array
Deformation gradient tensor.
statevars : jax.Array
Vector of stacked state variables (CTS, C, SA).
p : list of float
A list which contains the 8 material parameters.

Notes
-----
The MORPH material model is implemented as a second Piola-Kirchhoff stress-based
formulation with automatic differentiation. The Tresca invariant of the distortional
part of the right Cauchy-Green deformation tensor is used as internal state
variable, see Eq. :eq:`morph-state`.

.. warning::
While the `MORPH <https://doi.org/10.1016/s0749-6419(02)00091-8>`_-material
formulation captures the Mullins effect and quasi-static hysteresis effects of
rubber mixtures very nicely, it has been observed to be unstable for medium- to
highly-distorted states of deformation.

.. math::
:label: morph-state

\boldsymbol{C} &= \boldsymbol{F}^T \boldsymbol{F}

I_3 &= \det (\boldsymbol{C})

\hat{\boldsymbol{C}} &= I_3^{-1/3} \boldsymbol{C}

\hat{\lambda}^2_\alpha &= \text{eigvals}(\hat{\boldsymbol{C}})

\hat{C}_T &= \max \left( \hat{\lambda}^2_\alpha - \hat{\lambda}^2_\beta \right)

\hat{C}_T^S &= \max \left( \hat{C}_T, \hat{C}_{T,n}^S \right)

A sigmoid-function is used inside the deformation-dependent variables
:math:`\alpha`, :math:`\beta` and :math:`\gamma`, see Eq. :eq:`morph-sigmoid`.

.. math::
:label: morph-sigmoid

f(x) &= \frac{1}{\sqrt{1 + x^2}}

\alpha &= p_1 + p_2 \ f(p_3\ C_T^S)

\beta &= p_4\ f(p_3\ C_T^S)

\gamma &= p_5\ C_T^S\ \left( 1 - f\left(\frac{C_T^S}{p_6}\right) \right)

The rate of deformation is described by the Lagrangian tensor and its Tresca-
invariant, see Eq. :eq:`morph-rate-of-deformation`.

.. note::
It is important to evaluate the incremental right Cauchy-Green tensor by the
difference of the final and the previous state of deformation, not by its
variation with respect to the deformation gradient tensor.

.. math::
:label: morph-rate-of-deformation

\hat{\boldsymbol{L}} &= \text{sym}\left(
\text{dev}(\boldsymbol{C}^{-1} \Delta\boldsymbol{C})
\right) \hat{\boldsymbol{C}}

\lambda_{\hat{\boldsymbol{L}}, \alpha} &= \text{eigvals}(\hat{\boldsymbol{L}})

\hat{L}_T &= \max \left(
\lambda_{\hat{\boldsymbol{L}}, \alpha}-\lambda_{\hat{\boldsymbol{L}}, \beta}
\right)

\Delta\boldsymbol{C} &= \boldsymbol{C} - \boldsymbol{C}_n

The additional stresses evolve between the limiting stresses, see Eq.
:eq:`morph-stresses`. The additional deviatoric-enforcement terms [1]_ are neglected
in this implementation.

.. math::
:label: morph-stresses

\boldsymbol{S}_L &= \left(
\gamma \exp \left(p_7 \frac{\hat{\boldsymbol{L}}}{\hat{L}_T}
\frac{\hat{C}_T}{\hat{C}_T^S} \right) +
p8 \frac{\hat{\boldsymbol{L}}}{\hat{L}_T}
\right) \boldsymbol{C}^{-1}

\boldsymbol{S}_A &= \frac{
\boldsymbol{S}_{A,n} + \beta\ \hat{L}_T\ \boldsymbol{S}_L
}{1 + \beta\ \hat{L}_T}

\boldsymbol{S} &= 2 \alpha\ \text{dev}( \hat{\boldsymbol{C}} )
\boldsymbol{C}^{-1}+\text{dev}\left(\boldsymbol{S}_A\ \boldsymbol{C}\right)
\boldsymbol{C}^{-1}

Examples
--------
.. pyvista-plot::
:context:

>>> import felupe as fem
>>> import felupe.constitution.jax as mat
>>>
>>> umat = mat.Material(
... mat.models.lagrange.morph,
... p=[0.039, 0.371, 0.174, 2.41, 0.0094, 6.84, 5.65, 0.244],
... nstatevars=13,
... )
>>> ax = umat.plot(
... incompressible=True,
... ux=fem.math.linsteps(
... # [1, 2, 1, 2.75, 1, 3.5, 1, 4.2, 1, 4.8, 1, 4.8, 1],
... [1, 2.75, 1, 2.75],
... num=20,
... ),
... ps=None,
... bx=None,
... )

.. pyvista-plot::
:include-source: False
:context:
:force_static:

>>> import pyvista as pv
>>>
>>> fig = ax.get_figure()
>>> chart = pv.ChartMPL(fig)
>>> chart.show()

References
----------
.. [1] D. Besdo and J. Ihlemann, "A phenomenological constitutive model for
rubberlike materials and its numerical applications", International Journal
of Plasticity, vol. 19, no. 7. Elsevier BV, pp. 1019–1036, Jul. 2003. doi:
`10.1016/s0749-6419(02)00091-8 <https://doi.org/10.1016/s0749-6419(02)00091-8>`_.

See Also
--------
felupe.constitution.tensortrax.models.lagrange.morph : MORPH model (tensortrax)
"""

from jax.numpy import (
array,
concatenate,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# -*- coding: utf-8 -*-
"""
This file is part of FElupe.

FElupe is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

FElupe is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with FElupe. If not, see <http://www.gnu.org/licenses/>.
"""
from functools import wraps

from ....tensortrax.models.lagrange import morph_representative_directions as morph_repr
from ._morph_uniaxial import morph_uniaxial
from .microsphere import affine_force_statevars


@wraps(morph_repr)
def morph_representative_directions(F, statevars, p, ε=1e-6):
def f(λ, statevars, **kwargs):
dψdλ, statevars_new = morph_uniaxial(λ, statevars, **kwargs)
return 5 * dψdλ, statevars_new

return affine_force_statevars(F, statevars, f=f, kwargs={"p": p, "ε": ε})
58 changes: 58 additions & 0 deletions src/felupe/constitution/jax/models/lagrange/_morph_uniaxial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# -*- coding: utf-8 -*-
"""
This file is part of FElupe.

FElupe is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

FElupe is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with FElupe. If not, see <http://www.gnu.org/licenses/>.
"""
from functools import wraps

from ....tensortrax.models.lagrange import morph_uniaxial as morph_ux


@wraps(morph_ux)
def morph_uniaxial(λ, statevars, p, ε=1e-6):
from jax.numpy import abs as jabs
from jax.numpy import concatenate, exp, maximum, sqrt

CTSn = statevars[:21]
λn = 1 + statevars[21:42]
SA1n = statevars[42:63]
SA2n = statevars[63:84]

CT = jabs(λ**2 - 1 / λ)
CTS = maximum(CT, CTSn)

L1 = 2 * (λ**3 / λn - λn**2) / 3
L2 = (λn**2 / λ**3 - 1 / λn) / 3
LT = jabs(L1 - L2)

sigmoid = lambda x: 1 / sqrt(1 + x**2)
α = p[0] + p[1] * sigmoid(p[2] * CTS)
β = p[3] * sigmoid(p[2] * CTS)
γ = p[4] * CTS * (1 - sigmoid(CTS / p[5]))

L1_LT = L1 / (ε + LT)
L2_LT = L2 / (ε + LT)
CT_CTS = CT / (ε + CTS)

SL1 = (γ * exp(p[6] * L1_LT * CT_CTS) + p[7] * L1_LT) / λ**2
SL2 = (γ * exp(p[6] * L2_LT * CT_CTS) + p[7] * L2_LT) * λ

SA1 = (SA1n + β * LT * SL1) / (1 + β * LT)
SA2 = (SA2n + β * LT * SL2) / (1 + β * LT)

dψdλ = (2 * α + SA1) * λ - (2 * α + SA2) / λ**2
statevars_new = concatenate([CTS, (λ - 1), SA1, SA2])

return dψdλ, statevars_new
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from ._framework_affine import affine_force_statevars

__all__ = [
"affine_force_statevars",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# -*- coding: utf-8 -*-
"""
This file is part of FElupe.

FElupe is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

FElupe is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with FElupe. If not, see <http://www.gnu.org/licenses/>.
"""
from ......quadrature import BazantOh
from ...._total_lagrange import total_lagrange


@total_lagrange
def affine_force_statevars(F, statevars, f, kwargs, quadrature=BazantOh(n=21)):
"Micro-sphere model: Affine force (stretch) part."

from jax.numpy import einsum, sqrt, trace
from jax.numpy.linalg import det, inv

r = quadrature.points
M = einsum("ai,aj->aij", r, r)
Mw = einsum("aij,a->aij", M, quadrature.weights)

# affine stretches (unimodular part)
J = det(F)
C = F.T @ F
λ = J ** (-1 / 3) * sqrt(einsum("ij...,aij->a...", C, M))

dψdλ, statevars_new = f(λ, statevars, **kwargs)
dψdE = einsum("a...,aij->ij...", dψdλ / λ, Mw)

S = J ** (-2 / 3) * (dψdE - trace(dψdE @ C) / 3 * inv(C))

return S, statevars_new
Loading
Loading