{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{- |
Framework for extracting subsize in 'Array.unsafeCreateWithSizes'.
-}
module Data.Array.Comfort.Shape.SubSize (
   T(Cons, measure),
   auto,
   atom,
   Sub(Sub),
   sub,
   pair,
   triple,
   append,

   C(ToShape),
   Atom(Atom),
   evaluate,
   ) where

import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Shape ((::+)((::+)))


newtype T sh nsize = Cons {forall sh nsize. T sh nsize -> sh -> (Int, nsize)
measure :: sh -> (Int,nsize)}

auto :: (C nsize) => T (ToShape nsize) nsize
auto :: forall nsize. C nsize => T (ToShape nsize) nsize
auto = (ToShape nsize -> (Int, nsize)) -> T (ToShape nsize) nsize
forall sh nsize. (sh -> (Int, nsize)) -> T sh nsize
Cons ToShape nsize -> (Int, nsize)
forall nsize. C nsize => ToShape nsize -> (Int, nsize)
evaluate

atom :: (Shape.C sh) => T sh Int
atom :: forall sh. C sh => T sh Int
atom = (sh -> (Int, Int)) -> T sh Int
forall sh nsize. (sh -> (Int, nsize)) -> T sh nsize
Cons ((sh -> (Int, Int)) -> T sh Int) -> (sh -> (Int, Int)) -> T sh Int
forall a b. (a -> b) -> a -> b
$ \sh
sh -> let n :: Int
n = sh -> Int
forall sh. C sh => sh -> Int
Shape.size sh
sh in (Int
n,Int
n)

data Sub nsize = Sub Int nsize

sub :: T sh nsize -> T sh (Sub nsize)
sub :: forall sh nsize. T sh nsize -> T sh (Sub nsize)
sub (Cons sh -> (Int, nsize)
s) =
   (sh -> (Int, Sub nsize)) -> T sh (Sub nsize)
forall sh nsize. (sh -> (Int, nsize)) -> T sh nsize
Cons ((sh -> (Int, Sub nsize)) -> T sh (Sub nsize))
-> (sh -> (Int, Sub nsize)) -> T sh (Sub nsize)
forall a b. (a -> b) -> a -> b
$ \sh
sh ->
      let (Int
n,nsize
subSizes) = sh -> (Int, nsize)
s sh
sh
      in (Int
n, Int -> nsize -> Sub nsize
forall nsize. Int -> nsize -> Sub nsize
Sub Int
n nsize
subSizes)

pair ::
   T sh0 nsize0 ->
   T sh1 nsize1 ->
   T (sh0,sh1) (nsize0,nsize1)
pair :: forall sh0 nsize0 sh1 nsize1.
T sh0 nsize0 -> T sh1 nsize1 -> T (sh0, sh1) (nsize0, nsize1)
pair (Cons sh0 -> (Int, nsize0)
s0) (Cons sh1 -> (Int, nsize1)
s1) =
   ((sh0, sh1) -> (Int, (nsize0, nsize1)))
-> T (sh0, sh1) (nsize0, nsize1)
forall sh nsize. (sh -> (Int, nsize)) -> T sh nsize
Cons (((sh0, sh1) -> (Int, (nsize0, nsize1)))
 -> T (sh0, sh1) (nsize0, nsize1))
-> ((sh0, sh1) -> (Int, (nsize0, nsize1)))
-> T (sh0, sh1) (nsize0, nsize1)
forall a b. (a -> b) -> a -> b
$ \(sh0
sh0,sh1
sh1) ->
      let (Int
n0,nsize0
sub0) = sh0 -> (Int, nsize0)
s0 sh0
sh0
          (Int
n1,nsize1
sub1) = sh1 -> (Int, nsize1)
s1 sh1
sh1
      in (Int
n0Int -> Int -> Int
forall a. Num a => a -> a -> a
*Int
n1, (nsize0
sub0,nsize1
sub1))

triple ::
   T sh0 nsize0 ->
   T sh1 nsize1 ->
   T sh2 nsize2 ->
   T (sh0,sh1,sh2) (nsize0,nsize1,nsize2)
triple :: forall sh0 nsize0 sh1 nsize1 sh2 nsize2.
T sh0 nsize0
-> T sh1 nsize1
-> T sh2 nsize2
-> T (sh0, sh1, sh2) (nsize0, nsize1, nsize2)
triple (Cons sh0 -> (Int, nsize0)
s0) (Cons sh1 -> (Int, nsize1)
s1) (Cons sh2 -> (Int, nsize2)
s2) =
   ((sh0, sh1, sh2) -> (Int, (nsize0, nsize1, nsize2)))
-> T (sh0, sh1, sh2) (nsize0, nsize1, nsize2)
forall sh nsize. (sh -> (Int, nsize)) -> T sh nsize
Cons (((sh0, sh1, sh2) -> (Int, (nsize0, nsize1, nsize2)))
 -> T (sh0, sh1, sh2) (nsize0, nsize1, nsize2))
-> ((sh0, sh1, sh2) -> (Int, (nsize0, nsize1, nsize2)))
-> T (sh0, sh1, sh2) (nsize0, nsize1, nsize2)
forall a b. (a -> b) -> a -> b
$ \(sh0
sh0,sh1
sh1,sh2
sh2) ->
      let (Int
n0,nsize0
sub0) = sh0 -> (Int, nsize0)
s0 sh0
sh0
          (Int
n1,nsize1
sub1) = sh1 -> (Int, nsize1)
s1 sh1
sh1
          (Int
n2,nsize2
sub2) = sh2 -> (Int, nsize2)
s2 sh2
sh2
      in (Int
n0Int -> Int -> Int
forall a. Num a => a -> a -> a
*Int
n1Int -> Int -> Int
forall a. Num a => a -> a -> a
*Int
n2, (nsize0
sub0,nsize1
sub1,nsize2
sub2))

append ::
   T sh0 nsize0 ->
   T sh1 nsize1 ->
   T (sh0::+sh1) (nsize0::+nsize1)
append :: forall sh0 nsize0 sh1 nsize1.
T sh0 nsize0 -> T sh1 nsize1 -> T (sh0 ::+ sh1) (nsize0 ::+ nsize1)
append (Cons sh0 -> (Int, nsize0)
s0) (Cons sh1 -> (Int, nsize1)
s1) =
   ((sh0 ::+ sh1) -> (Int, nsize0 ::+ nsize1))
-> T (sh0 ::+ sh1) (nsize0 ::+ nsize1)
forall sh nsize. (sh -> (Int, nsize)) -> T sh nsize
Cons (((sh0 ::+ sh1) -> (Int, nsize0 ::+ nsize1))
 -> T (sh0 ::+ sh1) (nsize0 ::+ nsize1))
-> ((sh0 ::+ sh1) -> (Int, nsize0 ::+ nsize1))
-> T (sh0 ::+ sh1) (nsize0 ::+ nsize1)
forall a b. (a -> b) -> a -> b
$ \(sh0
sh0::+sh1
sh1) ->
      let (Int
n0,nsize0
sub0) = sh0 -> (Int, nsize0)
s0 sh0
sh0
          (Int
n1,nsize1
sub1) = sh1 -> (Int, nsize1)
s1 sh1
sh1
      in (Int
n0Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
n1, nsize0
sub0nsize0 -> nsize1 -> nsize0 ::+ nsize1
forall sh0 sh1. sh0 -> sh1 -> sh0 ::+ sh1
::+nsize1
sub1)




class C nsize where
   type ToShape nsize
   {- |
   Compute the sizes of a shape and some sub-shapes.
   -}
   evaluate :: ToShape nsize -> (Int, nsize)

newtype Atom sh = Atom Int

instance (Shape.C sh) => C (Atom sh) where
   type ToShape (Atom sh) = sh
   evaluate :: ToShape (Atom sh) -> (Int, Atom sh)
evaluate ToShape (Atom sh)
sh = let n :: Int
n = sh -> Int
forall sh. C sh => sh -> Int
Shape.size sh
ToShape (Atom sh)
sh in (Int
n, Int -> Atom sh
forall sh. Int -> Atom sh
Atom Int
n)

instance (C sub) => C (Sub sub) where
   type ToShape (Sub sub) = ToShape sub
   evaluate :: ToShape (Sub sub) -> (Int, Sub sub)
evaluate = T (ToShape (Sub sub)) (Sub sub)
-> ToShape (Sub sub) -> (Int, Sub sub)
forall sh nsize. T sh nsize -> sh -> (Int, nsize)
measure (T (ToShape (Sub sub)) (Sub sub)
 -> ToShape (Sub sub) -> (Int, Sub sub))
-> T (ToShape (Sub sub)) (Sub sub)
-> ToShape (Sub sub)
-> (Int, Sub sub)
forall a b. (a -> b) -> a -> b
$ T (ToShape sub) sub -> T (ToShape sub) (Sub sub)
forall sh nsize. T sh nsize -> T sh (Sub nsize)
sub T (ToShape sub) sub
forall nsize. C nsize => T (ToShape nsize) nsize
auto

instance (C nsize0, C nsize1) => C (nsize0,nsize1) where
   type ToShape (nsize0,nsize1) =
            (ToShape nsize0, ToShape nsize1)
   evaluate :: ToShape (nsize0, nsize1) -> (Int, (nsize0, nsize1))
evaluate = T (ToShape (nsize0, nsize1)) (nsize0, nsize1)
-> ToShape (nsize0, nsize1) -> (Int, (nsize0, nsize1))
forall sh nsize. T sh nsize -> sh -> (Int, nsize)
measure (T (ToShape (nsize0, nsize1)) (nsize0, nsize1)
 -> ToShape (nsize0, nsize1) -> (Int, (nsize0, nsize1)))
-> T (ToShape (nsize0, nsize1)) (nsize0, nsize1)
-> ToShape (nsize0, nsize1)
-> (Int, (nsize0, nsize1))
forall a b. (a -> b) -> a -> b
$ T (ToShape nsize0) nsize0
-> T (ToShape nsize1) nsize1
-> T (ToShape nsize0, ToShape nsize1) (nsize0, nsize1)
forall sh0 nsize0 sh1 nsize1.
T sh0 nsize0 -> T sh1 nsize1 -> T (sh0, sh1) (nsize0, nsize1)
pair T (ToShape nsize0) nsize0
forall nsize. C nsize => T (ToShape nsize) nsize
auto T (ToShape nsize1) nsize1
forall nsize. C nsize => T (ToShape nsize) nsize
auto

instance (C nsize0, C nsize1, C nsize2) => C (nsize0,nsize1,nsize2) where
   type ToShape (nsize0,nsize1,nsize2) =
            (ToShape nsize0, ToShape nsize1, ToShape nsize2)
   evaluate :: ToShape (nsize0, nsize1, nsize2) -> (Int, (nsize0, nsize1, nsize2))
evaluate = T (ToShape (nsize0, nsize1, nsize2)) (nsize0, nsize1, nsize2)
-> ToShape (nsize0, nsize1, nsize2)
-> (Int, (nsize0, nsize1, nsize2))
forall sh nsize. T sh nsize -> sh -> (Int, nsize)
measure (T (ToShape (nsize0, nsize1, nsize2)) (nsize0, nsize1, nsize2)
 -> ToShape (nsize0, nsize1, nsize2)
 -> (Int, (nsize0, nsize1, nsize2)))
-> T (ToShape (nsize0, nsize1, nsize2)) (nsize0, nsize1, nsize2)
-> ToShape (nsize0, nsize1, nsize2)
-> (Int, (nsize0, nsize1, nsize2))
forall a b. (a -> b) -> a -> b
$ T (ToShape nsize0) nsize0
-> T (ToShape nsize1) nsize1
-> T (ToShape nsize2) nsize2
-> T (ToShape nsize0, ToShape nsize1, ToShape nsize2)
     (nsize0, nsize1, nsize2)
forall sh0 nsize0 sh1 nsize1 sh2 nsize2.
T sh0 nsize0
-> T sh1 nsize1
-> T sh2 nsize2
-> T (sh0, sh1, sh2) (nsize0, nsize1, nsize2)
triple T (ToShape nsize0) nsize0
forall nsize. C nsize => T (ToShape nsize) nsize
auto T (ToShape nsize1) nsize1
forall nsize. C nsize => T (ToShape nsize) nsize
auto T (ToShape nsize2) nsize2
forall nsize. C nsize => T (ToShape nsize) nsize
auto

instance (C nsize0, C nsize1) => C (nsize0::+nsize1) where
   type ToShape (nsize0::+nsize1) = (ToShape nsize0 ::+ ToShape nsize1)
   evaluate :: ToShape (nsize0 ::+ nsize1) -> (Int, nsize0 ::+ nsize1)
evaluate = T (ToShape (nsize0 ::+ nsize1)) (nsize0 ::+ nsize1)
-> ToShape (nsize0 ::+ nsize1) -> (Int, nsize0 ::+ nsize1)
forall sh nsize. T sh nsize -> sh -> (Int, nsize)
measure (T (ToShape (nsize0 ::+ nsize1)) (nsize0 ::+ nsize1)
 -> ToShape (nsize0 ::+ nsize1) -> (Int, nsize0 ::+ nsize1))
-> T (ToShape (nsize0 ::+ nsize1)) (nsize0 ::+ nsize1)
-> ToShape (nsize0 ::+ nsize1)
-> (Int, nsize0 ::+ nsize1)
forall a b. (a -> b) -> a -> b
$ T (ToShape nsize0) nsize0
-> T (ToShape nsize1) nsize1
-> T (ToShape nsize0 ::+ ToShape nsize1) (nsize0 ::+ nsize1)
forall sh0 nsize0 sh1 nsize1.
T sh0 nsize0 -> T sh1 nsize1 -> T (sh0 ::+ sh1) (nsize0 ::+ nsize1)
append T (ToShape nsize0) nsize0
forall nsize. C nsize => T (ToShape nsize) nsize
auto T (ToShape nsize1) nsize1
forall nsize. C nsize => T (ToShape nsize) nsize
auto