diff --git a/src/forge/actors/_torchstore_utils.py b/src/forge/actors/_torchstore_utils.py index bc0d55c3b..2d14f7f30 100644 --- a/src/forge/actors/_torchstore_utils.py +++ b/src/forge/actors/_torchstore_utils.py @@ -10,6 +10,7 @@ import torch import torch.distributed.checkpoint as dcp from torch.distributed.checkpoint.metadata import Metadata as DcpMeta +from torchstore.transport.buffers import rdma_available logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -69,3 +70,8 @@ def extract_param_name(key: str) -> str: def get_dcp_whole_state_dict_key(policy_version: int) -> str: return f"{get_param_prefix(policy_version)}{KEY_DELIM}{DCP_WHOLE_STATE_TAG}" + + +def rdma_enabled() -> bool: + """Return if TorchStore thinks we're using RDMA""" + return rdma_available() diff --git a/src/forge/actors/generator.py b/src/forge/actors/generator.py index 0dc385cc0..8f14b1d39 100644 --- a/src/forge/actors/generator.py +++ b/src/forge/actors/generator.py @@ -46,6 +46,7 @@ get_param_key, get_param_prefix, load_tensor_from_dcp, + rdma_available, ) from forge.controller import ( @@ -56,7 +57,6 @@ ) from forge.data_models.completion import Completion from forge.data_models.prompt import to_prompt -from forge.env import TORCHSTORE_USE_RDMA from forge.observability.metrics import record_metric, Reduce from forge.observability.perf_tracker import Tracer from forge.types import ProcessConfig @@ -112,7 +112,7 @@ def __post_init__(self): self.sampling_params.output_kind = RequestOutputKind.FINAL_ONLY if self.use_dcp_for_weight_sync is None: - self.use_dcp_for_weight_sync = not TORCHSTORE_USE_RDMA.get_value() + self.use_dcp_for_weight_sync = not rdma_available() logger.debug(f"{self.use_dcp_for_weight_sync=}") @endpoint @@ -492,14 +492,16 @@ async def shutdown( # pyright: ignore[reportIncompatibleMethodOverride] await stop_proc_mesh(actor._generator_proc) @endpoint - async def save_model_params(self): - """Used for debugging purpose. Save model parameters before weight update.""" - await self.worker.save_model_params.call() + async def _test_save_model_params(self): + """Save model parameters before weight update, used for tesing purposes only.""" + logger.info("[Generator] save model parameters for testing.") + await self.worker._test_save_model_params.call() @endpoint - async def validate_model_params(self, validate_fn): - """Used for debugging purpose. Validate saved params using validate_fn.""" - return await self.worker.validate_model_params.call(validate_fn) + async def _test_validate_model_params(self, validate_fn): + """Validate updated model params using validate_fn.""" + logger.info("[Generator] start validating model parameters.") + return await self.worker._test_validate_model_params.call(validate_fn) @dataclass @@ -512,6 +514,8 @@ class GeneratorWorker(ForgeActor): """ vllm_config: VllmConfig + # TODO: Remove below param + _test_prev_params = {} @endpoint async def setup(self): @@ -601,19 +605,20 @@ async def update_weights(self, version: int) -> None: t.stop() @endpoint - async def save_model_params(self): - """Used for debugging purposes. Save model parameters before weight update.""" - self._debug_saved_params = {} + async def _test_save_model_params(self): + """Save model parameters before weight update, used for tesing purposes only.""" + logger.info("[GeneratorWorker] save model parameters for testing.") for name, param in self.worker.model_runner.model.named_parameters(): - self._debug_saved_params[name] = param.detach().cpu() + self._test_prev_params[name] = param.detach().cpu() logger.info( "[GeneratorWorker] finished saving model parameters, len = %d", - len(self._debug_saved_params), + len(self._test_prev_params), ) @endpoint - async def validate_model_params(self, validate_fn): - """Used for debugging purposes. Validate saved params using validate_fn.""" + async def _test_validate_model_params(self, validate_fn): + """Validate updated model params using validate_fn.""" + logger.info("[GeneratorWorker] start validating model parameters.") return validate_fn( - self._debug_saved_params, self.worker.model_runner.model, logger + self._test_prev_params, self.worker.model_runner.model, logger ) diff --git a/src/forge/actors/trainer.py b/src/forge/actors/trainer.py index 71049bc52..c98c836fa 100644 --- a/src/forge/actors/trainer.py +++ b/src/forge/actors/trainer.py @@ -41,11 +41,11 @@ DcpHandle, get_dcp_whole_state_dict_key, get_param_key, + rdma_available, ) from forge.controller import ForgeActor from forge.data.utils import batch_to_device -from forge.env import TORCHSTORE_USE_RDMA from forge.observability.metrics import record_metric, Reduce from forge.observability.perf_tracker import Tracer @@ -131,9 +131,7 @@ class RLTrainer(ForgeActor): # Non JobConfig-related fields loss: Callable = lambda logits, **targets: logits state_dict_key: str = "model_state_dict" - use_dcp: bool = ( - TORCHSTORE_USE_RDMA.get_value() == 0 - ) # torchstore currently only accepts 0 or 1 + use_dcp: bool = not rdma_available() dcp_path: str = "forge_dcp_tmp" def __post_init__(self): diff --git a/src/forge/env.py b/src/forge/env.py index 1478909da..b698b8013 100644 --- a/src/forge/env.py +++ b/src/forge/env.py @@ -101,7 +101,7 @@ def get_value(self) -> Any: TORCHSTORE_USE_RDMA = EnvVar( name="TORCHSTORE_RDMA_ENABLED", - default=0, + default=1, description="Whether or not to use RDMA in TorchStore.", )