aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorstuebinm2024-05-08 22:42:35 +0200
committerstuebinm2024-05-08 22:43:05 +0200
commitad8a09cafa519a15a22cafbfd2fa289538edc73d (patch)
tree81f49d19669d5895115a1e8d39bd3557fc0c03d8
parent0febc9cd99e0d8b80b1385593e25e7670d5c842b (diff)
restructure: split up the server module
-rw-r--r--lib/Server.hs349
-rw-r--r--lib/Server/Base.hs9
-rw-r--r--lib/Server/Ingest.hs211
-rw-r--r--lib/Server/Subscribe.hs63
-rw-r--r--lib/Server/Util.hs64
5 files changed, 394 insertions, 302 deletions
diff --git a/lib/Server.hs b/lib/Server.hs
index 73c55cb..1833aa0 100644
--- a/lib/Server.hs
+++ b/lib/Server.hs
@@ -1,97 +1,54 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE ExplicitNamespaces #-}
-{-# LANGUAGE LambdaCase #-}
-{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE RecordWildCards #-}
-- Implementation of the API. This module is the main point of the program.
module Server (application) where
-import Control.Concurrent.STM (TQueue, TVar, atomically,
- newTQueue, newTVar,
- newTVarIO, readTQueue,
- readTVar, writeTQueue,
- writeTVar)
-import Control.Monad (forever, unless, void,
- when)
-import Control.Monad.Catch (handle)
-import Control.Monad.Extra (ifM, mapMaybeM, maybeM,
- unlessM, whenJust, whenM)
-import Control.Monad.IO.Class (MonadIO (liftIO))
-import Control.Monad.Logger (LoggingT, MonadLogger,
- NoLoggingT, logInfoN,
- logWarnN)
-import Control.Monad.Reader (MonadReader, ReaderT,
- forM)
-import Control.Monad.Trans (lift)
-import Data.Aeson ((.=))
-import qualified Data.Aeson as A
-import qualified Data.ByteString.Char8 as C8
-import Data.Coerce (coerce)
-import Data.Functor ((<&>))
-import qualified Data.Map as M
-import Data.Pool (Pool)
-import Data.Proxy (Proxy (Proxy))
-import Data.Swagger (toSchema)
-import Data.Text (Text)
-import Data.Text.Encoding (decodeASCII, decodeUtf8)
-import Data.Time (NominalDiffTime,
- UTCTime (utctDay),
- addUTCTime, diffUTCTime,
- getCurrentTime,
- nominalDay)
-import qualified Data.Vector as V
-import Database.Persist
-import Database.Persist.Postgresql (SqlBackend,
- migrateEnableExtension,
- runMigration)
-import Fmt ((+|), (|+))
-import qualified Network.WebSockets as WS
-import Servant (Application,
- ServerError (errBody),
- err400, err401, err404,
- serve,
- serveDirectoryFileServer,
- throwError)
-import Servant.API (NoContent (..),
- (:<|>) (..))
-import Servant.Server (Handler, hoistServer)
-import Servant.Swagger (toSwagger)
-
-import API
+import API (API, CompleteAPI, Metrics (..))
+import Conduit (ResourceT)
+import Config (LoggingConfig, ServerConfig (..))
+import Control.Concurrent.STM (newTVarIO)
+import Control.Monad.Extra (forM)
+import Control.Monad.IO.Class (MonadIO (liftIO))
+import Control.Monad.Logger (MonadLogger)
+import Control.Monad.Reader (ReaderT)
+import Data.ByteString.Lazy (toStrict)
+import Data.Functor ((<&>))
+import qualified Data.Map as M
+import Data.Pool (Pool)
+import Data.Proxy (Proxy (Proxy))
+import Data.Text.Encoding (decodeUtf8)
+import Data.Time (getCurrentTime)
+import Data.UUID (UUID)
+import Database.Persist (Entity (..),
+ PersistQueryRead (selectFirst),
+ SelectOpt (Desc), selectList,
+ (<-.), (==.), (>=.), (||.))
+import Database.Persist.Postgresql (SqlBackend,
+ migrateEnableExtension,
+ runMigration)
+import Fmt ((+|), (|+))
import qualified GTFS
import Persist
-import Server.ControlRoom
-import Server.GTFS_RT (gtfsRealtimeServer)
-import Server.Util (Service, ServiceM,
- runService, sendErrorMsg,
- utcToSeconds)
-import Yesod (toWaiAppPlain)
+import Prometheus (Info (Info), exportMetricsAsText,
+ gauge, register)
+import Prometheus.Metric.GHC (ghcMetrics)
+import Servant (Application, err401, serve,
+ serveDirectoryFileServer,
+ throwError)
+import Servant.API ((:<|>) (..))
+import Servant.Server (hoistServer)
+import Servant.Swagger (toSwagger)
+import Server.Base (ServerState)
+import Server.ControlRoom (ControlRoom (ControlRoom))
+import Server.GTFS_RT (gtfsRealtimeServer)
+import Server.Ingest (handleTrackerRegister,
+ handleTrainPing, handleWS)
+import Server.Subscribe (handleSubscribe)
+import Server.Util (Service, runService)
+import System.IO.Unsafe (unsafePerformIO)
+import Yesod (toWaiAppPlain)
-import Conduit (ResourceT)
-import Config (LoggingConfig,
- ServerConfig (..))
-import Control.Exception (throw)
-import Data.ByteString (ByteString)
-import Data.ByteString.Lazy (toStrict)
-import Data.Foldable (minimumBy)
-import Data.Function (on, (&))
-import Data.Maybe (fromMaybe)
-import qualified Data.Text as T
-import Data.Time.LocalTime.TimeZone.Olson (getTimeZoneSeriesFromOlsonFile)
-import Data.Time.LocalTime.TimeZone.Series (TimeZoneSeries)
-import Data.UUID (UUID)
-import qualified Data.UUID as UUID
-import Extrapolation (Extrapolator (..),
- LinearExtrapolator (..),
- euclid)
-import GTFS (Seconds (unSeconds),
- seconds2Double)
-import Prometheus
-import Prometheus.Metric.GHC
-import System.FilePath ((</>))
-import System.IO.Unsafe
application :: GTFS.GTFS -> Pool SqlBackend -> ServerConfig -> IO Application
application gtfs dbpool settings = do
@@ -100,189 +57,48 @@ application gtfs dbpool settings = do
<$> register (gauge (Info "ws_connections" "Number of WS Connections"))
register ghcMetrics
- -- TODO: maybe cache these in a TVar, we're not likely to ever need
- -- more than one of these
- let getTzseries tzname = getTimeZoneSeriesFromOlsonFile
- (serverConfigZoneinfoPath settings </> T.unpack tzname)
-
subscribers <- newTVarIO mempty
pure $ serve (Proxy @CompleteAPI)
$ hoistServer (Proxy @CompleteAPI) (runService (serverConfigLogging settings))
- $ server gtfs getTzseries metrics subscribers dbpool settings
+ $ server gtfs metrics subscribers dbpool settings
--- databaseMigration :: ConnectionString -> IO ()
doMigration pool = runSqlWithoutLog pool $ runMigration $ do
migrateEnableExtension "uuid-ossp"
migrateAll
-server :: GTFS.GTFS -> (Text -> IO TimeZoneSeries) -> Metrics -> TVar (M.Map UUID [TQueue (Maybe TrainPing)]) -> Pool SqlBackend -> ServerConfig -> Service CompleteAPI
-server gtfs getTzseries Metrics{..} subscribers dbpool settings = handleDebugAPI
- :<|> (handleTrackerRegister :<|> handleTrainPing (throwError err401) :<|> handleWS
- :<|> handleSubscribe :<|> handleDebugState :<|> handleDebugTrain
+server
+ :: GTFS.GTFS
+ -> Metrics
+ -> ServerState
+ -> Pool SqlBackend
+ -> ServerConfig
+ -> Service CompleteAPI
+server gtfs metrics@Metrics{..} subscribers dbpool settings = handleDebugAPI
+ :<|> (handleTrackerRegister dbpool
+ :<|> handleTrainPing dbpool subscribers settings (throwError err401)
+ :<|> handleWS dbpool subscribers settings metrics
+ :<|> handleSubscribe dbpool subscribers
+ :<|> handleDebugState :<|> handleDebugTrain
:<|> pure (GTFS.gtfsFile gtfs) :<|> gtfsRealtimeServer gtfs dbpool)
- :<|> metrics
+ :<|> handleMetrics
:<|> serveDirectoryFileServer (serverConfigAssets settings)
:<|> pure (unsafePerformIO (toWaiAppPlain (ControlRoom gtfs dbpool settings)))
- where handleTrackerRegister RegisterJson{..} = do
- today <- liftIO getCurrentTime <&> utctDay
- expires <- liftIO $ getCurrentTime <&> addUTCTime validityPeriod
- runSql dbpool $ do
- TrackerKey tracker <- insert (Tracker expires False registerAgent Nothing)
- pure tracker
- handleTrainPing onError ping@TrainPing{..} =
- isTokenValid dbpool trainPingToken >>= \case
- Nothing -> onError >> pure Nothing
- Just tracker@Tracker{..} -> do
-
- -- if the tracker is not associated with a ticket, it is probably
- -- just starting out on a new trip, or has finished an old one.
- maybeTicketId <- case trackerCurrentTicket of
- Just ticketId -> pure (Just ticketId)
- Nothing -> runSql dbpool $ do
- now <- liftIO getCurrentTime
- tickets <- selectList [ TicketDay ==. utctDay now, TicketCompleted ==. False ] []
- ticketsWithFirstStation <- flip mapMaybeM tickets
- (\ticket@(Entity ticketId _) -> do
- selectFirst [StopTicket ==. ticketId] [Asc StopSequence] >>= \case
- Nothing -> pure Nothing
- Just (Entity _ stop) -> do
- station <- getJust (stopStation stop)
- tzseries <- liftIO $ getTzseries (GTFS.tzname (stopDeparture stop))
- pure (Just (ticket, station, stop, tzseries)))
-
- if null ticketsWithFirstStation then pure Nothing else do
- let (closestTicket, _, _, _) = minimumBy
- -- (compare `on` euclid trainPingGeopos . stationGeopos . snd)
- (compare `on`
- (\(Entity _ ticket, station, stop, tzseries) ->
- let
- runningDay = ticketDay ticket
- spaceDistance = euclid trainPingGeopos (stationGeopos station)
- timeDiff =
- GTFS.toSeconds (stopDeparture stop) tzseries runningDay
- - utcToSeconds now runningDay
- in
- euclid trainPingGeopos (stationGeopos station)
- + abs (seconds2Double timeDiff / 3600)))
- ticketsWithFirstStation
- logInfoN
- $ "Tracker "+|UUID.toString (coerce trainPingToken)|+
- " is now handling ticket "+|UUID.toString (coerce (entityKey closestTicket))|+
- " (trip "+|ticketTripName (entityVal closestTicket)|+")."
-
- update (coerce trainPingToken)
- [TrackerCurrentTicket =. Just (entityKey closestTicket)]
-
- pure (Just (entityKey closestTicket))
-
- ticketId <- case maybeTicketId of
- Just ticketId -> pure ticketId
- Nothing -> do
- logWarnN $ "Tracker "+|UUID.toString (coerce trainPingToken)|+
- " sent a ping, but no trips are running today."
- throwError err400
-
- runSql dbpool $ do
- ticket@Ticket{..} <- getJust ticketId
-
- stations <- selectList [ StopTicket ==. ticketId ] [Asc StopArrival]
- >>= mapM (\stop -> do
- station <- getJust (stopStation (entityVal stop))
- tzseries <- liftIO $ getTzseries (GTFS.tzname (stopArrival (entityVal stop)))
- pure (entityVal stop, station, tzseries))
- <&> V.fromList
-
- shapePoints <- selectList [ShapePointShape ==. ticketShape] [Asc ShapePointIndex]
- <&> (V.fromList . fmap entityVal)
-
- let anchor = extrapolateAnchorFromPing LinearExtrapolator
- ticketId ticket stations shapePoints ping
-
- insert ping
-
- last <- selectFirst [TrainAnchorTicket ==. ticketId] [Desc TrainAnchorWhen]
- -- only insert new estimates if they've actually changed anything
- when (fmap (trainAnchorDelay . entityVal) last /= Just (trainAnchorDelay anchor))
- $ void $ insert anchor
-
- -- are we at the final stop? if so, mark this ticket as done
- -- & the tracker as free
- let maxSequence = V.last stations
- & (\(stop, _, _) -> stopSequence stop)
- & fromIntegral
- when (trainAnchorSequence anchor + 0.1 >= maxSequence) $ do
- update (coerce trainPingToken)
- [TrackerCurrentTicket =. Nothing]
- update ticketId
- [TicketCompleted =. True]
- logInfoN $ "Tracker "+|UUID.toString (coerce trainPingToken)|+
- " has completed ticket "+|UUID.toString (coerce ticketId)|+
- " (trip "+|ticketTripName|+")"
-
- queues <- liftIO $ atomically $ do
- queues <- readTVar subscribers <&> M.lookup (coerce ticketId)
- whenJust queues $
- mapM_ (\q -> writeTQueue q (Just ping))
- pure queues
- pure (Just anchor)
- handleWS conn = do
- liftIO $ WS.forkPingThread conn 30
- incGauge metricsWSGauge
- handle (\(e :: WS.ConnectionException) -> decGauge metricsWSGauge) $ forever $ do
- msg <- liftIO $ WS.receiveData conn
- case A.eitherDecode msg of
- Left err -> do
- logWarnN ("stray websocket message: "+|decodeASCII (toStrict msg)|+" (could not decode: "+|err|+")")
- liftIO $ WS.sendClose conn (C8.pack err)
- -- TODO: send a close msg (Nothing) to the subscribed queues? decGauge metricsWSGauge
- Right ping -> do
- -- if invalid token, send a "polite" close request. Note that the client may
- -- ignore this and continue sending messages, which will continue to be handled.
- handleTrainPing (liftIO $ WS.sendClose conn ("" :: ByteString)) ping >>= \case
- Just anchor -> liftIO $ WS.sendTextData conn (A.encode anchor)
- Nothing -> pure ()
- handleSubscribe (ticketId :: UUID) conn = liftIO $ WS.withPingThread conn 30 (pure ()) $ do
- queue <- atomically $ do
- queue <- newTQueue
- qs <- readTVar subscribers
- writeTVar subscribers
- $ M.insertWith (<>) ticketId [queue] qs
- pure queue
- -- send most recent ping, if any (so we won't have to wait for movement)
- lastPing <- runSqlWithoutLog dbpool $ do
- trackers <- getTicketTrackers ticketId
- <&> fmap entityKey
- selectFirst [TrainPingToken <-. trackers] [Desc TrainPingTimestamp]
- <&> fmap entityVal
- whenJust lastPing $ \ping ->
- WS.sendTextData conn (A.encode lastPing)
- handle (\(e :: WS.ConnectionException) -> removeSubscriber queue) $ forever $ do
- res <- atomically $ readTQueue queue
- case res of
- Just ping -> WS.sendTextData conn (A.encode ping)
- Nothing -> do
- removeSubscriber queue
- WS.sendClose conn (C8.pack "train ended")
- where removeSubscriber queue = atomically $ do
- qs <- readTVar subscribers
- writeTVar subscribers
- $ M.adjust (filter (/= queue)) ticketId qs
- handleDebugState = do
- now <- liftIO getCurrentTime
- runSql dbpool $ do
- tracker <- selectList [TrackerBlocked ==. False, TrackerExpires >=. now] []
- pairs <- forM tracker $ \(Entity token@(TrackerKey uuid) _) -> do
- entities <- selectList [TrainPingToken ==. token] []
- pure (uuid, fmap entityVal entities)
- pure (M.fromList pairs)
- handleDebugTrain ticketId = do
- runSql dbpool $ do
- trackers <- getTicketTrackers ticketId
- pings <- forM trackers $ \(Entity token _) -> do
- selectList [TrainPingToken ==. token] [] <&> fmap entityVal
- pure (concat pings)
- handleDebugAPI = pure $ toSwagger (Proxy @API)
- metrics = exportMetricsAsText <&> (decodeUtf8 . toStrict)
+ where
+ handleDebugState = do
+ now <- liftIO getCurrentTime
+ runSql dbpool $ do
+ tracker <- selectList [TrackerBlocked ==. False, TrackerExpires >=. now] []
+ pairs <- forM tracker $ \(Entity token@(TrackerKey uuid) _) -> do
+ entities <- selectList [TrainPingToken ==. token] []
+ pure (uuid, fmap entityVal entities)
+ pure (M.fromList pairs)
+ handleDebugTrain ticketId = runSql dbpool $ do
+ trackers <- getTicketTrackers ticketId
+ pings <- forM trackers $ \(Entity token _) -> do
+ selectList [TrainPingToken ==. token] [] <&> fmap entityVal
+ pure (concat pings)
+ handleDebugAPI = pure $ toSwagger (Proxy @API)
+ handleMetrics = exportMetricsAsText <&> (decodeUtf8 . toStrict)
getTicketTrackers :: (MonadLogger (t (ResourceT IO)), MonadIO (t (ResourceT IO)))
=> UUID -> ReaderT SqlBackend (t (ResourceT IO)) [Entity Tracker]
@@ -290,22 +106,3 @@ getTicketTrackers ticketId = do
joins <- selectList [TrackerTicketTicket ==. TicketKey ticketId] []
<&> fmap (trackerTicketTracker . entityVal)
selectList ([TrackerId <-. joins] ||. [TrackerCurrentTicket ==. Just (TicketKey ticketId)]) []
-
-
--- TODO: proper debug logging for expired tokens
-isTokenValid :: Pool SqlBackend -> TrackerId -> ServiceM (Maybe Tracker)
-isTokenValid dbpool token = runSql dbpool $ get token >>= \case
- Just tracker | not (trackerBlocked tracker) -> do
- ifM (hasExpired (trackerExpires tracker))
- (pure Nothing)
- (pure (Just tracker))
- _ -> pure Nothing
-
-hasExpired :: MonadIO m => UTCTime -> m Bool
-hasExpired limit = do
- now <- liftIO getCurrentTime
- pure (now > limit)
-
-validityPeriod :: NominalDiffTime
-validityPeriod = nominalDay
-
diff --git a/lib/Server/Base.hs b/lib/Server/Base.hs
new file mode 100644
index 0000000..14b77ca
--- /dev/null
+++ b/lib/Server/Base.hs
@@ -0,0 +1,9 @@
+
+module Server.Base (ServerState) where
+
+import Control.Concurrent.STM (TQueue, TVar)
+import qualified Data.Map as M
+import Data.UUID (UUID)
+import Persist (TrainPing)
+
+type ServerState = TVar (M.Map UUID [TQueue (Maybe TrainPing)])
diff --git a/lib/Server/Ingest.hs b/lib/Server/Ingest.hs
new file mode 100644
index 0000000..d774017
--- /dev/null
+++ b/lib/Server/Ingest.hs
@@ -0,0 +1,211 @@
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE RecordWildCards #-}
+
+module Server.Ingest (handleTrackerRegister, handleTrainPing, handleWS) where
+import API (Metrics (..),
+ RegisterJson (..))
+import Control.Concurrent.STM (atomically, readTVar,
+ writeTQueue)
+import Control.Monad (forever, void, when)
+import Control.Monad.Catch (handle)
+import Control.Monad.Extra (ifM, mapMaybeM, whenJust)
+import Control.Monad.IO.Class (MonadIO (liftIO))
+import Control.Monad.Logger (LoggingT, logInfoN,
+ logWarnN)
+import Control.Monad.Reader (ReaderT)
+import qualified Data.Aeson as A
+import qualified Data.ByteString.Char8 as C8
+import Data.Coerce (coerce)
+import Data.Functor ((<&>))
+import qualified Data.Map as M
+import Data.Pool (Pool)
+import Data.Text (Text)
+import Data.Text.Encoding (decodeASCII, decodeUtf8)
+import Data.Time (NominalDiffTime,
+ UTCTime (..), addUTCTime,
+ getCurrentTime,
+ nominalDay)
+import qualified Data.Vector as V
+import Database.Persist
+import Database.Persist.Postgresql (SqlBackend)
+import Fmt ((+|), (|+))
+import qualified GTFS
+import qualified Network.WebSockets as WS
+import Persist
+import Servant (err400, throwError)
+import Servant.Server (Handler)
+import Server.Util (ServiceM, getTzseries,
+ utcToSeconds)
+
+import Config (LoggingConfig,
+ ServerConfig)
+import Data.ByteString (ByteString)
+import Data.ByteString.Lazy (toStrict)
+import Data.Foldable (minimumBy)
+import Data.Function (on, (&))
+import qualified Data.Text as T
+import Data.Time.LocalTime.TimeZone.Series (TimeZoneSeries)
+import qualified Data.UUID as UUID
+import Extrapolation (Extrapolator (..),
+ LinearExtrapolator (..),
+ euclid)
+import GTFS (seconds2Double)
+import Prometheus (decGauge, incGauge)
+import Server.Base (ServerState)
+
+
+
+handleTrackerRegister
+ :: Pool SqlBackend
+ -> RegisterJson
+ -> ServiceM Token
+handleTrackerRegister dbpool RegisterJson{..} = do
+ today <- liftIO getCurrentTime <&> utctDay
+ expires <- liftIO $ getCurrentTime <&> addUTCTime validityPeriod
+ runSql dbpool $ do
+ TrackerKey tracker <- insert (Tracker expires False registerAgent Nothing)
+ pure tracker
+ where
+ validityPeriod :: NominalDiffTime
+ validityPeriod = nominalDay
+
+handleTrainPing
+ :: Pool SqlBackend
+ -> ServerState
+ -> ServerConfig
+ -> LoggingT (ReaderT LoggingConfig Handler) a
+ -> TrainPing
+ -> LoggingT (ReaderT LoggingConfig Handler) (Maybe TrainAnchor)
+handleTrainPing dbpool subscribers cfg onError ping@TrainPing{..} =
+ isTokenValid dbpool trainPingToken >>= \case
+ Nothing -> onError >> pure Nothing
+ Just tracker@Tracker{..} -> do
+
+ -- if the tracker is not associated with a ticket, it is probably
+ -- just starting out on a new trip, or has finished an old one.
+ maybeTicketId <- case trackerCurrentTicket of
+ Just ticketId -> pure (Just ticketId)
+ Nothing -> runSql dbpool $ do
+ now <- liftIO getCurrentTime
+ tickets <- selectList [ TicketDay ==. utctDay now, TicketCompleted ==. False ] []
+ ticketsWithFirstStation <- flip mapMaybeM tickets
+ (\ticket@(Entity ticketId _) -> do
+ selectFirst [StopTicket ==. ticketId] [Asc StopSequence] >>= \case
+ Nothing -> pure Nothing
+ Just (Entity _ stop) -> do
+ station <- getJust (stopStation stop)
+ tzseries <- liftIO $ getTzseries cfg (GTFS.tzname (stopDeparture stop))
+ pure (Just (ticket, station, stop, tzseries)))
+
+ if null ticketsWithFirstStation then pure Nothing else do
+ let (closestTicket, _, _, _) = minimumBy
+ -- (compare `on` euclid trainPingGeopos . stationGeopos . snd)
+ (compare `on`
+ (\(Entity _ ticket, station, stop, tzseries) ->
+ let
+ runningDay = ticketDay ticket
+ spaceDistance = euclid trainPingGeopos (stationGeopos station)
+ timeDiff =
+ GTFS.toSeconds (stopDeparture stop) tzseries runningDay
+ - utcToSeconds now runningDay
+ in
+ euclid trainPingGeopos (stationGeopos station)
+ + abs (seconds2Double timeDiff / 3600)))
+ ticketsWithFirstStation
+ logInfoN
+ $ "Tracker "+|UUID.toString (coerce trainPingToken)|+
+ " is now handling ticket "+|UUID.toString (coerce (entityKey closestTicket))|+
+ " (trip "+|ticketTripName (entityVal closestTicket)|+")."
+
+ update trainPingToken
+ [TrackerCurrentTicket =. Just (entityKey closestTicket)]
+
+ pure (Just (entityKey closestTicket))
+
+ ticketId <- case maybeTicketId of
+ Just ticketId -> pure ticketId
+ Nothing -> do
+ logWarnN $ "Tracker "+|UUID.toString (coerce trainPingToken)|+
+ " sent a ping, but no trips are running today."
+ throwError err400
+
+ runSql dbpool $ do
+ ticket@Ticket{..} <- getJust ticketId
+
+ stations <- selectList [ StopTicket ==. ticketId ] [Asc StopArrival]
+ >>= mapM (\stop -> do
+ station <- getJust (stopStation (entityVal stop))
+ tzseries <- liftIO $ getTzseries cfg (GTFS.tzname (stopArrival (entityVal stop)))
+ pure (entityVal stop, station, tzseries))
+ <&> V.fromList
+
+ shapePoints <- selectList [ShapePointShape ==. ticketShape] [Asc ShapePointIndex]
+ <&> (V.fromList . fmap entityVal)
+
+ let anchor = extrapolateAnchorFromPing LinearExtrapolator
+ ticketId ticket stations shapePoints ping
+
+ insert ping
+
+ last <- selectFirst [TrainAnchorTicket ==. ticketId] [Desc TrainAnchorWhen]
+ -- only insert new estimates if they've actually changed anything
+ when (fmap (trainAnchorDelay . entityVal) last /= Just (trainAnchorDelay anchor))
+ $ void $ insert anchor
+
+ -- are we at the final stop? if so, mark this ticket as done
+ -- & the tracker as free
+ let maxSequence = V.last stations
+ & (\(stop, _, _) -> stopSequence stop)
+ & fromIntegral
+ when (trainAnchorSequence anchor + 0.1 >= maxSequence) $ do
+ update trainPingToken
+ [TrackerCurrentTicket =. Nothing]
+ update ticketId
+ [TicketCompleted =. True]
+ logInfoN $ "Tracker "+|UUID.toString (coerce trainPingToken)|+
+ " has completed ticket "+|UUID.toString (coerce ticketId)|+
+ " (trip "+|ticketTripName|+")"
+
+ queues <- liftIO $ atomically $ do
+ queues <- readTVar subscribers <&> M.lookup (coerce ticketId)
+ whenJust queues $
+ mapM_ (\q -> writeTQueue q (Just ping))
+ pure queues
+ pure (Just anchor)
+
+handleWS
+ :: Pool SqlBackend
+ -> ServerState
+ -> ServerConfig
+ -> Metrics
+ -> WS.Connection -> ServiceM ()
+handleWS dbpool subscribers cfg Metrics{..} conn = do
+ liftIO $ WS.forkPingThread conn 30
+ incGauge metricsWSGauge
+ handle (\(e :: WS.ConnectionException) -> decGauge metricsWSGauge) $ forever $ do
+ msg <- liftIO $ WS.receiveData conn
+ case A.eitherDecode msg of
+ Left err -> do
+ logWarnN ("stray websocket message: "+|decodeASCII (toStrict msg)|+" (could not decode: "+|err|+")")
+ liftIO $ WS.sendClose conn (C8.pack err)
+ -- TODO: send a close msg (Nothing) to the subscribed queues? decGauge metricsWSGauge
+ Right ping -> do
+ -- if invalid token, send a "polite" close request. Note that the client may
+ -- ignore this and continue sending messages, which will continue to be handled.
+ handleTrainPing dbpool subscribers cfg (liftIO $ WS.sendClose conn ("" :: ByteString)) ping >>= \case
+ Just anchor -> liftIO $ WS.sendTextData conn (A.encode anchor)
+ Nothing -> pure ()
+
+-- TODO: proper debug logging for expired tokens
+isTokenValid :: Pool SqlBackend -> TrackerId -> ServiceM (Maybe Tracker)
+isTokenValid dbpool token = runSql dbpool $ get token >>= \case
+ Just tracker | not (trackerBlocked tracker) -> do
+ ifM (hasExpired (trackerExpires tracker))
+ (pure Nothing)
+ (pure (Just tracker))
+ _ -> pure Nothing
+
+hasExpired :: MonadIO m => UTCTime -> m Bool
+hasExpired limit = do
+ now <- liftIO getCurrentTime
+ pure (now > limit)
diff --git a/lib/Server/Subscribe.hs b/lib/Server/Subscribe.hs
new file mode 100644
index 0000000..fdc092b
--- /dev/null
+++ b/lib/Server/Subscribe.hs
@@ -0,0 +1,63 @@
+
+module Server.Subscribe where
+import Conduit (MonadIO (..))
+import Control.Concurrent.STM (atomically, newTQueue, readTQueue,
+ readTVar, writeTVar)
+import Control.Exception (handle)
+import Control.Monad.Extra (forever, whenJust)
+import qualified Data.Aeson as A
+import qualified Data.ByteString.Char8 as C8
+import Data.Functor ((<&>))
+import Data.Map (Map)
+import qualified Data.Map as M
+import Data.Pool
+import Data.UUID (UUID)
+import Database.Persist (Entity (entityKey), SelectOpt (Desc),
+ entityVal, selectFirst, selectList,
+ (<-.), (==.), (||.))
+import Database.Persist.Sql (SqlBackend)
+import qualified Network.WebSockets as WS
+import Persist
+import Server.Base (ServerState)
+import Server.Util (ServiceM)
+
+
+handleSubscribe
+ :: Pool SqlBackend
+ -> ServerState
+ -> UUID
+ -> WS.Connection
+ -> ServiceM ()
+handleSubscribe dbpool subscribers (ticketId :: UUID) conn = liftIO $ WS.withPingThread conn 30 (pure ()) $ do
+ queue <- atomically $ do
+ queue <- newTQueue
+ qs <- readTVar subscribers
+ writeTVar subscribers
+ $ M.insertWith (<>) ticketId [queue] qs
+ pure queue
+ -- send most recent ping, if any (so we won't have to wait for movement)
+ lastPing <- runSqlWithoutLog dbpool $ do
+ trackers <- getTicketTrackers ticketId
+ <&> fmap entityKey
+ selectFirst [TrainPingToken <-. trackers] [Desc TrainPingTimestamp]
+ <&> fmap entityVal
+ whenJust lastPing $ \ping ->
+ WS.sendTextData conn (A.encode lastPing)
+ handle (\(e :: WS.ConnectionException) -> removeSubscriber queue) $ forever $ do
+ res <- atomically $ readTQueue queue
+ case res of
+ Just ping -> WS.sendTextData conn (A.encode ping)
+ Nothing -> do
+ removeSubscriber queue
+ WS.sendClose conn (C8.pack "train ended")
+ where removeSubscriber queue = atomically $ do
+ qs <- readTVar subscribers
+ writeTVar subscribers
+ $ M.adjust (filter (/= queue)) ticketId qs
+
+-- getTicketTrackers :: (MonadLogger (t (ResourceT IO)), MonadIO (t (ResourceT IO)))
+-- => UUID -> ReaderT SqlBackend (t (ResourceT IO)) [Entity Tracker]
+getTicketTrackers ticketId = do
+ joins <- selectList [TrackerTicketTicket ==. TicketKey ticketId] []
+ <&> fmap (trackerTicketTracker . entityVal)
+ selectList ([TrackerId <-. joins] ||. [TrackerCurrentTicket ==. Just (TicketKey ticketId)]) []
diff --git a/lib/Server/Util.hs b/lib/Server/Util.hs
index 0106428..290b9c5 100644
--- a/lib/Server/Util.hs
+++ b/lib/Server/Util.hs
@@ -1,33 +1,41 @@
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE RecordWildCards #-}
-- | mostly the monad the service runs in
-module Server.Util (Service, ServiceM, runService, sendErrorMsg, secondsNow, utcToSeconds, runLogging) where
+module Server.Util (Service, ServiceM, runService, sendErrorMsg, secondsNow, utcToSeconds, runLogging, getTzseries) where
-import Config (LoggingConfig (..))
-import Control.Exception (handle, try)
-import Control.Monad.Extra (void, whenJust)
-import Control.Monad.IO.Class (MonadIO (liftIO))
-import Control.Monad.Logger (Loc, LogLevel (..), LogSource, LogStr,
- LoggingT (..), defaultOutput,
- fromLogStr, runStderrLoggingT)
-import Control.Monad.Reader (ReaderT (..))
-import qualified Data.Aeson as A
-import Data.ByteString (ByteString)
-import qualified Data.ByteString as C8
-import Data.Text (Text)
-import qualified Data.Text as T
-import Data.Text.Encoding (decodeUtf8Lenient)
-import Data.Time (Day, UTCTime (..), diffUTCTime,
- getCurrentTime,
- nominalDiffTimeToSeconds)
-import Fmt ((+|), (|+))
-import GHC.IO.Exception (IOException (IOError))
-import GTFS (Seconds (..))
-import Prometheus (MonadMonitor (doIO))
-import Servant (Handler, ServerError, ServerT, err404,
- errBody, errHeaders, throwError)
-import System.IO (stderr)
-import System.Process.Extra (callProcess)
+import Config (LoggingConfig (..),
+ ServerConfig (..))
+import Control.Exception (handle, try)
+import Control.Monad.Extra (void, whenJust)
+import Control.Monad.IO.Class (MonadIO (liftIO))
+import Control.Monad.Logger (Loc, LogLevel (..),
+ LogSource, LogStr,
+ LoggingT (..),
+ defaultOutput, fromLogStr,
+ runStderrLoggingT)
+import Control.Monad.Reader (ReaderT (..))
+import qualified Data.Aeson as A
+import Data.ByteString (ByteString)
+import qualified Data.ByteString as C8
+import Data.Text (Text)
+import qualified Data.Text as T
+import Data.Text.Encoding (decodeUtf8Lenient)
+import Data.Time (Day, UTCTime (..),
+ diffUTCTime,
+ getCurrentTime,
+ nominalDiffTimeToSeconds)
+import Data.Time.LocalTime.TimeZone.Olson (getTimeZoneSeriesFromOlsonFile)
+import Data.Time.LocalTime.TimeZone.Series (TimeZoneSeries)
+import Fmt ((+|), (|+))
+import GHC.IO.Exception (IOException (IOError))
+import GTFS (Seconds (..))
+import Prometheus (MonadMonitor (doIO))
+import Servant (Handler, ServerError,
+ ServerT, err404, errBody,
+ errHeaders, throwError)
+import System.FilePath ((</>))
+import System.IO (stderr)
+import System.Process.Extra (callProcess)
type ServiceM = LoggingT (ReaderT LoggingConfig Handler)
type Service api = ServerT api ServiceM
@@ -77,3 +85,7 @@ secondsNow runningDay = do
utcToSeconds :: UTCTime -> Day -> Seconds
utcToSeconds time day =
Seconds $ round $ nominalDiffTimeToSeconds $ diffUTCTime time (UTCTime day 0)
+
+getTzseries :: ServerConfig -> Text -> IO TimeZoneSeries
+getTzseries settings tzname = getTimeZoneSeriesFromOlsonFile
+ (serverConfigZoneinfoPath settings </> T.unpack tzname)