Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

a few final tweaks for marin runs #755

Merged
merged 2 commits into from
Oct 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docs/Configuration-Guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ The following table lists some of the parameters that you might want to change.
| Parameter | Description | Default |
|----------------|-------------------------------------------------------------------------------|---------|
| `log_dir` | Where to save logs (python logger). `$run_id` will be appended | `logs/` |
| `run_base_dir` | where to save run artifacts. not really used much. `$run_id` will be appended | `runs/` |



Expand Down
6 changes: 5 additions & 1 deletion src/levanter/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,8 +549,12 @@ class CheckpointerConfig:
default_factory=lambda: [dict(every=10000)]
) # list of dicts with two keys: every and until

append_run_id_to_base_path: bool = True

def expanded_path(self, run_id) -> str:
return os.path.expanduser(os.path.join(self.base_path, run_id))
if self.append_run_id_to_base_path:
return os.path.expanduser(os.path.join(self.base_path, run_id))
return os.path.expanduser(self.base_path)

def create(self, run_id) -> Checkpointer:
keeps = [CheckpointInterval(**k) for k in self.keep]
Expand Down
4 changes: 4 additions & 0 deletions src/levanter/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,8 @@ def tagged_eval_sets(
class LMDatasetConfig(LMDatasetSourceConfig, LMTaskConfig):
"""This class supports loading data both from HF Datasets and from a raw dataset of jsonl urls"""

cache_dir: Optional[str] = "cache/"

def train_set(
self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True, *, key: Optional[PRNGKeyArray] = None
) -> AsyncDataset[np.ndarray]:
Expand Down Expand Up @@ -705,6 +707,8 @@ def _convert_id_to_token(self, index: int) -> str:
class LMMixtureDatasetConfig(LMTaskConfig):
"""This class represents a mixture of datasets with their associated weights."""

cache_dir: Optional[str] = "cache/"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense to move this into a base class?


# data source configs and weights
configs: Dict[str, LMDatasetSourceConfig] = field(default_factory=dict)
""" configuration of each dataset source (urls, hf dataset id, etc.) """
Expand Down
6 changes: 5 additions & 1 deletion src/levanter/main/train_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,11 @@ def main(config: TrainLmConfig):
callbacks.log_performance_stats(Pos.size, trainer.config.train_batch_size, flops_per_example), every=1
)
if config.hf_save_path is not None:
full_save_path = os.path.join(config.hf_save_path, trainer.run_id)
# bit gross to reach this far into the config, but it's fine
if config.trainer.checkpointer.append_run_id_to_base_path:
full_save_path = os.path.join(config.hf_save_path, trainer.run_id)
else:
full_save_path = config.hf_save_path

trainer.add_hook(
save_hf_checkpoint_callback(full_save_path, converter, upload_to_hf=config.hf_upload or False),
Expand Down
1 change: 0 additions & 1 deletion src/levanter/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,6 @@ class TrainerConfig:

wandb: Optional[tracker.wandb.WandbConfig] = None
log_dir: Path = Path("logs/")
run_base_dir: Path = Path("runs/")
id: Optional[str] = None # run id. if None, will be set to a random string

tracker: TrackerConfig | Tuple[TrackerConfig, ...] = field(default_factory=tracker.wandb.WandbConfig)
Expand Down
Loading