diff --git a/headers/promise.hpp/promise.hpp b/headers/promise.hpp/promise.hpp index f22b737..778729a 100644 --- a/headers/promise.hpp/promise.hpp +++ b/headers/promise.hpp/promise.hpp @@ -331,6 +331,18 @@ namespace promise_hpp }); } + template < typename ResolveF > + auto then_any(ResolveF&& on_resolve) { + return then([ + f = std::forward(on_resolve) + ](auto&& v) mutable { + auto r = std::invoke( + std::forward(f), + std::forward(v)); + return make_any_promise(std::move(r)); + }); + } + template < typename ResolveF > auto then_race(ResolveF&& on_resolve) { return then([ @@ -743,6 +755,17 @@ namespace promise_hpp }); } + template < typename ResolveF > + auto then_any(ResolveF&& on_resolve) { + return then([ + f = std::forward(on_resolve) + ]() mutable { + auto r = std::invoke( + std::forward(f)); + return make_any_promise(std::move(r)); + }); + } + template < typename ResolveF > auto then_race(ResolveF&& on_resolve) { return then([ @@ -1095,34 +1118,6 @@ namespace promise_hpp // make_all_promise // - namespace impl - { - template < typename ResultType > - class all_promise_context_t final : private detail::noncopyable { - public: - all_promise_context_t(std::size_t count) - : results_(count) {} - - template < typename T > - bool apply_result(std::size_t index, T&& value) { - results_[index] = std::forward(value); - return ++counter_ == results_.size(); - } - - std::vector get_results() { - std::vector ret; - ret.reserve(results_.size()); - for ( auto&& v : results_ ) { - ret.push_back(std::move(*v)); - } - return ret; - } - private: - std::atomic_size_t counter_{0}; - std::vector> results_; - }; - } - template < typename Iter , typename SubPromise = typename std::iterator_traits::value_type , typename SubPromiseResult = typename SubPromise::value_type @@ -1132,18 +1127,32 @@ namespace promise_hpp if ( begin == end ) { return make_resolved_promise(ResultPromiseValueType()); } + + struct context_t { + std::atomic_size_t success_counter{0u}; + std::vector> results; + context_t(std::size_t count) + : success_counter(count) + , results(count) {} + }; + return make_promise([begin, end](auto&& resolver, auto&& rejector){ std::size_t result_index = 0; - auto context = std::make_shared>(std::distance(begin, end)); + auto context = std::make_shared(std::distance(begin, end)); for ( Iter iter = begin; iter != end; ++iter, ++result_index ) { (*iter).then([ context, resolver, result_index ](auto&& v) mutable { - if ( context->apply_result(result_index, std::forward(v)) ) { - resolver(context->get_results()); + context->results[result_index] = std::forward(v); + if ( !--context->success_counter ) { + std::vector results; + results.reserve(context->results.size()); + for ( auto&& r : context->results ) { + results.push_back(std::move(*r)); + } + resolver(std::move(results)); } }).except(rejector); } @@ -1157,6 +1166,46 @@ namespace promise_hpp std::end(container)); } + // + // make_any_promise + // + + template < typename Iter + , typename SubPromise = typename std::iterator_traits::value_type + , typename SubPromiseResult = typename SubPromise::value_type > + promise + make_any_promise(Iter begin, Iter end) { + if ( begin == end ) { + throw std::logic_error("at least one input promise must be provided for make_any_promise"); + } + + struct context_t { + std::atomic_size_t failure_counter{0u}; + context_t(std::size_t count) + : failure_counter(count) {} + }; + + return make_promise([begin, end](auto&& resolver, auto&& rejector){ + auto context = std::make_shared(std::distance(begin, end)); + for ( Iter iter = begin; iter != end; ++iter ) { + (*iter).then([resolver](auto&& v) mutable { + resolver(std::forward(v)); + }).except([context, rejector](std::exception_ptr e) mutable { + if ( !--context->failure_counter ) { + rejector(e); + } + }); + } + }); + } + + template < typename Container > + auto make_any_promise(Container&& container) { + return make_any_promise( + std::begin(container), + std::end(container)); + } + // // make_race_promise // @@ -1169,6 +1218,7 @@ namespace promise_hpp if ( begin == end ) { throw std::logic_error("at least one input promise must be provided for make_race_promise"); } + return make_promise([begin, end](auto&& resolver, auto&& rejector){ for ( Iter iter = begin; iter != end; ++iter ) { (*iter) diff --git a/untests/promise_tests.cpp b/untests/promise_tests.cpp index a112555..abd2fe7 100644 --- a/untests/promise_tests.cpp +++ b/untests/promise_tests.cpp @@ -692,9 +692,23 @@ TEST_CASE("promise") { } { bool all_is_ok = false; - auto p = pr::make_all_promise(std::vector>{ - pr::make_resolved_promise(32), - pr::make_resolved_promise(10) + auto p = pr::make_resolved_promise().then_all([](){ + return std::vector>{ + pr::make_resolved_promise(32), + pr::make_resolved_promise(10)}; + }).then([&all_is_ok](const std::vector& c){ + all_is_ok = (2 == c.size()) + && c[0] == 32 + && c[1] == 10; + }); + REQUIRE(all_is_ok); + } + { + bool all_is_ok = false; + auto p = pr::make_resolved_promise(1).then_all([](int){ + return std::vector>{ + pr::make_resolved_promise(32), + pr::make_resolved_promise(10)}; }).then([&all_is_ok](const std::vector& c){ all_is_ok = (2 == c.size()) && c[0] == 32 @@ -749,6 +763,60 @@ TEST_CASE("promise") { }); } } + SECTION("make_any_promise") { + REQUIRE_THROWS_AS( + pr::make_any_promise(std::vector>{}), + std::logic_error); + { + auto p = pr::make_resolved_promise().then_any([](){ + return std::vector>{ + pr::make_resolved_promise(32), + pr::make_resolved_promise(10)}; + }).then([](int i){ + return i; + }); + REQUIRE(p.get() == 32); + } + { + auto p = pr::make_resolved_promise(1).then_any([](int){ + return std::vector>{ + pr::make_resolved_promise(32), + pr::make_resolved_promise(10)}; + }).then([](int i){ + return i; + }); + REQUIRE(p.get() == 32); + } + { + auto p = pr::make_any_promise(std::vector>{ + pr::make_resolved_promise(32), + pr::make_rejected_promise(std::logic_error("hello fail")) + }).then([](int i){ + return i; + }); + REQUIRE(p.get() == 32); + } + { + auto p = pr::make_any_promise(std::vector>{ + pr::make_rejected_promise(std::logic_error("hello fail")), + pr::make_resolved_promise(32) + }).then([](int i){ + return i; + }); + REQUIRE(p.get() == 32); + } + { + bool all_is_ok = false; + auto p = pr::make_any_promise(std::vector>{ + pr::make_rejected_promise(std::logic_error("hello fail")), + pr::make_rejected_promise(std::logic_error("hello fail")) + }).except([&all_is_ok](std::exception_ptr e){ + all_is_ok = true; + return 0; + }); + REQUIRE(all_is_ok); + } + } SECTION("make_race_promise") { { auto p1 = pr::promise();