diff options
Diffstat (limited to '')
-rw-r--r-- | lib/Server.hs | 349 |
1 files changed, 73 insertions, 276 deletions
diff --git a/lib/Server.hs b/lib/Server.hs index 73c55cb..1833aa0 100644 --- a/lib/Server.hs +++ b/lib/Server.hs @@ -1,97 +1,54 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE ExplicitNamespaces #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE OverloadedLists #-} {-# LANGUAGE PartialTypeSignatures #-} {-# LANGUAGE RecordWildCards #-} -- Implementation of the API. This module is the main point of the program. module Server (application) where -import Control.Concurrent.STM (TQueue, TVar, atomically, - newTQueue, newTVar, - newTVarIO, readTQueue, - readTVar, writeTQueue, - writeTVar) -import Control.Monad (forever, unless, void, - when) -import Control.Monad.Catch (handle) -import Control.Monad.Extra (ifM, mapMaybeM, maybeM, - unlessM, whenJust, whenM) -import Control.Monad.IO.Class (MonadIO (liftIO)) -import Control.Monad.Logger (LoggingT, MonadLogger, - NoLoggingT, logInfoN, - logWarnN) -import Control.Monad.Reader (MonadReader, ReaderT, - forM) -import Control.Monad.Trans (lift) -import Data.Aeson ((.=)) -import qualified Data.Aeson as A -import qualified Data.ByteString.Char8 as C8 -import Data.Coerce (coerce) -import Data.Functor ((<&>)) -import qualified Data.Map as M -import Data.Pool (Pool) -import Data.Proxy (Proxy (Proxy)) -import Data.Swagger (toSchema) -import Data.Text (Text) -import Data.Text.Encoding (decodeASCII, decodeUtf8) -import Data.Time (NominalDiffTime, - UTCTime (utctDay), - addUTCTime, diffUTCTime, - getCurrentTime, - nominalDay) -import qualified Data.Vector as V -import Database.Persist -import Database.Persist.Postgresql (SqlBackend, - migrateEnableExtension, - runMigration) -import Fmt ((+|), (|+)) -import qualified Network.WebSockets as WS -import Servant (Application, - ServerError (errBody), - err400, err401, err404, - serve, - serveDirectoryFileServer, - throwError) -import Servant.API (NoContent (..), - (:<|>) (..)) -import Servant.Server (Handler, hoistServer) -import Servant.Swagger (toSwagger) - -import API +import API (API, CompleteAPI, Metrics (..)) +import Conduit (ResourceT) +import Config (LoggingConfig, ServerConfig (..)) +import Control.Concurrent.STM (newTVarIO) +import Control.Monad.Extra (forM) +import Control.Monad.IO.Class (MonadIO (liftIO)) +import Control.Monad.Logger (MonadLogger) +import Control.Monad.Reader (ReaderT) +import Data.ByteString.Lazy (toStrict) +import Data.Functor ((<&>)) +import qualified Data.Map as M +import Data.Pool (Pool) +import Data.Proxy (Proxy (Proxy)) +import Data.Text.Encoding (decodeUtf8) +import Data.Time (getCurrentTime) +import Data.UUID (UUID) +import Database.Persist (Entity (..), + PersistQueryRead (selectFirst), + SelectOpt (Desc), selectList, + (<-.), (==.), (>=.), (||.)) +import Database.Persist.Postgresql (SqlBackend, + migrateEnableExtension, + runMigration) +import Fmt ((+|), (|+)) import qualified GTFS import Persist -import Server.ControlRoom -import Server.GTFS_RT (gtfsRealtimeServer) -import Server.Util (Service, ServiceM, - runService, sendErrorMsg, - utcToSeconds) -import Yesod (toWaiAppPlain) +import Prometheus (Info (Info), exportMetricsAsText, + gauge, register) +import Prometheus.Metric.GHC (ghcMetrics) +import Servant (Application, err401, serve, + serveDirectoryFileServer, + throwError) +import Servant.API ((:<|>) (..)) +import Servant.Server (hoistServer) +import Servant.Swagger (toSwagger) +import Server.Base (ServerState) +import Server.ControlRoom (ControlRoom (ControlRoom)) +import Server.GTFS_RT (gtfsRealtimeServer) +import Server.Ingest (handleTrackerRegister, + handleTrainPing, handleWS) +import Server.Subscribe (handleSubscribe) +import Server.Util (Service, runService) +import System.IO.Unsafe (unsafePerformIO) +import Yesod (toWaiAppPlain) -import Conduit (ResourceT) -import Config (LoggingConfig, - ServerConfig (..)) -import Control.Exception (throw) -import Data.ByteString (ByteString) -import Data.ByteString.Lazy (toStrict) -import Data.Foldable (minimumBy) -import Data.Function (on, (&)) -import Data.Maybe (fromMaybe) -import qualified Data.Text as T -import Data.Time.LocalTime.TimeZone.Olson (getTimeZoneSeriesFromOlsonFile) -import Data.Time.LocalTime.TimeZone.Series (TimeZoneSeries) -import Data.UUID (UUID) -import qualified Data.UUID as UUID -import Extrapolation (Extrapolator (..), - LinearExtrapolator (..), - euclid) -import GTFS (Seconds (unSeconds), - seconds2Double) -import Prometheus -import Prometheus.Metric.GHC -import System.FilePath ((</>)) -import System.IO.Unsafe application :: GTFS.GTFS -> Pool SqlBackend -> ServerConfig -> IO Application application gtfs dbpool settings = do @@ -100,189 +57,48 @@ application gtfs dbpool settings = do <$> register (gauge (Info "ws_connections" "Number of WS Connections")) register ghcMetrics - -- TODO: maybe cache these in a TVar, we're not likely to ever need - -- more than one of these - let getTzseries tzname = getTimeZoneSeriesFromOlsonFile - (serverConfigZoneinfoPath settings </> T.unpack tzname) - subscribers <- newTVarIO mempty pure $ serve (Proxy @CompleteAPI) $ hoistServer (Proxy @CompleteAPI) (runService (serverConfigLogging settings)) - $ server gtfs getTzseries metrics subscribers dbpool settings + $ server gtfs metrics subscribers dbpool settings --- databaseMigration :: ConnectionString -> IO () doMigration pool = runSqlWithoutLog pool $ runMigration $ do migrateEnableExtension "uuid-ossp" migrateAll -server :: GTFS.GTFS -> (Text -> IO TimeZoneSeries) -> Metrics -> TVar (M.Map UUID [TQueue (Maybe TrainPing)]) -> Pool SqlBackend -> ServerConfig -> Service CompleteAPI -server gtfs getTzseries Metrics{..} subscribers dbpool settings = handleDebugAPI - :<|> (handleTrackerRegister :<|> handleTrainPing (throwError err401) :<|> handleWS - :<|> handleSubscribe :<|> handleDebugState :<|> handleDebugTrain +server + :: GTFS.GTFS + -> Metrics + -> ServerState + -> Pool SqlBackend + -> ServerConfig + -> Service CompleteAPI +server gtfs metrics@Metrics{..} subscribers dbpool settings = handleDebugAPI + :<|> (handleTrackerRegister dbpool + :<|> handleTrainPing dbpool subscribers settings (throwError err401) + :<|> handleWS dbpool subscribers settings metrics + :<|> handleSubscribe dbpool subscribers + :<|> handleDebugState :<|> handleDebugTrain :<|> pure (GTFS.gtfsFile gtfs) :<|> gtfsRealtimeServer gtfs dbpool) - :<|> metrics + :<|> handleMetrics :<|> serveDirectoryFileServer (serverConfigAssets settings) :<|> pure (unsafePerformIO (toWaiAppPlain (ControlRoom gtfs dbpool settings))) - where handleTrackerRegister RegisterJson{..} = do - today <- liftIO getCurrentTime <&> utctDay - expires <- liftIO $ getCurrentTime <&> addUTCTime validityPeriod - runSql dbpool $ do - TrackerKey tracker <- insert (Tracker expires False registerAgent Nothing) - pure tracker - handleTrainPing onError ping@TrainPing{..} = - isTokenValid dbpool trainPingToken >>= \case - Nothing -> onError >> pure Nothing - Just tracker@Tracker{..} -> do - - -- if the tracker is not associated with a ticket, it is probably - -- just starting out on a new trip, or has finished an old one. - maybeTicketId <- case trackerCurrentTicket of - Just ticketId -> pure (Just ticketId) - Nothing -> runSql dbpool $ do - now <- liftIO getCurrentTime - tickets <- selectList [ TicketDay ==. utctDay now, TicketCompleted ==. False ] [] - ticketsWithFirstStation <- flip mapMaybeM tickets - (\ticket@(Entity ticketId _) -> do - selectFirst [StopTicket ==. ticketId] [Asc StopSequence] >>= \case - Nothing -> pure Nothing - Just (Entity _ stop) -> do - station <- getJust (stopStation stop) - tzseries <- liftIO $ getTzseries (GTFS.tzname (stopDeparture stop)) - pure (Just (ticket, station, stop, tzseries))) - - if null ticketsWithFirstStation then pure Nothing else do - let (closestTicket, _, _, _) = minimumBy - -- (compare `on` euclid trainPingGeopos . stationGeopos . snd) - (compare `on` - (\(Entity _ ticket, station, stop, tzseries) -> - let - runningDay = ticketDay ticket - spaceDistance = euclid trainPingGeopos (stationGeopos station) - timeDiff = - GTFS.toSeconds (stopDeparture stop) tzseries runningDay - - utcToSeconds now runningDay - in - euclid trainPingGeopos (stationGeopos station) - + abs (seconds2Double timeDiff / 3600))) - ticketsWithFirstStation - logInfoN - $ "Tracker "+|UUID.toString (coerce trainPingToken)|+ - " is now handling ticket "+|UUID.toString (coerce (entityKey closestTicket))|+ - " (trip "+|ticketTripName (entityVal closestTicket)|+")." - - update (coerce trainPingToken) - [TrackerCurrentTicket =. Just (entityKey closestTicket)] - - pure (Just (entityKey closestTicket)) - - ticketId <- case maybeTicketId of - Just ticketId -> pure ticketId - Nothing -> do - logWarnN $ "Tracker "+|UUID.toString (coerce trainPingToken)|+ - " sent a ping, but no trips are running today." - throwError err400 - - runSql dbpool $ do - ticket@Ticket{..} <- getJust ticketId - - stations <- selectList [ StopTicket ==. ticketId ] [Asc StopArrival] - >>= mapM (\stop -> do - station <- getJust (stopStation (entityVal stop)) - tzseries <- liftIO $ getTzseries (GTFS.tzname (stopArrival (entityVal stop))) - pure (entityVal stop, station, tzseries)) - <&> V.fromList - - shapePoints <- selectList [ShapePointShape ==. ticketShape] [Asc ShapePointIndex] - <&> (V.fromList . fmap entityVal) - - let anchor = extrapolateAnchorFromPing LinearExtrapolator - ticketId ticket stations shapePoints ping - - insert ping - - last <- selectFirst [TrainAnchorTicket ==. ticketId] [Desc TrainAnchorWhen] - -- only insert new estimates if they've actually changed anything - when (fmap (trainAnchorDelay . entityVal) last /= Just (trainAnchorDelay anchor)) - $ void $ insert anchor - - -- are we at the final stop? if so, mark this ticket as done - -- & the tracker as free - let maxSequence = V.last stations - & (\(stop, _, _) -> stopSequence stop) - & fromIntegral - when (trainAnchorSequence anchor + 0.1 >= maxSequence) $ do - update (coerce trainPingToken) - [TrackerCurrentTicket =. Nothing] - update ticketId - [TicketCompleted =. True] - logInfoN $ "Tracker "+|UUID.toString (coerce trainPingToken)|+ - " has completed ticket "+|UUID.toString (coerce ticketId)|+ - " (trip "+|ticketTripName|+")" - - queues <- liftIO $ atomically $ do - queues <- readTVar subscribers <&> M.lookup (coerce ticketId) - whenJust queues $ - mapM_ (\q -> writeTQueue q (Just ping)) - pure queues - pure (Just anchor) - handleWS conn = do - liftIO $ WS.forkPingThread conn 30 - incGauge metricsWSGauge - handle (\(e :: WS.ConnectionException) -> decGauge metricsWSGauge) $ forever $ do - msg <- liftIO $ WS.receiveData conn - case A.eitherDecode msg of - Left err -> do - logWarnN ("stray websocket message: "+|decodeASCII (toStrict msg)|+" (could not decode: "+|err|+")") - liftIO $ WS.sendClose conn (C8.pack err) - -- TODO: send a close msg (Nothing) to the subscribed queues? decGauge metricsWSGauge - Right ping -> do - -- if invalid token, send a "polite" close request. Note that the client may - -- ignore this and continue sending messages, which will continue to be handled. - handleTrainPing (liftIO $ WS.sendClose conn ("" :: ByteString)) ping >>= \case - Just anchor -> liftIO $ WS.sendTextData conn (A.encode anchor) - Nothing -> pure () - handleSubscribe (ticketId :: UUID) conn = liftIO $ WS.withPingThread conn 30 (pure ()) $ do - queue <- atomically $ do - queue <- newTQueue - qs <- readTVar subscribers - writeTVar subscribers - $ M.insertWith (<>) ticketId [queue] qs - pure queue - -- send most recent ping, if any (so we won't have to wait for movement) - lastPing <- runSqlWithoutLog dbpool $ do - trackers <- getTicketTrackers ticketId - <&> fmap entityKey - selectFirst [TrainPingToken <-. trackers] [Desc TrainPingTimestamp] - <&> fmap entityVal - whenJust lastPing $ \ping -> - WS.sendTextData conn (A.encode lastPing) - handle (\(e :: WS.ConnectionException) -> removeSubscriber queue) $ forever $ do - res <- atomically $ readTQueue queue - case res of - Just ping -> WS.sendTextData conn (A.encode ping) - Nothing -> do - removeSubscriber queue - WS.sendClose conn (C8.pack "train ended") - where removeSubscriber queue = atomically $ do - qs <- readTVar subscribers - writeTVar subscribers - $ M.adjust (filter (/= queue)) ticketId qs - handleDebugState = do - now <- liftIO getCurrentTime - runSql dbpool $ do - tracker <- selectList [TrackerBlocked ==. False, TrackerExpires >=. now] [] - pairs <- forM tracker $ \(Entity token@(TrackerKey uuid) _) -> do - entities <- selectList [TrainPingToken ==. token] [] - pure (uuid, fmap entityVal entities) - pure (M.fromList pairs) - handleDebugTrain ticketId = do - runSql dbpool $ do - trackers <- getTicketTrackers ticketId - pings <- forM trackers $ \(Entity token _) -> do - selectList [TrainPingToken ==. token] [] <&> fmap entityVal - pure (concat pings) - handleDebugAPI = pure $ toSwagger (Proxy @API) - metrics = exportMetricsAsText <&> (decodeUtf8 . toStrict) + where + handleDebugState = do + now <- liftIO getCurrentTime + runSql dbpool $ do + tracker <- selectList [TrackerBlocked ==. False, TrackerExpires >=. now] [] + pairs <- forM tracker $ \(Entity token@(TrackerKey uuid) _) -> do + entities <- selectList [TrainPingToken ==. token] [] + pure (uuid, fmap entityVal entities) + pure (M.fromList pairs) + handleDebugTrain ticketId = runSql dbpool $ do + trackers <- getTicketTrackers ticketId + pings <- forM trackers $ \(Entity token _) -> do + selectList [TrainPingToken ==. token] [] <&> fmap entityVal + pure (concat pings) + handleDebugAPI = pure $ toSwagger (Proxy @API) + handleMetrics = exportMetricsAsText <&> (decodeUtf8 . toStrict) getTicketTrackers :: (MonadLogger (t (ResourceT IO)), MonadIO (t (ResourceT IO))) => UUID -> ReaderT SqlBackend (t (ResourceT IO)) [Entity Tracker] @@ -290,22 +106,3 @@ getTicketTrackers ticketId = do joins <- selectList [TrackerTicketTicket ==. TicketKey ticketId] [] <&> fmap (trackerTicketTracker . entityVal) selectList ([TrackerId <-. joins] ||. [TrackerCurrentTicket ==. Just (TicketKey ticketId)]) [] - - --- TODO: proper debug logging for expired tokens -isTokenValid :: Pool SqlBackend -> TrackerId -> ServiceM (Maybe Tracker) -isTokenValid dbpool token = runSql dbpool $ get token >>= \case - Just tracker | not (trackerBlocked tracker) -> do - ifM (hasExpired (trackerExpires tracker)) - (pure Nothing) - (pure (Just tracker)) - _ -> pure Nothing - -hasExpired :: MonadIO m => UTCTime -> m Bool -hasExpired limit = do - now <- liftIO getCurrentTime - pure (now > limit) - -validityPeriod :: NominalDiffTime -validityPeriod = nominalDay - |