diff options
Diffstat (limited to 'lib/Server/Ingest.hs')
-rw-r--r-- | lib/Server/Ingest.hs | 275 |
1 files changed, 275 insertions, 0 deletions
diff --git a/lib/Server/Ingest.hs b/lib/Server/Ingest.hs new file mode 100644 index 0000000..959a4c6 --- /dev/null +++ b/lib/Server/Ingest.hs @@ -0,0 +1,275 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE RecordWildCards #-} + +module Server.Ingest (handleTrackerRegister, handleTrainPing, handleWS) where +import API (Metrics (..), + RegisterJson (..), + SentPing (..)) +import Control.Concurrent.STM (atomically, readTVar, + writeTQueue) +import Control.Monad (forM, forever, unless, + void, when) +import Control.Monad.Catch (handle) +import Control.Monad.Extra (ifM, mapMaybeM, whenJust, + whenJustM) +import Control.Monad.IO.Class (MonadIO (liftIO)) +import Control.Monad.Logger (LoggingT, logInfoN, + logWarnN) +import Control.Monad.Reader (ReaderT) +import qualified Data.Aeson as A +import qualified Data.ByteString.Char8 as C8 +import Data.Coerce (coerce) +import Data.Functor ((<&>)) +import qualified Data.Map as M +import Data.Pool (Pool) +import Data.Text (Text) +import Data.Text.Encoding (decodeASCII, decodeUtf8) +import Data.Time (NominalDiffTime, + UTCTime (..), addUTCTime, + diffUTCTime, + getCurrentTime, + nominalDay) +import qualified Data.Vector as V +import Database.Persist +import Database.Persist.Postgresql (SqlBackend) +import Fmt ((+|), (|+)) +import qualified GTFS +import qualified Network.WebSockets as WS +import Persist +import Servant (err400, throwError) +import Servant.Server (Handler) +import Server.Util (ServiceM, getTzseries, + utcToSeconds) + +import Config (LoggingConfig, + ServerConfig (..)) +import Control.Exception (throw) +import Control.Monad.Logger.CallStack (logErrorN) +import Data.ByteString (ByteString) +import Data.ByteString.Lazy (toStrict) +import Data.Foldable (find, minimumBy) +import Data.Function (on, (&)) +import qualified Data.Text as T +import Data.Time.LocalTime.TimeZone.Series (TimeZoneSeries) +import qualified Data.UUID as UUID +import Extrapolation (Extrapolator (..), + LinearExtrapolator (..), + euclid) +import GHC.Generics (Generic) +import GTFS (seconds2Double) +import Prometheus (decGauge, incGauge) +import Server.Base (ServerState) + + +handleTrackerRegister + :: Pool SqlBackend + -> RegisterJson + -> ServiceM Token +handleTrackerRegister dbpool RegisterJson{..} = do + today <- liftIO getCurrentTime <&> utctDay + expires <- liftIO $ getCurrentTime <&> addUTCTime validityPeriod + runSql dbpool $ do + TrackerKey tracker <- insert (Tracker expires False registerAgent Nothing) + pure tracker + where + validityPeriod :: NominalDiffTime + validityPeriod = nominalDay + +handleTrainPing + :: Pool SqlBackend + -> ServerState + -> ServerConfig + -> LoggingT (ReaderT LoggingConfig Handler) a + -> SentPing + -> LoggingT (ReaderT LoggingConfig Handler) (Maybe TrainAnchor) +handleTrainPing dbpool subscribers cfg onError ping@SentPing{..} = + isTokenValid dbpool sentPingToken >>= \case + Nothing -> onError >> pure Nothing + Just tracker@Tracker{..} -> do + + -- unless (serverConfigDebugMode cfg) $ do + -- now <- liftIO getCurrentTime + -- let timeDiff = sentPingTimestamp `diffUTCTime` now + -- when (utctDay sentPingTimestamp /= utctDay now) $ do + -- logErrorN "received ping for wrong day" + -- throw err400 + -- when (timeDiff < 10) $ do + -- logWarnN "received ping more than 10 seconds out of date" + -- throw err400 + -- when (timeDiff > 10) $ do + -- logWarnN "received ping from more than 10 seconds in the future" + -- throw err400 + + ticketId <- case trackerCurrentTicket of + Just ticketId -> pure ticketId + -- if the tracker is not associated with a ticket, it is probably new + -- & should be auto-associated with the most fitting current ticket + Nothing -> runSql dbpool (guessTicketFromPing cfg ping) >>= \case + Just ticketId -> pure ticketId + Nothing -> do + logWarnN $ "Tracker "+|UUID.toString (coerce sentPingToken)|+ + " sent a ping, but no trips are running today." + throwError err400 + + + runSql dbpool $ insertSentPing subscribers cfg ping tracker ticketId + +insertSentPing + :: ServerState + -> ServerConfig + -> SentPing + -> Tracker + -> TicketId + -> InSql (Maybe TrainAnchor) +insertSentPing subscribers cfg ping@SentPing{..} tracker@Tracker{..} ticketId = do + ticket@Ticket{..} <- getJust ticketId + + stations <- selectList [ StopTicket ==. ticketId ] [Asc StopArrival] + >>= mapM (\stop -> do + station <- getJust (stopStation (entityVal stop)) + tzseries <- liftIO $ getTzseries cfg (GTFS.tzname (stopArrival (entityVal stop))) + pure (entityVal stop, station, tzseries)) + <&> V.fromList + + shapePoints <- selectList [ShapePointShape ==. ticketShape] [Asc ShapePointIndex] + <&> (V.fromList . fmap entityVal) + + + let anchor = extrapolateAnchorFromPing LinearExtrapolator + ticketId ticket stations shapePoints ping + + + maybeReassign <- selectFirst + [ TrainPingTicket ==. ticketId ] + [ Desc TrainPingTimestamp ] + <&> find (\ping -> trainPingSequence (entityVal ping) > trainAnchorSequence anchor) + >> guessTicketFromPing cfg ping + <&> find (/= ticketId) + + + -- mapM (\newTicketId -> if ticketId /= newTicketId then Just newTicketId else Nothing)) + -- >>= (\ping -> guessTicketFromPing cfg ping >>= \case + -- Just newTicketId | ticketId /= newTicketId -> pure (Just newTicketId) + -- _ -> pure Nothing) + + case maybeReassign of + Just newTicketId -> do + update sentPingToken + [TrackerCurrentTicket =. Just newTicketId ] + logInfoN $ "tracker "+|UUID.toText (coerce sentPingToken)|+ + "has switched direction, and was reassigned to ticket " + +|UUID.toText (coerce newTicketId)|+"." + insertSentPing subscribers cfg ping tracker newTicketId + Nothing -> do + let trackedPing = TrainPing + { trainPingToken = sentPingToken + , trainPingGeopos = sentPingGeopos + , trainPingTimestamp = sentPingTimestamp + , trainPingSequence = trainAnchorSequence anchor + , trainPingTicket = ticketId + } + + insert trackedPing + + last <- selectFirst [TrainAnchorTicket ==. ticketId] [Desc TrainAnchorWhen] + -- only insert new estimates if they've actually changed anything + when (fmap (trainAnchorDelay . entityVal) last /= Just (trainAnchorDelay anchor)) + $ void $ insert anchor + + -- are we at the final stop? if so, mark this ticket as done + -- & the tracker as free + let maxSequence = V.last stations + & (\(stop, _, _) -> stopSequence stop) + & fromIntegral + when (trainAnchorSequence anchor + 0.1 >= maxSequence) $ do + update sentPingToken + [TrackerCurrentTicket =. Nothing] + update ticketId + [TicketCompleted =. True] + logInfoN $ "Tracker "+|UUID.toString (coerce sentPingToken)|+ + " has completed ticket "+|UUID.toString (coerce ticketId)|+ + " (trip "+|ticketTripName|+")" + + queues <- liftIO $ atomically $ do + queues <- readTVar subscribers <&> M.lookup (coerce ticketId) + whenJust queues $ + mapM_ (\q -> writeTQueue q (Just trackedPing)) + pure queues + pure (Just anchor) + +handleWS + :: Pool SqlBackend + -> ServerState + -> ServerConfig + -> Metrics + -> WS.Connection -> ServiceM () +handleWS dbpool subscribers cfg Metrics{..} conn = do + liftIO $ WS.forkPingThread conn 30 + incGauge metricsWSGauge + handle (\(e :: WS.ConnectionException) -> decGauge metricsWSGauge) $ forever $ do + msg <- liftIO $ WS.receiveData conn + case A.eitherDecode msg of + Left err -> do + logWarnN ("stray websocket message: "+|decodeASCII (toStrict msg)|+" (could not decode: "+|err|+")") + liftIO $ WS.sendClose conn (C8.pack err) + -- TODO: send a close msg (Nothing) to the subscribed queues? decGauge metricsWSGauge + Right ping -> do + -- if invalid token, send a "polite" close request. Note that the client may + -- ignore this and continue sending messages, which will continue to be handled. + handleTrainPing dbpool subscribers cfg (liftIO $ WS.sendClose conn ("" :: ByteString)) ping >>= \case + Just anchor -> liftIO $ WS.sendTextData conn (A.encode anchor) + Nothing -> pure () + + +guessTicketFromPing :: ServerConfig -> SentPing -> InSql (Maybe (Key Ticket)) +guessTicketFromPing cfg SentPing{..} = do + tickets <- selectList [ TicketDay ==. utctDay sentPingTimestamp, TicketCompleted ==. False ] [] + + ticketsWithStation <- forM tickets (\ticket@(Entity ticketId _) -> do + stops <- selectList [StopTicket ==. ticketId] [Asc StopSequence] >>= mapM (\(Entity _ stop) -> do + station <- getJust (stopStation stop) + tzseries <- liftIO $ getTzseries cfg (GTFS.tzname (stopDeparture stop)) + pure (station, stop, tzseries)) + pure (ticket, stops)) + + if null ticketsWithStation then pure Nothing else do + let (closestTicket, _) = ticketsWithStation + & minimumBy (compare `on` (\(Entity _ ticket, stations) -> + let + runningDay = ticketDay ticket + smallestDistance = stations + <&> (\(station, stop, tzseries) -> spaceAndTimeDiff + (sentPingGeopos, utcToSeconds sentPingTimestamp runningDay) + (stationGeopos station, GTFS.toSeconds (stopDeparture stop) tzseries runningDay)) + & minimum + in smallestDistance)) + + logInfoN + $ "Tracker "+|UUID.toString (coerce sentPingToken)|+ + " is now handling ticket "+|UUID.toString (coerce (entityKey closestTicket))|+ + " (trip "+|ticketTripName (entityVal closestTicket)|+")." + + update sentPingToken + [TrackerCurrentTicket =. Just (entityKey closestTicket)] + + pure (Just (entityKey closestTicket)) + +spaceAndTimeDiff :: (Geopos, GTFS.Seconds) -> (Geopos, GTFS.Seconds) -> Double +spaceAndTimeDiff (pos1, time1) (pos2, time2) = + spaceDistance + abs (seconds2Double timeDiff / 3600) + where spaceDistance = euclid pos1 pos2 + timeDiff = time1 - time2 + +-- TODO: proper debug logging for expired tokens +isTokenValid :: Pool SqlBackend -> TrackerId -> ServiceM (Maybe Tracker) +isTokenValid dbpool token = runSql dbpool $ get token >>= \case + Just tracker | not (trackerBlocked tracker) -> do + ifM (hasExpired (trackerExpires tracker)) + (pure Nothing) + (pure (Just tracker)) + _ -> pure Nothing + +hasExpired :: MonadIO m => UTCTime -> m Bool +hasExpired limit = do + now <- liftIO getCurrentTime + pure (now > limit) |