CANN/cann-bench UnsortedSegmentSum 算子 API 描述
UnsortedSegmentSum 算子 API 描述【免费下载链接】cann-bench评测AI在处理CANN领域代码任务的能力涵盖算子生成、算子优化等领域支撑模型选型、训练效果评估统一量化评估标准识别Agent能力短板构建CANN领域评测平台推动AI能力在CANN领域的持续演进。项目地址: https://gitcode.com/cann/cann-bench1. 算子简介沿 segment_ids 指定的段对数据进行求和。主要应用场景图神经网络中的节点特征聚合按邻居分段求和点云处理中的体素化聚合稀疏特征的按组求和与池化嵌入表梯度的按 ID 累加算子特征难度等级L2ScatterUpdate双输入单输出根据 segment_ids 将 data 中的元素按段分组求和2. 算子定义数学公式$$ y[i] \sum_{j: \text{segment_ids}[j] i} \text{data}[j] $$对于每个段 $i \in [0, \text{num_segments})$将所有 segment_ids 等于 $i$ 的 data 元素在第 0 维上求和。若某段没有对应的元素则输出为零。3. 接口规范算子原型cann_bench.unsorted_segment_sum(Tensor data, Tensor segment_ids, int num_segments) - Tensor y输入参数说明参数类型默认值描述dataTensor必选输入数据张量segment_idsTensor必选段 ID 张量值在 [0, num_segments) 范围内num_segmentsint必选段数量输出参数Shapedtype描述y(num_segments, *data.shape[1:])与 data 相同输出张量段求和结果数据类型data dtypesegment_ids dtype输出 dtypefloat16int32 / int64float16float32int32 / int64float32int32int32 / int64int32int64int32 / int64int64规则与约束segment_ids 的形状必须与 data 的第 0 维大小一致或与 data 形状完全一致多维场景segment_ids 中的值必须在 [0, num_segments) 范围内输出的第 0 维大小为 num_segments其余维度与 data 的后续维度一致若某个段 ID 在 segment_ids 中未出现对应输出段为全零segment_ids 的 dtype 必须为 int32 或 int64num_segments 必须为正整数4. 精度要求采用生态算子精度标准进行验证。误差指标平均相对误差MERE采样点中相对误差平均值$$ \text{MERE} \text{avg}(\frac{\text{abs}(actual - golden)}{\text{abs}(golden)\text{1e-7}}) $$最大相对误差MARE采样点中相对误差最大值$$ \text{MARE} \max(\frac{\text{abs}(actual - golden)}{\text{abs}(golden)\text{1e-7}}) $$通过标准数据类型FLOAT16BFLOAT16FLOAT32HiFLOAT32FLOAT8 E4M3FLOAT8 E5M2通过阈值(Threshold)2^-102^-72^-132^-112^-32^-2当平均相对误差 MERE Threshold最大相对误差 MARE 10 * Threshold 时判定为通过。5. 标准 Golden 代码import torch UnsortedSegmentSum算子Torch Golden参考实现 沿segment_ids指定的段对数据进行求和 公式: y[i] sum(data[j]) where segment_ids[j] i def unsorted_segment_sum( data: torch.Tensor, segment_ids: torch.Tensor, num_segments: int ) - torch.Tensor: 沿segment_ids指定的段对数据进行求和 公式: y[i] sum(data[j]) where segment_ids[j] i 对于 FP16/BF16 输入使用 FP32 进行内部累加以保证精度 其他类型保持原样 Args: data: 输入数据张量 segment_ids: 段ID张量 num_segments: 段数量 Returns: 输出张量段求和结果 output_shape (num_segments,) data.shape[1:] # FP16/BF16 输入升精度到 FP32 进行累加以保证精度 if data.dtype in (torch.float16, torch.bfloat16): y_fp32 torch.zeros(output_shape, dtypetorch.float32, devicedata.device) data_fp32 data.to(torch.float32) y_fp32.index_add_(0, segment_ids, data_fp32) y y_fp32.to(data.dtype) else: y torch.zeros(output_shape, dtypedata.dtype, devicedata.device) y.index_add_(0, segment_ids, data) return y6. 额外信息算子调用示例import torch import cann_bench data torch.randn(1048576, dtypetorch.float16, devicenpu) segment_ids torch.randint(0, 1024, (1048576,), dtypetorch.int32, devicenpu) y cann_bench.unsorted_segment_sum(data, segment_ids, num_segments1024) # 2D 数据按段求和 data torch.randn(1024, 1024, dtypetorch.float32, devicenpu) segment_ids torch.randint(0, 256, (1024,), dtypetorch.int32, devicenpu) y cann_bench.unsorted_segment_sum(data, segment_ids, num_segments256) # int32 数据类型 data torch.randint(-1000, 1000, (2048, 512), dtypetorch.int32, devicenpu) segment_ids torch.randint(0, 512, (2048,), dtypetorch.int32, devicenpu) y cann_bench.unsorted_segment_sum(data, segment_ids, num_segments512)【免费下载链接】cann-bench评测AI在处理CANN领域代码任务的能力涵盖算子生成、算子优化等领域支撑模型选型、训练效果评估统一量化评估标准识别Agent能力短板构建CANN领域评测平台推动AI能力在CANN领域的持续演进。项目地址: https://gitcode.com/cann/cann-bench创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考