diff --git a/promise.hpp b/promise.hpp index febbf57..00eb8c8 100644 --- a/promise.hpp +++ b/promise.hpp @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -20,6 +21,7 @@ #include #include #include +#include // // invoke.hpp @@ -368,6 +370,15 @@ namespace promise_hpp }; } + // + // promise_wait_status + // + + enum class promise_wait_status { + no_timeout, + timeout + }; + // // promise // @@ -376,12 +387,6 @@ namespace promise_hpp class promise final { public: using value_type = T; - - enum class status : std::uint8_t { - pending, - resolved, - rejected - }; public: promise() : state_(std::make_shared()) {} @@ -535,6 +540,28 @@ namespace promise_hpp return state_->reject( std::make_exception_ptr(std::forward(e))); } + + const T& get() const { + return state_->get(); + } + + void wait() const { + state_->wait(); + } + + template < typename Rep, typename Period > + promise_wait_status wait_for( + const std::chrono::duration& timeout_duration) const + { + return state_->wait_for(timeout_duration); + } + + template < typename Clock, typename Duration > + promise_wait_status wait_until( + const std::chrono::time_point& timeout_time) const + { + return state_->wait_until(timeout_time); + } private: class state; std::shared_ptr state_; @@ -552,6 +579,7 @@ namespace promise_hpp storage_.set(std::forward(value)); status_ = status::resolved; invoke_resolve_handlers_(); + cond_var_.notify_all(); return true; } @@ -563,9 +591,51 @@ namespace promise_hpp exception_ = e; status_ = status::rejected; invoke_reject_handlers_(); + cond_var_.notify_all(); return true; } + const T& get() { + std::unique_lock lock(mutex_); + cond_var_.wait(lock, [this](){ + return status_ != status::pending; + }); + if ( status_ == status::rejected ) { + std::rethrow_exception(exception_); + } + assert(status_ == status::resolved); + return storage_.value(); + } + + void wait() const { + std::unique_lock lock(mutex_); + cond_var_.wait(lock, [this](){ + return status_ != status::pending; + }); + } + + template < typename Rep, typename Period > + promise_wait_status wait_for( + const std::chrono::duration& timeout_duration) const + { + std::unique_lock lock(mutex_); + return cond_var_.wait_for(lock, timeout_duration, [this](){ + return status_ != status::pending; + }) ? promise_wait_status::no_timeout + : promise_wait_status::timeout; + } + + template < typename Clock, typename Duration > + promise_wait_status wait_until( + const std::chrono::time_point& timeout_time) const + { + std::unique_lock lock(mutex_); + return cond_var_.wait_until(lock, timeout_time, [this](){ + return status_ != status::pending; + }) ? promise_wait_status::no_timeout + : promise_wait_status::timeout; + } + template < typename U, typename ResolveF, typename RejectF > std::enable_if_t::value, void> attach(promise& next, ResolveF&& resolve, RejectF&& reject) { @@ -674,11 +744,17 @@ namespace promise_hpp handlers_.clear(); } private: - detail::storage storage_; - status status_ = status::pending; - std::exception_ptr exception_ = nullptr; + enum class status { + pending, + resolved, + rejected + }; - std::mutex mutex_; + status status_{status::pending}; + std::exception_ptr exception_{nullptr}; + + mutable std::mutex mutex_; + mutable std::condition_variable cond_var_; struct handler { using resolve_t = std::function; @@ -694,6 +770,7 @@ namespace promise_hpp }; std::vector handlers_; + detail::storage storage_; }; }; @@ -705,12 +782,6 @@ namespace promise_hpp class promise final { public: using value_type = void; - - enum class status : std::uint8_t { - pending, - resolved, - rejected - }; public: promise() : state_(std::make_shared()) {} @@ -858,6 +929,28 @@ namespace promise_hpp return state_->reject( std::make_exception_ptr(std::forward(e))); } + + void get() const { + state_->get(); + } + + void wait() const { + state_->wait(); + } + + template < typename Rep, typename Period > + promise_wait_status wait_for( + const std::chrono::duration& timeout_duration) const + { + return state_->wait_for(timeout_duration); + } + + template < typename Clock, typename Duration > + promise_wait_status wait_until( + const std::chrono::time_point& timeout_time) const + { + return state_->wait_until(timeout_time); + } private: class state; std::shared_ptr state_; @@ -873,6 +966,7 @@ namespace promise_hpp } status_ = status::resolved; invoke_resolve_handlers_(); + cond_var_.notify_all(); return true; } @@ -884,9 +978,50 @@ namespace promise_hpp exception_ = e; status_ = status::rejected; invoke_reject_handlers_(); + cond_var_.notify_all(); return true; } + void get() { + std::unique_lock lock(mutex_); + cond_var_.wait(lock, [this](){ + return status_ != status::pending; + }); + if ( status_ == status::rejected ) { + std::rethrow_exception(exception_); + } + assert(status_ == status::resolved); + } + + void wait() const { + std::unique_lock lock(mutex_); + cond_var_.wait(lock, [this](){ + return status_ != status::pending; + }); + } + + template < typename Rep, typename Period > + promise_wait_status wait_for( + const std::chrono::duration& timeout_duration) const + { + std::unique_lock lock(mutex_); + return cond_var_.wait_for(lock, timeout_duration, [this](){ + return status_ != status::pending; + }) ? promise_wait_status::no_timeout + : promise_wait_status::timeout; + } + + template < typename Clock, typename Duration > + promise_wait_status wait_until( + const std::chrono::time_point& timeout_time) const + { + std::unique_lock lock(mutex_); + return cond_var_.wait_until(lock, timeout_time, [this](){ + return status_ != status::pending; + }) ? promise_wait_status::no_timeout + : promise_wait_status::timeout; + } + template < typename U, typename ResolveF, typename RejectF > std::enable_if_t::value, void> attach(promise& next, ResolveF&& resolve, RejectF&& reject) { @@ -991,10 +1126,17 @@ namespace promise_hpp handlers_.clear(); } private: - status status_ = status::pending; - std::exception_ptr exception_ = nullptr; + enum class status { + pending, + resolved, + rejected + }; - std::mutex mutex_; + status status_{status::pending}; + std::exception_ptr exception_{nullptr}; + + mutable std::mutex mutex_; + mutable std::condition_variable cond_var_; struct handler { using resolve_t = std::function; diff --git a/tests.cpp b/tests.cpp index f366468..503dd12 100644 --- a/tests.cpp +++ b/tests.cpp @@ -10,6 +10,7 @@ #include "promise.hpp" namespace pr = promise_hpp; +#include #include namespace @@ -36,6 +37,25 @@ namespace return false; } } + + class auto_thread final { + public: + template < typename F, typename... Args > + auto_thread(F&& f, Args&&... args) + : thread_(std::forward(f), std::forward(args)...) {} + + ~auto_thread() noexcept { + if ( thread_.joinable() ) { + thread_.join(); + } + } + + void join() { + thread_.join(); + } + private: + std::thread thread_; + }; } TEST_CASE("is_promise") { @@ -531,12 +551,10 @@ TEST_CASE("promise") { { bool call_fail_with_logic_error = false; auto p1 = pr::make_resolved_promise(42); - auto p2 = pr::make_resolved_promise(84); - p1.then([&p2](int v){ + p1.then([](int v) -> pr::promise { (void)v; throw std::logic_error("hello fail"); - return p2; }).then([](int v2){ (void)v2; }).except([&call_fail_with_logic_error](std::exception_ptr e){ @@ -599,11 +617,9 @@ TEST_CASE("promise") { { bool call_fail_with_logic_error = false; auto p1 = pr::make_resolved_promise(); - auto p2 = pr::make_resolved_promise(); - p1.then([&p2](){ + p1.then([]() -> pr::promise { throw std::logic_error("hello fail"); - return p2; }).then([](){ }).except([&call_fail_with_logic_error](std::exception_ptr e){ call_fail_with_logic_error = check_hello_fail_exception(e); @@ -850,3 +866,114 @@ TEST_CASE("promise") { } } } + +TEST_CASE("get_and_wait") { + SECTION("get_void_promises") { + { + auto p = pr::make_resolved_promise(); + REQUIRE_NOTHROW(p.get()); + } + { + auto p = pr::make_rejected_promise(std::logic_error("hello fail")); + REQUIRE_THROWS_AS(p.get(), std::logic_error); + } + { + auto p = pr::promise(); + auto_thread t{[p]() mutable { + std::this_thread::sleep_for(std::chrono::milliseconds(5)); + p.resolve(); + }}; + t.join(); + REQUIRE_NOTHROW(p.get()); + } + { + auto p = pr::promise(); + auto_thread t{[p]() mutable { + std::this_thread::sleep_for(std::chrono::milliseconds(5)); + p.resolve(); + }}; + REQUIRE_NOTHROW(p.get()); + } + { + auto p = pr::promise(); + auto_thread t{[p]() mutable { + std::this_thread::sleep_for(std::chrono::milliseconds(30)); + p.resolve(); + }}; + REQUIRE(p.wait_for( + std::chrono::milliseconds(1)) + == pr::promise_wait_status::timeout); + REQUIRE(p.wait_for( + std::chrono::milliseconds(60)) + == pr::promise_wait_status::no_timeout); + } + { + auto p = pr::promise(); + auto_thread t{[p]() mutable { + std::this_thread::sleep_for(std::chrono::milliseconds(30)); + p.resolve(); + }}; + REQUIRE(p.wait_until( + std::chrono::high_resolution_clock::now() + std::chrono::milliseconds(1)) + == pr::promise_wait_status::timeout); + REQUIRE(p.wait_until( + std::chrono::high_resolution_clock::now() + std::chrono::milliseconds(60)) + == pr::promise_wait_status::no_timeout); + } + } + SECTION("get_typed_promises") { + { + auto p = pr::make_resolved_promise(42); + REQUIRE(p.get() == 42); + } + { + auto p = pr::make_rejected_promise(std::logic_error("hello fail")); + REQUIRE_THROWS_AS(p.get(), std::logic_error); + } + { + auto p = pr::promise(); + auto_thread t{[p]() mutable { + std::this_thread::sleep_for(std::chrono::milliseconds(5)); + p.resolve(42); + }}; + t.join(); + REQUIRE(p.get() == 42); + } + { + auto p = pr::promise(); + auto_thread t{[p]() mutable { + std::this_thread::sleep_for(std::chrono::milliseconds(5)); + p.resolve(42); + }}; + REQUIRE(p.get() == 42); + } + { + auto p = pr::promise(); + auto_thread t{[p]() mutable { + std::this_thread::sleep_for(std::chrono::milliseconds(30)); + p.resolve(42); + }}; + REQUIRE(p.wait_for( + std::chrono::milliseconds(1)) + == pr::promise_wait_status::timeout); + REQUIRE(p.wait_for( + std::chrono::milliseconds(60)) + == pr::promise_wait_status::no_timeout); + REQUIRE(p.get() == 42); + } + { + auto p = pr::promise(); + auto_thread t{[p]() mutable { + std::this_thread::sleep_for(std::chrono::milliseconds(30)); + p.resolve(42); + }}; + REQUIRE(p.wait_until( + std::chrono::high_resolution_clock::now() + std::chrono::milliseconds(1)) + == pr::promise_wait_status::timeout); + REQUIRE(p.wait_until( + std::chrono::high_resolution_clock::now() + std::chrono::milliseconds(60)) + == pr::promise_wait_status::no_timeout); + REQUIRE(p.get() == 42); + } + } +}