diff --git a/docs/code/use_dataset_in_paddlepaddle.py b/docs/code/use_dataset_in_paddlepaddle.py index 709aea2f5..22d0701eb 100644 --- a/docs/code/use_dataset_in_paddlepaddle.py +++ b/docs/code/use_dataset_in_paddlepaddle.py @@ -19,12 +19,12 @@ from tensorbay.dataset import Dataset as TensorBayDataset -class MNISTSegment(Dataset): - """class for wrapping a MNIST segment.""" +class DogsVsCatsSegment(Dataset): + """class for wrapping a DogsVsCats segment.""" def __init__(self, gas, segment_name, transform): super().__init__() - self.dataset = TensorBayDataset("MNIST", gas) + self.dataset = TensorBayDataset("DogsVsCats", gas) self.segment = self.dataset[segment_name] self.category_to_index = self.dataset.catalog.classification.get_category_to_index() self.transform = transform @@ -38,10 +38,9 @@ def __getitem__(self, idx): image_tensor = self.transform(Image.open(fp)) return image_tensor, self.category_to_index[data.label.classification.category] + # """""" -"""""" - """Build a dataloader and run it""" ACCESS_KEY = "Accesskey-*****" @@ -49,7 +48,7 @@ def __getitem__(self, idx): normalization = transforms.Normalize(mean=[0.485], std=[0.229]) my_transforms = transforms.Compose([to_tensor, normalization]) -train_segment = MNISTSegment(GAS(ACCESS_KEY), segment_name="train", transform=my_transforms) +train_segment = DogsVsCatsSegment(GAS(ACCESS_KEY), segment_name="train", transform=my_transforms) train_dataloader = DataLoader(train_segment, batch_size=4, shuffle=True, num_workers=0) for index, (image, label) in enumerate(train_dataloader): diff --git a/docs/source/integrations/paddlepaddle.rst b/docs/source/integrations/paddlepaddle.rst index 85cb94cd2..5a926ae8f 100644 --- a/docs/source/integrations/paddlepaddle.rst +++ b/docs/source/integrations/paddlepaddle.rst @@ -3,7 +3,7 @@ ############### This topic describes how to integrate TensorBay dataset with PaddlePaddle Pipeline -using the `MNIST Dataset `_ as an example. +using the `DogsVsCats Dataset `_ as an example. The typical method to integrate TensorBay dataset with PaddlePaddle is to build a "Segment" class derived from ``paddle.io.Dataset``. @@ -11,11 +11,11 @@ derived from ``paddle.io.Dataset``. .. literalinclude:: ../../../docs/code/use_dataset_in_paddlepaddle.py :language: python :start-after: """Build a Segment class""" - :end-before: """""" + :end-before: # """""" Using the following code to create a PaddlePaddle dataloader and run it: .. literalinclude:: ../../../docs/code/use_dataset_in_paddlepaddle.py :language: python :start-after: """Build a dataloader and run it""" - :end-before: """""" \ No newline at end of file + :end-before: """"""