{- |
Module      : Language.Egison.Type.Infer
Licence     : MIT

This module provides type inference for IExpr (Internal Expression).
This is the unified type inference module for Phase 5-6 of the Egison compiler:
  IExpr (Desugared, no types) → (Type, Subst)

This module consolidates all type inference functionality, including:
  - Hindley-Milner type inference
  - Type class constraint collection
  - Infer monad and state management
  - All helper functions

Note: This module only performs type inference and returns Type information.
The typed AST (TIExpr) is created in a separate phase by combining IExpr with Type.

Previous modules (Infer.hs for Expr, TypeInfer.hs for Expr→TypedExpr) are deprecated.
-}

module Language.Egison.Type.Infer
  ( -- * Type inference
    inferIExpr
  , inferITopExpr
  , inferITopExprs
    -- * Infer monad
  , Infer
  , InferState(..)
  , InferConfig(..)
  , initialInferState
  , initialInferStateWithConfig
  , defaultInferConfig
  , permissiveInferConfig
  , runInfer
  , runInferWithWarnings
  , runInferWithWarningsAndState
    -- * Running inference
  , runInferI
  , runInferIWithEnv
    -- * Helper functions
  , freshVar
  , getEnv
  , setEnv
  , withEnv
  , lookupVar
  , unifyTypes
  , generalize
  , inferConstant
  , addWarning
  , clearWarnings
  ) where

import           Control.Monad              (foldM, zipWithM)
import           Control.Monad.Except       (ExceptT, runExceptT, throwError)
import           Control.Monad.State.Strict (StateT, evalStateT, runStateT, get, gets, modify, put)
import           Data.List                  (isPrefixOf, nub, partition)
import           Data.Maybe                  (catMaybes)
import qualified Data.Map.Strict             as Map
import qualified Data.Set                    as Set
import           Language.Egison.AST        (ConstantExpr (..), PrimitivePatPattern (..))
import           Language.Egison.IExpr      (IExpr (..), ITopExpr (..), TITopExpr (..)
                                            , TIExpr (..), TIExprNode (..)
                                            , IBindingExpr, TIBindingExpr
                                            , IMatchClause, TIMatchClause, IPatternDef, TIPatternDef
                                            , IPattern (..), ILoopRange (..)
                                            , TIPattern (..), TIPatternNode (..), TILoopRange (..)
                                            , IPrimitiveDataPattern, PDPatternBase (..)
                                            , extractNameFromVar, Var (..), Index (..), stringToVar
                                            , tiExprType)
import           Language.Egison.Pretty     (prettyStr)
import           Language.Egison.Type.Env
import qualified Language.Egison.Type.Error as TE
import           Language.Egison.Type.Error (TypeError(..), TypeErrorContext(..), TypeWarning(..),
                                              emptyContext, withExpr)
import           Language.Egison.Type.Subst (Subst(..), applySubst, applySubstConstraint,
                                              applySubstScheme, composeSubst, emptySubst)
import           Language.Egison.Type.Tensor (normalizeTensorType)
import           Language.Egison.Type.Types
import qualified Language.Egison.Type.Types as Types
import           Language.Egison.Type.Unify as TU
import qualified Language.Egison.Type.Unify as Unify
import           Language.Egison.Type.Instance (findMatchingInstanceForType)

--------------------------------------------------------------------------------
-- * Infer Monad and State
--------------------------------------------------------------------------------

-- | Inference configuration
data InferConfig = InferConfig
  { InferConfig -> Bool
cfgPermissive      :: Bool  -- ^ Treat unbound variables as warnings, not errors
  , InferConfig -> Bool
cfgCollectWarnings :: Bool  -- ^ Collect warnings during inference
  }

instance Show InferConfig where
  show :: InferConfig -> String
show InferConfig
cfg = String
"InferConfig { cfgPermissive = " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Bool -> String
forall a. Show a => a -> String
show (InferConfig -> Bool
cfgPermissive InferConfig
cfg)
           String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", cfgCollectWarnings = " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Bool -> String
forall a. Show a => a -> String
show (InferConfig -> Bool
cfgCollectWarnings InferConfig
cfg)
           String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" }"

-- | Default configuration (strict mode)
defaultInferConfig :: InferConfig
defaultInferConfig :: InferConfig
defaultInferConfig = InferConfig
  { cfgPermissive :: Bool
cfgPermissive = Bool
False
  , cfgCollectWarnings :: Bool
cfgCollectWarnings = Bool
False
  }

-- | Permissive configuration (for gradual adoption)
permissiveInferConfig :: InferConfig
permissiveInferConfig :: InferConfig
permissiveInferConfig = InferConfig
  { cfgPermissive :: Bool
cfgPermissive = Bool
True
  , cfgCollectWarnings :: Bool
cfgCollectWarnings = Bool
True
  }

-- | Inference state
data InferState = InferState
  { InferState -> Int
inferCounter     :: Int              -- ^ Fresh variable counter
  , InferState -> TypeEnv
inferEnv         :: TypeEnv          -- ^ Current type environment
  , InferState -> [TypeWarning]
inferWarnings    :: [TypeWarning]    -- ^ Collected warnings
  , InferState -> InferConfig
inferConfig      :: InferConfig      -- ^ Configuration
  , InferState -> ClassEnv
inferClassEnv    :: ClassEnv         -- ^ Type class environment
  , InferState -> PatternTypeEnv
inferPatternEnv  :: PatternTypeEnv   -- ^ Pattern constructor environment (merged)
  , InferState -> PatternTypeEnv
inferPatternFuncEnv :: PatternTypeEnv  -- ^ Pattern function environment (for disambiguation)
  , InferState -> [Constraint]
inferConstraints :: [Constraint]     -- ^ Accumulated type class constraints
  , InferState -> Map String Type
declaredSymbols  :: Map.Map String Type  -- ^ Declared symbols with their types
  } deriving (Int -> InferState -> ShowS
[InferState] -> ShowS
InferState -> String
(Int -> InferState -> ShowS)
-> (InferState -> String)
-> ([InferState] -> ShowS)
-> Show InferState
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> InferState -> ShowS
showsPrec :: Int -> InferState -> ShowS
$cshow :: InferState -> String
show :: InferState -> String
$cshowList :: [InferState] -> ShowS
showList :: [InferState] -> ShowS
Show)

-- | Initial inference state
initialInferState :: InferState
initialInferState :: InferState
initialInferState = Int
-> TypeEnv
-> [TypeWarning]
-> InferConfig
-> ClassEnv
-> PatternTypeEnv
-> PatternTypeEnv
-> [Constraint]
-> Map String Type
-> InferState
InferState Int
0 TypeEnv
emptyEnv [] InferConfig
defaultInferConfig ClassEnv
emptyClassEnv PatternTypeEnv
emptyPatternEnv PatternTypeEnv
emptyPatternEnv [] Map String Type
forall k a. Map k a
Map.empty

-- | Create initial state with config
initialInferStateWithConfig :: InferConfig -> InferState
initialInferStateWithConfig :: InferConfig -> InferState
initialInferStateWithConfig InferConfig
cfg = Int
-> TypeEnv
-> [TypeWarning]
-> InferConfig
-> ClassEnv
-> PatternTypeEnv
-> PatternTypeEnv
-> [Constraint]
-> Map String Type
-> InferState
InferState Int
0 TypeEnv
emptyEnv [] InferConfig
cfg ClassEnv
emptyClassEnv PatternTypeEnv
emptyPatternEnv PatternTypeEnv
emptyPatternEnv [] Map String Type
forall k a. Map k a
Map.empty

-- | Inference monad (with IO for potential future extensions)
type Infer a = ExceptT TypeError (StateT InferState IO) a

-- | Run type inference
runInfer :: Infer a -> InferState -> IO (Either TypeError a)
runInfer :: forall a. Infer a -> InferState -> IO (Either TypeError a)
runInfer Infer a
m InferState
st = StateT InferState IO (Either TypeError a)
-> InferState -> IO (Either TypeError a)
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT (Infer a -> StateT InferState IO (Either TypeError a)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT Infer a
m) InferState
st

-- | Run type inference and also return warnings
runInferWithWarnings :: Infer a -> InferState -> IO (Either TypeError a, [TypeWarning])
runInferWithWarnings :: forall a.
Infer a -> InferState -> IO (Either TypeError a, [TypeWarning])
runInferWithWarnings Infer a
m InferState
st = do
  (Either TypeError a
result, InferState
finalState) <- StateT InferState IO (Either TypeError a)
-> InferState -> IO (Either TypeError a, InferState)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (Infer a -> StateT InferState IO (Either TypeError a)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT Infer a
m) InferState
st
  (Either TypeError a, [TypeWarning])
-> IO (Either TypeError a, [TypeWarning])
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TypeError a
result, InferState -> [TypeWarning]
inferWarnings InferState
finalState)

-- | Run inference and return result, warnings, and final state
runInferWithWarningsAndState :: Infer a -> InferState -> IO (Either TypeError a, [TypeWarning], InferState)
runInferWithWarningsAndState :: forall a.
Infer a
-> InferState -> IO (Either TypeError a, [TypeWarning], InferState)
runInferWithWarningsAndState Infer a
m InferState
st = do
  (Either TypeError a
result, InferState
finalState) <- StateT InferState IO (Either TypeError a)
-> InferState -> IO (Either TypeError a, InferState)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (Infer a -> StateT InferState IO (Either TypeError a)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT Infer a
m) InferState
st
  (Either TypeError a, [TypeWarning], InferState)
-> IO (Either TypeError a, [TypeWarning], InferState)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TypeError a
result, InferState -> [TypeWarning]
inferWarnings InferState
finalState, InferState
finalState)

--------------------------------------------------------------------------------
-- * Helper Functions
--------------------------------------------------------------------------------

-- | Add a warning
addWarning :: TypeWarning -> Infer ()
addWarning :: TypeWarning -> Infer ()
addWarning TypeWarning
w = (InferState -> InferState) -> Infer ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((InferState -> InferState) -> Infer ())
-> (InferState -> InferState) -> Infer ()
forall a b. (a -> b) -> a -> b
$ \InferState
st -> InferState
st { inferWarnings = w : inferWarnings st }

-- | Clear all accumulated warnings
clearWarnings :: Infer ()
clearWarnings :: Infer ()
clearWarnings = (InferState -> InferState) -> Infer ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((InferState -> InferState) -> Infer ())
-> (InferState -> InferState) -> Infer ()
forall a b. (a -> b) -> a -> b
$ \InferState
st -> InferState
st { inferWarnings = [] }

-- | Add type class constraints (with deduplication)
addConstraints :: [Constraint] -> Infer ()
addConstraints :: [Constraint] -> Infer ()
addConstraints [Constraint]
cs = (InferState -> InferState) -> Infer ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((InferState -> InferState) -> Infer ())
-> (InferState -> InferState) -> Infer ()
forall a b. (a -> b) -> a -> b
$ \InferState
st ->
  let existing :: [Constraint]
existing = InferState -> [Constraint]
inferConstraints InferState
st
      -- Only add constraints that are not already present
      newConstraints :: [Constraint]
newConstraints = (Constraint -> Bool) -> [Constraint] -> [Constraint]
forall a. (a -> Bool) -> [a] -> [a]
filter (Constraint -> [Constraint] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [Constraint]
existing) [Constraint]
cs
  in InferState
st { inferConstraints = existing ++ newConstraints }

-- | Get accumulated constraints
getConstraints :: Infer [Constraint]
getConstraints :: Infer [Constraint]
getConstraints = InferState -> [Constraint]
inferConstraints (InferState -> [Constraint])
-> ExceptT TypeError (StateT InferState IO) InferState
-> Infer [Constraint]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ExceptT TypeError (StateT InferState IO) InferState
forall s (m :: * -> *). MonadState s m => m s
get

-- | Clear accumulated constraints
clearConstraints :: Infer ()
clearConstraints :: Infer ()
clearConstraints = (InferState -> InferState) -> Infer ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((InferState -> InferState) -> Infer ())
-> (InferState -> InferState) -> Infer ()
forall a b. (a -> b) -> a -> b
$ \InferState
st -> InferState
st { inferConstraints = [] }

-- | Run an action with local constraint tracking
withLocalConstraints :: Infer a -> Infer (a, [Constraint])
withLocalConstraints :: forall a. Infer a -> Infer (a, [Constraint])
withLocalConstraints Infer a
action = do
  [Constraint]
oldConstraints <- Infer [Constraint]
getConstraints
  Infer ()
clearConstraints
  a
result <- Infer a
action
  [Constraint]
newConstraints <- Infer [Constraint]
getConstraints
  (InferState -> InferState) -> Infer ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((InferState -> InferState) -> Infer ())
-> (InferState -> InferState) -> Infer ()
forall a b. (a -> b) -> a -> b
$ \InferState
st -> InferState
st { inferConstraints = oldConstraints }
  (a, [Constraint]) -> Infer (a, [Constraint])
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (a
result, [Constraint]
newConstraints)

-- | Check if we're in permissive mode
isPermissive :: Infer Bool
isPermissive :: Infer Bool
isPermissive = InferConfig -> Bool
cfgPermissive (InferConfig -> Bool)
-> (InferState -> InferConfig) -> InferState -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. InferState -> InferConfig
inferConfig (InferState -> Bool)
-> ExceptT TypeError (StateT InferState IO) InferState
-> Infer Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ExceptT TypeError (StateT InferState IO) InferState
forall s (m :: * -> *). MonadState s m => m s
get

-- | Generate a fresh type variable
freshVar :: String -> Infer Type
freshVar :: String -> Infer Type
freshVar String
prefix = do
  InferState
st <- ExceptT TypeError (StateT InferState IO) InferState
forall s (m :: * -> *). MonadState s m => m s
get
  let n :: Int
n = InferState -> Int
inferCounter InferState
st
  InferState -> Infer ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put InferState
st { inferCounter = n + 1 }
  Type -> Infer Type
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> Infer Type) -> Type -> Infer Type
forall a b. (a -> b) -> a -> b
$ TyVar -> Type
TVar (TyVar -> Type) -> TyVar -> Type
forall a b. (a -> b) -> a -> b
$ String -> TyVar
TyVar (String -> TyVar) -> String -> TyVar
forall a b. (a -> b) -> a -> b
$ String
prefix String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
n

-- | Get the current type environment
getEnv :: Infer TypeEnv
getEnv :: Infer TypeEnv
getEnv = InferState -> TypeEnv
inferEnv (InferState -> TypeEnv)
-> ExceptT TypeError (StateT InferState IO) InferState
-> Infer TypeEnv
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ExceptT TypeError (StateT InferState IO) InferState
forall s (m :: * -> *). MonadState s m => m s
get

-- | Set the type environment
setEnv :: TypeEnv -> Infer ()
setEnv :: TypeEnv -> Infer ()
setEnv TypeEnv
env = (InferState -> InferState) -> Infer ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((InferState -> InferState) -> Infer ())
-> (InferState -> InferState) -> Infer ()
forall a b. (a -> b) -> a -> b
$ \InferState
st -> InferState
st { inferEnv = env }

-- | Get the current pattern type environment
getPatternEnv :: Infer PatternTypeEnv
getPatternEnv :: Infer PatternTypeEnv
getPatternEnv = InferState -> PatternTypeEnv
inferPatternEnv (InferState -> PatternTypeEnv)
-> ExceptT TypeError (StateT InferState IO) InferState
-> Infer PatternTypeEnv
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ExceptT TypeError (StateT InferState IO) InferState
forall s (m :: * -> *). MonadState s m => m s
get

-- | Set the pattern type environment
setPatternEnv :: PatternTypeEnv -> Infer ()
setPatternEnv :: PatternTypeEnv -> Infer ()
setPatternEnv PatternTypeEnv
penv = (InferState -> InferState) -> Infer ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((InferState -> InferState) -> Infer ())
-> (InferState -> InferState) -> Infer ()
forall a b. (a -> b) -> a -> b
$ \InferState
st -> InferState
st { inferPatternEnv = penv }

-- | Get the current pattern function environment (for disambiguation)
getPatternFuncEnv :: Infer PatternTypeEnv
getPatternFuncEnv :: Infer PatternTypeEnv
getPatternFuncEnv = InferState -> PatternTypeEnv
inferPatternFuncEnv (InferState -> PatternTypeEnv)
-> ExceptT TypeError (StateT InferState IO) InferState
-> Infer PatternTypeEnv
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ExceptT TypeError (StateT InferState IO) InferState
forall s (m :: * -> *). MonadState s m => m s
get

-- | Set the pattern function environment
setPatternFuncEnv :: PatternTypeEnv -> Infer ()
setPatternFuncEnv :: PatternTypeEnv -> Infer ()
setPatternFuncEnv PatternTypeEnv
penv = (InferState -> InferState) -> Infer ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((InferState -> InferState) -> Infer ())
-> (InferState -> InferState) -> Infer ()
forall a b. (a -> b) -> a -> b
$ \InferState
st -> InferState
st { inferPatternFuncEnv = penv }

-- | Get the current class environment
getClassEnv :: Infer ClassEnv
getClassEnv :: Infer ClassEnv
getClassEnv = InferState -> ClassEnv
inferClassEnv (InferState -> ClassEnv)
-> ExceptT TypeError (StateT InferState IO) InferState
-> Infer ClassEnv
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ExceptT TypeError (StateT InferState IO) InferState
forall s (m :: * -> *). MonadState s m => m s
get

-- | Resolve a constraint based on available instances
-- If the constraint type is a Tensor type and no instance exists for it,
-- try to use the element type's instance instead
-- | Resolve constraints in a TIExpr recursively
resolveConstraintsInTIExpr :: ClassEnv -> Subst -> TIExpr -> TIExpr
resolveConstraintsInTIExpr :: ClassEnv -> Subst -> TIExpr -> TIExpr
resolveConstraintsInTIExpr ClassEnv
classEnv Subst
subst (TIExpr (Forall [TyVar]
vars [Constraint]
constraints Type
ty) TIExprNode
node) =
  let resolvedConstraints :: [Constraint]
resolvedConstraints = (Constraint -> Constraint) -> [Constraint] -> [Constraint]
forall a b. (a -> b) -> [a] -> [b]
map (ClassEnv -> Subst -> Constraint -> Constraint
resolveConstraintWithInstances ClassEnv
classEnv Subst
subst) [Constraint]
constraints
      resolvedNode :: TIExprNode
resolvedNode = ClassEnv -> Subst -> TIExprNode -> TIExprNode
resolveConstraintsInNode ClassEnv
classEnv Subst
subst TIExprNode
node
  in TypeScheme -> TIExprNode -> TIExpr
TIExpr ([TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [TyVar]
vars [Constraint]
resolvedConstraints Type
ty) TIExprNode
resolvedNode

-- | Resolve constraints in a TIExprNode recursively
resolveConstraintsInNode :: ClassEnv -> Subst -> TIExprNode -> TIExprNode
resolveConstraintsInNode :: ClassEnv -> Subst -> TIExprNode -> TIExprNode
resolveConstraintsInNode ClassEnv
classEnv Subst
subst TIExprNode
node = case TIExprNode
node of
  TIConstantExpr ConstantExpr
c -> ConstantExpr -> TIExprNode
TIConstantExpr ConstantExpr
c
  TIVarExpr String
name -> String -> TIExprNode
TIVarExpr String
name
  TILambdaExpr Maybe Var
mVar [Var]
params TIExpr
body ->
    Maybe Var -> [Var] -> TIExpr -> TIExprNode
TILambdaExpr Maybe Var
mVar [Var]
params (ClassEnv -> Subst -> TIExpr -> TIExpr
resolveConstraintsInTIExpr ClassEnv
classEnv Subst
subst TIExpr
body)
  TIApplyExpr TIExpr
func [TIExpr]
args ->
    TIExpr -> [TIExpr] -> TIExprNode
TIApplyExpr (ClassEnv -> Subst -> TIExpr -> TIExpr
resolveConstraintsInTIExpr ClassEnv
classEnv Subst
subst TIExpr
func)
                ((TIExpr -> TIExpr) -> [TIExpr] -> [TIExpr]
forall a b. (a -> b) -> [a] -> [b]
map (ClassEnv -> Subst -> TIExpr -> TIExpr
resolveConstraintsInTIExpr ClassEnv
classEnv Subst
subst) [TIExpr]
args)
  TITupleExpr [TIExpr]
exprs ->
    [TIExpr] -> TIExprNode
TITupleExpr ((TIExpr -> TIExpr) -> [TIExpr] -> [TIExpr]
forall a b. (a -> b) -> [a] -> [b]
map (ClassEnv -> Subst -> TIExpr -> TIExpr
resolveConstraintsInTIExpr ClassEnv
classEnv Subst
subst) [TIExpr]
exprs)
  TICollectionExpr [TIExpr]
exprs ->
    [TIExpr] -> TIExprNode
TICollectionExpr ((TIExpr -> TIExpr) -> [TIExpr] -> [TIExpr]
forall a b. (a -> b) -> [a] -> [b]
map (ClassEnv -> Subst -> TIExpr -> TIExpr
resolveConstraintsInTIExpr ClassEnv
classEnv Subst
subst) [TIExpr]
exprs)
  TIIfExpr TIExpr
cond TIExpr
thenExpr TIExpr
elseExpr ->
    TIExpr -> TIExpr -> TIExpr -> TIExprNode
TIIfExpr (ClassEnv -> Subst -> TIExpr -> TIExpr
resolveConstraintsInTIExpr ClassEnv
classEnv Subst
subst TIExpr
cond)
             (ClassEnv -> Subst -> TIExpr -> TIExpr
resolveConstraintsInTIExpr ClassEnv
classEnv Subst
subst TIExpr
thenExpr)
             (ClassEnv -> Subst -> TIExpr -> TIExpr
resolveConstraintsInTIExpr ClassEnv
classEnv Subst
subst TIExpr
elseExpr)
  TILetExpr [TIBindingExpr]
bindings TIExpr
body ->
    [TIBindingExpr] -> TIExpr -> TIExprNode
TILetExpr ((TIBindingExpr -> TIBindingExpr)
-> [TIBindingExpr] -> [TIBindingExpr]
forall a b. (a -> b) -> [a] -> [b]
map (\(IPrimitiveDataPattern
p, TIExpr
e) -> (IPrimitiveDataPattern
p, ClassEnv -> Subst -> TIExpr -> TIExpr
resolveConstraintsInTIExpr ClassEnv
classEnv Subst
subst TIExpr
e)) [TIBindingExpr]
bindings)
              (ClassEnv -> Subst -> TIExpr -> TIExpr
resolveConstraintsInTIExpr ClassEnv
classEnv Subst
subst TIExpr
body)
  TILetRecExpr [TIBindingExpr]
bindings TIExpr
body ->
    [TIBindingExpr] -> TIExpr -> TIExprNode
TILetRecExpr ((TIBindingExpr -> TIBindingExpr)
-> [TIBindingExpr] -> [TIBindingExpr]
forall a b. (a -> b) -> [a] -> [b]
map (\(IPrimitiveDataPattern
p, TIExpr
e) -> (IPrimitiveDataPattern
p, ClassEnv -> Subst -> TIExpr -> TIExpr
resolveConstraintsInTIExpr ClassEnv
classEnv Subst
subst TIExpr
e)) [TIBindingExpr]
bindings)
                 (ClassEnv -> Subst -> TIExpr -> TIExpr
resolveConstraintsInTIExpr ClassEnv
classEnv Subst
subst TIExpr
body)
  TIIndexedExpr Bool
override TIExpr
expr [Index TIExpr]
indices ->
    Bool -> TIExpr -> [Index TIExpr] -> TIExprNode
TIIndexedExpr Bool
override (ClassEnv -> Subst -> TIExpr -> TIExpr
resolveConstraintsInTIExpr ClassEnv
classEnv Subst
subst TIExpr
expr) 
                  ((TIExpr -> TIExpr) -> Index TIExpr -> Index TIExpr
forall a b. (a -> b) -> Index a -> Index b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (ClassEnv -> Subst -> TIExpr -> TIExpr
resolveConstraintsInTIExpr ClassEnv
classEnv Subst
subst) (Index TIExpr -> Index TIExpr) -> [Index TIExpr] -> [Index TIExpr]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Index TIExpr]
indices)
  TIGenerateTensorExpr TIExpr
func TIExpr
shape ->
    TIExpr -> TIExpr -> TIExprNode
TIGenerateTensorExpr (ClassEnv -> Subst -> TIExpr -> TIExpr
resolveConstraintsInTIExpr ClassEnv
classEnv Subst
subst TIExpr
func)
                         (ClassEnv -> Subst -> TIExpr -> TIExpr
resolveConstraintsInTIExpr ClassEnv
classEnv Subst
subst TIExpr
shape)
  TITensorExpr TIExpr
shape TIExpr
elems ->
    TIExpr -> TIExpr -> TIExprNode
TITensorExpr (ClassEnv -> Subst -> TIExpr -> TIExpr
resolveConstraintsInTIExpr ClassEnv
classEnv Subst
subst TIExpr
shape)
                 (ClassEnv -> Subst -> TIExpr -> TIExpr
resolveConstraintsInTIExpr ClassEnv
classEnv Subst
subst TIExpr
elems)
  TITensorContractExpr TIExpr
tensor ->
    TIExpr -> TIExprNode
TITensorContractExpr (ClassEnv -> Subst -> TIExpr -> TIExpr
resolveConstraintsInTIExpr ClassEnv
classEnv Subst
subst TIExpr
tensor)
  TITensorMapExpr TIExpr
func TIExpr
tensor ->
    TIExpr -> TIExpr -> TIExprNode
TITensorMapExpr (ClassEnv -> Subst -> TIExpr -> TIExpr
resolveConstraintsInTIExpr ClassEnv
classEnv Subst
subst TIExpr
func)
                    (ClassEnv -> Subst -> TIExpr -> TIExpr
resolveConstraintsInTIExpr ClassEnv
classEnv Subst
subst TIExpr
tensor)
  TITensorMap2Expr TIExpr
func TIExpr
t1 TIExpr
t2 ->
    TIExpr -> TIExpr -> TIExpr -> TIExprNode
TITensorMap2Expr (ClassEnv -> Subst -> TIExpr -> TIExpr
resolveConstraintsInTIExpr ClassEnv
classEnv Subst
subst TIExpr
func)
                     (ClassEnv -> Subst -> TIExpr -> TIExpr
resolveConstraintsInTIExpr ClassEnv
classEnv Subst
subst TIExpr
t1)
                     (ClassEnv -> Subst -> TIExpr -> TIExpr
resolveConstraintsInTIExpr ClassEnv
classEnv Subst
subst TIExpr
t2)
  TIMatchExpr PMMode
mode TIExpr
target TIExpr
matcher [TIMatchClause]
clauses ->
    PMMode -> TIExpr -> TIExpr -> [TIMatchClause] -> TIExprNode
TIMatchExpr PMMode
mode
                (ClassEnv -> Subst -> TIExpr -> TIExpr
resolveConstraintsInTIExpr ClassEnv
classEnv Subst
subst TIExpr
target)
                (ClassEnv -> Subst -> TIExpr -> TIExpr
resolveConstraintsInTIExpr ClassEnv
classEnv Subst
subst TIExpr
matcher)
                ((TIMatchClause -> TIMatchClause)
-> [TIMatchClause] -> [TIMatchClause]
forall a b. (a -> b) -> [a] -> [b]
map (\(TIPattern
p, TIExpr
e) -> (TIPattern
p, ClassEnv -> Subst -> TIExpr -> TIExpr
resolveConstraintsInTIExpr ClassEnv
classEnv Subst
subst TIExpr
e)) [TIMatchClause]
clauses)
  TIExprNode
_ -> TIExprNode
node

resolveConstraintWithInstances :: ClassEnv -> Subst -> Constraint -> Constraint
resolveConstraintWithInstances :: ClassEnv -> Subst -> Constraint -> Constraint
resolveConstraintWithInstances ClassEnv
classEnv Subst
subst (Constraint String
className Type
tyVar) =
  let resolvedType :: Type
resolvedType = Subst -> Type -> Type
applySubst Subst
subst Type
tyVar
      instances :: [InstanceInfo]
instances = String -> ClassEnv -> [InstanceInfo]
lookupInstances String
className ClassEnv
classEnv
  in case Type
resolvedType of
       TTensor Type
elemType ->
         -- For Tensor types, search for an instance
         case Type -> [InstanceInfo] -> Maybe InstanceInfo
findMatchingInstanceForType Type
resolvedType [InstanceInfo]
instances of
           Just InstanceInfo
_ -> 
             -- If Tensor itself has an instance, use it
             String -> Type -> Constraint
Constraint String
className Type
resolvedType
           Maybe InstanceInfo
Nothing -> 
             -- If Tensor has no instance, use the element type's constraint
             -- This assumes tensorMap will apply element-wise
             -- Use element type's constraint even if no instance is found for it
             -- (Error will be detected in a later phase)
             String -> Type -> Constraint
Constraint String
className Type
elemType
       Type
_ -> 
         -- For non-Tensor types, simply apply the substitution
         String -> Type -> Constraint
Constraint String
className Type
resolvedType

-- | Extend the environment temporarily
withEnv :: [(String, TypeScheme)] -> Infer a -> Infer a
withEnv :: forall a. [(String, TypeScheme)] -> Infer a -> Infer a
withEnv [(String, TypeScheme)]
bindings Infer a
action = do
  TypeEnv
oldEnv <- Infer TypeEnv
getEnv
  TypeEnv -> Infer ()
setEnv (TypeEnv -> Infer ()) -> TypeEnv -> Infer ()
forall a b. (a -> b) -> a -> b
$ [(Var, TypeScheme)] -> TypeEnv -> TypeEnv
extendEnvMany (((String, TypeScheme) -> (Var, TypeScheme))
-> [(String, TypeScheme)] -> [(Var, TypeScheme)]
forall a b. (a -> b) -> [a] -> [b]
map (\(String
name, TypeScheme
scheme) -> (String -> Var
stringToVar String
name, TypeScheme
scheme)) [(String, TypeScheme)]
bindings) TypeEnv
oldEnv
  a
result <- Infer a
action
  TypeEnv -> Infer ()
setEnv TypeEnv
oldEnv
  a -> Infer a
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return a
result

-- | Look up a variable's type
lookupVar :: String -> Infer Type
lookupVar :: String -> Infer Type
lookupVar String
name = do
  TypeEnv
env <- Infer TypeEnv
getEnv
  case Var -> TypeEnv -> Maybe TypeScheme
lookupEnv (String -> Var
stringToVar String
name) TypeEnv
env of
    Just TypeScheme
scheme -> do
      InferState
st <- ExceptT TypeError (StateT InferState IO) InferState
forall s (m :: * -> *). MonadState s m => m s
get
      let ([Constraint]
constraints, Type
t, Int
newCounter) = TypeScheme -> Int -> ([Constraint], Type, Int)
instantiate TypeScheme
scheme (InferState -> Int
inferCounter InferState
st)
      -- Track constraints for type class resolution
      (InferState -> InferState) -> Infer ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((InferState -> InferState) -> Infer ())
-> (InferState -> InferState) -> Infer ()
forall a b. (a -> b) -> a -> b
$ \InferState
s -> InferState
s { inferCounter = newCounter }
      [Constraint] -> Infer ()
addConstraints [Constraint]
constraints
      Type -> Infer Type
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return Type
t
    Maybe TypeScheme
Nothing -> do
      -- Check if this is a declared symbol
      InferState
st <- ExceptT TypeError (StateT InferState IO) InferState
forall s (m :: * -> *). MonadState s m => m s
get
      case String -> Map String Type -> Maybe Type
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup String
name (InferState -> Map String Type
declaredSymbols InferState
st) of
        Just Type
ty -> Type -> Infer Type
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return Type
ty  -- Return the declared type without warning
        Maybe Type
Nothing -> do
          Bool
permissive <- Infer Bool
isPermissive
          if Bool
permissive
            then do
              -- In permissive mode, treat as a warning and return a fresh type variable
              TypeWarning -> Infer ()
addWarning (TypeWarning -> Infer ()) -> TypeWarning -> Infer ()
forall a b. (a -> b) -> a -> b
$ String -> TypeErrorContext -> TypeWarning
UnboundVariableWarning String
name TypeErrorContext
emptyContext
              String -> Infer Type
freshVar String
"unbound"
            else TypeError -> Infer Type
forall a. TypeError -> ExceptT TypeError (StateT InferState IO) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (TypeError -> Infer Type) -> TypeError -> Infer Type
forall a b. (a -> b) -> a -> b
$ String -> TypeErrorContext -> TypeError
UnboundVariable String
name TypeErrorContext
emptyContext

-- | Lookup variable and return type with constraints
lookupVarWithConstraints :: String -> Infer (Type, [Constraint])
lookupVarWithConstraints :: String -> Infer (Type, [Constraint])
lookupVarWithConstraints String
name = do
  TypeEnv
env <- Infer TypeEnv
getEnv
  case Var -> TypeEnv -> Maybe TypeScheme
lookupEnv (String -> Var
stringToVar String
name) TypeEnv
env of
    Just TypeScheme
scheme -> do
      InferState
st <- ExceptT TypeError (StateT InferState IO) InferState
forall s (m :: * -> *). MonadState s m => m s
get
      let ([Constraint]
constraints, Type
t, Int
newCounter) = TypeScheme -> Int -> ([Constraint], Type, Int)
instantiate TypeScheme
scheme (InferState -> Int
inferCounter InferState
st)
      -- Track constraints for type class resolution
      (InferState -> InferState) -> Infer ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((InferState -> InferState) -> Infer ())
-> (InferState -> InferState) -> Infer ()
forall a b. (a -> b) -> a -> b
$ \InferState
s -> InferState
s { inferCounter = newCounter }
      [Constraint] -> Infer ()
addConstraints [Constraint]
constraints
      (Type, [Constraint]) -> Infer (Type, [Constraint])
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
t, [Constraint]
constraints)
    Maybe TypeScheme
Nothing -> do
      -- Check if this is a declared symbol
      InferState
st <- ExceptT TypeError (StateT InferState IO) InferState
forall s (m :: * -> *). MonadState s m => m s
get
      case String -> Map String Type -> Maybe Type
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup String
name (InferState -> Map String Type
declaredSymbols InferState
st) of
        Just Type
ty -> (Type, [Constraint]) -> Infer (Type, [Constraint])
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
ty, [])  -- Return the declared type without warning
        Maybe Type
Nothing -> do
          Bool
permissive <- Infer Bool
isPermissive
          if Bool
permissive
            then do
              -- In permissive mode, treat as a warning and return a fresh type variable
              TypeWarning -> Infer ()
addWarning (TypeWarning -> Infer ()) -> TypeWarning -> Infer ()
forall a b. (a -> b) -> a -> b
$ String -> TypeErrorContext -> TypeWarning
UnboundVariableWarning String
name TypeErrorContext
emptyContext
              Type
t <- String -> Infer Type
freshVar String
"unbound"
              (Type, [Constraint]) -> Infer (Type, [Constraint])
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
t, [])
            else TypeError -> Infer (Type, [Constraint])
forall a. TypeError -> ExceptT TypeError (StateT InferState IO) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (TypeError -> Infer (Type, [Constraint]))
-> TypeError -> Infer (Type, [Constraint])
forall a b. (a -> b) -> a -> b
$ String -> TypeErrorContext -> TypeError
UnboundVariable String
name TypeErrorContext
emptyContext

-- | Unify two types
unifyTypes :: Type -> Type -> Infer Subst
unifyTypes :: Type -> Type -> Infer Subst
unifyTypes Type
t1 Type
t2 = Type -> Type -> TypeErrorContext -> Infer Subst
unifyTypesWithContext Type
t1 Type
t2 TypeErrorContext
emptyContext

-- | Unify two types with context information
-- This now uses the accumulated constraints from the Infer monad to properly
-- handle constraint-aware unification (e.g., ensuring {Num a} a doesn't unify with Tensor b)
unifyTypesWithContext :: Type -> Type -> TypeErrorContext -> Infer Subst
unifyTypesWithContext :: Type -> Type -> TypeErrorContext -> Infer Subst
unifyTypesWithContext Type
t1 Type
t2 TypeErrorContext
ctx = do
  [Constraint]
constraints <- Infer [Constraint]
getConstraints
  ClassEnv
classEnv <- Infer ClassEnv
getClassEnv
  case ClassEnv
-> [Constraint] -> Type -> Type -> Either UnifyError (Subst, Bool)
TU.unifyWithConstraints ClassEnv
classEnv [Constraint]
constraints Type
t1 Type
t2 of
    Right (Subst
s, Bool
_)  -> Subst -> Infer Subst
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return Subst
s  -- Discard flag in basic unification
    Left UnifyError
err -> case UnifyError
err of
      TU.OccursCheck TyVar
v Type
t -> TypeError -> Infer Subst
forall a. TypeError -> ExceptT TypeError (StateT InferState IO) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (TypeError -> Infer Subst) -> TypeError -> Infer Subst
forall a b. (a -> b) -> a -> b
$ TyVar -> Type -> TypeErrorContext -> TypeError
OccursCheckError TyVar
v Type
t TypeErrorContext
ctx
      TU.TypeMismatch Type
a Type
b -> TypeError -> Infer Subst
forall a. TypeError -> ExceptT TypeError (StateT InferState IO) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (TypeError -> Infer Subst) -> TypeError -> Infer Subst
forall a b. (a -> b) -> a -> b
$ Type -> Type -> TypeErrorContext -> TypeError
UnificationError Type
a Type
b TypeErrorContext
ctx

-- | Unify two types with context, allowing Tensor a to unify with a
-- This is used only for top-level definitions with type annotations
-- According to type-tensor-simple.md: "Only for top-level tensor definitions, if Tensor a is unified with a, it becomes a."
unifyTypesWithTopLevel :: Type -> Type -> TypeErrorContext -> Infer Subst
unifyTypesWithTopLevel :: Type -> Type -> TypeErrorContext -> Infer Subst
unifyTypesWithTopLevel Type
t1 Type
t2 TypeErrorContext
ctx = case Type -> Type -> Either UnifyError Subst
TU.unifyWithTopLevel Type
t1 Type
t2 of
  Right Subst
s  -> Subst -> Infer Subst
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return Subst
s
  Left UnifyError
err -> case UnifyError
err of
    TU.OccursCheck TyVar
v Type
t -> TypeError -> Infer Subst
forall a. TypeError -> ExceptT TypeError (StateT InferState IO) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (TypeError -> Infer Subst) -> TypeError -> Infer Subst
forall a b. (a -> b) -> a -> b
$ TyVar -> Type -> TypeErrorContext -> TypeError
OccursCheckError TyVar
v Type
t TypeErrorContext
ctx
    TU.TypeMismatch Type
a Type
b -> TypeError -> Infer Subst
forall a. TypeError -> ExceptT TypeError (StateT InferState IO) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (TypeError -> Infer Subst) -> TypeError -> Infer Subst
forall a b. (a -> b) -> a -> b
$ Type -> Type -> TypeErrorContext -> TypeError
UnificationError Type
a Type
b TypeErrorContext
ctx

-- | Unify two types with constraint-aware handling
-- This is crucial for unifying types when type variables have constraints
-- (e.g., {Num t0}) - the constraint affects how Tensor types are unified
unifyTypesWithConstraints :: [Constraint] -> Type -> Type -> TypeErrorContext -> Infer Subst
unifyTypesWithConstraints :: [Constraint] -> Type -> Type -> TypeErrorContext -> Infer Subst
unifyTypesWithConstraints [Constraint]
constraints Type
t1 Type
t2 TypeErrorContext
ctx = do
  ClassEnv
classEnv <- Infer ClassEnv
getClassEnv
  case ClassEnv
-> [Constraint] -> Type -> Type -> Either UnifyError (Subst, Bool)
TU.unifyWithConstraints ClassEnv
classEnv [Constraint]
constraints Type
t1 Type
t2 of
    Right (Subst
s, Bool
_)  -> Subst -> Infer Subst
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return Subst
s  -- Discard flag in basic unification
    Left UnifyError
err -> case UnifyError
err of
      TU.OccursCheck TyVar
v Type
t -> TypeError -> Infer Subst
forall a. TypeError -> ExceptT TypeError (StateT InferState IO) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (TypeError -> Infer Subst) -> TypeError -> Infer Subst
forall a b. (a -> b) -> a -> b
$ TyVar -> Type -> TypeErrorContext -> TypeError
OccursCheckError TyVar
v Type
t TypeErrorContext
ctx
      TU.TypeMismatch Type
a Type
b -> TypeError -> Infer Subst
forall a. TypeError -> ExceptT TypeError (StateT InferState IO) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (TypeError -> Infer Subst) -> TypeError -> Infer Subst
forall a b. (a -> b) -> a -> b
$ Type -> Type -> TypeErrorContext -> TypeError
UnificationError Type
a Type
b TypeErrorContext
ctx

-- | Infer type for constants
inferConstant :: ConstantExpr -> Infer Type
inferConstant :: ConstantExpr -> Infer Type
inferConstant ConstantExpr
c = case ConstantExpr
c of
  CharExpr Char
_    -> Type -> Infer Type
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return Type
TChar
  StringExpr Text
_  -> Type -> Infer Type
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return Type
TString
  BoolExpr Bool
_    -> Type -> Infer Type
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return Type
TBool
  IntegerExpr Integer
_ -> Type -> Infer Type
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return Type
TInt
  FloatExpr Double
_   -> Type -> Infer Type
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return Type
TFloat
  -- something : Matcher a (polymorphic matcher that matches any type)
  ConstantExpr
SomethingExpr -> do
    Type
elemType <- String -> Infer Type
freshVar String
"a"
    Type -> Infer Type
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> Type
TMatcher Type
elemType)
  -- undefined has a fresh type variable (bottom-like, can be any type)
  ConstantExpr
UndefinedExpr -> String -> Infer Type
freshVar String
"undefined"

--------------------------------------------------------------------------------
-- * Type Inference for IExpr
--------------------------------------------------------------------------------

-- | Helper: Create a TIExpr with a simple monomorphic type (no type variables, no constraints)
mkTIExpr :: Type -> TIExprNode -> TIExpr
mkTIExpr :: Type -> TIExprNode -> TIExpr
mkTIExpr Type
ty TIExprNode
node = TypeScheme -> TIExprNode -> TIExpr
TIExpr ([TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [] Type
ty) TIExprNode
node

-- | Simplify Tensor constraints in type schemes
-- Rewrites C (Tensor a) to C a when C (Tensor a) has no instance but C a does
-- This enables correct type class expansion for higher-order functions with Tensor arguments
simplifyTensorConstraints :: ClassEnv -> [Constraint] -> [Constraint]
simplifyTensorConstraints :: ClassEnv -> [Constraint] -> [Constraint]
simplifyTensorConstraints ClassEnv
classEnv = (Constraint -> Constraint) -> [Constraint] -> [Constraint]
forall a b. (a -> b) -> [a] -> [b]
map Constraint -> Constraint
simplifyConstraint
  where
    hasInstance :: String -> Type -> Bool
    hasInstance :: String -> Type -> Bool
hasInstance String
cls Type
ty =
      case Type -> [InstanceInfo] -> Maybe InstanceInfo
findMatchingInstanceForType Type
ty (String -> ClassEnv -> [InstanceInfo]
lookupInstances String
cls ClassEnv
classEnv) of
        Just InstanceInfo
_  -> Bool
True
        Maybe InstanceInfo
Nothing -> Bool
False
    
    simplifyConstraint :: Constraint -> Constraint
    simplifyConstraint :: Constraint -> Constraint
simplifyConstraint (Constraint String
cls Type
ty) = String -> Type -> Constraint
Constraint String
cls (String -> Type -> Type
unwrapTensorInType String
cls Type
ty)
      where
        unwrapTensorInType :: String -> Type -> Type
        unwrapTensorInType :: String -> Type -> Type
unwrapTensorInType String
cls' Type
ty0 = case Type
ty0 of
          TTensor Type
inner
            | String -> Type -> Bool
hasInstance String
cls' Type
ty0   -> Type
ty0           -- Tensor has instance, keep it
            | String -> Type -> Bool
hasInstance String
cls' Type
inner -> String -> Type -> Type
unwrapTensorInType String
cls' Type
inner  -- Unwrap recursively
            | Bool
otherwise              -> Type
ty0           -- No instance for either, keep original
          Type
_ -> Type
ty0

-- | Simplify Tensor constraints in a type scheme
-- During type inference, keep type variables unquantified (Forall [])
-- Quantification only happens at let/def boundaries
simplifyTensorConstraintsInScheme :: ClassEnv -> TypeScheme -> TypeScheme
simplifyTensorConstraintsInScheme :: ClassEnv -> TypeScheme -> TypeScheme
simplifyTensorConstraintsInScheme ClassEnv
classEnv (Forall [TyVar]
tvs [Constraint]
cs Type
ty) =
  let cs' :: [Constraint]
cs' = ClassEnv -> [Constraint] -> [Constraint]
simplifyTensorConstraints ClassEnv
classEnv [Constraint]
cs
  in [TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [TyVar]
tvs [Constraint]
cs' Type
ty

-- | Simplify Tensor constraints in a TIExpr
simplifyTensorConstraintsInTIExpr :: ClassEnv -> TIExpr -> TIExpr
simplifyTensorConstraintsInTIExpr :: ClassEnv -> TIExpr -> TIExpr
simplifyTensorConstraintsInTIExpr ClassEnv
classEnv (TIExpr TypeScheme
scheme TIExprNode
node) =
  TypeScheme -> TIExprNode -> TIExpr
TIExpr (ClassEnv -> TypeScheme -> TypeScheme
simplifyTensorConstraintsInScheme ClassEnv
classEnv TypeScheme
scheme) TIExprNode
node

-- | Apply a substitution to a type scheme with class environment awareness
-- This adjusts the substitution based on type class constraints:
-- When {Num t0} t0 -> t0 is unified with Tensor t1, if Num (Tensor t1) has no instance,
-- the substitution is adjusted to t0 -> t1 (unwrapping the Tensor)
applySubstSchemeWithClassEnv :: ClassEnv -> Subst -> TypeScheme -> TypeScheme
applySubstSchemeWithClassEnv :: ClassEnv -> Subst -> TypeScheme -> TypeScheme
applySubstSchemeWithClassEnv ClassEnv
classEnv (Subst Map TyVar Type
m) (Forall [TyVar]
vs [Constraint]
cs Type
t) =
  let m' :: Map TyVar Type
m' = (TyVar -> Map TyVar Type -> Map TyVar Type)
-> Map TyVar Type -> [TyVar] -> Map TyVar Type
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr TyVar -> Map TyVar Type -> Map TyVar Type
forall k a. Ord k => k -> Map k a -> Map k a
Map.delete Map TyVar Type
m [TyVar]
vs
      -- Adjust substitution based on constraints
      m'' :: Map TyVar Type
m'' = ClassEnv -> [Constraint] -> Map TyVar Type -> Map TyVar Type
adjustSubstForConstraints ClassEnv
classEnv [Constraint]
cs Map TyVar Type
m'
      s' :: Subst
s' = Map TyVar Type -> Subst
Subst Map TyVar Type
m''
  in [TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [TyVar]
vs ((Constraint -> Constraint) -> [Constraint] -> [Constraint]
forall a b. (a -> b) -> [a] -> [b]
map (Subst -> Constraint -> Constraint
applySubstConstraint Subst
s') [Constraint]
cs) (Subst -> Type -> Type
applySubst Subst
s' Type
t)
  where
    -- Adjust substitution to unwrap Tensor when constraint has no instance
    adjustSubstForConstraints :: ClassEnv -> [Constraint] -> Map.Map TyVar Type -> Map.Map TyVar Type
    adjustSubstForConstraints :: ClassEnv -> [Constraint] -> Map TyVar Type -> Map TyVar Type
adjustSubstForConstraints ClassEnv
env [Constraint]
constraints Map TyVar Type
substMap =
      -- For each constraint, check if we need to adjust substitutions
      (Constraint -> Map TyVar Type -> Map TyVar Type)
-> Map TyVar Type -> [Constraint] -> Map TyVar Type
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (ClassEnv
-> Map TyVar Type -> Constraint -> Map TyVar Type -> Map TyVar Type
adjustForConstraint ClassEnv
env Map TyVar Type
substMap) Map TyVar Type
substMap [Constraint]
constraints

    adjustForConstraint :: ClassEnv -> Map.Map TyVar Type -> Constraint -> Map.Map TyVar Type -> Map.Map TyVar Type
    adjustForConstraint :: ClassEnv
-> Map TyVar Type -> Constraint -> Map TyVar Type -> Map TyVar Type
adjustForConstraint ClassEnv
env Map TyVar Type
originalSubst (Constraint String
cls Type
constraintType) Map TyVar Type
currentSubst =
      -- Get all type variables in the constraint type
      let constraintVars :: [TyVar]
constraintVars = Set TyVar -> [TyVar]
forall a. Set a -> [a]
Set.toList (Set TyVar -> [TyVar]) -> Set TyVar -> [TyVar]
forall a b. (a -> b) -> a -> b
$ Type -> Set TyVar
freeTyVars Type
constraintType
      in (TyVar -> Map TyVar Type -> Map TyVar Type)
-> Map TyVar Type -> [TyVar] -> Map TyVar Type
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (ClassEnv
-> String
-> Map TyVar Type
-> TyVar
-> Map TyVar Type
-> Map TyVar Type
adjustVarForClass ClassEnv
env String
cls Map TyVar Type
originalSubst) Map TyVar Type
currentSubst [TyVar]
constraintVars

    adjustVarForClass :: ClassEnv -> String -> Map.Map TyVar Type -> TyVar -> Map.Map TyVar Type -> Map.Map TyVar Type
    adjustVarForClass :: ClassEnv
-> String
-> Map TyVar Type
-> TyVar
-> Map TyVar Type
-> Map TyVar Type
adjustVarForClass ClassEnv
env String
cls Map TyVar Type
originalSubst TyVar
var Map TyVar Type
currentSubst =
      case TyVar -> Map TyVar Type -> Maybe Type
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup TyVar
var Map TyVar Type
originalSubst of
        Just replacementType :: Type
replacementType@(TTensor Type
_) ->
          -- This variable is being replaced with a Tensor type
          -- Check if the class has an instance for the Tensor type
          let instances :: [InstanceInfo]
instances = String -> ClassEnv -> [InstanceInfo]
lookupInstances String
cls ClassEnv
env
              hasTensorInstance :: Bool
hasTensorInstance = case Type -> [InstanceInfo] -> Maybe InstanceInfo
findMatchingInstanceForType Type
replacementType [InstanceInfo]
instances of
                                    Just InstanceInfo
_  -> Bool
True
                                    Maybe InstanceInfo
Nothing -> Bool
False
          in if Bool
hasTensorInstance
               then Map TyVar Type
currentSubst  -- Keep the Tensor substitution
               else TyVar -> Type -> Map TyVar Type -> Map TyVar Type
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert TyVar
var (Type -> Type
unwrapTensorCompletely Type
replacementType) Map TyVar Type
currentSubst  -- Unwrap Tensor
        Maybe Type
_ -> Map TyVar Type
currentSubst  -- Not a Tensor substitution, keep as is

    -- Recursively unwrap Tensor to get the innermost type
    unwrapTensorCompletely :: Type -> Type
    unwrapTensorCompletely :: Type -> Type
unwrapTensorCompletely (TTensor Type
inner) = Type -> Type
unwrapTensorCompletely Type
inner
    unwrapTensorCompletely Type
ty = Type
ty

-- | Apply a substitution to a TIExpr, updating both the type scheme and all subexpressions
applySubstToTIExpr :: Subst -> TIExpr -> TIExpr
applySubstToTIExpr :: Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s (TIExpr TypeScheme
scheme TIExprNode
node) =
  let updatedScheme :: TypeScheme
updatedScheme = Subst -> TypeScheme -> TypeScheme
applySubstScheme Subst
s TypeScheme
scheme
      updatedNode :: TIExprNode
updatedNode = Subst -> TIExprNode -> TIExprNode
applySubstToTIExprNode Subst
s TIExprNode
node
  in TypeScheme -> TIExprNode -> TIExpr
TIExpr TypeScheme
updatedScheme TIExprNode
updatedNode

-- | Apply a substitution to a TIExpr with ClassEnv awareness
-- This adjusts the substitution based on type class constraints
-- Example: {Num t0} t0 -> t0 with substitution t0 -> Tensor t1
--   If Num (Tensor t1) has no instance, the substitution is adjusted to t0 -> t1
applySubstToTIExprWithClassEnv :: ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv :: ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
classEnv Subst
s (TIExpr TypeScheme
scheme TIExprNode
node) =
  let updatedScheme :: TypeScheme
updatedScheme = ClassEnv -> Subst -> TypeScheme -> TypeScheme
applySubstSchemeWithClassEnv ClassEnv
classEnv Subst
s TypeScheme
scheme
      updatedNode :: TIExprNode
updatedNode = ClassEnv -> Subst -> TIExprNode -> TIExprNode
applySubstToTIExprNodeWithClassEnv ClassEnv
classEnv Subst
s TIExprNode
node
  in TypeScheme -> TIExprNode -> TIExpr
TIExpr TypeScheme
updatedScheme TIExprNode
updatedNode

-- | Monadic version that uses ClassEnv to adjust substitutions based on constraints
-- Use this in type inference when you need to apply substitutions with constraint awareness
applySubstToTIExprM :: Subst -> TIExpr -> Infer TIExpr
applySubstToTIExprM :: Subst -> TIExpr -> Infer TIExpr
applySubstToTIExprM Subst
s TIExpr
tiExpr = do
  ClassEnv
classEnv <- Infer ClassEnv
getClassEnv
  TIExpr -> Infer TIExpr
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIExpr -> Infer TIExpr) -> TIExpr -> Infer TIExpr
forall a b. (a -> b) -> a -> b
$ ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
classEnv Subst
s TIExpr
tiExpr

-- | Apply a substitution to a Type with constraint awareness
-- This is a monadic version that retrieves ClassEnv and constraints from the Infer monad
-- and adjusts the substitution based on type class constraints before applying it
applySubstWithConstraintsM :: Subst -> Type -> Infer Type
applySubstWithConstraintsM :: Subst -> Type -> Infer Type
applySubstWithConstraintsM s :: Subst
s@(Subst Map TyVar Type
m) Type
t = do
  ClassEnv
classEnv <- Infer ClassEnv
getClassEnv
  [Constraint]
constraints <- (InferState -> [Constraint]) -> Infer [Constraint]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets InferState -> [Constraint]
inferConstraints
  -- Adjust substitution based on constraints using the same logic as applySubstSchemeWithClassEnv
  let m' :: Map TyVar Type
m' = ClassEnv -> [Constraint] -> Map TyVar Type -> Map TyVar Type
adjustSubstForConstraints ClassEnv
classEnv [Constraint]
constraints Map TyVar Type
m
      s' :: Subst
s' = Map TyVar Type -> Subst
Subst Map TyVar Type
m'
  Type -> Infer Type
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> Infer Type) -> Type -> Infer Type
forall a b. (a -> b) -> a -> b
$ Subst -> Type -> Type
applySubst Subst
s' Type
t
  where
    -- Adjust substitution to unwrap Tensor when constraint has no instance
    adjustSubstForConstraints :: ClassEnv -> [Constraint] -> Map.Map TyVar Type -> Map.Map TyVar Type
    adjustSubstForConstraints :: ClassEnv -> [Constraint] -> Map TyVar Type -> Map TyVar Type
adjustSubstForConstraints ClassEnv
env [Constraint]
cs Map TyVar Type
substMap =
      (Constraint -> Map TyVar Type -> Map TyVar Type)
-> Map TyVar Type -> [Constraint] -> Map TyVar Type
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (ClassEnv
-> Map TyVar Type -> Constraint -> Map TyVar Type -> Map TyVar Type
adjustForConstraint ClassEnv
env Map TyVar Type
substMap) Map TyVar Type
substMap [Constraint]
cs

    adjustForConstraint :: ClassEnv -> Map.Map TyVar Type -> Constraint -> Map.Map TyVar Type -> Map.Map TyVar Type
    adjustForConstraint :: ClassEnv
-> Map TyVar Type -> Constraint -> Map TyVar Type -> Map TyVar Type
adjustForConstraint ClassEnv
env Map TyVar Type
originalSubst (Constraint String
cls Type
constraintType) Map TyVar Type
currentSubst =
      let constraintVars :: [TyVar]
constraintVars = Set TyVar -> [TyVar]
forall a. Set a -> [a]
Set.toList (Set TyVar -> [TyVar]) -> Set TyVar -> [TyVar]
forall a b. (a -> b) -> a -> b
$ Type -> Set TyVar
freeTyVars Type
constraintType
      in (TyVar -> Map TyVar Type -> Map TyVar Type)
-> Map TyVar Type -> [TyVar] -> Map TyVar Type
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (ClassEnv
-> String
-> Map TyVar Type
-> TyVar
-> Map TyVar Type
-> Map TyVar Type
adjustVarForClass ClassEnv
env String
cls Map TyVar Type
originalSubst) Map TyVar Type
currentSubst [TyVar]
constraintVars

    adjustVarForClass :: ClassEnv -> String -> Map.Map TyVar Type -> TyVar -> Map.Map TyVar Type -> Map.Map TyVar Type
    adjustVarForClass :: ClassEnv
-> String
-> Map TyVar Type
-> TyVar
-> Map TyVar Type
-> Map TyVar Type
adjustVarForClass ClassEnv
env String
cls Map TyVar Type
originalSubst TyVar
var Map TyVar Type
currentSubst =
      case TyVar -> Map TyVar Type -> Maybe Type
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup TyVar
var Map TyVar Type
originalSubst of
        Just replacementType :: Type
replacementType@(TTensor Type
_) ->
          let instances :: [InstanceInfo]
instances = String -> ClassEnv -> [InstanceInfo]
lookupInstances String
cls ClassEnv
env
              hasTensorInstance :: Bool
hasTensorInstance = case Type -> [InstanceInfo] -> Maybe InstanceInfo
findMatchingInstanceForType Type
replacementType [InstanceInfo]
instances of
                                    Just InstanceInfo
_  -> Bool
True
                                    Maybe InstanceInfo
Nothing -> Bool
False
          in if Bool
hasTensorInstance
               then Map TyVar Type
currentSubst
               else TyVar -> Type -> Map TyVar Type -> Map TyVar Type
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert TyVar
var (Type -> Type
unwrapTensorCompletely Type
replacementType) Map TyVar Type
currentSubst
        Maybe Type
_ -> Map TyVar Type
currentSubst

    unwrapTensorCompletely :: Type -> Type
    unwrapTensorCompletely :: Type -> Type
unwrapTensorCompletely (TTensor Type
inner) = Type -> Type
unwrapTensorCompletely Type
inner
    unwrapTensorCompletely Type
ty = Type
ty

-- | Apply a substitution to a TIExprNode recursively
applySubstToTIExprNode :: Subst -> TIExprNode -> TIExprNode
applySubstToTIExprNode :: Subst -> TIExprNode -> TIExprNode
applySubstToTIExprNode Subst
s TIExprNode
node = case TIExprNode
node of
  TIConstantExpr ConstantExpr
c -> ConstantExpr -> TIExprNode
TIConstantExpr ConstantExpr
c
  TIVarExpr String
name -> String -> TIExprNode
TIVarExpr String
name
  
  TILambdaExpr Maybe Var
mVar [Var]
params TIExpr
body ->
    Maybe Var -> [Var] -> TIExpr -> TIExprNode
TILambdaExpr Maybe Var
mVar [Var]
params (Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s TIExpr
body)
  
  TIApplyExpr TIExpr
func [TIExpr]
args ->
    TIExpr -> [TIExpr] -> TIExprNode
TIApplyExpr (Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s TIExpr
func) ((TIExpr -> TIExpr) -> [TIExpr] -> [TIExpr]
forall a b. (a -> b) -> [a] -> [b]
map (Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s) [TIExpr]
args)
  
  TITupleExpr [TIExpr]
exprs ->
    [TIExpr] -> TIExprNode
TITupleExpr ((TIExpr -> TIExpr) -> [TIExpr] -> [TIExpr]
forall a b. (a -> b) -> [a] -> [b]
map (Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s) [TIExpr]
exprs)
  
  TICollectionExpr [TIExpr]
exprs ->
    [TIExpr] -> TIExprNode
TICollectionExpr ((TIExpr -> TIExpr) -> [TIExpr] -> [TIExpr]
forall a b. (a -> b) -> [a] -> [b]
map (Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s) [TIExpr]
exprs)
  
  TIConsExpr TIExpr
h TIExpr
t ->
    TIExpr -> TIExpr -> TIExprNode
TIConsExpr (Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s TIExpr
h) (Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s TIExpr
t)
  
  TIJoinExpr TIExpr
l TIExpr
r ->
    TIExpr -> TIExpr -> TIExprNode
TIJoinExpr (Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s TIExpr
l) (Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s TIExpr
r)
  
  TIIfExpr TIExpr
cond TIExpr
thenE TIExpr
elseE ->
    TIExpr -> TIExpr -> TIExpr -> TIExprNode
TIIfExpr (Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s TIExpr
cond) (Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s TIExpr
thenE) (Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s TIExpr
elseE)
  
  TILetExpr [TIBindingExpr]
bindings TIExpr
body ->
    [TIBindingExpr] -> TIExpr -> TIExprNode
TILetExpr ((TIBindingExpr -> TIBindingExpr)
-> [TIBindingExpr] -> [TIBindingExpr]
forall a b. (a -> b) -> [a] -> [b]
map (\(IPrimitiveDataPattern
pat, TIExpr
expr) -> (IPrimitiveDataPattern
pat, Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s TIExpr
expr)) [TIBindingExpr]
bindings)
              (Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s TIExpr
body)
  
  TILetRecExpr [TIBindingExpr]
bindings TIExpr
body ->
    [TIBindingExpr] -> TIExpr -> TIExprNode
TILetRecExpr ((TIBindingExpr -> TIBindingExpr)
-> [TIBindingExpr] -> [TIBindingExpr]
forall a b. (a -> b) -> [a] -> [b]
map (\(IPrimitiveDataPattern
pat, TIExpr
expr) -> (IPrimitiveDataPattern
pat, Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s TIExpr
expr)) [TIBindingExpr]
bindings)
                 (Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s TIExpr
body)
  
  TISeqExpr TIExpr
e1 TIExpr
e2 ->
    TIExpr -> TIExpr -> TIExprNode
TISeqExpr (Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s TIExpr
e1) (Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s TIExpr
e2)
  
  TIInductiveDataExpr String
name [TIExpr]
exprs ->
    String -> [TIExpr] -> TIExprNode
TIInductiveDataExpr String
name ((TIExpr -> TIExpr) -> [TIExpr] -> [TIExpr]
forall a b. (a -> b) -> [a] -> [b]
map (Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s) [TIExpr]
exprs)
  
  TIMatcherExpr [TIPatternDef]
patDefs ->
    [TIPatternDef] -> TIExprNode
TIMatcherExpr ((TIPatternDef -> TIPatternDef) -> [TIPatternDef] -> [TIPatternDef]
forall a b. (a -> b) -> [a] -> [b]
map (\(PrimitivePatPattern
pat, TIExpr
expr, [TIBindingExpr]
bindings) -> (PrimitivePatPattern
pat, Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s TIExpr
expr, [TIBindingExpr]
bindings)) [TIPatternDef]
patDefs)
  
  TIMatchExpr PMMode
mode TIExpr
target TIExpr
matcher [TIMatchClause]
clauses ->
    PMMode -> TIExpr -> TIExpr -> [TIMatchClause] -> TIExprNode
TIMatchExpr PMMode
mode 
                (Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s TIExpr
target)
                (Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s TIExpr
matcher)
                ((TIMatchClause -> TIMatchClause)
-> [TIMatchClause] -> [TIMatchClause]
forall a b. (a -> b) -> [a] -> [b]
map (\(TIPattern
pat, TIExpr
body) -> (TIPattern
pat, Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s TIExpr
body)) [TIMatchClause]
clauses)
  
  TIMatchAllExpr PMMode
mode TIExpr
target TIExpr
matcher [TIMatchClause]
clauses ->
    PMMode -> TIExpr -> TIExpr -> [TIMatchClause] -> TIExprNode
TIMatchAllExpr PMMode
mode
                   (Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s TIExpr
target)
                   (Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s TIExpr
matcher)
                   ((TIMatchClause -> TIMatchClause)
-> [TIMatchClause] -> [TIMatchClause]
forall a b. (a -> b) -> [a] -> [b]
map (\(TIPattern
pat, TIExpr
body) -> (TIPattern
pat, Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s TIExpr
body)) [TIMatchClause]
clauses)
  
  TIMemoizedLambdaExpr [String]
params TIExpr
body ->
    [String] -> TIExpr -> TIExprNode
TIMemoizedLambdaExpr [String]
params (Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s TIExpr
body)
  
  TIDoExpr [TIBindingExpr]
bindings TIExpr
body ->
    [TIBindingExpr] -> TIExpr -> TIExprNode
TIDoExpr ((TIBindingExpr -> TIBindingExpr)
-> [TIBindingExpr] -> [TIBindingExpr]
forall a b. (a -> b) -> [a] -> [b]
map (\(IPrimitiveDataPattern
pat, TIExpr
expr) -> (IPrimitiveDataPattern
pat, Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s TIExpr
expr)) [TIBindingExpr]
bindings)
             (Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s TIExpr
body)
  
  TICambdaExpr String
var TIExpr
body ->
    String -> TIExpr -> TIExprNode
TICambdaExpr String
var (Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s TIExpr
body)
  
  TIWithSymbolsExpr [String]
syms TIExpr
body ->
    [String] -> TIExpr -> TIExprNode
TIWithSymbolsExpr [String]
syms (Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s TIExpr
body)
  
  TIQuoteExpr TIExpr
e ->
    TIExpr -> TIExprNode
TIQuoteExpr (Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s TIExpr
e)
  
  TIQuoteSymbolExpr TIExpr
e ->
    TIExpr -> TIExprNode
TIQuoteSymbolExpr (Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s TIExpr
e)
  
  TIIndexedExpr Bool
override TIExpr
base [Index TIExpr]
indices ->
    Bool -> TIExpr -> [Index TIExpr] -> TIExprNode
TIIndexedExpr Bool
override (Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s TIExpr
base) ((TIExpr -> TIExpr) -> Index TIExpr -> Index TIExpr
forall a b. (a -> b) -> Index a -> Index b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s) (Index TIExpr -> Index TIExpr) -> [Index TIExpr] -> [Index TIExpr]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Index TIExpr]
indices)
  
  TISubrefsExpr Bool
override TIExpr
base TIExpr
ref ->
    Bool -> TIExpr -> TIExpr -> TIExprNode
TISubrefsExpr Bool
override (Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s TIExpr
base) (Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s TIExpr
ref)
  
  TISuprefsExpr Bool
override TIExpr
base TIExpr
ref ->
    Bool -> TIExpr -> TIExpr -> TIExprNode
TISuprefsExpr Bool
override (Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s TIExpr
base) (Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s TIExpr
ref)
  
  TIUserrefsExpr Bool
override TIExpr
base TIExpr
ref ->
    Bool -> TIExpr -> TIExpr -> TIExprNode
TIUserrefsExpr Bool
override (Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s TIExpr
base) (Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s TIExpr
ref)
  
  TIWedgeApplyExpr TIExpr
func [TIExpr]
args ->
    TIExpr -> [TIExpr] -> TIExprNode
TIWedgeApplyExpr (Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s TIExpr
func) ((TIExpr -> TIExpr) -> [TIExpr] -> [TIExpr]
forall a b. (a -> b) -> [a] -> [b]
map (Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s) [TIExpr]
args)
  
  TIFunctionExpr [String]
names ->
    [String] -> TIExprNode
TIFunctionExpr [String]
names
  
  TIVectorExpr [TIExpr]
exprs ->
    [TIExpr] -> TIExprNode
TIVectorExpr ((TIExpr -> TIExpr) -> [TIExpr] -> [TIExpr]
forall a b. (a -> b) -> [a] -> [b]
map (Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s) [TIExpr]
exprs)
  
  TIHashExpr [(TIExpr, TIExpr)]
pairs ->
    [(TIExpr, TIExpr)] -> TIExprNode
TIHashExpr (((TIExpr, TIExpr) -> (TIExpr, TIExpr))
-> [(TIExpr, TIExpr)] -> [(TIExpr, TIExpr)]
forall a b. (a -> b) -> [a] -> [b]
map (\(TIExpr
k, TIExpr
v) -> (Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s TIExpr
k, Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s TIExpr
v)) [(TIExpr, TIExpr)]
pairs)
  
  TIGenerateTensorExpr TIExpr
func TIExpr
shape ->
    TIExpr -> TIExpr -> TIExprNode
TIGenerateTensorExpr (Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s TIExpr
func) (Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s TIExpr
shape)
  
  TITensorExpr TIExpr
shape TIExpr
elems ->
    TIExpr -> TIExpr -> TIExprNode
TITensorExpr (Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s TIExpr
shape) (Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s TIExpr
elems)
  
  TITransposeExpr TIExpr
perm TIExpr
tensor ->
    TIExpr -> TIExpr -> TIExprNode
TITransposeExpr (Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s TIExpr
perm) (Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s TIExpr
tensor)
  
  TIFlipIndicesExpr TIExpr
tensor ->
    TIExpr -> TIExprNode
TIFlipIndicesExpr (Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s TIExpr
tensor)
  
  TITensorMapExpr TIExpr
func TIExpr
tensor ->
    TIExpr -> TIExpr -> TIExprNode
TITensorMapExpr (Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s TIExpr
func) (Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s TIExpr
tensor)
  
  TITensorMap2Expr TIExpr
func TIExpr
t1 TIExpr
t2 ->
    TIExpr -> TIExpr -> TIExpr -> TIExprNode
TITensorMap2Expr (Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s TIExpr
func) (Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s TIExpr
t1) (Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s TIExpr
t2)
  
  TITensorContractExpr TIExpr
tensor ->
    TIExpr -> TIExprNode
TITensorContractExpr (Subst -> TIExpr -> TIExpr
applySubstToTIExpr Subst
s TIExpr
tensor)

-- | Apply a substitution to a TIExprNode recursively with ClassEnv awareness
applySubstToTIExprNodeWithClassEnv :: ClassEnv -> Subst -> TIExprNode -> TIExprNode
applySubstToTIExprNodeWithClassEnv :: ClassEnv -> Subst -> TIExprNode -> TIExprNode
applySubstToTIExprNodeWithClassEnv ClassEnv
env Subst
s TIExprNode
node = case TIExprNode
node of
  TIConstantExpr ConstantExpr
c -> ConstantExpr -> TIExprNode
TIConstantExpr ConstantExpr
c
  TIVarExpr String
name -> String -> TIExprNode
TIVarExpr String
name

  TILambdaExpr Maybe Var
mVar [Var]
params TIExpr
body ->
    Maybe Var -> [Var] -> TIExpr -> TIExprNode
TILambdaExpr Maybe Var
mVar [Var]
params (ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s TIExpr
body)

  TIApplyExpr TIExpr
func [TIExpr]
args ->
    TIExpr -> [TIExpr] -> TIExprNode
TIApplyExpr (ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s TIExpr
func) ((TIExpr -> TIExpr) -> [TIExpr] -> [TIExpr]
forall a b. (a -> b) -> [a] -> [b]
map (ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s) [TIExpr]
args)

  TITupleExpr [TIExpr]
exprs ->
    [TIExpr] -> TIExprNode
TITupleExpr ((TIExpr -> TIExpr) -> [TIExpr] -> [TIExpr]
forall a b. (a -> b) -> [a] -> [b]
map (ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s) [TIExpr]
exprs)

  TICollectionExpr [TIExpr]
exprs ->
    [TIExpr] -> TIExprNode
TICollectionExpr ((TIExpr -> TIExpr) -> [TIExpr] -> [TIExpr]
forall a b. (a -> b) -> [a] -> [b]
map (ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s) [TIExpr]
exprs)

  TIConsExpr TIExpr
h TIExpr
t ->
    TIExpr -> TIExpr -> TIExprNode
TIConsExpr (ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s TIExpr
h) (ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s TIExpr
t)

  TIJoinExpr TIExpr
l TIExpr
r ->
    TIExpr -> TIExpr -> TIExprNode
TIJoinExpr (ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s TIExpr
l) (ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s TIExpr
r)

  TIIfExpr TIExpr
cond TIExpr
thenE TIExpr
elseE ->
    TIExpr -> TIExpr -> TIExpr -> TIExprNode
TIIfExpr (ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s TIExpr
cond) (ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s TIExpr
thenE) (ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s TIExpr
elseE)

  TILetExpr [TIBindingExpr]
bindings TIExpr
body ->
    [TIBindingExpr] -> TIExpr -> TIExprNode
TILetExpr ((TIBindingExpr -> TIBindingExpr)
-> [TIBindingExpr] -> [TIBindingExpr]
forall a b. (a -> b) -> [a] -> [b]
map (\(IPrimitiveDataPattern
pat, TIExpr
expr) -> (IPrimitiveDataPattern
pat, ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s TIExpr
expr)) [TIBindingExpr]
bindings)
              (ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s TIExpr
body)

  TILetRecExpr [TIBindingExpr]
bindings TIExpr
body ->
    [TIBindingExpr] -> TIExpr -> TIExprNode
TILetRecExpr ((TIBindingExpr -> TIBindingExpr)
-> [TIBindingExpr] -> [TIBindingExpr]
forall a b. (a -> b) -> [a] -> [b]
map (\(IPrimitiveDataPattern
pat, TIExpr
expr) -> (IPrimitiveDataPattern
pat, ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s TIExpr
expr)) [TIBindingExpr]
bindings)
                 (ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s TIExpr
body)

  TISeqExpr TIExpr
e1 TIExpr
e2 ->
    TIExpr -> TIExpr -> TIExprNode
TISeqExpr (ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s TIExpr
e1) (ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s TIExpr
e2)

  TIInductiveDataExpr String
name [TIExpr]
exprs ->
    String -> [TIExpr] -> TIExprNode
TIInductiveDataExpr String
name ((TIExpr -> TIExpr) -> [TIExpr] -> [TIExpr]
forall a b. (a -> b) -> [a] -> [b]
map (ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s) [TIExpr]
exprs)

  TIMatcherExpr [TIPatternDef]
patDefs ->
    [TIPatternDef] -> TIExprNode
TIMatcherExpr ((TIPatternDef -> TIPatternDef) -> [TIPatternDef] -> [TIPatternDef]
forall a b. (a -> b) -> [a] -> [b]
map (\(PrimitivePatPattern
pat, TIExpr
expr, [TIBindingExpr]
bindings) -> (PrimitivePatPattern
pat, ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s TIExpr
expr, [TIBindingExpr]
bindings)) [TIPatternDef]
patDefs)

  TIMatchExpr PMMode
mode TIExpr
target TIExpr
matcher [TIMatchClause]
clauses ->
    PMMode -> TIExpr -> TIExpr -> [TIMatchClause] -> TIExprNode
TIMatchExpr PMMode
mode
                (ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s TIExpr
target)
                (ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s TIExpr
matcher)
                ((TIMatchClause -> TIMatchClause)
-> [TIMatchClause] -> [TIMatchClause]
forall a b. (a -> b) -> [a] -> [b]
map (\(TIPattern
pat, TIExpr
body) -> (TIPattern
pat, ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s TIExpr
body)) [TIMatchClause]
clauses)

  TIMatchAllExpr PMMode
mode TIExpr
target TIExpr
matcher [TIMatchClause]
clauses ->
    PMMode -> TIExpr -> TIExpr -> [TIMatchClause] -> TIExprNode
TIMatchAllExpr PMMode
mode
                   (ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s TIExpr
target)
                   (ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s TIExpr
matcher)
                   ((TIMatchClause -> TIMatchClause)
-> [TIMatchClause] -> [TIMatchClause]
forall a b. (a -> b) -> [a] -> [b]
map (\(TIPattern
pat, TIExpr
body) -> (TIPattern
pat, ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s TIExpr
body)) [TIMatchClause]
clauses)

  TIMemoizedLambdaExpr [String]
params TIExpr
body ->
    [String] -> TIExpr -> TIExprNode
TIMemoizedLambdaExpr [String]
params (ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s TIExpr
body)

  TIDoExpr [TIBindingExpr]
bindings TIExpr
body ->
    [TIBindingExpr] -> TIExpr -> TIExprNode
TIDoExpr ((TIBindingExpr -> TIBindingExpr)
-> [TIBindingExpr] -> [TIBindingExpr]
forall a b. (a -> b) -> [a] -> [b]
map (\(IPrimitiveDataPattern
pat, TIExpr
expr) -> (IPrimitiveDataPattern
pat, ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s TIExpr
expr)) [TIBindingExpr]
bindings)
             (ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s TIExpr
body)

  TICambdaExpr String
var TIExpr
body ->
    String -> TIExpr -> TIExprNode
TICambdaExpr String
var (ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s TIExpr
body)

  TIWithSymbolsExpr [String]
syms TIExpr
body ->
    [String] -> TIExpr -> TIExprNode
TIWithSymbolsExpr [String]
syms (ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s TIExpr
body)

  TIQuoteExpr TIExpr
e ->
    TIExpr -> TIExprNode
TIQuoteExpr (ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s TIExpr
e)

  TIQuoteSymbolExpr TIExpr
e ->
    TIExpr -> TIExprNode
TIQuoteSymbolExpr (ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s TIExpr
e)

  TIIndexedExpr Bool
override TIExpr
base [Index TIExpr]
indices ->
    Bool -> TIExpr -> [Index TIExpr] -> TIExprNode
TIIndexedExpr Bool
override (ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s TIExpr
base) ((TIExpr -> TIExpr) -> Index TIExpr -> Index TIExpr
forall a b. (a -> b) -> Index a -> Index b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s) (Index TIExpr -> Index TIExpr) -> [Index TIExpr] -> [Index TIExpr]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Index TIExpr]
indices)

  TISubrefsExpr Bool
override TIExpr
base TIExpr
ref ->
    Bool -> TIExpr -> TIExpr -> TIExprNode
TISubrefsExpr Bool
override (ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s TIExpr
base) (ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s TIExpr
ref)

  TISuprefsExpr Bool
override TIExpr
base TIExpr
ref ->
    Bool -> TIExpr -> TIExpr -> TIExprNode
TISuprefsExpr Bool
override (ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s TIExpr
base) (ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s TIExpr
ref)

  TIUserrefsExpr Bool
override TIExpr
base TIExpr
ref ->
    Bool -> TIExpr -> TIExpr -> TIExprNode
TIUserrefsExpr Bool
override (ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s TIExpr
base) (ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s TIExpr
ref)

  TIWedgeApplyExpr TIExpr
func [TIExpr]
args ->
    TIExpr -> [TIExpr] -> TIExprNode
TIWedgeApplyExpr (ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s TIExpr
func) ((TIExpr -> TIExpr) -> [TIExpr] -> [TIExpr]
forall a b. (a -> b) -> [a] -> [b]
map (ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s) [TIExpr]
args)

  TIFunctionExpr [String]
names ->
    [String] -> TIExprNode
TIFunctionExpr [String]
names

  TIVectorExpr [TIExpr]
exprs ->
    [TIExpr] -> TIExprNode
TIVectorExpr ((TIExpr -> TIExpr) -> [TIExpr] -> [TIExpr]
forall a b. (a -> b) -> [a] -> [b]
map (ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s) [TIExpr]
exprs)

  TIHashExpr [(TIExpr, TIExpr)]
pairs ->
    [(TIExpr, TIExpr)] -> TIExprNode
TIHashExpr (((TIExpr, TIExpr) -> (TIExpr, TIExpr))
-> [(TIExpr, TIExpr)] -> [(TIExpr, TIExpr)]
forall a b. (a -> b) -> [a] -> [b]
map (\(TIExpr
k, TIExpr
v) -> (ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s TIExpr
k, ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s TIExpr
v)) [(TIExpr, TIExpr)]
pairs)

  TIGenerateTensorExpr TIExpr
func TIExpr
shape ->
    TIExpr -> TIExpr -> TIExprNode
TIGenerateTensorExpr (ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s TIExpr
func) (ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s TIExpr
shape)

  TITensorExpr TIExpr
shape TIExpr
elems ->
    TIExpr -> TIExpr -> TIExprNode
TITensorExpr (ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s TIExpr
shape) (ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s TIExpr
elems)

  TITransposeExpr TIExpr
perm TIExpr
tensor ->
    TIExpr -> TIExpr -> TIExprNode
TITransposeExpr (ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s TIExpr
perm) (ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s TIExpr
tensor)

  TIFlipIndicesExpr TIExpr
tensor ->
    TIExpr -> TIExprNode
TIFlipIndicesExpr (ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s TIExpr
tensor)

  TITensorMapExpr TIExpr
func TIExpr
tensor ->
    TIExpr -> TIExpr -> TIExprNode
TITensorMapExpr (ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s TIExpr
func) (ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s TIExpr
tensor)

  TITensorMap2Expr TIExpr
func TIExpr
t1 TIExpr
t2 ->
    TIExpr -> TIExpr -> TIExpr -> TIExprNode
TITensorMap2Expr (ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s TIExpr
func) (ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s TIExpr
t1) (ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s TIExpr
t2)

  TITensorContractExpr TIExpr
tensor ->
    TIExpr -> TIExprNode
TITensorContractExpr (ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
env Subst
s TIExpr
tensor)

-- | Infer type for IExpr
-- NEW: Returns TIExpr (typed expression) instead of (IExpr, Type, Subst)
-- This builds the recursive TIExpr structure directly during type inference
inferIExpr :: IExpr -> Infer (TIExpr, Subst)
inferIExpr :: IExpr -> Infer (TIExpr, Subst)
inferIExpr IExpr
expr = IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
expr TypeErrorContext
emptyContext

-- | Infer type for IExpr with context information
-- NEW: Returns TIExpr (typed expression) with type information embedded
inferIExprWithContext :: IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext :: IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
expr TypeErrorContext
ctx = case IExpr
expr of
  -- Constants
  IConstantExpr ConstantExpr
c -> do
    Type
ty <- ConstantExpr -> Infer Type
inferConstant ConstantExpr
c
    let scheme :: TypeScheme
scheme = [TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [] Type
ty
    (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TypeScheme -> TIExprNode -> TIExpr
TIExpr TypeScheme
scheme (ConstantExpr -> TIExprNode
TIConstantExpr ConstantExpr
c), Subst
emptySubst)
  
  -- Variables
  IVarExpr String
name -> do
    let exprCtx :: TypeErrorContext
exprCtx = String -> TypeErrorContext -> TypeErrorContext
withExpr (IExpr -> String
forall a. Pretty a => a -> String
prettyStr IExpr
expr) TypeErrorContext
ctx
    -- Variables starting with ":::" are treated as Any type without warning
    if String
":::" String -> String -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isPrefixOf` String
name
      then do
        let scheme :: TypeScheme
scheme = [TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [] Type
TAny
        (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TypeScheme -> TIExprNode -> TIExpr
TIExpr TypeScheme
scheme (String -> TIExprNode
TIVarExpr String
name), Subst
emptySubst)
      else do
        (Type
ty, [Constraint]
constraints) <- String -> Infer (Type, [Constraint])
lookupVarWithConstraints String
name
        let scheme :: TypeScheme
scheme = [TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [Constraint]
constraints Type
ty
        (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TypeScheme -> TIExprNode -> TIExpr
TIExpr TypeScheme
scheme (String -> TIExprNode
TIVarExpr String
name), Subst
emptySubst)
  
  -- Tuples
  ITupleExpr [IExpr]
elems -> do
    let exprCtx :: TypeErrorContext
exprCtx = String -> TypeErrorContext -> TypeErrorContext
withExpr (IExpr -> String
forall a. Pretty a => a -> String
prettyStr IExpr
expr) TypeErrorContext
ctx
    case [IExpr]
elems of
      [] -> do
        -- Empty tuple: unit type ()
        let scheme :: TypeScheme
scheme = [TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [] ([Type] -> Type
TTuple [])
        (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TypeScheme -> TIExprNode -> TIExpr
TIExpr TypeScheme
scheme ([TIExpr] -> TIExprNode
TITupleExpr []), Subst
emptySubst)
      [IExpr
single] -> do
        -- Single element tuple: same as the element itself (parentheses are just grouping)
        IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
single TypeErrorContext
exprCtx
      [IExpr]
_ -> do
        [(TIExpr, Subst)]
results <- (IExpr -> Infer (TIExpr, Subst))
-> [IExpr]
-> ExceptT TypeError (StateT InferState IO) [(TIExpr, Subst)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\IExpr
e -> IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
e TypeErrorContext
exprCtx) [IExpr]
elems
        let elemTIExprs :: [TIExpr]
elemTIExprs = ((TIExpr, Subst) -> TIExpr) -> [(TIExpr, Subst)] -> [TIExpr]
forall a b. (a -> b) -> [a] -> [b]
map (TIExpr, Subst) -> TIExpr
forall a b. (a, b) -> a
fst [(TIExpr, Subst)]
results
            elemTypes :: [Type]
elemTypes = ((TIExpr, Subst) -> Type) -> [(TIExpr, Subst)] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (TIExpr -> Type
tiExprType (TIExpr -> Type)
-> ((TIExpr, Subst) -> TIExpr) -> (TIExpr, Subst) -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TIExpr, Subst) -> TIExpr
forall a b. (a, b) -> a
fst) [(TIExpr, Subst)]
results
            s :: Subst
s = (Subst -> Subst -> Subst) -> Subst -> [Subst] -> Subst
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Subst -> Subst -> Subst
composeSubst Subst
emptySubst (((TIExpr, Subst) -> Subst) -> [(TIExpr, Subst)] -> [Subst]
forall a b. (a -> b) -> [a] -> [b]
map (TIExpr, Subst) -> Subst
forall a b. (a, b) -> b
snd [(TIExpr, Subst)]
results)
        
        -- Check if all elements are Matcher types
        -- If so, return Matcher (Tuple ...) instead of (Matcher ..., Matcher ...)
        [Type]
appliedElemTypes <- (Type -> Infer Type)
-> [Type] -> ExceptT TypeError (StateT InferState IO) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s) [Type]
elemTypes
        let matcherTypes :: [Type]
matcherTypes = [Maybe Type] -> [Type]
forall a. [Maybe a] -> [a]
catMaybes ((Type -> Maybe Type) -> [Type] -> [Maybe Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Maybe Type
extractMatcherType [Type]
appliedElemTypes)
        
        if [Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
matcherTypes Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
appliedElemTypes Bool -> Bool -> Bool
&& Bool -> Bool
not ([Type] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Type]
appliedElemTypes)
          then do
            -- All elements are matchers: return Matcher (Tuple ...)
            let tupleType :: Type
tupleType = [Type] -> Type
TTuple [Type]
matcherTypes
                resultType :: Type
resultType = Type -> Type
TMatcher Type
tupleType
                scheme :: TypeScheme
scheme = [TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [] Type
resultType
            (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TypeScheme -> TIExprNode -> TIExpr
TIExpr TypeScheme
scheme ([TIExpr] -> TIExprNode
TITupleExpr [TIExpr]
elemTIExprs), Subst
s)
          else do
            -- Not all elements are matchers: return regular tuple
            let resultType :: Type
resultType = [Type] -> Type
TTuple [Type]
appliedElemTypes
                scheme :: TypeScheme
scheme = [TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [] Type
resultType
            (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TypeScheme -> TIExprNode -> TIExpr
TIExpr TypeScheme
scheme ([TIExpr] -> TIExprNode
TITupleExpr [TIExpr]
elemTIExprs), Subst
s)
        where
          -- Extract the inner type from Matcher a -> Just a, otherwise Nothing
          extractMatcherType :: Type -> Maybe Type
          extractMatcherType :: Type -> Maybe Type
extractMatcherType (TMatcher Type
t) = Type -> Maybe Type
forall a. a -> Maybe a
Just Type
t
          extractMatcherType Type
_ = Maybe Type
forall a. Maybe a
Nothing
  
  -- Collections (Lists)
  ICollectionExpr [IExpr]
elems -> do
    let exprCtx :: TypeErrorContext
exprCtx = String -> TypeErrorContext -> TypeErrorContext
withExpr (IExpr -> String
forall a. Pretty a => a -> String
prettyStr IExpr
expr) TypeErrorContext
ctx
    Type
elemType <- String -> Infer Type
freshVar String
"elem"
    ([TIExpr]
elemTIExprs, Subst
s) <- (([TIExpr], Subst)
 -> IExpr
 -> ExceptT TypeError (StateT InferState IO) ([TIExpr], Subst))
-> ([TIExpr], Subst)
-> [IExpr]
-> ExceptT TypeError (StateT InferState IO) ([TIExpr], Subst)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Type
-> TypeErrorContext
-> ([TIExpr], Subst)
-> IExpr
-> ExceptT TypeError (StateT InferState IO) ([TIExpr], Subst)
inferListElem Type
elemType TypeErrorContext
exprCtx) ([], Subst
emptySubst) [IExpr]
elems
    Type
elemType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s Type
elemType
    let resultType :: Type
resultType = Type -> Type
TCollection Type
elemType'
    (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> TIExprNode -> TIExpr
mkTIExpr Type
resultType ([TIExpr] -> TIExprNode
TICollectionExpr ([TIExpr] -> [TIExpr]
forall a. [a] -> [a]
reverse [TIExpr]
elemTIExprs)), Subst
s)
    where
      inferListElem :: Type
-> TypeErrorContext
-> ([TIExpr], Subst)
-> IExpr
-> ExceptT TypeError (StateT InferState IO) ([TIExpr], Subst)
inferListElem Type
eType TypeErrorContext
exprCtx ([TIExpr]
accExprs, Subst
s) IExpr
e = do
        (TIExpr
tiExpr, Subst
s') <- IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
e TypeErrorContext
exprCtx
        let t :: Type
t = TIExpr -> Type
tiExprType TIExpr
tiExpr
        Type
eType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s Type
eType
        Subst
s'' <- Type -> Type -> TypeErrorContext -> Infer Subst
unifyTypesWithContext Type
eType' Type
t TypeErrorContext
exprCtx
        ([TIExpr], Subst)
-> ExceptT TypeError (StateT InferState IO) ([TIExpr], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIExpr
tiExpr TIExpr -> [TIExpr] -> [TIExpr]
forall a. a -> [a] -> [a]
: [TIExpr]
accExprs, Subst -> Subst -> Subst
composeSubst Subst
s'' (Subst -> Subst -> Subst
composeSubst Subst
s' Subst
s))

  -- Cons
  IConsExpr IExpr
headExpr IExpr
tailExpr -> do
    let exprCtx :: TypeErrorContext
exprCtx = String -> TypeErrorContext -> TypeErrorContext
withExpr (IExpr -> String
forall a. Pretty a => a -> String
prettyStr IExpr
expr) TypeErrorContext
ctx
    (TIExpr
headTI, Subst
s1) <- IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
headExpr TypeErrorContext
exprCtx
    (TIExpr
tailTI, Subst
s2) <- IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
tailExpr TypeErrorContext
exprCtx
    let headType :: Type
headType = TIExpr -> Type
tiExprType TIExpr
headTI
        tailType :: Type
tailType = TIExpr -> Type
tiExprType TIExpr
tailTI
        s12 :: Subst
s12 = Subst -> Subst -> Subst
composeSubst Subst
s2 Subst
s1
    Type
headType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s12 Type
headType
    Type
tailType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s12 Type
tailType
    Subst
s3 <- Type -> Type -> TypeErrorContext -> Infer Subst
unifyTypesWithContext (Type -> Type
TCollection Type
headType') Type
tailType' TypeErrorContext
exprCtx
    let finalS :: Subst
finalS = Subst -> Subst -> Subst
composeSubst Subst
s3 Subst
s12
    Type
resultType <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
finalS Type
tailType
    (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> TIExprNode -> TIExpr
mkTIExpr Type
resultType (TIExpr -> TIExpr -> TIExprNode
TIConsExpr TIExpr
headTI TIExpr
tailTI), Subst
finalS)
  
  -- Join (list concatenation)
  IJoinExpr IExpr
leftExpr IExpr
rightExpr -> do
    let exprCtx :: TypeErrorContext
exprCtx = String -> TypeErrorContext -> TypeErrorContext
withExpr (IExpr -> String
forall a. Pretty a => a -> String
prettyStr IExpr
expr) TypeErrorContext
ctx
    (TIExpr
leftTI, Subst
s1) <- IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
leftExpr TypeErrorContext
exprCtx
    (TIExpr
rightTI, Subst
s2) <- IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
rightExpr TypeErrorContext
exprCtx
    let leftType :: Type
leftType = TIExpr -> Type
tiExprType TIExpr
leftTI
        rightType :: Type
rightType = TIExpr -> Type
tiExprType TIExpr
rightTI
        s12 :: Subst
s12 = Subst -> Subst -> Subst
composeSubst Subst
s2 Subst
s1
    Type
leftType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s12 Type
leftType
    Type
rightType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s12 Type
rightType
    Subst
s3 <- Type -> Type -> TypeErrorContext -> Infer Subst
unifyTypesWithContext Type
leftType' Type
rightType' TypeErrorContext
exprCtx
    let finalS :: Subst
finalS = Subst -> Subst -> Subst
composeSubst Subst
s3 Subst
s12
    Type
resultType <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
finalS Type
leftType
    (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> TIExprNode -> TIExpr
mkTIExpr Type
resultType (TIExpr -> TIExpr -> TIExprNode
TIJoinExpr TIExpr
leftTI TIExpr
rightTI), Subst
finalS)
  
  -- Hash (Map)
  IHashExpr [(IExpr, IExpr)]
pairs -> do
    let exprCtx :: TypeErrorContext
exprCtx = String -> TypeErrorContext -> TypeErrorContext
withExpr (IExpr -> String
forall a. Pretty a => a -> String
prettyStr IExpr
expr) TypeErrorContext
ctx
    Type
keyType <- String -> Infer Type
freshVar String
"hashKey"
    Type
valType <- String -> Infer Type
freshVar String
"hashVal"
    ([(TIExpr, TIExpr)]
pairTIs, Subst
s) <- (([(TIExpr, TIExpr)], Subst)
 -> (IExpr, IExpr)
 -> ExceptT
      TypeError (StateT InferState IO) ([(TIExpr, TIExpr)], Subst))
-> ([(TIExpr, TIExpr)], Subst)
-> [(IExpr, IExpr)]
-> ExceptT
     TypeError (StateT InferState IO) ([(TIExpr, TIExpr)], Subst)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Type
-> Type
-> TypeErrorContext
-> ([(TIExpr, TIExpr)], Subst)
-> (IExpr, IExpr)
-> ExceptT
     TypeError (StateT InferState IO) ([(TIExpr, TIExpr)], Subst)
inferHashPair Type
keyType Type
valType TypeErrorContext
exprCtx) ([], Subst
emptySubst) [(IExpr, IExpr)]
pairs
    Type
keyType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s Type
keyType
    Type
valType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s Type
valType
    let resultType :: Type
resultType = Type -> Type -> Type
THash Type
keyType' Type
valType'
    (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> TIExprNode -> TIExpr
mkTIExpr Type
resultType ([(TIExpr, TIExpr)] -> TIExprNode
TIHashExpr ([(TIExpr, TIExpr)] -> [(TIExpr, TIExpr)]
forall a. [a] -> [a]
reverse [(TIExpr, TIExpr)]
pairTIs)), Subst
s)
    where
      inferHashPair :: Type
-> Type
-> TypeErrorContext
-> ([(TIExpr, TIExpr)], Subst)
-> (IExpr, IExpr)
-> ExceptT
     TypeError (StateT InferState IO) ([(TIExpr, TIExpr)], Subst)
inferHashPair Type
kType Type
vType TypeErrorContext
exprCtx ([(TIExpr, TIExpr)]
accPairs, Subst
s') (IExpr
k, IExpr
v) = do
        (TIExpr
kTI, Subst
s1) <- IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
k TypeErrorContext
exprCtx
        (TIExpr
vTI, Subst
s2) <- IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
v TypeErrorContext
exprCtx
        let kt :: Type
kt = TIExpr -> Type
tiExprType TIExpr
kTI
            vt :: Type
vt = TIExpr -> Type
tiExprType TIExpr
vTI
        Type
kType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM (Subst -> Subst -> Subst
composeSubst Subst
s2 Subst
s1) Type
kType
        Subst
s3 <- Type -> Type -> TypeErrorContext -> Infer Subst
unifyTypesWithContext Type
kType' Type
kt TypeErrorContext
exprCtx
        Type
vType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM (Subst -> Subst -> Subst
composeSubst Subst
s3 (Subst -> Subst -> Subst
composeSubst Subst
s2 Subst
s1)) Type
vType
        Subst
s4 <- Type -> Type -> TypeErrorContext -> Infer Subst
unifyTypesWithContext Type
vType' Type
vt TypeErrorContext
exprCtx
        ([(TIExpr, TIExpr)], Subst)
-> ExceptT
     TypeError (StateT InferState IO) ([(TIExpr, TIExpr)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return ((TIExpr
kTI, TIExpr
vTI) (TIExpr, TIExpr) -> [(TIExpr, TIExpr)] -> [(TIExpr, TIExpr)]
forall a. a -> [a] -> [a]
: [(TIExpr, TIExpr)]
accPairs, (Subst -> Subst -> Subst) -> Subst -> [Subst] -> Subst
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Subst -> Subst -> Subst
composeSubst Subst
s' [Subst
s4, Subst
s3, Subst
s2, Subst
s1])
  
  -- Vector (Tensor)
  IVectorExpr [IExpr]
elems -> do
    let exprCtx :: TypeErrorContext
exprCtx = String -> TypeErrorContext -> TypeErrorContext
withExpr (IExpr -> String
forall a. Pretty a => a -> String
prettyStr IExpr
expr) TypeErrorContext
ctx
    Type
elemType <- String -> Infer Type
freshVar String
"vecElem"
    ([TIExpr]
elemTIs, Subst
s) <- (([TIExpr], Subst)
 -> IExpr
 -> ExceptT TypeError (StateT InferState IO) ([TIExpr], Subst))
-> ([TIExpr], Subst)
-> [IExpr]
-> ExceptT TypeError (StateT InferState IO) ([TIExpr], Subst)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Type
-> TypeErrorContext
-> ([TIExpr], Subst)
-> IExpr
-> ExceptT TypeError (StateT InferState IO) ([TIExpr], Subst)
inferListElem Type
elemType TypeErrorContext
exprCtx) ([], Subst
emptySubst) [IExpr]
elems
    Type
elemType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s Type
elemType
    let resultType :: Type
resultType = Type -> Type
normalizeTensorType (Type -> Type
TTensor Type
elemType')
    (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> TIExprNode -> TIExpr
mkTIExpr Type
resultType ([TIExpr] -> TIExprNode
TIVectorExpr ([TIExpr] -> [TIExpr]
forall a. [a] -> [a]
reverse [TIExpr]
elemTIs)), Subst
s)
    where
      inferListElem :: Type
-> TypeErrorContext
-> ([TIExpr], Subst)
-> IExpr
-> ExceptT TypeError (StateT InferState IO) ([TIExpr], Subst)
inferListElem Type
eType TypeErrorContext
exprCtx ([TIExpr]
accExprs, Subst
s) IExpr
e = do
        (TIExpr
tiExpr, Subst
s') <- IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
e TypeErrorContext
exprCtx
        let t :: Type
t = TIExpr -> Type
tiExprType TIExpr
tiExpr
        Type
eType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s Type
eType
        Subst
s'' <- Type -> Type -> TypeErrorContext -> Infer Subst
unifyTypesWithContext Type
eType' Type
t TypeErrorContext
exprCtx
        ([TIExpr], Subst)
-> ExceptT TypeError (StateT InferState IO) ([TIExpr], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIExpr
tiExpr TIExpr -> [TIExpr] -> [TIExpr]
forall a. a -> [a] -> [a]
: [TIExpr]
accExprs, Subst -> Subst -> Subst
composeSubst Subst
s'' (Subst -> Subst -> Subst
composeSubst Subst
s' Subst
s))

  -- Lambda
  ILambdaExpr Maybe Var
mVar [Var]
params IExpr
body -> do
    let exprCtx :: TypeErrorContext
exprCtx = String -> TypeErrorContext -> TypeErrorContext
withExpr (IExpr -> String
forall a. Pretty a => a -> String
prettyStr IExpr
expr) TypeErrorContext
ctx
    [Type]
argTypes <- (Var -> Infer Type)
-> [Var] -> ExceptT TypeError (StateT InferState IO) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\Var
_ -> String -> Infer Type
freshVar String
"arg") [Var]
params
    let bindings :: [(String, Type)]
bindings = (Var -> Type -> (String, Type))
-> [Var] -> [Type] -> [(String, Type)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Var -> Type -> (String, Type)
forall {b}. Var -> b -> (String, b)
makeBinding [Var]
params [Type]
argTypes
    (TIExpr
bodyTIExpr, Subst
s) <- [(String, TypeScheme)]
-> Infer (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a. [(String, TypeScheme)] -> Infer a -> Infer a
withEnv (((String, Type) -> (String, TypeScheme))
-> [(String, Type)] -> [(String, TypeScheme)]
forall a b. (a -> b) -> [a] -> [b]
map (String, Type) -> (String, TypeScheme)
forall {a}. (a, Type) -> (a, TypeScheme)
toScheme [(String, Type)]
bindings) (Infer (TIExpr, Subst) -> Infer (TIExpr, Subst))
-> Infer (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a b. (a -> b) -> a -> b
$ IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
body TypeErrorContext
exprCtx
    let bodyType :: Type
bodyType = TIExpr -> Type
tiExprType TIExpr
bodyTIExpr
    [Type]
finalArgTypes <- (Type -> Infer Type)
-> [Type] -> ExceptT TypeError (StateT InferState IO) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s) [Type]
argTypes
    let funType :: Type
funType = (Type -> Type -> Type) -> Type -> [Type] -> Type
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Type -> Type -> Type
TFun Type
bodyType [Type]
finalArgTypes
    (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> TIExprNode -> TIExpr
mkTIExpr Type
funType (Maybe Var -> [Var] -> TIExpr -> TIExprNode
TILambdaExpr Maybe Var
mVar [Var]
params TIExpr
bodyTIExpr), Subst
s)
    where
      makeBinding :: Var -> b -> (String, b)
makeBinding Var
var b
t = (Var -> String
extractNameFromVar Var
var, b
t)
      toScheme :: (a, Type) -> (a, TypeScheme)
toScheme (a
name, Type
t) = (a
name, [TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [] Type
t)
  
  -- Function Application
  IApplyExpr IExpr
func [IExpr]
args -> do
    let exprCtx :: TypeErrorContext
exprCtx = String -> TypeErrorContext -> TypeErrorContext
withExpr (IExpr -> String
forall a. Pretty a => a -> String
prettyStr IExpr
expr) TypeErrorContext
ctx
    (TIExpr
funcTI, Subst
s1) <- IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
func TypeErrorContext
exprCtx
    let funcType :: Type
funcType = TIExpr -> Type
tiExprType TIExpr
funcTI
    TIExpr
-> Type
-> [IExpr]
-> Subst
-> TypeErrorContext
-> Infer (TIExpr, Subst)
inferIApplicationWithContext TIExpr
funcTI Type
funcType [IExpr]
args Subst
s1 TypeErrorContext
exprCtx

  -- Wedge apply expression (exterior product)
  IWedgeApplyExpr IExpr
func [IExpr]
args -> do
    let exprCtx :: TypeErrorContext
exprCtx = String -> TypeErrorContext -> TypeErrorContext
withExpr (IExpr -> String
forall a. Pretty a => a -> String
prettyStr IExpr
expr) TypeErrorContext
ctx
    (TIExpr
funcTI, Subst
s1) <- IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
func TypeErrorContext
exprCtx
    let funcType :: Type
funcType = TIExpr -> Type
tiExprType TIExpr
funcTI
    -- Wedge application is similar to normal application
    (TIExpr
resultTI, Subst
finalS) <- TIExpr
-> Type
-> [IExpr]
-> Subst
-> TypeErrorContext
-> Infer (TIExpr, Subst)
inferIApplicationWithContext TIExpr
funcTI Type
funcType [IExpr]
args Subst
s1 TypeErrorContext
exprCtx
    -- Convert TIApplyExpr to TIWedgeApplyExpr to preserve wedge semantics
    let resultScheme :: TypeScheme
resultScheme = TIExpr -> TypeScheme
tiScheme TIExpr
resultTI
    case TIExpr -> TIExprNode
tiExprNode TIExpr
resultTI of
      TIApplyExpr TIExpr
funcTI' [TIExpr]
argTIs' ->
        (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TypeScheme -> TIExprNode -> TIExpr
TIExpr TypeScheme
resultScheme (TIExpr -> [TIExpr] -> TIExprNode
TIWedgeApplyExpr TIExpr
funcTI' [TIExpr]
argTIs'), Subst
finalS)
      TIExprNode
_ -> (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIExpr
resultTI, Subst
finalS)

  -- If expression
  IIfExpr IExpr
cond IExpr
thenExpr IExpr
elseExpr -> do
    let exprCtx :: TypeErrorContext
exprCtx = String -> TypeErrorContext -> TypeErrorContext
withExpr (IExpr -> String
forall a. Pretty a => a -> String
prettyStr IExpr
expr) TypeErrorContext
ctx
    (TIExpr
condTI, Subst
s1) <- IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
cond TypeErrorContext
exprCtx
    let condType :: Type
condType = TIExpr -> Type
tiExprType TIExpr
condTI
    Subst
s2 <- Type -> Type -> TypeErrorContext -> Infer Subst
unifyTypesWithContext Type
condType Type
TBool TypeErrorContext
exprCtx
    let s12 :: Subst
s12 = Subst -> Subst -> Subst
composeSubst Subst
s2 Subst
s1
    (TIExpr
thenTI, Subst
s3) <- IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
thenExpr TypeErrorContext
exprCtx
    (TIExpr
elseTI, Subst
s4) <- IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
elseExpr TypeErrorContext
exprCtx
    let thenType :: Type
thenType = TIExpr -> Type
tiExprType TIExpr
thenTI
        elseType :: Type
elseType = TIExpr -> Type
tiExprType TIExpr
elseTI
    Type
thenType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s4 Type
thenType
    Subst
s5 <- Type -> Type -> TypeErrorContext -> Infer Subst
unifyTypesWithContext Type
thenType' Type
elseType TypeErrorContext
exprCtx
    let finalS :: Subst
finalS = (Subst -> Subst -> Subst) -> Subst -> [Subst] -> Subst
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Subst -> Subst -> Subst
composeSubst Subst
emptySubst [Subst
s5, Subst
s4, Subst
s3, Subst
s12]
    Type
resultType <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
finalS Type
elseType
    (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> TIExprNode -> TIExpr
mkTIExpr Type
resultType (TIExpr -> TIExpr -> TIExpr -> TIExprNode
TIIfExpr TIExpr
condTI TIExpr
thenTI TIExpr
elseTI), Subst
finalS)
  
  -- Let expression
  ILetExpr [IBindingExpr]
bindings IExpr
body -> do
    let exprCtx :: TypeErrorContext
exprCtx = String -> TypeErrorContext -> TypeErrorContext
withExpr (IExpr -> String
forall a. Pretty a => a -> String
prettyStr IExpr
expr) TypeErrorContext
ctx
    TypeEnv
env <- Infer TypeEnv
getEnv
    ([TIBindingExpr]
bindingTIs, [(String, TypeScheme)]
extendedEnv, Subst
s1) <- [IBindingExpr]
-> TypeEnv
-> Subst
-> TypeErrorContext
-> Infer ([TIBindingExpr], [(String, TypeScheme)], Subst)
inferIBindingsWithContext [IBindingExpr]
bindings TypeEnv
env Subst
emptySubst TypeErrorContext
exprCtx
    (TIExpr
bodyTI, Subst
s2) <- [(String, TypeScheme)]
-> Infer (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a. [(String, TypeScheme)] -> Infer a -> Infer a
withEnv [(String, TypeScheme)]
extendedEnv (Infer (TIExpr, Subst) -> Infer (TIExpr, Subst))
-> Infer (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a b. (a -> b) -> a -> b
$ IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
body TypeErrorContext
exprCtx
    let bodyType :: Type
bodyType = TIExpr -> Type
tiExprType TIExpr
bodyTI
        finalS :: Subst
finalS = Subst -> Subst -> Subst
composeSubst Subst
s2 Subst
s1
    Type
resultType <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
finalS Type
bodyType
    (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> TIExprNode -> TIExpr
mkTIExpr Type
resultType ([TIBindingExpr] -> TIExpr -> TIExprNode
TILetExpr [TIBindingExpr]
bindingTIs TIExpr
bodyTI), Subst
finalS)
  
  -- LetRec expression
  ILetRecExpr [IBindingExpr]
bindings IExpr
body -> do
    let exprCtx :: TypeErrorContext
exprCtx = String -> TypeErrorContext -> TypeErrorContext
withExpr (IExpr -> String
forall a. Pretty a => a -> String
prettyStr IExpr
expr) TypeErrorContext
ctx
    TypeEnv
env <- Infer TypeEnv
getEnv
    ([TIBindingExpr]
bindingTIs, [(String, TypeScheme)]
extendedEnv, Subst
s1) <- [IBindingExpr]
-> TypeEnv
-> Subst
-> TypeErrorContext
-> Infer ([TIBindingExpr], [(String, TypeScheme)], Subst)
inferIRecBindingsWithContext [IBindingExpr]
bindings TypeEnv
env Subst
emptySubst TypeErrorContext
exprCtx
    (TIExpr
bodyTI, Subst
s2) <- [(String, TypeScheme)]
-> Infer (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a. [(String, TypeScheme)] -> Infer a -> Infer a
withEnv [(String, TypeScheme)]
extendedEnv (Infer (TIExpr, Subst) -> Infer (TIExpr, Subst))
-> Infer (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a b. (a -> b) -> a -> b
$ IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
body TypeErrorContext
exprCtx
    let bodyType :: Type
bodyType = TIExpr -> Type
tiExprType TIExpr
bodyTI
        finalS :: Subst
finalS = Subst -> Subst -> Subst
composeSubst Subst
s2 Subst
s1
    Type
resultType <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
finalS Type
bodyType
    (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> TIExprNode -> TIExpr
mkTIExpr Type
resultType ([TIBindingExpr] -> TIExpr -> TIExprNode
TILetRecExpr [TIBindingExpr]
bindingTIs TIExpr
bodyTI), Subst
finalS)
  
  -- Sequence expression
  ISeqExpr IExpr
expr1 IExpr
expr2 -> do
    let exprCtx :: TypeErrorContext
exprCtx = String -> TypeErrorContext -> TypeErrorContext
withExpr (IExpr -> String
forall a. Pretty a => a -> String
prettyStr IExpr
expr) TypeErrorContext
ctx
    (TIExpr
expr1TI, Subst
s1) <- IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
expr1 TypeErrorContext
exprCtx
    (TIExpr
expr2TI, Subst
s2) <- IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
expr2 TypeErrorContext
exprCtx
    let t2 :: Type
t2 = TIExpr -> Type
tiExprType TIExpr
expr2TI
    (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> TIExprNode -> TIExpr
mkTIExpr Type
t2 (TIExpr -> TIExpr -> TIExprNode
TISeqExpr TIExpr
expr1TI TIExpr
expr2TI), Subst -> Subst -> Subst
composeSubst Subst
s2 Subst
s1)
  
  -- Inductive Data Constructor
  IInductiveDataExpr String
name [IExpr]
args -> do
    -- Look up constructor type in environment
    TypeEnv
env <- Infer TypeEnv
getEnv
    case Var -> TypeEnv -> Maybe TypeScheme
lookupEnv (String -> Var
stringToVar String
name) TypeEnv
env of
      Just TypeScheme
scheme -> do
        -- Instantiate the type scheme
        InferState
st <- ExceptT TypeError (StateT InferState IO) InferState
forall s (m :: * -> *). MonadState s m => m s
get
        let ([Constraint]
_constraints, Type
constructorType, Int
newCounter) = TypeScheme -> Int -> ([Constraint], Type, Int)
instantiate TypeScheme
scheme (InferState -> Int
inferCounter InferState
st)
        (InferState -> InferState) -> Infer ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((InferState -> InferState) -> Infer ())
-> (InferState -> InferState) -> Infer ()
forall a b. (a -> b) -> a -> b
$ \InferState
s -> InferState
s { inferCounter = newCounter }
        -- Treat constructor as a function application
        String -> Type -> [IExpr] -> Subst -> Infer (TIExpr, Subst)
inferIApplication String
name Type
constructorType [IExpr]
args Subst
emptySubst
      Maybe TypeScheme
Nothing -> do
        -- Constructor not found in environment
        let exprCtx :: TypeErrorContext
exprCtx = String -> TypeErrorContext -> TypeErrorContext
withExpr (IExpr -> String
forall a. Pretty a => a -> String
prettyStr IExpr
expr) TypeErrorContext
ctx
        Bool
permissive <- Infer Bool
isPermissive
        if Bool
permissive
          then do
            -- In permissive mode, treat as a warning and return a fresh type variable
            TypeWarning -> Infer ()
addWarning (TypeWarning -> Infer ()) -> TypeWarning -> Infer ()
forall a b. (a -> b) -> a -> b
$ String -> TypeErrorContext -> TypeWarning
UnboundVariableWarning String
name TypeErrorContext
exprCtx
            Type
resultType <- String -> Infer Type
freshVar String
"ctor"
            (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> TIExprNode -> TIExpr
mkTIExpr Type
resultType (String -> [TIExpr] -> TIExprNode
TIInductiveDataExpr String
name []), Subst
emptySubst)
          else TypeError -> Infer (TIExpr, Subst)
forall a. TypeError -> ExceptT TypeError (StateT InferState IO) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (TypeError -> Infer (TIExpr, Subst))
-> TypeError -> Infer (TIExpr, Subst)
forall a b. (a -> b) -> a -> b
$ String -> TypeErrorContext -> TypeError
UnboundVariable String
name TypeErrorContext
exprCtx
  
  -- Matchers (return Matcher type)
  IMatcherExpr [IPatternDef]
patDefs -> do
    let exprCtx :: TypeErrorContext
exprCtx = String -> TypeErrorContext -> TypeErrorContext
withExpr (IExpr -> String
forall a. Pretty a => a -> String
prettyStr IExpr
expr) TypeErrorContext
ctx
    -- Infer type of each pattern definition (matcher clause)
    -- Each clause has: (PrimitivePatPattern, nextMatcherExpr, [(primitiveDataPat, targetExpr)])
    [(TIPatternDef, (Type, [Subst]))]
results <- (IPatternDef
 -> ExceptT
      TypeError (StateT InferState IO) (TIPatternDef, (Type, [Subst])))
-> [IPatternDef]
-> ExceptT
     TypeError (StateT InferState IO) [(TIPatternDef, (Type, [Subst]))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (TypeErrorContext
-> IPatternDef
-> ExceptT
     TypeError (StateT InferState IO) (TIPatternDef, (Type, [Subst]))
inferPatternDef TypeErrorContext
exprCtx) [IPatternDef]
patDefs
    
    -- Collect TIPatternDefs and substitutions
    let tiPatDefs :: [TIPatternDef]
tiPatDefs = ((TIPatternDef, (Type, [Subst])) -> TIPatternDef)
-> [(TIPatternDef, (Type, [Subst]))] -> [TIPatternDef]
forall a b. (a -> b) -> [a] -> [b]
map (TIPatternDef, (Type, [Subst])) -> TIPatternDef
forall a b. (a, b) -> a
fst [(TIPatternDef, (Type, [Subst]))]
results
        substs :: [Subst]
substs = ((TIPatternDef, (Type, [Subst])) -> [Subst])
-> [(TIPatternDef, (Type, [Subst]))] -> [Subst]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ((Type, [Subst]) -> [Subst]
forall a b. (a, b) -> b
snd ((Type, [Subst]) -> [Subst])
-> ((TIPatternDef, (Type, [Subst])) -> (Type, [Subst]))
-> (TIPatternDef, (Type, [Subst]))
-> [Subst]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TIPatternDef, (Type, [Subst])) -> (Type, [Subst])
forall a b. (a, b) -> b
snd) [(TIPatternDef, (Type, [Subst]))]
results  -- Extract [Subst] from (TIPatternDef, (Type, [Subst]))
        finalSubst :: Subst
finalSubst = (Subst -> Subst -> Subst) -> Subst -> [Subst] -> Subst
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Subst -> Subst -> Subst
composeSubst Subst
emptySubst [Subst]
substs
    
    -- All clauses should agree on the matched type
    -- Unify all matched types from each pattern definition
    [Type]
matchedTypes <- ((TIPatternDef, (Type, [Subst])) -> Infer Type)
-> [(TIPatternDef, (Type, [Subst]))]
-> ExceptT TypeError (StateT InferState IO) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\(TIPatternDef
_, (Type
ty, [Subst]
_)) -> Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
finalSubst Type
ty) [(TIPatternDef, (Type, [Subst]))]
results
    (Type
matchedTy, Subst
s_matched) <- case [Type]
matchedTypes of
      [] -> do
        Type
ty <- String -> Infer Type
freshVar String
"matched"
        (Type, Subst)
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
ty, Subst
emptySubst)
      (Type
firstTy:[Type]
restTys) -> do
        -- Unify all matched types
        Subst
s <- (Subst -> Type -> Infer Subst) -> Subst -> [Type] -> Infer Subst
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (\Subst
accS Type
ty -> do
            Type
firstTy' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
accS Type
firstTy
            Type
ty' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
accS Type
ty
            Subst
s' <- Type -> Type -> TypeErrorContext -> Infer Subst
unifyTypesWithContext Type
firstTy' Type
ty' TypeErrorContext
exprCtx
            Subst -> Infer Subst
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Subst -> Infer Subst) -> Subst -> Infer Subst
forall a b. (a -> b) -> a -> b
$ Subst -> Subst -> Subst
composeSubst Subst
s' Subst
accS
          ) Subst
emptySubst [Type]
restTys
        Type
resultTy <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s Type
firstTy
        (Type, Subst)
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
resultTy, Subst
s)
    
    let allSubst :: Subst
allSubst = Subst -> Subst -> Subst
composeSubst Subst
s_matched Subst
finalSubst
    (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> TIExprNode -> TIExpr
mkTIExpr (Type -> Type
TMatcher Type
matchedTy) ([TIPatternDef] -> TIExprNode
TIMatcherExpr [TIPatternDef]
tiPatDefs), Subst
allSubst)
    where
      -- Infer a single pattern definition (matcher clause)
      -- Returns (TIPatternDef, (matched type, [substitutions]))
      inferPatternDef :: TypeErrorContext -> IPatternDef -> Infer (TIPatternDef, (Type, [Subst]))
      inferPatternDef :: TypeErrorContext
-> IPatternDef
-> ExceptT
     TypeError (StateT InferState IO) (TIPatternDef, (Type, [Subst]))
inferPatternDef TypeErrorContext
ctx (PrimitivePatPattern
ppPat, IExpr
nextMatcherExpr, [IBindingExpr]
dataClauses) = do
        -- Infer the type of next matcher expression
        -- It should be a Matcher type (possibly Matcher of tuple, like Matcher (a, b))
        -- Note: (integer, integer) is inferred as Matcher (Integer, Integer), not (Matcher Integer, Matcher Integer)
        (TIExpr
nextMatcherTI, Subst
s1) <- IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
nextMatcherExpr TypeErrorContext
ctx
        let nextMatcherType :: Type
nextMatcherType = TIExpr -> Type
tiExprType TIExpr
nextMatcherTI
        
        -- nextMatcherType must be a Matcher type
        -- Unify with Matcher a to constrain it and detect errors early
        Type
matcherInnerTy <- String -> Infer Type
freshVar String
"matcherInner"
        Type
nextMatcherType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s1 Type
nextMatcherType
        Subst
s1' <- Type -> Type -> TypeErrorContext -> Infer Subst
unifyTypesWithContext Type
nextMatcherType' (Type -> Type
TMatcher Type
matcherInnerTy) TypeErrorContext
ctx
        Type
nextMatcherType'' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s1' Type
nextMatcherType
        
        -- Infer PrimitivePatPattern type to get matched type, pattern hole types, and variable bindings
        (Type
matchedType, [Type]
patternHoleTypes, [(String, TypeScheme)]
ppBindings, Subst
s_pp) <- PrimitivePatPattern
-> TypeErrorContext
-> Infer (Type, [Type], [(String, TypeScheme)], Subst)
inferPrimitivePatPattern PrimitivePatPattern
ppPat TypeErrorContext
ctx
        let s1'' :: Subst
s1'' = Subst -> Subst -> Subst
composeSubst Subst
s_pp Subst
s1'
        Type
matchedType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s1'' Type
matchedType
        let -- Apply substitution to variable bindings
            ppBindings' :: [(String, TypeScheme)]
ppBindings' = [(String
var, Subst -> TypeScheme -> TypeScheme
applySubstScheme Subst
s1'' TypeScheme
scheme) | (String
var, TypeScheme
scheme) <- [(String, TypeScheme)]
ppBindings]

        -- Apply substitution to pattern hole types (keep as inner types)
        [Type]
patternHoleTypes' <- (Type -> Infer Type)
-> [Type] -> ExceptT TypeError (StateT InferState IO) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s1'') [Type]
patternHoleTypes

        -- Extract inner type(s) from next matcher type
        -- If multiple pattern holes, combine them into a tuple to match ITupleExpr behavior
        [Type]
nextMatcherInnerTypes <- Type
-> Int
-> TypeErrorContext
-> ExceptT TypeError (StateT InferState IO) [Type]
extractInnerTypesFromMatcher Type
nextMatcherType'' ([Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
patternHoleTypes') TypeErrorContext
ctx
        
        -- Unify pattern hole types (inner types) with next matcher inner types
        Subst
s_unify <- [Type] -> [Type] -> TypeErrorContext -> Infer Subst
checkPatternHoleConsistency [Type]
patternHoleTypes' [Type]
nextMatcherInnerTypes TypeErrorContext
ctx
        let s1''' :: Subst
s1''' = Subst -> Subst -> Subst
composeSubst Subst
s_unify Subst
s1''
        
        -- Infer the type of data clauses with pp variables in scope
        -- Each data clause: (primitiveDataPattern, targetListExpr)
        [Subst]
dataClauseResults <- [(String, TypeScheme)] -> Infer [Subst] -> Infer [Subst]
forall a. [(String, TypeScheme)] -> Infer a -> Infer a
withEnv [(String, TypeScheme)]
ppBindings' (Infer [Subst] -> Infer [Subst]) -> Infer [Subst] -> Infer [Subst]
forall a b. (a -> b) -> a -> b
$ 
          (IBindingExpr -> Infer Subst) -> [IBindingExpr] -> Infer [Subst]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (TypeErrorContext -> [Type] -> Type -> IBindingExpr -> Infer Subst
inferDataClauseWithCheck TypeErrorContext
ctx [Type]
nextMatcherInnerTypes Type
matchedType') [IBindingExpr]
dataClauses
        let s2 :: Subst
s2 = (Subst -> Subst -> Subst) -> Subst -> [Subst] -> Subst
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Subst -> Subst -> Subst
composeSubst Subst
emptySubst [Subst]
dataClauseResults
        
        -- Build TIPatternDef: need to convert dataClauses to TIBindingExpr
        -- For each data clause, infer the pattern to get bindings, then infer the expression with those bindings
        [TIBindingExpr]
dataClauseTIs <- [(String, TypeScheme)]
-> Infer [TIBindingExpr] -> Infer [TIBindingExpr]
forall a. [(String, TypeScheme)] -> Infer a -> Infer a
withEnv [(String, TypeScheme)]
ppBindings' (Infer [TIBindingExpr] -> Infer [TIBindingExpr])
-> Infer [TIBindingExpr] -> Infer [TIBindingExpr]
forall a b. (a -> b) -> a -> b
$ 
          (IBindingExpr
 -> ExceptT TypeError (StateT InferState IO) TIBindingExpr)
-> [IBindingExpr] -> Infer [TIBindingExpr]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\(IPrimitiveDataPattern
pdPat, IExpr
targetExpr) -> do
            -- Infer primitive data pattern to get variable bindings
            (Type
_, [(String, TypeScheme)]
pdBindings, Subst
_) <- IPrimitiveDataPattern
-> Type
-> TypeErrorContext
-> Infer (Type, [(String, TypeScheme)], Subst)
inferPrimitiveDataPattern IPrimitiveDataPattern
pdPat Type
matchedType' TypeErrorContext
ctx
            -- Infer target expression with both pp variables and pd pattern variables in scope
            (TIExpr
targetTI, Subst
_) <- [(String, TypeScheme)]
-> Infer (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a. [(String, TypeScheme)] -> Infer a -> Infer a
withEnv [(String, TypeScheme)]
pdBindings (Infer (TIExpr, Subst) -> Infer (TIExpr, Subst))
-> Infer (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a b. (a -> b) -> a -> b
$ IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
targetExpr TypeErrorContext
ctx
            TIBindingExpr
-> ExceptT TypeError (StateT InferState IO) TIBindingExpr
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (IPrimitiveDataPattern
pdPat, TIExpr
targetTI)) [IBindingExpr]
dataClauses
        
        let tiPatDef :: TIPatternDef
tiPatDef = (PrimitivePatPattern
ppPat, TIExpr
nextMatcherTI, [TIBindingExpr]
dataClauseTIs)
        
        (TIPatternDef, (Type, [Subst]))
-> ExceptT
     TypeError (StateT InferState IO) (TIPatternDef, (Type, [Subst]))
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIPatternDef
tiPatDef, (Type
matchedType', [Subst
s1''', Subst
s2]))
      
      -- Infer PrimitivePatPattern type
      -- Returns (matched type, pattern hole types, variable bindings, substitution)
      -- Pattern hole types are the inner types (without TMatcher wrapper)
      -- The caller should wrap them with TMatcher when unifying with next matcher types
      -- Variable bindings are for PPValuePat variables (#$val)
      -- Note: Pattern hole types are determined by the pattern constructor, not by external context
      inferPrimitivePatPattern :: PrimitivePatPattern -> TypeErrorContext -> Infer (Type, [Type], [(String, TypeScheme)], Subst)
      inferPrimitivePatPattern :: PrimitivePatPattern
-> TypeErrorContext
-> Infer (Type, [Type], [(String, TypeScheme)], Subst)
inferPrimitivePatPattern PrimitivePatPattern
ppPat TypeErrorContext
ctx = case PrimitivePatPattern
ppPat of
        PrimitivePatPattern
PPWildCard -> do
          -- Wildcard pattern: no pattern holes, no bindings
          Type
matchedTy <- String -> Infer Type
freshVar String
"matched"
          (Type, [Type], [(String, TypeScheme)], Subst)
-> Infer (Type, [Type], [(String, TypeScheme)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
matchedTy, [], [], Subst
emptySubst)
        
        PrimitivePatPattern
PPPatVar -> do
          -- Pattern variable ($): one pattern hole, no binding
          -- Returns the matched type as the pattern hole type
          -- The caller will wrap it with TMatcher when unifying with next matcher type
          Type
matchedTy <- String -> Infer Type
freshVar String
"matched"
          (Type, [Type], [(String, TypeScheme)], Subst)
-> Infer (Type, [Type], [(String, TypeScheme)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
matchedTy, [Type
matchedTy], [], Subst
emptySubst)
        
        PPValuePat String
var -> do
          -- Value pattern (#$val): no pattern holes, binds variable to matched type
          Type
matchedTy <- String -> Infer Type
freshVar String
"matched"
          let binding :: (String, TypeScheme)
binding = (String
var, [TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [] Type
matchedTy)
          (Type, [Type], [(String, TypeScheme)], Subst)
-> Infer (Type, [Type], [(String, TypeScheme)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
matchedTy, [], [(String, TypeScheme)
binding], Subst
emptySubst)
        
        PPTuplePat [PrimitivePatPattern]
ppPats -> do
          -- Tuple pattern: ($p1, $p2, ...)
          -- Recursively infer each sub-pattern
          [(Type, [Type], [(String, TypeScheme)], Subst)]
results <- (PrimitivePatPattern
 -> Infer (Type, [Type], [(String, TypeScheme)], Subst))
-> [PrimitivePatPattern]
-> ExceptT
     TypeError
     (StateT InferState IO)
     [(Type, [Type], [(String, TypeScheme)], Subst)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\PrimitivePatPattern
pp -> PrimitivePatPattern
-> TypeErrorContext
-> Infer (Type, [Type], [(String, TypeScheme)], Subst)
inferPrimitivePatPattern PrimitivePatPattern
pp TypeErrorContext
ctx) [PrimitivePatPattern]
ppPats
          let matchedTypes :: [Type]
matchedTypes = [Type
mt | (Type
mt, [Type]
_, [(String, TypeScheme)]
_, Subst
_) <- [(Type, [Type], [(String, TypeScheme)], Subst)]
results]
              patternHoleLists :: [[Type]]
patternHoleLists = [[Type]
phs | (Type
_, [Type]
phs, [(String, TypeScheme)]
_, Subst
_) <- [(Type, [Type], [(String, TypeScheme)], Subst)]
results]
              bindingLists :: [[(String, TypeScheme)]]
bindingLists = [[(String, TypeScheme)]
bs | (Type
_, [Type]
_, [(String, TypeScheme)]
bs, Subst
_) <- [(Type, [Type], [(String, TypeScheme)], Subst)]
results]
              substs :: [Subst]
substs = [Subst
s | (Type
_, [Type]
_, [(String, TypeScheme)]
_, Subst
s) <- [(Type, [Type], [(String, TypeScheme)], Subst)]
results]
              allPatternHoles :: [Type]
allPatternHoles = [[Type]] -> [Type]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Type]]
patternHoleLists
              allBindings :: [(String, TypeScheme)]
allBindings = [[(String, TypeScheme)]] -> [(String, TypeScheme)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[(String, TypeScheme)]]
bindingLists
              finalSubst :: Subst
finalSubst = (Subst -> Subst -> Subst) -> Subst -> [Subst] -> Subst
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Subst -> Subst -> Subst
composeSubst Subst
emptySubst [Subst]
substs
          
          -- Matched type is tuple of matched types
          [Type]
matchedTypes' <- (Type -> Infer Type)
-> [Type] -> ExceptT TypeError (StateT InferState IO) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
finalSubst) [Type]
matchedTypes
          [Type]
allPatternHoles' <- (Type -> Infer Type)
-> [Type] -> ExceptT TypeError (StateT InferState IO) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
finalSubst) [Type]
allPatternHoles
          let matchedTy :: Type
matchedTy = [Type] -> Type
TTuple [Type]
matchedTypes'
          (Type, [Type], [(String, TypeScheme)], Subst)
-> Infer (Type, [Type], [(String, TypeScheme)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
matchedTy, [Type]
allPatternHoles', [(String, TypeScheme)]
allBindings, Subst
finalSubst)
        
        PPInductivePat String
name [PrimitivePatPattern]
ppPats -> do
          -- Inductive pattern: look up pattern constructor type from pattern environment
          PatternTypeEnv
patternEnv <- Infer PatternTypeEnv
getPatternEnv
          case String -> PatternTypeEnv -> Maybe TypeScheme
lookupPatternEnv String
name PatternTypeEnv
patternEnv of
            Just TypeScheme
scheme -> do
              -- Found in pattern environment: use the declared type
              InferState
st <- ExceptT TypeError (StateT InferState IO) InferState
forall s (m :: * -> *). MonadState s m => m s
get
              let ([Constraint]
_constraints, Type
ctorType, Int
newCounter) = TypeScheme -> Int -> ([Constraint], Type, Int)
instantiate TypeScheme
scheme (InferState -> Int
inferCounter InferState
st)
              (InferState -> InferState) -> Infer ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((InferState -> InferState) -> Infer ())
-> (InferState -> InferState) -> Infer ()
forall a b. (a -> b) -> a -> b
$ \InferState
s -> InferState
s { inferCounter = newCounter }
              
              -- Pattern constructor type: arg1 -> arg2 -> ... -> resultType
              -- Extract argument types and result type
              let ([Type]
argTypes, Type
resultType) = Type -> ([Type], Type)
extractFunctionArgs Type
ctorType
              
              -- Check argument count matches
              if [Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
argTypes Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= [PrimitivePatPattern] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimitivePatPattern]
ppPats
                then TypeError -> Infer (Type, [Type], [(String, TypeScheme)], Subst)
forall a. TypeError -> ExceptT TypeError (StateT InferState IO) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (TypeError -> Infer (Type, [Type], [(String, TypeScheme)], Subst))
-> TypeError -> Infer (Type, [Type], [(String, TypeScheme)], Subst)
forall a b. (a -> b) -> a -> b
$ Type -> Type -> String -> TypeErrorContext -> TypeError
TE.TypeMismatch
                       ((Type -> Type -> Type) -> Type -> [Type] -> Type
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Type -> Type -> Type
TFun Type
resultType (Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate ([PrimitivePatPattern] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimitivePatPattern]
ppPats) (TyVar -> Type
TVar (String -> TyVar
TyVar String
"a"))))
                       Type
ctorType
                       (String
"Pattern constructor " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
name String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" expects " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show ([Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
argTypes) 
                        String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" arguments, but got " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show ([PrimitivePatPattern] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimitivePatPattern]
ppPats))
                       TypeErrorContext
ctx
                else do
                  -- Recursively infer each sub-pattern
                  [(Type, [Type], [(String, TypeScheme)], Subst)]
results <- (PrimitivePatPattern
 -> Infer (Type, [Type], [(String, TypeScheme)], Subst))
-> [PrimitivePatPattern]
-> ExceptT
     TypeError
     (StateT InferState IO)
     [(Type, [Type], [(String, TypeScheme)], Subst)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\PrimitivePatPattern
pp -> PrimitivePatPattern
-> TypeErrorContext
-> Infer (Type, [Type], [(String, TypeScheme)], Subst)
inferPrimitivePatPattern PrimitivePatPattern
pp TypeErrorContext
ctx) [PrimitivePatPattern]
ppPats
                  
                  let matchedTypes :: [Type]
matchedTypes = [Type
mt | (Type
mt, [Type]
_, [(String, TypeScheme)]
_, Subst
_) <- [(Type, [Type], [(String, TypeScheme)], Subst)]
results]
                      patternHoleLists :: [[Type]]
patternHoleLists = [[Type]
phs | (Type
_, [Type]
phs, [(String, TypeScheme)]
_, Subst
_) <- [(Type, [Type], [(String, TypeScheme)], Subst)]
results]
                      bindingLists :: [[(String, TypeScheme)]]
bindingLists = [[(String, TypeScheme)]
bs | (Type
_, [Type]
_, [(String, TypeScheme)]
bs, Subst
_) <- [(Type, [Type], [(String, TypeScheme)], Subst)]
results]
                      substs :: [Subst]
substs = [Subst
s | (Type
_, [Type]
_, [(String, TypeScheme)]
_, Subst
s) <- [(Type, [Type], [(String, TypeScheme)], Subst)]
results]
                      allPatternHoles :: [Type]
allPatternHoles = [[Type]] -> [Type]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Type]]
patternHoleLists
                      allBindings :: [(String, TypeScheme)]
allBindings = [[(String, TypeScheme)]] -> [(String, TypeScheme)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[(String, TypeScheme)]]
bindingLists
                      s :: Subst
s = (Subst -> Subst -> Subst) -> Subst -> [Subst] -> Subst
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Subst -> Subst -> Subst
composeSubst Subst
emptySubst [Subst]
substs
                  
                  -- Verify that inferred matched types match expected argument types
                  -- Extract inner types from Matcher types in argTypes
                  let expectedMatchedTypes :: [Type]
expectedMatchedTypes = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (\Type
ty -> case Type
ty of
                        TMatcher Type
inner -> Type
inner
                        Type
_ -> Type
ty) [Type]
argTypes
                  Subst
s' <- (Subst -> (Type, Type) -> Infer Subst)
-> Subst -> [(Type, Type)] -> Infer Subst
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (\Subst
accS (Type
inferredTy, Type
expectedTy) -> do
                      Type
inferredTy' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
accS Type
inferredTy
                      Type
expectedTy' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
accS Type
expectedTy
                      Subst
s'' <- Type -> Type -> TypeErrorContext -> Infer Subst
unifyTypesWithContext Type
inferredTy' Type
expectedTy' TypeErrorContext
ctx
                      Subst -> Infer Subst
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Subst -> Infer Subst) -> Subst -> Infer Subst
forall a b. (a -> b) -> a -> b
$ Subst -> Subst -> Subst
composeSubst Subst
s'' Subst
accS
                    ) Subst
s ([Type] -> [Type] -> [(Type, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Type]
matchedTypes [Type]
expectedMatchedTypes)

                  Type
resultType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s' Type
resultType
                  [Type]
allPatternHoles' <- (Type -> Infer Type)
-> [Type] -> ExceptT TypeError (StateT InferState IO) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s') [Type]
allPatternHoles
                  (Type, [Type], [(String, TypeScheme)], Subst)
-> Infer (Type, [Type], [(String, TypeScheme)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
resultType', [Type]
allPatternHoles', [(String, TypeScheme)]
allBindings, Subst
s')
            
            Maybe TypeScheme
Nothing -> do
              -- Not found in pattern environment: use generic inference
              -- This is for backward compatibility
              [(Type, [Type], [(String, TypeScheme)], Subst)]
results <- (PrimitivePatPattern
 -> Infer (Type, [Type], [(String, TypeScheme)], Subst))
-> [PrimitivePatPattern]
-> ExceptT
     TypeError
     (StateT InferState IO)
     [(Type, [Type], [(String, TypeScheme)], Subst)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\PrimitivePatPattern
pp -> PrimitivePatPattern
-> TypeErrorContext
-> Infer (Type, [Type], [(String, TypeScheme)], Subst)
inferPrimitivePatPattern PrimitivePatPattern
pp TypeErrorContext
ctx) [PrimitivePatPattern]
ppPats
              let matchedTypes :: [Type]
matchedTypes = [Type
mt | (Type
mt, [Type]
_, [(String, TypeScheme)]
_, Subst
_) <- [(Type, [Type], [(String, TypeScheme)], Subst)]
results]
                  patternHoleLists :: [[Type]]
patternHoleLists = [[Type]
phs | (Type
_, [Type]
phs, [(String, TypeScheme)]
_, Subst
_) <- [(Type, [Type], [(String, TypeScheme)], Subst)]
results]
                  bindingLists :: [[(String, TypeScheme)]]
bindingLists = [[(String, TypeScheme)]
bs | (Type
_, [Type]
_, [(String, TypeScheme)]
bs, Subst
_) <- [(Type, [Type], [(String, TypeScheme)], Subst)]
results]
                  substs :: [Subst]
substs = [Subst
s | (Type
_, [Type]
_, [(String, TypeScheme)]
_, Subst
s) <- [(Type, [Type], [(String, TypeScheme)], Subst)]
results]
                  allPatternHoles :: [Type]
allPatternHoles = [[Type]] -> [Type]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Type]]
patternHoleLists
                  allBindings :: [(String, TypeScheme)]
allBindings = [[(String, TypeScheme)]] -> [(String, TypeScheme)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[(String, TypeScheme)]]
bindingLists
                  s :: Subst
s = (Subst -> Subst -> Subst) -> Subst -> [Subst] -> Subst
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Subst -> Subst -> Subst
composeSubst Subst
emptySubst [Subst]
substs
              
              -- Result type is inductive type
              [Type]
matchedTypes' <- (Type -> Infer Type)
-> [Type] -> ExceptT TypeError (StateT InferState IO) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s) [Type]
matchedTypes
              [Type]
allPatternHoles' <- (Type -> Infer Type)
-> [Type] -> ExceptT TypeError (StateT InferState IO) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s) [Type]
allPatternHoles
              let resultType :: Type
resultType = String -> [Type] -> Type
TInductive String
name [Type]
matchedTypes'
              (Type, [Type], [(String, TypeScheme)], Subst)
-> Infer (Type, [Type], [(String, TypeScheme)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
resultType, [Type]
allPatternHoles', [(String, TypeScheme)]
allBindings, Subst
s)
      
      -- Extract function argument types and result type
      -- e.g., a -> b -> c -> d  =>  ([a, b, c], d)
      extractFunctionArgs :: Type -> ([Type], Type)
      extractFunctionArgs :: Type -> ([Type], Type)
extractFunctionArgs (TFun Type
arg Type
rest) = 
        let ([Type]
args, Type
result) = Type -> ([Type], Type)
extractFunctionArgs Type
rest
        in (Type
arg Type -> [Type] -> [Type]
forall a. a -> [a] -> [a]
: [Type]
args, Type
result)
      extractFunctionArgs Type
t = ([], Type
t)
      
      -- Extract matched type from Matcher type
      -- Check consistency between pattern hole types and next matcher types
      checkPatternHoleConsistency :: [Type] -> [Type] -> TypeErrorContext -> Infer Subst
      checkPatternHoleConsistency :: [Type] -> [Type] -> TypeErrorContext -> Infer Subst
checkPatternHoleConsistency [] [] TypeErrorContext
_ctx = Subst -> Infer Subst
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return Subst
emptySubst
      checkPatternHoleConsistency [Type]
patternHoles [Type]
nextMatchers TypeErrorContext
ctx
        | [Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
patternHoles Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= [Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
nextMatchers = 
            TypeError -> Infer Subst
forall a. TypeError -> ExceptT TypeError (StateT InferState IO) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (TypeError -> Infer Subst) -> TypeError -> Infer Subst
forall a b. (a -> b) -> a -> b
$ Type -> Type -> String -> TypeErrorContext -> TypeError
TE.TypeMismatch
              ([Type] -> Type
TTuple [Type]
nextMatchers)
              ([Type] -> Type
TTuple [Type]
patternHoles)
              (String
"Inconsistent number of pattern holes (" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show ([Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
patternHoles) 
               String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
") and next matchers (" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show ([Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
nextMatchers) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")")
              TypeErrorContext
ctx
        | Bool
otherwise = do
            -- Unify each pattern hole type with corresponding next matcher type
            (Subst -> (Type, Type) -> Infer Subst)
-> Subst -> [(Type, Type)] -> Infer Subst
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (\Subst
accS (Type
holeTy, Type
matcherTy) -> do
                Type
holeTy' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
accS Type
holeTy
                Type
matcherTy' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
accS Type
matcherTy
                Subst
s <- Type -> Type -> TypeErrorContext -> Infer Subst
unifyTypesWithContext Type
holeTy' Type
matcherTy' TypeErrorContext
ctx
                Subst -> Infer Subst
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Subst -> Infer Subst) -> Subst -> Infer Subst
forall a b. (a -> b) -> a -> b
$ Subst -> Subst -> Subst
composeSubst Subst
s Subst
accS
              ) Subst
emptySubst ([Type] -> [Type] -> [(Type, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Type]
patternHoles [Type]
nextMatchers)
      
      -- Extract inner types from next matcher type
      -- Given Matcher a, returns [a]
      -- Given Matcher (a, b, ...) and n pattern holes, returns [a, b, ...] if n > 1, or [(a, b, ...)] if n = 1
      -- Special case: (Matcher a, Matcher b, ...) should be converted to Matcher (a, b, ...) first
      -- Note: Even when numHoles = 0, we extract inner types to detect mismatches in checkPatternHoleConsistency
      extractInnerTypesFromMatcher :: Type -> Int -> TypeErrorContext -> Infer [Type]
      extractInnerTypesFromMatcher :: Type
-> Int
-> TypeErrorContext
-> ExceptT TypeError (StateT InferState IO) [Type]
extractInnerTypesFromMatcher Type
matcherType Int
numHoles TypeErrorContext
ctx = case Int
numHoles of
        Int
0 -> case Type
matcherType of
          -- No pattern holes, but extract inner type to allow error detection
          TMatcher Type
innerType -> [Type] -> ExceptT TypeError (StateT InferState IO) [Type]
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return [Type
innerType]
          TTuple [Type]
types -> do
            let matcherInners :: Maybe [Type]
matcherInners = (Type -> Maybe Type) -> [Type] -> Maybe [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Type -> Maybe Type
extractMatcherInner [Type]
types
            case Maybe [Type]
matcherInners of
              Just [Type]
inners -> [Type] -> ExceptT TypeError (StateT InferState IO) [Type]
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return [Type]
inners
              Maybe [Type]
Nothing -> [Type] -> ExceptT TypeError (StateT InferState IO) [Type]
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return []  -- Not matcher types, return empty
          Type
_ -> [Type] -> ExceptT TypeError (StateT InferState IO) [Type]
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return []  -- Not a matcher type
        Int
1 -> case Type
matcherType of
          TMatcher Type
innerType -> [Type] -> ExceptT TypeError (StateT InferState IO) [Type]
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return [Type
innerType]  -- Single hole: return inner type as-is
          -- Special case: (Matcher a, Matcher b, ...) from ITupleExpr that failed to convert
          -- This can happen when matcher parameters are used before ITupleExpr conversion
          TTuple [Type]
types -> do
            let matcherInners :: Maybe [Type]
matcherInners = (Type -> Maybe Type) -> [Type] -> Maybe [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Type -> Maybe Type
extractMatcherInner [Type]
types
            case Maybe [Type]
matcherInners of
              Just [Type]
inners -> [Type] -> ExceptT TypeError (StateT InferState IO) [Type]
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return [[Type] -> Type
TTuple [Type]
inners]  -- Return as single tuple type
              Maybe [Type]
Nothing -> TypeError -> ExceptT TypeError (StateT InferState IO) [Type]
forall a. TypeError -> ExceptT TypeError (StateT InferState IO) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (TypeError -> ExceptT TypeError (StateT InferState IO) [Type])
-> TypeError -> ExceptT TypeError (StateT InferState IO) [Type]
forall a b. (a -> b) -> a -> b
$ Type -> Type -> String -> TypeErrorContext -> TypeError
TE.TypeMismatch
                           (Type -> Type
TMatcher (TyVar -> Type
TVar (String -> TyVar
TyVar String
"a")))
                           Type
matcherType
                           String
"Expected Matcher type or tuple of Matcher types"
                           TypeErrorContext
ctx
          Type
_ -> TypeError -> ExceptT TypeError (StateT InferState IO) [Type]
forall a. TypeError -> ExceptT TypeError (StateT InferState IO) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (TypeError -> ExceptT TypeError (StateT InferState IO) [Type])
-> TypeError -> ExceptT TypeError (StateT InferState IO) [Type]
forall a b. (a -> b) -> a -> b
$ Type -> Type -> String -> TypeErrorContext -> TypeError
TE.TypeMismatch
                 (Type -> Type
TMatcher (TyVar -> Type
TVar (String -> TyVar
TyVar String
"a")))
                 Type
matcherType
                 String
"Expected Matcher type"
                 TypeErrorContext
ctx
        Int
n -> case Type
matcherType of
          -- Multiple holes: expect Matcher (tuple) and extract each element
          TMatcher (TTuple [Type]
innerTypes) ->
            if [Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
innerTypes Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n
              then [Type] -> ExceptT TypeError (StateT InferState IO) [Type]
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return [Type]
innerTypes
              else TypeError -> ExceptT TypeError (StateT InferState IO) [Type]
forall a. TypeError -> ExceptT TypeError (StateT InferState IO) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (TypeError -> ExceptT TypeError (StateT InferState IO) [Type])
-> TypeError -> ExceptT TypeError (StateT InferState IO) [Type]
forall a b. (a -> b) -> a -> b
$ Type -> Type -> String -> TypeErrorContext -> TypeError
TE.TypeMismatch
                     (Type -> Type
TMatcher ([Type] -> Type
TTuple (Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate Int
n (TyVar -> Type
TVar (String -> TyVar
TyVar String
"a")))))
                     Type
matcherType
                     (String
"Expected Matcher with tuple of " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
n String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" elements, but got " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show ([Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
innerTypes))
                     TypeErrorContext
ctx
          -- Special case: (Matcher a, Matcher b, ...) - extract inner types directly
          TTuple [Type]
types -> do
            let matcherInners :: Maybe [Type]
matcherInners = (Type -> Maybe Type) -> [Type] -> Maybe [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Type -> Maybe Type
extractMatcherInner [Type]
types
            case Maybe [Type]
matcherInners of
              Just [Type]
inners | [Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
inners Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n -> [Type] -> ExceptT TypeError (StateT InferState IO) [Type]
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return [Type]
inners
              Maybe [Type]
_ -> TypeError -> ExceptT TypeError (StateT InferState IO) [Type]
forall a. TypeError -> ExceptT TypeError (StateT InferState IO) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (TypeError -> ExceptT TypeError (StateT InferState IO) [Type])
-> TypeError -> ExceptT TypeError (StateT InferState IO) [Type]
forall a b. (a -> b) -> a -> b
$ Type -> Type -> String -> TypeErrorContext -> TypeError
TE.TypeMismatch
                     (Type -> Type
TMatcher ([Type] -> Type
TTuple (Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate Int
n (TyVar -> Type
TVar (String -> TyVar
TyVar String
"a")))))
                     Type
matcherType
                     String
"Expected tuple of Matcher types with correct count"
                     TypeErrorContext
ctx
          Type
_ -> TypeError -> ExceptT TypeError (StateT InferState IO) [Type]
forall a. TypeError -> ExceptT TypeError (StateT InferState IO) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (TypeError -> ExceptT TypeError (StateT InferState IO) [Type])
-> TypeError -> ExceptT TypeError (StateT InferState IO) [Type]
forall a b. (a -> b) -> a -> b
$ Type -> Type -> String -> TypeErrorContext -> TypeError
TE.TypeMismatch
                 (Type -> Type
TMatcher ([Type] -> Type
TTuple (Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate Int
n (TyVar -> Type
TVar (String -> TyVar
TyVar String
"a")))))
                 Type
matcherType
                 (String
"Expected Matcher of tuple with " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
n String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" elements")
                 TypeErrorContext
ctx
      
      -- Helper: Extract inner type from Matcher a -> Just a, otherwise Nothing
      extractMatcherInner :: Type -> Maybe Type
      extractMatcherInner :: Type -> Maybe Type
extractMatcherInner (TMatcher Type
t) = Type -> Maybe Type
forall a. a -> Maybe a
Just Type
t
      extractMatcherInner Type
_ = Maybe Type
forall a. Maybe a
Nothing
      
      -- Infer a data clause with type checking
      -- Check that the target expression returns a list of values with types matching next matcher inner types
      -- Also uses matched type for validation
      -- nextMatcherInnerTypes: inner types extracted from next matcher (already without TMatcher wrapper)
      inferDataClauseWithCheck :: TypeErrorContext -> [Type] -> Type -> (IPrimitiveDataPattern, IExpr) -> Infer Subst
      inferDataClauseWithCheck :: TypeErrorContext -> [Type] -> Type -> IBindingExpr -> Infer Subst
inferDataClauseWithCheck TypeErrorContext
ctx [Type]
nextMatcherInnerTypes Type
matchedType (IPrimitiveDataPattern
pdPat, IExpr
targetExpr) = do
        -- Extract expected element type from next matcher inner types (the target type)
        -- This is the type of elements in the list returned by the target expression
        Type
targetType <- case [Type]
nextMatcherInnerTypes of
          [] -> Type -> Infer Type
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return ([Type] -> Type
TTuple [])  -- No pattern holes: empty tuple () case
          [Type
single] -> Type -> Infer Type
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return Type
single  -- Single pattern hole: use inner type directly
          [Type]
multiple -> Type -> Infer Type
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return ([Type] -> Type
TTuple [Type]
multiple)  -- Multiple holes: tuple of inner types
        
        -- Infer PrimitiveDataPattern with matched type
        -- Primitive data pattern matches against values of the matched type
        -- and produces bindings and next targets
        (Type
pdTargetType, [(String, TypeScheme)]
bindings, Subst
s_pd) <- IPrimitiveDataPattern
-> Type
-> TypeErrorContext
-> Infer (Type, [(String, TypeScheme)], Subst)
inferPrimitiveDataPattern IPrimitiveDataPattern
pdPat Type
matchedType TypeErrorContext
ctx
        
        -- The primitive data pattern should match the matched type
        -- No need to unify pdTargetType with targetType - they serve different purposes
        -- pdTargetType: type of data that pdPat matches (should be matchedType)
        -- targetType: type of next targets returned by the target expression
        
        -- Verify that pdTargetType is consistent with matchedType
        Type
pdTargetType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s_pd Type
pdTargetType
        Type
matchedType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s_pd Type
matchedType
        Subst
s_match <- Type -> Type -> TypeErrorContext -> Infer Subst
unifyTypesWithContext Type
pdTargetType' Type
matchedType' TypeErrorContext
ctx
        let s_pd' :: Subst
s_pd' = Subst -> Subst -> Subst
composeSubst Subst
s_match Subst
s_pd

        -- Infer the target expression with pattern variables in scope
        (TIExpr
targetTI, Subst
s1) <- [(String, TypeScheme)]
-> Infer (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a. [(String, TypeScheme)] -> Infer a -> Infer a
withEnv [(String, TypeScheme)]
bindings (Infer (TIExpr, Subst) -> Infer (TIExpr, Subst))
-> Infer (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a b. (a -> b) -> a -> b
$ IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
targetExpr TypeErrorContext
ctx
        let exprType :: Type
exprType = TIExpr -> Type
tiExprType TIExpr
targetTI
            s_combined :: Subst
s_combined = Subst -> Subst -> Subst
composeSubst Subst
s1 Subst
s_pd'

        -- Unify with actual expression type
        -- Expected: [targetType]
        Type
targetType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s_combined Type
targetType
        let expectedType :: Type
expectedType = Type -> Type
TCollection Type
targetType'

        Type
exprType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s_combined Type
exprType
        Subst
s2 <- Type -> Type -> TypeErrorContext -> Infer Subst
unifyTypesWithContext Type
exprType' Type
expectedType TypeErrorContext
ctx
        Subst -> Infer Subst
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Subst -> Infer Subst) -> Subst -> Infer Subst
forall a b. (a -> b) -> a -> b
$ Subst -> Subst -> Subst
composeSubst Subst
s2 Subst
s_combined
      
      -- Helper to check if a pattern is a pattern variable
      isPDPatVar :: IPrimitiveDataPattern -> Bool
      isPDPatVar :: IPrimitiveDataPattern -> Bool
isPDPatVar (PDPatVar Var
_) = Bool
True
      isPDPatVar IPrimitiveDataPattern
_ = Bool
False
      
      -- Infer PrimitiveDataPattern type
      -- Returns (inferred target type, variable bindings, substitution)
      -- This is similar to pattern matching in Haskell for algebraic data types
      inferPrimitiveDataPattern :: IPrimitiveDataPattern -> Type -> TypeErrorContext -> Infer (Type, [(String, TypeScheme)], Subst)
      inferPrimitiveDataPattern :: IPrimitiveDataPattern
-> Type
-> TypeErrorContext
-> Infer (Type, [(String, TypeScheme)], Subst)
inferPrimitiveDataPattern IPrimitiveDataPattern
pdPat Type
expectedType TypeErrorContext
ctx = case IPrimitiveDataPattern
pdPat of
        IPrimitiveDataPattern
PDWildCard -> do
          -- Wildcard: matches any type, no bindings
          (Type, [(String, TypeScheme)], Subst)
-> Infer (Type, [(String, TypeScheme)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
expectedType, [], Subst
emptySubst)
        
        PDPatVar Var
var -> do
          -- Pattern variable: binds to the expected type
          let varName :: String
varName = Var -> String
extractNameFromVar Var
var
          (Type, [(String, TypeScheme)], Subst)
-> Infer (Type, [(String, TypeScheme)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
expectedType, [(String
varName, [TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [] Type
expectedType)], Subst
emptySubst)
        
        PDConstantPat ConstantExpr
c -> do
          -- Constant pattern: must match the constant's type
          Type
constTy <- ConstantExpr -> Infer Type
inferConstant ConstantExpr
c
          Subst
s <- Type -> Type -> TypeErrorContext -> Infer Subst
unifyTypesWithContext Type
constTy Type
expectedType TypeErrorContext
ctx
          Type
expectedType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s Type
expectedType
          (Type, [(String, TypeScheme)], Subst)
-> Infer (Type, [(String, TypeScheme)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
expectedType', [], Subst
s)
        
        PDTuplePat [IPrimitiveDataPattern]
pats -> do
          -- Tuple pattern: expected type should be a tuple
          case Type
expectedType of
            TTuple [Type]
types | [Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
types Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [IPrimitiveDataPattern] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [IPrimitiveDataPattern]
pats -> do
              -- Types match: infer each sub-pattern
              [(Type, [(String, TypeScheme)], Subst)]
results <- (IPrimitiveDataPattern
 -> Type -> Infer (Type, [(String, TypeScheme)], Subst))
-> [IPrimitiveDataPattern]
-> [Type]
-> ExceptT
     TypeError
     (StateT InferState IO)
     [(Type, [(String, TypeScheme)], Subst)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (\IPrimitiveDataPattern
p Type
t -> IPrimitiveDataPattern
-> Type
-> TypeErrorContext
-> Infer (Type, [(String, TypeScheme)], Subst)
inferPrimitiveDataPattern IPrimitiveDataPattern
p Type
t TypeErrorContext
ctx) [IPrimitiveDataPattern]
pats [Type]
types
              let ([Type]
_, [[(String, TypeScheme)]]
bindingsList, [Subst]
substs) = [(Type, [(String, TypeScheme)], Subst)]
-> ([Type], [[(String, TypeScheme)]], [Subst])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Type, [(String, TypeScheme)], Subst)]
results
                  allBindings :: [(String, TypeScheme)]
allBindings = [[(String, TypeScheme)]] -> [(String, TypeScheme)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[(String, TypeScheme)]]
bindingsList
                  s :: Subst
s = (Subst -> Subst -> Subst) -> Subst -> [Subst] -> Subst
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Subst -> Subst -> Subst
composeSubst Subst
emptySubst [Subst]
substs
              Type
expectedType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s Type
expectedType
              (Type, [(String, TypeScheme)], Subst)
-> Infer (Type, [(String, TypeScheme)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
expectedType', [(String, TypeScheme)]
allBindings, Subst
s)
            
            TVar TyVar
_ -> do
              -- Expected type is a type variable: create fresh types for each element
              [Type]
elemTypes <- (IPrimitiveDataPattern -> Infer Type)
-> [IPrimitiveDataPattern]
-> ExceptT TypeError (StateT InferState IO) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\IPrimitiveDataPattern
_ -> String -> Infer Type
freshVar String
"elem") [IPrimitiveDataPattern]
pats
              let tupleTy :: Type
tupleTy = [Type] -> Type
TTuple [Type]
elemTypes
              Subst
s <- Type -> Type -> TypeErrorContext -> Infer Subst
unifyTypesWithContext Type
expectedType Type
tupleTy TypeErrorContext
ctx

              -- Recursively infer each sub-pattern
              [Type]
elemTypes' <- (Type -> Infer Type)
-> [Type] -> ExceptT TypeError (StateT InferState IO) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s) [Type]
elemTypes
              [(Type, [(String, TypeScheme)], Subst)]
results <- (IPrimitiveDataPattern
 -> Type -> Infer (Type, [(String, TypeScheme)], Subst))
-> [IPrimitiveDataPattern]
-> [Type]
-> ExceptT
     TypeError
     (StateT InferState IO)
     [(Type, [(String, TypeScheme)], Subst)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (\IPrimitiveDataPattern
p Type
t -> IPrimitiveDataPattern
-> Type
-> TypeErrorContext
-> Infer (Type, [(String, TypeScheme)], Subst)
inferPrimitiveDataPattern IPrimitiveDataPattern
p Type
t TypeErrorContext
ctx) [IPrimitiveDataPattern]
pats [Type]
elemTypes'
              let ([Type]
_, [[(String, TypeScheme)]]
bindingsList, [Subst]
substs) = [(Type, [(String, TypeScheme)], Subst)]
-> ([Type], [[(String, TypeScheme)]], [Subst])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Type, [(String, TypeScheme)], Subst)]
results
                  allBindings :: [(String, TypeScheme)]
allBindings = [[(String, TypeScheme)]] -> [(String, TypeScheme)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[(String, TypeScheme)]]
bindingsList
                  s' :: Subst
s' = (Subst -> Subst -> Subst) -> Subst -> [Subst] -> Subst
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Subst -> Subst -> Subst
composeSubst Subst
s [Subst]
substs
              Type
tupleTy' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s' Type
tupleTy
              (Type, [(String, TypeScheme)], Subst)
-> Infer (Type, [(String, TypeScheme)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
tupleTy', [(String, TypeScheme)]
allBindings, Subst
s')
            
            Type
_ -> do
              -- Type mismatch
              TypeError -> Infer (Type, [(String, TypeScheme)], Subst)
forall a. TypeError -> ExceptT TypeError (StateT InferState IO) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (TypeError -> Infer (Type, [(String, TypeScheme)], Subst))
-> TypeError -> Infer (Type, [(String, TypeScheme)], Subst)
forall a b. (a -> b) -> a -> b
$ Type -> Type -> String -> TypeErrorContext -> TypeError
TE.TypeMismatch
                ([Type] -> Type
TTuple (Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate ([IPrimitiveDataPattern] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [IPrimitiveDataPattern]
pats) (TyVar -> Type
TVar (String -> TyVar
TyVar String
"a"))))
                Type
expectedType
                String
"Tuple pattern but target is not a tuple type"
                TypeErrorContext
ctx
        
        IPrimitiveDataPattern
PDEmptyPat -> do
          -- Empty collection pattern: expected type should be [a] for some a
          Type
elemTy <- String -> Infer Type
freshVar String
"elem"
          Subst
s <- Type -> Type -> TypeErrorContext -> Infer Subst
unifyTypesWithContext Type
expectedType (Type -> Type
TCollection Type
elemTy) TypeErrorContext
ctx
          Type
collTy <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s (Type -> Type
TCollection Type
elemTy)
          (Type, [(String, TypeScheme)], Subst)
-> Infer (Type, [(String, TypeScheme)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
collTy, [], Subst
s)
        
        PDConsPat IPrimitiveDataPattern
p1 IPrimitiveDataPattern
p2 -> do
          -- Cons pattern: expected type should be [a] for some a
          case Type
expectedType of
            TCollection Type
elemType -> do
              -- Infer head pattern with element type
              (Type
_, [(String, TypeScheme)]
bindings1, Subst
s1) <- IPrimitiveDataPattern
-> Type
-> TypeErrorContext
-> Infer (Type, [(String, TypeScheme)], Subst)
inferPrimitiveDataPattern IPrimitiveDataPattern
p1 Type
elemType TypeErrorContext
ctx
              -- Infer tail pattern with collection type
              Type
expectedType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s1 Type
expectedType
              (Type
_, [(String, TypeScheme)]
bindings2, Subst
s2) <- IPrimitiveDataPattern
-> Type
-> TypeErrorContext
-> Infer (Type, [(String, TypeScheme)], Subst)
inferPrimitiveDataPattern IPrimitiveDataPattern
p2 Type
expectedType' TypeErrorContext
ctx
              let s :: Subst
s = Subst -> Subst -> Subst
composeSubst Subst
s2 Subst
s1
              Type
expectedType'' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s Type
expectedType
              (Type, [(String, TypeScheme)], Subst)
-> Infer (Type, [(String, TypeScheme)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
expectedType'', [(String, TypeScheme)]
bindings1 [(String, TypeScheme)]
-> [(String, TypeScheme)] -> [(String, TypeScheme)]
forall a. [a] -> [a] -> [a]
++ [(String, TypeScheme)]
bindings2, Subst
s)
            
            TVar TyVar
_ -> do
              -- Expected type is a type variable: constrain it to be a collection
              Type
elemTy <- String -> Infer Type
freshVar String
"elem"
              Subst
s <- Type -> Type -> TypeErrorContext -> Infer Subst
unifyTypesWithContext Type
expectedType (Type -> Type
TCollection Type
elemTy) TypeErrorContext
ctx
              Type
collTy <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s (Type -> Type
TCollection Type
elemTy)
              Type
elemTy' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s Type
elemTy
              (Type
_, [(String, TypeScheme)]
bindings1, Subst
s1) <- IPrimitiveDataPattern
-> Type
-> TypeErrorContext
-> Infer (Type, [(String, TypeScheme)], Subst)
inferPrimitiveDataPattern IPrimitiveDataPattern
p1 Type
elemTy' TypeErrorContext
ctx
              Type
collTy' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s1 Type
collTy
              (Type
_, [(String, TypeScheme)]
bindings2, Subst
s2) <- IPrimitiveDataPattern
-> Type
-> TypeErrorContext
-> Infer (Type, [(String, TypeScheme)], Subst)
inferPrimitiveDataPattern IPrimitiveDataPattern
p2 Type
collTy' TypeErrorContext
ctx
              let s' :: Subst
s' = Subst -> Subst -> Subst
composeSubst Subst
s2 (Subst -> Subst -> Subst
composeSubst Subst
s1 Subst
s)
              Type
collTy'' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s' Type
collTy
              (Type, [(String, TypeScheme)], Subst)
-> Infer (Type, [(String, TypeScheme)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
collTy'', [(String, TypeScheme)]
bindings1 [(String, TypeScheme)]
-> [(String, TypeScheme)] -> [(String, TypeScheme)]
forall a. [a] -> [a] -> [a]
++ [(String, TypeScheme)]
bindings2, Subst
s')
            
            Type
_ -> do
              TypeError -> Infer (Type, [(String, TypeScheme)], Subst)
forall a. TypeError -> ExceptT TypeError (StateT InferState IO) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (TypeError -> Infer (Type, [(String, TypeScheme)], Subst))
-> TypeError -> Infer (Type, [(String, TypeScheme)], Subst)
forall a b. (a -> b) -> a -> b
$ Type -> Type -> String -> TypeErrorContext -> TypeError
TE.TypeMismatch
                (Type -> Type
TCollection (TyVar -> Type
TVar (String -> TyVar
TyVar String
"a")))
                Type
expectedType
                String
"Cons pattern but target is not a collection type"
                TypeErrorContext
ctx
        
        PDSnocPat IPrimitiveDataPattern
p1 IPrimitiveDataPattern
p2 -> do
          -- Snoc pattern: similar to cons but reversed
          case Type
expectedType of
            TCollection Type
elemType -> do
              (Type
_, [(String, TypeScheme)]
bindings1, Subst
s1) <- IPrimitiveDataPattern
-> Type
-> TypeErrorContext
-> Infer (Type, [(String, TypeScheme)], Subst)
inferPrimitiveDataPattern IPrimitiveDataPattern
p1 Type
expectedType TypeErrorContext
ctx
              Type
elemType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s1 Type
elemType
              (Type
_, [(String, TypeScheme)]
bindings2, Subst
s2) <- IPrimitiveDataPattern
-> Type
-> TypeErrorContext
-> Infer (Type, [(String, TypeScheme)], Subst)
inferPrimitiveDataPattern IPrimitiveDataPattern
p2 Type
elemType' TypeErrorContext
ctx
              let s :: Subst
s = Subst -> Subst -> Subst
composeSubst Subst
s2 Subst
s1
              Type
expectedType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s Type
expectedType
              (Type, [(String, TypeScheme)], Subst)
-> Infer (Type, [(String, TypeScheme)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
expectedType', [(String, TypeScheme)]
bindings1 [(String, TypeScheme)]
-> [(String, TypeScheme)] -> [(String, TypeScheme)]
forall a. [a] -> [a] -> [a]
++ [(String, TypeScheme)]
bindings2, Subst
s)
            
            TVar TyVar
_ -> do
              Type
elemTy <- String -> Infer Type
freshVar String
"elem"
              Subst
s <- Type -> Type -> TypeErrorContext -> Infer Subst
unifyTypesWithContext Type
expectedType (Type -> Type
TCollection Type
elemTy) TypeErrorContext
ctx
              Type
collTy <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s (Type -> Type
TCollection Type
elemTy)
              Type
elemTy' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s Type
elemTy
              (Type
_, [(String, TypeScheme)]
bindings1, Subst
s1) <- IPrimitiveDataPattern
-> Type
-> TypeErrorContext
-> Infer (Type, [(String, TypeScheme)], Subst)
inferPrimitiveDataPattern IPrimitiveDataPattern
p1 Type
collTy TypeErrorContext
ctx
              Type
elemTy'' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s1 Type
elemTy'
              (Type
_, [(String, TypeScheme)]
bindings2, Subst
s2) <- IPrimitiveDataPattern
-> Type
-> TypeErrorContext
-> Infer (Type, [(String, TypeScheme)], Subst)
inferPrimitiveDataPattern IPrimitiveDataPattern
p2 Type
elemTy'' TypeErrorContext
ctx
              let s' :: Subst
s' = Subst -> Subst -> Subst
composeSubst Subst
s2 (Subst -> Subst -> Subst
composeSubst Subst
s1 Subst
s)
              Type
collTy' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s' Type
collTy
              (Type, [(String, TypeScheme)], Subst)
-> Infer (Type, [(String, TypeScheme)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
collTy', [(String, TypeScheme)]
bindings1 [(String, TypeScheme)]
-> [(String, TypeScheme)] -> [(String, TypeScheme)]
forall a. [a] -> [a] -> [a]
++ [(String, TypeScheme)]
bindings2, Subst
s')
            
            Type
_ -> do
              TypeError -> Infer (Type, [(String, TypeScheme)], Subst)
forall a. TypeError -> ExceptT TypeError (StateT InferState IO) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (TypeError -> Infer (Type, [(String, TypeScheme)], Subst))
-> TypeError -> Infer (Type, [(String, TypeScheme)], Subst)
forall a b. (a -> b) -> a -> b
$ Type -> Type -> String -> TypeErrorContext -> TypeError
TE.TypeMismatch
                (Type -> Type
TCollection (TyVar -> Type
TVar (String -> TyVar
TyVar String
"a")))
                Type
expectedType
                String
"Snoc pattern but target is not a collection type"
                TypeErrorContext
ctx
        
        PDInductivePat String
name [IPrimitiveDataPattern]
pats -> do
          -- Inductive pattern: look up data constructor type from environment
          TypeEnv
env <- Infer TypeEnv
getEnv
          case Var -> TypeEnv -> Maybe TypeScheme
lookupEnv (String -> Var
stringToVar String
name) TypeEnv
env of
            Just TypeScheme
scheme -> do
              -- Found in environment: use the declared type
              InferState
st <- ExceptT TypeError (StateT InferState IO) InferState
forall s (m :: * -> *). MonadState s m => m s
get
              let ([Constraint]
_constraints, Type
ctorType, Int
newCounter) = TypeScheme -> Int -> ([Constraint], Type, Int)
instantiate TypeScheme
scheme (InferState -> Int
inferCounter InferState
st)
              (InferState -> InferState) -> Infer ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((InferState -> InferState) -> Infer ())
-> (InferState -> InferState) -> Infer ()
forall a b. (a -> b) -> a -> b
$ \InferState
s -> InferState
s { inferCounter = newCounter }
              
              -- Data constructor type: arg1 -> arg2 -> ... -> resultType
              let ([Type]
argTypes, Type
resultType) = Type -> ([Type], Type)
extractFunctionArgs Type
ctorType
              
              -- Check argument count matches
              if [Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
argTypes Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= [IPrimitiveDataPattern] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [IPrimitiveDataPattern]
pats
                then TypeError -> Infer (Type, [(String, TypeScheme)], Subst)
forall a. TypeError -> ExceptT TypeError (StateT InferState IO) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (TypeError -> Infer (Type, [(String, TypeScheme)], Subst))
-> TypeError -> Infer (Type, [(String, TypeScheme)], Subst)
forall a b. (a -> b) -> a -> b
$ Type -> Type -> String -> TypeErrorContext -> TypeError
TE.TypeMismatch
                       ((Type -> Type -> Type) -> Type -> [Type] -> Type
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Type -> Type -> Type
TFun Type
resultType (Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate ([IPrimitiveDataPattern] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [IPrimitiveDataPattern]
pats) (TyVar -> Type
TVar (String -> TyVar
TyVar String
"a"))))
                       Type
ctorType
                       (String
"Data constructor " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
name String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" expects " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show ([Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
argTypes) 
                        String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" arguments, but got " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show ([IPrimitiveDataPattern] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [IPrimitiveDataPattern]
pats))
                       TypeErrorContext
ctx
                else do
                  -- Unify result type with expected type
                  Subst
s0 <- Type -> Type -> TypeErrorContext -> Infer Subst
unifyTypesWithContext Type
resultType Type
expectedType TypeErrorContext
ctx
                  Type
resultType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s0 Type
resultType
                  [Type]
argTypes' <- (Type -> Infer Type)
-> [Type] -> ExceptT TypeError (StateT InferState IO) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s0) [Type]
argTypes

                  -- Recursively infer each sub-pattern
                  [(Type, [(String, TypeScheme)], Subst)]
results <- (IPrimitiveDataPattern
 -> Type -> Infer (Type, [(String, TypeScheme)], Subst))
-> [IPrimitiveDataPattern]
-> [Type]
-> ExceptT
     TypeError
     (StateT InferState IO)
     [(Type, [(String, TypeScheme)], Subst)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (\IPrimitiveDataPattern
p Type
argTy -> IPrimitiveDataPattern
-> Type
-> TypeErrorContext
-> Infer (Type, [(String, TypeScheme)], Subst)
inferPrimitiveDataPattern IPrimitiveDataPattern
p Type
argTy TypeErrorContext
ctx) [IPrimitiveDataPattern]
pats [Type]
argTypes'
                  let ([Type]
_, [[(String, TypeScheme)]]
bindingsList, [Subst]
substs) = [(Type, [(String, TypeScheme)], Subst)]
-> ([Type], [[(String, TypeScheme)]], [Subst])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Type, [(String, TypeScheme)], Subst)]
results
                      allBindings :: [(String, TypeScheme)]
allBindings = [[(String, TypeScheme)]] -> [(String, TypeScheme)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[(String, TypeScheme)]]
bindingsList
                      s :: Subst
s = (Subst -> Subst -> Subst) -> Subst -> [Subst] -> Subst
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Subst -> Subst -> Subst
composeSubst Subst
s0 [Subst]
substs

                  -- Return the result type, not expected type
                  Type
resultType'' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s Type
resultType'
                  (Type, [(String, TypeScheme)], Subst)
-> Infer (Type, [(String, TypeScheme)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
resultType'', [(String, TypeScheme)]
allBindings, Subst
s)
            
            Maybe TypeScheme
Nothing -> do
              -- Not found in environment: use generic inference
              [Type]
argTypes <- (IPrimitiveDataPattern -> Infer Type)
-> [IPrimitiveDataPattern]
-> ExceptT TypeError (StateT InferState IO) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\IPrimitiveDataPattern
_ -> String -> Infer Type
freshVar String
"arg") [IPrimitiveDataPattern]
pats
              let resultType :: Type
resultType = String -> [Type] -> Type
TInductive String
name [Type]
argTypes

              Subst
s0 <- Type -> Type -> TypeErrorContext -> Infer Subst
unifyTypesWithContext Type
resultType Type
expectedType TypeErrorContext
ctx
              Type
resultType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s0 Type
resultType

              [Type]
argTypes' <- (Type -> Infer Type)
-> [Type] -> ExceptT TypeError (StateT InferState IO) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s0) [Type]
argTypes
              [(Type, [(String, TypeScheme)], Subst)]
results <- (IPrimitiveDataPattern
 -> Type -> Infer (Type, [(String, TypeScheme)], Subst))
-> [IPrimitiveDataPattern]
-> [Type]
-> ExceptT
     TypeError
     (StateT InferState IO)
     [(Type, [(String, TypeScheme)], Subst)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (\IPrimitiveDataPattern
p Type
argTy -> IPrimitiveDataPattern
-> Type
-> TypeErrorContext
-> Infer (Type, [(String, TypeScheme)], Subst)
inferPrimitiveDataPattern IPrimitiveDataPattern
p Type
argTy TypeErrorContext
ctx) [IPrimitiveDataPattern]
pats [Type]
argTypes'
              let ([Type]
_, [[(String, TypeScheme)]]
bindingsList, [Subst]
substs) = [(Type, [(String, TypeScheme)], Subst)]
-> ([Type], [[(String, TypeScheme)]], [Subst])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Type, [(String, TypeScheme)], Subst)]
results
                  allBindings :: [(String, TypeScheme)]
allBindings = [[(String, TypeScheme)]] -> [(String, TypeScheme)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[(String, TypeScheme)]]
bindingsList
                  s :: Subst
s = (Subst -> Subst -> Subst) -> Subst -> [Subst] -> Subst
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Subst -> Subst -> Subst
composeSubst Subst
s0 [Subst]
substs

              Type
resultType'' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s Type
resultType'
              (Type, [(String, TypeScheme)], Subst)
-> Infer (Type, [(String, TypeScheme)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
resultType'', [(String, TypeScheme)]
allBindings, Subst
s)
        
        -- ScalarData (MathExpr) primitive patterns
        PDDivPat IPrimitiveDataPattern
patNum IPrimitiveDataPattern
patDen -> do
          -- Div: MathExpr -> PolyExpr, PolyExpr
          -- However, if pattern is a pattern variable, it gets MathExpr (auto-conversion)
          let polyExprTy :: Type
polyExprTy = Type
TPolyExpr
              mathExprTy :: Type
mathExprTy = Type
TMathExpr
              numTy :: Type
numTy = if IPrimitiveDataPattern -> Bool
isPDPatVar IPrimitiveDataPattern
patNum then Type
mathExprTy else Type
polyExprTy
              denTy :: Type
denTy = if IPrimitiveDataPattern -> Bool
isPDPatVar IPrimitiveDataPattern
patDen then Type
mathExprTy else Type
polyExprTy
          (Type
_, [(String, TypeScheme)]
bindings1, Subst
s1) <- IPrimitiveDataPattern
-> Type
-> TypeErrorContext
-> Infer (Type, [(String, TypeScheme)], Subst)
inferPrimitiveDataPattern IPrimitiveDataPattern
patNum Type
numTy TypeErrorContext
ctx
          Type
denTy' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s1 Type
denTy
          (Type
_, [(String, TypeScheme)]
bindings2, Subst
s2) <- IPrimitiveDataPattern
-> Type
-> TypeErrorContext
-> Infer (Type, [(String, TypeScheme)], Subst)
inferPrimitiveDataPattern IPrimitiveDataPattern
patDen Type
denTy' TypeErrorContext
ctx
          let s :: Subst
s = Subst -> Subst -> Subst
composeSubst Subst
s2 Subst
s1
          Type
expectedType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s Type
expectedType
          (Type, [(String, TypeScheme)], Subst)
-> Infer (Type, [(String, TypeScheme)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
expectedType', [(String, TypeScheme)]
bindings1 [(String, TypeScheme)]
-> [(String, TypeScheme)] -> [(String, TypeScheme)]
forall a. [a] -> [a] -> [a]
++ [(String, TypeScheme)]
bindings2, Subst
s)
        
        PDPlusPat IPrimitiveDataPattern
patTerms -> do
          -- Plus: PolyExpr -> [TermExpr]
          -- If pattern variable, it gets [MathExpr]
          let termExprTy :: Type
termExprTy = Type
TTermExpr
              mathExprTy :: Type
mathExprTy = Type
TMathExpr
              termsTy :: Type
termsTy = if IPrimitiveDataPattern -> Bool
isPDPatVar IPrimitiveDataPattern
patTerms then Type -> Type
TCollection Type
mathExprTy else Type -> Type
TCollection Type
termExprTy
          (Type
_, [(String, TypeScheme)]
bindings, Subst
s) <- IPrimitiveDataPattern
-> Type
-> TypeErrorContext
-> Infer (Type, [(String, TypeScheme)], Subst)
inferPrimitiveDataPattern IPrimitiveDataPattern
patTerms Type
termsTy TypeErrorContext
ctx
          Type
expectedType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s Type
expectedType
          (Type, [(String, TypeScheme)], Subst)
-> Infer (Type, [(String, TypeScheme)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
expectedType', [(String, TypeScheme)]
bindings, Subst
s)
        
        PDTermPat IPrimitiveDataPattern
patCoeff IPrimitiveDataPattern
patMonomials -> do
          -- Term: TermExpr -> Integer, [(SymbolExpr, Integer)]
          -- If patMonomials is pattern variable, it gets [(MathExpr, Integer)]
          let symbolExprTy :: Type
symbolExprTy = Type
TSymbolExpr
              mathExprTy :: Type
mathExprTy = Type
TMathExpr
              monomialsElemTy :: Type
monomialsElemTy = if IPrimitiveDataPattern -> Bool
isPDPatVar IPrimitiveDataPattern
patMonomials
                                then [Type] -> Type
TTuple [Type
mathExprTy, Type
TInt]
                                else [Type] -> Type
TTuple [Type
symbolExprTy, Type
TInt]
          (Type
_, [(String, TypeScheme)]
bindings1, Subst
s1) <- IPrimitiveDataPattern
-> Type
-> TypeErrorContext
-> Infer (Type, [(String, TypeScheme)], Subst)
inferPrimitiveDataPattern IPrimitiveDataPattern
patCoeff Type
TInt TypeErrorContext
ctx
          Type
monomialsCollTy <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s1 (Type -> Type
TCollection Type
monomialsElemTy)
          (Type
_, [(String, TypeScheme)]
bindings2, Subst
s2) <- IPrimitiveDataPattern
-> Type
-> TypeErrorContext
-> Infer (Type, [(String, TypeScheme)], Subst)
inferPrimitiveDataPattern IPrimitiveDataPattern
patMonomials Type
monomialsCollTy TypeErrorContext
ctx
          let s :: Subst
s = Subst -> Subst -> Subst
composeSubst Subst
s2 Subst
s1
          Type
expectedType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s Type
expectedType
          (Type, [(String, TypeScheme)], Subst)
-> Infer (Type, [(String, TypeScheme)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
expectedType', [(String, TypeScheme)]
bindings1 [(String, TypeScheme)]
-> [(String, TypeScheme)] -> [(String, TypeScheme)]
forall a. [a] -> [a] -> [a]
++ [(String, TypeScheme)]
bindings2, Subst
s)
        
        PDSymbolPat IPrimitiveDataPattern
patName IPrimitiveDataPattern
patIndices -> do
          -- Symbol: SymbolExpr -> String, [IndexExpr]
          -- patName and patIndices types don't change for pattern variables
          let indexExprTy :: Type
indexExprTy = Type
TIndexExpr
          (Type
_, [(String, TypeScheme)]
bindings1, Subst
s1) <- IPrimitiveDataPattern
-> Type
-> TypeErrorContext
-> Infer (Type, [(String, TypeScheme)], Subst)
inferPrimitiveDataPattern IPrimitiveDataPattern
patName Type
TString TypeErrorContext
ctx
          Type
indicesCollTy <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s1 (Type -> Type
TCollection Type
indexExprTy)
          (Type
_, [(String, TypeScheme)]
bindings2, Subst
s2) <- IPrimitiveDataPattern
-> Type
-> TypeErrorContext
-> Infer (Type, [(String, TypeScheme)], Subst)
inferPrimitiveDataPattern IPrimitiveDataPattern
patIndices Type
indicesCollTy TypeErrorContext
ctx
          let s :: Subst
s = Subst -> Subst -> Subst
composeSubst Subst
s2 Subst
s1
          Type
expectedType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s Type
expectedType
          (Type, [(String, TypeScheme)], Subst)
-> Infer (Type, [(String, TypeScheme)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
expectedType', [(String, TypeScheme)]
bindings1 [(String, TypeScheme)]
-> [(String, TypeScheme)] -> [(String, TypeScheme)]
forall a. [a] -> [a] -> [a]
++ [(String, TypeScheme)]
bindings2, Subst
s)
        
        PDApply1Pat IPrimitiveDataPattern
patFn IPrimitiveDataPattern
patArg -> do
          -- Apply1: SymbolExpr -> (MathExpr -> MathExpr), MathExpr
          let mathExprTy :: Type
mathExprTy = Type
TMathExpr
              fnTy :: Type
fnTy = Type -> Type -> Type
TFun Type
mathExprTy Type
mathExprTy
          (Type
_, [(String, TypeScheme)]
bindings1, Subst
s1) <- IPrimitiveDataPattern
-> Type
-> TypeErrorContext
-> Infer (Type, [(String, TypeScheme)], Subst)
inferPrimitiveDataPattern IPrimitiveDataPattern
patFn Type
fnTy TypeErrorContext
ctx
          Type
mathExprTy' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s1 Type
mathExprTy
          (Type
_, [(String, TypeScheme)]
bindings2, Subst
s2) <- IPrimitiveDataPattern
-> Type
-> TypeErrorContext
-> Infer (Type, [(String, TypeScheme)], Subst)
inferPrimitiveDataPattern IPrimitiveDataPattern
patArg Type
mathExprTy' TypeErrorContext
ctx
          let s :: Subst
s = Subst -> Subst -> Subst
composeSubst Subst
s2 Subst
s1
          Type
expectedType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s Type
expectedType
          (Type, [(String, TypeScheme)], Subst)
-> Infer (Type, [(String, TypeScheme)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
expectedType', [(String, TypeScheme)]
bindings1 [(String, TypeScheme)]
-> [(String, TypeScheme)] -> [(String, TypeScheme)]
forall a. [a] -> [a] -> [a]
++ [(String, TypeScheme)]
bindings2, Subst
s)
        
        PDApply2Pat IPrimitiveDataPattern
patFn IPrimitiveDataPattern
patArg1 IPrimitiveDataPattern
patArg2 -> do
          let mathExprTy :: Type
mathExprTy = Type
TMathExpr
              fnTy :: Type
fnTy = Type -> Type -> Type
TFun Type
mathExprTy (Type -> Type -> Type
TFun Type
mathExprTy Type
mathExprTy)
          (Type
_, [(String, TypeScheme)]
bindings1, Subst
s1) <- IPrimitiveDataPattern
-> Type
-> TypeErrorContext
-> Infer (Type, [(String, TypeScheme)], Subst)
inferPrimitiveDataPattern IPrimitiveDataPattern
patFn Type
fnTy TypeErrorContext
ctx
          Type
mathExprTy1 <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s1 Type
mathExprTy
          (Type
_, [(String, TypeScheme)]
bindings2, Subst
s2) <- IPrimitiveDataPattern
-> Type
-> TypeErrorContext
-> Infer (Type, [(String, TypeScheme)], Subst)
inferPrimitiveDataPattern IPrimitiveDataPattern
patArg1 Type
mathExprTy1 TypeErrorContext
ctx
          Type
mathExprTy2 <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s2 Type
mathExprTy
          (Type
_, [(String, TypeScheme)]
bindings3, Subst
s3) <- IPrimitiveDataPattern
-> Type
-> TypeErrorContext
-> Infer (Type, [(String, TypeScheme)], Subst)
inferPrimitiveDataPattern IPrimitiveDataPattern
patArg2 Type
mathExprTy2 TypeErrorContext
ctx
          let s :: Subst
s = Subst -> Subst -> Subst
composeSubst Subst
s3 (Subst -> Subst -> Subst
composeSubst Subst
s2 Subst
s1)
          Type
expectedType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s Type
expectedType
          (Type, [(String, TypeScheme)], Subst)
-> Infer (Type, [(String, TypeScheme)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
expectedType', [(String, TypeScheme)]
bindings1 [(String, TypeScheme)]
-> [(String, TypeScheme)] -> [(String, TypeScheme)]
forall a. [a] -> [a] -> [a]
++ [(String, TypeScheme)]
bindings2 [(String, TypeScheme)]
-> [(String, TypeScheme)] -> [(String, TypeScheme)]
forall a. [a] -> [a] -> [a]
++ [(String, TypeScheme)]
bindings3, Subst
s)
        
        PDApply3Pat IPrimitiveDataPattern
patFn IPrimitiveDataPattern
patArg1 IPrimitiveDataPattern
patArg2 IPrimitiveDataPattern
patArg3 -> do
          let mathExprTy :: Type
mathExprTy = Type
TMathExpr
              fnTy :: Type
fnTy = Type -> Type -> Type
TFun Type
mathExprTy (Type -> Type -> Type
TFun Type
mathExprTy (Type -> Type -> Type
TFun Type
mathExprTy Type
mathExprTy))
          (Type
_, [(String, TypeScheme)]
bindings1, Subst
s1) <- IPrimitiveDataPattern
-> Type
-> TypeErrorContext
-> Infer (Type, [(String, TypeScheme)], Subst)
inferPrimitiveDataPattern IPrimitiveDataPattern
patFn Type
fnTy TypeErrorContext
ctx
          Type
mathExprTy1 <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s1 Type
mathExprTy
          (Type
_, [(String, TypeScheme)]
bindings2, Subst
s2) <- IPrimitiveDataPattern
-> Type
-> TypeErrorContext
-> Infer (Type, [(String, TypeScheme)], Subst)
inferPrimitiveDataPattern IPrimitiveDataPattern
patArg1 Type
mathExprTy1 TypeErrorContext
ctx
          Type
mathExprTy2 <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s2 Type
mathExprTy
          (Type
_, [(String, TypeScheme)]
bindings3, Subst
s3) <- IPrimitiveDataPattern
-> Type
-> TypeErrorContext
-> Infer (Type, [(String, TypeScheme)], Subst)
inferPrimitiveDataPattern IPrimitiveDataPattern
patArg2 Type
mathExprTy2 TypeErrorContext
ctx
          Type
mathExprTy3 <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s3 Type
mathExprTy
          (Type
_, [(String, TypeScheme)]
bindings4, Subst
s4) <- IPrimitiveDataPattern
-> Type
-> TypeErrorContext
-> Infer (Type, [(String, TypeScheme)], Subst)
inferPrimitiveDataPattern IPrimitiveDataPattern
patArg3 Type
mathExprTy3 TypeErrorContext
ctx
          let s :: Subst
s = Subst -> Subst -> Subst
composeSubst Subst
s4 (Subst -> Subst -> Subst
composeSubst Subst
s3 (Subst -> Subst -> Subst
composeSubst Subst
s2 Subst
s1))
          Type
expectedType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s Type
expectedType
          (Type, [(String, TypeScheme)], Subst)
-> Infer (Type, [(String, TypeScheme)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
expectedType', [(String, TypeScheme)]
bindings1 [(String, TypeScheme)]
-> [(String, TypeScheme)] -> [(String, TypeScheme)]
forall a. [a] -> [a] -> [a]
++ [(String, TypeScheme)]
bindings2 [(String, TypeScheme)]
-> [(String, TypeScheme)] -> [(String, TypeScheme)]
forall a. [a] -> [a] -> [a]
++ [(String, TypeScheme)]
bindings3 [(String, TypeScheme)]
-> [(String, TypeScheme)] -> [(String, TypeScheme)]
forall a. [a] -> [a] -> [a]
++ [(String, TypeScheme)]
bindings4, Subst
s)
        
        PDApply4Pat IPrimitiveDataPattern
patFn IPrimitiveDataPattern
patArg1 IPrimitiveDataPattern
patArg2 IPrimitiveDataPattern
patArg3 IPrimitiveDataPattern
patArg4 -> do
          let mathExprTy :: Type
mathExprTy = Type
TMathExpr
              fnTy :: Type
fnTy = Type -> Type -> Type
TFun Type
mathExprTy (Type -> Type -> Type
TFun Type
mathExprTy (Type -> Type -> Type
TFun Type
mathExprTy (Type -> Type -> Type
TFun Type
mathExprTy Type
mathExprTy)))
          (Type
_, [(String, TypeScheme)]
bindings1, Subst
s1) <- IPrimitiveDataPattern
-> Type
-> TypeErrorContext
-> Infer (Type, [(String, TypeScheme)], Subst)
inferPrimitiveDataPattern IPrimitiveDataPattern
patFn Type
fnTy TypeErrorContext
ctx
          Type
mathExprTy1 <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s1 Type
mathExprTy
          (Type
_, [(String, TypeScheme)]
bindings2, Subst
s2) <- IPrimitiveDataPattern
-> Type
-> TypeErrorContext
-> Infer (Type, [(String, TypeScheme)], Subst)
inferPrimitiveDataPattern IPrimitiveDataPattern
patArg1 Type
mathExprTy1 TypeErrorContext
ctx
          Type
mathExprTy2 <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s2 Type
mathExprTy
          (Type
_, [(String, TypeScheme)]
bindings3, Subst
s3) <- IPrimitiveDataPattern
-> Type
-> TypeErrorContext
-> Infer (Type, [(String, TypeScheme)], Subst)
inferPrimitiveDataPattern IPrimitiveDataPattern
patArg2 Type
mathExprTy2 TypeErrorContext
ctx
          Type
mathExprTy3 <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s3 Type
mathExprTy
          (Type
_, [(String, TypeScheme)]
bindings4, Subst
s4) <- IPrimitiveDataPattern
-> Type
-> TypeErrorContext
-> Infer (Type, [(String, TypeScheme)], Subst)
inferPrimitiveDataPattern IPrimitiveDataPattern
patArg3 Type
mathExprTy3 TypeErrorContext
ctx
          Type
mathExprTy4 <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s4 Type
mathExprTy
          (Type
_, [(String, TypeScheme)]
bindings5, Subst
s5) <- IPrimitiveDataPattern
-> Type
-> TypeErrorContext
-> Infer (Type, [(String, TypeScheme)], Subst)
inferPrimitiveDataPattern IPrimitiveDataPattern
patArg4 Type
mathExprTy4 TypeErrorContext
ctx
          let s :: Subst
s = Subst -> Subst -> Subst
composeSubst Subst
s5 (Subst -> Subst -> Subst
composeSubst Subst
s4 (Subst -> Subst -> Subst
composeSubst Subst
s3 (Subst -> Subst -> Subst
composeSubst Subst
s2 Subst
s1)))
          Type
expectedType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s Type
expectedType
          (Type, [(String, TypeScheme)], Subst)
-> Infer (Type, [(String, TypeScheme)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
expectedType', [(String, TypeScheme)]
bindings1 [(String, TypeScheme)]
-> [(String, TypeScheme)] -> [(String, TypeScheme)]
forall a. [a] -> [a] -> [a]
++ [(String, TypeScheme)]
bindings2 [(String, TypeScheme)]
-> [(String, TypeScheme)] -> [(String, TypeScheme)]
forall a. [a] -> [a] -> [a]
++ [(String, TypeScheme)]
bindings3 [(String, TypeScheme)]
-> [(String, TypeScheme)] -> [(String, TypeScheme)]
forall a. [a] -> [a] -> [a]
++ [(String, TypeScheme)]
bindings4 [(String, TypeScheme)]
-> [(String, TypeScheme)] -> [(String, TypeScheme)]
forall a. [a] -> [a] -> [a]
++ [(String, TypeScheme)]
bindings5, Subst
s)
        
        PDQuotePat IPrimitiveDataPattern
patExpr -> do
          -- Quote: SymbolExpr -> MathExpr
          let mathExprTy :: Type
mathExprTy = Type
TMathExpr
          (Type
_, [(String, TypeScheme)]
bindings, Subst
s) <- IPrimitiveDataPattern
-> Type
-> TypeErrorContext
-> Infer (Type, [(String, TypeScheme)], Subst)
inferPrimitiveDataPattern IPrimitiveDataPattern
patExpr Type
mathExprTy TypeErrorContext
ctx
          Type
expectedType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s Type
expectedType
          (Type, [(String, TypeScheme)], Subst)
-> Infer (Type, [(String, TypeScheme)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
expectedType', [(String, TypeScheme)]
bindings, Subst
s)
        
        PDFunctionPat IPrimitiveDataPattern
patName IPrimitiveDataPattern
patArgs -> do
          -- Function: SymbolExpr -> MathExpr, [MathExpr]
          let mathExprTy :: Type
mathExprTy = Type
TMathExpr
          (Type
_, [(String, TypeScheme)]
bindings1, Subst
s1) <- IPrimitiveDataPattern
-> Type
-> TypeErrorContext
-> Infer (Type, [(String, TypeScheme)], Subst)
inferPrimitiveDataPattern IPrimitiveDataPattern
patName Type
mathExprTy TypeErrorContext
ctx
          Type
argsCollTy <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s1 (Type -> Type
TCollection Type
mathExprTy)
          (Type
_, [(String, TypeScheme)]
bindings2, Subst
s2) <- IPrimitiveDataPattern
-> Type
-> TypeErrorContext
-> Infer (Type, [(String, TypeScheme)], Subst)
inferPrimitiveDataPattern IPrimitiveDataPattern
patArgs Type
argsCollTy TypeErrorContext
ctx
          Type
expectedType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s2 Type
expectedType
          (Type, [(String, TypeScheme)], Subst)
-> Infer (Type, [(String, TypeScheme)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
expectedType', [(String, TypeScheme)]
bindings1 [(String, TypeScheme)]
-> [(String, TypeScheme)] -> [(String, TypeScheme)]
forall a. [a] -> [a] -> [a]
++ [(String, TypeScheme)]
bindings2, Subst
s2)
        
        PDSubPat IPrimitiveDataPattern
patExpr -> do
          -- Sub: IndexExpr -> MathExpr
          let mathExprTy :: Type
mathExprTy = Type
TMathExpr
          (Type
_, [(String, TypeScheme)]
bindings, Subst
s) <- IPrimitiveDataPattern
-> Type
-> TypeErrorContext
-> Infer (Type, [(String, TypeScheme)], Subst)
inferPrimitiveDataPattern IPrimitiveDataPattern
patExpr Type
mathExprTy TypeErrorContext
ctx
          Type
expectedType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s Type
expectedType
          (Type, [(String, TypeScheme)], Subst)
-> Infer (Type, [(String, TypeScheme)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
expectedType', [(String, TypeScheme)]
bindings, Subst
s)

        PDSupPat IPrimitiveDataPattern
patExpr -> do
          -- Sup: IndexExpr -> MathExpr
          let mathExprTy :: Type
mathExprTy = Type
TMathExpr
          (Type
_, [(String, TypeScheme)]
bindings, Subst
s) <- IPrimitiveDataPattern
-> Type
-> TypeErrorContext
-> Infer (Type, [(String, TypeScheme)], Subst)
inferPrimitiveDataPattern IPrimitiveDataPattern
patExpr Type
mathExprTy TypeErrorContext
ctx
          Type
expectedType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s Type
expectedType
          (Type, [(String, TypeScheme)], Subst)
-> Infer (Type, [(String, TypeScheme)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
expectedType', [(String, TypeScheme)]
bindings, Subst
s)
        
        PDUserPat IPrimitiveDataPattern
patExpr -> do
          -- User: IndexExpr -> MathExpr
          let mathExprTy :: Type
mathExprTy = Type
TMathExpr
          (Type
_, [(String, TypeScheme)]
bindings, Subst
s) <- IPrimitiveDataPattern
-> Type
-> TypeErrorContext
-> Infer (Type, [(String, TypeScheme)], Subst)
inferPrimitiveDataPattern IPrimitiveDataPattern
patExpr Type
mathExprTy TypeErrorContext
ctx
          Type
expectedType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s Type
expectedType
          (Type, [(String, TypeScheme)], Subst)
-> Infer (Type, [(String, TypeScheme)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
expectedType', [(String, TypeScheme)]
bindings, Subst
s)
  
  -- Match expressions (pattern matching)
  IMatchExpr PMMode
mode IExpr
target IExpr
matcher [IMatchClause]
clauses -> do
    let exprCtx :: TypeErrorContext
exprCtx = String -> TypeErrorContext -> TypeErrorContext
withExpr (IExpr -> String
forall a. Pretty a => a -> String
prettyStr IExpr
expr) TypeErrorContext
ctx
    (TIExpr
targetTI, Subst
s1) <- IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
target TypeErrorContext
exprCtx
    (TIExpr
matcherTI, Subst
s2) <- IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
matcher TypeErrorContext
exprCtx
    let targetType :: Type
targetType = TIExpr -> Type
tiExprType TIExpr
targetTI
        matcherType :: Type
matcherType = TIExpr -> Type
tiExprType TIExpr
matcherTI

    -- Matcher should be TMatcher a or (TMatcher a, TMatcher b, ...) which becomes TMatcher (a, b, ...)
    let s12 :: Subst
s12 = Subst -> Subst -> Subst
composeSubst Subst
s2 Subst
s1
    Type
appliedMatcherType <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s12 Type
matcherType

    -- Normalize matcher type: if it's a tuple, ensure each element is a Matcher
    (Type
_normalizedMatcherType, Type
matchedInnerType, Subst
s3) <- case Type
appliedMatcherType of
      TTuple [Type]
elemTypes -> do
        -- Each tuple element should be Matcher ai
        [Type]
matchedInnerTypes <- (Type -> Infer Type)
-> [Type] -> ExceptT TypeError (StateT InferState IO) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\Type
_ -> String -> Infer Type
freshVar String
"matched") [Type]
elemTypes
        Subst
s_elems <- (Subst -> (Type, Type) -> Infer Subst)
-> Subst -> [(Type, Type)] -> Infer Subst
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (\Subst
accS (Type
elemTy, Type
innerTy) -> do
          Type
appliedElemTy <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
accS Type
elemTy
          Type
appliedInnerTy <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
accS Type
innerTy
          Subst
s' <- Type -> Type -> TypeErrorContext -> Infer Subst
unifyTypesWithContext Type
appliedElemTy (Type -> Type
TMatcher Type
appliedInnerTy) TypeErrorContext
exprCtx
          Subst -> Infer Subst
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Subst -> Infer Subst) -> Subst -> Infer Subst
forall a b. (a -> b) -> a -> b
$ Subst -> Subst -> Subst
composeSubst Subst
s' Subst
accS
          ) Subst
emptySubst ([Type] -> [Type] -> [(Type, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Type]
elemTypes [Type]
matchedInnerTypes)
        -- The tuple as a whole becomes Matcher (a1, a2, ...)
        [Type]
finalInnerTypes <- (Type -> Infer Type)
-> [Type] -> ExceptT TypeError (StateT InferState IO) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s_elems) [Type]
matchedInnerTypes
        let tupleInnerType :: Type
tupleInnerType = [Type] -> Type
TTuple [Type]
finalInnerTypes
        (Type, Type, Subst)
-> ExceptT TypeError (StateT InferState IO) (Type, Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> Type
TMatcher Type
tupleInnerType, Type
tupleInnerType, Subst
s_elems)
      Type
_ -> do
        -- Single matcher: TMatcher a
        Type
matchedTy <- String -> Infer Type
freshVar String
"matched"
        Subst
s' <- Type -> Type -> TypeErrorContext -> Infer Subst
unifyTypesWithContext Type
appliedMatcherType (Type -> Type
TMatcher Type
matchedTy) TypeErrorContext
exprCtx
        Type
finalMatchedTy <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s' Type
matchedTy
        (Type, Type, Subst)
-> ExceptT TypeError (StateT InferState IO) (Type, Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> Type
TMatcher Type
finalMatchedTy, Type
finalMatchedTy, Subst
s')

    let s123 :: Subst
s123 = Subst -> Subst -> Subst
composeSubst Subst
s3 Subst
s12
    Type
targetType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s123 Type
targetType
    Type
matchedInnerType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s123 Type
matchedInnerType
    Subst
s4 <- Type -> Type -> TypeErrorContext -> Infer Subst
unifyTypesWithContext Type
targetType' Type
matchedInnerType' TypeErrorContext
exprCtx
    
    -- Infer match clauses result type
    let s1234 :: Subst
s1234 = Subst -> Subst -> Subst
composeSubst Subst
s4 Subst
s123
    case [IMatchClause]
clauses of
      [] -> do
        -- No clauses: this should not happen, but handle gracefully
        Type
resultTy <- String -> Infer Type
freshVar String
"matchResult"
        TIExpr
targetTI' <- Subst -> TIExpr -> Infer TIExpr
applySubstToTIExprM Subst
s1234 TIExpr
targetTI
        TIExpr
matcherTI' <- Subst -> TIExpr -> Infer TIExpr
applySubstToTIExprM Subst
s1234 TIExpr
matcherTI
        Type
resultTy' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s1234 Type
resultTy
        (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> TIExprNode -> TIExpr
mkTIExpr Type
resultTy' (PMMode -> TIExpr -> TIExpr -> [TIMatchClause] -> TIExprNode
TIMatchExpr PMMode
mode TIExpr
targetTI' TIExpr
matcherTI' []), Subst
s1234)
      [IMatchClause]
_ -> do
        -- Infer type of each clause and unify them
        Type
matchedInnerType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s1234 Type
matchedInnerType
        (Type
resultTy, [TIMatchClause]
clauseTIs, Subst
clauseSubst) <- TypeErrorContext
-> Type
-> [IMatchClause]
-> Subst
-> Infer (Type, [TIMatchClause], Subst)
inferMatchClauses TypeErrorContext
exprCtx Type
matchedInnerType' [IMatchClause]
clauses Subst
s1234
        let finalS :: Subst
finalS = Subst -> Subst -> Subst
composeSubst Subst
clauseSubst Subst
s1234
        TIExpr
targetTI' <- Subst -> TIExpr -> Infer TIExpr
applySubstToTIExprM Subst
finalS TIExpr
targetTI
        TIExpr
matcherTI' <- Subst -> TIExpr -> Infer TIExpr
applySubstToTIExprM Subst
finalS TIExpr
matcherTI
        Type
resultTy' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
finalS Type
resultTy
        (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> TIExprNode -> TIExpr
mkTIExpr Type
resultTy' (PMMode -> TIExpr -> TIExpr -> [TIMatchClause] -> TIExprNode
TIMatchExpr PMMode
mode TIExpr
targetTI' TIExpr
matcherTI' [TIMatchClause]
clauseTIs), Subst
finalS)
  
  -- MatchAll expressions
  IMatchAllExpr PMMode
mode IExpr
target IExpr
matcher [IMatchClause]
clauses -> do
    let exprCtx :: TypeErrorContext
exprCtx = String -> TypeErrorContext -> TypeErrorContext
withExpr (IExpr -> String
forall a. Pretty a => a -> String
prettyStr IExpr
expr) TypeErrorContext
ctx
    (TIExpr
targetTI, Subst
s1) <- IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
target TypeErrorContext
exprCtx
    (TIExpr
matcherTI, Subst
s2) <- IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
matcher TypeErrorContext
exprCtx
    let targetType :: Type
targetType = TIExpr -> Type
tiExprType TIExpr
targetTI
        matcherType :: Type
matcherType = TIExpr -> Type
tiExprType TIExpr
matcherTI
    
    -- Matcher should be TMatcher a or (TMatcher a, TMatcher b, ...) which becomes TMatcher (a, b, ...)
    let s12 :: Subst
s12 = Subst -> Subst -> Subst
composeSubst Subst
s2 Subst
s1
    Type
appliedMatcherType <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s12 Type
matcherType
    
    -- Normalize matcher type: if it's a tuple, ensure each element is a Matcher
    (Type
_normalizedMatcherType, Type
matchedInnerType, Subst
s3) <- case Type
appliedMatcherType of
      TTuple [Type]
elemTypes -> do
        -- Each tuple element should be Matcher ai
        [Type]
matchedInnerTypes <- (Type -> Infer Type)
-> [Type] -> ExceptT TypeError (StateT InferState IO) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\Type
_ -> String -> Infer Type
freshVar String
"matched") [Type]
elemTypes
        Subst
s_elems <- (Subst -> (Type, Type) -> Infer Subst)
-> Subst -> [(Type, Type)] -> Infer Subst
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (\Subst
accS (Type
elemTy, Type
innerTy) -> do
          Type
appliedElemTy <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
accS Type
elemTy
          Type
appliedInnerTy <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
accS Type
innerTy
          Subst
s' <- Type -> Type -> TypeErrorContext -> Infer Subst
unifyTypesWithContext Type
appliedElemTy (Type -> Type
TMatcher Type
appliedInnerTy) TypeErrorContext
exprCtx
          Subst -> Infer Subst
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Subst -> Infer Subst) -> Subst -> Infer Subst
forall a b. (a -> b) -> a -> b
$ Subst -> Subst -> Subst
composeSubst Subst
s' Subst
accS
          ) Subst
emptySubst ([Type] -> [Type] -> [(Type, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Type]
elemTypes [Type]
matchedInnerTypes)
        -- The tuple as a whole becomes Matcher (a1, a2, ...)
        [Type]
finalInnerTypes <- (Type -> Infer Type)
-> [Type] -> ExceptT TypeError (StateT InferState IO) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s_elems) [Type]
matchedInnerTypes
        let tupleInnerType :: Type
tupleInnerType = [Type] -> Type
TTuple [Type]
finalInnerTypes
        (Type, Type, Subst)
-> ExceptT TypeError (StateT InferState IO) (Type, Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> Type
TMatcher Type
tupleInnerType, Type
tupleInnerType, Subst
s_elems)
      Type
_ -> do
        -- Single matcher: TMatcher a
        Type
matchedTy <- String -> Infer Type
freshVar String
"matched"
        Subst
s' <- Type -> Type -> TypeErrorContext -> Infer Subst
unifyTypesWithContext Type
appliedMatcherType (Type -> Type
TMatcher Type
matchedTy) TypeErrorContext
exprCtx
        Type
finalMatchedTy <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s' Type
matchedTy
        (Type, Type, Subst)
-> ExceptT TypeError (StateT InferState IO) (Type, Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> Type
TMatcher Type
finalMatchedTy, Type
finalMatchedTy, Subst
s')

    let s123 :: Subst
s123 = Subst -> Subst -> Subst
composeSubst Subst
s3 Subst
s12
    Type
targetType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s123 Type
targetType
    Type
matchedInnerType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s123 Type
matchedInnerType
    Subst
s4 <- Type -> Type -> TypeErrorContext -> Infer Subst
unifyTypesWithContext Type
targetType' Type
matchedInnerType' TypeErrorContext
exprCtx
    
    -- MatchAll returns a collection of results from match clauses
    let s1234 :: Subst
s1234 = Subst -> Subst -> Subst
composeSubst Subst
s4 Subst
s123
    case [IMatchClause]
clauses of
      [] -> do
        -- No clauses: return empty collection type
        Type
resultElemTy <- String -> Infer Type
freshVar String
"matchAllElem"
        TIExpr
targetTI' <- Subst -> TIExpr -> Infer TIExpr
applySubstToTIExprM Subst
s1234 TIExpr
targetTI
        TIExpr
matcherTI' <- Subst -> TIExpr -> Infer TIExpr
applySubstToTIExprM Subst
s1234 TIExpr
matcherTI
        Type
resultElemTy' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s1234 Type
resultElemTy
        (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> TIExprNode -> TIExpr
mkTIExpr (Type -> Type
TCollection Type
resultElemTy') (PMMode -> TIExpr -> TIExpr -> [TIMatchClause] -> TIExprNode
TIMatchAllExpr PMMode
mode TIExpr
targetTI' TIExpr
matcherTI' []), Subst
s1234)
      [IMatchClause]
_ -> do
        -- Infer type of each clause (they should all have the same type)
        Type
matchedInnerType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s1234 Type
matchedInnerType
        (Type
resultElemTy, [TIMatchClause]
clauseTIs, Subst
clauseSubst) <- TypeErrorContext
-> Type
-> [IMatchClause]
-> Subst
-> Infer (Type, [TIMatchClause], Subst)
inferMatchClauses TypeErrorContext
exprCtx Type
matchedInnerType' [IMatchClause]
clauses Subst
s1234
        let finalS :: Subst
finalS = Subst -> Subst -> Subst
composeSubst Subst
clauseSubst Subst
s1234
        TIExpr
targetTI' <- Subst -> TIExpr -> Infer TIExpr
applySubstToTIExprM Subst
finalS TIExpr
targetTI
        TIExpr
matcherTI' <- Subst -> TIExpr -> Infer TIExpr
applySubstToTIExprM Subst
finalS TIExpr
matcherTI
        Type
resultElemTy' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
finalS Type
resultElemTy
        (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> TIExprNode -> TIExpr
mkTIExpr (Type -> Type
TCollection Type
resultElemTy') (PMMode -> TIExpr -> TIExpr -> [TIMatchClause] -> TIExprNode
TIMatchAllExpr PMMode
mode TIExpr
targetTI' TIExpr
matcherTI' [TIMatchClause]
clauseTIs), Subst
finalS)
  
  -- Memoized Lambda
  IMemoizedLambdaExpr [String]
args IExpr
body -> do
    let exprCtx :: TypeErrorContext
exprCtx = String -> TypeErrorContext -> TypeErrorContext
withExpr (IExpr -> String
forall a. Pretty a => a -> String
prettyStr IExpr
expr) TypeErrorContext
ctx
    [Type]
argTypes <- (String -> Infer Type)
-> [String] -> ExceptT TypeError (StateT InferState IO) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\String
_ -> String -> Infer Type
freshVar String
"memoArg") [String]
args
    let bindings :: [(String, Type)]
bindings = [String] -> [Type] -> [(String, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [String]
args [Type]
argTypes  -- [(String, Type)]
        schemes :: [(String, TypeScheme)]
schemes = ((String, Type) -> (String, TypeScheme))
-> [(String, Type)] -> [(String, TypeScheme)]
forall a b. (a -> b) -> [a] -> [b]
map (\(String
name, Type
t) -> (String
name, [TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [] Type
t)) [(String, Type)]
bindings
    (TIExpr
bodyTI, Subst
s) <- [(String, TypeScheme)]
-> Infer (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a. [(String, TypeScheme)] -> Infer a -> Infer a
withEnv [(String, TypeScheme)]
schemes (Infer (TIExpr, Subst) -> Infer (TIExpr, Subst))
-> Infer (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a b. (a -> b) -> a -> b
$ IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
body TypeErrorContext
exprCtx
    let bodyType :: Type
bodyType = TIExpr -> Type
tiExprType TIExpr
bodyTI
    [Type]
finalArgTypes <- (Type -> Infer Type)
-> [Type] -> ExceptT TypeError (StateT InferState IO) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s) [Type]
argTypes
    let funType :: Type
funType = (Type -> Type -> Type) -> Type -> [Type] -> Type
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Type -> Type -> Type
TFun Type
bodyType [Type]
finalArgTypes
    (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> TIExprNode -> TIExpr
mkTIExpr Type
funType ([String] -> TIExpr -> TIExprNode
TIMemoizedLambdaExpr [String]
args TIExpr
bodyTI), Subst
s)
  
  -- Do expression
  IDoExpr [IBindingExpr]
bindings IExpr
body -> do
    let exprCtx :: TypeErrorContext
exprCtx = String -> TypeErrorContext -> TypeErrorContext
withExpr (IExpr -> String
forall a. Pretty a => a -> String
prettyStr IExpr
expr) TypeErrorContext
ctx
    -- Infer IO monad bindings: each binding should be of type IO a
    TypeEnv
env <- Infer TypeEnv
getEnv
    ([TIBindingExpr]
bindingTIs, [(String, TypeScheme)]
bindingSchemes, Subst
s1) <- [IBindingExpr]
-> TypeEnv
-> Subst
-> TypeErrorContext
-> Infer ([TIBindingExpr], [(String, TypeScheme)], Subst)
inferIOBindingsWithContext [IBindingExpr]
bindings TypeEnv
env Subst
emptySubst TypeErrorContext
exprCtx
    (TIExpr
bodyTI, Subst
s2) <- [(String, TypeScheme)]
-> Infer (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a. [(String, TypeScheme)] -> Infer a -> Infer a
withEnv [(String, TypeScheme)]
bindingSchemes (Infer (TIExpr, Subst) -> Infer (TIExpr, Subst))
-> Infer (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a b. (a -> b) -> a -> b
$ IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
body TypeErrorContext
exprCtx
    let bodyType :: Type
bodyType = TIExpr -> Type
tiExprType TIExpr
bodyTI
        finalS :: Subst
finalS = Subst -> Subst -> Subst
composeSubst Subst
s2 Subst
s1
        
    -- Verify that body type is IO a
    Type
bodyResultType <- String -> Infer Type
freshVar String
"ioResult"
    Type
bodyType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
finalS Type
bodyType
    Subst
s3 <- Type -> Type -> TypeErrorContext -> Infer Subst
unifyTypesWithContext Type
bodyType' (Type -> Type
TIO Type
bodyResultType) TypeErrorContext
exprCtx
    Type
resultType <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s3 (Type -> Type
TIO Type
bodyResultType)
    let finalS' :: Subst
finalS' = Subst -> Subst -> Subst
composeSubst Subst
s3 Subst
finalS
    (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> TIExprNode -> TIExpr
mkTIExpr Type
resultType ([TIBindingExpr] -> TIExpr -> TIExprNode
TIDoExpr [TIBindingExpr]
bindingTIs TIExpr
bodyTI), Subst
finalS')
  
  -- Cambda (pattern matching lambda)
  ICambdaExpr String
var IExpr
body -> do
    let exprCtx :: TypeErrorContext
exprCtx = String -> TypeErrorContext -> TypeErrorContext
withExpr (IExpr -> String
forall a. Pretty a => a -> String
prettyStr IExpr
expr) TypeErrorContext
ctx
    Type
argType <- String -> Infer Type
freshVar String
"cambdaArg"
    (TIExpr
bodyTI, Subst
s) <- IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
body TypeErrorContext
exprCtx
    let bodyType :: Type
bodyType = TIExpr -> Type
tiExprType TIExpr
bodyTI
    (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> TIExprNode -> TIExpr
mkTIExpr (Type -> Type -> Type
TFun Type
argType Type
bodyType) (String -> TIExpr -> TIExprNode
TICambdaExpr String
var TIExpr
bodyTI), Subst
s)
  
  -- With symbols
  IWithSymbolsExpr [String]
syms IExpr
body -> do
    -- Add symbols to type environment as MathExpr (TMathExpr = TInt)
    -- Symbols introduced by withSymbols are mathematical symbols
    let symbolBindings :: [(String, TypeScheme)]
symbolBindings = [(String
sym, [TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [] Type
TMathExpr) | String
sym <- [String]
syms]
    (TIExpr
bodyTI, Subst
s) <- [(String, TypeScheme)]
-> Infer (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a. [(String, TypeScheme)] -> Infer a -> Infer a
withEnv [(String, TypeScheme)]
symbolBindings (Infer (TIExpr, Subst) -> Infer (TIExpr, Subst))
-> Infer (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a b. (a -> b) -> a -> b
$ IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
body TypeErrorContext
ctx
    let bodyType :: Type
bodyType = TIExpr -> Type
tiExprType TIExpr
bodyTI
    (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> TIExprNode -> TIExpr
mkTIExpr Type
bodyType ([String] -> TIExpr -> TIExprNode
TIWithSymbolsExpr [String]
syms TIExpr
bodyTI), Subst
s)
  
  -- Quote expressions (symbolic math)
  IQuoteExpr IExpr
e -> do
    (TIExpr
eTI, Subst
s) <- IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
e TypeErrorContext
ctx
    (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> TIExprNode -> TIExpr
mkTIExpr Type
TInt (TIExpr -> TIExprNode
TIQuoteExpr TIExpr
eTI), Subst
s)
  IQuoteSymbolExpr IExpr
e -> do
    (TIExpr
eTI, Subst
s) <- IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
e TypeErrorContext
ctx
    (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> TIExprNode -> TIExpr
mkTIExpr (TIExpr -> Type
tiExprType TIExpr
eTI) (TIExpr -> TIExprNode
TIQuoteSymbolExpr TIExpr
eTI), Subst
s)
  
  -- Indexed expression (tensor indexing)
  IIndexedExpr Bool
override IExpr
baseExpr [Index IExpr]
indices -> do
    let exprCtx :: TypeErrorContext
exprCtx = String -> TypeErrorContext -> TypeErrorContext
withExpr (IExpr -> String
forall a. Pretty a => a -> String
prettyStr IExpr
expr) TypeErrorContext
ctx
    -- Special handling for IVarExpr: lookup with Var including index info
    -- Use the same strategy as refVar in Data.hs (Core.hs:235)
    (TIExpr
baseTI, Subst
s) <- case IExpr
baseExpr of
      IVarExpr String
varName -> do
        -- Convert indices to index types (structure only, no content)
        -- Like: map (fmap (const Nothing)) indices in Core.hs
        let indexTypes :: [Index (Maybe a)]
indexTypes = (Index IExpr -> Index (Maybe a))
-> [Index IExpr] -> [Index (Maybe a)]
forall a b. (a -> b) -> [a] -> [b]
map ((IExpr -> Maybe a) -> Index IExpr -> Index (Maybe a)
forall a b. (a -> b) -> Index a -> Index b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Maybe a -> IExpr -> Maybe a
forall a b. a -> b -> a
const Maybe a
forall a. Maybe a
Nothing)) [Index IExpr]
indices
            varWithIndices :: Var
varWithIndices = String -> [Index (Maybe Var)] -> Var
Var String
varName [Index (Maybe Var)]
forall {a}. [Index (Maybe a)]
indexTypes
        TypeEnv
env <- Infer TypeEnv
getEnv
        -- lookupEnv will try: Var "e" [Sub Nothing, Sub Nothing]
        --                 -> Var "e" [Sub Nothing]
        --                 -> Var "e" []
        case Var -> TypeEnv -> Maybe TypeScheme
lookupEnv Var
varWithIndices TypeEnv
env of
          Just TypeScheme
scheme -> do
            InferState
st <- ExceptT TypeError (StateT InferState IO) InferState
forall s (m :: * -> *). MonadState s m => m s
get
            let ([Constraint]
constraints, Type
t, Int
newCounter) = TypeScheme -> Int -> ([Constraint], Type, Int)
instantiate TypeScheme
scheme (InferState -> Int
inferCounter InferState
st)
            (InferState -> InferState) -> Infer ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((InferState -> InferState) -> Infer ())
-> (InferState -> InferState) -> Infer ()
forall a b. (a -> b) -> a -> b
$ \InferState
s' -> InferState
s' { inferCounter = newCounter }
            [Constraint] -> Infer ()
addConstraints [Constraint]
constraints
            (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TypeScheme -> TIExprNode -> TIExpr
TIExpr ([TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [Constraint]
constraints Type
t) (String -> TIExprNode
TIVarExpr String
varName), Subst
emptySubst)
          Maybe TypeScheme
Nothing -> do
            -- No variable found in type environment - fall back to normal inference
            -- This is necessary for lambda parameters, let-bound variables, etc.
            IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
baseExpr TypeErrorContext
exprCtx
      IExpr
_ -> IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
baseExpr TypeErrorContext
exprCtx
    let baseType :: Type
baseType = TIExpr -> Type
tiExprType TIExpr
baseTI
    -- Infer indices as TIExpr
    [Index TIExpr]
indicesTI <- (Index IExpr
 -> ExceptT TypeError (StateT InferState IO) (Index TIExpr))
-> [Index IExpr]
-> ExceptT TypeError (StateT InferState IO) [Index TIExpr]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((IExpr -> Infer TIExpr)
-> Index IExpr
-> ExceptT TypeError (StateT InferState IO) (Index TIExpr)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Index a -> f (Index b)
traverse (\IExpr
idxExpr -> do
      (TIExpr
idxTI, Subst
_) <- IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
idxExpr TypeErrorContext
exprCtx
      TIExpr -> Infer TIExpr
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return TIExpr
idxTI)) [Index IExpr]
indices
    -- Check if all indices are concrete (constants) or symbolic (variables)
    let isSymbolicIndex :: Index TIExpr -> Bool
isSymbolicIndex Index TIExpr
idx = case Index TIExpr
idx of
          Sub (TIExpr TypeScheme
_ (TIVarExpr String
_)) -> Bool
True
          Sup (TIExpr TypeScheme
_ (TIVarExpr String
_)) -> Bool
True
          SupSub (TIExpr TypeScheme
_ (TIVarExpr String
_)) -> Bool
True
          User (TIExpr TypeScheme
_ (TIVarExpr String
_)) -> Bool
True
          Index TIExpr
_ -> Bool
False
        hasSymbolicIndex :: Bool
hasSymbolicIndex = (Index TIExpr -> Bool) -> [Index TIExpr] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any Index TIExpr -> Bool
isSymbolicIndex [Index TIExpr]
indicesTI
    -- For tensors with symbolic indices, keep the tensor type
    -- For concrete indices (numeric), return element type
    let resultType :: Type
resultType = case Type
baseType of
          TTensor Type
elemType -> 
            if Bool
hasSymbolicIndex
              then Type -> Type
TTensor Type
elemType  -- Symbolic index: keep tensor type
              else Type
elemType           -- Concrete index: element access
          TCollection Type
elemType -> Type
elemType
          THash Type
_keyType Type
valType -> Type
valType  -- Hash access returns value type
          Type
_ -> Type
baseType  -- Fallback: return base type
    (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> TIExprNode -> TIExpr
mkTIExpr Type
resultType (Bool -> TIExpr -> [Index TIExpr] -> TIExprNode
TIIndexedExpr Bool
override TIExpr
baseTI [Index TIExpr]
indicesTI), Subst
s)
  
  -- Subrefs expression (subscript references)
  ISubrefsExpr Bool
override IExpr
baseExpr IExpr
refExpr -> do
    let exprCtx :: TypeErrorContext
exprCtx = String -> TypeErrorContext -> TypeErrorContext
withExpr (IExpr -> String
forall a. Pretty a => a -> String
prettyStr IExpr
expr) TypeErrorContext
ctx
    (TIExpr
baseTI, Subst
s1) <- IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
baseExpr TypeErrorContext
exprCtx
    (TIExpr
refTI, Subst
s2) <- IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
refExpr TypeErrorContext
exprCtx
    let baseType :: Type
baseType = TIExpr -> Type
tiExprType TIExpr
baseTI
        finalS :: Subst
finalS = Subst -> Subst -> Subst
composeSubst Subst
s2 Subst
s1
        -- Subrefs requires base to be a Tensor type
        -- Force base type to be Tensor if not already
        tensorBaseType :: Type
tensorBaseType = case Type
baseType of
          TTensor Type
elemType -> Type -> Type
TTensor Type
elemType  -- Already Tensor
          Type
otherType -> Type -> Type
TTensor Type
otherType  -- Wrap non-Tensor in Tensor
        -- Result is also a Tensor type
        resultType :: Type
resultType = Type
tensorBaseType
    (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> TIExprNode -> TIExpr
mkTIExpr Type
resultType (Bool -> TIExpr -> TIExpr -> TIExprNode
TISubrefsExpr Bool
override TIExpr
baseTI TIExpr
refTI), Subst
finalS)
  
  -- Suprefs expression (superscript references)
  ISuprefsExpr Bool
override IExpr
baseExpr IExpr
refExpr -> do
    let exprCtx :: TypeErrorContext
exprCtx = String -> TypeErrorContext -> TypeErrorContext
withExpr (IExpr -> String
forall a. Pretty a => a -> String
prettyStr IExpr
expr) TypeErrorContext
ctx
    (TIExpr
baseTI, Subst
s1) <- IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
baseExpr TypeErrorContext
exprCtx
    (TIExpr
refTI, Subst
s2) <- IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
refExpr TypeErrorContext
exprCtx
    let baseType :: Type
baseType = TIExpr -> Type
tiExprType TIExpr
baseTI
        finalS :: Subst
finalS = Subst -> Subst -> Subst
composeSubst Subst
s2 Subst
s1
        -- Suprefs requires base to be a Tensor type
        -- Force base type to be Tensor if not already
        tensorBaseType :: Type
tensorBaseType = case Type
baseType of
          TTensor Type
elemType -> Type -> Type
TTensor Type
elemType  -- Already Tensor
          Type
otherType -> Type -> Type
TTensor Type
otherType  -- Wrap non-Tensor in Tensor
        -- Result is also a Tensor type
        resultType :: Type
resultType = Type
tensorBaseType
    (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> TIExprNode -> TIExpr
mkTIExpr Type
resultType (Bool -> TIExpr -> TIExpr -> TIExprNode
TISuprefsExpr Bool
override TIExpr
baseTI TIExpr
refTI), Subst
finalS)
  
  -- Userrefs expression (user-defined references)
  IUserrefsExpr Bool
override IExpr
baseExpr IExpr
refExpr -> do
    let exprCtx :: TypeErrorContext
exprCtx = String -> TypeErrorContext -> TypeErrorContext
withExpr (IExpr -> String
forall a. Pretty a => a -> String
prettyStr IExpr
expr) TypeErrorContext
ctx
    (TIExpr
baseTI, Subst
s1) <- IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
baseExpr TypeErrorContext
exprCtx
    (TIExpr
refTI, Subst
s2) <- IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
refExpr TypeErrorContext
exprCtx
    let baseType :: Type
baseType = TIExpr -> Type
tiExprType TIExpr
baseTI
        finalS :: Subst
finalS = Subst -> Subst -> Subst
composeSubst Subst
s2 Subst
s1
    -- TODO: Properly handle user-defined references
    (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> TIExprNode -> TIExpr
mkTIExpr Type
baseType (Bool -> TIExpr -> TIExpr -> TIExprNode
TIUserrefsExpr Bool
override TIExpr
baseTI TIExpr
refTI), Subst
finalS)

  -- Generate tensor expression
  IGenerateTensorExpr IExpr
funcExpr IExpr
shapeExpr -> do
    let exprCtx :: TypeErrorContext
exprCtx = String -> TypeErrorContext -> TypeErrorContext
withExpr (IExpr -> String
forall a. Pretty a => a -> String
prettyStr IExpr
expr) TypeErrorContext
ctx
    (TIExpr
funcTI, Subst
s1) <- IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
funcExpr TypeErrorContext
exprCtx
    (TIExpr
shapeTI, Subst
s2) <- IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
shapeExpr TypeErrorContext
exprCtx
    let funcType :: Type
funcType = TIExpr -> Type
tiExprType TIExpr
funcTI
    -- Extract element type from function result
    Type
elemType <- case Type
funcType of
      TFun Type
_ Type
resultType -> Type -> Infer Type
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return Type
resultType
      Type
_ -> String -> Infer Type
freshVar String
"tensorElem"
    let finalS :: Subst
finalS = Subst -> Subst -> Subst
composeSubst Subst
s2 Subst
s1
    Type
elemType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
finalS Type
elemType
    let resultType :: Type
resultType = Type -> Type
normalizeTensorType (Type -> Type
TTensor Type
elemType')
    (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> TIExprNode -> TIExpr
mkTIExpr Type
resultType (TIExpr -> TIExpr -> TIExprNode
TIGenerateTensorExpr TIExpr
funcTI TIExpr
shapeTI), Subst
finalS)
  
  -- Tensor expression
  ITensorExpr IExpr
shapeExpr IExpr
elemsExpr -> do
    let exprCtx :: TypeErrorContext
exprCtx = String -> TypeErrorContext -> TypeErrorContext
withExpr (IExpr -> String
forall a. Pretty a => a -> String
prettyStr IExpr
expr) TypeErrorContext
ctx
    (TIExpr
shapeTI, Subst
s1) <- IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
shapeExpr TypeErrorContext
exprCtx
    (TIExpr
elemsTI, Subst
s2) <- IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
elemsExpr TypeErrorContext
exprCtx
    let elemsType :: Type
elemsType = TIExpr -> Type
tiExprType TIExpr
elemsTI
    -- Extract element type
    Type
elemType <- case Type
elemsType of
      TCollection Type
t -> Type -> Infer Type
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return Type
t
      Type
_ -> String -> Infer Type
freshVar String
"tensorElem"
    let finalS :: Subst
finalS = Subst -> Subst -> Subst
composeSubst Subst
s2 Subst
s1
    Type
elemType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
finalS Type
elemType
    let resultType :: Type
resultType = Type -> Type
normalizeTensorType (Type -> Type
TTensor Type
elemType')
    (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> TIExprNode -> TIExpr
mkTIExpr Type
resultType (TIExpr -> TIExpr -> TIExprNode
TITensorExpr TIExpr
shapeTI TIExpr
elemsTI), Subst
finalS)
  
  -- Tensor contract expression
  ITensorContractExpr IExpr
tensorExpr -> do
    let exprCtx :: TypeErrorContext
exprCtx = String -> TypeErrorContext -> TypeErrorContext
withExpr (IExpr -> String
forall a. Pretty a => a -> String
prettyStr IExpr
expr) TypeErrorContext
ctx
    (TIExpr
tensorTI, Subst
s1) <- IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
tensorExpr TypeErrorContext
exprCtx
    let tensorType :: Type
tensorType = TIExpr -> Type
tiExprType TIExpr
tensorTI
    
    -- contract : Tensor a -> [Tensor a]
    -- Ensure the argument is a Tensor type by unifying with TTensor elemType
    Type
elemType <- String -> Infer Type
freshVar String
"contractElem"
    Type
tensorType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s1 Type
tensorType
    Subst
s2 <- Type -> Type -> TypeErrorContext -> Infer Subst
unifyTypesWithContext Type
tensorType' (Type -> Type
TTensor Type
elemType) TypeErrorContext
exprCtx

    let finalS :: Subst
finalS = Subst -> Subst -> Subst
composeSubst Subst
s2 Subst
s1
    Type
finalElemType <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
finalS Type
elemType
    let resultType :: Type
resultType = Type -> Type
TCollection (Type -> Type
TTensor Type
finalElemType)
    TIExpr
updatedTensorTI <- Subst -> TIExpr -> Infer TIExpr
applySubstToTIExprM Subst
finalS TIExpr
tensorTI

    (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> TIExprNode -> TIExpr
mkTIExpr Type
resultType (TIExpr -> TIExprNode
TITensorContractExpr TIExpr
updatedTensorTI), Subst
finalS)
  
  -- Tensor map expression
  ITensorMapExpr IExpr
func IExpr
tensorExpr -> do
    let exprCtx :: TypeErrorContext
exprCtx = String -> TypeErrorContext -> TypeErrorContext
withExpr (IExpr -> String
forall a. Pretty a => a -> String
prettyStr IExpr
expr) TypeErrorContext
ctx
    (TIExpr
funcTI, Subst
s1) <- IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
func TypeErrorContext
exprCtx
    (TIExpr
tensorTI, Subst
s2) <- IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
tensorExpr TypeErrorContext
exprCtx
    let funcType :: Type
funcType = TIExpr -> Type
tiExprType TIExpr
funcTI
        tensorType :: Type
tensorType = TIExpr -> Type
tiExprType TIExpr
tensorTI
        s12 :: Subst
s12 = Subst -> Subst -> Subst
composeSubst Subst
s2 Subst
s1
    -- Function maps elements: a -> b, tensor is Tensor a, result is Tensor b
    case Type
tensorType of
      TTensor Type
elemType -> do
        Type
resultElemType <- String -> Infer Type
freshVar String
"tmapElem"
        Type
funcType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s12 Type
funcType
        Subst
s3 <- Type -> Type -> TypeErrorContext -> Infer Subst
unifyTypesWithContext Type
funcType' (Type -> Type -> Type
TFun Type
elemType Type
resultElemType) TypeErrorContext
exprCtx
        let finalS :: Subst
finalS = Subst -> Subst -> Subst
composeSubst Subst
s3 Subst
s12
        Type
resultElemType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
finalS Type
resultElemType
        let resultType :: Type
resultType = Type -> Type
normalizeTensorType (Type -> Type
TTensor Type
resultElemType')
        TIExpr
updatedFuncTI <- Subst -> TIExpr -> Infer TIExpr
applySubstToTIExprM Subst
finalS TIExpr
funcTI
        TIExpr
updatedTensorTI <- Subst -> TIExpr -> Infer TIExpr
applySubstToTIExprM Subst
finalS TIExpr
tensorTI
        (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> TIExprNode -> TIExpr
mkTIExpr Type
resultType (TIExpr -> TIExpr -> TIExprNode
TITensorMapExpr TIExpr
updatedFuncTI TIExpr
updatedTensorTI), Subst
finalS)
      Type
_ -> do
        TIExpr
updatedFuncTI <- Subst -> TIExpr -> Infer TIExpr
applySubstToTIExprM Subst
s12 TIExpr
funcTI
        TIExpr
updatedTensorTI <- Subst -> TIExpr -> Infer TIExpr
applySubstToTIExprM Subst
s12 TIExpr
tensorTI
        (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> TIExprNode -> TIExpr
mkTIExpr Type
tensorType (TIExpr -> TIExpr -> TIExprNode
TITensorMapExpr TIExpr
updatedFuncTI TIExpr
updatedTensorTI), Subst
s12)
  
  -- Tensor map2 expression (binary map)
  ITensorMap2Expr IExpr
func IExpr
tensor1 IExpr
tensor2 -> do
    let exprCtx :: TypeErrorContext
exprCtx = String -> TypeErrorContext -> TypeErrorContext
withExpr (IExpr -> String
forall a. Pretty a => a -> String
prettyStr IExpr
expr) TypeErrorContext
ctx
    (TIExpr
funcTI, Subst
s1) <- IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
func TypeErrorContext
exprCtx
    (TIExpr
tensor1TI, Subst
s2) <- IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
tensor1 TypeErrorContext
exprCtx
    (TIExpr
tensor2TI, Subst
s3) <- IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
tensor2 TypeErrorContext
exprCtx
    let funcType :: Type
funcType = TIExpr -> Type
tiExprType TIExpr
funcTI
        t1Type :: Type
t1Type = TIExpr -> Type
tiExprType TIExpr
tensor1TI
        t2Type :: Type
t2Type = TIExpr -> Type
tiExprType TIExpr
tensor2TI
        s123 :: Subst
s123 = (Subst -> Subst -> Subst) -> Subst -> [Subst] -> Subst
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Subst -> Subst -> Subst
composeSubst Subst
emptySubst [Subst
s3, Subst
s2, Subst
s1]
    -- Function: a -> b -> c, tensors are Tensor a and Tensor b, result is Tensor c
    case (Type
t1Type, Type
t2Type) of
      (TTensor Type
elem1, TTensor Type
elem2) -> do
        Type
resultElemType <- String -> Infer Type
freshVar String
"tmap2Elem"
        Type
funcType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s123 Type
funcType
        Subst
s4 <- Type -> Type -> TypeErrorContext -> Infer Subst
unifyTypesWithContext Type
funcType'
                (Type -> Type -> Type
TFun Type
elem1 (Type -> Type -> Type
TFun Type
elem2 Type
resultElemType)) TypeErrorContext
exprCtx
        let finalS :: Subst
finalS = Subst -> Subst -> Subst
composeSubst Subst
s4 Subst
s123
        Type
resultElemType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
finalS Type
resultElemType
        let resultType :: Type
resultType = Type -> Type
normalizeTensorType (Type -> Type
TTensor Type
resultElemType')
        TIExpr
updatedFuncTI <- Subst -> TIExpr -> Infer TIExpr
applySubstToTIExprM Subst
finalS TIExpr
funcTI
        TIExpr
updatedTensor1TI <- Subst -> TIExpr -> Infer TIExpr
applySubstToTIExprM Subst
finalS TIExpr
tensor1TI
        TIExpr
updatedTensor2TI <- Subst -> TIExpr -> Infer TIExpr
applySubstToTIExprM Subst
finalS TIExpr
tensor2TI
        (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> TIExprNode -> TIExpr
mkTIExpr Type
resultType (TIExpr -> TIExpr -> TIExpr -> TIExprNode
TITensorMap2Expr TIExpr
updatedFuncTI TIExpr
updatedTensor1TI TIExpr
updatedTensor2TI), Subst
finalS)
      (Type, Type)
_ -> do
        TIExpr
updatedFuncTI <- Subst -> TIExpr -> Infer TIExpr
applySubstToTIExprM Subst
s123 TIExpr
funcTI
        TIExpr
updatedTensor1TI <- Subst -> TIExpr -> Infer TIExpr
applySubstToTIExprM Subst
s123 TIExpr
tensor1TI
        TIExpr
updatedTensor2TI <- Subst -> TIExpr -> Infer TIExpr
applySubstToTIExprM Subst
s123 TIExpr
tensor2TI
        (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> TIExprNode -> TIExpr
mkTIExpr Type
t1Type (TIExpr -> TIExpr -> TIExpr -> TIExprNode
TITensorMap2Expr TIExpr
updatedFuncTI TIExpr
updatedTensor1TI TIExpr
updatedTensor2TI), Subst
s123)
  
  -- Transpose expression
  -- ITransposeExpr takes (permutation, tensor) to match tTranspose signature
  ITransposeExpr IExpr
permExpr IExpr
tensorExpr -> do
    let exprCtx :: TypeErrorContext
exprCtx = String -> TypeErrorContext -> TypeErrorContext
withExpr (IExpr -> String
forall a. Pretty a => a -> String
prettyStr IExpr
expr) TypeErrorContext
ctx
    (TIExpr
permTI, Subst
s) <- IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
permExpr TypeErrorContext
exprCtx
    let permType :: Type
permType = TIExpr -> Type
tiExprType TIExpr
permTI
    -- Unify permutation type with [MathExpr]
    Type
permType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s Type
permType
    Subst
s2 <- Type -> Type -> TypeErrorContext -> Infer Subst
unifyTypesWithContext Type
permType' (Type -> Type
TCollection Type
TMathExpr) TypeErrorContext
exprCtx
    (TIExpr
tensorTI, Subst
s3) <- IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
tensorExpr TypeErrorContext
exprCtx
    let finalS :: Subst
finalS = Subst -> Subst -> Subst
composeSubst Subst
s3 (Subst -> Subst -> Subst
composeSubst Subst
s2 Subst
s)
    TIExpr
updatedPermTI <- Subst -> TIExpr -> Infer TIExpr
applySubstToTIExprM Subst
finalS TIExpr
permTI
    TIExpr
updatedTensorTI <- Subst -> TIExpr -> Infer TIExpr
applySubstToTIExprM Subst
finalS TIExpr
tensorTI
    let tensorType :: Type
tensorType = TIExpr -> Type
tiExprType TIExpr
updatedTensorTI
    -- Transpose preserves tensor type
    (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> TIExprNode -> TIExpr
mkTIExpr (Type -> Type
normalizeTensorType Type
tensorType) (TIExpr -> TIExpr -> TIExprNode
TITransposeExpr TIExpr
updatedPermTI TIExpr
updatedTensorTI), Subst
finalS)

  -- Flip indices expression
  IFlipIndicesExpr IExpr
tensorExpr -> do
    let exprCtx :: TypeErrorContext
exprCtx = String -> TypeErrorContext -> TypeErrorContext
withExpr (IExpr -> String
forall a. Pretty a => a -> String
prettyStr IExpr
expr) TypeErrorContext
ctx
    (TIExpr
tensorTI, Subst
s) <- IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
tensorExpr TypeErrorContext
exprCtx
    TIExpr
updatedTensorTI <- Subst -> TIExpr -> Infer TIExpr
applySubstToTIExprM Subst
s TIExpr
tensorTI
    let tensorType :: Type
tensorType = TIExpr -> Type
tiExprType TIExpr
updatedTensorTI
    -- Flipping indices preserves tensor type
    (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> TIExprNode -> TIExpr
mkTIExpr (Type -> Type
normalizeTensorType Type
tensorType) (TIExpr -> TIExprNode
TIFlipIndicesExpr TIExpr
updatedTensorTI), Subst
s)
  
  -- Function symbol expression
  IFunctionExpr [String]
names -> do
    -- Function symbols are mathematical function symbols (e.g., f(x,y))
    -- They are represented as MathExpr type
    (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> TIExprNode -> TIExpr
mkTIExpr Type
TMathExpr ([String] -> TIExprNode
TIFunctionExpr [String]
names), Subst
emptySubst)

-- | Infer match clauses type
-- All clauses should return the same type
-- NEW: Returns TIMatchClause list in addition to type and subst
inferMatchClauses :: TypeErrorContext -> Type -> [IMatchClause] -> Subst -> Infer (Type, [TIMatchClause], Subst)
inferMatchClauses :: TypeErrorContext
-> Type
-> [IMatchClause]
-> Subst
-> Infer (Type, [TIMatchClause], Subst)
inferMatchClauses TypeErrorContext
ctx Type
matchedType [IMatchClause]
clauses Subst
initSubst = do
  case [IMatchClause]
clauses of
    [] -> do
      -- No clauses (should not happen)
      Type
ty <- String -> Infer Type
freshVar String
"clauseResult"
      (Type, [TIMatchClause], Subst)
-> Infer (Type, [TIMatchClause], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
ty, [], Subst
initSubst)
    (IMatchClause
firstClause:[IMatchClause]
restClauses) -> do
      -- Infer first clause
      (TIMatchClause
firstTI, Type
firstType, Subst
s1) <- TypeErrorContext
-> Type
-> IMatchClause
-> Subst
-> Infer (TIMatchClause, Type, Subst)
inferMatchClause TypeErrorContext
ctx Type
matchedType IMatchClause
firstClause Subst
initSubst
      
      -- Infer rest clauses and unify with first
      (Type
finalType, [TIMatchClause]
clauseTIs, Subst
finalSubst) <- ((Type, [TIMatchClause], Subst)
 -> IMatchClause -> Infer (Type, [TIMatchClause], Subst))
-> (Type, [TIMatchClause], Subst)
-> [IMatchClause]
-> Infer (Type, [TIMatchClause], Subst)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (TypeErrorContext
-> Type
-> (Type, [TIMatchClause], Subst)
-> IMatchClause
-> Infer (Type, [TIMatchClause], Subst)
inferAndUnifyClause TypeErrorContext
ctx Type
matchedType) (Type
firstType, [TIMatchClause
firstTI], Subst
s1) [IMatchClause]
restClauses
      (Type, [TIMatchClause], Subst)
-> Infer (Type, [TIMatchClause], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
finalType, [TIMatchClause] -> [TIMatchClause]
forall a. [a] -> [a]
reverse [TIMatchClause]
clauseTIs, Subst
finalSubst)
  where
    inferAndUnifyClause :: TypeErrorContext -> Type -> (Type, [TIMatchClause], Subst) -> IMatchClause -> Infer (Type, [TIMatchClause], Subst)
    inferAndUnifyClause :: TypeErrorContext
-> Type
-> (Type, [TIMatchClause], Subst)
-> IMatchClause
-> Infer (Type, [TIMatchClause], Subst)
inferAndUnifyClause TypeErrorContext
ctx' Type
matchedTy (Type
expectedType, [TIMatchClause]
accClauses, Subst
accSubst) IMatchClause
clause = do
      Type
matchedTy' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
accSubst Type
matchedTy
      (TIMatchClause
clauseTI, Type
clauseType, Subst
s1) <- TypeErrorContext
-> Type
-> IMatchClause
-> Subst
-> Infer (TIMatchClause, Type, Subst)
inferMatchClause TypeErrorContext
ctx' Type
matchedTy' IMatchClause
clause Subst
accSubst
      Type
expectedType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s1 Type
expectedType
      Subst
s2 <- Type -> Type -> TypeErrorContext -> Infer Subst
unifyTypesWithContext Type
expectedType' Type
clauseType TypeErrorContext
ctx'
      let finalS :: Subst
finalS = Subst -> Subst -> Subst
composeSubst Subst
s2 (Subst -> Subst -> Subst
composeSubst Subst
s1 Subst
accSubst)
      Type
finalExpectedType <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
finalS Type
expectedType
      (Type, [TIMatchClause], Subst)
-> Infer (Type, [TIMatchClause], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
finalExpectedType, TIMatchClause
clauseTI TIMatchClause -> [TIMatchClause] -> [TIMatchClause]
forall a. a -> [a] -> [a]
: [TIMatchClause]
accClauses, Subst
finalS)

-- | Infer a single match clause
-- NEW: Returns TIMatchClause in addition to type and subst
inferMatchClause :: TypeErrorContext -> Type -> IMatchClause -> Subst -> Infer (TIMatchClause, Type, Subst)
inferMatchClause :: TypeErrorContext
-> Type
-> IMatchClause
-> Subst
-> Infer (TIMatchClause, Type, Subst)
inferMatchClause TypeErrorContext
ctx Type
matchedType (IPattern
pattern, IExpr
bodyExpr) Subst
initSubst = do
  -- Infer pattern type and extract pattern variable bindings
  -- Use pattern constructor and pattern function type information
  (TIPattern
tiPattern, [(String, Type)]
bindings, Subst
s_pat) <- IPattern
-> Type
-> TypeErrorContext
-> Infer (TIPattern, [(String, Type)], Subst)
inferIPattern IPattern
pattern Type
matchedType TypeErrorContext
ctx
  let s1 :: Subst
s1 = Subst -> Subst -> Subst
composeSubst Subst
s_pat Subst
initSubst
  
  -- Convert bindings to TypeScheme format
  let schemes :: [(String, TypeScheme)]
schemes = [(String
var, [TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [] Type
ty) | (String
var, Type
ty) <- [(String, Type)]
bindings]
  
  -- Infer body expression type with pattern variables in scope
  (TIExpr
bodyTI, Subst
s2) <- [(String, TypeScheme)]
-> Infer (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a. [(String, TypeScheme)] -> Infer a -> Infer a
withEnv [(String, TypeScheme)]
schemes (Infer (TIExpr, Subst) -> Infer (TIExpr, Subst))
-> Infer (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a b. (a -> b) -> a -> b
$ IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
bodyExpr TypeErrorContext
ctx
  let bodyType :: Type
bodyType = TIExpr -> Type
tiExprType TIExpr
bodyTI
      finalS :: Subst
finalS = Subst -> Subst -> Subst
composeSubst Subst
s2 Subst
s1
  Type
finalBodyType <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
finalS Type
bodyType
  (TIMatchClause, Type, Subst) -> Infer (TIMatchClause, Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return ((TIPattern
tiPattern, TIExpr
bodyTI), Type
finalBodyType, Subst
finalS)

-- | Infer multiple patterns left-to-right, making left bindings available to right patterns
-- This enables non-linear patterns like ($p, #(p + 1))
-- Returns (list of TIPattern, accumulated bindings, substitution)
inferPatternsLeftToRight :: [IPattern] -> [Type] -> [(String, Type)] -> Subst -> TypeErrorContext 
                         -> Infer ([TIPattern], [(String, Type)], Subst)
inferPatternsLeftToRight :: [IPattern]
-> [Type]
-> [(String, Type)]
-> Subst
-> TypeErrorContext
-> Infer ([TIPattern], [(String, Type)], Subst)
inferPatternsLeftToRight [] [] [(String, Type)]
accBindings Subst
accSubst TypeErrorContext
_ctx = 
  ([TIPattern], [(String, Type)], Subst)
-> Infer ([TIPattern], [(String, Type)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return ([], [(String, Type)]
accBindings, Subst
accSubst)
inferPatternsLeftToRight (IPattern
p:[IPattern]
ps) (Type
t:[Type]
ts) [(String, Type)]
accBindings Subst
accSubst TypeErrorContext
ctx = do
  -- Add accumulated bindings to environment for this pattern
  let schemes :: [(String, TypeScheme)]
schemes = [(String
var, [TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [] Type
ty) | (String
var, Type
ty) <- [(String, Type)]
accBindings]

  -- Infer this pattern with left bindings in scope
  Type
t' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
accSubst Type
t
  (TIPattern
tipat, [(String, Type)]
newBindings, Subst
s) <- [(String, TypeScheme)]
-> Infer (TIPattern, [(String, Type)], Subst)
-> Infer (TIPattern, [(String, Type)], Subst)
forall a. [(String, TypeScheme)] -> Infer a -> Infer a
withEnv [(String, TypeScheme)]
schemes (Infer (TIPattern, [(String, Type)], Subst)
 -> Infer (TIPattern, [(String, Type)], Subst))
-> Infer (TIPattern, [(String, Type)], Subst)
-> Infer (TIPattern, [(String, Type)], Subst)
forall a b. (a -> b) -> a -> b
$ IPattern
-> Type
-> TypeErrorContext
-> Infer (TIPattern, [(String, Type)], Subst)
inferIPattern IPattern
p Type
t' TypeErrorContext
ctx

  -- Compose substitutions
  let accSubst' :: Subst
accSubst' = Subst -> Subst -> Subst
composeSubst Subst
s Subst
accSubst

  -- Apply substitution to accumulated bindings
  [(String, Type)]
accBindings'' <- ((String, Type)
 -> ExceptT TypeError (StateT InferState IO) (String, Type))
-> [(String, Type)]
-> ExceptT TypeError (StateT InferState IO) [(String, Type)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\(String
v, Type
ty) -> do
      Type
ty' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s Type
ty
      (String, Type)
-> ExceptT TypeError (StateT InferState IO) (String, Type)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (String
v, Type
ty')) [(String, Type)]
accBindings
  let accBindings' :: [(String, Type)]
accBindings' = [(String, Type)]
accBindings'' [(String, Type)] -> [(String, Type)] -> [(String, Type)]
forall a. [a] -> [a] -> [a]
++ [(String, Type)]
newBindings
  
  -- Continue with remaining patterns
  ([TIPattern]
restTipats, [(String, Type)]
finalBindings, Subst
finalSubst) <- [IPattern]
-> [Type]
-> [(String, Type)]
-> Subst
-> TypeErrorContext
-> Infer ([TIPattern], [(String, Type)], Subst)
inferPatternsLeftToRight [IPattern]
ps [Type]
ts [(String, Type)]
accBindings' Subst
accSubst' TypeErrorContext
ctx
  ([TIPattern], [(String, Type)], Subst)
-> Infer ([TIPattern], [(String, Type)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIPattern
tipat TIPattern -> [TIPattern] -> [TIPattern]
forall a. a -> [a] -> [a]
: [TIPattern]
restTipats, [(String, Type)]
finalBindings, Subst
finalSubst)
inferPatternsLeftToRight [IPattern]
_ [Type]
_ [(String, Type)]
accBindings Subst
accSubst TypeErrorContext
_ = 
  ([TIPattern], [(String, Type)], Subst)
-> Infer ([TIPattern], [(String, Type)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return ([], [(String, Type)]
accBindings, Subst
accSubst)  -- Mismatched lengths

-- | Infer IPattern type and extract pattern variable bindings
-- Returns (TIPattern, bindings, substitution)
-- bindings: [(variable name, type)]
inferIPattern :: IPattern -> Type -> TypeErrorContext -> Infer (TIPattern, [(String, Type)], Subst)
inferIPattern :: IPattern
-> Type
-> TypeErrorContext
-> Infer (TIPattern, [(String, Type)], Subst)
inferIPattern IPattern
pat Type
expectedType TypeErrorContext
ctx = case IPattern
pat of
  IPattern
IWildCard -> do
    -- Wildcard: no bindings
    let tipat :: TIPattern
tipat = TypeScheme -> TIPatternNode -> TIPattern
TIPattern ([TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [] Type
expectedType) TIPatternNode
TIWildCard
    (TIPattern, [(String, Type)], Subst)
-> Infer (TIPattern, [(String, Type)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIPattern
tipat, [], Subst
emptySubst)
  
  IPatVar String
name -> do
    -- Pattern variable: bind to expected type
    let tipat :: TIPattern
tipat = TypeScheme -> TIPatternNode -> TIPattern
TIPattern ([TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [] Type
expectedType) (String -> TIPatternNode
TIPatVar String
name)
    (TIPattern, [(String, Type)], Subst)
-> Infer (TIPattern, [(String, Type)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIPattern
tipat, [(String
name, Type
expectedType)], Subst
emptySubst)
  
  IValuePat IExpr
expr -> do
    -- Value pattern: infer expression type and unify with expected type
    (TIExpr
exprTI, Subst
s) <- IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
expr TypeErrorContext
ctx
    let exprType :: Type
exprType = TIExpr -> Type
tiExprType TIExpr
exprTI
    Type
exprType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s Type
exprType
    Type
expectedType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s Type
expectedType
    Subst
s' <- Type -> Type -> TypeErrorContext -> Infer Subst
unifyTypesWithContext Type
exprType' Type
expectedType' TypeErrorContext
ctx
    let finalS :: Subst
finalS = Subst -> Subst -> Subst
composeSubst Subst
s' Subst
s
    TIExpr
exprTI' <- Subst -> TIExpr -> Infer TIExpr
applySubstToTIExprM Subst
finalS TIExpr
exprTI
    Type
finalType <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
finalS Type
expectedType
    let tipat :: TIPattern
tipat = TypeScheme -> TIPatternNode -> TIPattern
TIPattern ([TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [] Type
finalType) (TIExpr -> TIPatternNode
TIValuePat TIExpr
exprTI')
    (TIPattern, [(String, Type)], Subst)
-> Infer (TIPattern, [(String, Type)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIPattern
tipat, [], Subst
finalS)

  IPredPat IExpr
expr -> do
    -- Predicate pattern: infer predicate expression
    -- Expected type for predicate is: expectedType -> Bool
    let predicateType :: Type
predicateType = Type -> Type -> Type
TFun Type
expectedType Type
TBool
    (TIExpr
exprTI, Subst
s) <- IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
expr TypeErrorContext
ctx
    -- Unify with expected predicate type to concretize type variables
    Type
exprType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s (TIExpr -> Type
tiExprType TIExpr
exprTI)
    Type
predicateType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s Type
predicateType
    Subst
s' <- Type -> Type -> TypeErrorContext -> Infer Subst
unifyTypesWithContext Type
exprType' Type
predicateType' TypeErrorContext
ctx
    let finalS :: Subst
finalS = Subst -> Subst -> Subst
composeSubst Subst
s' Subst
s
    TIExpr
exprTI' <- Subst -> TIExpr -> Infer TIExpr
applySubstToTIExprM Subst
finalS TIExpr
exprTI
    Type
finalType <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
finalS Type
expectedType
    let tipat :: TIPattern
tipat = TypeScheme -> TIPatternNode -> TIPattern
TIPattern ([TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [] Type
finalType) (TIExpr -> TIPatternNode
TIPredPat TIExpr
exprTI')
    (TIPattern, [(String, Type)], Subst)
-> Infer (TIPattern, [(String, Type)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIPattern
tipat, [], Subst
finalS)
  
  ITuplePat [IPattern]
pats -> do
    -- Tuple pattern: decompose expected type
    case Type
expectedType of
      TTuple [Type]
types | [Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
types Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [IPattern] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [IPattern]
pats -> do
        -- Types match: infer each sub-pattern left-to-right
        -- Left patterns' bindings are available for right patterns (for non-linear patterns)
        ([TIPattern]
tipats, [(String, Type)]
allBindings, Subst
s) <- [IPattern]
-> [Type]
-> [(String, Type)]
-> Subst
-> TypeErrorContext
-> Infer ([TIPattern], [(String, Type)], Subst)
inferPatternsLeftToRight [IPattern]
pats [Type]
types [] Subst
emptySubst TypeErrorContext
ctx
        Type
finalType <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s Type
expectedType
        let tipat :: TIPattern
tipat = TypeScheme -> TIPatternNode -> TIPattern
TIPattern ([TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [] Type
finalType) ([TIPattern] -> TIPatternNode
TITuplePat [TIPattern]
tipats)
        (TIPattern, [(String, Type)], Subst)
-> Infer (TIPattern, [(String, Type)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIPattern
tipat, [(String, Type)]
allBindings, Subst
s)
      
      TVar TyVar
_ -> do
        -- Expected type is a type variable: create tuple type
        [Type]
elemTypes <- (IPattern -> Infer Type)
-> [IPattern] -> ExceptT TypeError (StateT InferState IO) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\IPattern
_ -> String -> Infer Type
freshVar String
"elem") [IPattern]
pats
        let tupleTy :: Type
tupleTy = [Type] -> Type
TTuple [Type]
elemTypes
        Subst
s <- Type -> Type -> TypeErrorContext -> Infer Subst
unifyTypesWithContext Type
expectedType Type
tupleTy TypeErrorContext
ctx

        -- Recursively infer each sub-pattern left-to-right
        [Type]
elemTypes' <- (Type -> Infer Type)
-> [Type] -> ExceptT TypeError (StateT InferState IO) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s) [Type]
elemTypes
        ([TIPattern]
tipats, [(String, Type)]
allBindings, Subst
s') <- [IPattern]
-> [Type]
-> [(String, Type)]
-> Subst
-> TypeErrorContext
-> Infer ([TIPattern], [(String, Type)], Subst)
inferPatternsLeftToRight [IPattern]
pats [Type]
elemTypes' [] Subst
s TypeErrorContext
ctx
        Type
finalType <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s' Type
expectedType
        let tipat :: TIPattern
tipat = TypeScheme -> TIPatternNode -> TIPattern
TIPattern ([TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [] Type
finalType) ([TIPattern] -> TIPatternNode
TITuplePat [TIPattern]
tipats)
        (TIPattern, [(String, Type)], Subst)
-> Infer (TIPattern, [(String, Type)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIPattern
tipat, [(String, Type)]
allBindings, Subst
s')
      
      Type
_ -> do
        -- Type mismatch
        TypeError -> Infer (TIPattern, [(String, Type)], Subst)
forall a. TypeError -> ExceptT TypeError (StateT InferState IO) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (TypeError -> Infer (TIPattern, [(String, Type)], Subst))
-> TypeError -> Infer (TIPattern, [(String, Type)], Subst)
forall a b. (a -> b) -> a -> b
$ Type -> Type -> String -> TypeErrorContext -> TypeError
TE.TypeMismatch
          ([Type] -> Type
TTuple (Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate ([IPattern] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [IPattern]
pats) (TyVar -> Type
TVar (String -> TyVar
TyVar String
"a"))))
          Type
expectedType
          String
"Tuple pattern but matched type is not a tuple"
          TypeErrorContext
ctx
  
  IInductivePat String
name [IPattern]
pats -> do
    -- Inductive pattern: look up pattern constructor type from pattern environment
    PatternTypeEnv
patternEnv <- Infer PatternTypeEnv
getPatternEnv
    case String -> PatternTypeEnv -> Maybe TypeScheme
lookupPatternEnv String
name PatternTypeEnv
patternEnv of
      Just TypeScheme
scheme -> do
        -- Found in pattern environment: use the declared type
        InferState
st <- ExceptT TypeError (StateT InferState IO) InferState
forall s (m :: * -> *). MonadState s m => m s
get
        let ([Constraint]
_constraints, Type
ctorType, Int
newCounter) = TypeScheme -> Int -> ([Constraint], Type, Int)
instantiate TypeScheme
scheme (InferState -> Int
inferCounter InferState
st)
        (InferState -> InferState) -> Infer ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((InferState -> InferState) -> Infer ())
-> (InferState -> InferState) -> Infer ()
forall a b. (a -> b) -> a -> b
$ \InferState
s -> InferState
s { inferCounter = newCounter }
        
        -- Pattern constructor type: arg1 -> arg2 -> ... -> resultType
        let ([Type]
argTypes, Type
resultType) = Type -> ([Type], Type)
extractFunctionArgs Type
ctorType
        
        -- Check argument count matches
        if [Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
argTypes Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= [IPattern] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [IPattern]
pats
          then TypeError -> Infer (TIPattern, [(String, Type)], Subst)
forall a. TypeError -> ExceptT TypeError (StateT InferState IO) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (TypeError -> Infer (TIPattern, [(String, Type)], Subst))
-> TypeError -> Infer (TIPattern, [(String, Type)], Subst)
forall a b. (a -> b) -> a -> b
$ Type -> Type -> String -> TypeErrorContext -> TypeError
TE.TypeMismatch
                 ((Type -> Type -> Type) -> Type -> [Type] -> Type
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Type -> Type -> Type
TFun Type
resultType (Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate ([IPattern] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [IPattern]
pats) (TyVar -> Type
TVar (String -> TyVar
TyVar String
"a"))))
                 Type
ctorType
                 (String
"Pattern constructor " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
name String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" expects " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show ([Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
argTypes) 
                  String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" arguments, but got " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show ([IPattern] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [IPattern]
pats))
                 TypeErrorContext
ctx
          else do
            -- Unify result type with expected type
            Subst
s0 <- Type -> Type -> TypeErrorContext -> Infer Subst
unifyTypesWithContext Type
resultType Type
expectedType TypeErrorContext
ctx
            [Type]
argTypes' <- (Type -> Infer Type)
-> [Type] -> ExceptT TypeError (StateT InferState IO) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s0) [Type]
argTypes

            -- Recursively infer each sub-pattern left-to-right
            -- Left patterns' bindings are available for right patterns
            ([TIPattern]
tipats, [(String, Type)]
allBindings, Subst
s) <- [IPattern]
-> [Type]
-> [(String, Type)]
-> Subst
-> TypeErrorContext
-> Infer ([TIPattern], [(String, Type)], Subst)
inferPatternsLeftToRight [IPattern]
pats [Type]
argTypes' [] Subst
s0 TypeErrorContext
ctx
            Type
finalType <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s Type
expectedType
            let tipat :: TIPattern
tipat = TypeScheme -> TIPatternNode -> TIPattern
TIPattern ([TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [] Type
finalType) (String -> [TIPattern] -> TIPatternNode
TIInductivePat String
name [TIPattern]
tipats)
            (TIPattern, [(String, Type)], Subst)
-> Infer (TIPattern, [(String, Type)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIPattern
tipat, [(String, Type)]
allBindings, Subst
s)
      
      Maybe TypeScheme
Nothing -> do
        -- Not found in pattern environment: try data constructor from value environment
        -- This handles data constructors used as patterns
        TypeEnv
env <- Infer TypeEnv
getEnv
        case Var -> TypeEnv -> Maybe TypeScheme
lookupEnv (String -> Var
stringToVar String
name) TypeEnv
env of
          Just TypeScheme
scheme -> do
            InferState
st <- ExceptT TypeError (StateT InferState IO) InferState
forall s (m :: * -> *). MonadState s m => m s
get
            let ([Constraint]
_constraints, Type
ctorType, Int
newCounter) = TypeScheme -> Int -> ([Constraint], Type, Int)
instantiate TypeScheme
scheme (InferState -> Int
inferCounter InferState
st)
            (InferState -> InferState) -> Infer ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((InferState -> InferState) -> Infer ())
-> (InferState -> InferState) -> Infer ()
forall a b. (a -> b) -> a -> b
$ \InferState
s -> InferState
s { inferCounter = newCounter }
            
            let ([Type]
argTypes, Type
resultType) = Type -> ([Type], Type)
extractFunctionArgs Type
ctorType
            
            if [Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
argTypes Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= [IPattern] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [IPattern]
pats
              then TypeError -> Infer (TIPattern, [(String, Type)], Subst)
forall a. TypeError -> ExceptT TypeError (StateT InferState IO) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (TypeError -> Infer (TIPattern, [(String, Type)], Subst))
-> TypeError -> Infer (TIPattern, [(String, Type)], Subst)
forall a b. (a -> b) -> a -> b
$ Type -> Type -> String -> TypeErrorContext -> TypeError
TE.TypeMismatch
                     ((Type -> Type -> Type) -> Type -> [Type] -> Type
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Type -> Type -> Type
TFun Type
resultType (Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate ([IPattern] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [IPattern]
pats) (TyVar -> Type
TVar (String -> TyVar
TyVar String
"a"))))
                     Type
ctorType
                     (String
"Constructor " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
name String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" expects " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show ([Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
argTypes) 
                      String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" arguments, but got " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show ([IPattern] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [IPattern]
pats))
                     TypeErrorContext
ctx
              else do
                Subst
s0 <- Type -> Type -> TypeErrorContext -> Infer Subst
unifyTypesWithContext Type
resultType Type
expectedType TypeErrorContext
ctx
                [Type]
argTypes' <- (Type -> Infer Type)
-> [Type] -> ExceptT TypeError (StateT InferState IO) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s0) [Type]
argTypes

                -- Recursively infer each sub-pattern left-to-right
                ([TIPattern]
tipats, [(String, Type)]
allBindings, Subst
s) <- [IPattern]
-> [Type]
-> [(String, Type)]
-> Subst
-> TypeErrorContext
-> Infer ([TIPattern], [(String, Type)], Subst)
inferPatternsLeftToRight [IPattern]
pats [Type]
argTypes' [] Subst
s0 TypeErrorContext
ctx
                Type
finalType <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s Type
expectedType
                let tipat :: TIPattern
tipat = TypeScheme -> TIPatternNode -> TIPattern
TIPattern ([TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [] Type
finalType) (String -> [TIPattern] -> TIPatternNode
TIInductivePat String
name [TIPattern]
tipats)
                (TIPattern, [(String, Type)], Subst)
-> Infer (TIPattern, [(String, Type)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIPattern
tipat, [(String, Type)]
allBindings, Subst
s)
          
          Maybe TypeScheme
Nothing -> do
            -- Not found: generic inference
            [Type]
argTypes <- (IPattern -> Infer Type)
-> [IPattern] -> ExceptT TypeError (StateT InferState IO) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\IPattern
_ -> String -> Infer Type
freshVar String
"arg") [IPattern]
pats
            let resultType :: Type
resultType = String -> [Type] -> Type
TInductive String
name [Type]
argTypes

            Subst
s0 <- Type -> Type -> TypeErrorContext -> Infer Subst
unifyTypesWithContext Type
resultType Type
expectedType TypeErrorContext
ctx
            [Type]
argTypes' <- (Type -> Infer Type)
-> [Type] -> ExceptT TypeError (StateT InferState IO) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s0) [Type]
argTypes

            -- Recursively infer each sub-pattern left-to-right
            ([TIPattern]
tipats, [(String, Type)]
allBindings, Subst
s) <- [IPattern]
-> [Type]
-> [(String, Type)]
-> Subst
-> TypeErrorContext
-> Infer ([TIPattern], [(String, Type)], Subst)
inferPatternsLeftToRight [IPattern]
pats [Type]
argTypes' [] Subst
s0 TypeErrorContext
ctx
            Type
finalType <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s Type
expectedType
            let tipat :: TIPattern
tipat = TypeScheme -> TIPatternNode -> TIPattern
TIPattern ([TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [] Type
finalType) (String -> [TIPattern] -> TIPatternNode
TIInductivePat String
name [TIPattern]
tipats)
            (TIPattern, [(String, Type)], Subst)
-> Infer (TIPattern, [(String, Type)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIPattern
tipat, [(String, Type)]
allBindings, Subst
s)
  
  IIndexedPat IPattern
p [IExpr]
indices -> do
    -- Indexed pattern: infer base pattern and index expressions
    -- For $x_i pattern, x should have type Hash keyType expectedType
    -- where expectedType is the type of the indexed result
    
    -- First, infer the index expressions to determine their types
    [Type]
indexTypes <- (IExpr -> Infer Type)
-> [IExpr] -> ExceptT TypeError (StateT InferState IO) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\IExpr
_ -> String -> Infer Type
freshVar String
"idx") [IExpr]
indices
    ([TIExpr]
indexTIs, Subst
s1) <- (([TIExpr], Subst)
 -> (IExpr, Type)
 -> ExceptT TypeError (StateT InferState IO) ([TIExpr], Subst))
-> ([TIExpr], Subst)
-> [(IExpr, Type)]
-> ExceptT TypeError (StateT InferState IO) ([TIExpr], Subst)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (\([TIExpr]
accTIs, Subst
accS) (IExpr
idx, Type
idxType) -> do
      (TIExpr
idxTI, Subst
idxS) <- IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
idx TypeErrorContext
ctx
      let actualIdxType :: Type
actualIdxType = TIExpr -> Type
tiExprType TIExpr
idxTI
      Type
actualIdxType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
idxS Type
actualIdxType
      Type
idxType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
idxS Type
idxType
      Subst
s' <- Type -> Type -> TypeErrorContext -> Infer Subst
unifyTypesWithContext Type
actualIdxType' Type
idxType' TypeErrorContext
ctx
      let finalS :: Subst
finalS = Subst -> Subst -> Subst
composeSubst Subst
s' (Subst -> Subst -> Subst
composeSubst Subst
idxS Subst
accS)
      ([TIExpr], Subst)
-> ExceptT TypeError (StateT InferState IO) ([TIExpr], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return ([TIExpr]
accTIs [TIExpr] -> [TIExpr] -> [TIExpr]
forall a. [a] -> [a] -> [a]
++ [TIExpr
idxTI], Subst
finalS)) ([], Subst
emptySubst) ([IExpr] -> [Type] -> [(IExpr, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [IExpr]
indices [Type]
indexTypes)

    -- Construct the base type: Hash indexType expectedType
    -- For simplicity, assume single index access and use THash
    Type
indexType <- case [Type]
indexTypes of
                   [Type
t] -> Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s1 Type
t
                   [Type]
_ -> Type -> Infer Type
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return Type
TInt  -- Multiple indices: fallback to Int
    let baseType :: Type
baseType = Type -> Type -> Type
THash Type
indexType Type
expectedType

    -- Infer base pattern with Hash type
    Type
baseType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s1 Type
baseType
    (TIPattern
tipat, [(String, Type)]
bindings, Subst
s2) <- IPattern
-> Type
-> TypeErrorContext
-> Infer (TIPattern, [(String, Type)], Subst)
inferIPattern IPattern
p Type
baseType' TypeErrorContext
ctx

    let finalS :: Subst
finalS = Subst -> Subst -> Subst
composeSubst Subst
s2 Subst
s1
    Type
finalType <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
finalS Type
expectedType
    let tiIndexedPat :: TIPattern
tiIndexedPat = TypeScheme -> TIPatternNode -> TIPattern
TIPattern ([TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [] Type
finalType) (TIPattern -> [TIExpr] -> TIPatternNode
TIIndexedPat TIPattern
tipat [TIExpr]
indexTIs)
    (TIPattern, [(String, Type)], Subst)
-> Infer (TIPattern, [(String, Type)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIPattern
tiIndexedPat, [(String, Type)]
bindings, Subst
finalS)
  
  ILetPat [IBindingExpr]
bindings IPattern
p -> do
    -- Let pattern: infer bindings and then the pattern
    -- Infer bindings first
    TypeEnv
env <- Infer TypeEnv
getEnv
    ([TIBindingExpr]
bindingTIs, [(String, TypeScheme)]
bindingSchemes, Subst
s1) <- [IBindingExpr]
-> TypeEnv
-> Subst
-> TypeErrorContext
-> Infer ([TIBindingExpr], [(String, TypeScheme)], Subst)
inferIBindingsWithContext [IBindingExpr]
bindings TypeEnv
env Subst
emptySubst TypeErrorContext
ctx

    -- Infer pattern with bindings in scope
    Type
expectedType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s1 Type
expectedType
    (TIPattern
tipat, [(String, Type)]
patBindings, Subst
s2) <- [(String, TypeScheme)]
-> Infer (TIPattern, [(String, Type)], Subst)
-> Infer (TIPattern, [(String, Type)], Subst)
forall a. [(String, TypeScheme)] -> Infer a -> Infer a
withEnv [(String, TypeScheme)]
bindingSchemes (Infer (TIPattern, [(String, Type)], Subst)
 -> Infer (TIPattern, [(String, Type)], Subst))
-> Infer (TIPattern, [(String, Type)], Subst)
-> Infer (TIPattern, [(String, Type)], Subst)
forall a b. (a -> b) -> a -> b
$ IPattern
-> Type
-> TypeErrorContext
-> Infer (TIPattern, [(String, Type)], Subst)
inferIPattern IPattern
p Type
expectedType' TypeErrorContext
ctx

    let s :: Subst
s = Subst -> Subst -> Subst
composeSubst Subst
s2 Subst
s1
    Type
finalType <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s Type
expectedType
    let tiLetPat :: TIPattern
tiLetPat = TypeScheme -> TIPatternNode -> TIPattern
TIPattern ([TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [] Type
finalType) ([TIBindingExpr] -> TIPattern -> TIPatternNode
TILetPat [TIBindingExpr]
bindingTIs TIPattern
tipat)
    -- Let bindings are not exported, only pattern bindings
    (TIPattern, [(String, Type)], Subst)
-> Infer (TIPattern, [(String, Type)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIPattern
tiLetPat, [(String, Type)]
patBindings, Subst
s)
  
  INotPat IPattern
p -> do
    -- Not pattern: infer the sub-pattern but don't use its bindings
    (TIPattern
tipat, [(String, Type)]
_, Subst
s) <- IPattern
-> Type
-> TypeErrorContext
-> Infer (TIPattern, [(String, Type)], Subst)
inferIPattern IPattern
p Type
expectedType TypeErrorContext
ctx
    Type
finalType <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s Type
expectedType
    let tiNotPat :: TIPattern
tiNotPat = TypeScheme -> TIPatternNode -> TIPattern
TIPattern ([TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [] Type
finalType) (TIPattern -> TIPatternNode
TINotPat TIPattern
tipat)
    (TIPattern, [(String, Type)], Subst)
-> Infer (TIPattern, [(String, Type)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIPattern
tiNotPat, [], Subst
s)
  
  IAndPat IPattern
p1 IPattern
p2 -> do
    -- And pattern: both patterns must match the same type
    -- Left bindings should be available to right pattern
    (TIPattern
tipat1, [(String, Type)]
bindings1, Subst
s1) <- IPattern
-> Type
-> TypeErrorContext
-> Infer (TIPattern, [(String, Type)], Subst)
inferIPattern IPattern
p1 Type
expectedType TypeErrorContext
ctx
    let schemes1 :: [(String, TypeScheme)]
schemes1 = [(String
var, [TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [] Type
ty) | (String
var, Type
ty) <- [(String, Type)]
bindings1]
    Type
expectedType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s1 Type
expectedType
    (TIPattern
tipat2, [(String, Type)]
bindings2, Subst
s2) <- [(String, TypeScheme)]
-> Infer (TIPattern, [(String, Type)], Subst)
-> Infer (TIPattern, [(String, Type)], Subst)
forall a. [(String, TypeScheme)] -> Infer a -> Infer a
withEnv [(String, TypeScheme)]
schemes1 (Infer (TIPattern, [(String, Type)], Subst)
 -> Infer (TIPattern, [(String, Type)], Subst))
-> Infer (TIPattern, [(String, Type)], Subst)
-> Infer (TIPattern, [(String, Type)], Subst)
forall a b. (a -> b) -> a -> b
$ IPattern
-> Type
-> TypeErrorContext
-> Infer (TIPattern, [(String, Type)], Subst)
inferIPattern IPattern
p2 Type
expectedType' TypeErrorContext
ctx
    let s :: Subst
s = Subst -> Subst -> Subst
composeSubst Subst
s2 Subst
s1
    -- Apply substitution to left bindings
    [(String, Type)]
bindings1'' <- ((String, Type)
 -> ExceptT TypeError (StateT InferState IO) (String, Type))
-> [(String, Type)]
-> ExceptT TypeError (StateT InferState IO) [(String, Type)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\(String
v, Type
ty) -> do
        Type
ty' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s2 Type
ty
        (String, Type)
-> ExceptT TypeError (StateT InferState IO) (String, Type)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (String
v, Type
ty')) [(String, Type)]
bindings1
    Type
finalType <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s Type
expectedType
    let bindings1' :: [(String, Type)]
bindings1' = [(String, Type)]
bindings1''
        tiAndPat :: TIPattern
tiAndPat = TypeScheme -> TIPatternNode -> TIPattern
TIPattern ([TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [] Type
finalType) (TIPattern -> TIPattern -> TIPatternNode
TIAndPat TIPattern
tipat1 TIPattern
tipat2)
    (TIPattern, [(String, Type)], Subst)
-> Infer (TIPattern, [(String, Type)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIPattern
tiAndPat, [(String, Type)]
bindings1' [(String, Type)] -> [(String, Type)] -> [(String, Type)]
forall a. [a] -> [a] -> [a]
++ [(String, Type)]
bindings2, Subst
s)
  
  IOrPat IPattern
p1 IPattern
p2 -> do
    -- Or pattern: both patterns must match the same type
    -- Left bindings should be available to right pattern for non-linear patterns
    (TIPattern
tipat1, [(String, Type)]
bindings1, Subst
s1) <- IPattern
-> Type
-> TypeErrorContext
-> Infer (TIPattern, [(String, Type)], Subst)
inferIPattern IPattern
p1 Type
expectedType TypeErrorContext
ctx
    let schemes1 :: [(String, TypeScheme)]
schemes1 = [(String
var, [TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [] Type
ty) | (String
var, Type
ty) <- [(String, Type)]
bindings1]
    Type
expectedType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s1 Type
expectedType
    (TIPattern
tipat2, [(String, Type)]
bindings2, Subst
s2) <- [(String, TypeScheme)]
-> Infer (TIPattern, [(String, Type)], Subst)
-> Infer (TIPattern, [(String, Type)], Subst)
forall a. [(String, TypeScheme)] -> Infer a -> Infer a
withEnv [(String, TypeScheme)]
schemes1 (Infer (TIPattern, [(String, Type)], Subst)
 -> Infer (TIPattern, [(String, Type)], Subst))
-> Infer (TIPattern, [(String, Type)], Subst)
-> Infer (TIPattern, [(String, Type)], Subst)
forall a b. (a -> b) -> a -> b
$ IPattern
-> Type
-> TypeErrorContext
-> Infer (TIPattern, [(String, Type)], Subst)
inferIPattern IPattern
p2 Type
expectedType' TypeErrorContext
ctx
    let s :: Subst
s = Subst -> Subst -> Subst
composeSubst Subst
s2 Subst
s1
    -- Apply substitution to left bindings
    [(String, Type)]
bindings1'' <- ((String, Type)
 -> ExceptT TypeError (StateT InferState IO) (String, Type))
-> [(String, Type)]
-> ExceptT TypeError (StateT InferState IO) [(String, Type)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\(String
v, Type
ty) -> do
        Type
ty' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s2 Type
ty
        (String, Type)
-> ExceptT TypeError (StateT InferState IO) (String, Type)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (String
v, Type
ty')) [(String, Type)]
bindings1
    Type
finalType <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s Type
expectedType
    let bindings1' :: [(String, Type)]
bindings1' = [(String, Type)]
bindings1''
        tiOrPat :: TIPattern
tiOrPat = TypeScheme -> TIPatternNode -> TIPattern
TIPattern ([TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [] Type
finalType) (TIPattern -> TIPattern -> TIPatternNode
TIOrPat TIPattern
tipat1 TIPattern
tipat2)
    -- For or patterns, ideally both branches should have same variables
    -- For now, we take union of bindings
    (TIPattern, [(String, Type)], Subst)
-> Infer (TIPattern, [(String, Type)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIPattern
tiOrPat, [(String, Type)]
bindings1' [(String, Type)] -> [(String, Type)] -> [(String, Type)]
forall a. [a] -> [a] -> [a]
++ [(String, Type)]
bindings2, Subst
s)
  
  IForallPat IPattern
p1 IPattern
p2 -> do
    -- Forall pattern: similar to and pattern
    -- Left bindings should be available to right pattern
    (TIPattern
tipat1, [(String, Type)]
bindings1, Subst
s1) <- IPattern
-> Type
-> TypeErrorContext
-> Infer (TIPattern, [(String, Type)], Subst)
inferIPattern IPattern
p1 Type
expectedType TypeErrorContext
ctx
    let schemes1 :: [(String, TypeScheme)]
schemes1 = [(String
var, [TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [] Type
ty) | (String
var, Type
ty) <- [(String, Type)]
bindings1]
    Type
expectedType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s1 Type
expectedType
    (TIPattern
tipat2, [(String, Type)]
bindings2, Subst
s2) <- [(String, TypeScheme)]
-> Infer (TIPattern, [(String, Type)], Subst)
-> Infer (TIPattern, [(String, Type)], Subst)
forall a. [(String, TypeScheme)] -> Infer a -> Infer a
withEnv [(String, TypeScheme)]
schemes1 (Infer (TIPattern, [(String, Type)], Subst)
 -> Infer (TIPattern, [(String, Type)], Subst))
-> Infer (TIPattern, [(String, Type)], Subst)
-> Infer (TIPattern, [(String, Type)], Subst)
forall a b. (a -> b) -> a -> b
$ IPattern
-> Type
-> TypeErrorContext
-> Infer (TIPattern, [(String, Type)], Subst)
inferIPattern IPattern
p2 Type
expectedType' TypeErrorContext
ctx
    let s :: Subst
s = Subst -> Subst -> Subst
composeSubst Subst
s2 Subst
s1
    -- Apply substitution to left bindings
    [(String, Type)]
bindings1'' <- ((String, Type)
 -> ExceptT TypeError (StateT InferState IO) (String, Type))
-> [(String, Type)]
-> ExceptT TypeError (StateT InferState IO) [(String, Type)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\(String
v, Type
ty) -> do
        Type
ty' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s2 Type
ty
        (String, Type)
-> ExceptT TypeError (StateT InferState IO) (String, Type)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (String
v, Type
ty')) [(String, Type)]
bindings1
    Type
finalType <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s Type
expectedType
    let bindings1' :: [(String, Type)]
bindings1' = [(String, Type)]
bindings1''
        tiForallPat :: TIPattern
tiForallPat = TypeScheme -> TIPatternNode -> TIPattern
TIPattern ([TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [] Type
finalType) (TIPattern -> TIPattern -> TIPatternNode
TIForallPat TIPattern
tipat1 TIPattern
tipat2)
    (TIPattern, [(String, Type)], Subst)
-> Infer (TIPattern, [(String, Type)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIPattern
tiForallPat, [(String, Type)]
bindings1' [(String, Type)] -> [(String, Type)] -> [(String, Type)]
forall a. [a] -> [a] -> [a]
++ [(String, Type)]
bindings2, Subst
s)
  
  ILoopPat String
var ILoopRange
range IPattern
p1 IPattern
p2 -> do
    -- Loop pattern: $var is the loop variable (Integer), range contains pattern
    -- First, infer the range pattern (third element of ILoopRange)
    let ILoopRange IExpr
startExpr IExpr
endExpr IPattern
rangePattern = ILoopRange
range
    (TIPattern
tiRangePat, [(String, Type)]
rangeBindings, Subst
s_range) <- IPattern
-> Type
-> TypeErrorContext
-> Infer (TIPattern, [(String, Type)], Subst)
inferIPattern IPattern
rangePattern Type
TInt TypeErrorContext
ctx
    
    -- Infer start and end expressions
    (TIExpr
startTI, Subst
s_start) <- IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
startExpr TypeErrorContext
ctx
    (TIExpr
endTI, Subst
s_end) <- IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
endExpr TypeErrorContext
ctx
    let tiLoopRange :: TILoopRange
tiLoopRange = TIExpr -> TIExpr -> TIPattern -> TILoopRange
TILoopRange TIExpr
startTI TIExpr
endTI TIPattern
tiRangePat
    
    -- Add loop variable binding (always Integer for loop index)
    let loopVarBinding :: (String, Type)
loopVarBinding = (String
var, Type
TInt)
        initialBindings :: [(String, Type)]
initialBindings = (String, Type)
loopVarBinding (String, Type) -> [(String, Type)] -> [(String, Type)]
forall a. a -> [a] -> [a]
: [(String, Type)]
rangeBindings
        schemes0 :: [(String, TypeScheme)]
schemes0 = [(String
v, [TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [] Type
ty) | (String
v, Type
ty) <- [(String, Type)]
initialBindings]
        s_combined :: Subst
s_combined = (Subst -> Subst -> Subst) -> Subst -> [Subst] -> Subst
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Subst -> Subst -> Subst
composeSubst Subst
emptySubst [Subst
s_end, Subst
s_start, Subst
s_range]

    -- Infer p1 with loop variable and range bindings in scope
    Type
expectedType1 <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s_combined Type
expectedType
    (TIPattern
tipat1, [(String, Type)]
bindings1, Subst
s1) <- [(String, TypeScheme)]
-> Infer (TIPattern, [(String, Type)], Subst)
-> Infer (TIPattern, [(String, Type)], Subst)
forall a. [(String, TypeScheme)] -> Infer a -> Infer a
withEnv [(String, TypeScheme)]
schemes0 (Infer (TIPattern, [(String, Type)], Subst)
 -> Infer (TIPattern, [(String, Type)], Subst))
-> Infer (TIPattern, [(String, Type)], Subst)
-> Infer (TIPattern, [(String, Type)], Subst)
forall a b. (a -> b) -> a -> b
$ IPattern
-> Type
-> TypeErrorContext
-> Infer (TIPattern, [(String, Type)], Subst)
inferIPattern IPattern
p1 Type
expectedType1 TypeErrorContext
ctx

    -- Infer p2 with all previous bindings in scope
    [(String, Type)]
allPrevBindings' <- ((String, Type)
 -> ExceptT TypeError (StateT InferState IO) (String, Type))
-> [(String, Type)]
-> ExceptT TypeError (StateT InferState IO) [(String, Type)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\(String
v, Type
ty) -> do
        Type
ty' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s1 Type
ty
        (String, Type)
-> ExceptT TypeError (StateT InferState IO) (String, Type)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (String
v, Type
ty')) [(String, Type)]
initialBindings
    let allPrevBindings :: [(String, Type)]
allPrevBindings = [(String, Type)]
allPrevBindings' [(String, Type)] -> [(String, Type)] -> [(String, Type)]
forall a. [a] -> [a] -> [a]
++ [(String, Type)]
bindings1
        schemes1 :: [(String, TypeScheme)]
schemes1 = [(String
v, [TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [] Type
ty) | (String
v, Type
ty) <- [(String, Type)]
allPrevBindings]
    Type
expectedType2 <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s1 Type
expectedType
    (TIPattern
tipat2, [(String, Type)]
bindings2, Subst
s2) <- [(String, TypeScheme)]
-> Infer (TIPattern, [(String, Type)], Subst)
-> Infer (TIPattern, [(String, Type)], Subst)
forall a. [(String, TypeScheme)] -> Infer a -> Infer a
withEnv [(String, TypeScheme)]
schemes1 (Infer (TIPattern, [(String, Type)], Subst)
 -> Infer (TIPattern, [(String, Type)], Subst))
-> Infer (TIPattern, [(String, Type)], Subst)
-> Infer (TIPattern, [(String, Type)], Subst)
forall a b. (a -> b) -> a -> b
$ IPattern
-> Type
-> TypeErrorContext
-> Infer (TIPattern, [(String, Type)], Subst)
inferIPattern IPattern
p2 Type
expectedType2 TypeErrorContext
ctx
    
    let s :: Subst
s = (Subst -> Subst -> Subst) -> Subst -> [Subst] -> Subst
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Subst -> Subst -> Subst
composeSubst Subst
emptySubst [Subst
s2, Subst
s1, Subst
s_combined]
    -- Apply final substitution to all bindings
    [(String, Type)]
finalBindings' <- ((String, Type)
 -> ExceptT TypeError (StateT InferState IO) (String, Type))
-> [(String, Type)]
-> ExceptT TypeError (StateT InferState IO) [(String, Type)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\(String
v, Type
ty) -> do
        Type
ty' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s Type
ty
        (String, Type)
-> ExceptT TypeError (StateT InferState IO) (String, Type)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (String
v, Type
ty')) ((String, Type)
loopVarBinding (String, Type) -> [(String, Type)] -> [(String, Type)]
forall a. a -> [a] -> [a]
: [(String, Type)]
rangeBindings [(String, Type)] -> [(String, Type)] -> [(String, Type)]
forall a. [a] -> [a] -> [a]
++ [(String, Type)]
bindings1 [(String, Type)] -> [(String, Type)] -> [(String, Type)]
forall a. [a] -> [a] -> [a]
++ [(String, Type)]
bindings2)
    Type
finalType <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s Type
expectedType
    let finalBindings :: [(String, Type)]
finalBindings = [(String, Type)]
finalBindings'
        tiLoopPat :: TIPattern
tiLoopPat = TypeScheme -> TIPatternNode -> TIPattern
TIPattern ([TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [] Type
finalType) (String -> TILoopRange -> TIPattern -> TIPattern -> TIPatternNode
TILoopPat String
var TILoopRange
tiLoopRange TIPattern
tipat1 TIPattern
tipat2)

    (TIPattern, [(String, Type)], Subst)
-> Infer (TIPattern, [(String, Type)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIPattern
tiLoopPat, [(String, Type)]
finalBindings, Subst
s)
  
  IPattern
IContPat -> do
    -- Continuation pattern: no bindings
    let tipat :: TIPattern
tipat = TypeScheme -> TIPatternNode -> TIPattern
TIPattern ([TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [] Type
expectedType) TIPatternNode
TIContPat
    (TIPattern, [(String, Type)], Subst)
-> Infer (TIPattern, [(String, Type)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIPattern
tipat, [], Subst
emptySubst)
  
  IPApplyPat IExpr
funcExpr [IPattern]
argPats -> do
    -- Pattern application: infer pattern function type
    (TIExpr
funcTI, Subst
s1) <- IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
funcExpr TypeErrorContext
ctx
    
    -- Pattern function should return a pattern that matches expectedType
    -- Infer argument patterns left-to-right with fresh types
    [Type]
argTypes <- (IPattern -> Infer Type)
-> [IPattern] -> ExceptT TypeError (StateT InferState IO) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\IPattern
_ -> String -> Infer Type
freshVar String
"parg") [IPattern]
argPats
    ([TIPattern]
tipats, [(String, Type)]
allBindings, Subst
s2) <- [IPattern]
-> [Type]
-> [(String, Type)]
-> Subst
-> TypeErrorContext
-> Infer ([TIPattern], [(String, Type)], Subst)
inferPatternsLeftToRight [IPattern]
argPats [Type]
argTypes [] Subst
s1 TypeErrorContext
ctx

    Type
finalType <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s2 Type
expectedType
    let tipat :: TIPattern
tipat = TypeScheme -> TIPatternNode -> TIPattern
TIPattern ([TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [] Type
finalType) (TIExpr -> [TIPattern] -> TIPatternNode
TIPApplyPat TIExpr
funcTI [TIPattern]
tipats)
    (TIPattern, [(String, Type)], Subst)
-> Infer (TIPattern, [(String, Type)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIPattern
tipat, [(String, Type)]
allBindings, Subst
s2)
  
  IVarPat String
name -> do
    -- Variable pattern (with ~): bind to expected type
    let tipat :: TIPattern
tipat = TypeScheme -> TIPatternNode -> TIPattern
TIPattern ([TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [] Type
expectedType) (String -> TIPatternNode
TIVarPat String
name)
    (TIPattern, [(String, Type)], Subst)
-> Infer (TIPattern, [(String, Type)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIPattern
tipat, [(String
name, Type
expectedType)], Subst
emptySubst)
  
  IInductiveOrPApplyPat String
name [IPattern]
pats -> do
    -- Could be either inductive pattern or pattern application
    -- Check pattern function environment to distinguish
    -- Pattern functions are ONLY in patternFuncEnv, pattern constructors are NOT
    PatternTypeEnv
patternFuncEnv <- Infer PatternTypeEnv
getPatternFuncEnv
    case String -> PatternTypeEnv -> Maybe TypeScheme
lookupPatternEnv String
name PatternTypeEnv
patternFuncEnv of
      Just TypeScheme
_ -> do
        -- It's a pattern function: treat as pattern application
        (TIPattern
tipat, [(String, Type)]
bindings, Subst
s) <- IPattern
-> Type
-> TypeErrorContext
-> Infer (TIPattern, [(String, Type)], Subst)
inferIPattern (IExpr -> [IPattern] -> IPattern
IPApplyPat (String -> IExpr
IVarExpr String
name) [IPattern]
pats) Type
expectedType TypeErrorContext
ctx
        (TIPattern, [(String, Type)], Subst)
-> Infer (TIPattern, [(String, Type)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIPattern
tipat, [(String, Type)]
bindings, Subst
s)
      Maybe TypeScheme
Nothing -> do
        -- It's an inductive pattern constructor (or not found, will be handled later)
        (TIPattern
tipat, [(String, Type)]
bindings, Subst
s) <- IPattern
-> Type
-> TypeErrorContext
-> Infer (TIPattern, [(String, Type)], Subst)
inferIPattern (String -> [IPattern] -> IPattern
IInductivePat String
name [IPattern]
pats) Type
expectedType TypeErrorContext
ctx
        -- Wrap it as InductiveOrPApplyPat (if it's actually an inductive pattern)
        case TIPattern -> TIPatternNode
tipPatternNode TIPattern
tipat of
          TIInductivePat String
_ [TIPattern]
tipats -> do
            let scheme :: TypeScheme
scheme = TIPattern -> TypeScheme
tipScheme TIPattern
tipat
                tiInductiveOrPApplyPat :: TIPattern
tiInductiveOrPApplyPat = TypeScheme -> TIPatternNode -> TIPattern
TIPattern TypeScheme
scheme (String -> [TIPattern] -> TIPatternNode
TIInductiveOrPApplyPat String
name [TIPattern]
tipats)
            (TIPattern, [(String, Type)], Subst)
-> Infer (TIPattern, [(String, Type)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIPattern
tiInductiveOrPApplyPat, [(String, Type)]
bindings, Subst
s)
          TIPatternNode
_ -> 
            -- Not an inductive pattern (e.g., already processed as pattern application)
            (TIPattern, [(String, Type)], Subst)
-> Infer (TIPattern, [(String, Type)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIPattern
tipat, [(String, Type)]
bindings, Subst
s)
  
  IPattern
ISeqNilPat -> do
    -- Sequence nil: no bindings
    let tipat :: TIPattern
tipat = TypeScheme -> TIPatternNode -> TIPattern
TIPattern ([TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [] Type
expectedType) TIPatternNode
TISeqNilPat
    (TIPattern, [(String, Type)], Subst)
-> Infer (TIPattern, [(String, Type)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIPattern
tipat, [], Subst
emptySubst)
  
  ISeqConsPat IPattern
p1 IPattern
p2 -> do
    -- Sequence cons: infer both patterns
    -- Left bindings should be available to right pattern
    (TIPattern
tipat1, [(String, Type)]
bindings1, Subst
s1) <- IPattern
-> Type
-> TypeErrorContext
-> Infer (TIPattern, [(String, Type)], Subst)
inferIPattern IPattern
p1 Type
expectedType TypeErrorContext
ctx
    let schemes1 :: [(String, TypeScheme)]
schemes1 = [(String
var, [TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [] Type
ty) | (String
var, Type
ty) <- [(String, Type)]
bindings1]
    Type
expectedType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s1 Type
expectedType
    (TIPattern
tipat2, [(String, Type)]
bindings2, Subst
s2) <- [(String, TypeScheme)]
-> Infer (TIPattern, [(String, Type)], Subst)
-> Infer (TIPattern, [(String, Type)], Subst)
forall a. [(String, TypeScheme)] -> Infer a -> Infer a
withEnv [(String, TypeScheme)]
schemes1 (Infer (TIPattern, [(String, Type)], Subst)
 -> Infer (TIPattern, [(String, Type)], Subst))
-> Infer (TIPattern, [(String, Type)], Subst)
-> Infer (TIPattern, [(String, Type)], Subst)
forall a b. (a -> b) -> a -> b
$ IPattern
-> Type
-> TypeErrorContext
-> Infer (TIPattern, [(String, Type)], Subst)
inferIPattern IPattern
p2 Type
expectedType' TypeErrorContext
ctx
    let s :: Subst
s = Subst -> Subst -> Subst
composeSubst Subst
s2 Subst
s1
    -- Apply substitution to left bindings
    [(String, Type)]
bindings1'' <- ((String, Type)
 -> ExceptT TypeError (StateT InferState IO) (String, Type))
-> [(String, Type)]
-> ExceptT TypeError (StateT InferState IO) [(String, Type)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\(String
v, Type
ty) -> do
        Type
ty' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s2 Type
ty
        (String, Type)
-> ExceptT TypeError (StateT InferState IO) (String, Type)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (String
v, Type
ty')) [(String, Type)]
bindings1
    Type
finalType <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s Type
expectedType
    let bindings1' :: [(String, Type)]
bindings1' = [(String, Type)]
bindings1''
        tipat :: TIPattern
tipat = TypeScheme -> TIPatternNode -> TIPattern
TIPattern ([TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [] Type
finalType) (TIPattern -> TIPattern -> TIPatternNode
TISeqConsPat TIPattern
tipat1 TIPattern
tipat2)
    (TIPattern, [(String, Type)], Subst)
-> Infer (TIPattern, [(String, Type)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIPattern
tipat, [(String, Type)]
bindings1' [(String, Type)] -> [(String, Type)] -> [(String, Type)]
forall a. [a] -> [a] -> [a]
++ [(String, Type)]
bindings2, Subst
s)
  
  IPattern
ILaterPatVar -> do
    -- Later pattern variable: no immediate binding
    let tipat :: TIPattern
tipat = TypeScheme -> TIPatternNode -> TIPattern
TIPattern ([TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [] Type
expectedType) TIPatternNode
TILaterPatVar
    (TIPattern, [(String, Type)], Subst)
-> Infer (TIPattern, [(String, Type)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIPattern
tipat, [], Subst
emptySubst)
  
  IDApplyPat IPattern
p [IPattern]
pats -> do
    -- D-apply pattern: infer base pattern and argument patterns
    -- Base pattern bindings should be available to argument patterns
    (TIPattern
tipat, [(String, Type)]
bindings1, Subst
s1) <- IPattern
-> Type
-> TypeErrorContext
-> Infer (TIPattern, [(String, Type)], Subst)
inferIPattern IPattern
p Type
expectedType TypeErrorContext
ctx
    
    -- Infer argument patterns left-to-right with base pattern bindings in scope
    [Type]
argTypes <- (IPattern -> Infer Type)
-> [IPattern] -> ExceptT TypeError (StateT InferState IO) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\IPattern
_ -> String -> Infer Type
freshVar String
"darg") [IPattern]
pats
    let schemes1 :: [(String, TypeScheme)]
schemes1 = [(String
var, [TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [] Type
ty) | (String
var, Type
ty) <- [(String, Type)]
bindings1]
    ([TIPattern]
tipats, [(String, Type)]
argBindings, Subst
s2) <- [(String, TypeScheme)]
-> Infer ([TIPattern], [(String, Type)], Subst)
-> Infer ([TIPattern], [(String, Type)], Subst)
forall a. [(String, TypeScheme)] -> Infer a -> Infer a
withEnv [(String, TypeScheme)]
schemes1 (Infer ([TIPattern], [(String, Type)], Subst)
 -> Infer ([TIPattern], [(String, Type)], Subst))
-> Infer ([TIPattern], [(String, Type)], Subst)
-> Infer ([TIPattern], [(String, Type)], Subst)
forall a b. (a -> b) -> a -> b
$ [IPattern]
-> [Type]
-> [(String, Type)]
-> Subst
-> TypeErrorContext
-> Infer ([TIPattern], [(String, Type)], Subst)
inferPatternsLeftToRight [IPattern]
pats [Type]
argTypes [] Subst
s1 TypeErrorContext
ctx
    
    let s :: Subst
s = Subst -> Subst -> Subst
composeSubst Subst
s2 Subst
s1
    -- Apply substitution to base bindings
    [(String, Type)]
bindings1'' <- ((String, Type)
 -> ExceptT TypeError (StateT InferState IO) (String, Type))
-> [(String, Type)]
-> ExceptT TypeError (StateT InferState IO) [(String, Type)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\(String
v, Type
ty) -> do
        Type
ty' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s2 Type
ty
        (String, Type)
-> ExceptT TypeError (StateT InferState IO) (String, Type)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (String
v, Type
ty')) [(String, Type)]
bindings1
    Type
finalType <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s Type
expectedType
    let bindings1' :: [(String, Type)]
bindings1' = [(String, Type)]
bindings1''
        tiDApplyPat :: TIPattern
tiDApplyPat = TypeScheme -> TIPatternNode -> TIPattern
TIPattern ([TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [] Type
finalType) (TIPattern -> [TIPattern] -> TIPatternNode
TIDApplyPat TIPattern
tipat [TIPattern]
tipats)
    (TIPattern, [(String, Type)], Subst)
-> Infer (TIPattern, [(String, Type)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TIPattern
tiDApplyPat, [(String, Type)]
bindings1' [(String, Type)] -> [(String, Type)] -> [(String, Type)]
forall a. [a] -> [a] -> [a]
++ [(String, Type)]
argBindings, Subst
s)
  where
    -- Extract function argument types and result type
    -- e.g., a -> b -> c -> d  =>  ([a, b, c], d)
    extractFunctionArgs :: Type -> ([Type], Type)
    extractFunctionArgs :: Type -> ([Type], Type)
extractFunctionArgs (TFun Type
arg Type
rest) = 
      let ([Type]
args, Type
result) = Type -> ([Type], Type)
extractFunctionArgs Type
rest
      in (Type
arg Type -> [Type] -> [Type]
forall a. a -> [a] -> [a]
: [Type]
args, Type
result)
    extractFunctionArgs Type
t = ([], Type
t)

-- | Infer application (helper)
-- NEW: Returns TIExpr instead of (IExpr, Type, Subst)
inferIApplication :: String -> Type -> [IExpr] -> Subst -> Infer (TIExpr, Subst)
inferIApplication :: String -> Type -> [IExpr] -> Subst -> Infer (TIExpr, Subst)
inferIApplication String
funcName Type
funcType [IExpr]
args Subst
initSubst = do
  let funcTI :: TIExpr
funcTI = Type -> TIExprNode -> TIExpr
mkTIExpr Type
funcType (String -> TIExprNode
TIVarExpr String
funcName)
  TIExpr
-> Type
-> [IExpr]
-> Subst
-> TypeErrorContext
-> Infer (TIExpr, Subst)
inferIApplicationWithContext TIExpr
funcTI Type
funcType [IExpr]
args Subst
initSubst TypeErrorContext
emptyContext

-- TensorMap insertion logic has been moved to Language.Egison.Type.TensorMapInsertion
-- This keeps type inference focused on type checking only

-- | Infer application (helper) with context
-- NEW: Returns TIExpr instead of (IExpr, Type, Subst)
-- TensorMap insertion has been moved to Phase 8 (TensorMapInsertion module)
-- This function now only performs type inference and unification
-- When a Tensor argument is passed to a scalar parameter, the result type is wrapped in Tensor
--
-- IMPORTANT: Non-function arguments are unified first to let data types (like lists)
-- constrain type variables before callback function types are unified.
-- This ensures that foldl (+) 0 [t1, t2] properly infers a = Tensor Integer from the list
-- before trying to match the callback type.
inferIApplicationWithContext :: TIExpr -> Type -> [IExpr] -> Subst -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIApplicationWithContext :: TIExpr
-> Type
-> [IExpr]
-> Subst
-> TypeErrorContext
-> Infer (TIExpr, Subst)
inferIApplicationWithContext TIExpr
funcTIExpr Type
funcType [IExpr]
args Subst
initSubst TypeErrorContext
ctx = do
  -- Infer argument types
  [(TIExpr, Subst)]
argResults <- (IExpr -> Infer (TIExpr, Subst))
-> [IExpr]
-> ExceptT TypeError (StateT InferState IO) [(TIExpr, Subst)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\IExpr
arg -> IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
arg TypeErrorContext
ctx) [IExpr]
args
  let argTIExprs :: [TIExpr]
argTIExprs = ((TIExpr, Subst) -> TIExpr) -> [(TIExpr, Subst)] -> [TIExpr]
forall a b. (a -> b) -> [a] -> [b]
map (TIExpr, Subst) -> TIExpr
forall a b. (a, b) -> a
fst [(TIExpr, Subst)]
argResults
      argTypes :: [Type]
argTypes = ((TIExpr, Subst) -> Type) -> [(TIExpr, Subst)] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (TIExpr -> Type
tiExprType (TIExpr -> Type)
-> ((TIExpr, Subst) -> TIExpr) -> (TIExpr, Subst) -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TIExpr, Subst) -> TIExpr
forall a b. (a, b) -> a
fst) [(TIExpr, Subst)]
argResults
      argSubst :: Subst
argSubst = (Subst -> Subst -> Subst) -> Subst -> [Subst] -> Subst
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Subst -> Subst -> Subst
composeSubst Subst
initSubst (((TIExpr, Subst) -> Subst) -> [(TIExpr, Subst)] -> [Subst]
forall a b. (a -> b) -> [a] -> [b]
map (TIExpr, Subst) -> Subst
forall a b. (a, b) -> b
snd [(TIExpr, Subst)]
argResults)

  -- Create fresh type variables for parameters and result
  [Type]
paramVars <- (Int -> Infer Type)
-> [Int] -> ExceptT TypeError (StateT InferState IO) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\Int
i -> String -> Infer Type
freshVar (String
"param" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i)) [Int
1..[IExpr] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [IExpr]
args]
  Type
resultType <- String -> Infer Type
freshVar String
"result"
  let expectedFuncType :: Type
expectedFuncType = (Type -> Type -> Type) -> Type -> [Type] -> Type
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Type -> Type -> Type
TFun Type
resultType [Type]
paramVars
  Type
appliedFuncType <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
argSubst Type
funcType


  -- First unify function type structure to get parameter bindings
  let funcScheme :: TypeScheme
funcScheme = TIExpr -> TypeScheme
tiScheme TIExpr
funcTIExpr
      (Forall [TyVar]
_tvs [Constraint]
funcConstraints Type
_) = TypeScheme
funcScheme
  ClassEnv
classEnv <- Infer ClassEnv
getClassEnv
  -- Include constraints from both the function being applied AND the inference context
  -- The context constraints include constraints from outer scopes (e.g., {Num a} from (.) definition)
  [Constraint]
contextConstraints <- Infer [Constraint]
getConstraints
  let constraints :: [Constraint]
constraints = [Constraint]
funcConstraints [Constraint] -> [Constraint] -> [Constraint]
forall a. [a] -> [a] -> [a]
++ [Constraint]
contextConstraints
  case ClassEnv
-> [Constraint] -> Type -> Type -> Either UnifyError (Subst, Bool)
Unify.unifyWithConstraints ClassEnv
classEnv [Constraint]
constraints Type
appliedFuncType Type
expectedFuncType of
    Right (Subst
s1, Bool
flag1) -> do
      -- Now unify argument types with parameter types
      -- Key: Unify non-function arguments FIRST to let data types constrain type variables
      [Type]
paramTypesRaw <- (Type -> Infer Type)
-> [Type] -> ExceptT TypeError (StateT InferState IO) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s1) [Type]
paramVars
      let indexedArgs :: [(Int, Type, Type)]
indexedArgs = [Int] -> [Type] -> [Type] -> [(Int, Type, Type)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Int
0..] [Type]
argTypes [Type]
paramTypesRaw

      -- Classify arguments: non-functions first, then functions
      -- A type is considered a function if it's TFun
          isArgFunction :: Type -> Bool
isArgFunction (TFun Type
_ Type
_) = Bool
True
          isArgFunction Type
_ = Bool
False
          ([(Int, Type, Type)]
funcArgsList, [(Int, Type, Type)]
nonFuncArgsList) = ((Int, Type, Type) -> Bool)
-> [(Int, Type, Type)]
-> ([(Int, Type, Type)], [(Int, Type, Type)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (\(Int
_, Type
at, Type
_) -> Type -> Bool
isArgFunction Type
at) [(Int, Type, Type)]
indexedArgs

      -- Unify non-function arguments first (data types like lists)
      -- IMPORTANT: Apply substitution to constraints so that constraint checking works correctly
      (Subst
s2, Bool
flag2) <- ((Subst, Bool)
 -> (Int, Type, Type)
 -> ExceptT TypeError (StateT InferState IO) (Subst, Bool))
-> (Subst, Bool)
-> [(Int, Type, Type)]
-> ExceptT TypeError (StateT InferState IO) (Subst, Bool)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (\(Subst
s, Bool
flagAcc) (Int
_, Type
at, Type
pt) -> do
                     Type
at' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s Type
at
                     Type
pt' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s Type
pt
                     let cs' :: [Constraint]
cs' = (Constraint -> Constraint) -> [Constraint] -> [Constraint]
forall a b. (a -> b) -> [a] -> [b]
map (Subst -> Constraint -> Constraint
applySubstConstraint Subst
s) [Constraint]
constraints
                     case ClassEnv
-> [Constraint] -> Type -> Type -> Either UnifyError (Subst, Bool)
Unify.unifyWithConstraints ClassEnv
classEnv [Constraint]
cs' Type
at' Type
pt' of
                       Right (Subst
s', Bool
flag') -> (Subst, Bool)
-> ExceptT TypeError (StateT InferState IO) (Subst, Bool)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Subst -> Subst -> Subst
composeSubst Subst
s' Subst
s, Bool
flagAcc Bool -> Bool -> Bool
|| Bool
flag')
                       Left UnifyError
_ -> TypeError -> ExceptT TypeError (StateT InferState IO) (Subst, Bool)
forall a. TypeError -> ExceptT TypeError (StateT InferState IO) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (TypeError
 -> ExceptT TypeError (StateT InferState IO) (Subst, Bool))
-> TypeError
-> ExceptT TypeError (StateT InferState IO) (Subst, Bool)
forall a b. (a -> b) -> a -> b
$ Type -> Type -> TypeErrorContext -> TypeError
UnificationError Type
at' Type
pt' TypeErrorContext
ctx
                  ) (Subst
s1, Bool
flag1) [(Int, Type, Type)]
nonFuncArgsList

      -- Then unify function arguments (callbacks)
      -- IMPORTANT: Include constraints from the argument's type scheme (e.g., {Num t} from (+))
      -- so that constraint checking works correctly for the argument's type variables
      (Subst
s3, Bool
flag3) <- ((Subst, Bool)
 -> (Int, Type, Type)
 -> ExceptT TypeError (StateT InferState IO) (Subst, Bool))
-> (Subst, Bool)
-> [(Int, Type, Type)]
-> ExceptT TypeError (StateT InferState IO) (Subst, Bool)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (\(Subst
s, Bool
flagAcc) (Int
idx, Type
at, Type
pt) -> do
                     Type
at' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s Type
at
                     Type
pt' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s Type
pt
                     let -- Get constraints from both the outer function and the argument itself
                         outerCs :: [Constraint]
outerCs = (Constraint -> Constraint) -> [Constraint] -> [Constraint]
forall a b. (a -> b) -> [a] -> [b]
map (Subst -> Constraint -> Constraint
applySubstConstraint Subst
s) [Constraint]
constraints
                         argScheme :: TypeScheme
argScheme = TIExpr -> TypeScheme
tiScheme ([TIExpr]
argTIExprs [TIExpr] -> Int -> TIExpr
forall a. HasCallStack => [a] -> Int -> a
!! Int
idx)
                         (Forall [TyVar]
_ [Constraint]
argConstraints Type
_) = TypeScheme
argScheme
                         argCs :: [Constraint]
argCs = (Constraint -> Constraint) -> [Constraint] -> [Constraint]
forall a b. (a -> b) -> [a] -> [b]
map (Subst -> Constraint -> Constraint
applySubstConstraint Subst
s) [Constraint]
argConstraints
                         allCs :: [Constraint]
allCs = [Constraint]
outerCs [Constraint] -> [Constraint] -> [Constraint]
forall a. [a] -> [a] -> [a]
++ [Constraint]
argCs
                     case ClassEnv
-> [Constraint] -> Type -> Type -> Either UnifyError (Subst, Bool)
Unify.unifyWithConstraints ClassEnv
classEnv [Constraint]
allCs Type
at' Type
pt' of
                       Right (Subst
s', Bool
flag') -> (Subst, Bool)
-> ExceptT TypeError (StateT InferState IO) (Subst, Bool)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Subst -> Subst -> Subst
composeSubst Subst
s' Subst
s, Bool
flagAcc Bool -> Bool -> Bool
|| Bool
flag')
                       Left UnifyError
_ -> TypeError -> ExceptT TypeError (StateT InferState IO) (Subst, Bool)
forall a. TypeError -> ExceptT TypeError (StateT InferState IO) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (TypeError
 -> ExceptT TypeError (StateT InferState IO) (Subst, Bool))
-> TypeError
-> ExceptT TypeError (StateT InferState IO) (Subst, Bool)
forall a b. (a -> b) -> a -> b
$ Type -> Type -> TypeErrorContext -> TypeError
UnificationError Type
at' Type
pt' TypeErrorContext
ctx
                  ) (Subst
s2, Bool
flag2) [(Int, Type, Type)]
funcArgsList

      let finalS :: Subst
finalS = Subst -> Subst -> Subst
composeSubst Subst
s3 Subst
argSubst
      Type
baseResultType <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
finalS Type
resultType

      -- Check if Tensor was unwrapped during unification (flag3)
      -- If so, wrap the result type in Tensor
      -- This handles cases like sum : {Num a} [a] -> a with [Tensor Integer]
      -- where a unifies with Tensor Integer but gets unwrapped to Integer
      let needsTensorWrap :: Bool
needsTensorWrap = Bool
flag3
          finalType :: Type
finalType = if Bool
needsTensorWrap Bool -> Bool -> Bool
&& Bool -> Bool
not (Type -> Bool
Types.isTensorType Type
baseResultType)
                      then Type -> Type
TTensor Type
baseResultType
                      else Type
baseResultType

      -- Apply substitution to constraints and simplify Tensor constraints
      -- This rewrites C (Tensor a) to C a when appropriate, while keeping types as Tensor a
      -- IMPORTANT: Only use funcConstraints for the result scheme, not contextConstraints
      -- contextConstraints are from outer scopes and should not be propagated to sub-expressions
      let updatedFuncConstraints :: [Constraint]
updatedFuncConstraints = (Constraint -> Constraint) -> [Constraint] -> [Constraint]
forall a b. (a -> b) -> [a] -> [b]
map (Subst -> Constraint -> Constraint
applySubstConstraint Subst
finalS) [Constraint]
funcConstraints
          simplifiedFuncConstraints :: [Constraint]
simplifiedFuncConstraints = ClassEnv -> [Constraint] -> [Constraint]
simplifyTensorConstraints ClassEnv
classEnv [Constraint]
updatedFuncConstraints
          -- Deduplicate constraints
          deduplicatedConstraints :: [Constraint]
deduplicatedConstraints = [Constraint] -> [Constraint]
forall a. Eq a => [a] -> [a]
nub [Constraint]
simplifiedFuncConstraints
          -- Filter out constraints on concrete types (only keep constraints on type variables)
          -- This prevents constraints like {Num (Tensor t0)} from appearing in result types
          isTypeVarConstraint :: Constraint -> Bool
isTypeVarConstraint (Constraint String
_ (TVar TyVar
_)) = Bool
True
          isTypeVarConstraint Constraint
_ = Bool
False
          typeVarConstraints :: [Constraint]
typeVarConstraints = (Constraint -> Bool) -> [Constraint] -> [Constraint]
forall a. (a -> Bool) -> [a] -> [a]
filter Constraint -> Bool
isTypeVarConstraint [Constraint]
deduplicatedConstraints
          -- Result constraints: functions (partial applications) keep constraints,
          -- but values (fully applied) don't need them
          resultConstraints :: [Constraint]
resultConstraints = case Type
finalType of
                                TFun Type
_ Type
_ -> [Constraint]
typeVarConstraints  -- Partial application
                                Type
_ -> []  -- Fully applied: no constraints needed
          resultScheme :: TypeScheme
resultScheme = [TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [Constraint]
resultConstraints Type
finalType

          -- Update function and argument TIExprs
          -- IMPORTANT: Use applySubstToTIExprWithClassEnv to adjust substitution based on constraints
          -- When {Num t0} t0 -> t0 is unified with Tensor t1, if Num (Tensor t1) has no instance,
          -- the substitution is adjusted to t0 -> t1 (unwrapping the Tensor)
          updatedFuncTI :: TIExpr
updatedFuncTI = ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
classEnv Subst
finalS TIExpr
funcTIExpr
          updatedArgTIs :: [TIExpr]
updatedArgTIs = (TIExpr -> TIExpr) -> [TIExpr] -> [TIExpr]
forall a b. (a -> b) -> [a] -> [b]
map (ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
classEnv Subst
finalS) [TIExpr]
argTIExprs

      (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TypeScheme -> TIExprNode -> TIExpr
TIExpr TypeScheme
resultScheme (TIExpr -> [TIExpr] -> TIExprNode
TIApplyExpr TIExpr
updatedFuncTI [TIExpr]
updatedArgTIs), Subst
finalS)

    Left UnifyError
_ ->
      -- Special case: if function has type MathExpr, allow application returning MathExpr
      -- (handles FunctionData application, e.g. f 0 where f := function (x))
      case Type
appliedFuncType of
        Type
TMathExpr -> do
          ClassEnv
classEnv' <- Infer ClassEnv
getClassEnv
          let resultScheme :: TypeScheme
resultScheme = [TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [] Type
TMathExpr
              updatedFuncTI :: TIExpr
updatedFuncTI = ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
classEnv' Subst
argSubst TIExpr
funcTIExpr
              updatedArgTIs :: [TIExpr]
updatedArgTIs = (TIExpr -> TIExpr) -> [TIExpr] -> [TIExpr]
forall a b. (a -> b) -> [a] -> [b]
map (ClassEnv -> Subst -> TIExpr -> TIExpr
applySubstToTIExprWithClassEnv ClassEnv
classEnv' Subst
argSubst) [TIExpr]
argTIExprs
          (TIExpr, Subst) -> Infer (TIExpr, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TypeScheme -> TIExprNode -> TIExpr
TIExpr TypeScheme
resultScheme (TIExpr -> [TIExpr] -> TIExprNode
TIApplyExpr TIExpr
updatedFuncTI [TIExpr]
updatedArgTIs), Subst
argSubst)
        Type
_ -> TypeError -> Infer (TIExpr, Subst)
forall a. TypeError -> ExceptT TypeError (StateT InferState IO) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (TypeError -> Infer (TIExpr, Subst))
-> TypeError -> Infer (TIExpr, Subst)
forall a b. (a -> b) -> a -> b
$ Type -> Type -> TypeErrorContext -> TypeError
UnificationError Type
appliedFuncType Type
expectedFuncType TypeErrorContext
ctx
-- | Infer let bindings (non-recursive)

-- | Infer let bindings (non-recursive) with context
-- NEW: Returns TIBindingExpr instead of IBindingExpr
-- Infer IO bindings for do expressions
inferIOBindingsWithContext :: [IBindingExpr] -> TypeEnv -> Subst -> TypeErrorContext -> Infer ([TIBindingExpr], [(String, TypeScheme)], Subst)
inferIOBindingsWithContext :: [IBindingExpr]
-> TypeEnv
-> Subst
-> TypeErrorContext
-> Infer ([TIBindingExpr], [(String, TypeScheme)], Subst)
inferIOBindingsWithContext [] TypeEnv
_env Subst
s TypeErrorContext
_ctx = ([TIBindingExpr], [(String, TypeScheme)], Subst)
-> Infer ([TIBindingExpr], [(String, TypeScheme)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return ([], [], Subst
s)
inferIOBindingsWithContext ((IPrimitiveDataPattern
pat, IExpr
expr):[IBindingExpr]
bs) TypeEnv
env Subst
s TypeErrorContext
ctx = do
  -- Infer the type of the expression
  (TIExpr
exprTI, Subst
s1) <- IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
expr TypeErrorContext
ctx
  let exprType :: Type
exprType = TIExpr -> Type
tiExprType TIExpr
exprTI

  -- The expression should be of type IO a
  Type
innerType <- String -> Infer Type
freshVar String
"ioInner"
  Type
exprType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s1 Type
exprType
  Subst
s2 <- Type -> Type -> TypeErrorContext -> Infer Subst
unifyTypesWithContext Type
exprType' (Type -> Type
TIO Type
innerType) TypeErrorContext
ctx
  let s12 :: Subst
s12 = Subst -> Subst -> Subst
composeSubst Subst
s2 Subst
s1
  Type
actualInnerType <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s12 Type
innerType

  -- Create expected type from pattern and unify with inner type
  (Type
patternType, Subst
s3) <- IPrimitiveDataPattern
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
inferPatternType IPrimitiveDataPattern
pat
  let s123 :: Subst
s123 = Subst -> Subst -> Subst
composeSubst Subst
s3 Subst
s12
  Type
actualInnerType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s123 Type
actualInnerType
  Type
patternType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s123 Type
patternType
  Subst
s4 <- Type -> Type -> TypeErrorContext -> Infer Subst
unifyTypesWithContext Type
actualInnerType' Type
patternType' TypeErrorContext
ctx

  -- Apply all substitutions and extract bindings with inner type
  let finalS :: Subst
finalS = Subst -> Subst -> Subst
composeSubst Subst
s4 Subst
s123
  Type
finalInnerType <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
finalS Type
actualInnerType
  let bindings :: [(String, TypeScheme)]
bindings = IPrimitiveDataPattern -> Type -> [(String, TypeScheme)]
extractIBindingsFromPattern IPrimitiveDataPattern
pat Type
finalInnerType
      s' :: Subst
s' = Subst -> Subst -> Subst
composeSubst Subst
finalS Subst
s

  TypeEnv
_env' <- Infer TypeEnv
getEnv
  let extendedEnvList :: [(String, TypeScheme)]
extendedEnvList = [(String, TypeScheme)]
bindings  -- Already a list of (String, TypeScheme)
  ([TIBindingExpr]
restBindingTIs, [(String, TypeScheme)]
restBindings, Subst
s2') <- [(String, TypeScheme)]
-> Infer ([TIBindingExpr], [(String, TypeScheme)], Subst)
-> Infer ([TIBindingExpr], [(String, TypeScheme)], Subst)
forall a. [(String, TypeScheme)] -> Infer a -> Infer a
withEnv [(String, TypeScheme)]
extendedEnvList (Infer ([TIBindingExpr], [(String, TypeScheme)], Subst)
 -> Infer ([TIBindingExpr], [(String, TypeScheme)], Subst))
-> Infer ([TIBindingExpr], [(String, TypeScheme)], Subst)
-> Infer ([TIBindingExpr], [(String, TypeScheme)], Subst)
forall a b. (a -> b) -> a -> b
$ [IBindingExpr]
-> TypeEnv
-> Subst
-> TypeErrorContext
-> Infer ([TIBindingExpr], [(String, TypeScheme)], Subst)
inferIOBindingsWithContext [IBindingExpr]
bs TypeEnv
env Subst
s' TypeErrorContext
ctx
  ([TIBindingExpr], [(String, TypeScheme)], Subst)
-> Infer ([TIBindingExpr], [(String, TypeScheme)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return ((IPrimitiveDataPattern
pat, TIExpr
exprTI) TIBindingExpr -> [TIBindingExpr] -> [TIBindingExpr]
forall a. a -> [a] -> [a]
: [TIBindingExpr]
restBindingTIs, [(String, TypeScheme)]
bindings [(String, TypeScheme)]
-> [(String, TypeScheme)] -> [(String, TypeScheme)]
forall a. [a] -> [a] -> [a]
++ [(String, TypeScheme)]
restBindings, Subst
s2')
  where
    -- Infer the type that a pattern expects
    inferPatternType :: IPrimitiveDataPattern -> Infer (Type, Subst)
    inferPatternType :: IPrimitiveDataPattern
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
inferPatternType IPrimitiveDataPattern
PDWildCard = do
      Type
t <- String -> Infer Type
freshVar String
"wild"
      (Type, Subst)
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
t, Subst
emptySubst)
    inferPatternType (PDPatVar Var
_) = do
      Type
t <- String -> Infer Type
freshVar String
"patvar"
      (Type, Subst)
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
t, Subst
emptySubst)
    inferPatternType (PDTuplePat [IPrimitiveDataPattern]
pats) = do
      [(Type, Subst)]
results <- (IPrimitiveDataPattern
 -> ExceptT TypeError (StateT InferState IO) (Type, Subst))
-> [IPrimitiveDataPattern]
-> ExceptT TypeError (StateT InferState IO) [(Type, Subst)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM IPrimitiveDataPattern
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
inferPatternType [IPrimitiveDataPattern]
pats
      let types :: [Type]
types = ((Type, Subst) -> Type) -> [(Type, Subst)] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type, Subst) -> Type
forall a b. (a, b) -> a
fst [(Type, Subst)]
results
          substs :: [Subst]
substs = ((Type, Subst) -> Subst) -> [(Type, Subst)] -> [Subst]
forall a b. (a -> b) -> [a] -> [b]
map (Type, Subst) -> Subst
forall a b. (a, b) -> b
snd [(Type, Subst)]
results
          s :: Subst
s = (Subst -> Subst -> Subst) -> Subst -> [Subst] -> Subst
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Subst -> Subst -> Subst
composeSubst Subst
emptySubst [Subst]
substs
      (Type, Subst)
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return ([Type] -> Type
TTuple [Type]
types, Subst
s)
    inferPatternType IPrimitiveDataPattern
PDEmptyPat = (Type, Subst)
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> Type
TCollection (TyVar -> Type
TVar (String -> TyVar
TyVar String
"a")), Subst
emptySubst)
    inferPatternType (PDConsPat IPrimitiveDataPattern
_ IPrimitiveDataPattern
_) = do
      Type
elemType <- String -> Infer Type
freshVar String
"elem"
      (Type, Subst)
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> Type
TCollection Type
elemType, Subst
emptySubst)
    inferPatternType (PDSnocPat IPrimitiveDataPattern
_ IPrimitiveDataPattern
_) = do
      Type
elemType <- String -> Infer Type
freshVar String
"elem"
      (Type, Subst)
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> Type
TCollection Type
elemType, Subst
emptySubst)
    inferPatternType (PDInductivePat String
name [IPrimitiveDataPattern]
pats) = do
      [(Type, Subst)]
results <- (IPrimitiveDataPattern
 -> ExceptT TypeError (StateT InferState IO) (Type, Subst))
-> [IPrimitiveDataPattern]
-> ExceptT TypeError (StateT InferState IO) [(Type, Subst)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM IPrimitiveDataPattern
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
inferPatternType [IPrimitiveDataPattern]
pats
      let types :: [Type]
types = ((Type, Subst) -> Type) -> [(Type, Subst)] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type, Subst) -> Type
forall a b. (a, b) -> a
fst [(Type, Subst)]
results
          substs :: [Subst]
substs = ((Type, Subst) -> Subst) -> [(Type, Subst)] -> [Subst]
forall a b. (a -> b) -> [a] -> [b]
map (Type, Subst) -> Subst
forall a b. (a, b) -> b
snd [(Type, Subst)]
results
          s :: Subst
s = (Subst -> Subst -> Subst) -> Subst -> [Subst] -> Subst
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Subst -> Subst -> Subst
composeSubst Subst
emptySubst [Subst]
substs
      (Type, Subst)
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (String -> [Type] -> Type
TInductive String
name [Type]
types, Subst
s)
    inferPatternType (PDConstantPat ConstantExpr
c) = do
      Type
ty <- ConstantExpr -> Infer Type
inferConstant ConstantExpr
c
      (Type, Subst)
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
ty, Subst
emptySubst)
    -- ScalarData primitive patterns
    inferPatternType (PDDivPat IPrimitiveDataPattern
_ IPrimitiveDataPattern
_) = (Type, Subst)
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
TMathExpr, Subst
emptySubst)
    inferPatternType (PDPlusPat IPrimitiveDataPattern
_) = (Type, Subst)
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
TPolyExpr, Subst
emptySubst)
    inferPatternType (PDTermPat IPrimitiveDataPattern
_ IPrimitiveDataPattern
_) = (Type, Subst)
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
TTermExpr, Subst
emptySubst)
    inferPatternType (PDSymbolPat IPrimitiveDataPattern
_ IPrimitiveDataPattern
_) = (Type, Subst)
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
TSymbolExpr, Subst
emptySubst)
    inferPatternType (PDApply1Pat IPrimitiveDataPattern
_ IPrimitiveDataPattern
_) = (Type, Subst)
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
TSymbolExpr, Subst
emptySubst)
    inferPatternType (PDApply2Pat IPrimitiveDataPattern
_ IPrimitiveDataPattern
_ IPrimitiveDataPattern
_) = (Type, Subst)
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
TSymbolExpr, Subst
emptySubst)
    inferPatternType (PDApply3Pat IPrimitiveDataPattern
_ IPrimitiveDataPattern
_ IPrimitiveDataPattern
_ IPrimitiveDataPattern
_) = (Type, Subst)
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
TSymbolExpr, Subst
emptySubst)
    inferPatternType (PDApply4Pat IPrimitiveDataPattern
_ IPrimitiveDataPattern
_ IPrimitiveDataPattern
_ IPrimitiveDataPattern
_ IPrimitiveDataPattern
_) = (Type, Subst)
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
TSymbolExpr, Subst
emptySubst)
    inferPatternType (PDQuotePat IPrimitiveDataPattern
_) = (Type, Subst)
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
TSymbolExpr, Subst
emptySubst)
    inferPatternType (PDFunctionPat IPrimitiveDataPattern
_ IPrimitiveDataPattern
_) = (Type, Subst)
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
TSymbolExpr, Subst
emptySubst)
    inferPatternType (PDSubPat IPrimitiveDataPattern
_) = (Type, Subst)
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
TIndexExpr, Subst
emptySubst)
    inferPatternType (PDSupPat IPrimitiveDataPattern
_) = (Type, Subst)
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
TIndexExpr, Subst
emptySubst)
    inferPatternType (PDUserPat IPrimitiveDataPattern
_) = (Type, Subst)
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
TIndexExpr, Subst
emptySubst)

-- | Apply substitution recursively until a fixed point is reached
-- This ensures that nested type variables are fully resolved
-- For example, if s = {t1 -> (Integer, t2), t2 -> [Integer]}, then
-- applySubstRecursively s t1 will return (Integer, [Integer])
-- instead of (Integer, t2)
applySubstRecursively :: Subst -> Type -> Infer Type
applySubstRecursively :: Subst -> Type -> Infer Type
applySubstRecursively Subst
s Type
t = Subst -> Type -> Int -> Infer Type
applySubstRecursively' Subst
s Type
t Int
5  -- Max 5 iterations (reduced from 10)
  where
    applySubstRecursively' :: Subst -> Type -> Int -> Infer Type
    applySubstRecursively' :: Subst -> Type -> Int -> Infer Type
applySubstRecursively' Subst
_ Type
t Int
0 = Type -> Infer Type
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return Type
t  -- Stop after max iterations
    applySubstRecursively' Subst
s Type
t Int
n = do
      Type
t' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s Type
t
      if Type
t' Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
t
        then Type -> Infer Type
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return Type
t
        else Subst -> Type -> Int -> Infer Type
applySubstRecursively' Subst
s Type
t' (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)

inferIBindingsWithContext :: [IBindingExpr] -> TypeEnv -> Subst -> TypeErrorContext -> Infer ([TIBindingExpr], [(String, TypeScheme)], Subst)
inferIBindingsWithContext :: [IBindingExpr]
-> TypeEnv
-> Subst
-> TypeErrorContext
-> Infer ([TIBindingExpr], [(String, TypeScheme)], Subst)
inferIBindingsWithContext [] TypeEnv
_env Subst
s TypeErrorContext
_ctx = ([TIBindingExpr], [(String, TypeScheme)], Subst)
-> Infer ([TIBindingExpr], [(String, TypeScheme)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return ([], [], Subst
s)
inferIBindingsWithContext ((IPrimitiveDataPattern
pat, IExpr
expr):[IBindingExpr]
bs) TypeEnv
env Subst
s TypeErrorContext
ctx = do
  -- Infer the type of the expression
  (TIExpr
exprTI, Subst
s1) <- IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
expr TypeErrorContext
ctx
  let exprType :: Type
exprType = TIExpr -> Type
tiExprType TIExpr
exprTI

  -- Create expected type from pattern and unify with expression type
  -- This helps resolve type variables in the expression type
  (Type
patternType, Subst
s2) <- IPrimitiveDataPattern
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
inferPatternType IPrimitiveDataPattern
pat
  let s12 :: Subst
s12 = Subst -> Subst -> Subst
composeSubst Subst
s2 Subst
s1
  Type
exprType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s12 Type
exprType
  Type
patternType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s12 Type
patternType
  Subst
s3 <- Type -> Type -> TypeErrorContext -> Infer Subst
unifyTypesWithContext Type
exprType' Type
patternType' TypeErrorContext
ctx

  -- Apply all substitutions recursively until fixed point
  -- This ensures nested type variables are fully resolved (e.g., for sortWithSign)
  let finalS :: Subst
finalS = Subst -> Subst -> Subst
composeSubst Subst
s3 Subst
s12
  Type
finalExprType <- Subst -> Type -> Infer Type
applySubstRecursively Subst
finalS Type
exprType
  let bindings :: [(String, TypeScheme)]
bindings = IPrimitiveDataPattern -> Type -> [(String, TypeScheme)]
extractIBindingsFromPattern IPrimitiveDataPattern
pat Type
finalExprType
      s' :: Subst
s' = Subst -> Subst -> Subst
composeSubst Subst
finalS Subst
s

  TypeEnv
_env' <- Infer TypeEnv
getEnv
  let extendedEnvList :: [(String, TypeScheme)]
extendedEnvList = [(String, TypeScheme)]
bindings  -- Already a list of (String, TypeScheme)
  ([TIBindingExpr]
restBindingTIs, [(String, TypeScheme)]
restBindings, Subst
s2') <- [(String, TypeScheme)]
-> Infer ([TIBindingExpr], [(String, TypeScheme)], Subst)
-> Infer ([TIBindingExpr], [(String, TypeScheme)], Subst)
forall a. [(String, TypeScheme)] -> Infer a -> Infer a
withEnv [(String, TypeScheme)]
extendedEnvList (Infer ([TIBindingExpr], [(String, TypeScheme)], Subst)
 -> Infer ([TIBindingExpr], [(String, TypeScheme)], Subst))
-> Infer ([TIBindingExpr], [(String, TypeScheme)], Subst)
-> Infer ([TIBindingExpr], [(String, TypeScheme)], Subst)
forall a b. (a -> b) -> a -> b
$ [IBindingExpr]
-> TypeEnv
-> Subst
-> TypeErrorContext
-> Infer ([TIBindingExpr], [(String, TypeScheme)], Subst)
inferIBindingsWithContext [IBindingExpr]
bs TypeEnv
env Subst
s' TypeErrorContext
ctx
  ([TIBindingExpr], [(String, TypeScheme)], Subst)
-> Infer ([TIBindingExpr], [(String, TypeScheme)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return ((IPrimitiveDataPattern
pat, TIExpr
exprTI) TIBindingExpr -> [TIBindingExpr] -> [TIBindingExpr]
forall a. a -> [a] -> [a]
: [TIBindingExpr]
restBindingTIs, [(String, TypeScheme)]
bindings [(String, TypeScheme)]
-> [(String, TypeScheme)] -> [(String, TypeScheme)]
forall a. [a] -> [a] -> [a]
++ [(String, TypeScheme)]
restBindings, Subst
s2')
  where
    -- Infer the type that a pattern expects
    inferPatternType :: IPrimitiveDataPattern -> Infer (Type, Subst)
    inferPatternType :: IPrimitiveDataPattern
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
inferPatternType IPrimitiveDataPattern
PDWildCard = do
      Type
t <- String -> Infer Type
freshVar String
"wild"
      (Type, Subst)
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
t, Subst
emptySubst)
    inferPatternType (PDPatVar Var
_) = do
      Type
t <- String -> Infer Type
freshVar String
"patvar"
      (Type, Subst)
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
t, Subst
emptySubst)
    inferPatternType (PDTuplePat [IPrimitiveDataPattern]
pats) = do
      [(Type, Subst)]
results <- (IPrimitiveDataPattern
 -> ExceptT TypeError (StateT InferState IO) (Type, Subst))
-> [IPrimitiveDataPattern]
-> ExceptT TypeError (StateT InferState IO) [(Type, Subst)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM IPrimitiveDataPattern
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
inferPatternType [IPrimitiveDataPattern]
pats
      let types :: [Type]
types = ((Type, Subst) -> Type) -> [(Type, Subst)] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type, Subst) -> Type
forall a b. (a, b) -> a
fst [(Type, Subst)]
results
          substs :: [Subst]
substs = ((Type, Subst) -> Subst) -> [(Type, Subst)] -> [Subst]
forall a b. (a -> b) -> [a] -> [b]
map (Type, Subst) -> Subst
forall a b. (a, b) -> b
snd [(Type, Subst)]
results
          s :: Subst
s = (Subst -> Subst -> Subst) -> Subst -> [Subst] -> Subst
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Subst -> Subst -> Subst
composeSubst Subst
emptySubst [Subst]
substs
      (Type, Subst)
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return ([Type] -> Type
TTuple [Type]
types, Subst
s)
    inferPatternType IPrimitiveDataPattern
PDEmptyPat = (Type, Subst)
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> Type
TCollection (TyVar -> Type
TVar (String -> TyVar
TyVar String
"a")), Subst
emptySubst)
    inferPatternType (PDConsPat IPrimitiveDataPattern
_ IPrimitiveDataPattern
_) = do
      Type
elemType <- String -> Infer Type
freshVar String
"elem"
      (Type, Subst)
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> Type
TCollection Type
elemType, Subst
emptySubst)
    inferPatternType (PDSnocPat IPrimitiveDataPattern
_ IPrimitiveDataPattern
_) = do
      Type
elemType <- String -> Infer Type
freshVar String
"elem"
      (Type, Subst)
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> Type
TCollection Type
elemType, Subst
emptySubst)
    inferPatternType (PDInductivePat String
name [IPrimitiveDataPattern]
pats) = do
      [(Type, Subst)]
results <- (IPrimitiveDataPattern
 -> ExceptT TypeError (StateT InferState IO) (Type, Subst))
-> [IPrimitiveDataPattern]
-> ExceptT TypeError (StateT InferState IO) [(Type, Subst)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM IPrimitiveDataPattern
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
inferPatternType [IPrimitiveDataPattern]
pats
      let types :: [Type]
types = ((Type, Subst) -> Type) -> [(Type, Subst)] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type, Subst) -> Type
forall a b. (a, b) -> a
fst [(Type, Subst)]
results
          substs :: [Subst]
substs = ((Type, Subst) -> Subst) -> [(Type, Subst)] -> [Subst]
forall a b. (a -> b) -> [a] -> [b]
map (Type, Subst) -> Subst
forall a b. (a, b) -> b
snd [(Type, Subst)]
results
          s :: Subst
s = (Subst -> Subst -> Subst) -> Subst -> [Subst] -> Subst
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Subst -> Subst -> Subst
composeSubst Subst
emptySubst [Subst]
substs
      (Type, Subst)
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (String -> [Type] -> Type
TInductive String
name [Type]
types, Subst
s)
    inferPatternType (PDConstantPat ConstantExpr
c) = do
      Type
ty <- ConstantExpr -> Infer Type
inferConstant ConstantExpr
c
      (Type, Subst)
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
ty, Subst
emptySubst)
    -- ScalarData primitive patterns
    inferPatternType (PDDivPat IPrimitiveDataPattern
_ IPrimitiveDataPattern
_) = (Type, Subst)
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
TMathExpr, Subst
emptySubst)
    inferPatternType (PDPlusPat IPrimitiveDataPattern
_) = (Type, Subst)
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
TPolyExpr, Subst
emptySubst)
    inferPatternType (PDTermPat IPrimitiveDataPattern
_ IPrimitiveDataPattern
_) = (Type, Subst)
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
TTermExpr, Subst
emptySubst)
    inferPatternType (PDSymbolPat IPrimitiveDataPattern
_ IPrimitiveDataPattern
_) = (Type, Subst)
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
TSymbolExpr, Subst
emptySubst)
    inferPatternType (PDApply1Pat IPrimitiveDataPattern
_ IPrimitiveDataPattern
_) = (Type, Subst)
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
TSymbolExpr, Subst
emptySubst)
    inferPatternType (PDApply2Pat IPrimitiveDataPattern
_ IPrimitiveDataPattern
_ IPrimitiveDataPattern
_) = (Type, Subst)
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
TSymbolExpr, Subst
emptySubst)
    inferPatternType (PDApply3Pat IPrimitiveDataPattern
_ IPrimitiveDataPattern
_ IPrimitiveDataPattern
_ IPrimitiveDataPattern
_) = (Type, Subst)
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
TSymbolExpr, Subst
emptySubst)
    inferPatternType (PDApply4Pat IPrimitiveDataPattern
_ IPrimitiveDataPattern
_ IPrimitiveDataPattern
_ IPrimitiveDataPattern
_ IPrimitiveDataPattern
_) = (Type, Subst)
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
TSymbolExpr, Subst
emptySubst)
    inferPatternType (PDQuotePat IPrimitiveDataPattern
_) = (Type, Subst)
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
TSymbolExpr, Subst
emptySubst)
    inferPatternType (PDFunctionPat IPrimitiveDataPattern
_ IPrimitiveDataPattern
_) = (Type, Subst)
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
TSymbolExpr, Subst
emptySubst)
    inferPatternType (PDSubPat IPrimitiveDataPattern
_) = (Type, Subst)
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
TIndexExpr, Subst
emptySubst)
    inferPatternType (PDSupPat IPrimitiveDataPattern
_) = (Type, Subst)
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
TIndexExpr, Subst
emptySubst)
    inferPatternType (PDUserPat IPrimitiveDataPattern
_) = (Type, Subst)
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
TIndexExpr, Subst
emptySubst)

-- | Infer letrec bindings (recursive)

-- | Infer letrec bindings (recursive) with context
-- NEW: Returns TIBindingExpr instead of IBindingExpr
inferIRecBindingsWithContext :: [IBindingExpr] -> TypeEnv -> Subst -> TypeErrorContext -> Infer ([TIBindingExpr], [(String, TypeScheme)], Subst)
inferIRecBindingsWithContext :: [IBindingExpr]
-> TypeEnv
-> Subst
-> TypeErrorContext
-> Infer ([TIBindingExpr], [(String, TypeScheme)], Subst)
inferIRecBindingsWithContext [IBindingExpr]
bindings TypeEnv
_env Subst
s TypeErrorContext
ctx = do
  -- Create placeholders with fresh type variables
  [(IPrimitiveDataPattern, Type, Subst)]
placeholders <- (IBindingExpr
 -> ExceptT
      TypeError
      (StateT InferState IO)
      (IPrimitiveDataPattern, Type, Subst))
-> [IBindingExpr]
-> ExceptT
     TypeError
     (StateT InferState IO)
     [(IPrimitiveDataPattern, Type, Subst)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\(IPrimitiveDataPattern
pat, IExpr
_) -> do
    (Type
patternType, Subst
s1) <- IPrimitiveDataPattern
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
inferPatternType IPrimitiveDataPattern
pat
    (IPrimitiveDataPattern, Type, Subst)
-> ExceptT
     TypeError
     (StateT InferState IO)
     (IPrimitiveDataPattern, Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (IPrimitiveDataPattern
pat, Type
patternType, Subst
s1)) [IBindingExpr]
bindings
  
  let placeholderTypes :: [Type]
placeholderTypes = ((IPrimitiveDataPattern, Type, Subst) -> Type)
-> [(IPrimitiveDataPattern, Type, Subst)] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (\(IPrimitiveDataPattern
_, Type
ty, Subst
_) -> Type
ty) [(IPrimitiveDataPattern, Type, Subst)]
placeholders
      placeholderSubsts :: [Subst]
placeholderSubsts = ((IPrimitiveDataPattern, Type, Subst) -> Subst)
-> [(IPrimitiveDataPattern, Type, Subst)] -> [Subst]
forall a b. (a -> b) -> [a] -> [b]
map (\(IPrimitiveDataPattern
_, Type
_, Subst
s) -> Subst
s) [(IPrimitiveDataPattern, Type, Subst)]
placeholders
      s0 :: Subst
s0 = (Subst -> Subst -> Subst) -> Subst -> [Subst] -> Subst
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Subst -> Subst -> Subst
composeSubst Subst
s [Subst]
placeholderSubsts
  
  -- Extract bindings from placeholders
  let placeholderBindings :: [(String, TypeScheme)]
placeholderBindings = [[(String, TypeScheme)]] -> [(String, TypeScheme)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[(String, TypeScheme)]] -> [(String, TypeScheme)])
-> [[(String, TypeScheme)]] -> [(String, TypeScheme)]
forall a b. (a -> b) -> a -> b
$ ((IPrimitiveDataPattern, Type, Subst)
 -> Type -> [(String, TypeScheme)])
-> [(IPrimitiveDataPattern, Type, Subst)]
-> [Type]
-> [[(String, TypeScheme)]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\(IPrimitiveDataPattern
pat, Type
_, Subst
_) Type
ty -> IPrimitiveDataPattern -> Type -> [(String, TypeScheme)]
extractIBindingsFromPattern IPrimitiveDataPattern
pat Type
ty) [(IPrimitiveDataPattern, Type, Subst)]
placeholders [Type]
placeholderTypes
  
  -- Infer expressions in extended environment
  [(TIExpr, Subst)]
results <- [(String, TypeScheme)]
-> ExceptT TypeError (StateT InferState IO) [(TIExpr, Subst)]
-> ExceptT TypeError (StateT InferState IO) [(TIExpr, Subst)]
forall a. [(String, TypeScheme)] -> Infer a -> Infer a
withEnv [(String, TypeScheme)]
placeholderBindings (ExceptT TypeError (StateT InferState IO) [(TIExpr, Subst)]
 -> ExceptT TypeError (StateT InferState IO) [(TIExpr, Subst)])
-> ExceptT TypeError (StateT InferState IO) [(TIExpr, Subst)]
-> ExceptT TypeError (StateT InferState IO) [(TIExpr, Subst)]
forall a b. (a -> b) -> a -> b
$ (IBindingExpr -> Infer (TIExpr, Subst))
-> [IBindingExpr]
-> ExceptT TypeError (StateT InferState IO) [(TIExpr, Subst)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\(IPrimitiveDataPattern
_, IExpr
expr) -> IExpr -> TypeErrorContext -> Infer (TIExpr, Subst)
inferIExprWithContext IExpr
expr TypeErrorContext
ctx) [IBindingExpr]
bindings
  
  let exprTIs :: [TIExpr]
exprTIs = ((TIExpr, Subst) -> TIExpr) -> [(TIExpr, Subst)] -> [TIExpr]
forall a b. (a -> b) -> [a] -> [b]
map (TIExpr, Subst) -> TIExpr
forall a b. (a, b) -> a
fst [(TIExpr, Subst)]
results
      exprTypes :: [Type]
exprTypes = ((TIExpr, Subst) -> Type) -> [(TIExpr, Subst)] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (TIExpr -> Type
tiExprType (TIExpr -> Type)
-> ((TIExpr, Subst) -> TIExpr) -> (TIExpr, Subst) -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TIExpr, Subst) -> TIExpr
forall a b. (a, b) -> a
fst) [(TIExpr, Subst)]
results
      substList :: [Subst]
substList = ((TIExpr, Subst) -> Subst) -> [(TIExpr, Subst)] -> [Subst]
forall a b. (a -> b) -> [a] -> [b]
map (TIExpr, Subst) -> Subst
forall a b. (a, b) -> b
snd [(TIExpr, Subst)]
results
      s1 :: Subst
s1 = (Subst -> Subst -> Subst) -> Subst -> [Subst] -> Subst
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Subst -> Subst -> Subst
composeSubst Subst
s0 [Subst]
substList
  
  -- Unify placeholder types with inferred expression types
  [Subst]
unifySubsts <- (Type -> Type -> Infer Subst) -> [Type] -> [Type] -> Infer [Subst]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (\Type
placeholderTy Type
exprTy -> do
    Type
placeholderTy' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s1 Type
placeholderTy
    Type
exprTy' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
s1 Type
exprTy
    Type -> Type -> TypeErrorContext -> Infer Subst
unifyTypesWithContext Type
exprTy' Type
placeholderTy' TypeErrorContext
ctx) [Type]
placeholderTypes [Type]
exprTypes
  
  let finalS :: Subst
finalS = (Subst -> Subst -> Subst) -> Subst -> [Subst] -> Subst
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Subst -> Subst -> Subst
composeSubst Subst
s1 [Subst]
unifySubsts

  -- Re-extract bindings with fully resolved types
  [Type]
exprTypes' <- (Type -> Infer Type)
-> [Type] -> ExceptT TypeError (StateT InferState IO) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Subst -> Type -> Infer Type
applySubstRecursively Subst
finalS) [Type]
exprTypes
  let finalBindings :: [(String, TypeScheme)]
finalBindings = [[(String, TypeScheme)]] -> [(String, TypeScheme)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[(String, TypeScheme)]] -> [(String, TypeScheme)])
-> [[(String, TypeScheme)]] -> [(String, TypeScheme)]
forall a b. (a -> b) -> a -> b
$ ((IPrimitiveDataPattern, Type, Subst)
 -> Type -> [(String, TypeScheme)])
-> [(IPrimitiveDataPattern, Type, Subst)]
-> [Type]
-> [[(String, TypeScheme)]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\(IPrimitiveDataPattern
pat, Type
_, Subst
_) Type
ty -> IPrimitiveDataPattern -> Type -> [(String, TypeScheme)]
extractIBindingsFromPattern IPrimitiveDataPattern
pat Type
ty) [(IPrimitiveDataPattern, Type, Subst)]
placeholders [Type]
exprTypes'
      transformedBindings :: [TIBindingExpr]
transformedBindings = (IBindingExpr -> TIExpr -> TIBindingExpr)
-> [IBindingExpr] -> [TIExpr] -> [TIBindingExpr]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\(IPrimitiveDataPattern
pat, IExpr
_) TIExpr
exprTI -> (IPrimitiveDataPattern
pat, TIExpr
exprTI)) [IBindingExpr]
bindings [TIExpr]
exprTIs

  ([TIBindingExpr], [(String, TypeScheme)], Subst)
-> Infer ([TIBindingExpr], [(String, TypeScheme)], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return ([TIBindingExpr]
transformedBindings, [(String, TypeScheme)]
finalBindings, Subst
finalS)
  where
    -- Infer the type that a pattern expects (same as in inferIBindingsWithContext)
    inferPatternType :: IPrimitiveDataPattern -> Infer (Type, Subst)
    inferPatternType :: IPrimitiveDataPattern
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
inferPatternType IPrimitiveDataPattern
PDWildCard = do
      Type
t <- String -> Infer Type
freshVar String
"wild"
      (Type, Subst)
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
t, Subst
emptySubst)
    inferPatternType (PDPatVar Var
_) = do
      Type
t <- String -> Infer Type
freshVar String
"rec"
      (Type, Subst)
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
t, Subst
emptySubst)
    inferPatternType (PDTuplePat [IPrimitiveDataPattern]
pats) = do
      [(Type, Subst)]
results <- (IPrimitiveDataPattern
 -> ExceptT TypeError (StateT InferState IO) (Type, Subst))
-> [IPrimitiveDataPattern]
-> ExceptT TypeError (StateT InferState IO) [(Type, Subst)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM IPrimitiveDataPattern
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
inferPatternType [IPrimitiveDataPattern]
pats
      let types :: [Type]
types = ((Type, Subst) -> Type) -> [(Type, Subst)] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type, Subst) -> Type
forall a b. (a, b) -> a
fst [(Type, Subst)]
results
          substs :: [Subst]
substs = ((Type, Subst) -> Subst) -> [(Type, Subst)] -> [Subst]
forall a b. (a -> b) -> [a] -> [b]
map (Type, Subst) -> Subst
forall a b. (a, b) -> b
snd [(Type, Subst)]
results
          s :: Subst
s = (Subst -> Subst -> Subst) -> Subst -> [Subst] -> Subst
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Subst -> Subst -> Subst
composeSubst Subst
emptySubst [Subst]
substs
      (Type, Subst)
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return ([Type] -> Type
TTuple [Type]
types, Subst
s)
    inferPatternType IPrimitiveDataPattern
PDEmptyPat = (Type, Subst)
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> Type
TCollection (TyVar -> Type
TVar (String -> TyVar
TyVar String
"a")), Subst
emptySubst)
    inferPatternType (PDConsPat IPrimitiveDataPattern
_ IPrimitiveDataPattern
_) = do
      Type
elemType <- String -> Infer Type
freshVar String
"elem"
      (Type, Subst)
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> Type
TCollection Type
elemType, Subst
emptySubst)
    inferPatternType (PDSnocPat IPrimitiveDataPattern
_ IPrimitiveDataPattern
_) = do
      Type
elemType <- String -> Infer Type
freshVar String
"elem"
      (Type, Subst)
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> Type
TCollection Type
elemType, Subst
emptySubst)
    inferPatternType (PDInductivePat String
name [IPrimitiveDataPattern]
pats) = do
      [(Type, Subst)]
results <- (IPrimitiveDataPattern
 -> ExceptT TypeError (StateT InferState IO) (Type, Subst))
-> [IPrimitiveDataPattern]
-> ExceptT TypeError (StateT InferState IO) [(Type, Subst)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM IPrimitiveDataPattern
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
inferPatternType [IPrimitiveDataPattern]
pats
      let types :: [Type]
types = ((Type, Subst) -> Type) -> [(Type, Subst)] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type, Subst) -> Type
forall a b. (a, b) -> a
fst [(Type, Subst)]
results
          substs :: [Subst]
substs = ((Type, Subst) -> Subst) -> [(Type, Subst)] -> [Subst]
forall a b. (a -> b) -> [a] -> [b]
map (Type, Subst) -> Subst
forall a b. (a, b) -> b
snd [(Type, Subst)]
results
          s :: Subst
s = (Subst -> Subst -> Subst) -> Subst -> [Subst] -> Subst
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Subst -> Subst -> Subst
composeSubst Subst
emptySubst [Subst]
substs
      (Type, Subst)
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (String -> [Type] -> Type
TInductive String
name [Type]
types, Subst
s)
    inferPatternType (PDConstantPat ConstantExpr
c) = do
      Type
ty <- ConstantExpr -> Infer Type
inferConstant ConstantExpr
c
      (Type, Subst)
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
ty, Subst
emptySubst)
    -- Add other cases as needed
    inferPatternType IPrimitiveDataPattern
_ = do
      Type
t <- String -> Infer Type
freshVar String
"rec"
      (Type, Subst)
-> ExceptT TypeError (StateT InferState IO) (Type, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
t, Subst
emptySubst)

-- | Extract bindings from pattern
-- This function extracts variable bindings from a primitive data pattern
-- given the type that the pattern should match against
-- Helper to check if a pattern is a pattern variable
isPatVarPat :: IPrimitiveDataPattern -> Bool
isPatVarPat :: IPrimitiveDataPattern -> Bool
isPatVarPat (PDPatVar Var
_) = Bool
True
isPatVarPat IPrimitiveDataPattern
_ = Bool
False

extractIBindingsFromPattern :: IPrimitiveDataPattern -> Type -> [(String, TypeScheme)]
extractIBindingsFromPattern :: IPrimitiveDataPattern -> Type -> [(String, TypeScheme)]
extractIBindingsFromPattern IPrimitiveDataPattern
pat Type
ty = case IPrimitiveDataPattern
pat of
  IPrimitiveDataPattern
PDWildCard -> []
  PDPatVar Var
var -> [(Var -> String
extractNameFromVar Var
var, [TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [] Type
ty)]
  PDInductivePat String
_ [IPrimitiveDataPattern]
pats -> (IPrimitiveDataPattern -> [(String, TypeScheme)])
-> [IPrimitiveDataPattern] -> [(String, TypeScheme)]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (\IPrimitiveDataPattern
p -> IPrimitiveDataPattern -> Type -> [(String, TypeScheme)]
extractIBindingsFromPattern IPrimitiveDataPattern
p Type
ty) [IPrimitiveDataPattern]
pats
  PDTuplePat [IPrimitiveDataPattern]
pats -> 
    case Type
ty of
      TTuple [Type]
tys | [IPrimitiveDataPattern] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [IPrimitiveDataPattern]
pats Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
tys -> 
        -- Types match: bind each pattern variable to corresponding type
        [[(String, TypeScheme)]] -> [(String, TypeScheme)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[(String, TypeScheme)]] -> [(String, TypeScheme)])
-> [[(String, TypeScheme)]] -> [(String, TypeScheme)]
forall a b. (a -> b) -> a -> b
$ (IPrimitiveDataPattern -> Type -> [(String, TypeScheme)])
-> [IPrimitiveDataPattern] -> [Type] -> [[(String, TypeScheme)]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith IPrimitiveDataPattern -> Type -> [(String, TypeScheme)]
extractIBindingsFromPattern [IPrimitiveDataPattern]
pats [Type]
tys
      Type
_ -> 
        -- Type is not a resolved tuple (might be type variable or mismatch)
        -- Extract pattern variables but assign them the full tuple type for now
        -- This is imprecise but allows variables to be in scope
        -- The actual element types will be determined during later unification
        (IPrimitiveDataPattern -> [(String, TypeScheme)])
-> [IPrimitiveDataPattern] -> [(String, TypeScheme)]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (\IPrimitiveDataPattern
p -> IPrimitiveDataPattern -> Type -> [(String, TypeScheme)]
extractIBindingsFromPattern IPrimitiveDataPattern
p Type
ty) [IPrimitiveDataPattern]
pats
  IPrimitiveDataPattern
PDEmptyPat -> []
  PDConsPat IPrimitiveDataPattern
p1 IPrimitiveDataPattern
p2 ->
    case Type
ty of
      TCollection Type
elemTy -> IPrimitiveDataPattern -> Type -> [(String, TypeScheme)]
extractIBindingsFromPattern IPrimitiveDataPattern
p1 Type
elemTy [(String, TypeScheme)]
-> [(String, TypeScheme)] -> [(String, TypeScheme)]
forall a. [a] -> [a] -> [a]
++ IPrimitiveDataPattern -> Type -> [(String, TypeScheme)]
extractIBindingsFromPattern IPrimitiveDataPattern
p2 Type
ty
      Type
_ -> []
  PDSnocPat IPrimitiveDataPattern
p1 IPrimitiveDataPattern
p2 ->
    case Type
ty of
      TCollection Type
elemTy -> IPrimitiveDataPattern -> Type -> [(String, TypeScheme)]
extractIBindingsFromPattern IPrimitiveDataPattern
p1 Type
ty [(String, TypeScheme)]
-> [(String, TypeScheme)] -> [(String, TypeScheme)]
forall a. [a] -> [a] -> [a]
++ IPrimitiveDataPattern -> Type -> [(String, TypeScheme)]
extractIBindingsFromPattern IPrimitiveDataPattern
p2 Type
elemTy
      Type
_ -> []
  -- ScalarData primitive patterns
  PDDivPat IPrimitiveDataPattern
p1 IPrimitiveDataPattern
p2 ->
    let polyExprTy :: Type
polyExprTy = Type
TPolyExpr
        mathExprTy :: Type
mathExprTy = Type
TMathExpr
        p1Ty :: Type
p1Ty = if IPrimitiveDataPattern -> Bool
isPatVarPat IPrimitiveDataPattern
p1 then Type
mathExprTy else Type
polyExprTy
        p2Ty :: Type
p2Ty = if IPrimitiveDataPattern -> Bool
isPatVarPat IPrimitiveDataPattern
p2 then Type
mathExprTy else Type
polyExprTy
    in IPrimitiveDataPattern -> Type -> [(String, TypeScheme)]
extractIBindingsFromPattern IPrimitiveDataPattern
p1 Type
p1Ty [(String, TypeScheme)]
-> [(String, TypeScheme)] -> [(String, TypeScheme)]
forall a. [a] -> [a] -> [a]
++ IPrimitiveDataPattern -> Type -> [(String, TypeScheme)]
extractIBindingsFromPattern IPrimitiveDataPattern
p2 Type
p2Ty
  PDPlusPat IPrimitiveDataPattern
p ->
    let termExprTy :: Type
termExprTy = Type
TTermExpr
        mathExprTy :: Type
mathExprTy = Type
TMathExpr
        pTy :: Type
pTy = if IPrimitiveDataPattern -> Bool
isPatVarPat IPrimitiveDataPattern
p then Type -> Type
TCollection Type
mathExprTy else Type -> Type
TCollection Type
termExprTy
    in IPrimitiveDataPattern -> Type -> [(String, TypeScheme)]
extractIBindingsFromPattern IPrimitiveDataPattern
p Type
pTy
  PDTermPat IPrimitiveDataPattern
p1 IPrimitiveDataPattern
p2 ->
    let symbolExprTy :: Type
symbolExprTy = Type
TSymbolExpr
        mathExprTy :: Type
mathExprTy = Type
TMathExpr
        p2Ty :: Type
p2Ty = if IPrimitiveDataPattern -> Bool
isPatVarPat IPrimitiveDataPattern
p2
               then Type -> Type
TCollection ([Type] -> Type
TTuple [Type
mathExprTy, Type
TInt])
               else Type -> Type
TCollection ([Type] -> Type
TTuple [Type
symbolExprTy, Type
TInt])
    in IPrimitiveDataPattern -> Type -> [(String, TypeScheme)]
extractIBindingsFromPattern IPrimitiveDataPattern
p1 Type
TInt [(String, TypeScheme)]
-> [(String, TypeScheme)] -> [(String, TypeScheme)]
forall a. [a] -> [a] -> [a]
++ IPrimitiveDataPattern -> Type -> [(String, TypeScheme)]
extractIBindingsFromPattern IPrimitiveDataPattern
p2 Type
p2Ty
  PDSymbolPat IPrimitiveDataPattern
p1 IPrimitiveDataPattern
p2 ->
    let indexExprTy :: Type
indexExprTy = Type
TIndexExpr
    in IPrimitiveDataPattern -> Type -> [(String, TypeScheme)]
extractIBindingsFromPattern IPrimitiveDataPattern
p1 Type
TString [(String, TypeScheme)]
-> [(String, TypeScheme)] -> [(String, TypeScheme)]
forall a. [a] -> [a] -> [a]
++ IPrimitiveDataPattern -> Type -> [(String, TypeScheme)]
extractIBindingsFromPattern IPrimitiveDataPattern
p2 (Type -> Type
TCollection Type
indexExprTy)
  PDApply1Pat IPrimitiveDataPattern
p1 IPrimitiveDataPattern
p2 ->
    let mathExprTy :: Type
mathExprTy = Type
TMathExpr
        fnTy :: Type
fnTy = Type -> Type -> Type
TFun Type
mathExprTy Type
mathExprTy
    in IPrimitiveDataPattern -> Type -> [(String, TypeScheme)]
extractIBindingsFromPattern IPrimitiveDataPattern
p1 Type
fnTy [(String, TypeScheme)]
-> [(String, TypeScheme)] -> [(String, TypeScheme)]
forall a. [a] -> [a] -> [a]
++ IPrimitiveDataPattern -> Type -> [(String, TypeScheme)]
extractIBindingsFromPattern IPrimitiveDataPattern
p2 Type
mathExprTy
  PDApply2Pat IPrimitiveDataPattern
p1 IPrimitiveDataPattern
p2 IPrimitiveDataPattern
p3 ->
    let mathExprTy :: Type
mathExprTy = Type
TMathExpr
        fnTy :: Type
fnTy = Type -> Type -> Type
TFun Type
mathExprTy (Type -> Type -> Type
TFun Type
mathExprTy Type
mathExprTy)
    in IPrimitiveDataPattern -> Type -> [(String, TypeScheme)]
extractIBindingsFromPattern IPrimitiveDataPattern
p1 Type
fnTy [(String, TypeScheme)]
-> [(String, TypeScheme)] -> [(String, TypeScheme)]
forall a. [a] -> [a] -> [a]
++ IPrimitiveDataPattern -> Type -> [(String, TypeScheme)]
extractIBindingsFromPattern IPrimitiveDataPattern
p2 Type
mathExprTy [(String, TypeScheme)]
-> [(String, TypeScheme)] -> [(String, TypeScheme)]
forall a. [a] -> [a] -> [a]
++ IPrimitiveDataPattern -> Type -> [(String, TypeScheme)]
extractIBindingsFromPattern IPrimitiveDataPattern
p3 Type
mathExprTy
  PDApply3Pat IPrimitiveDataPattern
p1 IPrimitiveDataPattern
p2 IPrimitiveDataPattern
p3 IPrimitiveDataPattern
p4 ->
    let mathExprTy :: Type
mathExprTy = Type
TMathExpr
        fnTy :: Type
fnTy = Type -> Type -> Type
TFun Type
mathExprTy (Type -> Type -> Type
TFun Type
mathExprTy (Type -> Type -> Type
TFun Type
mathExprTy Type
mathExprTy))
    in IPrimitiveDataPattern -> Type -> [(String, TypeScheme)]
extractIBindingsFromPattern IPrimitiveDataPattern
p1 Type
fnTy [(String, TypeScheme)]
-> [(String, TypeScheme)] -> [(String, TypeScheme)]
forall a. [a] -> [a] -> [a]
++ IPrimitiveDataPattern -> Type -> [(String, TypeScheme)]
extractIBindingsFromPattern IPrimitiveDataPattern
p2 Type
mathExprTy [(String, TypeScheme)]
-> [(String, TypeScheme)] -> [(String, TypeScheme)]
forall a. [a] -> [a] -> [a]
++ IPrimitiveDataPattern -> Type -> [(String, TypeScheme)]
extractIBindingsFromPattern IPrimitiveDataPattern
p3 Type
mathExprTy [(String, TypeScheme)]
-> [(String, TypeScheme)] -> [(String, TypeScheme)]
forall a. [a] -> [a] -> [a]
++ IPrimitiveDataPattern -> Type -> [(String, TypeScheme)]
extractIBindingsFromPattern IPrimitiveDataPattern
p4 Type
mathExprTy
  PDApply4Pat IPrimitiveDataPattern
p1 IPrimitiveDataPattern
p2 IPrimitiveDataPattern
p3 IPrimitiveDataPattern
p4 IPrimitiveDataPattern
p5 ->
    let mathExprTy :: Type
mathExprTy = Type
TMathExpr
        fnTy :: Type
fnTy = Type -> Type -> Type
TFun Type
mathExprTy (Type -> Type -> Type
TFun Type
mathExprTy (Type -> Type -> Type
TFun Type
mathExprTy (Type -> Type -> Type
TFun Type
mathExprTy Type
mathExprTy)))
    in IPrimitiveDataPattern -> Type -> [(String, TypeScheme)]
extractIBindingsFromPattern IPrimitiveDataPattern
p1 Type
fnTy [(String, TypeScheme)]
-> [(String, TypeScheme)] -> [(String, TypeScheme)]
forall a. [a] -> [a] -> [a]
++ IPrimitiveDataPattern -> Type -> [(String, TypeScheme)]
extractIBindingsFromPattern IPrimitiveDataPattern
p2 Type
mathExprTy [(String, TypeScheme)]
-> [(String, TypeScheme)] -> [(String, TypeScheme)]
forall a. [a] -> [a] -> [a]
++ IPrimitiveDataPattern -> Type -> [(String, TypeScheme)]
extractIBindingsFromPattern IPrimitiveDataPattern
p3 Type
mathExprTy [(String, TypeScheme)]
-> [(String, TypeScheme)] -> [(String, TypeScheme)]
forall a. [a] -> [a] -> [a]
++ IPrimitiveDataPattern -> Type -> [(String, TypeScheme)]
extractIBindingsFromPattern IPrimitiveDataPattern
p4 Type
mathExprTy [(String, TypeScheme)]
-> [(String, TypeScheme)] -> [(String, TypeScheme)]
forall a. [a] -> [a] -> [a]
++ IPrimitiveDataPattern -> Type -> [(String, TypeScheme)]
extractIBindingsFromPattern IPrimitiveDataPattern
p5 Type
mathExprTy
  PDQuotePat IPrimitiveDataPattern
p ->
    let mathExprTy :: Type
mathExprTy = Type
TMathExpr
    in IPrimitiveDataPattern -> Type -> [(String, TypeScheme)]
extractIBindingsFromPattern IPrimitiveDataPattern
p Type
mathExprTy
  PDFunctionPat IPrimitiveDataPattern
p1 IPrimitiveDataPattern
p2 ->
    let mathExprTy :: Type
mathExprTy = Type
TMathExpr
    in IPrimitiveDataPattern -> Type -> [(String, TypeScheme)]
extractIBindingsFromPattern IPrimitiveDataPattern
p1 Type
mathExprTy [(String, TypeScheme)]
-> [(String, TypeScheme)] -> [(String, TypeScheme)]
forall a. [a] -> [a] -> [a]
++ IPrimitiveDataPattern -> Type -> [(String, TypeScheme)]
extractIBindingsFromPattern IPrimitiveDataPattern
p2 (Type -> Type
TCollection Type
mathExprTy)
  PDSubPat IPrimitiveDataPattern
p ->
    let mathExprTy :: Type
mathExprTy = Type
TMathExpr
    in IPrimitiveDataPattern -> Type -> [(String, TypeScheme)]
extractIBindingsFromPattern IPrimitiveDataPattern
p Type
mathExprTy
  PDSupPat IPrimitiveDataPattern
p ->
    let mathExprTy :: Type
mathExprTy = Type
TMathExpr
    in IPrimitiveDataPattern -> Type -> [(String, TypeScheme)]
extractIBindingsFromPattern IPrimitiveDataPattern
p Type
mathExprTy
  PDUserPat IPrimitiveDataPattern
p ->
    let mathExprTy :: Type
mathExprTy = Type
TMathExpr
    in IPrimitiveDataPattern -> Type -> [(String, TypeScheme)]
extractIBindingsFromPattern IPrimitiveDataPattern
p Type
mathExprTy
  IPrimitiveDataPattern
_ -> []

-- | Infer top-level IExpr and return TITopExpr directly
inferITopExpr :: ITopExpr -> Infer (Maybe TITopExpr, Subst)
inferITopExpr :: ITopExpr -> Infer (Maybe TITopExpr, Subst)
inferITopExpr ITopExpr
topExpr = case ITopExpr
topExpr of
  IDefine Var
var IExpr
expr -> do
    String
varName <- String -> ExceptT TypeError (StateT InferState IO) String
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (String -> ExceptT TypeError (StateT InferState IO) String)
-> String -> ExceptT TypeError (StateT InferState IO) String
forall a b. (a -> b) -> a -> b
$ Var -> String
extractNameFromVar Var
var
    TypeEnv
env <- Infer TypeEnv
getEnv
    -- Check if there's an explicit type signature in the environment
    -- (added by EnvBuilder from DefineWithType)
    case Var -> TypeEnv -> Maybe TypeScheme
lookupEnv Var
var TypeEnv
env of
      Just TypeScheme
existingScheme -> do
        -- There's an explicit type signature: check that the inferred type matches
        InferState
st <- ExceptT TypeError (StateT InferState IO) InferState
forall s (m :: * -> *). MonadState s m => m s
get
        let ([Constraint]
instConstraints, Type
expectedType, Int
newCounter) = TypeScheme -> Int -> ([Constraint], Type, Int)
instantiate TypeScheme
existingScheme (InferState -> Int
inferCounter InferState
st)
        (InferState -> InferState) -> Infer ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((InferState -> InferState) -> Infer ())
-> (InferState -> InferState) -> Infer ()
forall a b. (a -> b) -> a -> b
$ \InferState
s -> InferState
s { inferCounter = newCounter }
        -- Add instantiated constraints to the inference context
        -- This is crucial for constraint-aware unification inside the definition body
        -- e.g., when (.) has {Num a}, this constraint must be visible when type-checking t1 * t2
        Infer ()
clearConstraints  -- Start fresh
        [Constraint] -> Infer ()
addConstraints [Constraint]
instConstraints

        -- Infer the expression type
        (TIExpr
exprTI, Subst
subst1) <- IExpr -> Infer (TIExpr, Subst)
inferIExpr IExpr
expr
        let exprType :: Type
exprType = TIExpr -> Type
tiExprType TIExpr
exprTI

        -- Unify inferred type with expected type using constraint-aware unification
        -- This is crucial for cases like (.) where type variables have constraints
        -- The constraints from the type signature affect how Tensor types are unified
        let exprCtx :: TypeErrorContext
exprCtx = String -> TypeErrorContext -> TypeErrorContext
withExpr (IExpr -> String
forall a. Pretty a => a -> String
prettyStr IExpr
expr) TypeErrorContext
emptyContext
            -- Apply substitution to constraints to get current state
            currentConstraints :: [Constraint]
currentConstraints = (Constraint -> Constraint) -> [Constraint] -> [Constraint]
forall a b. (a -> b) -> [a] -> [b]
map (Subst -> Constraint -> Constraint
applySubstConstraint Subst
subst1) [Constraint]
instConstraints
        Type
exprType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
subst1 Type
exprType
        Type
expectedType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
subst1 Type
expectedType
        Subst
subst2 <- [Constraint] -> Type -> Type -> TypeErrorContext -> Infer Subst
unifyTypesWithConstraints [Constraint]
currentConstraints Type
exprType' Type
expectedType' TypeErrorContext
exprCtx
        let finalSubst :: Subst
finalSubst = Subst -> Subst -> Subst
composeSubst Subst
subst2 Subst
subst1

        -- Apply final substitution to exprTI to resolve all type variables
        -- IMPORTANT: Use applySubstToTIExprM to adjust substitution based on constraints
        TIExpr
exprTI' <- Subst -> TIExpr -> Infer TIExpr
applySubstToTIExprM Subst
finalSubst TIExpr
exprTI

        -- Resolve constraints in exprTI' (Tensor t0 -> t0)
        ClassEnv
classEnv <- Infer ClassEnv
getClassEnv
        let exprTI'' :: TIExpr
exprTI'' = ClassEnv -> Subst -> TIExpr -> TIExpr
resolveConstraintsInTIExpr ClassEnv
classEnv Subst
finalSubst TIExpr
exprTI'
        
        -- Reconstruct type scheme from exprTI'' to match actual type variables
        -- Use instantiated constraints and apply final substitution
        -- When there's an explicit type annotation, use the expected type
        -- (with substitutions applied) as the final type, not the inferred type.
        -- This ensures that Tensor types are preserved when explicitly annotated.
        Type
finalType <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
finalSubst Type
expectedType
        let constraints' :: [Constraint]
constraints' = (Constraint -> Constraint) -> [Constraint] -> [Constraint]
forall a b. (a -> b) -> [a] -> [b]
map (Subst -> Constraint -> Constraint
applySubstConstraint Subst
finalSubst) [Constraint]
instConstraints
            envFreeVars :: Set TyVar
envFreeVars = TypeEnv -> Set TyVar
freeVarsInEnv TypeEnv
env
            typeFreeVars :: Set TyVar
typeFreeVars = Type -> Set TyVar
freeTyVars Type
finalType
            genVars :: [TyVar]
genVars = Set TyVar -> [TyVar]
forall a. Set a -> [a]
Set.toList (Set TyVar -> [TyVar]) -> Set TyVar -> [TyVar]
forall a b. (a -> b) -> a -> b
$ Set TyVar
typeFreeVars Set TyVar -> Set TyVar -> Set TyVar
forall a. Ord a => Set a -> Set a -> Set a
`Set.difference` Set TyVar
envFreeVars
            updatedScheme :: TypeScheme
updatedScheme = [TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [TyVar]
genVars [Constraint]
constraints' Type
finalType
        
        -- Keep the updated scheme (with actual type variables) in the environment
        (Maybe TITopExpr, Subst) -> Infer (Maybe TITopExpr, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TITopExpr -> Maybe TITopExpr
forall a. a -> Maybe a
Just (TypeScheme -> Var -> TIExpr -> TITopExpr
TIDefine TypeScheme
updatedScheme Var
var TIExpr
exprTI''), Subst
finalSubst)
      
      Maybe TypeScheme
Nothing -> do
        -- No explicit type signature: infer and generalize as before
        Infer ()
clearConstraints  -- Start with fresh constraints for this expression
        (TIExpr
exprTI, Subst
subst) <- IExpr -> Infer (TIExpr, Subst)
inferIExpr IExpr
expr
        let exprType :: Type
exprType = TIExpr -> Type
tiExprType TIExpr
exprTI
        [Constraint]
constraints <- Infer [Constraint]
getConstraints  -- Collect constraints from type inference
        
        -- Resolve constraints based on available instances
        ClassEnv
classEnv <- Infer ClassEnv
getClassEnv
        let updatedConstraints :: [Constraint]
updatedConstraints = (Constraint -> Constraint) -> [Constraint] -> [Constraint]
forall a b. (a -> b) -> [a] -> [b]
map (ClassEnv -> Subst -> Constraint -> Constraint
resolveConstraintWithInstances ClassEnv
classEnv Subst
subst) [Constraint]
constraints
            -- Filter out constraints on concrete types (non-type-variables)
            -- Concrete constraints don't need to be generalized since the type is already determined
            isTypeVarConstraint :: Constraint -> Bool
isTypeVarConstraint (Constraint String
_ (TVar TyVar
_)) = Bool
True
            isTypeVarConstraint Constraint
_ = Bool
False
            -- Deduplicate constraints (e.g., {Num a, Num a} -> {Num a})
            generalizedConstraints :: [Constraint]
generalizedConstraints = [Constraint] -> [Constraint]
forall a. Eq a => [a] -> [a]
nub ([Constraint] -> [Constraint]) -> [Constraint] -> [Constraint]
forall a b. (a -> b) -> a -> b
$ (Constraint -> Bool) -> [Constraint] -> [Constraint]
forall a. (a -> Bool) -> [a] -> [a]
filter Constraint -> Bool
isTypeVarConstraint [Constraint]
updatedConstraints

        -- Generalize with filtered constraints (only type variables)
        let envFreeVars :: Set TyVar
envFreeVars = TypeEnv -> Set TyVar
freeVarsInEnv TypeEnv
env
            typeFreeVars :: Set TyVar
typeFreeVars = Type -> Set TyVar
freeTyVars Type
exprType
            genVars :: [TyVar]
genVars = Set TyVar -> [TyVar]
forall a. Set a -> [a]
Set.toList (Set TyVar -> [TyVar]) -> Set TyVar -> [TyVar]
forall a b. (a -> b) -> a -> b
$ Set TyVar
typeFreeVars Set TyVar -> Set TyVar -> Set TyVar
forall a. Ord a => Set a -> Set a -> Set a
`Set.difference` Set TyVar
envFreeVars
            scheme :: TypeScheme
scheme = [TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [TyVar]
genVars [Constraint]
generalizedConstraints Type
exprType
        
        -- Add to environment using the Var directly (preserves index info)
        (InferState -> InferState) -> Infer ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((InferState -> InferState) -> Infer ())
-> (InferState -> InferState) -> Infer ()
forall a b. (a -> b) -> a -> b
$ \InferState
s -> InferState
s { inferEnv = extendEnv var scheme (inferEnv s) }
        
        (Maybe TITopExpr, Subst) -> Infer (Maybe TITopExpr, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TITopExpr -> Maybe TITopExpr
forall a. a -> Maybe a
Just (TypeScheme -> Var -> TIExpr -> TITopExpr
TIDefine TypeScheme
scheme Var
var TIExpr
exprTI), Subst
subst)
  
  ITest IExpr
expr -> do
    Infer ()
clearConstraints  -- Start with fresh constraints
    (TIExpr
exprTI, Subst
subst) <- IExpr -> Infer (TIExpr, Subst)
inferIExpr IExpr
expr
    -- Constraints are now in state, will be retrieved by Eval.hs
    (Maybe TITopExpr, Subst) -> Infer (Maybe TITopExpr, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TITopExpr -> Maybe TITopExpr
forall a. a -> Maybe a
Just (TIExpr -> TITopExpr
TITest TIExpr
exprTI), Subst
subst)
  
  IExecute IExpr
expr -> do
    Infer ()
clearConstraints  -- Start with fresh constraints
    (TIExpr
exprTI, Subst
subst) <- IExpr -> Infer (TIExpr, Subst)
inferIExpr IExpr
expr
    -- Constraints are now in state, will be retrieved by Eval.hs
    (Maybe TITopExpr, Subst) -> Infer (Maybe TITopExpr, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TITopExpr -> Maybe TITopExpr
forall a. a -> Maybe a
Just (TIExpr -> TITopExpr
TIExecute TIExpr
exprTI), Subst
subst)
  
  ILoadFile String
_path -> (Maybe TITopExpr, Subst) -> Infer (Maybe TITopExpr, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe TITopExpr
forall a. Maybe a
Nothing, Subst
emptySubst)
  ILoad String
_lib -> (Maybe TITopExpr, Subst) -> Infer (Maybe TITopExpr, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe TITopExpr
forall a. Maybe a
Nothing, Subst
emptySubst)

  IDefineMany [(Var, IExpr)]
bindings -> do
    -- Process each binding in the list
    TypeEnv
env <- Infer TypeEnv
getEnv
    [((Var, TIExpr), Subst)]
results <- ((Var, IExpr)
 -> ExceptT TypeError (StateT InferState IO) ((Var, TIExpr), Subst))
-> [(Var, IExpr)]
-> ExceptT
     TypeError (StateT InferState IO) [((Var, TIExpr), Subst)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (TypeEnv
-> (Var, IExpr)
-> ExceptT TypeError (StateT InferState IO) ((Var, TIExpr), Subst)
inferBinding TypeEnv
env) [(Var, IExpr)]
bindings
    let bindingsTI :: [(Var, TIExpr)]
bindingsTI = (((Var, TIExpr), Subst) -> (Var, TIExpr))
-> [((Var, TIExpr), Subst)] -> [(Var, TIExpr)]
forall a b. (a -> b) -> [a] -> [b]
map ((Var, TIExpr), Subst) -> (Var, TIExpr)
forall a b. (a, b) -> a
fst [((Var, TIExpr), Subst)]
results
        substs :: [Subst]
substs = (((Var, TIExpr), Subst) -> Subst)
-> [((Var, TIExpr), Subst)] -> [Subst]
forall a b. (a -> b) -> [a] -> [b]
map ((Var, TIExpr), Subst) -> Subst
forall a b. (a, b) -> b
snd [((Var, TIExpr), Subst)]
results
        combinedSubst :: Subst
combinedSubst = (Subst -> Subst -> Subst) -> Subst -> [Subst] -> Subst
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Subst -> Subst -> Subst
composeSubst Subst
emptySubst [Subst]
substs
    (Maybe TITopExpr, Subst) -> Infer (Maybe TITopExpr, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TITopExpr -> Maybe TITopExpr
forall a. a -> Maybe a
Just ([(Var, TIExpr)] -> TITopExpr
TIDefineMany [(Var, TIExpr)]
bindingsTI), Subst
combinedSubst)
    where
      inferBinding :: TypeEnv
-> (Var, IExpr)
-> ExceptT TypeError (StateT InferState IO) ((Var, TIExpr), Subst)
inferBinding TypeEnv
env (Var
var, IExpr
expr) = do
        let varName :: String
varName = Var -> String
extractNameFromVar Var
var
        -- Check if there's an existing type signature
        case Var -> TypeEnv -> Maybe TypeScheme
lookupEnv Var
var TypeEnv
env of
          Just TypeScheme
existingScheme -> do
            -- With type signature: check type
            InferState
st <- ExceptT TypeError (StateT InferState IO) InferState
forall s (m :: * -> *). MonadState s m => m s
get
            let ([Constraint]
_, Type
expectedType, Int
newCounter) = TypeScheme -> Int -> ([Constraint], Type, Int)
instantiate TypeScheme
existingScheme (InferState -> Int
inferCounter InferState
st)
            (InferState -> InferState) -> Infer ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((InferState -> InferState) -> Infer ())
-> (InferState -> InferState) -> Infer ()
forall a b. (a -> b) -> a -> b
$ \InferState
s -> InferState
s { inferCounter = newCounter }
            
            Infer ()
clearConstraints
            (TIExpr
exprTI, Subst
subst1) <- IExpr -> Infer (TIExpr, Subst)
inferIExpr IExpr
expr
            let exprType :: Type
exprType = TIExpr -> Type
tiExprType TIExpr
exprTI
            Type
exprType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
subst1 Type
exprType
            Type
expectedType' <- Subst -> Type -> Infer Type
applySubstWithConstraintsM Subst
subst1 Type
expectedType
            Subst
subst2 <- Type -> Type -> TypeErrorContext -> Infer Subst
unifyTypesWithTopLevel Type
exprType' Type
expectedType' TypeErrorContext
emptyContext
            let finalSubst :: Subst
finalSubst = Subst -> Subst -> Subst
composeSubst Subst
subst2 Subst
subst1
            TIExpr
exprTI' <- Subst -> TIExpr -> Infer TIExpr
applySubstToTIExprM Subst
finalSubst TIExpr
exprTI
            ((Var, TIExpr), Subst)
-> ExceptT TypeError (StateT InferState IO) ((Var, TIExpr), Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return ((Var
var, TIExpr
exprTI'), Subst
finalSubst)
          
          Maybe TypeScheme
Nothing -> do
            -- Without type signature: infer and generalize
            Infer ()
clearConstraints
            (TIExpr
exprTI, Subst
subst) <- IExpr -> Infer (TIExpr, Subst)
inferIExpr IExpr
expr
            let exprType :: Type
exprType = TIExpr -> Type
tiExprType TIExpr
exprTI
            [Constraint]
constraints <- Infer [Constraint]
getConstraints
            
            -- Resolve constraints based on available instances
            ClassEnv
classEnv <- Infer ClassEnv
getClassEnv
            let updatedConstraints :: [Constraint]
updatedConstraints = (Constraint -> Constraint) -> [Constraint] -> [Constraint]
forall a b. (a -> b) -> [a] -> [b]
map (ClassEnv -> Subst -> Constraint -> Constraint
resolveConstraintWithInstances ClassEnv
classEnv Subst
subst) [Constraint]
constraints
                -- Filter out constraints on concrete types (non-type-variables)
                isTypeVarConstraint :: Constraint -> Bool
isTypeVarConstraint (Constraint String
_ (TVar TyVar
_)) = Bool
True
                isTypeVarConstraint Constraint
_ = Bool
False
                -- Deduplicate constraints (e.g., {Num a, Num a} -> {Num a})
                generalizedConstraints :: [Constraint]
generalizedConstraints = [Constraint] -> [Constraint]
forall a. Eq a => [a] -> [a]
nub ([Constraint] -> [Constraint]) -> [Constraint] -> [Constraint]
forall a b. (a -> b) -> a -> b
$ (Constraint -> Bool) -> [Constraint] -> [Constraint]
forall a. (a -> Bool) -> [a] -> [a]
filter Constraint -> Bool
isTypeVarConstraint [Constraint]
updatedConstraints

            -- Generalize the type
            let envFreeVars :: Set TyVar
envFreeVars = TypeEnv -> Set TyVar
freeVarsInEnv TypeEnv
env
                typeFreeVars :: Set TyVar
typeFreeVars = Type -> Set TyVar
freeTyVars Type
exprType
                genVars :: [TyVar]
genVars = Set TyVar -> [TyVar]
forall a. Set a -> [a]
Set.toList (Set TyVar -> [TyVar]) -> Set TyVar -> [TyVar]
forall a b. (a -> b) -> a -> b
$ Set TyVar
typeFreeVars Set TyVar -> Set TyVar -> Set TyVar
forall a. Ord a => Set a -> Set a -> Set a
`Set.difference` Set TyVar
envFreeVars
                scheme :: TypeScheme
scheme = [TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [TyVar]
genVars [Constraint]
generalizedConstraints Type
exprType
            
            -- Add to environment for subsequent bindings using Var directly
            (InferState -> InferState) -> Infer ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((InferState -> InferState) -> Infer ())
-> (InferState -> InferState) -> Infer ()
forall a b. (a -> b) -> a -> b
$ \InferState
s -> InferState
s { inferEnv = extendEnv var scheme (inferEnv s) }
            
            ((Var, TIExpr), Subst)
-> ExceptT TypeError (StateT InferState IO) ((Var, TIExpr), Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return ((Var
var, TIExpr
exprTI), Subst
subst)
  
  IPatternFunctionDecl String
name [TyVar]
tyVars [(String, Type)]
params Type
retType IPattern
body -> do
    -- Pattern function type checking:
    -- 1. Add parameters to environment for type checking
    -- 2. Infer body pattern with expected return type
    -- 3. Create type scheme with type parameters
    
    Infer ()
clearConstraints  -- Start fresh
    
    -- Add parameters to environment for type checking the body
    -- Note: Parameter types don't need Pattern wrapper (design/pattern.md)
    let paramBindings :: [(String, TypeScheme)]
paramBindings = ((String, Type) -> (String, TypeScheme))
-> [(String, Type)] -> [(String, TypeScheme)]
forall a b. (a -> b) -> [a] -> [b]
map (\(String
pname, Type
pty) -> (String
pname, [TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [] Type
pty)) [(String, Type)]
params
    [(String, TypeScheme)]
-> Infer (Maybe TITopExpr, Subst) -> Infer (Maybe TITopExpr, Subst)
forall a. [(String, TypeScheme)] -> Infer a -> Infer a
withEnv [(String, TypeScheme)]
paramBindings (Infer (Maybe TITopExpr, Subst) -> Infer (Maybe TITopExpr, Subst))
-> Infer (Maybe TITopExpr, Subst) -> Infer (Maybe TITopExpr, Subst)
forall a b. (a -> b) -> a -> b
$ do
      -- Infer body pattern with expected return type
      let ctx :: TypeErrorContext
ctx = TypeErrorContext 
                  { errorLocation :: Maybe SourceLocation
errorLocation = Maybe SourceLocation
forall a. Maybe a
Nothing
                  , errorExpr :: Maybe String
errorExpr = String -> Maybe String
forall a. a -> Maybe a
Just (String
"Pattern function: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
name)
                  , errorContext :: Maybe String
errorContext = String -> Maybe String
forall a. a -> Maybe a
Just (String
"Expected type: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
forall a. Show a => a -> String
show Type
retType)
                  }
      (TIPattern
tiBody, [(String, Type)]
_bodyBindings, Subst
subst) <- IPattern
-> Type
-> TypeErrorContext
-> Infer (TIPattern, [(String, Type)], Subst)
inferIPattern IPattern
body Type
retType TypeErrorContext
ctx
      
      -- Note: Pattern variables that reference parameters (using ~param) will appear in bodyBindings
      -- but they are NOT conflicts - they are references to the parameters themselves.
      -- Only NEW variable bindings (using $var) would be actual conflicts.
      -- Since the pattern body uses ~p1 and ~p2 (pattern variable references), 
      -- not $p1 and $p2 (new bindings), we don't need to check for conflicts here.
      -- The existing semantics already handle this correctly during pattern matching.
      
      -- Create type scheme with type parameters
      -- Pattern function type: param1 -> param2 -> ... -> retType
      let paramTypes :: [Type]
paramTypes = ((String, Type) -> Type) -> [(String, Type)] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (String, Type) -> Type
forall a b. (a, b) -> b
snd [(String, Type)]
params
          funcType :: Type
funcType = (Type -> Type -> Type) -> Type -> [Type] -> Type
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Type -> Type -> Type
TFun Type
retType [Type]
paramTypes
          typeScheme :: TypeScheme
typeScheme = [TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [TyVar]
tyVars [] Type
funcType
      
      -- Add pattern function to both inferPatternFuncEnv and inferEnv
      -- This allows the type checker to recognize it in subsequent declarations
      (InferState -> InferState) -> Infer ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((InferState -> InferState) -> Infer ())
-> (InferState -> InferState) -> Infer ()
forall a b. (a -> b) -> a -> b
$ \InferState
s -> InferState
s { 
        inferPatternFuncEnv = extendPatternEnv name typeScheme (inferPatternFuncEnv s),
        inferEnv = extendEnv (stringToVar name) typeScheme (inferEnv s)
      }
      
      (Maybe TITopExpr, Subst) -> Infer (Maybe TITopExpr, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TITopExpr -> Maybe TITopExpr
forall a. a -> Maybe a
Just (String
-> TypeScheme -> [(String, Type)] -> Type -> TIPattern -> TITopExpr
TIPatternFunctionDecl String
name TypeScheme
typeScheme [(String, Type)]
params Type
retType TIPattern
tiBody), Subst
subst)
  
  IDeclareSymbol [String]
names Maybe Type
mType -> do
    -- Register declared symbols with their types
    let ty :: Type
ty = case Maybe Type
mType of
               Just Type
t  -> Type
t
               Maybe Type
Nothing -> Type
TInt  -- Default to Integer (MathExpr)
    -- Add symbols to declared symbols map
    (InferState -> InferState) -> Infer ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((InferState -> InferState) -> Infer ())
-> (InferState -> InferState) -> Infer ()
forall a b. (a -> b) -> a -> b
$ \InferState
s -> InferState
s { declaredSymbols = 
                        foldr (\String
name Map String Type
m -> String -> Type -> Map String Type -> Map String Type
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert String
name Type
ty Map String Type
m) 
                              (declaredSymbols s) 
                              names }
    -- Also add to type environment so they can be used in subsequent expressions
    let scheme :: TypeScheme
scheme = [TyVar] -> [Constraint] -> Type -> TypeScheme
Forall [] [] Type
ty
    (InferState -> InferState) -> Infer ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((InferState -> InferState) -> Infer ())
-> (InferState -> InferState) -> Infer ()
forall a b. (a -> b) -> a -> b
$ \InferState
s -> InferState
s { inferEnv = 
                        foldr (\String
name TypeEnv
e -> Var -> TypeScheme -> TypeEnv -> TypeEnv
extendEnv (String -> Var
stringToVar String
name) TypeScheme
scheme TypeEnv
e) 
                              (inferEnv s) 
                              names }
    -- Return the typed declaration
    (Maybe TITopExpr, Subst) -> Infer (Maybe TITopExpr, Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TITopExpr -> Maybe TITopExpr
forall a. a -> Maybe a
Just ([String] -> Type -> TITopExpr
TIDeclareSymbol [String]
names Type
ty), Subst
emptySubst)

-- | Infer multiple top-level IExprs
inferITopExprs :: [ITopExpr] -> Infer ([Maybe TITopExpr], Subst)
inferITopExprs :: [ITopExpr] -> Infer ([Maybe TITopExpr], Subst)
inferITopExprs [] = ([Maybe TITopExpr], Subst) -> Infer ([Maybe TITopExpr], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return ([], Subst
emptySubst)
inferITopExprs (ITopExpr
e:[ITopExpr]
es) = do
  (Maybe TITopExpr
tyE, Subst
s1) <- ITopExpr -> Infer (Maybe TITopExpr, Subst)
inferITopExpr ITopExpr
e
  ([Maybe TITopExpr]
tyEs, Subst
s2) <- [ITopExpr] -> Infer ([Maybe TITopExpr], Subst)
inferITopExprs [ITopExpr]
es
  ([Maybe TITopExpr], Subst) -> Infer ([Maybe TITopExpr], Subst)
forall a. a -> ExceptT TypeError (StateT InferState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe TITopExpr
tyE Maybe TITopExpr -> [Maybe TITopExpr] -> [Maybe TITopExpr]
forall a. a -> [a] -> [a]
: [Maybe TITopExpr]
tyEs, Subst -> Subst -> Subst
composeSubst Subst
s2 Subst
s1)

--------------------------------------------------------------------------------
-- * Running Inference
--------------------------------------------------------------------------------

-- | Run type inference on IExpr
runInferI :: InferConfig -> TypeEnv -> IExpr -> IO (Either TypeError (Type, Subst, [TypeWarning]))
runInferI :: InferConfig
-> TypeEnv
-> IExpr
-> IO (Either TypeError (Type, Subst, [TypeWarning]))
runInferI InferConfig
cfg TypeEnv
env IExpr
expr = do
  let initState :: InferState
initState = (InferConfig -> InferState
initialInferStateWithConfig InferConfig
cfg) { inferEnv = env }
  (Either TypeError (TIExpr, Subst)
result, [TypeWarning]
warnings) <- Infer (TIExpr, Subst)
-> InferState
-> IO (Either TypeError (TIExpr, Subst), [TypeWarning])
forall a.
Infer a -> InferState -> IO (Either TypeError a, [TypeWarning])
runInferWithWarnings (IExpr -> Infer (TIExpr, Subst)
inferIExpr IExpr
expr) InferState
initState
  Either TypeError (Type, Subst, [TypeWarning])
-> IO (Either TypeError (Type, Subst, [TypeWarning]))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TypeError (Type, Subst, [TypeWarning])
 -> IO (Either TypeError (Type, Subst, [TypeWarning])))
-> Either TypeError (Type, Subst, [TypeWarning])
-> IO (Either TypeError (Type, Subst, [TypeWarning]))
forall a b. (a -> b) -> a -> b
$ case Either TypeError (TIExpr, Subst)
result of
    Left TypeError
err -> TypeError -> Either TypeError (Type, Subst, [TypeWarning])
forall a b. a -> Either a b
Left TypeError
err
    Right (TIExpr
tiExpr, Subst
subst) -> (Type, Subst, [TypeWarning])
-> Either TypeError (Type, Subst, [TypeWarning])
forall a b. b -> Either a b
Right (TIExpr -> Type
tiExprType TIExpr
tiExpr, Subst
subst, [TypeWarning]
warnings)

-- | Run type inference on IExpr with initial environment
runInferIWithEnv :: InferConfig -> TypeEnv -> IExpr -> IO (Either TypeError (Type, Subst, TypeEnv, [TypeWarning]))
runInferIWithEnv :: InferConfig
-> TypeEnv
-> IExpr
-> IO (Either TypeError (Type, Subst, TypeEnv, [TypeWarning]))
runInferIWithEnv InferConfig
cfg TypeEnv
env IExpr
expr = do
  let initState :: InferState
initState = (InferConfig -> InferState
initialInferStateWithConfig InferConfig
cfg) { inferEnv = env }
  (Either TypeError (TIExpr, Subst)
result, [TypeWarning]
warnings, InferState
finalState) <- Infer (TIExpr, Subst)
-> InferState
-> IO (Either TypeError (TIExpr, Subst), [TypeWarning], InferState)
forall a.
Infer a
-> InferState -> IO (Either TypeError a, [TypeWarning], InferState)
runInferWithWarningsAndState (IExpr -> Infer (TIExpr, Subst)
inferIExpr IExpr
expr) InferState
initState
  Either TypeError (Type, Subst, TypeEnv, [TypeWarning])
-> IO (Either TypeError (Type, Subst, TypeEnv, [TypeWarning]))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TypeError (Type, Subst, TypeEnv, [TypeWarning])
 -> IO (Either TypeError (Type, Subst, TypeEnv, [TypeWarning])))
-> Either TypeError (Type, Subst, TypeEnv, [TypeWarning])
-> IO (Either TypeError (Type, Subst, TypeEnv, [TypeWarning]))
forall a b. (a -> b) -> a -> b
$ case Either TypeError (TIExpr, Subst)
result of
    Left TypeError
err -> TypeError -> Either TypeError (Type, Subst, TypeEnv, [TypeWarning])
forall a b. a -> Either a b
Left TypeError
err
    Right (TIExpr
tiExpr, Subst
subst) -> (Type, Subst, TypeEnv, [TypeWarning])
-> Either TypeError (Type, Subst, TypeEnv, [TypeWarning])
forall a b. b -> Either a b
Right (TIExpr -> Type
tiExprType TIExpr
tiExpr, Subst
subst, InferState -> TypeEnv
inferEnv InferState
finalState, [TypeWarning]
warnings)