{-|
Copyright  :  (C) 2019, Andrew Lelechenko
License    :  MIT
Maintainer :  QBayLogic B.V. <devops@qbaylogic.com>

This module contains code from: https://hackage.haskell.org/package/mod and has
the following license:

Copyright (c) 2019 Andrew Lelechenko

Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
associated documentation files (the "Software"), to deal in the Software without restriction,
including without limitation the rights to use, copy, modify, merge, publish, distribute,
sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all copies or
substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT
LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
-}

{-# LANGUAGE CPP #-}
{-# LANGUAGE UnboxedTuples #-}

module Clash.Sized.Internal.Mod where

#if MIN_VERSION_base(4,15,0)
import GHC.Exts (eqWord#, leWord#, word2Int#)
#else
import GHC.Exts ((==#))
#endif
import GHC.Exts
  ((<=#), geWord#, isTrue#, minusWord#, plusWord#, uncheckedShiftL#, xor#,
   timesWord2#, quotRemWord2#, and#, addWordC#)
#if MIN_VERSION_base(4,15,0)
import GHC.Num.BigNat
  (BigNat#, bigNatAdd, bigNatAddWord#, bigNatAnd, bigNatBit#, bigNatCompare,
   bigNatFromWord#, bigNatFromWord2#, bigNatMul, bigNatMulWord#, bigNatRem,
   bigNatSize#, bigNatSubUnsafe, bigNatSubWordUnsafe#, bigNatToWord#, bigNatXor)
import GHC.Num.Natural (Natural (..))
#else
import GHC.Natural (Natural (..))
import GHC.Integer.GMP.Internals
  (BigNat, Integer (..), bigNatToWord, compareBigNat, minusBigNat, minusBigNatWord,
   plusBigNat, plusBigNatWord, sizeofBigNat#, bitBigNat, wordToBigNat2,
   remBigNat, timesBigNat, timesBigNatWord, xorBigNat, wordToBigNat, andBigNat)
#endif

#include "MachDeps.h"

#if MIN_VERSION_base(4,15,0)
-- | modular subtraction
subMod :: Natural -> Natural -> Natural -> Natural
subMod :: Natural -> Natural -> Natural -> Natural
subMod (NS Word#
m#) (NS Word#
x#) (NS Word#
y#) =
  if Int# -> Bool
isTrue# (Word#
x# Word# -> Word# -> Int#
`geWord#` Word#
y#) then Word# -> Natural
NS Word#
z# else Word# -> Natural
NS (Word#
z# Word# -> Word# -> Word#
`plusWord#` Word#
m#)
  where
    z# :: Word#
z# = Word#
x# Word# -> Word# -> Word#
`minusWord#` Word#
y#
subMod NS{} Natural
_ Natural
_ = Natural
forall a. a
brokenInvariant
subMod (NB ByteArray#
m#) (NS Word#
x#) (NS Word#
y#) =
  if Int# -> Bool
isTrue# (Word#
x# Word# -> Word# -> Int#
`geWord#` Word#
y#)
    then Word# -> Natural
NS (Word#
x# Word# -> Word# -> Word#
`minusWord#` Word#
y#)
    else ByteArray# -> Natural
bigNatToNat (ByteArray#
m# ByteArray# -> Word# -> ByteArray#
`bigNatSubWordUnsafe#` (Word#
y# Word# -> Word# -> Word#
`minusWord#` Word#
x#))
subMod (NB ByteArray#
m#) (NS Word#
x#) (NB ByteArray#
y#) =
  ByteArray# -> Natural
bigNatToNat ((ByteArray#
m# ByteArray# -> ByteArray# -> ByteArray#
`bigNatSubUnsafe` ByteArray#
y#) ByteArray# -> Word# -> ByteArray#
`bigNatAddWord#` Word#
x#)
subMod NB{} (NB ByteArray#
x#) (NS Word#
y#) =
  ByteArray# -> Natural
bigNatToNat (ByteArray#
x# ByteArray# -> Word# -> ByteArray#
`bigNatSubWordUnsafe#` Word#
y#)
subMod (NB ByteArray#
m#) (NB ByteArray#
x#) (NB ByteArray#
y#) = case ByteArray#
x# ByteArray# -> ByteArray# -> Ordering
`bigNatCompare` ByteArray#
y# of
  Ordering
LT -> ByteArray# -> Natural
bigNatToNat ((ByteArray#
m# ByteArray# -> ByteArray# -> ByteArray#
`bigNatSubUnsafe` ByteArray#
y#) ByteArray# -> ByteArray# -> ByteArray#
`bigNatAdd` ByteArray#
x#)
  Ordering
EQ -> Word# -> Natural
NS Word#
0##
  Ordering
GT -> ByteArray# -> Natural
bigNatToNat (ByteArray#
x# ByteArray# -> ByteArray# -> ByteArray#
`bigNatSubUnsafe` ByteArray#
y#)

-- | modular addition
addMod :: Natural -> Natural -> Natural -> Natural
addMod :: Natural -> Natural -> Natural -> Natural
addMod (NS Word#
m#) (NS Word#
x#) (NS Word#
y#) =
  if Int# -> Bool
isTrue# Int#
c# Bool -> Bool -> Bool
|| Int# -> Bool
isTrue# (Word#
z# Word# -> Word# -> Int#
`geWord#` Word#
m#) then Word# -> Natural
NS (Word#
z# Word# -> Word# -> Word#
`minusWord#` Word#
m#) else Word# -> Natural
NS Word#
z#
  where
    !(# Word#
z#, Int#
c# #) = Word#
x# Word# -> Word# -> (# Word#, Int# #)
`addWordC#` Word#
y#
addMod NS{} Natural
_ Natural
_ = Natural
forall a. a
brokenInvariant
addMod (NB ByteArray#
m#) (NS Word#
x#) (NS Word#
y#) =
  if Int# -> Bool
isTrue# Int#
c# then ByteArray# -> ByteArray# -> Natural
subIfGe (Word# -> Word# -> ByteArray#
bigNatFromWord2# Word#
1## Word#
z#) ByteArray#
m# else Word# -> Natural
NS Word#
z#
  where
    !(# Word#
z#, Int#
c# #) = Word#
x# Word# -> Word# -> (# Word#, Int# #)
`addWordC#` Word#
y#
addMod (NB ByteArray#
m#) (NS Word#
x#) (NB ByteArray#
y#) = ByteArray# -> ByteArray# -> Natural
subIfGe (ByteArray#
y# ByteArray# -> Word# -> ByteArray#
`bigNatAddWord#` Word#
x#) ByteArray#
m#
addMod (NB ByteArray#
m#) (NB ByteArray#
x#) (NS Word#
y#) = ByteArray# -> ByteArray# -> Natural
subIfGe (ByteArray#
x# ByteArray# -> Word# -> ByteArray#
`bigNatAddWord#` Word#
y#) ByteArray#
m#
addMod (NB ByteArray#
m#) (NB ByteArray#
x#) (NB ByteArray#
y#) = ByteArray# -> ByteArray# -> Natural
subIfGe (ByteArray#
x# ByteArray# -> ByteArray# -> ByteArray#
`bigNatAdd`     ByteArray#
y#) ByteArray#
m#

-- | modular multiplication
mulMod :: Natural -> Natural -> Natural -> Natural
mulMod :: Natural -> Natural -> Natural -> Natural
mulMod (NS Word#
m#) (NS Word#
x#) (NS Word#
y#) = Word# -> Natural
NS Word#
r#
  where
    !(# Word#
z1#, Word#
z2# #) = Word# -> Word# -> (# Word#, Word# #)
timesWord2# Word#
x# Word#
y#
    !(# Word#
_, Word#
r# #) = Word# -> Word# -> Word# -> (# Word#, Word# #)
quotRemWord2# Word#
z1# Word#
z2# Word#
m#
mulMod NS{} Natural
_ Natural
_ = Natural
forall a. a
brokenInvariant
mulMod (NB ByteArray#
m#) (NS Word#
x#) (NS Word#
y#) =
  ByteArray# -> Natural
bigNatToNat (Word# -> Word# -> ByteArray#
bigNatFromWord2# Word#
z1# Word#
z2# ByteArray# -> ByteArray# -> ByteArray#
`bigNatRem` ByteArray#
m#)
  where
    !(# Word#
z1#, Word#
z2# #) = Word# -> Word# -> (# Word#, Word# #)
timesWord2# Word#
x# Word#
y#
mulMod (NB ByteArray#
m#) (NS Word#
x#) (NB ByteArray#
y#) =
  ByteArray# -> Natural
bigNatToNat ((ByteArray#
y# ByteArray# -> Word# -> ByteArray#
`bigNatMulWord#` Word#
x#) ByteArray# -> ByteArray# -> ByteArray#
`bigNatRem` ByteArray#
m#)
mulMod (NB ByteArray#
m#) (NB ByteArray#
x#) (NS Word#
y#) =
  ByteArray# -> Natural
bigNatToNat ((ByteArray#
x# ByteArray# -> Word# -> ByteArray#
`bigNatMulWord#` Word#
y#) ByteArray# -> ByteArray# -> ByteArray#
`bigNatRem` ByteArray#
m#)
mulMod (NB ByteArray#
m#) (NB ByteArray#
x#) (NB ByteArray#
y#) =
  ByteArray# -> Natural
bigNatToNat ((ByteArray#
x# ByteArray# -> ByteArray# -> ByteArray#
`bigNatMul` ByteArray#
y#) ByteArray# -> ByteArray# -> ByteArray#
`bigNatRem` ByteArray#
m#)

-- | modular multiplication for powers of 2, takes a mask instead of a
-- wrap-around point
mulMod2 :: Natural -> Natural -> Natural -> Natural
mulMod2 :: Natural -> Natural -> Natural -> Natural
mulMod2 (NS Word#
m#) (NS Word#
x#) (NS Word#
y#) = Word# -> Natural
NS (Word#
z2# Word# -> Word# -> Word#
`and#` Word#
m#)
  where
    !(# Word#
_, Word#
z2# #) = Word# -> Word# -> (# Word#, Word# #)
timesWord2# Word#
x# Word#
y#
mulMod2 NS{} Natural
_ Natural
_ = Natural
forall a. a
brokenInvariant
mulMod2 (NB ByteArray#
m#) (NS Word#
x#) (NS Word#
y#) =
  ByteArray# -> Natural
bigNatToNat (Word# -> Word# -> ByteArray#
bigNatFromWord2# Word#
z1# Word#
z2# ByteArray# -> ByteArray# -> ByteArray#
`bigNatAnd` ByteArray#
m#)
  where
    !(# Word#
z1#, Word#
z2# #) = Word# -> Word# -> (# Word#, Word# #)
timesWord2# Word#
x# Word#
y#
mulMod2 (NB ByteArray#
m#) (NS Word#
x#) (NB ByteArray#
y#) =
  ByteArray# -> Natural
bigNatToNat ((ByteArray#
y# ByteArray# -> Word# -> ByteArray#
`bigNatMulWord#` Word#
x#) ByteArray# -> ByteArray# -> ByteArray#
`bigNatAnd` ByteArray#
m#)
mulMod2 (NB ByteArray#
m#) (NB ByteArray#
x#) (NS Word#
y#) =
  ByteArray# -> Natural
bigNatToNat ((ByteArray#
x# ByteArray# -> Word# -> ByteArray#
`bigNatMulWord#` Word#
y#) ByteArray# -> ByteArray# -> ByteArray#
`bigNatAnd` ByteArray#
m#)
mulMod2 (NB ByteArray#
m#) (NB ByteArray#
x#) (NB ByteArray#
y#) =
  ByteArray# -> Natural
bigNatToNat ((ByteArray#
x# ByteArray# -> ByteArray# -> ByteArray#
`bigNatMul` ByteArray#
y#) ByteArray# -> ByteArray# -> ByteArray#
`bigNatAnd` ByteArray#
m#)

-- | modular negations
negateMod :: Natural -> Natural -> Natural
negateMod :: Natural -> Natural -> Natural
negateMod Natural
_ (NS Word#
0##) = Word# -> Natural
NS Word#
0##
negateMod (NS Word#
m#) (NS Word#
x#) = Word# -> Natural
NS (Word#
m# Word# -> Word# -> Word#
`minusWord#` Word#
x#)
negateMod NS{} Natural
_ = Natural
forall a. a
brokenInvariant
negateMod (NB ByteArray#
m#) (NS Word#
x#) = ByteArray# -> Natural
bigNatToNat (ByteArray#
m# ByteArray# -> Word# -> ByteArray#
`bigNatSubWordUnsafe#` Word#
x#)
negateMod (NB ByteArray#
m#) (NB ByteArray#
x#) = ByteArray# -> Natural
bigNatToNat (ByteArray#
m# ByteArray# -> ByteArray# -> ByteArray#
`bigNatSubUnsafe`      ByteArray#
x#)

-- | Given a size in bits, return a function that complements the bits in a
-- 'Natural' up to that size.
complementMod
  :: Natural
  -> (Natural -> Natural)
complementMod :: Natural -> Natural -> Natural
complementMod (NS Word#
sz#) =
  if Int# -> Bool
isTrue# (Word#
sz# Word# -> Word# -> Int#
`leWord#` WORD_SIZE_IN_BITS##) then
    let m# :: Word#
m# = if Int# -> Bool
isTrue# (Word#
sz# Word# -> Word# -> Int#
`eqWord#` WORD_SIZE_IN_BITS##) then
#if WORD_SIZE_IN_BITS == 64
                Word#
0xFFFFFFFFFFFFFFFF##
#elif WORD_SIZE_IN_BITS == 32
                0xFFFFFFFF##
#else
#error Unhandled value for WORD_SIZE_IN_BITS
#endif
             else
               (Word#
1## Word# -> Int# -> Word#
`uncheckedShiftL#` (Word# -> Int#
word2Int# Word#
sz#)) Word# -> Word# -> Word#
`minusWord#` Word#
1##
        go :: Natural -> Natural
go (NS Word#
x#) = Word# -> Natural
NS (Word#
x# Word# -> Word# -> Word#
`xor#` Word#
m#)
        go (NB ByteArray#
r#) = Word# -> Natural
NS (ByteArray# -> Word#
bigNatToWord# ByteArray#
r# Word# -> Word# -> Word#
`xor#` Word#
m#)
    in  Natural -> Natural
go
  else
    let m# :: ByteArray#
m# = Word# -> ByteArray#
bigNatBit# Word#
sz# ByteArray# -> Word# -> ByteArray#
`bigNatSubWordUnsafe#` Word#
1##

        go :: Natural -> Natural
go (NS Word#
x#) = ByteArray# -> Natural
bigNatToNat (ByteArray# -> ByteArray# -> ByteArray#
bigNatXor (Word# -> ByteArray#
bigNatFromWord# Word#
x#) ByteArray#
m#)
        go (NB ByteArray#
x#) = ByteArray# -> Natural
bigNatToNat (ByteArray# -> ByteArray# -> ByteArray#
bigNatXor ByteArray#
x# ByteArray#
m#)
    in  Natural -> Natural
go
complementMod Natural
_ = [Char] -> Natural -> Natural
forall a. HasCallStack => [Char] -> a
error [Char]
"size too large"

-- | Keep all the bits up to a certain size
maskMod
  :: Natural
  -> (Natural -> Natural)
maskMod :: Natural -> Natural -> Natural
maskMod (NS Word#
sz#) =
  if Int# -> Bool
isTrue# (Word#
sz# Word# -> Word# -> Int#
`leWord#` WORD_SIZE_IN_BITS##) then
    if Int# -> Bool
isTrue# (Word#
sz# Word# -> Word# -> Int#
`eqWord#` WORD_SIZE_IN_BITS##) then
       -- Mask equal to the word size
       let go :: Natural -> Natural
go (NB ByteArray#
x#) = Word# -> Natural
NS (ByteArray# -> Word#
bigNatToWord# ByteArray#
x#)
           go Natural
n          = Natural
n
       in  Natural -> Natural
go
    else
       let m# :: Word#
m# = (Word#
1## Word# -> Int# -> Word#
`uncheckedShiftL#` (Word# -> Int#
word2Int# Word#
sz#)) Word# -> Word# -> Word#
`minusWord#` Word#
1##

           go :: Natural -> Natural
go (NS Word#
x#) = Word# -> Natural
NS (Word#
x# Word# -> Word# -> Word#
`and#` Word#
m#)
           go (NB ByteArray#
x#) = Word# -> Natural
NS (ByteArray# -> Word#
bigNatToWord# ByteArray#
x# Word# -> Word# -> Word#
`and#` Word#
m#)
       in  Natural -> Natural
go
  else
    let m# :: ByteArray#
m# = Word# -> ByteArray#
bigNatBit# Word#
sz#

        -- faster than `bigNatAnd (m# `minuxBigNatWord` 1##)`
        go :: Natural -> Natural
go (NB ByteArray#
x#) = ByteArray# -> Natural
bigNatToNat (ByteArray# -> ByteArray# -> ByteArray#
bigNatRem ByteArray#
x# ByteArray#
m#)
        -- The mask is larger than the word size, so we can keep all the bits
        go Natural
x = Natural
x
    in  Natural -> Natural
go
maskMod Natural
_ = [Char] -> Natural -> Natural
forall a. HasCallStack => [Char] -> a
error [Char]
"size too large"

bigNatToNat :: BigNat# -> Natural
bigNatToNat :: ByteArray# -> Natural
bigNatToNat ByteArray#
r# =
  if Int# -> Bool
isTrue# (ByteArray# -> Int#
bigNatSize# ByteArray#
r# Int# -> Int# -> Int#
<=# Int#
1#) then
    Word# -> Natural
NS (ByteArray# -> Word#
bigNatToWord# ByteArray#
r#)
  else
    ByteArray# -> Natural
NB ByteArray#
r#

subIfGe :: BigNat# -> BigNat# -> Natural
subIfGe :: ByteArray# -> ByteArray# -> Natural
subIfGe ByteArray#
z# ByteArray#
m# = case ByteArray#
z# ByteArray# -> ByteArray# -> Ordering
`bigNatCompare` ByteArray#
m# of
  Ordering
LT -> ByteArray# -> Natural
NB ByteArray#
z#
  Ordering
EQ -> Word# -> Natural
NS Word#
0##
  Ordering
GT -> ByteArray# -> Natural
bigNatToNat (ByteArray#
z# ByteArray# -> ByteArray# -> ByteArray#
`bigNatSubUnsafe` ByteArray#
m#)
#else
-- | modular subtraction
subMod :: Natural -> Natural -> Natural -> Natural
subMod (NatS# m#) (NatS# x#) (NatS# y#) =
  if isTrue# (x# `geWord#` y#) then NatS# z# else NatS# (z# `plusWord#` m#)
  where
    z# = x# `minusWord#` y#
subMod NatS#{} _ _ = brokenInvariant
subMod (NatJ# m#) (NatS# x#) (NatS# y#) =
  if isTrue# (x# `geWord#` y#)
    then NatS# (x# `minusWord#` y#)
    else bigNatToNat $ m# `minusBigNatWord` (y# `minusWord#` x#)
subMod (NatJ# m#) (NatS# x#) (NatJ# y#) =
  bigNatToNat $ (m# `minusBigNat` y#) `plusBigNatWord` x#
subMod NatJ#{} (NatJ# x#) (NatS# y#) =
  bigNatToNat $ x# `minusBigNatWord` y#
subMod (NatJ# m#) (NatJ# x#) (NatJ# y#) = case x# `compareBigNat` y# of
  LT -> bigNatToNat $ (m# `minusBigNat` y#) `plusBigNat` x#
  EQ -> NatS# 0##
  GT -> bigNatToNat $ x# `minusBigNat` y#

-- | modular addition
addMod :: Natural -> Natural -> Natural -> Natural
addMod (NatS# m#) (NatS# x#) (NatS# y#) =
  if isTrue# c# || isTrue# (z# `geWord#` m#) then NatS# (z# `minusWord#` m#) else NatS# z#
  where
    !(# z#, c# #) = x# `addWordC#` y#
addMod NatS#{} _ _ = brokenInvariant
addMod (NatJ# m#) (NatS# x#) (NatS# y#) =
  if isTrue# c# then subIfGe (wordToBigNat2 1## z#) m# else NatS# z#
  where
    !(# z#, c# #) = x# `addWordC#` y#
addMod (NatJ# m#) (NatS# x#) (NatJ# y#) = subIfGe (y# `plusBigNatWord` x#) m#
addMod (NatJ# m#) (NatJ# x#) (NatS# y#) = subIfGe (x# `plusBigNatWord` y#) m#
addMod (NatJ# m#) (NatJ# x#) (NatJ# y#) = subIfGe (x# `plusBigNat`     y#) m#

-- | modular multiplication
mulMod :: Natural -> Natural -> Natural -> Natural
mulMod (NatS# m#) (NatS# x#) (NatS# y#) = NatS# r#
  where
    !(# z1#, z2# #) = timesWord2# x# y#
    !(# _, r# #) = quotRemWord2# z1# z2# m#
mulMod NatS#{} _ _ = brokenInvariant
mulMod (NatJ# m#) (NatS# x#) (NatS# y#) =
  bigNatToNat $ wordToBigNat2 z1# z2# `remBigNat` m#
  where
    !(# z1#, z2# #) = timesWord2# x# y#
mulMod (NatJ# m#) (NatS# x#) (NatJ# y#) =
  bigNatToNat $ (y# `timesBigNatWord` x#) `remBigNat` m#
mulMod (NatJ# m#) (NatJ# x#) (NatS# y#) =
  bigNatToNat $ (x# `timesBigNatWord` y#) `remBigNat` m#
mulMod (NatJ# m#) (NatJ# x#) (NatJ# y#) =
  bigNatToNat $ (x# `timesBigNat` y#) `remBigNat` m#

-- | modular multiplication for powers of 2, takes a mask instead of a
-- wrap-around point
mulMod2 :: Natural -> Natural -> Natural -> Natural
mulMod2 (NatS# m#) (NatS# x#) (NatS# y#) = NatS# (z2# `and#` m#)
  where
    !(# _, z2# #) = timesWord2# x# y#
mulMod2 NatS#{} _ _ = brokenInvariant
mulMod2 (NatJ# m#) (NatS# x#) (NatS# y#) =
  bigNatToNat $ wordToBigNat2 z1# z2# `andBigNat` m#
  where
    !(# z1#, z2# #) = timesWord2# x# y#
mulMod2 (NatJ# m#) (NatS# x#) (NatJ# y#) =
  bigNatToNat $ (y# `timesBigNatWord` x#) `andBigNat` m#
mulMod2 (NatJ# m#) (NatJ# x#) (NatS# y#) =
  bigNatToNat $ (x# `timesBigNatWord` y#) `andBigNat` m#
mulMod2 (NatJ# m#) (NatJ# x#) (NatJ# y#) =
  bigNatToNat $ (x# `timesBigNat` y#) `andBigNat` m#

-- | modular negations
negateMod :: Natural -> Natural -> Natural
negateMod _ (NatS# 0##) = NatS# 0##
negateMod (NatS# m#) (NatS# x#) = NatS# (m# `minusWord#` x#)
negateMod NatS#{} _ = brokenInvariant
negateMod (NatJ# m#) (NatS# x#) = bigNatToNat $ m# `minusBigNatWord` x#
negateMod (NatJ# m#) (NatJ# x#) = bigNatToNat $ m# `minusBigNat`     x#

-- | Given a size in bits, return a function that complements the bits in a
-- 'Natural' up to that size.
complementMod
  :: Integer
  -> (Natural -> Natural)
complementMod (S# sz#) =
  if isTrue# (sz# <=# WORD_SIZE_IN_BITS#) then
    let m# = if isTrue# (sz# ==# WORD_SIZE_IN_BITS#) then
#if WORD_SIZE_IN_BITS == 64
                0xFFFFFFFFFFFFFFFF##
#elif WORD_SIZE_IN_BITS == 32
                0xFFFFFFFF##
#else
#error Unhandled value for WORD_SIZE_IN_BITS
#endif
             else
               (1## `uncheckedShiftL#` sz#) `minusWord#` 1##
        go (NatS# x#) = NatS# (x# `xor#` m#)
        go (NatJ# r#) = NatS# (bigNatToWord r# `xor#` m#)
    in  go
  else
    let m# = bitBigNat sz# `minusBigNatWord` 1##

        go (NatS# x#) = bigNatToNat (xorBigNat (wordToBigNat x#) m#)
        go (NatJ# x#) = bigNatToNat (xorBigNat x# m#)
    in  go
complementMod _ = error "size too large"

-- | Keep all the bits up to a certain size
maskMod
  :: Integer
  -> (Natural -> Natural)
maskMod (S# sz#) =
  if isTrue# (sz# <=# WORD_SIZE_IN_BITS#) then
    if isTrue# (sz# ==# WORD_SIZE_IN_BITS#) then
       -- Mask equal to the word size
       let go (NatJ# x#) = NatS# (bigNatToWord x#)
           go n          = n
       in  go
    else
       let m# = (1## `uncheckedShiftL#` sz#) `minusWord#` 1##

           go (NatS# x#) = NatS# (x# `and#` m#)
           go (NatJ# x#) = NatS# (bigNatToWord x# `and#` m#)
       in  go
  else
    let m# = bitBigNat sz#

        -- faster than `andBigNat (m# `minuxBigNatWord` 1##)`
        go (NatJ# x#) = bigNatToNat (remBigNat x# m#)
        -- The mask is larger than the word size, so we can keep all the bits
        go x = x
    in  go
maskMod _ = error "size too large"

bigNatToNat :: BigNat -> Natural
bigNatToNat r# =
  if isTrue# (sizeofBigNat# r# ==# 1#) then
    NatS# (bigNatToWord r#)
  else
    NatJ# r#

subIfGe :: BigNat -> BigNat -> Natural
subIfGe z# m# = case z# `compareBigNat` m# of
  LT -> NatJ# z#
  EQ -> NatS# 0##
  GT -> bigNatToNat $ z# `minusBigNat` m#

#endif

brokenInvariant :: a
brokenInvariant :: forall a. a
brokenInvariant = [Char] -> a
forall a. HasCallStack => [Char] -> a
error [Char]
"argument is larger than modulo"