From dea83d31e91ac29de35a73abd5de06a6f4fef9e5 Mon Sep 17 00:00:00 2001 From: Benjamin Perret Date: Sun, 22 Nov 2020 13:43:10 +0100 Subject: [PATCH] add pybind casters for strided_views, array_adaptor, and tensor_adaptor --- include/xtensor-python/pyarray.hpp | 6 +-- include/xtensor-python/pynative_casters.hpp | 52 +++++++++++++++++++ include/xtensor-python/pytensor.hpp | 6 +-- .../xtensor_type_caster_base.hpp | 2 +- test_python/main.cpp | 38 ++++++++++++++ test_python/test_pyarray.py | 19 +++++++ 6 files changed, 112 insertions(+), 11 deletions(-) create mode 100644 include/xtensor-python/pynative_casters.hpp diff --git a/include/xtensor-python/pyarray.hpp b/include/xtensor-python/pyarray.hpp index 0873399..ead92cc 100644 --- a/include/xtensor-python/pyarray.hpp +++ b/include/xtensor-python/pyarray.hpp @@ -21,6 +21,7 @@ #include "pyarray_backstrides.hpp" #include "pycontainer.hpp" #include "pystrides_adaptor.hpp" +#include "pynative_casters.hpp" #include "xtensor_type_caster_base.hpp" namespace xt @@ -91,11 +92,6 @@ namespace pybind11 } }; - // Type caster for casting xarray to ndarray - template - struct type_caster> : xtensor_type_caster_base> - { - }; } } diff --git a/include/xtensor-python/pynative_casters.hpp b/include/xtensor-python/pynative_casters.hpp new file mode 100644 index 0000000..09c9fff --- /dev/null +++ b/include/xtensor-python/pynative_casters.hpp @@ -0,0 +1,52 @@ +/*************************************************************************** +* Copyright (c) Wolf Vollprecht, Johan Mabille and Sylvain Corlay * +* Copyright (c) QuantStack * +* * +* Distributed under the terms of the BSD 3-Clause License. * +* * +* The full license is in the file LICENSE, distributed with this software. * +****************************************************************************/ + +#ifndef PYNATIVE_CASTER_HPP +#define PYNATIVE_CASTER_HPP + +#include "xtensor_type_caster_base.hpp" + + +namespace pybind11 +{ + namespace detail + { + // Type caster for casting xarray to ndarray + template + struct type_caster> : xtensor_type_caster_base> + { + }; + + // Type caster for casting xt::xtensor to ndarray + template + struct type_caster> : xtensor_type_caster_base> + { + }; + + // Type caster for casting xt::xstrided_view to ndarray + template + struct type_caster> : xtensor_type_caster_base> + { + }; + + // Type caster for casting xt::xarray_adaptor to ndarray + template + struct type_caster> : xtensor_type_caster_base> + { + }; + + // Type caster for casting xt::xtensor_adaptor to ndarray + template + struct type_caster> : xtensor_type_caster_base> + { + }; + } +} + +#endif diff --git a/include/xtensor-python/pytensor.hpp b/include/xtensor-python/pytensor.hpp index 906bfff..736b4e1 100644 --- a/include/xtensor-python/pytensor.hpp +++ b/include/xtensor-python/pytensor.hpp @@ -21,6 +21,7 @@ #include "pycontainer.hpp" #include "pystrides_adaptor.hpp" +#include "pynative_casters.hpp" #include "xtensor_type_caster_base.hpp" namespace xt @@ -99,11 +100,6 @@ namespace pybind11 } }; - // Type caster for casting xt::xtensor to ndarray - template - struct type_caster> : xtensor_type_caster_base> - { - }; } } diff --git a/include/xtensor-python/xtensor_type_caster_base.hpp b/include/xtensor-python/xtensor_type_caster_base.hpp index 4435dd8..adb7e2b 100644 --- a/include/xtensor-python/xtensor_type_caster_base.hpp +++ b/include/xtensor-python/xtensor_type_caster_base.hpp @@ -39,7 +39,7 @@ namespace pybind11 std::vector python_shape(src.shape().size()); std::copy(src.shape().begin(), src.shape().end(), python_shape.begin()); - array a(python_shape, python_strides, src.begin(), base); + array a(python_shape, python_strides, src.data(), base); if (!writeable) { diff --git a/test_python/main.cpp b/test_python/main.cpp index 6145e76..32b348c 100644 --- a/test_python/main.cpp +++ b/test_python/main.cpp @@ -15,6 +15,8 @@ #include "xtensor-python/pyarray.hpp" #include "xtensor-python/pytensor.hpp" #include "xtensor-python/pyvectorize.hpp" +#include "xtensor/xadapt.hpp" +#include "xtensor/xstrided_view.hpp" namespace py = pybind11; using complex_t = std::complex; @@ -133,6 +135,34 @@ class C array_type m_array; }; +struct test_native_casters +{ + using array_type = xt::xarray; + array_type a{{0, 1, 2},{3, 4, 5}}; + + const auto & get_array(){ + return a; + } + + auto get_strided_view(){ + return xt::strided_view(a, {xt::range(0, 1), xt::range(0, 3, 2)}); + } + + auto get_array_adapter(){ + using shape_type = std::vector; + shape_type shape = {2, 2}; + shape_type stride = {3, 2}; + return xt::adapt(a.data(), 4, xt::no_ownership(), shape, stride); + } + + auto get_tensor_adapter(){ + using shape_type = std::array; + shape_type shape = {2, 2}; + shape_type stride = {3, 2}; + return xt::adapt(a.data(), 4, xt::no_ownership(), shape, stride); + } +}; + xt::pyarray dtype_to_python() { A a1{123, 321, 'a', {1, 2, 3}}; @@ -257,4 +287,12 @@ PYBIND11_MODULE(xtensor_python_test, m) m.def("diff_shape_overload", [](xt::pytensor a) { return 1; }); m.def("diff_shape_overload", [](xt::pytensor a) { return 2; }); + + py::class_(m, "test_native_casters") + .def(py::init<>()) + .def("get_array", &test_native_casters::get_array, py::return_value_policy::reference_internal) + .def("get_strided_view", &test_native_casters::get_strided_view) + .def("get_array_adapter", &test_native_casters::get_array_adapter) + .def("get_tensor_adapter", &test_native_casters::get_tensor_adapter); + } diff --git a/test_python/test_pyarray.py b/test_python/test_pyarray.py index cdc215c..4f1d26d 100644 --- a/test_python/test_pyarray.py +++ b/test_python/test_pyarray.py @@ -166,6 +166,25 @@ def test_diff_shape_overload(self): # FIXME: the TypeError information is not informative xt.diff_shape_overload(np.ones((2, 2, 2))) + def test_native_casters(self): + obj = xt.test_native_casters() + arr = obj.get_array() + + strided_view = obj.get_strided_view() + strided_view[0, 1] = -1 + self.assertEqual(strided_view.shape, (1, 2)) + self.assertEqual(arr[0, 2], -1) + + adapter = obj.get_array_adapter() + self.assertEqual(adapter.shape, (2, 2)) + adapter[1, 1] = -2 + self.assertEqual(arr[1, 2], -2) + + adapter = obj.get_tensor_adapter() + self.assertEqual(adapter.shape, (2, 2)) + adapter[1, 1] = -3 + self.assertEqual(arr[1, 2], -3) + class AttributeTest(TestCase):