Skip to content

Commit

Permalink
Fixed python 3.12 compatibility issues. (#1004)
Browse files Browse the repository at this point in the history
* correctly handle jax version == "latest"

* Changed playwright to only be a dependency if python version is less than 3.12 due to an incompatibility in its dependencies

* New pycodestyle issues with 3.12.
unittest.TestCase.assertEquals -> assertEqual
  • Loading branch information
robfalck authored Oct 16, 2023
1 parent 1c0cb06 commit 9e02030
Show file tree
Hide file tree
Showing 8 changed files with 63 additions and 36 deletions.
9 changes: 6 additions & 3 deletions .github/workflows/dymos_docs_workflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ jobs:
SNOPT: 7.7
OPENMDAO: 'dev'
OPTIONAL: '[docs]'
JAX: '0.3.24'
JAX: 'latest'
PUBLISH_DOCS: 0

steps:
Expand Down Expand Up @@ -101,8 +101,11 @@ jobs:
echo "============================================================="
echo "Install jax"
echo "============================================================="
python -m pip install jaxlib==${{ matrix.JAX }} jax==${{ matrix.JAX }}
if [[ "${{ matrix.JAX }}" == "latest" ]]; then
python -m pip install jaxlib jax
else
python -m pip install jaxlib==${{ matrix.JAX }} jax==${{ matrix.JAX }}
fi
- name: Install PETSc
if: matrix.PETSc
shell: bash -l {0}
Expand Down
17 changes: 15 additions & 2 deletions .github/workflows/dymos_tests_workflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ jobs:
SNOPT: 7.7
OPENMDAO: 'dev'
OPTIONAL: '[test]'
JAX: '0.3.24'
JAX: 'latest'

# oldest supported versions
- NAME: oldest
Expand Down Expand Up @@ -160,7 +160,20 @@ jobs:
echo "============================================================="
echo "Install jax"
echo "============================================================="
python -m pip install jaxlib==${{ matrix.JAX }} jax==${{ matrix.JAX }}
if [[ "${{ matrix.JAX }}" == "latest" ]]; then
python -m pip install jaxlib jax
else
python -m pip install jaxlib==${{ matrix.JAX }} jax==${{ matrix.JAX }}
fi
- name: Install greenlet
if: env.NAME == 'latest'
shell: bash -l {0}
run: |
echo "============================================================="
echo "Install greenlet from wheels"
echo "============================================================="
pip install --only-binary :all: greenlet
- name: Install PETSc
if: env.RUN_BUILD && matrix.PETSc
Expand Down
18 changes: 9 additions & 9 deletions dymos/trajectory/test/test_t_initial_bounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def test_pair_fixed_t_initial_below(self):

msg = ("'traj' <class Trajectory>: Fixed t_initial of 5.0 is outside of allowed bounds "
"(10.0, 15.0) for phase 'phase1'.")
self.assertEquals(cm.exception.args[0], msg)
self.assertEqual(cm.exception.args[0], msg)

def test_pair_fixed_t_initial_above(self):
kwargs = {
Expand All @@ -172,7 +172,7 @@ def test_pair_fixed_t_initial_above(self):

msg = ("'traj' <class Trajectory>: Fixed t_initial of 99.0 is outside of allowed bounds "
"(10.0, 15.0) for phase 'phase1'.")
self.assertEquals(cm.exception.args[0], msg)
self.assertEqual(cm.exception.args[0], msg)

def test_pair_t_initial_bounds_below(self):
kwargs = {
Expand All @@ -194,7 +194,7 @@ def test_pair_t_initial_bounds_below(self):

msg = ("'traj' <class Trajectory>: t_initial bounds of (5.0, 7.0) do not overlap with "
"allowed bounds (8.0, 17.0) for phase 'phase1'.")
self.assertEquals(cm.exception.args[0], msg)
self.assertEqual(cm.exception.args[0], msg)

def test_pair_t_initial_bounds_above(self):
kwargs = {
Expand All @@ -216,7 +216,7 @@ def test_pair_t_initial_bounds_above(self):

msg = ("'traj' <class Trajectory>: t_initial bounds of (20.0, 22.0) do not overlap with "
"allowed bounds (8.0, 17.0) for phase 'phase1'.")
self.assertEquals(cm.exception.args[0], msg)
self.assertEqual(cm.exception.args[0], msg)

def test_pair_no_duration_bounds(self):
kwargs = {
Expand Down Expand Up @@ -255,7 +255,7 @@ def test_all_fixed_t_initial(self):

msg = ("'traj' <class Trajectory>: Fixed t_initial of 5.0 is outside of allowed "
"bounds (10.0, 15.0) for phase 'phase2'.")
self.assertEquals(cm.exception.args[0], msg)
self.assertEqual(cm.exception.args[0], msg)

def test_all_t_initial_bounds(self):
nphases = 3
Expand All @@ -278,7 +278,7 @@ def test_all_t_initial_bounds(self):

msg = ("'traj' <class Trajectory>: t_initial bounds of (99.0, 104.0) do not overlap with "
"allowed bounds (10.0, 20.0) for phase 'phase2'.")
self.assertEquals(cm.exception.args[0], msg)
self.assertEqual(cm.exception.args[0], msg)

def test_odd_fixed_t_initial(self):
nphases = 4
Expand Down Expand Up @@ -424,7 +424,7 @@ def test_branching_all_fixed_t_initial(self):
"(15.0, 20.0) for phase 'br0_phase0'.\n"
"Fixed t_initial of 10.0 is outside of allowed bounds (15.0, 20.0) for phase "
"'br1_phase0'.")
self.assertEquals(cm.exception.args[0], msg)
self.assertEqual(cm.exception.args[0], msg)

def test_branching_all_t_initial_bounds(self):
nphases = 3 # number of phases in trunk and each branch
Expand Down Expand Up @@ -473,7 +473,7 @@ def test_branching_all_t_initial_bounds(self):
"allowed bounds (20.0, 30.0) for phase 'br0_phase1'.\n"
"t_initial bounds of (0.0, 5) do not overlap with allowed bounds (20.0, 30.0) "
"for phase 'br1_phase1'.")
self.assertEquals(cm.exception.args[0], msg)
self.assertEqual(cm.exception.args[0], msg)

def test_branching_odd_fixed_t_initial(self):
nphases = 3 # number of phases in trunk and each branch
Expand Down Expand Up @@ -521,4 +521,4 @@ def test_branching_odd_fixed_t_initial(self):
"(40.0, 50.0) for phase 'br0_phase3'.\n"
"Fixed t_initial of 60.0 is outside of allowed bounds (40.0, 50.0) for phase "
"'br1_phase3'.")
self.assertEquals(cm.exception.args[0], msg)
self.assertEqual(cm.exception.args[0], msg)
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def _get_rate_source_path(self, state_name, nodes, phase):
try:
var = phase.state_options[state_name]['rate_source']
except RuntimeError:
raise ValueError(f"state '{state_name}' in phase '{ phase.name}' was not given a "
raise ValueError(f"state '{state_name}' in phase '{phase.name}' was not given a "
"rate_source")

# Note the rate source must be shape-compatible with the state
Expand Down
32 changes: 16 additions & 16 deletions dymos/utils/test/test_testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ def test_unequal_time_series_rel_only(self):
self.assertTrue(actual_errmsg.startswith(start_of_expected_errmsg),
f"Error message expected to start with {start_of_expected_errmsg} but "
f"instead was {actual_errmsg}")
self.assertEquals(actual_errmsg.count('>REL_TOL'), 2)
self.assertEquals(actual_errmsg.count('>ABS_TOL'), 0)
self.assertEqual(actual_errmsg.count('>REL_TOL'), 2)
self.assertEqual(actual_errmsg.count('>ABS_TOL'), 0)

def test_unequal_time_series_abs_only(self):
# slightly modify the "to be checked" time series and check that the assert is working
Expand Down Expand Up @@ -102,8 +102,8 @@ def test_unequal_time_series_abs_only(self):
self.assertTrue(actual_errmsg.startswith(start_of_expected_errmsg),
f"Error message expected to start with {start_of_expected_errmsg} but "
f"instead was {actual_errmsg}")
self.assertEquals(actual_errmsg.count('>ABS_TOL'), 2)
self.assertEquals(actual_errmsg.count('>REL_TOL'), 0)
self.assertEqual(actual_errmsg.count('>ABS_TOL'), 2)
self.assertEqual(actual_errmsg.count('>REL_TOL'), 0)

def test_unequal_time_series_abs_and_rel(self):
# slightly modify the "to be checked" time series and check that the assert is working
Expand Down Expand Up @@ -139,8 +139,8 @@ def test_unequal_time_series_abs_and_rel(self):
self.assertTrue(actual_errmsg.startswith(start_of_expected_errmsg),
f"Error message expected to start with '{start_of_expected_errmsg}' but "
f"instead was '{actual_errmsg}'")
self.assertEquals(actual_errmsg.count('>ABS_TOL'), 2)
self.assertEquals(actual_errmsg.count('>REL_TOL'), 2)
self.assertEqual(actual_errmsg.count('>ABS_TOL'), 2)
self.assertEqual(actual_errmsg.count('>REL_TOL'), 2)

# for > 100, uses the rel, x_check[15] is ~ 150
x_check[5] = x_check_5_orig
Expand All @@ -166,8 +166,8 @@ def test_unequal_time_series_abs_and_rel(self):
self.assertTrue(actual_errmsg.startswith(start_of_expected_errmsg),
f"Error message expected to start with '{start_of_expected_errmsg}' but "
f"instead was '{actual_errmsg}'")
self.assertEquals(actual_errmsg.count('>ABS_TOL'), 2)
self.assertEquals(actual_errmsg.count('>REL_TOL'), 2)
self.assertEqual(actual_errmsg.count('>ABS_TOL'), 2)
self.assertEqual(actual_errmsg.count('>REL_TOL'), 2)

# Combine the two cases where one data paint fails because of abs error and one because
# of rel error
Expand All @@ -187,8 +187,8 @@ def test_unequal_time_series_abs_and_rel(self):
self.assertTrue(actual_errmsg.startswith(start_of_expected_errmsg),
f"Error message expected to start with '{start_of_expected_errmsg}' but "
f"instead was '{actual_errmsg}'")
self.assertEquals(actual_errmsg.count('>ABS_TOL'), 3)
self.assertEquals(actual_errmsg.count('>REL_TOL'), 3)
self.assertEqual(actual_errmsg.count('>ABS_TOL'), 3)
self.assertEqual(actual_errmsg.count('>REL_TOL'), 3)

def test_no_overlapping_time(self):
t_ref, x_ref = create_linear_time_series(100, 0.0, 500.0, 0.0, 1000.0)
Expand Down Expand Up @@ -253,8 +253,8 @@ def test_multi_dimensional_unequal(self):
self.assertTrue(actual_errmsg.startswith(start_of_expected_errmsg),
f"Error message expected to start with {start_of_expected_errmsg} but "
f"instead was {actual_errmsg}")
self.assertEquals(actual_errmsg.count('>ABS_TOL'), 0)
self.assertEquals(actual_errmsg.count('>REL_TOL'), 199)
self.assertEqual(actual_errmsg.count('>ABS_TOL'), 0)
self.assertEqual(actual_errmsg.count('>REL_TOL'), 199)

def test_multi_dimensional_unequal_abs_and_rel(self):
t_ref, x_ref_1 = create_linear_time_series(10, 0.0, 500.0, 0.0, 1000.0)
Expand All @@ -279,8 +279,8 @@ def test_multi_dimensional_unequal_abs_and_rel(self):
self.assertTrue(actual_errmsg.startswith(start_of_expected_errmsg),
f"Error message expected to start with {start_of_expected_errmsg} but "
f"instead was {actual_errmsg}")
self.assertEquals(actual_errmsg.count('>ABS_TOL'), 19)
self.assertEquals(actual_errmsg.count('>REL_TOL'), 19)
self.assertEqual(actual_errmsg.count('>ABS_TOL'), 19)
self.assertEqual(actual_errmsg.count('>REL_TOL'), 19)

def test_multi_dimensional_with_overlapping_times(self):
t_ref, x_ref_1 = create_linear_time_series(100, 0.0, 500.0, 0.0, 1000.0)
Expand Down Expand Up @@ -314,8 +314,8 @@ def test_multi_dimensional_unequal_with_overlapping_times(self):
self.assertTrue(actual_errmsg.startswith(start_of_expected_errmsg),
f"Error message expected to start with {start_of_expected_errmsg} but "
f"instead was {actual_errmsg}")
self.assertEquals(actual_errmsg.count('>ABS_TOL'), 0)
self.assertEquals(actual_errmsg.count('>REL_TOL'), 99)
self.assertEqual(actual_errmsg.count('>ABS_TOL'), 0)
self.assertEqual(actual_errmsg.count('>REL_TOL'), 99)


@use_tempdirs
Expand Down
11 changes: 9 additions & 2 deletions dymos/visualization/linkage/test/test_gui.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
"""Test Dymos Linkage report GUI with using Playwright."""
import unittest
import os
try:
import playwright
except ImportError:
playwright = None


if playwright is not None:
os.system("playwright install")
from linkage_report_ui_test import dymos_linkage_gui_test_case # nopep8: E402

os.system("playwright install")
from linkage_report_ui_test import dymos_linkage_gui_test_case # nopep8: E402

if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion dymos/visualization/timeseries_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def _mpl_timeseries_plots(time_units, var_units, phase_names, phases_node_path,
plt.subplots_adjust(bottom=0.23, top=0.9, left=0.2)

# save to file
plot_file_path = plot_dir_path.joinpath(f'{var_name.replace(":","_")}.png')
plot_file_path = plot_dir_path.joinpath(f'{var_name.replace(":", "_")}.png')
plt.savefig(plot_file_path, dpi=dpi)
plt.close(fig)
plotfiles.append(plot_file_path)
Expand Down
8 changes: 6 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from packaging.version import Version
from platform import python_version
from setuptools import find_packages, setup

# Setup optional dependencies
Expand Down Expand Up @@ -27,11 +29,13 @@
'testflo>=1.3.6',
'matplotlib',
'numpydoc>=1.1',
'playwright>=1.20',
'aiounittest'
]
}

# playwright dependencies are currently incompatible with python 3.12
if Version(python_version()) < Version('3.12.0'):
optional_dependencies['test'].extend(['playwright>=1.20', 'aiounittest'])

# Add an optional dependency that concatenates all others
optional_dependencies['all'] = sorted([
dependency
Expand Down

0 comments on commit 9e02030

Please sign in to comment.