Skip to content

Commit

Permalink
Log progress/status to Prime Intellect
Browse files Browse the repository at this point in the history
  • Loading branch information
manveerxyz committed Sep 10, 2024
1 parent ad3a344 commit c54f87b
Show file tree
Hide file tree
Showing 4 changed files with 204 additions and 4 deletions.
40 changes: 38 additions & 2 deletions open_diloco/hivemind_diloco.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from hivemind.optim.optimizer import logger
from hivemind.optim.progress_tracker import LocalTrainingProgress

from open_diloco.pi_progress_logger import log_progress_to_pi
from open_diloco.utils import found_inf_grad


Expand Down Expand Up @@ -206,10 +207,31 @@ def local_step(self) -> int:
@property
def real_step(self) -> int:
return self.local_step + self.local_progress.epoch * self.batch_size

def report_local_progress(self, local_epoch: int, samples_accumulated: int, loss: Optional[float] = None):
"""
Update the number of locally accumulated samples and notify to other peers about this.
This just calls the parent method, but additionally logs the status to Prime Intellect.
"""
super().report_local_progress(local_epoch, samples_accumulated) # this updates self.local_progress
if not self.client_mode:
log_progress_to_pi(
{
"update_type": "local",
"epoch": self.local_progress.epoch,
"local_step": self.local_step,
"total_local_step": self.real_step,
"local_loss": loss,
"samples_accumulated": self.local_progress.samples_accumulated,
"samples_per_second": self.local_progress.samples_per_second,
"time": self.local_progress.time,
}
)

def _parse_swarm_progress_data(self, metadata: TrainingProgressSchema) -> GlobalTrainingProgress:
"""Read performance statistics reported by peers, estimate progress towards next batch
This function is copy paste from hivemind. Only difference is that if fix the ETA estimation.
This function is copy paste from hivemind. Only difference is that if fix the ETA estimation,
and it reports progress to Prime Intellect.
"""
current_time = get_dht_time()

Expand Down Expand Up @@ -271,7 +293,7 @@ def _parse_swarm_progress_data(self, metadata: TrainingProgressSchema) -> Global
f"{self.prefix} has taken {self.local_step} local steps. Peers: {num_peers}, epoch: {self.local_progress.epoch}, steps: {self.real_step}. ETA: {estimated_time_to_next_epoch:.2f}",
)

return GlobalTrainingProgress(
global_progress = GlobalTrainingProgress(
global_epoch,
total_samples_accumulated,
target_batch_size=self.target_batch_size,
Expand All @@ -280,6 +302,20 @@ def _parse_swarm_progress_data(self, metadata: TrainingProgressSchema) -> Global
eta_next_epoch=current_time + estimated_time_to_next_epoch,
next_fetch_time=current_time + time_to_next_fetch,
)
log_progress_to_pi(
{
"update_type": "global",
"epoch": global_progress.epoch,
"total_samples_accumulated": global_progress.samples_accumulated,
"target_batch_size": global_progress.target_batch_size,
"num_peers": global_progress.num_peers,
"num_clients": global_progress.num_clients,
"eta_next_epoch": global_progress.eta_next_epoch,
"next_fetch_time": global_progress.next_fetch_time,
"time": current_time,
}
)
return global_progress


class AllReduceStrategy(Enum):
Expand Down
144 changes: 144 additions & 0 deletions open_diloco/pi_progress_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import os
import requests
from enum import Enum
from typing import Any
from multiaddr import Multiaddr
import json
import base64

class PrimeIntellectProgressLogger:
"""
Logs the status of nodes, and training progress to Prime Intellect's API.
"""

def __init__(self, peer_id, project, config, maddrs, *args, **kwargs):
self.peer_id = str(peer_id)
self.project = project
self.config = self._serialize_payload(config)
self.data = []
self.batch_size = 10
self.base_url = "https://protocol-api.primeintellect.ai/training_runs"

self.maddrs = [str(maddr) for maddr in maddrs]
self.run_id = self._initialize_run()

def _serialize_payload(self, data):
def serialize_custom(obj):
if isinstance(obj, Enum):
return obj.name
elif isinstance(obj, Multiaddr):
return str(obj)
elif isinstance(obj, bytes):
return base64.b64encode(obj).decode('utf-8')
raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable")

return json.loads(json.dumps(data, default=serialize_custom))

def _initialize_run(self):
headers = {
"Content-Type": "application/json"
}
payload = {
"project": self.project,
"config": self.config,
"peer_maddrs": self.maddrs,
"peer_id": self.peer_id
}
api = f"{self.base_url}/init"
try:
response = requests.post(api, json=payload, headers=headers)
response.raise_for_status()
response_data = response.json()
run_id = response_data.get('run_id')
if run_id:
print(f"Successfully initialized Prime Protocol API. Run ID: {run_id}")
return run_id
else:
raise ValueError("No run ID returned from Prime Protocol API")
except requests.RequestException as e:
print(f"Failed to initialize Prime Protocol API: {e}")
return None

def _remove_duplicates(self):
seen = set()
unique_logs = []
for log in self.data:
log_tuple = tuple(sorted(log.items()))
if log_tuple not in seen:
unique_logs.append(log)
seen.add(log_tuple)
self.data = unique_logs

def log(self, data: dict[str, Any]):
serialized_data = self._serialize_payload(data)
# Add peer_id to log data, so that logs can be associated with the correct node
serialized_data['peer_id'] = self.peer_id
self.data.append(serialized_data)
if len(self.data) >= self.batch_size:
self._remove_duplicates() # Remove duplicates before sending
self._send_batch()

def _send_batch(self):
# Remove duplicates before sending
self._remove_duplicates()

# Send batch of logs to Prime Intellect's API endpoint
batch = self.data[:self.batch_size]
headers = {
"Content-Type": "application/json"
}
payload = {
"run_id": self.run_id,
"logs": batch
}
api = f"{self.base_url}/logs"
try:
response = requests.post(api, json=payload, headers=headers)
response.raise_for_status()
print(f"Successfully sent batch of {len(batch)} logs to Prime Protocol API")
except requests.RequestException as e:
print(f"Failed to send logs to Prime Protocol API: {e}")

self.data = self.data[self.batch_size:]

def _finish(self):
headers = {
"Content-Type": "application/json"
}
api = f"{self.base_url}/{self.run_id}/finish"
try:
response = requests.post(api, headers=headers)
response.raise_for_status()
print(f"Successfully called finish endpoint for run ID: {self.run_id}")
except requests.RequestException as e:
print(f"Failed to call finish endpoint: {e}")

def finish(self):
# Remove duplicates before sending any remaining logs
self._remove_duplicates()

# Send any remaining logs
while self.data:
self._send_batch()

self._finish()

_progress_logger = None

def init_pi_progress_logger(peer_id, project, config, *args, **kwargs):
global _progress_logger
_progress_logger = PrimeIntellectProgressLogger(peer_id, project, config, *args, **kwargs)

def get_pi_progress_logger():
global _progress_logger
if _progress_logger is None:
raise ValueError("Status logger has not been initialized. Please call init_status_logger first.")
return _progress_logger

def log_progress_to_pi(data: dict[str, Any]):
logger = get_pi_progress_logger()
logger.log(data)

def finish_pi_progress_logger():
logger = get_pi_progress_logger()
logger.finish()
19 changes: 18 additions & 1 deletion open_diloco/train_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@
)
from open_diloco.hivemind_diloco import AllReduceStrategy, DiLoCoOptimizer
from open_diloco.utils import WandbLogger, DummyLogger
from open_diloco.pi_progress_logger import (
init_pi_progress_logger,
finish_pi_progress_logger,
log_progress_to_pi,
)

from hivemind.dht.dht import DHT
from hivemind.utils.networking import log_visible_maddrs
Expand Down Expand Up @@ -121,6 +126,7 @@ class Config(BaseConfig):
# Checkpointing and logging
project: str = "hivemind_debug"
metric_logger_type: Literal["wandb", "dummy"] = "wandb"
status_logger_type: Literal["prime", "dummy"] = "prime"
log_activations_steps: int | None = None
ckpt: CkptConfig = CkptConfig()
# Hivemind
Expand Down Expand Up @@ -209,7 +215,9 @@ def train(config: Config):
host_maddrs=config.hv.host_maddrs,
announce_maddrs=config.hv.announce_maddrs,
)
log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=False)
maddrs = dht.get_visible_maddrs(latest=True)
log_visible_maddrs(maddrs, only_p2p=False)
init_pi_progress_logger(dht.peer_id, config.project, config.model_dump(), maddrs=maddrs)

if local_rank == 0:
check_checkpoint_path_access(config.ckpt.path, rank, config.hv.world_rank if config.hv else None)
Expand Down Expand Up @@ -459,6 +467,14 @@ def scheduler_fn(opt):
current_time = time.time()

metric_logger.log(metrics)
# log progress to prime intellect
pi_update = metrics.copy()
pi_update["update_type"] = "global"
pi_update["epoch"] = real_step
pi_update["time"] = time.time()
# lowercase all keys for sending to pi
pi_update = {k.lower(): v for k, v in pi_update.items()}
log_progress_to_pi(pi_update)

if config.hv is None:
log(
Expand Down Expand Up @@ -513,6 +529,7 @@ def scheduler_fn(opt):

log("Training completed.")
if rank == 0:
finish_pi_progress_logger()
metric_logger.finish()


Expand Down
5 changes: 4 additions & 1 deletion open_diloco/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from torch.distributed.fsdp import ShardingStrategy
from torch.utils.data import IterableDataset
import wandb
import os
import requests


_WRAPPED_NAME_TO_REMOVE = ["_forward_module.", "_fsdp_wrapped_module.", "_orig_mod."]
Expand Down Expand Up @@ -192,7 +194,8 @@ class DummyLogger:
def __init__(self, project, config, *args, **kwargs):
self.project = project
self.config = config
open(project, "a").close() # Create an empty file at the project path
filename = kwargs.get('filename', 'default.log')
open(os.path.join(project, filename), "w").close() # Create an empty file at the project path

self.data = []

Expand Down

0 comments on commit c54f87b

Please sign in to comment.