{-# OPTIONS_HADDOCK prune #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE UnliftedNewtypes #-}
{-# LANGUAGE UnboxedTuples #-}
{-# LANGUAGE ViewPatterns #-}

-- |
-- Module: Data.Choice
-- Copyright: (c) 2025 Jared Tobin
-- License: MIT
-- Maintainer: Jared Tobin <jared@ppad.tech>
--
-- Primitives for constant-time choice.
--
-- The 'Choice' type encodes truthy and falsy values as unboxed 'Word#'
-- bit masks.
--
-- Use the standard logical primitives ('or', 'and', 'xor', 'not', eq')
-- to manipulate in-flight 'Choice' values. Use one of the selection
-- functions to use a 'Choice' to select a value in constant time,
-- or 'decide' to reduce a 'Choice' to a 'Bool' at the /end/ of a
-- sensitive computation.

module Data.Choice (
  -- * Choice
    Choice
  , decide
  , true#
  , false#
  , to_word#

  -- * Construction
  , from_full_mask#
  , from_bit#
  , from_word_nonzero#
  , from_word_eq#
  , from_word_le#
  , from_word_lt#
  , from_word_gt#

  -- * Manipulation
  , or
  , and
  , xor
  , not
  , ne
  , eq

  -- * Constant-time Selection
  , select_word#
  , select_wide#
  , select_wider#

  -- * Constant-time Equality
  , eq_word#
  , eq_wide#
  , eq_wider#
  ) where

import qualified Data.Bits as B
import GHC.Exts (Word#, Int(..), Word(..))
import qualified GHC.Exts as Exts
import Prelude hiding (and, not, or)

-- utilities ------------------------------------------------------------------

type Limb2 = (# Word#, Word# #)

type Limb4 = (# Word#, Word#, Word#, Word# #)

-- wrapping negation
neg_w# :: Word# -> Word#
neg_w# :: Word# -> Word#
neg_w# Word#
w = Word# -> Word# -> Word#
Exts.plusWord# (Word# -> Word#
Exts.not# Word#
w) Word#
1##
{-# INLINE neg_w# #-}

hi# :: Word# -> Limb2
hi# :: Word# -> Limb2
hi# Word#
w = (# Word#
0##, Word#
w #)
{-# INLINE hi# #-}

lo# :: Word# -> Limb2
lo# :: Word# -> Limb2
lo# Word#
w = (# Word#
w, Word#
0## #)
{-# INLINE lo# #-}

or_w# :: Limb2 -> Limb2 -> Limb2
or_w# :: Limb2 -> Limb2 -> Limb2
or_w# (# Word#
a0, Word#
a1 #) (# Word#
b0, Word#
b1 #) = (# Word# -> Word# -> Word#
Exts.or# Word#
a0 Word#
b0, Word# -> Word# -> Word#
Exts.or# Word#
a1 Word#
b1 #)
{-# INLINE or_w# #-}

and_w# :: Limb2 -> Limb2 -> Limb2
and_w# :: Limb2 -> Limb2 -> Limb2
and_w# (# Word#
a0, Word#
a1 #) (# Word#
b0, Word#
b1 #) = (# Word# -> Word# -> Word#
Exts.and# Word#
a0 Word#
b0, Word# -> Word# -> Word#
Exts.and# Word#
a1 Word#
b1 #)
{-# INLINE and_w# #-}

xor_w# :: Limb2 -> Limb2 -> Limb2
xor_w# :: Limb2 -> Limb2 -> Limb2
xor_w# (# Word#
a0, Word#
a1 #) (# Word#
b0, Word#
b1 #) = (# Word# -> Word# -> Word#
Exts.xor# Word#
a0 Word#
b0, Word# -> Word# -> Word#
Exts.xor# Word#
a1 Word#
b1 #)
{-# INLINE xor_w# #-}

-- choice ---------------------------------------------------------------------

-- | Constant-time choice, encoded as a mask.
--
--   Note that 'Choice' is defined as an unlifted newtype, and so a
--   'Choice' value cannot be bound at the top level. You should work
--   with it locally in the context of a computation.
--
--   Use one of the selection functions to select a 'Choice' value in
--   constant time, or 'decide' to reduce it to a 'Bool' at the /end/ of
--   a sensitive computation.
--
--   >>> decide (or# (false# ()) (true# ()))
--   True
newtype Choice = Choice Word#

-- | Construct the falsy 'Choice'.
--
--   >>> decide (false# ())
--   False
false# :: () -> Choice
false# :: () -> Choice
false# ()
_ = Word# -> Choice
Choice Word#
0##
{-# INLINE false# #-}

-- | Construct the truthy 'Choice'.
--
--   >>> decide (true# ())
--   True
true# :: () -> Choice
true# :: () -> Choice
true# ()
_ = case Word
forall a. Bounded a => a
maxBound :: Word of
  W# Word#
w -> Word# -> Choice
Choice Word#
w
{-# INLINE true# #-}

-- | Decide a 'Choice' by reducing it to a 'Bool'.
--
--   The 'decide' function itself runs in constant time, but once
--   it reduces a 'Choice' to a 'Bool', any subsequent branching on
--   the result is liable to introduce variable-time behaviour.
--
--   You should 'decide' only at the /end/ of a computation, after all
--   security-sensitive computations have been carried out.
--
--   >>> decide (true# ())
--   True
decide :: Choice -> Bool
decide :: Choice -> Bool
decide (Choice Word#
c) = Int# -> Bool
Exts.isTrue# (Word# -> Word# -> Int#
Exts.neWord# Word#
c Word#
0##)
{-# INLINE decide #-}

-- | Convert a 'Choice' to an unboxed 'Word#'.
--
--   This essentially "unboxes" the 'Choice' for direct manipulation.
--
--   >>> import qualified GHC.Exts as Exts
--   >>> Exts.isTrue# (Exts.eqWord# 0## (to_word# (false# ())))
--   True
to_word# :: Choice -> Word#
to_word# :: Choice -> Word#
to_word# (Choice Word#
c) = Word# -> Word# -> Word#
Exts.and# Word#
c Word#
1##
{-# INLINE to_word# #-}

-- construction ---------------------------------------------------------------

-- | Construct a 'Choice' from an unboxed full-word mask.
--
--   The input is /not/ checked to be a full-word mask.
--
--   >>> decide (from_full_mask# 0##)
--   False
--   >>> decide (from_full_mask# 0xFFFFFFFFF_FFFFFFFF##)
--   True
from_full_mask# :: Word# -> Choice
from_full_mask# :: Word# -> Choice
from_full_mask# Word#
w = Word# -> Choice
Choice Word#
w
{-# INLINE from_full_mask# #-}

-- | Construct a 'Choice' from an unboxed word, which should be either
--   0## or 1##.
--
--   The input is /not/ checked to be a bit.
--
--   >>> decide (from_bit# 1##)
--   True
from_bit# :: Word# -> Choice
from_bit# :: Word# -> Choice
from_bit# Word#
w = Word# -> Choice
Choice (Word# -> Word#
neg_w# Word#
w)
{-# INLINE from_bit# #-}

-- | Construct a 'Choice' from a /nonzero/ unboxed word.
--
--   The input is /not/ checked to be nonzero.
--
--   >>> decide (from_word_nonzero# 2##)
--   True
from_word_nonzero# :: Word# -> Choice
from_word_nonzero# :: Word# -> Choice
from_word_nonzero# Word#
w =
  let !n :: Word#
n = Word# -> Word#
neg_w# Word#
w
      !s :: Int#
s = case Word -> Int
forall b. FiniteBits b => b -> Int
B.finiteBitSize (Word
0 :: Word) of I# Int#
m -> Int#
m Int# -> Int# -> Int#
Exts.-# Int#
1#
      !v :: Word#
v = Word# -> Int# -> Word#
Exts.uncheckedShiftRL# (Word# -> Word# -> Word#
Exts.or# Word#
w Word#
n) Int#
s
  in  Word# -> Choice
from_bit# Word#
v
{-# INLINE from_word_nonzero# #-}

-- | Construct a 'Choice' from an equality comparison.
--
--   >>> decide (from_word_eq# 0## 1##)
--   False
--   decide (from_word_eq# 1## 1##)
--   True
from_word_eq# :: Word# -> Word# -> Choice
from_word_eq# :: Word# -> Word# -> Choice
from_word_eq# Word#
x Word#
y = case Word# -> Choice
from_word_nonzero# (Word# -> Word# -> Word#
Exts.xor# Word#
x Word#
y) of
  Choice Word#
w -> Word# -> Choice
Choice (Word# -> Word#
Exts.not# Word#
w)
{-# INLINE from_word_eq# #-}

-- | Construct a 'Choice from an at-most comparison.
--
--   >>> decide (from_word_le# 0## 1##)
--   True
--   >>> decide (from_word_le# 1## 1##)
--   True
from_word_le# :: Word# -> Word# -> Choice
from_word_le# :: Word# -> Word# -> Choice
from_word_le# Word#
x Word#
y =
  let !s :: Int#
s = case Word -> Int
forall b. FiniteBits b => b -> Int
B.finiteBitSize (Word
0 :: Word) of I# Int#
m -> Int#
m Int# -> Int# -> Int#
Exts.-# Int#
1#
      !bit :: Word#
bit =
        Word# -> Int# -> Word#
Exts.uncheckedShiftRL#
          (Word# -> Word# -> Word#
Exts.and#
            (Word# -> Word# -> Word#
Exts.or# (Word# -> Word#
Exts.not# Word#
x) Word#
y)
            (Word# -> Word# -> Word#
Exts.or# (Word# -> Word# -> Word#
Exts.xor# Word#
x Word#
y) (Word# -> Word#
Exts.not# (Word# -> Word# -> Word#
Exts.minusWord# Word#
y Word#
x))))
          Int#
s
  in  Word# -> Choice
from_bit# Word#
bit
{-# INLINE from_word_le# #-}

-- | Construct a 'Choice' from a less-than comparison.
--
--   >>> decide (from_word_lt# 0## 1##)
--   True
--   >>> decide (from_word_lt# 1## 1##)
--   False
from_word_lt# :: Word# -> Word# -> Choice
from_word_lt# :: Word# -> Word# -> Choice
from_word_lt# Word#
x Word#
y =
  let !s :: Int#
s = case Word -> Int
forall b. FiniteBits b => b -> Int
B.finiteBitSize (Word
0 :: Word) of I# Int#
m -> Int#
m Int# -> Int# -> Int#
Exts.-# Int#
1#
      !bit :: Word#
bit =
        Word# -> Int# -> Word#
Exts.uncheckedShiftRL#
          (Word# -> Word# -> Word#
Exts.or#
            (Word# -> Word# -> Word#
Exts.and# (Word# -> Word#
Exts.not# Word#
x) Word#
y)
            (Word# -> Word# -> Word#
Exts.and# (Word# -> Word# -> Word#
Exts.or# (Word# -> Word#
Exts.not# Word#
x) Word#
y) (Word# -> Word# -> Word#
Exts.minusWord# Word#
x Word#
y)))
          Int#
s
  in  Word# -> Choice
from_bit# Word#
bit
{-# INLINE from_word_lt# #-}

-- | Construct a 'Choice' from a greater-than comparison.
--
--   >>> decide (from_word_gt# 0## 1##)
--   False
--   >>> decide (from_word_gt# 1## 1##)
--   False
from_word_gt# :: Word# -> Word# -> Choice
from_word_gt# :: Word# -> Word# -> Choice
from_word_gt# Word#
x Word#
y = Word# -> Word# -> Choice
from_word_lt# Word#
y Word#
x
{-# INLINE from_word_gt# #-}

-- manipulation ---------------------------------------------------------------

-- | Logically negate a 'Choice'.
--
--   >>> decide (not (true# ()))
--   False
--   >>> decide (not (false# ()))
--   True
not :: Choice -> Choice
not :: Choice -> Choice
not (Choice Word#
w) = Word# -> Choice
Choice (Word# -> Word#
Exts.not# Word#
w)
{-# INLINE not #-}

-- | Logical disjunction on 'Choice' values.
--
--   >>> decide (or (true# ()) (false# ()))
--   True
or :: Choice -> Choice -> Choice
or :: Choice -> Choice -> Choice
or (Choice Word#
w0) (Choice Word#
w1) = Word# -> Choice
Choice (Word# -> Word# -> Word#
Exts.or# Word#
w0 Word#
w1)
{-# INLINE or #-}

-- | Logical conjunction on 'Choice' values.
--
--   >>> decide (and (true# ()) (false# ()))
--   False
and :: Choice -> Choice -> Choice
and :: Choice -> Choice -> Choice
and (Choice Word#
w0) (Choice Word#
w1) = Word# -> Choice
Choice (Word# -> Word# -> Word#
Exts.and# Word#
w0 Word#
w1)
{-# INLINE and #-}

-- | Logical inequality on 'Choice' values.
--
--   >>> decide (xor (true# ()) (false# ()))
--   True
xor :: Choice -> Choice -> Choice
xor :: Choice -> Choice -> Choice
xor (Choice Word#
w0) (Choice Word#
w1) = Word# -> Choice
Choice (Word# -> Word# -> Word#
Exts.xor# Word#
w0 Word#
w1)
{-# INLINE xor #-}

-- | Logical inequality on 'Choice' values.
--
--   >>> decide (ne (true# ()) (false# ()))
--   True
ne :: Choice -> Choice -> Choice
ne :: Choice -> Choice -> Choice
ne Choice
c0 Choice
c1 = Choice -> Choice -> Choice
xor Choice
c0 Choice
c1
{-# INLINE ne #-}

-- | Logical equality on 'Choice' values.
--
--   >>> decide (eq (true# ()) (false# ()))
--   False
eq :: Choice -> Choice -> Choice
eq :: Choice -> Choice -> Choice
eq Choice
c0 Choice
c1 = Choice -> Choice
not (Choice -> Choice -> Choice
ne Choice
c0 Choice
c1)
{-# INLINE eq #-}

-- constant-time selection ----------------------------------------------------

-- | Select an unboxed word without branching, given a 'Choice'.
--
--   >>> let w = C.select_word# 0## 1## (C.true# ()) in GHC.Word.W# w
--   1
select_word# :: Word# -> Word# -> Choice -> Word#
select_word# :: Word# -> Word# -> Choice -> Word#
select_word# Word#
a Word#
b (Choice Word#
c) = Word# -> Word# -> Word#
Exts.xor# Word#
a (Word# -> Word# -> Word#
Exts.and# Word#
c (Word# -> Word# -> Word#
Exts.xor# Word#
a Word#
b))
{-# INLINE select_word# #-}

-- | Select an unboxed two-limb word without branching, given a 'Choice'.
select_wide#
  :: Limb2
  -> Limb2
  -> Choice
  -> Limb2
select_wide# :: Limb2 -> Limb2 -> Choice -> Limb2
select_wide# Limb2
a Limb2
b (Choice Word#
w) =
  let !mask :: Limb2
mask = Limb2 -> Limb2 -> Limb2
or_w# (Word# -> Limb2
hi# Word#
w) (Word# -> Limb2
lo# Word#
w)
  in  Limb2 -> Limb2 -> Limb2
xor_w# Limb2
a (Limb2 -> Limb2 -> Limb2
and_w# Limb2
mask (Limb2 -> Limb2 -> Limb2
xor_w# Limb2
a Limb2
b))
{-# INLINE select_wide# #-}

-- | Select an unboxed four-limb word without branching, given a 'Choice'.
select_wider#
  :: Limb4
  -> Limb4
  -> Choice
  -> Limb4
select_wider# :: Limb4 -> Limb4 -> Choice -> Limb4
select_wider# (# Word#
a0, Word#
a1, Word#
a2, Word#
a3 #) (# Word#
b0, Word#
b1, Word#
b2, Word#
b3 #) (Choice Word#
w) =
  let !w0 :: Word#
w0 = Word# -> Word# -> Word#
Exts.xor# Word#
a0 (Word# -> Word# -> Word#
Exts.and# Word#
w (Word# -> Word# -> Word#
Exts.xor# Word#
a0 Word#
b0))
      !w1 :: Word#
w1 = Word# -> Word# -> Word#
Exts.xor# Word#
a1 (Word# -> Word# -> Word#
Exts.and# Word#
w (Word# -> Word# -> Word#
Exts.xor# Word#
a1 Word#
b1))
      !w2 :: Word#
w2 = Word# -> Word# -> Word#
Exts.xor# Word#
a2 (Word# -> Word# -> Word#
Exts.and# Word#
w (Word# -> Word# -> Word#
Exts.xor# Word#
a2 Word#
b2))
      !w3 :: Word#
w3 = Word# -> Word# -> Word#
Exts.xor# Word#
a3 (Word# -> Word# -> Word#
Exts.and# Word#
w (Word# -> Word# -> Word#
Exts.xor# Word#
a3 Word#
b3))
  in  (# Word#
w0, Word#
w1, Word#
w2, Word#
w3 #)
{-# INLINE select_wider# #-}

-- constant-time equality -----------------------------------------------------

-- | Compare unboxed words for equality in constant time.
--
--   >>> decide (eq_word# 0## 1##)
--   False
eq_word# :: Word# -> Word# -> Choice
eq_word# :: Word# -> Word# -> Choice
eq_word# Word#
a Word#
b =
  let !s :: Int#
s = case Word -> Int
forall b. FiniteBits b => b -> Int
B.finiteBitSize (Word
0 :: Word) of I# Int#
m -> Int#
m Int# -> Int# -> Int#
Exts.-# Int#
1#
      !x :: Word#
x = Word# -> Word# -> Word#
Exts.xor# Word#
a Word#
b
      !y :: Word#
y = Word# -> Int# -> Word#
Exts.uncheckedShiftRL# (Word# -> Word# -> Word#
Exts.or# Word#
x (Word# -> Word#
neg_w# Word#
x)) Int#
s
  in  Word# -> Choice
Choice (Word# -> Word# -> Word#
Exts.xor# Word#
y Word#
1##)
{-# INLINE eq_word# #-}

-- | Compare unboxed two-limb words for equality in constant time.
--
--   >>> decide (eq_wide (# 0##, 0## #) (# 0##, 0## #))
--   True
eq_wide#
  :: Limb2
  -> Limb2
  -> Choice
eq_wide# :: Limb2 -> Limb2 -> Choice
eq_wide# (# Word#
a0, Word#
a1 #) (# Word#
b0, Word#
b1 #) =
  let !s :: Int#
s = case Word -> Int
forall b. FiniteBits b => b -> Int
B.finiteBitSize (Word
0 :: Word) of I# Int#
m -> Int#
m Int# -> Int# -> Int#
Exts.-# Int#
1#
      !x :: Word#
x = Word# -> Word# -> Word#
Exts.or# (Word# -> Word# -> Word#
Exts.xor# Word#
a0 Word#
b0) (Word# -> Word# -> Word#
Exts.xor# Word#
a1 Word#
b1)
      !y :: Word#
y = Word# -> Int# -> Word#
Exts.uncheckedShiftRL# (Word# -> Word# -> Word#
Exts.or# Word#
x (Word# -> Word#
neg_w# Word#
x)) Int#
s
  in  Word# -> Choice
Choice (Word# -> Word# -> Word#
Exts.xor# Word#
y Word#
1##)
{-# INLINE eq_wide# #-}

-- | Compare unboxed four-limb words for equality in constant time.
--
--   >>> let zero = (# 0##, 0##, 0##, 0## #) in decide (eq_wider# zero zero)
--   True
eq_wider#
  :: Limb4
  -> Limb4
  -> Choice
eq_wider# :: Limb4 -> Limb4 -> Choice
eq_wider# (# Word#
a0, Word#
a1, Word#
a2, Word#
a3 #) (# Word#
b0, Word#
b1, Word#
b2, Word#
b3 #) =
  let !s :: Int#
s = case Word -> Int
forall b. FiniteBits b => b -> Int
B.finiteBitSize (Word
0 :: Word) of I# Int#
m -> Int#
m Int# -> Int# -> Int#
Exts.-# Int#
1#
      !x :: Word#
x = Word# -> Word# -> Word#
Exts.or# (Word# -> Word# -> Word#
Exts.or# (Word# -> Word# -> Word#
Exts.xor# Word#
a0 Word#
b0) (Word# -> Word# -> Word#
Exts.xor# Word#
a1 Word#
b1))
                    (Word# -> Word# -> Word#
Exts.or# (Word# -> Word# -> Word#
Exts.xor# Word#
a2 Word#
b2) (Word# -> Word# -> Word#
Exts.xor# Word#
a3 Word#
b3))
      !y :: Word#
y = Word# -> Int# -> Word#
Exts.uncheckedShiftRL# (Word# -> Word# -> Word#
Exts.or# Word#
x (Word# -> Word#
neg_w# Word#
x)) Int#
s
  in  Word# -> Choice
Choice (Word# -> Word# -> Word#
Exts.xor# Word#
y Word#
1##)
{-# INLINE eq_wider# #-}