{-# LANGUAGE MagicHash #-}
{-# LANGUAGE NamedFieldPuns #-}

-- | Fast modular multiplication for `Word64` using Montgomery multiplication. If the modulus value
-- is known to fit in 32 bits, use the @AtCoder.Internal.Barrett@ module instead.
--
-- @since 1.2.6.0
module AtCoder.Extra.Math.Montgomery64
  ( -- * Montgomery64
    Montgomery64,

    -- * Constructor
    new,
    fromVal,

    -- * Accessor
    umod,

    -- * Montgomery form encoding
    encode,
    decode,
    reduce,

    -- * Calculations
    addMod,
    subMod,
    mulMod,
    powMod,
    eq,
  )
where

import AtCoder.Internal.Assert qualified as ACIA
import Data.Bits (bit, (!>>.))
import Data.WideWord.Word128 (Word128 (..))
import Data.Word (Word64)
import GHC.Exts (Proxy#)
import GHC.Stack (HasCallStack)
import GHC.TypeNats (KnownNat, natVal')

-- TODO: provide with newtype for Montgomery form?

-- | Fast modular multiplication for `Word64` using Montgomery64 multiplication.
--
-- @since 1.2.6.0
data Montgomery64 = Montgomery64
  { Montgomery64 -> Word64
mM64 :: {-# UNPACK #-} !Word64,
    Montgomery64 -> Word64
rM64 :: {-# UNPACK #-} !Word64,
    Montgomery64 -> Word64
n2M64 :: {-# UNPACK #-} !Word64
  }
  deriving
    ( -- | @since 1.2.6.0
      Montgomery64 -> Montgomery64 -> Bool
(Montgomery64 -> Montgomery64 -> Bool)
-> (Montgomery64 -> Montgomery64 -> Bool) -> Eq Montgomery64
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Montgomery64 -> Montgomery64 -> Bool
== :: Montgomery64 -> Montgomery64 -> Bool
$c/= :: Montgomery64 -> Montgomery64 -> Bool
/= :: Montgomery64 -> Montgomery64 -> Bool
Eq,
      -- | @since 1.2.6.0
      Int -> Montgomery64 -> ShowS
[Montgomery64] -> ShowS
Montgomery64 -> String
(Int -> Montgomery64 -> ShowS)
-> (Montgomery64 -> String)
-> ([Montgomery64] -> ShowS)
-> Show Montgomery64
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Montgomery64 -> ShowS
showsPrec :: Int -> Montgomery64 -> ShowS
$cshow :: Montgomery64 -> String
show :: Montgomery64 -> String
$cshowList :: [Montgomery64] -> ShowS
showList :: [Montgomery64] -> ShowS
Show
    )

-- TODO: add unasfePerformIO?
-- TODO: remove NOINLINE?

-- | \(O(1)\) Static, shared storage of `Montgomery64`.
--
-- ==== Constraints
-- - \(m \le 2^{62})
-- - \(m\) is odd
--
-- @since 1.2.6.0
{-# NOINLINE new #-}
new :: forall a. (KnownNat a) => Proxy# a -> Montgomery64
-- FIXME: test allocated once
new :: forall (a :: Nat). KnownNat a => Proxy# a -> Montgomery64
new Proxy# a
p = Word64 -> Montgomery64
fromVal (Word64 -> Montgomery64) -> (Nat -> Word64) -> Nat -> Montgomery64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Nat -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Nat -> Montgomery64) -> Nat -> Montgomery64
forall a b. (a -> b) -> a -> b
$! Proxy# a -> Nat
forall (n :: Nat). KnownNat n => Proxy# n -> Nat
natVal' Proxy# a
p

-- | \(O(1)\) Creates a `Montgomery64` for a modulus value \(m\) of type `Word64` value.
--
-- ==== Constraints
-- - \(m \le 2^{62})
-- - \(m\) is odd
--
-- @since 1.2.6.0
{-# INLINE fromVal #-}
fromVal :: Word64 -> Montgomery64
fromVal :: Word64 -> Montgomery64
fromVal Word64
m =
  let !Word128
m128 :: Word128 = Word64 -> Word128
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
m
      !n2 :: Word64
n2 = Word128 -> Word64
word128Lo64 (Word128 -> Word64) -> Word128 -> Word64
forall a b. (a -> b) -> a -> b
$ (-Word128
m128) Word128 -> Word128 -> Word128
forall a. Integral a => a -> a -> a
`mod` Word128
m128
      !r :: Word64
r = Word64 -> Int -> Word64
getR Word64
m Int
0
      !()
_ = HasCallStack => Bool -> String -> ()
Bool -> String -> ()
ACIA.runtimeAssert (Word64
r Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
* Word64
m Word64 -> Word64 -> Bool
forall a. Eq a => a -> a -> Bool
== -Word64
1) String
"AtCoder.Extra.Montgomery64.fromVal: internal implementation error"
   in Word64 -> Word64 -> Word64 -> Montgomery64
Montgomery64 Word64
m Word64
r Word64
n2
  where
    !()
_ = HasCallStack => Bool -> String -> ()
Bool -> String -> ()
ACIA.runtimeAssert (Word64 -> Bool
forall a. Integral a => a -> Bool
odd Word64
m Bool -> Bool -> Bool
&& Word64
m Word64 -> Word64 -> Bool
forall a. Ord a => a -> a -> Bool
<= Int -> Word64
forall a. Bits a => Int -> a
bit Int
62) (String -> ()) -> String -> ()
forall a b. (a -> b) -> a -> b
$ String
"AtCoder.Extra.Montgomery64.fromVal: not given odd modulus value that is less than or equal to 2^62: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Word64 -> String
forall a. Show a => a -> String
show Word64
m
    getR :: Word64 -> Int -> Word64
    getR :: Word64 -> Int -> Word64
getR !Word64
acc Int
i
      | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
5 = -Word64
acc
      | Bool
otherwise = Word64 -> Int -> Word64
getR (Word64
acc Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
* (Word64
2 Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- Word64
m Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
* Word64
acc)) (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)

-- | \(O(1)\) Retrieves the modulus \(m\).
--
-- @since 1.2.6.0
{-# INLINE umod #-}
umod :: Montgomery64 -> Word64
umod :: Montgomery64 -> Word64
umod Montgomery64 {Word64
mM64 :: Montgomery64 -> Word64
mM64 :: Word64
mM64} = Word64
mM64

-- | \(O(1)\) Converts the given `Word64` to Montgomery form.
--
-- @since 1.2.6.0
{-# INLINE encode #-}
encode :: Montgomery64 -> Word64 -> Word64
encode :: Montgomery64 -> Word64 -> Word64
encode mont :: Montgomery64
mont@Montgomery64 {Word64
n2M64 :: Montgomery64 -> Word64
n2M64 :: Word64
n2M64} Word64
x = Montgomery64 -> Word128 -> Word64
reduce Montgomery64
mont (Word128 -> Word64) -> Word128 -> Word64
forall a b. (a -> b) -> a -> b
$! Word64 -> Word128
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
x Word128 -> Word128 -> Word128
forall a. Num a => a -> a -> a
* Word64 -> Word128
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
n2M64

-- | \(O(1)\) Retrieves the value from a Montgomery form of value.
--
-- @since 1.2.6.0
{-# INLINE decode #-}
decode :: Montgomery64 -> Word64 -> Word64
decode :: Montgomery64 -> Word64 -> Word64
decode mont :: Montgomery64
mont@Montgomery64 {Word64
mM64 :: Montgomery64 -> Word64
mM64 :: Word64
mM64} Word64
x =
  let !res :: Word64
res = Montgomery64 -> Word128 -> Word64
reduce Montgomery64
mont (Word128 -> Word64) -> Word128 -> Word64
forall a b. (a -> b) -> a -> b
$! Word64 -> Word128
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
x
   in if Word64
res Word64 -> Word64 -> Bool
forall a. Ord a => a -> a -> Bool
>= Word64
mM64 then Word64
res Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- Word64
mM64 else Word64
res

-- | \(O(1)\) Takes the mod in Montgomery form.
--
-- @since 1.2.6.0
{-# INLINE reduce #-}
reduce :: Montgomery64 -> Word128 -> Word64
reduce :: Montgomery64 -> Word128 -> Word64
reduce Montgomery64 {Word64
mM64 :: Montgomery64 -> Word64
mM64 :: Word64
mM64, Word64
rM64 :: Montgomery64 -> Word64
rM64 :: Word64
rM64} Word128
x =
  Word128 -> Word64
word128Hi64 (Word128 -> Word64) -> Word128 -> Word64
forall a b. (a -> b) -> a -> b
$!
    (Word128
x Word128 -> Word128 -> Word128
forall a. Num a => a -> a -> a
+ Word64 -> Word128
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word128 -> Word64
word128Lo64 Word128
x Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
* Word64
rM64) Word128 -> Word128 -> Word128
forall a. Num a => a -> a -> a
* Word64 -> Word128
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
mM64)

-- | \(O(1)\) Calculates \(a + b \bmod m\) in the Montgomery form.
{-# INLINE addMod #-}
addMod :: Word64 -> Word64 -> Word64 -> Word64
addMod :: Word64 -> Word64 -> Word64 -> Word64
addMod Word64
m Word64
a Word64
b
    | Word64
x' Word64 -> Word64 -> Bool
forall a. Ord a => a -> a -> Bool
>= Word64
m = Word64
x' Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- Word64
m
    | Bool
otherwise = Word64
x'
  where
    !x' :: Word64
x' = Word64
a Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Word64
b

-- | \(O(1)\) Calculates \(a - b \bmod m\) in the Montgomery form.
{-# INLINE subMod #-}
subMod :: Word64 -> Word64 -> Word64 -> Word64
subMod :: Word64 -> Word64 -> Word64 -> Word64
subMod Word64
m Word64
a Word64
b
    | Word64
a Word64 -> Word64 -> Bool
forall a. Ord a => a -> a -> Bool
>= Word64
b = Word64
a Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- Word64
b
    | Bool
otherwise = Word64
a Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- Word64
b Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Word64
m

-- | \(O(1)\) Calculates \(a^n \bmod m\) in the Montgomery form.
--
-- @since 1.2.6.0
{-# INLINE mulMod #-}
mulMod :: Montgomery64 -> Word64 -> Word64 -> Word64
mulMod :: Montgomery64 -> Word64 -> Word64 -> Word64
mulMod Montgomery64
mont Word64
a Word64
b = Montgomery64 -> Word128 -> Word64
reduce Montgomery64
mont (Word128 -> Word64) -> Word128 -> Word64
forall a b. (a -> b) -> a -> b
$! Word64 -> Word128
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
a Word128 -> Word128 -> Word128
forall a. Num a => a -> a -> a
* Word64 -> Word128
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
b

-- | \(O(w)\) Calculates \(a^n \bmod m\) in the Montgomery form.
--
-- @since 1.2.6.0
{-# INLINE powMod #-}
powMod :: (HasCallStack) => Montgomery64 -> Word64 -> Int -> Word64
powMod :: HasCallStack => Montgomery64 -> Word64 -> Int -> Word64
powMod Montgomery64
mont Word64
x0 Int
n0 = Int -> Word64 -> Word64 -> Word64
inner Int
n0 (Montgomery64 -> Word64 -> Word64
encode Montgomery64
mont Word64
1) Word64
x0
  where
    !()
_ = HasCallStack => Bool -> String -> ()
Bool -> String -> ()
ACIA.runtimeAssert (Int
0 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
n0) (String -> ()) -> String -> ()
forall a b. (a -> b) -> a -> b
$ String
"AtCoder.Extra.Math.Montgomery64.powMod: given negative exponential `n`: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
n0 String -> ShowS
forall a. [a] -> [a] -> [a]
++ ShowS
forall a. Show a => a -> String
show String
"`"
    inner :: Int -> Word64 -> Word64 -> Word64
    inner :: Int -> Word64 -> Word64 -> Word64
inner !Int
n !Word64
r !Word64
y
      | Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = Word64
r
      | Bool
otherwise =
          let !r' :: Word64
r' = if Int -> Bool
forall a. Integral a => a -> Bool
odd Int
n then Montgomery64 -> Word64 -> Word64 -> Word64
mulMod Montgomery64
mont Word64
r Word64
y else Word64
r
              !y' :: Word64
y' = Montgomery64 -> Word64 -> Word64 -> Word64
mulMod Montgomery64
mont Word64
y Word64
y
           in Int -> Word64 -> Word64 -> Word64
inner (Int
n Int -> Int -> Int
forall a. Bits a => a -> Int -> a
!>>. Int
1) Word64
r' Word64
y'

-- | \(O(1)\) Compares two values of Montgomery form and returns whether they represent the same
-- value.
--
-- @since 1.2.6.0
{-# INLINE eq #-}
eq :: Word64 -> Word64 -> Word64 -> Bool
eq :: Word64 -> Word64 -> Word64 -> Bool
eq Word64
mM64 Word64
a Word64
b = Word64
a' Word64 -> Word64 -> Bool
forall a. Eq a => a -> a -> Bool
== Word64
b'
  where
    !a' :: Word64
a' = if Word64
a Word64 -> Word64 -> Bool
forall a. Ord a => a -> a -> Bool
< Word64
mM64 then Word64
a else Word64
a Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- Word64
mM64
    !b' :: Word64
b' = if Word64
b Word64 -> Word64 -> Bool
forall a. Ord a => a -> a -> Bool
< Word64
mM64 then Word64
b else Word64
b Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- Word64
mM64