{-# language BangPatterns #-}
{-# language BlockArguments #-}
{-# language PatternSynonyms #-}
{-# language DataKinds #-}
{-# language ExplicitNamespaces #-}
{-# language GADTs #-}
{-# language KindSignatures #-}
{-# language MagicHash #-}
{-# language UnliftedNewtypes #-}
{-# language RankNTypes #-}
{-# language ScopedTypeVariables #-}
{-# language TypeApplications #-}
{-# language TypeOperators #-}
{-# language UnboxedTuples #-}
{-# language UnboxedSums #-}

module ZipVector
  ( zip
  , unzip
  ) where

import Prelude hiding (map,zip,unzip)
import Arithmetic.Types (Nat#)
import Control.Monad.ST (runST)
import Data.Either.Void (pattern LeftVoid#,pattern RightVoid#)

import qualified VectorA as A
import qualified VectorB as B
import qualified VectorC as C
import qualified Arithmetic.Fin as Fin
import qualified Arithmetic.Nat as Nat

zip :: (a -> b -> c) -> Nat# n -> A.Vector n a -> B.Vector n b -> C.Vector n c
{-# inline zip #-}
zip :: forall (a :: TYPE R) (b :: TYPE R) (c :: TYPE R) (n :: Nat).
(a -> b -> c) -> Nat# n -> Vector n a -> Vector n b -> Vector n c
zip a -> b -> c
f Nat# n
n !Vector n a
va !Vector n b
vb = case Nat# n -> EitherVoid# (0 :=:# n) (0 <# n)
forall (a :: Nat). Nat# a -> EitherVoid# (0 :=:# a) (0 <# a)
Nat.testZero# Nat# n
n of
  LeftVoid# 0 :=:# n
zeq -> (0 :=:# n) -> Vector 0 c -> Vector n c
forall (m :: Nat) (n :: Nat) (a :: TYPE R).
(m :=:# n) -> Vector m a -> Vector n a
C.substitute 0 :=:# n
zeq Vector 0 c
forall (a :: TYPE R). Vector 0 a
C.empty
  RightVoid# 0 <# n
zlt -> (forall (s :: TYPE R). ST s (Vector n c)) -> Vector n c
forall (a :: TYPE R). (forall (s :: TYPE R). ST s a) -> a
runST ((forall (s :: TYPE R). ST s (Vector n c)) -> Vector n c)
-> (forall (s :: TYPE R). ST s (Vector n c)) -> Vector n c
forall a b. (a -> b) -> a -> b
$ do
    MutableVector s n c
dst <- Nat# n -> c -> ST s (MutableVector s n c)
forall (s :: TYPE R) (n :: Nat) (a :: TYPE R).
Nat# n -> a -> ST s (MutableVector s n a)
C.initialized Nat# n
n (a -> b -> c
f (Vector n a -> Fin# n -> a
forall (n :: Nat) (a :: TYPE R). Vector n a -> Fin# n -> a
A.index Vector n a
va ((0 <# n) -> Nat# 0 -> Fin# n
forall (i :: Nat) (n :: Nat). (i <# n) -> Nat# i -> Fin# n
Fin.construct# 0 <# n
zlt (# #) -> Nat# 0
Nat.N0#)) (Vector n b -> Fin# n -> b
forall (n :: Nat) (a :: TYPE R). Vector n a -> Fin# n -> a
B.index Vector n b
vb ((0 <# n) -> Nat# 0 -> Fin# n
forall (i :: Nat) (n :: Nat). (i <# n) -> Nat# i -> Fin# n
Fin.construct# 0 <# n
zlt (# #) -> Nat# 0
Nat.N0#)))
    Nat# n -> (Fin# n -> ST s ()) -> ST s ()
forall (m :: TYPE R -> TYPE R) (a :: TYPE R) (n :: Nat).
Monad m =>
Nat# n -> (Fin# n -> m a) -> m ()
Fin.ascendM_# Nat# n
n
      (\Fin# n
fin -> do
        MutableVector s n c -> Fin# n -> c -> ST s ()
forall (s :: TYPE R) (n :: Nat) (a :: TYPE R).
MutableVector s n a -> Fin# n -> a -> ST s ()
C.write MutableVector s n c
dst Fin# n
fin (a -> b -> c
f (Vector n a -> Fin# n -> a
forall (n :: Nat) (a :: TYPE R). Vector n a -> Fin# n -> a
A.index Vector n a
va Fin# n
fin) (Vector n b -> Fin# n -> b
forall (n :: Nat) (a :: TYPE R). Vector n a -> Fin# n -> a
B.index Vector n b
vb Fin# n
fin))
      )
    MutableVector s n c -> ST s (Vector n c)
forall (s :: TYPE R) (n :: Nat) (a :: TYPE R).
MutableVector s n a -> ST s (Vector n a)
C.unsafeFreeze MutableVector s n c
dst

unzip :: (a -> (# b, c #)) -> Nat# n -> A.Vector n a -> (# B.Vector n b, C.Vector n c #)
{-# inline unzip #-}
unzip :: forall (a :: TYPE R) (b :: TYPE R) (c :: TYPE R) (n :: Nat).
(a -> (# b, c #))
-> Nat# n -> Vector n a -> (# Vector n b, Vector n c #)
unzip a -> (# b, c #)
f Nat# n
n !Vector n a
va = case Nat# n -> EitherVoid# (0 :=:# n) (0 <# n)
forall (a :: Nat). Nat# a -> EitherVoid# (0 :=:# a) (0 <# a)
Nat.testZero# Nat# n
n of
  LeftVoid# 0 :=:# n
zeq -> (# (0 :=:# n) -> Vector 0 b -> Vector n b
forall (m :: Nat) (n :: Nat) (a :: TYPE R).
(m :=:# n) -> Vector m a -> Vector n a
B.substitute 0 :=:# n
zeq Vector 0 b
forall (a :: TYPE R). Vector 0 a
B.empty, (0 :=:# n) -> Vector 0 c -> Vector n c
forall (m :: Nat) (n :: Nat) (a :: TYPE R).
(m :=:# n) -> Vector m a -> Vector n a
C.substitute 0 :=:# n
zeq Vector 0 c
forall (a :: TYPE R). Vector 0 a
C.empty #)
  RightVoid# 0 <# n
zlt ->
    let (Vector n b
x,Vector n c
y) = (forall (s :: TYPE R). ST s (Vector n b, Vector n c))
-> (Vector n b, Vector n c)
forall (a :: TYPE R). (forall (s :: TYPE R). ST s a) -> a
runST ((forall (s :: TYPE R). ST s (Vector n b, Vector n c))
 -> (Vector n b, Vector n c))
-> (forall (s :: TYPE R). ST s (Vector n b, Vector n c))
-> (Vector n b, Vector n c)
forall a b. (a -> b) -> a -> b
$ case a -> (# b, c #)
f (Vector n a -> Fin# n -> a
forall (n :: Nat) (a :: TYPE R). Vector n a -> Fin# n -> a
A.index Vector n a
va ((0 <# n) -> Nat# 0 -> Fin# n
forall (i :: Nat) (n :: Nat). (i <# n) -> Nat# i -> Fin# n
Fin.construct# 0 <# n
zlt (# #) -> Nat# 0
Nat.N0#)) of
          (# b
b0, c
c0 #) -> do
            MutableVector s n b
dstB <- Nat# n -> b -> ST s (MutableVector s n b)
forall (s :: TYPE R) (n :: Nat) (a :: TYPE R).
Nat# n -> a -> ST s (MutableVector s n a)
B.initialized Nat# n
n b
b0
            MutableVector s n c
dstC <- Nat# n -> c -> ST s (MutableVector s n c)
forall (s :: TYPE R) (n :: Nat) (a :: TYPE R).
Nat# n -> a -> ST s (MutableVector s n a)
C.initialized Nat# n
n c
c0
            Nat# n -> (Fin# n -> ST s ()) -> ST s ()
forall (m :: TYPE R -> TYPE R) (a :: TYPE R) (n :: Nat).
Monad m =>
Nat# n -> (Fin# n -> m a) -> m ()
Fin.ascendM_# Nat# n
n
              (\Fin# n
fin -> case a -> (# b, c #)
f (Vector n a -> Fin# n -> a
forall (n :: Nat) (a :: TYPE R). Vector n a -> Fin# n -> a
A.index Vector n a
va Fin# n
fin) of
                (# b
b, c
c #) -> do
                  MutableVector s n b -> Fin# n -> b -> ST s ()
forall (s :: TYPE R) (n :: Nat) (a :: TYPE R).
MutableVector s n a -> Fin# n -> a -> ST s ()
B.write MutableVector s n b
dstB Fin# n
fin b
b
                  MutableVector s n c -> Fin# n -> c -> ST s ()
forall (s :: TYPE R) (n :: Nat) (a :: TYPE R).
MutableVector s n a -> Fin# n -> a -> ST s ()
C.write MutableVector s n c
dstC Fin# n
fin c
c
              )
            Vector n b
dstB' <- MutableVector s n b -> ST s (Vector n b)
forall (s :: TYPE R) (n :: Nat) (a :: TYPE R).
MutableVector s n a -> ST s (Vector n a)
B.unsafeFreeze MutableVector s n b
dstB
            Vector n c
dstC' <- MutableVector s n c -> ST s (Vector n c)
forall (s :: TYPE R) (n :: Nat) (a :: TYPE R).
MutableVector s n a -> ST s (Vector n a)
C.unsafeFreeze MutableVector s n c
dstC
            (Vector n b, Vector n c) -> ST s (Vector n b, Vector n c)
forall (a :: TYPE R). a -> ST s a
forall (f :: TYPE R -> TYPE R) (a :: TYPE R).
Applicative f =>
a -> f a
pure (Vector n b
dstB',Vector n c
dstC')
     in (# Vector n b
x, Vector n c
y #)