Montgomery Reduction 算法流程与实际实现 下面默认对于模数$m$取模,由于这篇文章的重点是实现(其实就是我自己存一下板子),因此没有证明
使用注意: Montgomery Reduction 相较于 Barret Reduction来说,不需要使用__int128
但是有着更高的封装程度,因为涉及到普通数与Montgomery Reduction运算中间量的转化
另外,常见的Montgomery Reduction 在编程竞赛中的应用 要求模数为奇数
但是在Min25博客上来看,Montgomery似乎有着更高的效率
Montgomery Reduction算法思想简介 在计算取模运算的过程中,将每一个元素$T$都乘上一个特定的值$R(R>m,\gcd(R,m)=1)$
用特殊的方法处理相乘时除掉一个$R$的过程,从而避免取模运算
在使用的模数为常量时,编译器通常会自动加入Barrett reduction的优化,因此实际上这个算法对于动态模数的情形更为适用
(你自己真不一定写得过STL,但是确实可以比STL块)
编程上的应用简介 对于$m$为奇数的情况,取$R=2^{32}$,用 自然溢出来代替取模/位运算位移代替除法 来加速运算
我们还需要令$m’ = -m^{-1} \mod R$,有结论
对于某一个数$T,0 \leq T < mR$,若令$U = Tm’ \mod R$,则 $\frac{T+Um} {R}$为整数,且 $\frac{T+Um} {R}=TR^{-1} \mod m$
那么我们在计算$\frac{T} {R}$时,实际上只需要计算$\frac{T+Um} {R}$,可以预处理$m’$,溢出计算$Tm’$,位运算左移计算$\frac{T+Um} {R}$
实际使用时的实现,可以用一个类实现以下方法
在实现时需要尤其注意不要出现溢出
1.预处理$m’$ $(R-\lfloor \frac{R} {m}\rfloor )\cdot (R\mod m)$
1 2 3 4 5 6 7 8 9 10 using u32=unsigned ;using i32=int ;using u64=unsigned long long ;using i64=long long ;u32 m; u32 getinv () { u32 inv=m; for (int i=0 ;i<4 ;++i) inv*=2 -inv*m; }
2.reduce方法 1 2 3 4 5 u32 reduce (u64 x) { u32 y = u32 (x >> 32 ) - u32 ((u64 (u32 (x)*inv)*m) >> 32 ); return i32 (y) < 0 ? y + m : y; }
3.普通数转Montgomery Reduction 我们要计算$x\rightarrow xR=x\cdot 2^{32}$,但是如果直接用取模就失去了意义。。。
方法是快速计算$x\cdot R^2$,然后reduce一次
1 2 3 4 u32 R2=-u64 (m)%m; u32 intToMont (i32 x) { return reduce (u64 (x)*R2); }
4.Montomery运算 1 2 3 4 5 6 7 8 9 10 11 u32 Add (u32 x,u32 y) { x+=y-m; return i32 (x<0 )?x+m:x; } u32 Dec (u32 x,u32 y) { x-=y; return i32 (x<0 )?x+m:x; } u32 Mul (u32 x,u32 y) { return reduce (u64 (x)*y); }
5.Montomery Reduction转普通数 1 2 3 i32 get (u32 x) { return reduce (x); }
封装之后,得到板子一号,这个是动态模数的。。。
实现上可能的误区:
为什么不用-inv?避免加法,原因是加法取模要和m比较
同样的,下面的i32(y)<0语句可以被替换为y>=m(负数溢出),看似减少一次类型转换,但是实际上0作为常量比较快得多
加法运算时也是类似的原因,x>=m的比较实在太慢,因此强制减去一个m,然后和0比
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 using u32=uint32_t ;using i32=int32_t ;using u64=uint64_t ;using i64=int64_t ;static u32 m,inv,r2,P;u32 getinv () { u32 inv=m; for (int i=0 ;i<4 ;++i) inv*=2 -inv*m; return inv; } struct Mont {private : u32 x; public : static u32 reduce (u64 x) { u32 y=u32 (x>>32 )-u32 ((u64 (u32 (x)*inv)*m)>>32 ); return i32 (y)<0 ?y+m:y; } Mont (){ ; } Mont (i32 x):x (reduce (u64 (x)*r2)) { } Mont& operator += (const Mont &rhs) { return x+=rhs.x-m,is32 (x)<0 &&(x+=m),*this ; } Mont& operator -= (const Mont &rhs) { return x-=rhs.x,i32 (x)<0 &&(x+=m),*this ; } Mont& operator *= (const Mont &rhs) { return x=reduce (u64 (x)*rhs.x),*this ; } friend Mont operator + (Mont x,const Mont &y) { return x+=y; } friend Mont operator - (Mont x,const Mont &y) { return x-=y; } friend Mont operator * (Mont x,const Mont &y) { return x*=y; } i32 get () { return reduce (x); } }; void Init (int m) { ::m=m; inv=-getinv (); r2=-u64 (m)%m; }
动态模数的方法,计算$5\cdot 10^7!\mod 998244353$在duck.ac上评测结果,时间单位是微秒$\mu s$
1 2 Naive Mod : 213689172 Time: 518352 My Montgomery : 213689172 Time: 192195
这个是我自己写的静态模数的,因为模数是静态的,所以不需要一定和0比较大小
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 template <uint32_t m> struct Mont {private : using u32=uint32_t ; using i32=int32_t ; using u64=uint64_t ; using i64=int64_t ; static constexpr u32 getinv () { u32 inv=m; for (int i=0 ;i<4 ;++i) inv*=2 -inv*m; return inv; } static constexpr u32 inv=-getinv (),r2=-u64 (m)%m; u32 x; public : static constexpr u32 reduce (u64 x) { u32 y=(x+u64 (u32 (x)*inv)*m)>>32 ; return y>=m?y-m:y; } Mont (){ ; } constexpr Mont (i32 x) :x(reduce(u64(x)*r2)) { } constexpr Mont& operator += (const Mont &rhs) { return x+=rhs.x-m,x>=m&&(x+=m),*this ; } constexpr Mont& operator -= (const Mont &rhs) { return x-=rhs.x,x>=m&&(x+=m),*this ; } constexpr Mont& operator *= (const Mont &rhs) { return x=reduce (u64 (x)*rhs.x),*this ; } constexpr friend Mont operator + (Mont x,const Mont &y) { return x+=y; } constexpr friend Mont operator - (Mont x,const Mont &y) { return x-=y; } constexpr friend Mont operator * (Mont x,const Mont &y) { return x*=y; } constexpr i32 get () { return reduce (x); } } ;
这个是摘自LOJ多项式乘法 hly1204的提交记录
个人解读:实际上每次存储的是$x \mod 2m$的值,避免了reduce时的加减取模
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 template <std::uint32_t P> struct MontgomeryModInt32 {public : using i32 = std::int32_t ; using u32 = std::uint32_t ; using i64 = std::int64_t ; using u64 = std::uint64_t ; private : u32 v; static constexpr u32 get_r () { u32 iv = P; for (u32 i = 0 ; i != 4 ; ++i) iv *= 2 - P * iv; return iv; } static constexpr u32 r = -get_r (), r2 = -u64 (P) % P; static_assert ((P & 1 ) == 1 ); static_assert (r * P == -1 ); static_assert (P < (1 << 30 )); public : static constexpr u32 pow_mod (u32 x, u64 y) { if ((y %= P - 1 ) < 0 ) y += P - 1 ; u32 res = 1 ; for (; y != 0 ; y >>= 1 , x = u64 (x) * x % P) if (y & 1 ) res = u64 (res) * x % P; return res; } static constexpr u32 get_pr () { u32 tmp[32 ] = { }, cnt = 0 ; const u64 phi = P - 1 ; u64 m = phi; for (u64 i = 2 ; i * i <= m; ++i) { if (m % i == 0 ) { tmp[cnt++] = i; while (m % i == 0 ) m /= i; } } if (m > 1 ) tmp[cnt++] = m; for (u64 res = 2 ; res <= phi; ++res) { bool flag = true ; for (u32 i = 0 ; i != cnt && flag; ++i) flag &= pow_mod (res, phi / tmp[i]) != 1 ; if (flag) return res; } return 0 ; } MontgomeryModInt32 () = default ; ~MontgomeryModInt32 () = default ; constexpr MontgomeryModInt32 (u32 v) : v(reduce(u64(v) * r2)) { } constexpr MontgomeryModInt32 (const MontgomeryModInt32 &rhs) : v(rhs.v) { } static constexpr u32 reduce (u64 x) { return x + (u64 (u32 (x) * r) * P) >> 32 ; } constexpr u32 get () const { u32 res = reduce (v); return res - (P & -(res >= P)); } explicit constexpr operator u32 () const { return get (); } explicit constexpr operator i32 () const { return i32 (get ()); } constexpr MontgomeryModInt32 &operator =(const MontgomeryModInt32 &rhs) { return v = rhs.v, *this ; } constexpr MontgomeryModInt32 operator -() const { MontgomeryModInt32 res; return res.v = (P << 1 & -(v != 0 )) - v, res; } constexpr MontgomeryModInt32 inv () const { return pow (-1 ); } constexpr MontgomeryModInt32 &operator +=(const MontgomeryModInt32 &rhs) { return v += rhs.v - (P << 1 ), v += P << 1 & -(i32 (v) < 0 ), *this ; } constexpr MontgomeryModInt32 &operator -=(const MontgomeryModInt32 &rhs) { return v -= rhs.v, v += P << 1 & -(i32 (v) < 0 ), *this ; } constexpr MontgomeryModInt32 &operator *=(const MontgomeryModInt32 &rhs) { return v = reduce (u64 (v) * rhs.v), *this ; } constexpr MontgomeryModInt32 &operator /=(const MontgomeryModInt32 &rhs) { return this ->operator *=(rhs.inv ()); } friend MontgomeryModInt32 operator +(const MontgomeryModInt32 &lhs, const MontgomeryModInt32 &rhs) { return MontgomeryModInt32 (lhs) += rhs; } friend MontgomeryModInt32 operator -(const MontgomeryModInt32 &lhs, const MontgomeryModInt32 &rhs) { return MontgomeryModInt32 (lhs) -= rhs; } friend MontgomeryModInt32 operator *(const MontgomeryModInt32 &lhs, const MontgomeryModInt32 &rhs) { return MontgomeryModInt32 (lhs) *= rhs; } friend MontgomeryModInt32 operator /(const MontgomeryModInt32 &lhs, const MontgomeryModInt32 &rhs) { return MontgomeryModInt32 (lhs) /= rhs; } friend std::istream &operator >>(std::istream &is, MontgomeryModInt32 &rhs) { return is >> rhs.v, rhs.v = reduce (u64 (rhs.v) * r2), is; } friend std::ostream &operator <<(std::ostream &os, const MontgomeryModInt32 &rhs) { return os << rhs.get (); } constexpr MontgomeryModInt32 pow (i64 y) const { if ((y %= P - 1 ) < 0 ) y += P - 1 ; MontgomeryModInt32 res (1 ) , x (*this ) ; for (; y != 0 ; y >>= 1 , x *= x) if (y & 1 ) res *= x; return res; } };
这个是计算$5\cdot 10^7!\mod 998244353$在duck.ac上的测试结果
1 2 3 Naive Mod : 213689172 Time: 180649 My Montgomery : 213689172 Time: 178217 His Montgomery : 213689172 Time: 152847
这个是计算$7\cdot 10^7!\mod 998244353$在duck.ac上的测试结果
1 2 3 Naive Mod : 939830261 Time: 252908 My Montgomery : 939830261 Time: 249476 His Montgomery : 939830261 Time: 213986
还可以看Min25博客里下面的ModInt64板本
传送门
下面自己实现的$\mod 2m$版本,差不多也是最终版本了,跑起来和hly1204差不多
静态版本
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 template <uint32_t m> struct Mont2 {private : using u32=uint32_t ; using i32=int32_t ; using u64=uint64_t ; using i64=int64_t ; static constexpr u32 m2=m<<1 ; static constexpr u32 getinv () { u32 inv=m; for (int i=0 ;i<4 ;++i) inv*=2 -inv*m; return inv; } static constexpr u32 inv=-getinv (),r2=-u64 (m)%m; u32 x; public : static constexpr u32 reduce (u64 x) { return (x+u64 (u32 (x)*inv)*m)>>32 ; } Mont2 (){ ; } constexpr Mont2 (i32 x) :x(reduce(u64(x)*r2)) { } constexpr Mont2& operator += (const Mont2 &rhs) { return x+=rhs.x-m2,x>=m2&&(x+=m2),*this ; } constexpr Mont2& operator -= (const Mont2 &rhs) { return x-=rhs.x,x>=m2&&(x+=m2),*this ; } constexpr Mont2& operator *= (const Mont2 &rhs) { return x=reduce (u64 (x)*rhs.x),*this ; } constexpr friend Mont2 operator + (Mont2 x,const Mont2 &y) { return x+=y; } constexpr friend Mont2 operator - (Mont2 x,const Mont2 &y) { return x-=y; } constexpr friend Mont2 operator * (Mont2 x,const Mont2 &y) { return x*=y; } constexpr i32 get () { u32 res=reduce (x); return res>=m?res-m:res; } } ;
板子各有优劣.jpg
另外这是Int_To_Montgomery加法的速度,$7\cdot 10^7$次加法与类型转换
1 2 3 4 Naive : : 305907824 80074 My Montgomery : 305907824 109479 My Montgomery2 : 305907824 99896 His Montgomery : 305907824 117449
动态版本
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 using u32=uint32_t ;using i32=int32_t ;using u64=uint64_t ;using i64=int64_t ;static u32 m,m2,inv,r2,P;u32 getinv () { u32 inv=m; for (int i=0 ;i<4 ;++i) inv*=2 -inv*m; return inv; } struct Mont {private : u32 x; public : static u32 reduce (u64 x) { u32 y=(x+u64 (u32 (x)*inv)*m)>>32 ; return i32 (y)<0 ?y+m:y; } Mont (){ ; } Mont (i32 x):x (reduce (u64 (x)*r2)) { } Mont& operator += (const Mont &rhs) { return x+=rhs.x-m2,i32 (x)<0 &&(x+=m2),*this ; } Mont& operator -= (const Mont &rhs) { return x-=rhs.x,i32 (x)<0 &&(x+=m2),*this ; } Mont& operator *= (const Mont &rhs) { return x=reduce (u64 (x)*rhs.x),*this ; } friend Mont operator + (Mont x,const Mont &y) { return x+=y; } friend Mont operator - (Mont x,const Mont &y) { return x-=y; } friend Mont operator * (Mont x,const Mont &y) { return x*=y; } i32 get () { u32 res=reduce (x); return res>=m?res-m:res; } }; void Init (int m) { ::m=m,m2=m*2 ; inv=-getinv (); r2=-u64 (m)%m; }
这个动态模板计算$5\cdot 10^7!\mod 998244353$
1 2 Naive Mod : 213689172 494061 (稍微修改了一下暴力的细节。。) My Montgomery2 : 213689172 152849
不得不说duck.ac真的很nb