Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow mixture components to override cache_dir #754

Merged
merged 4 commits into from
Oct 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions config/gpt2_nano_mixture.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ data:
id: dlwh/wikitext_103_detokenized
w2:
id: dlwh/wikitext_103_detokenized
cache_dir: wikitext2_cache
train_weights:
wikitext: 1.0
w2: 1.0
Expand Down
25 changes: 19 additions & 6 deletions src/levanter/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,7 @@ class LMDatasetSourceConfig:

train_urls: List[str] = () # type: ignore
validation_urls: List[str] = () # type:ignore
cache_dir: Optional[str] = None # Optionally override the cache dir for this component

def get_shard_source(self, split) -> Optional[ShardedDataSource[str]]:
if self.id is not None:
Expand Down Expand Up @@ -530,7 +531,7 @@ class LMTaskConfig(abc.ABC):
vocab_size: Optional[int] = None # if using the passthrough tokenizer, this is required

# config related to caching
cache_dir: str = "cache/"
cache_dir: Optional[str] = "cache/"
cache_options: CacheOptions = field(default_factory=CacheOptions)
enforce_eos: bool = True # whether to append eos even if the tokenizer doesn't

Expand Down Expand Up @@ -560,7 +561,7 @@ def validation_sets(

@property
@abc.abstractmethod
def sources(self) -> dict[str, LMDatasetSourceConfig]:
def sources(self) -> Mapping[str, LMDatasetSourceConfig]:
pass

def tagged_eval_sets(
Expand Down Expand Up @@ -605,7 +606,7 @@ def validation_sets(
return {}

@property
def sources(self) -> dict[str, LMDatasetSourceConfig]:
def sources(self) -> Mapping[str, LMDatasetSourceConfig]:
return {"": self}

@cached_property
Expand Down Expand Up @@ -634,6 +635,9 @@ def token_seq_dataset(
def build_or_load_cache(
self, split: str, monitors: Union[bool, List[MetricsMonitor]] = True, logger_name: Optional[str] = None
) -> Optional[TreeCache[BatchEncoding]]:
if self.cache_dir is None:
raise ValueError("cache_dir cannot be None")

split_cache_dir = os.path.join(self.cache_dir, split)
name = logger_name or os.path.basename(self.cache_dir)

Expand Down Expand Up @@ -788,10 +792,19 @@ def build_caches(
if weight == 0 and split == "train":
continue

source_config_dict = source_config.__dict__
source_config_dict = dict(**source_config.__dict__)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add comment on why we need to delete it - to make the LMDatasetMixtureComponentConfig act like a LMDatasetSourceConfig?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

more or less. i also cleaned this part up a bit


if source_config.cache_dir is None:
# replace with the main cache dir/{name}
if self.cache_dir is None:
raise ValueError(
"If the 'main' cache_dir is None, then all component cache_dirs must be non-None, but"
f"{name}'s cache_dir is None."
)
cache_dir = os.path.join(self.cache_dir, name)
source_config_dict["cache_dir"] = cache_dir

dataset = LMDatasetConfig(
cache_dir=os.path.join(self.cache_dir, name),
**source_config_dict,
**task_config_dict,
)
Expand All @@ -813,5 +826,5 @@ def build_caches(
return caches

@property
def sources(self) -> dict[str, LMDatasetSourceConfig]:
def sources(self) -> Mapping[str, LMDatasetSourceConfig]:
return self.configs
2 changes: 1 addition & 1 deletion src/levanter/main/cache_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def main(args: RayCachedLMDatasetConfig):
print(f"Caching {split} to {args.cache_dir}.")
# connect or start the actor
batch_tokenizer = BatchTokenizer(tokenizer, enforce_eos=args.enforce_eos)
split_cache_dir = os.path.join(args.cache_dir, split)
split_cache_dir = os.path.join(args.cache_dir, split) # type: ignore
source = args.get_shard_source(split)

if source is None:
Expand Down
Loading