{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE OverloadedStrings #-}

-- |
-- SPDX-License-Identifier: BSD-3-Clause
--
-- Arithmetic and Comparison commands
module Swarm.Game.Step.Arithmetic where

import Control.Carrier.State.Lazy
import Control.Effect.Error
import Control.Monad (zipWithM)
import Data.Function (on)
import Data.Map qualified as M
import Data.Text qualified as T
import Swarm.Game.Exception
import Swarm.Game.Step.Util
import Swarm.Language.Syntax
import Swarm.Language.Value
import Witch (From (from))
import Prelude hiding (lookup)

------------------------------------------------------------
-- Comparison
------------------------------------------------------------

-- | Evaluate the application of a comparison operator.  Returns
--   @Nothing@ if the application does not make sense.
evalCmp :: Has (Throw Exn) sig m => Const -> Value -> Value -> m Bool
evalCmp :: forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
Has (Throw Exn) sig m =>
Const -> Value -> Value -> m Bool
evalCmp Const
c Value
v1 Value
v2 = Const -> m Ordering -> m Bool
decideCmp Const
c (m Ordering -> m Bool) -> m Ordering -> m Bool
forall a b. (a -> b) -> a -> b
$ Value -> Value -> m Ordering
forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
Has (Throw Exn) sig m =>
Value -> Value -> m Ordering
compareValues Value
v1 Value
v2
 where
  decideCmp :: Const -> m Ordering -> m Bool
decideCmp = \case
    Const
Eq -> (Ordering -> Bool) -> m Ordering -> m Bool
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Ordering -> Ordering -> Bool
forall a. Eq a => a -> a -> Bool
== Ordering
EQ)
    Const
Neq -> (Ordering -> Bool) -> m Ordering -> m Bool
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Ordering -> Ordering -> Bool
forall a. Eq a => a -> a -> Bool
/= Ordering
EQ)
    Const
Lt -> (Ordering -> Bool) -> m Ordering -> m Bool
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Ordering -> Ordering -> Bool
forall a. Eq a => a -> a -> Bool
== Ordering
LT)
    Const
Gt -> (Ordering -> Bool) -> m Ordering -> m Bool
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Ordering -> Ordering -> Bool
forall a. Eq a => a -> a -> Bool
== Ordering
GT)
    Const
Leq -> (Ordering -> Bool) -> m Ordering -> m Bool
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Ordering -> Ordering -> Bool
forall a. Eq a => a -> a -> Bool
/= Ordering
GT)
    Const
Geq -> (Ordering -> Bool) -> m Ordering -> m Bool
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Ordering -> Ordering -> Bool
forall a. Eq a => a -> a -> Bool
/= Ordering
LT)
    Const
_ -> m Bool -> m Ordering -> m Bool
forall a b. a -> b -> a
const (m Bool -> m Ordering -> m Bool)
-> (String -> m Bool) -> String -> m Ordering -> m Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exn -> m Bool
forall e (sig :: (* -> *) -> * -> *) (m :: * -> *) a.
Has (Throw e) sig m =>
e -> m a
throwError (Exn -> m Bool) -> (String -> Exn) -> String -> m Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Exn
Fatal (Text -> Exn) -> (String -> Text) -> String -> Exn
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Text -> Text
T.append Text
"evalCmp called on bad constant " (Text -> Text) -> (String -> Text) -> String -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Text
forall source target. From source target => source -> target
from (String -> m Ordering -> m Bool) -> String -> m Ordering -> m Bool
forall a b. (a -> b) -> a -> b
$ Const -> String
forall a. Show a => a -> String
show Const
c

-- | Compare two values, returning an 'Ordering' if they can be
--   compared, or @Nothing@ if they cannot.
compareValues :: Has (Throw Exn) sig m => Value -> Value -> m Ordering
compareValues :: forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
Has (Throw Exn) sig m =>
Value -> Value -> m Ordering
compareValues Value
v1 = case Value
v1 of
  Value
VUnit -> \case Value
VUnit -> Ordering -> m Ordering
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Ordering
EQ; Value
v2 -> Value -> Value -> m Ordering
forall (sig :: (* -> *) -> * -> *) (m :: * -> *) a.
Has (Throw Exn) sig m =>
Value -> Value -> m a
incompatCmp Value
VUnit Value
v2
  VInt Integer
n1 -> \case VInt Integer
n2 -> Ordering -> m Ordering
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Integer -> Integer -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Integer
n1 Integer
n2); Value
v2 -> Value -> Value -> m Ordering
forall (sig :: (* -> *) -> * -> *) (m :: * -> *) a.
Has (Throw Exn) sig m =>
Value -> Value -> m a
incompatCmp Value
v1 Value
v2
  VText Text
t1 -> \case VText Text
t2 -> Ordering -> m Ordering
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Text -> Text -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Text
t1 Text
t2); Value
v2 -> Value -> Value -> m Ordering
forall (sig :: (* -> *) -> * -> *) (m :: * -> *) a.
Has (Throw Exn) sig m =>
Value -> Value -> m a
incompatCmp Value
v1 Value
v2
  VDir Direction
d1 -> \case VDir Direction
d2 -> Ordering -> m Ordering
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Direction -> Direction -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Direction
d1 Direction
d2); Value
v2 -> Value -> Value -> m Ordering
forall (sig :: (* -> *) -> * -> *) (m :: * -> *) a.
Has (Throw Exn) sig m =>
Value -> Value -> m a
incompatCmp Value
v1 Value
v2
  VBool Bool
b1 -> \case VBool Bool
b2 -> Ordering -> m Ordering
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> Bool -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Bool
b1 Bool
b2); Value
v2 -> Value -> Value -> m Ordering
forall (sig :: (* -> *) -> * -> *) (m :: * -> *) a.
Has (Throw Exn) sig m =>
Value -> Value -> m a
incompatCmp Value
v1 Value
v2
  VRobot Int
r1 -> \case VRobot Int
r2 -> Ordering -> m Ordering
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Int
r1 Int
r2); Value
v2 -> Value -> Value -> m Ordering
forall (sig :: (* -> *) -> * -> *) (m :: * -> *) a.
Has (Throw Exn) sig m =>
Value -> Value -> m a
incompatCmp Value
v1 Value
v2
  VInj Bool
s1 Value
v1' -> \case
    VInj Bool
s2 Value
v2' ->
      case Bool -> Bool -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Bool
s1 Bool
s2 of
        Ordering
EQ -> Value -> Value -> m Ordering
forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
Has (Throw Exn) sig m =>
Value -> Value -> m Ordering
compareValues Value
v1' Value
v2'
        Ordering
o -> Ordering -> m Ordering
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Ordering
o
    Value
v2 -> Value -> Value -> m Ordering
forall (sig :: (* -> *) -> * -> *) (m :: * -> *) a.
Has (Throw Exn) sig m =>
Value -> Value -> m a
incompatCmp Value
v1 Value
v2
  VPair Value
v11 Value
v12 -> \case
    VPair Value
v21 Value
v22 ->
      Ordering -> Ordering -> Ordering
forall a. Semigroup a => a -> a -> a
(<>) (Ordering -> Ordering -> Ordering)
-> m Ordering -> m (Ordering -> Ordering)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Value -> Value -> m Ordering
forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
Has (Throw Exn) sig m =>
Value -> Value -> m Ordering
compareValues Value
v11 Value
v21 m (Ordering -> Ordering) -> m Ordering -> m Ordering
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Value -> Value -> m Ordering
forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
Has (Throw Exn) sig m =>
Value -> Value -> m Ordering
compareValues Value
v12 Value
v22
    Value
v2 -> Value -> Value -> m Ordering
forall (sig :: (* -> *) -> * -> *) (m :: * -> *) a.
Has (Throw Exn) sig m =>
Value -> Value -> m a
incompatCmp Value
v1 Value
v2
  VRcd Map Text Value
m1 -> \case
    VRcd Map Text Value
m2 -> [Ordering] -> Ordering
forall a. Monoid a => [a] -> a
mconcat ([Ordering] -> Ordering) -> m [Ordering] -> m Ordering
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Value -> Value -> m Ordering)
-> [Value] -> [Value] -> m [Ordering]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM Value -> Value -> m Ordering
forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
Has (Throw Exn) sig m =>
Value -> Value -> m Ordering
compareValues ([Value] -> [Value] -> m [Ordering])
-> (Map Text Value -> [Value])
-> Map Text Value
-> Map Text Value
-> m [Ordering]
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` Map Text Value -> [Value]
forall k a. Map k a -> [a]
M.elems) Map Text Value
m1 Map Text Value
m2
    Value
v2 -> Value -> Value -> m Ordering
forall (sig :: (* -> *) -> * -> *) (m :: * -> *) a.
Has (Throw Exn) sig m =>
Value -> Value -> m a
incompatCmp Value
v1 Value
v2
  VKey KeyCombo
kc1 -> \case
    VKey KeyCombo
kc2 -> Ordering -> m Ordering
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (KeyCombo -> KeyCombo -> Ordering
forall a. Ord a => a -> a -> Ordering
compare KeyCombo
kc1 KeyCombo
kc2)
    Value
v2 -> Value -> Value -> m Ordering
forall (sig :: (* -> *) -> * -> *) (m :: * -> *) a.
Has (Throw Exn) sig m =>
Value -> Value -> m a
incompatCmp Value
v1 Value
v2
  VClo {} -> Value -> Value -> m Ordering
forall (sig :: (* -> *) -> * -> *) (m :: * -> *) a.
Has (Throw Exn) sig m =>
Value -> Value -> m a
incomparable Value
v1
  VCApp {} -> Value -> Value -> m Ordering
forall (sig :: (* -> *) -> * -> *) (m :: * -> *) a.
Has (Throw Exn) sig m =>
Value -> Value -> m a
incomparable Value
v1
  VBind {} -> Value -> Value -> m Ordering
forall (sig :: (* -> *) -> * -> *) (m :: * -> *) a.
Has (Throw Exn) sig m =>
Value -> Value -> m a
incomparable Value
v1
  VDelay {} -> Value -> Value -> m Ordering
forall (sig :: (* -> *) -> * -> *) (m :: * -> *) a.
Has (Throw Exn) sig m =>
Value -> Value -> m a
incomparable Value
v1
  VRef {} -> Value -> Value -> m Ordering
forall (sig :: (* -> *) -> * -> *) (m :: * -> *) a.
Has (Throw Exn) sig m =>
Value -> Value -> m a
incomparable Value
v1
  VIndir {} -> Value -> Value -> m Ordering
forall (sig :: (* -> *) -> * -> *) (m :: * -> *) a.
Has (Throw Exn) sig m =>
Value -> Value -> m a
incomparable Value
v1
  VRequirements {} -> Value -> Value -> m Ordering
forall (sig :: (* -> *) -> * -> *) (m :: * -> *) a.
Has (Throw Exn) sig m =>
Value -> Value -> m a
incomparable Value
v1
  VSuspend {} -> Value -> Value -> m Ordering
forall (sig :: (* -> *) -> * -> *) (m :: * -> *) a.
Has (Throw Exn) sig m =>
Value -> Value -> m a
incomparable Value
v1
  VExc {} -> Value -> Value -> m Ordering
forall (sig :: (* -> *) -> * -> *) (m :: * -> *) a.
Has (Throw Exn) sig m =>
Value -> Value -> m a
incomparable Value
v1
  VBlackhole {} -> Value -> Value -> m Ordering
forall (sig :: (* -> *) -> * -> *) (m :: * -> *) a.
Has (Throw Exn) sig m =>
Value -> Value -> m a
incomparable Value
v1
  VType {} -> Value -> Value -> m Ordering
forall (sig :: (* -> *) -> * -> *) (m :: * -> *) a.
Has (Throw Exn) sig m =>
Value -> Value -> m a
incomparable Value
v1

-- | Values with different types were compared; this should not be
--   possible since the type system should catch it.
incompatCmp :: Has (Throw Exn) sig m => Value -> Value -> m a
incompatCmp :: forall (sig :: (* -> *) -> * -> *) (m :: * -> *) a.
Has (Throw Exn) sig m =>
Value -> Value -> m a
incompatCmp Value
v1 Value
v2 =
  Exn -> m a
forall e (sig :: (* -> *) -> * -> *) (m :: * -> *) a.
Has (Throw e) sig m =>
e -> m a
throwError (Exn -> m a) -> Exn -> m a
forall a b. (a -> b) -> a -> b
$
    Text -> Exn
Fatal (Text -> Exn) -> Text -> Exn
forall a b. (a -> b) -> a -> b
$
      [Text] -> Text
T.unwords [Text
"Incompatible comparison of ", Value -> Text
prettyValue Value
v1, Text
"and", Value -> Text
prettyValue Value
v2]

-- | Values were compared of a type which cannot be compared
--   (e.g. functions, etc.).
incomparable :: Has (Throw Exn) sig m => Value -> Value -> m a
incomparable :: forall (sig :: (* -> *) -> * -> *) (m :: * -> *) a.
Has (Throw Exn) sig m =>
Value -> Value -> m a
incomparable Value
v1 Value
v2 =
  Exn -> m a
forall e (sig :: (* -> *) -> * -> *) (m :: * -> *) a.
Has (Throw e) sig m =>
e -> m a
throwError (Exn -> m a) -> Exn -> m a
forall a b. (a -> b) -> a -> b
$
    Const -> [Text] -> Exn
cmdExn
      Const
Lt
      [Text
"Comparison is undefined for ", Value -> Text
prettyValue Value
v1, Text
"and", Value -> Text
prettyValue Value
v2]

------------------------------------------------------------
-- Arithmetic
------------------------------------------------------------

-- | Evaluate the application of an arithmetic operator, returning
--   an exception in the case of a failing operation, or in case we
--   incorrectly use it on a bad 'Const' in the library.
evalArith :: Has (Throw Exn) sig m => Const -> Integer -> Integer -> m Integer
evalArith :: forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
Has (Throw Exn) sig m =>
Const -> Integer -> Integer -> m Integer
evalArith = \case
  Const
Add -> (Integer -> Integer -> Integer) -> Integer -> Integer -> m Integer
forall {m :: * -> *} {t} {t} {a}.
Monad m =>
(t -> t -> a) -> t -> t -> m a
ok Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
(+)
  Const
Sub -> (Integer -> Integer -> Integer) -> Integer -> Integer -> m Integer
forall {m :: * -> *} {t} {t} {a}.
Monad m =>
(t -> t -> a) -> t -> t -> m a
ok (-)
  Const
Mul -> (Integer -> Integer -> Integer) -> Integer -> Integer -> m Integer
forall {m :: * -> *} {t} {t} {a}.
Monad m =>
(t -> t -> a) -> t -> t -> m a
ok Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
(*)
  Const
Div -> Integer -> Integer -> m Integer
forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
Has (Throw Exn) sig m =>
Integer -> Integer -> m Integer
safeDiv
  Const
Exp -> Integer -> Integer -> m Integer
forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
Has (Throw Exn) sig m =>
Integer -> Integer -> m Integer
safeExp
  Const
c -> \Integer
_ Integer
_ -> Exn -> m Integer
forall e (sig :: (* -> *) -> * -> *) (m :: * -> *) a.
Has (Throw e) sig m =>
e -> m a
throwError (Exn -> m Integer) -> Exn -> m Integer
forall a b. (a -> b) -> a -> b
$ Text -> Exn
Fatal (Text -> Exn) -> Text -> Exn
forall a b. (a -> b) -> a -> b
$ Text -> Text -> Text
T.append Text
"evalArith called on bad constant " (String -> Text
forall source target. From source target => source -> target
from (Const -> String
forall a. Show a => a -> String
show Const
c))
 where
  ok :: (t -> t -> a) -> t -> t -> m a
ok t -> t -> a
f t
x t
y = a -> m a
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> m a) -> a -> m a
forall a b. (a -> b) -> a -> b
$ t -> t -> a
f t
x t
y

-- | Perform an integer division, but return @Nothing@ for division by
--   zero.
safeDiv :: Has (Throw Exn) sig m => Integer -> Integer -> m Integer
safeDiv :: forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
Has (Throw Exn) sig m =>
Integer -> Integer -> m Integer
safeDiv Integer
_ Integer
0 = Exn -> m Integer
forall e (sig :: (* -> *) -> * -> *) (m :: * -> *) a.
Has (Throw e) sig m =>
e -> m a
throwError (Exn -> m Integer) -> Exn -> m Integer
forall a b. (a -> b) -> a -> b
$ Const -> [Text] -> Exn
cmdExn Const
Div ([Text] -> Exn) -> [Text] -> Exn
forall a b. (a -> b) -> a -> b
$ Text -> [Text]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure Text
"Division by zero"
safeDiv Integer
a Integer
b = Integer -> m Integer
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Integer -> m Integer) -> Integer -> m Integer
forall a b. (a -> b) -> a -> b
$ Integer
a Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`div` Integer
b

-- | Perform exponentiation, but return @Nothing@ if the power is negative.
safeExp :: Has (Throw Exn) sig m => Integer -> Integer -> m Integer
safeExp :: forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
Has (Throw Exn) sig m =>
Integer -> Integer -> m Integer
safeExp Integer
a Integer
b
  | Integer
b Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
0 = Exn -> m Integer
forall e (sig :: (* -> *) -> * -> *) (m :: * -> *) a.
Has (Throw e) sig m =>
e -> m a
throwError (Exn -> m Integer) -> Exn -> m Integer
forall a b. (a -> b) -> a -> b
$ Const -> [Text] -> Exn
cmdExn Const
Exp ([Text] -> Exn) -> [Text] -> Exn
forall a b. (a -> b) -> a -> b
$ Text -> [Text]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure Text
"Negative exponent"
  | Bool
otherwise = Integer -> m Integer
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Integer -> m Integer) -> Integer -> m Integer
forall a b. (a -> b) -> a -> b
$ Integer
a Integer -> Integer -> Integer
forall a b. (Num a, Integral b) => a -> b -> a
^ Integer
b