Skip to content

Commit 0df6f81

Browse files
Yifu Wangfacebook-github-bot
authored andcommitted
Polish API documentation (#123)
Summary: As title. Pull Request resolved: #123 Reviewed By: ananthsub Differential Revision: D40815842 Pulled By: yifuwang fbshipit-source-id: ed10e82e0d79cec7f0a5ea307f400d69b3e09752
1 parent c728806 commit 0df6f81

File tree

8 files changed

+116
-196
lines changed

8 files changed

+116
-196
lines changed

docs/source/api_reference.rst

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,7 @@ API Reference
33

44
.. autoclass:: torchsnapshot.Snapshot
55
:members:
6-
:undoc-members:
6+
7+
.. autoclass:: torchsnapshot.StateDict
8+
9+
.. autoclass:: torchsnapshot.RNGState

docs/source/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,4 @@
8282
}
8383

8484
add_module_names = False
85+
autodoc_member_order = "bysource"

docs/source/getting_started.rst

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ Objects within a snapshot can be efficiently accessed without fetching the entir
9999
Taking a Snapshot Asynchronously
100100
--------------------------------
101101

102-
When host memory is abundant, users can leverage it with :func:`Snapshot.async_take() <torchsnapshot.Snapshot.async_take>` to allow training to resume before all storage I/O completes. :func:`Snapshot.async_take() <torchsnapshot.Snapshot.async_take>` return as soon as it stages the snapshot content in host RAM and schedules storage I/O in background. This can drastically reduce the time blocked for checkpointing especially when the underly storage is slow.
102+
When host memory is abundant, users can leverage it with :func:`Snapshot.async_take() <torchsnapshot.Snapshot.async_take>` to allow training to resume before all storage I/O completes. :func:`Snapshot.async_take() <torchsnapshot.Snapshot.async_take>` returns as soon as it stages the snapshot content in host RAM and schedules storage I/O in background. This can drastically reduce the time blocked for checkpointing especially when the underly storage is slow.
103103

104104

105105
.. code-block:: Python
@@ -124,8 +124,7 @@ When host memory is abundant, users can leverage it with :func:`Snapshot.async_t
124124
Reproducibility
125125
---------------
126126

127-
TorchSnapshot provides a utility called :class:`RNGState <torchsnapshot.rng_state.RNGState>` to help users manage reproducibility. If an :class:`RNGState <torchsnapshot.rng_state.RNGState>` object is captured in the application state, TorchSnapshot ensures that the global RNG state is set to the same values after taking the snapshot and after restoring from the snapshot.
128-
127+
TorchSnapshot provides a utility called :class:`RNGState <torchsnapshot.rng_state.RNGState>` to help users manage reproducibility. If an :class:`RNGState <torchsnapshot.rng_state.RNGState>` object is captured in the application state, TorchSnapshot ensures that the global RNG state is set to the same values after restoring from the snapshot as it was after taking the snapshot.
129128

130129
.. code-block:: Python
131130

docs/source/index.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ TorchSnapshot API
1515

1616
getting_started.rst
1717
api_reference.rst
18-
utilities.rst
1918

2019
Examples
2120
--------

docs/source/utilities.rst

Lines changed: 0 additions & 17 deletions
This file was deleted.

torchsnapshot/rng_state.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,25 +12,32 @@
1212

1313
class RNGState:
1414
"""
15-
When captured in app state, it is guaranteed that rng states will be the
16-
same after ``Snapshot.take`` and ``Snapshot.restore``.
15+
A special stateful object for saving and restoring global RNG state.
1716
18-
::
17+
When captured in the application state, it is guaranteed that the global
18+
RNG state is set to the same values after restoring from the snapshot as it
19+
was after taking the snapshot.
20+
21+
Example:
1922
20-
app_state = {
21-
"rng_state": RNGState(),
22-
}
23-
snapshot = Snapshot.take("foo/bar", app_state, backend=...)
24-
after_take = torch.rand(1)
23+
::
2524
26-
snapshot.restore(app_state)
27-
after_restore = torch.rand(1)
25+
>>> Snapshot.take(
26+
>>> path="foo/bar",
27+
>>> app_state={"rng_state": RNGState()},
28+
>>> )
29+
>>> after_take = torch.rand(1)
2830
29-
torch.testing.assert_close(after_take, after_restore)
31+
>>> # In the same process or in another process
32+
>>> snapshot = Snapshot(path="foo/bar")
33+
>>> snapshot.restore(app_state)
34+
>>> after_restore = torch.rand(1)
3035
31-
TODO augment this to capture rng states other than torch.get_rng_state().
36+
>>> torch.testing.assert_close(after_take, after_restore)
3237
"""
3338

39+
# TODO: augment this to capture rng states other than torch.get_rng_state()
40+
3441
def state_dict(self) -> Dict[str, torch.Tensor]:
3542
return {"rng_state": torch.get_rng_state()}
3643

torchsnapshot/snapshot.py

Lines changed: 85 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -65,82 +65,20 @@
6565

6666
class Snapshot:
6767
"""
68-
Snapshot represents the persisted program state at one point in time.
68+
Create a reference to an existing snapshot.
6969
70-
Basic usage:
71-
::
70+
Args:
71+
path (str): The path to the snapshot. This should be the same as the
72+
``path`` argument used for :func:`Snapshot.take` when the snapshot
73+
was taken.
7274
73-
# Define the program state
74-
app_state = {"model": model, "optimizer": optimizer"}
75+
pg (ProcessGroup, optional): The process group for the participants of
76+
:meth:`Snapshot.restore`. If none, the default process group will be
77+
used.
7578
76-
# At an appropriate time, persist the program state as a snapshot
77-
snapshot = Snapshot.take(path=path, app_state=app_state)
78-
79-
# On resuming, restore the program state from a snapshot
80-
snapshot.restore(app_state)
81-
82-
Overview:
83-
84-
At high level, torchsnapshot saves each value in state dicts as a
85-
file/object in the corresponding storage system. It also saves a manifest
86-
describing the persisted values and the structure of the original state
87-
dict.
88-
89-
Comparing with :py:func:`torch.save` and :py:func:`torch.load`, torchsnapshot:
90-
91-
- Enables efficient random access of persisted model weights.
92-
93-
- Accelerates persistence by parallelizing writes.
94-
95-
- For replicated values, persistence is parallelized across ranks.
96-
97-
- Enables flexible yet robust elasticity (changing world size on
98-
restore).
99-
100-
101-
Elasticity:
102-
103-
Elasticity is implemented via correctly making persisted values
104-
available to a newly joined rank, and having it correctly restores the
105-
corresponding runtime objects from the persisted values.
106-
107-
For the purpose of elasticity, all persisted values fall into one of
108-
the categories in [per-rank, replicated, sharded].
109-
110-
per-rank:
111-
112-
By default, all non-sharded values are treated as per-rank.
113-
114-
On save, the value is only saved by the owning rank.
115-
116-
On load, the value is only made available to the same rank.
117-
118-
replicated:
119-
120-
A user can suggest any non-sharded value as replicated via glob
121-
patterns.
122-
123-
On save, the value is only saved once (can be by any rank).
124-
125-
On load, the value is made available to all ranks, including newly
126-
joined ranks.
127-
128-
sharded:
129-
130-
Specific types are always treated as sharded (e.g. ShardedTensor).
131-
132-
On save, all shard-owning ranks save their shards.
133-
134-
On load, all shards are made available to all ranks, including
135-
newly joined rank. All ranks can read from all shards for
136-
restoring the runtime object from persisted values.
137-
(ShardedTensor resharding is powered by torch.dist.checkpoint).
138-
139-
If all values within a snapshot are either replicated or sharded, the
140-
snapshot is automatically reshard-able.
141-
142-
If a snapshot contains per-rank values, it cannot be resharded unless
143-
the per-rank values are explicitly coerced to replicated on load.
79+
storage_options (Dict[str, Any], optional): Additional keyword options
80+
for the storage plugin to use. See each storage plugin's documentation
81+
for customizations.
14482
"""
14583

14684
def __init__(
@@ -149,18 +87,6 @@ def __init__(
14987
pg: Optional[dist.ProcessGroup] = None,
15088
storage_options: Optional[Dict[str, Any]] = None,
15189
) -> None:
152-
"""
153-
Initializes the reference to an existing snapshot.
154-
155-
Args:
156-
path: The location of the snapshot.
157-
pg: The process group for the processes restoring from the snapshot.
158-
When unspecified:
159-
- If distributed is initialized, the global process group will be used.
160-
- If distributed is not initialized, single process is assumed.
161-
storage_options: Additional keyword options for the StoragePlugin to use.
162-
See each StoragePlugin's documentation for customizations.
163-
"""
16490
self.path: str = path
16591
self.pg: Optional[dist.ProcessGroup] = pg
16692
self._metadata: Optional[SnapshotMetadata] = None
@@ -179,20 +105,43 @@ def take(
179105
] = None,
180106
) -> "Snapshot":
181107
"""
182-
Take a snapshot from the program state.
108+
Takes a snapshot of the application state.
183109
184110
Args:
185-
app_state: The program state to take the snapshot from.
186-
path: The location to save the snapshot.
187-
pg: The process group for the processes taking the snapshot.
188-
When unspecified:
189-
- If distributed is initialized, the global process group will be used.
190-
- If distributed is not initialized, single process is assumed.
191-
replicated: A list of glob patterns for hinting the matching paths
192-
as replicated. Note that patterns not specified by all ranks
193-
are ignored.
194-
storage_options: Additional keyword options for the StoragePlugin to use.
195-
See each StoragePlugin's documentation for customizations.
111+
app_state (Dict[str, Stateful]): The application state to persist.
112+
It takes the form of a dictionary, with the keys being
113+
user-defined strings and the values being stateful objects.
114+
Stateful objects are objects that exposes ``.state_dict()`` and
115+
``.load_state_dict()`` methods. Common PyTorch objects such as
116+
:class:`torch.nn.Module`, :class:`torch.optim.Optimizer`, and
117+
LR schedulers all qualify as stateful objects.
118+
119+
path (str): The location to save the snapshot. ``path`` can have a
120+
URI prefix (e.g. ``s3://``) that specifies a storage backend.
121+
If no URI prefix is supplied, ``path`` is assumed to be a file
122+
system location. For distributed snapshot, if ``path`` is
123+
inconsistent across participating ranks, the value specified by
124+
rank 0 will be used. For multi-host snapshot, ``path`` needs to
125+
be a location accessible by all hosts.
126+
127+
.. note:: ``path`` must **not** point to an existing snapshot.
128+
129+
pg (ProcessGroup, optional): The process group for the participants
130+
of :meth:`Snapshot.take`. If none, the default process group will
131+
be used.
132+
133+
replicated (List[str], optional): Glob patterns for marking
134+
checkpoint content as replicated. Matching objects will be deduped
135+
and load-balanced across ranks.
136+
137+
.. note:: The replication property is automatically inferred
138+
for ``DistributedDataParallel``. Only specify this argument
139+
if your model has fully replicated states but does not use
140+
``DistributedDataParallel``.
141+
142+
storage_options (Dict[str, Any], optional): Additional keyword
143+
options for the storage plugin to use. See each storage plugin's
144+
documentation for customizations.
196145
197146
Returns:
198147
The newly taken snapshot.
@@ -252,31 +201,23 @@ def async_take(
252201
] = None,
253202
) -> "PendingSnapshot":
254203
"""
255-
Asynchronously take a snapshot from the program state.
204+
Asynchronously takes a snapshot from the application state.
256205
257-
This method creates a consistent snapshot of the app state (i.e.
258-
changes to the app state after this method returns have no effect on
259-
the snapshot). The asynchronicity is a result of performing storage I/O
260-
in the background.
206+
This function is identical to :func:`Snapshot.take`, except that it
207+
returns early and performs as much I/O operations in the background as
208+
possible, allowing training to resume early.
261209
262210
Args:
263-
app_state: The program state to take the snapshot from.
264-
path: The location to save the snapshot.
265-
pg: The process group for the processes taking the snapshot.
266-
When unspecified:
267-
- If distributed is initialized, the global process group will be used.
268-
- If distributed is not initialized, single process is assumed.
269-
replicated: A list of glob patterns for hinting the matching paths
270-
as replicated. Note that patterns not specified by all ranks
271-
are ignored.
272-
storage_options: Additional keyword options for the StoragePlugin to use.
273-
See each StoragePlugin's documentation for customizations.
211+
app_state (Dict[str, Stateful]): Same as the ``app_state`` argument of :func:`Snapshot.take`.
212+
path (str): Same as the ``path`` argument of :func:`Snapshot.take`.
213+
pg (ProcessGroup, optional): Same as the ``pg`` argument of :func:`Snapshot.take`.
214+
replicated (List[str], optional): Same as the ``replicated`` argument of :func:`Snapshot.take`.
215+
storage_options (Dict[str, Any], optional): Same as the ``storage_options`` argument of :func:`Snapshot.take`.
274216
275217
Returns:
276-
A handle with which the newly taken snapshot can be obtained via
277-
`.wait()`. Note that waiting on the handle is optional. The
278-
snapshot will be committed regardless of whether `.wait()` is
279-
invoked.
218+
A handle to the pending snapshot. The handle has exposes a
219+
``.done()`` method for querying the progress and a ``.wait()``
220+
method for waiting for the snapshot's completion.
280221
"""
281222
torch._C._log_api_usage_once("torchsnapshot.Snapshot.async_take")
282223
cls._validate_app_state(app_state)
@@ -436,11 +377,13 @@ def _take_impl(
436377

437378
def restore(self, app_state: AppState) -> None:
438379
"""
439-
Restores the program state from the snapshot.
380+
Restores the application state from the snapshot.
440381
441382
Args:
442-
app_state: The program state to restore from the snapshot.
443-
383+
app_state (Dict[str, Stateful]): The application state to restore.
384+
``app_state`` needs to be either identical to or a subset of the
385+
``app_state`` used for :func:`Snapshot.take` when the snapshot was
386+
taken.
444387
"""
445388
torch._C._log_api_usage_once("torchsnapshot.Snapshot.restore")
446389
self._validate_app_state(app_state)
@@ -505,31 +448,25 @@ def read_object(
505448
memory_budget_bytes: Optional[int] = None,
506449
) -> T:
507450
"""
508-
Read a persisted object from the snapshot's content.
509-
510-
The persisted object to read is specified by its path in the snapshot
511-
metadata. Available paths can be obtained via `snapshot.get_manifest()`.
451+
Reads an object from the snapshot's content.
512452
513-
A path in snapshot metadata follows the following format:
514-
515-
``RANK/STATEFUL_NAME/STATE_DICT_KEY[/NESTED_CONTAINER_KEY...]``
453+
Args:
454+
path (str): The path to the target object within the snapshot.
455+
``path`` is equivalent to the target object's key in the
456+
snapshot manifest and can be obtained via
457+
:meth:`Snapshot.get_manifest`.
516458
517-
The rank only matters when the persisted object is "per-rank".
518-
Arbitrary rank can be used when the persisted object is "replicated" or
519-
"sharded".
459+
obj_out (Any, optional): When specified, load the object in-place
460+
into ``obj_out`` if in-place load is supported for the object's
461+
type. Otherwise, ``obj_out`` is ignored.
520462
521-
If the persisted object is a sharded tensor, `obj_out` must be
522-
supplied. The supplied tensor can be either a tensor or sharded tensor.
523-
`read_object` will correctly populate `obj_out`'s data according to
524-
sharding spec.
463+
.. note::
464+
When the target object is a ``ShardedTensor``, ``obj_out``
465+
must be specified.
525466
526-
Args:
527-
path: The path to the persisted object.
528-
obj_out: If specified and the object type supports in-place load,
529-
`read_object` will directly read the persisted object into
530-
`obj_out`'s buffer.
531-
memory_budget_bytes: When specified, the read operation will keep
532-
the temporary memory buffer size below this threshold.
467+
memory_budget_bytes (int, optional): When specified, the read
468+
operation will keep the temporary memory buffer size below this
469+
threshold.
533470
534471
Returns:
535472
The object read from the snapshot's content.
@@ -595,10 +532,15 @@ def read_object(
595532

596533
def get_manifest(self) -> Dict[str, Entry]:
597534
"""
598-
Returns the snapshot's manifest.
535+
Returns the snapshot manifest.
536+
537+
Each entry in the dictionary corresponds to an object in the snapshot,
538+
with the keys being the logical paths to the objects and the values
539+
being the metadata describing the object. For distributed snapshots,
540+
the manifest contain entries for objects saved by all ranks.
599541
600542
Returns:
601-
The snapshot's manifest.
543+
The snapshot manifest.
602544
"""
603545
return copy.deepcopy(self.metadata.manifest)
604546

0 commit comments

Comments
 (0)