module SSH.Sender where
import Control.Concurrent.Chan (Chan, readChan)
import Control.Monad (replicateM)
import Data.Word (Word8, Word32)
import System.IO (Handle, hFlush)
import System.Random (randomRIO)
import qualified Codec.Crypto.SimpleAES as A
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as LBS
import SSH.Debug (dump)
import SSH.Crypto (Cipher(..), HMAC(..), CipherType(..), CipherMode(..), fromBlocks, toBlocks)
import SSH.Packet (Packet, doPacket, raw, byte, long)
import SSH.Util (strictLBS)
data SenderState
= NoKeys
{ senderThem :: Handle
, senderOutSeq :: Word32
}
| GotKeys
{ senderThem :: Handle
, senderOutSeq :: Word32
, senderEncrypting :: Bool
, senderCipher :: Cipher
, senderKey :: BS.ByteString
, senderVector :: BS.ByteString
, senderHMAC :: HMAC
}
data SenderMessage
= Prepare Cipher BS.ByteString BS.ByteString HMAC
| StartEncrypting
| Send LBS.ByteString
| Stop
class Sender a where
send :: SenderMessage -> a ()
sendPacket :: Packet () -> a ()
sendPacket = send . Send . doPacket
sender :: Chan SenderMessage -> SenderState -> IO ()
sender ms ss = do
m <- readChan ms
case m of
Stop -> return ()
Prepare cipher key iv hmac -> do
dump ("initiating encryption", key, iv)
sender ms (GotKeys (senderThem ss) (senderOutSeq ss) False cipher key iv hmac)
StartEncrypting -> do
dump ("starting encryption")
sender ms (ss { senderEncrypting = True })
Send msg -> do
pad <- fmap (LBS.pack . map fromIntegral) $
replicateM (fromIntegral $ paddingLen msg) (randomRIO (0, 255 :: Int))
let f = full msg pad
case ss of
GotKeys h os True cipher key iv (HMAC _ mac) -> do
dump ("sending encrypted", os, f)
let (encrypted, newVector) = encrypt cipher key iv f
LBS.hPut h . LBS.concat $
[ encrypted
, mac . doPacket $ long os >> raw f
]
hFlush h
sender ms $ ss
{ senderOutSeq = senderOutSeq ss + 1
, senderVector = newVector
}
_ -> do
dump ("sending unencrypted", senderOutSeq ss, f)
LBS.hPut (senderThem ss) f
hFlush (senderThem ss)
sender ms (ss { senderOutSeq = senderOutSeq ss + 1 })
where
blockSize =
case ss of
GotKeys { senderCipher = Cipher _ _ bs _ }
| bs > 8 -> bs
_ -> 8
full msg pad = doPacket $ do
long (len msg)
byte (paddingLen msg)
raw msg
raw pad
len :: LBS.ByteString -> Word32
len msg = 1 + fromIntegral (LBS.length msg) + fromIntegral (paddingLen msg)
paddingNeeded :: LBS.ByteString -> Word8
paddingNeeded msg = fromIntegral blockSize (fromIntegral $ (5 + LBS.length msg) `mod` fromIntegral blockSize)
paddingLen :: LBS.ByteString -> Word8
paddingLen msg =
if paddingNeeded msg < 4
then paddingNeeded msg + fromIntegral blockSize
else paddingNeeded msg
encrypt :: Cipher -> BS.ByteString -> BS.ByteString -> LBS.ByteString -> (LBS.ByteString, BS.ByteString)
encrypt (Cipher AES CBC bs _) key vector m =
( fromBlocks encrypted
, case encrypted of
(_:_) -> strictLBS (last encrypted)
[] -> error ("encrypted data empty for `" ++ show m ++ "' in encrypt") vector
)
where
encrypted = toBlocks bs $ A.crypt A.CBC key vector A.Encrypt m