{-# OPTIONS_GHC -Wno-noncanonical-monad-instances #-}
{-# LANGUAGE FlexibleInstances          #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase                 #-}
{-# LANGUAGE MultiParamTypeClasses      #-}
{-# LANGUAGE StrictData                 #-}
{-# LANGUAGE TupleSections              #-}
{-# LANGUAGE UndecidableInstances       #-}

-- | Monad class for caching of combined keys
module Tox.Crypto.Core.Keyed where

import           Control.Monad.IO.Class           (MonadIO)
import           Control.Monad.Random             (RandT)
import           Control.Monad.Reader             (ReaderT)
import           Control.Monad.RWS                (RWST)
import           Control.Monad.State              (MonadState, StateT (..),
                                                   evalStateT, gets, modify,
                                                   runStateT, state)
import           Control.Monad.Trans              (MonadTrans, lift)
import           Control.Monad.Writer             (MonadWriter, WriterT)

import           Data.Map                         (Map)
import qualified Data.Map                         as Map
import           Tox.Core.Timed                   (Timed)
import qualified Tox.Crypto.Core.CombinedKey      as CombinedKey
import           Tox.Crypto.Core.Key              (CombinedKey, PublicKey,
                                                   SecretKey)
import           Tox.Crypto.Core.MonadRandomBytes (MonadRandomBytes)

class (Monad m, Applicative m) => Keyed m where
  getCombinedKey :: SecretKey -> PublicKey -> m CombinedKey

instance Keyed m => Keyed (ReaderT r m) where
  getCombinedKey :: SecretKey -> PublicKey -> ReaderT r m CombinedKey
getCombinedKey = (m CombinedKey -> ReaderT r m CombinedKey
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m CombinedKey -> ReaderT r m CombinedKey)
-> (PublicKey -> m CombinedKey)
-> PublicKey
-> ReaderT r m CombinedKey
forall b c a. (b -> c) -> (a -> b) -> a -> c
.) ((PublicKey -> m CombinedKey)
 -> PublicKey -> ReaderT r m CombinedKey)
-> (SecretKey -> PublicKey -> m CombinedKey)
-> SecretKey
-> PublicKey
-> ReaderT r m CombinedKey
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SecretKey -> PublicKey -> m CombinedKey
forall (m :: * -> *).
Keyed m =>
SecretKey -> PublicKey -> m CombinedKey
getCombinedKey
instance (Monoid w, Keyed m) => Keyed (WriterT w m) where
  getCombinedKey :: SecretKey -> PublicKey -> WriterT w m CombinedKey
getCombinedKey = (m CombinedKey -> WriterT w m CombinedKey
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m CombinedKey -> WriterT w m CombinedKey)
-> (PublicKey -> m CombinedKey)
-> PublicKey
-> WriterT w m CombinedKey
forall b c a. (b -> c) -> (a -> b) -> a -> c
.) ((PublicKey -> m CombinedKey)
 -> PublicKey -> WriterT w m CombinedKey)
-> (SecretKey -> PublicKey -> m CombinedKey)
-> SecretKey
-> PublicKey
-> WriterT w m CombinedKey
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SecretKey -> PublicKey -> m CombinedKey
forall (m :: * -> *).
Keyed m =>
SecretKey -> PublicKey -> m CombinedKey
getCombinedKey
instance Keyed m => Keyed (StateT s m) where
  getCombinedKey :: SecretKey -> PublicKey -> StateT s m CombinedKey
getCombinedKey = (m CombinedKey -> StateT s m CombinedKey
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m CombinedKey -> StateT s m CombinedKey)
-> (PublicKey -> m CombinedKey)
-> PublicKey
-> StateT s m CombinedKey
forall b c a. (b -> c) -> (a -> b) -> a -> c
.) ((PublicKey -> m CombinedKey)
 -> PublicKey -> StateT s m CombinedKey)
-> (SecretKey -> PublicKey -> m CombinedKey)
-> SecretKey
-> PublicKey
-> StateT s m CombinedKey
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SecretKey -> PublicKey -> m CombinedKey
forall (m :: * -> *).
Keyed m =>
SecretKey -> PublicKey -> m CombinedKey
getCombinedKey
instance (Monoid w, Keyed m) => Keyed (RWST r w s m) where
  getCombinedKey :: SecretKey -> PublicKey -> RWST r w s m CombinedKey
getCombinedKey = (m CombinedKey -> RWST r w s m CombinedKey
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m CombinedKey -> RWST r w s m CombinedKey)
-> (PublicKey -> m CombinedKey)
-> PublicKey
-> RWST r w s m CombinedKey
forall b c a. (b -> c) -> (a -> b) -> a -> c
.) ((PublicKey -> m CombinedKey)
 -> PublicKey -> RWST r w s m CombinedKey)
-> (SecretKey -> PublicKey -> m CombinedKey)
-> SecretKey
-> PublicKey
-> RWST r w s m CombinedKey
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SecretKey -> PublicKey -> m CombinedKey
forall (m :: * -> *).
Keyed m =>
SecretKey -> PublicKey -> m CombinedKey
getCombinedKey
instance Keyed m => Keyed (RandT s m) where
  getCombinedKey :: SecretKey -> PublicKey -> RandT s m CombinedKey
getCombinedKey = (m CombinedKey -> RandT s m CombinedKey
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m CombinedKey -> RandT s m CombinedKey)
-> (PublicKey -> m CombinedKey)
-> PublicKey
-> RandT s m CombinedKey
forall b c a. (b -> c) -> (a -> b) -> a -> c
.) ((PublicKey -> m CombinedKey)
 -> PublicKey -> RandT s m CombinedKey)
-> (SecretKey -> PublicKey -> m CombinedKey)
-> SecretKey
-> PublicKey
-> RandT s m CombinedKey
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SecretKey -> PublicKey -> m CombinedKey
forall (m :: * -> *).
Keyed m =>
SecretKey -> PublicKey -> m CombinedKey
getCombinedKey

-- | trivial instance: the trivial monad, with no caching of keys
newtype NullKeyed a = NullKeyed { NullKeyed a -> a
runNullKeyed :: a }
instance Functor NullKeyed where
  fmap :: (a -> b) -> NullKeyed a -> NullKeyed b
fmap a -> b
f (NullKeyed a
x) = b -> NullKeyed b
forall a. a -> NullKeyed a
NullKeyed (a -> b
f a
x)
instance Applicative NullKeyed where
  pure :: a -> NullKeyed a
pure = a -> NullKeyed a
forall a. a -> NullKeyed a
NullKeyed
  (NullKeyed a -> b
f) <*> :: NullKeyed (a -> b) -> NullKeyed a -> NullKeyed b
<*> (NullKeyed a
x) = b -> NullKeyed b
forall a. a -> NullKeyed a
NullKeyed (a -> b
f a
x)
instance Monad NullKeyed where
  return :: a -> NullKeyed a
return = a -> NullKeyed a
forall a. a -> NullKeyed a
NullKeyed
  NullKeyed a
x >>= :: NullKeyed a -> (a -> NullKeyed b) -> NullKeyed b
>>= a -> NullKeyed b
f = a -> NullKeyed b
f a
x
instance Keyed NullKeyed where
  getCombinedKey :: SecretKey -> PublicKey -> NullKeyed CombinedKey
getCombinedKey = (CombinedKey -> NullKeyed CombinedKey
forall a. a -> NullKeyed a
NullKeyed (CombinedKey -> NullKeyed CombinedKey)
-> (PublicKey -> CombinedKey) -> PublicKey -> NullKeyed CombinedKey
forall b c a. (b -> c) -> (a -> b) -> a -> c
.) ((PublicKey -> CombinedKey) -> PublicKey -> NullKeyed CombinedKey)
-> (SecretKey -> PublicKey -> CombinedKey)
-> SecretKey
-> PublicKey
-> NullKeyed CombinedKey
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SecretKey -> PublicKey -> CombinedKey
CombinedKey.precompute

type KeyRing = Map (SecretKey, PublicKey) CombinedKey

-- | caches computations of combined keys. Makes no attempt to delete old keys.
newtype KeyedT m a = KeyedT (StateT KeyRing m a)
  deriving (Applicative (KeyedT m)
a -> KeyedT m a
Applicative (KeyedT m)
-> (forall a b. KeyedT m a -> (a -> KeyedT m b) -> KeyedT m b)
-> (forall a b. KeyedT m a -> KeyedT m b -> KeyedT m b)
-> (forall a. a -> KeyedT m a)
-> Monad (KeyedT m)
KeyedT m a -> (a -> KeyedT m b) -> KeyedT m b
KeyedT m a -> KeyedT m b -> KeyedT m b
forall a. a -> KeyedT m a
forall a b. KeyedT m a -> KeyedT m b -> KeyedT m b
forall a b. KeyedT m a -> (a -> KeyedT m b) -> KeyedT m b
forall (m :: * -> *). Monad m => Applicative (KeyedT m)
forall (m :: * -> *) a. Monad m => a -> KeyedT m a
forall (m :: * -> *) a b.
Monad m =>
KeyedT m a -> KeyedT m b -> KeyedT m b
forall (m :: * -> *) a b.
Monad m =>
KeyedT m a -> (a -> KeyedT m b) -> KeyedT m b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> KeyedT m a
$creturn :: forall (m :: * -> *) a. Monad m => a -> KeyedT m a
>> :: KeyedT m a -> KeyedT m b -> KeyedT m b
$c>> :: forall (m :: * -> *) a b.
Monad m =>
KeyedT m a -> KeyedT m b -> KeyedT m b
>>= :: KeyedT m a -> (a -> KeyedT m b) -> KeyedT m b
$c>>= :: forall (m :: * -> *) a b.
Monad m =>
KeyedT m a -> (a -> KeyedT m b) -> KeyedT m b
$cp1Monad :: forall (m :: * -> *). Monad m => Applicative (KeyedT m)
Monad, Functor (KeyedT m)
a -> KeyedT m a
Functor (KeyedT m)
-> (forall a. a -> KeyedT m a)
-> (forall a b. KeyedT m (a -> b) -> KeyedT m a -> KeyedT m b)
-> (forall a b c.
    (a -> b -> c) -> KeyedT m a -> KeyedT m b -> KeyedT m c)
-> (forall a b. KeyedT m a -> KeyedT m b -> KeyedT m b)
-> (forall a b. KeyedT m a -> KeyedT m b -> KeyedT m a)
-> Applicative (KeyedT m)
KeyedT m a -> KeyedT m b -> KeyedT m b
KeyedT m a -> KeyedT m b -> KeyedT m a
KeyedT m (a -> b) -> KeyedT m a -> KeyedT m b
(a -> b -> c) -> KeyedT m a -> KeyedT m b -> KeyedT m c
forall a. a -> KeyedT m a
forall a b. KeyedT m a -> KeyedT m b -> KeyedT m a
forall a b. KeyedT m a -> KeyedT m b -> KeyedT m b
forall a b. KeyedT m (a -> b) -> KeyedT m a -> KeyedT m b
forall a b c.
(a -> b -> c) -> KeyedT m a -> KeyedT m b -> KeyedT m c
forall (m :: * -> *). Monad m => Functor (KeyedT m)
forall (m :: * -> *) a. Monad m => a -> KeyedT m a
forall (m :: * -> *) a b.
Monad m =>
KeyedT m a -> KeyedT m b -> KeyedT m a
forall (m :: * -> *) a b.
Monad m =>
KeyedT m a -> KeyedT m b -> KeyedT m b
forall (m :: * -> *) a b.
Monad m =>
KeyedT m (a -> b) -> KeyedT m a -> KeyedT m b
forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> KeyedT m a -> KeyedT m b -> KeyedT m c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: KeyedT m a -> KeyedT m b -> KeyedT m a
$c<* :: forall (m :: * -> *) a b.
Monad m =>
KeyedT m a -> KeyedT m b -> KeyedT m a
*> :: KeyedT m a -> KeyedT m b -> KeyedT m b
$c*> :: forall (m :: * -> *) a b.
Monad m =>
KeyedT m a -> KeyedT m b -> KeyedT m b
liftA2 :: (a -> b -> c) -> KeyedT m a -> KeyedT m b -> KeyedT m c
$cliftA2 :: forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> KeyedT m a -> KeyedT m b -> KeyedT m c
<*> :: KeyedT m (a -> b) -> KeyedT m a -> KeyedT m b
$c<*> :: forall (m :: * -> *) a b.
Monad m =>
KeyedT m (a -> b) -> KeyedT m a -> KeyedT m b
pure :: a -> KeyedT m a
$cpure :: forall (m :: * -> *) a. Monad m => a -> KeyedT m a
$cp1Applicative :: forall (m :: * -> *). Monad m => Functor (KeyedT m)
Applicative, a -> KeyedT m b -> KeyedT m a
(a -> b) -> KeyedT m a -> KeyedT m b
(forall a b. (a -> b) -> KeyedT m a -> KeyedT m b)
-> (forall a b. a -> KeyedT m b -> KeyedT m a)
-> Functor (KeyedT m)
forall a b. a -> KeyedT m b -> KeyedT m a
forall a b. (a -> b) -> KeyedT m a -> KeyedT m b
forall (m :: * -> *) a b.
Functor m =>
a -> KeyedT m b -> KeyedT m a
forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> KeyedT m a -> KeyedT m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> KeyedT m b -> KeyedT m a
$c<$ :: forall (m :: * -> *) a b.
Functor m =>
a -> KeyedT m b -> KeyedT m a
fmap :: (a -> b) -> KeyedT m a -> KeyedT m b
$cfmap :: forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> KeyedT m a -> KeyedT m b
Functor, MonadWriter w
    , Monad (KeyedT m)
Applicative (KeyedT m)
KeyedT m KeyPair
Monad (KeyedT m)
-> Applicative (KeyedT m)
-> (Int -> KeyedT m ByteString)
-> KeyedT m KeyPair
-> MonadRandomBytes (KeyedT m)
Int -> KeyedT m ByteString
forall (m :: * -> *).
Monad m
-> Applicative m
-> (Int -> m ByteString)
-> m KeyPair
-> MonadRandomBytes m
forall (m :: * -> *). MonadRandomBytes m => Monad (KeyedT m)
forall (m :: * -> *). MonadRandomBytes m => Applicative (KeyedT m)
forall (m :: * -> *). MonadRandomBytes m => KeyedT m KeyPair
forall (m :: * -> *).
MonadRandomBytes m =>
Int -> KeyedT m ByteString
newKeyPair :: KeyedT m KeyPair
$cnewKeyPair :: forall (m :: * -> *). MonadRandomBytes m => KeyedT m KeyPair
randomBytes :: Int -> KeyedT m ByteString
$crandomBytes :: forall (m :: * -> *).
MonadRandomBytes m =>
Int -> KeyedT m ByteString
$cp2MonadRandomBytes :: forall (m :: * -> *). MonadRandomBytes m => Applicative (KeyedT m)
$cp1MonadRandomBytes :: forall (m :: * -> *). MonadRandomBytes m => Monad (KeyedT m)
MonadRandomBytes, m a -> KeyedT m a
(forall (m :: * -> *) a. Monad m => m a -> KeyedT m a)
-> MonadTrans KeyedT
forall (m :: * -> *) a. Monad m => m a -> KeyedT m a
forall (t :: (* -> *) -> * -> *).
(forall (m :: * -> *) a. Monad m => m a -> t m a) -> MonadTrans t
lift :: m a -> KeyedT m a
$clift :: forall (m :: * -> *) a. Monad m => m a -> KeyedT m a
MonadTrans, Monad (KeyedT m)
Monad (KeyedT m)
-> (forall a. IO a -> KeyedT m a) -> MonadIO (KeyedT m)
IO a -> KeyedT m a
forall a. IO a -> KeyedT m a
forall (m :: * -> *).
Monad m -> (forall a. IO a -> m a) -> MonadIO m
forall (m :: * -> *). MonadIO m => Monad (KeyedT m)
forall (m :: * -> *) a. MonadIO m => IO a -> KeyedT m a
liftIO :: IO a -> KeyedT m a
$cliftIO :: forall (m :: * -> *) a. MonadIO m => IO a -> KeyedT m a
$cp1MonadIO :: forall (m :: * -> *). MonadIO m => Monad (KeyedT m)
MonadIO, Monad (KeyedT m)
KeyedT m Timestamp
Monad (KeyedT m) -> KeyedT m Timestamp -> Timed (KeyedT m)
forall (m :: * -> *). Monad m -> m Timestamp -> Timed m
forall (m :: * -> *). Timed m => Monad (KeyedT m)
forall (m :: * -> *). Timed m => KeyedT m Timestamp
askTime :: KeyedT m Timestamp
$caskTime :: forall (m :: * -> *). Timed m => KeyedT m Timestamp
$cp1Timed :: forall (m :: * -> *). Timed m => Monad (KeyedT m)
Timed)

runKeyedT :: Monad m => KeyedT m a -> KeyRing -> m (a, KeyRing)
runKeyedT :: KeyedT m a -> KeyRing -> m (a, KeyRing)
runKeyedT (KeyedT StateT KeyRing m a
m) = StateT KeyRing m a -> KeyRing -> m (a, KeyRing)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT StateT KeyRing m a
m

evalKeyedT :: Monad m => KeyedT m a -> KeyRing -> m a
evalKeyedT :: KeyedT m a -> KeyRing -> m a
evalKeyedT (KeyedT StateT KeyRing m a
m) = StateT KeyRing m a -> KeyRing -> m a
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT StateT KeyRing m a
m

instance (MonadState s m, Applicative m) => MonadState s (KeyedT m) where
  state :: (s -> (a, s)) -> KeyedT m a
state s -> (a, s)
f = StateT KeyRing m a -> KeyedT m a
forall (m :: * -> *) a. StateT KeyRing m a -> KeyedT m a
KeyedT (StateT KeyRing m a -> KeyedT m a)
-> ((KeyRing -> m (a, KeyRing)) -> StateT KeyRing m a)
-> (KeyRing -> m (a, KeyRing))
-> KeyedT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (KeyRing -> m (a, KeyRing)) -> StateT KeyRing m a
forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateT ((KeyRing -> m (a, KeyRing)) -> KeyedT m a)
-> (KeyRing -> m (a, KeyRing)) -> KeyedT m a
forall a b. (a -> b) -> a -> b
$ \KeyRing
s -> (, KeyRing
s) (a -> (a, KeyRing)) -> m a -> m (a, KeyRing)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (s -> (a, s)) -> m a
forall s (m :: * -> *) a. MonadState s m => (s -> (a, s)) -> m a
state s -> (a, s)
f

instance (Monad m, Applicative m) => Keyed (KeyedT m) where
  getCombinedKey :: SecretKey -> PublicKey -> KeyedT m CombinedKey
getCombinedKey SecretKey
secretKey PublicKey
publicKey =
    let keys :: (SecretKey, PublicKey)
keys = (SecretKey
secretKey, PublicKey
publicKey)
    in StateT KeyRing m CombinedKey -> KeyedT m CombinedKey
forall (m :: * -> *) a. StateT KeyRing m a -> KeyedT m a
KeyedT (StateT KeyRing m CombinedKey -> KeyedT m CombinedKey)
-> StateT KeyRing m CombinedKey -> KeyedT m CombinedKey
forall a b. (a -> b) -> a -> b
$ (KeyRing -> Maybe CombinedKey)
-> StateT KeyRing m (Maybe CombinedKey)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((SecretKey, PublicKey) -> KeyRing -> Maybe CombinedKey
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup (SecretKey, PublicKey)
keys) StateT KeyRing m (Maybe CombinedKey)
-> (Maybe CombinedKey -> StateT KeyRing m CombinedKey)
-> StateT KeyRing m CombinedKey
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      Maybe CombinedKey
Nothing ->
        let shared :: CombinedKey
shared = SecretKey -> PublicKey -> CombinedKey
CombinedKey.precompute SecretKey
secretKey PublicKey
publicKey
        in (KeyRing -> KeyRing) -> StateT KeyRing m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((SecretKey, PublicKey) -> CombinedKey -> KeyRing -> KeyRing
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert (SecretKey, PublicKey)
keys CombinedKey
shared) StateT KeyRing m ()
-> StateT KeyRing m CombinedKey -> StateT KeyRing m CombinedKey
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> CombinedKey -> StateT KeyRing m CombinedKey
forall (m :: * -> *) a. Monad m => a -> m a
return CombinedKey
shared
      Just CombinedKey
shared -> CombinedKey -> StateT KeyRing m CombinedKey
forall (m :: * -> *) a. Monad m => a -> m a
return CombinedKey
shared