aboutsummaryrefslogtreecommitdiff
path: root/lib/Server.hs
diff options
context:
space:
mode:
authorstuebinm2024-05-08 22:42:35 +0200
committerstuebinm2024-05-08 22:43:05 +0200
commitad8a09cafa519a15a22cafbfd2fa289538edc73d (patch)
tree81f49d19669d5895115a1e8d39bd3557fc0c03d8 /lib/Server.hs
parent0febc9cd99e0d8b80b1385593e25e7670d5c842b (diff)
restructure: split up the server module
Diffstat (limited to 'lib/Server.hs')
-rw-r--r--lib/Server.hs349
1 files changed, 73 insertions, 276 deletions
diff --git a/lib/Server.hs b/lib/Server.hs
index 73c55cb..1833aa0 100644
--- a/lib/Server.hs
+++ b/lib/Server.hs
@@ -1,97 +1,54 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE ExplicitNamespaces #-}
-{-# LANGUAGE LambdaCase #-}
-{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE RecordWildCards #-}
-- Implementation of the API. This module is the main point of the program.
module Server (application) where
-import Control.Concurrent.STM (TQueue, TVar, atomically,
- newTQueue, newTVar,
- newTVarIO, readTQueue,
- readTVar, writeTQueue,
- writeTVar)
-import Control.Monad (forever, unless, void,
- when)
-import Control.Monad.Catch (handle)
-import Control.Monad.Extra (ifM, mapMaybeM, maybeM,
- unlessM, whenJust, whenM)
-import Control.Monad.IO.Class (MonadIO (liftIO))
-import Control.Monad.Logger (LoggingT, MonadLogger,
- NoLoggingT, logInfoN,
- logWarnN)
-import Control.Monad.Reader (MonadReader, ReaderT,
- forM)
-import Control.Monad.Trans (lift)
-import Data.Aeson ((.=))
-import qualified Data.Aeson as A
-import qualified Data.ByteString.Char8 as C8
-import Data.Coerce (coerce)
-import Data.Functor ((<&>))
-import qualified Data.Map as M
-import Data.Pool (Pool)
-import Data.Proxy (Proxy (Proxy))
-import Data.Swagger (toSchema)
-import Data.Text (Text)
-import Data.Text.Encoding (decodeASCII, decodeUtf8)
-import Data.Time (NominalDiffTime,
- UTCTime (utctDay),
- addUTCTime, diffUTCTime,
- getCurrentTime,
- nominalDay)
-import qualified Data.Vector as V
-import Database.Persist
-import Database.Persist.Postgresql (SqlBackend,
- migrateEnableExtension,
- runMigration)
-import Fmt ((+|), (|+))
-import qualified Network.WebSockets as WS
-import Servant (Application,
- ServerError (errBody),
- err400, err401, err404,
- serve,
- serveDirectoryFileServer,
- throwError)
-import Servant.API (NoContent (..),
- (:<|>) (..))
-import Servant.Server (Handler, hoistServer)
-import Servant.Swagger (toSwagger)
-
-import API
+import API (API, CompleteAPI, Metrics (..))
+import Conduit (ResourceT)
+import Config (LoggingConfig, ServerConfig (..))
+import Control.Concurrent.STM (newTVarIO)
+import Control.Monad.Extra (forM)
+import Control.Monad.IO.Class (MonadIO (liftIO))
+import Control.Monad.Logger (MonadLogger)
+import Control.Monad.Reader (ReaderT)
+import Data.ByteString.Lazy (toStrict)
+import Data.Functor ((<&>))
+import qualified Data.Map as M
+import Data.Pool (Pool)
+import Data.Proxy (Proxy (Proxy))
+import Data.Text.Encoding (decodeUtf8)
+import Data.Time (getCurrentTime)
+import Data.UUID (UUID)
+import Database.Persist (Entity (..),
+ PersistQueryRead (selectFirst),
+ SelectOpt (Desc), selectList,
+ (<-.), (==.), (>=.), (||.))
+import Database.Persist.Postgresql (SqlBackend,
+ migrateEnableExtension,
+ runMigration)
+import Fmt ((+|), (|+))
import qualified GTFS
import Persist
-import Server.ControlRoom
-import Server.GTFS_RT (gtfsRealtimeServer)
-import Server.Util (Service, ServiceM,
- runService, sendErrorMsg,
- utcToSeconds)
-import Yesod (toWaiAppPlain)
+import Prometheus (Info (Info), exportMetricsAsText,
+ gauge, register)
+import Prometheus.Metric.GHC (ghcMetrics)
+import Servant (Application, err401, serve,
+ serveDirectoryFileServer,
+ throwError)
+import Servant.API ((:<|>) (..))
+import Servant.Server (hoistServer)
+import Servant.Swagger (toSwagger)
+import Server.Base (ServerState)
+import Server.ControlRoom (ControlRoom (ControlRoom))
+import Server.GTFS_RT (gtfsRealtimeServer)
+import Server.Ingest (handleTrackerRegister,
+ handleTrainPing, handleWS)
+import Server.Subscribe (handleSubscribe)
+import Server.Util (Service, runService)
+import System.IO.Unsafe (unsafePerformIO)
+import Yesod (toWaiAppPlain)
-import Conduit (ResourceT)
-import Config (LoggingConfig,
- ServerConfig (..))
-import Control.Exception (throw)
-import Data.ByteString (ByteString)
-import Data.ByteString.Lazy (toStrict)
-import Data.Foldable (minimumBy)
-import Data.Function (on, (&))
-import Data.Maybe (fromMaybe)
-import qualified Data.Text as T
-import Data.Time.LocalTime.TimeZone.Olson (getTimeZoneSeriesFromOlsonFile)
-import Data.Time.LocalTime.TimeZone.Series (TimeZoneSeries)
-import Data.UUID (UUID)
-import qualified Data.UUID as UUID
-import Extrapolation (Extrapolator (..),
- LinearExtrapolator (..),
- euclid)
-import GTFS (Seconds (unSeconds),
- seconds2Double)
-import Prometheus
-import Prometheus.Metric.GHC
-import System.FilePath ((</>))
-import System.IO.Unsafe
application :: GTFS.GTFS -> Pool SqlBackend -> ServerConfig -> IO Application
application gtfs dbpool settings = do
@@ -100,189 +57,48 @@ application gtfs dbpool settings = do
<$> register (gauge (Info "ws_connections" "Number of WS Connections"))
register ghcMetrics
- -- TODO: maybe cache these in a TVar, we're not likely to ever need
- -- more than one of these
- let getTzseries tzname = getTimeZoneSeriesFromOlsonFile
- (serverConfigZoneinfoPath settings </> T.unpack tzname)
-
subscribers <- newTVarIO mempty
pure $ serve (Proxy @CompleteAPI)
$ hoistServer (Proxy @CompleteAPI) (runService (serverConfigLogging settings))
- $ server gtfs getTzseries metrics subscribers dbpool settings
+ $ server gtfs metrics subscribers dbpool settings
--- databaseMigration :: ConnectionString -> IO ()
doMigration pool = runSqlWithoutLog pool $ runMigration $ do
migrateEnableExtension "uuid-ossp"
migrateAll
-server :: GTFS.GTFS -> (Text -> IO TimeZoneSeries) -> Metrics -> TVar (M.Map UUID [TQueue (Maybe TrainPing)]) -> Pool SqlBackend -> ServerConfig -> Service CompleteAPI
-server gtfs getTzseries Metrics{..} subscribers dbpool settings = handleDebugAPI
- :<|> (handleTrackerRegister :<|> handleTrainPing (throwError err401) :<|> handleWS
- :<|> handleSubscribe :<|> handleDebugState :<|> handleDebugTrain
+server
+ :: GTFS.GTFS
+ -> Metrics
+ -> ServerState
+ -> Pool SqlBackend
+ -> ServerConfig
+ -> Service CompleteAPI
+server gtfs metrics@Metrics{..} subscribers dbpool settings = handleDebugAPI
+ :<|> (handleTrackerRegister dbpool
+ :<|> handleTrainPing dbpool subscribers settings (throwError err401)
+ :<|> handleWS dbpool subscribers settings metrics
+ :<|> handleSubscribe dbpool subscribers
+ :<|> handleDebugState :<|> handleDebugTrain
:<|> pure (GTFS.gtfsFile gtfs) :<|> gtfsRealtimeServer gtfs dbpool)
- :<|> metrics
+ :<|> handleMetrics
:<|> serveDirectoryFileServer (serverConfigAssets settings)
:<|> pure (unsafePerformIO (toWaiAppPlain (ControlRoom gtfs dbpool settings)))
- where handleTrackerRegister RegisterJson{..} = do
- today <- liftIO getCurrentTime <&> utctDay
- expires <- liftIO $ getCurrentTime <&> addUTCTime validityPeriod
- runSql dbpool $ do
- TrackerKey tracker <- insert (Tracker expires False registerAgent Nothing)
- pure tracker
- handleTrainPing onError ping@TrainPing{..} =
- isTokenValid dbpool trainPingToken >>= \case
- Nothing -> onError >> pure Nothing
- Just tracker@Tracker{..} -> do
-
- -- if the tracker is not associated with a ticket, it is probably
- -- just starting out on a new trip, or has finished an old one.
- maybeTicketId <- case trackerCurrentTicket of
- Just ticketId -> pure (Just ticketId)
- Nothing -> runSql dbpool $ do
- now <- liftIO getCurrentTime
- tickets <- selectList [ TicketDay ==. utctDay now, TicketCompleted ==. False ] []
- ticketsWithFirstStation <- flip mapMaybeM tickets
- (\ticket@(Entity ticketId _) -> do
- selectFirst [StopTicket ==. ticketId] [Asc StopSequence] >>= \case
- Nothing -> pure Nothing
- Just (Entity _ stop) -> do
- station <- getJust (stopStation stop)
- tzseries <- liftIO $ getTzseries (GTFS.tzname (stopDeparture stop))
- pure (Just (ticket, station, stop, tzseries)))
-
- if null ticketsWithFirstStation then pure Nothing else do
- let (closestTicket, _, _, _) = minimumBy
- -- (compare `on` euclid trainPingGeopos . stationGeopos . snd)
- (compare `on`
- (\(Entity _ ticket, station, stop, tzseries) ->
- let
- runningDay = ticketDay ticket
- spaceDistance = euclid trainPingGeopos (stationGeopos station)
- timeDiff =
- GTFS.toSeconds (stopDeparture stop) tzseries runningDay
- - utcToSeconds now runningDay
- in
- euclid trainPingGeopos (stationGeopos station)
- + abs (seconds2Double timeDiff / 3600)))
- ticketsWithFirstStation
- logInfoN
- $ "Tracker "+|UUID.toString (coerce trainPingToken)|+
- " is now handling ticket "+|UUID.toString (coerce (entityKey closestTicket))|+
- " (trip "+|ticketTripName (entityVal closestTicket)|+")."
-
- update (coerce trainPingToken)
- [TrackerCurrentTicket =. Just (entityKey closestTicket)]
-
- pure (Just (entityKey closestTicket))
-
- ticketId <- case maybeTicketId of
- Just ticketId -> pure ticketId
- Nothing -> do
- logWarnN $ "Tracker "+|UUID.toString (coerce trainPingToken)|+
- " sent a ping, but no trips are running today."
- throwError err400
-
- runSql dbpool $ do
- ticket@Ticket{..} <- getJust ticketId
-
- stations <- selectList [ StopTicket ==. ticketId ] [Asc StopArrival]
- >>= mapM (\stop -> do
- station <- getJust (stopStation (entityVal stop))
- tzseries <- liftIO $ getTzseries (GTFS.tzname (stopArrival (entityVal stop)))
- pure (entityVal stop, station, tzseries))
- <&> V.fromList
-
- shapePoints <- selectList [ShapePointShape ==. ticketShape] [Asc ShapePointIndex]
- <&> (V.fromList . fmap entityVal)
-
- let anchor = extrapolateAnchorFromPing LinearExtrapolator
- ticketId ticket stations shapePoints ping
-
- insert ping
-
- last <- selectFirst [TrainAnchorTicket ==. ticketId] [Desc TrainAnchorWhen]
- -- only insert new estimates if they've actually changed anything
- when (fmap (trainAnchorDelay . entityVal) last /= Just (trainAnchorDelay anchor))
- $ void $ insert anchor
-
- -- are we at the final stop? if so, mark this ticket as done
- -- & the tracker as free
- let maxSequence = V.last stations
- & (\(stop, _, _) -> stopSequence stop)
- & fromIntegral
- when (trainAnchorSequence anchor + 0.1 >= maxSequence) $ do
- update (coerce trainPingToken)
- [TrackerCurrentTicket =. Nothing]
- update ticketId
- [TicketCompleted =. True]
- logInfoN $ "Tracker "+|UUID.toString (coerce trainPingToken)|+
- " has completed ticket "+|UUID.toString (coerce ticketId)|+
- " (trip "+|ticketTripName|+")"
-
- queues <- liftIO $ atomically $ do
- queues <- readTVar subscribers <&> M.lookup (coerce ticketId)
- whenJust queues $
- mapM_ (\q -> writeTQueue q (Just ping))
- pure queues
- pure (Just anchor)
- handleWS conn = do
- liftIO $ WS.forkPingThread conn 30
- incGauge metricsWSGauge
- handle (\(e :: WS.ConnectionException) -> decGauge metricsWSGauge) $ forever $ do
- msg <- liftIO $ WS.receiveData conn
- case A.eitherDecode msg of
- Left err -> do
- logWarnN ("stray websocket message: "+|decodeASCII (toStrict msg)|+" (could not decode: "+|err|+")")
- liftIO $ WS.sendClose conn (C8.pack err)
- -- TODO: send a close msg (Nothing) to the subscribed queues? decGauge metricsWSGauge
- Right ping -> do
- -- if invalid token, send a "polite" close request. Note that the client may
- -- ignore this and continue sending messages, which will continue to be handled.
- handleTrainPing (liftIO $ WS.sendClose conn ("" :: ByteString)) ping >>= \case
- Just anchor -> liftIO $ WS.sendTextData conn (A.encode anchor)
- Nothing -> pure ()
- handleSubscribe (ticketId :: UUID) conn = liftIO $ WS.withPingThread conn 30 (pure ()) $ do
- queue <- atomically $ do
- queue <- newTQueue
- qs <- readTVar subscribers
- writeTVar subscribers
- $ M.insertWith (<>) ticketId [queue] qs
- pure queue
- -- send most recent ping, if any (so we won't have to wait for movement)
- lastPing <- runSqlWithoutLog dbpool $ do
- trackers <- getTicketTrackers ticketId
- <&> fmap entityKey
- selectFirst [TrainPingToken <-. trackers] [Desc TrainPingTimestamp]
- <&> fmap entityVal
- whenJust lastPing $ \ping ->
- WS.sendTextData conn (A.encode lastPing)
- handle (\(e :: WS.ConnectionException) -> removeSubscriber queue) $ forever $ do
- res <- atomically $ readTQueue queue
- case res of
- Just ping -> WS.sendTextData conn (A.encode ping)
- Nothing -> do
- removeSubscriber queue
- WS.sendClose conn (C8.pack "train ended")
- where removeSubscriber queue = atomically $ do
- qs <- readTVar subscribers
- writeTVar subscribers
- $ M.adjust (filter (/= queue)) ticketId qs
- handleDebugState = do
- now <- liftIO getCurrentTime
- runSql dbpool $ do
- tracker <- selectList [TrackerBlocked ==. False, TrackerExpires >=. now] []
- pairs <- forM tracker $ \(Entity token@(TrackerKey uuid) _) -> do
- entities <- selectList [TrainPingToken ==. token] []
- pure (uuid, fmap entityVal entities)
- pure (M.fromList pairs)
- handleDebugTrain ticketId = do
- runSql dbpool $ do
- trackers <- getTicketTrackers ticketId
- pings <- forM trackers $ \(Entity token _) -> do
- selectList [TrainPingToken ==. token] [] <&> fmap entityVal
- pure (concat pings)
- handleDebugAPI = pure $ toSwagger (Proxy @API)
- metrics = exportMetricsAsText <&> (decodeUtf8 . toStrict)
+ where
+ handleDebugState = do
+ now <- liftIO getCurrentTime
+ runSql dbpool $ do
+ tracker <- selectList [TrackerBlocked ==. False, TrackerExpires >=. now] []
+ pairs <- forM tracker $ \(Entity token@(TrackerKey uuid) _) -> do
+ entities <- selectList [TrainPingToken ==. token] []
+ pure (uuid, fmap entityVal entities)
+ pure (M.fromList pairs)
+ handleDebugTrain ticketId = runSql dbpool $ do
+ trackers <- getTicketTrackers ticketId
+ pings <- forM trackers $ \(Entity token _) -> do
+ selectList [TrainPingToken ==. token] [] <&> fmap entityVal
+ pure (concat pings)
+ handleDebugAPI = pure $ toSwagger (Proxy @API)
+ handleMetrics = exportMetricsAsText <&> (decodeUtf8 . toStrict)
getTicketTrackers :: (MonadLogger (t (ResourceT IO)), MonadIO (t (ResourceT IO)))
=> UUID -> ReaderT SqlBackend (t (ResourceT IO)) [Entity Tracker]
@@ -290,22 +106,3 @@ getTicketTrackers ticketId = do
joins <- selectList [TrackerTicketTicket ==. TicketKey ticketId] []
<&> fmap (trackerTicketTracker . entityVal)
selectList ([TrackerId <-. joins] ||. [TrackerCurrentTicket ==. Just (TicketKey ticketId)]) []
-
-
--- TODO: proper debug logging for expired tokens
-isTokenValid :: Pool SqlBackend -> TrackerId -> ServiceM (Maybe Tracker)
-isTokenValid dbpool token = runSql dbpool $ get token >>= \case
- Just tracker | not (trackerBlocked tracker) -> do
- ifM (hasExpired (trackerExpires tracker))
- (pure Nothing)
- (pure (Just tracker))
- _ -> pure Nothing
-
-hasExpired :: MonadIO m => UTCTime -> m Bool
-hasExpired limit = do
- now <- liftIO getCurrentTime
- pure (now > limit)
-
-validityPeriod :: NominalDiffTime
-validityPeriod = nominalDay
-