aboutsummaryrefslogtreecommitdiff
path: root/lib/Server/Ingest.hs
diff options
context:
space:
mode:
authorstuebinm2024-05-08 22:42:35 +0200
committerstuebinm2024-05-08 22:43:05 +0200
commitad8a09cafa519a15a22cafbfd2fa289538edc73d (patch)
tree81f49d19669d5895115a1e8d39bd3557fc0c03d8 /lib/Server/Ingest.hs
parent0febc9cd99e0d8b80b1385593e25e7670d5c842b (diff)
restructure: split up the server module
Diffstat (limited to 'lib/Server/Ingest.hs')
-rw-r--r--lib/Server/Ingest.hs211
1 files changed, 211 insertions, 0 deletions
diff --git a/lib/Server/Ingest.hs b/lib/Server/Ingest.hs
new file mode 100644
index 0000000..d774017
--- /dev/null
+++ b/lib/Server/Ingest.hs
@@ -0,0 +1,211 @@
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE RecordWildCards #-}
+
+module Server.Ingest (handleTrackerRegister, handleTrainPing, handleWS) where
+import API (Metrics (..),
+ RegisterJson (..))
+import Control.Concurrent.STM (atomically, readTVar,
+ writeTQueue)
+import Control.Monad (forever, void, when)
+import Control.Monad.Catch (handle)
+import Control.Monad.Extra (ifM, mapMaybeM, whenJust)
+import Control.Monad.IO.Class (MonadIO (liftIO))
+import Control.Monad.Logger (LoggingT, logInfoN,
+ logWarnN)
+import Control.Monad.Reader (ReaderT)
+import qualified Data.Aeson as A
+import qualified Data.ByteString.Char8 as C8
+import Data.Coerce (coerce)
+import Data.Functor ((<&>))
+import qualified Data.Map as M
+import Data.Pool (Pool)
+import Data.Text (Text)
+import Data.Text.Encoding (decodeASCII, decodeUtf8)
+import Data.Time (NominalDiffTime,
+ UTCTime (..), addUTCTime,
+ getCurrentTime,
+ nominalDay)
+import qualified Data.Vector as V
+import Database.Persist
+import Database.Persist.Postgresql (SqlBackend)
+import Fmt ((+|), (|+))
+import qualified GTFS
+import qualified Network.WebSockets as WS
+import Persist
+import Servant (err400, throwError)
+import Servant.Server (Handler)
+import Server.Util (ServiceM, getTzseries,
+ utcToSeconds)
+
+import Config (LoggingConfig,
+ ServerConfig)
+import Data.ByteString (ByteString)
+import Data.ByteString.Lazy (toStrict)
+import Data.Foldable (minimumBy)
+import Data.Function (on, (&))
+import qualified Data.Text as T
+import Data.Time.LocalTime.TimeZone.Series (TimeZoneSeries)
+import qualified Data.UUID as UUID
+import Extrapolation (Extrapolator (..),
+ LinearExtrapolator (..),
+ euclid)
+import GTFS (seconds2Double)
+import Prometheus (decGauge, incGauge)
+import Server.Base (ServerState)
+
+
+
+handleTrackerRegister
+ :: Pool SqlBackend
+ -> RegisterJson
+ -> ServiceM Token
+handleTrackerRegister dbpool RegisterJson{..} = do
+ today <- liftIO getCurrentTime <&> utctDay
+ expires <- liftIO $ getCurrentTime <&> addUTCTime validityPeriod
+ runSql dbpool $ do
+ TrackerKey tracker <- insert (Tracker expires False registerAgent Nothing)
+ pure tracker
+ where
+ validityPeriod :: NominalDiffTime
+ validityPeriod = nominalDay
+
+handleTrainPing
+ :: Pool SqlBackend
+ -> ServerState
+ -> ServerConfig
+ -> LoggingT (ReaderT LoggingConfig Handler) a
+ -> TrainPing
+ -> LoggingT (ReaderT LoggingConfig Handler) (Maybe TrainAnchor)
+handleTrainPing dbpool subscribers cfg onError ping@TrainPing{..} =
+ isTokenValid dbpool trainPingToken >>= \case
+ Nothing -> onError >> pure Nothing
+ Just tracker@Tracker{..} -> do
+
+ -- if the tracker is not associated with a ticket, it is probably
+ -- just starting out on a new trip, or has finished an old one.
+ maybeTicketId <- case trackerCurrentTicket of
+ Just ticketId -> pure (Just ticketId)
+ Nothing -> runSql dbpool $ do
+ now <- liftIO getCurrentTime
+ tickets <- selectList [ TicketDay ==. utctDay now, TicketCompleted ==. False ] []
+ ticketsWithFirstStation <- flip mapMaybeM tickets
+ (\ticket@(Entity ticketId _) -> do
+ selectFirst [StopTicket ==. ticketId] [Asc StopSequence] >>= \case
+ Nothing -> pure Nothing
+ Just (Entity _ stop) -> do
+ station <- getJust (stopStation stop)
+ tzseries <- liftIO $ getTzseries cfg (GTFS.tzname (stopDeparture stop))
+ pure (Just (ticket, station, stop, tzseries)))
+
+ if null ticketsWithFirstStation then pure Nothing else do
+ let (closestTicket, _, _, _) = minimumBy
+ -- (compare `on` euclid trainPingGeopos . stationGeopos . snd)
+ (compare `on`
+ (\(Entity _ ticket, station, stop, tzseries) ->
+ let
+ runningDay = ticketDay ticket
+ spaceDistance = euclid trainPingGeopos (stationGeopos station)
+ timeDiff =
+ GTFS.toSeconds (stopDeparture stop) tzseries runningDay
+ - utcToSeconds now runningDay
+ in
+ euclid trainPingGeopos (stationGeopos station)
+ + abs (seconds2Double timeDiff / 3600)))
+ ticketsWithFirstStation
+ logInfoN
+ $ "Tracker "+|UUID.toString (coerce trainPingToken)|+
+ " is now handling ticket "+|UUID.toString (coerce (entityKey closestTicket))|+
+ " (trip "+|ticketTripName (entityVal closestTicket)|+")."
+
+ update trainPingToken
+ [TrackerCurrentTicket =. Just (entityKey closestTicket)]
+
+ pure (Just (entityKey closestTicket))
+
+ ticketId <- case maybeTicketId of
+ Just ticketId -> pure ticketId
+ Nothing -> do
+ logWarnN $ "Tracker "+|UUID.toString (coerce trainPingToken)|+
+ " sent a ping, but no trips are running today."
+ throwError err400
+
+ runSql dbpool $ do
+ ticket@Ticket{..} <- getJust ticketId
+
+ stations <- selectList [ StopTicket ==. ticketId ] [Asc StopArrival]
+ >>= mapM (\stop -> do
+ station <- getJust (stopStation (entityVal stop))
+ tzseries <- liftIO $ getTzseries cfg (GTFS.tzname (stopArrival (entityVal stop)))
+ pure (entityVal stop, station, tzseries))
+ <&> V.fromList
+
+ shapePoints <- selectList [ShapePointShape ==. ticketShape] [Asc ShapePointIndex]
+ <&> (V.fromList . fmap entityVal)
+
+ let anchor = extrapolateAnchorFromPing LinearExtrapolator
+ ticketId ticket stations shapePoints ping
+
+ insert ping
+
+ last <- selectFirst [TrainAnchorTicket ==. ticketId] [Desc TrainAnchorWhen]
+ -- only insert new estimates if they've actually changed anything
+ when (fmap (trainAnchorDelay . entityVal) last /= Just (trainAnchorDelay anchor))
+ $ void $ insert anchor
+
+ -- are we at the final stop? if so, mark this ticket as done
+ -- & the tracker as free
+ let maxSequence = V.last stations
+ & (\(stop, _, _) -> stopSequence stop)
+ & fromIntegral
+ when (trainAnchorSequence anchor + 0.1 >= maxSequence) $ do
+ update trainPingToken
+ [TrackerCurrentTicket =. Nothing]
+ update ticketId
+ [TicketCompleted =. True]
+ logInfoN $ "Tracker "+|UUID.toString (coerce trainPingToken)|+
+ " has completed ticket "+|UUID.toString (coerce ticketId)|+
+ " (trip "+|ticketTripName|+")"
+
+ queues <- liftIO $ atomically $ do
+ queues <- readTVar subscribers <&> M.lookup (coerce ticketId)
+ whenJust queues $
+ mapM_ (\q -> writeTQueue q (Just ping))
+ pure queues
+ pure (Just anchor)
+
+handleWS
+ :: Pool SqlBackend
+ -> ServerState
+ -> ServerConfig
+ -> Metrics
+ -> WS.Connection -> ServiceM ()
+handleWS dbpool subscribers cfg Metrics{..} conn = do
+ liftIO $ WS.forkPingThread conn 30
+ incGauge metricsWSGauge
+ handle (\(e :: WS.ConnectionException) -> decGauge metricsWSGauge) $ forever $ do
+ msg <- liftIO $ WS.receiveData conn
+ case A.eitherDecode msg of
+ Left err -> do
+ logWarnN ("stray websocket message: "+|decodeASCII (toStrict msg)|+" (could not decode: "+|err|+")")
+ liftIO $ WS.sendClose conn (C8.pack err)
+ -- TODO: send a close msg (Nothing) to the subscribed queues? decGauge metricsWSGauge
+ Right ping -> do
+ -- if invalid token, send a "polite" close request. Note that the client may
+ -- ignore this and continue sending messages, which will continue to be handled.
+ handleTrainPing dbpool subscribers cfg (liftIO $ WS.sendClose conn ("" :: ByteString)) ping >>= \case
+ Just anchor -> liftIO $ WS.sendTextData conn (A.encode anchor)
+ Nothing -> pure ()
+
+-- TODO: proper debug logging for expired tokens
+isTokenValid :: Pool SqlBackend -> TrackerId -> ServiceM (Maybe Tracker)
+isTokenValid dbpool token = runSql dbpool $ get token >>= \case
+ Just tracker | not (trackerBlocked tracker) -> do
+ ifM (hasExpired (trackerExpires tracker))
+ (pure Nothing)
+ (pure (Just tracker))
+ _ -> pure Nothing
+
+hasExpired :: MonadIO m => UTCTime -> m Bool
+hasExpired limit = do
+ now <- liftIO getCurrentTime
+ pure (now > limit)