diff --git a/openrl/configs/utils.py b/openrl/configs/utils.py index 53e1f4d2..b3198748 100644 --- a/openrl/configs/utils.py +++ b/openrl/configs/utils.py @@ -16,7 +16,7 @@ """""" - +import os import re import tempfile @@ -83,9 +83,19 @@ def __call__(self, parser, cfg, values, option_string=None): # Load the rendered content as a dictionary data = yaml.safe_load(rendered_content) - # Write the result to a temporary file - with tempfile.NamedTemporaryFile("w", delete=True, suffix=".yaml") as temp_file: + # Write the result to a temporary file. Not work on Windows. + # with tempfile.NamedTemporaryFile("w", delete=True, suffix=".yaml") as temp_file: + # yaml.dump(data, temp_file) + # temp_file.seek(0) # Move to the beginning of the file + # # Use the default behavior of ActionConfigFile to handle the temporary file + # super().__call__(parser, cfg, temp_file.name, option_string) + + # Write the result to a temporary file. This works on all platforms. + temp_fd, temp_filename = tempfile.mkstemp(suffix=".yaml") + with os.fdopen(temp_fd, 'w') as temp_file: yaml.dump(data, temp_file) - temp_file.seek(0) # Move to the beginning of the file + try: # Use the default behavior of ActionConfigFile to handle the temporary file - super().__call__(parser, cfg, temp_file.name, option_string) + super().__call__(parser, cfg, temp_filename, option_string) + finally: + os.remove(temp_filename)