import DivConN
import Eden
import System
import System.IO
import Data.Complex
import Data.List
import DivConRW
import Control.Parallel.Strategies(seqList)
#ifdef TEST
import qualified Reference
refFFT :: Int -> ([Complex Double],Int) -> ([Complex Double],Int)
refFFT (l,n) = Reference.fft (n,l)
#endif
main = do
hSetBuffering stdout NoBuffering
putStrLn "usage: FFT <size or base for power of 4(1024)> <version 0,1,2,3(0)> <chunksize(128)> <depth for wp version>"
args <- getArgs
let l = mkList n
n = if b > 15 then b else 4^b
b = if null args then 1024 else read (head args)
cs = if length args < 3 then 512 else read (args!!2)
v = if length args < 2 then 0 else read (args!!1)
d = if length args < 4 then 3 else read (args!!3)
let
r = fft cs l
r2 = fftTs cs l
r3 = fftMW d cs l
r4 = fftDM d cs l
r5 = fftMW2 d cs l
#ifdef TEST
rnf r `seq`
putStrLn "Done fft with check"
let ref = refFFT cs l
if (r /= ref) then outputDiff r ref else putStrLn "CORRECT"
#else
case v of
1 -> rnf r2 `seq` putStrLn ("Done fft with more tickets")
2 -> rnf r3 `seq` putStrLn ("Done fft with workpool")
3 -> rnf r4 `seq` do putStrLn ("Done fft with direct mapping")
4 -> rnf r5 `seq` putStrLn ("Done fft with hierachical workpool")
_ -> putStrLn ("no version " ++ show v ++ " defined")
#endif
mkList :: Int -> ([Complex Double],Int)
mkList n = (map fromIntegral [1..n],n)
outputDiff :: (Show a,Eq a,Num a) => Int -> [a] -> [a] -> IO ()
outputDiff n l1 l2
= do putStrLn (show (length diffs) ++ " different elements")
putStrLn (unlines (map show (take n diffs)))
where diffs = [ (k,x,y,xy) | (k,x,y) <- zip3 [0..] l1 l2, x/=y ]
instance (RealFloat a, Trans a) => Trans (Complex a)
fft :: Int -> ([Complex Double],Int) -> ([Complex Double],Int)
fft cSize input = unchunk (
dcNTickets 4 peTickets
(less4Elem . fst . unchunk) fftSC unshuffleC combine4C seqDC
(chunk cSize input))
where
fftSC = chunk cSize . fftS . unchunk
unshuffleC = map (chunk cSize) . unshuffle' 4 . unchunk
combine4C xs = chunk cSize (combine4 (map unchunk xs))
peTickets = [2..noPe]
seqDC = chunk cSize . seqDC' . unchunk
seqDC' input = if less4Elem (fst input) then fftS input
else combine4 (map seqDC' (unshuffle' 4 input))
seqDC'' input = let splitSolve [] = error "splitSolve: empty"
splitSolve (x:xs) | less4Elem (fst x) = map fftS (x:xs)
| otherwise = let parts = unshuffle' 4 x
rHead (x:_,_) = rnf x
in seqList rHead parts `seq`
splitSolve (xs ++ parts)
in combineTree (splitSolve [input])
less4Elem [] = True
less4Elem [_] = True
less4Elem [_,_] = True
less4Elem [_,_,_] = True
less4Elem _ = False
fftTs :: Int -> ([Complex Double],Int) -> ([Complex Double],Int)
fftTs cSize input = unchunk (
dcNTickets 4 peTickets
(less4Elem . fst . unchunk) fftSC unshuffleC combine4C seqDC
(chunk cSize input))
where
fftSC = chunk cSize . fftS . unchunk
unshuffleC = map (chunk cSize) . unshuffle' 4 . unchunk
combine4C xs = chunk cSize (combine4 (map unchunk xs))
pes = [2..noPe]
list = pes ++ cycle (reverse (1:pes))
peTickets = take (4^(logN 4 noPe) 1) list
seqDC = chunk cSize . seqDC' . unchunk
seqDC' input = if less4Elem (fst input) then fftS input
else combine4 (map seqDC' (unshuffle' 4 input))
seqDC'' input
= let splitSolve [] = error "splitSolve: empty"
splitSolve (x:xs) | less4Elem (fst x) = map fftS (x:xs)
| otherwise = let parts = unshuffle' 4 x
rHead (x:_,_) = rnf x
in seqList rHead parts `seq`
splitSolve (xs ++ parts)
in combineTree (splitSolve [input])
combineTree :: [([Complex Double],Int)] -> ([Complex Double],Int)
combineTree [] = error "combineTree: empty"
combineTree [x] = x
combineTree l@(a:b:c:d:rest) = let list4 = chunkL 4 l
in combineTree (map combine4 list4)
fftMW :: Int -> Int -> ([Complex Double],Int) -> ([Complex Double],Int)
fftMW depth cSize input = unchunk (
divConRW depth
(less4Elem . fst . unchunk) fftSC unshuffleC
combine4C
(chunk cSize input))
where
fftSC = chunk cSize . fftS . unchunk
unshuffleC = map (chunk cSize) . unshuffle' 4 . unchunk
combine4C _ xs = chunk cSize (combine4 (map unchunk xs))
fftMW2 :: Int -> Int -> ([Complex Double],Int) -> ([Complex Double],Int)
fftMW2 depth cSize input = unchunk (
divConRW2 depth
(less4Elem . fst . unchunk) fftSC unshuffleC
combine4C
(chunk cSize input))
where
fftSC = chunk cSize . fftS . unchunk
unshuffleC = map (chunk cSize) . unshuffle' 4 . unchunk
combine4C _ xs = chunk cSize (combine4 (map unchunk xs))
fftDM :: Int -> Int -> ([Complex Double],Int) -> ([Complex Double],Int)
fftDM depth cSize input =
divConDM depth
(less4Elem . fst )
fftS
(unshuffle' 4)
combine4NoC
input
where
combine4NoC _ xs = combine4 xs
unshuffle :: Int -> [a] -> [[a]]
unshuffle n xs = [takeEach n (drop i xs) | i <- [0..n1]]
takeEach :: Int -> [a] -> [a]
takeEach n [] = []
takeEach n (x:xs) = x : takeEach n (drop (n1) xs)
unshuffle' :: Int -> ([a],Int) -> [([a],Int)]
unshuffle' n (xs,len) = zip [takeEach n (drop i xs) | i <- [0..n1]] parts
where parts = zipWith (+) ((replicate (len `mod` n) 1) ++ repeat 0)
(replicate n (len `div` n))
chunk :: Int -> ([a],Int) -> ([[a]],Int)
chunk _ ([],_) = ([],0)
chunk n (xs,l) = (chunkL n xs,l)
chunkL n [] = []
chunkL n xs = ys : chunkL n zs
where (ys,zs) = splitAt n xs
unchunk :: ([[a]],Int) -> ([a],Int)
unchunk (xs,l) = (concat xs,l)
combine4 :: (Num a,NFData a,RealFloat a) =>
[([Complex a],Int)] -> ([Complex a],Int)
combine4 [(xs0,n),(xs1,_),(xs2,_),(xs3,_)]
=
(r0 ++ r1 ++ r2 ++ r3,4*n)
where w = root_of_unity (1) (4*n)
roots = snd $ powers (3*n) w
roots2 = snd $ powers n (w*w)
roots3 = snd $ powers n (w*w*w)
temp = [ let x1 = a1*r1
x2 = a2*r2
x3 = a3*r3
eP = x1+x3
oP = x0+x2
tW = (0:+1)*(x1x3)
oM = x0x2
in (eP + oP, oM tW, oP eP, oM + tW )
`using` rnf
| x0 <- xs0
| a1 <- xs1
| a2 <- xs2
| a3 <- xs3
| r1 <- roots
| r2 <- roots2
| r3 <- roots3
]
oldtemp = [ (eP + oP, oM tW, oP eP, oM + tW )
| (x0,a1,a2,a3,r1,r2,r3)
<- zip7 xs0 xs1 xs2 xs3 roots roots2 roots3
, let x1 = a1*r1
, let x2 = a2*r2
, let x3 = a3*r3
, let (eP,oP,tW,oM) = (x1+x3, x0+x2, (0:+1)*(x1x3), x0x2)
]
(r0,r1,r2,r3) = unzip4 temp
combine4 other = error ("wrong input for merge2: length is " ++
show (length other) ++ "and first components are "
++ show (map snd other))
root_of_unity j n = cos (2*pi*j'/n') :+ sin (2*pi*j'/n')
where j' = fromIntegral j
n' = fromIntegral n
powers :: RealFloat a => Int -> Complex a -> (Int, [Complex a])
powers n w = (n, take n (iterate (*w) 1))
fftS :: (Show a, RealFloat a) => ([Complex a],Int) -> ([Complex a],Int)
fftS ([],0) = error "Empty input!"
fftS ([x0],1) = ([x0],1)
fftS ([x0, x1],2) = ([x0+x1, x0x1],2)
fftS ([x0, x1, x2],3) = let i = 0.0 :+ 1.0
u = 2.0*pi/3.0
t1 = x1+x2
m0 = x0+t1
m1 = (cos(u) 1.0) * t1
m2 = i * sin(u) * (x2x1)
s1 = m0 + m1
in ([m0, s1+m2, s1m2],3)
fftS ([x0, x1, x2, x3], 4) = let i = 0.0 :+ 1.0
t1 = x0 + x2
t2 = x1 + x3
m0 = t1 + t2
m1 = t1 t2
m2 = x0 x2
m3 = (x3 x1) * i
in ([m0, m2+m3, m1, m2m3],4)
fftS other = error ("fftS: unexpected input " ++ show other)