Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Addition of significant docstrings #142

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 120 additions & 1 deletion common/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,34 +12,96 @@


class BatchBase(Pipelineable, abc.ABC):
"""
A base class for batches used in pipelines.

Attributes:
None

"""
@abc.abstractmethod
def as_dict(self) -> Dict:
"""
Convert the batch into a dictionary representation.

Returns:
Dict: A dictionary representation of the batch.

Raises:
NotImplementedError: If the method is not implemented in a subclass.

"""
raise NotImplementedError

def to(self, device: torch.device, non_blocking: bool = False):
"""
Move the batch to the specified device.

Args:
device (torch.device): The target device.
non_blocking (bool, optional): Whether to use non-blocking transfers. Defaults to False.

Returns:
BatchBase: A new batch on the target device.

"""
args = {}
for feature_name, feature_value in self.as_dict().items():
args[feature_name] = feature_value.to(device=device, non_blocking=non_blocking)
return self.__class__(**args)

def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
"""
Record a CUDA stream for all tensors in the batch.

Args:
stream (torch.cuda.streams.Stream): The CUDA stream to record.

Returns:
None

"""
for feature_value in self.as_dict().values():
feature_value.record_stream(stream)

def pin_memory(self):
"""
Pin memory for all tensors in the batch.

Returns:
BatchBase: A new batch with pinned memory.

"""
args = {}
for feature_name, feature_value in self.as_dict().items():
args[feature_name] = feature_value.pin_memory()
return self.__class__(**args)

def __repr__(self) -> str:
"""
Generate a string representation of the batch.

Returns:
str: A string representation of the batch.

"""
def obj2str(v):
return f"{v.size()}" if hasattr(v, "size") else f"{v.length_per_key()}"

return "\n".join([f"{k}: {obj2str(v)}," for k, v in self.as_dict().items()])

@property
def batch_size(self) -> int:
"""
Get the batch size from the tensors in the batch.

Returns:
int: The batch size.

Raises:
Exception: If the batch size cannot be determined from the tensors.

"""
for tensor in self.as_dict().values():
if tensor is None:
continue
Expand All @@ -51,11 +113,32 @@ def batch_size(self) -> int:

@dataclass
class DataclassBatch(BatchBase):
"""
A batch class that uses dataclasses to define its fields.

Attributes:
None

"""
@classmethod
def feature_names(cls):
"""
Get the feature names of the dataclass.

Returns:
List[str]: A list of feature names.

"""
return list(cls.__dataclass_fields__.keys())

def as_dict(self):
"""
Convert the dataclass batch into a dictionary representation.

Returns:
Dict: A dictionary representation of the batch.

"""
return {
feature_name: getattr(self, feature_name)
for feature_name in self.feature_names()
Expand All @@ -64,7 +147,18 @@ def as_dict(self):

@staticmethod
def from_schema(name: str, schema):
"""Instantiates a custom batch subclass if all columns can be represented as a torch.Tensor."""
"""
Instantiate a custom batch subclass if all columns can be represented as a torch.Tensor.

Args:
name (str): The name of the custom batch class.
schema: The schema or structure of the batch.

Returns:
Type[DataclassBatch]: A custom batch class.

"""

return dataclasses.make_dataclass(
cls_name=name,
fields=[(name, torch.Tensor, dataclasses.field(default=None)) for name in schema.names],
Expand All @@ -73,6 +167,17 @@ def from_schema(name: str, schema):

@staticmethod
def from_fields(name: str, fields: dict):
"""
Create a custom batch subclass from a set of fields.

Args:
name (str): The name of the custom batch class.
fields (dict): A dictionary specifying the fields and their types.

Returns:
Type[DataclassBatch]: A custom batch class.

"""
return dataclasses.make_dataclass(
cls_name=name,
fields=[(_name, _type, dataclasses.field(default=None)) for _name, _type in fields.items()],
Expand All @@ -81,5 +186,19 @@ def from_fields(name: str, fields: dict):


class DictionaryBatch(BatchBase, dict):
"""
A batch class that represents data as a dictionary.

Attributes:
None

"""
def as_dict(self) -> Dict:
"""
Convert the dictionary batch into a dictionary representation.

Returns:
Dict: A dictionary representation of the batch.

"""
return self
89 changes: 71 additions & 18 deletions common/checkpointing/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,30 @@


class Snapshot:
"""Checkpoints using torchsnapshot.

Also saves step to be updated by the training loop.

"""
Checkpoints using torchsnapshot. Also saves step to be updated by the training loop.
"""

def __init__(self, save_dir: str, state: Dict[str, Any]) -> None:
"""
Initializes a Snapshot object.

Args:
save_dir (str): Directory where checkpoints will be saved.
state (Dict[str, Any]): State dictionary containing checkpoint information.
"""
self.save_dir = save_dir
self.state = state
self.state["extra_state"] = torchsnapshot.StateDict(step=0, walltime=0.0)

@property
def step(self):
"""Get the current training step."""
return self.state["extra_state"]["step"]

@step.setter
def step(self, step: int) -> None:
"""Set the current training step."""
self.state["extra_state"]["step"] = step

@property
Expand All @@ -41,7 +48,15 @@ def walltime(self, walltime: float) -> None:
self.state["extra_state"]["walltime"] = walltime

def save(self, global_step: int) -> "PendingSnapshot":
"""Saves checkpoint with given global_step."""
"""
Saves a checkpoint with a given global step.

Args:
global_step (int): The global step to associate with the checkpoint.

Returns:
PendingSnapshot: A pending snapshot object.
"""
path = os.path.join(self.save_dir, str(global_step))
logging.info(f"Saving snapshot global_step {global_step} to {path}.")
start_time = time.time()
Expand All @@ -58,7 +73,12 @@ def save(self, global_step: int) -> "PendingSnapshot":
return snapshot

def restore(self, checkpoint: str) -> None:
"""Restores a given checkpoint."""
"""
Restores a given checkpoint.

Args:
checkpoint (str): Path to the checkpoint to restore.
"""
snapshot = torchsnapshot.Snapshot(path=checkpoint)
logging.info(f"Restoring snapshot from {snapshot.path}.")
start_time = time.time()
Expand All @@ -83,12 +103,17 @@ def get_torch_snapshot(
global_step: Optional[int] = None,
missing_ok: bool = False,
) -> torchsnapshot.Snapshot:
"""Get torch stateless snapshot, without actually loading it.
Args:
snapshot_path: path to the model snapshot
global_step: restores from this checkpoint if specified.
missing_ok: if True and checkpoints do not exist, returns without restoration.
"""
Get a torch stateless snapshot, without actually loading it.

Args:
snapshot_path (str): Path to the model snapshot.
global_step (int, optional): Restores from this checkpoint if specified.
missing_ok (bool): If True and checkpoints do not exist, returns without restoration.

Returns:
torchsnapshot.Snapshot: A torch snapshot object.
"""
path = get_checkpoint(snapshot_path, global_step, missing_ok)
logging.info(f"Loading snapshot from {path}.")
return torchsnapshot.Snapshot(path=path)
Expand All @@ -100,13 +125,14 @@ def load_snapshot_to_weight(
snapshot_emb_name: str,
weight_tensor,
) -> None:
"""Loads pretrained embedding from the snapshot to the model.
Utilise partial lodaing meachanism from torchsnapshot.
Args:
embedding_snapshot: Path to the snapshot containing pretrained embeddings (EBC).
snapshot_emb_name: Name of the layer in the *snapshot* model, containing the EBC.
weight_tensor: embeddings tensor of *current* model, where the embeddings will be loaded.
"""
Loads pretrained embedding from the snapshot to the model.

Args:
embedding_snapshot (torchsnapshot.Snapshot): Path to the snapshot containing pretrained embeddings (EBC).
snapshot_emb_name (str): Name of the layer in the snapshot model containing the EBC.
weight_tensor: Embeddings tensor of the current model where the embeddings will be loaded.
"""
start_time = time.time()
manifest = embedding_snapshot.get_manifest()
for path in manifest.keys():
Expand Down Expand Up @@ -209,7 +235,22 @@ def get_checkpoint(


def get_checkpoints(save_dir: str) -> List[str]:
"""Gets all checkpoints that have been fully written."""
"""
Get a list of fully written checkpoints in the specified directory.

This function retrieves a list of fully written checkpoints in the given directory.
Checkpoints that are considered fully written include those that have a
corresponding snapshot metadata file.

Args:
save_dir (str): The directory where checkpoints are stored.

Returns:
List[str]: A list of fully written checkpoint paths.

Note:
Checkpoints are sorted by their numeric filenames in ascending order.
"""
checkpoints = []
fs = infer_fs(save_dir)
if fs.exists(save_dir):
Expand All @@ -232,6 +273,18 @@ def wait_for_evaluators(
global_step: int,
timeout: int,
) -> None:
"""
Waits for all evaluators to finish and checks for their completion status.

Args:
save_dir (str): Directory where checkpoints are saved.
partition_names (List[str]): List of partition names to check for completion.
global_step (int): The global step for which to wait for evaluators.
timeout (int): Maximum time in seconds to wait for evaluators to finish.

Returns:
None: This function returns nothing but logs the progress and results.
"""
logging.info("Waiting for all evaluators to finish.")
start_time = time.time()

Expand Down
21 changes: 21 additions & 0 deletions common/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,15 @@


def maybe_setup_tensorflow():
"""
Try to import TensorFlow and disable GPU devices if TensorFlow is available.

This function checks if TensorFlow is installed and, if so, disables GPU devices used by TensorFlow to avoid conflicts with PyTorch.

Returns:
None

"""
try:
import tensorflow as tf
except ImportError:
Expand All @@ -14,6 +23,18 @@ def maybe_setup_tensorflow():


def setup_and_get_device(tf_ok: bool = True) -> torch.device:
"""
Set up the distributed environment and get the appropriate torch device.

This function sets up the distributed environment using PyTorch's `dist.init_process_group` and retrieves the appropriate torch device based on GPU availability and local rank.

Args:
tf_ok (bool, optional): Whether to run `maybe_setup_tensorflow` to disable TensorFlow GPU devices. Defaults to True.

Returns:
torch.device: The torch device for the current process.

"""
if tf_ok:
maybe_setup_tensorflow()

Expand Down
Loading