{-# LANGUAGE BangPatterns, BlockArguments #-}
{- |
Module      : Data.RME.Vector
Copyright   : Galois, Inc. 2016
License     : BSD3
Maintainer  : huffman@galois.com
Stability   : experimental
Portability : portable

Operations on big-endian vectors of RME formulas.
-}

module Data.RME.Vector
  ( RMEV
  , eq, ule, ult, sle, slt
  , neg, add, sub, mul
  , udiv, urem, sdiv, srem
  , pmul, pmod, pdiv
  , shl, ashr, lshr, ror, rol
  , integer
  , popcount
  , countLeadingZeros
  , countTrailingZeros
  ) where

import Data.RME.Base (RME)
import qualified Data.RME.Base as RME

import qualified Data.Bits as Bits
import Data.Vector (Vector)
import qualified Data.Vector as V

type RMEV = Vector RME

-- | Constant integer literals.
integer :: Int -> Integer -> RMEV
integer :: Int -> Integer -> RMEV
integer Int
width Integer
x = RMEV -> RMEV
forall a. Vector a -> Vector a
V.reverse (Int -> (Int -> RME) -> RMEV
forall a. Int -> (Int -> a) -> Vector a
V.generate Int
width (Bool -> RME
RME.constant (Bool -> RME) -> (Int -> Bool) -> Int -> RME
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
Bits.testBit Integer
x))

-- | Bitvector equality.
eq :: RMEV -> RMEV -> RME
eq :: RMEV -> RMEV -> RME
eq RMEV
x RMEV
y = (RME -> RME -> RME) -> RME -> RMEV -> RME
forall a b. (a -> b -> b) -> b -> Vector a -> b
V.foldr RME -> RME -> RME
RME.conj RME
RME.true ((RME -> RME -> RME) -> RMEV -> RMEV -> RMEV
forall a b c. (a -> b -> c) -> Vector a -> Vector b -> Vector c
V.zipWith RME -> RME -> RME
RME.iff RMEV
x RMEV
y)

-- | Unsigned less-than-or-equal.
ule :: RMEV -> RMEV -> RME
ule :: RMEV -> RMEV -> RME
ule RMEV
xv RMEV
yv = [RME] -> [RME] -> RME
go (RMEV -> [RME]
forall a. Vector a -> [a]
V.toList RMEV
xv) (RMEV -> [RME]
forall a. Vector a -> [a]
V.toList RMEV
yv)
  where
    go :: [RME] -> [RME] -> RME
go (RME
x : [RME]
xs) (RME
y : [RME]
ys) =
      let z :: RME
z = [RME] -> [RME] -> RME
go [RME]
xs [RME]
ys
      in RME -> RME -> RME
RME.xor (RME -> RME -> RME
RME.conj RME
y RME
z) (RME -> RME -> RME
RME.conj (RME -> RME
RME.compl RME
x) (RME -> RME -> RME
RME.xor RME
y RME
z))
    go [RME]
_ [RME]
_ = RME
RME.true

-- | Unsigned less-than.
ult :: RMEV -> RMEV -> RME
ult :: RMEV -> RMEV -> RME
ult RMEV
x RMEV
y = RME -> RME
RME.compl (RMEV -> RMEV -> RME
ule RMEV
y RMEV
x)

swap_sign :: RMEV -> RMEV
swap_sign :: RMEV -> RMEV
swap_sign RMEV
x
  | RMEV -> Bool
forall a. Vector a -> Bool
V.null RMEV
x = RMEV
x
  | Bool
otherwise = RME -> RMEV
forall a. a -> Vector a
V.singleton (RME -> RME
RME.compl (RMEV -> RME
forall a. Vector a -> a
V.head RMEV
x)) RMEV -> RMEV -> RMEV
forall a. Vector a -> Vector a -> Vector a
V.++ RMEV -> RMEV
forall a. Vector a -> Vector a
V.tail RMEV
x

-- | Signed less-than-or-equal.
sle :: RMEV -> RMEV -> RME
sle :: RMEV -> RMEV -> RME
sle RMEV
x RMEV
y = RMEV -> RMEV -> RME
ule (RMEV -> RMEV
swap_sign RMEV
x) (RMEV -> RMEV
swap_sign RMEV
y)

-- | Signed less-than.
slt :: RMEV -> RMEV -> RME
slt :: RMEV -> RMEV -> RME
slt RMEV
x RMEV
y = RMEV -> RMEV -> RME
ult (RMEV -> RMEV
swap_sign RMEV
x) (RMEV -> RMEV
swap_sign RMEV
y)

-- | Big-endian bitvector increment with carry.
increment :: [RME] -> (RME, [RME])
increment :: [RME] -> (RME, [RME])
increment [] = (RME
RME.true, [])
increment (RME
x : [RME]
xs) = (RME -> RME -> RME
RME.conj RME
x RME
c, RME -> RME -> RME
RME.xor RME
x RME
c RME -> [RME] -> [RME]
forall a. a -> [a] -> [a]
: [RME]
ys)
  where (RME
c, [RME]
ys) = [RME] -> (RME, [RME])
increment [RME]
xs

-- | Two's complement bitvector negation.
neg :: RMEV -> RMEV
neg :: RMEV -> RMEV
neg RMEV
x = [RME] -> RMEV
forall a. [a] -> Vector a
V.fromList ((RME, [RME]) -> [RME]
forall a b. (a, b) -> b
snd ([RME] -> (RME, [RME])
increment ((RME -> RME) -> [RME] -> [RME]
forall a b. (a -> b) -> [a] -> [b]
map RME -> RME
RME.compl (RMEV -> [RME]
forall a. Vector a -> [a]
V.toList RMEV
x))))

-- | 1-bit full adder.
full_adder :: RME -> RME -> RME -> (RME, RME)
full_adder :: RME -> RME -> RME -> (RME, RME)
full_adder RME
a RME
b RME
c = (RME
carry, RME -> RME -> RME
RME.xor (RME -> RME -> RME
RME.xor RME
a RME
b) RME
c)
  where carry :: RME
carry = RME -> RME -> RME
RME.xor (RME -> RME -> RME
RME.conj RME
a RME
b) (RME -> RME -> RME
RME.conj (RME -> RME -> RME
RME.xor RME
a RME
b) RME
c)

-- | Big-endian ripple-carry adder.
ripple_carry_adder :: [RME] -> [RME] -> RME -> (RME, [RME])
ripple_carry_adder :: [RME] -> [RME] -> RME -> (RME, [RME])
ripple_carry_adder [] [RME]
_ RME
c = (RME
c, [])
ripple_carry_adder [RME]
_ [] RME
c = (RME
c, [])
ripple_carry_adder (RME
x : [RME]
xs) (RME
y : [RME]
ys) RME
c = (RME
c'', RME
z RME -> [RME] -> [RME]
forall a. a -> [a] -> [a]
: [RME]
zs)
  where (RME
c', [RME]
zs) = [RME] -> [RME] -> RME -> (RME, [RME])
ripple_carry_adder [RME]
xs [RME]
ys RME
c
        (RME
c'', RME
z) = RME -> RME -> RME -> (RME, RME)
full_adder RME
x RME
y RME
c'

-- | Two's complement bitvector addition.
add :: RMEV -> RMEV -> RMEV
add :: RMEV -> RMEV -> RMEV
add RMEV
x RMEV
y =
  [RME] -> RMEV
forall a. [a] -> Vector a
V.fromList ((RME, [RME]) -> [RME]
forall a b. (a, b) -> b
snd ([RME] -> [RME] -> RME -> (RME, [RME])
ripple_carry_adder (RMEV -> [RME]
forall a. Vector a -> [a]
V.toList RMEV
x) (RMEV -> [RME]
forall a. Vector a -> [a]
V.toList RMEV
y) RME
RME.false))

-- | Two's complement bitvector subtraction.
sub :: RMEV -> RMEV -> RMEV
sub :: RMEV -> RMEV -> RMEV
sub RMEV
x RMEV
y =
  [RME] -> RMEV
forall a. [a] -> Vector a
V.fromList ((RME, [RME]) -> [RME]
forall a b. (a, b) -> b
snd ([RME] -> [RME] -> RME -> (RME, [RME])
ripple_carry_adder (RMEV -> [RME]
forall a. Vector a -> [a]
V.toList RMEV
x) ((RME -> RME) -> [RME] -> [RME]
forall a b. (a -> b) -> [a] -> [b]
map RME -> RME
RME.compl (RMEV -> [RME]
forall a. Vector a -> [a]
V.toList RMEV
y)) RME
RME.true))

-- | Two's complement bitvector multiplication.
mul :: RMEV -> RMEV -> RMEV
mul :: RMEV -> RMEV -> RMEV
mul RMEV
x RMEV
y = (RMEV -> RME -> RMEV) -> RMEV -> RMEV -> RMEV
forall a b. (a -> b -> a) -> a -> Vector b -> a
V.foldl RMEV -> RME -> RMEV
f RMEV
zero RMEV
y
  where
    zero :: RMEV
zero = Int -> RME -> RMEV
forall a. Int -> a -> Vector a
V.replicate (RMEV -> Int
forall a. Vector a -> Int
V.length RMEV
x) RME
RME.false
    f :: RMEV -> RME -> RMEV
f RMEV
acc RME
c = (RME -> RME -> RME) -> RMEV -> RMEV -> RMEV
forall a b c. (a -> b -> c) -> Vector a -> Vector b -> Vector c
V.zipWith (RME -> RME -> RME -> RME
RME.mux RME
c) (RMEV -> RMEV -> RMEV
add RMEV
acc2 RMEV
x) RMEV
acc2
      where acc2 :: RMEV
acc2 = Int -> RMEV -> RMEV
forall a. Int -> Vector a -> Vector a
V.drop Int
1 (RMEV
acc RMEV -> RMEV -> RMEV
forall a. Vector a -> Vector a -> Vector a
V.++ RME -> RMEV
forall a. a -> Vector a
V.singleton RME
RME.false)

-- | Unsigned bitvector division.
udiv :: RMEV -> RMEV -> RMEV
udiv :: RMEV -> RMEV -> RMEV
udiv RMEV
x RMEV
y = (RMEV, RMEV) -> RMEV
forall a b. (a, b) -> a
fst (RMEV -> RMEV -> (RMEV, RMEV)
udivrem RMEV
x RMEV
y)

-- | Unsigned bitvector remainder.
urem :: RMEV -> RMEV -> RMEV
urem :: RMEV -> RMEV -> RMEV
urem RMEV
x RMEV
y = (RMEV, RMEV) -> RMEV
forall a b. (a, b) -> b
snd (RMEV -> RMEV -> (RMEV, RMEV)
udivrem RMEV
x RMEV
y)

-- | Signed bitvector division.
sdiv :: RMEV -> RMEV -> RMEV
sdiv :: RMEV -> RMEV -> RMEV
sdiv RMEV
x RMEV
y = (RMEV, RMEV) -> RMEV
forall a b. (a, b) -> a
fst (RMEV -> RMEV -> (RMEV, RMEV)
sdivrem RMEV
x RMEV
y)

-- | Signed bitvector remainder.
srem :: RMEV -> RMEV -> RMEV
srem :: RMEV -> RMEV -> RMEV
srem RMEV
x RMEV
y = (RMEV, RMEV) -> RMEV
forall a b. (a, b) -> b
snd (RMEV -> RMEV -> (RMEV, RMEV)
sdivrem RMEV
x RMEV
y)

udivrem :: RMEV -> RMEV -> (RMEV, RMEV)
udivrem :: RMEV -> RMEV -> (RMEV, RMEV)
udivrem RMEV
dividend RMEV
divisor = Int -> RME -> RMEV -> (RMEV, RMEV)
divStep Int
0 RME
RME.false RMEV
initial
  where
    n :: Int
    n :: Int
n = RMEV -> Int
forall a. Vector a -> Int
V.length RMEV
dividend

    -- Given an n-bit dividend and divisor, 'initial' is the starting value of
    -- the 2n-bit "remainder register" that carries both the quotient and remainder;
    initial :: RMEV
    initial :: RMEV
initial = Int -> Integer -> RMEV
integer Int
n Integer
0 RMEV -> RMEV -> RMEV
forall a. Vector a -> Vector a -> Vector a
V.++ RMEV
dividend

    divStep :: Int -> RME -> RMEV -> (RMEV, RMEV)
    divStep :: Int -> RME -> RMEV -> (RMEV, RMEV)
divStep Int
i RME
p RMEV
rr | Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n = (RMEV
q RMEV -> RME -> RMEV
`shiftL1` RME
p, RMEV
r)
      where (RMEV
r, RMEV
q) = Int -> RMEV -> (RMEV, RMEV)
forall a. Int -> Vector a -> (Vector a, Vector a)
V.splitAt Int
n RMEV
rr
    divStep Int
i RME
p RMEV
rr = Int -> RME -> RMEV -> (RMEV, RMEV)
divStep (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) RME
b ((RME -> RME -> RME) -> RMEV -> RMEV -> RMEV
forall a b c. (a -> b -> c) -> Vector a -> Vector b -> Vector c
V.zipWith (RME -> RME -> RME -> RME
RME.mux RME
b) ([RME] -> RMEV
forall a. [a] -> Vector a
V.fromList [RME]
s RMEV -> RMEV -> RMEV
forall a. Vector a -> Vector a -> Vector a
V.++ RMEV
q) RMEV
rs)
      where rs :: RMEV
rs = RMEV
rr RMEV -> RME -> RMEV
`shiftL1` RME
p
            (RMEV
r, RMEV
q) = Int -> RMEV -> (RMEV, RMEV)
forall a. Int -> Vector a -> (Vector a, Vector a)
V.splitAt Int
n RMEV
rs
            -- Subtract the divisor from the left half of the "remainder register"
            (RME
b, [RME]
s) = [RME] -> [RME] -> RME -> (RME, [RME])
ripple_carry_adder (RMEV -> [RME]
forall a. Vector a -> [a]
V.toList RMEV
r) ((RME -> RME) -> [RME] -> [RME]
forall a b. (a -> b) -> [a] -> [b]
map RME -> RME
RME.compl (RMEV -> [RME]
forall a. Vector a -> [a]
V.toList RMEV
divisor)) RME
RME.true

    shiftL1 :: RMEV -> RME -> RMEV
    shiftL1 :: RMEV -> RME -> RMEV
shiftL1 RMEV
v RME
e = RMEV -> RMEV
forall a. Vector a -> Vector a
V.tail RMEV
v RMEV -> RME -> RMEV
forall a. Vector a -> a -> Vector a
`V.snoc` RME
e

-- Perform udivrem on the absolute value of the operands.  Then, negate the
-- quotient if the signs of the operands differ and make the sign of a nonzero
-- remainder to match that of the dividend.
sdivrem :: RMEV -> RMEV -> (RMEV, RMEV)
sdivrem :: RMEV -> RMEV -> (RMEV, RMEV)
sdivrem RMEV
dividend RMEV
divisor = (RMEV
q',RMEV
r')
  where
    sign1 :: RME
sign1 = RMEV -> RME
forall a. Vector a -> a
V.head RMEV
dividend
    sign2 :: RME
sign2 = RMEV -> RME
forall a. Vector a -> a
V.head RMEV
divisor
    signXor :: RME
signXor = RME -> RME -> RME
RME.xor RME
sign1 RME
sign2
    negWhen :: RMEV -> RME -> RMEV
negWhen RMEV
x RME
c = (RME -> RME -> RME) -> RMEV -> RMEV -> RMEV
forall a b c. (a -> b -> c) -> Vector a -> Vector b -> Vector c
V.zipWith (RME -> RME -> RME -> RME
RME.mux RME
c) (RMEV -> RMEV
neg RMEV
x) RMEV
x
    dividend' :: RMEV
dividend' = RMEV -> RME -> RMEV
negWhen RMEV
dividend RME
sign1
    divisor' :: RMEV
divisor' = RMEV -> RME -> RMEV
negWhen RMEV
divisor RME
sign2
    (RMEV
q, RMEV
r) = RMEV -> RMEV -> (RMEV, RMEV)
udivrem RMEV
dividend' RMEV
divisor'
    q' :: RMEV
q' = RMEV -> RME -> RMEV
negWhen RMEV
q RME
signXor
    r' :: RMEV
r' = RMEV -> RME -> RMEV
negWhen RMEV
r RME
sign1

popcount :: RMEV -> RMEV
popcount :: RMEV -> RMEV
popcount RMEV
bits = if Int
l Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 then RMEV
forall a. Vector a
V.empty else (Int -> RME -> RMEV
forall a. Int -> a -> Vector a
V.replicate (Int
lInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
wInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) RME
RME.false) RMEV -> RMEV -> RMEV
forall a. Semigroup a => a -> a -> a
<> RMEV
pcnt
  where
    l :: Int
l = RMEV -> Int
forall a. Vector a -> Int
V.length RMEV
bits
    w :: Int
w = Int -> Int
forall b. FiniteBits b => b -> Int
Bits.countTrailingZeros Int
l -- log_2 rounded down, w+1 is enough bits to hold popcount
    zs :: RMEV
zs = Int -> RME -> RMEV
forall a. Int -> a -> Vector a
V.replicate Int
w RME
RME.false

    pcnt :: RMEV
pcnt = (RMEV -> RMEV -> RMEV) -> [RMEV] -> RMEV
forall a. (a -> a -> a) -> [a] -> a
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldr1 RMEV -> RMEV -> RMEV
add [RMEV]
xs -- length is w+1
    xs :: [RMEV]
xs = [ RMEV
zs RMEV -> RMEV -> RMEV
forall a. Semigroup a => a -> a -> a
<> RME -> RMEV
forall a. a -> Vector a
V.singleton RME
b | RME
b <- RMEV -> [RME]
forall a. Vector a -> [a]
V.toList RMEV
bits ]

countTrailingZeros :: RMEV -> RMEV
countTrailingZeros :: RMEV -> RMEV
countTrailingZeros RMEV
bits = RMEV -> RMEV
countLeadingZeros (RMEV -> RMEV
forall a. Vector a -> Vector a
V.reverse RMEV
bits)

-- Big endian convention means its easier to count leading zeros
countLeadingZeros :: RMEV -> RMEV
countLeadingZeros :: RMEV -> RMEV
countLeadingZeros RMEV
bits = if Int
l Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 then RMEV
forall a. Vector a
V.empty else (Int -> RME -> RMEV
forall a. Int -> a -> Vector a
V.replicate (Int
lInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
wInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) RME
RME.false) RMEV -> RMEV -> RMEV
forall a. Semigroup a => a -> a -> a
<> (Integer -> [RME] -> RMEV
go Integer
0 (RMEV -> [RME]
forall a. Vector a -> [a]
V.toList RMEV
bits))
  where
    l :: Int
l = RMEV -> Int
forall a. Vector a -> Int
V.length RMEV
bits
    w :: Int
w = Int -> Int
forall b. FiniteBits b => b -> Int
Bits.countTrailingZeros Int
l -- log_2 rounded down, w+1 is enough bits to hold count

    go :: Integer -> [RME] -> Vector RME
    go :: Integer -> [RME] -> RMEV
go !Integer
i []      = Int -> Integer -> RMEV
integer (Int
wInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Integer
i
    go !Integer
i (RME
b:[RME]
bs)  = (RME -> RME -> RME) -> RMEV -> RMEV -> RMEV
forall a b c. (a -> b -> c) -> Vector a -> Vector b -> Vector c
V.zipWith (RME -> RME -> RME -> RME
RME.mux RME
b) (Int -> Integer -> RMEV
integer (Int
wInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Integer
i) (Integer -> [RME] -> RMEV
go (Integer
iInteger -> Integer -> Integer
forall a. Num a => a -> a -> a
+Integer
1) [RME]
bs)

-- | Polynomial multiplication. Note that the algorithm works the same
-- no matter which endianness convention is used. Result length is
-- @max 0 (m+n-1)@, where @m@ and @n@ are the lengths of the inputs.
pmul :: RMEV -> RMEV -> RMEV
pmul :: RMEV -> RMEV -> RMEV
pmul RMEV
x RMEV
y = Int -> (Int -> RME) -> RMEV
forall a. Int -> (Int -> a) -> Vector a
V.generate (Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
0 (Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)) Int -> RME
coeff
  where
    m :: Int
m = RMEV -> Int
forall a. Vector a -> Int
V.length RMEV
x
    n :: Int
n = RMEV -> Int
forall a. Vector a -> Int
V.length RMEV
y
    coeff :: Int -> RME
coeff Int
k = (RME -> RME -> RME) -> RME -> [RME] -> RME
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr RME -> RME -> RME
RME.xor RME
RME.false
      [ RME -> RME -> RME
RME.conj (RMEV
x RMEV -> Int -> RME
forall a. Vector a -> Int -> a
V.! Int
i) (RMEV
y RMEV -> Int -> RME
forall a. Vector a -> Int -> a
V.! Int
j) | Int
i <- [Int
0 .. Int
k], let j :: Int
j = Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
i, Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
m, Int
j Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
n ]

-- | Polynomial mod with symbolic modulus. Return value has length one
-- less than the length of the modulus.
-- This implementation is optimized for the (common) case where the modulus
-- is concrete.
pmod :: RMEV -> RMEV -> RMEV
pmod :: RMEV -> RMEV -> RMEV
pmod RMEV
x RMEV
y = [RME] -> RMEV
findmsb (RMEV -> [RME]
forall a. Vector a -> [a]
V.toList RMEV
y)
  where
    findmsb :: [RME] -> RMEV
    findmsb :: [RME] -> RMEV
findmsb [] = Int -> RME -> RMEV
forall a. Int -> a -> Vector a
V.replicate (RMEV -> Int
forall a. Vector a -> Int
V.length RMEV
y Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) RME
RME.false -- division by zero
    findmsb (RME
c : [RME]
cs)
      | RME
c RME -> RME -> Bool
forall a. Eq a => a -> a -> Bool
== RME
RME.true = [RME] -> RMEV
usemask [RME]
cs
      | RME
c RME -> RME -> Bool
forall a. Eq a => a -> a -> Bool
== RME
RME.false = [RME] -> RMEV
findmsb [RME]
cs
      | Bool
otherwise = (RME -> RME -> RME) -> RMEV -> RMEV -> RMEV
forall a b c. (a -> b -> c) -> Vector a -> Vector b -> Vector c
V.zipWith (RME -> RME -> RME -> RME
RME.mux RME
c) ([RME] -> RMEV
usemask [RME]
cs) ([RME] -> RMEV
findmsb [RME]
cs)

    usemask :: [RME] -> RMEV
    usemask :: [RME] -> RMEV
usemask [RME]
m = RMEV -> Int -> RMEV
zext ([RME] -> RMEV
forall a. [a] -> Vector a
V.fromList (Int -> [RME] -> [RME] -> [RME]
go (RMEV -> Int
forall a. Vector a -> Int
V.length RMEV
x Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) [RME]
p0 [RME]
z0)) (RMEV -> Int
forall a. Vector a -> Int
V.length RMEV
y Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
      where
        zext :: RMEV -> Int -> RMEV
zext RMEV
v Int
r = Int -> RME -> RMEV
forall a. Int -> a -> Vector a
V.replicate (Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
- RMEV -> Int
forall a. Vector a -> Int
V.length RMEV
v) RME
RME.false RMEV -> RMEV -> RMEV
forall a. Vector a -> Vector a -> Vector a
V.++ RMEV
v
        msize :: Int
msize = [RME] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [RME]
m
        p0 :: [RME]
p0 = Int -> RME -> [RME]
forall a. Int -> a -> [a]
replicate (Int
msize Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) RME
RME.false [RME] -> [RME] -> [RME]
forall a. [a] -> [a] -> [a]
++ [RME
RME.true]
        z0 :: [RME]
z0 = Int -> RME -> [RME]
forall a. Int -> a -> [a]
replicate Int
msize RME
RME.false

        next :: [RME] -> [RME]
        next :: [RME] -> [RME]
next [] = []
        next (RME
b : [RME]
bs) =
          let m' :: [RME]
m' = (RME -> RME) -> [RME] -> [RME]
forall a b. (a -> b) -> [a] -> [b]
map (RME -> RME -> RME
RME.conj RME
b) [RME]
m
              bs' :: [RME]
bs' = [RME]
bs [RME] -> [RME] -> [RME]
forall a. [a] -> [a] -> [a]
++ [RME
RME.false]
          in (RME -> RME -> RME) -> [RME] -> [RME] -> [RME]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith RME -> RME -> RME
RME.xor [RME]
m' [RME]
bs'

        go :: Int -> [RME] -> [RME] -> [RME]
        go :: Int -> [RME] -> [RME] -> [RME]
go Int
i [RME]
p [RME]
acc
          | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 = [RME]
acc
          | Bool
otherwise =
              let px :: [RME]
px = (RME -> RME) -> [RME] -> [RME]
forall a b. (a -> b) -> [a] -> [b]
map (RME -> RME -> RME
RME.conj (RMEV
x RMEV -> Int -> RME
forall a. Vector a -> Int -> a
V.! Int
i)) [RME]
p
                  acc' :: [RME]
acc' = (RME -> RME -> RME) -> [RME] -> [RME] -> [RME]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith RME -> RME -> RME
RME.xor [RME]
px [RME]
acc
                  p' :: [RME]
p' = [RME] -> [RME]
next [RME]
p
              in Int -> [RME] -> [RME] -> [RME]
go (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) [RME]
p' [RME]
acc'

-- | Polynomial division. Return value has length
--   equal to the first argument.
pdiv :: RMEV -> RMEV -> RMEV
pdiv :: RMEV -> RMEV -> RMEV
pdiv RMEV
x RMEV
y = (RMEV, RMEV) -> RMEV
forall a b. (a, b) -> a
fst (RMEV -> RMEV -> (RMEV, RMEV)
pdivmod RMEV
x RMEV
y)

-- Polynomial div/mod: resulting lengths are as in Cryptol.

-- TODO: probably this function should be disentangled to only compute
-- division, given that we have a separate polynomial modulus algorithm.
pdivmod :: RMEV -> RMEV -> (RMEV, RMEV)
pdivmod :: RMEV -> RMEV -> (RMEV, RMEV)
pdivmod RMEV
x RMEV
y = [RME] -> (RMEV, RMEV)
findmsb (RMEV -> [RME]
forall a. Vector a -> [a]
V.toList RMEV
y)
  where
    findmsb :: [RME] -> (RMEV, RMEV)
    findmsb :: [RME] -> (RMEV, RMEV)
findmsb (RME
c : [RME]
cs) = RME -> (RMEV, RMEV) -> (RMEV, RMEV) -> (RMEV, RMEV)
muxPair RME
c ([RME] -> (RMEV, RMEV)
usemask [RME]
cs) ([RME] -> (RMEV, RMEV)
findmsb [RME]
cs)
    findmsb [] = (RMEV
x, Int -> RME -> RMEV
forall a. Int -> a -> Vector a
V.replicate (RMEV -> Int
forall a. Vector a -> Int
V.length RMEV
y Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) RME
RME.false) -- division by zero

    usemask :: [RME] -> (RMEV, RMEV)
    usemask :: [RME] -> (RMEV, RMEV)
usemask [RME]
mask = (RMEV
q, RMEV
r)
      where
        ([RME]
qs, [RME]
rs) = [RME] -> [RME] -> ([RME], [RME])
pdivmod_helper (RMEV -> [RME]
forall a. Vector a -> [a]
V.toList RMEV
x) [RME]
mask
        z :: RME
z = RME
RME.false
        qs' :: [RME]
qs' = (RME -> RME) -> [RME] -> [RME]
forall a b. (a -> b) -> [a] -> [b]
map (RME -> RME -> RME
forall a b. a -> b -> a
const RME
z) [RME]
rs [RME] -> [RME] -> [RME]
forall a. [a] -> [a] -> [a]
++ [RME]
qs
        rs' :: [RME]
rs' = Int -> RME -> [RME]
forall a. Int -> a -> [a]
replicate (RMEV -> Int
forall a. Vector a -> Int
V.length RMEV
y Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
- [RME] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [RME]
rs) RME
z [RME] -> [RME] -> [RME]
forall a. [a] -> [a] -> [a]
++ [RME]
rs
        q :: RMEV
q = [RME] -> RMEV
forall a. [a] -> Vector a
V.fromList [RME]
qs'
        r :: RMEV
r = [RME] -> RMEV
forall a. [a] -> Vector a
V.fromList [RME]
rs'

    muxPair :: RME -> (RMEV, RMEV) -> (RMEV, RMEV) -> (RMEV, RMEV)
    muxPair :: RME -> (RMEV, RMEV) -> (RMEV, RMEV) -> (RMEV, RMEV)
muxPair RME
c (RMEV, RMEV)
a (RMEV, RMEV)
b
      | RME
c RME -> RME -> Bool
forall a. Eq a => a -> a -> Bool
== RME
RME.true = (RMEV, RMEV)
a
      | RME
c RME -> RME -> Bool
forall a. Eq a => a -> a -> Bool
== RME
RME.false = (RMEV, RMEV)
b
      | Bool
otherwise = ((RME -> RME -> RME) -> RMEV -> RMEV -> RMEV
forall a b c. (a -> b -> c) -> Vector a -> Vector b -> Vector c
V.zipWith (RME -> RME -> RME -> RME
RME.mux RME
c) ((RMEV, RMEV) -> RMEV
forall a b. (a, b) -> a
fst (RMEV, RMEV)
a) ((RMEV, RMEV) -> RMEV
forall a b. (a, b) -> a
fst (RMEV, RMEV)
b), (RME -> RME -> RME) -> RMEV -> RMEV -> RMEV
forall a b c. (a -> b -> c) -> Vector a -> Vector b -> Vector c
V.zipWith (RME -> RME -> RME -> RME
RME.mux RME
c) ((RMEV, RMEV) -> RMEV
forall a b. (a, b) -> b
snd (RMEV, RMEV)
a) ((RMEV, RMEV) -> RMEV
forall a b. (a, b) -> b
snd (RMEV, RMEV)
b))

-- Divide ds by (1 : mask), giving quotient and remainder. All
-- arguments and results are big-endian. Remainder has the same length
-- as mask (but limited by length ds); total length of quotient ++
-- remainder = length ds.
pdivmod_helper :: [RME] -> [RME] -> ([RME], [RME])
pdivmod_helper :: [RME] -> [RME] -> ([RME], [RME])
pdivmod_helper [RME]
ds [RME]
mask = Int -> [RME] -> ([RME], [RME])
go ([RME] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [RME]
ds Int -> Int -> Int
forall a. Num a => a -> a -> a
- [RME] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [RME]
mask) [RME]
ds
  where
    go :: Int -> [RME] -> ([RME], [RME])
    go :: Int -> [RME] -> ([RME], [RME])
go Int
n [RME]
cs | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0 = ([], [RME]
cs)
    go Int
_ []          = [Char] -> ([RME], [RME])
forall a. HasCallStack => [Char] -> a
error [Char]
"Data.AIG.Operations.pdiv: impossible"
    go Int
n (RME
c : [RME]
cs)    = (RME
c RME -> [RME] -> [RME]
forall a. a -> [a] -> [a]
: [RME]
qs, [RME]
rs)
      where cs' :: [RME]
cs' = RME -> [RME] -> [RME] -> [RME]
mux_add RME
c [RME]
cs [RME]
mask
            ([RME]
qs, [RME]
rs) = Int -> [RME] -> ([RME], [RME])
go (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) [RME]
cs'

    mux_add :: RME -> [RME] -> [RME] -> [RME]
    mux_add :: RME -> [RME] -> [RME] -> [RME]
mux_add RME
c (RME
x : [RME]
xs) (RME
y : [RME]
ys) = RME -> RME -> RME -> RME
RME.mux RME
c (RME -> RME -> RME
RME.xor RME
x RME
y) RME
x RME -> [RME] -> [RME]
forall a. a -> [a] -> [a]
: RME -> [RME] -> [RME] -> [RME]
mux_add RME
c [RME]
xs [RME]
ys
    mux_add RME
_ []       (RME
_ : [RME]
_ ) = [Char] -> [RME]
forall a. HasCallStack => [Char] -> a
error [Char]
"pdiv: impossible"
    mux_add RME
_ [RME]
xs       []       = [RME]
xs

-- | Helper for building shift and rotate operations.
-- The callback function is called with: the first argument,
-- the index being filled in the result, and the arithmetic
-- value of the second argument.
bitOp :: (RMEV -> Integer -> Integer -> RME) -> RMEV -> RMEV -> RMEV
bitOp :: (RMEV -> Integer -> Integer -> RME) -> RMEV -> RMEV -> RMEV
bitOp RMEV -> Integer -> Integer -> RME
f RMEV
x RMEV
y = Int -> (Int -> RME) -> RMEV
forall a. Int -> (Int -> a) -> Vector a
V.generate Int
w \Int
i -> Integer -> Integer -> [RME] -> RME
pick (Int -> Integer
forall a. Integral a => a -> Integer
toInteger Int
i) Integer
0 [RME]
y'
  where
    y' :: [RME]
y' = RMEV -> [RME]
forall a. Vector a -> [a]
V.toList RMEV
y
    w :: Int
w = RMEV -> Int
forall a. Vector a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length RMEV
x
    pick :: Integer -> Integer -> [RME] -> RME
pick Integer
i Integer
j [] = RMEV -> Integer -> Integer -> RME
f RMEV
x Integer
i Integer
j
    pick Integer
i Integer
j (RME
b:[RME]
bs) = RME -> RME -> RME -> RME
RME.mux RME
b (Integer -> Integer -> [RME] -> RME
pick Integer
i (Integer
1Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+Integer
2Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
*Integer
j) [RME]
bs) (Integer -> Integer -> [RME] -> RME
pick Integer
i (Integer
2Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
*Integer
j) [RME]
bs)

-- | Bitwise logical left shift. Shifts the bits in the first bit-vector
-- by the unsigned, arithmetic value in the second bit-vector filling
-- in with false bits.
shl :: RMEV -> RMEV -> RMEV
shl :: RMEV -> RMEV -> RMEV
shl = (RMEV -> Integer -> Integer -> RME) -> RMEV -> RMEV -> RMEV
bitOp \RMEV
x Integer
i Integer
j ->
  let w :: Int
w = RMEV -> Int
forall a. Vector a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length RMEV
x in 
  if Integer
i Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
j Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= Int -> Integer
forall a. Integral a => a -> Integer
toInteger Int
w then RME
RME.false else RMEV
x RMEV -> Int -> RME
forall a. Vector a -> Int -> a
V.! Integer -> Int
forall a. Num a => Integer -> a
fromInteger (Integer
iInteger -> Integer -> Integer
forall a. Num a => a -> a -> a
+Integer
j)

-- | Arithmetic logical right shift. Shifts the bits in the first bit-vector
-- by the unsigned, arithmetic value in the second bit-vector filling
-- in with bits matching the first bit (which is treated as a sign bit).
ashr :: RMEV -> RMEV -> RMEV
ashr :: RMEV -> RMEV -> RMEV
ashr = (RMEV -> Integer -> Integer -> RME) -> RMEV -> RMEV -> RMEV
bitOp \RMEV
x Integer
i Integer
j ->
  if Integer
i Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
j then RMEV -> RME
forall a. Vector a -> a
V.head RMEV
x else RMEV
x RMEV -> Int -> RME
forall a. Vector a -> Int -> a
V.! Integer -> Int
forall a. Num a => Integer -> a
fromInteger (Integer
iInteger -> Integer -> Integer
forall a. Num a => a -> a -> a
-Integer
j)

-- | Bitwise logical right shift. Shifts the bits in the first bit-vector
-- by the unsigned, arithmetic value in the second bit-vector filling
-- in with false bits.
lshr :: RMEV -> RMEV -> RMEV
lshr :: RMEV -> RMEV -> RMEV
lshr = (RMEV -> Integer -> Integer -> RME) -> RMEV -> RMEV -> RMEV
bitOp \RMEV
x Integer
i Integer
j ->
  if Integer
i Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
j then RME
RME.false else RMEV
x RMEV -> Int -> RME
forall a. Vector a -> Int -> a
V.! Integer -> Int
forall a. Num a => Integer -> a
fromInteger (Integer
iInteger -> Integer -> Integer
forall a. Num a => a -> a -> a
-Integer
j)

-- | Bitwise left rotation. Rotates the bits in the first bit-vector
-- by the unsigned, arithmetic value in the second bit-vector.
rol :: RMEV -> RMEV -> RMEV
rol :: RMEV -> RMEV -> RMEV
rol = (RMEV -> Integer -> Integer -> RME) -> RMEV -> RMEV -> RMEV
bitOp \RMEV
x Integer
i Integer
j ->
  let w :: Int
w = RMEV -> Int
forall a. Vector a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length RMEV
x in
  RMEV
x RMEV -> Int -> RME
forall a. Vector a -> Int -> a
V.! Integer -> Int
forall a. Num a => Integer -> a
fromInteger ((Integer
i Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
j) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Int -> Integer
forall a. Integral a => a -> Integer
toInteger Int
w)

-- | Bitwise right rotation. Rotates the bits in the first bit-vector
-- by the unsigned, arithmetic value in the second bit-vector.
ror :: RMEV -> RMEV -> RMEV
ror :: RMEV -> RMEV -> RMEV
ror = (RMEV -> Integer -> Integer -> RME) -> RMEV -> RMEV -> RMEV
bitOp \RMEV
x Integer
i Integer
j ->
  let w :: Int
w = RMEV -> Int
forall a. Vector a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length RMEV
x in
  RMEV
x RMEV -> Int -> RME
forall a. Vector a -> Int -> a
V.! Integer -> Int
forall a. Num a => Integer -> a
fromInteger ((Integer
i Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
j) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Int -> Integer
forall a. Integral a => a -> Integer
toInteger Int
w)