Skip to content

Commit fb47d9b

Browse files
committed
format
1 parent 13c8499 commit fb47d9b

File tree

2 files changed

+18
-17
lines changed

2 files changed

+18
-17
lines changed

src/forge/actors/policy.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,23 +17,6 @@
1717
import torch
1818
import torch.distributed.checkpoint as dcp
1919
import torchstore as ts
20-
21-
from forge.actors._torchstore_utils import (
22-
extract_param_name,
23-
get_dcp_whole_state_dict_key,
24-
get_param_key,
25-
get_param_prefix,
26-
load_tensor_from_dcp,
27-
)
28-
29-
from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh
30-
from forge.data.sharding import VLLMSharding
31-
from forge.data_models.completion import Completion
32-
from forge.data_models.prompt import to_prompt
33-
from forge.interfaces import Policy as PolicyInterface
34-
from forge.observability.metrics import record_metric, Reduce
35-
from forge.observability.perf_tracker import Tracer
36-
from forge.types import ProcessConfig
3720
from monarch.actor import current_rank, endpoint, ProcMesh
3821
from torchstore.state_dict_utils import DELIM
3922
from vllm.config import VllmConfig
@@ -58,6 +41,23 @@
5841
from vllm.v1.structured_output import StructuredOutputManager
5942
from vllm.worker.worker_base import WorkerWrapperBase
6043

44+
from forge.actors._torchstore_utils import (
45+
extract_param_name,
46+
get_dcp_whole_state_dict_key,
47+
get_param_key,
48+
get_param_prefix,
49+
load_tensor_from_dcp,
50+
)
51+
52+
from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh
53+
from forge.data.sharding import VLLMSharding
54+
from forge.data_models.completion import Completion
55+
from forge.data_models.prompt import to_prompt
56+
from forge.interfaces import Policy as PolicyInterface
57+
from forge.observability.metrics import record_metric, Reduce
58+
from forge.observability.perf_tracker import Tracer
59+
from forge.types import ProcessConfig
60+
6161
logger = logging.getLogger(__name__)
6262
logger.setLevel(logging.INFO)
6363

tests/integration_tests/test_policy_update.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def _load_config(config_path: str) -> DictConfig:
7777
cfg = resolve_hf_hub_paths(cfg)
7878
return cfg
7979

80+
8081
def _test_validate_params_unchanged(
8182
prev_params, curr_model, logger
8283
) -> Exception | None:

0 commit comments

Comments
 (0)