{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Data.Poly.Internal.Dense.DFT
( dft
, inverseDft
) where
import Prelude hiding (recip, fromIntegral)
import Control.Monad.ST
import Data.Bits hiding (shift)
import Data.Foldable
import Data.Semiring (Semiring(..), Ring(..), minus, fromIntegral)
import Data.Field (Field, recip)
import qualified Data.Vector.Generic as G
import qualified Data.Vector.Generic.Mutable as MG
dft
:: (Ring a, G.Vector v a)
=> a
-> v a
-> v a
dft :: forall a (v :: * -> *). (Ring a, Vector v a) => a -> v a -> v a
dft a
primRoot (v a
xs :: v a)
| Int -> Int
forall a. Bits a => a -> Int
popCount Int
nn Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
1 = [Char] -> v a
forall a. HasCallStack => [Char] -> a
error [Char]
"dft: only vectors of length 2^n are supported"
| Bool
otherwise = Int -> Int -> v a
go Int
0 Int
0
where
nn :: Int
nn = v a -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int
G.length v a
xs
n :: Int
n = Int -> Int
forall b. FiniteBits b => b -> Int
countTrailingZeros Int
nn
roots :: v a
roots :: v a
roots = Int -> (a -> a) -> a -> v a
forall (v :: * -> *) a. Vector v a => Int -> (a -> a) -> a -> v a
G.iterateN
(Int
1 Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`unsafeShiftL` (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1))
(\a
x -> a
x a -> a -> a
forall a b. a -> b -> b
`seq` (a
x a -> a -> a
forall a. Semiring a => a -> a -> a
`times` a
primRoot))
a
forall a. Semiring a => a
one
go :: Int -> Int -> v a
go !Int
offset !Int
shift
| Int
shift Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
n = Int -> Int -> v a -> v a
forall (v :: * -> *) a. Vector v a => Int -> Int -> v a -> v a
G.unsafeSlice Int
offset Int
1 v a
xs
| Bool
otherwise = (forall s. ST s (v a)) -> v a
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (v a)) -> v a) -> (forall s. ST s (v a)) -> v a
forall a b. (a -> b) -> a -> b
$ do
let halfLen :: Int
halfLen = Int
1 Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`unsafeShiftL` (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
shift Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
ys0 :: v a
ys0 = Int -> Int -> v a
go Int
offset (Int
shift Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
ys1 :: v a
ys1 = Int -> Int -> v a
go (Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
shift) (Int
shift Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
Mutable v s a
ys <- Int -> ST s (Mutable v (PrimState (ST s)) a)
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
Int -> m (v (PrimState m) a)
MG.new (Int
halfLen Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
1)
let y00 :: a
y00 = v a -> Int -> a
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
G.unsafeIndex v a
ys0 Int
0
y10 :: a
y10 = v a -> Int -> a
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
G.unsafeIndex v a
ys1 Int
0
Mutable v (PrimState (ST s)) a -> Int -> a -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
MG.unsafeWrite Mutable v s a
Mutable v (PrimState (ST s)) a
ys Int
0 (a -> ST s ()) -> a -> ST s ()
forall a b. (a -> b) -> a -> b
$! a
y00 a -> a -> a
forall a. Semiring a => a -> a -> a
`plus` a
y10
Mutable v (PrimState (ST s)) a -> Int -> a -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
MG.unsafeWrite Mutable v s a
Mutable v (PrimState (ST s)) a
ys Int
halfLen (a -> ST s ()) -> a -> ST s ()
forall a b. (a -> b) -> a -> b
$! a
y00 a -> a -> a
forall a. Ring a => a -> a -> a
`minus` a
y10
[Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
1..Int
halfLen Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
k -> do
let y0 :: a
y0 = v a -> Int -> a
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
G.unsafeIndex v a
ys0 Int
k
y1 :: a
y1 = v a -> Int -> a
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
G.unsafeIndex v a
ys1 Int
k a -> a -> a
forall a. Semiring a => a -> a -> a
`times`
v a -> Int -> a
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
G.unsafeIndex v a
roots (Int
k Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
shift)
Mutable v (PrimState (ST s)) a -> Int -> a -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
MG.unsafeWrite Mutable v s a
Mutable v (PrimState (ST s)) a
ys Int
k (a -> ST s ()) -> a -> ST s ()
forall a b. (a -> b) -> a -> b
$! a
y0 a -> a -> a
forall a. Semiring a => a -> a -> a
`plus` a
y1
Mutable v (PrimState (ST s)) a -> Int -> a -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
MG.unsafeWrite Mutable v s a
Mutable v (PrimState (ST s)) a
ys (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
halfLen) (a -> ST s ()) -> a -> ST s ()
forall a b. (a -> b) -> a -> b
$! a
y0 a -> a -> a
forall a. Ring a => a -> a -> a
`minus` a
y1
Mutable v (PrimState (ST s)) a -> ST s (v a)
forall (m :: * -> *) (v :: * -> *) a.
(PrimMonad m, Vector v a) =>
Mutable v (PrimState m) a -> m (v a)
G.unsafeFreeze Mutable v s a
Mutable v (PrimState (ST s)) a
ys
{-# INLINABLE dft #-}
inverseDft
:: (Field a, G.Vector v a)
=> a
-> v a
-> v a
inverseDft :: forall a (v :: * -> *). (Field a, Vector v a) => a -> v a -> v a
inverseDft a
primRoot v a
ys = (a -> a) -> v a -> v a
forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
G.map (a -> a -> a
forall a. Semiring a => a -> a -> a
`times` a
invN) (v a -> v a) -> v a -> v a
forall a b. (a -> b) -> a -> b
$ a -> v a -> v a
forall a (v :: * -> *). (Ring a, Vector v a) => a -> v a -> v a
dft (a -> a
forall a. Field a => a -> a
recip a
primRoot) v a
ys
where
invN :: a
invN = a -> a
forall a. Field a => a -> a
recip (a -> a) -> a -> a
forall a b. (a -> b) -> a -> b
$ Int -> a
forall a b. (Integral a, Ring b) => a -> b
fromIntegral (Int -> a) -> Int -> a
forall a b. (a -> b) -> a -> b
$ v a -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int
G.length v a
ys
{-# INLINABLE inverseDft #-}