{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ForeignFunctionInterface #-}
module Main where

import qualified LLVM.Extra.ScalarOrVector as SoV
import qualified LLVM.Extra.Vector as Vector

import qualified LLVM.Extra.Iterator as Iter
import qualified LLVM.Extra.Tuple as Tuple
import qualified LLVM.Extra.Arithmetic as A
import LLVM.Extra.Storable (arrayLoop, store)
import LLVM.Extra.Control (ret)

import qualified LLVM.ExecutionEngine as EE
import qualified LLVM.Core as LLVM
import LLVM.ExecutionEngine (simpleFunction, )
import LLVM.Core
         (Value, valueOf, value, constOf, undef, zero, add, sub, mul, frem,
          createFunction, Function, Linkage(ExternalLinkage),
          CodeGenModule, CodeGenFunction,
          Vector, extractelement, insertelement, shufflevector, )
import qualified System.IO as IO

import Type.Data.Num.Decimal(D4, )
import Data.Word (Word32, )
import qualified Foreign.Storable as St
import Foreign.Marshal.Array (allocaArray, )
import Foreign.Ptr (FunPtr, Ptr, )

import qualified Data.Empty as Empty
import Data.NonEmpty ((!:), )

import Control.Monad.Trans.State (StateT(StateT), runStateT)
import Control.Monad.HT ((<=<))
import Control.Monad (liftM2)
import Control.Applicative (liftA2)



type Vec = LLVM.ConstValue (Vector D4 Float)

constVec ::
   Float -> CodeGenFunction r (Value (Vector D4 Float))
constVec x =
   return $ valueOf $ LLVM.consVector x x x x

constVecInsert ::
   Float -> CodeGenFunction r (Value (Vector D4 Float))
constVecInsert x' =
   let x = valueOf x'
   in  foldr
          (\n mv v -> insertelement v x (valueOf n) >>= mv)
          return
          [0..3]
          (value (undef :: Vec))

{-
This implementation cannot make use of vector operations,
because 'frem' is only available in the FPU.
-}
fractionVector0 ::
   Value (Vector D4 Float) -> CodeGenFunction r (Value (Vector D4 Float))
fractionVector0 x =
   frem x =<< constVec 1


{-
This call

    fill (fromIntegral len) ptr
       (LLVM.consVector 0.01003 0.01001 0.00999 0.00997) >>

would not work, because Vector is not of type Generic.
-}
mChorusVectorArg ::
  CodeGenModule (Function (Word32 -> Ptr Float -> Vector D4 Float -> IO Float))
mChorusVectorArg =
  createFunction ExternalLinkage $ \ size ptr freq -> do
    const1 <- constVec 1
    const2 <- constVec 2
    s <- arrayLoop size ptr (value (zero :: Vec)) $ \ ptri phase -> do
      y <- sub const1 =<< mul const2 phase
      s0 <- extractelement y (valueOf 0)
      s1 <- extractelement y (valueOf 1)
      s2 <- extractelement y (valueOf 2)
      s3 <- extractelement y (valueOf 3)
      s01 <- add s0 s1
      s23 <- add s2 s3
      s0123 <- add s01 s23
      flip store ptri =<< A.mul (valueOf 0.25) s0123
      Vector.fraction =<< add phase freq
    ss <- extractelement s (valueOf 0)
    ret (ss :: Value Float)


{- |
differing vector sizes are allowed according to documentation,
but not supported by C++ library of LLVM-2.5

mixReduceSize :: Value (Vector D4 Float) -> CodeGenFunction r (Value Float)
mixReduceSize y = do
    y01 <- shufflevector y (value undef) (LLVM.constVector [constOf 0, constOf 1])
    y23 <- shufflevector y (value undef) (LLVM.constVector [constOf 2, constOf 3])
    z <- add
       (y01 :: Value (Vector D2 Float))
       (y23 :: Value (Vector D2 Float))
    s0 <- extractelement z (valueOf 0)
    s1 <- extractelement z (valueOf 1)
    A.mul (valueOf 0.25) =<< add s0 s1
-}

mixScalar :: Value (Vector D4 Float) -> CodeGenFunction r (Value Float)
mixScalar y = do
    y0 <- extractelement y (valueOf 0)
    y1 <- extractelement y (valueOf 1)
    y2 <- extractelement y (valueOf 2)
    y3 <- extractelement y (valueOf 3)
    s0 <- A.add y0 y1
    s1 <- A.add y2 y3
    A.mul (valueOf 0.25) =<< A.add s0 s1

{-
Here we do use consistently Vectors of size 4.
Since we declare the upper floats as undefined
the code is efficient.
-}
mixGeneric :: Value (Vector D4 Float) -> CodeGenFunction r (Value Float)
mixGeneric y = do
    -- that is translated to movhlps
    y23 <-
       shufflevector y (value undef)
          (LLVM.constVector $ constOf 2 !: constOf 3 !: undef !: undef !: Empty.Cons)
    z <- A.add y y23
    s0 <- extractelement z (valueOf 0)
    s1 <- extractelement z (valueOf 1)
    A.mul (valueOf 0.25) =<< A.add s0 s1


mChorusVector ::
  CodeGenModule
    (Function
      (Word32 -> Ptr Float -> Float -> Float -> Float -> Float -> IO Float))
mChorusVector =
  createFunction ExternalLinkage $ \ size ptr f0 f1 f2 f3 -> do
    freq <- Vector.assemble [f0,f1,f2,f3]
    const1 <- constVec 1
    const2 <- constVec (-2)
    s <- arrayLoop size ptr (value (zero :: Vec)) $ \ ptri phase -> do
      flip store ptri =<< mixGeneric =<< add const1 =<< mul const2 phase
      Vector.fraction =<< A.add phase freq
    ss <- extractelement s (valueOf 0)
    ret ss

mChorusVectorIterator ::
  CodeGenModule
    (Function
      (Word32 -> Ptr Float -> Float -> Float -> Float -> Float -> IO Float))
mChorusVectorIterator =
  createFunction ExternalLinkage $ \ size ptr f0 f1 f2 f3 -> do
    freq <- Vector.assemble [f0,f1,f2,f3]
    const1 <- constVec 1
    const2 <- constVec (-2)
    Iter.mapM_ id $ Iter.take size $
      liftA2
        (\ptri phase ->
          flip store ptri =<< mixGeneric =<< add const1 =<< mul const2 phase)
        (Iter.storableArrayPtrs ptr)
        (Iter.iterate (Vector.fraction <=< A.add freq) (value (zero :: Vec)))
    ret (value zero :: Value Float)


waveSaw :: Value Float -> CodeGenFunction r (Value Float)
waveSaw t =
  A.sub (valueOf 1) =<<
  A.mul (valueOf 2) t

osciSaw ::
  Value Float -> Value Float -> CodeGenFunction r (Value Float, Value Float)
osciSaw freq phase =
  liftM2 (,) (waveSaw phase) (SoV.incPhase freq phase)

mChorus ::
  CodeGenModule
    (Function
      (Word32 -> Ptr Float -> Float -> Float -> Float -> Float -> IO Float))
mChorus =
  createFunction ExternalLinkage $ \ size ptr f0 f1 f2 f3 -> do
    s <- arrayLoop size ptr Tuple.zero $
         \ ptri ((phase0, phase1), (phase2, phase3)) -> do
      (y0, phase0') <- osciSaw f0 phase0
      (y1, phase1') <- osciSaw f1 phase1
      (y2, phase2') <- osciSaw f2 phase2
      (y3, phase3') <- osciSaw f3 phase3
      y01 <- A.add y0 y1
      y23 <- A.add y2 y3
      y0123 <- A.add y01 y23
      flip store ptri =<< A.mul (valueOf 0.25) y0123
      return ((phase0', phase1'), (phase2', phase3'))
    ret (fst (fst s) :: Value Float)


sawOsciAction ::
  Value Float ->
  StateT (Value Float) (CodeGenFunction r) (Value Float)
sawOsciAction freq =
  StateT $ osciSaw freq

{-
(***) :: StateT s m a -> StateT t m b -> StateT (s,t) m (a,b)
(***) sta stb =
  StateT $ \(s0,t0) ->
  do (a,s1) <- runStateT sta s0
     (b,t1) <- runStateT stb t0
     return ((a,b), (s1,t1))
-}

(=+=) ::
  StateT s (CodeGenFunction r) (Value Float) ->
  StateT t (CodeGenFunction r) (Value Float) ->
  StateT (s,t) (CodeGenFunction r) (Value Float)
(=+=) sta stb =
  StateT $ \(s0,t0) ->
  do (a,s1) <- runStateT sta s0
     (b,t1) <- runStateT stb t0
     c <- add a b
     return (c, (s1,t1))

mChorusMonadic ::
  CodeGenModule
    (Function
      (Word32 -> Ptr Float -> Float -> Float -> Float -> Float -> IO Float))
mChorusMonadic =
  createFunction ExternalLinkage $ \ size ptr f0 f1 f2 f3 -> do
    s <- arrayLoop size ptr Tuple.zero $
         \ ptri phases -> do
      (y, phases') <-
         flip runStateT phases $
            (sawOsciAction f0 =+= sawOsciAction f1) =+=
            (sawOsciAction f2 =+= sawOsciAction f3)
      flip store ptri =<< A.mul (valueOf 0.25) y
      return phases'
    ret (fst (fst s))


type Importer func = FunPtr func -> func

generateFunction ::
  EE.ExecutionFunction f =>
  Importer f -> CodeGenModule (Function f) -> IO f
generateFunction imprt code = do
  m <- LLVM.newModule
  fill <- do
    func <- LLVM.defineModule m $ LLVM.setTarget LLVM.hostTriple >> code
    EE.runEngineAccessWithModule m $ EE.getExecutionFunction imprt func
  LLVM.writeBitcodeToFile "array.bc" m
  return fill


foreign import ccall safe "dynamic" derefChorusPtr ::
  Importer
    (Word32 -> Ptr Float -> Float -> Float -> Float -> Float -> IO Float)

renderChorus :: IO ()
renderChorus = do
  fill <- generateFunction derefChorusPtr mChorusVectorIterator
  IO.withFile "speedtest.f32" IO.WriteMode $ \h ->
    let len = 10000000
    in  allocaArray len $ \ ptr ->
          fill (fromIntegral len) ptr 0.01003 0.01001 0.00999 0.00997 >>
          IO.hPutBuf h ptr (len*St.sizeOf(undefined::Float))


mSaw :: CodeGenModule (Function (Word32 -> Ptr Float -> Float -> IO Float))
mSaw =
  createFunction ExternalLinkage $ \ size ptr freq -> do
    s <- arrayLoop size ptr (valueOf 0) $ \ ptri phase -> do
      (y, phase') <- osciSaw freq phase
      store y ptri
      return phase'
    ret (s :: Value Float)

foreign import ccall safe "dynamic" derefSawPtr ::
  Importer (Word32 -> Ptr Float -> Float -> IO Float)

renderSaw :: IO ()
renderSaw = do
  fill <- generateFunction derefSawPtr mSaw
  IO.withFile "speedtest.f32" IO.WriteMode $ \h ->
    let len = 10000000
    in  allocaArray len $ \ ptr ->
          fill (fromIntegral len) ptr 0.01 >>
          IO.hPutBuf h ptr (len*St.sizeOf(undefined::Float))


mRamp :: CodeGenModule (Function (Word32 -> Ptr Float -> Float -> IO Float))
mRamp =
  createFunction ExternalLinkage $ \ size ptr slope -> do
    s <- arrayLoop size ptr (valueOf 0) $ \ ptri y -> do
      store y ptri
      add slope y
    ret (s :: Value Float)

renderRamp :: IO ()
renderRamp = do
  fill <- simpleFunction mRamp
  IO.withFile "speedtest.f32" IO.WriteMode $ \h ->
    let len = 10000000
    in  allocaArray len $ \ ptr ->
          fill (fromIntegral len) ptr (recip $ fromIntegral len) >>
          IO.hPutBuf h ptr (len*St.sizeOf(undefined::Float))

main :: IO ()
main = do
   LLVM.initializeNativeTarget
   renderChorus