diff options
| author | maxim-yurchuk <[email protected]> | 2024-10-09 12:29:46 +0300 |
|---|---|---|
| committer | maxim-yurchuk <[email protected]> | 2024-10-09 13:14:22 +0300 |
| commit | 9731d8a4bb7ee2cc8554eaf133bb85498a4c7d80 (patch) | |
| tree | a8fb3181d5947c0d78cf402aa56e686130179049 /contrib/python/pythran/pythran/pythonic/include/numpy/dot.hpp | |
| parent | a44b779cd359f06c3ebbef4ec98c6b38609d9d85 (diff) | |
publishFullContrib: true for ydb
<HIDDEN_URL>
commit_hash:c82a80ac4594723cebf2c7387dec9c60217f603e
Diffstat (limited to 'contrib/python/pythran/pythran/pythonic/include/numpy/dot.hpp')
| -rw-r--r-- | contrib/python/pythran/pythran/pythonic/include/numpy/dot.hpp | 284 |
1 files changed, 284 insertions, 0 deletions
diff --git a/contrib/python/pythran/pythran/pythonic/include/numpy/dot.hpp b/contrib/python/pythran/pythran/pythonic/include/numpy/dot.hpp new file mode 100644 index 00000000000..a8dc8089a82 --- /dev/null +++ b/contrib/python/pythran/pythran/pythonic/include/numpy/dot.hpp @@ -0,0 +1,284 @@ +#ifndef PYTHONIC_INCLUDE_NUMPY_DOT_HPP +#define PYTHONIC_INCLUDE_NUMPY_DOT_HPP + +#include "pythonic/include/types/ndarray.hpp" +#include "pythonic/include/numpy/sum.hpp" +#include "pythonic/include/types/numpy_expr.hpp" +#include "pythonic/include/types/traits.hpp" + +template <class T> +struct is_blas_type : pythonic::types::is_complex<T> { +}; + +template <> +struct is_blas_type<float> : std::true_type { +}; + +template <> +struct is_blas_type<double> : std::true_type { +}; + +template <class E> +struct is_strided { + template <class T> + static decltype(T::is_strided, std::true_type{}) get(T *); + static std::false_type get(...); + static constexpr bool value = decltype(get((E *)nullptr))::value; +}; + +template <class E> +struct is_blas_array { + // FIXME: also support gexpr with stride? + static constexpr bool value = + pythonic::types::is_array<E>::value && + is_blas_type<pythonic::types::dtype_of<E>>::value && + !is_strided<E>::value; +}; + +PYTHONIC_NS_BEGIN + +namespace numpy +{ + template <class E, class F> + typename std::enable_if<types::is_dtype<E>::value && + types::is_dtype<F>::value, + decltype(std::declval<E>() * std::declval<F>())>::type + dot(E const &e, F const &f); + + /// Vector / Vector multiplication + template <class E, class F> + typename std::enable_if< + types::is_numexpr_arg<E>::value && types::is_numexpr_arg<F>::value && + E::value == 1 && F::value == 1 && + (!is_blas_array<E>::value || !is_blas_array<F>::value || + !std::is_same<typename E::dtype, typename F::dtype>::value), + typename __combined<typename E::dtype, typename F::dtype>::type>::type + dot(E const &e, F const &f); + + template <class E, class F> + typename std::enable_if<E::value == 1 && F::value == 1 && + std::is_same<typename E::dtype, float>::value && + std::is_same<typename F::dtype, float>::value && + is_blas_array<E>::value && + is_blas_array<F>::value, + float>::type + dot(E const &e, F const &f); + + template <class E, class F> + typename std::enable_if<E::value == 1 && F::value == 1 && + std::is_same<typename E::dtype, double>::value && + std::is_same<typename F::dtype, double>::value && + is_blas_array<E>::value && + is_blas_array<F>::value, + double>::type + dot(E const &e, F const &f); + + template <class E, class F> + typename std::enable_if< + E::value == 1 && F::value == 1 && + std::is_same<typename E::dtype, std::complex<float>>::value && + std::is_same<typename F::dtype, std::complex<float>>::value && + is_blas_array<E>::value && is_blas_array<F>::value, + std::complex<float>>::type + dot(E const &e, F const &f); + + template <class E, class F> + typename std::enable_if< + E::value == 1 && F::value == 1 && + std::is_same<typename E::dtype, std::complex<double>>::value && + std::is_same<typename F::dtype, std::complex<double>>::value && + is_blas_array<E>::value && is_blas_array<F>::value, + std::complex<double>>::type + dot(E const &e, F const &f); + + /// Matrix / Vector multiplication + + // We transpose the matrix to reflect our C order + template <class E, class pS0, class pS1> + typename std::enable_if<is_blas_type<E>::value && + std::tuple_size<pS0>::value == 2 && + std::tuple_size<pS1>::value == 1, + types::ndarray<E, types::pshape<long>>>::type + dot(types::ndarray<E, pS0> const &f, types::ndarray<E, pS1> const &e); + + template <class E, class pS0, class pS1> + typename std::enable_if<is_blas_type<E>::value && + std::tuple_size<pS0>::value == 2 && + std::tuple_size<pS1>::value == 1, + types::ndarray<E, types::pshape<long>>>::type + dot(types::numpy_texpr<types::ndarray<E, pS0>> const &f, + types::ndarray<E, pS1> const &e); + + // The trick is to not transpose the matrix so that MV become VM + template <class E, class pS0, class pS1> + typename std::enable_if<is_blas_type<E>::value && + std::tuple_size<pS0>::value == 1 && + std::tuple_size<pS1>::value == 2, + types::ndarray<E, types::pshape<long>>>::type + dot(types::ndarray<E, pS0> const &e, types::ndarray<E, pS1> const &f); + + template <class E, class pS0, class pS1> + typename std::enable_if<is_blas_type<E>::value && + std::tuple_size<pS0>::value == 1 && + std::tuple_size<pS1>::value == 2, + types::ndarray<E, types::pshape<long>>>::type + dot(types::ndarray<E, pS0> const &e, + types::numpy_texpr<types::ndarray<E, pS1>> const &f); + + // If arguments could be use with blas, we evaluate them as we need pointer + // on array for blas + template <class E, class F> + typename std::enable_if< + types::is_numexpr_arg<E>::value && + types::is_numexpr_arg<F>::value // It is an array_like + && (!(types::is_ndarray<E>::value && types::is_ndarray<F>::value) || + !std::is_same<typename E::dtype, typename F::dtype>::value) && + is_blas_type<typename E::dtype>::value && + is_blas_type<typename F::dtype>::value // With dtype compatible with + // blas + && + E::value == 2 && F::value == 1, // And it is matrix / vect + types::ndarray< + typename __combined<typename E::dtype, typename F::dtype>::type, + types::pshape<long>>>::type + dot(E const &e, F const &f); + + // If arguments could be use with blas, we evaluate them as we need pointer + // on array for blas + template <class E, class F> + typename std::enable_if< + types::is_numexpr_arg<E>::value && + types::is_numexpr_arg<F>::value // It is an array_like + && (!(types::is_ndarray<E>::value && types::is_ndarray<F>::value) || + !std::is_same<typename E::dtype, typename F::dtype>::value) && + is_blas_type<typename E::dtype>::value && + is_blas_type<typename F::dtype>::value // With dtype compatible with + // blas + && + E::value == 1 && F::value == 2, // And it is vect / matrix + types::ndarray< + typename __combined<typename E::dtype, typename F::dtype>::type, + types::pshape<long>>>::type + dot(E const &e, F const &f); + + // If one of the arg doesn't have a "blas compatible type", we use a slow + // matrix vector multiplication. + template <class E, class F> + typename std::enable_if< + (!is_blas_type<typename E::dtype>::value || + !is_blas_type<typename F::dtype>::value) && + E::value == 1 && F::value == 2, // And it is vect / matrix + types::ndarray< + typename __combined<typename E::dtype, typename F::dtype>::type, + types::pshape<long>>>::type + dot(E const &e, F const &f); + + // If one of the arg doesn't have a "blas compatible type", we use a slow + // matrix vector multiplication. + template <class E, class F> + typename std::enable_if< + (!is_blas_type<typename E::dtype>::value || + !is_blas_type<typename F::dtype>::value) && + E::value == 2 && F::value == 1, // And it is vect / matrix + types::ndarray< + typename __combined<typename E::dtype, typename F::dtype>::type, + types::pshape<long>>>::type + dot(E const &e, F const &f); + + /// Matrix / Matrix multiplication + + // The trick is to use the transpose arguments to reflect C order. + // We want to perform A * B in C order but blas order is F order. + // So we compute B'A' == (AB)'. As this equality is perform with F order + // We doesn't have to return a texpr because we want a C order matrice!! + template <class E, class pS0, class pS1> + typename std::enable_if<is_blas_type<E>::value && + std::tuple_size<pS0>::value == 2 && + std::tuple_size<pS1>::value == 2, + types::ndarray<E, types::array<long, 2>>>::type + dot(types::ndarray<E, pS0> const &a, types::ndarray<E, pS1> const &b); + + template <class E, class pS0, class pS1, class pS2> + typename std::enable_if< + is_blas_type<E>::value && std::tuple_size<pS0>::value == 2 && + std::tuple_size<pS1>::value == 2 && std::tuple_size<pS2>::value == 2, + types::ndarray<E, pS2>>::type & + dot(types::ndarray<E, pS0> const &a, types::ndarray<E, pS1> const &b, + types::ndarray<E, pS2> &c); + + // texpr variants: MT, TM, TT + template <class E, class pS0, class pS1> + typename std::enable_if<is_blas_type<E>::value && + std::tuple_size<pS0>::value == 2 && + std::tuple_size<pS1>::value == 2, + types::ndarray<E, types::array<long, 2>>>::type + dot(types::numpy_texpr<types::ndarray<E, pS0>> const &a, + types::ndarray<E, pS1> const &b); + template <class E, class pS0, class pS1> + typename std::enable_if<is_blas_type<E>::value && + std::tuple_size<pS0>::value == 2 && + std::tuple_size<pS1>::value == 2, + types::ndarray<E, types::array<long, 2>>>::type + dot(types::ndarray<E, pS0> const &a, + types::numpy_texpr<types::ndarray<E, pS1>> const &b); + template <class E, class pS0, class pS1> + typename std::enable_if<is_blas_type<E>::value && + std::tuple_size<pS0>::value == 2 && + std::tuple_size<pS1>::value == 2, + types::ndarray<E, types::array<long, 2>>>::type + dot(types::numpy_texpr<types::ndarray<E, pS0>> const &a, + types::numpy_texpr<types::ndarray<E, pS1>> const &b); + + // If arguments could be use with blas, we evaluate them as we need pointer + // on array for blas + template <class E, class F> + typename std::enable_if< + types::is_numexpr_arg<E>::value && + types::is_numexpr_arg<F>::value // It is an array_like + && (!(types::is_ndarray<E>::value && types::is_ndarray<F>::value) || + !std::is_same<typename E::dtype, typename F::dtype>::value) && + is_blas_type<typename E::dtype>::value && + is_blas_type<typename F::dtype>::value // With dtype compatible with + // blas + && + E::value == 2 && F::value == 2, // And both are matrix + types::ndarray< + typename __combined<typename E::dtype, typename F::dtype>::type, + types::array<long, 2>>>::type + dot(E const &e, F const &f); + + // If one of the arg doesn't have a "blas compatible type", we use a slow + // matrix multiplication. + template <class E, class F> + typename std::enable_if< + (!is_blas_type<typename E::dtype>::value || + !is_blas_type<typename F::dtype>::value) && + E::value == 2 && F::value == 2, // And it is matrix / matrix + types::ndarray< + typename __combined<typename E::dtype, typename F::dtype>::type, + types::array<long, 2>>>::type + dot(E const &e, F const &f); + + // N x M where N >= 3 and M == 1 + template <class E, class F> + typename std::enable_if< + (E::value >= 3 && F::value == 1), + types::ndarray< + typename __combined<typename E::dtype, typename F::dtype>::type, + types::array<long, E::value - 1>>>::type + dot(E const &e, F const &f); + + // N x M where N >= 3 and M >= 2 + template <class E, class F> + typename std::enable_if< + (E::value >= 3 && F::value >= 2), + types::ndarray< + typename __combined<typename E::dtype, typename F::dtype>::type, + types::array<long, E::value - 1>>>::type + dot(E const &e, F const &f); + + DEFINE_FUNCTOR(pythonic::numpy, dot); +} +PYTHONIC_NS_END + +#endif |
