{-# LANGUAGE CPP #-}

module Database.PostgreSQL.PQTypes.Checks.Util
  ( ValidationResult
  , validationError
  , validationInfo
  , mapValidationResult
  , validationErrorsToInfos
  , resultCheck
  , resultHasErrors
  , topMessage
  , tblNameText
  , tblNameString
  , checkEquality
  , checkNames
  , checkPKPresence
  , objectHasLess
  , objectHasMore
  , arrListTable
  , checkOverlappingIndexesQuery
  ) where

import Control.Monad.Catch
#if !MIN_VERSION_base(4,11,0)
import Data.Monoid
#endif
import Data.List qualified as L
import Data.Monoid.Utils
import Data.Semigroup qualified as SG
import Data.Text (Text)
import Data.Text qualified as T
import Log
import TextShow

import Database.PostgreSQL.PQTypes
import Database.PostgreSQL.PQTypes.Model

-- | A (potentially empty) list of info/error messages.
data ValidationResult = ValidationResult
  { ValidationResult -> [Text]
vrInfos :: [Text]
  , ValidationResult -> [Text]
vrErrors :: [Text]
  }
  deriving (Int -> ValidationResult -> ShowS
[ValidationResult] -> ShowS
ValidationResult -> String
(Int -> ValidationResult -> ShowS)
-> (ValidationResult -> String)
-> ([ValidationResult] -> ShowS)
-> Show ValidationResult
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ValidationResult -> ShowS
showsPrec :: Int -> ValidationResult -> ShowS
$cshow :: ValidationResult -> String
show :: ValidationResult -> String
$cshowList :: [ValidationResult] -> ShowS
showList :: [ValidationResult] -> ShowS
Show, ValidationResult -> ValidationResult -> Bool
(ValidationResult -> ValidationResult -> Bool)
-> (ValidationResult -> ValidationResult -> Bool)
-> Eq ValidationResult
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: ValidationResult -> ValidationResult -> Bool
== :: ValidationResult -> ValidationResult -> Bool
$c/= :: ValidationResult -> ValidationResult -> Bool
/= :: ValidationResult -> ValidationResult -> Bool
Eq)

validationError :: Text -> ValidationResult
validationError :: Text -> ValidationResult
validationError Text
err = ValidationResult
forall a. Monoid a => a
mempty {vrErrors = [err]}

validationInfo :: Text -> ValidationResult
validationInfo :: Text -> ValidationResult
validationInfo Text
msg = ValidationResult
forall a. Monoid a => a
mempty {vrInfos = [msg]}

-- | Downgrade all error messages in a ValidationResult to info messages.
validationErrorsToInfos :: ValidationResult -> ValidationResult
validationErrorsToInfos :: ValidationResult -> ValidationResult
validationErrorsToInfos ValidationResult {[Text]
vrInfos :: ValidationResult -> [Text]
vrErrors :: ValidationResult -> [Text]
vrInfos :: [Text]
vrErrors :: [Text]
..} =
  ValidationResult
forall a. Monoid a => a
mempty {vrInfos = vrInfos <> vrErrors}

mapValidationResult
  :: ([Text] -> [Text]) -> ([Text] -> [Text]) -> ValidationResult -> ValidationResult
mapValidationResult :: ([Text] -> [Text])
-> ([Text] -> [Text]) -> ValidationResult -> ValidationResult
mapValidationResult [Text] -> [Text]
mapInfos [Text] -> [Text]
mapErrs ValidationResult {[Text]
vrInfos :: ValidationResult -> [Text]
vrErrors :: ValidationResult -> [Text]
vrInfos :: [Text]
vrErrors :: [Text]
..} =
  ValidationResult
forall a. Monoid a => a
mempty {vrInfos = mapInfos vrInfos, vrErrors = mapErrs vrErrors}

instance SG.Semigroup ValidationResult where
  (ValidationResult [Text]
infos0 [Text]
errs0) <> :: ValidationResult -> ValidationResult -> ValidationResult
<> (ValidationResult [Text]
infos1 [Text]
errs1) =
    [Text] -> [Text] -> ValidationResult
ValidationResult ([Text]
infos0 [Text] -> [Text] -> [Text]
forall a. Semigroup a => a -> a -> a
<> [Text]
infos1) ([Text]
errs0 [Text] -> [Text] -> [Text]
forall a. Semigroup a => a -> a -> a
<> [Text]
errs1)

instance Monoid ValidationResult where
  mempty :: ValidationResult
mempty = [Text] -> [Text] -> ValidationResult
ValidationResult [] []
  mappend :: ValidationResult -> ValidationResult -> ValidationResult
mappend = ValidationResult -> ValidationResult -> ValidationResult
forall a. Semigroup a => a -> a -> a
(SG.<>)

topMessage :: Text -> Text -> ValidationResult -> ValidationResult
topMessage :: Text -> Text -> ValidationResult -> ValidationResult
topMessage Text
objtype Text
objname vr :: ValidationResult
vr@ValidationResult {[Text]
vrInfos :: ValidationResult -> [Text]
vrErrors :: ValidationResult -> [Text]
vrInfos :: [Text]
vrErrors :: [Text]
..} =
  case [Text]
vrErrors of
    [] -> ValidationResult
vr
    [Text]
es ->
      [Text] -> [Text] -> ValidationResult
ValidationResult
        [Text]
vrInfos
        ( Text
"There are problems with the"
            Text -> Text -> Text
forall m. (IsString m, Monoid m) => m -> m -> m
<+> Text
objtype
            Text -> Text -> Text
forall m. (IsString m, Monoid m) => m -> m -> m
<+> Text
"'"
            Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
objname
            Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"'"
            Text -> [Text] -> [Text]
forall a. a -> [a] -> [a]
: [Text]
es
        )

resultHasErrors :: ValidationResult -> Bool
resultHasErrors :: ValidationResult -> Bool
resultHasErrors ValidationResult {[Text]
vrInfos :: ValidationResult -> [Text]
vrErrors :: ValidationResult -> [Text]
vrInfos :: [Text]
vrErrors :: [Text]
..} = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Text] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Text]
vrErrors

-- | Log all messages in a 'ValidationResult', and fail if any of them
-- were errors.
resultCheck
  :: (MonadLog m, MonadThrow m)
  => ValidationResult
  -> m ()
resultCheck :: forall (m :: * -> *).
(MonadLog m, MonadThrow m) =>
ValidationResult -> m ()
resultCheck ValidationResult {[Text]
vrInfos :: ValidationResult -> [Text]
vrErrors :: ValidationResult -> [Text]
vrInfos :: [Text]
vrErrors :: [Text]
..} = do
  (Text -> m ()) -> [Text] -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Text -> m ()
forall (m :: * -> *). MonadLog m => Text -> m ()
logInfo_ [Text]
vrInfos
  case [Text]
vrErrors of
    [] -> () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    [Text]
msgs -> do
      (Text -> m ()) -> [Text] -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Text -> m ()
forall (m :: * -> *). MonadLog m => Text -> m ()
logAttention_ [Text]
msgs
      String -> m ()
forall a. HasCallStack => String -> a
error String
"resultCheck: validation failed"

----------------------------------------

tblNameText :: Table -> Text
tblNameText :: Table -> Text
tblNameText = RawSQL () -> Text
unRawSQL (RawSQL () -> Text) -> (Table -> RawSQL ()) -> Table -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Table -> RawSQL ()
tblName

tblNameString :: Table -> String
tblNameString :: Table -> String
tblNameString = Text -> String
T.unpack (Text -> String) -> (Table -> Text) -> Table -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Table -> Text
tblNameText

checkEquality :: (Eq t, Show t) => Text -> [t] -> [t] -> ValidationResult
checkEquality :: forall t. (Eq t, Show t) => Text -> [t] -> [t] -> ValidationResult
checkEquality Text
pname [t]
defs [t]
props = case ([t]
defs [t] -> [t] -> [t]
forall a. Eq a => [a] -> [a] -> [a]
L.\\ [t]
props, [t]
props [t] -> [t] -> [t]
forall a. Eq a => [a] -> [a] -> [a]
L.\\ [t]
defs) of
  ([], []) -> ValidationResult
forall a. Monoid a => a
mempty
  ([t]
def_diff, [t]
db_diff) ->
    Text -> ValidationResult
validationError (Text -> ValidationResult) -> Text -> ValidationResult
forall a b. (a -> b) -> a -> b
$
      [Text] -> Text
forall a. Monoid a => [a] -> a
mconcat
        [ Text
"Table and its definition have diverged and have "
        , Int -> Text
forall a. TextShow a => a -> Text
showt (Int -> Text) -> Int -> Text
forall a b. (a -> b) -> a -> b
$ [t] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [t]
db_diff
        , Text
" and "
        , Int -> Text
forall a. TextShow a => a -> Text
showt (Int -> Text) -> Int -> Text
forall a b. (a -> b) -> a -> b
$ [t] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [t]
def_diff
        , Text
" different "
        , Text
pname
        , Text
" each, respectively:\n"
        , Text
"  ● table:"
        , [t] -> Text
showDiff [t]
db_diff
        , Text
"\n  ● definition:"
        , [t] -> Text
showDiff [t]
def_diff
        ]
  where
    showDiff :: [t] -> Text
showDiff = [Text] -> Text
forall a. Monoid a => [a] -> a
mconcat ([Text] -> Text) -> ([t] -> [Text]) -> [t] -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (t -> Text) -> [t] -> [Text]
forall a b. (a -> b) -> [a] -> [b]
map ((Text
"\n    ○ " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<>) (Text -> Text) -> (t -> Text) -> t -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Text
T.pack (String -> Text) -> (t -> String) -> t -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. t -> String
forall a. Show a => a -> String
show)

checkNames :: Show t => (t -> RawSQL ()) -> [(t, RawSQL ())] -> ValidationResult
checkNames :: forall t.
Show t =>
(t -> RawSQL ()) -> [(t, RawSQL ())] -> ValidationResult
checkNames t -> RawSQL ()
prop_name = [ValidationResult] -> ValidationResult
forall a. Monoid a => [a] -> a
mconcat ([ValidationResult] -> ValidationResult)
-> ([(t, RawSQL ())] -> [ValidationResult])
-> [(t, RawSQL ())]
-> ValidationResult
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((t, RawSQL ()) -> ValidationResult)
-> [(t, RawSQL ())] -> [ValidationResult]
forall a b. (a -> b) -> [a] -> [b]
map (t, RawSQL ()) -> ValidationResult
check
  where
    check :: (t, RawSQL ()) -> ValidationResult
check (t
prop, RawSQL ()
name) = case t -> RawSQL ()
prop_name t
prop of
      RawSQL ()
pname
        | RawSQL ()
pname RawSQL () -> RawSQL () -> Bool
forall a. Eq a => a -> a -> Bool
== RawSQL ()
name -> ValidationResult
forall a. Monoid a => a
mempty
        | Bool
otherwise ->
            Text -> ValidationResult
validationError (Text -> ValidationResult)
-> ([Text] -> Text) -> [Text] -> ValidationResult
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Text] -> Text
forall a. Monoid a => [a] -> a
mconcat ([Text] -> ValidationResult) -> [Text] -> ValidationResult
forall a b. (a -> b) -> a -> b
$
              [ Text
"Property "
              , String -> Text
T.pack (String -> Text) -> String -> Text
forall a b. (a -> b) -> a -> b
$ t -> String
forall a. Show a => a -> String
show t
prop
              , Text
" has invalid name (expected: "
              , RawSQL () -> Text
unRawSQL RawSQL ()
pname
              , Text
", given: "
              , RawSQL () -> Text
unRawSQL RawSQL ()
name
              , Text
")."
              ]

-- | Check presence of primary key on the named table. We cover all the cases so
-- this could be used standalone, but note that the those where the table source
-- definition and the table in the database differ in this respect is also
-- covered by @checkEquality@.
checkPKPresence
  :: RawSQL ()
  -- ^ The name of the table to check for presence of primary key
  -> Maybe PrimaryKey
  -- ^ A possible primary key gotten from the table data structure
  -> Maybe (PrimaryKey, RawSQL ())
  -- ^ A possible primary key as retrieved from database along
  -- with its name
  -> ValidationResult
checkPKPresence :: RawSQL ()
-> Maybe PrimaryKey
-> Maybe (PrimaryKey, RawSQL ())
-> ValidationResult
checkPKPresence RawSQL ()
tableName Maybe PrimaryKey
mdef Maybe (PrimaryKey, RawSQL ())
mpk =
  case (Maybe PrimaryKey
mdef, Maybe (PrimaryKey, RawSQL ())
mpk) of
    (Maybe PrimaryKey
Nothing, Maybe (PrimaryKey, RawSQL ())
Nothing) -> [Text] -> ValidationResult
valRes [Text
noSrc, Text
noTbl]
    (Maybe PrimaryKey
Nothing, Just (PrimaryKey, RawSQL ())
_) -> [Text] -> ValidationResult
valRes [Text
noSrc]
    (Just PrimaryKey
_, Maybe (PrimaryKey, RawSQL ())
Nothing) -> [Text] -> ValidationResult
valRes [Text
noTbl]
    (Maybe PrimaryKey, Maybe (PrimaryKey, RawSQL ()))
_ -> ValidationResult
forall a. Monoid a => a
mempty
  where
    noSrc :: Text
noSrc = Text
"no source definition"
    noTbl :: Text
noTbl = Text
"no table definition"
    valRes :: [Text] -> ValidationResult
valRes [Text]
msgs =
      Text -> ValidationResult
validationError (Text -> ValidationResult)
-> ([Text] -> Text) -> [Text] -> ValidationResult
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Text] -> Text
forall a. Monoid a => [a] -> a
mconcat ([Text] -> ValidationResult) -> [Text] -> ValidationResult
forall a b. (a -> b) -> a -> b
$
        [ Text
"Table "
        , RawSQL () -> Text
unRawSQL RawSQL ()
tableName
        , Text
" has no primary key defined "
        , Text
" (" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text -> [Text] -> Text
forall m. Monoid m => m -> [m] -> m
mintercalate Text
", " [Text]
msgs Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
")"
        ]

objectHasLess :: Show t => Text -> Text -> t -> Text
objectHasLess :: forall t. Show t => Text -> Text -> t -> Text
objectHasLess Text
otype Text
ptype t
missing =
  Text
otype
    Text -> Text -> Text
forall m. (IsString m, Monoid m) => m -> m -> m
<+> Text
"in the database has *less*"
    Text -> Text -> Text
forall m. (IsString m, Monoid m) => m -> m -> m
<+> Text
ptype
    Text -> Text -> Text
forall m. (IsString m, Monoid m) => m -> m -> m
<+> Text
"than its definition (missing:"
    Text -> Text -> Text
forall m. (IsString m, Monoid m) => m -> m -> m
<+> String -> Text
T.pack (t -> String
forall a. Show a => a -> String
show t
missing)
    Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
")"

objectHasMore :: Show t => Text -> Text -> t -> Text
objectHasMore :: forall t. Show t => Text -> Text -> t -> Text
objectHasMore Text
otype Text
ptype t
extra =
  Text
otype
    Text -> Text -> Text
forall m. (IsString m, Monoid m) => m -> m -> m
<+> Text
"in the database has *more*"
    Text -> Text -> Text
forall m. (IsString m, Monoid m) => m -> m -> m
<+> Text
ptype
    Text -> Text -> Text
forall m. (IsString m, Monoid m) => m -> m -> m
<+> Text
"than its definition (extra:"
    Text -> Text -> Text
forall m. (IsString m, Monoid m) => m -> m -> m
<+> String -> Text
T.pack (t -> String
forall a. Show a => a -> String
show t
extra)
    Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
")"

arrListTable :: RawSQL () -> Text
arrListTable :: RawSQL () -> Text
arrListTable RawSQL ()
tableName = Text
" ->" Text -> Text -> Text
forall m. (IsString m, Monoid m) => m -> m -> m
<+> RawSQL () -> Text
unRawSQL RawSQL ()
tableName Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
": "

checkOverlappingIndexesQuery :: RawSQL () -> SQL
checkOverlappingIndexesQuery :: RawSQL () -> SQL
checkOverlappingIndexesQuery RawSQL ()
tableName =
  [SQL] -> SQL
forall m. (IsString m, Monoid m) => [m] -> m
smconcat
    [ SQL
"WITH"
    , -- get predicates (WHERE clause) definition in text format (ugly but the parsed version
      -- can differ even if the predicate is the same), ignore functional indexes at the same time
      -- as that would make this query very ugly
      SQL
"     indexdata1 AS (SELECT *"
    , SQL
"                         , ((regexp_match(pg_get_indexdef(indexrelid)"
    , SQL
"                                        , 'WHERE (.*)$')))[1] AS preddef"
    , SQL
"                    FROM pg_index"
    , SQL
"                    WHERE indexprs IS NULL"
    , SQL
"                    AND indrelid = '" SQL -> SQL -> SQL
forall a. Semigroup a => a -> a -> a
<> RawSQL () -> SQL
raw RawSQL ()
tableName SQL -> SQL -> SQL
forall a. Semigroup a => a -> a -> a
<> SQL
"'::regclass)"
    , -- add the rest of metadata and do the join
      SQL
"   , indexdata2 AS (SELECT t1.*"
    , SQL
"                         , pg_get_indexdef(t1.indexrelid) AS contained"
    , SQL
"                         , pg_get_indexdef(t2.indexrelid) AS contains"
    , SQL
"                         , array_to_string(t1.indkey, '+') AS colindex"
    , SQL
"                         , array_to_string(t2.indkey, '+') AS colotherindex"
    , SQL
"                         , t2.indexrelid AS other_index"
    , SQL
"                         , t2.indisunique AS other_indisunique"
    , SQL
"                         , t2.preddef AS other_preddef"
    , -- cross join all indexes on the same table to try all combination (except oneself)
      SQL
"                    FROM indexdata1 AS t1"
    , SQL
"                         INNER JOIN indexdata1 AS t2 ON t1.indrelid = t2.indrelid"
    , SQL
"                                                    AND t1.indexrelid <> t2.indexrelid)"
    , SQL
"  SELECT contained"
    , SQL
"       , contains"
    , SQL
"  FROM indexdata2"
    , SQL
" JOIN pg_class c ON (c.oid = indexdata2.indexrelid)"
    , -- The indexes are the same or the "other" is larger than us
      SQL
"  WHERE (colotherindex = colindex"
    , SQL
"      OR colotherindex LIKE colindex || '+%')"
    , -- and this is not a local index
      SQL
"    AND NOT c.relname ILIKE 'local_%'"
    , -- and we have the same predicate
      SQL
"    AND other_preddef IS NOT DISTINCT FROM preddef"
    , -- and either the other is unique (so better than us) or none of us is unique
      SQL
"    AND (NOT indisunique)"
    , SQL
"    OR ("
    , SQL
"             indisunique"
    , SQL
"         AND other_indisunique"
    , SQL
"         AND colindex = colotherindex"
    , SQL
");"
    ]