aboutsummaryrefslogtreecommitdiff
path: root/lib/Server.hs
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Server.hs')
-rw-r--r--lib/Server.hs179
1 files changed, 112 insertions, 67 deletions
diff --git a/lib/Server.hs b/lib/Server.hs
index 3922a7b..73c55cb 100644
--- a/lib/Server.hs
+++ b/lib/Server.hs
@@ -16,12 +16,14 @@ import Control.Concurrent.STM (TQueue, TVar, atomically,
import Control.Monad (forever, unless, void,
when)
import Control.Monad.Catch (handle)
-import Control.Monad.Extra (ifM, maybeM, unlessM,
- whenJust, whenM)
+import Control.Monad.Extra (ifM, mapMaybeM, maybeM,
+ unlessM, whenJust, whenM)
import Control.Monad.IO.Class (MonadIO (liftIO))
-import Control.Monad.Logger (LoggingT, NoLoggingT,
+import Control.Monad.Logger (LoggingT, MonadLogger,
+ NoLoggingT, logInfoN,
logWarnN)
-import Control.Monad.Reader (ReaderT, forM)
+import Control.Monad.Reader (MonadReader, ReaderT,
+ forM)
import Control.Monad.Trans (lift)
import Data.Aeson ((.=))
import qualified Data.Aeson as A
@@ -33,7 +35,7 @@ import Data.Pool (Pool)
import Data.Proxy (Proxy (Proxy))
import Data.Swagger (toSchema)
import Data.Text (Text)
-import Data.Text.Encoding (decodeUtf8)
+import Data.Text.Encoding (decodeASCII, decodeUtf8)
import Data.Time (NominalDiffTime,
UTCTime (utctDay),
addUTCTime, diffUTCTime,
@@ -48,7 +50,8 @@ import Fmt ((+|), (|+))
import qualified Network.WebSockets as WS
import Servant (Application,
ServerError (errBody),
- err401, err404, serve,
+ err400, err401, err404,
+ serve,
serveDirectoryFileServer,
throwError)
import Servant.API (NoContent (..),
@@ -62,24 +65,33 @@ import Persist
import Server.ControlRoom
import Server.GTFS_RT (gtfsRealtimeServer)
import Server.Util (Service, ServiceM,
- runService, sendErrorMsg)
+ runService, sendErrorMsg,
+ utcToSeconds)
import Yesod (toWaiAppPlain)
-import Extrapolation (Extrapolator (..),
- LinearExtrapolator (..))
-import System.IO.Unsafe
-
import Conduit (ResourceT)
-import Config (ServerConfig (serverConfigAssets, serverConfigZoneinfoPath))
+import Config (LoggingConfig,
+ ServerConfig (..))
+import Control.Exception (throw)
import Data.ByteString (ByteString)
import Data.ByteString.Lazy (toStrict)
+import Data.Foldable (minimumBy)
+import Data.Function (on, (&))
+import Data.Maybe (fromMaybe)
import qualified Data.Text as T
import Data.Time.LocalTime.TimeZone.Olson (getTimeZoneSeriesFromOlsonFile)
import Data.Time.LocalTime.TimeZone.Series (TimeZoneSeries)
import Data.UUID (UUID)
+import qualified Data.UUID as UUID
+import Extrapolation (Extrapolator (..),
+ LinearExtrapolator (..),
+ euclid)
+import GTFS (Seconds (unSeconds),
+ seconds2Double)
import Prometheus
import Prometheus.Metric.GHC
import System.FilePath ((</>))
+import System.IO.Unsafe
application :: GTFS.GTFS -> Pool SqlBackend -> ServerConfig -> IO Application
application gtfs dbpool settings = do
@@ -94,65 +106,84 @@ application gtfs dbpool settings = do
(serverConfigZoneinfoPath settings </> T.unpack tzname)
subscribers <- newTVarIO mempty
- pure $ serve (Proxy @CompleteAPI) $ hoistServer (Proxy @CompleteAPI) runService
+ pure $ serve (Proxy @CompleteAPI)
+ $ hoistServer (Proxy @CompleteAPI) (runService (serverConfigLogging settings))
$ server gtfs getTzseries metrics subscribers dbpool settings
-- databaseMigration :: ConnectionString -> IO ()
-doMigration pool = runSql pool $ runMigration $ do
+doMigration pool = runSqlWithoutLog pool $ runMigration $ do
migrateEnableExtension "uuid-ossp"
migrateAll
server :: GTFS.GTFS -> (Text -> IO TimeZoneSeries) -> Metrics -> TVar (M.Map UUID [TQueue (Maybe TrainPing)]) -> Pool SqlBackend -> ServerConfig -> Service CompleteAPI
server gtfs getTzseries Metrics{..} subscribers dbpool settings = handleDebugAPI
- :<|> (handleTimetable :<|> handleTimetableStops :<|> handleTrip
- :<|> handleRegister :<|> handleTrainPing (throwError err401) :<|> handleWS
+ :<|> (handleTrackerRegister :<|> handleTrainPing (throwError err401) :<|> handleWS
:<|> handleSubscribe :<|> handleDebugState :<|> handleDebugTrain
- :<|> handleDebugRegister :<|> pure (GTFS.gtfsFile gtfs) :<|> gtfsRealtimeServer gtfs dbpool)
+ :<|> pure (GTFS.gtfsFile gtfs) :<|> gtfsRealtimeServer gtfs dbpool)
:<|> metrics
:<|> serveDirectoryFileServer (serverConfigAssets settings)
:<|> pure (unsafePerformIO (toWaiAppPlain (ControlRoom gtfs dbpool settings)))
- where handleTimetable station maybeDay =
- M.filter isLastStop . GTFS.tripsOnDay gtfs <$> liftIO day
- where isLastStop = (==) station . GTFS.stationId . GTFS.stopStation . V.last . GTFS.tripStops
- day = maybeM (getCurrentTime <&> utctDay) pure (pure maybeDay)
- handleTimetableStops day =
- pure . A.toJSON . fmap mkJson . M.elems $ GTFS.tripsOnDay gtfs day
- where mkJson :: GTFS.Trip GTFS.Deep GTFS.Deep -> A.Value
- mkJson GTFS.Trip {..} = A.object
- [ "trip" .= tripTripId
- , "sequencelength" .= (GTFS.stopSequence . V.last) tripStops
- , "stops" .= fmap (\GTFS.Stop{..} -> A.object
- [ "departure" .= GTFS.toUTC stopDeparture (GTFS.tzseries gtfs) day
- , "arrival" .= GTFS.toUTC stopArrival (GTFS.tzseries gtfs) day
- , "station" .= GTFS.stationId stopStation
- , "lat" .= GTFS.stationLat stopStation
- , "lon" .= GTFS.stationLon stopStation
- ]) tripStops
- ]
- handleTrip trip = case M.lookup trip (GTFS.trips gtfs) of
- Just res -> pure res
- Nothing -> throwError err404
- handleRegister (ticketId :: UUID) RegisterJson{..} = do
+ where handleTrackerRegister RegisterJson{..} = do
today <- liftIO getCurrentTime <&> utctDay
expires <- liftIO $ getCurrentTime <&> addUTCTime validityPeriod
runSql dbpool $ do
- TrackerKey tracker <- insert (Tracker expires False registerAgent)
- insert (TrackerTicket (TicketKey ticketId) (TrackerKey tracker))
- pure tracker
- handleDebugRegister (ticketId :: UUID) = do
- expires <- liftIO $ getCurrentTime <&> addUTCTime validityPeriod
- runSql dbpool $ do
- TrackerKey tracker <- insert (Tracker expires False "debug key")
- insert (TrackerTicket (TicketKey ticketId) (TrackerKey tracker))
+ TrackerKey tracker <- insert (Tracker expires False registerAgent Nothing)
pure tracker
handleTrainPing onError ping@TrainPing{..} =
- let ticketId = trainPingTicket in
- isTokenValid dbpool trainPingToken ticketId >>= \case
- Nothing -> do
- onError
- pure Nothing
- Just (tracker@Tracker{..}, ticket@Ticket{..}) -> do
+ isTokenValid dbpool trainPingToken >>= \case
+ Nothing -> onError >> pure Nothing
+ Just tracker@Tracker{..} -> do
+
+ -- if the tracker is not associated with a ticket, it is probably
+ -- just starting out on a new trip, or has finished an old one.
+ maybeTicketId <- case trackerCurrentTicket of
+ Just ticketId -> pure (Just ticketId)
+ Nothing -> runSql dbpool $ do
+ now <- liftIO getCurrentTime
+ tickets <- selectList [ TicketDay ==. utctDay now, TicketCompleted ==. False ] []
+ ticketsWithFirstStation <- flip mapMaybeM tickets
+ (\ticket@(Entity ticketId _) -> do
+ selectFirst [StopTicket ==. ticketId] [Asc StopSequence] >>= \case
+ Nothing -> pure Nothing
+ Just (Entity _ stop) -> do
+ station <- getJust (stopStation stop)
+ tzseries <- liftIO $ getTzseries (GTFS.tzname (stopDeparture stop))
+ pure (Just (ticket, station, stop, tzseries)))
+
+ if null ticketsWithFirstStation then pure Nothing else do
+ let (closestTicket, _, _, _) = minimumBy
+ -- (compare `on` euclid trainPingGeopos . stationGeopos . snd)
+ (compare `on`
+ (\(Entity _ ticket, station, stop, tzseries) ->
+ let
+ runningDay = ticketDay ticket
+ spaceDistance = euclid trainPingGeopos (stationGeopos station)
+ timeDiff =
+ GTFS.toSeconds (stopDeparture stop) tzseries runningDay
+ - utcToSeconds now runningDay
+ in
+ euclid trainPingGeopos (stationGeopos station)
+ + abs (seconds2Double timeDiff / 3600)))
+ ticketsWithFirstStation
+ logInfoN
+ $ "Tracker "+|UUID.toString (coerce trainPingToken)|+
+ " is now handling ticket "+|UUID.toString (coerce (entityKey closestTicket))|+
+ " (trip "+|ticketTripName (entityVal closestTicket)|+")."
+
+ update (coerce trainPingToken)
+ [TrackerCurrentTicket =. Just (entityKey closestTicket)]
+
+ pure (Just (entityKey closestTicket))
+
+ ticketId <- case maybeTicketId of
+ Just ticketId -> pure ticketId
+ Nothing -> do
+ logWarnN $ "Tracker "+|UUID.toString (coerce trainPingToken)|+
+ " sent a ping, but no trips are running today."
+ throwError err400
+
runSql dbpool $ do
+ ticket@Ticket{..} <- getJust ticketId
stations <- selectList [ StopTicket ==. ticketId ] [Asc StopArrival]
>>= mapM (\stop -> do
@@ -165,17 +196,31 @@ server gtfs getTzseries Metrics{..} subscribers dbpool settings = handleDebugAPI
<&> (V.fromList . fmap entityVal)
let anchor = extrapolateAnchorFromPing LinearExtrapolator
- ticket stations shapePoints ping
+ ticketId ticket stations shapePoints ping
insert ping
- last <- selectFirst [TrainAnchorTicket ==. trainPingTicket] [Desc TrainAnchorWhen]
+ last <- selectFirst [TrainAnchorTicket ==. ticketId] [Desc TrainAnchorWhen]
-- only insert new estimates if they've actually changed anything
when (fmap (trainAnchorDelay . entityVal) last /= Just (trainAnchorDelay anchor))
$ void $ insert anchor
+ -- are we at the final stop? if so, mark this ticket as done
+ -- & the tracker as free
+ let maxSequence = V.last stations
+ & (\(stop, _, _) -> stopSequence stop)
+ & fromIntegral
+ when (trainAnchorSequence anchor + 0.1 >= maxSequence) $ do
+ update (coerce trainPingToken)
+ [TrackerCurrentTicket =. Nothing]
+ update ticketId
+ [TicketCompleted =. True]
+ logInfoN $ "Tracker "+|UUID.toString (coerce trainPingToken)|+
+ " has completed ticket "+|UUID.toString (coerce ticketId)|+
+ " (trip "+|ticketTripName|+")"
+
queues <- liftIO $ atomically $ do
- queues <- readTVar subscribers <&> M.lookup (coerce trainPingTicket)
+ queues <- readTVar subscribers <&> M.lookup (coerce ticketId)
whenJust queues $
mapM_ (\q -> writeTQueue q (Just ping))
pure queues
@@ -187,14 +232,14 @@ server gtfs getTzseries Metrics{..} subscribers dbpool settings = handleDebugAPI
msg <- liftIO $ WS.receiveData conn
case A.eitherDecode msg of
Left err -> do
- logWarnN ("stray websocket message: "+|show msg|+" (could not decode: "+|err|+")")
+ logWarnN ("stray websocket message: "+|decodeASCII (toStrict msg)|+" (could not decode: "+|err|+")")
liftIO $ WS.sendClose conn (C8.pack err)
-- TODO: send a close msg (Nothing) to the subscribed queues? decGauge metricsWSGauge
- Right ping ->
+ Right ping -> do
-- if invalid token, send a "polite" close request. Note that the client may
-- ignore this and continue sending messages, which will continue to be handled.
- liftIO $ handleTrainPing (WS.sendClose conn ("" :: ByteString)) ping >>= \case
- Just anchor -> WS.sendTextData conn (A.encode anchor)
+ handleTrainPing (liftIO $ WS.sendClose conn ("" :: ByteString)) ping >>= \case
+ Just anchor -> liftIO $ WS.sendTextData conn (A.encode anchor)
Nothing -> pure ()
handleSubscribe (ticketId :: UUID) conn = liftIO $ WS.withPingThread conn 30 (pure ()) $ do
queue <- atomically $ do
@@ -204,7 +249,7 @@ server gtfs getTzseries Metrics{..} subscribers dbpool settings = handleDebugAPI
$ M.insertWith (<>) ticketId [queue] qs
pure queue
-- send most recent ping, if any (so we won't have to wait for movement)
- lastPing <- runSql dbpool $ do
+ lastPing <- runSqlWithoutLog dbpool $ do
trackers <- getTicketTrackers ticketId
<&> fmap entityKey
selectFirst [TrainPingToken <-. trackers] [Desc TrainPingTimestamp]
@@ -239,21 +284,21 @@ server gtfs getTzseries Metrics{..} subscribers dbpool settings = handleDebugAPI
handleDebugAPI = pure $ toSwagger (Proxy @API)
metrics = exportMetricsAsText <&> (decodeUtf8 . toStrict)
-getTicketTrackers :: UUID -> ReaderT SqlBackend (NoLoggingT (ResourceT IO)) [Entity Tracker]
+getTicketTrackers :: (MonadLogger (t (ResourceT IO)), MonadIO (t (ResourceT IO)))
+ => UUID -> ReaderT SqlBackend (t (ResourceT IO)) [Entity Tracker]
getTicketTrackers ticketId = do
joins <- selectList [TrackerTicketTicket ==. TicketKey ticketId] []
<&> fmap (trackerTicketTracker . entityVal)
- selectList [TrackerId <-. joins] []
+ selectList ([TrackerId <-. joins] ||. [TrackerCurrentTicket ==. Just (TicketKey ticketId)]) []
-- TODO: proper debug logging for expired tokens
-isTokenValid :: MonadIO m => Pool SqlBackend -> TrackerId -> TicketId -> m (Maybe (Tracker, Ticket))
-isTokenValid dbpool token ticketId = runSql dbpool $ get token >>= \case
+isTokenValid :: Pool SqlBackend -> TrackerId -> ServiceM (Maybe Tracker)
+isTokenValid dbpool token = runSql dbpool $ get token >>= \case
Just tracker | not (trackerBlocked tracker) -> do
ifM (hasExpired (trackerExpires tracker))
(pure Nothing)
- $ runSql dbpool $ get ticketId
- <&> (\case { Nothing -> Nothing; Just ticket -> Just (tracker, ticket) })
+ (pure (Just tracker))
_ -> pure Nothing
hasExpired :: MonadIO m => UTCTime -> m Bool