Skip to content

Commit

Permalink
refine collate_fn
Browse files Browse the repository at this point in the history
  • Loading branch information
HydrogenSulfate committed Sep 26, 2023
1 parent 8fee479 commit f242756
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 18 deletions.
9 changes: 7 additions & 2 deletions ppsci/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def build_dataloader(_dataset, cfg):
# build collate_fn if specified
batch_transforms_cfg = cfg.pop("batch_transforms", None)

collate_fn = batch_transform.default_collate_fn
collate_fn = None
if isinstance(batch_transforms_cfg, (list, tuple)):
collate_fn = batch_transform.build_batch_transforms(batch_transforms_cfg)

Expand All @@ -98,8 +98,13 @@ def build_dataloader(_dataset, cfg):
# build dataloader
if getattr(_dataset, "use_pgl", False):
# Use special dataloader from "Paddle Graph Learning" toolkit.
from pgl.utils import data as pgl_data
try:
from pgl.utils import data as pgl_data
except ModuleNotFoundError:
logger.error("Please install pgl with `pip install pgl`.")
raise ModuleNotFoundError("pgl is not installed.")

collate_fn = batch_transform.default_collate_fn
dataloader_ = pgl_data.Dataloader(
dataset=_dataset,
batch_size=cfg["batch_size"],
Expand Down
28 changes: 12 additions & 16 deletions ppsci/data/process/batch_transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,20 +46,6 @@ def default_collate_fn(batch: List[Any]) -> Any:
sample = batch[0]
if sample is None:
return None
elif isinstance(sample, pgl.Graph):
graph = pgl.Graph(
num_nodes=sample.num_nodes,
edges=sample.edges,
)
graph.x = paddle.concat([g.x for g in batch])
graph.y = paddle.concat([g.y for g in batch])
graph.edge_index = paddle.concat([g.edge_index for g in batch], axis=1)
graph.edge_attr = paddle.concat([g.edge_attr for g in batch])
graph.pos = paddle.concat([g.pos for g in batch])
graph.shape = [
len(batch),
]
return graph
elif isinstance(sample, np.ndarray):
batch = np.stack(batch, axis=0)
return batch
Expand All @@ -77,10 +63,20 @@ def default_collate_fn(batch: List[Any]) -> Any:
if not all(len(sample) == sample_fields_num for sample in iter(batch)):
raise RuntimeError("fileds number not same among samples in a batch")
return [default_collate_fn(fields) for fields in zip(*batch)]
elif str(type(sample)) == "<class 'pgl.graph.Graph'>":
# use str(type()) instead of isinstance() in case of pgl is not installed.
graph = pgl.Graph(num_nodes=sample.num_nodes, edges=sample.edges)
graph.x = paddle.concat([g.x for g in batch])
graph.y = paddle.concat([g.y for g in batch])
graph.edge_index = paddle.concat([g.edge_index for g in batch], axis=1)
graph.edge_attr = paddle.concat([g.edge_attr for g in batch])
graph.pos = paddle.concat([g.pos for g in batch])
graph.shape = [len(batch)]
return graph

raise TypeError(
"batch data can only contains: tensor, numpy.ndarray, "
f"dict, list, number, None, but got {type(sample)}"
"batch data can only contains: Tensor, numpy.ndarray, "
f"dict, list, number, None, pgl.Graph, but got {type(sample)}"
)


Expand Down

0 comments on commit f242756

Please sign in to comment.