Skip to content

Commit

Permalink
Added the ability to use a callable that returns a System as an ODE (#…
Browse files Browse the repository at this point in the history
…528)

* The phase instantiation can now take ode_class as a callable that returns an OpenMDAO system, where the callable has the arguments (num_nodes=<int>, **ode_init_kwargs).
Added some documentation.
Added tests for invalid callable ODEs with appropriate error messages.
Added test for copying an existing ODE instance via deepcopy, though this is discouraged in the documentation.
  • Loading branch information
robfalck authored Jan 26, 2021
1 parent 77e5ad5 commit 5c7895c
Show file tree
Hide file tree
Showing 7 changed files with 574 additions and 8 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,339 @@
import os
import unittest

import numpy as np
import openmdao.api as om
from openmdao.utils.testing_utils import use_tempdirs
from openmdao.utils.assert_utils import assert_near_equal
import dymos as dm


@use_tempdirs
class TestBrachExecCompODE(unittest.TestCase):

def _make_problem(self, transcription='gauss-lobatto', num_segments=8, transcription_order=3,
compressed=True, optimizer='SLSQP', run_driver=True, force_alloc_complex=False,
solve_segments=False):

p = om.Problem(model=om.Group())

p.driver = om.pyOptSparseDriver()
p.driver.options['optimizer'] = optimizer
p.driver.declare_coloring(tol=1.0E-12)

if transcription == 'gauss-lobatto':
t = dm.GaussLobatto(num_segments=num_segments,
order=transcription_order,
compressed=compressed)
elif transcription == 'radau-ps':
t = dm.Radau(num_segments=num_segments,
order=transcription_order,
compressed=compressed)

ode = lambda num_nodes: om.ExecComp(['vdot = g * cos(theta)',
'xdot = v * sin(theta)',
'ydot = -v * cos(theta)'],
g={'value': 9.80665, 'units': 'm/s**2'},
v={'shape': (num_nodes,), 'units': 'm/s'},
theta={'shape': (num_nodes,), 'units': 'rad'},
vdot={'shape': (num_nodes,),
'units': 'm/s**2',
'tags': ['state_rate_source:v']},
xdot={'shape': (num_nodes,),
'units': 'm/s',
'tags': ['state_rate_source:x']},
ydot={'shape': (num_nodes,),
'units': 'm/s',
'tags': ['state_rate_source:y']},
has_diag_partials=True)

traj = dm.Trajectory()
phase = dm.Phase(ode_class=ode, transcription=t)
p.model.add_subsystem('traj0', traj)
traj.add_phase('phase0', phase)

phase.set_time_options(fix_initial=True, duration_bounds=(.5, 10))

phase.add_state('x', fix_initial=True, fix_final=False, solve_segments=solve_segments)
phase.add_state('y', fix_initial=True, fix_final=False, solve_segments=solve_segments)

# Note that by omitting the targets here Dymos will automatically attempt to connect
# to a top-level input named 'v' in the ODE, and connect to nothing if it's not found.
phase.add_state('v', fix_initial=True, fix_final=False, solve_segments=solve_segments)

phase.add_control('theta',
continuity=True, rate_continuity=True,
units='deg', lower=0.01, upper=179.9)

phase.add_parameter('g', units='m/s**2', dynamic=False)

phase.add_boundary_constraint('x', loc='final', equals=10)
phase.add_boundary_constraint('y', loc='final', equals=5)
# Minimize time at the end of the phase
phase.add_objective('time_phase', loc='final', scaler=10)

p.setup(check=['unconnected_inputs'], force_alloc_complex=force_alloc_complex)

p['traj0.phase0.t_initial'] = 0.0
p['traj0.phase0.t_duration'] = 2.0

p['traj0.phase0.states:x'] = phase.interpolate(ys=[0, 10], nodes='state_input')
p['traj0.phase0.states:y'] = phase.interpolate(ys=[10, 5], nodes='state_input')
p['traj0.phase0.states:v'] = phase.interpolate(ys=[0, 9.9], nodes='state_input')
p['traj0.phase0.controls:theta'] = phase.interpolate(ys=[5, 100], nodes='control_input')
p['traj0.phase0.parameters:g'] = 9.80665

dm.run_problem(p, run_driver=run_driver, simulate=True)

return p

def run_asserts(self):

for db in ['dymos_solution.db', 'dymos_simulation.db']:
p = om.CaseReader(db).get_case('final')

t_initial = p.get_val('traj0.phase0.timeseries.time')[0]
tf = p.get_val('traj0.phase0.timeseries.time')[-1]

x0 = p.get_val('traj0.phase0.timeseries.states:x')[0]
xf = p.get_val('traj0.phase0.timeseries.states:x')[-1]

y0 = p.get_val('traj0.phase0.timeseries.states:y')[0]
yf = p.get_val('traj0.phase0.timeseries.states:y')[-1]

v0 = p.get_val('traj0.phase0.timeseries.states:v')[0]
vf = p.get_val('traj0.phase0.timeseries.states:v')[-1]

g = p.get_val('traj0.phase0.timeseries.parameters:g')[0]

thetaf = p.get_val('traj0.phase0.timeseries.controls:theta')[-1]

assert_near_equal(t_initial, 0.0)
assert_near_equal(x0, 0.0)
assert_near_equal(y0, 10.0)
assert_near_equal(v0, 0.0)

assert_near_equal(tf, 1.8016, tolerance=0.01)
assert_near_equal(xf, 10.0, tolerance=0.01)
assert_near_equal(yf, 5.0, tolerance=0.01)
assert_near_equal(vf, 9.902, tolerance=0.01)
assert_near_equal(g, 9.80665, tolerance=0.01)

assert_near_equal(thetaf, 100.12, tolerance=0.01)

def test_ex_brachistochrone_radau_uncompressed(self):
self._make_problem(transcription='radau-ps', compressed=False)
self.run_asserts()

def test_ex_brachistochrone_gl_uncompressed(self):
self._make_problem(transcription='gauss-lobatto', compressed=False)
self.run_asserts()


@use_tempdirs
class TestInvalidCallableODEClass(unittest.TestCase):

def test_invalid_callable(self):
num_segments = 10
transcription_order = 3
compressed = False

p = om.Problem(model=om.Group())

p.driver = om.pyOptSparseDriver()
p.driver.options['optimizer'] = 'SLSQP'
p.driver.declare_coloring(tol=1.0E-12)

t = dm.Radau(num_segments=num_segments,
order=transcription_order,
compressed=compressed)

ode = lambda num_nodes: num_nodes*2

traj = dm.Trajectory()
phase = dm.Phase(ode_class=ode, transcription=t)
p.model.add_subsystem('traj0', traj)
traj.add_phase('phase0', phase)

phase.set_time_options(fix_initial=True, duration_bounds=(.5, 10))

phase.add_state('x', fix_initial=True, fix_final=False, rate_source='xdot')
phase.add_state('y', fix_initial=True, fix_final=False, rate_source='ydot')

# Note that by omitting the targets here Dymos will automatically attempt to connect
# to a top-level input named 'v' in the ODE, and connect to nothing if it's not found.
phase.add_state('v', fix_initial=True, fix_final=False, rate_source='vdot')

phase.add_control('theta',
continuity=True, rate_continuity=True,
units='deg', lower=0.01, upper=179.9)

phase.add_parameter('g', units='m/s**2', dynamic=False)

phase.add_boundary_constraint('x', loc='final', equals=10)
phase.add_boundary_constraint('y', loc='final', equals=5)
# Minimize time at the end of the phase
phase.add_objective('time_phase', loc='final', scaler=10)

with self.assertRaises(expected_exception=ValueError) as e:
p.setup()
expected = "When provided as a callable, ode_class must return an instance of " \
"openmdao.core.System. Got <class 'int'>"
self.assertEqual(expected, str(e.exception))


class CallableBrachistochroneODE(om.ExplicitComponent):

def initialize(self):
self.options.declare('num_nodes', types=int)

def __call__(self, num_nodes, **kwargs):
from copy import deepcopy
ret = deepcopy(self)
ret.options['num_nodes'] = num_nodes
return ret

def setup(self):
nn = self.options['num_nodes']

# Inputs
self.add_input('v', val=np.zeros(nn), desc='velocity', units='m/s')

self.add_input('g', val=9.80665, desc='grav. acceleration', units='m/s/s')

self.add_input('theta', val=np.ones(nn), desc='angle of wire', units='rad')

self.add_output('xdot', val=np.zeros(nn), desc='velocity component in x', units='m/s',
tags=['state_rate_source:x', 'state_units:m'])

self.add_output('ydot', val=np.zeros(nn), desc='velocity component in y', units='m/s',
tags=['state_rate_source:y', 'state_units:m'])

self.add_output('vdot', val=np.zeros(nn), desc='acceleration magnitude', units='m/s**2',
tags=['state_rate_source:v', 'state_units:m/s'])

self.declare_partials(of='*', wrt='*', method='cs')

def compute(self, inputs, outputs):
theta = inputs['theta']
cos_theta = np.cos(theta)
sin_theta = np.sin(theta)
g = inputs['g']
v = inputs['v']

outputs['vdot'] = g * cos_theta
outputs['xdot'] = v * sin_theta
outputs['ydot'] = -v * cos_theta


@use_tempdirs
class TestBrachCallableODE(unittest.TestCase):

def setUp(self):
self.ode = CallableBrachistochroneODE(num_nodes=1)

def _make_problem(self, transcription='gauss-lobatto', num_segments=8, transcription_order=3,
compressed=True, optimizer='SLSQP', run_driver=True,
force_alloc_complex=False,
solve_segments=False):

p = om.Problem(model=om.Group())

p.driver = om.pyOptSparseDriver()
p.driver.options['optimizer'] = optimizer
p.driver.declare_coloring(tol=1.0E-12)

if transcription == 'gauss-lobatto':
t = dm.GaussLobatto(num_segments=num_segments,
order=transcription_order,
compressed=compressed)
elif transcription == 'radau-ps':
t = dm.Radau(num_segments=num_segments,
order=transcription_order,
compressed=compressed)

traj = dm.Trajectory()
phase = dm.Phase(ode_class=self.ode, transcription=t)
p.model.add_subsystem('traj0', traj)
traj.add_phase('phase0', phase)

phase.set_time_options(fix_initial=True, duration_bounds=(.5, 10))

phase.add_state('x', fix_initial=True, fix_final=False, solve_segments=solve_segments,
rate_source='xdot')
phase.add_state('y', fix_initial=True, fix_final=False, solve_segments=solve_segments,
rate_source='ydot')

# Note that by omitting the targets here Dymos will automatically attempt to connect
# to a top-level input named 'v' in the ODE, and connect to nothing if it's not found.
phase.add_state('v', fix_initial=True, fix_final=False, solve_segments=solve_segments,
rate_source='vdot')

phase.add_control('theta',
continuity=True, rate_continuity=True,
units='deg', lower=0.01, upper=179.9)

phase.add_parameter('g', units='m/s**2', dynamic=False)

phase.add_boundary_constraint('x', loc='final', equals=10)
phase.add_boundary_constraint('y', loc='final', equals=5)
# Minimize time at the end of the phase
phase.add_objective('time_phase', loc='final', scaler=10)

p.setup(check=['unconnected_inputs'], force_alloc_complex=force_alloc_complex)

p['traj0.phase0.t_initial'] = 0.0
p['traj0.phase0.t_duration'] = 2.0

p['traj0.phase0.states:x'] = phase.interpolate(ys=[0, 10], nodes='state_input')
p['traj0.phase0.states:y'] = phase.interpolate(ys=[10, 5], nodes='state_input')
p['traj0.phase0.states:v'] = phase.interpolate(ys=[0, 9.9], nodes='state_input')
p['traj0.phase0.controls:theta'] = phase.interpolate(ys=[5, 100], nodes='control_input')
p['traj0.phase0.parameters:g'] = 9.80665

dm.run_problem(p, run_driver=run_driver, simulate=True)

return p

def run_asserts(self):

for db in ['dymos_solution.db', 'dymos_simulation.db']:
p = om.CaseReader(db).get_case('final')

t_initial = p.get_val('traj0.phase0.timeseries.time')[0]
tf = p.get_val('traj0.phase0.timeseries.time')[-1]

x0 = p.get_val('traj0.phase0.timeseries.states:x')[0]
xf = p.get_val('traj0.phase0.timeseries.states:x')[-1]

y0 = p.get_val('traj0.phase0.timeseries.states:y')[0]
yf = p.get_val('traj0.phase0.timeseries.states:y')[-1]

v0 = p.get_val('traj0.phase0.timeseries.states:v')[0]
vf = p.get_val('traj0.phase0.timeseries.states:v')[-1]

g = p.get_val('traj0.phase0.timeseries.parameters:g')[0]

thetaf = p.get_val('traj0.phase0.timeseries.controls:theta')[-1]

assert_near_equal(t_initial, 0.0)
assert_near_equal(x0, 0.0)
assert_near_equal(y0, 10.0)
assert_near_equal(v0, 0.0)

assert_near_equal(tf, 1.8016, tolerance=0.01)
assert_near_equal(xf, 10.0, tolerance=0.01)
assert_near_equal(yf, 5.0, tolerance=0.01)
assert_near_equal(vf, 9.902, tolerance=0.01)
assert_near_equal(g, 9.80665, tolerance=0.01)

assert_near_equal(thetaf, 100.12, tolerance=0.01)

def test_ex_brachistochrone_radau_uncompressed(self):
self._make_problem(transcription='radau-ps', compressed=False)
self.run_asserts()

def test_in_series(self):
self._make_problem(transcription='gauss-lobatto', compressed=False)
self._make_problem(transcription='radau-ps', compressed=False)
self.run_asserts()
17 changes: 13 additions & 4 deletions dymos/phase/phase.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import Callable
from collections.abc import Iterable, Sequence
import inspect
import warnings
Expand Down Expand Up @@ -1623,8 +1624,9 @@ def _check_ode(self):
"""
Check that the provided ODE class meets minimum requirements.
* The ode_class must be a class, not an instance.
* The ode_class must derive from openmdao.core.System
* The ode_class must be provided as a class or a callable.
* When given as a callable, ode_class must return an instance derived from openmdao.core.System.
* When given as a class, ode_class must derive from openmdao.core.System
Raises
------
Expand All @@ -1634,9 +1636,16 @@ def _check_ode(self):
"""
ode_class = self.options['ode_class']
if not inspect.isclass(ode_class):
raise ValueError('ode_class must be a class, not an instance.')
if not isinstance(ode_class, Callable):
raise ValueError('ode_class must be given as a callable object that returns an '
'object derived from openmdao.core.System, or as a class derived '
'from openmdao.core.System.')
test_instance = ode_class(num_nodes=1, **self.options['ode_init_kwargs'])
if not isinstance(test_instance, System):
raise ValueError(f'When provided as a callable, ode_class must return an instance '
f'of openmdao.core.System. Got {type(test_instance)}')
elif not issubclass(ode_class, System):
raise ValueError('ode_class must be derived from openmdao.core.System.')
raise ValueError('If given as a class, ode_class must be derived from openmdao.core.System.')

def setup(self):
"""
Expand Down
7 changes: 5 additions & 2 deletions dymos/phase/test/test_phase.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def test_invalid_ode_wrong_class(self):
with self.assertRaises(ValueError) as e:
p.setup(check=True)

self.assertEqual(str(e.exception), 'ode_class must be derived from openmdao.core.System.')
expected = 'If given as a class, ode_class must be derived from openmdao.core.System.'
self.assertEqual(expected, str(e.exception))

def test_invalid_ode_instance(self):

Expand Down Expand Up @@ -97,7 +98,9 @@ def test_invalid_ode_instance(self):
with self.assertRaises(ValueError) as e:
p.setup(check=True)

self.assertEqual(str(e.exception), 'ode_class must be a class, not an instance.')
expected = 'ode_class must be given as a callable object that returns an object derived ' \
'from openmdao.core.System, or as a class derived from openmdao.core.System.'
self.assertEqual(expected, str(e.exception))

def test_add_existing_parameter_as_parameter(self):

Expand Down
Loading

0 comments on commit 5c7895c

Please sign in to comment.