{-# LANGUAGE DataKinds #-} {-# LANGUAGE ExplicitNamespaces #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedLists #-} {-# LANGUAGE PartialTypeSignatures #-} {-# LANGUAGE RecordWildCards #-} -- Implementation of the API. This module is the main point of the program. module Server (application) where import Control.Concurrent.STM (TQueue, TVar, atomically, newTQueue, newTVar, newTVarIO, readTQueue, readTVar, writeTQueue, writeTVar) import Control.Monad (forever, unless, void, when) import Control.Monad.Catch (handle) import Control.Monad.Extra (ifM, mapMaybeM, maybeM, unlessM, whenJust, whenM) import Control.Monad.IO.Class (MonadIO (liftIO)) import Control.Monad.Logger (LoggingT, MonadLogger, NoLoggingT, logInfoN, logWarnN) import Control.Monad.Reader (MonadReader, ReaderT, forM) import Control.Monad.Trans (lift) import Data.Aeson ((.=)) import qualified Data.Aeson as A import qualified Data.ByteString.Char8 as C8 import Data.Coerce (coerce) import Data.Functor ((<&>)) import qualified Data.Map as M import Data.Pool (Pool) import Data.Proxy (Proxy (Proxy)) import Data.Swagger (toSchema) import Data.Text (Text) import Data.Text.Encoding (decodeASCII, decodeUtf8) import Data.Time (NominalDiffTime, UTCTime (utctDay), addUTCTime, diffUTCTime, getCurrentTime, nominalDay) import qualified Data.Vector as V import Database.Persist import Database.Persist.Postgresql (SqlBackend, migrateEnableExtension, runMigration) import Fmt ((+|), (|+)) import qualified Network.WebSockets as WS import Servant (Application, ServerError (errBody), err400, err401, err404, serve, serveDirectoryFileServer, throwError) import Servant.API (NoContent (..), (:<|>) (..)) import Servant.Server (Handler, hoistServer) import Servant.Swagger (toSwagger) import API import qualified GTFS import Persist import Server.ControlRoom import Server.GTFS_RT (gtfsRealtimeServer) import Server.Util (Service, ServiceM, runService, sendErrorMsg, utcToSeconds) import Yesod (toWaiAppPlain) import Conduit (ResourceT) import Config (LoggingConfig, ServerConfig (..)) import Control.Exception (throw) import Data.ByteString (ByteString) import Data.ByteString.Lazy (toStrict) import Data.Foldable (minimumBy) import Data.Function (on, (&)) import Data.Maybe (fromMaybe) import qualified Data.Text as T import Data.Time.LocalTime.TimeZone.Olson (getTimeZoneSeriesFromOlsonFile) import Data.Time.LocalTime.TimeZone.Series (TimeZoneSeries) import Data.UUID (UUID) import qualified Data.UUID as UUID import Extrapolation (Extrapolator (..), LinearExtrapolator (..), euclid) import GTFS (Seconds (unSeconds), seconds2Double) import Prometheus import Prometheus.Metric.GHC import System.FilePath (()) import System.IO.Unsafe application :: GTFS.GTFS -> Pool SqlBackend -> ServerConfig -> IO Application application gtfs dbpool settings = do doMigration dbpool metrics <- Metrics <$> register (gauge (Info "ws_connections" "Number of WS Connections")) register ghcMetrics -- TODO: maybe cache these in a TVar, we're not likely to ever need -- more than one of these let getTzseries tzname = getTimeZoneSeriesFromOlsonFile (serverConfigZoneinfoPath settings T.unpack tzname) subscribers <- newTVarIO mempty pure $ serve (Proxy @CompleteAPI) $ hoistServer (Proxy @CompleteAPI) (runService (serverConfigLogging settings)) $ server gtfs getTzseries metrics subscribers dbpool settings -- databaseMigration :: ConnectionString -> IO () doMigration pool = runSqlWithoutLog pool $ runMigration $ do migrateEnableExtension "uuid-ossp" migrateAll server :: GTFS.GTFS -> (Text -> IO TimeZoneSeries) -> Metrics -> TVar (M.Map UUID [TQueue (Maybe TrainPing)]) -> Pool SqlBackend -> ServerConfig -> Service CompleteAPI server gtfs getTzseries Metrics{..} subscribers dbpool settings = handleDebugAPI :<|> (handleTrackerRegister :<|> handleTrainPing (throwError err401) :<|> handleWS :<|> handleSubscribe :<|> handleDebugState :<|> handleDebugTrain :<|> pure (GTFS.gtfsFile gtfs) :<|> gtfsRealtimeServer gtfs dbpool) :<|> metrics :<|> serveDirectoryFileServer (serverConfigAssets settings) :<|> pure (unsafePerformIO (toWaiAppPlain (ControlRoom gtfs dbpool settings))) where handleTrackerRegister RegisterJson{..} = do today <- liftIO getCurrentTime <&> utctDay expires <- liftIO $ getCurrentTime <&> addUTCTime validityPeriod runSql dbpool $ do TrackerKey tracker <- insert (Tracker expires False registerAgent Nothing) pure tracker handleTrainPing onError ping@TrainPing{..} = isTokenValid dbpool trainPingToken >>= \case Nothing -> onError >> pure Nothing Just tracker@Tracker{..} -> do -- if the tracker is not associated with a ticket, it is probably -- just starting out on a new trip, or has finished an old one. maybeTicketId <- case trackerCurrentTicket of Just ticketId -> pure (Just ticketId) Nothing -> runSql dbpool $ do now <- liftIO getCurrentTime tickets <- selectList [ TicketDay ==. utctDay now, TicketCompleted ==. False ] [] ticketsWithFirstStation <- flip mapMaybeM tickets (\ticket@(Entity ticketId _) -> do selectFirst [StopTicket ==. ticketId] [Asc StopSequence] >>= \case Nothing -> pure Nothing Just (Entity _ stop) -> do station <- getJust (stopStation stop) tzseries <- liftIO $ getTzseries (GTFS.tzname (stopDeparture stop)) pure (Just (ticket, station, stop, tzseries))) if null ticketsWithFirstStation then pure Nothing else do let (closestTicket, _, _, _) = minimumBy -- (compare `on` euclid trainPingGeopos . stationGeopos . snd) (compare `on` (\(Entity _ ticket, station, stop, tzseries) -> let runningDay = ticketDay ticket spaceDistance = euclid trainPingGeopos (stationGeopos station) timeDiff = GTFS.toSeconds (stopDeparture stop) tzseries runningDay - utcToSeconds now runningDay in euclid trainPingGeopos (stationGeopos station) + abs (seconds2Double timeDiff / 3600))) ticketsWithFirstStation logInfoN $ "Tracker "+|UUID.toString (coerce trainPingToken)|+ " is now handling ticket "+|UUID.toString (coerce (entityKey closestTicket))|+ " (trip "+|ticketTripName (entityVal closestTicket)|+")." update (coerce trainPingToken) [TrackerCurrentTicket =. Just (entityKey closestTicket)] pure (Just (entityKey closestTicket)) ticketId <- case maybeTicketId of Just ticketId -> pure ticketId Nothing -> do logWarnN $ "Tracker "+|UUID.toString (coerce trainPingToken)|+ " sent a ping, but no trips are running today." throwError err400 runSql dbpool $ do ticket@Ticket{..} <- getJust ticketId stations <- selectList [ StopTicket ==. ticketId ] [Asc StopArrival] >>= mapM (\stop -> do station <- getJust (stopStation (entityVal stop)) tzseries <- liftIO $ getTzseries (GTFS.tzname (stopArrival (entityVal stop))) pure (entityVal stop, station, tzseries)) <&> V.fromList shapePoints <- selectList [ShapePointShape ==. ticketShape] [Asc ShapePointIndex] <&> (V.fromList . fmap entityVal) let anchor = extrapolateAnchorFromPing LinearExtrapolator ticketId ticket stations shapePoints ping insert ping last <- selectFirst [TrainAnchorTicket ==. ticketId] [Desc TrainAnchorWhen] -- only insert new estimates if they've actually changed anything when (fmap (trainAnchorDelay . entityVal) last /= Just (trainAnchorDelay anchor)) $ void $ insert anchor -- are we at the final stop? if so, mark this ticket as done -- & the tracker as free let maxSequence = V.last stations & (\(stop, _, _) -> stopSequence stop) & fromIntegral when (trainAnchorSequence anchor + 0.1 >= maxSequence) $ do update (coerce trainPingToken) [TrackerCurrentTicket =. Nothing] update ticketId [TicketCompleted =. True] logInfoN $ "Tracker "+|UUID.toString (coerce trainPingToken)|+ " has completed ticket "+|UUID.toString (coerce ticketId)|+ " (trip "+|ticketTripName|+")" queues <- liftIO $ atomically $ do queues <- readTVar subscribers <&> M.lookup (coerce ticketId) whenJust queues $ mapM_ (\q -> writeTQueue q (Just ping)) pure queues pure (Just anchor) handleWS conn = do liftIO $ WS.forkPingThread conn 30 incGauge metricsWSGauge handle (\(e :: WS.ConnectionException) -> decGauge metricsWSGauge) $ forever $ do msg <- liftIO $ WS.receiveData conn case A.eitherDecode msg of Left err -> do logWarnN ("stray websocket message: "+|decodeASCII (toStrict msg)|+" (could not decode: "+|err|+")") liftIO $ WS.sendClose conn (C8.pack err) -- TODO: send a close msg (Nothing) to the subscribed queues? decGauge metricsWSGauge Right ping -> do -- if invalid token, send a "polite" close request. Note that the client may -- ignore this and continue sending messages, which will continue to be handled. handleTrainPing (liftIO $ WS.sendClose conn ("" :: ByteString)) ping >>= \case Just anchor -> liftIO $ WS.sendTextData conn (A.encode anchor) Nothing -> pure () handleSubscribe (ticketId :: UUID) conn = liftIO $ WS.withPingThread conn 30 (pure ()) $ do queue <- atomically $ do queue <- newTQueue qs <- readTVar subscribers writeTVar subscribers $ M.insertWith (<>) ticketId [queue] qs pure queue -- send most recent ping, if any (so we won't have to wait for movement) lastPing <- runSqlWithoutLog dbpool $ do trackers <- getTicketTrackers ticketId <&> fmap entityKey selectFirst [TrainPingToken <-. trackers] [Desc TrainPingTimestamp] <&> fmap entityVal whenJust lastPing $ \ping -> WS.sendTextData conn (A.encode lastPing) handle (\(e :: WS.ConnectionException) -> removeSubscriber queue) $ forever $ do res <- atomically $ readTQueue queue case res of Just ping -> WS.sendTextData conn (A.encode ping) Nothing -> do removeSubscriber queue WS.sendClose conn (C8.pack "train ended") where removeSubscriber queue = atomically $ do qs <- readTVar subscribers writeTVar subscribers $ M.adjust (filter (/= queue)) ticketId qs handleDebugState = do now <- liftIO getCurrentTime runSql dbpool $ do tracker <- selectList [TrackerBlocked ==. False, TrackerExpires >=. now] [] pairs <- forM tracker $ \(Entity token@(TrackerKey uuid) _) -> do entities <- selectList [TrainPingToken ==. token] [] pure (uuid, fmap entityVal entities) pure (M.fromList pairs) handleDebugTrain ticketId = do runSql dbpool $ do trackers <- getTicketTrackers ticketId pings <- forM trackers $ \(Entity token _) -> do selectList [TrainPingToken ==. token] [] <&> fmap entityVal pure (concat pings) handleDebugAPI = pure $ toSwagger (Proxy @API) metrics = exportMetricsAsText <&> (decodeUtf8 . toStrict) getTicketTrackers :: (MonadLogger (t (ResourceT IO)), MonadIO (t (ResourceT IO))) => UUID -> ReaderT SqlBackend (t (ResourceT IO)) [Entity Tracker] getTicketTrackers ticketId = do joins <- selectList [TrackerTicketTicket ==. TicketKey ticketId] [] <&> fmap (trackerTicketTracker . entityVal) selectList ([TrackerId <-. joins] ||. [TrackerCurrentTicket ==. Just (TicketKey ticketId)]) [] -- TODO: proper debug logging for expired tokens isTokenValid :: Pool SqlBackend -> TrackerId -> ServiceM (Maybe Tracker) isTokenValid dbpool token = runSql dbpool $ get token >>= \case Just tracker | not (trackerBlocked tracker) -> do ifM (hasExpired (trackerExpires tracker)) (pure Nothing) (pure (Just tracker)) _ -> pure Nothing hasExpired :: MonadIO m => UTCTime -> m Bool hasExpired limit = do now <- liftIO getCurrentTime pure (now > limit) validityPeriod :: NominalDiffTime validityPeriod = nominalDay