{-# LANGUAGE CPP #-}

module Test.Sandwich.Contexts.UnixSocketPath (
  withUnixSocketDirectory
  , maxUnixSocketLength
  ) where

import Control.Monad.IO.Unlift
import Relude
import System.IO.Error (IOError)
import Test.Sandwich.Expectations (expectationFailure)
import UnliftIO.Directory
import UnliftIO.Exception
import UnliftIO.Temporary


-- | The longest allowed path for a Unix socket on the current system.
maxUnixSocketLength :: Int
#ifdef mingw32_HOST_OS
maxUnixSocketLength = maxBound
#elif darwin_host_os
maxUnixSocketLength = 103 -- macOS: 104 with null terminator
#else
maxUnixSocketLength :: Int
maxUnixSocketLength = Int
107 -- Linux: 108 with null terminator
#endif

-- | Create a temporary directory in which a Unix socket can be safely created,
-- bearing in mind the longest allowed Unix socket path on the system.
withUnixSocketDirectory :: (MonadUnliftIO m)
  -- | Name template, as passed to 'withSystemTempDirectory'
  => String
  -- | Amount of headroom to leave for a file name in this directory,
  -- before hitting the 'maxUnixSocketLength'
  -> Int
  -- | Callback
  -> (FilePath -> m a) -> m a
withUnixSocketDirectory :: forall (m :: * -> *) a.
MonadUnliftIO m =>
String -> Int -> (String -> m a) -> m a
withUnixSocketDirectory String
nameTemplate Int
headroom String -> m a
action = do
  String -> (String -> m a) -> m a
forall (m :: * -> *) a.
MonadUnliftIO m =>
String -> (String -> m a) -> m a
withSystemTempDirectory String
nameTemplate ((String -> m a) -> m a) -> (String -> m a) -> m a
forall a b. (a -> b) -> a -> b
$ \String
dir ->
    if | String -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length String
dir Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
headroom Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
maxUnixSocketLength -> String -> m a
action String
dir
       | Bool
otherwise -> String -> Int -> (String -> m a) -> m a
forall (m :: * -> *) a.
MonadUnliftIO m =>
String -> Int -> (String -> m a) -> m a
withShortTempDir String
nameTemplate Int
headroom String -> m a
action

withShortTempDir :: (
  MonadUnliftIO m
  )
  => String
  -> Int
  -> (FilePath -> m a)
  -> m a
withShortTempDir :: forall (m :: * -> *) a.
MonadUnliftIO m =>
String -> Int -> (String -> m a) -> m a
withShortTempDir String
nameTemplate Int
headroom String -> m a
action = String -> m Bool
forall (m :: * -> *). MonadIO m => String -> m Bool
doesDirectoryExist String
"/tmp" m Bool -> (Bool -> m a) -> m a
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
  Bool
True -> String -> m Bool
forall (m :: * -> *). MonadUnliftIO m => String -> m Bool
isDirectoryWritable String
"/tmp" m Bool -> (Bool -> m a) -> m a
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Bool
True -> String -> String -> (String -> m a) -> m a
forall (m :: * -> *) a.
MonadUnliftIO m =>
String -> String -> (String -> m a) -> m a
withTempDirectory String
"/tmp" String
nameTemplate ((String -> m a) -> m a) -> (String -> m a) -> m a
forall a b. (a -> b) -> a -> b
$ \String
dir ->
      if | String -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length String
dir Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
headroom Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
maxUnixSocketLength -> String -> m a
action String
dir
         | Bool
otherwise -> m a
forall {a}. m a
doFail
    Bool
False -> m a
forall {a}. m a
doFail
  Bool
_ -> m a
forall {a}. m a
doFail
  where
    doFail :: m a
doFail = String -> m a
forall (m :: * -> *) a. (HasCallStack, MonadIO m) => String -> m a
expectationFailure String
"Couldn't create a short enough Unix socket path on this system."

isDirectoryWritable :: MonadUnliftIO m => FilePath -> m Bool
isDirectoryWritable :: forall (m :: * -> *). MonadUnliftIO m => String -> m Bool
isDirectoryWritable String
dir = do
  m Permissions -> m (Either IOError Permissions)
forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
m a -> m (Either e a)
try (String -> m Permissions
forall (m :: * -> *). MonadIO m => String -> m Permissions
getPermissions String
dir) m (Either IOError Permissions)
-> (Either IOError Permissions -> m Bool) -> m Bool
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Left (IOError
_ :: IOError) -> Bool -> m Bool
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
    Right Permissions
perms -> Bool -> m Bool
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> m Bool) -> Bool -> m Bool
forall a b. (a -> b) -> a -> b
$ Permissions -> Bool
writable Permissions
perms