{-# LANGUAGE MagicHash #-}
{-# LANGUAGE RecordWildCards #-}

-- | A simple HxW matrix backed by a vector, mainly for binary exponention.
--
-- The matrix is a left semigroup action: \(m_2 (m_1 v) = (m_2 \circ m_1) v\).
--
-- @since 1.1.0.0
module AtCoder.Extra.Semigroup.Matrix
  ( -- * Matrix
    Matrix (..),

    -- * Constructors
    new,
    square,
    zero,
    ident,
    diag,

    -- * Mapping
    map,

    -- * Multiplications
    mulToCol,
    mul,
    mulMod,
    mulMint,

    -- * Powers
    pow,
    powMod,
    powMint,

    -- * Rank
    rank,

    -- * Inverse
    inv,
    invRaw,

    -- * Determinant
    detMod,
    detMint,
  )
where

import AtCoder.Extra.Math qualified as ACEM
import AtCoder.Internal.Assert qualified as ACIA
import AtCoder.Internal.Barrett qualified as BT
import AtCoder.ModInt qualified as M
import Control.Monad (when)
import Control.Monad.Primitive (PrimMonad, PrimState)
import Control.Monad.ST (runST)
import Data.Foldable (for_)
import Data.Semigroup (Semigroup (..))
import Data.Vector qualified as V
import Data.Vector.Generic qualified as VG
import Data.Vector.Generic.Mutable qualified as VGM
import Data.Vector.Mutable qualified as VM
import Data.Vector.Unboxed qualified as VU
import Data.Vector.Unboxed.Mutable qualified as VUM
import Data.Word (Word64)
import GHC.Exts (proxy#)
import GHC.Stack (HasCallStack)
import GHC.TypeNats (KnownNat, natVal')
import Prelude hiding (map)

-- | A simple HxW matrix backed by a vector, mainly for binary exponention.
--
-- The matrix is a left semigroup action: \(m_2 (m_1 v) = (m_2 \circ m_1) v\).
--
--
-- @since 1.1.0.0
data Matrix a = Matrix
  { -- | @since 1.1.0.0
    forall a. Matrix a -> Int
hM :: {-# UNPACK #-} !Int,
    -- | @since 1.1.0.0
    forall a. Matrix a -> Int
wM :: {-# UNPACK #-} !Int,
    -- | @since 1.1.0.0
    forall a. Matrix a -> Vector a
vecM :: !(VU.Vector a)
  }
  deriving
    ( -- | @since 1.1.0.0
      Int -> Matrix a -> ShowS
[Matrix a] -> ShowS
Matrix a -> String
(Int -> Matrix a -> ShowS)
-> (Matrix a -> String) -> ([Matrix a] -> ShowS) -> Show (Matrix a)
forall a. (Show a, Unbox a) => Int -> Matrix a -> ShowS
forall a. (Show a, Unbox a) => [Matrix a] -> ShowS
forall a. (Show a, Unbox a) => Matrix a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall a. (Show a, Unbox a) => Int -> Matrix a -> ShowS
showsPrec :: Int -> Matrix a -> ShowS
$cshow :: forall a. (Show a, Unbox a) => Matrix a -> String
show :: Matrix a -> String
$cshowList :: forall a. (Show a, Unbox a) => [Matrix a] -> ShowS
showList :: [Matrix a] -> ShowS
Show,
      -- | @since 1.1.0.0
      Matrix a -> Matrix a -> Bool
(Matrix a -> Matrix a -> Bool)
-> (Matrix a -> Matrix a -> Bool) -> Eq (Matrix a)
forall a. (Unbox a, Eq a) => Matrix a -> Matrix a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall a. (Unbox a, Eq a) => Matrix a -> Matrix a -> Bool
== :: Matrix a -> Matrix a -> Bool
$c/= :: forall a. (Unbox a, Eq a) => Matrix a -> Matrix a -> Bool
/= :: Matrix a -> Matrix a -> Bool
Eq
    )

-- | Type alias of a column vector.
--
-- @since 1.1.0.0
type Col a = VU.Vector a

-- | \(O(hw)\) Creates an HxW matrix.
--
-- @since 1.1.0.0
{-# INLINE new #-}
new :: (HasCallStack, VU.Unbox a) => Int -> Int -> VU.Vector a -> Matrix a
new :: forall a.
(HasCallStack, Unbox a) =>
Int -> Int -> Vector a -> Matrix a
new Int
h Int
w Vector a
vec
  | Vector a -> Int
forall a. Unbox a => Vector a -> Int
VU.length Vector a
vec Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
h Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
w = String -> Matrix a
forall a. HasCallStack => String -> a
error String
"AtCoder.Extra.Matrix: size mismatch"
  | Bool
otherwise = Int -> Int -> Vector a -> Matrix a
forall a. Int -> Int -> Vector a -> Matrix a
Matrix Int
h Int
w Vector a
vec

-- | \(O(n^2)\) Creates an NxN square matrix.
--
-- @since 1.1.1.0
{-# INLINE square #-}
square :: (HasCallStack, VU.Unbox a) => Int -> VU.Vector a -> Matrix a
square :: forall a. (HasCallStack, Unbox a) => Int -> Vector a -> Matrix a
square Int
n = Int -> Int -> Vector a -> Matrix a
forall a.
(HasCallStack, Unbox a) =>
Int -> Int -> Vector a -> Matrix a
new Int
n Int
n

-- | \(O(n^2)\) Creates an NxN zero matrix.
--
-- @since 1.1.0.0
{-# INLINE zero #-}
zero :: (VU.Unbox a, Num a) => Int -> Matrix a
zero :: forall a. (Unbox a, Num a) => Int -> Matrix a
zero Int
n = Int -> Int -> Vector a -> Matrix a
forall a. Int -> Int -> Vector a -> Matrix a
Matrix Int
n Int
n (Vector a -> Matrix a) -> Vector a -> Matrix a
forall a b. (a -> b) -> a -> b
$ Int -> a -> Vector a
forall a. Unbox a => Int -> a -> Vector a
VU.replicate (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
n) a
0

-- | \(O(n^2)\) Creates an NxN identity matrix.
--
-- @since 1.1.0.0
{-# INLINE ident #-}
ident :: (VU.Unbox a, Num a) => Int -> Matrix a
ident :: forall a. (Unbox a, Num a) => Int -> Matrix a
ident Int
n = Int -> Int -> Vector a -> Matrix a
forall a. Int -> Int -> Vector a -> Matrix a
Matrix Int
n Int
n (Vector a -> Matrix a) -> Vector a -> Matrix a
forall a b. (a -> b) -> a -> b
$ (forall s. ST s (MVector s a)) -> Vector a
forall a. Unbox a => (forall s. ST s (MVector s a)) -> Vector a
VU.create ((forall s. ST s (MVector s a)) -> Vector a)
-> (forall s. ST s (MVector s a)) -> Vector a
forall a b. (a -> b) -> a -> b
$ do
  MVector s a
vec <- Int -> a -> ST s (MVector (PrimState (ST s)) a)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
VUM.replicate (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
n) a
0
  [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ [Int
0 .. Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
    MVector (PrimState (ST s)) a -> Int -> a -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector s a
MVector (PrimState (ST s)) a
vec (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
i) a
1
  MVector s a -> ST s (MVector s a)
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure MVector s a
vec

-- | \(O(n^2)\) Creates an NxN diagonal matrix.
--
-- @since 1.1.0.0
{-# INLINE diag #-}
diag :: (VU.Unbox a, Num a) => VU.Vector a -> Matrix a
diag :: forall a. (Unbox a, Num a) => Vector a -> Matrix a
diag Vector a
xs = Int -> Int -> Vector a -> Matrix a
forall a. Int -> Int -> Vector a -> Matrix a
Matrix Int
n Int
n (Vector a -> Matrix a) -> Vector a -> Matrix a
forall a b. (a -> b) -> a -> b
$ (forall s. ST s (MVector s a)) -> Vector a
forall a. Unbox a => (forall s. ST s (MVector s a)) -> Vector a
VU.create ((forall s. ST s (MVector s a)) -> Vector a)
-> (forall s. ST s (MVector s a)) -> Vector a
forall a b. (a -> b) -> a -> b
$ do
  MVector s a
vec <- Int -> a -> ST s (MVector (PrimState (ST s)) a)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
VUM.replicate (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
n) a
0
  Vector a -> (Int -> a -> ST s ()) -> ST s ()
forall (m :: * -> *) a b.
(Monad m, Unbox a) =>
Vector a -> (Int -> a -> m b) -> m ()
VU.iforM_ Vector a
xs ((Int -> a -> ST s ()) -> ST s ())
-> (Int -> a -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i a
x -> do
    MVector (PrimState (ST s)) a -> Int -> a -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector s a
MVector (PrimState (ST s)) a
vec (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
i) a
x
  MVector s a -> ST s (MVector s a)
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure MVector s a
vec
  where
    n :: Int
n = Vector a -> Int
forall a. Unbox a => Vector a -> Int
VU.length Vector a
xs

-- | \(O(n^2)\) Maps the `Matrix`.
--
-- @since 1.1.0.0
{-# INLINE map #-}
map :: (VU.Unbox a, VU.Unbox b) => (a -> b) -> Matrix a -> Matrix b
map :: forall a b. (Unbox a, Unbox b) => (a -> b) -> Matrix a -> Matrix b
map a -> b
f Matrix {Int
Vector a
hM :: forall a. Matrix a -> Int
wM :: forall a. Matrix a -> Int
vecM :: forall a. Matrix a -> Vector a
hM :: Int
wM :: Int
vecM :: Vector a
..} = Int -> Int -> Vector b -> Matrix b
forall a. Int -> Int -> Vector a -> Matrix a
Matrix Int
hM Int
wM (Vector b -> Matrix b) -> Vector b -> Matrix b
forall a b. (a -> b) -> a -> b
$ (a -> b) -> Vector a -> Vector b
forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
VU.map a -> b
f Vector a
vecM

-- | \(O(hw)\) Multiplies HxW matrix to a Hx1 column vector.
--
-- @since 1.1.0.0
{-# INLINE mulToCol #-}
mulToCol :: (Num a, VU.Unbox a) => Matrix a -> Col a -> Col a
mulToCol :: forall a. (Num a, Unbox a) => Matrix a -> Col a -> Col a
mulToCol Matrix {Int
Vector a
hM :: forall a. Matrix a -> Int
wM :: forall a. Matrix a -> Int
vecM :: forall a. Matrix a -> Vector a
hM :: Int
wM :: Int
vecM :: Vector a
..} !Vector a
col = Vector a -> Vector a
forall (v :: * -> *) a (w :: * -> *).
(Vector v a, Vector w a) =>
v a -> w a
VU.convert (Vector a -> Vector a) -> Vector a -> Vector a
forall a b. (a -> b) -> a -> b
$ (Vector a -> a) -> Vector (Vector a) -> Vector a
forall a b. (a -> b) -> Vector a -> Vector b
V.map (Vector a -> a
forall a. (Unbox a, Num a) => Vector a -> a
VU.sum (Vector a -> a) -> (Vector a -> Vector a) -> Vector a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> a -> a) -> Vector a -> Vector a -> Vector a
forall a b c.
(Unbox a, Unbox b, Unbox c) =>
(a -> b -> c) -> Vector a -> Vector b -> Vector c
VU.zipWith a -> a -> a
forall a. Num a => a -> a -> a
(*) Vector a
col) Vector (Vector a)
rows
  where
    !n :: Int
n = Vector a -> Int
forall a. Unbox a => Vector a -> Int
VU.length Vector a
col
    !()
_ = HasCallStack => Bool -> String -> ()
Bool -> String -> ()
ACIA.runtimeAssert (Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
wM) String
"AtCoder.Extra.Matrix.mulToCol: size mismatch"
    rows :: Vector (Vector a)
rows = Int
-> (Vector a -> (Vector a, Vector a))
-> Vector a
-> Vector (Vector a)
forall b a. Int -> (b -> (a, b)) -> b -> Vector a
V.unfoldrExactN Int
hM (Int -> Vector a -> (Vector a, Vector a)
forall a. Unbox a => Int -> Vector a -> (Vector a, Vector a)
VU.splitAt Int
wM) Vector a
vecM

-- | \(O(h_1 K w_2)\) Multiplies H1xK matrix to a KxW2 matrix.
--
-- @since 1.1.0.0
{-# INLINE mul #-}
mul :: forall e. (Num e, VU.Unbox e) => Matrix e -> Matrix e -> Matrix e
mul :: forall e. (Num e, Unbox e) => Matrix e -> Matrix e -> Matrix e
mul (Matrix Int
h Int
w Vector e
vecA) (Matrix Int
h' Int
w' Vector e
vecB) = Int -> Int -> Vector e -> Matrix e
forall a. Int -> Int -> Vector a -> Matrix a
Matrix Int
h Int
w' (Vector e -> Matrix e) -> Vector e -> Matrix e
forall a b. (a -> b) -> a -> b
$ (forall s. ST s (MVector s e)) -> Vector e
forall a. Unbox a => (forall s. ST s (MVector s a)) -> Vector a
VU.create ((forall s. ST s (MVector s e)) -> Vector e)
-> (forall s. ST s (MVector s e)) -> Vector e
forall a b. (a -> b) -> a -> b
$ do
  MVector s e
c <- Int -> e -> ST s (MVector (PrimState (ST s)) e)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
VUM.replicate (Int
h Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
w') (e
0 :: e)
  [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ [Int
0 .. Int
h Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
    [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ [Int
0 .. Int
w Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
k -> do
      [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ [Int
0 .. Int
w' Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
j -> do
        let !aik :: e
aik = Vector e -> Int -> e
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
VG.unsafeIndex Vector e
vecA (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
w Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
k)
        let !bkj :: e
bkj = Vector e -> Int -> e
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
VG.unsafeIndex Vector e
vecB (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
w' Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j)
        MVector (PrimState (ST s)) e -> (e -> e) -> Int -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> (a -> a) -> Int -> m ()
VGM.unsafeModify MVector s e
MVector (PrimState (ST s)) e
c (e -> e -> e
forall a. Num a => a -> a -> a
+ (e
aik e -> e -> e
forall a. Num a => a -> a -> a
* e
bkj)) (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
w' Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j)
  MVector s e -> ST s (MVector s e)
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure MVector s e
c
  where
    !()
_ = HasCallStack => Bool -> String -> ()
Bool -> String -> ()
ACIA.runtimeAssert (Int
w Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
h') String
"AtCoder.Extra.Matrix.mul: matrix size mismatch"

-- | \(O(h_1 w_2 K)\) Multiplies H1xK matrix to a KxW2 matrix, taking the mod.
--
-- @since 1.1.0.0
{-# INLINE mulMod #-}
mulMod :: Int -> Matrix Int -> Matrix Int -> Matrix Int
mulMod :: Int -> Matrix Int -> Matrix Int -> Matrix Int
mulMod !Int
m (Matrix Int
h Int
w Vector Int
vecA) (Matrix Int
h' Int
w' Vector Int
vecB) =
  Int -> Int -> Vector Int -> Matrix Int
forall a. Int -> Int -> Vector a -> Matrix a
Matrix Int
h Int
w' (Vector Int -> Matrix Int) -> Vector Int -> Matrix Int
forall a b. (a -> b) -> a -> b
$
    Int
-> ((Int, Int) -> (Int, (Int, Int))) -> (Int, Int) -> Vector Int
forall a b. Unbox a => Int -> (b -> (a, b)) -> b -> Vector a
VU.unfoldrExactN
      (Int
h Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
w')
      ( \(!Int
row, !Int
col) ->
          let !x :: Int
x = Int -> Int -> Int
forall {c}. Num c => Int -> Int -> c
f Int
row Int
col
           in if Int
col Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
w'
                then (Int
x, (Int
row Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1, Int
0))
                else (Int
x, (Int
row, Int
col Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1))
      )
      (Int
0, Int
0)
  where
    !bt :: Barrett
bt = Word32 -> Barrett
BT.new32 (Word32 -> Barrett) -> Word32 -> Barrett
forall a b. (a -> b) -> a -> b
$ Int -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
m
    -- NOTE: this is unsafe if the matrix is too large
    f :: Int -> Int -> c
f Int
row Int
col =
      Word64 -> c
forall a b. (Integral a, Num b) => a -> b
fromIntegral
        (Word64 -> c) -> (Vector Word64 -> Word64) -> Vector Word64 -> c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Word64 -> Word64 -> Word64
forall a. Integral a => a -> a -> a
`rem` Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
m)
        (Word64 -> Word64)
-> (Vector Word64 -> Word64) -> Vector Word64 -> Word64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector Word64 -> Word64
forall a. (Unbox a, Num a) => Vector a -> a
VU.sum
        (Vector Word64 -> c) -> Vector Word64 -> c
forall a b. (a -> b) -> a -> b
$ (Int -> Int -> Word64) -> Vector Int -> Vector Word64
forall a b.
(Unbox a, Unbox b) =>
(Int -> a -> b) -> Vector a -> Vector b
VU.imap
          (\Int
iRow Int
x -> Barrett -> Word64 -> Word64 -> Word64
BT.mulMod Barrett
bt (Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
x) (Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Vector Int -> Int -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
VG.unsafeIndex Vector Int
vecB (Int
col Int -> Int -> Int
forall a. Num a => a -> a -> a
+ (Int
iRow Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
w')))))
          (Int -> Int -> Vector Int -> Vector Int
forall a. Unbox a => Int -> Int -> Vector a -> Vector a
VU.unsafeSlice (Int
w Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
row) Int
w Vector Int
vecA)
    !()
_ = HasCallStack => Bool -> String -> ()
Bool -> String -> ()
ACIA.runtimeAssert (Int
w Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
h') String
"AtCoder.Extra.Matrix.mulMod: matrix size mismatch"

-- | \(O(h_1 w_2 K)\) `mul` specialized to `M.ModInt`.
--
-- @since 1.1.0.0
{-# INLINE mulMint #-}
mulMint :: forall a. (KnownNat a) => Matrix (M.ModInt a) -> Matrix (M.ModInt a) -> Matrix (M.ModInt a)
mulMint :: forall (a :: Nat).
KnownNat a =>
Matrix (ModInt a) -> Matrix (ModInt a) -> Matrix (ModInt a)
mulMint = Barrett
-> Matrix (ModInt a) -> Matrix (ModInt a) -> Matrix (ModInt a)
forall (a :: Nat).
KnownNat a =>
Barrett
-> Matrix (ModInt a) -> Matrix (ModInt a) -> Matrix (ModInt a)
mulMintImpl Barrett
bt
  where
    !bt :: Barrett
bt = Word32 -> Barrett
BT.new32 (Word32 -> Barrett) -> Word32 -> Barrett
forall a b. (a -> b) -> a -> b
$ Nat -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Proxy# a -> Nat
forall (n :: Nat). KnownNat n => Proxy# n -> Nat
natVal' (forall (a :: Nat). Proxy# a
forall {k} (a :: k). Proxy# a
proxy# @a))

{-# INLINE mulMintImpl #-}
mulMintImpl :: forall a. (KnownNat a) => BT.Barrett -> Matrix (M.ModInt a) -> Matrix (M.ModInt a) -> Matrix (M.ModInt a)
mulMintImpl :: forall (a :: Nat).
KnownNat a =>
Barrett
-> Matrix (ModInt a) -> Matrix (ModInt a) -> Matrix (ModInt a)
mulMintImpl !Barrett
bt (Matrix Int
h Int
w Vector (ModInt a)
vecA) (Matrix Int
h' Int
w' Vector (ModInt a)
vecB) =
  Int -> Int -> Vector (ModInt a) -> Matrix (ModInt a)
forall a. Int -> Int -> Vector a -> Matrix a
Matrix Int
h Int
w' (Vector (ModInt a) -> Matrix (ModInt a))
-> Vector (ModInt a) -> Matrix (ModInt a)
forall a b. (a -> b) -> a -> b
$
    Int
-> ((Int, Int) -> (ModInt a, (Int, Int)))
-> (Int, Int)
-> Vector (ModInt a)
forall a b. Unbox a => Int -> (b -> (a, b)) -> b -> Vector a
VU.unfoldrExactN
      (Int
h Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
w')
      ( \(!Int
row, !Int
col) ->
          let !x :: ModInt a
x = Int -> Int -> ModInt a
f Int
row Int
col
           in if Int
col Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
w'
                then (ModInt a
x, (Int
row Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1, Int
0))
                else (ModInt a
x, (Int
row, Int
col Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1))
      )
      (Int
0, Int
0)
  where
    -- NOTE: this is unsafe if the matrix is too large
    f :: Int -> Int -> M.ModInt a
    f :: Int -> Int -> ModInt a
f Int
row Int
col =
      Word64 -> ModInt a
forall (a :: Nat). KnownNat a => Word64 -> ModInt a
M.new64
        (Word64 -> ModInt a)
-> (Vector Word64 -> Word64) -> Vector Word64 -> ModInt a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector Word64 -> Word64
forall a. (Unbox a, Num a) => Vector a -> a
VU.sum
        (Vector Word64 -> ModInt a) -> Vector Word64 -> ModInt a
forall a b. (a -> b) -> a -> b
$ (Int -> ModInt a -> Word64) -> Vector (ModInt a) -> Vector Word64
forall a b.
(Unbox a, Unbox b) =>
(Int -> a -> b) -> Vector a -> Vector b
VU.imap
          ( \Int
iRow ModInt a
x ->
              Barrett -> Word64 -> Word64 -> Word64
BT.mulMod
                Barrett
bt
                (Word32 -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ModInt a -> Word32
forall {k} (a :: k). ModInt a -> Word32
M.unModInt ModInt a
x))
                (Word32 -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ModInt a -> Word32
forall {k} (a :: k). ModInt a -> Word32
M.unModInt (Vector (ModInt a) -> Int -> ModInt a
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
VG.unsafeIndex Vector (ModInt a)
vecB (Int
col Int -> Int -> Int
forall a. Num a => a -> a -> a
+ (Int
iRow Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
w')))))
          )
          (Int -> Int -> Vector (ModInt a) -> Vector (ModInt a)
forall a. Unbox a => Int -> Int -> Vector a -> Vector a
VU.unsafeSlice (Int
w Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
row) Int
w Vector (ModInt a)
vecA)
    !()
_ = HasCallStack => Bool -> String -> ()
Bool -> String -> ()
ACIA.runtimeAssert (Int
w Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
h') String
"AtCoder.Extra.Matrix.mulMintImpl: matrix size mismatch"

-- | \(O(w n^3)\) Returns \(M^k\).
--
-- @since 1.1.0.0
{-# INLINE pow #-}
pow :: Int -> Matrix Int -> Matrix Int
pow :: Int -> Matrix Int -> Matrix Int
pow Int
k Matrix Int
mat
  | Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 = String -> Matrix Int
forall a. HasCallStack => String -> a
error String
"AtCoder.Extra.Matrix.pow: the exponential must be non-negative"
  | Int
k Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = Int -> Matrix Int
forall a. (Unbox a, Num a) => Int -> Matrix a
ident (Int -> Matrix Int) -> Int -> Matrix Int
forall a b. (a -> b) -> a -> b
$ Matrix Int -> Int
forall a. Matrix a -> Int
hM Matrix Int
mat
  | Bool
otherwise = (Matrix Int -> Matrix Int -> Matrix Int)
-> Int -> Matrix Int -> Matrix Int
forall a. (a -> a -> a) -> Int -> a -> a
ACEM.power Matrix Int -> Matrix Int -> Matrix Int
forall e. (Num e, Unbox e) => Matrix e -> Matrix e -> Matrix e
mul Int
k Matrix Int
mat
  where
    !()
_ = HasCallStack => Bool -> String -> ()
Bool -> String -> ()
ACIA.runtimeAssert (Matrix Int -> Int
forall a. Matrix a -> Int
hM Matrix Int
mat Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Matrix Int -> Int
forall a. Matrix a -> Int
wM Matrix Int
mat) String
"AtCoder.Extra.Matrix.pow: matrix size mismatch"

-- | \(O(w n^3)\) Returns \(M^k\), taking the mod.
--
-- @since 1.1.0.0
{-# INLINE powMod #-}
powMod :: Int -> Int -> Matrix Int -> Matrix Int
powMod :: Int -> Int -> Matrix Int -> Matrix Int
powMod Int
m Int
k Matrix Int
mat
  | Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 = String -> Matrix Int
forall a. HasCallStack => String -> a
error String
"AtCoder.Extra.Matrix.powMod: the exponential must be non-negative"
  | Int
k Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = Int -> Matrix Int
forall a. (Unbox a, Num a) => Int -> Matrix a
ident (Int -> Matrix Int) -> Int -> Matrix Int
forall a b. (a -> b) -> a -> b
$ Matrix Int -> Int
forall a. Matrix a -> Int
hM Matrix Int
mat
  | Bool
otherwise = (Matrix Int -> Matrix Int -> Matrix Int)
-> Int -> Matrix Int -> Matrix Int
forall a. (a -> a -> a) -> Int -> a -> a
ACEM.power (Int -> Matrix Int -> Matrix Int -> Matrix Int
mulMod Int
m) Int
k Matrix Int
mat
  where
    !()
_ = HasCallStack => Bool -> String -> ()
Bool -> String -> ()
ACIA.runtimeAssert (Matrix Int -> Int
forall a. Matrix a -> Int
hM Matrix Int
mat Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Matrix Int -> Int
forall a. Matrix a -> Int
wM Matrix Int
mat) String
"AtCoder.Extra.Matrix.powMod: matrix size mismatch"

-- | \(O(w n^3)\) Returns \(M^k\), specialized to `M.ModInt`.
--
-- @since 1.1.0.0
{-# INLINE powMint #-}
powMint :: forall m. (KnownNat m) => Int -> Matrix (M.ModInt m) -> Matrix (M.ModInt m)
powMint :: forall (m :: Nat).
KnownNat m =>
Int -> Matrix (ModInt m) -> Matrix (ModInt m)
powMint Int
k Matrix (ModInt m)
mat
  | Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 = String -> Matrix (ModInt m)
forall a. HasCallStack => String -> a
error String
"AtCoder.Extra.Matrix.powMint: the exponential must be non-negative"
  | Int
k Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = Int -> Matrix (ModInt m)
forall a. (Unbox a, Num a) => Int -> Matrix a
ident (Int -> Matrix (ModInt m)) -> Int -> Matrix (ModInt m)
forall a b. (a -> b) -> a -> b
$ Matrix (ModInt m) -> Int
forall a. Matrix a -> Int
hM Matrix (ModInt m)
mat
  | Bool
otherwise = (Matrix (ModInt m) -> Matrix (ModInt m) -> Matrix (ModInt m))
-> Int -> Matrix (ModInt m) -> Matrix (ModInt m)
forall a. (a -> a -> a) -> Int -> a -> a
ACEM.power (Barrett
-> Matrix (ModInt m) -> Matrix (ModInt m) -> Matrix (ModInt m)
forall (a :: Nat).
KnownNat a =>
Barrett
-> Matrix (ModInt a) -> Matrix (ModInt a) -> Matrix (ModInt a)
mulMintImpl Barrett
bt) Int
k Matrix (ModInt m)
mat
  where
    !()
_ = HasCallStack => Bool -> String -> ()
Bool -> String -> ()
ACIA.runtimeAssert (Matrix (ModInt m) -> Int
forall a. Matrix a -> Int
hM Matrix (ModInt m)
mat Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Matrix (ModInt m) -> Int
forall a. Matrix a -> Int
wM Matrix (ModInt m)
mat) String
"AtCoder.Extra.Matrix.powMint: matrix size mismatch"
    !bt :: Barrett
bt = Word32 -> Barrett
BT.new32 (Word32 -> Barrett) -> Word32 -> Barrett
forall a b. (a -> b) -> a -> b
$ Nat -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Proxy# m -> Nat
forall (n :: Nat). KnownNat n => Proxy# n -> Nat
natVal' (forall (a :: Nat). Proxy# a
forall {k} (a :: k). Proxy# a
proxy# @m))

-- | (Internal)
{-# INLINE read2d #-}
read2d ::
  (PrimMonad m, VU.Unbox a) =>
  VM.MVector (PrimState m) (VUM.MVector (PrimState m) a) ->
  Int ->
  Int ->
  m a
read2d :: forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) (MVector (PrimState m) a)
-> Int -> Int -> m a
read2d MVector (PrimState m) (MVector (PrimState m) a)
view Int
i Int
j = do
  MVector (PrimState m) a
row <- MVector (PrimState m) (MVector (PrimState m) a)
-> Int -> m (MVector (PrimState m) a)
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.unsafeRead MVector (PrimState m) (MVector (PrimState m) a)
view Int
i
  MVector (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.unsafeRead MVector (PrimState m) a
row Int
j

-- | \(O(hw \min(h, w))\) Returns the rank of the matrix.
--
-- @since 1.1.1.0
{-# INLINE rank #-}
rank :: (Fractional a, Eq a, VU.Unbox a) => Matrix a -> Int
rank :: forall a. (Fractional a, Eq a, Unbox a) => Matrix a -> Int
rank (Matrix Int
h Int
w Vector a
vec) = (forall s. ST s Int) -> Int
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s Int) -> Int) -> (forall s. ST s Int) -> Int
forall a b. (a -> b) -> a -> b
$ do
  MVector s a
vm <- Vector a -> ST s (MVector (PrimState (ST s)) a)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Vector a -> m (MVector (PrimState m) a)
VU.thaw Vector a
vec
  MVector (PrimState (ST s)) (MVector (PrimState (ST s)) a)
view <- Vector (MVector s a)
-> ST s (MVector (PrimState (ST s)) (MVector s a))
forall (m :: * -> *) a.
PrimMonad m =>
Vector a -> m (MVector (PrimState m) a)
V.thaw (Vector (MVector s a)
 -> ST s (MVector (PrimState (ST s)) (MVector s a)))
-> Vector (MVector s a)
-> ST s (MVector (PrimState (ST s)) (MVector s a))
forall a b. (a -> b) -> a -> b
$ Int
-> (MVector s a -> (MVector s a, MVector s a))
-> MVector s a
-> Vector (MVector s a)
forall b a. Int -> (b -> (a, b)) -> b -> Vector a
V.unfoldrExactN Int
h (Int -> MVector s a -> (MVector s a, MVector s a)
forall a s.
Unbox a =>
Int -> MVector s a -> (MVector s a, MVector s a)
VUM.splitAt Int
w) MVector s a
vm
  let inner :: Int -> Int -> ST s Int
inner Int
rk Int
j
        | Int
rk Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
h Bool -> Bool -> Bool
|| Int
j Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
w = Int -> ST s Int
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
rk
        | Bool
otherwise = do
            a
xrj <- MVector (PrimState (ST s)) (MVector (PrimState (ST s)) a)
-> Int -> Int -> ST s a
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) (MVector (PrimState m) a)
-> Int -> Int -> m a
read2d MVector (PrimState (ST s)) (MVector (PrimState (ST s)) a)
view Int
rk Int
j
            Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (a
xrj a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
              let runSwap :: Int -> ST s ()
runSwap Int
i
                    | Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
h = () -> ST s ()
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
                    | Bool
otherwise = do
                        a
xij <- MVector (PrimState (ST s)) (MVector (PrimState (ST s)) a)
-> Int -> Int -> ST s a
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) (MVector (PrimState m) a)
-> Int -> Int -> m a
read2d MVector (PrimState (ST s)) (MVector (PrimState (ST s)) a)
view Int
i Int
j
                        if a
xij a -> a -> Bool
forall a. Eq a => a -> a -> Bool
/= a
0
                          then MVector (PrimState (ST s)) (MVector (PrimState (ST s)) a)
-> Int -> Int -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> Int -> m ()
VGM.unsafeSwap MVector (PrimState (ST s)) (MVector (PrimState (ST s)) a)
view Int
rk Int
i
                          else Int -> ST s ()
runSwap (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
              Int -> ST s ()
runSwap (Int
rk Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
            a
xrj' <- MVector (PrimState (ST s)) (MVector (PrimState (ST s)) a)
-> Int -> Int -> ST s a
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) (MVector (PrimState m) a)
-> Int -> Int -> m a
read2d MVector (PrimState (ST s)) (MVector (PrimState (ST s)) a)
view Int
rk Int
j
            if a
xrj' a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0
              then Int -> Int -> ST s Int
inner Int
rk (Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
              else do
                let c :: a
c = a
1 a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
xrj'
                MVector (PrimState (ST s)) a
rowRk <- MVector (PrimState (ST s)) (MVector (PrimState (ST s)) a)
-> Int -> ST s (MVector (PrimState (ST s)) a)
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState (ST s)) (MVector (PrimState (ST s)) a)
view Int
rk
                -- for_ [j .. w - 1] $ \k -> do
                MVector (PrimState (ST s)) a -> (Int -> a -> ST s ()) -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a b.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> (Int -> a -> m b) -> m ()
VGM.iforM_ (Int -> MVector (PrimState (ST s)) a -> MVector (PrimState (ST s)) a
forall (v :: * -> * -> *) a s. MVector v a => Int -> v s a -> v s a
VGM.unsafeDrop Int
j MVector (PrimState (ST s)) a
rowRk) ((Int -> a -> ST s ()) -> ST s ())
-> (Int -> a -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
k_ a
x -> do
                  MVector (PrimState (ST s)) a -> Int -> a -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.unsafeWrite MVector (PrimState (ST s)) a
rowRk (Int
k_ Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j) (a -> ST s ()) -> a -> ST s ()
forall a b. (a -> b) -> a -> b
$! a
c a -> a -> a
forall a. Num a => a -> a -> a
* a
x
                [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ [Int
rk Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 .. Int
h Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
                  a
c_ <- MVector (PrimState (ST s)) (MVector (PrimState (ST s)) a)
-> Int -> Int -> ST s a
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) (MVector (PrimState m) a)
-> Int -> Int -> m a
read2d MVector (PrimState (ST s)) (MVector (PrimState (ST s)) a)
view Int
i Int
j
                  MVector (PrimState (ST s)) a
rowI <- MVector (PrimState (ST s)) (MVector (PrimState (ST s)) a)
-> Int -> ST s (MVector (PrimState (ST s)) a)
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState (ST s)) (MVector (PrimState (ST s)) a)
view Int
i
                  -- for_ [j .. w - 1] $ \k -> do
                  MVector (PrimState (ST s)) a -> (Int -> a -> ST s ()) -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a b.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> (Int -> a -> m b) -> m ()
VGM.iforM_ (Int -> MVector (PrimState (ST s)) a -> MVector (PrimState (ST s)) a
forall (v :: * -> * -> *) a s. MVector v a => Int -> v s a -> v s a
VGM.unsafeDrop Int
j MVector (PrimState (ST s)) a
rowRk) ((Int -> a -> ST s ()) -> ST s ())
-> (Int -> a -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
k_ a
ark -> do
                    MVector (PrimState (ST s)) a -> (a -> a) -> Int -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> (a -> a) -> Int -> m ()
VGM.unsafeModify MVector (PrimState (ST s)) a
rowI (a -> a -> a
forall a. Num a => a -> a -> a
subtract (a
ark a -> a -> a
forall a. Num a => a -> a -> a
* a
c_)) (Int
k_ Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j)
                Int -> Int -> ST s Int
inner (Int
rk Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
  Int -> Int -> ST s Int
inner Int
0 Int
0

-- TODO: add HasCallStack and compare their speeds

-- | \(O(n^3)\) Returns @(det, invMatrix)@ or `Nothing` if the matrix does not have inverse (the
-- determinant is zero).
--
-- ==== Constraints
-- - The input must be a square matrix.
--
-- @since 1.1.1.0
{-# INLINE inv #-}
inv :: forall a. (Fractional a, Eq a, VU.Unbox a) => Matrix a -> Maybe (a, Matrix a)
inv :: forall a.
(Fractional a, Eq a, Unbox a) =>
Matrix a -> Maybe (a, Matrix a)
inv mat :: Matrix a
mat@(Matrix Int
n Int
_ Vector a
_) = do
  (!a
det, !Vector (Vector a)
invMat) <- Matrix a -> Maybe (a, Vector (Vector a))
forall a.
(Fractional a, Eq a, Unbox a) =>
Matrix a -> Maybe (a, Vector (Vector a))
invRaw Matrix a
mat
  let !invMat' :: Vector a
invMat' = [Vector a] -> Vector a
forall a. Unbox a => [Vector a] -> Vector a
VU.concat ([Vector a] -> Vector a) -> [Vector a] -> Vector a
forall a b. (a -> b) -> a -> b
$ Vector (Vector a) -> [Vector a]
forall a. Vector a -> [a]
V.toList Vector (Vector a)
invMat
  (a, Matrix a) -> Maybe (a, Matrix a)
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
det, Int -> Int -> Vector a -> Matrix a
forall a. Int -> Int -> Vector a -> Matrix a
Matrix Int
n Int
n Vector a
invMat')

-- | \(O(n^3)\) Returns @(det, invMatrix)@ or `Nothing` if the matrix does not have inverse (the
-- determinant is zero).
--
-- ==== Constraints
-- - The input must be a square matrix.
--
-- @since 1.1.1.0
{-# INLINE invRaw #-}
invRaw :: forall a. (Fractional a, Eq a, VU.Unbox a) => Matrix a -> Maybe (a, V.Vector (VU.Vector a))
invRaw :: forall a.
(Fractional a, Eq a, Unbox a) =>
Matrix a -> Maybe (a, Vector (Vector a))
invRaw (Matrix Int
h Int
w Vector a
vec) = (forall s. ST s (Maybe (a, Vector (Vector a))))
-> Maybe (a, Vector (Vector a))
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Maybe (a, Vector (Vector a))))
 -> Maybe (a, Vector (Vector a)))
-> (forall s. ST s (Maybe (a, Vector (Vector a))))
-> Maybe (a, Vector (Vector a))
forall a b. (a -> b) -> a -> b
$ do
  MVector s a
vecA <- Vector a -> ST s (MVector (PrimState (ST s)) a)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Vector a -> m (MVector (PrimState m) a)
VU.thaw Vector a
vec
  MVector (PrimState (ST s)) (MVector (PrimState (ST s)) a)
viewA <- Vector (MVector s a)
-> ST s (MVector (PrimState (ST s)) (MVector s a))
forall (m :: * -> *) a.
PrimMonad m =>
Vector a -> m (MVector (PrimState m) a)
V.thaw (Vector (MVector s a)
 -> ST s (MVector (PrimState (ST s)) (MVector s a)))
-> Vector (MVector s a)
-> ST s (MVector (PrimState (ST s)) (MVector s a))
forall a b. (a -> b) -> a -> b
$ Int
-> (MVector s a -> (MVector s a, MVector s a))
-> MVector s a
-> Vector (MVector s a)
forall b a. Int -> (b -> (a, b)) -> b -> Vector a
V.unfoldrExactN Int
n (Int -> MVector s a -> (MVector s a, MVector s a)
forall a s.
Unbox a =>
Int -> MVector s a -> (MVector s a, MVector s a)
VUM.splitAt Int
n) MVector s a
vecA

  MVector s a
vecB <- Int -> a -> ST s (MVector (PrimState (ST s)) a)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
VUM.replicate (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
n) (a
0 :: a)
  [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ [Int
0 .. Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
    MVector (PrimState (ST s)) a -> Int -> a -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.unsafeWrite MVector s a
MVector (PrimState (ST s)) a
vecB (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i) (a
1 :: a)
  MVector (PrimState (ST s)) (MVector (PrimState (ST s)) a)
viewB <- Vector (MVector s a)
-> ST s (MVector (PrimState (ST s)) (MVector s a))
forall (m :: * -> *) a.
PrimMonad m =>
Vector a -> m (MVector (PrimState m) a)
V.thaw (Vector (MVector s a)
 -> ST s (MVector (PrimState (ST s)) (MVector s a)))
-> Vector (MVector s a)
-> ST s (MVector (PrimState (ST s)) (MVector s a))
forall a b. (a -> b) -> a -> b
$ Int
-> (MVector s a -> (MVector s a, MVector s a))
-> MVector s a
-> Vector (MVector s a)
forall b a. Int -> (b -> (a, b)) -> b -> Vector a
V.unfoldrExactN Int
n (Int -> MVector s a -> (MVector s a, MVector s a)
forall a s.
Unbox a =>
Int -> MVector s a -> (MVector s a, MVector s a)
VUM.splitAt Int
n) MVector s a
vecB

  let inner :: Int -> a -> ST s (Maybe (a, Vector (Vector a)))
inner Int
i !a
det
        | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
n = do
            Vector (Vector a)
viewB' <- (MVector (PrimState (ST s)) a -> ST s (Vector a))
-> Vector (MVector (PrimState (ST s)) a)
-> ST s (Vector (Vector a))
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Vector a -> m (Vector b)
V.mapM MVector (PrimState (ST s)) a -> ST s (Vector a)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
VU.unsafeFreeze (Vector (MVector (PrimState (ST s)) a) -> ST s (Vector (Vector a)))
-> ST s (Vector (MVector (PrimState (ST s)) a))
-> ST s (Vector (Vector a))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< MVector (PrimState (ST s)) (MVector (PrimState (ST s)) a)
-> ST s (Vector (MVector (PrimState (ST s)) a))
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> m (Vector a)
V.unsafeFreeze MVector (PrimState (ST s)) (MVector (PrimState (ST s)) a)
viewB
            Maybe (a, Vector (Vector a)) -> ST s (Maybe (a, Vector (Vector a)))
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (a, Vector (Vector a))
 -> ST s (Maybe (a, Vector (Vector a))))
-> Maybe (a, Vector (Vector a))
-> ST s (Maybe (a, Vector (Vector a)))
forall a b. (a -> b) -> a -> b
$ (a, Vector (Vector a)) -> Maybe (a, Vector (Vector a))
forall a. a -> Maybe a
Just (a
det, Vector (Vector a)
viewB')
        | Bool
otherwise = do
            let swapLoop :: Int -> t -> ST s t
swapLoop Int
k !t
det_
                  | Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
n = t -> ST s t
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure t
det_
                  | Bool
otherwise = do
                      a
aki <- MVector (PrimState (ST s)) (MVector (PrimState (ST s)) a)
-> Int -> Int -> ST s a
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) (MVector (PrimState m) a)
-> Int -> Int -> m a
read2d MVector (PrimState (ST s)) (MVector (PrimState (ST s)) a)
viewA Int
k Int
i
                      if a
aki a -> a -> Bool
forall a. Eq a => a -> a -> Bool
/= a
0
                        then do
                          if Int
k Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
i
                            then do
                              MVector (PrimState (ST s)) (MVector (PrimState (ST s)) a)
-> Int -> Int -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> Int -> m ()
VGM.unsafeSwap MVector (PrimState (ST s)) (MVector (PrimState (ST s)) a)
viewA Int
i Int
k
                              MVector (PrimState (ST s)) (MVector (PrimState (ST s)) a)
-> Int -> Int -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> Int -> m ()
VGM.unsafeSwap MVector (PrimState (ST s)) (MVector (PrimState (ST s)) a)
viewB Int
i Int
k
                              t -> ST s t
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (-t
det_)
                            else t -> ST s t
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure t
det_
                        else do
                          Int -> t -> ST s t
swapLoop (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) t
det_
            a
det' <- Int -> a -> ST s a
forall {t}. Num t => Int -> t -> ST s t
swapLoop Int
i a
det
            a
aii <- MVector (PrimState (ST s)) (MVector (PrimState (ST s)) a)
-> Int -> Int -> ST s a
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) (MVector (PrimState m) a)
-> Int -> Int -> m a
read2d MVector (PrimState (ST s)) (MVector (PrimState (ST s)) a)
viewA Int
i Int
i
            if a
aii a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0
              then Maybe (a, Vector (Vector a)) -> ST s (Maybe (a, Vector (Vector a)))
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (a, Vector (Vector a))
forall a. Maybe a
Nothing
              else do
                let !c :: a
c = (a
1 :: a) a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
aii
                let !det'' :: a
det'' = a
det' a -> a -> a
forall a. Num a => a -> a -> a
* a
aii
                MVector (PrimState (ST s)) a
rowAI <- MVector (PrimState (ST s)) (MVector (PrimState (ST s)) a)
-> Int -> ST s (MVector (PrimState (ST s)) a)
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.unsafeRead MVector (PrimState (ST s)) (MVector (PrimState (ST s)) a)
viewA Int
i
                MVector (PrimState (ST s)) a
rowBI <- MVector (PrimState (ST s)) (MVector (PrimState (ST s)) a)
-> Int -> ST s (MVector (PrimState (ST s)) a)
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.unsafeRead MVector (PrimState (ST s)) (MVector (PrimState (ST s)) a)
viewB Int
i
                MVector (PrimState (ST s)) a -> (Int -> a -> ST s ()) -> ST s ()
forall (m :: * -> *) a b.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> (Int -> a -> m b) -> m ()
VUM.iforM_ (Int -> MVector (PrimState (ST s)) a -> MVector (PrimState (ST s)) a
forall a s. Unbox a => Int -> MVector s a -> MVector s a
VUM.unsafeDrop Int
i MVector (PrimState (ST s)) a
rowAI) ((Int -> a -> ST s ()) -> ST s ())
-> (Int -> a -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
j_ a
x -> do
                  MVector (PrimState (ST s)) a -> Int -> a -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.unsafeWrite MVector (PrimState (ST s)) a
rowAI (Int
j_ Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i) (a -> ST s ()) -> a -> ST s ()
forall a b. (a -> b) -> a -> b
$! a
x a -> a -> a
forall a. Num a => a -> a -> a
* a
c
                MVector (PrimState (ST s)) a -> (Int -> a -> ST s ()) -> ST s ()
forall (m :: * -> *) a b.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> (Int -> a -> m b) -> m ()
VUM.iforM_ MVector (PrimState (ST s)) a
rowBI ((Int -> a -> ST s ()) -> ST s ())
-> (Int -> a -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
j a
x -> do
                  MVector (PrimState (ST s)) a -> Int -> a -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.unsafeWrite MVector (PrimState (ST s)) a
rowBI Int
j (a -> ST s ()) -> a -> ST s ()
forall a b. (a -> b) -> a -> b
$! a
x a -> a -> a
forall a. Num a => a -> a -> a
* a
c
                [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ [Int
0 .. Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
k -> do
                  Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
k) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
                    a
c_ <- MVector (PrimState (ST s)) (MVector (PrimState (ST s)) a)
-> Int -> Int -> ST s a
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) (MVector (PrimState m) a)
-> Int -> Int -> m a
read2d MVector (PrimState (ST s)) (MVector (PrimState (ST s)) a)
viewA Int
k Int
i
                    MVector (PrimState (ST s)) a
rowAK <- MVector (PrimState (ST s)) (MVector (PrimState (ST s)) a)
-> Int -> ST s (MVector (PrimState (ST s)) a)
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.unsafeRead MVector (PrimState (ST s)) (MVector (PrimState (ST s)) a)
viewA Int
k
                    MVector (PrimState (ST s)) a
rowBK <- MVector (PrimState (ST s)) (MVector (PrimState (ST s)) a)
-> Int -> ST s (MVector (PrimState (ST s)) a)
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.unsafeRead MVector (PrimState (ST s)) (MVector (PrimState (ST s)) a)
viewB Int
k
                    MVector (PrimState (ST s)) a -> (Int -> a -> ST s ()) -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a b.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> (Int -> a -> m b) -> m ()
VGM.iforM_ (Int -> MVector (PrimState (ST s)) a -> MVector (PrimState (ST s)) a
forall (v :: * -> * -> *) a s. MVector v a => Int -> v s a -> v s a
VGM.unsafeDrop Int
i MVector (PrimState (ST s)) a
rowAI) ((Int -> a -> ST s ()) -> ST s ())
-> (Int -> a -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
j_ a
aij -> do
                      MVector (PrimState (ST s)) a -> (a -> a) -> Int -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> (a -> a) -> Int -> m ()
VGM.unsafeModify MVector (PrimState (ST s)) a
rowAK (a -> a -> a
forall a. Num a => a -> a -> a
subtract (a
aij a -> a -> a
forall a. Num a => a -> a -> a
* a
c_)) (Int
j_ Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i)
                    MVector (PrimState (ST s)) a -> (Int -> a -> ST s ()) -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a b.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> (Int -> a -> m b) -> m ()
VGM.iforM_ MVector (PrimState (ST s)) a
rowBI ((Int -> a -> ST s ()) -> ST s ())
-> (Int -> a -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
j a
bij -> do
                      MVector (PrimState (ST s)) a -> (a -> a) -> Int -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> (a -> a) -> Int -> m ()
VGM.unsafeModify MVector (PrimState (ST s)) a
rowBK (a -> a -> a
forall a. Num a => a -> a -> a
subtract (a
bij a -> a -> a
forall a. Num a => a -> a -> a
* a
c_)) Int
j
                Int -> a -> ST s (Maybe (a, Vector (Vector a)))
inner (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) a
det''

  Int -> a -> ST s (Maybe (a, Vector (Vector a)))
inner Int
0 (a
1 :: a)
  where
    !()
_ = HasCallStack => Bool -> String -> ()
Bool -> String -> ()
ACIA.runtimeAssert (Int
h Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
w) (String -> ()) -> String -> ()
forall a b. (a -> b) -> a -> b
$ String
"AtCoder.Extra.Semigroup.Matrix.inv: given non-square matrix of size " String -> ShowS
forall a. [a] -> [a] -> [a]
++ (Int, Int) -> String
forall a. Show a => a -> String
show (Int
h, Int
w)
    n :: Int
n = Int
h

-- | \(O(hw \min(h, w))\) Returns the rank of the matrix.
--
-- @since 1.1.1.0
{-# INLINE detMod #-}
detMod :: Int -> Matrix Int -> Int
detMod :: Int -> Matrix Int -> Int
detMod Int
m (Matrix Int
h Int
w Vector Int
vecA) = (forall s. ST s Int) -> Int
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s Int) -> Int) -> (forall s. ST s Int) -> Int
forall a b. (a -> b) -> a -> b
$ do
  MVector s Int
vm <- Vector Int -> ST s (MVector (PrimState (ST s)) Int)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Vector a -> m (MVector (PrimState m) a)
VU.thaw Vector Int
vecA
  MVector (PrimState (ST s)) (MVector (PrimState (ST s)) Int)
view <- Vector (MVector s Int)
-> ST s (MVector (PrimState (ST s)) (MVector s Int))
forall (m :: * -> *) a.
PrimMonad m =>
Vector a -> m (MVector (PrimState m) a)
V.thaw (Vector (MVector s Int)
 -> ST s (MVector (PrimState (ST s)) (MVector s Int)))
-> Vector (MVector s Int)
-> ST s (MVector (PrimState (ST s)) (MVector s Int))
forall a b. (a -> b) -> a -> b
$ Int
-> (MVector s Int -> (MVector s Int, MVector s Int))
-> MVector s Int
-> Vector (MVector s Int)
forall b a. Int -> (b -> (a, b)) -> b -> Vector a
V.unfoldrExactN Int
n (Int -> MVector s Int -> (MVector s Int, MVector s Int)
forall a s.
Unbox a =>
Int -> MVector s a -> (MVector s a, MVector s a)
VUM.splitAt Int
n) MVector s Int
vm

  let inner :: Int -> Int -> ST s Int
inner Int
i (!Int
det :: Int)
        | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
n = Int -> ST s Int
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
det
        | Bool
otherwise = do
            let swapLoop :: Int -> Int -> ST s Int
swapLoop Int
j !Int
det_
                  | Int
j Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
n = Int -> ST s Int
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
det_
                  | Bool
otherwise = do
                      Int
aji <- MVector (PrimState (ST s)) (MVector (PrimState (ST s)) Int)
-> Int -> Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) (MVector (PrimState m) a)
-> Int -> Int -> m a
read2d MVector (PrimState (ST s)) (MVector (PrimState (ST s)) Int)
view Int
j Int
i
                      if Int
aji Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0
                        then Int -> Int -> ST s Int
swapLoop (Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
det_
                        else do
                          if Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
j
                            then do
                              MVector (PrimState (ST s)) (MVector (PrimState (ST s)) Int)
-> Int -> Int -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> Int -> m ()
VGM.unsafeSwap MVector (PrimState (ST s)) (MVector (PrimState (ST s)) Int)
view Int
i Int
j
                              Int -> ST s Int
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> ST s Int) -> Int -> ST s Int
forall a b. (a -> b) -> a -> b
$! Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
det_
                            else Int -> ST s Int
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
det_
            Int
det' <- Int -> Int -> ST s Int
swapLoop Int
i Int
det
            Int
det'' <-
              (Int -> Int -> ST s Int) -> Int -> Vector Int -> ST s Int
forall (m :: * -> *) b a.
(Monad m, Unbox b) =>
(a -> b -> m a) -> a -> Vector b -> m a
VU.foldM'
                ( \ !Int
acc Int
j -> do
                    let visitDiag :: Int -> m Int
visitDiag !Int
det_ = do
                          Int
aii <- MVector (PrimState m) (MVector (PrimState m) Int)
-> Int -> Int -> m Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) (MVector (PrimState m) a)
-> Int -> Int -> m a
read2d MVector (PrimState m) (MVector (PrimState m) Int)
MVector (PrimState (ST s)) (MVector (PrimState (ST s)) Int)
view Int
i Int
i
                          if Int
aii Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0
                            then Int -> m Int
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
det_
                            else do
                              Int
aji <- MVector (PrimState m) (MVector (PrimState m) Int)
-> Int -> Int -> m Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) (MVector (PrimState m) a)
-> Int -> Int -> m a
read2d MVector (PrimState m) (MVector (PrimState m) Int)
MVector (PrimState (ST s)) (MVector (PrimState (ST s)) Int)
view Int
j Int
i
                              let !c :: Int
c = Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
aji Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
aii
                              MVector (PrimState (ST s)) Int
rowI <- MVector (PrimState m) (MVector (PrimState (ST s)) Int)
-> Int -> m (MVector (PrimState (ST s)) Int)
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.unsafeRead MVector (PrimState m) (MVector (PrimState (ST s)) Int)
MVector (PrimState (ST s)) (MVector (PrimState (ST s)) Int)
view Int
i
                              MVector (PrimState (ST s)) Int
rowJ <- MVector (PrimState m) (MVector (PrimState (ST s)) Int)
-> Int -> m (MVector (PrimState (ST s)) Int)
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.unsafeRead MVector (PrimState m) (MVector (PrimState (ST s)) Int)
MVector (PrimState (ST s)) (MVector (PrimState (ST s)) Int)
view Int
j
                              -- NOTE: it's a reverse loop!
                              (Int -> Int -> () -> m ())
-> () -> MVector (PrimState m) Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a b.
(PrimMonad m, MVector v a) =>
(Int -> a -> b -> m b) -> b -> v (PrimState m) a -> m b
VGM.ifoldrM'
                                ( \Int
k_ Int
aik () -> do
                                    MVector (PrimState m) Int -> (Int -> Int) -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> (a -> a) -> Int -> m ()
VGM.unsafeModify MVector (PrimState m) Int
MVector (PrimState (ST s)) Int
rowJ ((Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
m) (Int -> Int) -> (Int -> Int) -> Int -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
aik Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
c)) (Int
k_ Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i)
                                )
                                ()
                                (Int
-> MVector (PrimState (ST s)) Int -> MVector (PrimState (ST s)) Int
forall (v :: * -> * -> *) a s. MVector v a => Int -> v s a -> v s a
VGM.unsafeDrop Int
i MVector (PrimState (ST s)) Int
rowI)
                              MVector (PrimState m) (MVector (PrimState (ST s)) Int)
-> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> Int -> m ()
VGM.unsafeSwap MVector (PrimState m) (MVector (PrimState (ST s)) Int)
MVector (PrimState (ST s)) (MVector (PrimState (ST s)) Int)
view Int
i Int
j
                              Int -> m Int
visitDiag (Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
det_)
                    Int
acc' <- Int -> ST s Int
forall {m :: * -> *}.
(PrimState m ~ PrimState (ST s), PrimMonad m) =>
Int -> m Int
visitDiag Int
acc
                    MVector (PrimState (ST s)) (MVector (PrimState (ST s)) Int)
-> Int -> Int -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> Int -> m ()
VGM.unsafeSwap MVector (PrimState (ST s)) (MVector (PrimState (ST s)) Int)
view Int
i Int
j
                    Int -> ST s Int
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> ST s Int) -> Int -> ST s Int
forall a b. (a -> b) -> a -> b
$! Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
acc'
                )
                Int
det'
                (Int -> (Int -> Int) -> Vector Int
forall a. Unbox a => Int -> (Int -> a) -> Vector a
VU.generate (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)) (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)))

            Int -> Int -> ST s Int
inner (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
det''

  Int
det <- Int -> Int -> ST s Int
inner Int
0 (Int
1 :: Int)
  Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral
    (Word64 -> Int) -> ST s Word64 -> ST s Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Word64 -> Int -> ST s Word64)
-> Word64 -> Vector Int -> ST s Word64
forall (m :: * -> *) b a.
(Monad m, Unbox b) =>
(a -> b -> m a) -> a -> Vector b -> m a
VU.foldM'
      ( \(!Word64
acc :: Word64) Int
i -> do
          Int
aii <- MVector (PrimState (ST s)) (MVector (PrimState (ST s)) Int)
-> Int -> Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) (MVector (PrimState m) a)
-> Int -> Int -> m a
read2d MVector (PrimState (ST s)) (MVector (PrimState (ST s)) Int)
view Int
i Int
i
          Word64 -> ST s Word64
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Word64 -> ST s Word64) -> Word64 -> ST s Word64
forall a b. (a -> b) -> a -> b
$! Barrett -> Word64 -> Word64 -> Word64
BT.mulMod Barrett
bt Word64
acc (Word64 -> Word64) -> Word64 -> Word64
forall a b. (a -> b) -> a -> b
$! Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
aii
      )
      (Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
det)
      (Int -> (Int -> Int) -> Vector Int
forall a. Unbox a => Int -> (Int -> a) -> Vector a
VU.generate Int
n Int -> Int
forall a. a -> a
id)
  where
    !()
_ = HasCallStack => Bool -> String -> ()
Bool -> String -> ()
ACIA.runtimeAssert (Int
h Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
w) (String -> ()) -> String -> ()
forall a b. (a -> b) -> a -> b
$ String
"AtCoder.Extra.Semigroup.Matrix.detMod: given non-square matrix of size " String -> ShowS
forall a. [a] -> [a] -> [a]
++ (Int, Int) -> String
forall a. Show a => a -> String
show (Int
h, Int
w)
    !n :: Int
n = Int
h
    !bt :: Barrett
bt = Word32 -> Barrett
BT.new32 (Word32 -> Barrett) -> Word32 -> Barrett
forall a b. (a -> b) -> a -> b
$ Int -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
m

-- | \(O(hw \min(h, w))\) Returns the rank of the matrix.
--
-- @since 1.1.1.0
{-# INLINE detMint #-}
detMint :: forall a. (KnownNat a) => Matrix (M.ModInt a) -> M.ModInt a
detMint :: forall (a :: Nat). KnownNat a => Matrix (ModInt a) -> ModInt a
detMint Matrix (ModInt a)
matA = Int -> ModInt a
forall (a :: Nat). KnownNat a => Int -> ModInt a
M.new (Int -> ModInt a) -> (Matrix Int -> Int) -> Matrix Int -> ModInt a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Matrix Int -> Int
detMod Int
m (Matrix Int -> ModInt a) -> Matrix Int -> ModInt a
forall a b. (a -> b) -> a -> b
$ (ModInt a -> Int) -> Matrix (ModInt a) -> Matrix Int
forall a b. (Unbox a, Unbox b) => (a -> b) -> Matrix a -> Matrix b
map ModInt a -> Int
forall (a :: Nat). KnownNat a => ModInt a -> Int
M.val Matrix (ModInt a)
matA
  where
    !m :: Int
m = Nat -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Proxy# a -> Nat
forall (n :: Nat). KnownNat n => Proxy# n -> Nat
natVal' (forall (a :: Nat). Proxy# a
forall {k} (a :: k). Proxy# a
proxy# @a))

-- | @since 1.1.0.0
instance (Num a, VU.Unbox a) => Semigroup (Matrix a) where
  {-# INLINE (<>) #-}
  <> :: Matrix a -> Matrix a -> Matrix a
(<>) = Matrix a -> Matrix a -> Matrix a
forall e. (Num e, Unbox e) => Matrix e -> Matrix e -> Matrix e
mul

  -- Prefer `powMod` or `powMint` as specialized, much efficient variant.
  {-# INLINE stimes #-}
  stimes :: forall b. Integral b => b -> Matrix a -> Matrix a
stimes = (Matrix a -> Matrix a -> Matrix a) -> Int -> Matrix a -> Matrix a
forall a. (a -> a -> a) -> Int -> a -> a
ACEM.power Matrix a -> Matrix a -> Matrix a
forall a. Semigroup a => a -> a -> a
(<>) (Int -> Matrix a -> Matrix a)
-> (b -> Int) -> b -> Matrix a -> Matrix a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral