aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorstuebinm2022-06-16 00:25:15 +0200
committerstuebinm2022-06-16 00:25:15 +0200
commit9e89c93b3b84b5c82c186cff62c33218a0a4d298 (patch)
treef810fc1eacc7b82e82543196257a2e93c5f21a9f
parentd418ad82c98ab8dd3d540e910777fa530de350eb (diff)
actually use the database
(at least for a few simple things) Also, more modules!
-rw-r--r--app/Main.hs11
-rw-r--r--haskell-gtfs.cabal5
-rw-r--r--lib/API.hs35
-rw-r--r--lib/Persist.hs85
-rw-r--r--lib/Server.hs199
5 files changed, 236 insertions, 99 deletions
diff --git a/app/Main.hs b/app/Main.hs
index 0b3165e..7d3b5dc 100644
--- a/app/Main.hs
+++ b/app/Main.hs
@@ -26,16 +26,21 @@ import Network.Wai.Handler.Warp (run)
import Network.Wai.Middleware.RequestLogger (OutputFormat (..),
RequestLoggerSettings (..),
mkRequestLogger)
+import Database.Persist.Postgresql
+import Control.Monad.Logger (runStderrLoggingT)
+import Control.Monad.IO.Class (MonadIO (liftIO))
import GTFS
import Server
+connStr = "user=travelynx"
main :: IO ()
main = do
gtfs <- loadGtfs "./gtfs.zip"
- app <- application gtfs
loggerMiddleware <- mkRequestLogger
$ def { outputFormat = Detailed True }
- putStrLn "starting server …"
- run 4000 (loggerMiddleware app)
+ runStderrLoggingT $ withPostgresqlPool connStr 10 $ \pool -> liftIO $ do
+ app <- application gtfs pool
+ putStrLn "starting server …"
+ run 4000 (loggerMiddleware app)
diff --git a/haskell-gtfs.cabal b/haskell-gtfs.cabal
index a54a9ff..33700b0 100644
--- a/haskell-gtfs.cabal
+++ b/haskell-gtfs.cabal
@@ -32,6 +32,8 @@ executable haskell-gtfs
, wai-extra
, warp
, data-default-class >= 0.1.2
+ , persistent-postgresql
+ , monad-logger
hs-source-dirs: app
default-language: Haskell2010
default-extensions: OverloadedStrings
@@ -67,6 +69,9 @@ library
, conduit
, path-pieces
, either
+ , resource-pool
+ , transformers
+ , extra
hs-source-dirs: lib
exposed-modules: GTFS, Server, PersistOrphans
default-language: Haskell2010
diff --git a/lib/API.hs b/lib/API.hs
new file mode 100644
index 0000000..3fb4c3c
--- /dev/null
+++ b/lib/API.hs
@@ -0,0 +1,35 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DerivingStrategies #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE MultiParamTypeClasses #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+
+
+module API where
+
+import Data.Map (Map)
+import Data.Swagger (Swagger)
+import GTFS
+import Persist
+import Servant (Application, FromHttpApiData (parseUrlPiece),
+ Server, err401, err404, serve, throwError,
+ type (:>))
+import Servant.API (Capture, FromHttpApiData, Get, JSON, Post,
+ ReqBody, type (:<|>) ((:<|>)))
+
+type API = "stations" :> Get '[JSON] (Map StationID Station)
+ :<|> "timetable" :> Capture "Station ID" StationID :> Get '[JSON] (Map TripID (Trip Deep))
+ :<|> "trip" :> Capture "Trip ID" TripID :> Get '[JSON] (Trip Deep)
+ -- ingress API (put this behind BasicAuth?)
+ -- TODO: perhaps require a first ping for registration?
+ :<|> "trip" :> "register" :> Capture "Trip ID" TripID :> Post '[JSON] Token
+ -- TODO: perhaps a websocket instead?
+ :<|> "trip" :> "ping" :> ReqBody '[JSON] TripPing :> Post '[JSON] ()
+ -- debug things
+ :<|> "debug" :> "state" :> Get '[JSON] (Map Token [TripPing])
+type CompleteAPI = "debug" :> "openapi" :> Get '[JSON] Swagger
+ :<|> API
diff --git a/lib/Persist.hs b/lib/Persist.hs
new file mode 100644
index 0000000..b4df1fb
--- /dev/null
+++ b/lib/Persist.hs
@@ -0,0 +1,85 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DeriveAnyClass #-}
+{-# LANGUAGE DeriveGeneric #-}
+{-# LANGUAGE DerivingStrategies #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE MultiParamTypeClasses #-}
+{-# LANGUAGE QuasiQuotes #-}
+{-# LANGUAGE RecordWildCards #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TemplateHaskell #-}
+{-# LANGUAGE TupleSections #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+
+
+module Persist where
+
+import Data.Aeson (FromJSON, ToJSON, ToJSONKey)
+import Data.Swagger (ToParamSchema (..), ToSchema (..),
+ genericDeclareNamedSchema)
+import Data.Text (Text)
+import Data.UUID (UUID)
+import Database.Persist
+import Database.Persist.Sql (PersistFieldSql,
+ runSqlPersistMPool)
+import Database.Persist.TH
+import GTFS
+import PersistOrphans
+import Servant (FromHttpApiData)
+
+import Conduit (ResourceT)
+import Control.Monad.IO.Class (MonadIO (liftIO))
+import Control.Monad.Logger (NoLoggingT)
+import Control.Monad.Reader (ReaderT)
+import Data.Data (Proxy (..))
+import Data.Pool (Pool)
+import Data.Time (NominalDiffTime,
+ UTCTime (utctDay), addUTCTime,
+ dayOfWeek, diffUTCTime,
+ getCurrentTime, nominalDay)
+import Database.Persist.Postgresql (SqlBackend)
+import GHC.Generics (Generic)
+import Web.PathPieces (PathPiece)
+
+
+
+newtype Token = Token UUID
+ deriving newtype
+ ( Show, ToJSON, FromJSON, Eq, Ord, FromHttpApiData
+ , ToJSONKey, PersistField, PersistFieldSql, PathPiece)
+instance ToSchema Token where
+ declareNamedSchema _ = declareNamedSchema (Proxy @String)
+instance ToParamSchema Token where
+ toParamSchema _ = toParamSchema (Proxy @String)
+
+
+share [mkPersist sqlSettings, mkMigrate "migrateAll"] [persistLowerCase|
+RunningTrip sql=tt_tracker_token
+ Id UUID default=uuid_generate_v4()
+ expires UTCTime
+ blocked Bool
+ tripNumber Text
+ deriving Eq Show Generic
+
+TripPing json sql=tt_trip_ping
+ token Token
+ latitude Double
+ longitude Double
+ delay Double
+ timestamp UTCTime
+ deriving Show Generic Eq
+
+|]
+
+instance ToSchema TripPing where
+ declareNamedSchema = genericDeclareNamedSchema (swaggerOptions "ping")
+
+runSql :: MonadIO m => Pool SqlBackend -> (ReaderT SqlBackend (NoLoggingT (ResourceT IO)) a) -> m a
+runSql pool = liftIO . flip runSqlPersistMPool pool
diff --git a/lib/Server.hs b/lib/Server.hs
index 1b79300..4a78735 100644
--- a/lib/Server.hs
+++ b/lib/Server.hs
@@ -6,11 +6,13 @@
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TemplateHaskell #-}
+{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
@@ -18,115 +20,120 @@
{-# LANGUAGE UndecidableInstances #-}
module Server where
-import Servant (Application,
- FromHttpApiData (parseUrlPiece),
- Server, err404, serve, throwError,
- type (:>))
-import Servant.API (Capture, FromHttpApiData, Get, JSON,
- Post, ReqBody, type (:<|>) ((:<|>)))
-import Servant.Docs (DocCapture (..), DocQueryParam (..),
- ParamKind (..), ToCapture (..),
- ToParam (..))
-
+import Conduit (MonadTrans (lift), ResourceT)
import Control.Concurrent.STM
-import Control.Monad.IO.Class (MonadIO (liftIO))
-import Data.Aeson (FromJSON (parseJSON), ToJSON (toJSON),
- ToJSONKey, genericParseJSON,
- genericToJSON)
-import qualified Data.Aeson as A
-import Data.Functor ((<&>))
-import Data.Map (Map)
-import qualified Data.Map as M
-import Data.Proxy (Proxy (Proxy))
-import Data.Swagger
-import Data.Text (Text)
-import Data.Time (UTCTime (utctDay), dayOfWeek,
- getCurrentTime)
-import Data.UUID (UUID)
-import qualified Data.UUID as UUID
-import qualified Data.UUID.V4 as UUID
-import Data.Vector (Vector)
+import Control.Monad (when)
+import Control.Monad.Extra (whenM)
+import Control.Monad.IO.Class (MonadIO (liftIO))
+import Control.Monad.Logger.CallStack (NoLoggingT)
+import Control.Monad.Reader (forM)
+import Control.Monad.Trans.Maybe (MaybeT (..))
+import Data.Aeson (FromJSON (parseJSON),
+ ToJSON (toJSON), ToJSONKey,
+ genericParseJSON,
+ genericToJSON)
+import qualified Data.Aeson as A
+import Data.Coerce (coerce)
+import Data.Functor ((<&>))
+import Data.Map (Map)
+import qualified Data.Map as M
+import Data.Pool (Pool)
+import Data.Proxy (Proxy (Proxy))
+import Data.Swagger hiding (get)
+import Data.Text (Text)
+import Data.Time (NominalDiffTime,
+ UTCTime (utctDay), addUTCTime,
+ dayOfWeek, diffUTCTime,
+ getCurrentTime, nominalDay)
+import Data.UUID (UUID)
+import qualified Data.UUID as UUID
+import qualified Data.UUID.V4 as UUID
+import Data.Vector (Vector)
import Database.Persist
-import Database.Persist.TH
-import GHC.Foreign (withCStringsLen)
-import GHC.Generics (Generic)
+import Database.Persist.Postgresql
+import GHC.Generics (Generic)
import GTFS
-import PersistOrphans
-import Servant.Server (Handler)
-import Servant.Swagger (toSwagger)
-
-newtype Token = Token UUID
- deriving newtype (Show, ToJSON, Eq, Ord, FromHttpApiData, ToJSONKey)
-instance ToSchema Token where
- declareNamedSchema _ = declareNamedSchema (Proxy @String)
-instance ToParamSchema Token where
- toParamSchema _ = toParamSchema (Proxy @String)
-
-share [mkPersist sqlSettings, mkMigrate "migrateAll"] [persistLowerCase|
-TripToken sql=tt_tracker_token
- Id UUID default=uuid_generate_v4()
- issued UTCTime
- blocked Bool
- tripNumber Text
- deriving Eq Show Generic
-
-TripPing json sql=tt_trip_ping
- token UUID
- latitude Double
- longitude Double
- delay Double
- timestamp UTCTime
- deriving Show Generic Eq
-
-|]
-
-
-instance ToSchema TripPing where
- declareNamedSchema = genericDeclareNamedSchema (swaggerOptions "ping")
-
-type KnownTrips = TVar (Map Token [TripPing])
-
-type API = "stations" :> Get '[JSON] (Map StationID Station)
- :<|> "timetable" :> Capture "Station ID" StationID :> Get '[JSON] (Map TripID (Trip Deep))
- :<|> "trip" :> Capture "Trip ID" TripID :> Get '[JSON] (Trip Deep)
- -- ingress API (put this behind BasicAuth?)
- -- TODO: perhaps require a first ping for registration?
- :<|> "trainregister" :> Capture "Trip ID" TripID :> Post '[JSON] Token
- -- TODO: perhaps a websocket instead?
- :<|> "trainping" :> Capture "Train Token" Token :> ReqBody '[JSON] TripPing :> Post '[JSON] ()
- -- debug things
- :<|> "debug" :> "state" :> Get '[JSON] (Map Token [TripPing])
-type CompleteAPI = "debug" :> "openapi" :> Get '[JSON] Swagger
- :<|> API
-
-
-
-
-
-server :: GTFS -> KnownTrips -> Server CompleteAPI
-server gtfs@GTFS{..} knownTrains = handleDebugAPI :<|> handleStations :<|> handleTimetable :<|> handleTrip
+import Servant (Application,
+ FromHttpApiData (parseUrlPiece),
+ Server, err401, err404, serve,
+ throwError, type (:>))
+import Servant.API (Capture, FromHttpApiData, Get,
+ JSON, Post, ReqBody,
+ type (:<|>) ((:<|>)))
+import Servant.Docs (DocCapture (..),
+ DocQueryParam (..),
+ ParamKind (..), ToCapture (..),
+ ToParam (..))
+import Servant.Server (Handler)
+import Servant.Swagger (toSwagger)
+import Web.PathPieces (PathPiece)
+
+import API
+import Persist
+
+application :: GTFS -> Pool SqlBackend -> IO Application
+application gtfs dbpool = do
+ doMigration dbpool
+ pure $ serve (Proxy @CompleteAPI) $ server gtfs dbpool
+
+
+
+-- databaseMigration :: ConnectionString -> IO ()
+doMigration pool = runSql pool $
+ -- TODO: before that, check if the uuid module is enabled
+ -- in sql: check if SELECT * FROM pg_extension WHERE extname = 'uuid-ossp';
+ -- returns an empty list
+ runMigration migrateAll
+
+server :: GTFS -> Pool SqlBackend -> Server CompleteAPI
+server gtfs@GTFS{..} dbpool = handleDebugAPI :<|> handleStations :<|> handleTimetable :<|> handleTrip
:<|> handleRegister :<|> handleTripPing :<|> handleDebugState
where handleStations = pure stations
handleTimetable station = do
+ -- TODO: resolve "overlay" trips (perhaps just additional CalendarDates?)
today <- liftIO getCurrentTime <&> utctDay
pure $ tripsOnDay gtfs today
handleTrip trip = case M.lookup trip trips of
Just res -> pure res
Nothing -> throwError err404
- handleRegister tripID = liftIO $ do
- token <- UUID.nextRandom <&> Token
- atomically $ modifyTVar knownTrains (M.insert token [])
- pure token
- handleTripPing token ping = liftIO $ atomically $ do
- modifyTVar knownTrains (M.update (\history -> Just (ping : history)) token)
- pure ()
- handleDebugState = liftIO $ readTVarIO knownTrains
+ handleRegister tripID = do
+ expires <- liftIO $ getCurrentTime <&> addUTCTime validityPeriod
+ RunningTripKey uuid <- runSql dbpool $ insert (RunningTrip expires False tripID)
+ pure (Token uuid)
+ handleTripPing ping = do
+ checkTokenValid dbpool (tripPingToken ping)
+ -- TODO: are these always inserted in order?
+ runSql dbpool $ insert ping
+ pure ()
+ handleDebugState = do
+ now <- liftIO $ getCurrentTime
+ runSql dbpool $ do
+ running <- selectList [RunningTripBlocked ==. False, RunningTripExpires >=. now] []
+ pairs <- forM running $ \(Entity (RunningTripKey uuid) _) -> do
+ entities <- selectList [TripPingToken ==. Token uuid] []
+ pure (Token uuid, fmap entityVal entities)
+ pure (M.fromList pairs)
handleDebugAPI = pure $ toSwagger (Proxy @API)
-application :: GTFS -> IO Application
-application gtfs = do
- knownTrips <- newTVarIO mempty
- pure $ serve (Proxy @CompleteAPI) $ server gtfs knownTrips
+checkTokenValid :: Pool SqlBackend -> Token -> Handler ()
+checkTokenValid dbpool token = do
+ trip <- try $ runSql dbpool $ get (coerce token)
+ when (runningTripBlocked trip)
+ $ throwError err401
+ whenM (hasExpired (runningTripExpires trip))
+ $ throwError err401
+ where try m = m >>= \case
+ Just a -> pure a
+ Nothing -> throwError err404
+
+hasExpired :: MonadIO m => UTCTime -> m Bool
+hasExpired limit = do
+ now <- liftIO getCurrentTime
+ pure (now > limit)
+
+validityPeriod :: NominalDiffTime
+validityPeriod = nominalDay
+