diff --git a/headers/flat.hpp/flat_map.hpp b/headers/flat.hpp/flat_map.hpp index c2c3f66..a6ddb38 100644 --- a/headers/flat.hpp/flat_map.hpp +++ b/headers/flat.hpp/flat_map.hpp @@ -265,6 +265,30 @@ namespace flat_hpp throw std::out_of_range("flat_map::at: key not found"); } + template < typename K > + std::enable_if_t< + detail::is_transparent_v, + mapped_type&> + at(const K& key) { + const iterator iter = find(key); + if ( iter != end() ) { + return iter->second; + } + throw std::out_of_range("flat_map::at: key not found"); + } + + template < typename K > + std::enable_if_t< + detail::is_transparent_v, + const mapped_type&> + at(const K& key) const { + const const_iterator iter = find(key); + if ( iter != end() ) { + return iter->second; + } + throw std::out_of_range("flat_map::at: key not found"); + } + std::pair insert(value_type&& value) { const iterator iter = lower_bound(value.first); return iter == end() || this->operator()(value, *iter) @@ -328,7 +352,18 @@ namespace flat_hpp } size_type erase(const key_type& key) { - const iterator iter = find(key); + const const_iterator iter = find(key); + return iter != end() + ? (erase(iter), 1) + : 0; + } + + template < typename K > + std::enable_if_t< + detail::is_transparent_v, + size_type> + erase(const K& key) { + const const_iterator iter = find(key); return iter != end() ? (erase(iter), 1) : 0; @@ -350,6 +385,15 @@ namespace flat_hpp return iter != end() ? 1 : 0; } + template < typename K > + std::enable_if_t< + detail::is_transparent_v, + size_type> + count(const K& key) const { + const const_iterator iter = find(key); + return iter != end() ? 1 : 0; + } + iterator find(const key_type& key) { const iterator iter = lower_bound(key); return iter != end() && !this->operator()(key, *iter) @@ -396,6 +440,24 @@ namespace flat_hpp return std::equal_range(begin(), end(), key, comp); } + template < typename K > + std::enable_if_t< + detail::is_transparent_v, + std::pair> + equal_range(const K& key) { + const base_type& comp = *this; + return std::equal_range(begin(), end(), key, comp); + } + + template < typename K > + std::enable_if_t< + detail::is_transparent_v, + std::pair> + equal_range(const K& key) const { + const base_type& comp = *this; + return std::equal_range(begin(), end(), key, comp); + } + iterator lower_bound(const key_type& key) { const base_type& comp = *this; return std::lower_bound(begin(), end(), key, comp); diff --git a/headers/flat.hpp/flat_multimap.hpp b/headers/flat.hpp/flat_multimap.hpp index 3cd2d5a..aeccbbf 100644 --- a/headers/flat.hpp/flat_multimap.hpp +++ b/headers/flat.hpp/flat_multimap.hpp @@ -265,6 +265,30 @@ namespace flat_hpp throw std::out_of_range("flat_multimap::at: key not found"); } + template < typename K > + std::enable_if_t< + detail::is_transparent_v, + mapped_type&> + at(const K& key) { + const iterator iter = find(key); + if ( iter != end() ) { + return iter->second; + } + throw std::out_of_range("flat_multimap::at: key not found"); + } + + template < typename K > + std::enable_if_t< + detail::is_transparent_v, + const mapped_type&> + at(const K& key) const { + const const_iterator iter = find(key); + if ( iter != end() ) { + return iter->second; + } + throw std::out_of_range("flat_multimap::at: key not found"); + } + iterator insert(value_type&& value) { const iterator iter = upper_bound(value.first); return data_.insert(iter, std::move(value)); @@ -330,6 +354,17 @@ namespace flat_hpp return r; } + template < typename K > + std::enable_if_t< + detail::is_transparent_v, + size_type> + erase(const K& key) { + const auto p = equal_range(key); + size_type r = std::distance(p.first, p.second); + erase(p.first, p.second); + return r; + } + void swap(flat_multimap& other) noexcept(std::is_nothrow_swappable_v && std::is_nothrow_swappable_v) @@ -346,6 +381,15 @@ namespace flat_hpp return std::distance(p.first, p.second); } + template < typename K > + std::enable_if_t< + detail::is_transparent_v, + size_type> + count(const K& key) const { + const auto p = equal_range(key); + return std::distance(p.first, p.second); + } + iterator find(const key_type& key) { const iterator iter = lower_bound(key); return iter != end() && !this->operator()(key, *iter) @@ -392,6 +436,24 @@ namespace flat_hpp return std::equal_range(begin(), end(), key, comp); } + template < typename K > + std::enable_if_t< + detail::is_transparent_v, + std::pair> + equal_range(const K& key) { + const base_type& comp = *this; + return std::equal_range(begin(), end(), key, comp); + } + + template < typename K > + std::enable_if_t< + detail::is_transparent_v, + std::pair> + equal_range(const K& key) const { + const base_type& comp = *this; + return std::equal_range(begin(), end(), key, comp); + } + iterator lower_bound(const key_type& key) { const base_type& comp = *this; return std::lower_bound(begin(), end(), key, comp); diff --git a/headers/flat.hpp/flat_multiset.hpp b/headers/flat.hpp/flat_multiset.hpp index d361844..6fdd395 100644 --- a/headers/flat.hpp/flat_multiset.hpp +++ b/headers/flat.hpp/flat_multiset.hpp @@ -281,6 +281,17 @@ namespace flat_hpp return r; } + template < typename K > + std::enable_if_t< + detail::is_transparent_v, + size_type> + erase(const K& key) { + const auto p = equal_range(key); + size_type r = std::distance(p.first, p.second); + erase(p.first, p.second); + return r; + } + void swap(flat_multiset& other) noexcept(std::is_nothrow_swappable_v && std::is_nothrow_swappable_v) @@ -297,6 +308,15 @@ namespace flat_hpp return std::distance(p.first, p.second); } + template < typename K > + std::enable_if_t< + detail::is_transparent_v, + size_type> + count(const K& key) const { + const auto p = equal_range(key); + return std::distance(p.first, p.second); + } + iterator find(const key_type& key) { const iterator iter = lower_bound(key); return iter != end() && !this->operator()(key, *iter) @@ -341,6 +361,22 @@ namespace flat_hpp return std::equal_range(begin(), end(), key, key_comp()); } + template < typename K > + std::enable_if_t< + detail::is_transparent_v, + std::pair> + equal_range(const K& key) { + return std::equal_range(begin(), end(), key, key_comp()); + } + + template < typename K > + std::enable_if_t< + detail::is_transparent_v, + std::pair> + equal_range(const K& key) const { + return std::equal_range(begin(), end(), key, key_comp()); + } + iterator lower_bound(const key_type& key) { return std::lower_bound(begin(), end(), key, key_comp()); } diff --git a/headers/flat.hpp/flat_set.hpp b/headers/flat.hpp/flat_set.hpp index 3cdb994..21e06ab 100644 --- a/headers/flat.hpp/flat_set.hpp +++ b/headers/flat.hpp/flat_set.hpp @@ -279,7 +279,18 @@ namespace flat_hpp } size_type erase(const key_type& key) { - const iterator iter = find(key); + const const_iterator iter = find(key); + return iter != end() + ? (erase(iter), 1) + : 0; + } + + template < typename K > + std::enable_if_t< + detail::is_transparent_v, + size_type> + erase(const K& key) { + const const_iterator iter = find(key); return iter != end() ? (erase(iter), 1) : 0; @@ -301,6 +312,15 @@ namespace flat_hpp return iter != end() ? 1 : 0; } + template < typename K > + std::enable_if_t< + detail::is_transparent_v, + size_type> + count(const K& key) const { + const const_iterator iter = find(key); + return iter != end() ? 1 : 0; + } + iterator find(const key_type& key) { const iterator iter = lower_bound(key); return iter != end() && !this->operator()(key, *iter) @@ -345,6 +365,22 @@ namespace flat_hpp return std::equal_range(begin(), end(), key, key_comp()); } + template < typename K > + std::enable_if_t< + detail::is_transparent_v, + std::pair> + equal_range(const K& key) { + return std::equal_range(begin(), end(), key, key_comp()); + } + + template < typename K > + std::enable_if_t< + detail::is_transparent_v, + std::pair> + equal_range(const K& key) const { + return std::equal_range(begin(), end(), key, key_comp()); + } + iterator lower_bound(const key_type& key) { return std::lower_bound(begin(), end(), key, key_comp()); } diff --git a/untests/flat_map_tests.cpp b/untests/flat_map_tests.cpp index 935b90b..45fbd38 100644 --- a/untests/flat_map_tests.cpp +++ b/untests/flat_map_tests.cpp @@ -410,13 +410,23 @@ TEST_CASE("flat_map") { REQUIRE(my_as_const(s0).lower_bound(-1) == s0.cbegin()); REQUIRE(my_as_const(s0).lower_bound(7) == s0.cbegin() + 3); } - { - flat_map> s0{{"hello", 42}, {"world", 84}}; - REQUIRE(s0.find(std::string_view("hello")) == s0.begin()); - REQUIRE(my_as_const(s0).find(std::string_view("world")) == s0.begin() + 1); - REQUIRE(s0.find(std::string_view("42")) == s0.end()); - REQUIRE(my_as_const(s0).find(std::string_view("42")) == s0.cend()); - } + } + SECTION("heterogeneous") { + flat_map> s0{{"hello", 42}, {"world", 84}}; + REQUIRE(s0.find(std::string_view("hello")) == s0.begin()); + REQUIRE(my_as_const(s0).find(std::string_view("world")) == s0.begin() + 1); + REQUIRE(s0.find(std::string_view("42")) == s0.end()); + REQUIRE(my_as_const(s0).find(std::string_view("42")) == s0.cend()); + + REQUIRE(my_as_const(s0).count(std::string_view("hello")) == 1); + REQUIRE(my_as_const(s0).count(std::string_view("hello_42")) == 0); + + REQUIRE(s0.upper_bound(std::string_view("hello")) == s0.begin() + 1); + REQUIRE(my_as_const(s0).upper_bound(std::string_view("hello")) == s0.begin() + 1); + + REQUIRE(s0.erase(std::string_view("hello")) == 1); + REQUIRE(s0.at(std::string_view("world")) == 84); + REQUIRE(my_as_const(s0).at(std::string_view("world")) == 84); } SECTION("observers") { struct my_less { diff --git a/untests/flat_multimap_tests.cpp b/untests/flat_multimap_tests.cpp index cd97dfa..0c9953a 100644 --- a/untests/flat_multimap_tests.cpp +++ b/untests/flat_multimap_tests.cpp @@ -412,13 +412,23 @@ TEST_CASE("flat_multimap") { REQUIRE(my_as_const(s0).lower_bound(-1) == s0.cbegin()); REQUIRE(my_as_const(s0).lower_bound(7) == s0.cbegin() + 4); } - { - flat_multimap> s0{{"hello", 42}, {"world", 84}}; - REQUIRE(s0.find(std::string_view("hello")) == s0.begin()); - REQUIRE(my_as_const(s0).find(std::string_view("world")) == s0.begin() + 1); - REQUIRE(s0.find(std::string_view("42")) == s0.end()); - REQUIRE(my_as_const(s0).find(std::string_view("42")) == s0.cend()); - } + } + SECTION("heterogeneous") { + flat_multimap> s0{{"hello", 42}, {"world", 84}}; + REQUIRE(s0.find(std::string_view("hello")) == s0.begin()); + REQUIRE(my_as_const(s0).find(std::string_view("world")) == s0.begin() + 1); + REQUIRE(s0.find(std::string_view("42")) == s0.end()); + REQUIRE(my_as_const(s0).find(std::string_view("42")) == s0.cend()); + + REQUIRE(my_as_const(s0).count(std::string_view("hello")) == 1); + REQUIRE(my_as_const(s0).count(std::string_view("hello_42")) == 0); + + REQUIRE(s0.upper_bound(std::string_view("hello")) == s0.begin() + 1); + REQUIRE(my_as_const(s0).upper_bound(std::string_view("hello")) == s0.begin() + 1); + + REQUIRE(s0.erase(std::string_view("hello")) == 1); + REQUIRE(s0.at(std::string_view("world")) == 84); + REQUIRE(my_as_const(s0).at(std::string_view("world")) == 84); } SECTION("observers") { struct my_less { diff --git a/untests/flat_multiset_tests.cpp b/untests/flat_multiset_tests.cpp index a5ba1fb..9e7833f 100644 --- a/untests/flat_multiset_tests.cpp +++ b/untests/flat_multiset_tests.cpp @@ -388,13 +388,21 @@ TEST_CASE("flat_multiset") { REQUIRE(my_as_const(s0).lower_bound(-1) == s0.cbegin()); REQUIRE(my_as_const(s0).lower_bound(7) == s0.cbegin() + 4); } - { - flat_multiset> s0{"hello", "world"}; - REQUIRE(s0.find(std::string_view("hello")) == s0.begin()); - REQUIRE(my_as_const(s0).find(std::string_view("world")) == s0.begin() + 1); - REQUIRE(s0.find(std::string_view("42")) == s0.end()); - REQUIRE(my_as_const(s0).find(std::string_view("42")) == s0.cend()); - } + } + SECTION("heterogeneous") { + flat_multiset> s0{"hello", "world"}; + REQUIRE(s0.find(std::string_view("hello")) == s0.begin()); + REQUIRE(my_as_const(s0).find(std::string_view("world")) == s0.begin() + 1); + REQUIRE(s0.find(std::string_view("42")) == s0.end()); + REQUIRE(my_as_const(s0).find(std::string_view("42")) == s0.cend()); + + REQUIRE(my_as_const(s0).count(std::string_view("hello")) == 1); + REQUIRE(my_as_const(s0).count(std::string_view("hello_42")) == 0); + + REQUIRE(s0.upper_bound(std::string_view("hello")) == s0.begin() + 1); + REQUIRE(my_as_const(s0).upper_bound(std::string_view("hello")) == s0.begin() + 1); + + REQUIRE(s0.erase(std::string_view("hello")) == 1); } SECTION("observers") { struct my_less { diff --git a/untests/flat_set_tests.cpp b/untests/flat_set_tests.cpp index 60f2c52..83dbd23 100644 --- a/untests/flat_set_tests.cpp +++ b/untests/flat_set_tests.cpp @@ -386,13 +386,21 @@ TEST_CASE("flat_set") { REQUIRE(my_as_const(s0).lower_bound(-1) == s0.cbegin()); REQUIRE(my_as_const(s0).lower_bound(7) == s0.cbegin() + 3); } - { - flat_set> s0{"hello", "world"}; - REQUIRE(s0.find(std::string_view("hello")) == s0.begin()); - REQUIRE(my_as_const(s0).find(std::string_view("world")) == s0.begin() + 1); - REQUIRE(s0.find(std::string_view("42")) == s0.end()); - REQUIRE(my_as_const(s0).find(std::string_view("42")) == s0.cend()); - } + } + SECTION("heterogeneous") { + flat_set> s0{"hello", "world"}; + REQUIRE(s0.find(std::string_view("hello")) == s0.begin()); + REQUIRE(my_as_const(s0).find(std::string_view("world")) == s0.begin() + 1); + REQUIRE(s0.find(std::string_view("42")) == s0.end()); + REQUIRE(my_as_const(s0).find(std::string_view("42")) == s0.cend()); + + REQUIRE(my_as_const(s0).count(std::string_view("hello")) == 1); + REQUIRE(my_as_const(s0).count(std::string_view("hello_42")) == 0); + + REQUIRE(s0.upper_bound(std::string_view("hello")) == s0.begin() + 1); + REQUIRE(my_as_const(s0).upper_bound(std::string_view("hello")) == s0.begin() + 1); + + REQUIRE(s0.erase(std::string_view("hello")) == 1); } SECTION("observers") { struct my_less {