{-# LANGUAGE MagicHash #-} {-# LANGUAGE RecordWildCards #-} module BenchLib.Matrix ( Matrix (..), new, map, mulToCol, mulToColModInt, mul1, mul2, mul3_1, mul3_2, mul3_3, mul4_1, mul4_2, mul4_3, mulMod1, mulMod2, mulMod3, mulMod4, mulMod5, mulMod6, mulMod7, mulMint1, mulMint2, mulMint3, mulMint4, ) where import AtCoder.Internal.Assert qualified as ACIA import AtCoder.Internal.Barrett qualified as BT import AtCoder.ModInt qualified as M import Data.Foldable (for_) import Data.Vector qualified as V import Data.Vector.Generic qualified as VG import Data.Vector.Generic.Mutable qualified as VGM import Data.Vector.Unboxed qualified as VU import Data.Vector.Unboxed.Mutable qualified as VUM import Data.Word (Word64) import GHC.Exts (Proxy#, proxy#) import GHC.Stack (HasCallStack) import GHC.TypeNats (KnownNat, natVal, natVal') import Prelude hiding (map) data Matrix a = Matrix { hM :: {-# UNPACK #-} !Int, wM :: {-# UNPACK #-} !Int, vecM :: !(VU.Vector a) } deriving (Show, Eq) type Col a = VU.Vector a {-# INLINE new #-} new :: (HasCallStack, VU.Unbox a) => Int -> Int -> VU.Vector a -> Matrix a new h w vec | VU.length vec /= h * w = error "AtCoder.Extra.Matrix: size mismatch" | otherwise = Matrix h w vec {-# INLINE map #-} map :: (VU.Unbox a, VU.Unbox b) => (a -> b) -> Matrix a -> Matrix b map f Matrix {..} = Matrix hM wM $ VU.map f vecM {-# INLINE mulToCol #-} mulToCol :: (Num a, VU.Unbox a) => Matrix a -> Col a -> Col a mulToCol Matrix {..} !col = VU.convert $ V.map (VU.sum . VU.zipWith (*) col) rows where !n = VU.length col !_ = ACIA.runtimeAssert (n == wM) "AtCoder.Extra.Matrix.mulToCol: size mismatch" rows = V.unfoldrExactN hM (VU.splitAt wM) vecM {-# INLINE mulToColModInt #-} mulToColModInt :: forall m. (KnownNat m) => Matrix (M.ModInt m) -> Col (M.ModInt m) -> Col (M.ModInt m) mulToColModInt Matrix {..} !col = VU.convert $ V.map (VU.foldl' (+) (M.unsafeNew 0) . VU.zipWith mulMod col) rows where !_ = ACIA.runtimeAssert (VU.length col == wM) "AtCoder.Extra.Matrix.mulToColModInt: size mismatch" !bt = BT.new32 $ fromIntegral (natVal' (proxy# @m)) rows = V.unfoldrExactN hM (VU.splitAt wM) vecM mulMod (M.ModInt x) (M.ModInt y) = M.unsafeNew . fromIntegral $ BT.mulMod bt (fromIntegral x) (fromIntegral y) {-# INLINE mul1 #-} mul1 :: (Num e, VU.Unbox e) => Matrix e -> Matrix e -> Matrix e mul1 (Matrix h w vecA) (Matrix h' w' vecB) = Matrix h w' $ VU.unfoldrExactN (h * w') ( \(!row, !col) -> let !x = f row col in if col + 1 >= w' then (x, (row + 1, 0)) else (x, (row, col + 1)) ) (0, 0) where f row col = VU.sum $ VU.zipWith (*) (rows1 VG.! row) (cols2 VG.! col) !_ = ACIA.runtimeAssert (w == h') "AtCoder.Extra.Matrix.mul: matrix size mismatch" rows1 = V.unfoldrExactN h (VU.splitAt w) vecA cols2 = V.generate w' $ \col -> VU.unfoldrExactN h' (\i -> (VG.unsafeIndex vecB i, i + w')) col {-# INLINE mul2 #-} mul2 :: (Num e, VU.Unbox e) => Matrix e -> Matrix e -> Matrix e mul2 (Matrix h w vecA) (Matrix h' w' vecB) = Matrix h w' $ VU.unfoldrExactN (h * w') ( \(!row, !col) -> let !x = f row col in if col + 1 >= w' then (x, (row + 1, 0)) else (x, (row, col + 1)) ) (0, 0) where f row col = VU.sum $ VU.imap (\iRow x -> x * VG.unsafeIndex vecB (col + iRow * w')) (rows1 VG.! row) !_ = ACIA.runtimeAssert (w == h') "AtCoder.Extra.Matrix.mul: matrix size mismatch" rows1 = V.unfoldrExactN h (VU.splitAt w) vecA {-# INLINE mul3_1 #-} mul3_1 :: (Num e, VU.Unbox e) => Matrix e -> Matrix e -> Matrix e mul3_1 (Matrix h w vecA) (Matrix h' w' vecB) = Matrix h w' $ VU.unfoldrExactN (h * w') ( \(!row, !col) -> let !x = f row col in if col + 1 >= w' then (x, (row + 1, 0)) else (x, (row, col + 1)) ) (0, 0) where f row col = VU.sum $ VU.imap (\iRow x -> x * VG.unsafeIndex vecB (col + iRow * w')) (VU.unsafeSlice (w * row) w vecA) !_ = ACIA.runtimeAssert (w == h') "AtCoder.Extra.Matrix.mul: matrix size mismatch" {-# INLINE mul3_2 #-} mul3_2 :: (Num e, VU.Unbox e) => Matrix e -> Matrix e -> Matrix e mul3_2 (Matrix h w vecA) (Matrix h' w' vecB) = Matrix h w' $ VU.unfoldrExactN (h * w') ( \(!row, !col) -> let !y = VU.sum $ VU.imap (\iRow x -> x * VG.unsafeIndex vecB (col + iRow * w')) (VU.unsafeSlice (w * row) w vecA) in if col + 1 >= w' then (y, (row + 1, 0)) else (y, (row, col + 1)) ) (0, 0) where !_ = ACIA.runtimeAssert (w == h') "AtCoder.Extra.Matrix.mul: matrix size mismatch" {-# INLINE mul3_3 #-} mul3_3 :: (Num e, VU.Unbox e) => Matrix e -> Matrix e -> Matrix e mul3_3 (Matrix h w vecA) (Matrix h' w' vecB) = Matrix h w' $ VU.generate (h * w') $ \i -> let (!row, !col) = i `quotRem` w' in VU.sum $ VU.imap (\iRow x -> x * VG.unsafeIndex vecB (col + iRow * w')) (VU.unsafeSlice (w * row) w vecA) where !_ = ACIA.runtimeAssert (w == h') "AtCoder.Extra.Matrix.mul: matrix size mismatch" -- | Fastest: efficient iteration {-# INLINE mul4_1 #-} mul4_1 :: forall e. (Num e, VU.Unbox e) => Matrix e -> Matrix e -> Matrix e mul4_1 (Matrix h w vecA) (Matrix h' w' vecB) = Matrix h w' $ VU.create $ do c <- VUM.replicate (h * w') (0 :: e) for_ [0 .. h - 1] $ \i -> do for_ [0 .. w - 1] $ \k -> do for_ [0 .. w' - 1] $ \j -> do let !aik = VG.unsafeIndex vecA (i * w + k) let !bkj = VG.unsafeIndex vecB (k * w' + j) VGM.unsafeModify c (+ (aik * bkj)) (i * w' + j) pure c where !_ = ACIA.runtimeAssert (w == h') "AtCoder.Extra.Matrix.mul: matrix size mismatch" {-# INLINE mul4_2 #-} mul4_2 :: forall e. (Num e, VU.Unbox e) => Matrix e -> Matrix e -> Matrix e mul4_2 (Matrix h w vecA) (Matrix h' w' vecB) = Matrix h w' $ VU.create $ do c <- VUM.replicate (h * w') (0 :: e) for_ [0 .. h - 1] $ \i -> do let !iw = i * w let !iw' = i * w' for_ [0 .. w - 1] $ \k -> do let !kw' = k * w' for_ [0 .. w' - 1] $ \j -> do let !aik = VG.unsafeIndex vecA (iw + k) let !bkj = VG.unsafeIndex vecB (kw' + j) VGM.unsafeModify c (+ (aik * bkj)) (iw' + j) pure c where !_ = ACIA.runtimeAssert (w == h') "AtCoder.Extra.Matrix.mul: matrix size mismatch" {-# INLINE mul4_3 #-} mul4_3 :: forall e. (Num e, VU.Unbox e) => Matrix e -> Matrix e -> Matrix e mul4_3 (Matrix h w vecA) (Matrix h' w' vecB) = Matrix h w' $ VU.create $ do c <- VUM.replicate (h * w') (0 :: e) for_ [0 .. h - 1] $ \i -> do for_ [0 .. w - 1] $ \k -> do VU.iforM_ (VU.unsafeSlice (k * w') w' vecB) $ \j bkj -> do let !aik = VG.unsafeIndex vecA (i * w + k) VGM.unsafeModify c (+ (aik * bkj)) (i * w' + j) pure c where !_ = ACIA.runtimeAssert (w == h') "AtCoder.Extra.Matrix.mul: matrix size mismatch" {-# INLINE mulMod1 #-} mulMod1 :: Int -> Matrix Int -> Matrix Int -> Matrix Int mulMod1 !m (Matrix h w vecA) (Matrix h' w' vecB) = Matrix h w' $ VU.unfoldrExactN (h * w') ( \(!row, !col) -> let !x = f row col in if col + 1 >= w' then (x, (row + 1, 0)) else (x, (row, col + 1)) ) (0, 0) where f row col = VU.foldl1' addMod $ VU.imap (\iRow x -> mulMod x (VG.unsafeIndex vecB (col + (iRow * w')))) (VU.unsafeSlice (w * row) w vecA) addMod x y = (x + y) `mod` m mulMod x y = (x * y) `mod` m !_ = ACIA.runtimeAssert (w == h') "AtCoder.Extra.Matrix.mul: matrix size mismatch" {-# INLINE mulMod2 #-} mulMod2 :: Int -> Matrix Int -> Matrix Int -> Matrix Int mulMod2 !m (Matrix h w vecA) (Matrix h' w' vecB) = Matrix h w' $ VU.unfoldrExactN (h * w') ( \(!row, !col) -> let !x = f row col in if col + 1 >= w' then (x, (row + 1, 0)) else (x, (row, col + 1)) ) (0, 0) where f row col = VU.foldl1' addMod $ VU.imap (\iRow x -> mulMod x (VG.unsafeIndex vecB (col + (iRow * w')))) (VU.unsafeSlice (w * row) w vecA) -- very slow addMod x y | x + y >= m = x + y - m | otherwise = x + y mulMod x y = (x * y) `mod` m !_ = ACIA.runtimeAssert (w == h') "AtCoder.Extra.Matrix.mul: matrix size mismatch" {-# INLINE mulMod3 #-} mulMod3 :: Int -> Matrix Int -> Matrix Int -> Matrix Int mulMod3 !m (Matrix h w vecA) (Matrix h' w' vecB) = Matrix h w' $ VU.unfoldrExactN (h * w') ( \(!row, !col) -> let !x = f row col in if col + 1 >= w' then (x, (row + 1, 0)) else (x, (row, col + 1)) ) (0, 0) where !bt = BT.new32 $ fromIntegral m f row col = VU.foldl1' addMod $ VU.imap (\iRow x -> mulMod x (VG.unsafeIndex vecB (col + (iRow * w')))) (VU.unsafeSlice (w * row) w vecA) addMod x y = (x + y) `mod` m mulMod x y = fromIntegral $ BT.mulMod bt (fromIntegral x) (fromIntegral y) !_ = ACIA.runtimeAssert (w == h') "AtCoder.Extra.Matrix.mul: matrix size mismatch" {-# INLINE mulMod4 #-} mulMod4 :: Int -> Matrix Int -> Matrix Int -> Matrix Int mulMod4 !m (Matrix h w vecA) (Matrix h' w' vecB) = Matrix h w' $ VU.unfoldrExactN (h * w') ( \(!row, !col) -> let !x = f row col in if col + 1 >= w' then (x, (row + 1, 0)) else (x, (row, col + 1)) ) (0, 0) where !bt = BT.new32 $ fromIntegral m f row col = VU.foldl1' addMod $ VU.imap (\iRow x -> mulMod x (VG.unsafeIndex vecB (col + (iRow * w')))) (VU.unsafeSlice (w * row) w vecA) addMod x y = (x + y) `rem` m mulMod x y = fromIntegral $ BT.mulMod bt (fromIntegral x) (fromIntegral y) !_ = ACIA.runtimeAssert (w == h') "AtCoder.Extra.Matrix.mul: matrix size mismatch" {-# INLINE mulMod5 #-} mulMod5 :: Int -> Matrix Int -> Matrix Int -> Matrix Int mulMod5 !m (Matrix h w vecA) (Matrix h' w' vecB) = Matrix h w' $ VU.unfoldrExactN (h * w') ( \(!row, !col) -> let !x = f row col in if col + 1 >= w' then (x, (row + 1, 0)) else (x, (row, col + 1)) ) (0, 0) where !bt = BT.new32 $ fromIntegral m -- NOTE: this is unsafe if the matrix is too large f row col = fromIntegral . (`rem` fromIntegral m) . VU.sum $ VU.imap (\iRow x -> BT.mulMod bt (fromIntegral x) (fromIntegral (VG.unsafeIndex vecB (col + (iRow * w'))))) (VU.unsafeSlice (w * row) w vecA) !_ = ACIA.runtimeAssert (w == h') "AtCoder.Extra.Matrix.mul: matrix size mismatch" {-# INLINE mulMod6 #-} mulMod6 :: Int -> Matrix Int -> Matrix Int -> Matrix Int mulMod6 !m (Matrix h w vecA) (Matrix h' w' vecB) = Matrix h w' $ VU.unfoldrExactN (h * w') ( \(!row, !col) -> let !x = f row col in if col + 1 >= w' then (x, (row + 1, 0)) else (x, (row, col + 1)) ) (0, 0) where !bt = BT.new32 $ fromIntegral m f row col = VU.foldl1' addMod $ VU.imap (\iRow x -> mulMod x (VG.unsafeIndex vecB (col + (iRow * w')))) (VU.unsafeSlice (w * row) w vecA) addMod x y | x + y >= m = x + y - m | otherwise = x + y mulMod x y = fromIntegral $ BT.mulMod bt (fromIntegral x) (fromIntegral y) !_ = ACIA.runtimeAssert (w == h') "AtCoder.Extra.Matrix.mul: matrix size mismatch" {-# INLINE mulMod7 #-} mulMod7 :: Int -> Matrix Int -> Matrix Int -> Matrix Int mulMod7 !m (Matrix h w vecA) (Matrix h' w' vecB) = Matrix h w' $ VU.create $ do c <- VUM.replicate (h * w') (0 :: Int) for_ [0 .. h - 1] $ \i -> do for_ [0 .. w - 1] $ \k -> do for_ [0 .. w' - 1] $ \j -> do let !aik = VG.unsafeIndex vecA (i * w + k) let !bkj = VG.unsafeIndex vecB (k * w' + j) VGM.unsafeModify c (addMod (mulMod_ aik bkj)) (i * w' + j) pure c where !bt = BT.new32 $ fromIntegral m addMod x y = (x + y) `rem` m mulMod_ x y = fromIntegral $ BT.mulMod bt (fromIntegral x) (fromIntegral y) !_ = ACIA.runtimeAssert (w == h') "AtCoder.Extra.Matrix.mulMod: matrix size mismatch" {-# INLINE mulMint1 #-} mulMint1 :: forall a. (KnownNat a) => Matrix (M.ModInt a) -> Matrix (M.ModInt a) -> Matrix (M.ModInt a) mulMint1 (Matrix h w vecA) (Matrix h' w' vecB) = Matrix h w' $ VU.unfoldrExactN (h * w') ( \(!row, !col) -> let !x = f row col in if col + 1 >= w' then (x, (row + 1, 0)) else (x, (row, col + 1)) ) (0, 0) where f :: Int -> Int -> M.ModInt a f row col = VU.sum $ VU.imap (\iRow x -> mulMod x (VG.unsafeIndex vecB (col + (iRow * w')))) (VU.unsafeSlice (w * row) w vecA) mulMod :: M.ModInt a -> M.ModInt a -> M.ModInt a mulMod = (*) !_ = ACIA.runtimeAssert (w == h') "AtCoder.Extra.Matrix.mul: matrix size mismatch" {-# INLINE mulMint2 #-} mulMint2 :: forall a. (KnownNat a) => Matrix (M.ModInt a) -> Matrix (M.ModInt a) -> Matrix (M.ModInt a) mulMint2 (Matrix h w vecA) (Matrix h' w' vecB) = Matrix h w' $ VU.unfoldrExactN (h * w') ( \(!row, !col) -> let !x = f row col in if col + 1 >= w' then (x, (row + 1, 0)) else (x, (row, col + 1)) ) (0, 0) where !bt = BT.new32 $ fromIntegral (natVal' (proxy# @a)) f :: Int -> Int -> M.ModInt a f row col = VU.sum $ VU.imap (\iRow x -> mulMod x (VG.unsafeIndex vecB (col + (iRow * w')))) (VU.unsafeSlice (w * row) w vecA) mulMod :: M.ModInt a -> M.ModInt a -> M.ModInt a mulMod (M.ModInt x) (M.ModInt y) = M.unsafeNew . fromIntegral $ BT.mulMod bt (fromIntegral x) (fromIntegral y) !_ = ACIA.runtimeAssert (w == h') "AtCoder.Extra.Matrix.mul: matrix size mismatch" {-# INLINE mulMint3 #-} mulMint3 :: forall a. (KnownNat a) => Matrix (M.ModInt a) -> Matrix (M.ModInt a) -> Matrix (M.ModInt a) mulMint3 (Matrix h w vecA) (Matrix h' w' vecB) = Matrix h w' $ VU.unfoldrExactN (h * w') ( \(!row, !col) -> let !x = f row col in if col + 1 >= w' then (x, (row + 1, 0)) else (x, (row, col + 1)) ) (0, 0) where !bt = BT.new32 $ fromIntegral (natVal' (proxy# @a)) -- NOTE: this is unsafe if the matrix is too large f :: Int -> Int -> M.ModInt a f row col = M.new64 . VU.sum $ VU.imap (\iRow x -> mulMod x (VG.unsafeIndex vecB (col + (iRow * w')))) (VU.unsafeSlice (w * row) w vecA) mulMod :: M.ModInt a -> M.ModInt a -> Word64 mulMod (M.ModInt x) (M.ModInt y) = BT.mulMod bt (fromIntegral x) (fromIntegral y) !_ = ACIA.runtimeAssert (w == h') "AtCoder.Extra.Matrix.mul: matrix size mismatch" -- performs memory efficient iteration, but requires more type conversions and slow {-# INLINE mulMint4 #-} mulMint4 :: forall a. (KnownNat a) => Matrix (M.ModInt a) -> Matrix (M.ModInt a) -> Matrix (M.ModInt a) mulMint4 (Matrix h w vecA) (Matrix h' w' vecB) = Matrix h w' $ VU.create $ do c <- VUM.replicate (h * w') (M.unsafeNew 0) for_ [0 .. h - 1] $ \i -> do for_ [0 .. w - 1] $ \k -> do for_ [0 .. w' - 1] $ \j -> do let !aik = VG.unsafeIndex vecA (i * w + k) let !bkj = VG.unsafeIndex vecB (k * w' + j) VGM.unsafeModify c (+ (mulMod_ aik bkj)) (i * w' + j) pure c where !bt = BT.new32 $ fromIntegral (natVal' (proxy# @a)) mulMod_ :: M.ModInt a -> M.ModInt a -> M.ModInt a mulMod_ (M.ModInt x) (M.ModInt y) = M.unsafeNew . fromIntegral $ BT.mulMod bt (fromIntegral x) (fromIntegral y) !_ = ACIA.runtimeAssert (w == h') "AtCoder.Extra.Matrix.mulMint: matrix size mismatch"