{-# Language GADTs #-}
{-# Language NamedFieldPuns #-}
module EVM.Stepper
  ( Action (..)
  , Failure (..)
  , Stepper
  , exec
  , execFully
  , execFullyOrFail
  , decode
  , fail
  , wait
  , evm
  , note
  , entering
  , enter
  )
where
import Prelude hiding (fail)
import Control.Monad.Operational (Program, singleton)
import Data.Binary.Get (runGetOrFail)
import Data.Text (Text)
import EVM (EVM, VMResult (VMFailure, VMSuccess), Error (Query), Query)
import qualified EVM
import EVM.ABI (AbiType, AbiValue, getAbi)
import Data.ByteString (ByteString)
import qualified Data.ByteString.Lazy as LazyByteString
data Action a where
  
  Exec ::            Action VMResult
  
  Fail :: Failure -> Action a
  
  Wait :: Query   -> Action ()
  
  EVM  :: EVM a   -> Action a
  
  Note :: Text    -> Action ()
data Failure
  = ContractNotFound
  | DecodingError
  | VMFailed Error
  deriving Show
type Stepper a = Program Action a
exec :: Stepper VMResult
exec = singleton Exec
fail :: Failure -> Stepper a
fail = singleton . Fail
wait :: Query -> Stepper ()
wait = singleton . Wait
evm :: EVM a -> Stepper a
evm = singleton . EVM
note :: Text -> Stepper ()
note = singleton . Note
execFully :: Stepper (Either Error ByteString)
execFully =
  exec >>= \case
    VMFailure (Query q) ->
      wait q >> execFully
    VMFailure x ->
      pure (Left x)
    VMSuccess x ->
      pure (Right x)
execFullyOrFail :: Stepper ByteString
execFullyOrFail = execFully >>= either (fail . VMFailed) pure
decode :: AbiType -> ByteString -> Stepper AbiValue
decode abiType bytes =
  case runGetOrFail (getAbi abiType) (LazyByteString.fromStrict bytes) of
    Right ("", _, x) ->
      pure x
    Right _ ->
      fail DecodingError
    Left _ ->
      fail DecodingError
entering :: Text -> Stepper a -> Stepper a
entering t stepper = do
  evm (EVM.pushTrace (EVM.EntryTrace t))
  x <- stepper
  evm EVM.popTrace
  pure x
enter :: Text -> Stepper ()
enter t = do
  evm (EVM.pushTrace (EVM.EntryTrace t))