Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Simulation Save and Load in JSON Files #695

Open
wants to merge 10 commits into
base: develop
Choose a base branch
from
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Attention: The newest changes should be on top -->

### Added

- ENH: Expansion of Encoders Implementation for Full Flights. [#679](https://github.com/RocketPy-Team/RocketPy/pull/679)
- ENH: Generic Surfaces and Generic Linear Surfaces [#680](https://github.com/RocketPy-Team/RocketPy/pull/680)
- ENH: Free-Form Fins [#694](https://github.com/RocketPy-Team/RocketPy/pull/694)
- ENH: Expand Polation Options for ND Functions. [#691](https://github.com/RocketPy-Team/RocketPy/pull/691)
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ netCDF4>=1.6.4
requests
pytz
simplekml
dill
123 changes: 113 additions & 10 deletions rocketpy/_encoders.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
"""Defines a custom JSON encoder for RocketPy objects."""

import base64
import json
import types
from datetime import datetime
from importlib import import_module

import dill
import numpy as np

from rocketpy.mathutils.function import Function


class RocketPyEncoder(json.JSONEncoder):
"""NOTE: This is still under construction, please don't use it yet."""
"""Custom JSON encoder for RocketPy objects. It defines how to encode
different types of objects to a JSON supported format."""

def default(self, o):
if isinstance(
Expand All @@ -33,11 +35,112 @@
return float(o)
elif isinstance(o, np.ndarray):
return o.tolist()
elif isinstance(o, datetime):
return [o.year, o.month, o.day, o.hour]

Check warning on line 39 in rocketpy/_encoders.py

View check run for this annotation

Codecov / codecov/patch

rocketpy/_encoders.py#L39

Added line #L39 was not covered by tests
elif hasattr(o, "__iter__") and not isinstance(o, str):
return list(o)
elif hasattr(o, "to_dict"):
return o.to_dict()
# elif isinstance(o, Function):
# return o.__dict__()
elif isinstance(o, (Function, types.FunctionType)):
return repr(o)
encoding = o.to_dict()

encoding["signature"] = get_class_signature(o)

return encoding

elif hasattr(o, "__dict__"):
exception_set = {"prints", "plots"}
encoding = {
key: value
for key, value in o.__dict__.items()
if key not in exception_set
}

if "rocketpy" in o.__class__.__module__ and not any(
subclass in o.__class__.__name__
for subclass in ["FlightPhase", "TimeNode"]
):
encoding["signature"] = get_class_signature(o)

return encoding
else:
return json.JSONEncoder.default(self, o)
return super().default(o)

Check warning on line 65 in rocketpy/_encoders.py

View check run for this annotation

Codecov / codecov/patch

rocketpy/_encoders.py#L65

Added line #L65 was not covered by tests


class RocketPyDecoder(json.JSONDecoder):
"""Custom JSON decoder for RocketPy objects. It defines how to decode
different types of objects from a JSON supported format."""

def __init__(self, *args, **kwargs):
super().__init__(object_hook=self.object_hook, *args, **kwargs)

def object_hook(self, obj):
if "signature" in obj:
signature = obj.pop("signature")

try:
class_ = get_class_from_signature(signature)

if hasattr(class_, "from_dict"):
return class_.from_dict(obj)
else:
# Filter keyword arguments
kwargs = {
key: value
for key, value in obj.items()
if key in class_.__init__.__code__.co_varnames
}

return class_(**kwargs)
except ImportError: # AttributeException
return obj

Check warning on line 94 in rocketpy/_encoders.py

View check run for this annotation

Codecov / codecov/patch

rocketpy/_encoders.py#L93-L94

Added lines #L93 - L94 were not covered by tests
else:
return obj


def get_class_signature(obj):
class_ = obj.__class__

return f"{class_.__module__}.{class_.__name__}"


def get_class_from_signature(signature):
module_name, class_name = signature.rsplit(".", 1)

module = import_module(module_name)

return getattr(module, class_name)


def to_hex_encode(obj, encoder=base64.b85encode):
"""Converts an object to hex representation using dill.

Parameters
----------
obj : object
Object to be converted to hex.
encoder : callable, optional
Function to encode the bytes. Default is base64.b85encode.

Returns
-------
bytes
Object converted to bytes.
"""
return encoder(dill.dumps(obj)).hex()


def from_hex_decode(obj_bytes, decoder=base64.b85decode):
"""Converts an object from hex representation using dill.

Parameters
----------
obj_bytes : str
Hex string to be converted to object.
decoder : callable, optional
Function to decode the bytes. Default is base64.b85decode.

Returns
-------
object
Object converted from bytes.
"""
return dill.loads(decoder(bytes.fromhex(obj_bytes)))
117 changes: 110 additions & 7 deletions rocketpy/environment/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,12 +367,15 @@
self.standard_g = 9.80665
self.__weather_model_map = WeatherModelMapping()
self.__atm_type_file_to_function_map = {
("forecast", "GFS"): fetch_gfs_file_return_dataset,
("forecast", "NAM"): fetch_nam_file_return_dataset,
("forecast", "RAP"): fetch_rap_file_return_dataset,
("forecast", "HIRESW"): fetch_hiresw_file_return_dataset,
("ensemble", "GEFS"): fetch_gefs_ensemble,
# ("ensemble", "CMC"): fetch_cmc_ensemble,
"forecast": {
"GFS": fetch_gfs_file_return_dataset,
"NAM": fetch_nam_file_return_dataset,
"RAP": fetch_rap_file_return_dataset,
"HIRESW": fetch_hiresw_file_return_dataset,
},
"ensemble": {
"GEFS": fetch_gefs_ensemble,
},
}
self.__standard_atmosphere_layers = {
"geopotential_height": [ # in geopotential m
Expand Down Expand Up @@ -1287,7 +1290,10 @@
self.process_windy_atmosphere(file)
elif type in ["forecast", "reanalysis", "ensemble"]:
dictionary = self.__validate_dictionary(file, dictionary)
fetch_function = self.__atm_type_file_to_function_map.get((type, file))
try:
fetch_function = self.__atm_type_file_to_function_map[type][file]
except KeyError:
fetch_function = None

# Fetches the dataset using OpenDAP protocol or uses the file path
dataset = fetch_function() if fetch_function is not None else file
Expand Down Expand Up @@ -2847,6 +2853,103 @@
arc_seconds = (remainder * 60 - arc_minutes) * 60
return degrees, arc_minutes, arc_seconds

def to_dict(self):
return {
"gravity": self.gravity,
"date": self.date,
"latitude": self.latitude,
"longitude": self.longitude,
"elevation": self.elevation,
"datum": self.datum,
"timezone": self.timezone,
"_max_expected_height": self.max_expected_height,
"atmospheric_model_type": self.atmospheric_model_type,
"pressure": self.pressure,
"barometric_height": self.barometric_height,
"temperature": self.temperature,
"wind_velocity_x": self.wind_velocity_x,
"wind_velocity_y": self.wind_velocity_y,
"wind_heading": self.wind_heading,
"wind_direction": self.wind_direction,
"wind_speed": self.wind_speed,
}

@classmethod
def from_dict(cls, data): # pylint: disable=too-many-statements
environment = cls(
gravity=data["gravity"],
date=data["date"],
latitude=data["latitude"],
longitude=data["longitude"],
elevation=data["elevation"],
datum=data["datum"],
timezone=data["timezone"],
max_expected_height=data["_max_expected_height"],
)
atmospheric_model = data["atmospheric_model_type"]

if atmospheric_model == "standard_atmosphere":
environment.set_atmospheric_model("standard_atmosphere")
elif atmospheric_model == "custom_atmosphere":
environment.set_atmospheric_model(

Check warning on line 2894 in rocketpy/environment/environment.py

View check run for this annotation

Codecov / codecov/patch

rocketpy/environment/environment.py#L2894

Added line #L2894 was not covered by tests
type="custom_atmosphere",
pressure=data["pressure"],
temperature=data["temperature"],
wind_u=data["wind_velocity_x"],
wind_v=data["wind_velocity_y"],
)
else:
environment.__set_pressure_function(data["pressure"])
environment.__set_barometric_height_function(data["barometric_height"])
environment.__set_temperature_function(data["temperature"])
environment.__set_wind_velocity_x_function(data["wind_velocity_x"])
environment.__set_wind_velocity_y_function(data["wind_velocity_y"])
environment.__set_wind_heading_function(data["wind_heading"])
environment.__set_wind_direction_function(data["wind_direction"])
environment.__set_wind_speed_function(data["wind_speed"])
environment.elevation = data["elevation"]
environment.max_expected_height = data["_max_expected_height"]

if atmospheric_model in ["windy", "forecast", "reanalysis", "ensemble"]:
environment.atmospheric_model_init_date = data[

Check warning on line 2914 in rocketpy/environment/environment.py

View check run for this annotation

Codecov / codecov/patch

rocketpy/environment/environment.py#L2914

Added line #L2914 was not covered by tests
"atmospheric_model_init_date"
]
environment.atmospheric_model_end_date = data[

Check warning on line 2917 in rocketpy/environment/environment.py

View check run for this annotation

Codecov / codecov/patch

rocketpy/environment/environment.py#L2917

Added line #L2917 was not covered by tests
"atmospheric_model_end_date"
]
environment.atmospheric_model_interval = data[

Check warning on line 2920 in rocketpy/environment/environment.py

View check run for this annotation

Codecov / codecov/patch

rocketpy/environment/environment.py#L2920

Added line #L2920 was not covered by tests
"atmospheric_model_interval"
]
environment.atmospheric_model_init_lat = data[

Check warning on line 2923 in rocketpy/environment/environment.py

View check run for this annotation

Codecov / codecov/patch

rocketpy/environment/environment.py#L2923

Added line #L2923 was not covered by tests
"atmospheric_model_init_lat"
]
environment.atmospheric_model_end_lat = data[

Check warning on line 2926 in rocketpy/environment/environment.py

View check run for this annotation

Codecov / codecov/patch

rocketpy/environment/environment.py#L2926

Added line #L2926 was not covered by tests
"atmospheric_model_end_lat"
]
environment.atmospheric_model_init_lon = data[

Check warning on line 2929 in rocketpy/environment/environment.py

View check run for this annotation

Codecov / codecov/patch

rocketpy/environment/environment.py#L2929

Added line #L2929 was not covered by tests
"atmospheric_model_init_lon"
]
environment.atmospheric_model_end_lon = data[

Check warning on line 2932 in rocketpy/environment/environment.py

View check run for this annotation

Codecov / codecov/patch

rocketpy/environment/environment.py#L2932

Added line #L2932 was not covered by tests
"atmospheric_model_end_lon"
]

if atmospheric_model == "ensemble":
environment.level_ensemble = data["level_ensemble"]
environment.height_ensemble = data["height_ensemble"]
environment.temperature_ensemble = data["temperature_ensemble"]
environment.wind_u_ensemble = data["wind_u_ensemble"]
environment.wind_v_ensemble = data["wind_v_ensemble"]
environment.wind_heading_ensemble = data["wind_heading_ensemble"]
environment.wind_direction_ensemble = data["wind_direction_ensemble"]
environment.wind_speed_ensemble = data["wind_speed_ensemble"]
environment.num_ensemble_members = data["num_ensemble_members"]

Check warning on line 2945 in rocketpy/environment/environment.py

View check run for this annotation

Codecov / codecov/patch

rocketpy/environment/environment.py#L2937-L2945

Added lines #L2937 - L2945 were not covered by tests

environment.calculate_density_profile()
environment.calculate_speed_of_sound_profile()
environment.calculate_dynamic_viscosity()

return environment


if __name__ == "__main__":
import doctest
Expand Down
46 changes: 46 additions & 0 deletions rocketpy/mathutils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
RBFInterpolator,
)

from rocketpy._encoders import from_hex_decode, to_hex_encode

# Numpy 1.x compatibility,
# TODO: remove these lines when all dependencies support numpy>=2.0.0
if np.lib.NumpyVersion(np.__version__) >= "2.0.0b1":
Expand Down Expand Up @@ -3387,6 +3389,50 @@ def __validate_extrapolation(self, extrapolation):
extrapolation = "natural"
return extrapolation

def to_dict(self):
"""Serializes the Function instance to a dictionary.

Returns
-------
dict
A dictionary containing the Function's attributes.
"""
source = self.source

if callable(source):
source = to_hex_encode(source)

return {
"source": source,
"title": self.title,
"inputs": self.__inputs__,
"outputs": self.__outputs__,
"interpolation": self.__interpolation__,
"extrapolation": self.__extrapolation__,
}

@classmethod
def from_dict(cls, func_dict):
"""Creates a Function instance from a dictionary.

Parameters
----------
func_dict
The JSON like Function dictionary.
"""
source = func_dict["source"]
if func_dict["interpolation"] is None and func_dict["extrapolation"] is None:
source = from_hex_decode(source)

return cls(
source=source,
interpolation=func_dict["interpolation"],
extrapolation=func_dict["extrapolation"],
inputs=func_dict["inputs"],
outputs=func_dict["outputs"],
title=func_dict["title"],
)


class PiecewiseFunction(Function):
"""Class for creating piecewise functions. These kind of functions are
Expand Down
34 changes: 34 additions & 0 deletions rocketpy/motors/hybrid_motor.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,3 +612,37 @@ def all_info(self):
"""Prints out all data and graphs available about the Motor."""
self.prints.all()
self.plots.all()

@classmethod
def from_dict(cls, data):
motor = cls(
thrust_source=data["thrust"],
burn_time=data["_burn_time"],
nozzle_radius=data["nozzle_radius"],
dry_mass=data["_dry_mass"],
center_of_dry_mass_position=data["center_of_dry_mass_position"],
dry_inertia=(
data["dry_I_11"],
data["dry_I_22"],
data["dry_I_33"],
data["dry_I_12"],
data["dry_I_13"],
data["dry_I_23"],
),
interpolation_method=data["interpolate"],
coordinate_system_orientation=data["coordinate_system_orientation"],
grain_number=data["grain_number"],
grain_density=data["grain_density"],
grain_outer_radius=data["grain_outer_radius"],
grain_initial_inner_radius=data["grain_initial_inner_radius"],
grain_initial_height=data["grain_initial_height"],
grain_separation=data["grain_separation"],
grains_center_of_mass_position=data["grains_center_of_mass_position"],
nozzle_position=data["nozzle_position"],
throat_radius=data["throat_radius"],
)

for tank in data["positioned_tanks"]:
motor.add_tank(tank["tank"], tank["position"])

return motor
Loading
Loading