CANN asc-devkit矩阵乘加API
asc_mmad【免费下载链接】asc-devkit本项目是CANN 推出的昇腾AI处理器专用的算子程序开发语言原生支持C和C标准规范主要由类库和语言扩展层构成提供多层级API满足多维场景算子开发诉求。项目地址: https://gitcode.com/cann/asc-devkit产品支持情况产品是否支持Atlas A3 训练系列产品/Atlas A3 推理系列产品√Atlas A2 训练系列产品/Atlas A2 推理系列产品√Ascend 950PR/Ascend 950DT√功能说明完成矩阵乘加操作。计算公式如下$$ c_{matrix} (a_{matrix} * b_{matrix}) c_{matrix} $$函数原型常规计算__aicore__ inline void asc_mmad_s4(__cc__ int32_t* c_matrix, __ca__ int4b_t* a_matrix, __cb__ int4b_t* b_matrix, uint16_t left_height, uint16_t n_dim, uint16_t right_width, uint8_t unit_flag, bool k_direction_align, bool c_matrix_source, bool c_matrix_init_val); __aicore__ inline void asc_mmad_s4(__cc__ int32_t* c_matrix, __ca__ int4b_t* a_matrix, __cb__ int4b_t* b_matrix, uint16_t left_height, uint16_t n_dim, uint16_t right_width, uint8_t feat_offset, uint8_t smask_offset, uint8_t unit_flag, bool k_direction_align, bool is_weight_offset, bool c_matrix_source, bool c_matrix_init_val); __aicore__ inline void asc_mmad(__cc__ int32_t* c_matrix, __ca__ int8_t* a_matrix, __cb__ int8_t* b_matrix, uint16_t left_height, uint16_t n_dim, uint16_t right_width, uint8_t unit_flag, bool k_direction_align, bool c_matrix_source, bool c_matrix_init_val); __aicore__ inline void asc_mmad(__cc__ int32_t* c_matrix, __ca__ int8_t* a_matrix, __cb__ int8_t* b_matrix, uint16_t left_height, uint16_t n_dim, uint16_t right_width, uint8_t feat_offset, uint8_t smask_offset, uint8_t unit_flag, bool k_direction_align, bool is_weight_offset, bool c_matrix_source, bool c_matrix_init_val); __aicore__ inline void asc_mmad(__cc__ float* c_matrix, __ca__ half* a_matrix, __cb__ half* b_matrix, uint16_t left_height, uint16_t n_dim, uint16_t right_width, uint8_t unit_flag, bool k_direction_align, bool c_matrix_source, bool c_matrix_init_val); __aicore__ inline void asc_mmad(__cc__ float* c_matrix, __ca__ half* a_matrix, __cb__ half* b_matrix, uint16_t left_height, uint16_t n_dim, uint16_t right_width, uint8_t feat_offset, uint8_t smask_offset, uint8_t unit_flag, bool k_direction_align, bool is_weight_offset, bool c_matrix_source, bool c_matrix_init_val); __aicore__ inline void asc_mmad(__cc__ float* c_matrix, __ca__ float* a_matrix, __cb__ float* b_matrix, uint16_t left_height, uint16_t n_dim, uint16_t right_width, uint8_t unit_flag, bool k_direction_align, bool c_matrix_source, bool c_matrix_init_val); __aicore__ inline void asc_mmad(__cc__ float* c_matrix, __ca__ float* a_matrix, __cb__ float* b_matrix, uint16_t left_height, uint16_t n_dim, uint16_t right_width, uint8_t feat_offset, uint8_t smask_offset, uint8_t unit_flag, bool k_direction_align, bool is_weight_offset, bool c_matrix_source, bool c_matrix_init_val); __aicore__ inline void asc_mmad(__cc__ float* c_matrix, __ca__ bfloat16_t* a_matrix, __cb__ bfloat16_t* b_matrix, uint16_t left_height, uint16_t n_dim, uint16_t right_width, uint8_t unit_flag, bool k_direction_align, bool c_matrix_source, bool c_matrix_init_val); __aicore__ inline void asc_mmad(__cc__ float* c_matrix, __ca__ bfloat16_t* a_matrix, __cb__ bfloat16_t* b_matrix, uint16_t left_height, uint16_t n_dim, uint16_t right_width, uint8_t feat_offset, uint8_t smask_offset, uint8_t unit_flag, bool k_direction_align, bool is_weight_offset, bool c_matrix_source, bool c_matrix_init_val);// 如下原型仅支持Ascend 950PR/Ascend 950DT __aicore__ inline void asc_mmad(__cc__ float* c_matrix, __ca__ bfloat16_t* a_matrix, __cb__ bfloat16_t* b_matrix, uint16_t left_height, uint16_t n_dim, uint16_t right_width, uint8_t unit_flag, bool disable_gemv, bool c_matrix_source, bool c_matrix_init_val); __aicore__ inline void asc_mmad(__cc__ float* c_matrix, __ca__ float8_e4m3_t* a_matrix, __cb__ float8_e4m3_t* b_matrix, uint16_t left_height, uint16_t n_dim, uint16_t right_width, uint8_t unit_flag, bool disable_gemv, bool c_matrix_source, bool c_matrix_init_val); __aicore__ inline void asc_mmad(__cc__ float* c_matrix, __ca__ float8_e4m3_t* a_matrix, __cb__ float8_e5m2_t* b_matrix, uint16_t left_height, uint16_t n_dim, uint16_t right_width, uint8_t unit_flag, bool disable_gemv, bool c_matrix_source, bool c_matrix_init_val); __aicore__ inline void asc_mmad(__cc__ float* c_matrix, __ca__ float8_e5m2_t* a_matrix, __cb__ float8_e4m3_t* b_matrix, uint16_t left_height, uint16_t n_dim, uint16_t right_width, uint8_t unit_flag, bool disable_gemv, bool c_matrix_source, bool c_matrix_init_val); __aicore__ inline void asc_mmad(__cc__ float* c_matrix, __ca__ float8_e5m2_t* a_matrix, __cb__ float8_e5m2_t* b_matrix, uint16_t left_height, uint16_t n_dim, uint16_t right_width, uint8_t unit_flag, bool disable_gemv, bool c_matrix_source, bool c_matrix_init_val); __aicore__ inline void asc_mmad(__cc__ float* c_matrix, __ca__ half* a_matrix, __cb__ half* b_matrix, uint16_t left_height, uint16_t n_dim, uint16_t right_width, uint8_t unit_flag, bool disable_gemv, bool c_matrix_source, bool c_matrix_init_val); __aicore__ inline void asc_mmad(__cc__ float* c_matrix, __ca__ float* a_matrix, __cb__ float* b_matrix, uint16_t left_height, uint16_t n_dim, uint16_t right_width, uint8_t unit_flag, bool disable_gemv, bool c_matrix_source, bool c_matrix_init_val); __aicore__ inline void asc_mmad(__cc__ int32_t* c_matrix, __ca__ int8_t* a_matrix, __cb__ int8_t* b_matrix, uint16_t left_height, uint16_t n_dim, uint16_t right_width, uint8_t unit_flag, bool disable_gemv, bool c_matrix_source, bool c_matrix_init_val);同步计算__aicore__ inline void asc_mmad_s4_sync(__cc__ int32_t* c_matrix, __ca__ int4b_t* a_matrix, __cb__ int4b_t* b_matrix, uint16_t left_height, uint16_t n_dim, uint16_t right_width, uint8_t unit_flag, bool k_direction_align, bool c_matrix_source, bool c_matrix_init_val); __aicore__ inline void asc_mmad_s4_sync(__cc__ int32_t* c_matrix, __ca__ int4b_t* a_matrix, __cb__ int4b_t* b_matrix, uint16_t left_height, uint16_t n_dim, uint16_t right_width, uint8_t feat_offset, uint8_t smask_offset, uint8_t unit_flag, bool k_direction_align, bool is_weight_offset, bool c_matrix_source, bool c_matrix_init_val); __aicore__ inline void asc_mmad_sync(__cc__ int32_t* c_matrix, __ca__ int8_t* a_matrix, __cb__ int8_t* b_matrix, uint16_t left_height, uint16_t n_dim, uint16_t right_width, uint8_t unit_flag, bool k_direction_align, bool c_matrix_source, bool c_matrix_init_val); __aicore__ inline void asc_mmad_sync(__cc__ int32_t* c_matrix, __ca__ int8_t* a_matrix, __cb__ int8_t* b_matrix, uint16_t left_height, uint16_t n_dim, uint16_t right_width, uint8_t feat_offset, uint8_t smask_offset, uint8_t unit_flag, bool k_direction_align, bool is_weight_offset, bool c_matrix_source, bool c_matrix_init_val); __aicore__ inline void asc_mmad_sync(__cc__ float* c_matrix, __ca__ half* a_matrix, __cb__ half* b_matrix, uint16_t left_height, uint16_t n_dim, uint16_t right_width, uint8_t unit_flag, bool k_direction_align, bool c_matrix_source, bool c_matrix_init_val); __aicore__ inline void asc_mmad_sync(__cc__ float* c_matrix, __ca__ half* a_matrix, __cb__ half* b_matrix, uint16_t left_height, uint16_t n_dim, uint16_t right_width, uint8_t feat_offset, uint8_t smask_offset, uint8_t unit_flag, bool k_direction_align, bool is_weight_offset, bool c_matrix_source, bool c_matrix_init_val); __aicore__ inline void asc_mmad_sync(__cc__ float* c_matrix, __ca__ float* a_matrix, __cb__ float* b_matrix, uint16_t left_height, uint16_t n_dim, uint16_t right_width, uint8_t unit_flag, bool k_direction_align, bool c_matrix_source, bool c_matrix_init_val); __aicore__ inline void asc_mmad_sync(__cc__ float* c_matrix, __ca__ float* a_matrix, __cb__ float* b_matrix, uint16_t left_height, uint16_t n_dim, uint16_t right_width, uint8_t feat_offset, uint8_t smask_offset, uint8_t unit_flag, bool k_direction_align, bool is_weight_offset, bool c_matrix_source, bool c_matrix_init_val); __aicore__ inline void asc_mmad_sync(__cc__ float* c_matrix, __ca__ bfloat16_t* a_matrix, __cb__ bfloat16_t* b_matrix, uint16_t left_height, uint16_t n_dim, uint16_t right_width, uint8_t unit_flag, bool k_direction_align, bool c_matrix_source, bool c_matrix_init_val); __aicore__ inline void asc_mmad_sync(__cc__ float* c_matrix, __ca__ bfloat16_t* a_matrix, __cb__ bfloat16_t* b_matrix, uint16_t left_height, uint16_t n_dim, uint16_t right_width, uint8_t feat_offset, uint8_t smask_offset, uint8_t unit_flag, bool k_direction_align, bool is_weight_offset, bool c_matrix_source, bool c_matrix_init_val);// 如下原型仅支持Ascend 950PR/Ascend 950DT __aicore__ inline void asc_mmad_sync(__cc__ float* c_matrix, __ca__ bfloat16_t* a_matrix, __cb__ bfloat16_t* b_matrix, uint16_t left_height, uint16_t n_dim, uint16_t right_width, uint8_t unit_flag, bool disable_gemv, bool c_matrix_source, bool c_matrix_init_val); __aicore__ inline void asc_mmad_sync(__cc__ float* c_matrix, __ca__ float8_e4m3_t* a_matrix, __cb__ float8_e4m3_t* b_matrix, uint16_t left_height, uint16_t n_dim, uint16_t right_width, uint8_t unit_flag, bool disable_gemv, bool c_matrix_source, bool c_matrix_init_val); __aicore__ inline void asc_mmad_sync(__cc__ float* c_matrix, __ca__ float8_e4m3_t* a_matrix, __cb__ float8_e5m2_t* b_matrix, uint16_t left_height, uint16_t n_dim, uint16_t right_width, uint8_t unit_flag, bool disable_gemv, bool c_matrix_source, bool c_matrix_init_val); __aicore__ inline void asc_mmad_sync(__cc__ float* c_matrix, __ca__ float8_e5m2_t* a_matrix, __cb__ float8_e4m3_t* b_matrix, uint16_t left_height, uint16_t n_dim, uint16_t right_width, uint8_t unit_flag, bool disable_gemv, bool c_matrix_source, bool c_matrix_init_val); __aicore__ inline void asc_mmad_sync(__cc__ float* c_matrix, __ca__ float8_e5m2_t* a_matrix, __cb__ float8_e5m2_t* b_matrix, uint16_t left_height, uint16_t n_dim, uint16_t right_width, uint8_t unit_flag, bool disable_gemv, bool c_matrix_source, bool c_matrix_init_val); __aicore__ inline void asc_mmad_sync(__cc__ float* c_matrix, __ca__ half* a_matrix, __cb__ half* b_matrix, uint16_t left_height, uint16_t n_dim, uint16_t right_width, uint8_t unit_flag, bool disable_gemv, bool c_matrix_source, bool c_matrix_init_val); __aicore__ inline void asc_mmad_sync(__cc__ float* c_matrix, __ca__ float* a_matrix, __cb__ float* b_matrix, uint16_t left_height, uint16_t n_dim, uint16_t right_width, uint8_t unit_flag, bool disable_gemv, bool c_matrix_source, bool c_matrix_init_val); __aicore__ inline void asc_mmad_sync(__cc__ int32_t* c_matrix, __ca__ int8_t* a_matrix, __cb__ int8_t* b_matrix, uint16_t left_height, uint16_t n_dim, uint16_t right_width, uint8_t unit_flag, bool disable_gemv, bool c_matrix_source, bool c_matrix_init_val);参数说明参数名输入/输出描述c_matrix输出目的操作数结果矩阵。a_matrix输入源操作数左矩阵A。b_matrix输入源操作数右矩阵B。left_height输入左矩阵height 取值范围为[0,4095]。n_dim输入左矩阵width、右矩阵height取值范围为[0,4095]。right_width输入右矩阵width取值范围为[0,4095]。feat_offset输入保留参数。smask_offset输入权重矩阵的偏移位。unit_flag输入unit_flag是一种asc_mmad接口和Fixpipe指令细粒度的并行使能该功能后硬件每计算完一个分形计算结果就会被搬出该功能不适用于L0C Buffer累加的场景。取值说明如下• 0关闭unit_flag• 1保留值• 2使能unit_flag硬件执行完指令后不会关闭unit_flag功能• 3使能unit_flag硬件执行完指令后会关闭unit_flag功能。使能该功能时矩阵计算的unit_flag在最后一个分形设置为3其余分形计算设置为2即可。k_direction_align输入当源操作数和目的操作数为float时L0A和L0B中的矩阵在right_width方向上按ceil(right_width/16)*16方式都对齐到48对于right_width44L0A/L0B中的所有12个分形都会被读取到Cube中而对于right_width36只有L0A/L0B中的10个分形会被读取到Cube中。is_weight_offset输入使能weight matrix offset。c_matrix_source输入配置C矩阵初始值是否来源于C2存放Bias的硬件缓存区。取值说明如下• true来源于C2。• false来源于CO1(L0C)。c_matrix_init_val输入配置C矩阵初始值是否为0。取值说明如下• trueC矩阵初始值为0。• falseC矩阵初始值通过c_matrix_source参数进行配置。disable_gemv输入是否关闭GEMV模式false表示开启GEMV模式true表示关闭GEMV模式。GEMV(General Matrix-Vector Multiplication)表示实现矩阵和向量的乘积。当left_height1时开启GEMV后从L0A Buffer读取数据时将以ND格式进行读取而不会将其视为ZZ格式。返回值说明无流水类型PIPE_M约束说明当left_height、right_width、n_dim中的任意一个值为0时该指令不会被执行。操作数地址对齐约束请参考通用地址对齐约束。调用示例// total_length指参与搬运的数据总长度 constexpr uint64_t total_length 128; // 以下三个参数分别对应矩阵c,a,b的地址 __cc__ int32_t c_matrix[total_length]; __ca__ int8_t a_matrix[total_length]; __cb__ int8_t b_matrix[total_length]; // 其余入参均已默认数值传入 uint16_t left_height 16; uint16_t n_dim 16; uint16_t right_width 16; uint8_t unit_flag 0; bool disable_gemv false; bool c_matrix_source false; bool c_matrix_init_val true; // 函数调用 asc_mmad_sync(c_matrix, a_matrix, b_matrix, left_height, n_dim, right_width, unit_flag, disable_gemv, c_matrix_source, c_matrix_init_val);【免费下载链接】asc-devkit本项目是CANN 推出的昇腾AI处理器专用的算子程序开发语言原生支持C和C标准规范主要由类库和语言扩展层构成提供多层级API满足多维场景算子开发诉求。项目地址: https://gitcode.com/cann/asc-devkit创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考