aboutsummaryrefslogtreecommitdiff
path: root/lib/Server.hs
diff options
context:
space:
mode:
authorstuebinm2022-06-16 00:25:15 +0200
committerstuebinm2022-06-16 00:25:15 +0200
commit9e89c93b3b84b5c82c186cff62c33218a0a4d298 (patch)
treef810fc1eacc7b82e82543196257a2e93c5f21a9f /lib/Server.hs
parentd418ad82c98ab8dd3d540e910777fa530de350eb (diff)
actually use the database
(at least for a few simple things) Also, more modules!
Diffstat (limited to 'lib/Server.hs')
-rw-r--r--lib/Server.hs199
1 files changed, 103 insertions, 96 deletions
diff --git a/lib/Server.hs b/lib/Server.hs
index 1b79300..4a78735 100644
--- a/lib/Server.hs
+++ b/lib/Server.hs
@@ -6,11 +6,13 @@
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TemplateHaskell #-}
+{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
@@ -18,115 +20,120 @@
{-# LANGUAGE UndecidableInstances #-}
module Server where
-import Servant (Application,
- FromHttpApiData (parseUrlPiece),
- Server, err404, serve, throwError,
- type (:>))
-import Servant.API (Capture, FromHttpApiData, Get, JSON,
- Post, ReqBody, type (:<|>) ((:<|>)))
-import Servant.Docs (DocCapture (..), DocQueryParam (..),
- ParamKind (..), ToCapture (..),
- ToParam (..))
-
+import Conduit (MonadTrans (lift), ResourceT)
import Control.Concurrent.STM
-import Control.Monad.IO.Class (MonadIO (liftIO))
-import Data.Aeson (FromJSON (parseJSON), ToJSON (toJSON),
- ToJSONKey, genericParseJSON,
- genericToJSON)
-import qualified Data.Aeson as A
-import Data.Functor ((<&>))
-import Data.Map (Map)
-import qualified Data.Map as M
-import Data.Proxy (Proxy (Proxy))
-import Data.Swagger
-import Data.Text (Text)
-import Data.Time (UTCTime (utctDay), dayOfWeek,
- getCurrentTime)
-import Data.UUID (UUID)
-import qualified Data.UUID as UUID
-import qualified Data.UUID.V4 as UUID
-import Data.Vector (Vector)
+import Control.Monad (when)
+import Control.Monad.Extra (whenM)
+import Control.Monad.IO.Class (MonadIO (liftIO))
+import Control.Monad.Logger.CallStack (NoLoggingT)
+import Control.Monad.Reader (forM)
+import Control.Monad.Trans.Maybe (MaybeT (..))
+import Data.Aeson (FromJSON (parseJSON),
+ ToJSON (toJSON), ToJSONKey,
+ genericParseJSON,
+ genericToJSON)
+import qualified Data.Aeson as A
+import Data.Coerce (coerce)
+import Data.Functor ((<&>))
+import Data.Map (Map)
+import qualified Data.Map as M
+import Data.Pool (Pool)
+import Data.Proxy (Proxy (Proxy))
+import Data.Swagger hiding (get)
+import Data.Text (Text)
+import Data.Time (NominalDiffTime,
+ UTCTime (utctDay), addUTCTime,
+ dayOfWeek, diffUTCTime,
+ getCurrentTime, nominalDay)
+import Data.UUID (UUID)
+import qualified Data.UUID as UUID
+import qualified Data.UUID.V4 as UUID
+import Data.Vector (Vector)
import Database.Persist
-import Database.Persist.TH
-import GHC.Foreign (withCStringsLen)
-import GHC.Generics (Generic)
+import Database.Persist.Postgresql
+import GHC.Generics (Generic)
import GTFS
-import PersistOrphans
-import Servant.Server (Handler)
-import Servant.Swagger (toSwagger)
-
-newtype Token = Token UUID
- deriving newtype (Show, ToJSON, Eq, Ord, FromHttpApiData, ToJSONKey)
-instance ToSchema Token where
- declareNamedSchema _ = declareNamedSchema (Proxy @String)
-instance ToParamSchema Token where
- toParamSchema _ = toParamSchema (Proxy @String)
-
-share [mkPersist sqlSettings, mkMigrate "migrateAll"] [persistLowerCase|
-TripToken sql=tt_tracker_token
- Id UUID default=uuid_generate_v4()
- issued UTCTime
- blocked Bool
- tripNumber Text
- deriving Eq Show Generic
-
-TripPing json sql=tt_trip_ping
- token UUID
- latitude Double
- longitude Double
- delay Double
- timestamp UTCTime
- deriving Show Generic Eq
-
-|]
-
-
-instance ToSchema TripPing where
- declareNamedSchema = genericDeclareNamedSchema (swaggerOptions "ping")
-
-type KnownTrips = TVar (Map Token [TripPing])
-
-type API = "stations" :> Get '[JSON] (Map StationID Station)
- :<|> "timetable" :> Capture "Station ID" StationID :> Get '[JSON] (Map TripID (Trip Deep))
- :<|> "trip" :> Capture "Trip ID" TripID :> Get '[JSON] (Trip Deep)
- -- ingress API (put this behind BasicAuth?)
- -- TODO: perhaps require a first ping for registration?
- :<|> "trainregister" :> Capture "Trip ID" TripID :> Post '[JSON] Token
- -- TODO: perhaps a websocket instead?
- :<|> "trainping" :> Capture "Train Token" Token :> ReqBody '[JSON] TripPing :> Post '[JSON] ()
- -- debug things
- :<|> "debug" :> "state" :> Get '[JSON] (Map Token [TripPing])
-type CompleteAPI = "debug" :> "openapi" :> Get '[JSON] Swagger
- :<|> API
-
-
-
-
-
-server :: GTFS -> KnownTrips -> Server CompleteAPI
-server gtfs@GTFS{..} knownTrains = handleDebugAPI :<|> handleStations :<|> handleTimetable :<|> handleTrip
+import Servant (Application,
+ FromHttpApiData (parseUrlPiece),
+ Server, err401, err404, serve,
+ throwError, type (:>))
+import Servant.API (Capture, FromHttpApiData, Get,
+ JSON, Post, ReqBody,
+ type (:<|>) ((:<|>)))
+import Servant.Docs (DocCapture (..),
+ DocQueryParam (..),
+ ParamKind (..), ToCapture (..),
+ ToParam (..))
+import Servant.Server (Handler)
+import Servant.Swagger (toSwagger)
+import Web.PathPieces (PathPiece)
+
+import API
+import Persist
+
+application :: GTFS -> Pool SqlBackend -> IO Application
+application gtfs dbpool = do
+ doMigration dbpool
+ pure $ serve (Proxy @CompleteAPI) $ server gtfs dbpool
+
+
+
+-- databaseMigration :: ConnectionString -> IO ()
+doMigration pool = runSql pool $
+ -- TODO: before that, check if the uuid module is enabled
+ -- in sql: check if SELECT * FROM pg_extension WHERE extname = 'uuid-ossp';
+ -- returns an empty list
+ runMigration migrateAll
+
+server :: GTFS -> Pool SqlBackend -> Server CompleteAPI
+server gtfs@GTFS{..} dbpool = handleDebugAPI :<|> handleStations :<|> handleTimetable :<|> handleTrip
:<|> handleRegister :<|> handleTripPing :<|> handleDebugState
where handleStations = pure stations
handleTimetable station = do
+ -- TODO: resolve "overlay" trips (perhaps just additional CalendarDates?)
today <- liftIO getCurrentTime <&> utctDay
pure $ tripsOnDay gtfs today
handleTrip trip = case M.lookup trip trips of
Just res -> pure res
Nothing -> throwError err404
- handleRegister tripID = liftIO $ do
- token <- UUID.nextRandom <&> Token
- atomically $ modifyTVar knownTrains (M.insert token [])
- pure token
- handleTripPing token ping = liftIO $ atomically $ do
- modifyTVar knownTrains (M.update (\history -> Just (ping : history)) token)
- pure ()
- handleDebugState = liftIO $ readTVarIO knownTrains
+ handleRegister tripID = do
+ expires <- liftIO $ getCurrentTime <&> addUTCTime validityPeriod
+ RunningTripKey uuid <- runSql dbpool $ insert (RunningTrip expires False tripID)
+ pure (Token uuid)
+ handleTripPing ping = do
+ checkTokenValid dbpool (tripPingToken ping)
+ -- TODO: are these always inserted in order?
+ runSql dbpool $ insert ping
+ pure ()
+ handleDebugState = do
+ now <- liftIO $ getCurrentTime
+ runSql dbpool $ do
+ running <- selectList [RunningTripBlocked ==. False, RunningTripExpires >=. now] []
+ pairs <- forM running $ \(Entity (RunningTripKey uuid) _) -> do
+ entities <- selectList [TripPingToken ==. Token uuid] []
+ pure (Token uuid, fmap entityVal entities)
+ pure (M.fromList pairs)
handleDebugAPI = pure $ toSwagger (Proxy @API)
-application :: GTFS -> IO Application
-application gtfs = do
- knownTrips <- newTVarIO mempty
- pure $ serve (Proxy @CompleteAPI) $ server gtfs knownTrips
+checkTokenValid :: Pool SqlBackend -> Token -> Handler ()
+checkTokenValid dbpool token = do
+ trip <- try $ runSql dbpool $ get (coerce token)
+ when (runningTripBlocked trip)
+ $ throwError err401
+ whenM (hasExpired (runningTripExpires trip))
+ $ throwError err401
+ where try m = m >>= \case
+ Just a -> pure a
+ Nothing -> throwError err404
+
+hasExpired :: MonadIO m => UTCTime -> m Bool
+hasExpired limit = do
+ now <- liftIO getCurrentTime
+ pure (now > limit)
+
+validityPeriod :: NominalDiffTime
+validityPeriod = nominalDay
+