Skip to content

Commit 7dd4b1d

Browse files
author
Yifu Wang
committed
Polish API documentation
1 parent c8c5430 commit 7dd4b1d

File tree

8 files changed

+115
-194
lines changed

8 files changed

+115
-194
lines changed

docs/source/api_reference.rst

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,7 @@ API Reference
33

44
.. autoclass:: torchsnapshot.Snapshot
55
:members:
6-
:undoc-members:
6+
7+
.. autoclass:: torchsnapshot.StateDict
8+
9+
.. autoclass:: torchsnapshot.RNGState

docs/source/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,4 @@
8282
}
8383

8484
add_module_names = False
85+
autodoc_member_order = 'bysource'

docs/source/getting_started.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ Objects within a snapshot can be efficiently accessed without fetching the entir
9999
Taking a Snapshot Asynchronously
100100
--------------------------------
101101

102-
When host memory is abundant, users can leverage it with :func:`Snapshot.async_take() <torchsnapshot.Snapshot.async_take>` to allow training to resume before all storage I/O completes. :func:`Snapshot.async_take() <torchsnapshot.Snapshot.async_take>` return as soon as it stages the snapshot content in host RAM and schedules storage I/O in background. This can drastically reduce the time blocked for checkpointing especially when the underly storage is slow.
102+
When host memory is abundant, users can leverage it with :func:`Snapshot.async_take() <torchsnapshot.Snapshot.async_take>` to allow training to resume before all storage I/O completes. :func:`Snapshot.async_take() <torchsnapshot.Snapshot.async_take>` returns as soon as it stages the snapshot content in host RAM and schedules storage I/O in background. This can drastically reduce the time blocked for checkpointing especially when the underly storage is slow.
103103

104104

105105
.. code-block:: Python

docs/source/index.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ TorchSnapshot API
1515

1616
getting_started.rst
1717
api_reference.rst
18-
utilities.rst
1918

2019
Examples
2120
--------

docs/source/utilities.rst

Lines changed: 0 additions & 17 deletions
This file was deleted.

torchsnapshot/rng_state.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,25 +12,32 @@
1212

1313
class RNGState:
1414
"""
15-
When captured in app state, it is guaranteed that rng states will be the
16-
same after ``Snapshot.take`` and ``Snapshot.restore``.
15+
A special stateful object for saving and restoring global RNG state.
1716
18-
::
17+
When captured in the application state, it is guaranteed that the global
18+
RNG state is set to the same values after taking the snapshot and after
19+
restoring from the snapshot.
20+
21+
Example:
1922
20-
app_state = {
21-
"rng_state": RNGState(),
22-
}
23-
snapshot = Snapshot.take("foo/bar", app_state, backend=...)
24-
after_take = torch.rand(1)
23+
::
2524
26-
snapshot.restore(app_state)
27-
after_restore = torch.rand(1)
25+
>>> Snapshot.take(
26+
>>> path="foo/bar",
27+
>>> app_state={"rng_state": RNGState()},
28+
>>> )
29+
>>> after_take = torch.rand(1)
2830
29-
torch.testing.assert_close(after_take, after_restore)
31+
>>> # In the same process or in another process
32+
>>> snapshot = Snapshot(path="foo/bar")
33+
>>> snapshot.restore(app_state)
34+
>>> after_restore = torch.rand(1)
3035
31-
TODO augment this to capture rng states other than torch.get_rng_state().
36+
>>> torch.testing.assert_close(after_take, after_restore)
3237
"""
3338

39+
# TODO: augment this to capture rng states other than torch.get_rng_state()
40+
3441
def state_dict(self) -> Dict[str, torch.Tensor]:
3542
return {"rng_state": torch.get_rng_state()}
3643

torchsnapshot/snapshot.py

Lines changed: 85 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -65,82 +65,20 @@
6565

6666
class Snapshot:
6767
"""
68-
Snapshot represents the persisted program state at one point in time.
68+
Create a reference to an existing snapshot.
6969
70-
Basic usage:
71-
::
70+
Args:
71+
path (str): The path to the snapshot. This should be the same as the
72+
``path`` argument used for :func:`Snapshot.take` when the snapshot
73+
was taken.
7274
73-
# Define the program state
74-
app_state = {"model": model, "optimizer": optimizer"}
75+
pg (ProcessGroup, optional): The process group for the participants of
76+
:meth:`Snapshot.restore`. If none, the default process group will be
77+
used.
7578
76-
# At an appropriate time, persist the program state as a snapshot
77-
snapshot = Snapshot.take(path=path, app_state=app_state)
78-
79-
# On resuming, restore the program state from a snapshot
80-
snapshot.restore(app_state)
81-
82-
Overview:
83-
84-
At high level, torchsnapshot saves each value in state dicts as a
85-
file/object in the corresponding storage system. It also saves a manifest
86-
describing the persisted values and the structure of the original state
87-
dict.
88-
89-
Comparing with :py:func:`torch.save` and :py:func:`torch.load`, torchsnapshot:
90-
91-
- Enables efficient random access of persisted model weights.
92-
93-
- Accelerates persistence by parallelizing writes.
94-
95-
- For replicated values, persistence is parallelized across ranks.
96-
97-
- Enables flexible yet robust elasticity (changing world size on
98-
restore).
99-
100-
101-
Elasticity:
102-
103-
Elasticity is implemented via correctly making persisted values
104-
available to a newly joined rank, and having it correctly restores the
105-
corresponding runtime objects from the persisted values.
106-
107-
For the purpose of elasticity, all persisted values fall into one of
108-
the categories in [per-rank, replicated, sharded].
109-
110-
per-rank:
111-
112-
By default, all non-sharded values are treated as per-rank.
113-
114-
On save, the value is only saved by the owning rank.
115-
116-
On load, the value is only made available to the same rank.
117-
118-
replicated:
119-
120-
A user can suggest any non-sharded value as replicated via glob
121-
patterns.
122-
123-
On save, the value is only saved once (can be by any rank).
124-
125-
On load, the value is made available to all ranks, including newly
126-
joined ranks.
127-
128-
sharded:
129-
130-
Specific types are always treated as sharded (e.g. ShardedTensor).
131-
132-
On save, all shard-owning ranks save their shards.
133-
134-
On load, all shards are made available to all ranks, including
135-
newly joined rank. All ranks can read from all shards for
136-
restoring the runtime object from persisted values.
137-
(ShardedTensor resharding is powered by torch.dist.checkpoint).
138-
139-
If all values within a snapshot are either replicated or sharded, the
140-
snapshot is automatically reshard-able.
141-
142-
If a snapshot contains per-rank values, it cannot be resharded unless
143-
the per-rank values are explicitly coerced to replicated on load.
79+
storage_options (Dict[str, Any], optional): Additional keyword options
80+
for the storage plugin to use. See each storage plugin's documentation
81+
for customizations.
14482
"""
14583

14684
def __init__(
@@ -149,18 +87,6 @@ def __init__(
14987
pg: Optional[dist.ProcessGroup] = None,
15088
storage_options: Optional[Dict[str, Any]] = None,
15189
) -> None:
152-
"""
153-
Initializes the reference to an existing snapshot.
154-
155-
Args:
156-
path: The location of the snapshot.
157-
pg: The process group for the processes restoring from the snapshot.
158-
When unspecified:
159-
- If distributed is initialized, the global process group will be used.
160-
- If distributed is not initialized, single process is assumed.
161-
storage_options: Additional keyword options for the StoragePlugin to use.
162-
See each StoragePlugin's documentation for customizations.
163-
"""
16490
self.path: str = path
16591
self.pg: Optional[dist.ProcessGroup] = pg
16692
self._metadata: Optional[SnapshotMetadata] = None
@@ -179,20 +105,43 @@ def take(
179105
] = None,
180106
) -> "Snapshot":
181107
"""
182-
Take a snapshot from the program state.
108+
Takes a snapshot of the application state.
183109
184110
Args:
185-
app_state: The program state to take the snapshot from.
186-
path: The location to save the snapshot.
187-
pg: The process group for the processes taking the snapshot.
188-
When unspecified:
189-
- If distributed is initialized, the global process group will be used.
190-
- If distributed is not initialized, single process is assumed.
191-
replicated: A list of glob patterns for hinting the matching paths
192-
as replicated. Note that patterns not specified by all ranks
193-
are ignored.
194-
storage_options: Additional keyword options for the StoragePlugin to use.
195-
See each StoragePlugin's documentation for customizations.
111+
app_state (Dict[str, Stateful]): The application state to persist.
112+
It take the form of a dictionary, with the keys being
113+
user-defined strings and the values being stateful objects.
114+
Stateful objects are objects that exposes ``.state_dict()`` and
115+
``.load_state_dict()`` methods. Common PyTorch objects such as
116+
:class:`torch.nn.Module`, :class:`torch.optim.Optimizer`, and
117+
LR schedulers all qualify as stateful objects.
118+
119+
path (str): The location to save the snapshot. ``path`` can have a
120+
URI prefix (e.g. ``s3://``) that specifies a storage backend.
121+
If no URI prefix is supplied, ``path`` is assumed to be a file
122+
system location. For distributed snapshot, if ``path`` is
123+
inconsistent across participating ranks, the value specified by
124+
rank 0 will be used. For multi-host snapshot, ``path`` needs to
125+
be a location accessible by all hosts.
126+
127+
.. note:: ``path`` must **not** point to an existing snapshot.
128+
129+
pg (ProcessGroup, optional): The process group for the participants
130+
of :meth:`Snapshot.take`. If none, the default process group will
131+
be used.
132+
133+
replicated (List[str], optional): Glob patterns for marking
134+
checkpoint content as replicated. Matching objects will be deduped
135+
and load-balanced across ranks.
136+
137+
.. note:: The replication property is automatically inferred
138+
for ``DistributedDataParallel``. Only specify this argument
139+
if your model has fully replicated states but does not use
140+
``DistributedDataParallel``.
141+
142+
storage_options (Dict[str, Any], optional): Additional keyword
143+
options for the storage plugin to use. See each storage plugin's
144+
documentation for customizations.
196145
197146
Returns:
198147
The newly taken snapshot.
@@ -252,31 +201,23 @@ def async_take(
252201
] = None,
253202
) -> "PendingSnapshot":
254203
"""
255-
Asynchronously take a snapshot from the program state.
204+
Asynchronously takes a snapshot from the application state.
256205
257-
This method creates a consistent snapshot of the app state (i.e.
258-
changes to the app state after this method returns have no effect on
259-
the snapshot). The asynchronicity is a result of performing storage I/O
260-
in the background.
206+
This function is identical to :func:`Snapshot.take`, except that it
207+
returns early and performs as much I/O operations in the background as
208+
possible, allowing training to resume early.
261209
262210
Args:
263-
app_state: The program state to take the snapshot from.
264-
path: The location to save the snapshot.
265-
pg: The process group for the processes taking the snapshot.
266-
When unspecified:
267-
- If distributed is initialized, the global process group will be used.
268-
- If distributed is not initialized, single process is assumed.
269-
replicated: A list of glob patterns for hinting the matching paths
270-
as replicated. Note that patterns not specified by all ranks
271-
are ignored.
272-
storage_options: Additional keyword options for the StoragePlugin to use.
273-
See each StoragePlugin's documentation for customizations.
211+
app_state (Dict[str, Stateful]): Same as the ``app_state`` argument of :func:`Snapshot.take`.
212+
path (str): Same as the ``path`` argument of :func:`Snapshot.take`.
213+
pg (ProcessGroup, optional): Same as the ``pg`` argument of :func:`Snapshot.take`.
214+
replicated (List[str], optional): Same as the ``replicated`` argument of :func:`Snapshot.take`.
215+
storage_options (Dict[str, Any], optional): Same as the ``storage_options`` argument of :func:`Snapshot.take`.
274216
275217
Returns:
276-
A handle with which the newly taken snapshot can be obtained via
277-
`.wait()`. Note that waiting on the handle is optional. The
278-
snapshot will be committed regardless of whether `.wait()` is
279-
invoked.
218+
A handle to the pending snapshot. The handle has exposes a
219+
``.done()`` method for querying the progress and a ``.wait()``
220+
method for waiting for the snapshot's completion.
280221
"""
281222
torch._C._log_api_usage_once("torchsnapshot.Snapshot.async_take")
282223
cls._validate_app_state(app_state)
@@ -436,11 +377,13 @@ def _take_impl(
436377

437378
def restore(self, app_state: AppState) -> None:
438379
"""
439-
Restores the program state from the snapshot.
380+
Restores the application state from the snapshot.
440381
441382
Args:
442-
app_state: The program state to restore from the snapshot.
443-
383+
app_state (Dict[str, Stateful]): The application state to restore.
384+
``app_state`` needs to be either identical to or a subset of the
385+
``app_state`` used for :func:`Snapshot.take` when the snapshot was
386+
taken.
444387
"""
445388
torch._C._log_api_usage_once("torchsnapshot.Snapshot.restore")
446389
self._validate_app_state(app_state)
@@ -505,31 +448,25 @@ def read_object(
505448
memory_budget_bytes: Optional[int] = None,
506449
) -> T:
507450
"""
508-
Read a persisted object from the snapshot's content.
509-
510-
The persisted object to read is specified by its path in the snapshot
511-
metadata. Available paths can be obtained via `snapshot.get_manifest()`.
451+
Reads an object from the snapshot's content.
512452
513-
A path in snapshot metadata follows the following format:
514-
515-
``RANK/STATEFUL_NAME/STATE_DICT_KEY[/NESTED_CONTAINER_KEY...]``
453+
Args:
454+
path (str): The path to the target object within the snapshot.
455+
``path`` is equivalent to the target object's key in the
456+
snapshot manifest and can be obtained via
457+
:meth:`Snapshot.get_manifest`.
516458
517-
The rank only matters when the persisted object is "per-rank".
518-
Arbitrary rank can be used when the persisted object is "replicated" or
519-
"sharded".
459+
obj_out (Any, optional): When specified, load the object in-place
460+
into ``obj_out`` if in-place load is supported for the object's
461+
type. Otherwise, ``obj_out`` is ignored.
520462
521-
If the persisted object is a sharded tensor, `obj_out` must be
522-
supplied. The supplied tensor can be either a tensor or sharded tensor.
523-
`read_object` will correctly populate `obj_out`'s data according to
524-
sharding spec.
463+
.. note::
464+
When the target object is a ``ShardedTensor``, ``obj_out``
465+
must be specified.
525466
526-
Args:
527-
path: The path to the persisted object.
528-
obj_out: If specified and the object type supports in-place load,
529-
`read_object` will directly read the persisted object into
530-
`obj_out`'s buffer.
531-
memory_budget_bytes: When specified, the read operation will keep
532-
the temporary memory buffer size below this threshold.
467+
memory_budget_bytes (int, optional): When specified, the read
468+
operation will keep the temporary memory buffer size below this
469+
threshold.
533470
534471
Returns:
535472
The object read from the snapshot's content.
@@ -595,10 +532,15 @@ def read_object(
595532

596533
def get_manifest(self) -> Dict[str, Entry]:
597534
"""
598-
Returns the snapshot's manifest.
535+
Returns the snapshot manifest.
536+
537+
Each entry in the dictionary corresponds to an object in the snapshot,
538+
with the keys being the logical paths to the objects and the values
539+
being the metadata describing the object. For distributed snapshots,
540+
the manifest contain entries for objects saved by all ranks.
599541
600542
Returns:
601-
The snapshot's manifest.
543+
The snapshot manifest.
602544
"""
603545
return copy.deepcopy(self.metadata.manifest)
604546

0 commit comments

Comments
 (0)