6565
6666class Snapshot :
6767 """
68- Snapshot represents the persisted program state at one point in time .
68+ Create a reference to an existing snapshot .
6969
70- Basic usage:
71- ::
70+ Args:
71+ path (str): The path to the snapshot. This should be the same as the
72+ ``path`` argument used for :func:`Snapshot.take` when the snapshot
73+ was taken.
7274
73- # Define the program state
74- app_state = {"model": model, "optimizer": optimizer"}
75+ pg (ProcessGroup, optional): The process group for the participants of
76+ :meth:`Snapshot.restore`. If none, the default process group will be
77+ used.
7578
76- # At an appropriate time, persist the program state as a snapshot
77- snapshot = Snapshot.take(path=path, app_state=app_state)
78-
79- # On resuming, restore the program state from a snapshot
80- snapshot.restore(app_state)
81-
82- Overview:
83-
84- At high level, torchsnapshot saves each value in state dicts as a
85- file/object in the corresponding storage system. It also saves a manifest
86- describing the persisted values and the structure of the original state
87- dict.
88-
89- Comparing with :py:func:`torch.save` and :py:func:`torch.load`, torchsnapshot:
90-
91- - Enables efficient random access of persisted model weights.
92-
93- - Accelerates persistence by parallelizing writes.
94-
95- - For replicated values, persistence is parallelized across ranks.
96-
97- - Enables flexible yet robust elasticity (changing world size on
98- restore).
99-
100-
101- Elasticity:
102-
103- Elasticity is implemented via correctly making persisted values
104- available to a newly joined rank, and having it correctly restores the
105- corresponding runtime objects from the persisted values.
106-
107- For the purpose of elasticity, all persisted values fall into one of
108- the categories in [per-rank, replicated, sharded].
109-
110- per-rank:
111-
112- By default, all non-sharded values are treated as per-rank.
113-
114- On save, the value is only saved by the owning rank.
115-
116- On load, the value is only made available to the same rank.
117-
118- replicated:
119-
120- A user can suggest any non-sharded value as replicated via glob
121- patterns.
122-
123- On save, the value is only saved once (can be by any rank).
124-
125- On load, the value is made available to all ranks, including newly
126- joined ranks.
127-
128- sharded:
129-
130- Specific types are always treated as sharded (e.g. ShardedTensor).
131-
132- On save, all shard-owning ranks save their shards.
133-
134- On load, all shards are made available to all ranks, including
135- newly joined rank. All ranks can read from all shards for
136- restoring the runtime object from persisted values.
137- (ShardedTensor resharding is powered by torch.dist.checkpoint).
138-
139- If all values within a snapshot are either replicated or sharded, the
140- snapshot is automatically reshard-able.
141-
142- If a snapshot contains per-rank values, it cannot be resharded unless
143- the per-rank values are explicitly coerced to replicated on load.
79+ storage_options (Dict[str, Any], optional): Additional keyword options
80+ for the storage plugin to use. See each storage plugin's documentation
81+ for customizations.
14482 """
14583
14684 def __init__ (
@@ -149,18 +87,6 @@ def __init__(
14987 pg : Optional [dist .ProcessGroup ] = None ,
15088 storage_options : Optional [Dict [str , Any ]] = None ,
15189 ) -> None :
152- """
153- Initializes the reference to an existing snapshot.
154-
155- Args:
156- path: The location of the snapshot.
157- pg: The process group for the processes restoring from the snapshot.
158- When unspecified:
159- - If distributed is initialized, the global process group will be used.
160- - If distributed is not initialized, single process is assumed.
161- storage_options: Additional keyword options for the StoragePlugin to use.
162- See each StoragePlugin's documentation for customizations.
163- """
16490 self .path : str = path
16591 self .pg : Optional [dist .ProcessGroup ] = pg
16692 self ._metadata : Optional [SnapshotMetadata ] = None
@@ -179,20 +105,43 @@ def take(
179105 ] = None ,
180106 ) -> "Snapshot" :
181107 """
182- Take a snapshot from the program state.
108+ Takes a snapshot of the application state.
183109
184110 Args:
185- app_state: The program state to take the snapshot from.
186- path: The location to save the snapshot.
187- pg: The process group for the processes taking the snapshot.
188- When unspecified:
189- - If distributed is initialized, the global process group will be used.
190- - If distributed is not initialized, single process is assumed.
191- replicated: A list of glob patterns for hinting the matching paths
192- as replicated. Note that patterns not specified by all ranks
193- are ignored.
194- storage_options: Additional keyword options for the StoragePlugin to use.
195- See each StoragePlugin's documentation for customizations.
111+ app_state (Dict[str, Stateful]): The application state to persist.
112+ It takes the form of a dictionary, with the keys being
113+ user-defined strings and the values being stateful objects.
114+ Stateful objects are objects that exposes ``.state_dict()`` and
115+ ``.load_state_dict()`` methods. Common PyTorch objects such as
116+ :class:`torch.nn.Module`, :class:`torch.optim.Optimizer`, and
117+ LR schedulers all qualify as stateful objects.
118+
119+ path (str): The location to save the snapshot. ``path`` can have a
120+ URI prefix (e.g. ``s3://``) that specifies a storage backend.
121+ If no URI prefix is supplied, ``path`` is assumed to be a file
122+ system location. For distributed snapshot, if ``path`` is
123+ inconsistent across participating ranks, the value specified by
124+ rank 0 will be used. For multi-host snapshot, ``path`` needs to
125+ be a location accessible by all hosts.
126+
127+ .. note:: ``path`` must **not** point to an existing snapshot.
128+
129+ pg (ProcessGroup, optional): The process group for the participants
130+ of :meth:`Snapshot.take`. If none, the default process group will
131+ be used.
132+
133+ replicated (List[str], optional): Glob patterns for marking
134+ checkpoint content as replicated. Matching objects will be deduped
135+ and load-balanced across ranks.
136+
137+ .. note:: The replication property is automatically inferred
138+ for ``DistributedDataParallel``. Only specify this argument
139+ if your model has fully replicated states but does not use
140+ ``DistributedDataParallel``.
141+
142+ storage_options (Dict[str, Any], optional): Additional keyword
143+ options for the storage plugin to use. See each storage plugin's
144+ documentation for customizations.
196145
197146 Returns:
198147 The newly taken snapshot.
@@ -252,31 +201,23 @@ def async_take(
252201 ] = None ,
253202 ) -> "PendingSnapshot" :
254203 """
255- Asynchronously take a snapshot from the program state.
204+ Asynchronously takes a snapshot from the application state.
256205
257- This method creates a consistent snapshot of the app state (i.e.
258- changes to the app state after this method returns have no effect on
259- the snapshot). The asynchronicity is a result of performing storage I/O
260- in the background.
206+ This function is identical to :func:`Snapshot.take`, except that it
207+ returns early and performs as much I/O operations in the background as
208+ possible, allowing training to resume early.
261209
262210 Args:
263- app_state: The program state to take the snapshot from.
264- path: The location to save the snapshot.
265- pg: The process group for the processes taking the snapshot.
266- When unspecified:
267- - If distributed is initialized, the global process group will be used.
268- - If distributed is not initialized, single process is assumed.
269- replicated: A list of glob patterns for hinting the matching paths
270- as replicated. Note that patterns not specified by all ranks
271- are ignored.
272- storage_options: Additional keyword options for the StoragePlugin to use.
273- See each StoragePlugin's documentation for customizations.
211+ app_state (Dict[str, Stateful]): Same as the ``app_state`` argument of :func:`Snapshot.take`.
212+ path (str): Same as the ``path`` argument of :func:`Snapshot.take`.
213+ pg (ProcessGroup, optional): Same as the ``pg`` argument of :func:`Snapshot.take`.
214+ replicated (List[str], optional): Same as the ``replicated`` argument of :func:`Snapshot.take`.
215+ storage_options (Dict[str, Any], optional): Same as the ``storage_options`` argument of :func:`Snapshot.take`.
274216
275217 Returns:
276- A handle with which the newly taken snapshot can be obtained via
277- `.wait()`. Note that waiting on the handle is optional. The
278- snapshot will be committed regardless of whether `.wait()` is
279- invoked.
218+ A handle to the pending snapshot. The handle has exposes a
219+ ``.done()`` method for querying the progress and a ``.wait()``
220+ method for waiting for the snapshot's completion.
280221 """
281222 torch ._C ._log_api_usage_once ("torchsnapshot.Snapshot.async_take" )
282223 cls ._validate_app_state (app_state )
@@ -436,11 +377,13 @@ def _take_impl(
436377
437378 def restore (self , app_state : AppState ) -> None :
438379 """
439- Restores the program state from the snapshot.
380+ Restores the application state from the snapshot.
440381
441382 Args:
442- app_state: The program state to restore from the snapshot.
443-
383+ app_state (Dict[str, Stateful]): The application state to restore.
384+ ``app_state`` needs to be either identical to or a subset of the
385+ ``app_state`` used for :func:`Snapshot.take` when the snapshot was
386+ taken.
444387 """
445388 torch ._C ._log_api_usage_once ("torchsnapshot.Snapshot.restore" )
446389 self ._validate_app_state (app_state )
@@ -505,31 +448,25 @@ def read_object(
505448 memory_budget_bytes : Optional [int ] = None ,
506449 ) -> T :
507450 """
508- Read a persisted object from the snapshot's content.
509-
510- The persisted object to read is specified by its path in the snapshot
511- metadata. Available paths can be obtained via `snapshot.get_manifest()`.
451+ Reads an object from the snapshot's content.
512452
513- A path in snapshot metadata follows the following format:
514-
515- ``RANK/STATEFUL_NAME/STATE_DICT_KEY[/NESTED_CONTAINER_KEY...]``
453+ Args:
454+ path (str): The path to the target object within the snapshot.
455+ ``path`` is equivalent to the target object's key in the
456+ snapshot manifest and can be obtained via
457+ :meth:`Snapshot.get_manifest`.
516458
517- The rank only matters when the persisted object is "per-rank".
518- Arbitrary rank can be used when the persisted object is "replicated" or
519- "sharded" .
459+ obj_out (Any, optional): When specified, load the object in-place
460+ into ``obj_out`` if in-place load is supported for the object's
461+ type. Otherwise, ``obj_out`` is ignored .
520462
521- If the persisted object is a sharded tensor, `obj_out` must be
522- supplied. The supplied tensor can be either a tensor or sharded tensor.
523- `read_object` will correctly populate `obj_out`'s data according to
524- sharding spec.
463+ .. note::
464+ When the target object is a ``ShardedTensor``, ``obj_out``
465+ must be specified.
525466
526- Args:
527- path: The path to the persisted object.
528- obj_out: If specified and the object type supports in-place load,
529- `read_object` will directly read the persisted object into
530- `obj_out`'s buffer.
531- memory_budget_bytes: When specified, the read operation will keep
532- the temporary memory buffer size below this threshold.
467+ memory_budget_bytes (int, optional): When specified, the read
468+ operation will keep the temporary memory buffer size below this
469+ threshold.
533470
534471 Returns:
535472 The object read from the snapshot's content.
@@ -595,10 +532,15 @@ def read_object(
595532
596533 def get_manifest (self ) -> Dict [str , Entry ]:
597534 """
598- Returns the snapshot's manifest.
535+ Returns the snapshot manifest.
536+
537+ Each entry in the dictionary corresponds to an object in the snapshot,
538+ with the keys being the logical paths to the objects and the values
539+ being the metadata describing the object. For distributed snapshots,
540+ the manifest contain entries for objects saved by all ranks.
599541
600542 Returns:
601- The snapshot's manifest.
543+ The snapshot manifest.
602544 """
603545 return copy .deepcopy (self .metadata .manifest )
604546
0 commit comments