{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE ViewPatterns #-}

module Sel.Internal where

import Control.Monad (when)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.Trans.Class (lift)
import Data.Base16.Types qualified as Base16
import Data.ByteString (StrictByteString)
import Data.ByteString.Base16 qualified as Base16
import Data.ByteString.Internal (memcmp)
import Data.ByteString.Internal qualified as ByteString
import Data.Coerce (coerce)
import Data.Kind (Type)
import Foreign (ForeignPtr, Ptr)
import Foreign qualified
import Foreign.C (CSize, CUChar, throwErrno)
import Foreign.C.Types (CChar)
import LibSodium.Bindings.Comparison (sodiumCompare, sodiumMemcmp)
import LibSodium.Bindings.SecureMemory (finalizerSodiumFree, sodiumFree, sodiumMalloc)
import System.IO.Unsafe (unsafeDupablePerformIO)

import Sel.Internal.Scoped
import Sel.Internal.Scoped.Foreign

-- | Compare the contents of two byte arrays for equality in constant time.
--
-- /See:/ [Constant-time test for equality](https://doc.libsodium.org/helpers#constant-time-test-for-equality)
--
-- @since 0.0.3.0
foreignPtrEqConstantTime :: ForeignPtr CUChar -> ForeignPtr CUChar -> CSize -> Bool
foreignPtrEqConstantTime :: ForeignPtr CUChar -> ForeignPtr CUChar -> CSize -> Bool
foreignPtrEqConstantTime ForeignPtr CUChar
p ForeignPtr CUChar
q CSize
size =
  IO Bool -> Bool
forall a. IO a -> a
unsafeDupablePerformIO (IO Bool -> Bool)
-> (Scoped IO CInt -> IO Bool) -> Scoped IO CInt -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (CInt -> Bool) -> IO CInt -> IO Bool
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (CInt -> CInt -> Bool
forall a. Eq a => a -> a -> Bool
== CInt
0) (IO CInt -> IO Bool)
-> (Scoped IO CInt -> IO CInt) -> Scoped IO CInt -> IO Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scoped IO CInt -> IO CInt
forall (m :: * -> *) a. Applicative m => Scoped m a -> m a
use (Scoped IO CInt -> Bool) -> Scoped IO CInt -> Bool
forall a b. (a -> b) -> a -> b
$
    Ptr CUChar -> Ptr CUChar -> CSize -> CInt
sodiumMemcmp (Ptr CUChar -> Ptr CUChar -> CSize -> CInt)
-> Scoped IO (Ptr CUChar)
-> Scoped IO (Ptr CUChar -> CSize -> CInt)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ForeignPtr CUChar -> Scoped IO (Ptr CUChar)
forall a. ForeignPtr a -> Scoped IO (Ptr a)
foreignPtr ForeignPtr CUChar
p Scoped IO (Ptr CUChar -> CSize -> CInt)
-> Scoped IO (Ptr CUChar) -> Scoped IO (CSize -> CInt)
forall a b. Scoped IO (a -> b) -> Scoped IO a -> Scoped IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ForeignPtr CUChar -> Scoped IO (Ptr CUChar)
forall a. ForeignPtr a -> Scoped IO (Ptr a)
foreignPtr ForeignPtr CUChar
q Scoped IO (CSize -> CInt) -> Scoped IO CSize -> Scoped IO CInt
forall a b. Scoped IO (a -> b) -> Scoped IO a -> Scoped IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> CSize -> Scoped IO CSize
forall a. a -> Scoped IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure CSize
size

-- | Lexicographically compare the contents of two byte arrays.
--
-- ⚠️ Such comparisons are vulnerable to timing attacks, and should be
-- avoided for secret data.
--
-- @since 0.0.1.0
foreignPtrOrd :: ForeignPtr CUChar -> ForeignPtr CUChar -> CSize -> Ordering
foreignPtrOrd :: ForeignPtr CUChar -> ForeignPtr CUChar -> CSize -> Ordering
foreignPtrOrd ForeignPtr CUChar
p ForeignPtr CUChar
q CSize
size =
  IO Ordering -> Ordering
forall a. IO a -> a
unsafeDupablePerformIO (IO Ordering -> Ordering)
-> (Scoped IO (IO CInt) -> IO Ordering)
-> Scoped IO (IO CInt)
-> Ordering
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (CInt -> Ordering) -> IO CInt -> IO Ordering
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (CInt -> CInt -> Ordering
forall a. Ord a => a -> a -> Ordering
`compare` CInt
0) (IO CInt -> IO Ordering)
-> (Scoped IO (IO CInt) -> IO CInt)
-> Scoped IO (IO CInt)
-> IO Ordering
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scoped IO (IO CInt) -> IO CInt
forall (m :: * -> *) a. Monad m => Scoped m (m a) -> m a
useM (Scoped IO (IO CInt) -> Ordering)
-> Scoped IO (IO CInt) -> Ordering
forall a b. (a -> b) -> a -> b
$
    Ptr Word8 -> Ptr Word8 -> Int -> IO CInt
memcmp
      (Ptr Word8 -> Ptr Word8 -> Int -> IO CInt)
-> Scoped IO (Ptr Word8) -> Scoped IO (Ptr Word8 -> Int -> IO CInt)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ForeignPtr Word8 -> Scoped IO (Ptr Word8)
forall a. ForeignPtr a -> Scoped IO (Ptr a)
foreignPtr (ForeignPtr CUChar -> ForeignPtr Word8
forall a b. Coercible a b => a -> b
coerce ForeignPtr CUChar
p)
      Scoped IO (Ptr Word8 -> Int -> IO CInt)
-> Scoped IO (Ptr Word8) -> Scoped IO (Int -> IO CInt)
forall a b. Scoped IO (a -> b) -> Scoped IO a -> Scoped IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ForeignPtr Word8 -> Scoped IO (Ptr Word8)
forall a. ForeignPtr a -> Scoped IO (Ptr a)
foreignPtr (ForeignPtr CUChar -> ForeignPtr Word8
forall a b. Coercible a b => a -> b
coerce ForeignPtr CUChar
q)
      Scoped IO (Int -> IO CInt) -> Scoped IO Int -> Scoped IO (IO CInt)
forall a b. Scoped IO (a -> b) -> Scoped IO a -> Scoped IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> Scoped IO Int
forall a. a -> Scoped IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (CSize -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CSize
size)

-- | Lexicographically compare the contents of two byte arrays in constant time.
--
-- /See:/ [Comparing large numbers](https://libsodium.gitbook.io/doc/helpers#comparing-large-numbers)
--
-- @since 0.0.3.0
foreignPtrOrdConstantTime :: ForeignPtr CUChar -> ForeignPtr CUChar -> CSize -> Ordering
foreignPtrOrdConstantTime :: ForeignPtr CUChar -> ForeignPtr CUChar -> CSize -> Ordering
foreignPtrOrdConstantTime ForeignPtr CUChar
p ForeignPtr CUChar
q CSize
size =
  IO Ordering -> Ordering
forall a. IO a -> a
unsafeDupablePerformIO (IO Ordering -> Ordering)
-> (Scoped IO (IO CInt) -> IO Ordering)
-> Scoped IO (IO CInt)
-> Ordering
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (CInt -> Ordering) -> IO CInt -> IO Ordering
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (CInt -> CInt -> Ordering
forall a. Ord a => a -> a -> Ordering
`compare` CInt
0) (IO CInt -> IO Ordering)
-> (Scoped IO (IO CInt) -> IO CInt)
-> Scoped IO (IO CInt)
-> IO Ordering
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scoped IO (IO CInt) -> IO CInt
forall (m :: * -> *) a. Monad m => Scoped m (m a) -> m a
useM (Scoped IO (IO CInt) -> Ordering)
-> Scoped IO (IO CInt) -> Ordering
forall a b. (a -> b) -> a -> b
$
    Ptr CUChar -> Ptr CUChar -> CSize -> IO CInt
sodiumCompare (Ptr CUChar -> Ptr CUChar -> CSize -> IO CInt)
-> Scoped IO (Ptr CUChar)
-> Scoped IO (Ptr CUChar -> CSize -> IO CInt)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ForeignPtr CUChar -> Scoped IO (Ptr CUChar)
forall a. ForeignPtr a -> Scoped IO (Ptr a)
foreignPtr ForeignPtr CUChar
p Scoped IO (Ptr CUChar -> CSize -> IO CInt)
-> Scoped IO (Ptr CUChar) -> Scoped IO (CSize -> IO CInt)
forall a b. Scoped IO (a -> b) -> Scoped IO a -> Scoped IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ForeignPtr CUChar -> Scoped IO (Ptr CUChar)
forall a. ForeignPtr a -> Scoped IO (Ptr a)
foreignPtr ForeignPtr CUChar
q Scoped IO (CSize -> IO CInt)
-> Scoped IO CSize -> Scoped IO (IO CInt)
forall a b. Scoped IO (a -> b) -> Scoped IO a -> Scoped IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> CSize -> Scoped IO CSize
forall a. a -> Scoped IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure CSize
size

-- | Compare two byte arrays for lexicographic equality.
--
-- ⚠️ Such comparisons are vulnerable to timing attacks, and should be
-- avoided for secret data.
--
-- @since 0.0.1.0
foreignPtrEq :: ForeignPtr CUChar -> ForeignPtr CUChar -> CSize -> Bool
foreignPtrEq :: ForeignPtr CUChar -> ForeignPtr CUChar -> CSize -> Bool
foreignPtrEq ForeignPtr CUChar
p ForeignPtr CUChar
q CSize
size = ForeignPtr CUChar -> ForeignPtr CUChar -> CSize -> Ordering
foreignPtrOrd ForeignPtr CUChar
p ForeignPtr CUChar
q CSize
size Ordering -> Ordering -> Bool
forall a. Eq a => a -> a -> Bool
== Ordering
EQ

-- | Convert a @'ForeignPtr' a@ to a 'ByteString' of the given length
-- and render the hexadecimal-encoded bytes as a 'String'.
--
-- @since 0.0.1.0
foreignPtrShow :: ForeignPtr a -> CSize -> String
foreignPtrShow :: forall a. ForeignPtr a -> CSize -> String
foreignPtrShow (ForeignPtr a -> ForeignPtr Word8
forall a b. ForeignPtr a -> ForeignPtr b
Foreign.castForeignPtr -> ForeignPtr Word8
cstring) CSize
size =
  ByteString -> String
ByteString.unpackChars (ByteString -> String)
-> (ByteString -> ByteString) -> ByteString -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Base16 ByteString -> ByteString
forall a. Base16 a -> a
Base16.extractBase16 (Base16 ByteString -> ByteString)
-> (ByteString -> Base16 ByteString) -> ByteString -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Base16 ByteString
Base16.encodeBase16' (ByteString -> String) -> ByteString -> String
forall a b. (a -> b) -> a -> b
$
    ForeignPtr Word8 -> Int -> Int -> ByteString
ByteString.fromForeignPtr ForeignPtr Word8
cstring Int
0 (forall a b. (Integral a, Num b) => a -> b
fromIntegral @CSize @Int CSize
size)

-- | Copy a byte array to a @libsodium@ pointer.
--
-- The size of the array is not checked. The input may be truncated if
-- it is too long, or an unchecked exception may be thrown if it is
-- too short.
--
-- @since 0.0.3.0
unsafeCopyToSodiumPointer :: CSize -> StrictByteString -> IO (ForeignPtr CUChar)
unsafeCopyToSodiumPointer :: CSize -> ByteString -> IO (ForeignPtr CUChar)
unsafeCopyToSodiumPointer CSize
size ByteString
s = Scoped IO (ForeignPtr CUChar) -> IO (ForeignPtr CUChar)
forall (m :: * -> *) a. Applicative m => Scoped m a -> m a
use (Scoped IO (ForeignPtr CUChar) -> IO (ForeignPtr CUChar))
-> Scoped IO (ForeignPtr CUChar) -> IO (ForeignPtr CUChar)
forall a b. (a -> b) -> a -> b
$ do
  str <- ByteString -> Scoped IO CString
unsafeCString ByteString
s
  lift $ sodiumPointer size $ \Ptr CUChar
k ->
    CString -> CString -> Int -> IO ()
forall a. Storable a => Ptr a -> Ptr a -> Int -> IO ()
Foreign.copyArray
      (forall a b. Ptr a -> Ptr b
Foreign.castPtr @CUChar @CChar Ptr CUChar
k)
      CString
str
      (forall a b. (Integral a, Num b) => a -> b
fromIntegral @CSize @Int CSize
size)

-- | Allocate secure memory and populate it with the provided action.
--
-- Memory is allocated with 'LibSodium.Bindings.SecureMemory.sodiumMalloc' (see notes).
--
-- A finalizer frees the memory when the key goes out of scope.
--
-- @since 0.0.3.0
sodiumPointer :: CSize -> (Ptr CUChar -> IO ()) -> IO (ForeignPtr CUChar)
sodiumPointer :: CSize -> (Ptr CUChar -> IO ()) -> IO (ForeignPtr CUChar)
sodiumPointer CSize
size Ptr CUChar -> IO ()
action = do
  ptr <- CSize -> IO (Ptr CUChar)
forall a. CSize -> IO (Ptr a)
sodiumMalloc CSize
size
  when (ptr == Foreign.nullPtr) $ do
    throwErrno "sodium_malloc"
  action ptr
  Foreign.newForeignPtr finalizerSodiumFree ptr

-- | Securely allocate an amount of memory with 'sodiumMalloc' and pass
-- a pointer to the region to the provided action.
-- The region is deallocated with 'sodiumFree' afterwards.
-- Do not try to jailbreak the pointer outside of the action,
-- this will not be pleasant.
allocateWith
  :: forall (a :: Type) (b :: Type) (m :: Type -> Type)
   . MonadIO m
  => CSize
  -- ^ Amount of memory to allocate
  -> (Ptr a -> m b)
  -- ^ Action to perform on the memory
  -> m b
allocateWith :: forall a b (m :: * -> *).
MonadIO m =>
CSize -> (Ptr a -> m b) -> m b
allocateWith CSize
size Ptr a -> m b
action = do
  !ptr <- IO (Ptr a) -> m (Ptr a)
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Ptr a) -> m (Ptr a)) -> IO (Ptr a) -> m (Ptr a)
forall a b. (a -> b) -> a -> b
$ CSize -> IO (Ptr a)
forall a. CSize -> IO (Ptr a)
sodiumMalloc CSize
size
  !result <- action ptr
  liftIO $ sodiumFree ptr
  pure result