Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(v3.5.9) - Action delay fix #443

Merged
merged 10 commits into from
Sep 14, 2024
4 changes: 4 additions & 0 deletions docs/source/pages/wrappers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,10 @@ updated per episode. The structure of this file is defined by the **LoggerWrappe
Each episode directory includes a **monitor** folder with several CSV files for data such as observations, actions,
rewards, infos, and custom metrics. For more details, see :ref:`Output Format`.

Observations and infos CSV has one row more than the rest of the files, as they are saved at the beginning
of the episode (reset). Then, for a given row with the same index, there would be the observation and info,
the action taken in that state and the reward obtained from that action in that state.

WandBLogger
-------------

Expand Down
35 changes: 18 additions & 17 deletions sinergym/simulators/eplus.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,13 @@ def stop(self) -> None:
# Kill progress bar
if self.progress_bar is not None:
self.progress_bar.close()
# Flush all queues and wait to thread to finish (without control)
# Flush all queues and unblock thread if needed
self._flush_queues()
if self.act_queue.empty():
self.act_queue.put([0] * len(self.actuators))
# Wait to thread to finish (without control)
self.energyplus_thread.join()
self._flush_queues()
# Delete thread
self.energyplus_thread = None
# Clean runtime callbacks
Expand Down Expand Up @@ -319,25 +323,22 @@ def _process_action(self, state_argument: int) -> None:
self._init_system(self.energyplus_state)
if not self.system_ready:
return
# If not value in action queue --> do nothing
if self.act_queue.empty():
return
# Get next action from queue and check type
next_action = self.act_queue.get()
# self.logger.debug('ACTION get from queue: {}'.format(next_action))

# Set the action values obtained in actuator handlers
for i, (act_name, act_handle) in enumerate(
self.actuator_handlers.items()):
self.exchange.set_actuator_value(
state=state_argument,
actuator_handle=act_handle,
actuator_value=next_action[i]
)

# self.logger.debug(
# 'Set in actuator {} value {}.'.format(
# act_name, next_action[i]))
if not self.simulation_complete:
# Set the action values obtained in actuator handlers
for i, (act_name, act_handle) in enumerate(
self.actuator_handlers.items()):
self.exchange.set_actuator_value(
state=state_argument,
actuator_handle=act_handle,
actuator_value=next_action[i]
)

# self.logger.debug(
# 'Set in actuator {} value {}.'.format(
# act_name, next_action[i]))

def _init_system(self, state_argument: int) -> None:
"""Indicate whether system are ready to work. After waiting to API data is available, handlers are initialized, and warmup flag is correct.
Expand Down
20 changes: 20 additions & 0 deletions sinergym/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@ def log_interaction(self,
truncated (bool): Truncation flag.
custom_metrics (List[Any]): Custom metric data. Default is None.
"""
if isinstance(action, np.ndarray):
action = action.tolist()
if isinstance(obs, np.ndarray):
obs = obs.tolist()
self.observations.append(obs)
self.actions.append(action)
self.rewards.append(reward)
Expand All @@ -146,6 +150,22 @@ def log_norm_obs(self, norm_obs: List[float]) -> None:
"""
self.normalized_observations.append(norm_obs)

def log_obs(self, obs: List[float]) -> None:
"""Log observation data.

Args:
obs (List[float]): Observation data.
"""
self.observations.append(obs)

def log_info(self, info: Dict[str, Any]) -> None:
"""Log info data.

Args:
info (Dict[str, Any]): Info data.
"""
self.infos.append(info)

def reset_data(self) -> None:
"""Reset logger interactions data"""
self.interactions = 0
Expand Down
34 changes: 23 additions & 11 deletions sinergym/utils/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,16 @@ def reset(self,
# Environment reset
obs, info = self.env.reset(seed=seed, options=options)

# Log reset observation
if is_wrapped(self.env, NormalizeObservation):
self.data_logger.log_norm_obs(obs)
self.data_logger.log_obs(
self.get_wrapper_attr('unwrapped_observation'))
else:
self.data_logger.log_obs(obs)

self.data_logger.log_info(info)

return obs, info

def step(self, action: Union[int, np.ndarray]) -> Tuple[
Expand Down Expand Up @@ -951,17 +961,17 @@ def calculate_custom_metrics(self,
def get_episode_summary(self) -> Dict[str, float]:
# Get information from logger
comfort_terms = [info['comfort_term']
for info in self.data_logger.infos]
for info in self.data_logger.infos[1:]]
energy_terms = [info['energy_term']
for info in self.data_logger.infos]
for info in self.data_logger.infos[1:]]
abs_comfort_penalties = [info['abs_comfort_penalty']
for info in self.data_logger.infos]
for info in self.data_logger.infos[1:]]
abs_energy_penalties = [info['abs_energy_penalty']
for info in self.data_logger.infos]
for info in self.data_logger.infos[1:]]
temperature_violations = [info['total_temperature_violation']
for info in self.data_logger.infos]
for info in self.data_logger.infos[1:]]
power_demands = [info['total_power_demand']
for info in self.data_logger.infos]
for info in self.data_logger.infos[1:]]
try:
comfort_violation_time = len(
[value for value in temperature_violations if value > 0]) / self.get_wrapper_attr('timestep') * 100
Expand Down Expand Up @@ -1108,12 +1118,14 @@ def dump_log_files(self) -> None:
# Infos (except excluded keys)
with open(monitor_path + '/infos.csv', 'w') as f:
writer = csv.writer(f)
column_names = [key for key in episode_data.infos[0].keys(
column_names = [key for key in episode_data.infos[-1].keys(
) if key not in self.get_wrapper_attr('info_excluded_keys')]
# reset_values = [None for _ in column_names]
# Skip reset row
rows = [[value for key, value in info.items() if key not in self.get_wrapper_attr(
'info_excluded_keys')] for info in episode_data.infos]
'info_excluded_keys')] for info in episode_data.infos[1:]]
writer.writerow(column_names)
# write null row for reset
writer.writerow([None for _ in range(len(column_names))])
writer.writerows(rows)

# Agent Actions
Expand All @@ -1132,8 +1144,8 @@ def dump_log_files(self) -> None:
writer.writerow(self.get_wrapper_attr('action_variables'))
# reset_action = [None for _ in range(
# len(self.get_wrapper_attr('action_variables')))]
simulated_actions = [info['action']
for info in episode_data.infos]
simulated_actions = [list(info['action'])
for info in episode_data.infos[1:]]
if isinstance(simulated_actions[0], list):
writer.writerows(simulated_actions)
else:
Expand Down
2 changes: 1 addition & 1 deletion sinergym/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3.5.8
3.5.9
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ def calculate_custom_metrics(self,
def get_episode_summary(self) -> Dict[str, float]:
# Get information from logger
power_demands = [info['total_power_demand']
for info in self.data_logger.infos]
for info in self.data_logger.infos[1:]]

# Data summary
data_summary = {
Expand Down
22 changes: 11 additions & 11 deletions tests/test_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,10 +444,10 @@ def test_logger_wrapper(env_demo):

env.reset()
# Reset is not logged
assert len(logger.observations) == 0
assert len(logger.observations) == 1
assert len(logger.actions) == 0
assert len(logger.rewards) == 0
assert len(logger.infos) == 0
assert len(logger.infos) == 1
assert len(logger.terminateds) == 0
assert len(logger.truncateds) == 0
assert len(logger.custom_metrics) == 0
Expand All @@ -459,10 +459,10 @@ def test_logger_wrapper(env_demo):
env.step(a)

# Check that the logger has stored the data
assert len(logger.observations) == 3
assert len(logger.observations) == 4
assert len(logger.actions) == 3
assert len(logger.rewards) == 3
assert len(logger.infos) == 3
assert len(logger.infos) == 4
assert len(logger.terminateds) == 3
assert len(logger.truncateds) == 3
assert len(logger.custom_metrics) == 0
Expand All @@ -483,10 +483,10 @@ def test_logger_wrapper(env_demo):

# Check if reset method reset logger data too
env.reset()
assert len(logger.observations) == 0
assert len(logger.observations) == 1
assert len(logger.actions) == 0
assert len(logger.rewards) == 0
assert len(logger.infos) == 0
assert len(logger.infos) == 1
assert len(logger.terminateds) == 0
assert len(logger.truncateds) == 0
assert len(logger.custom_metrics) == 0
Expand Down Expand Up @@ -526,10 +526,10 @@ def test_custom_loggers(env_demo, custom_logger_wrapper):
env.step(a)

# Check that the logger has stored the data (custom metrics too)
assert len(logger.observations) == 3
assert len(logger.observations) == 4
assert len(logger.actions) == 3
assert len(logger.rewards) == 3
assert len(logger.infos) == 3
assert len(logger.infos) == 4
assert len(logger.terminateds) == 3
assert len(logger.truncateds) == 3
assert len(logger.custom_metrics) == 3
Expand All @@ -543,10 +543,10 @@ def test_custom_loggers(env_demo, custom_logger_wrapper):

# Check if reset method reset logger data too (custom_metrics too)
env.reset()
assert len(logger.observations) == 0
assert len(logger.observations) == 1
assert len(logger.actions) == 0
assert len(logger.rewards) == 0
assert len(logger.infos) == 0
assert len(logger.infos) == 1
assert len(logger.terminateds) == 0
assert len(logger.truncateds) == 0
assert len(logger.custom_metrics) == 0
Expand Down Expand Up @@ -605,7 +605,7 @@ def test_CSVlogger_wrapper(env_demo):
with open(episode_path + '/monitor/observations.csv', mode='r', newline='') as csvfile:
reader = csv.reader(csvfile, delimiter=',')
# Header row, reset and 10 steps (12)
assert len(list(reader)) == 11
assert len(list(reader)) == 12

# If env is wrapped with normalize obs...
if is_wrapped(env, NormalizeObservation):
Expand Down
Loading