From 64f9c7929daee6a67ed46440fae0e25798937f83 Mon Sep 17 00:00:00 2001 From: Manveer Date: Tue, 3 Sep 2024 17:48:59 -0700 Subject: [PATCH] Log progress/status to Prime Intellect --- open_diloco/hivemind_diloco.py | 40 ++++++++- open_diloco/pi_progress_logger.py | 144 ++++++++++++++++++++++++++++++ open_diloco/train_fsdp.py | 19 +++- open_diloco/utils.py | 4 +- 4 files changed, 203 insertions(+), 4 deletions(-) create mode 100644 open_diloco/pi_progress_logger.py diff --git a/open_diloco/hivemind_diloco.py b/open_diloco/hivemind_diloco.py index 308608b..bfed5fb 100644 --- a/open_diloco/hivemind_diloco.py +++ b/open_diloco/hivemind_diloco.py @@ -29,6 +29,7 @@ from hivemind.optim.optimizer import logger from hivemind.optim.progress_tracker import LocalTrainingProgress +from open_diloco.pi_progress_logger import log_progress_to_pi from open_diloco.utils import found_inf_grad @@ -206,10 +207,31 @@ def local_step(self) -> int: @property def real_step(self) -> int: return self.local_step + self.local_progress.epoch * self.batch_size + + def report_local_progress(self, local_epoch: int, samples_accumulated: int, loss: Optional[float] = None): + """ + Update the number of locally accumulated samples and notify to other peers about this. + This just calls the parent method, but additionally logs the status to Prime Intellect. + """ + super().report_local_progress(local_epoch, samples_accumulated) # this updates self.local_progress + if not self.client_mode: + log_progress_to_pi( + { + "update_type": "local", + "epoch": self.local_progress.epoch, + "local_step": self.local_step, + "total_local_step": self.real_step, + "local_loss": loss, + "samples_accumulated": self.local_progress.samples_accumulated, + "samples_per_second": self.local_progress.samples_per_second, + "time": self.local_progress.time, + } + ) def _parse_swarm_progress_data(self, metadata: TrainingProgressSchema) -> GlobalTrainingProgress: """Read performance statistics reported by peers, estimate progress towards next batch - This function is copy paste from hivemind. Only difference is that if fix the ETA estimation. + This function is copy paste from hivemind. Only difference is that if fix the ETA estimation, + and it reports progress to Prime Intellect. """ current_time = get_dht_time() @@ -271,7 +293,7 @@ def _parse_swarm_progress_data(self, metadata: TrainingProgressSchema) -> Global f"{self.prefix} has taken {self.local_step} local steps. Peers: {num_peers}, epoch: {self.local_progress.epoch}, steps: {self.real_step}. ETA: {estimated_time_to_next_epoch:.2f}", ) - return GlobalTrainingProgress( + global_progress = GlobalTrainingProgress( global_epoch, total_samples_accumulated, target_batch_size=self.target_batch_size, @@ -280,6 +302,20 @@ def _parse_swarm_progress_data(self, metadata: TrainingProgressSchema) -> Global eta_next_epoch=current_time + estimated_time_to_next_epoch, next_fetch_time=current_time + time_to_next_fetch, ) + log_progress_to_pi( + { + "update_type": "global", + "epoch": global_progress.epoch, + "total_samples_accumulated": global_progress.samples_accumulated, + "target_batch_size": global_progress.target_batch_size, + "num_peers": global_progress.num_peers, + "num_clients": global_progress.num_clients, + "eta_next_epoch": global_progress.eta_next_epoch, + "next_fetch_time": global_progress.next_fetch_time, + "time": current_time, + } + ) + return global_progress class AllReduceStrategy(Enum): diff --git a/open_diloco/pi_progress_logger.py b/open_diloco/pi_progress_logger.py new file mode 100644 index 0000000..7ee9108 --- /dev/null +++ b/open_diloco/pi_progress_logger.py @@ -0,0 +1,144 @@ +import requests +from enum import Enum +from typing import Any +from multiaddr import Multiaddr +from hivemind.optim.optimizer import logger +import json +import base64 + +class PrimeIntellectProgressLogger: + """ + Logs the status of nodes, and training progress to Prime Intellect's API. + """ + + def __init__(self, peer_id, project, config, maddrs, *args, **kwargs): + self.peer_id = str(peer_id) + self.project = project + self.config = self._serialize_payload(config) + self.data = [] + self.batch_size = 10 + self.base_url = "https://protocol-api.primeintellect.ai/training_runs" + + self.maddrs = [str(maddr) for maddr in maddrs] + self.run_id = self._initialize_run() + + def _serialize_payload(self, data): + def serialize_custom(obj): + if isinstance(obj, Enum): + return obj.name + elif isinstance(obj, Multiaddr): + return str(obj) + elif isinstance(obj, bytes): + return base64.b64encode(obj).decode('utf-8') + raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable") + + return json.loads(json.dumps(data, default=serialize_custom)) + + def _initialize_run(self): + headers = { + "Content-Type": "application/json" + } + payload = { + "project": self.project, + "config": self.config, + "peer_maddrs": self.maddrs, + "peer_id": self.peer_id + } + api = f"{self.base_url}/init" + try: + response = requests.post(api, json=payload, headers=headers) + response.raise_for_status() + response_data = response.json() + run_id = response_data.get('run_id') + if run_id: + logger.info(f"Successfully initialized run on Prime Intellect API. Run ID: {run_id}") + return run_id + else: + raise ValueError("No run ID returned from Prime Intellect API") + except requests.RequestException as e: + logger.error(f"Failed to initialize run on Prime Intellect API: {e}") + return None + + def _remove_duplicates(self): + seen = set() + unique_logs = [] + for log in self.data: + log_tuple = tuple(sorted(log.items())) + if log_tuple not in seen: + unique_logs.append(log) + seen.add(log_tuple) + self.data = unique_logs + + def log(self, data: dict[str, Any]): + serialized_data = self._serialize_payload(data) + # Add peer_id to log data, so that logs can be associated with the correct node + serialized_data['peer_id'] = self.peer_id + self.data.append(serialized_data) + if len(self.data) >= self.batch_size: + self._remove_duplicates() # Remove duplicates before sending + self._send_batch() + + def _send_batch(self): + # Remove duplicates before sending + self._remove_duplicates() + + # Send batch of logs to Prime Intellect's API endpoint + batch = self.data[:self.batch_size] + headers = { + "Content-Type": "application/json" + } + payload = { + "run_id": self.run_id, + "logs": batch + } + api = f"{self.base_url}/logs" + try: + response = requests.post(api, json=payload, headers=headers) + response.raise_for_status() + logger.debug(f"Successfully sent batch of {len(batch)} logs to Prime Intellect API") + except requests.RequestException as e: + logger.warning(f"Failed to send logs to Prime Intellect API: {e}") + + self.data = self.data[self.batch_size:] + + def _finish(self): + headers = { + "Content-Type": "application/json" + } + api = f"{self.base_url}/{self.run_id}/finish" + try: + response = requests.post(api, headers=headers) + response.raise_for_status() + logger.debug(f"Successfully called finish endpoint for run ID: {self.run_id}") + except requests.RequestException as e: + logger.warning(f"Failed to call finish endpoint: {e}") + + def finish(self): + # Remove duplicates before sending any remaining logs + self._remove_duplicates() + + # Send any remaining logs + while self.data: + self._send_batch() + + self._finish() + +_progress_logger = None + +def init_pi_progress_logger(peer_id, project, config, *args, **kwargs): + global _progress_logger + _progress_logger = PrimeIntellectProgressLogger(peer_id, project, config, *args, **kwargs) + +def get_pi_progress_logger(): + global _progress_logger + if _progress_logger is None: + raise ValueError("Status logger has not been initialized. Please call init_status_logger first.") + return _progress_logger + +def log_progress_to_pi(data: dict[str, Any]): + logger = get_pi_progress_logger() + logger.log(data) + +def finish_pi_progress_logger(): + logger = get_pi_progress_logger() + logger.finish() diff --git a/open_diloco/train_fsdp.py b/open_diloco/train_fsdp.py index 4d5ef3e..91d0d13 100644 --- a/open_diloco/train_fsdp.py +++ b/open_diloco/train_fsdp.py @@ -47,6 +47,11 @@ ) from open_diloco.hivemind_diloco import AllReduceStrategy, DiLoCoOptimizer from open_diloco.utils import WandbLogger, DummyLogger +from open_diloco.pi_progress_logger import ( + init_pi_progress_logger, + finish_pi_progress_logger, + log_progress_to_pi, +) from hivemind.dht.dht import DHT from hivemind.utils.networking import log_visible_maddrs @@ -121,6 +126,7 @@ class Config(BaseConfig): # Checkpointing and logging project: str = "hivemind_debug" metric_logger_type: Literal["wandb", "dummy"] = "wandb" + status_logger_type: Literal["prime", "dummy"] = "prime" log_activations_steps: int | None = None ckpt: CkptConfig = CkptConfig() # Hivemind @@ -209,7 +215,9 @@ def train(config: Config): host_maddrs=config.hv.host_maddrs, announce_maddrs=config.hv.announce_maddrs, ) - log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=False) + maddrs = dht.get_visible_maddrs(latest=True) + log_visible_maddrs(maddrs, only_p2p=False) + init_pi_progress_logger(dht.peer_id, config.project, config.model_dump(), maddrs=maddrs) if local_rank == 0: check_checkpoint_path_access(config.ckpt.path, rank, config.hv.world_rank if config.hv else None) @@ -459,6 +467,14 @@ def scheduler_fn(opt): current_time = time.time() metric_logger.log(metrics) + # log progress to prime intellect + pi_update = metrics.copy() + pi_update["update_type"] = "global" + pi_update["epoch"] = real_step + pi_update["time"] = time.time() + # lowercase all keys for sending to pi + pi_update = {k.lower(): v for k, v in pi_update.items()} + log_progress_to_pi(pi_update) if config.hv is None: log( @@ -513,6 +529,7 @@ def scheduler_fn(opt): log("Training completed.") if rank == 0: + finish_pi_progress_logger() metric_logger.finish() diff --git a/open_diloco/utils.py b/open_diloco/utils.py index 0ced218..2dc3ee5 100644 --- a/open_diloco/utils.py +++ b/open_diloco/utils.py @@ -8,6 +8,7 @@ from torch.distributed.fsdp import ShardingStrategy from torch.utils.data import IterableDataset import wandb +import os _WRAPPED_NAME_TO_REMOVE = ["_forward_module.", "_fsdp_wrapped_module.", "_orig_mod."] @@ -192,7 +193,8 @@ class DummyLogger: def __init__(self, project, config, *args, **kwargs): self.project = project self.config = config - open(project, "a").close() # Create an empty file at the project path + filename = kwargs.get('filename', 'default.log') + open(os.path.join(project, filename), "w").close() # Create an empty file at the project path self.data = []