diff --git a/include/xtensor-python/pyarray.hpp b/include/xtensor-python/pyarray.hpp index 648c80d..b7fc162 100644 --- a/include/xtensor-python/pyarray.hpp +++ b/include/xtensor-python/pyarray.hpp @@ -42,8 +42,15 @@ namespace pybind11 { using type = xt::pyarray; - bool load(handle src, bool) + bool load(handle src, bool convert) { + if (!convert) { + if (!PyArray_Check(src.ptr())) + return false; + int type_num = xt::detail::numpy_traits::type_num; + if (PyArray_TYPE(reinterpret_cast(src.ptr())) != type_num) + return false; + } value = type::ensure(src); return static_cast(value); } diff --git a/include/xtensor-python/pytensor.hpp b/include/xtensor-python/pytensor.hpp index 69f2916..97b265c 100644 --- a/include/xtensor-python/pytensor.hpp +++ b/include/xtensor-python/pytensor.hpp @@ -43,9 +43,16 @@ namespace pybind11 { using type = xt::pytensor; - bool load(handle src, bool) + bool load(handle src, bool convert) { - value = type::ensure(src); + if (!convert) { + if (!PyArray_Check(src.ptr())) + return false; + int type_num = xt::detail::numpy_traits::type_num; + if (PyArray_TYPE(reinterpret_cast(src.ptr())) != type_num) + return false; + } + value = type::ensure(src); return static_cast(value); }