-- |
-- Module:      Data.Poly.Internal.Dense.FFT
-- Copyright:   (c) 2020 Andrew Lelechenko
-- Licence:     BSD3
-- Maintainer:  Andrew Lelechenko <andrew.lelechenko@gmail.com>
--
-- Discrete Fourier transform.
--

{-# LANGUAGE BangPatterns        #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Data.Poly.Internal.Dense.DFT
  ( dft
  , inverseDft
  ) where

import Prelude hiding (recip, fromIntegral)
import Control.Monad.ST
import Data.Bits hiding (shift)
import Data.Foldable
import Data.Semiring (Semiring(..), Ring(..), minus, fromIntegral)
import Data.Field (Field, recip)
import qualified Data.Vector.Generic as G
import qualified Data.Vector.Generic.Mutable as MG

-- | <https://en.wikipedia.org/wiki/Fast_Fourier_transform Discrete Fourier transform>
-- \( y_k = \sum_{j=0}^{N-1} x_j \sqrt[N]{1}^{jk} \).
--
-- @since 0.5.0.0
dft
  :: (Ring a, G.Vector v a)
  => a   -- ^ primitive root \( \sqrt[N]{1} \), otherwise behaviour is undefined
  -> v a -- ^ \( \{ x_k \}_{k=0}^{N-1} \) (currently only  \( N = 2^n \) is supported)
  -> v a -- ^ \( \{ y_k \}_{k=0}^{N-1} \)
dft :: forall a (v :: * -> *). (Ring a, Vector v a) => a -> v a -> v a
dft a
primRoot (v a
xs :: v a)
  | Int -> Int
forall a. Bits a => a -> Int
popCount Int
nn Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
1 = [Char] -> v a
forall a. HasCallStack => [Char] -> a
error [Char]
"dft: only vectors of length 2^n are supported"
  | Bool
otherwise = Int -> Int -> v a
go Int
0 Int
0
  where
    nn :: Int
nn = v a -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int
G.length v a
xs
    n :: Int
n = Int -> Int
forall b. FiniteBits b => b -> Int
countTrailingZeros Int
nn

    roots :: v a
    roots :: v a
roots = Int -> (a -> a) -> a -> v a
forall (v :: * -> *) a. Vector v a => Int -> (a -> a) -> a -> v a
G.iterateN
      (Int
1 Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`unsafeShiftL` (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1))
      (\a
x -> a
x a -> a -> a
forall a b. a -> b -> b
`seq` (a
x a -> a -> a
forall a. Semiring a => a -> a -> a
`times` a
primRoot))
      a
forall a. Semiring a => a
one

    go :: Int -> Int -> v a
go !Int
offset !Int
shift
      | Int
shift Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
n = Int -> Int -> v a -> v a
forall (v :: * -> *) a. Vector v a => Int -> Int -> v a -> v a
G.unsafeSlice Int
offset Int
1 v a
xs
      | Bool
otherwise = (forall s. ST s (v a)) -> v a
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (v a)) -> v a) -> (forall s. ST s (v a)) -> v a
forall a b. (a -> b) -> a -> b
$ do
        let halfLen :: Int
halfLen = Int
1 Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`unsafeShiftL` (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
shift Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
            ys0 :: v a
ys0 = Int -> Int -> v a
go Int
offset (Int
shift Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
            ys1 :: v a
ys1 = Int -> Int -> v a
go (Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
shift) (Int
shift Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
        Mutable v s a
ys <- Int -> ST s (Mutable v (PrimState (ST s)) a)
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
Int -> m (v (PrimState m) a)
MG.new (Int
halfLen Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
1)

        -- This corresponds to k = 0 in the loop below.
        -- It improves performance by avoiding multiplication
        -- by roots V.! 0 = 1.
        let y00 :: a
y00 = v a -> Int -> a
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
G.unsafeIndex v a
ys0 Int
0
            y10 :: a
y10 = v a -> Int -> a
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
G.unsafeIndex v a
ys1 Int
0
        Mutable v (PrimState (ST s)) a -> Int -> a -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
MG.unsafeWrite Mutable v s a
Mutable v (PrimState (ST s)) a
ys Int
0       (a -> ST s ()) -> a -> ST s ()
forall a b. (a -> b) -> a -> b
$! a
y00 a -> a -> a
forall a. Semiring a => a -> a -> a
`plus`  a
y10
        Mutable v (PrimState (ST s)) a -> Int -> a -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
MG.unsafeWrite Mutable v s a
Mutable v (PrimState (ST s)) a
ys Int
halfLen (a -> ST s ()) -> a -> ST s ()
forall a b. (a -> b) -> a -> b
$! a
y00 a -> a -> a
forall a. Ring a => a -> a -> a
`minus` a
y10

        [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
1..Int
halfLen Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
k -> do
          let y0 :: a
y0 = v a -> Int -> a
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
G.unsafeIndex v a
ys0 Int
k
              y1 :: a
y1 = v a -> Int -> a
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
G.unsafeIndex v a
ys1 Int
k a -> a -> a
forall a. Semiring a => a -> a -> a
`times`
                   v a -> Int -> a
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
G.unsafeIndex v a
roots (Int
k Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
shift)
          Mutable v (PrimState (ST s)) a -> Int -> a -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
MG.unsafeWrite Mutable v s a
Mutable v (PrimState (ST s)) a
ys Int
k             (a -> ST s ()) -> a -> ST s ()
forall a b. (a -> b) -> a -> b
$! a
y0 a -> a -> a
forall a. Semiring a => a -> a -> a
`plus`  a
y1
          Mutable v (PrimState (ST s)) a -> Int -> a -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
MG.unsafeWrite Mutable v s a
Mutable v (PrimState (ST s)) a
ys (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
halfLen) (a -> ST s ()) -> a -> ST s ()
forall a b. (a -> b) -> a -> b
$! a
y0 a -> a -> a
forall a. Ring a => a -> a -> a
`minus` a
y1
        Mutable v (PrimState (ST s)) a -> ST s (v a)
forall (m :: * -> *) (v :: * -> *) a.
(PrimMonad m, Vector v a) =>
Mutable v (PrimState m) a -> m (v a)
G.unsafeFreeze Mutable v s a
Mutable v (PrimState (ST s)) a
ys
{-# INLINABLE dft #-}

-- | Inverse <https://en.wikipedia.org/wiki/Fast_Fourier_transform discrete Fourier transform>
-- \( x_k = {1\over N} \sum_{j=0}^{N-1} y_j \sqrt[N]{1}^{-jk} \).
--
-- @since 0.5.0.0
inverseDft
  :: (Field a, G.Vector v a)
  => a   -- ^ primitive root \( \sqrt[N]{1} \), otherwise behaviour is undefined
  -> v a -- ^ \( \{ y_k \}_{k=0}^{N-1} \) (currently only  \( N = 2^n \) is supported)
  -> v a -- ^ \( \{ x_k \}_{k=0}^{N-1} \)
inverseDft :: forall a (v :: * -> *). (Field a, Vector v a) => a -> v a -> v a
inverseDft a
primRoot v a
ys = (a -> a) -> v a -> v a
forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
G.map (a -> a -> a
forall a. Semiring a => a -> a -> a
`times` a
invN) (v a -> v a) -> v a -> v a
forall a b. (a -> b) -> a -> b
$ a -> v a -> v a
forall a (v :: * -> *). (Ring a, Vector v a) => a -> v a -> v a
dft (a -> a
forall a. Field a => a -> a
recip a
primRoot) v a
ys
  where
    invN :: a
invN = a -> a
forall a. Field a => a -> a
recip (a -> a) -> a -> a
forall a b. (a -> b) -> a -> b
$ Int -> a
forall a b. (Integral a, Ring b) => a -> b
fromIntegral (Int -> a) -> Int -> a
forall a b. (a -> b) -> a -> b
$ v a -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int
G.length v a
ys
{-# INLINABLE inverseDft #-}