{- |
Module      : Language.Egison.Type.TensorMapInsertion
Licence     : MIT

This module implements automatic tensorMap insertion for Phase 8 of the Egison compiler.
This is the first step of TypedDesugar, before type class expansion.
When a function expects a scalar type (e.g., Integer) but receives a Tensor type,
this module automatically inserts tensorMap to apply the function element-wise.

Two insertion modes:
1. Direct application: When argument is Tensor and parameter expects scalar,
   wrap the application with tensorMap.
2. Higher-order functions (simplified approach): When a binary function with
   constrained/scalar parameter types is passed as an argument, always wrap
   it with tensorMap2. This handles cases like `foldl1 (+) xs` where elements
   of xs might be Tensors at runtime.

According to tensor-map-insertion-simple.md:
- When a binary function with scalar parameter types is passed as an argument, always wrap it with tensorMap2
- tensorMap/tensorMap2 act as identity for scalar values, so wrapping is safe regardless of whether the actual argument is a tensor or scalar

Example:
  def f (x : Integer) : Integer := x
  def t1 := [| 1, 2 |]
  f t1  --=>  tensorMap (\t1e -> f t1e) t1

  def sum {Num a} (xs: [a]) : a := foldl1 (+) xs
  --=>  def sum {Num a} (xs: [a]) : a := foldl1 (tensorMap2 (+)) xs
-}

module Language.Egison.Type.TensorMapInsertion
  ( insertTensorMaps
  ) where

import           Data.List                  (nub)
import           Language.Egison.Data       (EvalM)
import           Language.Egison.EvalState  (MonadEval(..))
import           Language.Egison.IExpr      (TIExpr(..), TIExprNode(..),
                                             Var(..), tiExprType, tiScheme, tiExprNode)
import           Language.Egison.Type.Env   (ClassEnv)
import           Language.Egison.Type.Tensor ()
import           Language.Egison.Type.Types (Type(..), TypeScheme(..), Constraint(..), TyVar(..))
import           Language.Egison.Type.Unify as Unify (unifyStrictWithConstraints)

--------------------------------------------------------------------------------
-- * TensorMap Insertion Decision Logic
--------------------------------------------------------------------------------

-- | Check if tensorMap should be inserted for an argument
-- This implements the type-tensor-simple.md specification
--
-- TensorMap should be inserted when:
-- 1. paramType does NOT unify with Tensor a (i.e., paramType is a scalar type)
-- 2. AND argType does unify with Tensor a (i.e., argType is a tensor type)
--
-- Arguments:
--   ClassEnv     : The current type class environment (holds available type class instances).
--   [Constraint] : The set of type class constraints in scope (e.g., Num a, Eq a).
--   Type         : The type of the argument being applied to the function.
--   Type         : The type of the parameter as expected by the function (i.e., declared type).
shouldInsertTensorMap :: ClassEnv -> [Constraint] -> Type -> Type -> Bool
shouldInsertTensorMap :: ClassEnv -> [Constraint] -> Type -> Type -> Bool
shouldInsertTensorMap ClassEnv
classEnv [Constraint]
constraints Type
argType Type
paramType =
  -- Check if paramType does NOT unify with Tensor a (is scalar)
  let isParamScalar :: Bool
isParamScalar = ClassEnv -> [Constraint] -> Type -> Bool
isPotentialScalarType ClassEnv
classEnv [Constraint]
constraints Type
paramType
      -- Check if argType does unify with Tensor a (is tensor)
      freshVar :: TyVar
freshVar = String -> TyVar
TyVar String
"a_arg_check"
      tensorType :: Type
tensorType = Type -> Type
TTensor (TyVar -> Type
TVar TyVar
freshVar)
      isArgTensor :: Bool
isArgTensor = case ClassEnv -> [Constraint] -> Type -> Type -> Either UnifyError Subst
Unify.unifyStrictWithConstraints ClassEnv
classEnv [Constraint]
constraints Type
argType Type
tensorType of
                      Right Subst
_ -> Bool
True   -- Can unify with Tensor a → is tensor
                      Left UnifyError
_  -> Bool
False  -- Cannot unify → not tensor
  in Bool
isParamScalar Bool -> Bool -> Bool
&& Bool
isArgTensor


-- | Unlift a function type that was lifted for Tensor arguments
-- Tensor a -> Tensor b -> Tensor c  becomes  a -> b -> c
unliftFunctionType :: Type -> Type
unliftFunctionType :: Type -> Type
unliftFunctionType (TFun (TTensor Type
paramType) Type
restType) =
  Type -> Type -> Type
TFun Type
paramType (Type -> Type
unliftFunctionType Type
restType)
unliftFunctionType (TFun Type
paramType Type
restType) =
  Type -> Type -> Type
TFun Type
paramType (Type -> Type
unliftFunctionType Type
restType)
unliftFunctionType (TTensor Type
returnType) = Type
returnType
unliftFunctionType Type
ty = Type
ty

-- | Get the parameter type at the specified index from a function type
-- Example: (a -> b -> c) at index 0 → Just a, at index 1 → Just b
getParamType :: Type -> Int -> Maybe Type
getParamType :: Type -> Int -> Maybe Type
getParamType (TFun Type
param Type
_) Int
0 = Type -> Maybe Type
forall a. a -> Maybe a
Just Type
param
getParamType (TFun Type
_ Type
rest) Int
n 
  | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 = Type -> Int -> Maybe Type
getParamType Type
rest (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
getParamType Type
_ Int
_ = Maybe Type
forall a. Maybe a
Nothing

-- | Apply one argument to a function type
-- Example: (a -> b -> c) → (b -> c)
applyOneArgType :: Type -> Type
applyOneArgType :: Type -> Type
applyOneArgType (TFun Type
_ Type
rest) = Type
rest
applyOneArgType Type
t = Type
t  -- No more arguments

--------------------------------------------------------------------------------
-- * Simplified Approach: Always wrap binary functions with tensorMap2
--------------------------------------------------------------------------------

-- | Check if a type is a scalar type (not a Tensor type)
-- A scalar type is one that does NOT unify with Tensor a (using strict unification with constraints).
--
-- This uses unifyStrictWithConstraints to determine if a type can unify with Tensor a:
-- - If unification succeeds → the type IS compatible with Tensor → NOT a scalar type (False)
-- - If unification fails → the type is NOT compatible with Tensor → IS a scalar type (True)
--
-- Examples:
-- - {Num t0} t0: Tensor a doesn't have Num instance → cannot unify → scalar type (True)
-- - Tensor t0: Tensor t0 unifies with Tensor a → not a scalar type (False)
-- - Integer: Integer doesn't unify with Tensor a (concrete type mismatch) → scalar type (True)
-- - Unconstrained type variable a: can unify with Tensor b → not a scalar type (False)
isPotentialScalarType :: ClassEnv -> [Constraint] -> Type -> Bool
isPotentialScalarType :: ClassEnv -> [Constraint] -> Type -> Bool
isPotentialScalarType ClassEnv
classEnv [Constraint]
constraints Type
ty =
  -- Create a fresh type variable 'a' and try to unify ty with Tensor a
  let freshVar :: TyVar
freshVar = String -> TyVar
TyVar String
"a_scalar_check"
      tensorType :: Type
tensorType = Type -> Type
TTensor (TyVar -> Type
TVar TyVar
freshVar)
  in case ClassEnv -> [Constraint] -> Type -> Type -> Either UnifyError Subst
Unify.unifyStrictWithConstraints ClassEnv
classEnv [Constraint]
constraints Type
ty Type
tensorType of
       Right Subst
_ -> Bool
False  -- Can unify with Tensor a → not scalar
       Left UnifyError
_  -> Bool
True   -- Cannot unify with Tensor a → is scalar

-- | Check if a binary function should be wrapped with tensorMap2
-- A function should be wrapped if:
-- 1. It's a binary function (a -> b -> c)
-- 2. Both parameter types are scalar types (not Tensor types)
--
-- For example:
-- - (+) : {Num a} a -> a -> a  -- Both params are scalar → wrap with tensorMap2
-- - (.) : {Num a} Tensor a -> Tensor a -> Tensor a  -- Both params are Tensor → do NOT wrap
shouldWrapWithTensorMap2 :: ClassEnv -> [Constraint] -> Type -> Bool
shouldWrapWithTensorMap2 :: ClassEnv -> [Constraint] -> Type -> Bool
shouldWrapWithTensorMap2 ClassEnv
classEnv [Constraint]
constraints Type
ty = case Type
ty of
  TFun Type
param1 (TFun Type
param2 Type
_result) ->
      ClassEnv -> [Constraint] -> Type -> Bool
isPotentialScalarType ClassEnv
classEnv [Constraint]
constraints Type
param1 Bool -> Bool -> Bool
&&
      ClassEnv -> [Constraint] -> Type -> Bool
isPotentialScalarType ClassEnv
classEnv [Constraint]
constraints Type
param2
  Type
_ -> Bool
False

-- | Wrap a binary function expression with tensorMap2
-- f : a -> b -> c  becomes  \x y -> tensorMap2 f x y
-- The lambda receives TENSOR arguments and returns a TENSOR result
wrapWithTensorMap2 :: [Constraint] -> TIExpr -> TIExpr
wrapWithTensorMap2 :: [Constraint] -> TIExpr -> TIExpr
wrapWithTensorMap2 [Constraint]
_constraints TIExpr
funcExpr =
  let funcType :: Type
funcType = TIExpr -> Type
tiExprType TIExpr
funcExpr
  in case Type
funcType of
    TFun Type
param1 (TFun Type
param2 Type
result) ->
      let -- Create fresh variable names
          var1Name :: String
var1Name = String
"tmap2_arg1"
          var2Name :: String
var2Name = String
"tmap2_arg2"
          var1 :: Var
var1 = String -> [Index (Maybe Var)] -> Var
Var String
var1Name []
          var2 :: Var
var2 = String -> [Index (Maybe Var)] -> Var
Var String
var2Name []

          -- Variables have TENSOR types (they receive tensor arguments)
          var1Scheme :: TypeScheme
var1Scheme = [TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [] (Type -> Type
TTensor Type
param1)
          var2Scheme :: TypeScheme
var2Scheme = [TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [] (Type -> Type
TTensor Type
param2)
          var1TI :: TIExpr
var1TI = TypeScheme -> TIExprNode -> TIExpr
TIExpr TypeScheme
var1Scheme (String -> TIExprNode
TIVarExpr String
var1Name)
          var2TI :: TIExpr
var2TI = TypeScheme -> TIExprNode -> TIExpr
TIExpr TypeScheme
var2Scheme (String -> TIExprNode
TIVarExpr String
var2Name)

          -- Result is also a TENSOR
          resultScheme :: TypeScheme
resultScheme = [TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [] (Type -> Type
TTensor Type
result)

          -- Build: tensorMap2 funcExpr var1 var2
          innerNode :: TIExprNode
innerNode = TIExpr -> TIExpr -> TIExpr -> TIExprNode
TITensorMap2Expr TIExpr
funcExpr TIExpr
var1TI TIExpr
var2TI
          innerExpr :: TIExpr
innerExpr = TypeScheme -> TIExprNode -> TIExpr
TIExpr TypeScheme
resultScheme TIExprNode
innerNode

          -- Build lambda: \var1 var2 -> tensorMap2 funcExpr var1 var2
          -- Lambda type: Tensor a -> Tensor b -> Tensor c
          -- No constraints needed - this is just a wrapper
          lambdaType :: Type
lambdaType = Type -> Type -> Type
TFun (Type -> Type
TTensor Type
param1) (Type -> Type -> Type
TFun (Type -> Type
TTensor Type
param2) (Type -> Type
TTensor Type
result))
          lambdaScheme :: TypeScheme
lambdaScheme = [TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [] Type
lambdaType
          lambdaNode :: TIExprNode
lambdaNode = Maybe Var -> [Var] -> TIExpr -> TIExprNode
TILambdaExpr Maybe Var
forall a. Maybe a
Nothing [Var
var1, Var
var2] TIExpr
innerExpr

      in TypeScheme -> TIExprNode -> TIExpr
TIExpr TypeScheme
lambdaScheme TIExprNode
lambdaNode
    Type
_ -> TIExpr
funcExpr  -- Not a binary function, return unchanged

-- | Check if an expression is already wrapped with tensorMap2
isAlreadyWrappedWithTensorMap2 :: TIExprNode -> Bool
isAlreadyWrappedWithTensorMap2 :: TIExprNode -> Bool
isAlreadyWrappedWithTensorMap2 (TILambdaExpr Maybe Var
_ [Var
_, Var
_] TIExpr
body) =
  case TIExpr -> TIExprNode
tiExprNode TIExpr
body of
    TITensorMap2Expr TIExpr
_ TIExpr
_ TIExpr
_ -> Bool
True
    TIExprNode
_ -> Bool
False
isAlreadyWrappedWithTensorMap2 TIExprNode
_ = Bool
False

--------------------------------------------------------------------------------
-- * TensorMap Insertion Implementation
--------------------------------------------------------------------------------

-- | Insert tensorMap expressions where needed in a TIExpr
-- This is the main entry point for tensorMap insertion
insertTensorMaps :: TIExpr -> EvalM TIExpr
insertTensorMaps :: TIExpr -> EvalM TIExpr
insertTensorMaps TIExpr
tiExpr = do
  ClassEnv
classEnv <- StateT EvalState (ExceptT EgisonError RuntimeM) ClassEnv
forall (m :: * -> *). MonadEval m => m ClassEnv
getClassEnv
  let scheme :: TypeScheme
scheme = TIExpr -> TypeScheme
tiScheme TIExpr
tiExpr
  ClassEnv -> TypeScheme -> TIExpr -> EvalM TIExpr
insertTensorMapsInExpr ClassEnv
classEnv TypeScheme
scheme TIExpr
tiExpr

-- | Wrap a binary function with tensorMap2 if it should be wrapped
-- This implements the simplified approach from tensor-map-insertion-simple.md
wrapBinaryFunctionIfNeeded :: ClassEnv -> [Constraint] -> TIExpr -> TIExpr
wrapBinaryFunctionIfNeeded :: ClassEnv -> [Constraint] -> TIExpr -> TIExpr
wrapBinaryFunctionIfNeeded ClassEnv
classEnv [Constraint]
constraints TIExpr
tiExpr =
  let exprType :: Type
exprType = TIExpr -> Type
tiExprType TIExpr
tiExpr
      node :: TIExprNode
node = TIExpr -> TIExprNode
tiExprNode TIExpr
tiExpr
  in -- Don't wrap if already wrapped with tensorMap2
     if TIExprNode -> Bool
isAlreadyWrappedWithTensorMap2 TIExprNode
node
       then TIExpr
tiExpr
       else case TIExprNode
node of
         -- For binary lambda expressions like \x y -> f x y, wrap the body with tensorMap2
         -- This handles eta-expanded type class methods like \etaVar1 etaVar2 -> dict_("plus") etaVar1 etaVar2
         TILambdaExpr Maybe Var
mVar [Var
var1, Var
var2] TIExpr
body
           | ClassEnv -> [Constraint] -> Type -> Bool
shouldWrapWithTensorMap2 ClassEnv
classEnv [Constraint]
constraints Type
exprType ->
               [Constraint]
-> Maybe Var -> Var -> Var -> TIExpr -> TIExpr -> TIExpr
wrapLambdaBodyWithTensorMap2 [Constraint]
constraints Maybe Var
mVar Var
var1 Var
var2 TIExpr
body TIExpr
tiExpr
         -- Don't wrap other lambda expressions
         TILambdaExpr {} -> TIExpr
tiExpr
         -- Don't wrap function applications (they're already being applied)
         TIApplyExpr {} -> TIExpr
tiExpr
         -- Wrap variable references and other expressions that represent functions
         TIExprNode
_ | ClassEnv -> [Constraint] -> Type -> Bool
shouldWrapWithTensorMap2 ClassEnv
classEnv [Constraint]
constraints Type
exprType ->
               [Constraint] -> TIExpr -> TIExpr
wrapWithTensorMap2 [Constraint]
constraints TIExpr
tiExpr
           | Bool
otherwise -> TIExpr
tiExpr

-- | Wrap the body of a binary lambda with tensorMap2
-- Transform: \x y -> f x y  to  \x y -> tensorMap2 f x y
wrapLambdaBodyWithTensorMap2 :: [Constraint] -> Maybe Var -> Var -> Var -> TIExpr -> TIExpr -> TIExpr
wrapLambdaBodyWithTensorMap2 :: [Constraint]
-> Maybe Var -> Var -> Var -> TIExpr -> TIExpr -> TIExpr
wrapLambdaBodyWithTensorMap2 [Constraint]
constraints Maybe Var
mVar Var
var1 Var
var2 TIExpr
body TIExpr
originalExpr =
  case TIExpr -> TIExprNode
tiExprNode TIExpr
body of
    -- Body is a function application: \x y -> f x y
    TIApplyExpr TIExpr
func [TIExpr]
args
      | [TIExpr] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TIExpr]
args Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
2 ->
          let arg1 :: TIExpr
arg1 = [TIExpr]
args [TIExpr] -> Int -> TIExpr
forall a. HasCallStack => [a] -> Int -> a
!! Int
0
              arg2 :: TIExpr
arg2 = [TIExpr]
args [TIExpr] -> Int -> TIExpr
forall a. HasCallStack => [a] -> Int -> a
!! Int
1
              -- Create tensorMap2 f arg1 arg2
              resultType :: Type
resultType = TIExpr -> Type
tiExprType TIExpr
body
              resultScheme :: TypeScheme
resultScheme = [TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [] Type
resultType
              newBody :: TIExpr
newBody = TypeScheme -> TIExprNode -> TIExpr
TIExpr TypeScheme
resultScheme (TIExpr -> TIExpr -> TIExpr -> TIExprNode
TITensorMap2Expr TIExpr
func TIExpr
arg1 TIExpr
arg2)
              -- Rebuild the lambda with the new body
              (Forall [TyVar]
tvs [Constraint]
cs Type
lambdaType) = TIExpr -> TypeScheme
tiScheme TIExpr
originalExpr
              newLambdaScheme :: TypeScheme
newLambdaScheme = [TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [TyVar]
tvs ([Constraint]
constraints [Constraint] -> [Constraint] -> [Constraint]
forall a. [a] -> [a] -> [a]
++ [Constraint]
cs) Type
lambdaType
          in TypeScheme -> TIExprNode -> TIExpr
TIExpr TypeScheme
newLambdaScheme (Maybe Var -> [Var] -> TIExpr -> TIExprNode
TILambdaExpr Maybe Var
mVar [Var
var1, Var
var2] TIExpr
newBody)
    -- Body is already tensorMap2
    TITensorMap2Expr {} -> TIExpr
originalExpr
    -- Other cases: just wrap the whole thing
    TIExprNode
_ -> [Constraint] -> TIExpr -> TIExpr
wrapWithTensorMap2 [Constraint]
constraints TIExpr
originalExpr

-- | Insert tensorMap in a TIExpr with type scheme information
insertTensorMapsInExpr :: ClassEnv -> TypeScheme -> TIExpr -> EvalM TIExpr
insertTensorMapsInExpr :: ClassEnv -> TypeScheme -> TIExpr -> EvalM TIExpr
insertTensorMapsInExpr ClassEnv
classEnv TypeScheme
scheme TIExpr
tiExpr = do
  let (Forall [TyVar]
_vars [Constraint]
constraints Type
_ty) = TypeScheme
scheme
  TIExprNode
expandedNode <- ClassEnv -> [Constraint] -> TIExprNode -> EvalM TIExprNode
insertInNode ClassEnv
classEnv [Constraint]
constraints (TIExpr -> TIExprNode
tiExprNode TIExpr
tiExpr)
  -- Note: We don't wrap at this level. Wrapping only happens for function arguments
  -- in TIApplyExpr to avoid wrapping definitions like `def (*') := i.*`
  TIExpr -> EvalM TIExpr
forall a. a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIExpr -> EvalM TIExpr) -> TIExpr -> EvalM TIExpr
forall a b. (a -> b) -> a -> b
$ TypeScheme -> TIExprNode -> TIExpr
TIExpr TypeScheme
scheme TIExprNode
expandedNode
  where
    -- Process a TIExprNode
    insertInNode :: ClassEnv -> [Constraint] -> TIExprNode -> EvalM TIExprNode
    insertInNode :: ClassEnv -> [Constraint] -> TIExprNode -> EvalM TIExprNode
insertInNode ClassEnv
env [Constraint]
cs TIExprNode
node = case TIExprNode
node of
      -- Constants and variables: no change needed
      TIConstantExpr ConstantExpr
c -> TIExprNode -> EvalM TIExprNode
forall a. a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIExprNode -> EvalM TIExprNode) -> TIExprNode -> EvalM TIExprNode
forall a b. (a -> b) -> a -> b
$ ConstantExpr -> TIExprNode
TIConstantExpr ConstantExpr
c
      TIVarExpr String
name -> TIExprNode -> EvalM TIExprNode
forall a. a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIExprNode -> EvalM TIExprNode) -> TIExprNode -> EvalM TIExprNode
forall a b. (a -> b) -> a -> b
$ String -> TIExprNode
TIVarExpr String
name
      
      -- Lambda expressions: process body
      TILambdaExpr Maybe Var
mVar [Var]
params TIExpr
body -> do
        let (Forall [TyVar]
_ [Constraint]
bodyConstraints Type
_) = TIExpr -> TypeScheme
tiScheme TIExpr
body
            allConstraints :: [Constraint]
allConstraints = [Constraint]
cs [Constraint] -> [Constraint] -> [Constraint]
forall a. [a] -> [a] -> [a]
++ [Constraint]
bodyConstraints
        TIExpr
body' <- ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
allConstraints TIExpr
body
        TIExprNode -> EvalM TIExprNode
forall a. a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIExprNode -> EvalM TIExprNode) -> TIExprNode -> EvalM TIExprNode
forall a b. (a -> b) -> a -> b
$ Maybe Var -> [Var] -> TIExpr -> TIExprNode
TILambdaExpr Maybe Var
mVar [Var]
params TIExpr
body'
      
      -- Function application: check if tensorMap is needed
      TIApplyExpr TIExpr
func [TIExpr]
args -> do
        -- First, recursively process function and arguments
        TIExpr
func' <- ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs TIExpr
func
        [TIExpr]
args' <- (TIExpr -> EvalM TIExpr)
-> [TIExpr]
-> StateT EvalState (ExceptT EgisonError RuntimeM) [TIExpr]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs) [TIExpr]
args

        -- Apply simplified approach: wrap binary function arguments with tensorMap2
        -- This handles cases like `foldl (+) 0 xs` where (+) needs to be wrapped because (+) is a binary function that takes two scalar arguments
        -- But `foldl1 (.) [t1, t2]` should not be wrapped with tensorMap2 because (.) is a binary function that takes two tensor arguments
        -- IMPORTANT: Include each argument's own constraints when deciding if it needs wrapping
        let (Forall [TyVar]
_ [Constraint]
funcConstraints Type
_) = TIExpr -> TypeScheme
tiScheme TIExpr
func'
            baseConstraints :: [Constraint]
baseConstraints = [Constraint]
cs [Constraint] -> [Constraint] -> [Constraint]
forall a. [a] -> [a] -> [a]
++ [Constraint]
funcConstraints
            -- For each argument, merge base constraints with the argument's own constraints
            wrapArg :: TIExpr -> TIExpr
wrapArg TIExpr
arg =
              let (Forall [TyVar]
_ [Constraint]
argConstraints Type
_) = TIExpr -> TypeScheme
tiScheme TIExpr
arg
                  argAllConstraints :: [Constraint]
argAllConstraints = [Constraint] -> [Constraint]
forall a. Eq a => [a] -> [a]
nub ([Constraint]
baseConstraints [Constraint] -> [Constraint] -> [Constraint]
forall a. [a] -> [a] -> [a]
++ [Constraint]
argConstraints)
              in ClassEnv -> [Constraint] -> TIExpr -> TIExpr
wrapBinaryFunctionIfNeeded ClassEnv
env [Constraint]
argAllConstraints TIExpr
arg
            args'' :: [TIExpr]
args'' = (TIExpr -> TIExpr) -> [TIExpr] -> [TIExpr]
forall a b. (a -> b) -> [a] -> [b]
map TIExpr -> TIExpr
wrapArg [TIExpr]
args'

        -- Use the INFERRED function type (after type inference)
        -- This ensures we use concrete types like Integer instead of type variables like a
        -- For example, (+) has inferred type {Num Integer} Integer -> Integer -> Integer
        -- instead of the polymorphic type {Num a} a -> a -> a
        let funcType :: Type
funcType = TIExpr -> Type
tiExprType TIExpr
func'
            argTypes :: [Type]
argTypes = (TIExpr -> Type) -> [TIExpr] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map TIExpr -> Type
tiExprType [TIExpr]
args''

        -- Normal processing: check if tensorMap is needed based on parameter types
        Maybe TIExprNode
result <- ClassEnv
-> [Constraint]
-> TIExpr
-> Type
-> [TIExpr]
-> [Type]
-> EvalM (Maybe TIExprNode)
wrapWithTensorMapIfNeeded ClassEnv
env [Constraint]
baseConstraints TIExpr
func' Type
funcType [TIExpr]
args'' [Type]
argTypes
        case Maybe TIExprNode
result of
          Just TIExprNode
wrappedNode -> TIExprNode -> EvalM TIExprNode
forall a. a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall (m :: * -> *) a. Monad m => a -> m a
return TIExprNode
wrappedNode
          Maybe TIExprNode
Nothing -> TIExprNode -> EvalM TIExprNode
forall a. a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIExprNode -> EvalM TIExprNode) -> TIExprNode -> EvalM TIExprNode
forall a b. (a -> b) -> a -> b
$ TIExpr -> [TIExpr] -> TIExprNode
TIApplyExpr TIExpr
func' [TIExpr]
args''
      
      -- Collections
      TITupleExpr [TIExpr]
exprs -> do
        [TIExpr]
exprs' <- (TIExpr -> EvalM TIExpr)
-> [TIExpr]
-> StateT EvalState (ExceptT EgisonError RuntimeM) [TIExpr]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs) [TIExpr]
exprs
        TIExprNode -> EvalM TIExprNode
forall a. a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIExprNode -> EvalM TIExprNode) -> TIExprNode -> EvalM TIExprNode
forall a b. (a -> b) -> a -> b
$ [TIExpr] -> TIExprNode
TITupleExpr [TIExpr]
exprs'
      
      TICollectionExpr [TIExpr]
exprs -> do
        [TIExpr]
exprs' <- (TIExpr -> EvalM TIExpr)
-> [TIExpr]
-> StateT EvalState (ExceptT EgisonError RuntimeM) [TIExpr]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs) [TIExpr]
exprs
        TIExprNode -> EvalM TIExprNode
forall a. a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIExprNode -> EvalM TIExprNode) -> TIExprNode -> EvalM TIExprNode
forall a b. (a -> b) -> a -> b
$ [TIExpr] -> TIExprNode
TICollectionExpr [TIExpr]
exprs'
      
      TIConsExpr TIExpr
h TIExpr
t -> do
        TIExpr
h' <- ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs TIExpr
h
        TIExpr
t' <- ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs TIExpr
t
        TIExprNode -> EvalM TIExprNode
forall a. a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIExprNode -> EvalM TIExprNode) -> TIExprNode -> EvalM TIExprNode
forall a b. (a -> b) -> a -> b
$ TIExpr -> TIExpr -> TIExprNode
TIConsExpr TIExpr
h' TIExpr
t'
      
      TIJoinExpr TIExpr
l TIExpr
r -> do
        TIExpr
l' <- ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs TIExpr
l
        TIExpr
r' <- ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs TIExpr
r
        TIExprNode -> EvalM TIExprNode
forall a. a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIExprNode -> EvalM TIExprNode) -> TIExprNode -> EvalM TIExprNode
forall a b. (a -> b) -> a -> b
$ TIExpr -> TIExpr -> TIExprNode
TIJoinExpr TIExpr
l' TIExpr
r'
      
      TIHashExpr [(TIExpr, TIExpr)]
pairs -> do
        [(TIExpr, TIExpr)]
pairs' <- ((TIExpr, TIExpr)
 -> StateT
      EvalState (ExceptT EgisonError RuntimeM) (TIExpr, TIExpr))
-> [(TIExpr, TIExpr)]
-> StateT
     EvalState (ExceptT EgisonError RuntimeM) [(TIExpr, TIExpr)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\(TIExpr
k, TIExpr
v) -> do
          TIExpr
k' <- ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs TIExpr
k
          TIExpr
v' <- ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs TIExpr
v
          (TIExpr, TIExpr)
-> StateT EvalState (ExceptT EgisonError RuntimeM) (TIExpr, TIExpr)
forall a. a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIExpr
k', TIExpr
v')) [(TIExpr, TIExpr)]
pairs
        TIExprNode -> EvalM TIExprNode
forall a. a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIExprNode -> EvalM TIExprNode) -> TIExprNode -> EvalM TIExprNode
forall a b. (a -> b) -> a -> b
$ [(TIExpr, TIExpr)] -> TIExprNode
TIHashExpr [(TIExpr, TIExpr)]
pairs'
      
      TIVectorExpr [TIExpr]
exprs -> do
        [TIExpr]
exprs' <- (TIExpr -> EvalM TIExpr)
-> [TIExpr]
-> StateT EvalState (ExceptT EgisonError RuntimeM) [TIExpr]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs) [TIExpr]
exprs
        TIExprNode -> EvalM TIExprNode
forall a. a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIExprNode -> EvalM TIExprNode) -> TIExprNode -> EvalM TIExprNode
forall a b. (a -> b) -> a -> b
$ [TIExpr] -> TIExprNode
TIVectorExpr [TIExpr]
exprs'
      
      -- Control flow
      TIIfExpr TIExpr
cond TIExpr
thenExpr TIExpr
elseExpr -> do
        TIExpr
cond' <- ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs TIExpr
cond
        TIExpr
thenExpr' <- ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs TIExpr
thenExpr
        TIExpr
elseExpr' <- ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs TIExpr
elseExpr
        TIExprNode -> EvalM TIExprNode
forall a. a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIExprNode -> EvalM TIExprNode) -> TIExprNode -> EvalM TIExprNode
forall a b. (a -> b) -> a -> b
$ TIExpr -> TIExpr -> TIExpr -> TIExprNode
TIIfExpr TIExpr
cond' TIExpr
thenExpr' TIExpr
elseExpr'
      
      -- Let bindings
      TILetExpr [TIBindingExpr]
bindings TIExpr
body -> do
        [TIBindingExpr]
bindings' <- (TIBindingExpr
 -> StateT EvalState (ExceptT EgisonError RuntimeM) TIBindingExpr)
-> [TIBindingExpr]
-> StateT EvalState (ExceptT EgisonError RuntimeM) [TIBindingExpr]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\(IPrimitiveDataPattern
v, TIExpr
e) -> do
          TIExpr
e' <- ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs TIExpr
e
          TIBindingExpr
-> StateT EvalState (ExceptT EgisonError RuntimeM) TIBindingExpr
forall a. a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall (m :: * -> *) a. Monad m => a -> m a
return (IPrimitiveDataPattern
v, TIExpr
e')) [TIBindingExpr]
bindings
        TIExpr
body' <- ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs TIExpr
body
        TIExprNode -> EvalM TIExprNode
forall a. a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIExprNode -> EvalM TIExprNode) -> TIExprNode -> EvalM TIExprNode
forall a b. (a -> b) -> a -> b
$ [TIBindingExpr] -> TIExpr -> TIExprNode
TILetExpr [TIBindingExpr]
bindings' TIExpr
body'
      
      TILetRecExpr [TIBindingExpr]
bindings TIExpr
body -> do
        [TIBindingExpr]
bindings' <- (TIBindingExpr
 -> StateT EvalState (ExceptT EgisonError RuntimeM) TIBindingExpr)
-> [TIBindingExpr]
-> StateT EvalState (ExceptT EgisonError RuntimeM) [TIBindingExpr]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\(IPrimitiveDataPattern
v, TIExpr
e) -> do
          TIExpr
e' <- ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs TIExpr
e
          TIBindingExpr
-> StateT EvalState (ExceptT EgisonError RuntimeM) TIBindingExpr
forall a. a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall (m :: * -> *) a. Monad m => a -> m a
return (IPrimitiveDataPattern
v, TIExpr
e')) [TIBindingExpr]
bindings
        TIExpr
body' <- ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs TIExpr
body
        TIExprNode -> EvalM TIExprNode
forall a. a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIExprNode -> EvalM TIExprNode) -> TIExprNode -> EvalM TIExprNode
forall a b. (a -> b) -> a -> b
$ [TIBindingExpr] -> TIExpr -> TIExprNode
TILetRecExpr [TIBindingExpr]
bindings' TIExpr
body'
      
      TISeqExpr TIExpr
e1 TIExpr
e2 -> do
        TIExpr
e1' <- ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs TIExpr
e1
        TIExpr
e2' <- ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs TIExpr
e2
        TIExprNode -> EvalM TIExprNode
forall a. a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIExprNode -> EvalM TIExprNode) -> TIExprNode -> EvalM TIExprNode
forall a b. (a -> b) -> a -> b
$ TIExpr -> TIExpr -> TIExprNode
TISeqExpr TIExpr
e1' TIExpr
e2'
      
      -- Pattern matching
      TIMatchExpr PMMode
mode TIExpr
target TIExpr
matcher [TIMatchClause]
clauses -> do
        TIExpr
target' <- ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs TIExpr
target
        TIExpr
matcher' <- ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs TIExpr
matcher
        [TIMatchClause]
clauses' <- (TIMatchClause
 -> StateT EvalState (ExceptT EgisonError RuntimeM) TIMatchClause)
-> [TIMatchClause]
-> StateT EvalState (ExceptT EgisonError RuntimeM) [TIMatchClause]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\(TIPattern
pat, TIExpr
body) -> do
          TIExpr
body' <- ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs TIExpr
body
          TIMatchClause
-> StateT EvalState (ExceptT EgisonError RuntimeM) TIMatchClause
forall a. a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIPattern
pat, TIExpr
body')) [TIMatchClause]
clauses
        TIExprNode -> EvalM TIExprNode
forall a. a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIExprNode -> EvalM TIExprNode) -> TIExprNode -> EvalM TIExprNode
forall a b. (a -> b) -> a -> b
$ PMMode -> TIExpr -> TIExpr -> [TIMatchClause] -> TIExprNode
TIMatchExpr PMMode
mode TIExpr
target' TIExpr
matcher' [TIMatchClause]
clauses'
      
      TIMatchAllExpr PMMode
mode TIExpr
target TIExpr
matcher [TIMatchClause]
clauses -> do
        TIExpr
target' <- ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs TIExpr
target
        TIExpr
matcher' <- ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs TIExpr
matcher
        [TIMatchClause]
clauses' <- (TIMatchClause
 -> StateT EvalState (ExceptT EgisonError RuntimeM) TIMatchClause)
-> [TIMatchClause]
-> StateT EvalState (ExceptT EgisonError RuntimeM) [TIMatchClause]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\(TIPattern
pat, TIExpr
body) -> do
          TIExpr
body' <- ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs TIExpr
body
          TIMatchClause
-> StateT EvalState (ExceptT EgisonError RuntimeM) TIMatchClause
forall a. a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIPattern
pat, TIExpr
body')) [TIMatchClause]
clauses
        TIExprNode -> EvalM TIExprNode
forall a. a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIExprNode -> EvalM TIExprNode) -> TIExprNode -> EvalM TIExprNode
forall a b. (a -> b) -> a -> b
$ PMMode -> TIExpr -> TIExpr -> [TIMatchClause] -> TIExprNode
TIMatchAllExpr PMMode
mode TIExpr
target' TIExpr
matcher' [TIMatchClause]
clauses'
      
      -- More lambda-like constructs
      TIMemoizedLambdaExpr [String]
vars TIExpr
body -> do
        TIExpr
body' <- ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs TIExpr
body
        TIExprNode -> EvalM TIExprNode
forall a. a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIExprNode -> EvalM TIExprNode) -> TIExprNode -> EvalM TIExprNode
forall a b. (a -> b) -> a -> b
$ [String] -> TIExpr -> TIExprNode
TIMemoizedLambdaExpr [String]
vars TIExpr
body'
      
      TICambdaExpr String
var TIExpr
body -> do
        TIExpr
body' <- ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs TIExpr
body
        TIExprNode -> EvalM TIExprNode
forall a. a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIExprNode -> EvalM TIExprNode) -> TIExprNode -> EvalM TIExprNode
forall a b. (a -> b) -> a -> b
$ String -> TIExpr -> TIExprNode
TICambdaExpr String
var TIExpr
body'
      
      TIWithSymbolsExpr [String]
syms TIExpr
body -> do
        TIExpr
body' <- ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs TIExpr
body
        TIExprNode -> EvalM TIExprNode
forall a. a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIExprNode -> EvalM TIExprNode) -> TIExprNode -> EvalM TIExprNode
forall a b. (a -> b) -> a -> b
$ [String] -> TIExpr -> TIExprNode
TIWithSymbolsExpr [String]
syms TIExpr
body'
      
      TIDoExpr [TIBindingExpr]
bindings TIExpr
body -> do
        [TIBindingExpr]
bindings' <- (TIBindingExpr
 -> StateT EvalState (ExceptT EgisonError RuntimeM) TIBindingExpr)
-> [TIBindingExpr]
-> StateT EvalState (ExceptT EgisonError RuntimeM) [TIBindingExpr]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\(IPrimitiveDataPattern
v, TIExpr
e) -> do
          TIExpr
e' <- ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs TIExpr
e
          TIBindingExpr
-> StateT EvalState (ExceptT EgisonError RuntimeM) TIBindingExpr
forall a. a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall (m :: * -> *) a. Monad m => a -> m a
return (IPrimitiveDataPattern
v, TIExpr
e')) [TIBindingExpr]
bindings
        TIExpr
body' <- ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs TIExpr
body
        TIExprNode -> EvalM TIExprNode
forall a. a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIExprNode -> EvalM TIExprNode) -> TIExprNode -> EvalM TIExprNode
forall a b. (a -> b) -> a -> b
$ [TIBindingExpr] -> TIExpr -> TIExprNode
TIDoExpr [TIBindingExpr]
bindings' TIExpr
body'
      
      -- Tensor operations
      TITensorMapExpr TIExpr
func TIExpr
tensor -> do
        TIExpr
func' <- ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs TIExpr
func
        TIExpr
tensor' <- ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs TIExpr
tensor
        TIExprNode -> EvalM TIExprNode
forall a. a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIExprNode -> EvalM TIExprNode) -> TIExprNode -> EvalM TIExprNode
forall a b. (a -> b) -> a -> b
$ TIExpr -> TIExpr -> TIExprNode
TITensorMapExpr TIExpr
func' TIExpr
tensor'
      
      TITensorMap2Expr TIExpr
func TIExpr
t1 TIExpr
t2 -> do
        TIExpr
func' <- ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs TIExpr
func
        TIExpr
t1' <- ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs TIExpr
t1
        TIExpr
t2' <- ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs TIExpr
t2
        TIExprNode -> EvalM TIExprNode
forall a. a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIExprNode -> EvalM TIExprNode) -> TIExprNode -> EvalM TIExprNode
forall a b. (a -> b) -> a -> b
$ TIExpr -> TIExpr -> TIExpr -> TIExprNode
TITensorMap2Expr TIExpr
func' TIExpr
t1' TIExpr
t2'

      TITensorMap2WedgeExpr TIExpr
func TIExpr
t1 TIExpr
t2 -> do
        TIExpr
func' <- ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs TIExpr
func
        TIExpr
t1' <- ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs TIExpr
t1
        TIExpr
t2' <- ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs TIExpr
t2
        TIExprNode -> EvalM TIExprNode
forall a. a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIExprNode -> EvalM TIExprNode) -> TIExprNode -> EvalM TIExprNode
forall a b. (a -> b) -> a -> b
$ TIExpr -> TIExpr -> TIExpr -> TIExprNode
TITensorMap2WedgeExpr TIExpr
func' TIExpr
t1' TIExpr
t2'

      TIGenerateTensorExpr TIExpr
func TIExpr
shape -> do
        TIExpr
func' <- ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs TIExpr
func
        TIExpr
shape' <- ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs TIExpr
shape
        TIExprNode -> EvalM TIExprNode
forall a. a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIExprNode -> EvalM TIExprNode) -> TIExprNode -> EvalM TIExprNode
forall a b. (a -> b) -> a -> b
$ TIExpr -> TIExpr -> TIExprNode
TIGenerateTensorExpr TIExpr
func' TIExpr
shape'
      
      TITensorExpr TIExpr
shape TIExpr
elems -> do
        TIExpr
shape' <- ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs TIExpr
shape
        TIExpr
elems' <- ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs TIExpr
elems
        TIExprNode -> EvalM TIExprNode
forall a. a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIExprNode -> EvalM TIExprNode) -> TIExprNode -> EvalM TIExprNode
forall a b. (a -> b) -> a -> b
$ TIExpr -> TIExpr -> TIExprNode
TITensorExpr TIExpr
shape' TIExpr
elems'
      
      TITensorContractExpr TIExpr
tensor -> do
        TIExpr
tensor' <- ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs TIExpr
tensor
        TIExprNode -> EvalM TIExprNode
forall a. a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIExprNode -> EvalM TIExprNode) -> TIExprNode -> EvalM TIExprNode
forall a b. (a -> b) -> a -> b
$ TIExpr -> TIExprNode
TITensorContractExpr TIExpr
tensor'
      
      TITransposeExpr TIExpr
perm TIExpr
tensor -> do
        TIExpr
perm' <- ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs TIExpr
perm
        TIExpr
tensor' <- ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs TIExpr
tensor
        TIExprNode -> EvalM TIExprNode
forall a. a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIExprNode -> EvalM TIExprNode) -> TIExprNode -> EvalM TIExprNode
forall a b. (a -> b) -> a -> b
$ TIExpr -> TIExpr -> TIExprNode
TITransposeExpr TIExpr
perm' TIExpr
tensor'
      
      TIFlipIndicesExpr TIExpr
tensor -> do
        TIExpr
tensor' <- ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs TIExpr
tensor
        TIExprNode -> EvalM TIExprNode
forall a. a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIExprNode -> EvalM TIExprNode) -> TIExprNode -> EvalM TIExprNode
forall a b. (a -> b) -> a -> b
$ TIExpr -> TIExprNode
TIFlipIndicesExpr TIExpr
tensor'
      
      -- Quote expressions
      TIQuoteExpr TIExpr
e -> do
        TIExpr
e' <- ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs TIExpr
e
        TIExprNode -> EvalM TIExprNode
forall a. a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIExprNode -> EvalM TIExprNode) -> TIExprNode -> EvalM TIExprNode
forall a b. (a -> b) -> a -> b
$ TIExpr -> TIExprNode
TIQuoteExpr TIExpr
e'
      
      TIQuoteSymbolExpr TIExpr
e -> do
        TIExpr
e' <- ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs TIExpr
e
        TIExprNode -> EvalM TIExprNode
forall a. a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIExprNode -> EvalM TIExprNode) -> TIExprNode -> EvalM TIExprNode
forall a b. (a -> b) -> a -> b
$ TIExpr -> TIExprNode
TIQuoteSymbolExpr TIExpr
e'
      
      -- Indexed expressions
      TISubrefsExpr Bool
b TIExpr
base TIExpr
ref -> do
        TIExpr
base' <- ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs TIExpr
base
        TIExpr
ref' <- ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs TIExpr
ref
        TIExprNode -> EvalM TIExprNode
forall a. a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIExprNode -> EvalM TIExprNode) -> TIExprNode -> EvalM TIExprNode
forall a b. (a -> b) -> a -> b
$ Bool -> TIExpr -> TIExpr -> TIExprNode
TISubrefsExpr Bool
b TIExpr
base' TIExpr
ref'
      
      TISuprefsExpr Bool
b TIExpr
base TIExpr
ref -> do
        TIExpr
base' <- ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs TIExpr
base
        TIExpr
ref' <- ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs TIExpr
ref
        TIExprNode -> EvalM TIExprNode
forall a. a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIExprNode -> EvalM TIExprNode) -> TIExprNode -> EvalM TIExprNode
forall a b. (a -> b) -> a -> b
$ Bool -> TIExpr -> TIExpr -> TIExprNode
TISuprefsExpr Bool
b TIExpr
base' TIExpr
ref'
      
      TIUserrefsExpr Bool
b TIExpr
base TIExpr
ref -> do
        TIExpr
base' <- ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs TIExpr
base
        TIExpr
ref' <- ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs TIExpr
ref
        TIExprNode -> EvalM TIExprNode
forall a. a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIExprNode -> EvalM TIExprNode) -> TIExprNode -> EvalM TIExprNode
forall a b. (a -> b) -> a -> b
$ Bool -> TIExpr -> TIExpr -> TIExprNode
TIUserrefsExpr Bool
b TIExpr
base' TIExpr
ref'
      
      -- Other cases
      TIInductiveDataExpr String
name [TIExpr]
exprs -> do
        [TIExpr]
exprs' <- (TIExpr -> EvalM TIExpr)
-> [TIExpr]
-> StateT EvalState (ExceptT EgisonError RuntimeM) [TIExpr]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs) [TIExpr]
exprs
        TIExprNode -> EvalM TIExprNode
forall a. a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIExprNode -> EvalM TIExprNode) -> TIExprNode -> EvalM TIExprNode
forall a b. (a -> b) -> a -> b
$ String -> [TIExpr] -> TIExprNode
TIInductiveDataExpr String
name [TIExpr]
exprs'
      
      TIMatcherExpr [TIPatternDef]
patDefs -> TIExprNode -> EvalM TIExprNode
forall a. a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIExprNode -> EvalM TIExprNode) -> TIExprNode -> EvalM TIExprNode
forall a b. (a -> b) -> a -> b
$ [TIPatternDef] -> TIExprNode
TIMatcherExpr [TIPatternDef]
patDefs
      
      TIIndexedExpr Bool
override TIExpr
base [Index TIExpr]
indices -> do
        TIExpr
base' <- ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs TIExpr
base
        [Index TIExpr]
indices' <- (Index TIExpr
 -> StateT EvalState (ExceptT EgisonError RuntimeM) (Index TIExpr))
-> [Index TIExpr]
-> StateT EvalState (ExceptT EgisonError RuntimeM) [Index TIExpr]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((TIExpr -> EvalM TIExpr)
-> Index TIExpr
-> StateT EvalState (ExceptT EgisonError RuntimeM) (Index TIExpr)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Index a -> f (Index b)
traverse (\TIExpr
tiexpr -> ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs TIExpr
tiexpr)) [Index TIExpr]
indices
        TIExprNode -> EvalM TIExprNode
forall a. a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIExprNode -> EvalM TIExprNode) -> TIExprNode -> EvalM TIExprNode
forall a b. (a -> b) -> a -> b
$ Bool -> TIExpr -> [Index TIExpr] -> TIExprNode
TIIndexedExpr Bool
override TIExpr
base' [Index TIExpr]
indices'
      
      TIWedgeApplyExpr TIExpr
func [TIExpr]
args -> do
        TIExpr
func' <- ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs TIExpr
func
        [TIExpr]
args' <- (TIExpr -> EvalM TIExpr)
-> [TIExpr]
-> StateT EvalState (ExceptT EgisonError RuntimeM) [TIExpr]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
cs) [TIExpr]
args

        -- Check if the function's parameter types are NOT Tensor types
        -- If so, insert tensorMap2Wedge; otherwise, keep WedgeApply
        let funcType :: Type
funcType = TIExpr -> Type
tiExprType TIExpr
func'
            -- Check if this is a binary function with non-Tensor parameters
            -- A type is non-Tensor if it's not TTensor _ (could be TVar, TBase, etc.)
            isNonTensorType :: Type -> Bool
isNonTensorType Type
ty = case Type
ty of
              TTensor Type
_ -> Bool
False
              Type
_ -> Bool
True
            isScalarFunction :: Bool
isScalarFunction = case Type
funcType of
              TFun Type
param1 (TFun Type
param2 Type
_result) ->
                Type -> Bool
isNonTensorType Type
param1 Bool -> Bool -> Bool
&& Type -> Bool
isNonTensorType Type
param2
              Type
_ -> Bool
False

        if Bool
isScalarFunction Bool -> Bool -> Bool
&& [TIExpr] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TIExpr]
args' Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
2
          then do
            -- Insert tensorMap2Wedge for binary scalar functions
            let [TIExpr
arg1, TIExpr
arg2] = [TIExpr]
args'
                -- Preserve the function's original scheme with its constraints
                (Forall [TyVar]
tvs [Constraint]
funcConstraints Type
_) = TIExpr -> TypeScheme
tiScheme TIExpr
func'
                -- Unlift the function type to get the scalar version
                unliftedFuncType :: Type
unliftedFuncType = Type -> Type
unliftFunctionType Type
funcType
                unliftedFunc :: TIExpr
unliftedFunc = TypeScheme -> TIExprNode -> TIExpr
TIExpr ([TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [TyVar]
tvs [Constraint]
funcConstraints Type
unliftedFuncType) (TIExpr -> TIExprNode
tiExprNode TIExpr
func')
                -- Get the result type after applying to tensor arguments
                resultType :: Type
resultType = case Type
funcType of
                  TFun Type
_ (TFun Type
_ Type
res) -> Type -> Type
TTensor Type
res  -- Lifting scalar result to Tensor
                  Type
_ -> Type
funcType  -- Fallback
                tensorMap2WedgeScheme :: TypeScheme
tensorMap2WedgeScheme = [TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [Constraint]
cs Type
resultType
            TIExprNode -> EvalM TIExprNode
forall a. a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIExprNode -> EvalM TIExprNode) -> TIExprNode -> EvalM TIExprNode
forall a b. (a -> b) -> a -> b
$ TIExpr -> TIExpr -> TIExpr -> TIExprNode
TITensorMap2WedgeExpr TIExpr
unliftedFunc TIExpr
arg1 TIExpr
arg2
          else
            -- Keep WedgeApply for tensor functions or non-binary functions
            TIExprNode -> EvalM TIExprNode
forall a. a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIExprNode -> EvalM TIExprNode) -> TIExprNode -> EvalM TIExprNode
forall a b. (a -> b) -> a -> b
$ TIExpr -> [TIExpr] -> TIExprNode
TIWedgeApplyExpr TIExpr
func' [TIExpr]
args'
      
      TIFunctionExpr [String]
names -> TIExprNode -> EvalM TIExprNode
forall a. a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIExprNode -> EvalM TIExprNode) -> TIExprNode -> EvalM TIExprNode
forall a b. (a -> b) -> a -> b
$ [String] -> TIExprNode
TIFunctionExpr [String]
names

-- | Helper to insert tensorMaps in a TIExpr with constraints
-- IMPORTANT: Merges context constraints with expression's own constraints
-- This is critical for polymorphic functions where the constraint (e.g., {Num t0})
-- comes from the enclosing scope, not the expression itself.
insertTensorMapsWithConstraints :: ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints :: ClassEnv -> [Constraint] -> TIExpr -> EvalM TIExpr
insertTensorMapsWithConstraints ClassEnv
env [Constraint]
contextConstraints TIExpr
expr = do
  let (Forall [TyVar]
tvs [Constraint]
exprConstraints Type
ty) = TIExpr -> TypeScheme
tiScheme TIExpr
expr
      -- Merge context constraints with expression's own constraints, deduplicating
      mergedConstraints :: [Constraint]
mergedConstraints = [Constraint] -> [Constraint]
forall a. Eq a => [a] -> [a]
nub ([Constraint]
contextConstraints [Constraint] -> [Constraint] -> [Constraint]
forall a. [a] -> [a] -> [a]
++ [Constraint]
exprConstraints)
      mergedScheme :: TypeScheme
mergedScheme = [TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [TyVar]
tvs [Constraint]
mergedConstraints Type
ty
  ClassEnv -> TypeScheme -> TIExpr -> EvalM TIExpr
insertTensorMapsInExpr ClassEnv
env TypeScheme
mergedScheme TIExpr
expr

-- | Wrap function application with tensorMap if needed
-- Returns Just wrappedNode if tensorMap was inserted, Nothing otherwise
wrapWithTensorMapIfNeeded :: ClassEnv -> [Constraint] -> TIExpr -> Type -> [TIExpr] -> [Type] -> EvalM (Maybe TIExprNode)
wrapWithTensorMapIfNeeded :: ClassEnv
-> [Constraint]
-> TIExpr
-> Type
-> [TIExpr]
-> [Type]
-> EvalM (Maybe TIExprNode)
wrapWithTensorMapIfNeeded ClassEnv
classEnv [Constraint]
constraints TIExpr
func Type
funcType [TIExpr]
args [Type]
argTypes = do
  -- Check if any argument needs tensorMap
  let checks :: [Bool]
checks = (Type -> Int -> Bool) -> [Type] -> [Int] -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\Type
argType Int
idx -> 
                 case Type -> Int -> Maybe Type
getParamType Type
funcType Int
idx of
                   Just Type
paramType -> ClassEnv -> [Constraint] -> Type -> Type -> Bool
shouldInsertTensorMap ClassEnv
classEnv [Constraint]
constraints Type
argType Type
paramType
                   Maybe Type
Nothing -> Bool
False
               ) [Type]
argTypes [Int
0..]
  
  if [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
or [Bool]
checks
    then do
      -- Need to insert tensorMap - use recursive wrapping
      TIExprNode
wrapped <- ClassEnv
-> [Constraint]
-> TIExpr
-> Type
-> [TIExpr]
-> [Type]
-> EvalM TIExprNode
wrapWithTensorMapRecursive ClassEnv
classEnv [Constraint]
constraints TIExpr
func Type
funcType [TIExpr]
args [Type]
argTypes
      Maybe TIExprNode -> EvalM (Maybe TIExprNode)
forall a. a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe TIExprNode -> EvalM (Maybe TIExprNode))
-> Maybe TIExprNode -> EvalM (Maybe TIExprNode)
forall a b. (a -> b) -> a -> b
$ TIExprNode -> Maybe TIExprNode
forall a. a -> Maybe a
Just TIExprNode
wrapped
    else
      -- No tensorMap needed
      Maybe TIExprNode -> EvalM (Maybe TIExprNode)
forall a. a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe TIExprNode
forall a. Maybe a
Nothing

-- | Recursively wrap function application with tensorMap where needed
-- Process arguments from left to right, building tensorMap2 for consecutive tensor arguments
wrapWithTensorMapRecursive :: 
    ClassEnv
    -> [Constraint]
    -> TIExpr          -- Current function expression (possibly partially applied)
    -> Type            -- Current function type
    -> [TIExpr]        -- Remaining argument expressions  
    -> [Type]          -- Remaining argument types
    -> EvalM TIExprNode
wrapWithTensorMapRecursive :: ClassEnv
-> [Constraint]
-> TIExpr
-> Type
-> [TIExpr]
-> [Type]
-> EvalM TIExprNode
wrapWithTensorMapRecursive ClassEnv
_classEnv [Constraint]
_constraints TIExpr
currentFunc Type
_currentType [] [] = do
  -- All arguments processed - return the application
  TIExprNode -> EvalM TIExprNode
forall a. a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIExprNode -> EvalM TIExprNode) -> TIExprNode -> EvalM TIExprNode
forall a b. (a -> b) -> a -> b
$ TIExpr -> TIExprNode
tiExprNode TIExpr
currentFunc

wrapWithTensorMapRecursive ClassEnv
classEnv [Constraint]
constraints TIExpr
currentFunc Type
currentType (TIExpr
arg1:[TIExpr]
restArgs) (Type
argType1:[Type]
restArgTypes) = do
  -- Get the expected parameter type for first argument
  case Type -> Int -> Maybe Type
getParamType Type
currentType Int
0 of
    Maybe Type
Nothing -> TIExprNode -> EvalM TIExprNode
forall a. a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIExprNode -> EvalM TIExprNode) -> TIExprNode -> EvalM TIExprNode
forall a b. (a -> b) -> a -> b
$ TIExpr -> [TIExpr] -> TIExprNode
TIApplyExpr TIExpr
currentFunc (TIExpr
arg1 TIExpr -> [TIExpr] -> [TIExpr]
forall a. a -> [a] -> [a]
: [TIExpr]
restArgs)
    Just Type
paramType1 -> do
      let needsTensorMap1 :: Bool
needsTensorMap1 = ClassEnv -> [Constraint] -> Type -> Type -> Bool
shouldInsertTensorMap ClassEnv
classEnv [Constraint]
constraints Type
argType1 Type
paramType1
      
      if Bool
needsTensorMap1
        then do
          -- Check if we have a second argument that also needs tensorMap
          -- If so, use tensorMap2 instead of nested tensorMap
          case ([TIExpr]
restArgs, [Type]
restArgTypes) of
            (TIExpr
arg2:[TIExpr]
restArgs', Type
argType2:[Type]
restArgTypes') -> do
              let innerType :: Type
innerType = Type -> Type
applyOneArgType Type
currentType
              case Type -> Int -> Maybe Type
getParamType Type
innerType Int
0 of
                Just Type
paramType2 | ClassEnv -> [Constraint] -> Type -> Type -> Bool
shouldInsertTensorMap ClassEnv
classEnv [Constraint]
constraints Type
argType2 Type
paramType2 -> do
                  -- Both first and second arguments need tensorMap → use tensorMap2
                  let varName1 :: String
varName1 = String
"tmapVar" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show ([TIExpr] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TIExpr]
restArgs)
                      varName2 :: String
varName2 = String
"tmapVar" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show ([TIExpr] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TIExpr]
restArgs')
                      var1 :: Var
var1 = String -> [Index (Maybe Var)] -> Var
Var String
varName1 []
                      var2 :: Var
var2 = String -> [Index (Maybe Var)] -> Var
Var String
varName2 []

                      -- Extract element types from tensors
                      elemType1 :: Type
elemType1 = case Type
argType1 of
                                    TTensor Type
t -> Type
t
                                    Type
_ -> Type
argType1
                      elemType2 :: Type
elemType2 = case Type
argType2 of
                                    TTensor Type
t -> Type
t
                                    Type
_ -> Type
argType2

                      varScheme1 :: TypeScheme
varScheme1 = [TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [] Type
elemType1
                      varScheme2 :: TypeScheme
varScheme2 = [TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [] Type
elemType2
                      varTIExpr1 :: TIExpr
varTIExpr1 = TypeScheme -> TIExprNode -> TIExpr
TIExpr TypeScheme
varScheme1 (String -> TIExprNode
TIVarExpr String
varName1)
                      varTIExpr2 :: TIExpr
varTIExpr2 = TypeScheme -> TIExprNode -> TIExpr
TIExpr TypeScheme
varScheme2 (String -> TIExprNode
TIVarExpr String
varName2)

                      -- Unlift the function type for use inside tensorMap
                      -- IMPORTANT: Use the instantiated type from currentFunc, not the polymorphic currentType
                      -- This ensures we use the unified type variable (e.g., t0) instead of fresh variables (e.g., a)
                      instantiatedFuncType :: Type
instantiatedFuncType = TIExpr -> Type
tiExprType TIExpr
currentFunc
                      unliftedFuncType :: Type
unliftedFuncType = Type -> Type
unliftFunctionType Type
instantiatedFuncType
                      funcScheme :: TypeScheme
funcScheme = TIExpr -> TypeScheme
tiScheme TIExpr
currentFunc
                      (Forall [TyVar]
tvs [Constraint]
funcConstraints Type
_) = TypeScheme
funcScheme
                      unliftedFuncScheme :: TypeScheme
unliftedFuncScheme = [TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [TyVar]
tvs [Constraint]
funcConstraints Type
unliftedFuncType
                      unliftedFunc :: TIExpr
unliftedFunc = TypeScheme -> TIExprNode -> TIExpr
TIExpr TypeScheme
unliftedFuncScheme (TIExpr -> TIExprNode
tiExprNode TIExpr
currentFunc)

                      -- Build inner expression with both variables applied
                      innerType2 :: Type
innerType2 = Type -> Type
applyOneArgType (Type -> Type
applyOneArgType Type
unliftedFuncType)
                      -- After applying both arguments, this is a fully-applied result - no constraints needed
                      innerFuncScheme :: TypeScheme
innerFuncScheme = [TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [] Type
innerType2
                      innerFuncTI :: TIExpr
innerFuncTI = TypeScheme -> TIExprNode -> TIExpr
TIExpr TypeScheme
innerFuncScheme (TIExpr -> [TIExpr] -> TIExprNode
TIApplyExpr TIExpr
unliftedFunc [TIExpr
varTIExpr1, TIExpr
varTIExpr2])

                  -- Process remaining arguments after consuming two
                  TIExprNode
innerNode <- ClassEnv
-> [Constraint]
-> TIExpr
-> Type
-> [TIExpr]
-> [Type]
-> EvalM TIExprNode
wrapWithTensorMapRecursive ClassEnv
classEnv [Constraint]
constraints TIExpr
innerFuncTI Type
innerType2 [TIExpr]
restArgs' [Type]
restArgTypes'
                  let innerTIExpr :: TIExpr
innerTIExpr = TypeScheme -> TIExprNode -> TIExpr
TIExpr TypeScheme
innerFuncScheme TIExprNode
innerNode
                      finalType :: Type
finalType = TIExpr -> Type
tiExprType TIExpr
innerTIExpr

                  -- Build lambda: \varName1 varName2 -> innerTIExpr
                  -- Lambda has no constraints - it's just a wrapper that receives scalars
                  let lambdaType :: Type
lambdaType = Type -> Type -> Type
TFun Type
elemType1 (Type -> Type -> Type
TFun Type
elemType2 Type
finalType)
                      lambdaScheme :: TypeScheme
lambdaScheme = [TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [] Type
lambdaType
                      lambdaTI :: TIExpr
lambdaTI = TypeScheme -> TIExprNode -> TIExpr
TIExpr TypeScheme
lambdaScheme (Maybe Var -> [Var] -> TIExpr -> TIExprNode
TILambdaExpr Maybe Var
forall a. Maybe a
Nothing [Var
var1, Var
var2] TIExpr
innerTIExpr)

                  TIExprNode -> EvalM TIExprNode
forall a. a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIExprNode -> EvalM TIExprNode) -> TIExprNode -> EvalM TIExprNode
forall a b. (a -> b) -> a -> b
$ TIExpr -> TIExpr -> TIExpr -> TIExprNode
TITensorMap2Expr TIExpr
lambdaTI TIExpr
arg1 TIExpr
arg2
                
                Maybe Type
_ -> do
                  -- Only first argument needs tensorMap → use regular tensorMap
                  ClassEnv
-> [Constraint]
-> TIExpr
-> Type
-> TIExpr
-> Type
-> [TIExpr]
-> [Type]
-> EvalM TIExprNode
insertSingleTensorMap ClassEnv
classEnv [Constraint]
constraints TIExpr
currentFunc Type
currentType TIExpr
arg1 Type
argType1 [TIExpr]
restArgs [Type]
restArgTypes
            
            ([TIExpr], [Type])
_ -> do
              -- No more arguments or types → use regular tensorMap for first argument
              ClassEnv
-> [Constraint]
-> TIExpr
-> Type
-> TIExpr
-> Type
-> [TIExpr]
-> [Type]
-> EvalM TIExprNode
insertSingleTensorMap ClassEnv
classEnv [Constraint]
constraints TIExpr
currentFunc Type
currentType TIExpr
arg1 Type
argType1 [TIExpr]
restArgs [Type]
restArgTypes
        
        else do
          -- First argument doesn't need tensorMap, apply normally and continue
          let appliedType :: Type
appliedType = Type -> Type
applyOneArgType Type
currentType
              appliedScheme :: TypeScheme
appliedScheme = [TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [Constraint]
constraints Type
appliedType
              appliedTI :: TIExpr
appliedTI = TypeScheme -> TIExprNode -> TIExpr
TIExpr TypeScheme
appliedScheme (TIExpr -> [TIExpr] -> TIExprNode
TIApplyExpr TIExpr
currentFunc [TIExpr
arg1])
          
          -- Process remaining arguments (recursive call)
          ClassEnv
-> [Constraint]
-> TIExpr
-> Type
-> [TIExpr]
-> [Type]
-> EvalM TIExprNode
wrapWithTensorMapRecursive ClassEnv
classEnv [Constraint]
constraints TIExpr
appliedTI Type
appliedType [TIExpr]
restArgs [Type]
restArgTypes

wrapWithTensorMapRecursive ClassEnv
_classEnv [Constraint]
_constraints TIExpr
currentFunc Type
_currentType [TIExpr]
_args [Type]
_argTypes = 
  TIExprNode -> EvalM TIExprNode
forall a. a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIExprNode -> EvalM TIExprNode) -> TIExprNode -> EvalM TIExprNode
forall a b. (a -> b) -> a -> b
$ TIExpr -> [TIExpr] -> TIExprNode
TIApplyExpr TIExpr
currentFunc []

-- | Helper function to insert a single tensorMap (when tensorMap2 is not applicable)
insertSingleTensorMap ::
    ClassEnv
    -> [Constraint]
    -> TIExpr          -- Current function expression
    -> Type            -- Current function type
    -> TIExpr          -- Tensor argument
    -> Type            -- Tensor argument type
    -> [TIExpr]        -- Remaining arguments
    -> [Type]          -- Remaining argument types
    -> EvalM TIExprNode
insertSingleTensorMap :: ClassEnv
-> [Constraint]
-> TIExpr
-> Type
-> TIExpr
-> Type
-> [TIExpr]
-> [Type]
-> EvalM TIExprNode
insertSingleTensorMap ClassEnv
classEnv [Constraint]
constraints TIExpr
currentFunc Type
_currentType TIExpr
arg Type
argType [TIExpr]
restArgs [Type]
restArgTypes = do
  let varName :: String
varName = String
"tmapVar" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show ([TIExpr] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TIExpr]
restArgs)
      var :: Var
var = String -> [Index (Maybe Var)] -> Var
Var String
varName []

      -- Extract element type from tensor
      elemType :: Type
elemType = case Type
argType of
                   TTensor Type
t -> Type
t
                   Type
_ -> Type
argType

      varScheme :: TypeScheme
varScheme = [TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [] Type
elemType
      varTIExpr :: TIExpr
varTIExpr = TypeScheme -> TIExprNode -> TIExpr
TIExpr TypeScheme
varScheme (String -> TIExprNode
TIVarExpr String
varName)

      -- Unlift the function type for use inside tensorMap
      -- IMPORTANT: Use the instantiated type from currentFunc, not the polymorphic currentType
      -- This ensures we use the unified type variable (e.g., t0) instead of fresh variables (e.g., a)
      instantiatedFuncType :: Type
instantiatedFuncType = TIExpr -> Type
tiExprType TIExpr
currentFunc
      unliftedFuncType :: Type
unliftedFuncType = Type -> Type
unliftFunctionType Type
instantiatedFuncType
      funcScheme :: TypeScheme
funcScheme = TIExpr -> TypeScheme
tiScheme TIExpr
currentFunc
      (Forall [TyVar]
tvs [Constraint]
funcConstraints Type
_) = TypeScheme
funcScheme
      unliftedFuncScheme :: TypeScheme
unliftedFuncScheme = [TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [TyVar]
tvs [Constraint]
funcConstraints Type
unliftedFuncType
      unliftedFunc :: TIExpr
unliftedFunc = TypeScheme -> TIExprNode -> TIExpr
TIExpr TypeScheme
unliftedFuncScheme (TIExpr -> TIExprNode
tiExprNode TIExpr
currentFunc)

      -- Build inner expression (recursive call)
      innerType :: Type
innerType = Type -> Type
applyOneArgType Type
unliftedFuncType
      -- Only keep constraints if this is a partial application (function type)
      -- If it's a fully-applied value, no constraints needed
      innerConstraints :: [Constraint]
innerConstraints = case Type
innerType of
                           TFun Type
_ Type
_ -> [Constraint]
funcConstraints  -- Partial application
                           Type
_ -> []  -- Fully applied: no constraints
      innerFuncScheme :: TypeScheme
innerFuncScheme = [TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [Constraint]
innerConstraints Type
innerType
      innerFuncTI :: TIExpr
innerFuncTI = TypeScheme -> TIExprNode -> TIExpr
TIExpr TypeScheme
innerFuncScheme (TIExpr -> [TIExpr] -> TIExprNode
TIApplyExpr TIExpr
unliftedFunc [TIExpr
varTIExpr])

  -- Process remaining arguments
  TIExprNode
innerNode <- ClassEnv
-> [Constraint]
-> TIExpr
-> Type
-> [TIExpr]
-> [Type]
-> EvalM TIExprNode
wrapWithTensorMapRecursive ClassEnv
classEnv [Constraint]
constraints TIExpr
innerFuncTI Type
innerType [TIExpr]
restArgs [Type]
restArgTypes
  let innerTIExpr :: TIExpr
innerTIExpr = TypeScheme -> TIExprNode -> TIExpr
TIExpr TypeScheme
innerFuncScheme TIExprNode
innerNode
      finalType :: Type
finalType = TIExpr -> Type
tiExprType TIExpr
innerTIExpr

  -- Build lambda: \varName -> innerTIExpr
  -- Lambda has no constraints - it's just a wrapper that receives a scalar
  let lambdaType :: Type
lambdaType = Type -> Type -> Type
TFun Type
elemType Type
finalType
      lambdaScheme :: TypeScheme
lambdaScheme = [TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [] Type
lambdaType
      lambdaTI :: TIExpr
lambdaTI = TypeScheme -> TIExprNode -> TIExpr
TIExpr TypeScheme
lambdaScheme (Maybe Var -> [Var] -> TIExpr -> TIExprNode
TILambdaExpr Maybe Var
forall a. Maybe a
Nothing [Var
var] TIExpr
innerTIExpr)

  TIExprNode -> EvalM TIExprNode
forall a. a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIExprNode -> EvalM TIExprNode) -> TIExprNode -> EvalM TIExprNode
forall a b. (a -> b) -> a -> b
$ TIExpr -> TIExpr -> TIExprNode
TITensorMapExpr TIExpr
lambdaTI TIExpr
arg