From 335d28565ade209403cc226ece28f43fe354abe4 Mon Sep 17 00:00:00 2001 From: shuo-yuan Date: Wed, 14 Jan 2026 07:40:09 +0000 Subject: [PATCH] osworld integration --- .../examples/run_tinker/tinker_osworld.sh | 82 + .../examples/run_tinker/tinker_osworld.yaml | 39 + skyrl-agent/skyrl_agent/agents/base.py | 371 ++-- .../skyrl_agent/agents/react/react_agent.py | 187 +- .../skyrl_agent/agents/react/react_runner.py | 75 +- .../skyrl_agent/config/configuration_utils.py | 10 +- .../skyrl_agent/dispatcher/dispatchers.py | 36 +- .../integrations/tinker/tinker_train.py | 144 +- .../skyrl_agent/tasks/osworld/__init__.py | 0 .../tasks/osworld/desktop_env/__init__.py | 1 + .../tasks/osworld/desktop_env/actions.py | 203 +++ .../desktop_env/controllers/__init__.py | 0 .../osworld/desktop_env/controllers/python.py | 472 +++++ .../osworld/desktop_env/controllers/setup.py | 865 +++++++++ .../tasks/osworld/desktop_env/desktop_env.py | 616 +++++++ .../osworld/desktop_env/evaluators/README.md | 224 +++ .../desktop_env/evaluators/__init__.py | 5 + .../evaluators/getters/__init__.py | 39 + .../desktop_env/evaluators/getters/calc.py | 15 + .../desktop_env/evaluators/getters/chrome.py | 1392 +++++++++++++++ .../desktop_env/evaluators/getters/file.py | 125 ++ .../desktop_env/evaluators/getters/general.py | 42 + .../desktop_env/evaluators/getters/gimp.py | 75 + .../desktop_env/evaluators/getters/impress.py | 126 ++ .../desktop_env/evaluators/getters/info.py | 24 + .../desktop_env/evaluators/getters/misc.py | 204 +++ .../desktop_env/evaluators/getters/replay.py | 20 + .../desktop_env/evaluators/getters/vlc.py | 95 + .../desktop_env/evaluators/getters/vscode.py | 35 + .../evaluators/metrics/__init__.py | 159 ++ .../evaluators/metrics/basic_os.py | 68 + .../desktop_env/evaluators/metrics/chrome.py | 421 +++++ .../desktop_env/evaluators/metrics/docs.py | 961 ++++++++++ .../desktop_env/evaluators/metrics/general.py | 506 ++++++ .../desktop_env/evaluators/metrics/gimp.py | 573 ++++++ .../evaluators/metrics/libreoffice.py | 28 + .../desktop_env/evaluators/metrics/others.py | 90 + .../desktop_env/evaluators/metrics/pdf.py | 31 + .../desktop_env/evaluators/metrics/slides.py | 582 ++++++ .../desktop_env/evaluators/metrics/table.py | 517 ++++++ .../evaluators/metrics/thunderbird.py | 176 ++ .../desktop_env/evaluators/metrics/utils.py | 702 ++++++++ .../desktop_env/evaluators/metrics/vlc.py | 524 ++++++ .../desktop_env/evaluators/metrics/vscode.py | 283 +++ .../osworld/desktop_env/providers/README.md | 0 .../osworld/desktop_env/providers/__init__.py | 36 + .../providers/aws/AWS_GUIDELINE.md | 52 + .../desktop_env/providers/aws/__init__.py | 0 .../desktop_env/providers/aws/manager.py | 271 +++ .../desktop_env/providers/aws/provider.py | 186 ++ .../providers/aws/provider_with_proxy.py | 315 ++++ .../desktop_env/providers/aws/proxy_pool.py | 193 ++ .../desktop_env/providers/azure/__init__.py | 0 .../desktop_env/providers/azure/manager.py | 85 + .../desktop_env/providers/azure/provider.py | 205 +++ .../osworld/desktop_env/providers/base.py | 97 + .../providers/docker/DOCKER_GUIDELINE.md | 29 + .../desktop_env/providers/docker/manager.py | 123 ++ .../desktop_env/providers/docker/provider.py | 294 +++ .../desktop_env/providers/gcp/__init__.py | 0 .../desktop_env/providers/gcp/manager.py | 0 .../desktop_env/providers/gcp/provider.py | 0 .../providers/virtualbox/INSTALL_VITUALBOX.md | 11 + .../providers/virtualbox/__init__.py | 0 .../providers/virtualbox/manager.py | 461 +++++ .../providers/virtualbox/provider.py | 120 ++ .../providers/vmware/INSTALL_VMWARE.md | 23 + .../desktop_env/providers/vmware/__init__.py | 0 .../desktop_env/providers/vmware/manager.py | 453 +++++ .../desktop_env/providers/vmware/provider.py | 103 ++ .../osworld/desktop_env/server/README.md | 657 +++++++ .../tasks/osworld/desktop_env/server/main.py | 1570 +++++++++++++++++ .../desktop_env/server/osworld_server.service | 16 + .../osworld/desktop_env/server/pyxcursor.py | 146 ++ .../tasks/osworld/desktop_env_interface.py | 123 ++ .../skyrl_agent/tasks/osworld/osworld_task.py | 265 +++ .../skyrl_agent/tools/osworld_tools.py | 229 +++ 77 files changed, 16790 insertions(+), 416 deletions(-) create mode 100644 skyrl-agent/examples/run_tinker/tinker_osworld.sh create mode 100644 skyrl-agent/examples/run_tinker/tinker_osworld.yaml create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/__init__.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/__init__.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/actions.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/controllers/__init__.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/controllers/python.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/controllers/setup.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/desktop_env.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/README.md create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/__init__.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/getters/__init__.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/getters/calc.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/getters/chrome.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/getters/file.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/getters/general.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/getters/gimp.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/getters/impress.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/getters/info.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/getters/misc.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/getters/replay.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/getters/vlc.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/getters/vscode.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/__init__.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/basic_os.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/chrome.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/docs.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/general.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/gimp.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/libreoffice.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/others.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/pdf.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/slides.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/table.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/thunderbird.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/utils.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/vlc.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/vscode.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/README.md create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/__init__.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/aws/AWS_GUIDELINE.md create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/aws/__init__.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/aws/manager.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/aws/provider.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/aws/provider_with_proxy.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/aws/proxy_pool.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/azure/__init__.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/azure/manager.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/azure/provider.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/base.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/docker/DOCKER_GUIDELINE.md create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/docker/manager.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/docker/provider.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/gcp/__init__.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/gcp/manager.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/gcp/provider.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/virtualbox/INSTALL_VITUALBOX.md create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/virtualbox/__init__.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/virtualbox/manager.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/virtualbox/provider.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/vmware/INSTALL_VMWARE.md create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/vmware/__init__.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/vmware/manager.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/vmware/provider.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/server/README.md create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/server/main.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/server/osworld_server.service create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/server/pyxcursor.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/desktop_env_interface.py create mode 100644 skyrl-agent/skyrl_agent/tasks/osworld/osworld_task.py create mode 100644 skyrl-agent/skyrl_agent/tools/osworld_tools.py diff --git a/skyrl-agent/examples/run_tinker/tinker_osworld.sh b/skyrl-agent/examples/run_tinker/tinker_osworld.sh new file mode 100644 index 000000000..a99b0d605 --- /dev/null +++ b/skyrl-agent/examples/run_tinker/tinker_osworld.sh @@ -0,0 +1,82 @@ +#!/bin/bash +# set -x + +# ============================================================================= +# Tinker RL Training for MemAgent Task +# ============================================================================= +# This script demonstrates how to train a model on ruler/hotpotqa using: +# - GRPO (Group Relative Policy Optimization) for advantages +# - PPO loss for stable training +# - MemAgent tool with multi-turn interactions +# ============================================================================= + +# Data paths +DATASET_FILE="/home/ubuntu/shuo/osworld/OSWorld_llm_agentsynth/osworld_train_8.parquet" + +EVAL_DATASET_FILE="/home/ubuntu/shuo/osworld/OSWorld_llm_agentsynth/osworld_train_8.parquet" + +# Output directory +NAME="${NAME:-jan03_qwen3_8b_osworld_tinker_lr4e_5_rank128}" +OUTPUT_DIR="/home/ubuntu/shuo/osworld/checkpoints/${NAME}" +mkdir -p "$OUTPUT_DIR" + +# Model configuration +MODEL_NAME="${MODEL_NAME:-Qwen/Qwen3-8B}" +LORA_RANK="${LORA_RANK:-128}" + +# Training hyperparameters +BATCH_SIZE="${BATCH_SIZE:-8}" +LEARNING_RATE="${LEARNING_RATE:-4e-5}" +MAX_STEPS="${MAX_STEPS:-50}" +SAVE_EVERY="${SAVE_EVERY:-5}" +EVAL_EVERY="${EVAL_EVERY:-10}" + +# RL configuration +LOSS_FN="${LOSS_FN:-ppo}" +GROUP_SIZE="${GROUP_SIZE:-8}" # Should match num_trajectories in YAML +NORMALIZE_ADVANTAGES="${NORMALIZE_ADVANTAGES:-false}" + +# Logging +WANDB_PROJECT="${WANDB_PROJECT:-tinker-osw}" +WANDB_NAME="${WANDB_NAME:-${NAME}}" + +# Task configuration +TASK_YAML="./examples/run_tinker/tinker_osworld.yaml" + +echo "================================================" +echo "Tinker RL Training Configuration - OSWorld" +echo "================================================" +echo "Model: $MODEL_NAME" +echo "Dataset: $DATASET_FILE" +echo "Task YAML: $TASK_YAML" +echo "Batch Size: $BATCH_SIZE" +echo "Group Size (GRPO): $GROUP_SIZE" +echo "Max Steps: $MAX_STEPS" +echo "Output: $OUTPUT_DIR" +echo "================================================" + +# Run training +# UV_NO_SYNC=1 prevents uv from trying to (re)install dependencies (like vllm); +# make sure required deps are already installed in the active env. +LD_PRELOAD=/usr/lib/aarch64-linux-gnu/libgomp.so.1 UV_NO_SYNC=1 uv run --active --extra tinker --env-file .env -m skyrl_agent.integrations.tinker.tinker_train \ + model_name="$MODEL_NAME" \ + skyrl_agent_task_yaml="$TASK_YAML" \ + dataset_file="$DATASET_FILE" \ + eval_dataset_file="$EVAL_DATASET_FILE" \ + batch_size="$BATCH_SIZE" \ + learning_rate="$LEARNING_RATE" \ + lora_rank="$LORA_RANK" \ + max_steps="$MAX_STEPS" \ + save_every="$SAVE_EVERY" \ + loss_fn="$LOSS_FN" \ + group_size="$GROUP_SIZE" \ + normalize_advantages="$NORMALIZE_ADVANTAGES" \ + wandb_project="$WANDB_PROJECT" \ + wandb_name="$WANDB_NAME" \ + log_dir="$OUTPUT_DIR" \ + "$@" + +echo "================================================" +echo "Training completed!" +echo "Checkpoints saved to: ${OUTPUT_DIR}/${WANDB_NAME}_*" +echo "================================================" diff --git a/skyrl-agent/examples/run_tinker/tinker_osworld.yaml b/skyrl-agent/examples/run_tinker/tinker_osworld.yaml new file mode 100644 index 000000000..41b211287 --- /dev/null +++ b/skyrl-agent/examples/run_tinker/tinker_osworld.yaml @@ -0,0 +1,39 @@ +agent_cls: skyrl_agent.agents.react.ReActAgent + +task: skyrl_agent.tasks.osworld.osworld_task.OSWorldTask + +tools: ["finish", "osworld_action"] + +data: + instance_key: instance + data_source_key: data_source + +generator: + infer_backend: tinker + use_cpu_node: false + backend_config: null + num_trajectories: 8 # need to be the same as the num_trajectories in the verl config + max_iterations: 15 + max_prompt_length: 15000 + sampling_params: + temperature: 1.0 + top_p: 0.95 + max_tokens: 3000 + val_config: + num_trajectories: 1 + sampling_params: + temperature: 0.6 + top_p: 0.95 + max_tokens: 3000 + remove_think_tokens: false + vision_is_active: false + qwen3_enable_thinking: true + qwen3_acc_thinking: false + history_length: 3 + path_to_vm: "/path/to/your/vm.qcow2" + +dispatcher: + type: async_fix_pool + scheduler: naive + max_parallel_agents: 8 + max_eval_parallel_agents: 8 \ No newline at end of file diff --git a/skyrl-agent/skyrl_agent/agents/base.py b/skyrl-agent/skyrl_agent/agents/base.py index b4d801def..11a062803 100644 --- a/skyrl-agent/skyrl_agent/agents/base.py +++ b/skyrl-agent/skyrl_agent/agents/base.py @@ -10,30 +10,24 @@ from transformers import AutoTokenizer from dataclasses import dataclass -from skyrl_agent.integrations.base import ( - build_backend, - build_generator_input, - build_generator_output, - _import_object, - AsyncInferBackend, -) +from skyrl_agent.integrations.base import build_backend, build_generator_input, build_generator_output, _import_object, AsyncInferBackend from skyrl_agent.dispatcher.dispatchers import DISPATCHER_REGISTRY, DispatcherType from skyrl_agent.config.configuration_utils import TASK_CONFIG_REGISTRY, get_field_from_config, TrajectoryConfig from skyrl_agent.functional.chat_template import chat_template, chat_template_qwen3_thinking from skyrl_agent.functional.utils import transitions_to_training_data from .mapping import AGENT_TRAJECTORY_REGISTRY +import math class CompleterOutput: - pass - + pass @dataclass class RuntimeConfig: - runtime_initializer: Optional[Callable] + runtime_initializer: Optional[Callable] instruction_getter: Callable - config_builder: Optional[Callable] - completer: Optional[Callable] + config_builder: Optional[Callable] + completer: Optional[Callable] evaluator: Callable @classmethod @@ -50,13 +44,7 @@ def safe_import(cfg, key): instruction_getter = safe_import(cfg, "instruction_getter") # If optional; else raise if missing completer = safe_import(cfg, "completer") evaluator = safe_import(cfg, "evaluator") - return cls( - runtime_initializer=runtime_initializer, - config_builder=config_builder, - instruction_getter=instruction_getter, - completer=completer, - evaluator=evaluator, - ) + return cls(runtime_initializer=runtime_initializer, config_builder=config_builder, instruction_getter=instruction_getter, completer=completer, evaluator=evaluator) @dataclass @@ -75,26 +63,18 @@ class TrajectoryResult(TypedDict): class BaseTrajectory(ABC): - def __init__( - self, - cfg: TrajectoryConfig, - data: Dict[str, Any], - infer_engine: AsyncInferBackend, - tokenizer: AutoTokenizer, - task: BaseTask, - val_mode: bool = False, - ) -> None: + def __init__(self, cfg: TrajectoryConfig, data: Dict[str, Any], infer_engine: AsyncInferBackend, tokenizer: AutoTokenizer, task: BaseTask, val_mode: bool=False) -> None: super().__init__() self.cfg = cfg self.data = data - self.infer_engine = infer_engine + self.infer_engine = infer_engine self.tokenizer = tokenizer self.task = task self.val_mode = val_mode self.agent_cls = _import_object(cfg.agent_cls) - self.result: TrajectoryResult = None + self.result: TrajectoryResult = None @abstractmethod async def initialize_trajectory(self): @@ -114,19 +94,14 @@ class AgentRunner: def __init__(self, cfg: Dict[str, Any], infer_engine: Any, tokenizer: Any) -> None: """ Initialize the CodeActGenerator with the given configuration. - + Args: generation_config: Configuration dictionary containing parameters like max_prompt_length, max_response_length, etc. """ self.cfg = cfg - + # infer engine - self.infer_engine = build_backend( - cfg.generator.infer_backend, - infer_engine=infer_engine, - tokenizer=tokenizer, - cfg=cfg.generator.backend_config, - ) + self.infer_engine = build_backend(cfg.generator.infer_backend, infer_engine=infer_engine, tokenizer=tokenizer, cfg=cfg.generator.backend_config) self.tokenizer = tokenizer self.traj_cls: Type[BaseTrajectory] = _import_object(AGENT_TRAJECTORY_REGISTRY.get(cfg.agent_cls)) self.task: BaseTask = _import_object(cfg.task)() @@ -136,7 +111,7 @@ def __init__(self, cfg: Dict[str, Any], infer_engine: Any, tokenizer: Any) -> No # Will be set in subclasses self.agent_config = None - + @classmethod def from_task(cls, task: str, infer_engine: Any, tokenizer: Any): # Resolve task name or path @@ -145,39 +120,37 @@ def from_task(cls, task: str, infer_engine: Any, tokenizer: Any): elif task in TASK_CONFIG_REGISTRY: config_path = TASK_CONFIG_REGISTRY[task] else: - raise ValueError( - f"Unknown task '{task}'. Must be a YAML path or one of: {list(TASK_CONFIG_REGISTRY.keys())}" - ) + raise ValueError(f"Unknown task '{task}'. Must be a YAML path or one of: {list(TASK_CONFIG_REGISTRY.keys())}") cfg = OmegaConf.load(config_path) - + return cls(cfg, infer_engine, tokenizer) - + def _get_data(self, content) -> Dict[str, Any]: """Process input data into trajectory input.""" - data_cfg = self.cfg.get("data", {}) + data_cfg = self.cfg.get('data', {}) # Resolve instance payload instance = None - if data_cfg.get("instance_key"): + if data_cfg.get('instance_key'): try: - instance = get_field_from_config(data_cfg.get("instance_key"), content) + instance = get_field_from_config(data_cfg.get('instance_key'), content) except ValueError: instance = None # Resolve instance_id; rely on configured key and downstream default to batch_id if missing instance_id = None - if data_cfg.get("instance_id_key"): + if data_cfg.get('instance_id_key'): try: - instance_id = get_field_from_config(data_cfg.get("instance_id_key"), content) + instance_id = get_field_from_config(data_cfg.get('instance_id_key'), content) except ValueError: instance_id = None # Resolve data_source with default fallback data_source = "default" - if data_cfg.get("data_source_key"): + if data_cfg.get('data_source_key'): try: - data_source = get_field_from_config(data_cfg.get("data_source_key"), content) + data_source = get_field_from_config(data_cfg.get('data_source_key'), content) except ValueError: data_source = "default" @@ -190,20 +163,16 @@ def _get_data(self, content) -> Dict[str, Any]: def _initialize_trajectories(self, val_mode: bool = False): for batch_id, content in enumerate(self.batch): data = self._get_data(content) - instance_id: str = data["instance_id"] if data["instance_id"] else batch_id + instance_id: str = data['instance_id'] if data['instance_id'] else batch_id self.trajectories[instance_id] = {} - sampling_params = ( - self.cfg.generator.val_config.sampling_params if val_mode else self.cfg.generator.sampling_params - ) - sampling_params = OmegaConf.to_container(sampling_params, resolve=True) # e.g. converts ListConfig to list - num_trajectories = ( - self.cfg.generator.val_config.num_trajectories if val_mode else self.cfg.generator.num_trajectories - ) + sampling_params = self.cfg.generator.val_config.sampling_params if val_mode else self.cfg.generator.sampling_params + sampling_params = OmegaConf.to_container(sampling_params, resolve=True) # e.g. converts ListConfig to list + num_trajectories = self.cfg.generator.val_config.num_trajectories if val_mode else self.cfg.generator.num_trajectories # Simple generator-level toggles (defaults) - profile_tools = bool(self.cfg.generator.get("profile_tools", False)) - debug_log = bool(self.cfg.generator.get("debug_log", False)) - enable_turn_reminder = bool(self.cfg.generator.get("enable_turn_reminder", False)) + profile_tools = bool(self.cfg.generator.get('profile_tools', False)) + debug_log = bool(self.cfg.generator.get('debug_log', False)) + enable_turn_reminder = bool(self.cfg.generator.get('enable_turn_reminder', False)) for traj_id in range(num_trajectories): traj_cfg = TrajectoryConfig( @@ -219,8 +188,9 @@ def _initialize_trajectories(self, val_mode: bool = False): agent_cls=self.cfg.agent_cls, profile_tools=profile_tools, debug_log=debug_log, - early_step_threshold=self.cfg.generator.get("early_step_threshold", 0), + early_step_threshold=self.cfg.generator.get('early_step_threshold', 0), enable_turn_reminder=enable_turn_reminder, + generator_cfg=self.cfg.generator if hasattr(self.cfg, 'generator') else None ) traj: BaseTrajectory = self.traj_cls( cfg=traj_cfg, @@ -228,14 +198,14 @@ def _initialize_trajectories(self, val_mode: bool = False): tokenizer=self.tokenizer, infer_engine=self.infer_engine, task=self.task, - val_mode=val_mode, + val_mode=val_mode ) self.trajectories[instance_id][traj_id] = traj - + def _post_process_results(self, return_tensors=False, val_mode: bool = False) -> Dict[str, Any]: """ Post-process the results to convert them into the appropriate output format. - + Returns: A dictionary containing the processed results. """ @@ -249,35 +219,30 @@ def _post_process_results(self, return_tensors=False, val_mode: bool = False) -> has_finish_action_list = [] finish_reason_list = [] - num_trajectories = ( - self.cfg.generator.val_config.num_trajectories if val_mode else self.cfg.generator.num_trajectories - ) + + num_trajectories = self.cfg.generator.val_config.num_trajectories if val_mode else self.cfg.generator.num_trajectories for instance_id in self.trajectories: for trajectory_id in self.trajectories[instance_id]: - all_results.setdefault(instance_id, {})[trajectory_id] = self.trajectories[instance_id][ - trajectory_id - ].result + all_results.setdefault(instance_id, {})[trajectory_id] = self.trajectories[instance_id][trajectory_id].result for batch_idx, content in enumerate(self.batch): data = self._get_data(content) - instance = pd.Series(data["instance"]) - instance_id = data["instance_id"] if data["instance_id"] else batch_idx - instance["instance_id"] = instance_id # safe mutation + instance = pd.Series(data['instance']) + instance_id = data['instance_id'] if data['instance_id'] else batch_idx + instance['instance_id'] = instance_id # safe mutation trajectories = all_results.get(instance_id, {}) matched_results.extend(trajectories.values()) instance_list.extend([instance] * len(trajectories)) - - assert len(matched_results) == num_trajectories * len( - self.batch - ), f"Expected number of results {num_trajectories * len(self.batch)}, got {len(matched_results)}" - + + assert len(matched_results) == num_trajectories * len(self.batch), f"Expected number of results {num_trajectories * len(self.batch)}, got {len(matched_results)}" + # Group results by instance_id for message handling results_by_instance = {} for i, (instance, result) in enumerate(zip(instance_list, matched_results)): - instance_id = instance["instance_id"] + instance_id = instance['instance_id'] results_by_instance.setdefault(instance_id, []).append((i, result)) - + global_fallback_set = None for results in results_by_instance.values(): if all(res.get("messages") for _, res in results): @@ -285,26 +250,27 @@ def _post_process_results(self, return_tensors=False, val_mode: bool = False) -> break # get reward before handling empty messages for idx, result in enumerate(matched_results): - reward = result.get("reward", False) + reward = result.get('reward', False) raw_reward_list.append(reward) raw_reward = sum(raw_reward_list) / len(raw_reward_list) num_empty_messages = sum(1 for res in matched_results if not res.get("messages", [])) # Handle empty messages by copying from another trajectory of the same instance for instance_id, results in results_by_instance.items(): # Look for a non-empty base result - fallback = next((res for _, res in results if res.get("messages")), None) + fallback = next( + (res for _, res in results if res.get("messages")), + None + ) if not fallback: if global_fallback_set: - logger.warning( - f"[WARN] No local fallback for instance_id {instance_id}, using global fallback set." - ) + logger.warning(f"[WARN] No local fallback for instance_id {instance_id}, using global fallback set.") for j, (idx, res) in enumerate(results): # Use corresponding global fallback result (same trajectory index) fallback_res = global_fallback_set[j % len(global_fallback_set)] print(f"Empty messages for instance_id {instance_id}, trajectory {idx}. Using global fallback.") for key, value in fallback_res.items(): matched_results[idx][key] = copy.deepcopy(value) - matched_results[idx]["finish_reason"] = "error_runtime" + matched_results[idx]['finish_reason'] = "error_runtime" else: logger.error(f"[FATAL] No fallback (local/global) for instance_id {instance_id}. Skipping.") @@ -315,26 +281,19 @@ def _post_process_results(self, return_tensors=False, val_mode: bool = False) -> print(f"Empty messages for instance_id {instance_id}, trajectory {idx}. Using local fallback.") for key, value in fallback.items(): matched_results[idx][key] = copy.deepcopy(value) - matched_results[idx]["finish_reason"] = "error_runtime" + matched_results[idx]['finish_reason'] = "error_runtime" + # error evaluation mainly due to timeout during tool execution - mask_out_reason = [ - "CONTEXT_WINDOW_EXCEEDED", - "error_runtime", - "error_evaluation", - "max_iterations_reached", - "BAD_LLM_RESPONSE", - "stuck_in_a_loop", - "cmd_timeout", - ] + mask_out_reason = ["CONTEXT_WINDOW_EXCEEDED", "error_runtime", "error_evaluation", "max_iterations_reached", "BAD_LLM_RESPONSE", "stuck_in_a_loop", "cmd_timeout"] # Get training data - + # backward compatibility for old format # TODO(csy): remove this after oh_agent is updated all_messages = [] all_prompts = [] all_responses = [] - + # step-level prompt_input_ids = [] response_ids = [] @@ -349,60 +308,50 @@ def _post_process_results(self, return_tensors=False, val_mode: bool = False) -> num_turns = [] # assistant-based turns for result in matched_results: current_traj_id = f"{result.get('instance_id')}-traj{result.get('trajectory_id')}" - messages = result.get("messages", []) + messages = result.get('messages', []) # Count assistant messages as turns to match actual steps num_turns.append(sum(1 for m in messages if m.get("role") == "assistant")) # trajectory-level results - error_list.append(result.get("error", None)) - resolved_list.append(result.get("reward", False)) - traj_reward_list.append(result.get("reward", False)) - has_finish_action_list.append(result.get("finish", False)) - finish_reason_list.append(result.get("finish_reason", None)) + error_list.append(result.get('error', None)) + resolved_list.append(result.get('reward', False)) + traj_reward_list.append(result.get('reward', False)) + has_finish_action_list.append(result.get('finish', False)) + finish_reason_list.append(result.get('finish_reason', None)) - transitions = result.get("transitions", []) + transitions = result.get('transitions', []) # backward compatibility for old format # TODO(csy): remove this after oh_agent is updated if not transitions: - logger.info( - f"No transitions found for instance_id {instance_id}, trajectory_id {trajectory_id}. Using messages instead." - ) + logger.info(f"No transitions found for instance_id {instance_id}, trajectory_id {trajectory_id}. Using messages instead.") all_messages.append(messages) starting_index = 0 for i, msg in enumerate(messages): - if msg["role"] == "assistant": + if msg["role"] == 'assistant': starting_index = i break if starting_index == 0: # If we don't find an assistant, all messages are prompts and there are no responses - print( - f'ERROR: Found no assistant message. len(messages) == {len(messages)} and roles are {[msg["role"] for msg in messages]}' - ) + print(f'ERROR: Found no assistant message. len(messages) == {len(messages)} and roles are {[msg["role"] for msg in messages]}') starting_index = len(messages) prompt = messages[:starting_index] all_prompts.append(prompt) response = messages[starting_index:] all_responses.append(response) # filter bad trajectories - if messages and messages[-1]["role"] == "assistant": - finish_reason = result.get("finish_reason", None) + if messages and messages[-1]['role'] == 'assistant': + finish_reason = result.get('finish_reason', None) if finish_reason not in mask_out_reason: - if not ( - "" in messages[-1]["content"] and "" in messages[-1]["content"] - ): - print( - f"[WARN] Last message does not contain finish function call. Marking finish_reason {finish_reason} as BAD_LLM_RESPONSE. Content: {messages[-1]['content']}" - ) - result["finish_reason"] = "BAD_LLM_RESPONSE" - - if messages and messages[-1]["role"] == "user": - finish_reason = result.get("finish_reason", None) + if not ('' in messages[-1]['content'] and '' in messages[-1]['content']): + print(f"[WARN] Last message does not contain finish function call. Marking finish_reason {finish_reason} as BAD_LLM_RESPONSE. Content: {messages[-1]['content']}") + result['finish_reason'] = "BAD_LLM_RESPONSE" + + if messages and messages[-1]['role'] == 'user': + finish_reason = result.get('finish_reason', None) if finish_reason not in mask_out_reason: - print( - f"[WARN] Last message is from user but it's not in mask_out_reason. Marking finish_reason {finish_reason} as error_runtime. Content: {messages[-1]['content']}" - ) - result["finish_reason"] = "error_runtime" + print(f"[WARN] Last message is from user but it's not in mask_out_reason. Marking finish_reason {finish_reason} as error_runtime. Content: {messages[-1]['content']}") + result['finish_reason'] = "error_runtime" continue - + # step-level results data_list = transitions_to_training_data(transitions) for data in data_list: @@ -413,22 +362,23 @@ def _post_process_results(self, return_tensors=False, val_mode: bool = False) -> is_last_episode_list.append(False) is_last_episode_list[-1] = True steps_per_trajectory.append(len(data_list)) - reward_list.extend([result.get("reward", False)] * len(data_list)) - step_finish_reason_list.extend([result.get("finish_reason", None)] * len(data_list)) + reward_list.extend([result.get('reward', False)] * len(data_list)) + step_finish_reason_list.extend([result.get('finish_reason', None)] * len(data_list)) traj_idx_list.extend([current_traj_id] * len(data_list)) - + + # backward compatibility for old format # TODO(csy): remove this after oh_agent is updated if all_messages: # Encode messages, get assitant mask and position ids prompt_encodings = self.tokenizer.apply_chat_template( - all_prompts, + all_prompts, # return_tensors="pt", add_generation_prompt=True, return_dict=True, # padding=True ) - prompt_input_ids = prompt_encodings["input_ids"] + prompt_input_ids = prompt_encodings['input_ids'] response_encodings = self.tokenizer.apply_chat_template( all_responses, @@ -437,9 +387,9 @@ def _post_process_results(self, return_tensors=False, val_mode: bool = False) -> add_generation_prompt=False, return_dict=True, ) - - response_ids = response_encodings["input_ids"] - response_assistant_mask = response_encodings["assistant_masks"] + + response_ids = response_encodings['input_ids'] + response_assistant_mask = response_encodings['assistant_masks'] # to be compatible with new format logprobs = [None] * len(response_ids) step_finish_reason_list = finish_reason_list @@ -450,9 +400,7 @@ def _post_process_results(self, return_tensors=False, val_mode: bool = False) -> truncated_masks = [] truncated_logprobs = [] - for idx, (ids, mask, logprob, reason) in enumerate( - zip(response_ids, response_assistant_mask, logprobs, step_finish_reason_list) - ): + for idx, (ids, mask, logprob, reason) in enumerate(zip(response_ids, response_assistant_mask, logprobs, step_finish_reason_list)): # Check if truncation is needed first_nonzero = mask.index(1) if 1 in mask else len(mask) ids = ids[first_nonzero:] @@ -473,7 +421,7 @@ def _post_process_results(self, return_tensors=False, val_mode: bool = False) -> mask = mask[:max_response_length] if logprob is not None: logprob = logprob[:max_response_length] - + truncated_ids.append(ids) truncated_masks.append(mask) truncated_logprobs.append(logprob) @@ -481,23 +429,27 @@ def _post_process_results(self, return_tensors=False, val_mode: bool = False) -> response_ids = truncated_ids response_assistant_mask = truncated_masks logprobs = truncated_logprobs - + + + loss_mask = [ - [0] * len(mask) if (reason in mask_out_reason) else mask + [0] * len(mask) if ( + reason in mask_out_reason + ) else mask for mask, reason in zip(response_assistant_mask, step_finish_reason_list) - ] + ] rollout_metrics = {} # Compute assistant-based turn average and record metric avg_turn_assistant = (sum(num_turns) / len(num_turns)) if len(num_turns) > 0 else 0.0 - rollout_metrics["rollout_metrics/avg_turn_assistant"] = avg_turn_assistant + rollout_metrics['rollout_metrics/avg_turn_assistant'] = avg_turn_assistant # Note: no backward-compat key kept (removed per request) total_per_instance = defaultdict(int) resolved_per_instance = defaultdict(int) for instance, reward in zip(instance_list, resolved_list): - instance_id = instance["instance_id"] + instance_id = instance['instance_id'] total_per_instance[instance_id] += 1 if reward > 0: resolved_per_instance[instance_id] += 1 @@ -515,37 +467,21 @@ def _post_process_results(self, return_tensors=False, val_mode: bool = False) -> num_resolved_0 += 1 elif resolved == total: num_resolved_1 += 1 - - rollout_metrics["rollout_metrics/num_all_resolved"] = num_resolved_1 - rollout_metrics["rollout_metrics/num_none_resolved"] = num_resolved_0 - rollout_metrics["rollout_metrics/finish_tool_ratio"] = sum( - 1 for reason in finish_reason_list if reason == "FINISH_TOOL" - ) / len(finish_reason_list) - rollout_metrics["rollout_metrics/context_exceed_ratio"] = sum( - 1 for reason in finish_reason_list if reason == "CONTEXT_WINDOW_EXCEEDED" - ) / len(finish_reason_list) + + rollout_metrics['rollout_metrics/num_all_resolved'] = num_resolved_1 + rollout_metrics['rollout_metrics/num_none_resolved'] = num_resolved_0 + rollout_metrics['rollout_metrics/finish_tool_ratio'] = sum(1 for reason in finish_reason_list if reason == "FINISH_TOOL") / len(finish_reason_list) + rollout_metrics['rollout_metrics/context_exceed_ratio'] = sum(1 for reason in finish_reason_list if reason == "CONTEXT_WINDOW_EXCEEDED") / len(finish_reason_list) # Ratio of trajectories stopped by iteration cap; avoid 'max' in key to prevent max-reduction - rollout_metrics["rollout_metrics/iter_cap_ratio"] = sum( - 1 for reason in finish_reason_list if reason == "max_iterations_reached" - ) / len(finish_reason_list) - rollout_metrics["rollout_metrics/stuck_in_a_loop_ratio"] = sum( - 1 for reason in finish_reason_list if reason == "stuck_in_a_loop" - ) / len(finish_reason_list) - rollout_metrics["rollout_metrics/error_runtime"] = sum( - 1 for reason in finish_reason_list if reason == "error_runtime" - ) / len(finish_reason_list) - rollout_metrics["rollout_metrics/error_evaluation"] = sum( - 1 for reason in finish_reason_list if reason == "error_evaluation" - ) / len(finish_reason_list) - rollout_metrics["rollout_metrics/num_mask_out"] = sum(1 for mask in loss_mask if all(m == 0 for m in mask)) - rollout_metrics["rollout_metrics/num_mask_non_zero_reward"] = sum( - 1 for mask, reward in zip(loss_mask, resolved_list) if all(m == 0 for m in mask) and reward > 0 - ) - rollout_metrics["rollout_metrics/bad_llm_response"] = sum( - 1 for reason in finish_reason_list if reason == "BAD_LLM_RESPONSE" - ) / len(finish_reason_list) - rollout_metrics["rollout_metrics/raw_reward"] = raw_reward - rollout_metrics["rollout_metrics/num_empty_messages"] = num_empty_messages + rollout_metrics['rollout_metrics/iter_cap_ratio'] = sum(1 for reason in finish_reason_list if reason == "max_iterations_reached") / len(finish_reason_list) + rollout_metrics['rollout_metrics/stuck_in_a_loop_ratio'] = sum(1 for reason in finish_reason_list if reason == "stuck_in_a_loop") / len(finish_reason_list) + rollout_metrics['rollout_metrics/error_runtime'] = sum(1 for reason in finish_reason_list if reason == "error_runtime") / len(finish_reason_list) + rollout_metrics['rollout_metrics/error_evaluation'] = sum(1 for reason in finish_reason_list if reason == "error_evaluation") / len(finish_reason_list) + rollout_metrics['rollout_metrics/num_mask_out'] = sum(1 for mask in loss_mask if all(m == 0 for m in mask)) + rollout_metrics['rollout_metrics/num_mask_non_zero_reward'] = sum(1 for mask, reward in zip(loss_mask, resolved_list) if all(m == 0 for m in mask) and reward > 0) + rollout_metrics['rollout_metrics/bad_llm_response'] = sum(1 for reason in finish_reason_list if reason == "BAD_LLM_RESPONSE") / len(finish_reason_list) + rollout_metrics['rollout_metrics/raw_reward'] = raw_reward + rollout_metrics['rollout_metrics/num_empty_messages'] = num_empty_messages # Optional aggregation of tool-call profiling if available try: @@ -563,11 +499,11 @@ def _post_process_results(self, return_tensors=False, val_mode: bool = False) -> tool_name_totals = defaultdict(int) if profile_enabled: for res in matched_results: - state = res.get("state") or {} - prof = state.get("tool_profile") if isinstance(state, dict) else None + state = res.get('state') or {} + prof = state.get('tool_profile') if isinstance(state, dict) else None if prof and isinstance(prof, dict): - total = prof.get("tool_calls_total") - by_name = prof.get("tool_calls_by_name") or {} + total = prof.get('tool_calls_total') + by_name = prof.get('tool_calls_by_name') or {} if isinstance(total, int): tool_calls_per_traj.append(total) if isinstance(by_name, dict): @@ -579,7 +515,7 @@ def _post_process_results(self, return_tensors=False, val_mode: bool = False) -> pass # compute no-finish per-traj sum try: - nf_sum = sum(int(v) for k, v in by_name.items() if k != "finish") + nf_sum = sum(int(v) for k, v in by_name.items() if k != 'finish') tool_calls_per_traj_nf.append(nf_sum) except Exception: pass @@ -592,48 +528,49 @@ def emit_distribution(prefix: str, vals: list[int]): mean_val = sum(vals) / n mnv = s[0] mxv = s[-1] - rollout_metrics[f"{prefix}_total"] = int(sum(vals)) - rollout_metrics[f"{prefix}_per_traj_mean"] = float(mean_val) - rollout_metrics[f"{prefix}_per_traj_min"] = float(mnv) - rollout_metrics[f"{prefix}_per_traj_max"] = float(mxv) + rollout_metrics[f'{prefix}_total'] = int(sum(vals)) + rollout_metrics[f'{prefix}_per_traj_mean'] = float(mean_val) + rollout_metrics[f'{prefix}_per_traj_min'] = float(mnv) + rollout_metrics[f'{prefix}_per_traj_max'] = float(mxv) # Emit distributions for overall and no-finish variants - emit_distribution("rollout_metrics/tool_calls", tool_calls_per_traj) - emit_distribution("rollout_metrics/tool_calls_no_finish", tool_calls_per_traj_nf) + emit_distribution('rollout_metrics/tool_calls', tool_calls_per_traj) + emit_distribution('rollout_metrics/tool_calls_no_finish', tool_calls_per_traj_nf) for name, cnt in tool_name_totals.items(): - rollout_metrics[f"rollout_metrics/tool_name/{name}"] = int(cnt) + rollout_metrics[f'rollout_metrics/tool_name/{name}'] = int(cnt) except Exception: pass print("rollout metrics:", rollout_metrics) + print(f"Finish reason: {finish_reason_list}") # Create tensor dictionary output = { - "prompt_token_ids": prompt_input_ids, - "response_ids": response_ids, - "rewards": reward_list, - "traj_rewards": traj_reward_list, - "loss_masks": loss_mask, - "episode_nums": steps_per_trajectory, - "is_last_episode": is_last_episode_list, - "traj_idx": traj_idx_list, - "stop_reasons": None, - "rollout_logprobs": logprobs, - "rollout_metrics": rollout_metrics, + 'prompt_token_ids': prompt_input_ids, + 'response_ids': response_ids, + 'rewards': reward_list, + 'traj_rewards': traj_reward_list, + 'loss_masks': loss_mask, + 'episode_nums': steps_per_trajectory, + 'is_last_episode': is_last_episode_list, + 'traj_idx': traj_idx_list, + 'stop_reasons': None, + 'rollout_logprobs': logprobs, + 'rollout_metrics': rollout_metrics, } - + return output async def run(self, input_batch: Any, val_mode: bool = False) -> Any: """ Generate trajectories for the given prompts using the configured agents. - + Args: prompts: A dictionary containing training instances. val_mode: Whether we're running validation. - + Returns: Results converted to the appropriate output format based on infer backend. """ @@ -643,9 +580,9 @@ async def run(self, input_batch: Any, val_mode: bool = False) -> Any: num_trajectories = self.cfg.generator.val_config.num_trajectories sampling_params = self.cfg.generator.val_config.sampling_params else: - sampling_params = self.cfg.generator.sampling_params + sampling_params = self.cfg.generator.sampling_params num_trajectories = self.cfg.generator.num_trajectories - + # Initialize agents and other components self._initialize_trajectories(val_mode=val_mode) @@ -658,18 +595,15 @@ async def run(self, input_batch: Any, val_mode: bool = False) -> Any: run_fn = "generate_trajectory" eval_fn = "evaluate_trajectory" if val_mode: - max_parallel_agents = self.cfg.dispatcher.get("val_config", {}).get( - "max_parallel_agents", self.cfg.dispatcher.max_parallel_agents - ) - max_eval_parallel_agents = self.cfg.dispatcher.get("val_config", {}).get( - "max_eval_parallel_agents", self.cfg.dispatcher.max_eval_parallel_agents - ) + max_parallel_agents = self.cfg.dispatcher.get('val_config', {}).get('max_parallel_agents', self.cfg.dispatcher.max_parallel_agents) + max_eval_parallel_agents = self.cfg.dispatcher.get('val_config', {}).get('max_eval_parallel_agents', self.cfg.dispatcher.max_eval_parallel_agents) dispatcher_cfg = { "sampling_params": sampling_params, "max_parallel_agents": max_parallel_agents, "max_eval_parallel_agents": max_eval_parallel_agents, "num_instances": len(self.batch), "num_trajectories": num_trajectories, + "generator": self.cfg.generator } else: dispatcher_cfg = { @@ -678,14 +612,17 @@ async def run(self, input_batch: Any, val_mode: bool = False) -> Any: "max_eval_parallel_agents": self.cfg.dispatcher.max_eval_parallel_agents, "num_instances": len(self.batch), "num_trajectories": num_trajectories, + "generator": self.cfg.generator } - await generator_dispatcher( - dispatcher_cfg, self.trajectories, init_fn=init_fn, run_fn=run_fn, eval_fn=eval_fn - ) - + shared_envs = await self.task.initialize_shared_env(self.cfg) + dispatcher_cfg['envs'] = shared_envs + await generator_dispatcher(dispatcher_cfg, self.trajectories, init_fn=init_fn, run_fn=run_fn, eval_fn=eval_fn) + + await self.task.close_shared_env(shared_envs) + output = self._post_process_results(val_mode=val_mode) # reset after run self.trajectories = {} - + return build_generator_output(self.cfg.generator.infer_backend, output).result diff --git a/skyrl-agent/skyrl_agent/agents/react/react_agent.py b/skyrl-agent/skyrl_agent/agents/react/react_agent.py index 61cd2f5a8..f30da76e4 100644 --- a/skyrl-agent/skyrl_agent/agents/react/react_agent.py +++ b/skyrl-agent/skyrl_agent/agents/react/react_agent.py @@ -1,17 +1,18 @@ import json import copy +import os from typing import Any, List, Dict from collections import defaultdict from uuid import uuid4 import traceback from skyrl_agent.functional.utils import ( - Transition, - record_transition, - StepResult, - StepException, - ContextWindowExceeded, - ParseError, - NoToolCall, + Transition, + record_transition, + StepResult, + StepException, + ContextWindowExceeded, + ParseError, + NoToolCall, ToolExecutionFailed, ) @@ -43,6 +44,7 @@ def __init__( infer_engine: AsyncInferBackend, tokenizer: Any, ) -> None: + self.traj_config = traj_config self.tokenizer = tokenizer self.infer_engine = infer_engine self.sampling_params = traj_config.sampling_params @@ -65,14 +67,17 @@ def __init__( self.agent_id = uuid4().hex self._register_tools(traj_config.tools) - + # Message encoder - self.message_encoder = MessageEncoder( - tokenizer, qwen3_enable_thinking=self.qwen3_enable_thinking, qwen3_acc_thinking=self.qwen3_acc_thinking - ) + self.message_encoder = MessageEncoder(tokenizer, + qwen3_enable_thinking=self.qwen3_enable_thinking, + qwen3_acc_thinking=self.qwen3_acc_thinking) self.prompt_token_len = 0 self.response_token_len = 0 + + self.runtime = None + self.history_length = getattr(traj_config.generator_cfg, "history_length", -1) # -1 means no history length limit # Debug and profiling flags/counters self._debug = bool(traj_config.debug_log) @@ -99,7 +104,7 @@ def _register_tools(self, tools: List[str]) -> None: @record_transition async def _generate_with_recording(self, input_ids, sampling_params, request_id): """LLM generation wrapper that records transitions. - + This method is decorated to automatically capture: - input_ids: tokens fed to the LLM - output_tokens: tokens generated by the LLM @@ -113,20 +118,20 @@ async def _generate_with_recording(self, input_ids, sampling_params, request_id) def _prepare_llm_input(self) -> tuple[List[int], Dict]: """Prepare input_ids and sampling params for LLM using incremental encoding. - + When history is reset, retokenizes everything. Otherwise, performs incremental encoding by appending new messages to existing tokens. - + Returns: Tuple of (input_ids, sampling_params) """ - + # Check if history was reset - store flag before clearing it history_was_reset = self.history.was_reset() if history_was_reset: self.prompt_token_len = 0 self.history.clear_reset_flag() - + # Add turn reminder to history (optional via config) if self.enable_turn_reminder: remaining_steps = self.max_iterations - self.step_count + 1 @@ -136,16 +141,15 @@ def _prepare_llm_input(self) -> tuple[List[int], Dict]: early_step_threshold=self.early_step_threshold, ) self.history.add_turn_reminder(reminder_text) - + # Determine if we should retokenize everything or use incremental encoding is_prompt = len(self.history) == 2 # system + user (with reminder appended) should_retokenize = is_prompt or history_was_reset or not self.transitions - + if should_retokenize: # Retokenize everything (first message or after history reset) input_ids = self.message_encoder.encode_messages( - self.history.messages, - self.tool_params, + self.history.messages, self.tool_params, is_first_message=True, ) self.prompt_token_len = len(input_ids) @@ -165,42 +169,43 @@ def _prepare_llm_input(self) -> tuple[List[int], Dict]: last_transition.ac.token_ids = self.message_encoder.encode_messages( message, self.tool_params, add_generation=False ) - + # Encode only the new observation message(s) # Simple default behavior: assuming one env observation message per step # Can be overridden by subclasses new_obs_ids = self.message_encoder.encode_messages( [self.history.messages[-1]], self.tool_params, add_generation=True ) - + # Build input_ids incrementally: previous observation + previous action + new observation input_ids = last_transition.ob.input_ids + last_transition.ac.token_ids + new_obs_ids - + self.response_token_len = len(input_ids) - self.prompt_token_len - + # Prepare sampling params sampling_params = copy.deepcopy(self.sampling_params) - sampling_params["max_tokens"] = self.max_prompt_length - self.response_token_len - + sampling_params['max_tokens'] = self.max_prompt_length - self.response_token_len + return input_ids, sampling_params - + + def _prepare_llm_input_deprecated(self) -> tuple[List[int], Dict]: """[DEPRECATED] Prepare input_ids and sampling params for LLM. - + This is an old version that performs retokenization at every turn. This method is deprecated and will be removed in a future version. Use `_prepare_llm_input()` instead. - + Returns: Tuple of (input_ids, sampling_params) """ - + # Check if history was reset - store flag before clearing it history_was_reset = self.history.was_reset() if history_was_reset: self.prompt_token_len = 0 self.history.clear_reset_flag() - + # Track token lengths # Encode messages to input_ids if self.enable_turn_reminder: @@ -211,63 +216,63 @@ def _prepare_llm_input_deprecated(self) -> tuple[List[int], Dict]: early_step_threshold=self.early_step_threshold, ) self.history.add_turn_reminder(reminder_text) - + input_ids = self.message_encoder.encode_messages( - self.history.messages, - self.tool_params, + self.history.messages, self.tool_params, is_first_message=True, ) # Set prompt_token_len on first message (initial setup) or after history reset is_prompt = len(self.history) == 2 # system + user (possibly with reminder appended) if is_prompt or history_was_reset: self.prompt_token_len = len(input_ids) - + self.response_token_len = len(input_ids) - self.prompt_token_len - + # Prepare sampling params sampling_params = copy.deepcopy(self.sampling_params) - sampling_params["max_tokens"] = self.max_prompt_length - self.response_token_len - + sampling_params['max_tokens'] = self.max_prompt_length - self.response_token_len + return input_ids, sampling_params def _handle_parse_error(self, error: str) -> None: """Handle tool call parsing error and raise ParseError.""" print(f"[Agent Step Error] Converter failed to parse tool call: {error}") guidance = TOOL_CALL_PARSE_ERROR_GUIDANCE.format(error=error) - + self.history.add_tool_error(error) self.history.add_user_guidance(guidance) - + raise ParseError() def _handle_no_tool_call(self, response_str: str) -> None: """Handle case when no tool call is detected and raise NoToolCall.""" print(f"[Agent Step {self.step_count}] No tool call found in response") - + # Check if response was likely truncated during a tool call if check_truncated_tool_call(response_str): - print("[ERROR] Tool call appears incomplete - likely truncated!") + print(f"[ERROR] Tool call appears incomplete - likely truncated!") print(f"[ERROR] Last 500 chars: {response_str[-500:]}") - + self.history.add_user_guidance(NO_TOOL_CALL_DETECTED_GUIDANCE) raise NoToolCall() async def _execute_tool(self, tool_name: str, tool_args: Dict, tool_call_id: str) -> Any: """Execute a tool and return output. - + Raises: ToolExecutionFailed: If tool execution fails """ tool = self.tools[tool_name] - + try: output = await call_sync_from_async( tool.call, tool_args, agent=self, trajectory_id=self.trajectory_id, + runtime = self.runtime, ) - + # Record profiling stats if enabled if self._profile_enabled: try: @@ -276,9 +281,9 @@ async def _execute_tool(self, tool_name: str, tool_args: Dict, tool_call_id: str self._tool_calls_by_name[tool_name] += 1 except Exception: pass - + return output - + except Exception as e: # Tool invocation failed error_str = str(e) @@ -286,31 +291,31 @@ async def _execute_tool(self, tool_name: str, tool_args: Dict, tool_call_id: str self.history.add_tool_error(error_str, tool_call_id) except Exception: self.history.add_tool_error("Tool failed with an exception.", tool_call_id) - + self.history.add_user_guidance(TOOL_INVOCATION_ERROR_GUIDANCE) raise ToolExecutionFailed() def _append_tool_output(self, output: Any, tool_call_id: str) -> None: """Append tool output to message history. - + Args: output: Tool output to append tool_call_id: ID of the tool call """ try: self.history.add_tool_response(output, tool_call_id) - + if self._debug: preview = format_output_preview(output) print(f"[Tool Output Preview] {preview}") - + except Exception as e: print(f"[Agent Step Error] Error appending tool output to messages: {str(e)}") self.history.add_tool_error(str(e), tool_call_id) async def step(self): """Execute one agent step: LLM generation -> tool call -> tool execution. - + Returns: Tuple of (done, finish_reason, result) """ @@ -318,14 +323,14 @@ async def step(self): print(f"[Agent Step {self.step_count}] instance={self.instance_id} traj={self.trajectory_id}") result = None - + try: # 1. Prepare LLM input input_ids, sampling_params = self._prepare_llm_input() - + # Check context window if self.response_token_len >= self.max_prompt_length: - print("[Agent Step] Stopping reason: context_window_exceeded. Stopping agent.") + print(f"[Agent Step] Stopping reason: context_window_exceeded. Stopping agent.") raise ContextWindowExceeded() # 2. Generate LLM response @@ -336,10 +341,10 @@ async def step(self): ) stop_reason = meta_info["finish_reason"] print(f"[Agent Step {self.step_count}] LLM response: {response_str}. Stop reason: {stop_reason}") - + # Add assistant message to history self.history.add_assistant(response_str) - + # Check if generation stopped due to length if stop_reason == "length": print(f"[Agent Step] Stopping reason: {stop_reason}. Stopping agent.") @@ -347,7 +352,7 @@ async def step(self): # 3. Parse tool call from response tool_call, parse_error = parse_tool_call(response_str, self.tool_params) - + # Handle parse error if parse_error: self._handle_parse_error(parse_error) @@ -356,24 +361,26 @@ async def step(self): if not self.tools: print(f"[Agent Step {self.step_count}] No tools provided, returning response.") result = StepResult.finished("FINISH", response_str) - + # Handle no tool call detected elif tool_call is None: self._handle_no_tool_call(response_str) - + else: # 4. Extract tool information tool_name, tool_args = extract_tool_info(tool_call) tool_call_id = tool_call.get("id") - + # Validate tool exists if tool_name not in self.tools: - self.history.add_user_guidance(json.dumps({"error": f"Tool '{tool_name}' not found."})) + self.history.add_user_guidance( + json.dumps({"error": f"Tool '{tool_name}' not found."}) + ) result = StepResult.continuing(response_str) else: # 5. Execute tool output = await self._execute_tool(tool_name, tool_args, tool_call_id) - + # 6. Check if finish tool was called if tool_name == "finish": print(f"[Agent Step {self.step_count}] Finish tool called. Stopping agent.") @@ -381,7 +388,7 @@ async def step(self): else: # Continue agent loop result = StepResult.continuing(response_str) - + # 7. Append tool output to history only if output is not None # Some tools (like next_with_summary) embed feedback in user message # and return None to skip adding tool output @@ -390,29 +397,31 @@ async def step(self): self._append_tool_output(output, tool_call_id) else: print(f"[Tool Output step {self.step_count}] No output (feedback embedded in user message)") - + except StepException as e: # Handle expected control flow exceptions result = e.step_result - + except Exception as e: # Handle unexpected errors print(f"[Agent Step Error] Error during step: {str(e)}") result = StepResult.finished(f"error: {str(e)}", None) - + # Single exit point return result.to_tuple() - + async def run(self, instruction: List[Dict], instance: Dict | None = None) -> List[str]: """Run the agent till the end with the provided user input. Optionally accepts an instance payload for tools (stored on self.instance). """ self.instance = instance + self.runtime = instance['runtime'] if 'runtime' in instance else None self._init_message(instruction) result = None finish_reason = None while self.step_count < self.max_iterations: try: + self.task.get_task_dependent_context_management(self, self.traj_config) done, finish_reason, result = await self.step() if done: break @@ -422,17 +431,19 @@ async def run(self, instruction: List[Dict], instance: Dict | None = None) -> Li # traceback print(traceback.format_exc()) break - else: # If we exit the loop without hitting a break, it means we reached max iterations + else: # If we exit the loop without hitting a break, it means we reached max iterations finish_reason = "max_iterations_reached" - + return finish_reason, result def get_messages(self) -> List[dict]: - return convert_fncall_messages_to_non_fncall_messages(self.history.messages, self.tool_params) + return convert_fncall_messages_to_non_fncall_messages( + self.history.messages, self.tool_params + ) def get_transitions(self) -> List[Transition]: """Return the list of transitions recorded during agent execution. - + Each transition contains: - ob: Observation with input_ids (tokens fed to LLM) - ac: TokensWithLogprobs with output_tokens, logprobs, and generated text @@ -444,30 +455,30 @@ def get_transitions(self) -> List[Transition]: def _init_message(self, instruction: List[Dict]) -> None: """Initialize the agent's message history with the provided instruction. - + Automatically collects system prompt prefixes from registered tools and prepends them to the system message if present. """ if not isinstance(instruction, list): raise ValueError("Instruction must be a list of messages.") - + for msg in instruction: if not isinstance(msg, dict) or "role" not in msg or "content" not in msg: raise ValueError("Each message must be a dictionary with 'role' and 'content'.") - + # Collect system prompt prefixes from registered tools tool_prefixes = [] for tool_name, tool in self.tools.items(): prefix = tool.get_system_prompt_prefix() if prefix: tool_prefixes.append(prefix) - + # Prepend tool prefixes to system message if any exist if tool_prefixes: processed_instruction = copy.deepcopy(instruction) # Combine all tool prefixes combined_prefix = "\n\n---\n\n".join(tool_prefixes) - + # Find the first system message system_msg_found = False for msg in processed_instruction: @@ -476,11 +487,14 @@ def _init_message(self, instruction: List[Dict]) -> None: msg["content"] = combined_prefix + "\n\n---\n\n" + msg["content"] system_msg_found = True break - + # If no system message exists, create one at the beginning if not system_msg_found: - processed_instruction.insert(0, {"role": "system", "content": combined_prefix}) - + processed_instruction.insert(0, { + "role": "system", + "content": combined_prefix + }) + self.history.initialize(processed_instruction) else: self.history.initialize(instruction) @@ -497,7 +511,6 @@ def get_tool_profile(self) -> Dict[str, Any]: except Exception: return None - if __name__ == "__main__": # Example usage for testing from skyrl_agent.config.configuration_utils import TrajectoryConfig @@ -528,7 +541,7 @@ def get_tool_profile(self) -> Dict[str, Any]: backend_config = OpenAIBackendConfig( model_name=model_name, # change this to your desired url and port - api_url="http://localhost:8000", + api_url="http://localhost:8000" ) # TODO: model_name need not be in config infer_engine = OpenAIBackend(infer_engine=None, cfg=backend_config) @@ -542,13 +555,7 @@ def get_tool_profile(self) -> Dict[str, Any]: ) # Define a sample instruction - instruction = [ - {"content": "Please reason step by step, and put your final answer within \\boxed{}.", "role": "system"}, - { - "content": "Points $A,B,C,D,E$ and $F$ lie, in that order, on $\\overline{AF}$, dividing it into five segments, each of length 1. Point $G$ is not on line $AF$. Point $H$ lies on $\\overline{GD}$, and point $J$ lies on $\\overline{GF}$. The line segments $\\overline{HC}, \\overline{JE},$ and $\\overline{AG}$ are parallel. Find $HC/JE$.", - "role": "user", - }, - ] + instruction = [{'content': 'Please reason step by step, and put your final answer within \\boxed{}.', 'role': 'system'}, {'content': 'Points $A,B,C,D,E$ and $F$ lie, in that order, on $\\overline{AF}$, dividing it into five segments, each of length 1. Point $G$ is not on line $AF$. Point $H$ lies on $\\overline{GD}$, and point $J$ lies on $\\overline{GF}$. The line segments $\\overline{HC}, \\overline{JE},$ and $\\overline{AG}$ are parallel. Find $HC/JE$.', 'role': 'user'}] # Run the agent finish_reason, result = asyncio.run(agent.run(instruction)) diff --git a/skyrl-agent/skyrl_agent/agents/react/react_runner.py b/skyrl-agent/skyrl_agent/agents/react/react_runner.py index d1f8d7489..f03999013 100644 --- a/skyrl-agent/skyrl_agent/agents/react/react_runner.py +++ b/skyrl-agent/skyrl_agent/agents/react/react_runner.py @@ -1,27 +1,63 @@ +import os +import importlib +import json +from tokenize import TokenInfo +from typing import Type +from typing import Dict, Any, List, Optional + import pandas as pd +from omegaconf import OmegaConf +from loguru import logger from skyrl_agent.agents.react.react_agent import ReActAgent +from skyrl_agent.integrations.base import _import_object +from skyrl_agent.dispatcher.async_utils import call_sync_from_async +from skyrl_agent.config.configuration_utils import TrajectoryConfig, get_field_from_config -from skyrl_agent.agents.base import BaseTrajectory +from skyrl_agent.tools.base import BaseTool +from skyrl_agent.agents.base import AgentRunner, BaseTrajectory +from skyrl_agent.tasks.osworld.osworld_task import OSWorldTask class ReActTrajectory(BaseTrajectory): - async def initialize_trajectory(self): + async def initialize_trajectory(self, **kwargs): pass - async def generate_trajectory(self) -> None: + async def generate_trajectory(self, **kwargs) -> None: data = self.data - instance_id = data["instance_id"] if data["instance_id"] else self.cfg.instance_id + instance_id = data['instance_id'] if data['instance_id'] else self.cfg.instance_id instance = pd.Series(data["instance"]) + instance['runtime'] = kwargs['env'] if 'env' in kwargs else None + instance['cfg'] = self.cfg # self.agent = ReActAgent(traj_config=self.cfg, infer_engine=self.infer_engine, tokenizer=self.tokenizer) self.agent: ReActAgent = self.agent_cls( traj_config=self.cfg, infer_engine=self.infer_engine, tokenizer=self.tokenizer, ) + self.agent.task = self.task # sys + user messages - instruction = self.task.get_instruction(instance) - + try: + instruction = await self.task.get_instruction(instance) + except Exception as e: + logger.error(f"Failed to get instruction/initialize runtime for instance {instance_id}: {str(e)}") + instruction = [] + logger.error(f"Failed to initialize runtime for instance {instance_id}: {str(e)}") + instance['runtime'] = None + + return_val = { + 'instance_id': instance_id, + 'trajectory_id': self.cfg.trajectory_id, + 'messages': [], + 'state': None, + 'results': None, + 'error': str(e), + 'finish': False, + 'finish_reason': 'error_initialize_runtime', + } + self.result = return_val + return + finish_reason, result = await self.agent.run(instruction, instance) # Optional tool profile snapshot (env-gated inside agent) tool_profile = None @@ -30,28 +66,29 @@ async def generate_trajectory(self) -> None: except Exception: tool_profile = None self.result = { - "instance_id": instance_id, - "trajectory_id": self.cfg.trajectory_id, - "messages": self.agent.get_messages(), - "transitions": self.agent.get_transitions(), - "results": result, - "finish_reason": finish_reason, - "state": {"tool_profile": tool_profile} if tool_profile is not None else {}, + 'instance_id': instance_id, + 'trajectory_id': self.cfg.trajectory_id, + 'messages': self.agent.get_messages(), + 'transitions': self.agent.get_transitions(), + 'results': result, + 'finish_reason': finish_reason, + 'state': { 'tool_profile': tool_profile } if tool_profile is not None else {}, } - async def evaluate_trajectory(self) -> None: + async def evaluate_trajectory(self, **kwargs) -> None: instance_id = self.cfg.instance_id trajectory_id = self.cfg.trajectory_id data = self.data - instance_id = data["instance_id"] if data["instance_id"] else self.cfg.instance_id + instance_id = data['instance_id'] if data['instance_id'] else self.cfg.instance_id instance = data["instance"] + instance['runtime'] = kwargs['env'] if 'env' in kwargs else None # print(f"[react_runner] instance_id={instance_id} original instance type: {type(instance).__name__}") # if isinstance(instance, dict): # # print(f"[react_runner] instance keys: {list(instance.keys())}") # if not isinstance(instance, (dict, pd.Series)): # # print(f"[react_runner] Converting to Series for instance_id={instance_id}") # instance = pd.Series(instance) - result = self.result.get("results") + result = self.result.get('results') try: eval_result = await self.task.evaluate_result( @@ -61,8 +98,8 @@ async def evaluate_trajectory(self) -> None: instance_id, trajectory_id, ) - self.result["reward"] = eval_result + self.result['reward'] = eval_result except Exception as e: print(f"Error evaluating result: {e}") - self.result["reward"] = 0 - self.result["eval_error"] = str(e) + self.result['reward'] = 0 + self.result['eval_error'] = str(e) diff --git a/skyrl-agent/skyrl_agent/config/configuration_utils.py b/skyrl-agent/skyrl_agent/config/configuration_utils.py index a39206c6f..5d197183e 100644 --- a/skyrl-agent/skyrl_agent/config/configuration_utils.py +++ b/skyrl-agent/skyrl_agent/config/configuration_utils.py @@ -1,7 +1,6 @@ from dataclasses import dataclass from typing import Any, Optional - # TODO(csy): a smarter way? def get_field_from_config(key_path, context): parts = key_path.split(".") @@ -13,7 +12,6 @@ def get_field_from_config(key_path, context): raise ValueError(f"Path '{key_path}' not found in context.") return value - @dataclass class TrajectoryConfig: instance_id: int @@ -30,13 +28,13 @@ class TrajectoryConfig: debug_log: bool = False early_step_threshold: int = 0 # Step count threshold for early reminder enable_turn_reminder: bool = False - - + generator_cfg: Optional[Any] = None # DEPR @dataclass class AgentConfig: max_iterations: int = 5 tools: Optional[list] = None - -TASK_CONFIG_REGISTRY = {"swe_bench": "swe_bench.yaml"} +TASK_CONFIG_REGISTRY = { + "swe_bench": "swe_bench.yaml" +} diff --git a/skyrl-agent/skyrl_agent/dispatcher/dispatchers.py b/skyrl-agent/skyrl_agent/dispatcher/dispatchers.py index 580334f92..5d3596775 100644 --- a/skyrl-agent/skyrl_agent/dispatcher/dispatchers.py +++ b/skyrl-agent/skyrl_agent/dispatcher/dispatchers.py @@ -1,27 +1,23 @@ import asyncio +from functools import partial from typing import Callable, Any, Dict from loguru import logger -DoFnType = Callable[[int, int], Any] # batch_idx, trajectory_id +DoFnType = Callable[[int, int], Any] # batch_idx, trajectory_id DispatcherType = Callable[[DoFnType, DoFnType, DoFnType], Any] # Dispatcher Registry DISPATCHER_REGISTRY: Dict[str, DispatcherType] = {} - def register_dispatcher(name): def decorator(fn): DISPATCHER_REGISTRY[name] = fn return fn - return decorator - # Async Pipeline Dispatcher (Producer-Consumer Pipelining) @register_dispatcher("async_pipeline") -async def async_pipeline_dispatcher( - cfg, trajectories: Dict[str, Dict[str, Any]], init_fn: str, run_fn: str, eval_fn: str -): +async def async_pipeline_dispatcher(cfg, trajectories: Dict[str,Dict[str, Any]], init_fn: str, run_fn: str, eval_fn: str): async def pipeline(): """Pipeline dispatcher for async processing of init, run, and eval functions.""" # Initialize queues @@ -37,15 +33,11 @@ async def pipeline(): num_trajectories = cfg["num_trajectories"] total_instances = num_instances - max_eval_parallel_agents = min(total_instances * num_trajectories, max_eval_parallel_agents) - max_parallel_agents = min(total_instances * num_trajectories, max_parallel_agents) + max_eval_parallel_agents = min(total_instances*num_trajectories, max_eval_parallel_agents) + max_parallel_agents = min(total_instances*num_trajectories, max_parallel_agents) - logger.info( - f"Using max_parallel_agents of {max_parallel_agents} for {total_instances} instances with {num_trajectories} trajectories each" - ) - logger.info( - f"Using max_eval_parallel_agents of {max_eval_parallel_agents} for {total_instances} instances with {num_trajectories} trajectories each" - ) + logger.info(f"Using max_parallel_agents of {max_parallel_agents} for {total_instances} instances with {num_trajectories} trajectories each") + logger.info(f"Using max_eval_parallel_agents of {max_eval_parallel_agents} for {total_instances} instances with {num_trajectories} trajectories each") # Fill the init queue with tasks for trajectory_id in range(num_trajectories): @@ -124,13 +116,12 @@ async def one_traj(instance_id, trajectory_id): # Async FixedEnv Pool Dispatcher (Env Pool Reuse) @register_dispatcher("async_fix_pool") -async def async_fix_pool_dispatcher(cfg, init_fn, run_fn, eval_fn): +async def async_fix_pool_dispatcher(cfg, trajectories: Dict[str,Dict[str, Any]], init_fn, run_fn, eval_fn): """ Dispatcher for pre-initialized environments. Each trajectory is assigned to a free env. When finished, the env is returned to the pool. """ - - async def dispatcher(): + async def run_all(): envs = cfg["envs"] # List of pre-initialized environments num_envs = len(envs) num_instances = cfg["num_instances"] @@ -154,14 +145,15 @@ async def worker(): while True: try: batch_idx, trajectory_id = await work_queue.get() + traj = trajectories[batch_idx][trajectory_id] env_id = await env_queue.get() # Reset and assign env - await init_fn(batch_idx, trajectory_id, env_id) + await getattr(traj, init_fn)(env=envs[env_id]) # Run and eval - await run_fn(batch_idx, trajectory_id, env_id) - await eval_fn(batch_idx, trajectory_id, env_id) + await getattr(traj, run_fn)(env=envs[env_id]) + await getattr(traj, eval_fn)(env=envs[env_id]) # Mark trajectory and env as done work_queue.task_done() @@ -181,4 +173,4 @@ async def worker(): for w in workers: w.cancel() - await dispatcher() + await run_all() \ No newline at end of file diff --git a/skyrl-agent/skyrl_agent/integrations/tinker/tinker_train.py b/skyrl-agent/skyrl_agent/integrations/tinker/tinker_train.py index 45a772baf..22b94ba06 100644 --- a/skyrl-agent/skyrl_agent/integrations/tinker/tinker_train.py +++ b/skyrl-agent/skyrl_agent/integrations/tinker/tinker_train.py @@ -14,6 +14,7 @@ import tinker import torch import wandb +from termcolor import colored from tinker import types from tinker.types.tensor_data import TensorData from transformers.models.auto.tokenization_auto import AutoTokenizer @@ -35,7 +36,6 @@ def set_seed(seed: int): torch.cuda.manual_seed_all(seed) logger.info(f"Set random seed to {seed}") - @contextmanager def timed(key: str, metrics: dict[str, Any]): logger.info(f"Starting {key}") @@ -44,10 +44,8 @@ def timed(key: str, metrics: dict[str, Any]): logger.info(f"{key} took {time.time() - tstart:.2f} seconds") metrics[f"time/{key}"] = time.time() - tstart - safezip = cast(type[zip], lambda *args, **kwargs: zip(*args, **kwargs, strict=True)) - def normalize_advantages(advantages: List[float]) -> List[float]: """Normalize advantages to have mean 0 and std 1 (standard normalization).""" if not advantages or len(advantages) == 1: @@ -66,53 +64,54 @@ def compute_advantages_grpo( ) -> List[float]: """ GRPO (Group Relative Policy Optimization) advantage estimation. - + For each group of trajectories from the same prompt, compute advantages - as deviations from the group mean. This is particularly useful for + as deviations from the group mean. This is particularly useful for best-of-N sampling scenarios. - + Reference: https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py - + Args: rewards: List of rewards for all trajectories - group_size: Number of trajectories per prompt group. + group_size: Number of trajectories per prompt group. If None, treats all as one group. normalize: Whether to apply additional global normalization - + Returns: List of advantages, same length as rewards """ rewards = np.array(rewards) - + if group_size is None: # Treat all trajectories as one group (equivalent to simple baseline) group_size = len(rewards) - + n_groups = len(rewards) // group_size advantages = [] - + for i in range(n_groups): start_idx = i * group_size end_idx = start_idx + group_size group_rewards = rewards[start_idx:end_idx] - + # GRPO: advantage = reward - mean(group_rewards) group_mean = group_rewards.mean() group_advantages = group_rewards - group_mean advantages.extend(group_advantages.tolist()) - + # Handle remaining trajectories if not evenly divisible remaining = len(rewards) % group_size assert remaining == 0, f"Remaining trajectories: {remaining} is not divisible by group_size: {group_size}" - + # Optional: Apply global normalization for extra stability if normalize: advantages = normalize_advantages(advantages) - + return advantages - -def compute_kl_sample_train(data_D: List[tinker.Datum], training_logprobs_D: List[torch.Tensor]) -> Dict[str, float]: +def compute_kl_sample_train( + data_D: List[tinker.Datum], training_logprobs_D: List[torch.Tensor] +) -> Dict[str, float]: """Compute KL divergence metrics between sampling and training logprobs.""" all_diffs: list[torch.Tensor] = [] all_sampling_logprobs: list[torch.Tensor] = [] @@ -160,12 +159,12 @@ class Config: skyrl_agent_task_yaml: str = None dataset_file: str = None # Path to the training dataset parquet file eval_dataset_file: str = None # Path to the evaluation dataset parquet file - + # Loss function configuration loss_fn: Literal["importance_sampling", "ppo", "custom_ppo"] = "ppo" # Options: # "ppo" or "importance_sampling": Use Tinker's built-in loss (forward_backward) - + # GRPO (Group Relative Policy Optimization) settings group_size: int = 8 # Trajectories per prompt group (None = auto-infer from task yaml) normalize_advantages: bool = True # Apply global normalization after group-relative computation @@ -208,7 +207,7 @@ async def save_checkpoint_async( def collate_fn(batch): """Custom collate function that returns batch as-is without tensor collation. - + This is needed because the agent runner expects to handle the raw batch data through build_generator_input, rather than having PyTorch stack tensors. """ @@ -218,13 +217,14 @@ def collate_fn(batch): async def main(config: Config): # Set random seed for reproducibility set_seed(config.seed) - + # Setup logging if config.resume_exp_name: wandb_name = config.resume_exp_name else: wandb_name = config.wandb_name or config.model_name.split("/")[-1] - wandb_name += "_" + datetime.now().strftime("%m%dT%H:%M:%S") + # Use a filesystem-safe timestamp (avoid ':' which breaks on some mounts) + wandb_name += "_" + datetime.now().strftime("%m%dT%H-%M-%S") save_path = os.path.join("./tinker_output", wandb_name) os.makedirs(save_path, exist_ok=True) @@ -240,7 +240,7 @@ async def main(config: Config): print(f"Resuming training from step {resume_from_step}") else: resume_from_step = 0 - print("Starting training from scratch") + print(f"Starting training from scratch") wandb.init( project=config.wandb_project, @@ -252,29 +252,29 @@ async def main(config: Config): # dataset and dataloader train_dataset = load_dataset("parquet", data_files=config.dataset_file)["train"] eval_dataset = load_dataset("parquet", data_files=config.eval_dataset_file)["train"] - + # Calculate steps per epoch for tracking steps_per_epoch = (len(train_dataset) + config.batch_size - 1) // config.batch_size logger.info(f"Dataset size: {len(train_dataset)}, Steps per epoch: {steps_per_epoch}") - + # Create function to get dataloader for a specific epoch def create_train_dataloader(epoch: int): """Create dataloader with epoch-specific seed for different shuffle orders.""" return DataLoader( - train_dataset, - batch_size=config.batch_size, - shuffle=True, + train_dataset, + batch_size=config.batch_size, + shuffle=True, collate_fn=collate_fn, - generator=torch.Generator().manual_seed(config.seed + epoch), # Different shuffle per epoch + generator=torch.Generator().manual_seed(config.seed + epoch) # Different shuffle per epoch ) - + # Initialize iterator state for resuming current_epoch = resume_from_step // steps_per_epoch batch_offset_in_epoch = resume_from_step % steps_per_epoch - + train_dataloader = create_train_dataloader(current_epoch) train_iterator = iter(train_dataloader) - + # Skip batches within the current epoch if resuming mid-epoch if batch_offset_in_epoch > 0: logger.info(f"Resuming from epoch {current_epoch}, batch {batch_offset_in_epoch}/{steps_per_epoch}") @@ -290,8 +290,10 @@ def create_train_dataloader(epoch: int): future = await training_client.load_state_async(load_state_path) _ = await future.result_async() logger.info(f"Loaded state from {load_state_path}") - - adam_params = types.AdamParams(learning_rate=config.learning_rate, beta1=0.9, beta2=0.95, eps=1e-8) + + adam_params = types.AdamParams( + learning_rate=config.learning_rate, beta1=0.9, beta2=0.95, eps=1e-8 + ) skyrl_agent_task_yaml_path = config.skyrl_agent_task_yaml tokenizer = AutoTokenizer.from_pretrained(config.model_name) @@ -305,7 +307,11 @@ def create_train_dataloader(epoch: int): } # save model - if config.save_every > 0 and policy_iteration_step > 0 and policy_iteration_step % config.save_every == 0: + if ( + config.save_every > 0 + and policy_iteration_step > 0 + and policy_iteration_step % config.save_every == 0 + ): await save_checkpoint_async( training_client, f"{policy_iteration_step:06d}", @@ -314,17 +320,25 @@ def create_train_dataloader(epoch: int): loop_state={"policy_iteration_step": policy_iteration_step}, ) - sampling_path = training_client.save_weights_for_sampler(name=f"{policy_iteration_step:06d}").result().path - sampling_client = service_client.create_sampling_client(model_path=sampling_path) + sampling_path = ( + training_client.save_weights_for_sampler( + name=f"{policy_iteration_step:06d}" + ) + .result() + .path + ) + sampling_client = service_client.create_sampling_client( + model_path=sampling_path + ) agent_generator = AutoAgentRunner.from_task( - skyrl_agent_task_yaml_path, infer_engine=sampling_client, tokenizer=tokenizer + skyrl_agent_task_yaml_path, + infer_engine=sampling_client, + tokenizer=tokenizer ) if policy_iteration_step % config.eval_every == 0: - eval_dataloader = DataLoader( - eval_dataset, batch_size=config.eval_batch_size, shuffle=False, collate_fn=collate_fn - ) + eval_dataloader = DataLoader(eval_dataset, batch_size=config.eval_batch_size, shuffle=False, collate_fn=collate_fn) data_source_rewards = {} for batch in eval_dataloader: input_batch = batch @@ -342,7 +356,7 @@ def create_train_dataloader(epoch: int): # Collect rollouts using AgentRunner print(f"🎲 Start collecting episodes at step {policy_iteration_step}") st = time.time() - + # Get next batch, handling epoch transitions try: input_batch = next(train_iterator) @@ -353,12 +367,12 @@ def create_train_dataloader(epoch: int): train_dataloader = create_train_dataloader(current_epoch) train_iterator = iter(train_dataloader) input_batch = next(train_iterator) - + rollouts = await agent_generator.run(input_batch, val_mode=False) metrics["time/sample"] = time.time() - st # rollout time print(f"Rollout time: {metrics['time/sample']}") - + # Write rollout_metrics to wandb rollout_metrics = rollouts.get("rollout_metrics", {}) wandb.log({f"rollout/{k}": v for k, v in rollout_metrics.items()}, step=policy_iteration_step) @@ -373,36 +387,38 @@ def create_train_dataloader(epoch: int): actual_batch_size = len(response_ids) logger.info(f"Processing {actual_batch_size} rollouts for training") - + # Compute advantages using GRPO (Group Relative Policy Optimization) all_returns = [float(r) for r in traj_rewards_list] - + # Determine group size for GRPO group_size = config.group_size if group_size is None: # Try to infer from task config from omegaconf import OmegaConf - task_config = OmegaConf.load(skyrl_agent_task_yaml_path) group_size = task_config.generator.get("num_trajectories", 1) logger.info(f"Auto-inferred group_size={group_size} from task config") - + # Compute GRPO advantages logger.info(f"Computing GRPO advantages: group_size={group_size}, normalize={config.normalize_advantages}") all_advantages = compute_advantages_grpo( - all_returns, group_size=group_size, normalize=config.normalize_advantages + all_returns, + group_size=group_size, + normalize=config.normalize_advantages ) # broadcast advantages to num_steps per trajectory step_advantages = [] for idx, num_steps in enumerate(num_steps_per_trajectory): step_advantages.extend([all_advantages[idx]] * num_steps) - + + metrics["reward/mean"] = np.mean(all_returns) metrics["reward/max"] = np.max(all_returns) metrics["reward/min"] = np.min(all_returns) metrics["advantage/mean"] = np.mean(all_advantages) metrics["advantage/std"] = np.std(all_advantages) - + # Prepare training datums compatible with Tinker API # For each trajectory, we need to provide: # - model_input: the full sequence (prompt + response) @@ -412,14 +428,15 @@ def create_train_dataloader(epoch: int): # Concatenate prompt and response to get full sequence full_sequence = prompt_token_ids[idx] + response_ids[idx] prompt_len = len(prompt_token_ids[idx]) - + # Target tokens are same as input (autoregressive training) target_tokens = full_sequence[1:] logprobs = ([0] * prompt_len + sampled_logprobs[idx])[1:] - + + # Base mask: 0 for prompt, loss_mask value for response - mask = [0] * prompt_len + loss_masks[idx] - + mask = ([0] * prompt_len + loss_masks[idx]) + # Advantages: broadcast the single advantage value across all response tokens advantage_value = step_advantages[idx] advantages = torch.zeros(len(full_sequence)) @@ -431,6 +448,7 @@ def create_train_dataloader(epoch: int): advantages = advantages[1:] mask = mask[1:] + datum = types.Datum( model_input=types.ModelInput.from_ints(tokens=full_sequence[:-1]), loss_fn_inputs={ @@ -444,13 +462,15 @@ def create_train_dataloader(epoch: int): # Training step print(f"🎈 Start training at step {policy_iteration_step}") st = time.time() - + # Use Tinker's built-in loss function ("ppo" or "importance_sampling") - fwd_bwd_future = training_client.forward_backward(training_datums, loss_fn=config.loss_fn) + fwd_bwd_future = training_client.forward_backward( + training_datums, loss_fn=config.loss_fn + ) # Optimize optim_step_future = training_client.optim_step(adam_params) fwd_bwd_result = fwd_bwd_future.result() - + # Extract training logprobs from loss_fn_outputs training_logprobs_D: list[torch.Tensor] = [] for output in fwd_bwd_result.loss_fn_outputs: @@ -459,7 +479,7 @@ def create_train_dataloader(epoch: int): # with timed("compute_kl_sample_train", metrics): # kl_sample_train_metrics = compute_kl_sample_train(training_datums, training_logprobs_D) # metrics.update(kl_sample_train_metrics) - + _ = optim_step_future.result() metrics["time/train"] = time.time() - st @@ -475,10 +495,10 @@ def create_train_dataloader(epoch: int): kind="both", loop_state={"policy_iteration_step": config.max_steps}, ) - + wandb.finish() logger.info("Training completed successfully") if __name__ == "__main__": - asyncio.run(main(chz.entrypoint(Config))) + asyncio.run(main(chz.entrypoint(Config))) \ No newline at end of file diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/__init__.py b/skyrl-agent/skyrl_agent/tasks/osworld/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/__init__.py b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/__init__.py @@ -0,0 +1 @@ + diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/actions.py b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/actions.py new file mode 100644 index 000000000..5e286c52b --- /dev/null +++ b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/actions.py @@ -0,0 +1,203 @@ +X_MAX = 1920 # TODO: get the screen resolution +Y_MAX = 1080 + +KEYBOARD_KEYS = ['\t', '\n', '\r', ' ', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '{', '|', '}', '~', 'accept', 'add', 'alt', 'altleft', 'altright', 'apps', 'backspace', 'browserback', 'browserfavorites', 'browserforward', 'browserhome', 'browserrefresh', 'browsersearch', 'browserstop', 'capslock', 'clear', 'convert', 'ctrl', 'ctrlleft', 'ctrlright', 'decimal', 'del', 'delete', 'divide', 'down', 'end', 'enter', 'esc', 'escape', 'execute', 'f1', 'f10', 'f11', 'f12', 'f13', 'f14', 'f15', 'f16', 'f17', 'f18', 'f19', 'f2', 'f20', 'f21', 'f22', 'f23', 'f24', 'f3', 'f4', 'f5', 'f6', 'f7', 'f8', 'f9', 'final', 'fn', 'hanguel', 'hangul', 'hanja', 'help', 'home', 'insert', 'junja', 'kana', 'kanji', 'launchapp1', 'launchapp2', 'launchmail', 'launchmediaselect', 'left', 'modechange', 'multiply', 'nexttrack', 'nonconvert', 'num0', 'num1', 'num2', 'num3', 'num4', 'num5', 'num6', 'num7', 'num8', 'num9', 'numlock', 'pagedown', 'pageup', 'pause', 'pgdn', 'pgup', 'playpause', 'prevtrack', 'print', 'printscreen', 'prntscrn', 'prtsc', 'prtscr', 'return', 'right', 'scrolllock', 'select', 'separator', 'shift', 'shiftleft', 'shiftright', 'sleep', 'stop', 'subtract', 'tab', 'up', 'volumedown', 'volumemute', 'volumeup', 'win', 'winleft', 'winright', 'yen', 'command', 'option', 'optionleft', 'optionright'] + +ACTION_SPACE = [ + { + "action_type": "MOVE_TO", + "note": "move the cursor to the specified position", + "parameters": { + "x": { + "type": float, + "range": [0, X_MAX], + "optional": False, + }, + "y": { + "type": float, + "range": [0, Y_MAX], + "optional": False, + } + } + }, + { + "action_type": "CLICK", + "note": "click the left button if the button not specified, otherwise click the specified button; click at the current position if x and y are not specified, otherwise click at the specified position", + "parameters": { + "button": { + "type": str, + "range": ["left", "right", "middle"], + "optional": True, + }, + "x": { + "type": float, + "range": [0, X_MAX], + "optional": True, + }, + "y": { + "type": float, + "range": [0, Y_MAX], + "optional": True, + }, + "num_clicks": { + "type": int, + "range": [1, 2, 3], + "optional": True, + }, + } + }, + { + "action_type": "MOUSE_DOWN", + "note": "press the left button if the button not specified, otherwise press the specified button", + "parameters": { + "button": { + "type": str, + "range": ["left", "right", "middle"], + "optional": True, + } + } + }, + { + "action_type": "MOUSE_UP", + "note": "release the left button if the button not specified, otherwise release the specified button", + "parameters": { + "button": { + "type": str, + "range": ["left", "right", "middle"], + "optional": True, + } + } + }, + { + "action_type": "RIGHT_CLICK", + "note": "right click at the current position if x and y are not specified, otherwise right click at the specified position", + "parameters": { + "x": { + "type": float, + "range": [0, X_MAX], + "optional": True, + }, + "y": { + "type": float, + "range": [0, Y_MAX], + "optional": True, + } + } + }, + { + "action_type": "DOUBLE_CLICK", + "note": "double click at the current position if x and y are not specified, otherwise double click at the specified position", + "parameters": { + "x": { + "type": float, + "range": [0, X_MAX], + "optional": True, + }, + "y": { + "type": float, + "range": [0, Y_MAX], + "optional": True, + } + } + }, + { + "action_type": "DRAG_TO", + "note": "drag the cursor to the specified position with the left button pressed", + "parameters": { + "x": { + "type": float, + "range": [0, X_MAX], + "optional": False, + }, + "y": { + "type": float, + "range": [0, Y_MAX], + "optional": False, + } + } + }, + { + "action_type": "SCROLL", + "note": "scroll the mouse wheel up or down", + "parameters": { + "dx": { + "type": int, + "range": None, + "optional": False, + }, + "dy": { + "type": int, + "range": None, + "optional": False, + } + } + }, + { + "action_type": "TYPING", + "note": "type the specified text", + "parameters": { + "text": { + "type": str, + "range": None, + "optional": False, + } + } + }, + { + "action_type": "PRESS", + "note": "press the specified key and release it", + "parameters": { + "key": { + "type": str, + "range": KEYBOARD_KEYS, + "optional": False, + } + } + }, + { + "action_type": "KEY_DOWN", + "note": "press the specified key", + "parameters": { + "key": { + "type": str, + "range": KEYBOARD_KEYS, + "optional": False, + } + } + }, + { + "action_type": "KEY_UP", + "note": "release the specified key", + "parameters": { + "key": { + "type": str, + "range": KEYBOARD_KEYS, + "optional": False, + } + } + }, + { + "action_type": "HOTKEY", + "note": "press the specified key combination", + "parameters": { + "keys": { + "type": list, + "range": [KEYBOARD_KEYS], + "optional": False, + } + } + }, + ############################################################################################################ + { + "action_type": "WAIT", + "note": "wait until the next action", + }, + { + "action_type": "FAIL", + "note": "decide the task can not be performed", + }, + { + "action_type": "DONE", + "note": "decide the task is done", + } +] diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/controllers/__init__.py b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/controllers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/controllers/python.py b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/controllers/python.py new file mode 100644 index 000000000..80b038256 --- /dev/null +++ b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/controllers/python.py @@ -0,0 +1,472 @@ +import json +import logging +import random +from typing import Any, Dict, Optional +import time +import requests + +from skyrl_agent.tasks.osworld.desktop_env.actions import KEYBOARD_KEYS + +logger = logging.getLogger("desktopenv.pycontroller") + + +class PythonController: + def __init__(self, vm_ip: str, + server_port: int, + pkgs_prefix: str = "import pyautogui; import time; pyautogui.FAILSAFE = False; {command}"): + self.vm_ip = vm_ip + self.http_server = f"http://{vm_ip}:{server_port}" + self.pkgs_prefix = pkgs_prefix # fixme: this is a hacky way to execute python commands. fix it and combine it with installation of packages + self.retry_times = 2 + self.retry_interval = 3 + self.request_timeout = 45 # 45 second timeout for HTTP requests + + def get_screenshot(self) -> Optional[bytes]: + """ + Gets a screenshot from the server. With the cursor. None -> no screenshot or unexpected error. + """ + + for _ in range(self.retry_times): + try: + response = requests.get(self.http_server + "/screenshot", timeout=self.request_timeout) + if response.status_code == 200: + logger.info("Got screenshot successfully") + return response.content + else: + logger.error("Failed to get screenshot. Status code: %d", response.status_code) + logger.info("Retrying to get screenshot.") + except Exception as e: + logger.error("An error occurred while trying to get the screenshot: %s", e) + logger.info("Retrying to get screenshot.") + time.sleep(self.retry_interval) + + logger.error("Failed to get screenshot.") + return None + + def get_accessibility_tree(self) -> Optional[str]: + """ + Gets the accessibility tree from the server. None -> no accessibility tree or unexpected error. + """ + + for _ in range(self.retry_times): + try: + response: requests.Response = requests.get(self.http_server + "/accessibility", timeout=self.request_timeout) + if response.status_code == 200: + logger.info("Got accessibility tree successfully") + return response.json()["AT"] + else: + logger.error("Failed to get accessibility tree. Status code: %d", response.status_code) + logger.info("Retrying to get accessibility tree.") + except Exception as e: + logger.error("An error occurred while trying to get the accessibility tree: %s", e) + logger.info("Retrying to get accessibility tree.") + time.sleep(self.retry_interval) + + logger.error("Failed to get accessibility tree.") + return None + + def get_terminal_output(self) -> Optional[str]: + """ + Gets the terminal output from the server. None -> no terminal output or unexpected error. + """ + + for _ in range(self.retry_times): + try: + response = requests.get(self.http_server + "/terminal", timeout=self.request_timeout) + if response.status_code == 200: + logger.info("Got terminal output successfully") + return response.json()["output"] + else: + logger.error("Failed to get terminal output. Status code: %d", response.status_code) + logger.info("Retrying to get terminal output.") + except Exception as e: + logger.error("An error occurred while trying to get the terminal output: %s", e) + logger.info("Retrying to get terminal output.") + time.sleep(self.retry_interval) + + logger.error("Failed to get terminal output.") + return None + + def get_file(self, file_path: str) -> Optional[bytes]: + """ + Gets a file from the server. + """ + + for _ in range(self.retry_times): + try: + response = requests.post(self.http_server + "/file", data={"file_path": file_path}, timeout=self.request_timeout) + if response.status_code == 200: + logger.info("File downloaded successfully") + return response.content + else: + logger.error("Failed to get file. Status code: %d", response.status_code) + logger.info("Retrying to get file.") + except Exception as e: + logger.error("An error occurred while trying to get the file: %s", e) + logger.info("Retrying to get file.") + time.sleep(self.retry_interval) + + logger.error("Failed to get file.") + return None + + def execute_python_command(self, command: str) -> None: + """ + Executes a python command on the server. + It can be used to execute the pyautogui commands, or... any other python command. who knows? + """ + # command_list = ["python", "-c", self.pkgs_prefix.format(command=command)] + command_list = ["python", "-c", self.pkgs_prefix.format(command=command)] + payload = json.dumps({"command": command_list, "shell": False}) + + for _ in range(self.retry_times): + try: + response = requests.post(self.http_server + "/execute", headers={'Content-Type': 'application/json'}, + data=payload, timeout=90) + if response.status_code == 200: + logger.info("Command executed successfully: %s", response.text) + return response.json() + else: + logger.error("Failed to execute command. Status code: %d", response.status_code) + logger.info("Retrying to execute command.") + except requests.exceptions.ReadTimeout: + break + except Exception as e: + logger.error("An error occurred while trying to execute the command: %s", e) + logger.info("Retrying to execute command.") + time.sleep(self.retry_interval) + + logger.error("Failed to execute command.") + return None + + def execute_action(self, action: Dict[str, Any]): + """ + Executes an action on the server computer. + """ + if action in ['WAIT', 'FAIL', 'DONE']: + return + + action_type = action["action_type"] + parameters = action["parameters"] if "parameters" in action else {param: action[param] for param in action if param != 'action_type'} + move_mode = random.choice( + ["pyautogui.easeInQuad", "pyautogui.easeOutQuad", "pyautogui.easeInOutQuad", "pyautogui.easeInBounce", + "pyautogui.easeInElastic"]) + duration = random.uniform(0.5, 1) + + if action_type == "MOVE_TO": + if parameters == {} or None: + self.execute_python_command("pyautogui.moveTo()") + elif "x" in parameters and "y" in parameters: + x = parameters["x"] + y = parameters["y"] + self.execute_python_command(f"pyautogui.moveTo({x}, {y}, {duration}, {move_mode})") + else: + raise Exception(f"Unknown parameters: {parameters}") + + elif action_type == "CLICK": + if parameters == {} or None: + self.execute_python_command("pyautogui.click()") + elif "button" in parameters and "x" in parameters and "y" in parameters: + button = parameters["button"] + x = parameters["x"] + y = parameters["y"] + if "num_clicks" in parameters: + num_clicks = parameters["num_clicks"] + self.execute_python_command( + f"pyautogui.click(button='{button}', x={x}, y={y}, clicks={num_clicks})") + else: + self.execute_python_command(f"pyautogui.click(button='{button}', x={x}, y={y})") + elif "button" in parameters and "x" not in parameters and "y" not in parameters: + button = parameters["button"] + if "num_clicks" in parameters: + num_clicks = parameters["num_clicks"] + self.execute_python_command(f"pyautogui.click(button='{button}', clicks={num_clicks})") + else: + self.execute_python_command(f"pyautogui.click(button='{button}')") + elif "button" not in parameters and "x" in parameters and "y" in parameters: + x = parameters["x"] + y = parameters["y"] + if "num_clicks" in parameters: + num_clicks = parameters["num_clicks"] + self.execute_python_command(f"pyautogui.click(x={x}, y={y}, clicks={num_clicks})") + else: + self.execute_python_command(f"pyautogui.click(x={x}, y={y})") + else: + raise Exception(f"Unknown parameters: {parameters}") + + elif action_type == "MOUSE_DOWN": + if parameters == {} or None: + self.execute_python_command("pyautogui.mouseDown()") + elif "button" in parameters: + button = parameters["button"] + self.execute_python_command(f"pyautogui.mouseDown(button='{button}')") + else: + raise Exception(f"Unknown parameters: {parameters}") + + elif action_type == "MOUSE_UP": + if parameters == {} or None: + self.execute_python_command("pyautogui.mouseUp()") + elif "button" in parameters: + button = parameters["button"] + self.execute_python_command(f"pyautogui.mouseUp(button='{button}')") + else: + raise Exception(f"Unknown parameters: {parameters}") + + elif action_type == "RIGHT_CLICK": + if parameters == {} or None: + self.execute_python_command("pyautogui.rightClick()") + elif "x" in parameters and "y" in parameters: + x = parameters["x"] + y = parameters["y"] + self.execute_python_command(f"pyautogui.rightClick(x={x}, y={y})") + else: + raise Exception(f"Unknown parameters: {parameters}") + + elif action_type == "DOUBLE_CLICK": + if parameters == {} or None: + self.execute_python_command("pyautogui.doubleClick()") + elif "x" in parameters and "y" in parameters: + x = parameters["x"] + y = parameters["y"] + self.execute_python_command(f"pyautogui.doubleClick(x={x}, y={y})") + else: + raise Exception(f"Unknown parameters: {parameters}") + + elif action_type == "DRAG_TO": + if "x" in parameters and "y" in parameters: + x = parameters["x"] + y = parameters["y"] + self.execute_python_command( + f"pyautogui.dragTo({x}, {y}, duration=1.0, button='left', mouseDownUp=True)") + + elif action_type == "SCROLL": + # todo: check if it is related to the operating system, as https://github.com/TheDuckAI/DuckTrack/blob/main/ducktrack/playback.py pointed out + if "dx" in parameters and "dy" in parameters: + dx = parameters["dx"] + dy = parameters["dy"] + self.execute_python_command(f"pyautogui.hscroll({dx})") + self.execute_python_command(f"pyautogui.vscroll({dy})") + elif "dx" in parameters and "dy" not in parameters: + dx = parameters["dx"] + self.execute_python_command(f"pyautogui.hscroll({dx})") + elif "dx" not in parameters and "dy" in parameters: + dy = parameters["dy"] + self.execute_python_command(f"pyautogui.vscroll({dy})") + else: + raise Exception(f"Unknown parameters: {parameters}") + + elif action_type == "TYPING": + if "text" not in parameters: + raise Exception(f"Unknown parameters: {parameters}") + # deal with special ' and \ characters + # text = parameters["text"].replace("\\", "\\\\").replace("'", "\\'") + # self.execute_python_command(f"pyautogui.typewrite('{text}')") + text = parameters["text"] + self.execute_python_command("pyautogui.typewrite({:})".format(repr(text))) + + elif action_type == "PRESS": + if "key" not in parameters: + raise Exception(f"Unknown parameters: {parameters}") + key = parameters["key"] + if key.lower() not in KEYBOARD_KEYS: + raise Exception(f"Key must be one of {KEYBOARD_KEYS}") + self.execute_python_command(f"pyautogui.press('{key}')") + + elif action_type == "KEY_DOWN": + if "key" not in parameters: + raise Exception(f"Unknown parameters: {parameters}") + key = parameters["key"] + if key.lower() not in KEYBOARD_KEYS: + raise Exception(f"Key must be one of {KEYBOARD_KEYS}") + self.execute_python_command(f"pyautogui.keyDown('{key}')") + + elif action_type == "KEY_UP": + if "key" not in parameters: + raise Exception(f"Unknown parameters: {parameters}") + key = parameters["key"] + if key.lower() not in KEYBOARD_KEYS: + raise Exception(f"Key must be one of {KEYBOARD_KEYS}") + self.execute_python_command(f"pyautogui.keyUp('{key}')") + + elif action_type == "HOTKEY": + if "keys" not in parameters: + raise Exception(f"Unknown parameters: {parameters}") + keys = parameters["keys"] + if not isinstance(keys, list): + raise Exception("Keys must be a list of keys") + for key in keys: + if key.lower() not in KEYBOARD_KEYS: + raise Exception(f"Key must be one of {KEYBOARD_KEYS}") + + keys_para_rep = "', '".join(keys) + self.execute_python_command(f"pyautogui.hotkey('{keys_para_rep}')") + + elif action_type in ['WAIT', 'FAIL', 'DONE']: + pass + + else: + raise Exception(f"Unknown action type: {action_type}") + + # Record video + def start_recording(self): + """ + Starts recording the screen. + """ + + for _ in range(self.retry_times): + try: + response = requests.post(self.http_server + "/start_recording", timeout=self.request_timeout) + if response.status_code == 200: + logger.info("Recording started successfully") + return + else: + logger.error("Failed to start recording. Status code: %d", response.status_code) + logger.info("Retrying to start recording.") + except Exception as e: + logger.error("An error occurred while trying to start recording: %s", e) + logger.info("Retrying to start recording.") + time.sleep(self.retry_interval) + + logger.error("Failed to start recording.") + + def end_recording(self, dest: str): + """ + Ends recording the screen. + """ + + for _ in range(self.retry_times): + try: + response = requests.post(self.http_server + "/end_recording", timeout=self.request_timeout) + if response.status_code == 200: + logger.info("Recording stopped successfully") + with open(dest, 'wb') as f: + for chunk in response.iter_content(chunk_size=8192): + if chunk: + f.write(chunk) + return + else: + logger.error("Failed to stop recording. Status code: %d", response.status_code) + logger.info("Retrying to stop recording.") + except Exception as e: + logger.error("An error occurred while trying to stop recording: %s", e) + logger.info("Retrying to stop recording.") + time.sleep(self.retry_interval) + + logger.error("Failed to stop recording.") + + # Additional info + def get_vm_platform(self): + """ + Gets the size of the vm screen. + """ + return self.execute_python_command("import platform; print(platform.system())")['output'].strip() + + def get_vm_screen_size(self): + """ + Gets the size of the vm screen. + """ + + for _ in range(self.retry_times): + try: + response = requests.post(self.http_server + "/screen_size", timeout=self.request_timeout) + if response.status_code == 200: + logger.info("Got screen size successfully") + return response.json() + else: + logger.error("Failed to get screen size. Status code: %d", response.status_code) + logger.info("Retrying to get screen size.") + except Exception as e: + logger.error("An error occurred while trying to get the screen size: %s", e) + logger.info("Retrying to get screen size.") + time.sleep(self.retry_interval) + + logger.error("Failed to get screen size.") + return None + + def get_vm_window_size(self, app_class_name: str): + """ + Gets the size of the vm app window. + """ + + for _ in range(self.retry_times): + try: + response = requests.post(self.http_server + "/window_size", data={"app_class_name": app_class_name}, timeout=self.request_timeout) + if response.status_code == 200: + logger.info("Got window size successfully") + return response.json() + else: + logger.error("Failed to get window size. Status code: %d", response.status_code) + logger.info("Retrying to get window size.") + except Exception as e: + logger.error("An error occurred while trying to get the window size: %s", e) + logger.info("Retrying to get window size.") + time.sleep(self.retry_interval) + + logger.error("Failed to get window size.") + return None + + def get_vm_wallpaper(self): + """ + Gets the wallpaper of the vm. + """ + + for _ in range(self.retry_times): + try: + response = requests.post(self.http_server + "/wallpaper", timeout=self.request_timeout) + if response.status_code == 200: + logger.info("Got wallpaper successfully") + return response.content + else: + logger.error("Failed to get wallpaper. Status code: %d", response.status_code) + logger.info("Retrying to get wallpaper.") + except Exception as e: + logger.error("An error occurred while trying to get the wallpaper: %s", e) + logger.info("Retrying to get wallpaper.") + time.sleep(self.retry_interval) + + logger.error("Failed to get wallpaper.") + return None + + def get_vm_desktop_path(self) -> Optional[str]: + """ + Gets the desktop path of the vm. + """ + + for _ in range(self.retry_times): + try: + response = requests.post(self.http_server + "/desktop_path", timeout=self.request_timeout) + if response.status_code == 200: + logger.info("Got desktop path successfully") + return response.json()["desktop_path"] + else: + logger.error("Failed to get desktop path. Status code: %d", response.status_code) + logger.info("Retrying to get desktop path.") + except Exception as e: + logger.error("An error occurred while trying to get the desktop path: %s", e) + logger.info("Retrying to get desktop path.") + time.sleep(self.retry_interval) + + logger.error("Failed to get desktop path.") + return None + + def get_vm_directory_tree(self, path) -> Optional[Dict[str, Any]]: + """ + Gets the directory tree of the vm. + """ + payload = json.dumps({"path": path}) + + for _ in range(self.retry_times): + try: + response = requests.post(self.http_server + "/list_directory", headers={'Content-Type': 'application/json'}, data=payload, timeout=self.request_timeout) + if response.status_code == 200: + logger.info("Got directory tree successfully") + return response.json()["directory_tree"] + else: + logger.error("Failed to get directory tree. Status code: %d", response.status_code) + logger.info("Retrying to get directory tree.") + except Exception as e: + logger.error("An error occurred while trying to get directory tree: %s", e) + logger.info("Retrying to get directory tree.") + time.sleep(self.retry_interval) + + logger.error("Failed to get directory tree.") + return None diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/controllers/setup.py b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/controllers/setup.py new file mode 100644 index 000000000..050eb4134 --- /dev/null +++ b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/controllers/setup.py @@ -0,0 +1,865 @@ +import asyncio +import json +import logging +import os +import os.path +import platform +import shutil +import sqlite3 +import tempfile +import time +import traceback +import uuid +from datetime import datetime, timedelta +from typing import Any, Union, Optional +from typing import Dict, List + +import requests +from playwright.sync_api import sync_playwright, TimeoutError +from playwright.async_api import async_playwright +from pydrive.auth import GoogleAuth +from pydrive.drive import GoogleDrive, GoogleDriveFile, GoogleDriveFileList +from requests_toolbelt.multipart.encoder import MultipartEncoder + +from skyrl_agent.tasks.osworld.desktop_env.controllers.python import PythonController +from skyrl_agent.tasks.osworld.desktop_env.evaluators.metrics.utils import compare_urls +from skyrl_agent.tasks.osworld.desktop_env.providers.aws.proxy_pool import get_global_proxy_pool, init_proxy_pool, ProxyInfo + +import dotenv +# Load environment variables from .env file +dotenv.load_dotenv() + +CLIENT_PASSWORD = os.getenv("CLIENT_PASSWORD", "osworld-public-evaluation") # Default password for sudo operations +PROXY_CONFIG_FILE = os.getenv("PROXY_CONFIG_FILE", "evaluation_examples/settings/proxy/dataimpulse.json") # Default proxy config file + +logger = logging.getLogger("desktopenv.setup") + +FILE_PATH = os.path.dirname(os.path.abspath(__file__)) + +init_proxy_pool(PROXY_CONFIG_FILE) # initialize the global proxy pool + +MAX_RETRIES = 20 + +class SetupController: + def __init__(self, vm_ip: str, server_port: int = 5000, chromium_port: int = 9222, vlc_port: int = 8080, cache_dir: str = "cache"): + self.vm_ip: str = vm_ip + self.server_port: int = server_port + self.chromium_port: int = chromium_port + self.vlc_port: int = vlc_port + self.http_server: str = f"http://{vm_ip}:{server_port}" + self.http_server_setup_root: str = f"http://{vm_ip}:{server_port}/setup" + self.cache_dir: str = cache_dir + self.use_proxy: bool = False + + def reset_cache_dir(self, cache_dir: str): + self.cache_dir = cache_dir + + async def setup(self, config: List[Dict[str, Any]], use_proxy: bool = False)-> bool: + """ + Args: + config (List[Dict[str, Any]]): list of dict like {str: Any}. each + config dict has the structure like + { + "type": str, corresponding to the `_{:}_setup` methods of + this class + "parameters": dict like {str, Any} providing the keyword + parameters + } + """ + self.use_proxy = use_proxy + # make sure connection can be established + logger.info(f"try to connect {self.http_server}") + retry = 0 + while retry < MAX_RETRIES: + try: + _ = requests.get(self.http_server + "/terminal") + break + except: + time.sleep(5) + retry += 1 + logger.info(f"retry: {retry}/{MAX_RETRIES}") + + if retry == MAX_RETRIES: + return False + + + for i, cfg in enumerate(config): + config_type: str = cfg["type"] + parameters: Dict[str, Any] = cfg["parameters"] + + # Assumes all the setup the functions should follow this name + # protocol + setup_function: str = "_{:}_setup".format(config_type) + assert hasattr(self, setup_function), f'Setup controller cannot find init function {setup_function}' + + try: + logger.info(f"Executing setup step {i+1}/{len(config)}: {setup_function}") + logger.debug(f"Setup parameters: {parameters}") + + # Check if the function is async + func = getattr(self, setup_function) + if asyncio.iscoroutinefunction(func): + await func(**parameters) + else: + func(**parameters) + + logger.info(f"SETUP COMPLETED: {setup_function}({str(parameters)})") + except Exception as e: + logger.error(f"SETUP FAILED at step {i+1}/{len(config)}: {setup_function}({str(parameters)})") + logger.error(f"Error details: {e}") + logger.error(f"Traceback: {traceback.format_exc()}") + raise Exception(f"Setup step {i+1} failed: {setup_function} - {e}") from e + + return True + + def _download_setup(self, files: List[Dict[str, str]]): + """ + Args: + files (List[Dict[str, str]]): files to download. lisf of dict like + { + "url": str, the url to download + "path": str, the path on the VM to store the downloaded file + } + """ + for f in files: + url: str = f["url"] + path: str = f["path"] + cache_path: str = os.path.join(self.cache_dir, "{:}_{:}".format( + uuid.uuid5(uuid.NAMESPACE_URL, url), + os.path.basename(path))) + if not url or not path: + raise Exception(f"Setup Download - Invalid URL ({url}) or path ({path}).") + + if not os.path.exists(cache_path): + logger.info(f"Cache file not found, downloading from {url} to {cache_path}") + max_retries = 3 + downloaded = False + e = None + for i in range(max_retries): + try: + logger.info(f"Download attempt {i+1}/{max_retries} for {url}") + response = requests.get(url, stream=True, timeout=300) # Add 5 minute timeout + response.raise_for_status() + + # Get file size if available + total_size = int(response.headers.get('content-length', 0)) + if total_size > 0: + logger.info(f"File size: {total_size / (1024*1024):.2f} MB") + + downloaded_size = 0 + with open(cache_path, 'wb') as f: + for chunk in response.iter_content(chunk_size=8192): + if chunk: + f.write(chunk) + downloaded_size += len(chunk) + if total_size > 0 and downloaded_size % (1024*1024) == 0: # Log every MB + progress = (downloaded_size / total_size) * 100 + logger.info(f"Download progress: {progress:.1f}%") + + logger.info(f"File downloaded successfully to {cache_path} ({downloaded_size / (1024*1024):.2f} MB)") + downloaded = True + break + + except requests.RequestException as e: + logger.error( + f"Failed to download {url} caused by {e}. Retrying... ({max_retries - i - 1} attempts left)") + # Clean up partial download + if os.path.exists(cache_path): + os.remove(cache_path) + if not downloaded: + raise requests.RequestException(f"Failed to download {url}. No retries left.") + + form = MultipartEncoder({ + "file_path": path, + "file_data": (os.path.basename(path), open(cache_path, "rb")) + }) + headers = {"Content-Type": form.content_type} + logger.debug(form.content_type) + + # send request to server to upload file + try: + logger.info(f"Uploading {os.path.basename(path)} to VM at {path}") + logger.debug("REQUEST ADDRESS: %s", self.http_server + "/setup" + "/upload") + response = requests.post(self.http_server + "/setup" + "/upload", headers=headers, data=form, timeout=600) # 10 minute timeout for upload + if response.status_code == 200: + logger.info(f"File uploaded successfully: {path}") + logger.debug("Upload response: %s", response.text) + else: + logger.error(f"Failed to upload file {path}. Status code: {response.status_code}, Response: {response.text}") + raise requests.RequestException(f"Upload failed with status {response.status_code}") + except requests.exceptions.RequestException as e: + logger.error(f"An error occurred while trying to upload {path}: {e}") + raise + + def _upload_file_setup(self, files: List[Dict[str, str]]): + """ + Args: + files (List[Dict[str, str]]): files to download. lisf of dict like + { + "local_path": str, the local path to the file to upload + "path": str, the path on the VM to store the downloaded file + } + """ + for f in files: + local_path: str = f["local_path"] + path: str = f["path"] + + if not os.path.exists(local_path): + logger.error(f"Setup Upload - Invalid local path ({local_path}).") + return + + form = MultipartEncoder({ + "file_path": path, + "file_data": (os.path.basename(path), open(local_path, "rb")) + }) + headers = {"Content-Type": form.content_type} + logger.debug(form.content_type) + + # send request to server to upload file + try: + logger.debug("REQUEST ADDRESS: %s", self.http_server + "/setup" + "/upload") + response = requests.post(self.http_server + "/setup" + "/upload", headers=headers, data=form) + if response.status_code == 200: + logger.info("Command executed successfully: %s", response.text) + else: + logger.error("Failed to upload file. Status code: %s", response.text) + except requests.exceptions.RequestException as e: + logger.error("An error occurred while trying to send the request: %s", e) + + def _change_wallpaper_setup(self, path: str): + if not path: + raise Exception(f"Setup Wallpaper - Invalid path ({path}).") + + payload = json.dumps({"path": path}) + headers = { + 'Content-Type': 'application/json' + } + + # send request to server to change wallpaper + try: + response = requests.post(self.http_server + "/setup" + "/change_wallpaper", headers=headers, data=payload) + if response.status_code == 200: + logger.info("Command executed successfully: %s", response.text) + else: + logger.error("Failed to change wallpaper. Status code: %s", response.text) + except requests.exceptions.RequestException as e: + logger.error("An error occurred while trying to send the request: %s", e) + + def _tidy_desktop_setup(self, **config): + raise NotImplementedError() + + def _open_setup(self, path: str): + if not path: + raise Exception(f"Setup Open - Invalid path ({path}).") + + payload = json.dumps({"path": path}) + headers = { + 'Content-Type': 'application/json' + } + + # send request to server to open file + try: + # The server-side call is now blocking and can take time. + # We set a timeout that is slightly longer than the server's timeout (1800s). + response = requests.post(self.http_server + "/setup" + "/open_file", headers=headers, data=payload, timeout=1810) + response.raise_for_status() # This will raise an exception for 4xx and 5xx status codes + logger.info("Command executed successfully: %s", response.text) + except requests.exceptions.RequestException as e: + logger.error(f"Failed to open file '{path}'. An error occurred while trying to send the request or the server responded with an error: {e}") + raise Exception(f"Failed to open file '{path}'. An error occurred while trying to send the request or the server responded with an error: {e}") from e + + def _launch_setup(self, command: Union[str, List[str]], shell: bool = False): + if not command: + raise Exception("Empty command to launch.") + + if not shell and isinstance(command, str) and len(command.split()) > 1: + logger.warning("Command should be a list of strings. Now it is a string. Will split it by space.") + command = command.split() + + if command[0] == "google-chrome" and self.use_proxy: + command.append("--proxy-server=http://127.0.0.1:18888") # Use the proxy server set up by _proxy_setup + + payload = json.dumps({"command": command, "shell": shell}) + headers = {"Content-Type": "application/json"} + + try: + logger.info("REQUEST ADDRESS: %s", self.http_server + "/setup" + "/launch") + response = requests.post(self.http_server + "/setup" + "/launch", headers=headers, data=payload) + if response.status_code == 200: + logger.info("Command executed successfully: %s", response.text) + else: + logger.error("Failed to launch application. Status code: %s", response.text) + except requests.exceptions.RequestException as e: + logger.error("An error occurred while trying to send the request: %s", e) + + def _execute_setup( + self, + command: List[str], + stdout: str = "", + stderr: str = "", + shell: bool = False, + until: Optional[Dict[str, Any]] = None + ): + if not command: + raise Exception("Empty command to launch.") + + until: Dict[str, Any] = until or {} + terminates: bool = False + nb_failings = 0 + + payload = json.dumps({"command": command, "shell": shell}) + headers = {"Content-Type": "application/json"} + + while not terminates: + try: + response = requests.post(self.http_server + "/setup" + "/execute", headers=headers, data=payload) + if response.status_code == 200: + results: Dict[str, str] = response.json() + if stdout: + with open(os.path.join(self.cache_dir, stdout), "w") as f: + f.write(results["output"]) + if stderr: + with open(os.path.join(self.cache_dir, stderr), "w") as f: + f.write(results["error"]) + logger.info("Command executed successfully: %s -> %s" + , " ".join(command) if isinstance(command, list) else command + , response.text + ) + else: + logger.error("Failed to launch application. Status code: %s", response.text) + results = None + nb_failings += 1 + except requests.exceptions.RequestException as e: + logger.error("An error occurred while trying to send the request: %s", e) + traceback.print_exc() + + results = None + nb_failings += 1 + + if len(until) == 0: + terminates = True + elif results is not None: + terminates = "returncode" in until and results["returncode"] == until["returncode"] \ + or "stdout" in until and until["stdout"] in results["output"] \ + or "stderr" in until and until["stderr"] in results["error"] + terminates = terminates or nb_failings >= 5 + if not terminates: + time.sleep(0.3) + + def _execute_with_verification_setup( + self, + command: List[str], + verification: Dict[str, Any] = None, + max_wait_time: int = 10, + check_interval: float = 1.0, + shell: bool = False + ): + """Execute command with verification of results + + Args: + command: Command to execute + verification: Dict with verification criteria: + - window_exists: Check if window with this name exists + - command_success: Execute this command and check if it succeeds + max_wait_time: Maximum time to wait for verification + check_interval: Time between verification checks + shell: Whether to use shell + """ + if not command: + raise Exception("Empty command to launch.") + + verification = verification or {} + + payload = json.dumps({ + "command": command, + "shell": shell, + "verification": verification, + "max_wait_time": max_wait_time, + "check_interval": check_interval + }) + headers = {"Content-Type": "application/json"} + + try: + response = requests.post(self.http_server + "/setup" + "/execute_with_verification", + headers=headers, data=payload, timeout=max_wait_time + 10) + if response.status_code == 200: + result = response.json() + logger.info("Command executed and verified successfully: %s -> %s" + , " ".join(command) if isinstance(command, list) else command + , response.text + ) + return result + else: + logger.error("Failed to execute with verification. Status code: %s", response.text) + raise Exception(f"Command verification failed: {response.text}") + except requests.exceptions.RequestException as e: + logger.error("An error occurred while trying to send the request: %s", e) + traceback.print_exc() + raise Exception(f"Request failed: {e}") + + def _command_setup(self, command: List[str], **kwargs): + self._execute_setup(command, **kwargs) + + def _sleep_setup(self, seconds: float): + time.sleep(seconds) + + def _act_setup(self, action_seq: List[Union[Dict[str, Any], str]]): + # TODO + raise NotImplementedError() + + def _replay_setup(self, trajectory: str): + """ + Args: + trajectory (str): path to the replay trajectory file + """ + + # TODO + raise NotImplementedError() + + def _activate_window_setup(self, window_name: str, strict: bool = False, by_class: bool = False): + if not window_name: + raise Exception(f"Setup Open - Invalid path ({window_name}).") + + payload = json.dumps({"window_name": window_name, "strict": strict, "by_class": by_class}) + headers = { + 'Content-Type': 'application/json' + } + + # send request to server to open file + try: + response = requests.post(self.http_server + "/setup" + "/activate_window", headers=headers, data=payload) + if response.status_code == 200: + logger.info("Command executed successfully: %s", response.text) + else: + logger.error(f"Failed to activate window {window_name}. Status code: %s", response.text) + except requests.exceptions.RequestException as e: + logger.error("An error occurred while trying to send the request: %s", e) + + def _close_window_setup(self, window_name: str, strict: bool = False, by_class: bool = False): + if not window_name: + raise Exception(f"Setup Open - Invalid path ({window_name}).") + + payload = json.dumps({"window_name": window_name, "strict": strict, "by_class": by_class}) + headers = { + 'Content-Type': 'application/json' + } + + # send request to server to open file + try: + response = requests.post(self.http_server + "/setup" + "/close_window", headers=headers, data=payload) + if response.status_code == 200: + logger.info("Command executed successfully: %s", response.text) + else: + logger.error(f"Failed to close window {window_name}. Status code: %s", response.text) + except requests.exceptions.RequestException as e: + logger.error("An error occurred while trying to send the request: %s", e) + + def _proxy_setup(self, client_password: str = CLIENT_PASSWORD): + """Setup system-wide proxy configuration using proxy pool + + Args: + client_password (str): Password for sudo operations, defaults to "password" + """ + retry = 0 + while retry < MAX_RETRIES: + try: + _ = requests.get(self.http_server + "/terminal") + break + except: + time.sleep(5) + retry += 1 + logger.info(f"retry: {retry}/{MAX_RETRIES}") + + if retry == MAX_RETRIES: + return False + + # Get proxy from global proxy pool + proxy_pool = get_global_proxy_pool() + current_proxy = proxy_pool.get_next_proxy() + + if not current_proxy: + logger.error("No proxy available from proxy pool") + raise Exception("No proxy available from proxy pool") + + # Format proxy URL + proxy_url = proxy_pool._format_proxy_url(current_proxy) + logger.info(f"Setting up proxy: {current_proxy.host}:{current_proxy.port}") + + # Configure system proxy environment variables + proxy_commands = [ + f"echo '{client_password}' | sudo -S bash -c \"apt-get update\"", ## TODO: remove this line if ami is already updated + f"echo '{client_password}' | sudo -S bash -c \"apt-get install -y tinyproxy\"", ## TODO: remove this line if tinyproxy is already installed + f"echo '{client_password}' | sudo -S bash -c \"echo 'Port 18888' > /tmp/tinyproxy.conf\"", + f"echo '{client_password}' | sudo -S bash -c \"echo 'Allow 127.0.0.1' >> /tmp/tinyproxy.conf\"", + f"echo '{client_password}' | sudo -S bash -c \"echo 'Upstream http {current_proxy.username}:{current_proxy.password}@{current_proxy.host}:{current_proxy.port}' >> /tmp/tinyproxy.conf\"", + + # CML commands to set environment variables for proxy + f"echo 'export http_proxy={proxy_url}' >> ~/.bashrc", + f"echo 'export https_proxy={proxy_url}' >> ~/.bashrc", + f"echo 'export HTTP_PROXY={proxy_url}' >> ~/.bashrc", + f"echo 'export HTTPS_PROXY={proxy_url}' >> ~/.bashrc", + ] + + # Execute all proxy configuration commands + for cmd in proxy_commands: + try: + self._execute_setup([cmd], shell=True) + except Exception as e: + logger.error(f"Failed to execute proxy setup command: {e}") + proxy_pool.mark_proxy_failed(current_proxy) + raise + + self._launch_setup(["tinyproxy -c /tmp/tinyproxy.conf -d"], shell=True) + + # Reload environment variables + reload_cmd = "source /etc/environment" + try: + logger.info(f"Proxy setup completed successfully for {current_proxy.host}:{current_proxy.port}") + proxy_pool.mark_proxy_success(current_proxy) + except Exception as e: + logger.error(f"Failed to reload environment variables: {e}") + proxy_pool.mark_proxy_failed(current_proxy) + raise + + # Chrome setup + async def _chrome_open_tabs_setup(self, urls_to_open: List[str]): + time.sleep(3) # Wait for Chrome to finish launching + + host = self.vm_ip + port = self.chromium_port # fixme: this port is hard-coded, need to be changed from config file + + remote_debugging_url = f"http://{host}:{port}" + logger.info("Connect to Chrome @: %s", remote_debugging_url) + logger.debug("PLAYWRIGHT ENV: %s", repr(os.environ)) + for attempt in range(15): + if attempt > 0: + time.sleep(5) + + browser = None + async with async_playwright() as p: + try: + browser = await p.chromium.connect_over_cdp(remote_debugging_url) + # break + except Exception as e: + if attempt < 14: + logger.error(f"Attempt {attempt + 1}: Failed to connect, retrying. Error: {e}") + # time.sleep(10) + continue + else: + logger.error(f"Failed to connect after multiple attempts: {e}") + raise e + + if not browser: + return + + logger.info("Opening %s...", urls_to_open) + for i, url in enumerate(urls_to_open): + # Use the first context (which should be the only one if using default profile) + if i == 0: + context = browser.contexts[0] + + page = await context.new_page() # Create a new page (tab) within the existing context + try: + await page.goto(url, timeout=60000) + except: + logger.warning("Opening %s exceeds time limit", url) # only for human test + logger.info(f"Opened tab {i + 1}: {url}") + + if i == 0: + # clear the default tab + default_page = context.pages[0] + await default_page.close() + + # Do not close the context or browser; they will remain open after script ends + return browser, context + + async def _chrome_close_tabs_setup(self, urls_to_close: List[str]): + time.sleep(3) # Wait for Chrome to finish launching + + host = self.vm_ip + port = self.chromium_port # fixme: this port is hard-coded, need to be changed from config file + + remote_debugging_url = f"http://{host}:{port}" + async with async_playwright() as p: + browser = None + for attempt in range(15): + try: + browser = await p.chromium.connect_over_cdp(remote_debugging_url) + break + except Exception as e: + if attempt < 14: + logger.error(f"Attempt {attempt + 1}: Failed to connect, retrying. Error: {e}") + time.sleep(5) + else: + logger.error(f"Failed to connect after multiple attempts: {e}") + raise e + + if not browser: + return + + for i, url in enumerate(urls_to_close): + # Use the first context (which should be the only one if using default profile) + if i == 0: + context = browser.contexts[0] + + for page in context.pages: + + # if two urls are the same, close the tab + if compare_urls(page.url, url): + context.pages.pop(context.pages.index(page)) + await page.close() + logger.info(f"Closed tab {i + 1}: {url}") + break + + # Do not close the context or browser; they will remain open after script ends + return browser, context + + # google drive setup + def _googledrive_setup(self, **config): + """ Clean google drive space (eliminate the impact of previous experiments to reset the environment) + @args: + config(Dict[str, Any]): contain keys + settings_file(str): path to google drive settings file, which will be loaded by pydrive.auth.GoogleAuth() + operation(List[str]): each operation is chosen from ['delete', 'upload'] + args(List[Dict[str, Any]]): parameters for each operation + different args dict for different operations: + for delete: + query(str): query pattern string to search files or folder in google drive to delete, please refer to + https://developers.google.com/drive/api/guides/search-files?hl=en about how to write query string. + trash(bool): whether to delete files permanently or move to trash. By default, trash=false, completely delete it. + for mkdirs: + path(List[str]): the path in the google drive to create folder + for upload: + path(str): remote url to download file + dest(List[str]): the path in the google drive to store the downloaded file + """ + settings_file = config.get('settings_file', 'evaluation_examples/settings/googledrive/settings.yml') + gauth = GoogleAuth(settings_file=settings_file) + drive = GoogleDrive(gauth) + + def mkdir_in_googledrive(paths: List[str]): + paths = [paths] if type(paths) != list else paths + parent_id = 'root' + for p in paths: + q = f'"{parent_id}" in parents and title = "{p}" and mimeType = "application/vnd.google-apps.folder" and trashed = false' + folder = drive.ListFile({'q': q}).GetList() + if len(folder) == 0: # not exists, create it + parents = {} if parent_id == 'root' else {'parents': [{'id': parent_id}]} + file = drive.CreateFile({'title': p, 'mimeType': 'application/vnd.google-apps.folder', **parents}) + file.Upload() + parent_id = file['id'] + else: + parent_id = folder[0]['id'] + return parent_id + + for oid, operation in enumerate(config['operation']): + if operation == 'delete': # delete a specific file + # query pattern string, by default, remove all files/folders not in the trash to the trash + params = config['args'][oid] + q = params.get('query', '') + trash = params.get('trash', False) + q_file = f"( {q} ) and mimeType != 'application/vnd.google-apps.folder'" if q.strip() else "mimeType != 'application/vnd.google-apps.folder'" + filelist: GoogleDriveFileList = drive.ListFile({'q': q_file}).GetList() + q_folder = f"( {q} ) and mimeType = 'application/vnd.google-apps.folder'" if q.strip() else "mimeType = 'application/vnd.google-apps.folder'" + folderlist: GoogleDriveFileList = drive.ListFile({'q': q_folder}).GetList() + for file in filelist: # first delete file, then folder + file: GoogleDriveFile + if trash: + file.Trash() + else: + file.Delete() + for folder in folderlist: + folder: GoogleDriveFile + # note that, if a folder is trashed/deleted, all files and folders in it will be trashed/deleted + if trash: + folder.Trash() + else: + folder.Delete() + elif operation == 'mkdirs': + params = config['args'][oid] + mkdir_in_googledrive(params['path']) + elif operation == 'upload': + params = config['args'][oid] + url = params['url'] + with tempfile.NamedTemporaryFile(mode='wb', delete=False) as tmpf: + response = requests.get(url, stream=True) + response.raise_for_status() + for chunk in response.iter_content(chunk_size=8192): + if chunk: + tmpf.write(chunk) + tmpf.close() + paths = [params['path']] if params['path'] != list else params['path'] + parent_id = mkdir_in_googledrive(paths[:-1]) + parents = {} if parent_id == 'root' else {'parents': [{'id': parent_id}]} + file = drive.CreateFile({'title': paths[-1], **parents}) + file.SetContentFile(tmpf.name) + file.Upload() + return + else: + raise ValueError('[ERROR]: not implemented clean type!') + + async def _login_setup(self, **config): + """ Login to a website with account and password information. + @args: + config(Dict[str, Any]): contain keys + settings_file(str): path to the settings file + platform(str): platform to login, implemented platforms include: + googledrive: https://drive.google.com/drive/my-drive + + """ + host = self.vm_ip + port = self.chromium_port + + remote_debugging_url = f"http://{host}:{port}" + async with async_playwright() as p: + browser = None + for attempt in range(15): + try: + browser = await p.chromium.connect_over_cdp(remote_debugging_url) + break + except Exception as e: + if attempt < 14: + logger.error(f"Attempt {attempt + 1}: Failed to connect, retrying. Error: {e}") + await asyncio.sleep(5) + else: + logger.error(f"Failed to connect after multiple attempts: {e}") + raise e + if not browser: + return + + context = browser.contexts[0] + platform = config['platform'] + + if platform == 'googledrive': + url = 'https://drive.google.com/drive/my-drive' + page = await context.new_page() # Create a new page (tab) within the existing context + try: + await page.goto(url, timeout=60000) + except: + logger.warning("Opening %s exceeds time limit", url) # only for human test + logger.info(f"Opened new page: {url}") + settings = json.load(open(config['settings_file'])) + email, password = settings['email'], settings['password'] + + try: + await page.wait_for_selector('input[type="email"]', state="visible", timeout=3000) + await page.fill('input[type="email"]', email) + await page.click('#identifierNext > div > button') + await page.wait_for_selector('input[type="password"]', state="visible", timeout=5000) + await page.fill('input[type="password"]', password) + await page.click('#passwordNext > div > button') + await page.wait_for_load_state('load', timeout=5000) + except TimeoutError: + logger.info('[ERROR]: timeout when waiting for google drive login page to load!') + return + + else: + raise NotImplementedError + + return browser, context + + def _update_browse_history_setup(self, **config): + cache_path = os.path.join(self.cache_dir, "history_new.sqlite") + db_url = "https://drive.usercontent.google.com/u/0/uc?id=1Lv74QkJYDWVX0RIgg0Co-DUcoYpVL0oX&export=download" # google drive + if not os.path.exists(cache_path): + max_retries = 3 + downloaded = False + e = None + for i in range(max_retries): + try: + response = requests.get(db_url, stream=True) + response.raise_for_status() + + with open(cache_path, 'wb') as f: + for chunk in response.iter_content(chunk_size=8192): + if chunk: + f.write(chunk) + logger.info("File downloaded successfully") + downloaded = True + break + + except requests.RequestException as e: + logger.error( + f"Failed to download {db_url} caused by {e}. Retrying... ({max_retries - i - 1} attempts left)") + if not downloaded: + raise requests.RequestException(f"Failed to download {db_url}. No retries left. Error: {e}") + else: + logger.info("File already exists in cache directory") + # copy a new history file in the tmp folder + db_path = cache_path + + history = config['history'] + + for history_item in history: + url = history_item['url'] + title = history_item['title'] + visit_time = datetime.now() - timedelta(seconds=history_item['visit_time_from_now_in_seconds']) + + # Chrome use ms from 1601-01-01 as timestamp + epoch_start = datetime(1601, 1, 1) + chrome_timestamp = int((visit_time - epoch_start).total_seconds() * 1000000) + + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + + cursor.execute(''' + INSERT INTO urls (url, title, visit_count, typed_count, last_visit_time, hidden) + VALUES (?, ?, ?, ?, ?, ?) + ''', (url, title, 1, 0, chrome_timestamp, 0)) + + url_id = cursor.lastrowid + + cursor.execute(''' + INSERT INTO visits (url, visit_time, from_visit, transition, segment_id, visit_duration) + VALUES (?, ?, ?, ?, ?, ?) + ''', (url_id, chrome_timestamp, 0, 805306368, 0, 0)) + + conn.commit() + conn.close() + + logger.info('Fake browsing history added successfully.') + + controller = PythonController(self.vm_ip, self.server_port) + + # get the path of the history file according to the platform + os_type = controller.get_vm_platform() + + if os_type == 'Windows': + chrome_history_path = controller.execute_python_command( + """import os; print(os.path.join(os.getenv('USERPROFILE'), "AppData", "Local", "Google", "Chrome", "User Data", "Default", "History"))""")[ + 'output'].strip() + elif os_type == 'Darwin': + chrome_history_path = controller.execute_python_command( + """import os; print(os.path.join(os.getenv('HOME'), "Library", "Application Support", "Google", "Chrome", "Default", "History"))""")[ + 'output'].strip() + elif os_type == 'Linux': + if "arm" in platform.machine(): + chrome_history_path = controller.execute_python_command( + "import os; print(os.path.join(os.getenv('HOME'), 'snap', 'chromium', 'common', 'chromium', 'Default', 'History'))")[ + 'output'].strip() + else: + chrome_history_path = controller.execute_python_command( + "import os; print(os.path.join(os.getenv('HOME'), '.config', 'google-chrome', 'Default', 'History'))")[ + 'output'].strip() + else: + raise Exception('Unsupported operating system') + + form = MultipartEncoder({ + "file_path": chrome_history_path, + "file_data": (os.path.basename(chrome_history_path), open(db_path, "rb")) + }) + headers = {"Content-Type": form.content_type} + logger.debug(form.content_type) + + # send request to server to upload file + try: + logger.debug("REQUEST ADDRESS: %s", self.http_server + "/setup" + "/upload") + response = requests.post(self.http_server + "/setup" + "/upload", headers=headers, data=form) + if response.status_code == 200: + logger.info("Command executed successfully: %s", response.text) + else: + logger.error("Failed to upload file. Status code: %s", response.text) + except requests.exceptions.RequestException as e: + logger.error("An error occurred while trying to send the request: %s", e) + + self._execute_setup(["sudo chown -R user:user /home/user/.config/google-chrome/Default/History"], shell=True) diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/desktop_env.py b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/desktop_env.py new file mode 100644 index 000000000..f35ebc408 --- /dev/null +++ b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/desktop_env.py @@ -0,0 +1,616 @@ +from __future__ import annotations + +import logging +import os +import time +from typing import Callable, Any, Optional, Tuple +from typing import List, Dict, Union +import asyncio +import gymnasium as gym + +from skyrl_agent.tasks.osworld.desktop_env.controllers.python import PythonController +from skyrl_agent.tasks.osworld.desktop_env.controllers.setup import SetupController +from skyrl_agent.tasks.osworld.desktop_env.evaluators import metrics, getters +from skyrl_agent.tasks.osworld.desktop_env.providers import create_vm_manager_and_provider + +logger = logging.getLogger("desktopenv.env") + +Metric = Callable[[Any, Any], float] +Getter = Callable[[gym.Env, Dict[str, Any]], Any] + +MAX_RETRIES = 10 # Maximum retries for environment setup + +class DesktopEnv(gym.Env): + """ + DesktopEnv with OpenAI Gym interface. It provides a desktop environment for setting and evaluating desktop automation tasks. + """ + def __init__( + self, + provider_name: str = "vmware", + region: str = None, + path_to_vm: str = None, + snapshot_name: str = "init_state", + action_space: str = "computer_13", + cache_dir: str = "cache", + screen_size: Tuple[int] = (1920, 1080), + headless: bool = False, + require_a11y_tree: bool = True, + require_terminal: bool = False, + os_type: str = "Ubuntu", + enable_proxy: bool = False, + env_id: int = 0, + ): + """ + Args: + provider_name (str): virtualization provider name, default to "vmware" + region (str): the region for allocate machines, work for cloud services, default to "us-east-1" + path_to_vm (str): path to .vmx file + snapshot_name (str): snapshot name to revert to, default to "init_state" + action_space (str): "computer_13" | "pyautogui" + cache_dir (str): cache directory to cache task-related stuffs like + reference file for evaluation + screen_size (Tuple[int]): screen size of the VM + headless (bool): whether to run the VM in headless mode + require_a11y_tree (bool): whether to require accessibility tree + require_terminal (bool): whether to require terminal output + os_type (str): operating system type, default to "Ubuntu" + enable_proxy (bool): whether to enable proxy support, default to False + env_id (int): environment id for docker provider, default to 0 + """ + # Initialize VM manager and vitualization provider + self.env_id = env_id + self.region = region + self.provider_name = provider_name + self.enable_proxy = enable_proxy # Store proxy enablement setting + + # Default + self.server_port = 5000 + self.chromium_port = 9222 + self.vnc_port = 8006 + self.vlc_port = 8080 + + # Initialize with default (no proxy) provider + self.current_use_proxy = False + if provider_name == "docker": + self.manager, self.provider = create_vm_manager_and_provider(provider_name, region, use_proxy=False, env_id=self.env_id) + else: + self.manager, self.provider = create_vm_manager_and_provider(provider_name, region, use_proxy=False) + + self.os_type = os_type + + # Track whether environment has been used (step/setup) to optimize snapshot revert + # docker, aws, gcp, azure are always unused as the emulator starts from a clean state + # vmware, virtualbox are always used as the emulator starts from a dirty state + if self.provider_name in {"docker", "aws", "gcp", "azure"}: + self.is_environment_used = False + elif self.provider_name in {"vmware", "virtualbox"}: + self.is_environment_used = True + else: + raise ValueError(f"Invalid provider name: {self.provider_name}") + + # Initialize environment variables + if path_to_vm: + self.path_to_vm = os.path.abspath(os.path.expandvars(os.path.expanduser(path_to_vm))) \ + if provider_name in {"vmware", "virtualbox"} else path_to_vm + else: + + self.path_to_vm = self.manager.get_vm_path(os_type=self.os_type, region=region) + try: + self.snapshot_name = snapshot_name + self.cache_dir_base: str = cache_dir + # todo: add the logic to get the screen size from the VM + self.headless = headless + self.require_a11y_tree = require_a11y_tree + self.require_terminal = require_terminal + + # Initialize emulator and controller + if provider_name != "docker": # Check if this is applicable to other VM providers + logger.info("Initializing...") + self._start_emulator() + + # mode: human or machine + self.instruction = None + assert action_space in ["computer_13", "pyautogui"] + self.action_space = action_space # todo: refactor it to the ActType + + # episodic stuffs, like counters, will be updated or reset + # when calling self.reset() + self._traj_no: int = -1 + self._step_no: int = 0 + self.action_history: List[Dict[str, any]] = [] + except Exception as e: + logger.error(f"Failed to initialize DesktopEnv: {e}") + # If initialization fails, we should clean up the VM + try: + self.close() + self.manager.delete_vm(self.path_to_vm, self.region) + logger.info(f"Cleaned up VM {self.path_to_vm}.") + except Exception as cleanup_error: + logger.error(f"Failed to clean up VM {self.path_to_vm}: {cleanup_error}") + raise + + def _start_emulator(self): + # Power on the virtual machine + self.provider.start_emulator(self.path_to_vm, self.headless, self.os_type) + + # Get the ip from the virtual machine, and setup the controller + vm_ip_ports = self.provider.get_ip_address(self.path_to_vm).split(':') + self.vm_ip = vm_ip_ports[0] + if len(vm_ip_ports) > 1: + self.server_port = int(vm_ip_ports[1]) + self.chromium_port = int(vm_ip_ports[2]) + self.vnc_port = int(vm_ip_ports[3]) + self.vlc_port = int(vm_ip_ports[4]) + self.controller = PythonController(vm_ip=self.vm_ip, server_port=self.server_port) + self.setup_controller = SetupController(vm_ip=self.vm_ip, server_port=self.server_port, chromium_port=self.chromium_port, vlc_port=self.vlc_port, cache_dir=self.cache_dir_base) + + async def _start_emulator_async(self): + """ + Async version of _start_emulator that uses async provider methods when available. + Falls back to running blocking operations in thread pool. + """ + # Check if provider has async start_emulator method + if hasattr(self.provider, 'start_emulator_async'): + await self.provider.start_emulator_async(self.path_to_vm, self.headless, self.os_type) + else: + # Fallback to running blocking start_emulator in thread pool + loop = asyncio.get_event_loop() + await loop.run_in_executor(None, self.provider.start_emulator, self.path_to_vm, self.headless, self.os_type) + + # Get the ip from the virtual machine, and setup the controller + # These operations are typically fast, so run in thread pool + loop = asyncio.get_event_loop() + vm_ip_ports = await loop.run_in_executor(None, lambda: self.provider.get_ip_address(self.path_to_vm).split(':')) + + self.vm_ip = vm_ip_ports[0] + if len(vm_ip_ports) > 1: + self.server_port = int(vm_ip_ports[1]) + self.chromium_port = int(vm_ip_ports[2]) + self.vnc_port = int(vm_ip_ports[3]) + self.vlc_port = int(vm_ip_ports[4]) + + # Create controllers (these are typically fast initialization operations) + self.controller = PythonController(vm_ip=self.vm_ip, server_port=self.server_port) + self.setup_controller = SetupController(vm_ip=self.vm_ip, server_port=self.server_port, chromium_port=self.chromium_port, vlc_port=self.vlc_port, cache_dir=self.cache_dir_base) + + def _revert_to_snapshot(self): + # Revert to certain snapshot of the virtual machine, and refresh the path to vm and ip of vm + # due to the fact it could be changed when implemented by cloud services + path_to_vm = self.provider.revert_to_snapshot(self.path_to_vm, self.snapshot_name) + if path_to_vm and not path_to_vm == self.path_to_vm: + # path_to_vm has to be a new path + + self.manager.delete_vm(self.path_to_vm, self.region) + self.manager.add_vm(path_to_vm, self.region) + self.manager.occupy_vm(path_to_vm, os.getpid(), self.region) + self.path_to_vm = path_to_vm + + def _save_state(self, snapshot_name=None): + # Save the current virtual machine state to a certain snapshot name + self.provider.save_state(self.path_to_vm, snapshot_name) + + def close(self): + # Close (release) the virtual machine + self.provider.stop_emulator(self.path_to_vm) + + async def reset(self, task_config: Optional[Dict[str, Any]] = None, seed=None, options=None) -> Dict[str, Any]: + + # Reset to certain task in OSWorld + logger.info("Resetting environment...") + logger.info("Switching task...") + logger.info("Setting counters...") + self._traj_no += 1 + self._step_no = 0 + self.action_history.clear() + + for attempt in range(MAX_RETRIES): + # Check and handle proxy requirement changes BEFORE starting emulator + if task_config is not None: + # Only consider task proxy requirement if proxy is enabled at system level + task_use_proxy = task_config.get("proxy", False) and self.enable_proxy + if not self.enable_proxy and task_config.get("proxy", False): + logger.info("Task requires proxy but proxy is disabled at system level, ignoring proxy requirement.") + + if task_use_proxy != self.current_use_proxy: + logger.info(f"Task proxy requirement changed: {self.current_use_proxy} -> {task_use_proxy}") + + # Close current provider if it exists + if hasattr(self, 'provider') and self.provider: + try: + self.provider.stop_emulator(self.path_to_vm) + except Exception as e: + logger.warning(f"Failed to stop current provider: {e}") + + # Create new provider with appropriate proxy setting + self.current_use_proxy = task_use_proxy + if self.provider_name == "docker": + self.manager, self.provider = create_vm_manager_and_provider( + self.provider_name, + self.region, + use_proxy=task_use_proxy, + env_id=self.env_id + ) + else: + self.manager, self.provider = create_vm_manager_and_provider( + self.provider_name, + self.region, + use_proxy=task_use_proxy + ) + + if task_use_proxy: + logger.info("Using proxy-enabled AWS provider.") + else: + logger.info("Using regular AWS provider.") + + + # Only revert to snapshot if environment has been used (step/setup) + # This optimization is especially important for cloud providers like AWS + # where unnecessary snapshot operations are costly and time-consuming + if self.is_environment_used: + logger.info("Environment has been used, reverting to snapshot {}...".format(self.snapshot_name)) + self._revert_to_snapshot() + logger.info("Starting emulator...") + await self._start_emulator_async() + logger.info("Emulator started.") + # Reset the usage flag after reverting + self.is_environment_used = False + else: + logger.info("Environment is clean, skipping snapshot revert (provider: {}).".format(self.provider_name)) + + if task_config is not None: + if task_config.get("proxy", False) and self.enable_proxy: + # If using proxy and proxy is enabled, set up the proxy configuration + self.setup_controller._proxy_setup() + self._set_task_info(task_config) + self.setup_controller.reset_cache_dir(self.cache_dir) + logger.info("Setting up environment...") + success = await self.setup_controller.setup(self.config, task_config.get("proxy", False) and self.enable_proxy) + if success: + # Mark environment as used when setup is successfully executed + if self.config: # Only mark as used if there were actual setup operations + self.is_environment_used = True + break + else: + logger.error( + "Environment setup failed, retrying (%d/%d)...", + attempt + 1, + MAX_RETRIES, + ) + await asyncio.sleep(5) + else: + break + + logger.info("Environment setup complete.") + + observation = self._get_obs() + return observation + + def _get_obs(self): + # We provide screenshot, accessibility_tree (optional), terminal (optional), and instruction. + # can be customized and scaled + return { + "screenshot": self.controller.get_screenshot(), + "accessibility_tree": self.controller.get_accessibility_tree() if self.require_a11y_tree else None, + "terminal": self.controller.get_terminal_output() if self.require_terminal else None, + "instruction": self.instruction + } + + async def _get_obs_async(self): + """ + Async version of _get_obs that runs controller operations concurrently. + Significantly faster when multiple I/O operations are needed. + """ + loop = asyncio.get_event_loop() + + # Prepare tasks for concurrent execution + tasks = [] + task_names = [] + + # Always get screenshot + tasks.append(loop.run_in_executor(None, self.controller.get_screenshot)) + task_names.append("screenshot") + + # Conditionally get accessibility tree + if self.require_a11y_tree: + tasks.append(loop.run_in_executor(None, self.controller.get_accessibility_tree)) + task_names.append("accessibility_tree") + + # Conditionally get terminal output + if self.require_terminal: + tasks.append(loop.run_in_executor(None, self.controller.get_terminal_output)) + task_names.append("terminal") + + # Execute all I/O operations concurrently + results = await asyncio.gather(*tasks) + + # Build observation dictionary + obs = {"instruction": self.instruction} + + # Map results back to observation keys + for i, name in enumerate(task_names): + obs[name] = results[i] + + # Set optional fields to None if not requested + if not self.require_a11y_tree: + obs["accessibility_tree"] = None + if not self.require_terminal: + obs["terminal"] = None + + return obs + + @property + def vm_platform(self): + return self.controller.get_vm_platform() + + @property + def vm_screen_size(self): + return self.controller.get_vm_screen_size() + + def _set_task_info(self, task_config: Dict[str, Any]): + """Set task info (proxy logic is handled in reset method)""" + self.task_id: str = task_config["id"] + self.cache_dir: str = os.path.join(self.cache_dir_base, self.task_id) + os.makedirs(self.cache_dir, exist_ok=True) + self.instruction = task_config["instruction"] + self.config = task_config["config"] if "config" in task_config else [] + + self._set_evaluator_info(task_config) + + def _set_evaluator_info(self, task_config: Dict[str, Any]): + """Set evaluator information from task config""" + # evaluator dict + # func -> metric function string, or list of metric function strings + # conj -> conjunction of multiple metrics if func is a list with length > 1, "and"/"or" + # result -> result getter config, or list of result getter configs + # expected (optional) -> expected getter config, or list of expected getter configs + # options (optional) -> metric options, or list of metric options + # if func is a str list, then result, expected (if exists), options (if exists) should also be lists of the same length + # even if one of the metrics does not need expected or options field, it should be included in the list with None + self.evaluator = task_config["evaluator"] + self.metric: Metric = [getattr(metrics, func) for func in self.evaluator["func"]] \ + if isinstance(self.evaluator["func"], list) \ + else getattr(metrics, self.evaluator["func"]) + self.metric_conj: str = self.evaluator.get("conj", "and") # take conjunction of multiple metrics + if "result" in self.evaluator and len(self.evaluator["result"]) > 0: + self.result_getter: Getter = [getattr(getters, "get_{:}".format(res["type"])) for res in + self.evaluator["result"]] \ + if isinstance(self.evaluator["result"], list) \ + else getattr(getters, "get_{:}".format(self.evaluator["result"]["type"])) + else: + self.result_getter = [None] * len(self.metric) \ + if isinstance(self.metric, list) \ + else None + + if "expected" in self.evaluator and len(self.evaluator["expected"]) > 0: + self.expected_getter: Getter = [getattr(getters, "get_{:}".format(exp["type"])) if exp else None for exp in + self.evaluator["expected"]] \ + if isinstance(self.evaluator["expected"], list) \ + else getattr(getters, "get_{:}".format(self.evaluator["expected"]["type"])) + else: + self.expected_getter = [None] * len(self.metric) \ + if isinstance(self.metric, list) \ + else None + self.metric_options: Union[List[Dict[str, Any]], Dict[str, Any]] = [opt if opt else {} for opt in + self.evaluator["options"]] \ + if isinstance(self.evaluator.get("options", {}), list) \ + else self.evaluator["options"] \ + if "options" in self.evaluator \ + else [{}] * len(self.metric) \ + if isinstance(self.metric, list) \ + else {} + + assert (not isinstance(self.evaluator["func"], list) + or (len(self.metric) == len(self.result_getter) == len(self.expected_getter) == len( + self.metric_options))) + + def step(self, action, pause=2): + self._step_no += 1 + self.action_history.append(action) + + # Mark environment as used when step is called + self.is_environment_used = True + + reward = 0 # todo: Define reward calculation for each example + done = False # todo: Define episode termination condition for each example + info = {} + + # handle the special actions + if action in ['WAIT', 'FAIL', 'DONE'] or (type(action) == dict and action['action_type'] in ['WAIT', 'FAIL', 'DONE']): + if action == 'WAIT': + time.sleep(pause) + elif action == 'FAIL': + done = True + info = {"fail": True} + elif action == 'DONE': + done = True + info = {"done": True} + + if self.action_space == "computer_13": + # the set of all possible actions defined in the action representation + self.controller.execute_action(action) + elif self.action_space == "pyautogui": + if action in ['WAIT', 'FAIL', 'DONE']: + self.controller.execute_action(action) + else: + # the set of all possible python commands insides `pyautogui` + self.controller.execute_python_command(action) + + time.sleep(pause) + observation = self._get_obs() + + return observation, reward, done, info + + async def step_async(self, action, pause=2): + """ + Async version of step method that uses non-blocking operations. + Provides significant performance improvements for observation gathering. + """ + self._step_no += 1 + self.action_history.append(action) + + # Mark environment as used when step is called + self.is_environment_used = True + + reward = 0 # todo: Define reward calculation for each example + done = False # todo: Define episode termination condition for each example + info = {} + + # handle the special actions + if action in ['WAIT', 'FAIL', 'DONE'] or (type(action) == dict and action['action_type'] in ['WAIT', 'FAIL', 'DONE']): + if action == 'WAIT': + await asyncio.sleep(pause) # Non-blocking sleep + elif action == 'FAIL': + done = True + info = {"fail": True} + elif action == 'DONE': + done = True + info = {"done": True} + + # Execute controller actions in thread pool to avoid blocking + loop = asyncio.get_event_loop() + if self.action_space == "computer_13": + # the set of all possible actions defined in the action representation + await loop.run_in_executor(None, self.controller.execute_action, action) + elif self.action_space == "pyautogui": + if action in ['WAIT', 'FAIL', 'DONE']: + await loop.run_in_executor(None, self.controller.execute_action, action) + else: + # the set of all possible python commands insides `pyautogui` + await loop.run_in_executor(None, self.controller.execute_python_command, action) + + # Non-blocking pause + await asyncio.sleep(pause) + + # Use async observation gathering for better performance + observation = await self._get_obs_async() + + if observation["accessibility_tree"] is None: + raise ValueError("Accessibility tree is unavailable") + + return observation, reward, done, info + + async def evaluate(self): + """ + Evaluate whether the task is successfully completed. + """ + + postconfig = self.evaluator.get("postconfig", []) + await self.setup_controller.setup(postconfig) + # Mark environment as used if there were postconfig setup operations + if postconfig: + self.is_environment_used = True + + if self.evaluator['func'] == "infeasible": + if len(self.action_history) > 0 and self.action_history[-1] == "FAIL": + return 1 + else: + return 0 + else: + if len(self.action_history) > 0 and self.action_history[-1] == "FAIL": + return 0 + + if type(self.metric) == list: + # Multiple metrics to evaluate whether the task is successfully completed + results = [] + assert len(self.metric) == len(self.result_getter), "The number of metrics and result getters must be the same" + if "expected" in self.evaluator: + assert len(self.metric) == len(self.expected_getter), "The number of metrics and expected getters must be the same" + for idx, metric in enumerate(self.metric): + try: + config = self.evaluator["result"][idx] + result_state = self.result_getter[idx](self, config) + + # Handle case where getter returns None + if result_state is None: + logger.warning(f"Getter returned None for metric {idx}. Skipping evaluation.") + if self.metric_conj == 'and': + return 0 + else: + continue + + except FileNotFoundError: + logger.error("File not found!") + if self.metric_conj == 'and': + return 0 + except Exception as e: + logger.error(f"Error in result getter for metric {idx}: {e}") + if self.metric_conj == 'and': + return 0 + else: + continue + + try: + if "expected" in self.evaluator and self.expected_getter and self.evaluator["expected"]: + expected_state = self.expected_getter[idx](self, self.evaluator["expected"][idx]) + + # Handle case where expected getter returns None + if expected_state is None: + logger.warning(f"Expected getter returned None for metric {idx}. Using result-only evaluation.") + metric_result = metric(result_state, **self.metric_options[idx]) + else: + metric_result = metric(result_state, expected_state, **self.metric_options[idx]) + else: + metric_result = metric(result_state, **self.metric_options[idx]) + except Exception as e: + logger.error(f"Error in metric evaluation for metric {idx}: {e}") + if self.metric_conj == 'and': + return 0 + else: + continue + + # Handle case where metric is async + if hasattr(metric_result, '__await__'): + metric_result = await metric_result + + if self.metric_conj == 'and' and float(metric_result) == 0.0: + return 0 + elif self.metric_conj == 'or' and float(metric_result) == 1.0: + return 1 + else: + results.append(metric_result) + + return sum(results) / len(results) if self.metric_conj == 'and' else max(results) + else: + # Single metric to evaluate whether the task is successfully completed + try: + result_state = self.result_getter(self, self.evaluator["result"]) + + # Handle case where getter returns None + if result_state is None: + logger.warning("Getter returned None. Cannot evaluate task.") + return 0 + + except FileNotFoundError: + logger.error("File not found!") + return 0 + except Exception as e: + logger.error(f"Error in result getter: {e}") + return 0 + + try: + if "expected" in self.evaluator and self.expected_getter and self.evaluator["expected"]: + expected_state = self.expected_getter(self, self.evaluator["expected"]) + + # Handle case where expected getter returns None + if expected_state is None: + logger.warning("Expected getter returned None. Using result-only evaluation.") + metric_result = self.metric(result_state, **self.metric_options) + else: + metric_result = self.metric(result_state, expected_state, **self.metric_options) + else: + metric_result = self.metric(result_state, **self.metric_options) + except Exception as e: + logger.error(f"Error in metric evaluation: {e}") + return 0 + + # Handle case where metric is async + if hasattr(metric_result, '__await__'): + metric_result = await metric_result + + return metric_result + + def render(self, mode='rgb_array'): + if mode == 'rgb_array': + return self.controller.get_screenshot() + else: + raise ValueError('Unsupported render mode: {}'.format(mode)) diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/README.md b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/README.md new file mode 100644 index 000000000..3100ed27d --- /dev/null +++ b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/README.md @@ -0,0 +1,224 @@ +# Evaluator Setup Details +Setup scaffolding for the evaluators in the desktop environment for those who want to know the details of the evaluator setup for customized evaluation and extension + +## Overall +Inside the virtual machine, disable the system crash report by: +``` +sudo vim /etc/default/apport +``` +and then change the `enabled` to `0`. + +## VSCode +todo + +## LibreOffice +For LibreOffice, please enter into the app first, and then enable the no pop-up when 'ctrl + s'. + +## LibreOffice Press +### Setting Up the python-pptx Library +```shell +pip install python-pptx +``` + +## LibreOffice Writer + +### Setting Up the python-docx and odfpy Library +```shell +pip install python-docx +pip install odfpy +``` + +## LibreOffice Calc + +### Required Libraries + +``` +openpyxl +pandas +lxml +xmltodict +``` + +### How to Generate CSV from XLSX + +```sh +libreoffice --convert-to "csv:Text - txt - csv (StarCalc):44,34,UTF8,,,,false,true,true,false,false,1" --out-dir /home/user /home/user/abc.xlsx +``` + +This command will generate `abc-Sheet1.csv` under `/home/user`. The last `1` in +the conversion options indicates the sheet number (starting from 1) to export. +Detailed usage should be referred to at [CSV Filter +Options](https://help.libreoffice.org/latest/ro/text/shared/guide/csv_params.html). + +Refer to `libreoffice_calc/21df9241-f8d7-4509-b7f1-37e501a823f7.json` for an +example. + +### About `compare_table` + +Evaluation to xlsx files mainly relies on `compare_table`. It accepts two file +names and a list of rules defined as `options`. Refer to +`libreoffice_calc/21df9241-f8d7-4509-b7f1-37e501a823f7.json` for an example. + +In each rule, there is a required field `type`. The supported types are defined +in `compare_table` function. The most common two are `sheet_data` and +`sheet_print`. `sheet_data` compares the internal cell values through pandoc, +while `sheet_print` compares the shown cell values through csv. A csv should be +generated and downloaded for `sheet_print`. See the previous section and +example in `libreoffice_calc/21df9241-f8d7-4509-b7f1-37e501a823f7.json`. + +Other fields in a rule are described for each evaluation type in +`compare_table` function. `sheet_idx0` (or `sheet_idx1`, `sheet_idx`) is a +common field to indicate which sheet is to extracted from the workbook. If an +integer i is given, then it extracts the i-th sheet from result xlsx (i starts +from 0). If a string is given, it should be preceded with "RI", "RN", "EI", or +"EN". "R" indicates to extract from result xlsx while "E" indicates to extract +from expected (golden) xlsx. "I" indicates a sheet number (starting from 0) and +"N" indicates a sheet name (usually, they're like "Sheet1", "Sheet2", ...). + +Some rules use a atructure like `{"method": "eq", "ref": "abc"}`. These rules +are checked through `utils._match_value_to_rule` function. Check it for the +implemented matching methods. + +## Chrome + +### Starting Chrome with Remote Debugging for Python + +To enable remote debugging in Chrome, which allows tools like Playwright for Python to connect to and control an existing Chrome instance, follow these steps: + +#### Manually Enabling Remote Debugging in Chrome + +1. **Locate the Chrome Shortcut**: + - Find the Chrome shortcut that you usually use to open the browser. This could be on your desktop, start menu, or taskbar. + +2. **Edit Shortcut Properties**: + - Right-click on the Chrome shortcut and select `Properties`. + +3. **Modify the Target Field**: + - In the `Target` field, add `--remote-debugging-port=9222` at the end of the path. Ensure there is a space between the path and the flag you add. + - It should look something like this: `"C:\Path\To\Chrome.exe" --remote-debugging-port=9222`. + +4. **Apply and Close**: + - Click `Apply` and then `OK` to close the dialog. + +5. **Start Chrome**: + - Use this modified shortcut to start Chrome. Chrome will now start with remote debugging enabled on port 9222. + +6. **Confirm Remote Debugging**: + - Open a browser and navigate to `http://localhost:9222`. If you see a webpage with information about active tabs, remote debugging is working. + +--- + +### Setting Up Playwright for Python + +Playwright for Python is a browser automation library to control Chromium, Firefox, and WebKit with a single API. + +#### Installing Playwright + +- Ensure you have Python installed on your system. If not, download and install it from the [Python official website](https://www.python.org/). + +- Install Playwright using pip (Python Package Installer). Open a command line or terminal and run: + + ```bash + pip install playwright + ``` + +- After installing Playwright, you need to run the install command to download the necessary browser binaries: + + ```bash + playwright install + ``` + +#### Writing a Playwright Script in Python + +- Create a Python file for your automation script. + +- Import the Playwright module at the beginning of your script: + + ```python + from playwright.sync_api import sync_playwright + ``` + +- You can now use Playwright's API to control browsers. + +#### Example Playwright Script + +Here is a simple example to open a page using Playwright: + +```python +from playwright.sync_api import sync_playwright + +def run(playwright): + browser = playwright.chromium.launch() + page = browser.new_page() + page.goto("http://example.com") + ## other actions... + browser.close() + +with sync_playwright() as playwright: + run(playwright) +``` + +- This script launches Chromium, opens a new page, navigates to `example.com`, and then closes the browser. + +#### Troubleshooting + +- If you encounter issues with Playwright, ensure that your Python environment is correctly set up and that you have installed Playwright and its dependencies correctly. +- For detailed documentation, visit the [Playwright for Python Documentation](https://playwright.dev/python/docs/intro). + + +## VLC Media Player + +### Bugs fix +One thing on Ubuntu need to do, enter into the `meida`>`convert/save`>select files>`convert/save` +Then enter the profile of `Audio - MP3`, change the profile for mp3, section audiocodec from "MP3" to "MPEG Audio" +Otherwise the mp3 file will be created but with 0 bytes. It's a bug of VLC. + +### Setting Up VLC's HTTP Interface + +To enable and use the HTTP interface in VLC Media Player for remote control and status checks, follow these steps: + +#### 1. Open VLC Preferences + +- Open VLC Media Player. +- Go to `Tools` > `Preferences` from the menu. + +#### 2. Show All Settings + +- In the Preferences window, at the bottom left corner, select `All` under `Show settings` to display advanced settings. + +#### 3. Enable Main Interfaces + +- In the advanced preferences, expand the `Interface` section. +- Click on `Main interfaces`. +- Check the box for `Web` to enable the HTTP interface. + +#### 4. Configure Lua HTTP + +- Expand the `Main interfaces` node and select `Lua`. +- Under `Lua HTTP`, set a password `password` in the `Lua HTTP` section. This password will be required to access the HTTP interface. + +#### 5. Save and Restart VLC + +- Click `Save` to apply the changes. +- Restart VLC Media Player for the changes to take effect. + +#### 6. Accessing the HTTP Interface + +- Open a web browser and go to `http://localhost:8080`. +- You will be prompted for a password. Enter the password you set in the Lua HTTP settings. +- Once logged in, you will have access to VLC's HTTP interface for remote control. + +#### Packages +```bash + +pip install opencv-python-headless Pillow imagehash +``` + +#### Troubleshooting + +- If you cannot access the HTTP interface, check if your firewall or security software is blocking the connection. +- Ensure VLC is running and the correct port (default is 8080) is being used. +- If the port is in use by another application, you may change the port number in VLC's settings. + +## GIMP +Click on the "Keep" of the image loading pop-up. diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/__init__.py b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/__init__.py new file mode 100644 index 000000000..c88feff0b --- /dev/null +++ b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/__init__.py @@ -0,0 +1,5 @@ +#from .table import compare_table + +#eval_funcs = { + #"compare_table(expected, actual)": compare_table +#} diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/getters/__init__.py b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/getters/__init__.py new file mode 100644 index 000000000..a035e2759 --- /dev/null +++ b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/getters/__init__.py @@ -0,0 +1,39 @@ +from .chrome import ( + get_default_search_engine, + get_cookie_data, + get_bookmarks, + get_open_tabs_info, + get_pdf_from_url, + get_shortcuts_on_desktop, + get_history, + get_page_info, + get_enabled_experiments, + get_chrome_language, + get_chrome_font_size, + get_profile_name, + get_number_of_search_results, + get_googledrive_file, + get_active_tab_info, + get_enable_do_not_track, + get_enable_enhanced_safety_browsing, + get_new_startup_page, + get_find_unpacked_extension_path, + get_data_delete_automacally, + get_active_tab_html_parse, + get_active_tab_url_parse, + get_gotoRecreationPage_and_get_html_content, + get_url_dashPart, + get_active_url_from_accessTree, + get_find_installed_extension_name, + get_info_from_website +) +from .file import get_cloud_file, get_vm_file, get_cache_file, get_content_from_vm_file +from .general import get_vm_command_line, get_vm_terminal_output, get_vm_command_error +from .gimp import get_gimp_config_file +from .impress import get_audio_in_slide, get_background_image_in_slide +from .info import get_vm_screen_size, get_vm_window_size, get_vm_wallpaper, get_list_directory +from .misc import get_rule, get_accessibility_tree, get_rule_relativeTime, get_time_diff_range +from .replay import get_replay +from .vlc import get_vlc_playing_info, get_vlc_config, get_default_video_player +from .vscode import get_vscode_config +from .calc import get_conference_city_in_order diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/getters/calc.py b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/getters/calc.py new file mode 100644 index 000000000..81e11752a --- /dev/null +++ b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/getters/calc.py @@ -0,0 +1,15 @@ +import csv + + +# I want to write a function, reads a csv file, and get all the contents in the third column in the order of rows +def get_conference_city_in_order(env, config): + # read the csv file + csv_path = config['csv_path'] + print(f"Reading csv file from {csv_path}") + with open(csv_path, 'r') as f: + reader = csv.reader(f) + # skip the header row + next(reader) + # get the third column in the order of rows + conference_city_list = [row[2] for row in reader] + return conference_city_list diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/getters/chrome.py b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/getters/chrome.py new file mode 100644 index 000000000..b2203c8c9 --- /dev/null +++ b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/getters/chrome.py @@ -0,0 +1,1392 @@ +import asyncio +import json +import logging +import os +import platform +import sqlite3 +import time +from urllib.parse import unquote +from typing import Dict, Any, List +from urllib.parse import urlparse, parse_qs + +import lxml.etree +import requests +from lxml.cssselect import CSSSelector +from lxml.etree import _Element +from playwright.async_api import async_playwright, expect +from pydrive.auth import GoogleAuth +from pydrive.drive import GoogleDrive, GoogleDriveFileList, GoogleDriveFile + +_accessibility_ns_map = { + "st": "uri:deskat:state.at-spi.gnome.org", + "attr": "uri:deskat:attributes.at-spi.gnome.org", + "cp": "uri:deskat:component.at-spi.gnome.org", + "doc": "uri:deskat:document.at-spi.gnome.org", + "docattr": "uri:deskat:attributes.document.at-spi.gnome.org", + "txt": "uri:deskat:text.at-spi.gnome.org", + "val": "uri:deskat:value.at-spi.gnome.org", + "act": "uri:deskat:action.at-spi.gnome.org" +} + +logger = logging.getLogger("desktopenv.getters.chrome") + +""" +WARNING: +1. Functions from this script assume that no account is registered on Chrome, otherwise the default file path needs to be changed. +2. The functions are not tested on Windows and Mac, but they should work. +""" + + +async def get_info_from_website(env, config: Dict[Any, Any]) -> Any: + """ Get information from a website. Especially useful when the information may be updated through time. + Args: + env (Any): The environment object. + config (Dict[Any, Any]): The configuration dictionary. + - url (str): The URL of the website to visit + - infos (List[Dict[str, str]]): The list of information to be extracted from the website. Each dictionary contains: + - action (str): chosen from 'inner_text', 'attribute', 'click_and_inner_text', 'click_and_attribute', etc., concretely, + - inner_text: extract the inner text of the element specified by the selector + - attribute: extract the attribute of the element specified by the selector + - click_and_inner_text: click elements following the selector and then extract the inner text of the last element + - click_and_attribute: click elements following the selector and then extract the attribute of the last element + - selector (Union[str, List[str]]): The CSS selector(s) of the element(s) to be extracted. + - attribute (str): optional for 'attribute' and 'click_and_attribute', the attribute to be extracted. + - backups (Any): The backup information to be returned if the extraction fails. + """ + try: + host = env.vm_ip + port = env.chromium_port # fixme: this port is hard-coded, need to be changed from config file + server_port = env.server_port + remote_debugging_url = f"http://{host}:{port}" + backend_url = f"http://{host}:{server_port}" + use_proxy = env.current_use_proxy + async with async_playwright() as p: + # connect to remote Chrome instance + try: + browser = await p.chromium.connect_over_cdp(remote_debugging_url) + except Exception as e: + # If the connection fails (e.g., the agent close the browser instance), start a new browser instance + app = 'chromium' if 'arm' in platform.machine() else 'google-chrome' + command = [ + app, + "--remote-debugging-port=1337" + ] + if use_proxy: + command.append(f"--proxy-server=127.0.0.1:18888") + payload = json.dumps({"command": command, "shell": False}) + headers = {"Content-Type": "application/json"} + #requests.post("http://" + host + ":" + server_port + "/setup" + "/launch", headers=headers, data=payload) + requests.post(backend_url + "/setup" + "/launch", headers=headers, data=payload) + await asyncio.sleep(5) + browser = await p.chromium.connect_over_cdp(remote_debugging_url) + + page = await browser.contexts[0].new_page() + await page.goto(config["url"]) + await page.wait_for_load_state('load') + infos = [] + for info_dict in config.get('infos', []): + if page.url != config["url"]: + await page.goto(config["url"]) + await page.wait_for_load_state('load') + action = info_dict.get('action', 'inner_text') + if action == "inner_text": + ele = await page.wait_for_selector(info_dict['selector'], state='attached', timeout=10000) + infos.append(await ele.inner_text()) + elif action == "attribute": + ele = await page.wait_for_selector(info_dict['selector'], state='attached', timeout=10000) + infos.append(await ele.get_attribute(info_dict['attribute'])) + elif action == 'click_and_inner_text': + for idx, sel in enumerate(info_dict['selector']): + if idx != len(info_dict['selector']) - 1: + link = await page.wait_for_selector(sel, state='attached', timeout=10000) + await link.click() + await page.wait_for_load_state('load') + else: + ele = await page.wait_for_selector(sel, state='attached', timeout=10000) + infos.append(await ele.inner_text()) + elif action == 'click_and_attribute': + for idx, sel in enumerate(info_dict['selector']): + if idx != len(info_dict['selector']) - 1: + link = await page.wait_for_selector(sel, state='attached', timeout=10000) + await link.click() + await page.wait_for_load_state('load') + else: + ele = await page.wait_for_selector(sel, state='attached') + infos.append(await ele.get_attribute(info_dict['attribute'])) + else: + raise NotImplementedError(f'The action {action} is not supported yet.') + return infos + except Exception as e: + logger.error(f'[ERROR]: failed to obtain information from the website: {config["url"]}. Use backup results instead.') + return config.get('backups', None) + + +# The following ones just need to load info from the files of software, no need to connect to the software +def get_default_search_engine(env, config: Dict[str, str]): + os_type = env.vm_platform + if os_type == 'Windows': + preference_file_path = env.controller.execute_python_command("""import os; print(os.path.join(os.getenv('LOCALAPPDATA'), + 'Google\\Chrome\\User Data\\Default\\Preferences'))""")['output'].strip() + elif os_type == 'Darwin': + preference_file_path = env.controller.execute_python_command( + "import os; print(os.path.join(os.getenv('HOME'), 'Library/Application Support/Google/Chrome/Default/Preferences'))")[ + 'output'].strip() + elif os_type == 'Linux': + if "arm" in platform.machine(): + preference_file_path = env.controller.execute_python_command( + "import os; print(os.path.join(os.getenv('HOME'), 'snap/chromium/common/chromium/Default/Preferences'))")[ + 'output'].strip() + else: + preference_file_path = env.controller.execute_python_command( + "import os; print(os.path.join(os.getenv('HOME'), '.config/google-chrome/Default/Preferences'))")[ + 'output'].strip() + else: + raise Exception('Unsupported operating system') + + try: + content = env.controller.get_file(preference_file_path) + if content is None: + logger.warning("Failed to get Chrome preferences file. Returning default search engine.") + return "Google" + + data = json.loads(content) + + # The path within the JSON data to the default search engine might vary + search_engine = data.get('default_search_provider_data', {}).get('template_url_data', {}).get('short_name', + 'Google') + return search_engine + except Exception as e: + logger.error(f"Error: {e}") + return "Google" + + +def get_cookie_data(env, config: Dict[str, str]): + """ + Get the cookies from the Chrome browser. + Assume the cookies are stored in the default location, not encrypted and not large in size. + """ + os_type = env.vm_platform + if os_type == 'Windows': + chrome_cookie_file_path = env.controller.execute_python_command("""import os; print(os.path.join(os.getenv('LOCALAPPDATA'), + 'Google\\Chrome\\User Data\\Default\\Cookies'))""")['output'].strip() + elif os_type == 'Darwin': + chrome_cookie_file_path = env.controller.execute_python_command( + "import os; print(os.path.join(os.getenv('HOME'), 'Library/Application Support/Google/Chrome/Default/Cookies'))")[ + 'output'].strip() + elif os_type == 'Linux': + if "arm" in platform.machine(): + chrome_cookie_file_path = env.controller.execute_python_command( + "import os; print(os.path.join(os.getenv('HOME'), 'snap/chromium/common/chromium/Default/Cookies'))")[ + 'output'].strip() + else: + chrome_cookie_file_path = env.controller.execute_python_command( + "import os; print(os.path.join(os.getenv('HOME'), '.config/google-chrome/Default/Cookies'))")[ + 'output'].strip() + else: + raise Exception('Unsupported operating system') + + try: + content = env.controller.get_file(chrome_cookie_file_path) + if content is None: + logger.warning("Failed to get Chrome cookies file. Returning empty list.") + return [] + + _path = os.path.join(env.cache_dir, config["dest"]) + + with open(_path, "wb") as f: + f.write(content) + + conn = sqlite3.connect(_path, timeout=10.0) # 10 second timeout + cursor = conn.cursor() + + # Query to check for OpenAI cookies + cursor.execute("SELECT * FROM cookies") + cookies = cursor.fetchall() + conn.close() + return cookies + except Exception as e: + logger.error(f"Error: {e}") + return None + + +def get_history(env, config: Dict[str, str]): + os_type = env.vm_platform + if os_type == 'Windows': + chrome_history_path = env.controller.execute_python_command( + """import os; print(os.path.join(os.getenv('USERPROFILE'), "AppData", "Local", "Google", "Chrome", "User Data", "Default", "History"))""")[ + 'output'].strip() + elif os_type == 'Darwin': + chrome_history_path = env.controller.execute_python_command( + """import os; print(os.path.join(os.getenv('HOME'), "Library", "Application Support", "Google", "Chrome", "Default", "History"))""")[ + 'output'].strip() + elif os_type == 'Linux': + if "arm" in platform.machine(): + chrome_history_path = env.controller.execute_python_command( + "import os; print(os.path.join(os.getenv('HOME'), 'snap/chromium/common/chromium/Default/History'))")[ + 'output'].strip() + else: + chrome_history_path = env.controller.execute_python_command( + "import os; print(os.path.join(os.getenv('HOME'), '.config', 'google-chrome', 'Default', 'History'))")[ + 'output'].strip() + else: + raise Exception('Unsupported operating system') + + try: + content = env.controller.get_file(chrome_history_path) + if content is None: + logger.warning("Failed to get Chrome history file. Returning empty list.") + return [] + + _path = os.path.join(env.cache_dir, config["dest"]) + + with open(_path, "wb") as f: + f.write(content) + + conn = sqlite3.connect(_path, timeout=10.0) # 10 second timeout + cursor = conn.cursor() + + # Query to check for OpenAI cookies + cursor.execute("SELECT url, title, last_visit_time FROM urls") + history_items = cursor.fetchall() + conn.close() + return history_items + except Exception as e: + logger.error(f"Error: {e}") + return None + + +def get_enabled_experiments(env, config: Dict[str, str]): + os_type = env.vm_platform + if os_type == 'Windows': + preference_file_path = env.controller.execute_python_command("""import os; print(os.path.join(os.getenv('LOCALAPPDATA'), + 'Google\\Chrome\\User Data\\Local State'))""")[ + 'output'].strip() + elif os_type == 'Darwin': + preference_file_path = env.controller.execute_python_command( + "import os; print(os.path.join(os.getenv('HOME'), 'Library/Application Support/Google/Chrome/Local State'))")[ + 'output'].strip() + elif os_type == 'Linux': + if "arm" in platform.machine(): + preference_file_path = env.controller.execute_python_command( + "import os; print(os.path.join(os.getenv('HOME'), 'snap/chromium/common/chromium/Local State'))")[ + 'output'].strip() + else: + preference_file_path = env.controller.execute_python_command( + "import os; print(os.path.join(os.getenv('HOME'), '.config/google-chrome/Local State'))")[ + 'output'].strip() + else: + raise Exception('Unsupported operating system') + + try: + content = env.controller.get_file(preference_file_path) + if content is None: + logger.warning("Failed to get Chrome Local State file. Returning empty experiments list.") + return [] + + data = json.loads(content) + + # The path within the JSON data to the default search engine might vary + enabled_labs_experiments = data.get('browser', {}).get('enabled_labs_experiments', []) + return enabled_labs_experiments + except Exception as e: + logger.error(f"Error: {e}") + return [] + + +def get_profile_name(env, config: Dict[str, str]): + """ + Get the username from the Chrome browser. + Assume the cookies are stored in the default location, not encrypted and not large in size. + """ + os_type = env.vm_platform + if os_type == 'Windows': + preference_file_path = env.controller.execute_python_command("""import os; print(os.path.join(os.getenv('LOCALAPPDATA'), + 'Google\\Chrome\\User Data\\Default\\Preferences'))""")['output'].strip() + elif os_type == 'Darwin': + preference_file_path = env.controller.execute_python_command( + "import os; print(os.path.join(os.getenv('HOME'), 'Library/Application Support/Google/Chrome/Default/Preferences'))")[ + 'output'].strip() + elif os_type == 'Linux': + if "arm" in platform.machine(): + preference_file_path = env.controller.execute_python_command( + "import os; print(os.path.join(os.getenv('HOME'), 'snap/chromium/common/chromium/Default/Preferences'))")[ + 'output'].strip() + else: + preference_file_path = env.controller.execute_python_command( + "import os; print(os.path.join(os.getenv('HOME'), '.config/google-chrome/Default/Preferences'))")[ + 'output'].strip() + else: + raise Exception('Unsupported operating system') + + try: + content = env.controller.get_file(preference_file_path) + if content is None: + logger.warning("Failed to get Chrome preferences file. Returning None for profile name.") + return None + + data = json.loads(content) + + # The path within the JSON data to the default search engine might vary + profile_name = data.get('profile', {}).get('name', None) + return profile_name + except Exception as e: + logger.error(f"Error: {e}") + return None + + +def get_chrome_language(env, config: Dict[str, str]): + os_type = env.vm_platform + if os_type == 'Windows': + preference_file_path = env.controller.execute_python_command("""import os; print(os.path.join(os.getenv('LOCALAPPDATA'), + 'Google\\Chrome\\User Data\\Local State'))""")[ + 'output'].strip() + elif os_type == 'Darwin': + preference_file_path = env.controller.execute_python_command( + "import os; print(os.path.join(os.getenv('HOME'), 'Library/Application Support/Google/Chrome/Local State'))")[ + 'output'].strip() + elif os_type == 'Linux': + if "arm" in platform.machine(): + preference_file_path = env.controller.execute_python_command( + "import os; print(os.path.join(os.getenv('HOME'), 'snap/chromium/common/chromium/Local State'))")[ + 'output'].strip() + else: + preference_file_path = env.controller.execute_python_command( + "import os; print(os.path.join(os.getenv('HOME'), '.config/google-chrome/Local State'))")[ + 'output'].strip() + else: + raise Exception('Unsupported operating system') + + try: + content = env.controller.get_file(preference_file_path) + data = json.loads(content) + + # The path within the JSON data to the default search engine might vary + enabled_labs_experiments = data.get('intl', {}).get('app_locale', "en-US") + return enabled_labs_experiments + except Exception as e: + logger.error(f"Error: {e}") + return "en-US" + + +def get_chrome_font_size(env, config: Dict[str, str]): + os_type = env.vm_platform + if os_type == 'Windows': + preference_file_path = env.controller.execute_python_command("""import os; print(os.path.join(os.getenv('LOCALAPPDATA'), + 'Google\\Chrome\\User Data\\Default\\Preferences'))""")[ + 'output'].strip() + elif os_type == 'Darwin': + preference_file_path = env.controller.execute_python_command( + "import os; print(os.path.join(os.getenv('HOME'), 'Library/Application Support/Google/Chrome/Default/Preferences'))")[ + 'output'].strip() + elif os_type == 'Linux': + if "arm" in platform.machine(): + preference_file_path = env.controller.execute_python_command( + "import os; print(os.path.join(os.getenv('HOME'), 'snap/chromium/common/chromium/Default/Preferences'))")[ + 'output'].strip() + else: + preference_file_path = env.controller.execute_python_command( + "import os; print(os.path.join(os.getenv('HOME'), '.config/google-chrome/Default/Preferences'))")[ + 'output'].strip() + else: + raise Exception('Unsupported operating system') + + try: + content = env.controller.get_file(preference_file_path) + data = json.loads(content) + + # The path within the JSON data to the default search engine might vary + search_engine = data.get('webkit', {}).get('webprefs', { + "default_fixed_font_size": 13, + "default_font_size": 16, + "minimum_font_size": 13 + }) + return search_engine + except Exception as e: + logger.error(f"Error: {e}") + return { + "default_fixed_font_size": 13, + "default_font_size": 16 + } + + +def get_bookmarks(env, config: Dict[str, str]): + os_type = env.vm_platform + if os_type == 'Windows': + preference_file_path = env.controller.execute_python_command("""import os; print(os.path.join(os.getenv('LOCALAPPDATA'), + 'Google\\Chrome\\User Data\\Default\\Bookmarks'))""")['output'].strip() + elif os_type == 'Darwin': + preference_file_path = env.controller.execute_python_command( + "import os; print(os.path.join(os.getenv('HOME'), 'Library/Application Support/Google/Chrome/Default/Bookmarks'))")[ + 'output'].strip() + elif os_type == 'Linux': + if "arm" in platform.machine(): + preference_file_path = env.controller.execute_python_command( + "import os; print(os.path.join(os.getenv('HOME'), 'snap/chromium/common/chromium/Default/Bookmarks'))")[ + 'output'].strip() + else: + preference_file_path = env.controller.execute_python_command( + "import os; print(os.path.join(os.getenv('HOME'), '.config/google-chrome/Default/Bookmarks'))")[ + 'output'].strip() + else: + raise Exception('Unsupported operating system') + + content = env.controller.get_file(preference_file_path) + if not content: + return [] + data = json.loads(content) + bookmarks = data.get('roots', {}) + return bookmarks + + +# todo: move this to the main.py +def get_extensions_installed_from_shop(env, config: Dict[str, str]): + """Find the Chrome extensions directory based on the operating system.""" + os_type = env.vm_platform + if os_type == 'Windows': + chrome_extension_dir = env.controller.execute_python_command( + """os.path.expanduser('~') + '\\AppData\\Local\\Google\\Chrome\\User Data\\Default\\Extensions\\'""")[ + 'output'].strip() + elif os_type == 'Darwin': # macOS + chrome_extension_dir = env.controller.execute_python_command( + """os.path.expanduser('~') + '/Library/Application Support/Google/Chrome/Default/Extensions/'""")[ + 'output'].strip() + elif os_type == 'Linux': + if "arm" in platform.machine(): + preference_file_path = env.controller.execute_python_command( + "import os; print(os.path.join(os.getenv('HOME'), 'snap/chromium/common/chromium/Default/Extensions/'))")[ + 'output'].strip() + else: + chrome_extension_dir = env.controller.execute_python_command( + """os.path.expanduser('~') + '/.config/google-chrome/Default/Extensions/'""")['output'].strip() + else: + raise Exception('Unsupported operating system') + + manifests = [] + for extension_id in os.listdir(chrome_extension_dir): + extension_path = os.path.join(chrome_extension_dir, extension_id) + if os.path.isdir(extension_path): + # Iterate through version-named subdirectories + for version_dir in os.listdir(extension_path): + version_path = os.path.join(extension_path, version_dir) + manifest_path = os.path.join(version_path, 'manifest.json') + if os.path.isfile(manifest_path): + with open(manifest_path, 'r') as file: + try: + manifest = json.load(file) + manifests.append(manifest) + except json.JSONDecodeError: + logger.error(f"Error reading {manifest_path}") + return manifests + + +# The following ones require Playwright to be installed on the target machine, and the chrome needs to be pre-config on +# port info to allow remote debugging, see README.md for details + +async def get_page_info(env, config: Dict[str, str]): + host = env.vm_ip + port = env.chromium_port # fixme: this port is hard-coded, need to be changed from config file + server_port = env.server_port + url = config["url"] + + remote_debugging_url = f"http://{host}:{port}" + async with async_playwright() as p: + # connect to remote Chrome instance + try: + browser = await p.chromium.connect_over_cdp(remote_debugging_url) + except Exception as e: + # If the connection fails, start a new browser instance + platform.machine() + if "arm" in platform.machine(): + # start a new browser instance if the connection fails + payload = json.dumps({"command": [ + "chromium", + "--remote-debugging-port=1337" + ], "shell": False}) + else: + payload = json.dumps({"command": [ + "google-chrome", + "--remote-debugging-port=1337" + ], "shell": False}) + + headers = {"Content-Type": "application/json"} + requests.post("http://" + host + ":" + server_port + "/setup" + "/launch", headers=headers, data=payload) + await asyncio.sleep(5) + browser = await p.chromium.connect_over_cdp(remote_debugging_url) + + page = await browser.contexts[0].new_page() + await page.goto(url) + + try: + # Wait for the page to finish loading, this prevents the "execution context was destroyed" issue + await page.wait_for_load_state('load') # Wait for the 'load' event to complete + title = await page.title() + url = page.url + page_info = {'title': title, 'url': url, 'content': await page.content()} + except TimeoutError: + # If page loading times out, catch the exception and store the current information in the list + page_info = {'title': 'Load timeout', 'url': page.url, 'content': await page.content()} + except Exception as e: + # Catch other potential exceptions that might occur while reading the page title + print(f'Error: {e}') + page_info = {'title': 'Error encountered', 'url': page.url, 'content': await page.content()} + + await browser.close() + return page_info + + +async def get_open_tabs_info(env, config: Dict[str, str]): + host = env.vm_ip + port = env.chromium_port # fixme: this port is hard-coded, need to be changed from config file + server_port = env.server_port + + remote_debugging_url = f"http://{host}:{port}" + async with async_playwright() as p: + # connect to remote Chrome instance + try: + browser = await p.chromium.connect_over_cdp(remote_debugging_url) + except Exception as e: + # If the connection fails, start a new browser instance + platform.machine() + if "arm" in platform.machine(): + # start a new browser instance if the connection fails + payload = json.dumps({"command": [ + "chromium", + "--remote-debugging-port=1337" + ], "shell": False}) + else: + payload = json.dumps({"command": [ + "google-chrome", + "--remote-debugging-port=1337" + ], "shell": False}) + + headers = {"Content-Type": "application/json"} + requests.post(f"http://{host}:{server_port}/setup/launch", headers=headers, data=payload) + await asyncio.sleep(5) + try: + browser = await p.chromium.connect_over_cdp(remote_debugging_url) + except Exception as e: + return [] + + tabs_info = [] + for context in browser.contexts: + for page in context.pages: + try: + # Wait for the page to finish loading, this prevents the "execution context was destroyed" issue + await page.wait_for_load_state('networkidle') # Wait for the 'load' event to complete + title = await page.title() + url = page.url + tabs_info.append({'title': title, 'url': url}) + except TimeoutError: + # If page loading times out, catch the exception and store the current information in the list + tabs_info.append({'title': 'Load timeout', 'url': page.url}) + except Exception as e: + # Catch other potential exceptions that might occur while reading the page title + print(f'Error: {e}') + tabs_info.append({'title': 'Error encountered', 'url': page.url}) + + await browser.close() + return tabs_info + + +def get_active_url_from_accessTree(env, config): + """ + Playwright cannot get the url of active tab directly, + so we need to use accessibility tree to get the active tab info. + This function is used to get the active tab url from the accessibility tree. + config: + Dict[str, str]{ + # we no longer need to specify the xpath or selectors, since we will use defalut value + # 'xpath': + # the same as in metrics.general.accessibility_tree. + # 'selectors': + # the same as in metrics.general.accessibility_tree. + 'goto_prefix': + the prefix you want to add to the beginning of the url to be opened, default is "https://", + (the url we get from accTree does not have prefix) + ...(other keys, not used in this function) + } + Return + url: str + """ + # Ensure the controller and its method are accessible and return a valid result + if hasattr(env, 'controller') and callable(getattr(env.controller, 'get_accessibility_tree', None)): + accessibility_tree = env.controller.get_accessibility_tree() + if accessibility_tree is None: + print("Failed to get the accessibility tree.") + return None + else: + print("Controller or method 'get_accessibility_tree' not found.") + return None + + logger.debug("AT@eval: %s", accessibility_tree) + + at = None + try: + at = lxml.etree.fromstring(accessibility_tree) + except ValueError as e: + logger.error(f"Error parsing accessibility tree: {e}") + return None + + # Determine the correct selector based on system architecture + selector = None + arch = platform.machine() + print(f"Your architecture is: {arch}") + + if "arm" in arch: + selector_string = "application[name=Chromium] entry[name=Address\\ and\\ search\\ bar]" + else: + selector_string = "application[name=Google\\ Chrome] entry[name=Address\\ and\\ search\\ bar]" + + try: + selector = CSSSelector(selector_string, namespaces=_accessibility_ns_map) + except Exception as e: + logger.error(f"Failed to parse the selector for active tab URL: {e}") + return None + + elements = selector(at) if selector else [] + if not elements: + print("No elements found.") + return None + elif not elements[-1].text: + print("No text found in the latest element.") + return None + + # Use a default prefix if 'goto_prefix' is not specified in the config + goto_prefix = config.get("goto_prefix", "https://") + + active_tab_url = f"{goto_prefix}{elements[0].text}" + print(f"Active tab url now: {active_tab_url}") + return active_tab_url + + +async def get_active_tab_info(env, config: Dict[str, str]): + """ + This function is used to get all info about active tab. + Warning! This function will reload the target-url page + If the tartget url has cache or cookie, this function may reload to another page. + If you have tested the url will not pop up to another page (check in incongnito mode yourself first), + you can use this function. + config: Dict[str, str]{ + # Keys used in get_active_url_from_accessTree: "xpath", "selectors" + } + """ + active_tab_url = get_active_url_from_accessTree(env, config) + if active_tab_url is None: + logger.error("Failed to get the url of active tab") + return None + host = env.vm_ip + port = env.chromium_port # fixme: this port is hard-coded, need to be changed from config file + + remote_debugging_url = f"http://{host}:{port}" + async with async_playwright() as p: + # connect to remote Chrome instance, since it is supposed to be the active one, we won't start a new one if failed + try: + browser = await p.chromium.connect_over_cdp(remote_debugging_url) + except Exception as e: + return None + + active_tab_info = {} + # go to the target URL page + page = await browser.new_page() + try: + await page.goto(active_tab_url) + except: + logger.error("Failed to go to the target URL page") + return None + await page.wait_for_load_state('load') # Wait for the 'load' event to complete + active_tab_info = { + 'title': await page.title(), + 'url': page.url, + 'content': await page.content() # get the HTML content of the page + } + + await browser.close() + # print("active_tab_title: {}".format(active_tab_info.get('title', 'None'))) + # print("active_tab_url: {}".format(active_tab_info.get('url', 'None'))) + # print("active_tab_content: {}".format(active_tab_info.get('content', 'None'))) + return active_tab_info + + +async def get_pdf_from_url(env, config: Dict[str, str]) -> str: + """ + Download a PDF from a URL. + """ + _url = config["path"] + _path = os.path.join(env.cache_dir, config["dest"]) + + host = env.vm_ip + port = env.chromium_port # fixme: this port is hard-coded, need to be changed from config file + server_port = env.server_port + + remote_debugging_url = f"http://{host}:{port}" + + async with async_playwright() as p: + try: + browser = await p.chromium.connect_over_cdp(remote_debugging_url) + except Exception as e: + # If the connection fails, start a new browser instance + platform.machine() + if "arm" in platform.machine(): + # start a new browser instance if the connection fails + payload = json.dumps({"command": [ + "chromium", + "--remote-debugging-port=1337" + ], "shell": False}) + else: + payload = json.dumps({"command": [ + "google-chrome", + "--remote-debugging-port=1337" + ], "shell": False}) + + headers = {"Content-Type": "application/json"} + requests.post("http://" + host + ":" + server_port + "/setup" + "/launch", headers=headers, data=payload) + await asyncio.sleep(5) + browser = await p.chromium.connect_over_cdp(remote_debugging_url) + + page = await browser.new_page() + await page.goto(_url) + await page.pdf(path=_path) + await browser.close() + + return _path + + +# fixme: needs to be changed (maybe through post-processing) since it's not working +async def get_chrome_saved_address(env, config: Dict[str, str]): + host = env.vm_ip + port = env.chromium_port # fixme: this port is hard-coded, need to be changed from config file + server_port = env.server_port + + remote_debugging_url = f"http://{host}:{port}" + async with async_playwright() as p: + # connect to remote Chrome instance + try: + browser = await p.chromium.connect_over_cdp(remote_debugging_url) + except Exception as e: + # If the connection fails, start a new browser instance + platform.machine() + if "arm" in platform.machine(): + # start a new browser instance if the connection fails + payload = json.dumps({"command": [ + "chromium", + "--remote-debugging-port=1337" + ], "shell": False}) + else: + payload = json.dumps({"command": [ + "google-chrome", + "--remote-debugging-port=1337" + ], "shell": False}) + + headers = {"Content-Type": "application/json"} + requests.post("http://" + host + ":" + server_port + "/setup" + "/launch", headers=headers, data=payload) + await asyncio.sleep(5) + browser = await p.chromium.connect_over_cdp(remote_debugging_url) + + page = await browser.new_page() + + # Navigate to Chrome's settings page for autofill + await page.goto("chrome://settings/addresses") + + # Get the HTML content of the page + content = await page.content() + + await browser.close() + + return content + + +def get_shortcuts_on_desktop(env, config: Dict[str, str]): + # Find out the operating system + os_name = env.vm_platform + + # Depending on the OS, define the shortcut file extension + if os_name == 'Windows': + # Windows shortcuts are typically .url or .lnk files + shortcut_extension = '.lnk' + elif os_name == 'Darwin': + # macOS's shortcuts are .webloc files + shortcut_extension = '.webloc' + elif os_name == 'Linux': + # Linux (Ubuntu, etc.) shortcuts are typically .desktop files + shortcut_extension = '.desktop' + else: + logger.error(f"Unsupported operating system: {os_name}") + return [] + + # Get the path to the desktop folder + desktop_path = env.controller.get_vm_desktop_path() + desktop_directory_tree = env.controller.get_vm_directory_tree(desktop_path) + + shortcuts_paths = [file['name'] for file in desktop_directory_tree['children'] if + file['name'].endswith(shortcut_extension)] + + short_cuts = {} + + for shortcut_path in shortcuts_paths: + short_cuts[shortcut_path] = env.controller.get_file(env.controller.execute_python_command( + f"import os; print(os.path.join(os.path.expanduser('~'), 'Desktop', '{shortcut_path}'))")[ + 'output'].strip()).decode('utf-8') + + return short_cuts + + +async def get_number_of_search_results(env, config: Dict[str, str]): + # todo: move into the config file + url, result_selector = "https://google.com/search?q=query", '.search-result' + host = env.vm_ip + port = env.chromium_port # fixme: this port is hard-coded, need to be changed from config file + server_port = env.server_port + + remote_debugging_url = f"http://{host}:{port}" + async with async_playwright() as p: + try: + browser = await p.chromium.connect_over_cdp(remote_debugging_url) + except Exception as e: + # If the connection fails, start a new browser instance + platform.machine() + if "arm" in platform.machine(): + # start a new browser instance if the connection fails + payload = json.dumps({"command": [ + "chromium", + "--remote-debugging-port=1337" + ], "shell": False}) + else: + payload = json.dumps({"command": [ + "google-chrome", + "--remote-debugging-port=1337" + ], "shell": False}) + + headers = {"Content-Type": "application/json"} + requests.post("http://" + host + ":" + server_port + "/setup" + "/launch", headers=headers, data=payload) + await asyncio.sleep(5) + browser = await p.chromium.connect_over_cdp(remote_debugging_url) + page = await browser.new_page() + await page.goto(url) + search_results = await page.query_selector_all(result_selector) + actual_count = len(search_results) + await browser.close() + + return actual_count + + +def get_googledrive_file(env, config: Dict[str, Any]) -> str: + """ Get the desired file from Google Drive based on config, return the downloaded local filepath. + @args: keys in config dict + settings_file(str): target filepath to the settings file for Google Drive authentication, default is 'evaluation_examples/settings/googledrive/settings.yml' + query/path[_list](Union[str, List[str]]): the query or path [list] to the file(s) on Google Drive. To retrieve the file, we provide multiple key options to specify the filepath on drive in config dict: + 1) query: a list of queries to search the file, each query is a string that follows the format of Google Drive search query. The documentation is available here: (support more complex search but too complicated to use) + https://developers.google.com/drive/api/guides/search-files?hl=en + 2) path: a str list poingting to file path on googledrive, e.g., 'folder/subfolder/filename.txt' -> + config contain one key-value pair "path": ['folder', 'subfolder', 'filename.txt'] + 3) query_list: query extends to list to download multiple files + 4) path_list: path extends to list to download multiple files, e.g., + "path_list": [['folder', 'subfolder', 'filename1.txt'], ['folder', 'subfolder', 'filename2.txt']] + @return: + dest(Union[List[str], str]): target file name or list. If *_list is used in input config, dest should also be a list of the same length. Return the downloaded local filepath. + """ + settings_file = config.get('settings_file', 'evaluation_examples/settings/googledrive/settings.yml') + auth = GoogleAuth(settings_file=settings_file) + drive = GoogleDrive(auth) + + def get_single_file(_query, _path): + parent_id = 'root' + try: + for q in _query: + search = f'( {q} ) and "{parent_id}" in parents' + filelist: GoogleDriveFileList = drive.ListFile({'q': search}).GetList() + if len(filelist) == 0: # target file not found + return None + file: GoogleDriveFile = filelist[0] # HACK: if multiple candidates, just use the first one + parent_id = file['id'] + + file.GetContentFile(_path, mimetype=file['mimeType']) + except Exception as e: + logger.info('[ERROR]: Failed to download the file from Google Drive', e) + return None + return _path + + if 'query' in config: + return get_single_file(config['query'], os.path.join(env.cache_dir, config['dest'])) + elif 'path' in config: + query = [f"title = '{fp}' and mimeType = 'application/vnd.google-apps.folder' and trashed = false" if idx < len( + config['path']) - 1 + else f"title = '{fp}' and trashed = false" for idx, fp in enumerate(config['path'])] + return get_single_file(query, os.path.join(env.cache_dir, config['dest'])) + elif 'query_list' in config: + _path_list = [] + assert len(config['query_list']) == len(config['dest']) + for idx, query in enumerate(config['query_list']): + dest = config['dest'][idx] + _path_list.append(get_single_file(query, os.path.join(env.cache_dir, dest))) + return _path_list + else: # path_list in config + _path_list = [] + assert len(config['path_list']) == len(config['dest']) + for idx, path in enumerate(config['path_list']): + query = [ + f"title = '{fp}' and mimeType = 'application/vnd.google-apps.folder' and trashed = false" if jdx < len( + path) - 1 + else f"title = '{fp}' and trashed = false" for jdx, fp in enumerate(path)] + dest = config['dest'][idx] + _path_list.append(get_single_file(query, os.path.join(env.cache_dir, dest))) + return _path_list + + +def get_enable_do_not_track(env, config: Dict[str, str]): + os_type = env.vm_platform + if os_type == 'Windows': + preference_file_path = env.controller.execute_python_command("""import os; print(os.path.join(os.getenv('LOCALAPPDATA'), + 'Google\\Chrome\\User Data\\Default\\Preferences'))""")['output'].strip() + elif os_type == 'Darwin': + preference_file_path = env.controller.execute_python_command( + "import os; print(os.path.join(os.getenv('HOME'), 'Library/Application Support/Google/Chrome/Default/Preferences'))")[ + 'output'].strip() + elif os_type == 'Linux': + if "arm" in platform.machine(): + preference_file_path = env.controller.execute_python_command( + "import os; print(os.path.join(os.getenv('HOME'), 'snap/chromium/common/chromium/Default/Preferences'))")[ + 'output'].strip() + else: + preference_file_path = env.controller.execute_python_command( + "import os; print(os.path.join(os.getenv('HOME'), '.config/google-chrome/Default/Preferences'))")[ + 'output'].strip() + + else: + raise Exception('Unsupported operating system') + + try: + content = env.controller.get_file(preference_file_path) + data = json.loads(content) + + if_enable_do_not_track = data.get('enable_do_not_track', {}) # bool + return "true" if if_enable_do_not_track else "false" + except Exception as e: + logger.error(f"Error: {e}") + return "false" + + +def get_enable_enhanced_safety_browsing(env, config: Dict[str, str]): + os_type = env.vm_platform + if os_type == 'Windows': + preference_file_path = env.controller.execute_python_command("""import os; print(os.path.join(os.getenv('LOCALAPPDATA'), + 'Google\\Chrome\\User Data\\Default\\Preferences'))""")['output'].strip() + elif os_type == 'Darwin': + preference_file_path = env.controller.execute_python_command( + "import os; print(os.path.join(os.getenv('HOME'), 'Library/Application Support/Google/Chrome/Default/Preferences'))")[ + 'output'].strip() + elif os_type == 'Linux': + if "arm" in platform.machine(): + preference_file_path = env.controller.execute_python_command( + "import os; print(os.path.join(os.getenv('HOME'), 'snap/chromium/common/chromium/Default/Preferences'))")[ + 'output'].strip() + else: + preference_file_path = env.controller.execute_python_command( + "import os; print(os.path.join(os.getenv('HOME'), '.config/google-chrome/Default/Preferences'))")[ + 'output'].strip() + + else: + raise Exception('Unsupported operating system') + + try: + content = env.controller.get_file(preference_file_path) + data = json.loads(content) + + if_enable_do_not_track = data.get('safebrowsing', {}).get('enhanced', {}) # bool + return "true" if if_enable_do_not_track else "false" + except Exception as e: + logger.error(f"Error: {e}") + return "Google" + + +def get_new_startup_page(env, config: Dict[str, str]): + os_type = env.vm_platform + if os_type == 'Windows': + preference_file_path = env.controller.execute_python_command("""import os; print(os.path.join(os.getenv('LOCALAPPDATA'), + 'Google\\Chrome\\User Data\\Default\\Preferences'))""")['output'].strip() + elif os_type == 'Darwin': + preference_file_path = env.controller.execute_python_command( + "import os; print(os.path.join(os.getenv('HOME'), 'Library/Application Support/Google/Chrome/Default/Preferences'))")[ + 'output'].strip() + elif os_type == 'Linux': + if "arm" in platform.machine(): + preference_file_path = env.controller.execute_python_command( + "import os; print(os.path.join(os.getenv('HOME'), 'snap/chromium/common/chromium/Default/Preferences'))")[ + 'output'].strip() + else: + preference_file_path = env.controller.execute_python_command( + "import os; print(os.path.join(os.getenv('HOME'), '.config/google-chrome/Default/Preferences'))")[ + 'output'].strip() + + else: + raise Exception('Unsupported operating system') + + try: + content = env.controller.get_file(preference_file_path) + data = json.loads(content) + + # if data has no key called 'session', it means the chrome is on a fresh-start mode, which is a true state; + # otherwise, try to find the code number in 'restored_on_startup' in 'session' + if "session" not in data.keys(): + return "true" + else: + if_enable_do_not_track = data.get('session', {}).get('restore_on_startup', {}) # int, need to be 5 + return "true" if if_enable_do_not_track == 5 else "false" + except Exception as e: + logger.error(f"Error: {e}") + return "Google" + + +def get_find_unpacked_extension_path(env, config: Dict[str, str]): + os_type = env.vm_platform + if os_type == 'Windows': + preference_file_path = env.controller.execute_python_command("""import os; print(os.path.join(os.getenv('LOCALAPPDATA'), + 'Google\\Chrome\\User Data\\Default\\Preferences'))""")['output'].strip() + elif os_type == 'Darwin': + preference_file_path = env.controller.execute_python_command( + "import os; print(os.path.join(os.getenv('HOME'), 'Library/Application Support/Google/Chrome/Default/Preferences'))")[ + 'output'].strip() + elif os_type == 'Linux': + if "arm" in platform.machine(): + preference_file_path = env.controller.execute_python_command( + "import os; print(os.path.join(os.getenv('HOME'), 'snap/chromium/common/chromium/Default/Preferences'))")[ + 'output'].strip() + else: + preference_file_path = env.controller.execute_python_command( + "import os; print(os.path.join(os.getenv('HOME'), '.config/google-chrome/Default/Preferences'))")[ + 'output'].strip() + + else: + raise Exception('Unsupported operating system') + + try: + content = env.controller.get_file(preference_file_path) + data = json.loads(content) + # Preferences store all the path of installed extensions, return them all and let metrics try to find one matches the targeted extension path + all_extensions_path = [] + all_extensions = data.get('extensions', {}).get('settings', {}) + for id in all_extensions.keys(): + path = all_extensions[id]["path"] + all_extensions_path.append(path) + return all_extensions_path + except Exception as e: + logger.error(f"Error: {e}") + return "Google" + + +def get_find_installed_extension_name(env, config: Dict[str, str]): + os_type = env.vm_platform + if os_type == 'Windows': + preference_file_path = env.controller.execute_python_command("""import os; print(os.path.join(os.getenv('LOCALAPPDATA'), + 'Google\\Chrome\\User Data\\Default\\Preferences'))""")['output'].strip() + elif os_type == 'Darwin': + preference_file_path = env.controller.execute_python_command( + "import os; print(os.path.join(os.getenv('HOME'), 'Library/Application Support/Google/Chrome/Default/Preferences'))")[ + 'output'].strip() + elif os_type == 'Linux': + if "arm" in platform.machine(): + preference_file_path = env.controller.execute_python_command( + "import os; print(os.path.join(os.getenv('HOME'), 'snap/chromium/common/chromium/Default/Preferences'))")[ + 'output'].strip() + else: + preference_file_path = env.controller.execute_python_command( + "import os; print(os.path.join(os.getenv('HOME'), '.config/google-chrome/Default/Preferences'))")[ + 'output'].strip() + + else: + raise Exception('Unsupported operating system') + + try: + content = env.controller.get_file(preference_file_path) + data = json.loads(content) + # Preferences store all the path of installed extensions, return them all and let metrics try to find one matches the targeted extension path + all_extensions_name = [] + all_extensions = data.get('extensions', {}).get('settings', {}) + for id in all_extensions.keys(): + name = all_extensions[id]["manifest"]["name"] + all_extensions_name.append(name) + return all_extensions_name + except Exception as e: + logger.error(f"Error: {e}") + return "Google" + + +def get_data_delete_automacally(env, config: Dict[str, str]): + """ + This function is used to open th "auto-delete" mode of chromium + """ + os_type = env.vm_platform + if os_type == 'Windows': + preference_file_path = env.controller.execute_python_command("""import os; print(os.path.join(os.getenv('LOCALAPPDATA'), + 'Google\\Chrome\\User Data\\Default\\Preferences'))""")['output'].strip() + elif os_type == 'Darwin': + preference_file_path = env.controller.execute_python_command( + "import os; print(os.path.join(os.getenv('HOME'), 'Library/Application Support/Google/Chrome/Default/Preferences'))")[ + 'output'].strip() + elif os_type == 'Linux': + if "arm" in platform.machine(): + preference_file_path = env.controller.execute_python_command( + "import os; print(os.path.join(os.getenv('HOME'), 'snap/chromium/common/chromium/Default/Preferences'))")[ + 'output'].strip() + else: + preference_file_path = env.controller.execute_python_command( + "import os; print(os.path.join(os.getenv('HOME'), '.config/google-chrome/Default/Preferences'))")[ + 'output'].strip() + else: + raise Exception('Unsupported operating system') + + try: + content = env.controller.get_file(preference_file_path) + data = json.loads(content) + data_delete_state = data["profile"].get("default_content_setting_values", None) + return "true" if data_delete_state is not None else "false" + except Exception as e: + logger.error(f"Error: {e}") + return "Google" + + +async def get_active_tab_html_parse(env, config: Dict[str, Any]): + """ + This function is used to get the specific element's text content from the active tab's html. + config: + Dict[str, str]{ + # Keys used in get_active_url_from_accessTree: "xpath", "selectors" + 'category': + choose from ["class", "label", "xpath", "input"], used to indicate how to find the element + 'labelObject': + only exists when category is "label", + a dict like { "labelSelector": "the key you want to store the text content of this label's ee=lement"} + 'class_singleObject': + only exists when category is "class", a dict with keys as the class name, + like { "class name" : "the key you want to store the text content of this element" } + 'class_multiObject': + only exists when category is "class", used for elements with same class name. + Two layer of dict, like + ( { + "class name": { + "rank in this class" : "the key you want to store the text content of this element" + ... + } + } ) + 'xpathObject': + only exists when category is "xpath", a dict with keys as the xpath, + like { "full xpath" : "the key you want to store the text content of this element" } + 'inputObject': + only exists when category is "input", + a dict with keys as the input element's xpath, like { "full xpath" : "the key you want to store the text content of this element" } + } + """ + active_tab_url = get_active_url_from_accessTree(env, config) + if not isinstance(active_tab_url, str): + logger.error("active_tab_url is not a string") + return None + host = env.vm_ip + port = env.chromium_port # fixme: this port is hard-coded, need to be changed from config file + server_port = env.server_port + + remote_debugging_url = f"http://{host}:{port}" + async with async_playwright() as p: + # connect to remote Chrome instance + try: + browser = await p.chromium.connect_over_cdp(remote_debugging_url) + except Exception as e: + # If the connection fails, start a new browser instance + platform.machine() + if "arm" in platform.machine(): + # start a new browser instance if the connection fails + payload = json.dumps({"command": [ + "chromium", + "--remote-debugging-port=1337" + ], "shell": False}) + else: + payload = json.dumps({"command": [ + "google-chrome", + "--remote-debugging-port=1337" + ], "shell": False}) + + headers = {"Content-Type": "application/json"} + requests.post("http://" + host + ":" + server_port + "/setup" + "/launch", headers=headers, data=payload) + await asyncio.sleep(5) + browser = await p.chromium.connect_over_cdp(remote_debugging_url) + target_page = None + for context in browser.contexts: + for page in context.pages: + await page.wait_for_load_state("networkidle") + # the accTree and playwright can get encoding(percent-encoding) characters, we need to convert them to normal characters + if unquote(page.url) == unquote(active_tab_url): + target_page = page + print("\33[32mtartget page url: ", target_page.url, "\33[0m") + print("\33[32mtartget page title: ", await target_page.title(), "\33[0m") + break + if target_page is None: + logger.error("Your tab is not the target tab.") + return {} + + return_json = {} + + async def safely_get_text_content(selector): + elements = await target_page.query_selector_all(selector) + return [await element.text_content().strip() for element in elements if element] + + if config["category"] == "class": + class_multiObject = config.get("class_multiObject", {}) + for class_name, object_dict in class_multiObject.items(): + elements_texts = await safely_get_text_content("." + class_name) + for order_key, key in object_dict.items(): + index = int(order_key) + if len(elements_texts) > index: + return_json[key] = elements_texts[index] + + class_singleObject = config.get("class_singleObject", {}) + for class_name, key in class_singleObject.items(): + element_text = await safely_get_text_content("." + class_name) + if element_text: + return_json[key] = element_text[0] + + elif config['category'] == "label": + # Assuming get_by_label is a custom function or part of the framework being used + labelObject = config.get("labelObject", {}) + for labelSelector, key in labelObject.items(): + text = await target_page.locator(f"text={labelSelector}").first.text_content().strip() + if text: + return_json[key] = text + + elif config["category"] == "xpath": + xpathObject = config.get("xpathObject", {}) + for xpath, key in xpathObject.items(): + elements = target_page.locator(f"xpath={xpath}") + if await elements.count() > 0: + return_json[key] = await elements.first.text_content().strip() + + elif config["category"] == "input": + inputObjects = config.get("inputObject", {}) + for xpath, key in inputObjects.items(): + inputs = target_page.locator(f"xpath={xpath}") + if await inputs.count() > 0: + return_json[key] = await inputs.first.input_value().strip() + + await browser.close() + return return_json + + +async def get_gotoRecreationPage_and_get_html_content(env, config: Dict[str, Any]): + """ + especially used for www.recreation.gov examples + """ + host = env.vm_ip + port = env.chromium_port # fixme: this port is hard-coded, need to be changed from config file + server_port = env.server_port + + remote_debugging_url = f"http://{host}:{port}" + async with async_playwright() as p: + try: + browser = await p.chromium.connect_over_cdp(remote_debugging_url) + except Exception as e: + # If the connection fails, start a new browser instance + platform.machine() + if "arm" in platform.machine(): + # start a new browser instance if the connection fails + payload = json.dumps({"command": [ + "chromium", + "--remote-debugging-port=1337" + ], "shell": False}) + else: + payload = json.dumps({"command": [ + "google-chrome", + "--remote-debugging-port=1337" + ], "shell": False}) + + headers = {"Content-Type": "application/json"} + requests.post("http://" + host + ":" + server_port + "/setup" + "/launch", headers=headers, data=payload) + await asyncio.sleep(5) + browser = await p.chromium.connect_over_cdp(remote_debugging_url) + page = await browser.new_page() + await page.goto("https://www.recreation.gov/") + await page.fill("input#hero-search-input", "Albion Basin") + await page.click("button.nav-search-button") + print("after first click") + await asyncio.sleep(2) + # Assuming .search-result-highlight--success leads to a new page or requires page load + async with page.expect_popup() as popup_info: + await page.click(".search-result-highlight--success") + print("after second click") + newpage = popup_info.value + await newpage.wait_for_load_state() + print("go to newpage: ") + print(await newpage.title()) + await asyncio.sleep(2) + await newpage.click("button.next-available") + print("after third click") + + return_json = {} + return_json["expected"] = {} + # find the text of elements in html with specific class name + if config["selector"] == "class": + if "order" in config.keys(): + className = config["class"] + elements = await newpage.query_selector_all("." + className) + return_json["expected"][className] = await elements[int(config["order"])].text_content().strip() + else: + className = config["class"] + element = await newpage.query_selector("." + className) + return_json["expected"][className] = await element.text_content().strip() + await browser.close() + return return_json + + +def get_active_tab_url_parse(env, config: Dict[str, Any]): + """ + This function is used to parse the url according to config["parse_keys"]. + config: + 'parse_keys': must exist, + a list of keys to extract from the query parameters of the url. + 'replace': optional, + a dict, used to replace the original key with the new key. + ( { "original key": "new key" } ) + """ + active_tab_url = get_active_url_from_accessTree(env, config) + if active_tab_url is None: + return None + + # connect to remote Chrome instance + # parse in a hard-coded way to find the specific info about task + parsed_url = urlparse(active_tab_url) + # Extract the query parameters + query_params = parse_qs(parsed_url.query) + # Define the keys of interest + keys_of_interest = [key for key in config["parse_keys"]] + # Extract the parameters of interest + extracted_params = {key: query_params.get(key, [''])[0] for key in keys_of_interest} + if "replace" in config: + for key in config["replace"].keys(): + # change original key to new key, keep value unchange + value = extracted_params.pop(key) + extracted_params[config["replace"][key]] = value + return extracted_params + + +def get_url_dashPart(env, config: Dict[str, str]): + """ + This function is used to extract one of the dash-separated part of the URL. + config + 'partIndex': must exist, + the index of the dash-separated part to extract, starting from 0. + 'needDeleteId': optional, + a boolean, used to indicate whether to delete the "id" part ( an example: "/part-you-want?id=xxx" ) + 'returnType': must exist, + a string, used to indicate the return type, "string" or "json". + """ + active_tab_url = get_active_url_from_accessTree(env, config) + if active_tab_url is None: + return None + + # extract the last dash-separated part of the URL, and delete all the characters after "id" + dash_part = active_tab_url.split("/")[config["partIndex"]] + if config["needDeleteId"]: + dash_part = dash_part.split("?")[0] + # print("active_tab_title: {}".format(active_tab_info.get('title', 'None'))) + # print("active_tab_url: {}".format(active_tab_info.get('url', 'None'))) + # print("active_tab_content: {}".format(active_tab_info.get('content', 'None'))) + if config["returnType"] == "string": + return dash_part + elif config["returnType"] == "json": + return {config["key"]: dash_part} diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/getters/file.py b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/getters/file.py new file mode 100644 index 000000000..f4ab03a35 --- /dev/null +++ b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/getters/file.py @@ -0,0 +1,125 @@ +import os +from typing import Dict, List, Set +from typing import Optional, Any, Union +from datetime import datetime +import requests +import pandas as pd + + +def get_content_from_vm_file(env, config: Dict[str, Any]) -> Any: + """ + Config: + path (str): absolute path on the VM to fetch + """ + + path = config["path"] + file_path = get_vm_file(env, {"path": path, "dest": os.path.basename(path)}) + file_type, file_content = config['file_type'], config['file_content'] + if file_type == 'xlsx': + if file_content == 'last_row': + df = pd.read_excel(file_path) + last_row = df.iloc[-1] + last_row_as_list = last_row.astype(str).tolist() + return last_row_as_list + else: + raise NotImplementedError(f"File type {file_type} not supported") + + +def get_cloud_file(env, config: Dict[str, Any]) -> Union[str, List[str]]: + """ + Config: + path (str|List[str]): the url to download from + dest (str|List[str])): file name of the downloaded file + multi (bool) : optional. if path and dest are lists providing + information of multiple files. defaults to False + gives (List[int]): optional. defaults to [0]. which files are directly + returned to the metric. if len==1, str is returned; else, list is + returned. + """ + + if not config.get("multi", False): + paths: List[str] = [config["path"]] + dests: List[str] = [config["dest"]] + else: + paths: List[str] = config["path"] + dests: List[str] = config["dest"] + cache_paths: List[str] = [] + + gives: Set[int] = set(config.get("gives", [0])) + + for i, (p, d) in enumerate(zip(paths, dests)): + _path = os.path.join(env.cache_dir, d) + if i in gives: + cache_paths.append(_path) + + if os.path.exists(_path): + #return _path + continue + + url = p + response = requests.get(url, stream=True) + response.raise_for_status() + + with open(_path, 'wb') as f: + for chunk in response.iter_content(chunk_size=8192): + if chunk: + f.write(chunk) + + return cache_paths[0] if len(cache_paths)==1 else cache_paths + + +def get_vm_file(env, config: Dict[str, Any]) -> Union[Optional[str], List[Optional[str]]]: + """ + Config: + path (str): absolute path on the VM to fetch + dest (str): file name of the downloaded file + multi (bool) : optional. if path and dest are lists providing + information of multiple files. defaults to False + gives (List[int]): optional. defaults to [0]. which files are directly + returned to the metric. if len==1, str is returned; else, list is + returned. + only support for single file now: + time_suffix(bool): optional. defaults to False. if True, append the current time in required format. + time_format(str): optional. defaults to "%Y%m%d_%H%M%S". format of the time suffix. + """ + time_format = "%Y%m%d_%H%M%S" + if not config.get("multi", False): + paths: List[str] = [config["path"]] + dests: List[str] = [config["dest"]] + if config.get("time_suffix", False): + time_format = config.get("time_format", time_format) + # Insert time before file extension. + dests = [f"{os.path.splitext(d)[0]}_{datetime.now().strftime(time_format)}{os.path.splitext(d)[1]}" for d in dests] + else: + paths: List[str] = config["path"] + dests: List[str] = config["dest"] + + + cache_paths: List[str] = [] + + gives: Set[int] = set(config.get("gives", [0])) + + for i, (p, d) in enumerate(zip(paths, dests)): + _path = os.path.join(env.cache_dir, d) + file = env.controller.get_file(p) + if file is None: + if i in gives: + cache_paths.append(None) + continue + + if i in gives: + cache_paths.append(_path) + with open(_path, "wb") as f: + f.write(file) + return cache_paths[0] if len(cache_paths)==1 else cache_paths + + +def get_cache_file(env, config: Dict[str, str]) -> str: + """ + Config: + path (str): relative path in cache dir + """ + + _path = os.path.join(env.cache_dir, config["path"]) + assert os.path.exists(_path) + return _path diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/getters/general.py b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/getters/general.py new file mode 100644 index 000000000..2f5ed32c1 --- /dev/null +++ b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/getters/general.py @@ -0,0 +1,42 @@ +import logging +from typing import Dict +import requests + +logger = logging.getLogger("desktopenv.getters.general") + + +def get_vm_command_line(env, config: Dict[str, str]): + vm_ip = env.vm_ip + port = env.server_port + command = config["command"] + shell = config.get("shell", False) + + response = requests.post(f"http://{vm_ip}:{port}/execute", json={"command": command, "shell": shell}) + + print(response.json()) + + if response.status_code == 200: + return response.json()["output"] + else: + logger.error("Failed to get vm command line. Status code: %d", response.status_code) + return None + +def get_vm_command_error(env, config: Dict[str, str]): + vm_ip = env.vm_ip + port = env.server_port + command = config["command"] + shell = config.get("shell", False) + + response = requests.post(f"http://{vm_ip}:{port}/execute", json={"command": command, "shell": shell}) + + print(response.json()) + + if response.status_code == 200: + return response.json()["error"] + else: + logger.error("Failed to get vm command line error. Status code: %d", response.status_code) + return None + + +def get_vm_terminal_output(env, config: Dict[str, str]): + return env.controller.get_terminal_output() diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/getters/gimp.py b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/getters/gimp.py new file mode 100644 index 000000000..d055ba761 --- /dev/null +++ b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/getters/gimp.py @@ -0,0 +1,75 @@ +import logging +import os +from typing import Dict + +logger = logging.getLogger("desktopenv.getters.gimp") + + +def get_gimp_config_file(env, config: Dict[str, str]): + """ + Gets the config setting of GIMP. + """ + + os_type = env.vm_platform + print(os_type) + + if os_type == "Linux": + try: + # First check if the GIMP config directory exists + check_command = f"import os; print(os.path.exists(os.path.expanduser('~/.config/GIMP/2.10/{config['file_name']}')))" + exists_result = env.controller.execute_python_command(check_command) + + if exists_result.get('output', '').strip().lower() == 'false': + logger.warning(f"GIMP config file {config['file_name']} does not exist. Skipping GIMP config retrieval.") + # Create an empty placeholder file to prevent evaluation errors + _path = os.path.join(env.cache_dir, config["dest"]) + os.makedirs(os.path.dirname(_path), exist_ok=True) + with open(_path, "w") as f: + f.write("# GIMP config file not found - placeholder\n") + return _path + + config_path = \ + env.controller.execute_python_command(f"import os; print(" + f"os" + f".path.expanduser(" + f"'~/.config/GIMP/2.10/" + f"{config['file_name']}'))")[ + 'output'].strip() + except Exception as e: + logger.error(f"Failed to check GIMP config file existence: {e}") + # Create an empty placeholder file to prevent evaluation errors + _path = os.path.join(env.cache_dir, config["dest"]) + os.makedirs(os.path.dirname(_path), exist_ok=True) + with open(_path, "w") as f: + f.write("# GIMP config file check failed - placeholder\n") + return _path + # TODO: Add support for macOS and Windows + else: + raise Exception("Unsupported operating system", os_type) + + _path = os.path.join(env.cache_dir, config["dest"]) + + try: + content = env.controller.get_file(config_path) + + if not content: + logger.warning("Failed to get GIMP config file content. Creating placeholder.") + # Create an empty placeholder file to prevent evaluation errors + os.makedirs(os.path.dirname(_path), exist_ok=True) + with open(_path, "w") as f: + f.write("# GIMP config file content not available - placeholder\n") + return _path + + os.makedirs(os.path.dirname(_path), exist_ok=True) + with open(_path, "wb") as f: + f.write(content) + + return _path + + except Exception as e: + logger.error(f"Error retrieving GIMP config file: {e}") + # Create an empty placeholder file to prevent evaluation errors + os.makedirs(os.path.dirname(_path), exist_ok=True) + with open(_path, "w") as f: + f.write(f"# GIMP config file retrieval failed: {e} - placeholder\n") + return _path diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/getters/impress.py b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/getters/impress.py new file mode 100644 index 000000000..f6bc60f10 --- /dev/null +++ b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/getters/impress.py @@ -0,0 +1,126 @@ +import os +import tempfile +import xml.etree.ElementTree as ET +import zipfile +from typing import Dict + +from desktop_env.evaluators.getters.file import get_vm_file + + +def get_background_image_in_slide(env, config: Dict[str, str]): + ppt_file_path, slide_index, dest = config["ppt_file_path"], int(config["slide_index"]), config["dest"] + image_id, image_file_path = None, None + + ppt_file_localhost_path = get_vm_file(env, {"path": ppt_file_path, "dest": os.path.split(ppt_file_path)[-1]}) + + with zipfile.ZipFile(ppt_file_localhost_path, 'r') as myzip: + slide1_xml_file = 'ppt/slides/slide{}.xml'.format(slide_index + 1) + # firstly, check whether the background image is used in the slide + if slide1_xml_file not in myzip.namelist(): return None + with myzip.open(slide1_xml_file) as f: + # Parse the XML tree from the relationships file + tree = ET.parse(f) + root = tree.getroot() + bg_tag = "{http://schemas.openxmlformats.org/presentationml/2006/main}bgPr" + image_tag = "{http://schemas.openxmlformats.org/drawingml/2006/main}blip" + attr_tag = "{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed" + for child in root.iter(bg_tag): + try: + for element in child.iter(image_tag): + image_id = element.attrib[attr_tag] + break + except: pass + if image_id is not None: break + else: return None + + # next, extract the background image from the slide + slide1_rels_file = 'ppt/slides/_rels/slide{}.xml.rels'.format(slide_index + 1) + if slide1_rels_file in myzip.namelist(): + with myzip.open(slide1_rels_file) as f: + # Parse the XML tree from the relationships file + tree = ET.parse(f) + root = tree.getroot() + # Define the namespace used in the relationships file + namespaces = {'r': 'http://schemas.openxmlformats.org/package/2006/relationships'} + # Look for all relationship elements that have a type attribute for image + for rel in root.findall('r:Relationship', namespaces): + # Check if the relationship is for an image file + if 'image' in rel.attrib['Type'] and rel.attrib['Id'] == image_id: + target = rel.attrib['Target'] + if target.startswith('..'): + # Resolve the relative path to get the correct path within the zip file + image_file_path = os.path.normpath(os.path.join('ppt/slides', target)) + # Replace backslashes with forward slashes for ZIP compatibility + image_file_path = image_file_path.replace('\\', '/') + tmpdirname = os.path.dirname(ppt_file_localhost_path) + myzip.extract(image_file_path, tmpdirname) + image_file_path = os.path.join(tmpdirname, image_file_path) + return image_file_path + else: # absolute path + assert target.startswith("file://"), target + image_file_path = target[7:] + break + if image_file_path is None: + return None + + else: + # Get the audio file from vm and return the file path in the host + return get_vm_file(env, {"path": image_file_path, "dest": dest}) + + +def get_audio_in_slide(env, config: Dict[str, str]): + ppt_file_path, slide_index, dest = config["ppt_file_path"], int(config["slide_index"]), config["dest"] + + # Open the .pptx file as a zip file, fixme: now we assume there is only one audio file in the slides + audio_file_path = None + + ppt_file_localhost_path = get_vm_file(env, {"path": ppt_file_path, "dest": os.path.split(ppt_file_path)[-1]}) + + with zipfile.ZipFile(ppt_file_localhost_path, 'r') as myzip: + # Find the relationships XML file for the first slide + slide1_rels_file = 'ppt/slides/_rels/slide{}.xml.rels'.format(slide_index + 1) + if slide1_rels_file in myzip.namelist(): + with myzip.open(slide1_rels_file) as f: + # Parse the XML tree from the relationships file + tree = ET.parse(f) + root = tree.getroot() + # Define the namespace used in the relationships file + namespaces = {'r': 'http://schemas.openxmlformats.org/package/2006/relationships'} + # Look for all relationship elements that have a type attribute for audio + for rel in root.findall('r:Relationship', namespaces): + # Check if the relationship is for an audio file + if 'audio' in rel.attrib['Type']: + # The audio can be embedded inside the file or linked to an external file + # Get the target attribute which contains the audio file path + target = rel.attrib['Target'] + + if target.startswith('..'): + # Resolve the relative path to get the correct path within the zip file + audio_file_path = os.path.normpath(os.path.join('ppt/slides', target)) + # Replace backslashes with forward slashes for ZIP compatibility + audio_file_path = audio_file_path.replace('\\', '/') + + # Create a temporary directory to extract the audio file + tmpdirname = os.path.dirname(ppt_file_localhost_path) + myzip.extract(audio_file_path, tmpdirname) + audio_file_path = os.path.join(tmpdirname, audio_file_path) + return audio_file_path + # with tempfile.TemporaryDirectory() as tmpdirname: + # # Extract the audio file + # myzip.extract(audio_file_path, tmpdirname) + # # Get the full path of the extracted audio file + # extracted_audio_path = os.path.join(tmpdirname, audio_file_path) + # # Return the extracted audio file path + # audio_file_path = extracted_audio_path + else: + # the audio file is external to the .pptx file + # Return the audio file path + assert target.startswith("file://"), target + audio_file_path = target[7:] + break + if audio_file_path is None: + return None + + else: + # Get the audio file from vm and return the file path in the host + return get_vm_file(env, {"path": audio_file_path, "dest": dest}) diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/getters/info.py b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/getters/info.py new file mode 100644 index 000000000..0c88fd20b --- /dev/null +++ b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/getters/info.py @@ -0,0 +1,24 @@ +import os +from typing import Union + + +def get_vm_screen_size(env, config: dict) -> dict: + return env.controller.get_vm_screen_size() + + +def get_vm_window_size(env, config: dict) -> dict: + return env.controller.get_vm_window_size(app_class_name=config["app_class_name"]) + + +def get_vm_wallpaper(env, config: dict) -> Union[str, bytes]: + _path = os.path.join(env.cache_dir, config["dest"]) + + content = env.controller.get_vm_wallpaper() + with open(_path, "wb") as f: + f.write(content) + + return _path + + +def get_list_directory(env, config: dict) -> dict: + return env.controller.get_vm_directory_tree(config["path"]) diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/getters/misc.py b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/getters/misc.py new file mode 100644 index 000000000..8862438a6 --- /dev/null +++ b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/getters/misc.py @@ -0,0 +1,204 @@ +import logging +from typing import TypeVar, Dict +from datetime import datetime, timedelta + +logger = logging.getLogger("desktopenv.getters.misc") + +R = TypeVar("Rule") + +day_of_week_mapping = { + 0: 'Mon', + 1: 'Tue', + 2: 'Wed', + 3: 'Thu', + 4: 'Fri', + 5: 'Sat', + 6: 'Sun' +} + +month_mapping = { + 1: 'Jan', + 2: 'Feb', + 3: 'Mar', + 4: 'Apr', + 5: 'May', + 6: 'Jun', + 7: 'Jul', + 8: 'Aug', + 9: 'Sep', + 10: 'Oct', + 11: 'Nov', + 12: 'Dec' +} + +Month_Mapping_Full = { + 1: "January", + 2: "February", + 3: "March", + 4: "April", + 5: "May", + 6: "June", + 7: "July", + 8: "August", + 9: "September", + 10: "October", + 11: "November", + 12: "December" +} + +month_mapping_full = { + 1: 'january', + 2: 'february', + 3:'march', + 4: 'april', + 5:'may', + 6: 'june', + 7: 'july', + 8: 'august', + 9:'september', + 10: 'october', + 11: 'november', + 12: 'december' +} + +relativeTime_to_IntDay = { + "tomorrow": 1, + "5th next month": "special", + "10th next month": "special", + "11th next month": "special", + "this month": "special", + "this Saturday": "special", + "this Sunday": "special", + "next Monday": "special", + "next Friday": "special", + "first monday four months later": "special" +} + +def get_rule(env, config: Dict[str, R]) -> R: + """ + Returns the rule as-is. + """ + return config["rules"] + +def get_rule_relativeTime(env, config: Dict[str, R]) -> R: + """ + According to the rule definded in funciton "apply_rules_to_timeFormat", convert the relative time to absolute time. + config: + 'relativeTime': { + "from": must exist; indicates the relativeTime. + "to": optional; indicates the relativeTime. + } + If relativeTime only has key "from", then the key of time in "expected" dict must be "time". + If relativeTime has key "to", then the key of time in "expected" dict must be "from" and "to". + """ + relativeRules = config["rules"] + relativeTime = relativeRules["relativeTime"] # int, "+" means future, "-" means past + # get the date now + now = datetime.now() + # calculate the relative time + if "to" not in relativeTime.keys(): + start_relative_time = relativeTime["from"] + if relativeTime_to_IntDay[start_relative_time] != "special": + # relativeTime can be represented by actual int days + start_relative_time_IntDat = relativeTime_to_IntDay[start_relative_time] + timediff = timedelta(days=start_relative_time_IntDat) + absoluteDay = now + timediff + else: + # special case, you can add more special cases here + if start_relative_time == "5th next month": + next_year = now.year + 1 if now.month == 12 else now.year + next_month = now.month + 1 if now.month < 12 else 1 + next_day = 5 + absoluteDay = datetime(next_year, next_month, next_day) + elif start_relative_time == "10th next month": + next_year = now.year + 1 if now.month == 12 else now.year + next_month = now.month + 1 if now.month < 12 else 1 + next_day = 10 + absoluteDay = datetime(next_year, next_month, next_day) + elif start_relative_time == "this month": + absoluteDay = now + elif start_relative_time == "next Monday": + absoluteDay = now + timedelta(days=((6-now.weekday())+1)) + elif start_relative_time == "first monday four months later": + next_year = now.year + 1 if now.month >=9 else now.year + next_month = (now.month + 4)%12 + # get the first monday of the next_month + temp_date = datetime(next_year, next_month, 1) + absoluteDay = temp_date + timedelta(days=((6-temp_date.weekday())+1)%7) + regular_time = apply_rules_to_timeFormat(relativeRules["expected"]["time"], absoluteDay) + config["rules"]["expected"]["time"] = regular_time + + else: + from_time = relativeTime["from"] + to_time = relativeTime["to"] + # deal with from_time first + if relativeTime_to_IntDay[from_time] != "special": + from_time_IntDat = relativeTime_to_IntDay[from_time] + from_timediff = timedelta(days=from_time_IntDat) + from_absoluteDay = now + from_timediff + else: + if from_time == "this Saturday": + from_absoluteDay = now + timedelta(days=(5-now.weekday())) + elif from_time == "10th next month": + next_year = now.year + 1 if now.month == 12 else now.year + next_month = now.month + 1 if now.month < 12 else 1 + next_day = 10 + from_absoluteDay = datetime(next_year, next_month, next_day) + elif from_time == "next Monday": + from_absoluteDay = now + timedelta(days=((6-now.weekday())+1)) + else: + pass # more rules here + regular_from_time = apply_rules_to_timeFormat(relativeRules["expected"]["from"], from_absoluteDay) + config["rules"]["expected"]["from"] = regular_from_time + + # deal with to_time + if relativeTime_to_IntDay[to_time] != "special": + to_time_IntDat = relativeTime_to_IntDay[to_time] + to_timediff = timedelta(days=to_time_IntDat) + to_absoluteDay = now + to_timediff + else: + if to_time == "this Sunday": + to_absoluteDay = now + timedelta(days=(6-now.weekday())) + elif to_time == "11th next month": + next_year = now.year + 1 if now.month == 12 else now.year + next_month = now.month + 1 if now.month < 12 else 1 + next_day = 11 + to_absoluteDay = datetime(next_year, next_month, next_day) + elif to_time == "next Friday": + if now.weekday() < 4 and from_time in ["next Monday"]: + to_absoluteDay = now + timedelta(days=((4-now.weekday())+7)) + else: + to_absoluteDay = now + timedelta(days=((4-now.weekday()) if now.weekday() < 4 else (6-now.weekday()) + 5)) + else: + pass # more rules here + regular_to_time = apply_rules_to_timeFormat(relativeRules["expected"]["to"], to_absoluteDay) + config["rules"]["expected"]["to"] = regular_to_time + + return config["rules"] + + +def apply_rules_to_timeFormat(timeFormat: str, absoluteDay: datetime): + timeFormat = timeFormat.replace("{DoW}", day_of_week_mapping[absoluteDay.weekday()], 1) + timeFormat = timeFormat.replace("{Month}", month_mapping[absoluteDay.month], 1) + timeFormat = timeFormat.replace("{DayD}", str(absoluteDay.day), 1) + timeFormat = timeFormat.replace("{Year}", str(absoluteDay.year), 1) + timeFormat = timeFormat.replace("{Month0D}", "0"+str(absoluteDay.month) if absoluteDay.month < 10 else str(absoluteDay.month), 1) + timeFormat = timeFormat.replace("{month}", month_mapping_full[absoluteDay.month], 1) + timeFormat = timeFormat.replace("{MonthFull}", Month_Mapping_Full[absoluteDay.month], 1) + timeFormat = timeFormat.replace("{Day0D}", "0"+str(absoluteDay.day) if absoluteDay.day < 10 else str(absoluteDay.day), 1) + # you can add other replace rules here + + return timeFormat + + +def get_accessibility_tree(env, *args) -> str: + accessibility_tree: str = env.controller.get_accessibility_tree() + logger.debug("AT@eval: %s", accessibility_tree) + return accessibility_tree + +def get_time_diff_range(env, config) -> str: + try: + return config["diff_range_in_minutes"] + except: + logger.error("diff_range_in_minutes not found in config.") + return None \ No newline at end of file diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/getters/replay.py b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/getters/replay.py new file mode 100644 index 000000000..c85098630 --- /dev/null +++ b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/getters/replay.py @@ -0,0 +1,20 @@ +from typing import List, Dict, Any + + +def get_replay(env, trajectory: List[Dict[str, Any]]) -> None: + # fixme: need to be combined with the accessibility tree to activate the selection of the target window + def parse(action): + if action["type"] == "hotkey": + keys = "', '".join(action["param"]) + return f"pyautogui.hotkey('{keys}')" + + if action["type"] == "typewrite": + text = action["param"] + return f"pyautogui.typewrite('{text}')" + + if action["type"] == "press": + key = action["param"] + return f"pyautogui.press('{key}')" + + for action in trajectory: + env.controller.execute_python_command(parse(action)) diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/getters/vlc.py b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/getters/vlc.py new file mode 100644 index 000000000..2e81543e7 --- /dev/null +++ b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/getters/vlc.py @@ -0,0 +1,95 @@ +import logging +import os +from typing import Dict +from collections import Counter +from .general import get_vm_command_line +import requests + +logger = logging.getLogger("desktopenv.getters.vlc") + + +def get_vlc_playing_info(env, config: Dict[str, str]): + """ + Gets the current playing information from VLC's HTTP interface. + """ + + host = env.vm_ip + port = env.vlc_port + password = 'password' + + _path = os.path.join(env.cache_dir, config["dest"]) + url = f'http://{host}:{port}/requests/status.xml' + response = requests.get(url, auth=('', password)) + if response.status_code == 200: + content = response.content + else: + logger.error("Failed to get vlc status. Status code: %d", response.status_code) + return None + + with open(_path, "wb") as f: + f.write(content) + + return _path + + +def get_vlc_config(env, config: Dict[str, str]): + """ + Reads the VLC configuration file to check setting. + """ + + os_type = env.vm_platform + + # fixme: depends on how we config and install the vlc in virtual machine, need to be aligned and double-checked + if os_type == "Linux": + config_path = \ + env.controller.execute_python_command("import os; print(os.path.expanduser('~/.config/vlc/vlcrc'))")[ + 'output'].strip() + elif os_type == "Darwin": + config_path = env.controller.execute_python_command( + "import os; print(os.path.expanduser('~/Library/Preferences/org.videolan.vlc/vlcrc'))")['output'].strip() + elif os_type == "Windows": + config_path = env.controller.execute_python_command( + "import os; print(os.path.expanduser('~\\AppData\\Roaming\\vlc\\vlcrc'))")['output'].strip() + else: + raise Exception("Unsupported operating system", os_type) + + _path = os.path.join(env.cache_dir, config["dest"]) + content = env.controller.get_file(config_path) + + if content is None: + logger.warning("Failed to get VLC config file. Creating placeholder.") + # Create an empty placeholder file to prevent evaluation errors + os.makedirs(os.path.dirname(_path), exist_ok=True) + with open(_path, "w") as f: + f.write("# VLC config file content not available - placeholder\n") + return _path + + with open(_path, "wb") as f: + f.write(content) + + return _path + + +def get_default_video_player(env, config: dict): + """ Gets the default application for a category or file extension. + """ + + os_type = env.vm_platform + + if os_type == "Linux": + extensions = ['3gp', '3gp', '3gpp', '3gpp', '3gpp2', '3gpp2', 'avi', 'avi', 'divx', 'divx', 'dv', 'dv', 'fli', 'fli', 'flv', 'flv', 'mp2t', 'mp2t', 'mp4', 'mp4', 'mp4v-es', 'mp4v-es', 'mpeg', 'mpeg', 'mpeg-system', 'mpeg-system', 'msvideo', 'msvideo', 'ogg', 'ogg', 'quicktime', 'quicktime', 'vnd.divx', 'vnd.divx', 'vnd.mpegurl', 'vnd.mpegurl', 'vnd.rn-realvideo', 'vnd.rn-realvideo', 'webm', 'webm', 'x-anim', 'x-anim', 'x-avi', 'x-avi', 'x-flc', 'x-flc', 'x-fli', 'x-fli', 'x-flv', 'x-flv', 'x-m4v', 'x-m4v', 'x-matroska', 'x-matroska', 'x-mpeg', 'x-mpeg', 'x-mpeg-system', 'x-mpeg-system', 'x-mpeg2', 'x-mpeg2', 'x-ms-asf', 'x-ms-asf', 'x-ms-asf-plugin', 'x-ms-asf-plugin', 'x-ms-asx', 'x-ms-asx', 'x-ms-wm', 'x-ms-wm', 'x-ms-wmv', 'x-ms-wmv', 'x-ms-wmx', 'x-ms-wmx', 'x-ms-wvx', 'x-ms-wvx', 'x-msvideo', 'x-msvideo', 'x-nsv', 'x-nsv', 'x-ogm', 'x-ogm', 'x-ogm+ogg', 'x-theora', 'x-theora', 'x-theora+ogg', 'x-theora+ogg'] + apps = [] + for ext in extensions: + app = get_vm_command_line(env, {"command": ["xdg-mime", "query", "default", f"video/{ext}"]}) + if app: + apps.append(app) + if len(apps) == 0: + return 'unknown' + else: + return Counter(apps).most_common(1)[0][0] + elif os_type == "Darwin": + raise Exception("Unsupported operating system", os_type) + elif os_type == "Windows": + raise Exception("Unsupported operating system", os_type) + else: + raise Exception("Unsupported operating system", os_type) \ No newline at end of file diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/getters/vscode.py b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/getters/vscode.py new file mode 100644 index 000000000..bf8f3516c --- /dev/null +++ b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/getters/vscode.py @@ -0,0 +1,35 @@ +import logging +from typing import Any, Dict +import time +from .file import get_vm_file +from .replay import get_replay + +logger = logging.getLogger("desktopenv.getters.vscode") + + +def get_vscode_config(env, config: Dict[str, Any]) -> str: + os_type = env.vm_platform + vscode_extension_command = config["vscode_extension_command"] + + # fixme: depends on how we config and install the vscode in virtual machine, need to be aligned and double-checked + + if os_type == "MacOS": + trajectory = [ + {"type": "hotkey", "param": ["command", "shift", "p"]}, + {"type": "typewrite", "param": vscode_extension_command}, + {"type": "press", "param": "enter"} + ] + else: + trajectory = [ + {"type": "hotkey", "param": ["ctrl", "shift", "p"]}, + {"type": "typewrite", "param": vscode_extension_command}, + {"type": "press", "param": "enter"} + ] + + get_replay(env, trajectory) + time.sleep(1.0) + + return get_vm_file(env, { + "path": config["path"], + "dest": config["dest"] + }) diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/__init__.py b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/__init__.py new file mode 100644 index 000000000..79cd2488e --- /dev/null +++ b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/__init__.py @@ -0,0 +1,159 @@ +from .basic_os import ( + check_gnome_favorite_apps, + is_utc_0, + check_text_enlarged, + check_moved_jpgs, + is_in_vm_clickboard +) +from .chrome import ( + is_expected_tabs, + is_expected_bookmarks, + compare_pdfs, + compare_htmls, + compare_archive, + is_cookie_deleted, + is_shortcut_on_desktop, + check_font_size, + check_enabled_experiments, + check_history_deleted, + is_expected_search_query, + is_expected_active_tab, + is_expected_url_pattern_match, + is_added_to_steam_cart, + is_expected_installed_extensions, + compare_pdf_images +) +from .docs import ( + compare_font_names, + compare_subscript_contains, + has_page_numbers_in_footers, + compare_docx_lines, + evaluate_colored_words_in_tables, + check_highlighted_words, + evaluate_strike_through_last_paragraph, + evaluate_conversion, + evaluate_spacing, + check_italic_font_size_14, + evaluate_alignment, + get_unique_train_ids, + check_no_duplicates, + compare_init_lines, + find_default_font, + contains_page_break, + compare_docx_files, + compare_docx_tables, + compare_line_spacing, + compare_insert_equation, + compare_highlighted_text, + is_first_line_centered, + check_file_exists, + check_tabstops, + compare_contains_image, + compare_docx_files_and_ignore_new_lines, + compare_docx_images, + compare_image_text, + compare_references, + compare_unique_train_records +) +from .general import ( + check_csv, + check_accessibility_tree, + run_sqlite3, + check_json, + check_list, + exact_match, + match_in_list, + is_in_list, + fuzzy_match, + check_include_exclude, + check_direct_json_object, + compare_time_in_speedtest_results, + is_included_all_json_objects, + is_gold_text_included_in_pdf, + check_line_number, + file_contains, + compare_terminal_and_txt, + fuzzy_place_math, + compare_python_pure_text, + diff_text_file, + literal_match +) +from .gimp import ( + check_structure_sim_resized, + check_brightness_decrease_and_structure_sim, + check_contrast_increase_and_structure_sim, + check_saturation_increase_and_structure_sim, + check_image_size, + check_image_mirror, + check_palette_and_structure_sim, + check_textbox_on_leftside, + check_green_background, + check_file_exists_and_structure_sim, + check_triangle_position, + check_structure_sim, + check_config_status, + compare_image_list, + increase_saturation, + decrease_brightness, + check_file_exists, + compare_triangle_positions, + check_sharper, + check_image_file_size +) +from .libreoffice import check_libre_locale +from .others import compare_epub, check_mp3_meta +from .pdf import check_pdf_pages +from .slides import ( + check_presenter_console_disable, + check_image_stretch_and_center, + check_slide_numbers_color, + compare_pptx_files, + check_strikethrough, + check_slide_orientation_Portrait, + evaluate_presentation_fill_to_rgb_distance, + check_left_panel, + check_transition, + check_page_number_colors, + check_auto_saving_time +) +from .table import ( + compare_table, + compare_csv, + compare_conference_city_in_order +) +from .thunderbird import ( + check_thunderbird_prefs, + check_thunderbird_filter, + check_thunderbird_folder +) +from .vlc import ( + is_vlc_playing, + is_vlc_recordings_folder, + is_vlc_fullscreen, + compare_images, + compare_audios, + compare_videos, + check_qt_bgcone, + check_one_instance_when_started_from_file, + check_qt_minimal_view, + check_qt_max_volume, + check_qt_slider_colours, + check_global_key_play_pause +) +from .vscode import ( + compare_text_file, + compare_config, + compare_answer, + compare_result_files, + is_extension_installed, + check_json_settings, + check_json_keybindings, + check_python_file_by_test_suite, + check_python_file_by_gold_file, + check_html_background_image, + compare_zip_files +) + + +def infeasible(): + pass diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/basic_os.py b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/basic_os.py new file mode 100644 index 000000000..05e51ff64 --- /dev/null +++ b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/basic_os.py @@ -0,0 +1,68 @@ +def check_gnome_favorite_apps(apps_str: str, rule): + # parse the string like "['thunderbird.desktop', 'vim.desktop', 'google-chrome.desktop']" + # to a list of strings + apps = eval(apps_str) + + expected_apps = rule["expected"] + + if len(apps) != len(expected_apps): + return 0 + + if set(apps) == set(expected_apps): + return 1 + else: + return 0 + + +def is_utc_0(timedatectl_output): + """ + Format as: + Local time: Thu 2024-01-25 12:56:06 WET + Universal time: Thu 2024-01-25 12:56:06 UTC + RTC time: Thu 2024-01-25 12:56:05 + Time zone: Atlantic/Faroe (WET, +0000) +System clock synchronized: yes + NTP service: inactive + RTC in local TZ: no + """ + + utc_line = timedatectl_output.split("\n")[3] + + if utc_line.endswith("+0000)"): + return 1 + else: + return 0 + + +def check_text_enlarged(scaling_factor_str): + scaling_factor = float(scaling_factor_str) + if scaling_factor > 1.0: + return 1 + else: + return 0 + + +def check_moved_jpgs(directory_list, rule): + expected_jpgs = rule["expected"] + moved_jpgs = [node['name'] for node in directory_list['children']] + + if len(moved_jpgs) != len(expected_jpgs): + return 0 + + if set(moved_jpgs) == set(expected_jpgs): + return 1 + else: + return 0 + + +def is_in_vm_clickboard(config, terminal_output): + print("terminal_output: ") + print(terminal_output) + print("config: ") + print(config) + expected_results = config["expected"] + # check if terminal_output has expected results + if not isinstance(expected_results, list): + return 1 if expected_results in terminal_output else 0 + else: + return 1 if all(result in terminal_output for result in expected_results) else 0 diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/chrome.py b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/chrome.py new file mode 100644 index 000000000..6bc61dae1 --- /dev/null +++ b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/chrome.py @@ -0,0 +1,421 @@ +import logging +import os +import re +import shutil +from itertools import product +from typing import Any, Dict, List, Union + +import rapidfuzz.fuzz as fuzz +from bs4 import BeautifulSoup, Tag + +from skyrl_agent.tasks.osworld.desktop_env.evaluators.metrics.utils import are_lists_equal, compare_urls + +logger = logging.getLogger("desktopenv.metrics.chrome") + + +async def is_expected_active_tab(active_tab_info: Dict[str, str], rule: Dict[str, Any]) -> float: + """ + Checks if the expected active tab is open in Chrome. + """ + # Handle case where active_tab_info is a coroutine (async function result) + if hasattr(active_tab_info, '__await__'): + active_tab_info = await active_tab_info + + if not active_tab_info: + return 0. + + match_type = rule['type'] + + if match_type == "url": + expected_url = rule['url'] + if isinstance(active_tab_info, Dict): + actual_url = active_tab_info.get('url', None) + else: + actual_url = active_tab_info + print("expected_url: {}".format(expected_url)) + print("actual_url: {}".format(actual_url)) + return 1 if compare_urls(expected_url, actual_url) else 0 + else: + logger.error(f"Unknown type: {match_type}") + return 0 + + +# rules[expected] is a string-formatted regex +async def is_expected_url_pattern_match(result, rules) -> float: + """ + This function is used to search the expected pattern in the url using regex. + result is the return value of function "activte_tab_info" or return value of function "get_active_url_from_accessTree" + """ + # Handle case where result is a coroutine (async function result) + if hasattr(result, '__await__'): + result = await result + + if not result: + return 0. + + if type(result) == dict: + result_url = result["url"] + print("result url: {}".format(result_url)) + else: + result_url = result + # expect_regex = re.compile(rules["expected"]) + patterns = rules["expected"] + print("expected_regex: {}".format(patterns)) + for pattern in patterns: + match = re.search(pattern, result_url) + print(match) + if not match: + return 0. + return 1. + + +def is_expected_installed_extensions(installed_extensions, expected) -> float: + print("installed_extensions: ") + print(installed_extensions) + expected_extensions = expected["expected"] + + # whether the expected extensions are installed + set_expected_extensions = set(expected_extensions) + set_installed_extensions = set(installed_extensions) + + if set_expected_extensions.issubset(set_installed_extensions): + return 1. + else: + return 0. + + +async def is_expected_tabs(open_tabs: List[Dict[str, str]], rule: Dict[str, Any]) -> float: + """ + Checks if the expected tabs are open in Chrome. + """ + + # Handle case where open_tabs is a coroutine (async function result) + if hasattr(open_tabs, '__await__'): + open_tabs = await open_tabs + + match_type = rule['type'] + + if match_type == "url": + expected_urls = rule['urls'] + actual_urls = [tab['url'] for tab in open_tabs] + return 1 if are_lists_equal(expected_urls, actual_urls, compare_urls) else 0 + else: + logger.error(f"Unknown type: {match_type}") + return 0 + + +def is_expected_bookmarks(bookmarks: List[str], rule: Dict[str, Any]) -> float: + """ + Checks if the expected bookmarks are in Chrome. + """ + if not bookmarks: + return 0. + elif rule['type'] == "bookmark_bar_folders_names": + bookmark_bar_folders_names = [bookmark['name'] for bookmark in bookmarks['bookmark_bar']['children'] if + bookmark['type'] == 'folder'] + return 1. if set(bookmark_bar_folders_names) == set(rule['names']) else 0. + elif rule['type'] == "bookmark_bar_websites_urls": + bookmark_bar_websites_urls = [bookmark['url'] for bookmark in bookmarks['bookmark_bar']['children'] if + bookmark['type'] == 'url'] + return 1. if set(bookmark_bar_websites_urls) == set(rule['urls']) else 0. + elif rule['type'] == "liked_authors_websites_urls": + # Check if "liked authors" folder exists + liked_authors_folder = next((bookmark for bookmark in bookmarks['bookmark_bar']['children'] if + bookmark['type'] == 'folder' and bookmark['name'] == 'Liked Authors'), None) + if liked_authors_folder: + # Check if it contains the specified URLs + liked_authors_urls = [bookmark['url'] for bookmark in liked_authors_folder['children'] if + bookmark['type'] == 'url'] + + urls = rule['urls'] + + for idx, url in enumerate(urls): + if isinstance(url, str): + urls[idx] = [url] + + combinations = product(*urls) + + for combination in combinations: + if set(combination) == set(liked_authors_urls): + return 1. + return 0. + else: + return 0. + else: + raise TypeError(f"{rule['type']} not support yet!") + + +async def is_expected_search_query(active_tab_info: Dict[str, str], rules: Dict[str, Any]) -> float: + # Handle case where active_tab_info is a coroutine (async function result) + if hasattr(active_tab_info, '__await__'): + active_tab_info = await active_tab_info + + expected = rules['expect'] + pattern = expected['pattern'] + matched = re.search(pattern, active_tab_info['url']) + if matched: + return 1. + return 0. + + +def compare_pdfs(pdf1_path: Union[str, List[str]], pdf2_path: Union[str, List[str]]): + """ + Compare two PDF files. + """ + if type(pdf2_path) != list: + pdf1_path, pdf2_path = [pdf1_path], [pdf2_path] + + def extract_text_from_pdf(pdf_path): + """Extract text from each page of the PDF.""" + text = "" + with fitz.open(pdf_path) as pdf: + for page in pdf: + text += page.get_text() + return text.strip() + + score = 0. + for path1, path2 in zip(pdf1_path, pdf2_path): + try: + text1 = extract_text_from_pdf(path1) + text2 = extract_text_from_pdf(path2) + score += fuzz.ratio(text1, text2) / 100 + except Exception as e: + logger.info(f"[ERROR]: unexpected error occurred when comparing PDF files: {e}") + return score / len(pdf2_path) + + +import fitz +from PIL import Image +from borb.pdf import Document +from borb.pdf import PDF + +from pathlib import Path +import typing + + +def compare_pdf_images(pdf1_path: str, pdf2_path: str, **kwargs) -> float: + if not pdf1_path or not pdf2_path: + return 0. + + def extract_images_from_pdf(pdf_path): + pdf_document = fitz.open(pdf_path) + images = [] + + for page_number in range(pdf_document.page_count): + page = pdf_document[page_number] + pixmap = page.get_pixmap() + + img = Image.frombytes("RGB", [pixmap.width, pixmap.height], pixmap.samples) + + images.append(img) + + return images + + def fix_pdf(in_path: Path, out_path: Path) -> None: + doc: typing.Optional[Document] = None + with open(in_path, "rb") as fh: + doc = PDF.loads(fh) + with open(out_path, "wb") as fh: + PDF.dumps(fh, doc) + + fix_pdf(Path(pdf1_path), Path(pdf1_path)) + fix_pdf(Path(pdf2_path), Path(pdf2_path)) + + images1 = extract_images_from_pdf(pdf1_path) + images2 = extract_images_from_pdf(pdf2_path) + + if len(images1) != len(images2): + return 0. + + for img1, img2 in zip(images1, images2): + if img1.tobytes() != img2.tobytes(): + return 0. + + return 1. + + +def compare_archive(pred_path: str, gold_path: str, **kwargs) -> float: + """ + Compare two archives. Note that the files in the archives should be of the same type. + """ + file_path = kwargs.pop('file_path', '') + + if not pred_path: + return 0. + pred_folder = os.path.splitext(pred_path)[0] + '_pred' + gold_folder = os.path.splitext(gold_path)[0] + '_gold' + + if os.path.exists(pred_folder): # remove existing folder for new predictions + shutil.rmtree(pred_folder, ignore_errors=True) + os.makedirs(pred_folder) + shutil.unpack_archive(pred_path, pred_folder) + + if not os.path.exists(gold_folder): # use cache if exists + os.makedirs(gold_folder) + shutil.unpack_archive(gold_path, gold_folder) + + pred_files = sorted(os.listdir(os.path.join(pred_folder, file_path))) + gold_files = sorted(os.listdir(os.path.join(gold_folder, file_path))) + + if pred_files != gold_files: + return 0. + + def get_compare_function(): + file_type = kwargs.pop('file_type', 'text') + if file_type == 'text': + from .vscode import compare_text_file + return compare_text_file + elif file_type == 'pdf': + return compare_pdfs + elif file_type == 'docx': + from .docs import compare_docx_files + return compare_docx_files + elif file_type == 'ppt': + from .slides import compare_pptx_files + return compare_pptx_files + elif file_type == 'image': + from .vlc import compare_images + return compare_images + elif file_type == 'csv': + from .table import compare_csv + return compare_csv + elif file_type == 'table': + from .table import compare_table + return compare_table + elif file_type == 'audio': + from .vlc import compare_audios + return compare_audios + elif file_type == 'video': + from .vlc import compare_videos + return compare_videos + else: + raise ValueError('[ERROR]: not support file type: %s' % file_type) + + score = 0 + compare_function = get_compare_function() + for f1, f2 in zip(pred_files, gold_files): + fp1 = os.path.join(pred_folder, file_path, f1) + fp2 = os.path.join(gold_folder, file_path, f2) + score += compare_function(fp1, fp2, **kwargs) + return score / len(pred_files) + + +def compare_htmls(html_path1: str, html_path2: str) -> float: + """ + Compare two HTML files. + """ + with open(html_path1, 'r', encoding='utf-8') as inf: + soup1 = BeautifulSoup(inf, 'lxml') + with open(html_path2, 'r', encoding='utf-8') as inf: + soup2 = BeautifulSoup(inf, 'lxml') + + def compare_elements(elem1, elem2): + if not (isinstance(elem1, Tag) and isinstance(elem2, Tag)): + return elem1 == elem2 + if elem1.name != elem2.name: + return False + if elem1.text.strip() != elem2.text.strip(): + return False + if elem1.attrs != elem2.attrs: + return False + return True + + for elem1, elem2 in zip(soup1.recursiveChildGenerator(), soup2.recursiveChildGenerator()): + if not compare_elements(elem1, elem2): + return .0 + return 1. + + +def is_cookie_deleted(cookie_data, rule): + """ + Check if the cookie is deleted. + """ + + if rule['type'] == 'domains': + cookies_domains = [cookie[1] for cookie in cookie_data] + for domain in rule['domains']: + for cookies_domain in cookies_domains: + if compare_urls(domain, cookies_domain): + return 0. + return 1. + else: + raise TypeError(f"{rule['type']} not support yet!") + + +def is_shortcut_on_desktop(shortcuts: Dict[str, str], rule): + """ + Check if the shortcut is on the desktop. + """ + # fixme: if the name of the website changed in the future, this will not work; can be replaced with url + if rule['type'] == 'name': + for shortcut_path, shortcut_content in shortcuts.items(): + if "Name=" + rule['name'] + "\n" in shortcut_content: + return 1. + return 0. + elif rule['type'] == 'url': + raise TypeError(f"{rule['type']} not support yet!") + elif rule['type'] == 'id': + raise TypeError(f"{rule['type']} not support yet!") + else: + raise TypeError(f"{rule['type']} not support yet!") + + +def check_history_deleted(history_data, rule): + """ + Check if the history is deleted. + """ + + if rule['type'] == 'keywords': + history_domains = [history[0] for history in history_data] + for keyword in rule['keywords']: + for history_domain in history_domains: + if keyword in history_domain: + return 0. + return 1. + else: + raise TypeError(f"{rule['type']} not support yet!") + + +def check_enabled_experiments(enabled_experiments, rule): + """ + Check if the enabled experiments are as expected. + """ + enabled_experiments_names = [experiment.split("@")[0] for experiment in enabled_experiments] + + if rule['type'] == 'names': + return 1. if enabled_experiments_names == rule['names'] else 0. + else: + raise TypeError(f"{rule['type']} not support yet!") + + +def check_font_size(font_size, rule): + """ + Check if the font size is as expected. + """ + + default_font_size = font_size['default_font_size'] + if rule['type'] == 'value': + return 1. if default_font_size == rule['value'] else 0. + elif rule['type'] == 'range': + return 1. if rule['min'] < default_font_size < rule['max'] else 0. + else: + raise TypeError(f"{rule['type']} not support yet!") + + +async def is_added_to_steam_cart(active_tab_info, rule): + """ + Check if the item is added to the Steam cart. + """ + # Handle case where active_tab_info is a coroutine (async function result) + if hasattr(active_tab_info, '__await__'): + active_tab_info = await active_tab_info + + items = rule['items'] + + content = active_tab_info['content'] + + for item in items: + if item not in content: + return 0. + + return 1. diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/docs.py b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/docs.py new file mode 100644 index 000000000..908a38729 --- /dev/null +++ b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/docs.py @@ -0,0 +1,961 @@ +import logging +import os +import re +import xml.etree.ElementTree as ET +import zipfile +from io import BytesIO +from typing import List, Dict, Any + +import easyocr +from PIL import Image +from docx import Document +from docx.enum.text import WD_PARAGRAPH_ALIGNMENT, WD_TAB_ALIGNMENT +from docx.shared import RGBColor +from odf.opendocument import load +from odf.text import P +from odf.text import Span +from rapidfuzz import fuzz +from skimage.color import deltaE_ciede2000 +from skimage.color import rgb2lab + +logger = logging.getLogger("desktopenv.metric.docs") + + +def find_default_font(config_file_path, rules): + """Find the default font in LibreOffice Writer.""" + default_font = None + expected_font = rules["font_name"] + + if not config_file_path: + return 0 + + try: + tree = ET.parse(config_file_path) + root = tree.getroot() + + # Define the XML namespace used in the file + namespace = {'oor': 'http://openoffice.org/2001/registry'} + + # Search for the node containing the default font setting for LibreOffice Writer + for elem in root.findall('.//item[@oor:path="/org.openoffice.Office.Writer/DefaultFont"]', namespace): + for prop in elem.findall('.//prop[@oor:name="Standard"]', namespace): + for value in prop.findall('value', namespace): + default_font = value.text + except Exception as e: + logger.error(f"Error: {e}") + + return 1 if default_font == expected_font else 0 + + +def contains_page_break(docx_file, rules): + if not docx_file: + return 0 + + try: + doc = Document(docx_file) + except Exception as e: + logger.error(f"Error: {e}") + return 0 + + try: + expected_page_break_count = rules["page_break_count"] + except Exception as e: + expected_page_break_count = None + + namespaces = {'w': 'http://schemas.openxmlformats.org/wordprocessingml/2006/main'} + + page_break_count = 0 + for paragraph in doc.paragraphs: + for run in paragraph.runs: + br_elems = run.element.findall('.//w:br', namespaces) + for br in br_elems: + if br is not None and '{http://schemas.openxmlformats.org/wordprocessingml/2006/main}type' in br.attrib and \ + br.attrib['{http://schemas.openxmlformats.org/wordprocessingml/2006/main}type'] == 'page': + page_break_count += 1 + + if expected_page_break_count is not None and page_break_count != expected_page_break_count: + return 0 + + if page_break_count > 0: + return 1 + else: + return 0 + +def compare_docx_files(file1, file2, **options): + ignore_blanks = options.get('ignore_blanks', True) + ignore_case = options.get('ignore_case', False) + ignore_order = options.get('ignore_order', False) + content_only = options.get('content_only', False) + + if not file1 or not file2: + return 0 + + def get_paragraph_texts_odt(document): + paragraphs = document.getElementsByType(P) + paragraph_texts = [] + for paragraph in paragraphs: + text_parts = [] + for node in paragraph.childNodes: + if node.nodeType == node.TEXT_NODE: + text_parts.append(node.data) + elif node.nodeType == node.ELEMENT_NODE and node.tagName == 'text:span': + # Assuming direct text content in , for simplicity + for child in node.childNodes: + if child.nodeType == child.TEXT_NODE: + text_parts.append(child.data) + paragraph_texts.append(''.join(text_parts)) + return paragraph_texts + + # Determine file types and load documents + if file1.endswith('.docx') and file2.endswith('.docx'): + try: + doc1 = Document(file1) + doc2 = Document(file2) + except Exception as e: + logger.error(f"Error: {e}") + return 0 + doc1_paragraphs = [p.text for p in doc1.paragraphs] + doc2_paragraphs = [p.text for p in doc2.paragraphs] + if ignore_order: + doc1_paragraphs = sorted(doc1_paragraphs) + doc2_paragraphs = sorted(doc2_paragraphs) + elif file1.endswith('.odt') and file2.endswith('.odt'): + try: + doc1 = load(file1) + doc2 = load(file2) + except Exception as e: + logger.error(f"Error: {e}") + return 0 + doc1_paragraphs = get_paragraph_texts_odt(doc1) + doc2_paragraphs = get_paragraph_texts_odt(doc2) + if ignore_order: + doc1_paragraphs = sorted(doc1_paragraphs) + doc2_paragraphs = sorted(doc2_paragraphs) + else: + # Unsupported file types or mismatch + print("Unsupported file types or mismatch between file types.") + return 0 + + if content_only: + # Compare the content of the documents + text1 = re.sub(r'\s+', ' ', '\n'.join(doc1_paragraphs)).strip() + text2 = re.sub(r'\s+', ' ', '\n'.join(doc2_paragraphs)).strip() + if ignore_case: + text1, text2 = text1.lower(), text2.lower() + similarity = fuzz.ratio(text1, text2) / 100.0 + return similarity + + # Process and compare documents + if ignore_blanks: + text1 = re.sub(r'\s+', ' ', '\n'.join(doc1_paragraphs)).strip() + text2 = re.sub(r'\s+', ' ', '\n'.join(doc2_paragraphs)).strip() + if ignore_case: + text1, text2 = text1.lower(), text2.lower() + if text1 != text2: + return 0 + else: + print("ignore_blanks=false") + if len(doc1_paragraphs) != len(doc2_paragraphs): + print(doc1_paragraphs) + print(doc2_paragraphs) + print(len(doc1_paragraphs)) + print(len(doc2_paragraphs)) + return 0 + print("in compare") + # Compare each paragraph + for p1, p2 in zip(doc1_paragraphs, doc2_paragraphs): + if ignore_case: + p1, p2 = p1.lower(), p2.lower() + if p1 != p2: + # show the difference + print("=== First Paragraph ===") + print(f"\033[92m{repr(p1)}\033[0m") # Green color for p1, repr() shows hidden chars + print("=== Second Paragraph ===") + print(f"\033[91m{repr(p2)}\033[0m") # Red color for p2, repr() shows hidden chars + print("=" * 50) # Clear boundary + return 0 + + return 1 + + +def compare_init_lines(file1, file2): + if not file1 or not file2: + return 0 + + try: + doc1 = Document(file1) + doc2 = Document(file2) + except Exception as e: + logger.error(f"Error: {e}") + return 0 + + doc1_paragraphs = [p.text for p in doc1.paragraphs] + doc2_paragraphs = [p.text for p in doc2.paragraphs] + + # Compare each paragraph + for p1, p2 in zip(doc1_paragraphs, doc2_paragraphs): + if p1 != p2: + # print(p1) + # print(p2) + return 0 + + return 1 + + +def compare_docx_tables(docx_file1, docx_file2): + if not docx_file1 or not docx_file2: + return 0 + + try: + doc1 = Document(docx_file1) + doc2 = Document(docx_file2) + except Exception as e: + logger.error(f"Error: {e}") + return 0 + + # get list of tables in docx + tables1 = doc1.tables + tables2 = doc2.tables + + if len(tables1) != len(tables2): + return 0 + + # Compare each table content + for table1, table2 in zip(tables1, tables2): + + if len(table1.rows) != len(table2.rows) or len(table1.columns) != len(table2.columns): + return 0 + + # Compare each cell + for i in range(len(table1.rows)): + for j in range(len(table1.columns)): + if table1.cell(i, j).text.strip() != table2.cell(i, j).text.strip(): + return 0 + + return 1 + + +def compare_docx_images(docx_file1, docx_file2): + if not docx_file1 or not docx_file2: + return 0 + + try: + doc1 = Document(docx_file1) + doc2 = Document(docx_file2) + except Exception as e: + logger.error(f"Error: {e}") + return 0 + + def extract_images(doc): + images = [] + for rel in doc.part.rels.values(): + if "image" in rel.reltype: + img_data = rel.target_part.blob + images.append(BytesIO(img_data)) + return images + + images1 = extract_images(doc1) + images2 = extract_images(doc2) + if len(images1) != len(images2): + return 0 + for img1, img2 in zip(images1, images2): + if Image.open(img1).tobytes() != Image.open(img2).tobytes(): + return 0 + return 1 + + +def compare_image_text(image_path, rule): + if not image_path: + return 0 + reader = easyocr.Reader(['en']) + result = reader.readtext(image_path) + extracted_text = ' '.join([entry[1] for entry in result]) + if rule['type'] == 'text': + return 1 if rule['text'] in extracted_text else 0 + else: + raise ValueError("Unsupported rule type") + + +def compare_line_spacing(docx_file1, docx_file2): + if not docx_file1 or not docx_file2: + return 0 + + if not compare_docx_files(docx_file1, docx_file2): + return 0 + + try: + doc1 = Document(docx_file1) + doc2 = Document(docx_file2) + except Exception as e: + logger.error(f"Error: {e}") + return 0 + + if len(doc1.paragraphs) != len(doc2.paragraphs): + return 0 + + # Compare each paragraph line spacing + for para1, para2 in zip(doc1.paragraphs, doc2.paragraphs): + + spacing1 = para1.paragraph_format.line_spacing + spacing2 = para2.paragraph_format.line_spacing + + if spacing1 != spacing2: + return 0 + + return 1 + + +def compare_insert_equation(docx_file1, docx_file2): + if not docx_file1 or not docx_file2: + return 0 + + if not compare_docx_files(docx_file1, docx_file2): + return 0 + + try: + doc1 = Document(docx_file1) + doc2 = Document(docx_file2) + except Exception as e: + logger.error(f"Error: {e}") + return 0 + + # Compare each paragraph if it contains equation + for para1, para2 in zip(doc1.paragraphs, doc2.paragraphs): + for run1, run2 in zip(para1.runs, para2.runs): + if run1.element.xpath('.//w:object') and run2.element.xpath('.//w:object'): + return 1 + return 0 + + +def compare_font_names(docx_file, rules: List[Dict[str, Any]]): + if not docx_file: + return 0 + + try: + doc = Document(docx_file) + except Exception as e: + logger.error(f"Error: {e}") + return 0 + + expected_font = rules["font_name"] + + for paragraph in doc.paragraphs: + for run in paragraph.runs: + font_name = run.font.name + if font_name != expected_font: + return 0 + return 1 + + +def compare_subscript_contains(docx_file1, docx_file2): + if not docx_file1 or not docx_file2: + return 0 + + try: + doc1 = Document(docx_file1) + doc2 = Document(docx_file2) + except Exception as e: + logger.error(f"Error: {e}") + return 0 + + for para1, para2 in zip(doc1.paragraphs, doc2.paragraphs): + for run1, run2 in zip(para1.runs, para2.runs): + # check if two paras both contain subscript + if run1.font.subscript and run2.font.subscript: + return 1 + return 0 + + +def has_page_numbers_in_footers(docx_file): + if not docx_file: + return 0 + + try: + doc = Document(docx_file) + except Exception as e: + logger.error(f"Error: {e}") + return 0 + + for section in doc.sections: + footer = section.footer + if footer is None: + return 0 + footer_text = footer.paragraphs[0].text if footer.paragraphs else '' + if not any(char.isdigit() for char in footer_text): + # if no digit in footer, then no page number + return 0 + return 1 + + +def is_first_line_centered(docx_file): + if not docx_file: + return 0 + + try: + doc = Document(docx_file) + except Exception as e: + logger.error(f"Error: {e}") + return 0 + + first_paragraph = doc.paragraphs[0] + + # check if the first line is center justified + return 1 if first_paragraph.paragraph_format.alignment == WD_PARAGRAPH_ALIGNMENT.CENTER else 0 + + +def check_file_exists(directory, filename): + if not directory or not filename: + return 0 + file_path = os.path.join(directory, filename) + return 1 if os.path.isfile(file_path) else 0 + + +def check_tabstops(docx_file1, docx_file2, **kwargs) -> float: + if not docx_file1 or not docx_file2: + return .0 + + try: + doc1: Document = Document(docx_file1) + doc2: Document = Document(docx_file2) + except Exception as e: + logger.error(f"Error: {e}") + return .0 + + para1 = [p for p in doc1.paragraphs if p.text.strip()] + para2 = [p for p in doc2.paragraphs if p.text.strip()] + if len(para1) != len(para2): return .0 + + if kwargs.get('word_number_split_by_tabstop', None) is not None: + number = kwargs['word_number_split_by_tabstop'] + index = kwargs.get('index', 0) + for p1 in para1: + splits = p1.text.split('\t') + if len(splits) == 0: return .0 + words = list(filter(lambda x: x.strip(), re.split(r'\s', splits[index]))) + if len(words) != number: return .0 + + section = doc2.sections[0] + paragraph_width = section.page_width - section.left_margin - section.right_margin + ignore_tabs = lambda x: x.alignment == WD_TAB_ALIGNMENT.CLEAR or ( + x.alignment == WD_TAB_ALIGNMENT.LEFT and x.position == 0) + minus = .0 + for p1, p2 in zip(para1, para2): + # filter CLEAR tabstop and default left-0 tabstop + tabs1 = [tst for tst in p1.paragraph_format.tab_stops if not ignore_tabs(tst)] + tabs2 = [tst for tst in p2.paragraph_format.tab_stops if not ignore_tabs(tst)] + if len(tabs1) != len(tabs2): return .0 + difference = .0 + for t1, t2 in zip(tabs1, tabs2): + if t1.alignment != t2.alignment: return .0 + difference += abs(t1.position - t2.position) + minus += difference / paragraph_width + score = 1 - (minus / len(para1)) + return score + + +def compare_contains_image(docx_file1, docx_file2): + if not docx_file1 or not docx_file2: + return 0 + + try: + doc1 = Document(docx_file1) + doc2 = Document(docx_file2) + except Exception as e: + logger.error(f"Error: {e}") + return 0 + + for para1, para2 in zip(doc1.paragraphs, doc2.paragraphs): + for run1, run2 in zip(para1.runs, para2.runs): + if ('graphicData' in run1._element.xml and 'graphicData' not in run2._element.xml) or ( + 'graphicData' not in run1._element.xml and 'graphicData' in run2._element.xml): + return 0 + return 1 + + +def evaluate_colored_words_in_tables(file_path1, file_path2, **kwargs): + if not file_path1 or not file_path2: + return 0 + + if not compare_docx_files(file_path1, file_path2): + return 0 + + try: + document = Document(file_path1) + except Exception as e: + logger.error(f"Error: {e}") + return 0 + + threshold = kwargs.get('threshold', 3.5) + + def _calculate_color_difference(rgb1, rgb2): + srgb1 = [rgb1[0] / 255.0, rgb1[1] / 255.0, rgb1[2] / 255.0] + srgb2 = [rgb2[0] / 255.0, rgb2[1] / 255.0, rgb2[2] / 255.0] + lab1, lab2 = rgb2lab(srgb1), rgb2lab(srgb2) + delta_e = deltaE_ciede2000(lab1, lab2) + return delta_e + + for table in document.tables: + # Iterate through rows and cells in the table + for row in table.rows: + for cell in row.cells: + for paragraph in cell.paragraphs: + for run in paragraph.runs: + word = run.text + if word: + first_letter = word[0].lower() + + if first_letter in 'aeiou' and _calculate_color_difference(run.font.color.rgb, + RGBColor(255, 0, 0)) > threshold: + return 0 # Vowel-colored words should be red + elif first_letter not in 'aeiou' and _calculate_color_difference(run.font.color.rgb, + RGBColor(0, 0, + 255)) > threshold: + return 0 # Non-vowel-colored words should be blue + + return 1 # All words in tables are correctly colored + + +def check_highlighted_words(file_path1, file_path2): + if not file_path1 or not file_path2: + return 0 + + if not compare_docx_files(file_path1, file_path2): + return 0 + + doc = load(file_path1) + highlighted = False + + for span in doc.getElementsByType(Span): + style_name = span.getAttribute('stylename') + if style_name: + for automatic_style in doc.automaticstyles.childNodes: + if automatic_style.getAttribute('name') == style_name: + for property in automatic_style.childNodes: + if property.getAttribute('backgroundcolor') == '#ffff00': + highlighted = True + break + if highlighted: + break + + return 0 if highlighted else 1 + + +def evaluate_strike_through_last_paragraph(file_path1, file_path2): + if not file_path1 or not file_path2: + return 0 + + if not compare_docx_files(file_path1, file_path2): + return 0 + + try: + document = Document(file_path1) + except Exception as e: + logger.error(f"Error: {e}") + return 0 + + # Get the last paragraph + last_paragraph = document.paragraphs[-1] + + # Check if any run in the last paragraph has strike-through formatting + for run in last_paragraph.runs: + if not run.font.strike: + return 0 # At least one word does not have strike-through formatting + + return 1 # All words in the last paragraph have strike-through formatting + + +def evaluate_conversion(file_path): + if not file_path: + return 0 + + try: + document = Document(file_path) + except Exception as e: + logger.error(f"Error: {e}") + return 0 + + for table in document.tables: + for row in table.rows: + for cell in row.cells: + for paragraph in cell.paragraphs: + for run in paragraph.runs: + if run.text.isupper(): + return 0 # Uppercase text should be converted to lowercase + + for paragraph in document.paragraphs: + for run in paragraph.runs: + if run.text.isupper(): + return 0 # Uppercase text should be converted to lowercase + + return 1 # All uppercase text has been successfully converted + + +def evaluate_spacing(file_path): + if not file_path: + return 0 + + try: + document = Document(file_path) + except Exception as e: + logger.error(f"Error: {e}") + return 0 + + # Check line spacing for introduction, body, and conclusion + introduction_spacing = document.paragraphs[0].paragraph_format.line_spacing + body_spacing = document.paragraphs[1].paragraph_format.line_spacing + conclusion_spacing = document.paragraphs[2].paragraph_format.line_spacing + if (introduction_spacing == 1.0 and body_spacing == 2.0 and conclusion_spacing == 1.5): + return 1 + else: + return 0 + + +def check_italic_font_size_14(path1, path2): + if not path1 or not path2: + return 0 + + if not compare_docx_files(path1, path2): + return 0 + + try: + document = Document(path1) + except Exception as e: + logger.error(f"Error: {e}") + return 0 + + for paragraph in document.paragraphs: + for run in paragraph.runs: + if run.italic: + # Check if font size is 14 + if run.font.size is None or run.font.size.pt != 14: + return 0 + return 1 + + +def evaluate_alignment(docx_path): + if not docx_path: + return 0 + + # Load the document + try: + doc = Document(docx_path) + except Exception as e: + logger.error(f"Error: {e}") + return 0 + + # Iterate through each paragraph in the document + for para in doc.paragraphs: + # Split the paragraph into individual sentences + sentences = para.text.split('.') + + for sentence in sentences: + # Split the sentence into words + words = sentence.strip().split() + + # Check if the sentence has at least three words + if len(words) < 3: + continue # Skip sentences with less than three words + + # The first three words should be separated from the rest + first_part = ' '.join(words[:3]) + second_part = ' '.join(words[3:]) + + # Check if the sentence structure matches the pattern: first part + large space/tab + second part + if not (first_part in sentence and second_part in sentence and sentence.find(first_part) < sentence.find( + second_part)): + return 0 # The sentence does not meet the alignment criteria + + return 1 # All sentences meet the alignment criteria + + +def get_unique_train_ids(initial_file): # fixed standard + if not initial_file: + return set(), 0 + + try: + doc = Document(initial_file) + except Exception as e: + logger.error(f"Error: {e}") + return set(), 0 + + train_ids = set() + processed_lines = 0 + + for para in doc.paragraphs: + line_parts = para.text.split(',') + if len(line_parts) == 4: + train_id = line_parts[1].strip() + if train_id not in train_ids: + train_ids.add(train_id) + processed_lines += 1 + + return train_ids, processed_lines + + +def check_no_duplicates(initial_file, processed_file): + if not initial_file or not processed_file: + return 0 + + # Open the document + train_ids_ini, ini_lines = get_unique_train_ids(initial_file) + + try: + doc_processed = Document(processed_file) + except Exception as e: + logger.error(f"Error: {e}") + return 0 + + train_ids_pro = set() + processed_lines = 0 # Counter for valid lines processed + + # processed + for para in doc_processed.paragraphs: + # Each line has the format: time_HH:MM:SS, train_id, station_id, platform_no + line_parts = para.text.split(',') + # Ensure the line has the correct format + if len(line_parts) == 4: + train_id = line_parts[1].strip() + # If train_id is already in the set, it's a duplicate + if train_id in train_ids_pro: + return 0 # Duplicate found + train_ids_pro.add(train_id) + processed_lines += 1 # Increment valid lines counter + + if train_ids_pro != train_ids_ini or processed_lines != ini_lines: + return 0 + + # No duplicates found and at least one valid line was processed + return 1 + + +def compare_docx_lines(file1, file2): + if not file1 or not file2: + return 0 + + # Read the text of the document, line by line + try: + doc1 = Document(file1) + doc2 = Document(file2) + except Exception as e: + logger.error(f"Error: {e}") + return 0 + + doc1_lines = [p.text.strip() for p in doc1.paragraphs if p.text.strip()] + doc2_lines = [p.text.strip() for p in doc2.paragraphs if p.text.strip()] + # print(doc1_lines) + # print(doc2_lines) + + # Convert the list of lines to sets and compare + if set(doc1_lines) == set(doc2_lines): + return 1 + else: + return 0 + + +def compare_docx_files_and_ignore_new_lines(file1, file2, **options): + ignore_blanks = options.get('ignore_blanks', True) + + if not file1 or not file2: + return 0 + + # Determine file types and load documents + if file1.endswith('.docx') and file2.endswith('.docx'): + try: + doc1 = Document(file1) + doc2 = Document(file2) + except Exception as e: + logger.error(f"Error: {e}") + return 0 + + # First, delete all the blank in paragraphs + doc1 = [p for p in doc1.paragraphs if p.text != ''] + doc2 = [p for p in doc2.paragraphs if p.text != ''] + doc1_paragraphs = [p.text for p in doc1] + doc2_paragraphs = [p.text for p in doc2] + else: + # Unsupported file types or mismatch + print("Unsupported file types or mismatch between file types.") + return 0 + + # Process and compare documents + if ignore_blanks: + text1 = re.sub(r'\s+', ' ', '\n'.join(doc1_paragraphs)).strip() + text2 = re.sub(r'\s+', ' ', '\n'.join(doc2_paragraphs)).strip() + if text1 != text2: + return 0 + else: + if len(doc1_paragraphs) != len(doc2_paragraphs): + return 0 + # Compare each paragraph + for p1, p2 in zip(doc1_paragraphs, doc2_paragraphs): + if p1 != p2: + return 0 + return 1 + + +# Docx file saved in the ubuntu cannot use this function to compare highlight, don't know why, deprecated +def compare_highlighted_text(file1, file2): + if not file1 or not file2: + return 0 + + def extract_highlighted_text(file_path): + highlighted_texts = [] + + # Open the .docx file as a zip file and read the document.xml + with zipfile.ZipFile(file_path, 'r') as docx: + with docx.open('word/document.xml') as document_xml: + tree = ET.parse(document_xml) + root = tree.getroot() + + # Define the namespaces + namespaces = { + 'w': 'http://schemas.openxmlformats.org/wordprocessingml/2006/main', + } + + # Find all runs with highlight property + for run in root.findall('.//w:r', namespaces): + highlight = run.find('.//w:highlight', namespaces) + if highlight is not None and highlight.get( + '{http://schemas.openxmlformats.org/wordprocessingml/2006/main}val') != 'none': + text = run.find('.//w:t', namespaces) + if text is not None: + highlighted_texts.append(text.text) + + return highlighted_texts + + # Read the highlighted text from both documents + doc1_highlighted = extract_highlighted_text(file1) + doc2_highlighted = extract_highlighted_text(file2) + + # Compare the sets of highlighted text to check if they are the same + if set(doc1_highlighted) == set(doc2_highlighted): + return 1 + else: + return 0 + + +def compare_references(file1, file2, **options): + if not file1 or not file2: + return 0 + + reference_indicator = options.get('reference_indicator', 'References') + reference_base_result = options.get('reference_base_result', 0.5) + + # Determine file types and load documents + if file1.endswith('.docx') and file2.endswith('.docx'): + try: + doc1 = Document(file1) + doc2 = Document(file2) + except Exception as e: + logger.error(f"Error: {e}") + return 0 + + doc1_paragraphs = [p.text for p in doc1.paragraphs] + doc2_paragraphs = [p.text for p in doc2.paragraphs] + else: + # Unsupported file types or mismatch + print("Unsupported file types or mismatch between file types.") + return 0 + + # Find the references section in the paragraphs, find the idx of the last reference_indicator in the paragraph list + ref1_idx = doc1_paragraphs.index(reference_indicator) if reference_indicator in doc1_paragraphs else -1 + ref2_idx = doc2_paragraphs.index(reference_indicator) if reference_indicator in doc2_paragraphs else -1 + + if ref1_idx == -1 and ref2_idx == -1: + return 1 + + if ref1_idx == -1 or ref2_idx == -1: + return 0 + + # split the reference section into reference items, and remove the empty string items + ref1 = [p for p in doc1_paragraphs[ref1_idx + 1:] if p.strip()] + ref2 = [p for p in doc2_paragraphs[ref2_idx + 1:] if p.strip()] + + # Compare the references + + if len(ref1) != len(ref2): + return 0 + + total_similarity = 0 + for r1, r2 in zip(ref1, ref2): + # fuzzy match the references + similarity = fuzz.ratio(r1, r2) / 100.0 + total_similarity += similarity + + result = total_similarity / len(ref1) + + epsilon = 0.01 + + if result >= reference_base_result + epsilon: + return (result - reference_base_result) / (1 - reference_base_result) + else: + return 0 + + +def compare_unique_train_records(processed_file, expected_files, **kwargs): + """ + Compares the processed file with a list of expected files containing the + gold standard and the initial document. + expected_files[0] should be the gold standard file. + expected_files[1] should be the initial file. + """ + # Debug logging to understand what we're actually receiving + logger.info(f"DEBUG: processed_file type: {type(processed_file)}, value: {processed_file}") + logger.info(f"DEBUG: expected_files type: {type(expected_files)}, value: {expected_files}") + logger.info(f"DEBUG: kwargs: {kwargs}") + + if not processed_file or not isinstance(expected_files, list) or len(expected_files) < 2: + logger.error("Invalid arguments: processed_file and a list of 2 expected_files are required.") + return 0 + + gold_file = expected_files[0] + initial_file = expected_files[1] + + if not gold_file or not initial_file: + logger.error("Gold file or initial file path is missing from expected_files list.") + return 0 + + # Helper function to get lines and IDs from a file + def get_lines_and_ids_from_file(file_path): + try: + doc = Document(file_path) + lines = [p.text.strip() for p in doc.paragraphs if p.text.strip()] + train_ids = [line.split(',')[1].strip() for line in lines if len(line.split(',')) == 4] + return lines, train_ids + except Exception as e: + logger.error(f"Error opening or parsing file {file_path}: {e}") + return None, None + + # Get data from all three files + processed_lines, processed_train_ids = get_lines_and_ids_from_file(processed_file) + if processed_lines is None: return 0 + + gold_lines, gold_train_ids = get_lines_and_ids_from_file(gold_file) + if gold_lines is None: return 0 + + initial_lines, _ = get_lines_and_ids_from_file(initial_file) + if initial_lines is None: return 0 + initial_lines_set = set(initial_lines) + + # 1. Subset Check: Ensure every processed line was in the initial file + if not set(processed_lines).issubset(initial_lines_set): + logger.error("Processed file contains lines not present in the initial file.") + logger.error(f"Extra lines: {set(processed_lines) - initial_lines_set}") + return 0 + + # 2. Uniqueness Check: Check for duplicates within the processed file + if len(processed_train_ids) != len(set(processed_train_ids)): + logger.error("Duplicate train_ids found in the processed file.") + return 0 + + # 3. Correctness Check: Compare the set of train_ids + if set(processed_train_ids) != set(gold_train_ids): + logger.error("Set of train_ids does not match between processed file and gold file.") + return 0 + + # 4. Line count check + if len(processed_lines) != len(gold_lines): + logger.error("Number of lines does not match between processed file and gold file.") + return 0 + + return 1 diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/general.py b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/general.py new file mode 100644 index 000000000..a401b740c --- /dev/null +++ b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/general.py @@ -0,0 +1,506 @@ +import csv +import datetime +import difflib +import functools +import json +import logging +import operator +import os +import re +import sqlite3 +from numbers import Number +from typing import Callable, Any, Union +from typing import Dict, List, Pattern + +import lxml.etree +import pdfplumber +import yaml +from docx import Document +from lxml.cssselect import CSSSelector +from lxml.etree import _Element +from rapidfuzz import fuzz + +from desktop_env.evaluators.metrics.utils import _match_record, _match_value_to_rule + +logger = logging.getLogger("desktopenv.metric.general") + + +def check_include_exclude(result: str, rules: Dict[str, List[str]]) -> float: + if result is None: + return 0. + + print(result, rules) + include = rules.get("include", []) + exclude = rules.get("exclude", []) + if all(r in result for r in include) and all(r not in result for r in exclude): + return 1. + else: + return 0. + + +def exact_match(result, rules) -> float: + expect = rules["expected"] + print(result, expect) + + if result == expect: + return 1. + else: + return 0. + +def match_in_list(result, rules) -> float: + expect = rules["expected"] + print(result, expect) + + if result in expect: + return 1. + else: + return 0. + +def literal_match(result: Any, expected: Any, **options) -> float: + literal_type = options.get('type', 'str') + if literal_type == 'str': + ignore_case = options.get('ignore_case', False) + score = str(result) == str(expected) if not ignore_case else str(result).lower() == str(expected).lower() + return float(score) + elif literal_type == 'list': + if type(result) not in [list, tuple] or type(expected) not in [list, tuple] or len(result) != len(expected): + return .0 + ignore_case = options.get('ignore_case', False) + result = [str(s) for s in result] if not ignore_case else [str(s).lower() for s in result] + expected = [str(s) for s in expected] if not ignore_case else [str(s).lower() for s in expected] + return float(result == expected) + else: + raise NotImplementedError(f"Type {type} not supported") + + +def is_in_list(result, rules) -> float: + expect = rules["expected"] + if expect in result: + return 1. + else: + return 0. + + +def diff_text_file(result: str, expect: str) -> float: + if result is None: + return 0. + + with open(result) as f: + result_lines: List[str] = f.read().splitlines() + with open(expect) as f: + expected_lines: List[str] = f.read().splitlines() + return difflib.SequenceMatcher(a=result_lines, b=expected_lines).ratio() + + +def fuzzy_match(result, rules) -> float: + expect = rules["expected"] + + return fuzz.ratio(result, expect) / 100. + + +def fuzzy_place_math(result_file_path, rules) -> float: + if result_file_path is None: + return 0. + expect = rules["expected"] # a list of possible answers + # read list.docx, and get all texts out, overlook blank lines, remove blanks before and after each line + doc = Document(result_file_path) + words_list = [] + for para in doc.paragraphs: + words_list.extend(para.text.split()) + fuzzy_score_list = [] + for word in words_list: + max_score = 0 + for ans in expect: + score = fuzz.ratio(word, ans) / 100 + max_score = max(max_score, score) + fuzzy_score_list.append(max_score) + if len(fuzzy_score_list) != 3: + return 0. + return sum(fuzzy_score_list) / 3 + + +def check_csv(result: str, rules: Dict[str, List[Dict[str, str]]]) -> float: + """ + Args: + result (str): path to csv file + rules (Dict[str, List[Dict[str, str]]]): dict like + { + "expect": [{key: value}] + "unexpect": [{key: value}] + } + + Returns: + float + """ + + if result is None: + return 0. + + expect_metrics = [False] * len(rules.get("expect", [])) + unexpect_metric = True + with open(result) as f: + reader = csv.DictReader(f) + + for rcd in reader: + for i, r in enumerate(rules.get("expect", [])): + expect_metrics[i] = expect_metrics[i] or _match_record(r, rcd) + unexpect_metric = unexpect_metric and not any(_match_record(r, rcd) for r in rules.get("unexpect", [])) + return float(all(expect_metrics) and unexpect_metric) + + +def check_list(result: str, rules: Dict[str, List[str]]) -> float: + """ + Args: + result (str): path to list file + rules (Dict[str, List[str]]): dict like + { + "expect": list of str as regexes + "unexpect": list of str as regexes + } + + Returns: + float + """ + + if result is None: + return 0. + + expect_patterns: List[Pattern[str]] = [re.compile(ptt) for ptt in rules.get("expect", [])] + unexpect_patterns: List[Pattern[str]] = [re.compile(ptt) for ptt in rules.get("unexpect", [])] + + expect_metrics = [False] * len(expect_patterns) + unexpect_metric = True + with open(result) as f: + for l in f: + for i, r in enumerate(expect_patterns): + expect_metrics[i] = expect_metrics[i] or (r.search(l) is not None) + unexpect_metric = unexpect_metric and all(r.search(l) is None for r in unexpect_patterns) + return float(all(expect_metrics) and unexpect_metric) + + +_accessibility_ns_map = {"st": "uri:deskat:state.at-spi.gnome.org" + , "attr": "uri:deskat:attributes.at-spi.gnome.org" + , "cp": "uri:deskat:component.at-spi.gnome.org" + , "doc": "uri:deskat:document.at-spi.gnome.org" + , "docattr": "uri:deskat:attributes.document.at-spi.gnome.org" + , "txt": "uri:deskat:text.at-spi.gnome.org" + , "val": "uri:deskat:value.at-spi.gnome.org" + , "act": "uri:deskat:action.at-spi.gnome.org" + } + + +def check_accessibility_tree(result: str, rules: List[Dict[str, Any]]) -> float: + """ + Args: + result (str): XML of GNOME Accessibility Tree + rules (List[Dict[str, Any]]): list of dict like + { + "selectors": list of str as CSS selectors, will be connected by ", " + to form a composite selector. Only one from `selectors` and + `xpath` is needed. If both are present, `xpath` takes the + priority. + "xpath": str as xpath. Only one from `selectors` and `xpath` is + needed. If both are present, `xpath` takes the priority. + "text": str as the expected text content of the selected element. + "exact": bool specifying whether exact match or fuzzy match should + be performed. defaults to True. + } + + Returns: + float + """ + + at: _Element = lxml.etree.fromstring(result) + total_match_score = 1. + for r in rules: + if "xpath" in r: + elements: List[_Element] = at.xpath(r["xpath"], namespaces=_accessibility_ns_map) + elif "selectors" in r: + selector = CSSSelector(", ".join(r["selectors"]), namespaces=_accessibility_ns_map) + elements: List[_Element] = selector(at) + else: + raise ValueError("At least one of xpath and selectors is required") + + if len(elements) == 0: + logger.info("No elements: %s", r["xpath"] if "xpath" in r else r["selectors"]) + return 0. + + if "text" in r: + match_func: Callable[[str], Number] = functools.partial(operator.eq if r["exact"] \ + else (lambda a, b: fuzz.ratio(a, b) / 100.) + , r["text"] + ) + match_score: Number = 0 + for elm in elements: + match_score = max(match_score, match_func(elm.text or None)) + else: + match_score = 1. + total_match_score *= match_score + + return float(total_match_score) + + +# def check_existence(result: str, *args) -> float: +# return 1. - (result is None) + +def run_sqlite3(result: str, rules: Dict[str, Any]) -> float: + connection: sqlite3.Connection = sqlite3.connect(result) + cursor: sqlite3.Cursor = connection.execute(rules["sql"]) + return float(cursor.fetchone()[0] or 0) + + +def check_json(result: str, rules: Dict[str, List[Dict[str, Union[List[str], str]]]], is_yaml: bool = False) -> float: + """ + Args: + result (str): path to json file + rules (Dict[str, List[Dict[str, Union[List[str], str]]]]): dict like + { + "expect": [ + { + "key": list of str + "method": str + "ref": something + } + ], + "unexpect": float: + """ + One of the most commonly used function to evalute. + Compare two json objects directly. + """ + if isinstance(result, str): + # remove blanks before and after result + result = result.strip() + # replace all ' with " + result = result.replace("'", '"') + # load json object + result = json.loads(result) + if result is None: + return 0. + try: + expect_in_result = rules.get("expect_in_result", False) + if not expect_in_result: + expected_json = rules["expected"] + for key in expected_json.keys(): + expected_value = expected_json.get(key) + if expected_value != result.get(key): + return 0. + return 1.0 + else: + expected_json = rules["expected"] + + for key in expected_json.keys(): + if isinstance(expected_json.get(key), list): + flag = 0 + expected_value_list = expected_json.get(key) + for each_expected_value in expected_value_list: + if isinstance(result.get(key), list) and each_expected_value in result.get(key): + flag = 1 + break + if flag == 0: + return 0. + elif isinstance(expected_json.get(key), str): + if expected_json.get(key) not in result.get(key): + return 0. + else: + logger.debug("check_direct_json_object: expected value type not supported") + return 0. + return 1.0 + except: + logger.debug("check_direct_json_object: result is not a valid json object") + return 0. + + +def compare_time_in_speedtest_results(speedtest_result_path, time_diff): + if not speedtest_result_path: + return 0 + + # open the speedtest results file(csv) + date_col = None + try: + with open(speedtest_result_path, 'r') as f: + for i, line in enumerate(f): + if i == 1: + date = line.split(',')[1] + break + now_date_time = datetime.datetime.now().strftime('%H:%M') + date_time = date[-5:] + # compare the date time with the current date time, if time diff less than time_diff para, then return true + if not abs((datetime.datetime.strptime(date_time, '%H:%M') - datetime.datetime.strptime(now_date_time, + '%H:%M')).total_seconds()) / 60 < int( + time_diff): + return 0 + return 1 + except: + logger.debug("compare_time_in_speedtest_results: file not found or not readable") + return 0 + + +def is_included_all_json_objects(gold_file_path, result_file_path): + if not gold_file_path or not result_file_path: + return 0 + + print("gold_file_path: ") + print(gold_file_path) + print("result_file_path: ") + print(result_file_path) + # two json file, check if all the key-value pair in gold_file_path is included in result_file_path + with open(gold_file_path, 'r') as f: + gold_json = json.load(f) + with open(result_file_path, 'r') as fr: + result_json = json.load(fr) + for key in gold_json.keys(): + if key not in result_json.keys() or gold_json[key] != result_json[key]: + return 0 + return 1 + + +def is_gold_text_included_in_pdf(pdf_file_path, gold_text_path): + if not gold_text_path or not pdf_file_path: + return 0 + + print("gold_text_path: ") + print(gold_text_path) + print("pdf_file_path: ") + print(pdf_file_path) + # gold file is a json file, we need to check all the value in json are included in pdf file. + with open(gold_text_path, 'r') as f: + gold_json = json.load(f) + with pdfplumber.open(pdf_file_path) as pdf: + text = '' + for page in pdf.pages: + text += page.extract_text() + false_list = [] + for key in gold_json.keys(): + if gold_json[key] not in text: + false_list.append(key) + if len(false_list) > 0: + print("false_list: ") + print(false_list) + return 0 + else: + return 1 + + +def file_contains(file_path, config): + # file_path ends with .txt + if not file_path: + return 0. + try: + with open(file_path, 'r') as f: + file_text = f.read() + for text in config["expected"]: + if text not in file_text: + logger.debug(f"file_contains: {text} not found in {file_path}") + return 0. + except: + logger.debug("file_contains: file not found or not readable") + return 0. + return 1. + + +def check_line_number(file_path, line_number): + # check if file_path exists + if file_path is None or not os.path.isfile(file_path): + return 0. + timeRegex = "([01]\\d|2[0-3]):[0-5]\\d:([0-5]\\d|60)" + # check if the string that matches the timeRegex in this txt file equals to line_number["expected"] + try: + with open(file_path, 'r') as f: + line_count = 0 + for line in f: + if re.search(timeRegex, line): + line_count += 1 + # if line_count equals to line_number["expected"], return 1, else return 0 + return 1 if line_count == int(line_number["expected"]) else 0 + except: + logger.debug("check_line_number: file not found or not readable") + return 0. + + +def compare_terminal_and_txt(txt_file_path, terminal_output): + if not txt_file_path or not terminal_output: + return 0 + + # read txt file content + with open(txt_file_path, 'r') as f: + txt_file_content = f.read() + # compare terminal output with txt file content + return 1 if terminal_output == txt_file_content else 0 + + +def compare_python_pure_text(py_file_path, gold_file_path): + if not py_file_path or not gold_file_path: + return 0 + + # first, change the suffix of gold_file from .txt to .py + print("py_file_path: ") + print(py_file_path) + print("gold_file_path: ") + print(gold_file_path) + + # gold_file_path = gold_file_path.replace('.txt', '.py') + def remove_whitespace(text): + return ''.join(text.split()) + + with open(py_file_path, 'r') as file1: + content1 = file1.read() + with open(gold_file_path, 'r') as file2: + content2 = file2.read() + content1_no_whitespace = remove_whitespace(content1) + content2_no_whitespace = remove_whitespace(content2) + if content1_no_whitespace == content2_no_whitespace: + return 1 + else: + return 0 + +if __name__ == '__main__': + print(check_direct_json_object([], rules={ + "relativeTime": { + "from": "5th next month" + }, + "expected": { + "start": "SEA", + "end": "NYC", + "time": "{DoW}, {Month} {DayD}, {Year}", + "category": "Miles" + }})) diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/gimp.py b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/gimp.py new file mode 100644 index 000000000..c87453b4a --- /dev/null +++ b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/gimp.py @@ -0,0 +1,573 @@ +import os +from typing import List, Union +from skimage.metrics import structural_similarity as ssim +from PIL import Image, ImageChops, ImageStat + + +def compare_image_list(pred_img_path_list: Union[str, List[str]], + gold_img_path_list: Union[str, List[str]]) -> float: + """ Compare two image lists, only if all images are the same, return 1.0, otherwise return 0.0 + """ + if type(pred_img_path_list) != list: + pred_img_path_list = [pred_img_path_list] + gold_img_path_list = [gold_img_path_list] + for pred_img_path, gold_img_path in zip(pred_img_path_list, gold_img_path_list): + if not pred_img_path or not gold_img_path: + return 0.0 + pred_img = Image.open(pred_img_path) + gold_img = Image.open(gold_img_path) + diff = ImageChops.difference(pred_img, gold_img) + if diff.getbbox(): + return 0.0 + return 1.0 + + +def get_gimp_export_path(): + # Path to GIMP's configuration file. This example assumes GIMP version 2.10. + # You need to adjust the path according to the GIMP version and user's file system. + gimp_config_file = os.path.expanduser("~/.config/GIMP/2.10/gimprc") + + try: + # Open and read the configuration file + with open(gimp_config_file, 'r') as file: + for line in file: + # Search for the default export path setting + if "default-export-path" in line: + # Extract the current path from the line (assuming it's enclosed in quotes) + current_path = line.split('"')[1] + # Compare the current path with the expected path + return current_path + except FileNotFoundError: + # Handle the case where the configuration file is not found + print("GIMP configuration file not found") + return False + + +def check_file_exists(directory, filename): + file_path = os.path.join(directory, filename) + return 1 if os.path.isfile(file_path) else 0 + + +def increase_saturation(image1_path: str, image2_path: str) -> float: + def calculate_saturation(image): + # convert the image to HSV mode + hsv_image = image.convert("HSV") + + saturation_channel = hsv_image.split()[1] + + # calculate the mean saturation level + stat = ImageStat.Stat(saturation_channel) + mean_saturation = stat.mean[0] + + return mean_saturation + + image1 = Image.open(image1_path) + image2 = Image.open(image2_path) + + # calculate the saturation level of each image + saturation1 = calculate_saturation(image1) + saturation2 = calculate_saturation(image2) + + return 1 if saturation1 < saturation2 else 0 + + +def decrease_brightness(image1_path: str, image2_path: str) -> float: + def calculate_brightness(image): + # Convert the image to grayscale mode + grayscale_image = image.convert("L") + + # Get the image data + pixels = list(grayscale_image.getdata()) + + brightness = sum(pixels) / len(pixels) + return brightness + + image1 = Image.open(image1_path) + image2 = Image.open(image2_path) + + brightness1 = calculate_brightness(image1) + brightness2 = calculate_brightness(image2) + + return 1 if brightness1 > brightness2 else 0 + + +import cv2 +import numpy as np + + +def find_yellow_triangle(image): + # Convert the image to RGBA + rgba = cv2.cvtColor(image, cv2.COLOR_BGR2RGBA) + + # define range of yellow color in HSV + lower_yellow = np.array([0, 0, 0], dtype=np.uint8) + upper_yellow = np.array([255, 255, 255], dtype=np.uint8) + + # expand the dimensions of lower and upper yellow to match the image dimensions + lower_yellow = np.reshape(lower_yellow, (1, 1, 3)) + upper_yellow = np.reshape(upper_yellow, (1, 1, 3)) + # build a mask for the yellow color + mask = cv2.inRange(rgba, lower_yellow, upper_yellow) + + # search for contours in the mask + contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + + # choose the largest contour + max_contour = max(contours, key=cv2.contourArea) + + # calculate the center of the contour + M = cv2.moments(max_contour) + cx = int(M['m10'] / M['m00']) + cy = int(M['m01'] / M['m00']) + + return cx, cy + + +def compare_triangle_positions(image1, image2): + image1 = cv2.imread(image1, cv2.IMREAD_COLOR) + image2 = cv2.imread(image2, cv2.IMREAD_COLOR) + # find the center of the yellow triangle in each image + cx1, cy1 = find_yellow_triangle(image1) + cx2, cy2 = find_yellow_triangle(image2) + + # calculate the distance between the center of the triangle and the center of the image + center_distance1 = np.sqrt( + (cx1 - image1.shape[1] // 2) ** 2 + (cy1 - image1.shape[0] // 2) ** 2) + center_distance2 = np.sqrt( + (cx2 - image2.shape[1] // 2) ** 2 + (cy2 - image2.shape[0] // 2) ** 2) + + return 1 if center_distance1 > center_distance2 else 0 + + +# Functions for the GIMP evaluator +def calculate_brightness(image): + """Calculate the average brightness of an image""" + grayscale = image.convert('L') + stat = ImageStat.Stat(grayscale) + return stat.mean[0] + + +def normalize_brightness(image, target_brightness): + """Normalize the brightness of an image to a target brightness in [0, 1]""" + current_brightness = calculate_brightness(image) + factor = target_brightness / current_brightness + + # Apply a point transform to each pixel + def point_transform(x): + return min(255, max(0, int(x * factor))) + + return image.point(point_transform) + + +def measure_saturation(hsv_image): + """Measure the average saturation of an image""" + # Split into H, S, V channels + _, s, _ = hsv_image.split() + # Convert the saturation channel to a numpy array + s_array = np.array(s) + # Calculate the average saturation + avg_saturation = np.mean(s_array) + return avg_saturation + + +def calculate_contrast(image): + """Calculate the contrast of an image as the standard deviation of the pixel + values.""" + pixels = np.asarray(image, dtype=np.float32) + return np.std(pixels) + + +def calculate_image_sharpness(image_path): + # Load the image in grayscale + image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE) + # Apply the Laplacian operator + laplacian = cv2.Laplacian(image, cv2.CV_64F) + # Calculate the variance + variance = np.var(laplacian) + return variance + + +def structure_check_by_mse(img1, img2, threshold=0.03): + """Check if two images are approximately the same by MSE""" + mse = np.mean( + (np.array(img1, dtype=np.float32) / 255 + - np.array(img2, dtype=np.float32) / 255) ** 2) + structure_same = True if mse < threshold else False + print("MSE: ", mse) + return structure_same + + +def structure_check_by_ssim(img1, img2, threshold=0.9): + """Check if two images are approximately the same by SSIM""" + similarity = ssim(np.array(img1), np.array(img2), multichannel=True, channel_axis=-1) + print("SSIM: ", similarity) + return similarity >= threshold + + +def check_brightness_decrease_and_structure_sim(src_path, tgt_path): + """ + Check the brightness of src is lower than tgt and the structures are similar + gimp:7a4deb26-d57d-4ea9-9a73-630f66a7b568 + """ + if src_path is None or tgt_path is None: + return 0. + + img_src = Image.open(src_path) + img_tgt = Image.open(tgt_path) + + # Brightness comparison + brightness_src = calculate_brightness(img_src) + brightness_tgt = calculate_brightness(img_tgt) + brightness_reduced = brightness_tgt > brightness_src + + # Normalize and compare images + target_brightness = 128 + img_src_normalized = normalize_brightness(img_src, target_brightness) + img_tgt_normalized = normalize_brightness(img_tgt, target_brightness) + + structure_same = structure_check_by_mse(img_src_normalized, img_tgt_normalized) + if brightness_reduced and structure_same: + return 1. + else: + return 0. + + +def check_saturation_increase_and_structure_sim(src_path, tgt_path): + """ + Check the saturation of src is higher than tgt and the structures are similar + gimp:554785e9-4523-4e7a-b8e1-8016f565f56a + """ + if src_path is None or tgt_path is None: + return 0. + + img_src = Image.open(src_path) + hsv_img_src = img_src.convert('HSV') + img_tgt = Image.open(tgt_path) + hsv_img_tgt = img_tgt.convert('HSV') + + # Saturation comparison + src_saturation = measure_saturation(hsv_img_src) + tgt_saturation = measure_saturation(hsv_img_tgt) + + saturation_increased = tgt_saturation < src_saturation + + # Structure comparison + h1, s1, v1 = hsv_img_src.split() + h2, s2, v2 = hsv_img_tgt.split() + h_same = structure_check_by_ssim(h1, h2) + v_same = structure_check_by_ssim(v1, v2) + if h_same and v_same: + structure_same = True + else: + structure_same = False + + if saturation_increased and structure_same: + return 1. + else: + return 0. + + +def check_file_exists_and_structure_sim(src_path, tgt_path): + """ + Check if the image has been exported to the desktop + gimp:77b8ab4d-994f-43ac-8930-8ca087d7c4b4 + """ + if src_path is None or tgt_path is None: + return 0. + + # Check if the file exists + export_file_exists = os.path.isfile(src_path) + if not export_file_exists: + return 0. + + # Check whether the target image is the same as the source image + img_src = Image.open(src_path) + img_tgt = Image.open(tgt_path) + structure_same = structure_check_by_ssim(img_src, img_tgt) + + if structure_same: + return 1. + else: + return 0. + + +def check_triangle_position(tgt_path): + """ + Check if the triangle is in the middle of the image. + gimp:f4aec372-4fb0-4df5-a52b-79e0e2a5d6ce + """ + if tgt_path is None: + return 0. + + # Load the image + img = Image.open(tgt_path) + img_array = np.array(img) + + # We assume the triangle is a different color from the background + # Find the unique colors + unique_colors, counts = np.unique(img_array.reshape(-1, img_array.shape[2]), axis=0, + return_counts=True) + unique_colors_sorted = unique_colors[np.argsort(counts)] + + # Assuming the background is the most common color and the triangle is a different color + triangle_color = unique_colors_sorted[1] + + # Create a mask where the triangle pixels are True + triangle_mask = np.all(img_array == triangle_color, axis=2) + + # Get the coordinates of the triangle pixels + triangle_coords = np.argwhere(triangle_mask) + + # Calculate the centroid of the triangle + centroid = triangle_coords.mean(axis=0) + + # Check if the centroid is approximately in the middle of the image + image_center = np.array(img_array.shape[:2]) / 2 + + # We will consider the triangle to be in the middle if the centroid is within 5% of the image's center + tolerance = 0.05 * np.array(img_array.shape[:2]) + middle = np.all(np.abs(centroid - image_center) < tolerance) + + if bool(middle): + return 1. + else: + return 0. + + +def check_structure_sim(src_path, tgt_path): + """ + Check if the structure of the two images are similar + gimp:2a729ded-3296-423d-aec4-7dd55ed5fbb3 + """ + if src_path is None or tgt_path is None: + return 0. + + img_src = Image.open(src_path) + img_tgt = Image.open(tgt_path) + structure_same = structure_check_by_ssim(img_src, img_tgt) + if structure_same: + return 1. + else: + return 0. + + +def check_structure_sim_resized(src_path, tgt_path): + """ + Check if the structure of the two images are similar after resizing. + gimp:d16c99dc-2a1e-46f2-b350-d97c86c85c15 + """ + if src_path is None or tgt_path is None: + return 0. + + img_src = Image.open(src_path) + img_tgt = Image.open(tgt_path) + + # Resize the images to the same size + img_src = img_src.resize(img_tgt.size) + + # Check if the structure is similar + structure_same = structure_check_by_ssim(img_src, img_tgt) + return structure_same + + +def check_contrast_increase_and_structure_sim(src_path, tgt_path): + """ + Check if the src image has higher contrast than the tgt image and the structures are similar + gimp:f723c744-e62c-4ae6-98d1-750d3cd7d79d + """ + if src_path is None or tgt_path is None: + return 0. + + # Load images + source_image = Image.open(src_path) + target_image = Image.open(tgt_path) + + # Calculate contrast + source_contrast = calculate_contrast(source_image) + target_contrast = calculate_contrast(target_image) + higher_contrast = target_contrast < source_contrast + + # Check structure + structure_same = structure_check_by_ssim(source_image, target_image, threshold=0.65) + + if higher_contrast and structure_same: + return 1. + else: + return 0. + + +def check_config_status(actual_config_path, rule): + """ + Check if the GIMP status is as expected + """ + if actual_config_path is None: + return 0. + + with open(actual_config_path, 'r') as f: + content = f.readlines() + + for line in content: + if line.startswith('#') or line == '\n': + continue + items = line.strip().lstrip('(').rstrip(')\n').split() + if isinstance(rule["key"], str): + if items[0] == rule["key"] and items[-1] == rule["value"]: + return 1. + elif isinstance(rule["key"], list) and len(rule["key"]) == 2: + if items[0] == rule["key"][0] \ + and items[1] == rule["key"][1] \ + and items[-1] == rule["value"]: + return 1. + return 0. + + +def check_image_size(src_path, rule): + """ + Check if the size of the src image is correct + multi-apps:42f4d1c7-4521-4161-b646-0a8934e36081 + """ + if src_path is None: + return 0. + + # Load the image + img = Image.open(src_path) + + # Check the size + if rule.get("height", None) is not None: + height_same = img.size[1] == rule["height"] + else: + height_same = True + if rule.get("width", None) is not None: + width_same = img.size[0] == rule["width"] + else: + width_same = True + + if height_same and width_same: + return 1. + else: + return 0. + + +def check_palette_and_structure_sim(src_path, tgt_path): + """ + Check if the src image is palette-based and the structure of the two images are similar + gimp:06ca5602-62ca-47f6-ad4f-da151cde54cc + """ + if src_path is None or tgt_path is None: + return 0. + + # Check if the source image is palette-based + source_image = Image.open(src_path) + palette_based = source_image.mode == 'P' + + # Check structure + target_image = Image.open(tgt_path) + source_image = source_image.convert('RGB') + structure_same = structure_check_by_ssim(source_image, target_image) + if palette_based and structure_same: + return 1. + else: + return 0. + + +def check_textbox_on_leftside(src_path): + """ + Check if the textbox is on the left side of the image. + gimp:e2dd0213-26db-4349-abe5-d5667bfd725c + """ + if src_path is None: + return 0. + + source_image = Image.open(src_path) + gray_image = source_image.convert("L") + width, height = source_image.size + + # Find the bounds of the black text + left_most_dark_pixel = width # Start with the farthest possible left position + for y in range(height): + for x in range(width): + # If the pixel is dark, consider it as part of the text + if gray_image.getpixel((x, y)) < 128: # Arbitrary threshold for "dark" + left_most_dark_pixel = min(left_most_dark_pixel, x) + break # Stop after finding the first dark pixel in this row + + # Here we define "almost" on the left side as being within the left 5% of the image + if left_most_dark_pixel < width * 0.05: + return 1. + else: + return 0. + + +def check_image_mirror(src_path, tgt_path): + """ + Check if the image is mirrored + gimp:72f83cdc-bf76-4531-9a1b-eb893a13f8aa + """ + if src_path is None or tgt_path is None: + return 0. + + # Load images + source_image = Image.open(src_path) + target_image = Image.open(tgt_path) + + # Check if the image is mirrored + transposed_image = source_image.transpose(Image.FLIP_LEFT_RIGHT) + # Use 0.99 because the image may not be exactly mirrored by gimp + mirrored = structure_check_by_ssim(transposed_image, target_image, 0.99) + if mirrored: + return 1. + else: + return 0. + + +def check_green_background(src_path, tgt_path): + """ + Check if the background of the source image is green. + gimp:734d6579-c07d-47a8-9ae2-13339795476b + """ + if src_path is None or tgt_path is None: + return 0. + + # Load images + source_image = Image.open(src_path) + target_image = Image.open(tgt_path) + + source_pixels = np.array(source_image) + target_pixels = np.array(target_image) + + for x in range(target_image.width): + for y in range(target_image.height): + # Identify background pixel in target image (not black) + if tuple(target_pixels[x, y][:3]) != (0, 0, 0): + # Check if corresponding pixel in source image is green + # Here, "green" means more green than red or blue + r, g, b = source_pixels[x, y][:3] + if not (g > r and g > b): + return 0. + + return 1. + + +def check_sharper(src_path, tgt_path): + """ + Check if the source image is sharper than the target image. + multi-app:bb7db4c2-30b5-4be7-8dd7-b8c4ec7d3108 + """ + sharpness_src = calculate_image_sharpness(src_path) + sharpness_tgt = calculate_image_sharpness(tgt_path) + return 1.0 if sharpness_src > sharpness_tgt else 0.0 + + +def check_image_file_size(src_path, rule): + """ + Check if the size of the src image within 500KB + """ + if src_path is None: + return 0.0 + + # Check the size + file_size = os.path.getsize(src_path) + if file_size < rule["max_size"]: + return 1.0 + else: + return 0.0 diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/libreoffice.py b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/libreoffice.py new file mode 100644 index 000000000..1870c345b --- /dev/null +++ b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/libreoffice.py @@ -0,0 +1,28 @@ +import fnmatch +from typing import Dict, List + +import lxml.cssselect +import lxml.etree +from lxml.etree import _Element as Element + +_libconf_namespaces = [("oor", "http://openoffice.org/2001/registry")] +_libconf_ns_mapping = dict(_libconf_namespaces) +_setup_locale_selector = lxml.cssselect.CSSSelector('item[oor|path$=L10N]>prop[oor|name=ooSetupSystemLocale]>value', + namespaces=_libconf_ns_mapping) +_locale_selector = lxml.cssselect.CSSSelector('item[oor|path$=L10N]>prop[oor|name=ooLocale]>value', + namespaces=_libconf_ns_mapping) + + +def check_libre_locale(config_file: str, rules: Dict[str, List[str]]) -> float: + config: Element = lxml.etree.parse(config_file).getroot() + setup_locale_setting: List[Element] = _setup_locale_selector(config) + locale_setting: List[Element] = _locale_selector(config) + + setup_locale_setting: str = setup_locale_setting[0].text \ + if len(setup_locale_setting) > 0 \ + else locale_setting[0].text + + return float(any(fnmatch.fnmatchcase(setup_locale_setting, ptn) \ + for ptn in rules["locale_set"] + ) + ) diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/others.py b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/others.py new file mode 100644 index 000000000..ebb599414 --- /dev/null +++ b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/others.py @@ -0,0 +1,90 @@ +import logging +import os +import os.path +import zipfile +from typing import List, Dict +from typing import Union, TypeVar + +import lxml.html +from lxml.html import HtmlElement +from mutagen.easyid3 import EasyID3 + +from .general import diff_text_file +from .utils import _match_value_to_rule + +logger = logging.getLogger("desktopenv.metric.others") + + +def process_epub(filename: str) -> List[str]: + file_list: List[str] = [] + + base_dir: str = filename + ".dir" + os.makedirs(base_dir, exist_ok=True) + + try: + with zipfile.ZipFile(filename, "r") as z_f: + with z_f.open("toc.ncx") as in_f \ + , open(os.path.join(base_dir, "toc.ncx"), "w") as out_f: + contents: str = in_f.read().decode() + contents = contents.splitlines() + for l in contents: + if "navPoint" not in l: + out_f.write(l + "\n") + file_list.append(os.path.join(base_dir, "toc.ncx")) + with z_f.open("content.opf") as in_f \ + , open(os.path.join(base_dir, "content.opf"), "w") as out_f: + contents: str = in_f.read().decode() + contents = contents.splitlines() + for l in contents: + if "dc:identifier" not in l: + out_f.write(l + "\n") + file_list.append(os.path.join(base_dir, "content.opf")) + for f_n in z_f.namelist(): + if f_n.endswith(".html"): + with z_f.open(f_n) as in_f \ + , open(os.path.join(base_dir, f_n), "w") as out_f: + html: HtmlElement = lxml.html.fromstring( + ''.join(filter(lambda ch: ch != "\n" and ch != "\r" + , in_f.read().decode() + ) + ).encode() + ) + out_f.write(lxml.html.tostring(html, pretty_print=True, encoding="unicode")) + file_list.append(os.path.join(base_dir, f_n)) + logger.debug("%s: %s", filename, file_list) + return list(sorted(file_list)) + except zipfile.BadZipFile: + return [] + + +def compare_epub(result: str, expected: str) -> float: + if result is None: + return 0. + result_files: List[str] = process_epub(result) + expected_files: List[str] = process_epub(expected) + + metric: float = 1. + for f1, f2 in zip(result_files, expected_files): + current_metric: float = diff_text_file(f1, f2) + logger.debug("%s vs %s: %f", f1, f2, current_metric) + metric *= current_metric + return metric + + +V = TypeVar("Value") + + +def check_mp3_meta(result: str, meta: Dict[str, Dict[str, Union[str, V]]]) -> bool: + # checks using _match_value_to_rule + if result is None: + return 0. + + id3_dict = EasyID3(result) + metric: bool = True + for k, r in meta.items(): + value = id3_dict.get(k, "") + if isinstance(value, list): + value: str = ",".join(value) + logger.debug("%s.%s: %s", result, k, value) + metric = metric and _match_value_to_rule(value, r) + return float(metric) diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/pdf.py b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/pdf.py new file mode 100644 index 000000000..ef5b38491 --- /dev/null +++ b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/pdf.py @@ -0,0 +1,31 @@ +import operator +from typing import Any +from typing import Dict + +import fitz # PyMuPDF +from pypdf import PdfReader + + +def check_pdf_pages(pdf_file: str, rules: Dict[str, Any]) -> float: + if pdf_file is None: + return 0.0 + reader = PdfReader(pdf_file) + nb_pages: int = len(reader.pages) + return float(getattr(operator, rules["relation"])(nb_pages, rules["ref_value"])) + + +def extract_answers_from_pdf(pdf_file): + doc = fitz.open(pdf_file) + answers = [] + + for page in doc: + text = page.get_text() + lines = text.split('\n') + for line in lines: + if line.strip(): + parts = line.split('=') + if len(parts) > 1: + answer = parts[-1].strip() + answers.append(answer) + + return answers diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/slides.py b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/slides.py new file mode 100644 index 000000000..307752fd7 --- /dev/null +++ b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/slides.py @@ -0,0 +1,582 @@ +import logging +import xml.etree.ElementTree as ET +import zipfile +from math import sqrt + +from pptx import Presentation +from pptx.util import Inches + +logger = logging.getLogger("desktopenv.metric.slides") + + +def check_presenter_console_disable(config_file_path): + try: + tree = ET.parse(config_file_path) + root = tree.getroot() + + namespaces = { + 'oor': 'http://openoffice.org/2001/registry' + } + + for item in root.findall( + ".//item[@oor:path='/org.openoffice.Office.Impress/Misc/Start']/prop[@oor:name='EnablePresenterScreen']", + namespaces): + # Check if the value of the configuration item indicates that the presenter console has been disabled + presenter_screen_enabled = item.find('value').text + if presenter_screen_enabled.lower() == 'false': + return 1. + else: + return 0. + return 0. + except Exception as e: + logger.error(f"Error: {e}") + return 0. + + +def check_image_stretch_and_center(modified_ppt, original_ppt): + # fixme: this func is overfit to this example libreoffice_impress + # Load the presentations + original_pres = Presentation(original_ppt) + modified_pres = Presentation(modified_ppt) + + # Get the first slide of each presentation + original_slide = original_pres.slides[0] + modified_slide = modified_pres.slides[0] + + # Get the image on the first slide of each presentation + original_slide_images = [shape for shape in original_slide.shapes if shape.shape_type == 13] + modified_slide_images = [shape for shape in modified_slide.shapes if shape.shape_type == 13] + + the_image = original_slide_images[0] + + the_modified_image = None + + # Get the images that modified in width and height + for modified_image in modified_slide_images: + if the_image.image.blob == modified_image.image.blob: + the_modified_image = modified_image + + if the_modified_image is None: + return 0. + + if (abs(the_modified_image.width - original_pres.slide_width) > Inches(0.5) or + abs(the_modified_image.height - original_pres.slide_height) > Inches(0.5) or + abs(the_modified_image.left - (original_pres.slide_width - the_modified_image.width) / 2) > Inches(0.5) or + abs(the_modified_image.top - (original_pres.slide_height - the_modified_image.height) / 2) > Inches(0.5)): + return 0. + + return 1. + + +def is_red_color(color): + # judge if the color is red + return color and color.rgb == (255, 0, 0) + + +def get_master_placeholder_color(prs): + # get the color of the placeholder + masters = prs.slide_masters + for idx, master in enumerate(masters): + for placeholder in master.placeholders: + if placeholder.has_text_frame and placeholder.text == "": + text_frame = placeholder.text_frame + + if text_frame.paragraphs: + first_paragraph = text_frame.paragraphs[0] + return first_paragraph.font.color + return None + + +def check_slide_numbers_color(pptx_file_path): + presentation = Presentation(pptx_file_path) + + for i, slide in enumerate(presentation.slides): + for shape in slide.shapes: + # check if the shape is a text box + if hasattr(shape, "text"): + if shape.text.isdigit(): + # "SlidePlaceholder" is the name of the placeholder in the master slide + page_number_text = shape.text + font_color = get_master_placeholder_color(presentation) + return 1 if font_color is not None and is_red_color(font_color) else 0 + + +# import numpy as np +# from PIL import Image +# from skimage.metrics import structural_similarity as ssim + +# def compare_images(image1_path, image2_path): +# # You would call this function with the paths to the two images you want to compare: +# # score = compare_images('path_to_image1', 'path_to_image2') +# # print("Similarity score:", score) + +# if not image1_path or not image2_path: +# return 0 + +# # Open the images and convert to grayscale +# image1 = Image.open(image1_path).convert('L') +# image2 = Image.open(image2_path).convert('L') + +# # Resize images to the smaller one's size for comparison +# image1_size = image1.size +# image2_size = image2.size +# new_size = min(image1_size, image2_size) + +# image1 = image1.resize(new_size, Image.Resampling.LANCZOS) +# image2 = image2.resize(new_size, Image.Resampling.LANCZOS) + +# # Convert images to numpy arrays +# image1_array = np.array(image1) +# image2_array = np.array(image2) + +# # Calculate SSIM between two images +# similarity_index = ssim(image1_array, image2_array) + +# return similarity_index + +def compare_pptx_files(file1_path, file2_path, **options): + # todo: not strictly match since not all information is compared because we cannot get the info through pptx + prs1 = Presentation(file1_path) + prs2 = Presentation(file2_path) + + examine_number_of_slides = options.get("examine_number_of_slides", True) + examine_shape = options.get("examine_shape", True) + examine_text = options.get("examine_text", True) + examine_indent = options.get("examine_indent", True) + examine_font_name = options.get("examine_font_name", True) + examine_font_size = options.get("examine_font_size", True) + examine_font_bold = options.get("examine_font_bold", True) + examine_font_italic = options.get("examine_font_italic", True) + examine_color_rgb = options.get("examine_color_rgb", True) + examine_font_underline = options.get("examine_font_underline", True) + examine_strike_through = options.get("examine_strike_through", True) + examine_alignment = options.get("examine_alignment", True) + examine_title_bottom_position = options.get("examine_title_bottom_position", False) + examine_table_bottom_position = options.get("examine_table_bottom_position", False) + examine_right_position = options.get("examine_right_position", False) + examine_top_position = options.get("examine_top_position", False) + examine_shape_for_shift_size = options.get("examine_shape_for_shift_size", False) + examine_image_size = options.get("examine_image_size", False) + examine_modify_height = options.get("examine_modify_height", False) + examine_bullets = options.get("examine_bullets", True) + examine_background_color = options.get("examine_background_color", True) + examine_note = options.get("examine_note", True) + + # compare the number of slides + if len(prs1.slides) != len(prs2.slides) and examine_number_of_slides: + return 0 + + slide_idx = 0 + # compare the content of each slide + for slide1, slide2 in zip(prs1.slides, prs2.slides): + slide_idx += 1 + + def get_slide_background_color(slide): + # background = slide.background + # if background.fill.background(): + # return background.fill.fore_color.rgb + # else: + # return None + fill = slide.background.fill + if fill.type == 1: + return fill.fore_color.rgb + elif fill.type == 5: + master_fill = slide.slide_layout.slide_master.background.fill + if master_fill.type == 1: + return master_fill.fore_color.rgb + else: + return None + else: + return None + + if get_slide_background_color(slide1) != get_slide_background_color(slide2) and examine_background_color: + return 0 + + def get_slide_notes(slide): + notes_slide = slide.notes_slide + if notes_slide: + return notes_slide.notes_text_frame.text + else: + return None + + if get_slide_notes(slide1).strip() != get_slide_notes(slide2).strip() and examine_note: + return 0 + + # check if the number of slides is the same + if len(slide1.shapes) != len(slide2.shapes): + return 0 + + # check if the shapes are the same + for shape1, shape2 in zip(slide1.shapes, slide2.shapes): + if examine_title_bottom_position: + if hasattr(shape1, "text") and hasattr(shape2, "text") and shape1.text == shape2.text: + if shape1.text == "Product Comparison" and (shape1.top <= shape2.top or shape1.top < 3600000): + return 0 + elif shape1.left != shape2.left or shape1.top != shape2.top or shape1.width != shape2.width or shape1.height != shape2.height: + return 0 + + if examine_table_bottom_position: + if slide_idx == 3 and shape1.shape_type == 19 and shape2.shape_type == 19: + if shape1.top <= shape2.top or shape1.top < 3600000: + return 0 + elif shape1.left != shape2.left or shape1.top != shape2.top or shape1.width != shape2.width or shape1.height != shape2.height: + return 0 + + if examine_right_position: + if slide_idx == 2 and not hasattr(shape1, "text") and not hasattr(shape2, "text"): + if shape1.left <= shape2.left or shape1.left < 4320000: + return 0 + + if examine_top_position: + if slide_idx == 2 and shape1.shape_type == 13 and shape2.shape_type == 13: + if shape1.top >= shape2.top or shape1.top > 1980000: + return 0 + elif shape1.left != shape2.left or shape1.top != shape2.top or shape1.width != shape2.width or shape1.height != shape2.height: + return 0 + + if examine_shape_for_shift_size: + if shape1.left != shape2.left or shape1.top != shape2.top or shape1.width != shape2.width or shape1.height != shape2.height: + if not (hasattr(shape1, "text") and hasattr(shape2, + "text") and shape1.text == shape2.text and shape1.text == "Elaborate on what you want to discuss."): + return 0 + + if ( + shape1.left != shape2.left or shape1.top != shape2.top or shape1.width != shape2.width or shape1.height != shape2.height) and examine_shape: + return 0 + + if examine_image_size: + if shape1.shape_type == 13 and shape2.shape_type == 13: + if shape1.width != shape2.width or shape1.height != shape2.height: + return 0 + elif shape1.left != shape2.left or shape1.top != shape2.top or shape1.width != shape2.width or shape1.height != shape2.height: + return 0 + + if examine_modify_height: + if not hasattr(shape1, "text") and not hasattr(shape2, + "text") or shape1.shape_type == 5 and shape2.shape_type == 5: + if shape1.height != shape2.height: + return 0 + elif shape1.left != shape2.left or shape1.top != shape2.top or shape1.width != shape2.width or shape1.height != shape2.height: + return 0 + + if hasattr(shape1, "text") and hasattr(shape2, "text"): + if shape1.text.strip() != shape2.text.strip() and examine_text: + return 0 + + # check if the paragraphs are the same + for para1, para2 in zip(shape1.text_frame.paragraphs, shape2.text_frame.paragraphs): + if para1.alignment != para2.alignment and examine_alignment: + return 0 + + # check if the runs are the same + if para1.text != para2.text and examine_text: + return 0 + + if para1.level != para2.level and examine_indent: + return 0 + + for run1, run2 in zip(para1.runs, para2.runs): + + # check if the font properties are the same + if run1.font.name != run2.font.name and examine_font_name: + return 0 + + if run1.font.size != run2.font.size and examine_font_size: + return 0 + + if run1.font.bold != run2.font.bold and examine_font_bold: + return 0 + + if run1.font.italic != run2.font.italic and examine_font_italic: + return 0 + + if hasattr(run1.font.color, "rgb") and hasattr(run2.font.color, "rgb"): + if run1.font.color.rgb != run2.font.color.rgb and examine_color_rgb: + return 0 + + if run1.font.underline != run2.font.underline and examine_font_underline: + return 0 + + if run1.font._element.attrib.get('strike', 'noStrike') != run2.font._element.attrib.get( + 'strike', 'noStrike') and examine_strike_through: + return 0 + + def _extract_bullets(xml_data): + root = ET.fromstring(xml_data) + + namespaces = { + 'a': 'http://schemas.openxmlformats.org/drawingml/2006/main', + 'p': 'http://schemas.openxmlformats.org/presentationml/2006/main', + } + + bullets = [] + + for paragraph in root.findall('.//a:p', namespaces): + pPr = paragraph.find('a:pPr', namespaces) + if pPr is not None: + lvl = pPr.get('lvl') + buChar = pPr.find('a:buChar', namespaces) + char = buChar.get('char') if buChar is not None else "No Bullet" + buClr = pPr.find('a:buClr/a:srgbClr', namespaces) + color = buClr.get('val') if buClr is not None else "No Color" + else: + lvl = "No Level" + char = "No Bullet" + color = "No Color" + + text = "".join(t.text for t in paragraph.findall('.//a:t', namespaces)) + + bullets.append((lvl, char, text, color)) + + return bullets + + if examine_bullets and _extract_bullets(run1.part.blob.decode('utf-8')) != _extract_bullets( + run2.part.blob.decode('utf-8')): + return 0 + + # fixme: Actually there are more properties to be compared, we can add them later via parsing the xml data + + return 1 + + +def check_strikethrough(pptx_path, rules): + # Load the presentation + presentation = Presentation(pptx_path) + + slide_index_s = rules["slide_index_s"] + shape_index_s = rules["shape_index_s"] + paragraph_index_s = rules["paragraph_index_s"] + + try: + for slide_index in slide_index_s: + # Get the slide + slide = presentation.slides[slide_index] + + for shape_index in shape_index_s: + # Get the text box + paragraphs = slide.shapes[shape_index].text_frame.paragraphs + + for paragraph_index in paragraph_index_s: + paragraph = paragraphs[paragraph_index] + run = paragraph.runs[0] + if 'strike' not in run.font._element.attrib: + return 0 + + + except Exception as e: + logger.error(f"Error: {e}") + return 0 + + return 1 + + +def check_slide_orientation_Portrait(pptx_path): + presentation = Presentation(pptx_path) + + slide_height = presentation.slide_height + slide_width = presentation.slide_width + + if slide_width < slide_height: + return 1 + return 0 + + +def evaluate_presentation_fill_to_rgb_distance(pptx_file, rules): + rgb = rules["rgb"] + + try: + original_rgb = rules["original_rgb"] + except: + original_rgb = None + + def get_rgb_from_color(color): + try: + if hasattr(color, "rgb"): + return color.rgb + else: + return None + except: + return None + + def slide_fill_distance_to_rgb(_slide, _rgb, _original_rgb): + fill = _slide.background.fill + if fill.type == 1: + color_rgb = get_rgb_from_color(fill.fore_color) + if color_rgb is None: + return 1 + r1, g1, b1 = color_rgb + r2, g2, b2 = _rgb + + if _original_rgb is not None: + r3, g3, b3 = _original_rgb + if r1 == r3 and g1 == g3 and b1 == b3: + return 1 + + return sqrt((r1 - r2) ** 2 + (g1 - g2) ** 2 + (b1 - b2) ** 2) / sqrt(255 ** 2 + 255 ** 2 + 255 ** 2) + elif fill.type == 5: + master_fill = _slide.slide_layout.slide_master.background.fill + if master_fill.type == 1: + color_rgb = get_rgb_from_color(master_fill.fore_color) + if color_rgb is None: + return 1 + r1, g1, b1 = color_rgb + else: + return 1 + r2, g2, b2 = _rgb + + if _original_rgb is not None: + r3, g3, b3 = _original_rgb + if r1 == r3 and g1 == g3 and b1 == b3: + return 1 + + return sqrt((r1 - r2) ** 2 + (g1 - g2) ** 2 + (b1 - b2) ** 2) / sqrt(255 ** 2 + 255 ** 2 + 255 ** 2) + + return 1 + + prs = Presentation(pptx_file) + similarity = 1 - sum(slide_fill_distance_to_rgb(slide, rgb, original_rgb) for slide in prs.slides) / len(prs.slides) + return similarity + + +def check_left_panel(accessibility_tree): + namespaces = { + 'st': 'uri:deskat:state.at-spi.gnome.org', + 'cp': 'uri:deskat:component.at-spi.gnome.org' + } + + root = ET.fromstring(accessibility_tree) + + for root_pane in root.iter('root-pane'): + for panel in root_pane.iter('panel'): + for split_pane in panel.iter('split-pane'): + # Get the left panel + if split_pane.attrib.get("{{{}}}parentcoord".format(namespaces['cp'])) == "(0, 0)": + # Get the visible attribute + visible = split_pane.attrib.get("{{{}}}visible".format(namespaces['st'])) + if visible: + # decide if it is left panel + return 1. + + return 0. + + +def check_transition(pptx_file, rules): + slide_idx = rules['slide_idx'] + transition_type = rules['transition_type'] + + # Use the zipfile module to open the .pptx file + with zipfile.ZipFile(pptx_file, 'r') as zip_ref: + # Get the slide XML file + slide_name = 'ppt/slides/slide{}.xml'.format(slide_idx + 1) + try: + zip_ref.getinfo(slide_name) + except KeyError: + # Slide does not exist + return 0. + + with zip_ref.open(slide_name) as slide_file: + # 解析XML + tree = ET.parse(slide_file) + root = tree.getroot() + + # XML namespace + namespaces = { + 'a': 'http://schemas.openxmlformats.org/drawingml/2006/main', + 'p': 'http://schemas.openxmlformats.org/presentationml/2006/main', + } + + # Search for the transition element + transition = root.find('.//p:transition', namespaces) + if transition is not None: + # Check if the transition is an expected transition + dissolve = transition.find('.//p:{}'.format(transition_type), namespaces) + if dissolve is not None: + return 1. + else: + return 0. + else: + return 0. + + +def check_page_number_colors(pptx_file, rules): + color = rules["color"] + + def is_red(rgb_str, threshold=50): + r, g, b = int(rgb_str[1:3], 16), int(rgb_str[3:5], 16), int(rgb_str[5:7], 16) + return r > g + threshold and r > b + threshold + + def is_blue(rgb_str, threshold=50): + r, g, b = int(rgb_str[1:3], 16), int(rgb_str[3:5], 16), int(rgb_str[5:7], 16) + return b > g + threshold and b > r + threshold + + def is_green(rgb_str, threshold=50): + r, g, b = int(rgb_str[1:3], 16), int(rgb_str[3:5], 16), int(rgb_str[5:7], 16) + return g > r + threshold and g > b + threshold + + def is_black(rgb_str, threshold=50): + r, g, b = int(rgb_str[1:3], 16), int(rgb_str[3:5], 16), int(rgb_str[5:7], 16) + return r < threshold and g < threshold and b < threshold + + with zipfile.ZipFile(pptx_file, 'r') as zip_ref: + slide_master_name = 'ppt/slideMasters/slideMaster1.xml' + with zip_ref.open(slide_master_name) as slide_master_file: + tree = ET.parse(slide_master_file) + root = tree.getroot() + + namespaces = { + 'a': 'http://schemas.openxmlformats.org/drawingml/2006/main', + 'p': 'http://schemas.openxmlformats.org/presentationml/2006/main', + } + + color_elems = root.findall('.//a:solidFill//a:srgbClr', namespaces) + slides_color_val = color_elems[-2].get('val') + + if slides_color_val is None: + return 0 + elif color == "red" and not is_red(slides_color_val): + return 0 + elif color == "blue" and not is_blue(slides_color_val): + return 0 + elif color == "green" and not is_green(slides_color_val): + return 0 + elif color == "black" and not is_black(slides_color_val): + return 0 + + return 1 + + +def check_auto_saving_time(pptx_file, rules): + minutes = rules["minutes"] + + # open and parse xml file + try: + tree = ET.parse(pptx_file) + root = tree.getroot() + + # Traverse the XML tree to find the autosave time setting + autosave_time = None + for item in root.findall(".//item"): + # Check the path attribute + path = item.get('{http://openoffice.org/2001/registry}path') + if path == "/org.openoffice.Office.Common/Save/Document": + # Once the correct item is found, look for the prop element with the name "AutoSaveTimeIntervall" + for prop in item.findall(".//prop"): + name = prop.get('{http://openoffice.org/2001/registry}name') + if name == "AutoSaveTimeIntervall": + # Extract the value of the autosave time interval + autosave_time = prop.find(".//value").text + break + + if autosave_time is None: + return 0 + else: + autosave_time = int(autosave_time) + if autosave_time == minutes: + return 1 + else: + return 0 + + except ET.ParseError as e: + logger.error(f"Error parsing XML: {e}") + except FileNotFoundError: + logger.error(f"File not found: {pptx_file}") diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/table.py b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/table.py new file mode 100644 index 000000000..9e888c7db --- /dev/null +++ b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/table.py @@ -0,0 +1,517 @@ +import functools +import itertools +import logging +import os.path +# import operator +from numbers import Number +from typing import Any, Union, cast, Callable, Iterable +from typing import Dict, List, Tuple, Set + +import openpyxl +import pandas as pd +from openpyxl import Workbook +from openpyxl.cell.cell import Cell +from openpyxl.utils import get_column_letter +from openpyxl.worksheet.cell_range import MultiCellRange +from openpyxl.worksheet.datavalidation import DataValidation +from openpyxl.worksheet.worksheet import Worksheet +from rapidfuzz import fuzz + +from desktop_env.evaluators.metrics.utils import _match_value_to_rule, _read_cell_style, read_cell_value +from desktop_env.evaluators.metrics.utils import load_charts, load_sparklines, load_rows_or_cols, load_xlsx_styles \ + , load_filters, load_pivot_tables + +# from openpyxl.utils import coordinate_to_tuple + +logger = logging.getLogger("desktopenv.metric.table") + +BOOK = Union[pd.ExcelFile, Workbook, str] + + +def _parse_sheet_idx(sheet_idx: Union[int, str] + , result: BOOK, expected: BOOK + , result_sheet_names: List[str] + , expected_sheet_names: List[str] + ) -> Tuple[BOOK, str]: + # function _parse_sheet_idx {{{ # + if isinstance(sheet_idx, int): + try: + index: str = result_sheet_names[sheet_idx] + except: + index = "" + book: BOOK = result + elif sheet_idx.startswith("RI"): + try: + index: str = result_sheet_names[int(sheet_idx[2:])] + except: + index = "" + book: BOOK = result + elif sheet_idx.startswith("RN"): + index: str = sheet_idx[2:] + book: BOOK = result + elif sheet_idx.startswith("EI"): + try: + index: str = expected_sheet_names[int(sheet_idx[2:])] + except: + index = "" + book: BOOK = expected + elif sheet_idx.startswith("EN"): + index: str = sheet_idx[2:] + book: BOOK = expected + else: + logger.error("Unrecognized sheet index") + raise ValueError("Unrecognized sheet index") + return book, index + # }}} function _parse_sheet_idx # + + +SHEET = Union[pd.DataFrame, Worksheet, List[str]] + + +def _load_sheet(book: BOOK, index: str) -> SHEET: + # function _load_sheet {{{ # + try: + if isinstance(book, str): + book: str = cast(str, book) + csv_name: str = "{:}-{:}.csv".format(os.path.splitext(book)[0], index) + + with open(csv_name) as f: + csv_lines: List[str] = list(itertools.dropwhile(lambda l: len(l) == 0 + , map(lambda l: l.strip() + , reversed(f.read().splitlines()) + ) + ) + ) + return csv_lines + if isinstance(book, pd.ExcelFile): + return pd.read_excel(book, index) + if isinstance(book, Workbook): + return book[index] + logger.error("Not supported workbook format") + raise NotImplementedError("Not supported workbook format") + except NotImplementedError as e: + raise e + except: + return None + # }}} function _load_sheet # + + +def compare_table(result: str, expected: str = None, **options) -> float: + # function compare_table {{{ # + """ + Args: + result (str): path to result xlsx + expected (str): path to golden xlsx + rules (List[Dict[str, Any]]): list of dict like + { + "type": str, + : anything + } + as sequential rules + + Returns: + float: the score + """ + + if result is None: + return 0. + + try: + xlworkbookr: Workbook = openpyxl.load_workbook(filename=result) + pdworkbookr = pd.ExcelFile(result) + except: + return 0. + worksheetr_names: List[str] = pdworkbookr.sheet_names + + if expected is not None: + xlworkbooke: Workbook = openpyxl.load_workbook(filename=expected) + pdworkbooke = pd.ExcelFile(expected) + worksheete_names: List[str] = pdworkbooke.sheet_names + else: + xlworkbooke: Workbook = None + pdworkbooke = None + worksheete_names: List[str] = None + + parse_idx: Callable[[Union[str, int], BOOK, BOOK], Tuple[BOOK, str]] = \ + functools.partial( + _parse_sheet_idx, + result_sheet_names=worksheetr_names, + expected_sheet_names=worksheete_names + ) + + passes = True + for r in options["rules"]: + if r["type"] == "sheet_name": + # Compare Sheet Names {{{ # + metric: bool = worksheetr_names == worksheete_names + logger.debug("Assertion: %s.sheet_names == %s.sheet_names - %s", result, expected, metric) + # }}} Compare Sheet Names # + + elif r["type"] == "sheet_data": + # Compare Sheet Data by Internal Value {{{ # + # sheet_idx0: 0 == "RI0" == "RNSheet1" | "EI0" == "ENSheet1" + # sheet_idx1: as sheet_idx0 + # precision: int as number of decimal digits, default to 4 + + error_limit: int = r.get("precision", 4) + sheet1: pd.DataFrame = _load_sheet(*parse_idx(r["sheet_idx0"], pdworkbookr, pdworkbooke)) + if sheet1 is None: + return 0. + sheet2: pd.DataFrame = _load_sheet(*parse_idx(r["sheet_idx1"], pdworkbookr, pdworkbooke)) + + sheet1 = sheet1.round(error_limit) + sheet2 = sheet2.round(error_limit) + metric: bool = sheet1.equals(sheet2) + logger.debug("Sheet1: \n%s", str(sheet1)) + logger.debug("Sheet2: \n%s", str(sheet2)) + try: + logger.debug("Sheet1 =v= Sheet2: \n%s", str(sheet1 == sheet2)) + except: + logger.debug("Sheet1 =/v= Sheet2") + logger.debug("Assertion: %s =v= %s - %s", r["sheet_idx0"], r["sheet_idx1"], metric) + # }}} Compare Sheet Data by Internal Value # + + elif r["type"] == "sheet_print": + # Compare Sheet Data by Printed Value {{{ # + # sheet_idx0: 0 == "RI0" == "RNSheet1" | "EI0" == "ENSheet1" + # sheet_idx1: as sheet_idx0 + # ignore_case: optional, defaults to False + + sheet1: List[str] = _load_sheet(*parse_idx(r["sheet_idx0"], result, expected)) + if sheet1 is None: + return 0. + sheet2: List[str] = _load_sheet(*parse_idx(r["sheet_idx1"], result, expected)) + if r.get("ignore_case", False): + sheet1 = [l.lower() for l in sheet1] + sheet2 = [l.lower() for l in sheet2] + metric: bool = sheet1 == sheet2 + logger.debug("Assertion: %s =p= %s - %s", r["sheet_idx0"], r["sheet_idx1"], metric) + # }}} Compare Sheet Data by Printed Value # + + elif r["type"] == "sheet_fuzzy": + # Fuzzy Match for Ranges {{{ # + # sheet_idx0: 0 == "RI0" == "RNSheet1" | "EI0" == "ENSheet1" + # sheet_idx1: as sheet_idx0 + # rules: list of dict, each dict is like + # { "range": ["A1:B6", "C2:E5"], + # "type": "includes" | "included_by" | "fuzzy_match" | "exact_match", # 0 includes 1, 0 includes_by 1 + # "threshold": 85, // for fuzzy match + # "ignore_case": true | false, + # "ignore_chars": " ()", # filtered out + # "trim_leadings": "+ ", # filtered by lstrip + # "trim_trailings": "", # filtered by rstrip + # "normalization": [["Rd", "Road"]], # filtered by replace + # } + + sheet1: Tuple[BOOK, str] = parse_idx(r["sheet_idx0"], result, expected) + sheet2: Tuple[BOOK, str] = parse_idx(r["sheet_idx1"], result, expected) + total_metric = True + for rl in r["rules"]: + for rng in MultiCellRange(rl["range"]): + for cdn in rng.cells: + coordinate: str = "{:}{:d}".format(get_column_letter(cdn[1]), cdn[0]) + value1: str = str(read_cell_value(*sheet1, coordinate)) + value2: str = str(read_cell_value(*sheet2, coordinate)) + logger.debug("%s: %s vs %s", cdn, value1, value2) + + for rplc in rl.get("normalization", []): + value1 = value1.replace(rplc[0], rplc[1]) + value2 = value2.replace(rplc[0], rplc[1]) + if "trim_leadings" in rl: + value1 = value1.lstrip(rl["trim_leadings"]) + value2 = value2.lstrip(rl["trim_leadings"]) + if "trim_trailings" in rl: + value1 = value1.rstrip(rl["trim_trailings"]) + value2 = value2.rstrip(rl["trim_trailings"]) + if "ignore_chars" in rl: + ignore_chars: Set[str] = set(rl["ignore_chars"]) + value1 = "".join(filter(lambda ch: ch not in ignore_chars, value1)) + value2 = "".join(filter(lambda ch: ch not in ignore_chars, value2)) + if rl.get("ignore_case", False): + value1 = value1.lower() + value2 = value2.lower() + + if rl["type"] == "includes": + metric: bool = value2 in value1 + elif rl["type"] == "included_by": + metric: bool = value1 in value2 + elif rl["type"] == "fuzzy_match": + metric: bool = fuzz.ratio(value1, value2) >= rl.get("threshold", 85.) + elif rl["type"] == "exact_match": + metric: bool = value1 == value2 + total_metric = total_metric and metric + + metric: bool = total_metric + logger.debug("Assertion: %s =~= %s - %s", r["sheet_idx0"], r["sheet_idx1"], metric) + # }}} Fuzzy Match for Ranges # + + elif r["type"] == "sparkline": + # Compare Sparklines {{{ # + # sheet_idx0: 0 == "RI0" == "RNSheet1" | "EI0" == "ENSheet1" + # sheet_idx1: as sheet_idx0 + + sparkline1: Dict[str, str] = load_sparklines(*parse_idx(r["sheet_idx0"], result, expected)) + sparkline2: Dict[str, str] = load_sparklines(*parse_idx(r["sheet_idx1"], result, expected)) + metric: bool = sparkline1 == sparkline2 + logger.debug("Assertion: %s.sp == %.sp - %s", r["sheet_idx0"], r["sheet_idx1"], metric) + # }}} Compare Sparklines # + + elif r["type"] == "chart": + # Compare Charts {{{ # + # sheet_idx0: 0 == "RI0" == "RNSheet1" | "EI0" == "ENSheet1" + # sheet_idx1: as sheet_idx0 + # chart_props: list of str, see utils.load_charts + + charts1: Dict[str, Any] = load_charts(*parse_idx(r["sheet_idx0"], xlworkbookr, xlworkbooke), **r) + charts2: Dict[str, Any] = load_charts(*parse_idx(r["sheet_idx1"], xlworkbookr, xlworkbooke), **r) + metric: bool = charts1 == charts2 + logger.debug("Assertion: %s[chart] == %s[chart] - %s", r["sheet_idx0"], r["sheet_idx1"], metric) + # }}} Compare Charts # + + elif r["type"] == "style": + # Compare Style (Also Conditional Formatiing) {{{ # + # sheet_idx0: 0 == "RI0" == "RNSheet1" | "EI0" == "ENSheet1" + # sheet_idx1: as sheet_idx0 + # props: list of str indicating concerned styles, see utils._read_cell_style + + sheet_idx1: Tuple[BOOK, str] = parse_idx(r["sheet_idx0"], xlworkbookr, xlworkbooke) + book_name1: str = parse_idx(r["sheet_idx0"], result, expected)[0] + styles1: Dict[str, List[Any]] = load_xlsx_styles(*sheet_idx1, book_name1, **r) + + sheet_idx2: Tuple[BOOK, str] = parse_idx(r["sheet_idx1"], xlworkbookr, xlworkbooke) + book_name2: str = parse_idx(r["sheet_idx1"], result, expected)[0] + styles2: Dict[str, List[Any]] = load_xlsx_styles(*sheet_idx2, book_name2, **r) + # number_formats1: List[str] = [c.number_format.lower() for col in sheet1.iter_cols() for c in col if c.value is not None and c.data_type=="n"] + # number_formats2: List[str] = [c.number_format.lower() for col in sheet2.iter_cols() for c in col if c.value is not None and c.data_type=="n"] + metric: bool = styles1 == styles2 + logger.debug("Assertion: %s.style == %s.style - %s", r["sheet_idx0"], r["sheet_idx1"], metric) + # }}} Compare Style (Also Conditional Formatiing) # + + elif r["type"] == "freeze": + # Compare Freezing {{{ # + # sheet_idx0: 0 == "RI0" == "RNSheet1" | "EI0" == "ENSheet1" + # sheet_idx1: as sheet_idx0 + + sheet1: Worksheet = _load_sheet(*parse_idx(r["sheet_idx0"], xlworkbookr, xlworkbooke)) + if sheet1 is None: + return 0. + sheet2: Worksheet = _load_sheet(*parse_idx(r["sheet_idx1"], xlworkbookr, xlworkbooke)) + metric: bool = sheet1.freeze_panes == sheet2.freeze_panes + logger.debug("Assertion: %s.freeze(%s) == %s.freeze(%s) - %s" + , r["sheet_idx0"], sheet1.freeze_panes + , r["sheet_idx1"], sheet2.freeze_panes + , metric + ) + # }}} Compare Freezing # + + elif r["type"] == "zoom": + # Check Zooming {{{ # + # sheet_idx: 0 == "RI0" == "RNSheet1" | "EI0" == "ENSheet1" + # method: str + # ref: value + + sheet: Worksheet = _load_sheet(*parse_idx(r["sheet_idx"], xlworkbookr, xlworkbooke)) + if sheet is None: + return 0. + zoom_scale: Number = sheet.sheet_view.zoomScale or 100. + metric: bool = _match_value_to_rule(zoom_scale, r) + logger.debug("Assertion: %s.zoom(%.1f) %s %.1f - %s", r["sheet_idx"], zoom_scale, r["method"], r["ref"], + metric) + # }}} Check Zooming # + + elif r["type"] == "data_validation": + # Check Data Validation {{{ # + # sheet_idx: 0 == "RI0" == "RNSheet1" | "EI0" == "ENSheet1" + # dv_props: list of dict like {attribute: {"method": str, "ref": anything}} + # available attributes: + # * ranges + # * type + # * formula1 + # * formula2 + # * operator + # * allowBlank + # * showDropDown + # * showInputMessage + # * showErrorMessage + # * error + # * errorTitle + # * errorStyle + # * prompt + # * promptTitle + # * imeMode + + sheet: Worksheet = _load_sheet(*parse_idx(r["sheet_idx"], xlworkbookr, xlworkbooke)) + if sheet is None: + return 0. + data_validators: List[DataValidation] = sheet.data_validations.dataValidation + + total_metric = len(data_validators) >= len(r["dv_props"]) + for dat_vldt in data_validators: + metric = False + for prpt in r["dv_props"]: + metric = metric or all(_match_value_to_rule(getattr(dat_vldt, attrbt) + , mr + ) \ + for attrbt, mr in prpt.items() + ) + if metric: + break + total_metric = total_metric and metric + if not total_metric: + break + + logger.debug("Assertion: %s.data_validation - %s", r["sheet_idx"], total_metric) + metric: bool = total_metric + # }}} Check Data Validation # + + elif r["type"] == "row_props": + # Check Row Properties {{{ # + # sheet_idx0: 0 == "RI0" == "RNSheet1" | "EI0" == "ENSheet1" + # sheet_idx1: as sheet_idx0 + # props: list of str, see utils.load_rows_or_cols + + rows1: Dict[str, Any] = load_rows_or_cols(*parse_idx(r["sheet_idx0"], xlworkbookr, xlworkbooke) + , obj="row" + , **r + ) + rows2: Dict[str, Any] = load_rows_or_cols(*parse_idx(r["sheet_idx1"], xlworkbookr, xlworkbooke) + , obj="row" + , **r + ) + logger.debug("Rows1: %s", repr(rows1)) + logger.debug("Rows2: %s", repr(rows2)) + metric: bool = rows1 == rows2 + logger.debug("Assertion: %s[rows] == %s[rows] - %s", r["sheet_idx0"], r["sheet_idx1"], metric) + # }}} Check Row Properties # + + elif r["type"] == "col_props": + # Check Row Properties {{{ # + # sheet_idx0: 0 == "RI0" == "RNSheet1" | "EI0" == "ENSheet1" + # sheet_idx1: as sheet_idx0 + # props: list of str, see utils.load_rows_or_cols + + cols1: Dict[str, Any] = load_rows_or_cols(*parse_idx(r["sheet_idx0"], xlworkbookr, xlworkbooke) + , obj="column" + , **r + ) + cols2: Dict[str, Any] = load_rows_or_cols(*parse_idx(r["sheet_idx1"], xlworkbookr, xlworkbooke) + , obj="column" + , **r + ) + metric: bool = cols1 == cols2 + logger.debug("Assertion: %s[cols] == %s[cols] - %s", r["sheet_idx0"], r["sheet_idx1"], metric) + # }}} Check Row Properties # + + elif r["type"] == "filter": + # Compare Filters {{{ # + # sheet_idx0: 0 == "RI0" == "RNSheet1" | "EI0" == "ENSheet1" + # sheet_idx1: as sheet_idx0 + + filters1: Dict[str, Any] = load_filters(*parse_idx(r["sheet_idx0"], xlworkbookr, xlworkbooke), **r) + filters2: Dict[str, Any] = load_filters(*parse_idx(r["sheet_idx1"], xlworkbookr, xlworkbooke), **r) + metric: bool = filters1 == filters2 + logger.debug("Assertion: %s[filter] == %s[filter] - %s", r["sheet_idx0"], r["sheet_idx1"], metric) + # }}} Compare Filters # + + elif r["type"] == "pivot_table": + # Compare Pivot Tables {{{ # + # sheet_idx0: 0 == "RI0" == "RNSheet1" | "EI0" == "ENSheet1" + # sheet_idx1: as sheet_idx0 + # pivot_props: list of str, see utils.load_pivot_tables + + pivots1: Dict[str, Any] = load_pivot_tables(*parse_idx(r["sheet_idx0"], xlworkbookr, xlworkbooke), **r) + pivots2: Dict[str, Any] = load_pivot_tables(*parse_idx(r["sheet_idx1"], xlworkbookr, xlworkbooke), **r) + metric: bool = pivots1 == pivots2 + logger.debug("Assertion: %s[pivot]==%s[pivot] - %s", r["sheet_idx0"], r["sheet_idx1"], metric) + # }}} Compare Pivot Tables # + + elif r["type"] == "check_cell": + # Check Cell Properties {{{ # + # sheet_idx: 0 == "RI0" == "RNSheet1" | "EI0" == "ENSheet1" + # coordinate: str, "E3" + # props: dict like {attribute: {"method": str, "ref": anything}} + # supported attributes: value & those supported by utils._read_cell_style + + sheet: Worksheet = _load_sheet(*parse_idx(r["sheet_idx"], xlworkbookr, xlworkbooke)) + if sheet is None: + return 0. + # data_frame: pd.DataFrame = _load_sheet(*parse_idx(r["sheet_idx"], pdworkbookr, pdworkbooke)) + cell: Cell = sheet[r["coordinate"]] + metric: bool = True + for prpt, rule in r["props"].items(): + if prpt == "value": + val = read_cell_value(*parse_idx(r["sheet_idx"], result, expected), r["coordinate"]) + else: + val = _read_cell_style(prpt, cell) + + metric = metric and _match_value_to_rule(val, rule) + + logger.debug("Assertion: %s[%s] :%s - %s" + , r["sheet_idx"], r["coordinate"] + , repr(r["props"]), metric + ) + # }}} Check Cell Properties # + + else: + raise NotImplementedError("Unimplemented sheet check: {:}".format(r["type"])) + + passes = passes and metric + if not passes: + break + + return float(passes) + # }}} function compare_table # + + +def compare_csv(result: str, expected: str, **options) -> float: + if result is None: + return 0. + + with open(result) as f: + result_lines: Iterable[str] = f.read().splitlines() + with open(expected) as f: + expected_lines: Iterable[str] = f.read().splitlines() + if not options.get("strict", True): + result_lines = map(str.strip, result_lines) + expected_lines = map(str.strip, expected_lines) + if options.get("ignore_case", False): + result_lines = map(str.lower, result_lines) + expected_lines = map(str.lower, expected_lines) + + metric: bool = list(result_lines) == list(expected_lines) + return float(metric) + + +def compare_conference_city_in_order(actual_city_list_path, expected_city): + expected_city_list = expected_city["expected"] + wb = openpyxl.load_workbook(actual_city_list_path) + sheet = wb.active + actual_city_list = [] + for row in sheet["C2:C22"]: + for cell in row: + actual_city_list.append(cell.value) + # expected_city is the city that we want to compare with the actual city list + # must in order index + # debug + try: + for i in range(len(actual_city_list)): + if isinstance(expected_city_list[i], str): + if expected_city_list[i] not in actual_city_list[i]: + logger.debug(f"Expected city {expected_city_list[i]}; Actual city {actual_city_list[i]}") + print(f"Expected city {expected_city_list[i]}; Actual city {actual_city_list[i]}") + return 0. + + + elif isinstance(expected_city_list[i], List): + if not any(possible_str in actual_city_list[i] for possible_str in expected_city_list[i]): + logger.debug(f"Expected city {expected_city_list[i]}; Actual city {actual_city_list[i]}") + print(f"Expected city {expected_city_list[i]}; Actual city {actual_city_list[i]}") + return 0. + + else: + raise TypeError("Expected city should be a string or a list of strings") + + except: + return 0. + + return 1. diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/thunderbird.py b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/thunderbird.py new file mode 100644 index 000000000..5b7aaa077 --- /dev/null +++ b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/thunderbird.py @@ -0,0 +1,176 @@ +import json +import logging +import re +from typing import List, Pattern, Dict, Match +from typing import Union, Any, TypeVar, Callable + +from .utils import _match_record +from .utils import _match_value_to_rule as _match_pref + +logger = logging.getLogger("desktopenv.metric.thunderbird") + +V = TypeVar("Value") + +_pref_pattern: Pattern[str] = re.compile(r'^user_pref\("(?P(?:[^"]|\\")+)\", (?P.+)\);$'); + + +def check_thunderbird_prefs(result: str, rule: Dict[str, Dict[str, Dict[str, Any]]]): + """ + Args: + result (str): path to result file + rule (Dict[str, Dict[str, Dict[str, Any]]]): dict like + { + "expect": { + str: { + "method": str + "ref": something + } + } + "unexpect": { + str: { + "method": str + "ref": something + } + } + } + + Returns: + float + """ + + if result is None: + return 0. + + expect_rules = rule.get("expect", {}) + unexpect_rules = rule.get("unexpect", {}) + + expect_metrics = {k: False for k in expect_rules} + unexpect_metric = True + with open(result) as f: + for l in f: + match_: Match[str] = _pref_pattern.match(l.strip()) + if match_ is None: + continue + + key: str = match_.group("key") + # value: str = match_.group("val") + # if value in {"true", "false"}: + # value = value.title() + # value: V = eval(value) + value = json.loads(match_.group("val")) + if key in expect_rules: + logger.debug("K: %s, V: %s", key, repr(value)) + expect_metrics[key] = _match_pref(value, expect_rules[key]) + elif key in unexpect_rules: + unexpect_metric = unexpect_metric and not _match_pref(value, unexpect_rules[key]) + + return float(all(expect_metrics.values()) and unexpect_metric) + + +_value_processor: Callable[[str], str] = lambda val: val.replace("\\\"", "\"").replace("\\\\", "\\") +# _condition_pattern: Pattern[str] = re.compile(r'(?PAND|OR) \((?P[\w ]+),(?P[\w ' + '\'' + r']+),(?:"(?P(?:[^"]|\")+)"|(?P[^)]+))\)') +_condition_pattern: Pattern[str] = re.compile( + r'\b(?:AND|OR) \((?:[\w ]+),(?:[\w ' + '\'' + r']+),(?:"(?:(?:[^"]|\")+)"|(?:[^)]+))\)|\bALL\b') + + +def check_thunderbird_filter(result: str, rules: Dict[str, List[Dict[str, str]]]) -> float: + """ + Args: + result (str): path to filter def file + rules (Dict[str, List[Dict[str, str]]]): dict like + { + "expect": [{key: value}] + "unexpect": [{key: value}] + } + + Returns: + float + """ + + if result is None: + return 0. + + # read filter def file + # a filter: + # { + # "name": "Name", + # "enabled": "yes" | "no", + # "type": "17", + # "action": "Move to folder" | ..., + # "actionValue": ..., + # "condition": [...] + # } + filters: List[Dict[str, Union[str, List[str]]]] = [] + with open(result) as f: + for l in f: + if l.startswith("name="): + filter_: Dict[str, Union[str, List[str]]] = {} + filter_["name"] = _value_processor(l[6:-2]) + elif l.startswith("enabled="): + filter_["enabled"] = _value_processor(l[9:-2]) + elif l.startswith("type="): + filter_["type"] = _value_processor(l[6:-2]) + elif l.startswith("action="): + filter_["action"] = _value_processor(l[8:-2]) + elif l.startswith("actionValue="): + filter_["actionValue"] = _value_processor(l[13:-2]) + elif l.startswith("condition="): + condition_str: str = _value_processor(l[11:-2]) + logger.debug("FILTER CONDITION: %s", condition_str) + + conditions: List[str] = \ + _condition_pattern.findall(condition_str) + logger.debug("FILTER CONDITIONS: %s", repr(conditions)) + + filter_["condition"] = conditions + logger.debug("FILTER %s", repr(filter_)) + filters.append(filter_) + + expect_metrics = [False] * len(rules.get("expect", [])) + unexpect_metric = True + for flt in filters: + for i, r in enumerate(rules.get("expect", [])): + expect_metrics[i] = expect_metrics[i] or _match_record(r, flt) + unexpect_metric = unexpect_metric and not any(_match_record(r, flt) for r in rules.get("unexpect", [])) + return float(all(expect_metrics) and unexpect_metric) + + +def check_thunderbird_folder(result: Union[str, List[str]], reference: Union[str, List[str]], **kwargs) -> float: + """ + Check the file or file_list that each text file contains all messages in a folder in Thunderbird. Each message is started with `FROM - `. + **kwargs: + ignore_status (bool): for comparison, ignore the status (X-Mozilla-Status: 0000) of each message. default: False + ignore_keys (bool): for comparison, ignore the keys (X-Mozilla-Keys: label) of each message. default: False + remove_deleted (bool): ignore deleted messages which has status code 0008 or 0009. default: True + remove_duplicate (bool): remove duplicate messages. default: True + """ + + def normalize_msg(msg, options): + ignore_status = options.get('ignore_status', False) + ignore_keys = options.get('ignore_keys', False) + if ignore_status: + msg = re.sub(r'X-Mozilla-Status\d?:[\s\d]+', '', msg) + if ignore_keys: + msg = re.sub(r'(X-Mozilla-Keys:[^\n]*?)\n(MIME-Version)', r'\2', msg) + return msg.strip() + + def read_thunderbird_folder_file(path: str) -> str: + with open(path, 'r') as inf: + data = inf.read().strip() + messages = [] + for mail in data.split('FROM - '): + if mail.strip(): continue + if kwargs.get('remove_deleted', True) and re.search(r'X-Mozilla-Status: 000[89]', mail): continue + messages.append('FROM - ' + normalize_msg(mail, kwargs)) + if kwargs.get('remove_duplicate', True): + messages = set(messages) + return '\n'.join(sorted(messages)) + + if type(reference) != list: + result, reference = [result], [reference] + for pred, gold in zip(result, reference): + if pred is None: return .0 + mail1 = read_thunderbird_folder_file(pred) + mail2 = read_thunderbird_folder_file(gold) + if mail1 != mail2: return .0 + return 1.0 diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/utils.py b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/utils.py new file mode 100644 index 000000000..e512a2682 --- /dev/null +++ b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/utils.py @@ -0,0 +1,702 @@ +import builtins +#import datetime +import functools +import itertools +import logging +import operator +import re +import zipfile +#import pandas as pd +from typing import Any, TypeVar, Union, Iterable, Optional, Callable +from typing import Dict, List, Set, Match, Tuple, Pattern +from urllib.parse import urlparse, urlunparse + +import formulas +import lxml.cssselect +import lxml.etree +import xmltodict +from lxml.etree import _Element +from openpyxl import Workbook +from openpyxl.cell.cell import Cell, MergedCell +from openpyxl.chart._chart import ChartBase +from openpyxl.formatting.formatting import ConditionalFormattingList +from openpyxl.pivot.cache import CacheSource as PivotCacheSource +from openpyxl.pivot.table import TableDefinition as PivotTableDefinition +from openpyxl.styles.differential import DifferentialStyle +from openpyxl.utils import coordinate_to_tuple, get_column_letter +from openpyxl.worksheet.cell_range import MultiCellRange, CellRange +from openpyxl.worksheet.dimensions import DimensionHolder +from openpyxl.worksheet.filters import AutoFilter, SortState +from openpyxl.worksheet.worksheet import Worksheet + +V = TypeVar("Value") + +logger = logging.getLogger("desktopenv.metrics.utils") + +_xlsx_namespaces = [("oo", "http://schemas.openxmlformats.org/spreadsheetml/2006/main") + , ("x14", "http://schemas.microsoft.com/office/spreadsheetml/2009/9/main") + , ("xm", "http://schemas.microsoft.com/office/excel/2006/main") + ] +_xlsx_ns_mapping = dict(_xlsx_namespaces) +_xlsx_ns_imapping = dict(map(lambda itm: (itm[1], itm[0]), _xlsx_namespaces)) +_xlsx_ns_imapping["http://schemas.openxmlformats.org/spreadsheetml/2006/main"] = None +_sheet_name_selector = lxml.cssselect.CSSSelector("oo|sheets>oo|sheet", namespaces=_xlsx_ns_mapping) +_sparklines_selector = lxml.cssselect.CSSSelector("x14|sparkline", namespaces=_xlsx_ns_mapping) + + +def load_sparklines(xlsx_file: str, sheet_name: str) -> Dict[str, str]: + # function load_sparklines {{{ # + """ + Args: + xlsx_file (str): path to xlsx + sheet_name (str): sheet name + + Returns: + List[Dict[str, str]]: sparkline definitions in form of + { + "F3": "Sheet1!C3:E3" + } + """ + + # read xlsx + try: + with zipfile.ZipFile(xlsx_file, "r") as z_f: + with z_f.open("xl/workbook.xml") as f: + workbook_database: _Element = lxml.etree.fromstring(f.read()) + sheets: List[_Element] = _sheet_name_selector(workbook_database) + sheet_names: Dict[str, str] = {sh.get("name"): sh.get("sheetId") for sh in sheets} + with z_f.open("xl/worksheets/sheet{:}.xml".format(sheet_names[sheet_name])) as f: + sheet: _Element = lxml.etree.fromstring(f.read()) + sparklines: List[_Element] = _sparklines_selector(sheet) + except zipfile.BadZipFile: + return {} + + sparklines_dict: Dict[str, str] = {} + for sp_l in sparklines: + sparkline_xml: str = lxml.etree.tostring(sp_l, encoding="unicode") + sparkline: Dict[str, Dict[str, str]] = xmltodict.parse(sparkline_xml + , process_namespaces=True + , namespaces=_xlsx_ns_imapping + ) + sparklines_dict[sparkline["x14:sparkline"]["xm:sqref"]] = sparkline["x14:sparkline"]["xm:f"] + return sparklines_dict + # }}} function load_sparklines # + + +# Available Chart Properties: +# title: str +# anchor: ["oneCell" | "twoCell" | "absolute", col0, row0, col1, row1] +# legend: "b" | "tr" | "l" | "r" | "t" +# width: number +# height: number +# type: "scatterChart" | "lineChart" | "barChart" +# direction: "bar" (hori) | "col" (vert) +# xtitle, ytitle, ztitle: str +def load_charts(xlsx_file: Workbook, sheet_name: str, **options) -> Dict[str, Any]: + # function load_charts {{{ # + """ + Args: + xlsx_file (Workbook): concerned excel book + sheet_name (str): sheet name + options (Dict[str, List[str]]): dict like {"chart_props": list of str} + giving the concerned chart properties + + Returns: + Dict[str, Any]: information of charts, dict like + { + : { + : anything + } + } + """ + + # workbook: Workbook = openpyxl.load_workbook(filename=xlsx_file) + try: + worksheet: Worksheet = xlsx_file[sheet_name] + except KeyError: + return {} + charts: List[ChartBase] = worksheet._charts + + chart_set: Dict[str, Any] = {} + chart_props: Set[str] = set(options["chart_props"]) if "chart_props" in options else set() + for ch in charts: + series: List[str] = [] + for ser in ch.series: + if hasattr(ser.val, "numRef") and hasattr(ser.val.numRef, "f"): + value_str: str = ser.val.numRef.f + elif hasattr(ser.val, "strRef") and hasattr(ser.val.strRef, "f"): + value_str: str = ser.val.strRef.f + else: + value_str: str = "" + if hasattr(ser.cat, "numRef") and hasattr(ser.cat.numRef, "f"): + categ_str: str = ser.cat.numRef.f + elif hasattr(ser.cat, "strRef") and hasattr(ser.cat.strRef, "f"): + categ_str: str = ser.cat.strRef.f + else: + categ_str: str = "" + series.append("{:},{:}".format(value_str, categ_str)) + series: str = ";".join(series) + + # TODO: maybe more aspects, like chart type + info: Dict[str, Any] = {} + + if "title" in chart_props: + try: + info["title"] = ch.title.tx.rich.p[0].r[0].t + except: + info["title"] = None + if "legend" in chart_props: + info["legend"] = ch.legend.position if ch.legend is not None else None + if "anchor" in chart_props: + info["anchor"] = [ch.anchor.editAs + , ch.anchor._from.col, ch.anchor.to.row + , ch.anchor.to.col, ch.anchor.to.row + ] + if "width" in chart_props: + info["width"] = ch.width + if "height" in chart_props: + info["height"] = ch.height + if "type" in chart_props: + info["type"] = ch.tagname + if "direction" in chart_props: + info["direction"] = ch.barDir + + if "xtitle" in chart_props: + try: + info["xtitle"] = ch.x_axis.title.tx.rich.p[0].r[0].t + except: + info["xtitle"] = None + if "ytitle" in chart_props: + try: + info["ytitle"] = ch.y_axis.title.tx.rich.p[0].r[0].t + except: + info["ytitle"] = None + if "ztitle" in chart_props: + try: + info["ztitle"] = ch.z_axis.title.tx.rich.p[0].r[0].t + except: + info["ztitle"] = None + chart_set[series] = info + logger.debug(".[%s].charts: %s", sheet_name, repr(chart_set)) + return chart_set + # }}} function load_charts # + + +# Available Pivot Properties: +# name: str +# show_total, show_empty_row, show_empty_col, show_headers: bool +# location: str +# selection: if the concrete item selection should be checked, a list of set of tuple like (bool, index) will be returned; list will be returned instead of set if "ordered" is specified +# filter: if the filter fields should be checked; fields indices will be return in `filter_fields` item +# col_fields: indices +# row_fields: indices +# data_fields: list of str representations. the str representation is like "index;name;subtotal_type;show_data_as"; name is optional and is only returned when `data_fields_name` is specified in `pivot_props` +def load_pivot_tables(xlsx_file: Workbook, sheet_name: str, **options) -> Dict[str, Any]: + # function load_pivot_tables {{{ # + """ + Args: + xlsx_file (Workbook): concerned excel book + sheet_name (str): sheet name + options (Dict[str, List[str]]): dict like {"pivot_props": list of str} + giving the concerned pivot properties + + Returns: + Dict[str, Any]: information of pivot tables, dict like + { + : { + : anything + } + } + """ + + try: + worksheet: Worksheet = xlsx_file[sheet_name] + except KeyError: + return {} + pivots: List[PivotTableDefinition] = worksheet._pivots + + pivot_set: Dict[str, Any] = {} + pivot_props: Set[str] = set(options.get("pivot_props", [])) + for pvt in pivots: + raw_selection: List[List[tuple[Optional[bool], int]]] = \ + [[(itm.h, itm.x) for itm in f.items if itm.x is not None] \ + for f in pvt.pivotFields + ] + raw__selection: List[List[tuple[Optional[bool], int]]] = list( + itertools.dropwhile(lambda r: len(r) == 0, raw_selection)) + left_bias = len(raw_selection) - len(raw__selection) + selection: List[List[tuple[Optional[bool], int]]] = list( + (itertools.dropwhile(lambda r: len(r) == 0, reversed(raw__selection))))[::-1] + right_bias = len(raw__selection) - len(selection) + cache_source: PivotCacheSource = pvt.cache.cacheSource + cell_range1: str + cell_range2: str + cell_range1, cell_range2 = cache_source.worksheetSource.ref.split(":") + cell_range1: Tuple[int, int] = coordinate_to_tuple(cell_range1) + cell_range1 = (cell_range1[0], cell_range1[1] + left_bias) + cell_range2: Tuple[int, int] = coordinate_to_tuple(cell_range2) + cell_range2 = (cell_range2[0], cell_range2[1] - right_bias) + source: str = "{:};{:}:{:};{:}".format(cache_source.type, cell_range1, cell_range2, + cache_source.worksheetSource.sheet) + + info: Dict[str, Any] = {} + if "name" in pivot_props: + info["name"] = pvt.name + + if "show_total" in pivot_props: + info["show_total"] = pvt.visualTotals + if "show_empty_row" in pivot_props: + info["show_empty_row"] = pvt.showEmptyRow + if "show_empty_col" in pivot_props: + info["show_empty_col"] = pvt.showEmptyCol + if "show_headers" in pivot_props: + info["show_headers"] = pvt.showHeaders + + if "location" in pivot_props: + info["location"] = pvt.location + if "filter" in pivot_props or "selection" in pivot_props: + info["selection"] = selection if "ordered" in pivot_props else list(set(r) for r in selection) + if "filter" in pivot_props: + info["filter_fields"] = set(f.fld for f in pvt.pageFields) + if "col_fields" in pivot_props: + info["col_fields"] = [f.x - left_bias for f in pvt.colFields] + if "row_fields" in pivot_props: + info["row_fields"] = [f.x - left_bias for f in pvt.rowFields] + if "data_fields" in pivot_props: + info["data_fields"] = [ + "{:d};{:};{:};{:}".format(f.fld - left_bias, f.name if "data_fields_name" in pivot_props else "" + , f.subtotal, f.showDataAs + ) \ + for f in pvt.dataFields + ] + + pivot_set[source] = info + logger.debug(".[%s].pivots: %s", sheet_name, repr(pivot_set)) + return pivot_set + # }}} function load_pivot_tables # + + +_shared_str_selector = lxml.cssselect.CSSSelector("oo|sst>oo|si", namespaces=_xlsx_ns_mapping) +_shared_str_value_selector = lxml.cssselect.CSSSelector("oo|t", namespaces=_xlsx_ns_mapping) + + +def read_cell_value(xlsx_file: str, sheet_name: str, coordinate: str) -> Any: + # read_cell_value {{{ # + try: + with zipfile.ZipFile(xlsx_file, "r") as z_f: + try: + with z_f.open("xl/sharedStrings.xml") as f: + shared_str_xml: _Element = lxml.etree.fromstring(f.read()) + str_elements: List[_Element] = _shared_str_selector(shared_str_xml) + shared_strs: List[str] = [ "".join(t.text for t in _shared_str_value_selector(elm))\ + for elm in str_elements + ] + except: + #logger.exception("Read shared strings error: %s", xlsx_file) + logger.debug("Read shared strings error: %s", xlsx_file) + shared_strs: List[str] = [] + + with z_f.open("xl/workbook.xml") as f: + workbook_database: _Element = lxml.etree.fromstring(f.read()) + sheets: List[_Element] = _sheet_name_selector(workbook_database) + sheet_names: Dict[str, str] = {sh.get("name"): sh.get("sheetId") for sh in sheets} + + with z_f.open("xl/worksheets/sheet{:}.xml".format(sheet_names[sheet_name])) as f: + sheet: _Element = lxml.etree.fromstring(f.read()) + cells: List[_Element] = \ + lxml.cssselect.CSSSelector('oo|row>oo|c[r="{:}"]'.format(coordinate) + , namespaces=_xlsx_ns_mapping + )(sheet) + if len(cells) == 0: + return None + cell: _Element = cells[0] + except zipfile.BadZipFile: + return None + + cell: Dict[str, str] = xmltodict.parse(lxml.etree.tostring(cell, encoding="unicode") + , process_namespaces=True + , namespaces=_xlsx_ns_imapping + ) + logger.debug("%s.shared_strings: %s", xlsx_file, repr(shared_strs)) + logger.debug("%s.%s[%s]: %s", xlsx_file, sheet_name, coordinate, repr(cell)) + try: + if "@t" not in cell["c"] or cell["c"]["@t"] == "n": + return float(cell["c"]["v"]) + if cell["c"]["@t"] == "s": + return shared_strs[int(cell["c"]["v"])] + if cell["c"]["@t"] == "str": + return cell["c"]["v"] + if cell["c"]["@t"] == "inlineStr": + return cell["c"]["is"]["t"] + except (KeyError, ValueError): + return None + # }}} read_cell_value # + + +# Supported Styles: +# number_format +# font_name - str +# font_family - float +# font_color - in aRGB, e.g., FF000000 is black +# font_bold - bool +# font_italic - bool +# font_underline - "single" | "double" | "singleAccounting" | "doubleAccounting" +# font_size - float +# fill_type - "patternFill" | "gradientFill" +# bgcolor - in aRGB, e.g., FFFF0000 is red; This property seems to be ambiguous with fgcolor in xlsx, strange +# fgcolor - in aRGB, e.g., FF00FFFF is yellow # Deprecated +# hyperlink - str +# merge - bool, if the cell is in a merged range and is not the first cell in the merged range +def _read_cell_style(style_name: str, cell: Union[Cell, MergedCell], diff_style: Optional[DifferentialStyle] = None) -> Any: + if style_name == "number_format": + return (cell.number_format if diff_style is None else diff_style.numFmt.formatCode) \ + if cell.value is not None and cell.data_type == "n" else None + elif style_name == "font_name": + return (diff_style or cell).font.name if cell.value is not None else None + elif style_name == "font_family": + return (diff_style or cell).font.family if cell.value is not None else None + elif style_name == "font_color": + return (diff_style or cell).font.color.rgb if cell.value is not None else None + elif style_name == "font_bold": + return (diff_style or cell).font.bold if cell.value is not None else None + elif style_name == "font_italic": + return (diff_style or cell).font.italic if cell.value is not None else None + elif style_name == "font_underline": + return (diff_style or cell).font.underline if cell.value is not None else None + elif style_name == "font_size": + return (diff_style or cell).font.size if cell.value is not None else None + elif style_name == "fill_type": + try: + return (diff_style or cell).fill.tagname + except: + return None + elif style_name == "bgcolor" or style_name == "fgcolor": + try: + #return (diff_style or cell).fill.bgColor.rgb + if diff_style is not None: + return diff_style.fill.bgColor.rgb + else: + return cell.fill.fgColor.rgb + except: + return None + #elif style_name == "fgcolor": + #try: + #return (diff_style or cell).fill.fgColor.rgb + #except: + #return None + elif style_name == "hyperlink": + return cell.hyperlink or "" if cell.value is not None else None + elif style_name == "merge": + return isinstance(cell, MergedCell) + else: + raise NotImplementedError("Unsupported Style: {:}".format(style_name)) + + +_absolute_range_pattern: Pattern[str] = re.compile(r"""\$(?P[A-Z]{1,3})\$(?P\d+) # coord1 + (?:: + \$(?P[A-Z]{1,3})\$(?P\d+) # coord2 + )? + """ + , re.X + ) + + +def load_xlsx_styles(xlsx_file: Workbook, sheet_name: str, book_name: str, **options) -> Dict[str, List[Any]]: + # function load_xlsx_styles {{{ # + """ + Args: + xlsx_file (Workbook): concerned excel book + sheet_name (str): sheet name + book_name (str): book name + options (Dict[str, List[str]): dick like {"props": list of str} giving + the concerned styles + + Returns: + Dict[str, List[Any]]: dict like + { + : list of anything indicating concerned + property values + } + """ + + try: + worksheet: Worksheet = xlsx_file[sheet_name] + except KeyError: + return {} + + style_dict: Dict[str, List[Any]] = {} + concerned_styles: List[str] = options.get("props", []) + + # Handles Cell Styles + for col in worksheet.iter_cols(): + for c in col: + style_list: List[Any] = [] + for st in concerned_styles: + style_list.append(_read_cell_style(st, c)) + style_dict[c.coordinate] = style_list + + # Handles Conditional Formatting + conditional_formattings: ConditionalFormattingList = worksheet.conditional_formatting + formula_parser = formulas.Parser() + for fmt in conditional_formattings: + for r in fmt.rules: + active_cells: List[Cell] = [] + if r.type == "expression": + condition: Callable[[str], bool] = formula_parser.ast("=" + r.formula[0])[1].compile() + logger.debug("Expression condition: %s", r.formula[0]) + + arguments: List[Any] = [] + absolute_range_match: List[Tuple[str, str, str, str]] = _absolute_range_pattern.findall(r.formula[0]) + for m in absolute_range_match: + logger.debug("Absolute ranges: %s", repr(m)) + if m[2] is None and m[3] is None: + arguments.append(read_cell_value(book_name, sheet_name, coordinate="{:}{:}".format(m[0], m[1]))) + else: + arguments.append([read_cell_value(book_name, sheet_name + , coordinate="{:}{:}".format(get_column_letter(c[1]) + , c[0] + ) + ) \ + for c in CellRange("{:}{:}:{:}{:}".format(m[0], m[1], m[2], m[3])).cells \ + ] + ) + logger.debug("Absolute range arguments: %s", repr(arguments)) + + nb_contiguous_nothings = 0 + for rge in fmt.cells: + for c in rge.cells: + cell: Cell = worksheet.cell(row=c[0], column=c[1]) + cell_value = read_cell_value(book_name, sheet_name + , coordinate="{:}{:d}".format(get_column_letter(c[1]) + , c[0] + ) + ) + if cell_value is None: + nb_contiguous_nothings += 1 + if nb_contiguous_nothings>50: + break + continue + elif condition(cell_value, *arguments): + logger.debug("Active Cell %s(%s) for %s", repr(cell), str(cell_value), r.formula[0]) + active_cells.append(cell) + else: + raise NotImplementedError("Not Implemented Condition Type: {:}".format(r.type)) + + for c in active_cells: + style_dict[c.coordinate] = [_read_cell_style(st, c, r.dxf) for st in concerned_styles] + + logger.debug(".[%s].styles: %s", sheet_name, repr(style_dict)) + return style_dict + # }}} function load_xlsx_styles # + + +# Available Row Properties: +# hidden +# collapsed +# height +# +# Available Column Properties: +# width +# auto_size +# hidden +# collapsed +# min +# max +def load_rows_or_cols(xlsx_file: Workbook, sheet_name: str, **options) \ + -> Dict[Union[int, str], Dict[str, Any]]: + # function load_rows_or_cols {{{ # + """ + Args: + xlsx_file (Workbook): concerned excel book + sheet_name (str): sheet name + options (Dict[str, List[str]]): dict like + {"obj": "row" | "column", "props": list of str} giving the concerned + row/column properties + + Returns: + Dict[Union[int, str], Dict[str, Any]]: row/column information + """ + + try: + worksheet: Worksheet = xlsx_file[sheet_name] + except KeyError: + return {} + objs: DimensionHolder = getattr(worksheet, "{:}_dimensions".format(options["obj"])) + + obj_set: Dict[int, Any] = {} + obj_props: Set[str] = set(options.get("props", [])) + for obj_no, obj_dms in objs.items(): + info_dict: Dict[str, Any] = {} + for prop in obj_props: + info_dict[prop] = getattr(obj_dms, prop) + obj_set[obj_no] = info_dict + return obj_set + # }}} function load_rows_or_cols # + + +def load_filters(xlsx_file: Workbook, sheet_name: str, **options) -> Dict[str, Any]: + # function load_filters {{{ # + try: + worksheet: Worksheet = xlsx_file[sheet_name] + except KeyError: + return {} + + filters: AutoFilter = worksheet.auto_filter + filter_dict: Dict[str, Any] = {} + filter_dict["ref"] = filters.ref + + # filterColumn + filter_column_set: List[Dict[str, Any]] = [] + for flt_clm in filters.filterColumn: + filter_column: Dict[str, Any] = {} + filter_column["col_id"] = flt_clm.colId + filter_column["hidden_button"] = flt_clm.hiddenButton + filter_column["show_button"] = flt_clm.showButton + if flt_clm.filters is not None: + filter_column["filters_blank"] = flt_clm.filters.blank + filter_column["filters"] = set(flt_clm.filters.filter) + if flt_clm.customFilters is not None: + filter_column["custom_filters_op"] = flt_clm.customFilters._and + filter_column["custom_filters"] = set((flt.operator + , flt.val + ) \ + for flt in flt_clm.customFilters.customFilter + ) + filter_column_set.append(filter_column) + filter_column_set = list(sorted(filter_column_set + , key=(lambda d: d["col_id"]) + ) + ) + filter_dict["filter_column"] = filter_column_set + + # sortState + sort_state: Optional[SortState] = filters.sortState + if sort_state is not None: + sort_state_dict: Dict[str, Any] = {} + sort_state_dict["sort"] = sort_state.columnSort + sort_state_dict["case"] = sort_state.caseSensitive + sort_state_dict["method"] = sort_state.sortMethod + sort_state_dict["ref"] = sort_state.ref + sort_state_dict["condition"] = list({"descending": cdt.descending + , "key": cdt.sortBy + , "ref": cdt.ref + , "custom_list": cdt.customList + , "dxf_id": cdt.dxfId + , "icon": cdt.iconSet + , "iconid": cdt.iconId + } \ + for cdt in sort_state.sortCondition + ) + filter_dict["sort_state"] = sort_state_dict + + return filter_dict + # }}} function load_filters # + + +def _match_record(pattern: Dict[str, Any], item: Dict[str, Any]) -> bool: + return all(k in item and item[k] == val for k, val in pattern.items()) + + +def _multicellrange_containsby(subset_candidate: MultiCellRange, superset_candidate: MultiCellRange) -> bool: + return all(r in superset_candidate for r in subset_candidate) + + +def _match_value_to_rule(value: V, rule: Dict[str, Union[str, V]]) -> bool: + """ + Args: + value (V): value to match + rule (Dict[str, Union[str, V]]): rule dict like + { + "method": str + "ref": V as ref value + } + + Returns: + bool + """ + + if rule["method"].startswith("re"): # re.FLAGs + flags: List[str] = rule["method"].split(".")[1:] + flags: Iterable[re.RegexFlag] = (getattr(re, fl) for fl in flags) + flag: re.RegexFlag = functools.reduce(operator.or_, flags, re.RegexFlag(0)) + logger.debug("REFLAG: %s", repr(flag)) + + match_: Optional[Match[str]] = re.search(rule["ref"], value, flag) + return match_ is not None + if rule["method"] in {"eq", "ne" + , "le", "lt" + , "ge", "gt" + }: + return getattr(operator, rule["method"])(value, rule["ref"]) + if rule["method"].startswith("approx"): # approx:THRESHOLD + threshold: float = float(rule["method"].split(":")[1]) + logger.debug("Approx: TH%f, REF%f, VAL%s", threshold, rule["ref"], repr(value)) + try: + value = float(value) + except (ValueError, TypeError): + return False + else: + return abs(value - rule["ref"]) <= threshold + if rule["method"] == "spreadsheet_range": + subset_limit = MultiCellRange(rule["ref"][0]) + superset_limit = MultiCellRange(rule["ref"][1]) + return _multicellrange_containsby(subset_limit, value) \ + and _multicellrange_containsby(value, superset_limit) + if rule["method"].startswith("range."): # e.g., range.te [0, 2] -> 0 < x <= 2 + left_et = rule["method"][6] + right_et = rule["method"][7] + return getattr(operator, "l" + left_et)(rule["ref"][0], value) \ + and getattr(operator, "l" + right_et)(value, rule["ref"][1]) + if rule["method"] in {"str_list_eq", "str_set_eq"}: + container_type_str: str = rule["method"][4:-3] + container_type = getattr(builtins, container_type_str) + + value: container_type = container_type(value.strip("\"'").split(",")) + ref: container_type = container_type(rule["ref"]) + return value == ref + raise NotImplementedError() + + +def are_lists_equal(list1, list2, comparison_func): + # First check if both lists have the same length + if len(list1) != len(list2): + return False + + # Now make sure each element in one list has an equal element in the other list + for item1 in list1: + # Use the supplied function to test for an equal item + if not any(comparison_func(item1, item2) for item2 in list2): + return False + + # If all items match, the lists are equal + return True + + +def compare_urls(url1, url2): + if url1 is None or url2 is None: + return url1 == url2 + + def normalize_url(url): + # Parse the URL + parsed_url = urlparse(url) + + # If no scheme is present, assume 'http' + scheme = parsed_url.scheme if parsed_url.scheme else 'http' + + # Lowercase the scheme and netloc, remove 'www.', and handle trailing slash + normalized_netloc = parsed_url.netloc.lower().replace("www.", "") + normalized_path = parsed_url.path if parsed_url.path != '/' else '' + + # Reassemble the URL with normalized components + normalized_parsed_url = parsed_url._replace(scheme=scheme.lower(), netloc=normalized_netloc, + path=normalized_path) + normalized_url = urlunparse(normalized_parsed_url) + + return normalized_url + + # Normalize both URLs for comparison + norm_url1 = normalize_url(url1) + norm_url2 = normalize_url(url2) + + # Compare the normalized URLs + return norm_url1 == norm_url2 diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/vlc.py b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/vlc.py new file mode 100644 index 000000000..bb8e5deea --- /dev/null +++ b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/vlc.py @@ -0,0 +1,524 @@ +import logging +import os +import subprocess +from typing import Dict +from xml.etree import ElementTree +from urllib.parse import urlparse + +import acoustid +import cv2 +import imagehash +import librosa +import numpy as np +from PIL import Image +from fastdtw import fastdtw +from scipy.spatial.distance import cosine +from skimage.metrics import structural_similarity as ssim + +logger = logging.getLogger("desktopenv.metrics.vlc") + + +def is_vlc_playing(actual_status_path: str, rule: Dict[str, str]) -> float: + """ + Checks if VLC is currently playing a file. + """ + with open(actual_status_path, 'rb') as file: + actual_status = file.read().decode('utf-8') + + tree = ElementTree.fromstring(actual_status) + status = tree.find('state').text + logger.info(f"VLC Status: {status}") + if status == 'playing': + if rule['type'] == 'file_name': + # Try multiple possible paths for file information in VLC XML + file_paths = [ + 'information/category[@name="meta"]/info[@name="filename"]', + 'information/category[@name="meta"]/info[@name="title"]', + 'information/category[@name="meta"]/info[@name="uri"]', + 'information/category[@name="meta"]/info[@name="location"]', + 'information/category[@name="meta"]/info[@name="name"]' + ] + + file_info = None + for path in file_paths: + element = tree.find(path) + if element is not None and element.text: + file_info = element.text + break + + if file_info: + expected_filename = rule['file_name'] + + # Method 1: Direct filename match (most precise) + actual_basename = os.path.basename(file_info) + if actual_basename == expected_filename: + return 1 + + # Method 2: Endswith match (for backward compatibility) + if file_info.endswith(expected_filename): + return 1 + + # Method 3: For paths, check if expected filename is in the path + if expected_filename in file_info: + # Additional check to avoid false positives + # Make sure it's actually the filename, not just part of a path + if file_info.endswith('/' + expected_filename) or file_info.endswith('\\' + expected_filename): + return 1 + + logger.warning(f"File name mismatch - Expected: {expected_filename}, Found: {file_info}") + return 0 + else: + logger.warning(f"Could not find file information in VLC status XML for rule: {rule}") + return 0 + elif rule['type'] == 'url': + # Try multiple possible paths for URL information in VLC XML + url_paths = [ + 'information/category[@name="meta"]/info[@name="url"]', + 'information/category[@name="meta"]/info[@name="URI"]', + 'information/category[@name="meta"]/info[@name="location"]', + 'information/category[@name="meta"]/info[@name="title"]', # Sometimes URL is in title for streams + 'information/category[@name="meta"]/info[@name="filename"]', # Sometimes URL is in filename for streams + 'information/category[@name="Stream 0"]/info[@name="Codec"]', # Try stream info + 'information/category[@name="Stream 0"]/info[@name="Type"]', + 'information/category[@name="Stream 0"]/info[@name="Language"]' + ] + + file_info = None + logger.debug(f"Looking for URL: {rule['url']}") + + for path in url_paths: + element = tree.find(path) + if element is not None and element.text: + file_info = element.text + logger.debug(f"Found URL info at '{path}': {file_info}") + break + + if file_info: + # For URL comparison, check if the rule URL is contained in the file_info + # This handles cases where VLC might show a longer or modified URL + expected_url = rule['url'] + + # Method 1: Direct URL match + if expected_url in file_info or file_info.endswith(expected_url): + return 1 + + # Method 2: For HLS streams, VLC often shows just the filename instead of full URL + # Check if the file_info matches the filename part of the expected URL + try: + expected_parsed = urlparse(expected_url) + expected_filename = os.path.basename(expected_parsed.path) + + # If VLC shows just the filename (common for HLS streams) + if file_info == expected_filename: + logger.info(f"URL filename match - Expected URL: {expected_url}, VLC shows filename: {file_info}") + return 1 + + # Method 3: Check if both are URLs from the same domain and similar path + if '://' in file_info: # file_info is also a URL + actual_parsed = urlparse(file_info) + # Same domain and similar path structure + if (expected_parsed.netloc == actual_parsed.netloc and + expected_parsed.path in actual_parsed.path): + return 1 + except Exception as e: + logger.debug(f"URL parsing error: {e}") + pass + + logger.warning(f"URL mismatch - Expected: {expected_url}, Found: {file_info}") + return 0 + else: + logger.warning(f"Could not find URL information in VLC status XML for rule: {rule}") + return 0 + else: + logger.error(f"Unknown type: {rule['type']}") + return 0 + else: + return 0 + + +# fixme: part of this function can be moved to getters +def is_vlc_recordings_folder(actual_config_path: str, rule: Dict[str, str]) -> float: + """ + Checks if VLC's recording folder is set to the expected value. + """ + with open(actual_config_path, 'rb') as file: + config_file = file.read().decode('utf-8') + + expected_recording_file_path = rule['recording_file_path'] + + try: + for line in config_file.split("\n"): + # Skip comments and empty lines + if line.startswith('#') or not line.strip(): + continue + # Check if the line contains the recording path setting + if 'input-record-path' in line: + # Extract the value of the recording path and remove surrounding whitespace + current_path = line.split('=')[-1].strip() + # Compare with the Desktop path + if current_path == expected_recording_file_path: + return 1 + else: + return 0 + # The configuration key was not found in the file + return 0 + except FileNotFoundError: + logger.error("VLC configuration file not found.") + return 0 + except Exception as e: + logger.error(f"An error occurred: {e}") + return 0 + + +def is_vlc_fullscreen(actual_window_size, screen_size): + if screen_size is None or actual_window_size is None: + # if the screen size is not available, means that the window is not fullscreen + return 0 + + if actual_window_size['width'] == screen_size['width'] and actual_window_size['height'] == screen_size['height']: + return 1 + else: + return 0 + + +def compare_images(image1_path, image2_path, **options): + # You would call this function with the paths to the two images you want to compare: + # score = compare_images('path_to_image1', 'path_to_image2') + # print("Similarity score:", score) + + if not image1_path or not image2_path: + return 0 + + base_score = options.get("reference_base_result", None) + + # Open the images and convert to grayscale + image1 = Image.open(image1_path).convert('L') + image2 = Image.open(image2_path).convert('L') + + # Resize images to the smaller one's size for comparison + image1_size = image1.size + image2_size = image2.size + new_size = min(image1_size, image2_size) + + image1 = image1.resize(new_size, Image.Resampling.LANCZOS) + image2 = image2.resize(new_size, Image.Resampling.LANCZOS) + + # Convert images to numpy arrays + image1_array = np.array(image1) + image2_array = np.array(image2) + + # Calculate SSIM between two images + similarity_index = ssim(image1_array, image2_array) + + epsilon = 0.01 + if base_score is not None: + if similarity_index >= base_score + epsilon: + return (similarity_index - base_score) / (1 - base_score) + else: + return 0 + else: + return similarity_index + +def compare_audios(audio_path_1, audio_path_2): + """ + Compare two audio files and return a similarity score in the range [0, 1]. + audio_path_1, audio_path_2: paths to the audio files to compare + """ + # similarity = compare_audios_simple('path_to_audio1.mp3', 'path_to_audio2.mp3') + # print(f'Similarity Score: {similarity}') + + if not audio_path_1 or not audio_path_2: + return 0 + + y1, y2 = None, None + try: + y1, sr1 = librosa.load(audio_path_1) + except Exception: + logger.warning(f"Could not load audio from {os.path.basename(audio_path_1)}. It might be empty or corrupt.") + + try: + y2, sr2 = librosa.load(audio_path_2) + except Exception: + logger.warning(f"Could not load audio from {os.path.basename(audio_path_2)}. It might be empty or corrupt.") + + # Handle cases where one or both audio files are empty or corrupt. + is_y1_bad = (y1 is None) or (y1.shape[0] == 0) + is_y2_bad = (y2 is None) or (y2.shape[0] == 0) + + if is_y1_bad and is_y2_bad: + logger.info("Both audio files are empty or corrupt. Considering them perfectly similar.") + return 1.0 + + if is_y1_bad or is_y2_bad: + logger.warning(f"One audio file is empty/corrupt, the other is not. Similarity is 0.") + return 0.0 + + try: + logger.info(f"Audio 1 ({os.path.basename(audio_path_1)}): sr={sr1}, len={len(y1)}") + logger.info(f"Audio 2 ({os.path.basename(audio_path_2)}): sr={sr2}, len={len(y2)}") + + # Extract MFCC features + mfcc1 = librosa.feature.mfcc(y=y1, sr=sr1) + mfcc2 = librosa.feature.mfcc(y=y2, sr=sr2) + except Exception as e: + logger.error(f"Error during MFCC extraction: {e}") + return 0.0 + + # Normalize the MFCC features + mfcc1 = librosa.util.normalize(mfcc1, axis=1) + mfcc2 = librosa.util.normalize(mfcc2, axis=1) + logger.info(f"MFCCs normalized.") + + # Define a lambda function to compute cosine distance + dist_func = lambda x, y: cosine(x, y) + + # Use the DTW algorithm to find the best alignment path + distance, path = fastdtw(mfcc1.T, mfcc2.T, dist=dist_func) + logger.info(f"DTW distance: {distance:.4f}, Path length: {len(path)}") + + # Normalize the DTW distance by the length of the alignment path. + if len(path) == 0: + normalized_distance = np.inf + else: + normalized_distance = distance / len(path) + logger.info(f"Normalized DTW distance: {normalized_distance:.4f}") + + # Convert the normalized distance to a similarity score using an exponential decay function. + similarity = np.exp(-normalized_distance) + + return similarity + + +def compare_audios_by_dl_model(audio_path_1, audio_path_2): + pass + + +def compare_videos(video_path1, video_path2, max_frames_to_check=100, threshold=5): + # Open both video files + cap1 = cv2.VideoCapture(video_path1) + cap2 = cv2.VideoCapture(video_path2) + + frames_checked = 0 + mismatch_count = 0 + + while frames_checked < max_frames_to_check: + # Read frames from both videos + ret1, frame1 = cap1.read() + ret2, frame2 = cap2.read() + + # If a video ends, then check if both ended to confirm they are of the same length + if not ret1 or not ret2: + return ret1 == ret2 + + # Convert frames to PIL Images + frame1 = Image.fromarray(cv2.cvtColor(frame1, cv2.COLOR_BGR2RGB)) + frame2 = Image.fromarray(cv2.cvtColor(frame2, cv2.COLOR_BGR2RGB)) + + # Compute the perceptual hash for each frame + hash1 = imagehash.phash(frame1) + hash2 = imagehash.phash(frame2) + + # Increment the frames checked + frames_checked += 1 + + # Compute the difference in the hashes + if hash1 - hash2 > threshold: + mismatch_count += 1 + # If there's a significant difference, the frames are not the same + if mismatch_count > threshold: + return 0. + + # If we reach here, the content appears to be the same + return 1. + + +def check_qt_bgcone(actual_config_path, rule): + with open(actual_config_path, 'rb') as file: + config_file = file.read().decode('utf-8') + + expected_qt_bgcone = rule['expected_qt_bgcone'] + if isinstance(expected_qt_bgcone, int): + expected_qt_bgcone = str(expected_qt_bgcone) + + try: + # The default value of qt_bgcone is 1, which means it is enabled + qt_bgcone = "1" + for line in config_file.split("\n"): + # Check if the line contains the recording path setting + if 'qt-bgcone=' in line: + # Extract the value of the recording path and remove surrounding whitespace + qt_bgcone = line.split('=')[-1].strip() + # The configuration key was not found in the file + + if qt_bgcone == expected_qt_bgcone: + return 1 + else: + return 0 + except FileNotFoundError: + logger.error("VLC configuration file not found.") + return 0 + except Exception as e: + logger.error(f"An error occurred: {e}") + return 0 + + +def check_qt_max_volume(actual_config_path, rule): + with open(actual_config_path, 'rb') as file: + config_file = file.read().decode('utf-8') + + expected_qt_max_volume = rule['expected_qt_max_volume'] + if isinstance(expected_qt_max_volume, int): + expected_qt_max_volume = str(expected_qt_max_volume) + + try: + qt_max_volume = "125" + for line in config_file.split("\n"): + if 'qt-max-volume=' in line: + qt_max_volume = line.split('=')[-1].strip() + # The configuration key was not found in the file + + if qt_max_volume == expected_qt_max_volume: + return 1 + else: + return 0 + except FileNotFoundError: + logger.error("VLC configuration file not found.") + return 0 + except Exception as e: + logger.error(f"An error occurred: {e}") + return 0 + + +def check_qt_minimal_view(actual_config_path, rule): + with open(actual_config_path, 'rb') as file: + config_file = file.read().decode('utf-8') + + expected_qt_minimal_view = rule['expected_qt_minimal_view'] + if isinstance(expected_qt_minimal_view, int): + expected_qt_minimal_view = str(expected_qt_minimal_view) + + try: + qt_minimal_view = "0" + for line in config_file.split("\n"): + if 'qt-minimal-view=' in line: + qt_minimal_view = line.split('=')[-1].strip() + + if qt_minimal_view == expected_qt_minimal_view: + return 1 + else: + return 0 + except FileNotFoundError: + logger.error("VLC configuration file not found.") + return 0 + except Exception as e: + logger.error(f"An error occurred: {e}") + return 0 + + +def check_qt_slider_colours(actual_config_path, rule): + with open(actual_config_path, 'rb') as file: + config_file = file.read().decode('utf-8') + + try: + qt_slider_colours = "153;210;153;20;210;20;255;199;15;245;39;29" + for line in config_file.split("\n"): + if 'qt-slider-colours' in line: + qt_slider_colours = line.split('=')[-1].strip() + # The configuration key was not found in the file + + if rule['type'] == 'match': + expected_qt_slider_colours = rule['expected_qt_slider_colours'] + if qt_slider_colours == expected_qt_slider_colours: + return 1 + else: + return 0 + elif rule['type'] == 'blackish': + def is_color_blackish(rgb_values, threshold=100): + # decide if the color is blackish + return all(value < threshold for value in rgb_values) + + def parse_qt_slider_colours(colours_string): + # parse the string of colours into a list of RGB tuples + values = [int(x) for x in colours_string.split(';')] + colors = list(zip(values[0::3], values[1::3], values[2::3])) + return colors + + colors = parse_qt_slider_colours(qt_slider_colours) + + # check if all colors are blackish + for color in colors: + if is_color_blackish(color): + pass + else: + return 0 + return 1 + + except FileNotFoundError: + logger.error("VLC configuration file not found.") + return 0 + except Exception as e: + logger.error(f"An error occurred: {e}") + return 0 + + +def check_global_key_play_pause(actual_config_path, rule): + """ + # Play/Pause (str) + #global-key-play-pause= + + # Play/Pause (str) + #key-play-pause=Space + """ + with open(actual_config_path, 'rb') as file: + config_file = file.read().decode('utf-8') + + expected_global_key_play_pause = rule['expected_global_key_play_pause'] + + if isinstance(expected_global_key_play_pause, int): + expected_global_key_play_pause = str(expected_global_key_play_pause) + + try: + global_key_play_pause = "0" + for line in config_file.split("\n"): + # Check if the line contains the recording path setting + if 'global-key-play-pause=' in line: + global_key_play_pause = "0" if line.split('=')[-1].strip() == "" else "1" + + if global_key_play_pause == expected_global_key_play_pause: + return 1 + else: + return 0 + except FileNotFoundError: + logger.error("VLC configuration file not found.") + return 0 + except Exception as e: + logger.error(f"An error occurred: {e}") + return 0 + + +def check_one_instance_when_started_from_file(actual_config_path, rule): + with open(actual_config_path, 'rb') as file: + config_file = file.read().decode('utf-8') + + expected_one_instance_when_started_from_file = rule['expected_one_instance_when_started_from_file'] + + if isinstance(expected_one_instance_when_started_from_file, int): + expected_one_instance_when_started_from_file = str(expected_one_instance_when_started_from_file) + + try: + one_instance_when_started_from_file = "1" + for line in config_file.split("\n"): + # Check if the line contains the recording path setting + if 'one-instance-when-started-from-file=' in line: + one_instance_when_started_from_file = line.split('=')[-1].strip() + + if one_instance_when_started_from_file == expected_one_instance_when_started_from_file: + return 1 + else: + return 0 + except FileNotFoundError: + logger.error("VLC configuration file not found.") + return 0 + except Exception as e: + logger.error(f"An error occurred: {e}") + return 0 diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/vscode.py b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/vscode.py new file mode 100644 index 000000000..82129f686 --- /dev/null +++ b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/evaluators/metrics/vscode.py @@ -0,0 +1,283 @@ +import copy +import importlib.util +import json +import sys +import re +from typing import Dict + + +def check_json_keybindings(actual: str, expected: str, **options) -> float: + """ + Args: + actual (str): path to result text file + expected (str): expected dict{} + + Return: + float: the score + """ + + def direct_load_json(fp): + try: + with open(fp, 'r') as f: + data = json.load(f) + return data + except: + return None + + def skip_first_line_load_json(fp): + try: + with open(fp, 'r') as f: + f.readline() + data = json.load(f) + return data + except: + return None + + for func in [direct_load_json, skip_first_line_load_json]: + data = func(actual) + if data is not None and type(data) == list: + break + else: + return 0.0 + expected = expected['expected'] + if expected in data: + return 1.0 + else: + return 0.0 + + +def check_json_settings(actual: str, expected: str, **options) -> float: + """ + Args: + actual (str): path to result text file + expected (dict): expected dict{}, containing key "expect" + + Return: + float: the score + """ + if not actual: + return 0. + + try: + with open(actual, 'r') as f: + data = json.load(f) + except Exception as e: + return 0.0 + + expect = expected['expected'] + + # Check if all expected key-value pairs are in the actual data + for key, value in expect.items(): + if key not in data or data[key] != value: + return 0.0 + + return 1.0 + + +def compare_text_file(actual: str, expected: str, **options) -> float: + """ + Args: + actual (str): path to result text file + expected (str): path to gold text file + + Return: + float: the score + """ + if not actual: + return 0. + + with open(actual) as f1: + actual_text = f1.read() + with open(expected) as f2: + expected_text = f2.read() + + ignore_blanks = options.get('ignore_blanks', False) + if ignore_blanks: + actual_text = re.sub(r'[\t\n]', ' ', actual_text).strip() + actual_text = re.sub(r'\s+', ' ', actual_text) + expected_text = re.sub(r'[\t\n]', ' ', expected_text).strip() + expected_text = re.sub(r'\s+', ' ', expected_text) + + ignore_case = options.get('ignore_case', False) + if ignore_case: + actual_text = actual_text.lower() + expected_text = expected_text.lower() + + if actual_text == expected_text: + return 1.0 + return 0.0 + +import zipfile +from difflib import SequenceMatcher +import PyPDF2 + +def compare_pdf_content(content1, content2, text_similarity_threshold): + def extract_text_from_pdf(content): + with open("temp.pdf", "wb") as temp_pdf: + temp_pdf.write(content) + with open("temp.pdf", "rb") as temp_pdf: + pdf_reader = PyPDF2.PdfReader(temp_pdf) + text = '' + for page_num in range(len(pdf_reader.pages)): + page = pdf_reader.pages[page_num] + text += page.extract_text() + return text + + text1 = extract_text_from_pdf(content1) + text2 = extract_text_from_pdf(content2) + + similarity_ratio = SequenceMatcher(None, text1, text2).ratio() + + return similarity_ratio >= text_similarity_threshold + +def compare_zip_files(actual: str, expected: str, **options) -> float: + """ + Args: + actual (str): path to result zip file + expected (str): path to gold zip file + + Return: + float: the score + """ + if not actual: + return 0. + + with zipfile.ZipFile(actual, 'r') as zip_file1, zipfile.ZipFile(expected, 'r') as zip_file2: + file_list1 = set(zip_file1.namelist()) + file_list2 = set(zip_file2.namelist()) + + if file_list1 != file_list2: + return 0.0 + + for file_name in file_list1: + content1 = zip_file1.read(file_name) + content2 = zip_file2.read(file_name) + + if file_name.lower().endswith('.pdf'): + if compare_pdf_content(content1, content2, 0.95): + continue + else: + return 0.0 + elif content1 != content2: + return 0.0 + return 1.0 + + +def compare_config(actual: str, rules: Dict, **options) -> float: + if not actual: + return 0. + + with open(actual) as f1: + actual_text = f1.read() + + if actual_text == rules['expected']: + return 1.0 + return 0.0 + + +def compare_answer(actual: str, rules: Dict, **options) -> float: + """ + Args: + actual (str): result string + expected (str): gold string + + Return: + float: the score + """ + if not actual: + return 0. + + if actual == rules['expected']: + return 1.0 + + # TODO: can use text embedding to get non-zero return + return 0.0 + + +def is_extension_installed(actual: str, rules: Dict, **options): + if rules['type'] == 'contain': + if rules['expected'] in actual: + return 1.0 + return 0.0 + elif rules['type'] == 'not_contain': + if rules['expected'] not in actual: + return 1.0 + return 0.0 + else: + raise NotImplementedError + + +def check_python_file_by_test_suite(actual_files, test_file, **options) -> float: + """Check the python file by running the test suite in the given test file.""" + + test_function_name = options.get('test_function_name', 'test') + # Create a unique module name, it can be arbitrary but must be unique in the current runtime environment + module_name = 'dynamic_module' + + # Load the module from the given file path + spec = importlib.util.spec_from_file_location(module_name, test_file) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module # Add the loaded module to sys.modules + spec.loader.exec_module(module) # Execute the module to make its content available + + # Retrieve the function by name from the loaded module and execute it + test_function = getattr(module, test_function_name) + try: + if test_function(): + return 1.0 + else: + return 0.0 + except Exception as e: + return 0.0 + + +def check_python_file_by_gold_file(actual_files, gold_file: str, **options) -> float: + pass + + +def check_html_background_image(src_path: str, rule: Dict = None) -> float: + """ + Check if the background image is correctly set. + multi-app:bb7db4c2-30b5-4be7-8dd7-b8c4ec7d3108 + """ + if not src_path: + return 0.0 + + from bs4 import BeautifulSoup + with open(src_path, 'r') as f: + html_content = f.read() + soup = BeautifulSoup(html_content, 'html.parser') + styles = soup.find_all('style') + for style in styles: + if f'background-image: url(\'{rule["value"]}\')' in style.text: + return 1.0 + return 0.0 + + +def compare_result_files(src_path, tgt_path): + """ + Compare whether the content of two files are the same. + multi-app:7f35355e-02a6-45b5-b140-f0be698bcf85 + """ + if not src_path or not tgt_path: + return 0.0 + + with open(src_path, 'r') as f: + src_content = f.read().strip() + with open(tgt_path, 'r') as f: + tgt_content = f.read().strip() + try: + # Compare the content as numbers + tgt_content_num = float(tgt_content) + if tgt_content in src_content: + # If the content of tgt is in src, return 1.0 since output src might be + # a superset(language description+number) of tgt + return 1.0 + src_content_num = float(src_content) + if abs(src_content_num - tgt_content_num) < 1e-4: + return 1.0 + return 0.0 + except: + if src_content == tgt_content: + return 1.0 + return 0.0 diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/README.md b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/README.md new file mode 100644 index 000000000..e69de29bb diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/__init__.py b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/__init__.py new file mode 100644 index 000000000..eefa3e6e6 --- /dev/null +++ b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/__init__.py @@ -0,0 +1,36 @@ +from .base import VMManager, Provider + + +def create_vm_manager_and_provider(provider_name: str, region: str, use_proxy: bool = False, env_id: int = 0): + """ + Factory function to get the Virtual Machine Manager and Provider instances based on the provided provider name. + + Args: + provider_name (str): The name of the provider (e.g., "aws", "vmware", etc.) + region (str): The region for the provider + use_proxy (bool): Whether to use proxy-enabled providers (currently only supported for AWS) + env_id (int): Environment ID for deterministic port allocation (used by Docker provider) + """ + provider_name = provider_name.lower().strip() + if provider_name == "vmware": + from .vmware.manager import VMwareVMManager + from .vmware.provider import VMwareProvider + return VMwareVMManager(), VMwareProvider(region) + elif provider_name == "virtualbox": + from .virtualbox.manager import VirtualBoxVMManager + from .virtualbox.provider import VirtualBoxProvider + return VirtualBoxVMManager(), VirtualBoxProvider(region) + elif provider_name in ["aws", "amazon web services"]: + from .aws.manager import AWSVMManager + from .aws.provider import AWSProvider + return AWSVMManager(), AWSProvider(region) + elif provider_name == "azure": + from .azure.manager import AzureVMManager + from .azure.provider import AzureProvider + return AzureVMManager(), AzureProvider(region) + elif provider_name == "docker": + from .docker.manager import DockerVMManager + from .docker.provider import DockerProvider + return DockerVMManager(), DockerProvider(region, env_id) + else: + raise NotImplementedError(f"{provider_name} not implemented!") diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/aws/AWS_GUIDELINE.md b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/aws/AWS_GUIDELINE.md new file mode 100644 index 000000000..76efa296b --- /dev/null +++ b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/aws/AWS_GUIDELINE.md @@ -0,0 +1,52 @@ +# ☁ Configuration of AWS + +--- + +Welcome to the AWS VM Management documentation. Before you proceed with using the code to manage AWS services, please ensure the following variables are set correctly according to your AWS environment. + +## Configuration Variables +You need to assign values to several variables crucial for the operation of these scripts on AWS: + +- **`REGISTRY_PATH`**: Sets the file path for VM registration logging. + - Example: `'.aws_vms'` +- **`DEFAULT_REGION`**: Default AWS region where your instances will be launched. + - Example: `"us-east-1"` +- **`IMAGE_ID_MAP`**: Dictionary mapping regions to specific AMI IDs that should be used for instance creation. Here we already set the AMI id to the official OSWorld image of Ubuntu supported by us. + - Formatted as follows: + ```python + IMAGE_ID_MAP = { + "us-east-1": "ami-00674d875de9addc1" + # Add other regions and corresponding AMIs + } + ``` +- **`INSTANCE_TYPE`**: Specifies the type of EC2 instance to be launched. + - Example: `"t3.medium"` +- **`KEY_NAME`**: Specifies the name of the key pair to be used for the instances. + - Example: `"osworld_key"` +- **`NETWORK_INTERFACES`**: Configuration settings for network interfaces, which include subnet IDs, security group IDs, and public IP addressing. + - Example: + ```bash + + AWS_REGION=us-east-1 + AWS_SUBNET_ID=subnet-xxxx + AWS_SECURITY_GROUP_ID=sg-xxxx + ``` + + +### AWS CLI Configuration +Before using these scripts, you must configure your AWS CLI with your credentials. This can be done via the following commands: + +```bash +aws configure +``` +This command will prompt you for: +- AWS Access Key ID +- AWS Secret Access Key +- Default region name (Optional, you can press enter) + +Enter your credentials as required. This setup will allow you to interact with AWS services using the credentials provided. + +### Disclaimer +Use the provided scripts and configurations at your own risk. Ensure that you understand the AWS pricing model and potential costs associated with deploying instances, as using these scripts might result in charges on your AWS account. + +> **Note:** Ensure all AMI images used in `IMAGE_ID_MAP` are accessible and permissioned correctly for your AWS account, and that they are available in the specified region. diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/aws/__init__.py b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/aws/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/aws/manager.py b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/aws/manager.py new file mode 100644 index 000000000..15875d134 --- /dev/null +++ b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/aws/manager.py @@ -0,0 +1,271 @@ +import os +from filelock import FileLock +import boto3 +import psutil +import logging +import dotenv +import signal + + +INSTANCE_TYPE = "t3.xlarge" + +# Load environment variables from .env file +dotenv.load_dotenv() + +# Ensure the AWS region is set in the environment +if not os.getenv('AWS_REGION'): + raise EnvironmentError("AWS_REGION must be set in the environment variables.") + +# Ensure the AWS subnet and security group IDs are set in the environment +if not os.getenv('AWS_SUBNET_ID') or not os.getenv('AWS_SECURITY_GROUP_ID'): + raise EnvironmentError("AWS_SUBNET_ID and AWS_SECURITY_GROUP_ID must be set in the environment variables.") + +from desktop_env.providers.base import VMManager + +# Import proxy-related modules only when needed +try: + from desktop_env.providers.aws.proxy_pool import get_global_proxy_pool, init_proxy_pool + PROXY_SUPPORT_AVAILABLE = True +except ImportError: + PROXY_SUPPORT_AVAILABLE = False + +logger = logging.getLogger("desktopenv.providers.aws.AWSVMManager") +logger.setLevel(logging.INFO) + +DEFAULT_REGION = "us-east-1" +# todo: Add doc for the configuration of image, security group and network interface +# todo: public the AMI images +IMAGE_ID_MAP = { + "us-east-1": "ami-0cae20d2680c939d4", + "ap-east-1": "ami-0c092a5b8be4116f5", +} + + +def _allocate_vm(region=DEFAULT_REGION): + + if region not in IMAGE_ID_MAP: + raise ValueError(f"Region {region} is not supported. Supported regions are: {list(IMAGE_ID_MAP.keys())}") + + ec2_client = boto3.client('ec2', region_name=region) + instance_id = None + original_sigint_handler = signal.getsignal(signal.SIGINT) + original_sigterm_handler = signal.getsignal(signal.SIGTERM) + + def signal_handler(sig, frame): + if instance_id: + signal_name = "SIGINT" if sig == signal.SIGINT else "SIGTERM" + logger.warning(f"Received {signal_name} signal, terminating instance {instance_id}...") + try: + ec2_client.terminate_instances(InstanceIds=[instance_id]) + logger.info(f"Successfully terminated instance {instance_id} after {signal_name}.") + except Exception as cleanup_error: + logger.error(f"Failed to terminate instance {instance_id} after {signal_name}: {str(cleanup_error)}") + + # Restore original signal handlers + signal.signal(signal.SIGINT, original_sigint_handler) + signal.signal(signal.SIGTERM, original_sigterm_handler) + + # Raise appropriate exception based on signal type + if sig == signal.SIGINT: + raise KeyboardInterrupt + else: + # For SIGTERM, exit gracefully + import sys + sys.exit(0) + + try: + # Set up signal handlers for both SIGINT and SIGTERM + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + if not os.getenv('AWS_SECURITY_GROUP_ID'): + raise ValueError("AWS_SECURITY_GROUP_ID is not set in the environment variables.") + if not os.getenv('AWS_SUBNET_ID'): + raise ValueError("AWS_SUBNET_ID is not set in the environment variables.") + + run_instances_params = { + "MaxCount": 1, + "MinCount": 1, + "ImageId": IMAGE_ID_MAP[region], + "InstanceType": INSTANCE_TYPE, + "EbsOptimized": True, + "NetworkInterfaces": [ + { + "SubnetId": os.getenv('AWS_SUBNET_ID'), + "AssociatePublicIpAddress": True, + "DeviceIndex": 0, + "Groups": [ + os.getenv('AWS_SECURITY_GROUP_ID') + ] + } + ], + "BlockDeviceMappings": [ + { + "DeviceName": "/dev/sda1", + "Ebs": { + # "VolumeInitializationRate": 300 + "VolumeSize": 30, # Size in GB + "VolumeType": "gp3", # General Purpose SSD + "Throughput": 1000, + "Iops": 4000 # Adjust IOPS as needed + } + } + ] + } + + response = ec2_client.run_instances(**run_instances_params) + instance_id = response['Instances'][0]['InstanceId'] + + waiter = ec2_client.get_waiter('instance_running') + logger.info(f"Waiting for instance {instance_id} to be running...") + waiter.wait(InstanceIds=[instance_id]) + logger.info(f"Instance {instance_id} is ready.") + + # 获取并显示VNC访问地址 + try: + instance_details = ec2_client.describe_instances(InstanceIds=[instance_id]) + instance = instance_details['Reservations'][0]['Instances'][0] + public_ip = instance.get('PublicIpAddress', '') + if public_ip: + vnc_url = f"http://{public_ip}:5910/vnc.html" + logger.info("="*80) + logger.info(f"🖥️ VNC Web Access URL: {vnc_url}") + logger.info(f"📡 Public IP: {public_ip}") + logger.info(f"🆔 Instance ID: {instance_id}") + logger.info("="*80) + print(f"\n🌐 VNC访问地址: {vnc_url}") + print(f"📍 请在浏览器中打开上述地址进行远程桌面访问\n") + except Exception as e: + logger.warning(f"Failed to get VNC address for instance {instance_id}: {e}") + except KeyboardInterrupt: + logger.warning("VM allocation interrupted by user (SIGINT).") + if instance_id: + logger.info(f"Terminating instance {instance_id} due to interruption.") + ec2_client.terminate_instances(InstanceIds=[instance_id]) + raise + except Exception as e: + logger.error(f"Failed to allocate VM: {e}", exc_info=True) + if instance_id: + logger.info(f"Terminating instance {instance_id} due to an error.") + ec2_client.terminate_instances(InstanceIds=[instance_id]) + raise + finally: + # Restore original signal handlers + signal.signal(signal.SIGINT, original_sigint_handler) + signal.signal(signal.SIGTERM, original_sigterm_handler) + + return instance_id + + +def _allocate_vm_with_proxy(region=DEFAULT_REGION, proxy_config_file=None): + """Allocate a VM with proxy configuration""" + if not PROXY_SUPPORT_AVAILABLE: + logger.warning("Proxy support not available, falling back to regular VM allocation") + return _allocate_vm(region) + + from desktop_env.providers.aws.provider_with_proxy import AWSProviderWithProxy + + # Initialize proxy pool if needed + if proxy_config_file: + init_proxy_pool(proxy_config_file) + + # Get current proxy + proxy_pool = get_global_proxy_pool() + current_proxy = proxy_pool.get_next_proxy() + + if current_proxy: + logger.info(f"Allocating VM with proxy: {current_proxy.host}:{current_proxy.port}") + + # Create provider instance + provider = AWSProviderWithProxy(region=region, proxy_config_file=proxy_config_file) + + # Create new instance + instance_id = provider.create_instance_with_proxy( + image_id=IMAGE_ID_MAP[region], + instance_type=INSTANCE_TYPE, + security_groups=[os.getenv('AWS_SECURITY_GROUP_ID')], + subnet_id=os.getenv('AWS_SUBNET_ID') + ) + + try: + ec2_client = boto3.client('ec2', region_name=region) + instance_details = ec2_client.describe_instances(InstanceIds=[instance_id]) + instance = instance_details['Reservations'][0]['Instances'][0] + public_ip = instance.get('PublicIpAddress', '') + if public_ip: + vnc_url = f"http://{public_ip}:5910/vnc.html" + logger.info("="*80) + logger.info(f"🖥️ VNC Web Access URL: {vnc_url}") + logger.info(f"📡 Public IP: {public_ip}") + logger.info(f"🆔 Instance ID: {instance_id}") + if current_proxy: + logger.info(f"🌐 Proxy: {current_proxy.host}:{current_proxy.port}") + logger.info("="*80) + print(f"\n🌐 VNC Web Access URL: {vnc_url}") + if current_proxy: + print(f"🔄 Current Proxy: {current_proxy.host}:{current_proxy.port}") + print(f"📍 Please open the above address in the browser for remote desktop access\n") + except Exception as e: + logger.warning(f"Failed to get VNC address for proxy instance {instance_id}: {e}") + + return instance_id + + +class AWSVMManager(VMManager): + """ + AWS VM Manager for managing virtual machines on AWS. + + AWS does not need to maintain a registry of VMs, as it can dynamically allocate and deallocate VMs. + This class supports both regular VM allocation and proxy-enabled VM allocation. + """ + def __init__(self, proxy_config_file=None, **kwargs): + self.proxy_config_file = proxy_config_file + # self.lock = FileLock(".aws_lck", timeout=60) + self.initialize_registry() + + # Initialize proxy pool if proxy configuration is provided + if proxy_config_file and PROXY_SUPPORT_AVAILABLE: + init_proxy_pool(proxy_config_file) + logger.info(f"Proxy pool initialized with config: {proxy_config_file}") + + def initialize_registry(self, **kwargs): + pass + + def add_vm(self, vm_path, region=DEFAULT_REGION, lock_needed=True, **kwargs): + pass + + def _add_vm(self, vm_path, region=DEFAULT_REGION): + pass + + def delete_vm(self, vm_path, region=DEFAULT_REGION, lock_needed=True, **kwargs): + pass + + def _delete_vm(self, vm_path, region=DEFAULT_REGION): + pass + + def occupy_vm(self, vm_path, pid, region=DEFAULT_REGION, lock_needed=True, **kwargs): + pass + + def _occupy_vm(self, vm_path, pid, region=DEFAULT_REGION): + pass + + def check_and_clean(self, lock_needed=True, **kwargs): + pass + + def _check_and_clean(self): + pass + + def list_free_vms(self, region=DEFAULT_REGION, lock_needed=True, **kwargs): + pass + + def _list_free_vms(self, region=DEFAULT_REGION): + pass + + def get_vm_path(self, region=DEFAULT_REGION, **kwargs): + if self.proxy_config_file: + logger.info("Allocating a new VM with proxy configuration in region: {}".format(region)) + new_vm_path = _allocate_vm_with_proxy(region, self.proxy_config_file) + else: + logger.info("Allocating a new VM in region: {}".format(region)) + new_vm_path = _allocate_vm(region) + return new_vm_path \ No newline at end of file diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/aws/provider.py b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/aws/provider.py new file mode 100644 index 000000000..d2c034e6d --- /dev/null +++ b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/aws/provider.py @@ -0,0 +1,186 @@ +import boto3 +from botocore.exceptions import ClientError + +import logging + +from desktop_env.providers.base import Provider +from datetime import datetime +import time + + +logger = logging.getLogger("desktopenv.providers.aws.AWSProvider") +logger.setLevel(logging.INFO) + +WAIT_DELAY = 15 +MAX_ATTEMPTS = 10 + + +class AWSProvider(Provider): + + + def start_emulator(self, path_to_vm: str, headless: bool, *args, **kwargs): + logger.info("Starting AWS VM...") + ec2_client = boto3.client('ec2', region_name=self.region) + + try: + # Check the current state of the instance + response = ec2_client.describe_instances(InstanceIds=[path_to_vm]) + state = response['Reservations'][0]['Instances'][0]['State']['Name'] + logger.info(f"Instance {path_to_vm} current state: {state}") + + if state == 'running': + # If the instance is already running, skip starting it + logger.info(f"Instance {path_to_vm} is already running. Skipping start.") + return + + if state == 'stopped': + # Start the instance if it's currently stopped + ec2_client.start_instances(InstanceIds=[path_to_vm]) + logger.info(f"Instance {path_to_vm} is starting...") + + # Wait until the instance reaches 'running' state + waiter = ec2_client.get_waiter('instance_running') + waiter.wait( + InstanceIds=[path_to_vm], + WaiterConfig={'Delay': WAIT_DELAY, 'MaxAttempts': MAX_ATTEMPTS} + ) + logger.info(f"Instance {path_to_vm} is now running.") + else: + # For all other states (terminated, pending, etc.), log a warning + logger.warning(f"Instance {path_to_vm} is in state '{state}' and cannot be started.") + + except ClientError as e: + logger.error(f"Failed to start the AWS VM {path_to_vm}: {str(e)}") + raise + + + def get_ip_address(self, path_to_vm: str) -> str: + logger.info("Getting AWS VM IP address...") + ec2_client = boto3.client('ec2', region_name=self.region) + + try: + response = ec2_client.describe_instances(InstanceIds=[path_to_vm]) + for reservation in response['Reservations']: + for instance in reservation['Instances']: + private_ip_address = instance.get('PrivateIpAddress', '') + public_ip_address = instance.get('PublicIpAddress', '') + + if public_ip_address: + vnc_url = f"http://{public_ip_address}:5910/vnc.html" + logger.info("="*80) + logger.info(f"🖥️ VNC Web Access URL: {vnc_url}") + logger.info(f"📡 Public IP: {public_ip_address}") + logger.info(f"🏠 Private IP: {private_ip_address}") + logger.info("="*80) + print(f"\n🌐 VNC Web Access URL: {vnc_url}") + print(f"📍 Please open the above address in the browser for remote desktop access\n") + else: + logger.warning("No public IP address available for VNC access") + + return private_ip_address + return '' # Return an empty string if no IP address is found + except ClientError as e: + logger.error(f"Failed to retrieve IP address for the instance {path_to_vm}: {str(e)}") + raise + + def save_state(self, path_to_vm: str, snapshot_name: str): + logger.info("Saving AWS VM state...") + ec2_client = boto3.client('ec2', region_name=self.region) + + try: + image_response = ec2_client.create_image(InstanceId=path_to_vm, Name=snapshot_name) + image_id = image_response['ImageId'] + logger.info(f"AMI {image_id} created successfully from instance {path_to_vm}.") + return image_id + except ClientError as e: + logger.error(f"Failed to create AMI from the instance {path_to_vm}: {str(e)}") + raise + + def revert_to_snapshot(self, path_to_vm: str, snapshot_name: str): + logger.info(f"Reverting AWS VM to snapshot AMI: {snapshot_name}...") + ec2_client = boto3.client('ec2', region_name=self.region) + + try: + # Step 1: Retrieve the original instance details + instance_details = ec2_client.describe_instances(InstanceIds=[path_to_vm]) + instance = instance_details['Reservations'][0]['Instances'][0] + security_groups = [sg['GroupId'] for sg in instance['SecurityGroups']] + subnet_id = instance['SubnetId'] + instance_type = instance['InstanceType'] + + # Step 2: Terminate the old instance + ec2_client.terminate_instances(InstanceIds=[path_to_vm]) + logger.info(f"Old instance {path_to_vm} has been terminated.") + + # Step 3: Launch a new instance from the snapshot(AMI) with performance optimization + logger.info(f"Launching a new instance from AMI {snapshot_name}...") + + run_instances_params = { + "MaxCount": 1, + "MinCount": 1, + "ImageId": snapshot_name, + "InstanceType": instance_type, + "EbsOptimized": True, + "NetworkInterfaces": [ + { + "SubnetId": subnet_id, + "AssociatePublicIpAddress": True, + "DeviceIndex": 0, + "Groups": security_groups + } + ], + "BlockDeviceMappings": [ + { + "DeviceName": "/dev/sda1", + "Ebs": { + # "VolumeInitializationRate": 300 + "VolumeSize": 30, # Size in GB + "VolumeType": "gp3", # General Purpose SSD + "Throughput": 1000, + "Iops": 4000 # Adjust IOPS as needed + } + } + ] + } + + new_instance = ec2_client.run_instances(**run_instances_params) + new_instance_id = new_instance['Instances'][0]['InstanceId'] + logger.info(f"New instance {new_instance_id} launched from AMI {snapshot_name}.") + logger.info(f"Waiting for instance {new_instance_id} to be running...") + ec2_client.get_waiter('instance_running').wait(InstanceIds=[new_instance_id]) + + logger.info(f"Instance {new_instance_id} is ready.") + + try: + instance_details = ec2_client.describe_instances(InstanceIds=[new_instance_id]) + instance = instance_details['Reservations'][0]['Instances'][0] + public_ip = instance.get('PublicIpAddress', '') + if public_ip: + vnc_url = f"http://{public_ip}:5910/vnc.html" + logger.info("="*80) + logger.info(f"🖥️ New Instance VNC Web Access URL: {vnc_url}") + logger.info(f"📡 Public IP: {public_ip}") + logger.info(f"🆔 New Instance ID: {new_instance_id}") + logger.info("="*80) + print(f"\n🌐 New Instance VNC Web Access URL: {vnc_url}") + print(f"📍 Please open the above address in the browser for remote desktop access\n") + except Exception as e: + logger.warning(f"Failed to get VNC address for new instance {new_instance_id}: {e}") + + return new_instance_id + + except ClientError as e: + logger.error(f"Failed to revert to snapshot {snapshot_name} for the instance {path_to_vm}: {str(e)}") + raise + + + def stop_emulator(self, path_to_vm, region=None): + logger.info(f"Stopping AWS VM {path_to_vm}...") + ec2_client = boto3.client('ec2', region_name=self.region) + + try: + ec2_client.terminate_instances(InstanceIds=[path_to_vm]) + logger.info(f"Instance {path_to_vm} has been terminated.") + except ClientError as e: + logger.error(f"Failed to stop the AWS VM {path_to_vm}: {str(e)}") + raise diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/aws/provider_with_proxy.py b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/aws/provider_with_proxy.py new file mode 100644 index 000000000..2ffb7c050 --- /dev/null +++ b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/aws/provider_with_proxy.py @@ -0,0 +1,315 @@ +import boto3 +from botocore.exceptions import ClientError +import base64 +import logging +import json +from typing import Optional + +from desktop_env.providers.base import Provider +from desktop_env.providers.aws.proxy_pool import get_global_proxy_pool, init_proxy_pool, ProxyInfo + +logger = logging.getLogger("desktopenv.providers.aws.AWSProviderWithProxy") +logger.setLevel(logging.INFO) + +WAIT_DELAY = 15 +MAX_ATTEMPTS = 10 + + +class AWSProviderWithProxy(Provider): + + def __init__(self, region: str = None, proxy_config_file: str = None): + super().__init__(region) + self.current_proxy: Optional[ProxyInfo] = None + + # 初始化代理池 + if proxy_config_file: + init_proxy_pool(proxy_config_file) + logger.info(f"Initialized proxy pool from {proxy_config_file}") + + # 获取下一个可用代理 + self._rotate_proxy() + + def _rotate_proxy(self): + """轮换到下一个可用代理""" + proxy_pool = get_global_proxy_pool() + self.current_proxy = proxy_pool.get_next_proxy() + + if self.current_proxy: + logger.info(f"Switched to proxy: {self.current_proxy.host}:{self.current_proxy.port}") + else: + logger.warning("No proxy available, using direct connection") + + def _generate_proxy_user_data(self) -> str: + """生成包含代理配置的user data脚本""" + if not self.current_proxy: + return "" + + proxy_url = self._format_proxy_url(self.current_proxy) + + user_data_script = f"""#!/bin/bash +# Configure system proxy +echo 'export http_proxy={proxy_url}' >> /etc/environment +echo 'export https_proxy={proxy_url}' >> /etc/environment +echo 'export HTTP_PROXY={proxy_url}' >> /etc/environment +echo 'export HTTPS_PROXY={proxy_url}' >> /etc/environment + +# Configure apt proxy +cat > /etc/apt/apt.conf.d/95proxy << EOF +Acquire::http::Proxy "{proxy_url}"; +Acquire::https::Proxy "{proxy_url}"; +EOF + +# Configure chrome/chromium proxy +mkdir -p /etc/opt/chrome/policies/managed +cat > /etc/opt/chrome/policies/managed/proxy.json << EOF +{{ + "ProxyMode": "fixed_servers", + "ProxyServer": "{self.current_proxy.host}:{self.current_proxy.port}" +}} +EOF + +# Configure chromium proxy (Ubuntu default) +mkdir -p /etc/chromium/policies/managed +cat > /etc/chromium/policies/managed/proxy.json << EOF +{{ + "ProxyMode": "fixed_servers", + "ProxyServer": "{self.current_proxy.host}:{self.current_proxy.port}" +}} +EOF + +# Configure firefox proxy - support multiple possible paths +for firefox_dir in /etc/firefox/policies /usr/lib/firefox/distribution/policies /etc/firefox-esr/policies; do + if [ -d "$(dirname "$firefox_dir")" ]; then + mkdir -p "$firefox_dir" + cat > "$firefox_dir/policies.json" << EOF +{{ + "policies": {{ + "Proxy": {{ + "Mode": "manual", + "HTTPProxy": "{self.current_proxy.host}:{self.current_proxy.port}", + "HTTPSProxy": "{self.current_proxy.host}:{self.current_proxy.port}", + "UseHTTPProxyForAllProtocols": true + }} + }} +}} +EOF + break + fi +done + +# Reload environment variables +source /etc/environment + +# Log proxy configuration +echo "$(date): Configured proxy {self.current_proxy.host}:{self.current_proxy.port}" >> /var/log/proxy-setup.log +""" + + return base64.b64encode(user_data_script.encode()).decode() + + def _format_proxy_url(self, proxy: ProxyInfo) -> str: + """格式化代理URL""" + if proxy.username and proxy.password: + return f"{proxy.protocol}://{proxy.username}:{proxy.password}@{proxy.host}:{proxy.port}" + else: + return f"{proxy.protocol}://{proxy.host}:{proxy.port}" + + def start_emulator(self, path_to_vm: str, headless: bool, *args, **kwargs): + logger.info("Starting AWS VM with proxy configuration...") + ec2_client = boto3.client('ec2', region_name=self.region) + + try: + # 如果实例已经存在,直接启动 + ec2_client.start_instances(InstanceIds=[path_to_vm]) + logger.info(f"Instance {path_to_vm} is starting...") + + # Wait for the instance to be in the 'running' state + waiter = ec2_client.get_waiter('instance_running') + waiter.wait(InstanceIds=[path_to_vm], WaiterConfig={'Delay': WAIT_DELAY, 'MaxAttempts': MAX_ATTEMPTS}) + logger.info(f"Instance {path_to_vm} is now running.") + + except ClientError as e: + logger.error(f"Failed to start the AWS VM {path_to_vm}: {str(e)}") + raise + + def create_instance_with_proxy(self, image_id: str, instance_type: str, + security_groups: list, subnet_id: str) -> str: + """创建带有代理配置的新实例""" + ec2_client = boto3.client('ec2', region_name=self.region) + + user_data = self._generate_proxy_user_data() + + run_instances_params = { + "MaxCount": 1, + "MinCount": 1, + "ImageId": image_id, + "InstanceType": instance_type, + "EbsOptimized": True, + "NetworkInterfaces": [ + { + "SubnetId": subnet_id, + "AssociatePublicIpAddress": True, + "DeviceIndex": 0, + "Groups": security_groups + } + ] + } + + if user_data: + run_instances_params["UserData"] = user_data + + try: + response = ec2_client.run_instances(**run_instances_params) + instance_id = response['Instances'][0]['InstanceId'] + + logger.info(f"Created new instance {instance_id} with proxy configuration") + + logger.info(f"Waiting for instance {instance_id} to be running...") + ec2_client.get_waiter('instance_running').wait(InstanceIds=[instance_id]) + logger.info(f"Instance {instance_id} is ready.") + + try: + instance_details = ec2_client.describe_instances(InstanceIds=[instance_id]) + instance = instance_details['Reservations'][0]['Instances'][0] + public_ip = instance.get('PublicIpAddress', '') + if public_ip: + vnc_url = f"http://{public_ip}:5910/vnc.html" + logger.info("="*80) + logger.info(f"🖥️ VNC Web Access URL: {vnc_url}") + logger.info(f"📡 Public IP: {public_ip}") + logger.info(f"🆔 Instance ID: {instance_id}") + if self.current_proxy: + logger.info(f"🌐 Proxy: {self.current_proxy.host}:{self.current_proxy.port}") + logger.info("="*80) + print(f"\n🌐 VNC Web Access URL: {vnc_url}") + if self.current_proxy: + print(f"🔄 Current Proxy: {self.current_proxy.host}:{self.current_proxy.port}") + print(f"📍 Please open the above address in the browser for remote desktop access\n") + except Exception as e: + logger.warning(f"Failed to get VNC address for instance {instance_id}: {e}") + + return instance_id + + except ClientError as e: + logger.error(f"Failed to create instance with proxy: {str(e)}") + if self.current_proxy: + proxy_pool = get_global_proxy_pool() + proxy_pool.mark_proxy_failed(self.current_proxy) + self._rotate_proxy() + raise + + def get_ip_address(self, path_to_vm: str) -> str: + logger.info("Getting AWS VM IP address...") + ec2_client = boto3.client('ec2', region_name=self.region) + + try: + response = ec2_client.describe_instances(InstanceIds=[path_to_vm]) + for reservation in response['Reservations']: + for instance in reservation['Instances']: + private_ip_address = instance.get('PrivateIpAddress', '') + public_ip_address = instance.get('PublicIpAddress', '') + + if public_ip_address: + vnc_url = f"http://{public_ip_address}:5910/vnc.html" + logger.info("="*80) + logger.info(f"🖥️ VNC Web Access URL: {vnc_url}") + logger.info(f"📡 Public IP: {public_ip_address}") + logger.info(f"🏠 Private IP: {private_ip_address}") + if self.current_proxy: + logger.info(f"🌐 Proxy: {self.current_proxy.host}:{self.current_proxy.port}") + logger.info("="*80) + print(f"\n🌐 VNC Web Access URL: {vnc_url}") + if self.current_proxy: + print(f"🔄 Current Proxy: {self.current_proxy.host}:{self.current_proxy.port}") + print(f"📍 Please open the above address in the browser for remote desktop access\n") + else: + logger.warning("No public IP address available for VNC access") + + return private_ip_address + return '' + except ClientError as e: + logger.error(f"Failed to retrieve IP address for the instance {path_to_vm}: {str(e)}") + raise + + def save_state(self, path_to_vm: str, snapshot_name: str): + logger.info("Saving AWS VM state...") + ec2_client = boto3.client('ec2', region_name=self.region) + + try: + image_response = ec2_client.create_image(InstanceId=path_to_vm, Name=snapshot_name) + image_id = image_response['ImageId'] + logger.info(f"AMI {image_id} created successfully from instance {path_to_vm}.") + return image_id + except ClientError as e: + logger.error(f"Failed to create AMI from the instance {path_to_vm}: {str(e)}") + raise + + def revert_to_snapshot(self, path_to_vm: str, snapshot_name: str): + logger.info(f"Reverting AWS VM to snapshot: {snapshot_name}...") + ec2_client = boto3.client('ec2', region_name=self.region) + + try: + # Get original instance details for config. + instance_details = ec2_client.describe_instances(InstanceIds=[path_to_vm]) + instance = instance_details['Reservations'][0]['Instances'][0] + security_groups = [sg['GroupId'] for sg in instance['SecurityGroups']] + subnet_id = instance['SubnetId'] + instance_type = instance['InstanceType'] + + # Terminate the old instance. This is a non-blocking call. + logger.info(f"Initiating termination for old instance {path_to_vm}...") + ec2_client.terminate_instances(InstanceIds=[path_to_vm]) + logger.info(f"Old instance {path_to_vm} termination initiated.") + + # Rotate to a new proxy + self._rotate_proxy() + + # Create a new instance + new_instance_id = self.create_instance_with_proxy( + snapshot_name, instance_type, security_groups, subnet_id + ) + + # Note: VNC address is displayed within create_instance_with_proxy + logger.info(f"Successfully launched new instance {new_instance_id} for revert.") + + return new_instance_id + + except ClientError as e: + logger.error(f"Failed to revert to snapshot {snapshot_name} for the instance {path_to_vm}: {str(e)}") + raise + + def stop_emulator(self, path_to_vm, region=None): + logger.info(f"Stopping AWS VM {path_to_vm}...") + ec2_client = boto3.client('ec2', region_name=self.region) + + try: + ec2_client.stop_instances(InstanceIds=[path_to_vm]) + waiter = ec2_client.get_waiter('instance_stopped') + waiter.wait(InstanceIds=[path_to_vm], WaiterConfig={'Delay': WAIT_DELAY, 'MaxAttempts': MAX_ATTEMPTS}) + logger.info(f"Instance {path_to_vm} has been stopped.") + except ClientError as e: + logger.error(f"Failed to stop the AWS VM {path_to_vm}: {str(e)}") + raise + + def get_current_proxy_info(self) -> Optional[dict]: + """获取当前代理信息""" + if self.current_proxy: + return { + 'host': self.current_proxy.host, + 'port': self.current_proxy.port, + 'protocol': self.current_proxy.protocol, + 'failed_count': self.current_proxy.failed_count + } + return None + + def force_rotate_proxy(self): + """强制轮换代理""" + logger.info("Force rotating proxy...") + if self.current_proxy: + proxy_pool = get_global_proxy_pool() + proxy_pool.mark_proxy_failed(self.current_proxy) + self._rotate_proxy() + + def get_proxy_stats(self) -> dict: + """获取代理池统计信息""" + proxy_pool = get_global_proxy_pool() + return proxy_pool.get_stats() \ No newline at end of file diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/aws/proxy_pool.py b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/aws/proxy_pool.py new file mode 100644 index 000000000..812df1854 --- /dev/null +++ b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/aws/proxy_pool.py @@ -0,0 +1,193 @@ +import random +import requests +import logging +import time +from typing import List, Dict, Optional +from dataclasses import dataclass +from threading import Lock +import json + +logger = logging.getLogger("desktopenv.providers.aws.ProxyPool") +logger.setLevel(logging.INFO) + +@dataclass +class ProxyInfo: + host: str + port: int + username: Optional[str] = None + password: Optional[str] = None + protocol: str = "http" # http, https, socks5 + failed_count: int = 0 + last_used: float = 0 + is_active: bool = True + +class ProxyPool: + def __init__(self, config_file: str = None): + self.proxies: List[ProxyInfo] = [] + self.current_index = 0 + self.lock = Lock() + self.max_failures = 3 # 最大失败次数 + self.cooldown_time = 300 # 5分钟冷却时间 + + if config_file: + self.load_proxies_from_file(config_file) + + def load_proxies_from_file(self, config_file: str): + """从配置文件加载代理列表""" + try: + with open(config_file, 'r') as f: + proxy_configs = json.load(f) + + for config in proxy_configs: + proxy = ProxyInfo( + host=config['host'], + port=config['port'], + username=config.get('username'), + password=config.get('password'), + protocol=config.get('protocol', 'http') + ) + self.proxies.append(proxy) + + logger.info(f"Loaded {len(self.proxies)} proxies from {config_file}") + except Exception as e: + logger.error(f"Failed to load proxies from {config_file}: {e}") + + def add_proxy(self, host: str, port: int, username: str = None, + password: str = None, protocol: str = "http"): + """添加代理到池中""" + proxy = ProxyInfo(host=host, port=port, username=username, + password=password, protocol=protocol) + with self.lock: + self.proxies.append(proxy) + logger.info(f"Added proxy {host}:{port}") + + def get_next_proxy(self) -> Optional[ProxyInfo]: + """获取下一个可用的代理""" + with self.lock: + if not self.proxies: + return None + + # 过滤掉失败次数过多的代理 + active_proxies = [p for p in self.proxies if self._is_proxy_available(p)] + + if not active_proxies: + logger.warning("No active proxies available") + return None + + # 轮询选择代理 + proxy = active_proxies[self.current_index % len(active_proxies)] + self.current_index += 1 + proxy.last_used = time.time() + + return proxy + + def _is_proxy_available(self, proxy: ProxyInfo) -> bool: + """检查代理是否可用""" + if not proxy.is_active: + return False + + if proxy.failed_count >= self.max_failures: + # 检查是否过了冷却时间 + if time.time() - proxy.last_used < self.cooldown_time: + return False + else: + # 重置失败计数 + proxy.failed_count = 0 + + return True + + def mark_proxy_failed(self, proxy: ProxyInfo): + """标记代理失败""" + with self.lock: + proxy.failed_count += 1 + if proxy.failed_count >= self.max_failures: + logger.warning(f"Proxy {proxy.host}:{proxy.port} marked as failed " + f"(failures: {proxy.failed_count})") + + def mark_proxy_success(self, proxy: ProxyInfo): + """标记代理成功""" + with self.lock: + proxy.failed_count = 0 + + def test_proxy(self, proxy: ProxyInfo, test_url: str = "http://httpbin.org/ip", + timeout: int = 10) -> bool: + """测试代理是否正常工作""" + try: + proxy_url = self._format_proxy_url(proxy) + proxies = { + 'http': proxy_url, + 'https': proxy_url + } + + response = requests.get(test_url, proxies=proxies, timeout=timeout) + if response.status_code == 200: + self.mark_proxy_success(proxy) + return True + else: + self.mark_proxy_failed(proxy) + return False + + except Exception as e: + logger.debug(f"Proxy test failed for {proxy.host}:{proxy.port}: {e}") + self.mark_proxy_failed(proxy) + return False + + def _format_proxy_url(self, proxy: ProxyInfo) -> str: + """格式化代理URL""" + if proxy.username and proxy.password: + return f"{proxy.protocol}://{proxy.username}:{proxy.password}@{proxy.host}:{proxy.port}" + else: + return f"{proxy.protocol}://{proxy.host}:{proxy.port}" + + def get_proxy_dict(self, proxy: ProxyInfo) -> Dict[str, str]: + """获取requests库使用的代理字典""" + proxy_url = self._format_proxy_url(proxy) + return { + 'http': proxy_url, + 'https': proxy_url + } + + def test_all_proxies(self, test_url: str = "http://httpbin.org/ip"): + """测试所有代理""" + logger.info("Testing all proxies...") + working_count = 0 + + for proxy in self.proxies: + if self.test_proxy(proxy, test_url): + working_count += 1 + logger.info(f"✓ Proxy {proxy.host}:{proxy.port} is working") + else: + logger.warning(f"✗ Proxy {proxy.host}:{proxy.port} failed") + + logger.info(f"Proxy test completed: {working_count}/{len(self.proxies)} working") + return working_count + + def get_stats(self) -> Dict: + """获取代理池统计信息""" + with self.lock: + total = len(self.proxies) + active = len([p for p in self.proxies if self._is_proxy_available(p)]) + failed = len([p for p in self.proxies if p.failed_count >= self.max_failures]) + + return { + 'total': total, + 'active': active, + 'failed': failed, + 'success_rate': active / total if total > 0 else 0 + } + +# 全局代理池实例 +_proxy_pool = None + +def get_global_proxy_pool() -> ProxyPool: + """获取全局代理池实例""" + global _proxy_pool + if _proxy_pool is None: + _proxy_pool = ProxyPool() + return _proxy_pool + +def init_proxy_pool(config_file: str = None): + """初始化全局代理池""" + global _proxy_pool + _proxy_pool = ProxyPool(config_file) + return _proxy_pool \ No newline at end of file diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/azure/__init__.py b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/azure/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/azure/manager.py b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/azure/manager.py new file mode 100644 index 000000000..60765114a --- /dev/null +++ b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/azure/manager.py @@ -0,0 +1,85 @@ +import os +import threading +import boto3 +import psutil + +import logging + +from desktop_env.providers.base import VMManager + +logger = logging.getLogger("desktopenv.providers.azure.AzureVMManager") +logger.setLevel(logging.INFO) + +REGISTRY_PATH = '.azure_vms' + + +def _allocate_vm(region): + raise NotImplementedError + + +class AzureVMManager(VMManager): + def __init__(self, registry_path=REGISTRY_PATH): + self.registry_path = registry_path + self.lock = threading.Lock() + self.initialize_registry() + + def initialize_registry(self): + with self.lock: # Locking during initialization + if not os.path.exists(self.registry_path): + with open(self.registry_path, 'w') as file: + file.write('') + + def add_vm(self, vm_path, region): + with self.lock: + with open(self.registry_path, 'r') as file: + lines = file.readlines() + vm_path_at_vm_region = "{}@{}".format(vm_path, region) + new_lines = lines + [f'{vm_path_at_vm_region}|free\n'] + with open(self.registry_path, 'w') as file: + file.writelines(new_lines) + + def occupy_vm(self, vm_path, pid, region): + with self.lock: + new_lines = [] + with open(self.registry_path, 'r') as file: + lines = file.readlines() + for line in lines: + registered_vm_path, _ = line.strip().split('|') + if registered_vm_path == "{}@{}".format(vm_path, region): + new_lines.append(f'{registered_vm_path}|{pid}\n') + else: + new_lines.append(line) + with open(self.registry_path, 'w') as file: + file.writelines(new_lines) + + def check_and_clean(self): + raise NotImplementedError + + def list_free_vms(self, region): + with self.lock: # Lock when reading the registry + free_vms = [] + with open(self.registry_path, 'r') as file: + lines = file.readlines() + for line in lines: + vm_path_at_vm_region, pid_str = line.strip().split('|') + vm_path, vm_region = vm_path_at_vm_region.split("@") + if pid_str == "free" and vm_region == region: + free_vms.append((vm_path, pid_str)) + return free_vms + + def get_vm_path(self, region): + self.check_and_clean() + free_vms_paths = self.list_free_vms(region) + if len(free_vms_paths) == 0: + # No free virtual machine available, generate a new one + logger.info("No free virtual machine available. Generating a new one, which would take a while...☕") + new_vm_path = _allocate_vm(region) + self.add_vm(new_vm_path, region) + self.occupy_vm(new_vm_path, os.getpid(), region) + return new_vm_path + else: + # Choose the first free virtual machine + chosen_vm_path = free_vms_paths[0][0] + self.occupy_vm(chosen_vm_path, os.getpid(), region) + return chosen_vm_path + diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/azure/provider.py b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/azure/provider.py new file mode 100644 index 000000000..fb435039b --- /dev/null +++ b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/azure/provider.py @@ -0,0 +1,205 @@ +import os +import time +from azure.identity import DefaultAzureCredential +from azure.mgmt.compute import ComputeManagementClient +from azure.mgmt.network import NetworkManagementClient +from azure.core.exceptions import ResourceNotFoundError + +import logging + +from desktop_env.providers.base import Provider + +logger = logging.getLogger("desktopenv.providers.azure.AzureProvider") +logger.setLevel(logging.INFO) + +WAIT_DELAY = 15 +MAX_ATTEMPTS = 10 + +# To use the Azure provider, download azure-cli by https://learn.microsoft.com/en-us/cli/azure/install-azure-cli, +# use "az login" to log into you Azure account, +# and set environment variable "AZURE_SUBSCRIPTION_ID" to your subscription ID. +# Provide your resource group name and VM name in the format "RESOURCE_GROUP_NAME/VM_NAME" and pass as an argument for "-p". + +class AzureProvider(Provider): + def __init__(self, region: str = None): + super().__init__(region) + credential = DefaultAzureCredential() + try: + self.subscription_id = os.environ["AZURE_SUBSCRIPTION_ID"] + except: + logger.error("Azure subscription ID not found. Please set environment variable \"AZURE_SUBSCRIPTION_ID\".") + raise + self.compute_client = ComputeManagementClient(credential, self.subscription_id) + self.network_client = NetworkManagementClient(credential, self.subscription_id) + + def start_emulator(self, path_to_vm: str, headless: bool): + logger.info("Starting Azure VM...") + resource_group_name, vm_name = path_to_vm.split('/') + + vm = self.compute_client.virtual_machines.get(resource_group_name, vm_name, expand='instanceView') + power_state = vm.instance_view.statuses[-1].code + if power_state == "PowerState/running": + logger.info("VM is already running.") + return + + try: + # Start the instance + for _ in range(MAX_ATTEMPTS): + async_vm_start = self.compute_client.virtual_machines.begin_start(resource_group_name, vm_name) + logger.info(f"VM {path_to_vm} is starting...") + # Wait for the instance to start + async_vm_start.wait(timeout=WAIT_DELAY) + vm = self.compute_client.virtual_machines.get(resource_group_name, vm_name, expand='instanceView') + power_state = vm.instance_view.statuses[-1].code + if power_state == "PowerState/running": + logger.info(f"VM {path_to_vm} is already running.") + break + except Exception as e: + logger.error(f"Failed to start the Azure VM {path_to_vm}: {str(e)}") + raise + + def get_ip_address(self, path_to_vm: str) -> str: + logger.info("Getting Azure VM IP address...") + resource_group_name, vm_name = path_to_vm.split('/') + + vm = self.compute_client.virtual_machines.get(resource_group_name, vm_name) + + for interface in vm.network_profile.network_interfaces: + name=" ".join(interface.id.split('/')[-1:]) + sub="".join(interface.id.split('/')[4]) + + try: + thing=self.network_client.network_interfaces.get(sub, name).ip_configurations + + network_card_id = thing[0].public_ip_address.id.split('/')[-1] + public_ip_address = self.network_client.public_ip_addresses.get(resource_group_name, network_card_id) + logger.info(f"VM IP address is {public_ip_address.ip_address}") + return public_ip_address.ip_address + + except Exception as e: + logger.error(f"Cannot get public IP for VM {path_to_vm}") + raise + + def save_state(self, path_to_vm: str, snapshot_name: str): + print("Saving Azure VM state...") + resource_group_name, vm_name = path_to_vm.split('/') + + vm = self.compute_client.virtual_machines.get(resource_group_name, vm_name) + + try: + # Backup each disk attached to the VM + for disk in vm.storage_profile.data_disks + [vm.storage_profile.os_disk]: + # Create a snapshot of the disk + snapshot = { + 'location': vm.location, + 'creation_data': { + 'create_option': 'Copy', + 'source_uri': disk.managed_disk.id + } + } + async_snapshot_creation = self.compute_client.snapshots.begin_create_or_update(resource_group_name, snapshot_name, snapshot) + async_snapshot_creation.wait(timeout=WAIT_DELAY) + + logger.info(f"Successfully created snapshot {snapshot_name} for VM {path_to_vm}.") + except Exception as e: + logger.error(f"Failed to create snapshot {snapshot_name} of the Azure VM {path_to_vm}: {str(e)}") + raise + + def revert_to_snapshot(self, path_to_vm: str, snapshot_name: str): + logger.info(f"Reverting VM to snapshot: {snapshot_name}...") + resource_group_name, vm_name = path_to_vm.split('/') + + vm = self.compute_client.virtual_machines.get(resource_group_name, vm_name) + + # Stop the VM for disk creation + logger.info(f"Stopping VM: {vm_name}") + async_vm_stop = self.compute_client.virtual_machines.begin_deallocate(resource_group_name, vm_name) + async_vm_stop.wait(timeout=WAIT_DELAY) # Wait for the VM to stop + + try: + # Get the snapshot + snapshot = self.compute_client.snapshots.get(resource_group_name, snapshot_name) + + # Get the original disk information + original_disk_id = vm.storage_profile.os_disk.managed_disk.id + disk_name = original_disk_id.split('/')[-1] + if disk_name[-1] in ['0', '1']: + new_disk_name = disk_name[:-1] + str(int(disk_name[-1])^1) + else: + new_disk_name = disk_name + "0" + + # Delete the disk if it exists + self.compute_client.disks.begin_delete(resource_group_name, new_disk_name).wait(timeout=WAIT_DELAY) + + # Make sure the disk is deleted before proceeding to the next step + disk_deleted = False + polling_interval = 10 + attempts = 0 + while not disk_deleted and attempts < MAX_ATTEMPTS: + try: + self.compute_client.disks.get(resource_group_name, new_disk_name) + # If the above line does not raise an exception, the disk still exists + time.sleep(polling_interval) + attempts += 1 + except ResourceNotFoundError: + disk_deleted = True + + if not disk_deleted: + logger.error(f"Disk {new_disk_name} deletion timed out.") + raise + + # Create a new managed disk from the snapshot + snapshot = self.compute_client.snapshots.get(resource_group_name, snapshot_name) + disk_creation = { + 'location': snapshot.location, + 'creation_data': { + 'create_option': 'Copy', + 'source_resource_id': snapshot.id + }, + 'zones': vm.zones if vm.zones else None # Preserve the original disk's zone + } + async_disk_creation = self.compute_client.disks.begin_create_or_update(resource_group_name, new_disk_name, disk_creation) + restored_disk = async_disk_creation.result() # Wait for the disk creation to complete + + vm.storage_profile.os_disk = { + 'create_option': vm.storage_profile.os_disk.create_option, + 'managed_disk': { + 'id': restored_disk.id + } + } + + async_vm_creation = self.compute_client.virtual_machines.begin_create_or_update(resource_group_name, vm_name, vm) + async_vm_creation.wait(timeout=WAIT_DELAY) + + # Delete the original disk + self.compute_client.disks.begin_delete(resource_group_name, disk_name).wait() + + logger.info(f"Successfully reverted to snapshot {snapshot_name}.") + except Exception as e: + logger.error(f"Failed to revert the Azure VM {path_to_vm} to snapshot {snapshot_name}: {str(e)}") + raise + + def stop_emulator(self, path_to_vm, region=None): + logger.info(f"Stopping Azure VM {path_to_vm}...") + resource_group_name, vm_name = path_to_vm.split('/') + + vm = self.compute_client.virtual_machines.get(resource_group_name, vm_name, expand='instanceView') + power_state = vm.instance_view.statuses[-1].code + if power_state == "PowerState/deallocated": + print("VM is already stopped.") + return + + try: + for _ in range(MAX_ATTEMPTS): + async_vm_deallocate = self.compute_client.virtual_machines.begin_deallocate(resource_group_name, vm_name) + logger.info(f"Stopping VM {path_to_vm}...") + # Wait for the instance to start + async_vm_deallocate.wait(timeout=WAIT_DELAY) + vm = self.compute_client.virtual_machines.get(resource_group_name, vm_name, expand='instanceView') + power_state = vm.instance_view.statuses[-1].code + if power_state == "PowerState/deallocated": + logger.info(f"VM {path_to_vm} is already stopped.") + break + except Exception as e: + logger.error(f"Failed to stop the Azure VM {path_to_vm}: {str(e)}") + raise diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/base.py b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/base.py new file mode 100644 index 000000000..f4867e7e5 --- /dev/null +++ b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/base.py @@ -0,0 +1,97 @@ +from abc import ABC, abstractmethod + + +class Provider(ABC): + def __init__(self, region: str = None): + """ + Region of the cloud service. + """ + self.region = region + + @abstractmethod + def start_emulator(self, path_to_vm: str, headless: bool): + """ + Method to start the emulator. + """ + pass + + @abstractmethod + def get_ip_address(self, path_to_vm: str) -> str: + """ + Method to get the private IP address of the VM. Private IP means inside the VPC. + """ + pass + + @abstractmethod + def save_state(self, path_to_vm: str, snapshot_name: str): + """ + Method to save the state of the VM. + """ + pass + + @abstractmethod + def revert_to_snapshot(self, path_to_vm: str, snapshot_name: str) -> str: + """ + Method to revert the VM to a given snapshot. + """ + pass + + @abstractmethod + def stop_emulator(self, path_to_vm: str): + """ + Method to stop the emulator. + """ + pass + + +class VMManager(ABC): + checked_and_cleaned = False + + @abstractmethod + def initialize_registry(self, **kwargs): + """ + Initialize registry. + """ + pass + + @abstractmethod + def add_vm(self, vm_path, **kwargs): + """ + Add the path of new VM to the registration. + """ + pass + + @abstractmethod + def delete_vm(self, vm_path, **kwargs): + """ + Delete the registration of VM by path. + """ + pass + + @abstractmethod + def occupy_vm(self, vm_path, pid, **kwargs): + """ + Mark the path of VM occupied by the pid. + """ + pass + + @abstractmethod + def list_free_vms(self, **kwargs): + """ + List the paths of VM that are free to use allocated. + """ + pass + + @abstractmethod + def check_and_clean(self, **kwargs): + """ + Check the registration list, and remove the paths of VM that are not in use. + """ + pass + + @abstractmethod + def get_vm_path(self, **kwargs): + """ + Get a virtual machine that is not occupied, generate a new one if no free VM. + """ + pass diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/docker/DOCKER_GUIDELINE.md b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/docker/DOCKER_GUIDELINE.md new file mode 100644 index 000000000..b13752ba0 --- /dev/null +++ b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/docker/DOCKER_GUIDELINE.md @@ -0,0 +1,29 @@ +# Configuration of Docker + +--- + +Welcome to the Docker VM Management documentation. + +## Prerequisite: Check if your machine supports KVM + +We recommend running the VM with KVM support. To check if your hosting platform supports KVM, run + +``` +egrep -c '(vmx|svm)' /proc/cpuinfo +``` + +on Linux. If the return value is greater than zero, the processor should be able to support KVM. + +> **Note**: macOS hosts generally do not support KVM. You are advised to use VMware if you would like to run OSWorld on macOS. + +## Install Docker + +If your hosting platform supports graphical user interface (GUI), you may refer to [Install Docker Desktop on Linux](https://docs.docker.com/desktop/install/linux/) or [Install Docker Desktop on Windows](https://docs.docker.com/desktop/install/windows-install/) based on your OS. Otherwise, you may [Install Docker Engine](https://docs.docker.com/engine/install/). + +## Running Experiments + +Add the following arguments when initializing `DesktopEnv`: +- `provider_name`: `docker` +- `os_type`: `Ubuntu` or `Windows`, depending on the OS of the VM + +Please allow for some time to download the virtual machine snapshot on your first run. diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/docker/manager.py b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/docker/manager.py new file mode 100644 index 000000000..6c1a73033 --- /dev/null +++ b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/docker/manager.py @@ -0,0 +1,123 @@ +import os +import platform +import zipfile + +from time import sleep +import requests +from tqdm import tqdm + +import logging + +from skyrl_agent.tasks.osworld.desktop_env.providers.base import VMManager + +logger = logging.getLogger("desktopenv.providers.docker.DockerVMManager") +logger.setLevel(logging.INFO) + +MAX_RETRY_TIMES = 10 +RETRY_INTERVAL = 5 + +UBUNTU_X86_URL = "https://huggingface.co/datasets/xlangai/ubuntu_osworld/resolve/main/Ubuntu.qcow2.zip" +WINDOWS_X86_URL = "https://huggingface.co/datasets/xlangai/windows_osworld/resolve/main/Windows-10-x64.qcow2.zip" +VMS_DIR = "./docker_vm_data" + +URL = UBUNTU_X86_URL +DOWNLOADED_FILE_NAME = URL.split('/')[-1] + +if platform.system() == 'Windows': + docker_path = r"C:\Program Files\Docker\Docker" + os.environ["PATH"] += os.pathsep + docker_path + + +def _download_vm(vms_dir: str): + global URL, DOWNLOADED_FILE_NAME + # Download the virtual machine image + logger.info("Downloading the virtual machine image...") + downloaded_size = 0 + + downloaded_file_name = DOWNLOADED_FILE_NAME + + os.makedirs(vms_dir, exist_ok=True) + + while True: + downloaded_file_path = os.path.join(vms_dir, downloaded_file_name) + headers = {} + if os.path.exists(downloaded_file_path): + downloaded_size = os.path.getsize(downloaded_file_path) + headers["Range"] = f"bytes={downloaded_size}-" + + with requests.get(URL, headers=headers, stream=True) as response: + if response.status_code == 416: + # This means the range was not satisfiable, possibly the file was fully downloaded + logger.info("Fully downloaded or the file size changed.") + break + + response.raise_for_status() + total_size = int(response.headers.get('content-length', 0)) + + with open(downloaded_file_path, "ab") as file, tqdm( + desc="Progress", + total=total_size, + unit='iB', + unit_scale=True, + unit_divisor=1024, + initial=downloaded_size, + ascii=True + ) as progress_bar: + try: + for data in response.iter_content(chunk_size=1024): + size = file.write(data) + progress_bar.update(size) + except (requests.exceptions.RequestException, IOError) as e: + logger.error(f"Download error: {e}") + sleep(RETRY_INTERVAL) + logger.error("Retrying...") + else: + logger.info("Download succeeds.") + break # Download completed successfully + + if downloaded_file_name.endswith(".zip"): + # Unzip the downloaded file + logger.info("Unzipping the downloaded file...☕️") + with zipfile.ZipFile(downloaded_file_path, 'r') as zip_ref: + zip_ref.extractall(vms_dir) + logger.info("Files have been successfully extracted to the directory: " + str(vms_dir)) + + +class DockerVMManager(VMManager): + def __init__(self, registry_path=""): + pass + + def add_vm(self, vm_path): + pass + + def check_and_clean(self): + pass + + def delete_vm(self, vm_path): + pass + + def initialize_registry(self): + pass + + def list_free_vms(self): + return os.path.join(VMS_DIR, DOWNLOADED_FILE_NAME) + + def occupy_vm(self, vm_path): + pass + + def get_vm_path(self, os_type, region): + global URL, DOWNLOADED_FILE_NAME + if os_type == "Ubuntu": + URL = UBUNTU_X86_URL + elif os_type == "Windows": + URL = WINDOWS_X86_URL + DOWNLOADED_FILE_NAME = URL.split('/')[-1] + + if DOWNLOADED_FILE_NAME.endswith(".zip"): + vm_name = DOWNLOADED_FILE_NAME[:-4] + else: + vm_name = DOWNLOADED_FILE_NAME + + if not os.path.exists(os.path.join(VMS_DIR, vm_name)): + _download_vm(VMS_DIR) + return os.path.join(VMS_DIR, vm_name) diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/docker/provider.py b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/docker/provider.py new file mode 100644 index 000000000..adbe83e57 --- /dev/null +++ b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/docker/provider.py @@ -0,0 +1,294 @@ +import logging +import os +import platform +import time +import asyncio +import docker +import psutil +import requests +import aiohttp +from filelock import FileLock +from pathlib import Path + +from skyrl_agent.tasks.osworld.desktop_env.providers.base import Provider + +logger = logging.getLogger("desktopenv.providers.docker.DockerProvider") +logger.setLevel(logging.INFO) + +WAIT_TIME = 3 +RETRY_INTERVAL = 1 +LOCK_TIMEOUT = 120 # Increased timeout for concurrent Ray actors + + +class PortAllocationError(Exception): + pass + + +class DockerProvider(Provider): + # Class-level async lock to coordinate file lock access across instances + _async_lock = None + # Class-level set to track ports that are reserved but not yet bound + _reserved_ports = set() + + @classmethod + def _get_async_lock(cls): + """Lazy initialization of async lock to avoid event loop issues.""" + try: + # Check if we have a lock and if it's still valid for the current event loop + if cls._async_lock is not None: + # Try to access the lock's loop - this will raise RuntimeError if bound to dead loop + cls._async_lock._get_loop() + return cls._async_lock + except RuntimeError: + # Lock is bound to a dead event loop, create a new one + cls._async_lock = None + + # Create new lock if we don't have one or the old one was invalid + if cls._async_lock is None: + cls._async_lock = asyncio.Lock() + return cls._async_lock + + @classmethod + def cleanup_all_reserved_ports(cls): + """Clean up all reserved ports. Useful between training steps.""" + cls._reserved_ports.clear() + logger.info("Cleared all reserved ports") + + def __init__(self, region: str, env_id: int): + self.client = docker.from_env() + self.server_port = None + self.vnc_port = None + self.chromium_port = None + self.vlc_port = None + self.container = None + self.environment = {"DISK_SIZE": "32G", "RAM_SIZE": "4G", "CPU_CORES": "4"} # Modify if needed + self.env_id = env_id + temp_dir = Path(os.getenv('TEMP') if platform.system() == 'Windows' else '/tmp') + self.lock_file = temp_dir / "docker_port_allocation.lck" + self.lock_file.parent.mkdir(parents=True, exist_ok=True) + + def _get_deterministic_ports(self): + """Get deterministic ports based on environment ID.""" + # Base ports for each service + base_ports = { + 'vnc': 8006, + 'server': 5000, + 'chromium': 9222, + 'vlc': 8081 + } + + # Calculate ports by adding env_id * 100 to base ports + # This gives us ranges like: 8006, 8106, 8206, ... for VNC + port_offset = self.env_id + + ports = {} + for service, base_port in base_ports.items(): + ports[service] = base_port + port_offset + + logger.info(f"Environment {self.env_id} allocated ports: {ports}") + return ports + + def _wait_for_vm_ready(self, timeout: int = 300): + """Wait for VM to be ready by checking screenshot endpoint.""" + start_time = time.time() + def check_screenshot(): + try: + response = requests.get( + f"http://localhost:{self.server_port}/screenshot", + timeout=(10, 10) + ) + return response.status_code == 200 + except Exception: + return False + + while time.time() - start_time < timeout: + if check_screenshot(): + return True + logger.info("Checking if virtual machine is ready...") + time.sleep(RETRY_INTERVAL) + + raise TimeoutError("VM failed to become ready within timeout period") + + async def _wait_for_vm_ready_async(self, timeout: int = 300): + """Async version: Wait for VM to be ready by checking screenshot endpoint.""" + start_time = asyncio.get_event_loop().time() + + async def check_screenshot(): + try: + async with aiohttp.ClientSession() as session: + async with session.get( + f"http://localhost:{self.server_port}/screenshot", + timeout=aiohttp.ClientTimeout(total=10) + ) as response: + return response.status == 200 + except Exception: + return False + + while (asyncio.get_event_loop().time() - start_time) < timeout: + if await check_screenshot(): + return True + logger.info("Checking if virtual machine is ready...") + await asyncio.sleep(RETRY_INTERVAL) + + raise TimeoutError("VM failed to become ready within timeout period") + + async def start_emulator_async(self, path_to_vm: str, headless: bool, os_type: str): + """ + Async version of start_emulator with deterministic port allocation. + No locks needed since ports are deterministically assigned based on env_id. + """ + # Step 1: Allocate ports deterministically (no lock needed) + ports = self._get_deterministic_ports() + self.vnc_port = ports['vnc'] + self.server_port = ports['server'] + self.chromium_port = ports['chromium'] + self.vlc_port = ports['vlc'] + + # Step 2: Start container + loop = asyncio.get_event_loop() + await loop.run_in_executor(None, self._start_container_sync, path_to_vm, headless, os_type) + + # Step 3: Wait for VM to be ready + await self._wait_for_vm_ready_async() + + def _allocate_ports_sync(self): + """Allocate ports deterministically - no lock needed.""" + ports = self._get_deterministic_ports() + self.vnc_port = ports['vnc'] + self.server_port = ports['server'] + self.chromium_port = ports['chromium'] + self.vlc_port = ports['vlc'] + + def _start_container_sync(self, path_to_vm: str, headless: bool, os_type: str): + """Start the Docker container - can run concurrently since ports are pre-allocated.""" + allocated_ports = [self.vnc_port, self.server_port, self.chromium_port, self.vlc_port] + try: + # Check if KVM is available + devices = [] + if os.path.exists("/dev/kvm"): + devices.append("/dev/kvm") + logger.info("KVM device found, using hardware acceleration") + else: + self.environment["KVM"] = "N" + logger.warning("KVM device not found, running without hardware acceleration (will be slower)") + + self.container = self.client.containers.run( + "happysixd/osworld-docker", + environment=self.environment, + cap_add=["NET_ADMIN"], + devices=devices, + volumes={ + os.path.abspath(path_to_vm): { + "bind": "/System.qcow2", + "mode": "ro" + } + }, + ports={ + 8006: self.vnc_port, + 5000: self.server_port, + 9222: self.chromium_port, + 8080: self.vlc_port + }, + detach=True + ) + + logger.info(f"Started container with ports - VNC: {self.vnc_port}, " + f"Server: {self.server_port}, Chrome: {self.chromium_port}, VLC: {self.vlc_port}") + + except Exception as e: + # Clean up if anything goes wrong + if self.container: + try: + self.container.stop() + self.container.remove() + except: + pass + raise e + + + def start_emulator(self, path_to_vm: str, headless: bool, os_type: str): + # Allocate ports deterministically (no lock needed) + ports = self._get_deterministic_ports() + self.vnc_port = ports['vnc'] + self.server_port = ports['server'] + self.chromium_port = ports['chromium'] + self.vlc_port = ports['vlc'] + + try: + # Check if KVM is available + devices = [] + if os.path.exists("/dev/kvm"): + devices.append("/dev/kvm") + logger.info("KVM device found, using hardware acceleration") + else: + self.environment["KVM"] = "N" + logger.warning("KVM device not found, running without hardware acceleration (will be slower)") + + self.container = self.client.containers.run( + "happysixd/osworld-docker", + environment=self.environment, + cap_add=["NET_ADMIN"], + devices=devices, + volumes={ + os.path.abspath(path_to_vm): { + "bind": "/System.qcow2", + "mode": "ro" + } + }, + ports={ + 8006: self.vnc_port, + 5000: self.server_port, + 9222: self.chromium_port, + 8080: self.vlc_port + }, + detach=True + ) + + logger.info(f"Started container with ports - VNC: {self.vnc_port}, " + f"Server: {self.server_port}, Chrome: {self.chromium_port}, VLC: {self.vlc_port}") + + # Wait for VM to be ready + self._wait_for_vm_ready() + + except Exception as e: + # Clean up if anything goes wrong + if self.container: + try: + self.container.stop() + self.container.remove() + except: + pass + raise e + + def get_ip_address(self, path_to_vm: str) -> str: + if not all([self.server_port, self.chromium_port, self.vnc_port, self.vlc_port]): + raise RuntimeError("VM not started - ports not allocated") + return f"localhost:{self.server_port}:{self.chromium_port}:{self.vnc_port}:{self.vlc_port}" + + def save_state(self, path_to_vm: str, snapshot_name: str): + raise NotImplementedError("Snapshots not available for Docker provider") + + def revert_to_snapshot(self, path_to_vm: str, snapshot_name: str): + self.stop_emulator(path_to_vm) + + def stop_emulator(self, path_to_vm: str, region=None, *args, **kwargs): + # Note: region parameter is ignored for Docker provider + # but kept for interface consistency with other providers + if self.container: + logger.info("Stopping VM...") + # Store ports for cleanup before clearing them + ports_to_cleanup = [self.vnc_port, self.server_port, self.chromium_port, self.vlc_port] + try: + self.container.stop() + self.container.remove() + time.sleep(WAIT_TIME) + except Exception as e: + logger.error(f"Error stopping container: {e}") + finally: + self.container = None + self.server_port = None + self.vnc_port = None + self.chromium_port = None + self.vlc_port = None + + # No need to clean up reserved ports since we use deterministic allocation diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/gcp/__init__.py b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/gcp/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/gcp/manager.py b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/gcp/manager.py new file mode 100644 index 000000000..e69de29bb diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/gcp/provider.py b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/gcp/provider.py new file mode 100644 index 000000000..e69de29bb diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/virtualbox/INSTALL_VITUALBOX.md b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/virtualbox/INSTALL_VITUALBOX.md new file mode 100644 index 000000000..e1d4ae4a0 --- /dev/null +++ b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/virtualbox/INSTALL_VITUALBOX.md @@ -0,0 +1,11 @@ +## 💾 Installation of VirtualBox + + +1. Download the VirtualBox from the [official website](https://www.virtualbox.org/wiki/Downloads). Unfortunately, for Apple chips (M1 chips, M2 chips, etc.), VirtualBox is not supported. You can only use VMware Fusion instead. +2. Install VirtualBox. Just follow the instructions provided by the installer. +For Windows, you also need to append the installation path to the environment variable `PATH` for enabling the `VBoxManage` command. The default installation path is `C:\Program Files\Oracle\VirtualBox`. +3. Verify the successful installation by running the following: + ```bash + VBoxManage --version + ``` + If the installation along with the environment variable set is successful, you will see the version of VirtualBox installed on your system. diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/virtualbox/__init__.py b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/virtualbox/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/virtualbox/manager.py b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/virtualbox/manager.py new file mode 100644 index 000000000..b407ac83d --- /dev/null +++ b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/virtualbox/manager.py @@ -0,0 +1,461 @@ +import logging +import os +import platform +import shutil +import subprocess +import threading +import time +import zipfile + +import psutil +import requests +from filelock import FileLock +from tqdm import tqdm + +from desktop_env.providers.base import VMManager + +logger = logging.getLogger("desktopenv.providers.virtualbox.VirtualBoxVMManager") +logger.setLevel(logging.INFO) + +MAX_RETRY_TIMES = 10 +RETRY_INTERVAL = 5 +UBUNTU_ARM_URL = "NOT_AVAILABLE" +UBUNTU_X86_URL = "https://huggingface.co/datasets/xlangai/ubuntu_x86_virtualbox/resolve/main/Ubuntu.zip" +DOWNLOADED_FILE_NAME = "Ubuntu.zip" +REGISTRY_PATH = '.virtualbox_vms' + +LOCK_FILE_NAME = '.virtualbox_lck' +VMS_DIR = "./virtualbox_vm_data" +update_lock = threading.Lock() + +if platform.system() == 'Windows': + vboxmanage_path = r"C:\Program Files\Oracle\VirtualBox" + os.environ["PATH"] += os.pathsep + vboxmanage_path + + +def generate_new_vm_name(vms_dir, os_type): + registry_idx = 0 + while True: + attempted_new_name = f"{os_type}{registry_idx}" + if os.path.exists( + os.path.join(vms_dir, attempted_new_name, attempted_new_name, attempted_new_name + ".vbox")): + registry_idx += 1 + else: + return attempted_new_name + + +def _install_vm(vm_name, vms_dir, downloaded_file_name, original_vm_name="Ubuntu", bridged_adapter_name=None): + os.makedirs(vms_dir, exist_ok=True) + + def __download_and_unzip_vm(): + # Determine the platform and CPU architecture to decide the correct VM image to download + if platform.system() == 'Darwin': # macOS + url = UBUNTU_ARM_URL + raise Exception("MacOS host is not currently supported for VirtualBox.") + elif platform.machine().lower() in ['amd64', 'x86_64']: + url = UBUNTU_X86_URL + else: + raise Exception("Unsupported platform or architecture.") + + # Download the virtual machine image + logger.info("Downloading the virtual machine image...") + downloaded_size = 0 + + while True: + downloaded_file_path = os.path.join(vms_dir, downloaded_file_name) + headers = {} + if os.path.exists(downloaded_file_path): + downloaded_size = os.path.getsize(downloaded_file_path) + headers["Range"] = f"bytes={downloaded_size}-" + + with requests.get(url, headers=headers, stream=True) as response: + if response.status_code == 416: + # This means the range was not satisfiable, possibly the file was fully downloaded + logger.info("Fully downloaded or the file size changed.") + break + + response.raise_for_status() + total_size = int(response.headers.get('content-length', 0)) + + with open(downloaded_file_path, "ab") as file, tqdm( + desc="Progress", + total=total_size, + unit='iB', + unit_scale=True, + unit_divisor=1024, + initial=downloaded_size, + ascii=True + ) as progress_bar: + try: + for data in response.iter_content(chunk_size=1024): + size = file.write(data) + progress_bar.update(size) + except (requests.exceptions.RequestException, IOError) as e: + logger.error(f"Download error: {e}") + time.sleep(RETRY_INTERVAL) + logger.error("Retrying...") + else: + logger.info("Download succeeds.") + break # Download completed successfully + + # Unzip the downloaded file + logger.info("Unzipping the downloaded file...☕️") + with zipfile.ZipFile(downloaded_file_path, 'r') as zip_ref: + zip_ref.extractall(vms_dir) + logger.info("Files have been successfully extracted to the directory: " + vms_dir) + + def import_vm(vms_dir, target_vm_name, max_retries=1): + """Import the .ovf file into VirtualBox.""" + logger.info(f"Starting to import VM {target_vm_name}...") + command = ( + f"VBoxManage import {os.path.abspath(os.path.join(vms_dir, original_vm_name, original_vm_name + '.ovf'))} " + f"--vsys 0 " + f"--vmname {target_vm_name} " + f"--settingsfile {os.path.abspath(os.path.join(vms_dir, target_vm_name, target_vm_name + '.vbox'))} " + f"--basefolder {vms_dir} " + f"--unit 14 " + f"--disk {os.path.abspath(os.path.join(vms_dir, target_vm_name, target_vm_name + '_disk1.vmdk'))}") + + for attempt in range(max_retries): + result = subprocess.run(command, shell=True, text=True, capture_output=True, encoding="utf-8", + errors='ignore') + if result.returncode == 0: + logger.info("Successfully imported VM.") + return True + else: + if not result.stderr or "Error" in result.stderr: + logger.error(f"Attempt {attempt + 1} failed with specific error: {result.stderr}") + else: + logger.error(f"Attempt {attempt + 1} failed: {result.stderr}") + + if attempt == max_retries - 1: + logger.error("Maximum retry attempts reached, failed to import the virtual machine.") + return False + + def configure_vm_network(vm_name, interface_name=None): + # Config of bridged network + command = f'VBoxManage modifyvm "{vm_name}" --nic1 bridged' + result = subprocess.run(command, shell=True, text=True, capture_output=True, encoding="utf-8", + errors='ignore') + if not interface_name: + output = subprocess.check_output(f"VBoxManage list bridgedifs", shell=True, stderr=subprocess.STDOUT) + output = output.decode() + output = output.splitlines() + result = [] + for line in output: + entries = line.split() + if entries and entries[0] == "Name:": + name = ' '.join(entries[1:]) + if entries and entries[0] == "IPAddress:": + ip = entries[1] + result.append((name, ip)) + logger.info("Found the following network adapters, default to the first. If you want to change it, please set the argument -r to the name of the adapter.") + for i, (name, ip) in enumerate(result): + logger.info(f"{i+1}: {name} ({ip})") + interface_id = 1 + interface_name = result[interface_id-1][0] + command = f'vboxmanage modifyvm "{vm_name}" --bridgeadapter1 "{interface_name}"' + result = subprocess.run(command, shell=True, text=True, capture_output=True, encoding="utf-8", + errors='ignore') + if result.returncode == 0: + logger.info(f"Changed to bridge adapter {interface_name}.") + return True + else: + logger.error(f"Failed to change to bridge adapter {interface_name}: {result.stderr}") + return False + + # # Config of NAT network + # command = f"VBoxManage natnetwork add --netname natnet --network {nat_network} --dhcp on" + # result = subprocess.run(command, shell=True, text=True, capture_output=True, encoding="utf-8", + # errors='ignore') + # if result.returncode == 0: + # logger.info(f"Created NAT network {nat_network}.") + # else: + # logger.error(f"Failed to create NAT network {nat_network}") + # return False + # command = f"VBoxManage modifyvm {vm_name} --nic1 natnetwork" + # result = subprocess.run(command, shell=True, text=True, capture_output=True, encoding="utf-8", + # errors='ignore') + # command = f"VBoxManage modifyvm {vm_name} --natnet1 natnet" + # result = subprocess.run(command, shell=True, text=True, capture_output=True, encoding="utf-8", + # errors='ignore') + # if result.returncode == 0: + # logger.info("Switched VM to the NAT network.") + # else: + # logger.error("Failed to switch VM to the NAT network") + # return False + # logger.info("Start to configure port forwarding...") + # command = f"VBoxManage modifyvm {vm_name} --natpf1 'server,tcp,,5000,,5000'" + # result = subprocess.run(command, shell=True, text=True, capture_output=True, encoding="utf-8", + # errors='ignore') + # if result.returncode == 0: + # logger.info("Successfully created port forwarding rule.") + # return True + # logger.error("Failed to create port forwarding rule.") + # return False + + + vm_path = os.path.join(vms_dir, vm_name, vm_name + ".vbox") + + # Execute the function to download and unzip the VM, and update the vm metadata + if not os.path.exists(vm_path): + __download_and_unzip_vm() + import_vm(vms_dir, vm_name) + if not configure_vm_network(vm_name, bridged_adapter_name): + raise Exception("Failed to configure VM network!") + else: + logger.info(f"Virtual machine exists: {vm_path}") + + # Start the virtual machine + def start_vm(vm_name, max_retries=20): + command = f'VBoxManage startvm "{vm_name}" --type headless' + + for attempt in range(max_retries): + result = subprocess.run(command, shell=True, text=True, capture_output=True, encoding="utf-8") + if result.returncode == 0: + logger.info("Virtual machine started.") + return True + else: + if not result.stderr or "Error" in result.stderr: + logger.error(f"Attempt {attempt + 1} failed with specific error: {result.stderr}") + else: + logger.error(f"Attempt {attempt + 1} failed: {result.stderr}") + + if attempt == max_retries - 1: + logger.error("Maximum retry attempts reached, failed to start the virtual machine.") + return False + + if not start_vm(vm_name): + raise ValueError("Error encountered during installation, please rerun the code for retrying.") + + def get_vm_ip(vm_name): + command = f'VBoxManage guestproperty get "{vm_name}" /VirtualBox/GuestInfo/Net/0/V4/IP' + result = subprocess.run(command, shell=True, text=True, capture_output=True, encoding="utf-8") + if result.returncode == 0: + return result.stdout.strip().split()[1] + else: + logger.error(f"Get VM IP failed: {result.stderr}") + return None + + def change_resolution(vm_name, resolution=(1920, 1080, 32)): + command = f'VBoxManage controlvm "{vm_name}" setvideomodehint {" ".join(map(str, resolution))}' + result = subprocess.run(command, shell=True, text=True, capture_output=True, encoding="utf-8") + if result.returncode == 0: + return True + else: + return False + + # Function used to check whether the virtual machine is ready + def download_screenshot(vm_name): + ip = get_vm_ip(vm_name) + url = f"http://{ip}:5000/screenshot" + try: + # max trey times 1, max timeout 1 + response = requests.get(url, timeout=(10, 10)) + if response.status_code == 200: + return True + except Exception as e: + logger.error(f"Error: {e}") + logger.error(f"Type: {type(e).__name__}") + logger.error(f"Error detail: {str(e)}") + return False + + # Try downloading the screenshot until successful + while not download_screenshot(vm_name): + logger.info("Check whether the virtual machine is ready...") + time.sleep(RETRY_INTERVAL) + + if not change_resolution(vm_name): + logger.error(f"Change resolution failed.") + raise + + logger.info("Virtual machine is ready. Start to make a snapshot on the virtual machine. It would take a while...") + + def create_vm_snapshot(vm_name, max_retries=20): + logger.info("Saving VirtualBox VM state...") + command = f'VBoxManage snapshot "{vm_name}" take init_state' + + for attempt in range(max_retries): + result = subprocess.run(command, shell=True, text=True, capture_output=True, encoding="utf-8") + if result.returncode == 0: + logger.info("Snapshot created.") + return True + else: + if "Error" in result.stderr: + logger.error(f"Attempt {attempt + 1} failed with specific error: {result.stderr}") + else: + logger.error(f"Attempt {attempt + 1} failed: {result.stderr}") + + if attempt == max_retries - 1: + logger.error("Maximum retry attempts reached, failed to create snapshot.") + return False + + # Create a snapshot of the virtual machine + if create_vm_snapshot(vm_name, max_retries=MAX_RETRY_TIMES): + return vm_path + else: + raise ValueError("Error encountered during installation, please rerun the code for retrying.") + + +class VirtualBoxVMManager(VMManager): + def __init__(self, registry_path=REGISTRY_PATH): + self.registry_path = registry_path + self.lock = FileLock(LOCK_FILE_NAME, timeout=60) + self.initialize_registry() + + def initialize_registry(self): + with self.lock: # Locking during initialization + if not os.path.exists(self.registry_path): + with open(self.registry_path, 'w') as file: + file.write('') + + def add_vm(self, vm_path, lock_needed=True): + if lock_needed: + with self.lock: + self._add_vm(vm_path) + else: + self._add_vm(vm_path) + + def _add_vm(self, vm_path, region=None): + assert region in [None, 'local'], "For VirtualBox provider, the region should be neither None or 'local'." + with self.lock: + with open(self.registry_path, 'r') as file: + lines = file.readlines() + new_lines = lines + [f'{vm_path}|free\n'] + with open(self.registry_path, 'w') as file: + file.writelines(new_lines) + + def occupy_vm(self, vm_path, pid, lock_needed=True): + if lock_needed: + with self.lock: + self._occupy_vm(vm_path, pid) + else: + self._occupy_vm(vm_path, pid) + + def _occupy_vm(self, vm_path, pid, region=None): + assert region in [None, 'local'], "For VirtualBox provider, the region should be neither None or 'local'." + with self.lock: + new_lines = [] + with open(self.registry_path, 'r') as file: + lines = file.readlines() + for line in lines: + registered_vm_path, _ = line.strip().split('|') + if registered_vm_path == vm_path: + new_lines.append(f'{registered_vm_path}|{pid}\n') + else: + new_lines.append(line) + with open(self.registry_path, 'w') as file: + file.writelines(new_lines) + + def delete_vm(self, vm_path, lock_needed=True): + if lock_needed: + with self.lock: + self._delete_vm(vm_path) + else: + self._delete_vm(vm_path) + + def _delete_vm(self, vm_path): + raise NotImplementedError + + def check_and_clean(self, vms_dir, lock_needed=True): + if lock_needed: + with self.lock: + self._check_and_clean(vms_dir) + else: + self._check_and_clean(vms_dir) + + def _check_and_clean(self, vms_dir): + with self.lock: # Lock when cleaning up the registry and vms_dir + # Check and clean on the running vms, detect the released ones and mark then as 'free' + active_pids = {p.pid for p in psutil.process_iter()} + new_lines = [] + vm_paths = [] + + with open(self.registry_path, 'r') as file: + lines = file.readlines() + for line in lines: + vm_path, pid_str = line.strip().split('|') + if not os.path.exists(vm_path): + logger.info(f"VM {vm_path} not found, releasing it.") + new_lines.append(f'{vm_path}|free\n') + continue + + vm_paths.append(vm_path) + if pid_str == "free": + new_lines.append(line) + continue + + if int(pid_str) in active_pids: + new_lines.append(line) + else: + new_lines.append(f'{vm_path}|free\n') + with open(self.registry_path, 'w') as file: + file.writelines(new_lines) + + # Check and clean on the files inside vms_dir, delete the unregistered ones + os.makedirs(vms_dir, exist_ok=True) + vm_names = os.listdir(vms_dir) + for vm_name in vm_names: + # skip the downloaded .zip file + if vm_name == DOWNLOADED_FILE_NAME: + continue + # Skip the .DS_Store file on macOS + if vm_name == ".DS_Store": + continue + + flag = True + for vm_path in vm_paths: + if vm_name + ".vbox" in vm_path: + flag = False + if flag: + shutil.rmtree(os.path.join(vms_dir, vm_name)) + + def list_free_vms(self, lock_needed=True): + if lock_needed: + with self.lock: + return self._list_free_vms() + else: + return self._list_free_vms() + + def _list_free_vms(self): + with self.lock: # Lock when reading the registry + free_vms = [] + with open(self.registry_path, 'r') as file: + lines = file.readlines() + for line in lines: + vm_path, pid_str = line.strip().split('|') + if pid_str == "free": + free_vms.append((vm_path, pid_str)) + return free_vms + + def get_vm_path(self, os_type, region=None): + if os_type != "Ubuntu": + raise ValueError("Only support Ubuntu for now.") + + with self.lock: + if not VirtualBoxVMManager.checked_and_cleaned: + VirtualBoxVMManager.checked_and_cleaned = True + self._check_and_clean(vms_dir=VMS_DIR) + + allocation_needed = False + with self.lock: + free_vms_paths = self._list_free_vms() + if len(free_vms_paths) == 0: + # No free virtual machine available, generate a new one + allocation_needed = True + else: + # Choose the first free virtual machine + chosen_vm_path = free_vms_paths[0][0] + self._occupy_vm(chosen_vm_path, os.getpid()) + return chosen_vm_path + + if allocation_needed: + logger.info("No free virtual machine available. Generating a new one, which would take a while...☕") + new_vm_name = generate_new_vm_name(vms_dir=VMS_DIR, os_type=os_type) + new_vm_path = _install_vm(new_vm_name, vms_dir=VMS_DIR, + downloaded_file_name=DOWNLOADED_FILE_NAME, + bridged_adapter_name=region) + with self.lock: + self._add_vm(new_vm_path) + self._occupy_vm(new_vm_path, os.getpid()) + return new_vm_path diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/virtualbox/provider.py b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/virtualbox/provider.py new file mode 100644 index 000000000..81cda086e --- /dev/null +++ b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/virtualbox/provider.py @@ -0,0 +1,120 @@ +import logging +import platform +import subprocess +import time +import os +from desktop_env.providers.base import Provider +import xml.etree.ElementTree as ET + +logger = logging.getLogger("desktopenv.providers.virtualbox.VirtualBoxProvider") +logger.setLevel(logging.INFO) + +WAIT_TIME = 3 + +# Note: Windows will not add command VBoxManage to PATH by default. Please add the folder where VBoxManage executable is in (Default should be "C:\Program Files\Oracle\VirtualBox" for Windows) to PATH. + +class VirtualBoxProvider(Provider): + @staticmethod + def _execute_command(command: list): + result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=60, text=True, + encoding="utf-8") + if result.returncode != 0: + raise Exception("\033[91m" + result.stdout + result.stderr + "\033[0m") + return result.stdout.strip() + + @staticmethod + def _get_vm_uuid(path_to_vm: str): + try: + output = subprocess.check_output(f"VBoxManage list vms", shell=True, stderr=subprocess.STDOUT) + output = output.decode() + output = output.splitlines() + if path_to_vm.endswith('.vbox'): + # Load and parse the XML content from the file + tree = ET.parse(path_to_vm) + root = tree.getroot() + + # Find the element and retrieve its 'uuid' attribute + machine_element = root.find('.//{http://www.virtualbox.org/}Machine') + if machine_element is not None: + uuid = machine_element.get('uuid')[1:-1] + return uuid + else: + logger.error(f"UUID not found in file {path_to_vm}") + raise + elif any(line.split()[1] == "{" + path_to_vm + "}" for line in output): + logger.info(f"Got valid UUID {path_to_vm}.") + return path_to_vm + else: + for line in output: + if line.split()[0] == '"' + path_to_vm + '"': + uuid = line.split()[1][1:-1] + return uuid + logger.error(f"The path you provided does not match any of the \".vbox\" file, name, or UUID of VM.") + raise + except subprocess.CalledProcessError as e: + logger.error(f"Error executing command: {e.output.decode().strip()}") + + + def start_emulator(self, path_to_vm: str, headless: bool): + logger.info("Starting VirtualBox VM...") + + while True: + try: + uuid = VirtualBoxProvider._get_vm_uuid(path_to_vm) + output = subprocess.check_output(f"VBoxManage list runningvms", shell=True, stderr=subprocess.STDOUT) + output = output.decode() + output = output.splitlines() + + if any(line.split()[1] == "{" + uuid + "}" for line in output): + logger.info("VM is running.") + break + else: + logger.info("Starting VM...") + VirtualBoxProvider._execute_command(["VBoxManage", "startvm", uuid]) if not headless else \ + VirtualBoxProvider._execute_command( + ["VBoxManage", "startvm", uuid, "--type", "headless"]) + time.sleep(WAIT_TIME) + + except subprocess.CalledProcessError as e: + logger.error(f"Error executing command: {e.output.decode().strip()}") + + def get_ip_address(self, path_to_vm: str) -> str: + logger.info("Getting VirtualBox VM IP address...") + while True: + try: + uuid = VirtualBoxProvider._get_vm_uuid(path_to_vm) + output = VirtualBoxProvider._execute_command( + ["VBoxManage", "guestproperty", "get", uuid, "/VirtualBox/GuestInfo/Net/0/V4/IP"] + ) + result = output.split()[1] + if result != "value": + logger.info(f"VirtualBox VM IP address: {result}") + return result + else: + logger.error("VM IP address not found. Have you installed the guest additions?") + raise + except Exception as e: + logger.error(e) + time.sleep(WAIT_TIME) + logger.info("Retrying to get VirtualBox VM IP address...") + + def save_state(self, path_to_vm: str, snapshot_name: str): + logger.info("Saving VirtualBox VM state...") + uuid = VirtualBoxProvider._get_vm_uuid(path_to_vm) + VirtualBoxProvider._execute_command(["VBoxManage", "snapshot", uuid, "take", snapshot_name]) + time.sleep(WAIT_TIME) # Wait for the VM to save + + def revert_to_snapshot(self, path_to_vm: str, snapshot_name: str): + logger.info(f"Reverting VirtualBox VM to snapshot: {snapshot_name}...") + uuid = VirtualBoxProvider._get_vm_uuid(path_to_vm) + VirtualBoxProvider._execute_command(["VBoxManage", "controlvm", uuid, "savestate"]) + time.sleep(WAIT_TIME) # Wait for the VM to stop + VirtualBoxProvider._execute_command(["VBoxManage", "snapshot", uuid, "restore", snapshot_name]) + time.sleep(WAIT_TIME) # Wait for the VM to revert + return path_to_vm + + def stop_emulator(self, path_to_vm: str): + logger.info("Stopping VirtualBox VM...") + uuid = VirtualBoxProvider._get_vm_uuid(path_to_vm) + VirtualBoxProvider._execute_command(["VBoxManage", "controlvm", uuid, "savestate"]) + time.sleep(WAIT_TIME) # Wait for the VM to stop diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/vmware/INSTALL_VMWARE.md b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/vmware/INSTALL_VMWARE.md new file mode 100644 index 000000000..596d9f4b0 --- /dev/null +++ b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/vmware/INSTALL_VMWARE.md @@ -0,0 +1,23 @@ +## 💾 Installation of VMware Workstation Pro + +--- + +1. Download VMware Workstation Pro from the [official website](https://www.vmware.com/products/workstation-pro/workstation-pro-evaluation.html). The version we are using is 17.5.1. For systems with Apple chips, you should install [VMware Fusion](https://www.vmware.com/go/getfusion). + +2. Install VMware Workstation + - **[On Linux](https://docs.vmware.com/en/VMware-Workstation-Pro/17/com.vmware.ws.using.doc/GUID-1F5B1F14-A586-4A56-83FA-2E7D8333D5CA.html):** Run the following command in your terminal, where `xxxx-xxxxxxx` represents the version number and internal version number. + ``` + sudo sh VMware-Workstation-xxxx-xxxxxxx.architecture.bundle --console + ``` + + - **[On Windows](https://docs.vmware.com/en/VMware-Workstation-Pro/17/com.vmware.ws.using.doc/GUID-F5A7B3CB-9141-458B-A256-E0C3EA805AAA.html):** Ensure that you're logged in as either the Administrator user or as a user who belongs to the local Administrators group. If you're logging in to a domain, make sure your domain account has local administrator privileges. Proceed by double-clicking the `VMware-workstation-xxxx-xxxxxxx.exe` file. Be aware that you might need to reboot your host system to finalize the installation. + + - **[For systems with Apple chips](https://docs.vmware.com/en/VMware-Fusion/13/com.vmware.fusion.using.doc/GUID-ACC3A019-93D3-442C-A34E-F7755DF6733B.html):** Double-click the `VMware-Fusion-xxxx-xxxxxxx.dmg` file to open it. In the Finder window that appears, double-click the 'Install Fusion' icon. When prompted, enter your administrator username and password. + + > **Note:** You need to fill the activation key during the installation process when prompted. + +3. Verify the successful installation by running the following: + ``` + vmrun -T ws list + ``` + If the installation along with the environment variable set is successful, you will see the message showing the current running virtual machines. diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/vmware/__init__.py b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/vmware/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/vmware/manager.py b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/vmware/manager.py new file mode 100644 index 000000000..c3d89711f --- /dev/null +++ b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/vmware/manager.py @@ -0,0 +1,453 @@ +import os +import platform +import random +import re + +import threading +from filelock import FileLock +import uuid +import zipfile + +from time import sleep +import shutil +import psutil +import subprocess +import requests +from tqdm import tqdm + +import logging + +from skyrl_agent.tasks.osworld.desktop_env.providers.base import VMManager + +logger = logging.getLogger("desktopenv.providers.vmware.VMwareVMManager") +logger.setLevel(logging.INFO) + +MAX_RETRY_TIMES = 10 +RETRY_INTERVAL = 5 +UBUNTU_ARM_URL = "https://huggingface.co/datasets/xlangai/ubuntu_osworld/resolve/main/Ubuntu-arm.zip" +UBUNTU_X86_URL = "https://huggingface.co/datasets/xlangai/ubuntu_osworld/resolve/main/Ubuntu-x86.zip" +WINDOWS_X86_URL = "https://huggingface.co/datasets/xlangai/windows_osworld/resolve/main/Windows-x86.zip" + +# Determine the platform and CPU architecture to decide the correct VM image to download +if platform.system() == 'Darwin': # macOS + # if os.uname().machine == 'arm64': # Apple Silicon + URL = UBUNTU_ARM_URL +# else: +# url = UBUNTU_X86_URL +elif platform.machine().lower() in ['amd64', 'x86_64']: + URL = UBUNTU_X86_URL +else: + raise Exception("Unsupported platform or architecture") + +DOWNLOADED_FILE_NAME = URL.split('/')[-1] +REGISTRY_PATH = '.vmware_vms' +LOCK_FILE_NAME = '.vmware_lck' +VMS_DIR = "./vmware_vm_data" +update_lock = threading.Lock() + +if platform.system() == 'Windows': + vboxmanage_path = r"C:\Program Files (x86)\VMware\VMware Workstation" + os.environ["PATH"] += os.pathsep + vboxmanage_path + +def generate_new_vm_name(vms_dir, os_type): + registry_idx = 0 + prefix = os_type + while True: + attempted_new_name = f"{prefix}{registry_idx}" + if os.path.exists( + os.path.join(vms_dir, attempted_new_name, attempted_new_name + ".vmx")): + registry_idx += 1 + else: + return attempted_new_name + + +def _update_vm(vmx_path, target_vm_name): + """Update the VMX file with the new VM name and other parameters, so that the VM can be started successfully without conflict with the original VM.""" + with update_lock: + dir_path, vmx_file = os.path.split(vmx_path) + + def _generate_mac_address(): + # VMware MAC address range starts with 00:0c:29 + mac = [0x00, 0x0c, 0x29, + random.randint(0x00, 0x7f), + random.randint(0x00, 0xff), + random.randint(0x00, 0xff)] + return ':'.join(map(lambda x: "%02x" % x, mac)) + + # Backup the original file + with open(vmx_path, 'r') as file: + original_content = file.read() + + # Generate new values + new_uuid_bios = str(uuid.uuid4()) + new_uuid_location = str(uuid.uuid4()) + new_mac_address = _generate_mac_address() + new_vmci_id = str(random.randint(-2147483648, 2147483647)) # Random 32-bit integer + + # Update the content + updated_content = re.sub(r'displayName = ".*?"', f'displayName = "{target_vm_name}"', original_content) + updated_content = re.sub(r'uuid.bios = ".*?"', f'uuid.bios = "{new_uuid_bios}"', updated_content) + updated_content = re.sub(r'uuid.location = ".*?"', f'uuid.location = "{new_uuid_location}"', updated_content) + updated_content = re.sub(r'ethernet0.generatedAddress = ".*?"', + f'ethernet0.generatedAddress = "{new_mac_address}"', + updated_content) + updated_content = re.sub(r'vmci0.id = ".*?"', f'vmci0.id = "{new_vmci_id}"', updated_content) + + # Write the updated content back to the file + with open(vmx_path, 'w') as file: + file.write(updated_content) + + logger.info(".vmx file updated successfully.") + + vmx_file_base_name = os.path.splitext(vmx_file)[0] + + files_to_rename = ['vmx', 'nvram', 'vmsd', 'vmxf'] + + for ext in files_to_rename: + original_file = os.path.join(dir_path, f"{vmx_file_base_name}.{ext}") + target_file = os.path.join(dir_path, f"{target_vm_name}.{ext}") + os.rename(original_file, target_file) + + # Update the dir_path to the target vm_name, only replace the last character + # Split the path into parts up to the last folder + path_parts = dir_path.rstrip(os.sep).split(os.sep) + path_parts[-1] = target_vm_name + target_dir_path = os.sep.join(path_parts) + os.rename(dir_path, target_dir_path) + + logger.info("VM files renamed successfully.") + + +def _install_vm(vm_name, vms_dir, downloaded_file_name, os_type, original_vm_name="Ubuntu"): + os.makedirs(vms_dir, exist_ok=True) + + def __download_and_unzip_vm(): + # Download the virtual machine image + logger.info("Downloading the virtual machine image...") + downloaded_size = 0 + + if os_type == "Ubuntu": + if platform.system() == 'Darwin': + URL = UBUNTU_ARM_URL + elif platform.machine().lower() in ['amd64', 'x86_64']: + URL = UBUNTU_X86_URL + elif os_type == "Windows": + if platform.machine().lower() in ['amd64', 'x86_64']: + URL = WINDOWS_X86_URL + DOWNLOADED_FILE_NAME = URL.split('/')[-1] + downloaded_file_name = DOWNLOADED_FILE_NAME + + while True: + downloaded_file_path = os.path.join(vms_dir, downloaded_file_name) + headers = {} + if os.path.exists(downloaded_file_path): + downloaded_size = os.path.getsize(downloaded_file_path) + headers["Range"] = f"bytes={downloaded_size}-" + + with requests.get(URL, headers=headers, stream=True) as response: + if response.status_code == 416: + # This means the range was not satisfiable, possibly the file was fully downloaded + logger.info("Fully downloaded or the file size changed.") + break + + response.raise_for_status() + total_size = int(response.headers.get('content-length', 0)) + + with open(downloaded_file_path, "ab") as file, tqdm( + desc="Progress", + total=total_size, + unit='iB', + unit_scale=True, + unit_divisor=1024, + initial=downloaded_size, + ascii=True + ) as progress_bar: + try: + for data in response.iter_content(chunk_size=1024): + size = file.write(data) + progress_bar.update(size) + except (requests.exceptions.RequestException, IOError) as e: + logger.error(f"Download error: {e}") + sleep(RETRY_INTERVAL) + logger.error("Retrying...") + else: + logger.info("Download succeeds.") + break # Download completed successfully + + # Unzip the downloaded file + logger.info("Unzipping the downloaded file...☕️") + with zipfile.ZipFile(downloaded_file_path, 'r') as zip_ref: + zip_ref.extractall(os.path.join(vms_dir, vm_name)) + logger.info("Files have been successfully extracted to the directory: " + str(os.path.join(vms_dir, vm_name))) + + vm_path = os.path.join(vms_dir, vm_name, vm_name + ".vmx") + + # Execute the function to download and unzip the VM, and update the vm metadata + if not os.path.exists(vm_path): + __download_and_unzip_vm() + _update_vm(os.path.join(vms_dir, vm_name, original_vm_name + ".vmx"), vm_name) + else: + logger.info(f"Virtual machine exists: {vm_path}") + + # Determine the platform of the host machine and decide the parameter for vmrun + def get_vmrun_type(): + if platform.system() == 'Windows' or platform.system() == 'Linux': + return '-T ws' + elif platform.system() == 'Darwin': # Darwin is the system name for macOS + return '-T fusion' + else: + raise Exception("Unsupported operating system") + + # Start the virtual machine + def start_vm(vm_path, max_retries=20): + command = f'vmrun {get_vmrun_type()} start "{vm_path}" nogui' + for attempt in range(max_retries): + result = subprocess.run(command, shell=True, text=True, capture_output=True, encoding="utf-8") + if result.returncode == 0: + logger.info("Virtual machine started.") + return True + else: + if "Error" in result.stderr: + logger.error(f"Attempt {attempt + 1} failed with specific error: {result.stderr}") + else: + logger.error(f"Attempt {attempt + 1} failed: {result.stderr}") + + if attempt == max_retries - 1: + logger.error("Maximum retry attempts reached, failed to start the virtual machine.") + return False + + if not start_vm(vm_path): + raise ValueError("Error encountered during installation, please rerun the code for retrying.") + + def get_vm_ip(vm_path, max_retries=20): + command = f'vmrun {get_vmrun_type()} getGuestIPAddress "{vm_path}" -wait' + for attempt in range(max_retries): + result = subprocess.run(command, shell=True, text=True, capture_output=True, encoding="utf-8") + if result.returncode == 0: + return result.stdout.strip() + else: + if "Error" in result.stderr: + logger.error(f"Attempt {attempt + 1} failed with specific error: {result.stderr}") + else: + logger.error(f"Attempt {attempt + 1} failed: {result.stderr}") + + if attempt == max_retries - 1: + logger.error("Maximum retry attempts reached, failed to get the IP of virtual machine.") + return None + + vm_ip = get_vm_ip(vm_path) + if not vm_ip: + raise ValueError("Error encountered during installation, please rerun the code for retrying.") + + # Function used to check whether the virtual machine is ready + def download_screenshot(ip): + url = f"http://{ip}:5000/screenshot" + try: + # max trey times 1, max timeout 1 + response = requests.get(url, timeout=(10, 10)) + if response.status_code == 200: + return True + except Exception as e: + logger.error(f"Error: {e}") + logger.error(f"Type: {type(e).__name__}") + logger.error(f"Error detail: {str(e)}") + sleep(RETRY_INTERVAL) + return False + + # Try downloading the screenshot until successful + while not download_screenshot(vm_ip): + # Try to get the IP again in case it has changed + vm_ip = get_vm_ip(vm_path) + logger.info("Check whether the virtual machine is ready...") + + logger.info("Virtual machine is ready. Start to make a snapshot on the virtual machine. It would take a while...") + + def create_vm_snapshot(vm_path, max_retries=20): + command = f'vmrun {get_vmrun_type()} snapshot "{vm_path}" "init_state"' + for attempt in range(max_retries): + result = subprocess.run(command, shell=True, text=True, capture_output=True, encoding="utf-8") + if result.returncode == 0: + logger.info("Snapshot created.") + return True + else: + if "Error" in result.stderr: + logger.error(f"Attempt {attempt + 1} failed with specific error: {result.stderr}") + else: + logger.error(f"Attempt {attempt + 1} failed: {result.stderr}") + + if attempt == max_retries - 1: + logger.error("Maximum retry attempts reached, failed to create snapshot.") + return False + + # Create a snapshot of the virtual machine + if create_vm_snapshot(vm_path, max_retries=MAX_RETRY_TIMES): + return vm_path + else: + raise ValueError("Error encountered during installation, please rerun the code for retrying.") + + +class VMwareVMManager(VMManager): + def __init__(self, registry_path=REGISTRY_PATH): + self.registry_path = registry_path + self.lock = FileLock(LOCK_FILE_NAME, timeout=60) + self.initialize_registry() + + def initialize_registry(self): + with self.lock: # Locking during initialization + if not os.path.exists(self.registry_path): + with open(self.registry_path, 'w') as file: + file.write('') + + def add_vm(self, vm_path, lock_needed=True): + if lock_needed: + with self.lock: + self._add_vm(vm_path) + else: + self._add_vm(vm_path) + + def _add_vm(self, vm_path, region=None): + assert region in [None, 'local'], "For VMware provider, the region should be neither None or 'local'." + with self.lock: + with open(self.registry_path, 'r') as file: + lines = file.readlines() + new_lines = lines + [f'{vm_path}|free\n'] + with open(self.registry_path, 'w') as file: + file.writelines(new_lines) + + def occupy_vm(self, vm_path, pid, lock_needed=True): + if lock_needed: + with self.lock: + self._occupy_vm(vm_path, pid) + else: + self._occupy_vm(vm_path, pid) + + def _occupy_vm(self, vm_path, pid, region=None): + assert region in [None, 'local'], "For VMware provider, the region should be neither None or 'local'." + with self.lock: + new_lines = [] + with open(self.registry_path, 'r') as file: + lines = file.readlines() + for line in lines: + registered_vm_path, _ = line.strip().split('|') + if registered_vm_path == vm_path: + new_lines.append(f'{registered_vm_path}|{pid}\n') + else: + new_lines.append(line) + with open(self.registry_path, 'w') as file: + file.writelines(new_lines) + + def delete_vm(self, vm_path, lock_needed=True): + if lock_needed: + with self.lock: + self._delete_vm(vm_path) + else: + self._delete_vm(vm_path) + + def _delete_vm(self, vm_path): + raise NotImplementedError + + def check_and_clean(self, vms_dir, lock_needed=True): + if lock_needed: + with self.lock: + self._check_and_clean(vms_dir) + else: + self._check_and_clean(vms_dir) + + def _check_and_clean(self, vms_dir): + with self.lock: # Lock when cleaning up the registry and vms_dir + # Check and clean on the running vms, detect the released ones and mark then as 'free' + active_pids = {p.pid for p in psutil.process_iter()} + new_lines = [] + vm_paths = [] + + with open(self.registry_path, 'r') as file: + lines = file.readlines() + for line in lines: + vm_path, pid_str = line.strip().split('|') + if not os.path.exists(vm_path): + logger.info(f"VM {vm_path} not found, releasing it.") + new_lines.append(f'{vm_path}|free\n') + continue + + vm_paths.append(vm_path) + if pid_str == "free": + new_lines.append(line) + continue + + if int(pid_str) in active_pids: + new_lines.append(line) + else: + new_lines.append(f'{vm_path}|free\n') + with open(self.registry_path, 'w') as file: + file.writelines(new_lines) + + # Check and clean on the files inside vms_dir, delete the unregistered ones + os.makedirs(vms_dir, exist_ok=True) + vm_names = os.listdir(vms_dir) + for vm_name in vm_names: + # skip the downloaded .zip file + if vm_name == DOWNLOADED_FILE_NAME: + continue + # Skip the .DS_Store file on macOS + if vm_name == ".DS_Store": + continue + + flag = True + for vm_path in vm_paths: + if vm_name + ".vmx" in vm_path: + flag = False + if flag: + shutil.rmtree(os.path.join(vms_dir, vm_name)) + + def list_free_vms(self, lock_needed=True): + if lock_needed: + with self.lock: + return self._list_free_vms() + else: + return self._list_free_vms() + + def _list_free_vms(self): + with self.lock: # Lock when reading the registry + free_vms = [] + with open(self.registry_path, 'r') as file: + lines = file.readlines() + for line in lines: + vm_path, pid_str = line.strip().split('|') + if pid_str == "free": + free_vms.append((vm_path, pid_str)) + return free_vms + + def get_vm_path(self, os_type, region=None): + with self.lock: + if not VMwareVMManager.checked_and_cleaned: + VMwareVMManager.checked_and_cleaned = True + self._check_and_clean(vms_dir=VMS_DIR) + + allocation_needed = False + with self.lock: + free_vms_paths = self._list_free_vms() + if len(free_vms_paths) == 0: + # No free virtual machine available, generate a new one + allocation_needed = True + else: + # Choose the first free virtual machine + chosen_vm_path = free_vms_paths[0][0] + self._occupy_vm(chosen_vm_path, os.getpid()) + return chosen_vm_path + + if allocation_needed: + logger.info("No free virtual machine available. Generating a new one, which would take a while...☕") + new_vm_name = generate_new_vm_name(vms_dir=VMS_DIR, os_type=os_type) + + original_vm_name = None + if os_type == "Ubuntu": + original_vm_name = "Ubuntu" + elif os_type == "Windows": + original_vm_name = "Windows 10 x64" + + new_vm_path = _install_vm(new_vm_name, vms_dir=VMS_DIR, + downloaded_file_name=DOWNLOADED_FILE_NAME, original_vm_name=original_vm_name, os_type=os_type) + with self.lock: + self._add_vm(new_vm_path) + self._occupy_vm(new_vm_path, os.getpid()) + return new_vm_path diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/vmware/provider.py b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/vmware/provider.py new file mode 100644 index 000000000..29733a86f --- /dev/null +++ b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/providers/vmware/provider.py @@ -0,0 +1,103 @@ +import logging +import os +import platform +import subprocess +import time + +from skyrl_agent.tasks.osworld.desktop_env.providers.base import Provider + +logger = logging.getLogger("desktopenv.providers.vmware.VMwareProvider") +logger.setLevel(logging.INFO) + +WAIT_TIME = 3 + + +def get_vmrun_type(return_list=False): + if platform.system() == 'Windows' or platform.system() == 'Linux': + if return_list: + return ['-T', 'ws'] + else: + return '-T ws' + elif platform.system() == 'Darwin': # Darwin is the system name for macOS + if return_list: + return ['-T', 'fusion'] + else: + return '-T fusion' + else: + raise Exception("Unsupported operating system") + + +class VMwareProvider(Provider): + @staticmethod + def _execute_command(command: list, return_output=False): + process = subprocess.Popen( + command, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + encoding="utf-8" + ) + + if return_output: + output = process.communicate()[0].strip() + return output + else: + return None + + def start_emulator(self, path_to_vm: str, headless: bool, os_type: str): + print("Starting VMware VM...") + logger.info("Starting VMware VM...") + + while True: + try: + output = subprocess.check_output(f"vmrun {get_vmrun_type()} list", shell=True, stderr=subprocess.STDOUT) + output = output.decode() + output = output.splitlines() + normalized_path_to_vm = os.path.abspath(os.path.normpath(path_to_vm)) + + if any(os.path.abspath(os.path.normpath(line)) == normalized_path_to_vm for line in output): + logger.info("VM is running.") + break + else: + logger.info("Starting VM...") + _command = ["vmrun"] + get_vmrun_type(return_list=True) + ["start", path_to_vm] + if headless: + _command.append("nogui") + VMwareProvider._execute_command(_command) + time.sleep(WAIT_TIME) + + except subprocess.CalledProcessError as e: + logger.error(f"Error executing command: {e.output.decode().strip()}") + + def get_ip_address(self, path_to_vm: str) -> str: + logger.info("Getting VMware VM IP address...") + while True: + try: + output = VMwareProvider._execute_command( + ["vmrun"] + get_vmrun_type(return_list=True) + ["getGuestIPAddress", path_to_vm, "-wait"], + return_output=True + ) + logger.info(f"VMware VM IP address: {output}") + return output + except Exception as e: + logger.error(e) + time.sleep(WAIT_TIME) + logger.info("Retrying to get VMware VM IP address...") + + def save_state(self, path_to_vm: str, snapshot_name: str): + logger.info("Saving VMware VM state...") + VMwareProvider._execute_command( + ["vmrun"] + get_vmrun_type(return_list=True) + ["snapshot", path_to_vm, snapshot_name]) + time.sleep(WAIT_TIME) # Wait for the VM to save + + def revert_to_snapshot(self, path_to_vm: str, snapshot_name: str): + logger.info(f"Reverting VMware VM to snapshot: {snapshot_name}...") + VMwareProvider._execute_command( + ["vmrun"] + get_vmrun_type(return_list=True) + ["revertToSnapshot", path_to_vm, snapshot_name]) + time.sleep(WAIT_TIME) # Wait for the VM to revert + return path_to_vm + + def stop_emulator(self, path_to_vm: str): + logger.info("Stopping VMware VM...") + VMwareProvider._execute_command(["vmrun"] + get_vmrun_type(return_list=True) + ["stop", path_to_vm]) + time.sleep(WAIT_TIME) # Wait for the VM to stop diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/server/README.md b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/server/README.md new file mode 100644 index 000000000..edf79e756 --- /dev/null +++ b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/server/README.md @@ -0,0 +1,657 @@ +# Server setup + +This README is useful if you want to set up your own machine for the environment. This README is not yet finished. Please contact the author if you need any assistance. + +## Configuration Overview + +The following sections contain guidelines for configuring the system image to ensure benchmark examples can run properly. + +The main configuration requirements include: + +1. **Account Credentials**: +Our benchmark configurations are based on specific username and password settings (with username `user` and password `password`). +Please ensure these settings remain consistent or update the corresponding configuration files. + +2. **Service Setup**: +Our environment operates through a service that automatically starts at boot time, as shown in the figure below. The service needs to be properly configured and placed. +![](https://os-world.github.io/static/images/env.png) + +3. **Accessibility Tree Support**: +Benchmark examples rely on accessibility tree functionality. The necessary support packages need to be installed. + +4. **System Service Management**: +Certain system services that may cause interference need to be disabled, such as automatic updates and notification pop-ups. + +5. **Required Software Installation**: +Ensure all necessary software packages required by the benchmark examples are properly installed. + +6. **Software Configuration**: +Various software packages require specific configurations, such as disabling certain auto-save features or installing additional plugins. + +7. **Port Configuration**: +To monitor and control software states from the host machine, specific port configurations are needed for various applications. + +8. **Miscellaneous Settings**: +Additional system-specific settings need to be configured, such as desktop environment settings and display resolution. + +Detailed instructions for each of these requirements will be provided in the following sections. + + +## [Ubuntu](https://huggingface.co/datasets/xlangai/ubuntu_osworld) + +Make a new VM with the Ubuntu 20.04 LTS image. + +### How to install Ubuntu Desktop (package: ubuntu-desktop) with GNOME desktop environment on Ubuntu 22.04 system. + +```bash +sudo apt update +sudo apt install ubuntu-desktop +sudo systemctl set-default graphical.target +``` + +### Account Credentials + +Download the iso file from the [Ubuntu website](https://ubuntu.com/download/alternative-downloads) and install it in the VM. + +Using GUI: +The default username should be `user` and the password should be `password` when you are asked to set up the account. Give the user sudo permission. + +Using Command Line: +```bash +sudo adduser user +usermod -aG sudo user +``` + +### Installation and Auto-login Setup + +1. Download the iso file from the [Ubuntu website](https://ubuntu.com/download/alternative-downloads) and install it in the VM. +The default username should be `user` and the password should be `password` when you are asked to set up the account. + +2. To enable automatic login: + +Using GUI: +```bash +# Open Settings -> Users +# Click Unlock button and enter password +# Toggle "Automatic Login" to ON for user 'user' +``` + +Or using Command Line: +```bash +# Edit the custom configuration file +sudo nano /etc/gdm3/custom.conf + +# Under [daemon] section, add or modify these lines: +AutomaticLoginEnable=true +AutomaticLogin=user + +# Save the file and restart the system +sudo systemctl restart gdm3 +``` + +After setting up automatic login, the system will boot directly into the desktop environment without requiring password input, which enables seamless startup experience for automated testing environments. + +### VNC Configuration + +1. Install x11vnc +``` +sudo apt update +sudo apt install x11vnc +``` + +2. Install noVNC +``` +sudo snap install novnc +``` + +3. Create system services for x11vnc and novnc: +- Go to directory cd `/etc/systemd/user/` +- Write a file `novnc.service` with the following content: +``` +[Unit] +Description=noVNC Service +After=x11vnc.service network.target snap.novnc.daemon.service +Wants=x11vnc.service + +[Service] +Type=simple +ExecStart=/snap/bin/novnc --vnc localhost:5900 --listen 5910 +Restart=on-failure +RestartSec=3 +Environment=DISPLAY=:0 +Environment=XAUTHORITY=/home/user/.Xauthority +Environment=SNAP_COOKIE=/run/snap.cookie +Environment=SNAP_NAME=novnc +Environment=SNAP_REVISION=current + +[Install] +WantedBy=default.target +``` +Write a file `x11vnc.service` with the following content: +``` +[Unit] +Description=X11 VNC Server +After=display-manager.service network.target +Wants=display-manager.service + +[Service] +Type=simple +ExecStart=x11vnc -display :0 -rfbport 5900 -forever +Restart=on-failure +RestartSec=3 +Environment=DISPLAY=:0 +Environment=XAUTHORITY=/home/user/.Xauthority + +[Install] +WantedBy=default.target +``` + +4. Enable both services: +``` +systemctl --user daemon-reload +systemctl --user enable novnc.service +systemctl --user enable x11vnc.service +systemctl --user start x11vnc.service +systemctl --user start novnc.service +``` + +5. Allow VNC port: +Expose 5910 port in the firewall and any other security tools you are using. + +6. Access the VNC server: +Connect to the VNC via `http://[Instance IP]:5910/vnc.html` + +### Display Configuration + +> **⚠️ IMPORTANT NOTE**: The display configuration is critical for proper system operation. Incorrect settings can prevent the graphical environment from starting and potentially crash the X server. Make sure to follow these steps carefully and verify each configuration file. If you encounter any issues, check the X server logs at `/var/log/Xorg.0.log` for troubleshooting. Backup your existing X11 configuration before making any changes. + + +1. Install dummy video driver: +``` +sudo apt-get install xserver-xorg-video-dummy +``` + +Go to `/etc/X11/` and create a file named `xorg.conf` with the following content: +``` +Section "ServerLayout" + Identifier "X.org Configured" + Screen 0 "Screen0" 0 0 + InputDevice "Mouse0" "CorePointer" + InputDevice "Keyboard0" "CoreKeyboard" +EndSection + +Section "Files" + ModulePath "/usr/lib/xorg/modules" + FontPath "/usr/share/fonts/X11/misc" + FontPath "/usr/share/fonts/X11/cyrillic" + FontPath "/usr/share/fonts/X11/100dpi/:unscaled" + FontPath "/usr/share/fonts/X11/75dpi/:unscaled" + FontPath "/usr/share/fonts/X11/Type1" + FontPath "/usr/share/fonts/X11/100dpi" + FontPath "/usr/share/fonts/X11/75dpi" + FontPath "built-ins" +EndSection + +Section "Module" + Load "glx" +EndSection + +Section "InputDevice" + Identifier "Keyboard0" + Driver "kbd" +EndSection + +Section "InputDevice" + Identifier "Mouse0" + Driver "mouse" + Option "Protocol" "auto" + Option "Device" "/dev/input/mice" + Option "ZAxisMapping" "4 5 6 7" +EndSection + +Section "Monitor" + Identifier "Monitor0" + VendorName "Monitor Vendor" + ModelName "Monitor Model" + HorizSync 28.0-80.0 + VertRefresh 48.0-75.0 +EndSection + +Section "Device" + ### Available Driver options are:- + ### Values: : integer, : float, : "True"/"False", + ### : "String", : " Hz/kHz/MHz", + ### : "%" + ### [arg]: arg optional + #Option "SWcursor" # [] + #Option "kmsdev" # + #Option "ShadowFB" # [] + #Option "AccelMethod" # + #Option "PageFlip" # [] + #Option "ZaphodHeads" # + #Option "DoubleShadow" # [] + #Option "Atomic" # [] + #Option "VariableRefresh" # [] + #Option "UseGammaLUT" # [] + #Option "AsyncFlipSecondaries" # [] + Identifier "Card0" + Driver "modesetting" + BusID "PCI:0:30:0" + VideoRam 256000 +EndSection + +Section "Screen" + Identifier "Screen0" + Device "Device0" + Monitor "Monitor0" + DefaultDepth 24 + SubSection "Display" + Depth 24 + Modes "1920x1080" + EndSubSection +EndSection +``` + +2. In the same directory as the previous step, go to its sub-directory `xorg.conf.d` , create a file named `10-dummy.conf` with the following content: +``` +Section "Device" + Identifier "DummyDevice" + Driver "dummy" + VideoRam 32768 +EndSection + +Section "Monitor" + Identifier "DummyMonitor" + HorizSync 28.0-80.0 + VertRefresh 48.0-75.0 + Modeline "1920x1080" 172.80 1920 2048 2248 2576 1080 1083 1088 1120 +EndSection + +Section "Screen" + Identifier "DummyScreen" + Device "DummyDevice" + Monitor "DummyMonitor" + DefaultDepth 24 + SubSection "Display" + Depth 24 + Modes "1920x1080" + EndSubSection +EndSection +``` + +3. Reload the display manager: +``` +sudo systemctl restart display-manager +``` + +### Set up the OSWorld server service in VM +Upload the OSWorld server to the home directory (/home/user) of user (via scp or git clone). + +1. Copy the `main.py` and `pyxcursor.py` and to the `/home/user-name` where the `user-name` is your username of the ubuntu, here we make it `user` as default. If you customize the path of placing these files in this step, you should change the parameters in the service file we will mention later accordingly. + +2. First please set up the environment: +```shell +sudo apt install python3 +pip3 install -r requirements.txt +sudo apt-get install python3-tk python3-dev +sudo apt install gnome-screenshot +sudo apt install wmctrl +sudo apt install ffmpeg +sudo apt install socat +sudo apt install xclip +``` + +If you encounter an error about python not being found, run: +``` +sudo ln -s /usr/bin/python3 /usr/bin/python +``` + +if you customize the environment in this step, you should change the parameters in the service file we will mention later accordingly. + +3. Due to some configuration issues, you need to modify the `osworld_server.service` file: +1) In our released version, the X server is set to :1, but the default X server is actually :0. You need to modify the `osworld_server.service` file to change the `DISPLAY` variable from `:1` to `:0`. +Change the following line: +``` +Environment="DISPLAY=:1" +``` +to +``` +Environment="DISPLAY=:0" +``` +2) Need to add environment variables to enable DBUS to change wallpaper. +Change the following line: +``` +Environment="DISPLAY=:0" +``` +to +``` +Environment="DISPLAY=:0;DBUS_SESSION_BUS_ADDRESS=unix:path=/run/user/1000/bus" +``` + +4. Copy the `osworld_server.service` to the systemd configuration directory at `/etc/systemd/system/`: +```shell +sudo cp osworld_server.service /etc/systemd/system/ +``` + +Reload the systemd daemon to recognize the new service: +```shell +sudo systemctl daemon-reload +``` + +Enable the service to start on boot: +```shell +sudo systemctl enable osworld_server.service +``` + +Start the service: +```shell +sudo systemctl start osworld_server.service +``` + +Verify the service is running correctly: +```shell +sudo systemctl status osworld_server.service +``` + +You should see output indicating the service is active and running. If there are errors, review the logs with `journalctl -xe` for further troubleshooting. + +If you need to make adjustments to the service configuration, you can edit the `/etc/systemd/system/osworld_server.service` file: +```shell +sudo nano /etc/systemd/system/osworld_server.service +``` + +After making changes, reload the daemon and restart the service: +```shell +sudo systemctl daemon-reload +sudo systemctl enable osworld_server.service +sudo systemctl start osworld_server.service +``` + +### Accessibility Tree Support + +To support the accessibility tree functionality, you'll need to install pyastpi2 in your Ubuntu environment. This package enables access to accessibility information and tree structures. + +Installation steps: + +```bash +# Update package list and ensure pip is installed +sudo apt-get update +sudo apt-get install python3-pip + +# Install pyastpi2 using pip +pip3 install pyastpi2 +``` + +### Xorg Configuration + +Regarding the graphics display system, we need to ensure that Ubuntu displays images using the **Xorg** protocol instead of **Wayland**. Since **Wayland** is typically the default setting for Ubuntu, we will need to manually change the settings. + +1. Click the user menu in the upper right corner and select "Log Out" or "Sign Off." +2. On the login screen, click on the username. +3. Before entering the password, click the gear icon in the lower right or upper right corner of the screen (it may need to be displayed after clicking the username first). +4. Select "Ubuntu on Xorg" from the pop-up menu. + +You can run the following command to check if **Xorg** is being used: + +```bash +echo $XDG_SESSION_TYPE +``` + +### System Service Management (Optional) + +The automatic software update service can interfere with benchmark examples. To disable this service, you can refer to the https://www.makeuseof.com/disable-automatic-updates-in-ubuntu/ for the solution. + +You can check and manage system services using systemctl commands. For example, to verify if a service like unattended-upgrades is installed and running on your system: + +```bash +# Check service status +sudo systemctl status unattended-upgrades.service +``` + +If the output is `x11`, it means you have switched to **Xorg**. + +To disable a system service: +```bash +# Disable and stop the service +sudo systemctl disable unattended-upgrades +sudo systemctl stop unattended-upgrades +``` + +To verify service configurations, you can use apt-config: +```bash +# Check current configurations +apt-config dump APT::Periodic::Update-Package-Lists +apt-config dump APT::Periodic::Unattended-Upgrade +``` + + +### Software Installation + +#### Software Installation Source +Since for some examples like change the settings of certain software, we hardcode some paths in our evaluation file, which means you need to install the software to the specific path. Here we provide a list of software that you need to install and the certain source which default the path you should install them to. + +1. Chrome: If you are using ARM System, download the chromium using `sudo snap install chromium` and make sure your Chromium config files are under `~/snap/chromium`; otherwise, download the chrome from the [Chromium](https://www.chromium.org/Home) and make sure your Chromium config files are under `~/.config/google-chrome`. +2. LibreOffice: Go to [LibreOffice Website](https://www.libreoffice.org/), select "Download Libreoffice", select "older versions" in the bottom of the page, and download `7.3.7.2` version. +3. GIMP: Search "GIMP" in "Ubuntu Software" and install it. Our GIMP version is `2.10.30`. +4. VLC: Search "VLC" in "Ubuntu Software" and install it. Our VLC version is `3.0.16`. +5. VSCode: Go to [VSCode Website](https://code.visualstudio.com/download), download the `.deb` file, and install it. Our VSCode version is `1.91.1`. + +#### Additional Inner Software Installation +> **⚠️ IMPORTANT NOTE**: The software installation and configuration steps described in this section are crucial for maintaining consistent task execution and performance. Skipping or incorrectly configuring these components may lead to task failures or degraded performance. Please follow the installation instructions carefully and verify each component is properly set up before proceeding. + + +##### LibreOffice font installation +Some examples in LibreOffice Impress use non-default system fonts, and you need to download the corresponding **TTF files** and put them in the system fonts directory. +[Here](https://drive.usercontent.google.com/download?id=1UzmdsfUQRTnvCxkvWrKguwZM3G5eQk87&export=download&authuser=0&confirm=t&uuid=70b9fbb7-9585-4aa4-a2c0-a7d6126469a0&at=AEz70l4rdEjdxBpqkLyW9lcil6S5:1740142224052) we provides all the fonts downloaded, just download it, and unzip to the system fonts directory (which usually `usr/share/fonts/`). +```bash +unzip fonts.zip -d /usr/share/fonts/ +``` + +And then run the following command to refresh the font cache: +```bash +sudo fc-cache -fv +``` + +##### Customized Plugin Installation + +**VS Code plugin installation:** +To extract relevant internal information and configurations from the VS Code environment, we principally leverage the capabilities offered by the VS Code Extension API. Here's how to install the extension developed by ourselves: +```bash +1. Download the extension from: https://github.com/xlang-ai/OSWorld/blob/04a9df627c7033fab991806200877a655e895bfd/vscodeEvalExtension/eval-0.0.1.vsix +2. Open VS Code +3. Go to Extensions -> ... -> Install from VSIX... -> choose the downloaded eval-0.0.1.vsix file +``` + + +### Software Configuration +1. LibreOffice Default Format Settings: +```bash +# Open LibreOffice Writer/Calc/Impress +# Go to Tools -> Options -> Load/Save -> General +# Under "Default file format and ODF settings": +# Change "Document type" to "Text document" +# Set "Always save as" to "Word 2007-365 (.docx)" +# Change "Document type" to "Spreadsheet" +# Set "Always save as" to "Excel 2007-365 (.xlsx)" +# Change "Document type" to "Presentation" +# Set "Always save as" to "PowerPoint 2007-365 (.pptx)" +``` + +2. Chrome password requirement removal: +Chrome requests a password input when first opened after system startup, which can interfere with our experiments. Here's how to disable this feature: + +```bash +# Prevent Chrome from using keyring +mkdir -p ~/.local/share/keyrings +touch ~/.local/share/keyrings/login.keyring +``` + +Or you can use any ways to disable the keyring service, which will prevent Chrome from requesting a password input. + + +### Network Configuration + +#### Firewall Configuration + +In OSWorld, we need the following ports to be open: +``` +server_port = 5000 +chromium_port = 9222 +vnc_port = 8006 +vlc_port = 8080 +novnc_port = 5910 +``` + +Please open the corresponding ports in the firewall and any other security tools you are using. + +#### socat Installation + +Ensure `socat` is installed to enable port forwarding. + +```sh +sudo apt install socat +``` + +#### Network Configuration for Remote Control + +##### VLC Configuration +To enable remote control of VLC media player, follow these configuration steps: + +1. Enable HTTP interface: +```bash +# Open VLC +# Go to Tools -> Preferences +# Show Settings: All (bottom left) +# Navigate to Interface -> Main interfaces +# Check 'Web' option +``` + +2. Configure HTTP interface settings: +```bash +# Still in Preferences +# Navigate to Interface -> Main interfaces -> Lua +# Under Lua HTTP: +# - Set Password to 'password' +``` + +The following is the screenshot of the VLC configuration: +![vlc_configuration](https://os-world.github.io/static/images/vlc_configuration.png) +When VLC is open, the service will be running on port 8080. + +##### Chrome Configuration +To ensure Chrome uses consistent debugging ports even after being closed and reopened, follow these steps: + +1. Create or edit Chrome desktop entry: +```bash +sudo nano /usr/share/applications/google-chrome.desktop +``` + +2. Modify the Exec lines to include debugging port: +```bash +# Find lines starting with "Exec=" and add the following flags: +--remote-debugging-port=1337 --remote-debugging-address=0.0.0.0 +``` + +In cases where need Chrome, the 1337 will be forwarded to 9222 in the virtual machine via socat. + + +### Miscellaneous Settings + +#### Screen Resolution + +The required screen resolution for the virtual machine is 1920x1080 in OSWorld and we did make some hardcode related to this resolution in our configuration file in some examples, but only a few. +So please set the screen resolution to 1920x1080 in the virtual machine settings. + +#### Automatic Suspend + +To close automatic suspend, open Setting app and enter "Power" section. Switch "Screen Blank" to "Never" and "Automatic Suspend" to "Off". + +#### Additional Installation + +Activating the window manager control requires the installation of `wmctrl`: +```bash +sudo apt install wmctrl +``` +Otherwise, you cannot control the window manager in the virtual machine when running the experiments. Some cases will be effected. + +To enable recording in the virtual machine, you need to install `ffmpeg`: +```bash +sudo apt install ffmpeg +``` +Otherwise you cannot get the video recording of the virtual machine when running the experiments. + + +### Others Information + +#### About the Converted Accessibility Tree + +For several applications like Firefox or Thunderbird, you should first enable + +```sh +gsettings set org.gnome.desktop.interface toolkit-accessibility true +``` + +to see their accessibility tree. + +##### Example of AT + +An example of a node: + +```xml +
+ 歡迎使用新的 Outlook.com 帳戶 +
+``` + +An example of a tree: + +```xml + + + ... + + ... + +``` + +##### Useful attributes + +1. `name` - shows the name of application, title of window, or name of some + component +2. `attr:class` - somewhat the same role as `class` in HTML +3. `attr:id` - somewhat the same role as `id` in HTML +4. `cp:screencoord` - absolute coordinator on the screen +5. `cp:windowcoord` - relative coordinator in the window +6. `cp:size` - the size + +Also several states like `st:enabled` and `st:visible` can be indicated. A full +state list is available at +. + +##### How to use it in evaluation + +See example `thunderbird/12086550-11c0-466b-b367-1d9e75b3910e.json` and +function `check_accessibility_tree` in `metrics/general.py`. You can use CSS +selector or XPath to reference a target nodes. You can also check its text +contents. + +An example of a CSS selector: + +```css +application[name=Thunderbird] page-tab-list[attr|id="tabmail-tabs"]>page-tab[name="About Profiles"] +``` + +This selector will select the page tab of profile manager in Thunderbird (if open). + +For usage of CSS selector: . For usage of XPath: . + +##### Manual check + +You can use accerciser to check the accessibility tree on GNOME VM. + +```sh +sudo apt install accerciser +``` + +## [MacOS](https://huggingface.co/datasets/xlangai/macos_osworld) +Coming soon... diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/server/main.py b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/server/main.py new file mode 100644 index 000000000..606d0ac11 --- /dev/null +++ b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/server/main.py @@ -0,0 +1,1570 @@ +import ctypes +import os +import platform +import shlex +import json +import subprocess, signal +import time +from pathlib import Path +from typing import Any, Optional, Sequence +from typing import List, Dict, Tuple, Literal +import concurrent.futures + +import Xlib +import lxml.etree +import pyautogui +import requests +import re +from PIL import Image, ImageGrab +from Xlib import display, X +from flask import Flask, request, jsonify, send_file, abort # , send_from_directory +from lxml.etree import _Element + +platform_name: str = platform.system() + +if platform_name == "Linux": + import pyatspi + from pyatspi import Accessible, StateType, STATE_SHOWING + from pyatspi import Action as ATAction + from pyatspi import Component # , Document + from pyatspi import Text as ATText + from pyatspi import Value as ATValue + + BaseWrapper = Any + +elif platform_name == "Windows": + from pywinauto import Desktop + from pywinauto.base_wrapper import BaseWrapper + import pywinauto.application + import win32ui, win32gui + + Accessible = Any + +elif platform_name == "Darwin": + import plistlib + + import AppKit + import ApplicationServices + import Foundation + import Quartz + import oa_atomacos + + Accessible = Any + BaseWrapper = Any + +else: + # Platform not supported + Accessible = None + BaseWrapper = Any + +from pyxcursor import Xcursor + +# todo: need to reformat and organize this whole file + +app = Flask(__name__) + +pyautogui.PAUSE = 0 +pyautogui.DARWIN_CATCH_UP_TIME = 0 + +TIMEOUT = 1800 # seconds + +logger = app.logger +recording_process = None # fixme: this is a temporary solution for recording, need to be changed to support multiple-process +recording_path = "/tmp/recording.mp4" + + +@app.route('/setup/execute', methods=['POST']) +@app.route('/execute', methods=['POST']) +def execute_command(): + data = request.json + # The 'command' key in the JSON request should contain the command to be executed. + shell = data.get('shell', False) + command = data.get('command', "" if shell else []) + + if isinstance(command, str) and not shell: + command = shlex.split(command) + + # Expand user directory + for i, arg in enumerate(command): + if arg.startswith("~/"): + command[i] = os.path.expanduser(arg) + + # Execute the command without any safety checks. + try: + if platform_name == "Windows": + flags = subprocess.CREATE_NO_WINDOW + else: + flags = 0 + result = subprocess.run( + command, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=shell, + text=True, + timeout=120, + creationflags=flags, + ) + return jsonify({ + 'status': 'success', + 'output': result.stdout, + 'error': result.stderr, + 'returncode': result.returncode + }) + except Exception as e: + return jsonify({ + 'status': 'error', + 'message': str(e) + }), 500 + + +@app.route('/setup/execute_with_verification', methods=['POST']) +@app.route('/execute_with_verification', methods=['POST']) +def execute_command_with_verification(): + """Execute command and verify the result based on provided verification criteria""" + data = request.json + shell = data.get('shell', False) + command = data.get('command', "" if shell else []) + verification = data.get('verification', {}) + max_wait_time = data.get('max_wait_time', 10) # Maximum wait time in seconds + check_interval = data.get('check_interval', 1) # Check interval in seconds + + if isinstance(command, str) and not shell: + command = shlex.split(command) + + # Expand user directory + for i, arg in enumerate(command): + if arg.startswith("~/"): + command[i] = os.path.expanduser(arg) + + # Execute the main command + try: + if platform_name == "Windows": + flags = subprocess.CREATE_NO_WINDOW + else: + flags = 0 + result = subprocess.run( + command, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=shell, + text=True, + timeout=120, + creationflags=flags, + ) + + # If no verification is needed, return immediately + if not verification: + return jsonify({ + 'status': 'success', + 'output': result.stdout, + 'error': result.stderr, + 'returncode': result.returncode + }) + + # Wait and verify the result + import time + start_time = time.time() + while time.time() - start_time < max_wait_time: + verification_passed = True + + # Check window existence if specified + if 'window_exists' in verification: + window_name = verification['window_exists'] + try: + if platform_name == 'Linux': + wmctrl_result = subprocess.run(['wmctrl', '-l'], + capture_output=True, text=True, check=True) + if window_name.lower() not in wmctrl_result.stdout.lower(): + verification_passed = False + elif platform_name in ['Windows', 'Darwin']: + import pygetwindow as gw + windows = gw.getWindowsWithTitle(window_name) + if not windows: + verification_passed = False + except Exception: + verification_passed = False + + # Check command execution if specified + if 'command_success' in verification: + verify_cmd = verification['command_success'] + try: + verify_result = subprocess.run(verify_cmd, shell=True, + capture_output=True, text=True, timeout=5) + if verify_result.returncode != 0: + verification_passed = False + except Exception: + verification_passed = False + + if verification_passed: + return jsonify({ + 'status': 'success', + 'output': result.stdout, + 'error': result.stderr, + 'returncode': result.returncode, + 'verification': 'passed', + 'wait_time': time.time() - start_time + }) + + time.sleep(check_interval) + + # Verification failed + return jsonify({ + 'status': 'verification_failed', + 'output': result.stdout, + 'error': result.stderr, + 'returncode': result.returncode, + 'verification': 'failed', + 'wait_time': max_wait_time + }), 500 + + except Exception as e: + return jsonify({ + 'status': 'error', + 'message': str(e) + }), 500 + + +def _get_machine_architecture() -> str: + """ Get the machine architecture, e.g., x86_64, arm64, aarch64, i386, etc. + """ + architecture = platform.machine().lower() + if architecture in ['amd32', 'amd64', 'x86', 'x86_64', 'x86-64', 'x64', 'i386', 'i686']: + return 'amd' + elif architecture in ['arm64', 'aarch64', 'aarch32']: + return 'arm' + else: + return 'unknown' + + +@app.route('/setup/launch', methods=["POST"]) +def launch_app(): + data = request.json + shell = data.get("shell", False) + command: List[str] = data.get("command", "" if shell else []) + + if isinstance(command, str) and not shell: + command = shlex.split(command) + + # Expand user directory + for i, arg in enumerate(command): + if arg.startswith("~/"): + command[i] = os.path.expanduser(arg) + + try: + if 'google-chrome' in command and _get_machine_architecture() == 'arm': + index = command.index('google-chrome') + command[index] = 'chromium' # arm64 chrome is not available yet, can only use chromium + subprocess.Popen(command, shell=shell) + return "{:} launched successfully".format(command if shell else " ".join(command)) + except Exception as e: + return jsonify({"status": "error", "message": str(e)}), 500 + + +@app.route('/screenshot', methods=['GET']) +def capture_screen_with_cursor(): + # fixme: when running on virtual machines, the cursor is not captured, don't know why + + file_path = os.path.join(os.path.dirname(__file__), "screenshots", "screenshot.png") + user_platform = platform.system() + + # Ensure the screenshots directory exists + os.makedirs(os.path.dirname(file_path), exist_ok=True) + + # fixme: This is a temporary fix for the cursor not being captured on Windows and Linux + if user_platform == "Windows": + def get_cursor(): + hcursor = win32gui.GetCursorInfo()[1] + hdc = win32ui.CreateDCFromHandle(win32gui.GetDC(0)) + hbmp = win32ui.CreateBitmap() + hbmp.CreateCompatibleBitmap(hdc, 36, 36) + hdc = hdc.CreateCompatibleDC() + hdc.SelectObject(hbmp) + hdc.DrawIcon((0,0), hcursor) + + bmpinfo = hbmp.GetInfo() + bmpstr = hbmp.GetBitmapBits(True) + cursor = Image.frombuffer('RGB', (bmpinfo['bmWidth'], bmpinfo['bmHeight']), bmpstr, 'raw', 'BGRX', 0, 1).convert("RGBA") + + win32gui.DestroyIcon(hcursor) + win32gui.DeleteObject(hbmp.GetHandle()) + hdc.DeleteDC() + + pixdata = cursor.load() + + width, height = cursor.size + for y in range(height): + for x in range(width): + if pixdata[x, y] == (0, 0, 0, 255): + pixdata[x, y] = (0, 0, 0, 0) + + hotspot = win32gui.GetIconInfo(hcursor)[1:3] + + return (cursor, hotspot) + + ratio = ctypes.windll.shcore.GetScaleFactorForDevice(0) / 100 + + img = ImageGrab.grab(bbox=None, include_layered_windows=True) + + try: + cursor, (hotspotx, hotspoty) = get_cursor() + + pos_win = win32gui.GetCursorPos() + pos = (round(pos_win[0]*ratio - hotspotx), round(pos_win[1]*ratio - hotspoty)) + + img.paste(cursor, pos, cursor) + except Exception as e: + logger.warning(f"Failed to capture cursor on Windows, screenshot will not have a cursor. Error: {e}") + + img.save(file_path) + elif user_platform == "Linux": + cursor_obj = Xcursor() + imgarray = cursor_obj.getCursorImageArrayFast() + cursor_img = Image.fromarray(imgarray) + screenshot = pyautogui.screenshot() + cursor_x, cursor_y = pyautogui.position() + screenshot.paste(cursor_img, (cursor_x, cursor_y), cursor_img) + screenshot.save(file_path) + elif user_platform == "Darwin": # (Mac OS) + # Use the screencapture utility to capture the screen with the cursor + subprocess.run(["screencapture", "-C", file_path]) + else: + logger.warning(f"The platform you're using ({user_platform}) is not currently supported") + + return send_file(file_path, mimetype='image/png') + + +def _has_active_terminal(desktop: Accessible) -> bool: + """ A quick check whether the terminal window is open and active. + """ + for app in desktop: + if app.getRoleName() == "application" and app.name == "gnome-terminal-server": + for frame in app: + if frame.getRoleName() == "frame" and frame.getState().contains(pyatspi.STATE_ACTIVE): + return True + return False + + +@app.route('/terminal', methods=['GET']) +def get_terminal_output(): + user_platform = platform.system() + output: Optional[str] = None + try: + if user_platform == "Linux": + desktop: Accessible = pyatspi.Registry.getDesktop(0) + if _has_active_terminal(desktop): + desktop_xml: _Element = _create_atspi_node(desktop) + # 1. the terminal window (frame of application is st:active) is open and active + # 2. the terminal tab (terminal status is st:focused) is focused + xpath = '//application[@name="gnome-terminal-server"]/frame[@st:active="true"]//terminal[@st:focused="true"]' + terminals: List[_Element] = desktop_xml.xpath(xpath, namespaces=_accessibility_ns_map_ubuntu) + output = terminals[0].text.rstrip() if len(terminals) == 1 else None + else: # windows and macos platform is not implemented currently + # raise NotImplementedError + return "Currently not implemented for platform {:}.".format(platform.platform()), 500 + return jsonify({"output": output, "status": "success"}) + except Exception as e: + logger.error("Failed to get terminal output. Error: %s", e) + return jsonify({"status": "error", "message": str(e)}), 500 + + +_accessibility_ns_map = { + "ubuntu": { + "st": "https://accessibility.ubuntu.example.org/ns/state", + "attr": "https://accessibility.ubuntu.example.org/ns/attributes", + "cp": "https://accessibility.ubuntu.example.org/ns/component", + "doc": "https://accessibility.ubuntu.example.org/ns/document", + "docattr": "https://accessibility.ubuntu.example.org/ns/document/attributes", + "txt": "https://accessibility.ubuntu.example.org/ns/text", + "val": "https://accessibility.ubuntu.example.org/ns/value", + "act": "https://accessibility.ubuntu.example.org/ns/action", + }, + "windows": { + "st": "https://accessibility.windows.example.org/ns/state", + "attr": "https://accessibility.windows.example.org/ns/attributes", + "cp": "https://accessibility.windows.example.org/ns/component", + "doc": "https://accessibility.windows.example.org/ns/document", + "docattr": "https://accessibility.windows.example.org/ns/document/attributes", + "txt": "https://accessibility.windows.example.org/ns/text", + "val": "https://accessibility.windows.example.org/ns/value", + "act": "https://accessibility.windows.example.org/ns/action", + "class": "https://accessibility.windows.example.org/ns/class" + }, + "macos": { + "st": "https://accessibility.macos.example.org/ns/state", + "attr": "https://accessibility.macos.example.org/ns/attributes", + "cp": "https://accessibility.macos.example.org/ns/component", + "doc": "https://accessibility.macos.example.org/ns/document", + "txt": "https://accessibility.macos.example.org/ns/text", + "val": "https://accessibility.macos.example.org/ns/value", + "act": "https://accessibility.macos.example.org/ns/action", + "role": "https://accessibility.macos.example.org/ns/role", + } + +} + +_accessibility_ns_map_ubuntu = _accessibility_ns_map['ubuntu'] +_accessibility_ns_map_windows = _accessibility_ns_map['windows'] +_accessibility_ns_map_macos = _accessibility_ns_map['macos'] + +# A11y tree getter for Ubuntu +libreoffice_version_tuple: Optional[Tuple[int, ...]] = None +MAX_DEPTH = 50 +MAX_WIDTH = 1024 +MAX_CALLS = 5000 + + +def _get_libreoffice_version() -> Tuple[int, ...]: + """Function to get the LibreOffice version as a tuple of integers.""" + result = subprocess.run("libreoffice --version", shell=True, text=True, stdout=subprocess.PIPE) + version_str = result.stdout.split()[1] # Assuming version is the second word in the command output + return tuple(map(int, version_str.split("."))) + + +def _create_atspi_node(node: Accessible, depth: int = 0, flag: Optional[str] = None) -> _Element: + node_name = node.name + attribute_dict: Dict[str, Any] = {"name": node_name} + + # States + states: List[StateType] = node.getState().get_states() + for st in states: + state_name: str = StateType._enum_lookup[st] + state_name: str = state_name.split("_", maxsplit=1)[1].lower() + if len(state_name) == 0: + continue + attribute_dict["{{{:}}}{:}".format(_accessibility_ns_map_ubuntu["st"], state_name)] = "true" + + # Attributes + attributes: Dict[str, str] = node.get_attributes() + for attribute_name, attribute_value in attributes.items(): + if len(attribute_name) == 0: + continue + attribute_dict["{{{:}}}{:}".format(_accessibility_ns_map_ubuntu["attr"], attribute_name)] = attribute_value + + # Component + if attribute_dict.get("{{{:}}}visible".format(_accessibility_ns_map_ubuntu["st"]), "false") == "true" \ + and attribute_dict.get("{{{:}}}showing".format(_accessibility_ns_map_ubuntu["st"]), "false") == "true": + try: + component: Component = node.queryComponent() + except NotImplementedError: + pass + else: + bbox: Sequence[int] = component.getExtents(pyatspi.XY_SCREEN) + attribute_dict["{{{:}}}screencoord".format(_accessibility_ns_map_ubuntu["cp"])] = \ + str(tuple(bbox[0:2])) + attribute_dict["{{{:}}}size".format(_accessibility_ns_map_ubuntu["cp"])] = str(tuple(bbox[2:])) + + text = "" + # Text + try: + text_obj: ATText = node.queryText() + # only text shown on current screen is available + # attribute_dict["txt:text"] = text_obj.getText(0, text_obj.characterCount) + text: str = text_obj.getText(0, text_obj.characterCount) + # if flag=="thunderbird": + # appeared in thunderbird (uFFFC) (not only in thunderbird), "Object + # Replacement Character" in Unicode, "used as placeholder in text for + # an otherwise unspecified object; uFFFD is another "Replacement + # Character", just in case + text = text.replace("\ufffc", "").replace("\ufffd", "") + except NotImplementedError: + pass + + # Image, Selection, Value, Action + try: + node.queryImage() + attribute_dict["image"] = "true" + except NotImplementedError: + pass + + try: + node.querySelection() + attribute_dict["selection"] = "true" + except NotImplementedError: + pass + + try: + value: ATValue = node.queryValue() + value_key = f"{{{_accessibility_ns_map_ubuntu['val']}}}" + + for attr_name, attr_func in [ + ("value", lambda: value.currentValue), + ("min", lambda: value.minimumValue), + ("max", lambda: value.maximumValue), + ("step", lambda: value.minimumIncrement) + ]: + try: + attribute_dict[f"{value_key}{attr_name}"] = str(attr_func()) + except: + pass + except NotImplementedError: + pass + + try: + action: ATAction = node.queryAction() + for i in range(action.nActions): + action_name: str = action.getName(i).replace(" ", "-") + attribute_dict[ + "{{{:}}}{:}_desc".format(_accessibility_ns_map_ubuntu["act"], action_name)] = action.getDescription( + i) + attribute_dict[ + "{{{:}}}{:}_kb".format(_accessibility_ns_map_ubuntu["act"], action_name)] = action.getKeyBinding(i) + except NotImplementedError: + pass + + # Add from here if we need more attributes in the future... + + raw_role_name: str = node.getRoleName().strip() + node_role_name = (raw_role_name or "unknown").replace(" ", "-") + + if not flag: + if raw_role_name == "document spreadsheet": + flag = "calc" + if raw_role_name == "application" and node.name == "Thunderbird": + flag = "thunderbird" + + xml_node = lxml.etree.Element( + node_role_name, + attrib=attribute_dict, + nsmap=_accessibility_ns_map_ubuntu + ) + + if len(text) > 0: + xml_node.text = text + + if depth == MAX_DEPTH: + logger.warning("Max depth reached") + return xml_node + + if flag == "calc" and node_role_name == "table": + # Maximum column: 1024 if ver<=7.3 else 16384 + # Maximum row: 104 8576 + # Maximun sheet: 1 0000 + + global libreoffice_version_tuple + MAXIMUN_COLUMN = 1024 if libreoffice_version_tuple < (7, 4) else 16384 + MAX_ROW = 104_8576 + + index_base = 0 + first_showing = False + column_base = None + for r in range(MAX_ROW): + for clm in range(column_base or 0, MAXIMUN_COLUMN): + child_node: Accessible = node[index_base + clm] + showing: bool = child_node.getState().contains(STATE_SHOWING) + if showing: + child_node: _Element = _create_atspi_node(child_node, depth + 1, flag) + if not first_showing: + column_base = clm + first_showing = True + xml_node.append(child_node) + elif first_showing and column_base is not None or clm >= 500: + break + if first_showing and clm == column_base or not first_showing and r >= 500: + break + index_base += MAXIMUN_COLUMN + return xml_node + else: + try: + for i, ch in enumerate(node): + if i == MAX_WIDTH: + logger.warning("Max width reached") + break + xml_node.append(_create_atspi_node(ch, depth + 1, flag)) + except: + logger.warning("Error occurred during children traversing. Has Ignored. Node: %s", + lxml.etree.tostring(xml_node, encoding="unicode")) + return xml_node + + +# A11y tree getter for Windows +def _create_pywinauto_node(node, nodes, depth: int = 0, flag: Optional[str] = None) -> _Element: + nodes = nodes or set() + if node in nodes: + return + nodes.add(node) + + attribute_dict: Dict[str, Any] = {"name": node.element_info.name} + + base_properties = {} + try: + base_properties.update( + node.get_properties()) # get all writable/not writable properties, but have bugs when landing on chrome and it's slower! + except: + logger.debug("Failed to call get_properties(), trying to get writable properites") + try: + _element_class = node.__class__ + + class TempElement(node.__class__): + writable_props = pywinauto.base_wrapper.BaseWrapper.writable_props + + # Instantiate the subclass + node.__class__ = TempElement + # Retrieve properties using get_properties() + properties = node.get_properties() + node.__class__ = _element_class + + base_properties.update(properties) # only get all writable properties + logger.debug("get writable properties") + except Exception as e: + logger.error(e) + pass + + # Count-cnt + for attr_name in ["control_count", "button_count", "item_count", "column_count"]: + try: + attribute_dict[f"{{{_accessibility_ns_map_windows['cnt']}}}{attr_name}"] = base_properties[ + attr_name].lower() + except: + pass + + # Columns-cols + try: + attribute_dict[f"{{{_accessibility_ns_map_windows['cols']}}}columns"] = base_properties["columns"].lower() + except: + pass + + # Id-id + for attr_name in ["control_id", "automation_id", "window_id"]: + try: + attribute_dict[f"{{{_accessibility_ns_map_windows['id']}}}{attr_name}"] = base_properties[attr_name].lower() + except: + pass + + # States + # 19 sec out of 20 + for attr_name, attr_func in [ + ("enabled", lambda: node.is_enabled()), + ("visible", lambda: node.is_visible()), + # ("active", lambda: node.is_active()), # occupied most of the time: 20s out of 21s for slack, 51.5s out of 54s for WeChat # maybe use for cutting branches + ("minimized", lambda: node.is_minimized()), + ("maximized", lambda: node.is_maximized()), + ("normal", lambda: node.is_normal()), + ("unicode", lambda: node.is_unicode()), + ("collapsed", lambda: node.is_collapsed()), + ("checkable", lambda: node.is_checkable()), + ("checked", lambda: node.is_checked()), + ("focused", lambda: node.is_focused()), + ("keyboard_focused", lambda: node.is_keyboard_focused()), + ("selected", lambda: node.is_selected()), + ("selection_required", lambda: node.is_selection_required()), + ("pressable", lambda: node.is_pressable()), + ("pressed", lambda: node.is_pressed()), + ("expanded", lambda: node.is_expanded()), + ("editable", lambda: node.is_editable()), + ("has_keyboard_focus", lambda: node.has_keyboard_focus()), + ("is_keyboard_focusable", lambda: node.is_keyboard_focusable()), + ]: + try: + attribute_dict[f"{{{_accessibility_ns_map_windows['st']}}}{attr_name}"] = str(attr_func()).lower() + except: + pass + + # Component + try: + rectangle = node.rectangle() + attribute_dict["{{{:}}}screencoord".format(_accessibility_ns_map_windows["cp"])] = \ + "({:d}, {:d})".format(rectangle.left, rectangle.top) + attribute_dict["{{{:}}}size".format(_accessibility_ns_map_windows["cp"])] = \ + "({:d}, {:d})".format(rectangle.width(), rectangle.height()) + + except Exception as e: + logger.error("Error accessing rectangle: ", e) + + # Text + text: str = node.window_text() + if text == attribute_dict["name"]: + text = "" + + # Selection + if hasattr(node, "select"): + attribute_dict["selection"] = "true" + + # Value + for attr_name, attr_funcs in [ + ("step", [lambda: node.get_step()]), + ("value", [lambda: node.value(), lambda: node.get_value(), lambda: node.get_position()]), + ("min", [lambda: node.min_value(), lambda: node.get_range_min()]), + ("max", [lambda: node.max_value(), lambda: node.get_range_max()]) + ]: + for attr_func in attr_funcs: + if hasattr(node, attr_func.__name__): + try: + attribute_dict[f"{{{_accessibility_ns_map_windows['val']}}}{attr_name}"] = str(attr_func()) + break # exit once the attribute is set successfully + except: + pass + + attribute_dict["{{{:}}}class".format(_accessibility_ns_map_windows["class"])] = str(type(node)) + + # class_name + for attr_name in ["class_name", "friendly_class_name"]: + try: + attribute_dict[f"{{{_accessibility_ns_map_windows['class']}}}{attr_name}"] = base_properties[ + attr_name].lower() + except: + pass + + node_role_name: str = node.class_name().lower().replace(" ", "-") + node_role_name = "".join( + map(lambda _ch: _ch if _ch.isidentifier() or _ch in {"-"} or _ch.isalnum() else "-", node_role_name)) + + if node_role_name.strip() == "": + node_role_name = "unknown" + if not node_role_name[0].isalpha(): + node_role_name = "tag" + node_role_name + + xml_node = lxml.etree.Element( + node_role_name, + attrib=attribute_dict, + nsmap=_accessibility_ns_map_windows + ) + + if text is not None and len(text) > 0 and text != attribute_dict["name"]: + xml_node.text = text + + if depth == MAX_DEPTH: + logger.warning("Max depth reached") + return xml_node + + # use multi thread to accelerate children fetching + children = node.children() + if children: + with concurrent.futures.ThreadPoolExecutor() as executor: + future_to_child = [executor.submit(_create_pywinauto_node, ch, nodes, depth + 1, flag) for ch in + children[:MAX_WIDTH]] + try: + xml_node.extend([future.result() for future in concurrent.futures.as_completed(future_to_child)]) + except Exception as e: + logger.error(f"Exception occurred: {e}") + return xml_node + + +# A11y tree getter for macOS + +def _create_axui_node(node, nodes: set = None, depth: int = 0, bbox: tuple = None): + nodes = nodes or set() + if node in nodes: + return + nodes.add(node) + + reserved_keys = { + "AXEnabled": "st", + "AXFocused": "st", + "AXFullScreen": "st", + "AXTitle": "attr", + "AXChildrenInNavigationOrder": "attr", + "AXChildren": "attr", + "AXFrame": "attr", + "AXRole": "role", + "AXHelp": "attr", + "AXRoleDescription": "role", + "AXSubrole": "role", + "AXURL": "attr", + "AXValue": "val", + "AXDescription": "attr", + "AXDOMIdentifier": "attr", + "AXSelected": "st", + "AXInvalid": "st", + "AXRows": "attr", + "AXColumns": "attr", + } + attribute_dict = {} + + if depth == 0: + bbox = ( + node["kCGWindowBounds"]["X"], + node["kCGWindowBounds"]["Y"], + node["kCGWindowBounds"]["X"] + node["kCGWindowBounds"]["Width"], + node["kCGWindowBounds"]["Y"] + node["kCGWindowBounds"]["Height"] + ) + app_ref = ApplicationServices.AXUIElementCreateApplication(node["kCGWindowOwnerPID"]) + + attribute_dict["name"] = node["kCGWindowOwnerName"] + if attribute_dict["name"] != "Dock": + error_code, app_wins_ref = ApplicationServices.AXUIElementCopyAttributeValue( + app_ref, "AXWindows", None) + if error_code: + logger.error("MacOS parsing %s encountered Error code: %d", app_ref, error_code) + else: + app_wins_ref = [app_ref] + node = app_wins_ref[0] + + error_code, attr_names = ApplicationServices.AXUIElementCopyAttributeNames(node, None) + + if error_code: + # -25202: AXError.invalidUIElement + # The accessibility object received in this event is invalid. + return + + value = None + + if "AXFrame" in attr_names: + error_code, attr_val = ApplicationServices.AXUIElementCopyAttributeValue(node, "AXFrame", None) + rep = repr(attr_val) + x_value = re.search(r"x:(-?[\d.]+)", rep) + y_value = re.search(r"y:(-?[\d.]+)", rep) + w_value = re.search(r"w:(-?[\d.]+)", rep) + h_value = re.search(r"h:(-?[\d.]+)", rep) + type_value = re.search(r"type\s?=\s?(\w+)", rep) + value = { + "x": float(x_value.group(1)) if x_value else None, + "y": float(y_value.group(1)) if y_value else None, + "w": float(w_value.group(1)) if w_value else None, + "h": float(h_value.group(1)) if h_value else None, + "type": type_value.group(1) if type_value else None, + } + + if not any(v is None for v in value.values()): + x_min = max(bbox[0], value["x"]) + x_max = min(bbox[2], value["x"] + value["w"]) + y_min = max(bbox[1], value["y"]) + y_max = min(bbox[3], value["y"] + value["h"]) + + if x_min > x_max or y_min > y_max: + # No intersection + return + + role = None + text = None + + for attr_name, ns_key in reserved_keys.items(): + if attr_name not in attr_names: + continue + + if value and attr_name == "AXFrame": + bb = value + if not any(v is None for v in bb.values()): + attribute_dict["{{{:}}}screencoord".format(_accessibility_ns_map_macos["cp"])] = \ + "({:d}, {:d})".format(int(bb["x"]), int(bb["y"])) + attribute_dict["{{{:}}}size".format(_accessibility_ns_map_macos["cp"])] = \ + "({:d}, {:d})".format(int(bb["w"]), int(bb["h"])) + continue + + error_code, attr_val = ApplicationServices.AXUIElementCopyAttributeValue(node, attr_name, None) + + full_attr_name = f"{{{_accessibility_ns_map_macos[ns_key]}}}{attr_name}" + + if attr_name == "AXValue" and not text: + text = str(attr_val) + continue + + if attr_name == "AXRoleDescription": + role = attr_val + continue + + # Set the attribute_dict + if not (isinstance(attr_val, ApplicationServices.AXUIElementRef) + or isinstance(attr_val, (AppKit.NSArray, list))): + if attr_val is not None: + attribute_dict[full_attr_name] = str(attr_val) + + node_role_name = role.lower().replace(" ", "_") if role else "unknown_role" + + xml_node = lxml.etree.Element( + node_role_name, + attrib=attribute_dict, + nsmap=_accessibility_ns_map_macos + ) + + if text is not None and len(text) > 0: + xml_node.text = text + + if depth == MAX_DEPTH: + logger.warning("Max depth reached") + return xml_node + + future_to_child = [] + + with concurrent.futures.ThreadPoolExecutor() as executor: + for attr_name, ns_key in reserved_keys.items(): + if attr_name not in attr_names: + continue + + error_code, attr_val = ApplicationServices.AXUIElementCopyAttributeValue(node, attr_name, None) + if isinstance(attr_val, ApplicationServices.AXUIElementRef): + future_to_child.append(executor.submit(_create_axui_node, attr_val, nodes, depth + 1, bbox)) + + elif isinstance(attr_val, (AppKit.NSArray, list)): + for child in attr_val: + future_to_child.append(executor.submit(_create_axui_node, child, nodes, depth + 1, bbox)) + + try: + for future in concurrent.futures.as_completed(future_to_child): + result = future.result() + if result is not None: + xml_node.append(result) + except Exception as e: + logger.error(f"Exception occurred: {e}") + + return xml_node + + +@app.route("/accessibility", methods=["GET"]) +def get_accessibility_tree(): + os_name: str = platform.system() + + # AT-SPI works for KDE as well + if os_name == "Linux": + global libreoffice_version_tuple + libreoffice_version_tuple = _get_libreoffice_version() + + desktop: Accessible = pyatspi.Registry.getDesktop(0) + xml_node = lxml.etree.Element("desktop-frame", nsmap=_accessibility_ns_map_ubuntu) + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [executor.submit(_create_atspi_node, app_node, 1) for app_node in desktop] + for future in concurrent.futures.as_completed(futures): + xml_tree = future.result() + xml_node.append(xml_tree) + return jsonify({"AT": lxml.etree.tostring(xml_node, encoding="unicode")}) + + elif os_name == "Windows": + # Attention: Windows a11y tree is implemented to be read through `pywinauto` module, however, + # two different backends `win32` and `uia` are supported and different results may be returned + desktop: Desktop = Desktop(backend="uia") + xml_node = lxml.etree.Element("desktop", nsmap=_accessibility_ns_map_windows) + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [executor.submit(_create_pywinauto_node, wnd, {}, 1) for wnd in desktop.windows()] + for future in concurrent.futures.as_completed(futures): + xml_tree = future.result() + xml_node.append(xml_tree) + return jsonify({"AT": lxml.etree.tostring(xml_node, encoding="unicode")}) + + elif os_name == "Darwin": + # TODO: Add Dock and MenuBar + xml_node = lxml.etree.Element("desktop", nsmap=_accessibility_ns_map_macos) + + with concurrent.futures.ThreadPoolExecutor() as executor: + foreground_windows = [ + win for win in Quartz.CGWindowListCopyWindowInfo( + (Quartz.kCGWindowListExcludeDesktopElements | + Quartz.kCGWindowListOptionOnScreenOnly), + Quartz.kCGNullWindowID + ) if win["kCGWindowLayer"] == 0 and win["kCGWindowOwnerName"] != "Window Server" + ] + dock_info = [ + win for win in Quartz.CGWindowListCopyWindowInfo( + Quartz.kCGWindowListOptionAll, + Quartz.kCGNullWindowID + ) if win.get("kCGWindowName", None) == "Dock" + ] + + futures = [ + executor.submit(_create_axui_node, wnd, None, 0) + for wnd in foreground_windows + dock_info + ] + + for future in concurrent.futures.as_completed(futures): + xml_tree = future.result() + if xml_tree is not None: + xml_node.append(xml_tree) + + return jsonify({"AT": lxml.etree.tostring(xml_node, encoding="unicode")}) + + else: + return "Currently not implemented for platform {:}.".format(platform.platform()), 500 + + +@app.route('/screen_size', methods=['POST']) +def get_screen_size(): + if platform_name == "Linux": + d = display.Display() + screen_width = d.screen().width_in_pixels + screen_height = d.screen().height_in_pixels + elif platform_name == "Windows": + user32 = ctypes.windll.user32 + screen_width: int = user32.GetSystemMetrics(0) + screen_height: int = user32.GetSystemMetrics(1) + return jsonify( + { + "width": screen_width, + "height": screen_height + } + ) + + +@app.route('/window_size', methods=['POST']) +def get_window_size(): + if 'app_class_name' in request.form: + app_class_name = request.form['app_class_name'] + else: + return jsonify({"error": "app_class_name is required"}), 400 + + d = display.Display() + root = d.screen().root + window_ids = root.get_full_property(d.intern_atom('_NET_CLIENT_LIST'), X.AnyPropertyType).value + + for window_id in window_ids: + try: + window = d.create_resource_object('window', window_id) + wm_class = window.get_wm_class() + + if wm_class is None: + continue + + if app_class_name.lower() in [name.lower() for name in wm_class]: + geom = window.get_geometry() + return jsonify( + { + "width": geom.width, + "height": geom.height + } + ) + except Xlib.error.XError: # Ignore windows that give an error + continue + return None + + +@app.route('/desktop_path', methods=['POST']) +def get_desktop_path(): + # Get the home directory in a platform-independent manner using pathlib + home_directory = str(Path.home()) + + # Determine the desktop path based on the operating system + desktop_path = { + "Windows": os.path.join(home_directory, "Desktop"), + "Darwin": os.path.join(home_directory, "Desktop"), # macOS + "Linux": os.path.join(home_directory, "Desktop") + }.get(platform.system(), None) + + # Check if the operating system is supported and the desktop path exists + if desktop_path and os.path.exists(desktop_path): + return jsonify(desktop_path=desktop_path) + else: + return jsonify(error="Unsupported operating system or desktop path not found"), 404 + + +@app.route('/wallpaper', methods=['POST']) +def get_wallpaper(): + def get_wallpaper_windows(): + SPI_GETDESKWALLPAPER = 0x73 + MAX_PATH = 260 + buffer = ctypes.create_unicode_buffer(MAX_PATH) + ctypes.windll.user32.SystemParametersInfoW(SPI_GETDESKWALLPAPER, MAX_PATH, buffer, 0) + return buffer.value + + def get_wallpaper_macos(): + script = """ + tell application "System Events" to tell every desktop to get picture + """ + process = subprocess.Popen(['osascript', '-e', script], stdout=subprocess.PIPE, stderr=subprocess.PIPE) + output, error = process.communicate() + if error: + app.logger.error("Error: %s", error.decode('utf-8')) + return None + return output.strip().decode('utf-8') + + def get_wallpaper_linux(): + try: + output = subprocess.check_output( + ["gsettings", "get", "org.gnome.desktop.background", "picture-uri"], + stderr=subprocess.PIPE + ) + return output.decode('utf-8').strip().replace('file://', '').replace("'", "") + except subprocess.CalledProcessError as e: + app.logger.error("Error: %s", e) + return None + + os_name = platform.system() + wallpaper_path = None + if os_name == 'Windows': + wallpaper_path = get_wallpaper_windows() + elif os_name == 'Darwin': + wallpaper_path = get_wallpaper_macos() + elif os_name == 'Linux': + wallpaper_path = get_wallpaper_linux() + else: + app.logger.error(f"Unsupported OS: {os_name}") + abort(400, description="Unsupported OS") + + if wallpaper_path: + try: + # Ensure the filename is secure + return send_file(wallpaper_path, mimetype='image/png') + except Exception as e: + app.logger.error(f"An error occurred while serving the wallpaper file: {e}") + abort(500, description="Unable to serve the wallpaper file") + else: + abort(404, description="Wallpaper file not found") + + +@app.route('/list_directory', methods=['POST']) +def get_directory_tree(): + def _list_dir_contents(directory): + """ + List the contents of a directory recursively, building a tree structure. + + :param directory: The path of the directory to inspect. + :return: A nested dictionary with the contents of the directory. + """ + tree = {'type': 'directory', 'name': os.path.basename(directory), 'children': []} + try: + # List all files and directories in the current directory + for entry in os.listdir(directory): + full_path = os.path.join(directory, entry) + # If entry is a directory, recurse into it + if os.path.isdir(full_path): + tree['children'].append(_list_dir_contents(full_path)) + else: + tree['children'].append({'type': 'file', 'name': entry}) + except OSError as e: + # If the directory cannot be accessed, return the exception message + tree = {'error': str(e)} + return tree + + # Extract the 'path' parameter from the JSON request + data = request.get_json() + if 'path' not in data: + return jsonify(error="Missing 'path' parameter"), 400 + + start_path = data['path'] + # Ensure the provided path is a directory + if not os.path.isdir(start_path): + return jsonify(error="The provided path is not a directory"), 400 + + # Generate the directory tree starting from the provided path + directory_tree = _list_dir_contents(start_path) + return jsonify(directory_tree=directory_tree) + + +@app.route('/file', methods=['POST']) +def get_file(): + # Retrieve filename from the POST request + if 'file_path' in request.form: + file_path = os.path.expandvars(os.path.expanduser(request.form['file_path'])) + else: + return jsonify({"error": "file_path is required"}), 400 + + try: + # Check if the file exists and get its size + if not os.path.exists(file_path): + return jsonify({"error": "File not found"}), 404 + + file_size = os.path.getsize(file_path) + logger.info(f"Serving file: {file_path} ({file_size} bytes)") + + # Check if the file exists and send it to the user + return send_file(file_path, as_attachment=True) + except FileNotFoundError: + # If the file is not found, return a 404 error + return jsonify({"error": "File not found"}), 404 + except Exception as e: + logger.error(f"Error serving file {file_path}: {e}") + return jsonify({"error": f"Failed to serve file: {str(e)}"}), 500 + + +@app.route("/setup/upload", methods=["POST"]) +def upload_file(): + # Retrieve filename from the POST request + if 'file_path' in request.form and 'file_data' in request.files: + file_path = os.path.expandvars(os.path.expanduser(request.form['file_path'])) + file = request.files["file_data"] + + try: + # Ensure target directory exists + os.makedirs(os.path.dirname(file_path), exist_ok=True) + + # Save file and get size for verification + file.save(file_path) + uploaded_size = os.path.getsize(file_path) + + logger.info(f"File uploaded successfully: {file_path} ({uploaded_size} bytes)") + return f"File Uploaded: {uploaded_size} bytes" + + except Exception as e: + logger.error(f"Error uploading file to {file_path}: {e}") + # Clean up partial file if it exists + if os.path.exists(file_path): + try: + os.remove(file_path) + except: + pass + return jsonify({"error": f"Failed to upload file: {str(e)}"}), 500 + else: + return jsonify({"error": "file_path and file_data are required"}), 400 + + +@app.route('/platform', methods=['GET']) +def get_platform(): + return platform.system() + + +@app.route('/cursor_position', methods=['GET']) +def get_cursor_position(): + pos = pyautogui.position() + return jsonify(pos.x, pos.y) + +@app.route("/setup/change_wallpaper", methods=['POST']) +def change_wallpaper(): + data = request.json + path = data.get('path', None) + + if not path: + return "Path not supplied!", 400 + + path = Path(os.path.expandvars(os.path.expanduser(path))) + + if not path.exists(): + return f"File not found: {path}", 404 + + try: + user_platform = platform.system() + if user_platform == "Windows": + import ctypes + ctypes.windll.user32.SystemParametersInfoW(20, 0, str(path), 3) + elif user_platform == "Linux": + import subprocess + subprocess.run(["gsettings", "set", "org.gnome.desktop.background", "picture-uri", f"file://{path}"]) + elif user_platform == "Darwin": # (Mac OS) + import subprocess + subprocess.run( + ["osascript", "-e", f'tell application "Finder" to set desktop picture to POSIX file "{path}"']) + return "Wallpaper changed successfully" + except Exception as e: + return f"Failed to change wallpaper. Error: {e}", 500 + + +@app.route("/setup/download_file", methods=['POST']) +def download_file(): + data = request.json + url = data.get('url', None) + path = data.get('path', None) + + if not url or not path: + return "Path or URL not supplied!", 400 + + path = Path(os.path.expandvars(os.path.expanduser(path))) + path.parent.mkdir(parents=True, exist_ok=True) + + max_retries = 3 + error: Optional[Exception] = None + + for i in range(max_retries): + try: + logger.info(f"Download attempt {i+1}/{max_retries} for {url}") + response = requests.get(url, stream=True, timeout=300) + response.raise_for_status() + + # Get expected file size if available + total_size = int(response.headers.get('content-length', 0)) + if total_size > 0: + logger.info(f"Expected file size: {total_size / (1024*1024):.2f} MB") + + downloaded_size = 0 + with open(path, 'wb') as f: + for chunk in response.iter_content(chunk_size=8192): + if chunk: + f.write(chunk) + downloaded_size += len(chunk) + if total_size > 0 and downloaded_size % (1024*1024) == 0: # Log every MB + progress = (downloaded_size / total_size) * 100 + logger.info(f"Download progress: {progress:.1f}%") + + # Verify download completeness + actual_size = os.path.getsize(path) + if total_size > 0 and actual_size != total_size: + raise Exception(f"Download incomplete. Expected {total_size} bytes, got {actual_size} bytes") + + logger.info(f"File downloaded successfully: {path} ({actual_size} bytes)") + return f"File downloaded successfully: {actual_size} bytes" + + except (requests.RequestException, Exception) as e: + error = e + logger.error(f"Failed to download {url}: {e}. Retrying... ({max_retries - i - 1} attempts left)") + # Clean up partial download + if path.exists(): + try: + path.unlink() + except: + pass + + return f"Failed to download {url}. No retries left. Error: {error}", 500 + + +@app.route("/setup/open_file", methods=['POST']) +def open_file(): + data = request.json + path = data.get('path', None) + + if not path: + return "Path not supplied!", 400 + + path_obj = Path(os.path.expandvars(os.path.expanduser(path))) + + # Check if it's a file path that exists + is_file_path = path_obj.exists() + + # If it's not a file path, treat it as an application name/command + if not is_file_path: + # Check if it's a valid command by trying to find it in PATH + import shutil + if not shutil.which(path): + return f"Application/file not found: {path}", 404 + + try: + if is_file_path: + # Handle file opening + if platform.system() == "Windows": + os.startfile(path_obj) + else: + open_cmd: str = "open" if platform.system() == "Darwin" else "xdg-open" + subprocess.Popen([open_cmd, str(path_obj)]) + file_name = path_obj.name + file_name_without_ext, _ = os.path.splitext(file_name) + else: + # Handle application launching + if platform.system() == "Windows": + subprocess.Popen([path]) + else: + subprocess.Popen([path]) + file_name = path + file_name_without_ext = path + + # Wait for the file/application to open + + start_time = time.time() + window_found = False + + while time.time() - start_time < TIMEOUT: + os_name = platform.system() + if os_name in ['Windows', 'Darwin']: + import pygetwindow as gw + # Check for window title containing file name or file name without extension + windows = gw.getWindowsWithTitle(file_name) + if not windows: + windows = gw.getWindowsWithTitle(file_name_without_ext) + + if windows: + # To be more specific, we can try to activate it + windows[0].activate() + window_found = True + break + elif os_name == 'Linux': + try: + # Using wmctrl to list windows and check if any window title contains the filename + result = subprocess.run(['wmctrl', '-l'], capture_output=True, text=True, check=True) + window_list = result.stdout.strip().split('\n') + if not result.stdout.strip(): + pass # No windows, just continue waiting + else: + for window in window_list: + if file_name in window or file_name_without_ext in window: + # a window is found, now activate it + window_id = window.split()[0] + subprocess.run(['wmctrl', '-i', '-a', window_id], check=True) + window_found = True + break + if window_found: + break + except (subprocess.CalledProcessError, FileNotFoundError): + # wmctrl might not be installed or the window manager isn't ready. + # We just log it once and let the main loop retry. + if 'wmctrl_failed_once' not in locals(): + logger.warning("wmctrl command is not ready, will keep retrying...") + wmctrl_failed_once = True + pass # Let the outer loop retry + + time.sleep(1) + + if window_found: + return "File opened and window activated successfully" + else: + return f"Failed to find window for {file_name} within {timeout} seconds.", 500 + + except Exception as e: + return f"Failed to open {path}. Error: {e}", 500 + + +@app.route("/setup/activate_window", methods=['POST']) +def activate_window(): + data = request.json + window_name = data.get('window_name', None) + if not window_name: + return "window_name required", 400 + strict: bool = data.get("strict", False) # compare case-sensitively and match the whole string + by_class_name: bool = data.get("by_class", False) + + os_name = platform.system() + + if os_name == 'Windows': + import pygetwindow as gw + if by_class_name: + return "Get window by class name is not supported on Windows currently.", 500 + windows: List[gw.Window] = gw.getWindowsWithTitle(window_name) + + window: Optional[gw.Window] = None + if len(windows) == 0: + return "Window {:} not found (empty results)".format(window_name), 404 + elif strict: + for wnd in windows: + if wnd.title == wnd: + window = wnd + if window is None: + return "Window {:} not found (strict mode).".format(window_name), 404 + else: + window = windows[0] + window.activate() + + elif os_name == 'Darwin': + import pygetwindow as gw + if by_class_name: + return "Get window by class name is not supported on macOS currently.", 500 + # Find the VS Code window + windows = gw.getWindowsWithTitle(window_name) + + window: Optional[gw.Window] = None + if len(windows) == 0: + return "Window {:} not found (empty results)".format(window_name), 404 + elif strict: + for wnd in windows: + if wnd.title == wnd: + window = wnd + if window is None: + return "Window {:} not found (strict mode).".format(window_name), 404 + else: + window = windows[0] + + # Un-minimize the window and then bring it to the front + window.unminimize() + window.activate() + + elif os_name == 'Linux': + # Attempt to activate VS Code window using wmctrl + subprocess.run(["wmctrl" + , "-{:}{:}a".format("x" if by_class_name else "" + , "F" if strict else "" + ) + , window_name + ] + ) + + else: + return f"Operating system {os_name} not supported.", 400 + + return "Window activated successfully", 200 + + +@app.route("/setup/close_window", methods=["POST"]) +def close_window(): + data = request.json + if "window_name" not in data: + return "window_name required", 400 + window_name: str = data["window_name"] + strict: bool = data.get("strict", False) # compare case-sensitively and match the whole string + by_class_name: bool = data.get("by_class", False) + + os_name: str = platform.system() + if os_name == "Windows": + import pygetwindow as gw + + if by_class_name: + return "Get window by class name is not supported on Windows currently.", 500 + windows: List[gw.Window] = gw.getWindowsWithTitle(window_name) + + window: Optional[gw.Window] = None + if len(windows) == 0: + return "Window {:} not found (empty results)".format(window_name), 404 + elif strict: + for wnd in windows: + if wnd.title == wnd: + window = wnd + if window is None: + return "Window {:} not found (strict mode).".format(window_name), 404 + else: + window = windows[0] + window.close() + elif os_name == "Linux": + subprocess.run(["wmctrl" + , "-{:}{:}c".format("x" if by_class_name else "" + , "F" if strict else "" + ) + , window_name + ] + ) + elif os_name == "Darwin": + import pygetwindow as gw + return "Currently not supported on macOS.", 500 + else: + return "Not supported platform {:}".format(os_name), 500 + + return "Window closed successfully.", 200 + + +@app.route('/start_recording', methods=['POST']) +def start_recording(): + global recording_process + if recording_process and recording_process.poll() is None: + return jsonify({'status': 'error', 'message': 'Recording is already in progress.'}), 400 + + # Clean up previous recording if it exists + if os.path.exists(recording_path): + try: + os.remove(recording_path) + except OSError as e: + logger.error(f"Error removing old recording file: {e}") + return jsonify({'status': 'error', 'message': f'Failed to remove old recording file: {e}'}), 500 + + d = display.Display() + screen_width = d.screen().width_in_pixels + screen_height = d.screen().height_in_pixels + + start_command = f"ffmpeg -y -f x11grab -draw_mouse 1 -s {screen_width}x{screen_height} -i :0.0 -c:v libx264 -r 30 {recording_path}" + + # Use stderr=PIPE to capture potential errors from ffmpeg + recording_process = subprocess.Popen(shlex.split(start_command), + stdout=subprocess.DEVNULL, + stderr=subprocess.PIPE, + text=True # To get stderr as string + ) + + # Wait a couple of seconds to see if ffmpeg starts successfully + try: + # Wait for 2 seconds. If ffmpeg exits within this time, it's an error. + recording_process.wait(timeout=2) + # If wait() returns, it means the process has terminated. + error_output = recording_process.stderr.read() + return jsonify({ + 'status': 'error', + 'message': f'Failed to start recording. ffmpeg terminated unexpectedly. Error: {error_output}' + }), 500 + except subprocess.TimeoutExpired: + # This is the expected outcome: the process is still running after 2 seconds. + return jsonify({'status': 'success', 'message': 'Started recording successfully.'}) + + +@app.route('/end_recording', methods=['POST']) +def end_recording(): + global recording_process + + if not recording_process or recording_process.poll() is not None: + recording_process = None # Clean up stale process object + return jsonify({'status': 'error', 'message': 'No recording in progress to stop.'}), 400 + + error_output = "" + try: + # Send SIGINT for a graceful shutdown, allowing ffmpeg to finalize the file. + recording_process.send_signal(signal.SIGINT) + # Wait for ffmpeg to terminate. communicate() gets output and waits. + _, error_output = recording_process.communicate(timeout=15) + except subprocess.TimeoutExpired: + logger.error("ffmpeg did not respond to SIGINT, killing the process.") + recording_process.kill() + # After killing, communicate to get any remaining output. + _, error_output = recording_process.communicate() + recording_process = None + return jsonify({ + 'status': 'error', + 'message': f'Recording process was unresponsive and had to be killed. Stderr: {error_output}' + }), 500 + + recording_process = None # Clear the process from global state + + # Check if the recording file was created and is not empty. + if os.path.exists(recording_path) and os.path.getsize(recording_path) > 0: + return send_file(recording_path, as_attachment=True) + else: + logger.error(f"Recording failed. The output file is missing or empty. ffmpeg stderr: {error_output}") + return abort(500, description=f"Recording failed. The output file is missing or empty. ffmpeg stderr: {error_output}") + + +if __name__ == '__main__': + app.run(debug=True, host="0.0.0.0") diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/server/osworld_server.service b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/server/osworld_server.service new file mode 100644 index 000000000..50dc881cd --- /dev/null +++ b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/server/osworld_server.service @@ -0,0 +1,16 @@ +[Unit] +Description=OSWorld Server +StartLimitIntervalSec=60 +StartLimitBurst=4 +After=network.target auditd.service + +[Service] +ExecStart=/usr/bin/python3 /home/user/main.py +User=user +WorkingDirectory=/home/user +Restart=on-failure +RestartSec=1 +Environment="DISPLAY=:1" + +[Install] +WantedBy=graphical.target diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/server/pyxcursor.py b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/server/pyxcursor.py new file mode 100644 index 000000000..0fe11def3 --- /dev/null +++ b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env/server/pyxcursor.py @@ -0,0 +1,146 @@ +import os +import ctypes +import ctypes.util +import numpy as np + +# A helper function to convert data from Xlib to byte array. +import struct, array + +# Define ctypes version of XFixesCursorImage structure. +PIXEL_DATA_PTR = ctypes.POINTER(ctypes.c_ulong) +Atom = ctypes.c_ulong + + +class XFixesCursorImage(ctypes.Structure): + """ + See /usr/include/X11/extensions/Xfixes.h + + typedef struct { + short x, y; + unsigned short width, height; + unsigned short xhot, yhot; + unsigned long cursor_serial; + unsigned long *pixels; + if XFIXES_MAJOR >= 2 + Atom atom; /* Version >= 2 only */ + const char *name; /* Version >= 2 only */ + endif + } XFixesCursorImage; + """ + _fields_ = [('x', ctypes.c_short), + ('y', ctypes.c_short), + ('width', ctypes.c_ushort), + ('height', ctypes.c_ushort), + ('xhot', ctypes.c_ushort), + ('yhot', ctypes.c_ushort), + ('cursor_serial', ctypes.c_ulong), + ('pixels', PIXEL_DATA_PTR), + ('atom', Atom), + ('name', ctypes.c_char_p)] + + +class Display(ctypes.Structure): + pass + + +class Xcursor: + display = None + + def __init__(self, display=None): + if not display: + try: + display = os.environ["DISPLAY"].encode("utf-8") + except KeyError: + raise Exception("$DISPLAY not set.") + + # XFixeslib = ctypes.CDLL('libXfixes.so') + XFixes = ctypes.util.find_library("Xfixes") + if not XFixes: + raise Exception("No XFixes library found.") + self.XFixeslib = ctypes.cdll.LoadLibrary(XFixes) + + # xlib = ctypes.CDLL('libX11.so.6') + x11 = ctypes.util.find_library("X11") + if not x11: + raise Exception("No X11 library found.") + self.xlib = ctypes.cdll.LoadLibrary(x11) + + # Define ctypes' version of XFixesGetCursorImage function + XFixesGetCursorImage = self.XFixeslib.XFixesGetCursorImage + XFixesGetCursorImage.restype = ctypes.POINTER(XFixesCursorImage) + XFixesGetCursorImage.argtypes = [ctypes.POINTER(Display)] + self.XFixesGetCursorImage = XFixesGetCursorImage + + XOpenDisplay = self.xlib.XOpenDisplay + XOpenDisplay.restype = ctypes.POINTER(Display) + XOpenDisplay.argtypes = [ctypes.c_char_p] + + if not self.display: + self.display = self.xlib.XOpenDisplay(display) # (display) or (None) + + def argbdata_to_pixdata(self, data, len): + if data == None or len < 1: return None + + # Create byte array + b = array.array('b', b'\x00' * 4 * len) + + offset, i = 0, 0 + while i < len: + argb = data[i] & 0xffffffff + rgba = (argb << 8) | (argb >> 24) + b1 = (rgba >> 24) & 0xff + b2 = (rgba >> 16) & 0xff + b3 = (rgba >> 8) & 0xff + b4 = rgba & 0xff + + struct.pack_into("=BBBB", b, offset, b1, b2, b3, b4) + offset = offset + 4 + i = i + 1 + + return b + + def getCursorImageData(self): + # Call the function. Read data of cursor/mouse-pointer. + cursor_data = self.XFixesGetCursorImage(self.display) + + if not (cursor_data and cursor_data[0]): + raise Exception("Cannot read XFixesGetCursorImage()") + + # Note: cursor_data is a pointer, take cursor_data[0] + return cursor_data[0] + + def getCursorImageArray(self): + data = self.getCursorImageData() + # x, y = data.x, data.y + height, width = data.height, data.width + + bytearr = self.argbdata_to_pixdata(data.pixels, height * width) + + imgarray = np.array(bytearr, dtype=np.uint8) + imgarray = imgarray.reshape(height, width, 4) + del bytearr + + return imgarray + + def getCursorImageArrayFast(self): + data = self.getCursorImageData() + # x, y = data.x, data.y + height, width = data.height, data.width + + bytearr = ctypes.cast(data.pixels, ctypes.POINTER(ctypes.c_ulong * height * width))[0] + imgarray = np.array(bytearray(bytearr)) + imgarray = imgarray.reshape(height, width, 8)[:, :, (0, 1, 2, 3)] + del bytearr + + return imgarray + + def saveImage(self, imgarray, text): + from PIL import Image + img = Image.fromarray(imgarray) + img.save(text) + + +if __name__ == "__main__": + cursor = Xcursor() + imgarray = cursor.getCursorImageArrayFast() + cursor.saveImage(imgarray, 'cursor_image.png') diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env_interface.py b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env_interface.py new file mode 100644 index 000000000..a52beb98b --- /dev/null +++ b/skyrl-agent/skyrl_agent/tasks/osworld/desktop_env_interface.py @@ -0,0 +1,123 @@ +from skyrl_agent.tasks.osworld.desktop_env.desktop_env import DesktopEnv +import ray +import socket +import psutil + +@ray.remote(num_cpus=1, num_gpus=0) +class DesktopEnvRay: + def __init__(self, *args, **kwargs): + hostname = socket.gethostname() + print(f"DesktopEnvRay actor started on {hostname}") + + # Increase Docker port allocation timeout for Ray distributed execution + import os + os.environ['DESKTOP_ENV_LOCK_TIMEOUT'] = '120' + + # This needs to be the path to the VM image on the worker/CPU node + kwargs['path_to_vm'] = "/home/ubuntu/shuo-vir/OSWorld_llm_agentsynth/docker_vm_data/Ubuntu.qcow2" + + self.desktop_env = DesktopEnv(*args, **kwargs) + + # Explicitly define the methods that Ray needs to expose + async def _start_emulator_async(self): + return await self.desktop_env._start_emulator_async() + + def _start_emulator(self): + return self.desktop_env._start_emulator() + + def step(self, action, pause): + return self.desktop_env.step(action, pause=pause) + + async def step_async(self, action, pause): + return await self.desktop_env.step_async(action, pause=pause) + + async def reset(self, *args, **kwargs): + return await self.desktop_env.reset(*args, **kwargs) + + def _get_obs(self): + return self.desktop_env._get_obs() + + async def _get_obs_async(self): + return await self.desktop_env._get_obs_async() + + async def evaluate(self, *args, **kwargs): + return await self.desktop_env.evaluate(*args, **kwargs) + + def close(self): + return self.desktop_env.close() + + +class DesktopEnvInterface: + def __init__(self, desktop_env, cpu_node: bool = False): + self.desktop_env = desktop_env + self.cpu_node = cpu_node + + async def _start_emulator_async(self): + """Start emulator async - call .remote() if using CPU, otherwise call normally""" + if self.cpu_node: + # For Ray actors, we need to get the result of the remote call + result = self.desktop_env._start_emulator_async.remote() + return await result + else: + return await self.desktop_env._start_emulator_async() + + def _start_emulator(self): + """Start emulator - call .remote() if using CPU, otherwise call normally""" + if self.cpu_node: + return ray.get(self.desktop_env._start_emulator.remote()) + else: + return self.desktop_env._start_emulator() + + def step(self, action, pause): + """Step action - call .remote() if using CPU, otherwise call normally""" + if self.cpu_node: + return ray.get(self.desktop_env.step.remote(action, pause)) + else: + return self.desktop_env.step(action, pause=pause) + + async def step_async(self, action, pause): + """Step action async - call .remote() if using CPU, otherwise call normally""" + if self.cpu_node: + result = self.desktop_env.step_async.remote(action, pause=pause) + return await result + else: + return await self.desktop_env.step_async(action, pause=pause) + + async def reset(self, *args, **kwargs): + """Reset environment - call .remote() if using CPU, otherwise call normally""" + if self.cpu_node: + result = self.desktop_env.reset.remote(*args, **kwargs) + return await result + else: + return await self.desktop_env.reset(*args, **kwargs) + + def _get_obs(self): + """Get observation - call .remote() if using CPU, otherwise call normally""" + if self.cpu_node: + return ray.get(self.desktop_env._get_obs.remote()) + else: + return self.desktop_env._get_obs() + + async def _get_obs_async(self): + """Get observation async - call .remote() if using CPU, otherwise call normally""" + if self.cpu_node: + result = self.desktop_env._get_obs_async.remote() + return await result + else: + return await self.desktop_env._get_obs_async() + + async def evaluate(self, *args, **kwargs): + """Evaluate - call .remote() if using CPU, otherwise call normally""" + if self.cpu_node: + result = self.desktop_env.evaluate.remote(*args, **kwargs) + return await result + else: + return await self.desktop_env.evaluate(*args, **kwargs) + + def close(self): + """Close the desktop environment - call .remote() if using CPU, otherwise call normally""" + if self.cpu_node: + return ray.get(self.desktop_env.close.remote()) + else: + return self.desktop_env.close() + \ No newline at end of file diff --git a/skyrl-agent/skyrl_agent/tasks/osworld/osworld_task.py b/skyrl-agent/skyrl_agent/tasks/osworld/osworld_task.py new file mode 100644 index 000000000..771e7f86c --- /dev/null +++ b/skyrl-agent/skyrl_agent/tasks/osworld/osworld_task.py @@ -0,0 +1,265 @@ +from email import message +from skyrl_agent.tasks.base import BaseTask +from typing import Dict, Any +import time +from skyrl_agent.tools.osworld_tools import OSWorldActionTool +import asyncio +import json +from loguru import logger +from skyrl_agent.tasks.osworld.desktop_env_interface import DesktopEnvInterface, DesktopEnvRay +from skyrl_agent.tasks.osworld.desktop_env.desktop_env import DesktopEnv +from typing import List +from io import BytesIO +import base64 + + +SYS_PROMPT_IN_A11Y_OUT_CODE = """ +You are an agent which follow my instruction and perform desktop computer tasks as instructed. +You have good knowledge of computer and good internet connection and assume your code will run on a computer for controlling the mouse and keyboard. +For each step, you will get an observation of the desktop by accessibility tree, which is based on AT-SPI library. And you will predict the action of the computer based on the accessibility tree. + +You can use the osworld_action tool to perform the action grounded to the observation. + +You are required to use `pyautogui` while using the osworld_action tool, but DO NOT use the `pyautogui.locateCenterOnScreen` function to locate the element you want to operate with since we have no image of the element you want to operate with. DO NOT USE `pyautogui.screenshot()` to make screenshot in the tool call. +Return one tool call to perform the action each time, be time efficient. +You need to to specify the coordinates by yourself based on current observation, but you should be careful to ensure that the coordinates are correct. + +Specially, it is also allowed to call the finish tool to finish the task: +When you think the task can not be done, call the finish tool with answer="FAIL" in the format of + +FAIL + +When you think the task is done, call the finish tool with answer="DONE" in the format of + +DONE + +DO NOT EASILY CALL THE FINISH TOOL, TRY YOUR BEST TO DO THE TASK; + +My computer's password is 'password', feel free to use it when you need sudo rights. +First give the current screenshot and previous things we did a short reflection, then CALL THE TOOLS I ASKED FOR. NEVER EVER RETURN ME ANYTHING ELSE. +""".strip() + +MAX_RETRY_TIMES = 10 + +class OSWorldTask(BaseTask): + @classmethod + async def initialize_runtime(cls, instance: Dict[str, Any]): + runtime = instance.get('runtime') + cfg = instance.pop('cfg') + vision_is_active = cfg.vision_is_active + instance['vision_is_active'] = vision_is_active + + start_time = time.time() + + last_exception = None + for attempt in range(MAX_RETRY_TIMES): + try: + await runtime.reset(task_config=instance) + break + except Exception as e: + last_exception = e + print(f"Failed to reset the runtime, retrying... (attempt {attempt + 1}/{MAX_RETRY_TIMES})") + print(f"Error: {str(e)}") + if attempt < MAX_RETRY_TIMES - 1: # Don't sleep after last attempt + await asyncio.sleep(1) + else: + # If we exit the loop without breaking, all retries failed + print(f"All {MAX_RETRY_TIMES} retry attempts failed") + raise last_exception + + await asyncio.sleep(5) # wait for the environment to be ready + initial_obs = await runtime._get_obs_async() + + reset_time = time.time() - start_time + print(f"[Runtime Reset] Reset completed in {reset_time:.2f} seconds") + if vision_is_active: + return cls.pil_to_base64(initial_obs["screenshot"]) + if initial_obs["accessibility_tree"] is None: + raise ValueError("Accessibility tree is None") + initial_acc_tree = OSWorldActionTool.linearize_accessibility_tree(initial_obs["accessibility_tree"], "ubuntu") # fixme: refer to the platform of the runtime/agent + if initial_acc_tree: + initial_acc_tree = OSWorldActionTool.trim_accessibility_tree(initial_acc_tree, 10000) # fixme: add arguments for max tokens + cls.initial_acc_tree = initial_acc_tree + return initial_acc_tree + + @classmethod + async def get_instruction(cls, instance: Dict[str, Any]): + instance = cls.osworld_data_preprocess(instance) + initial_observation = await cls.initialize_runtime(instance) + instance['initial_observation'] = initial_observation + vision_is_active = instance.get('vision_is_active', False) + + key = "instruction" if "instruction" in instance else "prompt" + + assert 'initial_observation' in instance, "initial_observation is required" + initial_observation = instance['initial_observation'] + + if isinstance(instance[key], str): + instruction = instance[key] # instance here should be json.load(example.json) examples from osworld + system_message = SYS_PROMPT_IN_A11Y_OUT_CODE + "\nYou are asked to complete the following task: {}".format(instruction) + messages = [ + { + "role": "system", + "content": system_message + } + ] + elif isinstance(instance[key], list): + messages = instance[key] + instruction = messages[0]["content"] + system_message_str = SYS_PROMPT_IN_A11Y_OUT_CODE + "\nYou are asked to complete the following task: {}".format(instruction) + messages = [ + { + "role": "system", + "content": system_message_str + } + ] + if vision_is_active: + user_content = [ + {"type": "text", "text": "Here is the current desktop screenshot. What's the next step to help with the task?"}, + {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{initial_observation}"}} + ] + else: + user_content = "Given the info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(initial_observation) + messages.append({ + "role": "user", + "content": user_content + }) + return messages + + @classmethod + def complete_runtime(cls, runtime: DesktopEnvRay, instance: Dict[str, Any]): + runtime.close() + + @classmethod + async def evaluate_result(cls, result: any, instance: any, data_source: str, instance_id: int, trajectory_id: int) -> float: + """ + result is returned by the self.runtime.evaluate() in the osworld_react_agent.py + """ + runtime = instance['runtime'] + if result != "DONE" and result != "FAIL": + await runtime.step_async("DONE", 0.2) + else: + await runtime.step_async(result, 0.2) + result = await runtime.evaluate() + return result + + @classmethod + def osworld_data_preprocess(cls, instance: Dict[str, Any]): + json_columns = ['config', 'evaluator', 'related_apps'] + for col in json_columns: + if col in instance: + try: + value = instance[col] + if isinstance(value, str): + if value and value != "": + instance[col] = json.loads(value) + elif col == 'config' or col == 'evaluator' or col == 'source' or col == 'related_apps': + instance[col] = [] + else: + instance[col] = None + except (json.JSONDecodeError, TypeError) as e: + logger.warning(f"Failed to parse JSON for column {col}: {e}") + logger.warning(f"Instance: {instance}") + instance[col] = [] + return instance + + @classmethod + async def initialize_shared_env(cls, cfg: Dict[str, Any]): + batch_size = 8 + total_agents = cfg.dispatcher.max_parallel_agents + shared_env = [] + print("creating shared env") + + for i in range(0, total_agents, batch_size): + # Calculate how many agents to create in this batch + current_batch_size = min(batch_size, total_agents - i) + + # Create tasks for current batch + desktop_env_tasks = [ + cls._create_and_start_desktop_env(i + j, cfg) + for j in range(current_batch_size) + ] + + # Execute current batch concurrently + batch_envs = await asyncio.gather(*desktop_env_tasks) + shared_env.extend(batch_envs) + + logger.info(f"Created batch {i//batch_size + 1}: {current_batch_size} DesktopEnv instances (Total: {len(shared_env)}/{total_agents})") + return shared_env + + @classmethod + async def _create_and_start_desktop_env(cls, env_id: int, cfg: Dict[str, Any]) -> DesktopEnvInterface: + """ + Async wrapper to create and start a DesktopEnv instance. + Uses Ray for distributed execution if enabled. + """ + print("creating and starting desktop env with id ", env_id) + if cfg.generator.use_cpu_node: # Add this flag to your config + # Create Ray remote actor + runtime = DesktopEnvRay.remote( + action_space="pyautogui", + provider_name="docker", + screen_size=(1920, 1080), + headless=True, + os_type="Ubuntu", + require_a11y_tree=not cfg.generator.vision_is_active, + env_id=env_id, + path_to_vm = cfg.generator.path_to_vm + ) + print(runtime) + # Wrap in interface first, then start the emulator + interface = DesktopEnvInterface(runtime, cpu_node=True) + await interface._start_emulator_async() + + return interface + else: + # Original local execution + runtime = DesktopEnv( + action_space="pyautogui", + provider_name="docker", + screen_size=(1920, 1080), + headless=True, + os_type="Ubuntu", + require_a11y_tree=not cfg.generator.vision_is_active, + env_id=env_id, + path_to_vm = cfg.generator.path_to_vm + ) + + await runtime._start_emulator_async() + + return DesktopEnvInterface(runtime, cpu_node=False) + + + @classmethod + async def close_shared_env(cls, shared_env: List[DesktopEnvInterface]): + if shared_env: + # Close all environments in parallel + close_tasks = [] + for env in shared_env: + # Run each close operation in a thread pool to avoid blocking + close_tasks.append(asyncio.to_thread(env.close)) + + # Wait for all close operations to complete + await asyncio.gather(*close_tasks, return_exceptions=True) + + # Clean up any remaining reserved ports after closing all environments + from skyrl_agent.tasks.osworld.desktop_env.providers.docker.provider import DockerProvider + DockerProvider.cleanup_all_reserved_ports() + while shared_env: + shared_env.pop() + + @classmethod + def get_task_dependent_context_management(cls, agent, traj_config): + assert getattr(traj_config.generator_cfg, "history_length", None) is not None, "history_length is required" + history_length = int(traj_config.generator_cfg.history_length) + assert history_length > 0, "history_length must be greater than 0" + + if (2 + history_length * 2) < len(agent.history.messages): + new_messages = [agent.history.messages[0]] + agent.history.messages[-history_length * 2 - 1:] + agent.history.initialize(new_messages) + + @classmethod + def pil_to_base64(cls, image): + buffer = BytesIO() + image.save(buffer, format="PNG") # 你可以改成 "JPEG" 等格式 + return base64.b64encode(buffer.getvalue()).decode("utf-8") \ No newline at end of file diff --git a/skyrl-agent/skyrl_agent/tools/osworld_tools.py b/skyrl-agent/skyrl_agent/tools/osworld_tools.py new file mode 100644 index 000000000..781776e7f --- /dev/null +++ b/skyrl-agent/skyrl_agent/tools/osworld_tools.py @@ -0,0 +1,229 @@ +from .base import BaseTool, register_tool +from typing import Union, Optional, Tuple +import xml.etree.ElementTree as ET +import re +import tiktoken +import base64 +from concurrent.futures import ThreadPoolExecutor, TimeoutError as FuturesTimeoutError + +@register_tool("osworld_action") +class OSWorldActionTool(BaseTool): + name = "osworld_action" + description = "Execute desktop automation actions using pyautogui code snippets. Provide Python code that uses pyautogui functions like click(), typewrite(), press(), hotkey(), scroll(), moveTo(), dragTo(), etc." + parameters = { + "type": "object", + "properties": { + "code": { + "type": "string", + "description": "Python code snippet using pyautogui functions. Examples: 'pyautogui.click(500, 300)', 'pyautogui.typewrite(\"Hello World\")', 'pyautogui.press(\"enter\")', 'pyautogui.hotkey(\"ctrl\", \"c\")', 'time.sleep(2)'" + } + }, + "required": ["code"] + } + + attributes_ns_ubuntu = "https://accessibility.windows.example.org/ns/attributes" + attributes_ns_windows = "https://accessibility.windows.example.org/ns/attributes" + state_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/state" + state_ns_windows = "https://accessibility.windows.example.org/ns/state" + component_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/component" + component_ns_windows = "https://accessibility.windows.example.org/ns/component" + value_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/value" + value_ns_windows = "https://accessibility.windows.example.org/ns/value" + class_ns_windows = "https://accessibility.windows.example.org/ns/class" + + def call(self, params: Union[str, dict], runtime=None, **kwargs) -> str: + try: + params = self._verify_json_format_args(params) + except ValueError as e: + return {"error": f"Invalid parameters: {str(e)}"} + + code = params.get("code", "").strip() + + if not code: + return {"error": "Code parameter cannot be empty"} + + # Execute runtime.step with a 2-minute timeout + try: + obs, reward, done, info = runtime.step(code, 0.2) + except Exception as e: + return f"Action execution failed: {str(e)}, please try again, possibly with a different action." + + # Prefer accessibility tree when available; otherwise fall back to screenshot + acc_tree = obs.get("accessibility_tree") + if acc_tree: + linearized_accessibility_tree = OSWorldActionTool.linearize_accessibility_tree(acc_tree, "ubuntu") # fixme: refer to the platform of the runtime/agent + if linearized_accessibility_tree: + linearized_accessibility_tree = OSWorldActionTool.trim_accessibility_tree(linearized_accessibility_tree, 10000) # fixme: add arguments for max tokens + return "Given the info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(linearized_accessibility_tree) + + + ## TODO(ys): untested visual tool response + screenshot_bytes = obs.get("screenshot") + if screenshot_bytes: + encoded = base64.b64encode(screenshot_bytes).decode("utf-8") + data_url = f"data:image/png;base64,{encoded}" + return [ + {"type": "text", "text": "Here is the latest desktop screenshot after executing the action."}, + {"type": "image_url", "image_url": {"url": data_url}} + ] + return {"error": "No observation available from runtime (both accessibility_tree and screenshot are missing)."} + + @staticmethod + def parse_code_from_string(input_string): + input_string = "\n".join([line.strip() for line in input_string.split(';') if line.strip()]) + if input_string.strip() in ['WAIT', 'DONE', 'FAIL']: + return [input_string.strip()] + + # This regular expression will match both ```code``` and ```python code``` + # and capture the `code` part. It uses a non-greedy match for the content inside. + pattern = r"```(?:\w+\s+)?(.*?)```" + # Find all non-overlapping matches in the string + matches = re.findall(pattern, input_string, re.DOTALL) + + # The regex above captures the content inside the triple backticks. + # The `re.DOTALL` flag allows the dot `.` to match newline characters as well, + # so the code inside backticks can span multiple lines. + + # matches now contains all the captured code snippets + + codes = [] + + for match in matches: + match = match.strip() + commands = ['WAIT', 'DONE', 'FAIL'] # fixme: updates this part when we have more commands + + if match in commands: + codes.append(match.strip()) + elif match.split('\n')[-1] in commands: + if len(match.split('\n')) > 1: + codes.append("\n".join(match.split('\n')[:-1])) + codes.append(match.split('\n')[-1]) + else: + codes.append(match) + + return codes + + @staticmethod + def linearize_accessibility_tree(accessibility_tree, platform="ubuntu"): + + if platform == "ubuntu": + _attributes_ns = OSWorldActionTool.attributes_ns_ubuntu + _state_ns = OSWorldActionTool.state_ns_ubuntu + _component_ns = OSWorldActionTool.component_ns_ubuntu + _value_ns = OSWorldActionTool.value_ns_ubuntu + elif platform == "windows": + _attributes_ns = OSWorldActionTool.attributes_ns_windows + _state_ns = OSWorldActionTool.state_ns_windows + _component_ns = OSWorldActionTool.component_ns_windows + _value_ns = OSWorldActionTool.value_ns_windows + else: + raise ValueError("Invalid platform, must be 'ubuntu' or 'windows'") + + filtered_nodes = OSWorldActionTool.filter_nodes(ET.fromstring(accessibility_tree), platform) + linearized_accessibility_tree = ["tag\tname\ttext\tclass\tdescription\tposition (top-left x&y)\tsize (w&h)"] + + # Linearize the accessibility tree nodes into a table format + for node in filtered_nodes: + if node.text: + text = ( + node.text if '"' not in node.text \ + else '"{:}"'.format(node.text.replace('"', '""')) + ) + + elif node.get("{{{:}}}class".format(OSWorldActionTool.class_ns_windows), "").endswith("EditWrapper") \ + and node.get("{{{:}}}value".format(_value_ns)): + node_text = node.get("{{{:}}}value".format(_value_ns), "") + text = (node_text if '"' not in node_text \ + else '"{:}"'.format(node_text.replace('"', '""')) + ) + else: + text = '""' + + linearized_accessibility_tree.append( + "{:}\t{:}\t{:}\t{:}\t{:}\t{:}\t{:}".format( + node.tag, node.get("name", ""), + text, + node.get("{{{:}}}class".format(_attributes_ns), "") if platform == "ubuntu" else node.get("{{{:}}}class".format(OSWorldActionTool.class_ns_windows), ""), + node.get("{{{:}}}description".format(_attributes_ns), ""), + node.get('{{{:}}}screencoord'.format(_component_ns), ""), + node.get('{{{:}}}size'.format(_component_ns), "") + ) + ) + + return "\n".join(linearized_accessibility_tree) + + @staticmethod + def filter_nodes(root: ET, platform="ubuntu", check_image=False): + filtered_nodes = [] + + for node in root.iter(): + if OSWorldActionTool.judge_node(node, platform, check_image): + filtered_nodes.append(node) + # print(ET.tostring(node, encoding="unicode")) + + return filtered_nodes + + @staticmethod + def judge_node(node: ET, platform="ubuntu", check_image=False) -> bool: + if platform == "ubuntu": + _state_ns = OSWorldActionTool.state_ns_ubuntu + _component_ns = OSWorldActionTool.component_ns_ubuntu + elif platform == "windows": + _state_ns = OSWorldActionTool.state_ns_windows + _component_ns = OSWorldActionTool.component_ns_windows + else: + raise ValueError("Invalid platform, must be 'ubuntu' or 'windows'") + + keeps: bool = node.tag.startswith("document") \ + or node.tag.endswith("item") \ + or node.tag.endswith("button") \ + or node.tag.endswith("heading") \ + or node.tag.endswith("label") \ + or node.tag.endswith("scrollbar") \ + or node.tag.endswith("searchbox") \ + or node.tag.endswith("textbox") \ + or node.tag.endswith("link") \ + or node.tag.endswith("tabelement") \ + or node.tag.endswith("textfield") \ + or node.tag.endswith("textarea") \ + or node.tag.endswith("menu") \ + or node.tag in {"alert", "canvas", "check-box" + , "combo-box", "entry", "icon" + , "image", "paragraph", "scroll-bar" + , "section", "slider", "static" + , "table-cell", "terminal", "text" + , "netuiribbontab", "start", "trayclockwclass" + , "traydummysearchcontrol", "uiimage", "uiproperty" + , "uiribboncommandbar" + } + keeps = keeps and ( + platform == "ubuntu" + and node.get("{{{:}}}showing".format(_state_ns), "false") == "true" + and node.get("{{{:}}}visible".format(_state_ns), "false") == "true" + or platform == "windows" + and node.get("{{{:}}}visible".format(_state_ns), "false") == "true" + ) \ + and ( + node.get("{{{:}}}enabled".format(_state_ns), "false") == "true" + or node.get("{{{:}}}editable".format(_state_ns), "false") == "true" + or node.get("{{{:}}}expandable".format(_state_ns), "false") == "true" + or node.get("{{{:}}}checkable".format(_state_ns), "false") == "true" + ) \ + and ( + node.get("name", "") != "" or node.text is not None and len(node.text) > 0 \ + or check_image and node.get("image", "false") == "true" + ) + + coordinates: Tuple[int, int] = eval(node.get("{{{:}}}screencoord".format(_component_ns), "(-1, -1)")) + sizes: Tuple[int, int] = eval(node.get("{{{:}}}size".format(_component_ns), "(-1, -1)")) + keeps = keeps and coordinates[0] >= 0 and coordinates[1] >= 0 and sizes[0] > 0 and sizes[1] > 0 + return keeps + + @staticmethod + def trim_accessibility_tree(linearized_accessibility_tree, max_tokens): + enc = tiktoken.encoding_for_model("gpt-4") + tokens = enc.encode(linearized_accessibility_tree) + if len(tokens) > max_tokens: + linearized_accessibility_tree = enc.decode(tokens[:max_tokens]) + linearized_accessibility_tree += "[...]\n" + return linearized_accessibility_tree \ No newline at end of file