diff --git a/README.md b/README.md index 53a0e31..633a92d 100644 --- a/README.md +++ b/README.md @@ -1296,8 +1296,9 @@ vec rsqrt(const vec& xs); #### Scalar ```cpp -template < arithmetic T > -T dot(T x, T y); +template < arithmetic T, arithmetic U + , arithmetic V = decltype(declval() * declval()) > +V dot(T x, U y); template < arithmetic T > T length(T x); @@ -1333,8 +1334,9 @@ T refract(T i, T n, T eta); #### Vector ```cpp -template < typename T, size_t Size > -T dot(const vec& xs, const vec& ys); +template < typename T, typename U, size_t Size + , typename V = decltype(declval() * declval()) > +V dot(const vec& xs, const vec& ys); template < typename T, size_t Size > T length(const vec& xs); @@ -1354,11 +1356,13 @@ T distance(const vec& xs, const vec& ys); template < typename T, size_t Size > T distance2(const vec& xs, const vec& ys); -template < typename T > -T cross(const vec& xs, const vec& ys); +template < typename T, typename U + , typename V = decltype(declval() * declval()) > +V cross(const vec& xs, const vec& ys); -template < typename T > -vec cross(const vec& xs, const vec& ys); +template < typename T, typename U + , typename V = decltype(declval() * declval()) > +vec cross(const vec& xs, const vec& ys); template < typename T, size_t Size > vec normalize(const vec& xs); @@ -1376,8 +1380,11 @@ vec refract(const vec& i, const vec& n, T eta); #### Quaternion ```cpp -template < typename T > -T dot(const qua& xs, const qua& ys); +template < typename T, typename U + , typename V = decltype(dot( + declval>(), + declval>())) > +V dot(const qua& xs, const qua& ys); template < typename T > T length(const qua& xs); diff --git a/headers/vmath.hpp/vmath_fun.hpp b/headers/vmath.hpp/vmath_fun.hpp index 8da26fc..f155291 100644 --- a/headers/vmath.hpp/vmath_fun.hpp +++ b/headers/vmath.hpp/vmath_fun.hpp @@ -305,10 +305,11 @@ namespace vmath_hpp namespace vmath_hpp { - template < typename T > - [[nodiscard]] std::enable_if_t, T> - constexpr dot(T x, T y) noexcept { - return x * y; + template < typename T, typename U + , typename V = decltype(std::declval() * std::declval()) > + [[nodiscard]] std::enable_if_t, V> + constexpr dot(T x, U y) noexcept { + return { x * y }; } template < typename T > diff --git a/headers/vmath.hpp/vmath_qua_fun.hpp b/headers/vmath.hpp/vmath_qua_fun.hpp index efb03ea..3b3ef2b 100644 --- a/headers/vmath.hpp/vmath_qua_fun.hpp +++ b/headers/vmath.hpp/vmath_qua_fun.hpp @@ -249,8 +249,11 @@ namespace vmath_hpp namespace vmath_hpp { - template < typename T > - [[nodiscard]] constexpr T dot(const qua& xs, const qua& ys) { + template < typename T, typename U + , typename V = decltype(dot( + std::declval>(), + std::declval>())) > + [[nodiscard]] constexpr V dot(const qua& xs, const qua& ys) { return dot(vec{xs}, vec{ys}); } diff --git a/headers/vmath.hpp/vmath_vec_fun.hpp b/headers/vmath.hpp/vmath_vec_fun.hpp index 69c3c6e..6a194f5 100644 --- a/headers/vmath.hpp/vmath_vec_fun.hpp +++ b/headers/vmath.hpp/vmath_vec_fun.hpp @@ -775,11 +775,12 @@ namespace vmath_hpp namespace vmath_hpp { - template < typename T, std::size_t Size > - [[nodiscard]] constexpr T dot(const vec& xs, const vec& ys) { - return fold_join([](T acc, T x, T y){ + template < typename T, typename U, std::size_t Size + , typename V = decltype(std::declval() * std::declval()) > + [[nodiscard]] constexpr V dot(const vec& xs, const vec& ys) { + return fold_join([](V acc, T x, U y){ return acc + (x * y); - }, T{0}, xs, ys); + }, V{0}, xs, ys); } template < typename T, std::size_t Size > @@ -812,13 +813,15 @@ namespace vmath_hpp return length2(ys - xs); } - template < typename T > - [[nodiscard]] constexpr T cross(const vec& xs, const vec& ys) { - return xs.x * ys.y - xs.y * ys.x; + template < typename T, typename U + , typename V = decltype(std::declval() * std::declval()) > + [[nodiscard]] constexpr V cross(const vec& xs, const vec& ys) { + return { xs.x * ys.y - xs.y * ys.x }; } - template < typename T > - [[nodiscard]] constexpr vec cross(const vec& xs, const vec& ys) { + template < typename T, typename U + , typename V = decltype(std::declval() * std::declval()) > + [[nodiscard]] constexpr vec cross(const vec& xs, const vec& ys) { return { xs.y * ys.z - xs.z * ys.y, xs.z * ys.x - xs.x * ys.z, diff --git a/untests/vmath_fun_tests.cpp b/untests/vmath_fun_tests.cpp index 9903781..4bb37e9 100644 --- a/untests/vmath_fun_tests.cpp +++ b/untests/vmath_fun_tests.cpp @@ -126,7 +126,9 @@ TEST_CASE("vmath/fun") { STATIC_CHECK(distance2(5.f, 10.f) == uapprox(25.f)); STATIC_CHECK(distance2(-5.f, -10.f) == uapprox(25.f)); - STATIC_CHECK(dot(2.f, 5.f) == uapprox(10.f)); + STATIC_CHECK(dot(2, 5) == uapprox(10)); + STATIC_CHECK(dot(2, 5.f) == uapprox(10.f)); + STATIC_CHECK(normalize(0.5f) == uapprox(1.f)); STATIC_CHECK(faceforward(1.f, 2.f, 3.f) == uapprox(-1.f)); diff --git a/untests/vmath_qua_fun_tests.cpp b/untests/vmath_qua_fun_tests.cpp index 80bb33e..b65d8e9 100644 --- a/untests/vmath_qua_fun_tests.cpp +++ b/untests/vmath_qua_fun_tests.cpp @@ -132,6 +132,7 @@ TEST_CASE("vmath/qua_fun") { SUBCASE("Geometric Functions") { STATIC_CHECK(dot(qua(1,2,3,4),qua(3,4,5,6)) == 50); + STATIC_CHECK(dot(qfloat(1,2,3,4),qdouble(3,4,5,6)) == uapprox(50.0)); CHECK(length(qfloat(10.f,0.f,0.f,0.f)) == uapprox(10.f)); CHECK(length(qfloat(-10.f,0.f,0.f,0.f)) == uapprox(10.f)); diff --git a/untests/vmath_vec_fun_tests.cpp b/untests/vmath_vec_fun_tests.cpp index 262f1ec..53df1de 100644 --- a/untests/vmath_vec_fun_tests.cpp +++ b/untests/vmath_vec_fun_tests.cpp @@ -271,8 +271,17 @@ TEST_CASE("vmath/vec_fun") { STATIC_CHECK(distance2(float2(-5.f,0.f), float2(-10.f,0.f)) == uapprox(25.f)); STATIC_CHECK(dot(int2(1,2),int2(3,4)) == 11); + STATIC_CHECK(dot(int2(1,2),float2(3,4)) == uapprox(11.f)); + STATIC_CHECK(dot(float2(3,4),int2(1,2)) == uapprox(11.f)); + STATIC_CHECK(cross(int2(1,0),int2(0,1)) == 1); + STATIC_CHECK(cross(int2(1,0),float2(0,1)) == uapprox(1.f)); + STATIC_CHECK(cross(float2(0,1),int2(1,0)) == uapprox(-1.f)); + STATIC_CHECK(cross(int3(1,0,0),int3(0,1,0)) == int3(0,0,1)); + STATIC_CHECK(cross(int3(1,0,0),float3(0,1,0)) == uapprox3(0.f,0.f,1.f)); + STATIC_CHECK(cross(float3(0,1,0),int3(1,0,0)) == uapprox3(0.f,0.f,-1.f)); + CHECK(normalize(float2(0.5f,0.f)).x == uapprox(1.f)); STATIC_CHECK(faceforward(float2(1.f), float2(2.f), float2(3.f)).x == uapprox(-1.f));