From 86404e4b4fadd07a687d8ee51e4a4047358b4f1b Mon Sep 17 00:00:00 2001 From: Christoph Boeddeker Date: Thu, 28 Nov 2024 09:36:15 +0100 Subject: [PATCH] add from_file --- lazy_dataset/core.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/lazy_dataset/core.py b/lazy_dataset/core.py index ed8a59f..b9ab2d2 100644 --- a/lazy_dataset/core.py +++ b/lazy_dataset/core.py @@ -112,6 +112,9 @@ def new( elif isinstance(examples, Dataset): dataset = from_dataset( examples, immutable_warranty=immutable_warranty, name=name) + elif isinstance(examples, (str, Path)): + dataset = from_file( + examples, immutable_warranty=immutable_warranty, name=name) else: raise TypeError(type(examples), examples) return dataset @@ -140,6 +143,31 @@ def from_list( return ListDataset(examples, name=name).map(deserialize) +def from_file( + examples: [str, Path], + immutable_warranty: str = 'pickle', + name: str = None, +): + assert isinstance(examples, (str, Path)), examples + examples = Path(examples) + if examples.suffix == '.json': + import json + with open(examples) as fd: + examples = json.load(fd) + elif examples.suffix == '.yaml': + import yaml + with open(examples) as fd: + examples = yaml.load(fd) + else: + raise NotImplementedError(examples.suffix, examples) + + assert isinstance(examples, (tuple, list, dict)), (type(examples), examples) + + if immutable_warranty is None: + return ListDataset(examples, name=name) + return new(examples, immutable_warranty=immutable_warranty, name=name) + + def from_dataset( examples: 'Dataset', immutable_warranty: str = 'pickle',