From 124c0b9843d5c817532687ea92cba15d75ea29b5 Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 4 Oct 2024 13:13:11 -0700 Subject: [PATCH 1/4] allow mixture components to override cache_dir --- config/gpt2_nano_mixture.yaml | 1 + src/levanter/data/text.py | 35 ++++++++++++++++++++++++++++------- 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/config/gpt2_nano_mixture.yaml b/config/gpt2_nano_mixture.yaml index 2939b9e5e..35b240787 100644 --- a/config/gpt2_nano_mixture.yaml +++ b/config/gpt2_nano_mixture.yaml @@ -5,6 +5,7 @@ data: id: dlwh/wikitext_103_detokenized w2: id: dlwh/wikitext_103_detokenized + cache_dir: wikitext2_cache train_weights: wikitext: 1.0 w2: 1.0 diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 62dfb62ba..1d1c159c0 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -524,13 +524,18 @@ def fsspec_expand_glob(url): return urls +@dataclass +class LMDatasetMixtureComponentConfig(LMDatasetSourceConfig): + cache_dir: Optional[str] = None # Optionally override the cache dir for this component + + @dataclass class LMTaskConfig(abc.ABC): tokenizer: str = "gpt2" vocab_size: Optional[int] = None # if using the passthrough tokenizer, this is required # config related to caching - cache_dir: str = "cache/" + cache_dir: Optional[str] = "cache/" cache_options: CacheOptions = field(default_factory=CacheOptions) enforce_eos: bool = True # whether to append eos even if the tokenizer doesn't @@ -560,7 +565,7 @@ def validation_sets( @property @abc.abstractmethod - def sources(self) -> dict[str, LMDatasetSourceConfig]: + def sources(self) -> Mapping[str, LMDatasetSourceConfig]: pass def tagged_eval_sets( @@ -605,7 +610,7 @@ def validation_sets( return {} @property - def sources(self) -> dict[str, LMDatasetSourceConfig]: + def sources(self) -> Mapping[str, LMDatasetSourceConfig]: return {"": self} @cached_property @@ -634,6 +639,9 @@ def token_seq_dataset( def build_or_load_cache( self, split: str, monitors: Union[bool, List[MetricsMonitor]] = True, logger_name: Optional[str] = None ) -> Optional[TreeCache[BatchEncoding]]: + if self.cache_dir is None: + raise ValueError("cache_dir cannot be None") + split_cache_dir = os.path.join(self.cache_dir, split) name = logger_name or os.path.basename(self.cache_dir) @@ -702,7 +710,7 @@ class LMMixtureDatasetConfig(LMTaskConfig): """This class represents a mixture of datasets with their associated weights.""" # data source configs and weights - configs: Dict[str, LMDatasetSourceConfig] = field(default_factory=dict) + configs: Dict[str, LMDatasetMixtureComponentConfig] = field(default_factory=dict) """ configuration of each dataset source (urls, hf dataset id, etc.) """ train_weights: Dict[str, float] = field(default_factory=dict) """ weights for each dataset source. They will be normalized to sum to 1. """ @@ -788,10 +796,23 @@ def build_caches( if weight == 0 and split == "train": continue - source_config_dict = source_config.__dict__ + source_config_dict = dict(**source_config.__dict__) + if "cache_dir" in source_config_dict: + del source_config_dict["cache_dir"] + + if source_config.cache_dir is not None: + cache_dir = source_config.cache_dir + else: + if self.cache_dir is None: + raise ValueError( + "If the 'main' cache_dir is None, then all component cache_dirs must be non-None, but" + f"{name}'s cache_dir is None." + ) + + cache_dir = os.path.join(self.cache_dir, name) dataset = LMDatasetConfig( - cache_dir=os.path.join(self.cache_dir, name), + cache_dir=cache_dir, **source_config_dict, **task_config_dict, ) @@ -813,5 +834,5 @@ def build_caches( return caches @property - def sources(self) -> dict[str, LMDatasetSourceConfig]: + def sources(self) -> Mapping[str, LMDatasetSourceConfig]: return self.configs From 88354965fbd1fb76911a85fbf3a4172b6c71291c Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 4 Oct 2024 22:41:49 -0500 Subject: [PATCH 2/4] ignore type --- src/levanter/main/cache_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/levanter/main/cache_dataset.py b/src/levanter/main/cache_dataset.py index 2483e9214..caccc567c 100644 --- a/src/levanter/main/cache_dataset.py +++ b/src/levanter/main/cache_dataset.py @@ -31,7 +31,7 @@ def main(args: RayCachedLMDatasetConfig): print(f"Caching {split} to {args.cache_dir}.") # connect or start the actor batch_tokenizer = BatchTokenizer(tokenizer, enforce_eos=args.enforce_eos) - split_cache_dir = os.path.join(args.cache_dir, split) + split_cache_dir = os.path.join(args.cache_dir, split) # type: ignore source = args.get_shard_source(split) if source is None: From efe0f53fc254dbe8d3919148f6561ea9f82e70c5 Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 4 Oct 2024 22:58:40 -0500 Subject: [PATCH 3/4] pr --- src/levanter/data/text.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 1d1c159c0..c926bc940 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -797,22 +797,18 @@ def build_caches( continue source_config_dict = dict(**source_config.__dict__) - if "cache_dir" in source_config_dict: - del source_config_dict["cache_dir"] - if source_config.cache_dir is not None: - cache_dir = source_config.cache_dir - else: + if source_config.cache_dir is None: + # replace with the main cache dir/{name} if self.cache_dir is None: raise ValueError( "If the 'main' cache_dir is None, then all component cache_dirs must be non-None, but" f"{name}'s cache_dir is None." ) - cache_dir = os.path.join(self.cache_dir, name) + source_config_dict["cache_dir"] = cache_dir dataset = LMDatasetConfig( - cache_dir=cache_dir, **source_config_dict, **task_config_dict, ) From 90247db8b6e4911acb10968bfd6c2ac132536b22 Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 4 Oct 2024 23:02:17 -0500 Subject: [PATCH 4/4] pr --- src/levanter/data/text.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index c926bc940..5e595b2a1 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -466,6 +466,7 @@ class LMDatasetSourceConfig: train_urls: List[str] = () # type: ignore validation_urls: List[str] = () # type:ignore + cache_dir: Optional[str] = None # Optionally override the cache dir for this component def get_shard_source(self, split) -> Optional[ShardedDataSource[str]]: if self.id is not None: @@ -524,11 +525,6 @@ def fsspec_expand_glob(url): return urls -@dataclass -class LMDatasetMixtureComponentConfig(LMDatasetSourceConfig): - cache_dir: Optional[str] = None # Optionally override the cache dir for this component - - @dataclass class LMTaskConfig(abc.ABC): tokenizer: str = "gpt2" @@ -710,7 +706,7 @@ class LMMixtureDatasetConfig(LMTaskConfig): """This class represents a mixture of datasets with their associated weights.""" # data source configs and weights - configs: Dict[str, LMDatasetMixtureComponentConfig] = field(default_factory=dict) + configs: Dict[str, LMDatasetSourceConfig] = field(default_factory=dict) """ configuration of each dataset source (urls, hf dataset id, etc.) """ train_weights: Dict[str, float] = field(default_factory=dict) """ weights for each dataset source. They will be normalized to sum to 1. """