Skip to content

Commit

Permalink
feat(opendataset): add dataloader for SVHN dataset
Browse files Browse the repository at this point in the history
PR Closed: Graviti-AI#896
  • Loading branch information
graczhual committed Aug 10, 2021
1 parent 360cf1a commit 2810d19
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions tensorbay/opendataset/SVHN/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,23 +53,24 @@ def SVHN(path: str) -> Dataset:
except ModuleNotFoundError as error:
raise ModuleImportError(error.name) from error # type: ignore[arg-type]

root_path = os.path.abspath(os.path.expanduser(path))
root_path = os.path.join(os.path.abspath(os.path.expanduser(path)), "FullNumbers")
dataset = Dataset(DATASET_NAME)
dataset.load_catalog(os.path.join(os.path.dirname(__file__), "catalog.json"))

for segment_name in _SEGMENTS:
segment = dataset.create_segment(segment_name)
mat = File(os.path.join(root_path, "FullNumbers", segment_name, "digitStruct.mat"))
file_path = os.path.join(root_path, segment_name)
mat = File(os.path.join(file_path, "digitStruct.mat"))
names = mat["digitStruct"]["name"]
bboxes = mat["digitStruct"]["bbox"]
for name, bbox in zip(names, bboxes):
segment.append(_get_data(mat, name, bbox))
segment.append(_get_data(mat, name, bbox, file_path))
return dataset


def _get_data(mat: Any, name: Any, bbox: Any) -> Data:
def _get_data(mat: Any, name: Any, bbox: Any, file_path: str) -> Data:
image_path = "".join(chr(v[0]) for v in mat[name[0]])
data = Data(image_path, target_remote_path=image_path.zfill(10))
data = Data(os.path.join(file_path, image_path), target_remote_path=image_path.zfill(10))
data.label.box2d = []
mat_bbox = mat[bbox[0]]

Expand Down

0 comments on commit 2810d19

Please sign in to comment.