{-# LANGUAGE UndecidableInstances #-}
-- | Additional classes that help in comparing values in tests.
module Shared
  ( lowercase, HasShape (shapeL), Linearizable (linearize)
  ) where

import Prelude

import Data.Char qualified
import Data.Foldable qualified
import Data.Int (Int64)
import Data.Vector.Storable qualified as VS
import Foreign.C (CInt)
import GHC.Exts (IsList (..))
import GHC.TypeLits (KnownNat)

import Data.Array.Nested qualified as Nested
import Data.Array.Nested.Shaped.Shape

import HordeAd.Core.CarriersConcrete
import HordeAd.Core.Types

lowercase :: String -> String
lowercase :: String -> String
lowercase = (Char -> Char) -> String -> String
forall a b. (a -> b) -> [a] -> [b]
map Char -> Char
Data.Char.toLower

-- | Things that have shape.
class HasShape a where
  shapeL :: a -> [Int]

instance (KnownNat n, Nested.PrimElt a) => HasShape (Nested.Ranked n a) where
  shapeL :: Ranked n a -> [Int]
shapeL = IShR n -> [Int]
IShR n -> [Item (IShR n)]
forall l. IsList l => l -> [Item l]
toList (IShR n -> [Int]) -> (Ranked n a -> IShR n) -> Ranked n a -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ranked n a -> IShR n
forall a (n :: Nat). Elt a => Ranked n a -> IShR n
Nested.rshape

instance KnownShS sh => HasShape (Nested.Shaped sh a) where
  shapeL :: Shaped sh a -> [Int]
shapeL Shaped sh a
_ = ShS sh -> [Item (ShS sh)]
forall l. IsList l => l -> [Item l]
toList (ShS sh -> [Item (ShS sh)]) -> ShS sh -> [Item (ShS sh)]
forall a b. (a -> b) -> a -> b
$ forall (sh :: [Nat]). KnownShS sh => ShS sh
knownShS @sh

instance HasShape (RepConcrete y) => HasShape (Concrete y) where
  shapeL :: Concrete y -> [Int]
shapeL = RepConcrete y -> [Int]
forall a. HasShape a => a -> [Int]
shapeL (RepConcrete y -> [Int])
-> (Concrete y -> RepConcrete y) -> Concrete y -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Concrete y -> RepConcrete y
forall (y :: TK). Concrete y -> RepConcrete y
unConcrete

instance HasShape Double where
  shapeL :: Double -> [Int]
shapeL Double
_ = []

instance HasShape Float where
  shapeL :: Float -> [Int]
shapeL Float
_ = []

instance HasShape Int64 where
  shapeL :: Int64 -> [Int]
shapeL Int64
_ = []

instance HasShape CInt where
  shapeL :: CInt -> [Int]
shapeL CInt
_ = []

instance HasShape Z1 where
  shapeL :: Z1 -> [Int]
shapeL Z1
_ = [Int
0]

instance {-# OVERLAPPABLE #-} (Foldable t) => HasShape (t a) where
  shapeL :: t a -> [Int]
shapeL = (Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: []) (Int -> [Int]) -> (t a -> Int) -> t a -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. t a -> Int
forall a. t a -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length

-- | Things that can be linearized, i.e. converted to a list.
class Linearizable a b | a -> b where
  linearize :: a -> [b]

instance (VS.Storable a, Nested.PrimElt a)
         => Linearizable (Nested.Ranked n a) a where
  linearize :: Ranked n a -> [a]
linearize = Vector a -> [a]
forall a. Storable a => Vector a -> [a]
VS.toList (Vector a -> [a]) -> (Ranked n a -> Vector a) -> Ranked n a -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ranked n a -> Vector a
forall a (n :: Nat). PrimElt a => Ranked n a -> Vector a
Nested.rtoVector

instance (VS.Storable a, Nested.PrimElt a)
         => Linearizable (Nested.Shaped sh a) a where
  linearize :: Shaped sh a -> [a]
linearize = Vector a -> [a]
forall a. Storable a => Vector a -> [a]
VS.toList (Vector a -> [a])
-> (Shaped sh a -> Vector a) -> Shaped sh a -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Shaped sh a -> Vector a
forall a (sh :: [Nat]). PrimElt a => Shaped sh a -> Vector a
Nested.stoVector

instance Linearizable (RepConcrete y) a
         => Linearizable (Concrete y) a where
  linearize :: Concrete y -> [a]
linearize = RepConcrete y -> [a]
forall a b. Linearizable a b => a -> [b]
linearize (RepConcrete y -> [a])
-> (Concrete y -> RepConcrete y) -> Concrete y -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Concrete y -> RepConcrete y
forall (y :: TK). Concrete y -> RepConcrete y
unConcrete

instance Linearizable Double Double where
  linearize :: Double -> [Double]
linearize Double
x = [Double
x]

instance Linearizable Float Float where
  linearize :: Float -> [Float]
linearize Float
x = [Float
x]

instance Linearizable Int64 Int64 where
  linearize :: Int64 -> [Int64]
linearize Int64
x = [Int64
x]

instance Linearizable CInt CInt where
  linearize :: CInt -> [CInt]
linearize CInt
x = [CInt
x]

instance Linearizable Z1 Z1 where
  linearize :: Z1 -> [Z1]
linearize Z1
_ = []

instance {-# OVERLAPPABLE #-} (Foldable t) => Linearizable (t a) a where
  linearize :: t a -> [a]
linearize = t a -> [a]
forall a. t a -> [a]
forall (t :: Type -> Type) a. Foldable t => t a -> [a]
Data.Foldable.toList