aboutsummaryrefslogtreecommitdiff
path: root/lib/Server
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Server')
-rw-r--r--lib/Server/Base.hs9
-rw-r--r--lib/Server/Ingest.hs211
-rw-r--r--lib/Server/Subscribe.hs63
-rw-r--r--lib/Server/Util.hs64
4 files changed, 321 insertions, 26 deletions
diff --git a/lib/Server/Base.hs b/lib/Server/Base.hs
new file mode 100644
index 0000000..14b77ca
--- /dev/null
+++ b/lib/Server/Base.hs
@@ -0,0 +1,9 @@
+
+module Server.Base (ServerState) where
+
+import Control.Concurrent.STM (TQueue, TVar)
+import qualified Data.Map as M
+import Data.UUID (UUID)
+import Persist (TrainPing)
+
+type ServerState = TVar (M.Map UUID [TQueue (Maybe TrainPing)])
diff --git a/lib/Server/Ingest.hs b/lib/Server/Ingest.hs
new file mode 100644
index 0000000..d774017
--- /dev/null
+++ b/lib/Server/Ingest.hs
@@ -0,0 +1,211 @@
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE RecordWildCards #-}
+
+module Server.Ingest (handleTrackerRegister, handleTrainPing, handleWS) where
+import API (Metrics (..),
+ RegisterJson (..))
+import Control.Concurrent.STM (atomically, readTVar,
+ writeTQueue)
+import Control.Monad (forever, void, when)
+import Control.Monad.Catch (handle)
+import Control.Monad.Extra (ifM, mapMaybeM, whenJust)
+import Control.Monad.IO.Class (MonadIO (liftIO))
+import Control.Monad.Logger (LoggingT, logInfoN,
+ logWarnN)
+import Control.Monad.Reader (ReaderT)
+import qualified Data.Aeson as A
+import qualified Data.ByteString.Char8 as C8
+import Data.Coerce (coerce)
+import Data.Functor ((<&>))
+import qualified Data.Map as M
+import Data.Pool (Pool)
+import Data.Text (Text)
+import Data.Text.Encoding (decodeASCII, decodeUtf8)
+import Data.Time (NominalDiffTime,
+ UTCTime (..), addUTCTime,
+ getCurrentTime,
+ nominalDay)
+import qualified Data.Vector as V
+import Database.Persist
+import Database.Persist.Postgresql (SqlBackend)
+import Fmt ((+|), (|+))
+import qualified GTFS
+import qualified Network.WebSockets as WS
+import Persist
+import Servant (err400, throwError)
+import Servant.Server (Handler)
+import Server.Util (ServiceM, getTzseries,
+ utcToSeconds)
+
+import Config (LoggingConfig,
+ ServerConfig)
+import Data.ByteString (ByteString)
+import Data.ByteString.Lazy (toStrict)
+import Data.Foldable (minimumBy)
+import Data.Function (on, (&))
+import qualified Data.Text as T
+import Data.Time.LocalTime.TimeZone.Series (TimeZoneSeries)
+import qualified Data.UUID as UUID
+import Extrapolation (Extrapolator (..),
+ LinearExtrapolator (..),
+ euclid)
+import GTFS (seconds2Double)
+import Prometheus (decGauge, incGauge)
+import Server.Base (ServerState)
+
+
+
+handleTrackerRegister
+ :: Pool SqlBackend
+ -> RegisterJson
+ -> ServiceM Token
+handleTrackerRegister dbpool RegisterJson{..} = do
+ today <- liftIO getCurrentTime <&> utctDay
+ expires <- liftIO $ getCurrentTime <&> addUTCTime validityPeriod
+ runSql dbpool $ do
+ TrackerKey tracker <- insert (Tracker expires False registerAgent Nothing)
+ pure tracker
+ where
+ validityPeriod :: NominalDiffTime
+ validityPeriod = nominalDay
+
+handleTrainPing
+ :: Pool SqlBackend
+ -> ServerState
+ -> ServerConfig
+ -> LoggingT (ReaderT LoggingConfig Handler) a
+ -> TrainPing
+ -> LoggingT (ReaderT LoggingConfig Handler) (Maybe TrainAnchor)
+handleTrainPing dbpool subscribers cfg onError ping@TrainPing{..} =
+ isTokenValid dbpool trainPingToken >>= \case
+ Nothing -> onError >> pure Nothing
+ Just tracker@Tracker{..} -> do
+
+ -- if the tracker is not associated with a ticket, it is probably
+ -- just starting out on a new trip, or has finished an old one.
+ maybeTicketId <- case trackerCurrentTicket of
+ Just ticketId -> pure (Just ticketId)
+ Nothing -> runSql dbpool $ do
+ now <- liftIO getCurrentTime
+ tickets <- selectList [ TicketDay ==. utctDay now, TicketCompleted ==. False ] []
+ ticketsWithFirstStation <- flip mapMaybeM tickets
+ (\ticket@(Entity ticketId _) -> do
+ selectFirst [StopTicket ==. ticketId] [Asc StopSequence] >>= \case
+ Nothing -> pure Nothing
+ Just (Entity _ stop) -> do
+ station <- getJust (stopStation stop)
+ tzseries <- liftIO $ getTzseries cfg (GTFS.tzname (stopDeparture stop))
+ pure (Just (ticket, station, stop, tzseries)))
+
+ if null ticketsWithFirstStation then pure Nothing else do
+ let (closestTicket, _, _, _) = minimumBy
+ -- (compare `on` euclid trainPingGeopos . stationGeopos . snd)
+ (compare `on`
+ (\(Entity _ ticket, station, stop, tzseries) ->
+ let
+ runningDay = ticketDay ticket
+ spaceDistance = euclid trainPingGeopos (stationGeopos station)
+ timeDiff =
+ GTFS.toSeconds (stopDeparture stop) tzseries runningDay
+ - utcToSeconds now runningDay
+ in
+ euclid trainPingGeopos (stationGeopos station)
+ + abs (seconds2Double timeDiff / 3600)))
+ ticketsWithFirstStation
+ logInfoN
+ $ "Tracker "+|UUID.toString (coerce trainPingToken)|+
+ " is now handling ticket "+|UUID.toString (coerce (entityKey closestTicket))|+
+ " (trip "+|ticketTripName (entityVal closestTicket)|+")."
+
+ update trainPingToken
+ [TrackerCurrentTicket =. Just (entityKey closestTicket)]
+
+ pure (Just (entityKey closestTicket))
+
+ ticketId <- case maybeTicketId of
+ Just ticketId -> pure ticketId
+ Nothing -> do
+ logWarnN $ "Tracker "+|UUID.toString (coerce trainPingToken)|+
+ " sent a ping, but no trips are running today."
+ throwError err400
+
+ runSql dbpool $ do
+ ticket@Ticket{..} <- getJust ticketId
+
+ stations <- selectList [ StopTicket ==. ticketId ] [Asc StopArrival]
+ >>= mapM (\stop -> do
+ station <- getJust (stopStation (entityVal stop))
+ tzseries <- liftIO $ getTzseries cfg (GTFS.tzname (stopArrival (entityVal stop)))
+ pure (entityVal stop, station, tzseries))
+ <&> V.fromList
+
+ shapePoints <- selectList [ShapePointShape ==. ticketShape] [Asc ShapePointIndex]
+ <&> (V.fromList . fmap entityVal)
+
+ let anchor = extrapolateAnchorFromPing LinearExtrapolator
+ ticketId ticket stations shapePoints ping
+
+ insert ping
+
+ last <- selectFirst [TrainAnchorTicket ==. ticketId] [Desc TrainAnchorWhen]
+ -- only insert new estimates if they've actually changed anything
+ when (fmap (trainAnchorDelay . entityVal) last /= Just (trainAnchorDelay anchor))
+ $ void $ insert anchor
+
+ -- are we at the final stop? if so, mark this ticket as done
+ -- & the tracker as free
+ let maxSequence = V.last stations
+ & (\(stop, _, _) -> stopSequence stop)
+ & fromIntegral
+ when (trainAnchorSequence anchor + 0.1 >= maxSequence) $ do
+ update trainPingToken
+ [TrackerCurrentTicket =. Nothing]
+ update ticketId
+ [TicketCompleted =. True]
+ logInfoN $ "Tracker "+|UUID.toString (coerce trainPingToken)|+
+ " has completed ticket "+|UUID.toString (coerce ticketId)|+
+ " (trip "+|ticketTripName|+")"
+
+ queues <- liftIO $ atomically $ do
+ queues <- readTVar subscribers <&> M.lookup (coerce ticketId)
+ whenJust queues $
+ mapM_ (\q -> writeTQueue q (Just ping))
+ pure queues
+ pure (Just anchor)
+
+handleWS
+ :: Pool SqlBackend
+ -> ServerState
+ -> ServerConfig
+ -> Metrics
+ -> WS.Connection -> ServiceM ()
+handleWS dbpool subscribers cfg Metrics{..} conn = do
+ liftIO $ WS.forkPingThread conn 30
+ incGauge metricsWSGauge
+ handle (\(e :: WS.ConnectionException) -> decGauge metricsWSGauge) $ forever $ do
+ msg <- liftIO $ WS.receiveData conn
+ case A.eitherDecode msg of
+ Left err -> do
+ logWarnN ("stray websocket message: "+|decodeASCII (toStrict msg)|+" (could not decode: "+|err|+")")
+ liftIO $ WS.sendClose conn (C8.pack err)
+ -- TODO: send a close msg (Nothing) to the subscribed queues? decGauge metricsWSGauge
+ Right ping -> do
+ -- if invalid token, send a "polite" close request. Note that the client may
+ -- ignore this and continue sending messages, which will continue to be handled.
+ handleTrainPing dbpool subscribers cfg (liftIO $ WS.sendClose conn ("" :: ByteString)) ping >>= \case
+ Just anchor -> liftIO $ WS.sendTextData conn (A.encode anchor)
+ Nothing -> pure ()
+
+-- TODO: proper debug logging for expired tokens
+isTokenValid :: Pool SqlBackend -> TrackerId -> ServiceM (Maybe Tracker)
+isTokenValid dbpool token = runSql dbpool $ get token >>= \case
+ Just tracker | not (trackerBlocked tracker) -> do
+ ifM (hasExpired (trackerExpires tracker))
+ (pure Nothing)
+ (pure (Just tracker))
+ _ -> pure Nothing
+
+hasExpired :: MonadIO m => UTCTime -> m Bool
+hasExpired limit = do
+ now <- liftIO getCurrentTime
+ pure (now > limit)
diff --git a/lib/Server/Subscribe.hs b/lib/Server/Subscribe.hs
new file mode 100644
index 0000000..fdc092b
--- /dev/null
+++ b/lib/Server/Subscribe.hs
@@ -0,0 +1,63 @@
+
+module Server.Subscribe where
+import Conduit (MonadIO (..))
+import Control.Concurrent.STM (atomically, newTQueue, readTQueue,
+ readTVar, writeTVar)
+import Control.Exception (handle)
+import Control.Monad.Extra (forever, whenJust)
+import qualified Data.Aeson as A
+import qualified Data.ByteString.Char8 as C8
+import Data.Functor ((<&>))
+import Data.Map (Map)
+import qualified Data.Map as M
+import Data.Pool
+import Data.UUID (UUID)
+import Database.Persist (Entity (entityKey), SelectOpt (Desc),
+ entityVal, selectFirst, selectList,
+ (<-.), (==.), (||.))
+import Database.Persist.Sql (SqlBackend)
+import qualified Network.WebSockets as WS
+import Persist
+import Server.Base (ServerState)
+import Server.Util (ServiceM)
+
+
+handleSubscribe
+ :: Pool SqlBackend
+ -> ServerState
+ -> UUID
+ -> WS.Connection
+ -> ServiceM ()
+handleSubscribe dbpool subscribers (ticketId :: UUID) conn = liftIO $ WS.withPingThread conn 30 (pure ()) $ do
+ queue <- atomically $ do
+ queue <- newTQueue
+ qs <- readTVar subscribers
+ writeTVar subscribers
+ $ M.insertWith (<>) ticketId [queue] qs
+ pure queue
+ -- send most recent ping, if any (so we won't have to wait for movement)
+ lastPing <- runSqlWithoutLog dbpool $ do
+ trackers <- getTicketTrackers ticketId
+ <&> fmap entityKey
+ selectFirst [TrainPingToken <-. trackers] [Desc TrainPingTimestamp]
+ <&> fmap entityVal
+ whenJust lastPing $ \ping ->
+ WS.sendTextData conn (A.encode lastPing)
+ handle (\(e :: WS.ConnectionException) -> removeSubscriber queue) $ forever $ do
+ res <- atomically $ readTQueue queue
+ case res of
+ Just ping -> WS.sendTextData conn (A.encode ping)
+ Nothing -> do
+ removeSubscriber queue
+ WS.sendClose conn (C8.pack "train ended")
+ where removeSubscriber queue = atomically $ do
+ qs <- readTVar subscribers
+ writeTVar subscribers
+ $ M.adjust (filter (/= queue)) ticketId qs
+
+-- getTicketTrackers :: (MonadLogger (t (ResourceT IO)), MonadIO (t (ResourceT IO)))
+-- => UUID -> ReaderT SqlBackend (t (ResourceT IO)) [Entity Tracker]
+getTicketTrackers ticketId = do
+ joins <- selectList [TrackerTicketTicket ==. TicketKey ticketId] []
+ <&> fmap (trackerTicketTracker . entityVal)
+ selectList ([TrackerId <-. joins] ||. [TrackerCurrentTicket ==. Just (TicketKey ticketId)]) []
diff --git a/lib/Server/Util.hs b/lib/Server/Util.hs
index 0106428..290b9c5 100644
--- a/lib/Server/Util.hs
+++ b/lib/Server/Util.hs
@@ -1,33 +1,41 @@
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE RecordWildCards #-}
-- | mostly the monad the service runs in
-module Server.Util (Service, ServiceM, runService, sendErrorMsg, secondsNow, utcToSeconds, runLogging) where
+module Server.Util (Service, ServiceM, runService, sendErrorMsg, secondsNow, utcToSeconds, runLogging, getTzseries) where
-import Config (LoggingConfig (..))
-import Control.Exception (handle, try)
-import Control.Monad.Extra (void, whenJust)
-import Control.Monad.IO.Class (MonadIO (liftIO))
-import Control.Monad.Logger (Loc, LogLevel (..), LogSource, LogStr,
- LoggingT (..), defaultOutput,
- fromLogStr, runStderrLoggingT)
-import Control.Monad.Reader (ReaderT (..))
-import qualified Data.Aeson as A
-import Data.ByteString (ByteString)
-import qualified Data.ByteString as C8
-import Data.Text (Text)
-import qualified Data.Text as T
-import Data.Text.Encoding (decodeUtf8Lenient)
-import Data.Time (Day, UTCTime (..), diffUTCTime,
- getCurrentTime,
- nominalDiffTimeToSeconds)
-import Fmt ((+|), (|+))
-import GHC.IO.Exception (IOException (IOError))
-import GTFS (Seconds (..))
-import Prometheus (MonadMonitor (doIO))
-import Servant (Handler, ServerError, ServerT, err404,
- errBody, errHeaders, throwError)
-import System.IO (stderr)
-import System.Process.Extra (callProcess)
+import Config (LoggingConfig (..),
+ ServerConfig (..))
+import Control.Exception (handle, try)
+import Control.Monad.Extra (void, whenJust)
+import Control.Monad.IO.Class (MonadIO (liftIO))
+import Control.Monad.Logger (Loc, LogLevel (..),
+ LogSource, LogStr,
+ LoggingT (..),
+ defaultOutput, fromLogStr,
+ runStderrLoggingT)
+import Control.Monad.Reader (ReaderT (..))
+import qualified Data.Aeson as A
+import Data.ByteString (ByteString)
+import qualified Data.ByteString as C8
+import Data.Text (Text)
+import qualified Data.Text as T
+import Data.Text.Encoding (decodeUtf8Lenient)
+import Data.Time (Day, UTCTime (..),
+ diffUTCTime,
+ getCurrentTime,
+ nominalDiffTimeToSeconds)
+import Data.Time.LocalTime.TimeZone.Olson (getTimeZoneSeriesFromOlsonFile)
+import Data.Time.LocalTime.TimeZone.Series (TimeZoneSeries)
+import Fmt ((+|), (|+))
+import GHC.IO.Exception (IOException (IOError))
+import GTFS (Seconds (..))
+import Prometheus (MonadMonitor (doIO))
+import Servant (Handler, ServerError,
+ ServerT, err404, errBody,
+ errHeaders, throwError)
+import System.FilePath ((</>))
+import System.IO (stderr)
+import System.Process.Extra (callProcess)
type ServiceM = LoggingT (ReaderT LoggingConfig Handler)
type Service api = ServerT api ServiceM
@@ -77,3 +85,7 @@ secondsNow runningDay = do
utcToSeconds :: UTCTime -> Day -> Seconds
utcToSeconds time day =
Seconds $ round $ nominalDiffTimeToSeconds $ diffUTCTime time (UTCTime day 0)
+
+getTzseries :: ServerConfig -> Text -> IO TimeZoneSeries
+getTzseries settings tzname = getTimeZoneSeriesFromOlsonFile
+ (serverConfigZoneinfoPath settings </> T.unpack tzname)