diff --git a/ppsci/data/__init__.py b/ppsci/data/__init__.py index 0faea3beb..5e2b36057 100644 --- a/ppsci/data/__init__.py +++ b/ppsci/data/__init__.py @@ -83,7 +83,7 @@ def build_dataloader(_dataset, cfg): # build collate_fn if specified batch_transforms_cfg = cfg.pop("batch_transforms", None) - collate_fn = batch_transform.default_collate_fn + collate_fn = None if isinstance(batch_transforms_cfg, (list, tuple)): collate_fn = batch_transform.build_batch_transforms(batch_transforms_cfg) @@ -98,8 +98,13 @@ def build_dataloader(_dataset, cfg): # build dataloader if getattr(_dataset, "use_pgl", False): # Use special dataloader from "Paddle Graph Learning" toolkit. - from pgl.utils import data as pgl_data + try: + from pgl.utils import data as pgl_data + except ModuleNotFoundError: + logger.error("Please install pgl with `pip install pgl`.") + raise ModuleNotFoundError("pgl is not installed.") + collate_fn = batch_transform.default_collate_fn dataloader_ = pgl_data.Dataloader( dataset=_dataset, batch_size=cfg["batch_size"], diff --git a/ppsci/data/process/batch_transform/__init__.py b/ppsci/data/process/batch_transform/__init__.py index badee27b7..e45279bda 100644 --- a/ppsci/data/process/batch_transform/__init__.py +++ b/ppsci/data/process/batch_transform/__init__.py @@ -46,20 +46,6 @@ def default_collate_fn(batch: List[Any]) -> Any: sample = batch[0] if sample is None: return None - elif isinstance(sample, pgl.Graph): - graph = pgl.Graph( - num_nodes=sample.num_nodes, - edges=sample.edges, - ) - graph.x = paddle.concat([g.x for g in batch]) - graph.y = paddle.concat([g.y for g in batch]) - graph.edge_index = paddle.concat([g.edge_index for g in batch], axis=1) - graph.edge_attr = paddle.concat([g.edge_attr for g in batch]) - graph.pos = paddle.concat([g.pos for g in batch]) - graph.shape = [ - len(batch), - ] - return graph elif isinstance(sample, np.ndarray): batch = np.stack(batch, axis=0) return batch @@ -77,10 +63,20 @@ def default_collate_fn(batch: List[Any]) -> Any: if not all(len(sample) == sample_fields_num for sample in iter(batch)): raise RuntimeError("fileds number not same among samples in a batch") return [default_collate_fn(fields) for fields in zip(*batch)] + elif str(type(sample)) == "": + # use str(type()) instead of isinstance() in case of pgl is not installed. + graph = pgl.Graph(num_nodes=sample.num_nodes, edges=sample.edges) + graph.x = paddle.concat([g.x for g in batch]) + graph.y = paddle.concat([g.y for g in batch]) + graph.edge_index = paddle.concat([g.edge_index for g in batch], axis=1) + graph.edge_attr = paddle.concat([g.edge_attr for g in batch]) + graph.pos = paddle.concat([g.pos for g in batch]) + graph.shape = [len(batch)] + return graph raise TypeError( - "batch data can only contains: tensor, numpy.ndarray, " - f"dict, list, number, None, but got {type(sample)}" + "batch data can only contains: Tensor, numpy.ndarray, " + f"dict, list, number, None, pgl.Graph, but got {type(sample)}" )