aboutsummaryrefslogtreecommitdiff
path: root/lib/Server/Ingest.hs
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Server/Ingest.hs')
-rw-r--r--lib/Server/Ingest.hs275
1 files changed, 275 insertions, 0 deletions
diff --git a/lib/Server/Ingest.hs b/lib/Server/Ingest.hs
new file mode 100644
index 0000000..959a4c6
--- /dev/null
+++ b/lib/Server/Ingest.hs
@@ -0,0 +1,275 @@
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE RecordWildCards #-}
+
+module Server.Ingest (handleTrackerRegister, handleTrainPing, handleWS) where
+import API (Metrics (..),
+ RegisterJson (..),
+ SentPing (..))
+import Control.Concurrent.STM (atomically, readTVar,
+ writeTQueue)
+import Control.Monad (forM, forever, unless,
+ void, when)
+import Control.Monad.Catch (handle)
+import Control.Monad.Extra (ifM, mapMaybeM, whenJust,
+ whenJustM)
+import Control.Monad.IO.Class (MonadIO (liftIO))
+import Control.Monad.Logger (LoggingT, logInfoN,
+ logWarnN)
+import Control.Monad.Reader (ReaderT)
+import qualified Data.Aeson as A
+import qualified Data.ByteString.Char8 as C8
+import Data.Coerce (coerce)
+import Data.Functor ((<&>))
+import qualified Data.Map as M
+import Data.Pool (Pool)
+import Data.Text (Text)
+import Data.Text.Encoding (decodeASCII, decodeUtf8)
+import Data.Time (NominalDiffTime,
+ UTCTime (..), addUTCTime,
+ diffUTCTime,
+ getCurrentTime,
+ nominalDay)
+import qualified Data.Vector as V
+import Database.Persist
+import Database.Persist.Postgresql (SqlBackend)
+import Fmt ((+|), (|+))
+import qualified GTFS
+import qualified Network.WebSockets as WS
+import Persist
+import Servant (err400, throwError)
+import Servant.Server (Handler)
+import Server.Util (ServiceM, getTzseries,
+ utcToSeconds)
+
+import Config (LoggingConfig,
+ ServerConfig (..))
+import Control.Exception (throw)
+import Control.Monad.Logger.CallStack (logErrorN)
+import Data.ByteString (ByteString)
+import Data.ByteString.Lazy (toStrict)
+import Data.Foldable (find, minimumBy)
+import Data.Function (on, (&))
+import qualified Data.Text as T
+import Data.Time.LocalTime.TimeZone.Series (TimeZoneSeries)
+import qualified Data.UUID as UUID
+import Extrapolation (Extrapolator (..),
+ LinearExtrapolator (..),
+ euclid)
+import GHC.Generics (Generic)
+import GTFS (seconds2Double)
+import Prometheus (decGauge, incGauge)
+import Server.Base (ServerState)
+
+
+handleTrackerRegister
+ :: Pool SqlBackend
+ -> RegisterJson
+ -> ServiceM Token
+handleTrackerRegister dbpool RegisterJson{..} = do
+ today <- liftIO getCurrentTime <&> utctDay
+ expires <- liftIO $ getCurrentTime <&> addUTCTime validityPeriod
+ runSql dbpool $ do
+ TrackerKey tracker <- insert (Tracker expires False registerAgent Nothing)
+ pure tracker
+ where
+ validityPeriod :: NominalDiffTime
+ validityPeriod = nominalDay
+
+handleTrainPing
+ :: Pool SqlBackend
+ -> ServerState
+ -> ServerConfig
+ -> LoggingT (ReaderT LoggingConfig Handler) a
+ -> SentPing
+ -> LoggingT (ReaderT LoggingConfig Handler) (Maybe TrainAnchor)
+handleTrainPing dbpool subscribers cfg onError ping@SentPing{..} =
+ isTokenValid dbpool sentPingToken >>= \case
+ Nothing -> onError >> pure Nothing
+ Just tracker@Tracker{..} -> do
+
+ -- unless (serverConfigDebugMode cfg) $ do
+ -- now <- liftIO getCurrentTime
+ -- let timeDiff = sentPingTimestamp `diffUTCTime` now
+ -- when (utctDay sentPingTimestamp /= utctDay now) $ do
+ -- logErrorN "received ping for wrong day"
+ -- throw err400
+ -- when (timeDiff < 10) $ do
+ -- logWarnN "received ping more than 10 seconds out of date"
+ -- throw err400
+ -- when (timeDiff > 10) $ do
+ -- logWarnN "received ping from more than 10 seconds in the future"
+ -- throw err400
+
+ ticketId <- case trackerCurrentTicket of
+ Just ticketId -> pure ticketId
+ -- if the tracker is not associated with a ticket, it is probably new
+ -- & should be auto-associated with the most fitting current ticket
+ Nothing -> runSql dbpool (guessTicketFromPing cfg ping) >>= \case
+ Just ticketId -> pure ticketId
+ Nothing -> do
+ logWarnN $ "Tracker "+|UUID.toString (coerce sentPingToken)|+
+ " sent a ping, but no trips are running today."
+ throwError err400
+
+
+ runSql dbpool $ insertSentPing subscribers cfg ping tracker ticketId
+
+insertSentPing
+ :: ServerState
+ -> ServerConfig
+ -> SentPing
+ -> Tracker
+ -> TicketId
+ -> InSql (Maybe TrainAnchor)
+insertSentPing subscribers cfg ping@SentPing{..} tracker@Tracker{..} ticketId = do
+ ticket@Ticket{..} <- getJust ticketId
+
+ stations <- selectList [ StopTicket ==. ticketId ] [Asc StopArrival]
+ >>= mapM (\stop -> do
+ station <- getJust (stopStation (entityVal stop))
+ tzseries <- liftIO $ getTzseries cfg (GTFS.tzname (stopArrival (entityVal stop)))
+ pure (entityVal stop, station, tzseries))
+ <&> V.fromList
+
+ shapePoints <- selectList [ShapePointShape ==. ticketShape] [Asc ShapePointIndex]
+ <&> (V.fromList . fmap entityVal)
+
+
+ let anchor = extrapolateAnchorFromPing LinearExtrapolator
+ ticketId ticket stations shapePoints ping
+
+
+ maybeReassign <- selectFirst
+ [ TrainPingTicket ==. ticketId ]
+ [ Desc TrainPingTimestamp ]
+ <&> find (\ping -> trainPingSequence (entityVal ping) > trainAnchorSequence anchor)
+ >> guessTicketFromPing cfg ping
+ <&> find (/= ticketId)
+
+
+ -- mapM (\newTicketId -> if ticketId /= newTicketId then Just newTicketId else Nothing))
+ -- >>= (\ping -> guessTicketFromPing cfg ping >>= \case
+ -- Just newTicketId | ticketId /= newTicketId -> pure (Just newTicketId)
+ -- _ -> pure Nothing)
+
+ case maybeReassign of
+ Just newTicketId -> do
+ update sentPingToken
+ [TrackerCurrentTicket =. Just newTicketId ]
+ logInfoN $ "tracker "+|UUID.toText (coerce sentPingToken)|+
+ "has switched direction, and was reassigned to ticket "
+ +|UUID.toText (coerce newTicketId)|+"."
+ insertSentPing subscribers cfg ping tracker newTicketId
+ Nothing -> do
+ let trackedPing = TrainPing
+ { trainPingToken = sentPingToken
+ , trainPingGeopos = sentPingGeopos
+ , trainPingTimestamp = sentPingTimestamp
+ , trainPingSequence = trainAnchorSequence anchor
+ , trainPingTicket = ticketId
+ }
+
+ insert trackedPing
+
+ last <- selectFirst [TrainAnchorTicket ==. ticketId] [Desc TrainAnchorWhen]
+ -- only insert new estimates if they've actually changed anything
+ when (fmap (trainAnchorDelay . entityVal) last /= Just (trainAnchorDelay anchor))
+ $ void $ insert anchor
+
+ -- are we at the final stop? if so, mark this ticket as done
+ -- & the tracker as free
+ let maxSequence = V.last stations
+ & (\(stop, _, _) -> stopSequence stop)
+ & fromIntegral
+ when (trainAnchorSequence anchor + 0.1 >= maxSequence) $ do
+ update sentPingToken
+ [TrackerCurrentTicket =. Nothing]
+ update ticketId
+ [TicketCompleted =. True]
+ logInfoN $ "Tracker "+|UUID.toString (coerce sentPingToken)|+
+ " has completed ticket "+|UUID.toString (coerce ticketId)|+
+ " (trip "+|ticketTripName|+")"
+
+ queues <- liftIO $ atomically $ do
+ queues <- readTVar subscribers <&> M.lookup (coerce ticketId)
+ whenJust queues $
+ mapM_ (\q -> writeTQueue q (Just trackedPing))
+ pure queues
+ pure (Just anchor)
+
+handleWS
+ :: Pool SqlBackend
+ -> ServerState
+ -> ServerConfig
+ -> Metrics
+ -> WS.Connection -> ServiceM ()
+handleWS dbpool subscribers cfg Metrics{..} conn = do
+ liftIO $ WS.forkPingThread conn 30
+ incGauge metricsWSGauge
+ handle (\(e :: WS.ConnectionException) -> decGauge metricsWSGauge) $ forever $ do
+ msg <- liftIO $ WS.receiveData conn
+ case A.eitherDecode msg of
+ Left err -> do
+ logWarnN ("stray websocket message: "+|decodeASCII (toStrict msg)|+" (could not decode: "+|err|+")")
+ liftIO $ WS.sendClose conn (C8.pack err)
+ -- TODO: send a close msg (Nothing) to the subscribed queues? decGauge metricsWSGauge
+ Right ping -> do
+ -- if invalid token, send a "polite" close request. Note that the client may
+ -- ignore this and continue sending messages, which will continue to be handled.
+ handleTrainPing dbpool subscribers cfg (liftIO $ WS.sendClose conn ("" :: ByteString)) ping >>= \case
+ Just anchor -> liftIO $ WS.sendTextData conn (A.encode anchor)
+ Nothing -> pure ()
+
+
+guessTicketFromPing :: ServerConfig -> SentPing -> InSql (Maybe (Key Ticket))
+guessTicketFromPing cfg SentPing{..} = do
+ tickets <- selectList [ TicketDay ==. utctDay sentPingTimestamp, TicketCompleted ==. False ] []
+
+ ticketsWithStation <- forM tickets (\ticket@(Entity ticketId _) -> do
+ stops <- selectList [StopTicket ==. ticketId] [Asc StopSequence] >>= mapM (\(Entity _ stop) -> do
+ station <- getJust (stopStation stop)
+ tzseries <- liftIO $ getTzseries cfg (GTFS.tzname (stopDeparture stop))
+ pure (station, stop, tzseries))
+ pure (ticket, stops))
+
+ if null ticketsWithStation then pure Nothing else do
+ let (closestTicket, _) = ticketsWithStation
+ & minimumBy (compare `on` (\(Entity _ ticket, stations) ->
+ let
+ runningDay = ticketDay ticket
+ smallestDistance = stations
+ <&> (\(station, stop, tzseries) -> spaceAndTimeDiff
+ (sentPingGeopos, utcToSeconds sentPingTimestamp runningDay)
+ (stationGeopos station, GTFS.toSeconds (stopDeparture stop) tzseries runningDay))
+ & minimum
+ in smallestDistance))
+
+ logInfoN
+ $ "Tracker "+|UUID.toString (coerce sentPingToken)|+
+ " is now handling ticket "+|UUID.toString (coerce (entityKey closestTicket))|+
+ " (trip "+|ticketTripName (entityVal closestTicket)|+")."
+
+ update sentPingToken
+ [TrackerCurrentTicket =. Just (entityKey closestTicket)]
+
+ pure (Just (entityKey closestTicket))
+
+spaceAndTimeDiff :: (Geopos, GTFS.Seconds) -> (Geopos, GTFS.Seconds) -> Double
+spaceAndTimeDiff (pos1, time1) (pos2, time2) =
+ spaceDistance + abs (seconds2Double timeDiff / 3600)
+ where spaceDistance = euclid pos1 pos2
+ timeDiff = time1 - time2
+
+-- TODO: proper debug logging for expired tokens
+isTokenValid :: Pool SqlBackend -> TrackerId -> ServiceM (Maybe Tracker)
+isTokenValid dbpool token = runSql dbpool $ get token >>= \case
+ Just tracker | not (trackerBlocked tracker) -> do
+ ifM (hasExpired (trackerExpires tracker))
+ (pure Nothing)
+ (pure (Just tracker))
+ _ -> pure Nothing
+
+hasExpired :: MonadIO m => UTCTime -> m Bool
+hasExpired limit = do
+ now <- liftIO getCurrentTime
+ pure (now > limit)