From ad8a09cafa519a15a22cafbfd2fa289538edc73d Mon Sep 17 00:00:00 2001 From: stuebinm Date: Wed, 8 May 2024 22:42:35 +0200 Subject: restructure: split up the server module --- lib/Server.hs | 349 ++++++++++-------------------------------------- lib/Server/Base.hs | 9 ++ lib/Server/Ingest.hs | 211 +++++++++++++++++++++++++++++ lib/Server/Subscribe.hs | 63 +++++++++ lib/Server/Util.hs | 64 +++++---- 5 files changed, 394 insertions(+), 302 deletions(-) create mode 100644 lib/Server/Base.hs create mode 100644 lib/Server/Ingest.hs create mode 100644 lib/Server/Subscribe.hs diff --git a/lib/Server.hs b/lib/Server.hs index 73c55cb..1833aa0 100644 --- a/lib/Server.hs +++ b/lib/Server.hs @@ -1,97 +1,54 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE ExplicitNamespaces #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE OverloadedLists #-} {-# LANGUAGE PartialTypeSignatures #-} {-# LANGUAGE RecordWildCards #-} -- Implementation of the API. This module is the main point of the program. module Server (application) where -import Control.Concurrent.STM (TQueue, TVar, atomically, - newTQueue, newTVar, - newTVarIO, readTQueue, - readTVar, writeTQueue, - writeTVar) -import Control.Monad (forever, unless, void, - when) -import Control.Monad.Catch (handle) -import Control.Monad.Extra (ifM, mapMaybeM, maybeM, - unlessM, whenJust, whenM) -import Control.Monad.IO.Class (MonadIO (liftIO)) -import Control.Monad.Logger (LoggingT, MonadLogger, - NoLoggingT, logInfoN, - logWarnN) -import Control.Monad.Reader (MonadReader, ReaderT, - forM) -import Control.Monad.Trans (lift) -import Data.Aeson ((.=)) -import qualified Data.Aeson as A -import qualified Data.ByteString.Char8 as C8 -import Data.Coerce (coerce) -import Data.Functor ((<&>)) -import qualified Data.Map as M -import Data.Pool (Pool) -import Data.Proxy (Proxy (Proxy)) -import Data.Swagger (toSchema) -import Data.Text (Text) -import Data.Text.Encoding (decodeASCII, decodeUtf8) -import Data.Time (NominalDiffTime, - UTCTime (utctDay), - addUTCTime, diffUTCTime, - getCurrentTime, - nominalDay) -import qualified Data.Vector as V -import Database.Persist -import Database.Persist.Postgresql (SqlBackend, - migrateEnableExtension, - runMigration) -import Fmt ((+|), (|+)) -import qualified Network.WebSockets as WS -import Servant (Application, - ServerError (errBody), - err400, err401, err404, - serve, - serveDirectoryFileServer, - throwError) -import Servant.API (NoContent (..), - (:<|>) (..)) -import Servant.Server (Handler, hoistServer) -import Servant.Swagger (toSwagger) - -import API +import API (API, CompleteAPI, Metrics (..)) +import Conduit (ResourceT) +import Config (LoggingConfig, ServerConfig (..)) +import Control.Concurrent.STM (newTVarIO) +import Control.Monad.Extra (forM) +import Control.Monad.IO.Class (MonadIO (liftIO)) +import Control.Monad.Logger (MonadLogger) +import Control.Monad.Reader (ReaderT) +import Data.ByteString.Lazy (toStrict) +import Data.Functor ((<&>)) +import qualified Data.Map as M +import Data.Pool (Pool) +import Data.Proxy (Proxy (Proxy)) +import Data.Text.Encoding (decodeUtf8) +import Data.Time (getCurrentTime) +import Data.UUID (UUID) +import Database.Persist (Entity (..), + PersistQueryRead (selectFirst), + SelectOpt (Desc), selectList, + (<-.), (==.), (>=.), (||.)) +import Database.Persist.Postgresql (SqlBackend, + migrateEnableExtension, + runMigration) +import Fmt ((+|), (|+)) import qualified GTFS import Persist -import Server.ControlRoom -import Server.GTFS_RT (gtfsRealtimeServer) -import Server.Util (Service, ServiceM, - runService, sendErrorMsg, - utcToSeconds) -import Yesod (toWaiAppPlain) +import Prometheus (Info (Info), exportMetricsAsText, + gauge, register) +import Prometheus.Metric.GHC (ghcMetrics) +import Servant (Application, err401, serve, + serveDirectoryFileServer, + throwError) +import Servant.API ((:<|>) (..)) +import Servant.Server (hoistServer) +import Servant.Swagger (toSwagger) +import Server.Base (ServerState) +import Server.ControlRoom (ControlRoom (ControlRoom)) +import Server.GTFS_RT (gtfsRealtimeServer) +import Server.Ingest (handleTrackerRegister, + handleTrainPing, handleWS) +import Server.Subscribe (handleSubscribe) +import Server.Util (Service, runService) +import System.IO.Unsafe (unsafePerformIO) +import Yesod (toWaiAppPlain) -import Conduit (ResourceT) -import Config (LoggingConfig, - ServerConfig (..)) -import Control.Exception (throw) -import Data.ByteString (ByteString) -import Data.ByteString.Lazy (toStrict) -import Data.Foldable (minimumBy) -import Data.Function (on, (&)) -import Data.Maybe (fromMaybe) -import qualified Data.Text as T -import Data.Time.LocalTime.TimeZone.Olson (getTimeZoneSeriesFromOlsonFile) -import Data.Time.LocalTime.TimeZone.Series (TimeZoneSeries) -import Data.UUID (UUID) -import qualified Data.UUID as UUID -import Extrapolation (Extrapolator (..), - LinearExtrapolator (..), - euclid) -import GTFS (Seconds (unSeconds), - seconds2Double) -import Prometheus -import Prometheus.Metric.GHC -import System.FilePath (()) -import System.IO.Unsafe application :: GTFS.GTFS -> Pool SqlBackend -> ServerConfig -> IO Application application gtfs dbpool settings = do @@ -100,189 +57,48 @@ application gtfs dbpool settings = do <$> register (gauge (Info "ws_connections" "Number of WS Connections")) register ghcMetrics - -- TODO: maybe cache these in a TVar, we're not likely to ever need - -- more than one of these - let getTzseries tzname = getTimeZoneSeriesFromOlsonFile - (serverConfigZoneinfoPath settings T.unpack tzname) - subscribers <- newTVarIO mempty pure $ serve (Proxy @CompleteAPI) $ hoistServer (Proxy @CompleteAPI) (runService (serverConfigLogging settings)) - $ server gtfs getTzseries metrics subscribers dbpool settings + $ server gtfs metrics subscribers dbpool settings --- databaseMigration :: ConnectionString -> IO () doMigration pool = runSqlWithoutLog pool $ runMigration $ do migrateEnableExtension "uuid-ossp" migrateAll -server :: GTFS.GTFS -> (Text -> IO TimeZoneSeries) -> Metrics -> TVar (M.Map UUID [TQueue (Maybe TrainPing)]) -> Pool SqlBackend -> ServerConfig -> Service CompleteAPI -server gtfs getTzseries Metrics{..} subscribers dbpool settings = handleDebugAPI - :<|> (handleTrackerRegister :<|> handleTrainPing (throwError err401) :<|> handleWS - :<|> handleSubscribe :<|> handleDebugState :<|> handleDebugTrain +server + :: GTFS.GTFS + -> Metrics + -> ServerState + -> Pool SqlBackend + -> ServerConfig + -> Service CompleteAPI +server gtfs metrics@Metrics{..} subscribers dbpool settings = handleDebugAPI + :<|> (handleTrackerRegister dbpool + :<|> handleTrainPing dbpool subscribers settings (throwError err401) + :<|> handleWS dbpool subscribers settings metrics + :<|> handleSubscribe dbpool subscribers + :<|> handleDebugState :<|> handleDebugTrain :<|> pure (GTFS.gtfsFile gtfs) :<|> gtfsRealtimeServer gtfs dbpool) - :<|> metrics + :<|> handleMetrics :<|> serveDirectoryFileServer (serverConfigAssets settings) :<|> pure (unsafePerformIO (toWaiAppPlain (ControlRoom gtfs dbpool settings))) - where handleTrackerRegister RegisterJson{..} = do - today <- liftIO getCurrentTime <&> utctDay - expires <- liftIO $ getCurrentTime <&> addUTCTime validityPeriod - runSql dbpool $ do - TrackerKey tracker <- insert (Tracker expires False registerAgent Nothing) - pure tracker - handleTrainPing onError ping@TrainPing{..} = - isTokenValid dbpool trainPingToken >>= \case - Nothing -> onError >> pure Nothing - Just tracker@Tracker{..} -> do - - -- if the tracker is not associated with a ticket, it is probably - -- just starting out on a new trip, or has finished an old one. - maybeTicketId <- case trackerCurrentTicket of - Just ticketId -> pure (Just ticketId) - Nothing -> runSql dbpool $ do - now <- liftIO getCurrentTime - tickets <- selectList [ TicketDay ==. utctDay now, TicketCompleted ==. False ] [] - ticketsWithFirstStation <- flip mapMaybeM tickets - (\ticket@(Entity ticketId _) -> do - selectFirst [StopTicket ==. ticketId] [Asc StopSequence] >>= \case - Nothing -> pure Nothing - Just (Entity _ stop) -> do - station <- getJust (stopStation stop) - tzseries <- liftIO $ getTzseries (GTFS.tzname (stopDeparture stop)) - pure (Just (ticket, station, stop, tzseries))) - - if null ticketsWithFirstStation then pure Nothing else do - let (closestTicket, _, _, _) = minimumBy - -- (compare `on` euclid trainPingGeopos . stationGeopos . snd) - (compare `on` - (\(Entity _ ticket, station, stop, tzseries) -> - let - runningDay = ticketDay ticket - spaceDistance = euclid trainPingGeopos (stationGeopos station) - timeDiff = - GTFS.toSeconds (stopDeparture stop) tzseries runningDay - - utcToSeconds now runningDay - in - euclid trainPingGeopos (stationGeopos station) - + abs (seconds2Double timeDiff / 3600))) - ticketsWithFirstStation - logInfoN - $ "Tracker "+|UUID.toString (coerce trainPingToken)|+ - " is now handling ticket "+|UUID.toString (coerce (entityKey closestTicket))|+ - " (trip "+|ticketTripName (entityVal closestTicket)|+")." - - update (coerce trainPingToken) - [TrackerCurrentTicket =. Just (entityKey closestTicket)] - - pure (Just (entityKey closestTicket)) - - ticketId <- case maybeTicketId of - Just ticketId -> pure ticketId - Nothing -> do - logWarnN $ "Tracker "+|UUID.toString (coerce trainPingToken)|+ - " sent a ping, but no trips are running today." - throwError err400 - - runSql dbpool $ do - ticket@Ticket{..} <- getJust ticketId - - stations <- selectList [ StopTicket ==. ticketId ] [Asc StopArrival] - >>= mapM (\stop -> do - station <- getJust (stopStation (entityVal stop)) - tzseries <- liftIO $ getTzseries (GTFS.tzname (stopArrival (entityVal stop))) - pure (entityVal stop, station, tzseries)) - <&> V.fromList - - shapePoints <- selectList [ShapePointShape ==. ticketShape] [Asc ShapePointIndex] - <&> (V.fromList . fmap entityVal) - - let anchor = extrapolateAnchorFromPing LinearExtrapolator - ticketId ticket stations shapePoints ping - - insert ping - - last <- selectFirst [TrainAnchorTicket ==. ticketId] [Desc TrainAnchorWhen] - -- only insert new estimates if they've actually changed anything - when (fmap (trainAnchorDelay . entityVal) last /= Just (trainAnchorDelay anchor)) - $ void $ insert anchor - - -- are we at the final stop? if so, mark this ticket as done - -- & the tracker as free - let maxSequence = V.last stations - & (\(stop, _, _) -> stopSequence stop) - & fromIntegral - when (trainAnchorSequence anchor + 0.1 >= maxSequence) $ do - update (coerce trainPingToken) - [TrackerCurrentTicket =. Nothing] - update ticketId - [TicketCompleted =. True] - logInfoN $ "Tracker "+|UUID.toString (coerce trainPingToken)|+ - " has completed ticket "+|UUID.toString (coerce ticketId)|+ - " (trip "+|ticketTripName|+")" - - queues <- liftIO $ atomically $ do - queues <- readTVar subscribers <&> M.lookup (coerce ticketId) - whenJust queues $ - mapM_ (\q -> writeTQueue q (Just ping)) - pure queues - pure (Just anchor) - handleWS conn = do - liftIO $ WS.forkPingThread conn 30 - incGauge metricsWSGauge - handle (\(e :: WS.ConnectionException) -> decGauge metricsWSGauge) $ forever $ do - msg <- liftIO $ WS.receiveData conn - case A.eitherDecode msg of - Left err -> do - logWarnN ("stray websocket message: "+|decodeASCII (toStrict msg)|+" (could not decode: "+|err|+")") - liftIO $ WS.sendClose conn (C8.pack err) - -- TODO: send a close msg (Nothing) to the subscribed queues? decGauge metricsWSGauge - Right ping -> do - -- if invalid token, send a "polite" close request. Note that the client may - -- ignore this and continue sending messages, which will continue to be handled. - handleTrainPing (liftIO $ WS.sendClose conn ("" :: ByteString)) ping >>= \case - Just anchor -> liftIO $ WS.sendTextData conn (A.encode anchor) - Nothing -> pure () - handleSubscribe (ticketId :: UUID) conn = liftIO $ WS.withPingThread conn 30 (pure ()) $ do - queue <- atomically $ do - queue <- newTQueue - qs <- readTVar subscribers - writeTVar subscribers - $ M.insertWith (<>) ticketId [queue] qs - pure queue - -- send most recent ping, if any (so we won't have to wait for movement) - lastPing <- runSqlWithoutLog dbpool $ do - trackers <- getTicketTrackers ticketId - <&> fmap entityKey - selectFirst [TrainPingToken <-. trackers] [Desc TrainPingTimestamp] - <&> fmap entityVal - whenJust lastPing $ \ping -> - WS.sendTextData conn (A.encode lastPing) - handle (\(e :: WS.ConnectionException) -> removeSubscriber queue) $ forever $ do - res <- atomically $ readTQueue queue - case res of - Just ping -> WS.sendTextData conn (A.encode ping) - Nothing -> do - removeSubscriber queue - WS.sendClose conn (C8.pack "train ended") - where removeSubscriber queue = atomically $ do - qs <- readTVar subscribers - writeTVar subscribers - $ M.adjust (filter (/= queue)) ticketId qs - handleDebugState = do - now <- liftIO getCurrentTime - runSql dbpool $ do - tracker <- selectList [TrackerBlocked ==. False, TrackerExpires >=. now] [] - pairs <- forM tracker $ \(Entity token@(TrackerKey uuid) _) -> do - entities <- selectList [TrainPingToken ==. token] [] - pure (uuid, fmap entityVal entities) - pure (M.fromList pairs) - handleDebugTrain ticketId = do - runSql dbpool $ do - trackers <- getTicketTrackers ticketId - pings <- forM trackers $ \(Entity token _) -> do - selectList [TrainPingToken ==. token] [] <&> fmap entityVal - pure (concat pings) - handleDebugAPI = pure $ toSwagger (Proxy @API) - metrics = exportMetricsAsText <&> (decodeUtf8 . toStrict) + where + handleDebugState = do + now <- liftIO getCurrentTime + runSql dbpool $ do + tracker <- selectList [TrackerBlocked ==. False, TrackerExpires >=. now] [] + pairs <- forM tracker $ \(Entity token@(TrackerKey uuid) _) -> do + entities <- selectList [TrainPingToken ==. token] [] + pure (uuid, fmap entityVal entities) + pure (M.fromList pairs) + handleDebugTrain ticketId = runSql dbpool $ do + trackers <- getTicketTrackers ticketId + pings <- forM trackers $ \(Entity token _) -> do + selectList [TrainPingToken ==. token] [] <&> fmap entityVal + pure (concat pings) + handleDebugAPI = pure $ toSwagger (Proxy @API) + handleMetrics = exportMetricsAsText <&> (decodeUtf8 . toStrict) getTicketTrackers :: (MonadLogger (t (ResourceT IO)), MonadIO (t (ResourceT IO))) => UUID -> ReaderT SqlBackend (t (ResourceT IO)) [Entity Tracker] @@ -290,22 +106,3 @@ getTicketTrackers ticketId = do joins <- selectList [TrackerTicketTicket ==. TicketKey ticketId] [] <&> fmap (trackerTicketTracker . entityVal) selectList ([TrackerId <-. joins] ||. [TrackerCurrentTicket ==. Just (TicketKey ticketId)]) [] - - --- TODO: proper debug logging for expired tokens -isTokenValid :: Pool SqlBackend -> TrackerId -> ServiceM (Maybe Tracker) -isTokenValid dbpool token = runSql dbpool $ get token >>= \case - Just tracker | not (trackerBlocked tracker) -> do - ifM (hasExpired (trackerExpires tracker)) - (pure Nothing) - (pure (Just tracker)) - _ -> pure Nothing - -hasExpired :: MonadIO m => UTCTime -> m Bool -hasExpired limit = do - now <- liftIO getCurrentTime - pure (now > limit) - -validityPeriod :: NominalDiffTime -validityPeriod = nominalDay - diff --git a/lib/Server/Base.hs b/lib/Server/Base.hs new file mode 100644 index 0000000..14b77ca --- /dev/null +++ b/lib/Server/Base.hs @@ -0,0 +1,9 @@ + +module Server.Base (ServerState) where + +import Control.Concurrent.STM (TQueue, TVar) +import qualified Data.Map as M +import Data.UUID (UUID) +import Persist (TrainPing) + +type ServerState = TVar (M.Map UUID [TQueue (Maybe TrainPing)]) diff --git a/lib/Server/Ingest.hs b/lib/Server/Ingest.hs new file mode 100644 index 0000000..d774017 --- /dev/null +++ b/lib/Server/Ingest.hs @@ -0,0 +1,211 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE RecordWildCards #-} + +module Server.Ingest (handleTrackerRegister, handleTrainPing, handleWS) where +import API (Metrics (..), + RegisterJson (..)) +import Control.Concurrent.STM (atomically, readTVar, + writeTQueue) +import Control.Monad (forever, void, when) +import Control.Monad.Catch (handle) +import Control.Monad.Extra (ifM, mapMaybeM, whenJust) +import Control.Monad.IO.Class (MonadIO (liftIO)) +import Control.Monad.Logger (LoggingT, logInfoN, + logWarnN) +import Control.Monad.Reader (ReaderT) +import qualified Data.Aeson as A +import qualified Data.ByteString.Char8 as C8 +import Data.Coerce (coerce) +import Data.Functor ((<&>)) +import qualified Data.Map as M +import Data.Pool (Pool) +import Data.Text (Text) +import Data.Text.Encoding (decodeASCII, decodeUtf8) +import Data.Time (NominalDiffTime, + UTCTime (..), addUTCTime, + getCurrentTime, + nominalDay) +import qualified Data.Vector as V +import Database.Persist +import Database.Persist.Postgresql (SqlBackend) +import Fmt ((+|), (|+)) +import qualified GTFS +import qualified Network.WebSockets as WS +import Persist +import Servant (err400, throwError) +import Servant.Server (Handler) +import Server.Util (ServiceM, getTzseries, + utcToSeconds) + +import Config (LoggingConfig, + ServerConfig) +import Data.ByteString (ByteString) +import Data.ByteString.Lazy (toStrict) +import Data.Foldable (minimumBy) +import Data.Function (on, (&)) +import qualified Data.Text as T +import Data.Time.LocalTime.TimeZone.Series (TimeZoneSeries) +import qualified Data.UUID as UUID +import Extrapolation (Extrapolator (..), + LinearExtrapolator (..), + euclid) +import GTFS (seconds2Double) +import Prometheus (decGauge, incGauge) +import Server.Base (ServerState) + + + +handleTrackerRegister + :: Pool SqlBackend + -> RegisterJson + -> ServiceM Token +handleTrackerRegister dbpool RegisterJson{..} = do + today <- liftIO getCurrentTime <&> utctDay + expires <- liftIO $ getCurrentTime <&> addUTCTime validityPeriod + runSql dbpool $ do + TrackerKey tracker <- insert (Tracker expires False registerAgent Nothing) + pure tracker + where + validityPeriod :: NominalDiffTime + validityPeriod = nominalDay + +handleTrainPing + :: Pool SqlBackend + -> ServerState + -> ServerConfig + -> LoggingT (ReaderT LoggingConfig Handler) a + -> TrainPing + -> LoggingT (ReaderT LoggingConfig Handler) (Maybe TrainAnchor) +handleTrainPing dbpool subscribers cfg onError ping@TrainPing{..} = + isTokenValid dbpool trainPingToken >>= \case + Nothing -> onError >> pure Nothing + Just tracker@Tracker{..} -> do + + -- if the tracker is not associated with a ticket, it is probably + -- just starting out on a new trip, or has finished an old one. + maybeTicketId <- case trackerCurrentTicket of + Just ticketId -> pure (Just ticketId) + Nothing -> runSql dbpool $ do + now <- liftIO getCurrentTime + tickets <- selectList [ TicketDay ==. utctDay now, TicketCompleted ==. False ] [] + ticketsWithFirstStation <- flip mapMaybeM tickets + (\ticket@(Entity ticketId _) -> do + selectFirst [StopTicket ==. ticketId] [Asc StopSequence] >>= \case + Nothing -> pure Nothing + Just (Entity _ stop) -> do + station <- getJust (stopStation stop) + tzseries <- liftIO $ getTzseries cfg (GTFS.tzname (stopDeparture stop)) + pure (Just (ticket, station, stop, tzseries))) + + if null ticketsWithFirstStation then pure Nothing else do + let (closestTicket, _, _, _) = minimumBy + -- (compare `on` euclid trainPingGeopos . stationGeopos . snd) + (compare `on` + (\(Entity _ ticket, station, stop, tzseries) -> + let + runningDay = ticketDay ticket + spaceDistance = euclid trainPingGeopos (stationGeopos station) + timeDiff = + GTFS.toSeconds (stopDeparture stop) tzseries runningDay + - utcToSeconds now runningDay + in + euclid trainPingGeopos (stationGeopos station) + + abs (seconds2Double timeDiff / 3600))) + ticketsWithFirstStation + logInfoN + $ "Tracker "+|UUID.toString (coerce trainPingToken)|+ + " is now handling ticket "+|UUID.toString (coerce (entityKey closestTicket))|+ + " (trip "+|ticketTripName (entityVal closestTicket)|+")." + + update trainPingToken + [TrackerCurrentTicket =. Just (entityKey closestTicket)] + + pure (Just (entityKey closestTicket)) + + ticketId <- case maybeTicketId of + Just ticketId -> pure ticketId + Nothing -> do + logWarnN $ "Tracker "+|UUID.toString (coerce trainPingToken)|+ + " sent a ping, but no trips are running today." + throwError err400 + + runSql dbpool $ do + ticket@Ticket{..} <- getJust ticketId + + stations <- selectList [ StopTicket ==. ticketId ] [Asc StopArrival] + >>= mapM (\stop -> do + station <- getJust (stopStation (entityVal stop)) + tzseries <- liftIO $ getTzseries cfg (GTFS.tzname (stopArrival (entityVal stop))) + pure (entityVal stop, station, tzseries)) + <&> V.fromList + + shapePoints <- selectList [ShapePointShape ==. ticketShape] [Asc ShapePointIndex] + <&> (V.fromList . fmap entityVal) + + let anchor = extrapolateAnchorFromPing LinearExtrapolator + ticketId ticket stations shapePoints ping + + insert ping + + last <- selectFirst [TrainAnchorTicket ==. ticketId] [Desc TrainAnchorWhen] + -- only insert new estimates if they've actually changed anything + when (fmap (trainAnchorDelay . entityVal) last /= Just (trainAnchorDelay anchor)) + $ void $ insert anchor + + -- are we at the final stop? if so, mark this ticket as done + -- & the tracker as free + let maxSequence = V.last stations + & (\(stop, _, _) -> stopSequence stop) + & fromIntegral + when (trainAnchorSequence anchor + 0.1 >= maxSequence) $ do + update trainPingToken + [TrackerCurrentTicket =. Nothing] + update ticketId + [TicketCompleted =. True] + logInfoN $ "Tracker "+|UUID.toString (coerce trainPingToken)|+ + " has completed ticket "+|UUID.toString (coerce ticketId)|+ + " (trip "+|ticketTripName|+")" + + queues <- liftIO $ atomically $ do + queues <- readTVar subscribers <&> M.lookup (coerce ticketId) + whenJust queues $ + mapM_ (\q -> writeTQueue q (Just ping)) + pure queues + pure (Just anchor) + +handleWS + :: Pool SqlBackend + -> ServerState + -> ServerConfig + -> Metrics + -> WS.Connection -> ServiceM () +handleWS dbpool subscribers cfg Metrics{..} conn = do + liftIO $ WS.forkPingThread conn 30 + incGauge metricsWSGauge + handle (\(e :: WS.ConnectionException) -> decGauge metricsWSGauge) $ forever $ do + msg <- liftIO $ WS.receiveData conn + case A.eitherDecode msg of + Left err -> do + logWarnN ("stray websocket message: "+|decodeASCII (toStrict msg)|+" (could not decode: "+|err|+")") + liftIO $ WS.sendClose conn (C8.pack err) + -- TODO: send a close msg (Nothing) to the subscribed queues? decGauge metricsWSGauge + Right ping -> do + -- if invalid token, send a "polite" close request. Note that the client may + -- ignore this and continue sending messages, which will continue to be handled. + handleTrainPing dbpool subscribers cfg (liftIO $ WS.sendClose conn ("" :: ByteString)) ping >>= \case + Just anchor -> liftIO $ WS.sendTextData conn (A.encode anchor) + Nothing -> pure () + +-- TODO: proper debug logging for expired tokens +isTokenValid :: Pool SqlBackend -> TrackerId -> ServiceM (Maybe Tracker) +isTokenValid dbpool token = runSql dbpool $ get token >>= \case + Just tracker | not (trackerBlocked tracker) -> do + ifM (hasExpired (trackerExpires tracker)) + (pure Nothing) + (pure (Just tracker)) + _ -> pure Nothing + +hasExpired :: MonadIO m => UTCTime -> m Bool +hasExpired limit = do + now <- liftIO getCurrentTime + pure (now > limit) diff --git a/lib/Server/Subscribe.hs b/lib/Server/Subscribe.hs new file mode 100644 index 0000000..fdc092b --- /dev/null +++ b/lib/Server/Subscribe.hs @@ -0,0 +1,63 @@ + +module Server.Subscribe where +import Conduit (MonadIO (..)) +import Control.Concurrent.STM (atomically, newTQueue, readTQueue, + readTVar, writeTVar) +import Control.Exception (handle) +import Control.Monad.Extra (forever, whenJust) +import qualified Data.Aeson as A +import qualified Data.ByteString.Char8 as C8 +import Data.Functor ((<&>)) +import Data.Map (Map) +import qualified Data.Map as M +import Data.Pool +import Data.UUID (UUID) +import Database.Persist (Entity (entityKey), SelectOpt (Desc), + entityVal, selectFirst, selectList, + (<-.), (==.), (||.)) +import Database.Persist.Sql (SqlBackend) +import qualified Network.WebSockets as WS +import Persist +import Server.Base (ServerState) +import Server.Util (ServiceM) + + +handleSubscribe + :: Pool SqlBackend + -> ServerState + -> UUID + -> WS.Connection + -> ServiceM () +handleSubscribe dbpool subscribers (ticketId :: UUID) conn = liftIO $ WS.withPingThread conn 30 (pure ()) $ do + queue <- atomically $ do + queue <- newTQueue + qs <- readTVar subscribers + writeTVar subscribers + $ M.insertWith (<>) ticketId [queue] qs + pure queue + -- send most recent ping, if any (so we won't have to wait for movement) + lastPing <- runSqlWithoutLog dbpool $ do + trackers <- getTicketTrackers ticketId + <&> fmap entityKey + selectFirst [TrainPingToken <-. trackers] [Desc TrainPingTimestamp] + <&> fmap entityVal + whenJust lastPing $ \ping -> + WS.sendTextData conn (A.encode lastPing) + handle (\(e :: WS.ConnectionException) -> removeSubscriber queue) $ forever $ do + res <- atomically $ readTQueue queue + case res of + Just ping -> WS.sendTextData conn (A.encode ping) + Nothing -> do + removeSubscriber queue + WS.sendClose conn (C8.pack "train ended") + where removeSubscriber queue = atomically $ do + qs <- readTVar subscribers + writeTVar subscribers + $ M.adjust (filter (/= queue)) ticketId qs + +-- getTicketTrackers :: (MonadLogger (t (ResourceT IO)), MonadIO (t (ResourceT IO))) +-- => UUID -> ReaderT SqlBackend (t (ResourceT IO)) [Entity Tracker] +getTicketTrackers ticketId = do + joins <- selectList [TrackerTicketTicket ==. TicketKey ticketId] [] + <&> fmap (trackerTicketTracker . entityVal) + selectList ([TrackerId <-. joins] ||. [TrackerCurrentTicket ==. Just (TicketKey ticketId)]) [] diff --git a/lib/Server/Util.hs b/lib/Server/Util.hs index 0106428..290b9c5 100644 --- a/lib/Server/Util.hs +++ b/lib/Server/Util.hs @@ -1,33 +1,41 @@ {-# LANGUAGE BlockArguments #-} {-# LANGUAGE RecordWildCards #-} -- | mostly the monad the service runs in -module Server.Util (Service, ServiceM, runService, sendErrorMsg, secondsNow, utcToSeconds, runLogging) where +module Server.Util (Service, ServiceM, runService, sendErrorMsg, secondsNow, utcToSeconds, runLogging, getTzseries) where -import Config (LoggingConfig (..)) -import Control.Exception (handle, try) -import Control.Monad.Extra (void, whenJust) -import Control.Monad.IO.Class (MonadIO (liftIO)) -import Control.Monad.Logger (Loc, LogLevel (..), LogSource, LogStr, - LoggingT (..), defaultOutput, - fromLogStr, runStderrLoggingT) -import Control.Monad.Reader (ReaderT (..)) -import qualified Data.Aeson as A -import Data.ByteString (ByteString) -import qualified Data.ByteString as C8 -import Data.Text (Text) -import qualified Data.Text as T -import Data.Text.Encoding (decodeUtf8Lenient) -import Data.Time (Day, UTCTime (..), diffUTCTime, - getCurrentTime, - nominalDiffTimeToSeconds) -import Fmt ((+|), (|+)) -import GHC.IO.Exception (IOException (IOError)) -import GTFS (Seconds (..)) -import Prometheus (MonadMonitor (doIO)) -import Servant (Handler, ServerError, ServerT, err404, - errBody, errHeaders, throwError) -import System.IO (stderr) -import System.Process.Extra (callProcess) +import Config (LoggingConfig (..), + ServerConfig (..)) +import Control.Exception (handle, try) +import Control.Monad.Extra (void, whenJust) +import Control.Monad.IO.Class (MonadIO (liftIO)) +import Control.Monad.Logger (Loc, LogLevel (..), + LogSource, LogStr, + LoggingT (..), + defaultOutput, fromLogStr, + runStderrLoggingT) +import Control.Monad.Reader (ReaderT (..)) +import qualified Data.Aeson as A +import Data.ByteString (ByteString) +import qualified Data.ByteString as C8 +import Data.Text (Text) +import qualified Data.Text as T +import Data.Text.Encoding (decodeUtf8Lenient) +import Data.Time (Day, UTCTime (..), + diffUTCTime, + getCurrentTime, + nominalDiffTimeToSeconds) +import Data.Time.LocalTime.TimeZone.Olson (getTimeZoneSeriesFromOlsonFile) +import Data.Time.LocalTime.TimeZone.Series (TimeZoneSeries) +import Fmt ((+|), (|+)) +import GHC.IO.Exception (IOException (IOError)) +import GTFS (Seconds (..)) +import Prometheus (MonadMonitor (doIO)) +import Servant (Handler, ServerError, + ServerT, err404, errBody, + errHeaders, throwError) +import System.FilePath (()) +import System.IO (stderr) +import System.Process.Extra (callProcess) type ServiceM = LoggingT (ReaderT LoggingConfig Handler) type Service api = ServerT api ServiceM @@ -77,3 +85,7 @@ secondsNow runningDay = do utcToSeconds :: UTCTime -> Day -> Seconds utcToSeconds time day = Seconds $ round $ nominalDiffTimeToSeconds $ diffUTCTime time (UTCTime day 0) + +getTzseries :: ServerConfig -> Text -> IO TimeZoneSeries +getTzseries settings tzname = getTimeZoneSeriesFromOlsonFile + (serverConfigZoneinfoPath settings T.unpack tzname) -- cgit v1.2.3