diff --git a/_doc/notebooks/onnx_fft.ipynb b/_doc/notebooks/onnx_fft.ipynb index 6efb475e2..90ca1f87a 100644 --- a/_doc/notebooks/onnx_fft.ipynb +++ b/_doc/notebooks/onnx_fft.ipynb @@ -1,822 +1,2823 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "51bc89fc", - "metadata": {}, - "source": [ - "# ONNX and FFT\n", - "\n", - "ONNX does not fully support complex yet. It does not have any FFT operators either. What if we need them anyway?" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "7b2add97", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
run previous cell, wait for 2 seconds
\n", - "" - ], - "text/plain": [ - "" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from jyquickhelper import add_notebook_menu\n", - "add_notebook_menu()" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "acfdc3b0", - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext mlprodict" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "abb5fa88", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'1.21.0'" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import numpy\n", - "numpy.__version__" - ] - }, - { - "cell_type": "markdown", - "id": "2e4f68e4", - "metadata": {}, - "source": [ - "## Python implementation of RFFT\n", - "\n", - "We try to replicate [numpy.rfft](https://numpy.org/doc/stable/reference/generated/numpy.fft.rfft.html)." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "cb1cc910", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([[ 0.92935219+0.j , 1.1166406 +0.18610885j,\n", - " 2.98881347-0.86137828j, 0.57062752-3.17075076j],\n", - " [-0.81071034+0.j , 4.04571912+1.34415298j,\n", - " -0.75316593+1.87375117j, -3.73972034+1.19963451j],\n", - " [ 0.49893169+0.j , -2.38853745+0.91784964j,\n", - " -2.3230939 +2.42467461j, 2.84973582+0.96874118j],\n", - " [-0.85518897+0.j , -1.07457921+2.14618057j,\n", - " 0.67522719-2.17320735j, 1.31480887+2.2782433j ],\n", - " [ 2.80867666+0.j , -2.79453396-2.22901834j,\n", - " 0.492986 +0.10661537j, 2.65317564+0.57651319j]])" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import numpy\n", - "\n", - "\n", - "def almost_equal(a, b, error=1e-5):\n", - " \"\"\"\n", - " The function compares two matrices, one may be complex. In that case,\n", - " this matrix is changed into a new matrix with a new first dimension,\n", - " [0,::] means real part, [1,::] means imaginary part.\n", - " \"\"\"\n", - " if a.dtype in (numpy.complex64, numpy.complex128):\n", - " dtype = numpy.float64 if a.dtype == numpy.complex128 else numpy.float32\n", - " new_a = numpy.empty((2,) + a.shape).astype(dtype)\n", - " new_a[0] = numpy.real(a)\n", - " new_a[1] = numpy.imag(a)\n", - " return almost_equal(new_a, b, error)\n", - " if b.dtype in (numpy.complex64, numpy.complex128):\n", - " return almost_equal(b, a, error)\n", - " if a.shape != b.shape:\n", - " raise AssertionError(\"Shape mismatch %r != %r.\" % (a.shape, b.shape))\n", - " diff = numpy.abs(a.ravel() - b.ravel()).max()\n", - " if diff > error:\n", - " raise AssertionError(\"Mismatch max diff=%r > %r.\" % (diff, error))\n", - "\n", - "\n", - "def dft_real_cst(N, fft_length):\n", - " n = numpy.arange(N)\n", - " k = n.reshape((N, 1)).astype(numpy.float64)\n", - " M = numpy.exp(-2j * numpy.pi * k * n / fft_length)\n", - " both = numpy.empty((2,) + M.shape)\n", - " both[0, :, :] = numpy.real(M)\n", - " both[1, :, :] = numpy.imag(M)\n", - " return both\n", - "\n", - "\n", - "def dft_real(x, fft_length=None, transpose=True):\n", - " if len(x.shape) == 1:\n", - " x = x.reshape((1, -1))\n", - " N = 1\n", - " else:\n", - " N = x.shape[0] \n", - " C = x.shape[-1] if transpose else x.shape[-2]\n", - " if fft_length is None:\n", - " fft_length = x.shape[-1]\n", - " size = fft_length // 2 + 1\n", - "\n", - " cst = dft_real_cst(C, fft_length)\n", - " if transpose:\n", - " x = numpy.transpose(x, (1, 0))\n", - " res = numpy.dot(cst[:, :, :fft_length], x[:fft_length])[:, :size, :]\n", - " return numpy.transpose(res, (0, 2, 1))\n", - " else:\n", - " return numpy.dot(cst[:, :, :fft_length], x[:fft_length])\n", - "\n", - "\n", - "rnd = numpy.random.randn(5, 7).astype(numpy.float32)\n", - "fft_np = numpy.fft.rfft(rnd)\n", - "fft_cus = dft_real(rnd)\n", - "fft_np" - ] - }, - { - "cell_type": "markdown", - "id": "0c052ea1", - "metadata": {}, - "source": [ - "Function `almost_equal` verifies both functions return the same results." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "3ca040cb", - "metadata": {}, - "outputs": [], - "source": [ - "almost_equal(fft_np, fft_cus)" - ] - }, - { - "cell_type": "markdown", - "id": "7fe77440", - "metadata": {}, - "source": [ - "Let's do the same with `fft_length < shape[1]`." - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "3a747a4a", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([[ 0.58212829+0.j , 1.91211772-1.78320393j],\n", - " [-0.3185378 +0.j , -0.20609781-1.18129868j],\n", - " [-0.81120646+0.j , -0.28543806+3.05769342j],\n", - " [-1.06384408+0.j , 0.74100591+0.43276681j],\n", - " [ 1.77509081+0.j , -0.13498855+1.82011058j]])" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "fft_np3 = numpy.fft.rfft(rnd, n=3)\n", - "fft_cus3 = dft_real(rnd, fft_length=3)\n", - "fft_np3" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "0db6247b", - "metadata": {}, - "outputs": [], - "source": [ - "almost_equal(fft_np3, fft_cus3)" - ] - }, - { - "cell_type": "markdown", - "id": "31a6ac9c", - "metadata": {}, - "source": [ - "## RFFT in ONNX\n", - "\n", - "Let's assume first the number of column of the input matrix is fixed. The result of function `dft_real_cst` can be considered as constant." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "efb67b9b", - "metadata": { - "scrolled": false - }, - "outputs": [ - { - "data": { - "text/plain": [ - "array([[[ 0.92935216, 1.1166406 , 2.9888134 , 0.5706275 ],\n", - " [-0.81071043, 4.045719 , -0.7531659 , -3.7397203 ],\n", - " [ 0.4989317 , -2.3885374 , -2.323094 , 2.849736 ],\n", - " [-0.85518885, -1.0745792 , 0.6752271 , 1.3148088 ],\n", - " [ 2.8086765 , -2.794534 , 0.49298596, 2.6531756 ]],\n", - "\n", - " [[ 0. , 0.18610872, -0.8613782 , -3.1707506 ],\n", - " [ 0. , 1.3441529 , 1.8737512 , 1.1996344 ],\n", - " [ 0. , 0.9178499 , 2.4246747 , 0.96874106],\n", - " [ 0. , 2.1461806 , -2.1732073 , 2.2782433 ],\n", - " [ 0. , -2.2290184 , 0.10661539, 0.5765133 ]]],\n", - " dtype=float32)" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from typing import Any\n", - "import mlprodict.npy.numpy_onnx_impl as npnx\n", - "from mlprodict.npy import onnxnumpy_np\n", - "from mlprodict.npy.onnx_numpy_annotation import NDArrayType\n", - "# from mlprodict.onnxrt import OnnxInference\n", - "\n", - "@onnxnumpy_np(signature=NDArrayType((\"T:all\", ), dtypes_out=('T',)))\n", - "def onnx_rfft(x, fft_length=None):\n", - " if fft_length is None:\n", - " raise RuntimeError(\"fft_length must be specified.\")\n", - " \n", - " size = fft_length // 2 + 1\n", - " cst = dft_real_cst(fft_length, fft_length).astype(numpy.float32)\n", - " xt = npnx.transpose(x, (1, 0))\n", - " res = npnx.dot(cst[:, :, :fft_length], xt[:fft_length])[:, :size, :]\n", - " return npnx.transpose(res, (0, 2, 1))\n", - "\n", - "fft_onx = onnx_rfft(rnd, fft_length=rnd.shape[1])\n", - "fft_onx" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "c4b6b1a5", - "metadata": {}, - "outputs": [], - "source": [ - "almost_equal(fft_cus, fft_onx)" - ] - }, - { - "cell_type": "markdown", - "id": "a8c35327", - "metadata": {}, - "source": [ - "The corresponding ONNX graph is the following:" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "4d1a85b0", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "" - ], - "text/plain": [ - "" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "key = list(onnx_rfft.signed_compiled)[0]\n", - "%onnxview onnx_rfft.signed_compiled[key].compiled.onnx_" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "6cf18aca", - "metadata": {}, - "outputs": [], - "source": [ - "fft_onx3 = onnx_rfft(rnd, fft_length=3)\n", - "almost_equal(fft_cus3, fft_onx3)" - ] - }, - { - "cell_type": "markdown", - "id": "6b466fd4", - "metadata": {}, - "source": [ - "## FFT 2D\n", - "\n", - "Below the code for complex features." - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "e0020084", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([[-2.036582 +0.j , -0.85992725+6.47780438j,\n", - " -3.99332006-3.11192536j, -1.32368431-2.48821071j],\n", - " [ 4.37345155-7.03173815j, -3.14890126+1.59632335j,\n", - " -3.75306979+0.66651699j, -1.56716114+4.75028368j],\n", - " [ 2.76767016+1.25297955j, 5.07926144-2.23393831j,\n", - " 2.41908275-8.55451105j, -8.84556476-1.29356088j],\n", - " [ 2.76767016-1.25297955j, 2.41782872-4.44962381j,\n", - " -3.6501426 +4.13120322j, 4.30875103+0.96179243j],\n", - " [ 4.37345155+7.03173815j, 1.77135529+4.4385736j ,\n", - " 2.40878105+5.40109054j, -1.65462983+0.2149866j ]])" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "def _DFT_cst(N, fft_length, trunc=True):\n", - " n = numpy.arange(N)\n", - " k = n.reshape((N, 1)).astype(numpy.float64)\n", - " M = numpy.exp(-2j * numpy.pi * k * n / fft_length)\n", - " return M[:fft_length // 2 + 1] if trunc else M\n", - "\n", - "def DFT(x, fft_length=None, axis=1):\n", - " if axis == 1:\n", - " x = x.T\n", - " if fft_length is None:\n", - " fft_length = x.shape[0]\n", - " cst = _DFT_cst(x.shape[0], fft_length, trunc=axis==1)\n", - " if axis == 1:\n", - " return numpy.dot(cst, x).T\n", - " return numpy.dot(cst, x)\n", - "\n", - "def fft2d_(mat, fft_length):\n", - " mat = mat[:fft_length[0], :fft_length[1]]\n", - " res = mat.copy()\n", - " res = DFT(res, fft_length[1], axis=1)\n", - " res = DFT(res, fft_length[0], axis=0)\n", - " return res[:fft_length[0], :fft_length[1]//2 + 1]\n", - "\n", - "\n", - "rnd = numpy.random.randn(5, 7).astype(numpy.float32)\n", - "fft2d_np_ = fft2d_(rnd, rnd.shape)\n", - "fft2d_np = numpy.fft.rfft2(rnd)\n", - "fft2d_np_" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "777d2775", - "metadata": {}, - "outputs": [], - "source": [ - "almost_equal(fft2d_np_, fft2d_np)" - ] - }, - { - "cell_type": "markdown", - "id": "cfbbe2fd", - "metadata": {}, - "source": [ - "It implies the computation of two FFT 1D along both axes. However, as ONNX does not support complex, it needs to be rewritten with only real numbers. The algorithm can be summarized into this formula $FFT(FFT(x, axis=1), axis=0)$. If *x* is real, $FFT(x, .)$ is complex. We still assume *x* is real, it then becomes (FFT is a linear operator, so $FFT(ix)=i FFT(x)$):\n", - "\n", - "* $y = FFT(x, axis=1)$\n", - "* $z_r = FFT(Real(y), axis=0)$, $z_i = FFT(Imag(y), axis=0)$\n", - "* $z = z_r + i z_i$\n", - "\n", - "*z* is the desired output. The following implementation is probably not the most efficient one. It avoids inplace computation as ONNX does like that." - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "dd4fc711", - "metadata": {}, - "outputs": [], - "source": [ - "def fft2d(mat, fft_length):\n", - " mat = mat[:fft_length[0], :fft_length[1]]\n", - " res = mat.copy()\n", - " \n", - " # first FFT\n", - " res = dft_real(res, fft_length=fft_length[1], transpose=True)\n", - " \n", - " # second FFT decomposed on FFT on real part and imaginary part\n", - " res2_real = dft_real(res[0], fft_length=fft_length[0], transpose=False)\n", - " res2_imag = dft_real(res[1], fft_length=fft_length[0], transpose=False) \n", - " res2_imag2 = numpy.vstack([-res2_imag[1:2], res2_imag[:1]])\n", - " res = res2_real + res2_imag2\n", - " size = fft_length[1]//2 + 1\n", - " return res[:, :fft_length[0], :size]\n", - "\n", - "\n", - "fft2d_np = numpy.fft.rfft2(rnd)\n", - "fft2d_cus = fft2d(rnd, rnd.shape)\n", - "almost_equal(fft2d_np, fft2d_cus)" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "bb8667e6", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([[-2.036582 +0.j , -0.85992725+6.47780438j,\n", - " -3.99332006-3.11192536j, -1.32368431-2.48821071j],\n", - " [ 4.37345155-7.03173815j, -3.14890126+1.59632335j,\n", - " -3.75306979+0.66651699j, -1.56716114+4.75028368j],\n", - " [ 2.76767016+1.25297955j, 5.07926144-2.23393831j,\n", - " 2.41908275-8.55451105j, -8.84556476-1.29356088j],\n", - " [ 2.76767016-1.25297955j, 2.41782872-4.44962381j,\n", - " -3.6501426 +4.13120322j, 4.30875103+0.96179243j],\n", - " [ 4.37345155+7.03173815j, 1.77135529+4.4385736j ,\n", - " 2.40878105+5.40109054j, -1.65462983+0.2149866j ]])" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "fft2d_np" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "56a94d97", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([[[-2.036582 , -0.85992725, -3.99332006, -1.32368431],\n", - " [ 4.37345155, -3.14890126, -3.75306979, -1.56716114],\n", - " [ 2.76767016, 5.07926144, 2.41908275, -8.84556476],\n", - " [ 2.76767016, 2.41782872, -3.6501426 , 4.30875103],\n", - " [ 4.37345155, 1.77135529, 2.40878105, -1.65462983]],\n", - "\n", - " [[ 0. , 6.47780438, -3.11192536, -2.48821071],\n", - " [-7.03173815, 1.59632335, 0.66651699, 4.75028368],\n", - " [ 1.25297955, -2.23393831, -8.55451105, -1.29356088],\n", - " [-1.25297955, -4.44962381, 4.13120322, 0.96179243],\n", - " [ 7.03173815, 4.4385736 , 5.40109054, 0.2149866 ]]])" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "fft2d_cus" - ] - }, - { - "cell_type": "markdown", - "id": "faa21909", - "metadata": {}, - "source": [ - "And with a different `fft_length`." - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "bf98995f", - "metadata": {}, - "outputs": [], - "source": [ - "fft2d_np = numpy.fft.rfft2(rnd, (4, 6))\n", - "fft2d_cus = fft2d(rnd, (4, 6))\n", - "almost_equal(fft2d_np[:4, :], fft2d_cus)" - ] - }, - { - "cell_type": "markdown", - "id": "caee1f84", - "metadata": {}, - "source": [ - "## FFT 2D in ONNX\n", - "\n", - "We use again the numpy API for ONNX." - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "id": "ca641274", - "metadata": {}, - "outputs": [], - "source": [ - "def onnx_rfft_1d(x, fft_length=None, transpose=True):\n", - " if fft_length is None:\n", - " raise RuntimeError(\"fft_length must be specified.\")\n", - " \n", - " size = fft_length // 2 + 1\n", - " cst = dft_real_cst(fft_length, fft_length).astype(numpy.float32)\n", - " if transpose:\n", - " xt = npnx.transpose(x, (1, 0))\n", - " res = npnx.dot(cst[:, :, :fft_length], xt[:fft_length])[:, :size, :]\n", - " return npnx.transpose(res, (0, 2, 1))\n", - " else:\n", - " return npnx.dot(cst[:, :, :fft_length], x[:fft_length])\n", - "\n", - "\n", - "@onnxnumpy_np(signature=NDArrayType((\"T:all\", ), dtypes_out=('T',)))\n", - "def onnx_rfft_2d(x, fft_length=None):\n", - " mat = x[:fft_length[0], :fft_length[1]]\n", - " \n", - " # first FFT\n", - " res = onnx_rfft_1d(mat, fft_length=fft_length[1], transpose=True)\n", - " \n", - " # second FFT decomposed on FFT on real part and imaginary part\n", - " res2_real = onnx_rfft_1d(res[0], fft_length=fft_length[0], transpose=False)\n", - " res2_imag = onnx_rfft_1d(res[1], fft_length=fft_length[0], transpose=False) \n", - " res2_imag2 = npnx.vstack(-res2_imag[1:2], res2_imag[:1])\n", - " res = res2_real + res2_imag2\n", - " size = fft_length[1]//2 + 1\n", - " return res[:, :fft_length[0], :size]\n", - "\n", - "\n", - "fft2d_cus = fft2d(rnd, rnd.shape)\n", - "fft2d_onx = onnx_rfft_2d(rnd, fft_length=rnd.shape)\n", - "almost_equal(fft2d_cus, fft2d_onx)" - ] - }, - { - "cell_type": "markdown", - "id": "20fcd8a9", - "metadata": {}, - "source": [ - "The corresponding ONNX graph." - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "b1379b06", - "metadata": { - "scrolled": false - }, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "" - ], - "text/plain": [ - "" - ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "key = list(onnx_rfft_2d.signed_compiled)[0]\n", - "%onnxview onnx_rfft_2d.signed_compiled[key].compiled.onnx_" - ] - }, - { - "cell_type": "markdown", - "id": "3a747f0c", - "metadata": {}, - "source": [ - "With a different `fft_length`." - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "id": "16732cbb", - "metadata": {}, - "outputs": [], - "source": [ - "fft2d_cus = fft2d(rnd, (4, 5))\n", - "fft2d_onx = onnx_rfft_2d(rnd, fft_length=(4, 5))\n", - "almost_equal(fft2d_cus, fft2d_onx)" - ] - }, - { - "cell_type": "markdown", - "id": "04924e7d", - "metadata": {}, - "source": [ - "This implementation of FFT in ONNX assumes shapes and fft lengths are constant. Otherwise, the matrix returned by function `dft_real_cst` must be converted as well. That's left as an exercise." - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "id": "faeff9cd", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.5" - } - }, - "nbformat": 4, - "nbformat_minor": 5 +{ + "cells": [ + { + "cell_type": "markdown", + "id": "51bc89fc", + "metadata": {}, + "source": [ + "# ONNX and FFT\n", + "\n", + "ONNX does not fully support complex yet. It does not have any FFT operators either. What if we need them anyway?" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "7b2add97", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
run previous cell, wait for 2 seconds
\n", + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from jyquickhelper import add_notebook_menu\n", + "add_notebook_menu()" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "acfdc3b0", + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext mlprodict" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "abb5fa88", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'1.21.1'" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import numpy\n", + "numpy.__version__" + ] + }, + { + "cell_type": "markdown", + "id": "2e4f68e4", + "metadata": {}, + "source": [ + "## Python implementation of RFFT\n", + "\n", + "We try to replicate [numpy.rfft](https://numpy.org/doc/stable/reference/generated/numpy.fft.rfft.html)." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "cb1cc910", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[ 1.4850522 +0.j , 0.20190521-0.68916309j,\n", + " -4.02216577+2.22221485j, 2.9797031 +5.4719969j ],\n", + " [ 4.34131445+0.j , 0.59711262-2.9628275j ,\n", + " 2.24324474+1.82517205j, 2.00603187-2.84835823j],\n", + " [-1.12793672+0.j , 1.38239012-2.21121807j,\n", + " -2.35260183+1.48559871j, -1.50040395+1.3950099j ],\n", + " [-2.28634521+0.j , 0.18810962-0.47532163j,\n", + " 2.02343296-0.92789895j, -0.21166875-3.03973764j],\n", + " [-0.48667854+0.j , 1.17243003-1.72963149j,\n", + " -1.73139954+1.88922393j, -0.44110992-0.81126496j]])" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import numpy\n", + "\n", + "\n", + "def almost_equal(a, b, error=1e-5):\n", + " \"\"\"\n", + " The function compares two matrices, one may be complex. In that case,\n", + " this matrix is changed into a new matrix with a new first dimension,\n", + " [0,::] means real part, [1,::] means imaginary part.\n", + " \"\"\"\n", + " if a.dtype in (numpy.complex64, numpy.complex128):\n", + " dtype = numpy.float64 if a.dtype == numpy.complex128 else numpy.float32\n", + " new_a = numpy.empty((2,) + a.shape).astype(dtype)\n", + " new_a[0] = numpy.real(a)\n", + " new_a[1] = numpy.imag(a)\n", + " return almost_equal(new_a, b, error)\n", + " if b.dtype in (numpy.complex64, numpy.complex128):\n", + " return almost_equal(b, a, error)\n", + " if a.shape != b.shape:\n", + " raise AssertionError(\"Shape mismatch %r != %r.\" % (a.shape, b.shape))\n", + " diff = numpy.abs(a.ravel() - b.ravel()).max()\n", + " if diff > error:\n", + " raise AssertionError(\"Mismatch max diff=%r > %r.\" % (diff, error))\n", + "\n", + "\n", + "def dft_real_cst(N, fft_length):\n", + " n = numpy.arange(N)\n", + " k = n.reshape((N, 1)).astype(numpy.float64)\n", + " M = numpy.exp(-2j * numpy.pi * k * n / fft_length)\n", + " both = numpy.empty((2,) + M.shape)\n", + " both[0, :, :] = numpy.real(M)\n", + " both[1, :, :] = numpy.imag(M)\n", + " return both\n", + "\n", + "\n", + "def dft_real(x, fft_length=None, transpose=True):\n", + " if len(x.shape) == 1:\n", + " x = x.reshape((1, -1))\n", + " N = 1\n", + " else:\n", + " N = x.shape[0] \n", + " C = x.shape[-1] if transpose else x.shape[-2]\n", + " if fft_length is None:\n", + " fft_length = x.shape[-1]\n", + " size = fft_length // 2 + 1\n", + "\n", + " cst = dft_real_cst(C, fft_length)\n", + " if transpose:\n", + " x = numpy.transpose(x, (1, 0))\n", + " a = cst[:, :, :fft_length]\n", + " b = x[:fft_length]\n", + " res = numpy.matmul(a, b)\n", + " res = res[:, :size, :]\n", + " return numpy.transpose(res, (0, 2, 1))\n", + " else:\n", + " a = cst[:, :, :fft_length]\n", + " b = x[:fft_length]\n", + " return numpy.matmul(a, b)\n", + "\n", + "\n", + "rnd = numpy.random.randn(5, 7).astype(numpy.float32)\n", + "fft_np = numpy.fft.rfft(rnd)\n", + "fft_cus = dft_real(rnd)\n", + "fft_np" + ] + }, + { + "cell_type": "markdown", + "id": "0c052ea1", + "metadata": {}, + "source": [ + "Function `almost_equal` verifies both functions return the same results." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "3ca040cb", + "metadata": {}, + "outputs": [], + "source": [ + "almost_equal(fft_np, fft_cus)" + ] + }, + { + "cell_type": "markdown", + "id": "7fe77440", + "metadata": {}, + "source": [ + "Let's do the same with `fft_length < shape[1]`." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "3a747a4a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[ 2.02068053+0.j , -1.0523537 +4.21051588j],\n", + " [ 3.35636759+0.j , 1.32912183+0.17609187j],\n", + " [ 0.86212105+0.j , -1.73159653+0.58274578j],\n", + " [-0.93982866+0.j , 0.837072 -1.67403309j],\n", + " [ 0.72620976+0.j , -0.89599861+0.37600383j]])" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "fft_np3 = numpy.fft.rfft(rnd, n=3)\n", + "fft_cus3 = dft_real(rnd, fft_length=3)\n", + "fft_np3" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "0db6247b", + "metadata": {}, + "outputs": [], + "source": [ + "almost_equal(fft_np3, fft_cus3)" + ] + }, + { + "cell_type": "markdown", + "id": "31a6ac9c", + "metadata": {}, + "source": [ + "## RFFT in ONNX\n", + "\n", + "Let's assume first the number of column of the input matrix is fixed. The result of function `dft_real_cst` can be considered as constant." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "efb67b9b", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[[ 1.4850521 , 0.2019052 , -4.0221653 , 2.979703 ],\n", + " [ 4.3413143 , 0.5971126 , 2.2432446 , 2.006032 ],\n", + " [-1.1279367 , 1.38239 , -2.3526018 , -1.500404 ],\n", + " [-2.2863452 , 0.18810958, 2.0234327 , -0.21166855],\n", + " [-0.48667857, 1.17243 , -1.7313995 , -0.44110993]],\n", + "\n", + " [[ 0. , -0.689163 , 2.2222147 , 5.4719973 ],\n", + " [ 0. , -2.9628277 , 1.8251722 , -2.8483584 ],\n", + " [ 0. , -2.211218 , 1.4855987 , 1.39501 ],\n", + " [ 0. , -0.47532162, -0.9278989 , -3.0397377 ],\n", + " [ 0. , -1.7296315 , 1.8892239 , -0.811265 ]]],\n", + " dtype=float32)" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from typing import Any\n", + "import mlprodict.npy.numpy_onnx_impl as npnx\n", + "from mlprodict.npy import onnxnumpy_np\n", + "from mlprodict.npy.onnx_numpy_annotation import NDArrayType\n", + "# from mlprodict.onnxrt import OnnxInference\n", + "\n", + "@onnxnumpy_np(signature=NDArrayType((\"T:all\", ), dtypes_out=('T',)))\n", + "def onnx_rfft(x, fft_length=None):\n", + " if fft_length is None:\n", + " raise RuntimeError(\"fft_length must be specified.\")\n", + " \n", + " size = fft_length // 2 + 1\n", + " cst = dft_real_cst(fft_length, fft_length).astype(numpy.float32)\n", + " xt = npnx.transpose(x, (1, 0))\n", + " res = npnx.matmul(cst[:, :, :fft_length], xt[:fft_length])[:, :size, :]\n", + " return npnx.transpose(res, (0, 2, 1))\n", + "\n", + "fft_onx = onnx_rfft(rnd, fft_length=rnd.shape[1])\n", + "fft_onx" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "c4b6b1a5", + "metadata": {}, + "outputs": [], + "source": [ + "almost_equal(fft_cus, fft_onx)" + ] + }, + { + "cell_type": "markdown", + "id": "a8c35327", + "metadata": {}, + "source": [ + "The corresponding ONNX graph is the following:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "4d1a85b0", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "key = list(onnx_rfft.signed_compiled)[0]\n", + "%onnxview onnx_rfft.signed_compiled[key].compiled.onnx_" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "6cf18aca", + "metadata": {}, + "outputs": [], + "source": [ + "fft_onx3 = onnx_rfft(rnd, fft_length=3)\n", + "almost_equal(fft_cus3, fft_onx3)" + ] + }, + { + "cell_type": "markdown", + "id": "6b466fd4", + "metadata": {}, + "source": [ + "## FFT 2D\n", + "\n", + "Below the code for complex features." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "e0020084", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[ 6.73214482 +0.j , 3.36817961 +2.73163668j,\n", + " 4.42688318-10.23641065j, -3.27162717 -0.54910943j],\n", + " [ 0.05934113 -2.02790296j, -4.71694176 -1.80039444j,\n", + " -0.16187544 -1.27214887j, -4.76195404 -8.23146595j],\n", + " [-2.43886644 +0.67253454j, 0.62177822 +1.71628605j,\n", + " -4.22144547 -0.24384973j, -1.96253444 -2.26942153j],\n", + " [-2.43886644 -0.67253454j, -4.94210355 +1.65439295j,\n", + " -6.75624015 -2.50966739j, -1.62599543 +7.41506091j],\n", + " [ 0.05934113 +2.02790296j, 1.56068457 -4.5734695j ,\n", + " -2.9809962 +2.90470743j, 4.42498542-10.45411745j]])" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "def _DFT_cst(N, fft_length, trunc=True):\n", + " n = numpy.arange(N)\n", + " k = n.reshape((N, 1)).astype(numpy.float64)\n", + " M = numpy.exp(-2j * numpy.pi * k * n / fft_length)\n", + " return M[:fft_length // 2 + 1] if trunc else M\n", + "\n", + "def DFT(x, fft_length=None, axis=1):\n", + " if axis == 1:\n", + " x = x.T\n", + " if fft_length is None:\n", + " fft_length = x.shape[0]\n", + " cst = _DFT_cst(x.shape[0], fft_length, trunc=axis==1)\n", + " if axis == 1:\n", + " return numpy.matmul(cst, x).T\n", + " return numpy.matmul(cst, x)\n", + "\n", + "def fft2d_(mat, fft_length):\n", + " mat = mat[:fft_length[0], :fft_length[1]]\n", + " res = mat.copy()\n", + " res = DFT(res, fft_length[1], axis=1)\n", + " res = DFT(res, fft_length[0], axis=0)\n", + " return res[:fft_length[0], :fft_length[1]//2 + 1]\n", + "\n", + "\n", + "rnd = numpy.random.randn(5, 7).astype(numpy.float32)\n", + "fft2d_np_ = fft2d_(rnd, rnd.shape)\n", + "fft2d_np = numpy.fft.rfft2(rnd)\n", + "fft2d_np_" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "777d2775", + "metadata": {}, + "outputs": [], + "source": [ + "almost_equal(fft2d_np_, fft2d_np)" + ] + }, + { + "cell_type": "markdown", + "id": "cfbbe2fd", + "metadata": {}, + "source": [ + "It implies the computation of two FFT 1D along both axes. However, as ONNX does not support complex, it needs to be rewritten with only real numbers. The algorithm can be summarized into this formula $FFT(FFT(x, axis=1), axis=0)$. If *x* is real, $FFT(x, .)$ is complex. We still assume *x* is real, it then becomes (FFT is a linear operator, so $FFT(ix)=i FFT(x)$):\n", + "\n", + "* $y = FFT(x, axis=1)$\n", + "* $z_r = FFT(Real(y), axis=0)$, $z_i = FFT(Imag(y), axis=0)$\n", + "* $z = z_r + i z_i$\n", + "\n", + "*z* is the desired output. The following implementation is probably not the most efficient one. It avoids inplace computation as ONNX does like that." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "dd4fc711", + "metadata": {}, + "outputs": [], + "source": [ + "def fft2d(mat, fft_length):\n", + " mat = mat[:fft_length[0], :fft_length[1]]\n", + " res = mat.copy()\n", + " \n", + " # first FFT\n", + " res = dft_real(res, fft_length=fft_length[1], transpose=True)\n", + " \n", + " # second FFT decomposed on FFT on real part and imaginary part\n", + " res2_real = dft_real(res[0], fft_length=fft_length[0], transpose=False)\n", + " res2_imag = dft_real(res[1], fft_length=fft_length[0], transpose=False) \n", + " res2_imag2 = numpy.vstack([-res2_imag[1:2], res2_imag[:1]])\n", + " res = res2_real + res2_imag2\n", + " size = fft_length[1]//2 + 1\n", + " return res[:, :fft_length[0], :size]\n", + "\n", + "\n", + "fft2d_np = numpy.fft.rfft2(rnd)\n", + "fft2d_cus = fft2d(rnd, rnd.shape)\n", + "almost_equal(fft2d_np, fft2d_cus)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "bb8667e6", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[ 6.73214482 +0.j , 3.36817961 +2.73163668j,\n", + " 4.42688318-10.23641065j, -3.27162717 -0.54910943j],\n", + " [ 0.05934113 -2.02790296j, -4.71694176 -1.80039444j,\n", + " -0.16187544 -1.27214887j, -4.76195404 -8.23146595j],\n", + " [-2.43886644 +0.67253454j, 0.62177822 +1.71628605j,\n", + " -4.22144547 -0.24384973j, -1.96253444 -2.26942153j],\n", + " [-2.43886644 -0.67253454j, -4.94210355 +1.65439295j,\n", + " -6.75624015 -2.50966739j, -1.62599543 +7.41506091j],\n", + " [ 0.05934113 +2.02790296j, 1.56068457 -4.5734695j ,\n", + " -2.9809962 +2.90470743j, 4.42498542-10.45411745j]])" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "fft2d_np" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "56a94d97", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[[ 6.73214482, 3.36817961, 4.42688318, -3.27162717],\n", + " [ 0.05934113, -4.71694176, -0.16187544, -4.76195404],\n", + " [ -2.43886644, 0.62177822, -4.22144547, -1.96253444],\n", + " [ -2.43886644, -4.94210355, -6.75624015, -1.62599543],\n", + " [ 0.05934113, 1.56068457, -2.9809962 , 4.42498542]],\n", + "\n", + " [[ 0. , 2.73163668, -10.23641065, -0.54910943],\n", + " [ -2.02790296, -1.80039444, -1.27214887, -8.23146595],\n", + " [ 0.67253454, 1.71628605, -0.24384973, -2.26942153],\n", + " [ -0.67253454, 1.65439295, -2.50966739, 7.41506091],\n", + " [ 2.02790296, -4.5734695 , 2.90470743, -10.45411745]]])" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "fft2d_cus" + ] + }, + { + "cell_type": "markdown", + "id": "faa21909", + "metadata": {}, + "source": [ + "And with a different `fft_length`." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "bf98995f", + "metadata": {}, + "outputs": [], + "source": [ + "fft2d_np = numpy.fft.rfft2(rnd, (4, 6))\n", + "fft2d_cus = fft2d(rnd, (4, 6))\n", + "almost_equal(fft2d_np[:4, :], fft2d_cus)" + ] + }, + { + "cell_type": "markdown", + "id": "caee1f84", + "metadata": {}, + "source": [ + "## FFT 2D in ONNX\n", + "\n", + "We use again the numpy API for ONNX." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "ca641274", + "metadata": {}, + "outputs": [], + "source": [ + "def onnx_rfft_1d(x, fft_length=None, transpose=True):\n", + " if fft_length is None:\n", + " raise RuntimeError(\"fft_length must be specified.\")\n", + " \n", + " size = fft_length // 2 + 1\n", + " cst = dft_real_cst(fft_length, fft_length).astype(numpy.float32)\n", + " if transpose:\n", + " xt = npnx.transpose(x, (1, 0))\n", + " res = npnx.matmul(cst[:, :, :fft_length], xt[:fft_length])[:, :size, :]\n", + " return npnx.transpose(res, (0, 2, 1))\n", + " else:\n", + " return npnx.matmul(cst[:, :, :fft_length], x[:fft_length])\n", + "\n", + "\n", + "@onnxnumpy_np(signature=NDArrayType((\"T:all\", ), dtypes_out=('T',)))\n", + "def onnx_rfft_2d(x, fft_length=None):\n", + " mat = x[:fft_length[0], :fft_length[1]]\n", + " \n", + " # first FFT\n", + " res = onnx_rfft_1d(mat, fft_length=fft_length[1], transpose=True)\n", + " \n", + " # second FFT decomposed on FFT on real part and imaginary part\n", + " res2_real = onnx_rfft_1d(res[0], fft_length=fft_length[0], transpose=False)\n", + " res2_imag = onnx_rfft_1d(res[1], fft_length=fft_length[0], transpose=False) \n", + " res2_imag2 = npnx.vstack(-res2_imag[1:2], res2_imag[:1])\n", + " res = res2_real + res2_imag2\n", + " size = fft_length[1]//2 + 1\n", + " return res[:, :fft_length[0], :size]\n", + "\n", + "\n", + "fft2d_cus = fft2d(rnd, rnd.shape)\n", + "fft2d_onx = onnx_rfft_2d(rnd, fft_length=rnd.shape)\n", + "almost_equal(fft2d_cus, fft2d_onx)" + ] + }, + { + "cell_type": "markdown", + "id": "20fcd8a9", + "metadata": {}, + "source": [ + "The corresponding ONNX graph." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "b1379b06", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "key = list(onnx_rfft_2d.signed_compiled)[0]\n", + "%onnxview onnx_rfft_2d.signed_compiled[key].compiled.onnx_" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "3034da60", + "metadata": {}, + "outputs": [], + "source": [ + "with open(\"fft2d.onnx\", \"wb\") as f:\n", + " key = list(onnx_rfft_2d.signed_compiled)[0]\n", + " f.write(onnx_rfft_2d.signed_compiled[key].compiled.onnx_.SerializeToString())" + ] + }, + { + "cell_type": "markdown", + "id": "3a747f0c", + "metadata": {}, + "source": [ + "With a different `fft_length`." + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "16732cbb", + "metadata": {}, + "outputs": [], + "source": [ + "fft2d_cus = fft2d(rnd, (4, 5))\n", + "fft2d_onx = onnx_rfft_2d(rnd, fft_length=(4, 5))\n", + "almost_equal(fft2d_cus, fft2d_onx)" + ] + }, + { + "cell_type": "markdown", + "id": "04924e7d", + "metadata": {}, + "source": [ + "This implementation of FFT in ONNX assumes shapes and fft lengths are constant. Otherwise, the matrix returned by function `dft_real_cst` must be converted as well. That's left as an exercise." + ] + }, + { + "cell_type": "markdown", + "id": "c9da88a0", + "metadata": {}, + "source": [ + "## FFT2D with shape (3,1,4)\n", + "\n", + "Previous implementation expects the input matrix to have two dimensions. It fails with 3." + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "66ba70ee", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(3, 1, 4)" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "shape = (3, 1, 4)\n", + "fft_length = (1, 4)\n", + "rnd = numpy.random.randn(*list(shape)).astype(numpy.float32)\n", + "fft2d_numpy = numpy.fft.fft2(rnd, fft_length)\n", + "fft2d_numpy.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "a4d123e1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[[ 0.87908971+0.j , 0.85659337+2.10206711j,\n", + " 1.5270735 +0.j , 0.85659337-2.10206711j]],\n", + "\n", + " [[-5.01959181+0.j , -0.25658643+0.62102163j,\n", + " 2.18641639+0.j , -0.25658643-0.62102163j]],\n", + "\n", + " [[ 0.60041136+0.j , -0.04546577-1.2931717j ,\n", + " 1.19486004+0.j , -0.04546577+1.2931717j ]]])" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "fft2d_numpy" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "4b1bd05b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "axes don't match array\n" + ] + } + ], + "source": [ + "try:\n", + " fft2d_cus = fft2d(rnd, fft_length)\n", + "except Exception as e:\n", + " print(e)\n", + "# fft2d_onx = onnx_rfft_2d(rnd, fft_length=fft_length)" + ] + }, + { + "cell_type": "markdown", + "id": "7bd79a00", + "metadata": {}, + "source": [ + "### numpy version\n", + "\n", + "Let's do it again with numpy first. [fft2](https://numpy.org/doc/stable/reference/generated/numpy.fft.fft2.html) performs `fft2` on the last two axis as many times as the first axis. The goal is still to have an implementation which works for any dimension." + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "3b618335", + "metadata": {}, + "outputs": [], + "source": [ + "conc = []\n", + "for i in range(rnd.shape[0]):\n", + " f2 = fft2d(rnd[i], fft_length)\n", + " conc.append(numpy.expand_dims(f2, 0))\n", + "res = numpy.vstack(conc).transpose(1, 0, 2, 3)\n", + "almost_equal(fft2d_numpy[:, :, :3], res)" + ] + }, + { + "cell_type": "markdown", + "id": "7c837e7a", + "metadata": {}, + "source": [ + "It works. And now a more efficient implementation. It is better to read [matmul](https://numpy.org/doc/stable/reference/generated/numpy.matmul.html) description before. To summarize, a third axis is equivalent to many matrix multiplications over the last two axes, as many as the dimension of the first axis: ``matmul(A[I,J,K], B[I,K,L]) --> C[I,J,L]``. Broadcasting also works... ``matmul(A[1,J,K], B[I,K,L]) --> C[I,J,L]``." + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "29055cb2", + "metadata": {}, + "outputs": [], + "source": [ + "def dft_real_d3(x, fft_length=None, transpose=True):\n", + " if len(x.shape) != 3:\n", + " raise RuntimeError(\"Not implemented for shape=%r.\" % x.shape)\n", + " N = x.shape[1]\n", + " C = x.shape[-1] if transpose else x.shape[-2]\n", + " if fft_length is None:\n", + " fft_length = x.shape[-1]\n", + " size = fft_length // 2 + 1\n", + "\n", + " cst = dft_real_cst(C, fft_length)\n", + " if transpose:\n", + " x = numpy.transpose(x, (0, 2, 1))\n", + " a = cst[:, :, :fft_length]\n", + " b = x[:, :fft_length, :]\n", + " a = numpy.expand_dims(a, 0)\n", + " b = numpy.expand_dims(b, 1)\n", + " res = numpy.matmul(a, b)\n", + " res = res[:, :, :size, :]\n", + " return numpy.transpose(res, (1, 0, 3, 2))\n", + " else:\n", + " a = cst[:, :, :fft_length]\n", + " b = x[:, :fft_length, :]\n", + " a = numpy.expand_dims(a, 0)\n", + " b = numpy.expand_dims(b, 1)\n", + " res = numpy.matmul(a, b)\n", + " return numpy.transpose(res, (1, 0, 2, 3))\n", + "\n", + "\n", + "def fft2d_d3(mat, fft_length):\n", + " mat = mat[:, :fft_length[-2], :fft_length[-1]]\n", + " res = mat.copy()\n", + " \n", + " # first FFT\n", + " res = dft_real_d3(res, fft_length=fft_length[-1], transpose=True)\n", + " \n", + " # second FFT decomposed on FFT on real part and imaginary part\n", + " res2_real = dft_real_d3(res[0], fft_length=fft_length[-2], transpose=False)\n", + " res2_imag = dft_real_d3(res[1], fft_length=fft_length[-2], transpose=False)\n", + " res2_imag2 = numpy.vstack([-res2_imag[1:2], res2_imag[:1]])\n", + " res = res2_real + res2_imag2\n", + " size = fft_length[-1]//2 + 1\n", + " return res[:, :, :fft_length[-2], :size]\n", + "\n", + "\n", + "def fft2d_any(mat, fft_length):\n", + " new_shape = (-1, ) + mat.shape[-2:]\n", + " mat2 = mat.reshape(new_shape)\n", + " f2 = fft2d_d3(mat2, fft_length)\n", + " new_shape = (2, ) + mat.shape[:-2] + f2.shape[-2:]\n", + " return f2.reshape(new_shape)\n", + "\n", + "\n", + "shape = (3, 1, 4)\n", + "fft_length = (1, 4)\n", + "rnd = numpy.random.randn(*list(shape)).astype(numpy.float32)\n", + "fft2d_numpy = numpy.fft.fft2(rnd, fft_length)\n", + "fft2d_cus = fft2d_any(rnd, fft_length)\n", + "almost_equal(fft2d_numpy[..., :3], fft2d_cus)" + ] + }, + { + "cell_type": "markdown", + "id": "0128b3f2", + "metadata": {}, + "source": [ + "We check with more shapes to see if the implementation works for all of them." + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "82f5fc78", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "OK x.shape=(3, 1, 4) length=(1, 4) output shape=(3, 1, 4) or (2, 3, 1, 3)\n", + "OK x.shape=(3, 1, 4) length=(1, 4) output shape=(3, 1, 4) or (2, 3, 1, 3)\n", + "OK x.shape=(3, 1, 4) length=(1, 4) output shape=(3, 1, 4) or (2, 3, 1, 3)\n", + "OK x.shape=(3, 1, 4) length=(1, 2) output shape=(3, 1, 2) or (2, 3, 1, 2)\n", + "OK x.shape=(3, 1, 4) length=(1, 1) output shape=(3, 1, 1) or (2, 3, 1, 1)\n", + "OK x.shape=(5, 7) length=(5, 7) output shape=(5, 7) or (2, 5, 4)\n", + "OK x.shape=(5, 7) length=(1, 7) output shape=(1, 7) or (2, 1, 4)\n", + "OK x.shape=(5, 7) length=(2, 7) output shape=(2, 7) or (2, 2, 4)\n", + "OK x.shape=(5, 7) length=(5, 2) output shape=(5, 2) or (2, 5, 2)\n", + "OK x.shape=(5, 7) length=(3, 4) output shape=(3, 4) or (2, 3, 3)\n", + "OK x.shape=(3, 5, 7) length=(5, 7) output shape=(3, 5, 7) or (2, 3, 5, 4)\n", + "OK x.shape=(3, 5, 7) length=(1, 7) output shape=(3, 1, 7) or (2, 3, 1, 4)\n", + "OK x.shape=(3, 5, 7) length=(2, 7) output shape=(3, 2, 7) or (2, 3, 2, 4)\n", + "OK x.shape=(3, 5, 7) length=(5, 2) output shape=(3, 5, 2) or (2, 3, 5, 2)\n", + "OK x.shape=(3, 5, 7) length=(3, 4) output shape=(3, 3, 4) or (2, 3, 3, 3)\n", + "OK x.shape=(7, 5) length=(7, 5) output shape=(7, 5) or (2, 7, 3)\n", + "OK x.shape=(7, 5) length=(1, 5) output shape=(1, 5) or (2, 1, 3)\n", + "OK x.shape=(7, 5) length=(2, 5) output shape=(2, 5) or (2, 2, 3)\n", + "OK x.shape=(7, 5) length=(7, 2) output shape=(7, 2) or (2, 7, 2)\n", + "OK x.shape=(7, 5) length=(3, 4) output shape=(3, 4) or (2, 3, 3)\n" + ] + } + ], + "source": [ + "for shape in [(3, 1, 4), (5, 7), (3, 5, 7), (7, 5)]:\n", + " for fft_length in [shape[-2:], (1, shape[-1]),\n", + " (min(2, shape[-2]), shape[-1]),\n", + " (shape[-2], 2),\n", + " (min(3, shape[-2]), min(4, shape[-2]))]:\n", + " x = numpy.random.randn(*list(shape)).astype(numpy.float32)\n", + " fnp = numpy.fft.fft2(x, fft_length)\n", + " if len(fnp.shape) == 2:\n", + " fn= numpy.expand_dims(fnp, 0)\n", + " try:\n", + " cus = fft2d_any(x, fft_length)\n", + " except IndexError as e:\n", + " print(\"ERR x.shape=%r length=%r error=%r\" % (x.shape, fft_length, e))\n", + " continue\n", + " try:\n", + " almost_equal(fnp[..., :cus.shape[-1]], cus)\n", + " except (AssertionError, IndexError) as e:\n", + " print(\"DIS x.shape=%r length=%r error=%r output shape=%r or %r\" % (\n", + " x.shape, fft_length, e, fnp.shape, cus.shape))\n", + " continue\n", + " print(\"OK x.shape=%r length=%r output shape=%r or %r\" % (\n", + " x.shape, fft_length, fnp.shape, cus.shape))" + ] + }, + { + "cell_type": "markdown", + "id": "c5f5229a", + "metadata": {}, + "source": [ + "### ONNX version\n", + "\n", + "Let's look into the differences first." + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "025c2d88", + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext pyquickhelper" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "82664bc5", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
populating...
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "/*\n", + "This is part of jsdifflib v1.0. \n", + "\n", + "Copyright 2007 - 2011 Chas Emerick . All rights reserved.\n", + "\n", + "Redistribution and use in source and binary forms, with or without modification, are\n", + "permitted provided that the following conditions are met:\n", + "\n", + " 1. Redistributions of source code must retain the above copyright notice, this list of\n", + " conditions and the following disclaimer.\n", + "\n", + " 2. Redistributions in binary form must reproduce the above copyright notice, this list\n", + " of conditions and the following disclaimer in the documentation and/or other materials\n", + " provided with the distribution.\n", + "\n", + "THIS SOFTWARE IS PROVIDED BY Chas Emerick ``AS IS'' AND ANY EXPRESS OR IMPLIED\n", + "WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND\n", + "FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL Chas Emerick OR\n", + "CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR\n", + "CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n", + "SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON\n", + "ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING\n", + "NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF\n", + "ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n", + "\n", + "The views and conclusions contained in the software and documentation are those of the\n", + "authors and should not be interpreted as representing official policies, either expressed\n", + "or implied, of Chas Emerick.\n", + "*/\n", + "var diffview = {\n", + "\t/**\n", + "\t * Builds and returns a visual diff view. The single parameter, `params', should contain\n", + "\t * the following values:\n", + "\t *\n", + "\t * - baseTextLines: the array of strings that was used as the base text input to SequenceMatcher\n", + "\t * - newTextLines: the array of strings that was used as the new text input to SequenceMatcher\n", + "\t * - opcodes: the array of arrays returned by SequenceMatcher.get_opcodes()\n", + "\t * - baseTextName: the title to be displayed above the base text listing in the diff view; defaults\n", + "\t *\t to \"Base Text\"\n", + "\t * - newTextName: the title to be displayed above the new text listing in the diff view; defaults\n", + "\t *\t to \"New Text\"\n", + "\t * - contextSize: the number of lines of context to show around differences; by default, all lines\n", + "\t *\t are shown\n", + "\t * - viewType: if 0, a side-by-side diff view is generated (default); if 1, an inline diff view is\n", + "\t *\t generated\n", + "\t */\n", + "\tbuildView: function (params) {\n", + "\t\tvar baseTextLines = params.baseTextLines;\n", + "\t\tvar newTextLines = params.newTextLines;\n", + "\t\tvar opcodes = params.opcodes;\n", + "\t\tvar baseTextName = params.baseTextName ? params.baseTextName : \"Base Text\";\n", + "\t\tvar newTextName = params.newTextName ? params.newTextName : \"New Text\";\n", + "\t\tvar contextSize = params.contextSize;\n", + "\t\tvar inline = (params.viewType == 0 || params.viewType == 1) ? params.viewType : 0;\n", + "\n", + "\t\tif (baseTextLines == null)\n", + "\t\t\tthrow \"Cannot build diff view; baseTextLines is not defined.\";\n", + "\t\tif (newTextLines == null)\n", + "\t\t\tthrow \"Cannot build diff view; newTextLines is not defined.\";\n", + "\t\tif (!opcodes)\n", + "\t\t\tthrow \"Cannot build diff view; opcodes is not defined.\";\n", + "\t\t\n", + "\t\tfunction celt (name, clazz) {\n", + "\t\t\tvar e = document.createElement(name);\n", + "\t\t\te.className = clazz;\n", + "\t\t\treturn e;\n", + "\t\t}\n", + "\t\t\n", + "\t\tfunction telt (name, text) {\n", + "\t\t\tvar e = document.createElement(name);\n", + "\t\t\te.appendChild(document.createTextNode(text));\n", + "\t\t\treturn e;\n", + "\t\t}\n", + "\t\t\n", + "\t\tfunction ctelt (name, clazz, text) {\n", + "\t\t\tvar e = document.createElement(name);\n", + "\t\t\te.className = clazz;\n", + "\t\t\te.appendChild(document.createTextNode(text));\n", + "\t\t\treturn e;\n", + "\t\t}\n", + "\t\n", + "\t\tvar tdata = document.createElement(\"thead\");\n", + "\t\tvar node = document.createElement(\"tr\");\n", + "\t\ttdata.appendChild(node);\n", + "\t\tif (inline) {\n", + "\t\t\tnode.appendChild(document.createElement(\"th\"));\n", + "\t\t\tnode.appendChild(document.createElement(\"th\"));\n", + "\t\t\tnode.appendChild(ctelt(\"th\", \"texttitle\", baseTextName + \" vs. \" + newTextName));\n", + "\t\t} else {\n", + "\t\t\tnode.appendChild(document.createElement(\"th\"));\n", + "\t\t\tnode.appendChild(ctelt(\"th\", \"texttitle\", baseTextName));\n", + "\t\t\tnode.appendChild(document.createElement(\"th\"));\n", + "\t\t\tnode.appendChild(ctelt(\"th\", \"texttitle\", newTextName));\n", + "\t\t}\n", + "\t\ttdata = [tdata];\n", + "\t\t\n", + "\t\tvar rows = [];\n", + "\t\tvar node2;\n", + "\t\t\n", + "\t\t/**\n", + "\t\t * Adds two cells to the given row; if the given row corresponds to a real\n", + "\t\t * line number (based on the line index tidx and the endpoint of the \n", + "\t\t * range in question tend), then the cells will contain the line number\n", + "\t\t * and the line of text from textLines at position tidx (with the class of\n", + "\t\t * the second cell set to the name of the change represented), and tidx + 1 will\n", + "\t\t * be returned.\t Otherwise, tidx is returned, and two empty cells are added\n", + "\t\t * to the given row.\n", + "\t\t */\n", + "\t\tfunction addCells (row, tidx, tend, textLines, change) {\n", + "\t\t\tif (tidx < tend) {\n", + "\t\t\t\trow.appendChild(telt(\"th\", (tidx + 1).toString()));\n", + "\t\t\t\trow.appendChild(ctelt(\"td\", change, textLines[tidx].replace(/\\t/g, \"\\u00a0\\u00a0\\u00a0\\u00a0\")));\n", + "\t\t\t\treturn tidx + 1;\n", + "\t\t\t} else {\n", + "\t\t\t\trow.appendChild(document.createElement(\"th\"));\n", + "\t\t\t\trow.appendChild(celt(\"td\", \"empty\"));\n", + "\t\t\t\treturn tidx;\n", + "\t\t\t}\n", + "\t\t}\n", + "\t\t\n", + "\t\tfunction addCellsInline (row, tidx, tidx2, textLines, change) {\n", + "\t\t\trow.appendChild(telt(\"th\", tidx == null ? \"\" : (tidx + 1).toString()));\n", + "\t\t\trow.appendChild(telt(\"th\", tidx2 == null ? \"\" : (tidx2 + 1).toString()));\n", + "\t\t\trow.appendChild(ctelt(\"td\", change, textLines[tidx != null ? tidx : tidx2].replace(/\\t/g, \"\\u00a0\\u00a0\\u00a0\\u00a0\")));\n", + "\t\t}\n", + "\t\t\n", + "\t\tfor (var idx = 0; idx < opcodes.length; idx++) {\n", + "\t\t\tvar code = opcodes[idx];\n", + "\t\t\tvar change = code[0];\n", + "\t\t\tvar b = code[1];\n", + "\t\t\tvar be = code[2];\n", + "\t\t\tvar n = code[3];\n", + "\t\t\tvar ne = code[4];\n", + "\t\t\tvar rowcnt = Math.max(be - b, ne - n);\n", + "\t\t\tvar toprows = [];\n", + "\t\t\tvar botrows = [];\n", + "\t\t\tfor (var i = 0; i < rowcnt; i++) {\n", + "\t\t\t\t// jump ahead if we've alredy provided leading context or if this is the first range\n", + "\t\t\t\tif (contextSize && opcodes.length > 1 && ((idx > 0 && i == contextSize) || (idx == 0 && i == 0)) && change==\"equal\") {\n", + "\t\t\t\t\tvar jump = rowcnt - ((idx == 0 ? 1 : 2) * contextSize);\n", + "\t\t\t\t\tif (jump > 1) {\n", + "\t\t\t\t\t\ttoprows.push(node = document.createElement(\"tr\"));\n", + "\t\t\t\t\t\t\n", + "\t\t\t\t\t\tb += jump;\n", + "\t\t\t\t\t\tn += jump;\n", + "\t\t\t\t\t\ti += jump - 1;\n", + "\t\t\t\t\t\tnode.appendChild(telt(\"th\", \"...\"));\n", + "\t\t\t\t\t\tif (!inline) node.appendChild(ctelt(\"td\", \"skip\", \"\"));\n", + "\t\t\t\t\t\tnode.appendChild(telt(\"th\", \"...\"));\n", + "\t\t\t\t\t\tnode.appendChild(ctelt(\"td\", \"skip\", \"\"));\n", + "\t\t\t\t\t\t\n", + "\t\t\t\t\t\t// skip last lines if they're all equal\n", + "\t\t\t\t\t\tif (idx + 1 == opcodes.length) {\n", + "\t\t\t\t\t\t\tbreak;\n", + "\t\t\t\t\t\t} else {\n", + "\t\t\t\t\t\t\tcontinue;\n", + "\t\t\t\t\t\t}\n", + "\t\t\t\t\t}\n", + "\t\t\t\t}\n", + "\t\t\t\t\n", + "\t\t\t\ttoprows.push(node = document.createElement(\"tr\"));\n", + "\t\t\t\tif (inline) {\n", + "\t\t\t\t\tif (change == \"insert\") {\n", + "\t\t\t\t\t\taddCellsInline(node, null, n++, newTextLines, change);\n", + "\t\t\t\t\t} else if (change == \"replace\") {\n", + "\t\t\t\t\t\tbotrows.push(node2 = document.createElement(\"tr\"));\n", + "\t\t\t\t\t\tif (b < be) addCellsInline(node, b++, null, baseTextLines, \"delete\");\n", + "\t\t\t\t\t\tif (n < ne) addCellsInline(node2, null, n++, newTextLines, \"insert\");\n", + "\t\t\t\t\t} else if (change == \"delete\") {\n", + "\t\t\t\t\t\taddCellsInline(node, b++, null, baseTextLines, change);\n", + "\t\t\t\t\t} else {\n", + "\t\t\t\t\t\t// equal\n", + "\t\t\t\t\t\taddCellsInline(node, b++, n++, baseTextLines, change);\n", + "\t\t\t\t\t}\n", + "\t\t\t\t} else {\n", + "\t\t\t\t\tb = addCells(node, b, be, baseTextLines, change);\n", + "\t\t\t\t\tn = addCells(node, n, ne, newTextLines, change);\n", + "\t\t\t\t}\n", + "\t\t\t}\n", + "\n", + "\t\t\tfor (var i = 0; i < toprows.length; i++) rows.push(toprows[i]);\n", + "\t\t\tfor (var i = 0; i < botrows.length; i++) rows.push(botrows[i]);\n", + "\t\t}\n", + "\t\t\n", + "\t\trows.push(node = ctelt(\"th\", \"author\", \"diff view generated by \"));\n", + "\t\tnode.setAttribute(\"colspan\", inline ? 3 : 4);\n", + "\t\tnode.appendChild(node2 = telt(\"a\", \"jsdifflib\"));\n", + "\t\tnode2.setAttribute(\"href\", \"http://github.com/cemerick/jsdifflib\");\n", + "\t\t\n", + "\t\ttdata.push(node = document.createElement(\"tbody\"));\n", + "\t\tfor (var idx in rows) rows.hasOwnProperty(idx) && node.appendChild(rows[idx]);\n", + "\t\t\n", + "\t\tnode = celt(\"table\", \"diff\" + (inline ? \" inlinediff\" : \"\"));\n", + "\t\tfor (var idx in tdata) tdata.hasOwnProperty(idx) && node.appendChild(tdata[idx]);\n", + "\t\treturn node;\n", + "\t}\n", + "};\n", + "\n", + "\n", + "/***\n", + "This is part of jsdifflib v1.0. \n", + "\n", + "Copyright (c) 2007, Snowtide Informatics Systems, Inc.\n", + "All rights reserved.\n", + "\n", + "Redistribution and use in source and binary forms, with or without modification,\n", + "are permitted provided that the following conditions are met:\n", + "\n", + "\t* Redistributions of source code must retain the above copyright notice, this\n", + "\t\tlist of conditions and the following disclaimer.\n", + "\t* Redistributions in binary form must reproduce the above copyright notice,\n", + "\t\tthis list of conditions and the following disclaimer in the documentation\n", + "\t\tand/or other materials provided with the distribution.\n", + "\t* Neither the name of the Snowtide Informatics Systems nor the names of its\n", + "\t\tcontributors may be used to endorse or promote products derived from this\n", + "\t\tsoftware without specific prior written permission.\n", + "\n", + "THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND ANY\n", + "EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES\n", + "OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT\n", + "SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,\n", + "INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED\n", + "TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR\n", + "BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN\n", + "CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN\n", + "ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH\n", + "DAMAGE.\n", + "***/\n", + "/* Author: Chas Emerick */\n", + "var __whitespace = {\" \":true, \"\\t\":true, \"\\n\":true, \"\\f\":true, \"\\r\":true};\n", + "\n", + "var difflib = {\n", + "\tdefaultJunkFunction: function (c) {\n", + "\t\treturn __whitespace.hasOwnProperty(c);\n", + "\t},\n", + "\t\n", + "\tstripLinebreaks: function (str) { return str.replace(/^[\\n\\r]*|[\\n\\r]*$/g, \"\"); },\n", + "\t\n", + "\tstringAsLines: function (str) {\n", + "\t\tvar lfpos = str.indexOf(\"\\n\");\n", + "\t\tvar crpos = str.indexOf(\"\\r\");\n", + "\t\tvar linebreak = ((lfpos > -1 && crpos > -1) || crpos < 0) ? \"\\n\" : \"\\r\";\n", + "\t\t\n", + "\t\tvar lines = str.split(linebreak);\n", + "\t\tfor (var i = 0; i < lines.length; i++) {\n", + "\t\t\tlines[i] = difflib.stripLinebreaks(lines[i]);\n", + "\t\t}\n", + "\t\t\n", + "\t\treturn lines;\n", + "\t},\n", + "\t\n", + "\t// iteration-based reduce implementation\n", + "\t__reduce: function (func, list, initial) {\n", + "\t\tif (initial != null) {\n", + "\t\t\tvar value = initial;\n", + "\t\t\tvar idx = 0;\n", + "\t\t} else if (list) {\n", + "\t\t\tvar value = list[0];\n", + "\t\t\tvar idx = 1;\n", + "\t\t} else {\n", + "\t\t\treturn null;\n", + "\t\t}\n", + "\t\t\n", + "\t\tfor (; idx < list.length; idx++) {\n", + "\t\t\tvalue = func(value, list[idx]);\n", + "\t\t}\n", + "\t\t\n", + "\t\treturn value;\n", + "\t},\n", + "\t\n", + "\t// comparison function for sorting lists of numeric tuples\n", + "\t__ntuplecomp: function (a, b) {\n", + "\t\tvar mlen = Math.max(a.length, b.length);\n", + "\t\tfor (var i = 0; i < mlen; i++) {\n", + "\t\t\tif (a[i] < b[i]) return -1;\n", + "\t\t\tif (a[i] > b[i]) return 1;\n", + "\t\t}\n", + "\t\t\n", + "\t\treturn a.length == b.length ? 0 : (a.length < b.length ? -1 : 1);\n", + "\t},\n", + "\t\n", + "\t__calculate_ratio: function (matches, length) {\n", + "\t\treturn length ? 2.0 * matches / length : 1.0;\n", + "\t},\n", + "\t\n", + "\t// returns a function that returns true if a key passed to the returned function\n", + "\t// is in the dict (js object) provided to this function; replaces being able to\n", + "\t// carry around dict.has_key in python...\n", + "\t__isindict: function (dict) {\n", + "\t\treturn function (key) { return dict.hasOwnProperty(key); };\n", + "\t},\n", + "\t\n", + "\t// replacement for python's dict.get function -- need easy default values\n", + "\t__dictget: function (dict, key, defaultValue) {\n", + "\t\treturn dict.hasOwnProperty(key) ? dict[key] : defaultValue;\n", + "\t},\t\n", + "\t\n", + "\tSequenceMatcher: function (a, b, isjunk) {\n", + "\t\tthis.set_seqs = function (a, b) {\n", + "\t\t\tthis.set_seq1(a);\n", + "\t\t\tthis.set_seq2(b);\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.set_seq1 = function (a) {\n", + "\t\t\tif (a == this.a) return;\n", + "\t\t\tthis.a = a;\n", + "\t\t\tthis.matching_blocks = this.opcodes = null;\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.set_seq2 = function (b) {\n", + "\t\t\tif (b == this.b) return;\n", + "\t\t\tthis.b = b;\n", + "\t\t\tthis.matching_blocks = this.opcodes = this.fullbcount = null;\n", + "\t\t\tthis.__chain_b();\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.__chain_b = function () {\n", + "\t\t\tvar b = this.b;\n", + "\t\t\tvar n = b.length;\n", + "\t\t\tvar b2j = this.b2j = {};\n", + "\t\t\tvar populardict = {};\n", + "\t\t\tfor (var i = 0; i < b.length; i++) {\n", + "\t\t\t\tvar elt = b[i];\n", + "\t\t\t\tif (b2j.hasOwnProperty(elt)) {\n", + "\t\t\t\t\tvar indices = b2j[elt];\n", + "\t\t\t\t\tif (n >= 200 && indices.length * 100 > n) {\n", + "\t\t\t\t\t\tpopulardict[elt] = 1;\n", + "\t\t\t\t\t\tdelete b2j[elt];\n", + "\t\t\t\t\t} else {\n", + "\t\t\t\t\t\tindices.push(i);\n", + "\t\t\t\t\t}\n", + "\t\t\t\t} else {\n", + "\t\t\t\t\tb2j[elt] = [i];\n", + "\t\t\t\t}\n", + "\t\t\t}\n", + "\t\n", + "\t\t\tfor (var elt in populardict) {\n", + "\t\t\t\tif (populardict.hasOwnProperty(elt)) {\n", + "\t\t\t\t\tdelete b2j[elt];\n", + "\t\t\t\t}\n", + "\t\t\t}\n", + "\t\t\t\n", + "\t\t\tvar isjunk = this.isjunk;\n", + "\t\t\tvar junkdict = {};\n", + "\t\t\tif (isjunk) {\n", + "\t\t\t\tfor (var elt in populardict) {\n", + "\t\t\t\t\tif (populardict.hasOwnProperty(elt) && isjunk(elt)) {\n", + "\t\t\t\t\t\tjunkdict[elt] = 1;\n", + "\t\t\t\t\t\tdelete populardict[elt];\n", + "\t\t\t\t\t}\n", + "\t\t\t\t}\n", + "\t\t\t\tfor (var elt in b2j) {\n", + "\t\t\t\t\tif (b2j.hasOwnProperty(elt) && isjunk(elt)) {\n", + "\t\t\t\t\t\tjunkdict[elt] = 1;\n", + "\t\t\t\t\t\tdelete b2j[elt];\n", + "\t\t\t\t\t}\n", + "\t\t\t\t}\n", + "\t\t\t}\n", + "\t\n", + "\t\t\tthis.isbjunk = difflib.__isindict(junkdict);\n", + "\t\t\tthis.isbpopular = difflib.__isindict(populardict);\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.find_longest_match = function (alo, ahi, blo, bhi) {\n", + "\t\t\tvar a = this.a;\n", + "\t\t\tvar b = this.b;\n", + "\t\t\tvar b2j = this.b2j;\n", + "\t\t\tvar isbjunk = this.isbjunk;\n", + "\t\t\tvar besti = alo;\n", + "\t\t\tvar bestj = blo;\n", + "\t\t\tvar bestsize = 0;\n", + "\t\t\tvar j = null;\n", + "\t\t\tvar k;\n", + "\t\n", + "\t\t\tvar j2len = {};\n", + "\t\t\tvar nothing = [];\n", + "\t\t\tfor (var i = alo; i < ahi; i++) {\n", + "\t\t\t\tvar newj2len = {};\n", + "\t\t\t\tvar jdict = difflib.__dictget(b2j, a[i], nothing);\n", + "\t\t\t\tfor (var jkey in jdict) {\n", + "\t\t\t\t\tif (jdict.hasOwnProperty(jkey)) {\n", + "\t\t\t\t\t\tj = jdict[jkey];\n", + "\t\t\t\t\t\tif (j < blo) continue;\n", + "\t\t\t\t\t\tif (j >= bhi) break;\n", + "\t\t\t\t\t\tnewj2len[j] = k = difflib.__dictget(j2len, j - 1, 0) + 1;\n", + "\t\t\t\t\t\tif (k > bestsize) {\n", + "\t\t\t\t\t\t\tbesti = i - k + 1;\n", + "\t\t\t\t\t\t\tbestj = j - k + 1;\n", + "\t\t\t\t\t\t\tbestsize = k;\n", + "\t\t\t\t\t\t}\n", + "\t\t\t\t\t}\n", + "\t\t\t\t}\n", + "\t\t\t\tj2len = newj2len;\n", + "\t\t\t}\n", + "\t\n", + "\t\t\twhile (besti > alo && bestj > blo && !isbjunk(b[bestj - 1]) && a[besti - 1] == b[bestj - 1]) {\n", + "\t\t\t\tbesti--;\n", + "\t\t\t\tbestj--;\n", + "\t\t\t\tbestsize++;\n", + "\t\t\t}\n", + "\t\t\t\t\n", + "\t\t\twhile (besti + bestsize < ahi && bestj + bestsize < bhi &&\n", + "\t\t\t\t\t!isbjunk(b[bestj + bestsize]) &&\n", + "\t\t\t\t\ta[besti + bestsize] == b[bestj + bestsize]) {\n", + "\t\t\t\tbestsize++;\n", + "\t\t\t}\n", + "\t\n", + "\t\t\twhile (besti > alo && bestj > blo && isbjunk(b[bestj - 1]) && a[besti - 1] == b[bestj - 1]) {\n", + "\t\t\t\tbesti--;\n", + "\t\t\t\tbestj--;\n", + "\t\t\t\tbestsize++;\n", + "\t\t\t}\n", + "\t\t\t\n", + "\t\t\twhile (besti + bestsize < ahi && bestj + bestsize < bhi && isbjunk(b[bestj + bestsize]) &&\n", + "\t\t\t\t\ta[besti + bestsize] == b[bestj + bestsize]) {\n", + "\t\t\t\tbestsize++;\n", + "\t\t\t}\n", + "\t\n", + "\t\t\treturn [besti, bestj, bestsize];\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.get_matching_blocks = function () {\n", + "\t\t\tif (this.matching_blocks != null) return this.matching_blocks;\n", + "\t\t\tvar la = this.a.length;\n", + "\t\t\tvar lb = this.b.length;\n", + "\t\n", + "\t\t\tvar queue = [[0, la, 0, lb]];\n", + "\t\t\tvar matching_blocks = [];\n", + "\t\t\tvar alo, ahi, blo, bhi, qi, i, j, k, x;\n", + "\t\t\twhile (queue.length) {\n", + "\t\t\t\tqi = queue.pop();\n", + "\t\t\t\talo = qi[0];\n", + "\t\t\t\tahi = qi[1];\n", + "\t\t\t\tblo = qi[2];\n", + "\t\t\t\tbhi = qi[3];\n", + "\t\t\t\tx = this.find_longest_match(alo, ahi, blo, bhi);\n", + "\t\t\t\ti = x[0];\n", + "\t\t\t\tj = x[1];\n", + "\t\t\t\tk = x[2];\n", + "\t\n", + "\t\t\t\tif (k) {\n", + "\t\t\t\t\tmatching_blocks.push(x);\n", + "\t\t\t\t\tif (alo < i && blo < j)\n", + "\t\t\t\t\t\tqueue.push([alo, i, blo, j]);\n", + "\t\t\t\t\tif (i+k < ahi && j+k < bhi)\n", + "\t\t\t\t\t\tqueue.push([i + k, ahi, j + k, bhi]);\n", + "\t\t\t\t}\n", + "\t\t\t}\n", + "\t\t\t\n", + "\t\t\tmatching_blocks.sort(difflib.__ntuplecomp);\n", + "\t\n", + "\t\t\tvar i1 = 0, j1 = 0, k1 = 0, block = 0;\n", + "\t\t\tvar i2, j2, k2;\n", + "\t\t\tvar non_adjacent = [];\n", + "\t\t\tfor (var idx in matching_blocks) {\n", + "\t\t\t\tif (matching_blocks.hasOwnProperty(idx)) {\n", + "\t\t\t\t\tblock = matching_blocks[idx];\n", + "\t\t\t\t\ti2 = block[0];\n", + "\t\t\t\t\tj2 = block[1];\n", + "\t\t\t\t\tk2 = block[2];\n", + "\t\t\t\t\tif (i1 + k1 == i2 && j1 + k1 == j2) {\n", + "\t\t\t\t\t\tk1 += k2;\n", + "\t\t\t\t\t} else {\n", + "\t\t\t\t\t\tif (k1) non_adjacent.push([i1, j1, k1]);\n", + "\t\t\t\t\t\ti1 = i2;\n", + "\t\t\t\t\t\tj1 = j2;\n", + "\t\t\t\t\t\tk1 = k2;\n", + "\t\t\t\t\t}\n", + "\t\t\t\t}\n", + "\t\t\t}\n", + "\t\t\t\n", + "\t\t\tif (k1) non_adjacent.push([i1, j1, k1]);\n", + "\t\n", + "\t\t\tnon_adjacent.push([la, lb, 0]);\n", + "\t\t\tthis.matching_blocks = non_adjacent;\n", + "\t\t\treturn this.matching_blocks;\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.get_opcodes = function () {\n", + "\t\t\tif (this.opcodes != null) return this.opcodes;\n", + "\t\t\tvar i = 0;\n", + "\t\t\tvar j = 0;\n", + "\t\t\tvar answer = [];\n", + "\t\t\tthis.opcodes = answer;\n", + "\t\t\tvar block, ai, bj, size, tag;\n", + "\t\t\tvar blocks = this.get_matching_blocks();\n", + "\t\t\tfor (var idx in blocks) {\n", + "\t\t\t\tif (blocks.hasOwnProperty(idx)) {\n", + "\t\t\t\t\tblock = blocks[idx];\n", + "\t\t\t\t\tai = block[0];\n", + "\t\t\t\t\tbj = block[1];\n", + "\t\t\t\t\tsize = block[2];\n", + "\t\t\t\t\ttag = '';\n", + "\t\t\t\t\tif (i < ai && j < bj) {\n", + "\t\t\t\t\t\ttag = 'replace';\n", + "\t\t\t\t\t} else if (i < ai) {\n", + "\t\t\t\t\t\ttag = 'delete';\n", + "\t\t\t\t\t} else if (j < bj) {\n", + "\t\t\t\t\t\ttag = 'insert';\n", + "\t\t\t\t\t}\n", + "\t\t\t\t\tif (tag) answer.push([tag, i, ai, j, bj]);\n", + "\t\t\t\t\ti = ai + size;\n", + "\t\t\t\t\tj = bj + size;\n", + "\t\t\t\t\t\n", + "\t\t\t\t\tif (size) answer.push(['equal', ai, i, bj, j]);\n", + "\t\t\t\t}\n", + "\t\t\t}\n", + "\t\t\t\n", + "\t\t\treturn answer;\n", + "\t\t}\n", + "\t\t\n", + "\t\t// this is a generator function in the python lib, which of course is not supported in javascript\n", + "\t\t// the reimplementation builds up the grouped opcodes into a list in their entirety and returns that.\n", + "\t\tthis.get_grouped_opcodes = function (n) {\n", + "\t\t\tif (!n) n = 3;\n", + "\t\t\tvar codes = this.get_opcodes();\n", + "\t\t\tif (!codes) codes = [[\"equal\", 0, 1, 0, 1]];\n", + "\t\t\tvar code, tag, i1, i2, j1, j2;\n", + "\t\t\tif (codes[0][0] == 'equal') {\n", + "\t\t\t\tcode = codes[0];\n", + "\t\t\t\ttag = code[0];\n", + "\t\t\t\ti1 = code[1];\n", + "\t\t\t\ti2 = code[2];\n", + "\t\t\t\tj1 = code[3];\n", + "\t\t\t\tj2 = code[4];\n", + "\t\t\t\tcodes[0] = [tag, Math.max(i1, i2 - n), i2, Math.max(j1, j2 - n), j2];\n", + "\t\t\t}\n", + "\t\t\tif (codes[codes.length - 1][0] == 'equal') {\n", + "\t\t\t\tcode = codes[codes.length - 1];\n", + "\t\t\t\ttag = code[0];\n", + "\t\t\t\ti1 = code[1];\n", + "\t\t\t\ti2 = code[2];\n", + "\t\t\t\tj1 = code[3];\n", + "\t\t\t\tj2 = code[4];\n", + "\t\t\t\tcodes[codes.length - 1] = [tag, i1, Math.min(i2, i1 + n), j1, Math.min(j2, j1 + n)];\n", + "\t\t\t}\n", + "\t\n", + "\t\t\tvar nn = n + n;\n", + "\t\t\tvar group = [];\n", + "\t\t\tvar groups = [];\n", + "\t\t\tfor (var idx in codes) {\n", + "\t\t\t\tif (codes.hasOwnProperty(idx)) {\n", + "\t\t\t\t\tcode = codes[idx];\n", + "\t\t\t\t\ttag = code[0];\n", + "\t\t\t\t\ti1 = code[1];\n", + "\t\t\t\t\ti2 = code[2];\n", + "\t\t\t\t\tj1 = code[3];\n", + "\t\t\t\t\tj2 = code[4];\n", + "\t\t\t\t\tif (tag == 'equal' && i2 - i1 > nn) {\n", + "\t\t\t\t\t\tgroup.push([tag, i1, Math.min(i2, i1 + n), j1, Math.min(j2, j1 + n)]);\n", + "\t\t\t\t\t\tgroups.push(group);\n", + "\t\t\t\t\t\tgroup = [];\n", + "\t\t\t\t\t\ti1 = Math.max(i1, i2-n);\n", + "\t\t\t\t\t\tj1 = Math.max(j1, j2-n);\n", + "\t\t\t\t\t}\n", + "\t\t\t\t\t\n", + "\t\t\t\t\tgroup.push([tag, i1, i2, j1, j2]);\n", + "\t\t\t\t}\n", + "\t\t\t}\n", + "\t\t\t\n", + "\t\t\tif (group && !(group.length == 1 && group[0][0] == 'equal')) groups.push(group)\n", + "\t\t\t\n", + "\t\t\treturn groups;\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.ratio = function () {\n", + "\t\t\tmatches = difflib.__reduce(\n", + "\t\t\t\t\t\t\tfunction (sum, triple) { return sum + triple[triple.length - 1]; },\n", + "\t\t\t\t\t\t\tthis.get_matching_blocks(), 0);\n", + "\t\t\treturn difflib.__calculate_ratio(matches, this.a.length + this.b.length);\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.quick_ratio = function () {\n", + "\t\t\tvar fullbcount, elt;\n", + "\t\t\tif (this.fullbcount == null) {\n", + "\t\t\t\tthis.fullbcount = fullbcount = {};\n", + "\t\t\t\tfor (var i = 0; i < this.b.length; i++) {\n", + "\t\t\t\t\telt = this.b[i];\n", + "\t\t\t\t\tfullbcount[elt] = difflib.__dictget(fullbcount, elt, 0) + 1;\n", + "\t\t\t\t}\n", + "\t\t\t}\n", + "\t\t\tfullbcount = this.fullbcount;\n", + "\t\n", + "\t\t\tvar avail = {};\n", + "\t\t\tvar availhas = difflib.__isindict(avail);\n", + "\t\t\tvar matches = numb = 0;\n", + "\t\t\tfor (var i = 0; i < this.a.length; i++) {\n", + "\t\t\t\telt = this.a[i];\n", + "\t\t\t\tif (availhas(elt)) {\n", + "\t\t\t\t\tnumb = avail[elt];\n", + "\t\t\t\t} else {\n", + "\t\t\t\t\tnumb = difflib.__dictget(fullbcount, elt, 0);\n", + "\t\t\t\t}\n", + "\t\t\t\tavail[elt] = numb - 1;\n", + "\t\t\t\tif (numb > 0) matches++;\n", + "\t\t\t}\n", + "\t\t\t\n", + "\t\t\treturn difflib.__calculate_ratio(matches, this.a.length + this.b.length);\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.real_quick_ratio = function () {\n", + "\t\t\tvar la = this.a.length;\n", + "\t\t\tvar lb = this.b.length;\n", + "\t\t\treturn _calculate_ratio(Math.min(la, lb), la + lb);\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.isjunk = isjunk ? isjunk : difflib.defaultJunkFunction;\n", + "\t\tthis.a = this.b = null;\n", + "\t\tthis.set_seqs(a, b);\n", + "\t}\n", + "};\n", + "\n", + "\n", + "\n", + "function diffUsingJS (viewType, contextSize, baseText, newText) {\n", + "\n", + " var byId = function (id) { return document.getElementById(id); },\n", + " base = difflib.stringAsLines(baseText),\n", + " newtxt = difflib.stringAsLines(newText),\n", + " sm = new difflib.SequenceMatcher(base, newtxt),\n", + " opcodes = sm.get_opcodes(),\n", + " diffoutputdiv = byId(\"diffid_2021-08-05_16_46_43_018480\");\n", + "\n", + " diffoutputdiv.innerHTML = \"\";\n", + " contextSize = contextSize || null;\n", + "\n", + " diffoutputdiv.appendChild(diffview.buildView({\n", + " baseTextLines: base,\n", + " newTextLines: newtxt,\n", + " opcodes: opcodes,\n", + " baseTextName: \"Base Text\",\n", + " newTextName: \"New Text\",\n", + " contextSize: contextSize,\n", + " viewType: viewType\n", + " }));\n", + "}\n", + "var tview=0;\n", + "var csize='';\n", + "var bt = 'def dft_real(x, fft_length=None, transpose=True):\\n if len(x.shape) == 1:\\n x = x.reshape((1, -1))\\n N = 1\\n else:\\n N = x.shape[0] \\n C = x.shape[-1] if transpose else x.shape[-2]\\n if fft_length is None:\\n fft_length = x.shape[-1]\\n size = fft_length // 2 + 1\\n\\n cst = dft_real_cst(C, fft_length)\\n if transpose:\\n x = numpy.transpose(x, (1, 0))\\n a = cst[:, :, :fft_length]\\n b = x[:fft_length]\\n res = numpy.matmul(a, b)\\n res = res[:, :size, :]\\n return numpy.transpose(res, (0, 2, 1))\\n else:\\n a = cst[:, :, :fft_length]\\n b = x[:fft_length]\\n return numpy.matmul(a, b)\\n';\n", + "var nt = 'def dft_real_d3(x, fft_length=None, transpose=True):\\n if len(x.shape) != 3:\\n raise RuntimeError(\"Not implemented for shape=%r.\" % x.shape)\\n N = x.shape[1]\\n C = x.shape[-1] if transpose else x.shape[-2]\\n if fft_length is None:\\n fft_length = x.shape[-1]\\n size = fft_length // 2 + 1\\n\\n cst = dft_real_cst(C, fft_length)\\n if transpose:\\n x = numpy.transpose(x, (0, 2, 1))\\n a = cst[:, :, :fft_length]\\n b = x[:, :fft_length, :]\\n a = numpy.expand_dims(a, 0)\\n b = numpy.expand_dims(b, 1)\\n res = numpy.matmul(a, b)\\n res = res[:, :, :size, :]\\n return numpy.transpose(res, (1, 0, 3, 2))\\n else:\\n a = cst[:, :, :fft_length]\\n b = x[:, :fft_length, :]\\n a = numpy.expand_dims(a, 0)\\n b = numpy.expand_dims(b, 1)\\n res = numpy.matmul(a, b)\\n return numpy.transpose(res, (1, 0, 2, 3))\\n';\n", + "diffUsingJS(tview, csize, bt, nt) ;\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import inspect\n", + "text1 = inspect.getsource(dft_real)\n", + "text2 = inspect.getsource(dft_real_d3)\n", + "%textdiff text1 text2" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "cd7e14d4", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
populating...
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "/*\n", + "This is part of jsdifflib v1.0. \n", + "\n", + "Copyright 2007 - 2011 Chas Emerick . All rights reserved.\n", + "\n", + "Redistribution and use in source and binary forms, with or without modification, are\n", + "permitted provided that the following conditions are met:\n", + "\n", + " 1. Redistributions of source code must retain the above copyright notice, this list of\n", + " conditions and the following disclaimer.\n", + "\n", + " 2. Redistributions in binary form must reproduce the above copyright notice, this list\n", + " of conditions and the following disclaimer in the documentation and/or other materials\n", + " provided with the distribution.\n", + "\n", + "THIS SOFTWARE IS PROVIDED BY Chas Emerick ``AS IS'' AND ANY EXPRESS OR IMPLIED\n", + "WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND\n", + "FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL Chas Emerick OR\n", + "CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR\n", + "CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n", + "SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON\n", + "ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING\n", + "NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF\n", + "ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n", + "\n", + "The views and conclusions contained in the software and documentation are those of the\n", + "authors and should not be interpreted as representing official policies, either expressed\n", + "or implied, of Chas Emerick.\n", + "*/\n", + "var diffview = {\n", + "\t/**\n", + "\t * Builds and returns a visual diff view. The single parameter, `params', should contain\n", + "\t * the following values:\n", + "\t *\n", + "\t * - baseTextLines: the array of strings that was used as the base text input to SequenceMatcher\n", + "\t * - newTextLines: the array of strings that was used as the new text input to SequenceMatcher\n", + "\t * - opcodes: the array of arrays returned by SequenceMatcher.get_opcodes()\n", + "\t * - baseTextName: the title to be displayed above the base text listing in the diff view; defaults\n", + "\t *\t to \"Base Text\"\n", + "\t * - newTextName: the title to be displayed above the new text listing in the diff view; defaults\n", + "\t *\t to \"New Text\"\n", + "\t * - contextSize: the number of lines of context to show around differences; by default, all lines\n", + "\t *\t are shown\n", + "\t * - viewType: if 0, a side-by-side diff view is generated (default); if 1, an inline diff view is\n", + "\t *\t generated\n", + "\t */\n", + "\tbuildView: function (params) {\n", + "\t\tvar baseTextLines = params.baseTextLines;\n", + "\t\tvar newTextLines = params.newTextLines;\n", + "\t\tvar opcodes = params.opcodes;\n", + "\t\tvar baseTextName = params.baseTextName ? params.baseTextName : \"Base Text\";\n", + "\t\tvar newTextName = params.newTextName ? params.newTextName : \"New Text\";\n", + "\t\tvar contextSize = params.contextSize;\n", + "\t\tvar inline = (params.viewType == 0 || params.viewType == 1) ? params.viewType : 0;\n", + "\n", + "\t\tif (baseTextLines == null)\n", + "\t\t\tthrow \"Cannot build diff view; baseTextLines is not defined.\";\n", + "\t\tif (newTextLines == null)\n", + "\t\t\tthrow \"Cannot build diff view; newTextLines is not defined.\";\n", + "\t\tif (!opcodes)\n", + "\t\t\tthrow \"Cannot build diff view; opcodes is not defined.\";\n", + "\t\t\n", + "\t\tfunction celt (name, clazz) {\n", + "\t\t\tvar e = document.createElement(name);\n", + "\t\t\te.className = clazz;\n", + "\t\t\treturn e;\n", + "\t\t}\n", + "\t\t\n", + "\t\tfunction telt (name, text) {\n", + "\t\t\tvar e = document.createElement(name);\n", + "\t\t\te.appendChild(document.createTextNode(text));\n", + "\t\t\treturn e;\n", + "\t\t}\n", + "\t\t\n", + "\t\tfunction ctelt (name, clazz, text) {\n", + "\t\t\tvar e = document.createElement(name);\n", + "\t\t\te.className = clazz;\n", + "\t\t\te.appendChild(document.createTextNode(text));\n", + "\t\t\treturn e;\n", + "\t\t}\n", + "\t\n", + "\t\tvar tdata = document.createElement(\"thead\");\n", + "\t\tvar node = document.createElement(\"tr\");\n", + "\t\ttdata.appendChild(node);\n", + "\t\tif (inline) {\n", + "\t\t\tnode.appendChild(document.createElement(\"th\"));\n", + "\t\t\tnode.appendChild(document.createElement(\"th\"));\n", + "\t\t\tnode.appendChild(ctelt(\"th\", \"texttitle\", baseTextName + \" vs. \" + newTextName));\n", + "\t\t} else {\n", + "\t\t\tnode.appendChild(document.createElement(\"th\"));\n", + "\t\t\tnode.appendChild(ctelt(\"th\", \"texttitle\", baseTextName));\n", + "\t\t\tnode.appendChild(document.createElement(\"th\"));\n", + "\t\t\tnode.appendChild(ctelt(\"th\", \"texttitle\", newTextName));\n", + "\t\t}\n", + "\t\ttdata = [tdata];\n", + "\t\t\n", + "\t\tvar rows = [];\n", + "\t\tvar node2;\n", + "\t\t\n", + "\t\t/**\n", + "\t\t * Adds two cells to the given row; if the given row corresponds to a real\n", + "\t\t * line number (based on the line index tidx and the endpoint of the \n", + "\t\t * range in question tend), then the cells will contain the line number\n", + "\t\t * and the line of text from textLines at position tidx (with the class of\n", + "\t\t * the second cell set to the name of the change represented), and tidx + 1 will\n", + "\t\t * be returned.\t Otherwise, tidx is returned, and two empty cells are added\n", + "\t\t * to the given row.\n", + "\t\t */\n", + "\t\tfunction addCells (row, tidx, tend, textLines, change) {\n", + "\t\t\tif (tidx < tend) {\n", + "\t\t\t\trow.appendChild(telt(\"th\", (tidx + 1).toString()));\n", + "\t\t\t\trow.appendChild(ctelt(\"td\", change, textLines[tidx].replace(/\\t/g, \"\\u00a0\\u00a0\\u00a0\\u00a0\")));\n", + "\t\t\t\treturn tidx + 1;\n", + "\t\t\t} else {\n", + "\t\t\t\trow.appendChild(document.createElement(\"th\"));\n", + "\t\t\t\trow.appendChild(celt(\"td\", \"empty\"));\n", + "\t\t\t\treturn tidx;\n", + "\t\t\t}\n", + "\t\t}\n", + "\t\t\n", + "\t\tfunction addCellsInline (row, tidx, tidx2, textLines, change) {\n", + "\t\t\trow.appendChild(telt(\"th\", tidx == null ? \"\" : (tidx + 1).toString()));\n", + "\t\t\trow.appendChild(telt(\"th\", tidx2 == null ? \"\" : (tidx2 + 1).toString()));\n", + "\t\t\trow.appendChild(ctelt(\"td\", change, textLines[tidx != null ? tidx : tidx2].replace(/\\t/g, \"\\u00a0\\u00a0\\u00a0\\u00a0\")));\n", + "\t\t}\n", + "\t\t\n", + "\t\tfor (var idx = 0; idx < opcodes.length; idx++) {\n", + "\t\t\tvar code = opcodes[idx];\n", + "\t\t\tvar change = code[0];\n", + "\t\t\tvar b = code[1];\n", + "\t\t\tvar be = code[2];\n", + "\t\t\tvar n = code[3];\n", + "\t\t\tvar ne = code[4];\n", + "\t\t\tvar rowcnt = Math.max(be - b, ne - n);\n", + "\t\t\tvar toprows = [];\n", + "\t\t\tvar botrows = [];\n", + "\t\t\tfor (var i = 0; i < rowcnt; i++) {\n", + "\t\t\t\t// jump ahead if we've alredy provided leading context or if this is the first range\n", + "\t\t\t\tif (contextSize && opcodes.length > 1 && ((idx > 0 && i == contextSize) || (idx == 0 && i == 0)) && change==\"equal\") {\n", + "\t\t\t\t\tvar jump = rowcnt - ((idx == 0 ? 1 : 2) * contextSize);\n", + "\t\t\t\t\tif (jump > 1) {\n", + "\t\t\t\t\t\ttoprows.push(node = document.createElement(\"tr\"));\n", + "\t\t\t\t\t\t\n", + "\t\t\t\t\t\tb += jump;\n", + "\t\t\t\t\t\tn += jump;\n", + "\t\t\t\t\t\ti += jump - 1;\n", + "\t\t\t\t\t\tnode.appendChild(telt(\"th\", \"...\"));\n", + "\t\t\t\t\t\tif (!inline) node.appendChild(ctelt(\"td\", \"skip\", \"\"));\n", + "\t\t\t\t\t\tnode.appendChild(telt(\"th\", \"...\"));\n", + "\t\t\t\t\t\tnode.appendChild(ctelt(\"td\", \"skip\", \"\"));\n", + "\t\t\t\t\t\t\n", + "\t\t\t\t\t\t// skip last lines if they're all equal\n", + "\t\t\t\t\t\tif (idx + 1 == opcodes.length) {\n", + "\t\t\t\t\t\t\tbreak;\n", + "\t\t\t\t\t\t} else {\n", + "\t\t\t\t\t\t\tcontinue;\n", + "\t\t\t\t\t\t}\n", + "\t\t\t\t\t}\n", + "\t\t\t\t}\n", + "\t\t\t\t\n", + "\t\t\t\ttoprows.push(node = document.createElement(\"tr\"));\n", + "\t\t\t\tif (inline) {\n", + "\t\t\t\t\tif (change == \"insert\") {\n", + "\t\t\t\t\t\taddCellsInline(node, null, n++, newTextLines, change);\n", + "\t\t\t\t\t} else if (change == \"replace\") {\n", + "\t\t\t\t\t\tbotrows.push(node2 = document.createElement(\"tr\"));\n", + "\t\t\t\t\t\tif (b < be) addCellsInline(node, b++, null, baseTextLines, \"delete\");\n", + "\t\t\t\t\t\tif (n < ne) addCellsInline(node2, null, n++, newTextLines, \"insert\");\n", + "\t\t\t\t\t} else if (change == \"delete\") {\n", + "\t\t\t\t\t\taddCellsInline(node, b++, null, baseTextLines, change);\n", + "\t\t\t\t\t} else {\n", + "\t\t\t\t\t\t// equal\n", + "\t\t\t\t\t\taddCellsInline(node, b++, n++, baseTextLines, change);\n", + "\t\t\t\t\t}\n", + "\t\t\t\t} else {\n", + "\t\t\t\t\tb = addCells(node, b, be, baseTextLines, change);\n", + "\t\t\t\t\tn = addCells(node, n, ne, newTextLines, change);\n", + "\t\t\t\t}\n", + "\t\t\t}\n", + "\n", + "\t\t\tfor (var i = 0; i < toprows.length; i++) rows.push(toprows[i]);\n", + "\t\t\tfor (var i = 0; i < botrows.length; i++) rows.push(botrows[i]);\n", + "\t\t}\n", + "\t\t\n", + "\t\trows.push(node = ctelt(\"th\", \"author\", \"diff view generated by \"));\n", + "\t\tnode.setAttribute(\"colspan\", inline ? 3 : 4);\n", + "\t\tnode.appendChild(node2 = telt(\"a\", \"jsdifflib\"));\n", + "\t\tnode2.setAttribute(\"href\", \"http://github.com/cemerick/jsdifflib\");\n", + "\t\t\n", + "\t\ttdata.push(node = document.createElement(\"tbody\"));\n", + "\t\tfor (var idx in rows) rows.hasOwnProperty(idx) && node.appendChild(rows[idx]);\n", + "\t\t\n", + "\t\tnode = celt(\"table\", \"diff\" + (inline ? \" inlinediff\" : \"\"));\n", + "\t\tfor (var idx in tdata) tdata.hasOwnProperty(idx) && node.appendChild(tdata[idx]);\n", + "\t\treturn node;\n", + "\t}\n", + "};\n", + "\n", + "\n", + "/***\n", + "This is part of jsdifflib v1.0. \n", + "\n", + "Copyright (c) 2007, Snowtide Informatics Systems, Inc.\n", + "All rights reserved.\n", + "\n", + "Redistribution and use in source and binary forms, with or without modification,\n", + "are permitted provided that the following conditions are met:\n", + "\n", + "\t* Redistributions of source code must retain the above copyright notice, this\n", + "\t\tlist of conditions and the following disclaimer.\n", + "\t* Redistributions in binary form must reproduce the above copyright notice,\n", + "\t\tthis list of conditions and the following disclaimer in the documentation\n", + "\t\tand/or other materials provided with the distribution.\n", + "\t* Neither the name of the Snowtide Informatics Systems nor the names of its\n", + "\t\tcontributors may be used to endorse or promote products derived from this\n", + "\t\tsoftware without specific prior written permission.\n", + "\n", + "THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND ANY\n", + "EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES\n", + "OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT\n", + "SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,\n", + "INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED\n", + "TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR\n", + "BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN\n", + "CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN\n", + "ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH\n", + "DAMAGE.\n", + "***/\n", + "/* Author: Chas Emerick */\n", + "var __whitespace = {\" \":true, \"\\t\":true, \"\\n\":true, \"\\f\":true, \"\\r\":true};\n", + "\n", + "var difflib = {\n", + "\tdefaultJunkFunction: function (c) {\n", + "\t\treturn __whitespace.hasOwnProperty(c);\n", + "\t},\n", + "\t\n", + "\tstripLinebreaks: function (str) { return str.replace(/^[\\n\\r]*|[\\n\\r]*$/g, \"\"); },\n", + "\t\n", + "\tstringAsLines: function (str) {\n", + "\t\tvar lfpos = str.indexOf(\"\\n\");\n", + "\t\tvar crpos = str.indexOf(\"\\r\");\n", + "\t\tvar linebreak = ((lfpos > -1 && crpos > -1) || crpos < 0) ? \"\\n\" : \"\\r\";\n", + "\t\t\n", + "\t\tvar lines = str.split(linebreak);\n", + "\t\tfor (var i = 0; i < lines.length; i++) {\n", + "\t\t\tlines[i] = difflib.stripLinebreaks(lines[i]);\n", + "\t\t}\n", + "\t\t\n", + "\t\treturn lines;\n", + "\t},\n", + "\t\n", + "\t// iteration-based reduce implementation\n", + "\t__reduce: function (func, list, initial) {\n", + "\t\tif (initial != null) {\n", + "\t\t\tvar value = initial;\n", + "\t\t\tvar idx = 0;\n", + "\t\t} else if (list) {\n", + "\t\t\tvar value = list[0];\n", + "\t\t\tvar idx = 1;\n", + "\t\t} else {\n", + "\t\t\treturn null;\n", + "\t\t}\n", + "\t\t\n", + "\t\tfor (; idx < list.length; idx++) {\n", + "\t\t\tvalue = func(value, list[idx]);\n", + "\t\t}\n", + "\t\t\n", + "\t\treturn value;\n", + "\t},\n", + "\t\n", + "\t// comparison function for sorting lists of numeric tuples\n", + "\t__ntuplecomp: function (a, b) {\n", + "\t\tvar mlen = Math.max(a.length, b.length);\n", + "\t\tfor (var i = 0; i < mlen; i++) {\n", + "\t\t\tif (a[i] < b[i]) return -1;\n", + "\t\t\tif (a[i] > b[i]) return 1;\n", + "\t\t}\n", + "\t\t\n", + "\t\treturn a.length == b.length ? 0 : (a.length < b.length ? -1 : 1);\n", + "\t},\n", + "\t\n", + "\t__calculate_ratio: function (matches, length) {\n", + "\t\treturn length ? 2.0 * matches / length : 1.0;\n", + "\t},\n", + "\t\n", + "\t// returns a function that returns true if a key passed to the returned function\n", + "\t// is in the dict (js object) provided to this function; replaces being able to\n", + "\t// carry around dict.has_key in python...\n", + "\t__isindict: function (dict) {\n", + "\t\treturn function (key) { return dict.hasOwnProperty(key); };\n", + "\t},\n", + "\t\n", + "\t// replacement for python's dict.get function -- need easy default values\n", + "\t__dictget: function (dict, key, defaultValue) {\n", + "\t\treturn dict.hasOwnProperty(key) ? dict[key] : defaultValue;\n", + "\t},\t\n", + "\t\n", + "\tSequenceMatcher: function (a, b, isjunk) {\n", + "\t\tthis.set_seqs = function (a, b) {\n", + "\t\t\tthis.set_seq1(a);\n", + "\t\t\tthis.set_seq2(b);\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.set_seq1 = function (a) {\n", + "\t\t\tif (a == this.a) return;\n", + "\t\t\tthis.a = a;\n", + "\t\t\tthis.matching_blocks = this.opcodes = null;\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.set_seq2 = function (b) {\n", + "\t\t\tif (b == this.b) return;\n", + "\t\t\tthis.b = b;\n", + "\t\t\tthis.matching_blocks = this.opcodes = this.fullbcount = null;\n", + "\t\t\tthis.__chain_b();\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.__chain_b = function () {\n", + "\t\t\tvar b = this.b;\n", + "\t\t\tvar n = b.length;\n", + "\t\t\tvar b2j = this.b2j = {};\n", + "\t\t\tvar populardict = {};\n", + "\t\t\tfor (var i = 0; i < b.length; i++) {\n", + "\t\t\t\tvar elt = b[i];\n", + "\t\t\t\tif (b2j.hasOwnProperty(elt)) {\n", + "\t\t\t\t\tvar indices = b2j[elt];\n", + "\t\t\t\t\tif (n >= 200 && indices.length * 100 > n) {\n", + "\t\t\t\t\t\tpopulardict[elt] = 1;\n", + "\t\t\t\t\t\tdelete b2j[elt];\n", + "\t\t\t\t\t} else {\n", + "\t\t\t\t\t\tindices.push(i);\n", + "\t\t\t\t\t}\n", + "\t\t\t\t} else {\n", + "\t\t\t\t\tb2j[elt] = [i];\n", + "\t\t\t\t}\n", + "\t\t\t}\n", + "\t\n", + "\t\t\tfor (var elt in populardict) {\n", + "\t\t\t\tif (populardict.hasOwnProperty(elt)) {\n", + "\t\t\t\t\tdelete b2j[elt];\n", + "\t\t\t\t}\n", + "\t\t\t}\n", + "\t\t\t\n", + "\t\t\tvar isjunk = this.isjunk;\n", + "\t\t\tvar junkdict = {};\n", + "\t\t\tif (isjunk) {\n", + "\t\t\t\tfor (var elt in populardict) {\n", + "\t\t\t\t\tif (populardict.hasOwnProperty(elt) && isjunk(elt)) {\n", + "\t\t\t\t\t\tjunkdict[elt] = 1;\n", + "\t\t\t\t\t\tdelete populardict[elt];\n", + "\t\t\t\t\t}\n", + "\t\t\t\t}\n", + "\t\t\t\tfor (var elt in b2j) {\n", + "\t\t\t\t\tif (b2j.hasOwnProperty(elt) && isjunk(elt)) {\n", + "\t\t\t\t\t\tjunkdict[elt] = 1;\n", + "\t\t\t\t\t\tdelete b2j[elt];\n", + "\t\t\t\t\t}\n", + "\t\t\t\t}\n", + "\t\t\t}\n", + "\t\n", + "\t\t\tthis.isbjunk = difflib.__isindict(junkdict);\n", + "\t\t\tthis.isbpopular = difflib.__isindict(populardict);\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.find_longest_match = function (alo, ahi, blo, bhi) {\n", + "\t\t\tvar a = this.a;\n", + "\t\t\tvar b = this.b;\n", + "\t\t\tvar b2j = this.b2j;\n", + "\t\t\tvar isbjunk = this.isbjunk;\n", + "\t\t\tvar besti = alo;\n", + "\t\t\tvar bestj = blo;\n", + "\t\t\tvar bestsize = 0;\n", + "\t\t\tvar j = null;\n", + "\t\t\tvar k;\n", + "\t\n", + "\t\t\tvar j2len = {};\n", + "\t\t\tvar nothing = [];\n", + "\t\t\tfor (var i = alo; i < ahi; i++) {\n", + "\t\t\t\tvar newj2len = {};\n", + "\t\t\t\tvar jdict = difflib.__dictget(b2j, a[i], nothing);\n", + "\t\t\t\tfor (var jkey in jdict) {\n", + "\t\t\t\t\tif (jdict.hasOwnProperty(jkey)) {\n", + "\t\t\t\t\t\tj = jdict[jkey];\n", + "\t\t\t\t\t\tif (j < blo) continue;\n", + "\t\t\t\t\t\tif (j >= bhi) break;\n", + "\t\t\t\t\t\tnewj2len[j] = k = difflib.__dictget(j2len, j - 1, 0) + 1;\n", + "\t\t\t\t\t\tif (k > bestsize) {\n", + "\t\t\t\t\t\t\tbesti = i - k + 1;\n", + "\t\t\t\t\t\t\tbestj = j - k + 1;\n", + "\t\t\t\t\t\t\tbestsize = k;\n", + "\t\t\t\t\t\t}\n", + "\t\t\t\t\t}\n", + "\t\t\t\t}\n", + "\t\t\t\tj2len = newj2len;\n", + "\t\t\t}\n", + "\t\n", + "\t\t\twhile (besti > alo && bestj > blo && !isbjunk(b[bestj - 1]) && a[besti - 1] == b[bestj - 1]) {\n", + "\t\t\t\tbesti--;\n", + "\t\t\t\tbestj--;\n", + "\t\t\t\tbestsize++;\n", + "\t\t\t}\n", + "\t\t\t\t\n", + "\t\t\twhile (besti + bestsize < ahi && bestj + bestsize < bhi &&\n", + "\t\t\t\t\t!isbjunk(b[bestj + bestsize]) &&\n", + "\t\t\t\t\ta[besti + bestsize] == b[bestj + bestsize]) {\n", + "\t\t\t\tbestsize++;\n", + "\t\t\t}\n", + "\t\n", + "\t\t\twhile (besti > alo && bestj > blo && isbjunk(b[bestj - 1]) && a[besti - 1] == b[bestj - 1]) {\n", + "\t\t\t\tbesti--;\n", + "\t\t\t\tbestj--;\n", + "\t\t\t\tbestsize++;\n", + "\t\t\t}\n", + "\t\t\t\n", + "\t\t\twhile (besti + bestsize < ahi && bestj + bestsize < bhi && isbjunk(b[bestj + bestsize]) &&\n", + "\t\t\t\t\ta[besti + bestsize] == b[bestj + bestsize]) {\n", + "\t\t\t\tbestsize++;\n", + "\t\t\t}\n", + "\t\n", + "\t\t\treturn [besti, bestj, bestsize];\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.get_matching_blocks = function () {\n", + "\t\t\tif (this.matching_blocks != null) return this.matching_blocks;\n", + "\t\t\tvar la = this.a.length;\n", + "\t\t\tvar lb = this.b.length;\n", + "\t\n", + "\t\t\tvar queue = [[0, la, 0, lb]];\n", + "\t\t\tvar matching_blocks = [];\n", + "\t\t\tvar alo, ahi, blo, bhi, qi, i, j, k, x;\n", + "\t\t\twhile (queue.length) {\n", + "\t\t\t\tqi = queue.pop();\n", + "\t\t\t\talo = qi[0];\n", + "\t\t\t\tahi = qi[1];\n", + "\t\t\t\tblo = qi[2];\n", + "\t\t\t\tbhi = qi[3];\n", + "\t\t\t\tx = this.find_longest_match(alo, ahi, blo, bhi);\n", + "\t\t\t\ti = x[0];\n", + "\t\t\t\tj = x[1];\n", + "\t\t\t\tk = x[2];\n", + "\t\n", + "\t\t\t\tif (k) {\n", + "\t\t\t\t\tmatching_blocks.push(x);\n", + "\t\t\t\t\tif (alo < i && blo < j)\n", + "\t\t\t\t\t\tqueue.push([alo, i, blo, j]);\n", + "\t\t\t\t\tif (i+k < ahi && j+k < bhi)\n", + "\t\t\t\t\t\tqueue.push([i + k, ahi, j + k, bhi]);\n", + "\t\t\t\t}\n", + "\t\t\t}\n", + "\t\t\t\n", + "\t\t\tmatching_blocks.sort(difflib.__ntuplecomp);\n", + "\t\n", + "\t\t\tvar i1 = 0, j1 = 0, k1 = 0, block = 0;\n", + "\t\t\tvar i2, j2, k2;\n", + "\t\t\tvar non_adjacent = [];\n", + "\t\t\tfor (var idx in matching_blocks) {\n", + "\t\t\t\tif (matching_blocks.hasOwnProperty(idx)) {\n", + "\t\t\t\t\tblock = matching_blocks[idx];\n", + "\t\t\t\t\ti2 = block[0];\n", + "\t\t\t\t\tj2 = block[1];\n", + "\t\t\t\t\tk2 = block[2];\n", + "\t\t\t\t\tif (i1 + k1 == i2 && j1 + k1 == j2) {\n", + "\t\t\t\t\t\tk1 += k2;\n", + "\t\t\t\t\t} else {\n", + "\t\t\t\t\t\tif (k1) non_adjacent.push([i1, j1, k1]);\n", + "\t\t\t\t\t\ti1 = i2;\n", + "\t\t\t\t\t\tj1 = j2;\n", + "\t\t\t\t\t\tk1 = k2;\n", + "\t\t\t\t\t}\n", + "\t\t\t\t}\n", + "\t\t\t}\n", + "\t\t\t\n", + "\t\t\tif (k1) non_adjacent.push([i1, j1, k1]);\n", + "\t\n", + "\t\t\tnon_adjacent.push([la, lb, 0]);\n", + "\t\t\tthis.matching_blocks = non_adjacent;\n", + "\t\t\treturn this.matching_blocks;\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.get_opcodes = function () {\n", + "\t\t\tif (this.opcodes != null) return this.opcodes;\n", + "\t\t\tvar i = 0;\n", + "\t\t\tvar j = 0;\n", + "\t\t\tvar answer = [];\n", + "\t\t\tthis.opcodes = answer;\n", + "\t\t\tvar block, ai, bj, size, tag;\n", + "\t\t\tvar blocks = this.get_matching_blocks();\n", + "\t\t\tfor (var idx in blocks) {\n", + "\t\t\t\tif (blocks.hasOwnProperty(idx)) {\n", + "\t\t\t\t\tblock = blocks[idx];\n", + "\t\t\t\t\tai = block[0];\n", + "\t\t\t\t\tbj = block[1];\n", + "\t\t\t\t\tsize = block[2];\n", + "\t\t\t\t\ttag = '';\n", + "\t\t\t\t\tif (i < ai && j < bj) {\n", + "\t\t\t\t\t\ttag = 'replace';\n", + "\t\t\t\t\t} else if (i < ai) {\n", + "\t\t\t\t\t\ttag = 'delete';\n", + "\t\t\t\t\t} else if (j < bj) {\n", + "\t\t\t\t\t\ttag = 'insert';\n", + "\t\t\t\t\t}\n", + "\t\t\t\t\tif (tag) answer.push([tag, i, ai, j, bj]);\n", + "\t\t\t\t\ti = ai + size;\n", + "\t\t\t\t\tj = bj + size;\n", + "\t\t\t\t\t\n", + "\t\t\t\t\tif (size) answer.push(['equal', ai, i, bj, j]);\n", + "\t\t\t\t}\n", + "\t\t\t}\n", + "\t\t\t\n", + "\t\t\treturn answer;\n", + "\t\t}\n", + "\t\t\n", + "\t\t// this is a generator function in the python lib, which of course is not supported in javascript\n", + "\t\t// the reimplementation builds up the grouped opcodes into a list in their entirety and returns that.\n", + "\t\tthis.get_grouped_opcodes = function (n) {\n", + "\t\t\tif (!n) n = 3;\n", + "\t\t\tvar codes = this.get_opcodes();\n", + "\t\t\tif (!codes) codes = [[\"equal\", 0, 1, 0, 1]];\n", + "\t\t\tvar code, tag, i1, i2, j1, j2;\n", + "\t\t\tif (codes[0][0] == 'equal') {\n", + "\t\t\t\tcode = codes[0];\n", + "\t\t\t\ttag = code[0];\n", + "\t\t\t\ti1 = code[1];\n", + "\t\t\t\ti2 = code[2];\n", + "\t\t\t\tj1 = code[3];\n", + "\t\t\t\tj2 = code[4];\n", + "\t\t\t\tcodes[0] = [tag, Math.max(i1, i2 - n), i2, Math.max(j1, j2 - n), j2];\n", + "\t\t\t}\n", + "\t\t\tif (codes[codes.length - 1][0] == 'equal') {\n", + "\t\t\t\tcode = codes[codes.length - 1];\n", + "\t\t\t\ttag = code[0];\n", + "\t\t\t\ti1 = code[1];\n", + "\t\t\t\ti2 = code[2];\n", + "\t\t\t\tj1 = code[3];\n", + "\t\t\t\tj2 = code[4];\n", + "\t\t\t\tcodes[codes.length - 1] = [tag, i1, Math.min(i2, i1 + n), j1, Math.min(j2, j1 + n)];\n", + "\t\t\t}\n", + "\t\n", + "\t\t\tvar nn = n + n;\n", + "\t\t\tvar group = [];\n", + "\t\t\tvar groups = [];\n", + "\t\t\tfor (var idx in codes) {\n", + "\t\t\t\tif (codes.hasOwnProperty(idx)) {\n", + "\t\t\t\t\tcode = codes[idx];\n", + "\t\t\t\t\ttag = code[0];\n", + "\t\t\t\t\ti1 = code[1];\n", + "\t\t\t\t\ti2 = code[2];\n", + "\t\t\t\t\tj1 = code[3];\n", + "\t\t\t\t\tj2 = code[4];\n", + "\t\t\t\t\tif (tag == 'equal' && i2 - i1 > nn) {\n", + "\t\t\t\t\t\tgroup.push([tag, i1, Math.min(i2, i1 + n), j1, Math.min(j2, j1 + n)]);\n", + "\t\t\t\t\t\tgroups.push(group);\n", + "\t\t\t\t\t\tgroup = [];\n", + "\t\t\t\t\t\ti1 = Math.max(i1, i2-n);\n", + "\t\t\t\t\t\tj1 = Math.max(j1, j2-n);\n", + "\t\t\t\t\t}\n", + "\t\t\t\t\t\n", + "\t\t\t\t\tgroup.push([tag, i1, i2, j1, j2]);\n", + "\t\t\t\t}\n", + "\t\t\t}\n", + "\t\t\t\n", + "\t\t\tif (group && !(group.length == 1 && group[0][0] == 'equal')) groups.push(group)\n", + "\t\t\t\n", + "\t\t\treturn groups;\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.ratio = function () {\n", + "\t\t\tmatches = difflib.__reduce(\n", + "\t\t\t\t\t\t\tfunction (sum, triple) { return sum + triple[triple.length - 1]; },\n", + "\t\t\t\t\t\t\tthis.get_matching_blocks(), 0);\n", + "\t\t\treturn difflib.__calculate_ratio(matches, this.a.length + this.b.length);\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.quick_ratio = function () {\n", + "\t\t\tvar fullbcount, elt;\n", + "\t\t\tif (this.fullbcount == null) {\n", + "\t\t\t\tthis.fullbcount = fullbcount = {};\n", + "\t\t\t\tfor (var i = 0; i < this.b.length; i++) {\n", + "\t\t\t\t\telt = this.b[i];\n", + "\t\t\t\t\tfullbcount[elt] = difflib.__dictget(fullbcount, elt, 0) + 1;\n", + "\t\t\t\t}\n", + "\t\t\t}\n", + "\t\t\tfullbcount = this.fullbcount;\n", + "\t\n", + "\t\t\tvar avail = {};\n", + "\t\t\tvar availhas = difflib.__isindict(avail);\n", + "\t\t\tvar matches = numb = 0;\n", + "\t\t\tfor (var i = 0; i < this.a.length; i++) {\n", + "\t\t\t\telt = this.a[i];\n", + "\t\t\t\tif (availhas(elt)) {\n", + "\t\t\t\t\tnumb = avail[elt];\n", + "\t\t\t\t} else {\n", + "\t\t\t\t\tnumb = difflib.__dictget(fullbcount, elt, 0);\n", + "\t\t\t\t}\n", + "\t\t\t\tavail[elt] = numb - 1;\n", + "\t\t\t\tif (numb > 0) matches++;\n", + "\t\t\t}\n", + "\t\t\t\n", + "\t\t\treturn difflib.__calculate_ratio(matches, this.a.length + this.b.length);\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.real_quick_ratio = function () {\n", + "\t\t\tvar la = this.a.length;\n", + "\t\t\tvar lb = this.b.length;\n", + "\t\t\treturn _calculate_ratio(Math.min(la, lb), la + lb);\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.isjunk = isjunk ? isjunk : difflib.defaultJunkFunction;\n", + "\t\tthis.a = this.b = null;\n", + "\t\tthis.set_seqs(a, b);\n", + "\t}\n", + "};\n", + "\n", + "\n", + "\n", + "function diffUsingJS (viewType, contextSize, baseText, newText) {\n", + "\n", + " var byId = function (id) { return document.getElementById(id); },\n", + " base = difflib.stringAsLines(baseText),\n", + " newtxt = difflib.stringAsLines(newText),\n", + " sm = new difflib.SequenceMatcher(base, newtxt),\n", + " opcodes = sm.get_opcodes(),\n", + " diffoutputdiv = byId(\"diffid_2021-08-05_16_46_43_079488\");\n", + "\n", + " diffoutputdiv.innerHTML = \"\";\n", + " contextSize = contextSize || null;\n", + "\n", + " diffoutputdiv.appendChild(diffview.buildView({\n", + " baseTextLines: base,\n", + " newTextLines: newtxt,\n", + " opcodes: opcodes,\n", + " baseTextName: \"Base Text\",\n", + " newTextName: \"New Text\",\n", + " contextSize: contextSize,\n", + " viewType: viewType\n", + " }));\n", + "}\n", + "var tview=0;\n", + "var csize='';\n", + "var bt = 'def fft2d(mat, fft_length):\\n mat = mat[:fft_length[0], :fft_length[1]]\\n res = mat.copy()\\n \\n # first FFT\\n res = dft_real(res, fft_length=fft_length[1], transpose=True)\\n \\n # second FFT decomposed on FFT on real part and imaginary part\\n res2_real = dft_real(res[0], fft_length=fft_length[0], transpose=False)\\n res2_imag = dft_real(res[1], fft_length=fft_length[0], transpose=False) \\n res2_imag2 = numpy.vstack([-res2_imag[1:2], res2_imag[:1]])\\n res = res2_real + res2_imag2\\n size = fft_length[1]//2 + 1\\n return res[:, :fft_length[0], :size]\\n';\n", + "var nt = 'def fft2d_d3(mat, fft_length):\\n mat = mat[:, :fft_length[-2], :fft_length[-1]]\\n res = mat.copy()\\n \\n # first FFT\\n res = dft_real_d3(res, fft_length=fft_length[-1], transpose=True)\\n \\n # second FFT decomposed on FFT on real part and imaginary part\\n res2_real = dft_real_d3(res[0], fft_length=fft_length[-2], transpose=False)\\n res2_imag = dft_real_d3(res[1], fft_length=fft_length[-2], transpose=False)\\n res2_imag2 = numpy.vstack([-res2_imag[1:2], res2_imag[:1]])\\n res = res2_real + res2_imag2\\n size = fft_length[-1]//2 + 1\\n return res[:, :, :fft_length[-2], :size]\\n';\n", + "diffUsingJS(tview, csize, bt, nt) ;\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "text1 = inspect.getsource(fft2d)\n", + "text2 = inspect.getsource(fft2d_d3)\n", + "%textdiff text1 text2" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "51e7a4f7", + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "def onnx_rfft_3d_1d(x, fft_length=None, transpose=True):\n", + " if fft_length is None:\n", + " raise RuntimeError(\"fft_length must be specified.\")\n", + " \n", + " size = fft_length // 2 + 1\n", + " cst = dft_real_cst(fft_length, fft_length).astype(numpy.float32)\n", + " if transpose:\n", + " xt = npnx.transpose(x, (0, 2, 1))\n", + " a = cst[:, :, :fft_length]\n", + " b = xt[:, :fft_length, :]\n", + " a = npnx.expand_dims(a, 0)\n", + " b = npnx.expand_dims(b, 1)\n", + " res = npnx.matmul(a, b)\n", + " res2 = res[:, :size, :]\n", + " return npnx.transpose(res2, (1, 0, 3, 2))\n", + " else:\n", + " a = cst[:, :, :fft_length]\n", + " b = x[:, :fft_length, :]\n", + " a = npnx.expand_dims(a, 0)\n", + " b = npnx.expand_dims(b, 1)\n", + " res = npnx.matmul(a, b)\n", + " return npnx.transpose(res, (1, 0, 2, 3)) \n", + " \n", + "\n", + "def onnx_rfft_3d_2d(x, fft_length=None):\n", + " mat = x[:, :fft_length[-2], :fft_length[-1]]\n", + " \n", + " # first FFT\n", + " res = onnx_rfft_3d_1d(mat, fft_length=fft_length[-1], transpose=True)\n", + " \n", + " # second FFT decomposed on FFT on real part and imaginary part\n", + " res2_real = onnx_rfft_3d_1d(res[0], fft_length=fft_length[0], transpose=False)\n", + " res2_imag = onnx_rfft_3d_1d(res[1], fft_length=fft_length[0], transpose=False) \n", + " res2_imag2 = npnx.vstack(-res2_imag[1:2], res2_imag[:1])\n", + " res = res2_real + res2_imag2\n", + " size = fft_length[1]//2 + 1\n", + " return res[:, :, :fft_length[-2], :size]\n", + "\n", + "\n", + "@onnxnumpy_np(signature=NDArrayType((\"T:all\", ), dtypes_out=('T',)))\n", + "def onnx_rfft_2d_any(x, fft_length=None):\n", + " new_shape = npnx.concat(\n", + " numpy.array([-1], dtype=numpy.int64), x.shape[-2:], axis=0)\n", + " mat2 = x.reshape(new_shape)\n", + " f2 = onnx_rfft_3d_2d(mat2, fft_length)\n", + " new_shape = npnx.concat(\n", + " numpy.array([2], dtype=numpy.int64), x.shape[:-2], f2.shape[-2:])\n", + " return f2.reshape(new_shape)\n", + "\n", + "\n", + "shape = (3, 1, 4)\n", + "fft_length = (1, 4)\n", + "rnd = numpy.random.randn(*list(shape)).astype(numpy.float32)\n", + "fft2d_cus = fft2d_any(rnd, fft_length)\n", + "fft2d_onx = onnx_rfft_2d_any(rnd, fft_length=fft_length)\n", + "almost_equal(fft2d_cus, fft2d_onx)" + ] + }, + { + "cell_type": "markdown", + "id": "37c45ae7", + "metadata": {}, + "source": [ + "Let's do the same comparison." + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "11c1e596", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "OK x.shape=(3, 1, 4) length=(1, 4) output shape=(3, 4) or (2, 3, 1, 3)\n", + "OK x.shape=(3, 1, 4) length=(1, 4) output shape=(3, 4) or (2, 3, 1, 3)\n", + "OK x.shape=(3, 1, 4) length=(1, 4) output shape=(3, 4) or (2, 3, 1, 3)\n", + "OK x.shape=(3, 1, 4) length=(1, 2) output shape=(3, 4) or (2, 3, 1, 2)\n", + "DIS x.shape=(3, 1, 4) length=(1, 1) error=AssertionError('Mismatch max diff=1.0 > 1e-05.') output shape=(3, 4) or (2, 3, 1, 1)\n", + "OK x.shape=(5, 7) length=(5, 7) output shape=(3, 4) or (2, 5, 4)\n", + "OK x.shape=(5, 7) length=(1, 7) output shape=(3, 4) or (2, 1, 4)\n", + "OK x.shape=(5, 7) length=(2, 7) output shape=(3, 4) or (2, 2, 4)\n", + "OK x.shape=(5, 7) length=(5, 2) output shape=(3, 4) or (2, 5, 2)\n", + "OK x.shape=(5, 7) length=(3, 4) output shape=(3, 4) or (2, 3, 3)\n", + "OK x.shape=(3, 5, 7) length=(5, 7) output shape=(3, 4) or (2, 3, 5, 4)\n", + "OK x.shape=(3, 5, 7) length=(1, 7) output shape=(3, 4) or (2, 3, 1, 4)\n", + "OK x.shape=(3, 5, 7) length=(2, 7) output shape=(3, 4) or (2, 3, 2, 4)\n", + "OK x.shape=(3, 5, 7) length=(5, 2) output shape=(3, 4) or (2, 3, 5, 2)\n", + "OK x.shape=(3, 5, 7) length=(3, 4) output shape=(3, 4) or (2, 3, 3, 3)\n", + "OK x.shape=(7, 5) length=(7, 5) output shape=(3, 4) or (2, 7, 3)\n", + "OK x.shape=(7, 5) length=(1, 5) output shape=(3, 4) or (2, 1, 3)\n", + "OK x.shape=(7, 5) length=(2, 5) output shape=(3, 4) or (2, 2, 3)\n", + "OK x.shape=(7, 5) length=(7, 2) output shape=(3, 4) or (2, 7, 2)\n", + "OK x.shape=(7, 5) length=(3, 4) output shape=(3, 4) or (2, 3, 3)\n" + ] + } + ], + "source": [ + "for shape in [(3, 1, 4), (5, 7), (3, 5, 7), (7, 5)]:\n", + " for fft_length in [shape[-2:], (1, shape[-1]),\n", + " (min(2, shape[-2]), shape[-1]),\n", + " (shape[-2], 2),\n", + " (min(3, shape[-2]), min(4, shape[-2]))]:\n", + " x = numpy.random.randn(*list(shape)).astype(numpy.float32)\n", + " if len(fnp.shape) == 2:\n", + " fn= numpy.expand_dims(fnp, 0)\n", + " try:\n", + " cus = fft2d_any(x, fft_length)\n", + " except IndexError as e:\n", + " print(\"ERR x.shape=%r length=%r error=%r\" % (x.shape, fft_length, e))\n", + " continue\n", + " try:\n", + " onx = onnx_rfft_2d_any(x, fft_length=fft_length)\n", + " except IndexError as e:\n", + " print(\"ERR x.shape=%r length=%r error=%r\" % (x.shape, fft_length, e))\n", + " continue\n", + " try:\n", + " almost_equal(onx, cus)\n", + " except (AssertionError, IndexError) as e:\n", + " print(\"DIS x.shape=%r length=%r error=%r output shape=%r or %r\" % (\n", + " x.shape, fft_length, e, fnp.shape, cus.shape))\n", + " continue\n", + " print(\"OK x.shape=%r length=%r output shape=%r or %r\" % (\n", + " x.shape, fft_length, fnp.shape, cus.shape))" + ] + }, + { + "cell_type": "markdown", + "id": "d197467f", + "metadata": {}, + "source": [ + "There is one issue with ``fft_length=(1, 1)`` but that case is out of scope." + ] + }, + { + "cell_type": "markdown", + "id": "33b5897e", + "metadata": {}, + "source": [ + "### ONNX graph" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "d45e9a99", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 34, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "key = list(onnx_rfft_2d_any.signed_compiled)[0]\n", + "%onnxview onnx_rfft_2d_any.signed_compiled[key].compiled.onnx_" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "2ab7a3d0", + "metadata": {}, + "outputs": [], + "source": [ + "with open(\"fft2d_any.onnx\", \"wb\") as f:\n", + " key = list(onnx_rfft_2d_any.signed_compiled)[0]\n", + " f.write(onnx_rfft_2d_any.signed_compiled[key].compiled.onnx_.SerializeToString())" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "9e5507f7", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 } \ No newline at end of file diff --git a/_doc/sphinxdoc/source/api/tools.rst b/_doc/sphinxdoc/source/api/tools.rst index de23a9573..cb631f6e4 100644 --- a/_doc/sphinxdoc/source/api/tools.rst +++ b/_doc/sphinxdoc/source/api/tools.rst @@ -18,6 +18,13 @@ Accessor .. autosignature:: mlprodict.onnx_tools.onnx_tools.insert_node +Export +++++++ + +.. autosignature:: mlprodict.onnx_tools.onnx_export.export2onnx + +.. autosignature:: mlprodict.onnx_tools.onnx_export.export2tf2onnx + Graphs ++++++ @@ -48,6 +55,8 @@ The following functions reduce the number of ONNX operators in a graph while keeping the same results. The optimized graph is left unchanged. +.. autosignature:: mlprodict.onnx_tools.onnx_tools.ensure_topological_order + .. autosignature:: mlprodict.onnx_tools.optim.onnx_optimisation.onnx_remove_node .. autosignature:: mlprodict.onnx_tools.optim.onnx_optimisation_identity.onnx_remove_node_identity @@ -68,15 +77,24 @@ Serialization .. autosignature:: mlprodict.onnx_tools.onnx2py_helper.to_bytes +Runtime +======= + +.. autosignature:: mlprodict.tools.onnx_micro_runtime.OnnxMicroRuntime + Validation ++++++++++ .. autosignature:: mlprodict.onnx_tools.model_checker.onnx_shaker -Runtime -======= +Visualization ++++++++++++++ -.. autosignature:: mlprodict.tools.onnx_micro_runtime.OnnxMicroRuntime +Many times I had to debug and I was thinking about a way to see +a graph in a text editor. That's the goal of this function with +the possibility later to only show a part of a graph. + +.. autosignature:: mlprodict.onnx_tools.graphs.onnx2bigraph Others ====== diff --git a/_doc/sphinxdoc/source/blog/2021/2021-05-05_numpyapionnx1.rst b/_doc/sphinxdoc/source/blog/2021/2021-05-05_numpyapionnx1.rst index b188769df..90df2235b 100644 --- a/_doc/sphinxdoc/source/blog/2021/2021-05-05_numpyapionnx1.rst +++ b/_doc/sphinxdoc/source/blog/2021/2021-05-05_numpyapionnx1.rst @@ -136,7 +136,7 @@ pipe.fit(X_train, y_train) print(pipe.predict_proba(X_test[:2])) - onx = to_onnx(pipe, X_train[:1], rewrite_ops=True, + onx = to_onnx(pipe, X_train[:1], options={LogisticRegression: {'zipmap': False}}) oinf = OnnxInference(onx) print(oinf.run({'X': X_test[:2]})['probabilities']) diff --git a/_unittests/ut_cli/test_cli_onnx_code.py b/_unittests/ut_cli/test_cli_onnx_code.py new file mode 100644 index 000000000..9a29c69e9 --- /dev/null +++ b/_unittests/ut_cli/test_cli_onnx_code.py @@ -0,0 +1,49 @@ +""" +@brief test tree node (time=10s) +""" +import os +import unittest +from pyquickhelper.loghelper import BufferedPrint +from pyquickhelper.pycode import ExtTestCase, get_temp_folder +from mlprodict.__main__ import main + + +class TestCliOnnxCode(ExtTestCase): + + def test_cli_onnx_code(self): + st = BufferedPrint() + main(args=["onnx_code", "--help"], fLOG=st.fprint) + res = str(st) + self.assertIn("verbose", res) + + def test_cli_onnx_code_onnx(self): + temp = get_temp_folder(__file__, "temp_cli_onnx_code_onnx") + name = os.path.join( + temp, "..", "..", "ut_tools", "data", "fft2d_any.onnx") + self.assertExists(name) + output = os.path.join(temp, "code_onnx.py") + st = BufferedPrint() + main(args=["onnx_code", "--filename", name, + "--output", output, "--verbose", "1"], fLOG=st.fprint) + self.assertExists(output) + with open(output, "r", encoding='utf-8') as f: + content = f.read() + self.assertIn("create_model()", content) + + def test_cli_onnx_code_tf2onnx(self): + temp = get_temp_folder(__file__, "temp_cli_onnx_code_tf2onnx") + name = os.path.join( + temp, "..", "..", "ut_tools", "data", "fft2d_any.onnx") + self.assertExists(name) + output = os.path.join(temp, "code_tf2onnx.py") + st = BufferedPrint() + main(args=["onnx_code", "--filename", name, '--format', 'tf2onnx', + "--output", output, "--verbose", "1"], fLOG=st.fprint) + self.assertExists(output) + with open(output, "r", encoding='utf-8') as f: + content = f.read() + self.assertIn("tf_op", content) + + +if __name__ == "__main__": + unittest.main() diff --git a/_unittests/ut_npy/test_onnx_variable.py b/_unittests/ut_npy/test_onnx_variable.py index aa585af26..852f518ee 100644 --- a/_unittests/ut_npy/test_onnx_variable.py +++ b/_unittests/ut_npy/test_onnx_variable.py @@ -102,6 +102,13 @@ def test_abs_matmul(x: NDArray[Any, numpy.float32], return nxnp.abs(x) @ x +@onnxnumpy_default +def test_abs_matmul2(x: NDArray[Any, numpy.float32], + ) -> NDArray[Any, numpy.float32]: + "onnx numpy addition" + return nxnp.matmul(nxnp.abs(x), x) + + @onnxnumpy_default def test_abs_div(x: NDArray[Any, numpy.float32], ) -> NDArray[Any, numpy.float32]: @@ -515,6 +522,11 @@ def test_py_abs_matmul(self): y = test_abs_matmul(x) self.assertEqualArray(y, numpy.abs(x) @ x) + def test_py_abs_matmul2(self): + x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) + y = test_abs_matmul2(x) + self.assertEqualArray(y, numpy.abs(x) @ x) + def test_py_abs_div(self): x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) y = test_abs_div(x) diff --git a/_unittests/ut_onnx_conv/test_onnxrt_runtime_lightgbm.py b/_unittests/ut_onnx_conv/test_onnxrt_runtime_lightgbm.py index 22d23b7bb..c0db8a11f 100644 --- a/_unittests/ut_onnx_conv/test_onnxrt_runtime_lightgbm.py +++ b/_unittests/ut_onnx_conv/test_onnxrt_runtime_lightgbm.py @@ -273,6 +273,7 @@ def test_onnxrt_python_lightgbm_categorical_iris_dataframe(self): values = pandas.DataFrame(got['output_probability']).values self.assertEqualArray(exp, values[:, 1], decimal=5) + @skipif_circleci('stuck') @unittest.skipIf(sys.platform == 'darwin', 'stuck') def test_lightgbm_booster_classifier(self): from lightgbm import Dataset, train as lgb_train diff --git a/_unittests/ut_onnx_conv/test_onnxrt_runtime_lightgbm_bug.py b/_unittests/ut_onnx_conv/test_onnxrt_runtime_lightgbm_bug.py new file mode 100644 index 000000000..04110dcaa --- /dev/null +++ b/_unittests/ut_onnx_conv/test_onnxrt_runtime_lightgbm_bug.py @@ -0,0 +1,134 @@ +""" +@brief test log(time=3s) +""" +import sys +import unittest +from logging import getLogger +import numpy +from pyquickhelper.pycode import ExtTestCase, skipif_circleci +from skl2onnx.common.data_types import FloatTensorType +from mlprodict.onnxrt import OnnxInference +from mlprodict.onnx_conv import register_converters, to_onnx + + +class TestOnnxrtRuntimeLightGbmBug(ExtTestCase): + + def setUp(self): + logger = getLogger('skl2onnx') + logger.disabled = True + register_converters() + + @skipif_circleci('stuck') + @unittest.skipIf(sys.platform == 'darwin', 'stuck') + def test_lightgbm_regressor(self): + from lightgbm import LGBMRegressor + try: + from onnxmltools.convert import convert_lightgbm + except ImportError: + convert_lightgbm = None + + X = numpy.abs(numpy.random.randn(7, 227)).astype(numpy.float32) + y = X.sum(axis=1) + numpy.random.randn( + X.shape[0]).astype(numpy.float32) / 10 + model = LGBMRegressor( + max_depth=8, n_estimators=100, min_child_samples=1, + learning_rate=0.0000001) + model.fit(X, y) + expected = model.predict(X) + + model_onnx = to_onnx(model, X) + if convert_lightgbm is not None: + model_onnx2 = convert_lightgbm( + model, initial_types=[('X', FloatTensorType([None, 227]))]) + else: + model_onnx2 = None + + for i, mo in enumerate([model_onnx, model_onnx2]): + if mo is None: + continue + for rt in ['python', 'onnxruntime1']: + with self.subTest(i=i, rt=rt): + oinf = OnnxInference(mo, runtime=rt) + got = oinf.run({'X': X})['variable'] + diff = numpy.abs(got.ravel() - expected.ravel()).max() + if __name__ == "__main__": + print("lgb", i, rt, diff) + self.assertLess(diff, 1e-3) + + @skipif_circleci('stuck') + @unittest.skipIf(sys.platform == 'darwin', 'stuck') + def test_lightgbm_regressor_double(self): + from lightgbm import LGBMRegressor + + X = numpy.abs(numpy.random.randn(7, 227)).astype(numpy.float32) + y = X.sum(axis=1) + numpy.random.randn( + X.shape[0]).astype(numpy.float32) / 10 + model = LGBMRegressor( + max_depth=8, n_estimators=100, min_child_samples=1, + learning_rate=0.0000001) + model.fit(X, y) + expected = model.predict(X) + + model_onnx = to_onnx(model, X, rewrite_ops=True) + model_onnx2 = to_onnx(model, X.astype(numpy.float64), + rewrite_ops=True) + + for i, mo in enumerate([model_onnx, model_onnx2]): + for rt in ['python', 'onnxruntime1']: + if "TreeEnsembleRegressorDouble" in str(mo): + x = X.astype(numpy.float64) + if rt == 'onnxruntime1': + continue + else: + x = X + with self.subTest(i=i, rt=rt): + oinf = OnnxInference(mo, runtime=rt) + got = oinf.run({'X': x})['variable'] + diff = numpy.abs(got.ravel() - expected.ravel()).max() + if __name__ == "__main__": + print("lgbd", i, rt, diff) + if i == 1 and rt == 'python': + self.assertLess(diff, 1e-5) + else: + self.assertLess(diff, 1e-3) + + @skipif_circleci('stuck') + @unittest.skipIf(sys.platform == 'darwin', 'stuck') + def test_xgboost_regressor(self): + from xgboost import XGBRegressor + try: + from onnxmltools.convert import convert_xgboost + except ImportError: + convert_xgboost = None + + X = numpy.abs(numpy.random.randn(7, 227)).astype(numpy.float32) + y = X.sum(axis=1) + numpy.random.randn( + X.shape[0]).astype(numpy.float32) / 10 + model = XGBRegressor( + max_depth=8, n_estimators=100, + learning_rate=0.000001) + model.fit(X, y) + expected = model.predict(X) + + model_onnx = to_onnx(model, X) + if convert_xgboost is not None: + model_onnx2 = convert_xgboost( + model, initial_types=[('X', FloatTensorType([None, 227]))]) + else: + model_onnx2 = None + + for i, mo in enumerate([model_onnx, model_onnx2]): + if mo is None: + continue + for rt in ['python', 'onnxruntime1']: + with self.subTest(i=i, rt=rt): + oinf = OnnxInference(mo, runtime=rt) + got = oinf.run({'X': X})['variable'] + diff = numpy.abs(got.ravel() - expected.ravel()).max() + if __name__ == "__main__": + print("xgb", i, rt, diff) + self.assertLess(diff, 1e-5) + + +if __name__ == "__main__": + unittest.main() diff --git a/_unittests/ut_onnxrt/test_onnxrt_python_runtime_custom.py b/_unittests/ut_onnxrt/test_onnxrt_python_runtime_custom.py index cf229f3ab..b5bb0b926 100644 --- a/_unittests/ut_onnxrt/test_onnxrt_python_runtime_custom.py +++ b/_unittests/ut_onnxrt/test_onnxrt_python_runtime_custom.py @@ -98,7 +98,8 @@ def test_onnxt_runtime_fft(self): if dim == 1: X = numpy.arange(16).astype(numpy.float32) elif dim == 2: - X = numpy.arange(48).astype(numpy.float32).reshape((3, -1)) + X = numpy.arange(48).astype( + numpy.float32).reshape((3, -1)) Y = numpy.fft.fft(X.astype(numpy.float32), axis=axis) onx = OnnxFFT('X', output_names=['Y'], @@ -125,7 +126,8 @@ def test_onnxt_runtime_fft(self): if dim == 1: X = numpy.arange(16).astype(numpy.float32) elif dim == 2: - X = numpy.arange(48).astype(numpy.float32).reshape((3, -1)) + X = numpy.arange(48).astype( + numpy.float32).reshape((3, -1)) Y = numpy.fft.fft(X.astype(numpy.float32), 8, axis=axis) onx = OnnxFFT('X', numpy.array([8], dtype=numpy.int64), @@ -155,7 +157,8 @@ def test_onnxt_runtime_rfft(self): if dim == 1: X = numpy.arange(16).astype(numpy.float32) elif dim == 2: - X = numpy.arange(48).astype(numpy.float32).reshape((3, -1)) + X = numpy.arange(48).astype( + numpy.float32).reshape((3, -1)) Y = numpy.fft.rfft(X.astype(numpy.float32), axis=axis) onx = OnnxRFFT('X', output_names=['Y'], @@ -182,7 +185,8 @@ def test_onnxt_runtime_rfft(self): if dim == 1: X = numpy.arange(16).astype(numpy.float32) elif dim == 2: - X = numpy.arange(48).astype(numpy.float32).reshape((3, -1)) + X = numpy.arange(48).astype( + numpy.float32).reshape((3, -1)) Y = numpy.fft.rfft(X.astype(numpy.float32), 8, axis=axis) onx = OnnxRFFT('X', numpy.array([8], dtype=numpy.int64), @@ -210,7 +214,8 @@ def test_onnxt_runtime_fft2d(self): if dim == 1: X = numpy.arange(16).astype(numpy.float32) elif dim == 2: - X = numpy.arange(48).astype(numpy.float32).reshape((3, -1)) + X = numpy.arange(48).astype( + numpy.float32).reshape((3, -1)) Y = numpy.fft.fft2(X.astype(numpy.float32), axes=axis) if axis is not None: @@ -239,8 +244,10 @@ def test_onnxt_runtime_fft2d(self): if dim == 1: X = numpy.arange(16).astype(numpy.float32) elif dim == 2: - X = numpy.arange(48).astype(numpy.float32).reshape((3, -1)) - Y = numpy.fft.fft2(X.astype(numpy.float32), (8, 8), axes=axis) + X = numpy.arange(48).astype( + numpy.float32).reshape((3, -1)) + Y = numpy.fft.fft2( + X.astype(numpy.float32), (8, 8), axes=axis) if axis is not None: onx = OnnxFFT2D('X', numpy.array([8, 8], dtype=numpy.int64), diff --git a/_unittests/ut_tools/data/fft2d_any.onnx b/_unittests/ut_tools/data/fft2d_any.onnx new file mode 100644 index 000000000..0a868c71f Binary files /dev/null and b/_unittests/ut_tools/data/fft2d_any.onnx differ diff --git a/_unittests/ut_tools/test_export_onnx.py b/_unittests/ut_tools/test_export_onnx.py new file mode 100644 index 000000000..331b2627a --- /dev/null +++ b/_unittests/ut_tools/test_export_onnx.py @@ -0,0 +1,636 @@ +""" +@brief test log(time=5s) +""" +import os +import unittest +import collections +import inspect +from io import StringIO +from contextlib import redirect_stdout, redirect_stderr +import numpy +from onnx import numpy_helper, helper +from onnx.helper import ( + make_model, make_node, set_model_props, make_tensor, make_graph, + make_tensor_value_info) +from pyquickhelper.pycode import ExtTestCase +from mlprodict.onnx_tools.onnx_export import export2onnx, export2tf2onnx +from mlprodict.testing.verify_code import verify_code +from mlprodict.onnxrt import OnnxInference +from mlprodict.onnx_tools.exports.tf2onnx_helper import make_sure, make_name +from mlprodict.tools.code_helper import print_code + + +class ConvertFFT2DOp: + + supported_dtypes = [ + numpy.float32, + ] + + @classmethod + def any_version(cls, opset, ctx, node, **kwargs): # pylint: disable=R0915 + ''' + Converter for ``FFT2D``. + + * producer: skl2onnx + * version: 0 + * description: + ''' + oldnode = node + input_name = node.input[0] + onnx_dtype = ctx.get_dtype(input_name) + make_sure(onnx_dtype in ConvertFFT2DOp.supported_dtypes, + "Unsupported input type.") + vars = {x: x for x in node.input} # pylint: disable=W0622 + + # initializers + if getattr(ctx, 'verbose', False): + print('[initializers] %r' % cls) + + list_value = [1.0, 0.0] + value = numpy.array(list_value, dtype=numpy.float32).reshape((2, 1, 1)) + + r_Un_Unsqueezecst = ctx.make_const( + name=make_name('init_Un_Unsqueezecst'), np_val=value) + vars['Un_Unsqueezecst'] = r_Un_Unsqueezecst.name + + list_value = [0] + value = numpy.array(list_value, dtype=numpy.int64) + + r_Un_Unsqueezecst1 = ctx.make_const( + name=make_name('init_Un_Unsqueezecst1'), np_val=value) + vars['Un_Unsqueezecst1'] = r_Un_Unsqueezecst1.name + + list_value = [1.0, 1.0, 1.0, 1.0, 1.0, 6.123234262925839e-17, + -1.0, -1.8369701465288538e-16, 1.0, -1.0, 1.0, -1.0, 1.0, + -1.8369701465288538e-16, -1.0, 5.510910704284357e-16, 0.0, + 0.0, 0.0, 0.0, 0.0, -1.0, -1.2246468525851679e-16, 1.0, 0.0, + -1.2246468525851679e-16, 2.4492937051703357e-16, + -3.6739402930577075e-16, 0.0, 1.0, -3.6739402930577075e-16, -1.0] + value = numpy.array(list_value, dtype=numpy.float32).reshape((2, 4, 4)) + + r_Un_Unsqueezecst2 = ctx.make_const( + name=make_name('init_Un_Unsqueezecst2'), np_val=value) + vars['Un_Unsqueezecst2'] = r_Un_Unsqueezecst2.name + + list_value = [-1] + value = numpy.array(list_value, dtype=numpy.int64) + + r_Co_Concatcst = ctx.make_const( + name=make_name('init_Co_Concatcst'), np_val=value) + vars['Co_Concatcst'] = r_Co_Concatcst.name + + list_value = [-2] + value = numpy.array(list_value, dtype=numpy.int64) + + r_Sl_Slicecst = ctx.make_const( + name=make_name('init_Sl_Slicecst'), np_val=value) + vars['Sl_Slicecst'] = r_Sl_Slicecst.name + + value = numpy.array(0, dtype=numpy.int64) + + r_Ga_Gathercst = ctx.make_const( + name=make_name('init_Ga_Gathercst'), np_val=value) + vars['Ga_Gathercst'] = r_Ga_Gathercst.name + + list_value = [0, 0] + value = numpy.array(list_value, dtype=numpy.int64) + + r_Sl_Slicecst2 = ctx.make_const( + name=make_name('init_Sl_Slicecst2'), np_val=value) + vars['Sl_Slicecst2'] = r_Sl_Slicecst2.name + + list_value = [1, 4] + value = numpy.array(list_value, dtype=numpy.int64) + + r_Sl_Slicecst3 = ctx.make_const( + name=make_name('init_Sl_Slicecst3'), np_val=value) + vars['Sl_Slicecst3'] = r_Sl_Slicecst3.name + + list_value = [1, 2] + value = numpy.array(list_value, dtype=numpy.int64) + + r_Sl_Slicecst4 = ctx.make_const( + name=make_name('init_Sl_Slicecst4'), np_val=value) + vars['Sl_Slicecst4'] = r_Sl_Slicecst4.name + + list_value = [4] + value = numpy.array(list_value, dtype=numpy.int64) + + r_Sl_Slicecst6 = ctx.make_const( + name=make_name('init_Sl_Slicecst6'), np_val=value) + vars['Sl_Slicecst6'] = r_Sl_Slicecst6.name + + list_value = [1] + value = numpy.array(list_value, dtype=numpy.int64) + + r_Sl_Slicecst7 = ctx.make_const( + name=make_name('init_Sl_Slicecst7'), np_val=value) + vars['Sl_Slicecst7'] = r_Sl_Slicecst7.name + + list_value = [3] + value = numpy.array(list_value, dtype=numpy.int64) + + r_Sl_Slicecst9 = ctx.make_const( + name=make_name('init_Sl_Slicecst9'), np_val=value) + vars['Sl_Slicecst9'] = r_Sl_Slicecst9.name + + value = numpy.array(1, dtype=numpy.int64) + + r_Ga_Gathercst2 = ctx.make_const( + name=make_name('init_Ga_Gathercst2'), np_val=value) + vars['Ga_Gathercst2'] = r_Ga_Gathercst2.name + + list_value = [2] + value = numpy.array(list_value, dtype=numpy.int64) + + r_Sl_Slicecst18 = ctx.make_const( + name=make_name('init_Sl_Slicecst18'), np_val=value) + vars['Sl_Slicecst18'] = r_Sl_Slicecst18.name + + list_value = [1, 3] + value = numpy.array(list_value, dtype=numpy.int64) + + r_Sl_Slicecst24 = ctx.make_const( + name=make_name('init_Sl_Slicecst24'), np_val=value) + vars['Sl_Slicecst24'] = r_Sl_Slicecst24.name + + list_value = [2, 3] + value = numpy.array(list_value, dtype=numpy.int64) + + r_Sl_Slicecst25 = ctx.make_const( + name=make_name('init_Sl_Slicecst25'), np_val=value) + vars['Sl_Slicecst25'] = r_Sl_Slicecst25.name + + # nodes + if getattr(ctx, 'verbose', False): + print('[nodes] %r' % cls) + + attr = dict() + inputs = [vars['Un_Unsqueezecst'], vars['Un_Unsqueezecst1'], ] + node = ctx.make_node( + 'Unsqueeze', inputs=inputs, attr=attr, + name=make_name('Un_Unsqueeze')) + vars['Un_expanded0'] = node.output[0] + + attr = dict() + inputs = [vars['Un_Unsqueezecst2'], vars['Un_Unsqueezecst1'], ] + node = ctx.make_node( + 'Unsqueeze', inputs=inputs, attr=attr, + name=make_name('Un_Unsqueeze1')) + vars['Un_expanded03'] = node.output[0] + + attr = dict() + inputs = [vars['x'], ] + node = ctx.make_node( + 'Shape', inputs=inputs, attr=attr, + name=make_name('Sh_Shape')) + vars['Sh_shape0'] = node.output[0] + + attr = dict() + inputs = [vars['Sh_shape0'], ] + node = ctx.make_node( + 'Shape', inputs=inputs, attr=attr, + name=make_name('Sh_Shape1')) + vars['Sh_shape01'] = node.output[0] + + attr = dict(axis=0,) + inputs = [vars['Sh_shape01'], vars['Ga_Gathercst'], ] + node = ctx.make_node( + 'Gather', inputs=inputs, attr=attr, + name=make_name('Ga_Gather')) + vars['Ga_output01'] = node.output[0] + + attr = dict() + inputs = [vars['Ga_output01'], vars['Un_Unsqueezecst1'], ] + node = ctx.make_node( + 'Unsqueeze', inputs=inputs, attr=attr, + name=make_name('Un_Unsqueeze2')) + vars['Un_expanded05'] = node.output[0] + + attr = dict(axis=0,) + inputs = [vars['Un_expanded05'], ] + node = ctx.make_node( + 'Concat', inputs=inputs, attr=attr, + name=make_name('Co_Concat')) + vars['Co_concat_result01'] = node.output[0] + + attr = dict() + inputs = [vars['Sh_shape0'], vars['Sl_Slicecst'], + vars['Co_concat_result01'], vars['Un_Unsqueezecst1'], ] + node = ctx.make_node( + 'Slice', inputs=inputs, attr=attr, + name=make_name('Sl_Slice')) + vars['Sl_output05'] = node.output[0] + + attr = dict(axis=0,) + inputs = [vars['Co_Concatcst'], vars['Sl_output05'], ] + node = ctx.make_node( + 'Concat', inputs=inputs, attr=attr, + name=make_name('Co_Concat1')) + vars['Co_concat_result0'] = node.output[0] + + attr = dict() + inputs = [vars['x'], vars['Co_concat_result0'], ] + node = ctx.make_node( + 'Reshape', inputs=inputs, attr=attr, + name=make_name('Re_Reshape')) + vars['Re_reshaped0'] = node.output[0] + + attr = dict() + inputs = [vars['Re_reshaped0'], vars['Sl_Slicecst2'], + vars['Sl_Slicecst3'], vars['Sl_Slicecst4'], ] + node = ctx.make_node( + 'Slice', inputs=inputs, attr=attr, + name=make_name('Sl_Slice1')) + vars['Sl_output04'] = node.output[0] + + attr = dict(perm=[0, 2, 1],) + inputs = [vars['Sl_output04'], ] + node = ctx.make_node( + 'Transpose', inputs=inputs, attr=attr, + name=make_name('Tr_Transpose')) + vars['Tr_transposed02'] = node.output[0] + + attr = dict() + inputs = [vars['Tr_transposed02'], vars['Un_Unsqueezecst1'], + vars['Sl_Slicecst6'], vars['Sl_Slicecst7'], ] + node = ctx.make_node( + 'Slice', inputs=inputs, attr=attr, + name=make_name('Sl_Slice2')) + vars['Sl_output03'] = node.output[0] + + attr = dict() + inputs = [vars['Sl_output03'], vars['Sl_Slicecst7'], ] + node = ctx.make_node( + 'Unsqueeze', inputs=inputs, attr=attr, + name=make_name('Un_Unsqueeze3')) + vars['Un_expanded04'] = node.output[0] + + attr = dict() + inputs = [vars['Un_expanded03'], vars['Un_expanded04'], ] + node = ctx.make_node( + 'MatMul', inputs=inputs, attr=attr, + name=make_name('Ma_MatMul')) + vars['Ma_Y01'] = node.output[0] + + attr = dict() + inputs = [vars['Ma_Y01'], vars['Un_Unsqueezecst1'], + vars['Sl_Slicecst9'], vars['Sl_Slicecst7'], ] + node = ctx.make_node( + 'Slice', inputs=inputs, attr=attr, + name=make_name('Sl_Slice3')) + vars['Sl_output02'] = node.output[0] + + attr = dict(perm=[1, 0, 3, 2],) + inputs = [vars['Sl_output02'], ] + node = ctx.make_node( + 'Transpose', inputs=inputs, attr=attr, + name=make_name('Tr_Transpose1')) + vars['Tr_transposed01'] = node.output[0] + + attr = dict(axis=0,) + inputs = [vars['Tr_transposed01'], vars['Ga_Gathercst'], ] + node = ctx.make_node( + 'Gather', inputs=inputs, attr=attr, + name=make_name('Ga_Gather1')) + vars['Ga_output0'] = node.output[0] + + attr = dict() + inputs = [vars['Ga_output0'], vars['Un_Unsqueezecst1'], + vars['Sl_Slicecst7'], vars['Sl_Slicecst7'], ] + node = ctx.make_node( + 'Slice', inputs=inputs, attr=attr, + name=make_name('Sl_Slice4')) + vars['Sl_output01'] = node.output[0] + + attr = dict() + inputs = [vars['Sl_output01'], vars['Sl_Slicecst7'], ] + node = ctx.make_node( + 'Unsqueeze', inputs=inputs, attr=attr, + name=make_name('Un_Unsqueeze4')) + vars['Un_expanded02'] = node.output[0] + + attr = dict() + inputs = [vars['Un_expanded0'], vars['Un_expanded02'], ] + node = ctx.make_node( + 'MatMul', inputs=inputs, attr=attr, + name=make_name('Ma_MatMul1')) + vars['Ma_Y0'] = node.output[0] + + attr = dict(perm=[1, 0, 2, 3],) + inputs = [vars['Ma_Y0'], ] + node = ctx.make_node( + 'Transpose', inputs=inputs, attr=attr, + name=make_name('Tr_Transpose2')) + vars['Tr_transposed0'] = node.output[0] + + attr = dict(axis=0,) + inputs = [vars['Tr_transposed01'], vars['Ga_Gathercst2'], ] + node = ctx.make_node( + 'Gather', inputs=inputs, attr=attr, + name=make_name('Ga_Gather2')) + vars['Ga_output03'] = node.output[0] + + attr = dict() + inputs = [vars['Ga_output03'], vars['Un_Unsqueezecst1'], + vars['Sl_Slicecst7'], vars['Sl_Slicecst7'], ] + node = ctx.make_node( + 'Slice', inputs=inputs, attr=attr, + name=make_name('Sl_Slice5')) + vars['Sl_output07'] = node.output[0] + + attr = dict() + inputs = [vars['Sl_output07'], vars['Sl_Slicecst7'], ] + node = ctx.make_node( + 'Unsqueeze', inputs=inputs, attr=attr, + name=make_name('Un_Unsqueeze6')) + vars['Un_expanded07'] = node.output[0] + + attr = dict() + inputs = [vars['Un_expanded0'], vars['Un_expanded07'], ] + node = ctx.make_node( + 'MatMul', inputs=inputs, attr=attr, + name=make_name('Ma_MatMul2')) + vars['Ma_Y03'] = node.output[0] + + attr = dict(perm=[1, 0, 2, 3],) + inputs = [vars['Ma_Y03'], ] + node = ctx.make_node( + 'Transpose', inputs=inputs, attr=attr, + name=make_name('Tr_Transpose3')) + vars['Tr_transposed04'] = node.output[0] + + attr = dict() + inputs = [vars['Tr_transposed04'], vars['Sl_Slicecst7'], + vars['Sl_Slicecst18'], vars['Un_Unsqueezecst1'], ] + node = ctx.make_node( + 'Slice', inputs=inputs, attr=attr, + name=make_name('Sl_Slice6')) + vars['Sl_output06'] = node.output[0] + + attr = dict() + inputs = [vars['Sl_output06'], ] + node = ctx.make_node( + 'Neg', inputs=inputs, attr=attr, + name=make_name('Ne_Neg')) + vars['Ne_Y0'] = node.output[0] + + attr = dict() + inputs = [vars['Tr_transposed04'], vars['Un_Unsqueezecst1'], + vars['Sl_Slicecst7'], vars['Un_Unsqueezecst1'], ] + node = ctx.make_node( + 'Slice', inputs=inputs, attr=attr, + name=make_name('Sl_Slice7')) + vars['Sl_output08'] = node.output[0] + + attr = dict(axis=0,) + inputs = [vars['Ne_Y0'], vars['Sl_output08'], ] + node = ctx.make_node( + 'Concat', inputs=inputs, attr=attr, + name=make_name('Co_Concat2')) + vars['Co_concat_result03'] = node.output[0] + + attr = dict() + inputs = [vars['Tr_transposed0'], vars['Co_concat_result03'], ] + node = ctx.make_node( + 'Add', inputs=inputs, attr=attr, + name=make_name('Ad_Add')) + vars['Ad_C0'] = node.output[0] + + attr = dict() + inputs = [vars['Ad_C0'], vars['Sl_Slicecst2'], + vars['Sl_Slicecst24'], vars['Sl_Slicecst25'], ] + node = ctx.make_node( + 'Slice', inputs=inputs, attr=attr, + name=make_name('Sl_Slice8')) + vars['Sl_output0'] = node.output[0] + + attr = dict() + inputs = [vars['Sh_shape0'], vars['Un_Unsqueezecst1'], + vars['Sl_Slicecst'], vars['Un_Unsqueezecst1'], ] + node = ctx.make_node( + 'Slice', inputs=inputs, attr=attr, + name=make_name('Sl_Slice9')) + vars['Sl_output010'] = node.output[0] + + attr = dict() + inputs = [vars['Sl_output0'], ] + node = ctx.make_node( + 'Shape', inputs=inputs, attr=attr, + name=make_name('Sh_Shape3')) + vars['Sh_shape03'] = node.output[0] + + attr = dict() + inputs = [vars['Sh_shape03'], ] + node = ctx.make_node( + 'Shape', inputs=inputs, attr=attr, + name=make_name('Sh_Shape4')) + vars['Sh_shape04'] = node.output[0] + + attr = dict(axis=0,) + inputs = [vars['Sh_shape04'], vars['Ga_Gathercst'], ] + node = ctx.make_node( + 'Gather', inputs=inputs, attr=attr, + name=make_name('Ga_Gather3')) + vars['Ga_output04'] = node.output[0] + + attr = dict() + inputs = [vars['Ga_output04'], vars['Un_Unsqueezecst1'], ] + node = ctx.make_node( + 'Unsqueeze', inputs=inputs, attr=attr, + name=make_name('Un_Unsqueeze7')) + vars['Un_expanded08'] = node.output[0] + + attr = dict(axis=0,) + inputs = [vars['Un_expanded08'], ] + node = ctx.make_node( + 'Concat', inputs=inputs, attr=attr, + name=make_name('Co_Concat3')) + vars['Co_concat_result05'] = node.output[0] + + attr = dict() + inputs = [vars['Sh_shape03'], vars['Sl_Slicecst'], + vars['Co_concat_result05'], vars['Un_Unsqueezecst1'], ] + node = ctx.make_node( + 'Slice', inputs=inputs, attr=attr, + name=make_name('Sl_Slice10')) + vars['Sl_output012'] = node.output[0] + + attr = dict(axis=0,) + inputs = [vars['Sl_Slicecst18'], + vars['Sl_output010'], vars['Sl_output012'], ] + node = ctx.make_node( + 'Concat', inputs=inputs, attr=attr, + name=make_name('Co_Concat4')) + vars['Co_concat_result04'] = node.output[0] + + attr = dict() + inputs = [vars['Sl_output0'], vars['Co_concat_result04'], ] + node = ctx.make_node( + 'Reshape', inputs=inputs, attr=attr, + name=make_name('Re_Reshape1')) + vars['y'] = node.output[0] + + # finalize + if getattr(ctx, 'verbose', False): + print('[replace_all_inputs] %r' % cls) + ctx.replace_all_inputs(oldnode.output[0], node.output[0]) + ctx.remove_node(oldnode.name) + + @classmethod + def version_13(cls, ctx, node, **kwargs): + return cls.any_version(13, ctx, node, **kwargs) + + +class TestExportOnnx(ExtTestCase): + + def verify(self, content): + try: + left, __ = verify_code(content, exc=False) + except SyntaxError as e: + raise AssertionError( + "Unable to analyse a script due to %r. " + "\n--CODE--\n%s" + "" % (e, content)) from e + + # execution + try: + obj = compile(content, '', 'exec') + except SyntaxError as e: + raise AssertionError( + "Unable to compile a script due to %r. " + "\n--CODE--\n%s" + "" % (e, print_code(content))) from e + glo = globals().copy() + loc = {'numpy_helper': numpy_helper, + 'make_model': make_model, + 'make_node': make_node, + 'set_model_props': set_model_props, + 'make_tensor': make_tensor, + 'make_graph': make_graph, + 'make_tensor_value_info': make_tensor_value_info, + 'print': print, 'sorted': sorted, + 'collections': collections, 'inspect': inspect} + out = StringIO() + err = StringIO() + if len(left) >= 5: + raise AssertionError( + "Too many unknown symbols: %r." % left) + + with redirect_stdout(out): + with redirect_stderr(err): + try: + exec(obj, glo, loc) # pylint: disable=W0122 + except Exception as e: + raise AssertionError( + "Unable to execute a script due to %r. " + "\n--OUT--\n%s\n--ERR--\n%s\n--CODE--\n%s" + "" % (e, out.getvalue(), err.getvalue(), + print_code(content))) from e + return glo, loc + + def test_export_onnx(self): + this = os.path.dirname(__file__) + folder = os.path.join(this, "data") + names = ["fft2d_any.onnx"] + for name in names: + with self.subTest(name=name): + oinf0 = OnnxInference(os.path.join(folder, name)) + + x = numpy.random.randn(3, 1, 4).astype(numpy.float32) + y = oinf0.run({'x': x}) + + new_onnx = export2onnx( + os.path.join(folder, name), name="FFT2D") + _, loc = self.verify(new_onnx) + model = loc['onnx_model'] + oinf = OnnxInference(model) + y1 = oinf.run({'x': x}) + + new_onnx = export2onnx( + os.path.join(folder, name), verbose=False) + _, loc = self.verify(new_onnx) + model = loc['onnx_model'] + oinf = OnnxInference(model) + y2 = oinf.run({'x': x}) + + self.assertEqualArray(y['y'], y1['y']) + self.assertEqualArray(y['y'], y2['y']) + + def verify_tf(self, content): + try: + left, __ = verify_code(content, exc=False) + except SyntaxError as e: + raise AssertionError( + "Unable to analyse a script due to %r. " + "\n--CODE--\n%s" + "" % (e, content)) from e + + # execution + try: + obj = compile(content, '', 'exec') + except SyntaxError as e: + raise AssertionError( + "Unable to compile a script due to %r. " + "\n--CODE--\n%s" + "" % (e, print_code(content))) from e + glo = globals().copy() + loc = {'numpy': numpy, 'dict': dict, 'list': list, + 'print': print, 'sorted': sorted, + 'collections': collections, 'inspect': inspect, + 'helper': helper, "make_sure": make_sure, + 'ConvertFFT2DOp': ConvertFFT2DOp, "make_name": make_name} + out = StringIO() + err = StringIO() + if len(left) >= 14: + raise AssertionError( + "Too many unknown symbols: %r." % left) + + with redirect_stdout(out): + with redirect_stderr(err): + try: + exec(obj, glo, loc) # pylint: disable=W0122 + except Exception as e: + raise AssertionError( + "Unable to execute a script due to %r. " + "\n--OUT--\n%s\n--ERR--\n%s\n--CODE--\n%s" + "" % (e, out.getvalue(), err.getvalue(), + print_code(content))) from e + return glo, loc + + def test_export2tf2onnx(self): + this = os.path.dirname(__file__) + folder = os.path.join(this, "data") + names = ["fft2d_any.onnx"] + for name in names: + with self.subTest(name=name): + oinf0 = OnnxInference(os.path.join(folder, name)) + + x = numpy.random.randn(3, 1, 4).astype(numpy.float32) + y = oinf0.run({'x': x}) + + new_onnx = export2tf2onnx( + os.path.join(folder, name), name="FFT2D") + _, loc = self.verify_tf(new_onnx) + model = loc['onnx_raw'] + self.assertIn('op_type: "FFT2D"', str(model)) + model = loc['onnx_model'] + self.assertNotIn('op_type: "FFT2D"', str(model)) + + oinf = OnnxInference(model) + y1 = oinf.run({'x': x}) + + new_onnx = export2tf2onnx( + os.path.join(folder, name), name="FFT2D") + _, loc = self.verify_tf(new_onnx) + model = loc['onnx_model'] + self.assertNotIn('op_type: "FFT2D"', str(model)) + oinf = OnnxInference(model) + y2 = oinf.run({'x': x}) + + self.assertEqualArray(y['y'], y1['y']) + self.assertEqualArray(y['y'], y2['y']) + + +if __name__ == "__main__": + unittest.main() diff --git a/mlprodict/__main__.py b/mlprodict/__main__.py index b38a90994..76e264e3c 100644 --- a/mlprodict/__main__.py +++ b/mlprodict/__main__.py @@ -22,6 +22,7 @@ def main(args, fLOG=print): from .cli.asv2csv import asv2csv from .cli.replay import benchmark_replay from .cli.einsum import einsum_test + from .cli.onnx_code import onnx_code except ImportError: # pragma: no cover from mlprodict.cli.validate import validate_runtime from mlprodict.cli.convert_validate import convert_validate @@ -30,6 +31,7 @@ def main(args, fLOG=print): from mlprodict.cli.asv2csv import asv2csv from mlprodict.cli.replay import benchmark_replay from mlprodict.cli.einsum import einsum_test + from mlprodict.cli.onnx_code import onnx_code fcts = dict(validate_runtime=validate_runtime, convert_validate=convert_validate, @@ -38,7 +40,8 @@ def main(args, fLOG=print): asv_bench=asv_bench, asv2csv=asv2csv, benchmark_replay=benchmark_replay, - einsum_test=einsum_test) + einsum_test=einsum_test, + onnx_code=onnx_code) try: from pyquickhelper.cli import cli_main_helper except ImportError: # pragma: no cover diff --git a/mlprodict/cli/__init__.py b/mlprodict/cli/__init__.py index 43fa491ba..5596858f9 100644 --- a/mlprodict/cli/__init__.py +++ b/mlprodict/cli/__init__.py @@ -4,5 +4,6 @@ """ from .convert_validate import convert_validate from .einsum import einsum_test +from .onnx_code import onnx_code from .optimize import onnx_optim from .validate import validate_runtime diff --git a/mlprodict/cli/onnx_code.py b/mlprodict/cli/onnx_code.py new file mode 100644 index 000000000..8832b97fa --- /dev/null +++ b/mlprodict/cli/onnx_code.py @@ -0,0 +1,57 @@ +""" +@file +@brief Command line to check einsum scenarios. +""" + + +def onnx_code(filename, format="onnx", output=None, verbose=0, name=None, + opset=None, fLOG=print): + """ + Exports an ONNX graph into a python code creating + the same graph. + + :param filename: onnx file + :param format: format to export too (`onnx`, `tf2onnx`) + :param output: output file to produce or None to print it on stdout + :param verbose: verbosity level + :param name: rewrite the graph name + :param opset: overwrite the opset (may not works depending on the format) + :param fLOG: logging function + + .. cmdref:: + :title: Exports an ONNX graph into a python code creating the same graph. + :cmd: -m mlprodict onnx_code --help + :lid: l-cmd-onnx_code + + The command pr + + Example:: + + python -m mlprodict onnx_code --filename="something.onnx" --format=onnx + """ + from ..onnx_tools.onnx_export import export2onnx, export2tf2onnx # pylint: disable=E0402 + + if name == '': + name = None + if opset == '': + opset = None + try: + v = int(opset) + opset = v + except (ValueError, TypeError): + opset = None + + if format == 'onnx': + code = export2onnx(filename, verbose=verbose, name=name, opset=opset) + elif format == 'tf2onnx': + code = export2tf2onnx(filename, verbose=verbose, + name=name, opset=opset) + else: + raise ValueError( # pragma: no cover + "Unknown format %r." % format) + + if output not in ('', None): + with open(output, "w", encoding="utf-8") as f: + f.write(code) + else: + fLOG(code) diff --git a/mlprodict/npy/numpy_onnx_impl.py b/mlprodict/npy/numpy_onnx_impl.py index 7fba336e2..cab5ec8b8 100644 --- a/mlprodict/npy/numpy_onnx_impl.py +++ b/mlprodict/npy/numpy_onnx_impl.py @@ -3,7 +3,10 @@ @brief :epkg:`numpy` functions implemented with :epkg:`onnx`. .. versionadded:: 0.6 + +.. versionchanged:: 0.7 """ +import warnings import numpy from onnx import onnx_pb as onnx_proto # pylint: disable=E1101 from onnx.helper import make_tensor @@ -221,7 +224,15 @@ def det(x): def dot(a, b): - "See :epkg:`numpy:dot`." + "See :epkg:`numpy:dot`" + warnings.warn( + "npnx.dot is equivalent to npnx.matmul == numpy.matmul " + "!= numpy.dot with arrays with more than 3D dimensions.") + return OnnxVar(a, b, op=OnnxMatMul) + + +def matmul(a, b): + "See :epkg:`numpy:matmul`." return OnnxVar(a, b, op=OnnxMatMul) diff --git a/mlprodict/npy/numpy_onnx_pyrt.py b/mlprodict/npy/numpy_onnx_pyrt.py index 174bbb44b..723b20dfa 100644 --- a/mlprodict/npy/numpy_onnx_pyrt.py +++ b/mlprodict/npy/numpy_onnx_pyrt.py @@ -42,6 +42,7 @@ hstack as nx_hstack, isnan as nx_isnan, log as nx_log, + matmul as nx_matmul, mean as nx_mean, pad as nx_pad, prod as nx_prod, @@ -247,6 +248,12 @@ def log(x): return nx_log(x) +@onnxnumpy_np(signature=NDArrayType(("T:all", "T"))) +def matmul(a, b): + "matmul" + return nx_matmul(a, b) + + @onnxnumpy_np(signature=NDArrayType(("T:all", numpy.int64, 'T'), n_optional=1)) def pad(x, pads, constant_value=None, mode='constant'): "pad" diff --git a/mlprodict/onnx_tools/exports/__init__.py b/mlprodict/onnx_tools/exports/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/mlprodict/onnx_tools/exports/tf2onnx_helper.py b/mlprodict/onnx_tools/exports/tf2onnx_helper.py new file mode 100644 index 000000000..d30023a11 --- /dev/null +++ b/mlprodict/onnx_tools/exports/tf2onnx_helper.py @@ -0,0 +1,364 @@ +""" +@file +@brief Helpers to run examples created with function +@see fn export2tf2onnx. +""" +import collections +import inspect +import numpy +from onnx.numpy_helper import from_array +from onnx.helper import ( + make_node, make_graph, make_model, set_model_props, make_tensor) +from onnx import AttributeProto +from ..onnx2py_helper import guess_dtype, guess_proto_dtype +from ..onnx_tools import ensure_topological_order + + +_make_name_id = 0 + + +def make_name(name): + "Creates a unique name." + global _make_name_id # pylint: disable=W0603 + name = "%s_%d" % (name, _make_name_id) + _make_name_id += 1 + return name + + +def make_sure(cond, msg, *args): + "Raises an exception if cond is not verified." + if not cond: + raise RuntimeError(msg % tuple(args)) + + +class tf_op: + """ + Decorator to register any new converter. + :param name: type of the operator to rewrite + :param domain: domain + """ + _OPSETS = collections.OrderedDict() + + def __init__(self, name, domain='', **kwargs): + if not isinstance(name, list): + name = [name] + self.names = name + self.domain = domain + self.kwargs = kwargs + + def __call__(self, func): + for ke, va in inspect.getmembers(func, inspect.ismethod): + if ke.startswith("version_"): + version = int(ke.replace("version_", "")) + self._register_handler( + va, version, self.names, self.domain, self.kwargs) + return func + + def _register_handler(self, func, version, names, domain, kwargs): + opset = tf_op._OPSETS.get(domain) + if not opset: + opset = [] + tf_op._OPSETS[domain] = opset + while version >= len(opset): + opset.append({}) + opset_dict = opset[version] + for name in names: + opset_dict[name] = (func, kwargs) + + +class Tf2OnnxConvert: + """ + Applies the converter on an ONNX graph. + + :param onnx_model: ONNX graph + :param tf_op: class which register + :param verbose: verbosity + :param target_opset: targetted opsets + """ + + def __init__(self, onnx_model, _tf_op=None, verbose=None, + target_opset=None): + self._onnx_model = onnx_model + self._tf_op = _tf_op or tf_op + self.verbose = verbose + if isinstance(target_opset, int): + self.target_opsets = {'': target_opset} + elif isinstance(target_opset, dict): + self.target_opsets = target_opset + elif target_opset is None: + opsets = {} + for oimp in onnx_model.opset_import: + if oimp.domain == '': + opsets[oimp.domain] = oimp.version + opset = oimp.version + else: + opsets[oimp.domain] = opset + self.target_opsets = opsets + else: + raise ValueError( # pragma: no cover + "Unexepected value for target_opset=%r." % target_opset) + self._names = {} + for node in onnx_model.graph.node: + self._names[node.name] = node + for init in onnx_model.graph.initializer: + self._names[init.name] = init + # _forbidden_new_names contains current names and deleted names. + self._forbidden_new_names = set(self._names) + + def get_node_by_name(self, name): + """ + Retrieves a node by its name. + + :param name: node name + :return: node name + """ + if name not in self._names: + raise RuntimeError( + "Unable to find node name %r among %r." % ( + name, ", ".join(sorted(self._names)))) + return self._names[name] + + def _add_node_name(self, obj): + """ + Registers an object in in the graph by its name. + :param name: node or initializer + """ + if obj.name in self._forbidden_new_names: + raise RuntimeError( + "Name %r is already registered." % obj.name) + self._names[obj.name] = obj + self._forbidden_new_names.add(obj.name) + + def make_node(self, op_type, inputs, attr=None, outputs=None, + name=None, domain='', output_count=1): + """ + Adds a node to the list of nodes. + + :param op_type: operator type + :param inputs: list of strings + :param attr: dictionary of attributes + :param outputs: None or list of strings + :param output_count: used if outputs is None to guess + the number of outputs of this node + :param name: name of the node + :param domain: domain + :return: created node + """ + if self.verbose: + print("[Tf2OnnxConvert.make_node] op_type=%r inputs=%r" % ( + op_type, inputs)) + + if attr is None: + attr = {} + if name is None: + name = make_name(op_type) + if name in self._names: + raise RuntimeError( + "Node name %r already exists in %r." % ( + name, ", ".join(sorted(self._names)))) + + if outputs is None: + outputs = [(name + ":" + str(i)) for i in range(output_count)] + + output_count = len(outputs) + raw_attr = {} + onnx_attrs = [] + for a, v in attr.items(): + if isinstance(v, AttributeProto): + onnx_attrs.append(v) + else: + raw_attr[a] = v + + onnx_node = make_node( + op_type, inputs, outputs, name=name, domain=domain, **raw_attr) + + self._add_node_name(onnx_node) + return onnx_node + + def make_const(self, name, np_val, skip_conversion=False, raw=True): + """ + Make a new constants in the graph. + :param name: const node name, must be unique. + :param np_val: value of type numpy ndarray. + :param skip_conversion: + bool, indicate whether this created node would be mapped + during conversion + :param raw: whether to store data at field of raw_data or the + specific field according to its dtype + :return: create initializer + """ + if name in self._names: + raise RuntimeError( + "Initializer name %r already exists in %r." % ( + name, ", ".join(sorted(self._names)))) + np_val_flat = np_val.flatten() + is_bytes = (np_val.dtype == numpy.object and len(np_val_flat) > 0 and + isinstance(np_val_flat[0], bytes)) + if raw and not is_bytes: + onnx_tensor = from_array(np_val, name) + else: + onnx_tensor = make_tensor( + name, guess_proto_dtype(np_val.dtype), + np_val.shape, np_val_flat, raw=False) + + self._add_node_name(onnx_tensor) + return onnx_tensor + + def get_dtype(self, input_name): + """ + Returns the type of one node or None if unknown. + :param input_name: result name + :return: numpy dtype + """ + inputs = self._onnx_model.graph.input + names = [_.name for _ in inputs] + if input_name not in names: + return None # pragma: no cover + ind = names.index(input_name) + return guess_dtype(inputs[ind].type.tensor_type.elem_type) + + def replace_all_inputs(self, old_name, new_name): + """ + Every taking *old_name* as inputs will take *new_name* instead. + Looks in the output as well but in that case, it creates an identity + node to avoid changing an output name. + :param old_name: name to replace + :param new_name: new name + :return: list of impacted nodes + """ + res = [] + for node in self._names.values(): + if not hasattr(node, 'input'): + continue + if old_name not in node.input: + continue + new_inputs = [new_name if i.name == old_name else i.name + for i in node.input] + node.input[:] = new_inputs[:] + res.append(node) + if self.verbose: + print("[Tf2OnnxConvert.replace_all_inputs] replace %r by %r in node %r" % ( + old_name, new_name, node.name)) + for o in self._onnx_model.graph.output: + if o.name != old_name: + continue + n = self.make_node("Identity", [new_name], outputs=[old_name], + name=make_name("IdOutputReplaced")) + res.append(n) + if self.verbose: + print("[Tf2OnnxConvert.replace_all_inputs] add id node from %r to %r " + "with node %r." % ( + old_name, new_name, n.name)) # pylint: disable=E1101 + return res + + def remove_node(self, name): + """ + Removes a node name from the list. + """ + if name not in self._names: + raise RuntimeError( + "Unable to delete name %r because it does not exists." % name) + del self._names[name] + if self.verbose: + print("[Tf2OnnxConvert.remove_node] delete name %r" % name) + + def get_shape(self, input_name): + """ + Returns the type of one node or None if unknown. + :param input_name: result name + :return: numpy dtype + """ + inputs = self._onnx_model.graph.input + names = [_.name for _ in inputs] + if input_name not in names: + return None # pragma: no cover + ind = names.index(input_name) + dims = inputs[ind].type.tensor_type.shape.dim + return tuple(dims) + + def run(self): + """ + Calls the registered converters on the graph + held by this instance. Returns the new onnx graph. + + :return: ONNX graph + """ + if len(self._tf_op._OPSETS) == 0: + raise RuntimeError( # pragma: no cover + "No converter was registered.") + if self.verbose: + print("[Tf2OnnxConvert.run]") + + done = {} + modif = 1 + while modif > 0: + modif = 0 + # The converter may alter the current list of nodes, we freeze it. + current_values = list(self._names.values()) + for node in current_values: + if not hasattr(node, 'domain'): + # initializer + continue + if done.get(node.name, False): + continue + domain = node.domain + if domain not in self._tf_op._OPSETS: + continue + + # look for a converter + rews = self._tf_op._OPSETS[domain] + target = min(self.target_opsets[domain], len(rews)) + conv = None + for i in range(len(rews) - 1, -1, -1): + if node.op_type in rews[i]: + conv = rews[i][node.op_type] + break + if conv is None: + continue + + # applies the converter + if self.verbose: + print("[Tf2OnnxConvert.run] convert node type=%r opset=%r name=%r" + "" % (node.op_type, target, node.name)) + fct, kwargs = conv + fct(self, node, target_opset=target, **kwargs) + modif += 1 + + return self.make_model() + + def make_model(self): + """ + Produces the new ONNX graph with the updated sets of nodes. + """ + inputs = self._onnx_model.graph.input + outputs = self._onnx_model.graph.output + inits = [init[1] for init in sorted(self._names.items()) + if not hasattr(init[1], 'domain')] + nodes = [node[1] for node in sorted(self._names.items()) + if hasattr(node[1], 'domain')] + nodes = ensure_topological_order(inputs, inits, nodes) + + if self.verbose: + print( + "[Tf2OnnxConvert.make_node] %d nodes %d inputs %d " + "outputs %d initializers" + "" % (len(nodes), len(inputs), len(outputs), len(inits))) + graph = make_graph(nodes, self._onnx_model.graph.name, + inputs, outputs, inits) + onnx_model = make_model(graph) + onnx_model.ir_version = self._onnx_model.ir_version + onnx_model.producer_name = self._onnx_model.producer_name + "-mlprodict" + onnx_model.producer_version = self._onnx_model.producer_version + onnx_model.domain = self._onnx_model.domain + onnx_model.model_version = self._onnx_model.model_version + onnx_model.doc_string = self._onnx_model.doc_string + metadata = {p.key: p.value for p in self._onnx_model.metadata_props} + set_model_props(onnx_model, metadata) + + # opsets + del onnx_model.opset_import[:] # pylint: disable=E1101 + for dom, value in self.target_opsets.items(): + op_set = onnx_model.opset_import.add() # pylint: disable=E1101 + op_set.domain = dom + op_set.version = value + return onnx_model diff --git a/mlprodict/onnx_tools/onnx2py_helper.py b/mlprodict/onnx_tools/onnx2py_helper.py index ae2d3b727..966f39f62 100644 --- a/mlprodict/onnx_tools/onnx2py_helper.py +++ b/mlprodict/onnx_tools/onnx2py_helper.py @@ -485,3 +485,41 @@ def guess_proto_dtype(dtype): return TensorProto.STRING # pylint: disable=E1101 raise RuntimeError( "Unable to guess type for dtype={}.".format(dtype)) # pragma: no cover + + +def guess_dtype(proto_type): + """ + Converts a proto type into a :epkg:`numpy` type. + + :param proto_type: example ``onnx.TensorProto.FLOAT`` + :return: :epkg:`numpy` dtype + """ + if proto_type == TensorProto.FLOAT: # pylint: disable=E1101 + return numpy.float32 + if proto_type == TensorProto.BOOL: # pylint: disable=E1101 + return numpy.bool_ + if proto_type == TensorProto.DOUBLE: # pylint: disable=E1101 + return numpy.float64 + if proto_type == TensorProto.STRING: # pylint: disable=E1101 + return numpy.str_ + if proto_type == TensorProto.INT64: # pylint: disable=E1101 + return numpy.int64 + if proto_type == TensorProto.INT32: # pylint: disable=E1101 + return numpy.int32 + if proto_type == TensorProto.INT8: # pylint: disable=E1101 + return numpy.int8 + if proto_type == TensorProto.INT16: # pylint: disable=E1101 + return numpy.int16 + if proto_type == TensorProto.UINT64: # pylint: disable=E1101 + return numpy.uint64 + if proto_type == TensorProto.UINT32: # pylint: disable=E1101 + return numpy.uint32 + if proto_type == TensorProto.UINT8: # pylint: disable=E1101 + return numpy.uint8 + if proto_type == TensorProto.UINT16: # pylint: disable=E1101 + return numpy.uint16 + if proto_type == TensorProto.FLOAT16: # pylint: disable=E1101 + return numpy.float16 + raise ValueError( + "Unable to convert proto_type {} to numpy type.".format( + proto_type)) diff --git a/mlprodict/onnx_tools/onnx_export.py b/mlprodict/onnx_tools/onnx_export.py new file mode 100644 index 000000000..7e06783db --- /dev/null +++ b/mlprodict/onnx_tools/onnx_export.py @@ -0,0 +1,402 @@ +""" +@file +@brief Exports an ONNX graph in a way it can we created again +with a python script. It relies on :epkg:`jinja2` and :epkg:`autopep8`. + +.. versionadded:: 0.7 +""" +from textwrap import dedent +import numpy +from jinja2 import Template +import autopep8 +import onnx +from onnx import numpy_helper +from .onnx2py_helper import _var_as_dict + + +_onnx_templates = dedent(""" + import numpy + from onnx import numpy_helper + from onnx.helper import ( + make_model, make_node, set_model_props, make_tensor, make_graph, + make_tensor_value_info) + + + def create_model(): + # containers + print('[containers]') # verbose + initializers = [] + nodes = [] + inputs = [] + outputs = [] + + # opsets + print('[opsets]') # verbose + opsets = {{ opsets }} + target_opset = {{ target_opset }} + + # initializers + print('[initializers]') # verbose + {% for name, value in initializers: %} + {% if len(value.shape) == 0: %} + value = numpy.array({{ value }}, dtype=numpy.{{ value.dtype }}) + {% else %} + list_value = {{ value.ravel().tolist() }} + value = numpy.array(list_value, dtype=numpy.{{ value.dtype }}){% if len(value.shape) > 1: %}.reshape({{ value.shape }}){% endif %} + {% endif %} + tensor = numpy_helper.from_array(value, name='{{ name }}') + initializers.append(tensor) + {% endfor %} + + # inputs + print('[inputs]') # verbose + {% for name, type, shape in inputs: %} + value = make_tensor_value_info('{{ name }}', {{ type }}, {{ shape }}) + inputs.append(value) + {% endfor %} + + # outputs + print('[outputs]') # verbose + {% for name, type, shape in outputs: %} + value = make_tensor_value_info('{{ name }}', {{ type }}, {{ shape }}) + outputs.append(value) + {% endfor %} + + # nodes + print('[nodes]') # verbose + {% for node in nodes: %} + node = make_node( + '{{ node['op_type'] }}', + {{ node['inputs'] }}, + {{ node['outputs'] }}, + {% if node['name']: %}name='{{ node['name'] }}',{% endif %} + {%- for name, value in node['attributes']: -%} + {{ name }}={{ value }}, + {%- endfor -%} + domain='{{ node['domain'] }}') + nodes.append(node) + {% endfor %} + + # graph + print('[graph]') # verbose + graph = make_graph(nodes, '{{ name }}', inputs, outputs, initializers) + onnx_model = make_model(graph) + onnx_model.ir_version = {{ ir_version }} + onnx_model.producer_name = '{{ producer_name }}' + onnx_model.producer_version = '{{ producer_version }}' + onnx_model.domain = '{{ domain }}' + onnx_model.model_version = {{ model_version }} + onnx_model.doc_string = '{{ doc_string }}' + set_model_props(onnx_model, {{ metadata }}) + + # opsets + print('[opset]') # verbose + del onnx_model.opset_import[:] # pylint: disable=E1101 + for dom, value in opsets.items(): + op_set = onnx_model.opset_import.add() + op_set.domain = dom + op_set.version = value + + return onnx_model + + + onnx_model = create_model() +""") + + +_tf2onnx_templates = dedent(""" + import inspect + import collections + import numpy + from onnx import AttributeProto + from onnx.helper import ( + make_model, make_node, set_model_props, make_tensor, make_graph, + make_tensor_value_info) + # from utils import make_name, make_sure + from mlprodict.onnx_tools.exports.tf2onnx_helper import ( + make_name, make_sure) + # from tf2onnx.handler import tf_op + from mlprodict.onnx_tools.exports.tf2onnx_helper import tf_op + from mlprodict.onnx_tools.exports.tf2onnx_helper import Tf2OnnxConvert + + + @tf_op("{{ name }}") + class Convert{{ name }}Op: + + supported_dtypes = [ + numpy.float32, + ] + + @classmethod + def any_version(cls, opset, ctx, node, **kwargs): + ''' + Converter for ``{{ name }}``. + + * producer: {{ producer_name }} + * version: {{ model_version }} + * description: {{ doc_string }} + {%- for key, val in sorted(metadata.items()): -%} + * {{ key }}: {{ val }} + {%- endfor %} + ''' + oldnode = node + input_name = node.input[0] + onnx_dtype = ctx.get_dtype(input_name) + make_sure(onnx_dtype in Convert{{ name }}Op.supported_dtypes, "Unsupported input type.") + shape = ctx.get_shape(input_name) + varx = {x: x for x in node.input} + + # initializers + if getattr(ctx, 'verbose', False): + print('[initializers] %r' % cls) + {% for name, value in initializers: %} + {% if len(value.shape) == 0: %} + value = numpy.array({{ value }}, dtype=numpy.{{ value.dtype }}) + {% else %} + list_value = {{ value.ravel().tolist() }} + value = numpy.array(list_value, dtype=numpy.{{ value.dtype }}){% if len(value.shape) > 1: %}.reshape({{ value.shape }}){% endif %} + {% endif %} + r_{{ name }} = ctx.make_const(name=make_name('init_{{ name }}'), np_val=value) + varx['{{ name }}'] = r_{{ name }}.name + {% endfor %} + + # nodes + if getattr(ctx, 'verbose', False): + print('[nodes] %r' % cls) + {% for node in nodes: %} + attr = dict( + {%- for name, value in node['attributes']: -%} + {{ name }}={{ value }}, + {%- endfor -%}) + inputs = [{% for name in node['inputs']: -%}varx['{{ name }}'], {%- endfor %}] + node = ctx.make_node( + '{{ node['op_type'] }}', inputs=inputs, attr=attr,{% if node['domain']: -%} domain='{{ node['domain'] }}', {% endif %} + name=make_name('{{ node['name'] }}')) + {% for i, name in enumerate(node['outputs']): -%} + varx['{{ name }}'] = node.output[{{ i }}] + {%- endfor %} + {% endfor %} + + # finalize + if getattr(ctx, 'verbose', False): + print('[replace_all_inputs] %r' % cls) + ctx.replace_all_inputs(oldnode.output[0], node.output[0]) + ctx.remove_node(oldnode.name) + + @classmethod + def version_13(cls, ctx, node, **kwargs): + return cls.any_version(13, ctx, node, **kwargs) + + + def create_model(): + inputs = [] + outputs = [] + + # inputs + print('[inputs]') # verbose + {% for name, type, shape in inputs: %} + value = make_tensor_value_info('{{ name }}', {{ type }}, {{ shape }}) + inputs.append(value) + {% endfor %} + + # outputs + print('[outputs]') # verbose + {% for name, type, shape in outputs: %} + value = make_tensor_value_info('{{ name }}', {{ type }}, {{ shape }}) + outputs.append(value) + {% endfor %} + + inames = [i.name for i in inputs] + onames = [i.name for i in outputs] + node = make_node('{{ name }}', inames, onames, name='{{ name }}') + + # graph + print('[graph]') # verbose + graph = make_graph([node], '{{ name }}', inputs, outputs) + onnx_model = make_model(graph) + onnx_model.ir_version = {{ ir_version }} + onnx_model.producer_name = '{{ producer_name }}' + onnx_model.producer_version = '{{ producer_version }}' + onnx_model.domain = '{{ domain }}' + onnx_model.model_version = {{ model_version }} + onnx_model.doc_string = '{{ doc_string }}' + set_model_props(onnx_model, {{ metadata }}) + + # opsets + print('[opset]') # verbose + opsets = {{ opsets }} + del onnx_model.opset_import[:] # pylint: disable=E1101 + for dom, value in opsets.items(): + op_set = onnx_model.opset_import.add() + op_set.domain = dom + op_set.version = value + + return onnx_model + + + onnx_raw = create_model() + onnx_model = Tf2OnnxConvert(onnx_raw, tf_op).run() +""") + + +def export_template(model_onnx, templates, opset=None, verbose=True, name=None): + """ + Exports an ONNX model to the onnx syntax. + + :param model_onnx: string or ONNX graph + :param templates: exporting templates + :param opset: opset to export to + (None to select the one from the graph) + :param verbose: insert prints + :param name: to overwrite onnx name + :return: python code + """ + # containers + context = {} + + # opset + opsets = {} + for oimp in model_onnx.opset_import: + if oimp.domain == '' and opset is None: + opsets[oimp.domain] = oimp.version + opset = oimp.version + else: + opsets[oimp.domain] = opset + context['opsets'] = opsets + context['target_opset'] = opset + + # inits + initializers = [] + for init in model_onnx.graph.initializer: + value = numpy_helper.to_array(init) + initializers.append((init.name, value)) + context['initializers'] = initializers + + # inputs + inputs = [] + for inp in model_onnx.graph.input: + t = inp.type.tensor_type + dims = tuple(t.shape.dim) + inputs.append((inp.name, t.elem_type, dims)) + context['inputs'] = inputs + + # outputs + outputs = [] + for inp in model_onnx.graph.output: + t = inp.type.tensor_type + dims = tuple(t.shape.dim) + outputs.append((inp.name, t.elem_type, dims)) + context['outputs'] = outputs + + # node + nodes = [] + for node in model_onnx.graph.node: + attributes = [] + for at in node.attribute: + temp = _var_as_dict(at) + value = temp['value'] + if isinstance(value, str): + attributes.append((at.name, "%r" % value)) + else: + if isinstance(value, numpy.ndarray): + attributes.append((at.name, repr(value.tolist()))) + else: + attributes.append((at.name, repr(value))) + d = dict(name=node.name, op_type=node.op_type, + domain=node.domain, inputs=node.input, + outputs=node.output, attributes=attributes) + nodes.append(d) + context['nodes'] = nodes + + # graph + context['name'] = name or model_onnx.graph.name + context['ir_version'] = model_onnx.ir_version + context['producer_name'] = model_onnx.producer_name + context['domain'] = model_onnx.domain + context['model_version'] = model_onnx.model_version + context['doc_string'] = model_onnx.doc_string + context['metadata'] = {p.key: p.value for p in model_onnx.metadata_props} + + # final + template = Template(templates) + final = template.render( + enumerate=enumerate, sorted=sorted, len=len, + **context) + + if not verbose: + rows = final.split("\n") + final = "\n".join(_ for _ in rows if not _.endswith("# verbose")) + return autopep8.fix_code(final) + + +def export2onnx(model_onnx, opset=None, verbose=True, name=None): + """ + Exports an ONNX model to the :epkg:`onnx` syntax. + + :param model_onnx: string or ONNX graph + :param opset: opset to export to + (None to select the one from the graph) + :param verbose: inserts prints + :param name: to overwrite onnx name + :return: python code + + The following example shows what a python code creating a graph + implementing the KMeans would look like. + + .. runpython:: + :showcode: + + import numpy + from sklearn.cluster import KMeans + from skl2onnx import to_onnx + from mlprodict.onnx_tools.onnx_export import export2onnx + + X = numpy.arange(20).reshape(10, 2).astype(numpy.float32) + tr = KMeans(n_clusters=2) + tr.fit(X) + + onx = to_onnx(tr, X, target_opset=14) + code = export2onnx(onx) + + print(code) + """ + if isinstance(model_onnx, str): + model_onnx = onnx.load(model_onnx) + + return export_template(model_onnx, templates=_onnx_templates, + opset=opset, verbose=verbose, name=name) + + +def export2tf2onnx(model_onnx, opset=None, verbose=True, name=None): + """ + Exports an ONNX model to the e:pkg:`tensorflow-onnx` syntax. + + :param model_onnx: string or ONNX graph + :param opset: opset to export to + (None to select the one from the graph) + :param verbose: inserts prints + :param name: to overwrite onnx name + :return: python code + + .. runpython:: + :showcode: + + import numpy + from sklearn.cluster import KMeans + from skl2onnx import to_onnx + from mlprodict.onnx_tools.onnx_export import export2tf2onnx + + X = numpy.arange(20).reshape(10, 2).astype(numpy.float32) + tr = KMeans(n_clusters=2) + tr.fit(X) + + onx = to_onnx(tr, X, target_opset=14) + code = export2tf2onnx(onx) + + print(code) + """ + if isinstance(model_onnx, str): + model_onnx = onnx.load(model_onnx) + + return export_template(model_onnx, templates=_tf2onnx_templates, + opset=opset, verbose=verbose, name=name) diff --git a/mlprodict/onnx_tools/onnx_tools.py b/mlprodict/onnx_tools/onnx_tools.py index 0ce1b542c..7d5ca3b5b 100644 --- a/mlprodict/onnx_tools/onnx_tools.py +++ b/mlprodict/onnx_tools/onnx_tools.py @@ -76,6 +76,8 @@ def insert_node(model, op_type, node, input_index=0, new_name=None, **attrs): inode.input[input_index] = new_name keep_nodes = list(model.graph.node) keep_nodes.append(new_node) + keep_nodes = ensure_topological_order( + model.graph.input, model.graph.initializer, keep_nodes) graph = helper.make_graph( keep_nodes, model.graph.name, model.graph.input, @@ -102,3 +104,88 @@ def insert_node(model, op_type, node, input_index=0, new_name=None, **attrs): "Input mismatch {} != {}".format( len(onnx_model.input), len(model.input))) # pylint: disable=E1101 return onnx_model + + +def ensure_topological_order(inputs, initializers, nodes): + """ + Ensures and modifies the order of nodes to have + a topological order (every node in the list + can only be an input for a node later in this list). + The function raises an exception if a cycle is detected. + + :param inputs: graph inputs: + :param initializers: graph initializers + :param nodes: graph nodes + :return: list ordered nodes + """ + order = {} + for inp in inputs: + name = inp.name + order[name] = 0 + for inp in initializers: + name = inp.name + order[name] = 0 + n_iter = 0 + while n_iter < len(nodes) * 2: + n_iter += 1 + missing_names = set() + missing_ops = [] + for node in nodes: + maxi = 0 + for name in node.input: + if name in order: + maxi = max(maxi, order[name]) + else: + maxi = None + missing_names.add(name) + break + if maxi is None: + missing_ops.append(node) + continue + key = id(node) + if key in order: + continue + maxi += 1 + order[key] = maxi + maxi += 1 + for name in node.output: + if name in order: + raise RuntimeError( + "Unable to sort a node (cycle). An output was " + "already ordered %r (iteration=%r)." % ( + name, n_iter)) + order[name] = maxi + if len(missing_names) == 0: + continue + + if len(missing_ops) > 0: + def nstr(name): + if name in order: + return "%s#%d" % (name, order[name]) + return name + rows = ["%s(%s) -> [%s]" % ( + n.name or n.op_type, + ', '.join(map(nstr, n.input)), + ', '.join(n.output)) + for n in missing_ops] + rows.insert(0, "") + rows.append("--") + rows.append("--all-nodes--") + rows.append("--") + rows.extend("%s(%s) -> [%s]" % ( + n.name or n.op_type, + ', '.join(map(nstr, n.input)), + ', '.join(n.output)) + for n in nodes) + raise RuntimeError( + "After %d iterations for %d nodes, still unable " + "to sort names %r. The graph may be disconnected. " + "List of operators: %s" % ( + n_iter, len(nodes), missing_names, + "\n".join(rows))) + + # Update order + topo = [(order[id(node)], str(id(node))) for node in nodes] + topo.sort() + map_nodes = {str(id(node)): node for node in nodes} + return [map_nodes[_[1]] for _ in topo] diff --git a/mlprodict/onnxrt/onnx_inference.py b/mlprodict/onnxrt/onnx_inference.py index 9a9a7bd7e..128bbf10a 100644 --- a/mlprodict/onnxrt/onnx_inference.py +++ b/mlprodict/onnxrt/onnx_inference.py @@ -822,7 +822,8 @@ def dispsimple(arr): mini = numpy_min(values[k]) maxi = numpy_max(values[k]) fLOG("+kr{}'{}': {} (dtype={} min={} max={}{})".format( - "=" if len(values[k].shape) == 0 or min(values[k].shape) > 0 else "*", + "=" if len(values[k].shape) == 0 or min( + values[k].shape) > 0 else "*", name, values[k].shape, values[k].dtype, mini, maxi, ' sparse' if isinstance(values[k], coo_matrix) else '')) diff --git a/mlprodict/onnxrt/ops_cpu/_op_helper.py b/mlprodict/onnxrt/ops_cpu/_op_helper.py index 389ac914e..15d5920d9 100644 --- a/mlprodict/onnxrt/ops_cpu/_op_helper.py +++ b/mlprodict/onnxrt/ops_cpu/_op_helper.py @@ -3,7 +3,6 @@ @brief Runtime operator. """ import numpy -from onnx import TensorProto def _get_typed_class_attribute(self, k, atts): @@ -32,35 +31,8 @@ def proto2dtype(proto_type): :param proto_type: example ``onnx.TensorProto.FLOAT`` :return: :epkg:`numpy` dtype """ - if proto_type == TensorProto.FLOAT: # pylint: disable=E1101 - return numpy.float32 - if proto_type == TensorProto.BOOL: # pylint: disable=E1101 - return numpy.bool_ - if proto_type == TensorProto.DOUBLE: # pylint: disable=E1101 - return numpy.float64 - if proto_type == TensorProto.STRING: # pylint: disable=E1101 - return numpy.str_ - if proto_type == TensorProto.INT64: # pylint: disable=E1101 - return numpy.int64 - if proto_type == TensorProto.INT32: # pylint: disable=E1101 - return numpy.int32 - if proto_type == TensorProto.INT8: # pylint: disable=E1101 - return numpy.int8 - if proto_type == TensorProto.INT16: # pylint: disable=E1101 - return numpy.int16 - if proto_type == TensorProto.UINT64: # pylint: disable=E1101 - return numpy.uint64 - if proto_type == TensorProto.UINT32: # pylint: disable=E1101 - return numpy.uint32 - if proto_type == TensorProto.UINT8: # pylint: disable=E1101 - return numpy.uint8 - if proto_type == TensorProto.UINT16: # pylint: disable=E1101 - return numpy.uint16 - if proto_type == TensorProto.FLOAT16: # pylint: disable=E1101 - return numpy.float16 - raise ValueError( - "Unable to convert proto_type {} to numpy type.".format( - proto_type)) + from ...onnx_tools.onnx2py_helper import guess_dtype + return guess_dtype(proto_type) def dtype_name(dtype): diff --git a/mlprodict/testing/verify_code.py b/mlprodict/testing/verify_code.py index 929fde3f5..6d2f18dbe 100644 --- a/mlprodict/testing/verify_code.py +++ b/mlprodict/testing/verify_code.py @@ -4,6 +4,9 @@ before finalizing the benchmark. """ import ast +import collections +import inspect +import numpy class ImperfectPythonCode(RuntimeError): @@ -29,7 +32,11 @@ def verify_code(source, exc=True): imports = v._imports names = v._names args = v._args - known = {'super': None, 'ImportError': None} + known = {'super': None, 'ImportError': None, 'print': print, + 'classmethod': classmethod, 'numpy': numpy, + 'dict': dict, 'list': list, 'sorted': sorted, 'len': len, + 'collections': collections, 'inspect': inspect, 'range': range, + 'int': int, 'str': str, 'isinstance': isinstance} for kn in imports: known[kn[0]] = kn for kn in assign: