pypto.experimental.transposed_batchmatmul
【免费下载链接】pyptoPyPTO(发音: pai p-t-o):Parallel Tensor/Tile Operation编程范式。项目地址: https://gitcode.com/cann/pypto
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Ascend 950PR/Ascend 950DT | √ |
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | √ |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | √ |
功能说明
该接口为定制接口,约束较多。不保证稳定性。
该算子执行转置批量矩阵乘法。具体操作为:
- 将输入张量
tensor_a从形状 (M, B, K) 转置为 (B, M, K)。 - 执行批量矩阵乘法,将转置后的
tensor_a(B, M, K) 与tensor_b(B, K, N) 相乘,得到中间结果 (B, M, N)。 - 将中间结果转置回形状 (M, B, N) 作为最终输出。
函数原型
transposed_batchmatmul(tensor_a: Tensor, tensor_b: Tensor, out_dtype: dtype) -> Tensor参数说明
| 参数名 | 输入/输出 | 说明 |
|---|---|---|
| tensor_a | 输入 | 左侧输入张量。 支持的数据类型为:DT_FP16, DT_BF16。 不支持空Tensor,支持三维。 形状必须为 (M, B, K)。 |
| tensor_b | 输入 | 右侧输入张量。 支持的数据类型为:DT_FP16, DT_BF16。 不支持空Tensor,支持三维。 形状必须为 (B, K, N)。 |
| out_dtype | 输入 | 输出张量的数据类型。 支持的数据类型为:DT_FP16, DT_BF16。 |
返回值说明
返回输出 Tensor,Tensor 的数据类型由out_dtype指定,形状为 (M, B, N)。
调用示例
import pypto # 创建输入张量 a = pypto.tensor((16, 2, 32), pypto.DT_FP16, "tensor_a") b = pypto.tensor((2, 32, 64), pypto.DT_FP16, "tensor_b") # 调用算子 c = pypto.experimental.transposed_batchmatmul(a, b, pypto.DT_FP16) # 输出张量 c 的形状为 (16, 2, 64)【免费下载链接】pyptoPyPTO(发音: pai p-t-o):Parallel Tensor/Tile Operation编程范式。项目地址: https://gitcode.com/cann/pypto
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考