{-# OPTIONS -cpp -XParallelListComp #-}
-- module Main where

import DivConN

import Eden -- noetig?
import System
import System.IO
import Data.Complex
import Data.List

import DivConRW

import Control.Parallel.Strategies(seqList)
-- import Observe 
 
#ifdef TEST
import qualified Reference
refFFT :: Int -> ([Complex Double],Int) -> ([Complex Double],Int)
refFFT (l,n) = Reference.fft (n,l)
#endif

-- main = runO program
-- program 
main = do 
  hSetBuffering stdout NoBuffering
  putStrLn "usage: FFT <size or base for power of 4(1024)> <version 0,1,2,3(0)> <chunksize(128)> <depth for wp version>"

  args <- getArgs
  let l  = mkList n
      n  = if b > 15 then b else 4^b
      b  = if null args then 1024 else read (head args)
      cs = if length args < 3 then 512 else read (args!!2)
      v  = if length args < 2 then 0 else read (args!!1)
      d  = if length args < 4 then 3 else read (args!!3)
  let -- define versions here:
      r  = fft cs l
      r2 = fftTs cs l
      r3 = fftMW d cs l
      r4 = fftDM d cs l
      r5 = fftMW2 d cs l
#ifdef TEST
  rnf r `seq` 
--  (length (snd r)) `seq`
--  (fst r ) `seq` 
    putStrLn "Done fft with check"
  let ref = refFFT cs l
  if (r /= ref) then outputDiff r ref else putStrLn "CORRECT"
#else
  case v of 
--     0 -> rnf r `seq` putStrLn ("Done fft")
     1 -> rnf r2 `seq` putStrLn ("Done fft with more tickets")
     2 -> rnf r3 `seq` putStrLn ("Done fft with workpool")
     3 -> rnf r4 `seq` do putStrLn ("Done fft with direct mapping")
     4 -> rnf r5 `seq` putStrLn ("Done fft with hierachical workpool")
     _ -> putStrLn ("no version " ++ show v ++ " defined")
#endif

mkList :: Int -> ([Complex Double],Int)
mkList n = (map fromIntegral [1..n],n)

outputDiff :: (Show a,Eq a,Num a) => Int -> [a] -> [a] -> IO ()
outputDiff n l1 l2
    = do putStrLn (show (length diffs) ++ " different elements")
	 putStrLn (unlines (map show (take n diffs)))
    where diffs = [ (k,x,y,x-y) | (k,x,y) <- zip3 [0..] l1 l2, x/=y ]

	      
instance (RealFloat a, Trans a) => Trans (Complex a) -- ???

--quadDCAtSep :: (Trans a, Trans b) => 
--           Int -> (a -> Bool) -> (a -> b) -> (a -> [a]) -> ([b] -> b) -> (a -> b) ->  a -> b


fft :: Int -> ([Complex Double],Int) -> ([Complex Double],Int)
-- 4-radix FFT, chunking included in worker functions:
fft cSize input = unchunk (
                           -- dc4WithSeq noPe
                           dcNTickets 4 peTickets  -- skel. params
                              (less4Elem . fst . unchunk) fftSC unshuffleC combine4C seqDC 
			   (chunk cSize input))
  where 
	-- chunking encoded in worker functions:
	fftSC = chunk cSize . fftS . unchunk 
	unshuffleC = map (chunk cSize) . unshuffle' 4 . unchunk
	combine4C xs = chunk cSize (combine4 (map unchunk xs))
    ---------------------
	-- tickets:
	peTickets = [2..noPe]
    ---------------------
        -- seq. version : no chunking
        seqDC = chunk cSize . seqDC' . unchunk
        seqDC' input = if less4Elem (fst input) then fftS input
                           else combine4 (map seqDC' (unshuffle' 4 input))
        seqDC'' input = let splitSolve [] = error "splitSolve: empty"
			    splitSolve (x:xs)   | less4Elem (fst x) = map fftS (x:xs)
						| otherwise = let parts = unshuffle' 4 x
								  rHead (x:_,_) = rnf x
							      in seqList rHead parts `seq`
									splitSolve (xs ++ parts)
			in combineTree (splitSolve [input])

less4Elem []  = True
less4Elem [_] = True
less4Elem [_,_] = True
less4Elem [_,_,_] = True
less4Elem _ = False

fftTs :: Int -> ([Complex Double],Int) -> ([Complex Double],Int)
-- 4-radix FFT, chunking included in worker functions:
fftTs cSize input = unchunk (
                           -- dc4WithSeq noPe
                           dcNTickets 4 peTickets  -- skel. params
                              (less4Elem . fst . unchunk) fftSC unshuffleC combine4C seqDC 
               (chunk cSize input))
  where 
    -- chunking encoded in worker functions:
    fftSC = chunk cSize . fftS . unchunk 
    unshuffleC = map (chunk cSize) . unshuffle' 4 . unchunk
    combine4C xs = chunk cSize (combine4 (map unchunk xs))
    ---------------------
    -- tickets:multiple placement
    pes = [2..noPe]
    list = pes ++ cycle (reverse (1:pes)) 
    peTickets = take (4^(logN 4 noPe) - 1) list
     ---------------------
        -- seq. version : no chunking
    seqDC = chunk cSize . seqDC' . unchunk
    seqDC' input = if less4Elem (fst input) then fftS input
                           else combine4 (map seqDC' (unshuffle' 4 input))
    seqDC'' input 
      = let splitSolve [] = error "splitSolve: empty"
            splitSolve (x:xs) | less4Elem (fst x) = map fftS (x:xs)
                              | otherwise = let parts = unshuffle' 4 x
                                                rHead (x:_,_) = rnf x
                                            in seqList rHead parts `seq`
                                                      splitSolve (xs ++ parts)
        in combineTree (splitSolve [input])

combineTree :: [([Complex Double],Int)] -> ([Complex Double],Int)
combineTree [] = error "combineTree: empty"
combineTree [x] = x
combineTree l@(a:b:c:d:rest) =  let list4 = chunkL 4 l
				in combineTree (map combine4 list4)

fftMW :: Int -> Int -> ([Complex Double],Int) -> ([Complex Double],Int)
-- 4-radix FFT using a workpool, chunking included in worker functions:
fftMW depth cSize input = unchunk (
                           -- dc4WithSeq noPe
                           divConRW depth -- skel. params
                              (less4Elem . fst . unchunk) fftSC unshuffleC 
			      combine4C -- different interface, see below
               (chunk cSize input))
  where 
    -- chunking encoded in worker functions:
    fftSC = chunk cSize . fftS . unchunk 
    unshuffleC = map (chunk cSize) . unshuffle' 4 . unchunk
    combine4C  _ xs = chunk cSize (combine4 (map unchunk xs))

fftMW2 :: Int -> Int -> ([Complex Double],Int) -> ([Complex Double],Int)
-- 4-radix FFT using another workpool, chunking included in worker functions:
fftMW2 depth cSize input = unchunk (
                           -- dc4WithSeq noPe
                           divConRW2 depth -- skel. params
                              (less4Elem . fst . unchunk) fftSC unshuffleC 
			      combine4C -- different interface, see below
               (chunk cSize input))
  where 
    -- chunking encoded in worker functions:
    fftSC = chunk cSize . fftS . unchunk 
    unshuffleC = map (chunk cSize) . unshuffle' 4 . unchunk
    combine4C  _ xs = chunk cSize (combine4 (map unchunk xs))

fftDM :: Int -> Int -> ([Complex Double],Int) -> ([Complex Double],Int)
-- 4-radix FFT using a workpool, chunking included in worker functions:
fftDM depth cSize input = -- unchunk (
                           -- dc4WithSeq noPe
                           divConDM depth -- skel. params
                              (less4Elem . fst ) --  . unchunk) 
                              fftS -- C 
                              (unshuffle' 4) -- C 
			      combine4NoC -- different interface, see below
                               -- (chunk cSize 
                                input -- )  )
  where 
--    -- chunking encoded in worker functions:
--    fftSC = chunk cSize . fftS . unchunk 
--    unshuffleC = map (chunk cSize) . unshuffle' 4 . unchunk
--    combine4C  _ xs = chunk cSize (combine4 (map unchunk xs))
    combine4NoC  _ xs = combine4 xs


unshuffle :: Int -> [a] -> [[a]]
unshuffle n xs = [takeEach n (drop i xs) | i <- [0..n-1]]

takeEach :: Int -> [a] -> [a] 
takeEach n [] = []
takeEach n (x:xs) = x : takeEach n (drop (n-1) xs)

unshuffle' :: Int -> ([a],Int) -> [([a],Int)]
unshuffle' n (xs,len) = zip [takeEach n (drop i xs) | i <- [0..n-1]] parts  
  where parts = zipWith (+) ((replicate (len `mod` n) 1) ++ repeat 0) 
                                     (replicate n (len `div` n))


chunk :: Int -> ([a],Int) -> ([[a]],Int)
chunk _ ([],_) = ([],0)
chunk n (xs,l) = (chunkL n xs,l)

chunkL n [] =  []
chunkL n xs =  ys : chunkL n zs
    where (ys,zs) = splitAt n xs

unchunk :: ([[a]],Int) -> ([a],Int)
unchunk (xs,l) = (concat xs,l)

combine4 :: (Num a,NFData a,RealFloat a) => 
	  [([Complex a],Int)] -> ([Complex a],Int)
combine4 [(xs0,n),(xs1,_),(xs2,_),(xs3,_)] 
	  = -- rnf temp `seq` 
	    (r0 ++ r1 ++ r2 ++ r3,4*n)
     where w = root_of_unity (-1) (4*n)
           roots  = snd $ powers (3*n) w   -- take (3*n) (iterate (*w) 1)
           roots2 = snd $ powers n (w*w)   -- take n (iterate (*(w*w)) 1)
           roots3 = snd $ powers n (w*w*w) -- take n (iterate (*(w*w*w)) 1)
	   temp = [ let x1 = a1*r1
          	        x2 = a2*r2 -- Version oben: anderes Resultat
		        x3 = a3*r3 -- Version oben: anderes Resultat
		        eP = x1+x3
		        oP = x0+x2
		        tW = (0:+1)*(x1-x3)
		        oM = x0-x2 
                    in (eP + oP, oM - tW, oP - eP, oM + tW ) 
		        `using` rnf
                    | x0 <- xs0
                    | a1 <- xs1
                    | a2 <- xs2
                    | a3 <- xs3
                    | r1 <- roots
                    | r2 <- roots2
                    | r3 <- roots3
		  ]
	   oldtemp = [ (eP + oP, oM - tW, oP - eP, oM + tW )
		    | (x0,a1,a2,a3,r1,r2,r3) 
		    <- zip7 xs0 xs1 xs2 xs3 roots roots2 roots3 
		  , let x1 = a1*r1
          	  , let x2 = a2*r2
		  , let x3 = a3*r3
		  , let (eP,oP,tW,oM) = (x1+x3, x0+x2, (0:+1)*(x1-x3), x0-x2) 
		  ]
	   (r0,r1,r2,r3) = unzip4 temp
combine4 other = error ("wrong input for merge2: length is " ++ 
			show (length other) ++ "and first components are "
			++ show (map snd other))


-----------------------------------------------------------------
-- fft basics:

-- possible limitation due to length of mantissa in realfloat
-- root_of_unity :: (RealFloat a, Integral b) => b -> b -> Complex a
root_of_unity j n = cos (2*pi*j'/n') :+ sin (2*pi*j'/n')
    where j' = fromIntegral j
          n' = fromIntegral n
powers :: RealFloat a => Int -> Complex a -> (Int, [Complex a])
powers n w = (n, take n (iterate (*w) 1))

fftS :: (Show a, RealFloat a) => ([Complex a],Int) -> ([Complex a],Int)
fftS ([],0) = error "Empty input!"
fftS ([x0],1) = ([x0],1)
fftS ([x0, x1],2) = ([x0+x1, x0-x1],2)
fftS ([x0, x1, x2],3) = let  i = 0.0 :+ 1.0
                             u = 2.0*pi/3.0
                             t1 = x1+x2
                             m0 = x0+t1
                             m1 = (cos(u) - 1.0) * t1
                             m2 = i * sin(u) * (x2-x1)
                             s1 = m0 + m1
                         in ([m0, s1+m2, s1-m2],3)
fftS ([x0, x1, x2, x3], 4) = let i = 0.0 :+ 1.0
                                 t1 = x0 + x2
                                 t2 = x1 + x3
                                 m0 = t1 + t2
                                 m1 = t1 - t2
                                 m2 = x0 - x2
                                 m3 = (x3 - x1) * i
                             in ([m0, m2+m3, m1, m2-m3],4)
fftS other = error ("fftS: unexpected input " ++ show other)

-- instance (Observable a, RealFloat a) =>  Observable (Complex a) 
--   where observer (x :+ y) = send "(:+)" (return (:+) << x << y)