{-# LANGUAGE DataKinds #-}
{-# LANGUAGE RebindableSyntax #-}
{-# LANGUAGE TypeFamilies #-}
{-# OPTIONS_GHC -Wno-incomplete-uni-patterns #-}
{-# OPTIONS_GHC -Wno-orphans #-}
{-# OPTIONS_GHC -Wno-redundant-constraints #-}
module Harpie.NumHask
(
ident,
undiag,
mult,
invtri,
inverse,
chol,
)
where
import Data.Functor.Rep
import Fcf qualified
import GHC.TypeNats
import Harpie.Fixed as F hiding (chol, ident, inverse, invtri, mult, undiag)
import Harpie.Shape (DeleteDims, Eval, GetDims, KnownNats, Rank, type (++))
import Harpie.Shape qualified as S
import NumHask.Prelude as P hiding (Min, cycle, diff, drop, empty, find, length, repeat, sequence, take, zipWith)
instance
( Additive a,
KnownNats s
) =>
Additive (Array s a)
where
+ :: Array s a -> Array s a -> Array s a
(+) = (a -> a -> a) -> Array s a -> Array s a -> Array s a
forall (f :: * -> *) a b c.
Representable f =>
(a -> b -> c) -> f a -> f b -> f c
liftR2 a -> a -> a
forall a. Additive a => a -> a -> a
(+)
zero :: Array s a
zero = a -> Array s a
forall (f :: * -> *) a. Representable f => a -> f a
pureRep a
forall a. Additive a => a
zero
instance
( Subtractive a,
KnownNats s
) =>
Subtractive (Array s a)
where
negate :: Array s a -> Array s a
negate = (a -> a) -> Array s a -> Array s a
forall (f :: * -> *) a b. Representable f => (a -> b) -> f a -> f b
fmapRep a -> a
forall a. Subtractive a => a -> a
negate
instance
(Multiplicative a) =>
MultiplicativeAction (Array s a)
where
type Scalar (Array s a) = a
|* :: Array s a -> Scalar (Array s a) -> Array s a
(|*) Array s a
r Scalar (Array s a)
s = (a -> a) -> Array s a -> Array s a
forall a b. (a -> b) -> Array s a -> Array s b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a
Scalar (Array s a)
s *) Array s a
r
instance (Additive a) => AdditiveAction (Array s a) where
type AdditiveScalar (Array s a) = a
|+ :: Array s a -> AdditiveScalar (Array s a) -> Array s a
(|+) Array s a
r AdditiveScalar (Array s a)
s = (a -> a) -> Array s a -> Array s a
forall a b. (a -> b) -> Array s a -> Array s b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a
AdditiveScalar (Array s a)
s +) Array s a
r
instance
(Subtractive a) =>
SubtractiveAction (Array s a)
where
|- :: Array s a -> AdditiveScalar (Array s a) -> Array s a
(|-) Array s a
r AdditiveScalar (Array s a)
s = (a -> a) -> Array s a -> Array s a
forall a b. (a -> b) -> Array s a -> Array s b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\a
x -> a
x a -> a -> a
forall a. Subtractive a => a -> a -> a
- a
AdditiveScalar (Array s a)
s) Array s a
r
instance
(Divisive a) =>
DivisiveAction (Array s a)
where
|/ :: Array s a -> Scalar (Array s a) -> Array s a
(|/) Array s a
r Scalar (Array s a)
s = (a -> a) -> Array s a -> Array s a
forall a b. (a -> b) -> Array s a -> Array s b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Scalar (Array s a) -> Scalar (Array s a) -> Scalar (Array s a)
forall a. Divisive a => a -> a -> a
/ Scalar (Array s a)
s) Array s a
r
instance (KnownNats s, JoinSemiLattice a) => JoinSemiLattice (Array s a) where
\/ :: Array s a -> Array s a -> Array s a
(\/) = (a -> a -> a) -> Array s a -> Array s a -> Array s a
forall (f :: * -> *) a b c.
Representable f =>
(a -> b -> c) -> f a -> f b -> f c
liftR2 a -> a -> a
forall a. JoinSemiLattice a => a -> a -> a
(\/)
instance (KnownNats s, MeetSemiLattice a) => MeetSemiLattice (Array s a) where
/\ :: Array s a -> Array s a -> Array s a
(/\) = (a -> a -> a) -> Array s a -> Array s a -> Array s a
forall (f :: * -> *) a b c.
Representable f =>
(a -> b -> c) -> f a -> f b -> f c
liftR2 a -> a -> a
forall a. MeetSemiLattice a => a -> a -> a
(/\)
instance (KnownNats s, Subtractive a, Epsilon a) => Epsilon (Array s a) where
epsilon :: Array s a
epsilon = a -> Array s a
forall (s :: [Nat]) a. KnownNats s => a -> Array s a
konst a
forall a. Epsilon a => a
epsilon
instance (FromInteger a) => FromInteger (Array ('[] :: [Nat]) a) where
fromInteger :: Integer -> Array '[] a
fromInteger Integer
x = a -> Array '[] a
forall a. a -> Array '[] a
toScalar (Integer -> a
forall a. FromInteger a => Integer -> a
fromInteger Integer
x)
instance (FromRational a) => FromRational (Array ('[] :: [Nat]) a) where
fromRational :: Rational -> Array '[] a
fromRational Rational
x = a -> Array '[] a
forall a. a -> Array '[] a
toScalar (Rational -> a
forall a. FromRational a => Rational -> a
fromRational Rational
x)
ident :: (KnownNats s, Additive a, Multiplicative a) => Array s a
ident :: forall (s :: [Nat]) a.
(KnownNats s, Additive a, Multiplicative a) =>
Array s a
ident = (Rep (Array s) -> a) -> Array s a
forall a. (Rep (Array s) -> a) -> Array s a
forall (f :: * -> *) a. Representable f => (Rep f -> a) -> f a
tabulate (a -> a -> Bool -> a
forall a. a -> a -> Bool -> a
bool a
forall a. Additive a => a
zero a
forall a. Multiplicative a => a
one (Bool -> a) -> (Fins s -> Bool) -> Fins s -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [Int] -> Bool
forall a. Eq a => [a] -> Bool
S.isDiag ([Int] -> Bool) -> (Fins s -> [Int]) -> Fins s -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Fins s -> [Int]
forall {k} (s :: k). Fins s -> [Int]
S.fromFins)
undiag ::
forall s' a s.
( KnownNats s,
KnownNats s',
s' ~ Eval ((++) s s),
Additive a
) =>
Array s a ->
Array s' a
undiag :: forall (s' :: [Nat]) a (s :: [Nat]).
(KnownNats s, KnownNats s', s' ~ Eval (s ++ s), Additive a) =>
Array s a -> Array s' a
undiag Array s a
a = (Rep (Array s') -> a) -> Array s' a
forall a. (Rep (Array s') -> a) -> Array s' a
forall (f :: * -> *) a. Representable f => (Rep f -> a) -> f a
tabulate (\Rep (Array s')
xs -> a -> a -> Bool -> a
forall a. a -> a -> Bool -> a
bool a
forall a. Additive a => a
zero (Array s a -> Rep (Array s) -> a
forall a. Array s a -> Rep (Array s) -> a
forall (f :: * -> *) a. Representable f => f a -> Rep f -> a
index Array s a
a ([Int] -> Fins s
forall {k} (s :: k). [Int] -> Fins s
S.UnsafeFins ([Int] -> Fins s) -> [Int] -> Fins s
forall a b. (a -> b) -> a -> b
$ Int -> [Int]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> [Int]) -> Int -> [Int]
forall a b. (a -> b) -> a -> b
$ Int -> [Int] -> Int
S.getDim Int
0 (Fins s' -> [Int]
forall {k} (s :: k). Fins s -> [Int]
S.fromFins Rep (Array s')
Fins s'
xs))) ([Int] -> Bool
forall a. Eq a => [a] -> Bool
S.isDiag (Fins s' -> [Int]
forall {k} (s :: k). Fins s -> [Int]
S.fromFins Rep (Array s')
Fins s'
xs)))
mult ::
forall a ds0 ds1 s0 s1 so0 so1 st si.
( Ring a,
KnownNats s0,
KnownNats s1,
KnownNats ds0,
KnownNats ds1,
KnownNats so0,
KnownNats so1,
KnownNats st,
KnownNats si,
so0 ~ Eval (DeleteDims ds0 s0),
so1 ~ Eval (DeleteDims ds1 s1),
si ~ Eval (GetDims ds0 s0),
si ~ Eval (GetDims ds1 s1),
st ~ Eval ((++) so0 so1),
ds0 ~ '[Eval ((Fcf.-) (Eval (Rank s0)) 1)],
ds1 ~ '[0]
) =>
Array s0 a ->
Array s1 a ->
Array st a
mult :: forall a (ds0 :: [Nat]) (ds1 :: [Nat]) (s0 :: [Nat]) (s1 :: [Nat])
(so0 :: [Nat]) (so1 :: [Nat]) (st :: [Nat]) (si :: [Nat]).
(Ring a, KnownNats s0, KnownNats s1, KnownNats ds0, KnownNats ds1,
KnownNats so0, KnownNats so1, KnownNats st, KnownNats si,
so0 ~ Eval (DeleteDims ds0 s0), so1 ~ Eval (DeleteDims ds1 s1),
si ~ Eval (GetDims ds0 s0), si ~ Eval (GetDims ds1 s1),
st ~ Eval (so0 ++ so1), ds0 ~ '[Eval (Eval (Rank s0) - 1)],
ds1 ~ '[0]) =>
Array s0 a -> Array s1 a -> Array st a
mult = (Array (Eval (Map (Flip GetDim s0) '[Eval (Length s0) - 1])) a
-> a)
-> (a -> a -> a) -> Array s0 a -> Array s1 a -> Array st a
forall a b c d (ds0 :: [Nat]) (ds1 :: [Nat]) (s0 :: [Nat])
(s1 :: [Nat]) (so0 :: [Nat]) (so1 :: [Nat]) (st :: [Nat])
(si :: [Nat]).
(KnownNats s0, KnownNats s1, KnownNats ds0, KnownNats ds1,
KnownNats so0, KnownNats so1, KnownNats st, KnownNats si,
so0 ~ Eval (DeleteDims ds0 s0), so1 ~ Eval (DeleteDims ds1 s1),
si ~ Eval (GetDims ds0 s0), si ~ Eval (GetDims ds1 s1),
st ~ Eval (so0 ++ so1), ds0 ~ '[Eval (Eval (Rank s0) - 1)],
ds1 ~ '[0]) =>
(Array si c -> d)
-> (a -> b -> c) -> Array s0 a -> Array s1 b -> Array st d
dot Array (Eval (Map (Flip GetDim s0) '[Eval (Length s0) - 1])) a -> a
forall a (f :: * -> *). (Additive a, Foldable f) => f a -> a
sum a -> a -> a
forall a. Multiplicative a => a -> a -> a
(*)
instance
( Multiplicative a,
P.Distributive a,
Subtractive a,
KnownNat m
) =>
Multiplicative (Matrix m m a)
where
* :: Matrix m m a -> Matrix m m a -> Matrix m m a
(*) = Matrix m m a -> Matrix m m a -> Matrix m m a
forall a (ds0 :: [Nat]) (ds1 :: [Nat]) (s0 :: [Nat]) (s1 :: [Nat])
(so0 :: [Nat]) (so1 :: [Nat]) (st :: [Nat]) (si :: [Nat]).
(Ring a, KnownNats s0, KnownNats s1, KnownNats ds0, KnownNats ds1,
KnownNats so0, KnownNats so1, KnownNats st, KnownNats si,
so0 ~ Eval (DeleteDims ds0 s0), so1 ~ Eval (DeleteDims ds1 s1),
si ~ Eval (GetDims ds0 s0), si ~ Eval (GetDims ds1 s1),
st ~ Eval (so0 ++ so1), ds0 ~ '[Eval (Eval (Rank s0) - 1)],
ds1 ~ '[0]) =>
Array s0 a -> Array s1 a -> Array st a
mult
one :: Matrix m m a
one = Matrix m m a
forall (s :: [Nat]) a.
(KnownNats s, Additive a, Multiplicative a) =>
Array s a
ident
instance
( Multiplicative a,
P.Distributive a,
Subtractive a,
Eq a,
ExpField a,
KnownNat m
) =>
Divisive (Matrix m m a)
where
recip :: Matrix m m a -> Matrix m m a
recip Matrix m m a
a = Matrix m m a -> Matrix m m a
forall a (n :: Nat).
(KnownNat n, ExpField a, Eq a) =>
Matrix n n a -> Matrix n n a
invtri (Matrix m m a -> Matrix m m a
forall a (s :: [Nat]) (s' :: [Nat]).
(KnownNats s, KnownNats s', s' ~ Eval (Reverse s)) =>
Array s a -> Array s' a
transpose (Matrix m m a -> Matrix m m a
forall (m :: Nat) a.
(KnownNat m, ExpField a) =>
Matrix m m a -> Matrix m m a
chol Matrix m m a
a)) Matrix m m a -> Matrix m m a -> Matrix m m a
forall a. Multiplicative a => a -> a -> a
* Matrix m m a -> Matrix m m a
forall a (n :: Nat).
(KnownNat n, ExpField a, Eq a) =>
Matrix n n a -> Matrix n n a
invtri (Matrix m m a -> Matrix m m a
forall (m :: Nat) a.
(KnownNat m, ExpField a) =>
Matrix m m a -> Matrix m m a
chol Matrix m m a
a)
inverse :: (Eq a, ExpField a, KnownNat m) => Matrix m m a -> Matrix m m a
inverse :: forall a (m :: Nat).
(Eq a, ExpField a, KnownNat m) =>
Matrix m m a -> Matrix m m a
inverse Matrix m m a
a = Matrix m m a -> Matrix m m a -> Matrix m m a
forall a (ds0 :: [Nat]) (ds1 :: [Nat]) (s0 :: [Nat]) (s1 :: [Nat])
(so0 :: [Nat]) (so1 :: [Nat]) (st :: [Nat]) (si :: [Nat]).
(Ring a, KnownNats s0, KnownNats s1, KnownNats ds0, KnownNats ds1,
KnownNats so0, KnownNats so1, KnownNats st, KnownNats si,
so0 ~ Eval (DeleteDims ds0 s0), so1 ~ Eval (DeleteDims ds1 s1),
si ~ Eval (GetDims ds0 s0), si ~ Eval (GetDims ds1 s1),
st ~ Eval (so0 ++ so1), ds0 ~ '[Eval (Eval (Rank s0) - 1)],
ds1 ~ '[0]) =>
Array s0 a -> Array s1 a -> Array st a
mult (Matrix m m a -> Matrix m m a
forall a (n :: Nat).
(KnownNat n, ExpField a, Eq a) =>
Matrix n n a -> Matrix n n a
invtri (Matrix m m a -> Matrix m m a
forall a (s :: [Nat]) (s' :: [Nat]).
(KnownNats s, KnownNats s', s' ~ Eval (Reverse s)) =>
Array s a -> Array s' a
transpose (Matrix m m a -> Matrix m m a
forall (m :: Nat) a.
(KnownNat m, ExpField a) =>
Matrix m m a -> Matrix m m a
chol Matrix m m a
a))) (Matrix m m a -> Matrix m m a
forall a (n :: Nat).
(KnownNat n, ExpField a, Eq a) =>
Matrix n n a -> Matrix n n a
invtri (Matrix m m a -> Matrix m m a
forall (m :: Nat) a.
(KnownNat m, ExpField a) =>
Matrix m m a -> Matrix m m a
chol Matrix m m a
a))
invtri :: forall a n. (KnownNat n, ExpField a, Eq a) => Matrix n n a -> Matrix n n a
invtri :: forall a (n :: Nat).
(KnownNat n, ExpField a, Eq a) =>
Matrix n n a -> Matrix n n a
invtri Matrix n n a
a = Array '[n] (Matrix n n a) -> Matrix n n a
forall a (f :: * -> *). (Additive a, Foldable f) => f a -> a
sum ((Int -> Matrix n n a)
-> Array '[n] Int -> Array '[n] (Matrix n n a)
forall a b. (a -> b) -> Array '[n] a -> Array '[n] b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Matrix n n a
l ^) (forall (n :: Nat). KnownNat n => Vector n Int
iota @n)) Matrix n n a -> Matrix n n a -> Matrix n n a
forall a. Multiplicative a => a -> a -> a
* Matrix n n a
ti
where
ti :: Matrix n n a
ti = Array '[n] a -> Matrix n n a
forall (s' :: [Nat]) a (s :: [Nat]).
(KnownNats s, KnownNats s', s' ~ Eval (s ++ s), Additive a) =>
Array s a -> Array s' a
undiag ((a -> a) -> Array '[n] a -> Array '[n] a
forall a b. (a -> b) -> Array '[n] a -> Array '[n] b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Divisive a => a -> a
recip (Matrix n n a -> Array '[n] a
forall (s' :: [Nat]) a (s :: [Nat]).
(KnownNats s, KnownNats s', s' ~ Eval (MinDim s)) =>
Array s a -> Array s' a
diag Matrix n n a
a))
tl :: Matrix n n a
tl = Matrix n n a
a Matrix n n a -> Matrix n n a -> Matrix n n a
forall a. Subtractive a => a -> a -> a
- Array '[n] a -> Matrix n n a
forall (s' :: [Nat]) a (s :: [Nat]).
(KnownNats s, KnownNats s', s' ~ Eval (s ++ s), Additive a) =>
Array s a -> Array s' a
undiag (Matrix n n a -> Array '[n] a
forall (s' :: [Nat]) a (s :: [Nat]).
(KnownNats s, KnownNats s', s' ~ Eval (MinDim s)) =>
Array s a -> Array s' a
diag Matrix n n a
a)
l :: Matrix n n a
l = Matrix n n a -> Matrix n n a
forall a. Subtractive a => a -> a
negate (Matrix n n a
ti Matrix n n a -> Matrix n n a -> Matrix n n a
forall a. Multiplicative a => a -> a -> a
* Matrix n n a
tl)
chol :: (KnownNat m, ExpField a) => Matrix m m a -> Matrix m m a
chol :: forall (m :: Nat) a.
(KnownNat m, ExpField a) =>
Matrix m m a -> Matrix m m a
chol Matrix m m a
a =
let l :: Matrix m m a
l =
([Int] -> a) -> Matrix m m a
forall (s :: [Nat]) a. KnownNats s => ([Int] -> a) -> Array s a
unsafeTabulate
( \[Int
i, Int
j] ->
a -> a -> Bool -> a
forall a. a -> a -> Bool -> a
bool
( a
forall a. Multiplicative a => a
one
a -> a -> a
forall a. Divisive a => a -> a -> a
/ Matrix m m a -> [Int] -> a
forall (s :: [Nat]) a. KnownNats s => Array s a -> [Int] -> a
unsafeIndex Matrix m m a
l [Int
j, Int
j]
a -> a -> a
forall a. Multiplicative a => a -> a -> a
* ( Matrix m m a -> [Int] -> a
forall (s :: [Nat]) a. KnownNats s => Array s a -> [Int] -> a
unsafeIndex Matrix m m a
a [Int
i, Int
j]
a -> a -> a
forall a. Subtractive a => a -> a -> a
- [a] -> a
forall a (f :: * -> *). (Additive a, Foldable f) => f a -> a
sum
( (\Int
k -> Matrix m m a -> [Int] -> a
forall (s :: [Nat]) a. KnownNats s => Array s a -> [Int] -> a
unsafeIndex Matrix m m a
l [Int
i, Int
k] a -> a -> a
forall a. Multiplicative a => a -> a -> a
* Matrix m m a -> [Int] -> a
forall (s :: [Nat]) a. KnownNats s => Array s a -> [Int] -> a
unsafeIndex Matrix m m a
l [Int
j, Int
k])
(Int -> a) -> [Int] -> [a]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([Int
0 .. (Int
j Int -> Int -> Int
forall a. Subtractive a => a -> a -> a
- Int
1)] :: [Int])
)
)
)
( a -> a
forall a. ExpField a => a -> a
sqrt
( Matrix m m a -> [Int] -> a
forall (s :: [Nat]) a. KnownNats s => Array s a -> [Int] -> a
unsafeIndex Matrix m m a
a [Int
i, Int
i]
a -> a -> a
forall a. Subtractive a => a -> a -> a
- [a] -> a
forall a (f :: * -> *). (Additive a, Foldable f) => f a -> a
sum
( (\Int
k -> Matrix m m a -> [Int] -> a
forall (s :: [Nat]) a. KnownNats s => Array s a -> [Int] -> a
unsafeIndex Matrix m m a
l [Int
j, Int
k] a -> Int -> a
forall a. Divisive a => a -> Int -> a
^ (Int
2 :: Int))
(Int -> a) -> [Int] -> [a]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([Int
0 .. (Int
j Int -> Int -> Int
forall a. Subtractive a => a -> a -> a
- Int
1)] :: [Int])
)
)
)
(Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
j)
)
in Matrix m m a
l