{-# LANGUAGE MagicHash #-}
{-# LANGUAGE NamedFieldPuns #-}
module AtCoder.Extra.Math.Montgomery64
(
Montgomery64,
new,
fromVal,
umod,
encode,
decode,
reduce,
addMod,
subMod,
mulMod,
powMod,
eq,
)
where
import AtCoder.Internal.Assert qualified as ACIA
import Data.Bits (bit, (!>>.))
import Data.WideWord.Word128 (Word128 (..))
import Data.Word (Word64)
import GHC.Exts (Proxy#)
import GHC.Stack (HasCallStack)
import GHC.TypeNats (KnownNat, natVal')
data Montgomery64 = Montgomery64
{ Montgomery64 -> Word64
mM64 :: {-# UNPACK #-} !Word64,
Montgomery64 -> Word64
rM64 :: {-# UNPACK #-} !Word64,
Montgomery64 -> Word64
n2M64 :: {-# UNPACK #-} !Word64
}
deriving
(
Montgomery64 -> Montgomery64 -> Bool
(Montgomery64 -> Montgomery64 -> Bool)
-> (Montgomery64 -> Montgomery64 -> Bool) -> Eq Montgomery64
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Montgomery64 -> Montgomery64 -> Bool
== :: Montgomery64 -> Montgomery64 -> Bool
$c/= :: Montgomery64 -> Montgomery64 -> Bool
/= :: Montgomery64 -> Montgomery64 -> Bool
Eq,
Int -> Montgomery64 -> ShowS
[Montgomery64] -> ShowS
Montgomery64 -> String
(Int -> Montgomery64 -> ShowS)
-> (Montgomery64 -> String)
-> ([Montgomery64] -> ShowS)
-> Show Montgomery64
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Montgomery64 -> ShowS
showsPrec :: Int -> Montgomery64 -> ShowS
$cshow :: Montgomery64 -> String
show :: Montgomery64 -> String
$cshowList :: [Montgomery64] -> ShowS
showList :: [Montgomery64] -> ShowS
Show
)
{-# NOINLINE new #-}
new :: forall a. (KnownNat a) => Proxy# a -> Montgomery64
new :: forall (a :: Nat). KnownNat a => Proxy# a -> Montgomery64
new Proxy# a
p = Word64 -> Montgomery64
fromVal (Word64 -> Montgomery64) -> (Nat -> Word64) -> Nat -> Montgomery64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Nat -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Nat -> Montgomery64) -> Nat -> Montgomery64
forall a b. (a -> b) -> a -> b
$! Proxy# a -> Nat
forall (n :: Nat). KnownNat n => Proxy# n -> Nat
natVal' Proxy# a
p
{-# INLINE fromVal #-}
fromVal :: Word64 -> Montgomery64
fromVal :: Word64 -> Montgomery64
fromVal Word64
m =
let !Word128
m128 :: Word128 = Word64 -> Word128
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
m
!n2 :: Word64
n2 = Word128 -> Word64
word128Lo64 (Word128 -> Word64) -> Word128 -> Word64
forall a b. (a -> b) -> a -> b
$ (-Word128
m128) Word128 -> Word128 -> Word128
forall a. Integral a => a -> a -> a
`mod` Word128
m128
!r :: Word64
r = Word64 -> Int -> Word64
getR Word64
m Int
0
!()
_ = HasCallStack => Bool -> String -> ()
Bool -> String -> ()
ACIA.runtimeAssert (Word64
r Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
* Word64
m Word64 -> Word64 -> Bool
forall a. Eq a => a -> a -> Bool
== -Word64
1) String
"AtCoder.Extra.Montgomery64.fromVal: internal implementation error"
in Word64 -> Word64 -> Word64 -> Montgomery64
Montgomery64 Word64
m Word64
r Word64
n2
where
!()
_ = HasCallStack => Bool -> String -> ()
Bool -> String -> ()
ACIA.runtimeAssert (Word64 -> Bool
forall a. Integral a => a -> Bool
odd Word64
m Bool -> Bool -> Bool
&& Word64
m Word64 -> Word64 -> Bool
forall a. Ord a => a -> a -> Bool
<= Int -> Word64
forall a. Bits a => Int -> a
bit Int
62) (String -> ()) -> String -> ()
forall a b. (a -> b) -> a -> b
$ String
"AtCoder.Extra.Montgomery64.fromVal: not given odd modulus value that is less than or equal to 2^62: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Word64 -> String
forall a. Show a => a -> String
show Word64
m
getR :: Word64 -> Int -> Word64
getR :: Word64 -> Int -> Word64
getR !Word64
acc Int
i
| Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
5 = -Word64
acc
| Bool
otherwise = Word64 -> Int -> Word64
getR (Word64
acc Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
* (Word64
2 Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- Word64
m Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
* Word64
acc)) (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
{-# INLINE umod #-}
umod :: Montgomery64 -> Word64
umod :: Montgomery64 -> Word64
umod Montgomery64 {Word64
mM64 :: Montgomery64 -> Word64
mM64 :: Word64
mM64} = Word64
mM64
{-# INLINE encode #-}
encode :: Montgomery64 -> Word64 -> Word64
encode :: Montgomery64 -> Word64 -> Word64
encode mont :: Montgomery64
mont@Montgomery64 {Word64
n2M64 :: Montgomery64 -> Word64
n2M64 :: Word64
n2M64} Word64
x = Montgomery64 -> Word128 -> Word64
reduce Montgomery64
mont (Word128 -> Word64) -> Word128 -> Word64
forall a b. (a -> b) -> a -> b
$! Word64 -> Word128
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
x Word128 -> Word128 -> Word128
forall a. Num a => a -> a -> a
* Word64 -> Word128
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
n2M64
{-# INLINE decode #-}
decode :: Montgomery64 -> Word64 -> Word64
decode :: Montgomery64 -> Word64 -> Word64
decode mont :: Montgomery64
mont@Montgomery64 {Word64
mM64 :: Montgomery64 -> Word64
mM64 :: Word64
mM64} Word64
x =
let !res :: Word64
res = Montgomery64 -> Word128 -> Word64
reduce Montgomery64
mont (Word128 -> Word64) -> Word128 -> Word64
forall a b. (a -> b) -> a -> b
$! Word64 -> Word128
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
x
in if Word64
res Word64 -> Word64 -> Bool
forall a. Ord a => a -> a -> Bool
>= Word64
mM64 then Word64
res Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- Word64
mM64 else Word64
res
{-# INLINE reduce #-}
reduce :: Montgomery64 -> Word128 -> Word64
reduce :: Montgomery64 -> Word128 -> Word64
reduce Montgomery64 {Word64
mM64 :: Montgomery64 -> Word64
mM64 :: Word64
mM64, Word64
rM64 :: Montgomery64 -> Word64
rM64 :: Word64
rM64} Word128
x =
Word128 -> Word64
word128Hi64 (Word128 -> Word64) -> Word128 -> Word64
forall a b. (a -> b) -> a -> b
$!
(Word128
x Word128 -> Word128 -> Word128
forall a. Num a => a -> a -> a
+ Word64 -> Word128
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word128 -> Word64
word128Lo64 Word128
x Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
* Word64
rM64) Word128 -> Word128 -> Word128
forall a. Num a => a -> a -> a
* Word64 -> Word128
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
mM64)
{-# INLINE addMod #-}
addMod :: Word64 -> Word64 -> Word64 -> Word64
addMod :: Word64 -> Word64 -> Word64 -> Word64
addMod Word64
m Word64
a Word64
b
| Word64
x' Word64 -> Word64 -> Bool
forall a. Ord a => a -> a -> Bool
>= Word64
m = Word64
x' Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- Word64
m
| Bool
otherwise = Word64
x'
where
!x' :: Word64
x' = Word64
a Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Word64
b
{-# INLINE subMod #-}
subMod :: Word64 -> Word64 -> Word64 -> Word64
subMod :: Word64 -> Word64 -> Word64 -> Word64
subMod Word64
m Word64
a Word64
b
| Word64
a Word64 -> Word64 -> Bool
forall a. Ord a => a -> a -> Bool
>= Word64
b = Word64
a Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- Word64
b
| Bool
otherwise = Word64
a Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- Word64
b Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Word64
m
{-# INLINE mulMod #-}
mulMod :: Montgomery64 -> Word64 -> Word64 -> Word64
mulMod :: Montgomery64 -> Word64 -> Word64 -> Word64
mulMod Montgomery64
mont Word64
a Word64
b = Montgomery64 -> Word128 -> Word64
reduce Montgomery64
mont (Word128 -> Word64) -> Word128 -> Word64
forall a b. (a -> b) -> a -> b
$! Word64 -> Word128
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
a Word128 -> Word128 -> Word128
forall a. Num a => a -> a -> a
* Word64 -> Word128
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
b
{-# INLINE powMod #-}
powMod :: (HasCallStack) => Montgomery64 -> Word64 -> Int -> Word64
powMod :: HasCallStack => Montgomery64 -> Word64 -> Int -> Word64
powMod Montgomery64
mont Word64
x0 Int
n0 = Int -> Word64 -> Word64 -> Word64
inner Int
n0 (Montgomery64 -> Word64 -> Word64
encode Montgomery64
mont Word64
1) Word64
x0
where
!()
_ = HasCallStack => Bool -> String -> ()
Bool -> String -> ()
ACIA.runtimeAssert (Int
0 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
n0) (String -> ()) -> String -> ()
forall a b. (a -> b) -> a -> b
$ String
"AtCoder.Extra.Math.Montgomery64.powMod: given negative exponential `n`: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
n0 String -> ShowS
forall a. [a] -> [a] -> [a]
++ ShowS
forall a. Show a => a -> String
show String
"`"
inner :: Int -> Word64 -> Word64 -> Word64
inner :: Int -> Word64 -> Word64 -> Word64
inner !Int
n !Word64
r !Word64
y
| Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = Word64
r
| Bool
otherwise =
let !r' :: Word64
r' = if Int -> Bool
forall a. Integral a => a -> Bool
odd Int
n then Montgomery64 -> Word64 -> Word64 -> Word64
mulMod Montgomery64
mont Word64
r Word64
y else Word64
r
!y' :: Word64
y' = Montgomery64 -> Word64 -> Word64 -> Word64
mulMod Montgomery64
mont Word64
y Word64
y
in Int -> Word64 -> Word64 -> Word64
inner (Int
n Int -> Int -> Int
forall a. Bits a => a -> Int -> a
!>>. Int
1) Word64
r' Word64
y'
{-# INLINE eq #-}
eq :: Word64 -> Word64 -> Word64 -> Bool
eq :: Word64 -> Word64 -> Word64 -> Bool
eq Word64
mM64 Word64
a Word64
b = Word64
a' Word64 -> Word64 -> Bool
forall a. Eq a => a -> a -> Bool
== Word64
b'
where
!a' :: Word64
a' = if Word64
a Word64 -> Word64 -> Bool
forall a. Ord a => a -> a -> Bool
< Word64
mM64 then Word64
a else Word64
a Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- Word64
mM64
!b' :: Word64
b' = if Word64
b Word64 -> Word64 -> Bool
forall a. Ord a => a -> a -> Bool
< Word64
mM64 then Word64
b else Word64
b Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- Word64
mM64