aboutsummaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
Diffstat (limited to 'lib')
-rw-r--r--lib/API.hs1
-rw-r--r--lib/Server.hs43
2 files changed, 37 insertions, 7 deletions
diff --git a/lib/API.hs b/lib/API.hs
index 70971c3..3e29249 100644
--- a/lib/API.hs
+++ b/lib/API.hs
@@ -61,6 +61,7 @@ type API = "stations" :> Get '[JSON] (Map StationID Station)
-- TODO: perhaps a websocket instead?
:<|> "train" :> "ping" :> ReqBody '[JSON] TrainPing :> Post '[JSON] (Maybe TrainAnchor)
:<|> "train" :> "ping" :> "ws" :> WebSocket
+ :<|> "train" :> "subscribe" :> Capture "Trip ID" TripID :> WebSocket
-- debug things
:<|> "debug" :> "pings" :> Get '[JSON] (Map Token [TrainPing])
:<|> "debug" :> "pings" :> Capture "Trip ID" TripID :> Capture "day" Day :> Get '[JSON] [TrainPing]
diff --git a/lib/Server.hs b/lib/Server.hs
index 6ca9c14..84dc27e 100644
--- a/lib/Server.hs
+++ b/lib/Server.hs
@@ -11,9 +11,13 @@
-- Implementation of the API. This module is the main point of the program.
module Server (application) where
+import Control.Concurrent.STM (TQueue, TVar, atomically,
+ newTQueue, newTVar, readTQueue,
+ readTVar, writeTQueue, writeTVar)
import Control.Monad (forever, unless, void, when)
import Control.Monad.Catch (handle)
-import Control.Monad.Extra (ifM, maybeM, unlessM, whenM)
+import Control.Monad.Extra (ifM, maybeM, unlessM, whenJust,
+ whenM)
import Control.Monad.IO.Class (MonadIO (liftIO))
import Control.Monad.Logger (LoggingT, logWarnN)
import Control.Monad.Reader (forM)
@@ -70,7 +74,8 @@ application gtfs dbpool = do
metrics <- Metrics
<$> register (gauge (Info "ws_connections" "Number of WS Connections"))
register ghcMetrics
- pure $ serve (Proxy @CompleteAPI) $ hoistServer (Proxy @CompleteAPI) runService $ server gtfs metrics dbpool
+ subscribers <- atomically $ newTVar mempty
+ pure $ serve (Proxy @CompleteAPI) $ hoistServer (Proxy @CompleteAPI) runService $ server gtfs metrics subscribers dbpool
-- databaseMigration :: ConnectionString -> IO ()
doMigration pool = runSql pool $
@@ -79,12 +84,13 @@ doMigration pool = runSql pool $
-- returns an empty list
runMigration migrateAll
-server :: GTFS -> Metrics -> Pool SqlBackend -> Service CompleteAPI
-server gtfs@GTFS{..} Metrics{..} dbpool = handleDebugAPI
+server :: GTFS -> Metrics -> TVar (M.Map TripID ([TQueue (Maybe TrainPing)])) -> Pool SqlBackend -> Service CompleteAPI
+server gtfs@GTFS{..} Metrics{..} subscribers dbpool = handleDebugAPI
:<|> (handleStations :<|> handleTimetable :<|> handleTrip
:<|> handleRegister :<|> handleTrainPing (throwError err401) :<|> handleWS
- :<|> handleDebugState :<|> handleDebugTrain :<|> handleDebugRegister
- :<|> gtfsRealtimeServer gtfs dbpool) :<|> metrics
+ :<|> handleSubscribe :<|> handleDebugState :<|> handleDebugTrain
+ :<|> handleDebugRegister :<|> gtfsRealtimeServer gtfs dbpool)
+ :<|> metrics
:<|> pure (unsafePerformIO (toWaiAppPlain (ControlRoom gtfs dbpool)))
where handleStations = pure stations
handleTimetable station maybeDay = do
@@ -123,6 +129,11 @@ server gtfs@GTFS{..} Metrics{..} dbpool = handleDebugAPI
-- only insert new estimates if they've actually changed anything
when (fmap (trainAnchorDelay . entityVal) last /= Just (trainAnchorDelay anchor))
$ void $ insert anchor
+ queues <- liftIO $ atomically $ do
+ queues <- readTVar subscribers <&> M.lookup runningTrip
+ whenJust queues $
+ mapM_ (\q -> writeTQueue q (Just ping))
+ pure queues
pure (Just anchor)
handleWS conn = do
liftIO $ WS.forkPingThread conn 30
@@ -133,13 +144,31 @@ server gtfs@GTFS{..} Metrics{..} dbpool = handleDebugAPI
Left err -> do
logWarnN ("stray websocket message: "+|show msg|+" (could not decode: "+|err|+")")
liftIO $ WS.sendClose conn (C8.pack err)
- decGauge metricsWSGauge
+ -- TODO: send a close msg (Nothing) to the subscribed queues? decGauge metricsWSGauge
Right ping ->
-- if invalid token, send a "polite" close request. Note that the client may
-- ignore this and continue sending messages, which will continue to be handled.
liftIO $ handleTrainPing (WS.sendClose conn ("" :: ByteString)) ping >>= \case
Just anchor -> WS.sendTextData conn (A.encode anchor)
Nothing -> pure ()
+ handleSubscribe tripId conn = liftIO $ WS.withPingThread conn 30 (pure ()) $ do
+ queue <- atomically $ do
+ queue <- newTQueue
+ qs <- readTVar subscribers
+ writeTVar subscribers
+ $ M.insertWith (<>) tripId [queue] qs
+ pure queue
+ handle (\(e :: WS.ConnectionException) -> removeSubscriber queue) $ forever $ do
+ res <- atomically $ readTQueue queue
+ case res of
+ Just ping -> WS.sendTextData conn (A.encode ping)
+ Nothing -> do
+ removeSubscriber queue
+ WS.sendClose conn (C8.pack "train ended")
+ where removeSubscriber queue = atomically $ do
+ qs <- readTVar subscribers
+ writeTVar subscribers
+ $ M.adjust (filter (/= queue)) tripId qs
handleDebugState = do
now <- liftIO getCurrentTime
runSql dbpool $ do