Skip to content

Commit 7c3d393

Browse files
xinyuangui2justinvyu
authored andcommitted
[train][V2] Implement Result::from_path in v2 (ray-project#58216)
## Description In this function, `Result::from_path` is implemented in ray train v2, which reconstructs a `Result` object from the checkpoints. This implementation leverages `CheckpointManager` and refers to https://github.com/ray-project/ray/blob/master/python/ray/train/v2/_internal/execution/controller/controller.py#L512-L540 --------- Signed-off-by: xgui <xgui@anyscale.com> Signed-off-by: Justin Yu <justinvyu@anyscale.com> Co-authored-by: Justin Yu <justinvyu@anyscale.com>
1 parent 7cd04a3 commit 7c3d393

File tree

6 files changed

+331
-9
lines changed

6 files changed

+331
-9
lines changed

python/ray/train/v2/BUILD.bazel

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ py_test(
343343

344344
py_test(
345345
name = "test_result",
346-
size = "small",
346+
size = "medium",
347347
srcs = ["tests/test_result.py"],
348348
env = {"RAY_TRAIN_V2_ENABLED": "1"},
349349
tags = [

python/ray/train/v2/_internal/execution/callback.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from ray.train.v2._internal.execution.training_report import _TrainingReport
55
from ray.train.v2.api.callback import RayTrainCallback
66
from ray.train.v2.api.config import ScalingConfig
7-
from ray.train.v2.api.result import Result
87
from ray.util.annotations import DeveloperAPI
98

109
if TYPE_CHECKING:
@@ -20,6 +19,7 @@
2019
WorkerGroupContext,
2120
WorkerGroupPollStatus,
2221
)
22+
from ray.train.v2.api.result import Result
2323

2424

2525
@DeveloperAPI
@@ -128,7 +128,7 @@ def before_controller_execute_resize_decision(
128128
"""Called before the controller executes a resize decision."""
129129
pass
130130

131-
def after_controller_finish(self, result: Result):
131+
def after_controller_finish(self, result: "Result"):
132132
"""Called after the training run completes, providing access to the final result.
133133
134134
Args:

python/ray/train/v2/_internal/execution/checkpoint/checkpoint_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ def __init__(
9090
self._condition = asyncio.Condition()
9191
super().__init__(checkpoint_config)
9292
# If the snapshot is found, the checkpoint manager will restore its state.
93+
# TODO(xgui): CheckpointManager is used to save or restore the checkpoint manager state.
94+
# We should sanity check if we should see old state in the storage folder.
9395
self._maybe_load_state_from_storage()
9496

9597
def register_checkpoint(

python/ray/train/v2/api/result.py

Lines changed: 97 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,21 @@
33
from dataclasses import dataclass
44
from typing import Any, Dict, List, Optional, Tuple, Union
55

6+
import pandas as pd
67
import pyarrow
78

89
import ray
910
from ray.air.result import Result as ResultV1
11+
from ray.train import Checkpoint, CheckpointConfig
12+
from ray.train.v2._internal.constants import CHECKPOINT_MANAGER_SNAPSHOT_FILENAME
13+
from ray.train.v2._internal.execution.checkpoint.checkpoint_manager import (
14+
CheckpointManager,
15+
)
16+
from ray.train.v2._internal.execution.storage import (
17+
StorageContext,
18+
_exists_at_fs_path,
19+
get_fs_and_path,
20+
)
1021
from ray.train.v2.api.exceptions import TrainingFailedError
1122
from ray.util.annotations import Deprecated, PublicAPI
1223

@@ -15,11 +26,9 @@
1526

1627
@dataclass
1728
class Result(ResultV1):
18-
checkpoint: Optional["ray.train.Checkpoint"]
29+
checkpoint: Optional[Checkpoint]
1930
error: Optional[TrainingFailedError]
20-
best_checkpoints: Optional[
21-
List[Tuple["ray.train.Checkpoint", Dict[str, Any]]]
22-
] = None
31+
best_checkpoints: Optional[List[Tuple[Checkpoint, Dict[str, Any]]]] = None
2332

2433
@PublicAPI(stability="alpha")
2534
def get_best_checkpoint(
@@ -33,7 +42,90 @@ def from_path(
3342
path: Union[str, os.PathLike],
3443
storage_filesystem: Optional[pyarrow.fs.FileSystem] = None,
3544
) -> "Result":
36-
raise NotImplementedError("`Result.from_path` is not implemented yet.")
45+
"""Restore a training result from a previously saved training run path.
46+
47+
Args:
48+
path: Path to the run output directory
49+
storage_filesystem: Optional filesystem to use for accessing the path
50+
51+
Returns:
52+
Result object with restored checkpoints and metrics
53+
"""
54+
fs, fs_path = get_fs_and_path(str(path), storage_filesystem)
55+
56+
# Validate that the experiment directory exists
57+
if not _exists_at_fs_path(fs, fs_path):
58+
raise RuntimeError(f"Experiment folder {fs_path} doesn't exist.")
59+
60+
# Remove trailing slashes to handle paths correctly
61+
# os.path.basename() returns empty string for paths with trailing slashes
62+
fs_path = fs_path.rstrip("/")
63+
storage_path, experiment_dir_name = os.path.dirname(fs_path), os.path.basename(
64+
fs_path
65+
)
66+
67+
storage_context = StorageContext(
68+
storage_path=storage_path,
69+
experiment_dir_name=experiment_dir_name,
70+
storage_filesystem=fs,
71+
)
72+
73+
# Validate that the checkpoint manager snapshot file exists
74+
if not _exists_at_fs_path(
75+
storage_context.storage_filesystem,
76+
storage_context.checkpoint_manager_snapshot_path,
77+
):
78+
raise RuntimeError(
79+
f"Failed to restore the Result object: "
80+
f"{CHECKPOINT_MANAGER_SNAPSHOT_FILENAME} doesn't exist in the "
81+
f"experiment folder. Make sure that this is an output directory created by a Ray Train run."
82+
)
83+
84+
checkpoint_manager = CheckpointManager(
85+
storage_context=storage_context,
86+
checkpoint_config=CheckpointConfig(),
87+
)
88+
89+
# When we build a Result object from checkpoints, the error is not loaded.
90+
return cls._from_checkpoint_manager(
91+
checkpoint_manager=checkpoint_manager,
92+
storage_context=storage_context,
93+
)
94+
95+
@classmethod
96+
def _from_checkpoint_manager(
97+
cls,
98+
checkpoint_manager: CheckpointManager,
99+
storage_context: StorageContext,
100+
error: Optional[TrainingFailedError] = None,
101+
) -> "Result":
102+
"""Create a Result object from a CheckpointManager."""
103+
latest_checkpoint_result = checkpoint_manager.latest_checkpoint_result
104+
if latest_checkpoint_result:
105+
latest_metrics = latest_checkpoint_result.metrics
106+
latest_checkpoint = latest_checkpoint_result.checkpoint
107+
else:
108+
latest_metrics = None
109+
latest_checkpoint = None
110+
best_checkpoints = [
111+
(r.checkpoint, r.metrics)
112+
for r in checkpoint_manager.best_checkpoint_results
113+
]
114+
115+
# Provide the history of metrics attached to checkpoints as a dataframe.
116+
metrics_dataframe = None
117+
if best_checkpoints:
118+
metrics_dataframe = pd.DataFrame([m for _, m in best_checkpoints])
119+
120+
return Result(
121+
metrics=latest_metrics,
122+
checkpoint=latest_checkpoint,
123+
error=error,
124+
path=storage_context.experiment_fs_path,
125+
best_checkpoints=best_checkpoints,
126+
metrics_dataframe=metrics_dataframe,
127+
_storage_filesystem=storage_context.storage_filesystem,
128+
)
37129

38130
@property
39131
@Deprecated

python/ray/train/v2/tests/conftest.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import logging
22

3+
import boto3
34
import pytest
45

56
import ray
67
from ray import runtime_context
8+
from ray._common.test_utils import simulate_s3_bucket
79
from ray.cluster_utils import Cluster
810
from ray.train.v2._internal.constants import (
911
ENABLE_STATE_ACTOR_RECONCILIATION_ENV_VAR,
@@ -80,3 +82,25 @@ def mock_current_actor(self):
8082
)
8183

8284
yield
85+
86+
87+
@pytest.fixture
88+
def mock_s3_bucket_uri():
89+
from ray.air._internal.uri_utils import URI
90+
91+
port = 5002
92+
region = "us-west-2"
93+
with simulate_s3_bucket(port=port, region=region) as s3_uri:
94+
s3 = boto3.client(
95+
"s3", region_name=region, endpoint_url=f"http://localhost:{port}"
96+
)
97+
# Bucket name will be autogenerated/unique per test
98+
bucket_name = URI(s3_uri).name
99+
s3.create_bucket(
100+
Bucket=bucket_name,
101+
CreateBucketConfiguration={"LocationConstraint": region},
102+
)
103+
# Disable server HTTP request logging
104+
logging.getLogger("werkzeug").setLevel(logging.WARNING)
105+
yield s3_uri
106+
logging.getLogger("werkzeug").setLevel(logging.INFO)

0 commit comments

Comments
 (0)