From ad8a09cafa519a15a22cafbfd2fa289538edc73d Mon Sep 17 00:00:00 2001 From: stuebinm Date: Wed, 8 May 2024 22:42:35 +0200 Subject: restructure: split up the server module --- lib/Server/Ingest.hs | 211 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 211 insertions(+) create mode 100644 lib/Server/Ingest.hs (limited to 'lib/Server/Ingest.hs') diff --git a/lib/Server/Ingest.hs b/lib/Server/Ingest.hs new file mode 100644 index 0000000..d774017 --- /dev/null +++ b/lib/Server/Ingest.hs @@ -0,0 +1,211 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE RecordWildCards #-} + +module Server.Ingest (handleTrackerRegister, handleTrainPing, handleWS) where +import API (Metrics (..), + RegisterJson (..)) +import Control.Concurrent.STM (atomically, readTVar, + writeTQueue) +import Control.Monad (forever, void, when) +import Control.Monad.Catch (handle) +import Control.Monad.Extra (ifM, mapMaybeM, whenJust) +import Control.Monad.IO.Class (MonadIO (liftIO)) +import Control.Monad.Logger (LoggingT, logInfoN, + logWarnN) +import Control.Monad.Reader (ReaderT) +import qualified Data.Aeson as A +import qualified Data.ByteString.Char8 as C8 +import Data.Coerce (coerce) +import Data.Functor ((<&>)) +import qualified Data.Map as M +import Data.Pool (Pool) +import Data.Text (Text) +import Data.Text.Encoding (decodeASCII, decodeUtf8) +import Data.Time (NominalDiffTime, + UTCTime (..), addUTCTime, + getCurrentTime, + nominalDay) +import qualified Data.Vector as V +import Database.Persist +import Database.Persist.Postgresql (SqlBackend) +import Fmt ((+|), (|+)) +import qualified GTFS +import qualified Network.WebSockets as WS +import Persist +import Servant (err400, throwError) +import Servant.Server (Handler) +import Server.Util (ServiceM, getTzseries, + utcToSeconds) + +import Config (LoggingConfig, + ServerConfig) +import Data.ByteString (ByteString) +import Data.ByteString.Lazy (toStrict) +import Data.Foldable (minimumBy) +import Data.Function (on, (&)) +import qualified Data.Text as T +import Data.Time.LocalTime.TimeZone.Series (TimeZoneSeries) +import qualified Data.UUID as UUID +import Extrapolation (Extrapolator (..), + LinearExtrapolator (..), + euclid) +import GTFS (seconds2Double) +import Prometheus (decGauge, incGauge) +import Server.Base (ServerState) + + + +handleTrackerRegister + :: Pool SqlBackend + -> RegisterJson + -> ServiceM Token +handleTrackerRegister dbpool RegisterJson{..} = do + today <- liftIO getCurrentTime <&> utctDay + expires <- liftIO $ getCurrentTime <&> addUTCTime validityPeriod + runSql dbpool $ do + TrackerKey tracker <- insert (Tracker expires False registerAgent Nothing) + pure tracker + where + validityPeriod :: NominalDiffTime + validityPeriod = nominalDay + +handleTrainPing + :: Pool SqlBackend + -> ServerState + -> ServerConfig + -> LoggingT (ReaderT LoggingConfig Handler) a + -> TrainPing + -> LoggingT (ReaderT LoggingConfig Handler) (Maybe TrainAnchor) +handleTrainPing dbpool subscribers cfg onError ping@TrainPing{..} = + isTokenValid dbpool trainPingToken >>= \case + Nothing -> onError >> pure Nothing + Just tracker@Tracker{..} -> do + + -- if the tracker is not associated with a ticket, it is probably + -- just starting out on a new trip, or has finished an old one. + maybeTicketId <- case trackerCurrentTicket of + Just ticketId -> pure (Just ticketId) + Nothing -> runSql dbpool $ do + now <- liftIO getCurrentTime + tickets <- selectList [ TicketDay ==. utctDay now, TicketCompleted ==. False ] [] + ticketsWithFirstStation <- flip mapMaybeM tickets + (\ticket@(Entity ticketId _) -> do + selectFirst [StopTicket ==. ticketId] [Asc StopSequence] >>= \case + Nothing -> pure Nothing + Just (Entity _ stop) -> do + station <- getJust (stopStation stop) + tzseries <- liftIO $ getTzseries cfg (GTFS.tzname (stopDeparture stop)) + pure (Just (ticket, station, stop, tzseries))) + + if null ticketsWithFirstStation then pure Nothing else do + let (closestTicket, _, _, _) = minimumBy + -- (compare `on` euclid trainPingGeopos . stationGeopos . snd) + (compare `on` + (\(Entity _ ticket, station, stop, tzseries) -> + let + runningDay = ticketDay ticket + spaceDistance = euclid trainPingGeopos (stationGeopos station) + timeDiff = + GTFS.toSeconds (stopDeparture stop) tzseries runningDay + - utcToSeconds now runningDay + in + euclid trainPingGeopos (stationGeopos station) + + abs (seconds2Double timeDiff / 3600))) + ticketsWithFirstStation + logInfoN + $ "Tracker "+|UUID.toString (coerce trainPingToken)|+ + " is now handling ticket "+|UUID.toString (coerce (entityKey closestTicket))|+ + " (trip "+|ticketTripName (entityVal closestTicket)|+")." + + update trainPingToken + [TrackerCurrentTicket =. Just (entityKey closestTicket)] + + pure (Just (entityKey closestTicket)) + + ticketId <- case maybeTicketId of + Just ticketId -> pure ticketId + Nothing -> do + logWarnN $ "Tracker "+|UUID.toString (coerce trainPingToken)|+ + " sent a ping, but no trips are running today." + throwError err400 + + runSql dbpool $ do + ticket@Ticket{..} <- getJust ticketId + + stations <- selectList [ StopTicket ==. ticketId ] [Asc StopArrival] + >>= mapM (\stop -> do + station <- getJust (stopStation (entityVal stop)) + tzseries <- liftIO $ getTzseries cfg (GTFS.tzname (stopArrival (entityVal stop))) + pure (entityVal stop, station, tzseries)) + <&> V.fromList + + shapePoints <- selectList [ShapePointShape ==. ticketShape] [Asc ShapePointIndex] + <&> (V.fromList . fmap entityVal) + + let anchor = extrapolateAnchorFromPing LinearExtrapolator + ticketId ticket stations shapePoints ping + + insert ping + + last <- selectFirst [TrainAnchorTicket ==. ticketId] [Desc TrainAnchorWhen] + -- only insert new estimates if they've actually changed anything + when (fmap (trainAnchorDelay . entityVal) last /= Just (trainAnchorDelay anchor)) + $ void $ insert anchor + + -- are we at the final stop? if so, mark this ticket as done + -- & the tracker as free + let maxSequence = V.last stations + & (\(stop, _, _) -> stopSequence stop) + & fromIntegral + when (trainAnchorSequence anchor + 0.1 >= maxSequence) $ do + update trainPingToken + [TrackerCurrentTicket =. Nothing] + update ticketId + [TicketCompleted =. True] + logInfoN $ "Tracker "+|UUID.toString (coerce trainPingToken)|+ + " has completed ticket "+|UUID.toString (coerce ticketId)|+ + " (trip "+|ticketTripName|+")" + + queues <- liftIO $ atomically $ do + queues <- readTVar subscribers <&> M.lookup (coerce ticketId) + whenJust queues $ + mapM_ (\q -> writeTQueue q (Just ping)) + pure queues + pure (Just anchor) + +handleWS + :: Pool SqlBackend + -> ServerState + -> ServerConfig + -> Metrics + -> WS.Connection -> ServiceM () +handleWS dbpool subscribers cfg Metrics{..} conn = do + liftIO $ WS.forkPingThread conn 30 + incGauge metricsWSGauge + handle (\(e :: WS.ConnectionException) -> decGauge metricsWSGauge) $ forever $ do + msg <- liftIO $ WS.receiveData conn + case A.eitherDecode msg of + Left err -> do + logWarnN ("stray websocket message: "+|decodeASCII (toStrict msg)|+" (could not decode: "+|err|+")") + liftIO $ WS.sendClose conn (C8.pack err) + -- TODO: send a close msg (Nothing) to the subscribed queues? decGauge metricsWSGauge + Right ping -> do + -- if invalid token, send a "polite" close request. Note that the client may + -- ignore this and continue sending messages, which will continue to be handled. + handleTrainPing dbpool subscribers cfg (liftIO $ WS.sendClose conn ("" :: ByteString)) ping >>= \case + Just anchor -> liftIO $ WS.sendTextData conn (A.encode anchor) + Nothing -> pure () + +-- TODO: proper debug logging for expired tokens +isTokenValid :: Pool SqlBackend -> TrackerId -> ServiceM (Maybe Tracker) +isTokenValid dbpool token = runSql dbpool $ get token >>= \case + Just tracker | not (trackerBlocked tracker) -> do + ifM (hasExpired (trackerExpires tracker)) + (pure Nothing) + (pure (Just tracker)) + _ -> pure Nothing + +hasExpired :: MonadIO m => UTCTime -> m Bool +hasExpired limit = do + now <- liftIO getCurrentTime + pure (now > limit) -- cgit v1.2.3