diff --git a/README.md b/README.md index 4ce8409..e6a1336 100644 --- a/README.md +++ b/README.md @@ -25,10 +25,11 @@ ## Installation -[flat.hpp][flat] is a single header library. All you need to do is copy the header file into your project and include this file: +[flat.hpp][flat] is a header only library. All you need to do is copy the header files (`flat_set.hpp` and `flat_map.hpp`) into your project and include them: ```cpp -#include "flat.hpp" +#include "flat_set.hpp" // for flat_set +#include "flat_map.hpp" // for flat_map ``` ## API diff --git a/flat_map.hpp b/flat_map.hpp index 8e03418..b117b2a 100644 --- a/flat_map.hpp +++ b/flat_map.hpp @@ -9,17 +9,40 @@ #include #include #include +#include +#include #include #include +#include namespace flat_hpp { template < typename Key , typename Value , typename Compare = std::less - , typename Allocator = std::allocator> > + , typename Allocator = std::allocator> > class flat_map final { - using data_type = std::vector, Allocator>; + using data_type = std::vector< + std::pair, + Allocator>; + + class uber_comparer_type : public Compare { + public: + uber_comparer_type() = default; + uber_comparer_type(const Compare& c) : Compare(c) {} + + bool operator()(const Key& l, const Key& r) const { + return Compare::operator()(l, r); + } + + bool operator()(const Key& l, typename data_type::const_reference r) const { + return Compare::operator()(l, r.first); + } + + bool operator()(typename data_type::const_reference l, const Key& r) const { + return Compare::operator()(l.first, r); + } + }; public: using key_type = Key; using mapped_type = Value; @@ -40,5 +63,273 @@ namespace flat_hpp using const_iterator = typename data_type::const_iterator; using reverse_iterator = typename data_type::reverse_iterator; using const_reverse_iterator = typename data_type::const_reverse_iterator; + + class value_compare final { + public: + bool operator()(const value_type& l, const value_type& r) const { + return compare_(l.first, r.first); + } + private: + friend class flat_map; + explicit value_compare(const key_compare& compare) + : compare_(compare) {} + private: + key_compare compare_; + }; + + static_assert( + std::is_same::value, + "Allocator::value_type must be same type as value_type"); + public: + explicit flat_map( + const Allocator& a) + : data_(a) {} + + explicit flat_map( + const Compare& c = Compare(), + const Allocator& a = Allocator()) + : data_(a) + , compare_(c) {} + + template < typename InputIter > + flat_map( + InputIter first, + InputIter last, + const Allocator& a) + : data_(a) { + insert(first, last); + } + + template < typename InputIter > + flat_map( + InputIter first, + InputIter last, + const Compare& c = Compare(), + const Allocator& a = Allocator()) + : data_(a) + , compare_(c) { + insert(first, last); + } + + flat_map( + std::initializer_list ilist, + const Allocator& a) + : data_(a) { + insert(ilist); + } + + flat_map( + std::initializer_list ilist, + const Compare& c = Compare(), + const Allocator& a = Allocator()) + : data_(a) + , compare_(c) { + insert(ilist); + } + + iterator begin() noexcept { return data_.begin(); } + const_iterator begin() const noexcept { return data_.begin(); } + const_iterator cbegin() const noexcept { return data_.cbegin(); } + + iterator end() noexcept { return data_.end(); } + const_iterator end() const noexcept { return data_.end(); } + const_iterator cend() const noexcept { return data_.cend(); } + + reverse_iterator rbegin() noexcept { return data_.rbegin(); } + const_reverse_iterator rbegin() const noexcept { return data_.rbegin(); } + const_reverse_iterator crbegin() const noexcept { return data_.crbegin(); } + + reverse_iterator rend() noexcept { return data_.rend(); } + const_reverse_iterator rend() const noexcept { return data_.rend(); } + const_reverse_iterator crend() const noexcept { return data_.crend(); } + + bool empty() const noexcept { + return data_.empty(); + } + + size_type size() const noexcept { + return data_.size(); + } + + size_type max_size() const noexcept { + return data_.max_size(); + } + + mapped_type& operator[](key_type&& key) { + return insert(value_type(std::move(key), mapped_type())).first->second; + } + + mapped_type& operator[](const key_type& key) { + return insert(value_type(key, mapped_type())).first->second; + } + + mapped_type& at(const key_type& key) { + const auto iter = find(key); + if ( iter != end() ) { + return iter->second; + } + throw std::out_of_range("flat_map::at: key not found"); + } + + const mapped_type& at(const key_type& key) const { + const auto iter = find(key); + if ( iter != end() ) { + return iter->second; + } + throw std::out_of_range("flat_map::at: key not found"); + } + + std::pair insert(const value_type& value) { + const iterator iter = lower_bound(value.first); + return iter == end() || compare_(value.first, iter->first) + ? std::make_pair(data_.insert(iter, value), true) + : std::make_pair(iter, false); + } + + iterator insert(const_iterator hint, const value_type& value) { + return (hint == begin() || compare_((hint - 1)->first, value.first)) + && (hint == end() || compare_(value.first, hint->first)) + ? data_.insert(hint, std::move(value)) + : insert(std::move(value)).first; + } + + template < typename InputIter > + void insert(InputIter first, InputIter last) { + while ( first != last ) { + insert(*first++); + } + } + + void insert(std::initializer_list ilist) { + insert(ilist.begin(), ilist.end()); + } + + template < typename... Args > + std::pair emplace(Args&&... args) { + return insert(value_type(std::forward(args)...)); + } + + template < typename... Args > + iterator emplace_hint(const_iterator hint, Args&&... args) { + return insert(hint, value_type(std::forward(args)...)); + } + + void clear() noexcept { + data_.clear(); + } + + iterator erase(const_iterator iter) { + return data_.erase(iter); + } + + iterator erase(const_iterator first, const_iterator last) { + return data_.erase(first, last); + } + + size_type erase(const key_type& key) { + const iterator iter = find(key); + return iter != end() + ? (erase(iter), 1) + : 0; + } + + void swap(flat_map& other) { + using std::swap; + swap(data_, other.data_); + swap(compare_, other.compare_); + } + + size_type count(const key_type& key) const { + const auto iter = find(key); + return iter != end() ? 1 : 0; + } + + iterator find(const key_type& key) { + const iterator iter = lower_bound(key); + return iter != end() && !compare_(key, iter->first) + ? iter + : end(); + } + + const_iterator find(const key_type& key) const { + const const_iterator iter = lower_bound(key); + return iter != end() && !compare_(key, iter->first) + ? iter + : end(); + } + + std::pair equal_range(const key_type& key) { + return std::equal_range(begin(), end(), key, compare_); + } + + std::pair equal_range(const key_type& key) const { + return std::equal_range(begin(), end(), key, compare_); + } + + iterator lower_bound(const key_type& key) { + return std::lower_bound(begin(), end(), key, compare_); + } + + const_iterator lower_bound(const key_type& key) const { + return std::lower_bound(begin(), end(), key, compare_); + } + + iterator upper_bound(const key_type& key) { + return std::upper_bound(begin(), end(), key, compare_); + } + + const_iterator upper_bound(const key_type& key) const { + return std::upper_bound(begin(), end(), key, compare_); + } + + key_compare key_comp() const { + return compare_; + } + + value_compare value_comp() const { + return value_compare(compare_); + } + private: + data_type data_; + uber_comparer_type compare_; }; } + +namespace flat_hpp +{ + template < typename K, typename V, typename C, typename A > + void swap(flat_map& l, flat_map& r) { + l.swap(r); + } + + template < typename K, typename V, typename C, typename A > + bool operator==(const flat_map& l, const flat_map& r) { + return l.size() == r.size() + && std::equal(l.begin(), l.end(), r.begin(), r.end()); + } + + template < typename K, typename V, typename C, typename A > + bool operator!=(const flat_map& l, const flat_map& r) { + return !(l == r); + } + + template < typename K, typename V, typename C, typename A > + bool operator<(const flat_map& l, const flat_map& r) { + return std::lexicographical_compare(l.begin(), l.end(), r.begin(), r.end()); + } + + template < typename K, typename V, typename C, typename A > + bool operator>(const flat_map& l, const flat_map& r) { + return r < l; + } + + template < typename K, typename V, typename C, typename A > + bool operator<=(const flat_map& l, const flat_map& r) { + return !(r < l); + } + + template < typename K, typename V, typename C, typename A > + bool operator>=(const flat_map& l, const flat_map& r) { + return !(l < r); + } +} diff --git a/flat_map_tests.cpp b/flat_map_tests.cpp index 2a51191..1957624 100644 --- a/flat_map_tests.cpp +++ b/flat_map_tests.cpp @@ -8,15 +8,50 @@ #include "catch.hpp" #include "flat_map.hpp" -namespace flat = flat_hpp; +using namespace flat_hpp; namespace { + template < typename T > + class dummy_allocator { + public: + using value_type = T; + + dummy_allocator() = default; + + template < typename U > + dummy_allocator(const dummy_allocator&) noexcept { + } + + T* allocate(std::size_t n) noexcept { + return static_cast(std::malloc(sizeof(T) * n)); + } + + void deallocate(T* p, std::size_t n) noexcept { + (void)n; + std::free(p); + } + }; + + template < typename T, typename U > + bool operator==(const dummy_allocator&, const dummy_allocator&) noexcept { + return true; + } + + template < typename T, typename U > + bool operator!=(const dummy_allocator& l, const dummy_allocator& r) noexcept { + return !(l == r); + } + + template < typename T > + constexpr std::add_const_t& my_as_const(T& t) noexcept { + return t; + } } TEST_CASE("flat_map") { - { - using map_t = flat::flat_map; + SECTION("types") { + using map_t = flat_map; static_assert( std::is_same::value, @@ -25,7 +60,7 @@ TEST_CASE("flat_map") { std::is_same::value, "unit test static error"); static_assert( - std::is_same>::value, + std::is_same>::value, "unit test static error"); static_assert( @@ -36,17 +71,263 @@ TEST_CASE("flat_map") { "unit test static error"); static_assert( - std::is_same&>::value, + std::is_same&>::value, "unit test static error"); static_assert( - std::is_same&>::value, + std::is_same&>::value, "unit test static error"); static_assert( - std::is_same*>::value, + std::is_same*>::value, "unit test static error"); static_assert( - std::is_same*>::value, + std::is_same*>::value, "unit test static error"); } + SECTION("ctors") { + using alloc_t = dummy_allocator< + std::pair>; + + using map_t = flat_map< + int, + unsigned, + std::less, + alloc_t>; + + using map2_t = flat_map< + int, + unsigned, + std::greater, + alloc_t>; + + { + auto s0 = map_t(); + auto s1 = map2_t(alloc_t()); + auto s2 = map_t(std::less()); + auto s3 = map2_t(std::greater(), alloc_t()); + } + + { + using vec_t = std::vector>; + + vec_t v{{1,30},{2,20},{3,10}}; + auto s0 = map_t(v.cbegin(), v.cend()); + auto s1 = map2_t(v.cbegin(), v.cend(), alloc_t()); + auto s2 = map_t(v.cbegin(), v.cend(), std::less()); + auto s3 = map2_t(v.cbegin(), v.cend(), std::greater(), alloc_t()); + + REQUIRE(vec_t(s0.begin(), s0.end()) == vec_t({{1,30},{2,20},{3,10}})); + REQUIRE(vec_t(s1.begin(), s1.end()) == vec_t({{3,10},{2,20},{1,30}})); + REQUIRE(vec_t(s2.begin(), s2.end()) == vec_t({{1,30},{2,20},{3,10}})); + REQUIRE(vec_t(s3.begin(), s3.end()) == vec_t({{3,10},{2,20},{1,30}})); + } + + { + auto s0 = map_t({{0,1}, {1,2}}); + auto s1 = map_t({{0,1}, {1,2}}, alloc_t()); + auto s2 = map_t({{0,1}, {1,2}}, std::less()); + auto s3 = map_t({{0,1}, {1,2}}, std::less(), alloc_t()); + } + } + SECTION("capacity") { + using map_t = flat_map; + map_t s0; + + REQUIRE(s0.empty()); + REQUIRE_FALSE(s0.size()); + REQUIRE(s0.max_size() == std::allocator>().max_size()); + + s0.insert({2,42}); + + REQUIRE_FALSE(s0.empty()); + REQUIRE(s0.size() == 1u); + REQUIRE(s0.max_size() == std::allocator>().max_size()); + + s0.insert({2,84}); + REQUIRE(s0.size() == 1u); + + s0.insert({3,84}); + REQUIRE(s0.size() == 2u); + + s0.clear(); + + REQUIRE(s0.empty()); + REQUIRE_FALSE(s0.size()); + REQUIRE(s0.max_size() == std::allocator>().max_size()); + } + SECTION("access") { + using map_t = flat_map; + map_t s0; + s0[1] = 42; + REQUIRE(s0 == map_t{{1,42}}); + s0[1] = 84; + REQUIRE(s0 == map_t{{1,84}}); + + REQUIRE(s0.at(1) == 84); + REQUIRE(my_as_const(s0).at(1) == 84); + REQUIRE_THROWS_AS(s0.at(0), std::out_of_range); + REQUIRE_THROWS_AS(my_as_const(s0).at(0), std::out_of_range); + } + SECTION("inserts") { + struct obj_t { + obj_t(int i) : i(i) {} + int i; + + bool operator<(const obj_t& o) const { + return i < o.i; + } + + bool operator==(const obj_t& o) const { + return i == o.i; + } + }; + + using map_t = flat_map; + + { + map_t s0; + + auto i0 = s0.insert(std::make_pair(1, 42)); + REQUIRE(s0 == map_t{{1,42}}); + REQUIRE(i0 == std::make_pair(s0.begin(), true)); + + auto i1 = s0.insert(std::make_pair(1, obj_t(42))); + REQUIRE(s0 == map_t{{1,42}}); + REQUIRE(i1 == std::make_pair(s0.begin(), false)); + + auto i2 = s0.insert(std::make_pair(2, obj_t(42))); + REQUIRE(s0 == map_t{{1,42},{2,42}}); + REQUIRE(i2 == std::make_pair(s0.begin() + 1, true)); + + auto i3 = s0.insert(s0.cend(), std::make_pair(3, 84)); + REQUIRE(i3 == s0.begin() + 2); + + s0.insert(s0.cend(), std::make_pair(4, obj_t(84))); + auto i4 = s0.insert(s0.cend(), std::make_pair(0, obj_t(21))); + REQUIRE(i4 == s0.begin()); + + auto i5 = s0.emplace(5, 100500); + REQUIRE(i5 == std::make_pair(s0.end() - 1, true)); + REQUIRE(s0 == map_t{{0,21},{1,42},{2,42},{3,84},{4,84},{5,100500}}); + + auto i6 = s0.emplace_hint(s0.cend(), 6, 100500); + REQUIRE(i6 == s0.end() - 1); + REQUIRE(s0 == map_t{{0,21},{1,42},{2,42},{3,84},{4,84},{5,100500},{6,100500}}); + } + } + SECTION("erasers") { + using map_t = flat_map; + { + map_t s0{{1,2},{2,3},{3,4}}; + s0.clear(); + REQUIRE(s0.empty()); + } + { + map_t s0{{1,2},{2,3},{3,4}}; + auto i = s0.erase(s0.find(2)); + REQUIRE(i == s0.begin() + 1); + REQUIRE(s0 == map_t{{1,2},{3,4}}); + } + { + map_t s0{{1,2},{2,3},{3,4}}; + auto i = s0.erase(s0.begin() + 1, s0.end()); + REQUIRE(i == s0.end()); + REQUIRE(s0 == map_t{{1,2}}); + } + { + map_t s0{{1,2},{2,3},{3,4}}; + REQUIRE(s0.erase(1) == 1); + REQUIRE(s0.erase(6) == 0); + REQUIRE(s0 == map_t{{2,3},{3,4}}); + } + { + map_t s0{{1,2},{2,3},{3,4}}; + map_t s1{{2,3},{3,4},{5,6}}; + s0.swap(s1); + REQUIRE(s0 == map_t{{2,3},{3,4},{5,6}}); + REQUIRE(s1 == map_t{{1,2},{2,3},{3,4}}); + swap(s1, s0); + REQUIRE(s0 == map_t{{1,2},{2,3},{3,4}}); + REQUIRE(s1 == map_t{{2,3},{3,4},{5,6}}); + } + } + SECTION("lookup") { + using map_t = flat_map; + { + map_t s0{{1,2},{2,3},{3,4},{4,5},{5,6}}; + REQUIRE(s0.count(3)); + REQUIRE_FALSE(s0.count(6)); + REQUIRE(my_as_const(s0).count(5)); + REQUIRE_FALSE(my_as_const(s0).count(0)); + } + { + map_t s0{{1,2},{2,3},{3,4},{4,5},{5,6}}; + REQUIRE(s0.find(2) == s0.begin() + 1); + REQUIRE(my_as_const(s0).find(3) == s0.cbegin() + 2); + REQUIRE(s0.find(6) == s0.end()); + REQUIRE(my_as_const(s0).find(0) == s0.cend()); + } + { + map_t s0{{1,2},{2,3},{3,4},{4,5},{5,6}}; + REQUIRE(s0.equal_range(3) == std::make_pair(s0.begin() + 2, s0.begin() + 3)); + REQUIRE(s0.equal_range(6) == std::make_pair(s0.end(), s0.end())); + REQUIRE(my_as_const(s0).equal_range(3) == std::make_pair(s0.cbegin() + 2, s0.cbegin() + 3)); + REQUIRE(my_as_const(s0).equal_range(0) == std::make_pair(s0.cbegin(), s0.cbegin())); + } + { + map_t s0{{0,1},{3,2},{6,3}}; + REQUIRE(s0.lower_bound(0) == s0.begin()); + REQUIRE(s0.lower_bound(1) == s0.begin() + 1); + REQUIRE(s0.lower_bound(10) == s0.end()); + REQUIRE(my_as_const(s0).lower_bound(-1) == s0.cbegin()); + REQUIRE(my_as_const(s0).lower_bound(7) == s0.cbegin() + 3); + } + } + SECTION("observers") { + struct my_less { + int i; + my_less(int i) : i(i) {} + bool operator()(int l, int r) const { + return l < r; + } + }; + using map_t = flat_map; + map_t s0(my_less(42)); + REQUIRE(my_as_const(s0).key_comp().i == 42); + REQUIRE(my_as_const(s0).value_comp()({2,50},{4,20})); + } + SECTION("operators") { + using map_t = flat_map; + + REQUIRE(map_t{{1,2},{3,4}} == map_t{{3,4},{1,2}}); + REQUIRE_FALSE(map_t{{1,2},{3,4}} == map_t{{2,4},{1,2}}); + REQUIRE_FALSE(map_t{{1,2},{3,4}} == map_t{{1,3},{1,2}}); + REQUIRE_FALSE(map_t{{1,2},{3,4}} == map_t{{3,4},{1,2},{0,0}}); + + REQUIRE_FALSE(map_t{{1,2},{3,4}} != map_t{{3,4},{1,2}}); + REQUIRE(map_t{{1,2},{3,4}} != map_t{{2,4},{1,2}}); + REQUIRE(map_t{{1,2},{3,4}} != map_t{{1,3},{1,2}}); + REQUIRE(map_t{{1,2},{3,4}} != map_t{{3,4},{1,2},{0,0}}); + + REQUIRE(map_t{{0,2},{3,4}} < map_t{{1,2},{3,4}}); + REQUIRE(map_t{{1,1},{3,4}} < map_t{{1,2},{3,4}}); + REQUIRE(map_t{{1,2},{3,4}} < map_t{{1,2},{3,4},{5,6}}); + + REQUIRE(map_t{{0,2},{3,4}} <= map_t{{1,2},{3,4}}); + REQUIRE(map_t{{1,1},{3,4}} <= map_t{{1,2},{3,4}}); + REQUIRE(map_t{{1,2},{3,4}} <= map_t{{1,2},{3,4},{5,6}}); + + REQUIRE(map_t{{1,2},{3,4}} > map_t{{0,2},{3,4}}); + REQUIRE(map_t{{1,2},{3,4}} > map_t{{1,1},{3,4}}); + REQUIRE(map_t{{1,2},{3,4},{5,6}} > map_t{{1,2},{3,4}}); + + REQUIRE(map_t{{1,2},{3,4}} >= map_t{{0,2},{3,4}}); + REQUIRE(map_t{{1,2},{3,4}} >= map_t{{1,1},{3,4}}); + REQUIRE(map_t{{1,2},{3,4},{5,6}} >= map_t{{1,2},{3,4}}); + + REQUIRE_FALSE(map_t{{1,2},{3,4}} < map_t{{1,2},{3,4}}); + REQUIRE(map_t{{1,2},{3,4}} <= map_t{{1,2},{3,4}}); + REQUIRE_FALSE(map_t{{1,2},{3,4}} > map_t{{1,2},{3,4}}); + REQUIRE(map_t{{1,2},{3,4}} >= map_t{{1,2},{3,4}}); + } } diff --git a/flat_set.hpp b/flat_set.hpp index f132a1a..f5652f0 100644 --- a/flat_set.hpp +++ b/flat_set.hpp @@ -9,8 +9,11 @@ #include #include #include +#include +#include #include #include +#include namespace flat_hpp { @@ -39,5 +42,250 @@ namespace flat_hpp using const_iterator = typename data_type::const_iterator; using reverse_iterator = typename data_type::reverse_iterator; using const_reverse_iterator = typename data_type::const_reverse_iterator; + + static_assert( + std::is_same::value, + "Allocator::value_type must be same type as value_type"); + public: + explicit flat_set( + const Allocator& a) + : data_(a) {} + + explicit flat_set( + const Compare& c = Compare(), + const Allocator& a = Allocator()) + : data_(a) + , compare_(c) {} + + template < typename InputIter > + flat_set( + InputIter first, + InputIter last, + const Allocator& a) + : data_(a) { + insert(first, last); + } + + template < typename InputIter > + flat_set( + InputIter first, + InputIter last, + const Compare& c = Compare(), + const Allocator& a = Allocator()) + : data_(a) + , compare_(c) { + insert(first, last); + } + + flat_set( + std::initializer_list ilist, + const Allocator& a) + : data_(a) { + insert(ilist); + } + + flat_set( + std::initializer_list ilist, + const Compare& c = Compare(), + const Allocator& a = Allocator()) + : data_(a) + , compare_(c) { + insert(ilist); + } + + iterator begin() noexcept { return data_.begin(); } + const_iterator begin() const noexcept { return data_.begin(); } + const_iterator cbegin() const noexcept { return data_.cbegin(); } + + iterator end() noexcept { return data_.end(); } + const_iterator end() const noexcept { return data_.end(); } + const_iterator cend() const noexcept { return data_.cend(); } + + reverse_iterator rbegin() noexcept { return data_.rbegin(); } + const_reverse_iterator rbegin() const noexcept { return data_.rbegin(); } + const_reverse_iterator crbegin() const noexcept { return data_.crbegin(); } + + reverse_iterator rend() noexcept { return data_.rend(); } + const_reverse_iterator rend() const noexcept { return data_.rend(); } + const_reverse_iterator crend() const noexcept { return data_.crend(); } + + bool empty() const noexcept { + return data_.empty(); + } + + size_type size() const noexcept { + return data_.size(); + } + + size_type max_size() const noexcept { + return data_.max_size(); + } + + std::pair insert(value_type&& value) { + const iterator iter = lower_bound(value); + return iter == end() || compare_(value, *iter) + ? std::make_pair(data_.insert(iter, std::move(value)), true) + : std::make_pair(iter, false); + } + + std::pair insert(const value_type& value) { + const iterator iter = lower_bound(value); + return iter == end() || compare_(value, *iter) + ? std::make_pair(data_.insert(iter, value), true) + : std::make_pair(iter, false); + } + + iterator insert(const_iterator hint, value_type&& value) { + return (hint == begin() || compare_(*(hint - 1), value)) + && (hint == end() || compare_(value, *hint)) + ? data_.insert(hint, std::move(value)) + : insert(std::move(value)).first; + } + + iterator insert(const_iterator hint, const value_type& value) { + return (hint == begin() || compare_(*(hint - 1), value)) + && (hint == end() || compare_(value, *hint)) + ? data_.insert(hint, value) + : insert(value).first; + } + + template < typename InputIter > + void insert(InputIter first, InputIter last) { + while ( first != last ) { + insert(*first++); + } + } + + void insert(std::initializer_list ilist) { + insert(ilist.begin(), ilist.end()); + } + + template < typename... Args > + std::pair emplace(Args&&... args) { + return insert(value_type(std::forward(args)...)); + } + + template < typename... Args > + iterator emplace_hint(const_iterator hint, Args&&... args) { + return insert(hint, value_type(std::forward(args)...)); + } + + void clear() noexcept { + data_.clear(); + } + + iterator erase(const_iterator iter) { + return data_.erase(iter); + } + + iterator erase(const_iterator first, const_iterator last) { + return data_.erase(first, last); + } + + size_type erase(const key_type& key) { + const iterator iter = find(key); + return iter != end() + ? (erase(iter), 1) + : 0; + } + + void swap(flat_set& other) { + using std::swap; + swap(data_, other.data_); + swap(compare_, other.compare_); + } + + size_type count(const key_type& key) const { + const const_iterator iter = find(key); + return iter != end() ? 1 : 0; + } + + iterator find(const key_type& key) { + const iterator iter = lower_bound(key); + return iter != end() && !compare_(key, *iter) + ? iter + : end(); + } + + const_iterator find(const key_type& key) const { + const const_iterator iter = lower_bound(key); + return iter != end() && !compare_(key, *iter) + ? iter + : end(); + } + + std::pair equal_range(const key_type& key) { + return std::equal_range(begin(), end(), key, compare_); + } + + std::pair equal_range(const key_type& key) const { + return std::equal_range(begin(), end(), key, compare_); + } + + iterator lower_bound(const key_type& key) { + return std::lower_bound(begin(), end(), key, compare_); + } + + const_iterator lower_bound(const key_type& key) const { + return std::lower_bound(begin(), end(), key, compare_); + } + + iterator upper_bound(const key_type& key) { + return std::upper_bound(begin(), end(), key, compare_); + } + + const_iterator upper_bound(const key_type& key) const { + return std::upper_bound(begin(), end(), key, compare_); + } + + key_compare key_comp() const { + return compare_; + } + + value_compare value_comp() const { + return value_compare(compare_); + } + private: + data_type data_; + key_compare compare_; }; } + +namespace flat_hpp +{ + template < typename K, typename C, typename A > + void swap(flat_set& l, flat_set& r) { + l.swap(r); + } + + template < typename K, typename C, typename A > + bool operator==(const flat_set& l, const flat_set& r) { + return l.size() == r.size() + && std::equal(l.begin(), l.end(), r.begin(), r.end()); + } + + template < typename K, typename C, typename A > + bool operator!=(const flat_set& l, const flat_set& r) { + return !(l == r); + } + + template < typename K, typename C, typename A > + bool operator<(const flat_set& l, const flat_set& r) { + return std::lexicographical_compare(l.begin(), l.end(), r.begin(), r.end()); + } + + template < typename K, typename C, typename A > + bool operator>(const flat_set& l, const flat_set& r) { + return r < l; + } + + template < typename K, typename C, typename A > + bool operator<=(const flat_set& l, const flat_set& r) { + return !(r < l); + } + + template < typename K, typename C, typename A > + bool operator>=(const flat_set& l, const flat_set& r) { + return !(l < r); + } +} diff --git a/flat_set_tests.cpp b/flat_set_tests.cpp index 284ee7a..d103f55 100644 --- a/flat_set_tests.cpp +++ b/flat_set_tests.cpp @@ -8,15 +8,50 @@ #include "catch.hpp" #include "flat_set.hpp" -namespace flat = flat_hpp; +using namespace flat_hpp; namespace { + template < typename T > + class dummy_allocator { + public: + using value_type = T; + + dummy_allocator() = default; + + template < typename U > + dummy_allocator(const dummy_allocator&) noexcept { + } + + T* allocate(std::size_t n) noexcept { + return static_cast(std::malloc(sizeof(T) * n)); + } + + void deallocate(T* p, std::size_t n) noexcept { + (void)n; + std::free(p); + } + }; + + template < typename T, typename U > + bool operator==(const dummy_allocator&, const dummy_allocator&) noexcept { + return true; + } + + template < typename T, typename U > + bool operator!=(const dummy_allocator& l, const dummy_allocator& r) noexcept { + return !(l == r); + } + + template < typename T > + constexpr std::add_const_t& my_as_const(T& t) noexcept { + return t; + } } TEST_CASE("flat_set") { - { - using set_t = flat::flat_set; + SECTION("types") { + using set_t = flat_set; static_assert( std::is_same::value, @@ -46,4 +81,251 @@ TEST_CASE("flat_set") { std::is_same::value, "unit test static error"); } + SECTION("ctors") { + using alloc_t = dummy_allocator; + using set_t = flat_set, alloc_t>; + using set2_t = flat_set, alloc_t>; + + { + auto s0 = set_t(); + auto s1 = set2_t(alloc_t()); + auto s2 = set_t(std::less()); + auto s3 = set2_t(std::greater(), alloc_t()); + } + + { + std::vector v{1,2,3}; + auto s0 = set_t(v.cbegin(), v.cend()); + auto s1 = set2_t(v.cbegin(), v.cend(), alloc_t()); + auto s2 = set_t(v.cbegin(), v.cend(), std::less()); + auto s3 = set2_t(v.cbegin(), v.cend(), std::greater(), alloc_t()); + + REQUIRE(std::vector(s0.begin(), s0.end()) == std::vector({1,2,3})); + REQUIRE(std::vector(s1.begin(), s1.end()) == std::vector({3,2,1})); + REQUIRE(std::vector(s2.begin(), s2.end()) == std::vector({1,2,3})); + REQUIRE(std::vector(s3.begin(), s3.end()) == std::vector({3,2,1})); + } + + { + auto s0 = set_t({0,1,2}); + auto s1 = set2_t({0,1,2}, alloc_t()); + auto s2 = set_t({0,1,2}, std::less()); + auto s3 = set2_t({0,1,2}, std::greater(), alloc_t()); + + REQUIRE(std::vector(s0.begin(), s0.end()) == std::vector({0,1,2})); + REQUIRE(std::vector(s1.begin(), s1.end()) == std::vector({2,1,0})); + REQUIRE(std::vector(s2.begin(), s2.end()) == std::vector({0,1,2})); + REQUIRE(std::vector(s3.begin(), s3.end()) == std::vector({2,1,0})); + } + } + SECTION("capacity") { + using set_t = flat_set; + set_t s0; + + REQUIRE(s0.empty()); + REQUIRE_FALSE(s0.size()); + REQUIRE(s0.max_size() == std::allocator().max_size()); + + s0.insert(42); + + REQUIRE_FALSE(s0.empty()); + REQUIRE(s0.size() == 1u); + REQUIRE(s0.max_size() == std::allocator().max_size()); + + s0.insert(42); + REQUIRE(s0.size() == 1u); + + s0.insert(84); + REQUIRE(s0.size() == 2u); + + s0.clear(); + + REQUIRE(s0.empty()); + REQUIRE_FALSE(s0.size()); + REQUIRE(s0.max_size() == std::allocator().max_size()); + } + SECTION("inserts") { + struct obj_t { + obj_t(int i) : i(i) {} + int i; + + bool operator<(const obj_t& o) const { + return i < o.i; + } + + bool operator==(const obj_t& o) const { + return i == o.i; + } + }; + + using set_t = flat_set; + + { + set_t s0; + + auto i0 = s0.insert(1); + REQUIRE(s0 == set_t{1}); + REQUIRE(i0 == std::make_pair(s0.begin(), true)); + + auto i1 = s0.insert(obj_t(1)); + REQUIRE(s0 == set_t{1}); + REQUIRE(i1 == std::make_pair(s0.begin(), false)); + + auto i2 = s0.insert(obj_t(2)); + REQUIRE(s0 == set_t{1,2}); + REQUIRE(i2 == std::make_pair(s0.begin() + 1, true)); + + auto o2 = obj_t(2); + auto i3 = s0.insert(o2); + REQUIRE(i3 == std::make_pair(s0.begin() + 1, false)); + + s0.insert(s0.cbegin(), 1); + s0.insert(s0.cbegin(), 2); + s0.insert(s0.cend(), 1); + s0.insert(s0.cend(), 2); + REQUIRE(s0 == set_t{1,2}); + + s0.insert(s0.cbegin(), 0); + REQUIRE(s0 == set_t{0,1,2}); + s0.insert(s0.cend(), 3); + REQUIRE(s0 == set_t{0,1,2,3}); + s0.insert(s0.cbegin(), 4); + s0.insert(s0.cend(), -1); + REQUIRE(s0 == set_t{-1,0,1,2,3,4}); + + s0.insert(s0.cbegin() + 2, obj_t(5)); + REQUIRE(s0 == set_t{-1,0,1,2,3,4,5}); + s0.insert(s0.cbegin(), obj_t(-2)); + REQUIRE(s0 == set_t{-2,-1,0,1,2,3,4,5}); + } + + { + set_t s0; + + auto e0 = s0.emplace(3); + REQUIRE(s0 == set_t{3}); + REQUIRE(e0 == std::make_pair(s0.begin(), true)); + + auto e1 = s0.emplace(obj_t(3)); + REQUIRE(e1 == std::make_pair(s0.begin(), false)); + + auto e2 = s0.emplace(4); + REQUIRE(s0 == set_t{3,4}); + REQUIRE(e2 == std::make_pair(s0.begin() + 1, true)); + + auto e3 = s0.emplace_hint(s0.cbegin(), 1); + REQUIRE(e3 == s0.begin()); + auto e4 = s0.emplace_hint(s0.cend(), 2); + REQUIRE(e4 == s0.begin() + 1); + s0.emplace_hint(s0.cbegin(), 5); + s0.emplace_hint(s0.cend(), 6); + REQUIRE(s0 == set_t{1,2,3,4,5,6}); + } + } + SECTION("erasers") { + using set_t = flat_set; + { + set_t s0{1,2,3,4,5}; + s0.clear(); + REQUIRE(s0.empty()); + } + { + set_t s0{1,2,3,4,5}; + auto i = s0.erase(s0.find(3)); + REQUIRE(i == s0.begin() + 2); + REQUIRE(s0 == set_t{1,2,4,5}); + } + { + set_t s0{1,2,3,4,5}; + auto i = s0.erase(s0.begin() + 2, s0.end()); + REQUIRE(i == s0.end()); + REQUIRE(s0 == set_t{1,2}); + } + { + set_t s0{1,2,3,4,5}; + REQUIRE(s0.erase(2) == 1); + REQUIRE(s0.erase(6) == 0); + REQUIRE(s0 == set_t{1,3,4,5}); + } + { + set_t s0{1,2,3}; + set_t s1{3,4,5}; + s0.swap(s1); + REQUIRE(s0 == set_t{3,4,5}); + REQUIRE(s1 == set_t{1,2,3}); + swap(s1, s0); + REQUIRE(s0 == set_t{1,2,3}); + REQUIRE(s1 == set_t{3,4,5}); + } + } + SECTION("lookup") { + using set_t = flat_set; + { + set_t s0{1,2,3,4,5}; + REQUIRE(s0.count(3)); + REQUIRE_FALSE(s0.count(6)); + REQUIRE(my_as_const(s0).count(5)); + REQUIRE_FALSE(my_as_const(s0).count(0)); + } + { + set_t s0{1,2,3,4,5}; + REQUIRE(s0.find(2) == s0.begin() + 1); + REQUIRE(my_as_const(s0).find(3) == s0.cbegin() + 2); + REQUIRE(s0.find(6) == s0.end()); + REQUIRE(my_as_const(s0).find(0) == s0.cend()); + } + { + set_t s0{1,2,3,4,5}; + REQUIRE(s0.equal_range(3) == std::make_pair(s0.begin() + 2, s0.begin() + 3)); + REQUIRE(s0.equal_range(6) == std::make_pair(s0.end(), s0.end())); + REQUIRE(my_as_const(s0).equal_range(3) == std::make_pair(s0.cbegin() + 2, s0.cbegin() + 3)); + REQUIRE(my_as_const(s0).equal_range(0) == std::make_pair(s0.cbegin(), s0.cbegin())); + } + { + set_t s0{0,3,6,9}; + REQUIRE(s0.lower_bound(0) == s0.begin()); + REQUIRE(s0.lower_bound(1) == s0.begin() + 1); + REQUIRE(s0.lower_bound(10) == s0.end()); + REQUIRE(my_as_const(s0).lower_bound(-1) == s0.cbegin()); + REQUIRE(my_as_const(s0).lower_bound(7) == s0.cbegin() + 3); + } + } + SECTION("observers") { + struct my_less { + int i; + my_less(int i) : i(i) {} + bool operator()(int l, int r) const { + return l < r; + } + }; + using set_t = flat_set; + set_t s0(my_less(42)); + REQUIRE(my_as_const(s0).key_comp().i == 42); + REQUIRE(my_as_const(s0).value_comp().i == 42); + } + SECTION("operators") { + using set_t = flat_set; + + REQUIRE(set_t{1,2,3} == set_t{3,2,1}); + REQUIRE_FALSE(set_t{1,2,3} == set_t{3,2,4}); + REQUIRE_FALSE(set_t{1,2,3} == set_t{1,2,3,4}); + + REQUIRE(set_t{1,2,3} != set_t{3,2,4}); + REQUIRE_FALSE(set_t{1,2,3} != set_t{3,2,1}); + + REQUIRE(set_t{2,3,4,6} < set_t{2,3,5}); + REQUIRE(set_t{2,3,4,6} <= set_t{2,3,5}); + REQUIRE_FALSE(set_t{2,3,5} < set_t{2,3,4,6}); + REQUIRE_FALSE(set_t{2,3,5} <= set_t{2,3,4,6}); + + REQUIRE_FALSE(set_t{2,3,4,6} > set_t{2,3,5}); + REQUIRE_FALSE(set_t{2,3,4,6} >= set_t{2,3,5}); + REQUIRE(set_t{2,3,5} > set_t{2,3,4,6}); + REQUIRE(set_t{2,3,5} >= set_t{2,3,4,6}); + + REQUIRE_FALSE(set_t{1,2,3} < set_t{1,2,3}); + REQUIRE(set_t{1,2,3} <= set_t{1,2,3}); + REQUIRE_FALSE(set_t{1,2,3} > set_t{1,2,3}); + REQUIRE(set_t{1,2,3} >= set_t{1,2,3}); + } }