%PDF- %PDF-
| Direktori : /usr/lib/python3/dist-packages/pythran/pythonic/include/numpy/ |
| Current File : //usr/lib/python3/dist-packages/pythran/pythonic/include/numpy/dot.hpp |
#ifndef PYTHONIC_INCLUDE_NUMPY_DOT_HPP
#define PYTHONIC_INCLUDE_NUMPY_DOT_HPP
#include "pythonic/include/types/ndarray.hpp"
#include "pythonic/include/numpy/sum.hpp"
#include "pythonic/include/types/numpy_expr.hpp"
#include "pythonic/include/types/traits.hpp"
template <class T>
struct is_blas_type : pythonic::types::is_complex<T> {
};
template <>
struct is_blas_type<float> : std::true_type {
};
template <>
struct is_blas_type<double> : std::true_type {
};
template <class E>
struct is_strided {
template <class T>
static decltype(T::is_strided, std::true_type{}) get(T *);
static std::false_type get(...);
static constexpr bool value = decltype(get((E *)nullptr))::value;
};
template <class E>
struct is_blas_array {
// FIXME: also support gexpr with stride?
static constexpr bool value =
pythonic::types::is_array<E>::value &&
is_blas_type<pythonic::types::dtype_of<E>>::value &&
!is_strided<E>::value;
};
PYTHONIC_NS_BEGIN
namespace numpy
{
template <class E, class F>
typename std::enable_if<types::is_dtype<E>::value &&
types::is_dtype<F>::value,
decltype(std::declval<E>() * std::declval<F>())>::type
dot(E const &e, F const &f);
/// Vector / Vector multiplication
template <class E, class F>
typename std::enable_if<
types::is_numexpr_arg<E>::value && types::is_numexpr_arg<F>::value &&
E::value == 1 && F::value == 1 &&
(!is_blas_array<E>::value || !is_blas_array<F>::value ||
!std::is_same<typename E::dtype, typename F::dtype>::value),
typename __combined<typename E::dtype, typename F::dtype>::type>::type
dot(E const &e, F const &f);
template <class E, class F>
typename std::enable_if<E::value == 1 && F::value == 1 &&
std::is_same<typename E::dtype, float>::value &&
std::is_same<typename F::dtype, float>::value &&
is_blas_array<E>::value &&
is_blas_array<F>::value,
float>::type
dot(E const &e, F const &f);
template <class E, class F>
typename std::enable_if<E::value == 1 && F::value == 1 &&
std::is_same<typename E::dtype, double>::value &&
std::is_same<typename F::dtype, double>::value &&
is_blas_array<E>::value &&
is_blas_array<F>::value,
double>::type
dot(E const &e, F const &f);
template <class E, class F>
typename std::enable_if<
E::value == 1 && F::value == 1 &&
std::is_same<typename E::dtype, std::complex<float>>::value &&
std::is_same<typename F::dtype, std::complex<float>>::value &&
is_blas_array<E>::value && is_blas_array<F>::value,
std::complex<float>>::type
dot(E const &e, F const &f);
template <class E, class F>
typename std::enable_if<
E::value == 1 && F::value == 1 &&
std::is_same<typename E::dtype, std::complex<double>>::value &&
std::is_same<typename F::dtype, std::complex<double>>::value &&
is_blas_array<E>::value && is_blas_array<F>::value,
std::complex<double>>::type
dot(E const &e, F const &f);
/// Matrix / Vector multiplication
// We transpose the matrix to reflect our C order
template <class E, class pS0, class pS1>
typename std::enable_if<is_blas_type<E>::value &&
std::tuple_size<pS0>::value == 2 &&
std::tuple_size<pS1>::value == 1,
types::ndarray<E, types::pshape<long>>>::type
dot(types::ndarray<E, pS0> const &f, types::ndarray<E, pS1> const &e);
// The trick is to not transpose the matrix so that MV become VM
template <class E, class pS0, class pS1>
typename std::enable_if<is_blas_type<E>::value &&
std::tuple_size<pS0>::value == 1 &&
std::tuple_size<pS1>::value == 2,
types::ndarray<E, types::pshape<long>>>::type
dot(types::ndarray<E, pS0> const &e, types::ndarray<E, pS1> const &f);
// If arguments could be use with blas, we evaluate them as we need pointer
// on array for blas
template <class E, class F>
typename std::enable_if<
types::is_numexpr_arg<E>::value &&
types::is_numexpr_arg<F>::value // It is an array_like
&& (!(types::is_ndarray<E>::value && types::is_ndarray<F>::value) ||
!std::is_same<typename E::dtype, typename F::dtype>::value) &&
is_blas_type<typename E::dtype>::value &&
is_blas_type<typename F::dtype>::value // With dtype compatible with
// blas
&&
E::value == 2 && F::value == 1, // And it is matrix / vect
types::ndarray<
typename __combined<typename E::dtype, typename F::dtype>::type,
types::pshape<long>>>::type
dot(E const &e, F const &f);
// If arguments could be use with blas, we evaluate them as we need pointer
// on array for blas
template <class E, class F>
typename std::enable_if<
types::is_numexpr_arg<E>::value &&
types::is_numexpr_arg<F>::value // It is an array_like
&& (!(types::is_ndarray<E>::value && types::is_ndarray<F>::value) ||
!std::is_same<typename E::dtype, typename F::dtype>::value) &&
is_blas_type<typename E::dtype>::value &&
is_blas_type<typename F::dtype>::value // With dtype compatible with
// blas
&&
E::value == 1 && F::value == 2, // And it is vect / matrix
types::ndarray<
typename __combined<typename E::dtype, typename F::dtype>::type,
types::pshape<long>>>::type
dot(E const &e, F const &f);
// If one of the arg doesn't have a "blas compatible type", we use a slow
// matrix vector multiplication.
template <class E, class F>
typename std::enable_if<
(!is_blas_type<typename E::dtype>::value ||
!is_blas_type<typename F::dtype>::value) &&
E::value == 1 && F::value == 2, // And it is vect / matrix
types::ndarray<
typename __combined<typename E::dtype, typename F::dtype>::type,
types::pshape<long>>>::type
dot(E const &e, F const &f);
// If one of the arg doesn't have a "blas compatible type", we use a slow
// matrix vector multiplication.
template <class E, class F>
typename std::enable_if<
(!is_blas_type<typename E::dtype>::value ||
!is_blas_type<typename F::dtype>::value) &&
E::value == 2 && F::value == 1, // And it is vect / matrix
types::ndarray<
typename __combined<typename E::dtype, typename F::dtype>::type,
types::pshape<long>>>::type
dot(E const &e, F const &f);
/// Matrix / Matrix multiplication
// The trick is to use the transpose arguments to reflect C order.
// We want to perform A * B in C order but blas order is F order.
// So we compute B'A' == (AB)'. As this equality is perform with F order
// We doesn't have to return a texpr because we want a C order matrice!!
template <class E, class pS0, class pS1>
typename std::enable_if<is_blas_type<E>::value &&
std::tuple_size<pS0>::value == 2 &&
std::tuple_size<pS1>::value == 2,
types::ndarray<E, types::array<long, 2>>>::type
dot(types::ndarray<E, pS0> const &a, types::ndarray<E, pS1> const &b);
template <class E, class pS0, class pS1, class pS2>
typename std::enable_if<
is_blas_type<E>::value && std::tuple_size<pS0>::value == 2 &&
std::tuple_size<pS1>::value == 2 && std::tuple_size<pS2>::value == 2,
types::ndarray<E, pS2>>::type &
dot(types::ndarray<E, pS0> const &a, types::ndarray<E, pS1> const &b,
types::ndarray<E, pS2> &c);
// texpr variants: MT, TM, TT
template <class E, class pS0, class pS1>
typename std::enable_if<is_blas_type<E>::value &&
std::tuple_size<pS0>::value == 2 &&
std::tuple_size<pS1>::value == 2,
types::ndarray<E, types::array<long, 2>>>::type
dot(types::numpy_texpr<types::ndarray<E, pS0>> const &a,
types::ndarray<E, pS1> const &b);
template <class E, class pS0, class pS1>
typename std::enable_if<is_blas_type<E>::value &&
std::tuple_size<pS0>::value == 2 &&
std::tuple_size<pS1>::value == 2,
types::ndarray<E, types::array<long, 2>>>::type
dot(types::ndarray<E, pS0> const &a,
types::numpy_texpr<types::ndarray<E, pS1>> const &b);
template <class E, class pS0, class pS1>
typename std::enable_if<is_blas_type<E>::value &&
std::tuple_size<pS0>::value == 2 &&
std::tuple_size<pS1>::value == 2,
types::ndarray<E, types::array<long, 2>>>::type
dot(types::numpy_texpr<types::ndarray<E, pS0>> const &a,
types::numpy_texpr<types::ndarray<E, pS1>> const &b);
// If arguments could be use with blas, we evaluate them as we need pointer
// on array for blas
template <class E, class F>
typename std::enable_if<
types::is_numexpr_arg<E>::value &&
types::is_numexpr_arg<F>::value // It is an array_like
&& (!(types::is_ndarray<E>::value && types::is_ndarray<F>::value) ||
!std::is_same<typename E::dtype, typename F::dtype>::value) &&
is_blas_type<typename E::dtype>::value &&
is_blas_type<typename F::dtype>::value // With dtype compatible with
// blas
&&
E::value == 2 && F::value == 2, // And both are matrix
types::ndarray<
typename __combined<typename E::dtype, typename F::dtype>::type,
types::array<long, 2>>>::type
dot(E const &e, F const &f);
// If one of the arg doesn't have a "blas compatible type", we use a slow
// matrix multiplication.
template <class E, class F>
typename std::enable_if<
(!is_blas_type<typename E::dtype>::value ||
!is_blas_type<typename F::dtype>::value) &&
E::value == 2 && F::value == 2, // And it is matrix / matrix
types::ndarray<
typename __combined<typename E::dtype, typename F::dtype>::type,
types::array<long, 2>>>::type
dot(E const &e, F const &f);
// N x M where N >= 3 and M == 1
template <class E, class F>
typename std::enable_if<
(E::value >= 3 && F::value == 1),
types::ndarray<
typename __combined<typename E::dtype, typename F::dtype>::type,
types::array<long, E::value - 1>>>::type
dot(E const &e, F const &f);
// N x M where N >= 3 and M >= 2
template <class E, class F>
typename std::enable_if<
(E::value >= 3 && F::value >= 2),
types::ndarray<
typename __combined<typename E::dtype, typename F::dtype>::type,
types::array<long, E::value - 1>>>::type
dot(E const &e, F const &f);
DEFINE_FUNCTOR(pythonic::numpy, dot);
}
PYTHONIC_NS_END
#endif