gmpy_math.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import os
  17. import random
  18. import gmpy2
  19. POWMOD_GMP_SIZE = pow(2, 64)
  20. def powmod(a, b, c):
  21. """
  22. return int: (a ** b) % c
  23. """
  24. if a == 1:
  25. return 1
  26. if max(a, b, c) < POWMOD_GMP_SIZE:
  27. return pow(a, b, c)
  28. else:
  29. return int(gmpy2.powmod(a, b, c))
  30. def crt_coefficient(p, q):
  31. """
  32. return crt coefficient
  33. """
  34. tq = gmpy2.invert(p, q)
  35. tp = gmpy2.invert(q, p)
  36. return tp * q, tq * p
  37. def powmod_crt(x, d, n, p, q, cp, cq):
  38. """
  39. return int: (a ** b) % n
  40. """
  41. rp = gmpy2.powmod(x, d % (p - 1), p)
  42. rq = gmpy2.powmod(x, d % (q - 1), q)
  43. return int((rp * cp + rq * cq) % n)
  44. def invert(a, b):
  45. """return int: x, where a * x == 1 mod b"""
  46. x = int(gmpy2.invert(a, b))
  47. if x == 0:
  48. raise ZeroDivisionError("invert(a, b) no inverse exists")
  49. return x
  50. def getprimeover(n):
  51. """return a random n-bit prime number"""
  52. r = gmpy2.mpz(random.SystemRandom().getrandbits(n))
  53. r = gmpy2.bit_set(r, n - 1)
  54. return int(gmpy2.next_prime(r))
  55. def isqrt(n):
  56. """ return the integer square root of N """
  57. return int(gmpy2.isqrt(n))
  58. def is_prime(n):
  59. """
  60. true if n is probably a prime, false otherwise
  61. :param n:
  62. :return:
  63. """
  64. return gmpy2.is_prime(int(n))
  65. def legendre(a, p):
  66. return pow(a, (p - 1) // 2, p)
  67. def tonelli(n, p):
  68. # assert legendre(n, p) == 1, "not a square (mod p)"
  69. q = p - 1
  70. s = 0
  71. while q % 2 == 0:
  72. q //= 2
  73. s += 1
  74. if s == 1:
  75. return pow(n, (p + 1) // 4, p)
  76. for z in range(2, p):
  77. if p - 1 == legendre(z, p):
  78. break
  79. c = pow(z, q, p)
  80. r = pow(n, (q + 1) // 2, p)
  81. t = pow(n, q, p)
  82. m = s
  83. while (t - 1) % p != 0:
  84. t2 = (t * t) % p
  85. for i in range(1, m):
  86. if (t2 - 1) % p == 0:
  87. break
  88. t2 = (t2 * t2) % p
  89. b = pow(c, 1 << (m - i - 1), p)
  90. r = (r * b) % p
  91. c = (b * b) % p
  92. t = (t * c) % p
  93. m = i
  94. return r
  95. def gcd(a, b):
  96. return int(gmpy2.gcd(a, b))
  97. def next_prime(n):
  98. return int(gmpy2.next_prime(n))
  99. def mpz(n):
  100. return gmpy2.mpz(n)