Skip to content

1. 依赖拓扑 (Dependency Topology)

作为融合优化的起点,本章立足于计算图 (Computational Graph) 的宏观视角。在这一层级,编译器不关注算子内部的具体实现,而是聚焦于算子节点之间的数据依赖 (Data Dependency) 与连接关系。

深度学习编译器(如TVM[1]、Glow[5]、XLA[10])的核心目标是通过重组图结构,识别并消除冗余的内存读写操作。通过分析生产者-消费者的垂直链条多分支的水平结构以及特定的子图模式,编译器决定将哪些独立的算子节点"坍缩"为一个内核,从而在算法层面最小化数据搬运的开销。

这种图级优化对于现代深度学习模型至关重要,因为:

  • 内存墙问题:数据移动能耗远高于计算能耗(约100倍差异)[1]
  • Kernel Launch开销:频繁的CPU-GPU同步会显著降低性能[5]
  • Cache局部性:融合可提高数据重用率,减少全局内存访问[2]

AI编译器领域的奠基性工作(如TVM[1]、Tensor Comprehensions[4])已证明,图级融合优化可带来2-10倍的性能提升。

1.1 垂直融合 (Vertical Fusion)

垂直融合沿着数据流的生产者-消费者链条进行优化,是最经典的融合模式[1,3]。Apollo[11]等研究表明,基于生产者-消费者关系的融合是深度学习编译器的核心优化技术之一。

1.1.1 Producer-Consumer Fusion (生产者-消费者融合)

背景

在深度学习模型中,算子之间往往存在严格的数据依赖关系:后一个算子需要前一个算子的输出作为输入。这种依赖链形成了 "生产者-消费者" (Producer-Consumer) 关系[12]。

传统执行模式下,每个算子独立编译成 kernel,物化 (Materialization,中间结果) 需要写入全局内存 (Global Memory),然后下一个 kernel 再重新读取。这种"物化"过程带来了显著的开销:

  • 内存带宽压力:每个中间张量都需要写回主存再读出
  • Cache利用率低:物化可能驱逐有用的Cache内容
  • Kernel启动开销:每个算子单独启动一个Kernel

TVM[1]的研究表明,消除中间物化可减少50-80%的内存访问量。

核心思想

消除中间物化 (Intermediate Materialization Elimination):将生产者的计算内联到消费者中,使中间结果仅存活于寄存器/L1 Cache,避免写回全局内存[1,3]。

cpp
// 融合前
C = Add(A, B)        // 写入全局内存
D = ReLU(C)          // 从全局内存读取C 

// 融合后
C = ReLU(Add(A, B))  // 物化仅存活于寄存器

应用场景

场景特征示例
逐元素操作链多个 Element-wise 操作串联Add → ReLU → Mul → Sigmoid
卷积后激活卷积 + 激活函数Conv2d → BiasAdd → ReLU
归约后处理归约操作 + 后处理ReduceMax → Subtract (LogSoftmax)
矩阵乘法融合GEMM + bias/activationMatMul → BiasAdd → Gelu

技术原理

Producer-Consumer Fusion 的决策依赖于以下几个关键因素[1,3,6,11]:

  1. 依赖链分析 (Dependency Chain Analysis)

    • 构建计算图的 依赖图 (Dependency Graph)
    • 识别 合法融合候选 (Legal Fusion Candidates)
    • 检测 循环依赖 (Cyclic Dependency)
  2. 内存代价评估 (Memory Cost Estimation)

    DNNFusion[6]提出了形式化的代价模型来评估融合收益。

  3. 计算兼容性检查 (Computational Compatibility)

    • 迭代空间一致性 (Iteration Space Alignment)
    • 仿射变换兼容性 (Affine Transform Compatibility)
    • 归约语义保持 (Reduction Semantic Preservation)

MLIR 实现方案

MLIR 通过多级 IR (Linalg、Affine、SCF) 和模式重写框架实现 Producer-Consumer Fusion[1,4]。MLIR的融合策略借鉴了TVM[1]和Tensor Comprehensions[4]的设计思想。

核心 Dialect 与 Pass:

Dialect/Pass作用关键数据结构
linalg结构化算子定义linalg.generic, linalg.matmul
-linalg-fuse-elementwise-opsElement-wise 融合ElementwiseOpFusionPass
-affine-loop-fusion循环级融合AffineLoopFusion
scf.for结构化控制流scf::ForOp

实现架构:

┌─────────────────────────────────────────────────────────┐
│                  MLIR Fusion Pipeline                    │
├─────────────────────────────────────────────────────────┤
│  1. 依赖分析 (Dependency Analysis)                       │
│     └── DominanceInfo, PostDominanceInfo                │
├─────────────────────────────────────────────────────────┤
│  2. 融合候选识别 (Fusion Candidate Identification)       │
│     └── linalg.generic 的 producer-consumer 匹配        │
├─────────────────────────────────────────────────────────┤
│  3. 合法性检查 (Legality Check)                          │
│     ├── 迭代空间兼容性 (Iteration Space Compatibility)   │
│     ├── Side-effect 检查                                 │
│     └── 形状推断 (Shape Inference)                       │
├─────────────────────────────────────────────────────────┤
│  4. 融合变换 (Fusion Transform)                          │
│     └── TileAndFuse, FuseIntoContainingOp              │
├─────────────────────────────────────────────────────────┤
│  5. 代码生成 (Code Generation)                          │
│     └── LLVM IR / SPIR-V / CUDA                         │
└─────────────────────────────────────────────────────────┘

MLIR 示例:Add + ReLU 融合

融合前的计算图:

cpp
// 生产者:逐元素加法
%add = tensor.empty() : tensor<128x128xf32>
%0 = linalg.generic {
  indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
                   affine_map<(d0, d1) -> (d0, d1)>,
                   affine_map<(d0, d1) -> (d0, d1)>],
  iterator_types = ["parallel", "parallel"]
} ins(%A, %B : tensor<128x128xf32>, tensor<128x128xf32>)
    outs(%add : tensor<128x128xf32>) {
  ^bb0(%arg0: f32, %arg1: f32, %arg2: f32):
    %sum = arith.addf %arg0, %arg1 : f32
    linalg.yield %sum : f32
}

// 消费者:ReLU 激活
%relu = tensor.empty() : tensor<128x128xf32>
%1 = linalg.generic {
  indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
                   affine_map<(d0, d1) -> (d0, d1)>],
  iterator_types = ["parallel", "parallel"]
} ins(%0 : tensor<128x128xf32>)
    outs(%relu : tensor<128x128xf32>) {
  ^bb0(%arg0: f32, %arg1: f32):
    %zero = arith.constant 0.0 : f32
    %cmp = arith.cmpf ogt, %arg0, %zero : f32
    %result = arith.select %cmp, %arg0, %zero : f32
    linalg.yield %result : f32
}

融合后的单算子:

cpp
// Producer-Consumer Fusion: Add + ReLU
%output = tensor.empty() : tensor<128x128xf32>
%2 = linalg.generic {
  indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,   // A
                   affine_map<(d0, d1) -> (d0, d1)>,   // B
                   affine_map<(d0, d1) -> (d0, d1)>],  // Output
  iterator_types = ["parallel", "parallel"]
} ins(%A, %B : tensor<128x128xf32>, tensor<128x128xf32>)
    outs(%output : tensor<128x128xf32>) {
  ^bb0(%arg0: f32, %arg1: f32, %arg2: f32):
    // 原始 Add 操作
    %sum = arith.addf %arg0, %arg1 : f32
    // 原始 ReLU 操作(物化 %sum 无需物化)
    %zero = arith.constant 0.0 : f32
    %cmp = arith.cmpf ogt, %sum, %zero : f32
    %result = arith.select %cmp, %sum, %zero : f32
    linalg.yield %result : f32
}

Pass Pipeline调用:

bash
# 使用 MLIR opt 工具执行融合
mlir-opt input.mlir \
  --linalg-fuse-elementwise-ops \
  --convert-linalg-to-loops \
  --convert-scf-to-cf \
  --convert-cf-to-llvm

1.1.2 Element-wise Chain Fusion (逐元素操作链融合)

背景

这是垂直融合的一个特化场景,针对 逐元素操作 (Element-wise Operation) 的连续链式结构。这类操作的迭代空间完全一致,且无归约依赖。

EFFNet[13]和Attentional Feature Fusion[5]等研究表明,逐元素操作链在Transformer和CNN中广泛存在。例如 Transformer 中的 FFN:Linear → Gelu → Dropout → Linear

技术原理

  • 迭代空间对齐 (Identical Iteration Space):所有操作共享相同的循环结构。
  • 索引映射简化 (Simple Indexing Maps):通常简化为恒等映射 (identity) 或广播 (broadcast)。
  • Kernel Launch 消除:N 个 Kernel 合并为 1 个。

TVM[1]和Tensor Comprehensions[4]都实现了高效的逐元素链融合优化。

MLIR 实现关键 Pass

cpp
// mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
struct ElementwiseOpFusionPattern : public OpRewritePattern<linalg::GenericOp> {
  LogicalResult matchAndRewrite(linalg::GenericOp op,
                                PatternRewriter &rewriter) const override {
    // 1. 检查是否为 element-wise 操作
    if (!isElementwise(op)) return failure();

    // 2. 查找 producer(当前操作的所有输入)
    for (Value operand : op.getInputs()) {
      if (auto producer = operand.getDefiningOp<linalg::GenericOp>()) {
        // 3. 验证融合合法性
        if (isValidFusionCandidate(producer, op)) {
          // 4. 执行融合
          return fuseOps(producer, op, rewriter);
        }
      }
    }
    return failure();
  }

private:
  bool isElementwise(linalg::GenericOp op) const {
    // 检查所有 iterator_types 是否为 "parallel"
    // 检查 indexing_maps 是否为 permutation(无归约维度)
  }

  bool isValidFusionCandidate(linalg::GenericOp producer,
                              linalg::GenericOp consumer) const {
    // 验证迭代空间兼容性
    // 验证没有 side-effect
    // 验证形状一致性
  }
};

MLIR 示例:融合 LayerNorm 的逐元素操作部分

LayerNorm 的后半部分 (Normalize, Scale, Shift) 是典型的逐元素链。

cpp
// 原始分解实现(多个独立的 linalg.generic)
func.func @layer_norm_unfused(%input: tensor<BxMxNxf32>,
                              %weight: tensor<Nxf32>,
                              %bias: tensor<Nxf32>) -> tensor<BxMxNxf32> {
  // Step 1: 计算 mean (reduction)
  %mean = linalg.generic ... { arith.addf, ... }

  // Step 2: 计算 variance (element-wise + reduction)
  %variance = linalg.generic ... { arith.subf, arith.mulf, ... }

  // Step 3: normalize (element-wise)
  %normalized = linalg.generic ... {
    ^bb0(%x, %m, %v, %w, %b):
      %eps = arith.constant 0.00001 : f32
      %std = arith.sqrt(arith.addf %v, %eps) : f32
      %norm = arith.divf arith.subf(%x, %m), %std
      %scaled = arith.mulf %norm, %w
      %result = arith.addf %scaled, %b
      linalg.yield %result
  }

  return %normalized : tensor<BxMxNxf32>
}

// 融合后
func.func @layer_norm_fused(%input: tensor<BxMxNxf32>,
                            %weight: tensor<Nxf32>,
                            %bias: tensor<Nxf32>) -> tensor<BxMxNxf32> {
  // 假设 Mean 和 Variance 已计算完成

  // 融合:normalize + scale + shift 三个 element-wise 操作
  %output = linalg.generic {
    indexing_maps = [
      affine_map<(b, m, n) -> (b, m, n)>,  // input
      affine_map<(b, m, n) -> (b, m)>,     // mean (broadcast)
      affine_map<(b, m, n) -> (b, m)>,     // variance (broadcast)
      affine_map<(b, m, n) -> (n)>,        // weight (broadcast)
      affine_map<(b, m, n) -> (n)>,        // bias (broadcast)
      affine_map<(b, m, n) -> (b, m, n)>   // output
    ],
    iterator_types = ["parallel", "parallel", "parallel"]
  } ins(%input, %mean, %variance, %weight, %bias
        : tensor<BxMxNxf32>, tensor<BxMxf32>, tensor<BxMxf32>,
           tensor<Nxf32>, tensor<Nxf32>)
      outs(%init : tensor<BxMxNxf32>) {
  ^bb0(%x: f32, %m: f32, %v: f32, %w: f32, %b: f32):
    // 融合的计算:一次遍历完成所有 element-wise 操作
    %eps = arith.constant 0.00001 : f32
    %std = arith.sqrt(arith.addf %v, %eps) : f32
    %norm = arith.divf arith.subf(%x, %m), %std : f32
    %scaled = arith.mulf %norm, %w : f32
    %result = arith.addf %scaled, %b : f32
    linalg.yield %result : f32
  }

  return %output : tensor<BxMxNxf32>
}

1.1.3 Reduce-to-Elementwise Fusion (归约-逐元素融合)

背景

标准的垂直融合通常在 归约 (Reduction) 操作处断开。例如 Softmax 或 LayerNorm,传统实现需要多次遍历内存 (Pass 1: 求和/均值 Pass 2: 归一化)。

归约-逐元素融合 旨在打破归约操作的同步屏障,通过数学算法的重写 (如 Welford 或 Online Softmax),将多次内存扫描合并为一次扫描 (One-pass)[7]。

技术原理

  1. Welford 算法 (For LayerNorm)

    • 传统:先遍历一次求 Mean,再遍历一次求 Variance,最后遍历求 Output。
    • 融合:在单次循环中同时维护 Mean 和 Variance 的迭代更新公式,一次遍历即可得到最终统计量并应用 Normalization。
  2. Online Softmax / Safe Softmax

    • 传统 -> -> (3 Pass)。
    • 融合:利用数学性质 ,在一次遍历中动态更新全局 Max 和 Sum,无需预先扫描最大值[7]。

这种融合技术在FlashAttention[7]中得到关键应用,是其实现高效Attention计算的核心技术之一。

应用场景

算子涉及算法收益
LayerNorm / RMSNormWelford Algorithm减少 1-2 次全局内存读写
SoftmaxOnline Softmax减少 2 次全局内存读写
Cross Entropy LossLog-Sum-Exp Trick提升数值稳定性与性能

这种融合无法通过简单的 linalg.fuse 实现,通常需要使用 scf.for 携带状态 (iter_args) 来表达复杂的更新逻辑。

MLIR 示例

cpp
// Online Softmax 逻辑结构示例
// scf.for 不仅计算,还通过 iter_args 携带动态更新的 max 和 sum
%final_max, %final_sum = scf.for %i = 0 to %N 
  iter_args(%curr_max = %neg_inf, %curr_sum = %c0) -> (f32, f32) {
  
  %val = load %input[%i]
  // 1. 更新 Max
  %new_max = arith.maxf %curr_max, %val
  // 2. 计算修正因子:exp(old_max - new_max)
  %correction = arith.exp (%curr_max - %new_max)
  // 3. 修正旧的 Sum 并加上新的项
  %term = arith.exp (%val - %new_max)
  %new_sum = arith.addf (arith.mulf %curr_sum, %correction), %term
  
  scf.yield %new_max, %new_sum
}

1.2 水平融合 (Horizontal Fusion)

水平融合针对多个算子共享同一输入的场景,也称为Sibling Fusion或Multi-output Fusion[14,15]。Data Movement is All You Need[7]等研究表明,这类融合对于Transformer等模型优化尤为关键。

1.2.1 Multi-output Fusion (多输出融合)

背景

Multi-output Fusion (也称 Sibling Fusion) 针对多个算子共享同一输入的场景。如果每个消费者独立执行,共享输入会被多次加载。水平融合将这些计算合并,实现"一次读取,多次计算"。

在Transformer模型中,QKV Projection[14,15]是最典型的应用场景,将三个独立的矩阵乘法融合为一个大矩阵乘法,可显著减少内存访问[7]。

应用场景

场景描述性能收益来源
Attention QKV同时计算 Q、K、V内存布局优化 ([3, H, ...] vs [H, 3, ...])
Gate ProjectionGLU 变体中的 Gate 和 Value 分支合并矩阵乘法
Multi-branchInception 模块 / 多 Loss 计算共享卷积输入

技术原理

Multi-output Fusion 的核心决策因素:

  1. 输入共享度 (Input Sharing Degree)

    Sharing_Score = |Shared_Inputs| / |Total_Inputs|
    融合收益 ∝ Sharing_Score × Input_Size
  2. 寄存器压力 (Register Pressure)

    融合前:每个 op 使用 R 寄存器
    融合后:同时存储所有物化需要 R × N 寄存器
    寄存器溢出会抵消融合收益
  3. 并行度权衡 (Parallelism Trade-off)

    融合前:N 个 kernel 并行执行(GPU 上可并发)
    融合后:1 个 kernel,但每个 thread 计算所有输出

MLIR 实现方案

方案DialectPass适用场景
Linalg 融合linalg-linalg-fuse-elementwise-ops结构化算子
Affine 循环融合affine-affine-loop-fusion低层循环优化
IREE Flow 融合flowiree-codegen端到端编译

Linalg 融合机制

cpp
// mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
// 融合策略:识别共享输入的多个 generic op

struct MultiOutputFusionStrategy {
  // 1. 构建输入使用图(Input Use Graph)
  struct UseGraph {
    DenseMap<Value, SmallVector<Operation*>> inputToUsers;
  };

  // 2. 识别可融合的兄弟操作组
  SmallVector<SmallVector<Operation*>> identifyFusionGroups(Operation *root) {
    // 找出共享相同输入的操作集合
    // 验证迭代空间兼容性
    // 评估寄存器压力
  }

  // 3. 生成融合后的 linalg.generic
  linalg::GenericOp buildFusedOp(ArrayRef<Operation*> ops) {
    // 合并 indexing_maps
    // 合并 iterator_types
    // 合并 region(计算逻辑)
  }
};

MLIR 示例:Attention QKV 融合

融合策略:将三个独立的矩阵乘法 合并为一个大的矩阵乘法

cpp
// 融合前
func.func @qkv_projection_unfused(
    %X: tensor<Seq x Hiddenxf32>,        // [Seq, Hidden]
    %W_q: tensor<Hidden x (Heads x Head_Dim)xf32>,
    %W_k: tensor<Hidden x (Heads x Head_Dim)xf32>,
    %W_v: tensor<Hidden x (Heads x Head_Dim)xf32>)
    -> (tensor<Seq x (Heads x Head_Dim)xf32>,
        tensor<Seq x (Heads x Head_Dim)xf32>,
        tensor<Seq x (Heads x Head_Dim)xf32>) {

  // Q = X @ W_q
  %Q = linalg.matmul ins(%X, %W_q : ...) outs(...)
  // K = X @ W_k
  %K = linalg.matmul ins(%X, %W_k : ...) outs(...)
  // V = X @ W_v
  %V = linalg.matmul ins(%X, %W_v : ...) outs(...)

  return %Q, %K, %V
}

// 融合后(Concat-then-Matmul 策略)
func.func @qkv_projection_fused(
    %X: tensor<Seq x Hiddenxf32>,
    %W_qkv: tensor<Hidden x (3 x Heads x Head_Dim)xf32>)  // 权重拼接
    -> tensor<Seq x 3 x Heads x Head_Dimxf32> {

  // 单次矩阵乘法:[Seq, Hidden] @ [Hidden, 3*Heads*Head_Dim]
  // 输出形状:[Seq, 3, Heads, Head_Dim] = [Seq, Q/K/V, Heads, Head_Dim]
  %QKV = linalg.generic {
    indexing_maps = [
      affine_map<(s, h) -> (s, h)>,           // X
      affine_map<(s, h) -> (h, 3, n, d)>,     // W_qkv
      affine_map<(s, h) -> (s, 3, n, d)>      // QKV output
    ],
    iterator_types = ["parallel", "parallel", "parallel", "parallel"]
  } ins(%X, %W_qkv : ...)
    outs(%init : tensor<Seq x 3 x Heads x Head_Dimxf32>) {
  ^bb0(%x: f32, %w: f32):
    // GEMM 计算
    %prod = arith.mulf %x, %w : f32
    linalg.yield %prod : f32
  }

  // 后续可通过 tensor.extract_slice 分离 Q、K、V
  return %QKV : tensor<Seq x 3 x Heads x Head_Dimxf32>
}

IREE Flow 方言实现

注:扩展阅读 IREE的Flow方言如何高效实现Attention的QKV计算?


1.2.2 Batch/SIMD Fusion

背景

将批次 (Batch) 维度作为并行维度,利用 SIMT (GPU) 或 SIMD (CPU) 指令并行处理多个独立样本[1]。这种融合策略在批量推理场景中尤为重要。

应用场景

  • 批量推理:将多个请求合并为一个 Batch 执行,提升 GPU 利用率[1]。
  • Ascend NPU:利用多核并行处理 Batch 维度。
  • CPU 向量化:将 Batch 映射到 AVX-512 向量通道。

TVM[1]和ROLLER[2]都实现了高效的Batch/SIMD融合优化策略。


1.3 模式融合 (Pattern Fusion)

模式融合针对深度学习中频繁出现的固定子图模式,通过模式匹配直接替换为高度优化的实现[1,16,17]。SpaceFusion[16]、Neptune[17]、PyPM[18]等最新研究提出了系统化的模式匹配与融合框架。

1.3.1 Special Pattern Fusion (特定模式融合)

背景

针对深度学习中频繁出现的固定子图模式,直接替换为高度优化的实现。TVM[1]、Glow[5]等编译器都内置了丰富的模式匹配规则。

常见模式

模式名称操作序列优化机会
Conv-BN-ReLUConv2d → BatchNorm → ReLU编译期将 BN 参数折叠进 Conv 权重
MatMul-Bias-ActMatMul → BiasAdd → ActivationEpilogue 融合 (详见 6.1.2)
SoftmaxMax → Subtract → Exp → Sum → DivOnline Softmax (详见 1.1.3)
LayerNormMean → Variance → NormalizeWelford One-pass 算法

MLIR 示例

使用 PatternRewriter 进行图匹配与重写:

cpp
// 使用 PatternRewriter 进行模式匹配融合
struct ConvBNReLUPattern : public OpRewritePattern<tensor::PackOp> {
  LogicalResult matchAndRewrite(tensor::PackOp packOp,
                                PatternRewriter &rewriter) const override {
    // 1. 匹配模式:Conv → BatchNorm → ReLU
    auto conv = packOp.getInput().getDefiningOp<linalg::Conv2DNhwcHwcfOp>();
    auto bn = /* 获取 BatchNorm op */;
    auto relu = /* 获取 ReLU op */;

    if (!conv || !bn || !relu) return failure();

    // 2. 转换 BatchNorm 参数到卷积权重
    auto fusedWeights = fuseBatchNormIntoConv(conv, bn);

    // 3. 生成融合后的 Conv + ReLU
    auto fusedOp = rewriter.create<linalg::Conv2DNhwcHwcfOp>(
        conv.getLoc(), conv.getInput(), fusedWeights, ...);

    // 4. 添加 ReLU 到 region
    addReLUToRegion(fusedOp);

    return success();
  }
};

1.3.2 Transformer Block Fusion (Transformer 块级融合)

背景

在 LLM 时代,传统的算子级融合 (如 MatMul+Bias+Act) 已不足以应对长序列带来的性能挑战。Block Fusion 旨在打破算子边界,将 Transformer 层中的关键子图 (如 Attention 或 FFN) 作为一个整体进行优化,以最大化 SRAM 利用率[7,14]。

特别是 Multi-Head Attention (MHA),其标准实现 (MatMul (Q,K) -> Softmax -> MatMul (S,V)) 存在两大核心痛点:

  1. 显存爆炸:中间生成的 Attention Score 矩阵形状为 ,随序列长度呈平方级增长。
  2. Memory Wall (内存墙):该巨大矩阵的 HBM 读写延迟远超计算耗时,导致严重的 IO-bound。

因此,现代 Block Fusion (如 FlashAttention) 本质上是一种 IO 感知的算法级融合 (IO-aware Algorithmic Fusion)[7]。它利用 Tiling (分块)Recomputation (重计算) 技术,将所有计算限制在片上 SRAM 中进行,完全消除了 中间矩阵对全局内存的访问。

典型模式

  1. FlashAttention (v1/v2/v3)[7,8,9]:不仅融合了 Softmax,还进一步将 DropoutMask Generation 甚至 RoPE (Rotary Embedding) 融合到 Attention 的 Forward/Backward 循环中。FlashAttention系列论文(NeurIPS 2022, ICLR 2023)是该领域的里程碑工作,引用量超过4000次。

  2. SwiGLU Fusion (FFN):LLM (如 LLaMA) 常用的 FFN 包含三个矩阵乘。融合策略是将两个并行的 MatMul (Gate proj 和 Up proj) 合并计算,并在寄存器中直接完成 SiLUElement-wise Mul,避免中间宽矩阵写回显存[7]。

核心原理

通过 Tiling (分块)Recomputation (重计算),将所有计算限制在片上 SRAM 中进行,完全消除 的物化[7]。FlashAttention[7]证明了这种IO-aware融合策略可将Attention计算速度提升2-4倍,同时显存占用减少一个数量级。

cpp
// 传统实现 (Standard Attention):
1. S = Q @ K^T          (Write S to HBM, size N^2)
2. P = Softmax(S)       (Read S, Write P to HBM, size N^2)
3. O = P @ V            (Read P, Write O to HBM)

// 融合后 (FlashAttention / Memory-Efficient Attention):
Block-wise loop:
  Load block of Q, K, V into SRAM
  Compute block of S = Q_i @ K_j^T (on SRAM)
  Compute block of P = Softmax(S)  (on SRAM, using Online Softmax)
  Compute block of O += P @ V_j    (on SRAM, accumulate to Output)
  (中间矩阵 S 和 P 从未离开过片上 SRAM)

MLIR 示例:Tiled Attention Logic

在 MLIR 中,这表现为带有 iter_args 的双层嵌套循环:

cpp
func.func @flash_attention_tiled(%Q: tensor<...>, %K: tensor<...>, %V: tensor<...>) {
  // 外层循环:遍历 Query 分块
  %res = scf.for %i = 0 to %SeqLen step %Br 
    iter_args(%O_acc = %init_O, %m_acc = %init_m, %l_acc = %init_l) {
    
    // 加载 Q 到 SRAM (逻辑上)
    %Qi = tensor.extract_slice %Q[...] 

    // 内层循环:遍历 Key/Value 分块
    %O_row, ... = scf.for %j = 0 to %SeqLen step %Bc 
      iter_args(%O_curr = %O_acc, ...) {
      
      // 1. Compute Scores (Q @ K.T)
      %S_ij = linalg.matmul ins(%Qi, %Kj) ...
      
      // 2. Online Softmax Logic (更新 max 和 sum)
      %m_new = arith.maxf %m_curr, %local_max
      %P_ij = ... // exp(S_ij - m_new)
      
      // 3. Accumulate Output (P @ V)
      %O_new = linalg.matmul ins(%P_ij, %Vj) ...
      
      scf.yield %O_new, %m_new, ...
    }
    scf.yield %O_row, ...
  }
  return %res
}

MLIR 示例:SwiGLU Fusion

cpp
// SwiGLU 融合示意:Gate_Proj 和 Up_Proj 的后处理融合
// 假设 %gate_out 和 %up_out 是两个 MatMul 的输出 (或者通过 Grouped MatMul 产生)
// 这里的融合消除了两个巨大的中间 Tensor 的 HBM 写回

func.func @swiglu_epilogue_fused(%gate_buf: tensor<?x?xf32>, %up_buf: tensor<?x?xf32>) 
    -> tensor<?x?xf32> {
  
  %res = linalg.generic {
    indexing_maps = [
      affine_map<(d0, d1) -> (d0, d1)>, // Gate input
      affine_map<(d0, d1) -> (d0, d1)>, // Up input
      affine_map<(d0, d1) -> (d0, d1)>  // Output
    ],
    iterator_types = ["parallel", "parallel"]
  } ins(%gate_buf, %up_buf : tensor<?x?xf32>, tensor<?x?xf32>)
    outs(%init : tensor<?x?xf32>) {
    
    ^bb0(%g: f32, %u: f32, %out: f32):
      // 1. Swish/SiLU 激活: x * sigmoid(x)
      %sigmoid = math.sigmoid %g : f32
      %act = arith.mulf %g, %sigmoid : f32
      
      // 2. Gated Multiplication: act * up
      %res = arith.mulf %act, %u : f32
      
      linalg.yield %res : f32
  }
  return %res
}

注:扩展阅读 MLIR如何高效实现Attention?

FlashAttention[7,8,9]系列工作是该领域的里程碑,推动了Attention计算的性能边界。Data Movement is All You Need[15]等研究进一步验证了这类融合优化的有效性。

1.3.3 Optimizer Fusion (优化器融合)

背景

在深度学习训练中,反向传播结束后需要执行参数更新 (Optimizer Step)。以主流的 AdamW 优化器为例,它包含一系列密集的 Element-wise 操作:计算一阶矩 (Momentum)、二阶矩 (Variance)、权重衰减 (Weight Decay) 以及最终的参数更新。

如果这些操作作为独立的 Kernel 执行,将导致严重的内存带宽浪费:参数和状态量需要被反复读写 HBM (High Bandwidth Memory),而计算量却很小 (Memory-bound)[1,5]。TVM[1]、Glow[5]等编译器都实现了优化器融合来减少训练时的内存访问。

技术原理

Kernel Fusion (Multi-input/Multi-output): 编译器将优化器的整个计算逻辑融合为一个单一的 Kernel。该 Kernel 一次性从内存读取 Parameter, Gradient, Momentum, Variance,在寄存器中完成所有数学运算,然后一次性写回更新后的值。

MLIR 实现:Fused AdamW

cpp
// 融合后的 AdamW Update Step
// 输入:Params, Grads, Exp_Avg (m), Exp_Avg_Sq (v)
// 输出:更新后的 Params, m, v
func.func @fused_adamw(%param: tensor<?xf32>, %grad: tensor<?xf32>, 
                       %m: tensor<?xf32>, %v: tensor<?xf32>,
                       %lr: f32, %beta1: f32, %beta2: f32, %eps: f32, %decay: f32) 
                       -> (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) {
  
  // 使用 linalg.generic 定义多输入多输出的融合算子
  %res:3 = linalg.generic {
    indexing_maps = [
      affine_map<(d0) -> (d0)>, // param
      affine_map<(d0) -> (d0)>, // grad
      affine_map<(d0) -> (d0)>, // m
      affine_map<(d0) -> (d0)>, // v
      affine_map<(d0) -> (d0)>, // out_param
      affine_map<(d0) -> (d0)>, // out_m
      affine_map<(d0) -> (d0)>  // out_v
    ],
    iterator_types = ["parallel"]
  } ins(%param, %grad, %m, %v : tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) 
    outs(%param, %m, %v : tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) {
    
    ^bb0(%p: f32, %g: f32, %m_in: f32, %v_in: f32, ...):
      // 1. Weight Decay: g = g + p * decay
      %p_decay = arith.mulf %p, %decay : f32
      %g_prime = arith.addf %g, %p_decay : f32

      // 2. Update Momentum (m): m = beta1 * m + (1-beta1) * g
      %m_beta = arith.mulf %m_in, %beta1 : f32
      %one_minus_beta1 = arith.constant ... : f32
      %g_term = arith.mulf %g_prime, %one_minus_beta1 : f32
      %m_out = arith.addf %m_beta, %g_term : f32

      // 3. Update Variance (v): v = beta2 * v + (1-beta2) * g^2
      // ... (省略部分计算细节) ...
      %v_out = ... 

      // 4. Update Parameter
      // p = p - lr * m / (sqrt(v) + eps)
      %p_new = ...

      // 5. 同时返回三个更新后的值
      linalg.yield %p_new, %m_out, %v_out : f32, f32, f32
  }
  return %res:3
}

参考文献 (References)

综述与基础工作

[1] Tianqi Chen, Thierry Moreau, Ziheng Jiang, et al. "TVM: An Automated End-to-End Optimizing Compiler for Deep Learning." USENIX OSDI, 2018. [Link]

[2] H. Zhu, et al. "ROLLER: Fast and Efficient Tensor Compilation for Deep Learning." USENIX OSDI, 2022. [Link]

[3] Y. Shi, et al. "Scheduling Deep Learning Memory Access via Tile-graph." USENIX OSDI, 2023. [Link]

[4] Nicolas Vasilache, Oleksandr Zinenko, et al. "Tensor Comprehensions: Framework-Agnostic High-Performance Machine Learning Kernels." arXiv, 2018. [Link]

[5] N. Rotem, et al. "Glow: Graph Lowering Compiler Techniques for Neural Networks." arXiv, 2018. [Link]

[6] Wei Niu, Jiexiong Guan, Yanzhi Wang, et al. "DNNFusion: Accelerating Deep Neural Networks Execution with Advanced Operator Fusion." ACM PLDI, 2021. [Link]

[7] Tri Dao, Daniel Y. Fu, et al. "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness." NeurIPS, 2022. [Link]

垂直融合 (Vertical Fusion)

[8] Tri Dao. "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning." ICLR, 2023. [Link]

[9] Tri Dao. "FlashAttention-3: Fast and Accurate Attention with Asynchrony and TMA." arXiv, 2024. [Link]

[10] Daniel Snider. "Operator Fusion in XLA: Analysis and Evaluation." arXiv, 2023. [Link]

[11] J. Zhao, et al. "Apollo: Automatic Partition-based Operator Fusion through Just-in-Time Compilation." MLSys, 2022. [Link]

[12] (IEEE, 2021). "A Model for Layer Fusion in Convolutional Neural Networks." IEEE. [Link]

水平融合 (Horizontal Fusion)

[13] (Signal Processing, 2023). "EFFNet: Element-wise Feature Fusion Network for Defect Detection." Signal Processing. [Link]

[14] L. Zhang, et al. "End-to-End Transformer Acceleration via Graph Fusion." ACM, 2025. [New]

[15] A. Ivanov, et al. "Data Movement is All You Need: A Case Study on Optimizing Transformers." MLSys, 2021. [Link]

模式融合 (Pattern Fusion)

[16] Y. Zhao, et al. "Neptune: Advanced ML Operator Fusion for Locality and Efficiency." arXiv, 2025. [New]

[17] (ACM, 2025). "PyPM: Pattern Matching in AI Compilers and Its Formalization." ACM, March 2025. [New]

[18] (ACM, 2025). "SpaceFusion: Advanced Deep Learning Operator Fusion." ACM, 2025. [New]

其他相关工作

[19] H. Xu, et al. "Optimized Spatial Architecture Mapping Flow for Transformers." arXiv, 2024. [Link]

[20] (arXiv, 2025). "Blockbuster: Block-level AI Operator Fusion." arXiv, 2025. [New]

[21] Y. Sun, et al. "The Deep Learning Compiler: A Comprehensive Survey." arXiv, 2020. [Link]

Released under the CC BY-NC-ND 4.0 License.