您现在的位置是:首页 >其他 >NTT 的 C/C++ 实现网站首页其他
NTT 的 C/C++ 实现
简介NTT 的 C/C++ 实现
NTT (C ref)
ntt_ref.h
#ifndef NTT_H
#define NTT_H
typedef char int8;
typedef short int16;
typedef int int32;
typedef long long int64;
typedef unsigned char uint8;
typedef unsigned short uint16;
typedef unsigned int uint32;
typedef unsigned long long uint64;
//################################### 参数设置 ###################################
#define NTT_NEG 1 //0:循环NTT。1:反循环NTT。
#define NTT_Q 12289
#define NTT_N 1024
#define NTT_ROUND 10
#define NTT_ORDER (1<<(NTT_ROUND+1))
#define NTT_BASELEN (NTT_N>>NTT_ROUND)
#define NTT_ZETA 7
//################################### 快速模约减 ###################################
#define MONT_L 16
#define MONT_R (1LL<<MONT_L)
#define MONT 4091 // MONT_R mod q
#define NEGQINV 12287 // -q^-1 mod MONT_R
#define BARR_R (1LL<<32)
#define BARR 349497 // round(2^32/q)
// 蒙特马利模约简,计算 a*R^{-1} mod q
//当qinv负数时结果[-q, q],与NTT中逻辑冲突(要求正数),需要变一下号
//Newhope中q=12289,它的MONT_R选为18位(16位时数据溢出?没有吧!)
#define montgomery_reduce(a) (((a) + ((int16)((int64)(a)*NEGQINV)&(MONT_R-1))*NTT_Q)>>MONT_L)
// 巴雷特模约简,计算 a mod q
#define barrett_reduce(a) ((a)-((BARR*(int64)(a))>>32)*NTT_Q)
//################################### 函数定义 ###################################
void get_ntt_param(int32 q, int32 n, int32 r);
void ntt(int16* f);
void intt(int16* f);
void nttmul(int16* r, const int16* a, const int16* b);
int32 print_bytes(int8* arr, int32 len);
int32 print_coeffs(int16* arr, int32 len);
#endif
ntt_ref.c
#include <stdio.h>
#include <stdlib.h>
#include "ntt_ref.h"
//################################### 参数设置 ###################################
int16 zetas[NTT_ORDER + 1] = { };
int16 zetas_mont[NTT_ORDER + 1] = { };
int16 bitrev_list[NTT_ORDER] = { };
int32 factor = 12277, factor_mont = 64;
//################################### 通用函数 ###################################
int32 brv(int32 b, int32 l)
{
int32 bb = 0;
for (int32 i = 0; i < l; i++)
{
bb <<= 1;
bb |= (b & 1);
b >>= 1;
}
return bb;
}
int64 fast_pow(int64 a, int64 b, int64 q)
{
int64 result = 1;
while (b != 0)
{
if (b % 2 == 1)
result = (result * a) % q;
a = (a * a) % q;
b >>= 1;
}
return result;
}
int64 exgcd(int64* x, int64* y, int64 a, int64 b)
{
if (b == 0)
{
*x = 1;
*y = 0;
return a;
}
int64 ret = exgcd(x, y, b, a % b);
int64 tmp = *x;
*x = *y;
*y = tmp - (a / b) * (*y);
return ret;
}
int32 print_bytes(int8* arr, int32 len)
{
printf("[ %d", arr[0]);
for (int64 i = 1; i < len; i++)
printf(", %d", arr[i]);
printf(" ]");
return 0;
}
int32 print_coeffs(int16* arr, int32 len)
{
printf("[ %d", arr[0]);
for (int64 i = 1; i < len; i++)
printf(", %d", arr[i]);
printf(" ]");
return 0;
}
//############################## 预计算参数 ##############################
int64 find_root(int64 q, int64 ord)
{
int64 w = 2;
while (w < q)
{
if (fast_pow(w, ord, q) == 1 && fast_pow(w, ord >> 1, q) != 1)
{
printf("%lld-th root = %lld
", ord, w);
return w;
}
w++;
}
return 0;
}
void get_zetas(int32* zetas, int32 zeta, int32 q, int32 ord)
{
int64 wi = 1;
int64 w = zeta;
zetas[0] = 1;
printf("zetas = { %d", zetas[0]);
for (int64 i = 1; i <= ord; i++)
{
wi = (wi * w) % q;
zetas[i] = wi;
printf(", %lld", wi);
}
printf(" };
");
}
void get_zetas_mont(int32* zetas_mont, int32* zetas, int64 q, int64 ord, int64 mont)
{
int64 wi_pre = mont * zetas[0] % q;
zetas_mont[0] = wi_pre;
printf("zetas_mont = { %lld", wi_pre);
for (int64 i = 1; i <= ord; i++)
{
wi_pre = mont * zetas[i] % q;
zetas_mont[i] = wi_pre;
printf(", %lld", wi_pre);
}
printf(" };
");
}
void get_brv_table(int32 bits)
{
printf("bitrev_list = { 0");
int32 len = (1LL << bits);
for (int i = 1; i < len; i++)
printf(", %d", brv(i, bits));
printf(" };
");
}
void get_intt_factor(int64 q, int64 r, int64 mont)
{
int64 factor, pinv;
int64 gcd = exgcd(&factor, &pinv, 1LL << r, q);
factor = factor < 0 ? factor + q : factor;
int64 factor_mont = (factor * mont) % q;
printf("factor = %lld, factor_mont = %lld
", factor, factor_mont);
}
void get_ntt_param(int32 q, int32 n, int32 r)
{
printf("/******************************* get ntt params *******************************/
");
printf("NTT_Q = %d, NTT_N = %d, NTT_ROUND = %d, MONT_R = %lld, BARR_R = %lld
", q, n, r, MONT_R, BARR_R);
int64 d, x, y;
int64 mont = MONT_R % q;
d = exgcd(&x, &y, q, MONT_R);
if (d != 1)
{
printf("gcd(NTT_Q, MONT_R) != 1
");
return;
}
printf("mont = %lld mod q
qinv = %lld mod R
", mont, x);
printf("barret = 2^32/q = %lld
", (BARR_R + (q >> 1)) / q);
int32 order = 1 << (r+1);
int64 Zeta;
Zeta = find_root(q, order);
int32* Zetas = (int32*)malloc(sizeof(int32)*(order + 1));
int32* Zetas_mont = (int32*)malloc(sizeof(int32) * (order + 1));
get_zetas(Zetas, Zeta, q, order);
get_zetas_mont(Zetas_mont, Zetas, q, order, mont);
get_brv_table(r + 1);
get_intt_factor(q, r, mont);
printf("//******************************* get ntt params *******************************//
");
free(Zetas);
free(Zetas_mont);
}
//################################### NTT变换 ###################################
void ntt(int16* f) {
int32 Blocknum = 1;
int32 Blocksize = NTT_N;
int32 Round = 0;
/*
Radix-2
X = X + WY
Y = X - WY
*/
if ((NTT_ROUND & 1) == 1) {
int32 offset = Blocksize >> 1;
int32 X, Y, WY;
int32 zeta_mont = zetas_mont[Blocknum * NTT_NEG];
int16* pf = f;
for (int32 k = 0; k < offset; k++) {
X = pf[k];
WY = pf[k + offset] * zeta_mont;
WY = montgomery_reduce(WY);
pf[k] = X + WY;
pf[k + offset] = X + NTT_Q - WY;
}
Blocknum <<= 1;
Blocksize >>= 1;
Round++;
}
/*
Radix-4
Harvey,输入输出范围[0,2q)
X1 = (X1 + W*Y1) + W0*(X2 + W*Y2),范围[0,4q)
X2 = (X1 + W*Y1) - W0*(X2 + W*Y2),范围[0,4q)
Y1 = (X1 - W*Y1) + W1*(X2 - W*Y2),范围[0,4q)
X2 = (X1 - W*Y1) - W1*(X2 - W*Y2),范围[0,4q)
先约束X1范围[0,q),接着约束(X1 + W*Y1)和(X1 - W*Y1)范围[0,q),共三次模约减
*/
for (; Round < NTT_ROUND; Round += 2, Blocksize >>= 2, Blocknum <<= 2) {
int32 offset = Blocksize >> 2;
int32 X1, X2, Y1, Y2, WY;
int32 zeta_mont, zeta1_mont, zeta2_mont;
for (int32 i = 0; i < Blocknum; i++) {
int16* pf = f + i * Blocksize;
/*
j=0是原始数组,第j次迭代中,j-1层第i个分块使用的单位根,
w_{2^{j}}^{brv_{j}(2i)} = w_{2^{r}}^{2^{r-j}*brv_{j}(2i)}
brv_{j}(2i) = brv_{r}/(r-j+1)
因此 w_{2^{j}}^{brv_{j}(2i)} = w_{2^{r}}^{brv_{r}(i)/2}
*/
zeta_mont = zetas_mont[bitrev_list[Blocknum * NTT_NEG + i] >> 1]; //Round层第i块
zeta1_mont = zetas_mont[bitrev_list[2 * Blocknum * NTT_NEG + i * 2] >> 1]; //Round+1层第2i块
zeta2_mont = zetas_mont[bitrev_list[2 * Blocknum * NTT_NEG + i * 2 + 1] >> 1]; //Round+1层第2i+1块
for (int k = 0; k < offset; k++) {
X1 = pf[k];
X2 = pf[k + offset];
Y1 = pf[k + offset * 2];
Y2 = pf[k + offset * 3];
X1 -= ((NTT_Q - X1 - 1) >> 31) & NTT_Q;
WY = montgomery_reduce(Y1 * zeta_mont);
Y1 = X1 + NTT_Q - WY;
X1 += WY;
X1 -= ((NTT_Q - X1 - 1) >> 31) & NTT_Q;
Y1 -= ((NTT_Q - Y1 - 1) >> 31) & NTT_Q;
WY = montgomery_reduce(Y2 * zeta_mont);
Y2 = X2 + NTT_Q - WY;
X2 += WY;
WY = montgomery_reduce(X2 * zeta1_mont);
X2 = X1 + NTT_Q - WY;
X1 += WY;
WY = montgomery_reduce(Y2 * zeta2_mont);
Y2 = Y1 + NTT_Q - WY;
Y1 += WY;
pf[k] = X1;
pf[k + offset] = X2;
pf[k + offset * 2] = Y1;
pf[k + offset * 3] = Y2;
}
}
}
//for (int32 k = 0; k < NTT_N; k++) {
// f[k] -= ((NTT_Q - f[k] - 1) >> 31) & NTT_Q; //模约减,从[0,2q)约减到[0,q)
//}
}
void intt(int16* f) {
int32 Blocknum = 1 << NTT_ROUND;
int32 Blocksize = NTT_N >> NTT_ROUND;
int32 Round = NTT_ROUND;
int32 Qtimes2 = NTT_Q * 2;
Blocksize <<= 2;
Blocknum >>= 2;
/*
Radix-4
Harvey,输入输出范围[0,2q)
X1 = (X1 + X2) + (Y1 + Y2),范围[0,8q)
X2 = IW0*(X1 - X2) + IW1*(Y1 - Y2),范围[0,2q)
Y1 = IW*((X1 + X2) + (Y1 + Y2)),范围[0,q)
Y2 = IW*(IW0*(X1 - X2) + IW1*(Y1 - Y2)),范围[0,q)
先约束(X1 + X2)和(Y1 + Y2)范围[0,2q),接着约束(X1 + X2) + (Y1 + Y2)范围[0,2q),共三次模约减
*/
for (; Round > 1; Round -= 2, Blocksize <<= 2, Blocknum >>= 2) {
int32 offset = Blocksize >> 2;
int32 X1, X2, Y1, Y2, T;
int32 zeta_mont, zeta1_mont, zeta2_mont;
for (int32 i = 0; i < Blocknum; i++) {
int16* pf = f + i * Blocksize;
/*
j=0是原始数组,第j次迭代中,j-1层第i个分块使用的单位根,
w_{2^{j}}^{brv_{j}(2i)} = w_{2^{r}}^{2^{r-j}*brv_{j}(2i)}
brv_{j}(2i) = brv_{r}/(r-j+1)
因此 w_{2^{j}}^{brv_{j}(2i)} = w_{2^{r}}^{brv_{r}(i)/2}
*/
zeta_mont = zetas_mont[NTT_ORDER - (bitrev_list[Blocknum * NTT_NEG + i] >> 1)]; //Round层第i块
zeta1_mont = zetas_mont[NTT_ORDER - (bitrev_list[2 * Blocknum * NTT_NEG + i * 2] >> 1)]; //Round+1层第2i块
zeta2_mont = zetas_mont[NTT_ORDER - (bitrev_list[2 * Blocknum * NTT_NEG + i * 2 + 1] >> 1)]; //Round+1层第2i+1块
for (int k = 0; k < offset; k++) {
X1 = pf[k];
X2 = pf[k + offset];
Y1 = pf[k + offset * 2];
Y2 = pf[k + offset * 3];
T = (X1 - X2) * zeta1_mont;
X1 += X2;
X2 = montgomery_reduce(T);
X1 -= ((Qtimes2 - X1 - 1) >> 31) & Qtimes2; //模约减
T = (Y1 - Y2) * zeta2_mont;
Y1 += Y2;
Y2 = montgomery_reduce(T);
Y1 -= ((Qtimes2 - Y1 - 1) >> 31) & Qtimes2; //模约减
T = (X1 - Y1) * zeta_mont;
X1 += Y1;
Y1 = montgomery_reduce(T);
X1 -= ((Qtimes2 - X1 - 1) >> 31) & Qtimes2; //模约减
T = (X2 - Y2) * zeta_mont;
X2 += Y2;
Y2 = montgomery_reduce(T);
pf[k] = X1;
pf[k + offset] = X2;
pf[k + offset * 2] = Y1;
pf[k + offset * 3] = Y2;
}
}
}
/*
Radix-2
X = X + Y
Y = IW*(X - Y)
*/
if ((NTT_ROUND & 1) == 1) {
int32 offset = Blocksize >> 1;
int32 X, Y, T;
int32 zeta_mont = zetas_mont[NTT_ORDER - (bitrev_list[Blocknum * NTT_NEG] >> 1)];
int16* pf = f;
for (int32 k = 0; k < offset; k++) {
X = pf[k];
Y = pf[k + offset];
T = (X - Y) * zeta_mont;
pf[k] = X + Y;
pf[k + offset] = montgomery_reduce(T);
}
}
//逆变换因子
for (int32 k = 0; k < NTT_N; k++) {
int32 X = f[k] * factor_mont;
X = montgomery_reduce(X);
f[k] = X - (((NTT_Q - X - 1) >> 31) & NTT_Q);
}
}
inline void basemul(int16* r, const int16* a, const int16* b, int16 zeta)
{
int32 res; // 用更长的累加器,延迟取模运算
int32 s;
for (int32 i = 0; i < NTT_BASELEN; i++)
{
res = 0;
s = NTT_BASELEN + i;
for (int32 j = 0; j <= i; j++)
res += b[j] * a[i - j];
for (int32 j = i + 1; j < NTT_BASELEN; j++)
res += zeta * barrett_reduce(b[j] * a[s - j]);
r[i] = barrett_reduce(res);
}
}
void nttmul(int16* r, const int16* a, const int16* b)
{
// 2^{r-1} 个 n/2^{r-1} 长小多项式,NTT_ROUND = r-1
int32 num = 1 << NTT_ROUND;
for (int32 i = 0; i < num; i++)
{
#if (NTT_BASELEN == 1)
int32 tmp = *a * *b;
*r = barrett_reduce(tmp);
#elif (NTT_BASELEN == 2)
// 第r层第2^{r-1}+i个多项式使用的单位根,
// w_{2^r}^{brv_r(2^{r-1}+i)}
int32 zeta = zetas[bitrev_list[num * NTT_NEG + i]];
int32 tmp0 = a[0] * b[0] + zeta * barrett_reduce(a[1] * b[1]);
int32 tmp1 = a[0] * b[1] + a[1] * b[0];
r[0] = barrett_reduce(tmp0);
r[1] = barrett_reduce(tmp1);
#else
// 第r层第2^{r-1}+i个多项式使用的单位根,
// w_{2^r}^{brv_r(2^{r-1}+i)}
int32 zeta = zetas[bitrev_list[num * NTT_NEG + i]];
basemul(r, a, b, zeta);
#endif
r += NTT_BASELEN;
a += NTT_BASELEN;
b += NTT_BASELEN;
}
}
NTT (AVX2)
ntt_avx2.h
#ifndef NTT_H
#define NTT_H
typedef char int8;
typedef short int16;
typedef int int32;
typedef long long int64;
typedef unsigned char uint8;
typedef unsigned short uint16;
typedef unsigned int uint32;
typedef unsigned long long uint64;
//################################### 参数设置 ###################################
#define NTT_NEG 1 //0:循环NTT。1:反循环NTT。
#define NTT_Q 12289
#define NTT_N 1024
#define NTT_ROUND 10
#define NTT_ORDER (1<<(NTT_ROUND+1))
#define NTT_BASELEN (NTT_N>>NTT_ROUND)
#define NTT_ZETA 7
//################################### 快速模约减 ###################################
#define MONT_L 16
#define MONT_R (1LL<<MONT_L)
#define MONT 4091 // MONT_R mod q
#define QINV -12287 // q^-1 mod MONT_R
#define NEGQINV 12287 // -q^-1 mod MONT_R
#define BARR_epi16 5 // round(2^16/q)
#define BARR_epi32 349497 // round(2^32/q)
// 蒙特马利模约简,计算 a*R^{-1} mod q
//#define montgomery_reduce(a) (((a) - (int32)((int16)((int64)(a)*QINV))*NTT_Q)>>MONT_L)
//当qinv负数时结果[-q, q],与NTT中逻辑冲突(要求正数),需要变一下号
//Newhope中q=12289,它的MONT_R选为18位(16位时数据溢出?没有吧!)
#define montgomery_reduce(a) (((a) + (((int64)(a)*NEGQINV)&(MONT_R-1))*NTT_Q)>>MONT_L)
// 巴雷特模约简,计算 a mod q
#define barrett_reduce(a) ((a)-((BARR_epi32*(int64)(a))>>32)*NTT_Q)
//################################### 函数定义 ###################################
void get_ntt_param(int32 q, int32 n, int32 r);
void ntt(int16* f);
void intt(int16* f, int8 mont);
void nttmul(int16* r, const int16* a, const int16* b, int8 mont);
int32 print_bytes(int8* arr, int32 len);
int32 print_coeffs(int16* arr, int32 len);
#endif
ntt_avx2.c
#include <stdio.h>
#include <stdlib.h>
#include <xmmintrin.h> // __m128
#include <immintrin.h> // __m256
//#include <zmmintrin.h> // __m512
#include "ntt_avx2.h"
//################################### 参数设置 ###################################
const int16 zetas[NTT_ORDER + 1] = { };
const int16 zetas_mont[NTT_ORDER + 1] = { };
const int16 bitrev_list[NTT_ORDER] = { };
const int32 factor = 12277, factor_mont = 64, factor_mont2 = 3755;
//################################### 通用函数 ###################################
int32 brv(int32 b, int32 l)
{
int32 bb = 0;
for (int32 i = 0; i < l; i++)
{
bb <<= 1;
bb |= (b & 1);
b >>= 1;
}
return bb;
}
int64 fast_pow(int64 a, int64 b, int64 q)
{
int64 result = 1;
while (b != 0)
{
if (b % 2 == 1)
result = (result * a) % q;
a = (a * a) % q;
b >>= 1;
}
return result;
}
int64 exgcd(int64* x, int64* y, int64 a, int64 b)
{
if (b == 0)
{
*x = 1;
*y = 0;
return a;
}
int64 ret = exgcd(x, y, b, a % b);
int64 tmp = *x;
*x = *y;
*y = tmp - (a / b) * (*y);
return ret;
}
int32 print_bytes(int8* arr, int32 len)
{
printf("[ %d", arr[0]);
for (int64 i = 1; i < len; i++)
printf(", %d", arr[i]);
printf(" ]");
return 0;
}
int32 print_coeffs(int16* arr, int32 len)
{
printf("[ %d", arr[0]);
for (int64 i = 1; i < len; i++)
printf(", %d", arr[i]);
printf(" ]");
return 0;
}
//############################## 预计算参数 ##############################
int64 find_root(int64 q, int64 ord)
{
int64 w = 2;
while (w < q)
{
if (fast_pow(w, ord, q) == 1 && fast_pow(w, ord >> 1, q) != 1)
{
printf("%lld-th root = %lld
", ord, w);
return w;
}
w++;
}
return 0;
}
void get_zetas(int32* zetas, int32 zeta, int32 q, int32 ord)
{
int64 wi = 1;
int64 w = zeta;
zetas[0] = 1;
printf("zetas = { %d", zetas[0]);
for (int64 i = 1; i <= ord; i++)
{
wi = (wi * w) % q;
zetas[i] = wi;
printf(", %lld", wi);
}
printf(" };
");
}
void get_zetas_mont(int32* zetas_mont, int32* zetas, int64 q, int64 ord, int64 mont)
{
int64 wi_pre = mont * zetas[0] % q;
zetas_mont[0] = wi_pre;
printf("zetas_mont = { %lld", wi_pre);
for (int64 i = 1; i <= ord; i++)
{
wi_pre = mont * zetas[i] % q;
zetas_mont[i] = wi_pre;
printf(", %lld", wi_pre);
}
printf(" };
");
}
void get_brv_table(int32 bits)
{
printf("bitrev_list = { 0");
int32 len = (1LL << bits);
for (int i = 1; i < len; i++)
printf(", %d", brv(i, bits));
printf(" };
");
}
void get_intt_factor(int64 q, int64 r, int64 mont)
{
int64 factor, pinv;
int64 gcd = exgcd(&factor, &pinv, 1LL << r, q);
factor = factor < 0 ? factor + q : factor;
int64 factor_mont = (factor * mont) % q;
int64 factor_mont2 = (factor_mont * mont) % q;
printf("factor = %lld, factor_mont = %lld, factor_mont2 = %lld
", factor, factor_mont, factor_mont2); //分别为:1/2^r,R/2^r,R^2/2^r
}
void get_ntt_param(int32 q, int32 n, int32 r)
{
printf("/******************************* get ntt params *******************************/
");
printf("NTT_Q = %d, NTT_N = %d, NTT_ROUND = %d
", q, n, r);
int64 d, x, y;
int64 mont = MONT_R % q;
d = exgcd(&x, &y, q, MONT_R);
if (d != 1)
{
printf("gcd(NTT_Q, MONT_R) != 1
");
return;
}
printf("MONT_R = %lld
MONT = %lld mod q
QINV = %lld mod R
", MONT_R, mont, x);
printf("BARR_R = 2^16, BARR_epi16 = 2^16/q = %d
BARR_R = 2^32, BARR_epi32 = 2^32/q = %d
", (16 + (q >> 1)) / q, (32 + (q >> 1)) / q);
int32 order = 1 << (r + 1);
int64 Zeta;
Zeta = find_root(q, order);
int32* Zetas = (int32*)malloc(sizeof(int32) * (order + 1));
int32* Zetas_mont = (int32*)malloc(sizeof(int32) * (order + 1));
get_zetas(Zetas, Zeta, q, order);
get_zetas_mont(Zetas_mont, Zetas, q, order, mont);
get_brv_table(r + 1);
get_intt_factor(q, r, mont);
printf("//******************************* get ntt params *******************************//
");
free(Zetas);
free(Zetas_mont);
}
//################################### Load/Store辅助函数 ###################################
__m256i NTT_TMP;
#define Half(X,Y)
NTT_TMP = _mm256_permute2x128_si256(X, Y, 0x31);
X = _mm256_permute2x128_si256(X, Y, 0x20);
Y = NTT_TMP;
#define Perm(X,Y)
X = _mm256_permute4x64_epi64(X, 0b11011000);
Y = _mm256_permute4x64_epi64(Y, 0b11011000);
#define Coll_32(X,Y)
X = _mm256_shuffle_epi32(X, 0b11011000);
Y = _mm256_shuffle_epi32(Y, 0b11011000);
const int8 CollIndex[32] = {
0,1,4,5,8,9,12,13,2,3,6,7,10,11,14,15,
0,1,4,5,8,9,12,13,2,3,6,7,10,11,14,15
};
#define Coll_16(X,Y)
X = _mm256_shuffle_epi8(X,*(__m256i*)CollIndex);
Y = _mm256_shuffle_epi8(Y,*(__m256i*)CollIndex);
const int8 CollIndex_inv[32] = {
0,1,8,9,2,3,10,11,4,5,12,13,6,7,14,15,
0,1,8,9,2,3,10,11,4,5,12,13,6,7,14,15
};
#define Coll_16_inv(X,Y)
X = _mm256_shuffle_epi8(X,*(__m256i*)CollIndex_inv);
Y = _mm256_shuffle_epi8(Y,*(__m256i*)CollIndex_inv);
#define offset8(X,Y) Half(X, Y);
#define offset8_inv(X,Y) Half(X, Y);
#define offset4(X,Y) Perm(X, Y); Half(X, Y);
#define offset4_inv(X,Y) Half(X, Y); Perm(X, Y);
#define offset2(X,Y) Coll_32(X,Y); Perm(X,Y); Half(X, Y);
#define offset2_inv(X,Y) Half(X, Y); Perm(X, Y); Coll_32(X,Y);
#define offset1(X,Y) Coll_16(X,Y); Perm(X,Y); Half(X, Y);
#define offset1_inv(X,Y) Half(X, Y); Perm(X, Y); Coll_16_inv(X,Y);
//################################### 快速模约减 ###################################
/*
* montgomery_mul(a, zeta) = (a * zeta_mont - (R-1)&(a * zeta_mont * qinv) * q) >> t
* zeta_mont = zeta*R mod q,t=16,R=2^16
*/
__m256i montgomery_reduce_epi16(__m256i a, __m256i w) {
__m256i q = _mm256_set1_epi16(NTT_Q);
__m256i qinv = _mm256_set1_epi16(QINV);
//正确性约束:(R - 1)*q + a*w < 2^32
__m256i hi = _mm256_mulhi_epi16(a, w); //有符号的高位,epi与epu的乘法结果模2^32同余,比特表示相同
__m256i lo = _mm256_mullo_epi16(a, w); //有符号的低位,是个无符号数,hi*65536 + lo
/*
只要不越界,_mm256_mullo_epi16 = _mm256_mullo_epu16
但是,_mm256_mulhi_epi16 != _mm256_mulhi_epu16,注意高位补0还是补1
*/
__m256i tmp = _mm256_mullo_epi16(lo, qinv); //无符号模乘,R=2^16,无论lo和qinv是int16下的负数或正数
tmp = _mm256_mulhi_epu16(tmp, q); //无符号乘法,要使用epu,将tmp识别为无符号数,抑制乘法的高位补1
/*
a*w和tmp*q的低16位相同,没有进位借位
hi是负数,减去无符号tmp后还是负数
hi是正数,减去无符号tmp后可能是正数也可能是负数
*/
hi = _mm256_sub_epi16(hi, tmp);
//要让约减结果范围[0,q],不能出现负数(与Harvey蝴蝶冲突)
tmp = _mm256_srai_epi16(hi, 15); //算数右移
tmp = _mm256_and_si256(tmp, q);
hi = _mm256_add_epi16(hi, tmp);
return hi;
}
/*
* a - ((m*a)>>t) * q
* m = R/q,t=16,R=2^16
*/
__m256i barrett_reduce_epi16(__m256i a) {
__m256i q = _mm256_set1_epi16(NTT_Q);
__m256i m = _mm256_set1_epi16(BARR_epi16);
__m256i tmp = _mm256_mulhi_epi16(a, m); //有符号
tmp = _mm256_mullo_epi16(tmp, q);
a = _mm256_sub_epi16(a, tmp);
return a; //范围乱变[-q, 2q)
}
/*
* a - ((m*a)>>t) * q
* m = R/q,t=32,R=2^32
*/
__m256i barrett_reduce_epi32(__m256i a) {
__m256i q = _mm256_set1_epi32(NTT_Q);
__m256i m = _mm256_set1_epi32(BARR_epi32);
__m256i tmp1 = _mm256_mul_epi32(a, m);
tmp1 = _mm256_srli_epi64(tmp1, 32); //本应算数右移,为了重构方便采用逻辑右移,截断结果仍有符号
__m256i tmp2 = _mm256_shuffle_epi32(a, 0b10110001);
tmp2 = _mm256_mul_epi32(tmp2, m);
tmp2 = _mm256_srli_epi64(tmp2, 32);
tmp2 = _mm256_shuffle_epi32(tmp2, 0b10110001);
tmp1 = _mm256_or_si256(tmp1, tmp2); //重构为epi32
tmp1 = _mm256_mullo_epi32(tmp1, q);
tmp1 = _mm256_sub_epi32(a, tmp1);
return tmp1; //范围乱变[-q, 2q)
}
__m256i iflt0_addq(__m256i a) {
__m256i q = _mm256_set1_epi16(NTT_Q);
__m256i tmp = _mm256_srai_epi16(a, 15);
tmp = _mm256_and_si256(tmp, q);
return _mm256_add_epi16(a, tmp);
}
__m256i ifgeq_subq(__m256i a) {
__m256i q = _mm256_set1_epi16(NTT_Q);
__m256i tmp = _mm256_set1_epi16(NTT_Q - 1);
tmp = _mm256_sub_epi16(tmp, a);
tmp = _mm256_srai_epi16(tmp, 15);
tmp = _mm256_and_si256(tmp, q);
return _mm256_sub_epi16(a, tmp);
}
__m256i ifge2q_sub2q(__m256i a) {
__m256i q = _mm256_set1_epi16(2 * NTT_Q);
__m256i tmp = _mm256_set1_epi16(2 * NTT_Q - 1);
tmp = _mm256_sub_epi16(tmp, a);
tmp = _mm256_srai_epi16(tmp, 15);
tmp = _mm256_and_si256(tmp, q);
return _mm256_sub_epi16(a, tmp);
}
//################################### NTT变换 ###################################
void ntt(int16* f) {
int32 Blocknum = 1;
int32 Blocksize = NTT_N;
int32 Round = 0;
__m256i T, Q = _mm256_set1_epi16(NTT_Q);
/*
Radix-2
X = X + WY
Y = X - WY
*/
if ((NTT_ROUND & 1) == 1) {
int32 offset = Blocksize >> 1;
__m256i W = _mm256_set1_epi16(bitrev_list[Blocknum * NTT_NEG] >> 1);
for (int32 k = 0; k < offset; k += 16) {
__m256i X = _mm256_loadu_si256((__m256i*)(f + k));
__m256i Y = _mm256_loadu_si256((__m256i*)(f + k + offset));
T = montgomery_reduce_epi16(Y, W);
Y = _mm256_add_epi16(X, Q);
Y = _mm256_sub_epi16(Y, T);
X = _mm256_add_epi16(X, T);
_mm256_storeu_si256((__m256i*)(f + k), X);
_mm256_storeu_si256((__m256i*)(f + k + offset), Y);
}
Blocknum <<= 1;
Blocksize >>= 1;
Round++;
}
/*
Radix-4
Harvey,输入输出范围[0,2q)
X1 = (X1 + W*Y1) + W0*(X2 + W*Y2),范围[0,4q)
X2 = (X1 + W*Y1) - W0*(X2 + W*Y2),范围[0,4q)
Y1 = (X1 - W*Y1) + W1*(X2 - W*Y2),范围[0,4q)
X2 = (X1 - W*Y1) - W1*(X2 - W*Y2),范围[0,4q)
先约束X1范围[0,q),接着约束(X1 + W*Y1)和(X1 - W*Y1)范围[0,q),共三次模约减
*/
for (; Round < NTT_ROUND; Round += 2, Blocksize >>= 2, Blocknum <<= 2) {
if (Blocksize >= 64)
goto Block64;
else
switch (Blocksize)
{
case 32: goto Block32;
case 16: goto Block16;
case 8: goto Block8;
case 4: goto Block4;
default:
goto Error; //本代码仅处理:NTT_N 是2的幂次
}
Block64: //处理分块大小整除64的情况,使用4个YMM,处理1个块
for (int32 i = 0; i < Blocknum; i++) {
int32 offset = Blocksize >> 2;
int32 num = offset >> 4; //16个系数1个YMM
int16* pf = f + i * Blocksize;
/*
j=0是原始数组,第j次迭代中,j-1层第i个分块使用的单位根,
w_{2^{j}}^{brv_{j}(2i)} = w_{2^{r}}^{2^{r-j}*brv_{j}(2i)}
brv_{j}(2i) = brv_{r}/(r-j+1)
因此 w_{2^{j}}^{brv_{j}(2i)} = w_{2^{r}}^{brv_{r}(i)/2}
*/
int32 ind = Blocknum * NTT_NEG + i;
__m256i W = _mm256_set1_epi16(zetas_mont[bitrev_list[ind] >> 1]); //Round层第i块
__m256i W0 = _mm256_set1_epi16(zetas_mont[bitrev_list[ind * 2] >> 1]); //Round+1层第2i块
__m256i W1 = _mm256_set1_epi16(zetas_mont[bitrev_list[ind * 2 + 1] >> 1]); //Round+1层第2i+1块
for (int32 k = 0; k < num; k++) {
__m256i X1 = _mm256_loadu_si256((__m256i*)(pf + k * 16));
__m256i X2 = _mm256_loadu_si256((__m256i*)(pf + k * 16 + offset));
__m256i Y1 = _mm256_loadu_si256((__m256i*)(pf + k * 16 + offset * 2));
__m256i Y2 = _mm256_loadu_si256((__m256i*)(pf + k * 16 + offset * 3));
X1 = ifgeq_subq(X1);
T = montgomery_reduce_epi16(Y1, W);
Y1 = _mm256_add_epi16(X1, Q);
Y1 = _mm256_sub_epi16(Y1, T);
X1 = _mm256_add_epi16(X1, T);
T = montgomery_reduce_epi16(Y2, W);
Y2 = _mm256_add_epi16(X2, Q);
Y2 = _mm256_sub_epi16(Y2, T);
X2 = _mm256_add_epi16(X2, T);
X1 = ifgeq_subq(X1);
T = montgomery_reduce_epi16(X2, W0);
X2 = _mm256_add_epi16(X1, Q);
X2 = _mm256_sub_epi16(X2, T);
X1 = _mm256_add_epi16(X1, T);
Y1 = ifgeq_subq(Y1);
T = montgomery_reduce_epi16(Y2, W1);
Y2 = _mm256_add_epi16(Y1, Q);
Y2 = _mm256_sub_epi16(Y2, T);
Y1 = _mm256_add_epi16(Y1, T);
_mm256_storeu_si256((__m256i*)(pf + k * 16), X1);
_mm256_storeu_si256((__m256i*)(pf + k * 16 + offset), X2);
_mm256_storeu_si256((__m256i*)(pf + k * 16 + offset * 2), Y1);
_mm256_storeu_si256((__m256i*)(pf + k * 16 + offset * 3), Y2);
}
}
continue;
Block32: //处理分块大小为32的情况,使用2个YMM,处理1个块
for (int32 i = 0; i < Blocknum; i++) {
int16* pf = f + i * Blocksize;
__m256i X = _mm256_loadu_si256((__m256i*)(pf));
__m256i Y = _mm256_loadu_si256((__m256i*)(pf + 16));
int32 ind = Blocknum * NTT_NEG + i;
int16 w = zetas_mont[bitrev_list[ind] >> 1];
__m256i W = _mm256_set1_epi16(w);
X = ifgeq_subq(X);
T = montgomery_reduce_epi16(Y, W);
Y = _mm256_add_epi16(X, Q);
Y = _mm256_sub_epi16(Y, T);
X = _mm256_add_epi16(X, T);
ind <<= 1;
int16 w0 = zetas_mont[bitrev_list[ind] >> 1];
int16 w1 = zetas_mont[bitrev_list[ind + 1] >> 1];
W = _mm256_setr_epi16(w0, w0, w0, w0, w0, w0, w0, w0, w1, w1, w1, w1, w1, w1, w1, w1);
offset8(X, Y);
X = ifgeq_subq(X);
T = montgomery_reduce_epi16(Y, W);
Y = _mm256_add_epi16(X, Q);
Y = _mm256_sub_epi16(Y, T);
X = _mm256_add_epi16(X, T);
offset8_inv(X, Y);
_mm256_storeu_si256((__m256i*)(pf), X);
_mm256_storeu_si256((__m256i*)(pf + 16), Y);
}
continue;
Block16: //处理分块大小为16的情况,使用2个YMM,处理2个块
for (int32 i = 0; i < Blocknum; i+=2) {
int16* pf = f + i * Blocksize;
__m256i X = _mm256_loadu_si256((__m256i*)(pf));
__m256i Y = _mm256_loadu_si256((__m256i*)(pf + 16));
int32 ind = Blocknum * NTT_NEG + i;
int16 w0 = zetas_mont[bitrev_list[ind] >> 1];
int16 w1 = zetas_mont[bitrev_list[ind + 1] >> 1];
__m256i W = _mm256_setr_epi16(w0, w0, w0, w0, w0, w0, w0, w0, w1, w1, w1, w1, w1, w1, w1, w1);
offset8(X, Y);
X = ifgeq_subq(X);
T = montgomery_reduce_epi16(Y, W);
Y = _mm256_add_epi16(X, Q);
Y = _mm256_sub_epi16(Y, T);
X = _mm256_add_epi16(X, T);
offset8_inv(X, Y);
ind <<= 1;
int16 w00 = zetas_mont[bitrev_list[ind] >> 1];
int16 w01 = zetas_mont[bitrev_list[ind + 1] >> 1];
int16 w10 = zetas_mont[bitrev_list[ind + 2] >> 1];
int16 w11 = zetas_mont[bitrev_list[ind + 3] >> 1];
W = _mm256_setr_epi16(w00, w00, w00, w00, w01, w01, w01, w01, w10, w10, w10, w10, w11, w11, w11, w11);
offset4(X, Y);
X = ifgeq_subq(X);
T = montgomery_reduce_epi16(Y, W);
Y = _mm256_add_epi16(X, Q);
Y = _mm256_sub_epi16(Y, T);
X = _mm256_add_epi16(X, T);
offset4_inv(X, Y);
_mm256_storeu_si256((__m256i*)(pf), X);
_mm256_storeu_si256((__m256i*)(pf + 16), Y);
}
continue;
Block8: //处理分块大小为8的情况,使用2个YMM,处理4个块
for (int32 i = 0; i < Blocknum; i += 4) {
int16* pf = f + i * Blocksize;
__m256i X = _mm256_loadu_si256((__m256i*)(pf));
__m256i Y = _mm256_loadu_si256((__m256i*)(pf + 16));
int32 ind = Blocknum * NTT_NEG + i;
int16 w0 = zetas_mont[bitrev_list[ind] >> 1];
int16 w1 = zetas_mont[bitrev_list[ind + 1] >> 1];
int16 w2 = zetas_mont[bitrev_list[ind + 2] >> 1];
int16 w3 = zetas_mont[bitrev_list[ind + 3] >> 1];
__m256i W = _mm256_setr_epi16(w0, w0, w0, w0, w1, w1, w1, w1, w2, w2, w2, w2, w3, w3, w3, w3);
offset4(X, Y);
X = ifgeq_subq(X);
T = montgomery_reduce_epi16(Y, W);
Y = _mm256_add_epi16(X, Q);
Y = _mm256_sub_epi16(Y, T);
X = _mm256_add_epi16(X, T);
offset4_inv(X, Y);
ind <<= 1;
int16 w00 = zetas_mont[bitrev_list[ind] >> 1];
int16 w01 = zetas_mont[bitrev_list[ind + 1] >> 1];
int16 w10 = zetas_mont[bitrev_list[ind + 2] >> 1];
int16 w11 = zetas_mont[bitrev_list[ind + 3] >> 1];
int16 w20 = zetas_mont[bitrev_list[ind + 4] >> 1];
int16 w21 = zetas_mont[bitrev_list[ind + 5] >> 1];
int16 w30 = zetas_mont[bitrev_list[ind + 6] >> 1];
int16 w31 = zetas_mont[bitrev_list[ind + 7] >> 1];
W = _mm256_setr_epi16(w00, w00, w01, w01, w10, w10, w11, w11, w20, w20, w21, w21, w30, w30, w31, w31);
offset2(X, Y);
X = ifgeq_subq(X);
T = montgomery_reduce_epi16(Y, W);
Y = _mm256_add_epi16(X, Q);
Y = _mm256_sub_epi16(Y, T);
X = _mm256_add_epi16(X, T);
offset2_inv(X, Y);
_mm256_storeu_si256((__m256i*)(pf), X);
_mm256_storeu_si256((__m256i*)(pf + 16), Y);
}
continue;
Block4: //处理分块大小为4的情况,使用2个YMM,处理8个块
for (int32 i = 0; i < Blocknum; i += 8) {
int16* pf = f + i * Blocksize;
__m256i X = _mm256_loadu_si256((__m256i*)(pf));
__m256i Y = _mm256_loadu_si256((__m256i*)(pf + 16));
int32 ind = Blocknum * NTT_NEG + i;
int16 w0 = zetas_mont[bitrev_list[ind] >> 1];
int16 w1 = zetas_mont[bitrev_list[ind + 1] >> 1];
int16 w2 = zetas_mont[bitrev_list[ind + 2] >> 1];
int16 w3 = zetas_mont[bitrev_list[ind + 3] >> 1];
int16 w4 = zetas_mont[bitrev_list[ind + 4] >> 1];
int16 w5 = zetas_mont[bitrev_list[ind + 5] >> 1];
int16 w6 = zetas_mont[bitrev_list[ind + 6] >> 1];
int16 w7 = zetas_mont[bitrev_list[ind + 7] >> 1];
__m256i W = _mm256_setr_epi16(w0, w0, w1, w1, w2, w2, w3, w3, w4, w4, w5, w5, w6, w6, w7, w7);
offset2(X, Y);
X = ifgeq_subq(X);
T = montgomery_reduce_epi16(Y, W);
Y = _mm256_add_epi16(X, Q);
Y = _mm256_sub_epi16(Y, T);
X = _mm256_add_epi16(X, T);
offset2_inv(X, Y);
ind <<= 1;
int16 w00 = zetas_mont[bitrev_list[ind] >> 1];
int16 w01 = zetas_mont[bitrev_list[ind + 1] >> 1];
int16 w10 = zetas_mont[bitrev_list[ind + 2] >> 1];
int16 w11 = zetas_mont[bitrev_list[ind + 3] >> 1];
int16 w20 = zetas_mont[bitrev_list[ind + 4] >> 1];
int16 w21 = zetas_mont[bitrev_list[ind + 5] >> 1];
int16 w30 = zetas_mont[bitrev_list[ind + 6] >> 1];
int16 w31 = zetas_mont[bitrev_list[ind + 7] >> 1];
int16 w40 = zetas_mont[bitrev_list[ind + 8] >> 1];
int16 w41 = zetas_mont[bitrev_list[ind + 9] >> 1];
int16 w50 = zetas_mont[bitrev_list[ind + 10] >> 1];
int16 w51 = zetas_mont[bitrev_list[ind + 11] >> 1];
int16 w60 = zetas_mont[bitrev_list[ind + 12] >> 1];
int16 w61 = zetas_mont[bitrev_list[ind + 13] >> 1];
int16 w70 = zetas_mont[bitrev_list[ind + 14] >> 1];
int16 w71 = zetas_mont[bitrev_list[ind + 15] >> 1];
W = _mm256_setr_epi16(w00, w01, w10, w11, w20, w21, w30, w31, w40, w41, w50, w51, w60, w61, w70, w71);
offset1(X, Y);
X = ifgeq_subq(X);
T = montgomery_reduce_epi16(Y, W);
Y = _mm256_add_epi16(X, Q);
Y = _mm256_sub_epi16(Y, T);
X = _mm256_add_epi16(X, T);
offset1_inv(X, Y);
_mm256_storeu_si256((__m256i*)(pf), X);
_mm256_storeu_si256((__m256i*)(pf + 16), Y);
}
continue;
Error: //捕获块大小错误
printf("Blocksize isn't power of 2.
");
}
for (int32 k = 0; k < NTT_N; k += 16) {
__m256i X = _mm256_loadu_si256((__m256i*)(f + k)); //模约减,从[0,2q)约减到[0,q)
X = ifgeq_subq(X);
_mm256_storeu_si256((__m256i*)(f + k), X);
}
}
void intt(int16* f, int8 mont) {
int32 Blocknum = 1 << NTT_ROUND;
int32 Blocksize = NTT_N >> NTT_ROUND;
int32 Round = NTT_ROUND;
int32 Qtimes2 = NTT_Q * 2;
Blocksize <<= 2;
Blocknum >>= 2;
__m256i T, Q = _mm256_set1_epi16(NTT_Q);
/*
Radix-4
Harvey,输入输出范围[0,2q)
X1 = (X1 + X2) + (Y1 + Y2),范围[0,8q)
X2 = IW0*(X1 - X2) + IW1*(Y1 - Y2),范围[0,2q)
Y1 = IW*((X1 + X2) + (Y1 + Y2)),范围[0,q)
Y2 = IW*(IW0*(X1 - X2) + IW1*(Y1 - Y2)),范围[0,q)
先约束(X1 + X2)和(Y1 + Y2)范围[0,2q),接着约束(X1 + X2) + (Y1 + Y2)范围[0,2q),共三次模约减
*/
for (; Round > 1; Round -= 2, Blocksize <<= 2, Blocknum >>= 2) {
if (Blocksize >= 64)
goto Block64;
else
switch (Blocksize)
{
case 32: goto Block32;
case 16: goto Block16;
case 8: goto Block8;
case 4: goto Block4;
default:
goto Error; //本代码仅处理:NTT_N 是2的幂次
}
Block64: //处理分块大小整除64的情况,使用4个YMM,处理1个块
for (int32 i = 0; i < Blocknum; i++) {
int32 offset = Blocksize >> 2;
int32 num = offset >> 4; //16个系数1个YMM
int16* pf = f + i * Blocksize;
/*
j=0是原始数组,第j次迭代中,j-1层第i个分块使用的单位根,
w_{2^{j}}^{brv_{j}(2i)} = w_{2^{r}}^{2^{r-j}*brv_{j}(2i)}
brv_{j}(2i) = brv_{r}/(r-j+1)
因此 w_{2^{j}}^{brv_{j}(2i)} = w_{2^{r}}^{brv_{r}(i)/2}
*/
int32 ind = Blocknum * NTT_NEG + i;
__m256i W = _mm256_set1_epi16(zetas_mont[NTT_ORDER - (bitrev_list[ind] >> 1)]); //Round层第i块
__m256i W0 = _mm256_set1_epi16(zetas_mont[NTT_ORDER - (bitrev_list[ind * 2] >> 1)]); //Round+1层第2i块
__m256i W1 = _mm256_set1_epi16(zetas_mont[NTT_ORDER - (bitrev_list[ind * 2 + 1] >> 1)]); //Round+1层第2i+1块
for (int32 k = 0; k < num; k++) {
__m256i X1 = _mm256_loadu_si256((__m256i*)(pf + k * 16));
__m256i X2 = _mm256_loadu_si256((__m256i*)(pf + k * 16 + offset));
__m256i Y1 = _mm256_loadu_si256((__m256i*)(pf + k * 16 + offset * 2));
__m256i Y2 = _mm256_loadu_si256((__m256i*)(pf + k * 16 + offset * 3));
T = _mm256_sub_epi16(X1, X2);
X1 = _mm256_add_epi16(X1, X2);
X2 = montgomery_reduce_epi16(T, W0);
X1 = ifgeq_subq(X1);
T = _mm256_sub_epi16(Y1, Y2);
Y1 = _mm256_add_epi16(Y1, Y2);
Y2 = montgomery_reduce_epi16(T, W1);
Y1 = ifgeq_subq(Y1);
T = _mm256_sub_epi16(X1, Y1);
X1 = _mm256_add_epi16(X1, Y1);
Y1 = montgomery_reduce_epi16(T, W);
X1 = ifgeq_subq(X1);
T = _mm256_sub_epi16(X2, Y2);
X2 = _mm256_add_epi16(X2, Y2);
Y2 = montgomery_reduce_epi16(T, W);
X2 = ifgeq_subq(X2);
_mm256_storeu_si256((__m256i*)(pf + k * 16), X1);
_mm256_storeu_si256((__m256i*)(pf + k * 16 + offset), X2);
_mm256_storeu_si256((__m256i*)(pf + k * 16 + offset * 2), Y1);
_mm256_storeu_si256((__m256i*)(pf + k * 16 + offset * 3), Y2);
}
}
continue;
Block32: //处理分块大小为32的情况,使用2个YMM,处理1个块
for (int32 i = 0; i < Blocknum; i++) {
int16* pf = f + i * Blocksize;
__m256i X = _mm256_loadu_si256((__m256i*)(pf));
__m256i Y = _mm256_loadu_si256((__m256i*)(pf + 16));
int32 ind = (Blocknum * NTT_NEG + i) * 2;
int16 w0 = zetas_mont[NTT_ORDER - (bitrev_list[ind] >> 1)];
int16 w1 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 1] >> 1)];
__m256i W = _mm256_setr_epi16(w0, w0, w0, w0, w0, w0, w0, w0, w1, w1, w1, w1, w1, w1, w1, w1);
offset8(X, Y);
T = _mm256_sub_epi16(X, Y);
X = _mm256_add_epi16(X, Y);
Y = montgomery_reduce_epi16(T, W);
X = ifgeq_subq(X);
offset8_inv(X, Y);
ind >>= 1;
int16 w = zetas_mont[NTT_ORDER - (bitrev_list[ind] >> 1)];
W = _mm256_set1_epi16(w);
T = _mm256_sub_epi16(X, Y);
X = _mm256_add_epi16(X, Y);
Y = montgomery_reduce_epi16(T, W);
X = ifgeq_subq(X);
_mm256_storeu_si256((__m256i*)(pf), X);
_mm256_storeu_si256((__m256i*)(pf + 16), Y);
}
continue;
Block16: //处理分块大小为16的情况,使用2个YMM,处理2个块
for (int32 i = 0; i < Blocknum; i += 2) {
int16* pf = f + i * Blocksize;
__m256i X = _mm256_loadu_si256((__m256i*)(pf));
__m256i Y = _mm256_loadu_si256((__m256i*)(pf + 16));
int32 ind = (Blocknum * NTT_NEG + i) * 2;
int16 w00 = zetas_mont[NTT_ORDER - (bitrev_list[ind] >> 1)];
int16 w01 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 1] >> 1)];
int16 w10 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 2] >> 1)];
int16 w11 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 3] >> 1)];
__m256i W = _mm256_setr_epi16(w00, w00, w00, w00, w01, w01, w01, w01, w10, w10, w10, w10, w11, w11, w11, w11);
offset4(X, Y);
T = _mm256_sub_epi16(X, Y);
X = _mm256_add_epi16(X, Y);
Y = montgomery_reduce_epi16(T, W);
X = ifgeq_subq(X);
offset4_inv(X, Y);
ind >>= 1;
int16 w0 = zetas_mont[NTT_ORDER - (bitrev_list[ind] >> 1)];
int16 w1 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 1] >> 1)];
W = _mm256_setr_epi16(w0, w0, w0, w0, w0, w0, w0, w0, w1, w1, w1, w1, w1, w1, w1, w1);
offset8(X, Y);
T = _mm256_sub_epi16(X, Y);
X = _mm256_add_epi16(X, Y);
Y = montgomery_reduce_epi16(T, W);
X = ifgeq_subq(X);
offset8_inv(X, Y);
_mm256_storeu_si256((__m256i*)(pf), X);
_mm256_storeu_si256((__m256i*)(pf + 16), Y);
}
continue;
Block8: //处理分块大小为8的情况,使用2个YMM,处理4个块
for (int32 i = 0; i < Blocknum; i += 4) {
int16* pf = f + i * Blocksize;
__m256i X = _mm256_loadu_si256((__m256i*)(pf));
__m256i Y = _mm256_loadu_si256((__m256i*)(pf + 16));
int32 ind = (Blocknum * NTT_NEG + i) * 2;
int16 w00 = zetas_mont[NTT_ORDER - (bitrev_list[ind] >> 1)];
int16 w01 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 1] >> 1)];
int16 w10 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 2] >> 1)];
int16 w11 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 3] >> 1)];
int16 w20 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 4] >> 1)];
int16 w21 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 5] >> 1)];
int16 w30 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 6] >> 1)];
int16 w31 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 7] >> 1)];
__m256i W = _mm256_setr_epi16(w00, w00, w01, w01, w10, w10, w11, w11, w20, w20, w21, w21, w30, w30, w31, w31);
offset2(X, Y);
T = _mm256_sub_epi16(X, Y);
X = _mm256_add_epi16(X, Y);
Y = montgomery_reduce_epi16(T, W);
X = ifgeq_subq(X);
offset2_inv(X, Y);
ind >>= 1;
int16 w0 = zetas_mont[NTT_ORDER - (bitrev_list[ind] >> 1)];
int16 w1 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 1] >> 1)];
int16 w2 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 2] >> 1)];
int16 w3 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 3] >> 1)];
W = _mm256_setr_epi16(w0, w0, w0, w0, w1, w1, w1, w1, w2, w2, w2, w2, w3, w3, w3, w3);
offset4(X, Y);
T = _mm256_sub_epi16(X, Y);
X = _mm256_add_epi16(X, Y);
Y = montgomery_reduce_epi16(T, W);
X = ifgeq_subq(X);
offset4_inv(X, Y);
_mm256_storeu_si256((__m256i*)(pf), X);
_mm256_storeu_si256((__m256i*)(pf + 16), Y);
}
continue;
Block4: //处理分块大小为4的情况,使用2个YMM,处理8个块
for (int32 i = 0; i < Blocknum; i += 8) {
int16* pf = f + i * Blocksize;
__m256i X = _mm256_loadu_si256((__m256i*)(pf));
__m256i Y = _mm256_loadu_si256((__m256i*)(pf + 16));
int32 ind = (Blocknum * NTT_NEG + i) * 2;
int16 w00 = zetas_mont[NTT_ORDER - (bitrev_list[ind] >> 1)];
int16 w01 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 1] >> 1)];
int16 w10 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 2] >> 1)];
int16 w11 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 3] >> 1)];
int16 w20 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 4] >> 1)];
int16 w21 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 5] >> 1)];
int16 w30 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 6] >> 1)];
int16 w31 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 7] >> 1)];
int16 w40 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 8] >> 1)];
int16 w41 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 9] >> 1)];
int16 w50 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 10] >> 1)];
int16 w51 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 11] >> 1)];
int16 w60 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 12] >> 1)];
int16 w61 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 13] >> 1)];
int16 w70 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 14] >> 1)];
int16 w71 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 15] >> 1)];
__m256i W = _mm256_setr_epi16(w00, w01, w10, w11, w20, w21, w30, w31, w40, w41, w50, w51, w60, w61, w70, w71);
offset1(X, Y);
T = _mm256_sub_epi16(X, Y);
X = _mm256_add_epi16(X, Y);
Y = montgomery_reduce_epi16(T, W);
X = ifgeq_subq(X);
offset1_inv(X, Y);
ind >>= 1;
int16 w0 = zetas_mont[NTT_ORDER - (bitrev_list[ind] >> 1)];
int16 w1 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 1] >> 1)];
int16 w2 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 2] >> 1)];
int16 w3 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 3] >> 1)];
int16 w4 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 4] >> 1)];
int16 w5 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 5] >> 1)];
int16 w6 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 6] >> 1)];
int16 w7 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 7] >> 1)];
W = _mm256_setr_epi16(w0, w0, w1, w1, w2, w2, w3, w3, w4, w4, w5, w5, w6, w6, w7, w7);
offset2(X, Y);
T = _mm256_sub_epi16(X, Y);
X = _mm256_add_epi16(X, Y);
Y = montgomery_reduce_epi16(T, W);
X = ifgeq_subq(X);
offset2_inv(X, Y);
_mm256_storeu_si256((__m256i*)(pf), X);
_mm256_storeu_si256((__m256i*)(pf + 16), Y);
}
continue;
Error: //捕获块大小错误
printf("Blocksize isn't power of 2.
");
}
/*
Radix-2
X = X + Y
Y = IW*(X - Y)
*/
if ((NTT_ROUND & 1) == 1) {
int32 offset = Blocksize >> 1;
int32 num = offset >> 4; //16个系数1个YMM
__m256i W = _mm256_set1_epi16(zetas_mont[NTT_ORDER - (bitrev_list[Blocknum * NTT_NEG] >> 1)]);
for (int32 k = 0; k < offset; k += 16) {
__m256i X = _mm256_loadu_si256((__m256i*)(f + k));
__m256i Y = _mm256_loadu_si256((__m256i*)(f + k + offset));
T = _mm256_sub_epi16(X, Y);
X = _mm256_add_epi16(X, Y);
Y = montgomery_reduce_epi16(T, W);
_mm256_storeu_si256((__m256i*)(f + k), X);
_mm256_storeu_si256((__m256i*)(f + k + offset), Y);
}
}
//逆变换因子
__m256i F = _mm256_set1_epi16(factor_mont);
if(mont != 0)
F = _mm256_set1_epi16(factor_mont2); //执行了 montgomery 版本的 nttmul,需额外再乘一个 mont = R mod q
for (int32 k = 0; k < NTT_N; k += 16) {
__m256i X = _mm256_loadu_si256((__m256i*)(f + k));
X = montgomery_reduce_epi16(X, F);
_mm256_storeu_si256((__m256i*)(f + k), X);
}
}
inline void basemul_mont(int16* r, const int16* a, const int16* b, int16 zeta_mont)
{
int32 res; // 用更长的累加器,延迟取模运算
int32 s;
for (int32 i = 0; i < NTT_BASELEN; i++)
{
res = 0;
s = NTT_BASELEN + i;
for (int32 j = 0; j <= i; j++)
res += b[j] * a[i - j];
for (int32 j = i + 1; j < NTT_BASELEN; j++) {
res += montgomery_reduce(b[j] * zeta_mont) * a[s - j];
}
r[i] = montgomery_reduce(res); //结果是 r = a*b/R
}
}
void nttmul_mont(int16* r, const int16* a, const int16* b)
{
// 2^{r-1} 个 n/2^{r-1} 长小多项式,NTT_ROUND = r-1
int32 num = 1 << NTT_ROUND;
#if (NTT_BASELEN == 1) //AVX2实现
for (int32 i = 0; i < num; i += 16) {
__m256i X = _mm256_loadu_si256(a);
__m256i Y = _mm256_loadu_si256(b);
X = montgomery_reduce_epi16(X, Y);
_mm256_storeu_si256(r, X);
r += 16;
a += 16;
b += 16;
}
#elif (NTT_BASELEN == 2) //常规实现
for (int32 i = 0; i < num; i++) {
// 第r层第2^{r-1}+i个多项式使用的单位根,
// w_{2^r}^{brv_r(2^{r-1}+i)},NTT_ROUND = r-1
int32 zeta = zetas_mont[bitrev_list[num * NTT_NEG + i]];
int32 tmp0 = a[0] * b[0] + montgomery_reduce(zeta * a[1]) * b[1];
int32 tmp1 = a[0] * b[1] + a[1] * b[0];
r[0] = montgomery_reduce(tmp0);
r[1] = montgomery_reduce(tmp1);
r += NTT_BASELEN;
a += NTT_BASELEN;
b += NTT_BASELEN;
}
#else
for (int32 i = 0; i < num; i++) {
int32 zeta = zetas_mont[bitrev_list[num * NTT_NEG + i]];
basemul_mont(r, a, b, zeta); //常规实现
r += NTT_BASELEN;
a += NTT_BASELEN;
b += NTT_BASELEN;
}
#endif
}
inline void basemul(int16* r, const int16* a, const int16* b, int16 zeta)
{
int32 res; // 用更长的累加器,延迟取模运算
int32 s;
for (int32 i = 0; i < NTT_BASELEN; i++)
{
res = 0;
s = NTT_BASELEN + i;
for (int32 j = 0; j <= i; j++)
res += b[j] * a[i - j];
for (int32 j = i + 1; j < NTT_BASELEN; j++)
res += zeta * barrett_reduce(b[j] * a[s - j]);
r[i] = barrett_reduce(res);
}
}
void nttmul(int16* r, const int16* a, const int16* b, int8 mont)
{
if (mont == 0) {
// 2^{r-1} 个 n/2^{r-1} 长小多项式,NTT_ROUND = r-1
int32 num = 1 << NTT_ROUND;
for (int32 i = 0; i < num; i++)
{
#if (NTT_BASELEN == 1)
int32 tmp = *a * *b;
*r = barrett_reduce(tmp);
#elif (NTT_BASELEN == 2)
// 第r层第2^{r-1}+i个多项式使用的单位根,
// w_{2^r}^{brv_r(2^{r-1}+i)},NTT_ROUND = r-1
int32 zeta = zetas[bitrev_list[num * NTT_NEG + i]];
int32 tmp0 = a[0] * b[0] + zeta * barrett_reduce(a[1] * b[1]);
int32 tmp1 = a[0] * b[1] + a[1] * b[0];
r[0] = barrett_reduce(tmp0);
r[1] = barrett_reduce(tmp1);
#else
// 第r层第2^{r-1}+i个多项式使用的单位根,
// w_{2^r}^{brv_r(2^{r-1}+i)},NTT_ROUND = r-1
int32 zeta = zetas[bitrev_list[num * NTT_NEG + i]];
basemul(r, a, b, zeta);
#endif
r += NTT_BASELEN;
a += NTT_BASELEN;
b += NTT_BASELEN;
}
}
else
nttmul_mont(r, a, b);
}
Test
cputimer.h
#ifndef CPUTIMER
#define CPUTIMER
#if defined(__linux__)
// Linux系统
#include <unistd.h>
#elif defined(_WIN32)
// Windows系统
#include <intrin.h>
#include <windows.h>
#endif
/*单位:毫秒*/
void sleepms(int time) {
#if defined(__linux__)
// Linux系统
usleep(time * 1000);
#elif defined(_WIN32)
// Windows系统
Sleep(time);
#endif
}
/* Needs echo 2 > /sys/devices/cpu/rdpmc */
unsigned long long cputimer() {
// 以下三种方法,是等价的(只在 x86 上运行,而 x64 不支持内联汇编)
// 1.
/*__asm {
rdtsc;
shl edx, 32;
or eax, edx;
}*/
// 2.
//__asm RDTSC;
// 3.
/*__asm _emit 0x0F
__asm _emit 0x31*/
#if _WIN32
return __rdtsc();
#else
unsigned int lo, hi;
__asm__ volatile ("rdtsc" : "=a" (lo), "=d" (hi));
return ((unsigned long long)hi << 32) | lo;
#endif
}
//unsigned long long cputimer(); // 独立汇编代码
/*
align 16
_cputimer:
rdtsc
shl rdx, 32
or rax, rdx
ret
*/
unsigned long long CPUFrequency;
// 测量 CPU 主频
unsigned long long GetFrequency() {
unsigned long long t1 = cputimer();
sleepms(1000);
unsigned long long t2 = cputimer();
CPUFrequency = t2 - t1;
return CPUFrequency;
}
#define pn printf("
")
unsigned long long TM_start, TM_end;
#define Timer(code) TM_start = cputimer(); code; TM_end = cputimer();
printf("time = %lld cycles (%f s)
", TM_end - TM_start, (double)(TM_end - TM_start)/CPUFrequency); //对code部分计时
unsigned long long TM_mem[10000];
#define Loop(loop, code) for(int i=0; i<loop; i++) {
TM_start = cputimer(); code; TM_end = cputimer(); TM_mem[i] = TM_end - TM_start;} Analyis_TM(loop);
void __quick_sort(unsigned long long* arr, int begin, int end) //快速排序,简化版
{
if (begin >= end)
return;
unsigned long long temp1 = arr[begin], temp2;
int k = begin;
for (int i = begin + 1; i <= end; i++)
{
if (temp1 > arr[i])
{
temp2 = arr[i];
int j;
for (j = i - 1; j >= k; j--)
arr[j + 1] = arr[j];
arr[j + 1] = temp2;
k++;
}
}
__quick_sort(arr, begin, k - 1);
__quick_sort(arr, k + 1, end);
}
void quick_sort(unsigned long long* arr, int size)
{
__quick_sort(arr, 0, size - 1);
}
void Analyis_TM(int loop) //分析代码性能
{
unsigned long long min, max, med, aver = 0;
quick_sort(TM_mem, loop);
min = TM_mem[0];
max = TM_mem[loop-1];
med = TM_mem[loop >> 1];
for (int i = 0; i < loop; i++) {
aver += TM_mem[i];
}
aver /= loop;
printf("Time:
Minimum %10lld cycles,%10.6f ms
Maximum %10lld cycles,%10.6f ms
Median %10lld cycles,%10.6f ms
Average %10lld cycles,%10.6f ms
",
min, (double)min / CPUFrequency * 1000, max, (double)max / CPUFrequency * 1000, med, (double)med / CPUFrequency * 1000, aver, (double)aver / CPUFrequency * 1000);
}
#endif
Result
WSL
下用 gcc
编译:
gcc ntt_avx2.c test_ntt_avx2.c -o test_ntt_avx2 -O3 -fopt-info-vec-optimized -mavx2
执行 ./test_ntt_avx2
的结果为:
CPU Frequency = 2918844449
Time: ntt
Minimum 2588 cycles, 0.000887 ms
Maximum 39908 cycles, 0.013672 ms
Median 2636 cycles, 0.000903 ms
Average 2881 cycles, 0.000987 ms
Time: intt
Minimum 2546 cycles, 0.000872 ms
Maximum 29334 cycles, 0.010050 ms
Median 2586 cycles, 0.000886 ms
Average 2650 cycles, 0.000908 ms
Time: nttmul
Minimum 1828 cycles, 0.000626 ms
Maximum 91908 cycles, 0.031488 ms
Median 1838 cycles, 0.000630 ms
Average 1865 cycles, 0.000639 ms
Time: nttmul_mont
Minimum 164 cycles, 0.000056 ms
Maximum 7684 cycles, 0.002632 ms
Median 186 cycles, 0.000064 ms
Average 199 cycles, 0.000068 ms
风语者!平时喜欢研究各种技术,目前在从事后端开发工作,热爱生活、热爱工作。