{- |
This module provides normalized versions of the transforms in @fftw@.

All of the transforms are normalized so that

 - Each transform is unitary, i.e., preserves the inner product and the sum-of-squares norm of its input.

 - Each backwards transform is the inverse of the corresponding forwards transform.

(Both conditions only hold approximately, due to floating point precision.)

For more information on the underlying transforms, see
<http://www.fftw.org/fftw3_doc/What-FFTW-Really-Computes.html>.
--
-- @since 0.2
-}

module Numeric.FFT.Vector.Unitary.Multi
  (
        -- * Creating and executing 'Plan's
        run,
        plan,
        execute,
        -- * Complex-to-complex transforms
        dft,
        idft,
        -- * Real-to-complex transforms
        dftR2C,
        dftC2R,
  ) where

import Control.Exception (assert)
import Control.Monad (forM_)
import Numeric.FFT.Vector.Base
import qualified Numeric.FFT.Vector.Unnormalized.Multi as U
import Data.Complex
import qualified Data.Vector.Storable as VS
import qualified Data.Vector.Storable.Mutable as MS
import Control.Monad.Primitive(RealWorld)

-- | A discrete Fourier transform. The output and input sizes are the same (@n@).
--
-- @y_k = (1\/sqrt n) sum_(j=0)^(n-1) x_j e^(-2pi i j k\/n)@
dft :: TransformND (Complex Double) (Complex Double)
dft :: TransformND (Complex Double) (Complex Double)
dft = TransformND (Complex Double) (Complex Double)
U.dft {normalizationND = \Vector Int
ns -> Double
-> Plan (Complex Double) (Complex Double)
-> Plan (Complex Double) (Complex Double)
forall b a.
(Storable b, Scalable b) =>
Double -> Plan a b -> Plan a b
constMultOutput (Double
 -> Plan (Complex Double) (Complex Double)
 -> Plan (Complex Double) (Complex Double))
-> Double
-> Plan (Complex Double) (Complex Double)
-> Plan (Complex Double) (Complex Double)
forall a b. (a -> b) -> a -> b
$ Double
1 Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double -> Double
forall a. Floating a => a -> a
sqrt (Int -> Double
forall a. Enum a => Int -> a
toEnum (Vector Int -> Int
forall a. (Storable a, Num a) => Vector a -> a
VS.product Vector Int
ns))}

-- | An inverse discrete Fourier transform.  The output and input sizes are the same (@n@).
--
-- @y_k = (1\/sqrt n) sum_(j=0)^(n-1) x_j e^(2pi i j k\/n)@
idft :: TransformND (Complex Double) (Complex Double)
idft :: TransformND (Complex Double) (Complex Double)
idft = TransformND (Complex Double) (Complex Double)
U.idft {normalizationND = \Vector Int
ns -> Double
-> Plan (Complex Double) (Complex Double)
-> Plan (Complex Double) (Complex Double)
forall b a.
(Storable b, Scalable b) =>
Double -> Plan a b -> Plan a b
constMultOutput (Double
 -> Plan (Complex Double) (Complex Double)
 -> Plan (Complex Double) (Complex Double))
-> Double
-> Plan (Complex Double) (Complex Double)
-> Plan (Complex Double) (Complex Double)
forall a b. (a -> b) -> a -> b
$ Double
1 Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double -> Double
forall a. Floating a => a -> a
sqrt (Int -> Double
forall a. Enum a => Int -> a
toEnum (Vector Int -> Int
forall a. (Storable a, Num a) => Vector a -> a
VS.product Vector Int
ns))}

-- | A forward discrete Fourier transform with real data.  If the input size is @n@,
-- the output size will be @n \`div\` 2 + 1@.
dftR2C :: TransformND Double (Complex Double)
dftR2C :: TransformND Double (Complex Double)
dftR2C = TransformND Double (Complex Double)
U.dftR2C {normalizationND = \Vector Int
ns -> (MVector RealWorld (Complex Double) -> IO ())
-> Plan Double (Complex Double) -> Plan Double (Complex Double)
forall b a. (MVector RealWorld b -> IO ()) -> Plan a b -> Plan a b
modifyOutput ((MVector RealWorld (Complex Double) -> IO ())
 -> Plan Double (Complex Double) -> Plan Double (Complex Double))
-> (MVector RealWorld (Complex Double) -> IO ())
-> Plan Double (Complex Double)
-> Plan Double (Complex Double)
forall a b. (a -> b) -> a -> b
$
                    Double
-> Vector Int -> Int -> MVector RealWorld (Complex Double) -> IO ()
complexR2CScaling (Double -> Double
forall a. Floating a => a -> a
sqrt Double
2) Vector Int
ns (TransformND Double (Complex Double) -> Int -> Int
forall a b. TransformND a b -> Int -> Int
outputSizeND TransformND Double (Complex Double)
U.dftR2C (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$ Vector Int -> Int
forall a. Storable a => Vector a -> a
VS.last Vector Int
ns)
        }

-- | A normalized backward discrete Fourier transform which is the left inverse of
-- 'U.dftR2C'.  (Specifically, @run dftC2R . run dftR2C == id@.)
--
-- This 'Transform' behaves differently than the others:
--
--  - Calling @plan dftC2R n@ creates a 'Plan' whose /output/ size is @n@, and whose
--    /input/ size is @n \`div\` 2 + 1@.
--
--  - If @length v == n@, then @length (run dftC2R v) == 2*(n-1)@.
--
dftC2R :: TransformND (Complex Double) Double
dftC2R :: TransformND (Complex Double) Double
dftC2R = TransformND (Complex Double) Double
U.dftC2R {normalizationND = \Vector Int
ns -> (MVector RealWorld (Complex Double) -> IO ())
-> Plan (Complex Double) Double -> Plan (Complex Double) Double
forall a b. (MVector RealWorld a -> IO ()) -> Plan a b -> Plan a b
modifyInput ((MVector RealWorld (Complex Double) -> IO ())
 -> Plan (Complex Double) Double -> Plan (Complex Double) Double)
-> (MVector RealWorld (Complex Double) -> IO ())
-> Plan (Complex Double) Double
-> Plan (Complex Double) Double
forall a b. (a -> b) -> a -> b
$
                    Double
-> Vector Int -> Int -> MVector RealWorld (Complex Double) -> IO ()
complexR2CScaling (Double -> Double
forall a. Floating a => a -> a
sqrt Double
0.5) Vector Int
ns (TransformND (Complex Double) Double -> Int -> Int
forall a b. TransformND a b -> Int -> Int
inputSizeND TransformND (Complex Double) Double
U.dftC2R (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$ Vector Int -> Int
forall a. Storable a => Vector a -> a
VS.last Vector Int
ns)
        }

complexR2CScaling :: Double -> VS.Vector Int -> Int -> MS.MVector RealWorld (Complex Double) -> IO ()
complexR2CScaling :: Double
-> Vector Int -> Int -> MVector RealWorld (Complex Double) -> IO ()
complexR2CScaling !Double
t !Vector Int
ns !Int
len !MVector RealWorld (Complex Double)
a = Bool -> IO () -> IO ()
forall a. (?callStack::CallStack) => Bool -> a -> a
assert (MVector RealWorld (Complex Double) -> Int
forall a s. Storable a => MVector s a -> Int
MS.length MVector RealWorld (Complex Double)
a Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Vector Int -> Int
forall a. (Storable a, Num a) => Vector a -> a
VS.product (Vector Int -> Vector Int
forall a. Storable a => Vector a -> Vector a
VS.init Vector Int
ns) Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
len) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
    let !s1 :: Double
s1 = Double -> Double
forall a. Floating a => a -> a
sqrt (Double
1Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Int -> Double
forall a. Enum a => Int -> a
toEnum (Vector Int -> Int
forall a. (Storable a, Num a) => Vector a -> a
VS.product Vector Int
ns))
    let !s2 :: Double
s2 = Double
t Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
s1
    -- Justification for the use of unsafeModify:
    -- The output size is 2n+1; so if n>0 then the output size is >=1;
    -- and if n even then the output size is >=3.
    [Int] -> (Int -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0.. Vector Int -> Int
forall a. (Storable a, Num a) => Vector a -> a
VS.product (Vector Int -> Vector Int
forall a. Storable a => Vector a -> Vector a
VS.init Vector Int
ns) Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] ((Int -> IO ()) -> IO ()) -> (Int -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Int
idx -> do
      MVector RealWorld (Complex Double)
-> Int -> (Complex Double -> Complex Double) -> IO ()
forall a.
Storable a =>
MVector RealWorld a -> Int -> (a -> a) -> IO ()
unsafeModify MVector RealWorld (Complex Double)
a (Int
idx Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
len) ((Complex Double -> Complex Double) -> IO ())
-> (Complex Double -> Complex Double) -> IO ()
forall a b. (a -> b) -> a -> b
$ Double -> Complex Double -> Complex Double
forall a. Scalable a => Double -> a -> a
scaleByD Double
s1
      if Int -> Bool
forall a. Integral a => a -> Bool
odd (Vector Int -> Int
forall a. Storable a => Vector a -> a
VS.last Vector Int
ns)
        then Double -> MVector RealWorld (Complex Double) -> IO ()
forall a.
(Storable a, Scalable a) =>
Double -> MVector RealWorld a -> IO ()
multC Double
s2 (Int
-> Int
-> MVector RealWorld (Complex Double)
-> MVector RealWorld (Complex Double)
forall a s. Storable a => Int -> Int -> MVector s a -> MVector s a
MS.unsafeSlice (Int
idx Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (Int
lenInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) MVector RealWorld (Complex Double)
a)
        else do
            MVector RealWorld (Complex Double)
-> Int -> (Complex Double -> Complex Double) -> IO ()
forall a.
Storable a =>
MVector RealWorld a -> Int -> (a -> a) -> IO ()
unsafeModify MVector RealWorld (Complex Double)
a (Int
idx Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) ((Complex Double -> Complex Double) -> IO ())
-> (Complex Double -> Complex Double) -> IO ()
forall a b. (a -> b) -> a -> b
$ Double -> Complex Double -> Complex Double
forall a. Scalable a => Double -> a -> a
scaleByD Double
s1
            Double -> MVector RealWorld (Complex Double) -> IO ()
forall a.
(Storable a, Scalable a) =>
Double -> MVector RealWorld a -> IO ()
multC Double
s2 (Int
-> Int
-> MVector RealWorld (Complex Double)
-> MVector RealWorld (Complex Double)
forall a s. Storable a => Int -> Int -> MVector s a -> MVector s a
MS.unsafeSlice (Int
idx Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (Int
lenInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
2) MVector RealWorld (Complex Double)
a)