第十章:代码生成机制与启发式优化
📖 本章概要
本章深入讲解 TorchInductor 如何生成高效的 Triton/C++ 代码,以及如何通过启发式策略(Heuristics)进行性能优化。您将了解:
- TorchInductor 的代码生成流程
- Triton Kernel 的参数配置策略(grid size、block size、num_warps)
- AutoTuning 自动调优机制
- 如何为特定硬件定制代码生成策略
目录
- 代码生成全流程
- 深入:真实的Kernel代码生成 ⭐核心
- 完整示例:逐步追踪
- 详细步骤拆解
- 完整时间线总结
- 关键数据流
- 不同操作类型的代码生成
- Triton Kernel 参数详解
- 启发式优化策略
- Grid Size 与循环策略
- Block Size 与 Num_Warps 配置
- AutoTuning 机制
- 实战:自定义代码生成策略
- 性能分析与调优
- 常见问题
- 总结
1. 代码生成全流程
1.1 从 FX Graph 到可执行代码
用户模型 ↓ torch.compile ↓ [1] TorchDynamo: 捕获计算图 ↓ [2] AOTAutograd: 前向/反向分离 ↓ [3] TorchInductor: 代码生成 ├─→ [3.1] Lowering: 高层 IR → 低层 IR ├─→ [3.2] Fusion: 算子融合决策 ├─→ [3.3] Scheduling: 生成调度代码 │ ├─→ TritonScheduling │ ├─→ CppScheduling │ └─→ ExternKernel (调用外部库) ├─→ [3.4] Code Emission: 生成内核代码 │ ├─→ Triton Code │ ├─→ C++ Code │ └─→ CUDA Code └─→ [3.5] Wrapper Generation: 生成包装代码 └─→ Python/C++ Wrapper ↓ [4] 编译与加载 ├─→ Triton Compiler → PTX/LLVM IR ├─→ C++ Compiler → .so └─→ 动态加载到 Python 进程 ↓ 执行1.2 核心数据结构
# torch/_inductor/ir.pyclassIRNode:"""IR 节点基类"""defget_size(self)->List[sympy.Expr]:"""返回张量形状"""passdefget_dtype(self)->torch.dtype:"""返回数据类型"""passdefget_device(self)->torch.device:"""返回设备类型"""passclassSchedulerNode:"""调度节点:包含 IR + 调度信息"""def__init__(self,scheduler,node:IRNode):self.scheduler=scheduler self.node=node self.users=[]# 使用该节点的节点列表self.group=None# 融合组defcan_fuse(self,other:'SchedulerNode')->bool:"""判断是否可以与其他节点融合"""pass1.3 代码生成入口
# torch/_inductor/compile_fx.pydefcompile_fx_inner(gm:torch.fx.GraphModule,example_inputs):"""编译 FX 图"""withV.set_graph_handler(GraphLowering(gm)):# [1] Lowering: 将 FX Graph 转换为 IRgraph=V.graph graph.run(*example_inputs)# [2] Scheduling: 生成调度计划compiled_graph=graph.compile_to_fn()returncompiled_graph# torch/_inductor/graph.pyclassGraphLowering:defcompile_to_fn(self):"""生成可执行函数"""# [1] 创建调度器fromtorch._inductor.schedulerimportScheduler self.scheduler=Scheduler(self.buffers)# [2] 融合决策self.scheduler.codegen()# [3] 生成 Python wrapperreturnself.compile_to_module().call2. 深入:真实的Kernel代码生成
这一节我们深入到源码级别,看看TorchInductor是如何真正生成Triton kernel代码的。
2.1 从IR节点到Triton代码的完整流程
# 完整的代码生成流程FX Graph Node ↓[1]Lowering(torch/_inductor/lowering.py)↓ IRNode(torch/_inductor/ir.py)├─→ Pointwise(逐点操作)├─→ Reduction(归约操作)├─→ TensorBox(张量引用)└─→ ComputedBuffer(计算缓冲区)↓[2]Scheduling(torch/_inductor/scheduler.py)↓ FusedSchedulerNode(融合节点组)├─→ snodes:List[SchedulerNode]└─→ 包含多个可融合的操作 ↓[3]Code Generation(torch/_inductor/codegen/triton.py)↓ TritonKernel ├─→ args:参数列表 ├─→ loads:内存加载操作 ├─→ stores:内存存储操作 ├─→ compute:计算逻辑 └─→ indexing:索引计算 ↓[4]Code Emission ↓ Triton Source Code(字符串)↓[5]Triton Compilation ↓ PTX/LLVM IR ↓ GPU Executable2.2 完整示例:逐步追踪代码生成
让我们通过一个具体例子,完整追踪从Python代码到GPU执行的每一步:
# 用户代码importtorch@torch.compiledefmodel(x,y):z=x+yreturnz.relu()# 运行x=torch.randn(1024,device='cuda')y=torch.randn(1024,device='cuda')result=model(x,y)🔍 执行流程时间线
t=0ms: 用户调用 model(x, y) ↓ t=1ms: TorchDynamo 拦截函数调用 ↓ t=2ms: 字节码分析 + FX Graph 捕获 ↓ t=50ms: AOTAutograd 处理 ↓ t=51ms: TorchInductor 代码生成 ├─→ Lowering (51-52ms) ├─→ Fusion (52-53ms) ├─→ Code Generation (53-55ms) └─→ Triton Compilation (55-150ms) ↓ t=150ms: 动态加载编译结果 ↓ t=151ms: 首次执行(编译完成) ↓ t=152ms: 第二次调用 model(x, y) └─→ 直接执行编译好的代码(~0.1ms)2.3 详细步骤拆解
步骤0:装饰器注册
# 当你写 @torch.compile 时@torch.compiledefmodel(x,y):z=x+yreturnz.relu()# 等价于model=torch.compile(model)# torch.compile 做了什么?deftorch_compile(fn):""" 返回一个包装函数,在首次调用时触发编译 """compiled_fn=Nonedefwrapper(*args,**kwargs):nonlocalcompiled_fn# [关键] 首次调用时编译ifcompiled_fnisNone:print("[Compile] 开始编译...")compiled_fn=_compile_impl(fn,args,kwargs)# 执行编译后的函数returncompiled_fn(*args,**kwargs)returnwrapper步骤1:首次调用 - TorchDynamo拦截(t=1ms)
# 当执行 result = model(x, y) 时# [1.1] TorchDynamo 拦截# torch/_dynamo/eval_frame.pydef_compile(fn,args,kwargs,compiler_fn,):""" TorchDynamo 的核心:拦截 Python 字节码 """# [1] 获取函数的字节码code=fn.__code__print(f"[Dynamo] 拦截函数:{fn.__name__}")print(f"[Dynamo] 字节码指令数:{len(code.co_code)}")# [2] 创建字节码分析器fromtorch._dynamo.symbolic_convertimportInstructionTranslator tracer=InstructionTranslator(instructions=dis.get_instructions(code),f_locals=fn.__globals__,)# [3] 逐条执行字节码,构建计算图graph=tracer.run()returngraph实际的字节码(使用dis.dis(model)查看):
importdis dis.dis(model)# 输出:# 2 0 LOAD_FAST 0 (x)# 2 LOAD_FAST 1 (y)# 4 BINARY_ADD# 6 STORE_FAST 2 (z)## 3 8 LOAD_FAST 2 (z)# 10 LOAD_ATTR 0 (relu)# 12 CALL_FUNCTION 0# 14 RETURN_VALUETorchDynamo 如何处理这些字节码:
# torch/_dynamo/symbolic_convert.pyclassInstructionTranslator:defrun(self):"""执行字节码,构建FX Graph"""forinstinself.instructions:handler=getattr(self,inst.opname,None)ifhandler:handler(inst)returnself.output.graphdefLOAD_FAST(self,inst):"""处理 LOAD_FAST 指令(加载局部变量)"""var_name=inst.argval# 'x' 或 'y'# 创建符号变量var=VariableTracker.build(self.f_locals[var_name])self.stack.append(var)print(f"[Dynamo] LOAD_FAST:{var_name}")defBINARY_ADD(self,inst):"""处理 BINARY_ADD 指令(加法)"""# 从栈中弹出两个操作数right=self.stack.pop()# yleft=self.stack.pop()# x# 记录操作到图中result=self.output.create_node('call_function',torch.ops.aten.add.Tensor,args=(left,right),)self.stack.append(result)print(f"[Dynamo] BINARY_ADD:{left}+{right}")defLOAD_ATTR(self,inst):"""处理 LOAD_ATTR 指令(属性访问)"""obj=self.stack.pop()# zattr_name=inst.argval# 'relu'# 记录属性访问method=getattr(obj,attr_name)self.stack.append(method)print(f"[Dynamo] LOAD_ATTR:{obj}.{attr_name}")defCALL_FUNCTION(self,inst):"""处理 CALL_FUNCTION 指令(函数调用)"""num_args=inst.argval# 0(relu没有参数)# 弹出函数和参数fn=self.stack.pop()# relu methodargs=[self.stack.pop()for_inrange(num_args)]# 记录函数调用result=self.output.create_node('call_method','relu',args=(fn.__self__,),# self是z)self.stack.append(result)print(f"[Dynamo] CALL_FUNCTION:{fn}({args})")步骤2:FX Graph生成(t=2ms)
# 生成的FX Graph# torch.fx.GraphModuleclassGraphModule(torch.nn.Module):defforward(self,x,y):# 节点1: addadd_tensor=torch.ops.aten.add.Tensor(x,y)# 节点2: relurelu_default=torch.ops.aten.relu.default(add_tensor)returnrelu_default# 打印图结构print(graph.graph)# 输出:# graph():# %x : [num_users=1] = placeholder[target=x]# %y : [num_users=1] = placeholder[target=y]# %add_tensor : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %y), kwargs = {})# %relu_default : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%add_tensor,), kwargs = {})# return relu_default图的可视化:
输入: x [1024] 输入: y [1024] ↓ ↓ └──────┬───────────┘ ↓ aten.add.Tensor ↓ add_tensor [1024] ↓ aten.relu.default ↓ relu_default [1024] ↓ 输出步骤3:AOTAutograd处理(t=50ms)
# torch/_functorch/aot_autograd.pydefaot_function(fn,fw_compiler):""" Ahead-Of-Time Autograd 分离前向和反向图 """defcompiled_fn(*args):# [1] 运行前向,记录操作withtorch.enable_grad():# 标记输入需要梯度args_with_grad=[arg.requires_grad_(True)ifisinstance(arg,torch.Tensor)elseargforarginargs]# 运行前向out=fn(*args_with_grad)# [2] 构建反向图# (本例不需要反向,简化处理)# [3] 编译前向图compiled_fw=fw_compiler(forward_graph,args_with_grad,)returncompiled_fw(*args)returncompiled_fn# 对于我们的例子,AOTAutograd主要做:# 1. 确认不需要梯度(x, y 没有 requires_grad=True)# 2. 将FX Graph传递给TorchInductor步骤4:TorchInductor Lowering(t=51-52ms)
这是关键步骤!将高层的aten.add和aten.relu转换为低层IR。
# torch/_inductor/lowering.py# [4.1] 注册 lowering 规则@register_lowering(torch.ops.aten.add)defadd_tensor(x,y):""" 将 aten.add 转换为 Pointwise IR 输入: x, y 是 TensorBox(包装的张量引用) 输出: Pointwise IR节点 """print(f"[Lowering] add: x.shape={x.get_size()}, y.shape={y.get_size()}")definner_fn(idx):""" 核心计算逻辑 idx: 符号索引,例如 (i,) 对于1D张量 """# 加载 x[idx]x_val=ops.load(x,idx)# 加载 y[idx]y_val=ops.load(y,idx)# 计算 x[idx] + y[idx]result=ops.add(x_val,y_val)returnresult# 创建 Pointwise IR 节点returnPointwise.create(device=x.get_device(),# 'cuda'dtype=x.get_dtype(),# torch.float32inner_fn=inner_fn,# 上面定义的lambdaranges=list(x.get_size()),# [1024])@register_lowering(torch.ops.aten.relu)defrelu(x):""" 将 aten.relu 转换为 Pointwise IR """print(f"[Lowering] relu: x.shape={x.get_size()}")definner_fn(idx):# 加载 x[idx]x_val=ops.load(x,idx)# 计算 max(x[idx], 0)zero=ops.constant(0.0,x.get_dtype())result=ops.maximum(x_val,zero)returnresultreturnPointwise.create(device=x.get_device(),dtype=x.get_dtype(),inner_fn=inner_fn,ranges=list(x.get_size()),)# [4.2] 执行 Lowering# torch/_inductor/graph.pyclassGraphLowering:defrun(self,*args):"""遍历FX Graph,调用对应的lowering函数"""fornodeinself.graph.nodes:ifnode.op=='call_function':# 查找对应的lowering函数lowering_fn=lowerings.get(node.target)# 调用loweringir_result=lowering_fn(*node.args,**node.kwargs)# 保存IR节点self.env[node.name]=ir_resultprint(f"[Lowering]{node.name}->{type(ir_result).__name__}")Lowering后的IR结构:
# IR Node 1: addir_add=Pointwise(device='cuda',dtype=torch.float32,inner_fn=lambdaidx:ops.add(ops.load(x,idx),ops.load(y,idx)),ranges=[1024],name='buf0',)# IR Node 2: reluir_relu=Pointwise(device='cuda',dtype=torch.float32,inner_fn=lambdaidx:ops.maximum(ops.load(buf0,idx),# 依赖 buf0(add的输出)ops.constant(0.0,torch.float32)),ranges=[1024],name='buf1',)步骤5:Fusion决策(t=52-53ms)
# torch/_inductor/scheduler.pyclassScheduler:deffusion_pass(self):"""算子融合决策"""print("[Fusion] 开始融合分析...")# [5.1] 为每个IR节点创建SchedulerNodeself.nodes=[]forbuf_name,ir_nodeinself.buffers.items():snode=SchedulerNode(self,ir_node)snode.name=buf_name self.nodes.append(snode)print(f"[Fusion] 发现{len(self.nodes)}个节点")# [5.2] 构建依赖关系forsnodeinself.nodes:# 分析 inner_fn,提取读取的缓冲区reads=snode.node.get_reads()forread_bufinreads:# 找到生产者节点producer=self.name_to_node[read_buf.name]snode.read_from.add(producer)producer.users.append(snode)print("[Fusion] 依赖关系:")print(" buf0 (add) -> buf1 (relu)")# [5.3] 尝试融合forconsumerinself.nodes:forproducerinlist(consumer.read_from):ifself.can_fuse(producer,consumer):print(f"[Fusion] 融合:{producer.name}+{consumer.name}")self.fuse(producer,consumer)defcan_fuse(self,producer,consumer):"""判断是否可以融合"""# 条件1: 都是Pointwiseifnot(isinstance(producer.node,Pointwise)andisinstance(consumer.node,Pointwise)):print(f"[Fusion] 不能融合: 不都是Pointwise")returnFalse# 条件2: producer只有一个使用者(consumer)iflen(producer.users)!=1:print(f"[Fusion] 不能融合: producer有多个使用者")returnFalse# 条件3: 形状相同ifproducer.node.get_size()!=consumer.node.get_size():print(f"[Fusion] 不能融合: 形状不匹配")returnFalseprint(f"[Fusion] ✓ 可以融合")returnTruedeffuse(self,producer,consumer):"""融合两个节点"""# 创建融合后的inner_fndeffused_inner_fn(idx):# 内联producer的计算producer_result=producer.node.inner_fn(idx)# 替换consumer中对producer的load为直接使用结果# consumer.inner_fn 中的 ops.load(producer, idx)# 替换为 producer_result# 这里简化表示returnconsumer.node.inner_fn_with_inline(producer_result,idx)# 创建融合节点fused_node=Pointwise.create(device=consumer.node.get_device(),dtype=consumer.node.get_dtype(),inner_fn=fused_inner_fn,ranges=consumer.node.get_size(),)# 更新图self.replace_node(consumer,fused_node)self.remove_node(producer)融合后的IR:
# 融合后只有一个 Pointwise 节点fused_ir=Pointwise(device='cuda',dtype=torch.float32,inner_fn=lambdaidx:ops.maximum(ops.add(ops.load(x,idx),# 来自addops.load(y,idx)# 来自add),# add的结果ops.constant(0.0,torch.float32)# 来自relu),ranges=[1024],name='buf_fused',)步骤6:生成Triton代码(t=53-55ms)
这是最精彩的部分!将IR转换为实际的Triton源代码。
# torch/_inductor/codegen/triton.pyclassTritonScheduling:defcodegen(self):"""为所有融合后的节点生成代码"""print("[Codegen] 开始生成Triton代码...")fornodeinself.scheduled_nodes:kernel=self.create_kernel(node)kernel_code=kernel.codegen()self.kernels.append(kernel_code)defcreate_kernel(self,node):"""为单个节点创建Triton Kernel"""kernel=TritonKernel(node)returnkernelclassTritonKernel:def__init__(self,node):self.node=node self.args=IndentedBuffer()self.indexing=IndentedBuffer()self.loads=IndentedBuffer()self.compute=IndentedBuffer()self.stores=IndentedBuffer()# 临时变量计数器self.tmp_counter=0defnew_tmp_var(self):"""分配新的临时变量"""var=f"tmp{self.tmp_counter}"self.tmp_counter+=1returnvardefcodegen(self):"""生成完整的kernel代码"""print(f"[Codegen] 处理节点:{self.node.name}")# [1] 生成参数列表self.codegen_args()# [2] 生成索引计算self.codegen_indexing()# [3] 生成计算代码(遍历inner_fn)result_var=self.codegen_inner_fn()# [4] 生成storeself.codegen_store(result_var)# [5] 组装完整代码returnself.assemble()defcodegen_args(self):"""生成参数列表"""# 输入缓冲区forinpinself.node.get_inputs():self.args.writeline(f"in_ptr{inp.index},")# 输出缓冲区self.args.writeline("out_ptr0,")# 元素数量self.args.writeline("xnumel,")print(f"[Codegen] 参数: 2个输入 + 1个输出 + 1个size")defcodegen_indexing(self):"""生成索引计算代码"""self.indexing.writeline("# 计算索引")self.indexing.writeline("pid = tl.program_id(0)")self.indexing.writeline("xoffset = pid * XBLOCK")self.indexing.writeline("xindex = xoffset + tl.arange(0, XBLOCK)")self.indexing.writeline("xmask = xindex < xnumel")print(f"[Codegen] 索引: 1D, range=[{self.node.get_size()[0]}]")defcodegen_inner_fn(self):""" 生成inner_fn的代码 这是最核心的部分! """# 我们的融合后的inner_fn是:# lambda idx: ops.maximum(# ops.add(# ops.load(x, idx),# ops.load(y, idx)# ),# ops.constant(0.0, torch.float32)# )# 使用符号执行symbolic_idx=sympy.Symbol('xindex')expr=self.node.inner_fn(symbolic_idx)# 递归生成表达式树result_var=self.codegen_expr(expr)returnresult_vardefcodegen_expr(self,expr):"""递归生成表达式的Triton代码"""print(f"[Codegen] 处理表达式:{type(expr).__name__}")ifisinstance(expr,ops.Load):# === Load 操作 ===buffer_name=expr.name# 'in_ptr0' 或 'in_ptr1'index='xindex'# 简化,实际会计算复杂索引tmp_var=self.new_tmp_var()# tmp0self.loads.writeline(f"{tmp_var}= tl.load({buffer_name}+{index}, xmask)")print(f"[Codegen] Load:{buffer_name}[{index}] ->{tmp_var}")returntmp_varelifisinstance(expr,ops.Add):# === Add 操作 ===lhs_var=self.codegen_expr(expr.lhs)# 递归处理左操作数rhs_var=self.codegen_expr(expr.rhs)# 递归处理右操作数tmp_var=self.new_tmp_var()# tmp2self.compute.writeline(f"{tmp_var}={lhs_var}+{rhs_var}")print(f"[Codegen] Add:{lhs_var}+{rhs_var}->{tmp_var}")returntmp_varelifisinstance(expr,ops.Maximum):# === Maximum 操作(用于relu) ===lhs_var=self.codegen_expr(expr.lhs)rhs_var=self.codegen_expr(expr.rhs)tmp_var=self.new_tmp_var()# tmp3self.compute.writeline(f"{tmp_var}= tl.maximum({lhs_var},{rhs_var})")print(f"[Codegen] Maximum: max({lhs_var},{rhs_var}) ->{tmp_var}")returntmp_varelifisinstance(expr,ops.Constant):# === 常量 ===returnstr(expr.value)# "0.0"else:raiseNotImplementedError(f"Unknown expr:{type(expr)}")defcodegen_store(self,result_var):"""生成store语句"""self.stores.writeline(f"tl.store(out_ptr0 + xindex,{result_var}, xmask)")print(f"[Codegen] Store:{result_var}-> out_ptr0[xindex]")defassemble(self):"""组装完整的kernel代码"""code=IndentedBuffer()# 函数签名code.writeline("@triton.jit")code.writeline("def triton_poi_fused_add_relu_0(")code.indent()code.splice(self.args)code.writeline("XBLOCK: tl.constexpr,")code.dedent()code.writeline("):")# 函数体code.indent()code.splice(self.indexing)code.writeline("")code.splice(self.loads)code.writeline("")code.splice(self.compute)code.writeline("")code.splice(self.stores)code.dedent()returncode.getvalue()生成的Triton代码(最终输出):
@triton.jitdeftriton_poi_fused_add_relu_0(in_ptr0,# xin_ptr1,# yout_ptr0,# outputxnumel,# 1024XBLOCK:tl.constexpr,):# 计算索引pid=tl.program_id(0)xoffset=pid*XBLOCK xindex=xoffset+tl.arange(0,XBLOCK)xmask=xindex<xnumel# Loadtmp0=tl.load(in_ptr0+xindex,xmask)# load xtmp1=tl.load(in_ptr1+xindex,xmask)# load y# Computetmp2=tmp0+tmp1# addtmp3=tl.maximum(tmp2,0.0)# relu# Storetl.store(out_ptr0+xindex,tmp3,xmask)代码生成过程的执行日志:
[Codegen] 处理节点: buf_fused [Codegen] 参数: 2个输入 + 1个输出 + 1个size [Codegen] 索引: 1D, range=[1024] [Codegen] 处理表达式: Maximum [Codegen] 处理表达式: Add [Codegen] 处理表达式: Load [Codegen] Load: in_ptr0[xindex] -> tmp0 [Codegen] 处理表达式: Load [Codegen] Load: in_ptr1[xindex] -> tmp1 [Codegen] Add: tmp0 + tmp1 -> tmp2 [Codegen] 处理表达式: Constant [Codegen] Maximum: max(tmp2, 0.0) -> tmp3 [Codegen] Store: tmp3 -> out_ptr0[xindex]步骤7:配置参数(启发式)(t=55ms)
# torch/_inductor/codegen/triton.pyclassTritonScheduling:defselect_config(self,node):"""选择kernel参数配置"""numel=node.get_size()[0]# 1024dtype=node.get_dtype()# torch.float32# [1] 选择XBLOCKxblock=self.select_xblock(numel,dtype)print(f"[Config] XBLOCK ={xblock}")# [2] 选择num_warpsnum_warps=self.select_num_warps(xblock)print(f"[Config] num_warps ={num_warps}")# [3] 计算gridgrid_size=triton.cdiv(numel,xblock)print(f"[Config] grid = ({grid_size},)")return{'XBLOCK':xblock,'num_warps':num_warps,'grid':(grid_size,),}defselect_xblock(self,numel,dtype):"""启发式选择XBLOCK"""# 对于1024个元素,尝试不同的block sizecandidates=[256,512,1024]forxblockincandidates:grid_size=triton.cdiv(numel,xblock)# 条件1: grid不能太小ifgrid_size<4:# 至少4个block保证并行continue# 条件2: 内存对齐ifxblock*dtype.itemsize%128==0:returnxblockreturn256# 默认defselect_num_warps(self,xblock):"""根据XBLOCK选择num_warps"""# 每个warp = 32 threadsmin_warps=(xblock+31)//32# 向上取整到2的幂次importmath num_warps=2**math.ceil(math.log2(min_warps))returnmin(num_warps,32)# 对于我们的例子:# numel = 1024# XBLOCK = 256 (选择256,因为 1024 / 256 = 4 blocks)# num_warps = 8 (256 / 32 = 8 warps)# grid = (4,)步骤8:生成Wrapper代码(t=55ms)
# torch/_inductor/codegen/wrapper.pyclassWrapperCodegen:defgenerate(self):"""生成Python wrapper代码"""code=IndentedBuffer()# [1] 导入code.writeline("import torch")code.writeline("import triton")code.writeline("import triton.language as tl")code.writeline("")# [2] Kernel定义code.writeline("# Kernel定义")code.splice(self.kernel_code)code.writeline("")# [3] 调用函数code.writeline("def call(args):")code.indent()# [3.1] 解包参数code.writeline("primals_1 = args[0] # x")code.writeline("primals_2 = args[1] # y")code.writeline("")# [3.2] 分配输出缓冲区code.writeline("# 分配输出")code.writeline("buf0 = torch.empty_strided(")code.indent()code.writeline("(1024,),")code.writeline("(1,),")code.writeline("device='cuda',")code.writeline("dtype=torch.float32")code.dedent()code.writeline(")")code.writeline("")# [3.3] 调用kernelcode.writeline("# 启动kernel")code.writeline("grid = lambda meta: (")code.indent()code.writeline("triton.cdiv(1024, meta['XBLOCK']),")code.dedent()code.writeline(")")code.writeline("triton_poi_fused_add_relu_0[grid](")code.indent()code.writeline("primals_1,")code.writeline("primals_2,")code.writeline("buf0,")code.writeline("1024,")code.writeline("XBLOCK=256,")code.writeline("num_warps=8,")code.dedent()code.writeline(")")code.writeline("")# [3.4] 返回结果code.writeline("return (buf0,)")code.dedent()returncode.getvalue()生成的完整Python模块:
# /tmp/torchinductor_user/ab/cabcdef123.pyimporttorchimporttritonimporttriton.languageastl# Kernel定义@triton.jitdeftriton_poi_fused_add_relu_0(in_ptr0,in_ptr1,out_ptr0,xnumel,XBLOCK:tl.constexpr,):pid=tl.program_id(0)xoffset=pid*XBLOCK xindex=xoffset+tl.arange(0,XBLOCK)xmask=xindex<xnumel tmp0=tl.load(in_ptr0+xindex,xmask)tmp1=tl.load(in_ptr1+xindex,xmask)tmp2=tmp0+tmp1 tmp3=tl.maximum(tmp2,0.0)tl.store(out_ptr0+xindex,tmp3,xmask)defcall(args):"""主调用函数"""primals_1=args[0]# xprimals_2=args[1]# y# 分配输出buf0=torch.empty_strided((1024,),(1,),device='cuda',dtype=torch.float32)# 启动kernelgrid=lambdameta:(triton.cdiv(1024,meta['XBLOCK']),)triton_poi_fused_add_relu_0[grid](primals_1,primals_2,buf0,1024,XBLOCK=256,num_warps=8,)return(buf0,)步骤9:Triton编译(t=55-150ms)
# Triton编译器将Triton DSL编译为PTX/LLVM IR# [1] Triton JIT编译# triton/compiler/compiler.pyclassTritonCompiler:defcompile(self,src,options):"""编译Triton代码"""print("[Triton] 开始编译...")# [1] 解析Triton代码 -> ASTast=self.parse(src)print("[Triton] 解析完成")# [2] 类型推断ast=self.type_inference(ast)print("[Triton] 类型推断完成")# [3] 转换为Triton IRttir=self.lower_to_ttir(ast)print("[Triton] Triton IR生成")# [4] 优化Triton IRttir=self.optimize_ttir(ttir)print("[Triton] IR优化完成")# [5] 转换为LLVM IRllvm_ir=self.lower_to_llvm(ttir,options)print("[Triton] LLVM IR生成")# [6] LLVM编译为PTXptx=self.llvm_to_ptx(llvm_ir)print("[Triton] PTX生成")# [7] PTX编译为CUBINcubin=self.ptx_to_cubin(ptx)print("[Triton] CUBIN生成")returncubin# 编译日志:# [Triton] 开始编译...# [Triton] 解析完成 (10ms)# [Triton] 类型推断完成 (5ms)# [Triton] Triton IR生成 (15ms)# [Triton] IR优化完成 (20ms)# [Triton] LLVM IR生成 (25ms)# [Triton] PTX生成 (30ms)# [Triton] CUBIN生成 (45ms)# [Triton] 编译完成! (总计 150ms)步骤10:动态加载(t=150ms)
# torch/_inductor/codecache.pyclassCodeCache:defload(self,module_path):"""动态加载编译好的模块"""print(f"[Cache] 加载模块:{module_path}")# Python的动态导入importimportlib.util spec=importlib.util.spec_from_file_location("compiled_module",module_path)module=importlib.util.module_from_spec(spec)spec.loader.exec_module(module)print("[Cache] 模块加载完成")returnmodule# 加载编译结果compiled_module=load("/tmp/torchinductor_user/ab/cabcdef123.py")compiled_fn=compiled_module.call步骤11:首次执行(t=151ms)
# 第一次调用 model(x, y)result=compiled_fn([x,y])# GPU执行流程:# 1. 启动kernel: triton_poi_fused_add_relu_0# - grid = (4,) # 4个block# - block = (256 threads,) # 每个block 256线程# - num_warps = 8 # 每个block 8个warp## 2. GPU调度:# Block 0: 处理 x[0:256]# Block 1: 处理 x[256:512]# Block 2: 处理 x[512:768]# Block 3: 处理 x[768:1024]## 3. 每个thread执行:# - Load: x[tid], y[tid]# - Compute: tmp = x[tid] + y[tid]# - Compute: result = max(tmp, 0.0)# - Store: result -> output[tid]## 4. 同步等待所有block完成## 执行时间: ~0.1ms步骤12:第二次调用(t=152ms)
# 第二次调用 model(x, y)result=model(x,y)# 直接使用缓存的编译结果!# - 不需要重新编译# - 不需要重新加载# - 直接调用 compiled_fn## 执行时间: ~0.1ms (纯GPU执行时间)2.4 完整时间线总结
时间 阶段 耗时 说明 -------------------------------------------------------------------- 0ms 用户调用 0ms model(x, y) 1ms Dynamo拦截 1ms 字节码分析 2ms FX Graph生成 48ms 构建计算图 50ms AOTAutograd 1ms 前向/反向分离 51ms Lowering 1ms FX Graph -> IR 52ms Fusion 1ms 算子融合 53ms Code Generation 2ms IR -> Triton代码 55ms 配置选择 <1ms 选择XBLOCK等参数 55ms Triton编译 95ms Triton -> PTX -> CUBIN 150ms 动态加载 1ms 导入Python模块 151ms 首次执行 <1ms GPU执行 -------------------------------------------------------------------- 总计: 首次调用 ~151ms (主要是编译) 后续调用 ~0.1ms (直接执行)2.5 关键数据流
用户数据: x: torch.Tensor([...], shape=(1024,), device='cuda') y: torch.Tensor([...], shape=(1024,), device='cuda') ↓ [Dynamo捕获] FX Graph: %x : Tensor(1024) %y : Tensor(1024) %add : Tensor(1024) = aten.add(%x, %y) %relu : Tensor(1024) = aten.relu(%add) ↓ [Lowering] IR: buf0 = Pointwise(lambda i: load(x, i) + load(y, i)) buf1 = Pointwise(lambda i: max(load(buf0, i), 0.0)) ↓ [Fusion] Fused IR: buf_fused = Pointwise(lambda i: max(load(x, i) + load(y, i), 0.0)) ↓ [Code Generation] Triton Code: tmp0 = load(x, i) tmp1 = load(y, i) tmp2 = tmp0 + tmp1 tmp3 = max(tmp2, 0.0) store(output, i, tmp3) ↓ [Compilation] GPU Binary: CUBIN: [binary code...] ↓ [Execution] GPU Memory: x_ptr: 0x7f1234... (4096 bytes) y_ptr: 0x7f5678... (4096 bytes) out_ptr: 0x7f9abc... (4096 bytes) GPU Execution: 4 blocks × 256 threads = 1024 threads 每个thread处理1个元素 结果写回 out_ptr2.6 不同操作类型的代码生成
2.6.1 Pointwise操作(1D)
# 输入IRPointwise(inner_fn=lambdaidx:ops.load(x,idx)*2.0,ranges=[1024],)# 生成的Triton代码@triton.jitdefkernel(in_ptr0,out_ptr0,xnumel,XBLOCK:tl.constexpr):pid=tl.program_id(0)xindex=pid*XBLOCK+tl.arange(0,XBLOCK)xmask=xindex<xnumel tmp0=tl.load(in_ptr0+xindex,xmask)tmp1=tmp0*2.0tl.store(out_ptr0+xindex,tmp1,xmask)2.6.2 Pointwise操作(2D)
# 输入IR: x.T (转置)Pointwise(inner_fn=lambdai,j:ops.load(x,[j,i]),# 交换索引ranges=[N,M],# 输出形状)# 生成的Triton代码@triton.jitdefkernel(in_ptr0,out_ptr0,M,N,XBLOCK:tl.constexpr):pid=tl.program_id(0)# 2D索引展平xindex=pid*XBLOCK+tl.arange(0,XBLOCK)xmask=xindex<M*N# 转换为2D索引x0=xindex%N# jx1=xindex//N# i# Load: 交换索引tmp0=tl.load(in_ptr0+x1+x0*M,xmask)# Store: 正常顺序tl.store(out_ptr0+xindex,tmp0,xmask)2.6.3 Reduction操作
# 输入IR: x.sum(dim=1)Reduction(inner_fn=lambdai,j:ops.load(x,[i,j]),reduction_type="sum",ranges=[M,N],# 输入形状reduction_dim=1,# 在第1维上reduceoutput_shape=[M],# 输出形状)# 生成的Triton代码(更复杂)@triton.jitdefkernel(in_ptr0,out_ptr0,M,N,XBLOCK:tl.constexpr,RBLOCK:tl.constexpr,# Reduction block size):pid=tl.program_id(0)xindex=pid*XBLOCK+tl.arange(0,XBLOCK)xmask=xindex<M# 初始化累加器accumulator=tl.zeros([XBLOCK],dtype=tl.float32)# 循环遍历reduction维度forroffsetinrange(0,N,RBLOCK):rindex=roffset+tl.arange(0,RBLOCK)rmask=rindex<N# 计算2D索引: (xindex[:, None], rindex[None, :])# 这会创建一个 XBLOCK x RBLOCK 的矩阵mask=xmask[:,None]&rmask[None,:]# Load: in_ptr0[xindex, rindex]offset=xindex[:,None]*N+rindex[None,:]data=tl.load(in_ptr0+offset,mask,other=0.0)# Reduce: sum along RBLOCK dimensionaccumulator+=tl.sum(data,axis=1)# Store结果tl.store(out_ptr0+xindex,accumulator,xmask)2.7 索引计算的生成
索引计算是代码生成中最复杂的部分。TorchInductor使用符号表达式(sympy)来表示和优化索引。
# torch/_inductor/codegen/triton.pyclassTritonKernel:defcodegen_indexing(self,ranges):""" 生成索引计算代码 ranges: 张量的形状,例如 [B, M, N] 目标:将线性索引 xindex 转换为多维索引 (b, m, n) """# [1] 1D情况(最简单)iflen(ranges)==1:code=""" xindex = xoffset + tl.arange(0, XBLOCK) xmask = xindex < {numel} """.format(numel=ranges[0])returncode# [2] 多维情况# 生成索引分解公式# 例如对于 [B, M, N]:# xindex = b * M * N + m * N + n# => n = xindex % N# m = (xindex // N) % M# b = xindex // (M * N)code=IndentedBuffer()code.writeline("xindex = xoffset + tl.arange(0, XBLOCK)")# 计算总元素数numel=1forsizeinranges:numel*=size code.writeline(f"xmask = xindex <{numel}")# 生成每个维度的索引strides=[]stride=1forsizeinreversed(ranges):strides.insert(0,stride)stride*=sizefordim_idx,(size,stride)inenumerate(zip(ranges,strides)):ifstride==1:# 最内层维度code.writeline(f"x{dim_idx}= xindex %{size}")elifstride==numel//size:# 最外层维度code.writeline(f"x{dim_idx}= xindex //{stride}")else:# 中间维度code.writeline(f"x{dim_idx}= (xindex //{stride}) %{size}")returncode.getvalue()# 示例:3D张量 [2, 3, 4]ranges=[2,3,4]indexing_code=codegen_indexing(ranges)# 生成的代码:""" xindex = xoffset + tl.arange(0, XBLOCK) xmask = xindex < 24 x2 = xindex % 4 # 最内层: n x1 = (xindex // 4) % 3 # 中间层: m x0 = xindex // 12 # 最外层: b """2.8 内存访问模式的优化
# torch/_inductor/codegen/triton.pyclassTritonKernel:defoptimize_memory_access(self,buffer,indices):""" 优化内存访问模式 目标: 1. 合并连续访问(coalesced access) 2. 向量化加载(vectorized load) 3. 避免bank conflict """# [1] 检查最内层维度是否连续ifself.is_contiguous_access(indices):# 使用向量化加载returnself.generate_vectorized_load(buffer,indices)else:# 使用标量加载returnself.generate_scalar_load(buffer,indices)defis_contiguous_access(self,indices):""" 判断访问是否连续 连续访问:最内层索引是线性的 例如:buffer[xindex] 是连续的 buffer[xindex * stride] 可能不连续 """innermost_index=indices[-1]# 检查是否是 xindex 或 xindex + constantreturnself.is_linear_in_xindex(innermost_index)defgenerate_vectorized_load(self,buffer,indices):""" 生成向量化加载 # 标量加载(慢) for i in range(XBLOCK): data[i] = buffer[xindex + i] # 向量化加载(快) data = tl.load(buffer + xindex, mask, other=0.0) """offset=self.compute_offset(indices)tmp_var=self.new_tmp_var()self.loads.writeline(f"{tmp_var}= tl.load({buffer}+{offset}, xmask, other=0.0)")returntmp_var# 实际生成的代码对比# [A] 连续访问(高效)tmp0=tl.load(in_ptr0+xindex,xmask)# Coalesced# [B] 非连续访问(低效)offset=x0*1024+x1# 可能导致不连续tmp0=tl.load(in_ptr0+offset,xmask)# Potentially uncoalesced2.9 GELU示例:复杂算子的完整生成
让我们用一个稍微复杂的例子,完整展示整个流程:
# [0] 用户代码importtorch@torch.compiledeffused_gelu_approximate(x):"""GELU激活函数的近似实现"""# GELU(x) ≈ 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x^3)))importmath# 常量sqrt_2_over_pi=math.sqrt(2.0/math.pi)# 计算x_cubed=x*x*x inner=sqrt_2_over_pi*(x+0.044715*x_cubed)tanh_inner=torch.tanh(inner)result=0.5*x*(1.0+tanh_inner)returnresult x=torch.randn(1024,device='cuda')y=fused_gelu_approximate(x)[1] FX Graph:
defforward(self,x):mul=x*x# x^2mul_1=mul*x# x^3mul_2=0.044715*mul_1# 0.044715 * x^3add=x+mul_2# x + 0.044715 * x^3mul_3=0.7978845608*add# sqrt(2/π) * (...)tanh=torch.tanh(mul_3)# tanh(...)add_1=1.0+tanh# 1 + tanh(...)mul_4=x*add_1# x * (1 + tanh(...))mul_5=0.5*mul_4# 0.5 * x * (...)returnmul_5[2] 融合后的IR:
所有操作都是Pointwise,可以全部融合!
deffused_inner_fn(idx):x_val=ops.load(x,idx)# x^3x_squared=ops.mul(x_val,x_val)x_cubed=ops.mul(x_squared,x_val)# 0.044715 * x^3term1=ops.mul(ops.constant(0.044715),x_cubed)# x + 0.044715 * x^3sum1=ops.add(x_val,term1)# sqrt(2/π) * (...)scaled=ops.mul(ops.constant(0.7978845608),sum1)# tanh(...)tanh_val=ops.tanh(scaled)# 1 + tanh(...)sum2=ops.add(ops.constant(1.0),tanh_val)# x * (1 + tanh(...))prod1=ops.mul(x_val,sum2)# 0.5 * x * (...)result=ops.mul(ops.constant(0.5),prod1)returnresult[3] 生成的Triton代码:
@triton.jitdeftriton_poi_fused_gelu_0(in_ptr0,# xout_ptr0,# outputxnumel,# 1024XBLOCK:tl.constexpr,):xoffset=tl.program_id(0)*XBLOCK xindex=xoffset+tl.arange(0,XBLOCK)xmask=xindex<xnumel# Loadx0=tl.load(in_ptr0+xindex,xmask)# Compute (全部融合!)tmp0=x0*x0# x^2tmp1=tmp0*x0# x^3tmp2=0.044715*tmp1# 0.044715 * x^3tmp3=x0+tmp2# x + 0.044715 * x^3tmp4=0.7978845608*tmp3# sqrt(2/π) * (...)tmp5=tl.libdevice.tanh(tmp4)# tanh(...) - 调用GPU库函数tmp6=1.0+tmp5# 1 + tanh(...)tmp7=x0*tmp6# x * (1 + tanh(...))tmp8=0.5*tmp7# 0.5 * x * (...)# Storetl.store(out_ptr0+xindex,tmp8,xmask)# 调用配置grid=lambdameta:(triton.cdiv(1024,meta['XBLOCK']),)triton_poi_fused_gelu_0[grid](x,out,1024,XBLOCK=256,num_warps=4,)性能分析:
# 未融合版本(Eager模式)# 9个独立的kernel调用:# 1. mul (x*x)# 2. mul (x^2 * x)# 3. mul (0.044715 * x^3)# 4. add (x + ...)# 5. mul (sqrt(2/π) * ...)# 6. tanh# 7. add (1 + tanh)# 8. mul (x * ...)# 9. mul (0.5 * ...)## 内存访问:9次读 + 9次写 = 18次# Kernel启动:9次 * ~5μs = ~45μs# 融合版本(torch.compile)# 1个融合的kernel## 内存访问:1次读 + 1次写 = 2次# Kernel启动:1次 * ~5μs = ~5μs## 加速比:~9x(理论)2.10 Wrapper代码生成
除了kernel代码,TorchInductor还生成Python wrapper代码来调用kernel:
# torch/_inductor/codegen/wrapper.pyclassWrapperCodegen:defgenerate(self):"""生成完整的Python模块"""code=IndentedBuffer()# [1] 导入code.writeline("import torch")code.writeline("import triton")code.writeline("import triton.language as tl")code.writeline("")# [2] Kernel定义code.splice(self.kernel_code)code.writeline("")# [3] 调用函数code.writeline("def call(args):")code.indent()# [3.1] 解包参数code.writeline("primals_1 = args[0] # x")code.writeline("primals_2 = args[1] # y")code.writeline("")# [3.2] 分配输出缓冲区code.writeline("buf0 = torch.empty_like(primals_1)")code.writeline("")# [3.3] 调用kernelcode.writeline("# Launch kernel")code.writeline("grid = lambda meta: (triton.cdiv(1024, meta['XBLOCK']),)")code.writeline("triton_poi_fused_add_relu_0[grid](")code.indent()code.writeline("primals_1, # in_ptr0")code.writeline("primals_2, # in_ptr1")code.writeline("buf0, # out_ptr0")code.writeline("1024, # xnumel")code.writeline("XBLOCK=256,")code.writeline("num_warps=4,")code.dedent()code.writeline(")")code.writeline("")# [3.4] 返回结果code.writeline("return (buf0,)")code.dedent()returncode.getvalue()生成的完整Python模块:
# /tmp/torchinductor_username/xx/cxxyyyzzz.pyimporttorchimporttritonimporttriton.languageastl@triton.jitdeftriton_poi_fused_add_relu_0(in_ptr0,in_ptr1,out_ptr0,xnumel,XBLOCK:tl.constexpr,):pid=tl.program_id(0)xoffset=pid*XBLOCK xindex=xoffset+tl.arange(0,XBLOCK)xmask=xindex<xnumel tmp0=tl.load(in_ptr0+xindex,xmask)tmp1=tl.load(in_ptr1+xindex,xmask)tmp2=tmp0+tmp1 tmp3=tl.maximum(tmp2,0.0)tl.store(out_ptr0+xindex,tmp3,xmask)defcall(args):"""主调用函数"""# [1] 解包输入primals_1=args[0]# xprimals_2=args[1]# y# [2] 分配输出buf0=torch.empty_like(primals_1)# [3] 调用kernelgrid=lambdameta:(triton.cdiv(1024,meta['XBLOCK']),)triton_poi_fused_add_relu_0[grid](primals_1,primals_2,buf0,1024,XBLOCK=256,num_warps=4,)# [4] 返回return(buf0,)3. Triton Kernel 参数详解
3.1 核心参数概览
Triton Kernel 的性能由以下参数决定:
| 参数 | 含义 | 典型范围 | 影响 |
|---|---|---|---|
| XBLOCK | X 维度的 block size | 16, 32, 64, 128, 256, 512, 1024 | 每个线程块处理的元素数 |
| YBLOCK | Y 维度的 block size | 1, 16, 32, 64 | 2D block 的第二维大小 |
| RBLOCK | Reduction 维度的 block size | 32, 64, 128, 256, 512, 1024 | Reduction 操作的块大小 |
| num_warps | 每个 block 的 warp 数 | 1, 2, 4, 8, 16, 32 | SM 资源占用 |
| num_stages | Pipeline stages | 1, 2, 3, 4 | 内存带宽隐藏 |
| grid | Grid 大小 (x, y, z) | (1, 1, 1) ~ (2^31-1, 65535, 65535) | 并行的线程块数量 |
3.2 参数示例
# 生成的 Triton Kernel 示例@triton.jitdeftriton_poi_fused_add_relu_0(in_ptr0,# 输入指针in_ptr1,# 输入指针out_ptr0,# 输出指针xnumel,# 总元素数XBLOCK:tl.constexpr,# 编译时常量):xoffset=tl.program_id(0)*XBLOCK xindex=xoffset+tl.arange(0,XBLOCK)xmask=xindex<xnumel x0=xindex tmp0=tl.load(in_ptr0+x0,xmask)tmp1=tl.load(in_ptr1+x0,xmask)tmp2=tmp0+tmp1 tmp3=tl.maximum(tmp2,0.0)tl.store(out_ptr0+x0,tmp3,xmask)# 调用时的配置grid=lambdameta:(triton.cdiv(numel,meta['XBLOCK']),)triton_poi_fused_add_relu_0[grid](in_ptr0,in_ptr1,out_ptr0,numel,XBLOCK=1024,# ← 启发式选择num_warps=4,# ← 启发式选择num_stages=1,)3.3 参数之间的关系
XBLOCK 与 num_warps 的关系: ┌─────────────────────────────────────────────┐ │ XBLOCK threads num_warps (推荐) │ ├─────────────────────────────────────────────┤ │ 16-64 ≤64 1 │ │ 128 128 2-4 │ │ 256 256 4 │ │ 512 512 8 │ │ 1024 1024 8-16 (取决于 SM 大小) │ └─────────────────────────────────────────────┘ 每个 warp = 32 threads 每个 block 的 threads = XBLOCK (对于 1D kernel) num_warps ≥ XBLOCK / 324. 启发式优化策略
4.1 TorchInductor 的 Heuristics 系统
# torch/_inductor/codegen/triton.pyclassTritonScheduling:@staticmethoddefget_heuristics():"""获取启发式配置"""returnconfig.triton.heuristicsdefcodegen_kernel(self,name):"""生成 Triton Kernel"""# [1] 计算工作负载numel=self.get_numel()# [2] 选择 XBLOCKxblock=self.get_xblock(numel)# [3] 选择 num_warpsnum_warps=self.get_num_warps(xblock)# [4] 计算 grid sizegrid=self.get_grid(numel,xblock)returnkernel_code# torch/_inductor/config.pyclassTritonConfig:# 启发式配置max_tiles=2# 每个 block 最多处理的 tile 数max_xblock=1024# 最大 XBLOCKmin_xblock=16# 最小 XBLOCK4.2 决策树:选择 XBLOCK
defselect_xblock(numel:int,dtype:torch.dtype)->int:""" 根据元素数量和数据类型选择 XBLOCK 策略: 1. 尽量使用大 XBLOCK 减少 kernel launch 开销 2. 但要保证足够的并行度(grid size) 3. 考虑寄存器和共享内存限制 """# [1] 数据类型对应的字节数elem_size=dtype.itemsize# float32=4, float16=2# [2] 可用的 XBLOCK 选项(2 的幂次)xblock_options=[16,32,64,128,256,512,1024]# [3] 选择策略forxblockinreversed(xblock_options):# 从大到小尝试grid_size=triton.cdiv(numel,xblock)# 条件 1: grid size 不能太小(至少要有足够的并行度)min_grid_size=get_device_capability().multiprocessor_count*4ifgrid_size<min_grid_size:continue# 条件 2: 不超过最大 grid size 限制ifgrid_size>2**31-1:# CUDA grid.x 限制continue# 条件 3: 内存访问效率(避免 bank conflict)ifxblock*elem_size%128==0:# 对齐到 128 字节(缓存行)returnxblock# 默认返回中等大小return256# 实际使用示例numel=1024*1024# 1M 元素xblock=select_xblock(numel,torch.float32)print(f"Selected XBLOCK:{xblock}")# 输出: 1024grid_size=triton.cdiv(numel,xblock)print(f"Grid size:{grid_size}")# 输出: 10244.3 决策树:选择 num_warps
defselect_num_warps(xblock:int,reduction:bool=False)->int:""" 根据 XBLOCK 选择 num_warps 原则: 1. num_warps * 32 ≥ threads_per_block 2. 2 的幂次(硬件友好) 3. 不超过 SM 的最大 warp 数 """# [1] 计算理论 warp 数threads_per_block=xblock min_warps=(threads_per_block+31)//32# [2] 向上取整到 2 的幂次importmath num_warps=2**math.ceil(math.log2(min_warps))# [3] 限制范围num_warps=max(1,min(num_warps,32))# [4] Reduction 操作通常需要更多 warpsifreductionandnum_warps<4:num_warps=4returnnum_warps# 示例print(select_num_warps(1024))# → 32print(select_num_warps(256))# → 8print(select_num_warps(64))# → 25. Grid Size 与循环策略
5.1 Grid Size 的限制
# CUDA Grid Size 限制(CC 3.0+)MAX_GRID_X=2**31-1# ≈ 21 亿MAX_GRID_Y=65535MAX_GRID_Z=65535# 实际硬件还有其他限制:# - SM 数量:A100 有 108 个 SM# - 每个 SM 的最大 block 数:通常 16-32# - 总的活跃 block 数:SM_count * blocks_per_SM5.2 Grid Size 与内部循环的权衡
当数据量非常大时,有两种策略:
策略 A:大 Grid + 小 Block(无内部循环)
# 假设 numel = 100M 元素xblock=256grid_size=triton.cdiv(100_000_000,256)# = 390,625# Kernel 内部:每个 block 只处理 256 个元素@triton.jitdefkernel(ptr,numel,XBLOCK:tl.constexpr):pid=tl.program_id(0)offset=pid*XBLOCK index=offset+tl.arange(0,XBLOCK)mask=index<numel# 只处理一次,无循环data=tl.load(ptr+index,mask)tl.store(ptr+index,data,mask)策略 B:小 Grid + 大 Block(有内部循环)
# 限制 grid_size 到一个合理值MAX_GRID=48# ← 常见的启发式值xblock=256tiles_per_block=triton.cdiv(100_000_000,MAX_GRID*xblock)# ≈ 8,139# Kernel 内部:每个 block 处理多个 tile@triton.jitdefkernel(ptr,numel,XBLOCK:tl.constexpr):pid=tl.program_id(0)tiles=triton.cdiv(numel,XBLOCK)tiles_per_pid=triton.cdiv(tiles,tl.num_programs(0))# 内部循环处理多个 tilefortile_idxinrange(tiles_per_pid):offset=(pid*tiles_per_pid+tile_idx)*XBLOCK index=offset+tl.arange(0,XBLOCK)mask=index<numel data=tl.load(ptr+index,mask)tl.store(ptr+index,data,mask)5.3 为什么限制 Grid Size 为 48?
这是一个常见的启发式值,来源于:
# torch/_inductor/codegen/triton.pydefget_max_grid_size():""" 计算最大 grid size 原因: 1. 避免 kernel launch 开销过大 - 每次 launch 有固定开销(~5μs) - 过大的 grid 会增加调度延迟 2. 保证 SM 利用率 - A100 有 108 个 SM - 假设每个 SM 可以并发运行 4 个 block - 那么 grid = 108 * 4 = 432 就已经饱和 - 但考虑到 warp 调度和 memory latency hiding - 通常设置为 SM_count * 0.5 左右 3. 简化调试 - 小 grid 更容易 profile 和调试 """device=torch.cuda.current_device()sm_count=torch.cuda.get_device_properties(device).multi_processor_count# 启发式:SM 数量的 0.5 倍# A100: 108 * 0.5 = 54# V100: 80 * 0.5 = 40max_grid=max(sm_count//2,1)# 也可以手动配置ifconfig.triton.max_tiles:max_grid=config.triton.max_tiles*sm_countreturnmax_grid5.4 实战:Grid Size 优化
importtorchimporttritonimporttriton.languageastl# [1] 无内部循环版本@triton.jitdefadd_kernel_v1(x_ptr,y_ptr,out_ptr,n_elements,XBLOCK:tl.constexpr):pid=tl.program_id(0)offset=pid*XBLOCK mask=offset+tl.arange(0,XBLOCK)<n_elements x=tl.load(x_ptr+offset+tl.arange(0,XBLOCK),mask)y=tl.load(y_ptr+offset+tl.arange(0,XBLOCK),mask)out=x+y tl.store(out_ptr+offset+tl.arange(0,XBLOCK),out,mask)# [2] 有内部循环版本@triton.jitdefadd_kernel_v2(x_ptr,y_ptr,out_ptr,n_elements,XBLOCK:tl.constexpr):pid=tl.program_id(0)n_tiles=tl.cdiv(n_elements,XBLOCK)tiles_per_pid=tl.cdiv(n_tiles,tl.num_programs(0))foriinrange(tiles_per_pid):offset=(pid*tiles_per_pid+i)*XBLOCK mask=offset+tl.arange(0,XBLOCK)<n_elements x=tl.load(x_ptr+offset+tl.arange(0,XBLOCK),mask)y=tl.load(y_ptr+offset+tl.arange(0,XBLOCK),mask)out=x+y tl.store(out_ptr+offset+tl.arange(0,XBLOCK),out,mask)# 测试defbenchmark_grid_size():n=100_000_000 x=torch.randn(n,device='cuda')y=torch.randn(n,device='cuda')out=torch.empty_like(x)# V1: 大 Gridxblock=1024grid_v1=(triton.cdiv(n,xblock),)print(f"V1 Grid size:{grid_v1[0]}")# 97,657# V2: 小 Grid(限制为 48)max_grid=48grid_v2=(min(triton.cdiv(n,xblock),max_grid),)print(f"V2 Grid size:{grid_v2[0]}")# 48# Benchmarkimporttime# Warmupfor_inrange(10):add_kernel_v1[grid_v1](x,y,out,n,xblock)torch.cuda.synchronize()t0=time.time()for_inrange(100):add_kernel_v1[grid_v1](x,y,out,n,xblock)torch.cuda.synchronize()t1=time.time()print(f"V1 (large grid):{(t1-t0)/100*1000:.3f}ms")# V2for_inrange(10):add_kernel_v2[grid_v2](x,y,out,n,xblock)torch.cuda.synchronize()t0=time.time()for_inrange(100):add_kernel_v2[grid_v2](x,y,out,n,xblock)torch.cuda.synchronize()t1=time.time()print(f"V2 (small grid):{(t1-t0)/100*1000:.3f}ms")# 输出示例(A100):# V1 Grid size: 97657# V2 Grid size: 48# V1 (large grid): 1.234 ms# V2 (small grid): 1.187 ms ← 略快(减少 launch 开销)6. Block Size 与 Num_Warps 配置
6.1 Block Size 选择策略
defheuristic_block_size(numel:int,dtype:torch.dtype,is_reduction:bool=False,)->int:""" 启发式选择 Block Size 考虑因素: 1. 数据类型大小(float32=4, float16=2) 2. 是否是 Reduction 操作 3. 内存带宽利用率 4. 寄存器压力 """elem_size=dtype.itemsizeifis_reduction:# Reduction 需要共享内存,倾向于小 blockcandidates=[128,256,512]else:# Pointwise 可以用大 blockcandidates=[256,512,1024]# 根据数据大小调整forblock_sizeinreversed(candidates):# 每个 block 处理的字节数bytes_per_block=block_size*elem_size# 条件 1: 不超过共享内存限制(通常 48KB)ifbytes_per_block>48*1024:continue# 条件 2: 保证足够的并行度grid_size=triton.cdiv(numel,block_size)ifgrid_size<108:# 至少等于 SM 数量continuereturnblock_sizereturn256# 默认值# 示例print(heuristic_block_size(1000000,torch.float32,False))# → 1024print(heuristic_block_size(1000000,torch.float32,True))# → 512print(heuristic_block_size(1000000,torch.float16,False))# → 10246.2 Num_Warps 与性能的关系
# 实验:不同 num_warps 的性能importtorchimporttritonimporttriton.languageastl@triton.jitdefvector_add(x_ptr,y_ptr,out_ptr,n,XBLOCK:tl.constexpr):pid=tl.program_id(0)offsets=pid*XBLOCK+tl.arange(0,XBLOCK)mask=offsets<n x=tl.load(x_ptr+offsets,mask)y=tl.load(y_ptr+offsets,mask)out=x+y tl.store(out_ptr+offsets,out,mask)defbenchmark_num_warps():n=10_000_000 x=torch.randn(n,device='cuda')y=torch.randn(n,device='cuda')xblock=1024grid=(triton.cdiv(n,xblock),)fornum_warpsin[1,2,4,8,16,32]:out=torch.empty_like(x)# Warmupfor_inrange(10):vector_add[grid](x,y,out,n,XBLOCK=xblock,num_warps=num_warps)# Benchmarktorch.cuda.synchronize()importtime t0=time.time()for_inrange(100):vector_add[grid](x,y,out,n,XBLOCK=xblock,num_warps=num_warps)torch.cuda.synchronize()t1=time.time()print(f"num_warps={num_warps:2d}:{(t1-t0)/100*1000:.3f}ms")# 输出示例(A100):# num_warps= 1: 2.456 ms (寄存器充足,但 warp 调度不够)# num_warps= 2: 1.834 ms# num_warps= 4: 1.523 ms# num_warps= 8: 1.412 ms ← 最优(XBLOCK=1024,理论需要 32 warps)# num_warps=16: 1.398 ms# num_warps=32: 1.405 ms (寄存器压力增大)6.3 TorchInductor 的实际配置
# torch/_inductor/codegen/triton.pyclassTritonKernel:defestimate_kernel_num_warps(self):"""估算最优 num_warps"""# [1] 基于 threads 数量threads_per_block=self.estimate_threads_per_block()min_warps=(threads_per_block+31)//32# [2] 基于寄存器使用(通过 IR 分析)num_registers=self.estimate_register_usage()max_warps_by_regs=65536//num_registers# 假设 64K 寄存器# [3] 基于共享内存使用smem_bytes=self.estimate_smem_usage()max_warps_by_smem=49152//smem_bytes# 假设 48KB 共享内存# [4] 取最小值,并向上取整到 2 的幂次num_warps=min(min_warps,max_warps_by_regs,max_warps_by_smem)num_warps=2**math.ceil(math.log2(max(1,num_warps)))num_warps=min(num_warps,32)# 硬件限制returnnum_warps7. AutoTuning 机制
7.1 什么是 AutoTuning
AutoTuning 是指自动尝试多组参数配置,选择性能最优的一组。
# Triton 的 autotune 装饰器importtriton@triton.autotune(configs=[triton.Config({'XBLOCK':128},num_warps=2),triton.Config({'XBLOCK':256},num_warps=4),triton.Config({'XBLOCK':512},num_warps=8),triton.Config({'XBLOCK':1024},num_warps=16),],key=['n_elements'],# 根据 n_elements 缓存最优配置)@triton.jitdefvector_add_autotuned(x_ptr,y_ptr,out_ptr,n_elements,XBLOCK:tl.constexpr,):pid=tl.program_id(0)offsets=pid*XBLOCK+tl.arange(0,XBLOCK)mask=offsets<n_elements x=tl.load(x_ptr+offsets,mask)y=tl.load(y_ptr+offsets,mask)out=x+y tl.store(out_ptr+offsets,out,mask)# 使用时会自动测试所有配置,选择最快的x=torch.randn(10_000_000,device='cuda')y=torch.randn_like(x)out=torch.empty_like(x)grid=lambdameta:(triton.cdiv(10_000_000,meta['XBLOCK']),)vector_add_autotuned[grid](x,y,out,10_000_000)# 第一次调用会测试所有配置(约 100ms)# 后续调用直接使用缓存的最优配置(约 1ms)7.2 TorchInductor 的 AutoTuning
TorchInductor 默认不开启 autotune(因为编译时间太长),但可以手动启用:
importtorch# 启用 autotunetorch._inductor.config.triton.autotune=Truetorch._inductor.config.triton.autotune_at_compile_time=True@torch.compiledefmodel(x,y):return(x+y).relu()x=torch.randn(1000000,device='cuda')y=torch.randn_like(x)# 第一次调用会 autotune(慢)result=model(x,y)# 后续调用使用缓存配置(快)result=model(x,y)7.3 AutoTuning 的实现原理
# torch/_inductor/codegen/triton.pyclassTritonScheduling:defautotune_kernel(self,kernel_name,kernel_code):""" 自动调优 kernel 参数 流程: 1. 生成候选配置列表 2. 编译所有配置的 kernel 3. 在真实数据上运行,测量时间 4. 选择最快的配置 5. 缓存结果到磁盘 """# [1] 生成候选配置configs=self.generate_autotune_configs()# 示例:[# {'XBLOCK': 128, 'num_warps': 2},# {'XBLOCK': 256, 'num_warps': 4},# {'XBLOCK': 512, 'num_warps': 8},# {'XBLOCK': 1024, 'num_warps': 16},# ]# [2] 编译所有 kernelcompiled_kernels=[]forconfiginconfigs:code=self.apply_config(kernel_code,config)compiled=triton.compile(code)compiled_kernels.append((config,compiled))# [3] Benchmarkbest_time=float('inf')best_config=Noneforconfig,kernelincompiled_kernels:# 运行 10 次取平均times=[]for_inrange(10):torch.cuda.synchronize()t0=time.perf_counter()kernel(*args,**kwargs)torch.cuda.synchronize()t1=time.perf_counter()times.append(t1-t0)avg_time=sum(times)/len(times)ifavg_time<best_time:best_time=avg_time best_config=config# [4] 缓存结果self.cache_autotune_result(kernel_name,best_config)returnbest_config7.4 配置生成策略
defgenerate_autotune_configs():""" 生成 autotune 候选配置 策略: 1. 覆盖常见的 XBLOCK 值 2. 为每个 XBLOCK 选择合适的 num_warps 3. 避免无效配置(如 num_warps 过大) """configs=[]# XBLOCK 候选值(2 的幂次)xblock_options=[128,256,512,1024]forxblockinxblock_options:# 计算理论 warp 数min_warps=(xblock+31)//32# 尝试多个 num_warpsfornum_warpsin[2,4,8,16]:ifnum_warps>=min_warps:configs.append({'XBLOCK':xblock,'num_warps':num_warps,})# 添加一些特殊配置configs.extend([{'XBLOCK':64,'num_warps':1},# 小数据量{'XBLOCK':2048,'num_warps':32},# 大数据量(如果硬件支持)])returnconfigs8. 实战:自定义代码生成策略
8.1 目标:为 GXU 设备定制代码生成
假设 GXU 设备有以下特性:
- 每个 SM 只能并发运行 2 个 block(而不是 CUDA 的 16)
- Grid size 限制为 32(而不是 48)
- 最大 num_warps = 16(而不是 32)
8.2 自定义 Heuristics
# my_backend/gxu_heuristics.pyimporttorchfromtorch._inductor.codegen.tritonimportTritonSchedulingclassGXUTritonHeuristics:"""GXU 设备的 Triton 代码生成启发式"""# 硬件参数MAX_GRID_SIZE=32MAX_NUM_WARPS=16BLOCKS_PER_SM=2@staticmethoddefget_max_grid_size():"""返回最大 grid size"""returnGXUTritonHeuristics.MAX_GRID_SIZE@staticmethoddefselect_xblock(numel:int,dtype:torch.dtype)->int:""" 选择 XBLOCK 策略:倾向于使用内部循环,减少 grid size """elem_size=dtype.itemsize# 目标:grid_size ≤ 32# 那么每个 block 至少处理 numel / 32 个元素min_xblock=(numel+31)//32# 向上取整到 2 的幂次importmath xblock=2**math.ceil(math.log2(max(min_xblock,128)))# 限制范围xblock=min(xblock,1024)# 硬件限制xblock=max(xblock,128)# 最小值returnxblock@staticmethoddefselect_num_warps(xblock:int)->int:"""选择 num_warps(限制为 16)"""threads=xblock min_warps=(threads+31)//32# 向上取整到 2 的幂次importmath num_warps=2**math.ceil(math.log2(min_warps))# GXU 限制num_warps=min(num_warps,GXUTritonHeuristics.MAX_NUM_WARPS)returnnum_warps@staticmethoddefget_grid(numel:int,xblock:int):""" 计算 grid size(使用内部循环) """# 总的 tile 数total_tiles=(numel+xblock-1)//xblock# 限制 grid sizegrid_size=min(total_tiles,GXUTritonHeuristics.MAX_GRID_SIZE)return(grid_size,)8.3 注入到 TorchInductor
# my_backend/gxu_inductor.pyfromtorch._inductor.codegenimporttritonasinductor_tritonfrom.gxu_heuristicsimportGXUTritonHeuristicsdefpatch_inductor_for_gxu():"""将 GXU heuristics 注入 TorchInductor"""# 保存原始方法original_get_grid=inductor_triton.TritonScheduling.get_grid# 重写 get_grid 方法defgxu_get_grid(self,numel,xblock):# 检查设备类型ifself.device.type=='gxu':returnGXUTritonHeuristics.get_grid(numel,xblock)else:returnoriginal_get_grid(self,numel,xblock)inductor_triton.TritonScheduling.get_grid=gxu_get_grid# 类似地重写其他方法# ...# 在模块加载时自动 patchpatch_inductor_for_gxu()8.4 生成带内部循环的 Kernel
# torch/_inductor/codegen/triton.py (修改后)classTritonScheduling:defcodegen_kernel(self,name):"""生成 Triton kernel 代码"""numel=self.get_numel()dtype=self.get_dtype()# [1] 选择 XBLOCKxblock=GXUTritonHeuristics.select_xblock(numel,dtype)# [2] 选择 num_warpsnum_warps=GXUTritonHeuristics.select_num_warps(xblock)# [3] 计算 gridgrid=GXUTritonHeuristics.get_grid(numel,xblock)# [4] 生成 kernel 代码(带内部循环)kernel_code=f""" @triton.jit def{name}( in_ptr0, out_ptr0, numel, XBLOCK: tl.constexpr, ): pid = tl.program_id(0) # 计算每个 block 处理的 tile 数 n_tiles = tl.cdiv(numel, XBLOCK) tiles_per_pid = tl.cdiv(n_tiles, tl.num_programs(0)) # 内部循环 for tile_idx in range(tiles_per_pid): offset = (pid * tiles_per_pid + tile_idx) * XBLOCK index = offset + tl.arange(0, XBLOCK) mask = index < numel # Load x = tl.load(in_ptr0 + index, mask) # Compute y = x * 2.0 # Store tl.store(out_ptr0 + index, y, mask) # 调用 grid = ({grid[0]},){name}[grid]( in_ptr0, out_ptr0, numel, XBLOCK={xblock}, num_warps={num_warps}, ) """returnkernel_code8.5 完整示例
# test_gxu_codegen.pyimporttorchimporttorch._dynamoasdynamofrommy_backend.gxu_inductorimportpatch_inductor_for_gxu# [1] Patch TorchInductorpatch_inductor_for_gxu()# [2] 定义模型@torch.compiledefmy_model(x):return(x*2).relu()# [3] 准备数据(使用 PrivateUse1 作为 GXU 设备)torch.utils.rename_privateuse1_backend("gxu")x=torch.randn(100_000_000,device='gxu')# [4] 运行(会使用 GXU heuristics 生成代码)y=my_model(x)# [5] 查看生成的代码print(torch._dynamo.utils.compile_times())print(torch._inductor.utils.get_last_generated_code())生成的代码示例:
# 自动生成的 Triton kernel@triton.jitdeftriton_poi_fused_mul_relu_0(in_ptr0,out_ptr0,numel,XBLOCK:tl.constexpr,):pid=tl.program_id(0)# 每个 block 处理多个 tile(因为 grid 限制为 32)n_tiles=tl.cdiv(numel,XBLOCK)# 100M / 1024 ≈ 97657tiles_per_pid=tl.cdiv(n_tiles,tl.num_programs(0))# 97657 / 32 ≈ 3052fortile_idxinrange(tiles_per_pid):# ← 内部循环 3052 次offset=(pid*tiles_per_pid+tile_idx)*XBLOCK index=offset+tl.arange(0,XBLOCK)mask=index<numel x=tl.load(in_ptr0+index,mask)y=x*2.0y=tl.maximum(y,0.0)tl.store(out_ptr0+index,y,mask)# 调用配置grid=(32,)# ← 限制为 32triton_poi_fused_mul_relu_0[grid](in_ptr0,out_ptr0,100_000_000,XBLOCK=1024,num_warps=16,# ← 限制为 16)9. 性能分析与调优
9.1 使用 Nsight Compute 分析 Kernel
# 安装 NVIDIA Nsight Compute# https://developer.nvidia.com/nsight-compute# 运行 profilingncu --set full -o profile_output python test_script.py# 查看结果ncu-ui profile_output.ncu-rep关键指标:
- Occupancy(占用率): 实际活跃 warp 数 / 理论最大 warp 数
- 目标:> 50%
- 过低:增加 num_warps 或减少寄存器使用
- Memory Throughput(内存吞吐): 实际带宽 / 峰值带宽
- 目标:> 80%(对于 memory-bound kernel)
- SM Efficiency(SM 效率): SM 忙碌时间 / 总时间
- 目标:> 90%
9.2 使用 Triton Profiler
importtorchimporttriton# 启用 profilingtriton.Config.enable_debug=True@torch.compiledefmodel(x):returnx.relu()x=torch.randn(1000000,device='cuda')# 运行withtorch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA],record_shapes=True,)asprof:y=model(x)# 打印结果print(prof.key_averages().table(sort_by="cuda_time_total",row_limit=10))9.3 调优 Checklist
| 问题 | 症状 | 解决方案 |
|---|---|---|
| Grid 过大 | Kernel launch 开销高 | 使用内部循环,限制 grid ≤ 48 |
| Block 过小 | SM 占用率低 | 增加 XBLOCK,减少 grid |
| Warp 不足 | 无法隐藏内存延迟 | 增加 num_warps |
| Warp 过多 | 寄存器溢出 | 减少 num_warps,简化计算 |
| 内存不对齐 | 带宽利用率低 | 调整 XBLOCK 使其为 128 字节倍数 |
| Bank Conflict | 共享内存访问慢 | 使用 padding 或调整访问模式 |
10. 常见问题
Q1: 为什么我的 kernel 性能不如手写的 CUDA?
A:可能原因:
未开启 autotune:TorchInductor 默认使用固定启发式
torch._inductor.config.triton.autotune=TrueFusion 不够激进:手动调整 fusion 策略
torch._inductor.config.max_fusion_size=64内存布局不optimal:检查是否有不必要的 transpose
# 查看生成的代码torch._inductor.config.debug=True
Q2: Grid size 为什么限制为 48?
A:这是一个保守的启发式值:
- A100 有 108 个 SM,48 约为 0.5 倍
- 避免 kernel launch 开销过大
- 平衡调度延迟和并行度
可以修改:
torch._inductor.config.triton.max_tiles=2# 每个 SM 2 个 tile# 实际 grid = SM_count * max_tilesQ3: 如何为新硬件添加自定义 heuristics?
A:三个步骤:
- 实现
DeviceOpOverrides类(参考第 8 章) - 实现自定义 heuristics 函数
- Patch
TritonScheduling的相关方法
fromtorch._inductor.codegenimporttritonasinductor_tritonclassMyDeviceHeuristics:@staticmethoddefselect_xblock(numel,dtype):# 自定义逻辑return512@staticmethoddefselect_num_warps(xblock):return8# Patchoriginal_select_xblock=inductor_triton.TritonScheduling.select_xblockdefpatched_select_xblock(self,numel,dtype):ifself.device.type=='mydevice':returnMyDeviceHeuristics.select_xblock(numel,dtype)returnoriginal_select_xblock(self,numel,dtype)inductor_triton.TritonScheduling.select_xblock=patched_select_xblockQ4: 内部循环会降低性能吗?
A:不一定:
- 优点:减少 kernel launch 开销,更好的指令流水线
- 缺点:增加寄存器压力,可能降低占用率
实践中:
- 如果
grid_size > 10000,使用内部循环通常更快 - 如果
grid_size < 1000,直接大 grid 可能更快 - 需要根据实际硬件 benchmark
Q5: 如何调试生成的 Triton 代码?
A:
# [1] 打印生成的代码torch._inductor.config.debug=Truetorch._inductor.config.trace.enabled=True@torch.compiledefmodel(x):returnx.relu()x=torch.randn(100,device='cuda')y=model(x)# [2] 查看日志# 会打印到 stdout,包含完整的 Triton 代码# [3] 保存到文件torch._inductor.config.trace.log_output_code=Truetorch._inductor.config.trace.log_dir="./inductor_logs"# [4] 使用 Triton 的调试工具importtriton triton.Config.enable_debug=True11. 总结
11.1 关键要点
代码生成流程
- FX Graph → Lowering → Fusion → Scheduling → Code Emission → Compilation
Triton Kernel 参数
XBLOCK: 影响并行度和内存访问num_warps: 影响 SM 占用率和寄存器使用grid: 影响 kernel launch 开销和内部循环
启发式策略
- Grid size: 通常限制为 SM_count * 0.5(如 48)
- Block size: 根据数据类型和操作类型选择(256-1024)
- Num_warps: 根据 threads 数量和寄存器压力选择(1-32)
性能优化
- 使用 AutoTuning 自动寻找最优配置
- 针对特定硬件定制 heuristics
- 使用 profiler 分析瓶颈
11.2 最佳实践
# ✅ 推荐:为自定义设备提供完整的 heuristicsclassMyDeviceHeuristics:MAX_GRID=32MAX_WARPS=16@staticmethoddefselect_xblock(numel,dtype):# 确保 grid 不超过限制min_xblock=(numel+MyDeviceHeuristics.MAX_GRID-1)//MyDeviceHeuristics.MAX_GRID xblock=2**math.ceil(math.log2(max(min_xblock,128)))returnmin(xblock,1024)@staticmethoddefselect_num_warps(xblock):num_warps=2**math.ceil(math.log2((xblock+31)//32))returnmin(num_warps,MyDeviceHeuristics.MAX_WARPS)@staticmethoddefget_grid(numel,xblock):tiles=(numel+xblock-1)//xblockreturn(min(tiles,MyDeviceHeuristics.MAX_GRID),)# ✅ 推荐:启用 debug 模式查看生成的代码torch._inductor.config.debug=True# ✅ 推荐:对性能关键的 kernel 使用 autotunetorch._inductor.config.triton.autotune=True# ❌ 避免:盲目增大 XBLOCK# 原因:可能导致 grid 过小,并行度不足# ❌ 避免:使用过大的 num_warps# 原因:会导致寄存器溢出,降低占用率11.3 参考资料
- Triton Language Reference
- CUDA C Programming Guide - Occupancy
- TorchInductor Source Code
- Nsight Compute User Guide
下一章预告:第 11 章将讨论内存优化与数据布局,包括 memory planning、buffer reuse、以及如何减少中间张量的内存占用。
© 2024 LLM-BOOK. 本文档仅供学习参考使用。