(.module: [library [lux #* [abstract [monad (#+ do)]] [control [pipe (#+ if>)] ["." io (#+ IO)] ["." try (#+ Try)] ["." exception (#+ exception:)]] [data [text ["%" format (#+ format)]] [collection ["." queue (#+ Queue)]]] [math [number ["n" nat] ["i" int]]] [type abstract ["." refinement]]]] [// ["." atom (#+ Atom)] ["." async (#+ Async Resolver)]]) (type: State (Record {#max_positions Nat #open_positions Int #waiting_list (Queue (Resolver Any))})) (abstract: .public Semaphore {} (Atom State) (def: most_positions_possible (.nat (\ i.interval top))) (def: .public (semaphore initial_open_positions) (-> Nat Semaphore) (let [max_positions (n.min initial_open_positions ..most_positions_possible)] (:abstraction (atom.atom {#max_positions max_positions #open_positions (.int max_positions) #waiting_list queue.empty})))) (def: .public (wait! semaphore) (Ex (_ k) (-> Semaphore (Async Any))) (let [semaphore (:representation semaphore) [signal sink] (: [(Async Any) (Resolver Any)] (async.async []))] (exec (io.run! (with_expansions [ (as_is (value@ #open_positions) (i.> -1))] (do io.monad [[_ state'] (atom.update! (|>> (revised@ #open_positions --) (if> [] [] [(revised@ #waiting_list (queue.end sink))])) semaphore)] (with_expansions [ (sink []) (in false)] (if (|> state' ) ))))) signal))) (exception: .public (semaphore_is_maxed_out {max_positions Nat}) (exception.report ["Max Positions" (%.nat max_positions)])) (def: .public (signal! semaphore) (Ex (_ k) (-> Semaphore (Async (Try Int)))) (let [semaphore (:representation semaphore)] (async.future (do {! io.monad} [[pre post] (atom.update! (function (_ state) (if (i.= (.int (value@ #max_positions state)) (value@ #open_positions state)) state (|> state (revised@ #open_positions ++) (revised@ #waiting_list queue.next)))) semaphore)] (if (same? pre post) (in (exception.except ..semaphore_is_maxed_out [(value@ #max_positions pre)])) (do ! [_ (case (queue.front (value@ #waiting_list pre)) #.None (in true) (#.Some sink) (sink []))] (in (#try.Success (value@ #open_positions post))))))))) ) (abstract: .public Mutex {} Semaphore (def: .public (mutex _) (-> Any Mutex) (:abstraction (semaphore 1))) (def: acquire! (-> Mutex (Async Any)) (|>> :representation ..wait!)) (def: release! (-> Mutex (Async Any)) (|>> :representation ..signal!)) (def: .public (synchronize! mutex procedure) (All (_ a) (-> Mutex (IO (Async a)) (Async a))) (do async.monad [_ (..acquire! mutex) output (io.run! procedure) _ (..release! mutex)] (in output))) ) (def: .public limit (refinement.refiner (n.> 0))) (type: .public Limit (:~ (refinement.type limit))) (abstract: .public Barrier {} (Record {#limit Limit #count (Atom Nat) #start_turnstile Semaphore #end_turnstile Semaphore}) (def: .public (barrier limit) (-> Limit Barrier) (:abstraction {#limit limit #count (atom.atom 0) #start_turnstile (..semaphore 0) #end_turnstile (..semaphore 0)})) (def: (un_block! times turnstile) (-> Nat Semaphore (Async Any)) (loop [step 0] (if (n.< times step) (do async.monad [outcome (..signal! turnstile)] (recur (++ step))) (\ async.monad in [])))) (template [ ] [(def: ( (^:representation barrier)) (-> Barrier (Async Any)) (do async.monad [.let [limit (refinement.value (value@ #limit barrier)) goal [_ count] (io.run! (atom.update! (value@ #count barrier))) reached? (n.= goal count)]] (if reached? (..un_block! (-- limit) (value@ barrier)) (..wait! (value@ barrier)))))] [start! ++ limit #start_turnstile] [end! -- 0 #end_turnstile] ) (def: .public (block! barrier) (-> Barrier (Async Any)) (do async.monad [_ (..start! barrier)] (..end! barrier))) )