Skip to content

Commit

Permalink
a few final tweaks for marin runs (#755)
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh authored Oct 5, 2024
1 parent 3bae9d3 commit 9847728
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 4 deletions.
1 change: 0 additions & 1 deletion docs/Configuration-Guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ The following table lists some of the parameters that you might want to change.
| Parameter | Description | Default |
|----------------|-------------------------------------------------------------------------------|---------|
| `log_dir` | Where to save logs (python logger). `$run_id` will be appended | `logs/` |
| `run_base_dir` | where to save run artifacts. not really used much. `$run_id` will be appended | `runs/` |



Expand Down
6 changes: 5 additions & 1 deletion src/levanter/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,8 +549,12 @@ class CheckpointerConfig:
default_factory=lambda: [dict(every=10000)]
) # list of dicts with two keys: every and until

append_run_id_to_base_path: bool = True

def expanded_path(self, run_id) -> str:
return os.path.expanduser(os.path.join(self.base_path, run_id))
if self.append_run_id_to_base_path:
return os.path.expanduser(os.path.join(self.base_path, run_id))
return os.path.expanduser(self.base_path)

def create(self, run_id) -> Checkpointer:
keeps = [CheckpointInterval(**k) for k in self.keep]
Expand Down
4 changes: 4 additions & 0 deletions src/levanter/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,8 @@ def tagged_eval_sets(
class LMDatasetConfig(LMDatasetSourceConfig, LMTaskConfig):
"""This class supports loading data both from HF Datasets and from a raw dataset of jsonl urls"""

cache_dir: Optional[str] = "cache/"

def train_set(
self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True, *, key: Optional[PRNGKeyArray] = None
) -> AsyncDataset[np.ndarray]:
Expand Down Expand Up @@ -705,6 +707,8 @@ def _convert_id_to_token(self, index: int) -> str:
class LMMixtureDatasetConfig(LMTaskConfig):
"""This class represents a mixture of datasets with their associated weights."""

cache_dir: Optional[str] = "cache/"

# data source configs and weights
configs: Dict[str, LMDatasetSourceConfig] = field(default_factory=dict)
""" configuration of each dataset source (urls, hf dataset id, etc.) """
Expand Down
6 changes: 5 additions & 1 deletion src/levanter/main/train_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,11 @@ def main(config: TrainLmConfig):
callbacks.log_performance_stats(Pos.size, trainer.config.train_batch_size, flops_per_example), every=1
)
if config.hf_save_path is not None:
full_save_path = os.path.join(config.hf_save_path, trainer.run_id)
# bit gross to reach this far into the config, but it's fine
if config.trainer.checkpointer.append_run_id_to_base_path:
full_save_path = os.path.join(config.hf_save_path, trainer.run_id)
else:
full_save_path = config.hf_save_path

trainer.add_hook(
save_hf_checkpoint_callback(full_save_path, converter, upload_to_hf=config.hf_upload or False),
Expand Down
1 change: 0 additions & 1 deletion src/levanter/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,6 @@ class TrainerConfig:

wandb: Optional[tracker.wandb.WandbConfig] = None
log_dir: Path = Path("logs/")
run_base_dir: Path = Path("runs/")
id: Optional[str] = None # run id. if None, will be set to a random string

tracker: TrackerConfig | Tuple[TrackerConfig, ...] = field(default_factory=tracker.wandb.WandbConfig)
Expand Down

0 comments on commit 9847728

Please sign in to comment.