-- | This module defines the state monad used in
-- 'Quipper.Utils.Template.Lifting' for Template Haskell 
-- term manipulation.
module Quipper.Utils.Template.LiftQ where

import qualified Language.Haskell.TH as TH
import qualified Data.Map as Map
import qualified Data.Set as Set
import qualified Data.List as List

import Language.Haskell.TH (Name)
import Control.Monad.State
import Data.Map (Map)
import Data.Set (Set)

import Control.Applicative (Applicative(..))
import Control.Monad (liftM, ap)

import qualified Quipper.Utils.Template.ErrorMsgQ as Err
import Quipper.Utils.Template.ErrorMsgQ (ErrMsgQ)

-- | State of the monad.
data LiftState = LiftState {
  boundVar :: Map Name Int, -- ^ How many times each name is bound.
  prefix :: Maybe String,   -- ^ The template prefix .
  monadName :: Maybe String -- ^ The name of the monad.
}

-- | An empty state.
emptyLiftState :: LiftState
emptyLiftState = LiftState { 
  boundVar = Map.empty, 
  prefix = Nothing,
  monadName = Nothing
}

-- | Shortcut to @StateT LiftState ErrMsgQ@.
type LiftQState = StateT LiftState ErrMsgQ

-- | The monad.
data LiftQ a = LiftQ (LiftQState a)

instance Monad LiftQ where
  return x = LiftQ $ return x
  (>>=) (LiftQ x) f = LiftQ $ do
           x' <- x
           let (LiftQ y) = f x'
           y

instance Applicative LiftQ where
  pure = return
  (<*>) = ap

instance Functor LiftQ where
  fmap = liftM

-- | Retrieve the state from the monad.
getState :: LiftQ LiftState
getState = LiftQ $ mapStateT (\x -> do ((),s) <- x; return (s,s))
                             (return ())

-- | Set the state of the monad.
setState :: LiftState -> LiftQ ()
setState s = LiftQ $ mapStateT (\_ -> return ((),s))
                               ((return ()) :: LiftQState ())

-- * Various functions to go back and forth between monads.

-- | From 'ErrMsgQ' to 'LiftQ'.
embedErrMsgQ :: ErrMsgQ a -> LiftQ a
embedErrMsgQ q = LiftQ $ mapStateT (\x -> do ((),s) <- x; y <- q; return (y,s))
                                   (return ())

-- | From 'TH.Q' to 'LiftQ'.
embedQ :: TH.Q a -> LiftQ a
embedQ q = LiftQ $ mapStateT (\x -> do ((),s) <- x; y <- Err.embedQ q; return (y,s))
                             (return ())

-- | Get 'TH.Q' out of 'LiftQ'
extractQ :: String -> LiftQ a -> TH.Q a
extractQ s (LiftQ x) = Err.extractQ s $ evalStateT x emptyLiftState

-- | Set an error message.
errorMsg :: String -> LiftQ a
errorMsg s = embedErrMsgQ $ Err.errorMsg s


-- * Working with variable names.

-- | Increase the number of binds of a variable name.
addToBoundVar :: Name -> LiftQ ()
addToBoundVar n = do
   s <- getState
   let new_value = 
         if (Map.member n $ boundVar s)
         then 1 + ((boundVar s) Map.! n) 
         else 0
   setState $ s { boundVar = Map.insert n new_value $ boundVar s }

-- | Decrease the number of binds of a variable name.
removeFromBoundVar :: Name -> LiftQ ()
removeFromBoundVar n = do
   s <- getState
   if (not $ Map.member n $ boundVar s) 
     then errorMsg ((show n) ++ " is not a bound value")
     else let old_value = (boundVar s) Map.! n in
        if old_value == 0
        then setState $ s { boundVar = Map.delete n $ boundVar s }
        else setState $ s { boundVar = Map.insert n (old_value - 1) $ boundVar s }

-- | Run a computation with a particular name being bound.
withBoundVar :: Name -> LiftQ a -> LiftQ a
withBoundVar n comp = do
  addToBoundVar n
  a <- comp
  removeFromBoundVar n
  return a

-- | Run a computation with a particular list of names being bound.
withBoundVars :: [Name] -> LiftQ a -> LiftQ a
withBoundVars names comp = foldl (flip withBoundVar) comp names

-- | Say whether a given name is bound.
isBoundVar :: Name -> LiftQ Bool
isBoundVar n = do
  s <- getState
  return $ Map.member n $ boundVar s


-- * Other operations on monad state.

-- | Set the template prefix.
setPrefix :: String -> LiftQ ()
setPrefix p = do
   s <- getState
   case (prefix s) of
      Just p' -> errorMsg ("cannot set the prefix to " ++ 
                           (show p) ++ 
                           ": prefix already defined as " ++ 
                           p') 
      Nothing -> setState $ s { prefix = Just p }


-- | Get the template prefix.
getPrefix :: LiftQ String
getPrefix = do
   s <- getState
   case (prefix s) of
      Nothing -> errorMsg "undefined prefix"
      Just p -> return p

-- | Set the monad name.
setMonadName :: String -> LiftQ ()
setMonadName m = do
   s <- getState
   case (monadName s) of
      Just m' -> errorMsg ("cannot set the monad to " ++ 
                           (show m) ++ 
                           ": monad already defined as " ++ 
                           m') 
      Nothing -> setState $ s { monadName = Just m }

-- | Get the monad name.
getMonadName :: LiftQ String
getMonadName = do
   s <- getState
   case (monadName s) of
      Nothing -> errorMsg "undefined monad"
      Just m -> return m




-- * Functions dealing with variable names.

-- | Make a name out of a string.
mkName :: String -> Name
mkName s = TH.mkName s

-- | Make a name out of a string, monadic-style.
newName :: String -> LiftQ Name
newName st = embedQ $ TH.newName st




-- | Make any string into a string containing only @[0-9a-zA-Z_.]@.
-- For example, it replaces any occurrence of @\"+\"@ with
-- @\"symb_plus_\"@.
sanitizeString :: String -> String
sanitizeString name = 
  List.concat (List.map 
               (\c -> 
                 Map.findWithDefault c c 
                     (Map.map (\s -> "symb_" ++ s ++ "_")
                              unicodeNames))
               (List.map (\x -> [x]) name))
   where
   unicodeNames :: Map.Map String String
   unicodeNames = Map.fromList
    [("!","exclamation"),
     ("\"","doublequote"),
     ("#","sharp"),
     ("$","dollar"),
     ("%","percent"),
     ("&","ampersand"),
     ("'","quote"),
     ("(","oparent"),
     (")","cparent"),
     ("*","star"),
     ("+","plus"),
     (",","comma"),
     ("-","minus"),
     -- we keep dots
     ("/","slash"),
     (":","colon"),
     (";","semicolon"),
     ("<","oangle"),
     ("=","equal"),
     (">","cangle"),
     ("?","question"),
     ("@","at"),
     ("[","obracket"),
     ("\\","backslash"),
     ("]","cbracket"),
     ("^","caret"),
     -- we keep _
     ("`","graveaccent"),
     ("{","obrace"),
     ("|","vbar"),
     ("}","cbrace"),
     ("~","tilde")]


-- | Take a string and make it into a valid Haskell name starting with
-- @\"template_\"@.
templateString :: String -> LiftQ String
templateString s = do
  p <- getPrefix
  return (p ++ (sanitizeString s))

-- | Look for the corresponding "template" name.
lookForTemplate :: Name -> LiftQ (Maybe Name)
lookForTemplate n = do
  t_string <- templateString $ TH.nameBase n
  embedQ $ TH.lookupValueName t_string

-- | Make a the template version of a given name.
makeTemplateName :: Name -> LiftQ Name
makeTemplateName n = do
  t_string <- templateString $ TH.nameBase n
  return $ TH.mkName t_string


-- * Other functions.

-- | Print on the terminal a monadic, printable object.
prettyPrint :: TH.Ppr a => LiftQ a -> IO ()
prettyPrint x = (TH.runQ $ extractQ "prettyPrint: " x) >>= (putStrLn . TH.pprint)


-- | Project patterns out of a clause.
clauseGetPats :: TH.Clause -> [TH.Pat]
clauseGetPats (TH.Clause pats _ _) = pats


-- | Check that the list is a non-empty repetition of the same
-- element.
equalNEListElts :: Eq a => [a] -> Bool
equalNEListElts [] = True
equalNEListElts (h:list) = foldl (&&) True $ map (== h) list


-- | Returns the length of the patterns in a list of clauses. Throw an
-- error if the patterns do not have all the same size.
clausesLengthPats :: [TH.Clause] -> LiftQ Int
clausesLengthPats [] = errorMsg "empty clause"
clausesLengthPats clauses 
  | (equalNEListElts $ map length $ map clauseGetPats clauses) = 
    return $ length $ clauseGetPats $ head clauses    
clausesLengthPats _ = errorMsg "patterns in clause are not of equal size"