{-# LANGUAGE UndecidableInstances #-}

module Telescope.Data.DataCube where

import Data.Kind
import Data.Massiv.Array as M hiding (mapM)
import Data.Proxy
import GHC.TypeLits (natVal)
import Telescope.Data.Array (AxesIndex (..))
import Telescope.Data.Axes (Axes, Major (Row))


-- Results ------------------------------------------------------------------------------

newtype DataCube (as :: [Type]) = DataCube
  { forall (as :: [*]). DataCube as -> Array D (IndexOf as) Double
array :: Array D (IndexOf as) Double
  }


instance (Index (IndexOf as)) => Eq (DataCube as) where
  DataCube Array D (IndexOf as) Double
arr == :: DataCube as -> DataCube as -> Bool
== DataCube Array D (IndexOf as) Double
arr2 = Array D (IndexOf as) Double
arr Array D (IndexOf as) Double -> Array D (IndexOf as) Double -> Bool
forall a. Eq a => a -> a -> Bool
== Array D (IndexOf as) Double
arr2


instance (Ragged L (IndexOf as) Double) => Show (DataCube as) where
  show :: DataCube as -> String
show (DataCube Array D (IndexOf as) Double
a) = Array D (IndexOf as) Double -> String
forall a. Show a => a -> String
show Array D (IndexOf as) Double
a


class HasIndex (as :: [Type]) where
  type IndexOf as :: Type


instance HasIndex '[] where
  type IndexOf '[] = Ix0
instance HasIndex '[a] where
  type IndexOf '[a] = Ix1
instance HasIndex '[a, b] where
  type IndexOf '[a, b] = Ix2
instance HasIndex '[a, b, c] where
  type IndexOf '[a, b, c] = Ix3
instance HasIndex '[a, b, c, d] where
  type IndexOf '[a, b, c, d] = Ix4
instance HasIndex '[a, b, c, d, e] where
  type IndexOf '[a, b, c, d, e] = Ix5


outerList
  :: forall a as
   . (Lower (IndexOf (a : as)) ~ IndexOf as, Index (IndexOf as), Index (IndexOf (a : as)))
  => DataCube (a : as)
  -> [DataCube as]
outerList :: forall a (as :: [*]).
(Lower (IndexOf (a : as)) ~ IndexOf as, Index (IndexOf as),
 Index (IndexOf (a : as))) =>
DataCube (a : as) -> [DataCube as]
outerList (DataCube Array D (IndexOf (a : as)) Double
a) = (Array D (Lower (IndexOf (a : as))) Double -> [DataCube as])
-> Array D (IndexOf (a : as)) Double -> [DataCube as]
forall ix r e m.
(Index ix, Index (Lower ix), Source r e, Monoid m) =>
(Array r (Lower ix) e -> m) -> Array r ix e -> m
foldOuterSlice Array D (Lower (IndexOf (a : as))) Double -> [DataCube as]
Array D (IndexOf as) Double -> [DataCube as]
row Array D (IndexOf (a : as)) Double
a
 where
  row :: Array D (IndexOf as) Double -> [DataCube as]
  row :: Array D (IndexOf as) Double -> [DataCube as]
row Array D (IndexOf as) Double
r = [Array D (IndexOf as) Double -> DataCube as
forall (as :: [*]). Array D (IndexOf as) Double -> DataCube as
DataCube Array D (IndexOf as) Double
r]


transposeMajor
  :: (IndexOf (a : b : xs) ~ IndexOf (b : a : xs), Index (Lower (IndexOf (b : a : xs))), Index (IndexOf (b : a : xs)))
  => DataCube (a : b : xs)
  -> DataCube (b : a : xs)
transposeMajor :: forall a b (xs :: [*]).
(IndexOf (a : b : xs) ~ IndexOf (b : a : xs),
 Index (Lower (IndexOf (b : a : xs))),
 Index (IndexOf (b : a : xs))) =>
DataCube (a : b : xs) -> DataCube (b : a : xs)
transposeMajor (DataCube Array D (IndexOf (a : b : xs)) Double
arr) = Array D (IndexOf (b : a : xs)) Double -> DataCube (b : a : xs)
forall (as :: [*]). Array D (IndexOf as) Double -> DataCube as
DataCube (Array D (IndexOf (b : a : xs)) Double -> DataCube (b : a : xs))
-> Array D (IndexOf (b : a : xs)) Double -> DataCube (b : a : xs)
forall a b. (a -> b) -> a -> b
$ Array D (IndexOf (b : a : xs)) Double
-> Array D (IndexOf (b : a : xs)) Double
forall r ix e.
(Index (Lower ix), Index ix, Source r e) =>
Array r ix e -> Array D ix e
transposeInner Array D (IndexOf (a : b : xs)) Double
Array D (IndexOf (b : a : xs)) Double
arr


transposeMinor4
  :: DataCube [a, b, c, d]
  -> DataCube [a, b, d, c]
transposeMinor4 :: forall a b c d. DataCube '[a, b, c, d] -> DataCube '[a, b, d, c]
transposeMinor4 (DataCube Array D (IndexOf '[a, b, c, d]) Double
arr) = Array D (IndexOf '[a, b, d, c]) Double -> DataCube '[a, b, d, c]
forall (as :: [*]). Array D (IndexOf as) Double -> DataCube as
DataCube (Array D (IndexOf '[a, b, d, c]) Double -> DataCube '[a, b, d, c])
-> Array D (IndexOf '[a, b, d, c]) Double -> DataCube '[a, b, d, c]
forall a b. (a -> b) -> a -> b
$ Array D Ix4 Double -> Array D Ix4 Double
forall r ix e.
(Index (Lower ix), Index ix, Source r e) =>
Array r ix e -> Array D ix e
transposeOuter Array D Ix4 Double
Array D (IndexOf '[a, b, c, d]) Double
arr


transposeMinor3
  :: DataCube [a, b, c]
  -> DataCube [a, c, b]
transposeMinor3 :: forall a b c. DataCube '[a, b, c] -> DataCube '[a, c, b]
transposeMinor3 (DataCube Array D (IndexOf '[a, b, c]) Double
arr) = Array D (IndexOf '[a, c, b]) Double -> DataCube '[a, c, b]
forall (as :: [*]). Array D (IndexOf as) Double -> DataCube as
DataCube (Array D (IndexOf '[a, c, b]) Double -> DataCube '[a, c, b])
-> Array D (IndexOf '[a, c, b]) Double -> DataCube '[a, c, b]
forall a b. (a -> b) -> a -> b
$ Array D (IxN 3) Double -> Array D (IxN 3) Double
forall r ix e.
(Index (Lower ix), Index ix, Source r e) =>
Array r ix e -> Array D ix e
transposeOuter Array D (IxN 3) Double
Array D (IndexOf '[a, b, c]) Double
arr


-- Slice along the 1st major dimension
sliceM0
  :: ( Lower (IndexOf (a : xs)) ~ IndexOf xs
     , Index (IndexOf xs)
     , Index (IndexOf (a : xs))
     )
  => Int
  -> DataCube (a : xs)
  -> DataCube xs
sliceM0 :: forall a (xs :: [*]).
(Lower (IndexOf (a : xs)) ~ IndexOf xs, Index (IndexOf xs),
 Index (IndexOf (a : xs))) =>
Int -> DataCube (a : xs) -> DataCube xs
sliceM0 Int
a (DataCube Array D (IndexOf (a : xs)) Double
arr) = Array D (IndexOf xs) Double -> DataCube xs
forall (as :: [*]). Array D (IndexOf as) Double -> DataCube as
DataCube (Array D (IndexOf (a : xs)) Double
arr Array D (IndexOf (a : xs)) Double
-> Int -> Array D (Lower (IndexOf (a : xs))) Double
forall r ix e.
(HasCallStack, Index ix, Index (Lower ix), Source r e) =>
Array r ix e -> Int -> Array r (Lower ix) e
!> Int
a)


-- Slice along the 2nd major dimension
sliceM1
  :: forall a b xs
   . ( Lower (IndexOf (a : b : xs)) ~ IndexOf (a : xs)
     , Index (IndexOf (a : xs))
     , Index (IndexOf (a : b : xs))
     )
  => Int
  -> DataCube (a : b : xs)
  -> DataCube (a : xs)
sliceM1 :: forall a b (xs :: [*]).
(Lower (IndexOf (a : b : xs)) ~ IndexOf (a : xs),
 Index (IndexOf (a : xs)), Index (IndexOf (a : b : xs))) =>
Int -> DataCube (a : b : xs) -> DataCube (a : xs)
sliceM1 Int
b (DataCube Array D (IndexOf (a : b : xs)) Double
arr) =
  let dims :: Int
dims = Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> Int) -> Integer -> Int
forall a b. (a -> b) -> a -> b
$ forall (n :: Natural) (proxy :: Natural -> *).
KnownNat n =>
proxy n -> Integer
natVal @(Dimensions (IndexOf (a : b : xs))) Proxy (Dimensions (IndexOf (a : b : xs)))
forall {k} (t :: k). Proxy t
Proxy
   in Array D (IndexOf (a : xs)) Double -> DataCube (a : xs)
forall (as :: [*]). Array D (IndexOf as) Double -> DataCube as
DataCube (Array D (IndexOf (a : xs)) Double -> DataCube (a : xs))
-> Array D (IndexOf (a : xs)) Double -> DataCube (a : xs)
forall a b. (a -> b) -> a -> b
$ Array D (IndexOf (a : b : xs)) Double
arr Array D (IndexOf (a : b : xs)) Double
-> (Dim, Int) -> Array D (Lower (IndexOf (a : b : xs))) Double
forall r ix e.
(HasCallStack, Index ix, Index (Lower ix), Source r e) =>
Array r ix e -> (Dim, Int) -> Array D (Lower ix) e
<!> (Int -> Dim
Dim (Int
dims Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1), Int
b)


-- Slice along the 3rd major dimension
sliceM2
  :: forall a b c xs
   . ( Lower (IndexOf (a : b : c : xs)) ~ IndexOf (a : b : xs)
     , Index (IndexOf (a : b : xs))
     , Index (IndexOf (a : b : c : xs))
     )
  => Int
  -> DataCube (a : b : c : xs)
  -> DataCube (a : b : xs)
sliceM2 :: forall a b c (xs :: [*]).
(Lower (IndexOf (a : b : c : xs)) ~ IndexOf (a : b : xs),
 Index (IndexOf (a : b : xs)), Index (IndexOf (a : b : c : xs))) =>
Int -> DataCube (a : b : c : xs) -> DataCube (a : b : xs)
sliceM2 Int
c (DataCube Array D (IndexOf (a : b : c : xs)) Double
arr) =
  let dims :: Int
dims = Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> Int) -> Integer -> Int
forall a b. (a -> b) -> a -> b
$ forall (n :: Natural) (proxy :: Natural -> *).
KnownNat n =>
proxy n -> Integer
natVal @(Dimensions (IndexOf (a : b : c : xs))) Proxy (Dimensions (IndexOf (a : b : c : xs)))
forall {k} (t :: k). Proxy t
Proxy
   in Array D (IndexOf (a : b : xs)) Double -> DataCube (a : b : xs)
forall (as :: [*]). Array D (IndexOf as) Double -> DataCube as
DataCube (Array D (IndexOf (a : b : xs)) Double -> DataCube (a : b : xs))
-> Array D (IndexOf (a : b : xs)) Double -> DataCube (a : b : xs)
forall a b. (a -> b) -> a -> b
$ Array D (IndexOf (a : b : c : xs)) Double
arr Array D (IndexOf (a : b : c : xs)) Double
-> (Dim, Int) -> Array D (Lower (IndexOf (a : b : c : xs))) Double
forall r ix e.
(HasCallStack, Index ix, Index (Lower ix), Source r e) =>
Array r ix e -> (Dim, Int) -> Array D (Lower ix) e
<!> (Int -> Dim
Dim (Int
dims Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
2), Int
c)


splitM0
  :: forall a xs m
   . ( Index (IndexOf (a : xs))
     , MonadThrow m
     )
  => Int
  -> DataCube (a : xs)
  -> m (DataCube (a : xs), DataCube (a : xs))
splitM0 :: forall a (xs :: [*]) (m :: * -> *).
(Index (IndexOf (a : xs)), MonadThrow m) =>
Int
-> DataCube (a : xs) -> m (DataCube (a : xs), DataCube (a : xs))
splitM0 Int
a (DataCube Array D (IndexOf (a : xs)) Double
arr) = do
  let dims :: Int
dims = Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> Int) -> Integer -> Int
forall a b. (a -> b) -> a -> b
$ forall (n :: Natural) (proxy :: Natural -> *).
KnownNat n =>
proxy n -> Integer
natVal @(Dimensions (IndexOf (a : xs))) Proxy (Dimensions (IndexOf (a : xs)))
forall {k} (t :: k). Proxy t
Proxy
  (Array D (IndexOf (a : xs)) Double
arr1, Array D (IndexOf (a : xs)) Double
arr2) <- Dim
-> Int
-> Array D (IndexOf (a : xs)) Double
-> m (Array D (IndexOf (a : xs)) Double,
      Array D (IndexOf (a : xs)) Double)
forall r ix e (m :: * -> *).
(MonadThrow m, Index ix, Source r e) =>
Dim -> Int -> Array r ix e -> m (Array D ix e, Array D ix e)
M.splitAtM (Int -> Dim
Dim Int
dims) Int
a Array D (IndexOf (a : xs)) Double
arr
  (DataCube (a : xs), DataCube (a : xs))
-> m (DataCube (a : xs), DataCube (a : xs))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Array D (IndexOf (a : xs)) Double -> DataCube (a : xs)
forall (as :: [*]). Array D (IndexOf as) Double -> DataCube as
DataCube Array D (IndexOf (a : xs)) Double
arr1, Array D (IndexOf (a : xs)) Double -> DataCube (a : xs)
forall (as :: [*]). Array D (IndexOf as) Double -> DataCube as
DataCube Array D (IndexOf (a : xs)) Double
arr2)


splitM1
  :: forall a b xs m
   . ( Index (IndexOf (a : xs))
     , Index (IndexOf (a : b : xs))
     , MonadThrow m
     )
  => Int
  -> DataCube (a : b : xs)
  -> m (DataCube (a : b : xs), DataCube (a : b : xs))
splitM1 :: forall a b (xs :: [*]) (m :: * -> *).
(Index (IndexOf (a : xs)), Index (IndexOf (a : b : xs)),
 MonadThrow m) =>
Int
-> DataCube (a : b : xs)
-> m (DataCube (a : b : xs), DataCube (a : b : xs))
splitM1 Int
b (DataCube Array D (IndexOf (a : b : xs)) Double
arr) = do
  let dims :: Int
dims = Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> Int) -> Integer -> Int
forall a b. (a -> b) -> a -> b
$ forall (n :: Natural) (proxy :: Natural -> *).
KnownNat n =>
proxy n -> Integer
natVal @(Dimensions (IndexOf (a : xs))) Proxy (Dimensions (IndexOf (a : xs)))
forall {k} (t :: k). Proxy t
Proxy
  (Array D (IndexOf (a : b : xs)) Double
arr1, Array D (IndexOf (a : b : xs)) Double
arr2) <- Dim
-> Int
-> Array D (IndexOf (a : b : xs)) Double
-> m (Array D (IndexOf (a : b : xs)) Double,
      Array D (IndexOf (a : b : xs)) Double)
forall r ix e (m :: * -> *).
(MonadThrow m, Index ix, Source r e) =>
Dim -> Int -> Array r ix e -> m (Array D ix e, Array D ix e)
M.splitAtM (Int -> Dim
Dim Int
dims) Int
b Array D (IndexOf (a : b : xs)) Double
arr
  (DataCube (a : b : xs), DataCube (a : b : xs))
-> m (DataCube (a : b : xs), DataCube (a : b : xs))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Array D (IndexOf (a : b : xs)) Double -> DataCube (a : b : xs)
forall (as :: [*]). Array D (IndexOf as) Double -> DataCube as
DataCube Array D (IndexOf (a : b : xs)) Double
arr1, Array D (IndexOf (a : b : xs)) Double -> DataCube (a : b : xs)
forall (as :: [*]). Array D (IndexOf as) Double -> DataCube as
DataCube Array D (IndexOf (a : b : xs)) Double
arr2)


dataCubeAxes :: (Index (IndexOf as), AxesIndex (IndexOf as)) => DataCube as -> Axes Row
dataCubeAxes :: forall (as :: [*]).
(Index (IndexOf as), AxesIndex (IndexOf as)) =>
DataCube as -> Axes 'Row
dataCubeAxes (DataCube Array D (IndexOf as) Double
arr) =
  let Sz IndexOf as
ix = Array D (IndexOf as) Double -> Sz (IndexOf as)
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array D ix e -> Sz ix
M.size Array D (IndexOf as) Double
arr
   in IndexOf as -> Axes 'Row
forall ix. AxesIndex ix => ix -> Axes 'Row
indexAxes IndexOf as
ix