summaryrefslogtreecommitdiffstats
path: root/contrib/python/pythran/pythran/pythonic/include/numpy/dot.hpp
diff options
context:
space:
mode:
authormaxim-yurchuk <[email protected]>2024-10-09 12:29:46 +0300
committermaxim-yurchuk <[email protected]>2024-10-09 13:14:22 +0300
commit9731d8a4bb7ee2cc8554eaf133bb85498a4c7d80 (patch)
treea8fb3181d5947c0d78cf402aa56e686130179049 /contrib/python/pythran/pythran/pythonic/include/numpy/dot.hpp
parenta44b779cd359f06c3ebbef4ec98c6b38609d9d85 (diff)
publishFullContrib: true for ydb
<HIDDEN_URL> commit_hash:c82a80ac4594723cebf2c7387dec9c60217f603e
Diffstat (limited to 'contrib/python/pythran/pythran/pythonic/include/numpy/dot.hpp')
-rw-r--r--contrib/python/pythran/pythran/pythonic/include/numpy/dot.hpp284
1 files changed, 284 insertions, 0 deletions
diff --git a/contrib/python/pythran/pythran/pythonic/include/numpy/dot.hpp b/contrib/python/pythran/pythran/pythonic/include/numpy/dot.hpp
new file mode 100644
index 00000000000..a8dc8089a82
--- /dev/null
+++ b/contrib/python/pythran/pythran/pythonic/include/numpy/dot.hpp
@@ -0,0 +1,284 @@
+#ifndef PYTHONIC_INCLUDE_NUMPY_DOT_HPP
+#define PYTHONIC_INCLUDE_NUMPY_DOT_HPP
+
+#include "pythonic/include/types/ndarray.hpp"
+#include "pythonic/include/numpy/sum.hpp"
+#include "pythonic/include/types/numpy_expr.hpp"
+#include "pythonic/include/types/traits.hpp"
+
+template <class T>
+struct is_blas_type : pythonic::types::is_complex<T> {
+};
+
+template <>
+struct is_blas_type<float> : std::true_type {
+};
+
+template <>
+struct is_blas_type<double> : std::true_type {
+};
+
+template <class E>
+struct is_strided {
+ template <class T>
+ static decltype(T::is_strided, std::true_type{}) get(T *);
+ static std::false_type get(...);
+ static constexpr bool value = decltype(get((E *)nullptr))::value;
+};
+
+template <class E>
+struct is_blas_array {
+ // FIXME: also support gexpr with stride?
+ static constexpr bool value =
+ pythonic::types::is_array<E>::value &&
+ is_blas_type<pythonic::types::dtype_of<E>>::value &&
+ !is_strided<E>::value;
+};
+
+PYTHONIC_NS_BEGIN
+
+namespace numpy
+{
+ template <class E, class F>
+ typename std::enable_if<types::is_dtype<E>::value &&
+ types::is_dtype<F>::value,
+ decltype(std::declval<E>() * std::declval<F>())>::type
+ dot(E const &e, F const &f);
+
+ /// Vector / Vector multiplication
+ template <class E, class F>
+ typename std::enable_if<
+ types::is_numexpr_arg<E>::value && types::is_numexpr_arg<F>::value &&
+ E::value == 1 && F::value == 1 &&
+ (!is_blas_array<E>::value || !is_blas_array<F>::value ||
+ !std::is_same<typename E::dtype, typename F::dtype>::value),
+ typename __combined<typename E::dtype, typename F::dtype>::type>::type
+ dot(E const &e, F const &f);
+
+ template <class E, class F>
+ typename std::enable_if<E::value == 1 && F::value == 1 &&
+ std::is_same<typename E::dtype, float>::value &&
+ std::is_same<typename F::dtype, float>::value &&
+ is_blas_array<E>::value &&
+ is_blas_array<F>::value,
+ float>::type
+ dot(E const &e, F const &f);
+
+ template <class E, class F>
+ typename std::enable_if<E::value == 1 && F::value == 1 &&
+ std::is_same<typename E::dtype, double>::value &&
+ std::is_same<typename F::dtype, double>::value &&
+ is_blas_array<E>::value &&
+ is_blas_array<F>::value,
+ double>::type
+ dot(E const &e, F const &f);
+
+ template <class E, class F>
+ typename std::enable_if<
+ E::value == 1 && F::value == 1 &&
+ std::is_same<typename E::dtype, std::complex<float>>::value &&
+ std::is_same<typename F::dtype, std::complex<float>>::value &&
+ is_blas_array<E>::value && is_blas_array<F>::value,
+ std::complex<float>>::type
+ dot(E const &e, F const &f);
+
+ template <class E, class F>
+ typename std::enable_if<
+ E::value == 1 && F::value == 1 &&
+ std::is_same<typename E::dtype, std::complex<double>>::value &&
+ std::is_same<typename F::dtype, std::complex<double>>::value &&
+ is_blas_array<E>::value && is_blas_array<F>::value,
+ std::complex<double>>::type
+ dot(E const &e, F const &f);
+
+ /// Matrix / Vector multiplication
+
+ // We transpose the matrix to reflect our C order
+ template <class E, class pS0, class pS1>
+ typename std::enable_if<is_blas_type<E>::value &&
+ std::tuple_size<pS0>::value == 2 &&
+ std::tuple_size<pS1>::value == 1,
+ types::ndarray<E, types::pshape<long>>>::type
+ dot(types::ndarray<E, pS0> const &f, types::ndarray<E, pS1> const &e);
+
+ template <class E, class pS0, class pS1>
+ typename std::enable_if<is_blas_type<E>::value &&
+ std::tuple_size<pS0>::value == 2 &&
+ std::tuple_size<pS1>::value == 1,
+ types::ndarray<E, types::pshape<long>>>::type
+ dot(types::numpy_texpr<types::ndarray<E, pS0>> const &f,
+ types::ndarray<E, pS1> const &e);
+
+ // The trick is to not transpose the matrix so that MV become VM
+ template <class E, class pS0, class pS1>
+ typename std::enable_if<is_blas_type<E>::value &&
+ std::tuple_size<pS0>::value == 1 &&
+ std::tuple_size<pS1>::value == 2,
+ types::ndarray<E, types::pshape<long>>>::type
+ dot(types::ndarray<E, pS0> const &e, types::ndarray<E, pS1> const &f);
+
+ template <class E, class pS0, class pS1>
+ typename std::enable_if<is_blas_type<E>::value &&
+ std::tuple_size<pS0>::value == 1 &&
+ std::tuple_size<pS1>::value == 2,
+ types::ndarray<E, types::pshape<long>>>::type
+ dot(types::ndarray<E, pS0> const &e,
+ types::numpy_texpr<types::ndarray<E, pS1>> const &f);
+
+ // If arguments could be use with blas, we evaluate them as we need pointer
+ // on array for blas
+ template <class E, class F>
+ typename std::enable_if<
+ types::is_numexpr_arg<E>::value &&
+ types::is_numexpr_arg<F>::value // It is an array_like
+ && (!(types::is_ndarray<E>::value && types::is_ndarray<F>::value) ||
+ !std::is_same<typename E::dtype, typename F::dtype>::value) &&
+ is_blas_type<typename E::dtype>::value &&
+ is_blas_type<typename F::dtype>::value // With dtype compatible with
+ // blas
+ &&
+ E::value == 2 && F::value == 1, // And it is matrix / vect
+ types::ndarray<
+ typename __combined<typename E::dtype, typename F::dtype>::type,
+ types::pshape<long>>>::type
+ dot(E const &e, F const &f);
+
+ // If arguments could be use with blas, we evaluate them as we need pointer
+ // on array for blas
+ template <class E, class F>
+ typename std::enable_if<
+ types::is_numexpr_arg<E>::value &&
+ types::is_numexpr_arg<F>::value // It is an array_like
+ && (!(types::is_ndarray<E>::value && types::is_ndarray<F>::value) ||
+ !std::is_same<typename E::dtype, typename F::dtype>::value) &&
+ is_blas_type<typename E::dtype>::value &&
+ is_blas_type<typename F::dtype>::value // With dtype compatible with
+ // blas
+ &&
+ E::value == 1 && F::value == 2, // And it is vect / matrix
+ types::ndarray<
+ typename __combined<typename E::dtype, typename F::dtype>::type,
+ types::pshape<long>>>::type
+ dot(E const &e, F const &f);
+
+ // If one of the arg doesn't have a "blas compatible type", we use a slow
+ // matrix vector multiplication.
+ template <class E, class F>
+ typename std::enable_if<
+ (!is_blas_type<typename E::dtype>::value ||
+ !is_blas_type<typename F::dtype>::value) &&
+ E::value == 1 && F::value == 2, // And it is vect / matrix
+ types::ndarray<
+ typename __combined<typename E::dtype, typename F::dtype>::type,
+ types::pshape<long>>>::type
+ dot(E const &e, F const &f);
+
+ // If one of the arg doesn't have a "blas compatible type", we use a slow
+ // matrix vector multiplication.
+ template <class E, class F>
+ typename std::enable_if<
+ (!is_blas_type<typename E::dtype>::value ||
+ !is_blas_type<typename F::dtype>::value) &&
+ E::value == 2 && F::value == 1, // And it is vect / matrix
+ types::ndarray<
+ typename __combined<typename E::dtype, typename F::dtype>::type,
+ types::pshape<long>>>::type
+ dot(E const &e, F const &f);
+
+ /// Matrix / Matrix multiplication
+
+ // The trick is to use the transpose arguments to reflect C order.
+ // We want to perform A * B in C order but blas order is F order.
+ // So we compute B'A' == (AB)'. As this equality is perform with F order
+ // We doesn't have to return a texpr because we want a C order matrice!!
+ template <class E, class pS0, class pS1>
+ typename std::enable_if<is_blas_type<E>::value &&
+ std::tuple_size<pS0>::value == 2 &&
+ std::tuple_size<pS1>::value == 2,
+ types::ndarray<E, types::array<long, 2>>>::type
+ dot(types::ndarray<E, pS0> const &a, types::ndarray<E, pS1> const &b);
+
+ template <class E, class pS0, class pS1, class pS2>
+ typename std::enable_if<
+ is_blas_type<E>::value && std::tuple_size<pS0>::value == 2 &&
+ std::tuple_size<pS1>::value == 2 && std::tuple_size<pS2>::value == 2,
+ types::ndarray<E, pS2>>::type &
+ dot(types::ndarray<E, pS0> const &a, types::ndarray<E, pS1> const &b,
+ types::ndarray<E, pS2> &c);
+
+ // texpr variants: MT, TM, TT
+ template <class E, class pS0, class pS1>
+ typename std::enable_if<is_blas_type<E>::value &&
+ std::tuple_size<pS0>::value == 2 &&
+ std::tuple_size<pS1>::value == 2,
+ types::ndarray<E, types::array<long, 2>>>::type
+ dot(types::numpy_texpr<types::ndarray<E, pS0>> const &a,
+ types::ndarray<E, pS1> const &b);
+ template <class E, class pS0, class pS1>
+ typename std::enable_if<is_blas_type<E>::value &&
+ std::tuple_size<pS0>::value == 2 &&
+ std::tuple_size<pS1>::value == 2,
+ types::ndarray<E, types::array<long, 2>>>::type
+ dot(types::ndarray<E, pS0> const &a,
+ types::numpy_texpr<types::ndarray<E, pS1>> const &b);
+ template <class E, class pS0, class pS1>
+ typename std::enable_if<is_blas_type<E>::value &&
+ std::tuple_size<pS0>::value == 2 &&
+ std::tuple_size<pS1>::value == 2,
+ types::ndarray<E, types::array<long, 2>>>::type
+ dot(types::numpy_texpr<types::ndarray<E, pS0>> const &a,
+ types::numpy_texpr<types::ndarray<E, pS1>> const &b);
+
+ // If arguments could be use with blas, we evaluate them as we need pointer
+ // on array for blas
+ template <class E, class F>
+ typename std::enable_if<
+ types::is_numexpr_arg<E>::value &&
+ types::is_numexpr_arg<F>::value // It is an array_like
+ && (!(types::is_ndarray<E>::value && types::is_ndarray<F>::value) ||
+ !std::is_same<typename E::dtype, typename F::dtype>::value) &&
+ is_blas_type<typename E::dtype>::value &&
+ is_blas_type<typename F::dtype>::value // With dtype compatible with
+ // blas
+ &&
+ E::value == 2 && F::value == 2, // And both are matrix
+ types::ndarray<
+ typename __combined<typename E::dtype, typename F::dtype>::type,
+ types::array<long, 2>>>::type
+ dot(E const &e, F const &f);
+
+ // If one of the arg doesn't have a "blas compatible type", we use a slow
+ // matrix multiplication.
+ template <class E, class F>
+ typename std::enable_if<
+ (!is_blas_type<typename E::dtype>::value ||
+ !is_blas_type<typename F::dtype>::value) &&
+ E::value == 2 && F::value == 2, // And it is matrix / matrix
+ types::ndarray<
+ typename __combined<typename E::dtype, typename F::dtype>::type,
+ types::array<long, 2>>>::type
+ dot(E const &e, F const &f);
+
+ // N x M where N >= 3 and M == 1
+ template <class E, class F>
+ typename std::enable_if<
+ (E::value >= 3 && F::value == 1),
+ types::ndarray<
+ typename __combined<typename E::dtype, typename F::dtype>::type,
+ types::array<long, E::value - 1>>>::type
+ dot(E const &e, F const &f);
+
+ // N x M where N >= 3 and M >= 2
+ template <class E, class F>
+ typename std::enable_if<
+ (E::value >= 3 && F::value >= 2),
+ types::ndarray<
+ typename __combined<typename E::dtype, typename F::dtype>::type,
+ types::array<long, E::value - 1>>>::type
+ dot(E const &e, F const &f);
+
+ DEFINE_FUNCTOR(pythonic::numpy, dot);
+}
+PYTHONIC_NS_END
+
+#endif