Skip to content

Commit

Permalink
.index() instead of .index[]
Browse files Browse the repository at this point in the history
  • Loading branch information
zacool64 committed Oct 29, 2024
1 parent 981c9a4 commit 9c98ed2
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions torch_geometric/datasets/web_qsp_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,6 @@ class WebQSPDataset(InMemoryDataset):
use_pcst (bool, optional): Whether to preprocess the dataset's graph
with PCST or return the full graphs. (default: :obj:`True`)
"""
split_vals = ["train", "val", "test"]

def __init__(
self,
root: str,
Expand All @@ -189,7 +187,7 @@ def __init__(
self.force_reload = force_reload
super().__init__(root, force_reload=force_reload)

if split not in set(self.split_vals):
if split not in set(self.raw_file_names):
raise ValueError(f"Invalid 'split' argument (got {split})")

self._load_raw_data()
Expand All @@ -212,7 +210,7 @@ def _check_dependencies(self) -> None:

@property
def raw_file_names(self) -> List[str]:
return self.split_vals
return ["train", "val", "test"]

@property
def processed_file_names(self) -> List[str]:
Expand All @@ -224,7 +222,7 @@ def processed_file_names(self) -> List[str]:
"pre_transform.pt",
"large_graph_indexer",
]
split_file = file_lst.pop(self.split_vals.index(self.split))
split_file = file_lst.pop(self.raw_file_names.index(self.split))
file_lst.insert(0, split_file)
return file_lst

Expand All @@ -239,7 +237,7 @@ def _load_raw_data(self) -> None:
import datasets
if not hasattr(self, "raw_dataset"):
self.raw_dataset = datasets.load_from_disk(
self.raw_paths[self.split_vals.index[self.split]])
self.raw_paths[self.raw_file_names.index(self.split)])

if self.limit >= 0:
self.raw_dataset = self.raw_dataset.select(
Expand Down

0 comments on commit 9c98ed2

Please sign in to comment.