{-# LANGUAGE RebindableSyntax #-}
module SignalProcessingSpecificLLVM where

import qualified Parameters as Params
import Parameters (Freq(Freq), )

import qualified SpectralDistribution as SD
import qualified SignalProcessingMethods as Methods
import qualified SignalProcessingLLVM as SPLLVM
import qualified SignalProcessing as SP
import qualified Signal
import qualified Rate
import SignalProcessingMethods (Triple, )
import SignalProcessing (fanout3, )

import qualified Synthesizer.LLVM.CausalParameterized.Process as CausalP
import qualified Synthesizer.LLVM.Parameterized.Signal as SigP
import qualified Synthesizer.LLVM.Filter.Universal as UniFilter
import qualified Synthesizer.LLVM.Causal.Process as Causal
import qualified Synthesizer.LLVM.Fold as Fold
import qualified Synthesizer.LLVM.Frame.Binary as Bin
import qualified Synthesizer.LLVM.Parameter as Param
import Synthesizer.LLVM.Causal.Process (($*), ($<), )

import qualified Synthesizer.Generic.Signal as SigG
import qualified Synthesizer.Plain.Filter.Recursive.Universal as UniFilt
import Synthesizer.Plain.Filter.Recursive (Pole(Pole), )

import qualified Sound.SoxLib as SoxLib
import qualified Data.StorableVector.Lazy as SVL

import qualified LLVM.Extra.Arithmetic as A
import LLVM.Core (Value, )

import Control.Arrow (arr, (&&&), (<<<), (^<<), )
import Control.Applicative (pure, liftA2, (<$>), )
import Data.Tuple.HT (mapSnd, fst3, snd3, thd3, uncurry3, )

import NumericPrelude.Numeric hiding (sum1)
import NumericPrelude.Base

import Data.Word (Word32, )


type Causal p = CausalP.T (SoxLib.Rate, p)
type Param p = Param.T (SoxLib.Rate, p)

dehum :: Causal p (Value Float) (Value Float)
dehum =
   UniFilter.highpass
   ^<<
   UniFilter.causal
      $< SigP.constant (UniFilt.parameter . Pole 1 <$> freq (pure (Freq 800)))

rumble :: Causal p (Value Float) (Value Float)
rumble =
   UniFilter.lowpass
   ^<<
   UniFilter.causal
      $< SigP.constant (UniFilt.parameter . Pole 5 <$> freq (pure (Freq 220)))

bandpass ::
   Param p Float -> Param p Freq -> Causal p (Value Float) (Value Float)
bandpass q f =
   UniFilter.bandpass
   ^<<
   UniFilter.causal
      $< SigP.constant (UniFilt.parameter <$> liftA2 Pole q (freq f))

freq :: Param p Freq -> Param p Float
freq f = liftA2 Params.freq (arr (Rate.Sample . fst)) f


_downSampleMaxAbsFrac ::
   Param.T p Double -> SigP.T p (Value Float) -> SigP.T p (Value Float)
_downSampleMaxAbsFrac sizeFrac xs =
   Causal.foldChunksPartial Fold.maxAbs xs $*
   SigP.fromStorableVectorLazy (downSampleChunkSizes <$> sizeFrac)

downSampleMaxAbsFrac ::
   SigP.T p (Value Float) -> Double -> p -> SVL.Vector Float
downSampleMaxAbsFrac xs =
   let folds =
         CausalP.applyStorableChunky (Causal.foldChunksPartial Fold.maxAbs xs)
   in  \sizeFrac p -> folds p (downSampleChunkSizes sizeFrac)

downSampleChunkSizes :: Double -> SVL.Vector Word32
downSampleChunkSizes =
   SigG.fromState SigG.defaultLazySize .
   fmap fromIntegral . SP.downSampleChunkSizes

bandsDerivativesProc ::
   Param.T (SoxLib.Rate, p) (Triple Freq) ->
   CausalP.T
      (SoxLib.Rate, p)
      (Value Float)
      (Triple (Value Float), Triple (Value Float))
bandsDerivativesProc bandFreqs =
   let band f = bandpass 10 (f <$> bandFreqs)
   in  fanout3 (band fst3) (band snd3) (band thd3)
   &&&
   fanout3 SPLLVM.zerothMoment SPLLVM.firstMoment SPLLVM.secondMoment

bandsDerivatives ::
   Triple Freq -> Signal.Sampled Float ->
   SVL.Vector (Triple Float, Triple Float)
bandsDerivatives bandFreqs =
   let proc = CausalP.applyStorableChunky $ bandsDerivativesProc $ arr snd
   in  \(Signal.Cons rate sig) -> proc (Rate.unpack rate, bandFreqs) sig

intervalSizes :: [Int] -> SVL.Vector Word32
intervalSizes =
   SigG.fromList SigG.defaultLazySize . map fromIntegral

allSums ::
   SigP.T p (Triple (Value Float), Triple (Value Float)) ->
   [Int] -> p -> [(Triple Float, Triple Float)]
allSums xs =
   let foldSumAbs = Fold.premap A.abs Fold.sum
       foldSumAbs3 = Fold.triple foldSumAbs foldSumAbs foldSumAbs
       folds =
         CausalP.applyStorableChunky
            (Causal.foldChunksPartial (Fold.pair foldSumAbs3 foldSumAbs3) xs)
   in  \sizes p -> SVL.unpack $ folds p $ intervalSizes sizes


spectralBandDistr :: Triple Freq -> Triple Float -> (Float, Float)
spectralBandDistr
      (Freq bandFreq0, Freq bandFreq1, Freq bandFreq2) (sum0, sum1, sum2) =
   mapSnd sqrt $
   SP.centroidVariance3
      (bandFreq0, sum0)
      (bandFreq1, sum1)
      (bandFreq2, sum2)

bandParameters ::
   Triple Freq -> (Triple Float, Triple Float) -> ((Float, Float), SD.T Float)
bandParameters bandFreqs (bandSums, diffSums) =
   (spectralBandDistr bandFreqs bandSums,
    uncurry3 SD.spectralDistribution1 diffSums)


methods :: Methods.T
methods =
   Methods.Cons {
      Methods.dehum =
         let run =
               CausalP.applyStorableChunky
                  (dehum <<< Causal.map Bin.toCanonical)
         in  \(Signal.Cons rate xs) ->
               Signal.Cons rate $ run (Rate.unpack rate, ()) xs,

      Methods.rumble =
         let run =
               CausalP.applyStorableChunky
                  (rumble <<< Causal.map Bin.toCanonical)
         in  \(Signal.Cons rate xs) ->
               Signal.Cons rate $ run (Rate.unpack rate, ()) xs,

      Methods.downSampleAbs =
         let run =
               downSampleMaxAbsFrac $
               SigP.fromStorableVectorLazy (arr snd)
         in  \featRate (Signal.Cons rate xs) ->
               run (Rate.unpack rate / featRate) (Rate.unpack rate, xs),

      Methods.bandpassDownSample =
         let run =
               downSampleMaxAbsFrac $
               CausalP.apply
                  (bandpass 10 (arr (fst.snd)) <<< Causal.map Bin.toCanonical)
                  (SigP.fromStorableVectorLazy (arr (snd.snd)))
         in  \featRate f (Signal.Cons rate xs) ->
               run (Rate.ratio rate featRate) (Rate.unpack rate, (f, xs)),

      Methods.bandParameters =
         let run =
               allSums $
               CausalP.apply
                  (bandsDerivativesProc (arr (fst.snd)))
                  (SigP.fromStorableVectorLazy (arr (snd.snd)))
         in  \bandFreqs (Signal.Cons rate xs) sizes ->
               map (bandParameters bandFreqs) $
               run sizes (Rate.unpack rate, (bandFreqs, xs))
   }