1. 依赖拓扑 (Dependency Topology)
作为融合优化的起点,本章立足于计算图 (Computational Graph) 的宏观视角。在这一层级,编译器不关注算子内部的具体实现,而是聚焦于算子节点之间的数据依赖 (Data Dependency) 与连接关系。
深度学习编译器(如TVM[1]、Glow[5]、XLA[10])的核心目标是通过重组图结构,识别并消除冗余的内存读写操作。通过分析生产者-消费者的垂直链条、多分支的水平结构以及特定的子图模式,编译器决定将哪些独立的算子节点"坍缩"为一个内核,从而在算法层面最小化数据搬运的开销。
这种图级优化对于现代深度学习模型至关重要,因为:
- 内存墙问题:数据移动能耗远高于计算能耗(约100倍差异)[1]
- Kernel Launch开销:频繁的CPU-GPU同步会显著降低性能[5]
- Cache局部性:融合可提高数据重用率,减少全局内存访问[2]
AI编译器领域的奠基性工作(如TVM[1]、Tensor Comprehensions[4])已证明,图级融合优化可带来2-10倍的性能提升。
1.1 垂直融合 (Vertical Fusion)
垂直融合沿着数据流的生产者-消费者链条进行优化,是最经典的融合模式[1,3]。Apollo[11]等研究表明,基于生产者-消费者关系的融合是深度学习编译器的核心优化技术之一。
1.1.1 Producer-Consumer Fusion (生产者-消费者融合)
背景
在深度学习模型中,算子之间往往存在严格的数据依赖关系:后一个算子需要前一个算子的输出作为输入。这种依赖链形成了 "生产者-消费者" (Producer-Consumer) 关系[12]。
传统执行模式下,每个算子独立编译成 kernel,物化 (Materialization,中间结果) 需要写入全局内存 (Global Memory),然后下一个 kernel 再重新读取。这种"物化"过程带来了显著的开销:
- 内存带宽压力:每个中间张量都需要写回主存再读出
- Cache利用率低:物化可能驱逐有用的Cache内容
- Kernel启动开销:每个算子单独启动一个Kernel
TVM[1]的研究表明,消除中间物化可减少50-80%的内存访问量。
核心思想
消除中间物化 (Intermediate Materialization Elimination):将生产者的计算内联到消费者中,使中间结果仅存活于寄存器/L1 Cache,避免写回全局内存[1,3]。
// 融合前
C = Add(A, B) // 写入全局内存
D = ReLU(C) // 从全局内存读取C
// 融合后
C = ReLU(Add(A, B)) // 物化仅存活于寄存器应用场景
| 场景 | 特征 | 示例 |
|---|---|---|
| 逐元素操作链 | 多个 Element-wise 操作串联 | Add → ReLU → Mul → Sigmoid |
| 卷积后激活 | 卷积 + 激活函数 | Conv2d → BiasAdd → ReLU |
| 归约后处理 | 归约操作 + 后处理 | ReduceMax → Subtract (LogSoftmax) |
| 矩阵乘法融合 | GEMM + bias/activation | MatMul → BiasAdd → Gelu |
技术原理
Producer-Consumer Fusion 的决策依赖于以下几个关键因素[1,3,6,11]:
依赖链分析 (Dependency Chain Analysis)
- 构建计算图的 依赖图 (Dependency Graph)
- 识别 合法融合候选 (Legal Fusion Candidates)
- 检测 循环依赖 (Cyclic Dependency)
内存代价评估 (Memory Cost Estimation)
DNNFusion[6]提出了形式化的代价模型来评估融合收益。
计算兼容性检查 (Computational Compatibility)
- 迭代空间一致性 (Iteration Space Alignment)
- 仿射变换兼容性 (Affine Transform Compatibility)
- 归约语义保持 (Reduction Semantic Preservation)
MLIR 实现方案
MLIR 通过多级 IR (Linalg、Affine、SCF) 和模式重写框架实现 Producer-Consumer Fusion[1,4]。MLIR的融合策略借鉴了TVM[1]和Tensor Comprehensions[4]的设计思想。
核心 Dialect 与 Pass:
| Dialect/Pass | 作用 | 关键数据结构 |
|---|---|---|
linalg | 结构化算子定义 | linalg.generic, linalg.matmul |
-linalg-fuse-elementwise-ops | Element-wise 融合 | ElementwiseOpFusionPass |
-affine-loop-fusion | 循环级融合 | AffineLoopFusion |
scf.for | 结构化控制流 | scf::ForOp |
实现架构:
┌─────────────────────────────────────────────────────────┐
│ MLIR Fusion Pipeline │
├─────────────────────────────────────────────────────────┤
│ 1. 依赖分析 (Dependency Analysis) │
│ └── DominanceInfo, PostDominanceInfo │
├─────────────────────────────────────────────────────────┤
│ 2. 融合候选识别 (Fusion Candidate Identification) │
│ └── linalg.generic 的 producer-consumer 匹配 │
├─────────────────────────────────────────────────────────┤
│ 3. 合法性检查 (Legality Check) │
│ ├── 迭代空间兼容性 (Iteration Space Compatibility) │
│ ├── Side-effect 检查 │
│ └── 形状推断 (Shape Inference) │
├─────────────────────────────────────────────────────────┤
│ 4. 融合变换 (Fusion Transform) │
│ └── TileAndFuse, FuseIntoContainingOp │
├─────────────────────────────────────────────────────────┤
│ 5. 代码生成 (Code Generation) │
│ └── LLVM IR / SPIR-V / CUDA │
└─────────────────────────────────────────────────────────┘MLIR 示例:Add + ReLU 融合
融合前的计算图:
// 生产者:逐元素加法
%add = tensor.empty() : tensor<128x128xf32>
%0 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]
} ins(%A, %B : tensor<128x128xf32>, tensor<128x128xf32>)
outs(%add : tensor<128x128xf32>) {
^bb0(%arg0: f32, %arg1: f32, %arg2: f32):
%sum = arith.addf %arg0, %arg1 : f32
linalg.yield %sum : f32
}
// 消费者:ReLU 激活
%relu = tensor.empty() : tensor<128x128xf32>
%1 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]
} ins(%0 : tensor<128x128xf32>)
outs(%relu : tensor<128x128xf32>) {
^bb0(%arg0: f32, %arg1: f32):
%zero = arith.constant 0.0 : f32
%cmp = arith.cmpf ogt, %arg0, %zero : f32
%result = arith.select %cmp, %arg0, %zero : f32
linalg.yield %result : f32
}融合后的单算子:
// Producer-Consumer Fusion: Add + ReLU
%output = tensor.empty() : tensor<128x128xf32>
%2 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, // A
affine_map<(d0, d1) -> (d0, d1)>, // B
affine_map<(d0, d1) -> (d0, d1)>], // Output
iterator_types = ["parallel", "parallel"]
} ins(%A, %B : tensor<128x128xf32>, tensor<128x128xf32>)
outs(%output : tensor<128x128xf32>) {
^bb0(%arg0: f32, %arg1: f32, %arg2: f32):
// 原始 Add 操作
%sum = arith.addf %arg0, %arg1 : f32
// 原始 ReLU 操作(物化 %sum 无需物化)
%zero = arith.constant 0.0 : f32
%cmp = arith.cmpf ogt, %sum, %zero : f32
%result = arith.select %cmp, %sum, %zero : f32
linalg.yield %result : f32
}Pass Pipeline调用:
# 使用 MLIR opt 工具执行融合
mlir-opt input.mlir \
--linalg-fuse-elementwise-ops \
--convert-linalg-to-loops \
--convert-scf-to-cf \
--convert-cf-to-llvm1.1.2 Element-wise Chain Fusion (逐元素操作链融合)
背景
这是垂直融合的一个特化场景,针对 逐元素操作 (Element-wise Operation) 的连续链式结构。这类操作的迭代空间完全一致,且无归约依赖。
EFFNet[13]和Attentional Feature Fusion[5]等研究表明,逐元素操作链在Transformer和CNN中广泛存在。例如 Transformer 中的 FFN:Linear → Gelu → Dropout → Linear。
技术原理
- 迭代空间对齐 (Identical Iteration Space):所有操作共享相同的循环结构。
- 索引映射简化 (Simple Indexing Maps):通常简化为恒等映射 (identity) 或广播 (broadcast)。
- Kernel Launch 消除:N 个 Kernel 合并为 1 个。
TVM[1]和Tensor Comprehensions[4]都实现了高效的逐元素链融合优化。
MLIR 实现关键 Pass
// mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
struct ElementwiseOpFusionPattern : public OpRewritePattern<linalg::GenericOp> {
LogicalResult matchAndRewrite(linalg::GenericOp op,
PatternRewriter &rewriter) const override {
// 1. 检查是否为 element-wise 操作
if (!isElementwise(op)) return failure();
// 2. 查找 producer(当前操作的所有输入)
for (Value operand : op.getInputs()) {
if (auto producer = operand.getDefiningOp<linalg::GenericOp>()) {
// 3. 验证融合合法性
if (isValidFusionCandidate(producer, op)) {
// 4. 执行融合
return fuseOps(producer, op, rewriter);
}
}
}
return failure();
}
private:
bool isElementwise(linalg::GenericOp op) const {
// 检查所有 iterator_types 是否为 "parallel"
// 检查 indexing_maps 是否为 permutation(无归约维度)
}
bool isValidFusionCandidate(linalg::GenericOp producer,
linalg::GenericOp consumer) const {
// 验证迭代空间兼容性
// 验证没有 side-effect
// 验证形状一致性
}
};MLIR 示例:融合 LayerNorm 的逐元素操作部分
LayerNorm 的后半部分 (Normalize, Scale, Shift) 是典型的逐元素链。
// 原始分解实现(多个独立的 linalg.generic)
func.func @layer_norm_unfused(%input: tensor<BxMxNxf32>,
%weight: tensor<Nxf32>,
%bias: tensor<Nxf32>) -> tensor<BxMxNxf32> {
// Step 1: 计算 mean (reduction)
%mean = linalg.generic ... { arith.addf, ... }
// Step 2: 计算 variance (element-wise + reduction)
%variance = linalg.generic ... { arith.subf, arith.mulf, ... }
// Step 3: normalize (element-wise)
%normalized = linalg.generic ... {
^bb0(%x, %m, %v, %w, %b):
%eps = arith.constant 0.00001 : f32
%std = arith.sqrt(arith.addf %v, %eps) : f32
%norm = arith.divf arith.subf(%x, %m), %std
%scaled = arith.mulf %norm, %w
%result = arith.addf %scaled, %b
linalg.yield %result
}
return %normalized : tensor<BxMxNxf32>
}
// 融合后
func.func @layer_norm_fused(%input: tensor<BxMxNxf32>,
%weight: tensor<Nxf32>,
%bias: tensor<Nxf32>) -> tensor<BxMxNxf32> {
// 假设 Mean 和 Variance 已计算完成
// 融合:normalize + scale + shift 三个 element-wise 操作
%output = linalg.generic {
indexing_maps = [
affine_map<(b, m, n) -> (b, m, n)>, // input
affine_map<(b, m, n) -> (b, m)>, // mean (broadcast)
affine_map<(b, m, n) -> (b, m)>, // variance (broadcast)
affine_map<(b, m, n) -> (n)>, // weight (broadcast)
affine_map<(b, m, n) -> (n)>, // bias (broadcast)
affine_map<(b, m, n) -> (b, m, n)> // output
],
iterator_types = ["parallel", "parallel", "parallel"]
} ins(%input, %mean, %variance, %weight, %bias
: tensor<BxMxNxf32>, tensor<BxMxf32>, tensor<BxMxf32>,
tensor<Nxf32>, tensor<Nxf32>)
outs(%init : tensor<BxMxNxf32>) {
^bb0(%x: f32, %m: f32, %v: f32, %w: f32, %b: f32):
// 融合的计算:一次遍历完成所有 element-wise 操作
%eps = arith.constant 0.00001 : f32
%std = arith.sqrt(arith.addf %v, %eps) : f32
%norm = arith.divf arith.subf(%x, %m), %std : f32
%scaled = arith.mulf %norm, %w : f32
%result = arith.addf %scaled, %b : f32
linalg.yield %result : f32
}
return %output : tensor<BxMxNxf32>
}1.1.3 Reduce-to-Elementwise Fusion (归约-逐元素融合)
背景
标准的垂直融合通常在 归约 (Reduction) 操作处断开。例如 Softmax 或 LayerNorm,传统实现需要多次遍历内存 (Pass 1: 求和/均值 Pass 2: 归一化)。
归约-逐元素融合 旨在打破归约操作的同步屏障,通过数学算法的重写 (如 Welford 或 Online Softmax),将多次内存扫描合并为一次扫描 (One-pass)[7]。
技术原理
Welford 算法 (For LayerNorm):
- 传统:先遍历一次求 Mean,再遍历一次求 Variance,最后遍历求 Output。
- 融合:在单次循环中同时维护 Mean 和 Variance 的迭代更新公式,一次遍历即可得到最终统计量并应用 Normalization。
Online Softmax / Safe Softmax:
- 传统: -> -> (3 Pass)。
- 融合:利用数学性质 ,在一次遍历中动态更新全局 Max 和 Sum,无需预先扫描最大值[7]。
这种融合技术在FlashAttention[7]中得到关键应用,是其实现高效Attention计算的核心技术之一。
应用场景
| 算子 | 涉及算法 | 收益 |
|---|---|---|
| LayerNorm / RMSNorm | Welford Algorithm | 减少 1-2 次全局内存读写 |
| Softmax | Online Softmax | 减少 2 次全局内存读写 |
| Cross Entropy Loss | Log-Sum-Exp Trick | 提升数值稳定性与性能 |
这种融合无法通过简单的 linalg.fuse 实现,通常需要使用 scf.for 携带状态 (iter_args) 来表达复杂的更新逻辑。
MLIR 示例
// Online Softmax 逻辑结构示例
// scf.for 不仅计算,还通过 iter_args 携带动态更新的 max 和 sum
%final_max, %final_sum = scf.for %i = 0 to %N
iter_args(%curr_max = %neg_inf, %curr_sum = %c0) -> (f32, f32) {
%val = load %input[%i]
// 1. 更新 Max
%new_max = arith.maxf %curr_max, %val
// 2. 计算修正因子:exp(old_max - new_max)
%correction = arith.exp (%curr_max - %new_max)
// 3. 修正旧的 Sum 并加上新的项
%term = arith.exp (%val - %new_max)
%new_sum = arith.addf (arith.mulf %curr_sum, %correction), %term
scf.yield %new_max, %new_sum
}1.2 水平融合 (Horizontal Fusion)
水平融合针对多个算子共享同一输入的场景,也称为Sibling Fusion或Multi-output Fusion[14,15]。Data Movement is All You Need[7]等研究表明,这类融合对于Transformer等模型优化尤为关键。
1.2.1 Multi-output Fusion (多输出融合)
背景
Multi-output Fusion (也称 Sibling Fusion) 针对多个算子共享同一输入的场景。如果每个消费者独立执行,共享输入会被多次加载。水平融合将这些计算合并,实现"一次读取,多次计算"。
在Transformer模型中,QKV Projection[14,15]是最典型的应用场景,将三个独立的矩阵乘法融合为一个大矩阵乘法,可显著减少内存访问[7]。
应用场景
| 场景 | 描述 | 性能收益来源 |
|---|---|---|
| Attention QKV | 同时计算 Q、K、V | 内存布局优化 ([3, H, ...] vs [H, 3, ...]) |
| Gate Projection | GLU 变体中的 Gate 和 Value 分支 | 合并矩阵乘法 |
| Multi-branch | Inception 模块 / 多 Loss 计算 | 共享卷积输入 |
技术原理
Multi-output Fusion 的核心决策因素:
输入共享度 (Input Sharing Degree)
Sharing_Score = |Shared_Inputs| / |Total_Inputs| 融合收益 ∝ Sharing_Score × Input_Size寄存器压力 (Register Pressure)
融合前:每个 op 使用 R 寄存器 融合后:同时存储所有物化需要 R × N 寄存器 寄存器溢出会抵消融合收益并行度权衡 (Parallelism Trade-off)
融合前:N 个 kernel 并行执行(GPU 上可并发) 融合后:1 个 kernel,但每个 thread 计算所有输出
MLIR 实现方案
| 方案 | Dialect | Pass | 适用场景 |
|---|---|---|---|
| Linalg 融合 | linalg | -linalg-fuse-elementwise-ops | 结构化算子 |
| Affine 循环融合 | affine | -affine-loop-fusion | 低层循环优化 |
| IREE Flow 融合 | flow | iree-codegen | 端到端编译 |
Linalg 融合机制
// mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
// 融合策略:识别共享输入的多个 generic op
struct MultiOutputFusionStrategy {
// 1. 构建输入使用图(Input Use Graph)
struct UseGraph {
DenseMap<Value, SmallVector<Operation*>> inputToUsers;
};
// 2. 识别可融合的兄弟操作组
SmallVector<SmallVector<Operation*>> identifyFusionGroups(Operation *root) {
// 找出共享相同输入的操作集合
// 验证迭代空间兼容性
// 评估寄存器压力
}
// 3. 生成融合后的 linalg.generic
linalg::GenericOp buildFusedOp(ArrayRef<Operation*> ops) {
// 合并 indexing_maps
// 合并 iterator_types
// 合并 region(计算逻辑)
}
};MLIR 示例:Attention QKV 融合
融合策略:将三个独立的矩阵乘法 合并为一个大的矩阵乘法 。
// 融合前
func.func @qkv_projection_unfused(
%X: tensor<Seq x Hiddenxf32>, // [Seq, Hidden]
%W_q: tensor<Hidden x (Heads x Head_Dim)xf32>,
%W_k: tensor<Hidden x (Heads x Head_Dim)xf32>,
%W_v: tensor<Hidden x (Heads x Head_Dim)xf32>)
-> (tensor<Seq x (Heads x Head_Dim)xf32>,
tensor<Seq x (Heads x Head_Dim)xf32>,
tensor<Seq x (Heads x Head_Dim)xf32>) {
// Q = X @ W_q
%Q = linalg.matmul ins(%X, %W_q : ...) outs(...)
// K = X @ W_k
%K = linalg.matmul ins(%X, %W_k : ...) outs(...)
// V = X @ W_v
%V = linalg.matmul ins(%X, %W_v : ...) outs(...)
return %Q, %K, %V
}
// 融合后(Concat-then-Matmul 策略)
func.func @qkv_projection_fused(
%X: tensor<Seq x Hiddenxf32>,
%W_qkv: tensor<Hidden x (3 x Heads x Head_Dim)xf32>) // 权重拼接
-> tensor<Seq x 3 x Heads x Head_Dimxf32> {
// 单次矩阵乘法:[Seq, Hidden] @ [Hidden, 3*Heads*Head_Dim]
// 输出形状:[Seq, 3, Heads, Head_Dim] = [Seq, Q/K/V, Heads, Head_Dim]
%QKV = linalg.generic {
indexing_maps = [
affine_map<(s, h) -> (s, h)>, // X
affine_map<(s, h) -> (h, 3, n, d)>, // W_qkv
affine_map<(s, h) -> (s, 3, n, d)> // QKV output
],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]
} ins(%X, %W_qkv : ...)
outs(%init : tensor<Seq x 3 x Heads x Head_Dimxf32>) {
^bb0(%x: f32, %w: f32):
// GEMM 计算
%prod = arith.mulf %x, %w : f32
linalg.yield %prod : f32
}
// 后续可通过 tensor.extract_slice 分离 Q、K、V
return %QKV : tensor<Seq x 3 x Heads x Head_Dimxf32>
}IREE Flow 方言实现
注:扩展阅读 IREE的Flow方言如何高效实现Attention的QKV计算?
1.2.2 Batch/SIMD Fusion
背景
将批次 (Batch) 维度作为并行维度,利用 SIMT (GPU) 或 SIMD (CPU) 指令并行处理多个独立样本[1]。这种融合策略在批量推理场景中尤为重要。
应用场景
- 批量推理:将多个请求合并为一个 Batch 执行,提升 GPU 利用率[1]。
- Ascend NPU:利用多核并行处理 Batch 维度。
- CPU 向量化:将 Batch 映射到 AVX-512 向量通道。
TVM[1]和ROLLER[2]都实现了高效的Batch/SIMD融合优化策略。
1.3 模式融合 (Pattern Fusion)
模式融合针对深度学习中频繁出现的固定子图模式,通过模式匹配直接替换为高度优化的实现[1,16,17]。SpaceFusion[16]、Neptune[17]、PyPM[18]等最新研究提出了系统化的模式匹配与融合框架。
1.3.1 Special Pattern Fusion (特定模式融合)
背景
针对深度学习中频繁出现的固定子图模式,直接替换为高度优化的实现。TVM[1]、Glow[5]等编译器都内置了丰富的模式匹配规则。
常见模式
| 模式名称 | 操作序列 | 优化机会 |
|---|---|---|
| Conv-BN-ReLU | Conv2d → BatchNorm → ReLU | 编译期将 BN 参数折叠进 Conv 权重 |
| MatMul-Bias-Act | MatMul → BiasAdd → Activation | Epilogue 融合 (详见 6.1.2) |
| Softmax | Max → Subtract → Exp → Sum → Div | Online Softmax (详见 1.1.3) |
| LayerNorm | Mean → Variance → Normalize | Welford One-pass 算法 |
MLIR 示例
使用 PatternRewriter 进行图匹配与重写:
// 使用 PatternRewriter 进行模式匹配融合
struct ConvBNReLUPattern : public OpRewritePattern<tensor::PackOp> {
LogicalResult matchAndRewrite(tensor::PackOp packOp,
PatternRewriter &rewriter) const override {
// 1. 匹配模式:Conv → BatchNorm → ReLU
auto conv = packOp.getInput().getDefiningOp<linalg::Conv2DNhwcHwcfOp>();
auto bn = /* 获取 BatchNorm op */;
auto relu = /* 获取 ReLU op */;
if (!conv || !bn || !relu) return failure();
// 2. 转换 BatchNorm 参数到卷积权重
auto fusedWeights = fuseBatchNormIntoConv(conv, bn);
// 3. 生成融合后的 Conv + ReLU
auto fusedOp = rewriter.create<linalg::Conv2DNhwcHwcfOp>(
conv.getLoc(), conv.getInput(), fusedWeights, ...);
// 4. 添加 ReLU 到 region
addReLUToRegion(fusedOp);
return success();
}
};1.3.2 Transformer Block Fusion (Transformer 块级融合)
背景
在 LLM 时代,传统的算子级融合 (如 MatMul+Bias+Act) 已不足以应对长序列带来的性能挑战。Block Fusion 旨在打破算子边界,将 Transformer 层中的关键子图 (如 Attention 或 FFN) 作为一个整体进行优化,以最大化 SRAM 利用率[7,14]。
特别是 Multi-Head Attention (MHA),其标准实现 (MatMul (Q,K) -> Softmax -> MatMul (S,V)) 存在两大核心痛点:
- 显存爆炸:中间生成的 Attention Score 矩阵形状为 ,随序列长度呈平方级增长。
- Memory Wall (内存墙):该巨大矩阵的 HBM 读写延迟远超计算耗时,导致严重的 IO-bound。
因此,现代 Block Fusion (如 FlashAttention) 本质上是一种 IO 感知的算法级融合 (IO-aware Algorithmic Fusion)[7]。它利用 Tiling (分块) 和 Recomputation (重计算) 技术,将所有计算限制在片上 SRAM 中进行,完全消除了 中间矩阵对全局内存的访问。
典型模式
FlashAttention (v1/v2/v3)[7,8,9]:不仅融合了 Softmax,还进一步将 Dropout、Mask Generation 甚至 RoPE (Rotary Embedding) 融合到 Attention 的 Forward/Backward 循环中。FlashAttention系列论文(NeurIPS 2022, ICLR 2023)是该领域的里程碑工作,引用量超过4000次。
SwiGLU Fusion (FFN):LLM (如 LLaMA) 常用的 FFN 包含三个矩阵乘。融合策略是将两个并行的 MatMul (Gate proj 和 Up proj) 合并计算,并在寄存器中直接完成
SiLU和Element-wise Mul,避免中间宽矩阵写回显存[7]。
核心原理
通过 Tiling (分块) 和 Recomputation (重计算),将所有计算限制在片上 SRAM 中进行,完全消除 的物化[7]。FlashAttention[7]证明了这种IO-aware融合策略可将Attention计算速度提升2-4倍,同时显存占用减少一个数量级。
// 传统实现 (Standard Attention):
1. S = Q @ K^T (Write S to HBM, size N^2)
2. P = Softmax(S) (Read S, Write P to HBM, size N^2)
3. O = P @ V (Read P, Write O to HBM)
// 融合后 (FlashAttention / Memory-Efficient Attention):
Block-wise loop:
Load block of Q, K, V into SRAM
Compute block of S = Q_i @ K_j^T (on SRAM)
Compute block of P = Softmax(S) (on SRAM, using Online Softmax)
Compute block of O += P @ V_j (on SRAM, accumulate to Output)
(中间矩阵 S 和 P 从未离开过片上 SRAM)MLIR 示例:Tiled Attention Logic
在 MLIR 中,这表现为带有 iter_args 的双层嵌套循环:
func.func @flash_attention_tiled(%Q: tensor<...>, %K: tensor<...>, %V: tensor<...>) {
// 外层循环:遍历 Query 分块
%res = scf.for %i = 0 to %SeqLen step %Br
iter_args(%O_acc = %init_O, %m_acc = %init_m, %l_acc = %init_l) {
// 加载 Q 到 SRAM (逻辑上)
%Qi = tensor.extract_slice %Q[...]
// 内层循环:遍历 Key/Value 分块
%O_row, ... = scf.for %j = 0 to %SeqLen step %Bc
iter_args(%O_curr = %O_acc, ...) {
// 1. Compute Scores (Q @ K.T)
%S_ij = linalg.matmul ins(%Qi, %Kj) ...
// 2. Online Softmax Logic (更新 max 和 sum)
%m_new = arith.maxf %m_curr, %local_max
%P_ij = ... // exp(S_ij - m_new)
// 3. Accumulate Output (P @ V)
%O_new = linalg.matmul ins(%P_ij, %Vj) ...
scf.yield %O_new, %m_new, ...
}
scf.yield %O_row, ...
}
return %res
}MLIR 示例:SwiGLU Fusion
// SwiGLU 融合示意:Gate_Proj 和 Up_Proj 的后处理融合
// 假设 %gate_out 和 %up_out 是两个 MatMul 的输出 (或者通过 Grouped MatMul 产生)
// 这里的融合消除了两个巨大的中间 Tensor 的 HBM 写回
func.func @swiglu_epilogue_fused(%gate_buf: tensor<?x?xf32>, %up_buf: tensor<?x?xf32>)
-> tensor<?x?xf32> {
%res = linalg.generic {
indexing_maps = [
affine_map<(d0, d1) -> (d0, d1)>, // Gate input
affine_map<(d0, d1) -> (d0, d1)>, // Up input
affine_map<(d0, d1) -> (d0, d1)> // Output
],
iterator_types = ["parallel", "parallel"]
} ins(%gate_buf, %up_buf : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%init : tensor<?x?xf32>) {
^bb0(%g: f32, %u: f32, %out: f32):
// 1. Swish/SiLU 激活: x * sigmoid(x)
%sigmoid = math.sigmoid %g : f32
%act = arith.mulf %g, %sigmoid : f32
// 2. Gated Multiplication: act * up
%res = arith.mulf %act, %u : f32
linalg.yield %res : f32
}
return %res
}注:扩展阅读 MLIR如何高效实现Attention?
FlashAttention[7,8,9]系列工作是该领域的里程碑,推动了Attention计算的性能边界。Data Movement is All You Need[15]等研究进一步验证了这类融合优化的有效性。
1.3.3 Optimizer Fusion (优化器融合)
背景
在深度学习训练中,反向传播结束后需要执行参数更新 (Optimizer Step)。以主流的 AdamW 优化器为例,它包含一系列密集的 Element-wise 操作:计算一阶矩 (Momentum)、二阶矩 (Variance)、权重衰减 (Weight Decay) 以及最终的参数更新。
如果这些操作作为独立的 Kernel 执行,将导致严重的内存带宽浪费:参数和状态量需要被反复读写 HBM (High Bandwidth Memory),而计算量却很小 (Memory-bound)[1,5]。TVM[1]、Glow[5]等编译器都实现了优化器融合来减少训练时的内存访问。
技术原理
Kernel Fusion (Multi-input/Multi-output): 编译器将优化器的整个计算逻辑融合为一个单一的 Kernel。该 Kernel 一次性从内存读取 Parameter, Gradient, Momentum, Variance,在寄存器中完成所有数学运算,然后一次性写回更新后的值。
MLIR 实现:Fused AdamW
// 融合后的 AdamW Update Step
// 输入:Params, Grads, Exp_Avg (m), Exp_Avg_Sq (v)
// 输出:更新后的 Params, m, v
func.func @fused_adamw(%param: tensor<?xf32>, %grad: tensor<?xf32>,
%m: tensor<?xf32>, %v: tensor<?xf32>,
%lr: f32, %beta1: f32, %beta2: f32, %eps: f32, %decay: f32)
-> (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) {
// 使用 linalg.generic 定义多输入多输出的融合算子
%res:3 = linalg.generic {
indexing_maps = [
affine_map<(d0) -> (d0)>, // param
affine_map<(d0) -> (d0)>, // grad
affine_map<(d0) -> (d0)>, // m
affine_map<(d0) -> (d0)>, // v
affine_map<(d0) -> (d0)>, // out_param
affine_map<(d0) -> (d0)>, // out_m
affine_map<(d0) -> (d0)> // out_v
],
iterator_types = ["parallel"]
} ins(%param, %grad, %m, %v : tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>)
outs(%param, %m, %v : tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) {
^bb0(%p: f32, %g: f32, %m_in: f32, %v_in: f32, ...):
// 1. Weight Decay: g = g + p * decay
%p_decay = arith.mulf %p, %decay : f32
%g_prime = arith.addf %g, %p_decay : f32
// 2. Update Momentum (m): m = beta1 * m + (1-beta1) * g
%m_beta = arith.mulf %m_in, %beta1 : f32
%one_minus_beta1 = arith.constant ... : f32
%g_term = arith.mulf %g_prime, %one_minus_beta1 : f32
%m_out = arith.addf %m_beta, %g_term : f32
// 3. Update Variance (v): v = beta2 * v + (1-beta2) * g^2
// ... (省略部分计算细节) ...
%v_out = ...
// 4. Update Parameter
// p = p - lr * m / (sqrt(v) + eps)
%p_new = ...
// 5. 同时返回三个更新后的值
linalg.yield %p_new, %m_out, %v_out : f32, f32, f32
}
return %res:3
}参考文献 (References)
综述与基础工作
[1] Tianqi Chen, Thierry Moreau, Ziheng Jiang, et al. "TVM: An Automated End-to-End Optimizing Compiler for Deep Learning." USENIX OSDI, 2018. [Link]
[2] H. Zhu, et al. "ROLLER: Fast and Efficient Tensor Compilation for Deep Learning." USENIX OSDI, 2022. [Link]
[3] Y. Shi, et al. "Scheduling Deep Learning Memory Access via Tile-graph." USENIX OSDI, 2023. [Link]
[4] Nicolas Vasilache, Oleksandr Zinenko, et al. "Tensor Comprehensions: Framework-Agnostic High-Performance Machine Learning Kernels." arXiv, 2018. [Link]
[5] N. Rotem, et al. "Glow: Graph Lowering Compiler Techniques for Neural Networks." arXiv, 2018. [Link]
[6] Wei Niu, Jiexiong Guan, Yanzhi Wang, et al. "DNNFusion: Accelerating Deep Neural Networks Execution with Advanced Operator Fusion." ACM PLDI, 2021. [Link]
[7] Tri Dao, Daniel Y. Fu, et al. "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness." NeurIPS, 2022. [Link]
垂直融合 (Vertical Fusion)
[8] Tri Dao. "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning." ICLR, 2023. [Link]
[9] Tri Dao. "FlashAttention-3: Fast and Accurate Attention with Asynchrony and TMA." arXiv, 2024. [Link]
[10] Daniel Snider. "Operator Fusion in XLA: Analysis and Evaluation." arXiv, 2023. [Link]
[11] J. Zhao, et al. "Apollo: Automatic Partition-based Operator Fusion through Just-in-Time Compilation." MLSys, 2022. [Link]
[12] (IEEE, 2021). "A Model for Layer Fusion in Convolutional Neural Networks." IEEE. [Link]
水平融合 (Horizontal Fusion)
[13] (Signal Processing, 2023). "EFFNet: Element-wise Feature Fusion Network for Defect Detection." Signal Processing. [Link]
[14] L. Zhang, et al. "End-to-End Transformer Acceleration via Graph Fusion." ACM, 2025. [New]
[15] A. Ivanov, et al. "Data Movement is All You Need: A Case Study on Optimizing Transformers." MLSys, 2021. [Link]
模式融合 (Pattern Fusion)
[16] Y. Zhao, et al. "Neptune: Advanced ML Operator Fusion for Locality and Efficiency." arXiv, 2025. [New]
[17] (ACM, 2025). "PyPM: Pattern Matching in AI Compilers and Its Formalization." ACM, March 2025. [New]
[18] (ACM, 2025). "SpaceFusion: Advanced Deep Learning Operator Fusion." ACM, 2025. [New]
其他相关工作
[19] H. Xu, et al. "Optimized Spatial Architecture Mapping Flow for Transformers." arXiv, 2024. [Link]
[20] (arXiv, 2025). "Blockbuster: Block-level AI Operator Fusion." arXiv, 2025. [New]
[21] Y. Sun, et al. "The Deep Learning Compiler: A Comprehensive Survey." arXiv, 2020. [Link]