1、问题背景
torch_npu.fused_linear_online_max_sum接口在批量执行多个用例的过程中,存在几条用例精度偶现不达标的现象,且每次失败的用例位置可能不一致,接口的多个输出中仅predicted_logits_local输出参数存在精度问题(有精度问题的输出值数值较大,且与cpu对比精度相对误差几十~几千倍不等),需要这一概率性精度失败的原因。
2、定位过程
由于在批跑用例时是存在概率性精度问题,导致这种现象最有可能的因素就是内存踩踏或者同步问题,因此从这两个方面进行排查。
2.1、注释掉其他输出的写入,排除内存踩踏影响
首先是被其他参数的内存踩踏问题,算子本身存在多个输出tensor,有可能predicted_logits_local输出tensor所占的空间被其他输出数据踩踏导致数值输出不正确。笔者通过将其他输出的写入(即搬出到GM)注释掉,保证只有predicted_logits_local进行搬出,同时依据指令排查对应的地址、拷贝元素,判定指令无越界可能。但经过测试,问题依然存在,非内存踩踏导致。
2.2、predicted_logits_local的计算逻辑
进一步排查是否为同步问题,笔者将predicted_logits_local所涉及到的所有操作包含搬入、搬出、计算过程简化了一下,计算逻辑主要有下面红框中从上往下①~④四个部分,其中绿框部分①InitOutputAndWorkspace、③CVProcess、④AllReduceProces过程都涉及到了拷贝结果到predicted_logits_local上的操作,我们根据这一过程分步进行定位。
2.2.1、初始化的尝试
从①InitOutputAndWorkspace初始化部分尝试定位
(1)仅保留predicted_logits_local初始化0的操作,注释掉InitOutputAndWorkspace后续predicted_logits_local的搬出操作,此时批跑predicted_logits_local结果都是0符合预期。
(2)初始化时不初始化为0,而是将predicted_logits_local初始化为shape大小的固定元素值(使多个用例初始化成不同的固定值,便于定位是否有异常),同样注释掉InitOutputAndWorkspace后续predicted_logits_local的搬出操作,批跑后结果都为固定元素值,也符合预期。
基于上述两点排除初始化的问题:初始化值正确,并且也和2.1相互印证确实没有被其他参数搬出过程踩踏。
2.2.2、CV融合计算过程中的尝试
(1)注释掉④AllReduceProces中的predicted_logits_local搬出操作,即predicted_logits_local仅输出③CVProcess中的临时计算结果,在输入数据相同情况下多次批跑后仍存在精度不一致的现象,确认问题由③CVProcess这一部分引入。
(2)观察③CVProcess中关于predicted_logits_local的计算,依赖于②TargetProcess中的计算结果maskedTarget和③CVProcess中的Cube计算结果vocabParallelLogitsOutOptional,(1)中已经判断出③CVProcess计算后predicted_logits_local搬出结果存在问题,因此首先怀疑是③中predicted_logits_local计算过程中可能存在同步问题。
笔者尝试在predicted_logits_local计算过程中的所有指令间都插入PipeBarrier<PIPE_ALL>(),发现依然存在概率性精度不达标现象,因此可以排除此处同步问题。
(3)在(2)的基础上,笔者注释掉predicted_logits_local计算逻辑,直接将其依赖参数vocabParallelLogitsOutOptional、maskedTarget分别拷贝到predicted_logits_local中输出,保持相同输入数据多次批跑,发现将maskedTarget拷贝出predicted_logits_local的时候,predicted_logits_local出现前后精度不一致的的现象,这说明predicted_logits_local计算时获取到的maskedTarget数据问题。
maskedTarget自身最终输出没有精度问题,但在predicted_logits_local计算时获取到的maskedTarget数值不正确,基于这一现象,问题确认为是predicted_logits_local计算时maskedTarget数据未计算完成导致计算错误。由于maskedTarget数据是在②TargetProcess中vec核上计算得到,而③CVProcess中predicted_logits_local数据计算涉及matmul过程,而matmul使用到了cube核计算,那么问题基本就缩小到了核间未同步导致。在vec核cube计算之间插入syncall(false)核间同步,再次批跑,概率性精度失败问题消失。
3、问题根因
predicted_logits_local依赖maskedTarget结算结果,maskedTarget在vec核上还未完全计算完情况下便使用cube核进行了相关matmul计算,导致predicted_logits_local计算过程中读取到了脏数据,每次批跑用例出现概率性失败的现象发生。
4、解决方案
在TargetProcess和CVProcess中间插入全核同步,保证maskedTarget计算完成后再进行predicted_logits_local相关计算。重新批跑执行用例,此时结果都pass。
5、经验总结
(1)同步问题不仅存在单核内计算过程中,当kernel计算过程中涉及到cube和vector的协同(此次精度失败的点)或者多个核之间数据依赖(通常在初始化操作或者某个核计算的结果需要提供给其他核使用)时,也需要关注核与核之间的同步。
(2)对于概率性失败问题,可以罗列出问题参数涉及的运算过程,从初始化到结果搬出过程逐步注释掉后续的一部分计算过程定位到异常产生的地方,针对异常的发生的位置进行数据依赖、数据空间复用、数据指针越界等检查,确保数据计算前置依赖已经完成,空间之间不会相互踩踏,指针不会越界影响到其他参数。