// A simple ModInt implementation.template <int M>struct ModInt { struct skip_mod {}; ModInt(int v, skip_mod) : v(v) {} int v; ModInt() : v(0) {} // Initialization: find remainder. // Equivalent to: v = int((x % M + M) % M) ModInt(long long x) { x %= M; if (x < 0) x += M; v = int(x); } // Addition. // Equivalent to: ModInt((l.v + r.v) % M) friend ModInt operator+(ModInt l, ModInt r) { int res = l.v + r.v; if (res >= M) res -= M; return ModInt(res, skip_mod{}); } // Subtraction. // Equivalent to: ModInt((l.v - r.v + M) % M) friend ModInt operator-(ModInt l, ModInt r) { int res = l.v - r.v; if (res < 0) res += M; return ModInt(res, skip_mod{}); } // Multiplication. friend ModInt operator*(ModInt l, ModInt r) { return ModInt(1LL * l.v * r.v % M, skip_mod{}); } // Exponentiation. ModInt pow(long long b) const { ModInt res{1}, po{*this}; for (; b; b >>= 1) { if (b & 1) res = res * po; po = po * po; } return res; }};
在素性测试与质因数分解中,经常会遇到模数在 long long 范围内的乘法取模运算.为了避免运算中的整型溢出问题,本节介绍一种可以处理模数在 long long 范围内,不需要使用 __int128 且复杂度为 O(1) 的「快速乘」.本算法要求测评系统中,long double 至少表示为 80 位扩展精度浮点数 1.
假设 0≤a,b<m,要计算 abmodm.注意到:
abmodm=ab−⌊mab⌋m.
利用 unsigned long long 的自然溢出:
abmodm=ab−⌊mab⌋m=(ab−⌊mab⌋m)mod264.
只要能算出商 ⌊mab⌋,最右侧表达式中的乘法和减法运算都可以使用 unsigned long long 直接计算.
接下来,只需要考虑如何计算 ⌊mab⌋.解决方案是先使用 long double 算出 ma 再乘上 b.既然使用了 long double,就无疑会有精度误差.假设 long double 表示为 80 位扩展精度浮点数(即符号为 1 位,指数为 15 位,尾数为 64 位),那么 long double 最多能精确表示的有效位数为 642.所以 ma 最差从第 65 位开始出错,误差范围 3 为 (−2−64,2−64).乘上 b 这个 64 位带符号整数,误差范围为 (−0.5,0.5).为了简化后续讨论,可以先加一个 0.5 再取整,最后的误差范围是 {0,1}.
最后,代入上式计算时,需要乘以 −m,所以最后的误差范围是 {0,−m}.因为 m 在 long long 范围内,所以当结果 r∈[0,m) 时,直接返回 r,否则返回 r+m.
代码实现如下:
参考实现
long long mul(long long a, long long b, long long m) { long long c = (unsigned long long)a * b - (unsigned long long)((long double)a / m * b + 0.5L) * m; return c < 0 ? c + m : c;}
// Modular multiplication of int32_t using Barrett reduction.class Barrett { int32_t m; uint64_t r; public: Barrett(int32_t m) : m(m), r((uint64_t)(-m) / m + 1) {} // Barrett reduction: a % m. int32_t reduce(int64_t a) const { int64_t q = (__int128)a * r >> 64; a -= q * m; return a >= m ? a - m : a; } // Modular multiplication: (a * b) % m; // Assume that 0 <= a, b < m. int32_t mul(int32_t a, int32_t b) const { return reduce((int64_t)a * b); }};
// Montgomery modular multiplication.// The modulus m must be odd. The constant r is 2^32.class Montgomery { int32_t m; uint32_t mm, r2; public: Montgomery(int32_t m) : m(m), mm(1), r2(-m) { // Compute mm as inv(m) mod r. for (int i = 0; i < 5; ++i) { mm *= 2 - mm * m; } // Compute r2 as r * r mod m. // If allowed to use modular operation for uint64_t, simply use: // r2 = (uint64_t)(-m) % m; r2 %= m; r2 <<= 1; if (r2 >= (uint32_t)m) r2 -= m; for (int i = 0; i < 5; ++i) { r2 = mul(r2, r2); } } // Montgomery reduction: x * inv(r) % m. // Also used to transform x from Montgomery space to the normal space. int32_t reduce(int64_t x) { uint32_t u = (uint32_t)x * mm; int32_t ans = (x - (int64_t)m * u) >> 32; return ans < 0 ? ans + m : ans; } // Multiplication in Montgomery space: x * y * inv(r) % m. int32_t mul(int32_t x, int32_t y) { return reduce((int64_t)x * y); } // Transform x from the normal space to Montgomery space. int32_t init(int32_t x) { return mul(x, r2); }};
// Store 4L(a) for a = 2^d + 1, where L(a) is disc. log. base 388251981.// The first two values are never used and thus set to zero.// The base is chosen such that 4L(2^16+1) = 2^16.constexpr uint32_t log_table[16] = { 0x00000000, 0x00000000, 0xbba0267c, 0x49b9d1e8, 0xf0026f90, 0xd6e17e20, 0xe78bf840, 0x039fe080, 0xaf7f8100, 0x60fe0200, 0xd1f80400, 0x23e00800, 0x47801000, 0x8e002000, 0x18004000, 0x20008000,};// Compute 4L(v).uint32_t log_mod_2_32(uint32_t x, uint32_t v) { for (int i = 2; i != 16; ++i) { if ((v >> i) & 1) { v += v << i; x -= log_table[i]; } } x += v ^ 1; return x;}// Compute x*a for 4L(a) = v.uint32_t exp_mod_2_32(uint32_t x, uint32_t v) { for (int i = 2; i != 16; ++i) { if ((v >> i) & 1) { x += x << i; v -= log_table[i]; } } x *= v ^ 1; return x;}// Compute x*a^b for odd a.uint32_t pow_odd_mod_2_32(uint32_t a, uint32_t b, uint32_t x) { if (a & 2) { a = -a; if (b & 1) { x = -x; } } return exp_mod_2_32(x, log_mod_2_32(0, a) * b);}// Compute x*a^b mod 2^32.uint32_t pow_mod_2_32(uint32_t a, uint32_t b, uint32_t x = 1) { if (!a) return b == 0 ? x : 0; auto d = __builtin_ctz(a); if ((uint64_t)d * b >= 32) return 0; return pow_odd_mod_2_32(a >> d, b, x) << (d * b);}
Barrett, Paul. “Implementing the Rivest Shamir and Adleman public key encryption algorithm on a standard digital signal processor.” In Conference on the Theory and Application of Cryptographic Techniques, pp. 311-323. Berlin, Heidelberg: Springer Berlin Heidelberg, 1986.
Becker, Hanno, Vincent Hwang, Matthias J. Kannwischer, Bo-Yin Yang, and Shang-Yi Yang. “Neon NTT: Faster Dilithium, Kyber, and Saber on Cortex-A72 and Apple M1.” IACR Transactions on Cryptographic Hardware and Embedded Systems (2022): 221-244.
Montgomery, Peter L. “Modular multiplication without trial division.” Mathematics of computation 44, no. 170 (1985): 519-521.