Skip to content

Commit 9a7ec49

Browse files
authored
【Inference Optimize】Support setting environment variables to enable stream_k (#74317)
1 parent 28be650 commit 9a7ec49

File tree

5 files changed

+160
-5
lines changed

5 files changed

+160
-5
lines changed

paddle/phi/kernels/fusion/cutlass/cutlass_extensions/ft_gemm_configs.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,10 @@ enum class SplitKStyle {
7575
SPLIT_K_SERIAL,
7676
// SPLIT_K_PARALLEL // Not supported yet
7777
};
78-
78+
// NOTE: (changwenbin) split_k_serial is turned on by default here.
7979
struct CutlassGemmConfig {
8080
CutlassTileConfig tile_config = CutlassTileConfig::ChooseWithHeuristic;
81-
SplitKStyle split_k_style = SplitKStyle::NO_SPLIT_K;
81+
SplitKStyle split_k_style = SplitKStyle::SPLIT_K_SERIAL;
8282
int split_k_factor = -1;
8383
int stages = -1;
8484
};

paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,16 @@ struct GemmFpAIntB {
7676
using LayoutC = typename Mma::LayoutC;
7777
using ElementScale = typename Mma::IteratorA::Element;
7878

79+
// NOTE: (changwenbin) Currently only A row major and B column major are
80+
// supported. Other cases have not been verified yet.
81+
82+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750)
83+
static_assert(
84+
platform::is_same<LayoutA, layout::RowMajor>::value &&
85+
platform::is_same<LayoutB, layout::ColumnMajor>::value,
86+
"A must be row major and B must be col major in cuda_arch >= sm75");
87+
#endif
88+
7989
static ComplexTransform const kTransformA = Mma::kTransformA;
8090
static ComplexTransform const kTransformB = Mma::kTransformA;
8191

paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/fpA_intB_gemm_split_k.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,9 @@ struct GemmFpAIntBSplitK {
259259
? device_sms
260260
: fast_min(args.avail_sms, device_sms);
261261

262-
// Initialize the block mapping structure
262+
static_assert(WarpCount::kK == 1, "WarpCount::kK should always == 1");
263+
// NOTE: (changwenbin) Adapt cutlass upgraded to version 3.8.0
264+
// Initialize the block mapping structure
263265
block_mapping = ThreadblockSwizzle(
264266
args.mode,
265267
args.problem_size,
@@ -271,7 +273,9 @@ struct GemmFpAIntBSplitK {
271273
cutlass::sizeof_bits<ElementA>::value,
272274
cutlass::sizeof_bits<ElementB>::value,
273275
cutlass::sizeof_bits<ElementC>::value,
274-
ThreadblockShape::kK / (WarpCount::kK * InstructionShape::kK));
276+
ThreadblockShape::kK /
277+
(WarpCount::kK *
278+
InstructionShape::kK)); // epilogue_acc_fragments_
275279
}
276280

277281
/// Returns the workspace size (in bytes) needed for these parameters

paddle/phi/kernels/fusion/cutlass/cutlass_kernels/cutlass_heuristic.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,17 @@ static std::vector<CutlassGemmConfig> get_candidate_configs(
120120
if (is_moe) {
121121
max_stages = 5;
122122
}
123+
// NOTE: (changwenbin)
124+
// Support enabling stream_k by setting the environment
125+
// variable `export CUTLASS_GEMM_STREAM_K=1`.
126+
SplitKStyle env_split_k = SplitKStyle::NO_SPLIT_K;
127+
const char* env_stream_k = std::getenv("CUTLASS_GEMM_STREAM_K");
128+
if (env_stream_k != nullptr) {
129+
env_split_k = SplitKStyle::SPLIT_K_SERIAL;
130+
}
123131
for (const auto& tile_config : tiles) {
124132
for (int stages = min_stages; stages <= max_stages; ++stages) {
125-
CutlassGemmConfig config{tile_config, SplitKStyle::NO_SPLIT_K, 1, stages};
133+
CutlassGemmConfig config{tile_config, env_split_k, 1, stages};
126134
candidate_configs.push_back(config);
127135
}
128136
}

test/quantization/test_weight_only_linear.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -925,5 +925,138 @@ def test_weightonly_linear_backward(
925925
)
926926

927927

928+
@unittest.skipIf(
929+
not core.is_compiled_with_cuda()
930+
or get_cuda_version() < 11020
931+
or paddle.device.cuda.get_device_capability()[0] < 8,
932+
"quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8",
933+
)
934+
class WeightOnlyLinear_stream_k_TestCase(unittest.TestCase):
935+
936+
def test_weightonly_linear_backward_int4(self):
937+
def test_weightonly_linear_backward(
938+
self, algo='weight_only_int4', weight_dtype='int4'
939+
):
940+
x = (
941+
paddle.rand(shape=(128, 8192), dtype='float16')
942+
* 1
943+
/ math.sqrt(8192)
944+
)
945+
x.stop_gradient = False
946+
quant_x = copy.deepcopy(x)
947+
quant_x.stop_gradient = False
948+
weight = (
949+
paddle.rand(shape=(8192, 8192), dtype='float16')
950+
* 1
951+
/ math.sqrt(8192)
952+
)
953+
954+
quant_weight, quant_scale = Q.weight_quantize(
955+
x=weight.cuda(), algo=algo
956+
)
957+
958+
quant_out = Q.weight_only_linear(
959+
x=quant_x,
960+
weight=quant_weight,
961+
weight_scale=quant_scale,
962+
weight_dtype=weight_dtype,
963+
)
964+
965+
test_weightonly_linear_backward(self)
966+
967+
def test_weightonly_linear_backward_int4_bf16(self):
968+
def test_weightonly_linear_backward(
969+
self, algo='weight_only_int4', weight_dtype='int4'
970+
):
971+
x = (
972+
paddle.rand(shape=(128, 8192), dtype='bfloat16')
973+
* 1
974+
/ math.sqrt(8192)
975+
)
976+
x.stop_gradient = False
977+
quant_x = copy.deepcopy(x)
978+
quant_x.stop_gradient = False
979+
weight = (
980+
paddle.rand(shape=(8192, 8192), dtype='bfloat16')
981+
* 1
982+
/ math.sqrt(8192)
983+
)
984+
985+
quant_weight, quant_scale = Q.weight_quantize(
986+
x=weight.cuda(), algo=algo
987+
)
988+
989+
quant_out = Q.weight_only_linear(
990+
x=quant_x,
991+
weight=quant_weight,
992+
weight_scale=quant_scale,
993+
weight_dtype=weight_dtype,
994+
)
995+
996+
test_weightonly_linear_backward(self)
997+
998+
def test_weightonly_linear_backward_int8(self):
999+
def test_weightonly_linear_backward(
1000+
self, algo='weight_only_int8', weight_dtype='int8'
1001+
):
1002+
x = (
1003+
paddle.rand(shape=(128, 8192), dtype='float16')
1004+
* 1
1005+
/ math.sqrt(8192)
1006+
)
1007+
x.stop_gradient = False
1008+
quant_x = copy.deepcopy(x)
1009+
quant_x.stop_gradient = False
1010+
weight = (
1011+
paddle.rand(shape=(8192, 8192), dtype='float16')
1012+
* 1
1013+
/ math.sqrt(8192)
1014+
)
1015+
1016+
quant_weight, quant_scale = Q.weight_quantize(
1017+
x=weight.cuda(), algo=algo
1018+
)
1019+
1020+
quant_out = Q.weight_only_linear(
1021+
x=quant_x,
1022+
weight=quant_weight,
1023+
weight_scale=quant_scale,
1024+
weight_dtype=weight_dtype,
1025+
)
1026+
1027+
test_weightonly_linear_backward(self)
1028+
1029+
def test_weightonly_linear_backward_int8_bf16(self):
1030+
def test_weightonly_linear_backward(
1031+
self, algo='weight_only_int8', weight_dtype='int8'
1032+
):
1033+
x = (
1034+
paddle.rand(shape=(128, 8192), dtype='bfloat16')
1035+
* 1
1036+
/ math.sqrt(8192)
1037+
)
1038+
x.stop_gradient = False
1039+
quant_x = copy.deepcopy(x)
1040+
quant_x.stop_gradient = False
1041+
weight = (
1042+
paddle.rand(shape=(8192, 8192), dtype='bfloat16')
1043+
* 1
1044+
/ math.sqrt(8192)
1045+
)
1046+
1047+
quant_weight, quant_scale = Q.weight_quantize(
1048+
x=weight.cuda(), algo=algo
1049+
)
1050+
1051+
quant_out = Q.weight_only_linear(
1052+
x=quant_x,
1053+
weight=quant_weight,
1054+
weight_scale=quant_scale,
1055+
weight_dtype=weight_dtype,
1056+
)
1057+
1058+
test_weightonly_linear_backward(self)
1059+
1060+
9281061
if __name__ == '__main__':
9291062
unittest.main()

0 commit comments

Comments
 (0)