合同算術用クラス (ModularArithmetic/modint.cpp)
- category: ModularArithmetic
-
View this file on GitHub
- Last commit date: 2020-04-23 19:40:18+09:00
Required by
- 階乗の高速計算 (ModularArithmetic/factorial.cpp)
- 補間多項式 (ModularArithmetic/interpolation.cpp)
- 多項式 (ModularArithmetic/polynomial.cpp)
Verified with
- test/aoj_0355.test.cpp
- test/aoj_1322.test.cpp
- test/aoj_2444.test.cpp
- test/aoj_3110.test.cpp
- test/aoj_DPL_5_A.test.cpp
- test/aoj_DPL_5_B.test.cpp
- test/aoj_DPL_5_C.test.cpp
- test/aoj_DPL_5_D.test.cpp
- test/aoj_DPL_5_E.test.cpp
- test/aoj_DPL_5_F.test.cpp
- test/aoj_DPL_5_G.test.cpp
- test/aoj_DPL_5_I.test.cpp
- test/aoj_DPL_5_J.test.cpp
- test/aoj_DPL_5_L.test.cpp
- test/yc_502.test.cpp
- test/yc_551.test.cpp
- test/yj_convolution_mod.test.cpp
- test/yj_convolution_mod_1000000007.test.cpp
- test/yj_convolution_mod_raw.test.cpp
- test/yj_inv_of_formal_power_series.test.cpp
- test/yj_log_of_formal_power_series.test.cpp
- test/yj_multipoint_evaluation.test.cpp
- test/yj_point_set_range_composite.test.cpp
- test/yj_polynomial_interpolation.test.cpp
- test/yj_queue_operate_all_composite.test.cpp
- test/yj_range_affine_range_sum.test.cpp
- test/yj_vertex_set_path_composite.test.cpp
Code
#ifndef H_modint
#define H_modint
/**
* @brief 合同算術用クラス
* @author えびちゃん
*/
#include <cstdint>
#include <limits>
#include <type_traits>
#include <utility>
template <intmax_t Modulo>
class modint {
public:
using value_type = typename std::conditional<
(0 < Modulo && Modulo < std::numeric_limits<int>::max() / 2), int, intmax_t
>::type;
private:
static constexpr value_type S_cmod = Modulo; // compile-time
static value_type S_rmod; // runtime
value_type M_value = 0;
static constexpr value_type S_inv(value_type n, value_type m) {
value_type x = 0;
value_type y = 1;
value_type a = n;
value_type b = m;
for (value_type u = y, v = x; a;) {
value_type q = b / a;
std::swap(x -= q*u, u);
std::swap(y -= q*v, v);
std::swap(b -= q*a, a);
}
if ((x %= m) < 0) x += m;
return x;
}
static value_type S_normalize(intmax_t n, value_type m) {
if (n >= m) {
n %= m;
} else if (n < 0) {
if ((n %= m) < 0) n += m;
}
return n;
}
public:
modint() = default;
modint(intmax_t n): M_value(S_normalize(n, get_modulo())) {}
modint& operator =(intmax_t n) {
M_value = S_normalize(n, get_modulo());
return *this;
}
modint& operator +=(modint const& that) {
if ((M_value += that.M_value) >= get_modulo()) M_value -= get_modulo();
return *this;
}
modint& operator -=(modint const& that) {
if ((M_value -= that.M_value) < 0) M_value += get_modulo();
return *this;
}
modint& operator *=(modint const& that) {
intmax_t tmp = M_value;
tmp *= that.M_value;
M_value = tmp % get_modulo();
return *this;
}
modint& operator /=(modint const& that) {
intmax_t tmp = M_value;
tmp *= S_inv(that.M_value, get_modulo());
M_value = tmp % get_modulo();
return *this;
}
modint& operator ++() {
if (++M_value == get_modulo()) M_value = 0;
return *this;
}
modint& operator --() {
if (M_value-- == 0) M_value = get_modulo()-1;
return *this;
}
modint operator ++(int) { modint tmp(*this); ++*this; return tmp; }
modint operator --(int) { modint tmp(*this); --*this; return tmp; }
friend modint operator +(modint lhs, modint const& rhs) { return lhs += rhs; }
friend modint operator -(modint lhs, modint const& rhs) { return lhs -= rhs; }
friend modint operator *(modint lhs, modint const& rhs) { return lhs *= rhs; }
friend modint operator /(modint lhs, modint const& rhs) { return lhs /= rhs; }
modint operator +() const { return *this; }
modint operator -() const {
if (M_value == 0) return *this;
return modint(get_modulo() - M_value);
}
friend bool operator ==(modint const& lhs, modint const& rhs) {
return lhs.M_value == rhs.M_value;
}
friend bool operator !=(modint const& lhs, modint const& rhs) {
return !(lhs == rhs);
}
value_type get() const { return M_value; }
static value_type get_modulo() { return ((S_cmod > 0)? S_cmod: S_rmod); }
template <int M = Modulo, typename Tp = typename std::enable_if<(M <= 0)>::type>
static Tp set_modulo(value_type m) { S_rmod = m; }
};
template <intmax_t N>
constexpr typename modint<N>::value_type modint<N>::S_cmod;
template <intmax_t N>
typename modint<N>::value_type modint<N>::S_rmod;
#endif /* !defined(H_modint) */
#line 1 "ModularArithmetic/modint.cpp"
/**
* @brief 合同算術用クラス
* @author えびちゃん
*/
#include <cstdint>
#include <limits>
#include <type_traits>
#include <utility>
template <intmax_t Modulo>
class modint {
public:
using value_type = typename std::conditional<
(0 < Modulo && Modulo < std::numeric_limits<int>::max() / 2), int, intmax_t
>::type;
private:
static constexpr value_type S_cmod = Modulo; // compile-time
static value_type S_rmod; // runtime
value_type M_value = 0;
static constexpr value_type S_inv(value_type n, value_type m) {
value_type x = 0;
value_type y = 1;
value_type a = n;
value_type b = m;
for (value_type u = y, v = x; a;) {
value_type q = b / a;
std::swap(x -= q*u, u);
std::swap(y -= q*v, v);
std::swap(b -= q*a, a);
}
if ((x %= m) < 0) x += m;
return x;
}
static value_type S_normalize(intmax_t n, value_type m) {
if (n >= m) {
n %= m;
} else if (n < 0) {
if ((n %= m) < 0) n += m;
}
return n;
}
public:
modint() = default;
modint(intmax_t n): M_value(S_normalize(n, get_modulo())) {}
modint& operator =(intmax_t n) {
M_value = S_normalize(n, get_modulo());
return *this;
}
modint& operator +=(modint const& that) {
if ((M_value += that.M_value) >= get_modulo()) M_value -= get_modulo();
return *this;
}
modint& operator -=(modint const& that) {
if ((M_value -= that.M_value) < 0) M_value += get_modulo();
return *this;
}
modint& operator *=(modint const& that) {
intmax_t tmp = M_value;
tmp *= that.M_value;
M_value = tmp % get_modulo();
return *this;
}
modint& operator /=(modint const& that) {
intmax_t tmp = M_value;
tmp *= S_inv(that.M_value, get_modulo());
M_value = tmp % get_modulo();
return *this;
}
modint& operator ++() {
if (++M_value == get_modulo()) M_value = 0;
return *this;
}
modint& operator --() {
if (M_value-- == 0) M_value = get_modulo()-1;
return *this;
}
modint operator ++(int) { modint tmp(*this); ++*this; return tmp; }
modint operator --(int) { modint tmp(*this); --*this; return tmp; }
friend modint operator +(modint lhs, modint const& rhs) { return lhs += rhs; }
friend modint operator -(modint lhs, modint const& rhs) { return lhs -= rhs; }
friend modint operator *(modint lhs, modint const& rhs) { return lhs *= rhs; }
friend modint operator /(modint lhs, modint const& rhs) { return lhs /= rhs; }
modint operator +() const { return *this; }
modint operator -() const {
if (M_value == 0) return *this;
return modint(get_modulo() - M_value);
}
friend bool operator ==(modint const& lhs, modint const& rhs) {
return lhs.M_value == rhs.M_value;
}
friend bool operator !=(modint const& lhs, modint const& rhs) {
return !(lhs == rhs);
}
value_type get() const { return M_value; }
static value_type get_modulo() { return ((S_cmod > 0)? S_cmod: S_rmod); }
template <int M = Modulo, typename Tp = typename std::enable_if<(M <= 0)>::type>
static Tp set_modulo(value_type m) { S_rmod = m; }
};
template <intmax_t N>
constexpr typename modint<N>::value_type modint<N>::S_cmod;
template <intmax_t N>
typename modint<N>::value_type modint<N>::S_rmod;