#!/usr/bin/env nix-shell
#!nix-shell --pure -i "runghc -- -i../" -p "haskellPackages.ghcWithPackages (pkgs: with pkgs; [ ])"
import Control.Monad (guard)
import Data.List (unfoldr)
import Text.Parsec
import Text.Parsec.Char
import Text.Parsec.Combinator
import Text.Parsec.String (Parser)
import Aoc (readAndParseStdin)
main = do
input <- readAndParseStdin parseInput
print $ part1 input
print $ part2 input
part1 :: [[Int]] -> Int
part1 = sum . map (round . interpolateNext)
part2 :: [[Int]] -> Int
part2 = sum . map (round . interpolatePrevious)
interpolateNext :: [Int] -> Double
interpolateNext i = interpolatePolynomial (length i) i
interpolatePrevious :: [Int] -> Double
interpolatePrevious = interpolatePolynomial (-1)
interpolatePolynomial :: Int -> [Int] -> Double
interpolatePolynomial nth seq =
let divDiff = (dividedDifference . buildDifferenceTable) seq
initialValue = (1, 0)
(_, val) = foldl foldFunction initialValue $ zip [0 ..] (tail divDiff)
in head divDiff + val
where
foldFunction (productAcc, valueAcc) (idx, val) =
let prod = productAcc * fromIntegral (nth - idx)
in (prod, valueAcc + (val * prod))
dividedDifference :: [[Int]] -> [Double]
dividedDifference table = [fromIntegral (head row) / fromIntegral (fac i) | (i, row) <- zip [0 ..] table]
where
fac i = product [1 .. i]
buildDifferenceTable :: [Int] -> [[Int]]
buildDifferenceTable input = input : unfoldr buildRow input
where
zipPairs list = zip list $ tail list
diffPairs = map $ uncurry subtract
buildRow lst =
let row = diffPairs $ zipPairs lst
in guard (not $ null row) >> Just (row, row)
parseInput :: Parser [[Int]]
parseInput = parseSequence `sepBy` char '\n'
parseSequence :: Parser [Int]
parseSequence = map read <$> many1 (digit <|> char '-') `sepBy` char ' '