diff --git a/headers/promise.hpp/promise.hpp b/headers/promise.hpp/promise.hpp index 778729a..daff97e 100644 --- a/headers/promise.hpp/promise.hpp +++ b/headers/promise.hpp/promise.hpp @@ -86,6 +86,54 @@ namespace promise_hpp no_timeout, timeout }; + + // + // aggregate_exception + // + + class aggregate_exception final : public std::exception { + private: + using exceptions_t = std::vector; + using internal_state_t = std::shared_ptr; + public: + aggregate_exception() + : state_(std::make_shared()) {} + + explicit aggregate_exception(exceptions_t exceptions) + : state_(std::make_shared(std::move(exceptions))) {} + + aggregate_exception(const aggregate_exception& other) noexcept + : state_(other.state_) {} + + aggregate_exception& operator=(const aggregate_exception& other) noexcept { + if ( this != &other ) { + state_ = other.state_; + } + return *this; + } + + const char* what() const noexcept override { + return "Aggregate exception"; + } + + bool empty() const noexcept { + return (*state_).empty(); + } + + std::size_t size() const noexcept { + return (*state_).size(); + } + + std::exception_ptr at(std::size_t index) const { + return (*state_).at(index); + } + + std::exception_ptr operator[](std::size_t index) const noexcept { + return (*state_)[index]; + } + private: + internal_state_t state_; + }; } // ----------------------------------------------------------------------------- @@ -416,6 +464,21 @@ namespace promise_hpp [](auto&& v) { return std::forward(v); }, std::forward(on_reject)); } + + // + // finally + // + + template < typename FinallyF > + promise finally(FinallyF&& on_finally) { + return then([f = on_finally](auto&& v) { + std::invoke(std::move(f)); + return std::forward(v); + }, [f = on_finally](std::exception_ptr e) -> T { + std::invoke(std::move(f)); + std::rethrow_exception(e); + }); + } private: class state; std::shared_ptr state_; @@ -837,6 +900,20 @@ namespace promise_hpp [](){}, std::forward(on_reject)); } + + // + // finally + // + + template < typename FinallyF > + promise finally(FinallyF&& on_finally) { + return then([f = on_finally]() { + std::invoke(std::move(f)); + }, [f = on_finally](std::exception_ptr e) { + std::invoke(std::move(f)); + std::rethrow_exception(e); + }); + } private: class state; std::shared_ptr state_; @@ -1136,15 +1213,12 @@ namespace promise_hpp , results(count) {} }; - return make_promise([begin, end](auto&& resolver, auto&& rejector){ + return make_promise( + [begin, end](auto&& resolver, auto&& rejector){ std::size_t result_index = 0; auto context = std::make_shared(std::distance(begin, end)); for ( Iter iter = begin; iter != end; ++iter, ++result_index ) { - (*iter).then([ - context, - resolver, - result_index - ](auto&& v) mutable { + (*iter).then([context, resolver, result_index](auto&& v) mutable { context->results[result_index] = std::forward(v); if ( !--context->success_counter ) { std::vector results; @@ -1154,7 +1228,9 @@ namespace promise_hpp } resolver(std::move(results)); } - }).except(rejector); + }).except([rejector](std::exception_ptr e) mutable { + rejector(e); + }); } }); } @@ -1172,27 +1248,33 @@ namespace promise_hpp template < typename Iter , typename SubPromise = typename std::iterator_traits::value_type - , typename SubPromiseResult = typename SubPromise::value_type > - promise + , typename SubPromiseResult = typename SubPromise::value_type + , typename ResultPromiseValueType = SubPromiseResult > + promise make_any_promise(Iter begin, Iter end) { if ( begin == end ) { - throw std::logic_error("at least one input promise must be provided for make_any_promise"); + return make_rejected_promise(aggregate_exception()); } struct context_t { std::atomic_size_t failure_counter{0u}; + std::vector exceptions; context_t(std::size_t count) - : failure_counter(count) {} + : failure_counter(count) + , exceptions(count) {} }; - return make_promise([begin, end](auto&& resolver, auto&& rejector){ + return make_promise( + [begin, end](auto&& resolver, auto&& rejector){ + std::size_t exception_index = 0; auto context = std::make_shared(std::distance(begin, end)); - for ( Iter iter = begin; iter != end; ++iter ) { + for ( Iter iter = begin; iter != end; ++iter, ++exception_index ) { (*iter).then([resolver](auto&& v) mutable { resolver(std::forward(v)); - }).except([context, rejector](std::exception_ptr e) mutable { + }).except([context, rejector, exception_index](std::exception_ptr e) mutable { + context->exceptions[exception_index] = e; if ( !--context->failure_counter ) { - rejector(e); + rejector(aggregate_exception(std::move(context->exceptions))); } }); } @@ -1212,14 +1294,12 @@ namespace promise_hpp template < typename Iter , typename SubPromise = typename std::iterator_traits::value_type - , typename SubPromiseResult = typename SubPromise::value_type > - promise + , typename SubPromiseResult = typename SubPromise::value_type + , typename ResultPromiseValueType = SubPromiseResult > + promise make_race_promise(Iter begin, Iter end) { - if ( begin == end ) { - throw std::logic_error("at least one input promise must be provided for make_race_promise"); - } - - return make_promise([begin, end](auto&& resolver, auto&& rejector){ + return make_promise( + [begin, end](auto&& resolver, auto&& rejector){ for ( Iter iter = begin; iter != end; ++iter ) { (*iter) .then(resolver) diff --git a/untests/promise_tests.cpp b/untests/promise_tests.cpp index abd2fe7..0102267 100644 --- a/untests/promise_tests.cpp +++ b/untests/promise_tests.cpp @@ -30,6 +30,30 @@ namespace } } + bool check_empty_aggregate_exception(std::exception_ptr e) { + try { + std::rethrow_exception(e); + } catch (pr::aggregate_exception& ee) { + return ee.empty(); + } catch (...) { + return false; + } + } + + bool check_two_aggregate_exception(std::exception_ptr e) { + try { + std::rethrow_exception(e); + } catch (pr::aggregate_exception& ee) { + if ( ee.size() != 2 ) { + return false; + } + return check_hello_fail_exception(ee[0]) + && check_hello_fail_exception(ee[1]); + } catch (...) { + return false; + } + } + class auto_thread final { public: template < typename F, typename... Args > @@ -310,6 +334,132 @@ TEST_CASE("promise") { REQUIRE(call_fail_with_logic_error); } } + SECTION("finally") { + { + bool all_is_ok = false; + auto p = pr::promise(); + p.finally([&all_is_ok](){ + all_is_ok = true; + }); + REQUIRE_FALSE(all_is_ok); + p.resolve(1); + REQUIRE(all_is_ok); + } + { + bool all_is_ok = false; + auto p = pr::promise(); + p.finally([&all_is_ok](){ + all_is_ok = true; + }); + REQUIRE_FALSE(all_is_ok); + p.reject(std::make_exception_ptr(std::logic_error("hello fail"))); + REQUIRE(all_is_ok); + } + { + bool all_is_ok = false; + pr::make_resolved_promise(1) + .finally([&all_is_ok](){ + all_is_ok = true; + }); + REQUIRE(all_is_ok); + } + { + bool all_is_ok = false; + pr::make_rejected_promise(std::logic_error("hello fail")) + .finally([&all_is_ok](){ + all_is_ok = true; + }); + REQUIRE(all_is_ok); + } + } + SECTION("after_finally") { + { + int check_84_int = 0; + auto p = pr::promise<>(); + p.finally([&check_84_int](){ + check_84_int = 42; + return 100500; + }).then([&check_84_int](){ + check_84_int *= 2; + }); + REQUIRE(check_84_int == 0); + p.resolve(); + REQUIRE(check_84_int == 84); + } + { + int check_84_int = 0; + auto p = pr::promise<>(); + p.finally([&check_84_int](){ + check_84_int = 42; + return 100500; + }).except([&check_84_int](std::exception_ptr){ + check_84_int *= 2; + }); + REQUIRE(check_84_int == 0); + p.reject(std::make_exception_ptr(std::logic_error("hello fail"))); + REQUIRE(check_84_int == 84); + } + } + SECTION("failed_finally") { + { + int check_84_int = 0; + auto p = pr::promise<>(); + p.finally([&check_84_int](){ + check_84_int += 42; + throw std::logic_error("hello fail"); + }).except([&check_84_int](std::exception_ptr e){ + if ( check_hello_fail_exception(e) ) { + check_84_int += 42; + } + }); + p.resolve(); + REQUIRE(check_84_int == 84); + } + { + int check_84_int = 0; + auto p = pr::promise<>(); + p.finally([&check_84_int](){ + check_84_int += 42; + throw std::logic_error("hello fail"); + }).except([&check_84_int](std::exception_ptr e){ + if ( check_hello_fail_exception(e) ) { + check_84_int += 42; + } + }); + p.reject(std::make_exception_ptr(std::logic_error("hello"))); + REQUIRE(check_84_int == 84); + } + { + int check_84_int = 0; + auto p = pr::promise(); + p.finally([&check_84_int](){ + check_84_int += 42; + throw std::logic_error("hello fail"); + }).except([&check_84_int](std::exception_ptr e) -> int { + if ( check_hello_fail_exception(e) ) { + check_84_int += 42; + } + return 0; + }); + p.resolve(1); + REQUIRE(check_84_int == 84); + } + { + int check_84_int = 0; + auto p = pr::promise(); + p.finally([&check_84_int](){ + check_84_int += 42; + throw std::logic_error("hello fail"); + }).except([&check_84_int](std::exception_ptr e) -> int { + if ( check_hello_fail_exception(e) ) { + check_84_int += 42; + } + return 0; + }); + p.reject(std::make_exception_ptr(std::logic_error("hello"))); + REQUIRE(check_84_int == 84); + } + } SECTION("make_promise") { { int check_84_int = 0; @@ -764,9 +914,15 @@ TEST_CASE("promise") { } } SECTION("make_any_promise") { - REQUIRE_THROWS_AS( - pr::make_any_promise(std::vector>{}), - std::logic_error); + { + bool all_is_ok = false; + auto p = pr::make_any_promise(std::vector>{}); + p.except([&all_is_ok](std::exception_ptr e){ + all_is_ok = check_empty_aggregate_exception(e); + return 0; + }); + REQUIRE(all_is_ok); + } { auto p = pr::make_resolved_promise().then_any([](){ return std::vector>{ @@ -811,7 +967,7 @@ TEST_CASE("promise") { pr::make_rejected_promise(std::logic_error("hello fail")), pr::make_rejected_promise(std::logic_error("hello fail")) }).except([&all_is_ok](std::exception_ptr e){ - all_is_ok = true; + all_is_ok = check_two_aggregate_exception(e); return 0; }); REQUIRE(all_is_ok); @@ -1014,9 +1170,6 @@ TEST_CASE("promise") { } } SECTION("make_race_promise_fail") { - REQUIRE_THROWS_AS( - pr::make_race_promise(std::vector>{}), - std::logic_error); { bool call_fail_with_logic_error = false; bool not_call_then_on_reject = true;