module Protocols.Vec (
vecCircuits,
append,
append3,
split,
split3,
zip,
zip3,
zip4,
zip5,
unzip,
unzip3,
unzip4,
unzip5,
concat,
unconcat,
) where
import Data.Tuple
import Prelude ()
import Clash.Prelude hiding (
concat,
split,
unconcat,
unzip,
unzip3,
unzip4,
unzip5,
zip,
zip3,
zip4,
zip5,
)
import Clash.Prelude qualified as C
import Protocols.Internal (applyC)
import Protocols.Plugin
vecCircuits :: (C.KnownNat n) => C.Vec n (Circuit a b) -> Circuit (C.Vec n a) (C.Vec n b)
vecCircuits fs = Circuit (\inps -> C.unzip $ f <$> fs <*> uncurry C.zip inps)
where
f (Circuit ff) = ff
append ::
(C.KnownNat n0) =>
Circuit (C.Vec n0 circuit, C.Vec n1 circuit) (C.Vec (n0 + n1) circuit)
append = applyC (uncurry (++)) splitAtI
append3 ::
(C.KnownNat n0, C.KnownNat n1, KnownNat n2) =>
Circuit
(C.Vec n0 circuit, C.Vec n1 circuit, C.Vec n2 circuit)
(C.Vec (n0 + n1 + n2) circuit)
append3 = applyC (uncurry3 append3Vec) split3Vec
split ::
(C.KnownNat n0) =>
Circuit (C.Vec (n0 + n1) circuit) (C.Vec n0 circuit, C.Vec n1 circuit)
split = applyC splitAtI (uncurry (++))
split3 ::
(C.KnownNat n0, C.KnownNat n1, C.KnownNat n2) =>
Circuit
(C.Vec (n0 + n1 + n2) circuit)
(C.Vec n0 circuit, C.Vec n1 circuit, C.Vec n2 circuit)
split3 = applyC split3Vec (uncurry3 append3Vec)
zip ::
(C.KnownNat n) =>
Circuit (C.Vec n a, C.Vec n b) (C.Vec n (a, b))
zip = applyC (uncurry C.zip) C.unzip
zip3 ::
(C.KnownNat n) =>
Circuit (C.Vec n a, C.Vec n b, C.Vec n c) (C.Vec n (a, b, c))
zip3 = applyC (uncurry3 C.zip3) C.unzip3
zip4 ::
(C.KnownNat n) =>
Circuit (C.Vec n a, C.Vec n b, C.Vec n c, C.Vec n d) (C.Vec n (a, b, c, d))
zip4 = applyC (\(a, b, c, d) -> C.zip4 a b c d) C.unzip4
zip5 ::
(C.KnownNat n) =>
Circuit (C.Vec n a, C.Vec n b, C.Vec n c, C.Vec n d, C.Vec n e) (C.Vec n (a, b, c, d, e))
zip5 = applyC (\(a, b, c, d, e) -> C.zip5 a b c d e) C.unzip5
unzip ::
(C.KnownNat n) =>
Circuit (C.Vec n (a, b)) (C.Vec n a, C.Vec n b)
unzip = applyC C.unzip (uncurry C.zip)
unzip3 ::
(C.KnownNat n) =>
Circuit (C.Vec n (a, b, c)) (C.Vec n a, C.Vec n b, C.Vec n c)
unzip3 = applyC C.unzip3 (uncurry3 C.zip3)
unzip4 ::
(C.KnownNat n) =>
Circuit (C.Vec n (a, b, c, d)) (C.Vec n a, C.Vec n b, C.Vec n c, C.Vec n d)
unzip4 = applyC C.unzip4 (uncurry4 C.zip4)
where
uncurry4 :: (a -> b -> c -> d -> e) -> ((a, b, c, d) -> e)
uncurry4 f ~(a, b, c, d) = f a b c d
unzip5 ::
(C.KnownNat n) =>
Circuit (C.Vec n (a, b, c, d, e)) (C.Vec n a, C.Vec n b, C.Vec n c, C.Vec n d, C.Vec n e)
unzip5 = applyC C.unzip5 (uncurry5 C.zip5)
where
uncurry5 :: (a -> b -> c -> d -> e -> f) -> ((a, b, c, d, e) -> f)
uncurry5 f ~(a, b, c, d, e) = f a b c d e
concat ::
(C.KnownNat n0, C.KnownNat n1) =>
Circuit (C.Vec n0 (C.Vec n1 circuit)) (C.Vec (n0 * n1) circuit)
concat = applyC C.concat (C.unconcat SNat)
unconcat ::
(C.KnownNat n, C.KnownNat m) =>
SNat m ->
Circuit (C.Vec (n * m) circuit) (C.Vec n (C.Vec m circuit))
unconcat SNat = applyC (C.unconcat SNat) C.concat
uncurry3 :: (a -> b -> c -> d) -> (a, b, c) -> d
uncurry3 f (a, b, c) = f a b c
append3Vec ::
(KnownNat n0, KnownNat n1, KnownNat n2) =>
C.Vec n0 a ->
C.Vec n1 a ->
C.Vec n2 a ->
C.Vec (n0 + n1 + n2) a
append3Vec v0 v1 v2 = v0 ++ v1 ++ v2
split3Vec ::
(KnownNat n0, KnownNat n1, KnownNat n2) =>
C.Vec (n0 + n1 + n2) a ->
(C.Vec n0 a, C.Vec n1 a, C.Vec n2 a)
split3Vec v = (v0, v1, v2)
where
(v0, splitAtI -> (v1, v2)) = splitAtI v