{-# LANGUAGE MagicHash #-}
{-# LANGUAGE RecordWildCards #-}

module BenchLib.Matrix
  ( Matrix (..),
    new,
    map,
    mulToCol,
    mulToColModInt,
    mul1,
    mul2,
    mul3_1,
    mul3_2,
    mul3_3,
    mulMod1,
    mulMod2,
    mulMod3,
    mulMod4,
    mulMod5,
    mulMint1,
    mulMint2,
    mulMint3,
  )
where

import AtCoder.Internal.Assert qualified as ACIA
import AtCoder.Internal.Barrett qualified as BT
import AtCoder.ModInt qualified as M
import Data.Vector qualified as V
import Data.Vector.Generic qualified as VG
import Data.Vector.Unboxed qualified as VU
import Data.Word (Word64)
import GHC.Exts (Proxy#, proxy#)
import GHC.Stack (HasCallStack)
import GHC.TypeNats (KnownNat, natVal, natVal')
import Prelude hiding (map)

data Matrix a = Matrix
  { hM :: {-# UNPACK #-} !Int,
    wM :: {-# UNPACK #-} !Int,
    vecM :: !(VU.Vector a)
  }
  deriving (Show, Eq)

type Col a = VU.Vector a

{-# INLINE new #-}
new :: (HasCallStack, VU.Unbox a) => Int -> Int -> VU.Vector a -> Matrix a
new h w vec
  | VU.length vec /= h * w = error "AtCoder.Extra.Matrix: size mismatch"
  | otherwise = Matrix h w vec

{-# INLINE map #-}
map :: (VU.Unbox a, VU.Unbox b) => (a -> b) -> Matrix a -> Matrix b
map f Matrix {..} = Matrix hM wM $ VU.map f vecM

{-# INLINE mulToCol #-}
mulToCol :: (Num a, VU.Unbox a) => Matrix a -> Col a -> Col a
mulToCol Matrix {..} !col = VU.convert $ V.map (VU.sum . VU.zipWith (*) col) rows
  where
    !n = VU.length col
    !_ = ACIA.runtimeAssert (n == wM) "AtCoder.Extra.Matrix.mulToCol: size mismatch"
    rows = V.unfoldrExactN hM (VU.splitAt wM) vecM

{-# INLINE mulToColModInt #-}
mulToColModInt :: forall m. (KnownNat m) => Matrix (M.ModInt m) -> Col (M.ModInt m) -> Col (M.ModInt m)
mulToColModInt Matrix {..} !col = VU.convert $ V.map (VU.foldl' (+) (M.unsafeNew 0) . VU.zipWith mulMod col) rows
  where
    !_ = ACIA.runtimeAssert (VU.length col == wM) "AtCoder.Extra.Matrix.mulToColModInt: size mismatch"
    !bt = BT.new32 $ fromIntegral (natVal' (proxy# @m))
    rows = V.unfoldrExactN hM (VU.splitAt wM) vecM
    mulMod (M.ModInt x) (M.ModInt y) = M.unsafeNew . fromIntegral $ BT.mulMod bt (fromIntegral x) (fromIntegral y)

{-# INLINE mul1 #-}
mul1 :: (Num e, VU.Unbox e) => Matrix e -> Matrix e -> Matrix e
mul1 !a !b =
  Matrix h w' $
    VU.unfoldrExactN
      (h * w')
      ( \(!row, !col) ->
          let !x = f row col
           in if col + 1 >= w'
                then (x, (row + 1, 0))
                else (x, (row, col + 1))
      )
      (0, 0)
  where
    f row col = VU.sum $ VU.zipWith (*) (rows1 VG.! row) (cols2 VG.! col)
    h = hM a
    w = wM a
    vecA = vecM a
    h' = hM b
    w' = wM b
    vecB = vecM b
    !_ = ACIA.runtimeAssert (w == h') "AtCoder.Extra.Matrix.mul: matrix size mismatch"
    rows1 = V.unfoldrExactN h (VU.splitAt w) vecA
    cols2 = V.generate w' $ \col -> VU.unfoldrExactN h' (\i -> (VG.unsafeIndex vecB i, i + w')) col

{-# INLINE mul2 #-}
mul2 :: (Num e, VU.Unbox e) => Matrix e -> Matrix e -> Matrix e
mul2 !a !b =
  Matrix h w' $
    VU.unfoldrExactN
      (h * w')
      ( \(!row, !col) ->
          let !x = f row col
           in if col + 1 >= w'
                then (x, (row + 1, 0))
                else (x, (row, col + 1))
      )
      (0, 0)
  where
    f row col = VU.sum $ VU.imap (\iRow x -> x * VG.unsafeIndex vecB (col + iRow * w')) (rows1 VG.! row)
    h = hM a
    w = wM a
    vecA = vecM a
    h' = hM b
    w' = wM b
    vecB = vecM b
    !_ = ACIA.runtimeAssert (w == h') "AtCoder.Extra.Matrix.mul: matrix size mismatch"
    rows1 = V.unfoldrExactN h (VU.splitAt w) vecA

{-# INLINE mul3_1 #-}
mul3_1 :: (Num e, VU.Unbox e) => Matrix e -> Matrix e -> Matrix e
mul3_1 !a !b =
  Matrix h w' $
    VU.unfoldrExactN
      (h * w')
      ( \(!row, !col) ->
          let !x = f row col
           in if col + 1 >= w'
                then (x, (row + 1, 0))
                else (x, (row, col + 1))
      )
      (0, 0)
  where
    f row col = VU.sum $ VU.imap (\iRow x -> x * VG.unsafeIndex vecB (col + iRow * w')) (VU.unsafeSlice (w * row) w vecA)
    h = hM a
    w = wM a
    h' = hM b
    vecA = vecM a
    w' = wM b
    vecB = vecM b
    !_ = ACIA.runtimeAssert (w == h') "AtCoder.Extra.Matrix.mul: matrix size mismatch"

{-# INLINE mul3_2 #-}
mul3_2 :: (Num e, VU.Unbox e) => Matrix e -> Matrix e -> Matrix e
mul3_2 !a !b =
  Matrix h w' $
    VU.unfoldrExactN
      (h * w')
      ( \(!row, !col) ->
          let !y = VU.sum $ VU.imap (\iRow x -> x * VG.unsafeIndex vecB (col + iRow * w')) (VU.unsafeSlice (w * row) w vecA)
           in if col + 1 >= w'
                then (y, (row + 1, 0))
                else (y, (row, col + 1))
      )
      (0, 0)
  where
    h = hM a
    w = wM a
    h' = hM b
    vecA = vecM a
    w' = wM b
    vecB = vecM b
    !_ = ACIA.runtimeAssert (w == h') "AtCoder.Extra.Matrix.mul: matrix size mismatch"

{-# INLINE mul3_3 #-}
mul3_3 :: (Num e, VU.Unbox e) => Matrix e -> Matrix e -> Matrix e
mul3_3 !a !b = Matrix h w' $ VU.generate (h * w') $ \i ->
  let (!row, !col) = i `quotRem` w'
   in VU.sum $ VU.imap (\iRow x -> x * VG.unsafeIndex vecB (col + iRow * w')) (VU.unsafeSlice (w * row) w vecA)
  where
    h = hM a
    w = wM a
    h' = hM b
    vecA = vecM a
    w' = wM b
    vecB = vecM b
    !_ = ACIA.runtimeAssert (w == h') "AtCoder.Extra.Matrix.mul: matrix size mismatch"

{-# INLINE mulMod1 #-}
mulMod1 :: Int -> Matrix Int -> Matrix Int -> Matrix Int
mulMod1 !m !a !b =
  Matrix h w' $
    VU.unfoldrExactN
      (h * w')
      ( \(!row, !col) ->
          let !x = f row col
           in if col + 1 >= w'
                then (x, (row + 1, 0))
                else (x, (row, col + 1))
      )
      (0, 0)
  where
    f row col = VU.foldl1' addMod $ VU.imap (\iRow x -> mulMod x (VG.unsafeIndex vecB (col + (iRow * w')))) (VU.unsafeSlice (w * row) w vecA)
    addMod x y = (x + y) `mod` m
    mulMod x y = (x * y) `mod` m
    h = hM a
    w = wM a
    h' = hM b
    vecA = vecM a
    w' = wM b
    vecB = vecM b
    !_ = ACIA.runtimeAssert (w == h') "AtCoder.Extra.Matrix.mul: matrix size mismatch"

{-# INLINE mulMod2 #-}
mulMod2 :: Int -> Matrix Int -> Matrix Int -> Matrix Int
mulMod2 !m !a !b =
  Matrix h w' $
    VU.unfoldrExactN
      (h * w')
      ( \(!row, !col) ->
          let !x = f row col
           in if col + 1 >= w'
                then (x, (row + 1, 0))
                else (x, (row, col + 1))
      )
      (0, 0)
  where
    f row col = VU.foldl1' addMod $ VU.imap (\iRow x -> mulMod x (VG.unsafeIndex vecB (col + (iRow * w')))) (VU.unsafeSlice (w * row) w vecA)
    -- very slow
    addMod x y
      | x + y >= m = x + y - m
      | otherwise = x + y
    mulMod x y = (x * y) `mod` m
    h = hM a
    w = wM a
    h' = hM b
    vecA = vecM a
    w' = wM b
    vecB = vecM b
    !_ = ACIA.runtimeAssert (w == h') "AtCoder.Extra.Matrix.mul: matrix size mismatch"

{-# INLINE mulMod3 #-}
mulMod3 :: Int -> Matrix Int -> Matrix Int -> Matrix Int
mulMod3 !m !a !b =
  Matrix h w' $
    VU.unfoldrExactN
      (h * w')
      ( \(!row, !col) ->
          let !x = f row col
           in if col + 1 >= w'
                then (x, (row + 1, 0))
                else (x, (row, col + 1))
      )
      (0, 0)
  where
    !bt = BT.new32 $ fromIntegral m
    f row col = VU.foldl1' addMod $ VU.imap (\iRow x -> mulMod x (VG.unsafeIndex vecB (col + (iRow * w')))) (VU.unsafeSlice (w * row) w vecA)
    addMod x y = (x + y) `mod` m
    mulMod x y = fromIntegral $ BT.mulMod bt (fromIntegral x) (fromIntegral y)
    h = hM a
    w = wM a
    h' = hM b
    vecA = vecM a
    w' = wM b
    vecB = vecM b
    !_ = ACIA.runtimeAssert (w == h') "AtCoder.Extra.Matrix.mul: matrix size mismatch"

{-# INLINE mulMod4 #-}
mulMod4 :: Int -> Matrix Int -> Matrix Int -> Matrix Int
mulMod4 !m !a !b =
  Matrix h w' $
    VU.unfoldrExactN
      (h * w')
      ( \(!row, !col) ->
          let !x = f row col
           in if col + 1 >= w'
                then (x, (row + 1, 0))
                else (x, (row, col + 1))
      )
      (0, 0)
  where
    !bt = BT.new32 $ fromIntegral m
    f row col = VU.foldl1' addMod $ VU.imap (\iRow x -> mulMod x (VG.unsafeIndex vecB (col + (iRow * w')))) (VU.unsafeSlice (w * row) w vecA)
    addMod x y = (x + y) `rem` m
    mulMod x y = fromIntegral $ BT.mulMod bt (fromIntegral x) (fromIntegral y)
    h = hM a
    w = wM a
    h' = hM b
    vecA = vecM a
    w' = wM b
    vecB = vecM b
    !_ = ACIA.runtimeAssert (w == h') "AtCoder.Extra.Matrix.mul: matrix size mismatch"

{-# INLINE mulMod5 #-}
mulMod5 :: Int -> Matrix Int -> Matrix Int -> Matrix Int
mulMod5 !m !a !b =
  Matrix h w' $
    VU.unfoldrExactN
      (h * w')
      ( \(!row, !col) ->
          let !x = f row col
           in if col + 1 >= w'
                then (x, (row + 1, 0))
                else (x, (row, col + 1))
      )
      (0, 0)
  where
    !bt = BT.new32 $ fromIntegral m
    f row col = VU.foldl1' addMod $ VU.imap (\iRow x -> mulMod x (VG.unsafeIndex vecB (col + (iRow * w')))) (VU.unsafeSlice (w * row) w vecA)
    addMod x y
      | x + y >= m = x + y - m
      | otherwise = x + y
    mulMod x y = fromIntegral $ BT.mulMod bt (fromIntegral x) (fromIntegral y)
    h = hM a
    w = wM a
    h' = hM b
    vecA = vecM a
    w' = wM b
    vecB = vecM b
    !_ = ACIA.runtimeAssert (w == h') "AtCoder.Extra.Matrix.mul: matrix size mismatch"

{-# INLINE mulMint1 #-}
mulMint1 :: forall a. (KnownNat a) => Matrix (M.ModInt a) -> Matrix (M.ModInt a) -> Matrix (M.ModInt a)
mulMint1 !a !b =
  Matrix h w' $
    VU.unfoldrExactN
      (h * w')
      ( \(!row, !col) ->
          let !x = f row col
           in if col + 1 >= w'
                then (x, (row + 1, 0))
                else (x, (row, col + 1))
      )
      (0, 0)
  where
    f :: Int -> Int -> M.ModInt a
    f row col = VU.sum $ VU.imap (\iRow x -> mulMod x (VG.unsafeIndex vecB (col + (iRow * w')))) (VU.unsafeSlice (w * row) w vecA)
    mulMod :: M.ModInt a -> M.ModInt a -> M.ModInt a
    mulMod = (*)
    h = hM a
    w = wM a
    h' = hM b
    vecA = vecM a
    w' = wM b
    vecB = vecM b
    !_ = ACIA.runtimeAssert (w == h') "AtCoder.Extra.Matrix.mul: matrix size mismatch"

{-# INLINE mulMint2 #-}
mulMint2 :: forall a. (KnownNat a) => Matrix (M.ModInt a) -> Matrix (M.ModInt a) -> Matrix (M.ModInt a)
mulMint2 !a !b =
  Matrix h w' $
    VU.unfoldrExactN
      (h * w')
      ( \(!row, !col) ->
          let !x = f row col
           in if col + 1 >= w'
                then (x, (row + 1, 0))
                else (x, (row, col + 1))
      )
      (0, 0)
  where
    !bt = BT.new32 $ fromIntegral (natVal' (proxy# @a))
    f :: Int -> Int -> M.ModInt a
    f row col = VU.sum $ VU.imap (\iRow x -> mulMod x (VG.unsafeIndex vecB (col + (iRow * w')))) (VU.unsafeSlice (w * row) w vecA)
    mulMod :: M.ModInt a -> M.ModInt a -> M.ModInt a
    mulMod (M.ModInt x) (M.ModInt y) = M.unsafeNew . fromIntegral $ BT.mulMod bt (fromIntegral x) (fromIntegral y)
    h = hM a
    w = wM a
    h' = hM b
    vecA = vecM a
    w' = wM b
    vecB = vecM b
    !_ = ACIA.runtimeAssert (w == h') "AtCoder.Extra.Matrix.mul: matrix size mismatch"

-- REMARK: This is very unsafe in that it can overflow (mod^2 * n)
{-# INLINE mulMint3 #-}
mulMint3 :: forall a. (KnownNat a) => Matrix (M.ModInt a) -> Matrix (M.ModInt a) -> Matrix (M.ModInt a)
mulMint3 !a !b =
  Matrix h w' $
    VU.unfoldrExactN
      (h * w')
      ( \(!row, !col) ->
          let !x = f row col
           in if col + 1 >= w'
                then (x, (row + 1, 0))
                else (x, (row, col + 1))
      )
      (0, 0)
  where
    !bt = BT.new32 $ fromIntegral (natVal' (proxy# @a))
    f :: Int -> Int -> M.ModInt a
    f row col = M.new64 . VU.sum $ VU.imap (\iRow x -> mulMod x (VG.unsafeIndex vecB (col + (iRow * w')))) (VU.unsafeSlice (w * row) w vecA)
    mulMod :: M.ModInt a -> M.ModInt a -> Word64
    mulMod (M.ModInt x) (M.ModInt y) = BT.mulMod bt (fromIntegral x) (fromIntegral y)
    h = hM a
    w = wM a
    h' = hM b
    vecA = vecM a
    w' = wM b
    vecB = vecM b
    !_ = ACIA.runtimeAssert (w == h') "AtCoder.Extra.Matrix.mul: matrix size mismatch"