module Tests.Internal.Math (tests) where

import AtCoder.Internal.Assert qualified as ACIA
import AtCoder.Internal.Barrett qualified as ACIBT
import AtCoder.Internal.Math qualified as ACIM
import Control.Monad (foldM, unless, when)
import Data.Foldable
import Data.Int (Int32)
import Data.Vector.Generic.Mutable qualified as VGM
import Data.Vector.Unboxed qualified as VU
import Data.Vector.Unboxed.Mutable qualified as VUM
import Data.WideWord.Word128
import Data.Word (Word32, Word64)
import Test.Tasty
import Test.Tasty.HUnit

isPrimeNaive :: Int -> Bool
isPrimeNaive n
  | n < 0 = error "given negative value"
  | n == 0 || n == 1 = False
  | otherwise = all (\i -> n `mod` i /= 0) $ takeWhile (\x -> x * x <= n) [2 ..]

isPrimitiveRootNaive :: (HasCallStack) => Int -> Int -> Bool
isPrimitiveRootNaive m g
  | not (1 <= g && g < m) = error "invalid input"
  | otherwise = inner 1 1
  where
    inner i x
      | i > m - 2 =
          let !_ = ACIA.runtimeAssert (x' == 1) "isPrimitiveRootNaive"
           in True
      | x' == 1 = False
      | otherwise = inner (i + 1) x'
      where
        x' = (fromIntegral x :: Word64) * fromIntegral g `mod` fromIntegral m

unit_barrett :: TestTree
unit_barrett = testCase "barrett" $ do
  for_ [1 .. 100 :: Word64] $ \m -> do
    let bt = ACIBT.new64 m
    for_ [0 .. m - 1 :: Word64] $ \a -> do
      for_ [0 .. m - 1 :: Word64] $ \b -> do
        (a * b) `mod` m @=? ACIBT.mulMod bt a b

  let bt = ACIBT.new64 1
  0 @=? ACIBT.mulMod bt 0 0

testBarrettWithModulo :: Word32 -> Assertion
testBarrettWithModulo modUpper = do
  for_ [modUpper, modUpper - 1 .. modUpper - 20] $ \modulo32 -> do
    let modulo64 :: Word64 = fromIntegral modulo32
    let bt = ACIBT.new32 modulo32
    let v = VU.create $ do
          vec <- VUM.unsafeNew @_ @Word32 40
          for_ [0 .. 10 - 1 :: Word32] $ \i -> do
            VGM.write vec (4 * fromIntegral i + 0) i
            VGM.write vec (4 * fromIntegral i + 1) $ modulo32 - i
            VGM.write vec (4 * fromIntegral i + 2) $ modulo32 `div` 2 + i
            VGM.write vec (4 * fromIntegral i + 3) $ modulo32 `div` 2 - i
          pure vec
    VU.forM_ v $ \a -> do
      let a64 :: Word64 = fromIntegral a
      let expected = fromIntegral $ ((a64 * a64) `mod` modulo64 * a64) `mod` modulo64
      (expected @=?) . fromIntegral $ ACIBT.mulMod bt a64 (ACIBT.mulMod bt a64 a64)
      VU.forM_ v $ \b -> do
        let b64 :: Word64 = fromIntegral b
        (a64 * b64) `mod` modulo64 @=? ACIBT.mulMod bt a64 b64

unit_barrettIntBorder :: TestTree
unit_barrettIntBorder = testCase "barrettIntBobrder" $ do
  let modUpper :: Word32 = fromIntegral $ maxBound @Int32
  testBarrettWithModulo modUpper

unit_barrettWord32Border :: TestTree
unit_barrettWord32Border = testCase "barrettWord32Bobrder" $ do
  let modUpper = maxBound @Word32
  testBarrettWithModulo modUpper

unit_isPrime :: TestTree
unit_isPrime = testCase "isPrime" $ do
  (False @=?) $ ACIM.isPrime 121
  (False @=?) $ ACIM.isPrime $ 11 * 13
  (True @=?) $ ACIM.isPrime 1_000_000_007
  (False @=?) $ ACIM.isPrime 1_000_000_008
  (True @=?) $ ACIM.isPrime 1_000_000_009
  for_ [0 .. 10000] $ \i -> do
    isPrimeNaive i @=? ACIM.isPrime i
  for_ [0 .. 10000] $ \i -> do
    let x :: Int = fromIntegral $ maxBound @Int32 - i
    isPrimeNaive x @=? ACIM.isPrime x

-- SafeMod

unit_invGcdBound :: TestTree
unit_invGcdBound = testCase "invGcdBound" $ do
  let ps = VU.create $ do
        p <- VUM.unsafeNew @_ @Int (12 * 11 + 6)
        -- TODO: use `maxBound @Int` for Int variant of invGcd next time
        for_ [0 .. 10 :: Int] $ \i -> do
          VGM.write p (12 * i + 0) i
          VGM.write p (12 * i + 1) (-i)
          VGM.write p (12 * i + 2) $ minBound @Int + i
          VGM.write p (12 * i + 3) $ maxBound @Int - i

          VGM.write p (12 * i + 4) $ minBound @Int `div` 2 + i
          VGM.write p (12 * i + 5) $ minBound @Int `div` 2 - i
          VGM.write p (12 * i + 6) $ maxBound @Int `div` 2 + i
          VGM.write p (12 * i + 7) $ maxBound @Int `div` 2 - i

          VGM.write p (12 * i + 8) $ minBound @Int `div` 3 + i
          VGM.write p (12 * i + 9) $ minBound @Int `div` 3 - i
          VGM.write p (12 * i + 10) $ maxBound @Int `div` 3 + i
          VGM.write p (12 * i + 11) $ maxBound @Int `div` 3 - i

        VGM.write p (12 * 11 + 0) 998244353
        VGM.write p (12 * 11 + 1) 1_000_000_007
        VGM.write p (12 * 11 + 2) 1_000_000_009
        VGM.write p (12 * 11 + 3) (-998244353)
        VGM.write p (12 * 11 + 4) (-1_000_000_007)
        VGM.write p (12 * 11 + 5) (-1_000_000_009)
        pure p

  VU.forM_ ps $ \a -> do
    VU.forM_ ps $ \b -> do
      unless (b <= 0) $ do
        let a2 = a `mod` b
        let (!eg1, !eg2) = ACIM.invGcd a b
        let g = gcd a2 b
        g @=? eg1
        assertBool "<=" $ 0 <= eg2
        -- FIXME: not working correctly
        assertBool "<=" $ eg2 <= b `div` eg1
        fromIntegral (g `mod` b) @=? (fromIntegral eg2 :: Word128) * fromIntegral a2 `mod` fromIntegral b

unit_primitiveRootNaive :: TestTree
unit_primitiveRootNaive = testCase "primitiveRootNaive" $ do
  for_ [2 .. 10000] $ \m -> do
    when (ACIM.isPrime m) $ do
      let n = ACIM.primitiveRoot m
      assertBool "<=" $ 1 <= n
      assertBool "<" $ n < m
      x' <-
        foldM
          ( \x _ -> do
              let !xx = x * fromIntegral n `mod` fromIntegral m
              assertBool "/=" $ 1 /= xx
              pure xx
          )
          (1 :: Word64)
          [1 .. m - 2]
      let !x'' = x' * fromIntegral n `mod` fromIntegral m
      1 @=? x''

-- REMARK: too heavy
unit_primitiveRootTemplate :: TestTree
unit_primitiveRootTemplate = testCase "primitiveRootTemplate" $ do
  for_
    [ 2,
      3,
      5,
      7,
      11,
      998244353,
      1000000007,
      469762049,
      167772161,
      754974721,
      324013369,
      831143041,
      1685283601
    ]
    $ \x -> do
      assertBool "" $ isPrimitiveRootNaive x (ACIM.primitiveRoot x)

unit_primitiveRootTest :: TestTree
unit_primitiveRootTest = testCase "primitiveRootTest" $ do
  for_ [0 .. 1000 - 1] $ \i -> do
    let x = fromIntegral $ maxBound @Int32 - i :: Int
    when (ACIM.isPrime x) $ do
      assertBool "" $ isPrimitiveRootNaive x (ACIM.primitiveRoot x)

tests :: [TestTree]
tests =
  [ unit_barrett,
    unit_barrettIntBorder,
    unit_barrettWord32Border,
    unit_isPrime,
    unit_invGcdBound,
    unit_primitiveRootNaive
    -- REMARK: The following primitive root tests take too much time:
    -- unit_primitiveRootTemplate,
    -- unit_primitiveRootTest
  ]