{-# LANGUAGE TemplateHaskell #-}

module TH (test) where

import Control.Monad

import Language.Haskell.TH
import Language.Haskell.TH.TypeInterpreter

import Types

runCase :: String -> TypeExp -> TypeExp -> Q ()
runCase label actual expected
    | expected == actual = pure ()
    | otherwise = do
        reportError ("Mismatch on " ++ label ++ "!")
        reportError ("Expected: " ++ show expected)
        reportError ("Actual:   " ++ show actual)
        fail "Case failed"

matchCase :: String -> TypeExp -> (TypeExp -> Maybe a) -> Q ()
matchCase label exp apply
    | Nothing <- apply exp = do
        reportError ("Mismatch on " ++ label ++ "!")
        reportError ("Got: " ++ show exp)
        fail "Case failed"

    | otherwise = pure ()

anyName :: Name
anyName = mkName ""

functionType :: [TypeExp] -> TypeExp -> TypeExp
functionType []       r = r
functionType (p : ps) r = Apply (Apply (Atom (Name ''(->))) p) (functionType ps r)

listType :: [TypeExp] -> TypeExp
listType []       = Atom (PromotedName '[])
listType (x : xs) = Apply (Apply (Atom (PromotedName '(:))) x) (listType xs)

test :: Q Exp
test = do
    con1 <- fromName ''Int
    con2 <- fromName ''Maybe

    primCon1 <- fromName ''[]
    primCon2 <- fromName ''(->)

    class1 <- fromName ''Num
    class2 <- fromName ''Class1

    syn1 <- fromName ''Syn1
    syn2 <- fromName ''Syn2
    syn3 <- fromName ''Syn3
    syn4 <- fromName ''Syn4
    syn5 <- fromName ''Syn5
    syn6 <- fromName ''Syn6
    syn7 <- fromName ''Syn7
    syn8 <- fromName ''Syn8
    syn9 <- fromName ''Syn9
    syn10 <- fromName ''Syn10
    syn11 <- fromName ''Syn11
    syn12 <- fromName ''Syn12
    syn13 <- fromName ''Syn13
    syn14 <- fromName ''Syn14
    syn15 <- fromName ''Syn15
    syn16 <- fromName ''Syn16
    syn17 <- fromName ''Syn17

    fam1 <- fromName ''Fam1
    fam2 <- fromName ''Fam2
    fam3 <- fromName ''Fam3
    fam4 <- fromName ''Fam4
    fam5 <- fromName ''Fam5
    fam6 <- fromName ''Fam6
    fam7 <- fromName ''Fam7
    fam8 <- fromName ''Fam8

    _ <- fromName ''Fam8

    -- Type constructors
    runCase "con1" con1 (Atom (Name ''Int))
    runCase "con2" con2 (Atom (Name ''Maybe))

    -- Primitive type constructors
    runCase "primCon1" primCon1 (Atom (Name ''[]))
    runCase "primCon2" primCon2 (Atom (Name ''(->)))

    -- Type classes
    runCase "class1" class1 (Atom (Name ''Num))
    runCase "class2" class2 (Atom (Name ''Class1))

    -- Type synonyms
    runCase "syn1" syn1 (Atom (Integer 1337))
    runCase "syn2" syn2 (Atom (String "Hello"))
    runCase "syn3" syn3 (Atom (Name ''Char))
    runCase "syn4" syn4 (Atom (PromotedName 'Nothing))
    runCase "syn5" syn5 syn2
    runCase "syn6" syn6 (Apply (Atom (Name ''[])) (Atom (Name ''Char)))
    runCase "syn7" syn7 (Apply (Apply (Atom (Name ''(,))) syn6) syn3)
    runCase "syn8" syn8 (Synonym anyName (Variable anyName))
    runCase "syn9" syn9 (Synonym anyName syn7)
    runCase "syn10" syn10 (Atom (Name ''Maybe))
    matchCase "syn11" syn11 $ \ input -> do
        let Synonym p (Apply (Apply (Atom (Name f)) (Variable n)) (Variable m)) = input
        guard (f == ''Either && p == n && n == m)
    runCase "syn12" syn12 syn3
    runCase "syn13" syn13 syn2
    runCase "syn14" syn14 (Apply (Apply fam4 (Atom (Name ''Int))) (Atom (Name ''Word)))
    runCase "syn15" syn15 (Apply (Apply fam7 (Atom (String "World"))) (Atom (Integer 42)))
    matchCase "syn16" syn16 $ \ input -> do
        let Synonym p (Apply (Apply f (Variable n)) b) = input
        guard (f == fam4 && p == n && b == syn6)
    runCase "syn17" syn17 (Function [TypeEquation [syn1] syn4])

    -- Type families
    runCase "fam1" fam1 (Function [])
    runCase "fam2" fam2 syn7
    runCase "fam3" fam3 (Function [])
    runCase "fam4" fam4 (Function [TypeEquation [syn3, syn7] syn6, TypeEquation [syn7, syn6] syn3])
    runCase "fam5" fam5 (Function [])
    runCase "fam6" fam6 syn2
    runCase "fam7" fam7 (Function [TypeEquation [syn2, syn1] syn4])

    let params1  = [Atom (Name ''Int), Atom (Name ''Char)]
        funType1 = functionType params1 (Atom (Name ''String))
    runCase "fam8" (reduce (Apply fam8 funType1)) (listType params1)

    pure (ConE '())

{-
    Tests that are missing:

    + Type family
        ?

    + Applications
        ?
-}