{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fconstraint-solver-iterations=20 #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}

{- |
Experimental protocol-agnostic simulation support.

This module contains the generic simulation classes and helpers that used to be
part of the default "Protocols" import surface. New tests should prefer
protocol-specific drivers and checkers over adding more assumptions here.
-}
module Protocols.Experimental.Simulate (
  module Protocols.Experimental.Simulate.Types,
  simulateC,
  simulateCS,
  simulateCSE,
  simulateCircuit,
  def,
) where

import Clash.Explicit.Prelude qualified as CE
import Clash.Prelude (type (*), type (+))
import Clash.Prelude qualified as C
import Data.Default (Default (def))
import Data.Proxy
import Data.Tuple (swap)

import Protocols.Experimental.Simulate.Types
import Protocols.Internal
import Protocols.Internal.TH (
  backPressureTupleInstances,
  drivableTupleInstances,
  simulateTupleInstances,
 )
import Protocols.Plugin.Cpp (maxTupleSize)

{- $setup
>>> import Protocols
>>> import Protocols.Experimental.Df
>>> import Protocols.Experimental.Simulate
-}

instance Backpressure () where
  boolsToBwd _ _ = ()

instance (Backpressure a, Backpressure b) => Backpressure (a, b) where
  boolsToBwd _ bs = (boolsToBwd (Proxy @a) bs, boolsToBwd (Proxy @b) bs)

backPressureTupleInstances 3 maxTupleSize

instance (C.KnownNat n, Backpressure a) => Backpressure (C.Vec n a) where
  boolsToBwd _ bs = C.repeat (boolsToBwd (Proxy @a) bs)

instance Backpressure (CSignal dom a) where
  boolsToBwd _ _ = ()

instance Simulate () where
  type SimulateFwdType () = ()
  type SimulateBwdType () = ()
  type SimulateChannels () = 0

  simToSigFwd _ = id
  simToSigBwd _ = id
  sigToSimFwd _ = id
  sigToSimBwd _ = id

  stallC _ _ = idC

instance Drivable () where
  type ExpectType () = ()

  toSimulateType Proxy () = ()
  fromSimulateType Proxy () = ()

  driveC _ _ = idC
  sampleC _ _ = ()

instance (Simulate a, Simulate b) => Simulate (a, b) where
  type SimulateFwdType (a, b) = (SimulateFwdType a, SimulateFwdType b)
  type SimulateBwdType (a, b) = (SimulateBwdType a, SimulateBwdType b)
  type SimulateChannels (a, b) = SimulateChannels a + SimulateChannels b

  simToSigFwd Proxy ~(fwdsA, fwdsB) = (simToSigFwd (Proxy @a) fwdsA, simToSigFwd (Proxy @b) fwdsB)
  simToSigBwd Proxy ~(bwdsA, bwdsB) = (simToSigBwd (Proxy @a) bwdsA, simToSigBwd (Proxy @b) bwdsB)
  sigToSimFwd Proxy ~(fwdSigA, fwdSigB) = (sigToSimFwd (Proxy @a) fwdSigA, sigToSimFwd (Proxy @b) fwdSigB)
  sigToSimBwd Proxy ~(bwdSigA, bwdSigB) = (sigToSimBwd (Proxy @a) bwdSigA, sigToSimBwd (Proxy @b) bwdSigB)

  stallC conf stalls =
    let
      (stallsL, stallsR) = C.splitAtI @(SimulateChannels a) @(SimulateChannels b) stalls
      Circuit stalledL = stallC @a conf stallsL
      Circuit stalledR = stallC @b conf stallsR
     in
      Circuit $ \(~((fwdL0, fwdR0), (bwdL0, bwdR0))) ->
        let
          (fwdL1, bwdL1) = stalledL (fwdL0, bwdL0)
          (fwdR1, bwdR1) = stalledR (fwdR0, bwdR0)
         in
          ((fwdL1, fwdR1), (bwdL1, bwdR1))

simulateTupleInstances 3 maxTupleSize

instance (Drivable a, Drivable b) => Drivable (a, b) where
  type ExpectType (a, b) = (ExpectType a, ExpectType b)

  toSimulateType Proxy (t1, t2) =
    ( toSimulateType (Proxy @a) t1
    , toSimulateType (Proxy @b) t2
    )

  fromSimulateType Proxy (t1, t2) =
    ( fromSimulateType (Proxy @a) t1
    , fromSimulateType (Proxy @b) t2
    )

  driveC conf (fwd1, fwd2) =
    let (Circuit f1, Circuit f2) = (driveC @a conf fwd1, driveC @b conf fwd2)
     in Circuit (\(_, ~(bwd1, bwd2)) -> ((), (snd (f1 ((), bwd1)), snd (f2 ((), bwd2)))))

  sampleC conf (Circuit f) =
    let
      bools = replicate (resetCycles conf) False <> repeat True
      (_, (fwd1, fwd2)) = f ((), (boolsToBwd (Proxy @a) bools, boolsToBwd (Proxy @b) bools))
     in
      ( sampleC @a conf (Circuit $ \_ -> ((), fwd1))
      , sampleC @b conf (Circuit $ \_ -> ((), fwd2))
      )

drivableTupleInstances 3 maxTupleSize

instance (Simulate a) => Simulate (Reverse a) where
  type SimulateFwdType (Reverse a) = SimulateBwdType a
  type SimulateBwdType (Reverse a) = SimulateFwdType a
  type SimulateChannels (Reverse a) = SimulateChannels a

  simToSigFwd Proxy = simToSigBwd (Proxy @a)
  simToSigBwd Proxy = simToSigFwd (Proxy @a)
  sigToSimFwd Proxy = sigToSimBwd (Proxy @a)
  sigToSimBwd Proxy = sigToSimFwd (Proxy @a)

  stallC conf stalls =
    let Circuit stalled = stallC @a conf stalls
     in Circuit $ \(fwd, bwd) -> swap (stalled (bwd, fwd))

instance (CE.KnownNat n, Simulate a) => Simulate (C.Vec n a) where
  type SimulateFwdType (C.Vec n a) = C.Vec n (SimulateFwdType a)
  type SimulateBwdType (C.Vec n a) = C.Vec n (SimulateBwdType a)
  type SimulateChannels (C.Vec n a) = n * SimulateChannels a

  simToSigFwd Proxy = C.map (simToSigFwd (Proxy @a))
  simToSigBwd Proxy = C.map (simToSigBwd (Proxy @a))
  sigToSimFwd Proxy = C.map (sigToSimFwd (Proxy @a))
  sigToSimBwd Proxy = C.map (sigToSimBwd (Proxy @a))

  stallC conf stalls0 =
    let
      stalls1 = C.unconcatI @n @(SimulateChannels a) stalls0
      stalled = C.map (toSignals . stallC @a conf) stalls1
     in
      Circuit $ \(fwds, bwds) -> C.unzip (C.zipWith ($) stalled (C.zip fwds bwds))

instance (C.KnownNat n, Drivable a) => Drivable (C.Vec n a) where
  type ExpectType (C.Vec n a) = C.Vec n (ExpectType a)

  toSimulateType Proxy = C.map (toSimulateType (Proxy @a))
  fromSimulateType Proxy = C.map (fromSimulateType (Proxy @a))

  driveC conf fwds =
    let circuits = C.map (($ ()) . curry . (toSignals @_ @a) . driveC conf) fwds
     in Circuit (\(_, bwds) -> ((), C.map snd (C.zipWith ($) circuits bwds)))

  sampleC conf (Circuit f) =
    let
      bools = replicate (resetCycles conf) False <> repeat True
      (_, fwds) = f ((), (C.repeat (boolsToBwd (Proxy @a) bools)))
     in
      C.map (\fwd -> sampleC @a conf (Circuit $ \_ -> ((), fwd))) fwds

instance (C.KnownDomain dom) => Simulate (CSignal dom a) where
  type SimulateFwdType (CSignal dom a) = [a]
  type SimulateBwdType (CSignal dom a) = ()
  type SimulateChannels (CSignal dom a) = 1

  simToSigFwd Proxy list = C.fromList_lazy list
  simToSigBwd Proxy () = ()
  sigToSimFwd Proxy sig = C.sample_lazy sig
  sigToSimBwd Proxy _ = ()

  stallC _ _ = idC

instance (C.NFDataX a, C.ShowX a, Show a, C.KnownDomain dom) => Drivable (CSignal dom a) where
  type ExpectType (CSignal dom a) = [a]

  toSimulateType Proxy = id
  fromSimulateType Proxy = id

  driveC _conf [] = error "CSignal.driveC: Can't drive with empty list"
  driveC SimulationConfig{resetCycles} fwd0@(f : _) =
    let fwd1 = C.fromList_lazy (replicate resetCycles f <> fwd0 <> repeat f)
     in Circuit (\_ -> ((), fwd1))

  sampleC SimulationConfig{resetCycles, ignoreReset, timeoutAfter} (Circuit f) =
    let sampled = CE.sampleN_lazy timeoutAfter (snd (f ((), ())))
     in if ignoreReset then drop resetCycles sampled else sampled

{- | Simulate a circuit. Includes samples while reset is asserted.
Not synthesizable.

To figure out what input you need to supply, either solve the type
"SimulateFwdType" manually, or let the repl do the work for you! Example:

>>> :kind! (forall dom a. SimulateFwdType (Df dom a))
...
= [Maybe a]

This would mean a @Circuit (Df dom a) (Df dom b)@ would need
@[Maybe a]@ as the last argument of 'simulateC' and would result in
@[Maybe b]@. Note that for this particular type you can neither supply
stalls nor introduce backpressure. If you want to do this use
'Protocols.Experimental.Df.stall'.
-}
simulateC ::
  forall a b.
  (Drivable a, Drivable b) =>
  -- | Circuit to simulate
  Circuit a b ->
  {- | Simulation configuration. Note that some options only apply to 'sampleC'
  and some only to 'driveC'.
  -}
  SimulationConfig ->
  -- | Circuit input
  SimulateFwdType a ->
  -- | Circuit output
  SimulateFwdType b
simulateC c conf as =
  sampleC conf (driveC conf as |> c)

{- | Like 'simulateC', but does not allow caller to control and observe
backpressure. Furthermore, it ignores all data produced while the reset is
asserted.

Example:

>>> import qualified Protocols.Df as Df
>>> take 2 (simulateCS (Df.catMaybes @C.System @Int) [Nothing, Just 1, Nothing, Just 3])
[1,3]
-}
simulateCS ::
  forall a b.
  (Drivable a, Drivable b) =>
  -- | Circuit to simulate
  Circuit a b ->
  -- | Circuit input
  ExpectType a ->
  -- | Circuit output
  ExpectType b
simulateCS c =
  fromSimulateType (Proxy @b)
    . simulateC c def{ignoreReset = True}
    . toSimulateType (Proxy @a)

-- | Like 'simulateCS', but takes a circuit expecting a clock, reset, and enable.
simulateCSE ::
  forall dom a b.
  (Drivable a, Drivable b, C.KnownDomain dom) =>
  -- | Circuit to simulate
  (C.Clock dom -> C.Reset dom -> C.Enable dom -> Circuit a b) ->
  -- | Circuit input
  ExpectType a ->
  -- | Circuit output
  ExpectType b
simulateCSE c = simulateCS (c clk rst ena)
 where
  clk = C.clockGen
  rst = resetGen (resetCycles def)
  ena = C.enableGen

  resetGen n =
    C.unsafeFromActiveHigh $
      C.fromList (replicate n True <> repeat False)

{- | Applies conversion functions defined in the 'Simulate' instance of @a@ and
@b@ to the given simulation types, and applies the results to the internal
function of the given t'Circuit'. The resulting internal types are converted to
the simulation types.
-}
simulateCircuit ::
  forall a b.
  (Simulate a, Simulate b) =>
  SimulateFwdType a ->
  SimulateBwdType b ->
  Circuit a b ->
  (SimulateBwdType a, SimulateFwdType b)
simulateCircuit fwds bwds circ =
  (sigToSimBwd (Proxy @a) bwdSig, sigToSimFwd (Proxy @b) fwdSig)
 where
  (bwdSig, fwdSig) =
    toSignals circ $
      (simToSigFwd (Proxy @a) fwds, simToSigBwd (Proxy @b) bwds)