{-# LANGUAGE RecordWildCards #-}

module Network.DNS.Memo where

import qualified Control.Reaper as R
import qualified Data.ByteString as B
import qualified Data.CaseInsensitive as CI
import Data.Hourglass (Elapsed)
import Data.OrdPSQ (OrdPSQ)
import qualified Data.OrdPSQ as PSQ
import Time.System (timeCurrent)

import Network.DNS.Imports
import Network.DNS.Types.Internal

data Section = Answer | Authority deriving (Section -> Section -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Section -> Section -> Bool
$c/= :: Section -> Section -> Bool
== :: Section -> Section -> Bool
$c== :: Section -> Section -> Bool
Eq, Eq Section
Section -> Section -> Bool
Section -> Section -> Ordering
Section -> Section -> Section
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Section -> Section -> Section
$cmin :: Section -> Section -> Section
max :: Section -> Section -> Section
$cmax :: Section -> Section -> Section
>= :: Section -> Section -> Bool
$c>= :: Section -> Section -> Bool
> :: Section -> Section -> Bool
$c> :: Section -> Section -> Bool
<= :: Section -> Section -> Bool
$c<= :: Section -> Section -> Bool
< :: Section -> Section -> Bool
$c< :: Section -> Section -> Bool
compare :: Section -> Section -> Ordering
$ccompare :: Section -> Section -> Ordering
Ord, Int -> Section -> ShowS
[Section] -> ShowS
Section -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Section] -> ShowS
$cshowList :: [Section] -> ShowS
show :: Section -> String
$cshow :: Section -> String
showsPrec :: Int -> Section -> ShowS
$cshowsPrec :: Int -> Section -> ShowS
Show)

type Key = (ByteString
           ,TYPE)
type Prio = Elapsed

type Entry = Either DNSError [RData]

type DB = OrdPSQ Key Prio Entry

type Cache = R.Reaper DB (Key,Prio,Entry)

newCache :: Int -> IO Cache
newCache :: Int -> IO Cache
newCache Int
delay = forall workload item.
ReaperSettings workload item -> IO (Reaper workload item)
R.mkReaper forall item. ReaperSettings [item] item
R.defaultReaperSettings {
    reaperEmpty :: OrdPSQ Key Prio Entry
R.reaperEmpty  = forall k p v. OrdPSQ k p v
PSQ.empty
  , reaperCons :: (Key, Prio, Entry)
-> OrdPSQ Key Prio Entry -> OrdPSQ Key Prio Entry
R.reaperCons   = \(Key
k, Prio
tim, Entry
v) OrdPSQ Key Prio Entry
psq -> forall k p v.
(Ord k, Ord p) =>
k -> p -> v -> OrdPSQ k p v -> OrdPSQ k p v
PSQ.insert Key
k Prio
tim Entry
v OrdPSQ Key Prio Entry
psq
  , reaperAction :: OrdPSQ Key Prio Entry
-> IO (OrdPSQ Key Prio Entry -> OrdPSQ Key Prio Entry)
R.reaperAction = OrdPSQ Key Prio Entry
-> IO (OrdPSQ Key Prio Entry -> OrdPSQ Key Prio Entry)
prune
  , reaperDelay :: Int
R.reaperDelay  = Int
delay forall a. Num a => a -> a -> a
* Int
1000000
  , reaperNull :: OrdPSQ Key Prio Entry -> Bool
R.reaperNull   = forall k p v. OrdPSQ k p v -> Bool
PSQ.null
  }

lookupCache :: Key -> Cache -> IO (Maybe (Prio, Entry))
lookupCache :: Key -> Cache -> IO (Maybe (Prio, Entry))
lookupCache Key
key Cache
reaper = forall k p v. Ord k => k -> OrdPSQ k p v -> Maybe (p, v)
PSQ.lookup Key
key forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall workload item. Reaper workload item -> IO workload
R.reaperRead Cache
reaper

insertCache :: Key -> Prio -> Entry -> Cache -> IO ()
insertCache :: Key -> Prio -> Entry -> Cache -> IO ()
insertCache (ByteString
dom,TYPE
typ) Prio
tim Entry
ent0 Cache
reaper = forall workload item. Reaper workload item -> item -> IO ()
R.reaperAdd Cache
reaper (Key
key,Prio
tim,Entry
ent)
  where
    key :: Key
key = (ByteString -> ByteString
B.copy ByteString
dom,TYPE
typ)
    ent :: Entry
ent = case Entry
ent0 of
      l :: Entry
l@(Left DNSError
_)  -> Entry
l
      (Right [RData]
rds) -> forall a b. b -> Either a b
Right forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map RData -> RData
copy [RData]
rds

-- Theoretically speaking, atMostView itself is good enough for pruning.
-- But auto-update assumes a list based db which does not provide atMost
-- functions. So, we need to do this redundant way.
prune :: DB -> IO (DB -> DB)
prune :: OrdPSQ Key Prio Entry
-> IO (OrdPSQ Key Prio Entry -> OrdPSQ Key Prio Entry)
prune OrdPSQ Key Prio Entry
oldpsq = do
    Prio
tim <- IO Prio
timeCurrent
    let ([(Key, Prio, Entry)]
_, OrdPSQ Key Prio Entry
pruned) = forall k p v.
(Ord k, Ord p) =>
p -> OrdPSQ k p v -> ([(k, p, v)], OrdPSQ k p v)
PSQ.atMostView Prio
tim OrdPSQ Key Prio Entry
oldpsq
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ \OrdPSQ Key Prio Entry
newpsq -> forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' forall {k} {p} {v}.
(Ord k, Ord p) =>
OrdPSQ k p v -> (k, p, v) -> OrdPSQ k p v
ins OrdPSQ Key Prio Entry
pruned forall a b. (a -> b) -> a -> b
$ forall k p v. OrdPSQ k p v -> [(k, p, v)]
PSQ.toList OrdPSQ Key Prio Entry
newpsq
  where
    ins :: OrdPSQ k p v -> (k, p, v) -> OrdPSQ k p v
ins OrdPSQ k p v
psq (k
k,p
p,v
v) = forall k p v.
(Ord k, Ord p) =>
k -> p -> v -> OrdPSQ k p v -> OrdPSQ k p v
PSQ.insert k
k p
p v
v OrdPSQ k p v
psq

copy :: RData -> RData
copy :: RData -> RData
copy r :: RData
r@(RD_A IPv4
_)           = RData
r
copy (RD_NS ByteString
dom)          = ByteString -> RData
RD_NS forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
B.copy ByteString
dom
copy (RD_CNAME ByteString
dom)       = ByteString -> RData
RD_CNAME forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
B.copy ByteString
dom
copy (RD_SOA ByteString
mn ByteString
mr Word32
a Word32
b Word32
c Word32
d Word32
e) = ByteString
-> ByteString
-> Word32
-> Word32
-> Word32
-> Word32
-> Word32
-> RData
RD_SOA (ByteString -> ByteString
B.copy ByteString
mn) (ByteString -> ByteString
B.copy ByteString
mr) Word32
a Word32
b Word32
c Word32
d Word32
e
copy (RD_PTR ByteString
dom)         = ByteString -> RData
RD_PTR forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
B.copy ByteString
dom
copy (RD_NULL ByteString
bytes)      = ByteString -> RData
RD_NULL forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
B.copy ByteString
bytes
copy (RD_MX Word16
prf ByteString
dom)      = Word16 -> ByteString -> RData
RD_MX Word16
prf forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
B.copy ByteString
dom
copy (RD_TXT ByteString
txt)         = ByteString -> RData
RD_TXT forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
B.copy ByteString
txt
copy (RD_RP ByteString
mbox ByteString
dname)   = ByteString -> ByteString -> RData
RD_RP (ByteString -> ByteString
B.copy ByteString
mbox) (ByteString -> ByteString
B.copy ByteString
dname)
copy r :: RData
r@(RD_AAAA IPv6
_)        = RData
r
copy (RD_SRV Word16
a Word16
b Word16
c ByteString
dom)   = Word16 -> Word16 -> Word16 -> ByteString -> RData
RD_SRV Word16
a Word16
b Word16
c forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
B.copy ByteString
dom
copy (RD_DNAME ByteString
dom)       = ByteString -> RData
RD_DNAME forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
B.copy ByteString
dom
copy (RD_OPT [OData]
od)          = [OData] -> RData
RD_OPT forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map OData -> OData
copyOData [OData]
od
copy (RD_DS Word16
t Word8
a Word8
dt ByteString
dv)    = Word16 -> Word8 -> Word8 -> ByteString -> RData
RD_DS Word16
t Word8
a Word8
dt forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
B.copy ByteString
dv
copy (RD_CDS Word16
t Word8
a Word8
dt ByteString
dv)   = Word16 -> Word8 -> Word8 -> ByteString -> RData
RD_CDS Word16
t Word8
a Word8
dt forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
B.copy ByteString
dv
copy (RD_NSEC ByteString
dom [TYPE]
ts)     = ByteString -> [TYPE] -> RData
RD_NSEC (ByteString -> ByteString
B.copy ByteString
dom) [TYPE]
ts
copy (RD_DNSKEY Word16
f Word8
p Word8
a ByteString
k)  = Word16 -> Word8 -> Word8 -> ByteString -> RData
RD_DNSKEY Word16
f Word8
p Word8
a forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
B.copy ByteString
k
copy (RD_CDNSKEY Word16
f Word8
p Word8
a ByteString
k) = Word16 -> Word8 -> Word8 -> ByteString -> RData
RD_CDNSKEY Word16
f Word8
p Word8
a forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
B.copy ByteString
k
copy (RD_TLSA Word8
a Word8
b Word8
c ByteString
dgst) = Word8 -> Word8 -> Word8 -> ByteString -> RData
RD_TLSA Word8
a Word8
b Word8
c forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
B.copy ByteString
dgst
copy (RD_NSEC3 Word8
a Word8
b Word16
c ByteString
s ByteString
h [TYPE]
t) = Word8
-> Word8 -> Word16 -> ByteString -> ByteString -> [TYPE] -> RData
RD_NSEC3 Word8
a Word8
b Word16
c (ByteString -> ByteString
B.copy ByteString
s) (ByteString -> ByteString
B.copy ByteString
h) [TYPE]
t
copy (RD_NSEC3PARAM Word8
a Word8
b Word16
c ByteString
salt) = Word8 -> Word8 -> Word16 -> ByteString -> RData
RD_NSEC3PARAM Word8
a Word8
b Word16
c forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
B.copy ByteString
salt
copy (RD_RRSIG RD_RRSIG
sig)       = RD_RRSIG -> RData
RD_RRSIG forall a b. (a -> b) -> a -> b
$ RD_RRSIG -> RD_RRSIG
copysig RD_RRSIG
sig
  where
    copysig :: RD_RRSIG -> RD_RRSIG
copysig s :: RD_RRSIG
s@RDREP_RRSIG{Int64
Word8
Word16
Word32
ByteString
TYPE
rrsigValue :: RD_RRSIG -> ByteString
rrsigZone :: RD_RRSIG -> ByteString
rrsigKeyTag :: RD_RRSIG -> Word16
rrsigInception :: RD_RRSIG -> Int64
rrsigExpiration :: RD_RRSIG -> Int64
rrsigTTL :: RD_RRSIG -> Word32
rrsigNumLabels :: RD_RRSIG -> Word8
rrsigKeyAlg :: RD_RRSIG -> Word8
rrsigType :: RD_RRSIG -> TYPE
rrsigValue :: ByteString
rrsigZone :: ByteString
rrsigKeyTag :: Word16
rrsigInception :: Int64
rrsigExpiration :: Int64
rrsigTTL :: Word32
rrsigNumLabels :: Word8
rrsigKeyAlg :: Word8
rrsigType :: TYPE
..} =
        RD_RRSIG
s { rrsigZone :: ByteString
rrsigZone = ByteString -> ByteString
B.copy ByteString
rrsigZone
          , rrsigValue :: ByteString
rrsigValue = ByteString -> ByteString
B.copy ByteString
rrsigValue }
copy (RD_CAA Word8
f CI ByteString
t ByteString
v)       = Word8 -> CI ByteString -> ByteString -> RData
RD_CAA Word8
f (forall s. FoldCase s => s -> CI s
CI.mk (ByteString -> ByteString
B.copy (forall s. CI s -> s
CI.original CI ByteString
t))) (ByteString -> ByteString
B.copy ByteString
v)
copy (UnknownRData ByteString
is)    = ByteString -> RData
UnknownRData forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
B.copy ByteString
is

copyOData :: OData -> OData
copyOData :: OData -> OData
copyOData (OD_ECSgeneric Word16
family Word8
srcBits Word8
scpBits ByteString
bs) =
    Word16 -> Word8 -> Word8 -> ByteString -> OData
OD_ECSgeneric Word16
family Word8
srcBits Word8
scpBits forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
B.copy ByteString
bs
copyOData (OD_NSID ByteString
nsid) = ByteString -> OData
OD_NSID forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
B.copy ByteString
nsid
copyOData (UnknownOData Word16
c ByteString
b)        = Word16 -> ByteString -> OData
UnknownOData Word16
c forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
B.copy ByteString
b

-- No copying required for the rest, but avoiding a wildcard pattern match
-- so that if more option types are added in the future, the compiler will
-- complain about a partial function.
--
copyOData o :: OData
o@OD_ClientSubnet {} = OData
o
copyOData o :: OData
o@OD_DAU {} = OData
o
copyOData o :: OData
o@OD_DHU {} = OData
o
copyOData o :: OData
o@OD_N3U {} = OData
o