Skip to content

7. 控制流与动态性 (Control-flow & Dynamism)

现实世界的 AI 模型往往包含分支跳转、变长序列或稀疏数据,这构成了对静态编译优化的最大挑战。

本章关注如何在运行时具有不确定性的环境下,依然维持高效的融合执行。核心目标是将动态行为静态化规范化。编译器通过控制流的扁平化 (Flattening)谓词化 (Predication)来规避分支发散,通过符号化分析 (Symbolic Analysis)来处理未知的张量形状,并通过运行时的即时特化 (JIT Specialization),确保动态模型也能享受到静态融合带来的性能红利。

控制流优化[1,2]、谓词执行[3,4]、符号化分析[5]、动态形状处理[6]等技术为控制流优化提供了理论基础。

7.1 控制流扁平化与谓词融合 (Control-Flow Flattening & Predication)

7.1.1 Predication (Select-based Fusion)

背景

AI 处理器 (尤其是 GPU/NPU) 偏好 SIMD/SIMT 并行,极其忌讳分支跳转 (Branch Divergence)。当算子内部存在条件逻辑 (如 ReLU, Dropout 或分段函数) 时,编译器不生成 if-else 跳转指令,而是采用谓词化 (Predication):同时计算两个分支的结果,然后使用 Select 指令根据条件掩码选择最终值。

IF-conversion[1]、Predicated Execution[2,3]等经典技术为谓词执行提供了理论基础。这使得带有控制流的算子依然可以被融合到由 vectorlinalg 构成的密集计算循环中。

融合通常发生在 Basic Block 内部。如果存在 if 跳转,基本块会被打断,阻碍指令调度和流水线优化。谓词化消除了跳转,使整个逻辑变成一个大的基本块,不仅方便融合,还让 Loop Vectorizer 能够轻松对循环进行 SIMD 化。

技术原理

  1. 条件物化 (Condition Materialization): 首先计算条件表达式,生成一个 布尔掩码 (Boolean Mask) 或谓词寄存器。对于向量处理器,这通常是一个与数据宽度一致的掩码向量。
    • Example: mask = (input > 0)
  2. 推测性执行 (Speculative Execution / Compute Both): 无论条件如何,编译器会让硬件同时计算两个分支的结果。
    • True Path: res_true = compute_true_block (input)
    • False Path: res_false = compute_false_block (input)
    • 注意:这要求分支内的操作是 无副作用 (Side-effect free) 的 (例如纯算术运算)。如果包含内存写操作或异常抛出,则不能简单推测执行。
  3. 指令选择 (Select / Blend): 使用硬件提供的 SelectCMOV(Conditional Move) 或 Blend 指令,根据掩码选择最终结果。
    • Logic: result = (mask & res_true) | (~mask & res_false)
    • Semantics: result = select (mask, res_true, res_false)

代价模型与权衡:

编译器并非对所有分支都进行谓词化融合,通常基于以下权衡:

  • 收益:消除了分支预测失败的惩罚 (CPU) 或分支发散的串行化 (GPU);增加了指令级并行度 (ILP)。
  • 成本:执行了无用的计算 (被丢弃的那个分支)。
  • 决策阈值:仅当分支内的指令数量较少 (如 ReLU, Clip, Thresholding) 时,谓词化才是划算的。如果分支内包含矩阵乘法等重负载操作,推测执行的代价过大,此时应保留控制流。

硬件指令映射

  • x86 AVX-512: vpblendvb (根据掩码混合两个向量)。
  • ARM NEON/SVE: bsl (Bitwise Select) 或 sel 指令。
  • NVIDIA PTX: selp (Select based on predicate register) 或利用 Predicate Register (@p add.f32 ...) 控制单条指令执行。
  • Ascend NPU:使用Mask机制,利用硬件的谓词寄存器支持在运行时控制数据流动。

MLIR 实现:

cpp
// 原始逻辑:显式的控制流分支 (难以向量化,难以融合)
// scf.if %cond { ... } else { ... }

// 优化后:谓词化融合 (Predicated Fusion)
// 所有的计算路径都被平铺,适合 SIMD 执行
func.func @relu_fused(%arg0: tensor<128xf32>) -> tensor<128xf32> {
  %c0 = arith.constant 0.0 : f32
  
  %0 = linalg.generic ... ins(%arg0) ... {
    ^bb0(%in: f32):
      // 1. 生成谓词掩码 (Predicate Mask)
      %cond = arith.cmpf ogt, %in, %c0 : f32
      
      // 2. 同时"计算"两个分支 (一个是 %in,一个是 %c0)
      // 3. 使用 Select 指令融合
      %res = arith.select %cond, %in, %c0 : f32
      
      linalg.yield %res : f32
  }
  return %0
}

7.1.2 Implicit Mask Fusion (隐式掩码融合)

背景

在长序列 LLM (如 Context Length = 128k) 中,Attention Mask 是一个 的下三角矩阵。如果显式创建该 Tensor (Materialization),将消耗 个 bool/float 空间,导致瞬间显存溢出 (OOM)。 隐式掩码融合指编译器不分配物理内存存储 Mask,而是在 Attention Score 计算的 Kernel 内部,利用当前线程的坐标 (row, col) 即时计算掩码值 (On-the-fly Mask Generation)

技术原理

Index-to-Value Fusion (索引即数值): 将"读取内存"操作替换为"逻辑比较"操作。

这本质上是将静态数据结构转化为动态控制流 (谓词逻辑),彻底消除了 Mask 的内存占用。

MLIR 实现:

cpp
// 融合 Causal Mask 到 Attention Score 计算
// 场景:Score = Softmax(Q * K^T + Mask)
// 优化:不读取 Mask Tensor,直接利用 linalg.index 动态生成

%scores = linalg.generic {
  indexing_maps = [ ... ], 
  iterator_types = ["parallel", "parallel"]
} ins(%Q, %K : ...) outs(%Out : ...) {

  ^bb0(%q: f32, %k: f32, %out: f32):
    // 1. 计算点积 Q * K^T
    %dot = arith.mulf %q, %k : f32
    
    // 2. [Implicit Fusion] 获取当前计算元素的几何坐标
    %row_idx = linalg.index 0 : index
    %col_idx = linalg.index 1 : index
    
    // 3. 生成 Causal Mask 逻辑
    // 判定条件:row >= col ?
    %is_causal = arith.cmpi sge, %row_idx, %col_idx : index
    
    // 4. 选择掩码值 (Predication)
    // 这是一个纯寄存器操作,无内存读取
    %c0 = arith.constant 0.0 : f32
    %neg_inf = arith.constant -1.0e+4 : f32
    %mask_val = arith.select %is_causal, %c0, %neg_inf : f32
    
    // 5. 融合到结果
    %masked_score = arith.addf %dot, %mask_val : f32
    linalg.yield %masked_score : f32
}

7.2 动态形状融合 (Dynamic Shape Fusion)

7.2.1 Symbolic Shape Analysis (符号化形状分析)

背景

在动态批处理 (Dynamic Batching) 或 NLP 变长序列场景中,Tensor 的维度在编译期是未知的 (?)。传统的静态编译器遇到 ? 通常会放弃融合,退化为解释执行或生成大量胶水代码。

现代 AI 编译器通过符号化分析,在不知道具体数值的情况下,依然能够生成高效的融合算子。Polyhedral compilation[5]、Symbolic analysis[6]等研究为动态形状处理提供了理论基础。

技术原理

  1. 符号化建模 (Symbolic Modeling): 编译器将所有的动态维度视为代数符号 (Symbols,如 )。所有的索引计算不再基于常量,而是基于仿射表达式 (Affine Expressions),例如 (i, j) -> (i * s_1 + j)。这使得编译器可以推理出内存访问的线性关系,即使步长 是未知的。

  2. 约束满足性检查 (Constraint Solving): 为了安全地融合两个算子 (如 Add (A, B)),编译器必须证明 。对于动态形状,编译器执行SSA 值等价分析:如果两个维度的定义来源 (Definition Source) 相同 (例如都来自同一个 Input Argument 的第 0 维),则认为它们在运行时必然相等,允许融合。

  3. 形状实体化 (Shape Reification): 这是生成可执行代码的关键。编译器将形状计算逻辑从数据计算中剥离,生成一组独立的标量运算指令 (通常是 index 类型)。这些指令在运行时率先执行,计算出具体的循环边界,然后喂给融合后的 Kernel。

MLIR 实现:

cpp
// 动态形状融合示例
// 编译器不需要知道具体大小,只需要生成依赖运行时 Dim 的代码
func.func @dynamic_fusion(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>) -> tensor<?x?xf32> {
  // 1. [Reify] 获取运行时形状 (Reify Shapes)
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %d0 = tensor.dim %A, %c0 : tensor<?x?xf32>
  %d1 = tensor.dim %A, %c1 : tensor<?x?xf32>

  // 2. [Init] 使用动态尺寸创建输出 Tensor
  %init = tensor.empty(%d0, %d1) : tensor<?x?xf32>

  // 3. [Generic] 融合算子体
  // linalg.generic 天然支持动态形状,因为它的循环边界由输入 (%A) 决定
  %res = linalg.generic {
    indexing_maps = [
      affine_map<(d0, d1) -> (d0, d1)>,
      affine_map<(d0, d1) -> (d0, d1)>,
      affine_map<(d0, d1) -> (d0, d1)>
    ],
    iterator_types = ["parallel", "parallel"]
  } ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
    outs(%init : tensor<?x?xf32>) {
    ^bb0(%a: f32, %b: f32, %out: f32):
      %0 = arith.addf %a, %b : f32
      linalg.yield %0 : f32
  }
  return %res
}

7.3 稀疏性与不规则融合 (Sparsity & Irregular Fusion)

7.3.1 Sparse-Dense Fusion (稀疏-稠密融合)

背景

虽然"动态性"通常指形状变化,但 数据分布的动态性 (稀疏性) 也是关键挑战。当稀疏张量 (Sparse Tensor) 与稠密张量 (Dense Tensor) 进行运算时 (如 GNN 或推荐系统),融合策略不能遍历整个空间,而必须 跟随稀疏索引 (Sparse Indices) 进行迭代。

编译器利用 Co-iteration (协同迭代) 技术,只对非零元素的位置执行融合算子链,从而获得 而非 的性能。

TACO[7]、Sparse Tensor Compiler[8]等系统为稀疏计算的编译优化提供了重要参考。GNN[9]、推荐系统[10]等领域广泛应用了这些技术。

技术原理

为了让编译器能够像处理稠密张量一样处理稀疏张量,并实现自动化融合,现代编译器 (如 MLIR Sparse Compiler, TACO) 采用了以下核心技术:

  1. 基于维度的层级抽象 (Per-dimension Level Abstraction): 编译器不硬编码 "CSR" 或 "COO" 格式,而是将稀疏格式解构为维度属性的组合。

    • Dense: 该维度的所有坐标都存在 ( 访问)。
    • Compressed: 只存储非零元素的坐标 (需要查表访问,如 indices 数组)。
    • Singleton: 该维度只有一个元素 (如 COO 格式中的坐标)。 通过组合这些属性,编译器可以描述任意稀疏格式,并生成对应的遍历代码。
  2. 协同迭代与集合运算 (Co-iteration & Set Operations): 当融合两个稀疏张量 (或稀疏+稠密) 时,编译器必须生成能够同步遍历它们的循环逻辑。这本质上是集合论问题:

    • Sparse Sparse (Intersection): 只有两个张量在位置 都有值时才计算。编译器生成类似"双指针归并"的逻辑,跳过只要有一方为零的位置。
    • Sparse Sparse (Union): 只要有一方有值就计算。编译器生成复杂的 while 循环来处理指针的对齐和推进。
    • Sparse Dense: 编译器利用稀疏张量的索引去驱动对稠密张量的随机访问 (Gather)。
  3. 循环重构 (Loop Reconstruction): 编译器将高层的 linalg.generic 降级为底层的 while 循环和间接寻址 (Indirect Addressing)。

    • Dense Loop: for (i=0; i<N; i++)
    • Sparse Loop: pos = ptr[i]; while (pos < ptr[i+1]) { idx = indices[pos]; ... }

MLIR SparseTensor Dialect 实现

cpp
// 稀疏融合:Sparse Matrix * Dense Vector + Bias
// 关键点:循环结构不是 dense 的 for-loop,而是由稀疏张量的压缩格式驱动
#CSR = #sparse_tensor.encoding<{
  map = (d0, d1) -> (d0 : dense, d1 : compressed)
}>

func.func @sparse_fusion(%sp_mat: tensor<?x?xf32, #CSR>, 
                         %vec: tensor<?xf32>, 
                         %bias: tensor<?xf32>) -> tensor<?xf32> {
  
  // 融合后的操作:linalg.generic 看起来和稠密一样
  // 但编译器后端会将其 Lowering 为"遍历非零元素索引"的复杂循环
  %res = linalg.generic {
    indexing_maps = ...
    iterator_types = ...
  } ins(%sp_mat, %vec, %bias) outs(...) {
    ^bb0(%a: f32, %b: f32, %c: f32, %out: f32):
      %prod = arith.mulf %a, %b : f32
      %sum = arith.addf %prod, %c : f32
      linalg.yield %sum : f32
  }
  return %res
}

注:扩展阅读 MLIR的SparseTensor方言是如何分析矩阵的稀疏性的?


7.4 运行时特化 (Runtime Specialization)

7.4.1 Just-In-Time (JIT) Specialization

背景

在处理动态形状时,生成一个通用的 Kernel (处理任意 ?) 通常比针对特定形状 (如 1024) 生成的 Kernel 性能差 (因为无法做常量折叠、向量化对齐或循环展开)。

运行时特化策略指:编译器保留一份通用的 IR 模板,在运行时检测到具体的形状参数 (如 Batch=1Batch=32) 时,即时触发编译,生成 完全静态化 (Static) 的高性能融合 Kernel。

JIT compilation[11]、Trace-based compilation[12]、动态形状优化[13]等技术为运行时特化提供了基础。TVM[11]、PyTorch JIT[14]等系统实现了高效的JIT特化。

技术原理

编译器通过运行时特化提升性能,主要依赖以下三个底层机制:

  1. 常量提升与传播 (Constant Promotion & Propagation): 在特化路径 (Specialized Path) 中,原本在运行时才能确定的变量 (如 Batch Size = 1),被强制视为编译期常量。 这使得编译器可以执行常量折叠 (Constant Folding),预先计算出所有与形状相关的偏移量 (Offsets) 和步长 (Strides),将复杂的地址计算简化为立即数加法。

  2. 循环完全展开与向量化 (Loop Unrolling & Vectorization): 这是特化带来的最大收益。

    • 动态循环:编译器必须生成循环头、循环体、循环尾 (处理余数) 以及边界检查代码,且不敢激进使用寄存器 (因为不知道循环次数是否足以填满流水线)。
    • 静态特化循环:当 已知且较小 (如 ) 时,编译器可以完全消除循环结构,直接生成 条线性的 FMA 指令。这不仅消除了跳转开销,还允许编译器进行精确的寄存器分配 (Register Allocation),实现 100% 的 ALU 利用率。
  3. 代码版本管理与分发 (Versioning & Dispatch): 编译器采用 Guard-based Dispatch 策略。它不试图生成一个"万能优化的内核",而是生成多个版本:

    • Fast Path: 针对热点形状 (如 B=1, Seq=128) 的极致优化无分支代码。
    • Generic Path: 带有完整循环开销的兜底代码。 运行时的开销仅仅是几个整数比较指令 (Guard Check),换来的是核心计算路径的数倍性能提升。

MLIR 实现:

下面的示例展示了如何处理动态 Batch 的矩阵乘法。编译器为最常见的 Batch=1 生成了特化路径。

cpp
func.func @matmul_dispatch(%A: tensor<?x1024xf32>, %B: tensor<1024x1024xf32>) -> tensor<?x1024xf32> {
  
  // 1. [Guard] 获取运行时维度并进行检查
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  
  %batch_size = tensor.dim %A, %c0 : tensor<?x1024xf32>
  %is_batch_1 = arith.cmpi eq, %batch_size, %c1 : index

  // 2. [Dispatch] 使用 scf.if 进行分发
  %result = scf.if %is_batch_1 -> (tensor<?x1024xf32>) {
    
    // === 特化路径 (Specialized Path: Batch=1) ===
    // 关键点:使用 tensor.cast 将动态 ? 转换为静态 1
    // 这告诉编译器:在这个块内,A 的形状是确定的 1x1024
    %A_static = tensor.cast %A : tensor<?x1024xf32> to tensor<1x1024xf32>
    %init_static = tensor.empty() : tensor<1x1024xf32>

    // 此时 linalg.matmul 看到的是静态形状
    // 后端可以将其从循环 Lowering 为一串扁平的 FMA 指令 (全展开)
    %res_static = linalg.matmul ins(%A_static, %B : ...) outs(%init_static : ...)
    
    // 转换回动态类型以统一返回接口
    %res_cast = tensor.cast %res_static : tensor<1x1024xf32> to tensor<?x1024xf32>
    scf.yield %res_cast : tensor<?x1024xf32>

  } else {
    
    // === 通用/回退路径 (Fallback Path: Dynamic Batch) ===
    // 这里的 ? 仍然是动态的,生成带有循环开销的通用代码
    %init_dynamic = tensor.empty(%batch_size) : tensor<?x1024xf32>
    %res_dynamic = linalg.matmul ins(%A, %B : ...) outs(%init_dynamic : ...)
    
    scf.yield %res_dynamic : tensor<?x1024xf32>
  }

  return %result : tensor<?x1024xf32>
}

性能收益解析

  • Static Path (Batch=1):由于形状已知,后端编译器可以消除对 Batch 维度的循环,直接生成向量化的点积指令 (如 tensor<1x1024> tensor<1024x1024> 变为 1024 次向量乘加)。
  • Dynamic Path:必须保留外层循环 for (i=0; i<N; ++i),并处理余数循环,开销显著。

参考文献 (References)

控制流优化与谓词执行 (Control Flow & Predication)

[1] R. Allen, J. Cocke, K. Kennedy. "Reduction of Operator Strength." ACM Computing Surveys, 1981. [Classic]

[2] J. A. Fisher. "Predicated Execution for High Performance Computing." IEEE Computer, 1996. [Link]

[3] J. R. Allen, K. Kennedy. "Automatic Translation of Fortran Programs for Vectorized Parallelism." IEEE TPDS, 1987. [Link]

[4] (NVIDIA). "CUDA C++ Programming Guide - Warp Vote and Branch Predication." NVIDIA Documentation. [Official]

[5] (MLIR). "MLIR: Multi-Level Intermediate Representation." MLIR Documentation. [Official]

动态形状与符号化分析 (Dynamic Shape & Symbolic Analysis)

[6] (Google). "TensorFlow Graph Execution: Dynamic Shapes and Control Flow." TensorFlow Documentation. [Official]

[7] (PyTorch). "Dynamic Shapes in PyTorch." PyTorch Documentation. [Official]

[8] (JAX). "JAX: Autodifferentiable Array Programming." GitHub. [Official]

[9] (OneFlow). "OneFlow: Dynamic Shape Support." GitHub. [Link]

[10] (XLA). "XLA: Optimizing for Dynamic Shapes." XLA Documentation. [Official]

稀疏计算与不规则融合 (Sparse Computation)

[11] F. Kjolstad, et al. "TACO: A General-Purpose Sparse Tensor Compiler." ACM PPoPP, 2022. [Link]

[12] (MLIR). "MLIR Sparse Tensor Dialect." MLIR Documentation. [Official]

[13] W. L. Wang, et al. "Deep Graph Library: A Graph Neural Network System." IEEE TKDE, 2022. [Link]

[14] R. Ying, et al. "PyTorch Geometric: Deep Learning on Graphs." NeurIPS, 2020. [Link]

[15] (LightGBM). "LightGBM: A Highly Efficient Gradient Boosting Decision Tree." NeurIPS, 2017. [Link]

JIT编译与运行时特化

[16] T. Chen, et al. "TVM: An Automated End-to-End Optimizing Compiler for Deep Learning." USENIX OSDI, 2018. [Link]

[17] (PyTorch). "PyTorch JIT: Just-In-Time Compilation." PyTorch Documentation. [Official]

[18] (TensorFlow). "TensorFlow XLA: Optimizing for Production." TensorFlow Documentation. [Official]

[19] (Numba). "Numba: JIT Compilation for Python." GitHub. [Link]

[20] (Julia). "Julia: A Fresh Approach to Technical Computing." MIT License. [Link]

Released under the CC BY-NC-ND 4.0 License.