diff --git a/python/oneflow/test/graph/test_alexnet_auto_parallel.py b/python/oneflow/test/graph/test_alexnet_auto_parallel.py index f33b4e905dc..2d9394221e7 100644 --- a/python/oneflow/test/graph/test_alexnet_auto_parallel.py +++ b/python/oneflow/test/graph/test_alexnet_auto_parallel.py @@ -100,6 +100,10 @@ class AlexNetEvalGraph(flow.nn.Graph): def __init__(self): super().__init__() self.alexnet = alexnet_module + self.config.enable_auto_parallel(True) + self.config.enable_auto_parallel_prune_parallel_cast_ops(True) + self.config.enable_auto_parallel_mainstem_algo(True) + self.config.enable_auto_parallel_sbp_collector(True) def build(self, image): with flow.no_grad():