diff options
Diffstat (limited to 'lib')
-rw-r--r-- | lib/API.hs | 1 | ||||
-rw-r--r-- | lib/Server.hs | 43 |
2 files changed, 37 insertions, 7 deletions
@@ -61,6 +61,7 @@ type API = "stations" :> Get '[JSON] (Map StationID Station) -- TODO: perhaps a websocket instead? :<|> "train" :> "ping" :> ReqBody '[JSON] TrainPing :> Post '[JSON] (Maybe TrainAnchor) :<|> "train" :> "ping" :> "ws" :> WebSocket + :<|> "train" :> "subscribe" :> Capture "Trip ID" TripID :> WebSocket -- debug things :<|> "debug" :> "pings" :> Get '[JSON] (Map Token [TrainPing]) :<|> "debug" :> "pings" :> Capture "Trip ID" TripID :> Capture "day" Day :> Get '[JSON] [TrainPing] diff --git a/lib/Server.hs b/lib/Server.hs index 6ca9c14..84dc27e 100644 --- a/lib/Server.hs +++ b/lib/Server.hs @@ -11,9 +11,13 @@ -- Implementation of the API. This module is the main point of the program. module Server (application) where +import Control.Concurrent.STM (TQueue, TVar, atomically, + newTQueue, newTVar, readTQueue, + readTVar, writeTQueue, writeTVar) import Control.Monad (forever, unless, void, when) import Control.Monad.Catch (handle) -import Control.Monad.Extra (ifM, maybeM, unlessM, whenM) +import Control.Monad.Extra (ifM, maybeM, unlessM, whenJust, + whenM) import Control.Monad.IO.Class (MonadIO (liftIO)) import Control.Monad.Logger (LoggingT, logWarnN) import Control.Monad.Reader (forM) @@ -70,7 +74,8 @@ application gtfs dbpool = do metrics <- Metrics <$> register (gauge (Info "ws_connections" "Number of WS Connections")) register ghcMetrics - pure $ serve (Proxy @CompleteAPI) $ hoistServer (Proxy @CompleteAPI) runService $ server gtfs metrics dbpool + subscribers <- atomically $ newTVar mempty + pure $ serve (Proxy @CompleteAPI) $ hoistServer (Proxy @CompleteAPI) runService $ server gtfs metrics subscribers dbpool -- databaseMigration :: ConnectionString -> IO () doMigration pool = runSql pool $ @@ -79,12 +84,13 @@ doMigration pool = runSql pool $ -- returns an empty list runMigration migrateAll -server :: GTFS -> Metrics -> Pool SqlBackend -> Service CompleteAPI -server gtfs@GTFS{..} Metrics{..} dbpool = handleDebugAPI +server :: GTFS -> Metrics -> TVar (M.Map TripID ([TQueue (Maybe TrainPing)])) -> Pool SqlBackend -> Service CompleteAPI +server gtfs@GTFS{..} Metrics{..} subscribers dbpool = handleDebugAPI :<|> (handleStations :<|> handleTimetable :<|> handleTrip :<|> handleRegister :<|> handleTrainPing (throwError err401) :<|> handleWS - :<|> handleDebugState :<|> handleDebugTrain :<|> handleDebugRegister - :<|> gtfsRealtimeServer gtfs dbpool) :<|> metrics + :<|> handleSubscribe :<|> handleDebugState :<|> handleDebugTrain + :<|> handleDebugRegister :<|> gtfsRealtimeServer gtfs dbpool) + :<|> metrics :<|> pure (unsafePerformIO (toWaiAppPlain (ControlRoom gtfs dbpool))) where handleStations = pure stations handleTimetable station maybeDay = do @@ -123,6 +129,11 @@ server gtfs@GTFS{..} Metrics{..} dbpool = handleDebugAPI -- only insert new estimates if they've actually changed anything when (fmap (trainAnchorDelay . entityVal) last /= Just (trainAnchorDelay anchor)) $ void $ insert anchor + queues <- liftIO $ atomically $ do + queues <- readTVar subscribers <&> M.lookup runningTrip + whenJust queues $ + mapM_ (\q -> writeTQueue q (Just ping)) + pure queues pure (Just anchor) handleWS conn = do liftIO $ WS.forkPingThread conn 30 @@ -133,13 +144,31 @@ server gtfs@GTFS{..} Metrics{..} dbpool = handleDebugAPI Left err -> do logWarnN ("stray websocket message: "+|show msg|+" (could not decode: "+|err|+")") liftIO $ WS.sendClose conn (C8.pack err) - decGauge metricsWSGauge + -- TODO: send a close msg (Nothing) to the subscribed queues? decGauge metricsWSGauge Right ping -> -- if invalid token, send a "polite" close request. Note that the client may -- ignore this and continue sending messages, which will continue to be handled. liftIO $ handleTrainPing (WS.sendClose conn ("" :: ByteString)) ping >>= \case Just anchor -> WS.sendTextData conn (A.encode anchor) Nothing -> pure () + handleSubscribe tripId conn = liftIO $ WS.withPingThread conn 30 (pure ()) $ do + queue <- atomically $ do + queue <- newTQueue + qs <- readTVar subscribers + writeTVar subscribers + $ M.insertWith (<>) tripId [queue] qs + pure queue + handle (\(e :: WS.ConnectionException) -> removeSubscriber queue) $ forever $ do + res <- atomically $ readTQueue queue + case res of + Just ping -> WS.sendTextData conn (A.encode ping) + Nothing -> do + removeSubscriber queue + WS.sendClose conn (C8.pack "train ended") + where removeSubscriber queue = atomically $ do + qs <- readTVar subscribers + writeTVar subscribers + $ M.adjust (filter (/= queue)) tripId qs handleDebugState = do now <- liftIO getCurrentTime runSql dbpool $ do |