Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Log progress/status to Prime Intellect PRI-442 #28

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 38 additions & 2 deletions open_diloco/hivemind_diloco.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from hivemind.optim.optimizer import logger
from hivemind.optim.progress_tracker import LocalTrainingProgress

from open_diloco.pi_progress_logger import log_progress_to_pi
from open_diloco.utils import found_inf_grad


Expand Down Expand Up @@ -206,10 +207,31 @@ def local_step(self) -> int:
@property
def real_step(self) -> int:
return self.local_step + self.local_progress.epoch * self.batch_size

def report_local_progress(self, local_epoch: int, samples_accumulated: int, loss: Optional[float] = None):
"""
Update the number of locally accumulated samples and notify to other peers about this.
This just calls the parent method, but additionally logs the status to Prime Intellect.
"""
super().report_local_progress(local_epoch, samples_accumulated) # this updates self.local_progress
if not self.client_mode:
log_progress_to_pi(
{
"update_type": "local",
"epoch": self.local_progress.epoch,
"local_step": self.local_step,
"total_local_step": self.real_step,
"local_loss": loss,
"samples_accumulated": self.local_progress.samples_accumulated,
"samples_per_second": self.local_progress.samples_per_second,
"time": self.local_progress.time,
}
)

def _parse_swarm_progress_data(self, metadata: TrainingProgressSchema) -> GlobalTrainingProgress:
"""Read performance statistics reported by peers, estimate progress towards next batch
This function is copy paste from hivemind. Only difference is that if fix the ETA estimation.
This function is copy paste from hivemind. Only difference is that if fix the ETA estimation,
and it reports progress to Prime Intellect.
"""
current_time = get_dht_time()

Expand Down Expand Up @@ -271,7 +293,7 @@ def _parse_swarm_progress_data(self, metadata: TrainingProgressSchema) -> Global
f"{self.prefix} has taken {self.local_step} local steps. Peers: {num_peers}, epoch: {self.local_progress.epoch}, steps: {self.real_step}. ETA: {estimated_time_to_next_epoch:.2f}",
)

return GlobalTrainingProgress(
global_progress = GlobalTrainingProgress(
global_epoch,
total_samples_accumulated,
target_batch_size=self.target_batch_size,
Expand All @@ -280,6 +302,20 @@ def _parse_swarm_progress_data(self, metadata: TrainingProgressSchema) -> Global
eta_next_epoch=current_time + estimated_time_to_next_epoch,
next_fetch_time=current_time + time_to_next_fetch,
)
log_progress_to_pi(
{
"update_type": "global",
"epoch": global_progress.epoch,
"total_samples_accumulated": global_progress.samples_accumulated,
"target_batch_size": global_progress.target_batch_size,
"num_peers": global_progress.num_peers,
"num_clients": global_progress.num_clients,
"eta_next_epoch": global_progress.eta_next_epoch,
"next_fetch_time": global_progress.next_fetch_time,
"time": current_time,
}
)
return global_progress


class AllReduceStrategy(Enum):
Expand Down
144 changes: 144 additions & 0 deletions open_diloco/pi_progress_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import requests
from enum import Enum
from typing import Any
from multiaddr import Multiaddr
from hivemind.optim.optimizer import logger
import json
import base64

class PrimeIntellectProgressLogger:
"""
Logs the status of nodes, and training progress to Prime Intellect's API.
"""

def __init__(self, peer_id, project, config, maddrs, *args, **kwargs):
self.peer_id = str(peer_id)
self.project = project
self.config = self._serialize_payload(config)
self.data = []
self.batch_size = 10
self.base_url = "https://protocol-api.primeintellect.ai/training_runs"

self.maddrs = [str(maddr) for maddr in maddrs]
self.run_id = self._initialize_run()

def _serialize_payload(self, data):
def serialize_custom(obj):
if isinstance(obj, Enum):
return obj.name
elif isinstance(obj, Multiaddr):
return str(obj)
elif isinstance(obj, bytes):
return base64.b64encode(obj).decode('utf-8')
raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable")

return json.loads(json.dumps(data, default=serialize_custom))

def _initialize_run(self):
headers = {
"Content-Type": "application/json"
}
payload = {
"project": self.project,
"config": self.config,
"peer_maddrs": self.maddrs,
"peer_id": self.peer_id
}
api = f"{self.base_url}/init"
try:
response = requests.post(api, json=payload, headers=headers)
response.raise_for_status()
response_data = response.json()
run_id = response_data.get('run_id')
if run_id:
logger.info(f"Successfully initialized run on Prime Intellect API. Run ID: {run_id}")
return run_id
else:
raise ValueError("No run ID returned from Prime Intellect API")
except requests.RequestException as e:
logger.error(f"Failed to initialize run on Prime Intellect API: {e}")
return None

def _remove_duplicates(self):
seen = set()
unique_logs = []
for log in self.data:
log_tuple = tuple(sorted(log.items()))
if log_tuple not in seen:
unique_logs.append(log)
seen.add(log_tuple)
self.data = unique_logs

def log(self, data: dict[str, Any]):
serialized_data = self._serialize_payload(data)
# Add peer_id to log data, so that logs can be associated with the correct node
serialized_data['peer_id'] = self.peer_id
self.data.append(serialized_data)
if len(self.data) >= self.batch_size:
self._remove_duplicates() # Remove duplicates before sending
self._send_batch()

def _send_batch(self):
# Remove duplicates before sending
self._remove_duplicates()

# Send batch of logs to Prime Intellect's API endpoint
batch = self.data[:self.batch_size]
headers = {
"Content-Type": "application/json"
}
payload = {
"run_id": self.run_id,
"logs": batch
}
api = f"{self.base_url}/logs"
try:
response = requests.post(api, json=payload, headers=headers)
response.raise_for_status()
logger.debug(f"Successfully sent batch of {len(batch)} logs to Prime Intellect API")
except requests.RequestException as e:
logger.warning(f"Failed to send logs to Prime Intellect API: {e}")

self.data = self.data[self.batch_size:]

def _finish(self):
headers = {
"Content-Type": "application/json"
}
api = f"{self.base_url}/{self.run_id}/finish"
try:
response = requests.post(api, headers=headers)
response.raise_for_status()
logger.debug(f"Successfully called finish endpoint for run ID: {self.run_id}")
except requests.RequestException as e:
logger.warning(f"Failed to call finish endpoint: {e}")

def finish(self):
# Remove duplicates before sending any remaining logs
self._remove_duplicates()

# Send any remaining logs
while self.data:
self._send_batch()

self._finish()

_progress_logger = None

def init_pi_progress_logger(peer_id, project, config, *args, **kwargs):
global _progress_logger
_progress_logger = PrimeIntellectProgressLogger(peer_id, project, config, *args, **kwargs)

def get_pi_progress_logger():
global _progress_logger
if _progress_logger is None:
raise ValueError("Status logger has not been initialized. Please call init_status_logger first.")
return _progress_logger

def log_progress_to_pi(data: dict[str, Any]):
logger = get_pi_progress_logger()
logger.log(data)

def finish_pi_progress_logger():
logger = get_pi_progress_logger()
logger.finish()
19 changes: 18 additions & 1 deletion open_diloco/train_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@
)
from open_diloco.hivemind_diloco import AllReduceStrategy, DiLoCoOptimizer
from open_diloco.utils import WandbLogger, DummyLogger
from open_diloco.pi_progress_logger import (
init_pi_progress_logger,
finish_pi_progress_logger,
log_progress_to_pi,
)

from hivemind.dht.dht import DHT
from hivemind.utils.networking import log_visible_maddrs
Expand Down Expand Up @@ -121,6 +126,7 @@ class Config(BaseConfig):
# Checkpointing and logging
project: str = "hivemind_debug"
metric_logger_type: Literal["wandb", "dummy"] = "wandb"
status_logger_type: Literal["prime", "dummy"] = "prime"
log_activations_steps: int | None = None
ckpt: CkptConfig = CkptConfig()
# Hivemind
Expand Down Expand Up @@ -209,7 +215,9 @@ def train(config: Config):
host_maddrs=config.hv.host_maddrs,
announce_maddrs=config.hv.announce_maddrs,
)
log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=False)
maddrs = dht.get_visible_maddrs(latest=True)
log_visible_maddrs(maddrs, only_p2p=False)
init_pi_progress_logger(dht.peer_id, config.project, config.model_dump(), maddrs=maddrs)

if local_rank == 0:
check_checkpoint_path_access(config.ckpt.path, rank, config.hv.world_rank if config.hv else None)
Expand Down Expand Up @@ -459,6 +467,14 @@ def scheduler_fn(opt):
current_time = time.time()

metric_logger.log(metrics)
# log progress to prime intellect
pi_update = metrics.copy()
pi_update["update_type"] = "global"
pi_update["epoch"] = real_step
pi_update["time"] = time.time()
# lowercase all keys for sending to pi
pi_update = {k.lower(): v for k, v in pi_update.items()}
log_progress_to_pi(pi_update)

if config.hv is None:
log(
Expand Down Expand Up @@ -513,6 +529,7 @@ def scheduler_fn(opt):

log("Training completed.")
if rank == 0:
finish_pi_progress_logger()
metric_logger.finish()


Expand Down
4 changes: 3 additions & 1 deletion open_diloco/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch.distributed.fsdp import ShardingStrategy
from torch.utils.data import IterableDataset
import wandb
import os


_WRAPPED_NAME_TO_REMOVE = ["_forward_module.", "_fsdp_wrapped_module.", "_orig_mod."]
Expand Down Expand Up @@ -192,7 +193,8 @@ class DummyLogger:
def __init__(self, project, config, *args, **kwargs):
self.project = project
self.config = config
open(project, "a").close() # Create an empty file at the project path
filename = kwargs.get('filename', 'default.log')
open(os.path.join(project, filename), "w").close() # Create an empty file at the project path

self.data = []

Expand Down
Loading