Skip to content

Commit

Permalink
allow mixture components to override cache_dir (#754)
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh authored Oct 5, 2024
1 parent b41838f commit 3bae9d3
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 7 deletions.
1 change: 1 addition & 0 deletions config/gpt2_nano_mixture.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ data:
id: dlwh/wikitext_103_detokenized
w2:
id: dlwh/wikitext_103_detokenized
cache_dir: wikitext2_cache
train_weights:
wikitext: 1.0
w2: 1.0
Expand Down
25 changes: 19 additions & 6 deletions src/levanter/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,7 @@ class LMDatasetSourceConfig:

train_urls: List[str] = () # type: ignore
validation_urls: List[str] = () # type:ignore
cache_dir: Optional[str] = None # Optionally override the cache dir for this component

def get_shard_source(self, split) -> Optional[ShardedDataSource[str]]:
if self.id is not None:
Expand Down Expand Up @@ -530,7 +531,7 @@ class LMTaskConfig(abc.ABC):
vocab_size: Optional[int] = None # if using the passthrough tokenizer, this is required

# config related to caching
cache_dir: str = "cache/"
cache_dir: Optional[str] = "cache/"
cache_options: CacheOptions = field(default_factory=CacheOptions)
enforce_eos: bool = True # whether to append eos even if the tokenizer doesn't

Expand Down Expand Up @@ -560,7 +561,7 @@ def validation_sets(

@property
@abc.abstractmethod
def sources(self) -> dict[str, LMDatasetSourceConfig]:
def sources(self) -> Mapping[str, LMDatasetSourceConfig]:
pass

def tagged_eval_sets(
Expand Down Expand Up @@ -605,7 +606,7 @@ def validation_sets(
return {}

@property
def sources(self) -> dict[str, LMDatasetSourceConfig]:
def sources(self) -> Mapping[str, LMDatasetSourceConfig]:
return {"": self}

@cached_property
Expand Down Expand Up @@ -634,6 +635,9 @@ def token_seq_dataset(
def build_or_load_cache(
self, split: str, monitors: Union[bool, List[MetricsMonitor]] = True, logger_name: Optional[str] = None
) -> Optional[TreeCache[BatchEncoding]]:
if self.cache_dir is None:
raise ValueError("cache_dir cannot be None")

split_cache_dir = os.path.join(self.cache_dir, split)
name = logger_name or os.path.basename(self.cache_dir)

Expand Down Expand Up @@ -788,10 +792,19 @@ def build_caches(
if weight == 0 and split == "train":
continue

source_config_dict = source_config.__dict__
source_config_dict = dict(**source_config.__dict__)

if source_config.cache_dir is None:
# replace with the main cache dir/{name}
if self.cache_dir is None:
raise ValueError(
"If the 'main' cache_dir is None, then all component cache_dirs must be non-None, but"
f"{name}'s cache_dir is None."
)
cache_dir = os.path.join(self.cache_dir, name)
source_config_dict["cache_dir"] = cache_dir

dataset = LMDatasetConfig(
cache_dir=os.path.join(self.cache_dir, name),
**source_config_dict,
**task_config_dict,
)
Expand All @@ -813,5 +826,5 @@ def build_caches(
return caches

@property
def sources(self) -> dict[str, LMDatasetSourceConfig]:
def sources(self) -> Mapping[str, LMDatasetSourceConfig]:
return self.configs
2 changes: 1 addition & 1 deletion src/levanter/main/cache_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def main(args: RayCachedLMDatasetConfig):
print(f"Caching {split} to {args.cache_dir}.")
# connect or start the actor
batch_tokenizer = BatchTokenizer(tokenizer, enforce_eos=args.enforce_eos)
split_cache_dir = os.path.join(args.cache_dir, split)
split_cache_dir = os.path.join(args.cache_dir, split) # type: ignore
source = args.get_shard_source(split)

if source is None:
Expand Down

0 comments on commit 3bae9d3

Please sign in to comment.