{-# LANGUAGE OverloadedStrings #-}

module UnionPoly (unionPolyTests) where

import G2.Language
import qualified G2.Language.ExprEnv as E
import G2.Liquid.Inference.UnionPoly

import qualified Data.Text as T

import Test.Tasty
import Test.Tasty.HUnit

unionPolyTests :: TestTree
unionPolyTests =
    testGroup "UnionPoly"
    [
      testCase "Let unification" $ assertBool "Polymorphic unification failed with lets" letTest
    , testCase "Lambda unification 1" $ assertBool "Polymorphic unification failed with lambdas" lambdaTest1
    , testCase "Lambda unification 2" $ assertBool "Polymorphic unification failed with lambdas" lambdaTest2
    ]

letTest :: Bool
letTest =
    let
        f = defName "f"
        g = defName "g"
        eenv = letExprEnv f g
        
        ut = sharedTyConsEE [f, g] eenv
    in
    case lookupUT g ut of
        Just (TyFun t1@(TyVar _) t2) -> t1 == t2 
        _ -> False

letExprEnv :: Name -> Name -> ExprEnv
letExprEnv f g =
    let
        call = defName "call"
        a = defName "a"

        a_id = Id a TYPE
        a_ty = TyVar a_id

        int = defName "Int"
        int_ty = TyCon int TYPE

        g_id = Id g (TyFun int_ty int_ty)
        r_id = Id (defName "r") (TyFun int_ty int_ty)
        call_id = Id call . TyForAll a_id $ TyFun (TyFun a_ty a_ty) a_ty

        f_e = Let [(r_id, Var g_id)] . App (App (Var call_id) (Type int_ty)) $ Var r_id

        x_id = Id (defName "x") int_ty
        y_id = Id (defName "y") int_ty
        g_e = Lam TermL x_id (Var y_id)
    in
    E.fromList [(f, f_e), (g, g_e)]

lambdaTest1 :: Bool
lambdaTest1 =
    let
        f = defName "f"
        g = defName "g"
        h = defName "h"
        eenv = lambdaExprEnv1 f g h

        ut = sharedTyConsEE [f, g, h] eenv
    in
    case lookupUT h ut of
        Just (TyFun t1@(TyVar _) t2) -> t1 == t2 
        _ -> False

lambdaExprEnv1 :: Name -> Name -> Name -> ExprEnv
lambdaExprEnv1 f g h =
    let
        a = defName "a"
        a_id = Id a TYPE
        a_ty = TyVar a_id

        int = defName "Int"
        int_ty = TyCon int TYPE

        g_id = Id g . TyForAll a_id $ TyFun (TyFun a_ty a_ty) a_ty
        h_id = Id h $ TyFun int_ty int_ty

        j_id = Id (defName "j") (TyFun a_ty a_ty)
        x_id = Id (defName "x") int_ty

        g_e = Lam TypeL a_id . Lam TermL j_id $ App (App (Var g_id) (Type a_ty)) (Var j_id)
        h_e = Lam TermL x_id $ Var x_id
        f_e = App (App (Var g_id) (Type int_ty)) . Lam TermL x_id $ App (Var h_id) (Var x_id)
    in
    E.fromList [(f, f_e), (g, g_e), (h, h_e)]

lambdaTest2 :: Bool
lambdaTest2 =
    let
        f = defName "f"
        g = defName "g"
        h = defName "h"
        eenv = lambdaExprEnv2 f g h

        ut = sharedTyConsEE [f, g, h] eenv
    in
    case lookupUT h ut of
        Just (TyFun t1@(TyVar _) t2) -> t1 == t2 
        _ -> False

lambdaExprEnv2 :: Name -> Name -> Name -> ExprEnv
lambdaExprEnv2 f g h =
    let
        a = defName "a"
        a_id = Id a TYPE
        a_ty = TyVar a_id

        int = defName "Int"
        int_ty = TyCon int TYPE

        g_id = Id g . TyForAll a_id . TyFun (TyFun int_ty int_ty) $ TyFun (TyFun a_ty a_ty) a_ty
        h_id = Id h $ TyFun int_ty int_ty

        j_id = Id (defName "j") (TyFun int_ty int_ty)
        k_id = Id (defName "k") (TyFun a_ty a_ty)
        x_id = Id (defName "x") int_ty
        y_id = Id (defName "y") a_ty

        bind_id = Id (defName "bind") (TyForAll a_id (TyFun a_ty a_ty))

        g_e = Lam TypeL a_id
            . Lam TermL j_id
            . Lam TermL k_id $ App (App (App (Var g_id) (Type a_ty)) (Var j_id)) (Var k_id)
        h_e = Lam TermL x_id $ Var x_id
        f_e = Let [(bind_id, Lam TypeL a_id $ Lam TermL y_id (Var y_id))]
            . App (App (App (Var g_id) (Type int_ty)) (App (Var bind_id) (Type int_ty))) $ Var h_id
    in
    E.fromList [(f, f_e), (g, g_e), (h, h_e)]

defName :: T.Text -> Name
defName n = Name n Nothing 0 Nothing