Skip to content

Commit

Permalink
Merge pull request #49 from arcee-ai/fix-dataset-load
Browse files Browse the repository at this point in the history
Fix dataset load
  • Loading branch information
Ben-Epstein authored Sep 21, 2023
2 parents 6f5db4e + 1536754 commit f938906
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
2 changes: 1 addition & 1 deletion dalm/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.0.2"
__version__ = "0.0.3"
11 changes: 8 additions & 3 deletions dalm/datasets/qa_gen/question_answer_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,15 +142,20 @@ def generate_qa_from_dataset(
def _load_dataset_from_path(dataset_path: str) -> Dataset:
if dataset_path.endswith(".csv"):
dataset = Dataset.from_csv(dataset_path)
elif not os.path.splitext(dataset_path):
elif not os.path.splitext(dataset_path)[-1]:
if os.path.isdir(dataset_path):
dataset = datasets.load_from_disk(dataset_path)
else:
dataset = datasets.load_dataset(dataset_path)
key = next(iter(dataset))
if isinstance(dataset, DatasetDict):
if "train" in dataset:
key = "train"
elif "training" in dataset:
key = "training"
else:
key = next(iter(dataset))
warnings.warn(f"Found multiple keys in dataset. Generating qa for split {key}", stacklevel=0)
dataset = dataset[key]
dataset = dataset[key]
else:
raise ValueError(
"dataset-path must be one of csv, dataset directory "
Expand Down

0 comments on commit f938906

Please sign in to comment.