{-# LANGUAGE MagicHash #-}
{-# LANGUAGE RecordWildCards #-}
{-# OPTIONS_HADDOCK hide #-}
module AtCoder.Internal.Convolution
(
FftInfo,
newInfo,
butterfly,
butterflyInv,
convolutionNaive,
convolutionFft,
)
where
import AtCoder.Internal.Bit qualified as ACIB
import AtCoder.ModInt qualified as AM
import Control.Monad (when)
import Control.Monad.Fix (fix)
import Control.Monad.Primitive (PrimMonad, PrimState)
import Data.Bits (bit, complement, countTrailingZeros, (.<<.), (.>>.))
import Data.Foldable
import Data.Vector.Generic qualified as VG
import Data.Vector.Generic.Mutable qualified as VGM
import Data.Vector.Unboxed qualified as VU
import Data.Vector.Unboxed.Mutable qualified as VUM
import Data.Word (Word64)
import GHC.Exts (proxy#)
import GHC.TypeNats (natVal')
data FftInfo p = FftInfo
{ forall {k} (p :: k). FftInfo p -> Vector (ModInt p)
rootFft :: !(VU.Vector (AM.ModInt p)),
forall {k} (p :: k). FftInfo p -> Vector (ModInt p)
iRootFft :: !(VU.Vector (AM.ModInt p)),
forall {k} (p :: k). FftInfo p -> Vector (ModInt p)
rate2Fft :: !(VU.Vector (AM.ModInt p)),
forall {k} (p :: k). FftInfo p -> Vector (ModInt p)
iRate2Fft :: !(VU.Vector (AM.ModInt p)),
forall {k} (p :: k). FftInfo p -> Vector (ModInt p)
rate3Fft :: !(VU.Vector (AM.ModInt p)),
forall {k} (p :: k). FftInfo p -> Vector (ModInt p)
iRate3Fft :: !(VU.Vector (AM.ModInt p))
}
deriving
(
FftInfo p -> FftInfo p -> Bool
(FftInfo p -> FftInfo p -> Bool)
-> (FftInfo p -> FftInfo p -> Bool) -> Eq (FftInfo p)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall k (p :: k). FftInfo p -> FftInfo p -> Bool
$c== :: forall k (p :: k). FftInfo p -> FftInfo p -> Bool
== :: FftInfo p -> FftInfo p -> Bool
$c/= :: forall k (p :: k). FftInfo p -> FftInfo p -> Bool
/= :: FftInfo p -> FftInfo p -> Bool
Eq,
Int -> FftInfo p -> ShowS
[FftInfo p] -> ShowS
FftInfo p -> String
(Int -> FftInfo p -> ShowS)
-> (FftInfo p -> String)
-> ([FftInfo p] -> ShowS)
-> Show (FftInfo p)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall k (p :: k). Int -> FftInfo p -> ShowS
forall k (p :: k). [FftInfo p] -> ShowS
forall k (p :: k). FftInfo p -> String
$cshowsPrec :: forall k (p :: k). Int -> FftInfo p -> ShowS
showsPrec :: Int -> FftInfo p -> ShowS
$cshow :: forall k (p :: k). FftInfo p -> String
show :: FftInfo p -> String
$cshowList :: forall k (p :: k). [FftInfo p] -> ShowS
showList :: [FftInfo p] -> ShowS
Show
)
{-# INLINE newInfo #-}
newInfo :: forall m p. (PrimMonad m, AM.Modulus p) => m (FftInfo p)
newInfo :: forall (m :: * -> *) (p :: Nat).
(PrimMonad m, Modulus p) =>
m (FftInfo p)
newInfo = do
let !g :: Int
g = Proxy# p -> Int
forall (a :: Nat). Modulus a => Proxy# a -> Int
AM.primitiveRootModulus (forall (a :: Nat). Proxy# a
forall {k} (a :: k). Proxy# a
proxy# @p)
let !m :: Int
m = Nat -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Nat -> Int) -> Nat -> Int
forall a b. (a -> b) -> a -> b
$ Proxy# p -> Nat
forall (n :: Nat). KnownNat n => Proxy# n -> Nat
natVal' (forall (a :: Nat). Proxy# a
forall {k} (a :: k). Proxy# a
proxy# @p)
let !rank2 :: Int
rank2 = Int -> Int
forall b. FiniteBits b => b -> Int
countTrailingZeros (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$ Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
MVector (PrimState m) (ModInt p)
root <- Int -> m (MVector (PrimState m) (ModInt p))
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
VUM.unsafeNew (Int
rank2 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
MVector (PrimState m) (ModInt p)
iRoot <- Int -> m (MVector (PrimState m) (ModInt p))
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
VUM.unsafeNew (Int
rank2 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
MVector (PrimState m) (ModInt p)
rate2 <- Int -> m (MVector (PrimState m) (ModInt p))
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
VUM.unsafeNew (Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
0 (Int
rank2 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1))
MVector (PrimState m) (ModInt p)
iRate2 <- Int -> m (MVector (PrimState m) (ModInt p))
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
VUM.unsafeNew (Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
0 (Int
rank2 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1))
MVector (PrimState m) (ModInt p)
rate3 <- Int -> m (MVector (PrimState m) (ModInt p))
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
VUM.unsafeNew (Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
0 (Int
rank2 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
3 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1))
MVector (PrimState m) (ModInt p)
iRate3 <- Int -> m (MVector (PrimState m) (ModInt p))
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
VUM.unsafeNew (Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
0 (Int
rank2 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
3 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1))
MVector (PrimState m) (ModInt p) -> Int -> ModInt p -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector (PrimState m) (ModInt p)
root Int
rank2 (ModInt p -> m ()) -> (Int -> ModInt p) -> Int -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ModInt p -> Int -> ModInt p
forall (a :: Nat).
(HasCallStack, KnownNat a) =>
ModInt a -> Int -> ModInt a
AM.pow (Int -> ModInt p
forall (a :: Nat). KnownNat a => Int -> ModInt a
AM.new Int
g) (Int -> m ()) -> Int -> m ()
forall a b. (a -> b) -> a -> b
$ (Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Int -> Int -> Int
forall a. Bits a => a -> Int -> a
.>>. Int
rank2
MVector (PrimState m) (ModInt p) -> Int -> ModInt p -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector (PrimState m) (ModInt p)
iRoot Int
rank2 (ModInt p -> m ()) -> (ModInt p -> ModInt p) -> ModInt p -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ModInt p -> ModInt p
forall (a :: Nat).
(HasCallStack, Modulus a) =>
ModInt a -> ModInt a
AM.inv (ModInt p -> m ()) -> m (ModInt p) -> m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< MVector (PrimState m) (ModInt p) -> Int -> m (ModInt p)
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) (ModInt p)
root Int
rank2
[Int] -> (Int -> m ()) -> m ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ [Int
rank2 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1, Int
rank2 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
2 .. Int
0] ((Int -> m ()) -> m ()) -> (Int -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
ModInt p
r <- MVector (PrimState m) (ModInt p) -> Int -> m (ModInt p)
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) (ModInt p)
root (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
ModInt p
ir <- MVector (PrimState m) (ModInt p) -> Int -> m (ModInt p)
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) (ModInt p)
iRoot (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
MVector (PrimState m) (ModInt p) -> Int -> ModInt p -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector (PrimState m) (ModInt p)
root Int
i (ModInt p -> m ()) -> ModInt p -> m ()
forall a b. (a -> b) -> a -> b
$! ModInt p
r ModInt p -> ModInt p -> ModInt p
forall a. Num a => a -> a -> a
* ModInt p
r
MVector (PrimState m) (ModInt p) -> Int -> ModInt p -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector (PrimState m) (ModInt p)
iRoot Int
i (ModInt p -> m ()) -> ModInt p -> m ()
forall a b. (a -> b) -> a -> b
$! ModInt p
ir ModInt p -> ModInt p -> ModInt p
forall a. Num a => a -> a -> a
* ModInt p
ir
((ModInt p, ModInt p) -> Int -> m (ModInt p, ModInt p))
-> (ModInt p, ModInt p) -> Vector Int -> m ()
forall (m :: * -> *) b a.
(Monad m, Unbox b) =>
(a -> b -> m a) -> a -> Vector b -> m ()
VU.foldM'_
( \(!ModInt p
prod, !ModInt p
iProd) Int
i -> do
ModInt p
r <- MVector (PrimState m) (ModInt p) -> Int -> m (ModInt p)
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) (ModInt p)
root (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
2)
ModInt p
ir <- MVector (PrimState m) (ModInt p) -> Int -> m (ModInt p)
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) (ModInt p)
iRoot (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
2)
MVector (PrimState m) (ModInt p) -> Int -> ModInt p -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector (PrimState m) (ModInt p)
rate2 Int
i (ModInt p -> m ()) -> ModInt p -> m ()
forall a b. (a -> b) -> a -> b
$! ModInt p
r ModInt p -> ModInt p -> ModInt p
forall a. Num a => a -> a -> a
* ModInt p
prod
MVector (PrimState m) (ModInt p) -> Int -> ModInt p -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector (PrimState m) (ModInt p)
iRate2 Int
i (ModInt p -> m ()) -> ModInt p -> m ()
forall a b. (a -> b) -> a -> b
$! ModInt p
ir ModInt p -> ModInt p -> ModInt p
forall a. Num a => a -> a -> a
* ModInt p
iProd
(ModInt p, ModInt p) -> m (ModInt p, ModInt p)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ModInt p
prod ModInt p -> ModInt p -> ModInt p
forall a. Num a => a -> a -> a
* ModInt p
ir, ModInt p
iProd ModInt p -> ModInt p -> ModInt p
forall a. Num a => a -> a -> a
* ModInt p
r)
)
(ModInt p
1, ModInt p
1)
(Int -> (Int -> Int) -> Vector Int
forall a. Unbox a => Int -> (Int -> a) -> Vector a
VU.generate (Int
rank2 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Int -> Int
forall a. a -> a
id)
((ModInt p, ModInt p) -> Int -> m (ModInt p, ModInt p))
-> (ModInt p, ModInt p) -> Vector Int -> m ()
forall (m :: * -> *) b a.
(Monad m, Unbox b) =>
(a -> b -> m a) -> a -> Vector b -> m ()
VU.foldM'_
( \(!ModInt p
prod, !ModInt p
iProd) Int
i -> do
ModInt p
r <- MVector (PrimState m) (ModInt p) -> Int -> m (ModInt p)
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) (ModInt p)
root (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
3)
ModInt p
ir <- MVector (PrimState m) (ModInt p) -> Int -> m (ModInt p)
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) (ModInt p)
iRoot (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
3)
MVector (PrimState m) (ModInt p) -> Int -> ModInt p -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector (PrimState m) (ModInt p)
rate3 Int
i (ModInt p -> m ()) -> ModInt p -> m ()
forall a b. (a -> b) -> a -> b
$! ModInt p
r ModInt p -> ModInt p -> ModInt p
forall a. Num a => a -> a -> a
* ModInt p
prod
MVector (PrimState m) (ModInt p) -> Int -> ModInt p -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector (PrimState m) (ModInt p)
iRate3 Int
i (ModInt p -> m ()) -> ModInt p -> m ()
forall a b. (a -> b) -> a -> b
$! ModInt p
ir ModInt p -> ModInt p -> ModInt p
forall a. Num a => a -> a -> a
* ModInt p
iProd
(ModInt p, ModInt p) -> m (ModInt p, ModInt p)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ModInt p
prod ModInt p -> ModInt p -> ModInt p
forall a. Num a => a -> a -> a
* ModInt p
ir, ModInt p
iProd ModInt p -> ModInt p -> ModInt p
forall a. Num a => a -> a -> a
* ModInt p
r)
)
(ModInt p
1, ModInt p
1)
(Int -> (Int -> Int) -> Vector Int
forall a. Unbox a => Int -> (Int -> a) -> Vector a
VU.generate (Int
rank2 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
2) Int -> Int
forall a. a -> a
id)
Vector (ModInt p)
rootFft <- MVector (PrimState m) (ModInt p) -> m (Vector (ModInt p))
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
VU.unsafeFreeze MVector (PrimState m) (ModInt p)
root
Vector (ModInt p)
iRootFft <- MVector (PrimState m) (ModInt p) -> m (Vector (ModInt p))
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
VU.unsafeFreeze MVector (PrimState m) (ModInt p)
iRoot
Vector (ModInt p)
rate2Fft <- MVector (PrimState m) (ModInt p) -> m (Vector (ModInt p))
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
VU.unsafeFreeze MVector (PrimState m) (ModInt p)
rate2
Vector (ModInt p)
iRate2Fft <- MVector (PrimState m) (ModInt p) -> m (Vector (ModInt p))
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
VU.unsafeFreeze MVector (PrimState m) (ModInt p)
iRate2
Vector (ModInt p)
rate3Fft <- MVector (PrimState m) (ModInt p) -> m (Vector (ModInt p))
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
VU.unsafeFreeze MVector (PrimState m) (ModInt p)
rate3
Vector (ModInt p)
iRate3Fft <- MVector (PrimState m) (ModInt p) -> m (Vector (ModInt p))
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
VU.unsafeFreeze MVector (PrimState m) (ModInt p)
iRate3
FftInfo p -> m (FftInfo p)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure FftInfo {Vector (ModInt p)
rootFft :: Vector (ModInt p)
iRootFft :: Vector (ModInt p)
rate2Fft :: Vector (ModInt p)
iRate2Fft :: Vector (ModInt p)
rate3Fft :: Vector (ModInt p)
iRate3Fft :: Vector (ModInt p)
rootFft :: Vector (ModInt p)
iRootFft :: Vector (ModInt p)
rate2Fft :: Vector (ModInt p)
iRate2Fft :: Vector (ModInt p)
rate3Fft :: Vector (ModInt p)
iRate3Fft :: Vector (ModInt p)
..}
{-# INLINE butterfly #-}
butterfly ::
forall m p.
(PrimMonad m, AM.Modulus p) =>
FftInfo p ->
VUM.MVector (PrimState m) (AM.ModInt p) ->
m ()
butterfly :: forall (m :: * -> *) (p :: Nat).
(PrimMonad m, Modulus p) =>
FftInfo p -> MVector (PrimState m) (ModInt p) -> m ()
butterfly FftInfo {Vector (ModInt p)
rootFft :: forall {k} (p :: k). FftInfo p -> Vector (ModInt p)
iRootFft :: forall {k} (p :: k). FftInfo p -> Vector (ModInt p)
rate2Fft :: forall {k} (p :: k). FftInfo p -> Vector (ModInt p)
iRate2Fft :: forall {k} (p :: k). FftInfo p -> Vector (ModInt p)
rate3Fft :: forall {k} (p :: k). FftInfo p -> Vector (ModInt p)
iRate3Fft :: forall {k} (p :: k). FftInfo p -> Vector (ModInt p)
rootFft :: Vector (ModInt p)
iRootFft :: Vector (ModInt p)
rate2Fft :: Vector (ModInt p)
iRate2Fft :: Vector (ModInt p)
rate3Fft :: Vector (ModInt p)
iRate3Fft :: Vector (ModInt p)
..} MVector (PrimState m) (ModInt p)
a = do
let n :: Int
n = MVector (PrimState m) (ModInt p) -> Int
forall a s. Unbox a => MVector s a -> Int
VUM.length MVector (PrimState m) (ModInt p)
a
let h :: Int
h = Int -> Int
forall b. FiniteBits b => b -> Int
countTrailingZeros Int
n
let !Word64
m :: Word64 = Nat -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Nat -> Word64) -> Nat -> Word64
forall a b. (a -> b) -> a -> b
$ Proxy# p -> Nat
forall (n :: Nat). KnownNat n => Proxy# n -> Nat
natVal' (forall (a :: Nat). Proxy# a
forall {k} (a :: k). Proxy# a
proxy# @p)
(((Int -> m ()) -> Int -> m ()) -> Int -> m ())
-> Int -> ((Int -> m ()) -> Int -> m ()) -> m ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((Int -> m ()) -> Int -> m ()) -> Int -> m ()
forall a. (a -> a) -> a
fix Int
0 (((Int -> m ()) -> Int -> m ()) -> m ())
-> ((Int -> m ()) -> Int -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Int -> m ()
loop Int
len -> do
Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
h) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
if Int
h Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
len Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1
then do
let !p :: Int
p = Int -> Int
forall a. Bits a => Int -> a
bit (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$ Int
h Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
(ModInt p -> Int -> m (ModInt p)) -> ModInt p -> Vector Int -> m ()
forall (m :: * -> *) b a.
(Monad m, Unbox b) =>
(a -> b -> m a) -> a -> Vector b -> m ()
VU.foldM'_
( \ !ModInt p
rot Int
s -> do
let offset :: Int
offset = Int
s Int -> Int -> Int
forall a. Bits a => a -> Int -> a
.<<. (Int
h Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
len)
[Int] -> (Int -> m ()) -> m ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ [Int
0 .. Int
p Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] ((Int -> m ()) -> m ()) -> (Int -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
ModInt p
l <- MVector (PrimState m) (ModInt p) -> Int -> m (ModInt p)
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) (ModInt p)
a (Int -> m (ModInt p)) -> Int -> m (ModInt p)
forall a b. (a -> b) -> a -> b
$ Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
offset
ModInt p
r <- (ModInt p -> ModInt p -> ModInt p
forall a. Num a => a -> a -> a
* ModInt p
rot) (ModInt p -> ModInt p) -> m (ModInt p) -> m (ModInt p)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState m) (ModInt p) -> Int -> m (ModInt p)
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) (ModInt p)
a (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
p)
MVector (PrimState m) (ModInt p) -> Int -> ModInt p -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector (PrimState m) (ModInt p)
a (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
offset) (ModInt p -> m ()) -> ModInt p -> m ()
forall a b. (a -> b) -> a -> b
$! ModInt p
l ModInt p -> ModInt p -> ModInt p
forall a. Num a => a -> a -> a
+ ModInt p
r
MVector (PrimState m) (ModInt p) -> Int -> ModInt p -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector (PrimState m) (ModInt p)
a (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
p) (ModInt p -> m ()) -> ModInt p -> m ()
forall a b. (a -> b) -> a -> b
$! ModInt p
l ModInt p -> ModInt p -> ModInt p
forall a. Num a => a -> a -> a
- ModInt p
r
if Int
s Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int -> Int
forall a. Bits a => Int -> a
bit Int
len
then ModInt p -> m (ModInt p)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ModInt p -> m (ModInt p))
-> (ModInt p -> ModInt p) -> ModInt p -> m (ModInt p)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ModInt p
rot *) (ModInt p -> m (ModInt p)) -> ModInt p -> m (ModInt p)
forall a b. (a -> b) -> a -> b
$ Vector (ModInt p)
rate2Fft Vector (ModInt p) -> Int -> ModInt p
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! Int -> Int
forall b. FiniteBits b => b -> Int
countTrailingZeros (Int -> Int
forall a. Bits a => a -> a
complement Int
s)
else ModInt p -> m (ModInt p)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ModInt p
rot
)
(forall (a :: Nat). KnownNat a => Word32 -> ModInt a
AM.new32 @p Word32
1)
(Int -> (Int -> Int) -> Vector Int
forall a. Unbox a => Int -> (Int -> a) -> Vector a
VU.generate (Int -> Int
forall a. Bits a => Int -> a
bit Int
len) Int -> Int
forall a. a -> a
id)
Int -> m ()
loop (Int -> m ()) -> Int -> m ()
forall a b. (a -> b) -> a -> b
$ Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
else do
let p :: Int
p = Int -> Int
forall a. Bits a => Int -> a
bit (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$ Int
h Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
2
let !imag :: Word64
imag = ModInt p -> Word64
forall (a :: Nat). KnownNat a => ModInt a -> Word64
AM.val64 (ModInt p -> Word64) -> ModInt p -> Word64
forall a b. (a -> b) -> a -> b
$ Vector (ModInt p)
rootFft Vector (ModInt p) -> Int -> ModInt p
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! Int
2
(ModInt p -> Int -> m (ModInt p)) -> ModInt p -> Vector Int -> m ()
forall (m :: * -> *) b a.
(Monad m, Unbox b) =>
(a -> b -> m a) -> a -> Vector b -> m ()
VU.foldM'_
( \ !ModInt p
rot Int
s -> do
let !rot1 :: Word64
rot1 = ModInt p -> Word64
forall (a :: Nat). KnownNat a => ModInt a -> Word64
AM.val64 ModInt p
rot
let !rot2_ :: ModInt p
rot2_ = ModInt p
rot ModInt p -> ModInt p -> ModInt p
forall a. Num a => a -> a -> a
* ModInt p
rot
let !rot2 :: Word64
rot2 = ModInt p -> Word64
forall (a :: Nat). KnownNat a => ModInt a -> Word64
AM.val64 ModInt p
rot2_
let !rot3 :: Word64
rot3 = ModInt p -> Word64
forall (a :: Nat). KnownNat a => ModInt a -> Word64
AM.val64 (ModInt p -> Word64) -> ModInt p -> Word64
forall a b. (a -> b) -> a -> b
$ ModInt p
rot2_ ModInt p -> ModInt p -> ModInt p
forall a. Num a => a -> a -> a
* ModInt p
rot
let !offset :: Int
offset = Int
s Int -> Int -> Int
forall a. Bits a => a -> Int -> a
.<<. (Int
h Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
len)
let !Word64
mod2 :: Word64 = Word64
m Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
* Word64
m
[Int] -> (Int -> m ()) -> m ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ [Int
0 .. Int
p Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] ((Int -> m ()) -> m ()) -> (Int -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
!Word64
a0 <- ModInt p -> Word64
forall (a :: Nat). KnownNat a => ModInt a -> Word64
AM.val64 (ModInt p -> Word64) -> m (ModInt p) -> m Word64
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState m) (ModInt p) -> Int -> m (ModInt p)
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) (ModInt p)
a (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
offset)
!Word64
a1 <- (Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
* Word64
rot1) (Word64 -> Word64) -> (ModInt p -> Word64) -> ModInt p -> Word64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ModInt p -> Word64
forall (a :: Nat). KnownNat a => ModInt a -> Word64
AM.val64 (ModInt p -> Word64) -> m (ModInt p) -> m Word64
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState m) (ModInt p) -> Int -> m (ModInt p)
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) (ModInt p)
a (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
p)
!Word64
a2 <- (Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
* Word64
rot2) (Word64 -> Word64) -> (ModInt p -> Word64) -> ModInt p -> Word64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ModInt p -> Word64
forall (a :: Nat). KnownNat a => ModInt a -> Word64
AM.val64 (ModInt p -> Word64) -> m (ModInt p) -> m Word64
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState m) (ModInt p) -> Int -> m (ModInt p)
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) (ModInt p)
a (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
p)
!Word64
a3 <- (Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
* Word64
rot3) (Word64 -> Word64) -> (ModInt p -> Word64) -> ModInt p -> Word64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ModInt p -> Word64
forall (a :: Nat). KnownNat a => ModInt a -> Word64
AM.val64 (ModInt p -> Word64) -> m (ModInt p) -> m Word64
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState m) (ModInt p) -> Int -> m (ModInt p)
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) (ModInt p)
a (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
3 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
p)
let !a1na3imag :: Word64
a1na3imag = (Word64
a1 Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Word64
mod2 Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- Word64
a3) Word64 -> Word64 -> Word64
forall a. Integral a => a -> a -> a
`mod` Word64
m Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
* Word64
imag Word64 -> Word64 -> Word64
forall a. Integral a => a -> a -> a
`mod` Word64
m
let !na2 :: Word64
na2 = Word64
mod2 Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- Word64
a2
MVector (PrimState m) (ModInt p) -> Int -> ModInt p -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector (PrimState m) (ModInt p)
a (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
offset) (ModInt p -> m ()) -> (Word64 -> ModInt p) -> Word64 -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word64 -> ModInt p
forall (a :: Nat). KnownNat a => Word64 -> ModInt a
AM.new64 (Word64 -> m ()) -> Word64 -> m ()
forall a b. (a -> b) -> a -> b
$! Word64
a0 Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Word64
a2 Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Word64
a1 Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Word64
a3
MVector (PrimState m) (ModInt p) -> Int -> ModInt p -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector (PrimState m) (ModInt p)
a (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
p) (ModInt p -> m ()) -> (Word64 -> ModInt p) -> Word64 -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word64 -> ModInt p
forall (a :: Nat). KnownNat a => Word64 -> ModInt a
AM.new64 (Word64 -> m ()) -> Word64 -> m ()
forall a b. (a -> b) -> a -> b
$! Word64
a0 Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Word64
a2 Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ (Word64
2 Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
* Word64
mod2 Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- (Word64
a1 Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Word64
a3))
MVector (PrimState m) (ModInt p) -> Int -> ModInt p -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector (PrimState m) (ModInt p)
a (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
p) (ModInt p -> m ()) -> (Word64 -> ModInt p) -> Word64 -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word64 -> ModInt p
forall (a :: Nat). KnownNat a => Word64 -> ModInt a
AM.new64 (Word64 -> m ()) -> Word64 -> m ()
forall a b. (a -> b) -> a -> b
$! Word64
a0 Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Word64
na2 Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Word64
a1na3imag
MVector (PrimState m) (ModInt p) -> Int -> ModInt p -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector (PrimState m) (ModInt p)
a (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
3 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
p) (ModInt p -> m ()) -> (Word64 -> ModInt p) -> Word64 -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word64 -> ModInt p
forall (a :: Nat). KnownNat a => Word64 -> ModInt a
AM.new64 (Word64 -> m ()) -> Word64 -> m ()
forall a b. (a -> b) -> a -> b
$! Word64
a0 Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Word64
na2 Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ (Word64
mod2 Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- Word64
a1na3imag)
if Int
s Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int -> Int
forall a. Bits a => Int -> a
bit Int
len
then ModInt p -> m (ModInt p)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ModInt p -> m (ModInt p))
-> (ModInt p -> ModInt p) -> ModInt p -> m (ModInt p)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ModInt p
rot *) (ModInt p -> m (ModInt p)) -> ModInt p -> m (ModInt p)
forall a b. (a -> b) -> a -> b
$ Vector (ModInt p)
rate3Fft Vector (ModInt p) -> Int -> ModInt p
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! Int -> Int
forall b. FiniteBits b => b -> Int
countTrailingZeros (Int -> Int
forall a. Bits a => a -> a
complement Int
s)
else ModInt p -> m (ModInt p)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ModInt p
rot
)
(forall (a :: Nat). KnownNat a => Word32 -> ModInt a
AM.unsafeNew @p Word32
1)
(Int -> (Int -> Int) -> Vector Int
forall a. Unbox a => Int -> (Int -> a) -> Vector a
VU.generate (Int -> Int
forall a. Bits a => Int -> a
bit Int
len) Int -> Int
forall a. a -> a
id)
Int -> m ()
loop (Int -> m ()) -> Int -> m ()
forall a b. (a -> b) -> a -> b
$ Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
2
{-# INLINE butterflyInv #-}
butterflyInv ::
forall m p.
(PrimMonad m, AM.Modulus p) =>
FftInfo p ->
VUM.MVector (PrimState m) (AM.ModInt p) ->
m ()
butterflyInv :: forall (m :: * -> *) (p :: Nat).
(PrimMonad m, Modulus p) =>
FftInfo p -> MVector (PrimState m) (ModInt p) -> m ()
butterflyInv FftInfo {Vector (ModInt p)
rootFft :: forall {k} (p :: k). FftInfo p -> Vector (ModInt p)
iRootFft :: forall {k} (p :: k). FftInfo p -> Vector (ModInt p)
rate2Fft :: forall {k} (p :: k). FftInfo p -> Vector (ModInt p)
iRate2Fft :: forall {k} (p :: k). FftInfo p -> Vector (ModInt p)
rate3Fft :: forall {k} (p :: k). FftInfo p -> Vector (ModInt p)
iRate3Fft :: forall {k} (p :: k). FftInfo p -> Vector (ModInt p)
rootFft :: Vector (ModInt p)
iRootFft :: Vector (ModInt p)
rate2Fft :: Vector (ModInt p)
iRate2Fft :: Vector (ModInt p)
rate3Fft :: Vector (ModInt p)
iRate3Fft :: Vector (ModInt p)
..} MVector (PrimState m) (ModInt p)
a = do
let n :: Int
n = MVector (PrimState m) (ModInt p) -> Int
forall a s. Unbox a => MVector s a -> Int
VUM.length MVector (PrimState m) (ModInt p)
a
let h :: Int
h = Int -> Int
forall b. FiniteBits b => b -> Int
countTrailingZeros Int
n
let !Word64
m :: Word64 = Nat -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Nat -> Word64) -> Nat -> Word64
forall a b. (a -> b) -> a -> b
$ Proxy# p -> Nat
forall (n :: Nat). KnownNat n => Proxy# n -> Nat
natVal' (forall (a :: Nat). Proxy# a
forall {k} (a :: k). Proxy# a
proxy# @p)
let !Int
mInt :: Int = Nat -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Nat -> Int) -> Nat -> Int
forall a b. (a -> b) -> a -> b
$ Proxy# p -> Nat
forall (n :: Nat). KnownNat n => Proxy# n -> Nat
natVal' (forall (a :: Nat). Proxy# a
forall {k} (a :: k). Proxy# a
proxy# @p)
(((Int -> m ()) -> Int -> m ()) -> Int -> m ())
-> Int -> ((Int -> m ()) -> Int -> m ()) -> m ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((Int -> m ()) -> Int -> m ()) -> Int -> m ()
forall a. (a -> a) -> a
fix Int
h (((Int -> m ()) -> Int -> m ()) -> m ())
-> ((Int -> m ()) -> Int -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Int -> m ()
loop Int
len -> do
Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
len Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
0) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
if Int
len Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1
then do
let p :: Int
p = Int -> Int
forall a. Bits a => Int -> a
bit (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$ Int
h Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
len
(ModInt p -> Int -> m (ModInt p)) -> ModInt p -> Vector Int -> m ()
forall (m :: * -> *) b a.
(Monad m, Unbox b) =>
(a -> b -> m a) -> a -> Vector b -> m ()
VU.foldM'_
( \ !ModInt p
irot Int
s -> do
let !offset :: Int
offset = Int
s Int -> Int -> Int
forall a. Bits a => a -> Int -> a
.<<. (Int
h Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
[Int] -> (Int -> m ()) -> m ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ [Int
0 .. Int
p Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] ((Int -> m ()) -> m ()) -> (Int -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
ModInt p
l <- MVector (PrimState m) (ModInt p) -> Int -> m (ModInt p)
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) (ModInt p)
a (Int -> m (ModInt p)) -> Int -> m (ModInt p)
forall a b. (a -> b) -> a -> b
$ Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
offset
ModInt p
r <- MVector (PrimState m) (ModInt p) -> Int -> m (ModInt p)
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) (ModInt p)
a (Int -> m (ModInt p)) -> Int -> m (ModInt p)
forall a b. (a -> b) -> a -> b
$ Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
p
MVector (PrimState m) (ModInt p) -> Int -> ModInt p -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector (PrimState m) (ModInt p)
a (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
offset) (ModInt p -> m ()) -> ModInt p -> m ()
forall a b. (a -> b) -> a -> b
$! ModInt p
l ModInt p -> ModInt p -> ModInt p
forall a. Num a => a -> a -> a
+ ModInt p
r
MVector (PrimState m) (ModInt p) -> Int -> ModInt p -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector (PrimState m) (ModInt p)
a (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
p) (ModInt p -> m ()) -> (Int -> ModInt p) -> Int -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> ModInt p
forall (a :: Nat). KnownNat a => Int -> ModInt a
AM.new (Int -> m ()) -> Int -> m ()
forall a b. (a -> b) -> a -> b
$! (Int
mInt Int -> Int -> Int
forall a. Num a => a -> a -> a
+ ModInt p -> Int
forall (a :: Nat). KnownNat a => ModInt a -> Int
AM.val ModInt p
l Int -> Int -> Int
forall a. Num a => a -> a -> a
- ModInt p -> Int
forall (a :: Nat). KnownNat a => ModInt a -> Int
AM.val ModInt p
r) Int -> Int -> Int
forall a. Num a => a -> a -> a
* ModInt p -> Int
forall (a :: Nat). KnownNat a => ModInt a -> Int
AM.val ModInt p
irot
if Int
s Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int -> Int
forall a. Bits a => Int -> a
bit (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
then ModInt p -> m (ModInt p)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ModInt p -> m (ModInt p))
-> (ModInt p -> ModInt p) -> ModInt p -> m (ModInt p)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ModInt p
irot *) (ModInt p -> m (ModInt p)) -> ModInt p -> m (ModInt p)
forall a b. (a -> b) -> a -> b
$ Vector (ModInt p)
iRate2Fft Vector (ModInt p) -> Int -> ModInt p
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! Int -> Int
forall b. FiniteBits b => b -> Int
countTrailingZeros (Int -> Int
forall a. Bits a => a -> a
complement Int
s)
else ModInt p -> m (ModInt p)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ModInt p
irot
)
(forall (a :: Nat). KnownNat a => Word32 -> ModInt a
AM.new32 @p Word32
1)
(Int -> (Int -> Int) -> Vector Int
forall a. Unbox a => Int -> (Int -> a) -> Vector a
VU.generate (Int -> Int
forall a. Bits a => Int -> a
bit (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)) Int -> Int
forall a. a -> a
id)
Int -> m ()
loop (Int -> m ()) -> Int -> m ()
forall a b. (a -> b) -> a -> b
$ Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
else do
let p :: Int
p = Int -> Int
forall a. Bits a => Int -> a
bit (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$ Int
h Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
len
let !iimag :: Word64
iimag = ModInt p -> Word64
forall (a :: Nat). KnownNat a => ModInt a -> Word64
AM.val64 (ModInt p -> Word64) -> ModInt p -> Word64
forall a b. (a -> b) -> a -> b
$ Vector (ModInt p)
iRootFft Vector (ModInt p) -> Int -> ModInt p
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! Int
2
(ModInt p -> Int -> m (ModInt p)) -> ModInt p -> Vector Int -> m ()
forall (m :: * -> *) b a.
(Monad m, Unbox b) =>
(a -> b -> m a) -> a -> Vector b -> m ()
VU.foldM'_
( \ !ModInt p
irot Int
s -> do
let !irot1 :: Word64
irot1 = ModInt p -> Word64
forall (a :: Nat). KnownNat a => ModInt a -> Word64
AM.val64 ModInt p
irot
let !irot2_ :: ModInt p
irot2_ = ModInt p
irot ModInt p -> ModInt p -> ModInt p
forall a. Num a => a -> a -> a
* ModInt p
irot
let !irot2 :: Word64
irot2 = ModInt p -> Word64
forall (a :: Nat). KnownNat a => ModInt a -> Word64
AM.val64 ModInt p
irot2_
let !irot3 :: Word64
irot3 = ModInt p -> Word64
forall (a :: Nat). KnownNat a => ModInt a -> Word64
AM.val64 (ModInt p -> Word64) -> ModInt p -> Word64
forall a b. (a -> b) -> a -> b
$ ModInt p
irot2_ ModInt p -> ModInt p -> ModInt p
forall a. Num a => a -> a -> a
* ModInt p
irot
let !offset :: Int
offset = Int
s Int -> Int -> Int
forall a. Bits a => a -> Int -> a
.<<. (Int
h Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
2)
[Int] -> (Int -> m ()) -> m ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ [Int
0 .. Int
p Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] ((Int -> m ()) -> m ()) -> (Int -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
!Word64
a0 <- ModInt p -> Word64
forall (a :: Nat). KnownNat a => ModInt a -> Word64
AM.val64 (ModInt p -> Word64) -> m (ModInt p) -> m Word64
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState m) (ModInt p) -> Int -> m (ModInt p)
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) (ModInt p)
a (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
0 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
p)
!Word64
a1 <- ModInt p -> Word64
forall (a :: Nat). KnownNat a => ModInt a -> Word64
AM.val64 (ModInt p -> Word64) -> m (ModInt p) -> m Word64
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState m) (ModInt p) -> Int -> m (ModInt p)
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) (ModInt p)
a (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
p)
!Word64
a2 <- ModInt p -> Word64
forall (a :: Nat). KnownNat a => ModInt a -> Word64
AM.val64 (ModInt p -> Word64) -> m (ModInt p) -> m Word64
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState m) (ModInt p) -> Int -> m (ModInt p)
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) (ModInt p)
a (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
p)
!Word64
a3 <- ModInt p -> Word64
forall (a :: Nat). KnownNat a => ModInt a -> Word64
AM.val64 (ModInt p -> Word64) -> m (ModInt p) -> m Word64
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState m) (ModInt p) -> Int -> m (ModInt p)
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) (ModInt p)
a (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
3 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
p)
let !a2na3iimag :: Word64
a2na3iimag = (Word64
m Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Word64
a2 Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- Word64
a3) Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
* Word64
iimag Word64 -> Word64 -> Word64
forall a. Integral a => a -> a -> a
`mod` Word64
m
MVector (PrimState m) (ModInt p) -> Int -> ModInt p -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector (PrimState m) (ModInt p)
a (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
offset) (ModInt p -> m ()) -> (Word64 -> ModInt p) -> Word64 -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word64 -> ModInt p
forall (a :: Nat). KnownNat a => Word64 -> ModInt a
AM.new64 (Word64 -> m ()) -> Word64 -> m ()
forall a b. (a -> b) -> a -> b
$! Word64
a0 Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Word64
a1 Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Word64
a2 Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Word64
a3
MVector (PrimState m) (ModInt p) -> Int -> ModInt p -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector (PrimState m) (ModInt p)
a (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
p) (ModInt p -> m ()) -> (Word64 -> ModInt p) -> Word64 -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word64 -> ModInt p
forall (a :: Nat). KnownNat a => Word64 -> ModInt a
AM.new64 (Word64 -> m ()) -> Word64 -> m ()
forall a b. (a -> b) -> a -> b
$! (Word64
a0 Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ (Word64
m Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- Word64
a1) Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Word64
a2na3iimag) Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
* Word64
irot1
MVector (PrimState m) (ModInt p) -> Int -> ModInt p -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector (PrimState m) (ModInt p)
a (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
p) (ModInt p -> m ()) -> (Word64 -> ModInt p) -> Word64 -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word64 -> ModInt p
forall (a :: Nat). KnownNat a => Word64 -> ModInt a
AM.new64 (Word64 -> m ()) -> Word64 -> m ()
forall a b. (a -> b) -> a -> b
$! (Word64
a0 Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Word64
a1 Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ (Word64
m Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- Word64
a2) Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ (Word64
m Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- Word64
a3)) Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
* Word64
irot2
MVector (PrimState m) (ModInt p) -> Int -> ModInt p -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector (PrimState m) (ModInt p)
a (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
3 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
p) (ModInt p -> m ()) -> (Word64 -> ModInt p) -> Word64 -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word64 -> ModInt p
forall (a :: Nat). KnownNat a => Word64 -> ModInt a
AM.new64 (Word64 -> m ()) -> Word64 -> m ()
forall a b. (a -> b) -> a -> b
$! (Word64
a0 Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ (Word64
m Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- Word64
a1) Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ (Word64
m Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- Word64
a2na3iimag)) Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
* Word64
irot3
if Int
s Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int -> Int
forall a. Bits a => Int -> a
bit (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
2)
then ModInt p -> m (ModInt p)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ModInt p -> m (ModInt p))
-> (ModInt p -> ModInt p) -> ModInt p -> m (ModInt p)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ModInt p
irot *) (ModInt p -> m (ModInt p)) -> ModInt p -> m (ModInt p)
forall a b. (a -> b) -> a -> b
$ Vector (ModInt p)
iRate3Fft Vector (ModInt p) -> Int -> ModInt p
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! Int -> Int
forall b. FiniteBits b => b -> Int
countTrailingZeros (Int -> Int
forall a. Bits a => a -> a
complement Int
s)
else ModInt p -> m (ModInt p)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ModInt p
irot
)
(forall (a :: Nat). KnownNat a => Word32 -> ModInt a
AM.unsafeNew @p Word32
1)
(Int -> (Int -> Int) -> Vector Int
forall a. Unbox a => Int -> (Int -> a) -> Vector a
VU.generate (Int -> Int
forall a. Bits a => Int -> a
bit (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
2)) Int -> Int
forall a. a -> a
id)
Int -> m ()
loop (Int -> m ()) -> Int -> m ()
forall a b. (a -> b) -> a -> b
$ Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
2
{-# INLINE convolutionNaive #-}
convolutionNaive ::
forall p.
(AM.Modulus p) =>
VU.Vector (AM.ModInt p) ->
VU.Vector (AM.ModInt p) ->
VU.Vector (AM.ModInt p)
convolutionNaive :: forall (p :: Nat).
Modulus p =>
Vector (ModInt p) -> Vector (ModInt p) -> Vector (ModInt p)
convolutionNaive Vector (ModInt p)
a Vector (ModInt p)
b = (forall s. ST s (MVector s (ModInt p))) -> Vector (ModInt p)
forall a. Unbox a => (forall s. ST s (MVector s a)) -> Vector a
VU.create ((forall s. ST s (MVector s (ModInt p))) -> Vector (ModInt p))
-> (forall s. ST s (MVector s (ModInt p))) -> Vector (ModInt p)
forall a b. (a -> b) -> a -> b
$ do
let n :: Int
n = Vector (ModInt p) -> Int
forall a. Unbox a => Vector a -> Int
VU.length Vector (ModInt p)
a
let m :: Int
m = Vector (ModInt p) -> Int
forall a. Unbox a => Vector a -> Int
VU.length Vector (ModInt p)
b
MVector s (ModInt p)
ans <- Int -> ModInt p -> ST s (MVector (PrimState (ST s)) (ModInt p))
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
Int -> a -> m (v (PrimState m) a)
VGM.replicate (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) ModInt p
0
if Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
m
then do
[Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ [Int
0 .. Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
j -> do
[Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ [Int
0 .. Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
MVector (PrimState (ST s)) (ModInt p)
-> (ModInt p -> ModInt p) -> Int -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> (a -> a) -> Int -> m ()
VGM.modify MVector s (ModInt p)
MVector (PrimState (ST s)) (ModInt p)
ans (ModInt p -> ModInt p -> ModInt p
forall a. Num a => a -> a -> a
+ Vector (ModInt p)
a Vector (ModInt p) -> Int -> ModInt p
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! Int
i ModInt p -> ModInt p -> ModInt p
forall a. Num a => a -> a -> a
* Vector (ModInt p)
b Vector (ModInt p) -> Int -> ModInt p
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! Int
j) (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j)
else do
[Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ [Int
0 .. Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
[Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ [Int
0 .. Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
j -> do
MVector (PrimState (ST s)) (ModInt p)
-> (ModInt p -> ModInt p) -> Int -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> (a -> a) -> Int -> m ()
VGM.modify MVector s (ModInt p)
MVector (PrimState (ST s)) (ModInt p)
ans (ModInt p -> ModInt p -> ModInt p
forall a. Num a => a -> a -> a
+ Vector (ModInt p)
a Vector (ModInt p) -> Int -> ModInt p
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! Int
i ModInt p -> ModInt p -> ModInt p
forall a. Num a => a -> a -> a
* Vector (ModInt p)
b Vector (ModInt p) -> Int -> ModInt p
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! Int
j) (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j)
MVector s (ModInt p) -> ST s (MVector s (ModInt p))
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure MVector s (ModInt p)
ans
{-# INLINE convolutionFft #-}
convolutionFft ::
forall p.
(AM.Modulus p) =>
VU.Vector (AM.ModInt p) ->
VU.Vector (AM.ModInt p) ->
VU.Vector (AM.ModInt p)
convolutionFft :: forall (p :: Nat).
Modulus p =>
Vector (ModInt p) -> Vector (ModInt p) -> Vector (ModInt p)
convolutionFft Vector (ModInt p)
a_ Vector (ModInt p)
b_ = (forall s. ST s (MVector s (ModInt p))) -> Vector (ModInt p)
forall a. Unbox a => (forall s. ST s (MVector s a)) -> Vector a
VU.create ((forall s. ST s (MVector s (ModInt p))) -> Vector (ModInt p))
-> (forall s. ST s (MVector s (ModInt p))) -> Vector (ModInt p)
forall a b. (a -> b) -> a -> b
$ do
let n :: Int
n = Vector (ModInt p) -> Int
forall a. Unbox a => Vector a -> Int
VU.length Vector (ModInt p)
a_
let m :: Int
m = Vector (ModInt p) -> Int
forall a. Unbox a => Vector a -> Int
VU.length Vector (ModInt p)
b_
let z :: Int
z = Int -> Int
ACIB.bitCeil (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
MVector s (ModInt p)
a <- Int -> ModInt p -> ST s (MVector (PrimState (ST s)) (ModInt p))
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
VUM.replicate Int
z ModInt p
0
Vector (ModInt p) -> (Int -> ModInt p -> ST s ()) -> ST s ()
forall (m :: * -> *) a b.
(Monad m, Unbox a) =>
Vector a -> (Int -> a -> m b) -> m ()
VU.iforM_ Vector (ModInt p)
a_ ((Int -> ModInt p -> ST s ()) -> ST s ())
-> (Int -> ModInt p -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i ModInt p
ai -> do
MVector (PrimState (ST s)) (ModInt p) -> Int -> ModInt p -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector s (ModInt p)
MVector (PrimState (ST s)) (ModInt p)
a Int
i ModInt p
ai
FftInfo p
info <- forall (m :: * -> *) (p :: Nat).
(PrimMonad m, Modulus p) =>
m (FftInfo p)
newInfo @_ @p
FftInfo p -> MVector (PrimState (ST s)) (ModInt p) -> ST s ()
forall (m :: * -> *) (p :: Nat).
(PrimMonad m, Modulus p) =>
FftInfo p -> MVector (PrimState m) (ModInt p) -> m ()
butterfly FftInfo p
info MVector s (ModInt p)
MVector (PrimState (ST s)) (ModInt p)
a
MVector s (ModInt p)
b <- Int -> ModInt p -> ST s (MVector (PrimState (ST s)) (ModInt p))
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
VUM.replicate Int
z ModInt p
0
Vector (ModInt p) -> (Int -> ModInt p -> ST s ()) -> ST s ()
forall (m :: * -> *) a b.
(Monad m, Unbox a) =>
Vector a -> (Int -> a -> m b) -> m ()
VU.iforM_ Vector (ModInt p)
b_ ((Int -> ModInt p -> ST s ()) -> ST s ())
-> (Int -> ModInt p -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i ModInt p
bi -> do
MVector (PrimState (ST s)) (ModInt p) -> Int -> ModInt p -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector s (ModInt p)
MVector (PrimState (ST s)) (ModInt p)
b Int
i ModInt p
bi
FftInfo p -> MVector (PrimState (ST s)) (ModInt p) -> ST s ()
forall (m :: * -> *) (p :: Nat).
(PrimMonad m, Modulus p) =>
FftInfo p -> MVector (PrimState m) (ModInt p) -> m ()
butterfly FftInfo p
info MVector s (ModInt p)
MVector (PrimState (ST s)) (ModInt p)
b
MVector (PrimState (ST s)) (ModInt p)
-> (Int -> ModInt p -> ST s ()) -> ST s ()
forall (m :: * -> *) a b.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> (Int -> a -> m b) -> m ()
VUM.iforM_ MVector s (ModInt p)
MVector (PrimState (ST s)) (ModInt p)
b ((Int -> ModInt p -> ST s ()) -> ST s ())
-> (Int -> ModInt p -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i ModInt p
bi -> do
MVector (PrimState (ST s)) (ModInt p)
-> (ModInt p -> ModInt p) -> Int -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> (a -> a) -> Int -> m ()
VGM.modify MVector s (ModInt p)
MVector (PrimState (ST s)) (ModInt p)
a (ModInt p -> ModInt p -> ModInt p
forall a. Num a => a -> a -> a
* ModInt p
bi) Int
i
FftInfo p -> MVector (PrimState (ST s)) (ModInt p) -> ST s ()
forall (m :: * -> *) (p :: Nat).
(PrimMonad m, Modulus p) =>
FftInfo p -> MVector (PrimState m) (ModInt p) -> m ()
butterflyInv FftInfo p
info MVector s (ModInt p)
MVector (PrimState (ST s)) (ModInt p)
a
let a' :: MVector s (ModInt p)
a' = Int -> MVector s (ModInt p) -> MVector s (ModInt p)
forall a s. Unbox a => Int -> MVector s a -> MVector s a
VUM.take (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) MVector s (ModInt p)
a
let !iz :: ModInt p
iz = ModInt p -> ModInt p
forall (a :: Nat).
(HasCallStack, Modulus a) =>
ModInt a -> ModInt a
AM.inv (ModInt p -> ModInt p) -> ModInt p -> ModInt p
forall a b. (a -> b) -> a -> b
$ Int -> ModInt p
forall (a :: Nat). KnownNat a => Int -> ModInt a
AM.new Int
z
[Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ [Int
0 .. Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
2] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
MVector (PrimState (ST s)) (ModInt p)
-> (ModInt p -> ModInt p) -> Int -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> (a -> a) -> Int -> m ()
VGM.modify MVector s (ModInt p)
MVector (PrimState (ST s)) (ModInt p)
a' (ModInt p -> ModInt p -> ModInt p
forall a. Num a => a -> a -> a
* ModInt p
iz) Int
i
MVector s (ModInt p) -> ST s (MVector s (ModInt p))
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure MVector s (ModInt p)
a'