前阵子在 GTC 上跟不少人聊,发现大家普遍有一个共识:想写出性能最优的 kernel,要么学 cute DSL,要么上 CUTLASS。Triton 更多被看作是一个”方便但性能一般”的选项。
但在实际开发中,我发现事情没那么绝对。至少在一些相对简单的 kernel 场景下,Triton 已经能提供非常出色的算子——甚至略优于手搓的 CUDA 或 cute DSL 的实现。本文我会通过几个具体的例子和实测数据来说明这一点。
LayerNorm
第一个例子是两个最常见的 Normalize 算子:LayerNorm 和 RMSNorm(定义见 PyTorch LayerNorm、PyTorch RMSNorm)。这两个操作还常与后续算子融合(Fused)。
RMSNorm
我们先来看不做任何 Fusion 的情况。
这方面,开源社区已经有不少高质量的实现可供参考:quack 用 cute DSL 实现了一版,Flash Attention 用 Triton 实现了一版,FlashInfer 用 CUDA 实现了一版。我认为这些实现基本代表了当前开源中的最高水平。
| 不过有一点值得补充:Flash Attention 虽然已经用 Triton 写了 LayerNorm,但它的实现为了兼容 backward pass,牺牲了一些前向性能,并不是一个纯粹的前向 LayerNorm。换句话说,Flash Attention 中的版本并不代表 Triton 所能达到的性能上限。因此,我还是选择自己重新实现了一版(实现代码 | benchmark 脚本)。 |
以下分别是在 H100 和 B200 上的实测结果,实验结果用 python bench_layernorm.py 即可复现。输入 shape 为 [12800, N],测试的是长序列(Diffusion 场景)下不同 hidden size 的吞吐(单位 GB/s):
H100(镜像:lmsysorg/sglang:dev)
| N | Flash Attention (Triton) | Ours (Triton) | FlashInfer (CUDA) | QuACK (cute DSL) |
|---|---|---|---|---|
| 128 | 73.78 | 783.47 | 305.77 | 797.19 |
| 3840 | 2253.13 | 2750.58 | 2524.30 | 2734.40 |
| 4096 | 2463.13 | 2781.84 | 2580.17 | 2756.49 |
| 6144 | 2854.33 | 2865.27 | 2150.57 | 2849.27 |
| 8192 | 2898.92 | 2911.14 | 2488.05 | 2899.73 |
B200(镜像:lmsysorg/sglang:glm5-grace-blackwell)
| N | Flash Attention (Triton) | Ours (Triton) | FlashInfer (CUDA) | QuACK (cute DSL) |
|---|---|---|---|---|
| 128 | 210.20 | 797.81 | 320.19 | 799.94 |
| 3840 | 5118.03 | 5485.12 | 3199.11 | 5049.26 |
| 4096 | 5391.03 | 5529.05 | 3299.84 | 5154.21 |
| 6144 | 5088.31 | 5482.43 | 2819.55 | 5486.63 |
| 8192 | 5868.90 | 6114.58 | 3193.40 | 5847.96 |
可以看到,无论在 H100 还是 B200 上,我们用 Triton 实现的版本在各个 hidden size 下都与 QuACK(cute DSL)基本持平,甚至略优;而 FlashInfer(CUDA)和 Flash Attention(Triton)的实现则在部分场景下有明显差距。在 B200 上这一趋势更加明显。
Fused RMSNorm + Scale & Shift
再来看 Diffusion 模型中常见的场景:y = (scale + 1) · norm(x) + shift。我在前面的 Triton kernel 基础上做了简单修改,与 SGLang 中 cute DSL 的实现对比,性能略优。输入 shape 依然为 [12800, N]:
H100
| N | PyTorch | Ours (Triton) | SGLang (cute DSL) |
|---|---|---|---|
| 1024 | 329.67 | 2490.82 | 2457.02 |
| 3072 | 362.30 | 2871.96 | 2848.10 |
| 4096 | 369.27 | 2948.85 | 2920.12 |
| 6144 | 375.90 | 2989.69 | 2913.16 |
| 8192 | 380.79 | 3037.77 | 3030.36 |
B200
| N | PyTorch | Ours (Triton) | SGLang (cute DSL) |
|---|---|---|---|
| 1024 | 526.68 | 4459.59 | 4098.89 |
| 3072 | 545.95 | 5137.64 | 5060.57 |
| 4096 | 562.53 | 5847.35 | 5360.87 |
| 6144 | 579.48 | 4759.59 | 4753.59 |
| 8192 | 588.68 | 5693.42 | 5560.85 |
需要说明的是,我测试的场景还比较有限,而且与 cute DSL 版本的性能差异其实非常小,因此不能就此得出”Triton 一定优于 cute DSL”的结论。但这已经足以支撑本文的核心观点:在部分 Memory Bound 的场景下,Triton 完全可以逼近甚至匹敌专家手写的 CUDA / cute DSL 代码。
考虑到 cute DSL 的学习曲线相当陡峭,用 Triton 来实现高度定制化的 kernel 其实是一个非常务实的选择。整个 kernel 加 benchmark 脚本我花了大约两天时间,其中一大半花在了调研各种开源实现和横向对比上,真正写 Triton 代码的时间并不多。也欢迎大家在更多场景下测试验证。
QK LayerNorm
这里额外对比一个特殊场景:QK LayerNorm。输入 shape 为 [B, H, N],其中 N 通常为 128 或 64,H 为 num heads,weight shape 为 [N]。
这个场景的特殊之处在于,输入并不总是 contiguous 的。考虑 qkv = self.to_qkv(x); q, k, v = torch.chunk(qkv, dim=-1) 这种典型用法,只有 [H, N] 维度是连续的,B 维度则不是。而前面实现的 LayerNorm kernel 要求输入连续。此外,从前面的结果也可以看到,N = 128 时直接调用 FlashInfer 的效果其实很差。因此,SGLang 选择针对这个场景单独写一个 kernel,这个做法是合理的。我也针对这个场景专门写了一版(实现代码 |
benchmark 脚本)。 |
下面分别对比三种实现方式:
- layer_norm (my) — 直接复用通用的 LayerNorm kernel 来做 Q Norm。
- qk_norm (my) — 专门为 Q Norm 场景写的 kernel,可以处理非连续的输入。
- fused_qk (my) — 一个 kernel 同时处理 Q Norm 和 K Norm,函数签名与当前 SGLang 的版本一致,唯一区别是非 in-place 的。我个人认为在 Diffusion 场景下不差这点显存,in-place 并非必要。
实验结果用 python bench_qk_layernorm.py 即可复现。输入 shape 为 [B, 24, 128](对应 Qwen-Image 模型的配置),分别测试 contiguous 和 non-contiguous 两种情况:
H100 — Non-contiguous 输入(throughput, GB/s)
| B | PyTorch | layer_norm (my) | qk_norm (my) | fused_qk (my) | SGLang (cute DSL) |
|---|---|---|---|---|---|
| 1024 | 3970.29 | 1097.14 | 3688.93 | 4268.97 | 3734.56 |
| 4096 | 2397.25 | 1093.51 | 2841.04 | 3089.91 | 3054.67 |
| 12800 | 2678.92 | 1042.60 | 2873.86 | 2970.73 | 2795.14 |
H100 — Contiguous 输入(throughput, GB/s)
| B | PyTorch | layer_norm (my) | qk_norm (my) | fused_qk (my) | SGLang (cute DSL) |
|---|---|---|---|---|---|
| 1024 | 3824.63 | 2891.10 | 3652.43 | 4248.65 | 3736.72 |
| 4096 | 2642.82 | 2595.48 | 2654.70 | 2801.24 | 2656.40 |
| 12800 | 2908.33 | 2868.07 | 2907.03 | 2954.84 | 2789.19 |
B200 — Non-contiguous 输入(throughput, GB/s)
| B | PyTorch | layer_norm (my) | qk_norm (my) | fused_qk (my) | SGLang (cute DSL) |
|---|---|---|---|---|---|
| 1024 | 4224.25 | 1384.62 | 4368.59 | 5568.24 | 4104.05 |
| 4096 | 3395.21 | 1466.98 | 5240.74 | 6323.57 | 4921.51 |
| 12800 | 3576.24 | 1447.01 | 5293.08 | 5884.29 | 4605.01 |
B200 — Contiguous 输入(throughput, GB/s)
| B | PyTorch | layer_norm (my) | qk_norm (my) | fused_qk (my) | SGLang (cute DSL) |
|---|---|---|---|---|---|
| 1024 | 4900.05 | 4222.25 | 4379.33 | 5586.45 | 4090.27 |
| 4096 | 6574.06 | 6083.06 | 4868.56 | 5766.84 | 4914.45 |
| 12800 | 6330.76 | 6233.05 | 5135.25 | 5605.78 | 4449.87 |
几个值得注意的点:通用 LayerNorm kernel 在 non-contiguous 场景下性能骤降(从 ~2800 跌至 ~1000 GB/s),这是因为引入了 tensor.contiguous() 这个额外的 copy 操作。而我们的 Fused Triton 版本(fused_qk)在两种输入模式下均表现得很不错。在 B200 上优势更为突出。
对更简单的 Kernel,请尝试 torch.compile
我还额外实现了一个针对 y = w * x + b 的 kernel(实现代码 |
benchmark 脚本)。SGLang 中其实也用 Triton 实现了一个 scale_shift kernel,我是写完之后才发现的——不过意外地发现它的性能不如我的版本。但这些都不重要了,因为两者都不如 torch.compile 生成的 kernel。 |
按理说,这时候我应该把 torch.compile 生成的 Triton kernel 拿出来研究一下,看看它做了什么优化。但我毕竟不是专职的 kernel engineer,感兴趣的同学可以自行探索。我还额外试了 torch.cat,同样发现 torch.compile 能带来明显的加速。
这其实解释了一个常见的困惑:为什么 torch.compile(model) 有时能加速,有时却没效果?
原因在于:对于 LayerNorm 这类本身已经高度优化的 kernel,torch.compile 并不会带来提升,甚至偶尔会更慢;但对于 y = w * x + b 或 torch.cat 这类极其简单的操作,torch.compile 生成的 fused kernel 效果很好。明白了这一点,以后就没必要对整个 model 无脑 torch.compile 了——那样 CPU overhead 太大。更好的做法是:只对这几个能从 torch compile 中受益的小 kernel 单独 compile。
以下是 y = w * x + b 的实测结果(throughput, GB/s),输入 shape 为 [1, S, 5120]:
Mult and Add: y = w * x + b
H100
| S | PyTorch | Ours (Triton) | SGLang (Triton) |
|---|---|---|---|
| 512 | 1436.44 | 1334.07 | 1254.17 |
| 4096 | 2158.38 | 2098.66 | 1836.13 |
| 6144 | 2219.55 | 2174.83 | 1881.81 |
| 8192 | 2252.52 | 2218.57 | 1900.80 |
B200
| S | PyTorch | Ours (Triton) | SGLang (Triton) |
|---|---|---|---|
| 512 | 2188.33 | 1725.52 | 1956.76 |
| 4096 | 4419.27 | 3858.74 | 3801.70 |
| 6144 | 4688.39 | 4353.55 | 4125.89 |
| 8192 | 4795.70 | 4520.54 | 4252.69 |
我们的 Triton 版本在 H100 上优于 SGLang 的 Triton 实现——一方面是 SGLang 里的 block size 估计没有专门调优过,另一方面那个 kernel 的目标场景也没有我这么 specific。无论如何,两者都不如 torch.compile 生成的 fused kernel。
Summary
最后需要强调的是,本文并不是在说”会 Triton 就够了”。对于 Attention 和 GEMM 这两个至关重要的算子,Triton 目前肯定是不够的——这部分内容会在后续文章中展开(没错,已经开始给自己挖坑了)。
本文也没有拉踩任何项目的意思。SGLang、Flash Attention、QuACK 都是非常优秀的开源库,而我这种专门针对一个极其具体的场景去手写 kernel 的做法,本身就带有一定的”不公平优势”,测试 case 也远称不上充分。
但这篇文章想反驳的是一个常见的偏见:“Triton 就是不如 cute DSL 或 CUDA”。至少从我的实验来看,用 Triton 写的这些 kernel 在 Diffusion 场景下具备一定的通用性,性能也不差——在多数测试场景下甚至略优于各个 baseline。