get, wait, wait_for, wait_until promise functions

This commit is contained in:
2018-12-11 22:45:53 +07:00
parent d845543433
commit bfd068d185
2 changed files with 294 additions and 25 deletions

View File

@@ -12,6 +12,7 @@
#include <new>
#include <mutex>
#include <atomic>
#include <chrono>
#include <memory>
#include <vector>
#include <utility>
@@ -20,6 +21,7 @@
#include <stdexcept>
#include <functional>
#include <type_traits>
#include <condition_variable>
//
// invoke.hpp
@@ -368,6 +370,15 @@ namespace promise_hpp
};
}
//
// promise_wait_status
//
enum class promise_wait_status {
no_timeout,
timeout
};
//
// promise<T>
//
@@ -376,12 +387,6 @@ namespace promise_hpp
class promise final {
public:
using value_type = T;
enum class status : std::uint8_t {
pending,
resolved,
rejected
};
public:
promise()
: state_(std::make_shared<state>()) {}
@@ -535,6 +540,28 @@ namespace promise_hpp
return state_->reject(
std::make_exception_ptr(std::forward<E>(e)));
}
const T& get() const {
return state_->get();
}
void wait() const {
state_->wait();
}
template < typename Rep, typename Period >
promise_wait_status wait_for(
const std::chrono::duration<Rep, Period>& timeout_duration) const
{
return state_->wait_for(timeout_duration);
}
template < typename Clock, typename Duration >
promise_wait_status wait_until(
const std::chrono::time_point<Clock, Duration>& timeout_time) const
{
return state_->wait_until(timeout_time);
}
private:
class state;
std::shared_ptr<state> state_;
@@ -552,6 +579,7 @@ namespace promise_hpp
storage_.set(std::forward<U>(value));
status_ = status::resolved;
invoke_resolve_handlers_();
cond_var_.notify_all();
return true;
}
@@ -563,9 +591,51 @@ namespace promise_hpp
exception_ = e;
status_ = status::rejected;
invoke_reject_handlers_();
cond_var_.notify_all();
return true;
}
const T& get() {
std::unique_lock<std::mutex> lock(mutex_);
cond_var_.wait(lock, [this](){
return status_ != status::pending;
});
if ( status_ == status::rejected ) {
std::rethrow_exception(exception_);
}
assert(status_ == status::resolved);
return storage_.value();
}
void wait() const {
std::unique_lock<std::mutex> lock(mutex_);
cond_var_.wait(lock, [this](){
return status_ != status::pending;
});
}
template < typename Rep, typename Period >
promise_wait_status wait_for(
const std::chrono::duration<Rep, Period>& timeout_duration) const
{
std::unique_lock<std::mutex> lock(mutex_);
return cond_var_.wait_for(lock, timeout_duration, [this](){
return status_ != status::pending;
}) ? promise_wait_status::no_timeout
: promise_wait_status::timeout;
}
template < typename Clock, typename Duration >
promise_wait_status wait_until(
const std::chrono::time_point<Clock, Duration>& timeout_time) const
{
std::unique_lock<std::mutex> lock(mutex_);
return cond_var_.wait_until(lock, timeout_time, [this](){
return status_ != status::pending;
}) ? promise_wait_status::no_timeout
: promise_wait_status::timeout;
}
template < typename U, typename ResolveF, typename RejectF >
std::enable_if_t<std::is_void<U>::value, void>
attach(promise<U>& next, ResolveF&& resolve, RejectF&& reject) {
@@ -674,11 +744,17 @@ namespace promise_hpp
handlers_.clear();
}
private:
detail::storage<T> storage_;
status status_ = status::pending;
std::exception_ptr exception_ = nullptr;
enum class status {
pending,
resolved,
rejected
};
std::mutex mutex_;
status status_{status::pending};
std::exception_ptr exception_{nullptr};
mutable std::mutex mutex_;
mutable std::condition_variable cond_var_;
struct handler {
using resolve_t = std::function<void(const T&)>;
@@ -694,6 +770,7 @@ namespace promise_hpp
};
std::vector<handler> handlers_;
detail::storage<T> storage_;
};
};
@@ -705,12 +782,6 @@ namespace promise_hpp
class promise<void> final {
public:
using value_type = void;
enum class status : std::uint8_t {
pending,
resolved,
rejected
};
public:
promise()
: state_(std::make_shared<state>()) {}
@@ -858,6 +929,28 @@ namespace promise_hpp
return state_->reject(
std::make_exception_ptr(std::forward<E>(e)));
}
void get() const {
state_->get();
}
void wait() const {
state_->wait();
}
template < typename Rep, typename Period >
promise_wait_status wait_for(
const std::chrono::duration<Rep, Period>& timeout_duration) const
{
return state_->wait_for(timeout_duration);
}
template < typename Clock, typename Duration >
promise_wait_status wait_until(
const std::chrono::time_point<Clock, Duration>& timeout_time) const
{
return state_->wait_until(timeout_time);
}
private:
class state;
std::shared_ptr<state> state_;
@@ -873,6 +966,7 @@ namespace promise_hpp
}
status_ = status::resolved;
invoke_resolve_handlers_();
cond_var_.notify_all();
return true;
}
@@ -884,9 +978,50 @@ namespace promise_hpp
exception_ = e;
status_ = status::rejected;
invoke_reject_handlers_();
cond_var_.notify_all();
return true;
}
void get() {
std::unique_lock<std::mutex> lock(mutex_);
cond_var_.wait(lock, [this](){
return status_ != status::pending;
});
if ( status_ == status::rejected ) {
std::rethrow_exception(exception_);
}
assert(status_ == status::resolved);
}
void wait() const {
std::unique_lock<std::mutex> lock(mutex_);
cond_var_.wait(lock, [this](){
return status_ != status::pending;
});
}
template < typename Rep, typename Period >
promise_wait_status wait_for(
const std::chrono::duration<Rep, Period>& timeout_duration) const
{
std::unique_lock<std::mutex> lock(mutex_);
return cond_var_.wait_for(lock, timeout_duration, [this](){
return status_ != status::pending;
}) ? promise_wait_status::no_timeout
: promise_wait_status::timeout;
}
template < typename Clock, typename Duration >
promise_wait_status wait_until(
const std::chrono::time_point<Clock, Duration>& timeout_time) const
{
std::unique_lock<std::mutex> lock(mutex_);
return cond_var_.wait_until(lock, timeout_time, [this](){
return status_ != status::pending;
}) ? promise_wait_status::no_timeout
: promise_wait_status::timeout;
}
template < typename U, typename ResolveF, typename RejectF >
std::enable_if_t<std::is_void<U>::value, void>
attach(promise<U>& next, ResolveF&& resolve, RejectF&& reject) {
@@ -991,10 +1126,17 @@ namespace promise_hpp
handlers_.clear();
}
private:
status status_ = status::pending;
std::exception_ptr exception_ = nullptr;
enum class status {
pending,
resolved,
rejected
};
std::mutex mutex_;
status status_{status::pending};
std::exception_ptr exception_{nullptr};
mutable std::mutex mutex_;
mutable std::condition_variable cond_var_;
struct handler {
using resolve_t = std::function<void()>;

139
tests.cpp
View File

@@ -10,6 +10,7 @@
#include "promise.hpp"
namespace pr = promise_hpp;
#include <thread>
#include <cstring>
namespace
@@ -36,6 +37,25 @@ namespace
return false;
}
}
class auto_thread final {
public:
template < typename F, typename... Args >
auto_thread(F&& f, Args&&... args)
: thread_(std::forward<F>(f), std::forward<Args>(args)...) {}
~auto_thread() noexcept {
if ( thread_.joinable() ) {
thread_.join();
}
}
void join() {
thread_.join();
}
private:
std::thread thread_;
};
}
TEST_CASE("is_promise") {
@@ -531,12 +551,10 @@ TEST_CASE("promise") {
{
bool call_fail_with_logic_error = false;
auto p1 = pr::make_resolved_promise(42);
auto p2 = pr::make_resolved_promise(84);
p1.then([&p2](int v){
p1.then([](int v) -> pr::promise<int> {
(void)v;
throw std::logic_error("hello fail");
return p2;
}).then([](int v2){
(void)v2;
}).except([&call_fail_with_logic_error](std::exception_ptr e){
@@ -599,11 +617,9 @@ TEST_CASE("promise") {
{
bool call_fail_with_logic_error = false;
auto p1 = pr::make_resolved_promise();
auto p2 = pr::make_resolved_promise();
p1.then([&p2](){
p1.then([]() -> pr::promise<void> {
throw std::logic_error("hello fail");
return p2;
}).then([](){
}).except([&call_fail_with_logic_error](std::exception_ptr e){
call_fail_with_logic_error = check_hello_fail_exception(e);
@@ -850,3 +866,114 @@ TEST_CASE("promise") {
}
}
}
TEST_CASE("get_and_wait") {
SECTION("get_void_promises") {
{
auto p = pr::make_resolved_promise();
REQUIRE_NOTHROW(p.get());
}
{
auto p = pr::make_rejected_promise<void>(std::logic_error("hello fail"));
REQUIRE_THROWS_AS(p.get(), std::logic_error);
}
{
auto p = pr::promise<void>();
auto_thread t{[p]() mutable {
std::this_thread::sleep_for(std::chrono::milliseconds(5));
p.resolve();
}};
t.join();
REQUIRE_NOTHROW(p.get());
}
{
auto p = pr::promise<void>();
auto_thread t{[p]() mutable {
std::this_thread::sleep_for(std::chrono::milliseconds(5));
p.resolve();
}};
REQUIRE_NOTHROW(p.get());
}
{
auto p = pr::promise<void>();
auto_thread t{[p]() mutable {
std::this_thread::sleep_for(std::chrono::milliseconds(30));
p.resolve();
}};
REQUIRE(p.wait_for(
std::chrono::milliseconds(1))
== pr::promise_wait_status::timeout);
REQUIRE(p.wait_for(
std::chrono::milliseconds(60))
== pr::promise_wait_status::no_timeout);
}
{
auto p = pr::promise<void>();
auto_thread t{[p]() mutable {
std::this_thread::sleep_for(std::chrono::milliseconds(30));
p.resolve();
}};
REQUIRE(p.wait_until(
std::chrono::high_resolution_clock::now() + std::chrono::milliseconds(1))
== pr::promise_wait_status::timeout);
REQUIRE(p.wait_until(
std::chrono::high_resolution_clock::now() + std::chrono::milliseconds(60))
== pr::promise_wait_status::no_timeout);
}
}
SECTION("get_typed_promises") {
{
auto p = pr::make_resolved_promise(42);
REQUIRE(p.get() == 42);
}
{
auto p = pr::make_rejected_promise<int>(std::logic_error("hello fail"));
REQUIRE_THROWS_AS(p.get(), std::logic_error);
}
{
auto p = pr::promise<int>();
auto_thread t{[p]() mutable {
std::this_thread::sleep_for(std::chrono::milliseconds(5));
p.resolve(42);
}};
t.join();
REQUIRE(p.get() == 42);
}
{
auto p = pr::promise<int>();
auto_thread t{[p]() mutable {
std::this_thread::sleep_for(std::chrono::milliseconds(5));
p.resolve(42);
}};
REQUIRE(p.get() == 42);
}
{
auto p = pr::promise<int>();
auto_thread t{[p]() mutable {
std::this_thread::sleep_for(std::chrono::milliseconds(30));
p.resolve(42);
}};
REQUIRE(p.wait_for(
std::chrono::milliseconds(1))
== pr::promise_wait_status::timeout);
REQUIRE(p.wait_for(
std::chrono::milliseconds(60))
== pr::promise_wait_status::no_timeout);
REQUIRE(p.get() == 42);
}
{
auto p = pr::promise<int>();
auto_thread t{[p]() mutable {
std::this_thread::sleep_for(std::chrono::milliseconds(30));
p.resolve(42);
}};
REQUIRE(p.wait_until(
std::chrono::high_resolution_clock::now() + std::chrono::milliseconds(1))
== pr::promise_wait_status::timeout);
REQUIRE(p.wait_until(
std::chrono::high_resolution_clock::now() + std::chrono::milliseconds(60))
== pr::promise_wait_status::no_timeout);
REQUIRE(p.get() == 42);
}
}
}