Skip to content

Commit 1028bfe

Browse files
committed
fix docstring for JaxTrainer
Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
1 parent b8638c4 commit 1028bfe

File tree

1 file changed

+0
-15
lines changed

1 file changed

+0
-15
lines changed

python/ray/train/v2/jax/jax_trainer.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -92,18 +92,6 @@ def main(argv: Sequence[str]):
9292
avoid driver-side TPU lock issues.
9393
9494
Args:
95-
train_loop_per_worker: The training function to execute.
96-
train_loop_config: Configurations to pass into
97-
``train_loop_per_worker`` if it accepts an argument.
98-
jax_config: Configuration for setting up the JAX backend.
99-
scaling_config: Configuration for how to scale data parallel training
100-
with SPMD.
101-
dataset_config: Optional configuration for datasets.
102-
run_config: Configuration for the execution of the training run.
103-
datasets: Any Datasets to use for training. Use
104-
the key "train" to denote which dataset is the training dataset.
105-
resume_from_checkpoint: A checkpoint to resume training from.
106-
10795
train_loop_per_worker: The training function to execute on each worker.
10896
This function can either take in zero arguments or a single ``Dict``
10997
argument which is set by defining ``train_loop_config``.
@@ -134,9 +122,6 @@ def main(argv: Sequence[str]):
134122
resume_from_checkpoint: A checkpoint to resume training from.
135123
This checkpoint can be accessed from within ``train_loop_per_worker``
136124
by calling ``ray.train.get_checkpoint()``.
137-
metadata: Dict that should be made available via
138-
`ray.train.get_context().get_metadata()` and in `checkpoint.get_metadata()`
139-
for checkpoints saved from this Trainer. Must be JSON-serializable.
140125
"""
141126

142127
def __init__(

0 commit comments

Comments
 (0)