Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow mixture components to override cache_dir #754

Merged
merged 4 commits into from
Oct 5, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions config/gpt2_nano_mixture.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ data:
id: dlwh/wikitext_103_detokenized
w2:
id: dlwh/wikitext_103_detokenized
cache_dir: wikitext2_cache
train_weights:
wikitext: 1.0
w2: 1.0
Expand Down
35 changes: 28 additions & 7 deletions src/levanter/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,13 +524,18 @@ def fsspec_expand_glob(url):
return urls


@dataclass
class LMDatasetMixtureComponentConfig(LMDatasetSourceConfig):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only thing different between the base and derived class is cache_dir? Seems weird since this on the surface doesn't seem to relate to mixtures...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah i can just put it in the base i guess

cache_dir: Optional[str] = None # Optionally override the cache dir for this component


@dataclass
class LMTaskConfig(abc.ABC):
tokenizer: str = "gpt2"
vocab_size: Optional[int] = None # if using the passthrough tokenizer, this is required

# config related to caching
cache_dir: str = "cache/"
cache_dir: Optional[str] = "cache/"
cache_options: CacheOptions = field(default_factory=CacheOptions)
enforce_eos: bool = True # whether to append eos even if the tokenizer doesn't

Expand Down Expand Up @@ -560,7 +565,7 @@ def validation_sets(

@property
@abc.abstractmethod
def sources(self) -> dict[str, LMDatasetSourceConfig]:
def sources(self) -> Mapping[str, LMDatasetSourceConfig]:
pass

def tagged_eval_sets(
Expand Down Expand Up @@ -605,7 +610,7 @@ def validation_sets(
return {}

@property
def sources(self) -> dict[str, LMDatasetSourceConfig]:
def sources(self) -> Mapping[str, LMDatasetSourceConfig]:
return {"": self}

@cached_property
Expand Down Expand Up @@ -634,6 +639,9 @@ def token_seq_dataset(
def build_or_load_cache(
self, split: str, monitors: Union[bool, List[MetricsMonitor]] = True, logger_name: Optional[str] = None
) -> Optional[TreeCache[BatchEncoding]]:
if self.cache_dir is None:
raise ValueError("cache_dir cannot be None")

split_cache_dir = os.path.join(self.cache_dir, split)
name = logger_name or os.path.basename(self.cache_dir)

Expand Down Expand Up @@ -702,7 +710,7 @@ class LMMixtureDatasetConfig(LMTaskConfig):
"""This class represents a mixture of datasets with their associated weights."""

# data source configs and weights
configs: Dict[str, LMDatasetSourceConfig] = field(default_factory=dict)
configs: Dict[str, LMDatasetMixtureComponentConfig] = field(default_factory=dict)
""" configuration of each dataset source (urls, hf dataset id, etc.) """
train_weights: Dict[str, float] = field(default_factory=dict)
""" weights for each dataset source. They will be normalized to sum to 1. """
Expand Down Expand Up @@ -788,10 +796,23 @@ def build_caches(
if weight == 0 and split == "train":
continue

source_config_dict = source_config.__dict__
source_config_dict = dict(**source_config.__dict__)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add comment on why we need to delete it - to make the LMDatasetMixtureComponentConfig act like a LMDatasetSourceConfig?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

more or less. i also cleaned this part up a bit

if "cache_dir" in source_config_dict:
del source_config_dict["cache_dir"]

if source_config.cache_dir is not None:
cache_dir = source_config.cache_dir
else:
if self.cache_dir is None:
raise ValueError(
"If the 'main' cache_dir is None, then all component cache_dirs must be non-None, but"
f"{name}'s cache_dir is None."
)

cache_dir = os.path.join(self.cache_dir, name)

dataset = LMDatasetConfig(
cache_dir=os.path.join(self.cache_dir, name),
cache_dir=cache_dir,
**source_config_dict,
**task_config_dict,
)
Expand All @@ -813,5 +834,5 @@ def build_caches(
return caches

@property
def sources(self) -> dict[str, LMDatasetSourceConfig]:
def sources(self) -> Mapping[str, LMDatasetSourceConfig]:
return self.configs
2 changes: 1 addition & 1 deletion src/levanter/main/cache_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def main(args: RayCachedLMDatasetConfig):
print(f"Caching {split} to {args.cache_dir}.")
# connect or start the actor
batch_tokenizer = BatchTokenizer(tokenizer, enforce_eos=args.enforce_eos)
split_cache_dir = os.path.join(args.cache_dir, split)
split_cache_dir = os.path.join(args.cache_dir, split) # type: ignore
source = args.get_shard_source(split)

if source is None:
Expand Down
Loading