Skip to content

Commit a13977c

Browse files
committed
fix: avoid reloading datasets
1 parent 0d85553 commit a13977c

File tree

4 files changed

+14
-10
lines changed

4 files changed

+14
-10
lines changed

Diff for: sdk/diffgram/core/directory.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -152,20 +152,22 @@ def to_pytorch(self, transform = None):
152152
Transforms the file list inside the dataset into a pytorch dataset.
153153
:return:
154154
"""
155-
file_id_list = self.all_file_ids()
155+
file_id_list = self.file_id_list
156156
pytorch_dataset = DiffgramPytorchDataset(
157157
project = self.client,
158158
diffgram_file_id_list = file_id_list,
159-
transform = transform
159+
transform = transform,
160+
validate_ids = False
160161

161162
)
162163
return pytorch_dataset
163164

164165
def to_tensorflow(self):
165-
file_id_list = self.all_file_ids()
166+
file_id_list = self.file_id_list
166167
diffgram_tensorflow_dataset = DiffgramTensorflowDataset(
167168
project = self.client,
168-
diffgram_file_id_list = file_id_list
169+
diffgram_file_id_list = file_id_list,
170+
validate_ids = False
169171
)
170172
return diffgram_tensorflow_dataset
171173

Diff for: sdk/diffgram/core/sliced_directory.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ def to_pytorch(self, transform = None):
5050
pytorch_dataset = DiffgramPytorchDataset(
5151
project = self.client,
5252
diffgram_file_id_list = self.file_id_list,
53-
transform = transform
53+
transform = transform,
54+
validate_ids = False
5455

5556
)
5657
return pytorch_dataset
@@ -59,6 +60,7 @@ def to_tensorflow(self):
5960
file_id_list = self.all_file_ids()
6061
diffgram_tensorflow_dataset = DiffgramTensorflowDataset(
6162
project = self.client,
62-
diffgram_file_id_list = file_id_list
63+
diffgram_file_id_list = file_id_list,
64+
validate_ids = False
6365
)
6466
return diffgram_tensorflow_dataset

Diff for: sdk/diffgram/pytorch_diffgram/diffgram_pytorch_dataset.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@
55

66
class DiffgramPytorchDataset(DiffgramDatasetIterator, Dataset):
77

8-
def __init__(self, project, diffgram_file_id_list = None, transform = None):
8+
def __init__(self, project, diffgram_file_id_list = None, transform = None, validate_ids = True):
99
"""
1010
1111
:param project (sdk.core.core.Project): A Project object from the Diffgram SDK
1212
:param diffgram_file_list (list): An arbitrary number of file ID's from Diffgram.
1313
:param transform (callable, optional): Optional transforms to be applied on a sample
1414
"""
15-
super(DiffgramPytorchDataset, self).__init__(project, diffgram_file_id_list)
15+
super(DiffgramPytorchDataset, self).__init__(project, diffgram_file_id_list, validate_ids)
1616

1717
self.diffgram_file_id_list = diffgram_file_id_list
1818

Diff for: sdk/diffgram/tensorflow_diffgram/diffgram_tensorflow_dataset.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@
1111

1212
class DiffgramTensorflowDataset(DiffgramDatasetIterator):
1313

14-
def __init__(self, project, diffgram_file_id_list):
14+
def __init__(self, project, diffgram_file_id_list, validate_ids = True):
1515
"""
1616
1717
:param project (sdk.core.core.Project): A Project object from the Diffgram SDK
1818
:param diffgram_file_list (list): An arbitrary number of file ID's from Diffgram.
1919
:param transform (callable, optional): Optional transforms to be applied on a sample
2020
"""
21-
super(DiffgramTensorflowDataset, self).__init__(project, diffgram_file_id_list)
21+
super(DiffgramTensorflowDataset, self).__init__(project, diffgram_file_id_list, validate_ids)
2222

2323
self.diffgram_file_id_list = diffgram_file_id_list
2424

0 commit comments

Comments
 (0)