{-# LANGUAGE ApplicativeDo #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}

module Main (main) where

import Options.Applicative

import Data.Char (toUpper)
import Data.Bifunctor (first)
import Data.Bits (bit)
import Data.ByteString.Builder (Builder)
import Data.ByteString.Builder qualified as Builder
import Data.List qualified as List
import Data.Map.Strict qualified as Map
import Data.SpirV.Headers.Enum qualified as Enum
import Data.Text (Text)
import Data.Text qualified as Text
import Data.Text.Encoding (encodeUtf8Builder)
import System.Directory (createDirectoryIfMissing)
import System.FilePath (joinPath, (</>), (<.>))

main :: IO ()
main = execParser (info optionsP idm) >>= run

run :: Options -> IO ()
run Options{..} = do
  Enum.Spv{meta=_, enum} <- Enum.decode input >>= either fail pure
  createDirectoryIfMissing True output
  mapM_ (genEnumModule output modulePrefix) enum

genEnumModule :: FilePath -> [Text] -> Enum.Enum -> IO ()
genEnumModule base path Enum.Enum{name, type_, values} = do
  createDirectoryIfMissing True dirPath
  putStrLn filePath
  Builder.writeFile filePath body
  where
    filePath = dirPath </> Text.unpack name <.> "hs"
    dirPath = base </> joinPath (map Text.unpack path)
    modulePath = Text.intercalate "." $ path <> [name]
    items = map (first pascalCase) . List.sortOn snd $ Map.assocs values
      where
        pascalCase t = case Text.uncons t of
          Just (c, t') | c /= c' -> Builder.char7 c' <> encodeUtf8Builder t'
            where
              c' = toUpper c
          _ -> encodeUtf8Builder t

    body :: Builder
    body = mconcat $ List.intersperse "\n" $ concat
      [ prologue
      , imports
      , [""]
      , typeDecl
      ]
    prologue :: [Builder]
    prologue =
      [ "{-# LANGUAGE GeneralizedNewtypeDeriving #-}"
      , "{-# LANGUAGE DerivingStrategies #-}"
      , "{-# LANGUAGE PatternSynonyms #-}"
      , "{-# LANGUAGE TypeSynonymInstances #-}"
      , ""
      , "module " <> encodeUtf8Builder modulePath <> " where"
      , ""
      ]
    imports :: [Builder]
    imports = concat
      [ [ "import Data.Bits (Bits, FiniteBits, (.|.))"
        | type_ == Enum.Bit
        ]
      , [ "import Data.Word (Word32)"
        , "import Foreign.Storable (Storable)"
        ]
      ]
    typeDecl :: [Builder]
    typeDecl = case type_ of
      Enum.Value ->
        [ "newtype " <> valueName <> " = " <> valueName <> " Word32"
        , "  deriving newtype (Eq, Ord, Storable)"
        ] <> instanceShow <> patterns
        where
          valueName = encodeUtf8Builder name
          instanceShow =
            [ ""
            , "instance Show " <> valueName <> " where"
            , "  showsPrec p (" <> valueName <> " v) = case v of"
            ] <> showCase <> showUnknown
          showCase = do
            (k, v) <- List.nubBy (\(_, a) (_, b) -> a == b) items
            pure $ "    " <> Builder.word32Dec v <> " -> showString \"" <> k <> "\""
          showUnknown =
            [ "    x -> showParen (p > 10) $ showString \"" <> valueName <> " \" . showsPrec (p + 1) x"
            ]
          patterns :: [Builder]
          patterns = do
            (k, v) <- items
            [ ""
              , "pattern " <> k <> " :: " <> valueName
              , "pattern " <> k <> " = " <> valueName <> " " <> Builder.word32Dec v
              ]
      Enum.Bit ->
        [ "type " <> flagName <> " = " <> bitsName
        , ""
        , "newtype " <> bitsName <> " = " <> bitsName <> " Word32"
        , "  deriving newtype (Eq, Ord, Storable, Bits, FiniteBits)"
        ] <> instanceShow <> instanceSemigroup <> instanceMonoid <> patterns
        where
          flagName = encodeUtf8Builder name
          bitsName = encodeUtf8Builder name <> "Bits"
          instanceShow =
            []
          instanceSemigroup =
            [ ""
            , "instance Semigroup " <> flagName <> " where"
            , "  (" <> bitsName <> " a) <> (" <> bitsName <> " b) = " <> bitsName <> " (a .|. b)"
            ]
          instanceMonoid =
            [ ""
            , "instance Monoid " <> flagName <> " where"
            , "  mempty = " <> bitsName <> " 0"
            ]
          patterns :: [Builder]
          patterns = do
            (k, v) <- items
            [ ""
              , "pattern " <> k <> " :: " <> bitsName
              , "pattern " <> k <> " = " <> bitsName <> " 0x" <> Builder.word32HexFixed (bit $ fromIntegral v)
              ]

data Options = Options
  { input :: FilePath
  , output :: FilePath
  , modulePrefix :: [Text]
  }

defaultOptions :: Options
defaultOptions = Options
  { input = "SPIRV-Headers/include/spirv/unified1/spirv.json"
  , output = "spirv-enum/src"
  , modulePrefix = ["Data", "SpirV", "Enum"]
  }

optionsP :: Parser Options
optionsP = do
  input <- strOption $ mconcat
    [ long "input"
    , short 'i'
    , metavar "FILE"
    , help "Path to spirv.json file."
    , value $ input defaultOptions
    , showDefault
    ]
  output <- strOption $ mconcat
    [ long "output"
    , short 'o'
    , metavar "DIR"
    , help "Root path for generated modules."
    , value $ output defaultOptions
    , showDefault
    ]
  modulePrefix <- option (eitherReader $ modulePrefixP . Text.pack) $ mconcat
    [ long "module-prefix"
    , short 'm'
    , metavar "HASKELL.MODULE.PATH"
    , help "Name and directory prefix for generated modules."
    , value $ modulePrefix defaultOptions
    , showDefaultWith (Text.unpack . Text.intercalate ".")
    ]
  pure Options{..}

modulePrefixP :: Text -> Either String [Text]
modulePrefixP = \case
  "" -> Left "Empty path"
  s | ".." `Text.isInfixOf` s -> Left "Empty module segment"
  s | ('/' `Text.elem` s) || ('\\' `Text.elem` s) -> Left "Module paths are dotted, not slashed."
  s | not $ null (drop 1 $ Text.words s) -> Left "Module paths can't contain spaces."
  s -> Right $ Text.split (== '.') s