```{-# LANGUAGE NoMonomorphismRestriction #-}

-- | This module contains an implementation of the oracle and its
-- automatic lifting to quantum circuits using Template Haskell.
module Algorithms.QLS.TemplateOracle where

import Data.Complex
import Libraries.Auxiliary

import QuipperLib.Arith hiding (template_symb_plus_)
import Algorithms.QLS.Utils
import Algorithms.QLS.QDouble
import Algorithms.QLS.RealFunc
import Algorithms.QLS.QSignedInt
import Algorithms.QLS.CircLiftingImport

import Quipper
import Quipper.CircLifting

-- | Lifted version of @'any'@ (local version).
build_circuit
local_any :: (a -> Bool) -> [a] -> Bool
local_any f l = case l of
[]    -> False
(h:t) -> if (f h) then True else local_any f t

-- | Auxiliary function.
build_circuit
itoxy :: Int -> Int -> Int -> (Int, Int)
itoxy i nx ny =
if (i <= 0) then (0,0)
else if (i > (nx-1)*ny + nx*(ny-1)) then (0,0)
else ((mod (i - 1) (2*nx - 1)) + 1, ceiling ((fromIntegral i)/(2.0*(fromIntegral nx)-1.0)))

-- | The function sin /x/ \/ /x/.
build_circuit
sinc :: Double -> Double
sinc x = if (x /= 0.0) then (sin x) / x else 1.0

-- | Auxiliary function.
build_circuit
edgetoxy :: Int -> Int -> Int -> (Double, Double)
edgetoxy e nx ny =
let (ex,ey) = itoxy e nx ny in
if (ex < nx) then ((fromIntegral ex) + 0.5, fromIntegral ey)
else (fromIntegral (ex - nx + 1), (fromIntegral ey) + 0.5)

-- | Auxiliary function. The inputs are:
--
-- * /y1/ :: 'Int' - Global edge index of row index of desired matrix element;
--
-- * /y2/ :: 'Int' - Global edge index of column index of desired matrix element;
--
-- * /nx/ :: 'Int' - Number of vertices left to right;
--
-- * /ny/ :: 'Int' - Number of vertices top to bottom;
--
-- * /lx/ :: 'Double' - Length of horizontal edges (distance between vertices in /x/ direction);
--
-- * /ly/ :: 'Double' - Length of vertical edges (distance between vertices in /y/ direction);
--
-- * /k/ :: 'Double' - Plane wave wavenumber.
--
-- The output is the matrix element /A/(/y1/, /y2/).
build_circuit
calcmatrixelement :: --
-- Inputs
Int     -- int y1 - Global edge index of row index of desired matrix element
-> Int  -- int y2 - Global edge index of column index of desired matrix element
-> Int  -- int nx - Number of vertices left to right
-> Int  -- int ny - Number of vertices top to bottom
-> Double -- float lx - Length of horizontal edges (distance between vertices in x direction)
-> Double -- float ly - Length of vertical edges (distance between vertices in y direction)
-> Double -- float k - Plane wave wavenumber
-- Outputs
-> Complex Double -- float A - Matrix element A(y1,y2)
calcmatrixelement y1 y2 nx ny lx ly k =
let (xg1, yg1) = itoxy y1 nx ny in
let (xg2, yg2) = itoxy y2 nx ny in
let b = if ( (y1==y2) && (xg1 >= nx) ) then ly/lx - k*k*lx*ly/3.0
-- B11 and B22
else if ( (y1==y2) && (xg1<nx) ) then lx/ly - k*k*lx*ly/3.0
-- B12 and B21
else if ( (abs(yg1-yg2) == 1) && (abs(xg1-xg2) == 0) && (xg1<nx) ) then -lx/ly + k*k*lx*ly/12.0
-- B34 and B43
else if ( (abs(yg1-yg2)==0) && (abs(xg1-xg2) == 1) && (xg1>=nx) ) then -ly/lx + k*k*lx*ly/12.0
-- B13
else if ( (yg1==(yg2+1)) && (xg1==(xg2-nx+1)) && (xg2>=nx) ) then -1.0
-- B31
else if ( (yg2==(yg1+1)) && (xg2==(xg1-nx+1)) && (xg1>=nx) ) then -1.0
-- B14
else if ( (yg1==(yg2+1)) && (xg1==(xg2-nx)) && (xg1<nx) ) then 1.0
-- B41
else if ( (yg2==(yg1+1)) && (xg2==(xg1-nx)) && (xg2<nx) ) then 1.0
-- B42
else if ( (yg1==yg2) && (xg1==(xg2+nx)) && (xg1>=nx) ) then -1.0
-- B24
else if ( (yg2==yg1) && (xg2==(xg1+nx)) && (xg2>=nx) ) then -1.0
-- B32
else if ( (yg1==yg2) && (xg1==(xg2+nx-1)) && (xg2<nx) ) then 1.0
-- B23
else if ( (yg2==yg1) && (xg2==(xg1+nx-1)) && (xg1<nx) ) then 1.0
else -1.0
in
let c = if ( (y1==y2) && ( (xg1==nx) || (xg1==(2*nx-1)) ) ) then 0.0 :+ (k*ly)
else if  ( (y1==y2) && ( ((yg1==1) && (xg1<nx)) || ((yg1==ny) && (xg1<nx)) ) ) then 0.0 :+ (k*lx)
else 0.0 :+ 0.0
in (b :+ 0.0) + c

-- | Auxiliary function.
build_circuit
get_edges l = case l of
[] -> []
(x:tt) -> case tt of
[] -> []
(y:t) -> (x,y):(get_edges (y:t))

-- | Auxiliary function.
build_circuit
checkedge :: Int -> [(Double,Double)] -> Int -> Int -> Bool
checkedge e scatteringnodes nx ny =
let (xi,yi) = edgetoxy e nx ny in
let test_elt ((x1,y1),(x2, y2)) = xi >= x1 && yi >= y1 && xi <= x2 && yi <= y2 in
let half = take_half scatteringnodes in
let test_list = local_any test_elt (get_edges half) in (not test_list)

-- | Oracle /r/.
build_circuit
calcRweights :: Int -> Int -> Int -> Double -> Double -> Double -> Double -> Double -> Complex Double
calcRweights y nx ny lx ly k theta phi =
let (xc',yc') = edgetoxy y nx ny in
let xc = (xc'-1.0)*lx - ((fromIntegral nx)-1.0)*lx/2.0 in
let yc = (yc'-1.0)*ly - ((fromIntegral ny)-1.0)*ly/2.0 in
let (xg,yg) = itoxy y nx ny in

if (xg == nx) then

let i = (mkPolar ly (k*xc*(cos phi)))*
(mkPolar 1.0 (k*yc*(sin phi)))*
((sinc (k*ly*(sin phi)/2.0)) :+ 0.0) in

let r = ( cos(phi) :+ k*lx )*((cos (theta - phi))/lx :+ 0.0) in i * r

else if (xg==2*nx-1) then

let i = (mkPolar ly (k*xc*cos(phi)))*
(mkPolar 1.0 (k*yc*sin(phi)))*
((sinc (k*ly*sin(phi)/2.0)) :+ 0.0) in

let r = ( cos(phi) :+ (- k*lx))*((cos (theta - phi))/lx :+ 0.0) in i * r

else if ( (yg==1) && (xg<nx) ) then

let i = (mkPolar lx (k*yc*sin(phi)))*
(mkPolar 1.0 (k*xc*cos(phi)))*
((sinc (k*lx*(cos phi)/2.0)) :+ 0.0) in

let r = ( (- sin phi) :+ k*ly )*((cos(theta - phi))/ly :+ 0.0) in i * r

else if ( (yg==ny) && (xg<nx) ) then

let i = (mkPolar lx (k*yc*sin(phi)))*
(mkPolar 1.0 (k*xc*cos(phi)))*
((sinc (k*lx*(cos phi)/2.0)) :+ 0.0) in

let r = ( (- sin phi) :+ (- k*ly) )*((cos(theta - phi)/ly) :+ 0.0) in i * r

else 0.0 :+ 0.0

-- | Auxiliary function for oracle /A/.
build_circuit
convertband :: Int -> Int -> Int -> Int -> Int
convertband y b nx ny =
let nedges = (nx - 1)*ny + nx*(ny - 1) in
let (ex,ey) = itoxy y nx ny in
let x = if ( (ex < nx) && (ey /= 1) ) then
case b of
1 -> y-2*nx+1
2 -> y-nx
3 -> y-nx+1
5 -> y
7 -> y+nx-1
8 -> y+nx
9 -> y+2*nx-1
_ -> -1
else if ( (ex < nx) && (ey == 1) ) then
case b of
5 -> y
7 -> y+nx-1
8 -> y+nx
9 -> y+2*nx-1
_ -> -1
else if ( (ex >= nx) && (ex /= nx) && (ex /= 2*nx-1) ) then
case b of
2 -> y-nx
3 -> y-nx+1
4 -> y-1
5 -> y
6 -> y+1
7 -> y+nx-1
8 -> y+nx
_ -> -1
else if ( (ex >= nx) && (ex == nx) ) then
case b of
3 -> y-nx+1
5 -> y
6 -> y+1
8 -> y+nx
_ -> -1
else if ( (ex >= nx) && (ex == 2*nx-1) ) then
case b of
2 -> y-nx
4 -> y-1
5 -> y
7 -> y+nx-1
_ -> -1
else -1
in if ( (x < 1) || (x > nedges) ) then -1 else x

-- | Oracle /A/. It is equivalent to the Matlab function
-- 'getBandNodeValues'.
--
-- 'getNodeValuesMoreOutputs v b ...'  outputs the node of the edge
-- connected to vertex v in band b, and a real number parameterized by
-- the 'BoolParam' parameter: the magnitude (PFalse) or the phase
-- (PTrue) of the complex value at the corresponding place in the matrix A.
build_circuit
getNodeValuesMoreOutputs ::
Int -> Int -> Int -> Int -> [(Double,Double)] -> Double -> Double -> Double -> BoolParam
-> Int -> (Int, Double)
getNodeValuesMoreOutputs v' b' nx ny scatteringnodes lx ly k argflag maxConnectivity =
let maxC = getIntFromParam maxConnectivity in
let b = getIntFromParam b' in
let nedges = (nx - 1)*ny + nx*(ny - 1) in
let flag = v' <= nedges in
let v = (if flag then v' else (v' - nedges)) in
let nodeDefault = (if flag then v + nedges else v) in
let valueDefault = 0.0 in
let indicesDefault = (-1, -1) in
let isvalid = checkedge v scatteringnodes nx ny
in
if ( (not isvalid) && b == 5 ) then

let indices = (if flag then (v, v + nedges) else (v + nedges, v)) in
case argflag of
PTrue  -> (nodeDefault, valueDefault)
PFalse -> (nodeDefault, 1.0)

else if ( (not isvalid) && b /= 5) then (nodeDefault, valueDefault)

else if ((b > maxC + 2) || (b <= 0)) then (nodeDefault, valueDefault)

else

let x = convertband v b' nx ny in

if (x == -1) then (nodeDefault, valueDefault)
else

let isvalid = checkedge x scatteringnodes nx ny in

if isvalid then

let ax = calcmatrixelement v x nx ny lx ly k in
let (node, indices) = if flag then (x + nedges, (v, x + nedges))
else (x, (v+nedges,x)) in
let value = case argflag of
PTrue -> atan2  (imagPart ax) (realPart ax)
PFalse -> magnitude ax
in (node, value)

else
(nodeDefault, valueDefault)

-- | Auxiliary function for oracle /b/. The inputs are:
--
-- * /y/ :: 'Int' - Global edge index.  Note this is the unmarked /y/
-- coordinate, i.e. the coordinate without scattering regions removed;
--
-- * /nx/ :: 'Int' - Number of vertices left to right;
--
-- * /ny/ :: 'Int' - Number of vertices top to bottom;
--
-- * /lx/ :: 'Double' - Length of horizontal edges (distance between vertices in /x/ direction);
--
-- * /ly/ :: 'Double' - Length of vertical edges (distance between vertices in /y/ direction);
--
-- * /k/ :: 'Double' - Plane wave wavenumber;
--
-- * θ :: 'Double' - Direction of wave propagation;
--
-- * /E0/ :: 'Double' - Magnitude of incident plane wave.
--
-- The output is the magnitude of the electric field on edge /y/.
build_circuit
calcincidentfield :: --
-- Inputs
Int -- int y - Global edge index.  Note this is the unmarked y coordinate,
-- i.e. the coordinate without scattering regions removed.
-> Int -- int nx - Number of vertices left to right
-> Int --  int ny - Number of vertices top to bottom
-> Double -- float lx - Length of horizontal edges (distance between vertices in x direction)
-> Double -- float ly - Length of vertical edges (distance between vertices in y direction)
-> Double -- float k - Plane wave wavenumber
-> Double -- float theta - Direction of wave propagation
-> Double -- float E0 - Magnitude of incident plane wave
-- Outputs
-> Complex Double -- complex float e - Magnitude of electric field on edge y
calcincidentfield y nx ny lx ly k theta e0 =
let (xg, yg) = itoxy y nx ny in
--Determine whether edge is horizontal or vertical
let isvertical = xg >= nx in
let (xvalueTmp, yvalueTmp) = edgetoxy y nx ny in
let xvalue = xvalueTmp * lx in
let yvalue = yvalueTmp * ly in
-- Convert x and y edge coordinates to x and y values and caluculate field
if isvertical then
mkPolar (-cos(theta)*e0) ( -k*(xvalue*cos(theta)+yvalue*sin(theta)))
else
mkPolar (sin(theta)*e0) ( -k*(xvalue*cos(theta)+yvalue*sin(theta)))

-- | Auxiliary function for oracle /b/.
build_circuit
getconnection :: Int -> Int -> Int -> Int -> Int -> Int
getconnection y i' nx ny maxConnectivity =
let i = getIntFromParam i' in
let maxC = getIntFromParam maxConnectivity in
let (ex,ey) = itoxy y nx ny in
let x = if ( (ex < nx) && (ey /= 1) ) then
case i' of
1 -> y-2*nx+1
2 -> y-nx
3 -> y-nx+1
4 -> y
5 -> y+nx-1
6 -> y+nx
7 -> y+2*nx-1
_ -> -1
else if ( (ex < nx) && (ey == 1) ) then
case i' of
1 -> y
2 -> y+nx-1
3 -> y+nx
4 -> y+2*nx-1
_ -> -1
else if ( (ex >= nx) && (ex /= nx) && (ex /= 2*nx-1) ) then
case i' of
1 -> y-nx
2 -> y-nx+1
3 -> y-1
4 -> y
5 -> y+1
6 -> y+nx-1
7 -> y+nx
_ -> -1
else if ( (ex >= nx) && (ex == nx) ) then
case i' of
1 -> y-nx+1
2 -> y
3 -> y+1
4 -> y+nx
_ -> -1
else if ( (ex >= nx) && (ex == 2*nx-1) ) then
case i' of
1 -> y-nx
2 -> y-1
3 -> y
4 -> y+nx-1
_ -> -1
else -1
in
if (i > maxC) then -1
else if (x > nx*(ny-1)+ny*(nx-1)) then -1
else x

-- | Auxiliary function to @'template_paramZero'@.
build_circuit
local_loop_with_index_aux :: Int -> Int -> t -> (Int -> t -> t) -> t
local_loop_with_index_aux i n x f =
case paramMinus n i of
0 -> x
_ -> local_loop_with_index_aux (paramSucc i) n (f i x) f

-- | Local version of @'loop_with_index'@, for lifting.
build_circuit
local_loop_with_index :: Int -> t -> (Int -> t -> t) -> t
local_loop_with_index n x f = local_loop_with_index_aux paramZero n x f

-- | Oracle /b/.
build_circuit
getKnownWeights :: Int -> Int -> Int -> [(Double,Double)] -> Double -> Double -> Double -> Double -> Double -> Int -> Complex Double
getKnownWeights y nx ny scatteringnodes lx ly k theta e0 maxConnectivity =
let makeConnections i connections = let x = getconnection y (paramSucc i) nx ny maxConnectivity in
let t = not \$ checkedge x scatteringnodes nx ny in
(x,t):connections
in
let calcTang b (c,t) = if t
then let matElt = calcmatrixelement y c nx ny lx ly k in
let incField = calcincidentfield c nx ny lx ly k theta e0
in b - matElt * incField
else b
in
let connections = local_loop_with_index maxConnectivity [] makeConnections in
if (not \$ checkedge y scatteringnodes nx ny) then 0.0 :+ 0.0
else foldl calcTang (0.0 :+ 0.0) connections

----------------------------------------------------------------------
----------------------------------------------------------------------
-- Testing functions

test_template_sinc = do
f <- template_sinc
r <- qinit (0 :: FDouble)
f r
return ()

test_template_itoxy = do
f <- template_itoxy
x <- qinit (0 :: FSignedInt)
y <- qinit (0 :: FSignedInt)
z <- qinit (0 :: FSignedInt)
g <- f x
h <- g y
k <- h z
return ()

test_template_edgetoxy = do
f <- template_edgetoxy
x <- qinit (0 :: FSignedInt)
y <- qinit (0 :: FSignedInt)
z <- qinit (0 :: FSignedInt)
g <- f x
h <- g y
k <- h z
return ()

test_template_calcRweights = do
f <- template_calcRweights
n1 <- qinit (0 :: FSignedInt)
n2 <- qinit (0 :: FSignedInt)
n3 <- qinit (0 :: FSignedInt)
x1 <- qinit (0 :: FDouble)
x2 <- qinit (0 :: FDouble)
x3 <- qinit (0 :: FDouble)
x4 <- qinit (0 :: FDouble)
x5 <- qinit (0 :: FDouble)
f1 <- f n1
f2 <- f1 n2
f3 <- f2 n3
g1 <- f3 x1
g2 <- g1 x2
g3 <- g2 x3
g4 <- g3 x4
g5 <- g4 x5
return ()

test_template_calcincidentfield = do
y' <- qinit (0 :: FSignedInt)
nx' <- qinit (0 :: FSignedInt)
ny' <- qinit (0 :: FSignedInt)
lx' <- qinit (0 :: FDouble)
ly' <- qinit (0 :: FDouble)
k' <- qinit (0 :: FDouble)
theta' <- qinit (0 :: FDouble)
e0' <- qinit (0 :: FDouble)
f <- template_calcincidentfield
f1 <- f y'
f2 <- f1 nx'
f3 <- f2 ny'
f4 <- f3 lx'
f5 <- f4 ly'
f6 <- f5 k'
f7 <- f6 theta'
f8 <- f7 e0'
return ()

test_template_calcmatrixelement = do
y1 <- qinit (0 :: FSignedInt)
y2 <- qinit (0 :: FSignedInt)
nx <- qinit (0 :: FSignedInt)
ny <- qinit (0 :: FSignedInt)
lx <- qinit (0 :: FDouble)
ly <- qinit (0 :: FDouble)
k  <- qinit (0 :: FDouble)
f <- template_calcmatrixelement
f1 <- f y1
f2 <- f1 y2
f3 <- f2 nx
f4 <- f3 ny
f5 <- f4 lx
f6 <- f5 ly
f7 <- f6 k
return ()

test_template_getconnection = do
y <- qinit (0 :: FSignedInt)
nx <- qinit (0 :: FSignedInt)
ny <- qinit (0 :: FSignedInt)
f <- template_getconnection
f1 <- f y
f2 <- f1 6
f3 <- f2 nx
f4 <- f3 ny
f5 <- f4 7
return ()

test_template_checkedge = do
y <- qinit (0 :: FSignedInt)
nx <- qinit (0 :: FSignedInt)
ny <- qinit (0 :: FSignedInt)
s <- qinit [(0 :: FDouble,0 :: FDouble),(0 :: FDouble,0 :: FDouble)]
f <- template_checkedge
f1 <- f y
f2 <- f1 s
f3 <- f2 nx
f4 <- f3 ny
return ()

```