diff --git a/Makefile b/Makefile
index e4f4feb27..bc75d96f4 100644
--- a/Makefile
+++ b/Makefile
@@ -5,6 +5,8 @@ _clean_coverage:
coverage erase
_clean_py:
+ find src/ tests/ contrib/ -name '*.nbi' -exec rm -f {} +
+ find src/ tests/ contrib/ -name '*.nbc' -exec rm -f {} +
find src/ tests/ contrib/ -name '*.pyc' -exec rm -f {} +
find src/ tests/ contrib/ -name '*.pyo' -exec rm -f {} +
find src/ tests/ contrib/ -name '*~' -exec rm -f {} +
@@ -49,6 +51,6 @@ upload:
done
test:
- DISPLAY= tox
+ DISPLAY= HAPSIRA_CACHE=0 tox -e style,tests-fast,tests-slow,docs
.PHONY: docs docker image release upload
diff --git a/contrib/CR3BP/CR3BP.py b/contrib/CR3BP/CR3BP.py
index 1b93cabc4..2a78611d1 100644
--- a/contrib/CR3BP/CR3BP.py
+++ b/contrib/CR3BP/CR3BP.py
@@ -32,7 +32,8 @@
from numba import njit as jit
import numpy as np
-from hapsira._math.ivp import DOP853, solve_ivp
+# from hapsira.core.math.ivp import solve_ivp
+from scipy.integrate import solve_ivp, DOP853
@jit
diff --git a/docs/source/changelog.md b/docs/source/changelog.md
index 6a5502281..e9671ec4c 100644
--- a/docs/source/changelog.md
+++ b/docs/source/changelog.md
@@ -1,4 +1,31 @@
-# What's new
+# Changes
+
+## hapsira 0.19.0 - 2024-XX-XX
+
+**CAUTION**: A number changes at least partially **BREAK BACKWARDS COMPATIBILITY** for certain use cases.
+
+This release features a significant refactoring of `core`, see [hapsira #7](https://github.com/pleiszenburg/hapsira/pull/7) for details. All relevant `core` functions are now designed to work equally on CPUs and GPUs as either [universal functions](https://numba.readthedocs.io/en/stable/user/vectorize.html#the-vectorize-decorator) or [generalized universal functions](https://numba.readthedocs.io/en/stable/user/vectorize.html#the-guvectorize-decorator). As a "side-effect", all relevant `core` functions allow parallel operation with full [broadcasting semantics](https://numpy.org/doc/stable/user/basics.broadcasting.html). Their single-thread performance was also increased depending on use-case by around two orders of magnitude. All refactored **functions** in `core` were **renamed**, now carrying additional suffixes to indicate how they can or can not be invoked.
+
+Critical fix changing behaviour: The Loss of Signal (LOS) event would previously produce wrong results.
+
+Module layout change: `core.earth_atmosphere` became `core.earth.atmosphere`.
+
+- FEATURE: New `core.math` module, see [hapsira #7](https://github.com/pleiszenburg/hapsira/pull/7), including fast replacements for many `numpy` and some `scipy` functions, most notably:
+ - [scipy.interpolate.interp1d](https://docs.scipy.org/doc/scipy/reference/generated/scipy.interpolate.interp1d.html) is replaced by `core.math.interpolate.interp_hb`. It custom-compiles 1D linear interpolators, embedding data statically into the compiled functions.
+ - [scipy.integrate.solve_ivp](https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.solve_ivp.html), [scipy.integrate.DOP853](https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.DOP853.html) and [scipy.optimize.brentq](https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.brentq.html) are replaced by `core.math.ivp`, a purely functional compiled implementation running entirely on the stack.
+- FEATURE: New `core.jit` module, wrapping a number of `numba` functions to have a central place to apply settings, see [hapsira #7](https://github.com/pleiszenburg/hapsira/pull/7)
+- FEATURE: New `settings` module, mainly for handling JIT compiler settings, see [hapsira #7](https://github.com/pleiszenburg/hapsira/pull/7)
+- FEATURE: New `debug` module, including logging capabilities, mainly logging JIT compiler issues, see [hapsira #7](https://github.com/pleiszenburg/hapsira/pull/7)
+- FEATURE: Significant portions of the COESA76 atmophere model are now compiled and available as part of `core`, see [hapsira #7](https://github.com/pleiszenburg/hapsira/pull/7). `core.earth_atmosphere` was renamed into `core.earth.atmosphere`.
+- DOCS: The `core` module is technically user-facing and as such now explicitly documented, see [hapsira #7](https://github.com/pleiszenburg/hapsira/pull/7)
+- DOCS: The "quickstart" section received an update and now includes all previously missing imports.
+- FIX: The Loss of Signal (LOS) event would misshandle the position of the secondary body i.e. producing wrong results, see [hapsira #7](https://github.com/pleiszenburg/hapsira/pull/7) as well as the [relevant commit](https://github.com/pleiszenburg/hapsira/commit/988a91cd22ff1de285c33af35b13d288963fcaf7)
+- FIX: The Cowell propagator could produce wrong results if times of flight (tof) where provided in units other than seconds, see [hapsira #7](https://github.com/pleiszenburg/hapsira/pull/7)
+- FIX: Broken plots in example notebooks, see [hapsira #7](https://github.com/pleiszenburg/hapsira/pull/7)
+- FIX: Typo in `bodies`, see [hapsira #6](https://github.com/pleiszenburg/hapsira/pull/6)
+- FIX: Some notebooks in the documentation had disappeared due to incomplete rebranding
+- DEV: Parallel (multi-core) testing enabled by default, see [hapsira #5](https://github.com/pleiszenburg/hapsira/pull/5)
+- DEV: Deactivated warning due to too many simultaneously opened `matplotlib` plots
## hapsira 0.18.0 - 2023-12-24
@@ -233,7 +260,7 @@ as well as the results from Google Summer of Code 2021.
The interactive orbit plotters {py:class}`~poliastro.plotting.OrbitPlotter2D`
and {py:class}`~poliastro.plotting.OrbitPlotter3D`
now have a new method to easily display impulsive burns.
- See {doc}`/examples/going-to-jupiter-with-python-using-jupyter-and-poliastro`
+ See {doc}`/examples/going-to-jupiter-with-python-using-jupyter-and-hapsira`
for an example.
- **Many performance improvements**
Several contributors have helped accelerate more algorithms
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 244d5dae9..20fc4979f 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -87,7 +87,7 @@
}
project = "hapsira"
-copyright = "2023 Sebastian M. Ernst"
+copyright = "2023-2024 Sebastian M. Ernst"
project_ver = version(project)
version = ".".join(project_ver.split(".")[:2])
diff --git a/docs/source/core.md b/docs/source/core.md
new file mode 100644
index 000000000..4efb54707
--- /dev/null
+++ b/docs/source/core.md
@@ -0,0 +1,227 @@
+(coremodule)=
+# Core module
+
+The `core` module handles most actual heavy computations. It is compiled via [numba](https://numba.readthedocs.io). For both working with functions from `core` directly and contributing to it, it is highly recommended to gain some basic understanding of how `numba` works.
+
+`core` is designed to work equally on CPUs and GPUs. (Most) exported `core` APIs interfacing with the rest of `hapsira` are either [universal functions](https://numba.readthedocs.io/en/stable/user/vectorize.html#the-vectorize-decorator) or [generalized universal functions](https://numba.readthedocs.io/en/stable/user/vectorize.html#the-guvectorize-decorator), which allow parallel operation with full [broadcasting semantics](https://numpy.org/doc/stable/user/basics.broadcasting.html).
+
+```{warning}
+Some `core` functions have yet not been refactored into this shape and will soon follow the same approach.
+```
+
+## Compiler targets
+
+There are three compiler targets, which can be controlled through settings and/or environment variables:
+
+- `cpu`: Single-threaded on CPUs
+- `parallel`: Parallelized by `numba` via [threading layers](https://numba.readthedocs.io/en/stable/user/threading-layer.html) on CPUs
+- `cuda`: Parallelized by `numba` via CUDA on Nvidia GPUs
+
+All code of `core` will be compiled for one of the above listed targets. If multiple targets are supposed to be used simultaneously, this can only be achieved by multiple Python processes running in parallel.
+
+## Compiler decorators
+
+`core` offers the follwing JIT compiler decorators provided via `core.jit`:
+
+- `vjit`: Wraps `numba.vectorize`. Functions decorated by it carry the suffix `_vf`.
+- `gjit`: Wraps `numba.guvectorize`. Functions decorated by it carry the suffix `_gf`.
+- `hjit`: Wraps `numba.jit` or `numba.cuda.jit`, depending on compiler target. Functions decorated by it carry the suffix `_hf`.
+- `djit`: Variation of `hjit` with fixed function signature for user-provided functions used by `Cowell`
+
+`core` functions dynamically generating (and compiling) functions within their scope carry `_hb`, `_vb` and `_gb` suffixes.
+
+Wrapping `numba` functions allows to centralize compiler options and target switching as well as to simplify typing.
+
+The decorators are applied in a **hierarchy**:
+
+- Functions decorated by either `vjit` and `gjit` serve as the **only** interface between regular uncompiled Python code and `core`
+- Functions decorated by `vjit` and `hjit` only call functions decorated by `hjit`
+- Functions decorated by `hjit` can only call each other.
+
+```{note}
+The "hierarchy" of decorators is imposed by CUDA-compatibility. While functions decorated by `numba.jit` (targets `cpu` and `parallel`) can be called from uncompiled Python code, functions decorated by `numba.cuda.jit` (target `cuda`) are considered [device functions](https://numba.readthedocs.io/en/stable/cuda/device-functions.html) and can not be called by uncompiled Python code directly. They are supposed to be called by CUDA-kernels (or other device functions) only (slightly simplifying the actual situation as implemented by `numba`). If the target is set to `cuda`, functions decorated by `numba.vectorize` and `numba.guvectorize` become CUDA kernels.
+```
+
+```{warning}
+As a result of name suffixes as of version `0.19.0`, many `core` module functions have been renamed making the package intentionally backwards-incompatible. Functions not yet using the new infrastructure can be recognized based on lack of suffix. Eventually all `core` functions will use this infrastructure and carry matching suffixes.
+```
+
+```{note}
+Some functions decorated by `gjit` must receive a dummy parameter, also explicitly named `dummy`. It is usually an empty `numpy` array of shape `(3,)` of data type `u1` (unsigned one-byte integer). This is a work-around for [numba #2797](https://github.com/numba/numba/issues/2797).
+```
+
+## Compiler errors
+
+Misconfigured compiler decorators or unavailable targets raise an `errors.JitError` exception.
+
+## Keyword arguments and defaults
+
+Due to incompletely documented limitations in `numba`, see [documentation](https://numba.readthedocs.io/en/stable/reference/pysupported.html#function-calls) and [numba #7870](https://github.com/numba/numba/issues/7870), functions decorated by `hjit`, `vjit` and `gjit` can not have defaults for any of their arguments. In this context, those functions can not reliably be called with keyword arguments, too, which must therefore be avoided. Defaults are provided as constants within the same submodule, usually the function name in capital letters followed by the name of the argument, also in capital letters.
+
+## Dependencies
+
+Functions decorated by `vjit`, `gjit` and `hjit` are only allowed to depend on Python's standard library's [math module](https://docs.python.org/3/library/math.html), but **not** on other third-party packages like [numpy](https://numpy.org/doc/stable/) or [scipy](https://docs.scipy.org/doc/scipy/) for that matter - except for certain details like [enforcing floating point precision](https://numpy.org/doc/stable/user/basics.types.html) as provided by `core.math.ieee754`
+
+```{note}
+Eliminating `numpy` and other dependencies serves two purposes. While it is critical for [CUDA-compatiblity](https://numba.readthedocs.io/en/stable/cuda/cudapysupported.html), it additionally makes the code significantly faster on CPUs.
+```
+
+## Typing
+
+All functions decorated by `hjit`, `vjit` and `gjit` must by typed using [signatures similar to those of numba](https://numba.readthedocs.io/en/stable/reference/types.html).
+
+All compiled code enforces a single floating point precision level, which can be configured. The default is FP64 / double precision. For simplicity, the type shortcut is `f`, replacing `f2`, `f4` or `f8`. Consider the following example:
+
+```python
+from numba import vectorize
+from hapsira.core.jit import vjit
+
+@vectorize("f8(f8)")
+def foo(x):
+ return x ** 2
+
+@vjit("f(f)")
+def bar_vf(x):
+ return x ** 2
+```
+
+Additional infrastructure can be found in `core.math.ieee754`. The default floating point type is exposed as `core.math.ieee754.float_` for explicit conversions. A matching epsilon is exposed as `core.math.ieee754.EPS`.
+
+```{note}
+Divisions by zero should, regardless of compiler target or even entirely deactivated compiler, always result in `inf` (infinity) instead of `ZeroDivisionError` exceptions. Most divisions within `core` are therefore explicitly guarded.
+```
+
+3D vectors are expressed as tuples, type shortcut `V`, replacing `Tuple([f,f,f])`. Consider the following example:
+
+```python
+from numba import njit
+from hapsira.core.jit import hjit
+
+@njit("f8(Tuple([f8,f8,f8]))")
+def foo(x):
+ return x[0] + x[1] + x[2]
+
+@hjit("f(V)")
+def bar_hf(x):
+ return x[0] + x[1] + x[2]
+```
+
+Matrices are expressed as tuples of tuples, type shortcut `M`, replacing `Tuple([V,V,V])`. Consider the following example:
+
+```python
+from numba import njit
+from hapsira.core.jit import hjit
+
+@njit("f8(Tuple([Tuple([f8,f8,f8]),Tuple([f8,f8,f8]),Tuple([f8,f8,f8])]))")
+def foo(x):
+ sum_ = 0
+ for idx in range(3):
+ for jdx in range(3):
+ sum_ += x[idx][jdx]
+ return sum_
+
+@hjit("f(M)")
+def bar_hf(x):
+ sum_ = 0
+ for idx in range(3):
+ for jdx in range(3):
+ sum_ += x[idx][jdx]
+ return sum_
+```
+
+Function types use the shortcut `F`, replacing `FunctionType`.
+
+## Cowell’s formulation
+
+Cowell’s formulation is one of the few places where `core` is exposed directly to the user.
+
+### Two-body function
+
+In its most simple form, the `CowellPropagator` relies on a variation of `func_twobody_hf` as a parameter, a function compiled by `hjit`, which can technically be omitted:
+
+```python
+from hapsira.core.propagation.base import func_twobody_hf
+from hapsira.twobody.propagation import CowellPropagator
+
+prop = CowellPropagator(f=func_twobody_hf)
+prop = CowellPropagator() # identical to the above
+```
+
+If perturbations are applied, however, `func_twobody_hf` needs to be altered. It is important that the new altered function is compiled via the `hjit` decorator and that is has the correct signature. To simplify the matter for users, a variation of `hjit` named `djit` carries the correct signature implicitly:
+
+```python
+from hapsira.core.jit import djit, hjit
+from hapsira.core.math.linalg import mul_Vs_hf
+from hapsira.core.propagation.base import func_twobody_hf
+from hapsira.twobody.propagation import CowellPropagator
+
+@hjit("Tuple([V,V])(f,V,V,f)")
+def foo_hf(t0, rr, vv, k):
+ du_kep_rr, du_kep_vv = func_twobody_hf(t0, rr, vv, k)
+ return du_kep_rr, mul_Vs_hf(du_kep_vv, 1.1) # multiply speed vector by 1.1
+
+@djit
+def bar_hf(t0, rr, vv, k):
+ du_kep_rr, du_kep_vv = func_twobody_hf(t0, rr, vv, k)
+ return du_kep_rr, mul_Vs_hf(du_kep_vv, 1.1) # multiply speed vector by 1.1
+
+prop = CowellPropagator(f=foo_hf)
+prop = CowellPropagator(f=bar_hf) # identical to the above
+```
+
+### Events
+
+The core of each event's implementation must also be compiled by `hjit`. New events must inherit from `BaseEvent`. The compiled implementation should be an attribute, a function or a static method, named `_impl_hf`. Once this attribute is specified, an explicit call to the `_wrap` method, most likely from the constructor, automatically generates a second version of `_impl_hf` named `_impl_dense_hf` that is used to not only evaluate but also to approximate the exact time of flight of an event based on dense output of the underlying solver.
+
+## Settings
+
+The following settings, available via `settings.settings`, allow to alter the compiler's behaviour:
+
+- `DEBUG`: `bool`, default `False`
+- `CACHE`: `bool`, default `not DEBUG`
+- `TARGET`: `str`, default `cpu`, alternatives `parallel` and `cuda`
+- `INLINE`: `bool`, default `TARGET == "cuda"`
+- `NOPYTHON`: `bool`, default `True`
+- `FORCEOBJ`: `bool`, default `False`
+- `PRECISION`: `str`, default `f8`, alternatives `f2` and `f4`
+
+```{note}
+Settings can be switched by either setting environment variables or importing the `settings` module **before** any other (sub-) module is imported.
+```
+
+The `DEBUG` setting disables caching and enables the highest log level, among other things.
+
+`CACHE` only works for `cpu` and `parallel` targets. It speeds up import times drastically if the package gets reused. Dynamically generated functions can not be cached and must be exempt from caching by passing `cache = False` as a parameter to the JIT compiler decorator.
+
+```{warning}
+Building the cache should not be done in parallel processes - this will most likely result in non-deterministic segmentation faults, see [numba #4807](https://github.com/numba/numba/issues/4807). Once `core` is fully compiled and cached, it can however be used in parallel processes. Rebuilding the cache can usually reliably resolve segmentation faults.
+```
+
+Inlining via `INLINE` drastically increases performance but also compile times. It is the default behaviour for target `cuda`. See [relevant chapter in numba documentation](https://numba.readthedocs.io/en/stable/developer/inlining.html#notes-on-inlining) for details.
+
+`NOPYTHON` and `FORCEOBJ` provide additional debugging capabilities but should not be changed for regular use. For details, see [nopython mode](https://numba.readthedocs.io/en/stable/glossary.html#term-nopython-mode) and [object mode](https://numba.readthedocs.io/en/stable/glossary.html#term-object-mode) in `numba`'s documentation.
+
+The default `PRECISION` of all floating point operations is FP64 / double precision float.
+
+```{warning}
+`hapsira`, formerly `poliastro`, was validated for FP64. Certain parts like Cowell reliably operate at this precision only. Other parts like for instance atmospheric models can easily handle single precision. This option is therefore provided for experimental purposes only.
+```
+
+## Logging
+
+Compiler issues are logged via logging channel `hapsira` using Python's standard library's [logging module](https://docs.python.org/3/howto/logging.html), also available as `debug.logger`. All compiler activity can be observed by enabling log level `debug`.
+
+## Math
+
+The former `_math` module, version `0.18` and earlier, has become a first-class citizen as `core.math`, fully compiled by the above mentioned infrastructure. `core.math` contains a number of replacements for `numpy` operations, mostly found in `core.math.linalg`. All of those functions do not allocate memory and are free of side-effects including a lack of changes to their parameters.
+
+Functions in `core.math` follow a loose naming convention, indicating for what types of parameters they can be used. `mul_Vs_hf` for instance is a multiplication of a vector `V` and a scalar `s` (floating point). `M` indicates matricis.
+
+`core.math` also replaces (some) required `scipy` functions:
+
+- [scipy.interpolate.interp1d](https://docs.scipy.org/doc/scipy/reference/generated/scipy.interpolate.interp1d.html) is replaced by `core.math.interpolate.interp_hb`. It custom-compiles 1D linear interpolators, embedding data statically into the compiled functions.
+- [scipy.integrate.solve_ivp](https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.solve_ivp.html), [scipy.integrate.DOP853](https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.DOP853.html) and [scipy.optimize.brentq](https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.brentq.html) are replaced by `core.math.ivp`.
+
+```{note}
+Future releases might remove more dependencies to `scipy` from `core` for full CUDA compatibility and additional performance.
+```
diff --git a/docs/source/examples/detecting-events.myst.md b/docs/source/examples/detecting-events.myst.md
index de3a48447..ec41a376a 100644
--- a/docs/source/examples/detecting-events.myst.md
+++ b/docs/source/examples/detecting-events.myst.md
@@ -4,7 +4,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
- jupytext_version: 1.14.0
+ jupytext_version: 1.16.0
kernelspec:
display_name: Python 3 (ipykernel)
language: python
@@ -30,7 +30,7 @@ an event during an orbit's propagation is fairly simple:
2. Pass the `Event` object(s) as an argument to `CowellPropagator`.
3. Detect events! Optionally, the `terminal` and `direction` attributes can be set as required.
-```{code-cell}
+```{code-cell} ipython3
# Imports
import numpy as np
from numpy.linalg import norm
@@ -63,10 +63,12 @@ from hapsira.util import time_range
## Altitude Crossing Event
Let's define some natural perturbation conditions for our orbit so that its altitude decreases with time.
-```{code-cell}
+```{code-cell} ipython3
from hapsira.constants import H0_earth, rho0_earth
-from hapsira.core.perturbations import atmospheric_drag_exponential
-from hapsira.core.propagation import func_twobody
+from hapsira.core.jit import djit
+from hapsira.core.math.linalg import add_VV_hf
+from hapsira.core.perturbations import atmospheric_drag_exponential_hf
+from hapsira.core.propagation.base import func_twobody_hf
R = Earth.R.to_value(u.km)
@@ -80,19 +82,18 @@ A_over_m = ((np.pi / 4.0) * (u.m**2) / (100 * u.kg)).to_value(
rho0 = rho0_earth.to_value(u.kg / u.km**3) # kg/km^3
H0 = H0_earth.to_value(u.km) # km
-
-def f(t0, u_, k):
- du_kep = func_twobody(t0, u_, k)
- ax, ay, az = atmospheric_drag_exponential(
- t0, u_, k, R=R, C_D=C_D, A_over_m=A_over_m, H0=H0, rho0=rho0
+@djit
+def f_hf(t0, r0, v0, k):
+ du_kep_rr, du_kep_vv = func_twobody_hf(t0, r0, v0, k)
+ a = atmospheric_drag_exponential_hf(
+ t0, r0, v0, k, R=R, C_D=C_D, A_over_m=A_over_m, H0=H0, rho0=rho0
)
- du_ad = np.array([0, 0, 0, ax, ay, az])
- return du_kep + du_ad
+ return du_kep_rr, add_VV_hf(du_kep_vv, a)
```
We shall use the `CowellPropagator` with the above perturbating conditions and pass the events we want to keep track of, in this case only the `AltitudeCrossEvent`.
-```{code-cell}
+```{code-cell} ipython3
tofs = np.arange(0, 2400, 100) << u.s
orbit = Orbit.circular(Earth, 150 * u.km)
@@ -101,7 +102,7 @@ thresh_alt = 50 # in km
altitude_cross_event = AltitudeCrossEvent(thresh_alt, R) # Set up the event.
events = [altitude_cross_event]
-method = CowellPropagator(events=events, f=f)
+method = CowellPropagator(events=events, f=f_hf)
rr, _ = orbit.to_ephem(
EpochsArray(orbit.epoch + tofs, method=method),
).rv()
@@ -113,7 +114,7 @@ print(
Let's see how did the orbit's altitude vary with time:
-```{code-cell}
+```{code-cell} ipython3
altitudes = np.apply_along_axis(
norm, 1, (rr << u.km).value
) - Earth.R.to_value(u.km)
@@ -130,7 +131,7 @@ Refer to the API documentation of the events to check the default values for `te
Similar to the `AltitudeCrossEvent`, just pass the threshold latitude while instantiating the event.
-```{code-cell}
+```{code-cell} ipython3
orbit = Orbit.from_classical(
Earth,
6900 << u.km,
@@ -142,7 +143,7 @@ orbit = Orbit.from_classical(
)
```
-```{code-cell}
+```{code-cell} ipython3
thresh_lat = 35 << u.deg
latitude_cross_event = LatitudeCrossEvent(orbit, thresh_lat, terminal=True)
events = [latitude_cross_event]
@@ -157,16 +158,14 @@ print(
Let's plot the latitude varying with time:
-```{code-cell}
-from hapsira.core.spheroid_location import cartesian_to_ellipsoidal
+```{code-cell} ipython3
+from hapsira.core.spheroid_location import cartesian_to_ellipsoidal_gf
-latitudes = []
-for r in rr:
- position_on_body = (r / norm(r)) * Earth.R
- _, lat, _ = cartesian_to_ellipsoidal(
- Earth.R, Earth.R_polar, *position_on_body
- )
- latitudes.append(np.rad2deg(lat))
+position_on_body = (rr.to_value(u.km) / norm(rr.to_value(u.km), axis = 1)[:, None]) * Earth.R.to_value(u.km)
+_, latitudes, _ = cartesian_to_ellipsoidal_gf(
+ Earth.R.to_value(u.km), Earth.R_polar.to_value(u.km), *position_on_body.T
+)
+latitudes = np.rad2deg(latitudes)
plt.plot(tofs[: len(rr)].to_value(u.s), latitudes)
plt.title("Latitude variation")
plt.ylabel("Latitude (in degrees)")
@@ -178,7 +177,7 @@ The orbit's latitude would not change after the event was detected since we had
Since the attractor is `Earth`, we could use `GroundtrackPlotter` for showing the groundtrack of the
orbit on Earth.
-```{code-cell}
+```{code-cell} ipython3
from hapsira.earth import EarthSatellite
from hapsira.earth.plotting import GroundtrackPlotter
from hapsira.plotting import OrbitPlotter
@@ -208,7 +207,7 @@ gp.plot(
Viewing it in the `orthographic` projection mode,
-```{code-cell}
+```{code-cell} ipython3
gp.update_geos(projection_type="orthographic")
gp.fig.show()
```
@@ -220,9 +219,7 @@ and voila! The groundtrack terminates almost at the 35 degree latitude mark.
Users can detect umbra/penumbra crossings using the `UmbraEvent` and `PenumbraEvent` event classes,
respectively. As seen from the above examples, the procedure doesn't change much.
-```{code-cell}
-from hapsira.core.events import eclipse_function
-
+```{code-cell} ipython3
attractor = Earth
tof = 2 * u.d
# Classical orbital elements
@@ -239,11 +236,12 @@ orbit = Orbit.from_classical(attractor, *coe)
Let's search for umbra crossings.
-```{code-cell}
-umbra_event = UmbraEvent(orbit, terminal=True)
+```{code-cell} ipython3
+tofs = np.arange(0, 600, 30) << u.s
+
+umbra_event = UmbraEvent(orbit, tof, terminal=True)
events = [umbra_event]
-tofs = np.arange(0, 600, 30) << u.s
method = CowellPropagator(events=events)
rr, vv = orbit.to_ephem(EpochsArray(orbit.epoch + tofs, method=method)).rv()
print(
@@ -255,7 +253,9 @@ print(
Let us plot the eclipse functions' variation with time.
-```{code-cell}
+```{code-cell} ipython3
+from hapsira.core.events import eclipse_function_gf, ECLIPSE_UMBRA
+
k = Earth.k.to_value(u.km**3 / u.s**2)
R_sec = Sun.R.to_value(u.km)
R_pri = Earth.R.to_value(u.km)
@@ -268,12 +268,7 @@ r_sec = ((r_sec_ssb - r_pri_ssb).xyz << u.km).value
rr = (rr << u.km).value
vv = (vv << u.km / u.s).value
-eclipses = [] # List to store values of eclipse_function.
-for i in range(len(rr)):
- r = rr[i]
- v = vv[i]
- eclipse = eclipse_function(k, np.hstack((r, v)), r_sec, R_sec, R_pri)
- eclipses.append(eclipse)
+eclipses = eclipse_function_gf(k, rr, vv, r_sec, R_sec, R_pri, ECLIPSE_UMBRA)
plt.xlabel("Time (s)")
plt.ylabel("Eclipse function")
@@ -285,7 +280,7 @@ plt.plot(tofs[: len(rr)].to_value(u.s), eclipses)
We could get some geometrical insights by plotting the orbit:
-```{code-cell}
+```{code-cell} ipython3
# Plot `Earth` at the instant of event occurence.
Earth.plot(
orbit.epoch.tdb + umbra_event.last_t,
@@ -311,7 +306,7 @@ It seems our satellite is exiting the umbra region, as is evident from the orang
This event detector aims to check for ascending and descending node crossings. Note that it could
yield inaccurate results if the orbit is near-equatorial.
-```{code-cell}
+```{code-cell} ipython3
r = [-3182930.668, 94242.56, -85767.257] << u.km
v = [505.848, 942.781, 7435.922] << u.km / u.s
orbit = Orbit.from_vectors(Earth, r, v)
@@ -319,13 +314,13 @@ orbit = Orbit.from_vectors(Earth, r, v)
As a sanity check, let's check the orbit's inclination to ensure it is not near-zero:
-```{code-cell}
+```{code-cell} ipython3
print(orbit.inc)
```
Indeed, it isn't!
-```{code-cell}
+```{code-cell} ipython3
node_event = NodeCrossEvent(terminal=True)
events = [node_event]
@@ -338,7 +333,7 @@ print(f"The nodal cross time was {node_event.last_t} after the orbit's epoch")
The plot below shows us the variation of the z coordinate of the orbit's position vector with time:
-```{code-cell}
+```{code-cell} ipython3
z_coords = [r[-1].to_value(u.km) for r in rr]
plt.xlabel("Time (s)")
plt.ylabel("Z coordinate of the position vector")
@@ -348,7 +343,7 @@ plt.plot(tofs[: len(rr)].to_value(u.s), z_coords)
We could do the same plotting done in `LatitudeCrossEvent` to check for equatorial crossings:
-```{code-cell}
+```{code-cell} ipython3
es = EarthSatellite(orbit, None)
# Show the groundtrack plot from
@@ -374,7 +369,7 @@ gp.plot(
)
```
-```{code-cell}
+```{code-cell} ipython3
gp.update_geos(projection_type="orthographic")
gp.fig.show()
```
@@ -388,7 +383,7 @@ either of the two crossings, the `direction` attribute is at our disposal!
If we would like to track multiple events while propagating an orbit, we just need to add the concerned events inside `events`.
Below, we show the case where `NodeCrossEvent` and `LatitudeCrossEvent` events are to be detected.
-```{code-cell}
+```{code-cell} ipython3
# NodeCrossEvent is detected earlier than the LatitudeCrossEvent.
r = [-6142438.668, 3492467.56, -25767.257] << u.km
v = [505.848, 942.781, 7435.922] << u.km / u.s
@@ -406,9 +401,9 @@ tofs = [1, 2, 4, 6, 8, 10, 12] << u.s
method = CowellPropagator(events=events)
rr, vv = orbit.to_ephem(EpochsArray(orbit.epoch + tofs, method=method)).rv()
-print(f"Node cross event termination time: {node_cross_event.last_t} s")
+print(f"Node cross event termination time: {node_cross_event.last_t}")
print(
- f"Latitude cross event termination time: {latitude_cross_event.last_t} s"
+ f"Latitude cross event termination time: {latitude_cross_event.last_t}"
)
```
diff --git a/docs/source/examples/going-to-jupiter-with-python-using-jupyter-and-poliastro.myst.md b/docs/source/examples/going-to-jupiter-with-python-using-jupyter-and-hapsira.myst.md
similarity index 98%
rename from docs/source/examples/going-to-jupiter-with-python-using-jupyter-and-poliastro.myst.md
rename to docs/source/examples/going-to-jupiter-with-python-using-jupyter-and-hapsira.myst.md
index 61c73e3b0..e77c29649 100644
--- a/docs/source/examples/going-to-jupiter-with-python-using-jupyter-and-poliastro.myst.md
+++ b/docs/source/examples/going-to-jupiter-with-python-using-jupyter-and-hapsira.myst.md
@@ -37,7 +37,7 @@ from hapsira.twobody import Orbit
from hapsira.util import norm, time_range
```
-All the data for Juno's mission is sorted [here](https://github.com/hapsira/hapsira/wiki/EuroPython:-Per-Python-ad-Astra). The main maneuvers that the spacecraft will perform are listed down:
+All the data for Juno's mission is sorted [here](https://github.com/poliastro/poliastro/wiki/EuroPython:-Per-Python-ad-Astra). The main maneuvers that the spacecraft will perform are listed down:
* Inner cruise phase 1: This will set Juno in a new orbit around the sun.
* Inner cruise phase 2: Fly-by around Earth. Gravity assist is performed.
diff --git a/docs/source/examples/going-to-mars-with-python-using-poliastro.myst.md b/docs/source/examples/going-to-mars-with-python-using-hapsira.myst.md
similarity index 91%
rename from docs/source/examples/going-to-mars-with-python-using-poliastro.myst.md
rename to docs/source/examples/going-to-mars-with-python-using-hapsira.myst.md
index 676c8ed7a..6c59378e0 100644
--- a/docs/source/examples/going-to-mars-with-python-using-poliastro.myst.md
+++ b/docs/source/examples/going-to-mars-with-python-using-hapsira.myst.md
@@ -13,7 +13,7 @@ kernelspec:
# Going to Mars with Python using hapsira
-This is an example on how to use [hapsira](https://github.com/hapsira/hapsira), a little library I've been working on to use in my Astrodynamics lessons. It features conversion between **classical orbital elements** and position vectors, propagation of **Keplerian orbits**, initial orbit determination using the solution of the **Lambert's problem** and **orbit plotting**.
+This is an example on how to use [hapsira](https://github.com/pleiszenburg/hapsira), a little library I've been working on to use in my Astrodynamics lessons. It features conversion between **classical orbital elements** and position vectors, propagation of **Keplerian orbits**, initial orbit determination using the solution of the **Lambert's problem** and **orbit plotting**.
In this example we're going to draw the trajectory of the mission [Mars Science Laboratory (MSL)](http://mars.jpl.nasa.gov/msl/), which carried the rover Curiosity to the surface of Mars in a period of something less than 9 months.
diff --git a/docs/source/examples/loading-OMM-and-TLE-satellite-data.myst.md b/docs/source/examples/loading-OMM-and-TLE-satellite-data.myst.md
index a4778698e..b4c9c3b97 100644
--- a/docs/source/examples/loading-OMM-and-TLE-satellite-data.myst.md
+++ b/docs/source/examples/loading-OMM-and-TLE-satellite-data.myst.md
@@ -28,7 +28,7 @@ kernelspec:
+++
-However, it turns out that GP data in general, and TLEs in particular, are poorly understood even by professionals ([[1]](https://www.linkedin.com/posts/tom-johnson-32333a2_flawed-data-activity-6825845118990381056-yJX7), [[2]](https://twitter.com/flightclubio/status/1435303066085982209), [[3]](https://github.com/hapsira/hapsira/issues/1185)). The core issue is that TLEs and OMMs contain _Brouwer mean elements_, which **cannot be directly translated to osculating elements**.
+However, it turns out that GP data in general, and TLEs in particular, are poorly understood even by professionals ([[1]](https://www.linkedin.com/posts/tom-johnson-32333a2_flawed-data-activity-6825845118990381056-yJX7), [[2]](https://twitter.com/flightclubio/status/1435303066085982209), [[3]](https://github.com/poliastro/poliastro/issues/1185)). The core issue is that TLEs and OMMs contain _Brouwer mean elements_, which **cannot be directly translated to osculating elements**.
From "Spacetrack Report #3":
@@ -58,7 +58,7 @@ Therefore, the **correct** way of using GP data is:
As explained in the [Orbit Mean-Elements Messages (OMMs) support assessment](https://opensatcom.org/2020/12/28/omm-assessment-sgp4-benchmarks/) deliverable of OpenSatCom, OMM input/output support in open source libraries is somewhat scattered. Luckily, [python-sgp4](https://pypi.org/project/sgp4/) supports reading OMM in CSV and XML format, as well as usual TLE and 3LE formats. On the other hand, Astropy has accurate transformations from TEME to other reference frames.
```{code-cell} ipython3
-# From https://github.com/hapsira/hapsira/blob/main/contrib/satgpio.py
+# From https://github.com/pleiszenburg/hapsira/blob/main/contrib/satgpio.py
"""
Author: Juan Luis Cano Rodríguez
diff --git a/docs/source/examples/natural-and-artificial-perturbations.myst.md b/docs/source/examples/natural-and-artificial-perturbations.myst.md
index 0f80764f2..0a64d80a3 100644
--- a/docs/source/examples/natural-and-artificial-perturbations.myst.md
+++ b/docs/source/examples/natural-and-artificial-perturbations.myst.md
@@ -4,7 +4,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
- jupytext_version: 1.14.1
+ jupytext_version: 1.16.0
kernelspec:
display_name: Python 3 (ipykernel)
language: python
@@ -24,13 +24,15 @@ from astropy import units as u
from hapsira.bodies import Earth, Moon
from hapsira.constants import rho0_earth, H0_earth
-from hapsira.core.elements import rv2coe
+from hapsira.core.elements import rv2coe_gf, RV2COE_TOL
+from hapsira.core.jit import djit, hjit
+from hapsira.core.math.linalg import add_VV_hf
from hapsira.core.perturbations import (
- atmospheric_drag_exponential,
- third_body,
- J2_perturbation,
+ atmospheric_drag_exponential_hf,
+ third_body_hf,
+ J2_perturbation_hf,
)
-from hapsira.core.propagation import func_twobody
+from hapsira.core.propagation.base import func_twobody_hf
from hapsira.ephem import build_ephem_interpolant
from hapsira.plotting import OrbitPlotter
from hapsira.plotting.orbit.backends import Plotly3D
@@ -64,12 +66,13 @@ H0 = H0_earth.to(u.km).value
tofs = TimeDelta(np.linspace(0 * u.h, 100000 * u.s, num=2000))
-
-def f(t0, state, k):
- du_kep = func_twobody(t0, state, k)
- ax, ay, az = atmospheric_drag_exponential(
+@djit
+def f_hf(t0, rr, vv, k):
+ du_kep_rr, du_kep_vv = func_twobody_hf(t0, rr, vv, k)
+ a = atmospheric_drag_exponential_hf(
t0,
- state,
+ rr,
+ vv,
k,
R=R,
C_D=C_D,
@@ -77,13 +80,10 @@ def f(t0, state, k):
H0=H0,
rho0=rho0,
)
- du_ad = np.array([0, 0, 0, ax, ay, az])
-
- return du_kep + du_ad
-
+ return du_kep_rr, add_VV_hf(du_kep_vv, a)
rr, _ = orbit.to_ephem(
- EpochsArray(orbit.epoch + tofs, method=CowellPropagator(f=f)),
+ EpochsArray(orbit.epoch + tofs, method=CowellPropagator(f=f_hf)),
).rv()
```
@@ -116,7 +116,7 @@ events = [lithobrake_event]
rr, _ = orbit.to_ephem(
EpochsArray(
- orbit.epoch + tofs, method=CowellPropagator(f=f, events=events)
+ orbit.epoch + tofs, method=CowellPropagator(f=f_hf, events=events)
),
).rv()
@@ -144,26 +144,26 @@ v0 = np.array([-7.36138, -2.98997, 1.64354]) * u.km / u.s
orbit = Orbit.from_vectors(Earth, r0, v0)
tofs = TimeDelta(np.linspace(0, 48.0 * u.h, num=2000))
-
-
-def f(t0, state, k):
- du_kep = func_twobody(t0, state, k)
- ax, ay, az = J2_perturbation(
- t0, state, k, J2=Earth.J2.value, R=Earth.R.to(u.km).value
+_J2 = Earth.J2.value
+_R = Earth.R.to(u.km).value
+
+@djit
+def f_hf(t0, rr, vv, k):
+ du_kep_rr, du_kep_vv = func_twobody_hf(t0, rr, vv, k)
+ a = J2_perturbation_hf(
+ t0, rr, vv, k, J2=_J2, R=_R
)
- du_ad = np.array([0, 0, 0, ax, ay, az])
-
- return du_kep + du_ad
+ return du_kep_rr, add_VV_hf(du_kep_vv, a)
rr, vv = orbit.to_ephem(
- EpochsArray(orbit.epoch + tofs, method=CowellPropagator(f=f)),
+ EpochsArray(orbit.epoch + tofs, method=CowellPropagator(f=f_hf)),
).rv()
# This will be easier to compute when this is solved:
-# https://github.com/hapsira/hapsira/issues/380
+# https://github.com/poliastro/poliastro/issues/380
raans = [
- rv2coe(k, r, v)[3]
+ rv2coe_gf(k, r, v, RV2COE_TOL)[3]
for r, v in zip(rr.to_value(u.km), vv.to_value(u.km / u.s))
]
```
@@ -205,33 +205,33 @@ initial = Orbit.from_classical(
)
tofs = TimeDelta(np.linspace(0, 60 * u.day, num=1000))
+_moon_k = Moon.k.to(u.km**3 / u.s**2).value
-
-def f(t0, state, k):
- du_kep = func_twobody(t0, state, k)
- ax, ay, az = third_body(
+@djit(cache = False)
+def f_hf(t0, rr, vv, k):
+ du_kep_rr, du_kep_vv = func_twobody_hf(t0, rr, vv, k)
+ a = third_body_hf(
t0,
- state,
+ rr,
+ vv,
k,
- k_third=400 * Moon.k.to(u.km**3 / u.s**2).value,
+ k_third=400 * _moon_k,
perturbation_body=body_r,
)
- du_ad = np.array([0, 0, 0, ax, ay, az])
-
- return du_kep + du_ad
+ return du_kep_rr, add_VV_hf(du_kep_vv, a)
# multiply Moon gravity by 400 so that effect is visible :)
ephem = initial.to_ephem(
- EpochsArray(initial.epoch + tofs, method=CowellPropagator(rtol=1e-6, f=f)),
+ EpochsArray(initial.epoch + tofs, method=CowellPropagator(rtol=1e-6, f=f_hf)),
)
```
```{code-cell} ipython3
frame = OrbitPlotter(backend=Plotly3D())
-
frame.set_attractor(Earth)
frame.plot_ephem(ephem, label="orbit influenced by Moon")
+frame.show()
```
## Applying thrust
@@ -260,33 +260,31 @@ orb0 = Orbit.from_classical(
epoch=Time(0, format="jd", scale="tdb"),
)
-a_d, _, t_f = change_ecc_inc(orb0, ecc_f, inc_f, f)
-
+a_d_hf, _, t_f = change_ecc_inc(orb0, ecc_f, inc_f, f)
-def f(t0, state, k):
- du_kep = func_twobody(t0, state, k)
- ax, ay, az = a_d(
+@djit
+def f_hf(t0, rr, vv, k):
+ du_kep_rr, du_kep_vv = func_twobody_hf(t0, rr, vv, k)
+ a = a_d_hf(
t0,
- state,
+ rr,
+ vv,
k,
)
- du_ad = np.array([0, 0, 0, ax, ay, az])
-
- return du_kep + du_ad
-
+ return du_kep_rr, add_VV_hf(du_kep_vv, a)
tofs = TimeDelta(np.linspace(0, t_f, num=1000))
ephem2 = orb0.to_ephem(
- EpochsArray(orb0.epoch + tofs, method=CowellPropagator(rtol=1e-6, f=f)),
+ EpochsArray(orb0.epoch + tofs, method=CowellPropagator(rtol=1e-6, f=f_hf)),
)
```
```{code-cell} ipython3
frame = OrbitPlotter(backend=Plotly3D())
-
frame.set_attractor(Earth)
frame.plot_ephem(ephem2, label="orbit with artificial thrust")
+frame.show()
```
## Combining multiple perturbations
@@ -294,13 +292,13 @@ frame.plot_ephem(ephem2, label="orbit with artificial thrust")
It might be of interest to determine what effect multiple perturbations have on a single object. In order to add multiple perturbations we can create a custom function that adds them up:
```{code-cell} ipython3
-from numba import njit as jit
-
-# Add @jit for speed!
-@jit
-def a_d(t0, state, k, J2, R, C_D, A_over_m, H0, rho0):
- return J2_perturbation(t0, state, k, J2, R) + atmospheric_drag_exponential(
- t0, state, k, R, C_D, A_over_m, H0, rho0
+@hjit("V(f,V,V,f,f,f,f,f,f,f)")
+def a_d_hf(t0, rr, vv, k, J2, R, C_D, A_over_m, H0, rho0):
+ return add_VV_hf(
+ J2_perturbation_hf(t0, rr, vv, k, J2, R),
+ atmospheric_drag_exponential_hf(
+ t0, rr, vv, k, R, C_D, A_over_m, H0, rho0
+ )
)
```
@@ -308,52 +306,50 @@ def a_d(t0, state, k, J2, R, C_D, A_over_m, H0, rho0):
# propagation times of flight and orbit
tofs = TimeDelta(np.linspace(0, 10 * u.day, num=10 * 500))
orbit = Orbit.circular(Earth, 250 * u.km) # recall orbit from drag example
+_J2 = Earth.J2.value
-
-def f(t0, state, k):
- du_kep = func_twobody(t0, state, k)
- ax, ay, az = a_d(
+@djit
+def f_hf(t0, rr, vv, k):
+ du_kep_rr, du_kep_vv = func_twobody_hf(t0, rr, vv, k)
+ a = a_d_hf(
t0,
- state,
+ rr,
+ vv,
k,
- R=R,
- C_D=C_D,
- A_over_m=A_over_m,
- H0=H0,
- rho0=rho0,
- J2=Earth.J2.value,
+ _J2,
+ R,
+ C_D,
+ A_over_m,
+ H0,
+ rho0,
)
- du_ad = np.array([0, 0, 0, ax, ay, az])
-
- return du_kep + du_ad
-
+ return du_kep_rr, add_VV_hf(du_kep_vv, a)
# propagate with J2 and atmospheric drag
rr3, _ = orbit.to_ephem(
- EpochsArray(orbit.epoch + tofs, method=CowellPropagator(f=f)),
+ EpochsArray(orbit.epoch + tofs, method=CowellPropagator(f=f_hf)),
).rv()
-
-def f(t0, state, k):
- du_kep = func_twobody(t0, state, k)
- ax, ay, az = atmospheric_drag_exponential(
+@djit
+def f_hf(t0, rr, vv, k):
+ du_kep_rr, du_kep_vv = func_twobody_hf(t0, rr, vv, k)
+ a = atmospheric_drag_exponential_hf(
t0,
- state,
+ rr,
+ vv,
k,
- R=R,
- C_D=C_D,
- A_over_m=A_over_m,
- H0=H0,
- rho0=rho0,
+ R,
+ C_D,
+ A_over_m,
+ H0,
+ rho0,
)
- du_ad = np.array([0, 0, 0, ax, ay, az])
-
- return du_kep + du_ad
+ return du_kep_rr, add_VV_hf(du_kep_vv, a)
# propagate with only atmospheric drag
rr4, _ = orbit.to_ephem(
- EpochsArray(orbit.epoch + tofs, method=CowellPropagator(f=f)),
+ EpochsArray(orbit.epoch + tofs, method=CowellPropagator(f=f_hf)),
).rv()
```
diff --git a/docs/source/examples/porkchops-with-poliastro.myst.md b/docs/source/examples/porkchops-with-hapsira.myst.md
similarity index 100%
rename from docs/source/examples/porkchops-with-poliastro.myst.md
rename to docs/source/examples/porkchops-with-hapsira.myst.md
diff --git a/docs/source/examples/propagation-using-cowells-formulation.myst.md b/docs/source/examples/propagation-using-cowells-formulation.myst.md
index 74c87eb21..3019474d9 100644
--- a/docs/source/examples/propagation-using-cowells-formulation.myst.md
+++ b/docs/source/examples/propagation-using-cowells-formulation.myst.md
@@ -4,7 +4,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
- jupytext_version: 1.14.1
+ jupytext_version: 1.16.0
kernelspec:
display_name: Python 3 (ipykernel)
language: python
@@ -25,7 +25,7 @@ $$\ddot{\mathbb{r}} = -\frac{\mu}{|\mathbb{r}|^3} \mathbb{r} + \mathbb{a}_d$$
+++
-
An earlier version of this notebook allowed for more flexibility and interactivity, but was considerably more complex. Future versions of hapsira and plotly might bring back part of that functionality, depending on user feedback. You can still download the older version [here](https://github.com/hapsira/hapsira/blob/0.8.x/docs/source/examples/Propagation%20using%20Cowell's%20formulation.ipynb).
+An earlier version of this notebook allowed for more flexibility and interactivity, but was considerably more complex. Future versions of hapsira and plotly might bring back part of that functionality, depending on user feedback. You can still download the older version [here](https://github.com/pleiszenburg/hapsira/blob/0.8.x/docs/source/examples/Propagation%20using%20Cowell's%20formulation.ipynb).
+++
@@ -40,7 +40,9 @@ from astropy import units as u
import numpy as np
from hapsira.bodies import Earth
-from hapsira.core.propagation import func_twobody
+from hapsira.core.jit import djit, hjit
+from hapsira.core.math.linalg import add_VV_hf, mul_Vs_hf, norm_V_hf
+from hapsira.core.propagation.base import func_twobody_hf
from hapsira.examples import iss
from hapsira.plotting import OrbitPlotter
from hapsira.plotting.orbit.backends import Plotly3D
@@ -57,22 +59,24 @@ accel = 2e-5
```
```{code-cell} ipython3
-def constant_accel_factory(accel):
- def constant_accel(t0, u, k):
- v = u[3:]
- norm_v = (v[0] ** 2 + v[1] ** 2 + v[2] ** 2) ** 0.5
- return accel * v / norm_v
+def constant_accel_hb(accel):
- return constant_accel
+ @hjit("V(f,V,V,f)", cache = False)
+ def constant_accel_hf(t0, rr, vv, k):
+ norm_v = norm_V_hf(vv)
+ return mul_Vs_hf(vv, accel / norm_v)
+
+ return constant_accel_hf
```
```{code-cell} ipython3
-def f(t0, state, k):
- du_kep = func_twobody(t0, state, k)
- ax, ay, az = constant_accel_factory(accel)(t0, state, k)
- du_ad = np.array([0, 0, 0, ax, ay, az])
+constant_accel_hf = constant_accel_hb(accel)
- return du_kep + du_ad
+@djit
+def f_hf(t0, rr, vv, k):
+ du_kep_rr, du_kep_vv = func_twobody_hf(t0, rr, vv, k)
+ a = constant_accel_hf(t0, rr, vv, k)
+ return du_kep_rr, add_VV_hf(du_kep_vv, a)
```
```{code-cell} ipython3
@@ -82,7 +86,7 @@ times
```{code-cell} ipython3
ephem = iss.to_ephem(
- EpochsArray(iss.epoch + times, method=CowellPropagator(rtol=1e-11, f=f)),
+ EpochsArray(iss.epoch + times, method=CowellPropagator(rtol=1e-11, f=f_hf)),
)
```
@@ -93,6 +97,7 @@ frame = OrbitPlotter(backend=Plotly3D())
frame.set_attractor(Earth)
frame.plot_ephem(ephem, label="ISS")
+frame.show()
```
## Error checking
@@ -165,18 +170,15 @@ So let's create a new circular orbit and perform the necessary checks, assuming
orb = Orbit.circular(Earth, 500 << u.km)
tof = 20 * orb.period
-ad = constant_accel_factory(1e-7)
-
-
-def f(t0, state, k):
- du_kep = func_twobody(t0, state, k)
- ax, ay, az = ad(t0, state, k)
- du_ad = np.array([0, 0, 0, ax, ay, az])
-
- return du_kep + du_ad
+ad_hf = constant_accel_hb(1e-7)
+@djit
+def f_hf(t0, rr, vv, k):
+ du_kep_rr, du_kep_vv = func_twobody_hf(t0, rr, vv, k)
+ a = ad_hf(t0, rr, vv, k)
+ return du_kep_rr, add_VV_hf(du_kep_vv, a)
-orb_final = orb.propagate(tof, method=CowellPropagator(f=f))
+orb_final = orb.propagate(tof, method=CowellPropagator(f=f_hf))
```
```{code-cell} ipython3
diff --git a/docs/source/examples/revisiting-lamberts-problem-in-python.myst.md b/docs/source/examples/revisiting-lamberts-problem-in-python.myst.md
index d931ea226..b96a1cf6f 100644
--- a/docs/source/examples/revisiting-lamberts-problem-in-python.myst.md
+++ b/docs/source/examples/revisiting-lamberts-problem-in-python.myst.md
@@ -4,9 +4,9 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
- jupytext_version: 1.14.0
+ jupytext_version: 1.16.0
kernelspec:
- display_name: Python 3
+ display_name: Python 3 (ipykernel)
language: python
name: python3
---
@@ -15,25 +15,25 @@ kernelspec:
The Izzo algorithm to solve the Lambert problem is available in hapsira and was implemented from [this paper](https://arxiv.org/abs/1403.2705).
-```{code-cell}
+```{code-cell} ipython3
from cycler import cycler
from matplotlib import pyplot as plt
import numpy as np
-from hapsira.core import iod
+from hapsira.core.iod import compute_y_vf, tof_equation_y_vf, compute_T_min_gf, find_xy_gf
from hapsira.iod import izzo
```
## Part 1: Reproducing the original figure
-```{code-cell}
+```{code-cell} ipython3
x = np.linspace(-1, 2, num=1000)
M_list = 0, 1, 2, 3
ll_list = 1, 0.9, 0.7, 0, -0.7, -0.9, -1
```
-```{code-cell}
+```{code-cell} ipython3
fig, ax = plt.subplots(figsize=(10, 8))
ax.set_prop_cycle(
cycler("linestyle", ["-", "--"])
@@ -43,8 +43,8 @@ for M in M_list:
for ll in ll_list:
T_x0 = np.zeros_like(x)
for ii in range(len(x)):
- y = iod._compute_y(x[ii], ll)
- T_x0[ii] = iod._tof_equation_y(x[ii], y, 0.0, ll, M)
+ y = compute_y_vf(x[ii], ll)
+ T_x0[ii] = tof_equation_y_vf(x[ii], y, 0.0, ll, M)
if M == 0 and ll == 1:
T_x0[x > 0] = np.nan
elif M > 0:
@@ -86,12 +86,12 @@ ax.set_ylabel("$T$")
## Part 2: Locating $T_{min}$
-```{code-cell}
+```{code-cell} ipython3
:tags: [nbsphinx-thumbnail]
for M in M_list:
for ll in ll_list:
- x_T_min, T_min = iod._compute_T_min(ll, M, 10, 1e-8)
+ x_T_min, T_min = compute_T_min_gf(ll, M, 10, 1e-8)
ax.plot(x_T_min, T_min, "kx", mew=2)
fig
@@ -99,17 +99,21 @@ fig
## Part 3: Try out solution
-```{code-cell}
+```{code-cell} ipython3
T_ref = 1
ll_ref = 0
-x_ref, _ = iod._find_xy(
- ll_ref, T_ref, M=0, numiter=10, lowpath=True, rtol=1e-8
+x_ref, _ = find_xy_gf(
+ ll_ref, T_ref,
+ 0, # M
+ 10, # numiter
+ True, # lowpath
+ 1e-8, # rtol
)
x_ref
```
-```{code-cell}
+```{code-cell} ipython3
ax.plot(x_ref, T_ref, "o", mew=2, mec="red", mfc="none")
fig
@@ -117,7 +121,7 @@ fig
## Part 4: Run some examples
-```{code-cell}
+```{code-cell} ipython3
from astropy import units as u
from hapsira.bodies import Earth
@@ -125,7 +129,7 @@ from hapsira.bodies import Earth
### Single revolution
-```{code-cell}
+```{code-cell} ipython3
k = Earth.k
r0 = [15945.34, 0.0, 0.0] * u.km
r = [12214.83399, 10249.46731, 0.0] * u.km
@@ -138,7 +142,7 @@ v0, v = izzo.lambert(k, r0, r, tof)
v
```
-```{code-cell}
+```{code-cell} ipython3
k = Earth.k
r0 = [5000.0, 10000.0, 2100.0] * u.km
r = [-14600.0, 2500.0, 7000.0] * u.km
@@ -153,7 +157,7 @@ v
### Multiple revolutions
-```{code-cell}
+```{code-cell} ipython3
k = Earth.k
r0 = [22592.145603, -1599.915239, -19783.950506] * u.km
r = [1922.067697, 4054.157051, -8925.727465] * u.km
@@ -169,20 +173,20 @@ expected_va_r = [-2.45759553, 1.16945801, 0.43161258] * u.km / u.s
expected_vb_r = [-5.53841370, 0.01822220, 5.49641054] * u.km / u.s
```
-```{code-cell}
+```{code-cell} ipython3
v0, v = izzo.lambert(k, r0, r, tof, M=0)
v
```
-```{code-cell}
+```{code-cell} ipython3
_, v_l = izzo.lambert(k, r0, r, tof, M=1, lowpath=True)
_, v_r = izzo.lambert(k, r0, r, tof, M=1, lowpath=False)
```
-```{code-cell}
+```{code-cell} ipython3
v_l
```
-```{code-cell}
+```{code-cell} ipython3
v_r
```
diff --git a/docs/source/index.md b/docs/source/index.md
index 581eba39d..c60ecd940 100644
--- a/docs/source/index.md
+++ b/docs/source/index.md
@@ -161,6 +161,7 @@ caption: How-to guides & Examples
---
gallery
contributing
+core
```
```{toctree}
diff --git a/docs/source/quickstart.md b/docs/source/quickstart.md
index 10033448e..e89cfd278 100644
--- a/docs/source/quickstart.md
+++ b/docs/source/quickstart.md
@@ -155,6 +155,7 @@ To explore different propagation algorithms, check out the {py:mod}`hapsira.twob
The `propagate` method gives you the final orbit at the epoch you designated. To retrieve the whole trajectory instead, you can use {py:meth}`hapsira.twobody.orbit.scalar.Orbit.to_ephem`, which returns an {{ Ephem }} instance:
```python
+from astropy.time import Time
from hapsira.twobody.sampling import EpochsArray, TrueAnomalyBounds, EpochBounds
from hapsira.util import time_range
@@ -165,7 +166,7 @@ end_date = Time("2022-07-11 07:05", scale="utc")
ephem1 = iss.to_ephem()
# Explicit times given
-ephem2 = iss.to_ephem(strategy=EpochsArray(epochs=time_range(start_date, end_date)))
+ephem2 = iss.to_ephem(strategy=EpochsArray(epochs=time_range(start_date, end=end_date)))
# Automatic grid, true anomaly limits
ephem3 = iss.to_ephem(strategy=TrueAnomalyBounds(min_nu=0 << u.deg, max_nu=180 << u.deg))
@@ -193,44 +194,45 @@ ephem4 = iss.to_ephem(strategy=EpochBounds(min_epoch=start_date, max_epoch=end_d
Apart from the Keplerian propagators, hapsira also allows you to define custom perturbation accelerations to study non Keplerian orbits, thanks to Cowell's method:
```python
->>> from numba import njit
>>> import numpy as np
->>> from hapsira.core.propagation import func_twobody
+>>> from hapsira.core.jit import hjit, djit
+>>> from hapsira.core.math.linalg import add_VV_hf, mul_Vs_hf, norm_V_hf
+>>> from hapsira.core.propagation.base import func_twobody_hf
>>> from hapsira.twobody.propagation import CowellPropagator
>>> r0 = [-2384.46, 5729.01, 3050.46] << u.km
>>> v0 = [-7.36138, -2.98997, 1.64354] << (u.km / u.s)
>>> initial = Orbit.from_vectors(Earth, r0, v0)
->>> @njit
-... def accel(t0, state, k):
+>>> @hjit("V(f,V,V,f)")
+... def accel_hf(t0, rr, vv, k):
... """Constant acceleration aligned with the velocity. """
-... v_vec = state[3:]
-... norm_v = (v_vec * v_vec).sum() ** 0.5
-... return 1e-5 * v_vec / norm_v
+... norm_v = norm_V_hf(vv)
+... return mul_Vs_hf(vv, 1e-5 / norm_v)
...
-... def f(t0, u_, k):
-... du_kep = func_twobody(t0, u_, k)
-... ax, ay, az = accel(t0, u_, k)
-... du_ad = np.array([0, 0, 0, ax, ay, az])
-... return du_kep + du_ad
+... @djit
+... def f_hf(t0, rr, vv, k):
+... du_kep_rr, du_kep_vv = func_twobody_hf(t0, rr, vv, k)
+... a = accel_hf(t0, rr, vv, k)
+... return du_kep_rr, add_VV_hf(du_kep_vv, a)
->>> initial.propagate(3 << u.day, method=CowellPropagator(f=f))
+>>> initial.propagate(3 << u.day, method=CowellPropagator(f=f_hf))
18255 x 21848 km x 28.0 deg (GCRS) orbit around Earth (♁) at epoch J2000.008 (TT)
```
Some natural perturbations are available in hapsira to be used directly in this way. For instance, to examine the effect of J2 perturbation:
```python
->>> from hapsira.core.perturbations import J2_perturbation
->>> tofs = [48.0] << u.h
->>> def f(t0, u_, k):
-... du_kep = func_twobody(t0, u_, k)
-... ax, ay, az = J2_perturbation(
-... t0, u_, k, J2=Earth.J2.value, R=Earth.R.to(u.km).value
+>>> from hapsira.core.perturbations import J2_perturbation_hf
+>>> tofs = 48.0 << u.h
+>>> _J2, _R = Earth.J2.value, Earth.R.to(u.km).value
+>>> @djit
+... def f_hf(t0, rr, vv, k):
+... du_kep_rr, du_kep_vv = func_twobody_hf(t0, rr, vv, k)
+... a = J2_perturbation_hf(
+... t0, rr, vv, k, _J2, _R
... )
-... du_ad = np.array([0, 0, 0, ax, ay, az])
-... return du_kep + du_ad
+... return du_kep_rr, add_VV_hf(du_kep_vv, a)
->>> final = initial.propagate(tofs, method=CowellPropagator(f=f))
+>>> final = initial.propagate(tofs, method=CowellPropagator(f=f_hf))
```
The J2 perturbation changes the orbit parameters (from Curtis example 12.2):
@@ -247,6 +249,7 @@ The J2 perturbation changes the orbit parameters (from Curtis example 12.2):
In addition to natural perturbations, hapsira also has built-in artificial perturbations (thrust guidance laws) aimed at intentional change of some orbital elements. For example, to simultaneously change eccentricity and inclination:
```python
+>>> from hapsira.twobody.thrust import change_ecc_inc
>>> ecc_0, ecc_f = [0.4, 0.0] << u.one
>>> a = 42164 << u.km
>>> inc_0 = 0.0 << u.deg # baseline
@@ -260,20 +263,20 @@ In addition to natural perturbations, hapsira also has built-in artificial pertu
... a,
... ecc_0,
... inc_0,
-... 0,
+... 0 << u.deg,
... argp,
-... 0,
+... 0 << u.deg,
... )
->>> a_d, _, t_f = change_ecc_inc(orb0, ecc_f, inc_f, f)
+>>> a_d_hf, _, t_f = change_ecc_inc(orb0, ecc_f, inc_f, f)
# Propagate orbit
->>> def f_geo(t0, u_, k):
-... du_kep = func_twobody(t0, u_, k)
-... ax, ay, az = a_d(t0, u_, k)
-... du_ad = np.array([0, 0, 0, ax, ay, az])
-... return du_kep + du_ad
+>>> @djit
+... def f_geo_hf(t0, rr, vv, k):
+... du_kep_rr, du_kep_vv = func_twobody_hf(t0, rr, vv, k)
+... a = a_d_hf(t0, rr, vv, k)
+... return du_kep_rr, add_VV_hf(du_kep_vv, a)
->>> orbf = orb0.propagate(t_f << u.s, method=CowellPropagator(f=f_geo, rtol=1e-8))
+>>> orbf = orb0.propagate(t_f << u.s, method=CowellPropagator(f=f_geo_hf, rtol=1e-8))
```
The thrust changes orbit parameters as desired (within errors):
@@ -346,8 +349,9 @@ To easily visualize several orbits in two dimensions, you can run this code:
```python
from hapsira.plotting import OrbitPlotter
+from hapsira.plotting.orbit.backends import Matplotlib2D
-op = OrbitPlotter(backend_name="matplotlib2D")
+op = OrbitPlotter(backend=Matplotlib2D())
orb_a, orb_f = orb_i.apply_maneuver(hoh, intermediate=True)
op.plot(orb_i, label="Initial orbit")
op.plot(orb_a, label="Transfer orbit")
@@ -377,7 +381,7 @@ The {py:class}`hapsira.ephem.Ephem` class allows you to retrieve a planetary orb
```python
>>> from astropy.time import Time
->>> epoch = time.Time("2020-04-29 10:43") # UTC by default
+>>> epoch = Time("2020-04-29 10:43") # UTC by default
>>> from hapsira.ephem import Ephem
>>> earth = Ephem.from_body(Earth, epoch.tdb)
>>> earth
@@ -440,9 +444,9 @@ And these are the results:
```python
>>> dv_a
-(, )
+(, )
>>> dv_b
-(, )
+(, )
```
```{figure} _static/msl.png
diff --git a/pyproject.toml b/pyproject.toml
index 441079dde..403d9199f 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -131,6 +131,7 @@ select = [
ignore = [
"E501", # Line too long. Ignoring this so "black" manages line length.
+ "I001", # import sort
]
exclude = [
@@ -143,7 +144,7 @@ exclude = [
[tool.ruff.isort]
combine-as-imports = true
-force-sort-within-sections = true
+force-sort-within-sections = false
known-first-party = ["hapsira"]
[tool.ruff.mccabe]
@@ -176,6 +177,10 @@ name = "hapsira.core does not import astropy.units"
type = "forbidden"
source_modules = ["hapsira.core"]
forbidden_modules = ["astropy.units"]
+ignore_imports = [
+ "hapsira.core.earth.atmosphere.coesa76 -> astropy.io.ascii",
+ "hapsira.core.earth.atmosphere.coesa76 -> astropy.utils.data"
+]
[tool.pytest.ini_options]
norecursedirs = [".git", ".tox", "dist", "build", ".venv"]
diff --git a/src/hapsira/_math/interpolate.py b/src/hapsira/_math/interpolate.py
deleted file mode 100644
index 711b1579b..000000000
--- a/src/hapsira/_math/interpolate.py
+++ /dev/null
@@ -1,33 +0,0 @@
-import numpy as np
-from scipy.interpolate import interp1d
-
-__all__ = ["interp1d", "spline_interp", "sinc_interp"]
-
-
-def spline_interp(y, x, u, *, kind="cubic"):
- """Interpolates y, sampled at x instants, at u instants using `scipy.interpolate.interp1d`."""
- y_u = interp1d(x, y, kind=kind)(u)
- return y_u
-
-
-def sinc_interp(y, x, u):
- """Interpolates y, sampled at x instants, at u instants using sinc interpolation.
-
- Notes
- -----
- Taken from https://gist.github.com/endolith/1297227.
- Possibly equivalent to `scipy.signal.resample`,
- see https://mail.python.org/pipermail/scipy-user/2012-January/031255.html.
- However, quick experiments show different ringing behavior.
-
- """
- if len(y) != len(x):
- raise ValueError("x and s must be the same length")
-
- # Find the period and assume it's constant
- T = x[1] - x[0]
-
- sincM = np.tile(u, (len(x), 1)) - np.tile(x[:, np.newaxis], (1, len(u)))
- y_u = y @ np.sinc(sincM / T)
-
- return y_u
diff --git a/src/hapsira/_math/ivp.py b/src/hapsira/_math/ivp.py
deleted file mode 100644
index f5b9c01b1..000000000
--- a/src/hapsira/_math/ivp.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from scipy.integrate import DOP853, solve_ivp
-
-__all__ = ["DOP853", "solve_ivp"]
diff --git a/src/hapsira/_math/linalg.py b/src/hapsira/_math/linalg.py
deleted file mode 100644
index a2845a43e..000000000
--- a/src/hapsira/_math/linalg.py
+++ /dev/null
@@ -1,7 +0,0 @@
-from numba import njit as jit
-import numpy as np
-
-
-@jit
-def norm(arr):
- return np.sqrt(arr @ arr)
diff --git a/src/hapsira/_math/optimize.py b/src/hapsira/_math/optimize.py
deleted file mode 100644
index 9ddcedca7..000000000
--- a/src/hapsira/_math/optimize.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from scipy.optimize import brentq
-
-__all__ = ["brentq"]
diff --git a/src/hapsira/_math/special.py b/src/hapsira/_math/special.py
deleted file mode 100644
index 6f704a965..000000000
--- a/src/hapsira/_math/special.py
+++ /dev/null
@@ -1,88 +0,0 @@
-from math import gamma
-
-from numba import njit as jit
-import numpy as np
-
-
-@jit
-def hyp2f1b(x):
- """Hypergeometric function 2F1(3, 1, 5/2, x), see [Battin].
-
- .. todo::
- Add more information about this function
-
- Notes
- -----
- More information about hypergeometric function can be checked at
- https://en.wikipedia.org/wiki/Hypergeometric_function
-
- """
- if x >= 1.0:
- return np.inf
- else:
- res = 1.0
- term = 1.0
- ii = 0
- while True:
- term = term * (3 + ii) * (1 + ii) / (5 / 2 + ii) * x / (ii + 1)
- res_old = res
- res += term
- if res_old == res:
- return res
- ii += 1
-
-
-@jit
-def stumpff_c2(psi):
- r"""Second Stumpff function.
-
- For positive arguments:
-
- .. math::
-
- c_2(\psi) = \frac{1 - \cos{\sqrt{\psi}}}{\psi}
-
- """
- eps = 1.0
- if psi > eps:
- res = (1 - np.cos(np.sqrt(psi))) / psi
- elif psi < -eps:
- res = (np.cosh(np.sqrt(-psi)) - 1) / (-psi)
- else:
- res = 1.0 / 2.0
- delta = (-psi) / gamma(2 + 2 + 1)
- k = 1
- while res + delta != res:
- res = res + delta
- k += 1
- delta = (-psi) ** k / gamma(2 * k + 2 + 1)
-
- return res
-
-
-@jit
-def stumpff_c3(psi):
- r"""Third Stumpff function.
-
- For positive arguments:
-
- .. math::
-
- c_3(\psi) = \frac{\sqrt{\psi} - \sin{\sqrt{\psi}}}{\sqrt{\psi^3}}
-
- """
- eps = 1.0
- if psi > eps:
- res = (np.sqrt(psi) - np.sin(np.sqrt(psi))) / (psi * np.sqrt(psi))
- elif psi < -eps:
- res = (np.sinh(np.sqrt(-psi)) - np.sqrt(-psi)) / (-psi * np.sqrt(-psi))
- else:
- res = 1.0 / 6.0
- delta = (-psi) / gamma(2 + 3 + 1)
- k = 1
- while res + delta != res:
- res = res + delta
- k += 1
- delta = (-psi) ** k / gamma(2 * k + 3 + 1)
-
- return res
diff --git a/src/hapsira/core/angles.py b/src/hapsira/core/angles.py
index 5e139f99f..92b44c354 100644
--- a/src/hapsira/core/angles.py
+++ b/src/hapsira/core/angles.py
@@ -1,53 +1,193 @@
-from numba import njit as jit
-import numpy as np
+from math import (
+ asinh,
+ atan,
+ atan2,
+ atanh,
+ cos,
+ cosh,
+ nan,
+ pi,
+ sin,
+ sinh,
+ sqrt,
+ tan,
+ tanh,
+)
+from .jit import hjit, vjit
+
+
+_TOL = 1.48e-08
+
+
+__all__ = [
+ "E_to_M_hf",
+ "E_to_M_vf",
+ "F_to_M_hf",
+ "F_to_M_vf",
+ "kepler_equation_hf",
+ "kepler_equation_prime_hf",
+ "kepler_equation_hyper_hf",
+ "kepler_equation_prime_hyper_hf",
+ "D_to_nu_hf",
+ "D_to_nu_vf",
+ "nu_to_D_hf",
+ "nu_to_D_vf",
+ "nu_to_E_hf",
+ "nu_to_E_vf",
+ "nu_to_F_hf",
+ "nu_to_F_vf",
+ "E_to_nu_hf",
+ "E_to_nu_vf",
+ "F_to_nu_hf",
+ "F_to_nu_vf",
+ "M_to_E_hf",
+ "M_to_E_vf",
+ "M_to_F_hf",
+ "M_to_F_vf",
+ "M_to_D_hf",
+ "M_to_D_vf",
+ "D_to_M_hf",
+ "D_to_M_vf",
+ "fp_angle_hf",
+ "fp_angle_vf",
+]
+
+
+@hjit("f(f,f)", inline=True)
+def E_to_M_hf(E, ecc):
+ r"""Mean anomaly from eccentric anomaly.
-@jit
-def _kepler_equation(E, M, ecc):
- return E_to_M(E, ecc) - M
+ .. versionadded:: 0.4.0
+ Parameters
+ ----------
+ E : float
+ Eccentric anomaly in radians.
+ ecc : float
+ Eccentricity.
-@jit
-def _kepler_equation_prime(E, M, ecc):
- return 1 - ecc * np.cos(E)
+ Returns
+ -------
+ M : float
+ Mean anomaly.
+ Warnings
+ --------
+ The mean anomaly will be outside of (-π, π]
+ if the eccentric anomaly is.
+ No validation or wrapping is performed.
-@jit
-def _kepler_equation_hyper(F, M, ecc):
- return F_to_M(F, ecc) - M
+ Notes
+ -----
+ The implementation uses the plain original Kepler equation:
+ .. math::
+ M = E - e \sin{E}
-@jit
-def _kepler_equation_prime_hyper(F, M, ecc):
- return ecc * np.cosh(F) - 1
+ """
+ M = E - ecc * sin(E)
+ return M
-def newton_factory(func, fprime):
- @jit
- def jit_newton_wrapper(x0, args=(), tol=1.48e-08, maxiter=50):
- p0 = float(x0)
- for _ in range(maxiter):
- fval = func(p0, *args)
- fder = fprime(p0, *args)
- newton_step = fval / fder
- p = p0 - newton_step
- if abs(p - p0) < tol:
- return p
- p0 = p
+@vjit("f(f,f)")
+def E_to_M_vf(E, ecc):
+ """
+ Vectorized E_to_M
+ """
- return np.nan
+ return E_to_M_hf(E, ecc)
- return jit_newton_wrapper
+@hjit("f(f,f)", inline=True)
+def F_to_M_hf(F, ecc):
+ r"""Mean anomaly from hyperbolic anomaly.
+
+ Parameters
+ ----------
+ F : float
+ Hyperbolic anomaly.
+ ecc : float
+ Eccentricity (>1).
+
+ Returns
+ -------
+ M : float
+ Mean anomaly.
+
+ Notes
+ -----
+ As noted in [5]_, by manipulating
+ the parametric equations of the hyperbola
+ we can derive a quantity that is equivalent
+ to the mean anomaly in the elliptic case:
+
+ .. math::
+
+ M = e \sinh{F} - F
+
+ """
+ M = ecc * sinh(F) - F
+ return M
+
+
+@vjit("f(f,f)")
+def F_to_M_vf(F, ecc):
+ """
+ Vectorized F_to_M
+ """
+
+ return F_to_M_hf(F, ecc)
+
+
+@hjit("f(f,f,f)", inline=True)
+def kepler_equation_hf(E, M, ecc):
+ return E_to_M_hf(E, ecc) - M
+
+
+@hjit("f(f,f,f)", inline=True)
+def kepler_equation_prime_hf(E, M, ecc):
+ return 1 - ecc * cos(E)
+
+
+@hjit("f(f,f,f)", inline=True)
+def kepler_equation_hyper_hf(F, M, ecc):
+ return F_to_M_hf(F, ecc) - M
+
+
+@hjit("f(f,f,f)", inline=True)
+def kepler_equation_prime_hyper_hf(F, M, ecc):
+ return ecc * cosh(F) - 1
-_newton_elliptic = newton_factory(_kepler_equation, _kepler_equation_prime)
-_newton_hyperbolic = newton_factory(
- _kepler_equation_hyper, _kepler_equation_prime_hyper
-)
+@hjit("f(f,f,f,f,i8)", inline=True)
+def _newton_elliptic_hf(p0, M, ecc, tol, maxiter):
+ for _ in range(maxiter):
+ fval = kepler_equation_hf(p0, M, ecc)
+ fder = kepler_equation_prime_hf(p0, M, ecc)
+ newton_step = fval / fder
+ p = p0 - newton_step
+ if abs(p - p0) < tol:
+ return p
+ p0 = p
+ return nan
-@jit
-def D_to_nu(D):
+
+@hjit("f(f,f,f,f,i8)", inline=True)
+def _newton_hyperbolic_hf(p0, M, ecc, tol, maxiter):
+ for _ in range(maxiter):
+ fval = kepler_equation_hyper_hf(p0, M, ecc)
+ fder = kepler_equation_prime_hyper_hf(p0, M, ecc)
+ newton_step = fval / fder
+ p = p0 - newton_step
+ if abs(p - p0) < tol:
+ return p
+ p0 = p
+ return nan
+
+
+@hjit("f(f)", inline=True)
+def D_to_nu_hf(D):
r"""True anomaly from parabolic anomaly.
Parameters
@@ -69,11 +209,20 @@ def D_to_nu(D):
\nu = 2 \arctan{D}
"""
- return 2.0 * np.arctan(D)
+ return 2 * atan(D)
-@jit
-def nu_to_D(nu):
+@vjit("f(f)")
+def D_to_nu_vf(D):
+ """
+ Vectorized D_to_nu
+ """
+
+ return D_to_nu_hf(D)
+
+
+@hjit("f(f)", inline=True)
+def nu_to_D_hf(nu):
r"""Parabolic anomaly from true anomaly.
Parameters
@@ -121,11 +270,20 @@ def nu_to_D(nu):
"""
# TODO: Rename to B
- return np.tan(nu / 2.0)
+ return tan(nu / 2)
+
+
+@vjit("f(f)")
+def nu_to_D_vf(nu):
+ """
+ Vectorized nu_to_D
+ """
+
+ return nu_to_D_hf(nu)
-@jit
-def nu_to_E(nu, ecc):
+@hjit("f(f,f)", inline=True)
+def nu_to_E_hf(nu, ecc):
r"""Eccentric anomaly from true anomaly.
.. versionadded:: 0.4.0
@@ -156,12 +314,21 @@ def nu_to_E(nu, ecc):
\in (-\pi, \pi]
"""
- E = 2 * np.arctan(np.sqrt((1 - ecc) / (1 + ecc)) * np.tan(nu / 2))
+ E = 2 * atan(sqrt((1 - ecc) / (1 + ecc)) * tan(nu / 2))
return E
-@jit
-def nu_to_F(nu, ecc):
+@vjit("f(f,f)")
+def nu_to_E_vf(nu, ecc):
+ """
+ Vectorized nu_to_E
+ """
+
+ return nu_to_E_hf(nu, ecc)
+
+
+@hjit("f(f,f)", inline=True)
+def nu_to_F_hf(nu, ecc):
r"""Hyperbolic anomaly from true anomaly.
Parameters
@@ -192,12 +359,21 @@ def nu_to_F(nu, ecc):
F = 2 \operatorname{arctanh} \left( \sqrt{\frac{e-1}{e+1}} \tan{\frac{\nu}{2}} \right)
"""
- F = 2 * np.arctanh(np.sqrt((ecc - 1) / (ecc + 1)) * np.tan(nu / 2))
+ F = 2 * atanh(sqrt((ecc - 1) / (ecc + 1)) * tan(nu / 2))
return F
-@jit
-def E_to_nu(E, ecc):
+@vjit("f(f,f)")
+def nu_to_F_vf(nu, ecc):
+ """
+ Vectorized nu_to_F
+ """
+
+ return nu_to_F_hf(nu, ecc)
+
+
+@hjit("f(f,f)", inline=True)
+def E_to_nu_hf(E, ecc):
r"""True anomaly from eccentric anomaly.
.. versionadded:: 0.4.0
@@ -228,12 +404,21 @@ def E_to_nu(E, ecc):
\in (-\pi, \pi]
"""
- nu = 2 * np.arctan(np.sqrt((1 + ecc) / (1 - ecc)) * np.tan(E / 2))
+ nu = 2 * atan(sqrt((1 + ecc) / (1 - ecc)) * tan(E / 2))
return nu
-@jit
-def F_to_nu(F, ecc):
+@vjit("f(f,f)")
+def E_to_nu_vf(E, ecc):
+ """
+ Vectorized E_to_nu
+ """
+
+ return E_to_nu_hf(E, ecc)
+
+
+@hjit("f(f,f)", inline=True)
+def F_to_nu_hf(F, ecc):
r"""True anomaly from hyperbolic anomaly.
Parameters
@@ -257,12 +442,21 @@ def F_to_nu(F, ecc):
\in (-\pi, \pi]
"""
- nu = 2 * np.arctan(np.sqrt((ecc + 1) / (ecc - 1)) * np.tanh(F / 2))
+ nu = 2 * atan(sqrt((ecc + 1) / (ecc - 1)) * tanh(F / 2))
return nu
-@jit
-def M_to_E(M, ecc):
+@vjit("f(f,f)")
+def F_to_nu_vf(F, ecc):
+ """
+ Vectorized F_to_nu
+ """
+
+ return F_to_nu_hf(F, ecc)
+
+
+@hjit("f(f,f)", inline=True)
+def M_to_E_hf(M, ecc):
"""Eccentric anomaly from mean anomaly.
.. versionadded:: 0.4.0
@@ -284,16 +478,25 @@ def M_to_E(M, ecc):
This uses a Newton iteration on the Kepler equation.
"""
- if -np.pi < M < 0 or np.pi < M:
+ if -pi < M < 0 or pi < M:
E0 = M - ecc
else:
E0 = M + ecc
- E = _newton_elliptic(E0, args=(M, ecc))
+ E = _newton_elliptic_hf(E0, M, ecc, _TOL, 50)
return E
-@jit
-def M_to_F(M, ecc):
+@vjit("f(f,f)")
+def M_to_E_vf(M, ecc):
+ """
+ Vectorized M_to_E
+ """
+
+ return M_to_E_hf(M, ecc)
+
+
+@hjit("f(f,f)", inline=True)
+def M_to_F_hf(M, ecc):
"""Hyperbolic anomaly from mean anomaly.
Parameters
@@ -313,13 +516,22 @@ def M_to_F(M, ecc):
This uses a Newton iteration on the hyperbolic Kepler equation.
"""
- F0 = np.arcsinh(M / ecc)
- F = _newton_hyperbolic(F0, args=(M, ecc), maxiter=100)
+ F0 = asinh(M / ecc)
+ F = _newton_hyperbolic_hf(F0, M, ecc, _TOL, 100)
return F
-@jit
-def M_to_D(M):
+@vjit("f(f,f)")
+def M_to_F_vf(M, ecc):
+ """
+ Vectorized M_to_F
+ """
+
+ return M_to_F_hf(M, ecc)
+
+
+@hjit("f(f)", inline=True)
+def M_to_D_hf(M):
"""Parabolic anomaly from mean anomaly.
Parameters
@@ -343,76 +555,17 @@ def M_to_D(M):
return D
-@jit
-def E_to_M(E, ecc):
- r"""Mean anomaly from eccentric anomaly.
-
- .. versionadded:: 0.4.0
-
- Parameters
- ----------
- E : float
- Eccentric anomaly in radians.
- ecc : float
- Eccentricity.
-
- Returns
- -------
- M : float
- Mean anomaly.
-
- Warnings
- --------
- The mean anomaly will be outside of (-π, π]
- if the eccentric anomaly is.
- No validation or wrapping is performed.
-
- Notes
- -----
- The implementation uses the plain original Kepler equation:
-
- .. math::
- M = E - e \sin{E}
-
+@vjit("f(f)")
+def M_to_D_vf(M):
"""
- M = E - ecc * np.sin(E)
- return M
-
-
-@jit
-def F_to_M(F, ecc):
- r"""Mean anomaly from hyperbolic anomaly.
-
- Parameters
- ----------
- F : float
- Hyperbolic anomaly.
- ecc : float
- Eccentricity (>1).
-
- Returns
- -------
- M : float
- Mean anomaly.
-
- Notes
- -----
- As noted in [5]_, by manipulating
- the parametric equations of the hyperbola
- we can derive a quantity that is equivalent
- to the mean anomaly in the elliptic case:
-
- .. math::
-
- M = e \sinh{F} - F
-
+ Vectorized M_to_D
"""
- M = ecc * np.sinh(F) - F
- return M
+ return M_to_D_hf(M)
-@jit
-def D_to_M(D):
+
+@hjit("f(f)", inline=True)
+def D_to_M_hf(D):
r"""Mean anomaly from parabolic anomaly.
Parameters
@@ -444,8 +597,17 @@ def D_to_M(D):
return M
-@jit
-def fp_angle(nu, ecc):
+@vjit("f(f)")
+def D_to_M_vf(D):
+ """
+ Vectorized D_to_M
+ """
+
+ return D_to_M_hf(D)
+
+
+@hjit("f(f,f)", inline=True)
+def fp_angle_hf(nu, ecc):
r"""Returns the flight path angle.
Parameters
@@ -469,4 +631,13 @@ def fp_angle(nu, ecc):
\phi = \arctan(\frac {e \sin{\nu}}{1 + e \cos{\nu}})
"""
- return np.arctan2(ecc * np.sin(nu), 1 + ecc * np.cos(nu))
+ return atan2(ecc * sin(nu), 1 + ecc * cos(nu))
+
+
+@vjit("f(f,f)")
+def fp_angle_vf(nu, ecc):
+ """
+ Vectorized fp_angle
+ """
+
+ return fp_angle_hf(nu, ecc)
diff --git a/src/hapsira/core/czml_utils.py b/src/hapsira/core/czml_utils.py
index 54313646b..4e7a515e1 100644
--- a/src/hapsira/core/czml_utils.py
+++ b/src/hapsira/core/czml_utils.py
@@ -1,7 +1,7 @@
from numba import njit as jit
import numpy as np
-from hapsira._math.linalg import norm
+from .math.linalg import norm_V_hf
@jit
@@ -95,7 +95,7 @@ def project_point_on_ellipsoid(x, y, z, a, b, c):
"""
p1, p2 = intersection_ellipsoid_line(x, y, z, x, y, z, a, b, c)
- norm_1 = norm(np.array([p1[0] - x, p1[1] - y, p1[2] - z]))
- norm_2 = norm(np.array([p2[0] - x, p2[1] - y, p2[2] - z]))
+ norm_1 = norm_V_hf((p1[0] - x, p1[1] - y, p1[2] - z))
+ norm_2 = norm_V_hf((p2[0] - x, p2[1] - y, p2[2] - z))
return p1 if norm_1 <= norm_2 else p2
diff --git a/src/hapsira/core/earth_atmosphere/__init__.py b/src/hapsira/core/earth/__init__.py
similarity index 100%
rename from src/hapsira/core/earth_atmosphere/__init__.py
rename to src/hapsira/core/earth/__init__.py
diff --git a/src/hapsira/core/earth/atmosphere/__init__.py b/src/hapsira/core/earth/atmosphere/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/src/hapsira/core/earth_atmosphere/util.py b/src/hapsira/core/earth/atmosphere/coesa.py
similarity index 70%
rename from src/hapsira/core/earth_atmosphere/util.py
rename to src/hapsira/core/earth/atmosphere/coesa.py
index b940e4e84..8435d4706 100644
--- a/src/hapsira/core/earth_atmosphere/util.py
+++ b/src/hapsira/core/earth/atmosphere/coesa.py
@@ -1,10 +1,15 @@
"""This script holds several utilities related to atmospheric computations."""
-from numba import njit as jit
+from ...jit import hjit
+__all__ = [
+ "get_index_hf",
+ "check_altitude_hf",
+]
-@jit
-def geometric_to_geopotential(z, r0):
+
+@hjit("f(f,f)")
+def _geometric_to_geopotential_hf(z, r0):
"""Converts from given geometric altitude to geopotential one.
Parameters
@@ -16,18 +21,15 @@ def geometric_to_geopotential(z, r0):
Returns
-------
- h: float
+ h : float
Geopotential altitude.
"""
h = r0 * z / (r0 + z)
return h
-z_to_h = geometric_to_geopotential
-
-
-@jit
-def geopotential_to_geometric(h, r0):
+@hjit("f(f,f)")
+def _geopotential_to_geometric_hf(h, r0):
"""Converts from given geopotential altitude to geometric one.
Parameters
@@ -39,18 +41,19 @@ def geopotential_to_geometric(h, r0):
Returns
-------
- z: float
+ z : float
Geometric altitude.
"""
z = r0 * h / (r0 - h)
return z
-h_to_z = geopotential_to_geometric
+_z_to_h_hf = _geometric_to_geopotential_hf
+_h_to_z_hf = _geopotential_to_geometric_hf
-@jit
-def gravity(z, g0, r0):
+@hjit("f(f,f,f)")
+def _gravity_hf(z, g0, r0):
"""Relates Earth gravity field magnitude with the geometric height.
Parameters
@@ -64,15 +67,15 @@ def gravity(z, g0, r0):
Returns
-------
- g: float
+ g : float
Gravity value at given geometric altitude.
"""
g = g0 * (r0 / (r0 + z)) ** 2
return g
-@jit
-def _get_index(x, x_levels):
+@hjit # ("i8(f,f)") # TODO use tuple with fixed length
+def get_index_hf(x, x_levels):
"""Finds element in list and returns index.
Parameters
@@ -84,7 +87,7 @@ def _get_index(x, x_levels):
Returns
-------
- i: int
+ i : int
Index for the value.
"""
@@ -97,14 +100,14 @@ def _get_index(x, x_levels):
return i - 1
-@jit
-def _check_altitude(alt, r0, geometric):
- # Get geometric and geopotential altitudes
+@hjit("Tuple([f,f])(f,f,b1)")
+def check_altitude_hf(alt, r0, geometric):
+ """Get geometric and geopotential altitudes"""
if geometric:
z = alt
- h = z_to_h(z, r0)
+ h = _z_to_h_hf(z, r0)
else:
h = alt
- z = h_to_z(h, r0)
+ z = _h_to_z_hf(h, r0)
return z, h
diff --git a/src/hapsira/core/earth/atmosphere/coesa76.py b/src/hapsira/core/earth/atmosphere/coesa76.py
new file mode 100644
index 000000000..43a02025a
--- /dev/null
+++ b/src/hapsira/core/earth/atmosphere/coesa76.py
@@ -0,0 +1,421 @@
+from math import exp
+
+from astropy.io import ascii as ascii_
+from astropy.utils.data import get_pkg_data_filename
+
+from numpy import float32 as float_, int64 as i8
+
+from .coesa import check_altitude_hf
+from ...jit import hjit, vjit
+
+__all__ = [
+ "R",
+ "R_air",
+ "k",
+ "Na",
+ "g0",
+ "r0",
+ "M0",
+ "P0",
+ "T0",
+ "Tinf",
+ "gamma",
+ "alpha",
+ "beta",
+ "S",
+ "b_levels",
+ "zb_levels",
+ "hb_levels",
+ "Tb_levels",
+ "Lb_levels",
+ "pb_levels",
+ "z_coeff",
+ "p_coeff",
+ "rho_coeff",
+ "COESA76_GEOMETRIC",
+ "pressure_hf",
+ "pressure_vf",
+ "temperature_hf",
+ "temperature_vf",
+ "density_hf",
+ "density_vf",
+]
+
+# Following constants come from original U.S Atmosphere 1962 paper so a pure
+# model of this atmosphere can be implemented
+R = float_(8314.32) # u.J / u.kmol / u.K
+R_air = float_(287.053) # u.J / u.kg / u.K
+k = float_(1.380622e-23) # u.J / u.K
+Na = float_(6.022169e-26) # 1 / u.kmol
+g0 = float_(9.80665) # u.m / u.s**2
+r0 = float_(6356.766) # u.km
+M0 = float_(28.9644) # u.kg / u.kmol
+P0 = float_(101325) # u.Pa
+T0 = float_(288.15) # u.K
+Tinf = 1000 # u.K
+gamma = float_(1.4) # one
+alpha = float_(34.1632) # u.K / u.km
+beta = float_(1.458e-6) # (u.kg / u.s / u.m / (u.K) ** 0.5)
+S = float_(110.4) # u.K
+
+# Reading layer parameters file
+coesa76_data = ascii_.read(get_pkg_data_filename("data/coesa76.dat"))
+b_levels = tuple(i8(number) for number in coesa76_data["b"].data)
+zb_levels = tuple(float_(number) for number in coesa76_data["Zb [km]"].data) # u.km
+hb_levels = tuple(float_(number) for number in coesa76_data["Hb [km]"].data) # u.km
+Tb_levels = tuple(float_(number) for number in coesa76_data["Tb [K]"].data) # u.K
+Lb_levels = tuple(
+ float_(number) for number in coesa76_data["Lb [K/km]"].data
+) # u.K / u.km
+pb_levels = tuple(float_(number) for number in coesa76_data["pb [mbar]"].data) # u.mbar
+
+# Reading pressure and density coefficients files
+p_data = ascii_.read(get_pkg_data_filename("data/coesa76_p.dat"))
+rho_data = ascii_.read(get_pkg_data_filename("data/coesa76_rho.dat"))
+
+# Zip coefficients for each altitude
+z_coeff = tuple(i8(number) for number in p_data["z [km]"].data) # u.km
+p_coeff = (
+ tuple(float_(number) for number in p_data["A"].data),
+ tuple(float_(number) for number in p_data["B"].data),
+ tuple(float_(number) for number in p_data["C"].data),
+ tuple(float_(number) for number in p_data["D"].data),
+ tuple(float_(number) for number in p_data["E"].data),
+)
+rho_coeff = (
+ tuple(float_(number) for number in rho_data["A"].data),
+ tuple(float_(number) for number in rho_data["B"].data),
+ tuple(float_(number) for number in rho_data["C"].data),
+ tuple(float_(number) for number in rho_data["D"].data),
+ tuple(float_(number) for number in rho_data["E"].data),
+)
+
+COESA76_GEOMETRIC = True
+
+
+@hjit("Tuple([f,f])(f,f,b1)")
+def _check_altitude_hf(alt, r0, geometric):
+ """Checks if altitude is inside valid range.
+
+ Parameters
+ ----------
+ alt : float
+ Altitude to be checked.
+ r0 : float
+ Attractor radius.
+ geometric : bool
+ If `True`, assumes geometric altitude kind.
+ Default `True`.
+
+ Returns
+ -------
+ z : float
+ Geometric altitude.
+ h : float
+ Geopotential altitude.
+
+ """
+ z, h = check_altitude_hf(alt, r0, geometric)
+ assert zb_levels[0] <= z <= zb_levels[-1]
+
+ return z, h
+
+
+@hjit("i8(f)")
+def _get_index_zb_levels_hf(x):
+ """Finds element in list and returns index.
+
+ Parameters
+ ----------
+ x : float
+ Element to be searched.
+
+ Returns
+ -------
+ i : int
+ Index for the value. `999` if there was an error.
+
+ """
+ for i, value in enumerate(zb_levels):
+ if i < len(zb_levels) and value < x:
+ continue
+ if x == value:
+ return i
+ return i - 1
+ return 999 # HACK error ... ?
+
+
+@hjit("i8(f)")
+def _get_index_z_coeff_hf(x):
+ """Finds element in list and returns index.
+
+ Parameters
+ ----------
+ x : float
+ Element to be searched.
+
+ Returns
+ -------
+ i : int
+ Index for the value.
+ Index for the value. `999` if there was an error.
+
+ """
+ for i, value in enumerate(z_coeff):
+ if i < len(z_coeff) and value < x:
+ continue
+ if x == value:
+ return i
+ return i - 1
+ return 999 # HACK error ... ?
+
+
+@hjit("Tuple([f,f,f,f,f])(f)")
+def _get_coefficients_avobe_86_p_coeff_hf(z):
+ """Returns corresponding coefficients for 4th order polynomial approximation.
+
+ Parameters
+ ----------
+ z : float
+ Geometric altitude
+
+ Returns
+ -------
+ coeffs : tuple[float,float,float,float,float]
+ Tuple of corresponding coefficients
+
+ """
+ # Get corresponding coefficients
+ i = _get_index_z_coeff_hf(z)
+ return p_coeff[0][i], p_coeff[1][i], p_coeff[2][i], p_coeff[3][i], p_coeff[4][i]
+
+
+@hjit("Tuple([f,f,f,f,f])(f)")
+def _get_coefficients_avobe_86_rho_coeff_hf(z):
+ """Returns corresponding coefficients for 4th order polynomial approximation.
+
+ Parameters
+ ----------
+ z : float
+ Geometric altitude
+
+ Returns
+ -------
+ coeffs : tuple[float,float,float,float,float]
+ Tuple of corresponding coefficients
+
+ """
+ # Get corresponding coefficients
+ i = _get_index_z_coeff_hf(z)
+ return (
+ rho_coeff[0][i],
+ rho_coeff[1][i],
+ rho_coeff[2][i],
+ rho_coeff[3][i],
+ rho_coeff[4][i],
+ )
+
+
+@hjit("f(f,b1)")
+def temperature_hf(alt, geometric):
+ """Solves for temperature at given altitude.
+
+ Parameters
+ ----------
+ alt : float
+ Geometric/Geopotential altitude.
+ geometric : bool
+ If `True`, assumes geometric altitude kind.
+ Default `True`.
+
+ Returns
+ -------
+ T : float
+ Kinetic temeperature.
+
+ """
+ # Test if altitude is inside valid range
+ z, h = _check_altitude_hf(alt, r0, geometric)
+
+ # Get base parameters
+ i = _get_index_zb_levels_hf(z)
+ Tb = Tb_levels[i]
+ Lb = Lb_levels[i]
+ hb = hb_levels[i]
+
+ # Apply different equations
+ if z < zb_levels[7]:
+ # Below 86km
+ # TODO: Apply air mean molecular weight ratio factor
+ Tm = Tb + Lb * (h - hb)
+ T = Tm
+ elif zb_levels[7] <= z and z < zb_levels[8]:
+ # [86km, 91km)
+ T = 186.87
+ elif zb_levels[8] <= z and z < zb_levels[9]:
+ # [91km, 110km]
+ Tc = 263.1905
+ A = -76.3232
+ a = -19.9429
+ T = Tc + A * (1 - ((z - zb_levels[8]) / a) ** 2) ** 0.5
+ elif zb_levels[9] <= z and z < zb_levels[10]:
+ # [110km, 120km]
+ T = 240 + Lb * (z - zb_levels[9])
+ else:
+ T10 = 360.0
+ _gamma = Lb_levels[9] / (Tinf - T10)
+ epsilon = (z - zb_levels[10]) * (r0 + zb_levels[10]) / (r0 + z)
+ T = Tinf - (Tinf - T10) * exp(-_gamma * epsilon)
+
+ return T
+
+
+@vjit("f(f,b1)")
+def temperature_vf(alt, geometric):
+ """Solves for temperature at given altitude.
+
+ Parameters
+ ----------
+ alt : float
+ Geometric/Geopotential altitude.
+ geometric : bool
+ If `True`, assumes geometric altitude kind.
+ Default `True`.
+
+ Returns
+ -------
+ T : float
+ Kinetic temeperature.
+
+ """
+ return temperature_hf(alt, geometric)
+
+
+@hjit("f(f,b1)")
+def pressure_hf(alt, geometric):
+ """Solves pressure at given altitude.
+
+ Parameters
+ ----------
+ alt : float
+ Geometric/Geopotential altitude.
+ geometric : bool
+ If `True`, assumes geometric altitude kind.
+ Default `True`.
+
+ Returns
+ -------
+ p : float
+ Pressure at given altitude.
+
+ """
+ # Test if altitude is inside valid range
+ z, h = _check_altitude_hf(alt, r0, geometric)
+
+ # Obtain gravity magnitude
+ # Get base parameters
+ i = _get_index_zb_levels_hf(z)
+ Tb = Tb_levels[i]
+ Lb = Lb_levels[i]
+ hb = hb_levels[i]
+ pb = pb_levels[i]
+
+ # If above 86[km] usual formulation is applied
+ if z < 86:
+ if Lb == 0.0:
+ p = pb * exp(-alpha * (h - hb) / Tb) * 100 # HACK 100 ... SI-prefix change?
+ else:
+ T = temperature_hf(z, geometric)
+ p = pb * (Tb / T) ** (alpha / Lb) * 100 # HACK 100 ... SI-prefix change?
+ else:
+ # TODO: equation (33c) should be applied instead of using coefficients
+
+ # A 4th order polynomial is used to approximate pressure. This was
+ # directly taken from: http://www.braeunig.us/space/atmmodel.htm
+ A, B, C, D, E = _get_coefficients_avobe_86_p_coeff_hf(z)
+
+ # Solve the polynomial
+ p = exp(A * z**4 + B * z**3 + C * z**2 + D * z + E)
+
+ return p
+
+
+@vjit("f(f,b1)")
+def pressure_vf(alt, geometric):
+ """Solves pressure at given altitude.
+
+ Parameters
+ ----------
+ alt : float
+ Geometric/Geopotential altitude.
+ geometric : bool
+ If `True`, assumes geometric altitude kind.
+ Default `True`.
+
+ Returns
+ -------
+ p : float
+ Pressure at given altitude.
+
+ """
+ return pressure_hf(alt, geometric)
+
+
+@hjit("f(f,b1)")
+def density_hf(alt, geometric):
+ """Solves density at given height.
+
+ Parameters
+ ----------
+ alt : float
+ Geometric/Geopotential height.
+ geometric : bool
+ If `True`, assumes that `alt` argument is geometric kind.
+ Default `True`.
+
+ Returns
+ -------
+ rho : float
+ Density at given height.
+
+ """
+ # Test if altitude is inside valid range
+ z, _ = _check_altitude_hf(alt, r0, geometric)
+
+ # Solve temperature and pressure
+ if z <= 86:
+ T = temperature_hf(z, geometric)
+ p = pressure_hf(z, geometric)
+ rho = p / R_air / T
+ else:
+ # TODO: equation (42) should be applied instead of using coefficients
+
+ # A 4th order polynomial is used to approximate pressure. This was
+ # directly taken from: http://www.braeunig.us/space/atmmodel.htm
+ A, B, C, D, E = _get_coefficients_avobe_86_rho_coeff_hf(z)
+
+ # Solve the polynomial
+ rho = exp(A * z**4 + B * z**3 + C * z**2 + D * z + E)
+
+ return rho
+
+
+@vjit("f(f,b1)")
+def density_vf(alt, geometric):
+ """Solves density at given height.
+
+ Parameters
+ ----------
+ alt : float
+ Geometric/Geopotential height.
+ geometric : bool
+ If `True`, assumes that `alt` argument is geometric kind.
+ Default `True`.
+
+ Returns
+ -------
+ rho : float
+ Density at given height.
+
+ """
+ return density_hf(alt, geometric)
diff --git a/src/hapsira/earth/atmosphere/data/coesa62.dat b/src/hapsira/core/earth/atmosphere/data/coesa62.dat
similarity index 100%
rename from src/hapsira/earth/atmosphere/data/coesa62.dat
rename to src/hapsira/core/earth/atmosphere/data/coesa62.dat
diff --git a/src/hapsira/earth/atmosphere/data/coesa76.dat b/src/hapsira/core/earth/atmosphere/data/coesa76.dat
similarity index 100%
rename from src/hapsira/earth/atmosphere/data/coesa76.dat
rename to src/hapsira/core/earth/atmosphere/data/coesa76.dat
diff --git a/src/hapsira/earth/atmosphere/data/coesa76_p.dat b/src/hapsira/core/earth/atmosphere/data/coesa76_p.dat
similarity index 100%
rename from src/hapsira/earth/atmosphere/data/coesa76_p.dat
rename to src/hapsira/core/earth/atmosphere/data/coesa76_p.dat
diff --git a/src/hapsira/earth/atmosphere/data/coesa76_rho.dat b/src/hapsira/core/earth/atmosphere/data/coesa76_rho.dat
similarity index 100%
rename from src/hapsira/earth/atmosphere/data/coesa76_rho.dat
rename to src/hapsira/core/earth/atmosphere/data/coesa76_rho.dat
diff --git a/src/hapsira/core/earth_atmosphere/jacchia.py b/src/hapsira/core/earth/atmosphere/jacchia.py
similarity index 100%
rename from src/hapsira/core/earth_atmosphere/jacchia.py
rename to src/hapsira/core/earth/atmosphere/jacchia.py
diff --git a/src/hapsira/core/elements.py b/src/hapsira/core/elements.py
index a9be9bd08..151005e58 100644
--- a/src/hapsira/core/elements.py
+++ b/src/hapsira/core/elements.py
@@ -2,19 +2,52 @@
convert between different elements that define the orbit of a body.
"""
-import sys
-
-from numba import njit as jit, prange
-import numpy as np
-from numpy import cos, cross, sin, sqrt
-
-from hapsira._math.linalg import norm
-from hapsira.core.angles import E_to_nu, F_to_nu
-from hapsira.core.util import rotation_matrix
-
-
-@jit
-def eccentricity_vector(k, r, v):
+from math import acos, atan, atan2, cos, fabs, log, pi, sin, sqrt, tan
+
+from .angles import E_to_nu_hf, F_to_nu_hf
+from .jit import array_to_V_hf, hjit, gjit, vjit
+from .math.linalg import (
+ cross_VV_hf,
+ div_Vs_hf,
+ matmul_MM_hf,
+ matmul_VM_hf,
+ matmul_VV_hf,
+ mul_Vs_hf,
+ norm_V_hf,
+ sub_VV_hf,
+ transpose_M_hf,
+)
+from .util import rotation_matrix_hf
+
+
+__all__ = [
+ "eccentricity_vector_hf",
+ "eccentricity_vector_gf",
+ "circular_velocity_hf",
+ "circular_velocity_vf",
+ "rv_pqw_hf",
+ "coe_rotation_matrix_hf",
+ "coe2rv_hf",
+ "coe2rv_gf",
+ "coe2mee_hf",
+ "coe2mee_gf",
+ "RV2COE_TOL",
+ "rv2coe_hf",
+ "rv2coe_gf",
+ "mee2coe_hf",
+ "mee2coe_gf",
+ "mee2rv_hf",
+ "mee2rv_gf",
+ "mean_motion_vf",
+ "period_vf",
+]
+
+
+RV2COE_TOL = 1e-8
+
+
+@hjit("V(f,V,V)")
+def eccentricity_vector_hf(k, r, v):
r"""Eccentricity vector.
.. math::
@@ -27,17 +60,28 @@ def eccentricity_vector(k, r, v):
----------
k : float
Standard gravitational parameter (km^3 / s^2).
- r : numpy.ndarray
+ r : tuple[float,float,float]
Position vector (km)
- v : numpy.ndarray
+ v : tuple[float,float,float]
Velocity vector (km / s)
"""
- return ((v @ v - k / norm(r)) * r - (r @ v) * v) / k
+ a = matmul_VV_hf(v, v) - k / norm_V_hf(r)
+ b = matmul_VV_hf(r, v)
+ return div_Vs_hf(sub_VV_hf(mul_Vs_hf(r, a), mul_Vs_hf(v, b)), k)
+
+
+@gjit("void(f,f[:],f[:],f[:])", "(),(n),(n)->(n)")
+def eccentricity_vector_gf(k, r, v, e):
+ """
+ Vectorized eccentricity_vector
+ """
+ e[0], e[1], e[2] = eccentricity_vector_hf(k, array_to_V_hf(r), array_to_V_hf(v))
-@jit
-def circular_velocity(k, a):
- r"""Compute circular velocity for a given body given thegravitational parameter and the semimajor axis.
+
+@hjit("f(f,f)")
+def circular_velocity_hf(k, a):
+ r"""Compute circular velocity for a given body given the gravitational parameter and the semimajor axis.
.. math::
@@ -51,11 +95,20 @@ def circular_velocity(k, a):
Semimajor Axis
"""
- return np.sqrt(k / a)
+ return sqrt(k / a)
+
+
+@vjit("f(f,f)")
+def circular_velocity_vf(k, a):
+ """
+ Vectorized circular_velocity
+ """
+
+ return circular_velocity_hf(k, a)
-@jit
-def rv_pqw(k, p, ecc, nu):
+@hjit("Tuple([V,V])(f,f,f,f)")
+def rv_pqw_hf(k, p, ecc, nu):
r"""Returns r and v vectors in perifocal frame.
Parameters
@@ -71,9 +124,9 @@ def rv_pqw(k, p, ecc, nu):
Returns
-------
- r: numpy.ndarray
+ r: tuple[float,float,float]
Position. Dimension 3 vector
- v: numpy.ndarray
+ v: tuple[float,float,float]
Velocity. Dimension 3 vector
Notes
@@ -109,23 +162,29 @@ def rv_pqw(k, p, ecc, nu):
v = [-5753.30180931 -1328.66813933 0] [m]/[s]
"""
- pqw = np.array([[cos(nu), sin(nu), 0], [-sin(nu), ecc + cos(nu), 0]]) * np.array(
- [[p / (1 + ecc * cos(nu))], [sqrt(k / p)]]
+
+ sin_nu = sin(nu)
+ cos_nu = cos(nu)
+ a = p / (1 + ecc * cos_nu)
+ b = sqrt(k / p)
+
+ return (
+ (cos_nu * a, sin_nu * a, 0.0),
+ (-sin_nu * b, (ecc + cos_nu) * b, 0.0),
)
- return pqw
-@jit
-def coe_rotation_matrix(inc, raan, argp):
+@hjit("M(f,f,f)")
+def coe_rotation_matrix_hf(inc, raan, argp):
"""Create a rotation matrix for coe transformation."""
- r = rotation_matrix(raan, 2)
- r = r @ rotation_matrix(inc, 0)
- r = r @ rotation_matrix(argp, 2)
+ r = rotation_matrix_hf(raan, 2)
+ r = matmul_MM_hf(r, rotation_matrix_hf(inc, 0))
+ r = matmul_MM_hf(r, rotation_matrix_hf(argp, 2))
return r
-@jit
-def coe2rv(k, p, ecc, inc, raan, argp, nu):
+@hjit("Tuple([V,V])(f,f,f,f,f,f,f)")
+def coe2rv_hf(k, p, ecc, inc, raan, argp, nu):
r"""Converts from classical orbital to state vectors.
Classical orbital elements are converted into position and velocity
@@ -151,9 +210,9 @@ def coe2rv(k, p, ecc, inc, raan, argp, nu):
Returns
-------
- r_ijk: numpy.ndarray
+ r_ijk: tuple[float,float,float]
Position vector in basis ijk.
- v_ijk: numpy.ndarray
+ v_ijk: tuple[float,float,float]
Velocity vector in basis ijk.
Notes
@@ -179,30 +238,27 @@ def coe2rv(k, p, ecc, inc, raan, argp, nu):
\end{bmatrix}
"""
- pqw = rv_pqw(k, p, ecc, nu)
- rm = coe_rotation_matrix(inc, raan, argp)
+ r, v = rv_pqw_hf(k, p, ecc, nu)
+ rm = transpose_M_hf(coe_rotation_matrix_hf(inc, raan, argp))
+ return matmul_VM_hf(r, rm), matmul_VM_hf(v, rm)
- ijk = pqw @ rm.T
- return ijk
-
-
-@jit(parallel=sys.maxsize > 2**31)
-def coe2rv_many(k, p, ecc, inc, raan, argp, nu):
- """Parallel version of coe2rv."""
- n = nu.shape[0]
- rr = np.zeros((n, 3))
- vv = np.zeros((n, 3))
+@gjit("void(f,f,f,f,f,f,f,u1[:],f[:],f[:])", "(),(),(),(),(),(),(),(n)->(n),(n)")
+def coe2rv_gf(k, p, ecc, inc, raan, argp, nu, dummy, rr, vv):
+ """
+ Vectorized coe2rv
- # Disabling pylint warning, see https://github.com/PyCQA/pylint/issues/2910
- for i in prange(n): # pylint: disable=not-an-iterable
- rr[i, :], vv[i, :] = coe2rv(k[i], p[i], ecc[i], inc[i], raan[i], argp[i], nu[i])
+ `dummy` because of https://github.com/numba/numba/issues/2797
+ """
+ assert dummy.shape == (3,)
- return rr, vv
+ (rr[0], rr[1], rr[2]), (vv[0], vv[1], vv[2]) = coe2rv_hf(
+ k, p, ecc, inc, raan, argp, nu
+ )
-@jit
-def coe2mee(p, ecc, inc, raan, argp, nu):
+@hjit("Tuple([f,f,f,f,f,f])(f,f,f,f,f,f)")
+def coe2mee_hf(p, ecc, inc, raan, argp, nu):
r"""Converts from classical orbital elements to modified equinoctial orbital elements.
The definition of the modified equinoctial orbital elements is taken from [Walker, 1985].
@@ -259,31 +315,43 @@ def coe2mee(p, ecc, inc, raan, argp, nu):
\end{align}
"""
- if inc == np.pi:
+ if inc == pi:
raise ValueError(
"Cannot compute modified equinoctial set for 180 degrees orbit inclination due to `h` and `k` singularity."
)
lonper = raan + argp
- f = ecc * np.cos(lonper)
- g = ecc * np.sin(lonper)
- h = np.tan(inc / 2) * np.cos(raan)
- k = np.tan(inc / 2) * np.sin(raan)
+ f = ecc * cos(lonper)
+ g = ecc * sin(lonper)
+ h = tan(inc / 2) * cos(raan)
+ k = tan(inc / 2) * sin(raan)
L = lonper + nu
return p, f, g, h, k, L
-@jit
-def rv2coe(k, r, v, tol=1e-8):
+@gjit(
+ "void(f,f,f,f,f,f,f[:],f[:],f[:],f[:],f[:],f[:])",
+ "(),(),(),(),(),()->(),(),(),(),(),()",
+)
+def coe2mee_gf(p, ecc, inc, raan, argp, nu, p_, f, g, h, k, L):
+ """
+ Vectorized coe2mee
+ """
+
+ p_[0], f[0], g[0], h[0], k[0], L[0] = coe2mee_hf(p, ecc, inc, raan, argp, nu)
+
+
+@hjit("Tuple([f,f,f,f,f,f])(f,V,V,f)")
+def rv2coe_hf(k, r, v, tol):
r"""Converts from vectors to classical orbital elements.
Parameters
----------
k : float
Standard gravitational parameter (km^3 / s^2)
- r : numpy.ndarray
+ r : tuple[float,float,float]
Position vector (km)
- v : numpy.ndarray
+ v : tuple[float,float,float]
Velocity vector (km / s)
tol : float, optional
Tolerance for eccentricity and inclination checks, default to 1e-8
@@ -329,7 +397,7 @@ def rv2coe(k, r, v, tol=1e-8):
N &= \sqrt{\vec{N}\cdot\vec{N}}
\end{align}
- 4. The rigth ascension node is computed:
+ 4. The right ascension node is computed:
.. math::
\Omega = \left\{ \begin{array}{lcc}
@@ -378,53 +446,80 @@ def rv2coe(k, r, v, tol=1e-8):
nu: 28.445804984192122 [deg]
"""
- h = cross(r, v)
- n = cross([0, 0, 1], h)
- e = ((v @ v - k / norm(r)) * r - (r @ v) * v) / k
- ecc = norm(e)
- p = (h @ h) / k
- inc = np.arccos(h[2] / norm(h))
+ h = cross_VV_hf(r, v)
+ n = cross_VV_hf((0, 0, 1), h)
+ e = mul_Vs_hf(
+ sub_VV_hf(
+ mul_Vs_hf(r, (matmul_VV_hf(v, v) - k / norm_V_hf(r))),
+ mul_Vs_hf(v, matmul_VV_hf(r, v)),
+ ),
+ 1 / k,
+ )
+ ecc = norm_V_hf(e)
+ p = matmul_VV_hf(h, h) / k
+ inc = acos(h[2] / norm_V_hf(h))
circular = ecc < tol
equatorial = abs(inc) < tol
if equatorial and not circular:
raan = 0
- argp = np.arctan2(e[1], e[0]) % (2 * np.pi) # Longitude of periapsis
- nu = np.arctan2((h @ cross(e, r)) / norm(h), r @ e)
+ argp = atan2(e[1], e[0]) % (2 * pi) # Longitude of periapsis
+ nu = atan2(
+ matmul_VV_hf(h, cross_VV_hf(e, r)) / norm_V_hf(h), matmul_VV_hf(r, e)
+ )
elif not equatorial and circular:
- raan = np.arctan2(n[1], n[0]) % (2 * np.pi)
+ raan = atan2(n[1], n[0]) % (2 * pi)
argp = 0
# Argument of latitude
- nu = np.arctan2((r @ cross(h, n)) / norm(h), r @ n)
+ nu = atan2(
+ matmul_VV_hf(r, cross_VV_hf(h, n)) / norm_V_hf(h), matmul_VV_hf(r, n)
+ )
elif equatorial and circular:
raan = 0
argp = 0
- nu = np.arctan2(r[1], r[0]) % (2 * np.pi) # True longitude
+ nu = atan2(r[1], r[0]) % (2 * pi) # True longitude
else:
a = p / (1 - (ecc**2))
ka = k * a
if a > 0:
- e_se = (r @ v) / sqrt(ka)
- e_ce = norm(r) * (v @ v) / k - 1
- nu = E_to_nu(np.arctan2(e_se, e_ce), ecc)
+ e_se = matmul_VV_hf(r, v) / sqrt(ka)
+ e_ce = norm_V_hf(r) * matmul_VV_hf(v, v) / k - 1
+ nu = E_to_nu_hf(atan2(e_se, e_ce), ecc)
else:
- e_sh = (r @ v) / sqrt(-ka)
- e_ch = norm(r) * (norm(v) ** 2) / k - 1
- nu = F_to_nu(np.log((e_ch + e_sh) / (e_ch - e_sh)) / 2, ecc)
+ e_sh = matmul_VV_hf(r, v) / sqrt(-ka)
+ e_ch = norm_V_hf(r) * (norm_V_hf(v) ** 2) / k - 1
+ nu = F_to_nu_hf(log((e_ch + e_sh) / (e_ch - e_sh)) / 2, ecc)
- raan = np.arctan2(n[1], n[0]) % (2 * np.pi)
- px = r @ n
- py = (r @ cross(h, n)) / norm(h)
- argp = (np.arctan2(py, px) - nu) % (2 * np.pi)
+ raan = atan2(n[1], n[0]) % (2 * pi)
+ px = matmul_VV_hf(r, n)
+ py = matmul_VV_hf(r, cross_VV_hf(h, n)) / norm_V_hf(h)
+ argp = (atan2(py, px) - nu) % (2 * pi)
- nu = (nu + np.pi) % (2 * np.pi) - np.pi
+ nu = (nu + pi) % (2 * pi) - pi
return p, ecc, inc, raan, argp, nu
-@jit
-def mee2coe(p, f, g, h, k, L):
+@gjit(
+ "void(f,f[:],f[:],f,f[:],f[:],f[:],f[:],f[:],f[:])",
+ "(),(n),(n),()->(),(),(),(),(),()",
+)
+def rv2coe_gf(k, r, v, tol, p, ecc, inc, raan, argp, nu):
+ """
+ Vectorized rv2coe
+ """
+
+ p[0], ecc[0], inc[0], raan[0], argp[0], nu[0] = rv2coe_hf(
+ k,
+ array_to_V_hf(r),
+ array_to_V_hf(v),
+ tol,
+ )
+
+
+@hjit("Tuple([f,f,f,f,f,f])(f,f,f,f,f,f)")
+def mee2coe_hf(p, f, g, h, k, L):
r"""Converts from modified equinoctial orbital elements to classical
orbital elements.
@@ -478,17 +573,29 @@ def mee2coe(p, f, g, h, k, L):
arguments.
"""
- ecc = np.sqrt(f**2 + g**2)
- inc = 2 * np.arctan(np.sqrt(h**2 + k**2))
- lonper = np.arctan2(g, f)
- raan = np.arctan2(k, h) % (2 * np.pi)
- argp = (lonper - raan) % (2 * np.pi)
- nu = (L - lonper) % (2 * np.pi)
+ ecc = sqrt(f**2 + g**2)
+ inc = 2 * atan(sqrt(h**2 + k**2))
+ lonper = atan2(g, f)
+ raan = atan2(k, h) % (2 * pi)
+ argp = (lonper - raan) % (2 * pi)
+ nu = (L - lonper) % (2 * pi)
return p, ecc, inc, raan, argp, nu
-@jit
-def mee2rv(p, f, g, h, k, L):
+@gjit(
+ "void(f,f,f,f,f,f,f[:],f[:],f[:],f[:],f[:],f[:])",
+ "(),(),(),(),(),()->(),(),(),(),(),()",
+)
+def mee2coe_gf(p, f, g, h, k, L, p_, ecc, inc, raan, argp, nu):
+ """
+ Vectorized mee2coe
+ """
+
+ p_[0], ecc[0], inc[0], raan[0], argp[0], nu[0] = mee2coe_hf(p, f, g, h, k, L)
+
+
+@hjit("Tuple([V,V])(f,f,f,f,f,f)")
+def mee2rv_hf(p, f, g, h, k, L): # TODO untested
"""Calculates position and velocity vector from modified equinoctial elements.
Parameters
@@ -508,9 +615,9 @@ def mee2rv(p, f, g, h, k, L):
Returns
-------
- r: numpy.ndarray
+ r: tuple[float,float,float]
Position vector.
- v: numpy.ndarray
+ v: tuple[float,float,float]
Velocity vector.
Note
@@ -520,22 +627,22 @@ def mee2rv(p, f, g, h, k, L):
Equation 3a and 3b.
"""
- w = 1 + f * np.cos(L) + g * np.sin(L)
+ w = 1 + f * cos(L) + g * sin(L)
r = p / w
s2 = 1 + h**2 + k**2
alpha2 = h**2 - k**2
- rx = (r / s2)(np.cos(L) + alpha2**2 * np.cos(L) + 2 * h * k * np.sin(L))
- ry = (r / s2)(np.sin(L) - alpha2**2 * np.sin(L) + 2 * h * k * np.cos(L))
- rz = (2 * r / s2)(h * np.sin(L) - k * np.cos(L))
+ rx = (r / s2) * (cos(L) + alpha2**2 * cos(L) + 2 * h * k * sin(L))
+ ry = (r / s2) * (sin(L) - alpha2**2 * sin(L) + 2 * h * k * cos(L))
+ rz = (2 * r / s2) * (h * sin(L) - k * cos(L))
vx = (
(-1 / s2)
- * (np.sqrt(k / p))
+ * (sqrt(k / p))
* (
- np.sin(L)
- + alpha2 * np.sin(L)
- - 2 * h * k * np.cos(L)
+ sin(L)
+ + alpha2 * sin(L)
+ - 2 * h * k * cos(L)
+ g
- 2 * f * h * k
+ alpha2 * g
@@ -543,16 +650,60 @@ def mee2rv(p, f, g, h, k, L):
)
vy = (
(-1 / s2)
- * (np.sqrt(k / p))
+ * (sqrt(k / p))
* (
- -np.cos(L)
- + alpha2 * np.cos(L)
- + 2 * h * k * np.sin(L)
+ -cos(L)
+ + alpha2 * cos(L)
+ + 2 * h * k * sin(L)
- f
+ 2 * g * h * k
+ alpha2 * f
)
)
- vz = (2 / s2) * (np.sqrt(k / p)) * (h * np.cos(L) + k * np.sin(L) + f * h + g * k)
+ vz = (2 / s2) * (sqrt(k / p)) * (h * cos(L) + k * sin(L) + f * h + g * k)
+
+ return (rx, ry, rz), (vx, vy, vz)
+
+
+@gjit("void(f,f,f,f,f,f,u1[:],f[:],f[:])", "(),(),(),(),(),(),(n)->(n),(n)")
+def mee2rv_gf(p, f, g, h, k, L, dummy, r, v):
+ """
+ Vectorized mee2rv
+ """
+ assert dummy.shape == (3,)
+
+ (r[0], r[1], r[2]), (v[0], v[1], v[2]) = mee2rv_hf(p, f, g, h, k, L)
+
+
+@hjit("f(f,f)", inline=True)
+def mean_motion_hf(k, a):
+ """
+ Mean motion given body (k) and semimajor axis (a).
+ """
+ return sqrt(k / fabs(a * a * a))
+
+
+@vjit("f(f,f)")
+def mean_motion_vf(k, a):
+ """
+ Vectorized mean_motion
+ """
+ return mean_motion_hf(k, a)
+
+
+@hjit("f(f,f)", inline=True)
+def period_hf(k, a):
+ """
+ Period given body (k) and semimajor axis (a).
+ """
+ n = mean_motion_hf(k, a)
+ return 2 * pi / n
+
+
+@vjit("f(f,f)")
+def period_vf(k, a):
+ """
+ Vectorized period
+ """
- return np.array([rx, ry, rz]), np.array([vx, vy, vz])
+ return period_hf(k, a)
diff --git a/src/hapsira/core/events.py b/src/hapsira/core/events.py
index 921ba0c80..eed0e1c65 100644
--- a/src/hapsira/core/events.py
+++ b/src/hapsira/core/events.py
@@ -1,22 +1,38 @@
-from numba import njit as jit
-import numpy as np
+from math import acos, asin, cos, sin, sqrt
-from hapsira._math.linalg import norm
-from hapsira.core.elements import coe_rotation_matrix, rv2coe
-from hapsira.core.util import planetocentric_to_AltAz
+from .elements import coe_rotation_matrix_hf, rv2coe_hf, RV2COE_TOL
+from .jit import array_to_V_hf, hjit, gjit
+from .math.linalg import div_Vs_hf, matmul_MV_hf, matmul_VV_hf, norm_V_hf, sub_VV_hf
+from .util import planetocentric_to_AltAz_hf
-@jit
-def eclipse_function(k, u_, r_sec, R_sec, R_primary, umbra=True):
+__all__ = [
+ "ECLIPSE_UMBRA",
+ "eclipse_function_hf",
+ "eclipse_function_gf",
+ "line_of_sight_hf",
+ "line_of_sight_gf",
+ "elevation_function_hf",
+ "elevation_function_gf",
+]
+
+
+ECLIPSE_UMBRA = True
+
+
+@hjit("f(f,V,V,V,f,f,b1)")
+def eclipse_function_hf(k, rr, vv, r_sec, R_sec, R_primary, umbra):
"""Calculates a continuous shadow function.
Parameters
----------
k : float
Standard gravitational parameter (km^3 / s^2).
- u_ : numpy.ndarray
- Satellite position and velocity vector with respect to the primary body.
- r_sec : numpy.ndarray
+ rr : tuple[float,float,float]
+ Satellite position vector with respect to the primary body.
+ vv : tuple[float,float,float]
+ Satellite velocity vector with respect to the primary body.
+ r_sec : tuple[float,float,float]
Position vector of the secondary body with respect to the primary body.
R_sec : float
Equatorial radius of the secondary body.
@@ -32,40 +48,61 @@ def eclipse_function(k, u_, r_sec, R_sec, R_primary, umbra=True):
The current implementation assumes circular bodies and doesn't account for flattening.
"""
+
# Plus or minus condition
pm = 1 if umbra else -1
- p, ecc, inc, raan, argp, nu = rv2coe(k, u_[:3], u_[3:])
+ p, ecc, inc, raan, argp, nu = rv2coe_hf(k, rr, vv, RV2COE_TOL)
- PQW = coe_rotation_matrix(inc, raan, argp)
- # Make arrays contiguous for faster dot product with numba.
- P_, Q_ = np.ascontiguousarray(PQW[:, 0]), np.ascontiguousarray(PQW[:, 1])
+ PQW = coe_rotation_matrix_hf(inc, raan, argp)
+ P_ = PQW[0][0], PQW[1][0], PQW[2][0]
+ Q_ = PQW[0][1], PQW[1][1], PQW[2][1]
- r_sec_norm = norm(r_sec)
- beta = (P_ @ r_sec) / r_sec_norm
- zeta = (Q_ @ r_sec) / r_sec_norm
+ r_sec_norm = norm_V_hf(r_sec)
+ beta = matmul_VV_hf(P_, r_sec) / r_sec_norm
+ zeta = matmul_VV_hf(Q_, r_sec) / r_sec_norm
- sin_delta_shadow = np.sin((R_sec - pm * R_primary) / r_sec_norm)
+ sin_delta_shadow = sin((R_sec - pm * R_primary) / r_sec_norm)
- cos_psi = beta * np.cos(nu) + zeta * np.sin(nu)
+ cos_psi = beta * cos(nu) + zeta * sin(nu)
shadow_function = (
- ((R_primary**2) * (1 + ecc * np.cos(nu)) ** 2)
+ ((R_primary**2) * (1 + ecc * cos(nu)) ** 2)
+ (p**2) * (cos_psi**2)
- p**2
- + pm * (2 * p * R_primary * cos_psi) * (1 + ecc * np.cos(nu)) * sin_delta_shadow
+ + pm * (2 * p * R_primary * cos_psi) * (1 + ecc * cos(nu)) * sin_delta_shadow
)
return shadow_function
-@jit
-def line_of_sight(r1, r2, R):
+@gjit(
+ "void(f,f[:],f[:],f[:],f,f,b1,f[:])",
+ "(),(n),(n),(n),(),(),()->()",
+)
+def eclipse_function_gf(k, rr, vv, r_sec, R_sec, R_primary, umbra, eclipse):
+ """
+ Vectorized eclipse_function
+ """
+
+ eclipse[0] = eclipse_function_hf(
+ k,
+ array_to_V_hf(rr),
+ array_to_V_hf(vv),
+ array_to_V_hf(r_sec),
+ R_sec,
+ R_primary,
+ umbra,
+ )
+
+
+@hjit("f(V,V,f)")
+def line_of_sight_hf(r1, r2, R):
"""Calculates the line of sight condition between two position vectors, r1 and r2.
Parameters
----------
- r1 : numpy.ndarray
+ r1 : tuple[float,float,float]
The position vector of the first object with respect to a central attractor.
- r2 : numpy.ndarray
+ r2 : tuple[float,float,float]
The position vector of the second object with respect to a central attractor.
R : float
The radius of the central attractor.
@@ -77,27 +114,34 @@ def line_of_sight(r1, r2, R):
located by r1 and r2, else negative.
"""
- r1_norm = norm(r1)
- r2_norm = norm(r2)
+ r1_norm = norm_V_hf(r1)
+ r2_norm = norm_V_hf(r2)
- theta = np.arccos((r1 @ r2) / r1_norm / r2_norm)
- theta_1 = np.arccos(R / r1_norm)
- theta_2 = np.arccos(R / r2_norm)
+ theta = acos(matmul_VV_hf(r1, r2) / r1_norm / r2_norm)
+ theta_1 = acos(R / r1_norm)
+ theta_2 = acos(R / r2_norm)
return (theta_1 + theta_2) - theta
-@jit
-def elevation_function(k, u_, phi, theta, R, R_p, H):
+@gjit("void(f[:],f[:],f,f[:])", "(n),(n),()->()")
+def line_of_sight_gf(r1, r2, R, delta_theta):
+ """
+ Vectorized line_of_sight
+ """
+
+ delta_theta[0] = line_of_sight_hf(array_to_V_hf(r1), array_to_V_hf(r2), R)
+
+
+@hjit("f(V,f,f,f,f,f)")
+def elevation_function_hf(rr, phi, theta, R, R_p, H):
"""Calculates the elevation angle of an object in orbit with respect to
a location on attractor.
Parameters
----------
- k: float
- Standard gravitational parameter.
- u_: numpy.ndarray
- Satellite position and velocity vector with respect to the central attractor.
+ rr: tuple[float,float,float]
+ Satellite position vector with respect to the central attractor.
phi: float
Geodetic Latitude of the station.
theta: float
@@ -109,26 +153,38 @@ def elevation_function(k, u_, phi, theta, R, R_p, H):
H: float
Elevation, above the ellipsoidal surface.
"""
- ecc = np.sqrt(1 - (R_p / R) ** 2)
- denom = np.sqrt(1 - ecc**2 * np.sin(phi) ** 2)
+
+ cos_phi = cos(phi)
+ sin_phi = sin(phi)
+
+ ecc = sqrt(1 - (R_p / R) ** 2)
+ denom = sqrt(1 - ecc * ecc * sin_phi * sin_phi)
g1 = H + (R / denom)
- g2 = H + (1 - ecc**2) * R / denom
+ g2 = H + (1 - ecc * ecc) * R / denom
+
# Coordinates of location on attractor.
- coords = np.array(
- [
- g1 * np.cos(phi) * np.cos(theta),
- g1 * np.cos(phi) * np.sin(theta),
- g2 * np.sin(phi),
- ]
+ coords = (
+ g1 * cos_phi * cos(theta),
+ g1 * cos_phi * sin(theta),
+ g2 * sin_phi,
)
# Position of satellite with respect to a point on attractor.
- rho = np.subtract(u_[:3], coords)
+ rho = sub_VV_hf(rr, coords)
- rot_matrix = planetocentric_to_AltAz(theta, phi)
+ rot_matrix = planetocentric_to_AltAz_hf(theta, phi)
- new_rho = rot_matrix @ rho
- new_rho = new_rho / np.linalg.norm(new_rho)
- el = np.arcsin(new_rho[-1])
+ new_rho = matmul_MV_hf(rot_matrix, rho)
+ new_rho = div_Vs_hf(new_rho, norm_V_hf(new_rho))
+ el = asin(new_rho[-1])
return el
+
+
+@gjit("void(f[:],f,f,f,f,f,f[:])", "(n),(),(),(),(),()->()")
+def elevation_function_gf(rr, phi, theta, R, R_p, H, el):
+ """
+ Vectorized elevation_function
+ """
+
+ el[0] = elevation_function_hf(array_to_V_hf(rr), phi, theta, R, R_p, H)
diff --git a/src/hapsira/core/flybys.py b/src/hapsira/core/flybys.py
index be8e8d656..53339ff60 100644
--- a/src/hapsira/core/flybys.py
+++ b/src/hapsira/core/flybys.py
@@ -4,7 +4,8 @@
import numpy as np
from numpy import cross
-from hapsira._math.linalg import norm
+from .jit import array_to_V_hf
+from .math.linalg import norm_V_hf
@jit
@@ -33,7 +34,7 @@ def compute_flyby(v_spacecraft, v_body, k, r_p, theta):
"""
v_inf_1 = v_spacecraft - v_body # Hyperbolic excess velocity
- v_inf = norm(v_inf_1)
+ v_inf = norm_V_hf(array_to_V_hf(v_inf_1))
ecc = 1 + r_p * v_inf**2 / k # Eccentricity of the entry hyperbola
delta = 2 * np.arcsin(1 / ecc) # Turn angle
@@ -48,7 +49,7 @@ def compute_flyby(v_spacecraft, v_body, k, r_p, theta):
S_vec = v_inf_1 / v_inf
c_vec = np.array([0, 0, 1])
T_vec = cross(S_vec, c_vec)
- T_vec = T_vec / norm(T_vec)
+ T_vec = T_vec / norm_V_hf(array_to_V_hf(T_vec))
R_vec = cross(S_vec, T_vec)
# This vector defines the B-Plane
@@ -62,7 +63,7 @@ def compute_flyby(v_spacecraft, v_body, k, r_p, theta):
# And now we rotate the outbound hyperbolic excess velocity
# u_vec = v_inf_1 / norm(v_inf) = S_vec
v_vec = cross(rot_v, v_inf_1)
- v_vec = v_vec / norm(v_vec)
+ v_vec = v_vec / norm_V_hf(array_to_V_hf(v_vec))
v_inf_2 = v_inf * (np.cos(delta) * S_vec + np.sin(delta) * v_vec)
diff --git a/src/hapsira/core/iod.py b/src/hapsira/core/iod.py
index 601705ff2..d7cf19793 100644
--- a/src/hapsira/core/iod.py
+++ b/src/hapsira/core/iod.py
@@ -1,13 +1,299 @@
-from numba import njit as jit
-import numpy as np
-from numpy import cross, pi
+from math import acos, asinh, exp, floor, inf, log, pi, sqrt
+
+from .jit import array_to_V_hf, hjit, gjit, vjit
+from .math.linalg import (
+ add_VV_hf,
+ cross_VV_hf,
+ div_ss_hf,
+ div_Vs_hf,
+ matmul_VV_hf,
+ mul_Vs_hf,
+ norm_V_hf,
+ sub_VV_hf,
+)
+from .math.special import hyp2f1b_hf, stumpff_c2_hf, stumpff_c3_hf
+
+
+__all__ = [
+ "vallado_hf",
+ "vallado_gf",
+ "izzo_hf",
+ "izzo_gf",
+ "compute_T_min_hf",
+ "compute_T_min_gf",
+ "compute_y_hf",
+ "compute_y_vf",
+ "tof_equation_y_hf",
+ "tof_equation_y_vf",
+ "find_xy_hf",
+ "find_xy_gf",
+]
+
+
+@hjit("f(f,f)")
+def compute_y_hf(x, ll):
+ """Computes y."""
+ return sqrt(1 - ll**2 * (1 - x**2))
+
+
+@vjit("f(f,f)")
+def compute_y_vf(x, ll):
+ """
+ Vectorized compute_y
+ """
+
+ return compute_y_hf(x, ll)
+
+
+@hjit("f(f,f,f,f)")
+def _tof_equation_p_hf(x, y, T, ll):
+ # TODO: What about derivatives when x approaches 1?
+ return (3 * T * x - 2 + 2 * ll**3 * x / y) / (1 - x**2)
+
+
+@hjit("f(f,f,f,f,f)")
+def _tof_equation_p2_hf(x, y, T, dT, ll):
+ return (3 * T + 5 * x * dT + 2 * (1 - ll**2) * ll**3 / y**3) / (1 - x**2)
+
+
+@hjit("f(f,f,f,f,f,f)")
+def _tof_equation_p3_hf(x, y, _, dT, ddT, ll):
+ return (7 * x * ddT + 8 * dT - 6 * (1 - ll**2) * ll**5 * x / y**5) / (
+ 1 - x**2
+ )
+
+
+@hjit("f(f,f,i8,b1)")
+def _initial_guess_hf(T, ll, M, lowpath):
+ """Initial guess."""
+ if M == 0:
+ # Single revolution
+ T_0 = acos(ll) + ll * sqrt(1 - ll**2) + M * pi # Equation 19
+ T_1 = 2 * (1 - ll**3) / 3 # Equation 21
+ if T >= T_0:
+ x_0 = (T_0 / T) ** (2 / 3) - 1
+ elif T < T_1:
+ x_0 = 5 / 2 * T_1 / T * (T_1 - T) / (1 - ll**5) + 1
+ else:
+ # This is the real condition, which is not exactly equivalent
+ # elif T_1 < T < T_0
+ # Corrected initial guess,
+ # piecewise equation right after expression (30) in the original paper is incorrect
+ # See https://github.com/poliastro/poliastro/issues/1362
+ x_0 = exp(log(2) * log(T / T_0) / log(T_1 / T_0)) - 1
+
+ return x_0
+ else:
+ # Multiple revolution
+ x_0l = (((M * pi + pi) / (8 * T)) ** (2 / 3) - 1) / (
+ ((M * pi + pi) / (8 * T)) ** (2 / 3) + 1
+ )
+ x_0r = (((8 * T) / (M * pi)) ** (2 / 3) - 1) / (
+ ((8 * T) / (M * pi)) ** (2 / 3) + 1
+ )
+
+ # Select one of the solutions according to desired type of path
+ x_0 = max((x_0l, x_0r)) if lowpath else min((x_0l, x_0r))
+
+ return x_0
+
+
+@hjit("Tuple([f,f,f,f])(f,f,f,f,f,f,f,f)")
+def _reconstruct_hf(x, y, r1, r2, ll, gamma, rho, sigma):
+ """Reconstruct solution velocity vectors."""
+ V_r1 = gamma * ((ll * y - x) - rho * (ll * y + x)) / r1
+ V_r2 = -gamma * ((ll * y - x) + rho * (ll * y + x)) / r2
+ V_t1 = gamma * sigma * (y + ll * x) / r1
+ V_t2 = gamma * sigma * (y + ll * x) / r2
+ return V_r1, V_r2, V_t1, V_t2
+
+
+@hjit("f(f,f,f)")
+def _compute_psi_hf(x, y, ll):
+ """Computes psi.
+
+ "The auxiliary angle psi is computed using Eq.(17) by the appropriate
+ inverse function"
+
+ """
+ if -1 <= x < 1:
+ # Elliptic motion
+ # Use arc cosine to avoid numerical errors
+ return acos(x * y + ll * (1 - x**2))
+ elif x > 1:
+ # Hyperbolic motion
+ # The hyperbolic sine is bijective
+ return asinh((y - x * ll) * sqrt(x**2 - 1))
+ else:
+ # Parabolic motion
+ return 0.0
+
+
+@hjit("f(f,f,f,f,i8)")
+def tof_equation_y_hf(x, y, T0, ll, M):
+ """Time of flight equation with externally computated y."""
+ if M == 0 and sqrt(0.6) < x < sqrt(1.4):
+ eta = y - ll * x
+ S_1 = (1 - ll - x * eta) * 0.5
+ Q = 4 / 3 * hyp2f1b_hf(S_1)
+ T_ = (eta**3 * Q + 4 * ll * eta) * 0.5
+ else:
+ psi = _compute_psi_hf(x, y, ll)
+ T_ = div_ss_hf(
+ (div_ss_hf((psi + M * pi), sqrt(abs(1 - x**2))) - x + ll * y),
+ (1 - x**2),
+ )
+ return T_ - T0
+
+
+@vjit("f(f,f,f,f,i8)")
+def tof_equation_y_vf(x, y, T0, ll, M):
+ """
+ Vectorized tof_equation_y
+ """
+
+ return tof_equation_y_hf(x, y, T0, ll, M)
+
+
+@hjit("f(f,f,f,i8)")
+def _tof_equation_hf(x, T0, ll, M):
+ """Time of flight equation."""
+ return tof_equation_y_hf(x, compute_y_hf(x, ll), T0, ll, M)
+
+
+@hjit("f(f,f,f,f,i8)")
+def _halley_hf(p0, T0, ll, tol, maxiter):
+ """Find a minimum of time of flight equation using the Halley method.
+
+ Notes
+ -----
+ This function is private because it assumes a calling convention specific to
+ this module and is not really reusable.
+
+ """
+ for ii in range(maxiter):
+ y = compute_y_hf(p0, ll)
+ fder = _tof_equation_p_hf(p0, y, T0, ll)
+ fder2 = _tof_equation_p2_hf(p0, y, T0, fder, ll)
+ if fder2 == 0:
+ raise RuntimeError("Derivative was zero")
+ fder3 = _tof_equation_p3_hf(p0, y, T0, fder, fder2, ll)
+
+ # Halley step (cubic)
+ p = p0 - 2 * fder * fder2 / (2 * fder2**2 - fder * fder3)
+
+ if abs(p - p0) < tol:
+ return p
+ p0 = p
+
+ raise RuntimeError("Failed to converge")
+
+
+@hjit("Tuple([f,f])(f,i8,i8,f)")
+def compute_T_min_hf(ll, M, numiter, rtol):
+ """Compute minimum T."""
+ if ll == 1:
+ x_T_min = 0.0
+ T_min = _tof_equation_hf(x_T_min, 0.0, ll, M)
+ else:
+ if M == 0:
+ x_T_min = inf
+ T_min = 0.0
+ else:
+ # Set x_i > 0 to avoid problems at ll = -1
+ x_i = 0.1
+ T_i = _tof_equation_hf(x_i, 0.0, ll, M)
+ x_T_min = _halley_hf(x_i, T_i, ll, rtol, numiter)
+ T_min = _tof_equation_hf(x_T_min, 0.0, ll, M)
+
+ return x_T_min, T_min
+
+
+@gjit("void(f,i8,i8,f,f[:],f[:])", "(),(),(),()->(),()")
+def compute_T_min_gf(ll, M, numiter, rtol, x_T_min, T_min):
+ """
+ Vectorized compute_T_min
+ """
+
+ x_T_min[0], T_min[0] = compute_T_min_hf(ll, M, numiter, rtol)
+
+
+@hjit("f(f,f,f,i8,f,i8)")
+def _householder_hf(p0, T0, ll, M, tol, maxiter):
+ """Find a zero of time of flight equation using the Householder method.
+
+ Notes
+ -----
+ This function is private because it assumes a calling convention specific to
+ this module and is not really reusable.
+
+ """
+ for ii in range(maxiter):
+ y = compute_y_hf(p0, ll)
+ fval = tof_equation_y_hf(p0, y, T0, ll, M)
+ T = fval + T0
+ fder = _tof_equation_p_hf(p0, y, T, ll)
+ fder2 = _tof_equation_p2_hf(p0, y, T, fder, ll)
+ fder3 = _tof_equation_p3_hf(p0, y, T, fder, fder2, ll)
+
+ # Householder step (quartic)
+ p = p0 - fval * (
+ (fder**2 - fval * fder2 / 2)
+ / (fder * (fder**2 - fval * fder2) + fder3 * fval**2 / 6)
+ )
+
+ if abs(p - p0) < tol:
+ return p
+ p0 = p
+
+ raise RuntimeError("Failed to converge")
+
+
+@hjit("Tuple([f,f])(f,f,i8,i8,b1,f)")
+def find_xy_hf(ll, T, M, numiter, lowpath, rtol):
+ """Computes all x, y for given number of revolutions."""
+ # For abs(ll) == 1 the derivative is not continuous
+ assert abs(ll) < 1
+ assert T > 0 # Mistake in the original paper
+
+ M_max = floor(T / pi)
+ T_00 = acos(ll) + ll * sqrt(1 - ll**2) # T_xM
+
+ # Refine maximum number of revolutions if necessary
+ if T < T_00 + M_max * pi and M_max > 0:
+ _, T_min = compute_T_min_hf(ll, M_max, numiter, rtol)
+ if T < T_min:
+ M_max -= 1
+
+ # Check if a feasible solution exist for the given number of revolutions
+ # This departs from the original paper in that we do not compute all solutions
+ if M > M_max:
+ raise ValueError("No feasible solution, try lower M")
+
+ # Initial guess
+ x_0 = _initial_guess_hf(T, ll, M, lowpath)
+
+ # Start Householder iterations from x_0 and find x, y
+ x = _householder_hf(x_0, T, ll, M, rtol, numiter)
+ y = compute_y_hf(x, ll)
+
+ return x, y
-from hapsira._math.linalg import norm
-from hapsira._math.special import hyp2f1b, stumpff_c2 as c2, stumpff_c3 as c3
+@gjit(
+ "void(f,f,i8,i8,b1,f,f[:],f[:])",
+ "(),(),(),(),(),()->(),()",
+)
+def find_xy_gf(ll, T, M, numiter, lowpath, rtol, x, y):
+ """
+ Vectorized find_xy
+ """
+
+ x[0], y[0] = find_xy_hf(ll, T, M, numiter, lowpath, rtol)
-@jit
-def vallado(k, r0, r, tof, M, prograde, lowpath, numiter, rtol):
+
+@hjit("Tuple([V,V])(f,V,V,f,i8,b1,b1,i8,f)")
+def vallado_hf(k, r0, r, tof, M, prograde, lowpath, numiter, rtol):
r"""Solves the Lambert's problem.
The algorithm returns the initial velocity vector and the final one, these are
@@ -53,9 +339,9 @@ def vallado(k, r0, r, tof, M, prograde, lowpath, numiter, rtol):
----------
k : float
Gravitational Parameter
- r0 : numpy.ndarray
+ r0 : tuple[float,float,float]
Initial position vector
- r : numpy.ndarray
+ r : tuple[float,float,float]
Final position vector
tof : float
Time of flight
@@ -73,9 +359,9 @@ def vallado(k, r0, r, tof, M, prograde, lowpath, numiter, rtol):
Returns
-------
- v0: numpy.ndarray
+ v0: tuple[float,float,float]
Initial velocity vector
- v: numpy.ndarray
+ v: tuple[float,float,float]
Final velocity vector
Examples
@@ -102,20 +388,20 @@ def vallado(k, r0, r, tof, M, prograde, lowpath, numiter, rtol):
"""
# TODO: expand for the multi-revolution case.
- # Issue: https://github.com/hapsira/hapsira/issues/858
+ # Issue: https://github.com/poliastro/poliastro/issues/858
if M > 0:
raise NotImplementedError(
- "Multi-revolution scenario not supported for Vallado. See issue https://github.com/hapsira/hapsira/issues/858"
+ "Multi-revolution scenario not supported for Vallado. See issue https://github.com/poliastro/poliastro/issues/858"
)
t_m = 1 if prograde else -1
- norm_r0 = norm(r0)
- norm_r = norm(r)
+ norm_r0 = norm_V_hf(r0)
+ norm_r = norm_V_hf(r)
norm_r0_times_norm_r = norm_r0 * norm_r
norm_r0_plus_norm_r = norm_r0 + norm_r
- cos_dnu = (r0 @ r) / norm_r0_times_norm_r
+ cos_dnu = matmul_VV_hf(r0, r) / norm_r0_times_norm_r
A = t_m * (norm_r * norm_r0 * (1 + cos_dnu)) ** 0.5
@@ -123,13 +409,16 @@ def vallado(k, r0, r, tof, M, prograde, lowpath, numiter, rtol):
raise RuntimeError("Cannot compute orbit, phase angle is 180 degrees")
psi = 0.0
- psi_low = -4 * np.pi**2
- psi_up = 4 * np.pi**2
+ psi_low = -4 * pi**2
+ psi_up = 4 * pi**2
count = 0
while count < numiter:
- y = norm_r0_plus_norm_r + A * (psi * c3(psi) - 1) / c2(psi) ** 0.5
+ y = (
+ norm_r0_plus_norm_r
+ + A * (psi * stumpff_c3_hf(psi) - 1) / stumpff_c2_hf(psi) ** 0.5
+ )
if A > 0.0:
# Readjust xi_low until y > 0.0
# Translated directly from Vallado
@@ -137,16 +426,19 @@ def vallado(k, r0, r, tof, M, prograde, lowpath, numiter, rtol):
psi_low = psi
psi = (
0.8
- * (1.0 / c3(psi))
- * (1.0 - norm_r0_times_norm_r * np.sqrt(c2(psi)) / A)
+ * (1.0 / stumpff_c3_hf(psi))
+ * (1.0 - norm_r0_times_norm_r * sqrt(stumpff_c2_hf(psi)) / A)
+ )
+ y = (
+ norm_r0_plus_norm_r
+ + A * (psi * stumpff_c3_hf(psi) - 1) / stumpff_c2_hf(psi) ** 0.5
)
- y = norm_r0_plus_norm_r + A * (psi * c3(psi) - 1) / c2(psi) ** 0.5
- xi = np.sqrt(y / c2(psi))
- tof_new = (xi**3 * c3(psi) + A * np.sqrt(y)) / np.sqrt(k)
+ xi = sqrt(y / stumpff_c2_hf(psi))
+ tof_new = (xi**3 * stumpff_c3_hf(psi) + A * sqrt(y)) / sqrt(k)
# Convergence check
- if np.abs((tof_new - tof) / tof) < rtol:
+ if abs((tof_new - tof) / tof) < rtol:
break
count += 1
# Bisection check
@@ -159,27 +451,41 @@ def vallado(k, r0, r, tof, M, prograde, lowpath, numiter, rtol):
raise RuntimeError("Maximum number of iterations reached")
f = 1 - y / norm_r0
- g = A * np.sqrt(y / k)
+ g = A * sqrt(y / k)
gdot = 1 - y / norm_r
- v0 = (r - f * r0) / g
- v = (gdot * r - r0) / g
+ v0 = div_Vs_hf(sub_VV_hf(r, mul_Vs_hf(r0, f)), g)
+ v = div_Vs_hf(sub_VV_hf(mul_Vs_hf(r, gdot), r0), g)
return v0, v
-@jit
-def izzo(k, r1, r2, tof, M, prograde, lowpath, numiter, rtol):
+@gjit(
+ "void(f,f[:],f[:],f,i8,b1,b1,i8,f,f[:],f[:])",
+ "(),(n),(n),(),(),(),(),(),()->(n),(n)",
+)
+def vallado_gf(k, r0, r, tof, M, prograde, lowpath, numiter, rtol, v0, v):
+ """
+ Vectorized vallado
+ """
+
+ ((v0[0], v0[1], v0[2]), (v[0], v[1], v[2])) = vallado_hf(
+ k, array_to_V_hf(r0), array_to_V_hf(r), tof, M, prograde, lowpath, numiter, rtol
+ )
+
+
+@hjit("Tuple([V,V])(f,V,V,f,i8,b1,b1,i8,f)")
+def izzo_hf(k, r1, r2, tof, M, prograde, lowpath, numiter, rtol):
"""Aplies izzo algorithm to solve Lambert's problem.
Parameters
----------
k : float
Gravitational Constant
- r1 : numpy.ndarray
+ r1 : tuple[float,float,float]
Initial position vector
- r2 : numpy.ndarray
+ r2 : tuple[float,float,float]
Final position vector
tof : float
Time of flight between both positions
@@ -197,9 +503,9 @@ def izzo(k, r1, r2, tof, M, prograde, lowpath, numiter, rtol):
Returns
-------
- v1: numpy.ndarray
+ v1: tuple[float,float,float]
Initial velocity vector
- v2: numpy.ndarray
+ v2: tuple[float,float,float]
Final velocity vector
"""
@@ -208,278 +514,83 @@ def izzo(k, r1, r2, tof, M, prograde, lowpath, numiter, rtol):
assert k > 0
# Check collinearity of r1 and r2
- if not cross(r1, r2).any():
+ cl = cross_VV_hf(r1, r2)
+ if cl[0] == 0 and cl[1] == 0 and cl[2] == 0:
raise ValueError("Lambert solution cannot be computed for collinear vectors")
# Chord
- c = r2 - r1
- c_norm, r1_norm, r2_norm = norm(c), norm(r1), norm(r2)
+ c = sub_VV_hf(r2, r1)
+ c_norm, r1_norm, r2_norm = (
+ norm_V_hf(c),
+ norm_V_hf(r1),
+ norm_V_hf(r2),
+ )
# Semiperimeter
s = (r1_norm + r2_norm + c_norm) * 0.5
# Versors
- i_r1, i_r2 = r1 / r1_norm, r2 / r2_norm
- i_h = cross(i_r1, i_r2)
- i_h = i_h / norm(i_h) # Fixed from paper
+ i_r1, i_r2 = div_Vs_hf(r1, r1_norm), div_Vs_hf(r2, r2_norm)
+ i_h = cross_VV_hf(i_r1, i_r2)
+ i_h = div_Vs_hf(i_h, norm_V_hf(i_h)) # Fixed from paper
# Geometry of the problem
- ll = np.sqrt(1 - min(1.0, c_norm / s))
+ ll = sqrt(1 - min(1.0, c_norm / s))
# Compute the fundamental tangential directions
if i_h[2] < 0:
ll = -ll
- i_t1, i_t2 = cross(i_r1, i_h), cross(i_r2, i_h)
+ i_t1, i_t2 = cross_VV_hf(i_r1, i_h), cross_VV_hf(i_r2, i_h)
else:
- i_t1, i_t2 = cross(i_h, i_r1), cross(i_h, i_r2)
+ i_t1, i_t2 = cross_VV_hf(i_h, i_r1), cross_VV_hf(i_h, i_r2)
# Correct transfer angle parameter and tangential vectors if required
- ll, i_t1, i_t2 = (ll, i_t1, i_t2) if prograde else (-ll, -i_t1, -i_t2)
+ ll, i_t1, i_t2 = (
+ (ll, i_t1, i_t2)
+ if prograde
+ else (-ll, mul_Vs_hf(i_t1, -1), mul_Vs_hf(i_t2, -1))
+ )
# Non dimensional time of flight
- T = np.sqrt(2 * k / s**3) * tof
+ T = sqrt(2 * k / s**3) * tof
# Find solutions
- x, y = _find_xy(ll, T, M, numiter, lowpath, rtol)
+ x, y = find_xy_hf(ll, T, M, numiter, lowpath, rtol)
# Reconstruct
- gamma = np.sqrt(k * s / 2)
+ gamma = sqrt(k * s / 2)
rho = (r1_norm - r2_norm) / c_norm
- sigma = np.sqrt(1 - rho**2)
+ sigma = sqrt(1 - rho**2)
# Compute the radial and tangential components at r0 and r
- V_r1, V_r2, V_t1, V_t2 = _reconstruct(x, y, r1_norm, r2_norm, ll, gamma, rho, sigma)
+ V_r1, V_r2, V_t1, V_t2 = _reconstruct_hf(
+ x, y, r1_norm, r2_norm, ll, gamma, rho, sigma
+ )
# Solve for the initial and final velocity
- v1 = V_r1 * (r1 / r1_norm) + V_t1 * i_t1
- v2 = V_r2 * (r2 / r2_norm) + V_t2 * i_t2
+ v1 = add_VV_hf(mul_Vs_hf(div_Vs_hf(r1, r1_norm), V_r1), mul_Vs_hf(i_t1, V_t1))
+ v2 = add_VV_hf(mul_Vs_hf(div_Vs_hf(r2, r2_norm), V_r2), mul_Vs_hf(i_t2, V_t2))
return v1, v2
-@jit
-def _reconstruct(x, y, r1, r2, ll, gamma, rho, sigma):
- """Reconstruct solution velocity vectors."""
- V_r1 = gamma * ((ll * y - x) - rho * (ll * y + x)) / r1
- V_r2 = -gamma * ((ll * y - x) + rho * (ll * y + x)) / r2
- V_t1 = gamma * sigma * (y + ll * x) / r1
- V_t2 = gamma * sigma * (y + ll * x) / r2
- return V_r1, V_r2, V_t1, V_t2
-
-
-@jit
-def _find_xy(ll, T, M, numiter, lowpath, rtol):
- """Computes all x, y for given number of revolutions."""
- # For abs(ll) == 1 the derivative is not continuous
- assert abs(ll) < 1
- assert T > 0 # Mistake in the original paper
-
- M_max = np.floor(T / pi)
- T_00 = np.arccos(ll) + ll * np.sqrt(1 - ll**2) # T_xM
-
- # Refine maximum number of revolutions if necessary
- if T < T_00 + M_max * pi and M_max > 0:
- _, T_min = _compute_T_min(ll, M_max, numiter, rtol)
- if T < T_min:
- M_max -= 1
-
- # Check if a feasible solution exist for the given number of revolutions
- # This departs from the original paper in that we do not compute all solutions
- if M > M_max:
- raise ValueError("No feasible solution, try lower M")
-
- # Initial guess
- x_0 = _initial_guess(T, ll, M, lowpath)
-
- # Start Householder iterations from x_0 and find x, y
- x = _householder(x_0, T, ll, M, rtol, numiter)
- y = _compute_y(x, ll)
-
- return x, y
-
-
-@jit
-def _compute_y(x, ll):
- """Computes y."""
- return np.sqrt(1 - ll**2 * (1 - x**2))
-
-
-@jit
-def _compute_psi(x, y, ll):
- """Computes psi.
-
- "The auxiliary angle psi is computed using Eq.(17) by the appropriate
- inverse function"
-
+@gjit(
+ "void(f,f[:],f[:],f,i8,b1,b1,i8,f,f[:],f[:])",
+ "(),(n),(n),(),(),(),(),(),()->(n),(n)",
+)
+def izzo_gf(k, r1, r2, tof, M, prograde, lowpath, numiter, rtol, v1, v2):
"""
- if -1 <= x < 1:
- # Elliptic motion
- # Use arc cosine to avoid numerical errors
- return np.arccos(x * y + ll * (1 - x**2))
- elif x > 1:
- # Hyperbolic motion
- # The hyperbolic sine is bijective
- return np.arcsinh((y - x * ll) * np.sqrt(x**2 - 1))
- else:
- # Parabolic motion
- return 0.0
-
-
-@jit
-def _tof_equation(x, T0, ll, M):
- """Time of flight equation."""
- return _tof_equation_y(x, _compute_y(x, ll), T0, ll, M)
-
-
-@jit
-def _tof_equation_y(x, y, T0, ll, M):
- """Time of flight equation with externally computated y."""
- if M == 0 and np.sqrt(0.6) < x < np.sqrt(1.4):
- eta = y - ll * x
- S_1 = (1 - ll - x * eta) * 0.5
- Q = 4 / 3 * hyp2f1b(S_1)
- T_ = (eta**3 * Q + 4 * ll * eta) * 0.5
- else:
- psi = _compute_psi(x, y, ll)
- T_ = np.divide(
- np.divide(psi + M * pi, np.sqrt(np.abs(1 - x**2))) - x + ll * y,
- (1 - x**2),
- )
-
- return T_ - T0
-
-
-@jit
-def _tof_equation_p(x, y, T, ll):
- # TODO: What about derivatives when x approaches 1?
- return (3 * T * x - 2 + 2 * ll**3 * x / y) / (1 - x**2)
-
-
-@jit
-def _tof_equation_p2(x, y, T, dT, ll):
- return (3 * T + 5 * x * dT + 2 * (1 - ll**2) * ll**3 / y**3) / (1 - x**2)
-
-
-@jit
-def _tof_equation_p3(x, y, _, dT, ddT, ll):
- return (7 * x * ddT + 8 * dT - 6 * (1 - ll**2) * ll**5 * x / y**5) / (
- 1 - x**2
- )
-
-
-@jit
-def _compute_T_min(ll, M, numiter, rtol):
- """Compute minimum T."""
- if ll == 1:
- x_T_min = 0.0
- T_min = _tof_equation(x_T_min, 0.0, ll, M)
- else:
- if M == 0:
- x_T_min = np.inf
- T_min = 0.0
- else:
- # Set x_i > 0 to avoid problems at ll = -1
- x_i = 0.1
- T_i = _tof_equation(x_i, 0.0, ll, M)
- x_T_min = _halley(x_i, T_i, ll, rtol, numiter)
- T_min = _tof_equation(x_T_min, 0.0, ll, M)
-
- return x_T_min, T_min
-
-
-@jit
-def _initial_guess(T, ll, M, lowpath):
- """Initial guess."""
- if M == 0:
- # Single revolution
- T_0 = np.arccos(ll) + ll * np.sqrt(1 - ll**2) + M * pi # Equation 19
- T_1 = 2 * (1 - ll**3) / 3 # Equation 21
- if T >= T_0:
- x_0 = (T_0 / T) ** (2 / 3) - 1
- elif T < T_1:
- x_0 = 5 / 2 * T_1 / T * (T_1 - T) / (1 - ll**5) + 1
- else:
- # This is the real condition, which is not exactly equivalent
- # elif T_1 < T < T_0
- # Corrected initial guess,
- # piecewise equation right after expression (30) in the original paper is incorrect
- # See https://github.com/hapsira/hapsira/issues/1362
- x_0 = np.exp(np.log(2) * np.log(T / T_0) / np.log(T_1 / T_0)) - 1
-
- return x_0
- else:
- # Multiple revolution
- x_0l = (((M * pi + pi) / (8 * T)) ** (2 / 3) - 1) / (
- ((M * pi + pi) / (8 * T)) ** (2 / 3) + 1
- )
- x_0r = (((8 * T) / (M * pi)) ** (2 / 3) - 1) / (
- ((8 * T) / (M * pi)) ** (2 / 3) + 1
- )
-
- # Select one of the solutions according to desired type of path
- x_0 = (
- np.max(np.array([x_0l, x_0r]))
- if lowpath
- else np.min(np.array([x_0l, x_0r]))
- )
-
- return x_0
-
-
-@jit
-def _halley(p0, T0, ll, tol, maxiter):
- """Find a minimum of time of flight equation using the Halley method.
-
- Notes
- -----
- This function is private because it assumes a calling convention specific to
- this module and is not really reusable.
-
- """
- for ii in range(maxiter):
- y = _compute_y(p0, ll)
- fder = _tof_equation_p(p0, y, T0, ll)
- fder2 = _tof_equation_p2(p0, y, T0, fder, ll)
- if fder2 == 0:
- raise RuntimeError("Derivative was zero")
- fder3 = _tof_equation_p3(p0, y, T0, fder, fder2, ll)
-
- # Halley step (cubic)
- p = p0 - 2 * fder * fder2 / (2 * fder2**2 - fder * fder3)
-
- if abs(p - p0) < tol:
- return p
- p0 = p
-
- raise RuntimeError("Failed to converge")
-
-
-@jit
-def _householder(p0, T0, ll, M, tol, maxiter):
- """Find a zero of time of flight equation using the Householder method.
-
- Notes
- -----
- This function is private because it assumes a calling convention specific to
- this module and is not really reusable.
-
+ Vectorized izzo
"""
- for ii in range(maxiter):
- y = _compute_y(p0, ll)
- fval = _tof_equation_y(p0, y, T0, ll, M)
- T = fval + T0
- fder = _tof_equation_p(p0, y, T, ll)
- fder2 = _tof_equation_p2(p0, y, T, fder, ll)
- fder3 = _tof_equation_p3(p0, y, T, fder, fder2, ll)
-
- # Householder step (quartic)
- p = p0 - fval * (
- (fder**2 - fval * fder2 / 2)
- / (fder * (fder**2 - fval * fder2) + fder3 * fval**2 / 6)
- )
- if abs(p - p0) < tol:
- return p
- p0 = p
-
- raise RuntimeError("Failed to converge")
+ ((v1[0], v1[1], v1[2]), (v2[0], v2[1], v2[2])) = izzo_hf(
+ k,
+ array_to_V_hf(r1),
+ array_to_V_hf(r2),
+ tof,
+ M,
+ prograde,
+ lowpath,
+ numiter,
+ rtol,
+ )
diff --git a/src/hapsira/core/jit.py b/src/hapsira/core/jit.py
new file mode 100644
index 000000000..51d97f7cd
--- /dev/null
+++ b/src/hapsira/core/jit.py
@@ -0,0 +1,323 @@
+from typing import Callable, List, Union
+
+import numba as nb
+from numba import cuda
+
+from hapsira.debug import logger
+from hapsira.errors import JitError
+from hapsira.settings import settings
+
+
+__all__ = [
+ "DSIG",
+ "hjit",
+ "djit",
+ "vjit",
+ "gjit",
+ "sjit",
+ "array_to_V_hf",
+]
+
+
+logger.debug("jit target: %s", settings["TARGET"].value)
+if settings["TARGET"].value == "cuda" and not cuda.is_available():
+ raise JitError('selected target "cuda" is not available')
+
+logger.debug("jit inline: %s", "yes" if settings["INLINE"].value else "no")
+logger.debug("jit nopython: %s", "yes" if settings["NOPYTHON"].value else "no")
+
+_PRECISIONS = (
+ settings["PRECISION"].value,
+) # TODO again allow to compile for multiple precision?
+logger.debug("jit precision: %s", settings["PRECISION"].value)
+if settings["PRECISION"].value != "f8":
+ logger.warning("jit precision: DOP853 as used by Cowell's method requires f8!")
+
+DSIG = "Tuple([V,V])(f,V,V,f)"
+
+
+def _parse_signatures(signature: str, noreturn: bool = False) -> Union[str, List[str]]:
+ """
+ Automatically generate signatures for single and double
+ """
+
+ if "->" in signature: # this is likely a layout for guvectorize
+ logger.warning(
+ "jit signature: likely a layout for guvectorize, not parsing (%s)",
+ signature,
+ )
+ return signature
+
+ if noreturn and not signature.startswith("void("):
+ raise JitError(
+ "function does not allow return values, likely compiled via guvectorize"
+ )
+
+ if not any(
+ notation in signature for notation in ("f", "V", "M")
+ ): # leave this signature as it is
+ logger.warning(
+ "jit signature: no special notation, not parsing (%s)", signature
+ )
+ return signature
+
+ if any(
+ level in signature for level in _PRECISIONS
+ ): # leave this signature as it is
+ logger.warning(
+ "jit signature: precision specified, not parsing (%s)", signature
+ )
+ return signature
+
+ signature = signature.replace("M", "Tuple([V,V,V])") # matrix is a tuple of vectors
+ signature = signature.replace("V", "Tuple([f,f,f])") # vector is a tuple of floats
+ signature = signature.replace(
+ "F", "FunctionType"
+ ) # TODO does not work for CUDA yet
+
+ return [signature.replace("f", dtype) for dtype in _PRECISIONS]
+
+
+def hjit(*args, **kwargs) -> Callable:
+ """
+ Scalar helper, pre-configured, internal, switches compiler targets.
+ Functions decorated by it can only be called directly if TARGET is cpu or parallel.
+ """
+
+ if len(args) == 1 and callable(args[0]):
+ outer_func = args[0]
+ args = tuple()
+ else:
+ outer_func = None
+
+ if len(args) > 0 and isinstance(args[0], str):
+ args = _parse_signatures(args[0]), *args[1:]
+
+ try:
+ inline = kwargs.pop("inline")
+ except KeyError:
+ inline = settings["INLINE"].value
+
+ def wrapper(inner_func: Callable) -> Callable:
+ """
+ Applies JIT
+ """
+
+ if settings["TARGET"].value == "cuda":
+ wjit = cuda.jit
+ cfg = dict(
+ device=True,
+ inline=inline,
+ cache=settings["CACHE"].value,
+ )
+ else:
+ assert settings["NOPYTHON"].value != settings["FORCEOBJ"].value
+ wjit = nb.jit
+ cfg = dict(
+ nopython=settings["NOPYTHON"].value,
+ forceobj=settings["FORCEOBJ"].value,
+ inline="always" if inline else "never",
+ cache=settings["CACHE"].value,
+ )
+ cfg.update(kwargs)
+
+ logger.debug(
+ "hjit: func=%s, args=%s, kwargs=%s",
+ getattr(inner_func, "__name__", repr(inner_func)),
+ repr(args),
+ repr(cfg),
+ )
+
+ return wjit(
+ *args,
+ **cfg,
+ )(inner_func)
+
+ if outer_func is not None:
+ return wrapper(outer_func)
+
+ return wrapper
+
+
+def djit(*args, **kwargs) -> Callable:
+ """
+ Wrapper for hjit to track differential equations
+ """
+
+ if len(args) == 1 and callable(args[0]):
+ outer_func = args[0]
+ args = tuple()
+ else:
+ outer_func = None
+
+ def wrapper(inner_func: Callable) -> Callable:
+ """
+ Applies JIT
+ """
+
+ logger.debug(
+ "djit: func=%s, args=%s, kwargs=%s",
+ getattr(inner_func, "__name__", repr(inner_func)),
+ repr(args),
+ repr(kwargs),
+ )
+
+ compiled = hjit(
+ DSIG,
+ *args,
+ **kwargs,
+ )(inner_func)
+ compiled.djit = None # attribute for debugging
+ return compiled
+
+ if outer_func is not None:
+ return wrapper(outer_func)
+
+ return wrapper
+
+
+def vjit(*args, **kwargs) -> Callable:
+ """
+ Vectorize on array, pre-configured, user-facing, switches compiler targets.
+ Functions decorated by it can always be called directly if needed.
+ """
+
+ if len(args) == 1 and callable(args[0]):
+ outer_func = args[0]
+ args = tuple()
+ else:
+ outer_func = None
+
+ if len(args) > 0 and isinstance(args[0], str):
+ args = _parse_signatures(args[0]), *args[1:]
+
+ def wrapper(inner_func: Callable) -> Callable:
+ """
+ Applies JIT
+ """
+
+ cfg = dict(
+ target=settings["TARGET"].value,
+ cache=settings["CACHE"].value,
+ )
+ if settings["TARGET"].value != "cuda":
+ assert settings["NOPYTHON"].value != settings["FORCEOBJ"].value
+ cfg["nopython"] = settings["NOPYTHON"].value
+ cfg["forceobj"] = settings["FORCEOBJ"].value
+ cfg.update(kwargs)
+
+ logger.debug(
+ "vjit: func=%s, args=%s, kwargs=%s",
+ getattr(inner_func, "__name__", repr(inner_func)),
+ repr(args),
+ repr(cfg),
+ )
+
+ return nb.vectorize(
+ *args,
+ **cfg,
+ )(inner_func)
+
+ if outer_func is not None:
+ return wrapper(outer_func)
+
+ return wrapper
+
+
+def gjit(*args, **kwargs) -> Callable:
+ """
+ General vectorize on array, pre-configured, user-facing, switches compiler targets.
+ Functions decorated by it can always be called directly if needed.
+ """
+
+ if len(args) == 1 and callable(args[0]):
+ outer_func = args[0]
+ args = tuple()
+ else:
+ outer_func = None
+
+ if len(args) > 0 and isinstance(args[0], str):
+ args = _parse_signatures(args[0], noreturn=True), *args[1:]
+
+ def wrapper(inner_func: Callable) -> Callable:
+ """
+ Applies JIT
+ """
+
+ cfg = dict(
+ target=settings["TARGET"].value,
+ cache=settings["CACHE"].value,
+ )
+ if settings["TARGET"].value != "cuda":
+ assert settings["NOPYTHON"].value != settings["FORCEOBJ"].value
+ cfg["nopython"] = settings["NOPYTHON"].value
+ cfg["forceobj"] = settings["FORCEOBJ"].value
+ cfg.update(kwargs)
+
+ logger.debug(
+ "gjit: func=%s, args=%s, kwargs=%s",
+ getattr(inner_func, "__name__", repr(inner_func)),
+ repr(args),
+ repr(cfg),
+ )
+
+ return nb.guvectorize(
+ *args,
+ **cfg,
+ )(inner_func)
+
+ if outer_func is not None:
+ return wrapper(outer_func)
+
+ return wrapper
+
+
+def sjit(*args, **kwargs) -> Callable:
+ """
+ Regular "scalar" (n)jit, pre-configured, potentially user-facing, always CPU compiler target.
+ Functions decorated by it can always be called directly if needed.
+ """
+
+ if len(args) == 1 and callable(args[0]):
+ outer_func = args[0]
+ args = tuple()
+ else:
+ outer_func = None
+
+ if len(args) > 0 and isinstance(args[0], str):
+ args = _parse_signatures(args[0]), *args[1:]
+
+ def wrapper(inner_func: Callable) -> Callable:
+ """
+ Applies JIT
+ """
+
+ assert settings["NOPYTHON"].value != settings["FORCEOBJ"].value
+ cfg = dict(
+ nopython=settings["NOPYTHON"].value,
+ forceobj=settings["FORCEOBJ"].value,
+ inline="always" if settings["INLINE"].value else "never",
+ **kwargs,
+ )
+
+ logger.debug(
+ "sjit: func=%s, args=%s, kwargs=%s",
+ getattr(inner_func, "__name__", repr(inner_func)),
+ repr(args),
+ repr(cfg),
+ )
+
+ return nb.jit(
+ *args,
+ **cfg,
+ )(inner_func)
+
+ if outer_func is not None:
+ return wrapper(outer_func)
+
+ return wrapper
+
+
+@hjit("V(f[:])")
+def array_to_V_hf(x):
+ return x[0], x[1], x[2]
diff --git a/src/hapsira/core/maneuver.py b/src/hapsira/core/maneuver.py
index 548e5aea6..25559aaa4 100644
--- a/src/hapsira/core/maneuver.py
+++ b/src/hapsira/core/maneuver.py
@@ -4,8 +4,15 @@
import numpy as np
from numpy import cross
-from hapsira._math.linalg import norm
-from hapsira.core.elements import coe_rotation_matrix, rv2coe, rv_pqw
+from hapsira.core.elements import (
+ coe_rotation_matrix_hf,
+ rv2coe_hf,
+ RV2COE_TOL,
+ rv_pqw_hf,
+)
+
+from .jit import array_to_V_hf
+from .math.linalg import norm_V_hf
@jit
@@ -40,14 +47,16 @@ def hohmann(k, rv, r_f):
Final orbital radius
"""
- _, ecc, inc, raan, argp, nu = rv2coe(k, *rv)
- h_i = norm(cross(*rv))
+ _, ecc, inc, raan, argp, nu = rv2coe_hf(
+ k, array_to_V_hf(rv[0]), array_to_V_hf(rv[1]), RV2COE_TOL
+ )
+ h_i = norm_V_hf(array_to_V_hf(cross(*rv)))
p_i = h_i**2 / k
- r_i, v_i = rv_pqw(k, p_i, ecc, nu)
+ r_i, v_i = rv_pqw_hf(k, p_i, ecc, nu)
- r_i = norm(r_i)
- v_i = norm(v_i)
+ r_i = norm_V_hf(r_i)
+ v_i = norm_V_hf(v_i)
a_trans = (r_i + r_f) / 2
dv_a = np.sqrt(2 * k / r_i - k / a_trans) - v_i
@@ -56,7 +65,7 @@ def hohmann(k, rv, r_f):
dv_a = np.array([0, dv_a, 0])
dv_b = np.array([0, -dv_b, 0])
- rot_matrix = coe_rotation_matrix(inc, raan, argp)
+ rot_matrix = np.array(coe_rotation_matrix_hf(inc, raan, argp))
dv_a = rot_matrix @ dv_a
dv_b = rot_matrix @ dv_b
@@ -110,14 +119,16 @@ def bielliptic(k, r_b, r_f, rv):
Position and velocity vectors
"""
- _, ecc, inc, raan, argp, nu = rv2coe(k, *rv)
- h_i = norm(cross(*rv))
+ _, ecc, inc, raan, argp, nu = rv2coe_hf(
+ k, array_to_V_hf(rv[0]), array_to_V_hf(rv[1]), RV2COE_TOL
+ )
+ h_i = norm_V_hf(array_to_V_hf(cross(*rv)))
p_i = h_i**2 / k
- r_i, v_i = rv_pqw(k, p_i, ecc, nu)
+ r_i, v_i = rv_pqw_hf(k, p_i, ecc, nu)
- r_i = norm(r_i)
- v_i = norm(v_i)
+ r_i = norm_V_hf(r_i)
+ v_i = norm_V_hf(v_i)
a_trans1 = (r_i + r_b) / 2
a_trans2 = (r_b + r_f) / 2
@@ -129,7 +140,7 @@ def bielliptic(k, r_b, r_f, rv):
dv_b = np.array([0, -dv_b, 0])
dv_c = np.array([0, dv_c, 0])
- rot_matrix = coe_rotation_matrix(inc, raan, argp)
+ rot_matrix = np.array(coe_rotation_matrix_hf(inc, raan, argp))
dv_a = rot_matrix @ dv_a
dv_b = rot_matrix @ dv_b
@@ -190,6 +201,6 @@ def correct_pericenter(k, R, J2, max_delta_r, v, a, inc, ecc):
delta_t = abs(delta_w / dw)
delta_v = 0.5 * n * a * ecc * abs(delta_w)
- vf_ = v / norm(v) * delta_v
+ vf_ = v / norm_V_hf(array_to_V_hf(v)) * delta_v
return delta_t, vf_
diff --git a/src/hapsira/_math/__init__.py b/src/hapsira/core/math/__init__.py
similarity index 100%
rename from src/hapsira/_math/__init__.py
rename to src/hapsira/core/math/__init__.py
diff --git a/src/hapsira/core/math/ieee754.py b/src/hapsira/core/math/ieee754.py
new file mode 100644
index 000000000..ae9cb3ec0
--- /dev/null
+++ b/src/hapsira/core/math/ieee754.py
@@ -0,0 +1,32 @@
+from numpy import (
+ float64 as f8,
+ float32 as f4,
+ float16 as f2,
+ finfo,
+ nextafter, # TODO switch to math module
+)
+
+from ...settings import settings
+
+
+__all__ = [
+ "EPS",
+ "f8",
+ "f4",
+ "f2",
+ "float_",
+ "nextafter",
+]
+
+
+if settings["PRECISION"].value == "f8":
+ float_ = f8
+elif settings["PRECISION"].value == "f4":
+ float_ = f4
+elif settings["PRECISION"].value == "f2":
+ float_ = f2
+else:
+ raise ValueError("unsupported precision")
+
+
+EPS = finfo(float_).eps
diff --git a/src/hapsira/_math/integrate.py b/src/hapsira/core/math/integrate.py
similarity index 56%
rename from src/hapsira/_math/integrate.py
rename to src/hapsira/core/math/integrate.py
index 49ad08a13..3a4565bad 100644
--- a/src/hapsira/_math/integrate.py
+++ b/src/hapsira/core/math/integrate.py
@@ -1,3 +1,5 @@
from scipy.integrate import quad
-__all__ = ["quad"]
+__all__ = [
+ "quad",
+]
diff --git a/src/hapsira/core/math/interpolate.py b/src/hapsira/core/math/interpolate.py
new file mode 100644
index 000000000..4b46a299e
--- /dev/null
+++ b/src/hapsira/core/math/interpolate.py
@@ -0,0 +1,135 @@
+from typing import Callable
+
+import numpy as np
+from scipy.interpolate import interp1d as _scipy_interp1d
+
+from ..jit import hjit
+from .linalg import add_VV_hf, div_Vs_hf, mul_Vs_hf, sub_VV_hf
+
+__all__ = [
+ "interp_hb",
+ "spline_interp",
+ "sinc_interp",
+]
+
+
+def interp_hb(x: np.ndarray, y: np.ndarray) -> Callable:
+ """
+ Builds compiled linear 1D interpolator for 3D vectors,
+ embedding x and y as const values into the binary.
+ Does not extrapolate!
+
+ Parameters
+ ----------
+ x : np.ndarray
+ Values for x
+ y : np.ndarray
+ Values for y
+
+ Returns
+ -------
+ rho : Callable
+ 1D interpolator
+
+ """
+
+ assert x.ndim == 1
+ assert y.ndim == 2
+ assert x.shape[0] >= 1 # > instead of >=
+ assert y.shape[0] == 3
+ assert y.shape[1] == x.shape[0]
+
+ y = tuple(tuple(record) for record in y.T)
+ x = tuple(x)
+ x_len = len(x)
+
+ @hjit("V(f)", cache=False)
+ def interp_hf(x_new):
+ """
+ 1D interpolator
+
+ Parameters
+ ----------
+ x_new : float
+ New value for x
+
+ Returns
+ -------
+ y_new : float
+ New value for y
+
+ """
+
+ assert x_new >= x[0]
+ assert x_new <= x[-1]
+
+ # bisect left
+ x_new_index = 0
+ hi = x_len
+ while x_new_index < hi:
+ mid = (x_new_index + hi) // 2
+ if x[mid] < x_new:
+ x_new_index = mid + 1
+ else:
+ hi = mid
+
+ # clip
+ if x_new_index > x_len:
+ x_new_index = x_len
+ if x_new_index < 1:
+ x_new_index = 1
+
+ # slope
+ lo = x_new_index - 1
+ hi = x_new_index
+ x_lo = x[lo]
+ x_hi = x[hi]
+ y_lo = y[lo] # tuple
+ y_hi = y[hi] # tuple
+ slope = div_Vs_hf(sub_VV_hf(y_hi, y_lo), x_hi - x_lo) # tuple
+
+ # new value
+ y_new = add_VV_hf(mul_Vs_hf(slope, x_new - x_lo), y_lo) # tuple
+
+ return y_new
+
+ return interp_hf
+
+
+def spline_interp(y, x, u, *, kind="cubic"):
+ """
+ Interpolates y, sampled at x instants, at u instants using `scipy.interpolate.interp1d`.
+
+ TODO compile
+
+ """
+
+ y_u = _scipy_interp1d(x, y, kind=kind)(u)
+ return y_u
+
+
+def sinc_interp(y, x, u):
+ """
+ Interpolates y, sampled at x instants, at u instants using sinc interpolation.
+
+ Notes
+ -----
+ Taken from https://gist.github.com/endolith/1297227.
+ Possibly equivalent to `scipy.signal.resample`,
+ see https://mail.python.org/pipermail/scipy-user/2012-January/031255.html.
+ However, quick experiments show different ringing behavior.
+
+ TODO compile
+
+ """
+
+ if len(y) != len(x):
+ raise ValueError("x and s must be the same length")
+
+ # Find the period and assume it's constant
+ T = x[1] - x[0]
+
+ sincM = np.tile(u, (len(x), 1)) - np.tile(x[:, np.newaxis], (1, len(u)))
+ y_u = y @ np.sinc(sincM / T)
+
+ return y_u
diff --git a/src/hapsira/core/math/ivp/__init__.py b/src/hapsira/core/math/ivp/__init__.py
new file mode 100644
index 000000000..6af349f72
--- /dev/null
+++ b/src/hapsira/core/math/ivp/__init__.py
@@ -0,0 +1,70 @@
+from ._brentq import (
+ BRENTQ_CONVERGED,
+ BRENTQ_SIGNERR,
+ BRENTQ_CONVERR,
+ BRENTQ_ERROR,
+ BRENTQ_XTOL,
+ BRENTQ_RTOL,
+ BRENTQ_MAXITER,
+ brentq_gb,
+ brentq_dense_hf,
+)
+from ._const import DENSE_SIG
+from ._rkcore import (
+ dop853_init_hf,
+ dop853_step_hf,
+ DOP853_FINISHED,
+ DOP853_FAILED,
+ DOP853_ARGK,
+ DOP853_FR,
+ DOP853_FUN,
+ DOP853_FV,
+ DOP853_H_PREVIOUS,
+ DOP853_K,
+ DOP853_RR,
+ DOP853_RR_OLD,
+ DOP853_STATUS,
+ DOP853_T,
+ DOP853_T_OLD,
+ DOP853_VV,
+ DOP853_VV_OLD,
+)
+from ._rkdenseinterp import dop853_dense_interp_brentq_hb, dop853_dense_interp_hf
+from ._rkdenseoutput import dop853_dense_output_hf
+from ._solve import event_is_active_hf, dispatcher_hb
+
+
+__all__ = [
+ "BRENTQ_CONVERGED",
+ "BRENTQ_SIGNERR",
+ "BRENTQ_CONVERR",
+ "BRENTQ_ERROR",
+ "BRENTQ_XTOL",
+ "BRENTQ_RTOL",
+ "BRENTQ_MAXITER",
+ "DENSE_SIG",
+ "DOP853_FINISHED",
+ "DOP853_FAILED",
+ "DOP853_ARGK",
+ "DOP853_FR",
+ "DOP853_FUN",
+ "DOP853_FV",
+ "DOP853_H_PREVIOUS",
+ "DOP853_K",
+ "DOP853_RR",
+ "DOP853_RR_OLD",
+ "DOP853_STATUS",
+ "DOP853_T",
+ "DOP853_T_OLD",
+ "DOP853_VV",
+ "DOP853_VV_OLD",
+ "brentq_gb",
+ "brentq_dense_hf",
+ "dispatcher_hb",
+ "dop853_dense_interp_brentq_hb",
+ "dop853_dense_interp_hf",
+ "dop853_dense_output_hf",
+ "dop853_init_hf",
+ "dop853_step_hf",
+ "event_is_active_hf",
+]
diff --git a/src/hapsira/core/math/ivp/_brentq.py b/src/hapsira/core/math/ivp/_brentq.py
new file mode 100644
index 000000000..3371c3519
--- /dev/null
+++ b/src/hapsira/core/math/ivp/_brentq.py
@@ -0,0 +1,429 @@
+from math import fabs, isnan, nan
+from typing import Callable
+
+from ._const import DENSE_SIG
+from ..ieee754 import EPS
+from ...jit import hjit, gjit
+
+
+__all__ = [
+ "BRENTQ_CONVERGED",
+ "BRENTQ_SIGNERR",
+ "BRENTQ_CONVERR",
+ "BRENTQ_ERROR",
+ "BRENTQ_XTOL",
+ "BRENTQ_RTOL",
+ "BRENTQ_MAXITER",
+ "brentq_gb",
+ "brentq_dense_hf",
+]
+
+
+BRENTQ_CONVERGED = 0
+BRENTQ_SIGNERR = -1
+BRENTQ_CONVERR = -2
+BRENTQ_ERROR = -3
+
+BRENTQ_XTOL = 2e-12
+BRENTQ_RTOL = 4 * EPS
+BRENTQ_MAXITER = 100
+
+
+@hjit("f(f,f)", inline=True)
+def _min_ss_hf(a, b):
+ """
+ The smaller of two scalars.
+ Inline by default.
+
+ Parameters
+ ----------
+ a : float
+ Scalar
+ b : float
+ Scalar
+
+ Returns
+ -------
+ c : float
+ Scalar
+
+ """
+
+ return a if a < b else b
+
+
+@hjit("b1(f)", inline=True)
+def _signbit_s_hf(a):
+ """
+ Sign bit of float.
+ Inline by default.
+
+ Parameters
+ ----------
+ a : float
+ Scalar
+
+ Returns
+ -------
+ b : boolean
+ Scalar
+
+ """
+
+ return a < 0
+
+
+@hjit("Tuple([f,i8])(F(f(f)),f,f,f,f,f)")
+def _brentq_hf(
+ func, # callback_type
+ xa, # double
+ xb, # double
+ xtol, # double
+ rtol, # double
+ maxiter, # int
+):
+ """
+ Find a root of a function in a bracketing interval using Brent's method.
+
+ Loosely adapted from
+ - https://github.com/scipy/scipy/blob/d23363809572e9a44074a3f06f66137083446b48/scipy/optimize/_zeros_py.py#L682
+ - https://github.com/scipy/scipy/blob/bd60b4ef9d886a9171345fc064c80aad6d171e73/scipy/optimize/Zeros/brentq.c#L37
+
+ Parameters
+ ----------
+ func : Callable, float of float
+ The function :math:`f` must be continuous, and :math:`f(xa)`
+ and :math:`f(xb)` must have opposite signs.
+ xa : float
+ One end of the bracketing interval :math:`[xa, xb]`.
+ xb : float
+ The other end of the bracketing interval :math:`[xa, xb]`.
+ xtol : float
+ The computed root ``x0`` will satisfy ``np.allclose(x, x0,
+ atol=xtol, rtol=rtol)``, where ``x`` is the exact root. The
+ parameter must be positive. For nice functions, Brent's
+ method will often satisfy the above condition with ``xtol/2``
+ and ``rtol/2``.
+ rtol : float
+ The computed root ``x0`` will satisfy ``np.allclose(x, x0,
+ atol=xtol, rtol=rtol)``, where ``x`` is the exact root. The
+ parameter cannot be smaller than its default value of
+ ``4*np.finfo(float).eps``. For nice functions, Brent's
+ method will often satisfy the above condition with ``xtol/2``
+ and ``rtol/2``.
+ maxiter : float
+ If convergence is not achieved in `maxiter` iterations, an error is
+ raised. Must be >= 0.
+
+ Returns
+ -------
+ xc : float
+ Root of `f` between `a` and `b`.
+ status : int
+ Solver status code.
+
+ """
+
+ # if not xtol + 0. > 0:
+ # return 0., BRENTQ_ERROR
+ # if not rtol + 0. >= BRENTQ_RTOL:
+ # return 0., BRENTQ_ERROR
+ # if not maxiter + 0 >= 0:
+ # return 0., BRENTQ_ERROR
+
+ xpre, xcur = xa, xb
+ xblk = 0.0
+ fpre, fcur, fblk = 0.0, 0.0, 0.0
+ spre, scur = 0.0, 0.0
+
+ fpre = func(xpre)
+ if isnan(fpre):
+ return 0.0, BRENTQ_ERROR
+
+ fcur = func(xcur)
+ if isnan(fcur):
+ return 0.0, BRENTQ_ERROR
+
+ if fpre == 0:
+ return xpre, BRENTQ_CONVERGED
+ if fcur == 0:
+ return xcur, BRENTQ_CONVERGED
+ if _signbit_s_hf(fpre) == _signbit_s_hf(fcur):
+ return 0.0, BRENTQ_SIGNERR
+
+ for _ in range(0, maxiter):
+ if fpre != 0 and fcur != 0 and _signbit_s_hf(fpre) != _signbit_s_hf(fcur):
+ xblk = xpre
+ fblk = fpre
+ scur = xcur - xpre
+ spre = scur
+ if fabs(fblk) < fabs(fcur):
+ xpre = xcur
+ xcur = xblk
+ xblk = xpre
+
+ fpre = fcur
+ fcur = fblk
+ fblk = fpre
+
+ delta = (xtol + rtol * fabs(xcur)) / 2
+ sbis = (xblk - xcur) / 2
+ if fcur == 0 or fabs(sbis) < delta:
+ return xcur, BRENTQ_CONVERGED
+
+ if fabs(spre) > delta and fabs(fcur) < fabs(fpre):
+ if xpre == xblk:
+ stry = -fcur * (xcur - xpre) / (fcur - fpre)
+ else:
+ dpre = (fpre - fcur) / (xpre - xcur)
+ dblk = (fblk - fcur) / (xblk - xcur)
+ stry = (
+ -fcur * (fblk * dblk - fpre * dpre) / (dblk * dpre * (fblk - fpre))
+ )
+ if 2 * fabs(stry) < _min_ss_hf(fabs(spre), 3 * fabs(sbis) - delta):
+ spre = scur
+ scur = stry
+ else:
+ spre = sbis
+ scur = sbis
+ else:
+ spre = sbis
+ scur = sbis
+
+ xpre = xcur
+ fpre = fcur
+ if fabs(scur) > delta:
+ xcur += scur
+ else:
+ xcur += delta if sbis > 0 else -delta
+
+ fcur = func(xcur)
+ if isnan(fcur):
+ return 0.0, BRENTQ_ERROR
+
+ return xcur, BRENTQ_CONVERR
+
+
+def brentq_gb(func: Callable) -> Callable:
+ """
+ Builds vectorized brentq
+
+ Parameters
+ ----------
+ func : Callable
+
+ Returns
+ -------
+ brentq_gf : Callable
+ """
+
+ @gjit(
+ "void(f,f,f,f,f,f[:],i8[:])",
+ "(),(),(),(),()->(),()",
+ cache=False,
+ )
+ def brentq_gf(
+ xa,
+ xb,
+ xtol,
+ rtol,
+ maxiter,
+ xcur,
+ status,
+ ):
+ """
+ Find a root of a function in a bracketing interval using Brent's method.
+
+ Loosely adapted from
+ - https://github.com/scipy/scipy/blob/d23363809572e9a44074a3f06f66137083446b48/scipy/optimize/_zeros_py.py#L682
+ - https://github.com/scipy/scipy/blob/bd60b4ef9d886a9171345fc064c80aad6d171e73/scipy/optimize/Zeros/brentq.c#L37
+
+ Parameters
+ ----------
+ xa : float
+ One end of the bracketing interval :math:`[xa, xb]`.
+ xb : float
+ The other end of the bracketing interval :math:`[xa, xb]`.
+ xtol : float
+ The computed root ``x0`` will satisfy ``np.allclose(x, x0,
+ atol=xtol, rtol=rtol)``, where ``x`` is the exact root. The
+ parameter must be positive. For nice functions, Brent's
+ method will often satisfy the above condition with ``xtol/2``
+ and ``rtol/2``.
+ rtol : float
+ The computed root ``x0`` will satisfy ``np.allclose(x, x0,
+ atol=xtol, rtol=rtol)``, where ``x`` is the exact root. The
+ parameter cannot be smaller than its default value of
+ ``4*np.finfo(float).eps``. For nice functions, Brent's
+ method will often satisfy the above condition with ``xtol/2``
+ and ``rtol/2``.
+ maxiter : float
+ If convergence is not achieved in `maxiter` iterations, an error is
+ raised. Must be >= 0.
+
+ Returns
+ -------
+ xc : float
+ Root of `f` between `a` and `b`.
+ status : int
+ Solver status code.
+
+ """
+
+ xcur[0], status[0] = _brentq_hf(func, xa, xb, xtol, rtol, maxiter)
+
+ return brentq_gf
+
+
+@hjit(f"Tuple([f,f,i8])(F(f(i8,f,{DENSE_SIG:s},f)),i8,f,f,f,f,f,{DENSE_SIG:s},f)")
+def brentq_dense_hf(
+ func, # callback_type
+ idx,
+ xa, # double
+ xb, # double
+ xtol, # double
+ rtol, # double
+ maxiter, # int
+ sol1,
+ sol2,
+ sol3,
+ sol4,
+ sol5,
+ argk,
+):
+ """
+ Find a root of a function in a bracketing interval using Brent's method.
+ Virtually identical to `_brentq_hf`, except that it passes extra arguments
+ through to `func`. Architecturally required due to limiations of `numba`.
+
+ Loosely adapted from
+ - https://github.com/scipy/scipy/blob/d23363809572e9a44074a3f06f66137083446b48/scipy/optimize/_zeros_py.py#L682
+ - https://github.com/scipy/scipy/blob/bd60b4ef9d886a9171345fc064c80aad6d171e73/scipy/optimize/Zeros/brentq.c#L37
+
+ Parameters
+ ----------
+ func : Callable, float of float
+ The function :math:`f` must be continuous, and :math:`f(xa)`
+ and :math:`f(xb)` must have opposite signs.
+ idx : int
+ Selects function in dispatcher, passed to `func`.
+ xa : float
+ One end of the bracketing interval :math:`[xa, xb]`.
+ xb : float
+ The other end of the bracketing interval :math:`[xa, xb]`.
+ xtol : float
+ The computed root ``x0`` will satisfy ``np.allclose(x, x0,
+ atol=xtol, rtol=rtol)``, where ``x`` is the exact root. The
+ parameter must be positive. For nice functions, Brent's
+ method will often satisfy the above condition with ``xtol/2``
+ and ``rtol/2``.
+ rtol : float
+ The computed root ``x0`` will satisfy ``np.allclose(x, x0,
+ atol=xtol, rtol=rtol)``, where ``x`` is the exact root. The
+ parameter cannot be smaller than its default value of
+ ``4*np.finfo(float).eps``. For nice functions, Brent's
+ method will often satisfy the above condition with ``xtol/2``
+ and ``rtol/2``.
+ maxiter : float
+ If convergence is not achieved in `maxiter` iterations, an error is
+ raised. Must be >= 0.
+ sol1 : any
+ Passed to `func`.
+ sol2 : any
+ Passed to `func`.
+ sol3 : any
+ Passed to `func`.
+ sol4 : any
+ Passed to `func`.
+ sol5 : any
+ Passed to `func`.
+ argk : float
+ Passed to `func`.
+
+ Returns
+ -------
+ xc : float
+ Root of `f` between `a` and `b`.
+ status : int
+ Solver status code.
+
+ """
+
+ if not xtol + 0.0 > 0:
+ return nan, 0.0, BRENTQ_ERROR
+ if not rtol + 0.0 >= BRENTQ_RTOL:
+ return nan, 0.0, BRENTQ_ERROR
+ if not maxiter + 0 >= 0:
+ return nan, 0.0, BRENTQ_ERROR
+
+ xpre, xcur = xa, xb
+ xblk = 0.0
+ fpre, fcur, fblk = 0.0, 0.0, 0.0
+ spre, scur = 0.0, 0.0
+
+ fpre = func(idx, xpre, sol1, sol2, sol3, sol4, sol5, argk)
+ if isnan(fpre):
+ return xpre, 0.0, BRENTQ_ERROR
+
+ fcur = func(idx, xcur, sol1, sol2, sol3, sol4, sol5, argk)
+ if isnan(fcur):
+ return xcur, 0.0, BRENTQ_ERROR
+
+ if fpre == 0:
+ return xcur, xpre, BRENTQ_CONVERGED
+ if fcur == 0:
+ return xcur, xcur, BRENTQ_CONVERGED
+ if _signbit_s_hf(fpre) == _signbit_s_hf(fcur):
+ return xcur, 0.0, BRENTQ_SIGNERR
+
+ for _ in range(0, maxiter):
+ if fpre != 0 and fcur != 0 and _signbit_s_hf(fpre) != _signbit_s_hf(fcur):
+ xblk = xpre
+ fblk = fpre
+ scur = xcur - xpre
+ spre = scur
+ if fabs(fblk) < fabs(fcur):
+ xpre = xcur
+ xcur = xblk
+ xblk = xpre
+
+ fpre = fcur
+ fcur = fblk
+ fblk = fpre
+
+ delta = (xtol + rtol * fabs(xcur)) / 2
+ sbis = (xblk - xcur) / 2
+ if fcur == 0 or fabs(sbis) < delta:
+ return xcur, xcur, BRENTQ_CONVERGED
+
+ if fabs(spre) > delta and fabs(fcur) < fabs(fpre):
+ if xpre == xblk:
+ stry = -fcur * (xcur - xpre) / (fcur - fpre)
+ else:
+ dpre = (fpre - fcur) / (xpre - xcur)
+ dblk = (fblk - fcur) / (xblk - xcur)
+ stry = (
+ -fcur * (fblk * dblk - fpre * dpre) / (dblk * dpre * (fblk - fpre))
+ )
+ if 2 * fabs(stry) < _min_ss_hf(fabs(spre), 3 * fabs(sbis) - delta):
+ spre = scur
+ scur = stry
+ else:
+ spre = sbis
+ scur = sbis
+ else:
+ spre = sbis
+ scur = sbis
+
+ xpre = xcur
+ fpre = fcur
+ if fabs(scur) > delta:
+ xcur += scur
+ else:
+ xcur += delta if sbis > 0 else -delta
+
+ fcur = func(idx, xcur, sol1, sol2, sol3, sol4, sol5, argk)
+ if isnan(fcur):
+ return xcur, 0.0, BRENTQ_ERROR
+
+ return xcur, xcur, BRENTQ_CONVERR
diff --git a/src/hapsira/core/math/ivp/_const.py b/src/hapsira/core/math/ivp/_const.py
new file mode 100644
index 000000000..ff276ba04
--- /dev/null
+++ b/src/hapsira/core/math/ivp/_const.py
@@ -0,0 +1,39 @@
+__all__ = [
+ "N_RV",
+ "N_STAGES",
+ "SAFETY",
+ "MIN_FACTOR",
+ "MAX_FACTOR",
+ "INTERPOLATOR_POWER",
+ "N_STAGES_EXTENDED",
+ "ERROR_ESTIMATOR_ORDER",
+ "ERROR_EXPONENT",
+ "KSIG",
+]
+
+N_RV = 6
+N_STAGES = 12
+N_STAGES_EXTENDED = 16
+
+SAFETY = 0.9 # Multiply steps computed from asymptotic behaviour of errors by this.
+
+MIN_FACTOR = 0.2 # Minimum allowed decrease in a step size.
+MAX_FACTOR = 10 # Maximum allowed increase in a step size.
+
+INTERPOLATOR_POWER = 7
+ERROR_ESTIMATOR_ORDER = 7
+ERROR_EXPONENT = -1 / (ERROR_ESTIMATOR_ORDER + 1)
+
+KSIG = (
+ "Tuple(["
+ + ",".join(["Tuple([" + ",".join(["f"] * N_RV) + "])"] * (N_STAGES + 1))
+ + "])"
+)
+
+FSIG = (
+ "Tuple(["
+ + ",".join(["Tuple([" + ",".join(["f"] * N_RV) + "])"] * INTERPOLATOR_POWER)
+ + "])"
+)
+
+DENSE_SIG = f"f,f,V,V,{FSIG:s}"
diff --git a/src/hapsira/core/math/ivp/_dop853_coefficients.py b/src/hapsira/core/math/ivp/_dop853_coefficients.py
new file mode 100644
index 000000000..fec0bea2c
--- /dev/null
+++ b/src/hapsira/core/math/ivp/_dop853_coefficients.py
@@ -0,0 +1,207 @@
+import numpy as np
+
+from ._const import N_STAGES, N_STAGES_EXTENDED, INTERPOLATOR_POWER
+
+__all__ = [
+ "A",
+ "B",
+ "C",
+ "D",
+ "E3",
+ "E5",
+]
+
+# Based on
+# https://github.com/scipy/scipy/blob/v1.12.0/scipy/integrate/_ivp/dop853_coefficients.py
+
+C = np.array(
+ [
+ 0.0,
+ 0.526001519587677318785587544488e-01,
+ 0.789002279381515978178381316732e-01,
+ 0.118350341907227396726757197510,
+ 0.281649658092772603273242802490,
+ 0.333333333333333333333333333333,
+ 0.25,
+ 0.307692307692307692307692307692,
+ 0.651282051282051282051282051282,
+ 0.6,
+ 0.857142857142857142857142857142,
+ 1.0,
+ 1.0,
+ 0.1,
+ 0.2,
+ 0.777777777777777777777777777778,
+ ]
+)
+
+A = np.zeros((N_STAGES_EXTENDED, N_STAGES_EXTENDED))
+A[1, 0] = 5.26001519587677318785587544488e-2
+
+A[2, 0] = 1.97250569845378994544595329183e-2
+A[2, 1] = 5.91751709536136983633785987549e-2
+
+A[3, 0] = 2.95875854768068491816892993775e-2
+A[3, 2] = 8.87627564304205475450678981324e-2
+
+A[4, 0] = 2.41365134159266685502369798665e-1
+A[4, 2] = -8.84549479328286085344864962717e-1
+A[4, 3] = 9.24834003261792003115737966543e-1
+
+A[5, 0] = 3.7037037037037037037037037037e-2
+A[5, 3] = 1.70828608729473871279604482173e-1
+A[5, 4] = 1.25467687566822425016691814123e-1
+
+A[6, 0] = 3.7109375e-2
+A[6, 3] = 1.70252211019544039314978060272e-1
+A[6, 4] = 6.02165389804559606850219397283e-2
+A[6, 5] = -1.7578125e-2
+
+A[7, 0] = 3.70920001185047927108779319836e-2
+A[7, 3] = 1.70383925712239993810214054705e-1
+A[7, 4] = 1.07262030446373284651809199168e-1
+A[7, 5] = -1.53194377486244017527936158236e-2
+A[7, 6] = 8.27378916381402288758473766002e-3
+
+A[8, 0] = 6.24110958716075717114429577812e-1
+A[8, 3] = -3.36089262944694129406857109825
+A[8, 4] = -8.68219346841726006818189891453e-1
+A[8, 5] = 2.75920996994467083049415600797e1
+A[8, 6] = 2.01540675504778934086186788979e1
+A[8, 7] = -4.34898841810699588477366255144e1
+
+A[9, 0] = 4.77662536438264365890433908527e-1
+A[9, 3] = -2.48811461997166764192642586468
+A[9, 4] = -5.90290826836842996371446475743e-1
+A[9, 5] = 2.12300514481811942347288949897e1
+A[9, 6] = 1.52792336328824235832596922938e1
+A[9, 7] = -3.32882109689848629194453265587e1
+A[9, 8] = -2.03312017085086261358222928593e-2
+
+A[10, 0] = -9.3714243008598732571704021658e-1
+A[10, 3] = 5.18637242884406370830023853209
+A[10, 4] = 1.09143734899672957818500254654
+A[10, 5] = -8.14978701074692612513997267357
+A[10, 6] = -1.85200656599969598641566180701e1
+A[10, 7] = 2.27394870993505042818970056734e1
+A[10, 8] = 2.49360555267965238987089396762
+A[10, 9] = -3.0467644718982195003823669022
+
+A[11, 0] = 2.27331014751653820792359768449
+A[11, 3] = -1.05344954667372501984066689879e1
+A[11, 4] = -2.00087205822486249909675718444
+A[11, 5] = -1.79589318631187989172765950534e1
+A[11, 6] = 2.79488845294199600508499808837e1
+A[11, 7] = -2.85899827713502369474065508674
+A[11, 8] = -8.87285693353062954433549289258
+A[11, 9] = 1.23605671757943030647266201528e1
+A[11, 10] = 6.43392746015763530355970484046e-1
+
+A[12, 0] = 5.42937341165687622380535766363e-2
+A[12, 5] = 4.45031289275240888144113950566
+A[12, 6] = 1.89151789931450038304281599044
+A[12, 7] = -5.8012039600105847814672114227
+A[12, 8] = 3.1116436695781989440891606237e-1
+A[12, 9] = -1.52160949662516078556178806805e-1
+A[12, 10] = 2.01365400804030348374776537501e-1
+A[12, 11] = 4.47106157277725905176885569043e-2
+
+A[13, 0] = 5.61675022830479523392909219681e-2
+A[13, 6] = 2.53500210216624811088794765333e-1
+A[13, 7] = -2.46239037470802489917441475441e-1
+A[13, 8] = -1.24191423263816360469010140626e-1
+A[13, 9] = 1.5329179827876569731206322685e-1
+A[13, 10] = 8.20105229563468988491666602057e-3
+A[13, 11] = 7.56789766054569976138603589584e-3
+A[13, 12] = -8.298e-3
+
+A[14, 0] = 3.18346481635021405060768473261e-2
+A[14, 5] = 2.83009096723667755288322961402e-2
+A[14, 6] = 5.35419883074385676223797384372e-2
+A[14, 7] = -5.49237485713909884646569340306e-2
+A[14, 10] = -1.08347328697249322858509316994e-4
+A[14, 11] = 3.82571090835658412954920192323e-4
+A[14, 12] = -3.40465008687404560802977114492e-4
+A[14, 13] = 1.41312443674632500278074618366e-1
+
+A[15, 0] = -4.28896301583791923408573538692e-1
+A[15, 5] = -4.69762141536116384314449447206
+A[15, 6] = 7.68342119606259904184240953878
+A[15, 7] = 4.06898981839711007970213554331
+A[15, 8] = 3.56727187455281109270669543021e-1
+A[15, 12] = -1.39902416515901462129418009734e-3
+A[15, 13] = 2.9475147891527723389556272149
+A[15, 14] = -9.15095847217987001081870187138
+
+
+B = A[N_STAGES, :N_STAGES]
+
+E3 = np.zeros(N_STAGES + 1)
+E3[:-1] = B.copy()
+E3[0] -= 0.244094488188976377952755905512
+E3[8] -= 0.733846688281611857341361741547
+E3[11] -= 0.220588235294117647058823529412e-1
+
+E5 = np.zeros(N_STAGES + 1)
+E5[0] = 0.1312004499419488073250102996e-1
+E5[5] = -0.1225156446376204440720569753e1
+E5[6] = -0.4957589496572501915214079952
+E5[7] = 0.1664377182454986536961530415e1
+E5[8] = -0.3503288487499736816886487290
+E5[9] = 0.3341791187130174790297318841
+E5[10] = 0.8192320648511571246570742613e-1
+E5[11] = -0.2235530786388629525884427845e-1
+
+# First 3 coefficients are computed separately.
+D = np.zeros((INTERPOLATOR_POWER - 3, N_STAGES_EXTENDED))
+D[0, 0] = -0.84289382761090128651353491142e1
+D[0, 5] = 0.56671495351937776962531783590
+D[0, 6] = -0.30689499459498916912797304727e1
+D[0, 7] = 0.23846676565120698287728149680e1
+D[0, 8] = 0.21170345824450282767155149946e1
+D[0, 9] = -0.87139158377797299206789907490
+D[0, 10] = 0.22404374302607882758541771650e1
+D[0, 11] = 0.63157877876946881815570249290
+D[0, 12] = -0.88990336451333310820698117400e-1
+D[0, 13] = 0.18148505520854727256656404962e2
+D[0, 14] = -0.91946323924783554000451984436e1
+D[0, 15] = -0.44360363875948939664310572000e1
+
+D[1, 0] = 0.10427508642579134603413151009e2
+D[1, 5] = 0.24228349177525818288430175319e3
+D[1, 6] = 0.16520045171727028198505394887e3
+D[1, 7] = -0.37454675472269020279518312152e3
+D[1, 8] = -0.22113666853125306036270938578e2
+D[1, 9] = 0.77334326684722638389603898808e1
+D[1, 10] = -0.30674084731089398182061213626e2
+D[1, 11] = -0.93321305264302278729567221706e1
+D[1, 12] = 0.15697238121770843886131091075e2
+D[1, 13] = -0.31139403219565177677282850411e2
+D[1, 14] = -0.93529243588444783865713862664e1
+D[1, 15] = 0.35816841486394083752465898540e2
+
+D[2, 0] = 0.19985053242002433820987653617e2
+D[2, 5] = -0.38703730874935176555105901742e3
+D[2, 6] = -0.18917813819516756882830838328e3
+D[2, 7] = 0.52780815920542364900561016686e3
+D[2, 8] = -0.11573902539959630126141871134e2
+D[2, 9] = 0.68812326946963000169666922661e1
+D[2, 10] = -0.10006050966910838403183860980e1
+D[2, 11] = 0.77771377980534432092869265740
+D[2, 12] = -0.27782057523535084065932004339e1
+D[2, 13] = -0.60196695231264120758267380846e2
+D[2, 14] = 0.84320405506677161018159903784e2
+D[2, 15] = 0.11992291136182789328035130030e2
+
+D[3, 0] = -0.25693933462703749003312586129e2
+D[3, 5] = -0.15418974869023643374053993627e3
+D[3, 6] = -0.23152937917604549567536039109e3
+D[3, 7] = 0.35763911791061412378285349910e3
+D[3, 8] = 0.93405324183624310003907691704e2
+D[3, 9] = -0.37458323136451633156875139351e2
+D[3, 10] = 0.10409964950896230045147246184e3
+D[3, 11] = 0.29840293426660503123344363579e2
+D[3, 12] = -0.43533456590011143754432175058e2
+D[3, 13] = 0.96324553959188282948394950600e2
+D[3, 14] = -0.39177261675615439165231486172e2
+D[3, 15] = -0.14972683625798562581422125276e3
diff --git a/src/hapsira/core/math/ivp/_rkcore.py b/src/hapsira/core/math/ivp/_rkcore.py
new file mode 100644
index 000000000..c70667348
--- /dev/null
+++ b/src/hapsira/core/math/ivp/_rkcore.py
@@ -0,0 +1,392 @@
+from math import nan
+
+from ._const import ERROR_ESTIMATOR_ORDER, KSIG
+from ._rkstepinit import select_initial_step_hf
+from ._rkstepimpl import step_impl_hf
+from ..ieee754 import EPS
+from ..linalg import sign_hf
+from ...jit import hjit, DSIG
+
+__all__ = [
+ "dop853_init_hf",
+ "dop853_step_hf",
+ "DOP853_RUNNING",
+ "DOP853_FINISHED",
+ "DOP853_FAILED",
+ "DOP853_ARGK",
+ "DOP853_FR",
+ "DOP853_FUN",
+ "DOP853_FV",
+ "DOP853_H_PREVIOUS",
+ "DOP853_K",
+ "DOP853_RR",
+ "DOP853_RR_OLD",
+ "DOP853_STATUS",
+ "DOP853_T",
+ "DOP853_T_OLD",
+ "DOP853_VV",
+ "DOP853_VV_OLD",
+]
+
+
+DOP853_RUNNING = 0
+DOP853_FINISHED = 1
+DOP853_FAILED = 2
+
+DOP853_ARGK = 5
+DOP853_FR = 15
+DOP853_FUN = 4
+DOP853_FV = 16
+DOP853_H_PREVIOUS = 13
+DOP853_K = 9
+DOP853_RR = 1
+DOP853_RR_OLD = 10
+DOP853_STATUS = 14
+DOP853_T = 0
+DOP853_T_OLD = 12
+DOP853_VV = 2
+DOP853_VV_OLD = 11
+
+DOP853_SIG = f"f,V,V,f,F({DSIG}),f,f,f,f,{KSIG:s},V,V,f,f,f,V,V,f"
+
+
+@hjit(f"Tuple([{DOP853_SIG:s}])(F({DSIG}),f,V,V,f,f,f,f)")
+def dop853_init_hf(fun, t0, rr, vv, t_bound, argk, rtol, atol):
+ """
+ Explicit Runge-Kutta method of order 8.
+ Functional re-write of constructor of class `DOP853` within `scipy.integrate`.
+
+ Based on
+ - https://github.com/scipy/scipy/blob/4edfcaa3ce8a387450b6efce968572def71be089/scipy/integrate/_ivp/rk.py#L502
+ - https://github.com/scipy/scipy/blob/4edfcaa3ce8a387450b6efce968572def71be089/scipy/integrate/_ivp/rk.py#L85
+ - https://github.com/scipy/scipy/blob/4edfcaa3ce8a387450b6efce968572def71be089/scipy/integrate/_ivp/base.py#L131
+
+ Parameters
+ ----------
+ fun : float
+ Right-hand side of the system.
+ t0 : float
+ Initial time.
+ rr : float
+ Initial state 0:3
+ vv : float
+ Initial state 3:6
+ t_bound : float
+ Boundary time
+ argk : float
+ Standard gravitational parameter for `fun`
+ rtol : float
+ Relative tolerance
+ atol : float
+ Absolute tolerance
+
+ Returns
+ -------
+ t0 : float
+ Initial time.
+ rr : tuple[float,float,float]
+ Initial state 0:3
+ vv : tuple[float,float,float]
+ Initial state 3:6
+ t_bound : float
+ Boundary time
+ fun : Callable
+ Right-hand side of the system
+ argk : float
+ Standard gravitational parameter for `fun`
+ rtol : float
+ Relative tolerance
+ atol : float
+ Absolute tolerance
+ direction : float
+ Integration direction
+ K : tuple[[float,...],...]
+ Storage array for RK stages
+ rr_old : tuple[float,float,float]
+ Last state 0:3
+ vv_old : tuple[float,float,float]
+ Last state 3:6
+ t_old : float
+ Last time
+ h_previous : float
+ Last step length
+ status : float
+ Solver status
+ fr : tuple[float,float,float]
+ Current value of the derivative 0:3
+ fv : tuple[float,float,float]
+ Current value of the derivative 3:6
+ h_abs : float
+ Absolute step
+
+ """
+
+ assert atol >= 0
+
+ if rtol < 100 * EPS:
+ rtol = 100 * EPS
+
+ direction = sign_hf(t_bound - t0) if t_bound != t0 else 1
+
+ K = (
+ (0.0, 0.0, 0.0, 0.0, 0.0, 0.0), # 0
+ (0.0, 0.0, 0.0, 0.0, 0.0, 0.0), # 1
+ (0.0, 0.0, 0.0, 0.0, 0.0, 0.0), # 2
+ (0.0, 0.0, 0.0, 0.0, 0.0, 0.0), # 3
+ (0.0, 0.0, 0.0, 0.0, 0.0, 0.0), # 4
+ (0.0, 0.0, 0.0, 0.0, 0.0, 0.0), # 5
+ (0.0, 0.0, 0.0, 0.0, 0.0, 0.0), # 6
+ (0.0, 0.0, 0.0, 0.0, 0.0, 0.0), # 7
+ (0.0, 0.0, 0.0, 0.0, 0.0, 0.0), # 8
+ (0.0, 0.0, 0.0, 0.0, 0.0, 0.0), # 9
+ (0.0, 0.0, 0.0, 0.0, 0.0, 0.0), # 10
+ (0.0, 0.0, 0.0, 0.0, 0.0, 0.0), # 11
+ (0.0, 0.0, 0.0, 0.0, 0.0, 0.0), # 12
+ )
+
+ rr_old = (nan, nan, nan)
+ vv_old = (nan, nan, nan)
+ t_old = nan
+ h_previous = nan
+
+ status = DOP853_RUNNING
+
+ fr, fv = fun(
+ t0,
+ rr,
+ vv,
+ argk,
+ )
+
+ h_abs = select_initial_step_hf(
+ fun,
+ t0,
+ rr,
+ vv,
+ argk,
+ fr,
+ fv,
+ direction,
+ ERROR_ESTIMATOR_ORDER,
+ rtol,
+ atol,
+ )
+
+ return (
+ t0, # 0 -> t
+ rr, # 1
+ vv, # 2
+ t_bound, # 3
+ fun, # 4
+ argk, # 5
+ rtol, # 6
+ atol, # 7
+ direction, # 8
+ K, # 9
+ rr_old, # 10
+ vv_old, # 11
+ t_old, # 12
+ h_previous, # 13
+ status, # 14
+ fr, # 15
+ fv, # 16
+ h_abs, # 17
+ )
+
+
+@hjit(f"Tuple([{DOP853_SIG:s}])({DOP853_SIG:s})")
+def dop853_step_hf(
+ t,
+ rr,
+ vv,
+ t_bound,
+ fun,
+ argk,
+ rtol,
+ atol,
+ direction,
+ K,
+ rr_old,
+ vv_old,
+ t_old,
+ h_previous,
+ status,
+ fr,
+ fv,
+ h_abs,
+):
+ """
+ Perform one integration step.
+ Functional re-write of method `step` of class `OdeSolver` within `scipy.integrate`.
+
+ Based on
+ https://github.com/scipy/scipy/blob/4edfcaa3ce8a387450b6efce968572def71be089/scipy/integrate/_ivp/base.py#L175
+
+ Parameters
+ ----------
+ t : float
+ Current time.
+ rr : tuple[float,float,float]
+ Current state 0:3
+ vv : tuple[float,float,float]
+ Current state 3:6
+ t_bound : float
+ Boundary time
+ fun : Callable
+ Right-hand side of the system
+ argk : float
+ Standard gravitational parameter for `fun`
+ rtol : float
+ Relative tolerance
+ atol : float
+ Absolute tolerance
+ direction : float
+ Integration direction
+ K : tuple[[float,...],...]
+ Storage array for RK stages
+ rr_old : tuple[float,float,float]
+ Last state 0:3
+ vv_old : tuple[float,float,float]
+ Last state 3:6
+ t_old : float
+ Last time
+ h_previous : float
+ Last step length
+ status : float
+ Solver status
+ fr : tuple[float,float,float]
+ Current value of the derivative 0:3
+ fv : tuple[float,float,float]
+ Current value of the derivative 3:6
+ h_abs : float
+ Absolute step
+
+ Returns
+ -------
+ t : float
+ Current time.
+ rr : tuple[float,float,float]
+ Current state 0:3
+ vv : tuple[float,float,float]
+ Current state 3:6
+ t_bound : float
+ Boundary time
+ fun : Callable
+ Right-hand side of the system
+ argk : float
+ Standard gravitational parameter for `fun`
+ rtol : float
+ Relative tolerance
+ atol : float
+ Absolute tolerance
+ direction : float
+ Integration direction
+ K : tuple[[float,...],...]
+ Storage array for RK stages
+ rr_old : tuple[float,float,float]
+ Last state 0:3
+ vv_old : tuple[float,float,float]
+ Last state 3:6
+ t_old : float
+ Last time
+ h_previous : float
+ Last step length
+ status : float
+ Solver status
+ fr : tuple[float,float,float]
+ Current value of the derivative 0:3
+ fv : tuple[float,float,float]
+ Current value of the derivative 3:6
+ h_abs : float
+ Absolute step
+
+ """
+
+ if status != DOP853_RUNNING:
+ raise RuntimeError("Attempt to step on a failed or finished " "solver.")
+
+ if t == t_bound:
+ # Handle corner cases of empty solver or no integration.
+ t_old = t
+ t = t_bound
+ status = DOP853_FINISHED
+ return (
+ t,
+ rr,
+ vv,
+ t_bound,
+ fun,
+ argk,
+ rtol,
+ atol,
+ direction,
+ K,
+ rr_old,
+ vv_old,
+ t_old,
+ h_previous,
+ status,
+ fr,
+ fv,
+ h_abs,
+ )
+
+ t_tmp = t
+ rets = step_impl_hf(
+ fun,
+ argk,
+ t,
+ rr,
+ vv,
+ fr,
+ fv,
+ rtol,
+ atol,
+ direction,
+ h_abs,
+ t_bound,
+ K,
+ )
+ success = rets[0]
+
+ if success:
+ rr_old = rr
+ vv_old = vv
+ (
+ h_previous,
+ t,
+ rr,
+ vv,
+ h_abs,
+ fr,
+ fv,
+ K,
+ ) = rets[1:]
+
+ if not success:
+ status = DOP853_FAILED
+ else:
+ t_old = t_tmp
+ if not direction * (t - t_bound) < 0:
+ status = DOP853_FINISHED
+
+ return (
+ t,
+ rr,
+ vv,
+ t_bound,
+ fun,
+ argk,
+ rtol,
+ atol,
+ direction,
+ K,
+ rr_old,
+ vv_old,
+ t_old,
+ h_previous,
+ status,
+ fr,
+ fv,
+ h_abs,
+ )
diff --git a/src/hapsira/core/math/ivp/_rkdenseinterp.py b/src/hapsira/core/math/ivp/_rkdenseinterp.py
new file mode 100644
index 000000000..6b675b686
--- /dev/null
+++ b/src/hapsira/core/math/ivp/_rkdenseinterp.py
@@ -0,0 +1,93 @@
+from ._const import DENSE_SIG
+from ..linalg import add_VV_hf, mul_Vs_hf
+from ...jit import hjit
+
+
+__all__ = [
+ "dop853_dense_interp_brentq_hb",
+ "dop853_dense_interp_hf",
+]
+
+
+@hjit(f"Tuple([V,V])(f,{DENSE_SIG:s})")
+def dop853_dense_interp_hf(t, t_old, h, rr_old, vv_old, F):
+ """
+ Local interpolant over step made by an ODE solver.
+ Evaluate the interpolant.
+
+ Based on
+ https://github.com/scipy/scipy/blob/4edfcaa3ce8a387450b6efce968572def71be089/scipy/integrate/_ivp/rk.py#L584
+
+ Parameters
+ ----------
+ t : float
+ Current time.
+ t_old : float
+ Previous time.
+ h : float
+ Step to use.
+ rr_rold : tuple[float,float,float]
+ Last values 0:3.
+ vv_vold : tuple[float,float,float]
+ Last values 3:6.
+ F : tuple[tuple[float,...]...]
+ Dense output coefficients.
+
+ Returns
+ -------
+ rr : tuple[float,float,float]
+ Computed values 0:3.
+ vv : tuple[float,float,float]
+ Computed values 3:6.
+ """
+
+ F00, F01, F02, F03, F04, F05, F06 = F
+
+ x = (t - t_old) / h
+
+ rr_new = mul_Vs_hf(F06[:3], x)
+ vv_new = mul_Vs_hf(F06[3:], x)
+
+ rr_new = add_VV_hf(rr_new, F05[:3])
+ vv_new = add_VV_hf(vv_new, F05[3:])
+ rr_new = mul_Vs_hf(rr_new, 1 - x)
+ vv_new = mul_Vs_hf(vv_new, 1 - x)
+
+ rr_new = add_VV_hf(rr_new, F04[:3])
+ vv_new = add_VV_hf(vv_new, F04[3:])
+ rr_new = mul_Vs_hf(rr_new, x)
+ vv_new = mul_Vs_hf(vv_new, x)
+
+ rr_new = add_VV_hf(rr_new, F03[:3])
+ vv_new = add_VV_hf(vv_new, F03[3:])
+ rr_new = mul_Vs_hf(rr_new, 1 - x)
+ vv_new = mul_Vs_hf(vv_new, 1 - x)
+
+ rr_new = add_VV_hf(rr_new, F02[:3])
+ vv_new = add_VV_hf(vv_new, F02[3:])
+ rr_new = mul_Vs_hf(rr_new, x)
+ vv_new = mul_Vs_hf(vv_new, x)
+
+ rr_new = add_VV_hf(rr_new, F01[:3])
+ vv_new = add_VV_hf(vv_new, F01[3:])
+ rr_new = mul_Vs_hf(rr_new, 1 - x)
+ vv_new = mul_Vs_hf(vv_new, 1 - x)
+
+ rr_new = add_VV_hf(rr_new, F00[:3])
+ vv_new = add_VV_hf(vv_new, F00[3:])
+ rr_new = mul_Vs_hf(rr_new, x)
+ vv_new = mul_Vs_hf(vv_new, x)
+
+ rr_new = add_VV_hf(rr_new, rr_old)
+ vv_new = add_VV_hf(vv_new, vv_old)
+
+ return rr_new, vv_new
+
+
+def dop853_dense_interp_brentq_hb(func):
+ @hjit(f"f(f,{DENSE_SIG:s},f)", cache=False)
+ def event_wrapper(t, t_old, h, rr_old, vv_old, F, argk):
+ rr, vv = dop853_dense_interp_hf(t, t_old, h, rr_old, vv_old, F)
+ return func(t, rr, vv, argk)
+
+ return event_wrapper
diff --git a/src/hapsira/core/math/ivp/_rkdenseoutput.py b/src/hapsira/core/math/ivp/_rkdenseoutput.py
new file mode 100644
index 000000000..c5ee39ef7
--- /dev/null
+++ b/src/hapsira/core/math/ivp/_rkdenseoutput.py
@@ -0,0 +1,877 @@
+from ._const import FSIG, KSIG, N_STAGES
+from ._dop853_coefficients import A as _A, C as _C, D as _D
+from ..ieee754 import float_
+from ..linalg import (
+ add_VV_hf,
+ mul_Vs_hf,
+ sub_VV_hf,
+)
+from ...jit import hjit, DSIG
+
+__all__ = [
+ "dop853_dense_output_hf",
+]
+
+
+A00 = tuple(float_(number) for number in _A[N_STAGES + 1, :13])
+A01 = tuple(float_(number) for number in _A[N_STAGES + 2, :14])
+A02 = tuple(float_(number) for number in _A[N_STAGES + 3, :15])
+C_EXTRA = tuple(float_(number) for number in _C[N_STAGES + 1 :])
+D00 = tuple(float_(number) for number in _D[0, :])
+D01 = tuple(float_(number) for number in _D[1, :])
+D02 = tuple(float_(number) for number in _D[2, :])
+D03 = tuple(float_(number) for number in _D[3, :])
+
+
+@hjit(f"Tuple([f,f,V,V,{FSIG:s}])(F({DSIG:s}),f,f,f,f,V,V,V,V,V,V,{KSIG:s})")
+def dop853_dense_output_hf(fun, argk, t_old, t, h, rr, vv, rr_old, vv_old, fr, fv, K):
+ """Compute a local interpolant over the last successful step.
+
+ Returns
+ -------
+ sol : `DenseOutput`
+ Local interpolant over the last successful step.
+ """
+
+ assert t_old is not None
+ assert t != t_old
+
+ K00, K01, K02, K03, K04, K05, K06, K07, K08, K09, K10, K11, K12 = K
+
+ dr = (
+ (
+ K00[0] * A00[0]
+ + K01[0] * A00[1]
+ + K02[0] * A00[2]
+ + K03[0] * A00[3]
+ + K04[0] * A00[4]
+ + K05[0] * A00[5]
+ + K06[0] * A00[6]
+ + K07[0] * A00[7]
+ + K08[0] * A00[8]
+ + K09[0] * A00[9]
+ + K10[0] * A00[10]
+ + K11[0] * A00[11]
+ + K12[0] * A00[12]
+ )
+ * h,
+ (
+ K00[1] * A00[0]
+ + K01[1] * A00[1]
+ + K02[1] * A00[2]
+ + K03[1] * A00[3]
+ + K04[1] * A00[4]
+ + K05[1] * A00[5]
+ + K06[1] * A00[6]
+ + K07[1] * A00[7]
+ + K08[1] * A00[8]
+ + K09[1] * A00[9]
+ + K10[1] * A00[10]
+ + K11[1] * A00[11]
+ + K12[1] * A00[12]
+ )
+ * h,
+ (
+ K00[2] * A00[0]
+ + K01[2] * A00[1]
+ + K02[2] * A00[2]
+ + K03[2] * A00[3]
+ + K04[2] * A00[4]
+ + K05[2] * A00[5]
+ + K06[2] * A00[6]
+ + K07[2] * A00[7]
+ + K08[2] * A00[8]
+ + K09[2] * A00[9]
+ + K10[2] * A00[10]
+ + K11[2] * A00[11]
+ + K12[2] * A00[12]
+ )
+ * h,
+ )
+ dv = (
+ (
+ K00[3] * A00[0]
+ + K01[3] * A00[1]
+ + K02[3] * A00[2]
+ + K03[3] * A00[3]
+ + K04[3] * A00[4]
+ + K05[3] * A00[5]
+ + K06[3] * A00[6]
+ + K07[3] * A00[7]
+ + K08[3] * A00[8]
+ + K09[3] * A00[9]
+ + K10[3] * A00[10]
+ + K11[3] * A00[11]
+ + K12[3] * A00[12]
+ )
+ * h,
+ (
+ K00[4] * A00[0]
+ + K01[4] * A00[1]
+ + K02[4] * A00[2]
+ + K03[4] * A00[3]
+ + K04[4] * A00[4]
+ + K05[4] * A00[5]
+ + K06[4] * A00[6]
+ + K07[4] * A00[7]
+ + K08[4] * A00[8]
+ + K09[4] * A00[9]
+ + K10[4] * A00[10]
+ + K11[4] * A00[11]
+ + K12[4] * A00[12]
+ )
+ * h,
+ (
+ K00[5] * A00[0]
+ + K01[5] * A00[1]
+ + K02[5] * A00[2]
+ + K03[5] * A00[3]
+ + K04[5] * A00[4]
+ + K05[5] * A00[5]
+ + K06[5] * A00[6]
+ + K07[5] * A00[7]
+ + K08[5] * A00[8]
+ + K09[5] * A00[9]
+ + K10[5] * A00[10]
+ + K11[5] * A00[11]
+ + K12[5] * A00[12]
+ )
+ * h,
+ )
+ rr_ = add_VV_hf(rr_old, dr)
+ vv_ = add_VV_hf(vv_old, dv)
+ rr_, vv_ = fun(
+ t_old + C_EXTRA[0] * h,
+ rr_,
+ vv_,
+ argk,
+ )
+ K13 = *rr_, *vv_
+
+ dr = (
+ (
+ K00[0] * A01[0]
+ + K01[0] * A01[1]
+ + K02[0] * A01[2]
+ + K03[0] * A01[3]
+ + K04[0] * A01[4]
+ + K05[0] * A01[5]
+ + K06[0] * A01[6]
+ + K07[0] * A01[7]
+ + K08[0] * A01[8]
+ + K09[0] * A01[9]
+ + K10[0] * A01[10]
+ + K11[0] * A01[11]
+ + K12[0] * A01[12]
+ + K13[0] * A01[13]
+ )
+ * h,
+ (
+ K00[1] * A01[0]
+ + K01[1] * A01[1]
+ + K02[1] * A01[2]
+ + K03[1] * A01[3]
+ + K04[1] * A01[4]
+ + K05[1] * A01[5]
+ + K06[1] * A01[6]
+ + K07[1] * A01[7]
+ + K08[1] * A01[8]
+ + K09[1] * A01[9]
+ + K10[1] * A01[10]
+ + K11[1] * A01[11]
+ + K12[1] * A01[12]
+ + K13[1] * A01[13]
+ )
+ * h,
+ (
+ K00[2] * A01[0]
+ + K01[2] * A01[1]
+ + K02[2] * A01[2]
+ + K03[2] * A01[3]
+ + K04[2] * A01[4]
+ + K05[2] * A01[5]
+ + K06[2] * A01[6]
+ + K07[2] * A01[7]
+ + K08[2] * A01[8]
+ + K09[2] * A01[9]
+ + K10[2] * A01[10]
+ + K11[2] * A01[11]
+ + K12[2] * A01[12]
+ + K13[2] * A01[13]
+ )
+ * h,
+ )
+ dv = (
+ (
+ K00[3] * A01[0]
+ + K01[3] * A01[1]
+ + K02[3] * A01[2]
+ + K03[3] * A01[3]
+ + K04[3] * A01[4]
+ + K05[3] * A01[5]
+ + K06[3] * A01[6]
+ + K07[3] * A01[7]
+ + K08[3] * A01[8]
+ + K09[3] * A01[9]
+ + K10[3] * A01[10]
+ + K11[3] * A01[11]
+ + K12[3] * A01[12]
+ + K13[3] * A01[13]
+ )
+ * h,
+ (
+ K00[4] * A01[0]
+ + K01[4] * A01[1]
+ + K02[4] * A01[2]
+ + K03[4] * A01[3]
+ + K04[4] * A01[4]
+ + K05[4] * A01[5]
+ + K06[4] * A01[6]
+ + K07[4] * A01[7]
+ + K08[4] * A01[8]
+ + K09[4] * A01[9]
+ + K10[4] * A01[10]
+ + K11[4] * A01[11]
+ + K12[4] * A01[12]
+ + K13[4] * A01[13]
+ )
+ * h,
+ (
+ K00[5] * A01[0]
+ + K01[5] * A01[1]
+ + K02[5] * A01[2]
+ + K03[5] * A01[3]
+ + K04[5] * A01[4]
+ + K05[5] * A01[5]
+ + K06[5] * A01[6]
+ + K07[5] * A01[7]
+ + K08[5] * A01[8]
+ + K09[5] * A01[9]
+ + K10[5] * A01[10]
+ + K11[5] * A01[11]
+ + K12[5] * A01[12]
+ + K13[5] * A01[13]
+ )
+ * h,
+ )
+ rr_ = add_VV_hf(rr_old, dr)
+ vv_ = add_VV_hf(vv_old, dv)
+ rr_, vv_ = fun(
+ t_old + C_EXTRA[1] * h,
+ rr_,
+ vv_,
+ argk,
+ )
+ K14 = *rr_, *vv_
+
+ dr = (
+ (
+ K00[0] * A02[0]
+ + K01[0] * A02[1]
+ + K02[0] * A02[2]
+ + K03[0] * A02[3]
+ + K04[0] * A02[4]
+ + K05[0] * A02[5]
+ + K06[0] * A02[6]
+ + K07[0] * A02[7]
+ + K08[0] * A02[8]
+ + K09[0] * A02[9]
+ + K10[0] * A02[10]
+ + K11[0] * A02[11]
+ + K12[0] * A02[12]
+ + K13[0] * A02[13]
+ + K14[0] * A02[14]
+ )
+ * h,
+ (
+ K00[1] * A02[0]
+ + K01[1] * A02[1]
+ + K02[1] * A02[2]
+ + K03[1] * A02[3]
+ + K04[1] * A02[4]
+ + K05[1] * A02[5]
+ + K06[1] * A02[6]
+ + K07[1] * A02[7]
+ + K08[1] * A02[8]
+ + K09[1] * A02[9]
+ + K10[1] * A02[10]
+ + K11[1] * A02[11]
+ + K12[1] * A02[12]
+ + K13[1] * A02[13]
+ + K14[1] * A02[14]
+ )
+ * h,
+ (
+ K00[2] * A02[0]
+ + K01[2] * A02[1]
+ + K02[2] * A02[2]
+ + K03[2] * A02[3]
+ + K04[2] * A02[4]
+ + K05[2] * A02[5]
+ + K06[2] * A02[6]
+ + K07[2] * A02[7]
+ + K08[2] * A02[8]
+ + K09[2] * A02[9]
+ + K10[2] * A02[10]
+ + K11[2] * A02[11]
+ + K12[2] * A02[12]
+ + K13[2] * A02[13]
+ + K14[2] * A02[14]
+ )
+ * h,
+ )
+ dv = (
+ (
+ K00[3] * A02[0]
+ + K01[3] * A02[1]
+ + K02[3] * A02[2]
+ + K03[3] * A02[3]
+ + K04[3] * A02[4]
+ + K05[3] * A02[5]
+ + K06[3] * A02[6]
+ + K07[3] * A02[7]
+ + K08[3] * A02[8]
+ + K09[3] * A02[9]
+ + K10[3] * A02[10]
+ + K11[3] * A02[11]
+ + K12[3] * A02[12]
+ + K13[3] * A02[13]
+ + K14[3] * A02[14]
+ )
+ * h,
+ (
+ K00[4] * A02[0]
+ + K01[4] * A02[1]
+ + K02[4] * A02[2]
+ + K03[4] * A02[3]
+ + K04[4] * A02[4]
+ + K05[4] * A02[5]
+ + K06[4] * A02[6]
+ + K07[4] * A02[7]
+ + K08[4] * A02[8]
+ + K09[4] * A02[9]
+ + K10[4] * A02[10]
+ + K11[4] * A02[11]
+ + K12[4] * A02[12]
+ + K13[4] * A02[13]
+ + K14[4] * A02[14]
+ )
+ * h,
+ (
+ K00[5] * A02[0]
+ + K01[5] * A02[1]
+ + K02[5] * A02[2]
+ + K03[5] * A02[3]
+ + K04[5] * A02[4]
+ + K05[5] * A02[5]
+ + K06[5] * A02[6]
+ + K07[5] * A02[7]
+ + K08[5] * A02[8]
+ + K09[5] * A02[9]
+ + K10[5] * A02[10]
+ + K11[5] * A02[11]
+ + K12[5] * A02[12]
+ + K13[5] * A02[13]
+ + K14[5] * A02[14]
+ )
+ * h,
+ )
+ rr_ = add_VV_hf(rr_old, dr)
+ vv_ = add_VV_hf(vv_old, dv)
+ rr_, vv_ = fun(
+ t_old + C_EXTRA[2] * h,
+ rr_,
+ vv_,
+ argk,
+ )
+ K15 = *rr_, *vv_
+
+ fr_old = K00[:3]
+ fv_old = K00[3:]
+
+ delta_rr = sub_VV_hf(rr, rr_old)
+ delta_vv = sub_VV_hf(vv, vv_old)
+
+ F00 = *delta_rr, *delta_vv
+ F01 = *sub_VV_hf(mul_Vs_hf(fr_old, h), delta_rr), *sub_VV_hf(
+ mul_Vs_hf(fv_old, h), delta_vv
+ )
+ F02 = *sub_VV_hf(
+ mul_Vs_hf(delta_rr, 2), mul_Vs_hf(add_VV_hf(fr, fr_old), h)
+ ), *sub_VV_hf(mul_Vs_hf(delta_vv, 2), mul_Vs_hf(add_VV_hf(fv, fv_old), h))
+
+ F03 = (
+ (
+ D00[0] * K00[0]
+ + D00[1] * K01[0]
+ + D00[2] * K02[0]
+ + D00[3] * K03[0]
+ + D00[4] * K04[0]
+ + D00[5] * K05[0]
+ + D00[6] * K06[0]
+ + D00[7] * K07[0]
+ + D00[8] * K08[0]
+ + D00[9] * K09[0]
+ + D00[10] * K10[0]
+ + D00[11] * K11[0]
+ + D00[12] * K12[0]
+ + D00[13] * K13[0]
+ + D00[14] * K14[0]
+ + D00[15] * K15[0]
+ )
+ * h,
+ (
+ D00[0] * K00[1]
+ + D00[1] * K01[1]
+ + D00[2] * K02[1]
+ + D00[3] * K03[1]
+ + D00[4] * K04[1]
+ + D00[5] * K05[1]
+ + D00[6] * K06[1]
+ + D00[7] * K07[1]
+ + D00[8] * K08[1]
+ + D00[9] * K09[1]
+ + D00[10] * K10[1]
+ + D00[11] * K11[1]
+ + D00[12] * K12[1]
+ + D00[13] * K13[1]
+ + D00[14] * K14[1]
+ + D00[15] * K15[1]
+ )
+ * h,
+ (
+ D00[0] * K00[2]
+ + D00[1] * K01[2]
+ + D00[2] * K02[2]
+ + D00[3] * K03[2]
+ + D00[4] * K04[2]
+ + D00[5] * K05[2]
+ + D00[6] * K06[2]
+ + D00[7] * K07[2]
+ + D00[8] * K08[2]
+ + D00[9] * K09[2]
+ + D00[10] * K10[2]
+ + D00[11] * K11[2]
+ + D00[12] * K12[2]
+ + D00[13] * K13[2]
+ + D00[14] * K14[2]
+ + D00[15] * K15[2]
+ )
+ * h,
+ (
+ D00[0] * K00[3]
+ + D00[1] * K01[3]
+ + D00[2] * K02[3]
+ + D00[3] * K03[3]
+ + D00[4] * K04[3]
+ + D00[5] * K05[3]
+ + D00[6] * K06[3]
+ + D00[7] * K07[3]
+ + D00[8] * K08[3]
+ + D00[9] * K09[3]
+ + D00[10] * K10[3]
+ + D00[11] * K11[3]
+ + D00[12] * K12[3]
+ + D00[13] * K13[3]
+ + D00[14] * K14[3]
+ + D00[15] * K15[3]
+ )
+ * h,
+ (
+ D00[0] * K00[4]
+ + D00[1] * K01[4]
+ + D00[2] * K02[4]
+ + D00[3] * K03[4]
+ + D00[4] * K04[4]
+ + D00[5] * K05[4]
+ + D00[6] * K06[4]
+ + D00[7] * K07[4]
+ + D00[8] * K08[4]
+ + D00[9] * K09[4]
+ + D00[10] * K10[4]
+ + D00[11] * K11[4]
+ + D00[12] * K12[4]
+ + D00[13] * K13[4]
+ + D00[14] * K14[4]
+ + D00[15] * K15[4]
+ )
+ * h,
+ (
+ D00[0] * K00[5]
+ + D00[1] * K01[5]
+ + D00[2] * K02[5]
+ + D00[3] * K03[5]
+ + D00[4] * K04[5]
+ + D00[5] * K05[5]
+ + D00[6] * K06[5]
+ + D00[7] * K07[5]
+ + D00[8] * K08[5]
+ + D00[9] * K09[5]
+ + D00[10] * K10[5]
+ + D00[11] * K11[5]
+ + D00[12] * K12[5]
+ + D00[13] * K13[5]
+ + D00[14] * K14[5]
+ + D00[15] * K15[5]
+ )
+ * h,
+ )
+
+ F04 = (
+ (
+ D01[0] * K00[0]
+ + D01[1] * K01[0]
+ + D01[2] * K02[0]
+ + D01[3] * K03[0]
+ + D01[4] * K04[0]
+ + D01[5] * K05[0]
+ + D01[6] * K06[0]
+ + D01[7] * K07[0]
+ + D01[8] * K08[0]
+ + D01[9] * K09[0]
+ + D01[10] * K10[0]
+ + D01[11] * K11[0]
+ + D01[12] * K12[0]
+ + D01[13] * K13[0]
+ + D01[14] * K14[0]
+ + D01[15] * K15[0]
+ )
+ * h,
+ (
+ D01[0] * K00[1]
+ + D01[1] * K01[1]
+ + D01[2] * K02[1]
+ + D01[3] * K03[1]
+ + D01[4] * K04[1]
+ + D01[5] * K05[1]
+ + D01[6] * K06[1]
+ + D01[7] * K07[1]
+ + D01[8] * K08[1]
+ + D01[9] * K09[1]
+ + D01[10] * K10[1]
+ + D01[11] * K11[1]
+ + D01[12] * K12[1]
+ + D01[13] * K13[1]
+ + D01[14] * K14[1]
+ + D01[15] * K15[1]
+ )
+ * h,
+ (
+ D01[0] * K00[2]
+ + D01[1] * K01[2]
+ + D01[2] * K02[2]
+ + D01[3] * K03[2]
+ + D01[4] * K04[2]
+ + D01[5] * K05[2]
+ + D01[6] * K06[2]
+ + D01[7] * K07[2]
+ + D01[8] * K08[2]
+ + D01[9] * K09[2]
+ + D01[10] * K10[2]
+ + D01[11] * K11[2]
+ + D01[12] * K12[2]
+ + D01[13] * K13[2]
+ + D01[14] * K14[2]
+ + D01[15] * K15[2]
+ )
+ * h,
+ (
+ D01[0] * K00[3]
+ + D01[1] * K01[3]
+ + D01[2] * K02[3]
+ + D01[3] * K03[3]
+ + D01[4] * K04[3]
+ + D01[5] * K05[3]
+ + D01[6] * K06[3]
+ + D01[7] * K07[3]
+ + D01[8] * K08[3]
+ + D01[9] * K09[3]
+ + D01[10] * K10[3]
+ + D01[11] * K11[3]
+ + D01[12] * K12[3]
+ + D01[13] * K13[3]
+ + D01[14] * K14[3]
+ + D01[15] * K15[3]
+ )
+ * h,
+ (
+ D01[0] * K00[4]
+ + D01[1] * K01[4]
+ + D01[2] * K02[4]
+ + D01[3] * K03[4]
+ + D01[4] * K04[4]
+ + D01[5] * K05[4]
+ + D01[6] * K06[4]
+ + D01[7] * K07[4]
+ + D01[8] * K08[4]
+ + D01[9] * K09[4]
+ + D01[10] * K10[4]
+ + D01[11] * K11[4]
+ + D01[12] * K12[4]
+ + D01[13] * K13[4]
+ + D01[14] * K14[4]
+ + D01[15] * K15[4]
+ )
+ * h,
+ (
+ D01[0] * K00[5]
+ + D01[1] * K01[5]
+ + D01[2] * K02[5]
+ + D01[3] * K03[5]
+ + D01[4] * K04[5]
+ + D01[5] * K05[5]
+ + D01[6] * K06[5]
+ + D01[7] * K07[5]
+ + D01[8] * K08[5]
+ + D01[9] * K09[5]
+ + D01[10] * K10[5]
+ + D01[11] * K11[5]
+ + D01[12] * K12[5]
+ + D01[13] * K13[5]
+ + D01[14] * K14[5]
+ + D01[15] * K15[5]
+ )
+ * h,
+ )
+
+ F05 = (
+ (
+ D02[0] * K00[0]
+ + D02[1] * K01[0]
+ + D02[2] * K02[0]
+ + D02[3] * K03[0]
+ + D02[4] * K04[0]
+ + D02[5] * K05[0]
+ + D02[6] * K06[0]
+ + D02[7] * K07[0]
+ + D02[8] * K08[0]
+ + D02[9] * K09[0]
+ + D02[10] * K10[0]
+ + D02[11] * K11[0]
+ + D02[12] * K12[0]
+ + D02[13] * K13[0]
+ + D02[14] * K14[0]
+ + D02[15] * K15[0]
+ )
+ * h,
+ (
+ D02[0] * K00[1]
+ + D02[1] * K01[1]
+ + D02[2] * K02[1]
+ + D02[3] * K03[1]
+ + D02[4] * K04[1]
+ + D02[5] * K05[1]
+ + D02[6] * K06[1]
+ + D02[7] * K07[1]
+ + D02[8] * K08[1]
+ + D02[9] * K09[1]
+ + D02[10] * K10[1]
+ + D02[11] * K11[1]
+ + D02[12] * K12[1]
+ + D02[13] * K13[1]
+ + D02[14] * K14[1]
+ + D02[15] * K15[1]
+ )
+ * h,
+ (
+ D02[0] * K00[2]
+ + D02[1] * K01[2]
+ + D02[2] * K02[2]
+ + D02[3] * K03[2]
+ + D02[4] * K04[2]
+ + D02[5] * K05[2]
+ + D02[6] * K06[2]
+ + D02[7] * K07[2]
+ + D02[8] * K08[2]
+ + D02[9] * K09[2]
+ + D02[10] * K10[2]
+ + D02[11] * K11[2]
+ + D02[12] * K12[2]
+ + D02[13] * K13[2]
+ + D02[14] * K14[2]
+ + D02[15] * K15[2]
+ )
+ * h,
+ (
+ D02[0] * K00[3]
+ + D02[1] * K01[3]
+ + D02[2] * K02[3]
+ + D02[3] * K03[3]
+ + D02[4] * K04[3]
+ + D02[5] * K05[3]
+ + D02[6] * K06[3]
+ + D02[7] * K07[3]
+ + D02[8] * K08[3]
+ + D02[9] * K09[3]
+ + D02[10] * K10[3]
+ + D02[11] * K11[3]
+ + D02[12] * K12[3]
+ + D02[13] * K13[3]
+ + D02[14] * K14[3]
+ + D02[15] * K15[3]
+ )
+ * h,
+ (
+ D02[0] * K00[4]
+ + D02[1] * K01[4]
+ + D02[2] * K02[4]
+ + D02[3] * K03[4]
+ + D02[4] * K04[4]
+ + D02[5] * K05[4]
+ + D02[6] * K06[4]
+ + D02[7] * K07[4]
+ + D02[8] * K08[4]
+ + D02[9] * K09[4]
+ + D02[10] * K10[4]
+ + D02[11] * K11[4]
+ + D02[12] * K12[4]
+ + D02[13] * K13[4]
+ + D02[14] * K14[4]
+ + D02[15] * K15[4]
+ )
+ * h,
+ (
+ D02[0] * K00[5]
+ + D02[1] * K01[5]
+ + D02[2] * K02[5]
+ + D02[3] * K03[5]
+ + D02[4] * K04[5]
+ + D02[5] * K05[5]
+ + D02[6] * K06[5]
+ + D02[7] * K07[5]
+ + D02[8] * K08[5]
+ + D02[9] * K09[5]
+ + D02[10] * K10[5]
+ + D02[11] * K11[5]
+ + D02[12] * K12[5]
+ + D02[13] * K13[5]
+ + D02[14] * K14[5]
+ + D02[15] * K15[5]
+ )
+ * h,
+ )
+
+ F06 = (
+ (
+ D03[0] * K00[0]
+ + D03[1] * K01[0]
+ + D03[2] * K02[0]
+ + D03[3] * K03[0]
+ + D03[4] * K04[0]
+ + D03[5] * K05[0]
+ + D03[6] * K06[0]
+ + D03[7] * K07[0]
+ + D03[8] * K08[0]
+ + D03[9] * K09[0]
+ + D03[10] * K10[0]
+ + D03[11] * K11[0]
+ + D03[12] * K12[0]
+ + D03[13] * K13[0]
+ + D03[14] * K14[0]
+ + D03[15] * K15[0]
+ )
+ * h,
+ (
+ D03[0] * K00[1]
+ + D03[1] * K01[1]
+ + D03[2] * K02[1]
+ + D03[3] * K03[1]
+ + D03[4] * K04[1]
+ + D03[5] * K05[1]
+ + D03[6] * K06[1]
+ + D03[7] * K07[1]
+ + D03[8] * K08[1]
+ + D03[9] * K09[1]
+ + D03[10] * K10[1]
+ + D03[11] * K11[1]
+ + D03[12] * K12[1]
+ + D03[13] * K13[1]
+ + D03[14] * K14[1]
+ + D03[15] * K15[1]
+ )
+ * h,
+ (
+ D03[0] * K00[2]
+ + D03[1] * K01[2]
+ + D03[2] * K02[2]
+ + D03[3] * K03[2]
+ + D03[4] * K04[2]
+ + D03[5] * K05[2]
+ + D03[6] * K06[2]
+ + D03[7] * K07[2]
+ + D03[8] * K08[2]
+ + D03[9] * K09[2]
+ + D03[10] * K10[2]
+ + D03[11] * K11[2]
+ + D03[12] * K12[2]
+ + D03[13] * K13[2]
+ + D03[14] * K14[2]
+ + D03[15] * K15[2]
+ )
+ * h,
+ (
+ D03[0] * K00[3]
+ + D03[1] * K01[3]
+ + D03[2] * K02[3]
+ + D03[3] * K03[3]
+ + D03[4] * K04[3]
+ + D03[5] * K05[3]
+ + D03[6] * K06[3]
+ + D03[7] * K07[3]
+ + D03[8] * K08[3]
+ + D03[9] * K09[3]
+ + D03[10] * K10[3]
+ + D03[11] * K11[3]
+ + D03[12] * K12[3]
+ + D03[13] * K13[3]
+ + D03[14] * K14[3]
+ + D03[15] * K15[3]
+ )
+ * h,
+ (
+ D03[0] * K00[4]
+ + D03[1] * K01[4]
+ + D03[2] * K02[4]
+ + D03[3] * K03[4]
+ + D03[4] * K04[4]
+ + D03[5] * K05[4]
+ + D03[6] * K06[4]
+ + D03[7] * K07[4]
+ + D03[8] * K08[4]
+ + D03[9] * K09[4]
+ + D03[10] * K10[4]
+ + D03[11] * K11[4]
+ + D03[12] * K12[4]
+ + D03[13] * K13[4]
+ + D03[14] * K14[4]
+ + D03[15] * K15[4]
+ )
+ * h,
+ (
+ D03[0] * K00[5]
+ + D03[1] * K01[5]
+ + D03[2] * K02[5]
+ + D03[3] * K03[5]
+ + D03[4] * K04[5]
+ + D03[5] * K05[5]
+ + D03[6] * K06[5]
+ + D03[7] * K07[5]
+ + D03[8] * K08[5]
+ + D03[9] * K09[5]
+ + D03[10] * K10[5]
+ + D03[11] * K11[5]
+ + D03[12] * K12[5]
+ + D03[13] * K13[5]
+ + D03[14] * K14[5]
+ + D03[15] * K15[5]
+ )
+ * h,
+ )
+
+ return (
+ t_old,
+ t - t_old, # h
+ rr_old,
+ vv_old,
+ (F00, F01, F02, F03, F04, F05, F06),
+ )
diff --git a/src/hapsira/core/math/ivp/_rkerror.py b/src/hapsira/core/math/ivp/_rkerror.py
new file mode 100644
index 000000000..98e04f6eb
--- /dev/null
+++ b/src/hapsira/core/math/ivp/_rkerror.py
@@ -0,0 +1,241 @@
+from math import fabs, sqrt
+
+from ._const import N_RV, KSIG
+from ._dop853_coefficients import E3 as _E3, E5 as _E5
+from ..ieee754 import float_
+from ..linalg import div_ss_hf
+from ...jit import hjit
+
+
+__all__ = [
+ "estimate_error_norm_V_hf",
+]
+
+
+E3 = tuple(float_(number) for number in _E3) # N_STAGES + 1
+E5 = tuple(float_(number) for number in _E5) # N_STAGES + 1
+
+
+@hjit(f"f({KSIG:s},f,V,V)")
+def estimate_error_norm_V_hf(K, h, scale_r, scale_v):
+ K00, K01, K02, K03, K04, K05, K06, K07, K08, K09, K10, K11, K12 = K
+
+ err3 = (
+ div_ss_hf(
+ K00[0] * E3[0]
+ + K01[0] * E3[1]
+ + K02[0] * E3[2]
+ + K03[0] * E3[3]
+ + K04[0] * E3[4]
+ + K05[0] * E3[5]
+ + K06[0] * E3[6]
+ + K07[0] * E3[7]
+ + K08[0] * E3[8]
+ + K09[0] * E3[9]
+ + K10[0] * E3[10]
+ + K11[0] * E3[11]
+ + K12[0] * E3[12],
+ scale_r[0],
+ ),
+ div_ss_hf(
+ K00[1] * E3[0]
+ + K01[1] * E3[1]
+ + K02[1] * E3[2]
+ + K03[1] * E3[3]
+ + K04[1] * E3[4]
+ + K05[1] * E3[5]
+ + K06[1] * E3[6]
+ + K07[1] * E3[7]
+ + K08[1] * E3[8]
+ + K09[1] * E3[9]
+ + K10[1] * E3[10]
+ + K11[1] * E3[11]
+ + K12[1] * E3[12],
+ scale_r[1],
+ ),
+ div_ss_hf(
+ K00[2] * E3[0]
+ + K01[2] * E3[1]
+ + K02[2] * E3[2]
+ + K03[2] * E3[3]
+ + K04[2] * E3[4]
+ + K05[2] * E3[5]
+ + K06[2] * E3[6]
+ + K07[2] * E3[7]
+ + K08[2] * E3[8]
+ + K09[2] * E3[9]
+ + K10[2] * E3[10]
+ + K11[2] * E3[11]
+ + K12[2] * E3[12],
+ scale_r[2],
+ ),
+ div_ss_hf(
+ K00[3] * E3[0]
+ + K01[3] * E3[1]
+ + K02[3] * E3[2]
+ + K03[3] * E3[3]
+ + K04[3] * E3[4]
+ + K05[3] * E3[5]
+ + K06[3] * E3[6]
+ + K07[3] * E3[7]
+ + K08[3] * E3[8]
+ + K09[3] * E3[9]
+ + K10[3] * E3[10]
+ + K11[3] * E3[11]
+ + K12[3] * E3[12],
+ scale_v[0],
+ ),
+ div_ss_hf(
+ K00[4] * E3[0]
+ + K01[4] * E3[1]
+ + K02[4] * E3[2]
+ + K03[4] * E3[3]
+ + K04[4] * E3[4]
+ + K05[4] * E3[5]
+ + K06[4] * E3[6]
+ + K07[4] * E3[7]
+ + K08[4] * E3[8]
+ + K09[4] * E3[9]
+ + K10[4] * E3[10]
+ + K11[4] * E3[11]
+ + K12[4] * E3[12],
+ scale_v[1],
+ ),
+ div_ss_hf(
+ K00[5] * E3[0]
+ + K01[5] * E3[1]
+ + K02[5] * E3[2]
+ + K03[5] * E3[3]
+ + K04[5] * E3[4]
+ + K05[5] * E3[5]
+ + K06[5] * E3[6]
+ + K07[5] * E3[7]
+ + K08[5] * E3[8]
+ + K09[5] * E3[9]
+ + K10[5] * E3[10]
+ + K11[5] * E3[11]
+ + K12[5] * E3[12],
+ scale_v[2],
+ ),
+ )
+ err5 = (
+ div_ss_hf(
+ K00[0] * E5[0]
+ + K01[0] * E5[1]
+ + K02[0] * E5[2]
+ + K03[0] * E5[3]
+ + K04[0] * E5[4]
+ + K05[0] * E5[5]
+ + K06[0] * E5[6]
+ + K07[0] * E5[7]
+ + K08[0] * E5[8]
+ + K09[0] * E5[9]
+ + K10[0] * E5[10]
+ + K11[0] * E5[11]
+ + K12[0] * E5[12],
+ scale_r[0],
+ ),
+ div_ss_hf(
+ K00[1] * E5[0]
+ + K01[1] * E5[1]
+ + K02[1] * E5[2]
+ + K03[1] * E5[3]
+ + K04[1] * E5[4]
+ + K05[1] * E5[5]
+ + K06[1] * E5[6]
+ + K07[1] * E5[7]
+ + K08[1] * E5[8]
+ + K09[1] * E5[9]
+ + K10[1] * E5[10]
+ + K11[1] * E5[11]
+ + K12[1] * E5[12],
+ scale_r[1],
+ ),
+ div_ss_hf(
+ K00[2] * E5[0]
+ + K01[2] * E5[1]
+ + K02[2] * E5[2]
+ + K03[2] * E5[3]
+ + K04[2] * E5[4]
+ + K05[2] * E5[5]
+ + K06[2] * E5[6]
+ + K07[2] * E5[7]
+ + K08[2] * E5[8]
+ + K09[2] * E5[9]
+ + K10[2] * E5[10]
+ + K11[2] * E5[11]
+ + K12[2] * E5[12],
+ scale_r[2],
+ ),
+ div_ss_hf(
+ K00[3] * E5[0]
+ + K01[3] * E5[1]
+ + K02[3] * E5[2]
+ + K03[3] * E5[3]
+ + K04[3] * E5[4]
+ + K05[3] * E5[5]
+ + K06[3] * E5[6]
+ + K07[3] * E5[7]
+ + K08[3] * E5[8]
+ + K09[3] * E5[9]
+ + K10[3] * E5[10]
+ + K11[3] * E5[11]
+ + K12[3] * E5[12],
+ scale_v[0],
+ ),
+ div_ss_hf(
+ K00[4] * E5[0]
+ + K01[4] * E5[1]
+ + K02[4] * E5[2]
+ + K03[4] * E5[3]
+ + K04[4] * E5[4]
+ + K05[4] * E5[5]
+ + K06[4] * E5[6]
+ + K07[4] * E5[7]
+ + K08[4] * E5[8]
+ + K09[4] * E5[9]
+ + K10[4] * E5[10]
+ + K11[4] * E5[11]
+ + K12[4] * E5[12],
+ scale_v[1],
+ ),
+ div_ss_hf(
+ K00[5] * E5[0]
+ + K01[5] * E5[1]
+ + K02[5] * E5[2]
+ + K03[5] * E5[3]
+ + K04[5] * E5[4]
+ + K05[5] * E5[5]
+ + K06[5] * E5[6]
+ + K07[5] * E5[7]
+ + K08[5] * E5[8]
+ + K09[5] * E5[9]
+ + K10[5] * E5[10]
+ + K11[5] * E5[11]
+ + K12[5] * E5[12],
+ scale_v[2],
+ ),
+ )
+
+ err5_norm_2 = (
+ err5[0] ** 2
+ + err5[1] ** 2
+ + err5[2] ** 2
+ + err5[3] ** 2
+ + err5[4] ** 2
+ + err5[5] ** 2
+ )
+ err3_norm_2 = (
+ err3[0] ** 2
+ + err3[1] ** 2
+ + err3[2] ** 2
+ + err3[3] ** 2
+ + err3[4] ** 2
+ + err3[5] ** 2
+ )
+
+ if err5_norm_2 == 0 and err3_norm_2 == 0:
+ return 0.0
+ denom = err5_norm_2 + 0.01 * err3_norm_2
+
+ return fabs(h) * div_ss_hf(err5_norm_2, sqrt(denom * N_RV))
diff --git a/src/hapsira/core/math/ivp/_rkstep.py b/src/hapsira/core/math/ivp/_rkstep.py
new file mode 100644
index 000000000..37fbb832d
--- /dev/null
+++ b/src/hapsira/core/math/ivp/_rkstep.py
@@ -0,0 +1,823 @@
+from ._const import N_STAGES, KSIG
+from ._dop853_coefficients import A as _A, B as _B, C as _C
+from ..ieee754 import float_
+from ..linalg import add_VV_hf
+from ...jit import hjit, DSIG
+
+
+__all__ = [
+ "rk_step_hf",
+]
+
+
+A01 = tuple(float_(number) for number in _A[1, :N_STAGES])
+A02 = tuple(float_(number) for number in _A[2, :N_STAGES])
+A03 = tuple(float_(number) for number in _A[3, :N_STAGES])
+A04 = tuple(float_(number) for number in _A[4, :N_STAGES])
+A05 = tuple(float_(number) for number in _A[5, :N_STAGES])
+A06 = tuple(float_(number) for number in _A[6, :N_STAGES])
+A07 = tuple(float_(number) for number in _A[7, :N_STAGES])
+A08 = tuple(float_(number) for number in _A[8, :N_STAGES])
+A09 = tuple(float_(number) for number in _A[9, :N_STAGES])
+A10 = tuple(float_(number) for number in _A[10, :N_STAGES])
+A11 = tuple(float_(number) for number in _A[11, :N_STAGES])
+B = tuple(float_(number) for number in _B)
+C = tuple(float_(number) for number in _C[:N_STAGES])
+
+
+@hjit(f"Tuple([V,V,V,V,{KSIG:s}])(F({DSIG:s}),f,V,V,V,V,f,f)")
+def rk_step_hf(fun, t, rr, vv, fr, fv, h, argk):
+ """Perform a single Runge-Kutta step.
+
+ This function computes a prediction of an explicit Runge-Kutta method and
+ also estimates the error of a less accurate method.
+
+ Notation for Butcher tableau is as in [1]_.
+
+ Parameters
+ ----------
+ fun : callable
+ Right-hand side of the system.
+ t : float
+ Current time.
+ r : tuple[float,float,float]
+ Current r.
+ v : tuple[float,float,float]
+ Current v.
+ fr : tuple[float,float,float]
+ Current value of the derivative, i.e., ``fun(x, y)``.
+ fv : tuple[float,float,float]
+ Current value of the derivative, i.e., ``fun(x, y)``.
+ h : float
+ Step to use.
+
+ Returns
+ -------
+ y_new : ndarray, shape (n,)
+ Solution at t + h computed with a higher accuracy.
+ f_new : ndarray, shape (n,)
+ Derivative ``fun(t + h, y_new)``.
+ K : ndarray, shape (n_stages + 1, n)
+ Storage array for putting RK stages here. Stages are stored in rows.
+ The last row is a linear combination of the previous rows with
+ coefficients
+
+ Const
+ -----
+ A : ndarray, shape (n_stages, n_stages)
+ Coefficients for combining previous RK stages to compute the next
+ stage. For explicit methods the coefficients at and above the main
+ diagonal are zeros.
+ B : ndarray, shape (n_stages,)
+ Coefficients for combining RK stages for computing the final
+ prediction.
+ C : ndarray, shape (n_stages,)
+ Coefficients for incrementing time for consecutive RK stages.
+ The value for the first stage is always zero.
+
+ References
+ ----------
+ .. [1] E. Hairer, S. P. Norsett G. Wanner, "Solving Ordinary Differential
+ Equations I: Nonstiff Problems", Sec. II.4.
+ """
+
+ K00 = *fr, *fv
+
+ dr = (
+ (K00[0] * A01[0]) * h,
+ (K00[1] * A01[0]) * h,
+ (K00[2] * A01[0]) * h,
+ )
+ dv = (
+ (K00[3] * A01[0]) * h,
+ (K00[4] * A01[0]) * h,
+ (K00[5] * A01[0]) * h,
+ )
+ fr, fv = fun(
+ t + C[1] * h,
+ add_VV_hf(rr, dr),
+ add_VV_hf(vv, dv),
+ argk,
+ )
+ K01 = *fr, *fv
+
+ dr = (
+ (K00[0] * A02[0] + K01[0] * A02[1]) * h,
+ (K00[1] * A02[0] + K01[1] * A02[1]) * h,
+ (K00[2] * A02[0] + K01[2] * A02[1]) * h,
+ )
+ dv = (
+ (K00[3] * A02[0] + K01[3] * A02[1]) * h,
+ (K00[4] * A02[0] + K01[4] * A02[1]) * h,
+ (K00[5] * A02[0] + K01[5] * A02[1]) * h,
+ )
+ fr, fv = fun(
+ t + C[2] * h,
+ add_VV_hf(rr, dr),
+ add_VV_hf(vv, dv),
+ argk,
+ )
+ K02 = *fr, *fv
+
+ dr = (
+ (K00[0] * A03[0] + K01[0] * A03[1] + K02[0] * A03[2]) * h,
+ (K00[1] * A03[0] + K01[1] * A03[1] + K02[1] * A03[2]) * h,
+ (K00[2] * A03[0] + K01[2] * A03[1] + K02[2] * A03[2]) * h,
+ )
+ dv = (
+ (K00[3] * A03[0] + K01[3] * A03[1] + K02[3] * A03[2]) * h,
+ (K00[4] * A03[0] + K01[4] * A03[1] + K02[4] * A03[2]) * h,
+ (K00[5] * A03[0] + K01[5] * A03[1] + K02[5] * A03[2]) * h,
+ )
+ fr, fv = fun(
+ t + C[3] * h,
+ add_VV_hf(rr, dr),
+ add_VV_hf(vv, dv),
+ argk,
+ )
+ K03 = *fr, *fv
+
+ dr = (
+ (K00[0] * A04[0] + K01[0] * A04[1] + K02[0] * A04[2] + K03[0] * A04[3]) * h,
+ (K00[1] * A04[0] + K01[1] * A04[1] + K02[1] * A04[2] + K03[1] * A04[3]) * h,
+ (K00[2] * A04[0] + K01[2] * A04[1] + K02[2] * A04[2] + K03[2] * A04[3]) * h,
+ )
+ dv = (
+ (K00[3] * A04[0] + K01[3] * A04[1] + K02[3] * A04[2] + K03[3] * A04[3]) * h,
+ (K00[4] * A04[0] + K01[4] * A04[1] + K02[4] * A04[2] + K03[4] * A04[3]) * h,
+ (K00[5] * A04[0] + K01[5] * A04[1] + K02[5] * A04[2] + K03[5] * A04[3]) * h,
+ )
+ fr, fv = fun(
+ t + C[4] * h,
+ add_VV_hf(rr, dr),
+ add_VV_hf(vv, dv),
+ argk,
+ )
+ K04 = *fr, *fv
+
+ dr = (
+ (
+ K00[0] * A05[0]
+ + K01[0] * A05[1]
+ + K02[0] * A05[2]
+ + K03[0] * A05[3]
+ + K04[0] * A05[4]
+ )
+ * h,
+ (
+ K00[1] * A05[0]
+ + K01[1] * A05[1]
+ + K02[1] * A05[2]
+ + K03[1] * A05[3]
+ + K04[1] * A05[4]
+ )
+ * h,
+ (
+ K00[2] * A05[0]
+ + K01[2] * A05[1]
+ + K02[2] * A05[2]
+ + K03[2] * A05[3]
+ + K04[2] * A05[4]
+ )
+ * h,
+ )
+ dv = (
+ (
+ K00[3] * A05[0]
+ + K01[3] * A05[1]
+ + K02[3] * A05[2]
+ + K03[3] * A05[3]
+ + K04[3] * A05[4]
+ )
+ * h,
+ (
+ K00[4] * A05[0]
+ + K01[4] * A05[1]
+ + K02[4] * A05[2]
+ + K03[4] * A05[3]
+ + K04[4] * A05[4]
+ )
+ * h,
+ (
+ K00[5] * A05[0]
+ + K01[5] * A05[1]
+ + K02[5] * A05[2]
+ + K03[5] * A05[3]
+ + K04[5] * A05[4]
+ )
+ * h,
+ )
+ fr, fv = fun(
+ t + C[5] * h,
+ add_VV_hf(rr, dr),
+ add_VV_hf(vv, dv),
+ argk,
+ )
+ K05 = *fr, *fv
+
+ dr = (
+ (
+ K00[0] * A06[0]
+ + K01[0] * A06[1]
+ + K02[0] * A06[2]
+ + K03[0] * A06[3]
+ + K04[0] * A06[4]
+ + K05[0] * A06[5]
+ )
+ * h,
+ (
+ K00[1] * A06[0]
+ + K01[1] * A06[1]
+ + K02[1] * A06[2]
+ + K03[1] * A06[3]
+ + K04[1] * A06[4]
+ + K05[1] * A06[5]
+ )
+ * h,
+ (
+ K00[2] * A06[0]
+ + K01[2] * A06[1]
+ + K02[2] * A06[2]
+ + K03[2] * A06[3]
+ + K04[2] * A06[4]
+ + K05[2] * A06[5]
+ )
+ * h,
+ )
+ dv = (
+ (
+ K00[3] * A06[0]
+ + K01[3] * A06[1]
+ + K02[3] * A06[2]
+ + K03[3] * A06[3]
+ + K04[3] * A06[4]
+ + K05[3] * A06[5]
+ )
+ * h,
+ (
+ K00[4] * A06[0]
+ + K01[4] * A06[1]
+ + K02[4] * A06[2]
+ + K03[4] * A06[3]
+ + K04[4] * A06[4]
+ + K05[4] * A06[5]
+ )
+ * h,
+ (
+ K00[5] * A06[0]
+ + K01[5] * A06[1]
+ + K02[5] * A06[2]
+ + K03[5] * A06[3]
+ + K04[5] * A06[4]
+ + K05[5] * A06[5]
+ )
+ * h,
+ )
+ fr, fv = fun(
+ t + C[6] * h,
+ add_VV_hf(rr, dr),
+ add_VV_hf(vv, dv),
+ argk,
+ )
+ K06 = *fr, *fv
+
+ dr = (
+ (
+ K00[0] * A07[0]
+ + K01[0] * A07[1]
+ + K02[0] * A07[2]
+ + K03[0] * A07[3]
+ + K04[0] * A07[4]
+ + K05[0] * A07[5]
+ + K06[0] * A07[6]
+ )
+ * h,
+ (
+ K00[1] * A07[0]
+ + K01[1] * A07[1]
+ + K02[1] * A07[2]
+ + K03[1] * A07[3]
+ + K04[1] * A07[4]
+ + K05[1] * A07[5]
+ + K06[1] * A07[6]
+ )
+ * h,
+ (
+ K00[2] * A07[0]
+ + K01[2] * A07[1]
+ + K02[2] * A07[2]
+ + K03[2] * A07[3]
+ + K04[2] * A07[4]
+ + K05[2] * A07[5]
+ + K06[2] * A07[6]
+ )
+ * h,
+ )
+ dv = (
+ (
+ K00[3] * A07[0]
+ + K01[3] * A07[1]
+ + K02[3] * A07[2]
+ + K03[3] * A07[3]
+ + K04[3] * A07[4]
+ + K05[3] * A07[5]
+ + K06[3] * A07[6]
+ )
+ * h,
+ (
+ K00[4] * A07[0]
+ + K01[4] * A07[1]
+ + K02[4] * A07[2]
+ + K03[4] * A07[3]
+ + K04[4] * A07[4]
+ + K05[4] * A07[5]
+ + K06[4] * A07[6]
+ )
+ * h,
+ (
+ K00[5] * A07[0]
+ + K01[5] * A07[1]
+ + K02[5] * A07[2]
+ + K03[5] * A07[3]
+ + K04[5] * A07[4]
+ + K05[5] * A07[5]
+ + K06[5] * A07[6]
+ )
+ * h,
+ )
+ fr, fv = fun(
+ t + C[7] * h,
+ add_VV_hf(rr, dr),
+ add_VV_hf(vv, dv),
+ argk,
+ )
+ K07 = *fr, *fv
+
+ dr = (
+ (
+ K00[0] * A08[0]
+ + K01[0] * A08[1]
+ + K02[0] * A08[2]
+ + K03[0] * A08[3]
+ + K04[0] * A08[4]
+ + K05[0] * A08[5]
+ + K06[0] * A08[6]
+ + K07[0] * A08[7]
+ )
+ * h,
+ (
+ K00[1] * A08[0]
+ + K01[1] * A08[1]
+ + K02[1] * A08[2]
+ + K03[1] * A08[3]
+ + K04[1] * A08[4]
+ + K05[1] * A08[5]
+ + K06[1] * A08[6]
+ + K07[1] * A08[7]
+ )
+ * h,
+ (
+ K00[2] * A08[0]
+ + K01[2] * A08[1]
+ + K02[2] * A08[2]
+ + K03[2] * A08[3]
+ + K04[2] * A08[4]
+ + K05[2] * A08[5]
+ + K06[2] * A08[6]
+ + K07[2] * A08[7]
+ )
+ * h,
+ )
+ dv = (
+ (
+ K00[3] * A08[0]
+ + K01[3] * A08[1]
+ + K02[3] * A08[2]
+ + K03[3] * A08[3]
+ + K04[3] * A08[4]
+ + K05[3] * A08[5]
+ + K06[3] * A08[6]
+ + K07[3] * A08[7]
+ )
+ * h,
+ (
+ K00[4] * A08[0]
+ + K01[4] * A08[1]
+ + K02[4] * A08[2]
+ + K03[4] * A08[3]
+ + K04[4] * A08[4]
+ + K05[4] * A08[5]
+ + K06[4] * A08[6]
+ + K07[4] * A08[7]
+ )
+ * h,
+ (
+ K00[5] * A08[0]
+ + K01[5] * A08[1]
+ + K02[5] * A08[2]
+ + K03[5] * A08[3]
+ + K04[5] * A08[4]
+ + K05[5] * A08[5]
+ + K06[5] * A08[6]
+ + K07[5] * A08[7]
+ )
+ * h,
+ )
+ fr, fv = fun(
+ t + C[8] * h,
+ add_VV_hf(rr, dr),
+ add_VV_hf(vv, dv),
+ argk,
+ )
+ K08 = *fr, *fv
+
+ dr = (
+ (
+ K00[0] * A09[0]
+ + K01[0] * A09[1]
+ + K02[0] * A09[2]
+ + K03[0] * A09[3]
+ + K04[0] * A09[4]
+ + K05[0] * A09[5]
+ + K06[0] * A09[6]
+ + K07[0] * A09[7]
+ + K08[0] * A09[8]
+ )
+ * h,
+ (
+ K00[1] * A09[0]
+ + K01[1] * A09[1]
+ + K02[1] * A09[2]
+ + K03[1] * A09[3]
+ + K04[1] * A09[4]
+ + K05[1] * A09[5]
+ + K06[1] * A09[6]
+ + K07[1] * A09[7]
+ + K08[1] * A09[8]
+ )
+ * h,
+ (
+ K00[2] * A09[0]
+ + K01[2] * A09[1]
+ + K02[2] * A09[2]
+ + K03[2] * A09[3]
+ + K04[2] * A09[4]
+ + K05[2] * A09[5]
+ + K06[2] * A09[6]
+ + K07[2] * A09[7]
+ + K08[2] * A09[8]
+ )
+ * h,
+ )
+ dv = (
+ (
+ K00[3] * A09[0]
+ + K01[3] * A09[1]
+ + K02[3] * A09[2]
+ + K03[3] * A09[3]
+ + K04[3] * A09[4]
+ + K05[3] * A09[5]
+ + K06[3] * A09[6]
+ + K07[3] * A09[7]
+ + K08[3] * A09[8]
+ )
+ * h,
+ (
+ K00[4] * A09[0]
+ + K01[4] * A09[1]
+ + K02[4] * A09[2]
+ + K03[4] * A09[3]
+ + K04[4] * A09[4]
+ + K05[4] * A09[5]
+ + K06[4] * A09[6]
+ + K07[4] * A09[7]
+ + K08[4] * A09[8]
+ )
+ * h,
+ (
+ K00[5] * A09[0]
+ + K01[5] * A09[1]
+ + K02[5] * A09[2]
+ + K03[5] * A09[3]
+ + K04[5] * A09[4]
+ + K05[5] * A09[5]
+ + K06[5] * A09[6]
+ + K07[5] * A09[7]
+ + K08[5] * A09[8]
+ )
+ * h,
+ )
+ fr, fv = fun(
+ t + C[9] * h,
+ add_VV_hf(rr, dr),
+ add_VV_hf(vv, dv),
+ argk,
+ )
+ K09 = *fr, *fv
+
+ dr = (
+ (
+ K00[0] * A10[0]
+ + K01[0] * A10[1]
+ + K02[0] * A10[2]
+ + K03[0] * A10[3]
+ + K04[0] * A10[4]
+ + K05[0] * A10[5]
+ + K06[0] * A10[6]
+ + K07[0] * A10[7]
+ + K08[0] * A10[8]
+ + K09[0] * A10[9]
+ )
+ * h,
+ (
+ K00[1] * A10[0]
+ + K01[1] * A10[1]
+ + K02[1] * A10[2]
+ + K03[1] * A10[3]
+ + K04[1] * A10[4]
+ + K05[1] * A10[5]
+ + K06[1] * A10[6]
+ + K07[1] * A10[7]
+ + K08[1] * A10[8]
+ + K09[1] * A10[9]
+ )
+ * h,
+ (
+ K00[2] * A10[0]
+ + K01[2] * A10[1]
+ + K02[2] * A10[2]
+ + K03[2] * A10[3]
+ + K04[2] * A10[4]
+ + K05[2] * A10[5]
+ + K06[2] * A10[6]
+ + K07[2] * A10[7]
+ + K08[2] * A10[8]
+ + K09[2] * A10[9]
+ )
+ * h,
+ )
+ dv = (
+ (
+ K00[3] * A10[0]
+ + K01[3] * A10[1]
+ + K02[3] * A10[2]
+ + K03[3] * A10[3]
+ + K04[3] * A10[4]
+ + K05[3] * A10[5]
+ + K06[3] * A10[6]
+ + K07[3] * A10[7]
+ + K08[3] * A10[8]
+ + K09[3] * A10[9]
+ )
+ * h,
+ (
+ K00[4] * A10[0]
+ + K01[4] * A10[1]
+ + K02[4] * A10[2]
+ + K03[4] * A10[3]
+ + K04[4] * A10[4]
+ + K05[4] * A10[5]
+ + K06[4] * A10[6]
+ + K07[4] * A10[7]
+ + K08[4] * A10[8]
+ + K09[4] * A10[9]
+ )
+ * h,
+ (
+ K00[5] * A10[0]
+ + K01[5] * A10[1]
+ + K02[5] * A10[2]
+ + K03[5] * A10[3]
+ + K04[5] * A10[4]
+ + K05[5] * A10[5]
+ + K06[5] * A10[6]
+ + K07[5] * A10[7]
+ + K08[5] * A10[8]
+ + K09[5] * A10[9]
+ )
+ * h,
+ )
+ fr, fv = fun(
+ t + C[10] * h,
+ add_VV_hf(rr, dr),
+ add_VV_hf(vv, dv),
+ argk,
+ )
+ K10 = *fr, *fv
+
+ dr = (
+ (
+ K00[0] * A11[0]
+ + K01[0] * A11[1]
+ + K02[0] * A11[2]
+ + K03[0] * A11[3]
+ + K04[0] * A11[4]
+ + K05[0] * A11[5]
+ + K06[0] * A11[6]
+ + K07[0] * A11[7]
+ + K08[0] * A11[8]
+ + K09[0] * A11[9]
+ + K10[0] * A11[10]
+ )
+ * h,
+ (
+ K00[1] * A11[0]
+ + K01[1] * A11[1]
+ + K02[1] * A11[2]
+ + K03[1] * A11[3]
+ + K04[1] * A11[4]
+ + K05[1] * A11[5]
+ + K06[1] * A11[6]
+ + K07[1] * A11[7]
+ + K08[1] * A11[8]
+ + K09[1] * A11[9]
+ + K10[1] * A11[10]
+ )
+ * h,
+ (
+ K00[2] * A11[0]
+ + K01[2] * A11[1]
+ + K02[2] * A11[2]
+ + K03[2] * A11[3]
+ + K04[2] * A11[4]
+ + K05[2] * A11[5]
+ + K06[2] * A11[6]
+ + K07[2] * A11[7]
+ + K08[2] * A11[8]
+ + K09[2] * A11[9]
+ + K10[2] * A11[10]
+ )
+ * h,
+ )
+ dv = (
+ (
+ K00[3] * A11[0]
+ + K01[3] * A11[1]
+ + K02[3] * A11[2]
+ + K03[3] * A11[3]
+ + K04[3] * A11[4]
+ + K05[3] * A11[5]
+ + K06[3] * A11[6]
+ + K07[3] * A11[7]
+ + K08[3] * A11[8]
+ + K09[3] * A11[9]
+ + K10[3] * A11[10]
+ )
+ * h,
+ (
+ K00[4] * A11[0]
+ + K01[4] * A11[1]
+ + K02[4] * A11[2]
+ + K03[4] * A11[3]
+ + K04[4] * A11[4]
+ + K05[4] * A11[5]
+ + K06[4] * A11[6]
+ + K07[4] * A11[7]
+ + K08[4] * A11[8]
+ + K09[4] * A11[9]
+ + K10[4] * A11[10]
+ )
+ * h,
+ (
+ K00[5] * A11[0]
+ + K01[5] * A11[1]
+ + K02[5] * A11[2]
+ + K03[5] * A11[3]
+ + K04[5] * A11[4]
+ + K05[5] * A11[5]
+ + K06[5] * A11[6]
+ + K07[5] * A11[7]
+ + K08[5] * A11[8]
+ + K09[5] * A11[9]
+ + K10[5] * A11[10]
+ )
+ * h,
+ )
+ fr, fv = fun(
+ t + C[11] * h,
+ add_VV_hf(rr, dr),
+ add_VV_hf(vv, dv),
+ argk,
+ )
+ K11 = *fr, *fv
+
+ dr = (
+ (
+ K00[0] * B[0]
+ + K01[0] * B[1]
+ + K02[0] * B[2]
+ + K03[0] * B[3]
+ + K04[0] * B[4]
+ + K05[0] * B[5]
+ + K06[0] * B[6]
+ + K07[0] * B[7]
+ + K08[0] * B[8]
+ + K09[0] * B[9]
+ + K10[0] * B[10]
+ + K11[0] * B[11]
+ )
+ * h,
+ (
+ K00[1] * B[0]
+ + K01[1] * B[1]
+ + K02[1] * B[2]
+ + K03[1] * B[3]
+ + K04[1] * B[4]
+ + K05[1] * B[5]
+ + K06[1] * B[6]
+ + K07[1] * B[7]
+ + K08[1] * B[8]
+ + K09[1] * B[9]
+ + K10[1] * B[10]
+ + K11[1] * B[11]
+ )
+ * h,
+ (
+ K00[2] * B[0]
+ + K01[2] * B[1]
+ + K02[2] * B[2]
+ + K03[2] * B[3]
+ + K04[2] * B[4]
+ + K05[2] * B[5]
+ + K06[2] * B[6]
+ + K07[2] * B[7]
+ + K08[2] * B[8]
+ + K09[2] * B[9]
+ + K10[2] * B[10]
+ + K11[2] * B[11]
+ )
+ * h,
+ )
+ dv = (
+ (
+ K00[3] * B[0]
+ + K01[3] * B[1]
+ + K02[3] * B[2]
+ + K03[3] * B[3]
+ + K04[3] * B[4]
+ + K05[3] * B[5]
+ + K06[3] * B[6]
+ + K07[3] * B[7]
+ + K08[3] * B[8]
+ + K09[3] * B[9]
+ + K10[3] * B[10]
+ + K11[3] * B[11]
+ )
+ * h,
+ (
+ K00[4] * B[0]
+ + K01[4] * B[1]
+ + K02[4] * B[2]
+ + K03[4] * B[3]
+ + K04[4] * B[4]
+ + K05[4] * B[5]
+ + K06[4] * B[6]
+ + K07[4] * B[7]
+ + K08[4] * B[8]
+ + K09[4] * B[9]
+ + K10[4] * B[10]
+ + K11[4] * B[11]
+ )
+ * h,
+ (
+ K00[5] * B[0]
+ + K01[4] * B[1]
+ + K02[5] * B[2]
+ + K03[5] * B[3]
+ + K04[5] * B[4]
+ + K05[5] * B[5]
+ + K06[5] * B[6]
+ + K07[5] * B[7]
+ + K08[5] * B[8]
+ + K09[5] * B[9]
+ + K10[5] * B[10]
+ + K11[5] * B[11]
+ )
+ * h,
+ )
+ rr_new = add_VV_hf(rr, dr)
+ vv_new = add_VV_hf(vv, dv)
+ fr_new, fv_new = fun(t + h, rr_new, vv_new, argk)
+ K12 = *fr_new, *fv_new
+
+ return (
+ rr_new,
+ vv_new,
+ fr_new,
+ fv_new,
+ (
+ K00,
+ K01,
+ K02,
+ K03,
+ K04,
+ K05,
+ K06,
+ K07,
+ K08,
+ K09,
+ K10,
+ K11,
+ K12,
+ ),
+ )
diff --git a/src/hapsira/core/math/ivp/_rkstepimpl.py b/src/hapsira/core/math/ivp/_rkstepimpl.py
new file mode 100644
index 000000000..0e4935a29
--- /dev/null
+++ b/src/hapsira/core/math/ivp/_rkstepimpl.py
@@ -0,0 +1,108 @@
+from math import inf, fabs
+
+from ._const import ERROR_EXPONENT, KSIG, MAX_FACTOR, MIN_FACTOR, SAFETY
+from ._rkstep import rk_step_hf
+from ._rkerror import estimate_error_norm_V_hf
+from ..ieee754 import nextafter
+from ..linalg import abs_V_hf, add_Vs_hf, max_VV_hf, mul_Vs_hf
+from ...jit import hjit, DSIG
+
+
+__all__ = [
+ "step_impl_hf",
+]
+
+
+@hjit(
+ f"Tuple([b1,f,f,V,V,f,V,V,{KSIG:s}])"
+ f"(F({DSIG:s}),f,f,V,V,V,V,f,f,f,f,f,{KSIG:s})"
+)
+def step_impl_hf(
+ fun, argk, t, rr, vv, fr, fv, rtol, atol, direction, h_abs, t_bound, K
+):
+ min_step = 10 * fabs(nextafter(t, direction * inf) - t)
+
+ if h_abs < min_step:
+ h_abs = min_step
+
+ step_accepted = False
+ step_rejected = False
+
+ while not step_accepted:
+ if h_abs < min_step:
+ return (
+ False,
+ 0.0,
+ 0.0,
+ (0.0, 0.0, 0.0),
+ (0.0, 0.0, 0.0),
+ 0.0,
+ (0.0, 0.0, 0.0),
+ (0.0, 0.0, 0.0),
+ K,
+ )
+
+ h = h_abs * direction
+ t_new = t + h
+
+ if direction * (t_new - t_bound) > 0:
+ t_new = t_bound
+
+ h = t_new - t
+ h_abs = fabs(h)
+
+ rr_new, vv_new, fr_new, fv_new, K_new = rk_step_hf(
+ fun,
+ t,
+ rr,
+ vv,
+ fr,
+ fv,
+ h,
+ argk,
+ )
+
+ scale_r = add_Vs_hf(
+ mul_Vs_hf(
+ max_VV_hf(
+ abs_V_hf(rr),
+ abs_V_hf(rr_new),
+ ),
+ rtol,
+ ),
+ atol,
+ )
+ scale_v = add_Vs_hf(
+ mul_Vs_hf(
+ max_VV_hf(
+ abs_V_hf(vv),
+ abs_V_hf(vv_new),
+ ),
+ rtol,
+ ),
+ atol,
+ )
+ error_norm = estimate_error_norm_V_hf(
+ K_new,
+ h,
+ scale_r,
+ scale_v,
+ )
+
+ if error_norm < 1:
+ if error_norm == 0:
+ factor = MAX_FACTOR
+ else:
+ factor = min(MAX_FACTOR, SAFETY * error_norm**ERROR_EXPONENT)
+
+ if step_rejected:
+ factor = min(1, factor)
+
+ h_abs *= factor
+
+ step_accepted = True
+ else:
+ h_abs *= max(MIN_FACTOR, SAFETY * error_norm**ERROR_EXPONENT)
+ step_rejected = True
+
+ return True, h, t_new, rr_new, vv_new, h_abs, fr_new, fv_new, K_new
diff --git a/src/hapsira/core/math/ivp/_rkstepinit.py b/src/hapsira/core/math/ivp/_rkstepinit.py
new file mode 100644
index 000000000..b975bec50
--- /dev/null
+++ b/src/hapsira/core/math/ivp/_rkstepinit.py
@@ -0,0 +1,57 @@
+from math import sqrt
+
+from ...jit import hjit, DSIG
+from ..linalg import add_VV_hf, div_VV_hf, mul_Vs_hf, norm_VV_hf, sub_VV_hf
+
+
+__all__ = [
+ "select_initial_step_hf",
+]
+
+
+@hjit(f"f(F({DSIG:s}),f,V,V,f,V,V,f,f,f,f)")
+def select_initial_step_hf(fun, t0, rr, vv, argk, fr, fv, direction, order, rtol, atol):
+ scale_r = (
+ atol + abs(rr[0]) * rtol,
+ atol + abs(rr[1]) * rtol,
+ atol + abs(rr[2]) * rtol,
+ )
+ scale_v = (
+ atol + abs(vv[0]) * rtol,
+ atol + abs(vv[1]) * rtol,
+ atol + abs(vv[2]) * rtol,
+ )
+
+ factor = 1 / sqrt(6)
+ d0 = norm_VV_hf(div_VV_hf(rr, scale_r), div_VV_hf(vv, scale_v)) * factor
+ d1 = norm_VV_hf(div_VV_hf(fr, scale_r), div_VV_hf(fv, scale_v)) * factor
+
+ if d0 < 1e-5 or d1 < 1e-5:
+ h0 = 1e-6
+ else:
+ h0 = 0.01 * d0 / d1
+
+ yr1 = add_VV_hf(rr, mul_Vs_hf(fr, h0 * direction))
+ yv1 = add_VV_hf(vv, mul_Vs_hf(fv, h0 * direction))
+
+ fr1, fv1 = fun(
+ t0 + h0 * direction,
+ yr1,
+ yv1,
+ argk,
+ )
+
+ d2 = (
+ norm_VV_hf(
+ div_VV_hf(sub_VV_hf(fr1, fr), scale_r),
+ div_VV_hf(sub_VV_hf(fv1, fv), scale_v),
+ )
+ / h0
+ )
+
+ if d1 <= 1e-15 and d2 <= 1e-15:
+ h1 = max(1e-6, h0 * 1e-3)
+ else:
+ h1 = (0.01 / max(d1, d2)) ** (1 / (order + 1))
+
+ return min(100 * h0, h1)
diff --git a/src/hapsira/core/math/ivp/_solve.py b/src/hapsira/core/math/ivp/_solve.py
new file mode 100644
index 000000000..96b13a6e4
--- /dev/null
+++ b/src/hapsira/core/math/ivp/_solve.py
@@ -0,0 +1,107 @@
+from math import nan
+from typing import Callable, Tuple
+
+from ...jit import hjit
+
+
+__all__ = [
+ "event_is_active_hf",
+ "dispatcher_hb",
+]
+
+
+TEMPLATE = """
+@hjit("{RESTYPE:s}(i8,{ARGTYPES:s})", cache = False)
+def dispatcher_hf(idx, {ARGUMENTS:s}):
+{DISPATCHER:s}
+ return {ERROR:s}
+"""
+
+_ = nan # keep import alive
+
+
+def dispatcher_hb(
+ funcs: Tuple[Callable, ...],
+ argtypes: str,
+ restype: str,
+ arguments: str,
+ error: str = "nan",
+) -> Callable:
+ """
+ Workaround for https://github.com/numba/numba/issues/9420
+ Compiles a dispatcher for a list of functions that can eventually called by index.
+
+ Parameters
+ ----------
+ funcs : tuple[Callable, ...]
+ One or multiple callables that require dispatching.
+ Dispatching will be based on position in tuple.
+ All callables must have the same signature.
+ argtypes : argument portion of signature for callables
+ restype : return type portion of signature for callables
+ arguments : names of arguments for callables
+
+ Returns
+ -------
+ b : Callable
+ Dispatcher function
+
+ """
+
+ funcs = [
+ (f"func_{id(func):x}", func) for func in funcs
+ ] # names are not unique, ids are
+ globals_, locals_ = globals(), locals() # HACK https://stackoverflow.com/a/71560563
+ globals_.update({name: handle for name, handle in funcs})
+
+ def switch(idx):
+ return "if" if idx == 0 else "elif"
+
+ code = TEMPLATE.format(
+ DISPATCHER="\n".join(
+ [
+ f" {switch(idx):s} idx == {idx:d}:\n return {name:s}({arguments:s})"
+ for idx, (name, _) in enumerate(funcs)
+ ]
+ ), # TODO tree-like dispatch, faster
+ ARGTYPES=argtypes,
+ RESTYPE=restype,
+ ARGUMENTS=arguments,
+ ERROR=error,
+ )
+ exec(code, globals_, locals_) # pylint: disable=W0122
+ globals_["dispatcher_hf"] = locals_[
+ "dispatcher_hf"
+ ] # HACK https://stackoverflow.com/a/71560563
+ return dispatcher_hf # pylint: disable=E0602 # noqa: F821
+
+
+@hjit("b1(f,f,f)")
+def event_is_active_hf(g_old, g_new, direction):
+ """
+ Find which event occurred during an integration step.
+
+ Based on
+ https://github.com/scipy/scipy/blob/4edfcaa3ce8a387450b6efce968572def71be089/scipy/integrate/_ivp/ivp.py#L130
+
+ Parameters
+ ----------
+ g_old : float
+ Value of event function at current point.
+ g_new : float
+ Value of event function at next point.
+ direction : float
+ Event "direction".
+
+ Returns
+ -------
+ active : boolean
+ Status of event (active or not)
+
+ """
+
+ up = (g_old <= 0) & (g_new >= 0)
+ down = (g_old >= 0) & (g_new <= 0)
+ either = up | down
+ active = up & (direction > 0) | down & (direction < 0) | either & (direction == 0)
+ return active
diff --git a/src/hapsira/core/math/linalg.py b/src/hapsira/core/math/linalg.py
new file mode 100644
index 000000000..503740cb4
--- /dev/null
+++ b/src/hapsira/core/math/linalg.py
@@ -0,0 +1,527 @@
+from math import fabs, inf, sqrt
+
+from ..jit import hjit, vjit
+
+__all__ = [
+ "abs_V_hf",
+ "add_Vs_hf",
+ "add_VV_hf",
+ "cross_VV_hf",
+ "div_Vs_hf",
+ "div_VV_hf",
+ "div_ss_hf",
+ "matmul_MM_hf",
+ "matmul_MV_hf",
+ "matmul_VM_hf",
+ "matmul_VV_hf",
+ "max_VV_hf",
+ "mul_Vs_hf",
+ "mul_VV_hf",
+ "norm_V_hf",
+ "norm_V_vf",
+ "norm_VV_hf",
+ "sign_hf",
+ "sub_VV_hf",
+ "transpose_M_hf",
+]
+
+
+@hjit("V(V)", inline=True)
+def abs_V_hf(a):
+ """
+ Abs 3D vector of 3D vector element-wise.
+ Inline-compiled by default.
+
+ Parameters
+ ----------
+ a : tuple[float,float,float]
+ Vector
+
+ Returns
+ -------
+ b : tuple[float,float,float]
+ Vector
+
+ """
+
+ return fabs(a[0]), fabs(a[1]), fabs(a[2])
+
+
+@hjit("V(V,f)", inline=True)
+def add_Vs_hf(a, b):
+ """
+ Adds a 3D vector and a scalar element-wise.
+ Inline-compiled by default.
+
+ Parameters
+ ----------
+ a : tuple[float,float,float]
+ Vector
+ b : float
+ Scalar
+
+ Returns
+ -------
+ c : tuple[float,float,float]
+ Vector
+
+ """
+
+ return a[0] + b, a[1] + b, a[2] + b
+
+
+@hjit("V(V,V)", inline=True)
+def add_VV_hf(a, b):
+ """
+ Adds two 3D vectors.
+ Inline-compiled by default.
+
+ Parameters
+ ----------
+ a : tuple[float,float,float]
+ Vector
+ b : tuple[float,float,float]
+ Vector
+
+ Returns
+ -------
+ c : tuple[float,float,float]
+ Vector
+
+ """
+
+ return a[0] + b[0], a[1] + b[1], a[2] + b[2]
+
+
+@hjit("V(V,V)", inline=True)
+def cross_VV_hf(a, b):
+ """
+ Cross-product of two 3D vectors.
+ Inline-compiled by default.
+
+ Parameters
+ ----------
+ a : tuple[float,float,float]
+ Vector
+ b : tuple[float,float,float]
+ Vector
+
+ Returns
+ -------
+ c : tuple[float,float,float]
+ Vector
+
+ """
+
+ return (
+ a[1] * b[2] - a[2] * b[1],
+ a[2] * b[0] - a[0] * b[2],
+ a[0] * b[1] - a[1] * b[0],
+ )
+
+
+@hjit("f(f,f)", inline=True)
+def div_ss_hf(a, b):
+ """
+ Division of two scalars. Similar to `numpy.divide` as it returns
+ +/- (depending on the sign of `a`) infinity if `b` is zero.
+ Required for compatibility if `core` is not compiled for debugging purposes.
+ Inline-compiled by default.
+
+ Parameters
+ ----------
+ a : float
+ Scalar
+ b : float
+ Scalar
+
+ Returns
+ -------
+ c : float
+ Scalar
+
+ """
+
+ if b == 0:
+ return inf if a >= 0 else -inf
+ return a / b
+
+
+@hjit("V(V,V)", inline=True)
+def div_VV_hf(a, b):
+ """
+ Division of two 3D vectors element-wise. Similar to `numpy.divide` as
+ it returns +/- (depending on the sign of `a`) infinity if `b` is zero.
+ Inline-compiled by default.
+
+ Parameters
+ ----------
+ a : tuple[float,float,float]
+ Vector
+ b : tuple[float,float,float]
+ Vector
+
+ Returns
+ -------
+ c : tuple[float,float,float]
+ Vector
+
+ """
+
+ return div_ss_hf(a[0], b[0]), div_ss_hf(a[1], b[1]), div_ss_hf(a[2], b[2])
+
+
+@hjit("V(V,f)", inline=True)
+def div_Vs_hf(a, b):
+ """
+ Division of a 3D vector by a scalar element-wise. Similar to `numpy.divide` as
+ it returns +/- (depending on the sign of `a`) infinity if `b` is zero.
+ Inline-compiled by default.
+
+ Parameters
+ ----------
+ a : tuple[float,float,float]
+ Vector
+ b : float
+ Scalar
+
+ Returns
+ -------
+ c : tuple[float,float,float]
+ Vector
+
+ """
+
+ return div_ss_hf(a[0], b), div_ss_hf(a[1], b), div_ss_hf(a[2], b)
+
+
+@hjit("M(M,M)", inline=True)
+def matmul_MM_hf(a, b):
+ """
+ Matmul (dot product) between two 3x3 matrices.
+ Inline-compiled by default.
+
+ Parameters
+ ----------
+ a : tuple[tuple[float,float,float],tuple[float,float,float],tuple[float,float,float]]
+ Matrix
+ b : tuple[tuple[float,float,float],tuple[float,float,float],tuple[float,float,float]]
+ Matrix
+
+ Returns
+ -------
+ c : tuple[tuple[float,float,float],tuple[float,float,float],tuple[float,float,float]]
+ Matrix
+
+ """
+
+ return (
+ (
+ a[0][0] * b[0][0] + a[0][1] * b[1][0] + a[0][2] * b[2][0],
+ a[0][0] * b[0][1] + a[0][1] * b[1][1] + a[0][2] * b[2][1],
+ a[0][0] * b[0][2] + a[0][1] * b[1][2] + a[0][2] * b[2][2],
+ ),
+ (
+ a[1][0] * b[0][0] + a[1][1] * b[1][0] + a[1][2] * b[2][0],
+ a[1][0] * b[0][1] + a[1][1] * b[1][1] + a[1][2] * b[2][1],
+ a[1][0] * b[0][2] + a[1][1] * b[1][2] + a[1][2] * b[2][2],
+ ),
+ (
+ a[2][0] * b[0][0] + a[2][1] * b[1][0] + a[2][2] * b[2][0],
+ a[2][0] * b[0][1] + a[2][1] * b[1][1] + a[2][2] * b[2][1],
+ a[2][0] * b[0][2] + a[2][1] * b[1][2] + a[2][2] * b[2][2],
+ ),
+ )
+
+
+@hjit("V(V,M)", inline=True)
+def matmul_VM_hf(a, b):
+ """
+ Matmul (dot product) between a 3D row vector and a 3x3 matrix
+ resulting in a 3D vector.
+ Inline-compiled by default.
+
+ Parameters
+ ----------
+ a : tuple[float,float,float]
+ Vector
+ b : tuple[tuple[float,float,float],tuple[float,float,float],tuple[float,float,float]]
+ Matrix
+
+ Returns
+ -------
+ c : tuple[float,float,float]
+ Vector
+
+ """
+
+ return (
+ a[0] * b[0][0] + a[1] * b[1][0] + a[2] * b[2][0],
+ a[0] * b[0][1] + a[1] * b[1][1] + a[2] * b[2][1],
+ a[0] * b[0][2] + a[1] * b[1][2] + a[2] * b[2][2],
+ )
+
+
+@hjit("V(M,V)", inline=True)
+def matmul_MV_hf(a, b):
+ """
+ Matmul (dot product) between a 3x3 matrix and a 3D column vector
+ resulting in a 3D vector.
+ Inline-compiled by default.
+
+ Parameters
+ ----------
+ a : tuple[tuple[float,float,float],tuple[float,float,float],tuple[float,float,float]]
+ Matrix
+ b : tuple[float,float,float]
+ Vector
+
+ Returns
+ -------
+ c : tuple[float,float,float]
+ Vector
+
+ """
+
+ return (
+ b[0] * a[0][0] + b[1] * a[0][1] + b[2] * a[0][2],
+ b[0] * a[1][0] + b[1] * a[1][1] + b[2] * a[1][2],
+ b[0] * a[2][0] + b[1] * a[2][1] + b[2] * a[2][2],
+ )
+
+
+@hjit("f(V,V)", inline=True)
+def matmul_VV_hf(a, b):
+ """
+ Matmul (dot product) between two 3D vectors resulting in a scalar.
+ Inline-compiled by default.
+
+ Parameters
+ ----------
+ a : tuple[float,float,float]
+ Vector
+ b : tuple[float,float,float]
+ Vector
+
+ Returns
+ -------
+ c : float
+ Scalar
+
+ """
+
+ return a[0] * b[0] + a[1] * b[1] + a[2] * b[2]
+
+
+@hjit("V(V,V)", inline=True)
+def max_VV_hf(a, b):
+ """
+ Max elements element-wise from two 3D vectors.
+ Inline-compiled by default.
+
+ Parameters
+ ----------
+ a : tuple[float,float,float]
+ Vector
+ b : tuple[float,float,float]
+ Vector
+
+ Returns
+ -------
+ c : tuple[float,float,float]
+ Vector
+
+ """
+
+ return (
+ a[0] if a[0] > b[0] else b[0],
+ a[1] if a[1] > b[1] else b[1],
+ a[2] if a[2] > b[2] else b[2],
+ )
+
+
+@hjit("V(V,f)", inline=True)
+def mul_Vs_hf(a, b):
+ """
+ Multiplication of a 3D vector by a scalar element-wise.
+ Inline-compiled by default.
+
+ Parameters
+ ----------
+ a : tuple[float,float,float]
+ Vector
+ b : float
+ Scalar
+
+ Returns
+ -------
+ c : tuple[float,float,float]
+ Vector
+
+ """
+
+ return a[0] * b, a[1] * b, a[2] * b
+
+
+@hjit("V(V,V)", inline=True)
+def mul_VV_hf(a, b):
+ """
+ Multiplication of two 3D vectors element-wise.
+ Inline-compiled by default.
+
+ Parameters
+ ----------
+ a : tuple[float,float,float]
+ Vector
+ b : tuple[float,float,float]
+ Vector
+
+ Returns
+ -------
+ c : tuple[float,float,float]
+ Vector
+
+ """
+
+ return a[0] * b[0], a[1] * b[1], a[2] * b[2]
+
+
+@hjit("f(V)", inline=True)
+def norm_V_hf(a):
+ """
+ Norm of a 3D vector.
+ Inline-compiled by default.
+
+ Parameters
+ ----------
+ a : tuple[float,float,float]
+ Vector
+
+ Returns
+ -------
+ b : float
+ Scalar
+
+ """
+
+ return sqrt(matmul_VV_hf(a, a))
+
+
+@vjit("f(f,f,f)")
+def norm_V_vf(a, b, c):
+ """
+ Norm of a 3D vector.
+
+ Parameters
+ ----------
+ a : float
+ First dimension scalar
+ b : float
+ Second dimension scalar
+ c : float
+ Third dimension scalar
+
+ Returns
+ -------
+ d : float
+ Scalar
+
+ """
+
+ return norm_V_hf((a, b, c))
+
+
+@hjit("f(V,V)", inline=True)
+def norm_VV_hf(a, b):
+ """
+ Combined norm of two 3D vectors treated like a single 6D vector.
+ Inline-compiled by default.
+
+ Parameters
+ ----------
+ a : tuple[float,float,float]
+ Vector
+ b : tuple[float,float,float]
+ Vector
+
+ Returns
+ -------
+ c : float
+ Scalar
+
+ """
+
+ return sqrt(matmul_VV_hf(a, a) + matmul_VV_hf(b, b))
+
+
+@hjit("f(f)", inline=True)
+def sign_hf(a):
+ """
+ Sign of a float represented as another float (-1, 0, +1).
+ Inline-compiled by default.
+
+ Parameters
+ ----------
+ a : float
+ Scalar
+
+ Returns
+ -------
+ b : float
+ Scalar
+
+ """
+
+ if a < 0.0:
+ return -1.0
+ if a == 0.0:
+ return 0.0
+ return 1.0 # if x > 0
+
+
+@hjit("V(V,V)", inline=True)
+def sub_VV_hf(a, b):
+ """
+ Subtraction of two 3D vectors element-wise.
+ Inline-compiled by default.
+
+ Parameters
+ ----------
+ a : tuple[float,float,float]
+ Vector
+ b : tuple[float,float,float]
+ Vector
+
+ Returns
+ -------
+ c : tuple[float,float,float]
+ Vector
+
+ """
+
+ return a[0] - b[0], a[1] - b[1], a[2] - b[2]
+
+
+@hjit("M(M)", inline=True)
+def transpose_M_hf(a):
+ """
+ Transposition of a matrix.
+ Inline-compiled by default.
+
+ Parameters
+ ----------
+ a : tuple[tuple[float,float,float],tuple[float,float,float],tuple[float,float,float]]
+ Matrix
+
+ Returns
+ -------
+ b : tuple[tuple[float,float,float],tuple[float,float,float],tuple[float,float,float]]
+ Matrix
+
+ """
+
+ return (
+ (a[0][0], a[1][0], a[2][0]),
+ (a[0][1], a[1][1], a[2][1]),
+ (a[0][2], a[1][2], a[2][2]),
+ )
diff --git a/src/hapsira/core/math/special.py b/src/hapsira/core/math/special.py
new file mode 100644
index 000000000..8c021630f
--- /dev/null
+++ b/src/hapsira/core/math/special.py
@@ -0,0 +1,214 @@
+from math import cos, cosh, gamma, inf, sin, sinh, sqrt
+
+from ..jit import hjit, vjit
+
+__all__ = [
+ "hyp2f1b_hf",
+ "hyp2f1b_vf",
+ "stumpff_c2_hf",
+ "stumpff_c2_vf",
+ "stumpff_c3_hf",
+ "stumpff_c3_vf",
+]
+
+
+@hjit("f(f)")
+def hyp2f1b_hf(x):
+ """
+ Hypergeometric function 2F1(3, 1, 5/2, x), see [Battin].
+
+ .. todo::
+ Add more information about this function
+
+ Notes
+ -----
+ More information about hypergeometric function can be checked at
+ https://en.wikipedia.org/wiki/Hypergeometric_function
+
+ Parameters
+ ----------
+ x : float
+ Scalar
+
+ Returns
+ -------
+ res : float
+ Scalar
+
+ """
+
+ if x >= 1.0:
+ return inf
+
+ res = 1.0
+ term = 1.0
+ ii = 0
+ while True:
+ term = term * (3 + ii) * (1 + ii) / (5 / 2 + ii) * x / (ii + 1)
+ res_old = res
+ res += term
+ if res_old == res:
+ return res
+ ii += 1
+
+
+@vjit("f(f)")
+def hyp2f1b_vf(x):
+ """
+ Hypergeometric function 2F1(3, 1, 5/2, x), see [Battin].
+
+ .. todo::
+ Add more information about this function
+
+ Notes
+ -----
+ More information about hypergeometric function can be checked at
+ https://en.wikipedia.org/wiki/Hypergeometric_function
+
+ Parameters
+ ----------
+ x : float
+ Scalar
+
+ Returns
+ -------
+ b : float
+ Scalar
+
+ """
+
+ return hyp2f1b_hf(x)
+
+
+@hjit("f(f)")
+def stumpff_c2_hf(psi):
+ r"""
+ Second Stumpff function.
+
+ For positive arguments:
+
+ .. math::
+
+ c_2(\psi) = \frac{1 - \cos{\sqrt{\psi}}}{\psi}
+
+ Parameters
+ ----------
+ psi : float
+ Scalar
+
+ Returns
+ -------
+ res : float
+ Scalar
+
+ """
+
+ eps = 1.0
+
+ if psi > eps:
+ return (1 - cos(sqrt(psi))) / psi
+
+ if psi < -eps:
+ return (cosh(sqrt(-psi)) - 1) / (-psi)
+
+ res = 1.0 / 2.0
+ delta = (-psi) / gamma(2 + 2 + 1)
+ k = 1
+ while res + delta != res:
+ res = res + delta
+ k += 1
+ delta = (-psi) ** k / gamma(2 * k + 2 + 1)
+ return res
+
+
+@vjit("f(f)")
+def stumpff_c2_vf(psi):
+ r"""
+ Second Stumpff function.
+
+ For positive arguments:
+
+ .. math::
+
+ c_2(\psi) = \frac{1 - \cos{\sqrt{\psi}}}{\psi}
+
+ Parameters
+ ----------
+ psi : float
+ Scalar
+
+ Returns
+ -------
+ res : float
+ Scalar
+
+ """
+
+ return stumpff_c2_hf(psi)
+
+
+@hjit("f(f)")
+def stumpff_c3_hf(psi):
+ r"""
+ Third Stumpff function.
+
+ For positive arguments:
+
+ .. math::
+
+ c_3(\psi) = \frac{\sqrt{\psi} - \sin{\sqrt{\psi}}}{\sqrt{\psi^3}}
+
+ Parameters
+ ----------
+ psi : float
+ Scalar
+
+ Returns
+ -------
+ res : float
+ Scalar
+
+ """
+
+ eps = 1.0
+
+ if psi > eps:
+ return (sqrt(psi) - sin(sqrt(psi))) / (psi * sqrt(psi))
+
+ if psi < -eps:
+ return (sinh(sqrt(-psi)) - sqrt(-psi)) / (-psi * sqrt(-psi))
+
+ res = 1.0 / 6.0
+ delta = (-psi) / gamma(2 + 3 + 1)
+ k = 1
+ while res + delta != res:
+ res = res + delta
+ k += 1
+ delta = (-psi) ** k / gamma(2 * k + 3 + 1)
+ return res
+
+
+@vjit("f(f)")
+def stumpff_c3_vf(psi):
+ r"""
+ Third Stumpff function.
+
+ For positive arguments:
+
+ .. math::
+
+ c_3(\psi) = \frac{\sqrt{\psi} - \sin{\sqrt{\psi}}}{\sqrt{\psi^3}}
+
+ Parameters
+ ----------
+ psi : float
+ Scalar
+
+ Returns
+ -------
+ res : float
+ Scalar
+
+ """
+
+ return stumpff_c3_hf(psi)
diff --git a/src/hapsira/core/perturbations.py b/src/hapsira/core/perturbations.py
index eb76a3582..4d154fe91 100644
--- a/src/hapsira/core/perturbations.py
+++ b/src/hapsira/core/perturbations.py
@@ -1,12 +1,22 @@
-from numba import njit as jit
-import numpy as np
+from math import exp, pow as pow_
-from hapsira._math.linalg import norm
-from hapsira.core.events import line_of_sight as line_of_sight_fast
+from .events import line_of_sight_hf
+from .jit import hjit
+from .math.linalg import norm_V_hf, mul_Vs_hf, mul_VV_hf, sub_VV_hf
-@jit
-def J2_perturbation(t0, state, k, J2, R):
+__all__ = [
+ "J2_perturbation_hf",
+ "J3_perturbation_hf",
+ "atmospheric_drag_exponential_hf",
+ "atmospheric_drag_hf",
+ "third_body_hf",
+ "radiation_pressure_hf",
+]
+
+
+@hjit("V(f,V,V,f,f,f)")
+def J2_perturbation_hf(t0, rr, vv, k, J2, R):
r"""Calculates J2_perturbation acceleration (km/s2).
.. math::
@@ -19,8 +29,10 @@ def J2_perturbation(t0, state, k, J2, R):
----------
t0 : float
Current time (s)
- state : numpy.ndarray
- Six component state vector [x, y, z, vx, vy, vz] (km, km/s).
+ rr : tuple[float,float,float]
+ Vector [x, y, z] (km)
+ vv : tuple[float,float,float]
+ Vector [vx, vy, vz] (km/s)
k : float
Standard Gravitational parameter. (km^3/s^2)
J2 : float
@@ -34,27 +46,26 @@ def J2_perturbation(t0, state, k, J2, R):
Howard Curtis, (12.30)
"""
- r_vec = state[:3]
- r = norm(r_vec)
-
- factor = (3.0 / 2.0) * k * J2 * (R**2) / (r**5)
+ r = norm_V_hf(rr)
- a_x = 5.0 * r_vec[2] ** 2 / r**2 - 1
- a_y = 5.0 * r_vec[2] ** 2 / r**2 - 1
- a_z = 5.0 * r_vec[2] ** 2 / r**2 - 3
- return np.array([a_x, a_y, a_z]) * r_vec * factor
+ factor = 1.5 * k * J2 * R * R / pow_(r, 5)
+ a_base = 5.0 * rr[2] * rr[2] / (r * r)
+ a = a_base - 1, a_base - 1, a_base - 3
+ return mul_Vs_hf(mul_VV_hf(a, rr), factor)
-@jit
-def J3_perturbation(t0, state, k, J3, R):
+@hjit("V(f,V,V,f,f,f)")
+def J3_perturbation_hf(t0, rr, vv, k, J3, R):
r"""Calculates J3_perturbation acceleration (km/s2).
Parameters
----------
t0 : float
Current time (s)
- state : numpy.ndarray
- Six component state vector [x, y, z, vx, vy, vz] (km, km/s).
+ rr : tuple[float,float,float]
+ Vector [x, y, z] (km)
+ vv : tuple[float,float,float]
+ Vector [vx, vy, vz] (km/s)
k : float
Standard Gravitational parameter. (km^3/s^2)
J3 : float
@@ -66,23 +77,22 @@ def J3_perturbation(t0, state, k, J3, R):
-----
The J3 accounts for the oblateness of the attractor. The formula is given in
Howard Curtis, problem 12.8
- This perturbation has not been fully validated, see https://github.com/hapsira/hapsira/pull/398
+ This perturbation has not been fully validated, see https://github.com/poliastro/poliastro/pull/398
"""
- r_vec = state[:3]
- r = norm(r_vec)
+ r = norm_V_hf(rr)
factor = (1.0 / 2.0) * k * J3 * (R**3) / (r**5)
- cos_phi = r_vec[2] / r
+ cos_phi = rr[2] / r
- a_x = 5.0 * r_vec[0] / r * (7.0 * cos_phi**3 - 3.0 * cos_phi)
- a_y = 5.0 * r_vec[1] / r * (7.0 * cos_phi**3 - 3.0 * cos_phi)
+ a_x = 5.0 * rr[0] / r * (7.0 * cos_phi**3 - 3.0 * cos_phi)
+ a_y = 5.0 * rr[1] / r * (7.0 * cos_phi**3 - 3.0 * cos_phi)
a_z = 3.0 * (35.0 / 3.0 * cos_phi**4 - 10.0 * cos_phi**2 + 1)
- return np.array([a_x, a_y, a_z]) * factor
+ return a_x * factor, a_y * factor, a_z * factor
-@jit
-def atmospheric_drag_exponential(t0, state, k, R, C_D, A_over_m, H0, rho0):
+@hjit("V(f,V,V,f,f,f,f,f,f)")
+def atmospheric_drag_exponential_hf(t0, rr, vv, k, R, C_D, A_over_m, H0, rho0):
r"""Calculates atmospheric drag acceleration (km/s2).
.. math::
@@ -95,8 +105,10 @@ def atmospheric_drag_exponential(t0, state, k, R, C_D, A_over_m, H0, rho0):
----------
t0 : float
Current time (s)
- state : numpy.ndarray
- Six component state vector [x, y, z, vx, vy, vz] (km, km/s).
+ rr : tuple[float,float,float]
+ Vector [x, y, z] (km)
+ vv : tuple[float,float,float]
+ Vector [vx, vy, vz] (km/s)
k : float
Standard Gravitational parameter (km^3/s^2).
R : float
@@ -118,18 +130,17 @@ def atmospheric_drag_exponential(t0, state, k, R, C_D, A_over_m, H0, rho0):
the atmospheric density model is rho(H) = rho0 x exp(-H / H0)
"""
- H = norm(state[:3])
+ H = norm_V_hf(rr)
- v_vec = state[3:]
- v = norm(v_vec)
+ v = norm_V_hf(vv)
B = C_D * A_over_m
- rho = rho0 * np.exp(-(H - R) / H0)
+ rho = rho0 * exp(-(H - R) / H0)
- return -(1.0 / 2.0) * rho * B * v * v_vec
+ return mul_Vs_hf(vv, -(1.0 / 2.0) * rho * B * v)
-@jit
-def atmospheric_drag(t0, state, k, C_D, A_over_m, rho):
+@hjit("V(f,V,V,f,f,f,f)")
+def atmospheric_drag_hf(t0, rr, vv, k, C_D, A_over_m, rho):
r"""Calculates atmospheric drag acceleration (km/s2).
.. math::
@@ -142,8 +153,10 @@ def atmospheric_drag(t0, state, k, C_D, A_over_m, rho):
----------
t0 : float
Current time (s).
- state : numpy.ndarray
- Six component state vector [x, y, z, vx, vy, vz] (km, km/s).
+ rr : tuple[float,float,float]
+ Vector [x, y, z] (km)
+ vv : tuple[float,float,float]
+ Vector [vx, vy, vz] (km/s)
k : float
Standard Gravitational parameter (km^3/s^2)
C_D : float
@@ -159,14 +172,14 @@ def atmospheric_drag(t0, state, k, C_D, A_over_m, rho):
computed by a model from hapsira.earth.atmosphere
"""
- v_vec = state[3:]
- v = norm(v_vec)
+ v = norm_V_hf(vv)
B = C_D * A_over_m
- return -(1.0 / 2.0) * rho * B * v * v_vec
+ return mul_Vs_hf(vv, -(1.0 / 2.0) * rho * B * v)
-def third_body(t0, state, k, k_third, perturbation_body):
+@hjit("V(f,V,V,f,f,F(V(f)))")
+def third_body_hf(t0, rr, vv, k, k_third, perturbation_body):
r"""Calculate third body acceleration (km/s2).
.. math::
@@ -177,8 +190,10 @@ def third_body(t0, state, k, k_third, perturbation_body):
----------
t0 : float
Current time (s).
- state : numpy.ndarray
- Six component state vector [x, y, z, vx, vy, vz] (km, km/s).
+ rr : tuple[float,float,float]
+ Vector [x, y, z] (km)
+ vv : tuple[float,float,float]
+ Vector [vx, vy, vz] (km/s)
k : float
Standard Gravitational parameter of the attractor (km^3/s^2).
k_third : float
@@ -194,11 +209,15 @@ def third_body(t0, state, k, k_third, perturbation_body):
"""
body_r = perturbation_body(t0)
- delta_r = body_r - state[:3]
- return k_third * delta_r / norm(delta_r) ** 3 - k_third * body_r / norm(body_r) ** 3
+ delta_r = sub_VV_hf(body_r, rr)
+ return sub_VV_hf(
+ mul_Vs_hf(delta_r, k_third / norm_V_hf(delta_r) ** 3),
+ mul_Vs_hf(body_r, k_third / norm_V_hf(body_r) ** 3),
+ )
-def radiation_pressure(t0, state, k, R, C_R, A_over_m, Wdivc_s, star):
+@hjit("V(f,V,V,f,f,f,f,f,F(V(f)))")
+def radiation_pressure_hf(t0, rr, vv, k, R, C_R, A_over_m, Wdivc_s, star):
r"""Calculates radiation pressure acceleration (km/s2).
.. math::
@@ -209,8 +228,10 @@ def radiation_pressure(t0, state, k, R, C_R, A_over_m, Wdivc_s, star):
----------
t0 : float
Current time (s).
- state : numpy.ndarray
- Six component state vector [x, y, z, vx, vy, vz] (km, km/s).
+ rr : tuple[float,float,float]
+ Vector [x, y, z] (km)
+ vv : tuple[float,float,float]
+ Vector [vx, vy, vz] (km/s)
k : float
Standard Gravitational parameter (km^3/s^2).
R : float
@@ -232,8 +253,10 @@ def radiation_pressure(t0, state, k, R, C_R, A_over_m, Wdivc_s, star):
"""
r_star = star(t0)
- r_sat = state[:3]
- P_s = Wdivc_s / (norm(r_star) ** 2)
+ P_s = Wdivc_s / (norm_V_hf(r_star) ** 2)
- nu = float(line_of_sight_fast(r_sat, r_star, R) > 0)
- return -nu * P_s * (C_R * A_over_m) * r_star / norm(r_star)
+ if line_of_sight_hf(rr, r_star, R) > 0:
+ nu = 1.0
+ else:
+ nu = 0.0
+ return mul_Vs_hf(r_star, -nu * P_s * (C_R * A_over_m) / norm_V_hf(r_star))
diff --git a/src/hapsira/core/propagation/__init__.py b/src/hapsira/core/propagation/__init__.py
index 5460386d5..c7bda8291 100644
--- a/src/hapsira/core/propagation/__init__.py
+++ b/src/hapsira/core/propagation/__init__.py
@@ -1,35 +1 @@
"""Low level propagation algorithms."""
-
-from hapsira.core.propagation.base import func_twobody
-from hapsira.core.propagation.cowell import cowell
-from hapsira.core.propagation.danby import danby, danby_coe
-from hapsira.core.propagation.farnocchia import (
- farnocchia_coe,
- farnocchia_rv as farnocchia,
-)
-from hapsira.core.propagation.gooding import gooding, gooding_coe
-from hapsira.core.propagation.markley import markley, markley_coe
-from hapsira.core.propagation.mikkola import mikkola, mikkola_coe
-from hapsira.core.propagation.pimienta import pimienta, pimienta_coe
-from hapsira.core.propagation.recseries import recseries, recseries_coe
-from hapsira.core.propagation.vallado import vallado
-
-__all__ = [
- "cowell",
- "func_twobody",
- "farnocchia_coe",
- "farnocchia",
- "vallado",
- "mikkola_coe",
- "mikkola",
- "markley_coe",
- "markley",
- "pimienta_coe",
- "pimienta",
- "gooding_coe",
- "gooding",
- "danby_coe",
- "danby",
- "recseries_coe",
- "recseries",
-]
diff --git a/src/hapsira/core/propagation/base.py b/src/hapsira/core/propagation/base.py
index 5d289a5cb..f96066ae2 100644
--- a/src/hapsira/core/propagation/base.py
+++ b/src/hapsira/core/propagation/base.py
@@ -1,23 +1,31 @@
-from numba import njit as jit
-import numpy as np
+from math import pow as pow_
+from ..jit import djit
-@jit
-def func_twobody(t0, u_, k):
+
+__all__ = [
+ "func_twobody_hf",
+]
+
+
+@djit
+def func_twobody_hf(t0, rr, vv, k):
"""Differential equation for the initial value two body problem.
Parameters
----------
t0 : float
Time.
- u_ : numpy.ndarray
- Six component state vector [x, y, z, vx, vy, vz] (km, km/s).
+ rr : tuple[float,float,float]
+ Position vector
+ vv : tuple[float,float,float]
+ Velocity vector.
k : float
Standard gravitational parameter.
"""
- x, y, z, vx, vy, vz = u_
- r3 = (x**2 + y**2 + z**2) ** 1.5
+ x, y, z = rr
+ vx, vy, vz = vv
+ r3 = pow_(x * x + y * y + z * z, 1.5)
- du = np.array([vx, vy, vz, -k * x / r3, -k * y / r3, -k * z / r3])
- return du
+ return (vx, vy, vz), (-k * x / r3, -k * y / r3, -k * z / r3)
diff --git a/src/hapsira/core/propagation/cowell.py b/src/hapsira/core/propagation/cowell.py
index 876545b11..28b85def7 100644
--- a/src/hapsira/core/propagation/cowell.py
+++ b/src/hapsira/core/propagation/cowell.py
@@ -1,49 +1,234 @@
-import numpy as np
+from math import isnan, nan
+from typing import Callable, Tuple
-from hapsira._math.ivp import DOP853, solve_ivp
-from hapsira.core.propagation.base import func_twobody
+from ..jit import gjit, array_to_V_hf
+from ..math.ieee754 import EPS
+from ..math.ivp import (
+ BRENTQ_CONVERGED,
+ BRENTQ_MAXITER,
+ DENSE_SIG,
+ DOP853_FINISHED,
+ DOP853_FAILED,
+ DOP853_ARGK,
+ DOP853_FR,
+ DOP853_FUN,
+ DOP853_FV,
+ DOP853_H_PREVIOUS,
+ DOP853_K,
+ DOP853_RR,
+ DOP853_RR_OLD,
+ DOP853_STATUS,
+ DOP853_T,
+ DOP853_T_OLD,
+ DOP853_VV,
+ DOP853_VV_OLD,
+ brentq_dense_hf,
+ event_is_active_hf,
+ dispatcher_hb,
+ dop853_dense_interp_hf,
+ dop853_dense_output_hf,
+ dop853_init_hf,
+ dop853_step_hf,
+)
+from ..propagation.base import func_twobody_hf
-def cowell(k, r, v, tofs, rtol=1e-11, *, events=None, f=func_twobody):
- x, y, z = r
- vx, vy, vz = v
+__all__ = [
+ "cowell_gb",
+ "SOLVE_BRENTQFAILED",
+ "SOLVE_FAILED",
+ "SOLVE_RUNNING",
+ "SOLVE_FINISHED",
+ "SOLVE_TERMINATED",
+]
- u0 = np.array([x, y, z, vx, vy, vz])
- result = solve_ivp(
- f,
- (0, max(tofs)),
- u0,
- args=(k,),
- rtol=rtol,
- atol=1e-12,
- method=DOP853,
- dense_output=True,
- events=events,
+SOLVE_BRENTQFAILED = -3
+SOLVE_FAILED = -2
+SOLVE_RUNNING = -1
+SOLVE_FINISHED = 0
+SOLVE_TERMINATED = 1
+
+
+def cowell_gb(
+ events: Tuple = tuple(),
+ func: Callable = func_twobody_hf,
+) -> Callable:
+ """
+ Builds vectorized cowell
+ """
+
+ assert hasattr(func, "djit") # DEBUG check for compiler flag
+
+ EVENTS = len(events)
+
+ event_impl_hf = dispatcher_hb(
+ funcs=tuple(event.impl_hf for event in events),
+ argtypes="f,V,V,f",
+ restype="f",
+ arguments="t, rr, vv, k",
)
- if not result.success:
- raise RuntimeError("Integration failed")
-
- if events is not None:
- # Collect only the terminal events
- terminal_events = [event for event in events if event.terminal]
-
- # If there are no terminal events, then the last time of integration is the
- # greatest one from the original array of propagation times
- if not terminal_events:
- last_t = max(tofs)
- else:
- # Filter the event which triggered first
- last_t = min(event._last_t for event in terminal_events)
- # FIXME: Here last_t has units, but tofs don't
- tofs = [tof for tof in tofs if tof < last_t] + [last_t]
-
- rrs = []
- vvs = []
- for i in range(len(tofs)):
- t = tofs[i]
- y = result.sol(t)
- rrs.append(y[:3])
- vvs.append(y[3:])
-
- return rrs, vvs
+ event_impl_dense_hf = dispatcher_hb(
+ funcs=tuple(event.impl_dense_hf for event in events),
+ argtypes=f"f,{DENSE_SIG:s},f",
+ restype="f",
+ arguments="t, t_old, h, rr_old, vv_old, F, argk",
+ )
+
+ @gjit(
+ "void(f[:],f[:],f[:],f,f,f,b1[:],f[:],f[:],f[:],f[:],f[:],i8[:],i8[:],f[:,:],f[:,:])",
+ "(n),(m),(m),(),(),(),(o),(o)->(o),(o),(o),(o),(),(),(n,m),(n,m)",
+ cache=False,
+ )
+ def cowell_gf(
+ tofs,
+ rr,
+ vv,
+ argk,
+ rtol,
+ atol,
+ event_terminals,
+ event_directions,
+ event_g_olds,
+ event_g_news,
+ event_actives,
+ event_last_ts,
+ status,
+ t_idx,
+ rrs,
+ vvs,
+ ):
+ """
+ Solve an initial value problem for a system of ODEs.
+
+ Can theoretically be reversed: https://github.com/poliastro/poliastro/issues/1630
+ """
+
+ # assert isinstance(rtol, float)
+ # assert all(tof >= 0 for tof in tofs)
+ # assert sorted(tofs) == list(tofs)
+
+ T0 = 0.0
+
+ solver = dop853_init_hf(
+ func, T0, array_to_V_hf(rr), array_to_V_hf(vv), tofs[-1], argk, rtol, atol
+ )
+
+ t_idx[0] = 0
+ t_last = T0
+
+ for event_idx in range(EVENTS):
+ event_g_olds[event_idx] = event_impl_hf(
+ event_idx, T0, array_to_V_hf(rr), array_to_V_hf(vv), argk
+ )
+ event_last_ts[event_idx] = T0
+
+ status[0] = SOLVE_RUNNING
+ while status[0] == SOLVE_RUNNING:
+ solver = dop853_step_hf(*solver)
+
+ if solver[DOP853_STATUS] == DOP853_FINISHED:
+ status[0] = SOLVE_FINISHED
+ elif solver[DOP853_STATUS] == DOP853_FAILED:
+ status[0] = SOLVE_FAILED
+ break
+
+ t_old = solver[DOP853_T_OLD]
+ t = solver[DOP853_T]
+
+ interpolant = dop853_dense_output_hf(
+ solver[DOP853_FUN],
+ solver[DOP853_ARGK],
+ solver[DOP853_T_OLD],
+ solver[DOP853_T],
+ solver[DOP853_H_PREVIOUS],
+ solver[DOP853_RR],
+ solver[DOP853_VV],
+ solver[DOP853_RR_OLD],
+ solver[DOP853_VV_OLD],
+ solver[DOP853_FR],
+ solver[DOP853_FV],
+ solver[DOP853_K],
+ )
+
+ at_least_one_active = False
+ for event_idx in range(EVENTS):
+ event_g_news[event_idx] = event_impl_hf(
+ event_idx, t, solver[DOP853_RR], solver[DOP853_VV], argk
+ )
+ event_last_ts[event_idx] = t
+ event_actives[event_idx] = event_is_active_hf(
+ event_g_olds[event_idx],
+ event_g_news[event_idx],
+ event_directions[event_idx],
+ )
+ if event_actives[event_idx]:
+ at_least_one_active = True
+
+ if at_least_one_active:
+ root_pivot = nan # set initial value
+ terminate = False
+
+ for event_idx in range(EVENTS):
+ if not event_actives[event_idx]:
+ continue
+
+ if not event_terminals[event_idx]:
+ continue
+
+ terminate = True
+
+ event_last_ts[event_idx], root, brentq_status = brentq_dense_hf(
+ event_impl_dense_hf,
+ event_idx,
+ t_old,
+ t,
+ 4 * EPS,
+ 4 * EPS,
+ BRENTQ_MAXITER,
+ *interpolant,
+ argk,
+ )
+ if brentq_status != BRENTQ_CONVERGED:
+ status[0] = SOLVE_BRENTQFAILED
+ return # failed on event
+
+ if isnan(root_pivot):
+ root_pivot = root
+ continue
+
+ if t > t_old: # smallest root of all active events
+ if root < root_pivot:
+ root_pivot = root
+ continue
+
+ # largest root of all active events
+ if root > root_pivot:
+ root_pivot = root
+ raise ValueError("not t > t_old", t, t_old) # TODO remove
+
+ if terminate:
+ assert not isnan(root_pivot)
+ status[0] = SOLVE_TERMINATED
+ t = root_pivot
+
+ for event_idx in range(EVENTS):
+ event_g_olds[event_idx] = event_g_news[event_idx]
+
+ if not t_last <= t:
+ raise ValueError("not t_last <= t")
+
+ while t_idx[0] < tofs.shape[0] and tofs[t_idx[0]] < t:
+ rrs[t_idx[0], :], vvs[t_idx[0], :] = dop853_dense_interp_hf(
+ tofs[t_idx[0]], *interpolant
+ )
+ t_idx[0] += 1
+ if status[0] == SOLVE_TERMINATED or tofs[t_idx[0]] == t:
+ rrs[t_idx[0], :], vvs[t_idx[0], :] = dop853_dense_interp_hf(
+ t, *interpolant
+ )
+ t_idx[0] += 1
+
+ t_last = t
+
+ return cowell_gf
diff --git a/src/hapsira/core/propagation/danby.py b/src/hapsira/core/propagation/danby.py
index f6a71c2cc..d428569b9 100644
--- a/src/hapsira/core/propagation/danby.py
+++ b/src/hapsira/core/propagation/danby.py
@@ -1,63 +1,82 @@
-from numba import njit as jit
-import numpy as np
+from math import atan2, cos, cosh, floor, log, pi, sin, sinh, sqrt
-from hapsira.core.angles import E_to_M, F_to_M, nu_to_E, nu_to_F
-from hapsira.core.elements import coe2rv, rv2coe
+from ..angles import E_to_M_hf, F_to_M_hf, nu_to_E_hf, nu_to_F_hf
+from ..elements import coe2rv_hf, rv2coe_hf, RV2COE_TOL
+from ..math.linalg import sign_hf
+from ..jit import array_to_V_hf, hjit, gjit, vjit
-@jit
-def danby_coe(k, p, ecc, inc, raan, argp, nu, tof, numiter=20, rtol=1e-8):
+__all__ = [
+ "danby_coe_hf",
+ "danby_coe_vf",
+ "danby_rv_hf",
+ "danby_rv_gf",
+ "DANBY_NUMITER",
+ "DANBY_RTOL",
+]
+
+
+DANBY_NUMITER = 20
+DANBY_RTOL = 1e-8
+
+
+@hjit("f(f,f,f,f,f,f,f,f,i8,f)")
+def danby_coe_hf(k, p, ecc, inc, raan, argp, nu, tof, numiter, rtol):
+ """
+ Scalar danby_coe
+ """
+
semi_axis_a = p / (1 - ecc**2)
- n = np.sqrt(k / np.abs(semi_axis_a) ** 3)
+ n = sqrt(k / abs(semi_axis_a) ** 3)
if ecc == 0:
# Solving for circular orbit
M0 = nu # for circular orbit M = E = nu
M = M0 + n * tof
- nu = M - 2 * np.pi * np.floor(M / 2 / np.pi)
+ nu = M - 2 * pi * floor(M / 2 / pi)
return nu
elif ecc < 1.0:
# For elliptical orbit
- M0 = E_to_M(nu_to_E(nu, ecc), ecc)
+ M0 = E_to_M_hf(nu_to_E_hf(nu, ecc), ecc)
M = M0 + n * tof
- xma = M - 2 * np.pi * np.floor(M / 2 / np.pi)
- E = xma + 0.85 * np.sign(np.sin(xma)) * ecc
+ xma = M - 2 * pi * floor(M / 2 / pi)
+ E = xma + 0.85 * sign_hf(sin(xma)) * ecc
else:
# For parabolic and hyperbolic
- M0 = F_to_M(nu_to_F(nu, ecc), ecc)
+ M0 = F_to_M_hf(nu_to_F_hf(nu, ecc), ecc)
M = M0 + n * tof
- xma = M - 2 * np.pi * np.floor(M / 2 / np.pi)
- E = np.log(2 * xma / ecc + 1.8)
+ xma = M - 2 * pi * floor(M / 2 / pi)
+ E = log(2 * xma / ecc + 1.8)
# Iterations begin
n = 0
while n <= numiter:
if ecc < 1.0:
- s = ecc * np.sin(E)
- c = ecc * np.cos(E)
+ s = ecc * sin(E)
+ c = ecc * cos(E)
f = E - s - xma
fp = 1 - c
fpp = s
fppp = c
else:
- s = ecc * np.sinh(E)
- c = ecc * np.cosh(E)
+ s = ecc * sinh(E)
+ c = ecc * cosh(E)
f = s - E - xma
fp = c - 1
fpp = s
fppp = c
- if np.abs(f) <= rtol:
+ if abs(f) <= rtol:
if ecc < 1.0:
- sta = np.sqrt(1 - ecc**2) * np.sin(E)
- cta = np.cos(E) - ecc
+ sta = sqrt(1 - ecc**2) * sin(E)
+ cta = cos(E) - ecc
else:
- sta = np.sqrt(ecc**2 - 1) * np.sinh(E)
- cta = ecc - np.cosh(E)
+ sta = sqrt(ecc**2 - 1) * sinh(E)
+ cta = ecc - cosh(E)
- nu = np.arctan2(sta, cta)
+ nu = atan2(sta, cta)
break
else:
delta = -f / fp
@@ -71,8 +90,17 @@ def danby_coe(k, p, ecc, inc, raan, argp, nu, tof, numiter=20, rtol=1e-8):
return nu
-@jit
-def danby(k, r0, v0, tof, numiter=20, rtol=1e-8):
+@vjit("f(f,f,f,f,f,f,f,f,i8,f)")
+def danby_coe_vf(k, p, ecc, inc, raan, argp, nu, tof, numiter, rtol):
+ """
+ Vectorized danby_coe
+ """
+
+ return danby_coe_hf(k, p, ecc, inc, raan, argp, nu, tof, numiter, rtol)
+
+
+@hjit("Tuple([V,V])(f,V,V,f,i8,f)")
+def danby_rv_hf(k, r0, v0, tof, numiter, rtol):
"""Kepler solver for both elliptic and parabolic orbits based on Danby's
algorithm.
@@ -80,9 +108,9 @@ def danby(k, r0, v0, tof, numiter=20, rtol=1e-8):
----------
k : float
Standard gravitational parameter of the attractor.
- r0 : numpy.ndarray
+ r0 : tuple[float,float,float]
Position vector.
- v0 : numpy.ndarray
+ v0 : tuple[float,float,float]
Velocity vector.
tof : float
Time of flight.
@@ -93,9 +121,9 @@ def danby(k, r0, v0, tof, numiter=20, rtol=1e-8):
Returns
-------
- rr : numpy.ndarray
+ rr : tuple[float,float,float]
Final position vector.
- vv : numpy.ndarray
+ vv : tuple[float,float,float]
Final velocity vector.
Notes
@@ -104,7 +132,26 @@ def danby(k, r0, v0, tof, numiter=20, rtol=1e-8):
Equation* with DOI: https://doi.org/10.1007/BF01686811
"""
# Solve first for eccentricity and mean anomaly
- p, ecc, inc, raan, argp, nu = rv2coe(k, r0, v0)
- nu = danby_coe(k, p, ecc, inc, raan, argp, nu, tof, numiter, rtol)
+ p, ecc, inc, raan, argp, nu = rv2coe_hf(k, r0, v0, RV2COE_TOL)
+ nu = danby_coe_hf(k, p, ecc, inc, raan, argp, nu, tof, numiter, rtol)
+
+ return coe2rv_hf(k, p, ecc, inc, raan, argp, nu)
+
+
+@gjit(
+ "void(f,f[:],f[:],f,i8,f,f[:],f[:])",
+ "(),(n),(n),(),(),()->(n),(n)",
+)
+def danby_rv_gf(k, r0, v0, tof, numiter, rtol, rr, vv):
+ """
+ Vectorized danby_rv
+ """
- return coe2rv(k, p, ecc, inc, raan, argp, nu)
+ (rr[0], rr[1], rr[2]), (vv[0], vv[1], vv[2]) = danby_rv_hf(
+ k,
+ array_to_V_hf(r0),
+ array_to_V_hf(v0),
+ tof,
+ numiter,
+ rtol,
+ )
diff --git a/src/hapsira/core/propagation/farnocchia.py b/src/hapsira/core/propagation/farnocchia.py
index d20e7124b..e135c13c2 100644
--- a/src/hapsira/core/propagation/farnocchia.py
+++ b/src/hapsira/core/propagation/farnocchia.py
@@ -1,38 +1,47 @@
-from numba import njit as jit
-import numpy as np
-
-from hapsira.core.angles import (
- D_to_M,
- D_to_nu,
- E_to_M,
- E_to_nu,
- F_to_M,
- F_to_nu,
- M_to_D,
- M_to_E,
- M_to_F,
- nu_to_D,
- nu_to_E,
- nu_to_F,
+from math import acos, acosh, cos, cosh, nan, pi, sqrt
+
+from ..angles import (
+ D_to_M_hf,
+ D_to_nu_hf,
+ E_to_M_hf,
+ E_to_nu_hf,
+ F_to_M_hf,
+ F_to_nu_hf,
+ M_to_D_hf,
+ M_to_E_hf,
+ M_to_F_hf,
+ nu_to_D_hf,
+ nu_to_E_hf,
+ nu_to_F_hf,
)
-from hapsira.core.elements import coe2rv, rv2coe
+from ..elements import coe2rv_hf, rv2coe_hf, RV2COE_TOL
+from ..jit import array_to_V_hf, hjit, vjit, gjit
-@jit
-def _kepler_equation_near_parabolic(D, M, ecc):
- return D_to_M_near_parabolic(D, ecc) - M
+__all__ = [
+ "delta_t_from_nu_hf",
+ "delta_t_from_nu_vf",
+ "farnocchia_coe_hf",
+ "farnocchia_coe_vf",
+ "farnocchia_rv_hf",
+ "farnocchia_rv_gf",
+ "FARNOCCHIA_K",
+ "FARNOCCHIA_Q",
+ "FARNOCCHIA_DELTA",
+]
-@jit
-def _kepler_equation_prime_near_parabolic(D, M, ecc):
- x = (ecc - 1.0) / (ecc + 1.0) * (D**2)
- assert abs(x) < 1
- S = dS_x_alt(ecc, x)
- return np.sqrt(2.0 / (1.0 + ecc)) + np.sqrt(2.0 / (1.0 + ecc) ** 3) * (D**2) * S
+FARNOCCHIA_K = 1.0
+FARNOCCHIA_Q = 1.0
+FARNOCCHIA_DELTA = 1e-2
+_ATOL = 1e-12
+_TOL = 1.48e-08
+_MAXITER = 50
-@jit
-def S_x(ecc, x, atol=1e-12):
+
+@hjit("f(f,f,f)")
+def _S_x_hf(ecc, x, atol):
assert abs(x) < 1
S = 0
k = 0
@@ -44,8 +53,8 @@ def S_x(ecc, x, atol=1e-12):
return S
-@jit
-def dS_x_alt(ecc, x, atol=1e-12):
+@hjit("f(f,f,f)")
+def _dS_x_alt_hf(ecc, x, atol):
# Notice that this is not exactly
# the partial derivative of S with respect to D,
# but the result of arranging the terms
@@ -61,8 +70,8 @@ def dS_x_alt(ecc, x, atol=1e-12):
return S
-@jit
-def d2S_x_alt(ecc, x, atol=1e-12):
+@hjit("f(f,f,f)")
+def _d2S_x_alt_hf(ecc, x, atol):
# Notice that this is not exactly
# the second partial derivative of S with respect to D,
# but the result of arranging the terms
@@ -79,18 +88,29 @@ def d2S_x_alt(ecc, x, atol=1e-12):
return S
-@jit
-def D_to_M_near_parabolic(D, ecc):
+@hjit("f(f,f,f)")
+def _kepler_equation_prime_near_parabolic_hf(D, M, ecc):
x = (ecc - 1.0) / (ecc + 1.0) * (D**2)
assert abs(x) < 1
- S = S_x(ecc, x)
- return (
- np.sqrt(2.0 / (1.0 + ecc)) * D + np.sqrt(2.0 / (1.0 + ecc) ** 3) * (D**3) * S
- )
+ S = _dS_x_alt_hf(ecc, x, _ATOL)
+ return sqrt(2.0 / (1.0 + ecc)) + sqrt(2.0 / (1.0 + ecc) ** 3) * (D**2) * S
+
+
+@hjit("f(f,f)")
+def _D_to_M_near_parabolic_hf(D, ecc):
+ x = (ecc - 1.0) / (ecc + 1.0) * (D**2)
+ assert abs(x) < 1
+ S = _S_x_hf(ecc, x, _ATOL)
+ return sqrt(2.0 / (1.0 + ecc)) * D + sqrt(2.0 / (1.0 + ecc) ** 3) * (D**3) * S
-@jit
-def M_to_D_near_parabolic(M, ecc, tol=1.48e-08, maxiter=50):
+@hjit("f(f,f,f)")
+def _kepler_equation_near_parabolic_hf(D, M, ecc):
+ return _D_to_M_near_parabolic_hf(D, ecc) - M
+
+
+@hjit("f(f,f,f,i8)")
+def _M_to_D_near_parabolic_hf(M, ecc, tol, maxiter):
"""Parabolic eccentric anomaly from mean anomaly, near parabolic case.
Parameters
@@ -110,11 +130,11 @@ def M_to_D_near_parabolic(M, ecc, tol=1.48e-08, maxiter=50):
Parabolic eccentric anomaly.
"""
- D0 = M_to_D(M)
+ D0 = M_to_D_hf(M)
for _ in range(maxiter):
- fval = _kepler_equation_near_parabolic(D0, M, ecc)
- fder = _kepler_equation_prime_near_parabolic(D0, M, ecc)
+ fval = _kepler_equation_near_parabolic_hf(D0, M, ecc)
+ fder = _kepler_equation_prime_near_parabolic_hf(D0, M, ecc)
newton_step = fval / fder
D = D0 - newton_step
@@ -123,11 +143,11 @@ def M_to_D_near_parabolic(M, ecc, tol=1.48e-08, maxiter=50):
D0 = D
- return np.nan
+ return nan
-@jit
-def delta_t_from_nu(nu, ecc, k=1.0, q=1.0, delta=1e-2):
+@hjit("f(f,f,f,f,f)")
+def delta_t_from_nu_hf(nu, ecc, k, q, delta):
"""Time elapsed since periapsis for given true anomaly.
Parameters
@@ -149,62 +169,71 @@ def delta_t_from_nu(nu, ecc, k=1.0, q=1.0, delta=1e-2):
Time elapsed since periapsis.
"""
- assert -np.pi <= nu < np.pi
+ assert -pi <= nu < pi
if ecc < 1 - delta:
# Strong elliptic
- E = nu_to_E(nu, ecc) # (-pi, pi]
- M = E_to_M(E, ecc) # (-pi, pi]
- n = np.sqrt(k * (1 - ecc) ** 3 / q**3)
+ E = nu_to_E_hf(nu, ecc) # (-pi, pi]
+ M = E_to_M_hf(E, ecc) # (-pi, pi]
+ n = sqrt(k * (1 - ecc) ** 3 / q**3)
elif 1 - delta <= ecc < 1:
- E = nu_to_E(nu, ecc) # (-pi, pi]
- if delta <= 1 - ecc * np.cos(E):
+ E = nu_to_E_hf(nu, ecc) # (-pi, pi]
+ if delta <= 1 - ecc * cos(E):
# Strong elliptic
- M = E_to_M(E, ecc) # (-pi, pi]
- n = np.sqrt(k * (1 - ecc) ** 3 / q**3)
+ M = E_to_M_hf(E, ecc) # (-pi, pi]
+ n = sqrt(k * (1 - ecc) ** 3 / q**3)
else:
# Near parabolic
- D = nu_to_D(nu) # (-∞, ∞)
+ D = nu_to_D_hf(nu) # (-∞, ∞)
# If |nu| is far from pi this result is bounded
# because the near parabolic region shrinks in its vicinity,
# otherwise the eccentricity is very close to 1
# and we are really far away
- M = D_to_M_near_parabolic(D, ecc)
- n = np.sqrt(k / (2 * q**3))
+ M = _D_to_M_near_parabolic_hf(D, ecc)
+ n = sqrt(k / (2 * q**3))
elif ecc == 1:
# Parabolic
- D = nu_to_D(nu) # (-∞, ∞)
- M = D_to_M(D) # (-∞, ∞)
- n = np.sqrt(k / (2 * q**3))
- elif 1 + ecc * np.cos(nu) < 0:
+ D = nu_to_D_hf(nu) # (-∞, ∞)
+ M = D_to_M_hf(D) # (-∞, ∞)
+ n = sqrt(k / (2 * q**3))
+ elif 1 + ecc * cos(nu) < 0:
# Unfeasible region
- return np.nan
+ return nan
elif 1 < ecc <= 1 + delta:
# NOTE: Do we need to wrap nu here?
# For hyperbolic orbits, it should anyway be in
# (-arccos(-1 / ecc), +arccos(-1 / ecc))
- F = nu_to_F(nu, ecc) # (-∞, ∞)
- if delta <= ecc * np.cosh(F) - 1:
+ F = nu_to_F_hf(nu, ecc) # (-∞, ∞)
+ if delta <= ecc * cosh(F) - 1:
# Strong hyperbolic
- M = F_to_M(F, ecc) # (-∞, ∞)
- n = np.sqrt(k * (ecc - 1) ** 3 / q**3)
+ M = F_to_M_hf(F, ecc) # (-∞, ∞)
+ n = sqrt(k * (ecc - 1) ** 3 / q**3)
else:
# Near parabolic
- D = nu_to_D(nu) # (-∞, ∞)
- M = D_to_M_near_parabolic(D, ecc) # (-∞, ∞)
- n = np.sqrt(k / (2 * q**3))
+ D = nu_to_D_hf(nu) # (-∞, ∞)
+ M = _D_to_M_near_parabolic_hf(D, ecc) # (-∞, ∞)
+ n = sqrt(k / (2 * q**3))
elif 1 + delta < ecc:
# Strong hyperbolic
- F = nu_to_F(nu, ecc) # (-∞, ∞)
- M = F_to_M(F, ecc) # (-∞, ∞)
- n = np.sqrt(k * (ecc - 1) ** 3 / q**3)
+ F = nu_to_F_hf(nu, ecc) # (-∞, ∞)
+ M = F_to_M_hf(F, ecc) # (-∞, ∞)
+ n = sqrt(k * (ecc - 1) ** 3 / q**3)
else:
raise RuntimeError
return M / n
-@jit
-def nu_from_delta_t(delta_t, ecc, k=1.0, q=1.0, delta=1e-2):
+@vjit("f(f,f,f,f,f)")
+def delta_t_from_nu_vf(nu, ecc, k, q, delta):
+ """
+ Vectorized delta_t_from_nu
+ """
+
+ return delta_t_from_nu_hf(nu, ecc, k, q, delta)
+
+
+@hjit("f(f,f,f,f,f)")
+def _nu_from_delta_t_hf(delta_t, ecc, k, q, delta):
"""True anomaly for given elapsed time since periapsis.
Parameters
@@ -228,77 +257,90 @@ def nu_from_delta_t(delta_t, ecc, k=1.0, q=1.0, delta=1e-2):
"""
if ecc < 1 - delta:
# Strong elliptic
- n = np.sqrt(k * (1 - ecc) ** 3 / q**3)
+ n = sqrt(k * (1 - ecc) ** 3 / q**3)
M = n * delta_t
# This might represent several revolutions,
# so we wrap the true anomaly
- E = M_to_E((M + np.pi) % (2 * np.pi) - np.pi, ecc)
- nu = E_to_nu(E, ecc)
+ E = M_to_E_hf((M + pi) % (2 * pi) - pi, ecc)
+ nu = E_to_nu_hf(E, ecc)
elif 1 - delta <= ecc < 1:
- E_delta = np.arccos((1 - delta) / ecc)
+ E_delta = acos((1 - delta) / ecc)
# We compute M assuming we are in the strong elliptic case
# and verify later
- n = np.sqrt(k * (1 - ecc) ** 3 / q**3)
+ n = sqrt(k * (1 - ecc) ** 3 / q**3)
M = n * delta_t
# We check against abs(M) because E_delta could also be negative
- if E_to_M(E_delta, ecc) <= abs(M):
+ if E_to_M_hf(E_delta, ecc) <= abs(M):
# Strong elliptic, proceed
# This might represent several revolutions,
# so we wrap the true anomaly
- E = M_to_E((M + np.pi) % (2 * np.pi) - np.pi, ecc)
- nu = E_to_nu(E, ecc)
+ E = M_to_E_hf((M + pi) % (2 * pi) - pi, ecc)
+ nu = E_to_nu_hf(E, ecc)
else:
# Near parabolic, recompute M
- n = np.sqrt(k / (2 * q**3))
+ n = sqrt(k / (2 * q**3))
M = n * delta_t
- D = M_to_D_near_parabolic(M, ecc)
- nu = D_to_nu(D)
+ D = _M_to_D_near_parabolic_hf(M, ecc, _TOL, _MAXITER)
+ nu = D_to_nu_hf(D)
elif ecc == 1:
# Parabolic
- n = np.sqrt(k / (2 * q**3))
+ n = sqrt(k / (2 * q**3))
M = n * delta_t
- D = M_to_D(M)
- nu = D_to_nu(D)
+ D = M_to_D_hf(M)
+ nu = D_to_nu_hf(D)
elif 1 < ecc <= 1 + delta:
- F_delta = np.arccosh((1 + delta) / ecc)
+ F_delta = acosh((1 + delta) / ecc)
# We compute M assuming we are in the strong hyperbolic case
# and verify later
- n = np.sqrt(k * (ecc - 1) ** 3 / q**3)
+ n = sqrt(k * (ecc - 1) ** 3 / q**3)
M = n * delta_t
# We check against abs(M) because F_delta could also be negative
- if F_to_M(F_delta, ecc) <= abs(M):
+ if F_to_M_hf(F_delta, ecc) <= abs(M):
# Strong hyperbolic, proceed
- F = M_to_F(M, ecc)
- nu = F_to_nu(F, ecc)
+ F = M_to_F_hf(M, ecc)
+ nu = F_to_nu_hf(F, ecc)
else:
# Near parabolic, recompute M
- n = np.sqrt(k / (2 * q**3))
+ n = sqrt(k / (2 * q**3))
M = n * delta_t
- D = M_to_D_near_parabolic(M, ecc)
- nu = D_to_nu(D)
+ D = _M_to_D_near_parabolic_hf(M, ecc, _TOL, _MAXITER)
+ nu = D_to_nu_hf(D)
# elif 1 + delta < ecc:
else:
# Strong hyperbolic
- n = np.sqrt(k * (ecc - 1) ** 3 / q**3)
+ n = sqrt(k * (ecc - 1) ** 3 / q**3)
M = n * delta_t
- F = M_to_F(M, ecc)
- nu = F_to_nu(F, ecc)
+ F = M_to_F_hf(M, ecc)
+ nu = F_to_nu_hf(F, ecc)
return nu
-@jit
-def farnocchia_coe(k, p, ecc, inc, raan, argp, nu, tof):
- q = p / (1 + ecc)
+@hjit("f(f,f,f,f,f,f,f,f)")
+def farnocchia_coe_hf(k, p, ecc, inc, raan, argp, nu, tof):
+ """
+ Scalar farnocchia_coe
+ """
- delta_t0 = delta_t_from_nu(nu, ecc, k, q)
+ q = p / (1.0 + ecc)
+
+ delta_t0 = delta_t_from_nu_hf(nu, ecc, k, q, FARNOCCHIA_DELTA)
delta_t = delta_t0 + tof
- return nu_from_delta_t(delta_t, ecc, k, q)
+ return _nu_from_delta_t_hf(delta_t, ecc, k, q, FARNOCCHIA_DELTA)
+
+
+@vjit("f(f,f,f,f,f,f,f,f)")
+def farnocchia_coe_vf(k, p, ecc, inc, raan, argp, nu, tof):
+ """
+ Vectorized farnocchia_coe
+ """
+
+ return farnocchia_coe_hf(k, p, ecc, inc, raan, argp, nu, tof)
-@jit
-def farnocchia_rv(k, r0, v0, tof):
+@hjit("Tuple([V,V])(f,V,V,f)")
+def farnocchia_rv_hf(k, r0, v0, tof):
r"""Propagates orbit using mean motion.
This algorithm depends on the geometric shape of the orbit.
@@ -314,9 +356,9 @@ def farnocchia_rv(k, r0, v0, tof):
----------
k : float
Standar Gravitational parameter
- r0 : numpy.ndarray
+ r0 : tuple[float,float,float]
Initial position vector wrt attractor center.
- v0 : numpy.ndarray
+ v0 : tuple[float,float,float]
Initial velocity vector.
tof : float
Time of flight (s).
@@ -329,7 +371,21 @@ def farnocchia_rv(k, r0, v0, tof):
"""
# get the initial true anomaly and orbit parameters that are constant over time
- p, ecc, inc, raan, argp, nu0 = rv2coe(k, r0, v0)
- nu = farnocchia_coe(k, p, ecc, inc, raan, argp, nu0, tof)
+ p, ecc, inc, raan, argp, nu0 = rv2coe_hf(k, r0, v0, RV2COE_TOL)
+ nu = farnocchia_coe_hf(k, p, ecc, inc, raan, argp, nu0, tof)
- return coe2rv(k, p, ecc, inc, raan, argp, nu)
+ return coe2rv_hf(k, p, ecc, inc, raan, argp, nu)
+
+
+@gjit(
+ "void(f,f[:],f[:],f,f[:],f[:])",
+ "(),(n),(n),()->(n),(n)",
+)
+def farnocchia_rv_gf(k, r0, v0, tof, rr, vv):
+ """
+ Vectorized farnocchia_rv
+ """
+
+ (rr[0], rr[1], rr[2]), (vv[0], vv[1], vv[2]) = farnocchia_rv_hf(
+ k, array_to_V_hf(r0), array_to_V_hf(v0), tof
+ )
diff --git a/src/hapsira/core/propagation/gooding.py b/src/hapsira/core/propagation/gooding.py
index 59f53bc0b..bd290f394 100644
--- a/src/hapsira/core/propagation/gooding.py
+++ b/src/hapsira/core/propagation/gooding.py
@@ -1,32 +1,49 @@
-from numba import njit as jit
-import numpy as np
+from math import cos, sin, sqrt
-from hapsira.core.angles import E_to_M, E_to_nu, nu_to_E
-from hapsira.core.elements import coe2rv, rv2coe
+from ..angles import E_to_M_hf, E_to_nu_hf, nu_to_E_hf
+from ..elements import coe2rv_hf, rv2coe_hf, RV2COE_TOL
+from ..jit import array_to_V_hf, hjit, vjit, gjit
-@jit
-def gooding_coe(k, p, ecc, inc, raan, argp, nu, tof, numiter=150, rtol=1e-8):
+__all__ = [
+ "gooding_coe_hf",
+ "gooding_coe_vf",
+ "gooding_rv_hf",
+ "gooding_rv_gf",
+ "GOODING_NUMITER",
+ "GOODING_RTOL",
+]
+
+
+GOODING_NUMITER = 150
+GOODING_RTOL = 1e-8
+
+
+@hjit("f(f,f,f,f,f,f,f,f,i8,f)")
+def gooding_coe_hf(k, p, ecc, inc, raan, argp, nu, tof, numiter, rtol):
+ """
+ Scalar gooding_coe
+ """
# TODO: parabolic and hyperbolic not implemented cases
if ecc >= 1.0:
raise NotImplementedError(
"Parabolic/Hyperbolic cases still not implemented in gooding."
)
- M0 = E_to_M(nu_to_E(nu, ecc), ecc)
+ M0 = E_to_M_hf(nu_to_E_hf(nu, ecc), ecc)
semi_axis_a = p / (1 - ecc**2)
- n = np.sqrt(k / np.abs(semi_axis_a) ** 3)
+ n = sqrt(k / abs(semi_axis_a) ** 3)
M = M0 + n * tof
# Start the computation
n = 0
- c = ecc * np.cos(M)
- s = ecc * np.sin(M)
- psi = s / np.sqrt(1 - 2 * c + ecc**2)
+ c = ecc * cos(M)
+ s = ecc * sin(M)
+ psi = s / sqrt(1 - 2 * c + ecc**2)
f = 1.0
while f**2 >= rtol and n <= numiter:
- xi = np.cos(psi)
- eta = np.sin(psi)
+ xi = cos(psi)
+ eta = sin(psi)
fd = (1 - c * xi) + s * eta
fdd = c * eta + s * xi
f = psi - fdd
@@ -34,11 +51,20 @@ def gooding_coe(k, p, ecc, inc, raan, argp, nu, tof, numiter=150, rtol=1e-8):
n += 1
E = M + psi
- return E_to_nu(E, ecc)
+ return E_to_nu_hf(E, ecc)
+
+
+@vjit("f(f,f,f,f,f,f,f,f,i8,f)")
+def gooding_coe_vf(k, p, ecc, inc, raan, argp, nu, tof, numiter, rtol):
+ """
+ Vectorized gooding_coe
+ """
+ return gooding_coe_hf(k, p, ecc, inc, raan, argp, nu, tof, numiter, rtol)
-@jit
-def gooding(k, r0, v0, tof, numiter=150, rtol=1e-8):
+
+@hjit("Tuple([V,V])(f,V,V,f,i8,f)")
+def gooding_rv_hf(k, r0, v0, tof, numiter, rtol):
"""Solves the Elliptic Kepler Equation with a cubic convergence and
accuracy better than 10e-12 rad is normally achieved. It is not valid for
eccentricities equal or higher than 1.0.
@@ -47,9 +73,9 @@ def gooding(k, r0, v0, tof, numiter=150, rtol=1e-8):
----------
k : float
Standard gravitational parameter of the attractor.
- r0 : numpy.ndarray
+ r0 : tuple[float,float,float]
Position vector.
- v0 : numpy.ndarray
+ v0 : tuple[float,float,float]
Velocity vector.
tof : float
Time of flight.
@@ -60,16 +86,27 @@ def gooding(k, r0, v0, tof, numiter=150, rtol=1e-8):
Returns
-------
- rr : numpy.ndarray
+ rr : tuple[float,float,float]
Final position vector.
- vv : numpy.ndarray
+ vv : tuple[float,float,float]
Final velocity vector.
Note
----
Original paper for the algorithm: https://doi.org/10.1007/BF01238923
"""
# Solve first for eccentricity and mean anomaly
- p, ecc, inc, raan, argp, nu = rv2coe(k, r0, v0)
- nu = gooding_coe(k, p, ecc, inc, raan, argp, nu, tof, numiter, rtol)
+ p, ecc, inc, raan, argp, nu = rv2coe_hf(k, r0, v0, RV2COE_TOL)
+ nu = gooding_coe_hf(k, p, ecc, inc, raan, argp, nu, tof, numiter, rtol)
+
+ return coe2rv_hf(k, p, ecc, inc, raan, argp, nu)
+
+
+@gjit("void(f,f[:],f[:],f,i8,f,f[:],f[:])", "(),(n),(n),(),(),()->(),()")
+def gooding_rv_gf(k, r0, v0, tof, numiter, rtol, rr, vv):
+ """
+ Vectorized gooding_rv
+ """
- return coe2rv(k, p, ecc, inc, raan, argp, nu)
+ (rr[0], rr[1], rr[2]), (vv[0], vv[1], vv[2]) = gooding_rv_hf(
+ k, array_to_V_hf(r0), array_to_V_hf(v0), tof, numiter, rtol
+ )
diff --git a/src/hapsira/core/propagation/markley.py b/src/hapsira/core/propagation/markley.py
index 2cedb5352..743f70d73 100644
--- a/src/hapsira/core/propagation/markley.py
+++ b/src/hapsira/core/propagation/markley.py
@@ -1,28 +1,40 @@
-from numba import njit as jit
-import numpy as np
-
-from hapsira.core.angles import (
- E_to_M,
- E_to_nu,
- _kepler_equation,
- _kepler_equation_prime,
- nu_to_E,
+from math import cos, pi, sin, sqrt
+
+from ..angles import (
+ E_to_M_hf,
+ E_to_nu_hf,
+ kepler_equation_hf,
+ kepler_equation_prime_hf,
+ nu_to_E_hf,
)
-from hapsira.core.elements import coe2rv, rv2coe
+from ..elements import coe2rv_hf, rv2coe_hf, RV2COE_TOL
+from ..jit import array_to_V_hf, hjit, vjit, gjit
-@jit
-def markley_coe(k, p, ecc, inc, raan, argp, nu, tof):
- M0 = E_to_M(nu_to_E(nu, ecc), ecc)
+__all__ = [
+ "markley_coe_hf",
+ "markley_coe_vf",
+ "markley_rv_hf",
+ "markley_rv_gf",
+]
+
+
+@hjit("f(f,f,f,f,f,f,f,f)")
+def markley_coe_hf(k, p, ecc, inc, raan, argp, nu, tof):
+ """
+ Scalar markley_coe
+ """
+
+ M0 = E_to_M_hf(nu_to_E_hf(nu, ecc), ecc)
a = p / (1 - ecc**2)
- n = np.sqrt(k / a**3)
+ n = sqrt(k / a**3)
M = M0 + n * tof
# Range between -pi and pi
- M = (M + np.pi) % (2 * np.pi) - np.pi
+ M = (M + pi) % (2 * pi) - pi
# Equation (20)
- alpha = (3 * np.pi**2 + 1.6 * (np.pi - np.abs(M)) / (1 + ecc)) / (np.pi**2 - 6)
+ alpha = (3 * pi**2 + 1.6 * (pi - abs(M)) / (1 + ecc)) / (pi**2 - 6)
# Equation (5)
d = 3 * (1 - ecc) + alpha * ecc
@@ -34,16 +46,16 @@ def markley_coe(k, p, ecc, inc, raan, argp, nu, tof):
r = 3 * alpha * d * (d - 1 + ecc) * M + M**3
# Equation (14)
- w = (np.abs(r) + np.sqrt(q**3 + r**2)) ** (2 / 3)
+ w = (abs(r) + sqrt(q**3 + r**2)) ** (2 / 3)
# Equation (15)
E = (2 * r * w / (w**2 + w * q + q**2) + M) / d
# Equation (26)
- f0 = _kepler_equation(E, M, ecc)
- f1 = _kepler_equation_prime(E, M, ecc)
- f2 = ecc * np.sin(E)
- f3 = ecc * np.cos(E)
+ f0 = kepler_equation_hf(E, M, ecc)
+ f1 = kepler_equation_prime_hf(E, M, ecc)
+ f2 = ecc * sin(E)
+ f3 = ecc * cos(E)
f4 = -f2
# Equation (22)
@@ -54,13 +66,22 @@ def markley_coe(k, p, ecc, inc, raan, argp, nu, tof):
)
E += delta5
- nu = E_to_nu(E, ecc)
+ nu = E_to_nu_hf(E, ecc)
return nu
-@jit
-def markley(k, r0, v0, tof):
+@vjit("f(f,f,f,f,f,f,f,f)")
+def markley_coe_vf(k, p, ecc, inc, raan, argp, nu, tof):
+ """
+ Vectorized markley_coe
+ """
+
+ return markley_coe_hf(k, p, ecc, inc, raan, argp, nu, tof)
+
+
+@hjit("Tuple([V,V])(f,V,V,f)")
+def markley_rv_hf(k, r0, v0, tof):
"""Solves the kepler problem by a non-iterative method. Relative error is
around 1e-18, only limited by machine double-precision errors.
@@ -68,18 +89,18 @@ def markley(k, r0, v0, tof):
----------
k : float
Standar Gravitational parameter.
- r0 : numpy.ndarray
+ r0 : tuple[float,float,float]
Initial position vector wrt attractor center.
- v0 : numpy.ndarray
+ v0 : tuple[float,float,float]
Initial velocity vector.
tof : float
Time of flight.
Returns
-------
- rr: numpy.ndarray
+ rr: tuple[float,float,float]
Final position vector.
- vv: numpy.ndarray
+ vv: tuple[float,float,float]
Final velocity vector.
Notes
@@ -88,7 +109,18 @@ def markley(k, r0, v0, tof):
"""
# Solve first for eccentricity and mean anomaly
- p, ecc, inc, raan, argp, nu = rv2coe(k, r0, v0)
- nu = markley_coe(k, p, ecc, inc, raan, argp, nu, tof)
+ p, ecc, inc, raan, argp, nu = rv2coe_hf(k, r0, v0, RV2COE_TOL)
+ nu = markley_coe_hf(k, p, ecc, inc, raan, argp, nu, tof)
+
+ return coe2rv_hf(k, p, ecc, inc, raan, argp, nu)
+
- return coe2rv(k, p, ecc, inc, raan, argp, nu)
+@gjit("void(f,f[:],f[:],f,f[:],f[:])", "(),(n),(n),()->(n),(n)")
+def markley_rv_gf(k, r0, v0, tof, rr, vv):
+ """
+ Vectorized markley_rv
+ """
+
+ (rr[0], rr[1], rr[2]), (vv[0], vv[1], vv[2]) = markley_rv_hf(
+ k, array_to_V_hf(r0), array_to_V_hf(v0), tof
+ )
diff --git a/src/hapsira/core/propagation/mikkola.py b/src/hapsira/core/propagation/mikkola.py
index f40d18412..5e4a87605 100644
--- a/src/hapsira/core/propagation/mikkola.py
+++ b/src/hapsira/core/propagation/mikkola.py
@@ -1,40 +1,52 @@
-from numba import njit as jit
-import numpy as np
-
-from hapsira.core.angles import (
- D_to_nu,
- E_to_M,
- E_to_nu,
- F_to_M,
- F_to_nu,
- nu_to_E,
- nu_to_F,
+from math import cos, cosh, log, sin, sinh, sqrt
+
+from ..angles import (
+ D_to_nu_hf,
+ E_to_M_hf,
+ E_to_nu_hf,
+ F_to_M_hf,
+ F_to_nu_hf,
+ nu_to_E_hf,
+ nu_to_F_hf,
)
-from hapsira.core.elements import coe2rv, rv2coe
+from ..elements import coe2rv_hf, rv2coe_hf, RV2COE_TOL
+from ..jit import array_to_V_hf, hjit, vjit, gjit
-@jit
-def mikkola_coe(k, p, ecc, inc, raan, argp, nu, tof):
+__all__ = [
+ "mikkola_coe_hf",
+ "mikkola_coe_vf",
+ "mikkola_rv_hf",
+ "mikkola_rv_gf",
+]
+
+
+@hjit("f(f,f,f,f,f,f,f,f)")
+def mikkola_coe_hf(k, p, ecc, inc, raan, argp, nu, tof):
+ """
+ Scalar mikkola_coe
+ """
+
a = p / (1 - ecc**2)
- n = np.sqrt(k / np.abs(a) ** 3)
+ n = sqrt(k / abs(a) ** 3)
# Solve for specific geometrical case
if ecc < 1.0:
# Equation (9a)
alpha = (1 - ecc) / (4 * ecc + 1 / 2)
- M0 = E_to_M(nu_to_E(nu, ecc), ecc)
+ M0 = E_to_M_hf(nu_to_E_hf(nu, ecc), ecc)
else:
alpha = (ecc - 1) / (4 * ecc + 1 / 2)
- M0 = F_to_M(nu_to_F(nu, ecc), ecc)
+ M0 = F_to_M_hf(nu_to_F_hf(nu, ecc), ecc)
M = M0 + n * tof
beta = M / 2 / (4 * ecc + 1 / 2)
# Equation (9b)
if beta >= 0:
- z = (beta + np.sqrt(beta**2 + alpha**3)) ** (1 / 3)
+ z = (beta + sqrt(beta**2 + alpha**3)) ** (1 / 3)
else:
- z = (beta - np.sqrt(beta**2 + alpha**3)) ** (1 / 3)
+ z = (beta - sqrt(beta**2 + alpha**3)) ** (1 / 3)
s = z - alpha / z
@@ -49,18 +61,18 @@ def mikkola_coe(k, p, ecc, inc, raan, argp, nu, tof):
# Solving for the true anomaly
if ecc < 1.0:
E = M + ecc * (3 * s - 4 * s**3)
- f = E - ecc * np.sin(E) - M
- f1 = 1.0 - ecc * np.cos(E)
- f2 = ecc * np.sin(E)
- f3 = ecc * np.cos(E)
+ f = E - ecc * sin(E) - M
+ f1 = 1.0 - ecc * cos(E)
+ f2 = ecc * sin(E)
+ f3 = ecc * cos(E)
f4 = -f2
f5 = -f3
else:
- E = 3 * np.log(s + np.sqrt(1 + s**2))
- f = -E + ecc * np.sinh(E) - M
- f1 = -1.0 + ecc * np.cosh(E)
- f2 = ecc * np.sinh(E)
- f3 = ecc * np.cosh(E)
+ E = 3 * log(s + sqrt(1 + s**2))
+ f = -E + ecc * sinh(E) - M
+ f1 = -1.0 + ecc * cosh(E)
+ f2 = ecc * sinh(E)
+ f3 = ecc * cosh(E)
f4 = f2
f5 = f3
@@ -82,47 +94,65 @@ def mikkola_coe(k, p, ecc, inc, raan, argp, nu, tof):
E += u5
if ecc < 1.0:
- nu = E_to_nu(E, ecc)
+ nu = E_to_nu_hf(E, ecc)
else:
if ecc == 1.0:
# Parabolic
- nu = D_to_nu(E)
+ nu = D_to_nu_hf(E)
else:
# Hyperbolic
- nu = F_to_nu(E, ecc)
+ nu = F_to_nu_hf(E, ecc)
return nu
-@jit
-def mikkola(k, r0, v0, tof, rtol=None):
+@vjit("f(f,f,f,f,f,f,f,f)")
+def mikkola_coe_vf(k, p, ecc, inc, raan, argp, nu, tof):
+ """
+ Vectorized mikkola_coe
+ """
+
+ return mikkola_coe_hf(k, p, ecc, inc, raan, argp, nu, tof)
+
+
+@hjit("Tuple([V,V])(f,V,V,f)")
+def mikkola_rv_hf(k, r0, v0, tof):
"""Raw algorithm for Mikkola's Kepler solver.
Parameters
----------
k : float
Standard gravitational parameter of the attractor.
- r0 : numpy.ndarray
+ r0 : tuple[float,float,float]
Position vector.
- v0 : numpy.ndarray
+ v0 : tuple[float,float,float]
Velocity vector.
tof : float
Time of flight.
- rtol : float
- This method does not require tolerance since it is non-iterative.
Returns
-------
- rr : numpy.ndarray
+ rr : tuple[float,float,float]
Final velocity vector.
- vv : numpy.ndarray
+ vv : tuple[float,float,float]
Final velocity vector.
Note
----
Original paper: https://doi.org/10.1007/BF01235850
"""
# Solving for the classical elements
- p, ecc, inc, raan, argp, nu = rv2coe(k, r0, v0)
- nu = mikkola_coe(k, p, ecc, inc, raan, argp, nu, tof)
+ p, ecc, inc, raan, argp, nu = rv2coe_hf(k, r0, v0, RV2COE_TOL)
+ nu = mikkola_coe_hf(k, p, ecc, inc, raan, argp, nu, tof)
+
+ return coe2rv_hf(k, p, ecc, inc, raan, argp, nu)
+
- return coe2rv(k, p, ecc, inc, raan, argp, nu)
+@gjit("void(f,f[:],f[:],f,f[:],f[:])", "(),(n),(n),()->(n),(n)")
+def mikkola_rv_gf(k, r0, v0, tof, rr, vv):
+ """
+ Vectorized mikkola_rv
+ """
+
+ (rr[0], rr[1], rr[2]), (vv[0], vv[1], vv[2]) = mikkola_rv_hf(
+ k, array_to_V_hf(r0), array_to_V_hf(v0), tof
+ )
diff --git a/src/hapsira/core/propagation/pimienta.py b/src/hapsira/core/propagation/pimienta.py
index 3914466ad..dfc471aa6 100644
--- a/src/hapsira/core/propagation/pimienta.py
+++ b/src/hapsira/core/propagation/pimienta.py
@@ -1,17 +1,29 @@
-from numba import njit as jit
-import numpy as np
+from math import sqrt
-from hapsira.core.angles import E_to_M, E_to_nu, nu_to_E
-from hapsira.core.elements import coe2rv, rv2coe
+from ..angles import E_to_M_hf, E_to_nu_hf, nu_to_E_hf
+from ..elements import coe2rv_hf, rv2coe_hf, RV2COE_TOL
+from ..jit import array_to_V_hf, hjit, vjit, gjit
-@jit
-def pimienta_coe(k, p, ecc, inc, raan, argp, nu, tof):
+__all__ = [
+ "pimienta_coe_hf",
+ "pimienta_coe_vf",
+ "pimienta_rv_hf",
+ "pimienta_rv_gf",
+]
+
+
+@hjit("f(f,f,f,f,f,f,f,f)")
+def pimienta_coe_hf(k, p, ecc, inc, raan, argp, nu, tof):
+ """
+ Scalar pimienta_coe
+ """
+
q = p / (1 + ecc)
# TODO: Do something to allow parabolic and hyperbolic orbits?
- n = np.sqrt(k * (1 - ecc) ** 3 / q**3)
- M0 = E_to_M(nu_to_E(nu, ecc), ecc)
+ n = sqrt(k * (1 - ecc) ** 3 / q**3)
+ M0 = E_to_M_hf(nu_to_E_hf(nu, ecc), ecc)
M = M0 + n * tof
@@ -19,7 +31,7 @@ def pimienta_coe(k, p, ecc, inc, raan, argp, nu, tof):
c3 = 5 / 2 + 560 * ecc
a = 15 * (1 - ecc) / c3
b = -M / c3
- y = np.sqrt(b**2 / 4 + a**3 / 27)
+ y = sqrt(b**2 / 4 + a**3 / 27)
# Equation (33)
x_bar = (-b / 2 + y) ** (1 / 3) - (b / 2 + y) ** (1 / 3)
@@ -333,11 +345,20 @@ def pimienta_coe(k, p, ecc, inc, raan, argp, nu, tof):
+ 15 * w
)
- return E_to_nu(E, ecc)
+ return E_to_nu_hf(E, ecc)
+
+
+@vjit("f(f,f,f,f,f,f,f,f)")
+def pimienta_coe_vf(k, p, ecc, inc, raan, argp, nu, tof):
+ """
+ Vectorized pimienta_coe
+ """
+
+ return pimienta_coe_hf(k, p, ecc, inc, raan, argp, nu, tof)
-@jit
-def pimienta(k, r0, v0, tof):
+@hjit("Tuple([V,V])(f,V,V,f)")
+def pimienta_rv_hf(k, r0, v0, tof):
"""Raw algorithm for Adonis' Pimienta and John L. Crassidis 15th order
polynomial Kepler solver.
@@ -368,7 +389,18 @@ def pimienta(k, r0, v0, tof):
# TODO: implement hyperbolic case
# Solve first for eccentricity and mean anomaly
- p, ecc, inc, raan, argp, nu = rv2coe(k, r0, v0)
- nu = pimienta_coe(k, p, ecc, inc, raan, argp, nu, tof)
+ p, ecc, inc, raan, argp, nu = rv2coe_hf(k, r0, v0, RV2COE_TOL)
+ nu = pimienta_coe_hf(k, p, ecc, inc, raan, argp, nu, tof)
+
+ return coe2rv_hf(k, p, ecc, inc, raan, argp, nu)
+
- return coe2rv(k, p, ecc, inc, raan, argp, nu)
+@gjit("void(f,f[:],f[:],f,f[:],f[:])", "(),(n),(n),()->(n),(n)")
+def pimienta_rv_gf(k, r0, v0, tof, rr, vv):
+ """
+ Vectorized pimienta_rv
+ """
+
+ (rr[0], rr[1], rr[2]), (vv[0], vv[1], vv[2]) = pimienta_rv_hf(
+ k, array_to_V_hf(r0), array_to_V_hf(v0), tof
+ )
diff --git a/src/hapsira/core/propagation/recseries.py b/src/hapsira/core/propagation/recseries.py
index ae99feba2..44ef6a92e 100644
--- a/src/hapsira/core/propagation/recseries.py
+++ b/src/hapsira/core/propagation/recseries.py
@@ -1,12 +1,32 @@
-from numba import njit as jit
-import numpy as np
+from math import floor, pi, sin, sqrt
-from hapsira.core.angles import E_to_M, E_to_nu, nu_to_E
-from hapsira.core.elements import coe2rv, rv2coe
+from ..angles import E_to_M_hf, E_to_nu_hf, nu_to_E_hf
+from ..elements import coe2rv_hf, rv2coe_hf, RV2COE_TOL
+from ..jit import array_to_V_hf, hjit, vjit, gjit
-@jit
-def recseries_coe(
+__all__ = [
+ "recseries_coe_hf",
+ "recseries_coe_vf",
+ "recseries_rv_hf",
+ "recseries_rv_gf",
+ "RECSERIES_METHOD_RTOL",
+ "RECSERIES_METHOD_ORDER",
+ "RECSERIES_ORDER",
+ "RECSERIES_NUMITER",
+ "RECSERIES_RTOL",
+]
+
+
+RECSERIES_METHOD_RTOL = 0
+RECSERIES_METHOD_ORDER = 1
+RECSERIES_ORDER = 8
+RECSERIES_NUMITER = 100
+RECSERIES_RTOL = 1e-8
+
+
+@hjit("f(f,f,f,f,f,f,f,f,i8,i8,i8,f)")
+def recseries_coe_hf(
k,
p,
ecc,
@@ -15,15 +35,19 @@ def recseries_coe(
argp,
nu,
tof,
- method="rtol",
- order=8,
- numiter=100,
- rtol=1e-8,
+ method,
+ order,
+ numiter,
+ rtol,
):
+ """
+ Scalar recseries_coe
+ """
+
# semi-major axis
semi_axis_a = p / (1 - ecc**2)
# mean angular motion
- n = np.sqrt(k / np.abs(semi_axis_a) ** 3)
+ n = sqrt(k / abs(semi_axis_a) ** 3)
if ecc == 0:
# Solving for circular orbit
@@ -33,7 +57,7 @@ def recseries_coe(
# final mean anaomaly
M = M0 + n * tof
# snapping anomaly to [0,pi] range
- nu = M - 2 * np.pi * np.floor(M / 2 / np.pi)
+ nu = M - 2 * pi * floor(M / 2 / pi)
return nu
@@ -41,16 +65,16 @@ def recseries_coe(
# Solving for elliptical orbit
# compute initial mean anoamly
- M0 = E_to_M(nu_to_E(nu, ecc), ecc)
+ M0 = E_to_M_hf(nu_to_E_hf(nu, ecc), ecc)
# final mean anaomaly
M = M0 + n * tof
# snapping anomaly to [0,pi] range
- M = M - 2 * np.pi * np.floor(M / 2 / np.pi)
+ M = M - 2 * pi * floor(M / 2 / pi)
# set recursion iteration
- if method == "rtol":
+ if method == RECSERIES_METHOD_RTOL:
Niter = numiter
- elif method == "order":
+ elif method == RECSERIES_METHOD_ORDER:
Niter = order
else:
raise ValueError("Unknown recursion termination method ('rtol','order').")
@@ -58,13 +82,13 @@ def recseries_coe(
# compute eccentric anomaly through recursive series
E = M + ecc # Using initial guess from vallado to improve convergence
for i in range(0, Niter):
- En = M + ecc * np.sin(E)
+ En = M + ecc * sin(E)
# check for break condition
if method == "rtol" and (abs(En - E) / abs(E)) < rtol:
break
E = En
- return E_to_nu(E, ecc)
+ return E_to_nu_hf(E, ecc)
else:
# Parabolic/Hyperbolic orbits are not supported
@@ -73,8 +97,43 @@ def recseries_coe(
return nu
-@jit
-def recseries(k, r0, v0, tof, method="rtol", order=8, numiter=100, rtol=1e-8):
+@vjit("f(f,f,f,f,f,f,f,f,i8,i8,i8,f)")
+def recseries_coe_vf(
+ k,
+ p,
+ ecc,
+ inc,
+ raan,
+ argp,
+ nu,
+ tof,
+ method,
+ order,
+ numiter,
+ rtol,
+):
+ """
+ Vectorized recseries_coe
+ """
+
+ return recseries_coe_hf(
+ k,
+ p,
+ ecc,
+ inc,
+ raan,
+ argp,
+ nu,
+ tof,
+ method,
+ order,
+ numiter,
+ rtol,
+ )
+
+
+@hjit("Tuple([V,V])(f,V,V,f,i8,i8,i8,f)")
+def recseries_rv_hf(k, r0, v0, tof, method, order, numiter, rtol):
"""Kepler solver for elliptical orbits with recursive series approximation
method. The order of the series is a user defined parameter.
@@ -112,9 +171,20 @@ def recseries(k, r0, v0, tof, method="rtol", order=8, numiter=100, rtol=1e-8):
with DOI: http://dx.doi.org/10.13140/RG.2.2.18578.58563/1
"""
# Solve first for eccentricity and mean anomaly
- p, ecc, inc, raan, argp, nu = rv2coe(k, r0, v0)
- nu = recseries_coe(
+ p, ecc, inc, raan, argp, nu = rv2coe_hf(k, r0, v0, RV2COE_TOL)
+ nu = recseries_coe_hf(
k, p, ecc, inc, raan, argp, nu, tof, method, order, numiter, rtol
)
- return coe2rv(k, p, ecc, inc, raan, argp, nu)
+ return coe2rv_hf(k, p, ecc, inc, raan, argp, nu)
+
+
+@gjit("void(f,f[:],f[:],f,i8,i8,i8,f,f[:],f[:])", "(),(n),(n),(),(),(),(),()->(n),(n)")
+def recseries_rv_gf(k, r0, v0, tof, method, order, numiter, rtol, rr, vv):
+ """
+ Vectorized recseries_rv
+ """
+
+ (rr[0], rr[1], rr[2]), (vv[0], vv[1], vv[2]) = recseries_rv_hf(
+ k, array_to_V_hf(r0), array_to_V_hf(v0), tof, method, order, numiter, rtol
+ )
diff --git a/src/hapsira/core/propagation/vallado.py b/src/hapsira/core/propagation/vallado.py
index d52fc3a97..df6ee56c7 100644
--- a/src/hapsira/core/propagation/vallado.py
+++ b/src/hapsira/core/propagation/vallado.py
@@ -1,12 +1,25 @@
-from numba import njit as jit
-import numpy as np
+from math import log, sqrt
-from hapsira._math.linalg import norm
-from hapsira._math.special import stumpff_c2 as c2, stumpff_c3 as c3
+from ..elements import coe2rv_hf, rv2coe_hf, RV2COE_TOL
+from ..math.linalg import add_VV_hf, matmul_VV_hf, mul_Vs_hf, norm_V_hf, sign_hf
+from ..math.special import stumpff_c2_hf, stumpff_c3_hf
+from ..jit import array_to_V_hf, hjit, vjit, gjit
-@jit
-def vallado(k, r0, v0, tof, numiter):
+__all__ = [
+ "vallado_coe_hf",
+ "vallado_coe_vf",
+ "vallado_rv_hf",
+ "vallado_rv_gf",
+ "VALLADO_NUMITER",
+]
+
+
+VALLADO_NUMITER = 350
+
+
+@hjit("Tuple([f,f,f,f])(f,V,V,f,i8)")
+def _vallado_hf(k, r0, v0, tof, numiter):
r"""Solves Kepler's Equation by applying a Newton-Raphson method.
If the position of a body along its orbit wants to be computed
@@ -45,9 +58,9 @@ def vallado(k, r0, v0, tof, numiter):
----------
k : float
Standard gravitational parameter.
- r0 : numpy.ndarray
+ r0 : tuple[float,float,float]
Initial position vector.
- v0 : numpy.ndarray
+ v0 : tuple[float,float,float]
Initial velocity vector.
tof : float
Time of flight.
@@ -72,10 +85,10 @@ def vallado(k, r0, v0, tof, numiter):
"""
# Cache some results
- dot_r0v0 = r0 @ v0
- norm_r0 = norm(r0)
+ dot_r0v0 = matmul_VV_hf(r0, v0)
+ norm_r0 = norm_V_hf(r0)
sqrt_mu = k**0.5
- alpha = -(v0 @ v0) / k + 2 / norm_r0
+ alpha = -matmul_VV_hf(v0, v0) / k + 2 / norm_r0
# First guess
if alpha > 0:
@@ -84,14 +97,11 @@ def vallado(k, r0, v0, tof, numiter):
elif alpha < 0:
# Hyperbolic orbit
xi_new = (
- np.sign(tof)
+ sign_hf(tof)
* (-1 / alpha) ** 0.5
- * np.log(
+ * log(
(-2 * k * alpha * tof)
- / (
- dot_r0v0
- + np.sign(tof) * np.sqrt(-k / alpha) * (1 - norm_r0 * alpha)
- )
+ / (dot_r0v0 + sign_hf(tof) * sqrt(-k / alpha) * (1 - norm_r0 * alpha))
)
)
else:
@@ -104,8 +114,8 @@ def vallado(k, r0, v0, tof, numiter):
while count < numiter:
xi = xi_new
psi = xi * xi * alpha
- c2_psi = c2(psi)
- c3_psi = c3(psi)
+ c2_psi = stumpff_c2_hf(psi)
+ c3_psi = stumpff_c3_hf(psi)
norm_r = (
xi * xi * c2_psi
+ dot_r0v0 / sqrt_mu * xi * (1 - psi * c3_psi)
@@ -136,3 +146,56 @@ def vallado(k, r0, v0, tof, numiter):
fdot = sqrt_mu / (norm_r * norm_r0) * xi * (psi * c3_psi - 1)
return f, g, fdot, gdot
+
+
+@hjit("Tuple([V,V])(f,V,V,f,i8)")
+def vallado_rv_hf(k, r0, v0, tof, numiter):
+ """
+ Scalar vallado_rv
+ """
+
+ # Compute Lagrange coefficients
+ f, g, fdot, gdot = _vallado_hf(k, r0, v0, tof, numiter)
+
+ assert (
+ abs(f * gdot - fdot * g - 1) < 1e-5
+ ), "Internal error, solution is not consistent" # Fixed tolerance
+
+ # Return position and velocity vectors
+ r = add_VV_hf(mul_Vs_hf(r0, f), mul_Vs_hf(v0, g))
+ v = add_VV_hf(mul_Vs_hf(r0, fdot), mul_Vs_hf(v0, gdot))
+
+ return r, v
+
+
+@gjit("void(f,f[:],f[:],f,i8,f[:],f[:])", "(),(n),(n),(),()->(n),(n)")
+def vallado_rv_gf(k, r0, v0, tof, numiter, rr, vv):
+ """
+ Vectorized vallado_rv
+ """
+
+ (rr[0], rr[1], rr[2]), (vv[0], vv[1], vv[2]) = vallado_rv_hf(
+ k, array_to_V_hf(r0), array_to_V_hf(v0), tof, numiter
+ )
+
+
+@hjit("f(f,f,f,f,f,f,f,f,i8)")
+def vallado_coe_hf(k, p, ecc, inc, raan, argp, nu, tof, numiter):
+ """
+ Scalar vallado_coe
+ """
+
+ r0, v0 = coe2rv_hf(k, p, ecc, inc, raan, argp, nu)
+ rr, vv = vallado_rv_hf(k, r0, v0, tof, numiter)
+ _, _, _, _, _, nu_ = rv2coe_hf(k, rr, vv, RV2COE_TOL)
+
+ return nu_
+
+
+@vjit("f(f,f,f,f,f,f,f,f,i8)")
+def vallado_coe_vf(k, p, ecc, inc, raan, argp, nu, tof, numiter):
+ """
+ Vectorized vallado_coe
+ """
+
+ return vallado_coe_hf(k, p, ecc, inc, raan, argp, nu, tof, numiter)
diff --git a/src/hapsira/core/spheroid_location.py b/src/hapsira/core/spheroid_location.py
index dd6b54651..6c552f329 100644
--- a/src/hapsira/core/spheroid_location.py
+++ b/src/hapsira/core/spheroid_location.py
@@ -1,9 +1,25 @@
"""Low level calculations for oblate spheroid locations."""
+from math import atan, atan2, cos, sin, sqrt
+
from numba import njit as jit
import numpy as np
-from hapsira._math.linalg import norm
+from .jit import array_to_V_hf, hjit, gjit
+from .math.linalg import norm_V_hf
+
+
+__all__ = [
+ "cartesian_cords",
+ "f",
+ "N",
+ "tangential_vecs",
+ "radius_of_curvature",
+ "distance",
+ "is_visible",
+ "cartesian_to_ellipsoidal_hf",
+ "cartesian_to_ellipsoidal_gf",
+]
@jit
@@ -66,7 +82,7 @@ def N(a, b, c, cartesian_cords):
"""
x, y, z = cartesian_cords
N = np.array([2 * x / a**2, 2 * y / b**2, 2 * z / c**2])
- N /= norm(N)
+ N /= norm_V_hf(array_to_V_hf(N))
return N
@@ -82,7 +98,7 @@ def tangential_vecs(N):
"""
u = np.array([1.0, 0, 0])
u -= (u @ N) * N
- u /= norm(u)
+ u /= norm_V_hf(array_to_V_hf(u))
v = np.cross(N, u)
return u, v
@@ -125,7 +141,7 @@ def distance(cartesian_cords, px, py, pz):
"""
c = cartesian_cords
u = np.array([px, py, pz])
- d = norm(c - u)
+ d = norm_V_hf(array_to_V_hf(c - u))
return d
@@ -155,8 +171,8 @@ def is_visible(cartesian_cords, px, py, pz, N):
return p >= 0
-@jit
-def cartesian_to_ellipsoidal(a, c, x, y, z):
+@hjit("Tuple([f,f,f])(f,f,f,f,f)")
+def cartesian_to_ellipsoidal_hf(a, c, x, y, z):
"""Converts cartesian coordinates to ellipsoidal coordinates for the given ellipsoid.
Instead of the iterative formula, the function uses the approximation introduced in
Bowring, B. R. (1976). TRANSFORMATION FROM SPATIAL TO GEOGRAPHICAL COORDINATES.
@@ -177,16 +193,25 @@ def cartesian_to_ellipsoidal(a, c, x, y, z):
"""
e2 = 1 - (c / a) ** 2
e2_ = e2 / (1 - e2)
- p = np.sqrt(x**2 + y**2)
- th = np.arctan(z * a / (p * c))
- lon = np.arctan2(y, x) # Use `arctan2` so that lon lies in the range: [-pi, +pi]
- lat = np.arctan((z + e2_ * c * np.sin(th) ** 3) / (p - e2 * a * np.cos(th) ** 3))
+ p = sqrt(x**2 + y**2)
+ th = atan(z * a / (p * c))
+ lon = atan2(y, x) # Use `atan2` so that lon lies in the range: [-pi, +pi]
+ lat = atan((z + e2_ * c * sin(th) ** 3) / (p - e2 * a * cos(th) ** 3))
- v = a / np.sqrt(1 - e2 * np.sin(lat) ** 2)
+ v = a / sqrt(1 - e2 * sin(lat) ** 2)
h = (
- np.sqrt(x**2 + y**2) / np.cos(lat) - v
+ sqrt(x**2 + y**2) / cos(lat) - v
if lat < abs(1e-18) # to avoid errors very close and at zero
- else z / np.sin(lat) - (1 - e2) * v
+ else z / sin(lat) - (1 - e2) * v
)
return lon, lat, h
+
+
+@gjit("void(f,f,f,f,f,f[:],f[:],f[:])", "(),(),(),(),()->(),(),()")
+def cartesian_to_ellipsoidal_gf(a, c, x, y, z, lon, lat, h):
+ """
+ Vectorized cartesian_to_ellipsoidal
+ """
+
+ lon[0], lat[0], h[0] = cartesian_to_ellipsoidal_hf(a, c, x, y, z)
diff --git a/src/hapsira/core/thrust/__init__.py b/src/hapsira/core/thrust/__init__.py
index 873a4c3f7..e69de29bb 100644
--- a/src/hapsira/core/thrust/__init__.py
+++ b/src/hapsira/core/thrust/__init__.py
@@ -1,5 +0,0 @@
-from hapsira.core.thrust.change_a_inc import change_a_inc
-from hapsira.core.thrust.change_argp import change_argp
-from hapsira.core.thrust.change_ecc_inc import change_ecc_inc
-
-__all__ = ["change_a_inc", "change_argp", "change_ecc_inc"]
diff --git a/src/hapsira/core/thrust/change_a_inc.py b/src/hapsira/core/thrust/change_a_inc.py
index 2dbb3abed..1027193a8 100644
--- a/src/hapsira/core/thrust/change_a_inc.py
+++ b/src/hapsira/core/thrust/change_a_inc.py
@@ -1,59 +1,86 @@
-from numba import njit as jit
-import numpy as np
-from numpy import cross
+from math import atan2, cos, pi, sin, tan
-from hapsira._math.linalg import norm
-from hapsira.core.elements import circular_velocity
+from ..jit import hjit, gjit
+from ..elements import circular_velocity_hf
+from ..math.linalg import (
+ add_VV_hf,
+ cross_VV_hf,
+ div_Vs_hf,
+ mul_Vs_hf,
+ norm_V_hf,
+ sign_hf,
+)
-@jit
-def extra_quantities(k, a_0, a_f, inc_0, inc_f, f):
- """Extra quantities given by the Edelbaum (a, i) model."""
- V_0, V_f, beta_0_ = compute_parameters(k, a_0, a_f, inc_0, inc_f)
- delta_V_ = delta_V(V_0, V_f, beta_0_, inc_0, inc_f)
- t_f_ = delta_V_ / f
-
- return delta_V_, t_f_
+__all__ = [
+ "change_a_inc_hb",
+]
-@jit
-def beta(t, V_0, f, beta_0):
- """Compute yaw angle (β) as a function of time and the problem parameters."""
- return np.arctan2(V_0 * np.sin(beta_0), V_0 * np.cos(beta_0) - f * t)
-
-
-@jit
-def beta_0(V_0, V_f, inc_0, inc_f):
+@hjit("f(f,f,f,f)")
+def _beta_0_hf(V_0, V_f, inc_0, inc_f):
"""Compute initial yaw angle (β) as a function of the problem parameters."""
delta_i_f = abs(inc_f - inc_0)
- return np.arctan2(
- np.sin(np.pi / 2 * delta_i_f),
- V_0 / V_f - np.cos(np.pi / 2 * delta_i_f),
+ return atan2(
+ sin(pi / 2 * delta_i_f),
+ V_0 / V_f - cos(pi / 2 * delta_i_f),
)
-@jit
-def compute_parameters(k, a_0, a_f, inc_0, inc_f):
+@hjit("Tuple([f,f,f])(f,f,f,f,f)")
+def _compute_parameters_hf(k, a_0, a_f, inc_0, inc_f):
"""Compute parameters of the model."""
- V_0 = circular_velocity(k, a_0)
- V_f = circular_velocity(k, a_f)
- beta_0_ = beta_0(V_0, V_f, inc_0, inc_f)
+ V_0 = circular_velocity_hf(k, a_0)
+ V_f = circular_velocity_hf(k, a_f)
+ beta_0_ = _beta_0_hf(V_0, V_f, inc_0, inc_f)
return V_0, V_f, beta_0_
-@jit
-def delta_V(V_0, V_f, beta_0, inc_0, inc_f):
+@gjit("void(f,f,f,f,f,f[:],f[:],f[:])", "(),(),(),(),()->(),(),()")
+def _compute_parameters_gf(k, a_0, a_f, inc_0, inc_f, V_0, V_f, beta_0_):
+ """
+ Vectorized compute_parameters
+ """
+
+ V_0[0], V_f[0], beta_0_[0] = _compute_parameters_hf(k, a_0, a_f, inc_0, inc_f)
+
+
+@hjit("f(f,f,f,f,f)")
+def _delta_V_hf(V_0, V_f, beta_0, inc_0, inc_f):
"""Compute required increment of velocity."""
delta_i_f = abs(inc_f - inc_0)
if delta_i_f == 0:
return abs(V_f - V_0)
- return V_0 * np.cos(beta_0) - V_0 * np.sin(beta_0) / np.tan(
- np.pi / 2 * delta_i_f + beta_0
- )
+ return V_0 * cos(beta_0) - V_0 * sin(beta_0) / tan(pi / 2 * delta_i_f + beta_0)
+
+
+@hjit("Tuple([f,f])(f,f,f,f,f,f)")
+def _extra_quantities_hf(k, a_0, a_f, inc_0, inc_f, f):
+ """Extra quantities given by the Edelbaum (a, i) model."""
+ V_0, V_f, beta_0_ = _compute_parameters_hf(k, a_0, a_f, inc_0, inc_f)
+ delta_V = _delta_V_hf(V_0, V_f, beta_0_, inc_0, inc_f)
+ t_f_ = delta_V / f
+ return delta_V, t_f_
-def change_a_inc(k, a_0, a_f, inc_0, inc_f, f):
+
+@gjit("void(f,f,f,f,f,f,f[:],f[:])", "(),(),(),(),(),()->(),()")
+def _extra_quantities_gf(k, a_0, a_f, inc_0, inc_f, f, delta_V, t_f_):
+ """
+ Vectorized extra_quantities
+ """
+
+ delta_V[0], t_f_[0] = _extra_quantities_hf(k, a_0, a_f, inc_0, inc_f, f)
+
+
+@hjit("f(f,f,f,f)")
+def _beta_hf(t, V_0, f, beta_0):
+ """Compute yaw angle (β) as a function of time and the problem parameters."""
+ return atan2(V_0 * sin(beta_0), V_0 * cos(beta_0) - f * t)
+
+
+def change_a_inc_hb(k, a_0, a_f, inc_0, inc_f, f):
"""Change semimajor axis and inclination.
Guidance law from the Edelbaum/Kéchichian theory, optimal transfer between circular inclined orbits
(a_0, i_0) --> (a_f, i_f), ecc = 0.
@@ -76,7 +103,7 @@ def change_a_inc(k, a_0, a_f, inc_0, inc_f, f):
Returns
-------
a_d : function
- delta_V : numpy.ndarray
+ delta_V : float
t_f : float
Notes
@@ -85,25 +112,29 @@ def change_a_inc(k, a_0, a_f, inc_0, inc_f, f):
References
----------
- * Edelbaum, T. N. "Propulsion Requirements for Controllable
+ * Edelbaum, T. N. "Propulsion Requirements delta_V for Controllable
Satellites", 1961.
* Kéchichian, J. A. "Reformulation of Edelbaum's Low-Thrust
Transfer Problem Using Optimal Control Theory", 1997.
"""
- V_0, V_f, beta_0_ = compute_parameters(k, a_0, a_f, inc_0, inc_f)
-
- @jit
- def a_d(t0, u_, k):
- r = u_[:3]
- v = u_[3:]
+ V_0, _, beta_0_ = _compute_parameters_gf( # pylint: disable=E1120,E0633
+ k, a_0, a_f, inc_0, inc_f
+ )
+ @hjit("V(f,V,V,f)", cache=False)
+ def a_d_hf(t0, rr, vv, k):
# Change sign of beta with the out-of-plane velocity
- beta_ = beta(t0, V_0, f, beta_0_) * np.sign(r[0] * (inc_f - inc_0))
-
- t_ = v / norm(v)
- w_ = cross(r, v) / norm(cross(r, v))
- accel_v = f * (np.cos(beta_) * t_ + np.sin(beta_) * w_)
+ beta_ = _beta_hf(t0, V_0, f, beta_0_) * sign_hf(rr[0] * (inc_f - inc_0))
+
+ t_ = div_Vs_hf(vv, norm_V_hf(vv))
+ crv = cross_VV_hf(rr, vv)
+ w_ = div_Vs_hf(crv, norm_V_hf(crv))
+ accel_v = mul_Vs_hf(
+ add_VV_hf(mul_Vs_hf(t_, cos(beta_)), mul_Vs_hf(w_, sin(beta_))), f
+ )
return accel_v
- delta_V, t_f = extra_quantities(k, a_0, a_f, inc_0, inc_f, f)
- return a_d, delta_V, t_f
+ delta_V, t_f = _extra_quantities_gf( # pylint: disable=E1120,E0633
+ k, a_0, a_f, inc_0, inc_f, f
+ )
+ return a_d_hf, delta_V, t_f
diff --git a/src/hapsira/core/thrust/change_argp.py b/src/hapsira/core/thrust/change_argp.py
index ffaf7d888..20da694a3 100644
--- a/src/hapsira/core/thrust/change_argp.py
+++ b/src/hapsira/core/thrust/change_argp.py
@@ -1,31 +1,51 @@
-from numba import njit as jit
-import numpy as np
-from numpy import cross
+from math import cos, pi, sin, sqrt
-from hapsira._math.linalg import norm
-from hapsira.core.elements import circular_velocity, rv2coe
+from ..elements import circular_velocity_hf, rv2coe_hf, RV2COE_TOL
+from ..jit import hjit, gjit
+from ..math.linalg import (
+ add_VV_hf,
+ cross_VV_hf,
+ div_Vs_hf,
+ mul_Vs_hf,
+ norm_V_hf,
+ sign_hf,
+)
-@jit
-def delta_V(V, ecc, argp_0, argp_f, f, A):
+__all__ = [
+ "change_argp_hb",
+]
+
+
+@hjit("f(f,f,f,f,f,f)")
+def _delta_V_hf(V, ecc, argp_0, argp_f, f, A):
"""Compute required increment of velocity."""
delta_argp = argp_f - argp_0
return delta_argp / (
- 3 * np.sign(delta_argp) / 2 * np.sqrt(1 - ecc**2) / ecc / V + A / f
+ 3 * sign_hf(delta_argp) / 2 * sqrt(1 - ecc**2) / ecc / V + A / f
)
-@jit
-def extra_quantities(k, a, ecc, argp_0, argp_f, f, A=0.0):
+@hjit("Tuple([f,f])(f,f,f,f,f,f,f)")
+def _extra_quantities_hf(k, a, ecc, argp_0, argp_f, f, A):
"""Extra quantities given by the model."""
- V = circular_velocity(k, a)
- delta_V_ = delta_V(V, ecc, argp_0, argp_f, f, A)
+ V = circular_velocity_hf(k, a)
+ delta_V_ = _delta_V_hf(V, ecc, argp_0, argp_f, f, A)
t_f_ = delta_V_ / f
return delta_V_, t_f_
-def change_argp(k, a, ecc, argp_0, argp_f, f):
+@gjit("void(f,f,f,f,f,f,f,f[:],f[:])", "(),(),(),(),(),(),()->(),()")
+def _extra_quantities_gf(k, a, ecc, argp_0, argp_f, f, A, delta_V_, t_f_):
+ """
+ Vectorized extra_quantities
+ """
+
+ delta_V_[0], t_f_[0] = _extra_quantities_hf(k, a, ecc, argp_0, argp_f, f, A)
+
+
+def change_argp_hb(k, a, ecc, argp_0, argp_f, f):
"""Guidance law from the model.
Thrust is aligned with an inertially fixed direction perpendicular to the
semimajor axis of the orbit.
@@ -48,24 +68,27 @@ def change_argp(k, a, ecc, argp_0, argp_f, f):
Returns
-------
a_d : function
- delta_V : numpy.ndarray
+ delta_V : float
t_f : float
"""
- @jit
- def a_d(t0, u_, k):
- r = u_[:3]
- v = u_[3:]
- nu = rv2coe(k, r, v)[-1]
+ @hjit("V(f,V,V,f)", cache=False)
+ def a_d_hf(t0, rr, vv, k):
+ nu = rv2coe_hf(k, rr, vv, RV2COE_TOL)[-1]
- alpha_ = nu - np.pi / 2
+ alpha_ = nu - pi / 2
- r_ = r / norm(r)
- w_ = cross(r, v) / norm(cross(r, v))
- s_ = cross(w_, r_)
- accel_v = f * (np.cos(alpha_) * s_ + np.sin(alpha_) * r_)
+ r_ = div_Vs_hf(rr, norm_V_hf(rr))
+ crv = cross_VV_hf(rr, vv)
+ w_ = div_Vs_hf(crv, norm_V_hf(crv))
+ s_ = cross_VV_hf(w_, r_)
+ accel_v = mul_Vs_hf(
+ add_VV_hf(mul_Vs_hf(s_, cos(alpha_)), mul_Vs_hf(r_, sin(alpha_))), f
+ )
return accel_v
- delta_V, t_f = extra_quantities(k, a, ecc, argp_0, argp_f, f, A=0.0)
+ delta_V, t_f = _extra_quantities_gf( # pylint: disable=E1120,E0633
+ k, a, ecc, argp_0, argp_f, f, 0.0
+ )
- return a_d, delta_V, t_f
+ return a_d_hf, delta_V, t_f
diff --git a/src/hapsira/core/thrust/change_ecc_inc.py b/src/hapsira/core/thrust/change_ecc_inc.py
index dfa97c86b..4b2ebc2ac 100644
--- a/src/hapsira/core/thrust/change_ecc_inc.py
+++ b/src/hapsira/core/thrust/change_ecc_inc.py
@@ -5,82 +5,148 @@
* Pollard, J. E. "Simplified Analysis of Low-Thrust Orbital Maneuvers", 2000.
"""
-from numba import njit as jit
-import numpy as np
-from numpy import cross
-
-from hapsira._math.linalg import norm
-from hapsira.core.elements import (
- circular_velocity,
- eccentricity_vector,
- rv2coe,
+
+from math import asin, atan, cos, pi, log, sin
+
+from numpy import array
+
+from ..elements import (
+ circular_velocity_hf,
+ eccentricity_vector_hf,
+ rv2coe_hf,
+ RV2COE_TOL,
)
+from ..jit import array_to_V_hf, hjit, vjit, gjit
+from ..math.linalg import (
+ add_VV_hf,
+ cross_VV_hf,
+ div_Vs_hf,
+ mul_Vs_hf,
+ norm_V_hf,
+ sign_hf,
+)
+
+
+__all__ = [
+ "beta_hf",
+ "beta_vf",
+ "change_ecc_inc_hb",
+]
-@jit
-def beta(ecc_0, ecc_f, inc_0, inc_f, argp):
+@hjit("f(f,f,f,f,f)")
+def beta_hf(ecc_0, ecc_f, inc_0, inc_f, argp):
+ """
+ Scalar beta
+ """
# Note: "The argument of perigee will vary during the orbit transfer
# due to the natural drift and because e may approach zero.
# However, [the equation] still gives a good estimate of the desired
# thrust angle."
- return np.arctan(
+ return atan(
abs(
3
- * np.pi
+ * pi
* (inc_f - inc_0)
/ (
4
- * np.cos(argp)
+ * cos(argp)
* (
ecc_0
- ecc_f
- + np.log((1 + ecc_f) * (-1 + ecc_0) / ((1 + ecc_0) * (-1 + ecc_f)))
+ + log((1 + ecc_f) * (-1 + ecc_0) / ((1 + ecc_0) * (-1 + ecc_f)))
)
)
)
)
-@jit
-def delta_V(V_0, ecc_0, ecc_f, beta_):
- """Compute required increment of velocity."""
- return 2 * V_0 * np.abs(np.arcsin(ecc_0) - np.arcsin(ecc_f)) / (3 * np.cos(beta_))
+@vjit("f(f,f,f,f,f)")
+def beta_vf(ecc_0, ecc_f, inc_0, inc_f, argp):
+ """
+ Vectorized beta
+ """
+ return beta_hf(ecc_0, ecc_f, inc_0, inc_f, argp)
-@jit
-def delta_t(delta_v, f):
- """Compute required increment of velocity."""
+
+@hjit("f(f,f,f,f)")
+def _delta_V_hf(V_0, ecc_0, ecc_f, beta_):
+ """
+ Compute required increment of velocity.
+ """
+
+ return 2 * V_0 * abs(asin(ecc_0) - asin(ecc_f)) / (3 * cos(beta_))
+
+
+@hjit("f(f,f)")
+def _delta_t_hf(delta_v, f):
+ """
+ Compute required increment of velocity.
+ """
return delta_v / f
-def change_ecc_inc(k, a, ecc_0, ecc_f, inc_0, inc_f, argp, r, v, f):
+@hjit("Tuple([V,f,f,f])(f,f,f,f,f,f,f,V,V,f)")
+def _prepare_hf(k, a, ecc_0, ecc_f, inc_0, inc_f, argp, r, v, f):
+ """
+ Vectorized prepare
+ """
+
# We fix the inertial direction at the beginning
if ecc_0 > 0.001: # Arbitrary tolerance
- e_vec = eccentricity_vector(k, r, v)
- ref_vec = e_vec / ecc_0
+ e_vec = eccentricity_vector_hf(k, r, v)
+ ref_vec = div_Vs_hf(e_vec, ecc_0)
else:
- ref_vec = r / norm(r)
+ ref_vec = div_Vs_hf(r, norm_V_hf(r))
+
+ h_vec = cross_VV_hf(r, v) # Specific angular momentum vector
+ h_unit = div_Vs_hf(h_vec, norm_V_hf(h_vec))
+ thrust_unit = mul_Vs_hf(cross_VV_hf(h_unit, ref_vec), sign_hf(ecc_f - ecc_0))
- h_vec = cross(r, v) # Specific angular momentum vector
- h_unit = h_vec / norm(h_vec)
- thrust_unit = cross(h_unit, ref_vec) * np.sign(ecc_f - ecc_0)
+ beta_0 = beta_hf(ecc_0, ecc_f, inc_0, inc_f, argp)
- beta_0 = beta(ecc_0, ecc_f, inc_0, inc_f, argp)
+ delta_v = _delta_V_hf(circular_velocity_hf(k, a), ecc_0, ecc_f, beta_0)
+ t_f = _delta_t_hf(delta_v, f)
- @jit
- def a_d(t0, u_, k_):
- r_ = u_[:3]
- v_ = u_[3:]
- nu = rv2coe(k_, r_, v_)[-1]
- beta_ = beta_0 * np.sign(
- np.cos(nu)
+ return thrust_unit, beta_0, delta_v, t_f
+
+
+@gjit(
+ "void(f,f,f,f,f,f,f,f[:],f[:],f,f[:],f[:],f[:],f[:])",
+ "(),(),(),(),(),(),(),(n),(n),()->(n),(),(),()",
+)
+def _prepare_gf(
+ k, a, ecc_0, ecc_f, inc_0, inc_f, argp, r, v, f, thrust_unit, beta_0, delta_v, t_f
+):
+ """
+ Vectorized prepare
+ """
+
+ thrust_unit[:], beta_0[0], delta_v[0], t_f[0] = _prepare_hf(
+ k, a, ecc_0, ecc_f, inc_0, inc_f, argp, array_to_V_hf(r), array_to_V_hf(v), f
+ )
+
+
+def change_ecc_inc_hb(k, a, ecc_0, ecc_f, inc_0, inc_f, argp, r, v, f):
+ thrust_unit, beta_0, delta_v, t_f = _prepare_gf( # pylint: disable=E1120,E0633
+ k, a, ecc_0, ecc_f, inc_0, inc_f, argp, array(r), array(v), f
+ )
+ thrust_unit = tuple(thrust_unit)
+
+ @hjit("V(f,V,V,f)", cache=False)
+ def a_d_hf(t0, rr, vv, k_):
+ nu = rv2coe_hf(k_, rr, vv, RV2COE_TOL)[-1]
+ beta_ = beta_0 * sign_hf(
+ cos(nu)
) # The sign of ß reverses at minor axis crossings
- w_ = (cross(r_, v_) / norm(cross(r_, v_))) * np.sign(inc_f - inc_0)
- accel_v = f * (np.cos(beta_) * thrust_unit + np.sin(beta_) * w_)
+ w_ = mul_Vs_hf(
+ cross_VV_hf(rr, vv), sign_hf(inc_f - inc_0) / norm_V_hf(cross_VV_hf(rr, vv))
+ )
+ accel_v = mul_Vs_hf(
+ add_VV_hf(mul_Vs_hf(thrust_unit, cos(beta_)), mul_Vs_hf(w_, sin(beta_))), f
+ )
return accel_v
- delta_v = delta_V(circular_velocity(k, a), ecc_0, ecc_f, beta_0)
- t_f = delta_t(delta_v, f)
-
- return a_d, delta_v, t_f
+ return a_d_hf, delta_v, t_f
diff --git a/src/hapsira/core/thrust/change_ecc_quasioptimal.py b/src/hapsira/core/thrust/change_ecc_quasioptimal.py
index c78d2487c..2fe6fe287 100644
--- a/src/hapsira/core/thrust/change_ecc_quasioptimal.py
+++ b/src/hapsira/core/thrust/change_ecc_quasioptimal.py
@@ -1,20 +1,89 @@
-from numba import njit as jit
-import numpy as np
+from math import asin
-from hapsira.core.elements import circular_velocity
+from numpy import array
+from ..elements import circular_velocity_hf
+from ..jit import array_to_V_hf, hjit, gjit
+from ..math.linalg import cross_VV_hf, div_Vs_hf, mul_Vs_hf, norm_V_hf, sign_hf
-@jit
-def delta_V(V_0, ecc_0, ecc_f):
- """Compute required increment of velocity."""
- return 2 / 3 * V_0 * np.abs(np.arcsin(ecc_0) - np.arcsin(ecc_f))
+__all__ = [
+ "change_ecc_quasioptimal_hb",
+]
-@jit
-def extra_quantities(k, a, ecc_0, ecc_f, f):
- """Extra quantities given by the model."""
- V_0 = circular_velocity(k, a)
- delta_V_ = delta_V(V_0, ecc_0, ecc_f)
+@hjit("f(f,f,f)")
+def _delta_V_hf(V_0, ecc_0, ecc_f):
+ """
+ Compute required increment of velocity.
+ """
+
+ return 2 / 3 * V_0 * abs(asin(ecc_0) - asin(ecc_f))
+
+
+@hjit("Tuple([f,f])(f,f,f,f,f)")
+def _extra_quantities_hf(k, a, ecc_0, ecc_f, f):
+ """
+ Extra quantities given by the model.
+ """
+
+ V_0 = circular_velocity_hf(k, a)
+ delta_V_ = _delta_V_hf(V_0, ecc_0, ecc_f)
t_f_ = delta_V_ / f
return delta_V_, t_f_
+
+
+@gjit("void(f,f,f,f,f,f[:],f[:])", "(),(),(),(),()->(),()")
+def _extra_quantities_gf(k, a, ecc_0, ecc_f, f, delta_V_, t_f_):
+ """
+ Vectorized extra_quantities
+ """
+
+ delta_V_[0], t_f_[0] = _extra_quantities_hf(k, a, ecc_0, ecc_f, f)
+
+
+@hjit("V(f,f,f,f,V,V,V)")
+def _prepare_hf(k, a, ecc_0, ecc_f, e_vec, h_vec, r):
+ """
+ Scalar prepare
+ """
+
+ if ecc_0 > 0.001: # Arbitrary tolerance
+ ref_vec = div_Vs_hf(e_vec, ecc_0)
+ else:
+ ref_vec = div_Vs_hf(r, norm_V_hf(r))
+
+ h_unit = div_Vs_hf(h_vec, norm_V_hf(h_vec))
+ thrust_unit = mul_Vs_hf(cross_VV_hf(h_unit, ref_vec), sign_hf(ecc_f - ecc_0))
+
+ return thrust_unit
+
+
+@gjit("void(f,f,f,f,f[:],f[:],f[:],f[:])", "(),(),(),(),(n),(n),(n)->(n)")
+def _prepare_gf(k, a, ecc_0, ecc_f, e_vec, h_vec, r, thrust_unit):
+ """
+ Vectorized prepare
+ """
+
+ thrust_unit[:] = _prepare_hf(
+ k, a, ecc_0, ecc_f, array_to_V_hf(e_vec), array_to_V_hf(h_vec), array_to_V_hf(r)
+ )
+
+
+def change_ecc_quasioptimal_hb(k, a, ecc_0, ecc_f, e_vec, h_vec, r, f):
+ # We fix the inertial direction at the beginning
+
+ thrust_unit = _prepare_gf( # pylint: disable=E1120,E0633
+ k, a, ecc_0, array(ecc_f), array(e_vec), array(h_vec), r
+ )
+ thrust_unit = tuple(thrust_unit)
+
+ @hjit("V(f,V,V,f)", cache=False)
+ def a_d_hf(t0, rr, vv, k):
+ accel_v = mul_Vs_hf(thrust_unit, f)
+ return accel_v
+
+ delta_V, t_f = _extra_quantities_gf( # pylint: disable=E1120,E0633
+ k, a, ecc_0, ecc_f, f
+ )
+ return a_d_hf, delta_V, t_f
diff --git a/src/hapsira/core/util.py b/src/hapsira/core/util.py
index 8689cc0f9..c05483c31 100644
--- a/src/hapsira/core/util.py
+++ b/src/hapsira/core/util.py
@@ -1,24 +1,58 @@
+from math import cos, sin
+
from numba import njit as jit
import numpy as np
-from numpy import cos, sin
+from .jit import hjit, gjit
-@jit
-def rotation_matrix(angle, axis):
- assert axis in (0, 1, 2)
- angle = np.asarray(angle)
+
+__all__ = [
+ "rotation_matrix_hf",
+ "rotation_matrix_gf",
+ "alinspace",
+ "spherical_to_cartesian",
+ "planetocentric_to_AltAz_hf",
+]
+
+
+@hjit("M(f,i8)")
+def rotation_matrix_hf(angle, axis):
c = cos(angle)
s = sin(angle)
+ if axis == 0:
+ return (
+ (1.0, 0.0, 0.0),
+ (0.0, c, -s),
+ (0.0, s, c),
+ )
+ if axis == 1:
+ return (
+ (c, 0.0, s),
+ (0.0, 1.0, 0.0),
+ (-s, 0.0, c),
+ )
+ if axis == 2:
+ return (
+ (c, -s, 0.0),
+ (s, c, 0.0),
+ (0.0, 0.0, 1.0),
+ )
+ raise ValueError("Invalid axis: must be one of 0, 1 or 2")
+
+
+@gjit("void(f,i8,u1[:],f[:,:])", "(),(),(n)->(n,n)")
+def rotation_matrix_gf(angle, axis, dummy, r):
+ """
+ Vectorized rotation_matrix
- a1 = (axis + 1) % 3
- a2 = (axis + 2) % 3
- R = np.zeros(angle.shape + (3, 3))
- R[..., axis, axis] = 1.0
- R[..., a1, a1] = c
- R[..., a1, a2] = -s
- R[..., a2, a1] = s
- R[..., a2, a2] = c
- return R
+ `dummy` because of https://github.com/numba/numba/issues/2797
+ """
+ assert dummy.shape == (3,)
+ (
+ (r[0, 0], r[0, 1], r[0, 2]),
+ (r[1, 0], r[1, 1], r[1, 2]),
+ (r[2, 0], r[2, 1], r[2, 2]),
+ ) = rotation_matrix_hf(angle, axis)
@jit
@@ -72,8 +106,8 @@ def spherical_to_cartesian(v):
return norm_vecs * np.stack((x, y, z), axis=-1)
-@jit
-def planetocentric_to_AltAz(theta, phi):
+@hjit("M(f,f)")
+def planetocentric_to_AltAz_hf(theta, phi):
r"""Defines transformation matrix to convert from Planetocentric coordinate system
to the Altitude-Azimuth system.
@@ -93,23 +127,25 @@ def planetocentric_to_AltAz(theta, phi):
Returns
-------
- t_matrix: numpy.ndarray
+ t_matrix: tuple[tuple[float,float,float],...]
Transformation matrix
"""
# Transformation matrix for converting planetocentric equatorial coordinates to topocentric horizon system.
- t_matrix = np.array(
- [
- [-np.sin(theta), np.cos(theta), 0],
- [
- -np.sin(phi) * np.cos(theta),
- -np.sin(phi) * np.sin(theta),
- np.cos(phi),
- ],
- [
- np.cos(phi) * np.cos(theta),
- np.cos(phi) * np.sin(theta),
- np.sin(phi),
- ],
- ]
+ st = sin(theta)
+ ct = cos(theta)
+ sp = sin(phi)
+ cp = cos(phi)
+
+ return (
+ (-st, ct, 0.0),
+ (
+ -sp * ct,
+ -sp * st,
+ cp,
+ ),
+ (
+ cp * ct,
+ cp * st,
+ sp,
+ ),
)
- return t_matrix
diff --git a/src/hapsira/debug.py b/src/hapsira/debug.py
new file mode 100644
index 000000000..de206d539
--- /dev/null
+++ b/src/hapsira/debug.py
@@ -0,0 +1,19 @@
+import logging
+
+from hapsira.settings import settings
+
+__all__ = [
+ "logger",
+]
+
+logger = logging.getLogger("hapsira")
+
+if settings["LOGLEVEL"].value != "NOTSET":
+ logger.setLevel(
+ logging.DEBUG
+ if settings["DEBUG"].value
+ else getattr(logging, settings["LOGLEVEL"].value)
+ )
+
+logger.debug("logging level: %s", logging.getLevelName(logger.level))
+logger.debug("debug mode: %s", "on" if settings["DEBUG"].value else "off")
diff --git a/src/hapsira/earth/__init__.py b/src/hapsira/earth/__init__.py
index 55865dcff..09e650fc6 100644
--- a/src/hapsira/earth/__init__.py
+++ b/src/hapsira/earth/__init__.py
@@ -1,13 +1,13 @@
"""Earth focused orbital mechanics routines."""
-from typing import Dict
from astropy import units as u
-import numpy as np
from hapsira.bodies import Earth
-from hapsira.core.perturbations import J2_perturbation
-from hapsira.core.propagation import func_twobody
+from hapsira.core.jit import hjit, djit
+from hapsira.core.math.linalg import add_VV_hf
+from hapsira.core.perturbations import J2_perturbation_hf
+from hapsira.core.propagation.base import func_twobody_hf
from hapsira.earth.enums import EarthGravity
from hapsira.twobody.propagation import CowellPropagator
@@ -73,36 +73,34 @@ def propagate(self, tof, atmosphere=None, gravity=None, *args):
A new EarthSatellite with the propagated Orbit
"""
- ad_kwargs: Dict[object, dict] = {}
- perturbations: Dict[object, dict] = {}
-
- def ad(t0, state, k, perturbations):
- if perturbations:
- return np.sum(
- [f(t0=t0, state=state, k=k, **p) for f, p in perturbations.items()],
- axis=0,
- )
- else:
- return np.array([0, 0, 0])
-
- if gravity is EarthGravity.J2:
- perturbations[J2_perturbation] = {
- "J2": Earth.J2.value,
- "R": Earth.R.to_value(u.km),
- }
+
+ if gravity not in (None, EarthGravity.J2):
+ raise NotImplementedError
+
if atmosphere is not None:
# Cannot compute density without knowing the state,
# the perturbations parameters are not always fixed
- # TODO: This whole function probably needs a refactoring
raise NotImplementedError
- def f(t0, state, k):
- du_kep = func_twobody(t0, state, k)
- ax, ay, az = ad(t0, state, k, perturbations)
- du_ad = np.array([0, 0, 0, ax, ay, az])
+ if gravity:
+ J2_ = Earth.J2.value
+ R_ = Earth.R.to_value(u.km)
+
+ @hjit("V(f,V,V,f)", cache=False)
+ def ad_hf(t0, rr, vv, k):
+ return J2_perturbation_hf(t0, rr, vv, k, J2_, R_)
+
+ else:
+
+ @hjit("V(f,V,V,f)")
+ def ad_hf(t0, rr, vv, k):
+ return 0.0, 0.0, 0.0
- return du_kep + du_ad
+ @djit(cache=False)
+ def f_hf(t0, rr, vv, k):
+ du_kep_rr, du_kep_vv = func_twobody_hf(t0, rr, vv, k)
+ du_ad_vv = ad_hf(t0, rr, vv, k)
+ return du_kep_rr, add_VV_hf(du_kep_vv, du_ad_vv)
- ad_kwargs.update(perturbations=perturbations)
- new_orbit = self.orbit.propagate(tof, method=CowellPropagator(f=f))
+ new_orbit = self.orbit.propagate(tof, method=CowellPropagator(f=f_hf))
return EarthSatellite(new_orbit, self.spacecraft)
diff --git a/src/hapsira/earth/atmosphere/base.py b/src/hapsira/earth/atmosphere/base.py
index dc5eeb2f7..2f81f9631 100644
--- a/src/hapsira/earth/atmosphere/base.py
+++ b/src/hapsira/earth/atmosphere/base.py
@@ -2,9 +2,9 @@
import astropy.units as u
-from hapsira.core.earth_atmosphere.util import (
- _check_altitude as _check_altitude_fast,
- _get_index as _get_index_fast,
+from hapsira.core.earth.atmosphere.coesa import (
+ check_altitude_hf,
+ get_index_hf,
)
@@ -69,7 +69,7 @@ def _check_altitude(self, alt, r0, geometric=True):
alt = alt.to_value(u.km)
r0 = r0.to_value(u.km)
- z, h = _check_altitude_fast(alt, r0, geometric)
+ z, h = check_altitude_hf(alt, r0, geometric) # TODO call from compiled context
z, h = z * u.km, h * u.km
# Assert in range
@@ -98,5 +98,5 @@ def _get_index(self, x, x_levels):
"""
x = x.to_value(u.km)
x_levels = (x_levels << u.km).value
- i = _get_index_fast(x, x_levels)
+ i = get_index_hf(x, x_levels) # TODO call from compiled context
return i
diff --git a/src/hapsira/earth/atmosphere/coesa62.py b/src/hapsira/earth/atmosphere/coesa62.py
index ee6f5dcc4..eee4ed245 100644
--- a/src/hapsira/earth/atmosphere/coesa62.py
+++ b/src/hapsira/earth/atmosphere/coesa62.py
@@ -53,12 +53,12 @@
"""
from astropy import units as u
-from astropy.io import ascii
+from astropy.io import ascii as ascii_
from astropy.units import imperial
from astropy.utils.data import get_pkg_data_filename
import numpy as np
-from hapsira._math.integrate import quad
+from hapsira.core.math.integrate import quad
from hapsira.earth.atmosphere.base import COESA
# Constants come from the original paper to achieve pure implementation
@@ -78,8 +78,8 @@
alpha = 34.1632 * u.K / u.km
# Reading layer parameters file
-coesa_file = get_pkg_data_filename("data/coesa62.dat")
-coesa62_data = ascii.read(coesa_file)
+coesa_file = get_pkg_data_filename("../../core/earth/atmosphere/data/coesa62.dat")
+coesa62_data = ascii_.read(coesa_file)
b_levels = coesa62_data["b"].data
zb_levels = coesa62_data["Zb [km]"].data * u.km
hb_levels = coesa62_data["Hb [km]"].data * u.km
diff --git a/src/hapsira/earth/atmosphere/coesa76.py b/src/hapsira/earth/atmosphere/coesa76.py
index 465d5d435..95cb2bc90 100644
--- a/src/hapsira/earth/atmosphere/coesa76.py
+++ b/src/hapsira/earth/atmosphere/coesa76.py
@@ -44,58 +44,69 @@
"""
from astropy import units as u
-from astropy.io import ascii
-from astropy.utils.data import get_pkg_data_filename
import numpy as np
from hapsira.earth.atmosphere.base import COESA
-# Following constants come from original U.S Atmosphere 1962 paper so a pure
-# model of this atmosphere can be implemented
-R = 8314.32 * u.J / u.kmol / u.K
-R_air = 287.053 * u.J / u.kg / u.K
-k = 1.380622e-23 * u.J / u.K
-Na = 6.022169e-26 / u.kmol
-g0 = 9.80665 * u.m / u.s**2
-r0 = 6356.766 * u.km
-M0 = 28.9644 * u.kg / u.kmol
-P0 = 101325 * u.Pa
-T0 = 288.15 * u.K
-Tinf = 1000 * u.K
-gamma = 1.4
-alpha = 34.1632 * u.K / u.km
-beta = 1.458e-6 * (u.kg / u.s / u.m / (u.K) ** 0.5)
-S = 110.4 * u.K
+from hapsira.core.earth.atmosphere.coesa76 import (
+ R,
+ R_air,
+ k,
+ Na,
+ g0,
+ r0,
+ M0,
+ P0,
+ T0,
+ Tinf,
+ gamma,
+ alpha,
+ beta,
+ S,
+ b_levels,
+ zb_levels,
+ hb_levels,
+ Tb_levels,
+ Lb_levels,
+ pb_levels,
+ z_coeff,
+ p_coeff,
+ rho_coeff,
+ pressure_vf,
+ density_vf,
+ temperature_vf,
+)
+
+__all__ = [
+ "COESA76",
+]
+
+R = R * u.J / u.kmol / u.K
+R_air = R_air * u.J / u.kg / u.K
+k = k * u.J / u.K
+Na = Na / u.kmol
+g0 = g0 * u.m / u.s**2
+r0 = r0 * u.km
+M0 = M0 * u.kg / u.kmol
+P0 = P0 * u.Pa
+T0 = T0 * u.K
+Tinf = Tinf * u.K
+alpha = alpha * u.K / u.km
+beta = beta * (u.kg / u.s / u.m / (u.K) ** 0.5)
+S = S * u.K
# Reading layer parameters file
-coesa76_data = ascii.read(get_pkg_data_filename("data/coesa76.dat"))
-b_levels = coesa76_data["b"].data
-zb_levels = coesa76_data["Zb [km]"].data * u.km
-hb_levels = coesa76_data["Hb [km]"].data * u.km
-Tb_levels = coesa76_data["Tb [K]"].data * u.K
-Lb_levels = coesa76_data["Lb [K/km]"].data * u.K / u.km
-pb_levels = coesa76_data["pb [mbar]"].data * u.mbar
-
-# Reading pressure and density coefficients files
-p_data = ascii.read(get_pkg_data_filename("data/coesa76_p.dat"))
-rho_data = ascii.read(get_pkg_data_filename("data/coesa76_rho.dat"))
+b_levels = np.array(b_levels)
+zb_levels = np.array(zb_levels) * u.km
+hb_levels = np.array(hb_levels) * u.km
+Tb_levels = np.array(Tb_levels) * u.K
+Lb_levels = np.array(Lb_levels) * u.K / u.km
+pb_levels = np.array(pb_levels) * u.mbar
# Zip coefficients for each altitude
-z_coeff = p_data["z [km]"].data * u.km
-p_coeff = [
- p_data["A"].data,
- p_data["B"].data,
- p_data["C"].data,
- p_data["D"].data,
- p_data["E"].data,
-]
-rho_coeff = [
- rho_data["A"].data,
- rho_data["B"].data,
- rho_data["C"].data,
- rho_data["D"].data,
- rho_data["E"].data,
-]
+z_coeff = z_coeff * u.km
+p_coeff = [np.array(entry) for entry in p_coeff]
+rho_coeff = [np.array(entry) for entry in rho_coeff]
class COESA76(COESA):
@@ -145,40 +156,8 @@ def temperature(self, alt, geometric=True):
T: ~astropy.units.Quantity
Kinetic temeperature.
"""
- # Test if altitude is inside valid range
- z, h = self._check_altitude(alt, r0, geometric=geometric)
- # Get base parameters
- i = self._get_index(z, self.zb_levels)
- Tb = self.Tb_levels[i]
- Lb = self.Lb_levels[i]
- hb = self.hb_levels[i]
-
- # Apply different equations
- if z < self.zb_levels[7]:
- # Below 86km
- # TODO: Apply air mean molecular weight ratio factor
- Tm = Tb + Lb * (h - hb)
- T = Tm
- elif self.zb_levels[7] <= z and z < self.zb_levels[8]:
- # [86km, 91km)
- T = 186.87 * u.K
- elif self.zb_levels[8] <= z and z < self.zb_levels[9]:
- # [91km, 110km]
- Tc = 263.1905 * u.K
- A = -76.3232 * u.K
- a = -19.9429 * u.km
- T = Tc + A * (1 - ((z - self.zb_levels[8]) / a) ** 2) ** 0.5
- elif self.zb_levels[9] <= z and z < self.zb_levels[10]:
- # [110km, 120km]
- T = 240 * u.K + Lb * (z - self.zb_levels[9])
- else:
- T10 = 360.0 * u.K
- _gamma = self.Lb_levels[9] / (Tinf - T10)
- epsilon = (z - self.zb_levels[10]) * (r0 + self.zb_levels[10]) / (r0 + z)
- T = Tinf - (Tinf - T10) * np.exp(-_gamma * epsilon)
-
- return T.to(u.K)
+ return temperature_vf(alt.to_value(u.km), geometric) * u.K
def pressure(self, alt, geometric=True):
"""Solves pressure at given altitude.
@@ -195,36 +174,8 @@ def pressure(self, alt, geometric=True):
p: ~astropy.units.Quantity
Pressure at given altitude.
"""
- # Test if altitude is inside valid range
- z, h = self._check_altitude(alt, r0, geometric=geometric)
- # Obtain gravity magnitude
- # Get base parameters
- i = self._get_index(z, self.zb_levels)
- Tb = self.Tb_levels[i]
- Lb = self.Lb_levels[i]
- hb = self.hb_levels[i]
- pb = self.pb_levels[i]
-
- # If above 86[km] usual formulation is applied
- if z < 86 * u.km:
- if Lb == 0.0 * u.K / u.km:
- p = pb * np.exp(-alpha * (h - hb) / Tb)
- else:
- T = self.temperature(z)
- p = pb * (Tb / T) ** (alpha / Lb)
- else:
- # TODO: equation (33c) should be applied instead of using coefficients
-
- # A 4th order polynomial is used to approximate pressure. This was
- # directly taken from: http://www.braeunig.us/space/atmmodel.htm
- A, B, C, D, E = self._get_coefficients_avobe_86(z, p_coeff)
-
- # Solve the polynomial
- z = z.to_value(u.km)
- p = np.exp(A * z**4 + B * z**3 + C * z**2 + D * z + E) * u.Pa
-
- return p.to(u.Pa)
+ return pressure_vf(alt.to_value(u.km), geometric) * u.Pa
def density(self, alt, geometric=True):
"""Solves density at given height.
@@ -241,30 +192,8 @@ def density(self, alt, geometric=True):
rho: ~astropy.units.Quantity
Density at given height.
"""
- # Test if altitude is inside valid range
- z, h = self._check_altitude(alt, r0, geometric=geometric)
-
- # Solve temperature and pressure
- if z <= 86 * u.km:
- T = self.temperature(z)
- p = self.pressure(z)
- rho = p / R_air / T
- else:
- # TODO: equation (42) should be applied instead of using coefficients
-
- # A 4th order polynomial is used to approximate pressure. This was
- # directly taken from: http://www.braeunig.us/space/atmmodel.htm
- A, B, C, D, E = self._get_coefficients_avobe_86(z, rho_coeff)
-
- # Solve the polynomial
- z = z.to_value(u.km)
- rho = (
- np.exp(A * z**4 + B * z**3 + C * z**2 + D * z + E)
- * u.kg
- / u.m**3
- )
- return rho.to(u.kg / u.m**3)
+ return density_vf(alt.to_value(u.km), geometric) * u.kg / u.m**3
def properties(self, alt, geometric=True):
"""Solves temperature, pressure, density at given height.
diff --git a/src/hapsira/earth/atmosphere/jacchia.py b/src/hapsira/earth/atmosphere/jacchia.py
index c2695248b..fcfb6e733 100644
--- a/src/hapsira/earth/atmosphere/jacchia.py
+++ b/src/hapsira/earth/atmosphere/jacchia.py
@@ -1,7 +1,7 @@
from astropy import units as u
import numpy as np
-from hapsira.core.earth_atmosphere.jacchia import (
+from hapsira.core.earth.atmosphere.jacchia import (
_altitude_profile as _altitude_profile_fast,
_H_correction as _H_correction_fast,
_O_and_O2_correction as _O_and_O2_correction_fast,
diff --git a/src/hapsira/ephem.py b/src/hapsira/ephem.py
index 5f46772ef..38705e3d5 100644
--- a/src/hapsira/ephem.py
+++ b/src/hapsira/ephem.py
@@ -10,8 +10,8 @@
)
from astroquery.jplhorizons import Horizons
-from hapsira._math.interpolate import interp1d, sinc_interp, spline_interp
from hapsira.bodies import Earth
+from hapsira.core.math.interpolate import interp_hb, sinc_interp, spline_interp
from hapsira.frames import Planes
from hapsira.frames.util import get_frame
from hapsira.twobody.sampling import EpochsArray
@@ -43,7 +43,7 @@ def build_ephem_interpolant(body, epochs, attractor=Earth):
"""
ephem = Ephem.from_body(body, epochs, attractor=attractor)
- interpolant = interp1d(
+ interpolant = interp_hb(
(epochs - epochs[0]).to_value(u.s),
ephem._coordinates.xyz.to_value(u.km),
)
diff --git a/src/hapsira/errors.py b/src/hapsira/errors.py
new file mode 100644
index 000000000..838069523
--- /dev/null
+++ b/src/hapsira/errors.py
@@ -0,0 +1,2 @@
+class JitError(Exception):
+ pass
diff --git a/src/hapsira/frames/ecliptic.py b/src/hapsira/frames/ecliptic.py
index 65c229162..6bb4044c8 100644
--- a/src/hapsira/frames/ecliptic.py
+++ b/src/hapsira/frames/ecliptic.py
@@ -12,7 +12,6 @@
)
from astropy.coordinates.builtin_frames.utils import DEFAULT_OBSTIME, get_jd12
from astropy.coordinates.matrix_utilities import (
- matrix_product,
matrix_transpose,
rotation_matrix,
)
@@ -71,7 +70,7 @@ def gcrs_to_geosolarecliptic(gcrs_coo, to_frame):
rot_matrix = _make_rotation_matrix_from_reprs(sun_earth_detilt, x_axis)
- return matrix_product(rot_matrix, _earth_detilt_matrix)
+ return rot_matrix @ _earth_detilt_matrix
@frame_transform_graph.transform(DynamicMatrixTransform, GeocentricSolarEcliptic, GCRS)
diff --git a/src/hapsira/iod/izzo.py b/src/hapsira/iod/izzo.py
index 42af60b85..8832459ed 100644
--- a/src/hapsira/iod/izzo.py
+++ b/src/hapsira/iod/izzo.py
@@ -1,7 +1,8 @@
"""Izzo's algorithm for Lambert's problem."""
from astropy import units as u
+import numpy as np
-from hapsira.core.iod import izzo as izzo_fast
+from hapsira.core.iod import izzo_gf
kms = u.km / u.s
@@ -44,5 +45,7 @@ def lambert(k, r0, r, tof, M=0, prograde=True, lowpath=True, numiter=35, rtol=1e
r_ = r.to_value(u.km)
tof_ = tof.to_value(u.s)
- v0, v = izzo_fast(k_, r0_, r_, tof_, M, prograde, lowpath, numiter, rtol)
+ v0, v = izzo_gf( # pylint: disable=E1120,E0633
+ k_, r0_, r_, tof_, M, np.asarray(prograde), np.asarray(lowpath), numiter, rtol
+ )
return v0 << kms, v << kms
diff --git a/src/hapsira/iod/vallado.py b/src/hapsira/iod/vallado.py
index a267d744e..7b2e40b46 100644
--- a/src/hapsira/iod/vallado.py
+++ b/src/hapsira/iod/vallado.py
@@ -1,7 +1,8 @@
"""Initial orbit determination."""
from astropy import units as u
+import numpy as np
-from hapsira.core.iod import vallado as vallado_fast
+from hapsira.core.iod import vallado_gf
kms = u.km / u.s
@@ -55,6 +56,8 @@ def lambert(k, r0, r, tof, M=0, prograde=True, lowpath=True, numiter=35, rtol=1e
r_ = r.to_value(u.km)
tof_ = tof.to_value(u.s)
- v0, v = vallado_fast(k_, r0_, r_, tof_, M, prograde, lowpath, numiter, rtol)
+ v0, v = vallado_gf( # pylint: disable=E1120,E0633
+ k_, r0_, r_, tof_, M, np.asarray(prograde), np.asarray(lowpath), numiter, rtol
+ )
return v0 << kms, v << kms
diff --git a/src/hapsira/settings.py b/src/hapsira/settings.py
new file mode 100644
index 000000000..71ba3e024
--- /dev/null
+++ b/src/hapsira/settings.py
@@ -0,0 +1,187 @@
+import os
+from typing import Any, Generator, Optional, Tuple, Type
+
+__all__ = [
+ "Setting",
+ "Settings",
+ "settings",
+]
+
+
+def _str2bool(value: str) -> bool:
+ """
+ Helper for parsing environment variables
+ """
+
+ if value.strip().lower() in ("true", "1", "yes", "y"):
+ return True
+ if value.strip().lower() in ("false", "0", "no", "n"):
+ return False
+
+ raise ValueError(f'can not convert value "{value:s}" to bool')
+
+
+class Setting:
+ """
+ Holds one setting settable by user before sub-module import
+ """
+
+ def __init__(self, name: str, default: Any, options: Optional[Tuple[Any]] = None):
+ self._name = name
+ self._type = type(default)
+ self._value = default
+ self._options = options
+ self._check_env()
+
+ def _check_env(self):
+ """
+ Check for environment variables
+ """
+ value = os.environ.get(f"HAPSIRA_{self._name:s}")
+ if value is None:
+ return
+ if self._type is bool:
+ value = _str2bool(value)
+ self.value = value # Run through setter for checks!
+
+ @property
+ def name(self) -> str:
+ """
+ Return name of setting
+ """
+ return self._name
+
+ @property
+ def type_(self) -> Type:
+ """
+ Return type of setting
+ """
+ return self._type
+
+ @property
+ def options(self) -> Optional[Tuple[Any]]:
+ """
+ Return options for value
+ """
+ return self._options
+
+ @property
+ def value(self) -> Any:
+ """
+ Change value of setting
+ """
+ return self._value
+
+ @value.setter
+ def value(self, new_value: Any):
+ """
+ Return value of setting
+ """
+ if not isinstance(new_value, self._type):
+ raise TypeError(
+ f"{repr(new_value):s} has type {repr(type(new_value)):s}, expected type {repr(self._type):s}"
+ )
+ if self._options is not None and new_value not in self._options:
+ raise ValueError(
+ f"value {repr(new_value):s} not a valid option, valid options are {repr(self._options):s}"
+ )
+ self._value = new_value
+
+
+class Settings:
+ """
+ Holds settings settable by user before sub-module import
+ """
+
+ def __init__(self):
+ self._settings = {}
+ self._add(
+ Setting(
+ "DEBUG",
+ False,
+ )
+ )
+ self._add(
+ Setting(
+ "CACHE",
+ not self["DEBUG"].value,
+ )
+ )
+ self._add(
+ Setting(
+ "LOGLEVEL",
+ "NOTSET",
+ options=("DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL", "NOTSET"),
+ )
+ )
+ self._add(
+ Setting(
+ "TARGET",
+ "cpu",
+ options=("cpu", "parallel", "cuda"),
+ )
+ )
+ self._add(
+ Setting(
+ "INLINE",
+ self["TARGET"].value == "cuda",
+ )
+ )
+ self._add(
+ Setting(
+ "NOPYTHON",
+ True,
+ )
+ )
+ self._add(
+ Setting(
+ "FORCEOBJ",
+ False,
+ )
+ ),
+ self._add(
+ Setting(
+ "PRECISION",
+ "f8",
+ options=("f2", "f4", "f8"),
+ )
+ )
+
+ def _add(self, setting: Setting):
+ """
+ Add new setting
+ """
+ self._settings[setting.name] = setting
+
+ def _validate(self, name: str):
+ """
+ Validate name
+ """
+ if name in self._settings.keys():
+ return
+ raise KeyError(
+ f'setting "{name:s}" unknown, possible settings are {repr(list(self._settings.keys())):s}'
+ )
+
+ def __getitem__(self, name: str) -> Setting:
+ """
+ Return setting by name
+ """
+ self._validate(name)
+ return self._settings[name]
+
+ def __setitem__(self, name: str, new_value: Any):
+ """
+ Return setting by name
+ """
+ self._validate(name)
+ self._settings[name].value = new_value
+
+ def keys(self) -> Generator:
+ """
+ Generator of all setting names
+ """
+ return (name for name in self._settings.keys())
+
+
+settings = Settings()
diff --git a/src/hapsira/spheroid_location.py b/src/hapsira/spheroid_location.py
index c1a9f5073..e466f6258 100644
--- a/src/hapsira/spheroid_location.py
+++ b/src/hapsira/spheroid_location.py
@@ -4,7 +4,7 @@
from hapsira.core.spheroid_location import (
N as N_fast,
cartesian_cords as cartesian_cords_fast,
- cartesian_to_ellipsoidal as cartesian_to_ellipsoidal_fast,
+ cartesian_to_ellipsoidal_gf,
distance as distance_fast,
f as f_fast,
is_visible as is_visible_fast,
@@ -145,5 +145,11 @@ def cartesian_to_ellipsoidal(self, x, y, z):
"""
_a, _c = self._a.to_value(u.m), self._c.to_value(u.m)
x, y, z = x.to_value(u.m), y.to_value(u.m), z.to_value(u.m)
- lon, lat, h = cartesian_to_ellipsoidal_fast(_a, _c, x, y, z)
+ lon, lat, h = cartesian_to_ellipsoidal_gf( # pylint: disable=E1120,E0633
+ _a,
+ _c,
+ x,
+ y,
+ z,
+ )
return lon * u.rad, lat * u.rad, h * u.m
diff --git a/src/hapsira/threebody/restricted.py b/src/hapsira/threebody/restricted.py
index 76c5e1d10..fc7dda17b 100644
--- a/src/hapsira/threebody/restricted.py
+++ b/src/hapsira/threebody/restricted.py
@@ -7,7 +7,14 @@
from astropy import units as u
import numpy as np
-from hapsira._math.optimize import brentq
+from hapsira.core.jit import hjit
+from hapsira.core.math.ivp import (
+ brentq_gb,
+ BRENTQ_XTOL,
+ BRENTQ_RTOL,
+ BRENTQ_MAXITER,
+ BRENTQ_CONVERGED,
+)
from hapsira.util import norm
@@ -37,27 +44,39 @@ def lagrange_points(r12, m1, m2):
"""
pi2 = (m2 / (m1 + m2)).value
+ @hjit("f(f)", cache=False)
def eq_L123(xi):
aux = (1 - pi2) * (xi + pi2) / abs(xi + pi2) ** 3
aux += pi2 * (xi + pi2 - 1) / abs(xi + pi2 - 1) ** 3
aux -= xi
return aux
+ brentq_gf = brentq_gb(eq_L123)
+
lp = np.zeros((5,))
# L1
tol = 1e-11 # `brentq` uses a xtol of 2e-12, so it should be covered
a = -pi2 + tol
b = 1 - pi2 - tol
- xi = brentq(eq_L123, a, b)
+ xi, status = brentq_gf( # pylint: disable=E0633,E1120
+ a, b, BRENTQ_XTOL, BRENTQ_RTOL, BRENTQ_MAXITER
+ )
+ assert status == BRENTQ_CONVERGED
lp[0] = xi + pi2
# L2
- xi = brentq(eq_L123, 1, 1.5)
+ xi, status = brentq_gf( # pylint: disable=E0633,E1120
+ 1, 1.5, BRENTQ_XTOL, BRENTQ_RTOL, BRENTQ_MAXITER
+ )
+ assert status == BRENTQ_CONVERGED
lp[1] = xi + pi2
# L3
- xi = brentq(eq_L123, -1.5, -1)
+ xi, status = brentq_gf( # pylint: disable=E0633,E1120
+ -1.5, -1, BRENTQ_XTOL, BRENTQ_RTOL, BRENTQ_MAXITER
+ )
+ assert status == BRENTQ_CONVERGED
lp[2] = xi + pi2
# L4, L5
diff --git a/src/hapsira/twobody/angles.py b/src/hapsira/twobody/angles.py
index a07091be4..1cf38326f 100644
--- a/src/hapsira/twobody/angles.py
+++ b/src/hapsira/twobody/angles.py
@@ -2,19 +2,19 @@
from astropy import units as u
from hapsira.core.angles import (
- D_to_M as D_to_M_fast,
- D_to_nu as D_to_nu_fast,
- E_to_M as E_to_M_fast,
- E_to_nu as E_to_nu_fast,
- F_to_M as F_to_M_fast,
- F_to_nu as F_to_nu_fast,
- M_to_D as M_to_D_fast,
- M_to_E as M_to_E_fast,
- M_to_F as M_to_F_fast,
- fp_angle as fp_angle_fast,
- nu_to_D as nu_to_D_fast,
- nu_to_E as nu_to_E_fast,
- nu_to_F as nu_to_F_fast,
+ D_to_M_vf,
+ D_to_nu_vf,
+ E_to_M_vf,
+ E_to_nu_vf,
+ F_to_M_vf,
+ F_to_nu_vf,
+ M_to_D_vf,
+ M_to_E_vf,
+ M_to_F_vf,
+ fp_angle_vf,
+ nu_to_D_vf,
+ nu_to_E_vf,
+ nu_to_F_vf,
)
@@ -38,7 +38,7 @@ def D_to_nu(D):
"Robust resolution of Kepler’s equation in all eccentricity regimes."
Celestial Mechanics and Dynamical Astronomy 116, no. 1 (2013): 21-34.
"""
- return (D_to_nu_fast(D.to_value(u.rad)) * u.rad).to(D.unit)
+ return (D_to_nu_vf(D.to_value(u.rad)) * u.rad).to(D.unit)
@u.quantity_input(nu=u.rad)
@@ -61,7 +61,7 @@ def nu_to_D(nu):
"Robust resolution of Kepler’s equation in all eccentricity regimes."
Celestial Mechanics and Dynamical Astronomy 116, no. 1 (2013): 21-34.
"""
- return (nu_to_D_fast(nu.to_value(u.rad)) * u.rad).to(nu.unit)
+ return (nu_to_D_vf(nu.to_value(u.rad)) * u.rad).to(nu.unit)
@u.quantity_input(nu=u.rad, ecc=u.one)
@@ -83,7 +83,7 @@ def nu_to_E(nu, ecc):
Eccentric anomaly.
"""
- return (nu_to_E_fast(nu.to_value(u.rad), ecc.value) * u.rad).to(nu.unit)
+ return (nu_to_E_vf(nu.to_value(u.rad), ecc.value) * u.rad).to(nu.unit)
@u.quantity_input(nu=u.rad, ecc=u.one)
@@ -107,7 +107,7 @@ def nu_to_F(nu, ecc):
Taken from Curtis, H. (2013). *Orbital mechanics for engineering students*. 167
"""
- return (nu_to_F_fast(nu.to_value(u.rad), ecc.value) * u.rad).to(nu.unit)
+ return (nu_to_F_vf(nu.to_value(u.rad), ecc.value) * u.rad).to(nu.unit)
@u.quantity_input(E=u.rad, ecc=u.one)
@@ -129,7 +129,7 @@ def E_to_nu(E, ecc):
True anomaly.
"""
- return (E_to_nu_fast(E.to_value(u.rad), ecc.value) * u.rad).to(E.unit)
+ return (E_to_nu_vf(E.to_value(u.rad), ecc.value) * u.rad).to(E.unit)
@u.quantity_input(F=u.rad, ecc=u.one)
@@ -149,7 +149,7 @@ def F_to_nu(F, ecc):
True anomaly.
"""
- return (F_to_nu_fast(F.to_value(u.rad), ecc.value) * u.rad).to(F.unit)
+ return (F_to_nu_vf(F.to_value(u.rad), ecc.value) * u.rad).to(F.unit)
@u.quantity_input(M=u.rad, ecc=u.one)
@@ -171,7 +171,7 @@ def M_to_E(M, ecc):
Eccentric anomaly.
"""
- return (M_to_E_fast(M.to_value(u.rad), ecc.value) * u.rad).to(M.unit)
+ return (M_to_E_vf(M.to_value(u.rad), ecc.value) * u.rad).to(M.unit)
@u.quantity_input(M=u.rad, ecc=u.one)
@@ -191,7 +191,7 @@ def M_to_F(M, ecc):
Hyperbolic eccentric anomaly.
"""
- return (M_to_F_fast(M.to_value(u.rad), ecc.value) * u.rad).to(M.unit)
+ return (M_to_F_vf(M.to_value(u.rad), ecc.value) * u.rad).to(M.unit)
@u.quantity_input(M=u.rad, ecc=u.one)
@@ -209,7 +209,7 @@ def M_to_D(M):
Parabolic eccentric anomaly.
"""
- return (M_to_D_fast(M.to_value(u.rad)) * u.rad).to(M.unit)
+ return (M_to_D_vf(M.to_value(u.rad)) * u.rad).to(M.unit)
@u.quantity_input(E=u.rad, ecc=u.one)
@@ -231,7 +231,7 @@ def E_to_M(E, ecc):
Mean anomaly.
"""
- return (E_to_M_fast(E.to_value(u.rad), ecc.value) * u.rad).to(E.unit)
+ return (E_to_M_vf(E.to_value(u.rad), ecc.value) * u.rad).to(E.unit)
@u.quantity_input(F=u.rad, ecc=u.one)
@@ -251,7 +251,7 @@ def F_to_M(F, ecc):
Mean anomaly.
"""
- return (F_to_M_fast(F.to_value(u.rad), ecc.value) * u.rad).to(F.unit)
+ return (F_to_M_vf(F.to_value(u.rad), ecc.value) * u.rad).to(F.unit)
@u.quantity_input(D=u.rad, ecc=u.one)
@@ -269,7 +269,7 @@ def D_to_M(D):
Mean anomaly.
"""
- return (D_to_M_fast(D.to_value(u.rad)) * u.rad).to(D.unit)
+ return (D_to_M_vf(D.to_value(u.rad)) * u.rad).to(D.unit)
@u.quantity_input(nu=u.rad, ecc=u.one)
@@ -290,4 +290,4 @@ def fp_angle(nu, ecc):
Algorithm taken from Vallado 2007, pp. 113.
"""
- return (fp_angle_fast(nu.to_value(u.rad), ecc.value) * u.rad).to(nu.unit)
+ return (fp_angle_vf(nu.to_value(u.rad), ecc.value) * u.rad).to(nu.unit)
diff --git a/src/hapsira/twobody/elements.py b/src/hapsira/twobody/elements.py
index 7b14f08d3..735c21e1f 100644
--- a/src/hapsira/twobody/elements.py
+++ b/src/hapsira/twobody/elements.py
@@ -2,14 +2,13 @@
import numpy as np
from hapsira.core.elements import (
- circular_velocity as circular_velocity_fast,
- coe2rv as coe2rv_fast,
- coe2rv_many as coe2rv_many_fast,
- eccentricity_vector as eccentricity_vector_fast,
-)
-from hapsira.core.propagation.farnocchia import (
- delta_t_from_nu as delta_t_from_nu_fast,
+ circular_velocity_vf,
+ coe2rv_gf,
+ eccentricity_vector_gf,
+ mean_motion_vf,
+ period_vf,
)
+from hapsira.core.propagation.farnocchia import delta_t_from_nu_vf, FARNOCCHIA_DELTA
u_kms = u.km / u.s
u_km3s2 = u.km**3 / u.s**2
@@ -18,20 +17,19 @@
@u.quantity_input(k=u_km3s2, a=u.km)
def circular_velocity(k, a):
"""Circular velocity for a given body (k) and semimajor axis (a)."""
- return circular_velocity_fast(k.to_value(u_km3s2), a.to_value(u.km)) * u_kms
+ return circular_velocity_vf(k.to_value(u_km3s2), a.to_value(u.km)) * u_kms
@u.quantity_input(k=u_km3s2, a=u.km)
def mean_motion(k, a):
"""Mean motion given body (k) and semimajor axis (a)."""
- return np.sqrt(k / abs(a**3)).to(1 / u.s) * u.rad
+ return mean_motion_vf(k.to_value(u_km3s2), a.to_value(u.km)) * u.rad / u.s
@u.quantity_input(k=u_km3s2, a=u.km)
def period(k, a):
"""Period given body (k) and semimajor axis (a)."""
- n = mean_motion(k, a)
- return 2 * np.pi * u.rad / n
+ return period_vf(k.to_value(u_km3s2), a.to_value(u.km)) * u.s
@u.quantity_input(k=u_km3s2, r=u.km, v=u_kms)
@@ -43,12 +41,10 @@ def energy(k, r, v):
@u.quantity_input(k=u_km3s2, r=u.km, v=u_kms)
def eccentricity_vector(k, r, v):
"""Eccentricity vector."""
- return (
- eccentricity_vector_fast(
- k.to_value(u_km3s2), r.to_value(u.km), v.to_value(u_kms)
- )
- * u.one
+ e = eccentricity_vector_gf( # pylint: disable=E1120
+ k.to_value(u_km3s2), r.to_value(u.km), v.to_value(u_kms)
)
+ return e << u.one
@u.quantity_input(nu=u.rad, ecc=u.one, k=u_km3s2, r_p=u.km)
@@ -56,11 +52,12 @@ def t_p(nu, ecc, k, r_p):
"""Elapsed time since latest perifocal passage."""
# TODO: Make this a propagator method
t_p = (
- delta_t_from_nu_fast(
+ delta_t_from_nu_vf(
nu.to_value(u.rad),
ecc.value,
k.to_value(u_km3s2),
r_p.to_value(u.km),
+ FARNOCCHIA_DELTA,
)
* u.s
)
@@ -193,35 +190,39 @@ def get_eccentricity_critical_inc(ecc=None):
return ecc
-def coe2rv(k, p, ecc, inc, raan, argp, nu):
- rr, vv = coe2rv_fast(
- k.to_value(u_km3s2),
- p.to_value(u.km),
- ecc.to_value(u.one),
- inc.to_value(u.rad),
- raan.to_value(u.rad),
- argp.to_value(u.rad),
- nu.to_value(u.rad),
- )
+def coe2rv(k, p, ecc, inc, raan, argp, nu, rr=None, vv=None):
+ """
+ TODO document function
+
+ Function works on scalars and arrays
+ """
+
+ if rr is None and vv is None:
+ rr, vv = coe2rv_gf( # pylint: disable=E1120,E0633
+ k.to_value(u_km3s2),
+ p.to_value(u.km),
+ ecc.to_value(u.one),
+ inc.to_value(u.rad),
+ raan.to_value(u.rad),
+ argp.to_value(u.rad),
+ nu.to_value(u.rad),
+ np.zeros((3,), dtype="u1"), # dummy
+ )
+ else:
+ coe2rv_gf(
+ k.to_value(u_km3s2),
+ p.to_value(u.km),
+ ecc.to_value(u.one),
+ inc.to_value(u.rad),
+ raan.to_value(u.rad),
+ argp.to_value(u.rad),
+ nu.to_value(u.rad),
+ np.zeros((3,), dtype="u1"), # dummy
+ rr,
+ vv,
+ )
rr = rr << u.km
vv = vv << (u.km / u.s)
return rr, vv
-
-
-def coe2rv_many(k_arr, p_arr, ecc_arr, inc_arr, raan_arr, argp_arr, nu_arr):
- rr_arr, vv_arr = coe2rv_many_fast(
- k_arr.to_value(u_km3s2),
- p_arr.to_value(u.km),
- ecc_arr.to_value(u.one),
- inc_arr.to_value(u.rad),
- raan_arr.to_value(u.rad),
- argp_arr.to_value(u.rad),
- nu_arr.to_value(u.rad),
- )
-
- rr_arr = rr_arr << u.km
- vv_arr = vv_arr << (u.km / u.s)
-
- return rr_arr, vv_arr
diff --git a/src/hapsira/twobody/events.py b/src/hapsira/twobody/events.py
index e0604b934..6b60ccdb0 100644
--- a/src/hapsira/twobody/events.py
+++ b/src/hapsira/twobody/events.py
@@ -1,20 +1,36 @@
-from warnings import warn
+from abc import ABC, abstractmethod
+from math import degrees as rad2deg
+from typing import Callable
from astropy import units as u
from astropy.coordinates import get_body_barycentric_posvel
-import numpy as np
-from hapsira._math.linalg import norm
+from hapsira.core.jit import hjit
+from hapsira.core.math.ivp import dop853_dense_interp_brentq_hb
+from hapsira.core.math.linalg import mul_Vs_hf, norm_V_hf
from hapsira.core.events import (
- eclipse_function as eclipse_function_fast,
- line_of_sight as line_of_sight_fast,
-)
-from hapsira.core.spheroid_location import (
- cartesian_to_ellipsoidal as cartesian_to_ellipsoidal_fast,
+ eclipse_function_hf,
+ line_of_sight_hf,
)
+from hapsira.core.math.interpolate import interp_hb
+from hapsira.core.spheroid_location import cartesian_to_ellipsoidal_hf
+from hapsira.util import time_range
+
+
+__all__ = [
+ "BaseEvent",
+ "AltitudeCrossEvent",
+ "LithobrakeEvent",
+ "LatitudeCrossEvent",
+ "BaseEclipseEvent",
+ "PenumbraEvent",
+ "UmbraEvent",
+ "NodeCrossEvent",
+ "LosEvent",
+]
-class Event:
+class BaseEvent(ABC):
"""Base class for event functionalities.
Parameters
@@ -26,9 +42,12 @@ class Event:
"""
+ @abstractmethod
def __init__(self, terminal, direction):
self._terminal, self._direction = terminal, direction
self._last_t = None
+ self._impl_hf = None
+ self._impl_dense_hf = None
@property
def terminal(self):
@@ -42,17 +61,33 @@ def direction(self):
def last_t(self):
return self._last_t << u.s
- def __call__(self, t, u, k):
- raise NotImplementedError
+ @property
+ def last_t_raw(self) -> float:
+ return self._last_t
+
+ @last_t_raw.setter
+ def last_t_raw(self, value: float):
+ self._last_t = value
+
+ @property
+ def impl_hf(self) -> Callable:
+ return self._impl_hf
+
+ @property
+ def impl_dense_hf(self) -> Callable:
+ return self._impl_dense_hf
+ def _wrap(self):
+ self._impl_dense_hf = dop853_dense_interp_brentq_hb(self._impl_hf)
-class AltitudeCrossEvent(Event):
+
+class AltitudeCrossEvent(BaseEvent):
"""Detect if a satellite crosses a specific threshold altitude.
Parameters
----------
alt: float
- Threshold altitude (km).
+ Threshold altitude from the ground (km).
R: float
Radius of the attractor (km).
terminal: bool
@@ -66,16 +101,16 @@ class AltitudeCrossEvent(Event):
def __init__(self, alt, R, terminal=True, direction=-1):
super().__init__(terminal, direction)
- self._R = R
- self._alt = alt # Threshold altitude from the ground.
- def __call__(self, t, u, k):
- self._last_t = t
- r_norm = norm(u[:3])
+ @hjit("f(f,V,V,f)", cache=False)
+ def impl_hf(t, rr, vv, k):
+ r_norm = norm_V_hf(rr)
+ return (
+ r_norm - R - alt
+ ) # If this goes from +ve to -ve, altitude is decreasing.
- return (
- r_norm - self._R - self._alt
- ) # If this goes from +ve to -ve, altitude is decreasing.
+ self._impl_hf = impl_hf
+ self._wrap()
class LithobrakeEvent(AltitudeCrossEvent):
@@ -94,7 +129,7 @@ def __init__(self, R, terminal=True):
super().__init__(0, R, terminal, direction=-1)
-class LatitudeCrossEvent(Event):
+class LatitudeCrossEvent(BaseEvent):
"""Detect if a satellite crosses a specific threshold latitude.
Parameters
@@ -114,26 +149,31 @@ class LatitudeCrossEvent(Event):
def __init__(self, orbit, lat, terminal=False, direction=0):
super().__init__(terminal, direction)
- self._R = orbit.attractor.R.to_value(u.m)
- self._R_polar = orbit.attractor.R_polar.to_value(u.m)
- self._epoch = orbit.epoch
- self._lat = lat.to_value(u.deg) # Threshold latitude (in degrees).
+ R = orbit.attractor.R.to_value(u.m)
+ R_polar = orbit.attractor.R_polar.to_value(u.m)
+ lat = lat.to_value(u.deg) # Threshold latitude (in degrees).
- def __call__(self, t, u_, k):
- self._last_t = t
- pos_on_body = (u_[:3] / norm(u_[:3])) * self._R
- _, lat_, _ = cartesian_to_ellipsoidal_fast(self._R, self._R_polar, *pos_on_body)
+ @hjit("f(f,V,V,f)", cache=False)
+ def impl_hf(t, rr, vv, k):
+ pos_on_body = mul_Vs_hf(rr, R / norm_V_hf(rr))
+ _, lat_, _ = cartesian_to_ellipsoidal_hf(R, R_polar, *pos_on_body)
+ return rad2deg(lat_) - lat
- return np.rad2deg(lat_) - self._lat
+ self._impl_hf = impl_hf
+ self._wrap()
-class EclipseEvent(Event):
+class BaseEclipseEvent(BaseEvent):
"""Base class for the eclipse event.
Parameters
----------
orbit: hapsira.twobody.orbit.Orbit
Orbit of the satellite.
+ tof: ~astropy.units.Quantity
+ Maximum time of flight for interpolator
+ steps: int
+ Steps for interpolator
terminal: bool, optional
Whether to terminate integration when the event occurs, defaults to False.
direction: float, optional
@@ -141,28 +181,27 @@ class EclipseEvent(Event):
"""
- def __init__(self, orbit, terminal=False, direction=0):
+ def __init__(self, orbit, tof, steps=50, terminal=False, direction=0):
super().__init__(terminal, direction)
- self._primary_body = orbit.attractor
- self._secondary_body = orbit.attractor.parent
- self._epoch = orbit.epoch
- self.k = self._primary_body.k.to_value(u.km**3 / u.s**2)
- self.R_sec = self._secondary_body.R.to_value(u.km)
- self.R_primary = self._primary_body.R.to_value(u.km)
-
- def __call__(self, t, u_, k):
- # Solve for primary and secondary bodies position w.r.t. solar system
- # barycenter at a particular epoch.
- (r_primary_wrt_ssb, _), (r_secondary_wrt_ssb, _) = (
- get_body_barycentric_posvel(body.name, self._epoch + t * u.s)
- for body in (self._primary_body, self._secondary_body)
- )
- r_sec = ((r_secondary_wrt_ssb - r_primary_wrt_ssb).xyz << u.km).value
+ primary_body = orbit.attractor
+ secondary_body = orbit.attractor.parent
+ epoch = orbit.epoch
+
+ self._R_sec = secondary_body.R.to_value(u.km)
+ self._R_primary = primary_body.R.to_value(u.km)
- return r_sec
+ epochs = time_range(start=epoch, end=epoch + tof, num_values=steps)
+ r_primary_wrt_ssb, _ = get_body_barycentric_posvel(primary_body.name, epochs)
+ r_secondary_wrt_ssb, _ = get_body_barycentric_posvel(
+ secondary_body.name, epochs
+ )
+ self._r_sec_hf = interp_hb(
+ (epochs - epoch).to_value(u.s),
+ (r_secondary_wrt_ssb - r_primary_wrt_ssb).xyz.to_value(u.km),
+ )
-class PenumbraEvent(EclipseEvent):
+class PenumbraEvent(BaseEclipseEvent):
"""Detect whether a satellite is in penumbra or not.
Parameters
@@ -177,26 +216,32 @@ class PenumbraEvent(EclipseEvent):
"""
- def __init__(self, orbit, terminal=False, direction=0):
- super().__init__(orbit, terminal, direction)
-
- def __call__(self, t, u_, k):
- self._last_t = t
-
- r_sec = super().__call__(t, u_, k)
- shadow_function = eclipse_function_fast(
- self.k,
- u_,
- r_sec,
- self.R_sec,
- self.R_primary,
- umbra=False,
- )
+ def __init__(self, orbit, tof, steps=50, terminal=False, direction=0):
+ super().__init__(orbit, tof, steps, terminal, direction)
+
+ R_sec = self._R_sec
+ R_primary = self._R_primary
+ r_sec_hf = self._r_sec_hf
+
+ @hjit("f(f,V,V,f)", cache=False)
+ def impl_hf(t, rr, vv, k):
+ r_sec = r_sec_hf(t)
+ shadow_function = eclipse_function_hf(
+ k,
+ rr,
+ vv,
+ r_sec,
+ R_sec,
+ R_primary,
+ False,
+ )
+ return shadow_function
- return shadow_function
+ self._impl_hf = impl_hf
+ self._wrap()
-class UmbraEvent(EclipseEvent):
+class UmbraEvent(BaseEclipseEvent):
"""Detect whether a satellite is in umbra or not.
Parameters
@@ -211,21 +256,32 @@ class UmbraEvent(EclipseEvent):
"""
- def __init__(self, orbit, terminal=False, direction=0):
- super().__init__(orbit, terminal, direction)
-
- def __call__(self, t, u_, k):
- self._last_t = t
-
- r_sec = super().__call__(t, u_, k)
- shadow_function = eclipse_function_fast(
- self.k, u_, r_sec, self.R_sec, self.R_primary
- )
+ def __init__(self, orbit, tof, steps=50, terminal=False, direction=0):
+ super().__init__(orbit, tof, steps, terminal, direction)
+
+ R_sec = self._R_sec
+ R_primary = self._R_primary
+ r_sec_hf = self._r_sec_hf
+
+ @hjit("f(f,V,V,f)", cache=False)
+ def impl_hf(t, rr, vv, k):
+ r_sec = r_sec_hf(t)
+ shadow_function = eclipse_function_hf(
+ k,
+ rr,
+ vv,
+ r_sec,
+ R_sec,
+ R_primary,
+ True,
+ )
+ return shadow_function
- return shadow_function
+ self._impl_hf = impl_hf
+ self._wrap()
-class NodeCrossEvent(Event):
+class NodeCrossEvent(BaseEvent):
"""Detect equatorial node (ascending or descending) crossings.
Parameters
@@ -242,13 +298,16 @@ class NodeCrossEvent(Event):
def __init__(self, terminal=False, direction=0):
super().__init__(terminal, direction)
- def __call__(self, t, u_, k):
- self._last_t = t
- # Check if the z coordinate of the satellite is zero.
- return u_[2]
+ @hjit("f(f,V,V,f)", cache=False)
+ def impl_hf(t, rr, vv, k):
+ # Check if the z coordinate of the satellite is zero.
+ return rr[2]
+
+ self._impl_hf = impl_hf
+ self._wrap()
-class LosEvent(Event):
+class LosEvent(BaseEvent):
"""Detect whether there exists a LOS between two satellites.
Parameters
@@ -261,25 +320,25 @@ class LosEvent(Event):
"""
- def __init__(self, attractor, pos_coords, terminal=False, direction=0):
+ def __init__(self, attractor, tofs, secondary_rr, terminal=False, direction=0):
super().__init__(terminal, direction)
- self._attractor = attractor
- self._pos_coords = (pos_coords << u.km).value.tolist()
- self._last_coord = (
- self._pos_coords[-1] << u.km
- ).value # Used to prevent any errors if `self._pos_coords` gets exhausted early.
- self._R = self._attractor.R.to_value(u.km)
-
- def __call__(self, t, u_, k):
- self._last_t = t
-
- if norm(u_[:3]) < self._R:
- warn(
- "The norm of the position vector of the primary body is less than the radius of the attractor."
+ secondary_hf = interp_hb(tofs.to_value(u.s), secondary_rr.to_value(u.km))
+ R = attractor.R.to_value(u.km)
+
+ @hjit("f(f,V,V,f)", cache=False)
+ def impl_hf(t, rr, vv, k):
+ # Can currently not warn due to: https://github.com/numba/numba/issues/1243
+ # TODO Matching test deactivated ...
+ # if norm_V_hf(rr) < R:
+ # warn(
+ # "The norm of the position vector of the primary body is less than the radius of the attractor."
+ # )
+ delta_angle = line_of_sight_hf(
+ rr,
+ secondary_hf(t),
+ R,
)
+ return delta_angle
- pos_coord = self._pos_coords.pop(0) if self._pos_coords else self._last_coord
-
- # Need to cast `pos_coord` to array since `norm` inside numba only works for arrays, not lists.
- delta_angle = line_of_sight_fast(u_[:3], np.array(pos_coord), self._R)
- return delta_angle
+ self._impl_hf = impl_hf
+ self._wrap()
diff --git a/src/hapsira/twobody/orbit/scalar.py b/src/hapsira/twobody/orbit/scalar.py
index c7e27949a..fb4b115fe 100644
--- a/src/hapsira/twobody/orbit/scalar.py
+++ b/src/hapsira/twobody/orbit/scalar.py
@@ -11,7 +11,7 @@
import numpy as np
from hapsira.bodies import Earth
-from hapsira.core.events import elevation_function as elevation_function_fast
+from hapsira.core.events import elevation_function_gf
from hapsira.frames.util import get_frame
from hapsira.threebody.soi import laplace_radius
from hapsira.twobody.elements import eccentricity_vector, energy, t_p
@@ -682,13 +682,8 @@ def elevation(self, lat, theta, h):
"Elevation implementation is currently only supported for orbits having Earth as the attractor."
)
- x, y, z = self.r.to_value(u.km)
- vx, vy, vz = self.v.to_value(u.km / u.s)
- u_ = np.array([x, y, z, vx, vy, vz])
-
- elevation = elevation_function_fast(
- self.attractor.k.to_value(u.km**3 / u.s**2),
- u_,
+ elevation = elevation_function_gf( # pylint: disable=E1120
+ self.r.to_value(u.km),
lat.to_value(u.rad),
theta.to_value(u.rad),
self.attractor.R.to(u.km).value,
diff --git a/src/hapsira/twobody/propagation/cowell.py b/src/hapsira/twobody/propagation/cowell.py
index 693369ecb..62b44fd2b 100644
--- a/src/hapsira/twobody/propagation/cowell.py
+++ b/src/hapsira/twobody/propagation/cowell.py
@@ -1,9 +1,17 @@
import sys
from astropy import units as u
-
-from hapsira.core.propagation import cowell
-from hapsira.core.propagation.base import func_twobody
+import numpy as np
+
+from hapsira.core.math.ieee754 import float_
+from hapsira.core.propagation.cowell import (
+ cowell_gb,
+ SOLVE_FINISHED,
+ SOLVE_TERMINATED,
+ SOLVE_BRENTQFAILED,
+ SOLVE_FAILED,
+)
+from hapsira.core.propagation.base import func_twobody_hf
from hapsira.twobody.propagation.enums import PropagatorKind
from hapsira.twobody.states import RVState
@@ -27,44 +35,88 @@ class CowellPropagator:
PropagatorKind.ELLIPTIC | PropagatorKind.PARABOLIC | PropagatorKind.HYPERBOLIC
)
- def __init__(self, rtol=1e-11, events=None, f=func_twobody):
+ def __init__(self, rtol=1e-11, atol=1e-12, events=tuple(), f=func_twobody_hf):
self._rtol = rtol
+ self._atol = atol
self._events = events
- self._f = f
+ self._terminals = np.array([event.terminal for event in events], dtype=bool)
+ self._directions = np.array([event.direction for event in events], dtype=float_)
+ self._cowell_gf = cowell_gb(events=events, func=f)
def propagate(self, state, tof):
state = state.to_vectors()
tofs = tof.reshape(-1)
-
- rrs, vvs = cowell(
- state.attractor.k.to_value(u.km**3 / u.s**2),
- *state.to_value(),
- tofs.to_value(u.s),
- self._rtol,
- events=self._events,
- f=self._f,
+ # TODO make sure tofs is sorted
+
+ r0, v0 = state.to_value()
+ ( # pylint: disable=E0633,E1120
+ _,
+ _,
+ _,
+ last_ts,
+ status,
+ t_idx,
+ rrs,
+ vvs,
+ ) = self._cowell_gf( # pylint: disable=E0633,E1120
+ tofs.to_value(u.s), # tofs
+ r0, # rr
+ v0, # vv
+ state.attractor.k.to_value(u.km**3 / u.s**2), # argk
+ self._rtol, # rtol
+ self._atol, # atol
+ self._terminals, # event_terminals
+ self._directions, # event_directions
)
- r = rrs[-1] << u.km
- v = vvs[-1] << (u.km / u.s)
+
+ assert np.all((status != SOLVE_FAILED))
+ assert np.all((status != SOLVE_BRENTQFAILED))
+ assert np.all((status == SOLVE_FINISHED) | (status == SOLVE_TERMINATED))
+
+ for last_t, event in zip(last_ts, self._events):
+ event.last_t_raw = last_t
+
+ r = rrs[t_idx - 1] << u.km
+ v = vvs[t_idx - 1] << (u.km / u.s)
new_state = RVState(state.attractor, (r, v), state.plane)
return new_state
def propagate_many(self, state, tofs):
state = state.to_vectors()
-
- rrs, vvs = cowell(
- state.attractor.k.to_value(u.km**3 / u.s**2),
- *state.to_value(),
- tofs.to_value(u.s),
- self._rtol,
- events=self._events,
- f=self._f,
+ # TODO make sure tofs is sorted
+
+ r0, v0 = state.to_value()
+ ( # pylint: disable=E0633,E1120
+ _,
+ _,
+ _,
+ last_ts,
+ status,
+ t_idx,
+ rrs,
+ vvs,
+ ) = self._cowell_gf( # pylint: disable=E0633,E1120
+ tofs.to_value(u.s), # tofs
+ r0, # rr
+ v0, # vv
+ state.attractor.k.to_value(u.km**3 / u.s**2), # argk
+ self._rtol, # rtol
+ self._atol, # atol
+ self._terminals, # event_terminals
+ self._directions, # event_directions
)
+ assert np.all((status != SOLVE_FAILED))
+ assert np.all((status != SOLVE_BRENTQFAILED))
+ assert np.all((status == SOLVE_FINISHED) | (status == SOLVE_TERMINATED))
+
+ for last_t, event in zip(last_ts, self._events):
+ event.last_t_raw = last_t
+
# TODO: This should probably return a RVStateArray instead,
- # see discussion at https://github.com/hapsira/hapsira/pull/1492
+ # see discussion at https://github.com/poliastro/poliastro/pull/1492
return (
- rrs << u.km,
- vvs << (u.km / u.s),
+ rrs[:t_idx, :] << u.km,
+ vvs[:t_idx, :] << (u.km / u.s),
)
diff --git a/src/hapsira/twobody/propagation/danby.py b/src/hapsira/twobody/propagation/danby.py
index e8c4c3d43..1d4e75458 100644
--- a/src/hapsira/twobody/propagation/danby.py
+++ b/src/hapsira/twobody/propagation/danby.py
@@ -2,7 +2,7 @@
from astropy import units as u
-from hapsira.core.propagation import danby_coe as danby_fast
+from hapsira.core.propagation.danby import danby_coe_vf, DANBY_NUMITER, DANBY_RTOL
from hapsira.twobody.propagation.enums import PropagatorKind
from hapsira.twobody.states import ClassicalState
@@ -27,10 +27,12 @@ def propagate(self, state, tof):
state = state.to_classical()
nu = (
- danby_fast(
+ danby_coe_vf(
state.attractor.k.to_value(u.km**3 / u.s**2),
*state.to_value(),
tof.to_value(u.s),
+ DANBY_NUMITER,
+ DANBY_RTOL,
)
<< u.rad
)
diff --git a/src/hapsira/twobody/propagation/farnocchia.py b/src/hapsira/twobody/propagation/farnocchia.py
index 4b26f57d6..dbad0b863 100644
--- a/src/hapsira/twobody/propagation/farnocchia.py
+++ b/src/hapsira/twobody/propagation/farnocchia.py
@@ -1,11 +1,10 @@
import sys
from astropy import units as u
-import numpy as np
from hapsira.core.propagation.farnocchia import (
- farnocchia_coe as farnocchia_coe_fast,
- farnocchia_rv as farnocchia_rv_fast,
+ farnocchia_coe_vf,
+ farnocchia_rv_gf,
)
from hapsira.twobody.propagation.enums import PropagatorKind
from hapsira.twobody.states import ClassicalState
@@ -34,7 +33,7 @@ def propagate(self, state, tof):
state = state.to_classical()
nu = (
- farnocchia_coe_fast(
+ farnocchia_coe_vf(
state.attractor.k.to_value(u.km**3 / u.s**2),
*state.to_value(),
tof.to_value(u.s),
@@ -53,11 +52,10 @@ def propagate_many(self, state, tofs):
rv0 = state.to_value()
# TODO: This should probably return a ClassicalStateArray instead,
- # see discussion at https://github.com/hapsira/hapsira/pull/1492
- results = np.array(
- [farnocchia_rv_fast(k, *rv0, tof) for tof in tofs.to_value(u.s)]
- )
+ # see discussion at https://github.com/poliastro/poliastro/pull/1492
+ rr, vv = farnocchia_rv_gf(k, *rv0, tofs.to_value(u.s)) # pylint: disable=E0633
+
return (
- results[:, 0] << u.km,
- results[:, 1] << (u.km / u.s),
+ rr << u.km,
+ vv << (u.km / u.s),
)
diff --git a/src/hapsira/twobody/propagation/gooding.py b/src/hapsira/twobody/propagation/gooding.py
index 3c0da8fcd..11f9eed75 100644
--- a/src/hapsira/twobody/propagation/gooding.py
+++ b/src/hapsira/twobody/propagation/gooding.py
@@ -2,7 +2,11 @@
from astropy import units as u
-from hapsira.core.propagation import gooding_coe as gooding_fast
+from hapsira.core.propagation.gooding import (
+ gooding_coe_vf,
+ GOODING_RTOL,
+ GOODING_NUMITER,
+)
from hapsira.twobody.propagation.enums import PropagatorKind
from hapsira.twobody.states import ClassicalState
@@ -32,10 +36,12 @@ def propagate(self, state, tof):
state = state.to_classical()
nu = (
- gooding_fast(
+ gooding_coe_vf(
state.attractor.k.to_value(u.km**3 / u.s**2),
*state.to_value(),
tof.to_value(u.s),
+ GOODING_NUMITER,
+ GOODING_RTOL,
)
<< u.rad
)
diff --git a/src/hapsira/twobody/propagation/markley.py b/src/hapsira/twobody/propagation/markley.py
index ed99efec2..530ac7a4e 100644
--- a/src/hapsira/twobody/propagation/markley.py
+++ b/src/hapsira/twobody/propagation/markley.py
@@ -2,7 +2,7 @@
from astropy import units as u
-from hapsira.core.propagation import markley_coe as markley_fast
+from hapsira.core.propagation.markley import markley_coe_vf
from hapsira.twobody.propagation.enums import PropagatorKind
from hapsira.twobody.states import ClassicalState
@@ -28,7 +28,7 @@ def propagate(self, state, tof):
state = state.to_classical()
nu = (
- markley_fast(
+ markley_coe_vf(
state.attractor.k.to_value(u.km**3 / u.s**2),
*state.to_value(),
tof.to_value(u.s),
diff --git a/src/hapsira/twobody/propagation/mikkola.py b/src/hapsira/twobody/propagation/mikkola.py
index 1f5233af7..a0f999fb8 100644
--- a/src/hapsira/twobody/propagation/mikkola.py
+++ b/src/hapsira/twobody/propagation/mikkola.py
@@ -2,7 +2,7 @@
from astropy import units as u
-from hapsira.core.propagation import mikkola_coe as mikkola_fast
+from hapsira.core.propagation.mikkola import mikkola_coe_vf
from hapsira.twobody.propagation.enums import PropagatorKind
from hapsira.twobody.states import ClassicalState
@@ -27,7 +27,7 @@ def propagate(self, state, tof):
state = state.to_classical()
nu = (
- mikkola_fast(
+ mikkola_coe_vf(
state.attractor.k.to_value(u.km**3 / u.s**2),
*state.to_value(),
tof.to_value(u.s),
diff --git a/src/hapsira/twobody/propagation/pimienta.py b/src/hapsira/twobody/propagation/pimienta.py
index 0e21908db..0a61b1735 100644
--- a/src/hapsira/twobody/propagation/pimienta.py
+++ b/src/hapsira/twobody/propagation/pimienta.py
@@ -2,7 +2,7 @@
from astropy import units as u
-from hapsira.core.propagation import pimienta_coe as pimienta_fast
+from hapsira.core.propagation.pimienta import pimienta_coe_vf
from hapsira.twobody.propagation.enums import PropagatorKind
from hapsira.twobody.states import ClassicalState
@@ -30,7 +30,7 @@ def propagate(self, state, tof):
state = state.to_classical()
nu = (
- pimienta_fast(
+ pimienta_coe_vf(
state.attractor.k.to_value(u.km**3 / u.s**2),
*state.to_value(),
tof.to_value(u.s),
diff --git a/src/hapsira/twobody/propagation/recseries.py b/src/hapsira/twobody/propagation/recseries.py
index cff5a0c23..5536152b8 100644
--- a/src/hapsira/twobody/propagation/recseries.py
+++ b/src/hapsira/twobody/propagation/recseries.py
@@ -2,7 +2,13 @@
from astropy import units as u
-from hapsira.core.propagation import recseries_coe as recseries_fast
+from hapsira.core.propagation.recseries import (
+ recseries_coe_vf,
+ RECSERIES_METHOD_RTOL,
+ RECSERIES_ORDER,
+ RECSERIES_NUMITER,
+ RECSERIES_RTOL,
+)
from hapsira.twobody.propagation.enums import PropagatorKind
from hapsira.twobody.states import ClassicalState
@@ -29,10 +35,10 @@ class RecseriesPropagator:
def __init__(
self,
- method="rtol",
- order=8,
- numiter=100,
- rtol=1e-8,
+ method=RECSERIES_METHOD_RTOL,
+ order=RECSERIES_ORDER,
+ numiter=RECSERIES_NUMITER,
+ rtol=RECSERIES_RTOL,
):
self._method = method
self._order = order
@@ -43,14 +49,14 @@ def propagate(self, state, tof):
state = state.to_classical()
nu = (
- recseries_fast(
+ recseries_coe_vf(
state.attractor.k.to_value(u.km**3 / u.s**2),
*state.to_value(),
tof.to_value(u.s),
- method=self._method,
- order=self._order,
- numiter=self._numiter,
- rtol=self._rtol,
+ self._method,
+ self._order,
+ self._numiter,
+ self._rtol,
)
<< u.rad
)
diff --git a/src/hapsira/twobody/propagation/vallado.py b/src/hapsira/twobody/propagation/vallado.py
index 06d203929..2fb197117 100644
--- a/src/hapsira/twobody/propagation/vallado.py
+++ b/src/hapsira/twobody/propagation/vallado.py
@@ -1,9 +1,8 @@
import sys
from astropy import units as u
-import numpy as np
-from hapsira.core.propagation import vallado as vallado_fast
+from hapsira.core.propagation.vallado import vallado_rv_gf, VALLADO_NUMITER
from hapsira.twobody.propagation.enums import PropagatorKind
from hapsira.twobody.states import RVState
@@ -12,21 +11,6 @@
sys.modules[__name__].__class__ = OldPropagatorModule
-def vallado(k, r0, v0, tof, *, numiter):
- # Compute Lagrange coefficients
- f, g, fdot, gdot = vallado_fast(k, r0, v0, tof, numiter)
-
- assert (
- np.abs(f * gdot - fdot * g - 1) < 1e-5
- ), "Internal error, solution is not consistent" # Fixed tolerance
-
- # Return position and velocity vectors
- r = f * r0 + g * v0
- v = fdot * r0 + gdot * v0
-
- return r, v
-
-
class ValladoPropagator:
"""Propagates Keplerian orbit using Vallado's method.
@@ -43,17 +27,17 @@ class ValladoPropagator:
PropagatorKind.ELLIPTIC | PropagatorKind.PARABOLIC | PropagatorKind.HYPERBOLIC
)
- def __init__(self, numiter=350):
+ def __init__(self, numiter=VALLADO_NUMITER):
self._numiter = numiter
def propagate(self, state, tof):
state = state.to_vectors()
- r_raw, v_raw = vallado(
+ r_raw, v_raw = vallado_rv_gf(
state.attractor.k.to_value(u.km**3 / u.s**2),
*state.to_value(),
tof.to_value(u.s),
- numiter=self._numiter,
+ self._numiter,
)
r = r_raw << u.km
v = v_raw << (u.km / u.s)
diff --git a/src/hapsira/twobody/sampling.py b/src/hapsira/twobody/sampling.py
index 89f4bcf73..fde8eb20b 100644
--- a/src/hapsira/twobody/sampling.py
+++ b/src/hapsira/twobody/sampling.py
@@ -3,7 +3,7 @@
import numpy as np
from hapsira.twobody.angles import E_to_nu, nu_to_E
-from hapsira.twobody.elements import coe2rv_many, hyp_nu_limit, t_p
+from hapsira.twobody.elements import coe2rv, hyp_nu_limit, t_p
from hapsira.twobody.propagation import FarnocchiaPropagator
from hapsira.util import alinspace, wrap_angle
@@ -91,7 +91,7 @@ def sample(self, orbit):
# However, we are also returning the epochs
# (since computing them here is more efficient than doing it from the outside)
# but there are open questions around StateArrays and epochs.
- # See discussion at https://github.com/hapsira/hapsira/pull/1492
+ # See discussion at https://github.com/poliastro/poliastro/pull/1492
cartesian = CartesianRepresentation(
rr, differentials=CartesianDifferential(vv, xyz_axis=1), xyz_axis=1
)
@@ -144,7 +144,7 @@ def sample(self, orbit):
epochs = orbit.epoch + (delta_ts - orbit.t_p)
n = nu_values.shape[0]
- rr, vv = coe2rv_many(
+ rr, vv = coe2rv(
np.tile(orbit.attractor.k, n),
np.tile(orbit.p, n),
np.tile(orbit.ecc, n),
@@ -160,7 +160,7 @@ def sample(self, orbit):
# However, we are also returning the epochs
# (since computing them here is more efficient than doing it from the outside)
# but there are open questions around StateArrays and epochs.
- # See discussion at https://github.com/hapsira/hapsira/pull/1492
+ # See discussion at https://github.com/poliastro/poliastro/pull/1492
cartesian = CartesianRepresentation(
rr, differentials=CartesianDifferential(vv, xyz_axis=1), xyz_axis=1
)
diff --git a/src/hapsira/twobody/states.py b/src/hapsira/twobody/states.py
index 53899ca37..f80e65e00 100644
--- a/src/hapsira/twobody/states.py
+++ b/src/hapsira/twobody/states.py
@@ -1,9 +1,22 @@
from functools import cached_property
from astropy import units as u
+import numpy as np
-from hapsira.core.elements import coe2mee, coe2rv, mee2coe, mee2rv, rv2coe
-from hapsira.twobody.elements import mean_motion, period, t_p
+from hapsira.core.elements import (
+ coe2mee_gf,
+ coe2rv_gf,
+ mee2coe_gf,
+ mee2rv_gf,
+ rv2coe_gf,
+ RV2COE_TOL,
+ mean_motion_vf,
+ period_vf,
+)
+from hapsira.core.propagation.farnocchia import delta_t_from_nu_vf, FARNOCCHIA_DELTA
+
+
+u_km3s2 = u.km**3 / u.s**2
class BaseState:
@@ -39,12 +52,25 @@ def attractor(self):
@cached_property
def n(self):
"""Mean motion."""
- return mean_motion(self.attractor.k, self.to_classical().a)
+ return (
+ mean_motion_vf(
+ self.attractor.k.to_value(u_km3s2),
+ self.to_classical().a.to_value(u.km),
+ )
+ * u.rad
+ / u.s
+ )
@cached_property
def period(self):
"""Period of the orbit."""
- return period(self.attractor.k, self.to_classical().a)
+ return (
+ period_vf(
+ self.attractor.k.to_value(u_km3s2),
+ self.to_classical().a.to_value(u.km),
+ )
+ * u.s
+ )
@cached_property
def r_p(self):
@@ -59,11 +85,16 @@ def r_a(self):
@cached_property
def t_p(self):
"""Elapsed time since latest perifocal passage."""
- return t_p(
- self.to_classical().nu,
- self.to_classical().ecc,
- self.attractor.k,
- self.r_p,
+ self_classical = self.to_classical()
+ return (
+ delta_t_from_nu_vf(
+ self_classical.nu.to_value(u.rad),
+ self_classical.ecc.value,
+ self.attractor.k.to_value(u_km3s2),
+ self.r_p.to_value(u.km),
+ FARNOCCHIA_DELTA,
+ )
+ * u.s
)
def to_tuple(self):
@@ -171,7 +202,17 @@ def to_value(self):
def to_vectors(self):
"""Converts to position and velocity vector representation."""
- r, v = coe2rv(self.attractor.k.to_value(u.km**3 / u.s**2), *self.to_value())
+
+ r = np.zeros(self.attractor.k.shape + (3,), dtype=self.attractor.k.dtype)
+ v = np.zeros(self.attractor.k.shape + (3,), dtype=self.attractor.k.dtype)
+
+ coe2rv_gf(
+ self.attractor.k.to_value(u.km**3 / u.s**2),
+ *self.to_value(),
+ np.zeros((3,), dtype="u1"), # dummy
+ r,
+ v,
+ )
return RVState(self.attractor, (r << u.km, v << u.km / u.s), self.plane)
@@ -181,7 +222,8 @@ def to_classical(self):
def to_equinoctial(self):
"""Converts to modified equinoctial elements representation."""
- p, f, g, h, k, L = coe2mee(*self.to_value())
+
+ p, f, g, h, k, L = coe2mee_gf(*self.to_value()) # pylint: disable=E1120,E0633
return ModifiedEquinoctialState(
self.attractor,
@@ -231,9 +273,10 @@ def to_vectors(self):
def to_classical(self):
"""Converts to classical orbital elements representation."""
- (p, ecc, inc, raan, argp, nu) = rv2coe(
+ (p, ecc, inc, raan, argp, nu) = rv2coe_gf( # pylint: disable=E1120,E0633
self.attractor.k.to_value(u.km**3 / u.s**2),
*self.to_value(),
+ RV2COE_TOL,
)
return ClassicalState(
@@ -312,7 +355,9 @@ def to_value(self):
def to_classical(self):
"""Converts to classical orbital elements representation."""
- p, ecc, inc, raan, argp, nu = mee2coe(*self.to_value())
+ p, ecc, inc, raan, argp, nu = mee2coe_gf( # pylint: disable=E1120,E0633
+ *self.to_value()
+ )
return ClassicalState(
self.attractor,
@@ -329,5 +374,7 @@ def to_classical(self):
def to_vectors(self):
"""Converts to position and velocity vector representation."""
- r, v = mee2rv(*self.to_value())
+ r, v = mee2rv_gf( # pylint: disable=E1120,E0633
+ *self.to_value(), np.zeros((3,), dtype="u1")
+ )
return RVState(self.attractor, (r << u.km, v << u.km / u.s), self.plane)
diff --git a/src/hapsira/twobody/thrust/__init__.py b/src/hapsira/twobody/thrust/__init__.py
index 924fcb10a..f91438807 100644
--- a/src/hapsira/twobody/thrust/__init__.py
+++ b/src/hapsira/twobody/thrust/__init__.py
@@ -1,9 +1,7 @@
-from hapsira.twobody.thrust.change_a_inc import change_a_inc
-from hapsira.twobody.thrust.change_argp import change_argp
-from hapsira.twobody.thrust.change_ecc_inc import change_ecc_inc
-from hapsira.twobody.thrust.change_ecc_quasioptimal import (
- change_ecc_quasioptimal,
-)
+from .change_a_inc import change_a_inc
+from .change_argp import change_argp
+from .change_ecc_inc import change_ecc_inc
+from .change_ecc_quasioptimal import change_ecc_quasioptimal
__all__ = [
"change_a_inc",
diff --git a/src/hapsira/twobody/thrust/change_a_inc.py b/src/hapsira/twobody/thrust/change_a_inc.py
index fb391255f..c2e677299 100644
--- a/src/hapsira/twobody/thrust/change_a_inc.py
+++ b/src/hapsira/twobody/thrust/change_a_inc.py
@@ -1,8 +1,6 @@
from astropy import units as u
-from hapsira.core.thrust.change_a_inc import (
- change_a_inc as change_a_inc_fast,
-)
+from hapsira.core.thrust.change_a_inc import change_a_inc_hb
def change_a_inc(k, a_0, a_f, inc_0, inc_f, f):
@@ -40,7 +38,7 @@ def change_a_inc(k, a_0, a_f, inc_0, inc_f, f):
* Kéchichian, J. A. "Reformulation of Edelbaum's Low-Thrust
Transfer Problem Using Optimal Control Theory", 1997.
"""
- a_d, delta_V, t_f = change_a_inc_fast(
+ a_d_hf, delta_V, t_f = change_a_inc_hb(
k=k.to_value(u.km**3 / u.s**2),
a_0=a_0.to_value(u.km),
a_f=a_f.to_value(u.km),
@@ -48,4 +46,8 @@ def change_a_inc(k, a_0, a_f, inc_0, inc_f, f):
inc_f=inc_f.to_value(u.rad),
f=f.to_value(u.km / u.s**2),
)
- return a_d, delta_V, t_f * u.s
+ return (
+ a_d_hf,
+ delta_V,
+ t_f * u.s,
+ ) # TODO delta_V is not a vector and does not carry a unit??
diff --git a/src/hapsira/twobody/thrust/change_argp.py b/src/hapsira/twobody/thrust/change_argp.py
index cc574234e..62eacbd6b 100644
--- a/src/hapsira/twobody/thrust/change_argp.py
+++ b/src/hapsira/twobody/thrust/change_argp.py
@@ -9,7 +9,7 @@
"""
from astropy import units as u
-from hapsira.core.thrust.change_argp import change_argp as change_a_inc_fast
+from hapsira.core.thrust.change_argp import change_argp_hb
def change_argp(k, a, ecc, argp_0, argp_f, f):
@@ -36,13 +36,13 @@ def change_argp(k, a, ecc, argp_0, argp_f, f):
-------
a_d, delta_V, t_f : tuple (function, ~astropy.units.quantity.Quantity, ~astropy.time.Time)
"""
- a_d, delta_V, t_f = change_a_inc_fast(
+ a_d_hf, delta_V, t_f = change_argp_hb(
k=k.to_value(u.km**3 / u.s**2),
a=a.to_value(u.km),
- ecc=ecc,
+ ecc=ecc.to_value() if hasattr(ecc, "to_value") else ecc,
argp_0=argp_0.to_value(u.rad),
argp_f=argp_f.to_value(u.rad),
f=f.to_value(u.km / u.s**2),
)
- return a_d, delta_V, t_f * u.s
+ return a_d_hf, delta_V, t_f * u.s # delta_V is scalar, TODO add unit to it?
diff --git a/src/hapsira/twobody/thrust/change_ecc_inc.py b/src/hapsira/twobody/thrust/change_ecc_inc.py
index 4c3ca0c45..4eb752c61 100644
--- a/src/hapsira/twobody/thrust/change_ecc_inc.py
+++ b/src/hapsira/twobody/thrust/change_ecc_inc.py
@@ -1,8 +1,6 @@
from astropy import units as u
-from hapsira.core.thrust.change_ecc_inc import (
- change_ecc_inc as change_ecc_inc_fast,
-)
+from hapsira.core.thrust.change_ecc_inc import change_ecc_inc_hb
def change_ecc_inc(orb_0, ecc_f, inc_f, f):
@@ -28,11 +26,11 @@ def change_ecc_inc(orb_0, ecc_f, inc_f, f):
* Pollard, J. E. "Simplified Analysis of Low-Thrust Orbital Maneuvers", 2000.
"""
r, v = orb_0.rv()
- a_d, delta_V, t_f = change_ecc_inc_fast(
+ a_d_hf, delta_V, t_f = change_ecc_inc_hb(
k=orb_0.attractor.k.to_value(u.km**3 / u.s**2),
a=orb_0.a.to_value(u.km),
ecc_0=orb_0.ecc.value,
- ecc_f=ecc_f,
+ ecc_f=getattr(ecc_f, "value", ecc_f), # in case of u.one
inc_0=orb_0.inc.to_value(u.rad),
inc_f=inc_f.to_value(u.rad),
argp=orb_0.argp.to_value(u.rad),
@@ -40,4 +38,4 @@ def change_ecc_inc(orb_0, ecc_f, inc_f, f):
v=v.to_value(u.km / u.s),
f=f.to_value(u.km / u.s**2),
)
- return a_d, delta_V << (u.km / u.s), t_f << u.s
+ return a_d_hf, delta_V << (u.km / u.s), t_f << u.s
diff --git a/src/hapsira/twobody/thrust/change_ecc_quasioptimal.py b/src/hapsira/twobody/thrust/change_ecc_quasioptimal.py
index c7f13305b..842a28e74 100644
--- a/src/hapsira/twobody/thrust/change_ecc_quasioptimal.py
+++ b/src/hapsira/twobody/thrust/change_ecc_quasioptimal.py
@@ -7,12 +7,9 @@
"""
from astropy import units as u
-from numba import njit
-import numpy as np
-from numpy import cross
-from hapsira.core.thrust.change_ecc_quasioptimal import extra_quantities
-from hapsira.util import norm
+from hapsira.core.jit import array_to_V_hf
+from hapsira.core.thrust.change_ecc_quasioptimal import change_ecc_quasioptimal_hb
def change_ecc_quasioptimal(orb_0, ecc_f, f):
@@ -29,22 +26,16 @@ def change_ecc_quasioptimal(orb_0, ecc_f, f):
f : float
Magnitude of constant acceleration
"""
- # We fix the inertial direction at the beginning
- k = orb_0.attractor.k.to(u.km**3 / u.s**2).value
- a = orb_0.a.to(u.km).value
- ecc_0 = orb_0.ecc.value
- if ecc_0 > 0.001: # Arbitrary tolerance
- ref_vec = orb_0.e_vec / ecc_0
- else:
- ref_vec = orb_0.r / norm(orb_0.r)
-
- h_unit = orb_0.h_vec / norm(orb_0.h_vec)
- thrust_unit = cross(h_unit, ref_vec) * np.sign(ecc_f - ecc_0)
-
- @njit
- def a_d(t0, u_, k):
- accel_v = f * thrust_unit
- return accel_v
-
- delta_V, t_f = extra_quantities(k, a, ecc_0, ecc_f, f)
- return a_d, delta_V, t_f
+
+ a_d_hf, delta_V, t_f = change_ecc_quasioptimal_hb(
+ orb_0.attractor.k.to(u.km**3 / u.s**2).value, # k
+ orb_0.a.to(u.km).value, # a
+ orb_0.ecc.value, # ecc_0
+ ecc_f,
+ array_to_V_hf(orb_0.e_vec), # e_vec,
+ array_to_V_hf(orb_0.h_vec), # h_vec,
+ array_to_V_hf(orb_0.r), # r
+ f,
+ )
+
+ return a_d_hf, delta_V, t_f
diff --git a/src/hapsira/util.py b/src/hapsira/util.py
index d1886867a..99b702ecf 100644
--- a/src/hapsira/util.py
+++ b/src/hapsira/util.py
@@ -4,7 +4,7 @@
from astropy.time import Time
import numpy as np
-from hapsira._math.linalg import norm as norm_fast
+from hapsira.core.math.linalg import norm_V_vf
from hapsira.core.util import alinspace as alinspace_fast
@@ -26,7 +26,7 @@ def norm(vec, axis=None):
result = norm_np(vec.value, axis=axis)
else:
- result = norm_fast(vec.value)
+ result = norm_V_vf(*vec.value)
return result << vec.unit
diff --git a/tests/conftest.py b/tests/conftest.py
index e99c92396..9835e434d 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,11 +1,14 @@
from astropy import units as u
from astropy.coordinates import solar_system_ephemeris
from astropy.time import Time
+import matplotlib as mpl
import pytest
from hapsira.bodies import Earth, Sun
from hapsira.twobody import Orbit
+mpl.rc("figure", max_open_warning=0)
+
solar_system_ephemeris.set("builtin")
diff --git a/tests/test_hyper.py b/tests/test_hyper.py
index 1caf64fe7..4baae59d3 100644
--- a/tests/test_hyper.py
+++ b/tests/test_hyper.py
@@ -3,12 +3,12 @@
import pytest
from scipy import special
-from hapsira._math.special import hyp2f1b as hyp2f1
+from hapsira.core.math.special import hyp2f1b_vf
@pytest.mark.parametrize("x", np.linspace(0, 1, num=11))
def test_hyp2f1_battin_scalar(x):
expected_res = special.hyp2f1(3, 1, 5 / 2, x)
- res = hyp2f1(x)
+ res = hyp2f1b_vf(x)
assert_allclose(res, expected_res)
diff --git a/tests/test_iod.py b/tests/test_iod.py
index ba44a1f37..63aa90f02 100644
--- a/tests/test_iod.py
+++ b/tests/test_iod.py
@@ -3,7 +3,7 @@
import pytest
from hapsira.bodies import Earth
-from hapsira.core import iod
+from hapsira.core.iod import compute_T_min_gf, compute_y_vf, tof_equation_y_vf
from hapsira.iod import izzo, vallado
@@ -119,9 +119,11 @@ def test_collinear_vectors_input(lambert):
@pytest.mark.parametrize("M", [1, 2, 3])
def test_minimum_time_of_flight_convergence(M):
ll = -1
- x_T_min_expected, T_min_expected = iod._compute_T_min(ll, M, numiter=10, rtol=1e-8)
- y = iod._compute_y(x_T_min_expected, ll)
- T_min = iod._tof_equation_y(x_T_min_expected, y, 0.0, ll, M)
+ x_T_min_expected, T_min_expected = compute_T_min_gf( # pylint: disable=E1120,E0633
+ ll, M, 10, 1e-8
+ )
+ y = compute_y_vf(x_T_min_expected, ll)
+ T_min = tof_equation_y_vf(x_T_min_expected, y, 0.0, ll, M)
assert T_min_expected == T_min
@@ -171,6 +173,6 @@ def test_vallado_not_implemented_multirev():
with pytest.raises(NotImplementedError) as excinfo:
vallado.lambert(k, r0, r, tof, M=1)
assert (
- "Multi-revolution scenario not supported for Vallado. See issue https://github.com/hapsira/hapsira/issues/858"
+ "Multi-revolution scenario not supported for Vallado. See issue https://github.com/poliastro/poliastro/issues/858"
in excinfo.exconly()
)
diff --git a/tests/test_stumpff.py b/tests/test_stumpff.py
index 4ee7b0be5..621867884 100644
--- a/tests/test_stumpff.py
+++ b/tests/test_stumpff.py
@@ -1,7 +1,7 @@
from numpy import cos, cosh, sin, sinh
from numpy.testing import assert_allclose
-from hapsira._math.special import stumpff_c2 as c2, stumpff_c3 as c3
+from hapsira.core.math.special import stumpff_c2_vf, stumpff_c3_vf
def test_stumpff_functions_near_zero():
@@ -9,8 +9,8 @@ def test_stumpff_functions_near_zero():
expected_c2 = (1 - cos(psi**0.5)) / psi
expected_c3 = (psi**0.5 - sin(psi**0.5)) / psi**1.5
- assert_allclose(c2(psi), expected_c2)
- assert_allclose(c3(psi), expected_c3)
+ assert_allclose(stumpff_c2_vf(psi), expected_c2)
+ assert_allclose(stumpff_c3_vf(psi), expected_c3)
def test_stumpff_functions_above_zero():
@@ -18,8 +18,8 @@ def test_stumpff_functions_above_zero():
expected_c2 = (1 - cos(psi**0.5)) / psi
expected_c3 = (psi**0.5 - sin(psi**0.5)) / psi**1.5
- assert_allclose(c2(psi), expected_c2, rtol=1e-10)
- assert_allclose(c3(psi), expected_c3, rtol=1e-10)
+ assert_allclose(stumpff_c2_vf(psi), expected_c2, rtol=1e-10)
+ assert_allclose(stumpff_c3_vf(psi), expected_c3, rtol=1e-10)
def test_stumpff_functions_under_zero():
@@ -27,5 +27,5 @@ def test_stumpff_functions_under_zero():
expected_c2 = (cosh((-psi) ** 0.5) - 1) / (-psi)
expected_c3 = (sinh((-psi) ** 0.5) - (-psi) ** 0.5) / (-psi) ** 1.5
- assert_allclose(c2(psi), expected_c2, rtol=1e-10)
- assert_allclose(c3(psi), expected_c3, rtol=1e-10)
+ assert_allclose(stumpff_c2_vf(psi), expected_c2, rtol=1e-10)
+ assert_allclose(stumpff_c3_vf(psi), expected_c3, rtol=1e-10)
diff --git a/tests/tests_core/test_core_propagation.py b/tests/tests_core/test_core_propagation.py
index 3c9f7e100..8dc4a3711 100644
--- a/tests/tests_core/test_core_propagation.py
+++ b/tests/tests_core/test_core_propagation.py
@@ -2,26 +2,44 @@
from astropy.tests.helper import assert_quantity_allclose
import pytest
-from hapsira.core.propagation import (
- danby_coe,
- gooding_coe,
- markley_coe,
- mikkola_coe,
- pimienta_coe,
+from hapsira.core.propagation.danby import danby_coe_vf, DANBY_NUMITER, DANBY_RTOL
+from hapsira.core.propagation.farnocchia import farnocchia_coe_vf
+from hapsira.core.propagation.gooding import (
+ gooding_coe_vf,
+ GOODING_NUMITER,
+ GOODING_RTOL,
)
-from hapsira.core.propagation.farnocchia import farnocchia_coe
+from hapsira.core.propagation.markley import markley_coe_vf
+from hapsira.core.propagation.mikkola import mikkola_coe_vf
+from hapsira.core.propagation.pimienta import pimienta_coe_vf
+from hapsira.core.propagation.recseries import (
+ recseries_coe_vf,
+ RECSERIES_METHOD_RTOL,
+ RECSERIES_ORDER,
+ RECSERIES_NUMITER,
+ RECSERIES_RTOL,
+)
+from hapsira.core.propagation.vallado import vallado_coe_vf, VALLADO_NUMITER
from hapsira.examples import iss
@pytest.mark.parametrize(
"propagator_coe",
[
- danby_coe,
- markley_coe,
- pimienta_coe,
- mikkola_coe,
- farnocchia_coe,
- gooding_coe,
+ lambda *args: danby_coe_vf(*args, DANBY_NUMITER, DANBY_RTOL),
+ markley_coe_vf,
+ pimienta_coe_vf,
+ mikkola_coe_vf,
+ farnocchia_coe_vf,
+ lambda *args: gooding_coe_vf(*args, GOODING_NUMITER, GOODING_RTOL),
+ lambda *args: recseries_coe_vf(
+ *args,
+ RECSERIES_METHOD_RTOL,
+ RECSERIES_ORDER,
+ RECSERIES_NUMITER,
+ RECSERIES_RTOL,
+ ),
+ lambda *args: vallado_coe_vf(*args, VALLADO_NUMITER),
],
)
def test_propagate_with_coe(propagator_coe):
diff --git a/tests/tests_core/test_core_util.py b/tests/tests_core/test_core_util.py
index cc613dc58..ec65c3430 100644
--- a/tests/tests_core/test_core_util.py
+++ b/tests/tests_core/test_core_util.py
@@ -10,20 +10,25 @@
from hapsira.core.util import (
alinspace,
- rotation_matrix as rotation_matrix_hapsira,
+ rotation_matrix_gf,
spherical_to_cartesian,
)
def _test_rotation_matrix_with_v(v, angle, axis):
exp = rotation_matrix_astropy(np.degrees(-angle), "xyz"[axis]) @ v
- res = rotation_matrix_hapsira(angle, axis) @ v
+ res = rotation_matrix_gf(
+ angle, axis, np.zeros((3,), dtype="u1")
+ ) # pylint: disable=E1120
+ res = res @ v
assert_allclose(exp, res)
def _test_rotation_matrix(angle, axis):
expected = rotation_matrix_astropy(-np.rad2deg(angle), "xyz"[axis])
- result = rotation_matrix_hapsira(angle, axis)
+ result = rotation_matrix_gf(
+ angle, axis, np.zeros((3,), dtype="u1")
+ ) # pylint: disable=E1120
assert_allclose(expected, result)
@@ -37,23 +42,38 @@ def test_rotation_matrix():
# These tests are adapted from astropy:
# https://github.com/astropy/astropy/blob/main/astropy/coordinates/tests/test_matrix_utilities.py
def test_rotation_matrix_astropy():
- assert_array_equal(rotation_matrix_hapsira(0, 0), np.eye(3))
+ exp = np.eye(3)
+ res = rotation_matrix_gf(0, 0, np.zeros((3,), dtype="u1")) # pylint: disable=E1120
+ assert_array_equal(res, exp)
+
+ exp = np.array([[0, 0, -1], [0, 1, 0], [1, 0, 0]], dtype=float)
+ res = rotation_matrix_gf(
+ np.deg2rad(-90), 1, np.zeros((3,), dtype="u1")
+ ) # pylint: disable=E1120
assert_allclose(
- rotation_matrix_hapsira(np.deg2rad(-90), 1),
- [[0, 0, -1], [0, 1, 0], [1, 0, 0]],
+ res,
+ exp,
atol=1e-12,
)
+ exp = np.array([[0, -1, 0], [1, 0, 0], [0, 0, 1]], dtype=float)
+ res = rotation_matrix_gf(
+ np.deg2rad(90), 2, np.zeros((3,), dtype="u1")
+ ) # pylint: disable=E1120
assert_allclose(
- rotation_matrix_hapsira(np.deg2rad(90), 2),
- [[0, -1, 0], [1, 0, 0], [0, 0, 1]],
+ res,
+ exp,
atol=1e-12,
)
# make sure it also works for very small angles
+ exp = rotation_matrix_astropy(-0.000001, "x")
+ res = rotation_matrix_gf(
+ np.deg2rad(0.000001), 0, np.zeros((3,), dtype="u1")
+ ) # pylint: disable=E1120
assert_allclose(
- rotation_matrix_astropy(-0.000001, "x"),
- rotation_matrix_hapsira(np.deg2rad(0.000001), 0),
+ exp,
+ res,
)
diff --git a/tests/tests_earth/tests_atmosphere/test_coesa76.py b/tests/tests_earth/tests_atmosphere/test_coesa76.py
index 774a70498..962da527a 100644
--- a/tests/tests_earth/tests_atmosphere/test_coesa76.py
+++ b/tests/tests_earth/tests_atmosphere/test_coesa76.py
@@ -1,5 +1,6 @@
from astropy import units as u
from astropy.tests.helper import assert_quantity_allclose
+from numpy.testing import assert_allclose
import pytest
from hapsira.earth.atmosphere import COESA76
@@ -42,8 +43,10 @@ def test_coefficients_over_86km():
-12.89844,
]
- assert coesa76._get_coefficients_avobe_86(350 * u.km, p_coeff) == expected_p
- assert coesa76._get_coefficients_avobe_86(350 * u.km, rho_coeff) == expected_rho
+ assert_allclose(coesa76._get_coefficients_avobe_86(350 * u.km, p_coeff), expected_p)
+ assert_allclose(
+ coesa76._get_coefficients_avobe_86(350 * u.km, rho_coeff), expected_rho
+ )
# SOLUTIONS DIRECTLY TAKEN FROM COESA76 REPORT
diff --git a/tests/tests_plotting/test_orbit_plotter.py b/tests/tests_plotting/test_orbit_plotter.py
index b769979f8..897f21c62 100644
--- a/tests/tests_plotting/test_orbit_plotter.py
+++ b/tests/tests_plotting/test_orbit_plotter.py
@@ -242,7 +242,7 @@ def test_set_frame_plots_same_colors():
def test_redraw_keeps_trajectories():
- # See https://github.com/hapsira/hapsira/issues/518
+ # See https://github.com/poliastro/poliastro/issues/518
op = OrbitPlotter()
trajectory = churi.sample()
op.plot_body_orbit(Mars, J2000_TDB, label="Mars")
diff --git a/tests/tests_twobody/test_angles.py b/tests/tests_twobody/test_angles.py
index 45da70fc3..81ec43309 100644
--- a/tests/tests_twobody/test_angles.py
+++ b/tests/tests_twobody/test_angles.py
@@ -5,7 +5,13 @@
import pytest
from hapsira.bodies import Earth
-from hapsira.core.elements import coe2mee, coe2rv, mee2coe, rv2coe
+from hapsira.core.elements import (
+ coe2mee_gf,
+ coe2rv_gf,
+ mee2coe_gf,
+ rv2coe_gf,
+ RV2COE_TOL,
+)
from hapsira.twobody.angles import (
E_to_M,
E_to_nu,
@@ -211,34 +217,60 @@ def test_eccentric_to_true_range(E, ecc):
def test_convert_between_coe_and_rv_is_transitive(classical):
k = Earth.k.to(u.km**3 / u.s**2).value # u.km**3 / u.s**2
- res = rv2coe(k, *coe2rv(k, *classical))
+ expected_res = classical
+
+ r, v = np.zeros((3,), dtype=float), np.zeros((3,), dtype=float)
+ coe2rv_gf(k, *expected_res, np.zeros((3,), dtype="u1"), r, v)
+
+ res = rv2coe_gf(k, r, v, RV2COE_TOL) # pylint: disable=E1120
+
assert_allclose(res, classical)
def test_convert_between_coe_and_mee_is_transitive(classical):
- res = mee2coe(*coe2mee(*classical))
+ res = mee2coe_gf(*coe2mee_gf(*classical)) # pylint: disable=E1133
assert_allclose(res, classical)
def test_convert_coe_and_rv_circular(circular):
k, expected_res = circular
- res = rv2coe(k, *coe2rv(k, *expected_res))
+
+ r, v = np.zeros((3,), dtype=float), np.zeros((3,), dtype=float)
+ coe2rv_gf(k, *expected_res, np.zeros((3,), dtype="u1"), r, v)
+
+ res = rv2coe_gf(k, r, v, RV2COE_TOL) # pylint: disable=E1120
+
assert_allclose(res, expected_res, atol=1e-8)
def test_convert_coe_and_rv_hyperbolic(hyperbolic):
k, expected_res = hyperbolic
- res = rv2coe(k, *coe2rv(k, *expected_res))
+
+ r, v = np.zeros((3,), dtype=float), np.zeros((3,), dtype=float)
+ coe2rv_gf(k, *expected_res, np.zeros((3,), dtype="u1"), r, v)
+
+ res = rv2coe_gf(k, r, v, RV2COE_TOL) # pylint: disable=E1120
+
assert_allclose(res, expected_res, atol=1e-8)
def test_convert_coe_and_rv_equatorial(equatorial):
k, expected_res = equatorial
- res = rv2coe(k, *coe2rv(k, *expected_res))
+
+ r, v = np.zeros((3,), dtype=float), np.zeros((3,), dtype=float)
+ coe2rv_gf(k, *expected_res, np.zeros((3,), dtype="u1"), r, v)
+
+ res = rv2coe_gf(k, r, v, RV2COE_TOL) # pylint: disable=E1120
+
assert_allclose(res, expected_res, atol=1e-8)
def test_convert_coe_and_rv_circular_equatorial(circular_equatorial):
k, expected_res = circular_equatorial
- res = rv2coe(k, *coe2rv(k, *expected_res))
+
+ r, v = np.zeros((3,), dtype=float), np.zeros((3,), dtype=float)
+ coe2rv_gf(k, *expected_res, np.zeros((3,), dtype="u1"), r, v)
+
+ res = rv2coe_gf(k, r, v, RV2COE_TOL) # pylint: disable=E1120
+
assert_allclose(res, expected_res, atol=1e-8)
diff --git a/tests/tests_twobody/test_events.py b/tests/tests_twobody/test_events.py
index b24e7cb01..af2b9c7ff 100644
--- a/tests/tests_twobody/test_events.py
+++ b/tests/tests_twobody/test_events.py
@@ -7,9 +7,11 @@
from hapsira.bodies import Earth
from hapsira.constants import H0_earth, rho0_earth
-from hapsira.core.events import line_of_sight
-from hapsira.core.perturbations import atmospheric_drag_exponential
-from hapsira.core.propagation import func_twobody
+from hapsira.core.events import line_of_sight_gf
+from hapsira.core.jit import djit
+from hapsira.core.math.linalg import add_VV_hf
+from hapsira.core.perturbations import atmospheric_drag_exponential_hf
+from hapsira.core.propagation.base import func_twobody_hf
from hapsira.twobody import Orbit
from hapsira.twobody.events import (
AltitudeCrossEvent,
@@ -47,15 +49,23 @@ def test_altitude_crossing():
altitude_cross_event = AltitudeCrossEvent(thresh_alt, R)
events = [altitude_cross_event]
- def f(t0, u_, k):
- du_kep = func_twobody(t0, u_, k)
- ax, ay, az = atmospheric_drag_exponential(
- t0, u_, k, R=R, C_D=C_D, A_over_m=A_over_m, H0=H0, rho0=rho0
+ @djit(cache=False)
+ def f_hf(t0, rr, vv, k):
+ du_kep_rr, du_kep_vv = func_twobody_hf(t0, rr, vv, k)
+ du_ad = atmospheric_drag_exponential_hf(
+ t0,
+ rr,
+ vv,
+ k,
+ R,
+ C_D,
+ A_over_m,
+ H0,
+ rho0,
)
- du_ad = np.array([0, 0, 0, ax, ay, az])
- return du_kep + du_ad
+ return du_kep_rr, add_VV_hf(du_kep_vv, du_ad)
- method = CowellPropagator(events=events, f=f)
+ method = CowellPropagator(events=events, f=f_hf)
rr, _ = method.propagate_many(
orbit._state,
tofs,
@@ -112,7 +122,7 @@ def test_penumbra_event_not_triggering_is_ok():
v0 = np.array([7.36138, 2.98997, 1.64354])
orbit = Orbit.from_vectors(attractor, r0 * u.km, v0 * u.km / u.s)
- penumbra_event = PenumbraEvent(orbit)
+ penumbra_event = PenumbraEvent(orbit, tof=tof)
method = CowellPropagator(events=[penumbra_event])
rr, _ = method.propagate_many(
orbit._state,
@@ -129,7 +139,7 @@ def test_umbra_event_not_triggering_is_ok():
v0 = np.array([7.36138, 2.98997, 1.64354])
orbit = Orbit.from_vectors(attractor, r0 * u.km, v0 * u.km / u.s)
- umbra_event = UmbraEvent(orbit)
+ umbra_event = UmbraEvent(orbit, tof=tof)
method = CowellPropagator(events=[umbra_event])
rr, _ = method.propagate_many(
@@ -156,7 +166,7 @@ def test_umbra_event_crossing():
epoch=epoch,
)
- umbra_event = UmbraEvent(orbit, terminal=True)
+ umbra_event = UmbraEvent(orbit, tof=tof, terminal=True)
method = CowellPropagator(events=[umbra_event])
rr, _ = method.propagate_many(
@@ -183,7 +193,7 @@ def test_penumbra_event_crossing():
epoch=epoch,
)
- penumbra_event = PenumbraEvent(orbit, terminal=True)
+ penumbra_event = PenumbraEvent(orbit, tof=tof, terminal=True)
method = CowellPropagator(events=[penumbra_event])
rr, _ = method.propagate_many(
orbit._state,
@@ -300,7 +310,11 @@ def test_propagation_stops_if_atleast_one_event_has_terminal_set_to_True(
epoch=epoch,
)
- penumbra_event = PenumbraEvent(orbit, terminal=penumbra_terminal)
+ penumbra_event = PenumbraEvent(
+ orbit,
+ tof=600 * u.s,
+ terminal=penumbra_terminal,
+ )
thresh_lat = 30 * u.deg
latitude_cross_event = LatitudeCrossEvent(
@@ -330,8 +344,8 @@ def test_line_of_sight():
r_sun = np.array([122233179, -76150708, 33016374]) << u.km
R = Earth.R.to(u.km).value
- los = line_of_sight(r1.value, r2.value, R)
- los_with_sun = line_of_sight(r1.value, r_sun.value, R)
+ los = line_of_sight_gf(r1.value, r2.value, R) # pylint: disable=E1120
+ los_with_sun = line_of_sight_gf(r1.value, r_sun.value, R) # pylint: disable=E1120
assert los < 0 # No LOS condition.
assert los_with_sun >= 0 # LOS condition.
@@ -342,14 +356,13 @@ def test_LOS_event_raises_warning_if_norm_of_r1_less_than_attractor_radius_durin
v2 = np.array([5021.38, -2900.7, 1000.354]) << u.km / u.s
orbit = Orbit.from_vectors(Earth, r2, v2)
- tofs = [100, 500, 1000, 2000] << u.s
+ tofs = tofs = np.arange(0, 2000, 10) << u.s
# Propagate the secondary body to generate its position coordinates
method = CowellPropagator()
- rr, vv = method.propagate_many(
+ secondary_rr, _ = method.propagate_many(
orbit._state,
tofs,
)
- pos_coords = rr # Trajectory of the secondary body.
r1 = (
np.array([0, -5010.696, -5102.509]) << u.km
@@ -357,16 +370,18 @@ def test_LOS_event_raises_warning_if_norm_of_r1_less_than_attractor_radius_durin
v1 = np.array([736.138, 29899.7, 164.354]) << u.km / u.s
orb = Orbit.from_vectors(Earth, r1, v1)
- los_event = LosEvent(Earth, pos_coords, terminal=True)
+ los_event = LosEvent(Earth, tofs, secondary_rr.T, terminal=True)
events = [los_event]
tofs = [0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.5] << u.s
- with pytest.warns(UserWarning, match="The norm of the position vector"):
- method = CowellPropagator(events=events)
- r, v = method.propagate_many(
- orb._state,
- tofs,
- )
+ # Can currently not warn due to: https://github.com/numba/numba/issues/1243
+ # TODO Matching implementation in LosEvent deactivated
+ # with pytest.warns(UserWarning, match="The norm of the position vector"):
+ method = CowellPropagator(events=events)
+ _, _ = method.propagate_many(
+ orb._state,
+ tofs,
+ ) # should trigger waring
@pytest.mark.filterwarnings("ignore::UserWarning")
@@ -375,20 +390,19 @@ def test_LOS_event_with_lithobrake_event_raises_warning_when_satellite_cuts_attr
v2 = np.array([5021.38, -2900.7, 1000.354]) << u.km / u.s
orbit = Orbit.from_vectors(Earth, r2, v2)
- tofs = [100, 500, 1000, 2000] << u.s
+ tofs = tofs = np.arange(0, 2000, 10) << u.s
# Propagate the secondary body to generate its position coordinates
method = CowellPropagator()
- rr, vv = method.propagate_many(
+ secondary_rr, _ = method.propagate_many(
orbit._state,
tofs,
)
- pos_coords = rr # Trajectory of the secondary body.
r1 = np.array([0, -5010.696, -5102.509]) << u.km
v1 = np.array([736.138, 2989.7, 164.354]) << u.km / u.s
orb = Orbit.from_vectors(Earth, r1, v1)
- los_event = LosEvent(Earth, pos_coords, terminal=True)
+ los_event = LosEvent(Earth, tofs, secondary_rr.T, terminal=True)
tofs = [
0.003,
0.004,
@@ -407,7 +421,7 @@ def test_LOS_event_with_lithobrake_event_raises_warning_when_satellite_cuts_attr
lithobrake_event = LithobrakeEvent(Earth.R.to_value(u.km))
method = CowellPropagator(events=[lithobrake_event, los_event])
- r, v = method.propagate_many(
+ _, _ = method.propagate_many(
orb._state,
tofs,
)
@@ -416,19 +430,19 @@ def test_LOS_event_with_lithobrake_event_raises_warning_when_satellite_cuts_attr
def test_LOS_event():
- t_los = 2327.165 * u.s
+ t_los = 2327.381434 * u.s
r2 = np.array([-500, 1500, 4012.09]) << u.km
v2 = np.array([5021.38, -2900.7, 1000.354]) << u.km / u.s
orbit = Orbit.from_vectors(Earth, r2, v2)
- tofs = [100, 500, 1000, 2000] << u.s
+ tofs = np.arange(0, 5000, 10) << u.s
+
# Propagate the secondary body to generate its position coordinates
method = CowellPropagator()
- rr, vv = method.propagate_many(
+ secondary_rr, _ = method.propagate_many(
orbit._state,
tofs,
)
- pos_coords = rr # Trajectory of the secondary body.
orb = Orbit.from_classical(
attractor=Earth,
@@ -440,12 +454,11 @@ def test_LOS_event():
nu=30 * u.deg,
)
- los_event = LosEvent(Earth, pos_coords, terminal=True)
+ los_event = LosEvent(Earth, tofs, secondary_rr.T, terminal=True)
events = [los_event]
- tofs = [1, 5, 10, 100, 1000, 2000, 3000, 5000] << u.s
method = CowellPropagator(events=events)
- r, v = method.propagate_many(
+ _, _ = method.propagate_many(
orb._state,
tofs,
)
diff --git a/tests/tests_twobody/test_orbit.py b/tests/tests_twobody/test_orbit.py
index 5eb5e63f7..59713ca33 100644
--- a/tests/tests_twobody/test_orbit.py
+++ b/tests/tests_twobody/test_orbit.py
@@ -419,7 +419,7 @@ def test_sample_numpoints():
def test_sample_big_orbits():
- # See https://github.com/hapsira/hapsira/issues/265
+ # See https://github.com/poliastro/poliastro/issues/265
ss = Orbit.from_vectors(
Sun,
[-9_018_878.6, -94_116_055, 22_619_059] * u.km,
@@ -1199,14 +1199,14 @@ def test_time_to_anomaly(expected_nu):
# In some corner cases the resulting anomaly goes out of range,
# and rather than trying to fix it right now
# we will wait until we remove the round tripping,
- # see https://github.com/hapsira/hapsira/issues/921
+ # see https://github.com/poliastro/poliastro/issues/921
# FIXME: Add test that verifies that `orbit.nu` is always within range
assert_quantity_allclose(iss_propagated.nu, expected_nu, atol=1e-12 * u.rad)
@pytest.mark.xfail
def test_can_set_iss_attractor_to_earth():
- # See https://github.com/hapsira/hapsira/issues/798
+ # See https://github.com/poliastro/poliastro/issues/798
epoch = Time("2019-11-10 12:00:00")
ephem = Ephem.from_horizons(
"International Space Station",
@@ -1235,7 +1235,7 @@ def test_issue_916(mock_query):
def test_near_parabolic_M_does_not_hang(near_parabolic):
- # See https://github.com/hapsira/hapsira/issues/907
+ # See https://github.com/poliastro/poliastro/issues/907
expected_nu = -168.65 * u.deg
orb = near_parabolic.propagate_to_anomaly(expected_nu)
@@ -1253,7 +1253,7 @@ def test_propagation_near_parabolic_orbits_zero_seconds_gives_same_anomaly(
def test_propagation_near_parabolic_orbits_does_not_hang(near_parabolic):
- # See https://github.com/hapsira/hapsira/issues/475
+ # See https://github.com/poliastro/poliastro/issues/475
orb_final = near_parabolic.propagate(near_parabolic.period)
# Smoke test
diff --git a/tests/tests_twobody/test_perturbations.py b/tests/tests_twobody/test_perturbations.py
index 630a41b7e..94abf302b 100644
--- a/tests/tests_twobody/test_perturbations.py
+++ b/tests/tests_twobody/test_perturbations.py
@@ -1,5 +1,3 @@
-import functools
-
from astropy import units as u
from astropy.coordinates import Angle
from astropy.tests.helper import assert_quantity_allclose
@@ -10,17 +8,20 @@
from hapsira.bodies import Earth, Moon, Sun
from hapsira.constants import H0_earth, Wdivc_sun, rho0_earth
-from hapsira.core.elements import rv2coe
-from hapsira.core.perturbations import (
- J2_perturbation,
- J3_perturbation,
- atmospheric_drag,
- atmospheric_drag_exponential,
- radiation_pressure,
- third_body,
+from hapsira.core.earth.atmosphere.coesa76 import density_hf as coesa76_density_hf
+from hapsira.core.elements import rv2coe_gf, RV2COE_TOL
+from hapsira.core.jit import hjit, djit
+from hapsira.core.math.linalg import add_VV_hf, mul_Vs_hf, norm_V_hf
+from hapsira.core.perturbations import ( # pylint: disable=E1120,E1136
+ J2_perturbation_hf,
+ J3_perturbation_hf,
+ atmospheric_drag_hf,
+ atmospheric_drag_exponential_hf,
+ radiation_pressure_hf,
+ third_body_hf,
)
-from hapsira.core.propagation import func_twobody
-from hapsira.earth.atmosphere import COESA76
+from hapsira.core.propagation.base import func_twobody_hf
+
from hapsira.ephem import build_ephem_interpolant
from hapsira.twobody import Orbit
from hapsira.twobody.events import LithobrakeEvent
@@ -37,22 +38,33 @@ def test_J2_propagation_Earth():
orbit = Orbit.from_vectors(Earth, r0 * u.km, v0 * u.km / u.s)
tofs = [48.0] * u.h
+ J2 = Earth.J2.value
+ R_ = Earth.R.to(u.km).value
- def f(t0, u_, k):
- du_kep = func_twobody(t0, u_, k)
- ax, ay, az = J2_perturbation(
- t0, u_, k, J2=Earth.J2.value, R=Earth.R.to(u.km).value
+ @djit(cache=False)
+ def f_hf(t0, rr, vv, k):
+ du_kep_rr, du_kep_vv = func_twobody_hf(t0, rr, vv, k)
+ du_ad = J2_perturbation_hf(
+ t0,
+ rr,
+ vv,
+ k,
+ J2,
+ R_,
)
- du_ad = np.array([0, 0, 0, ax, ay, az])
- return du_kep + du_ad
+ return du_kep_rr, add_VV_hf(du_kep_vv, du_ad)
- method = CowellPropagator(f=f)
+ method = CowellPropagator(f=f_hf)
rr, vv = method.propagate_many(orbit._state, tofs)
k = Earth.k.to(u.km**3 / u.s**2).value
- _, _, _, raan0, argp0, _ = rv2coe(k, r0, v0)
- _, _, _, raan, argp, _ = rv2coe(k, rr[0].to(u.km).value, vv[0].to(u.km / u.s).value)
+ _, _, _, raan0, argp0, _ = rv2coe_gf( # pylint: disable=E1120,E0633
+ k, r0, v0, RV2COE_TOL
+ )
+ _, _, _, raan, argp, _ = rv2coe_gf( # pylint: disable=E1120,E0633
+ k, rr[0].to(u.km).value, vv[0].to(u.km / u.s).value, RV2COE_TOL
+ )
raan_variation_rate = (raan - raan0) / tofs[0].to(u.s).value # type: ignore
argp_variation_rate = (argp - argp0) / tofs[0].to(u.s).value # type: ignore
@@ -117,30 +129,53 @@ def test_J3_propagation_Earth(test_params):
nu=nu_ini,
)
- def f(t0, u_, k):
- du_kep = func_twobody(t0, u_, k)
- ax, ay, az = J2_perturbation(
- t0, u_, k, J2=Earth.J2.value, R=Earth.R.to(u.km).value
+ J2 = Earth.J2.value
+ R_ = Earth.R.to(u.km).value
+
+ @djit(cache=False)
+ def f_hf(t0, rr, vv, k):
+ du_kep_rr, du_kep_vv = func_twobody_hf(t0, rr, vv, k)
+ du_ad = J2_perturbation_hf(
+ t0,
+ rr,
+ vv,
+ k,
+ J2,
+ R_,
)
- du_ad = np.array([0, 0, 0, ax, ay, az])
- return du_kep + du_ad
+ return du_kep_rr, add_VV_hf(du_kep_vv, du_ad)
tofs = np.linspace(0, 10.0 * u.day, 1000)
- method = CowellPropagator(rtol=1e-8, f=f)
+ method = CowellPropagator(rtol=1e-8, f=f_hf)
r_J2, v_J2 = method.propagate_many(
orbit._state,
tofs,
)
- def f_combined(t0, u_, k):
- du_kep = func_twobody(t0, u_, k)
- ax, ay, az = J2_perturbation(
- t0, u_, k, J2=Earth.J2.value, R=Earth.R.to_value(u.km)
- ) + J3_perturbation(t0, u_, k, J3=Earth.J3.value, R=Earth.R.to_value(u.km))
- du_ad = np.array([0, 0, 0, ax, ay, az])
- return du_kep + du_ad
+ J3 = Earth.J3.value
- method = CowellPropagator(rtol=1e-8, f=f_combined)
+ @djit(cache=False)
+ def f_combined_hf(t0, rr, vv, k):
+ du_kep_rr, du_kep_vv = func_twobody_hf(t0, rr, vv, k)
+ du_ad_J2 = J2_perturbation_hf(
+ t0,
+ rr,
+ vv,
+ k,
+ J2,
+ R_,
+ )
+ du_ad_J3 = J3_perturbation_hf(
+ t0,
+ rr,
+ vv,
+ k,
+ J3,
+ R_,
+ )
+ return du_kep_rr, add_VV_hf(du_kep_vv, add_VV_hf(du_ad_J2, du_ad_J3))
+
+ method = CowellPropagator(rtol=1e-8, f=f_combined_hf)
r_J3, v_J3 = method.propagate_many(
orbit._state,
tofs,
@@ -148,13 +183,23 @@ def f_combined(t0, u_, k):
a_values_J2 = np.array(
[
- rv2coe(k, ri, vi)[0] / (1.0 - rv2coe(k, ri, vi)[1] ** 2)
+ rv2coe_gf(k, ri, vi, RV2COE_TOL)[0] # pylint: disable=E1120,E1136
+ / (
+ 1.0
+ - rv2coe_gf(k, ri, vi, RV2COE_TOL)[1] # pylint: disable=E1120,E1136
+ ** 2
+ )
for ri, vi in zip(r_J2.to(u.km).value, v_J2.to(u.km / u.s).value)
]
)
a_values_J3 = np.array(
[
- rv2coe(k, ri, vi)[0] / (1.0 - rv2coe(k, ri, vi)[1] ** 2)
+ rv2coe_gf(k, ri, vi, RV2COE_TOL)[0] # pylint: disable=E1120,E1136
+ / (
+ 1.0
+ - rv2coe_gf(k, ri, vi, RV2COE_TOL)[1] # pylint: disable=E1120,E1136
+ ** 2
+ )
for ri, vi in zip(r_J3.to(u.km).value, v_J3.to(u.km / u.s).value)
]
)
@@ -162,13 +207,13 @@ def f_combined(t0, u_, k):
ecc_values_J2 = np.array(
[
- rv2coe(k, ri, vi)[1]
+ rv2coe_gf(k, ri, vi, RV2COE_TOL)[1] # pylint: disable=E1120,E1136
for ri, vi in zip(r_J2.to(u.km).value, v_J2.to(u.km / u.s).value)
]
)
ecc_values_J3 = np.array(
[
- rv2coe(k, ri, vi)[1]
+ rv2coe_gf(k, ri, vi, RV2COE_TOL)[1] # pylint: disable=E1120,E1136
for ri, vi in zip(r_J3.to(u.km).value, v_J3.to(u.km / u.s).value)
]
)
@@ -176,13 +221,13 @@ def f_combined(t0, u_, k):
inc_values_J2 = np.array(
[
- rv2coe(k, ri, vi)[2]
+ rv2coe_gf(k, ri, vi, RV2COE_TOL)[2] # pylint: disable=E1120,E1136
for ri, vi in zip(r_J2.to(u.km).value, v_J2.to(u.km / u.s).value)
]
)
inc_values_J3 = np.array(
[
- rv2coe(k, ri, vi)[2]
+ rv2coe_gf(k, ri, vi, RV2COE_TOL)[2] # pylint: disable=E1120,E1136
for ri, vi in zip(r_J3.to(u.km).value, v_J3.to(u.km / u.s).value)
]
)
@@ -226,15 +271,23 @@ def test_atmospheric_drag_exponential():
# dr_expected = F_r * tof (Newton's integration formula), where
# F_r = -B rho(r) |r|^2 sqrt(k / |r|^3) = -B rho(r) sqrt(k |r|)
- def f(t0, u_, k):
- du_kep = func_twobody(t0, u_, k)
- ax, ay, az = atmospheric_drag_exponential(
- t0, u_, k, R=R, C_D=C_D, A_over_m=A_over_m, H0=H0, rho0=rho0
+ @djit(cache=False)
+ def f_hf(t0, rr, vv, k):
+ du_kep_rr, du_kep_vv = func_twobody_hf(t0, rr, vv, k)
+ du_ad = atmospheric_drag_exponential_hf(
+ t0,
+ rr,
+ vv,
+ k,
+ R,
+ C_D,
+ A_over_m,
+ H0,
+ rho0,
)
- du_ad = np.array([0, 0, 0, ax, ay, az])
- return du_kep + du_ad
+ return du_kep_rr, add_VV_hf(du_kep_vv, du_ad)
- method = CowellPropagator(f=f)
+ method = CowellPropagator(f=f_hf)
rr, _ = method.propagate_many(
orbit._state,
[tof] * u.s,
@@ -268,15 +321,23 @@ def test_atmospheric_demise():
lithobrake_event = LithobrakeEvent(R)
events = [lithobrake_event]
- def f(t0, u_, k):
- du_kep = func_twobody(t0, u_, k)
- ax, ay, az = atmospheric_drag_exponential(
- t0, u_, k, R=R, C_D=C_D, A_over_m=A_over_m, H0=H0, rho0=rho0
+ @djit(cache=False)
+ def f_hf(t0, rr, vv, k):
+ du_kep_rr, du_kep_vv = func_twobody_hf(t0, rr, vv, k)
+ du_ad = atmospheric_drag_exponential_hf(
+ t0,
+ rr,
+ vv,
+ k,
+ R,
+ C_D,
+ A_over_m,
+ H0,
+ rho0,
)
- du_ad = np.array([0, 0, 0, ax, ay, az])
- return du_kep + du_ad
+ return du_kep_rr, add_VV_hf(du_kep_vv, du_ad)
- method = CowellPropagator(events=events, f=f)
+ method = CowellPropagator(events=events, f=f_hf)
rr, _ = method.propagate_many(
orbit._state,
tofs,
@@ -291,7 +352,7 @@ def f(t0, u_, k):
lithobrake_event = LithobrakeEvent(R)
events = [lithobrake_event]
- method = CowellPropagator(events=events, f=f)
+ method = CowellPropagator(events=events, f=f_hf)
rr, _ = method.propagate_many(
orbit._state,
tofs,
@@ -319,27 +380,31 @@ def test_atmospheric_demise_coesa76():
lithobrake_event = LithobrakeEvent(R)
events = [lithobrake_event]
- coesa76 = COESA76()
-
- def f(t0, u_, k):
- du_kep = func_twobody(t0, u_, k)
+ @djit(cache=False)
+ def f_hf(t0, rr, vv, k):
+ du_kep_rr, du_kep_vv = func_twobody_hf(t0, rr, vv, k)
# Avoid undershooting H below attractor radius R
- H = max(norm(u_[:3]), R)
- rho = coesa76.density((H - R) * u.km).to_value(u.kg / u.km**3)
+ H = norm_V_hf(rr)
+ if H < R:
+ H = R
- ax, ay, az = atmospheric_drag(
+ rho = (
+ coesa76_density_hf(H - R, True) * 1e9
+ ) # HACK convert from kg/m**3 to kg/km**3
+
+ du_ad = atmospheric_drag_hf(
t0,
- u_,
+ rr,
+ vv,
k,
- C_D=C_D,
- A_over_m=A_over_m,
- rho=rho,
+ C_D,
+ A_over_m,
+ rho,
)
- du_ad = np.array([0, 0, 0, ax, ay, az])
- return du_kep + du_ad
+ return du_kep_rr, add_VV_hf(du_kep_vv, du_ad)
- method = CowellPropagator(events=events, f=f)
+ method = CowellPropagator(events=events, f=f_hf)
rr, _ = method.propagate_many(
orbit._state,
tofs,
@@ -373,18 +438,17 @@ def test_cowell_works_with_small_perturbations():
initial = Orbit.from_vectors(Earth, r0, v0)
- def accel(t0, state, k):
- v_vec = state[3:]
- norm_v = (v_vec * v_vec).sum() ** 0.5
- return 1e-5 * v_vec / norm_v
+ @hjit("V(f,V,V,f)")
+ def accel_hf(t0, rr, vv, k):
+ return mul_Vs_hf(vv, 1e-5 / norm_V_hf(vv))
- def f(t0, u_, k):
- du_kep = func_twobody(t0, u_, k)
- ax, ay, az = accel(t0, u_, k)
- du_ad = np.array([0, 0, 0, ax, ay, az])
- return du_kep + du_ad
+ @djit
+ def f_hf(t0, rr, vv, k):
+ du_kep_rr, du_kep_vv = func_twobody_hf(t0, rr, vv, k)
+ du_ad = accel_hf(t0, rr, vv, k)
+ return du_kep_rr, add_VV_hf(du_kep_vv, du_ad)
- final = initial.propagate(3 * u.day, method=CowellPropagator(f=f))
+ final = initial.propagate(3 * u.day, method=CowellPropagator(f=f_hf))
# TODO: Accuracy reduced after refactor,
# but unclear what are we comparing against
@@ -399,18 +463,18 @@ def test_cowell_converges_with_small_perturbations():
initial = Orbit.from_vectors(Earth, r0, v0)
- def accel(t0, state, k):
- v_vec = state[3:]
- norm_v = (v_vec * v_vec).sum() ** 0.5
- return 0.0 * v_vec / norm_v
+ @hjit("V(f,V,V,f)")
+ def accel_hf(t0, rr, vv, k):
+ norm_v = norm_V_hf(vv)
+ return mul_Vs_hf(vv, 0.0 / norm_v)
- def f(t0, u_, k):
- du_kep = func_twobody(t0, u_, k)
- ax, ay, az = accel(t0, u_, k)
- du_ad = np.array([0, 0, 0, ax, ay, az])
- return du_kep + du_ad
+ @djit
+ def f_hf(t0, rr, vv, k):
+ du_kep_rr, du_kep_vv = func_twobody_hf(t0, rr, vv, k)
+ du_ad = accel_hf(t0, rr, vv, k)
+ return du_kep_rr, add_VV_hf(du_kep_vv, du_ad)
- final = initial.propagate(initial.period, method=CowellPropagator(f=f))
+ final = initial.propagate(initial.period, method=CowellPropagator(f=f_hf))
assert_quantity_allclose(final.r, initial.r)
assert_quantity_allclose(final.v, initial.v)
@@ -551,20 +615,22 @@ def test_3rd_body_Curtis(test_params):
end=epoch + test_params["tof"],
)
body_r = build_ephem_interpolant(body, body_epochs)
+ k_third = body.k.to_value(u.km**3 / u.s**2)
- def f(t0, u_, k):
- du_kep = func_twobody(t0, u_, k)
- ax, ay, az = third_body(
+ @djit(cache=False)
+ def f_hf(t0, rr, vv, k):
+ du_kep_rr, du_kep_vv = func_twobody_hf(t0, rr, vv, k)
+ du_ad = third_body_hf(
t0,
- u_,
+ rr,
+ vv,
k,
- k_third=body.k.to_value(u.km**3 / u.s**2),
- perturbation_body=body_r,
+ k_third,
+ body_r, # perturbation_body
)
- du_ad = np.array([0, 0, 0, ax, ay, az])
- return du_kep + du_ad
+ return du_kep_rr, add_VV_hf(du_kep_vv, du_ad)
- method = CowellPropagator(rtol=1e-10, f=f)
+ method = CowellPropagator(rtol=1e-10, f=f_hf)
rr, vv = method.propagate_many(
initial._state,
np.linspace(0, tof, 400) << u.s,
@@ -573,7 +639,10 @@ def f(t0, u_, k):
incs, raans, argps = [], [], []
for ri, vi in zip(rr.to_value(u.km), vv.to_value(u.km / u.s)):
angles = Angle(
- rv2coe(Earth.k.to_value(u.km**3 / u.s**2), ri, vi)[2:5] * u.rad
+ rv2coe_gf( # pylint: disable=E1120,E1136
+ Earth.k.to_value(u.km**3 / u.s**2), ri, vi, RV2COE_TOL
+ )[2:5]
+ * u.rad
) # inc, raan, argp
angles = angles.wrap_at(180 * u.deg)
incs.append(angles[0].value)
@@ -604,12 +673,7 @@ def sun_r():
tof = 600 * u.day
epoch = Time(j_date, format="jd", scale="tdb")
ephem_epochs = time_range(epoch, num_values=164, end=epoch + tof)
- return build_ephem_interpolant(Sun, ephem_epochs)
-
-
-def normalize_to_Curtis(t0, sun_r):
- r = sun_r(t0)
- return 149600000 * r / norm(r)
+ return build_ephem_interpolant(Sun, ephem_epochs) # returns hf
@pytest.mark.slow
@@ -641,27 +705,35 @@ def test_solar_pressure(t_days, deltas_expected, sun_r):
nu=343.4268 * u.deg,
epoch=epoch,
)
- # In Curtis, the mean distance to Sun is used. In order to validate against it, we have to do the same thing
- sun_normalized = functools.partial(normalize_to_Curtis, sun_r=sun_r)
- def f(t0, u_, k):
- du_kep = func_twobody(t0, u_, k)
- ax, ay, az = radiation_pressure(
+ # In Curtis, the mean distance to Sun is used. In order to validate against it, we have to do the same thing
+ @hjit("V(f)", cache=False)
+ def sun_normalized_hf(t0):
+ r = sun_r(t0) # sun_r is hf, returns V
+ return mul_Vs_hf(r, 149600000 / norm_V_hf(r))
+
+ R_ = Earth.R.to(u.km).value
+ Wdivc_s = Wdivc_sun.value
+
+ @djit(cache=False)
+ def f_hf(t0, rr, vv, k):
+ du_kep_rr, du_kep_vv = func_twobody_hf(t0, rr, vv, k)
+ du_ad = radiation_pressure_hf(
t0,
- u_,
+ rr,
+ vv,
k,
- R=Earth.R.to(u.km).value,
- C_R=2.0,
- A_over_m=2e-4 / 100,
- Wdivc_s=Wdivc_sun.value,
- star=sun_normalized,
+ R_,
+ 2.0, # C_R
+ 2e-4 / 100, # A_over_m
+ Wdivc_s,
+ sun_normalized_hf, # star
)
- du_ad = np.array([0, 0, 0, ax, ay, az])
- return du_kep + du_ad
+ return du_kep_rr, add_VV_hf(du_kep_vv, du_ad)
method = CowellPropagator(
rtol=1e-8,
- f=f,
+ f=f_hf,
)
rr, vv = method.propagate_many(
initial._state,
@@ -670,7 +742,9 @@ def f(t0, u_, k):
delta_eccs, delta_incs, delta_raans, delta_argps = [], [], [], []
for ri, vi in zip(rr.to(u.km).value, vv.to(u.km / u.s).value):
- orbit_params = rv2coe(Earth.k.to(u.km**3 / u.s**2).value, ri, vi)
+ orbit_params = rv2coe_gf( # pylint: disable=E1120,E1136
+ Earth.k.to(u.km**3 / u.s**2).value, ri, vi, RV2COE_TOL
+ )
delta_eccs.append(orbit_params[1] - initial.ecc.value)
delta_incs.append((orbit_params[2] * u.rad).to(u.deg).value - initial.inc.value)
delta_raans.append(
diff --git a/tests/tests_twobody/test_propagation.py b/tests/tests_twobody/test_propagation.py
index 0d20983bf..6de438993 100644
--- a/tests/tests_twobody/test_propagation.py
+++ b/tests/tests_twobody/test_propagation.py
@@ -9,8 +9,10 @@
from hapsira.bodies import Earth, Moon, Sun
from hapsira.constants import J2000
-from hapsira.core.elements import rv2coe
-from hapsira.core.propagation import func_twobody
+from hapsira.core.elements import rv2coe_gf, RV2COE_TOL
+from hapsira.core.jit import djit, hjit
+from hapsira.core.math.linalg import add_VV_hf, mul_Vs_hf, norm_V_hf
+from hapsira.core.propagation.base import func_twobody_hf
from hapsira.examples import iss
from hapsira.frames import Planes
from hapsira.twobody import Orbit
@@ -203,10 +205,11 @@ def test_propagation_parabolic(propagator):
orbit = Orbit.parabolic(Earth, p, _a, _a, _a, _a)
orbit = orbit.propagate(0.8897 / 2.0 * u.h, method=propagator())
- _, _, _, _, _, nu0 = rv2coe(
+ _, _, _, _, _, nu0 = rv2coe_gf( # pylint: disable=E1120,E0633
Earth.k.to(u.km**3 / u.s**2).value,
orbit.r.to(u.km).value,
orbit.v.to(u.km / u.s).value,
+ RV2COE_TOL,
)
assert_quantity_allclose(nu0, np.deg2rad(90.0), rtol=1e-4)
@@ -288,22 +291,21 @@ def test_cowell_propagation_circle_to_circle():
# From [Edelbaum, 1961]
accel = 1e-7
- def constant_accel(t0, u_, k):
- v = u_[3:]
- norm_v = (v[0] ** 2 + v[1] ** 2 + v[2] ** 2) ** 0.5
- return accel * v / norm_v
+ @hjit("V(f,V,V,f)")
+ def constant_accel_hf(t0, rr, vv, k):
+ norm_v = norm_V_hf(vv)
+ return mul_Vs_hf(vv, accel / norm_v)
- def f(t0, u_, k):
- du_kep = func_twobody(t0, u_, k)
- ax, ay, az = constant_accel(t0, u_, k)
- du_ad = np.array([0, 0, 0, ax, ay, az])
-
- return du_kep + du_ad
+ @djit(cache=False)
+ def f_hf(t0, rr, vv, k):
+ du_kep_rr, du_kep_vv = func_twobody_hf(t0, rr, vv, k)
+ du_ad = constant_accel_hf(t0, rr, vv, k)
+ return du_kep_rr, add_VV_hf(du_kep_vv, du_ad)
ss = Orbit.circular(Earth, 500 * u.km)
tofs = [20] * ss.period
- method = CowellPropagator(f=f)
+ method = CowellPropagator(f=f_hf)
rrs, vvs = method.propagate_many(ss._state, tofs)
orb_final = Orbit.from_vectors(Earth, rrs[0], vvs[0])
@@ -345,7 +347,7 @@ def test_propagate_to_date_has_proper_epoch():
)
def test_propagate_long_times_keeps_geometry(method):
# TODO: Extend to other propagators?
- # See https://github.com/hapsira/hapsira/issues/265
+ # See https://github.com/poliastro/poliastro/issues/265
time_of_flight = 100 * u.year
res = iss.propagate(time_of_flight, method=method)
@@ -434,7 +436,7 @@ def test_propagation_sets_proper_epoch():
def test_sample_around_moon_works():
- # See https://github.com/hapsira/hapsira/issues/649
+ # See https://github.com/poliastro/poliastro/issues/649
orbit = Orbit.circular(Moon, 100 << u.km)
coords = orbit.sample(10)
@@ -444,7 +446,7 @@ def test_sample_around_moon_works():
def test_propagate_around_moon_works():
- # See https://github.com/hapsira/hapsira/issues/649
+ # See https://github.com/poliastro/poliastro/issues/649
orbit = Orbit.circular(Moon, 100 << u.km)
new_orbit = orbit.propagate(1 << u.h)
diff --git a/tests/tests_twobody/test_thrust.py b/tests/tests_twobody/test_thrust.py
index 38a1b65f2..cf0d1c83d 100644
--- a/tests/tests_twobody/test_thrust.py
+++ b/tests/tests_twobody/test_thrust.py
@@ -4,12 +4,12 @@
import pytest
from hapsira.bodies import Earth
-from hapsira.core.propagation import func_twobody
-from hapsira.core.thrust import (
- change_a_inc as change_a_inc_fast,
- change_argp as change_argp_fast,
-)
-from hapsira.core.thrust.change_ecc_inc import beta as beta_change_ecc_inc
+from hapsira.core.jit import djit
+from hapsira.core.math.linalg import add_VV_hf
+from hapsira.core.propagation.base import func_twobody_hf
+from hapsira.core.thrust.change_a_inc import change_a_inc_hb
+from hapsira.core.thrust.change_argp import change_argp_hb
+from hapsira.core.thrust.change_ecc_inc import beta_vf as beta_change_ecc_inc
from hapsira.twobody import Orbit
from hapsira.twobody.propagation import CowellPropagator
from hapsira.twobody.thrust import (
@@ -34,19 +34,19 @@ def test_leo_geo_numerical_safe(inc_0):
k = Earth.k.to(u.km**3 / u.s**2)
- a_d, _, t_f = change_a_inc(k, a_0, a_f, inc_0, inc_f, f)
+ a_d_hf, _, t_f = change_a_inc(k, a_0, a_f, inc_0, inc_f, f)
# Retrieve r and v from initial orbit
s0 = Orbit.circular(Earth, a_0 - Earth.R, inc_0)
# Propagate orbit
- def f_leo_geo(t0, u_, k):
- du_kep = func_twobody(t0, u_, k)
- ax, ay, az = a_d(t0, u_, k)
- du_ad = np.array([0, 0, 0, ax, ay, az])
- return du_kep + du_ad
+ @djit(cache=False)
+ def f_leo_geo_hf(t0, rr, vv, k):
+ du_kep_rr, du_kep_vv = func_twobody_hf(t0, rr, vv, k)
+ du_ad = a_d_hf(t0, rr, vv, k)
+ return du_kep_rr, add_VV_hf(du_kep_vv, du_ad)
- sf = s0.propagate(t_f, method=CowellPropagator(rtol=1e-6, f=f_leo_geo))
+ sf = s0.propagate(t_f, method=CowellPropagator(rtol=1e-6, f=f_leo_geo_hf))
assert_allclose(sf.a.to(u.km).value, a_f.value, rtol=1e-3)
assert_allclose(sf.ecc.value, 0.0, atol=1e-2)
@@ -66,19 +66,19 @@ def test_leo_geo_numerical_fast(inc_0):
k = Earth.k.to(u.km**3 / u.s**2).value
- a_d, _, t_f = change_a_inc_fast(k, a_0, a_f, inc_0, inc_f, f)
+ a_d_hf, _, t_f = change_a_inc_hb(k, a_0, a_f, inc_0, inc_f, f)
# Retrieve r and v from initial orbit
s0 = Orbit.circular(Earth, a_0 * u.km - Earth.R, inc_0 * u.rad)
# Propagate orbit
- def f_leo_geo(t0, u_, k):
- du_kep = func_twobody(t0, u_, k)
- ax, ay, az = a_d(t0, u_, k)
- du_ad = np.array([0, 0, 0, ax, ay, az])
- return du_kep + du_ad
+ @djit(cache=False)
+ def f_leo_geo_hf(t0, rr, vv, k):
+ du_kep_rr, du_kep_vv = func_twobody_hf(t0, rr, vv, k)
+ du_ad = a_d_hf(t0, rr, vv, k)
+ return du_kep_rr, add_VV_hf(du_kep_vv, du_ad)
- sf = s0.propagate(t_f * u.s, method=CowellPropagator(rtol=1e-6, f=f_leo_geo))
+ sf = s0.propagate(t_f * u.s, method=CowellPropagator(rtol=1e-6, f=f_leo_geo_hf))
assert_allclose(sf.a.to(u.km).value, a_f, rtol=1e-3)
assert_allclose(sf.ecc.value, 0.0, atol=1e-2)
@@ -128,16 +128,18 @@ def test_sso_disposal_numerical(ecc_0, ecc_f):
argp=0 * u.deg,
nu=0 * u.deg,
)
- a_d, _, t_f = change_ecc_quasioptimal(s0, ecc_f, f)
+ a_d_hf, _, t_f = change_ecc_quasioptimal(s0, ecc_f, f)
# Propagate orbit
- def f_ss0_disposal(t0, u_, k):
- du_kep = func_twobody(t0, u_, k)
- ax, ay, az = a_d(t0, u_, k)
- du_ad = np.array([0, 0, 0, ax, ay, az])
- return du_kep + du_ad
+ @djit(cache=False)
+ def f_ss0_disposal_hf(t0, rr, vv, k):
+ du_kep_rr, du_kep_vv = func_twobody_hf(t0, rr, vv, k)
+ du_ad = a_d_hf(t0, rr, vv, k)
+ return du_kep_rr, add_VV_hf(du_kep_vv, du_ad)
- sf = s0.propagate(t_f * u.s, method=CowellPropagator(rtol=1e-8, f=f_ss0_disposal))
+ sf = s0.propagate(
+ t_f * u.s, method=CowellPropagator(rtol=1e-8, f=f_ss0_disposal_hf)
+ )
assert_allclose(sf.ecc.value, ecc_f, rtol=1e-4, atol=1e-4)
@@ -172,9 +174,7 @@ def test_geo_cases_beta_dnd_delta_v(ecc_0, inc_f, expected_beta, expected_delta_
nu=0 * u.deg,
)
- beta = beta_change_ecc_inc(
- ecc_0=ecc_0, ecc_f=ecc_f, inc_0=inc_0, inc_f=inc_f, argp=argp
- )
+ beta = beta_change_ecc_inc(ecc_0, ecc_f, inc_0, inc_f, argp)
_, delta_V, _ = change_ecc_inc(orb_0=s0, ecc_f=ecc_f, inc_f=inc_f * u.rad, f=f)
assert_allclose(delta_V.to_value(u.km / u.s), expected_delta_V, rtol=1e-2)
@@ -199,16 +199,16 @@ def test_geo_cases_numerical(ecc_0, ecc_f):
argp=argp * u.deg,
nu=0 * u.deg,
)
- a_d, _, t_f = change_ecc_inc(orb_0=s0, ecc_f=ecc_f, inc_f=inc_f, f=f)
+ a_d_hf, _, t_f = change_ecc_inc(orb_0=s0, ecc_f=ecc_f, inc_f=inc_f, f=f)
# Propagate orbit
- def f_geo(t0, u_, k):
- du_kep = func_twobody(t0, u_, k)
- ax, ay, az = a_d(t0, u_, k)
- du_ad = np.array([0, 0, 0, ax, ay, az])
- return du_kep + du_ad
+ @djit(cache=False)
+ def f_geo_hf(t0, rr, vv, k):
+ du_kep_rr, du_kep_vv = func_twobody_hf(t0, rr, vv, k)
+ du_ad = a_d_hf(t0, rr, vv, k)
+ return du_kep_rr, add_VV_hf(du_kep_vv, du_ad)
- sf = s0.propagate(t_f, method=CowellPropagator(rtol=1e-8, f=f_geo))
+ sf = s0.propagate(t_f, method=CowellPropagator(rtol=1e-8, f=f_geo_hf))
assert_allclose(sf.ecc.value, ecc_f, rtol=1e-2, atol=1e-2)
assert_allclose(sf.inc.to_value(u.rad), inc_f.to_value(u.rad), rtol=1e-1)
@@ -249,7 +249,7 @@ def test_soyuz_standard_gto_delta_v_fast():
k = Earth.k.to(u.km**3 / u.s**2).value
- _, delta_V, t_f = change_argp_fast(k, a, ecc, argp_0, argp_f, f)
+ _, delta_V, t_f = change_argp_hb(k, a, ecc, argp_0, argp_f, f)
expected_t_f = 12.0 # days, approximate
expected_delta_V = 0.2489 # km / s
@@ -271,7 +271,7 @@ def test_soyuz_standard_gto_numerical_safe():
k = Earth.k.to(u.km**3 / u.s**2)
- a_d, _, t_f = change_argp(k, a, ecc, argp_0, argp_f, f)
+ a_d_hf, _, t_f = change_argp(k, a, ecc, argp_0, argp_f, f)
# Retrieve r and v from initial orbit
s0 = Orbit.from_classical(
@@ -285,13 +285,13 @@ def test_soyuz_standard_gto_numerical_safe():
)
# Propagate orbit
- def f_soyuz(t0, u_, k):
- du_kep = func_twobody(t0, u_, k)
- ax, ay, az = a_d(t0, u_, k)
- du_ad = np.array([0, 0, 0, ax, ay, az])
- return du_kep + du_ad
+ @djit(cache=False)
+ def f_soyuz_hf(t0, rr, vv, k):
+ du_kep_rr, du_kep_vv = func_twobody_hf(t0, rr, vv, k)
+ du_ad = a_d_hf(t0, rr, vv, k)
+ return du_kep_rr, add_VV_hf(du_kep_vv, du_ad)
- sf = s0.propagate(t_f, method=CowellPropagator(rtol=1e-8, f=f_soyuz))
+ sf = s0.propagate(t_f, method=CowellPropagator(rtol=1e-8, f=f_soyuz_hf))
assert_allclose(sf.argp.to_value(u.rad), argp_f.to_value(u.rad), rtol=1e-4)
@@ -309,7 +309,7 @@ def test_soyuz_standard_gto_numerical_fast():
k = Earth.k.to(u.km**3 / u.s**2).value
- a_d, _, t_f = change_argp_fast(k, a, ecc, argp_0, argp_f, f)
+ a_d_hf, _, t_f = change_argp_hb(k, a, ecc, argp_0, argp_f, f)
# Retrieve r and v from initial orbit
s0 = Orbit.from_classical(
@@ -323,15 +323,15 @@ def test_soyuz_standard_gto_numerical_fast():
)
# Propagate orbit
- def f_soyuz(t0, u_, k):
- du_kep = func_twobody(t0, u_, k)
- ax, ay, az = a_d(t0, u_, k)
- du_ad = np.array([0, 0, 0, ax, ay, az])
- return du_kep + du_ad
+ @djit(cache=False)
+ def f_soyuz_hf(t0, rr, vv, k):
+ du_kep_rr, du_kep_vv = func_twobody_hf(t0, rr, vv, k)
+ du_ad = a_d_hf(t0, rr, vv, k)
+ return du_kep_rr, add_VV_hf(du_kep_vv, du_ad)
sf = s0.propagate(
t_f * u.s,
- method=CowellPropagator(rtol=1e-8, f=f_soyuz),
+ method=CowellPropagator(rtol=1e-8, f=f_soyuz_hf),
)
assert_allclose(sf.argp.to(u.rad).value, argp_f, rtol=1e-4)
@@ -354,7 +354,7 @@ def test_leo_geo_time_and_delta_v(inc_0, expected_t_f, expected_delta_V, rtol):
k = Earth.k.to(u.km**3 / u.s**2).value
inc_0 = np.radians(inc_0) # rad
- _, delta_V, t_f = change_a_inc_fast(k, a_0, a_f, inc_0, inc_f, f)
+ _, delta_V, t_f = change_a_inc_hb(k, a_0, a_f, inc_0, inc_f, f)
assert_allclose(delta_V, expected_delta_V, rtol=rtol)
assert_allclose((t_f * u.s).to(u.day).value, expected_t_f, rtol=rtol)
diff --git a/tox.ini b/tox.ini
index a62b0fc6a..6c44bce33 100644
--- a/tox.ini
+++ b/tox.ini
@@ -24,6 +24,7 @@ setenv =
PYTHONUNBUFFERED = yes
PIP_PREFER_BINARY = 1
NPY_DISABLE_CPU_FEATURES = AVX512_SKX
+ HAPSIRA_CACHE = 0
coverage: NUMBA_DISABLE_JIT = 1
fast: PYTEST_MARKERS = -m "not slow and not mpl_image_compare" -n auto
online: PYTEST_MARKERS = -m "remote_data" -n auto