module Camfort.Specification.Stencils.Model where
import Camfort.Specification.Stencils.Syntax
import Data.Set hiding (map,foldl',(\\))
import qualified Data.Set as Set
import Data.List
import qualified Data.List as DL
import qualified Data.Map as DM
import Debug.Trace
model :: Result Spatial -> Result (Multiset [Int])
model s = let ?globalDimensionality = dimensionality s
          in mkModel s
eqByModel :: Specification -> Specification -> Bool
eqByModel infered declared =
    let d1 = dimensionality infered
        d2 = dimensionality declared
    in let ?globalDimensionality = d1 `max` d2
       in let modelInf = mkModel infered
              modelDec = mkModel declared
          in case (modelInf, modelDec) of
               
               
               
               (Bound (Just mdlLI) Nothing, Bound (Just mdlLD) _)
                        -> mdlLD <= mdlLI
               (Bound Nothing (Just mdlUI), Bound _ (Just mdlUD))
                        -> mdlUI <= mdlUD
               (Bound (Just mdlLI) (Just _), Bound (Just mdlLD) Nothing)
                        -> mdlLD <= mdlLI
               (Bound (Just _ ) (Just mdlUI), Bound Nothing (Just mdlUD))
                        -> mdlUI <= mdlUD
               (Exact s, Bound Nothing (Just mdlUD))
                        -> s <= mdlUD
               (Exact s, Bound (Just mdlLD) Nothing)
                        -> mdlLD <= s
               (Exact s, Bound (Just mdlLD) (Just mdlUD))
                        -> (mdlLD <= s) && (s <= mdlUD)
              
               (x, y) -> x == y
class Model spec where
   type Domain spec
   
   
   
   mkModel :: (?globalDimensionality :: Int) => spec -> Domain spec
   
   
   dimensionality :: spec -> Int
   dimensionality = maximum . dimensions
   
   dimensions :: spec -> [Int]
type Multiset a = DM.Map a Bool
mkMultiset :: Ord a => [a] -> DM.Map a Bool
mkMultiset =
  Prelude.foldr (\a map -> DM.insertWithKey multi a True map) DM.empty
     where multi k x y = x || y
instance Model Specification where
   type Domain Specification = Result (Multiset [Int])
   mkModel (Specification (Left s)) = mkModel s
   mkModel _                        = error "Only spatial specs are modelled"
   dimensionality (Specification (Left s)) = dimensionality s
   dimensionality _                        = 0
   dimensions (Specification (Left s)) = dimensions s
   dimensions _                        = [0]
instance Model (Result Spatial) where
  type Domain (Result Spatial) = Result (Multiset [Int])
  mkModel = fmap mkModel
  dimensionality (Exact s) = dimensionality s
  dimensionality (Bound l u) = (dimensionality l) `max` (dimensionality u)
  dimensions (Exact s) = dimensions s
  dimensions (Bound l u) = (dimensions l) ++ (dimensions u)
instance Model a => Model (Maybe a) where
  type Domain (Maybe a) = Maybe (Domain a)
  mkModel Nothing = Nothing
  mkModel (Just x) = Just (mkModel x)
  dimensions Nothing = [0]
  dimensions (Just x) = dimensions x
instance Model Spatial where
    type Domain Spatial = Multiset [Int]
    mkModel spec@(Spatial lin s) =
      case lin of
        Linear    -> DM.fromList . map (,False) . toList $ indices
        NonLinear -> DM.fromList . map (,True) . toList $ indices
       where
         indices = mkModel s
    dimensionality (Spatial _ s) = dimensionality s
    dimensions (Spatial _ s)     = dimensions s
instance Model RegionSum where
   type Domain RegionSum = Set [Int]
   mkModel (Sum ss) = unions (map mkModel ss)
   dimensionality (Sum ss) =
     maximum1 (map dimensionality ss)
   dimensions (Sum ss) = concatMap dimensions ss
instance Model Region where
   type Domain Region = Set [Int]
   mkModel (Forward dep dim reflx) = fromList
     [mkSingleEntryNeg i dim ?globalDimensionality | i <- [i0..dep]]
       where i0 = if reflx then 0 else 1
   mkModel (Backward dep dim reflx) = fromList
     [mkSingleEntryNeg i dim ?globalDimensionality | i <- [(dep)..i0]]
       where i0 = if reflx then 0 else 1
   mkModel (Centered dep dim reflx) = fromList
     [mkSingleEntryNeg i dim ?globalDimensionality | i <- [(dep)..dep] \\ i0]
       where i0 = if reflx then [] else [0]
   dimensionality (Forward  _ d _) = d
   dimensionality (Backward _ d _) = d
   dimensionality (Centered _ d _) = d
   dimensions (Forward _ d _)  = [d]
   dimensions (Backward _ d _) = [d]
   dimensions (Centered _ d _) = [d]
mkSingleEntryNeg :: Int -> Int -> Int -> [Int]
mkSingleEntryNeg i 0 ds = error "Dimensions are 1-indexed"
mkSingleEntryNeg i 1 ds = i : replicate (ds  1) absoluteRep
mkSingleEntryNeg i d ds = absoluteRep : mkSingleEntryNeg i (d  1) (ds  1)
instance Model RegionProd where
   type Domain RegionProd = Set [Int]
   mkModel (Product [])  = Set.empty
   mkModel (Product [s])  = mkModel s
   mkModel p@(Product ss)  = cleanedProduct
     where
       cleanedProduct = fromList $ DL.filter keepPred product
       product = cprodVs $ map (toList . mkModel) ss
       dims = dimensions p
       keepPred el = DL.foldr (\pr acc -> nonProdP pr && acc) True (zip [(1::Int)..] el)
       nonProdP (i,el) = i `notElem` dims || el /= absoluteRep
   dimensionality (Product ss) =
      maximum1 (map dimensionality ss)
   dimensions (Product ss) =
      nub $ concatMap dimensions ss
tensor n s t = cleanedProduct
   where
       cleanedProduct = fromList $ DL.filter keepPred product
       product = cprodV s t
       keepPred el = DL.foldr (\pr acc -> nonProdP pr && acc) True (zip [(1::Int)..] el)
       nonProdP (i,el) = i `notElem` [1..n] || el /= absoluteRep
cprodVs :: [[[Int]]] -> [[Int]]
cprodVs = foldr1 cprodV
cprodV :: [[Int]] -> [[Int]] -> [[Int]]
cprodV xss yss = xss >>= (\xs -> yss >>= pairwisePerm xs)
pairwisePerm :: [Int] -> [Int] -> [[Int]]
pairwisePerm x y = sequence . transpose $ [x, y]
maximum1 [] = 0
maximum1 xs = maximum xs