From 3bae9d3e81f72b145a7e7764926f4843e3bf6336 Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 4 Oct 2024 23:20:40 -0500 Subject: [PATCH] allow mixture components to override cache_dir (#754) --- config/gpt2_nano_mixture.yaml | 1 + src/levanter/data/text.py | 25 +++++++++++++++++++------ src/levanter/main/cache_dataset.py | 2 +- 3 files changed, 21 insertions(+), 7 deletions(-) diff --git a/config/gpt2_nano_mixture.yaml b/config/gpt2_nano_mixture.yaml index 2939b9e5e..35b240787 100644 --- a/config/gpt2_nano_mixture.yaml +++ b/config/gpt2_nano_mixture.yaml @@ -5,6 +5,7 @@ data: id: dlwh/wikitext_103_detokenized w2: id: dlwh/wikitext_103_detokenized + cache_dir: wikitext2_cache train_weights: wikitext: 1.0 w2: 1.0 diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 62dfb62ba..5e595b2a1 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -466,6 +466,7 @@ class LMDatasetSourceConfig: train_urls: List[str] = () # type: ignore validation_urls: List[str] = () # type:ignore + cache_dir: Optional[str] = None # Optionally override the cache dir for this component def get_shard_source(self, split) -> Optional[ShardedDataSource[str]]: if self.id is not None: @@ -530,7 +531,7 @@ class LMTaskConfig(abc.ABC): vocab_size: Optional[int] = None # if using the passthrough tokenizer, this is required # config related to caching - cache_dir: str = "cache/" + cache_dir: Optional[str] = "cache/" cache_options: CacheOptions = field(default_factory=CacheOptions) enforce_eos: bool = True # whether to append eos even if the tokenizer doesn't @@ -560,7 +561,7 @@ def validation_sets( @property @abc.abstractmethod - def sources(self) -> dict[str, LMDatasetSourceConfig]: + def sources(self) -> Mapping[str, LMDatasetSourceConfig]: pass def tagged_eval_sets( @@ -605,7 +606,7 @@ def validation_sets( return {} @property - def sources(self) -> dict[str, LMDatasetSourceConfig]: + def sources(self) -> Mapping[str, LMDatasetSourceConfig]: return {"": self} @cached_property @@ -634,6 +635,9 @@ def token_seq_dataset( def build_or_load_cache( self, split: str, monitors: Union[bool, List[MetricsMonitor]] = True, logger_name: Optional[str] = None ) -> Optional[TreeCache[BatchEncoding]]: + if self.cache_dir is None: + raise ValueError("cache_dir cannot be None") + split_cache_dir = os.path.join(self.cache_dir, split) name = logger_name or os.path.basename(self.cache_dir) @@ -788,10 +792,19 @@ def build_caches( if weight == 0 and split == "train": continue - source_config_dict = source_config.__dict__ + source_config_dict = dict(**source_config.__dict__) + + if source_config.cache_dir is None: + # replace with the main cache dir/{name} + if self.cache_dir is None: + raise ValueError( + "If the 'main' cache_dir is None, then all component cache_dirs must be non-None, but" + f"{name}'s cache_dir is None." + ) + cache_dir = os.path.join(self.cache_dir, name) + source_config_dict["cache_dir"] = cache_dir dataset = LMDatasetConfig( - cache_dir=os.path.join(self.cache_dir, name), **source_config_dict, **task_config_dict, ) @@ -813,5 +826,5 @@ def build_caches( return caches @property - def sources(self) -> dict[str, LMDatasetSourceConfig]: + def sources(self) -> Mapping[str, LMDatasetSourceConfig]: return self.configs diff --git a/src/levanter/main/cache_dataset.py b/src/levanter/main/cache_dataset.py index 2483e9214..caccc567c 100644 --- a/src/levanter/main/cache_dataset.py +++ b/src/levanter/main/cache_dataset.py @@ -31,7 +31,7 @@ def main(args: RayCachedLMDatasetConfig): print(f"Caching {split} to {args.cache_dir}.") # connect or start the actor batch_tokenizer = BatchTokenizer(tokenizer, enforce_eos=args.enforce_eos) - split_cache_dir = os.path.join(args.cache_dir, split) + split_cache_dir = os.path.join(args.cache_dir, split) # type: ignore source = args.get_shard_source(split) if source is None: