CANN/xla-npu GatherV2Op实现总结
GatherV2Op 实现总结【免费下载链接】xla-npuXLA-NPU 是一个面向华为昇腾NPU硬件的 XLA后端实现。本项目通过接入OpenXLA/XLA开源项目将XLA开源生态与华为 CANN软件栈集成对接JAX框架。JAX框架运行时可以直接加载XLA-NPU使得基于JAX框架开发的模型可以运行在昇腾NPU上提供推理场景图编译加速能力。项目地址: https://gitcode.com/cann/xla-npu完成的工作1. 在 mair_ops.td 中添加 GatherV2Op 定义文件mair_ops.tddef Air_GatherV2Op : Air_OpGatherV2, [Pure] { let summary GatherV2 operation; let description [{ Gathers slices from params axis according to indices. Similar to TensorFlows tf.gather and PyTorchs torch.gather. }]; let arguments (ins Air_Tensor:$x, Air_Tensor:$indices, Air_Tensor:$axis, OptionalAir_Tensor:$batch_dims ); let results (outs Air_Tensor); }参数说明x: 输入张量indices: 索引张量axis: 指定收集的轴标量张量batch_dims: 可选的 batch 维度数2. 实现 ConvertGatherOp 转换逻辑文件mair_passes.cc转换策略针对简单的单轴 gather 情况满足以下条件start_index_map只有一个元素collapsed_slice_dims只有一个元素collapsed_slice_dims[0] start_index_map[0]slice_sizes在 collapsed 维度上为 1转换步骤提取维度信息auto offsetDims dimensionNumbers.getOffsetDims(); auto collapsedSliceDims dimensionNumbers.getCollapsedSliceDims(); auto startIndexMap dimensionNumbers.getStartIndexMap(); auto indexVectorDim dimensionNumbers.getIndexVectorDim();创建 axis 常量int64_t axis startIndexMap[0]; auto axisConst rewriter.createConstantOp(op.getLoc(), axisType, axisAttr);处理 indices可能需要 squeeze// 如果 index_vector_dim 是最后一个维度且大小为 1则 squeeze if (indexVectorDim indicesRank - 1 indicesShape.back() 1) { // Squeeze 最后一维 indices rewriter.createReshapeOp(...); }创建 GatherV2Opauto gatherResult rewriter.createGatherV2Op( op.getLoc(), op.getType(), operand, indices, axisConst.getResult(), nullptr);调试日志打印转换前的 MLIR打印维度信息operand shape, indices shape, offset_dims, collapsed_slice_dims, start_index_map, index_vector_dim, slice_sizes打印转换后的 MLIR3. 在 export_graphdef.cc 中添加 GatherV2 导出文件export_graphdef.cc添加的映射{GatherV2, {x, indices, axis, batch_dims}},说明MLIR Op 到 GE Op 的转换是通用的通过processOperationInputs函数使用inputNameMap来映射输入名称添加这个映射后GatherV2Op 会自动被正确导出为 Ascend 的 GatherV2 操作转换示例用户的例子输入%13 stablehlo.gather(%arg0, %5) { dimension_numbers #stablehlo.gather offset_dims [2], collapsed_slice_dims [0], start_index_map [0], index_vector_dim 2 , slice_sizes arrayi64: 1, 896 } : (tensor151936x896xf32, tensor1x8x1xi32) - tensor1x8x896xf32转换步骤分析维度信息operand:[151936, 896]indices:[1, 8, 1]start_index_map:[0]→ axis0collapsed_slice_dims:[0]index_vector_dim: 2最后一个维度slice_sizes:[1, 896]Squeeze indices[1, 8, 1]→[1, 8]去掉最后一个维度创建 GatherV2%axis mair.Const() {value 0 : i64} : () - tensori64 %result mair.GatherV2(%arg0, %indices_squeezed, %axis) : (tensor151936x896xf32, tensor1x8xi32, tensori64) - tensor1x8x896xf32语义解释# 伪代码 result zeros([1, 8, 896]) for i in range(1): for j in range(8): index indices[i, j] # 标量索引 result[i, j, :] operand[index, :] # 收集切片 [896]限制和未来工作当前限制只支持简单的单轴 gatherstart_index_map只有一个元素collapsed_slice_dims只有一个元素不支持多轴 gather不支持复杂的 gather 场景多个 batch 维度多个 collapsed 维度复杂的维度映射未来工作支持更复杂的 gather 场景使用 GatherNd 处理多轴 gather组合多个操作处理复杂的维度映射优化性能减少 Reshape 操作直接使用更高效的 Ascend 操作添加更多测试用例多轴 gather不同维度组合边界情况测试建议建议创建以下测试用例简单的单轴 gather当前支持stablehlo.gather with start_index_map [0], collapsed_slice_dims [0]不同轴的 gatherstablehlo.gather with start_index_map [1], collapsed_slice_dims [1]多索引维度stablehlo.gather with indices shape [batch, M, N, 1]边界情况空索引单元素索引大批量索引总结成功实现了 GatherV2Op 的定义和转换逻辑支持简单的单轴 gather 场景。对于用户的例子转换逻辑可以正确处理✓ 识别单轴 gather✓ 创建 axis 常量✓ Squeeze indices✓ 创建 GatherV2Op✓ 添加详细的调试日志对于更复杂的 gather 场景需要进一步扩展实现。【免费下载链接】xla-npuXLA-NPU 是一个面向华为昇腾NPU硬件的 XLA后端实现。本项目通过接入OpenXLA/XLA开源项目将XLA开源生态与华为 CANN软件栈集成对接JAX框架。JAX框架运行时可以直接加载XLA-NPU使得基于JAX框架开发的模型可以运行在昇腾NPU上提供推理场景图编译加速能力。项目地址: https://gitcode.com/cann/xla-npu创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考