Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Warn users when restarting with fix_final=True #804

Merged
merged 5 commits into from
Aug 18, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions dymos/load_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from openmdao.recorders.case import Case
from .phase.phase import Phase
from .trajectory import Trajectory
from warnings import warn


def find_phases(sys):
Expand Down Expand Up @@ -149,6 +150,12 @@ def load_case(problem, previous_solution):
if init_val_path:
problem.set_val(init_val_path[0], prev_state_val[0, ...], units=prev_state_units)

if options['fix_final']:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to use the OpenMDAO helper issue_warning in instead of warnings.warn. See

issue_warning(f"The timeseries variable name {output_name} is "

warning_message = f"{phase_name}.states:{state_name} specifies 'fix_final=True'. " \
f"If the given restart file has a" \
f" different final value this will overwrite the user-specified value"
warn(warning_message)

# Interpolate the timeseries control outputs from the previous solution onto the new grid.
for control_name, options in phase.control_options.items():
control_path = [s for s in phase_vars if s.endswith(f'{phase_name}.controls:{control_name}')][0]
Expand All @@ -160,6 +167,11 @@ def load_case(problem, previous_solution):
phase.interp(xs=prev_time_val, ys=prev_control_val,
nodes='control_input', kind='slinear'),
units=prev_control_units)
if options['fix_final']:
warning_message = f"{phase_name}.controls:{control_name} specifies 'fix_final=True'. " \
f"If the given restart file has a" \
f" different final value this will overwrite the user-specified value"
warn(warning_message)

# Set the output polynomial control outputs from the previous solution as the value
for pc_name, options in phase.polynomial_control_options.items():
Expand All @@ -169,6 +181,11 @@ def load_case(problem, previous_solution):
prev_pc_val = prev_vars[prev_pc_path]['val']
prev_pc_units = prev_vars[prev_pc_path]['units']
problem.set_val(pc_path, prev_pc_val, units=prev_pc_units)
if options['fix_final']:
warning_message = f"{phase_name}.polynomial_controls:{pc_name} specifies 'fix_final=True'. " \
f"If the given restart file has a" \
f" different final value this will overwrite the user-specified value"
warn(warning_message)

# Set the timeseries parameter outputs from the previous solution as the parameter value
for param_name, options in phase.parameter_options.items():
Expand Down
83 changes: 78 additions & 5 deletions dymos/test/test_load_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
om_version = tuple([int(s) for s in openmdao.__version__.split('-')[0].split('.')])


def setup_problem(trans=dm.GaussLobatto(num_segments=10), polynomial_control=False):
def setup_problem(trans=dm.GaussLobatto(num_segments=10), polynomial_control=False,
fix_final_state=True, fix_final_control=False):
from dymos.examples.brachistochrone.brachistochrone_ode import BrachistochroneODE

p = om.Problem(model=om.Group())
Expand All @@ -20,15 +21,16 @@ def setup_problem(trans=dm.GaussLobatto(num_segments=10), polynomial_control=Fal

phase.set_time_options(fix_initial=True, duration_bounds=(.5, 10))

phase.add_state('x', fix_initial=True, fix_final=True)
phase.add_state('y', fix_initial=True, fix_final=True)
phase.add_state('x', fix_initial=True, fix_final=fix_final_state)
phase.add_state('y', fix_initial=True, fix_final=fix_final_state)
phase.add_state('v', fix_initial=True)

if not polynomial_control:
phase.add_control('theta', units='deg',
rate_continuity=False, lower=0.01, upper=179.9)
rate_continuity=False, lower=0.01, upper=179.9, fix_final=fix_final_control)
else:
phase.add_polynomial_control('theta', order=1, units='deg', lower=0.01, upper=179.9)
phase.add_polynomial_control('theta', order=1, units='deg', lower=0.01, upper=179.9,
fix_final=fix_final_control)

phase.add_parameter('g', units='m/s**2', opt=False, val=9.80665)

Expand Down Expand Up @@ -186,6 +188,77 @@ def test_load_case_radau_to_lgl(self):
q.model.phase0.interp(xs=time_val, ys=theta_val, nodes='all'),
tolerance=1.0E-2)

def test_load_case_warn_fix_final_states(self):
import openmdao.api as om
from openmdao.utils.assert_utils import assert_warnings
import dymos as dm

p = setup_problem(dm.Radau(num_segments=20))

# Solve for the optimal trajectory
dm.run_problem(p)

# Load the solution
case = om.CaseReader('dymos_solution.db').get_case('final')

# create a problem with a different transcription with a different number of variables
q = setup_problem(dm.GaussLobatto(num_segments=50))

msgs = []

# Load the values from the previous solution
for state_name in ['x', 'y']:
msgs.append((UserWarning, f"phase0.states:{state_name} specifies 'fix_final=True'."
f" If the given restart file has a different final value"
f" this will overwrite the user-specified value"))

with assert_warnings(msgs):
dm.load_case(q, case)

def test_load_case_warn_fix_final_control(self):
import openmdao.api as om
from openmdao.utils.assert_utils import assert_warning
import dymos as dm
p = setup_problem(dm.Radau(num_segments=10))

# Solve for the optimal trajectory
dm.run_problem(p)

# Load the solution
case = om.CaseReader('dymos_solution.db').get_case('final')

# create a problem with a different transcription with a different number of variables
q = setup_problem(dm.Radau(num_segments=10), fix_final_state=False, fix_final_control=True)

msg = f"phase0.controls:theta specifies 'fix_final=True'. If the given restart file has a" \
f" different final value this will overwrite the user-specified value"

with assert_warning(UserWarning, msg):
dm.load_case(q, case)

def test_load_case_warn_fix_final_polynomial_control(self):
import openmdao.api as om
from openmdao.utils.assert_utils import assert_warning
import dymos as dm
p = setup_problem(dm.Radau(num_segments=10), polynomial_control=True,)

# Solve for the optimal trajectory
dm.run_problem(p)

# Load the solution
case = om.CaseReader('dymos_solution.db').get_case('final')

# create a problem with a different transcription with a different number of variables
q = setup_problem(dm.Radau(num_segments=10), polynomial_control=True,
fix_final_state=False, fix_final_control=True)

# Load the values from the previous solution
msg = f"phase0.polynomial_controls:theta specifies 'fix_final=True'. If the given restart file has a" \
f" different final value this will overwrite the user-specified value"

with assert_warning(UserWarning, msg):
dm.load_case(q, case)


if __name__ == '__main__': # pragma: no cover
unittest.main()