aboutsummaryrefslogtreecommitdiff
path: root/lib/Server.hs
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Server.hs')
-rw-r--r--lib/Server.hs91
1 files changed, 50 insertions, 41 deletions
diff --git a/lib/Server.hs b/lib/Server.hs
index 016707b..c6d2d94 100644
--- a/lib/Server.hs
+++ b/lib/Server.hs
@@ -1,8 +1,9 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE ExplicitNamespaces #-}
-{-# LANGUAGE LambdaCase #-}
-{-# LANGUAGE OverloadedLists #-}
-{-# LANGUAGE RecordWildCards #-}
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE ExplicitNamespaces #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE OverloadedLists #-}
+{-# LANGUAGE PartialTypeSignatures #-}
+{-# LANGUAGE RecordWildCards #-}
-- Implementation of the API. This module is the main point of the program.
@@ -16,8 +17,8 @@ import Control.Monad.Catch (handle)
import Control.Monad.Extra (ifM, maybeM, unlessM, whenJust,
whenM)
import Control.Monad.IO.Class (MonadIO (liftIO))
-import Control.Monad.Logger (LoggingT, logWarnN)
-import Control.Monad.Reader (forM)
+import Control.Monad.Logger (LoggingT, NoLoggingT, logWarnN)
+import Control.Monad.Reader (ReaderT, forM)
import Control.Monad.Trans (lift)
import Data.Aeson ((.=))
import qualified Data.Aeson as A
@@ -61,9 +62,11 @@ import Extrapolation (Extrapolator (..),
LinearExtrapolator (..))
import System.IO.Unsafe
+import Conduit (ResourceT)
import Config (ServerConfig (serverConfigAssets))
import Data.ByteString (ByteString)
import Data.ByteString.Lazy (toStrict)
+import Data.UUID (UUID)
import Prometheus
import Prometheus.Metric.GHC
@@ -83,7 +86,7 @@ doMigration pool = runSql pool $
-- returns an empty list
runMigration migrateAll
-server :: GTFS -> Metrics -> TVar (M.Map TripID [TQueue (Maybe TrainPing)]) -> Pool SqlBackend -> ServerConfig -> Service CompleteAPI
+server :: GTFS -> Metrics -> TVar (M.Map UUID [TQueue (Maybe TrainPing)]) -> Pool SqlBackend -> ServerConfig -> Service CompleteAPI
server gtfs@GTFS{..} Metrics{..} subscribers dbpool settings = handleDebugAPI
:<|> (handleStations :<|> handleTimetable :<|> handleTimetableStops :<|> handleTrip
:<|> handleRegister :<|> handleTrainPing (throwError err401) :<|> handleWS
@@ -101,7 +104,7 @@ server gtfs@GTFS{..} Metrics{..} subscribers dbpool settings = handleDebugAPI
pure . A.toJSON . fmap mkJson . M.elems $ tripsOnDay gtfs day
where mkJson :: Trip Deep Deep -> A.Value
mkJson Trip {..} = A.object
- [ "trip" .= tripTripID
+ [ "trip" .= tripTripId
, "sequencelength" .= (stopSequence . V.last) tripStops
, "stops" .= fmap (\Stop{..} -> A.object
[ "departure" .= toUTC stopDeparture tzseries day
@@ -114,34 +117,35 @@ server gtfs@GTFS{..} Metrics{..} subscribers dbpool settings = handleDebugAPI
handleTrip trip = case M.lookup trip trips of
Just res -> pure res
Nothing -> throwError err404
- handleRegister tripID RegisterJson{..} = do
+ handleRegister (ticketId :: UUID) RegisterJson{..} = do
today <- liftIO getCurrentTime <&> utctDay
- unless (runsOnDay gtfs tripID today)
- $ sendErrorMsg "this trip does not run today."
expires <- liftIO $ getCurrentTime <&> addUTCTime validityPeriod
- RunningKey token <- runSql dbpool $ insert (Running expires False tripID today Nothing registerAgent)
- pure token
- handleDebugRegister tripID day = do
+ runSql dbpool $ do
+ TrackerKey tracker <- insert (Tracker expires False registerAgent)
+ insert (TrackerTicket (TicketKey ticketId) (TrackerKey tracker))
+ pure tracker
+ handleDebugRegister (ticketId :: UUID) = do
expires <- liftIO $ getCurrentTime <&> addUTCTime validityPeriod
- RunningKey token <- runSql dbpool $ insert (Running expires False tripID day Nothing "debug key")
- pure token
- handleTrainPing onError ping = isTokenValid dbpool (coerce $ trainPingToken ping) >>= \case
+ runSql dbpool $ do
+ TrackerKey tracker <- insert (Tracker expires False "debug key")
+ insert (TrackerTicket (TicketKey ticketId) (TrackerKey tracker))
+ pure tracker
+ handleTrainPing onError ping@TrainPing{..} = isTokenValid dbpool trainPingToken trainPingTicket
+ >>= \case
Nothing -> do
onError
pure Nothing
- Just running@Running{..} -> do
- let anchor = extrapolateAnchorFromPing LinearExtrapolator gtfs running ping
+ Just (tracker@Tracker{..}, ticket@Ticket{..}) -> do
+ let anchor = extrapolateAnchorFromPing LinearExtrapolator gtfs ticket ping
-- TODO: are these always inserted in order?
runSql dbpool $ do
insert ping
- last <- selectFirst
- [TrainAnchorTrip ==. runningTrip, TrainAnchorDay ==. runningDay]
- [Desc TrainAnchorWhen]
+ last <- selectFirst [TrainAnchorTicket ==. trainPingTicket] [Desc TrainAnchorWhen]
-- only insert new estimates if they've actually changed anything
when (fmap (trainAnchorDelay . entityVal) last /= Just (trainAnchorDelay anchor))
$ void $ insert anchor
queues <- liftIO $ atomically $ do
- queues <- readTVar subscribers <&> M.lookup runningTrip
+ queues <- readTVar subscribers <&> M.lookup (coerce trainPingTicket)
whenJust queues $
mapM_ (\q -> writeTQueue q (Just ping))
pure queues
@@ -162,18 +166,18 @@ server gtfs@GTFS{..} Metrics{..} subscribers dbpool settings = handleDebugAPI
liftIO $ handleTrainPing (WS.sendClose conn ("" :: ByteString)) ping >>= \case
Just anchor -> WS.sendTextData conn (A.encode anchor)
Nothing -> pure ()
- handleSubscribe tripId day conn = liftIO $ WS.withPingThread conn 30 (pure ()) $ do
+ handleSubscribe (ticketId :: UUID) conn = liftIO $ WS.withPingThread conn 30 (pure ()) $ do
queue <- atomically $ do
queue <- newTQueue
qs <- readTVar subscribers
writeTVar subscribers
- $ M.insertWith (<>) tripId [queue] qs
+ $ M.insertWith (<>) ticketId [queue] qs
pure queue
-- send most recent ping, if any (so we won't have to wait for movement)
lastPing <- runSql dbpool $ do
- tokens <- selectList [RunningDay ==. day, RunningTrip ==. tripId] []
+ trackers <- getTicketTrackers ticketId
<&> fmap entityKey
- selectFirst [TrainPingToken <-. tokens] [Desc TrainPingTimestamp]
+ selectFirst [TrainPingToken <-. trackers] [Desc TrainPingTimestamp]
<&> fmap entityVal
whenJust lastPing $ \ping ->
WS.sendTextData conn (A.encode lastPing)
@@ -187,34 +191,39 @@ server gtfs@GTFS{..} Metrics{..} subscribers dbpool settings = handleDebugAPI
where removeSubscriber queue = atomically $ do
qs <- readTVar subscribers
writeTVar subscribers
- $ M.adjust (filter (/= queue)) tripId qs
+ $ M.adjust (filter (/= queue)) ticketId qs
handleDebugState = do
now <- liftIO getCurrentTime
runSql dbpool $ do
- running <- selectList [RunningBlocked ==. False, RunningExpires >=. now] []
- pairs <- forM running $ \(Entity token@(RunningKey uuid) _) -> do
+ tracker <- selectList [TrackerBlocked ==. False, TrackerExpires >=. now] []
+ pairs <- forM tracker $ \(Entity token@(TrackerKey uuid) _) -> do
entities <- selectList [TrainPingToken ==. token] []
pure (uuid, fmap entityVal entities)
pure (M.fromList pairs)
- handleDebugTrain tripId day = do
- unless (runsOnDay gtfs tripId day)
- $ sendErrorMsg ("this trip does not run on "+|day|+".")
+ handleDebugTrain ticketId = do
runSql dbpool $ do
- tokens <- selectList [RunningTrip ==. tripId, RunningDay ==. day] []
- pings <- forM tokens $ \(Entity token _) -> do
+ trackers <- getTicketTrackers ticketId
+ pings <- forM trackers $ \(Entity token _) -> do
selectList [TrainPingToken ==. token] [] <&> fmap entityVal
pure (concat pings)
handleDebugAPI = pure $ toSwagger (Proxy @API)
metrics = exportMetricsAsText <&> (decodeUtf8 . toStrict)
+getTicketTrackers :: UUID -> ReaderT SqlBackend (NoLoggingT (ResourceT IO)) [Entity Tracker]
+getTicketTrackers ticketId = do
+ joins <- selectList [TrackerTicketTicket ==. TicketKey ticketId] []
+ <&> fmap (trackerTicketTracker . entityVal)
+ selectList [TrackerId <-. joins] []
+
-- TODO: proper debug logging for expired tokens
-isTokenValid :: MonadIO m => Pool SqlBackend -> Token -> m (Maybe Running)
-isTokenValid dbpool token = runSql dbpool $ get (coerce token) >>= \case
- Just trip | not (runningBlocked trip) -> do
- ifM (hasExpired (runningExpires trip))
+isTokenValid :: MonadIO m => Pool SqlBackend -> TrackerId -> TicketId -> m (Maybe (Tracker, Ticket))
+isTokenValid dbpool token ticketId = runSql dbpool $ get token >>= \case
+ Just tracker | not (trackerBlocked tracker) -> do
+ ifM (hasExpired (trackerExpires tracker))
(pure Nothing)
- (pure (Just trip))
+ $ runSql dbpool $ get ticketId
+ <&> (\case { Nothing -> Nothing; Just ticket -> Just (tracker, ticket) })
_ -> pure Nothing
hasExpired :: MonadIO m => UTCTime -> m Bool