{-# LANGUAGE ScopedTypeVariables #-}

module BinaryDerive where

import Data.Generics
import Data.List

deriveM ::  (Typeable a, Data a) => a -> IO ()
deriveM (a :: a) = mapM_ putStrLn . lines $ derive (undefined :: a)

derive :: (Typeable a, Data a) => a -> String
derive x = 
    "instance " ++ context ++ "Binary " ++ inst ++ " where\n" ++
    concat putDefs ++ getDefs
    where
    context
        | nTypeChildren > 0 =
            wrap (join ", " (map ("Binary "++) typeLetters)) ++ " => "
        | otherwise = ""
    inst = wrap $ tyConName typeName ++ concatMap (" "++) typeLetters
    wrap x = if nTypeChildren > 0 then "("++x++")" else x 
    join sep lst = concat $ intersperse sep lst
    nTypeChildren = length typeChildren
    typeLetters = take nTypeChildren manyLetters
    manyLetters = map (:[]) ['a'..'z']
    (typeName,typeChildren) = splitTyConApp (typeOf x)
    constrs :: [(Int, (String, Int))]
    constrs = zip [0..] $ map gen $ dataTypeConstrs (dataTypeOf x)
    gen con = ( showConstr con
              , length $ gmapQ undefined $ fromConstr con `asTypeOf` x
              )
    putDefs = map ((++"\n") . putDef) constrs
    putDef (n, (name, ps)) =
        let wrap = if ps /= 0 then ("("++) . (++")") else id
            pattern = name ++ concatMap (' ':) (take ps manyLetters)
        in
        "  put " ++ wrap pattern ++" = "
        ++ concat [ "putWord8 " ++ show n | length constrs  > 1 ]
        ++ concat [ " >> "                | length constrs  > 1 && ps  > 0 ]
        ++ concat [ "return ()"           | length constrs == 1 && ps == 0 ]
        ++ join " >> " (map ("put "++) (take ps manyLetters))
    getDefs =
       (if length constrs > 1
            then "  get = do\n    tag_ <- getWord8\n    case tag_ of\n"
            else "  get =")
        ++ concatMap ((++"\n")) (map getDef constrs) ++
       (if length constrs > 1
	    then "      _ -> fail \"no decoding\""
	    else ""
       )
    getDef (n, (name, ps)) =
        let wrap = if ps /= 0 then ("("++) . (++")") else id
        in
        concat [ "      " ++ show n ++ " ->" | length constrs > 1 ]
        ++ concatMap (\x -> " get >>= \\"++x++" ->") (take ps manyLetters)
        ++ " return "
        ++ wrap (name ++ concatMap (" "++) (take ps manyLetters))