-- | Utility functions for working with 'Vec's of t'Circuit's.
module Protocols.Vec (
  vecCircuits,
  append,
  append3,
  split,
  split3,
  zip,
  zip3,
  zip4,
  zip5,
  unzip,
  unzip3,
  unzip4,
  unzip5,
  concat,
  unconcat,
) where

-- base
import Data.Tuple
import Prelude ()

-- clash-prelude
import Clash.Prelude hiding (
  concat,
  split,
  unconcat,
  unzip,
  unzip3,
  unzip4,
  unzip5,
  zip,
  zip3,
  zip4,
  zip5,
 )
import Clash.Prelude qualified as C

-- clash-protocols-base
import Protocols.Internal (applyC)
import Protocols.Plugin

{- | Bundle together a 'Vec' of t'Circuit's into a t'Circuit' with 'Vec' input
and output. The t'Circuit's all run in parallel.

A general inverse of 'vecCircuits' cannot exist, because we cannot guarantee
that the @n@th output circuit depends only on the @n@th input circuit.
-}
vecCircuits :: (C.KnownNat n) => C.Vec n (Circuit a b) -> Circuit (C.Vec n a) (C.Vec n b)
vecCircuits fs = Circuit (\inps -> C.unzip $ f <$> fs <*> uncurry C.zip inps)
 where
  f (Circuit ff) = ff

-- | Append two separate vectors of the same circuits into one vector of circuits
append ::
  (C.KnownNat n0) =>
  Circuit (C.Vec n0 circuit, C.Vec n1 circuit) (C.Vec (n0 + n1) circuit)
append = applyC (uncurry (++)) splitAtI

-- | Append three separate vectors of the same circuits into one vector of circuits
append3 ::
  (C.KnownNat n0, C.KnownNat n1, KnownNat n2) =>
  Circuit
    (C.Vec n0 circuit, C.Vec n1 circuit, C.Vec n2 circuit)
    (C.Vec (n0 + n1 + n2) circuit)
append3 = applyC (uncurry3 append3Vec) split3Vec

-- | Split a vector of circuits into two vectors of circuits.
split ::
  (C.KnownNat n0) =>
  Circuit (C.Vec (n0 + n1) circuit) (C.Vec n0 circuit, C.Vec n1 circuit)
split = applyC splitAtI (uncurry (++))

-- | Split a vector of circuits into three vectors of circuits.
split3 ::
  (C.KnownNat n0, C.KnownNat n1, C.KnownNat n2) =>
  Circuit
    (C.Vec (n0 + n1 + n2) circuit)
    (C.Vec n0 circuit, C.Vec n1 circuit, C.Vec n2 circuit)
split3 = applyC split3Vec (uncurry3 append3Vec)

{- | Transforms two vectors of circuits into a vector of tuples of circuits.
Only works if the two vectors have the same length.
-}
zip ::
  (C.KnownNat n) =>
  Circuit (C.Vec n a, C.Vec n b) (C.Vec n (a, b))
zip = applyC (uncurry C.zip) C.unzip

{- | Transforms three vectors of circuits into a vector of tuples of circuits.
Only works if the three vectors have the same length.
-}
zip3 ::
  (C.KnownNat n) =>
  Circuit (C.Vec n a, C.Vec n b, C.Vec n c) (C.Vec n (a, b, c))
zip3 = applyC (uncurry3 C.zip3) C.unzip3

{- | Transforms four vectors of circuits into a vector of tuples of circuits.
Only works if the four vectors have the same length.
-}
zip4 ::
  (C.KnownNat n) =>
  Circuit (C.Vec n a, C.Vec n b, C.Vec n c, C.Vec n d) (C.Vec n (a, b, c, d))
zip4 = applyC (\(a, b, c, d) -> C.zip4 a b c d) C.unzip4

{- | Transforms five vectors of circuits into a vector of tuples of circuits.
Only works if the five vectors have the same length.
-}
zip5 ::
  (C.KnownNat n) =>
  Circuit (C.Vec n a, C.Vec n b, C.Vec n c, C.Vec n d, C.Vec n e) (C.Vec n (a, b, c, d, e))
zip5 = applyC (\(a, b, c, d, e) -> C.zip5 a b c d e) C.unzip5

-- | Unzip a vector of tuples of circuits into a tuple of vectors of circuits.
unzip ::
  (C.KnownNat n) =>
  Circuit (C.Vec n (a, b)) (C.Vec n a, C.Vec n b)
unzip = applyC C.unzip (uncurry C.zip)

-- | Unzip a vector of 3-tuples of circuits into a 3-tuple of vectors of circuits.
unzip3 ::
  (C.KnownNat n) =>
  Circuit (C.Vec n (a, b, c)) (C.Vec n a, C.Vec n b, C.Vec n c)
unzip3 = applyC C.unzip3 (uncurry3 C.zip3)

-- | Unzip a vector of 4-tuples of circuits into a 4-tuple of vectors of circuits.
unzip4 ::
  (C.KnownNat n) =>
  Circuit (C.Vec n (a, b, c, d)) (C.Vec n a, C.Vec n b, C.Vec n c, C.Vec n d)
unzip4 = applyC C.unzip4 (uncurry4 C.zip4)
 where
  uncurry4 :: (a -> b -> c -> d -> e) -> ((a, b, c, d) -> e)
  uncurry4 f ~(a, b, c, d) = f a b c d

-- | Unzip a vector of 5-tuples of circuits into a 5-tuple of vectors of circuits.
unzip5 ::
  (C.KnownNat n) =>
  Circuit (C.Vec n (a, b, c, d, e)) (C.Vec n a, C.Vec n b, C.Vec n c, C.Vec n d, C.Vec n e)
unzip5 = applyC C.unzip5 (uncurry5 C.zip5)
 where
  uncurry5 :: (a -> b -> c -> d -> e -> f) -> ((a, b, c, d, e) -> f)
  uncurry5 f ~(a, b, c, d, e) = f a b c d e

-- | Transform a vector of vectors of circuits into a vector of circuits.
concat ::
  (C.KnownNat n0, C.KnownNat n1) =>
  Circuit (C.Vec n0 (C.Vec n1 circuit)) (C.Vec (n0 * n1) circuit)
concat = applyC C.concat (C.unconcat SNat)

-- | Transform a vector of circuits into a vector of vectors of circuits.
unconcat ::
  (C.KnownNat n, C.KnownNat m) =>
  SNat m ->
  Circuit (C.Vec (n * m) circuit) (C.Vec n (C.Vec m circuit))
unconcat SNat = applyC (C.unconcat SNat) C.concat

-- Internal utilities

-- | Uncurry a function with three arguments into a function that takes a 3-tuple as argument.
uncurry3 :: (a -> b -> c -> d) -> (a, b, c) -> d
uncurry3 f (a, b, c) = f a b c

-- Append three vectors of `a` into one vector of `a`.
append3Vec ::
  (KnownNat n0, KnownNat n1, KnownNat n2) =>
  C.Vec n0 a ->
  C.Vec n1 a ->
  C.Vec n2 a ->
  C.Vec (n0 + n1 + n2) a
append3Vec v0 v1 v2 = v0 ++ v1 ++ v2

-- Split a C.Vector of 3-tuples into three vectors of the same length.
split3Vec ::
  (KnownNat n0, KnownNat n1, KnownNat n2) =>
  C.Vec (n0 + n1 + n2) a ->
  (C.Vec n0 a, C.Vec n1 a, C.Vec n2 a)
split3Vec v = (v0, v1, v2)
 where
  (v0, splitAtI -> (v1, v2)) = splitAtI v