diff --git a/src/Futhark/CodeGen/ImpGen/GPU/Group.hs b/src/Futhark/CodeGen/ImpGen/GPU/Group.hs index dbef2ec170..a337ab9732 100644 --- a/src/Futhark/CodeGen/ImpGen/GPU/Group.hs +++ b/src/Futhark/CodeGen/ImpGen/GPU/Group.hs @@ -684,19 +684,25 @@ data Precomputed = Precomputed segOpSizes :: Stms GPUMem -> SegOpSizes segOpSizes = onStms where - onStms = foldMap (onExp . stmExp) - onExp (Op (Inner (SegOp op))) = + onStms = foldMap onStm + onStm (Let _ _ (Op (Inner (SegOp op)))) = case segVirt $ segLevel op of SegNoVirtFull seq_dims -> S.singleton $ map snd $ snd $ partitionSeqDims seq_dims $ segSpace op _ -> S.singleton $ map snd $ unSegSpace $ segSpace op - onExp (BasicOp (Replicate shape _)) = - S.singleton $ shapeDims shape - onExp (Match _ cases defbody _) = + onStm (Let (Pat [pe]) _ (BasicOp (Replicate {}))) = + S.singleton $ arrayDims $ patElemType pe + onStm (Let (Pat [pe]) _ (BasicOp (Iota {}))) = + S.singleton $ arrayDims $ patElemType pe + onStm (Let (Pat [pe]) _ (BasicOp (Copy {}))) = + S.singleton $ arrayDims $ patElemType pe + onStm (Let (Pat [pe]) _ (BasicOp (Manifest {}))) = + S.singleton $ arrayDims $ patElemType pe + onStm (Let _ _ (Match _ cases defbody _)) = foldMap (onStms . bodyStms . caseBody) cases <> onStms (bodyStms defbody) - onExp (DoLoop _ _ body) = + onStm (Let _ _ (DoLoop _ _ body)) = onStms (bodyStms body) - onExp _ = mempty + onStm _ = mempty -- | Precompute various constants and useful information. precomputeConstants :: Count GroupSize (Imp.TExp Int64) -> Stms GPUMem -> CallKernelGen Precomputed