Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not run AccelerateMatmul on pre-Volta GPUs #1505

Merged
merged 3 commits into from
Apr 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ class BlockedToMMA : public mlir::RewritePattern {
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
if (computeCapability < 70)
return failure();
auto dotOp = cast<triton::DotOp>(op);
// TODO: Check data-types and SM compatibility
auto oldRetType = dotOp.getResult().getType().cast<RankedTensorType>();
Expand Down
11 changes: 11 additions & 0 deletions python/triton/language/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from functools import wraps
from typing import List, Optional, Sequence, Tuple, TypeVar

import torch

import triton
from . import core as tl
from triton._C.libtriton.triton import ir

Expand Down Expand Up @@ -1180,6 +1183,14 @@ def dot(lhs: tl.tensor,
allow_tf32: bool,
out_dtype: tl.dtype,
builder: ir.builder) -> tl.tensor:
if torch.version.hip is None:
device = triton.runtime.jit.get_current_device()
capability = triton.runtime.jit.get_device_capability(device)
capability = capability[0] * 10 + capability[1]
if capability < 70:
assert (
not rhs.dtype.is_fp16() and not rhs.dtype.is_fp8()
), "Float8 and Float16 types are not supported for compute capability < 70 (use Float32 or above)"
assert lhs.type.is_block() and rhs.type.is_block()
assert lhs.dtype == rhs.dtype, "lhs and rhs must have the same dtype!"
assert len(lhs.shape) == 2 and len(rhs.shape) == 2
Expand Down