module Diagrams.TwoD.Segment.Bernstein
  ( BernsteinPoly (..)
  , listToBernstein
  , evaluateBernstein
  , degreeElevate
  , bernsteinDeriv
  , evaluateBernsteinDerivs
  ) where
import           Data.List           (tails)
import           Diagrams.Core.V
import           Diagrams.Parametric
import           Linear.V1
binomials :: Num n => Int -> [n]
binomials n = map fromIntegral $ scanl (\x m -> x * (n  m+1) `quot` m) 1 [1..n]
data BernsteinPoly n = BernsteinPoly
  { bernsteinDegree :: Int
  , bernsteinCoeffs :: [n]
  } deriving (Show, Functor)
type instance V        (BernsteinPoly n) = V1
type instance N        (BernsteinPoly n) = n
type instance Codomain (BernsteinPoly n) = V1
listToBernstein :: Fractional n => [n] -> BernsteinPoly n
listToBernstein [] = 0
listToBernstein l  = BernsteinPoly (length l  1) l
degreeElevate :: Fractional n => BernsteinPoly n -> Int -> BernsteinPoly n
degreeElevate b                    0     = b
degreeElevate (BernsteinPoly lp p) times =
  degreeElevate (BernsteinPoly (lp+1) (head p:inner p 1)) (times1)
  where
    n = fromIntegral lp
    inner []         _ = [0]
    inner [a]        _ = [a]
    inner (a:b:rest) i = (i*a/(n+1) + b*(1  i/(n+1))) : inner (b:rest) (i+1)
evaluateBernstein :: Fractional n => BernsteinPoly n -> n -> n
evaluateBernstein (BernsteinPoly _ [])       _ = 0
evaluateBernstein (BernsteinPoly _ [b])      _ = b
evaluateBernstein (BernsteinPoly lp (b':bs)) t = go t n (b'*u) 2 bs
  where
    u = 1t
    n = fromIntegral lp
    go tn bc tmp _ [b]      = tmp + tn*bc*b
    go tn bc tmp i (b:rest) =
      go (tn*t)              
         (bc*(n  i+1)/i)    
         ((tmp + tn*bc*b)*u) 
         (i+1)               
         rest
    go _ _ _ _ []           = error "evaluateBernstein: impossible"
evaluateBernsteinDerivs :: Fractional n => BernsteinPoly n -> n -> [n]
evaluateBernsteinDerivs b t
  | bernsteinDegree b == 0 = [evaluateBernstein b t]
  | otherwise              = evaluateBernstein b t : evaluateBernsteinDerivs (bernsteinDeriv b) t
bernsteinDeriv :: Fractional n => BernsteinPoly n -> BernsteinPoly n
bernsteinDeriv (BernsteinPoly 0 _)  = 0
bernsteinDeriv (BernsteinPoly lp p) =
  
  BernsteinPoly (lp1) $ zipWith (\a b -> (a  b) * fromIntegral lp) (tail p) p
instance Fractional n => Parametric (BernsteinPoly n) where
  atParam b = V1 . evaluateBernstein b
instance Num n        => DomainBounds (BernsteinPoly n)
instance Fractional n => EndValues    (BernsteinPoly n)
instance Fractional n => Sectionable  (BernsteinPoly n) where
  splitAtParam  = bernsteinSplit
  reverseDomain (BernsteinPoly i xs) = BernsteinPoly i (reverse xs)
bernsteinSplit :: Num n => BernsteinPoly n -> n -> (BernsteinPoly n, BernsteinPoly n)
bernsteinSplit (BernsteinPoly lp p) t =
  (BernsteinPoly lp $ map head controls,
   BernsteinPoly lp $ reverse $ map last controls)
  where
    interp a b = (1t)*a + t*b
    terp [_] = []
    terp l   = let ctrs = zipWith interp l (tail l)
               in  ctrs : terp ctrs
    controls = p : terp p
instance Fractional n => Num (BernsteinPoly n) where
  ba@(BernsteinPoly la a) + bb@(BernsteinPoly lb b)
    | la < lb   = BernsteinPoly lb $ zipWith (+) (bernsteinCoeffs $ degreeElevate ba $ lb  la) b
    | la > lb   = BernsteinPoly la $ zipWith (+) a (bernsteinCoeffs $ degreeElevate bb $ la  lb)
    | otherwise = BernsteinPoly la $ zipWith (+) a b
  ba@(BernsteinPoly la a)  bb@(BernsteinPoly lb b)
    | la < lb   = BernsteinPoly lb $ zipWith () (bernsteinCoeffs $ degreeElevate ba (lb  la)) b
    | la > lb   = BernsteinPoly la $ zipWith () a (bernsteinCoeffs $ degreeElevate bb (la  lb))
    | otherwise = BernsteinPoly la $ zipWith () a b
  (BernsteinPoly la a) * (BernsteinPoly lb b) =
    BernsteinPoly (la+lb) $
    zipWith (flip (/)) (binomials (la + lb)) $
                   init $ map sum $
                   map (zipWith (*) a') (down b') ++
                   map (zipWith (*) (reverse b')) (tail $ tails a')
                   
    where down l = tail $ scanl (flip (:)) [] l 
          a' = zipWith (*) a (binomials la)
          b' = zipWith (*) b (binomials lb)
  fromInteger a = BernsteinPoly 0 [fromInteger a]
  signum (BernsteinPoly _ [])    = 0
  signum (BernsteinPoly _ (a:_)) = BernsteinPoly 0 [signum a]
  abs = fmap abs