From 15e2476a80361f931382cdaa0101de0be3d58374 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Wed, 4 Aug 2021 00:47:21 +0200 Subject: [PATCH 01/12] Create test_onnxrt_runtime_lightgbm_bug.py --- .../test_onnxrt_runtime_lightgbm_bug.py | 118 ++++++++++++++++++ 1 file changed, 118 insertions(+) create mode 100644 _unittests/ut_onnx_conv/test_onnxrt_runtime_lightgbm_bug.py diff --git a/_unittests/ut_onnx_conv/test_onnxrt_runtime_lightgbm_bug.py b/_unittests/ut_onnx_conv/test_onnxrt_runtime_lightgbm_bug.py new file mode 100644 index 000000000..162a001b8 --- /dev/null +++ b/_unittests/ut_onnx_conv/test_onnxrt_runtime_lightgbm_bug.py @@ -0,0 +1,118 @@ +""" +@brief test log(time=3s) +""" +import sys +import unittest +from logging import getLogger +import numpy +import pandas +from pyquickhelper.pycode import ExtTestCase, skipif_circleci, ignore_warnings +from skl2onnx.common.data_types import FloatTensorType +from mlprodict.onnxrt import OnnxInference +from mlprodict.onnx_conv import register_converters, to_onnx +from mlprodict.tools.asv_options_helper import get_ir_version_from_onnx + + +class TestOnnxrtRuntimeLightGbmBug(ExtTestCase): + + def setUp(self): + logger = getLogger('skl2onnx') + logger.disabled = True + register_converters() + + @unittest.skipIf(sys.platform == 'darwin', 'stuck') + def test_lightgbm_regressor(self): + from lightgbm import LGBMRegressor + from onnxmltools.convert import convert_lightgbm + + X = numpy.abs(numpy.random.randn(7, 227)).astype(numpy.float32) + y = X.sum(axis=1) + numpy.random.randn( + X.shape[0]).astype(numpy.float32) / 10 + model = LGBMRegressor( + max_depth=8, n_estimators=100, min_child_samples=1, + learning_rate=0.0000001) + model.fit(X, y) + expected = model.predict(X) + + model_onnx = to_onnx(model, X) + model_onnx2 = convert_lightgbm( + model, initial_types=[('X', FloatTensorType([None, 227]))]) + + for i, mo in enumerate([model_onnx, model_onnx2]): + for rt in ['python', 'onnxruntime1']: + with self.subTest(i=i, rt=rt): + oinf = OnnxInference(mo, runtime=rt) + got = oinf.run({'X': X})['variable'] + diff = numpy.abs(got.ravel() - expected.ravel()).max() + if __name__ == "__main__": + print("lgb", i, rt, diff) + self.assertLess(diff, 1e-3) + + @unittest.skipIf(sys.platform == 'darwin', 'stuck') + def test_lightgbm_regressor_double(self): + from lightgbm import LGBMRegressor + from onnxmltools.convert import convert_lightgbm + + X = numpy.abs(numpy.random.randn(7, 227)).astype(numpy.float32) + y = X.sum(axis=1) + numpy.random.randn( + X.shape[0]).astype(numpy.float32) / 10 + model = LGBMRegressor( + max_depth=8, n_estimators=100, min_child_samples=1, + learning_rate=0.0000001) + model.fit(X, y) + expected = model.predict(X) + + model_onnx = to_onnx(model, X, rewrite_ops=True) + model_onnx2 = to_onnx(model, X.astype(numpy.float64), + rewrite_ops=True) + + for i, mo in enumerate([model_onnx, model_onnx2]): + for rt in ['python', 'onnxruntime1']: + if "TreeEnsembleRegressorDouble" in str(mo): + x = X.astype(numpy.float64) + if rt == 'onnxruntime1': + continue + else: + x = X + with self.subTest(i=i, rt=rt): + oinf = OnnxInference(mo, runtime=rt) + got = oinf.run({'X': x})['variable'] + diff = numpy.abs(got.ravel() - expected.ravel()).max() + if __name__ == "__main__": + print("lgbd", i, rt, diff) + if i == 1 and rt == 'python': + self.assertLess(diff, 1e-5) + else: + self.assertLess(diff, 1e-3) + + @unittest.skipIf(sys.platform == 'darwin', 'stuck') + def test_xgboost_regressor(self): + from xgboost import XGBRegressor + from onnxmltools.convert import convert_xgboost + + X = numpy.abs(numpy.random.randn(7, 227)).astype(numpy.float32) + y = X.sum(axis=1) + numpy.random.randn( + X.shape[0]).astype(numpy.float32) / 10 + model = XGBRegressor( + max_depth=8, n_estimators=100, + learning_rate=0.000001) + model.fit(X, y) + expected = model.predict(X) + + model_onnx = to_onnx(model, X) + model_onnx2 = convert_xgboost( + model, initial_types=[('X', FloatTensorType([None, 227]))]) + + for i, mo in enumerate([model_onnx, model_onnx2]): + for rt in ['python', 'onnxruntime1']: + with self.subTest(i=i, rt=rt): + oinf = OnnxInference(mo, runtime=rt) + got = oinf.run({'X': X})['variable'] + diff = numpy.abs(got.ravel() - expected.ravel()).max() + if __name__ == "__main__": + print("xgb", i, rt, diff) + self.assertLess(diff, 1e-5) + + +if __name__ == "__main__": + unittest.main() From 5a1a9048295888763a533a2d0d601fa39b512952 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Wed, 4 Aug 2021 09:25:51 +0200 Subject: [PATCH 02/12] add onnxmltools --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 99ae0ff4b..943306bcb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -46,5 +46,6 @@ xgboost # onnx onnx==1.9.0 +onnxmltools onnxruntime>=1.8.0 skl2onnx>=1.9.0 From ae3fbabfb1a7a1feee6edfe2b9fecbebc677306a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Wed, 4 Aug 2021 10:50:30 +0200 Subject: [PATCH 03/12] issue --- .../test_onnxrt_runtime_lightgbm_bug.py | 29 ++++++++++++++----- requirements.txt | 1 - 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/_unittests/ut_onnx_conv/test_onnxrt_runtime_lightgbm_bug.py b/_unittests/ut_onnx_conv/test_onnxrt_runtime_lightgbm_bug.py index 162a001b8..6fbfa492a 100644 --- a/_unittests/ut_onnx_conv/test_onnxrt_runtime_lightgbm_bug.py +++ b/_unittests/ut_onnx_conv/test_onnxrt_runtime_lightgbm_bug.py @@ -23,7 +23,10 @@ def setUp(self): @unittest.skipIf(sys.platform == 'darwin', 'stuck') def test_lightgbm_regressor(self): from lightgbm import LGBMRegressor - from onnxmltools.convert import convert_lightgbm + try: + from onnxmltools.convert import convert_lightgbm + except ImportError: + convert_lightgbm = None X = numpy.abs(numpy.random.randn(7, 227)).astype(numpy.float32) y = X.sum(axis=1) + numpy.random.randn( @@ -35,10 +38,15 @@ def test_lightgbm_regressor(self): expected = model.predict(X) model_onnx = to_onnx(model, X) - model_onnx2 = convert_lightgbm( - model, initial_types=[('X', FloatTensorType([None, 227]))]) + if convert_lightgbm is not None: + model_onnx2 = convert_lightgbm( + model, initial_types=[('X', FloatTensorType([None, 227]))]) + else: + model_onnx2 = None for i, mo in enumerate([model_onnx, model_onnx2]): + if mo is None: + continue for rt in ['python', 'onnxruntime1']: with self.subTest(i=i, rt=rt): oinf = OnnxInference(mo, runtime=rt) @@ -51,7 +59,6 @@ def test_lightgbm_regressor(self): @unittest.skipIf(sys.platform == 'darwin', 'stuck') def test_lightgbm_regressor_double(self): from lightgbm import LGBMRegressor - from onnxmltools.convert import convert_lightgbm X = numpy.abs(numpy.random.randn(7, 227)).astype(numpy.float32) y = X.sum(axis=1) + numpy.random.randn( @@ -88,7 +95,10 @@ def test_lightgbm_regressor_double(self): @unittest.skipIf(sys.platform == 'darwin', 'stuck') def test_xgboost_regressor(self): from xgboost import XGBRegressor - from onnxmltools.convert import convert_xgboost + try: + from onnxmltools.convert import convert_xgboost + except ImportError: + convert_xgboost = None X = numpy.abs(numpy.random.randn(7, 227)).astype(numpy.float32) y = X.sum(axis=1) + numpy.random.randn( @@ -100,10 +110,15 @@ def test_xgboost_regressor(self): expected = model.predict(X) model_onnx = to_onnx(model, X) - model_onnx2 = convert_xgboost( - model, initial_types=[('X', FloatTensorType([None, 227]))]) + if convert_xgboost is not None: + model_onnx2 = convert_xgboost( + model, initial_types=[('X', FloatTensorType([None, 227]))]) + else: + model_onnx2 = None for i, mo in enumerate([model_onnx, model_onnx2]): + if mo is None: + continue for rt in ['python', 'onnxruntime1']: with self.subTest(i=i, rt=rt): oinf = OnnxInference(mo, runtime=rt) diff --git a/requirements.txt b/requirements.txt index 943306bcb..99ae0ff4b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -46,6 +46,5 @@ xgboost # onnx onnx==1.9.0 -onnxmltools onnxruntime>=1.8.0 skl2onnx>=1.9.0 From febd911e063514240cebdcf830ef637d27271994 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Wed, 4 Aug 2021 16:52:56 +0200 Subject: [PATCH 04/12] lint --- _unittests/ut_onnx_conv/test_onnxrt_runtime_lightgbm_bug.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/_unittests/ut_onnx_conv/test_onnxrt_runtime_lightgbm_bug.py b/_unittests/ut_onnx_conv/test_onnxrt_runtime_lightgbm_bug.py index 6fbfa492a..3160007b2 100644 --- a/_unittests/ut_onnx_conv/test_onnxrt_runtime_lightgbm_bug.py +++ b/_unittests/ut_onnx_conv/test_onnxrt_runtime_lightgbm_bug.py @@ -5,12 +5,10 @@ import unittest from logging import getLogger import numpy -import pandas -from pyquickhelper.pycode import ExtTestCase, skipif_circleci, ignore_warnings +from pyquickhelper.pycode import ExtTestCase from skl2onnx.common.data_types import FloatTensorType from mlprodict.onnxrt import OnnxInference from mlprodict.onnx_conv import register_converters, to_onnx -from mlprodict.tools.asv_options_helper import get_ir_version_from_onnx class TestOnnxrtRuntimeLightGbmBug(ExtTestCase): From 4516393f9d1ef6efc273cd3aa3669c3d9a78411b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Thu, 5 Aug 2021 16:43:12 +0200 Subject: [PATCH 05/12] Add function matmul, extend notebook on fft --- _doc/notebooks/onnx_fft.ipynb | 3719 +++++++++++++---- .../blog/2021/2021-05-05_numpyapionnx1.rst | 2 +- _unittests/ut_npy/test_onnx_variable.py | 12 + mlprodict/npy/numpy_onnx_impl.py | 13 +- mlprodict/npy/numpy_onnx_pyrt.py | 7 + 5 files changed, 2929 insertions(+), 824 deletions(-) diff --git a/_doc/notebooks/onnx_fft.ipynb b/_doc/notebooks/onnx_fft.ipynb index 6efb475e2..da8073759 100644 --- a/_doc/notebooks/onnx_fft.ipynb +++ b/_doc/notebooks/onnx_fft.ipynb @@ -1,822 +1,2897 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "51bc89fc", - "metadata": {}, - "source": [ - "# ONNX and FFT\n", - "\n", - "ONNX does not fully support complex yet. It does not have any FFT operators either. What if we need them anyway?" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "7b2add97", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
run previous cell, wait for 2 seconds
\n", - "" - ], - "text/plain": [ - "" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from jyquickhelper import add_notebook_menu\n", - "add_notebook_menu()" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "acfdc3b0", - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext mlprodict" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "abb5fa88", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'1.21.0'" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import numpy\n", - "numpy.__version__" - ] - }, - { - "cell_type": "markdown", - "id": "2e4f68e4", - "metadata": {}, - "source": [ - "## Python implementation of RFFT\n", - "\n", - "We try to replicate [numpy.rfft](https://numpy.org/doc/stable/reference/generated/numpy.fft.rfft.html)." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "cb1cc910", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([[ 0.92935219+0.j , 1.1166406 +0.18610885j,\n", - " 2.98881347-0.86137828j, 0.57062752-3.17075076j],\n", - " [-0.81071034+0.j , 4.04571912+1.34415298j,\n", - " -0.75316593+1.87375117j, -3.73972034+1.19963451j],\n", - " [ 0.49893169+0.j , -2.38853745+0.91784964j,\n", - " -2.3230939 +2.42467461j, 2.84973582+0.96874118j],\n", - " [-0.85518897+0.j , -1.07457921+2.14618057j,\n", - " 0.67522719-2.17320735j, 1.31480887+2.2782433j ],\n", - " [ 2.80867666+0.j , -2.79453396-2.22901834j,\n", - " 0.492986 +0.10661537j, 2.65317564+0.57651319j]])" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import numpy\n", - "\n", - "\n", - "def almost_equal(a, b, error=1e-5):\n", - " \"\"\"\n", - " The function compares two matrices, one may be complex. In that case,\n", - " this matrix is changed into a new matrix with a new first dimension,\n", - " [0,::] means real part, [1,::] means imaginary part.\n", - " \"\"\"\n", - " if a.dtype in (numpy.complex64, numpy.complex128):\n", - " dtype = numpy.float64 if a.dtype == numpy.complex128 else numpy.float32\n", - " new_a = numpy.empty((2,) + a.shape).astype(dtype)\n", - " new_a[0] = numpy.real(a)\n", - " new_a[1] = numpy.imag(a)\n", - " return almost_equal(new_a, b, error)\n", - " if b.dtype in (numpy.complex64, numpy.complex128):\n", - " return almost_equal(b, a, error)\n", - " if a.shape != b.shape:\n", - " raise AssertionError(\"Shape mismatch %r != %r.\" % (a.shape, b.shape))\n", - " diff = numpy.abs(a.ravel() - b.ravel()).max()\n", - " if diff > error:\n", - " raise AssertionError(\"Mismatch max diff=%r > %r.\" % (diff, error))\n", - "\n", - "\n", - "def dft_real_cst(N, fft_length):\n", - " n = numpy.arange(N)\n", - " k = n.reshape((N, 1)).astype(numpy.float64)\n", - " M = numpy.exp(-2j * numpy.pi * k * n / fft_length)\n", - " both = numpy.empty((2,) + M.shape)\n", - " both[0, :, :] = numpy.real(M)\n", - " both[1, :, :] = numpy.imag(M)\n", - " return both\n", - "\n", - "\n", - "def dft_real(x, fft_length=None, transpose=True):\n", - " if len(x.shape) == 1:\n", - " x = x.reshape((1, -1))\n", - " N = 1\n", - " else:\n", - " N = x.shape[0] \n", - " C = x.shape[-1] if transpose else x.shape[-2]\n", - " if fft_length is None:\n", - " fft_length = x.shape[-1]\n", - " size = fft_length // 2 + 1\n", - "\n", - " cst = dft_real_cst(C, fft_length)\n", - " if transpose:\n", - " x = numpy.transpose(x, (1, 0))\n", - " res = numpy.dot(cst[:, :, :fft_length], x[:fft_length])[:, :size, :]\n", - " return numpy.transpose(res, (0, 2, 1))\n", - " else:\n", - " return numpy.dot(cst[:, :, :fft_length], x[:fft_length])\n", - "\n", - "\n", - "rnd = numpy.random.randn(5, 7).astype(numpy.float32)\n", - "fft_np = numpy.fft.rfft(rnd)\n", - "fft_cus = dft_real(rnd)\n", - "fft_np" - ] - }, - { - "cell_type": "markdown", - "id": "0c052ea1", - "metadata": {}, - "source": [ - "Function `almost_equal` verifies both functions return the same results." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "3ca040cb", - "metadata": {}, - "outputs": [], - "source": [ - "almost_equal(fft_np, fft_cus)" - ] - }, - { - "cell_type": "markdown", - "id": "7fe77440", - "metadata": {}, - "source": [ - "Let's do the same with `fft_length < shape[1]`." - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "3a747a4a", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([[ 0.58212829+0.j , 1.91211772-1.78320393j],\n", - " [-0.3185378 +0.j , -0.20609781-1.18129868j],\n", - " [-0.81120646+0.j , -0.28543806+3.05769342j],\n", - " [-1.06384408+0.j , 0.74100591+0.43276681j],\n", - " [ 1.77509081+0.j , -0.13498855+1.82011058j]])" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "fft_np3 = numpy.fft.rfft(rnd, n=3)\n", - "fft_cus3 = dft_real(rnd, fft_length=3)\n", - "fft_np3" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "0db6247b", - "metadata": {}, - "outputs": [], - "source": [ - "almost_equal(fft_np3, fft_cus3)" - ] - }, - { - "cell_type": "markdown", - "id": "31a6ac9c", - "metadata": {}, - "source": [ - "## RFFT in ONNX\n", - "\n", - "Let's assume first the number of column of the input matrix is fixed. The result of function `dft_real_cst` can be considered as constant." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "efb67b9b", - "metadata": { - "scrolled": false - }, - "outputs": [ - { - "data": { - "text/plain": [ - "array([[[ 0.92935216, 1.1166406 , 2.9888134 , 0.5706275 ],\n", - " [-0.81071043, 4.045719 , -0.7531659 , -3.7397203 ],\n", - " [ 0.4989317 , -2.3885374 , -2.323094 , 2.849736 ],\n", - " [-0.85518885, -1.0745792 , 0.6752271 , 1.3148088 ],\n", - " [ 2.8086765 , -2.794534 , 0.49298596, 2.6531756 ]],\n", - "\n", - " [[ 0. , 0.18610872, -0.8613782 , -3.1707506 ],\n", - " [ 0. , 1.3441529 , 1.8737512 , 1.1996344 ],\n", - " [ 0. , 0.9178499 , 2.4246747 , 0.96874106],\n", - " [ 0. , 2.1461806 , -2.1732073 , 2.2782433 ],\n", - " [ 0. , -2.2290184 , 0.10661539, 0.5765133 ]]],\n", - " dtype=float32)" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from typing import Any\n", - "import mlprodict.npy.numpy_onnx_impl as npnx\n", - "from mlprodict.npy import onnxnumpy_np\n", - "from mlprodict.npy.onnx_numpy_annotation import NDArrayType\n", - "# from mlprodict.onnxrt import OnnxInference\n", - "\n", - "@onnxnumpy_np(signature=NDArrayType((\"T:all\", ), dtypes_out=('T',)))\n", - "def onnx_rfft(x, fft_length=None):\n", - " if fft_length is None:\n", - " raise RuntimeError(\"fft_length must be specified.\")\n", - " \n", - " size = fft_length // 2 + 1\n", - " cst = dft_real_cst(fft_length, fft_length).astype(numpy.float32)\n", - " xt = npnx.transpose(x, (1, 0))\n", - " res = npnx.dot(cst[:, :, :fft_length], xt[:fft_length])[:, :size, :]\n", - " return npnx.transpose(res, (0, 2, 1))\n", - "\n", - "fft_onx = onnx_rfft(rnd, fft_length=rnd.shape[1])\n", - "fft_onx" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "c4b6b1a5", - "metadata": {}, - "outputs": [], - "source": [ - "almost_equal(fft_cus, fft_onx)" - ] - }, - { - "cell_type": "markdown", - "id": "a8c35327", - "metadata": {}, - "source": [ - "The corresponding ONNX graph is the following:" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "4d1a85b0", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "" - ], - "text/plain": [ - "" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "key = list(onnx_rfft.signed_compiled)[0]\n", - "%onnxview onnx_rfft.signed_compiled[key].compiled.onnx_" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "6cf18aca", - "metadata": {}, - "outputs": [], - "source": [ - "fft_onx3 = onnx_rfft(rnd, fft_length=3)\n", - "almost_equal(fft_cus3, fft_onx3)" - ] - }, - { - "cell_type": "markdown", - "id": "6b466fd4", - "metadata": {}, - "source": [ - "## FFT 2D\n", - "\n", - "Below the code for complex features." - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "e0020084", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([[-2.036582 +0.j , -0.85992725+6.47780438j,\n", - " -3.99332006-3.11192536j, -1.32368431-2.48821071j],\n", - " [ 4.37345155-7.03173815j, -3.14890126+1.59632335j,\n", - " -3.75306979+0.66651699j, -1.56716114+4.75028368j],\n", - " [ 2.76767016+1.25297955j, 5.07926144-2.23393831j,\n", - " 2.41908275-8.55451105j, -8.84556476-1.29356088j],\n", - " [ 2.76767016-1.25297955j, 2.41782872-4.44962381j,\n", - " -3.6501426 +4.13120322j, 4.30875103+0.96179243j],\n", - " [ 4.37345155+7.03173815j, 1.77135529+4.4385736j ,\n", - " 2.40878105+5.40109054j, -1.65462983+0.2149866j ]])" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "def _DFT_cst(N, fft_length, trunc=True):\n", - " n = numpy.arange(N)\n", - " k = n.reshape((N, 1)).astype(numpy.float64)\n", - " M = numpy.exp(-2j * numpy.pi * k * n / fft_length)\n", - " return M[:fft_length // 2 + 1] if trunc else M\n", - "\n", - "def DFT(x, fft_length=None, axis=1):\n", - " if axis == 1:\n", - " x = x.T\n", - " if fft_length is None:\n", - " fft_length = x.shape[0]\n", - " cst = _DFT_cst(x.shape[0], fft_length, trunc=axis==1)\n", - " if axis == 1:\n", - " return numpy.dot(cst, x).T\n", - " return numpy.dot(cst, x)\n", - "\n", - "def fft2d_(mat, fft_length):\n", - " mat = mat[:fft_length[0], :fft_length[1]]\n", - " res = mat.copy()\n", - " res = DFT(res, fft_length[1], axis=1)\n", - " res = DFT(res, fft_length[0], axis=0)\n", - " return res[:fft_length[0], :fft_length[1]//2 + 1]\n", - "\n", - "\n", - "rnd = numpy.random.randn(5, 7).astype(numpy.float32)\n", - "fft2d_np_ = fft2d_(rnd, rnd.shape)\n", - "fft2d_np = numpy.fft.rfft2(rnd)\n", - "fft2d_np_" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "777d2775", - "metadata": {}, - "outputs": [], - "source": [ - "almost_equal(fft2d_np_, fft2d_np)" - ] - }, - { - "cell_type": "markdown", - "id": "cfbbe2fd", - "metadata": {}, - "source": [ - "It implies the computation of two FFT 1D along both axes. However, as ONNX does not support complex, it needs to be rewritten with only real numbers. The algorithm can be summarized into this formula $FFT(FFT(x, axis=1), axis=0)$. If *x* is real, $FFT(x, .)$ is complex. We still assume *x* is real, it then becomes (FFT is a linear operator, so $FFT(ix)=i FFT(x)$):\n", - "\n", - "* $y = FFT(x, axis=1)$\n", - "* $z_r = FFT(Real(y), axis=0)$, $z_i = FFT(Imag(y), axis=0)$\n", - "* $z = z_r + i z_i$\n", - "\n", - "*z* is the desired output. The following implementation is probably not the most efficient one. It avoids inplace computation as ONNX does like that." - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "dd4fc711", - "metadata": {}, - "outputs": [], - "source": [ - "def fft2d(mat, fft_length):\n", - " mat = mat[:fft_length[0], :fft_length[1]]\n", - " res = mat.copy()\n", - " \n", - " # first FFT\n", - " res = dft_real(res, fft_length=fft_length[1], transpose=True)\n", - " \n", - " # second FFT decomposed on FFT on real part and imaginary part\n", - " res2_real = dft_real(res[0], fft_length=fft_length[0], transpose=False)\n", - " res2_imag = dft_real(res[1], fft_length=fft_length[0], transpose=False) \n", - " res2_imag2 = numpy.vstack([-res2_imag[1:2], res2_imag[:1]])\n", - " res = res2_real + res2_imag2\n", - " size = fft_length[1]//2 + 1\n", - " return res[:, :fft_length[0], :size]\n", - "\n", - "\n", - "fft2d_np = numpy.fft.rfft2(rnd)\n", - "fft2d_cus = fft2d(rnd, rnd.shape)\n", - "almost_equal(fft2d_np, fft2d_cus)" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "bb8667e6", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([[-2.036582 +0.j , -0.85992725+6.47780438j,\n", - " -3.99332006-3.11192536j, -1.32368431-2.48821071j],\n", - " [ 4.37345155-7.03173815j, -3.14890126+1.59632335j,\n", - " -3.75306979+0.66651699j, -1.56716114+4.75028368j],\n", - " [ 2.76767016+1.25297955j, 5.07926144-2.23393831j,\n", - " 2.41908275-8.55451105j, -8.84556476-1.29356088j],\n", - " [ 2.76767016-1.25297955j, 2.41782872-4.44962381j,\n", - " -3.6501426 +4.13120322j, 4.30875103+0.96179243j],\n", - " [ 4.37345155+7.03173815j, 1.77135529+4.4385736j ,\n", - " 2.40878105+5.40109054j, -1.65462983+0.2149866j ]])" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "fft2d_np" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "56a94d97", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([[[-2.036582 , -0.85992725, -3.99332006, -1.32368431],\n", - " [ 4.37345155, -3.14890126, -3.75306979, -1.56716114],\n", - " [ 2.76767016, 5.07926144, 2.41908275, -8.84556476],\n", - " [ 2.76767016, 2.41782872, -3.6501426 , 4.30875103],\n", - " [ 4.37345155, 1.77135529, 2.40878105, -1.65462983]],\n", - "\n", - " [[ 0. , 6.47780438, -3.11192536, -2.48821071],\n", - " [-7.03173815, 1.59632335, 0.66651699, 4.75028368],\n", - " [ 1.25297955, -2.23393831, -8.55451105, -1.29356088],\n", - " [-1.25297955, -4.44962381, 4.13120322, 0.96179243],\n", - " [ 7.03173815, 4.4385736 , 5.40109054, 0.2149866 ]]])" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "fft2d_cus" - ] - }, - { - "cell_type": "markdown", - "id": "faa21909", - "metadata": {}, - "source": [ - "And with a different `fft_length`." - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "bf98995f", - "metadata": {}, - "outputs": [], - "source": [ - "fft2d_np = numpy.fft.rfft2(rnd, (4, 6))\n", - "fft2d_cus = fft2d(rnd, (4, 6))\n", - "almost_equal(fft2d_np[:4, :], fft2d_cus)" - ] - }, - { - "cell_type": "markdown", - "id": "caee1f84", - "metadata": {}, - "source": [ - "## FFT 2D in ONNX\n", - "\n", - "We use again the numpy API for ONNX." - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "id": "ca641274", - "metadata": {}, - "outputs": [], - "source": [ - "def onnx_rfft_1d(x, fft_length=None, transpose=True):\n", - " if fft_length is None:\n", - " raise RuntimeError(\"fft_length must be specified.\")\n", - " \n", - " size = fft_length // 2 + 1\n", - " cst = dft_real_cst(fft_length, fft_length).astype(numpy.float32)\n", - " if transpose:\n", - " xt = npnx.transpose(x, (1, 0))\n", - " res = npnx.dot(cst[:, :, :fft_length], xt[:fft_length])[:, :size, :]\n", - " return npnx.transpose(res, (0, 2, 1))\n", - " else:\n", - " return npnx.dot(cst[:, :, :fft_length], x[:fft_length])\n", - "\n", - "\n", - "@onnxnumpy_np(signature=NDArrayType((\"T:all\", ), dtypes_out=('T',)))\n", - "def onnx_rfft_2d(x, fft_length=None):\n", - " mat = x[:fft_length[0], :fft_length[1]]\n", - " \n", - " # first FFT\n", - " res = onnx_rfft_1d(mat, fft_length=fft_length[1], transpose=True)\n", - " \n", - " # second FFT decomposed on FFT on real part and imaginary part\n", - " res2_real = onnx_rfft_1d(res[0], fft_length=fft_length[0], transpose=False)\n", - " res2_imag = onnx_rfft_1d(res[1], fft_length=fft_length[0], transpose=False) \n", - " res2_imag2 = npnx.vstack(-res2_imag[1:2], res2_imag[:1])\n", - " res = res2_real + res2_imag2\n", - " size = fft_length[1]//2 + 1\n", - " return res[:, :fft_length[0], :size]\n", - "\n", - "\n", - "fft2d_cus = fft2d(rnd, rnd.shape)\n", - "fft2d_onx = onnx_rfft_2d(rnd, fft_length=rnd.shape)\n", - "almost_equal(fft2d_cus, fft2d_onx)" - ] - }, - { - "cell_type": "markdown", - "id": "20fcd8a9", - "metadata": {}, - "source": [ - "The corresponding ONNX graph." - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "b1379b06", - "metadata": { - "scrolled": false - }, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "" - ], - "text/plain": [ - "" - ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "key = list(onnx_rfft_2d.signed_compiled)[0]\n", - "%onnxview onnx_rfft_2d.signed_compiled[key].compiled.onnx_" - ] - }, - { - "cell_type": "markdown", - "id": "3a747f0c", - "metadata": {}, - "source": [ - "With a different `fft_length`." - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "id": "16732cbb", - "metadata": {}, - "outputs": [], - "source": [ - "fft2d_cus = fft2d(rnd, (4, 5))\n", - "fft2d_onx = onnx_rfft_2d(rnd, fft_length=(4, 5))\n", - "almost_equal(fft2d_cus, fft2d_onx)" - ] - }, - { - "cell_type": "markdown", - "id": "04924e7d", - "metadata": {}, - "source": [ - "This implementation of FFT in ONNX assumes shapes and fft lengths are constant. Otherwise, the matrix returned by function `dft_real_cst` must be converted as well. That's left as an exercise." - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "id": "faeff9cd", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.5" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} \ No newline at end of file +{ + "cells": [ + { + "cell_type": "markdown", + "id": "51bc89fc", + "metadata": {}, + "source": [ + "# ONNX and FFT\n", + "\n", + "ONNX does not fully support complex yet. It does not have any FFT operators either. What if we need them anyway?" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "7b2add97", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
run previous cell, wait for 2 seconds
\n", + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from jyquickhelper import add_notebook_menu\n", + "add_notebook_menu()" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "acfdc3b0", + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext mlprodict" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "abb5fa88", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'1.21.1'" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import numpy\n", + "numpy.__version__" + ] + }, + { + "cell_type": "markdown", + "id": "2e4f68e4", + "metadata": {}, + "source": [ + "## Python implementation of RFFT\n", + "\n", + "We try to replicate [numpy.rfft](https://numpy.org/doc/stable/reference/generated/numpy.fft.rfft.html)." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "cb1cc910", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[ 0.53789312+0.j , 3.79404014-2.48643182j,\n", + " -1.17830152+1.41518659j, 0.03550683-0.74360516j],\n", + " [ 2.91933875+0.j , -0.80104596+1.95022724j,\n", + " 0.04362985+0.86098245j, -0.12878266-2.96095567j],\n", + " [ 4.03702923+0.j , 4.2701964 +0.97900965j,\n", + " 3.66512519-1.69493214j, 0.46502876-2.36643456j],\n", + " [-0.99284691+0.j , 2.275587 +1.04069498j,\n", + " -2.30580317-1.25203798j, 1.99527731+1.50659889j],\n", + " [-3.28264478+0.j , 3.55459652-2.06178787j,\n", + " -0.46036977-1.45184738j, 1.30470444+0.66049446j]])" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import numpy\n", + "\n", + "\n", + "def almost_equal(a, b, error=1e-5):\n", + " \"\"\"\n", + " The function compares two matrices, one may be complex. In that case,\n", + " this matrix is changed into a new matrix with a new first dimension,\n", + " [0,::] means real part, [1,::] means imaginary part.\n", + " \"\"\"\n", + " if a.dtype in (numpy.complex64, numpy.complex128):\n", + " dtype = numpy.float64 if a.dtype == numpy.complex128 else numpy.float32\n", + " new_a = numpy.empty((2,) + a.shape).astype(dtype)\n", + " new_a[0] = numpy.real(a)\n", + " new_a[1] = numpy.imag(a)\n", + " return almost_equal(new_a, b, error)\n", + " if b.dtype in (numpy.complex64, numpy.complex128):\n", + " return almost_equal(b, a, error)\n", + " if a.shape != b.shape:\n", + " raise AssertionError(\"Shape mismatch %r != %r.\" % (a.shape, b.shape))\n", + " diff = numpy.abs(a.ravel() - b.ravel()).max()\n", + " if diff > error:\n", + " raise AssertionError(\"Mismatch max diff=%r > %r.\" % (diff, error))\n", + "\n", + "\n", + "def dft_real_cst(N, fft_length):\n", + " n = numpy.arange(N)\n", + " k = n.reshape((N, 1)).astype(numpy.float64)\n", + " M = numpy.exp(-2j * numpy.pi * k * n / fft_length)\n", + " both = numpy.empty((2,) + M.shape)\n", + " both[0, :, :] = numpy.real(M)\n", + " both[1, :, :] = numpy.imag(M)\n", + " return both\n", + "\n", + "\n", + "def dft_real(x, fft_length=None, transpose=True):\n", + " if len(x.shape) == 1:\n", + " x = x.reshape((1, -1))\n", + " N = 1\n", + " else:\n", + " N = x.shape[0] \n", + " C = x.shape[-1] if transpose else x.shape[-2]\n", + " if fft_length is None:\n", + " fft_length = x.shape[-1]\n", + " size = fft_length // 2 + 1\n", + "\n", + " cst = dft_real_cst(C, fft_length)\n", + " if transpose:\n", + " x = numpy.transpose(x, (1, 0))\n", + " a = cst[:, :, :fft_length]\n", + " b = x[:fft_length]\n", + " res = numpy.matmul(a, b)\n", + " res = res[:, :size, :]\n", + " return numpy.transpose(res, (0, 2, 1))\n", + " else:\n", + " a = cst[:, :, :fft_length]\n", + " b = x[:fft_length]\n", + " return numpy.matmul(a, b)\n", + "\n", + "\n", + "rnd = numpy.random.randn(5, 7).astype(numpy.float32)\n", + "fft_np = numpy.fft.rfft(rnd)\n", + "fft_cus = dft_real(rnd)\n", + "fft_np" + ] + }, + { + "cell_type": "markdown", + "id": "0c052ea1", + "metadata": {}, + "source": [ + "Function `almost_equal` verifies both functions return the same results." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "3ca040cb", + "metadata": {}, + "outputs": [], + "source": [ + "almost_equal(fft_np, fft_cus)" + ] + }, + { + "cell_type": "markdown", + "id": "7fe77440", + "metadata": {}, + "source": [ + "Let's do the same with `fft_length < shape[1]`." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "3a747a4a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[ 2.75342363+0.j , -0.12520096-0.19458685j],\n", + " [-0.50395307+0.j , 0.49774965-0.57195525j],\n", + " [ 2.9414562 +0.j , 2.99449974-2.68322022j],\n", + " [ 0.79957253+0.j , 0.22962989+0.63038464j],\n", + " [ 1.62652442+0.j , 0.36857013-0.3835761j ]])" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "fft_np3 = numpy.fft.rfft(rnd, n=3)\n", + "fft_cus3 = dft_real(rnd, fft_length=3)\n", + "fft_np3" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "0db6247b", + "metadata": {}, + "outputs": [], + "source": [ + "almost_equal(fft_np3, fft_cus3)" + ] + }, + { + "cell_type": "markdown", + "id": "31a6ac9c", + "metadata": {}, + "source": [ + "## RFFT in ONNX\n", + "\n", + "Let's assume first the number of column of the input matrix is fixed. The result of function `dft_real_cst` can be considered as constant." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "efb67b9b", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[[ 0.53789306, 3.79404 , -1.1783015 , 0.03550687],\n", + " [ 2.9193387 , -0.80104595, 0.04362985, -0.12878267],\n", + " [ 4.037029 , 4.2701964 , 3.6651251 , 0.46502876],\n", + " [-0.99284697, 2.275587 , -2.305803 , 1.9952773 ],\n", + " [-3.2826447 , 3.5545964 , -0.4603698 , 1.3047044 ]],\n", + "\n", + " [[ 0. , -2.4864316 , 1.4151866 , -0.74360514],\n", + " [ 0. , 1.9502273 , 0.8609825 , -2.9609556 ],\n", + " [ 0. , 0.9790097 , -1.6949322 , -2.3664346 ],\n", + " [ 0. , 1.040695 , -1.2520379 , 1.5065988 ],\n", + " [ 0. , -2.061788 , -1.4518473 , 0.66049445]]],\n", + " dtype=float32)" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from typing import Any\n", + "import mlprodict.npy.numpy_onnx_impl as npnx\n", + "from mlprodict.npy import onnxnumpy_np\n", + "from mlprodict.npy.onnx_numpy_annotation import NDArrayType\n", + "# from mlprodict.onnxrt import OnnxInference\n", + "\n", + "@onnxnumpy_np(signature=NDArrayType((\"T:all\", ), dtypes_out=('T',)))\n", + "def onnx_rfft(x, fft_length=None):\n", + " if fft_length is None:\n", + " raise RuntimeError(\"fft_length must be specified.\")\n", + " \n", + " size = fft_length // 2 + 1\n", + " cst = dft_real_cst(fft_length, fft_length).astype(numpy.float32)\n", + " xt = npnx.transpose(x, (1, 0))\n", + " res = npnx.matmul(cst[:, :, :fft_length], xt[:fft_length])[:, :size, :]\n", + " return npnx.transpose(res, (0, 2, 1))\n", + "\n", + "fft_onx = onnx_rfft(rnd, fft_length=rnd.shape[1])\n", + "fft_onx" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "c4b6b1a5", + "metadata": {}, + "outputs": [], + "source": [ + "almost_equal(fft_cus, fft_onx)" + ] + }, + { + "cell_type": "markdown", + "id": "a8c35327", + "metadata": {}, + "source": [ + "The corresponding ONNX graph is the following:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "4d1a85b0", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "key = list(onnx_rfft.signed_compiled)[0]\n", + "%onnxview onnx_rfft.signed_compiled[key].compiled.onnx_" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "6cf18aca", + "metadata": {}, + "outputs": [], + "source": [ + "fft_onx3 = onnx_rfft(rnd, fft_length=3)\n", + "almost_equal(fft_cus3, fft_onx3)" + ] + }, + { + "cell_type": "markdown", + "id": "6b466fd4", + "metadata": {}, + "source": [ + "## FFT 2D\n", + "\n", + "Below the code for complex features." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "e0020084", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[-4.00478568+0.j , -8.1697212 -0.32937158j,\n", + " 10.53796531-6.04714298j, 6.21706082+1.12281921j],\n", + " [ 9.15418879-5.84728016j, 6.2117498 +3.03786463j,\n", + " 2.96815316-2.94192092j, -0.68635391-1.52268308j],\n", + " [ 0.52234925+2.67149571j, -0.54453905-1.55520678j,\n", + " 3.10942332-1.6485952j , -4.10661833+9.33122798j],\n", + " [ 0.52234925-2.67149571j, -0.17987319+2.84609599j,\n", + " -6.48443904+3.97163974j, 8.55488898-3.49370573j],\n", + " [ 9.15418879+5.84728016j, 2.24954303+6.3142503j ,\n", + " 10.0104149 +0.18248643j, 3.74038426+7.20880644j]])" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "def _DFT_cst(N, fft_length, trunc=True):\n", + " n = numpy.arange(N)\n", + " k = n.reshape((N, 1)).astype(numpy.float64)\n", + " M = numpy.exp(-2j * numpy.pi * k * n / fft_length)\n", + " return M[:fft_length // 2 + 1] if trunc else M\n", + "\n", + "def DFT(x, fft_length=None, axis=1):\n", + " if axis == 1:\n", + " x = x.T\n", + " if fft_length is None:\n", + " fft_length = x.shape[0]\n", + " cst = _DFT_cst(x.shape[0], fft_length, trunc=axis==1)\n", + " if axis == 1:\n", + " return numpy.matmul(cst, x).T\n", + " return numpy.matmul(cst, x)\n", + "\n", + "def fft2d_(mat, fft_length):\n", + " mat = mat[:fft_length[0], :fft_length[1]]\n", + " res = mat.copy()\n", + " res = DFT(res, fft_length[1], axis=1)\n", + " res = DFT(res, fft_length[0], axis=0)\n", + " return res[:fft_length[0], :fft_length[1]//2 + 1]\n", + "\n", + "\n", + "rnd = numpy.random.randn(5, 7).astype(numpy.float32)\n", + "fft2d_np_ = fft2d_(rnd, rnd.shape)\n", + "fft2d_np = numpy.fft.rfft2(rnd)\n", + "fft2d_np_" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "777d2775", + "metadata": {}, + "outputs": [], + "source": [ + "almost_equal(fft2d_np_, fft2d_np)" + ] + }, + { + "cell_type": "markdown", + "id": "cfbbe2fd", + "metadata": {}, + "source": [ + "It implies the computation of two FFT 1D along both axes. However, as ONNX does not support complex, it needs to be rewritten with only real numbers. The algorithm can be summarized into this formula $FFT(FFT(x, axis=1), axis=0)$. If *x* is real, $FFT(x, .)$ is complex. We still assume *x* is real, it then becomes (FFT is a linear operator, so $FFT(ix)=i FFT(x)$):\n", + "\n", + "* $y = FFT(x, axis=1)$\n", + "* $z_r = FFT(Real(y), axis=0)$, $z_i = FFT(Imag(y), axis=0)$\n", + "* $z = z_r + i z_i$\n", + "\n", + "*z* is the desired output. The following implementation is probably not the most efficient one. It avoids inplace computation as ONNX does like that." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "dd4fc711", + "metadata": {}, + "outputs": [], + "source": [ + "def fft2d(mat, fft_length):\n", + " mat = mat[:fft_length[0], :fft_length[1]]\n", + " res = mat.copy()\n", + " \n", + " # first FFT\n", + " res = dft_real(res, fft_length=fft_length[1], transpose=True)\n", + " \n", + " # second FFT decomposed on FFT on real part and imaginary part\n", + " res2_real = dft_real(res[0], fft_length=fft_length[0], transpose=False)\n", + " res2_imag = dft_real(res[1], fft_length=fft_length[0], transpose=False) \n", + " res2_imag2 = numpy.vstack([-res2_imag[1:2], res2_imag[:1]])\n", + " res = res2_real + res2_imag2\n", + " size = fft_length[1]//2 + 1\n", + " return res[:, :fft_length[0], :size]\n", + "\n", + "\n", + "fft2d_np = numpy.fft.rfft2(rnd)\n", + "fft2d_cus = fft2d(rnd, rnd.shape)\n", + "almost_equal(fft2d_np, fft2d_cus)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "bb8667e6", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[-4.00478568+0.j , -8.1697212 -0.32937158j,\n", + " 10.53796531-6.04714298j, 6.21706082+1.12281921j],\n", + " [ 9.15418879-5.84728016j, 6.2117498 +3.03786463j,\n", + " 2.96815316-2.94192092j, -0.68635391-1.52268308j],\n", + " [ 0.52234925+2.67149571j, -0.54453905-1.55520678j,\n", + " 3.10942332-1.6485952j , -4.10661833+9.33122798j],\n", + " [ 0.52234925-2.67149571j, -0.17987319+2.84609599j,\n", + " -6.48443904+3.97163974j, 8.55488898-3.49370573j],\n", + " [ 9.15418879+5.84728016j, 2.24954303+6.3142503j ,\n", + " 10.0104149 +0.18248643j, 3.74038426+7.20880644j]])" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "fft2d_np" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "56a94d97", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[[-4.00478568, -8.1697212 , 10.53796531, 6.21706082],\n", + " [ 9.15418879, 6.2117498 , 2.96815316, -0.68635391],\n", + " [ 0.52234925, -0.54453905, 3.10942332, -4.10661833],\n", + " [ 0.52234925, -0.17987319, -6.48443904, 8.55488898],\n", + " [ 9.15418879, 2.24954303, 10.0104149 , 3.74038426]],\n", + "\n", + " [[ 0. , -0.32937158, -6.04714298, 1.12281921],\n", + " [-5.84728016, 3.03786463, -2.94192092, -1.52268308],\n", + " [ 2.67149571, -1.55520678, -1.6485952 , 9.33122798],\n", + " [-2.67149571, 2.84609599, 3.97163974, -3.49370573],\n", + " [ 5.84728016, 6.3142503 , 0.18248643, 7.20880644]]])" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "fft2d_cus" + ] + }, + { + "cell_type": "markdown", + "id": "faa21909", + "metadata": {}, + "source": [ + "And with a different `fft_length`." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "bf98995f", + "metadata": {}, + "outputs": [], + "source": [ + "fft2d_np = numpy.fft.rfft2(rnd, (4, 6))\n", + "fft2d_cus = fft2d(rnd, (4, 6))\n", + "almost_equal(fft2d_np[:4, :], fft2d_cus)" + ] + }, + { + "cell_type": "markdown", + "id": "caee1f84", + "metadata": {}, + "source": [ + "## FFT 2D in ONNX\n", + "\n", + "We use again the numpy API for ONNX." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "ca641274", + "metadata": {}, + "outputs": [], + "source": [ + "def onnx_rfft_1d(x, fft_length=None, transpose=True):\n", + " if fft_length is None:\n", + " raise RuntimeError(\"fft_length must be specified.\")\n", + " \n", + " size = fft_length // 2 + 1\n", + " cst = dft_real_cst(fft_length, fft_length).astype(numpy.float32)\n", + " if transpose:\n", + " xt = npnx.transpose(x, (1, 0))\n", + " res = npnx.matmul(cst[:, :, :fft_length], xt[:fft_length])[:, :size, :]\n", + " return npnx.transpose(res, (0, 2, 1))\n", + " else:\n", + " return npnx.matmul(cst[:, :, :fft_length], x[:fft_length])\n", + "\n", + "\n", + "@onnxnumpy_np(signature=NDArrayType((\"T:all\", ), dtypes_out=('T',)))\n", + "def onnx_rfft_2d(x, fft_length=None):\n", + " mat = x[:fft_length[0], :fft_length[1]]\n", + " \n", + " # first FFT\n", + " res = onnx_rfft_1d(mat, fft_length=fft_length[1], transpose=True)\n", + " \n", + " # second FFT decomposed on FFT on real part and imaginary part\n", + " res2_real = onnx_rfft_1d(res[0], fft_length=fft_length[0], transpose=False)\n", + " res2_imag = onnx_rfft_1d(res[1], fft_length=fft_length[0], transpose=False) \n", + " res2_imag2 = npnx.vstack(-res2_imag[1:2], res2_imag[:1])\n", + " res = res2_real + res2_imag2\n", + " size = fft_length[1]//2 + 1\n", + " return res[:, :fft_length[0], :size]\n", + "\n", + "\n", + "fft2d_cus = fft2d(rnd, rnd.shape)\n", + "fft2d_onx = onnx_rfft_2d(rnd, fft_length=rnd.shape)\n", + "almost_equal(fft2d_cus, fft2d_onx)" + ] + }, + { + "cell_type": "markdown", + "id": "20fcd8a9", + "metadata": {}, + "source": [ + "The corresponding ONNX graph." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "b1379b06", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "key = list(onnx_rfft_2d.signed_compiled)[0]\n", + "%onnxview onnx_rfft_2d.signed_compiled[key].compiled.onnx_" + ] + }, + { + "cell_type": "code", + "execution_count": 70, + "id": "4bae913b", + "metadata": {}, + "outputs": [], + "source": [ + "with open(\"fft2d.onnx\", \"wb\") as f:\n", + " key = list(onnx_rfft_2d.signed_compiled)[0]\n", + " f.write(onnx_rfft_2d.signed_compiled[key].compiled.onnx_.SerializeToString())" + ] + }, + { + "cell_type": "markdown", + "id": "3a747f0c", + "metadata": {}, + "source": [ + "With a different `fft_length`." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "16732cbb", + "metadata": {}, + "outputs": [], + "source": [ + "fft2d_cus = fft2d(rnd, (4, 5))\n", + "fft2d_onx = onnx_rfft_2d(rnd, fft_length=(4, 5))\n", + "almost_equal(fft2d_cus, fft2d_onx)" + ] + }, + { + "cell_type": "markdown", + "id": "04924e7d", + "metadata": {}, + "source": [ + "This implementation of FFT in ONNX assumes shapes and fft lengths are constant. Otherwise, the matrix returned by function `dft_real_cst` must be converted as well. That's left as an exercise." + ] + }, + { + "cell_type": "markdown", + "id": "024db509", + "metadata": {}, + "source": [ + "## FFT2D with shape (3,1,4)\n", + "\n", + "Previous implementation expects the input matrix to have two dimensions. It fails with 3." + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "1e4435be", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(3, 1, 4)" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "shape = (3, 1, 4)\n", + "fft_length = (1, 4)\n", + "rnd = numpy.random.randn(*list(shape)).astype(numpy.float32)\n", + "fft2d_numpy = numpy.fft.fft2(rnd, fft_length)\n", + "fft2d_numpy.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "d8ca3ba4", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[[ 0.70852087+0.j , -1.19732058-1.59361078j,\n", + " -2.03344773+0.j , -1.19732058+1.59361078j]],\n", + "\n", + " [[ 0.28815117+0.j , -1.35137615+2.02822042j,\n", + " 0.9619607 +0.j , -1.35137615-2.02822042j]],\n", + "\n", + " [[-2.07279903+0.j , 0.18918216-1.91049451j,\n", + " -2.78790277+0.j , 0.18918216+1.91049451j]]])" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "fft2d_numpy" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "77d52f4f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "axes don't match array\n" + ] + } + ], + "source": [ + "try:\n", + " fft2d_cus = fft2d(rnd, fft_length)\n", + "except Exception as e:\n", + " print(e)\n", + "# fft2d_onx = onnx_rfft_2d(rnd, fft_length=fft_length)" + ] + }, + { + "cell_type": "markdown", + "id": "504d01de", + "metadata": {}, + "source": [ + "### numpy version\n", + "\n", + "Let's do it again with numpy first. [fft2](https://numpy.org/doc/stable/reference/generated/numpy.fft.fft2.html) performs `fft2` on the last two axis as many times as the first axis. The goal is still to have an implementation which works for any dimension." + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "4157ce65", + "metadata": {}, + "outputs": [], + "source": [ + "conc = []\n", + "for i in range(rnd.shape[0]):\n", + " f2 = fft2d(rnd[i], fft_length)\n", + " conc.append(numpy.expand_dims(f2, 0))\n", + "res = numpy.vstack(conc).transpose(1, 0, 2, 3)\n", + "almost_equal(fft2d_numpy[:, :, :3], res)" + ] + }, + { + "cell_type": "markdown", + "id": "cdc3836a", + "metadata": {}, + "source": [ + "It works. And now a more efficient implementation. It is better to read [matmul](https://numpy.org/doc/stable/reference/generated/numpy.matmul.html) description before. To summarize, a third axis is equivalent to many matrix multiplications over the last two axes, as many as the dimension of the first axis: ``matmul(A[I,J,K], B[I,K,L]) --> C[I,J,L]``. Broadcasting also works... ``matmul(A[1,J,K], B[I,K,L]) --> C[I,J,L]``." + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "bf76dd61", + "metadata": {}, + "outputs": [], + "source": [ + "def dft_real_d3(x, fft_length=None, transpose=True):\n", + " if len(x.shape) != 3:\n", + " raise RuntimeError(\"Not implemented for shape=%r.\" % x.shape)\n", + " N = x.shape[1]\n", + " C = x.shape[-1] if transpose else x.shape[-2]\n", + " if fft_length is None:\n", + " fft_length = x.shape[-1]\n", + " size = fft_length // 2 + 1\n", + "\n", + " cst = dft_real_cst(C, fft_length)\n", + " if transpose:\n", + " x = numpy.transpose(x, (0, 2, 1))\n", + " a = cst[:, :, :fft_length]\n", + " b = x[:, :fft_length, :]\n", + " a = numpy.expand_dims(a, 0)\n", + " b = numpy.expand_dims(b, 1)\n", + " res = numpy.matmul(a, b)\n", + " res = res[:, :, :size, :]\n", + " return numpy.transpose(res, (1, 0, 3, 2))\n", + " else:\n", + " a = cst[:, :, :fft_length]\n", + " b = x[:, :fft_length, :]\n", + " a = numpy.expand_dims(a, 0)\n", + " b = numpy.expand_dims(b, 1)\n", + " res = numpy.matmul(a, b)\n", + " return numpy.transpose(res, (1, 0, 2, 3))\n", + "\n", + "\n", + "def fft2d_d3(mat, fft_length):\n", + " mat = mat[:, :fft_length[-2], :fft_length[-1]]\n", + " res = mat.copy()\n", + " \n", + " # first FFT\n", + " res = dft_real_d3(res, fft_length=fft_length[-1], transpose=True)\n", + " \n", + " # second FFT decomposed on FFT on real part and imaginary part\n", + " res2_real = dft_real_d3(res[0], fft_length=fft_length[-2], transpose=False)\n", + " res2_imag = dft_real_d3(res[1], fft_length=fft_length[-2], transpose=False)\n", + " res2_imag2 = numpy.vstack([-res2_imag[1:2], res2_imag[:1]])\n", + " res = res2_real + res2_imag2\n", + " size = fft_length[-1]//2 + 1\n", + " return res[:, :, :fft_length[-2], :size]\n", + "\n", + "\n", + "def fft2d_any(mat, fft_length):\n", + " new_shape = (-1, ) + mat.shape[-2:]\n", + " mat2 = mat.reshape(new_shape)\n", + " f2 = fft2d_d3(mat2, fft_length)\n", + " new_shape = (2, ) + mat.shape[:-2] + f2.shape[-2:]\n", + " return f2.reshape(new_shape)\n", + "\n", + "\n", + "shape = (3, 1, 4)\n", + "fft_length = (1, 4)\n", + "rnd = numpy.random.randn(*list(shape)).astype(numpy.float32)\n", + "fft2d_numpy = numpy.fft.fft2(rnd, fft_length)\n", + "fft2d_cus = fft2d_any(rnd, fft_length)\n", + "almost_equal(fft2d_numpy[..., :3], fft2d_cus)" + ] + }, + { + "cell_type": "markdown", + "id": "8ddf5e67", + "metadata": {}, + "source": [ + "We check with more shapes to see if the implementation works for all of them." + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "c2a2d068", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "OK x.shape=(3, 1, 4) length=(1, 4) output shape=(3, 1, 4) or (2, 3, 1, 3)\n", + "OK x.shape=(3, 1, 4) length=(1, 4) output shape=(3, 1, 4) or (2, 3, 1, 3)\n", + "OK x.shape=(3, 1, 4) length=(1, 4) output shape=(3, 1, 4) or (2, 3, 1, 3)\n", + "OK x.shape=(3, 1, 4) length=(1, 2) output shape=(3, 1, 2) or (2, 3, 1, 2)\n", + "OK x.shape=(3, 1, 4) length=(1, 1) output shape=(3, 1, 1) or (2, 3, 1, 1)\n", + "OK x.shape=(5, 7) length=(5, 7) output shape=(5, 7) or (2, 5, 4)\n", + "OK x.shape=(5, 7) length=(1, 7) output shape=(1, 7) or (2, 1, 4)\n", + "OK x.shape=(5, 7) length=(2, 7) output shape=(2, 7) or (2, 2, 4)\n", + "OK x.shape=(5, 7) length=(5, 2) output shape=(5, 2) or (2, 5, 2)\n", + "OK x.shape=(5, 7) length=(3, 4) output shape=(3, 4) or (2, 3, 3)\n", + "OK x.shape=(3, 5, 7) length=(5, 7) output shape=(3, 5, 7) or (2, 3, 5, 4)\n", + "OK x.shape=(3, 5, 7) length=(1, 7) output shape=(3, 1, 7) or (2, 3, 1, 4)\n", + "OK x.shape=(3, 5, 7) length=(2, 7) output shape=(3, 2, 7) or (2, 3, 2, 4)\n", + "OK x.shape=(3, 5, 7) length=(5, 2) output shape=(3, 5, 2) or (2, 3, 5, 2)\n", + "OK x.shape=(3, 5, 7) length=(3, 4) output shape=(3, 3, 4) or (2, 3, 3, 3)\n", + "OK x.shape=(7, 5) length=(7, 5) output shape=(7, 5) or (2, 7, 3)\n", + "OK x.shape=(7, 5) length=(1, 5) output shape=(1, 5) or (2, 1, 3)\n", + "OK x.shape=(7, 5) length=(2, 5) output shape=(2, 5) or (2, 2, 3)\n", + "OK x.shape=(7, 5) length=(7, 2) output shape=(7, 2) or (2, 7, 2)\n", + "OK x.shape=(7, 5) length=(3, 4) output shape=(3, 4) or (2, 3, 3)\n" + ] + } + ], + "source": [ + "for shape in [(3, 1, 4), (5, 7), (3, 5, 7), (7, 5)]:\n", + " for fft_length in [shape[-2:], (1, shape[-1]),\n", + " (min(2, shape[-2]), shape[-1]),\n", + " (shape[-2], 2),\n", + " (min(3, shape[-2]), min(4, shape[-2]))]:\n", + " x = numpy.random.randn(*list(shape)).astype(numpy.float32)\n", + " fnp = numpy.fft.fft2(x, fft_length)\n", + " if len(fnp.shape) == 2:\n", + " fn= numpy.expand_dims(fnp, 0)\n", + " try:\n", + " cus = fft2d_any(x, fft_length)\n", + " except IndexError as e:\n", + " print(\"ERR x.shape=%r length=%r error=%r\" % (x.shape, fft_length, e))\n", + " continue\n", + " try:\n", + " almost_equal(fnp[..., :cus.shape[-1]], cus)\n", + " except (AssertionError, IndexError) as e:\n", + " print(\"DIS x.shape=%r length=%r error=%r output shape=%r or %r\" % (\n", + " x.shape, fft_length, e, fnp.shape, cus.shape))\n", + " continue\n", + " print(\"OK x.shape=%r length=%r output shape=%r or %r\" % (\n", + " x.shape, fft_length, fnp.shape, cus.shape))" + ] + }, + { + "cell_type": "markdown", + "id": "01d727fa", + "metadata": {}, + "source": [ + "### ONNX version\n", + "\n", + "Let's look into the differences first." + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "fe63d3ed", + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext pyquickhelper" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "650c4849", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
populating...
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "/*\n", + "This is part of jsdifflib v1.0. \n", + "\n", + "Copyright 2007 - 2011 Chas Emerick . All rights reserved.\n", + "\n", + "Redistribution and use in source and binary forms, with or without modification, are\n", + "permitted provided that the following conditions are met:\n", + "\n", + " 1. Redistributions of source code must retain the above copyright notice, this list of\n", + " conditions and the following disclaimer.\n", + "\n", + " 2. Redistributions in binary form must reproduce the above copyright notice, this list\n", + " of conditions and the following disclaimer in the documentation and/or other materials\n", + " provided with the distribution.\n", + "\n", + "THIS SOFTWARE IS PROVIDED BY Chas Emerick ``AS IS'' AND ANY EXPRESS OR IMPLIED\n", + "WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND\n", + "FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL Chas Emerick OR\n", + "CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR\n", + "CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n", + "SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON\n", + "ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING\n", + "NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF\n", + "ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n", + "\n", + "The views and conclusions contained in the software and documentation are those of the\n", + "authors and should not be interpreted as representing official policies, either expressed\n", + "or implied, of Chas Emerick.\n", + "*/\n", + "var diffview = {\n", + "\t/**\n", + "\t * Builds and returns a visual diff view. The single parameter, `params', should contain\n", + "\t * the following values:\n", + "\t *\n", + "\t * - baseTextLines: the array of strings that was used as the base text input to SequenceMatcher\n", + "\t * - newTextLines: the array of strings that was used as the new text input to SequenceMatcher\n", + "\t * - opcodes: the array of arrays returned by SequenceMatcher.get_opcodes()\n", + "\t * - baseTextName: the title to be displayed above the base text listing in the diff view; defaults\n", + "\t *\t to \"Base Text\"\n", + "\t * - newTextName: the title to be displayed above the new text listing in the diff view; defaults\n", + "\t *\t to \"New Text\"\n", + "\t * - contextSize: the number of lines of context to show around differences; by default, all lines\n", + "\t *\t are shown\n", + "\t * - viewType: if 0, a side-by-side diff view is generated (default); if 1, an inline diff view is\n", + "\t *\t generated\n", + "\t */\n", + "\tbuildView: function (params) {\n", + "\t\tvar baseTextLines = params.baseTextLines;\n", + "\t\tvar newTextLines = params.newTextLines;\n", + "\t\tvar opcodes = params.opcodes;\n", + "\t\tvar baseTextName = params.baseTextName ? params.baseTextName : \"Base Text\";\n", + "\t\tvar newTextName = params.newTextName ? params.newTextName : \"New Text\";\n", + "\t\tvar contextSize = params.contextSize;\n", + "\t\tvar inline = (params.viewType == 0 || params.viewType == 1) ? params.viewType : 0;\n", + "\n", + "\t\tif (baseTextLines == null)\n", + "\t\t\tthrow \"Cannot build diff view; baseTextLines is not defined.\";\n", + "\t\tif (newTextLines == null)\n", + "\t\t\tthrow \"Cannot build diff view; newTextLines is not defined.\";\n", + "\t\tif (!opcodes)\n", + "\t\t\tthrow \"Cannot build diff view; opcodes is not defined.\";\n", + "\t\t\n", + "\t\tfunction celt (name, clazz) {\n", + "\t\t\tvar e = document.createElement(name);\n", + "\t\t\te.className = clazz;\n", + "\t\t\treturn e;\n", + "\t\t}\n", + "\t\t\n", + "\t\tfunction telt (name, text) {\n", + "\t\t\tvar e = document.createElement(name);\n", + "\t\t\te.appendChild(document.createTextNode(text));\n", + "\t\t\treturn e;\n", + "\t\t}\n", + "\t\t\n", + "\t\tfunction ctelt (name, clazz, text) {\n", + "\t\t\tvar e = document.createElement(name);\n", + "\t\t\te.className = clazz;\n", + "\t\t\te.appendChild(document.createTextNode(text));\n", + "\t\t\treturn e;\n", + "\t\t}\n", + "\t\n", + "\t\tvar tdata = document.createElement(\"thead\");\n", + "\t\tvar node = document.createElement(\"tr\");\n", + "\t\ttdata.appendChild(node);\n", + "\t\tif (inline) {\n", + "\t\t\tnode.appendChild(document.createElement(\"th\"));\n", + "\t\t\tnode.appendChild(document.createElement(\"th\"));\n", + "\t\t\tnode.appendChild(ctelt(\"th\", \"texttitle\", baseTextName + \" vs. \" + newTextName));\n", + "\t\t} else {\n", + "\t\t\tnode.appendChild(document.createElement(\"th\"));\n", + "\t\t\tnode.appendChild(ctelt(\"th\", \"texttitle\", baseTextName));\n", + "\t\t\tnode.appendChild(document.createElement(\"th\"));\n", + "\t\t\tnode.appendChild(ctelt(\"th\", \"texttitle\", newTextName));\n", + "\t\t}\n", + "\t\ttdata = [tdata];\n", + "\t\t\n", + "\t\tvar rows = [];\n", + "\t\tvar node2;\n", + "\t\t\n", + "\t\t/**\n", + "\t\t * Adds two cells to the given row; if the given row corresponds to a real\n", + "\t\t * line number (based on the line index tidx and the endpoint of the \n", + "\t\t * range in question tend), then the cells will contain the line number\n", + "\t\t * and the line of text from textLines at position tidx (with the class of\n", + "\t\t * the second cell set to the name of the change represented), and tidx + 1 will\n", + "\t\t * be returned.\t Otherwise, tidx is returned, and two empty cells are added\n", + "\t\t * to the given row.\n", + "\t\t */\n", + "\t\tfunction addCells (row, tidx, tend, textLines, change) {\n", + "\t\t\tif (tidx < tend) {\n", + "\t\t\t\trow.appendChild(telt(\"th\", (tidx + 1).toString()));\n", + "\t\t\t\trow.appendChild(ctelt(\"td\", change, textLines[tidx].replace(/\\t/g, \"\\u00a0\\u00a0\\u00a0\\u00a0\")));\n", + "\t\t\t\treturn tidx + 1;\n", + "\t\t\t} else {\n", + "\t\t\t\trow.appendChild(document.createElement(\"th\"));\n", + "\t\t\t\trow.appendChild(celt(\"td\", \"empty\"));\n", + "\t\t\t\treturn tidx;\n", + "\t\t\t}\n", + "\t\t}\n", + "\t\t\n", + "\t\tfunction addCellsInline (row, tidx, tidx2, textLines, change) {\n", + "\t\t\trow.appendChild(telt(\"th\", tidx == null ? \"\" : (tidx + 1).toString()));\n", + "\t\t\trow.appendChild(telt(\"th\", tidx2 == null ? \"\" : (tidx2 + 1).toString()));\n", + "\t\t\trow.appendChild(ctelt(\"td\", change, textLines[tidx != null ? tidx : tidx2].replace(/\\t/g, \"\\u00a0\\u00a0\\u00a0\\u00a0\")));\n", + "\t\t}\n", + "\t\t\n", + "\t\tfor (var idx = 0; idx < opcodes.length; idx++) {\n", + "\t\t\tvar code = opcodes[idx];\n", + "\t\t\tvar change = code[0];\n", + "\t\t\tvar b = code[1];\n", + "\t\t\tvar be = code[2];\n", + "\t\t\tvar n = code[3];\n", + "\t\t\tvar ne = code[4];\n", + "\t\t\tvar rowcnt = Math.max(be - b, ne - n);\n", + "\t\t\tvar toprows = [];\n", + "\t\t\tvar botrows = [];\n", + "\t\t\tfor (var i = 0; i < rowcnt; i++) {\n", + "\t\t\t\t// jump ahead if we've alredy provided leading context or if this is the first range\n", + "\t\t\t\tif (contextSize && opcodes.length > 1 && ((idx > 0 && i == contextSize) || (idx == 0 && i == 0)) && change==\"equal\") {\n", + "\t\t\t\t\tvar jump = rowcnt - ((idx == 0 ? 1 : 2) * contextSize);\n", + "\t\t\t\t\tif (jump > 1) {\n", + "\t\t\t\t\t\ttoprows.push(node = document.createElement(\"tr\"));\n", + "\t\t\t\t\t\t\n", + "\t\t\t\t\t\tb += jump;\n", + "\t\t\t\t\t\tn += jump;\n", + "\t\t\t\t\t\ti += jump - 1;\n", + "\t\t\t\t\t\tnode.appendChild(telt(\"th\", \"...\"));\n", + "\t\t\t\t\t\tif (!inline) node.appendChild(ctelt(\"td\", \"skip\", \"\"));\n", + "\t\t\t\t\t\tnode.appendChild(telt(\"th\", \"...\"));\n", + "\t\t\t\t\t\tnode.appendChild(ctelt(\"td\", \"skip\", \"\"));\n", + "\t\t\t\t\t\t\n", + "\t\t\t\t\t\t// skip last lines if they're all equal\n", + "\t\t\t\t\t\tif (idx + 1 == opcodes.length) {\n", + "\t\t\t\t\t\t\tbreak;\n", + "\t\t\t\t\t\t} else {\n", + "\t\t\t\t\t\t\tcontinue;\n", + "\t\t\t\t\t\t}\n", + "\t\t\t\t\t}\n", + "\t\t\t\t}\n", + "\t\t\t\t\n", + "\t\t\t\ttoprows.push(node = document.createElement(\"tr\"));\n", + "\t\t\t\tif (inline) {\n", + "\t\t\t\t\tif (change == \"insert\") {\n", + "\t\t\t\t\t\taddCellsInline(node, null, n++, newTextLines, change);\n", + "\t\t\t\t\t} else if (change == \"replace\") {\n", + "\t\t\t\t\t\tbotrows.push(node2 = document.createElement(\"tr\"));\n", + "\t\t\t\t\t\tif (b < be) addCellsInline(node, b++, null, baseTextLines, \"delete\");\n", + "\t\t\t\t\t\tif (n < ne) addCellsInline(node2, null, n++, newTextLines, \"insert\");\n", + "\t\t\t\t\t} else if (change == \"delete\") {\n", + "\t\t\t\t\t\taddCellsInline(node, b++, null, baseTextLines, change);\n", + "\t\t\t\t\t} else {\n", + "\t\t\t\t\t\t// equal\n", + "\t\t\t\t\t\taddCellsInline(node, b++, n++, baseTextLines, change);\n", + "\t\t\t\t\t}\n", + "\t\t\t\t} else {\n", + "\t\t\t\t\tb = addCells(node, b, be, baseTextLines, change);\n", + "\t\t\t\t\tn = addCells(node, n, ne, newTextLines, change);\n", + "\t\t\t\t}\n", + "\t\t\t}\n", + "\n", + "\t\t\tfor (var i = 0; i < toprows.length; i++) rows.push(toprows[i]);\n", + "\t\t\tfor (var i = 0; i < botrows.length; i++) rows.push(botrows[i]);\n", + "\t\t}\n", + "\t\t\n", + "\t\trows.push(node = ctelt(\"th\", \"author\", \"diff view generated by \"));\n", + "\t\tnode.setAttribute(\"colspan\", inline ? 3 : 4);\n", + "\t\tnode.appendChild(node2 = telt(\"a\", \"jsdifflib\"));\n", + "\t\tnode2.setAttribute(\"href\", \"http://github.com/cemerick/jsdifflib\");\n", + "\t\t\n", + "\t\ttdata.push(node = document.createElement(\"tbody\"));\n", + "\t\tfor (var idx in rows) rows.hasOwnProperty(idx) && node.appendChild(rows[idx]);\n", + "\t\t\n", + "\t\tnode = celt(\"table\", \"diff\" + (inline ? \" inlinediff\" : \"\"));\n", + "\t\tfor (var idx in tdata) tdata.hasOwnProperty(idx) && node.appendChild(tdata[idx]);\n", + "\t\treturn node;\n", + "\t}\n", + "};\n", + "\n", + "\n", + "/***\n", + "This is part of jsdifflib v1.0. \n", + "\n", + "Copyright (c) 2007, Snowtide Informatics Systems, Inc.\n", + "All rights reserved.\n", + "\n", + "Redistribution and use in source and binary forms, with or without modification,\n", + "are permitted provided that the following conditions are met:\n", + "\n", + "\t* Redistributions of source code must retain the above copyright notice, this\n", + "\t\tlist of conditions and the following disclaimer.\n", + "\t* Redistributions in binary form must reproduce the above copyright notice,\n", + "\t\tthis list of conditions and the following disclaimer in the documentation\n", + "\t\tand/or other materials provided with the distribution.\n", + "\t* Neither the name of the Snowtide Informatics Systems nor the names of its\n", + "\t\tcontributors may be used to endorse or promote products derived from this\n", + "\t\tsoftware without specific prior written permission.\n", + "\n", + "THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND ANY\n", + "EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES\n", + "OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT\n", + "SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,\n", + "INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED\n", + "TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR\n", + "BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN\n", + "CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN\n", + "ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH\n", + "DAMAGE.\n", + "***/\n", + "/* Author: Chas Emerick */\n", + "var __whitespace = {\" \":true, \"\\t\":true, \"\\n\":true, \"\\f\":true, \"\\r\":true};\n", + "\n", + "var difflib = {\n", + "\tdefaultJunkFunction: function (c) {\n", + "\t\treturn __whitespace.hasOwnProperty(c);\n", + "\t},\n", + "\t\n", + "\tstripLinebreaks: function (str) { return str.replace(/^[\\n\\r]*|[\\n\\r]*$/g, \"\"); },\n", + "\t\n", + "\tstringAsLines: function (str) {\n", + "\t\tvar lfpos = str.indexOf(\"\\n\");\n", + "\t\tvar crpos = str.indexOf(\"\\r\");\n", + "\t\tvar linebreak = ((lfpos > -1 && crpos > -1) || crpos < 0) ? \"\\n\" : \"\\r\";\n", + "\t\t\n", + "\t\tvar lines = str.split(linebreak);\n", + "\t\tfor (var i = 0; i < lines.length; i++) {\n", + "\t\t\tlines[i] = difflib.stripLinebreaks(lines[i]);\n", + "\t\t}\n", + "\t\t\n", + "\t\treturn lines;\n", + "\t},\n", + "\t\n", + "\t// iteration-based reduce implementation\n", + "\t__reduce: function (func, list, initial) {\n", + "\t\tif (initial != null) {\n", + "\t\t\tvar value = initial;\n", + "\t\t\tvar idx = 0;\n", + "\t\t} else if (list) {\n", + "\t\t\tvar value = list[0];\n", + "\t\t\tvar idx = 1;\n", + "\t\t} else {\n", + "\t\t\treturn null;\n", + "\t\t}\n", + "\t\t\n", + "\t\tfor (; idx < list.length; idx++) {\n", + "\t\t\tvalue = func(value, list[idx]);\n", + "\t\t}\n", + "\t\t\n", + "\t\treturn value;\n", + "\t},\n", + "\t\n", + "\t// comparison function for sorting lists of numeric tuples\n", + "\t__ntuplecomp: function (a, b) {\n", + "\t\tvar mlen = Math.max(a.length, b.length);\n", + "\t\tfor (var i = 0; i < mlen; i++) {\n", + "\t\t\tif (a[i] < b[i]) return -1;\n", + "\t\t\tif (a[i] > b[i]) return 1;\n", + "\t\t}\n", + "\t\t\n", + "\t\treturn a.length == b.length ? 0 : (a.length < b.length ? -1 : 1);\n", + "\t},\n", + "\t\n", + "\t__calculate_ratio: function (matches, length) {\n", + "\t\treturn length ? 2.0 * matches / length : 1.0;\n", + "\t},\n", + "\t\n", + "\t// returns a function that returns true if a key passed to the returned function\n", + "\t// is in the dict (js object) provided to this function; replaces being able to\n", + "\t// carry around dict.has_key in python...\n", + "\t__isindict: function (dict) {\n", + "\t\treturn function (key) { return dict.hasOwnProperty(key); };\n", + "\t},\n", + "\t\n", + "\t// replacement for python's dict.get function -- need easy default values\n", + "\t__dictget: function (dict, key, defaultValue) {\n", + "\t\treturn dict.hasOwnProperty(key) ? dict[key] : defaultValue;\n", + "\t},\t\n", + "\t\n", + "\tSequenceMatcher: function (a, b, isjunk) {\n", + "\t\tthis.set_seqs = function (a, b) {\n", + "\t\t\tthis.set_seq1(a);\n", + "\t\t\tthis.set_seq2(b);\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.set_seq1 = function (a) {\n", + "\t\t\tif (a == this.a) return;\n", + "\t\t\tthis.a = a;\n", + "\t\t\tthis.matching_blocks = this.opcodes = null;\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.set_seq2 = function (b) {\n", + "\t\t\tif (b == this.b) return;\n", + "\t\t\tthis.b = b;\n", + "\t\t\tthis.matching_blocks = this.opcodes = this.fullbcount = null;\n", + "\t\t\tthis.__chain_b();\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.__chain_b = function () {\n", + "\t\t\tvar b = this.b;\n", + "\t\t\tvar n = b.length;\n", + "\t\t\tvar b2j = this.b2j = {};\n", + "\t\t\tvar populardict = {};\n", + "\t\t\tfor (var i = 0; i < b.length; i++) {\n", + "\t\t\t\tvar elt = b[i];\n", + "\t\t\t\tif (b2j.hasOwnProperty(elt)) {\n", + "\t\t\t\t\tvar indices = b2j[elt];\n", + "\t\t\t\t\tif (n >= 200 && indices.length * 100 > n) {\n", + "\t\t\t\t\t\tpopulardict[elt] = 1;\n", + "\t\t\t\t\t\tdelete b2j[elt];\n", + "\t\t\t\t\t} else {\n", + "\t\t\t\t\t\tindices.push(i);\n", + "\t\t\t\t\t}\n", + "\t\t\t\t} else {\n", + "\t\t\t\t\tb2j[elt] = [i];\n", + "\t\t\t\t}\n", + "\t\t\t}\n", + "\t\n", + "\t\t\tfor (var elt in populardict) {\n", + "\t\t\t\tif (populardict.hasOwnProperty(elt)) {\n", + "\t\t\t\t\tdelete b2j[elt];\n", + "\t\t\t\t}\n", + "\t\t\t}\n", + "\t\t\t\n", + "\t\t\tvar isjunk = this.isjunk;\n", + "\t\t\tvar junkdict = {};\n", + "\t\t\tif (isjunk) {\n", + "\t\t\t\tfor (var elt in populardict) {\n", + "\t\t\t\t\tif (populardict.hasOwnProperty(elt) && isjunk(elt)) {\n", + "\t\t\t\t\t\tjunkdict[elt] = 1;\n", + "\t\t\t\t\t\tdelete populardict[elt];\n", + "\t\t\t\t\t}\n", + "\t\t\t\t}\n", + "\t\t\t\tfor (var elt in b2j) {\n", + "\t\t\t\t\tif (b2j.hasOwnProperty(elt) && isjunk(elt)) {\n", + "\t\t\t\t\t\tjunkdict[elt] = 1;\n", + "\t\t\t\t\t\tdelete b2j[elt];\n", + "\t\t\t\t\t}\n", + "\t\t\t\t}\n", + "\t\t\t}\n", + "\t\n", + "\t\t\tthis.isbjunk = difflib.__isindict(junkdict);\n", + "\t\t\tthis.isbpopular = difflib.__isindict(populardict);\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.find_longest_match = function (alo, ahi, blo, bhi) {\n", + "\t\t\tvar a = this.a;\n", + "\t\t\tvar b = this.b;\n", + "\t\t\tvar b2j = this.b2j;\n", + "\t\t\tvar isbjunk = this.isbjunk;\n", + "\t\t\tvar besti = alo;\n", + "\t\t\tvar bestj = blo;\n", + "\t\t\tvar bestsize = 0;\n", + "\t\t\tvar j = null;\n", + "\t\t\tvar k;\n", + "\t\n", + "\t\t\tvar j2len = {};\n", + "\t\t\tvar nothing = [];\n", + "\t\t\tfor (var i = alo; i < ahi; i++) {\n", + "\t\t\t\tvar newj2len = {};\n", + "\t\t\t\tvar jdict = difflib.__dictget(b2j, a[i], nothing);\n", + "\t\t\t\tfor (var jkey in jdict) {\n", + "\t\t\t\t\tif (jdict.hasOwnProperty(jkey)) {\n", + "\t\t\t\t\t\tj = jdict[jkey];\n", + "\t\t\t\t\t\tif (j < blo) continue;\n", + "\t\t\t\t\t\tif (j >= bhi) break;\n", + "\t\t\t\t\t\tnewj2len[j] = k = difflib.__dictget(j2len, j - 1, 0) + 1;\n", + "\t\t\t\t\t\tif (k > bestsize) {\n", + "\t\t\t\t\t\t\tbesti = i - k + 1;\n", + "\t\t\t\t\t\t\tbestj = j - k + 1;\n", + "\t\t\t\t\t\t\tbestsize = k;\n", + "\t\t\t\t\t\t}\n", + "\t\t\t\t\t}\n", + "\t\t\t\t}\n", + "\t\t\t\tj2len = newj2len;\n", + "\t\t\t}\n", + "\t\n", + "\t\t\twhile (besti > alo && bestj > blo && !isbjunk(b[bestj - 1]) && a[besti - 1] == b[bestj - 1]) {\n", + "\t\t\t\tbesti--;\n", + "\t\t\t\tbestj--;\n", + "\t\t\t\tbestsize++;\n", + "\t\t\t}\n", + "\t\t\t\t\n", + "\t\t\twhile (besti + bestsize < ahi && bestj + bestsize < bhi &&\n", + "\t\t\t\t\t!isbjunk(b[bestj + bestsize]) &&\n", + "\t\t\t\t\ta[besti + bestsize] == b[bestj + bestsize]) {\n", + "\t\t\t\tbestsize++;\n", + "\t\t\t}\n", + "\t\n", + "\t\t\twhile (besti > alo && bestj > blo && isbjunk(b[bestj - 1]) && a[besti - 1] == b[bestj - 1]) {\n", + "\t\t\t\tbesti--;\n", + "\t\t\t\tbestj--;\n", + "\t\t\t\tbestsize++;\n", + "\t\t\t}\n", + "\t\t\t\n", + "\t\t\twhile (besti + bestsize < ahi && bestj + bestsize < bhi && isbjunk(b[bestj + bestsize]) &&\n", + "\t\t\t\t\ta[besti + bestsize] == b[bestj + bestsize]) {\n", + "\t\t\t\tbestsize++;\n", + "\t\t\t}\n", + "\t\n", + "\t\t\treturn [besti, bestj, bestsize];\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.get_matching_blocks = function () {\n", + "\t\t\tif (this.matching_blocks != null) return this.matching_blocks;\n", + "\t\t\tvar la = this.a.length;\n", + "\t\t\tvar lb = this.b.length;\n", + "\t\n", + "\t\t\tvar queue = [[0, la, 0, lb]];\n", + "\t\t\tvar matching_blocks = [];\n", + "\t\t\tvar alo, ahi, blo, bhi, qi, i, j, k, x;\n", + "\t\t\twhile (queue.length) {\n", + "\t\t\t\tqi = queue.pop();\n", + "\t\t\t\talo = qi[0];\n", + "\t\t\t\tahi = qi[1];\n", + "\t\t\t\tblo = qi[2];\n", + "\t\t\t\tbhi = qi[3];\n", + "\t\t\t\tx = this.find_longest_match(alo, ahi, blo, bhi);\n", + "\t\t\t\ti = x[0];\n", + "\t\t\t\tj = x[1];\n", + "\t\t\t\tk = x[2];\n", + "\t\n", + "\t\t\t\tif (k) {\n", + "\t\t\t\t\tmatching_blocks.push(x);\n", + "\t\t\t\t\tif (alo < i && blo < j)\n", + "\t\t\t\t\t\tqueue.push([alo, i, blo, j]);\n", + "\t\t\t\t\tif (i+k < ahi && j+k < bhi)\n", + "\t\t\t\t\t\tqueue.push([i + k, ahi, j + k, bhi]);\n", + "\t\t\t\t}\n", + "\t\t\t}\n", + "\t\t\t\n", + "\t\t\tmatching_blocks.sort(difflib.__ntuplecomp);\n", + "\t\n", + "\t\t\tvar i1 = 0, j1 = 0, k1 = 0, block = 0;\n", + "\t\t\tvar i2, j2, k2;\n", + "\t\t\tvar non_adjacent = [];\n", + "\t\t\tfor (var idx in matching_blocks) {\n", + "\t\t\t\tif (matching_blocks.hasOwnProperty(idx)) {\n", + "\t\t\t\t\tblock = matching_blocks[idx];\n", + "\t\t\t\t\ti2 = block[0];\n", + "\t\t\t\t\tj2 = block[1];\n", + "\t\t\t\t\tk2 = block[2];\n", + "\t\t\t\t\tif (i1 + k1 == i2 && j1 + k1 == j2) {\n", + "\t\t\t\t\t\tk1 += k2;\n", + "\t\t\t\t\t} else {\n", + "\t\t\t\t\t\tif (k1) non_adjacent.push([i1, j1, k1]);\n", + "\t\t\t\t\t\ti1 = i2;\n", + "\t\t\t\t\t\tj1 = j2;\n", + "\t\t\t\t\t\tk1 = k2;\n", + "\t\t\t\t\t}\n", + "\t\t\t\t}\n", + "\t\t\t}\n", + "\t\t\t\n", + "\t\t\tif (k1) non_adjacent.push([i1, j1, k1]);\n", + "\t\n", + "\t\t\tnon_adjacent.push([la, lb, 0]);\n", + "\t\t\tthis.matching_blocks = non_adjacent;\n", + "\t\t\treturn this.matching_blocks;\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.get_opcodes = function () {\n", + "\t\t\tif (this.opcodes != null) return this.opcodes;\n", + "\t\t\tvar i = 0;\n", + "\t\t\tvar j = 0;\n", + "\t\t\tvar answer = [];\n", + "\t\t\tthis.opcodes = answer;\n", + "\t\t\tvar block, ai, bj, size, tag;\n", + "\t\t\tvar blocks = this.get_matching_blocks();\n", + "\t\t\tfor (var idx in blocks) {\n", + "\t\t\t\tif (blocks.hasOwnProperty(idx)) {\n", + "\t\t\t\t\tblock = blocks[idx];\n", + "\t\t\t\t\tai = block[0];\n", + "\t\t\t\t\tbj = block[1];\n", + "\t\t\t\t\tsize = block[2];\n", + "\t\t\t\t\ttag = '';\n", + "\t\t\t\t\tif (i < ai && j < bj) {\n", + "\t\t\t\t\t\ttag = 'replace';\n", + "\t\t\t\t\t} else if (i < ai) {\n", + "\t\t\t\t\t\ttag = 'delete';\n", + "\t\t\t\t\t} else if (j < bj) {\n", + "\t\t\t\t\t\ttag = 'insert';\n", + "\t\t\t\t\t}\n", + "\t\t\t\t\tif (tag) answer.push([tag, i, ai, j, bj]);\n", + "\t\t\t\t\ti = ai + size;\n", + "\t\t\t\t\tj = bj + size;\n", + "\t\t\t\t\t\n", + "\t\t\t\t\tif (size) answer.push(['equal', ai, i, bj, j]);\n", + "\t\t\t\t}\n", + "\t\t\t}\n", + "\t\t\t\n", + "\t\t\treturn answer;\n", + "\t\t}\n", + "\t\t\n", + "\t\t// this is a generator function in the python lib, which of course is not supported in javascript\n", + "\t\t// the reimplementation builds up the grouped opcodes into a list in their entirety and returns that.\n", + "\t\tthis.get_grouped_opcodes = function (n) {\n", + "\t\t\tif (!n) n = 3;\n", + "\t\t\tvar codes = this.get_opcodes();\n", + "\t\t\tif (!codes) codes = [[\"equal\", 0, 1, 0, 1]];\n", + "\t\t\tvar code, tag, i1, i2, j1, j2;\n", + "\t\t\tif (codes[0][0] == 'equal') {\n", + "\t\t\t\tcode = codes[0];\n", + "\t\t\t\ttag = code[0];\n", + "\t\t\t\ti1 = code[1];\n", + "\t\t\t\ti2 = code[2];\n", + "\t\t\t\tj1 = code[3];\n", + "\t\t\t\tj2 = code[4];\n", + "\t\t\t\tcodes[0] = [tag, Math.max(i1, i2 - n), i2, Math.max(j1, j2 - n), j2];\n", + "\t\t\t}\n", + "\t\t\tif (codes[codes.length - 1][0] == 'equal') {\n", + "\t\t\t\tcode = codes[codes.length - 1];\n", + "\t\t\t\ttag = code[0];\n", + "\t\t\t\ti1 = code[1];\n", + "\t\t\t\ti2 = code[2];\n", + "\t\t\t\tj1 = code[3];\n", + "\t\t\t\tj2 = code[4];\n", + "\t\t\t\tcodes[codes.length - 1] = [tag, i1, Math.min(i2, i1 + n), j1, Math.min(j2, j1 + n)];\n", + "\t\t\t}\n", + "\t\n", + "\t\t\tvar nn = n + n;\n", + "\t\t\tvar group = [];\n", + "\t\t\tvar groups = [];\n", + "\t\t\tfor (var idx in codes) {\n", + "\t\t\t\tif (codes.hasOwnProperty(idx)) {\n", + "\t\t\t\t\tcode = codes[idx];\n", + "\t\t\t\t\ttag = code[0];\n", + "\t\t\t\t\ti1 = code[1];\n", + "\t\t\t\t\ti2 = code[2];\n", + "\t\t\t\t\tj1 = code[3];\n", + "\t\t\t\t\tj2 = code[4];\n", + "\t\t\t\t\tif (tag == 'equal' && i2 - i1 > nn) {\n", + "\t\t\t\t\t\tgroup.push([tag, i1, Math.min(i2, i1 + n), j1, Math.min(j2, j1 + n)]);\n", + "\t\t\t\t\t\tgroups.push(group);\n", + "\t\t\t\t\t\tgroup = [];\n", + "\t\t\t\t\t\ti1 = Math.max(i1, i2-n);\n", + "\t\t\t\t\t\tj1 = Math.max(j1, j2-n);\n", + "\t\t\t\t\t}\n", + "\t\t\t\t\t\n", + "\t\t\t\t\tgroup.push([tag, i1, i2, j1, j2]);\n", + "\t\t\t\t}\n", + "\t\t\t}\n", + "\t\t\t\n", + "\t\t\tif (group && !(group.length == 1 && group[0][0] == 'equal')) groups.push(group)\n", + "\t\t\t\n", + "\t\t\treturn groups;\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.ratio = function () {\n", + "\t\t\tmatches = difflib.__reduce(\n", + "\t\t\t\t\t\t\tfunction (sum, triple) { return sum + triple[triple.length - 1]; },\n", + "\t\t\t\t\t\t\tthis.get_matching_blocks(), 0);\n", + "\t\t\treturn difflib.__calculate_ratio(matches, this.a.length + this.b.length);\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.quick_ratio = function () {\n", + "\t\t\tvar fullbcount, elt;\n", + "\t\t\tif (this.fullbcount == null) {\n", + "\t\t\t\tthis.fullbcount = fullbcount = {};\n", + "\t\t\t\tfor (var i = 0; i < this.b.length; i++) {\n", + "\t\t\t\t\telt = this.b[i];\n", + "\t\t\t\t\tfullbcount[elt] = difflib.__dictget(fullbcount, elt, 0) + 1;\n", + "\t\t\t\t}\n", + "\t\t\t}\n", + "\t\t\tfullbcount = this.fullbcount;\n", + "\t\n", + "\t\t\tvar avail = {};\n", + "\t\t\tvar availhas = difflib.__isindict(avail);\n", + "\t\t\tvar matches = numb = 0;\n", + "\t\t\tfor (var i = 0; i < this.a.length; i++) {\n", + "\t\t\t\telt = this.a[i];\n", + "\t\t\t\tif (availhas(elt)) {\n", + "\t\t\t\t\tnumb = avail[elt];\n", + "\t\t\t\t} else {\n", + "\t\t\t\t\tnumb = difflib.__dictget(fullbcount, elt, 0);\n", + "\t\t\t\t}\n", + "\t\t\t\tavail[elt] = numb - 1;\n", + "\t\t\t\tif (numb > 0) matches++;\n", + "\t\t\t}\n", + "\t\t\t\n", + "\t\t\treturn difflib.__calculate_ratio(matches, this.a.length + this.b.length);\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.real_quick_ratio = function () {\n", + "\t\t\tvar la = this.a.length;\n", + "\t\t\tvar lb = this.b.length;\n", + "\t\t\treturn _calculate_ratio(Math.min(la, lb), la + lb);\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.isjunk = isjunk ? isjunk : difflib.defaultJunkFunction;\n", + "\t\tthis.a = this.b = null;\n", + "\t\tthis.set_seqs(a, b);\n", + "\t}\n", + "};\n", + "\n", + "\n", + "\n", + "function diffUsingJS (viewType, contextSize, baseText, newText) {\n", + "\n", + " var byId = function (id) { return document.getElementById(id); },\n", + " base = difflib.stringAsLines(baseText),\n", + " newtxt = difflib.stringAsLines(newText),\n", + " sm = new difflib.SequenceMatcher(base, newtxt),\n", + " opcodes = sm.get_opcodes(),\n", + " diffoutputdiv = byId(\"diffid_2021-08-05_15_22_33_732030\");\n", + "\n", + " diffoutputdiv.innerHTML = \"\";\n", + " contextSize = contextSize || null;\n", + "\n", + " diffoutputdiv.appendChild(diffview.buildView({\n", + " baseTextLines: base,\n", + " newTextLines: newtxt,\n", + " opcodes: opcodes,\n", + " baseTextName: \"Base Text\",\n", + " newTextName: \"New Text\",\n", + " contextSize: contextSize,\n", + " viewType: viewType\n", + " }));\n", + "}\n", + "var tview=0;\n", + "var csize='';\n", + "var bt = 'def dft_real(x, fft_length=None, transpose=True):\\n if len(x.shape) == 1:\\n x = x.reshape((1, -1))\\n N = 1\\n else:\\n N = x.shape[0] \\n C = x.shape[-1] if transpose else x.shape[-2]\\n if fft_length is None:\\n fft_length = x.shape[-1]\\n size = fft_length // 2 + 1\\n\\n cst = dft_real_cst(C, fft_length)\\n if transpose:\\n x = numpy.transpose(x, (1, 0))\\n a = cst[:, :, :fft_length]\\n b = x[:fft_length]\\n res = numpy.matmul(a, b)\\n res = res[:, :size, :]\\n return numpy.transpose(res, (0, 2, 1))\\n else:\\n a = cst[:, :, :fft_length]\\n b = x[:fft_length]\\n return numpy.matmul(a, b)\\n';\n", + "var nt = 'def dft_real_d3(x, fft_length=None, transpose=True):\\n if len(x.shape) != 3:\\n raise RuntimeError(\"Not implemented for shape=%r.\" % x.shape)\\n N = x.shape[1]\\n C = x.shape[-1] if transpose else x.shape[-2]\\n if fft_length is None:\\n fft_length = x.shape[-1]\\n size = fft_length // 2 + 1\\n\\n cst = dft_real_cst(C, fft_length)\\n if transpose:\\n x = numpy.transpose(x, (0, 2, 1))\\n a = cst[:, :, :fft_length]\\n b = x[:, :fft_length, :]\\n a = numpy.expand_dims(a, 0)\\n b = numpy.expand_dims(b, 1)\\n res = numpy.matmul(a, b)\\n res = res[:, :, :size, :]\\n return numpy.transpose(res, (1, 0, 3, 2))\\n else:\\n a = cst[:, :, :fft_length]\\n b = x[:, :fft_length, :]\\n a = numpy.expand_dims(a, 0)\\n b = numpy.expand_dims(b, 1)\\n res = numpy.matmul(a, b)\\n return numpy.transpose(res, (1, 0, 2, 3))\\n';\n", + "diffUsingJS(tview, csize, bt, nt) ;\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import inspect\n", + "text1 = inspect.getsource(dft_real)\n", + "text2 = inspect.getsource(dft_real_d3)\n", + "%textdiff text1 text2" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "b815568f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
populating...
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "/*\n", + "This is part of jsdifflib v1.0. \n", + "\n", + "Copyright 2007 - 2011 Chas Emerick . All rights reserved.\n", + "\n", + "Redistribution and use in source and binary forms, with or without modification, are\n", + "permitted provided that the following conditions are met:\n", + "\n", + " 1. Redistributions of source code must retain the above copyright notice, this list of\n", + " conditions and the following disclaimer.\n", + "\n", + " 2. Redistributions in binary form must reproduce the above copyright notice, this list\n", + " of conditions and the following disclaimer in the documentation and/or other materials\n", + " provided with the distribution.\n", + "\n", + "THIS SOFTWARE IS PROVIDED BY Chas Emerick ``AS IS'' AND ANY EXPRESS OR IMPLIED\n", + "WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND\n", + "FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL Chas Emerick OR\n", + "CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR\n", + "CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n", + "SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON\n", + "ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING\n", + "NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF\n", + "ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n", + "\n", + "The views and conclusions contained in the software and documentation are those of the\n", + "authors and should not be interpreted as representing official policies, either expressed\n", + "or implied, of Chas Emerick.\n", + "*/\n", + "var diffview = {\n", + "\t/**\n", + "\t * Builds and returns a visual diff view. The single parameter, `params', should contain\n", + "\t * the following values:\n", + "\t *\n", + "\t * - baseTextLines: the array of strings that was used as the base text input to SequenceMatcher\n", + "\t * - newTextLines: the array of strings that was used as the new text input to SequenceMatcher\n", + "\t * - opcodes: the array of arrays returned by SequenceMatcher.get_opcodes()\n", + "\t * - baseTextName: the title to be displayed above the base text listing in the diff view; defaults\n", + "\t *\t to \"Base Text\"\n", + "\t * - newTextName: the title to be displayed above the new text listing in the diff view; defaults\n", + "\t *\t to \"New Text\"\n", + "\t * - contextSize: the number of lines of context to show around differences; by default, all lines\n", + "\t *\t are shown\n", + "\t * - viewType: if 0, a side-by-side diff view is generated (default); if 1, an inline diff view is\n", + "\t *\t generated\n", + "\t */\n", + "\tbuildView: function (params) {\n", + "\t\tvar baseTextLines = params.baseTextLines;\n", + "\t\tvar newTextLines = params.newTextLines;\n", + "\t\tvar opcodes = params.opcodes;\n", + "\t\tvar baseTextName = params.baseTextName ? params.baseTextName : \"Base Text\";\n", + "\t\tvar newTextName = params.newTextName ? params.newTextName : \"New Text\";\n", + "\t\tvar contextSize = params.contextSize;\n", + "\t\tvar inline = (params.viewType == 0 || params.viewType == 1) ? params.viewType : 0;\n", + "\n", + "\t\tif (baseTextLines == null)\n", + "\t\t\tthrow \"Cannot build diff view; baseTextLines is not defined.\";\n", + "\t\tif (newTextLines == null)\n", + "\t\t\tthrow \"Cannot build diff view; newTextLines is not defined.\";\n", + "\t\tif (!opcodes)\n", + "\t\t\tthrow \"Cannot build diff view; opcodes is not defined.\";\n", + "\t\t\n", + "\t\tfunction celt (name, clazz) {\n", + "\t\t\tvar e = document.createElement(name);\n", + "\t\t\te.className = clazz;\n", + "\t\t\treturn e;\n", + "\t\t}\n", + "\t\t\n", + "\t\tfunction telt (name, text) {\n", + "\t\t\tvar e = document.createElement(name);\n", + "\t\t\te.appendChild(document.createTextNode(text));\n", + "\t\t\treturn e;\n", + "\t\t}\n", + "\t\t\n", + "\t\tfunction ctelt (name, clazz, text) {\n", + "\t\t\tvar e = document.createElement(name);\n", + "\t\t\te.className = clazz;\n", + "\t\t\te.appendChild(document.createTextNode(text));\n", + "\t\t\treturn e;\n", + "\t\t}\n", + "\t\n", + "\t\tvar tdata = document.createElement(\"thead\");\n", + "\t\tvar node = document.createElement(\"tr\");\n", + "\t\ttdata.appendChild(node);\n", + "\t\tif (inline) {\n", + "\t\t\tnode.appendChild(document.createElement(\"th\"));\n", + "\t\t\tnode.appendChild(document.createElement(\"th\"));\n", + "\t\t\tnode.appendChild(ctelt(\"th\", \"texttitle\", baseTextName + \" vs. \" + newTextName));\n", + "\t\t} else {\n", + "\t\t\tnode.appendChild(document.createElement(\"th\"));\n", + "\t\t\tnode.appendChild(ctelt(\"th\", \"texttitle\", baseTextName));\n", + "\t\t\tnode.appendChild(document.createElement(\"th\"));\n", + "\t\t\tnode.appendChild(ctelt(\"th\", \"texttitle\", newTextName));\n", + "\t\t}\n", + "\t\ttdata = [tdata];\n", + "\t\t\n", + "\t\tvar rows = [];\n", + "\t\tvar node2;\n", + "\t\t\n", + "\t\t/**\n", + "\t\t * Adds two cells to the given row; if the given row corresponds to a real\n", + "\t\t * line number (based on the line index tidx and the endpoint of the \n", + "\t\t * range in question tend), then the cells will contain the line number\n", + "\t\t * and the line of text from textLines at position tidx (with the class of\n", + "\t\t * the second cell set to the name of the change represented), and tidx + 1 will\n", + "\t\t * be returned.\t Otherwise, tidx is returned, and two empty cells are added\n", + "\t\t * to the given row.\n", + "\t\t */\n", + "\t\tfunction addCells (row, tidx, tend, textLines, change) {\n", + "\t\t\tif (tidx < tend) {\n", + "\t\t\t\trow.appendChild(telt(\"th\", (tidx + 1).toString()));\n", + "\t\t\t\trow.appendChild(ctelt(\"td\", change, textLines[tidx].replace(/\\t/g, \"\\u00a0\\u00a0\\u00a0\\u00a0\")));\n", + "\t\t\t\treturn tidx + 1;\n", + "\t\t\t} else {\n", + "\t\t\t\trow.appendChild(document.createElement(\"th\"));\n", + "\t\t\t\trow.appendChild(celt(\"td\", \"empty\"));\n", + "\t\t\t\treturn tidx;\n", + "\t\t\t}\n", + "\t\t}\n", + "\t\t\n", + "\t\tfunction addCellsInline (row, tidx, tidx2, textLines, change) {\n", + "\t\t\trow.appendChild(telt(\"th\", tidx == null ? \"\" : (tidx + 1).toString()));\n", + "\t\t\trow.appendChild(telt(\"th\", tidx2 == null ? \"\" : (tidx2 + 1).toString()));\n", + "\t\t\trow.appendChild(ctelt(\"td\", change, textLines[tidx != null ? tidx : tidx2].replace(/\\t/g, \"\\u00a0\\u00a0\\u00a0\\u00a0\")));\n", + "\t\t}\n", + "\t\t\n", + "\t\tfor (var idx = 0; idx < opcodes.length; idx++) {\n", + "\t\t\tvar code = opcodes[idx];\n", + "\t\t\tvar change = code[0];\n", + "\t\t\tvar b = code[1];\n", + "\t\t\tvar be = code[2];\n", + "\t\t\tvar n = code[3];\n", + "\t\t\tvar ne = code[4];\n", + "\t\t\tvar rowcnt = Math.max(be - b, ne - n);\n", + "\t\t\tvar toprows = [];\n", + "\t\t\tvar botrows = [];\n", + "\t\t\tfor (var i = 0; i < rowcnt; i++) {\n", + "\t\t\t\t// jump ahead if we've alredy provided leading context or if this is the first range\n", + "\t\t\t\tif (contextSize && opcodes.length > 1 && ((idx > 0 && i == contextSize) || (idx == 0 && i == 0)) && change==\"equal\") {\n", + "\t\t\t\t\tvar jump = rowcnt - ((idx == 0 ? 1 : 2) * contextSize);\n", + "\t\t\t\t\tif (jump > 1) {\n", + "\t\t\t\t\t\ttoprows.push(node = document.createElement(\"tr\"));\n", + "\t\t\t\t\t\t\n", + "\t\t\t\t\t\tb += jump;\n", + "\t\t\t\t\t\tn += jump;\n", + "\t\t\t\t\t\ti += jump - 1;\n", + "\t\t\t\t\t\tnode.appendChild(telt(\"th\", \"...\"));\n", + "\t\t\t\t\t\tif (!inline) node.appendChild(ctelt(\"td\", \"skip\", \"\"));\n", + "\t\t\t\t\t\tnode.appendChild(telt(\"th\", \"...\"));\n", + "\t\t\t\t\t\tnode.appendChild(ctelt(\"td\", \"skip\", \"\"));\n", + "\t\t\t\t\t\t\n", + "\t\t\t\t\t\t// skip last lines if they're all equal\n", + "\t\t\t\t\t\tif (idx + 1 == opcodes.length) {\n", + "\t\t\t\t\t\t\tbreak;\n", + "\t\t\t\t\t\t} else {\n", + "\t\t\t\t\t\t\tcontinue;\n", + "\t\t\t\t\t\t}\n", + "\t\t\t\t\t}\n", + "\t\t\t\t}\n", + "\t\t\t\t\n", + "\t\t\t\ttoprows.push(node = document.createElement(\"tr\"));\n", + "\t\t\t\tif (inline) {\n", + "\t\t\t\t\tif (change == \"insert\") {\n", + "\t\t\t\t\t\taddCellsInline(node, null, n++, newTextLines, change);\n", + "\t\t\t\t\t} else if (change == \"replace\") {\n", + "\t\t\t\t\t\tbotrows.push(node2 = document.createElement(\"tr\"));\n", + "\t\t\t\t\t\tif (b < be) addCellsInline(node, b++, null, baseTextLines, \"delete\");\n", + "\t\t\t\t\t\tif (n < ne) addCellsInline(node2, null, n++, newTextLines, \"insert\");\n", + "\t\t\t\t\t} else if (change == \"delete\") {\n", + "\t\t\t\t\t\taddCellsInline(node, b++, null, baseTextLines, change);\n", + "\t\t\t\t\t} else {\n", + "\t\t\t\t\t\t// equal\n", + "\t\t\t\t\t\taddCellsInline(node, b++, n++, baseTextLines, change);\n", + "\t\t\t\t\t}\n", + "\t\t\t\t} else {\n", + "\t\t\t\t\tb = addCells(node, b, be, baseTextLines, change);\n", + "\t\t\t\t\tn = addCells(node, n, ne, newTextLines, change);\n", + "\t\t\t\t}\n", + "\t\t\t}\n", + "\n", + "\t\t\tfor (var i = 0; i < toprows.length; i++) rows.push(toprows[i]);\n", + "\t\t\tfor (var i = 0; i < botrows.length; i++) rows.push(botrows[i]);\n", + "\t\t}\n", + "\t\t\n", + "\t\trows.push(node = ctelt(\"th\", \"author\", \"diff view generated by \"));\n", + "\t\tnode.setAttribute(\"colspan\", inline ? 3 : 4);\n", + "\t\tnode.appendChild(node2 = telt(\"a\", \"jsdifflib\"));\n", + "\t\tnode2.setAttribute(\"href\", \"http://github.com/cemerick/jsdifflib\");\n", + "\t\t\n", + "\t\ttdata.push(node = document.createElement(\"tbody\"));\n", + "\t\tfor (var idx in rows) rows.hasOwnProperty(idx) && node.appendChild(rows[idx]);\n", + "\t\t\n", + "\t\tnode = celt(\"table\", \"diff\" + (inline ? \" inlinediff\" : \"\"));\n", + "\t\tfor (var idx in tdata) tdata.hasOwnProperty(idx) && node.appendChild(tdata[idx]);\n", + "\t\treturn node;\n", + "\t}\n", + "};\n", + "\n", + "\n", + "/***\n", + "This is part of jsdifflib v1.0. \n", + "\n", + "Copyright (c) 2007, Snowtide Informatics Systems, Inc.\n", + "All rights reserved.\n", + "\n", + "Redistribution and use in source and binary forms, with or without modification,\n", + "are permitted provided that the following conditions are met:\n", + "\n", + "\t* Redistributions of source code must retain the above copyright notice, this\n", + "\t\tlist of conditions and the following disclaimer.\n", + "\t* Redistributions in binary form must reproduce the above copyright notice,\n", + "\t\tthis list of conditions and the following disclaimer in the documentation\n", + "\t\tand/or other materials provided with the distribution.\n", + "\t* Neither the name of the Snowtide Informatics Systems nor the names of its\n", + "\t\tcontributors may be used to endorse or promote products derived from this\n", + "\t\tsoftware without specific prior written permission.\n", + "\n", + "THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND ANY\n", + "EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES\n", + "OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT\n", + "SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,\n", + "INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED\n", + "TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR\n", + "BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN\n", + "CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN\n", + "ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH\n", + "DAMAGE.\n", + "***/\n", + "/* Author: Chas Emerick */\n", + "var __whitespace = {\" \":true, \"\\t\":true, \"\\n\":true, \"\\f\":true, \"\\r\":true};\n", + "\n", + "var difflib = {\n", + "\tdefaultJunkFunction: function (c) {\n", + "\t\treturn __whitespace.hasOwnProperty(c);\n", + "\t},\n", + "\t\n", + "\tstripLinebreaks: function (str) { return str.replace(/^[\\n\\r]*|[\\n\\r]*$/g, \"\"); },\n", + "\t\n", + "\tstringAsLines: function (str) {\n", + "\t\tvar lfpos = str.indexOf(\"\\n\");\n", + "\t\tvar crpos = str.indexOf(\"\\r\");\n", + "\t\tvar linebreak = ((lfpos > -1 && crpos > -1) || crpos < 0) ? \"\\n\" : \"\\r\";\n", + "\t\t\n", + "\t\tvar lines = str.split(linebreak);\n", + "\t\tfor (var i = 0; i < lines.length; i++) {\n", + "\t\t\tlines[i] = difflib.stripLinebreaks(lines[i]);\n", + "\t\t}\n", + "\t\t\n", + "\t\treturn lines;\n", + "\t},\n", + "\t\n", + "\t// iteration-based reduce implementation\n", + "\t__reduce: function (func, list, initial) {\n", + "\t\tif (initial != null) {\n", + "\t\t\tvar value = initial;\n", + "\t\t\tvar idx = 0;\n", + "\t\t} else if (list) {\n", + "\t\t\tvar value = list[0];\n", + "\t\t\tvar idx = 1;\n", + "\t\t} else {\n", + "\t\t\treturn null;\n", + "\t\t}\n", + "\t\t\n", + "\t\tfor (; idx < list.length; idx++) {\n", + "\t\t\tvalue = func(value, list[idx]);\n", + "\t\t}\n", + "\t\t\n", + "\t\treturn value;\n", + "\t},\n", + "\t\n", + "\t// comparison function for sorting lists of numeric tuples\n", + "\t__ntuplecomp: function (a, b) {\n", + "\t\tvar mlen = Math.max(a.length, b.length);\n", + "\t\tfor (var i = 0; i < mlen; i++) {\n", + "\t\t\tif (a[i] < b[i]) return -1;\n", + "\t\t\tif (a[i] > b[i]) return 1;\n", + "\t\t}\n", + "\t\t\n", + "\t\treturn a.length == b.length ? 0 : (a.length < b.length ? -1 : 1);\n", + "\t},\n", + "\t\n", + "\t__calculate_ratio: function (matches, length) {\n", + "\t\treturn length ? 2.0 * matches / length : 1.0;\n", + "\t},\n", + "\t\n", + "\t// returns a function that returns true if a key passed to the returned function\n", + "\t// is in the dict (js object) provided to this function; replaces being able to\n", + "\t// carry around dict.has_key in python...\n", + "\t__isindict: function (dict) {\n", + "\t\treturn function (key) { return dict.hasOwnProperty(key); };\n", + "\t},\n", + "\t\n", + "\t// replacement for python's dict.get function -- need easy default values\n", + "\t__dictget: function (dict, key, defaultValue) {\n", + "\t\treturn dict.hasOwnProperty(key) ? dict[key] : defaultValue;\n", + "\t},\t\n", + "\t\n", + "\tSequenceMatcher: function (a, b, isjunk) {\n", + "\t\tthis.set_seqs = function (a, b) {\n", + "\t\t\tthis.set_seq1(a);\n", + "\t\t\tthis.set_seq2(b);\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.set_seq1 = function (a) {\n", + "\t\t\tif (a == this.a) return;\n", + "\t\t\tthis.a = a;\n", + "\t\t\tthis.matching_blocks = this.opcodes = null;\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.set_seq2 = function (b) {\n", + "\t\t\tif (b == this.b) return;\n", + "\t\t\tthis.b = b;\n", + "\t\t\tthis.matching_blocks = this.opcodes = this.fullbcount = null;\n", + "\t\t\tthis.__chain_b();\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.__chain_b = function () {\n", + "\t\t\tvar b = this.b;\n", + "\t\t\tvar n = b.length;\n", + "\t\t\tvar b2j = this.b2j = {};\n", + "\t\t\tvar populardict = {};\n", + "\t\t\tfor (var i = 0; i < b.length; i++) {\n", + "\t\t\t\tvar elt = b[i];\n", + "\t\t\t\tif (b2j.hasOwnProperty(elt)) {\n", + "\t\t\t\t\tvar indices = b2j[elt];\n", + "\t\t\t\t\tif (n >= 200 && indices.length * 100 > n) {\n", + "\t\t\t\t\t\tpopulardict[elt] = 1;\n", + "\t\t\t\t\t\tdelete b2j[elt];\n", + "\t\t\t\t\t} else {\n", + "\t\t\t\t\t\tindices.push(i);\n", + "\t\t\t\t\t}\n", + "\t\t\t\t} else {\n", + "\t\t\t\t\tb2j[elt] = [i];\n", + "\t\t\t\t}\n", + "\t\t\t}\n", + "\t\n", + "\t\t\tfor (var elt in populardict) {\n", + "\t\t\t\tif (populardict.hasOwnProperty(elt)) {\n", + "\t\t\t\t\tdelete b2j[elt];\n", + "\t\t\t\t}\n", + "\t\t\t}\n", + "\t\t\t\n", + "\t\t\tvar isjunk = this.isjunk;\n", + "\t\t\tvar junkdict = {};\n", + "\t\t\tif (isjunk) {\n", + "\t\t\t\tfor (var elt in populardict) {\n", + "\t\t\t\t\tif (populardict.hasOwnProperty(elt) && isjunk(elt)) {\n", + "\t\t\t\t\t\tjunkdict[elt] = 1;\n", + "\t\t\t\t\t\tdelete populardict[elt];\n", + "\t\t\t\t\t}\n", + "\t\t\t\t}\n", + "\t\t\t\tfor (var elt in b2j) {\n", + "\t\t\t\t\tif (b2j.hasOwnProperty(elt) && isjunk(elt)) {\n", + "\t\t\t\t\t\tjunkdict[elt] = 1;\n", + "\t\t\t\t\t\tdelete b2j[elt];\n", + "\t\t\t\t\t}\n", + "\t\t\t\t}\n", + "\t\t\t}\n", + "\t\n", + "\t\t\tthis.isbjunk = difflib.__isindict(junkdict);\n", + "\t\t\tthis.isbpopular = difflib.__isindict(populardict);\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.find_longest_match = function (alo, ahi, blo, bhi) {\n", + "\t\t\tvar a = this.a;\n", + "\t\t\tvar b = this.b;\n", + "\t\t\tvar b2j = this.b2j;\n", + "\t\t\tvar isbjunk = this.isbjunk;\n", + "\t\t\tvar besti = alo;\n", + "\t\t\tvar bestj = blo;\n", + "\t\t\tvar bestsize = 0;\n", + "\t\t\tvar j = null;\n", + "\t\t\tvar k;\n", + "\t\n", + "\t\t\tvar j2len = {};\n", + "\t\t\tvar nothing = [];\n", + "\t\t\tfor (var i = alo; i < ahi; i++) {\n", + "\t\t\t\tvar newj2len = {};\n", + "\t\t\t\tvar jdict = difflib.__dictget(b2j, a[i], nothing);\n", + "\t\t\t\tfor (var jkey in jdict) {\n", + "\t\t\t\t\tif (jdict.hasOwnProperty(jkey)) {\n", + "\t\t\t\t\t\tj = jdict[jkey];\n", + "\t\t\t\t\t\tif (j < blo) continue;\n", + "\t\t\t\t\t\tif (j >= bhi) break;\n", + "\t\t\t\t\t\tnewj2len[j] = k = difflib.__dictget(j2len, j - 1, 0) + 1;\n", + "\t\t\t\t\t\tif (k > bestsize) {\n", + "\t\t\t\t\t\t\tbesti = i - k + 1;\n", + "\t\t\t\t\t\t\tbestj = j - k + 1;\n", + "\t\t\t\t\t\t\tbestsize = k;\n", + "\t\t\t\t\t\t}\n", + "\t\t\t\t\t}\n", + "\t\t\t\t}\n", + "\t\t\t\tj2len = newj2len;\n", + "\t\t\t}\n", + "\t\n", + "\t\t\twhile (besti > alo && bestj > blo && !isbjunk(b[bestj - 1]) && a[besti - 1] == b[bestj - 1]) {\n", + "\t\t\t\tbesti--;\n", + "\t\t\t\tbestj--;\n", + "\t\t\t\tbestsize++;\n", + "\t\t\t}\n", + "\t\t\t\t\n", + "\t\t\twhile (besti + bestsize < ahi && bestj + bestsize < bhi &&\n", + "\t\t\t\t\t!isbjunk(b[bestj + bestsize]) &&\n", + "\t\t\t\t\ta[besti + bestsize] == b[bestj + bestsize]) {\n", + "\t\t\t\tbestsize++;\n", + "\t\t\t}\n", + "\t\n", + "\t\t\twhile (besti > alo && bestj > blo && isbjunk(b[bestj - 1]) && a[besti - 1] == b[bestj - 1]) {\n", + "\t\t\t\tbesti--;\n", + "\t\t\t\tbestj--;\n", + "\t\t\t\tbestsize++;\n", + "\t\t\t}\n", + "\t\t\t\n", + "\t\t\twhile (besti + bestsize < ahi && bestj + bestsize < bhi && isbjunk(b[bestj + bestsize]) &&\n", + "\t\t\t\t\ta[besti + bestsize] == b[bestj + bestsize]) {\n", + "\t\t\t\tbestsize++;\n", + "\t\t\t}\n", + "\t\n", + "\t\t\treturn [besti, bestj, bestsize];\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.get_matching_blocks = function () {\n", + "\t\t\tif (this.matching_blocks != null) return this.matching_blocks;\n", + "\t\t\tvar la = this.a.length;\n", + "\t\t\tvar lb = this.b.length;\n", + "\t\n", + "\t\t\tvar queue = [[0, la, 0, lb]];\n", + "\t\t\tvar matching_blocks = [];\n", + "\t\t\tvar alo, ahi, blo, bhi, qi, i, j, k, x;\n", + "\t\t\twhile (queue.length) {\n", + "\t\t\t\tqi = queue.pop();\n", + "\t\t\t\talo = qi[0];\n", + "\t\t\t\tahi = qi[1];\n", + "\t\t\t\tblo = qi[2];\n", + "\t\t\t\tbhi = qi[3];\n", + "\t\t\t\tx = this.find_longest_match(alo, ahi, blo, bhi);\n", + "\t\t\t\ti = x[0];\n", + "\t\t\t\tj = x[1];\n", + "\t\t\t\tk = x[2];\n", + "\t\n", + "\t\t\t\tif (k) {\n", + "\t\t\t\t\tmatching_blocks.push(x);\n", + "\t\t\t\t\tif (alo < i && blo < j)\n", + "\t\t\t\t\t\tqueue.push([alo, i, blo, j]);\n", + "\t\t\t\t\tif (i+k < ahi && j+k < bhi)\n", + "\t\t\t\t\t\tqueue.push([i + k, ahi, j + k, bhi]);\n", + "\t\t\t\t}\n", + "\t\t\t}\n", + "\t\t\t\n", + "\t\t\tmatching_blocks.sort(difflib.__ntuplecomp);\n", + "\t\n", + "\t\t\tvar i1 = 0, j1 = 0, k1 = 0, block = 0;\n", + "\t\t\tvar i2, j2, k2;\n", + "\t\t\tvar non_adjacent = [];\n", + "\t\t\tfor (var idx in matching_blocks) {\n", + "\t\t\t\tif (matching_blocks.hasOwnProperty(idx)) {\n", + "\t\t\t\t\tblock = matching_blocks[idx];\n", + "\t\t\t\t\ti2 = block[0];\n", + "\t\t\t\t\tj2 = block[1];\n", + "\t\t\t\t\tk2 = block[2];\n", + "\t\t\t\t\tif (i1 + k1 == i2 && j1 + k1 == j2) {\n", + "\t\t\t\t\t\tk1 += k2;\n", + "\t\t\t\t\t} else {\n", + "\t\t\t\t\t\tif (k1) non_adjacent.push([i1, j1, k1]);\n", + "\t\t\t\t\t\ti1 = i2;\n", + "\t\t\t\t\t\tj1 = j2;\n", + "\t\t\t\t\t\tk1 = k2;\n", + "\t\t\t\t\t}\n", + "\t\t\t\t}\n", + "\t\t\t}\n", + "\t\t\t\n", + "\t\t\tif (k1) non_adjacent.push([i1, j1, k1]);\n", + "\t\n", + "\t\t\tnon_adjacent.push([la, lb, 0]);\n", + "\t\t\tthis.matching_blocks = non_adjacent;\n", + "\t\t\treturn this.matching_blocks;\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.get_opcodes = function () {\n", + "\t\t\tif (this.opcodes != null) return this.opcodes;\n", + "\t\t\tvar i = 0;\n", + "\t\t\tvar j = 0;\n", + "\t\t\tvar answer = [];\n", + "\t\t\tthis.opcodes = answer;\n", + "\t\t\tvar block, ai, bj, size, tag;\n", + "\t\t\tvar blocks = this.get_matching_blocks();\n", + "\t\t\tfor (var idx in blocks) {\n", + "\t\t\t\tif (blocks.hasOwnProperty(idx)) {\n", + "\t\t\t\t\tblock = blocks[idx];\n", + "\t\t\t\t\tai = block[0];\n", + "\t\t\t\t\tbj = block[1];\n", + "\t\t\t\t\tsize = block[2];\n", + "\t\t\t\t\ttag = '';\n", + "\t\t\t\t\tif (i < ai && j < bj) {\n", + "\t\t\t\t\t\ttag = 'replace';\n", + "\t\t\t\t\t} else if (i < ai) {\n", + "\t\t\t\t\t\ttag = 'delete';\n", + "\t\t\t\t\t} else if (j < bj) {\n", + "\t\t\t\t\t\ttag = 'insert';\n", + "\t\t\t\t\t}\n", + "\t\t\t\t\tif (tag) answer.push([tag, i, ai, j, bj]);\n", + "\t\t\t\t\ti = ai + size;\n", + "\t\t\t\t\tj = bj + size;\n", + "\t\t\t\t\t\n", + "\t\t\t\t\tif (size) answer.push(['equal', ai, i, bj, j]);\n", + "\t\t\t\t}\n", + "\t\t\t}\n", + "\t\t\t\n", + "\t\t\treturn answer;\n", + "\t\t}\n", + "\t\t\n", + "\t\t// this is a generator function in the python lib, which of course is not supported in javascript\n", + "\t\t// the reimplementation builds up the grouped opcodes into a list in their entirety and returns that.\n", + "\t\tthis.get_grouped_opcodes = function (n) {\n", + "\t\t\tif (!n) n = 3;\n", + "\t\t\tvar codes = this.get_opcodes();\n", + "\t\t\tif (!codes) codes = [[\"equal\", 0, 1, 0, 1]];\n", + "\t\t\tvar code, tag, i1, i2, j1, j2;\n", + "\t\t\tif (codes[0][0] == 'equal') {\n", + "\t\t\t\tcode = codes[0];\n", + "\t\t\t\ttag = code[0];\n", + "\t\t\t\ti1 = code[1];\n", + "\t\t\t\ti2 = code[2];\n", + "\t\t\t\tj1 = code[3];\n", + "\t\t\t\tj2 = code[4];\n", + "\t\t\t\tcodes[0] = [tag, Math.max(i1, i2 - n), i2, Math.max(j1, j2 - n), j2];\n", + "\t\t\t}\n", + "\t\t\tif (codes[codes.length - 1][0] == 'equal') {\n", + "\t\t\t\tcode = codes[codes.length - 1];\n", + "\t\t\t\ttag = code[0];\n", + "\t\t\t\ti1 = code[1];\n", + "\t\t\t\ti2 = code[2];\n", + "\t\t\t\tj1 = code[3];\n", + "\t\t\t\tj2 = code[4];\n", + "\t\t\t\tcodes[codes.length - 1] = [tag, i1, Math.min(i2, i1 + n), j1, Math.min(j2, j1 + n)];\n", + "\t\t\t}\n", + "\t\n", + "\t\t\tvar nn = n + n;\n", + "\t\t\tvar group = [];\n", + "\t\t\tvar groups = [];\n", + "\t\t\tfor (var idx in codes) {\n", + "\t\t\t\tif (codes.hasOwnProperty(idx)) {\n", + "\t\t\t\t\tcode = codes[idx];\n", + "\t\t\t\t\ttag = code[0];\n", + "\t\t\t\t\ti1 = code[1];\n", + "\t\t\t\t\ti2 = code[2];\n", + "\t\t\t\t\tj1 = code[3];\n", + "\t\t\t\t\tj2 = code[4];\n", + "\t\t\t\t\tif (tag == 'equal' && i2 - i1 > nn) {\n", + "\t\t\t\t\t\tgroup.push([tag, i1, Math.min(i2, i1 + n), j1, Math.min(j2, j1 + n)]);\n", + "\t\t\t\t\t\tgroups.push(group);\n", + "\t\t\t\t\t\tgroup = [];\n", + "\t\t\t\t\t\ti1 = Math.max(i1, i2-n);\n", + "\t\t\t\t\t\tj1 = Math.max(j1, j2-n);\n", + "\t\t\t\t\t}\n", + "\t\t\t\t\t\n", + "\t\t\t\t\tgroup.push([tag, i1, i2, j1, j2]);\n", + "\t\t\t\t}\n", + "\t\t\t}\n", + "\t\t\t\n", + "\t\t\tif (group && !(group.length == 1 && group[0][0] == 'equal')) groups.push(group)\n", + "\t\t\t\n", + "\t\t\treturn groups;\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.ratio = function () {\n", + "\t\t\tmatches = difflib.__reduce(\n", + "\t\t\t\t\t\t\tfunction (sum, triple) { return sum + triple[triple.length - 1]; },\n", + "\t\t\t\t\t\t\tthis.get_matching_blocks(), 0);\n", + "\t\t\treturn difflib.__calculate_ratio(matches, this.a.length + this.b.length);\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.quick_ratio = function () {\n", + "\t\t\tvar fullbcount, elt;\n", + "\t\t\tif (this.fullbcount == null) {\n", + "\t\t\t\tthis.fullbcount = fullbcount = {};\n", + "\t\t\t\tfor (var i = 0; i < this.b.length; i++) {\n", + "\t\t\t\t\telt = this.b[i];\n", + "\t\t\t\t\tfullbcount[elt] = difflib.__dictget(fullbcount, elt, 0) + 1;\n", + "\t\t\t\t}\n", + "\t\t\t}\n", + "\t\t\tfullbcount = this.fullbcount;\n", + "\t\n", + "\t\t\tvar avail = {};\n", + "\t\t\tvar availhas = difflib.__isindict(avail);\n", + "\t\t\tvar matches = numb = 0;\n", + "\t\t\tfor (var i = 0; i < this.a.length; i++) {\n", + "\t\t\t\telt = this.a[i];\n", + "\t\t\t\tif (availhas(elt)) {\n", + "\t\t\t\t\tnumb = avail[elt];\n", + "\t\t\t\t} else {\n", + "\t\t\t\t\tnumb = difflib.__dictget(fullbcount, elt, 0);\n", + "\t\t\t\t}\n", + "\t\t\t\tavail[elt] = numb - 1;\n", + "\t\t\t\tif (numb > 0) matches++;\n", + "\t\t\t}\n", + "\t\t\t\n", + "\t\t\treturn difflib.__calculate_ratio(matches, this.a.length + this.b.length);\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.real_quick_ratio = function () {\n", + "\t\t\tvar la = this.a.length;\n", + "\t\t\tvar lb = this.b.length;\n", + "\t\t\treturn _calculate_ratio(Math.min(la, lb), la + lb);\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.isjunk = isjunk ? isjunk : difflib.defaultJunkFunction;\n", + "\t\tthis.a = this.b = null;\n", + "\t\tthis.set_seqs(a, b);\n", + "\t}\n", + "};\n", + "\n", + "\n", + "\n", + "function diffUsingJS (viewType, contextSize, baseText, newText) {\n", + "\n", + " var byId = function (id) { return document.getElementById(id); },\n", + " base = difflib.stringAsLines(baseText),\n", + " newtxt = difflib.stringAsLines(newText),\n", + " sm = new difflib.SequenceMatcher(base, newtxt),\n", + " opcodes = sm.get_opcodes(),\n", + " diffoutputdiv = byId(\"diffid_2021-08-05_15_22_33_799031\");\n", + "\n", + " diffoutputdiv.innerHTML = \"\";\n", + " contextSize = contextSize || null;\n", + "\n", + " diffoutputdiv.appendChild(diffview.buildView({\n", + " baseTextLines: base,\n", + " newTextLines: newtxt,\n", + " opcodes: opcodes,\n", + " baseTextName: \"Base Text\",\n", + " newTextName: \"New Text\",\n", + " contextSize: contextSize,\n", + " viewType: viewType\n", + " }));\n", + "}\n", + "var tview=0;\n", + "var csize='';\n", + "var bt = 'def fft2d(mat, fft_length):\\n mat = mat[:fft_length[0], :fft_length[1]]\\n res = mat.copy()\\n \\n # first FFT\\n res = dft_real(res, fft_length=fft_length[1], transpose=True)\\n \\n # second FFT decomposed on FFT on real part and imaginary part\\n res2_real = dft_real(res[0], fft_length=fft_length[0], transpose=False)\\n res2_imag = dft_real(res[1], fft_length=fft_length[0], transpose=False) \\n res2_imag2 = numpy.vstack([-res2_imag[1:2], res2_imag[:1]])\\n res = res2_real + res2_imag2\\n size = fft_length[1]//2 + 1\\n return res[:, :fft_length[0], :size]\\n';\n", + "var nt = 'def fft2d_d3(mat, fft_length):\\n mat = mat[:, :fft_length[-2], :fft_length[-1]]\\n res = mat.copy()\\n \\n # first FFT\\n res = dft_real_d3(res, fft_length=fft_length[-1], transpose=True)\\n \\n # second FFT decomposed on FFT on real part and imaginary part\\n res2_real = dft_real_d3(res[0], fft_length=fft_length[-2], transpose=False)\\n res2_imag = dft_real_d3(res[1], fft_length=fft_length[-2], transpose=False)\\n res2_imag2 = numpy.vstack([-res2_imag[1:2], res2_imag[:1]])\\n res = res2_real + res2_imag2\\n size = fft_length[-1]//2 + 1\\n return res[:, :, :fft_length[-2], :size]\\n';\n", + "diffUsingJS(tview, csize, bt, nt) ;\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "text1 = inspect.getsource(fft2d)\n", + "text2 = inspect.getsource(fft2d_d3)\n", + "%textdiff text1 text2" + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "id": "6194fe21", + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "def dft_real_d3(x, fft_length=None, transpose=True):\n", + " if len(x.shape) != 3:\n", + " raise RuntimeError(\"Not implemented for shape=%r.\" % x.shape)\n", + " N = x.shape[1]\n", + " C = x.shape[-1] if transpose else x.shape[-2]\n", + " if fft_length is None:\n", + " fft_length = x.shape[-1]\n", + " size = fft_length // 2 + 1\n", + "\n", + " cst = dft_real_cst(C, fft_length)\n", + " if transpose:\n", + " x = numpy.transpose(x, (0, 2, 1))\n", + " a = cst[:, :, :fft_length]\n", + " b = x[:, :fft_length, :]\n", + " a = numpy.expand_dims(a, 0)\n", + " b = numpy.expand_dims(b, 1)\n", + " res = numpy.matmul(a, b)\n", + " res = res[:, :, :size, :]\n", + " return numpy.transpose(res, (1, 0, 3, 2))\n", + " else:\n", + " a = cst[:, :, :fft_length]\n", + " b = x[:, :fft_length, :]\n", + " a = numpy.expand_dims(a, 0)\n", + " b = numpy.expand_dims(b, 1)\n", + " res = numpy.matmul(a, b)\n", + " return numpy.transpose(res, (1, 0, 2, 3))\n", + "\n", + "\n", + "def fft2d_d3(mat, fft_length):\n", + " mat = mat[:, :fft_length[-2], :fft_length[-1]]\n", + " res = mat.copy()\n", + " \n", + " # first FFT\n", + " #print(\"AAAAAA\", res.shape)\n", + " res = dft_real_d3(res, fft_length=fft_length[-1], transpose=True)\n", + " #print(\"BBBBBB\", res.shape)\n", + " \n", + " # second FFT decomposed on FFT on real part and imaginary part\n", + " res2_real = dft_real_d3(res[0], fft_length=fft_length[-2], transpose=False)\n", + " res2_imag = dft_real_d3(res[1], fft_length=fft_length[-2], transpose=False)\n", + " res2_imag2 = numpy.vstack([-res2_imag[1:2], res2_imag[:1]])\n", + " res = res2_real + res2_imag2\n", + " size = fft_length[-1]//2 + 1\n", + " return res[:, :, :fft_length[-2], :size]\n", + "\n", + "\n", + "def fft2d_any(mat, fft_length):\n", + " new_shape = (-1, ) + mat.shape[-2:]\n", + " mat2 = mat.reshape(new_shape)\n", + " f2 = fft2d_d3(mat2, fft_length)\n", + " new_shape = (2, ) + mat.shape[:-2] + f2.shape[-2:]\n", + " return f2.reshape(new_shape)\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "def onnx_rfft_3d_1d(x, fft_length=None, transpose=True):\n", + " if fft_length is None:\n", + " raise RuntimeError(\"fft_length must be specified.\")\n", + " \n", + " size = fft_length // 2 + 1\n", + " cst = dft_real_cst(fft_length, fft_length).astype(numpy.float32)\n", + " if transpose:\n", + " xt = npnx.transpose(x, (0, 2, 1))\n", + " a = cst[:, :, :fft_length]\n", + " b = xt[:, :fft_length, :]\n", + " a = npnx.expand_dims(a, 0)\n", + " b = npnx.expand_dims(b, 1)\n", + " res = npnx.matmul(a, b)\n", + " res2 = res[:, :size, :]\n", + " return npnx.transpose(res2, (1, 0, 3, 2))\n", + " else:\n", + " a = cst[:, :, :fft_length]\n", + " b = x[:, :fft_length, :]\n", + " a = npnx.expand_dims(a, 0)\n", + " b = npnx.expand_dims(b, 1)\n", + " res = npnx.matmul(a, b)\n", + " return npnx.transpose(res, (1, 0, 2, 3)) \n", + " \n", + "\n", + "def onnx_rfft_3d_2d(x, fft_length=None):\n", + " mat = x[:, :fft_length[-2], :fft_length[-1]]\n", + " \n", + " # first FFT\n", + " res = onnx_rfft_3d_1d(mat, fft_length=fft_length[-1], transpose=True)\n", + " \n", + " # second FFT decomposed on FFT on real part and imaginary part\n", + " res2_real = onnx_rfft_3d_1d(res[0], fft_length=fft_length[0], transpose=False)\n", + " res2_imag = onnx_rfft_3d_1d(res[1], fft_length=fft_length[0], transpose=False) \n", + " res2_imag2 = npnx.vstack(-res2_imag[1:2], res2_imag[:1])\n", + " res = res2_real + res2_imag2\n", + " size = fft_length[1]//2 + 1\n", + " return res[:, :, :fft_length[-2], :size]\n", + "\n", + "\n", + "@onnxnumpy_np(signature=NDArrayType((\"T:all\", ), dtypes_out=('T',)))\n", + "def onnx_rfft_2d_any(x, fft_length=None):\n", + " new_shape = npnx.concat(\n", + " numpy.array([-1], dtype=numpy.int64), x.shape[-2:], axis=0)\n", + " mat2 = x.reshape(new_shape)\n", + " f2 = onnx_rfft_3d_2d(mat2, fft_length)\n", + " new_shape = npnx.concat(\n", + " numpy.array([2], dtype=numpy.int64), x.shape[:-2], f2.shape[-2:])\n", + " return f2.reshape(new_shape)\n", + "\n", + "\n", + "shape = (3, 1, 4)\n", + "fft_length = (1, 4)\n", + "rnd = numpy.random.randn(*list(shape)).astype(numpy.float32)\n", + "fft2d_cus = fft2d_any(rnd, fft_length)\n", + "fft2d_onx = onnx_rfft_2d_any(rnd, fft_length=fft_length)\n", + "almost_equal(fft2d_cus, fft2d_onx)" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "id": "85bccc7c", + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "if False:\n", + " li = list(onnx_rfft_2d_any.signed_compiled)[0]\n", + " onx = onnx_rfft_2d_any.signed_compiled[li]\n", + " from mlprodict.onnxrt import OnnxInference\n", + " oinf = OnnxInference(onx.compiled.onnx_)\n", + " oinf.run({'x': rnd}, verbose=1, fLOG=print)" + ] + }, + { + "cell_type": "markdown", + "id": "b170099f", + "metadata": {}, + "source": [ + "Let's do the same comparison." + ] + }, + { + "cell_type": "code", + "execution_count": 66, + "id": "b8ee6ac5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "OK x.shape=(3, 1, 4) length=(1, 4) output shape=(3, 4) or (2, 3, 1, 3)\n", + "OK x.shape=(3, 1, 4) length=(1, 4) output shape=(3, 4) or (2, 3, 1, 3)\n", + "OK x.shape=(3, 1, 4) length=(1, 4) output shape=(3, 4) or (2, 3, 1, 3)\n", + "OK x.shape=(3, 1, 4) length=(1, 2) output shape=(3, 4) or (2, 3, 1, 2)\n", + "DIS x.shape=(3, 1, 4) length=(1, 1) error=AssertionError('Mismatch max diff=1.5850220730146856e+22 > 1e-05.') output shape=(3, 4) or (2, 3, 1, 1)\n", + "OK x.shape=(5, 7) length=(5, 7) output shape=(3, 4) or (2, 5, 4)\n", + "OK x.shape=(5, 7) length=(1, 7) output shape=(3, 4) or (2, 1, 4)\n", + "OK x.shape=(5, 7) length=(2, 7) output shape=(3, 4) or (2, 2, 4)\n", + "OK x.shape=(5, 7) length=(5, 2) output shape=(3, 4) or (2, 5, 2)\n", + "OK x.shape=(5, 7) length=(3, 4) output shape=(3, 4) or (2, 3, 3)\n", + "OK x.shape=(3, 5, 7) length=(5, 7) output shape=(3, 4) or (2, 3, 5, 4)\n", + "OK x.shape=(3, 5, 7) length=(1, 7) output shape=(3, 4) or (2, 3, 1, 4)\n", + "OK x.shape=(3, 5, 7) length=(2, 7) output shape=(3, 4) or (2, 3, 2, 4)\n", + "OK x.shape=(3, 5, 7) length=(5, 2) output shape=(3, 4) or (2, 3, 5, 2)\n", + "OK x.shape=(3, 5, 7) length=(3, 4) output shape=(3, 4) or (2, 3, 3, 3)\n", + "OK x.shape=(7, 5) length=(7, 5) output shape=(3, 4) or (2, 7, 3)\n", + "OK x.shape=(7, 5) length=(1, 5) output shape=(3, 4) or (2, 1, 3)\n", + "OK x.shape=(7, 5) length=(2, 5) output shape=(3, 4) or (2, 2, 3)\n", + "OK x.shape=(7, 5) length=(7, 2) output shape=(3, 4) or (2, 7, 2)\n", + "OK x.shape=(7, 5) length=(3, 4) output shape=(3, 4) or (2, 3, 3)\n" + ] + } + ], + "source": [ + "for shape in [(3, 1, 4), (5, 7), (3, 5, 7), (7, 5)]:\n", + " for fft_length in [shape[-2:], (1, shape[-1]),\n", + " (min(2, shape[-2]), shape[-1]),\n", + " (shape[-2], 2),\n", + " (min(3, shape[-2]), min(4, shape[-2]))]:\n", + " x = numpy.random.randn(*list(shape)).astype(numpy.float32)\n", + " if len(fnp.shape) == 2:\n", + " fn= numpy.expand_dims(fnp, 0)\n", + " try:\n", + " cus = fft2d_any(x, fft_length)\n", + " except IndexError as e:\n", + " print(\"ERR x.shape=%r length=%r error=%r\" % (x.shape, fft_length, e))\n", + " continue\n", + " try:\n", + " onx = onnx_rfft_2d_any(x, fft_length=fft_length)\n", + " except IndexError as e:\n", + " print(\"ERR x.shape=%r length=%r error=%r\" % (x.shape, fft_length, e))\n", + " continue\n", + " try:\n", + " almost_equal(onx, cus)\n", + " except (AssertionError, IndexError) as e:\n", + " print(\"DIS x.shape=%r length=%r error=%r output shape=%r or %r\" % (\n", + " x.shape, fft_length, e, fnp.shape, cus.shape))\n", + " continue\n", + " print(\"OK x.shape=%r length=%r output shape=%r or %r\" % (\n", + " x.shape, fft_length, fnp.shape, cus.shape))" + ] + }, + { + "cell_type": "markdown", + "id": "b6f689c6", + "metadata": {}, + "source": [ + "There is one issue with ``fft_length=(1, 1)`` but that case is out of scope." + ] + }, + { + "cell_type": "markdown", + "id": "6b92f755", + "metadata": {}, + "source": [ + "### ONNX graph" + ] + }, + { + "cell_type": "code", + "execution_count": 67, + "id": "847256de", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 67, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "key = list(onnx_rfft_2d_any.signed_compiled)[0]\n", + "%onnxview onnx_rfft_2d_any.signed_compiled[key].compiled.onnx_" + ] + }, + { + "cell_type": "code", + "execution_count": 71, + "id": "dda43539", + "metadata": {}, + "outputs": [], + "source": [ + "with open(\"fft2d_any.onnx\", \"wb\") as f:\n", + " key = list(onnx_rfft_2d_any.signed_compiled)[0]\n", + " f.write(onnx_rfft_2d_any.signed_compiled[key].compiled.onnx_.SerializeToString())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "784c8f6e", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/_doc/sphinxdoc/source/blog/2021/2021-05-05_numpyapionnx1.rst b/_doc/sphinxdoc/source/blog/2021/2021-05-05_numpyapionnx1.rst index b188769df..90df2235b 100644 --- a/_doc/sphinxdoc/source/blog/2021/2021-05-05_numpyapionnx1.rst +++ b/_doc/sphinxdoc/source/blog/2021/2021-05-05_numpyapionnx1.rst @@ -136,7 +136,7 @@ pipe.fit(X_train, y_train) print(pipe.predict_proba(X_test[:2])) - onx = to_onnx(pipe, X_train[:1], rewrite_ops=True, + onx = to_onnx(pipe, X_train[:1], options={LogisticRegression: {'zipmap': False}}) oinf = OnnxInference(onx) print(oinf.run({'X': X_test[:2]})['probabilities']) diff --git a/_unittests/ut_npy/test_onnx_variable.py b/_unittests/ut_npy/test_onnx_variable.py index aa585af26..852f518ee 100644 --- a/_unittests/ut_npy/test_onnx_variable.py +++ b/_unittests/ut_npy/test_onnx_variable.py @@ -102,6 +102,13 @@ def test_abs_matmul(x: NDArray[Any, numpy.float32], return nxnp.abs(x) @ x +@onnxnumpy_default +def test_abs_matmul2(x: NDArray[Any, numpy.float32], + ) -> NDArray[Any, numpy.float32]: + "onnx numpy addition" + return nxnp.matmul(nxnp.abs(x), x) + + @onnxnumpy_default def test_abs_div(x: NDArray[Any, numpy.float32], ) -> NDArray[Any, numpy.float32]: @@ -515,6 +522,11 @@ def test_py_abs_matmul(self): y = test_abs_matmul(x) self.assertEqualArray(y, numpy.abs(x) @ x) + def test_py_abs_matmul2(self): + x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) + y = test_abs_matmul2(x) + self.assertEqualArray(y, numpy.abs(x) @ x) + def test_py_abs_div(self): x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) y = test_abs_div(x) diff --git a/mlprodict/npy/numpy_onnx_impl.py b/mlprodict/npy/numpy_onnx_impl.py index 7fba336e2..10a85ede2 100644 --- a/mlprodict/npy/numpy_onnx_impl.py +++ b/mlprodict/npy/numpy_onnx_impl.py @@ -3,7 +3,10 @@ @brief :epkg:`numpy` functions implemented with :epkg:`onnx`. .. versionadded:: 0.6 + +.. versionchanged:: 0.7 """ +import warnings import numpy from onnx import onnx_pb as onnx_proto # pylint: disable=E1101 from onnx.helper import make_tensor @@ -221,7 +224,15 @@ def det(x): def dot(a, b): - "See :epkg:`numpy:dot`." + "See :epkg:`numpy:dot`" + warnings.warn( + "npnx.dot is equivalent to npnx.matmul == numpy.matmul. " + "It only works in 2D.") + return OnnxVar(a, b, op=OnnxMatMul) + + +def matmul(a, b): + "See :epkg:`numpy:matmul`." return OnnxVar(a, b, op=OnnxMatMul) diff --git a/mlprodict/npy/numpy_onnx_pyrt.py b/mlprodict/npy/numpy_onnx_pyrt.py index 174bbb44b..723b20dfa 100644 --- a/mlprodict/npy/numpy_onnx_pyrt.py +++ b/mlprodict/npy/numpy_onnx_pyrt.py @@ -42,6 +42,7 @@ hstack as nx_hstack, isnan as nx_isnan, log as nx_log, + matmul as nx_matmul, mean as nx_mean, pad as nx_pad, prod as nx_prod, @@ -247,6 +248,12 @@ def log(x): return nx_log(x) +@onnxnumpy_np(signature=NDArrayType(("T:all", "T"))) +def matmul(a, b): + "matmul" + return nx_matmul(a, b) + + @onnxnumpy_np(signature=NDArrayType(("T:all", numpy.int64, 'T'), n_optional=1)) def pad(x, pads, constant_value=None, mode='constant'): "pad" From 06119cfe514782d3948511887a52ef6769a5f452 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Fri, 6 Aug 2021 10:25:23 +0200 Subject: [PATCH 06/12] unstuck circlecli --- _unittests/ut_onnx_conv/test_onnxrt_runtime_lightgbm.py | 1 + _unittests/ut_onnx_conv/test_onnxrt_runtime_lightgbm_bug.py | 5 ++++- mlprodict/npy/numpy_onnx_impl.py | 4 ++-- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/_unittests/ut_onnx_conv/test_onnxrt_runtime_lightgbm.py b/_unittests/ut_onnx_conv/test_onnxrt_runtime_lightgbm.py index 22d23b7bb..c0db8a11f 100644 --- a/_unittests/ut_onnx_conv/test_onnxrt_runtime_lightgbm.py +++ b/_unittests/ut_onnx_conv/test_onnxrt_runtime_lightgbm.py @@ -273,6 +273,7 @@ def test_onnxrt_python_lightgbm_categorical_iris_dataframe(self): values = pandas.DataFrame(got['output_probability']).values self.assertEqualArray(exp, values[:, 1], decimal=5) + @skipif_circleci('stuck') @unittest.skipIf(sys.platform == 'darwin', 'stuck') def test_lightgbm_booster_classifier(self): from lightgbm import Dataset, train as lgb_train diff --git a/_unittests/ut_onnx_conv/test_onnxrt_runtime_lightgbm_bug.py b/_unittests/ut_onnx_conv/test_onnxrt_runtime_lightgbm_bug.py index 3160007b2..04110dcaa 100644 --- a/_unittests/ut_onnx_conv/test_onnxrt_runtime_lightgbm_bug.py +++ b/_unittests/ut_onnx_conv/test_onnxrt_runtime_lightgbm_bug.py @@ -5,7 +5,7 @@ import unittest from logging import getLogger import numpy -from pyquickhelper.pycode import ExtTestCase +from pyquickhelper.pycode import ExtTestCase, skipif_circleci from skl2onnx.common.data_types import FloatTensorType from mlprodict.onnxrt import OnnxInference from mlprodict.onnx_conv import register_converters, to_onnx @@ -18,6 +18,7 @@ def setUp(self): logger.disabled = True register_converters() + @skipif_circleci('stuck') @unittest.skipIf(sys.platform == 'darwin', 'stuck') def test_lightgbm_regressor(self): from lightgbm import LGBMRegressor @@ -54,6 +55,7 @@ def test_lightgbm_regressor(self): print("lgb", i, rt, diff) self.assertLess(diff, 1e-3) + @skipif_circleci('stuck') @unittest.skipIf(sys.platform == 'darwin', 'stuck') def test_lightgbm_regressor_double(self): from lightgbm import LGBMRegressor @@ -90,6 +92,7 @@ def test_lightgbm_regressor_double(self): else: self.assertLess(diff, 1e-3) + @skipif_circleci('stuck') @unittest.skipIf(sys.platform == 'darwin', 'stuck') def test_xgboost_regressor(self): from xgboost import XGBRegressor diff --git a/mlprodict/npy/numpy_onnx_impl.py b/mlprodict/npy/numpy_onnx_impl.py index 10a85ede2..cab5ec8b8 100644 --- a/mlprodict/npy/numpy_onnx_impl.py +++ b/mlprodict/npy/numpy_onnx_impl.py @@ -226,8 +226,8 @@ def det(x): def dot(a, b): "See :epkg:`numpy:dot`" warnings.warn( - "npnx.dot is equivalent to npnx.matmul == numpy.matmul. " - "It only works in 2D.") + "npnx.dot is equivalent to npnx.matmul == numpy.matmul " + "!= numpy.dot with arrays with more than 3D dimensions.") return OnnxVar(a, b, op=OnnxMatMul) From bd08bb0d4fd9a437e7ca2e1b432fe86507cef347 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Fri, 6 Aug 2021 16:05:41 +0200 Subject: [PATCH 07/12] add function to convert onnx file into onnx code --- _doc/notebooks/onnx_fft.ipynb | 320 +++++++++--------------- _unittests/ut_tools/data/fft2d_any.onnx | Bin 0 -> 3781 bytes _unittests/ut_tools/test_export_onnx.py | 149 +++++++++++ mlprodict/onnx_tools/onnx_export.py | 266 ++++++++++++++++++++ 4 files changed, 538 insertions(+), 197 deletions(-) create mode 100644 _unittests/ut_tools/data/fft2d_any.onnx create mode 100644 _unittests/ut_tools/test_export_onnx.py create mode 100644 mlprodict/onnx_tools/onnx_export.py diff --git a/_doc/notebooks/onnx_fft.ipynb b/_doc/notebooks/onnx_fft.ipynb index da8073759..7be676d90 100644 --- a/_doc/notebooks/onnx_fft.ipynb +++ b/_doc/notebooks/onnx_fft.ipynb @@ -205,16 +205,16 @@ { "data": { "text/plain": [ - "array([[ 0.53789312+0.j , 3.79404014-2.48643182j,\n", - " -1.17830152+1.41518659j, 0.03550683-0.74360516j],\n", - " [ 2.91933875+0.j , -0.80104596+1.95022724j,\n", - " 0.04362985+0.86098245j, -0.12878266-2.96095567j],\n", - " [ 4.03702923+0.j , 4.2701964 +0.97900965j,\n", - " 3.66512519-1.69493214j, 0.46502876-2.36643456j],\n", - " [-0.99284691+0.j , 2.275587 +1.04069498j,\n", - " -2.30580317-1.25203798j, 1.99527731+1.50659889j],\n", - " [-3.28264478+0.j , 3.55459652-2.06178787j,\n", - " -0.46036977-1.45184738j, 1.30470444+0.66049446j]])" + "array([[ 1.4850522 +0.j , 0.20190521-0.68916309j,\n", + " -4.02216577+2.22221485j, 2.9797031 +5.4719969j ],\n", + " [ 4.34131445+0.j , 0.59711262-2.9628275j ,\n", + " 2.24324474+1.82517205j, 2.00603187-2.84835823j],\n", + " [-1.12793672+0.j , 1.38239012-2.21121807j,\n", + " -2.35260183+1.48559871j, -1.50040395+1.3950099j ],\n", + " [-2.28634521+0.j , 0.18810962-0.47532163j,\n", + " 2.02343296-0.92789895j, -0.21166875-3.03973764j],\n", + " [-0.48667854+0.j , 1.17243003-1.72963149j,\n", + " -1.73139954+1.88922393j, -0.44110992-0.81126496j]])" ] }, "execution_count": 4, @@ -323,11 +323,11 @@ { "data": { "text/plain": [ - "array([[ 2.75342363+0.j , -0.12520096-0.19458685j],\n", - " [-0.50395307+0.j , 0.49774965-0.57195525j],\n", - " [ 2.9414562 +0.j , 2.99449974-2.68322022j],\n", - " [ 0.79957253+0.j , 0.22962989+0.63038464j],\n", - " [ 1.62652442+0.j , 0.36857013-0.3835761j ]])" + "array([[ 2.02068053+0.j , -1.0523537 +4.21051588j],\n", + " [ 3.35636759+0.j , 1.32912183+0.17609187j],\n", + " [ 0.86212105+0.j , -1.73159653+0.58274578j],\n", + " [-0.93982866+0.j , 0.837072 -1.67403309j],\n", + " [ 0.72620976+0.j , -0.89599861+0.37600383j]])" ] }, "execution_count": 6, @@ -372,17 +372,17 @@ { "data": { "text/plain": [ - "array([[[ 0.53789306, 3.79404 , -1.1783015 , 0.03550687],\n", - " [ 2.9193387 , -0.80104595, 0.04362985, -0.12878267],\n", - " [ 4.037029 , 4.2701964 , 3.6651251 , 0.46502876],\n", - " [-0.99284697, 2.275587 , -2.305803 , 1.9952773 ],\n", - " [-3.2826447 , 3.5545964 , -0.4603698 , 1.3047044 ]],\n", - "\n", - " [[ 0. , -2.4864316 , 1.4151866 , -0.74360514],\n", - " [ 0. , 1.9502273 , 0.8609825 , -2.9609556 ],\n", - " [ 0. , 0.9790097 , -1.6949322 , -2.3664346 ],\n", - " [ 0. , 1.040695 , -1.2520379 , 1.5065988 ],\n", - " [ 0. , -2.061788 , -1.4518473 , 0.66049445]]],\n", + "array([[[ 1.4850521 , 0.2019052 , -4.0221653 , 2.979703 ],\n", + " [ 4.3413143 , 0.5971126 , 2.2432446 , 2.006032 ],\n", + " [-1.1279367 , 1.38239 , -2.3526018 , -1.500404 ],\n", + " [-2.2863452 , 0.18810958, 2.0234327 , -0.21166855],\n", + " [-0.48667857, 1.17243 , -1.7313995 , -0.44110993]],\n", + "\n", + " [[ 0. , -0.689163 , 2.2222147 , 5.4719973 ],\n", + " [ 0. , -2.9628277 , 1.8251722 , -2.8483584 ],\n", + " [ 0. , -2.211218 , 1.4855987 , 1.39501 ],\n", + " [ 0. , -0.47532162, -0.9278989 , -3.0397377 ],\n", + " [ 0. , -1.7296315 , 1.8892239 , -0.811265 ]]],\n", " dtype=float32)" ] }, @@ -440,16 +440,16 @@ { "data": { "text/html": [ - "
\n", + "
\n", "" ], "text/plain": [ - "" + "" ] }, "execution_count": 10, @@ -492,16 +492,16 @@ { "data": { "text/plain": [ - "array([[-4.00478568+0.j , -8.1697212 -0.32937158j,\n", - " 10.53796531-6.04714298j, 6.21706082+1.12281921j],\n", - " [ 9.15418879-5.84728016j, 6.2117498 +3.03786463j,\n", - " 2.96815316-2.94192092j, -0.68635391-1.52268308j],\n", - " [ 0.52234925+2.67149571j, -0.54453905-1.55520678j,\n", - " 3.10942332-1.6485952j , -4.10661833+9.33122798j],\n", - " [ 0.52234925-2.67149571j, -0.17987319+2.84609599j,\n", - " -6.48443904+3.97163974j, 8.55488898-3.49370573j],\n", - " [ 9.15418879+5.84728016j, 2.24954303+6.3142503j ,\n", - " 10.0104149 +0.18248643j, 3.74038426+7.20880644j]])" + "array([[ 6.73214482 +0.j , 3.36817961 +2.73163668j,\n", + " 4.42688318-10.23641065j, -3.27162717 -0.54910943j],\n", + " [ 0.05934113 -2.02790296j, -4.71694176 -1.80039444j,\n", + " -0.16187544 -1.27214887j, -4.76195404 -8.23146595j],\n", + " [-2.43886644 +0.67253454j, 0.62177822 +1.71628605j,\n", + " -4.22144547 -0.24384973j, -1.96253444 -2.26942153j],\n", + " [-2.43886644 -0.67253454j, -4.94210355 +1.65439295j,\n", + " -6.75624015 -2.50966739j, -1.62599543 +7.41506091j],\n", + " [ 0.05934113 +2.02790296j, 1.56068457 -4.5734695j ,\n", + " -2.9809962 +2.90470743j, 4.42498542-10.45411745j]])" ] }, "execution_count": 12, @@ -601,16 +601,16 @@ { "data": { "text/plain": [ - "array([[-4.00478568+0.j , -8.1697212 -0.32937158j,\n", - " 10.53796531-6.04714298j, 6.21706082+1.12281921j],\n", - " [ 9.15418879-5.84728016j, 6.2117498 +3.03786463j,\n", - " 2.96815316-2.94192092j, -0.68635391-1.52268308j],\n", - " [ 0.52234925+2.67149571j, -0.54453905-1.55520678j,\n", - " 3.10942332-1.6485952j , -4.10661833+9.33122798j],\n", - " [ 0.52234925-2.67149571j, -0.17987319+2.84609599j,\n", - " -6.48443904+3.97163974j, 8.55488898-3.49370573j],\n", - " [ 9.15418879+5.84728016j, 2.24954303+6.3142503j ,\n", - " 10.0104149 +0.18248643j, 3.74038426+7.20880644j]])" + "array([[ 6.73214482 +0.j , 3.36817961 +2.73163668j,\n", + " 4.42688318-10.23641065j, -3.27162717 -0.54910943j],\n", + " [ 0.05934113 -2.02790296j, -4.71694176 -1.80039444j,\n", + " -0.16187544 -1.27214887j, -4.76195404 -8.23146595j],\n", + " [-2.43886644 +0.67253454j, 0.62177822 +1.71628605j,\n", + " -4.22144547 -0.24384973j, -1.96253444 -2.26942153j],\n", + " [-2.43886644 -0.67253454j, -4.94210355 +1.65439295j,\n", + " -6.75624015 -2.50966739j, -1.62599543 +7.41506091j],\n", + " [ 0.05934113 +2.02790296j, 1.56068457 -4.5734695j ,\n", + " -2.9809962 +2.90470743j, 4.42498542-10.45411745j]])" ] }, "execution_count": 15, @@ -631,17 +631,17 @@ { "data": { "text/plain": [ - "array([[[-4.00478568, -8.1697212 , 10.53796531, 6.21706082],\n", - " [ 9.15418879, 6.2117498 , 2.96815316, -0.68635391],\n", - " [ 0.52234925, -0.54453905, 3.10942332, -4.10661833],\n", - " [ 0.52234925, -0.17987319, -6.48443904, 8.55488898],\n", - " [ 9.15418879, 2.24954303, 10.0104149 , 3.74038426]],\n", - "\n", - " [[ 0. , -0.32937158, -6.04714298, 1.12281921],\n", - " [-5.84728016, 3.03786463, -2.94192092, -1.52268308],\n", - " [ 2.67149571, -1.55520678, -1.6485952 , 9.33122798],\n", - " [-2.67149571, 2.84609599, 3.97163974, -3.49370573],\n", - " [ 5.84728016, 6.3142503 , 0.18248643, 7.20880644]]])" + "array([[[ 6.73214482, 3.36817961, 4.42688318, -3.27162717],\n", + " [ 0.05934113, -4.71694176, -0.16187544, -4.76195404],\n", + " [ -2.43886644, 0.62177822, -4.22144547, -1.96253444],\n", + " [ -2.43886644, -4.94210355, -6.75624015, -1.62599543],\n", + " [ 0.05934113, 1.56068457, -2.9809962 , 4.42498542]],\n", + "\n", + " [[ 0. , 2.73163668, -10.23641065, -0.54910943],\n", + " [ -2.02790296, -1.80039444, -1.27214887, -8.23146595],\n", + " [ 0.67253454, 1.71628605, -0.24384973, -2.26942153],\n", + " [ -0.67253454, 1.65439295, -2.50966739, 7.41506091],\n", + " [ 2.02790296, -4.5734695 , 2.90470743, -10.45411745]]])" ] }, "execution_count": 16, @@ -744,16 +744,16 @@ { "data": { "text/html": [ - "
\n", + "
\n", "" ], "text/plain": [ - "" + "" ] }, "execution_count": 19, @@ -768,8 +768,8 @@ }, { "cell_type": "code", - "execution_count": 70, - "id": "4bae913b", + "execution_count": 20, + "id": "3034da60", "metadata": {}, "outputs": [], "source": [ @@ -788,7 +788,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 21, "id": "16732cbb", "metadata": {}, "outputs": [], @@ -808,7 +808,7 @@ }, { "cell_type": "markdown", - "id": "024db509", + "id": "c9da88a0", "metadata": {}, "source": [ "## FFT2D with shape (3,1,4)\n", @@ -818,8 +818,8 @@ }, { "cell_type": "code", - "execution_count": 21, - "id": "1e4435be", + "execution_count": 22, + "id": "66ba70ee", "metadata": {}, "outputs": [ { @@ -828,7 +828,7 @@ "(3, 1, 4)" ] }, - "execution_count": 21, + "execution_count": 22, "metadata": {}, "output_type": "execute_result" } @@ -843,24 +843,24 @@ }, { "cell_type": "code", - "execution_count": 22, - "id": "d8ca3ba4", + "execution_count": 23, + "id": "a4d123e1", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "array([[[ 0.70852087+0.j , -1.19732058-1.59361078j,\n", - " -2.03344773+0.j , -1.19732058+1.59361078j]],\n", + "array([[[ 0.87908971+0.j , 0.85659337+2.10206711j,\n", + " 1.5270735 +0.j , 0.85659337-2.10206711j]],\n", "\n", - " [[ 0.28815117+0.j , -1.35137615+2.02822042j,\n", - " 0.9619607 +0.j , -1.35137615-2.02822042j]],\n", + " [[-5.01959181+0.j , -0.25658643+0.62102163j,\n", + " 2.18641639+0.j , -0.25658643-0.62102163j]],\n", "\n", - " [[-2.07279903+0.j , 0.18918216-1.91049451j,\n", - " -2.78790277+0.j , 0.18918216+1.91049451j]]])" + " [[ 0.60041136+0.j , -0.04546577-1.2931717j ,\n", + " 1.19486004+0.j , -0.04546577+1.2931717j ]]])" ] }, - "execution_count": 22, + "execution_count": 23, "metadata": {}, "output_type": "execute_result" } @@ -871,8 +871,8 @@ }, { "cell_type": "code", - "execution_count": 23, - "id": "77d52f4f", + "execution_count": 24, + "id": "4b1bd05b", "metadata": {}, "outputs": [ { @@ -893,7 +893,7 @@ }, { "cell_type": "markdown", - "id": "504d01de", + "id": "7bd79a00", "metadata": {}, "source": [ "### numpy version\n", @@ -903,8 +903,8 @@ }, { "cell_type": "code", - "execution_count": 24, - "id": "4157ce65", + "execution_count": 25, + "id": "3b618335", "metadata": {}, "outputs": [], "source": [ @@ -918,7 +918,7 @@ }, { "cell_type": "markdown", - "id": "cdc3836a", + "id": "7c837e7a", "metadata": {}, "source": [ "It works. And now a more efficient implementation. It is better to read [matmul](https://numpy.org/doc/stable/reference/generated/numpy.matmul.html) description before. To summarize, a third axis is equivalent to many matrix multiplications over the last two axes, as many as the dimension of the first axis: ``matmul(A[I,J,K], B[I,K,L]) --> C[I,J,L]``. Broadcasting also works... ``matmul(A[1,J,K], B[I,K,L]) --> C[I,J,L]``." @@ -926,8 +926,8 @@ }, { "cell_type": "code", - "execution_count": 25, - "id": "bf76dd61", + "execution_count": 26, + "id": "29055cb2", "metadata": {}, "outputs": [], "source": [ @@ -993,7 +993,7 @@ }, { "cell_type": "markdown", - "id": "8ddf5e67", + "id": "0128b3f2", "metadata": {}, "source": [ "We check with more shapes to see if the implementation works for all of them." @@ -1001,8 +1001,8 @@ }, { "cell_type": "code", - "execution_count": 26, - "id": "c2a2d068", + "execution_count": 27, + "id": "82f5fc78", "metadata": {}, "outputs": [ { @@ -1059,7 +1059,7 @@ }, { "cell_type": "markdown", - "id": "01d727fa", + "id": "c5f5229a", "metadata": {}, "source": [ "### ONNX version\n", @@ -1069,8 +1069,8 @@ }, { "cell_type": "code", - "execution_count": 27, - "id": "fe63d3ed", + "execution_count": 28, + "id": "025c2d88", "metadata": {}, "outputs": [], "source": [ @@ -1079,8 +1079,8 @@ }, { "cell_type": "code", - "execution_count": 28, - "id": "650c4849", + "execution_count": 29, + "id": "82664bc5", "metadata": {}, "outputs": [ { @@ -1169,7 +1169,7 @@ "\ttext-align:right;\n", "\tborder-top:1px solid #BBC;\n", "\tbackground:#EFEFEF\n", - "}
populating...
" + "}
populating...
" ] }, "metadata": {}, @@ -1799,7 +1799,7 @@ " newtxt = difflib.stringAsLines(newText),\n", " sm = new difflib.SequenceMatcher(base, newtxt),\n", " opcodes = sm.get_opcodes(),\n", - " diffoutputdiv = byId(\"diffid_2021-08-05_15_22_33_732030\");\n", + " diffoutputdiv = byId(\"diffid_2021-08-05_16_46_43_018480\");\n", "\n", " diffoutputdiv.innerHTML = \"\";\n", " contextSize = contextSize || null;\n", @@ -1824,7 +1824,7 @@ "" ] }, - "execution_count": 28, + "execution_count": 29, "metadata": {}, "output_type": "execute_result" } @@ -1838,8 +1838,8 @@ }, { "cell_type": "code", - "execution_count": 29, - "id": "b815568f", + "execution_count": 30, + "id": "cd7e14d4", "metadata": {}, "outputs": [ { @@ -1928,7 +1928,7 @@ "\ttext-align:right;\n", "\tborder-top:1px solid #BBC;\n", "\tbackground:#EFEFEF\n", - "}
populating...
" + "}
populating...
" ] }, "metadata": {}, @@ -2558,7 +2558,7 @@ " newtxt = difflib.stringAsLines(newText),\n", " sm = new difflib.SequenceMatcher(base, newtxt),\n", " opcodes = sm.get_opcodes(),\n", - " diffoutputdiv = byId(\"diffid_2021-08-05_15_22_33_799031\");\n", + " diffoutputdiv = byId(\"diffid_2021-08-05_16_46_43_079488\");\n", "\n", " diffoutputdiv.innerHTML = \"\";\n", " contextSize = contextSize || null;\n", @@ -2583,7 +2583,7 @@ "" ] }, - "execution_count": 29, + "execution_count": 30, "metadata": {}, "output_type": "execute_result" } @@ -2596,70 +2596,13 @@ }, { "cell_type": "code", - "execution_count": 61, - "id": "6194fe21", + "execution_count": 31, + "id": "51e7a4f7", "metadata": { "scrolled": false }, "outputs": [], "source": [ - "def dft_real_d3(x, fft_length=None, transpose=True):\n", - " if len(x.shape) != 3:\n", - " raise RuntimeError(\"Not implemented for shape=%r.\" % x.shape)\n", - " N = x.shape[1]\n", - " C = x.shape[-1] if transpose else x.shape[-2]\n", - " if fft_length is None:\n", - " fft_length = x.shape[-1]\n", - " size = fft_length // 2 + 1\n", - "\n", - " cst = dft_real_cst(C, fft_length)\n", - " if transpose:\n", - " x = numpy.transpose(x, (0, 2, 1))\n", - " a = cst[:, :, :fft_length]\n", - " b = x[:, :fft_length, :]\n", - " a = numpy.expand_dims(a, 0)\n", - " b = numpy.expand_dims(b, 1)\n", - " res = numpy.matmul(a, b)\n", - " res = res[:, :, :size, :]\n", - " return numpy.transpose(res, (1, 0, 3, 2))\n", - " else:\n", - " a = cst[:, :, :fft_length]\n", - " b = x[:, :fft_length, :]\n", - " a = numpy.expand_dims(a, 0)\n", - " b = numpy.expand_dims(b, 1)\n", - " res = numpy.matmul(a, b)\n", - " return numpy.transpose(res, (1, 0, 2, 3))\n", - "\n", - "\n", - "def fft2d_d3(mat, fft_length):\n", - " mat = mat[:, :fft_length[-2], :fft_length[-1]]\n", - " res = mat.copy()\n", - " \n", - " # first FFT\n", - " #print(\"AAAAAA\", res.shape)\n", - " res = dft_real_d3(res, fft_length=fft_length[-1], transpose=True)\n", - " #print(\"BBBBBB\", res.shape)\n", - " \n", - " # second FFT decomposed on FFT on real part and imaginary part\n", - " res2_real = dft_real_d3(res[0], fft_length=fft_length[-2], transpose=False)\n", - " res2_imag = dft_real_d3(res[1], fft_length=fft_length[-2], transpose=False)\n", - " res2_imag2 = numpy.vstack([-res2_imag[1:2], res2_imag[:1]])\n", - " res = res2_real + res2_imag2\n", - " size = fft_length[-1]//2 + 1\n", - " return res[:, :, :fft_length[-2], :size]\n", - "\n", - "\n", - "def fft2d_any(mat, fft_length):\n", - " new_shape = (-1, ) + mat.shape[-2:]\n", - " mat2 = mat.reshape(new_shape)\n", - " f2 = fft2d_d3(mat2, fft_length)\n", - " new_shape = (2, ) + mat.shape[:-2] + f2.shape[-2:]\n", - " return f2.reshape(new_shape)\n", - "\n", - "\n", - "\n", - "\n", - "\n", "def onnx_rfft_3d_1d(x, fft_length=None, transpose=True):\n", " if fft_length is None:\n", " raise RuntimeError(\"fft_length must be specified.\")\n", @@ -2718,26 +2661,9 @@ "almost_equal(fft2d_cus, fft2d_onx)" ] }, - { - "cell_type": "code", - "execution_count": 60, - "id": "85bccc7c", - "metadata": { - "scrolled": false - }, - "outputs": [], - "source": [ - "if False:\n", - " li = list(onnx_rfft_2d_any.signed_compiled)[0]\n", - " onx = onnx_rfft_2d_any.signed_compiled[li]\n", - " from mlprodict.onnxrt import OnnxInference\n", - " oinf = OnnxInference(onx.compiled.onnx_)\n", - " oinf.run({'x': rnd}, verbose=1, fLOG=print)" - ] - }, { "cell_type": "markdown", - "id": "b170099f", + "id": "37c45ae7", "metadata": {}, "source": [ "Let's do the same comparison." @@ -2745,8 +2671,8 @@ }, { "cell_type": "code", - "execution_count": 66, - "id": "b8ee6ac5", + "execution_count": 32, + "id": "11c1e596", "metadata": {}, "outputs": [ { @@ -2757,7 +2683,7 @@ "OK x.shape=(3, 1, 4) length=(1, 4) output shape=(3, 4) or (2, 3, 1, 3)\n", "OK x.shape=(3, 1, 4) length=(1, 4) output shape=(3, 4) or (2, 3, 1, 3)\n", "OK x.shape=(3, 1, 4) length=(1, 2) output shape=(3, 4) or (2, 3, 1, 2)\n", - "DIS x.shape=(3, 1, 4) length=(1, 1) error=AssertionError('Mismatch max diff=1.5850220730146856e+22 > 1e-05.') output shape=(3, 4) or (2, 3, 1, 1)\n", + "DIS x.shape=(3, 1, 4) length=(1, 1) error=AssertionError('Mismatch max diff=1.0 > 1e-05.') output shape=(3, 4) or (2, 3, 1, 1)\n", "OK x.shape=(5, 7) length=(5, 7) output shape=(3, 4) or (2, 5, 4)\n", "OK x.shape=(5, 7) length=(1, 7) output shape=(3, 4) or (2, 1, 4)\n", "OK x.shape=(5, 7) length=(2, 7) output shape=(3, 4) or (2, 2, 4)\n", @@ -2807,7 +2733,7 @@ }, { "cell_type": "markdown", - "id": "b6f689c6", + "id": "d197467f", "metadata": {}, "source": [ "There is one issue with ``fft_length=(1, 1)`` but that case is out of scope." @@ -2815,7 +2741,7 @@ }, { "cell_type": "markdown", - "id": "6b92f755", + "id": "33b5897e", "metadata": {}, "source": [ "### ONNX graph" @@ -2823,26 +2749,26 @@ }, { "cell_type": "code", - "execution_count": 67, - "id": "847256de", + "execution_count": 33, + "id": "d45e9a99", "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
\n", + "
\n", "" ], "text/plain": [ - "" + "" ] }, - "execution_count": 67, + "execution_count": 33, "metadata": {}, "output_type": "execute_result" } @@ -2854,8 +2780,8 @@ }, { "cell_type": "code", - "execution_count": 71, - "id": "dda43539", + "execution_count": 34, + "id": "2ab7a3d0", "metadata": {}, "outputs": [], "source": [ @@ -2867,7 +2793,7 @@ { "cell_type": "code", "execution_count": null, - "id": "784c8f6e", + "id": "9e5507f7", "metadata": {}, "outputs": [], "source": [] diff --git a/_unittests/ut_tools/data/fft2d_any.onnx b/_unittests/ut_tools/data/fft2d_any.onnx new file mode 100644 index 0000000000000000000000000000000000000000..0a868c71fdb9ab4451c7cc0edd9712971901d799 GIT binary patch literal 3781 zcmbVP-EJF26y9Ciu`{-tbogJXmDmD83yQq6>z_CmbsdpHAt}_Ypnx!&cts+NZEQzM zx$+C}1|ao@=kRrV%LQ)(oY|k5IlF5jVJX>}GiPSM?>lGC%vQ>@viNb#rjyB)S#tLF z_Z+M2U+nSz`x?`_zuh9Ak)0=#;giYY=gT1YB{*9w$!^Nv&>C2RtC>F;1*3+^dt!pM zEt*}8Jj_gH8H+>f>T<5hyJ8UsFZedmujuyRe7HFGXFtifN-;dT^LRd7&i%<^HeCQ9 zX4b&)I5Ak;qG|7t;w+f|)YE!KPk*kL0p|B4Jw2;E!q-B*wu;u3XqiNNq@u-K+jzTq zI1gQ%hX*(hUEKBP*0Dc)%6A7mkNxHG<=6rfe*(=fNj!WFyZz1Vm}KWlcE>6^mX3vE z*40=*0g=JD83#j$a1?q*6poX0IGj`=AhHmcydCX-Na6fK@C{KWT*JM+nL967aUOb- zJBlm-O=B8E=?%OKu<%gQ@`X2gg-H0% zv%=SfynM^5CD>J@XPcz|R;|JA|JS}PwQp|FJ{>t!u-jPs7S{gYzqD`1n+o^}@qw=p z|KoU9x#_%zovh|P?Br|Y_tWxwIGspcNiL_8Ra31)EvZ;bDt1TRsV!PM37}k5OMk8-Qx$4`8#;su%?&p?rTCs#CA@#aTs?9EGB?EOITh7?pQfLZ`Nr2x2_ z3IMuC6wN4O~FB}-FD7QmFBET;-T7l$I7HO(B%YIvB*0};m3X~_7T zqJ)9Xs?KSk9?MC3&{ib9w~5Drngclr?@|4kS(P_4__z>PGg^HJ9~p7izEd$OdPU#W zt%|0t?rQMsr$0OFx4@y?)wO;mU+v#@SfE%b!U8KjT7j=3|CsaZJJ4Uhd@=YVwy)J^ zY+rx%*X{TC+q1vbnmP4&uBP!etoiQC7k`3{|5kUS_G+EM$L!l7bn;K~)BEV2Y~PB; z|CZzYa+z#U`;~|{+6ngS6$8u$D5akHEV1iFWOs$#Kz7f5Es_QMy^l|8cx4zdqmK^``;7>?UO)rPO9qV8K=bQu{+U-KkR_+kd=@qPMWXHA_!mdb Bd+GoH literal 0 HcmV?d00001 diff --git a/_unittests/ut_tools/test_export_onnx.py b/_unittests/ut_tools/test_export_onnx.py new file mode 100644 index 000000000..0461ca00e --- /dev/null +++ b/_unittests/ut_tools/test_export_onnx.py @@ -0,0 +1,149 @@ +""" +@brief test log(time=3s) +""" +import os +import unittest +from io import StringIO +from contextlib import redirect_stdout, redirect_stderr +import numpy +from onnx import numpy_helper +from onnx.helper import ( + make_model, make_node, set_model_props, make_tensor, make_graph, + make_tensor_value_info) +from pyquickhelper.pycode import ExtTestCase +from mlprodict.onnx_tools.onnx_export import export2onnx, export2tf2onnx +from mlprodict.testing.verify_code import verify_code +from mlprodict.onnxrt import OnnxInference + + +class TestExportOnnx(ExtTestCase): + + def verify(self, content, existing_loc=None): + try: + left, __ = verify_code(content, exc=False) + except SyntaxError as e: + raise AssertionError( + "Unable to analyse a script due to %r. " + "\n--CODE--\n%s" + "" % (e, content)) from e + + # execution + try: + obj = compile(content, '', 'exec') + except SyntaxError as e: + raise AssertionError( + "Unable to compile a script due to %r. " + "\n--CODE--\n%s" + "" % (e, content)) from e + glo = globals().copy() + loc = {'numpy_helper': numpy_helper, + 'make_model': make_model, + 'make_node': make_node, + 'set_model_props': set_model_props, + 'make_tensor': make_tensor, + 'make_graph': make_graph, + 'make_tensor_value_info': make_tensor_value_info} + if existing_loc is not None: + loc.update(existing_loc) + glo.update(existing_loc) + out = StringIO() + err = StringIO() + self.assertLess(len(left), 5) + + with redirect_stdout(out): + with redirect_stderr(err): + try: + exec(obj, glo, loc) # pylint: disable=W0122 + except Exception as e: + raise AssertionError( + "Unable to execute a script due to %r. " + "\n--OUT--\n%s\n--ERR--\n%s\n--CODE--\n%s" + "" % (e, out.getvalue(), err.getvalue(), + content)) from e + return glo, loc + + def test_export_onnx(self): + this = os.path.dirname(__file__) + folder = os.path.join(this, "data") + names = ["fft2d_any.onnx"] + for name in names: + with self.subTest(name=name): + oinf0 = OnnxInference(os.path.join(folder, name)) + + x = numpy.random.randn(3, 1, 4).astype(numpy.float32) + y = oinf0.run({'x': x}) + + new_onnx = export2onnx(os.path.join(folder, name)) + glo, loc = self.verify(new_onnx) + model = loc['onnx_model'] + oinf = OnnxInference(model) + y1 = oinf0.run({'x': x}) + + new_onnx = export2onnx(os.path.join(folder, name), verbose=False) + glo, loc = self.verify(new_onnx) + model = loc['onnx_model'] + oinf = OnnxInference(model) + y2 = oinf0.run({'x': x}) + + self.assertEqualArray(y['y'], y1['y']) + self.assertEqualArray(y['y'], y2['y']) + + def verify_tf(self, content, existing_loc=None): + try: + left, __ = verify_code(content, exc=False) + except SyntaxError as e: + raise AssertionError( + "Unable to analyse a script due to %r. " + "\n--CODE--\n%s" + "" % (e, content)) from e + + # execution + try: + obj = compile(content, '', 'exec') + except SyntaxError as e: + raise AssertionError( + "Unable to compile a script due to %r. " + "\n--CODE--\n%s" + "" % (e, content)) from e + glo = globals().copy() + loc = {'numpy_helper': numpy_helper, + 'make_model': make_model, + 'make_node': make_node, + 'set_model_props': set_model_props, + 'make_tensor': make_tensor, + 'make_graph': make_graph, + 'make_tensor_value_info': make_tensor_value_info} + if existing_loc is not None: + loc.update(existing_loc) + glo.update(existing_loc) + out = StringIO() + err = StringIO() + self.assertLess(len(left), 5) + + with redirect_stdout(out): + with redirect_stderr(err): + try: + exec(obj, glo, loc) # pylint: disable=W0122 + except Exception as e: + raise AssertionError( + "Unable to execute a script due to %r. " + "\n--OUT--\n%s\n--ERR--\n%s\n--CODE--\n%s" + "" % (e, out.getvalue(), err.getvalue(), + content)) from e + return glo, loc + + def test_export2tf2onnx(self): + this = os.path.dirname(__file__) + folder = os.path.join(this, "data") + names = ["fft2d_any.onnx"] + for name in names: + with self.subTest(name=name): + new_onnx = export2tf2onnx(os.path.join(folder, name)) + print(new_onnx) + self.verify_tf(new_onnx) + + + + +if __name__ == "__main__": + unittest.main() diff --git a/mlprodict/onnx_tools/onnx_export.py b/mlprodict/onnx_tools/onnx_export.py new file mode 100644 index 000000000..ce966269f --- /dev/null +++ b/mlprodict/onnx_tools/onnx_export.py @@ -0,0 +1,266 @@ +""" +@file +@brief Exports an ONNX graph in a way it can we created again +with a python script. It relies on :epkg:`jinja2` and :epkg:`autopep8`. + +.. versionadded:: 0.7 +""" +from textwrap import dedent +import numpy +from jinja2 import Template +import autopep8 +import onnx +from onnx import numpy_helper +from .onnx2py_helper import _var_as_dict + + +_onnx_templates = dedent(""" + import numpy + from onnx import numpy_helper + from onnx.helper import ( + make_model, make_node, set_model_props, make_tensor, make_graph, + make_tensor_value_info) + + + def create_model(): + # containers + print('[containers]') # verbose + initializers = [] + nodes = [] + inputs = [] + outputs = [] + + # opsets + print('[opsets]') # verbose + opsets = {{ opsets }} + target_opset = {{ target_opset }} + + # initializers + print('[initializers]') # verbose + {% for name, value in initializers: %} + list_value = {{ value.ravel().tolist() }} + value = numpy.array(list_value, dtype=numpy.{{ value.dtype }}).reshape({{ value.shape }}) + tensor = numpy_helper.from_array(value, name='{{ name }}') + initializers.append(tensor) + {% endfor %} + + # inputs + print('[inputs]') # verbose + {% for name, type, shape in inputs: %} + value = make_tensor_value_info('{{ name }}', {{ type }}, {{ shape }}) + inputs.append(value) + {% endfor %} + + # inputs + print('[outputs]') # verbose + {% for name, type, shape in outputs: %} + value = make_tensor_value_info('{{ name }}', {{ type }}, {{ shape }}) + outputs.append(value) + {% endfor %} + + # nodes + print('[nodes]') # verbose + {% for node in nodes: %} + node = make_node( + '{{ node['op_type'] }}', + {{ node['inputs'] }}, + {{ node['outputs'] }}, + {% if node['name']: %}name='{{ node['name'] }}',{% endif %} + {%- for name, value in node['attributes']: -%} + {{ name }}={{ value }}, + {%- endfor -%} + domain='{{ node['domain'] }}') + nodes.append(node) + {% endfor %} + + # graph + print('[graph]') # verbose + graph = make_graph(nodes, '{{ name }}', inputs, outputs, initializers) + onnx_model = make_model(graph) + onnx_model.ir_version = {{ ir_version }} + onnx_model.producer_name = '{{ producer_name }}' + onnx_model.producer_version = '{{ producer_version }}' + onnx_model.domain = '{{ domain }}' + onnx_model.model_version = {{ model_version }} + onnx_model.doc_string = '{{ doc_string }}' + set_model_props(onnx_model, {{ metadata }}) + + # opsets + print('[graph]') # verbose + for dom, value in opsets.items(): + op_set = onnx_model.opset_import.add() + op_set.domain = dom + op_set.version = value + + return onnx_model + + onnx_model = create_model() +""") + + +_tf2onnx_templates = dedent(""" + @tf_op("MyOp") + class ConvertMyOp: + + supported_dtypes = [ + numpy.float32, + ] + + @classmethod + def any_version(cls, opset, ctx, node, **kwargs): + ''' + Documentation. + ''' + input_name = node.input[0] + onnx_dtype = ctx.get_dtype(input_name) + utils.make_sure(onnx_dtype in ConvertOp.supported_dtypes, "Unsupported input type.") + shape = ctx.get_shape(input_name) + space_names = {} + + # initializers + print('[initializers]') # verbose + {% for name, value in initializers: %} + value = numpy.array(list_value, dtype=numpy.{{ value.dtype }}).reshape({{ value.shape }}) + r_{{ name }} = ctx.make_const(name=utils.make_name('init_{{ name }}'), np_val=value) + space_names['{{ name }}'] = r_{{ name }}.name + initializers.append(tensor) + {% endfor %} + + # nodes + print('[nodes]') # verbose + {% for node in nodes: %} + attr = dict( + {%- for name, value in node['attributes']: -%} + {{ name }}={{ value }}, + {%- endfor -%}) + inputs = [{% for name in node['inputs']: -%}space_names['{{ name }}'], {%- endfor %}] + node = ctx.make_node( + {{ node['op_type'] }}, inputs=inputs, attr=attr,{% if node['domain']: -%} domain='{{ node['domain'] }}', {% endif %} + name=utils.make_name('{{ node['name'] }}')) + {% for i, name in enumerate(node['outputs']): -%} + space_names['{{ name }}'] = node.output[{{ i }}] + {%- endfor %} + nodes.append(node) + {% endfor %} +""") + + + +def export_template(model_onnx, templates, opset=None, verbose=True): + """ + Exports an ONNX model to the onnx syntax. + + :param model_onnx: string or ONNX graph + :param templates: exporting templates + :param opset: opset to export to + (None to select the one from the graph) + :param verbose: insert prints + :return: python code + """ + # containers + context = {} + + # opset + opsets = {} + for oimp in model_onnx.opset_import: + if oimp.domain == '' and opset is None: + opsets[oimp.domain] = oimp.version + opset = oimp.version + else: + opsets[oimp.domain] = opset + context['opsets'] = opsets + context['target_opset'] = opset + + # inits + initializers = [] + for init in model_onnx.graph.initializer: + value = numpy_helper.to_array(init) + initializers.append((init.name, value)) + context['initializers'] = initializers + + # inputs + inputs = [] + for inp in model_onnx.graph.input: + t = inp.type.tensor_type + dims = tuple(t.shape.dim) + inputs.append((inp.name, t.elem_type, dims)) + context['inputs'] = inputs + + # outputs + outputs = [] + for inp in model_onnx.graph.output: + t = inp.type.tensor_type + dims = tuple(t.shape.dim) + outputs.append((inp.name, t.elem_type, dims)) + context['outputs'] = outputs + + # node + nodes = [] + for node in model_onnx.graph.node: + attributes = [] + for at in node.attribute: + temp = _var_as_dict(at) + value = temp['value'] + if isinstance(value, str): + attributes.append((at.name, "%r" % value)) + else: + if isinstance(value, numpy.ndarray): + attributes.append((at.name, repr(value.tolist()))) + else: + attributes.append((at.name, repr(value))) + d = dict(name=node.name, op_type=node.op_type, + domain=node.domain, inputs=node.input, + outputs=node.output, attributes=attributes) + nodes.append(d) + context['nodes'] = nodes + + # graph + context['name'] = model_onnx.graph.name + context['ir_version'] = model_onnx.ir_version + context['producer_name'] = model_onnx.producer_name + context['domain'] = model_onnx.domain + context['model_version'] = model_onnx.model_version + context['doc_string'] = model_onnx.doc_string + context['metadata'] = {p.key: p.value for p in model_onnx.metadata_props} + + # final + template = Template(templates) + final = template.render(enumerate=enumerate, **context) + + if not verbose: + rows = final.split("\n") + final = "\n".join(_ for _ in rows if not _.endswith("# verbose")) + return autopep8.fix_code(final) + + +def export2onnx(model_onnx, opset=None, verbose=True): + """ + Exports an ONNX model to the :epkg:`onnx` syntax. + + :param model_onnx: string or ONNX graph + :param opset: opset to export to + (None to select the one from the graph) + :param verbose: inserts prints + :return: python code + """ + if isinstance(model_onnx, str): + model_onnx = onnx.load(model_onnx) + + return export_template(model_onnx, templates=_onnx_templates, opset=opset, verbose=verbose) + + + +def export2tf2onnx(model_onnx, opset=None, verbose=True): + """ + Exports an ONNX model to the e:pkg:`tensorflow-onnx` syntax. + + :param model_onnx: string or ONNX graph + :param opset: opset to export to + (None to select the one from the graph) + :param verbose: inserts prints + :return: python code + """ + if isinstance(model_onnx, str): + model_onnx = onnx.load(model_onnx) + + return export_template(model_onnx, templates=_tf2onnx_templates, opset=opset, verbose=verbose) From 3bbf6809fe22ac322ed866959562f05f221f28a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Sat, 7 Aug 2021 01:07:19 +0200 Subject: [PATCH 08/12] Add code to check tf2onnx conversion --- _doc/notebooks/onnx_fft.ipynb | 5550 ++++++++--------- .../test_onnxrt_python_runtime_custom.py | 21 +- _unittests/ut_tools/test_export_onnx.py | 79 +- mlprodict/onnx_tools/onnx_export.py | 269 +- mlprodict/onnxrt/onnx_inference.py | 3 +- mlprodict/testing/verify_code.py | 9 +- 6 files changed, 3067 insertions(+), 2864 deletions(-) diff --git a/_doc/notebooks/onnx_fft.ipynb b/_doc/notebooks/onnx_fft.ipynb index 7be676d90..90ca1f87a 100644 --- a/_doc/notebooks/onnx_fft.ipynb +++ b/_doc/notebooks/onnx_fft.ipynb @@ -1,2823 +1,2823 @@ { - "cells": [ - { - "cell_type": "markdown", - "id": "51bc89fc", - "metadata": {}, - "source": [ - "# ONNX and FFT\n", - "\n", - "ONNX does not fully support complex yet. It does not have any FFT operators either. What if we need them anyway?" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "7b2add97", - "metadata": {}, - "outputs": [ + "cells": [ + { + "cell_type": "markdown", + "id": "51bc89fc", + "metadata": {}, + "source": [ + "# ONNX and FFT\n", + "\n", + "ONNX does not fully support complex yet. It does not have any FFT operators either. What if we need them anyway?" + ] + }, { - "data": { - "text/html": [ - "
run previous cell, wait for 2 seconds
\n", - "" + "cell_type": "code", + "execution_count": 1, + "id": "7b2add97", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
run previous cell, wait for 2 seconds
\n", + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } ], - "text/plain": [ - "" + "source": [ + "from jyquickhelper import add_notebook_menu\n", + "add_notebook_menu()" ] - }, - "execution_count": 1, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from jyquickhelper import add_notebook_menu\n", - "add_notebook_menu()" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "acfdc3b0", - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext mlprodict" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "abb5fa88", - "metadata": {}, - "outputs": [ + }, { - "data": { - "text/plain": [ - "'1.21.1'" + "cell_type": "code", + "execution_count": 2, + "id": "acfdc3b0", + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext mlprodict" ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import numpy\n", - "numpy.__version__" - ] - }, - { - "cell_type": "markdown", - "id": "2e4f68e4", - "metadata": {}, - "source": [ - "## Python implementation of RFFT\n", - "\n", - "We try to replicate [numpy.rfft](https://numpy.org/doc/stable/reference/generated/numpy.fft.rfft.html)." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "cb1cc910", - "metadata": {}, - "outputs": [ + }, { - "data": { - "text/plain": [ - "array([[ 1.4850522 +0.j , 0.20190521-0.68916309j,\n", - " -4.02216577+2.22221485j, 2.9797031 +5.4719969j ],\n", - " [ 4.34131445+0.j , 0.59711262-2.9628275j ,\n", - " 2.24324474+1.82517205j, 2.00603187-2.84835823j],\n", - " [-1.12793672+0.j , 1.38239012-2.21121807j,\n", - " -2.35260183+1.48559871j, -1.50040395+1.3950099j ],\n", - " [-2.28634521+0.j , 0.18810962-0.47532163j,\n", - " 2.02343296-0.92789895j, -0.21166875-3.03973764j],\n", - " [-0.48667854+0.j , 1.17243003-1.72963149j,\n", - " -1.73139954+1.88922393j, -0.44110992-0.81126496j]])" + "cell_type": "code", + "execution_count": 3, + "id": "abb5fa88", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'1.21.1'" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import numpy\n", + "numpy.__version__" ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import numpy\n", - "\n", - "\n", - "def almost_equal(a, b, error=1e-5):\n", - " \"\"\"\n", - " The function compares two matrices, one may be complex. In that case,\n", - " this matrix is changed into a new matrix with a new first dimension,\n", - " [0,::] means real part, [1,::] means imaginary part.\n", - " \"\"\"\n", - " if a.dtype in (numpy.complex64, numpy.complex128):\n", - " dtype = numpy.float64 if a.dtype == numpy.complex128 else numpy.float32\n", - " new_a = numpy.empty((2,) + a.shape).astype(dtype)\n", - " new_a[0] = numpy.real(a)\n", - " new_a[1] = numpy.imag(a)\n", - " return almost_equal(new_a, b, error)\n", - " if b.dtype in (numpy.complex64, numpy.complex128):\n", - " return almost_equal(b, a, error)\n", - " if a.shape != b.shape:\n", - " raise AssertionError(\"Shape mismatch %r != %r.\" % (a.shape, b.shape))\n", - " diff = numpy.abs(a.ravel() - b.ravel()).max()\n", - " if diff > error:\n", - " raise AssertionError(\"Mismatch max diff=%r > %r.\" % (diff, error))\n", - "\n", - "\n", - "def dft_real_cst(N, fft_length):\n", - " n = numpy.arange(N)\n", - " k = n.reshape((N, 1)).astype(numpy.float64)\n", - " M = numpy.exp(-2j * numpy.pi * k * n / fft_length)\n", - " both = numpy.empty((2,) + M.shape)\n", - " both[0, :, :] = numpy.real(M)\n", - " both[1, :, :] = numpy.imag(M)\n", - " return both\n", - "\n", - "\n", - "def dft_real(x, fft_length=None, transpose=True):\n", - " if len(x.shape) == 1:\n", - " x = x.reshape((1, -1))\n", - " N = 1\n", - " else:\n", - " N = x.shape[0] \n", - " C = x.shape[-1] if transpose else x.shape[-2]\n", - " if fft_length is None:\n", - " fft_length = x.shape[-1]\n", - " size = fft_length // 2 + 1\n", - "\n", - " cst = dft_real_cst(C, fft_length)\n", - " if transpose:\n", - " x = numpy.transpose(x, (1, 0))\n", - " a = cst[:, :, :fft_length]\n", - " b = x[:fft_length]\n", - " res = numpy.matmul(a, b)\n", - " res = res[:, :size, :]\n", - " return numpy.transpose(res, (0, 2, 1))\n", - " else:\n", - " a = cst[:, :, :fft_length]\n", - " b = x[:fft_length]\n", - " return numpy.matmul(a, b)\n", - "\n", - "\n", - "rnd = numpy.random.randn(5, 7).astype(numpy.float32)\n", - "fft_np = numpy.fft.rfft(rnd)\n", - "fft_cus = dft_real(rnd)\n", - "fft_np" - ] - }, - { - "cell_type": "markdown", - "id": "0c052ea1", - "metadata": {}, - "source": [ - "Function `almost_equal` verifies both functions return the same results." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "3ca040cb", - "metadata": {}, - "outputs": [], - "source": [ - "almost_equal(fft_np, fft_cus)" - ] - }, - { - "cell_type": "markdown", - "id": "7fe77440", - "metadata": {}, - "source": [ - "Let's do the same with `fft_length < shape[1]`." - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "3a747a4a", - "metadata": {}, - "outputs": [ + }, { - "data": { - "text/plain": [ - "array([[ 2.02068053+0.j , -1.0523537 +4.21051588j],\n", - " [ 3.35636759+0.j , 1.32912183+0.17609187j],\n", - " [ 0.86212105+0.j , -1.73159653+0.58274578j],\n", - " [-0.93982866+0.j , 0.837072 -1.67403309j],\n", - " [ 0.72620976+0.j , -0.89599861+0.37600383j]])" + "cell_type": "markdown", + "id": "2e4f68e4", + "metadata": {}, + "source": [ + "## Python implementation of RFFT\n", + "\n", + "We try to replicate [numpy.rfft](https://numpy.org/doc/stable/reference/generated/numpy.fft.rfft.html)." ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "fft_np3 = numpy.fft.rfft(rnd, n=3)\n", - "fft_cus3 = dft_real(rnd, fft_length=3)\n", - "fft_np3" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "0db6247b", - "metadata": {}, - "outputs": [], - "source": [ - "almost_equal(fft_np3, fft_cus3)" - ] - }, - { - "cell_type": "markdown", - "id": "31a6ac9c", - "metadata": {}, - "source": [ - "## RFFT in ONNX\n", - "\n", - "Let's assume first the number of column of the input matrix is fixed. The result of function `dft_real_cst` can be considered as constant." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "efb67b9b", - "metadata": { - "scrolled": false - }, - "outputs": [ + }, { - "data": { - "text/plain": [ - "array([[[ 1.4850521 , 0.2019052 , -4.0221653 , 2.979703 ],\n", - " [ 4.3413143 , 0.5971126 , 2.2432446 , 2.006032 ],\n", - " [-1.1279367 , 1.38239 , -2.3526018 , -1.500404 ],\n", - " [-2.2863452 , 0.18810958, 2.0234327 , -0.21166855],\n", - " [-0.48667857, 1.17243 , -1.7313995 , -0.44110993]],\n", - "\n", - " [[ 0. , -0.689163 , 2.2222147 , 5.4719973 ],\n", - " [ 0. , -2.9628277 , 1.8251722 , -2.8483584 ],\n", - " [ 0. , -2.211218 , 1.4855987 , 1.39501 ],\n", - " [ 0. , -0.47532162, -0.9278989 , -3.0397377 ],\n", - " [ 0. , -1.7296315 , 1.8892239 , -0.811265 ]]],\n", - " dtype=float32)" + "cell_type": "code", + "execution_count": 4, + "id": "cb1cc910", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[ 1.4850522 +0.j , 0.20190521-0.68916309j,\n", + " -4.02216577+2.22221485j, 2.9797031 +5.4719969j ],\n", + " [ 4.34131445+0.j , 0.59711262-2.9628275j ,\n", + " 2.24324474+1.82517205j, 2.00603187-2.84835823j],\n", + " [-1.12793672+0.j , 1.38239012-2.21121807j,\n", + " -2.35260183+1.48559871j, -1.50040395+1.3950099j ],\n", + " [-2.28634521+0.j , 0.18810962-0.47532163j,\n", + " 2.02343296-0.92789895j, -0.21166875-3.03973764j],\n", + " [-0.48667854+0.j , 1.17243003-1.72963149j,\n", + " -1.73139954+1.88922393j, -0.44110992-0.81126496j]])" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import numpy\n", + "\n", + "\n", + "def almost_equal(a, b, error=1e-5):\n", + " \"\"\"\n", + " The function compares two matrices, one may be complex. In that case,\n", + " this matrix is changed into a new matrix with a new first dimension,\n", + " [0,::] means real part, [1,::] means imaginary part.\n", + " \"\"\"\n", + " if a.dtype in (numpy.complex64, numpy.complex128):\n", + " dtype = numpy.float64 if a.dtype == numpy.complex128 else numpy.float32\n", + " new_a = numpy.empty((2,) + a.shape).astype(dtype)\n", + " new_a[0] = numpy.real(a)\n", + " new_a[1] = numpy.imag(a)\n", + " return almost_equal(new_a, b, error)\n", + " if b.dtype in (numpy.complex64, numpy.complex128):\n", + " return almost_equal(b, a, error)\n", + " if a.shape != b.shape:\n", + " raise AssertionError(\"Shape mismatch %r != %r.\" % (a.shape, b.shape))\n", + " diff = numpy.abs(a.ravel() - b.ravel()).max()\n", + " if diff > error:\n", + " raise AssertionError(\"Mismatch max diff=%r > %r.\" % (diff, error))\n", + "\n", + "\n", + "def dft_real_cst(N, fft_length):\n", + " n = numpy.arange(N)\n", + " k = n.reshape((N, 1)).astype(numpy.float64)\n", + " M = numpy.exp(-2j * numpy.pi * k * n / fft_length)\n", + " both = numpy.empty((2,) + M.shape)\n", + " both[0, :, :] = numpy.real(M)\n", + " both[1, :, :] = numpy.imag(M)\n", + " return both\n", + "\n", + "\n", + "def dft_real(x, fft_length=None, transpose=True):\n", + " if len(x.shape) == 1:\n", + " x = x.reshape((1, -1))\n", + " N = 1\n", + " else:\n", + " N = x.shape[0] \n", + " C = x.shape[-1] if transpose else x.shape[-2]\n", + " if fft_length is None:\n", + " fft_length = x.shape[-1]\n", + " size = fft_length // 2 + 1\n", + "\n", + " cst = dft_real_cst(C, fft_length)\n", + " if transpose:\n", + " x = numpy.transpose(x, (1, 0))\n", + " a = cst[:, :, :fft_length]\n", + " b = x[:fft_length]\n", + " res = numpy.matmul(a, b)\n", + " res = res[:, :size, :]\n", + " return numpy.transpose(res, (0, 2, 1))\n", + " else:\n", + " a = cst[:, :, :fft_length]\n", + " b = x[:fft_length]\n", + " return numpy.matmul(a, b)\n", + "\n", + "\n", + "rnd = numpy.random.randn(5, 7).astype(numpy.float32)\n", + "fft_np = numpy.fft.rfft(rnd)\n", + "fft_cus = dft_real(rnd)\n", + "fft_np" ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from typing import Any\n", - "import mlprodict.npy.numpy_onnx_impl as npnx\n", - "from mlprodict.npy import onnxnumpy_np\n", - "from mlprodict.npy.onnx_numpy_annotation import NDArrayType\n", - "# from mlprodict.onnxrt import OnnxInference\n", - "\n", - "@onnxnumpy_np(signature=NDArrayType((\"T:all\", ), dtypes_out=('T',)))\n", - "def onnx_rfft(x, fft_length=None):\n", - " if fft_length is None:\n", - " raise RuntimeError(\"fft_length must be specified.\")\n", - " \n", - " size = fft_length // 2 + 1\n", - " cst = dft_real_cst(fft_length, fft_length).astype(numpy.float32)\n", - " xt = npnx.transpose(x, (1, 0))\n", - " res = npnx.matmul(cst[:, :, :fft_length], xt[:fft_length])[:, :size, :]\n", - " return npnx.transpose(res, (0, 2, 1))\n", - "\n", - "fft_onx = onnx_rfft(rnd, fft_length=rnd.shape[1])\n", - "fft_onx" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "c4b6b1a5", - "metadata": {}, - "outputs": [], - "source": [ - "almost_equal(fft_cus, fft_onx)" - ] - }, - { - "cell_type": "markdown", - "id": "a8c35327", - "metadata": {}, - "source": [ - "The corresponding ONNX graph is the following:" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "4d1a85b0", - "metadata": {}, - "outputs": [ + }, + { + "cell_type": "markdown", + "id": "0c052ea1", + "metadata": {}, + "source": [ + "Function `almost_equal` verifies both functions return the same results." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "3ca040cb", + "metadata": {}, + "outputs": [], + "source": [ + "almost_equal(fft_np, fft_cus)" + ] + }, { - "data": { - "text/html": [ - "
\n", - "" + "cell_type": "markdown", + "id": "7fe77440", + "metadata": {}, + "source": [ + "Let's do the same with `fft_length < shape[1]`." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "3a747a4a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[ 2.02068053+0.j , -1.0523537 +4.21051588j],\n", + " [ 3.35636759+0.j , 1.32912183+0.17609187j],\n", + " [ 0.86212105+0.j , -1.73159653+0.58274578j],\n", + " [-0.93982866+0.j , 0.837072 -1.67403309j],\n", + " [ 0.72620976+0.j , -0.89599861+0.37600383j]])" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } ], - "text/plain": [ - "" + "source": [ + "fft_np3 = numpy.fft.rfft(rnd, n=3)\n", + "fft_cus3 = dft_real(rnd, fft_length=3)\n", + "fft_np3" ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "key = list(onnx_rfft.signed_compiled)[0]\n", - "%onnxview onnx_rfft.signed_compiled[key].compiled.onnx_" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "6cf18aca", - "metadata": {}, - "outputs": [], - "source": [ - "fft_onx3 = onnx_rfft(rnd, fft_length=3)\n", - "almost_equal(fft_cus3, fft_onx3)" - ] - }, - { - "cell_type": "markdown", - "id": "6b466fd4", - "metadata": {}, - "source": [ - "## FFT 2D\n", - "\n", - "Below the code for complex features." - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "e0020084", - "metadata": {}, - "outputs": [ + }, { - "data": { - "text/plain": [ - "array([[ 6.73214482 +0.j , 3.36817961 +2.73163668j,\n", - " 4.42688318-10.23641065j, -3.27162717 -0.54910943j],\n", - " [ 0.05934113 -2.02790296j, -4.71694176 -1.80039444j,\n", - " -0.16187544 -1.27214887j, -4.76195404 -8.23146595j],\n", - " [-2.43886644 +0.67253454j, 0.62177822 +1.71628605j,\n", - " -4.22144547 -0.24384973j, -1.96253444 -2.26942153j],\n", - " [-2.43886644 -0.67253454j, -4.94210355 +1.65439295j,\n", - " -6.75624015 -2.50966739j, -1.62599543 +7.41506091j],\n", - " [ 0.05934113 +2.02790296j, 1.56068457 -4.5734695j ,\n", - " -2.9809962 +2.90470743j, 4.42498542-10.45411745j]])" + "cell_type": "code", + "execution_count": 7, + "id": "0db6247b", + "metadata": {}, + "outputs": [], + "source": [ + "almost_equal(fft_np3, fft_cus3)" ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "def _DFT_cst(N, fft_length, trunc=True):\n", - " n = numpy.arange(N)\n", - " k = n.reshape((N, 1)).astype(numpy.float64)\n", - " M = numpy.exp(-2j * numpy.pi * k * n / fft_length)\n", - " return M[:fft_length // 2 + 1] if trunc else M\n", - "\n", - "def DFT(x, fft_length=None, axis=1):\n", - " if axis == 1:\n", - " x = x.T\n", - " if fft_length is None:\n", - " fft_length = x.shape[0]\n", - " cst = _DFT_cst(x.shape[0], fft_length, trunc=axis==1)\n", - " if axis == 1:\n", - " return numpy.matmul(cst, x).T\n", - " return numpy.matmul(cst, x)\n", - "\n", - "def fft2d_(mat, fft_length):\n", - " mat = mat[:fft_length[0], :fft_length[1]]\n", - " res = mat.copy()\n", - " res = DFT(res, fft_length[1], axis=1)\n", - " res = DFT(res, fft_length[0], axis=0)\n", - " return res[:fft_length[0], :fft_length[1]//2 + 1]\n", - "\n", - "\n", - "rnd = numpy.random.randn(5, 7).astype(numpy.float32)\n", - "fft2d_np_ = fft2d_(rnd, rnd.shape)\n", - "fft2d_np = numpy.fft.rfft2(rnd)\n", - "fft2d_np_" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "777d2775", - "metadata": {}, - "outputs": [], - "source": [ - "almost_equal(fft2d_np_, fft2d_np)" - ] - }, - { - "cell_type": "markdown", - "id": "cfbbe2fd", - "metadata": {}, - "source": [ - "It implies the computation of two FFT 1D along both axes. However, as ONNX does not support complex, it needs to be rewritten with only real numbers. The algorithm can be summarized into this formula $FFT(FFT(x, axis=1), axis=0)$. If *x* is real, $FFT(x, .)$ is complex. We still assume *x* is real, it then becomes (FFT is a linear operator, so $FFT(ix)=i FFT(x)$):\n", - "\n", - "* $y = FFT(x, axis=1)$\n", - "* $z_r = FFT(Real(y), axis=0)$, $z_i = FFT(Imag(y), axis=0)$\n", - "* $z = z_r + i z_i$\n", - "\n", - "*z* is the desired output. The following implementation is probably not the most efficient one. It avoids inplace computation as ONNX does like that." - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "dd4fc711", - "metadata": {}, - "outputs": [], - "source": [ - "def fft2d(mat, fft_length):\n", - " mat = mat[:fft_length[0], :fft_length[1]]\n", - " res = mat.copy()\n", - " \n", - " # first FFT\n", - " res = dft_real(res, fft_length=fft_length[1], transpose=True)\n", - " \n", - " # second FFT decomposed on FFT on real part and imaginary part\n", - " res2_real = dft_real(res[0], fft_length=fft_length[0], transpose=False)\n", - " res2_imag = dft_real(res[1], fft_length=fft_length[0], transpose=False) \n", - " res2_imag2 = numpy.vstack([-res2_imag[1:2], res2_imag[:1]])\n", - " res = res2_real + res2_imag2\n", - " size = fft_length[1]//2 + 1\n", - " return res[:, :fft_length[0], :size]\n", - "\n", - "\n", - "fft2d_np = numpy.fft.rfft2(rnd)\n", - "fft2d_cus = fft2d(rnd, rnd.shape)\n", - "almost_equal(fft2d_np, fft2d_cus)" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "bb8667e6", - "metadata": {}, - "outputs": [ + }, { - "data": { - "text/plain": [ - "array([[ 6.73214482 +0.j , 3.36817961 +2.73163668j,\n", - " 4.42688318-10.23641065j, -3.27162717 -0.54910943j],\n", - " [ 0.05934113 -2.02790296j, -4.71694176 -1.80039444j,\n", - " -0.16187544 -1.27214887j, -4.76195404 -8.23146595j],\n", - " [-2.43886644 +0.67253454j, 0.62177822 +1.71628605j,\n", - " -4.22144547 -0.24384973j, -1.96253444 -2.26942153j],\n", - " [-2.43886644 -0.67253454j, -4.94210355 +1.65439295j,\n", - " -6.75624015 -2.50966739j, -1.62599543 +7.41506091j],\n", - " [ 0.05934113 +2.02790296j, 1.56068457 -4.5734695j ,\n", - " -2.9809962 +2.90470743j, 4.42498542-10.45411745j]])" + "cell_type": "markdown", + "id": "31a6ac9c", + "metadata": {}, + "source": [ + "## RFFT in ONNX\n", + "\n", + "Let's assume first the number of column of the input matrix is fixed. The result of function `dft_real_cst` can be considered as constant." ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "fft2d_np" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "56a94d97", - "metadata": {}, - "outputs": [ + }, { - "data": { - "text/plain": [ - "array([[[ 6.73214482, 3.36817961, 4.42688318, -3.27162717],\n", - " [ 0.05934113, -4.71694176, -0.16187544, -4.76195404],\n", - " [ -2.43886644, 0.62177822, -4.22144547, -1.96253444],\n", - " [ -2.43886644, -4.94210355, -6.75624015, -1.62599543],\n", - " [ 0.05934113, 1.56068457, -2.9809962 , 4.42498542]],\n", - "\n", - " [[ 0. , 2.73163668, -10.23641065, -0.54910943],\n", - " [ -2.02790296, -1.80039444, -1.27214887, -8.23146595],\n", - " [ 0.67253454, 1.71628605, -0.24384973, -2.26942153],\n", - " [ -0.67253454, 1.65439295, -2.50966739, 7.41506091],\n", - " [ 2.02790296, -4.5734695 , 2.90470743, -10.45411745]]])" + "cell_type": "code", + "execution_count": 8, + "id": "efb67b9b", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[[ 1.4850521 , 0.2019052 , -4.0221653 , 2.979703 ],\n", + " [ 4.3413143 , 0.5971126 , 2.2432446 , 2.006032 ],\n", + " [-1.1279367 , 1.38239 , -2.3526018 , -1.500404 ],\n", + " [-2.2863452 , 0.18810958, 2.0234327 , -0.21166855],\n", + " [-0.48667857, 1.17243 , -1.7313995 , -0.44110993]],\n", + "\n", + " [[ 0. , -0.689163 , 2.2222147 , 5.4719973 ],\n", + " [ 0. , -2.9628277 , 1.8251722 , -2.8483584 ],\n", + " [ 0. , -2.211218 , 1.4855987 , 1.39501 ],\n", + " [ 0. , -0.47532162, -0.9278989 , -3.0397377 ],\n", + " [ 0. , -1.7296315 , 1.8892239 , -0.811265 ]]],\n", + " dtype=float32)" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from typing import Any\n", + "import mlprodict.npy.numpy_onnx_impl as npnx\n", + "from mlprodict.npy import onnxnumpy_np\n", + "from mlprodict.npy.onnx_numpy_annotation import NDArrayType\n", + "# from mlprodict.onnxrt import OnnxInference\n", + "\n", + "@onnxnumpy_np(signature=NDArrayType((\"T:all\", ), dtypes_out=('T',)))\n", + "def onnx_rfft(x, fft_length=None):\n", + " if fft_length is None:\n", + " raise RuntimeError(\"fft_length must be specified.\")\n", + " \n", + " size = fft_length // 2 + 1\n", + " cst = dft_real_cst(fft_length, fft_length).astype(numpy.float32)\n", + " xt = npnx.transpose(x, (1, 0))\n", + " res = npnx.matmul(cst[:, :, :fft_length], xt[:fft_length])[:, :size, :]\n", + " return npnx.transpose(res, (0, 2, 1))\n", + "\n", + "fft_onx = onnx_rfft(rnd, fft_length=rnd.shape[1])\n", + "fft_onx" ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "fft2d_cus" - ] - }, - { - "cell_type": "markdown", - "id": "faa21909", - "metadata": {}, - "source": [ - "And with a different `fft_length`." - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "bf98995f", - "metadata": {}, - "outputs": [], - "source": [ - "fft2d_np = numpy.fft.rfft2(rnd, (4, 6))\n", - "fft2d_cus = fft2d(rnd, (4, 6))\n", - "almost_equal(fft2d_np[:4, :], fft2d_cus)" - ] - }, - { - "cell_type": "markdown", - "id": "caee1f84", - "metadata": {}, - "source": [ - "## FFT 2D in ONNX\n", - "\n", - "We use again the numpy API for ONNX." - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "id": "ca641274", - "metadata": {}, - "outputs": [], - "source": [ - "def onnx_rfft_1d(x, fft_length=None, transpose=True):\n", - " if fft_length is None:\n", - " raise RuntimeError(\"fft_length must be specified.\")\n", - " \n", - " size = fft_length // 2 + 1\n", - " cst = dft_real_cst(fft_length, fft_length).astype(numpy.float32)\n", - " if transpose:\n", - " xt = npnx.transpose(x, (1, 0))\n", - " res = npnx.matmul(cst[:, :, :fft_length], xt[:fft_length])[:, :size, :]\n", - " return npnx.transpose(res, (0, 2, 1))\n", - " else:\n", - " return npnx.matmul(cst[:, :, :fft_length], x[:fft_length])\n", - "\n", - "\n", - "@onnxnumpy_np(signature=NDArrayType((\"T:all\", ), dtypes_out=('T',)))\n", - "def onnx_rfft_2d(x, fft_length=None):\n", - " mat = x[:fft_length[0], :fft_length[1]]\n", - " \n", - " # first FFT\n", - " res = onnx_rfft_1d(mat, fft_length=fft_length[1], transpose=True)\n", - " \n", - " # second FFT decomposed on FFT on real part and imaginary part\n", - " res2_real = onnx_rfft_1d(res[0], fft_length=fft_length[0], transpose=False)\n", - " res2_imag = onnx_rfft_1d(res[1], fft_length=fft_length[0], transpose=False) \n", - " res2_imag2 = npnx.vstack(-res2_imag[1:2], res2_imag[:1])\n", - " res = res2_real + res2_imag2\n", - " size = fft_length[1]//2 + 1\n", - " return res[:, :fft_length[0], :size]\n", - "\n", - "\n", - "fft2d_cus = fft2d(rnd, rnd.shape)\n", - "fft2d_onx = onnx_rfft_2d(rnd, fft_length=rnd.shape)\n", - "almost_equal(fft2d_cus, fft2d_onx)" - ] - }, - { - "cell_type": "markdown", - "id": "20fcd8a9", - "metadata": {}, - "source": [ - "The corresponding ONNX graph." - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "b1379b06", - "metadata": { - "scrolled": false - }, - "outputs": [ + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "c4b6b1a5", + "metadata": {}, + "outputs": [], + "source": [ + "almost_equal(fft_cus, fft_onx)" + ] + }, + { + "cell_type": "markdown", + "id": "a8c35327", + "metadata": {}, + "source": [ + "The corresponding ONNX graph is the following:" + ] + }, { - "data": { - "text/html": [ - "
\n", - "" + "cell_type": "code", + "execution_count": 10, + "id": "4d1a85b0", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } ], - "text/plain": [ - "" + "source": [ + "key = list(onnx_rfft.signed_compiled)[0]\n", + "%onnxview onnx_rfft.signed_compiled[key].compiled.onnx_" ] - }, - "execution_count": 19, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "key = list(onnx_rfft_2d.signed_compiled)[0]\n", - "%onnxview onnx_rfft_2d.signed_compiled[key].compiled.onnx_" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "id": "3034da60", - "metadata": {}, - "outputs": [], - "source": [ - "with open(\"fft2d.onnx\", \"wb\") as f:\n", - " key = list(onnx_rfft_2d.signed_compiled)[0]\n", - " f.write(onnx_rfft_2d.signed_compiled[key].compiled.onnx_.SerializeToString())" - ] - }, - { - "cell_type": "markdown", - "id": "3a747f0c", - "metadata": {}, - "source": [ - "With a different `fft_length`." - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "id": "16732cbb", - "metadata": {}, - "outputs": [], - "source": [ - "fft2d_cus = fft2d(rnd, (4, 5))\n", - "fft2d_onx = onnx_rfft_2d(rnd, fft_length=(4, 5))\n", - "almost_equal(fft2d_cus, fft2d_onx)" - ] - }, - { - "cell_type": "markdown", - "id": "04924e7d", - "metadata": {}, - "source": [ - "This implementation of FFT in ONNX assumes shapes and fft lengths are constant. Otherwise, the matrix returned by function `dft_real_cst` must be converted as well. That's left as an exercise." - ] - }, - { - "cell_type": "markdown", - "id": "c9da88a0", - "metadata": {}, - "source": [ - "## FFT2D with shape (3,1,4)\n", - "\n", - "Previous implementation expects the input matrix to have two dimensions. It fails with 3." - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "id": "66ba70ee", - "metadata": {}, - "outputs": [ + }, { - "data": { - "text/plain": [ - "(3, 1, 4)" + "cell_type": "code", + "execution_count": 11, + "id": "6cf18aca", + "metadata": {}, + "outputs": [], + "source": [ + "fft_onx3 = onnx_rfft(rnd, fft_length=3)\n", + "almost_equal(fft_cus3, fft_onx3)" ] - }, - "execution_count": 22, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "shape = (3, 1, 4)\n", - "fft_length = (1, 4)\n", - "rnd = numpy.random.randn(*list(shape)).astype(numpy.float32)\n", - "fft2d_numpy = numpy.fft.fft2(rnd, fft_length)\n", - "fft2d_numpy.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "id": "a4d123e1", - "metadata": {}, - "outputs": [ + }, { - "data": { - "text/plain": [ - "array([[[ 0.87908971+0.j , 0.85659337+2.10206711j,\n", - " 1.5270735 +0.j , 0.85659337-2.10206711j]],\n", - "\n", - " [[-5.01959181+0.j , -0.25658643+0.62102163j,\n", - " 2.18641639+0.j , -0.25658643-0.62102163j]],\n", - "\n", - " [[ 0.60041136+0.j , -0.04546577-1.2931717j ,\n", - " 1.19486004+0.j , -0.04546577+1.2931717j ]]])" + "cell_type": "markdown", + "id": "6b466fd4", + "metadata": {}, + "source": [ + "## FFT 2D\n", + "\n", + "Below the code for complex features." ] - }, - "execution_count": 23, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "fft2d_numpy" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "id": "4b1bd05b", - "metadata": {}, - "outputs": [ + }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "axes don't match array\n" - ] - } - ], - "source": [ - "try:\n", - " fft2d_cus = fft2d(rnd, fft_length)\n", - "except Exception as e:\n", - " print(e)\n", - "# fft2d_onx = onnx_rfft_2d(rnd, fft_length=fft_length)" - ] - }, - { - "cell_type": "markdown", - "id": "7bd79a00", - "metadata": {}, - "source": [ - "### numpy version\n", - "\n", - "Let's do it again with numpy first. [fft2](https://numpy.org/doc/stable/reference/generated/numpy.fft.fft2.html) performs `fft2` on the last two axis as many times as the first axis. The goal is still to have an implementation which works for any dimension." - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "id": "3b618335", - "metadata": {}, - "outputs": [], - "source": [ - "conc = []\n", - "for i in range(rnd.shape[0]):\n", - " f2 = fft2d(rnd[i], fft_length)\n", - " conc.append(numpy.expand_dims(f2, 0))\n", - "res = numpy.vstack(conc).transpose(1, 0, 2, 3)\n", - "almost_equal(fft2d_numpy[:, :, :3], res)" - ] - }, - { - "cell_type": "markdown", - "id": "7c837e7a", - "metadata": {}, - "source": [ - "It works. And now a more efficient implementation. It is better to read [matmul](https://numpy.org/doc/stable/reference/generated/numpy.matmul.html) description before. To summarize, a third axis is equivalent to many matrix multiplications over the last two axes, as many as the dimension of the first axis: ``matmul(A[I,J,K], B[I,K,L]) --> C[I,J,L]``. Broadcasting also works... ``matmul(A[1,J,K], B[I,K,L]) --> C[I,J,L]``." - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "id": "29055cb2", - "metadata": {}, - "outputs": [], - "source": [ - "def dft_real_d3(x, fft_length=None, transpose=True):\n", - " if len(x.shape) != 3:\n", - " raise RuntimeError(\"Not implemented for shape=%r.\" % x.shape)\n", - " N = x.shape[1]\n", - " C = x.shape[-1] if transpose else x.shape[-2]\n", - " if fft_length is None:\n", - " fft_length = x.shape[-1]\n", - " size = fft_length // 2 + 1\n", - "\n", - " cst = dft_real_cst(C, fft_length)\n", - " if transpose:\n", - " x = numpy.transpose(x, (0, 2, 1))\n", - " a = cst[:, :, :fft_length]\n", - " b = x[:, :fft_length, :]\n", - " a = numpy.expand_dims(a, 0)\n", - " b = numpy.expand_dims(b, 1)\n", - " res = numpy.matmul(a, b)\n", - " res = res[:, :, :size, :]\n", - " return numpy.transpose(res, (1, 0, 3, 2))\n", - " else:\n", - " a = cst[:, :, :fft_length]\n", - " b = x[:, :fft_length, :]\n", - " a = numpy.expand_dims(a, 0)\n", - " b = numpy.expand_dims(b, 1)\n", - " res = numpy.matmul(a, b)\n", - " return numpy.transpose(res, (1, 0, 2, 3))\n", - "\n", - "\n", - "def fft2d_d3(mat, fft_length):\n", - " mat = mat[:, :fft_length[-2], :fft_length[-1]]\n", - " res = mat.copy()\n", - " \n", - " # first FFT\n", - " res = dft_real_d3(res, fft_length=fft_length[-1], transpose=True)\n", - " \n", - " # second FFT decomposed on FFT on real part and imaginary part\n", - " res2_real = dft_real_d3(res[0], fft_length=fft_length[-2], transpose=False)\n", - " res2_imag = dft_real_d3(res[1], fft_length=fft_length[-2], transpose=False)\n", - " res2_imag2 = numpy.vstack([-res2_imag[1:2], res2_imag[:1]])\n", - " res = res2_real + res2_imag2\n", - " size = fft_length[-1]//2 + 1\n", - " return res[:, :, :fft_length[-2], :size]\n", - "\n", - "\n", - "def fft2d_any(mat, fft_length):\n", - " new_shape = (-1, ) + mat.shape[-2:]\n", - " mat2 = mat.reshape(new_shape)\n", - " f2 = fft2d_d3(mat2, fft_length)\n", - " new_shape = (2, ) + mat.shape[:-2] + f2.shape[-2:]\n", - " return f2.reshape(new_shape)\n", - "\n", - "\n", - "shape = (3, 1, 4)\n", - "fft_length = (1, 4)\n", - "rnd = numpy.random.randn(*list(shape)).astype(numpy.float32)\n", - "fft2d_numpy = numpy.fft.fft2(rnd, fft_length)\n", - "fft2d_cus = fft2d_any(rnd, fft_length)\n", - "almost_equal(fft2d_numpy[..., :3], fft2d_cus)" - ] - }, - { - "cell_type": "markdown", - "id": "0128b3f2", - "metadata": {}, - "source": [ - "We check with more shapes to see if the implementation works for all of them." - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "id": "82f5fc78", - "metadata": {}, - "outputs": [ + "cell_type": "code", + "execution_count": 12, + "id": "e0020084", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[ 6.73214482 +0.j , 3.36817961 +2.73163668j,\n", + " 4.42688318-10.23641065j, -3.27162717 -0.54910943j],\n", + " [ 0.05934113 -2.02790296j, -4.71694176 -1.80039444j,\n", + " -0.16187544 -1.27214887j, -4.76195404 -8.23146595j],\n", + " [-2.43886644 +0.67253454j, 0.62177822 +1.71628605j,\n", + " -4.22144547 -0.24384973j, -1.96253444 -2.26942153j],\n", + " [-2.43886644 -0.67253454j, -4.94210355 +1.65439295j,\n", + " -6.75624015 -2.50966739j, -1.62599543 +7.41506091j],\n", + " [ 0.05934113 +2.02790296j, 1.56068457 -4.5734695j ,\n", + " -2.9809962 +2.90470743j, 4.42498542-10.45411745j]])" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "def _DFT_cst(N, fft_length, trunc=True):\n", + " n = numpy.arange(N)\n", + " k = n.reshape((N, 1)).astype(numpy.float64)\n", + " M = numpy.exp(-2j * numpy.pi * k * n / fft_length)\n", + " return M[:fft_length // 2 + 1] if trunc else M\n", + "\n", + "def DFT(x, fft_length=None, axis=1):\n", + " if axis == 1:\n", + " x = x.T\n", + " if fft_length is None:\n", + " fft_length = x.shape[0]\n", + " cst = _DFT_cst(x.shape[0], fft_length, trunc=axis==1)\n", + " if axis == 1:\n", + " return numpy.matmul(cst, x).T\n", + " return numpy.matmul(cst, x)\n", + "\n", + "def fft2d_(mat, fft_length):\n", + " mat = mat[:fft_length[0], :fft_length[1]]\n", + " res = mat.copy()\n", + " res = DFT(res, fft_length[1], axis=1)\n", + " res = DFT(res, fft_length[0], axis=0)\n", + " return res[:fft_length[0], :fft_length[1]//2 + 1]\n", + "\n", + "\n", + "rnd = numpy.random.randn(5, 7).astype(numpy.float32)\n", + "fft2d_np_ = fft2d_(rnd, rnd.shape)\n", + "fft2d_np = numpy.fft.rfft2(rnd)\n", + "fft2d_np_" + ] + }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "OK x.shape=(3, 1, 4) length=(1, 4) output shape=(3, 1, 4) or (2, 3, 1, 3)\n", - "OK x.shape=(3, 1, 4) length=(1, 4) output shape=(3, 1, 4) or (2, 3, 1, 3)\n", - "OK x.shape=(3, 1, 4) length=(1, 4) output shape=(3, 1, 4) or (2, 3, 1, 3)\n", - "OK x.shape=(3, 1, 4) length=(1, 2) output shape=(3, 1, 2) or (2, 3, 1, 2)\n", - "OK x.shape=(3, 1, 4) length=(1, 1) output shape=(3, 1, 1) or (2, 3, 1, 1)\n", - "OK x.shape=(5, 7) length=(5, 7) output shape=(5, 7) or (2, 5, 4)\n", - "OK x.shape=(5, 7) length=(1, 7) output shape=(1, 7) or (2, 1, 4)\n", - "OK x.shape=(5, 7) length=(2, 7) output shape=(2, 7) or (2, 2, 4)\n", - "OK x.shape=(5, 7) length=(5, 2) output shape=(5, 2) or (2, 5, 2)\n", - "OK x.shape=(5, 7) length=(3, 4) output shape=(3, 4) or (2, 3, 3)\n", - "OK x.shape=(3, 5, 7) length=(5, 7) output shape=(3, 5, 7) or (2, 3, 5, 4)\n", - "OK x.shape=(3, 5, 7) length=(1, 7) output shape=(3, 1, 7) or (2, 3, 1, 4)\n", - "OK x.shape=(3, 5, 7) length=(2, 7) output shape=(3, 2, 7) or (2, 3, 2, 4)\n", - "OK x.shape=(3, 5, 7) length=(5, 2) output shape=(3, 5, 2) or (2, 3, 5, 2)\n", - "OK x.shape=(3, 5, 7) length=(3, 4) output shape=(3, 3, 4) or (2, 3, 3, 3)\n", - "OK x.shape=(7, 5) length=(7, 5) output shape=(7, 5) or (2, 7, 3)\n", - "OK x.shape=(7, 5) length=(1, 5) output shape=(1, 5) or (2, 1, 3)\n", - "OK x.shape=(7, 5) length=(2, 5) output shape=(2, 5) or (2, 2, 3)\n", - "OK x.shape=(7, 5) length=(7, 2) output shape=(7, 2) or (2, 7, 2)\n", - "OK x.shape=(7, 5) length=(3, 4) output shape=(3, 4) or (2, 3, 3)\n" - ] - } - ], - "source": [ - "for shape in [(3, 1, 4), (5, 7), (3, 5, 7), (7, 5)]:\n", - " for fft_length in [shape[-2:], (1, shape[-1]),\n", - " (min(2, shape[-2]), shape[-1]),\n", - " (shape[-2], 2),\n", - " (min(3, shape[-2]), min(4, shape[-2]))]:\n", - " x = numpy.random.randn(*list(shape)).astype(numpy.float32)\n", - " fnp = numpy.fft.fft2(x, fft_length)\n", - " if len(fnp.shape) == 2:\n", - " fn= numpy.expand_dims(fnp, 0)\n", - " try:\n", - " cus = fft2d_any(x, fft_length)\n", - " except IndexError as e:\n", - " print(\"ERR x.shape=%r length=%r error=%r\" % (x.shape, fft_length, e))\n", - " continue\n", - " try:\n", - " almost_equal(fnp[..., :cus.shape[-1]], cus)\n", - " except (AssertionError, IndexError) as e:\n", - " print(\"DIS x.shape=%r length=%r error=%r output shape=%r or %r\" % (\n", - " x.shape, fft_length, e, fnp.shape, cus.shape))\n", - " continue\n", - " print(\"OK x.shape=%r length=%r output shape=%r or %r\" % (\n", - " x.shape, fft_length, fnp.shape, cus.shape))" - ] - }, - { - "cell_type": "markdown", - "id": "c5f5229a", - "metadata": {}, - "source": [ - "### ONNX version\n", - "\n", - "Let's look into the differences first." - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "id": "025c2d88", - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext pyquickhelper" - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "id": "82664bc5", - "metadata": {}, - "outputs": [ + "cell_type": "code", + "execution_count": 13, + "id": "777d2775", + "metadata": {}, + "outputs": [], + "source": [ + "almost_equal(fft2d_np_, fft2d_np)" + ] + }, + { + "cell_type": "markdown", + "id": "cfbbe2fd", + "metadata": {}, + "source": [ + "It implies the computation of two FFT 1D along both axes. However, as ONNX does not support complex, it needs to be rewritten with only real numbers. The algorithm can be summarized into this formula $FFT(FFT(x, axis=1), axis=0)$. If *x* is real, $FFT(x, .)$ is complex. We still assume *x* is real, it then becomes (FFT is a linear operator, so $FFT(ix)=i FFT(x)$):\n", + "\n", + "* $y = FFT(x, axis=1)$\n", + "* $z_r = FFT(Real(y), axis=0)$, $z_i = FFT(Imag(y), axis=0)$\n", + "* $z = z_r + i z_i$\n", + "\n", + "*z* is the desired output. The following implementation is probably not the most efficient one. It avoids inplace computation as ONNX does like that." + ] + }, { - "data": { - "text/html": [ - "
populating...
" + "cell_type": "code", + "execution_count": 14, + "id": "dd4fc711", + "metadata": {}, + "outputs": [], + "source": [ + "def fft2d(mat, fft_length):\n", + " mat = mat[:fft_length[0], :fft_length[1]]\n", + " res = mat.copy()\n", + " \n", + " # first FFT\n", + " res = dft_real(res, fft_length=fft_length[1], transpose=True)\n", + " \n", + " # second FFT decomposed on FFT on real part and imaginary part\n", + " res2_real = dft_real(res[0], fft_length=fft_length[0], transpose=False)\n", + " res2_imag = dft_real(res[1], fft_length=fft_length[0], transpose=False) \n", + " res2_imag2 = numpy.vstack([-res2_imag[1:2], res2_imag[:1]])\n", + " res = res2_real + res2_imag2\n", + " size = fft_length[1]//2 + 1\n", + " return res[:, :fft_length[0], :size]\n", + "\n", + "\n", + "fft2d_np = numpy.fft.rfft2(rnd)\n", + "fft2d_cus = fft2d(rnd, rnd.shape)\n", + "almost_equal(fft2d_np, fft2d_cus)" ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "data": { - "application/javascript": [ - "/*\n", - "This is part of jsdifflib v1.0. \n", - "\n", - "Copyright 2007 - 2011 Chas Emerick . All rights reserved.\n", - "\n", - "Redistribution and use in source and binary forms, with or without modification, are\n", - "permitted provided that the following conditions are met:\n", - "\n", - " 1. Redistributions of source code must retain the above copyright notice, this list of\n", - " conditions and the following disclaimer.\n", - "\n", - " 2. Redistributions in binary form must reproduce the above copyright notice, this list\n", - " of conditions and the following disclaimer in the documentation and/or other materials\n", - " provided with the distribution.\n", - "\n", - "THIS SOFTWARE IS PROVIDED BY Chas Emerick ``AS IS'' AND ANY EXPRESS OR IMPLIED\n", - "WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND\n", - "FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL Chas Emerick OR\n", - "CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR\n", - "CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n", - "SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON\n", - "ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING\n", - "NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF\n", - "ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n", - "\n", - "The views and conclusions contained in the software and documentation are those of the\n", - "authors and should not be interpreted as representing official policies, either expressed\n", - "or implied, of Chas Emerick.\n", - "*/\n", - "var diffview = {\n", - "\t/**\n", - "\t * Builds and returns a visual diff view. The single parameter, `params', should contain\n", - "\t * the following values:\n", - "\t *\n", - "\t * - baseTextLines: the array of strings that was used as the base text input to SequenceMatcher\n", - "\t * - newTextLines: the array of strings that was used as the new text input to SequenceMatcher\n", - "\t * - opcodes: the array of arrays returned by SequenceMatcher.get_opcodes()\n", - "\t * - baseTextName: the title to be displayed above the base text listing in the diff view; defaults\n", - "\t *\t to \"Base Text\"\n", - "\t * - newTextName: the title to be displayed above the new text listing in the diff view; defaults\n", - "\t *\t to \"New Text\"\n", - "\t * - contextSize: the number of lines of context to show around differences; by default, all lines\n", - "\t *\t are shown\n", - "\t * - viewType: if 0, a side-by-side diff view is generated (default); if 1, an inline diff view is\n", - "\t *\t generated\n", - "\t */\n", - "\tbuildView: function (params) {\n", - "\t\tvar baseTextLines = params.baseTextLines;\n", - "\t\tvar newTextLines = params.newTextLines;\n", - "\t\tvar opcodes = params.opcodes;\n", - "\t\tvar baseTextName = params.baseTextName ? params.baseTextName : \"Base Text\";\n", - "\t\tvar newTextName = params.newTextName ? params.newTextName : \"New Text\";\n", - "\t\tvar contextSize = params.contextSize;\n", - "\t\tvar inline = (params.viewType == 0 || params.viewType == 1) ? params.viewType : 0;\n", - "\n", - "\t\tif (baseTextLines == null)\n", - "\t\t\tthrow \"Cannot build diff view; baseTextLines is not defined.\";\n", - "\t\tif (newTextLines == null)\n", - "\t\t\tthrow \"Cannot build diff view; newTextLines is not defined.\";\n", - "\t\tif (!opcodes)\n", - "\t\t\tthrow \"Cannot build diff view; opcodes is not defined.\";\n", - "\t\t\n", - "\t\tfunction celt (name, clazz) {\n", - "\t\t\tvar e = document.createElement(name);\n", - "\t\t\te.className = clazz;\n", - "\t\t\treturn e;\n", - "\t\t}\n", - "\t\t\n", - "\t\tfunction telt (name, text) {\n", - "\t\t\tvar e = document.createElement(name);\n", - "\t\t\te.appendChild(document.createTextNode(text));\n", - "\t\t\treturn e;\n", - "\t\t}\n", - "\t\t\n", - "\t\tfunction ctelt (name, clazz, text) {\n", - "\t\t\tvar e = document.createElement(name);\n", - "\t\t\te.className = clazz;\n", - "\t\t\te.appendChild(document.createTextNode(text));\n", - "\t\t\treturn e;\n", - "\t\t}\n", - "\t\n", - "\t\tvar tdata = document.createElement(\"thead\");\n", - "\t\tvar node = document.createElement(\"tr\");\n", - "\t\ttdata.appendChild(node);\n", - "\t\tif (inline) {\n", - "\t\t\tnode.appendChild(document.createElement(\"th\"));\n", - "\t\t\tnode.appendChild(document.createElement(\"th\"));\n", - "\t\t\tnode.appendChild(ctelt(\"th\", \"texttitle\", baseTextName + \" vs. \" + newTextName));\n", - "\t\t} else {\n", - "\t\t\tnode.appendChild(document.createElement(\"th\"));\n", - "\t\t\tnode.appendChild(ctelt(\"th\", \"texttitle\", baseTextName));\n", - "\t\t\tnode.appendChild(document.createElement(\"th\"));\n", - "\t\t\tnode.appendChild(ctelt(\"th\", \"texttitle\", newTextName));\n", - "\t\t}\n", - "\t\ttdata = [tdata];\n", - "\t\t\n", - "\t\tvar rows = [];\n", - "\t\tvar node2;\n", - "\t\t\n", - "\t\t/**\n", - "\t\t * Adds two cells to the given row; if the given row corresponds to a real\n", - "\t\t * line number (based on the line index tidx and the endpoint of the \n", - "\t\t * range in question tend), then the cells will contain the line number\n", - "\t\t * and the line of text from textLines at position tidx (with the class of\n", - "\t\t * the second cell set to the name of the change represented), and tidx + 1 will\n", - "\t\t * be returned.\t Otherwise, tidx is returned, and two empty cells are added\n", - "\t\t * to the given row.\n", - "\t\t */\n", - "\t\tfunction addCells (row, tidx, tend, textLines, change) {\n", - "\t\t\tif (tidx < tend) {\n", - "\t\t\t\trow.appendChild(telt(\"th\", (tidx + 1).toString()));\n", - "\t\t\t\trow.appendChild(ctelt(\"td\", change, textLines[tidx].replace(/\\t/g, \"\\u00a0\\u00a0\\u00a0\\u00a0\")));\n", - "\t\t\t\treturn tidx + 1;\n", - "\t\t\t} else {\n", - "\t\t\t\trow.appendChild(document.createElement(\"th\"));\n", - "\t\t\t\trow.appendChild(celt(\"td\", \"empty\"));\n", - "\t\t\t\treturn tidx;\n", - "\t\t\t}\n", - "\t\t}\n", - "\t\t\n", - "\t\tfunction addCellsInline (row, tidx, tidx2, textLines, change) {\n", - "\t\t\trow.appendChild(telt(\"th\", tidx == null ? \"\" : (tidx + 1).toString()));\n", - "\t\t\trow.appendChild(telt(\"th\", tidx2 == null ? \"\" : (tidx2 + 1).toString()));\n", - "\t\t\trow.appendChild(ctelt(\"td\", change, textLines[tidx != null ? tidx : tidx2].replace(/\\t/g, \"\\u00a0\\u00a0\\u00a0\\u00a0\")));\n", - "\t\t}\n", - "\t\t\n", - "\t\tfor (var idx = 0; idx < opcodes.length; idx++) {\n", - "\t\t\tvar code = opcodes[idx];\n", - "\t\t\tvar change = code[0];\n", - "\t\t\tvar b = code[1];\n", - "\t\t\tvar be = code[2];\n", - "\t\t\tvar n = code[3];\n", - "\t\t\tvar ne = code[4];\n", - "\t\t\tvar rowcnt = Math.max(be - b, ne - n);\n", - "\t\t\tvar toprows = [];\n", - "\t\t\tvar botrows = [];\n", - "\t\t\tfor (var i = 0; i < rowcnt; i++) {\n", - "\t\t\t\t// jump ahead if we've alredy provided leading context or if this is the first range\n", - "\t\t\t\tif (contextSize && opcodes.length > 1 && ((idx > 0 && i == contextSize) || (idx == 0 && i == 0)) && change==\"equal\") {\n", - "\t\t\t\t\tvar jump = rowcnt - ((idx == 0 ? 1 : 2) * contextSize);\n", - "\t\t\t\t\tif (jump > 1) {\n", - "\t\t\t\t\t\ttoprows.push(node = document.createElement(\"tr\"));\n", - "\t\t\t\t\t\t\n", - "\t\t\t\t\t\tb += jump;\n", - "\t\t\t\t\t\tn += jump;\n", - "\t\t\t\t\t\ti += jump - 1;\n", - "\t\t\t\t\t\tnode.appendChild(telt(\"th\", \"...\"));\n", - "\t\t\t\t\t\tif (!inline) node.appendChild(ctelt(\"td\", \"skip\", \"\"));\n", - "\t\t\t\t\t\tnode.appendChild(telt(\"th\", \"...\"));\n", - "\t\t\t\t\t\tnode.appendChild(ctelt(\"td\", \"skip\", \"\"));\n", - "\t\t\t\t\t\t\n", - "\t\t\t\t\t\t// skip last lines if they're all equal\n", - "\t\t\t\t\t\tif (idx + 1 == opcodes.length) {\n", - "\t\t\t\t\t\t\tbreak;\n", - "\t\t\t\t\t\t} else {\n", - "\t\t\t\t\t\t\tcontinue;\n", - "\t\t\t\t\t\t}\n", - "\t\t\t\t\t}\n", - "\t\t\t\t}\n", - "\t\t\t\t\n", - "\t\t\t\ttoprows.push(node = document.createElement(\"tr\"));\n", - "\t\t\t\tif (inline) {\n", - "\t\t\t\t\tif (change == \"insert\") {\n", - "\t\t\t\t\t\taddCellsInline(node, null, n++, newTextLines, change);\n", - "\t\t\t\t\t} else if (change == \"replace\") {\n", - "\t\t\t\t\t\tbotrows.push(node2 = document.createElement(\"tr\"));\n", - "\t\t\t\t\t\tif (b < be) addCellsInline(node, b++, null, baseTextLines, \"delete\");\n", - "\t\t\t\t\t\tif (n < ne) addCellsInline(node2, null, n++, newTextLines, \"insert\");\n", - "\t\t\t\t\t} else if (change == \"delete\") {\n", - "\t\t\t\t\t\taddCellsInline(node, b++, null, baseTextLines, change);\n", - "\t\t\t\t\t} else {\n", - "\t\t\t\t\t\t// equal\n", - "\t\t\t\t\t\taddCellsInline(node, b++, n++, baseTextLines, change);\n", - "\t\t\t\t\t}\n", - "\t\t\t\t} else {\n", - "\t\t\t\t\tb = addCells(node, b, be, baseTextLines, change);\n", - "\t\t\t\t\tn = addCells(node, n, ne, newTextLines, change);\n", - "\t\t\t\t}\n", - "\t\t\t}\n", - "\n", - "\t\t\tfor (var i = 0; i < toprows.length; i++) rows.push(toprows[i]);\n", - "\t\t\tfor (var i = 0; i < botrows.length; i++) rows.push(botrows[i]);\n", - "\t\t}\n", - "\t\t\n", - "\t\trows.push(node = ctelt(\"th\", \"author\", \"diff view generated by \"));\n", - "\t\tnode.setAttribute(\"colspan\", inline ? 3 : 4);\n", - "\t\tnode.appendChild(node2 = telt(\"a\", \"jsdifflib\"));\n", - "\t\tnode2.setAttribute(\"href\", \"http://github.com/cemerick/jsdifflib\");\n", - "\t\t\n", - "\t\ttdata.push(node = document.createElement(\"tbody\"));\n", - "\t\tfor (var idx in rows) rows.hasOwnProperty(idx) && node.appendChild(rows[idx]);\n", - "\t\t\n", - "\t\tnode = celt(\"table\", \"diff\" + (inline ? \" inlinediff\" : \"\"));\n", - "\t\tfor (var idx in tdata) tdata.hasOwnProperty(idx) && node.appendChild(tdata[idx]);\n", - "\t\treturn node;\n", - "\t}\n", - "};\n", - "\n", - "\n", - "/***\n", - "This is part of jsdifflib v1.0. \n", - "\n", - "Copyright (c) 2007, Snowtide Informatics Systems, Inc.\n", - "All rights reserved.\n", - "\n", - "Redistribution and use in source and binary forms, with or without modification,\n", - "are permitted provided that the following conditions are met:\n", - "\n", - "\t* Redistributions of source code must retain the above copyright notice, this\n", - "\t\tlist of conditions and the following disclaimer.\n", - "\t* Redistributions in binary form must reproduce the above copyright notice,\n", - "\t\tthis list of conditions and the following disclaimer in the documentation\n", - "\t\tand/or other materials provided with the distribution.\n", - "\t* Neither the name of the Snowtide Informatics Systems nor the names of its\n", - "\t\tcontributors may be used to endorse or promote products derived from this\n", - "\t\tsoftware without specific prior written permission.\n", - "\n", - "THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND ANY\n", - "EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES\n", - "OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT\n", - "SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,\n", - "INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED\n", - "TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR\n", - "BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN\n", - "CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN\n", - "ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH\n", - "DAMAGE.\n", - "***/\n", - "/* Author: Chas Emerick */\n", - "var __whitespace = {\" \":true, \"\\t\":true, \"\\n\":true, \"\\f\":true, \"\\r\":true};\n", - "\n", - "var difflib = {\n", - "\tdefaultJunkFunction: function (c) {\n", - "\t\treturn __whitespace.hasOwnProperty(c);\n", - "\t},\n", - "\t\n", - "\tstripLinebreaks: function (str) { return str.replace(/^[\\n\\r]*|[\\n\\r]*$/g, \"\"); },\n", - "\t\n", - "\tstringAsLines: function (str) {\n", - "\t\tvar lfpos = str.indexOf(\"\\n\");\n", - "\t\tvar crpos = str.indexOf(\"\\r\");\n", - "\t\tvar linebreak = ((lfpos > -1 && crpos > -1) || crpos < 0) ? \"\\n\" : \"\\r\";\n", - "\t\t\n", - "\t\tvar lines = str.split(linebreak);\n", - "\t\tfor (var i = 0; i < lines.length; i++) {\n", - "\t\t\tlines[i] = difflib.stripLinebreaks(lines[i]);\n", - "\t\t}\n", - "\t\t\n", - "\t\treturn lines;\n", - "\t},\n", - "\t\n", - "\t// iteration-based reduce implementation\n", - "\t__reduce: function (func, list, initial) {\n", - "\t\tif (initial != null) {\n", - "\t\t\tvar value = initial;\n", - "\t\t\tvar idx = 0;\n", - "\t\t} else if (list) {\n", - "\t\t\tvar value = list[0];\n", - "\t\t\tvar idx = 1;\n", - "\t\t} else {\n", - "\t\t\treturn null;\n", - "\t\t}\n", - "\t\t\n", - "\t\tfor (; idx < list.length; idx++) {\n", - "\t\t\tvalue = func(value, list[idx]);\n", - "\t\t}\n", - "\t\t\n", - "\t\treturn value;\n", - "\t},\n", - "\t\n", - "\t// comparison function for sorting lists of numeric tuples\n", - "\t__ntuplecomp: function (a, b) {\n", - "\t\tvar mlen = Math.max(a.length, b.length);\n", - "\t\tfor (var i = 0; i < mlen; i++) {\n", - "\t\t\tif (a[i] < b[i]) return -1;\n", - "\t\t\tif (a[i] > b[i]) return 1;\n", - "\t\t}\n", - "\t\t\n", - "\t\treturn a.length == b.length ? 0 : (a.length < b.length ? -1 : 1);\n", - "\t},\n", - "\t\n", - "\t__calculate_ratio: function (matches, length) {\n", - "\t\treturn length ? 2.0 * matches / length : 1.0;\n", - "\t},\n", - "\t\n", - "\t// returns a function that returns true if a key passed to the returned function\n", - "\t// is in the dict (js object) provided to this function; replaces being able to\n", - "\t// carry around dict.has_key in python...\n", - "\t__isindict: function (dict) {\n", - "\t\treturn function (key) { return dict.hasOwnProperty(key); };\n", - "\t},\n", - "\t\n", - "\t// replacement for python's dict.get function -- need easy default values\n", - "\t__dictget: function (dict, key, defaultValue) {\n", - "\t\treturn dict.hasOwnProperty(key) ? dict[key] : defaultValue;\n", - "\t},\t\n", - "\t\n", - "\tSequenceMatcher: function (a, b, isjunk) {\n", - "\t\tthis.set_seqs = function (a, b) {\n", - "\t\t\tthis.set_seq1(a);\n", - "\t\t\tthis.set_seq2(b);\n", - "\t\t}\n", - "\t\t\n", - "\t\tthis.set_seq1 = function (a) {\n", - "\t\t\tif (a == this.a) return;\n", - "\t\t\tthis.a = a;\n", - "\t\t\tthis.matching_blocks = this.opcodes = null;\n", - "\t\t}\n", - "\t\t\n", - "\t\tthis.set_seq2 = function (b) {\n", - "\t\t\tif (b == this.b) return;\n", - "\t\t\tthis.b = b;\n", - "\t\t\tthis.matching_blocks = this.opcodes = this.fullbcount = null;\n", - "\t\t\tthis.__chain_b();\n", - "\t\t}\n", - "\t\t\n", - "\t\tthis.__chain_b = function () {\n", - "\t\t\tvar b = this.b;\n", - "\t\t\tvar n = b.length;\n", - "\t\t\tvar b2j = this.b2j = {};\n", - "\t\t\tvar populardict = {};\n", - "\t\t\tfor (var i = 0; i < b.length; i++) {\n", - "\t\t\t\tvar elt = b[i];\n", - "\t\t\t\tif (b2j.hasOwnProperty(elt)) {\n", - "\t\t\t\t\tvar indices = b2j[elt];\n", - "\t\t\t\t\tif (n >= 200 && indices.length * 100 > n) {\n", - "\t\t\t\t\t\tpopulardict[elt] = 1;\n", - "\t\t\t\t\t\tdelete b2j[elt];\n", - "\t\t\t\t\t} else {\n", - "\t\t\t\t\t\tindices.push(i);\n", - "\t\t\t\t\t}\n", - "\t\t\t\t} else {\n", - "\t\t\t\t\tb2j[elt] = [i];\n", - "\t\t\t\t}\n", - "\t\t\t}\n", - "\t\n", - "\t\t\tfor (var elt in populardict) {\n", - "\t\t\t\tif (populardict.hasOwnProperty(elt)) {\n", - "\t\t\t\t\tdelete b2j[elt];\n", - "\t\t\t\t}\n", - "\t\t\t}\n", - "\t\t\t\n", - "\t\t\tvar isjunk = this.isjunk;\n", - "\t\t\tvar junkdict = {};\n", - "\t\t\tif (isjunk) {\n", - "\t\t\t\tfor (var elt in populardict) {\n", - "\t\t\t\t\tif (populardict.hasOwnProperty(elt) && isjunk(elt)) {\n", - "\t\t\t\t\t\tjunkdict[elt] = 1;\n", - "\t\t\t\t\t\tdelete populardict[elt];\n", - "\t\t\t\t\t}\n", - "\t\t\t\t}\n", - "\t\t\t\tfor (var elt in b2j) {\n", - "\t\t\t\t\tif (b2j.hasOwnProperty(elt) && isjunk(elt)) {\n", - "\t\t\t\t\t\tjunkdict[elt] = 1;\n", - "\t\t\t\t\t\tdelete b2j[elt];\n", - "\t\t\t\t\t}\n", - "\t\t\t\t}\n", - "\t\t\t}\n", - "\t\n", - "\t\t\tthis.isbjunk = difflib.__isindict(junkdict);\n", - "\t\t\tthis.isbpopular = difflib.__isindict(populardict);\n", - "\t\t}\n", - "\t\t\n", - "\t\tthis.find_longest_match = function (alo, ahi, blo, bhi) {\n", - "\t\t\tvar a = this.a;\n", - "\t\t\tvar b = this.b;\n", - "\t\t\tvar b2j = this.b2j;\n", - "\t\t\tvar isbjunk = this.isbjunk;\n", - "\t\t\tvar besti = alo;\n", - "\t\t\tvar bestj = blo;\n", - "\t\t\tvar bestsize = 0;\n", - "\t\t\tvar j = null;\n", - "\t\t\tvar k;\n", - "\t\n", - "\t\t\tvar j2len = {};\n", - "\t\t\tvar nothing = [];\n", - "\t\t\tfor (var i = alo; i < ahi; i++) {\n", - "\t\t\t\tvar newj2len = {};\n", - "\t\t\t\tvar jdict = difflib.__dictget(b2j, a[i], nothing);\n", - "\t\t\t\tfor (var jkey in jdict) {\n", - "\t\t\t\t\tif (jdict.hasOwnProperty(jkey)) {\n", - "\t\t\t\t\t\tj = jdict[jkey];\n", - "\t\t\t\t\t\tif (j < blo) continue;\n", - "\t\t\t\t\t\tif (j >= bhi) break;\n", - "\t\t\t\t\t\tnewj2len[j] = k = difflib.__dictget(j2len, j - 1, 0) + 1;\n", - "\t\t\t\t\t\tif (k > bestsize) {\n", - "\t\t\t\t\t\t\tbesti = i - k + 1;\n", - "\t\t\t\t\t\t\tbestj = j - k + 1;\n", - "\t\t\t\t\t\t\tbestsize = k;\n", - "\t\t\t\t\t\t}\n", - "\t\t\t\t\t}\n", - "\t\t\t\t}\n", - "\t\t\t\tj2len = newj2len;\n", - "\t\t\t}\n", - "\t\n", - "\t\t\twhile (besti > alo && bestj > blo && !isbjunk(b[bestj - 1]) && a[besti - 1] == b[bestj - 1]) {\n", - "\t\t\t\tbesti--;\n", - "\t\t\t\tbestj--;\n", - "\t\t\t\tbestsize++;\n", - "\t\t\t}\n", - "\t\t\t\t\n", - "\t\t\twhile (besti + bestsize < ahi && bestj + bestsize < bhi &&\n", - "\t\t\t\t\t!isbjunk(b[bestj + bestsize]) &&\n", - "\t\t\t\t\ta[besti + bestsize] == b[bestj + bestsize]) {\n", - "\t\t\t\tbestsize++;\n", - "\t\t\t}\n", - "\t\n", - "\t\t\twhile (besti > alo && bestj > blo && isbjunk(b[bestj - 1]) && a[besti - 1] == b[bestj - 1]) {\n", - "\t\t\t\tbesti--;\n", - "\t\t\t\tbestj--;\n", - "\t\t\t\tbestsize++;\n", - "\t\t\t}\n", - "\t\t\t\n", - "\t\t\twhile (besti + bestsize < ahi && bestj + bestsize < bhi && isbjunk(b[bestj + bestsize]) &&\n", - "\t\t\t\t\ta[besti + bestsize] == b[bestj + bestsize]) {\n", - "\t\t\t\tbestsize++;\n", - "\t\t\t}\n", - "\t\n", - "\t\t\treturn [besti, bestj, bestsize];\n", - "\t\t}\n", - "\t\t\n", - "\t\tthis.get_matching_blocks = function () {\n", - "\t\t\tif (this.matching_blocks != null) return this.matching_blocks;\n", - "\t\t\tvar la = this.a.length;\n", - "\t\t\tvar lb = this.b.length;\n", - "\t\n", - "\t\t\tvar queue = [[0, la, 0, lb]];\n", - "\t\t\tvar matching_blocks = [];\n", - "\t\t\tvar alo, ahi, blo, bhi, qi, i, j, k, x;\n", - "\t\t\twhile (queue.length) {\n", - "\t\t\t\tqi = queue.pop();\n", - "\t\t\t\talo = qi[0];\n", - "\t\t\t\tahi = qi[1];\n", - "\t\t\t\tblo = qi[2];\n", - "\t\t\t\tbhi = qi[3];\n", - "\t\t\t\tx = this.find_longest_match(alo, ahi, blo, bhi);\n", - "\t\t\t\ti = x[0];\n", - "\t\t\t\tj = x[1];\n", - "\t\t\t\tk = x[2];\n", - "\t\n", - "\t\t\t\tif (k) {\n", - "\t\t\t\t\tmatching_blocks.push(x);\n", - "\t\t\t\t\tif (alo < i && blo < j)\n", - "\t\t\t\t\t\tqueue.push([alo, i, blo, j]);\n", - "\t\t\t\t\tif (i+k < ahi && j+k < bhi)\n", - "\t\t\t\t\t\tqueue.push([i + k, ahi, j + k, bhi]);\n", - "\t\t\t\t}\n", - "\t\t\t}\n", - "\t\t\t\n", - "\t\t\tmatching_blocks.sort(difflib.__ntuplecomp);\n", - "\t\n", - "\t\t\tvar i1 = 0, j1 = 0, k1 = 0, block = 0;\n", - "\t\t\tvar i2, j2, k2;\n", - "\t\t\tvar non_adjacent = [];\n", - "\t\t\tfor (var idx in matching_blocks) {\n", - "\t\t\t\tif (matching_blocks.hasOwnProperty(idx)) {\n", - "\t\t\t\t\tblock = matching_blocks[idx];\n", - "\t\t\t\t\ti2 = block[0];\n", - "\t\t\t\t\tj2 = block[1];\n", - "\t\t\t\t\tk2 = block[2];\n", - "\t\t\t\t\tif (i1 + k1 == i2 && j1 + k1 == j2) {\n", - "\t\t\t\t\t\tk1 += k2;\n", - "\t\t\t\t\t} else {\n", - "\t\t\t\t\t\tif (k1) non_adjacent.push([i1, j1, k1]);\n", - "\t\t\t\t\t\ti1 = i2;\n", - "\t\t\t\t\t\tj1 = j2;\n", - "\t\t\t\t\t\tk1 = k2;\n", - "\t\t\t\t\t}\n", - "\t\t\t\t}\n", - "\t\t\t}\n", - "\t\t\t\n", - "\t\t\tif (k1) non_adjacent.push([i1, j1, k1]);\n", - "\t\n", - "\t\t\tnon_adjacent.push([la, lb, 0]);\n", - "\t\t\tthis.matching_blocks = non_adjacent;\n", - "\t\t\treturn this.matching_blocks;\n", - "\t\t}\n", - "\t\t\n", - "\t\tthis.get_opcodes = function () {\n", - "\t\t\tif (this.opcodes != null) return this.opcodes;\n", - "\t\t\tvar i = 0;\n", - "\t\t\tvar j = 0;\n", - "\t\t\tvar answer = [];\n", - "\t\t\tthis.opcodes = answer;\n", - "\t\t\tvar block, ai, bj, size, tag;\n", - "\t\t\tvar blocks = this.get_matching_blocks();\n", - "\t\t\tfor (var idx in blocks) {\n", - "\t\t\t\tif (blocks.hasOwnProperty(idx)) {\n", - "\t\t\t\t\tblock = blocks[idx];\n", - "\t\t\t\t\tai = block[0];\n", - "\t\t\t\t\tbj = block[1];\n", - "\t\t\t\t\tsize = block[2];\n", - "\t\t\t\t\ttag = '';\n", - "\t\t\t\t\tif (i < ai && j < bj) {\n", - "\t\t\t\t\t\ttag = 'replace';\n", - "\t\t\t\t\t} else if (i < ai) {\n", - "\t\t\t\t\t\ttag = 'delete';\n", - "\t\t\t\t\t} else if (j < bj) {\n", - "\t\t\t\t\t\ttag = 'insert';\n", - "\t\t\t\t\t}\n", - "\t\t\t\t\tif (tag) answer.push([tag, i, ai, j, bj]);\n", - "\t\t\t\t\ti = ai + size;\n", - "\t\t\t\t\tj = bj + size;\n", - "\t\t\t\t\t\n", - "\t\t\t\t\tif (size) answer.push(['equal', ai, i, bj, j]);\n", - "\t\t\t\t}\n", - "\t\t\t}\n", - "\t\t\t\n", - "\t\t\treturn answer;\n", - "\t\t}\n", - "\t\t\n", - "\t\t// this is a generator function in the python lib, which of course is not supported in javascript\n", - "\t\t// the reimplementation builds up the grouped opcodes into a list in their entirety and returns that.\n", - "\t\tthis.get_grouped_opcodes = function (n) {\n", - "\t\t\tif (!n) n = 3;\n", - "\t\t\tvar codes = this.get_opcodes();\n", - "\t\t\tif (!codes) codes = [[\"equal\", 0, 1, 0, 1]];\n", - "\t\t\tvar code, tag, i1, i2, j1, j2;\n", - "\t\t\tif (codes[0][0] == 'equal') {\n", - "\t\t\t\tcode = codes[0];\n", - "\t\t\t\ttag = code[0];\n", - "\t\t\t\ti1 = code[1];\n", - "\t\t\t\ti2 = code[2];\n", - "\t\t\t\tj1 = code[3];\n", - "\t\t\t\tj2 = code[4];\n", - "\t\t\t\tcodes[0] = [tag, Math.max(i1, i2 - n), i2, Math.max(j1, j2 - n), j2];\n", - "\t\t\t}\n", - "\t\t\tif (codes[codes.length - 1][0] == 'equal') {\n", - "\t\t\t\tcode = codes[codes.length - 1];\n", - "\t\t\t\ttag = code[0];\n", - "\t\t\t\ti1 = code[1];\n", - "\t\t\t\ti2 = code[2];\n", - "\t\t\t\tj1 = code[3];\n", - "\t\t\t\tj2 = code[4];\n", - "\t\t\t\tcodes[codes.length - 1] = [tag, i1, Math.min(i2, i1 + n), j1, Math.min(j2, j1 + n)];\n", - "\t\t\t}\n", - "\t\n", - "\t\t\tvar nn = n + n;\n", - "\t\t\tvar group = [];\n", - "\t\t\tvar groups = [];\n", - "\t\t\tfor (var idx in codes) {\n", - "\t\t\t\tif (codes.hasOwnProperty(idx)) {\n", - "\t\t\t\t\tcode = codes[idx];\n", - "\t\t\t\t\ttag = code[0];\n", - "\t\t\t\t\ti1 = code[1];\n", - "\t\t\t\t\ti2 = code[2];\n", - "\t\t\t\t\tj1 = code[3];\n", - "\t\t\t\t\tj2 = code[4];\n", - "\t\t\t\t\tif (tag == 'equal' && i2 - i1 > nn) {\n", - "\t\t\t\t\t\tgroup.push([tag, i1, Math.min(i2, i1 + n), j1, Math.min(j2, j1 + n)]);\n", - "\t\t\t\t\t\tgroups.push(group);\n", - "\t\t\t\t\t\tgroup = [];\n", - "\t\t\t\t\t\ti1 = Math.max(i1, i2-n);\n", - "\t\t\t\t\t\tj1 = Math.max(j1, j2-n);\n", - "\t\t\t\t\t}\n", - "\t\t\t\t\t\n", - "\t\t\t\t\tgroup.push([tag, i1, i2, j1, j2]);\n", - "\t\t\t\t}\n", - "\t\t\t}\n", - "\t\t\t\n", - "\t\t\tif (group && !(group.length == 1 && group[0][0] == 'equal')) groups.push(group)\n", - "\t\t\t\n", - "\t\t\treturn groups;\n", - "\t\t}\n", - "\t\t\n", - "\t\tthis.ratio = function () {\n", - "\t\t\tmatches = difflib.__reduce(\n", - "\t\t\t\t\t\t\tfunction (sum, triple) { return sum + triple[triple.length - 1]; },\n", - "\t\t\t\t\t\t\tthis.get_matching_blocks(), 0);\n", - "\t\t\treturn difflib.__calculate_ratio(matches, this.a.length + this.b.length);\n", - "\t\t}\n", - "\t\t\n", - "\t\tthis.quick_ratio = function () {\n", - "\t\t\tvar fullbcount, elt;\n", - "\t\t\tif (this.fullbcount == null) {\n", - "\t\t\t\tthis.fullbcount = fullbcount = {};\n", - "\t\t\t\tfor (var i = 0; i < this.b.length; i++) {\n", - "\t\t\t\t\telt = this.b[i];\n", - "\t\t\t\t\tfullbcount[elt] = difflib.__dictget(fullbcount, elt, 0) + 1;\n", - "\t\t\t\t}\n", - "\t\t\t}\n", - "\t\t\tfullbcount = this.fullbcount;\n", - "\t\n", - "\t\t\tvar avail = {};\n", - "\t\t\tvar availhas = difflib.__isindict(avail);\n", - "\t\t\tvar matches = numb = 0;\n", - "\t\t\tfor (var i = 0; i < this.a.length; i++) {\n", - "\t\t\t\telt = this.a[i];\n", - "\t\t\t\tif (availhas(elt)) {\n", - "\t\t\t\t\tnumb = avail[elt];\n", - "\t\t\t\t} else {\n", - "\t\t\t\t\tnumb = difflib.__dictget(fullbcount, elt, 0);\n", - "\t\t\t\t}\n", - "\t\t\t\tavail[elt] = numb - 1;\n", - "\t\t\t\tif (numb > 0) matches++;\n", - "\t\t\t}\n", - "\t\t\t\n", - "\t\t\treturn difflib.__calculate_ratio(matches, this.a.length + this.b.length);\n", - "\t\t}\n", - "\t\t\n", - "\t\tthis.real_quick_ratio = function () {\n", - "\t\t\tvar la = this.a.length;\n", - "\t\t\tvar lb = this.b.length;\n", - "\t\t\treturn _calculate_ratio(Math.min(la, lb), la + lb);\n", - "\t\t}\n", - "\t\t\n", - "\t\tthis.isjunk = isjunk ? isjunk : difflib.defaultJunkFunction;\n", - "\t\tthis.a = this.b = null;\n", - "\t\tthis.set_seqs(a, b);\n", - "\t}\n", - "};\n", - "\n", - "\n", - "\n", - "function diffUsingJS (viewType, contextSize, baseText, newText) {\n", - "\n", - " var byId = function (id) { return document.getElementById(id); },\n", - " base = difflib.stringAsLines(baseText),\n", - " newtxt = difflib.stringAsLines(newText),\n", - " sm = new difflib.SequenceMatcher(base, newtxt),\n", - " opcodes = sm.get_opcodes(),\n", - " diffoutputdiv = byId(\"diffid_2021-08-05_16_46_43_018480\");\n", - "\n", - " diffoutputdiv.innerHTML = \"\";\n", - " contextSize = contextSize || null;\n", - "\n", - " diffoutputdiv.appendChild(diffview.buildView({\n", - " baseTextLines: base,\n", - " newTextLines: newtxt,\n", - " opcodes: opcodes,\n", - " baseTextName: \"Base Text\",\n", - " newTextName: \"New Text\",\n", - " contextSize: contextSize,\n", - " viewType: viewType\n", - " }));\n", - "}\n", - "var tview=0;\n", - "var csize='';\n", - "var bt = 'def dft_real(x, fft_length=None, transpose=True):\\n if len(x.shape) == 1:\\n x = x.reshape((1, -1))\\n N = 1\\n else:\\n N = x.shape[0] \\n C = x.shape[-1] if transpose else x.shape[-2]\\n if fft_length is None:\\n fft_length = x.shape[-1]\\n size = fft_length // 2 + 1\\n\\n cst = dft_real_cst(C, fft_length)\\n if transpose:\\n x = numpy.transpose(x, (1, 0))\\n a = cst[:, :, :fft_length]\\n b = x[:fft_length]\\n res = numpy.matmul(a, b)\\n res = res[:, :size, :]\\n return numpy.transpose(res, (0, 2, 1))\\n else:\\n a = cst[:, :, :fft_length]\\n b = x[:fft_length]\\n return numpy.matmul(a, b)\\n';\n", - "var nt = 'def dft_real_d3(x, fft_length=None, transpose=True):\\n if len(x.shape) != 3:\\n raise RuntimeError(\"Not implemented for shape=%r.\" % x.shape)\\n N = x.shape[1]\\n C = x.shape[-1] if transpose else x.shape[-2]\\n if fft_length is None:\\n fft_length = x.shape[-1]\\n size = fft_length // 2 + 1\\n\\n cst = dft_real_cst(C, fft_length)\\n if transpose:\\n x = numpy.transpose(x, (0, 2, 1))\\n a = cst[:, :, :fft_length]\\n b = x[:, :fft_length, :]\\n a = numpy.expand_dims(a, 0)\\n b = numpy.expand_dims(b, 1)\\n res = numpy.matmul(a, b)\\n res = res[:, :, :size, :]\\n return numpy.transpose(res, (1, 0, 3, 2))\\n else:\\n a = cst[:, :, :fft_length]\\n b = x[:, :fft_length, :]\\n a = numpy.expand_dims(a, 0)\\n b = numpy.expand_dims(b, 1)\\n res = numpy.matmul(a, b)\\n return numpy.transpose(res, (1, 0, 2, 3))\\n';\n", - "diffUsingJS(tview, csize, bt, nt) ;\n" + "cell_type": "code", + "execution_count": 15, + "id": "bb8667e6", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[ 6.73214482 +0.j , 3.36817961 +2.73163668j,\n", + " 4.42688318-10.23641065j, -3.27162717 -0.54910943j],\n", + " [ 0.05934113 -2.02790296j, -4.71694176 -1.80039444j,\n", + " -0.16187544 -1.27214887j, -4.76195404 -8.23146595j],\n", + " [-2.43886644 +0.67253454j, 0.62177822 +1.71628605j,\n", + " -4.22144547 -0.24384973j, -1.96253444 -2.26942153j],\n", + " [-2.43886644 -0.67253454j, -4.94210355 +1.65439295j,\n", + " -6.75624015 -2.50966739j, -1.62599543 +7.41506091j],\n", + " [ 0.05934113 +2.02790296j, 1.56068457 -4.5734695j ,\n", + " -2.9809962 +2.90470743j, 4.42498542-10.45411745j]])" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } ], - "text/plain": [ - "" + "source": [ + "fft2d_np" ] - }, - "execution_count": 29, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import inspect\n", - "text1 = inspect.getsource(dft_real)\n", - "text2 = inspect.getsource(dft_real_d3)\n", - "%textdiff text1 text2" - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "id": "cd7e14d4", - "metadata": {}, - "outputs": [ + }, { - "data": { - "text/html": [ - "
populating...
" + "cell_type": "code", + "execution_count": 16, + "id": "56a94d97", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[[ 6.73214482, 3.36817961, 4.42688318, -3.27162717],\n", + " [ 0.05934113, -4.71694176, -0.16187544, -4.76195404],\n", + " [ -2.43886644, 0.62177822, -4.22144547, -1.96253444],\n", + " [ -2.43886644, -4.94210355, -6.75624015, -1.62599543],\n", + " [ 0.05934113, 1.56068457, -2.9809962 , 4.42498542]],\n", + "\n", + " [[ 0. , 2.73163668, -10.23641065, -0.54910943],\n", + " [ -2.02790296, -1.80039444, -1.27214887, -8.23146595],\n", + " [ 0.67253454, 1.71628605, -0.24384973, -2.26942153],\n", + " [ -0.67253454, 1.65439295, -2.50966739, 7.41506091],\n", + " [ 2.02790296, -4.5734695 , 2.90470743, -10.45411745]]])" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "fft2d_cus" ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "data": { - "application/javascript": [ - "/*\n", - "This is part of jsdifflib v1.0. \n", - "\n", - "Copyright 2007 - 2011 Chas Emerick . All rights reserved.\n", - "\n", - "Redistribution and use in source and binary forms, with or without modification, are\n", - "permitted provided that the following conditions are met:\n", - "\n", - " 1. Redistributions of source code must retain the above copyright notice, this list of\n", - " conditions and the following disclaimer.\n", - "\n", - " 2. Redistributions in binary form must reproduce the above copyright notice, this list\n", - " of conditions and the following disclaimer in the documentation and/or other materials\n", - " provided with the distribution.\n", - "\n", - "THIS SOFTWARE IS PROVIDED BY Chas Emerick ``AS IS'' AND ANY EXPRESS OR IMPLIED\n", - "WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND\n", - "FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL Chas Emerick OR\n", - "CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR\n", - "CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n", - "SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON\n", - "ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING\n", - "NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF\n", - "ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n", - "\n", - "The views and conclusions contained in the software and documentation are those of the\n", - "authors and should not be interpreted as representing official policies, either expressed\n", - "or implied, of Chas Emerick.\n", - "*/\n", - "var diffview = {\n", - "\t/**\n", - "\t * Builds and returns a visual diff view. The single parameter, `params', should contain\n", - "\t * the following values:\n", - "\t *\n", - "\t * - baseTextLines: the array of strings that was used as the base text input to SequenceMatcher\n", - "\t * - newTextLines: the array of strings that was used as the new text input to SequenceMatcher\n", - "\t * - opcodes: the array of arrays returned by SequenceMatcher.get_opcodes()\n", - "\t * - baseTextName: the title to be displayed above the base text listing in the diff view; defaults\n", - "\t *\t to \"Base Text\"\n", - "\t * - newTextName: the title to be displayed above the new text listing in the diff view; defaults\n", - "\t *\t to \"New Text\"\n", - "\t * - contextSize: the number of lines of context to show around differences; by default, all lines\n", - "\t *\t are shown\n", - "\t * - viewType: if 0, a side-by-side diff view is generated (default); if 1, an inline diff view is\n", - "\t *\t generated\n", - "\t */\n", - "\tbuildView: function (params) {\n", - "\t\tvar baseTextLines = params.baseTextLines;\n", - "\t\tvar newTextLines = params.newTextLines;\n", - "\t\tvar opcodes = params.opcodes;\n", - "\t\tvar baseTextName = params.baseTextName ? params.baseTextName : \"Base Text\";\n", - "\t\tvar newTextName = params.newTextName ? params.newTextName : \"New Text\";\n", - "\t\tvar contextSize = params.contextSize;\n", - "\t\tvar inline = (params.viewType == 0 || params.viewType == 1) ? params.viewType : 0;\n", - "\n", - "\t\tif (baseTextLines == null)\n", - "\t\t\tthrow \"Cannot build diff view; baseTextLines is not defined.\";\n", - "\t\tif (newTextLines == null)\n", - "\t\t\tthrow \"Cannot build diff view; newTextLines is not defined.\";\n", - "\t\tif (!opcodes)\n", - "\t\t\tthrow \"Cannot build diff view; opcodes is not defined.\";\n", - "\t\t\n", - "\t\tfunction celt (name, clazz) {\n", - "\t\t\tvar e = document.createElement(name);\n", - "\t\t\te.className = clazz;\n", - "\t\t\treturn e;\n", - "\t\t}\n", - "\t\t\n", - "\t\tfunction telt (name, text) {\n", - "\t\t\tvar e = document.createElement(name);\n", - "\t\t\te.appendChild(document.createTextNode(text));\n", - "\t\t\treturn e;\n", - "\t\t}\n", - "\t\t\n", - "\t\tfunction ctelt (name, clazz, text) {\n", - "\t\t\tvar e = document.createElement(name);\n", - "\t\t\te.className = clazz;\n", - "\t\t\te.appendChild(document.createTextNode(text));\n", - "\t\t\treturn e;\n", - "\t\t}\n", - "\t\n", - "\t\tvar tdata = document.createElement(\"thead\");\n", - "\t\tvar node = document.createElement(\"tr\");\n", - "\t\ttdata.appendChild(node);\n", - "\t\tif (inline) {\n", - "\t\t\tnode.appendChild(document.createElement(\"th\"));\n", - "\t\t\tnode.appendChild(document.createElement(\"th\"));\n", - "\t\t\tnode.appendChild(ctelt(\"th\", \"texttitle\", baseTextName + \" vs. \" + newTextName));\n", - "\t\t} else {\n", - "\t\t\tnode.appendChild(document.createElement(\"th\"));\n", - "\t\t\tnode.appendChild(ctelt(\"th\", \"texttitle\", baseTextName));\n", - "\t\t\tnode.appendChild(document.createElement(\"th\"));\n", - "\t\t\tnode.appendChild(ctelt(\"th\", \"texttitle\", newTextName));\n", - "\t\t}\n", - "\t\ttdata = [tdata];\n", - "\t\t\n", - "\t\tvar rows = [];\n", - "\t\tvar node2;\n", - "\t\t\n", - "\t\t/**\n", - "\t\t * Adds two cells to the given row; if the given row corresponds to a real\n", - "\t\t * line number (based on the line index tidx and the endpoint of the \n", - "\t\t * range in question tend), then the cells will contain the line number\n", - "\t\t * and the line of text from textLines at position tidx (with the class of\n", - "\t\t * the second cell set to the name of the change represented), and tidx + 1 will\n", - "\t\t * be returned.\t Otherwise, tidx is returned, and two empty cells are added\n", - "\t\t * to the given row.\n", - "\t\t */\n", - "\t\tfunction addCells (row, tidx, tend, textLines, change) {\n", - "\t\t\tif (tidx < tend) {\n", - "\t\t\t\trow.appendChild(telt(\"th\", (tidx + 1).toString()));\n", - "\t\t\t\trow.appendChild(ctelt(\"td\", change, textLines[tidx].replace(/\\t/g, \"\\u00a0\\u00a0\\u00a0\\u00a0\")));\n", - "\t\t\t\treturn tidx + 1;\n", - "\t\t\t} else {\n", - "\t\t\t\trow.appendChild(document.createElement(\"th\"));\n", - "\t\t\t\trow.appendChild(celt(\"td\", \"empty\"));\n", - "\t\t\t\treturn tidx;\n", - "\t\t\t}\n", - "\t\t}\n", - "\t\t\n", - "\t\tfunction addCellsInline (row, tidx, tidx2, textLines, change) {\n", - "\t\t\trow.appendChild(telt(\"th\", tidx == null ? \"\" : (tidx + 1).toString()));\n", - "\t\t\trow.appendChild(telt(\"th\", tidx2 == null ? \"\" : (tidx2 + 1).toString()));\n", - "\t\t\trow.appendChild(ctelt(\"td\", change, textLines[tidx != null ? tidx : tidx2].replace(/\\t/g, \"\\u00a0\\u00a0\\u00a0\\u00a0\")));\n", - "\t\t}\n", - "\t\t\n", - "\t\tfor (var idx = 0; idx < opcodes.length; idx++) {\n", - "\t\t\tvar code = opcodes[idx];\n", - "\t\t\tvar change = code[0];\n", - "\t\t\tvar b = code[1];\n", - "\t\t\tvar be = code[2];\n", - "\t\t\tvar n = code[3];\n", - "\t\t\tvar ne = code[4];\n", - "\t\t\tvar rowcnt = Math.max(be - b, ne - n);\n", - "\t\t\tvar toprows = [];\n", - "\t\t\tvar botrows = [];\n", - "\t\t\tfor (var i = 0; i < rowcnt; i++) {\n", - "\t\t\t\t// jump ahead if we've alredy provided leading context or if this is the first range\n", - "\t\t\t\tif (contextSize && opcodes.length > 1 && ((idx > 0 && i == contextSize) || (idx == 0 && i == 0)) && change==\"equal\") {\n", - "\t\t\t\t\tvar jump = rowcnt - ((idx == 0 ? 1 : 2) * contextSize);\n", - "\t\t\t\t\tif (jump > 1) {\n", - "\t\t\t\t\t\ttoprows.push(node = document.createElement(\"tr\"));\n", - "\t\t\t\t\t\t\n", - "\t\t\t\t\t\tb += jump;\n", - "\t\t\t\t\t\tn += jump;\n", - "\t\t\t\t\t\ti += jump - 1;\n", - "\t\t\t\t\t\tnode.appendChild(telt(\"th\", \"...\"));\n", - "\t\t\t\t\t\tif (!inline) node.appendChild(ctelt(\"td\", \"skip\", \"\"));\n", - "\t\t\t\t\t\tnode.appendChild(telt(\"th\", \"...\"));\n", - "\t\t\t\t\t\tnode.appendChild(ctelt(\"td\", \"skip\", \"\"));\n", - "\t\t\t\t\t\t\n", - "\t\t\t\t\t\t// skip last lines if they're all equal\n", - "\t\t\t\t\t\tif (idx + 1 == opcodes.length) {\n", - "\t\t\t\t\t\t\tbreak;\n", - "\t\t\t\t\t\t} else {\n", - "\t\t\t\t\t\t\tcontinue;\n", - "\t\t\t\t\t\t}\n", - "\t\t\t\t\t}\n", - "\t\t\t\t}\n", - "\t\t\t\t\n", - "\t\t\t\ttoprows.push(node = document.createElement(\"tr\"));\n", - "\t\t\t\tif (inline) {\n", - "\t\t\t\t\tif (change == \"insert\") {\n", - "\t\t\t\t\t\taddCellsInline(node, null, n++, newTextLines, change);\n", - "\t\t\t\t\t} else if (change == \"replace\") {\n", - "\t\t\t\t\t\tbotrows.push(node2 = document.createElement(\"tr\"));\n", - "\t\t\t\t\t\tif (b < be) addCellsInline(node, b++, null, baseTextLines, \"delete\");\n", - "\t\t\t\t\t\tif (n < ne) addCellsInline(node2, null, n++, newTextLines, \"insert\");\n", - "\t\t\t\t\t} else if (change == \"delete\") {\n", - "\t\t\t\t\t\taddCellsInline(node, b++, null, baseTextLines, change);\n", - "\t\t\t\t\t} else {\n", - "\t\t\t\t\t\t// equal\n", - "\t\t\t\t\t\taddCellsInline(node, b++, n++, baseTextLines, change);\n", - "\t\t\t\t\t}\n", - "\t\t\t\t} else {\n", - "\t\t\t\t\tb = addCells(node, b, be, baseTextLines, change);\n", - "\t\t\t\t\tn = addCells(node, n, ne, newTextLines, change);\n", - "\t\t\t\t}\n", - "\t\t\t}\n", - "\n", - "\t\t\tfor (var i = 0; i < toprows.length; i++) rows.push(toprows[i]);\n", - "\t\t\tfor (var i = 0; i < botrows.length; i++) rows.push(botrows[i]);\n", - "\t\t}\n", - "\t\t\n", - "\t\trows.push(node = ctelt(\"th\", \"author\", \"diff view generated by \"));\n", - "\t\tnode.setAttribute(\"colspan\", inline ? 3 : 4);\n", - "\t\tnode.appendChild(node2 = telt(\"a\", \"jsdifflib\"));\n", - "\t\tnode2.setAttribute(\"href\", \"http://github.com/cemerick/jsdifflib\");\n", - "\t\t\n", - "\t\ttdata.push(node = document.createElement(\"tbody\"));\n", - "\t\tfor (var idx in rows) rows.hasOwnProperty(idx) && node.appendChild(rows[idx]);\n", - "\t\t\n", - "\t\tnode = celt(\"table\", \"diff\" + (inline ? \" inlinediff\" : \"\"));\n", - "\t\tfor (var idx in tdata) tdata.hasOwnProperty(idx) && node.appendChild(tdata[idx]);\n", - "\t\treturn node;\n", - "\t}\n", - "};\n", - "\n", - "\n", - "/***\n", - "This is part of jsdifflib v1.0. \n", - "\n", - "Copyright (c) 2007, Snowtide Informatics Systems, Inc.\n", - "All rights reserved.\n", - "\n", - "Redistribution and use in source and binary forms, with or without modification,\n", - "are permitted provided that the following conditions are met:\n", - "\n", - "\t* Redistributions of source code must retain the above copyright notice, this\n", - "\t\tlist of conditions and the following disclaimer.\n", - "\t* Redistributions in binary form must reproduce the above copyright notice,\n", - "\t\tthis list of conditions and the following disclaimer in the documentation\n", - "\t\tand/or other materials provided with the distribution.\n", - "\t* Neither the name of the Snowtide Informatics Systems nor the names of its\n", - "\t\tcontributors may be used to endorse or promote products derived from this\n", - "\t\tsoftware without specific prior written permission.\n", - "\n", - "THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND ANY\n", - "EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES\n", - "OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT\n", - "SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,\n", - "INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED\n", - "TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR\n", - "BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN\n", - "CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN\n", - "ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH\n", - "DAMAGE.\n", - "***/\n", - "/* Author: Chas Emerick */\n", - "var __whitespace = {\" \":true, \"\\t\":true, \"\\n\":true, \"\\f\":true, \"\\r\":true};\n", - "\n", - "var difflib = {\n", - "\tdefaultJunkFunction: function (c) {\n", - "\t\treturn __whitespace.hasOwnProperty(c);\n", - "\t},\n", - "\t\n", - "\tstripLinebreaks: function (str) { return str.replace(/^[\\n\\r]*|[\\n\\r]*$/g, \"\"); },\n", - "\t\n", - "\tstringAsLines: function (str) {\n", - "\t\tvar lfpos = str.indexOf(\"\\n\");\n", - "\t\tvar crpos = str.indexOf(\"\\r\");\n", - "\t\tvar linebreak = ((lfpos > -1 && crpos > -1) || crpos < 0) ? \"\\n\" : \"\\r\";\n", - "\t\t\n", - "\t\tvar lines = str.split(linebreak);\n", - "\t\tfor (var i = 0; i < lines.length; i++) {\n", - "\t\t\tlines[i] = difflib.stripLinebreaks(lines[i]);\n", - "\t\t}\n", - "\t\t\n", - "\t\treturn lines;\n", - "\t},\n", - "\t\n", - "\t// iteration-based reduce implementation\n", - "\t__reduce: function (func, list, initial) {\n", - "\t\tif (initial != null) {\n", - "\t\t\tvar value = initial;\n", - "\t\t\tvar idx = 0;\n", - "\t\t} else if (list) {\n", - "\t\t\tvar value = list[0];\n", - "\t\t\tvar idx = 1;\n", - "\t\t} else {\n", - "\t\t\treturn null;\n", - "\t\t}\n", - "\t\t\n", - "\t\tfor (; idx < list.length; idx++) {\n", - "\t\t\tvalue = func(value, list[idx]);\n", - "\t\t}\n", - "\t\t\n", - "\t\treturn value;\n", - "\t},\n", - "\t\n", - "\t// comparison function for sorting lists of numeric tuples\n", - "\t__ntuplecomp: function (a, b) {\n", - "\t\tvar mlen = Math.max(a.length, b.length);\n", - "\t\tfor (var i = 0; i < mlen; i++) {\n", - "\t\t\tif (a[i] < b[i]) return -1;\n", - "\t\t\tif (a[i] > b[i]) return 1;\n", - "\t\t}\n", - "\t\t\n", - "\t\treturn a.length == b.length ? 0 : (a.length < b.length ? -1 : 1);\n", - "\t},\n", - "\t\n", - "\t__calculate_ratio: function (matches, length) {\n", - "\t\treturn length ? 2.0 * matches / length : 1.0;\n", - "\t},\n", - "\t\n", - "\t// returns a function that returns true if a key passed to the returned function\n", - "\t// is in the dict (js object) provided to this function; replaces being able to\n", - "\t// carry around dict.has_key in python...\n", - "\t__isindict: function (dict) {\n", - "\t\treturn function (key) { return dict.hasOwnProperty(key); };\n", - "\t},\n", - "\t\n", - "\t// replacement for python's dict.get function -- need easy default values\n", - "\t__dictget: function (dict, key, defaultValue) {\n", - "\t\treturn dict.hasOwnProperty(key) ? dict[key] : defaultValue;\n", - "\t},\t\n", - "\t\n", - "\tSequenceMatcher: function (a, b, isjunk) {\n", - "\t\tthis.set_seqs = function (a, b) {\n", - "\t\t\tthis.set_seq1(a);\n", - "\t\t\tthis.set_seq2(b);\n", - "\t\t}\n", - "\t\t\n", - "\t\tthis.set_seq1 = function (a) {\n", - "\t\t\tif (a == this.a) return;\n", - "\t\t\tthis.a = a;\n", - "\t\t\tthis.matching_blocks = this.opcodes = null;\n", - "\t\t}\n", - "\t\t\n", - "\t\tthis.set_seq2 = function (b) {\n", - "\t\t\tif (b == this.b) return;\n", - "\t\t\tthis.b = b;\n", - "\t\t\tthis.matching_blocks = this.opcodes = this.fullbcount = null;\n", - "\t\t\tthis.__chain_b();\n", - "\t\t}\n", - "\t\t\n", - "\t\tthis.__chain_b = function () {\n", - "\t\t\tvar b = this.b;\n", - "\t\t\tvar n = b.length;\n", - "\t\t\tvar b2j = this.b2j = {};\n", - "\t\t\tvar populardict = {};\n", - "\t\t\tfor (var i = 0; i < b.length; i++) {\n", - "\t\t\t\tvar elt = b[i];\n", - "\t\t\t\tif (b2j.hasOwnProperty(elt)) {\n", - "\t\t\t\t\tvar indices = b2j[elt];\n", - "\t\t\t\t\tif (n >= 200 && indices.length * 100 > n) {\n", - "\t\t\t\t\t\tpopulardict[elt] = 1;\n", - "\t\t\t\t\t\tdelete b2j[elt];\n", - "\t\t\t\t\t} else {\n", - "\t\t\t\t\t\tindices.push(i);\n", - "\t\t\t\t\t}\n", - "\t\t\t\t} else {\n", - "\t\t\t\t\tb2j[elt] = [i];\n", - "\t\t\t\t}\n", - "\t\t\t}\n", - "\t\n", - "\t\t\tfor (var elt in populardict) {\n", - "\t\t\t\tif (populardict.hasOwnProperty(elt)) {\n", - "\t\t\t\t\tdelete b2j[elt];\n", - "\t\t\t\t}\n", - "\t\t\t}\n", - "\t\t\t\n", - "\t\t\tvar isjunk = this.isjunk;\n", - "\t\t\tvar junkdict = {};\n", - "\t\t\tif (isjunk) {\n", - "\t\t\t\tfor (var elt in populardict) {\n", - "\t\t\t\t\tif (populardict.hasOwnProperty(elt) && isjunk(elt)) {\n", - "\t\t\t\t\t\tjunkdict[elt] = 1;\n", - "\t\t\t\t\t\tdelete populardict[elt];\n", - "\t\t\t\t\t}\n", - "\t\t\t\t}\n", - "\t\t\t\tfor (var elt in b2j) {\n", - "\t\t\t\t\tif (b2j.hasOwnProperty(elt) && isjunk(elt)) {\n", - "\t\t\t\t\t\tjunkdict[elt] = 1;\n", - "\t\t\t\t\t\tdelete b2j[elt];\n", - "\t\t\t\t\t}\n", - "\t\t\t\t}\n", - "\t\t\t}\n", - "\t\n", - "\t\t\tthis.isbjunk = difflib.__isindict(junkdict);\n", - "\t\t\tthis.isbpopular = difflib.__isindict(populardict);\n", - "\t\t}\n", - "\t\t\n", - "\t\tthis.find_longest_match = function (alo, ahi, blo, bhi) {\n", - "\t\t\tvar a = this.a;\n", - "\t\t\tvar b = this.b;\n", - "\t\t\tvar b2j = this.b2j;\n", - "\t\t\tvar isbjunk = this.isbjunk;\n", - "\t\t\tvar besti = alo;\n", - "\t\t\tvar bestj = blo;\n", - "\t\t\tvar bestsize = 0;\n", - "\t\t\tvar j = null;\n", - "\t\t\tvar k;\n", - "\t\n", - "\t\t\tvar j2len = {};\n", - "\t\t\tvar nothing = [];\n", - "\t\t\tfor (var i = alo; i < ahi; i++) {\n", - "\t\t\t\tvar newj2len = {};\n", - "\t\t\t\tvar jdict = difflib.__dictget(b2j, a[i], nothing);\n", - "\t\t\t\tfor (var jkey in jdict) {\n", - "\t\t\t\t\tif (jdict.hasOwnProperty(jkey)) {\n", - "\t\t\t\t\t\tj = jdict[jkey];\n", - "\t\t\t\t\t\tif (j < blo) continue;\n", - "\t\t\t\t\t\tif (j >= bhi) break;\n", - "\t\t\t\t\t\tnewj2len[j] = k = difflib.__dictget(j2len, j - 1, 0) + 1;\n", - "\t\t\t\t\t\tif (k > bestsize) {\n", - "\t\t\t\t\t\t\tbesti = i - k + 1;\n", - "\t\t\t\t\t\t\tbestj = j - k + 1;\n", - "\t\t\t\t\t\t\tbestsize = k;\n", - "\t\t\t\t\t\t}\n", - "\t\t\t\t\t}\n", - "\t\t\t\t}\n", - "\t\t\t\tj2len = newj2len;\n", - "\t\t\t}\n", - "\t\n", - "\t\t\twhile (besti > alo && bestj > blo && !isbjunk(b[bestj - 1]) && a[besti - 1] == b[bestj - 1]) {\n", - "\t\t\t\tbesti--;\n", - "\t\t\t\tbestj--;\n", - "\t\t\t\tbestsize++;\n", - "\t\t\t}\n", - "\t\t\t\t\n", - "\t\t\twhile (besti + bestsize < ahi && bestj + bestsize < bhi &&\n", - "\t\t\t\t\t!isbjunk(b[bestj + bestsize]) &&\n", - "\t\t\t\t\ta[besti + bestsize] == b[bestj + bestsize]) {\n", - "\t\t\t\tbestsize++;\n", - "\t\t\t}\n", - "\t\n", - "\t\t\twhile (besti > alo && bestj > blo && isbjunk(b[bestj - 1]) && a[besti - 1] == b[bestj - 1]) {\n", - "\t\t\t\tbesti--;\n", - "\t\t\t\tbestj--;\n", - "\t\t\t\tbestsize++;\n", - "\t\t\t}\n", - "\t\t\t\n", - "\t\t\twhile (besti + bestsize < ahi && bestj + bestsize < bhi && isbjunk(b[bestj + bestsize]) &&\n", - "\t\t\t\t\ta[besti + bestsize] == b[bestj + bestsize]) {\n", - "\t\t\t\tbestsize++;\n", - "\t\t\t}\n", - "\t\n", - "\t\t\treturn [besti, bestj, bestsize];\n", - "\t\t}\n", - "\t\t\n", - "\t\tthis.get_matching_blocks = function () {\n", - "\t\t\tif (this.matching_blocks != null) return this.matching_blocks;\n", - "\t\t\tvar la = this.a.length;\n", - "\t\t\tvar lb = this.b.length;\n", - "\t\n", - "\t\t\tvar queue = [[0, la, 0, lb]];\n", - "\t\t\tvar matching_blocks = [];\n", - "\t\t\tvar alo, ahi, blo, bhi, qi, i, j, k, x;\n", - "\t\t\twhile (queue.length) {\n", - "\t\t\t\tqi = queue.pop();\n", - "\t\t\t\talo = qi[0];\n", - "\t\t\t\tahi = qi[1];\n", - "\t\t\t\tblo = qi[2];\n", - "\t\t\t\tbhi = qi[3];\n", - "\t\t\t\tx = this.find_longest_match(alo, ahi, blo, bhi);\n", - "\t\t\t\ti = x[0];\n", - "\t\t\t\tj = x[1];\n", - "\t\t\t\tk = x[2];\n", - "\t\n", - "\t\t\t\tif (k) {\n", - "\t\t\t\t\tmatching_blocks.push(x);\n", - "\t\t\t\t\tif (alo < i && blo < j)\n", - "\t\t\t\t\t\tqueue.push([alo, i, blo, j]);\n", - "\t\t\t\t\tif (i+k < ahi && j+k < bhi)\n", - "\t\t\t\t\t\tqueue.push([i + k, ahi, j + k, bhi]);\n", - "\t\t\t\t}\n", - "\t\t\t}\n", - "\t\t\t\n", - "\t\t\tmatching_blocks.sort(difflib.__ntuplecomp);\n", - "\t\n", - "\t\t\tvar i1 = 0, j1 = 0, k1 = 0, block = 0;\n", - "\t\t\tvar i2, j2, k2;\n", - "\t\t\tvar non_adjacent = [];\n", - "\t\t\tfor (var idx in matching_blocks) {\n", - "\t\t\t\tif (matching_blocks.hasOwnProperty(idx)) {\n", - "\t\t\t\t\tblock = matching_blocks[idx];\n", - "\t\t\t\t\ti2 = block[0];\n", - "\t\t\t\t\tj2 = block[1];\n", - "\t\t\t\t\tk2 = block[2];\n", - "\t\t\t\t\tif (i1 + k1 == i2 && j1 + k1 == j2) {\n", - "\t\t\t\t\t\tk1 += k2;\n", - "\t\t\t\t\t} else {\n", - "\t\t\t\t\t\tif (k1) non_adjacent.push([i1, j1, k1]);\n", - "\t\t\t\t\t\ti1 = i2;\n", - "\t\t\t\t\t\tj1 = j2;\n", - "\t\t\t\t\t\tk1 = k2;\n", - "\t\t\t\t\t}\n", - "\t\t\t\t}\n", - "\t\t\t}\n", - "\t\t\t\n", - "\t\t\tif (k1) non_adjacent.push([i1, j1, k1]);\n", - "\t\n", - "\t\t\tnon_adjacent.push([la, lb, 0]);\n", - "\t\t\tthis.matching_blocks = non_adjacent;\n", - "\t\t\treturn this.matching_blocks;\n", - "\t\t}\n", - "\t\t\n", - "\t\tthis.get_opcodes = function () {\n", - "\t\t\tif (this.opcodes != null) return this.opcodes;\n", - "\t\t\tvar i = 0;\n", - "\t\t\tvar j = 0;\n", - "\t\t\tvar answer = [];\n", - "\t\t\tthis.opcodes = answer;\n", - "\t\t\tvar block, ai, bj, size, tag;\n", - "\t\t\tvar blocks = this.get_matching_blocks();\n", - "\t\t\tfor (var idx in blocks) {\n", - "\t\t\t\tif (blocks.hasOwnProperty(idx)) {\n", - "\t\t\t\t\tblock = blocks[idx];\n", - "\t\t\t\t\tai = block[0];\n", - "\t\t\t\t\tbj = block[1];\n", - "\t\t\t\t\tsize = block[2];\n", - "\t\t\t\t\ttag = '';\n", - "\t\t\t\t\tif (i < ai && j < bj) {\n", - "\t\t\t\t\t\ttag = 'replace';\n", - "\t\t\t\t\t} else if (i < ai) {\n", - "\t\t\t\t\t\ttag = 'delete';\n", - "\t\t\t\t\t} else if (j < bj) {\n", - "\t\t\t\t\t\ttag = 'insert';\n", - "\t\t\t\t\t}\n", - "\t\t\t\t\tif (tag) answer.push([tag, i, ai, j, bj]);\n", - "\t\t\t\t\ti = ai + size;\n", - "\t\t\t\t\tj = bj + size;\n", - "\t\t\t\t\t\n", - "\t\t\t\t\tif (size) answer.push(['equal', ai, i, bj, j]);\n", - "\t\t\t\t}\n", - "\t\t\t}\n", - "\t\t\t\n", - "\t\t\treturn answer;\n", - "\t\t}\n", - "\t\t\n", - "\t\t// this is a generator function in the python lib, which of course is not supported in javascript\n", - "\t\t// the reimplementation builds up the grouped opcodes into a list in their entirety and returns that.\n", - "\t\tthis.get_grouped_opcodes = function (n) {\n", - "\t\t\tif (!n) n = 3;\n", - "\t\t\tvar codes = this.get_opcodes();\n", - "\t\t\tif (!codes) codes = [[\"equal\", 0, 1, 0, 1]];\n", - "\t\t\tvar code, tag, i1, i2, j1, j2;\n", - "\t\t\tif (codes[0][0] == 'equal') {\n", - "\t\t\t\tcode = codes[0];\n", - "\t\t\t\ttag = code[0];\n", - "\t\t\t\ti1 = code[1];\n", - "\t\t\t\ti2 = code[2];\n", - "\t\t\t\tj1 = code[3];\n", - "\t\t\t\tj2 = code[4];\n", - "\t\t\t\tcodes[0] = [tag, Math.max(i1, i2 - n), i2, Math.max(j1, j2 - n), j2];\n", - "\t\t\t}\n", - "\t\t\tif (codes[codes.length - 1][0] == 'equal') {\n", - "\t\t\t\tcode = codes[codes.length - 1];\n", - "\t\t\t\ttag = code[0];\n", - "\t\t\t\ti1 = code[1];\n", - "\t\t\t\ti2 = code[2];\n", - "\t\t\t\tj1 = code[3];\n", - "\t\t\t\tj2 = code[4];\n", - "\t\t\t\tcodes[codes.length - 1] = [tag, i1, Math.min(i2, i1 + n), j1, Math.min(j2, j1 + n)];\n", - "\t\t\t}\n", - "\t\n", - "\t\t\tvar nn = n + n;\n", - "\t\t\tvar group = [];\n", - "\t\t\tvar groups = [];\n", - "\t\t\tfor (var idx in codes) {\n", - "\t\t\t\tif (codes.hasOwnProperty(idx)) {\n", - "\t\t\t\t\tcode = codes[idx];\n", - "\t\t\t\t\ttag = code[0];\n", - "\t\t\t\t\ti1 = code[1];\n", - "\t\t\t\t\ti2 = code[2];\n", - "\t\t\t\t\tj1 = code[3];\n", - "\t\t\t\t\tj2 = code[4];\n", - "\t\t\t\t\tif (tag == 'equal' && i2 - i1 > nn) {\n", - "\t\t\t\t\t\tgroup.push([tag, i1, Math.min(i2, i1 + n), j1, Math.min(j2, j1 + n)]);\n", - "\t\t\t\t\t\tgroups.push(group);\n", - "\t\t\t\t\t\tgroup = [];\n", - "\t\t\t\t\t\ti1 = Math.max(i1, i2-n);\n", - "\t\t\t\t\t\tj1 = Math.max(j1, j2-n);\n", - "\t\t\t\t\t}\n", - "\t\t\t\t\t\n", - "\t\t\t\t\tgroup.push([tag, i1, i2, j1, j2]);\n", - "\t\t\t\t}\n", - "\t\t\t}\n", - "\t\t\t\n", - "\t\t\tif (group && !(group.length == 1 && group[0][0] == 'equal')) groups.push(group)\n", - "\t\t\t\n", - "\t\t\treturn groups;\n", - "\t\t}\n", - "\t\t\n", - "\t\tthis.ratio = function () {\n", - "\t\t\tmatches = difflib.__reduce(\n", - "\t\t\t\t\t\t\tfunction (sum, triple) { return sum + triple[triple.length - 1]; },\n", - "\t\t\t\t\t\t\tthis.get_matching_blocks(), 0);\n", - "\t\t\treturn difflib.__calculate_ratio(matches, this.a.length + this.b.length);\n", - "\t\t}\n", - "\t\t\n", - "\t\tthis.quick_ratio = function () {\n", - "\t\t\tvar fullbcount, elt;\n", - "\t\t\tif (this.fullbcount == null) {\n", - "\t\t\t\tthis.fullbcount = fullbcount = {};\n", - "\t\t\t\tfor (var i = 0; i < this.b.length; i++) {\n", - "\t\t\t\t\telt = this.b[i];\n", - "\t\t\t\t\tfullbcount[elt] = difflib.__dictget(fullbcount, elt, 0) + 1;\n", - "\t\t\t\t}\n", - "\t\t\t}\n", - "\t\t\tfullbcount = this.fullbcount;\n", - "\t\n", - "\t\t\tvar avail = {};\n", - "\t\t\tvar availhas = difflib.__isindict(avail);\n", - "\t\t\tvar matches = numb = 0;\n", - "\t\t\tfor (var i = 0; i < this.a.length; i++) {\n", - "\t\t\t\telt = this.a[i];\n", - "\t\t\t\tif (availhas(elt)) {\n", - "\t\t\t\t\tnumb = avail[elt];\n", - "\t\t\t\t} else {\n", - "\t\t\t\t\tnumb = difflib.__dictget(fullbcount, elt, 0);\n", - "\t\t\t\t}\n", - "\t\t\t\tavail[elt] = numb - 1;\n", - "\t\t\t\tif (numb > 0) matches++;\n", - "\t\t\t}\n", - "\t\t\t\n", - "\t\t\treturn difflib.__calculate_ratio(matches, this.a.length + this.b.length);\n", - "\t\t}\n", - "\t\t\n", - "\t\tthis.real_quick_ratio = function () {\n", - "\t\t\tvar la = this.a.length;\n", - "\t\t\tvar lb = this.b.length;\n", - "\t\t\treturn _calculate_ratio(Math.min(la, lb), la + lb);\n", - "\t\t}\n", - "\t\t\n", - "\t\tthis.isjunk = isjunk ? isjunk : difflib.defaultJunkFunction;\n", - "\t\tthis.a = this.b = null;\n", - "\t\tthis.set_seqs(a, b);\n", - "\t}\n", - "};\n", - "\n", - "\n", - "\n", - "function diffUsingJS (viewType, contextSize, baseText, newText) {\n", - "\n", - " var byId = function (id) { return document.getElementById(id); },\n", - " base = difflib.stringAsLines(baseText),\n", - " newtxt = difflib.stringAsLines(newText),\n", - " sm = new difflib.SequenceMatcher(base, newtxt),\n", - " opcodes = sm.get_opcodes(),\n", - " diffoutputdiv = byId(\"diffid_2021-08-05_16_46_43_079488\");\n", - "\n", - " diffoutputdiv.innerHTML = \"\";\n", - " contextSize = contextSize || null;\n", - "\n", - " diffoutputdiv.appendChild(diffview.buildView({\n", - " baseTextLines: base,\n", - " newTextLines: newtxt,\n", - " opcodes: opcodes,\n", - " baseTextName: \"Base Text\",\n", - " newTextName: \"New Text\",\n", - " contextSize: contextSize,\n", - " viewType: viewType\n", - " }));\n", - "}\n", - "var tview=0;\n", - "var csize='';\n", - "var bt = 'def fft2d(mat, fft_length):\\n mat = mat[:fft_length[0], :fft_length[1]]\\n res = mat.copy()\\n \\n # first FFT\\n res = dft_real(res, fft_length=fft_length[1], transpose=True)\\n \\n # second FFT decomposed on FFT on real part and imaginary part\\n res2_real = dft_real(res[0], fft_length=fft_length[0], transpose=False)\\n res2_imag = dft_real(res[1], fft_length=fft_length[0], transpose=False) \\n res2_imag2 = numpy.vstack([-res2_imag[1:2], res2_imag[:1]])\\n res = res2_real + res2_imag2\\n size = fft_length[1]//2 + 1\\n return res[:, :fft_length[0], :size]\\n';\n", - "var nt = 'def fft2d_d3(mat, fft_length):\\n mat = mat[:, :fft_length[-2], :fft_length[-1]]\\n res = mat.copy()\\n \\n # first FFT\\n res = dft_real_d3(res, fft_length=fft_length[-1], transpose=True)\\n \\n # second FFT decomposed on FFT on real part and imaginary part\\n res2_real = dft_real_d3(res[0], fft_length=fft_length[-2], transpose=False)\\n res2_imag = dft_real_d3(res[1], fft_length=fft_length[-2], transpose=False)\\n res2_imag2 = numpy.vstack([-res2_imag[1:2], res2_imag[:1]])\\n res = res2_real + res2_imag2\\n size = fft_length[-1]//2 + 1\\n return res[:, :, :fft_length[-2], :size]\\n';\n", - "diffUsingJS(tview, csize, bt, nt) ;\n" + "cell_type": "markdown", + "id": "faa21909", + "metadata": {}, + "source": [ + "And with a different `fft_length`." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "bf98995f", + "metadata": {}, + "outputs": [], + "source": [ + "fft2d_np = numpy.fft.rfft2(rnd, (4, 6))\n", + "fft2d_cus = fft2d(rnd, (4, 6))\n", + "almost_equal(fft2d_np[:4, :], fft2d_cus)" + ] + }, + { + "cell_type": "markdown", + "id": "caee1f84", + "metadata": {}, + "source": [ + "## FFT 2D in ONNX\n", + "\n", + "We use again the numpy API for ONNX." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "ca641274", + "metadata": {}, + "outputs": [], + "source": [ + "def onnx_rfft_1d(x, fft_length=None, transpose=True):\n", + " if fft_length is None:\n", + " raise RuntimeError(\"fft_length must be specified.\")\n", + " \n", + " size = fft_length // 2 + 1\n", + " cst = dft_real_cst(fft_length, fft_length).astype(numpy.float32)\n", + " if transpose:\n", + " xt = npnx.transpose(x, (1, 0))\n", + " res = npnx.matmul(cst[:, :, :fft_length], xt[:fft_length])[:, :size, :]\n", + " return npnx.transpose(res, (0, 2, 1))\n", + " else:\n", + " return npnx.matmul(cst[:, :, :fft_length], x[:fft_length])\n", + "\n", + "\n", + "@onnxnumpy_np(signature=NDArrayType((\"T:all\", ), dtypes_out=('T',)))\n", + "def onnx_rfft_2d(x, fft_length=None):\n", + " mat = x[:fft_length[0], :fft_length[1]]\n", + " \n", + " # first FFT\n", + " res = onnx_rfft_1d(mat, fft_length=fft_length[1], transpose=True)\n", + " \n", + " # second FFT decomposed on FFT on real part and imaginary part\n", + " res2_real = onnx_rfft_1d(res[0], fft_length=fft_length[0], transpose=False)\n", + " res2_imag = onnx_rfft_1d(res[1], fft_length=fft_length[0], transpose=False) \n", + " res2_imag2 = npnx.vstack(-res2_imag[1:2], res2_imag[:1])\n", + " res = res2_real + res2_imag2\n", + " size = fft_length[1]//2 + 1\n", + " return res[:, :fft_length[0], :size]\n", + "\n", + "\n", + "fft2d_cus = fft2d(rnd, rnd.shape)\n", + "fft2d_onx = onnx_rfft_2d(rnd, fft_length=rnd.shape)\n", + "almost_equal(fft2d_cus, fft2d_onx)" + ] + }, + { + "cell_type": "markdown", + "id": "20fcd8a9", + "metadata": {}, + "source": [ + "The corresponding ONNX graph." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "b1379b06", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } ], - "text/plain": [ - "" + "source": [ + "key = list(onnx_rfft_2d.signed_compiled)[0]\n", + "%onnxview onnx_rfft_2d.signed_compiled[key].compiled.onnx_" ] - }, - "execution_count": 30, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "text1 = inspect.getsource(fft2d)\n", - "text2 = inspect.getsource(fft2d_d3)\n", - "%textdiff text1 text2" - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "id": "51e7a4f7", - "metadata": { - "scrolled": false - }, - "outputs": [], - "source": [ - "def onnx_rfft_3d_1d(x, fft_length=None, transpose=True):\n", - " if fft_length is None:\n", - " raise RuntimeError(\"fft_length must be specified.\")\n", - " \n", - " size = fft_length // 2 + 1\n", - " cst = dft_real_cst(fft_length, fft_length).astype(numpy.float32)\n", - " if transpose:\n", - " xt = npnx.transpose(x, (0, 2, 1))\n", - " a = cst[:, :, :fft_length]\n", - " b = xt[:, :fft_length, :]\n", - " a = npnx.expand_dims(a, 0)\n", - " b = npnx.expand_dims(b, 1)\n", - " res = npnx.matmul(a, b)\n", - " res2 = res[:, :size, :]\n", - " return npnx.transpose(res2, (1, 0, 3, 2))\n", - " else:\n", - " a = cst[:, :, :fft_length]\n", - " b = x[:, :fft_length, :]\n", - " a = npnx.expand_dims(a, 0)\n", - " b = npnx.expand_dims(b, 1)\n", - " res = npnx.matmul(a, b)\n", - " return npnx.transpose(res, (1, 0, 2, 3)) \n", - " \n", - "\n", - "def onnx_rfft_3d_2d(x, fft_length=None):\n", - " mat = x[:, :fft_length[-2], :fft_length[-1]]\n", - " \n", - " # first FFT\n", - " res = onnx_rfft_3d_1d(mat, fft_length=fft_length[-1], transpose=True)\n", - " \n", - " # second FFT decomposed on FFT on real part and imaginary part\n", - " res2_real = onnx_rfft_3d_1d(res[0], fft_length=fft_length[0], transpose=False)\n", - " res2_imag = onnx_rfft_3d_1d(res[1], fft_length=fft_length[0], transpose=False) \n", - " res2_imag2 = npnx.vstack(-res2_imag[1:2], res2_imag[:1])\n", - " res = res2_real + res2_imag2\n", - " size = fft_length[1]//2 + 1\n", - " return res[:, :, :fft_length[-2], :size]\n", - "\n", - "\n", - "@onnxnumpy_np(signature=NDArrayType((\"T:all\", ), dtypes_out=('T',)))\n", - "def onnx_rfft_2d_any(x, fft_length=None):\n", - " new_shape = npnx.concat(\n", - " numpy.array([-1], dtype=numpy.int64), x.shape[-2:], axis=0)\n", - " mat2 = x.reshape(new_shape)\n", - " f2 = onnx_rfft_3d_2d(mat2, fft_length)\n", - " new_shape = npnx.concat(\n", - " numpy.array([2], dtype=numpy.int64), x.shape[:-2], f2.shape[-2:])\n", - " return f2.reshape(new_shape)\n", - "\n", - "\n", - "shape = (3, 1, 4)\n", - "fft_length = (1, 4)\n", - "rnd = numpy.random.randn(*list(shape)).astype(numpy.float32)\n", - "fft2d_cus = fft2d_any(rnd, fft_length)\n", - "fft2d_onx = onnx_rfft_2d_any(rnd, fft_length=fft_length)\n", - "almost_equal(fft2d_cus, fft2d_onx)" - ] - }, - { - "cell_type": "markdown", - "id": "37c45ae7", - "metadata": {}, - "source": [ - "Let's do the same comparison." - ] - }, - { - "cell_type": "code", - "execution_count": 32, - "id": "11c1e596", - "metadata": {}, - "outputs": [ + }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "OK x.shape=(3, 1, 4) length=(1, 4) output shape=(3, 4) or (2, 3, 1, 3)\n", - "OK x.shape=(3, 1, 4) length=(1, 4) output shape=(3, 4) or (2, 3, 1, 3)\n", - "OK x.shape=(3, 1, 4) length=(1, 4) output shape=(3, 4) or (2, 3, 1, 3)\n", - "OK x.shape=(3, 1, 4) length=(1, 2) output shape=(3, 4) or (2, 3, 1, 2)\n", - "DIS x.shape=(3, 1, 4) length=(1, 1) error=AssertionError('Mismatch max diff=1.0 > 1e-05.') output shape=(3, 4) or (2, 3, 1, 1)\n", - "OK x.shape=(5, 7) length=(5, 7) output shape=(3, 4) or (2, 5, 4)\n", - "OK x.shape=(5, 7) length=(1, 7) output shape=(3, 4) or (2, 1, 4)\n", - "OK x.shape=(5, 7) length=(2, 7) output shape=(3, 4) or (2, 2, 4)\n", - "OK x.shape=(5, 7) length=(5, 2) output shape=(3, 4) or (2, 5, 2)\n", - "OK x.shape=(5, 7) length=(3, 4) output shape=(3, 4) or (2, 3, 3)\n", - "OK x.shape=(3, 5, 7) length=(5, 7) output shape=(3, 4) or (2, 3, 5, 4)\n", - "OK x.shape=(3, 5, 7) length=(1, 7) output shape=(3, 4) or (2, 3, 1, 4)\n", - "OK x.shape=(3, 5, 7) length=(2, 7) output shape=(3, 4) or (2, 3, 2, 4)\n", - "OK x.shape=(3, 5, 7) length=(5, 2) output shape=(3, 4) or (2, 3, 5, 2)\n", - "OK x.shape=(3, 5, 7) length=(3, 4) output shape=(3, 4) or (2, 3, 3, 3)\n", - "OK x.shape=(7, 5) length=(7, 5) output shape=(3, 4) or (2, 7, 3)\n", - "OK x.shape=(7, 5) length=(1, 5) output shape=(3, 4) or (2, 1, 3)\n", - "OK x.shape=(7, 5) length=(2, 5) output shape=(3, 4) or (2, 2, 3)\n", - "OK x.shape=(7, 5) length=(7, 2) output shape=(3, 4) or (2, 7, 2)\n", - "OK x.shape=(7, 5) length=(3, 4) output shape=(3, 4) or (2, 3, 3)\n" - ] - } - ], - "source": [ - "for shape in [(3, 1, 4), (5, 7), (3, 5, 7), (7, 5)]:\n", - " for fft_length in [shape[-2:], (1, shape[-1]),\n", - " (min(2, shape[-2]), shape[-1]),\n", - " (shape[-2], 2),\n", - " (min(3, shape[-2]), min(4, shape[-2]))]:\n", - " x = numpy.random.randn(*list(shape)).astype(numpy.float32)\n", - " if len(fnp.shape) == 2:\n", - " fn= numpy.expand_dims(fnp, 0)\n", - " try:\n", - " cus = fft2d_any(x, fft_length)\n", - " except IndexError as e:\n", - " print(\"ERR x.shape=%r length=%r error=%r\" % (x.shape, fft_length, e))\n", - " continue\n", - " try:\n", - " onx = onnx_rfft_2d_any(x, fft_length=fft_length)\n", - " except IndexError as e:\n", - " print(\"ERR x.shape=%r length=%r error=%r\" % (x.shape, fft_length, e))\n", - " continue\n", - " try:\n", - " almost_equal(onx, cus)\n", - " except (AssertionError, IndexError) as e:\n", - " print(\"DIS x.shape=%r length=%r error=%r output shape=%r or %r\" % (\n", - " x.shape, fft_length, e, fnp.shape, cus.shape))\n", - " continue\n", - " print(\"OK x.shape=%r length=%r output shape=%r or %r\" % (\n", - " x.shape, fft_length, fnp.shape, cus.shape))" - ] - }, - { - "cell_type": "markdown", - "id": "d197467f", - "metadata": {}, - "source": [ - "There is one issue with ``fft_length=(1, 1)`` but that case is out of scope." - ] - }, - { - "cell_type": "markdown", - "id": "33b5897e", - "metadata": {}, - "source": [ - "### ONNX graph" - ] - }, - { - "cell_type": "code", - "execution_count": 33, - "id": "d45e9a99", - "metadata": {}, - "outputs": [ + "cell_type": "code", + "execution_count": 20, + "id": "3034da60", + "metadata": {}, + "outputs": [], + "source": [ + "with open(\"fft2d.onnx\", \"wb\") as f:\n", + " key = list(onnx_rfft_2d.signed_compiled)[0]\n", + " f.write(onnx_rfft_2d.signed_compiled[key].compiled.onnx_.SerializeToString())" + ] + }, + { + "cell_type": "markdown", + "id": "3a747f0c", + "metadata": {}, + "source": [ + "With a different `fft_length`." + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "16732cbb", + "metadata": {}, + "outputs": [], + "source": [ + "fft2d_cus = fft2d(rnd, (4, 5))\n", + "fft2d_onx = onnx_rfft_2d(rnd, fft_length=(4, 5))\n", + "almost_equal(fft2d_cus, fft2d_onx)" + ] + }, + { + "cell_type": "markdown", + "id": "04924e7d", + "metadata": {}, + "source": [ + "This implementation of FFT in ONNX assumes shapes and fft lengths are constant. Otherwise, the matrix returned by function `dft_real_cst` must be converted as well. That's left as an exercise." + ] + }, + { + "cell_type": "markdown", + "id": "c9da88a0", + "metadata": {}, + "source": [ + "## FFT2D with shape (3,1,4)\n", + "\n", + "Previous implementation expects the input matrix to have two dimensions. It fails with 3." + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "66ba70ee", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(3, 1, 4)" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "shape = (3, 1, 4)\n", + "fft_length = (1, 4)\n", + "rnd = numpy.random.randn(*list(shape)).astype(numpy.float32)\n", + "fft2d_numpy = numpy.fft.fft2(rnd, fft_length)\n", + "fft2d_numpy.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "a4d123e1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[[ 0.87908971+0.j , 0.85659337+2.10206711j,\n", + " 1.5270735 +0.j , 0.85659337-2.10206711j]],\n", + "\n", + " [[-5.01959181+0.j , -0.25658643+0.62102163j,\n", + " 2.18641639+0.j , -0.25658643-0.62102163j]],\n", + "\n", + " [[ 0.60041136+0.j , -0.04546577-1.2931717j ,\n", + " 1.19486004+0.j , -0.04546577+1.2931717j ]]])" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "fft2d_numpy" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "4b1bd05b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "axes don't match array\n" + ] + } + ], + "source": [ + "try:\n", + " fft2d_cus = fft2d(rnd, fft_length)\n", + "except Exception as e:\n", + " print(e)\n", + "# fft2d_onx = onnx_rfft_2d(rnd, fft_length=fft_length)" + ] + }, + { + "cell_type": "markdown", + "id": "7bd79a00", + "metadata": {}, + "source": [ + "### numpy version\n", + "\n", + "Let's do it again with numpy first. [fft2](https://numpy.org/doc/stable/reference/generated/numpy.fft.fft2.html) performs `fft2` on the last two axis as many times as the first axis. The goal is still to have an implementation which works for any dimension." + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "3b618335", + "metadata": {}, + "outputs": [], + "source": [ + "conc = []\n", + "for i in range(rnd.shape[0]):\n", + " f2 = fft2d(rnd[i], fft_length)\n", + " conc.append(numpy.expand_dims(f2, 0))\n", + "res = numpy.vstack(conc).transpose(1, 0, 2, 3)\n", + "almost_equal(fft2d_numpy[:, :, :3], res)" + ] + }, + { + "cell_type": "markdown", + "id": "7c837e7a", + "metadata": {}, + "source": [ + "It works. And now a more efficient implementation. It is better to read [matmul](https://numpy.org/doc/stable/reference/generated/numpy.matmul.html) description before. To summarize, a third axis is equivalent to many matrix multiplications over the last two axes, as many as the dimension of the first axis: ``matmul(A[I,J,K], B[I,K,L]) --> C[I,J,L]``. Broadcasting also works... ``matmul(A[1,J,K], B[I,K,L]) --> C[I,J,L]``." + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "29055cb2", + "metadata": {}, + "outputs": [], + "source": [ + "def dft_real_d3(x, fft_length=None, transpose=True):\n", + " if len(x.shape) != 3:\n", + " raise RuntimeError(\"Not implemented for shape=%r.\" % x.shape)\n", + " N = x.shape[1]\n", + " C = x.shape[-1] if transpose else x.shape[-2]\n", + " if fft_length is None:\n", + " fft_length = x.shape[-1]\n", + " size = fft_length // 2 + 1\n", + "\n", + " cst = dft_real_cst(C, fft_length)\n", + " if transpose:\n", + " x = numpy.transpose(x, (0, 2, 1))\n", + " a = cst[:, :, :fft_length]\n", + " b = x[:, :fft_length, :]\n", + " a = numpy.expand_dims(a, 0)\n", + " b = numpy.expand_dims(b, 1)\n", + " res = numpy.matmul(a, b)\n", + " res = res[:, :, :size, :]\n", + " return numpy.transpose(res, (1, 0, 3, 2))\n", + " else:\n", + " a = cst[:, :, :fft_length]\n", + " b = x[:, :fft_length, :]\n", + " a = numpy.expand_dims(a, 0)\n", + " b = numpy.expand_dims(b, 1)\n", + " res = numpy.matmul(a, b)\n", + " return numpy.transpose(res, (1, 0, 2, 3))\n", + "\n", + "\n", + "def fft2d_d3(mat, fft_length):\n", + " mat = mat[:, :fft_length[-2], :fft_length[-1]]\n", + " res = mat.copy()\n", + " \n", + " # first FFT\n", + " res = dft_real_d3(res, fft_length=fft_length[-1], transpose=True)\n", + " \n", + " # second FFT decomposed on FFT on real part and imaginary part\n", + " res2_real = dft_real_d3(res[0], fft_length=fft_length[-2], transpose=False)\n", + " res2_imag = dft_real_d3(res[1], fft_length=fft_length[-2], transpose=False)\n", + " res2_imag2 = numpy.vstack([-res2_imag[1:2], res2_imag[:1]])\n", + " res = res2_real + res2_imag2\n", + " size = fft_length[-1]//2 + 1\n", + " return res[:, :, :fft_length[-2], :size]\n", + "\n", + "\n", + "def fft2d_any(mat, fft_length):\n", + " new_shape = (-1, ) + mat.shape[-2:]\n", + " mat2 = mat.reshape(new_shape)\n", + " f2 = fft2d_d3(mat2, fft_length)\n", + " new_shape = (2, ) + mat.shape[:-2] + f2.shape[-2:]\n", + " return f2.reshape(new_shape)\n", + "\n", + "\n", + "shape = (3, 1, 4)\n", + "fft_length = (1, 4)\n", + "rnd = numpy.random.randn(*list(shape)).astype(numpy.float32)\n", + "fft2d_numpy = numpy.fft.fft2(rnd, fft_length)\n", + "fft2d_cus = fft2d_any(rnd, fft_length)\n", + "almost_equal(fft2d_numpy[..., :3], fft2d_cus)" + ] + }, { - "data": { - "text/html": [ - "
\n", - "" + "cell_type": "markdown", + "id": "0128b3f2", + "metadata": {}, + "source": [ + "We check with more shapes to see if the implementation works for all of them." + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "82f5fc78", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "OK x.shape=(3, 1, 4) length=(1, 4) output shape=(3, 1, 4) or (2, 3, 1, 3)\n", + "OK x.shape=(3, 1, 4) length=(1, 4) output shape=(3, 1, 4) or (2, 3, 1, 3)\n", + "OK x.shape=(3, 1, 4) length=(1, 4) output shape=(3, 1, 4) or (2, 3, 1, 3)\n", + "OK x.shape=(3, 1, 4) length=(1, 2) output shape=(3, 1, 2) or (2, 3, 1, 2)\n", + "OK x.shape=(3, 1, 4) length=(1, 1) output shape=(3, 1, 1) or (2, 3, 1, 1)\n", + "OK x.shape=(5, 7) length=(5, 7) output shape=(5, 7) or (2, 5, 4)\n", + "OK x.shape=(5, 7) length=(1, 7) output shape=(1, 7) or (2, 1, 4)\n", + "OK x.shape=(5, 7) length=(2, 7) output shape=(2, 7) or (2, 2, 4)\n", + "OK x.shape=(5, 7) length=(5, 2) output shape=(5, 2) or (2, 5, 2)\n", + "OK x.shape=(5, 7) length=(3, 4) output shape=(3, 4) or (2, 3, 3)\n", + "OK x.shape=(3, 5, 7) length=(5, 7) output shape=(3, 5, 7) or (2, 3, 5, 4)\n", + "OK x.shape=(3, 5, 7) length=(1, 7) output shape=(3, 1, 7) or (2, 3, 1, 4)\n", + "OK x.shape=(3, 5, 7) length=(2, 7) output shape=(3, 2, 7) or (2, 3, 2, 4)\n", + "OK x.shape=(3, 5, 7) length=(5, 2) output shape=(3, 5, 2) or (2, 3, 5, 2)\n", + "OK x.shape=(3, 5, 7) length=(3, 4) output shape=(3, 3, 4) or (2, 3, 3, 3)\n", + "OK x.shape=(7, 5) length=(7, 5) output shape=(7, 5) or (2, 7, 3)\n", + "OK x.shape=(7, 5) length=(1, 5) output shape=(1, 5) or (2, 1, 3)\n", + "OK x.shape=(7, 5) length=(2, 5) output shape=(2, 5) or (2, 2, 3)\n", + "OK x.shape=(7, 5) length=(7, 2) output shape=(7, 2) or (2, 7, 2)\n", + "OK x.shape=(7, 5) length=(3, 4) output shape=(3, 4) or (2, 3, 3)\n" + ] + } ], - "text/plain": [ - "" + "source": [ + "for shape in [(3, 1, 4), (5, 7), (3, 5, 7), (7, 5)]:\n", + " for fft_length in [shape[-2:], (1, shape[-1]),\n", + " (min(2, shape[-2]), shape[-1]),\n", + " (shape[-2], 2),\n", + " (min(3, shape[-2]), min(4, shape[-2]))]:\n", + " x = numpy.random.randn(*list(shape)).astype(numpy.float32)\n", + " fnp = numpy.fft.fft2(x, fft_length)\n", + " if len(fnp.shape) == 2:\n", + " fn= numpy.expand_dims(fnp, 0)\n", + " try:\n", + " cus = fft2d_any(x, fft_length)\n", + " except IndexError as e:\n", + " print(\"ERR x.shape=%r length=%r error=%r\" % (x.shape, fft_length, e))\n", + " continue\n", + " try:\n", + " almost_equal(fnp[..., :cus.shape[-1]], cus)\n", + " except (AssertionError, IndexError) as e:\n", + " print(\"DIS x.shape=%r length=%r error=%r output shape=%r or %r\" % (\n", + " x.shape, fft_length, e, fnp.shape, cus.shape))\n", + " continue\n", + " print(\"OK x.shape=%r length=%r output shape=%r or %r\" % (\n", + " x.shape, fft_length, fnp.shape, cus.shape))" + ] + }, + { + "cell_type": "markdown", + "id": "c5f5229a", + "metadata": {}, + "source": [ + "### ONNX version\n", + "\n", + "Let's look into the differences first." + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "025c2d88", + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext pyquickhelper" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "82664bc5", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
populating...
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "/*\n", + "This is part of jsdifflib v1.0. \n", + "\n", + "Copyright 2007 - 2011 Chas Emerick . All rights reserved.\n", + "\n", + "Redistribution and use in source and binary forms, with or without modification, are\n", + "permitted provided that the following conditions are met:\n", + "\n", + " 1. Redistributions of source code must retain the above copyright notice, this list of\n", + " conditions and the following disclaimer.\n", + "\n", + " 2. Redistributions in binary form must reproduce the above copyright notice, this list\n", + " of conditions and the following disclaimer in the documentation and/or other materials\n", + " provided with the distribution.\n", + "\n", + "THIS SOFTWARE IS PROVIDED BY Chas Emerick ``AS IS'' AND ANY EXPRESS OR IMPLIED\n", + "WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND\n", + "FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL Chas Emerick OR\n", + "CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR\n", + "CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n", + "SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON\n", + "ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING\n", + "NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF\n", + "ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n", + "\n", + "The views and conclusions contained in the software and documentation are those of the\n", + "authors and should not be interpreted as representing official policies, either expressed\n", + "or implied, of Chas Emerick.\n", + "*/\n", + "var diffview = {\n", + "\t/**\n", + "\t * Builds and returns a visual diff view. The single parameter, `params', should contain\n", + "\t * the following values:\n", + "\t *\n", + "\t * - baseTextLines: the array of strings that was used as the base text input to SequenceMatcher\n", + "\t * - newTextLines: the array of strings that was used as the new text input to SequenceMatcher\n", + "\t * - opcodes: the array of arrays returned by SequenceMatcher.get_opcodes()\n", + "\t * - baseTextName: the title to be displayed above the base text listing in the diff view; defaults\n", + "\t *\t to \"Base Text\"\n", + "\t * - newTextName: the title to be displayed above the new text listing in the diff view; defaults\n", + "\t *\t to \"New Text\"\n", + "\t * - contextSize: the number of lines of context to show around differences; by default, all lines\n", + "\t *\t are shown\n", + "\t * - viewType: if 0, a side-by-side diff view is generated (default); if 1, an inline diff view is\n", + "\t *\t generated\n", + "\t */\n", + "\tbuildView: function (params) {\n", + "\t\tvar baseTextLines = params.baseTextLines;\n", + "\t\tvar newTextLines = params.newTextLines;\n", + "\t\tvar opcodes = params.opcodes;\n", + "\t\tvar baseTextName = params.baseTextName ? params.baseTextName : \"Base Text\";\n", + "\t\tvar newTextName = params.newTextName ? params.newTextName : \"New Text\";\n", + "\t\tvar contextSize = params.contextSize;\n", + "\t\tvar inline = (params.viewType == 0 || params.viewType == 1) ? params.viewType : 0;\n", + "\n", + "\t\tif (baseTextLines == null)\n", + "\t\t\tthrow \"Cannot build diff view; baseTextLines is not defined.\";\n", + "\t\tif (newTextLines == null)\n", + "\t\t\tthrow \"Cannot build diff view; newTextLines is not defined.\";\n", + "\t\tif (!opcodes)\n", + "\t\t\tthrow \"Cannot build diff view; opcodes is not defined.\";\n", + "\t\t\n", + "\t\tfunction celt (name, clazz) {\n", + "\t\t\tvar e = document.createElement(name);\n", + "\t\t\te.className = clazz;\n", + "\t\t\treturn e;\n", + "\t\t}\n", + "\t\t\n", + "\t\tfunction telt (name, text) {\n", + "\t\t\tvar e = document.createElement(name);\n", + "\t\t\te.appendChild(document.createTextNode(text));\n", + "\t\t\treturn e;\n", + "\t\t}\n", + "\t\t\n", + "\t\tfunction ctelt (name, clazz, text) {\n", + "\t\t\tvar e = document.createElement(name);\n", + "\t\t\te.className = clazz;\n", + "\t\t\te.appendChild(document.createTextNode(text));\n", + "\t\t\treturn e;\n", + "\t\t}\n", + "\t\n", + "\t\tvar tdata = document.createElement(\"thead\");\n", + "\t\tvar node = document.createElement(\"tr\");\n", + "\t\ttdata.appendChild(node);\n", + "\t\tif (inline) {\n", + "\t\t\tnode.appendChild(document.createElement(\"th\"));\n", + "\t\t\tnode.appendChild(document.createElement(\"th\"));\n", + "\t\t\tnode.appendChild(ctelt(\"th\", \"texttitle\", baseTextName + \" vs. \" + newTextName));\n", + "\t\t} else {\n", + "\t\t\tnode.appendChild(document.createElement(\"th\"));\n", + "\t\t\tnode.appendChild(ctelt(\"th\", \"texttitle\", baseTextName));\n", + "\t\t\tnode.appendChild(document.createElement(\"th\"));\n", + "\t\t\tnode.appendChild(ctelt(\"th\", \"texttitle\", newTextName));\n", + "\t\t}\n", + "\t\ttdata = [tdata];\n", + "\t\t\n", + "\t\tvar rows = [];\n", + "\t\tvar node2;\n", + "\t\t\n", + "\t\t/**\n", + "\t\t * Adds two cells to the given row; if the given row corresponds to a real\n", + "\t\t * line number (based on the line index tidx and the endpoint of the \n", + "\t\t * range in question tend), then the cells will contain the line number\n", + "\t\t * and the line of text from textLines at position tidx (with the class of\n", + "\t\t * the second cell set to the name of the change represented), and tidx + 1 will\n", + "\t\t * be returned.\t Otherwise, tidx is returned, and two empty cells are added\n", + "\t\t * to the given row.\n", + "\t\t */\n", + "\t\tfunction addCells (row, tidx, tend, textLines, change) {\n", + "\t\t\tif (tidx < tend) {\n", + "\t\t\t\trow.appendChild(telt(\"th\", (tidx + 1).toString()));\n", + "\t\t\t\trow.appendChild(ctelt(\"td\", change, textLines[tidx].replace(/\\t/g, \"\\u00a0\\u00a0\\u00a0\\u00a0\")));\n", + "\t\t\t\treturn tidx + 1;\n", + "\t\t\t} else {\n", + "\t\t\t\trow.appendChild(document.createElement(\"th\"));\n", + "\t\t\t\trow.appendChild(celt(\"td\", \"empty\"));\n", + "\t\t\t\treturn tidx;\n", + "\t\t\t}\n", + "\t\t}\n", + "\t\t\n", + "\t\tfunction addCellsInline (row, tidx, tidx2, textLines, change) {\n", + "\t\t\trow.appendChild(telt(\"th\", tidx == null ? \"\" : (tidx + 1).toString()));\n", + "\t\t\trow.appendChild(telt(\"th\", tidx2 == null ? \"\" : (tidx2 + 1).toString()));\n", + "\t\t\trow.appendChild(ctelt(\"td\", change, textLines[tidx != null ? tidx : tidx2].replace(/\\t/g, \"\\u00a0\\u00a0\\u00a0\\u00a0\")));\n", + "\t\t}\n", + "\t\t\n", + "\t\tfor (var idx = 0; idx < opcodes.length; idx++) {\n", + "\t\t\tvar code = opcodes[idx];\n", + "\t\t\tvar change = code[0];\n", + "\t\t\tvar b = code[1];\n", + "\t\t\tvar be = code[2];\n", + "\t\t\tvar n = code[3];\n", + "\t\t\tvar ne = code[4];\n", + "\t\t\tvar rowcnt = Math.max(be - b, ne - n);\n", + "\t\t\tvar toprows = [];\n", + "\t\t\tvar botrows = [];\n", + "\t\t\tfor (var i = 0; i < rowcnt; i++) {\n", + "\t\t\t\t// jump ahead if we've alredy provided leading context or if this is the first range\n", + "\t\t\t\tif (contextSize && opcodes.length > 1 && ((idx > 0 && i == contextSize) || (idx == 0 && i == 0)) && change==\"equal\") {\n", + "\t\t\t\t\tvar jump = rowcnt - ((idx == 0 ? 1 : 2) * contextSize);\n", + "\t\t\t\t\tif (jump > 1) {\n", + "\t\t\t\t\t\ttoprows.push(node = document.createElement(\"tr\"));\n", + "\t\t\t\t\t\t\n", + "\t\t\t\t\t\tb += jump;\n", + "\t\t\t\t\t\tn += jump;\n", + "\t\t\t\t\t\ti += jump - 1;\n", + "\t\t\t\t\t\tnode.appendChild(telt(\"th\", \"...\"));\n", + "\t\t\t\t\t\tif (!inline) node.appendChild(ctelt(\"td\", \"skip\", \"\"));\n", + "\t\t\t\t\t\tnode.appendChild(telt(\"th\", \"...\"));\n", + "\t\t\t\t\t\tnode.appendChild(ctelt(\"td\", \"skip\", \"\"));\n", + "\t\t\t\t\t\t\n", + "\t\t\t\t\t\t// skip last lines if they're all equal\n", + "\t\t\t\t\t\tif (idx + 1 == opcodes.length) {\n", + "\t\t\t\t\t\t\tbreak;\n", + "\t\t\t\t\t\t} else {\n", + "\t\t\t\t\t\t\tcontinue;\n", + "\t\t\t\t\t\t}\n", + "\t\t\t\t\t}\n", + "\t\t\t\t}\n", + "\t\t\t\t\n", + "\t\t\t\ttoprows.push(node = document.createElement(\"tr\"));\n", + "\t\t\t\tif (inline) {\n", + "\t\t\t\t\tif (change == \"insert\") {\n", + "\t\t\t\t\t\taddCellsInline(node, null, n++, newTextLines, change);\n", + "\t\t\t\t\t} else if (change == \"replace\") {\n", + "\t\t\t\t\t\tbotrows.push(node2 = document.createElement(\"tr\"));\n", + "\t\t\t\t\t\tif (b < be) addCellsInline(node, b++, null, baseTextLines, \"delete\");\n", + "\t\t\t\t\t\tif (n < ne) addCellsInline(node2, null, n++, newTextLines, \"insert\");\n", + "\t\t\t\t\t} else if (change == \"delete\") {\n", + "\t\t\t\t\t\taddCellsInline(node, b++, null, baseTextLines, change);\n", + "\t\t\t\t\t} else {\n", + "\t\t\t\t\t\t// equal\n", + "\t\t\t\t\t\taddCellsInline(node, b++, n++, baseTextLines, change);\n", + "\t\t\t\t\t}\n", + "\t\t\t\t} else {\n", + "\t\t\t\t\tb = addCells(node, b, be, baseTextLines, change);\n", + "\t\t\t\t\tn = addCells(node, n, ne, newTextLines, change);\n", + "\t\t\t\t}\n", + "\t\t\t}\n", + "\n", + "\t\t\tfor (var i = 0; i < toprows.length; i++) rows.push(toprows[i]);\n", + "\t\t\tfor (var i = 0; i < botrows.length; i++) rows.push(botrows[i]);\n", + "\t\t}\n", + "\t\t\n", + "\t\trows.push(node = ctelt(\"th\", \"author\", \"diff view generated by \"));\n", + "\t\tnode.setAttribute(\"colspan\", inline ? 3 : 4);\n", + "\t\tnode.appendChild(node2 = telt(\"a\", \"jsdifflib\"));\n", + "\t\tnode2.setAttribute(\"href\", \"http://github.com/cemerick/jsdifflib\");\n", + "\t\t\n", + "\t\ttdata.push(node = document.createElement(\"tbody\"));\n", + "\t\tfor (var idx in rows) rows.hasOwnProperty(idx) && node.appendChild(rows[idx]);\n", + "\t\t\n", + "\t\tnode = celt(\"table\", \"diff\" + (inline ? \" inlinediff\" : \"\"));\n", + "\t\tfor (var idx in tdata) tdata.hasOwnProperty(idx) && node.appendChild(tdata[idx]);\n", + "\t\treturn node;\n", + "\t}\n", + "};\n", + "\n", + "\n", + "/***\n", + "This is part of jsdifflib v1.0. \n", + "\n", + "Copyright (c) 2007, Snowtide Informatics Systems, Inc.\n", + "All rights reserved.\n", + "\n", + "Redistribution and use in source and binary forms, with or without modification,\n", + "are permitted provided that the following conditions are met:\n", + "\n", + "\t* Redistributions of source code must retain the above copyright notice, this\n", + "\t\tlist of conditions and the following disclaimer.\n", + "\t* Redistributions in binary form must reproduce the above copyright notice,\n", + "\t\tthis list of conditions and the following disclaimer in the documentation\n", + "\t\tand/or other materials provided with the distribution.\n", + "\t* Neither the name of the Snowtide Informatics Systems nor the names of its\n", + "\t\tcontributors may be used to endorse or promote products derived from this\n", + "\t\tsoftware without specific prior written permission.\n", + "\n", + "THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND ANY\n", + "EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES\n", + "OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT\n", + "SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,\n", + "INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED\n", + "TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR\n", + "BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN\n", + "CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN\n", + "ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH\n", + "DAMAGE.\n", + "***/\n", + "/* Author: Chas Emerick */\n", + "var __whitespace = {\" \":true, \"\\t\":true, \"\\n\":true, \"\\f\":true, \"\\r\":true};\n", + "\n", + "var difflib = {\n", + "\tdefaultJunkFunction: function (c) {\n", + "\t\treturn __whitespace.hasOwnProperty(c);\n", + "\t},\n", + "\t\n", + "\tstripLinebreaks: function (str) { return str.replace(/^[\\n\\r]*|[\\n\\r]*$/g, \"\"); },\n", + "\t\n", + "\tstringAsLines: function (str) {\n", + "\t\tvar lfpos = str.indexOf(\"\\n\");\n", + "\t\tvar crpos = str.indexOf(\"\\r\");\n", + "\t\tvar linebreak = ((lfpos > -1 && crpos > -1) || crpos < 0) ? \"\\n\" : \"\\r\";\n", + "\t\t\n", + "\t\tvar lines = str.split(linebreak);\n", + "\t\tfor (var i = 0; i < lines.length; i++) {\n", + "\t\t\tlines[i] = difflib.stripLinebreaks(lines[i]);\n", + "\t\t}\n", + "\t\t\n", + "\t\treturn lines;\n", + "\t},\n", + "\t\n", + "\t// iteration-based reduce implementation\n", + "\t__reduce: function (func, list, initial) {\n", + "\t\tif (initial != null) {\n", + "\t\t\tvar value = initial;\n", + "\t\t\tvar idx = 0;\n", + "\t\t} else if (list) {\n", + "\t\t\tvar value = list[0];\n", + "\t\t\tvar idx = 1;\n", + "\t\t} else {\n", + "\t\t\treturn null;\n", + "\t\t}\n", + "\t\t\n", + "\t\tfor (; idx < list.length; idx++) {\n", + "\t\t\tvalue = func(value, list[idx]);\n", + "\t\t}\n", + "\t\t\n", + "\t\treturn value;\n", + "\t},\n", + "\t\n", + "\t// comparison function for sorting lists of numeric tuples\n", + "\t__ntuplecomp: function (a, b) {\n", + "\t\tvar mlen = Math.max(a.length, b.length);\n", + "\t\tfor (var i = 0; i < mlen; i++) {\n", + "\t\t\tif (a[i] < b[i]) return -1;\n", + "\t\t\tif (a[i] > b[i]) return 1;\n", + "\t\t}\n", + "\t\t\n", + "\t\treturn a.length == b.length ? 0 : (a.length < b.length ? -1 : 1);\n", + "\t},\n", + "\t\n", + "\t__calculate_ratio: function (matches, length) {\n", + "\t\treturn length ? 2.0 * matches / length : 1.0;\n", + "\t},\n", + "\t\n", + "\t// returns a function that returns true if a key passed to the returned function\n", + "\t// is in the dict (js object) provided to this function; replaces being able to\n", + "\t// carry around dict.has_key in python...\n", + "\t__isindict: function (dict) {\n", + "\t\treturn function (key) { return dict.hasOwnProperty(key); };\n", + "\t},\n", + "\t\n", + "\t// replacement for python's dict.get function -- need easy default values\n", + "\t__dictget: function (dict, key, defaultValue) {\n", + "\t\treturn dict.hasOwnProperty(key) ? dict[key] : defaultValue;\n", + "\t},\t\n", + "\t\n", + "\tSequenceMatcher: function (a, b, isjunk) {\n", + "\t\tthis.set_seqs = function (a, b) {\n", + "\t\t\tthis.set_seq1(a);\n", + "\t\t\tthis.set_seq2(b);\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.set_seq1 = function (a) {\n", + "\t\t\tif (a == this.a) return;\n", + "\t\t\tthis.a = a;\n", + "\t\t\tthis.matching_blocks = this.opcodes = null;\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.set_seq2 = function (b) {\n", + "\t\t\tif (b == this.b) return;\n", + "\t\t\tthis.b = b;\n", + "\t\t\tthis.matching_blocks = this.opcodes = this.fullbcount = null;\n", + "\t\t\tthis.__chain_b();\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.__chain_b = function () {\n", + "\t\t\tvar b = this.b;\n", + "\t\t\tvar n = b.length;\n", + "\t\t\tvar b2j = this.b2j = {};\n", + "\t\t\tvar populardict = {};\n", + "\t\t\tfor (var i = 0; i < b.length; i++) {\n", + "\t\t\t\tvar elt = b[i];\n", + "\t\t\t\tif (b2j.hasOwnProperty(elt)) {\n", + "\t\t\t\t\tvar indices = b2j[elt];\n", + "\t\t\t\t\tif (n >= 200 && indices.length * 100 > n) {\n", + "\t\t\t\t\t\tpopulardict[elt] = 1;\n", + "\t\t\t\t\t\tdelete b2j[elt];\n", + "\t\t\t\t\t} else {\n", + "\t\t\t\t\t\tindices.push(i);\n", + "\t\t\t\t\t}\n", + "\t\t\t\t} else {\n", + "\t\t\t\t\tb2j[elt] = [i];\n", + "\t\t\t\t}\n", + "\t\t\t}\n", + "\t\n", + "\t\t\tfor (var elt in populardict) {\n", + "\t\t\t\tif (populardict.hasOwnProperty(elt)) {\n", + "\t\t\t\t\tdelete b2j[elt];\n", + "\t\t\t\t}\n", + "\t\t\t}\n", + "\t\t\t\n", + "\t\t\tvar isjunk = this.isjunk;\n", + "\t\t\tvar junkdict = {};\n", + "\t\t\tif (isjunk) {\n", + "\t\t\t\tfor (var elt in populardict) {\n", + "\t\t\t\t\tif (populardict.hasOwnProperty(elt) && isjunk(elt)) {\n", + "\t\t\t\t\t\tjunkdict[elt] = 1;\n", + "\t\t\t\t\t\tdelete populardict[elt];\n", + "\t\t\t\t\t}\n", + "\t\t\t\t}\n", + "\t\t\t\tfor (var elt in b2j) {\n", + "\t\t\t\t\tif (b2j.hasOwnProperty(elt) && isjunk(elt)) {\n", + "\t\t\t\t\t\tjunkdict[elt] = 1;\n", + "\t\t\t\t\t\tdelete b2j[elt];\n", + "\t\t\t\t\t}\n", + "\t\t\t\t}\n", + "\t\t\t}\n", + "\t\n", + "\t\t\tthis.isbjunk = difflib.__isindict(junkdict);\n", + "\t\t\tthis.isbpopular = difflib.__isindict(populardict);\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.find_longest_match = function (alo, ahi, blo, bhi) {\n", + "\t\t\tvar a = this.a;\n", + "\t\t\tvar b = this.b;\n", + "\t\t\tvar b2j = this.b2j;\n", + "\t\t\tvar isbjunk = this.isbjunk;\n", + "\t\t\tvar besti = alo;\n", + "\t\t\tvar bestj = blo;\n", + "\t\t\tvar bestsize = 0;\n", + "\t\t\tvar j = null;\n", + "\t\t\tvar k;\n", + "\t\n", + "\t\t\tvar j2len = {};\n", + "\t\t\tvar nothing = [];\n", + "\t\t\tfor (var i = alo; i < ahi; i++) {\n", + "\t\t\t\tvar newj2len = {};\n", + "\t\t\t\tvar jdict = difflib.__dictget(b2j, a[i], nothing);\n", + "\t\t\t\tfor (var jkey in jdict) {\n", + "\t\t\t\t\tif (jdict.hasOwnProperty(jkey)) {\n", + "\t\t\t\t\t\tj = jdict[jkey];\n", + "\t\t\t\t\t\tif (j < blo) continue;\n", + "\t\t\t\t\t\tif (j >= bhi) break;\n", + "\t\t\t\t\t\tnewj2len[j] = k = difflib.__dictget(j2len, j - 1, 0) + 1;\n", + "\t\t\t\t\t\tif (k > bestsize) {\n", + "\t\t\t\t\t\t\tbesti = i - k + 1;\n", + "\t\t\t\t\t\t\tbestj = j - k + 1;\n", + "\t\t\t\t\t\t\tbestsize = k;\n", + "\t\t\t\t\t\t}\n", + "\t\t\t\t\t}\n", + "\t\t\t\t}\n", + "\t\t\t\tj2len = newj2len;\n", + "\t\t\t}\n", + "\t\n", + "\t\t\twhile (besti > alo && bestj > blo && !isbjunk(b[bestj - 1]) && a[besti - 1] == b[bestj - 1]) {\n", + "\t\t\t\tbesti--;\n", + "\t\t\t\tbestj--;\n", + "\t\t\t\tbestsize++;\n", + "\t\t\t}\n", + "\t\t\t\t\n", + "\t\t\twhile (besti + bestsize < ahi && bestj + bestsize < bhi &&\n", + "\t\t\t\t\t!isbjunk(b[bestj + bestsize]) &&\n", + "\t\t\t\t\ta[besti + bestsize] == b[bestj + bestsize]) {\n", + "\t\t\t\tbestsize++;\n", + "\t\t\t}\n", + "\t\n", + "\t\t\twhile (besti > alo && bestj > blo && isbjunk(b[bestj - 1]) && a[besti - 1] == b[bestj - 1]) {\n", + "\t\t\t\tbesti--;\n", + "\t\t\t\tbestj--;\n", + "\t\t\t\tbestsize++;\n", + "\t\t\t}\n", + "\t\t\t\n", + "\t\t\twhile (besti + bestsize < ahi && bestj + bestsize < bhi && isbjunk(b[bestj + bestsize]) &&\n", + "\t\t\t\t\ta[besti + bestsize] == b[bestj + bestsize]) {\n", + "\t\t\t\tbestsize++;\n", + "\t\t\t}\n", + "\t\n", + "\t\t\treturn [besti, bestj, bestsize];\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.get_matching_blocks = function () {\n", + "\t\t\tif (this.matching_blocks != null) return this.matching_blocks;\n", + "\t\t\tvar la = this.a.length;\n", + "\t\t\tvar lb = this.b.length;\n", + "\t\n", + "\t\t\tvar queue = [[0, la, 0, lb]];\n", + "\t\t\tvar matching_blocks = [];\n", + "\t\t\tvar alo, ahi, blo, bhi, qi, i, j, k, x;\n", + "\t\t\twhile (queue.length) {\n", + "\t\t\t\tqi = queue.pop();\n", + "\t\t\t\talo = qi[0];\n", + "\t\t\t\tahi = qi[1];\n", + "\t\t\t\tblo = qi[2];\n", + "\t\t\t\tbhi = qi[3];\n", + "\t\t\t\tx = this.find_longest_match(alo, ahi, blo, bhi);\n", + "\t\t\t\ti = x[0];\n", + "\t\t\t\tj = x[1];\n", + "\t\t\t\tk = x[2];\n", + "\t\n", + "\t\t\t\tif (k) {\n", + "\t\t\t\t\tmatching_blocks.push(x);\n", + "\t\t\t\t\tif (alo < i && blo < j)\n", + "\t\t\t\t\t\tqueue.push([alo, i, blo, j]);\n", + "\t\t\t\t\tif (i+k < ahi && j+k < bhi)\n", + "\t\t\t\t\t\tqueue.push([i + k, ahi, j + k, bhi]);\n", + "\t\t\t\t}\n", + "\t\t\t}\n", + "\t\t\t\n", + "\t\t\tmatching_blocks.sort(difflib.__ntuplecomp);\n", + "\t\n", + "\t\t\tvar i1 = 0, j1 = 0, k1 = 0, block = 0;\n", + "\t\t\tvar i2, j2, k2;\n", + "\t\t\tvar non_adjacent = [];\n", + "\t\t\tfor (var idx in matching_blocks) {\n", + "\t\t\t\tif (matching_blocks.hasOwnProperty(idx)) {\n", + "\t\t\t\t\tblock = matching_blocks[idx];\n", + "\t\t\t\t\ti2 = block[0];\n", + "\t\t\t\t\tj2 = block[1];\n", + "\t\t\t\t\tk2 = block[2];\n", + "\t\t\t\t\tif (i1 + k1 == i2 && j1 + k1 == j2) {\n", + "\t\t\t\t\t\tk1 += k2;\n", + "\t\t\t\t\t} else {\n", + "\t\t\t\t\t\tif (k1) non_adjacent.push([i1, j1, k1]);\n", + "\t\t\t\t\t\ti1 = i2;\n", + "\t\t\t\t\t\tj1 = j2;\n", + "\t\t\t\t\t\tk1 = k2;\n", + "\t\t\t\t\t}\n", + "\t\t\t\t}\n", + "\t\t\t}\n", + "\t\t\t\n", + "\t\t\tif (k1) non_adjacent.push([i1, j1, k1]);\n", + "\t\n", + "\t\t\tnon_adjacent.push([la, lb, 0]);\n", + "\t\t\tthis.matching_blocks = non_adjacent;\n", + "\t\t\treturn this.matching_blocks;\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.get_opcodes = function () {\n", + "\t\t\tif (this.opcodes != null) return this.opcodes;\n", + "\t\t\tvar i = 0;\n", + "\t\t\tvar j = 0;\n", + "\t\t\tvar answer = [];\n", + "\t\t\tthis.opcodes = answer;\n", + "\t\t\tvar block, ai, bj, size, tag;\n", + "\t\t\tvar blocks = this.get_matching_blocks();\n", + "\t\t\tfor (var idx in blocks) {\n", + "\t\t\t\tif (blocks.hasOwnProperty(idx)) {\n", + "\t\t\t\t\tblock = blocks[idx];\n", + "\t\t\t\t\tai = block[0];\n", + "\t\t\t\t\tbj = block[1];\n", + "\t\t\t\t\tsize = block[2];\n", + "\t\t\t\t\ttag = '';\n", + "\t\t\t\t\tif (i < ai && j < bj) {\n", + "\t\t\t\t\t\ttag = 'replace';\n", + "\t\t\t\t\t} else if (i < ai) {\n", + "\t\t\t\t\t\ttag = 'delete';\n", + "\t\t\t\t\t} else if (j < bj) {\n", + "\t\t\t\t\t\ttag = 'insert';\n", + "\t\t\t\t\t}\n", + "\t\t\t\t\tif (tag) answer.push([tag, i, ai, j, bj]);\n", + "\t\t\t\t\ti = ai + size;\n", + "\t\t\t\t\tj = bj + size;\n", + "\t\t\t\t\t\n", + "\t\t\t\t\tif (size) answer.push(['equal', ai, i, bj, j]);\n", + "\t\t\t\t}\n", + "\t\t\t}\n", + "\t\t\t\n", + "\t\t\treturn answer;\n", + "\t\t}\n", + "\t\t\n", + "\t\t// this is a generator function in the python lib, which of course is not supported in javascript\n", + "\t\t// the reimplementation builds up the grouped opcodes into a list in their entirety and returns that.\n", + "\t\tthis.get_grouped_opcodes = function (n) {\n", + "\t\t\tif (!n) n = 3;\n", + "\t\t\tvar codes = this.get_opcodes();\n", + "\t\t\tif (!codes) codes = [[\"equal\", 0, 1, 0, 1]];\n", + "\t\t\tvar code, tag, i1, i2, j1, j2;\n", + "\t\t\tif (codes[0][0] == 'equal') {\n", + "\t\t\t\tcode = codes[0];\n", + "\t\t\t\ttag = code[0];\n", + "\t\t\t\ti1 = code[1];\n", + "\t\t\t\ti2 = code[2];\n", + "\t\t\t\tj1 = code[3];\n", + "\t\t\t\tj2 = code[4];\n", + "\t\t\t\tcodes[0] = [tag, Math.max(i1, i2 - n), i2, Math.max(j1, j2 - n), j2];\n", + "\t\t\t}\n", + "\t\t\tif (codes[codes.length - 1][0] == 'equal') {\n", + "\t\t\t\tcode = codes[codes.length - 1];\n", + "\t\t\t\ttag = code[0];\n", + "\t\t\t\ti1 = code[1];\n", + "\t\t\t\ti2 = code[2];\n", + "\t\t\t\tj1 = code[3];\n", + "\t\t\t\tj2 = code[4];\n", + "\t\t\t\tcodes[codes.length - 1] = [tag, i1, Math.min(i2, i1 + n), j1, Math.min(j2, j1 + n)];\n", + "\t\t\t}\n", + "\t\n", + "\t\t\tvar nn = n + n;\n", + "\t\t\tvar group = [];\n", + "\t\t\tvar groups = [];\n", + "\t\t\tfor (var idx in codes) {\n", + "\t\t\t\tif (codes.hasOwnProperty(idx)) {\n", + "\t\t\t\t\tcode = codes[idx];\n", + "\t\t\t\t\ttag = code[0];\n", + "\t\t\t\t\ti1 = code[1];\n", + "\t\t\t\t\ti2 = code[2];\n", + "\t\t\t\t\tj1 = code[3];\n", + "\t\t\t\t\tj2 = code[4];\n", + "\t\t\t\t\tif (tag == 'equal' && i2 - i1 > nn) {\n", + "\t\t\t\t\t\tgroup.push([tag, i1, Math.min(i2, i1 + n), j1, Math.min(j2, j1 + n)]);\n", + "\t\t\t\t\t\tgroups.push(group);\n", + "\t\t\t\t\t\tgroup = [];\n", + "\t\t\t\t\t\ti1 = Math.max(i1, i2-n);\n", + "\t\t\t\t\t\tj1 = Math.max(j1, j2-n);\n", + "\t\t\t\t\t}\n", + "\t\t\t\t\t\n", + "\t\t\t\t\tgroup.push([tag, i1, i2, j1, j2]);\n", + "\t\t\t\t}\n", + "\t\t\t}\n", + "\t\t\t\n", + "\t\t\tif (group && !(group.length == 1 && group[0][0] == 'equal')) groups.push(group)\n", + "\t\t\t\n", + "\t\t\treturn groups;\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.ratio = function () {\n", + "\t\t\tmatches = difflib.__reduce(\n", + "\t\t\t\t\t\t\tfunction (sum, triple) { return sum + triple[triple.length - 1]; },\n", + "\t\t\t\t\t\t\tthis.get_matching_blocks(), 0);\n", + "\t\t\treturn difflib.__calculate_ratio(matches, this.a.length + this.b.length);\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.quick_ratio = function () {\n", + "\t\t\tvar fullbcount, elt;\n", + "\t\t\tif (this.fullbcount == null) {\n", + "\t\t\t\tthis.fullbcount = fullbcount = {};\n", + "\t\t\t\tfor (var i = 0; i < this.b.length; i++) {\n", + "\t\t\t\t\telt = this.b[i];\n", + "\t\t\t\t\tfullbcount[elt] = difflib.__dictget(fullbcount, elt, 0) + 1;\n", + "\t\t\t\t}\n", + "\t\t\t}\n", + "\t\t\tfullbcount = this.fullbcount;\n", + "\t\n", + "\t\t\tvar avail = {};\n", + "\t\t\tvar availhas = difflib.__isindict(avail);\n", + "\t\t\tvar matches = numb = 0;\n", + "\t\t\tfor (var i = 0; i < this.a.length; i++) {\n", + "\t\t\t\telt = this.a[i];\n", + "\t\t\t\tif (availhas(elt)) {\n", + "\t\t\t\t\tnumb = avail[elt];\n", + "\t\t\t\t} else {\n", + "\t\t\t\t\tnumb = difflib.__dictget(fullbcount, elt, 0);\n", + "\t\t\t\t}\n", + "\t\t\t\tavail[elt] = numb - 1;\n", + "\t\t\t\tif (numb > 0) matches++;\n", + "\t\t\t}\n", + "\t\t\t\n", + "\t\t\treturn difflib.__calculate_ratio(matches, this.a.length + this.b.length);\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.real_quick_ratio = function () {\n", + "\t\t\tvar la = this.a.length;\n", + "\t\t\tvar lb = this.b.length;\n", + "\t\t\treturn _calculate_ratio(Math.min(la, lb), la + lb);\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.isjunk = isjunk ? isjunk : difflib.defaultJunkFunction;\n", + "\t\tthis.a = this.b = null;\n", + "\t\tthis.set_seqs(a, b);\n", + "\t}\n", + "};\n", + "\n", + "\n", + "\n", + "function diffUsingJS (viewType, contextSize, baseText, newText) {\n", + "\n", + " var byId = function (id) { return document.getElementById(id); },\n", + " base = difflib.stringAsLines(baseText),\n", + " newtxt = difflib.stringAsLines(newText),\n", + " sm = new difflib.SequenceMatcher(base, newtxt),\n", + " opcodes = sm.get_opcodes(),\n", + " diffoutputdiv = byId(\"diffid_2021-08-05_16_46_43_018480\");\n", + "\n", + " diffoutputdiv.innerHTML = \"\";\n", + " contextSize = contextSize || null;\n", + "\n", + " diffoutputdiv.appendChild(diffview.buildView({\n", + " baseTextLines: base,\n", + " newTextLines: newtxt,\n", + " opcodes: opcodes,\n", + " baseTextName: \"Base Text\",\n", + " newTextName: \"New Text\",\n", + " contextSize: contextSize,\n", + " viewType: viewType\n", + " }));\n", + "}\n", + "var tview=0;\n", + "var csize='';\n", + "var bt = 'def dft_real(x, fft_length=None, transpose=True):\\n if len(x.shape) == 1:\\n x = x.reshape((1, -1))\\n N = 1\\n else:\\n N = x.shape[0] \\n C = x.shape[-1] if transpose else x.shape[-2]\\n if fft_length is None:\\n fft_length = x.shape[-1]\\n size = fft_length // 2 + 1\\n\\n cst = dft_real_cst(C, fft_length)\\n if transpose:\\n x = numpy.transpose(x, (1, 0))\\n a = cst[:, :, :fft_length]\\n b = x[:fft_length]\\n res = numpy.matmul(a, b)\\n res = res[:, :size, :]\\n return numpy.transpose(res, (0, 2, 1))\\n else:\\n a = cst[:, :, :fft_length]\\n b = x[:fft_length]\\n return numpy.matmul(a, b)\\n';\n", + "var nt = 'def dft_real_d3(x, fft_length=None, transpose=True):\\n if len(x.shape) != 3:\\n raise RuntimeError(\"Not implemented for shape=%r.\" % x.shape)\\n N = x.shape[1]\\n C = x.shape[-1] if transpose else x.shape[-2]\\n if fft_length is None:\\n fft_length = x.shape[-1]\\n size = fft_length // 2 + 1\\n\\n cst = dft_real_cst(C, fft_length)\\n if transpose:\\n x = numpy.transpose(x, (0, 2, 1))\\n a = cst[:, :, :fft_length]\\n b = x[:, :fft_length, :]\\n a = numpy.expand_dims(a, 0)\\n b = numpy.expand_dims(b, 1)\\n res = numpy.matmul(a, b)\\n res = res[:, :, :size, :]\\n return numpy.transpose(res, (1, 0, 3, 2))\\n else:\\n a = cst[:, :, :fft_length]\\n b = x[:, :fft_length, :]\\n a = numpy.expand_dims(a, 0)\\n b = numpy.expand_dims(b, 1)\\n res = numpy.matmul(a, b)\\n return numpy.transpose(res, (1, 0, 2, 3))\\n';\n", + "diffUsingJS(tview, csize, bt, nt) ;\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import inspect\n", + "text1 = inspect.getsource(dft_real)\n", + "text2 = inspect.getsource(dft_real_d3)\n", + "%textdiff text1 text2" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "cd7e14d4", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
populating...
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "/*\n", + "This is part of jsdifflib v1.0. \n", + "\n", + "Copyright 2007 - 2011 Chas Emerick . All rights reserved.\n", + "\n", + "Redistribution and use in source and binary forms, with or without modification, are\n", + "permitted provided that the following conditions are met:\n", + "\n", + " 1. Redistributions of source code must retain the above copyright notice, this list of\n", + " conditions and the following disclaimer.\n", + "\n", + " 2. Redistributions in binary form must reproduce the above copyright notice, this list\n", + " of conditions and the following disclaimer in the documentation and/or other materials\n", + " provided with the distribution.\n", + "\n", + "THIS SOFTWARE IS PROVIDED BY Chas Emerick ``AS IS'' AND ANY EXPRESS OR IMPLIED\n", + "WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND\n", + "FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL Chas Emerick OR\n", + "CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR\n", + "CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n", + "SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON\n", + "ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING\n", + "NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF\n", + "ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n", + "\n", + "The views and conclusions contained in the software and documentation are those of the\n", + "authors and should not be interpreted as representing official policies, either expressed\n", + "or implied, of Chas Emerick.\n", + "*/\n", + "var diffview = {\n", + "\t/**\n", + "\t * Builds and returns a visual diff view. The single parameter, `params', should contain\n", + "\t * the following values:\n", + "\t *\n", + "\t * - baseTextLines: the array of strings that was used as the base text input to SequenceMatcher\n", + "\t * - newTextLines: the array of strings that was used as the new text input to SequenceMatcher\n", + "\t * - opcodes: the array of arrays returned by SequenceMatcher.get_opcodes()\n", + "\t * - baseTextName: the title to be displayed above the base text listing in the diff view; defaults\n", + "\t *\t to \"Base Text\"\n", + "\t * - newTextName: the title to be displayed above the new text listing in the diff view; defaults\n", + "\t *\t to \"New Text\"\n", + "\t * - contextSize: the number of lines of context to show around differences; by default, all lines\n", + "\t *\t are shown\n", + "\t * - viewType: if 0, a side-by-side diff view is generated (default); if 1, an inline diff view is\n", + "\t *\t generated\n", + "\t */\n", + "\tbuildView: function (params) {\n", + "\t\tvar baseTextLines = params.baseTextLines;\n", + "\t\tvar newTextLines = params.newTextLines;\n", + "\t\tvar opcodes = params.opcodes;\n", + "\t\tvar baseTextName = params.baseTextName ? params.baseTextName : \"Base Text\";\n", + "\t\tvar newTextName = params.newTextName ? params.newTextName : \"New Text\";\n", + "\t\tvar contextSize = params.contextSize;\n", + "\t\tvar inline = (params.viewType == 0 || params.viewType == 1) ? params.viewType : 0;\n", + "\n", + "\t\tif (baseTextLines == null)\n", + "\t\t\tthrow \"Cannot build diff view; baseTextLines is not defined.\";\n", + "\t\tif (newTextLines == null)\n", + "\t\t\tthrow \"Cannot build diff view; newTextLines is not defined.\";\n", + "\t\tif (!opcodes)\n", + "\t\t\tthrow \"Cannot build diff view; opcodes is not defined.\";\n", + "\t\t\n", + "\t\tfunction celt (name, clazz) {\n", + "\t\t\tvar e = document.createElement(name);\n", + "\t\t\te.className = clazz;\n", + "\t\t\treturn e;\n", + "\t\t}\n", + "\t\t\n", + "\t\tfunction telt (name, text) {\n", + "\t\t\tvar e = document.createElement(name);\n", + "\t\t\te.appendChild(document.createTextNode(text));\n", + "\t\t\treturn e;\n", + "\t\t}\n", + "\t\t\n", + "\t\tfunction ctelt (name, clazz, text) {\n", + "\t\t\tvar e = document.createElement(name);\n", + "\t\t\te.className = clazz;\n", + "\t\t\te.appendChild(document.createTextNode(text));\n", + "\t\t\treturn e;\n", + "\t\t}\n", + "\t\n", + "\t\tvar tdata = document.createElement(\"thead\");\n", + "\t\tvar node = document.createElement(\"tr\");\n", + "\t\ttdata.appendChild(node);\n", + "\t\tif (inline) {\n", + "\t\t\tnode.appendChild(document.createElement(\"th\"));\n", + "\t\t\tnode.appendChild(document.createElement(\"th\"));\n", + "\t\t\tnode.appendChild(ctelt(\"th\", \"texttitle\", baseTextName + \" vs. \" + newTextName));\n", + "\t\t} else {\n", + "\t\t\tnode.appendChild(document.createElement(\"th\"));\n", + "\t\t\tnode.appendChild(ctelt(\"th\", \"texttitle\", baseTextName));\n", + "\t\t\tnode.appendChild(document.createElement(\"th\"));\n", + "\t\t\tnode.appendChild(ctelt(\"th\", \"texttitle\", newTextName));\n", + "\t\t}\n", + "\t\ttdata = [tdata];\n", + "\t\t\n", + "\t\tvar rows = [];\n", + "\t\tvar node2;\n", + "\t\t\n", + "\t\t/**\n", + "\t\t * Adds two cells to the given row; if the given row corresponds to a real\n", + "\t\t * line number (based on the line index tidx and the endpoint of the \n", + "\t\t * range in question tend), then the cells will contain the line number\n", + "\t\t * and the line of text from textLines at position tidx (with the class of\n", + "\t\t * the second cell set to the name of the change represented), and tidx + 1 will\n", + "\t\t * be returned.\t Otherwise, tidx is returned, and two empty cells are added\n", + "\t\t * to the given row.\n", + "\t\t */\n", + "\t\tfunction addCells (row, tidx, tend, textLines, change) {\n", + "\t\t\tif (tidx < tend) {\n", + "\t\t\t\trow.appendChild(telt(\"th\", (tidx + 1).toString()));\n", + "\t\t\t\trow.appendChild(ctelt(\"td\", change, textLines[tidx].replace(/\\t/g, \"\\u00a0\\u00a0\\u00a0\\u00a0\")));\n", + "\t\t\t\treturn tidx + 1;\n", + "\t\t\t} else {\n", + "\t\t\t\trow.appendChild(document.createElement(\"th\"));\n", + "\t\t\t\trow.appendChild(celt(\"td\", \"empty\"));\n", + "\t\t\t\treturn tidx;\n", + "\t\t\t}\n", + "\t\t}\n", + "\t\t\n", + "\t\tfunction addCellsInline (row, tidx, tidx2, textLines, change) {\n", + "\t\t\trow.appendChild(telt(\"th\", tidx == null ? \"\" : (tidx + 1).toString()));\n", + "\t\t\trow.appendChild(telt(\"th\", tidx2 == null ? \"\" : (tidx2 + 1).toString()));\n", + "\t\t\trow.appendChild(ctelt(\"td\", change, textLines[tidx != null ? tidx : tidx2].replace(/\\t/g, \"\\u00a0\\u00a0\\u00a0\\u00a0\")));\n", + "\t\t}\n", + "\t\t\n", + "\t\tfor (var idx = 0; idx < opcodes.length; idx++) {\n", + "\t\t\tvar code = opcodes[idx];\n", + "\t\t\tvar change = code[0];\n", + "\t\t\tvar b = code[1];\n", + "\t\t\tvar be = code[2];\n", + "\t\t\tvar n = code[3];\n", + "\t\t\tvar ne = code[4];\n", + "\t\t\tvar rowcnt = Math.max(be - b, ne - n);\n", + "\t\t\tvar toprows = [];\n", + "\t\t\tvar botrows = [];\n", + "\t\t\tfor (var i = 0; i < rowcnt; i++) {\n", + "\t\t\t\t// jump ahead if we've alredy provided leading context or if this is the first range\n", + "\t\t\t\tif (contextSize && opcodes.length > 1 && ((idx > 0 && i == contextSize) || (idx == 0 && i == 0)) && change==\"equal\") {\n", + "\t\t\t\t\tvar jump = rowcnt - ((idx == 0 ? 1 : 2) * contextSize);\n", + "\t\t\t\t\tif (jump > 1) {\n", + "\t\t\t\t\t\ttoprows.push(node = document.createElement(\"tr\"));\n", + "\t\t\t\t\t\t\n", + "\t\t\t\t\t\tb += jump;\n", + "\t\t\t\t\t\tn += jump;\n", + "\t\t\t\t\t\ti += jump - 1;\n", + "\t\t\t\t\t\tnode.appendChild(telt(\"th\", \"...\"));\n", + "\t\t\t\t\t\tif (!inline) node.appendChild(ctelt(\"td\", \"skip\", \"\"));\n", + "\t\t\t\t\t\tnode.appendChild(telt(\"th\", \"...\"));\n", + "\t\t\t\t\t\tnode.appendChild(ctelt(\"td\", \"skip\", \"\"));\n", + "\t\t\t\t\t\t\n", + "\t\t\t\t\t\t// skip last lines if they're all equal\n", + "\t\t\t\t\t\tif (idx + 1 == opcodes.length) {\n", + "\t\t\t\t\t\t\tbreak;\n", + "\t\t\t\t\t\t} else {\n", + "\t\t\t\t\t\t\tcontinue;\n", + "\t\t\t\t\t\t}\n", + "\t\t\t\t\t}\n", + "\t\t\t\t}\n", + "\t\t\t\t\n", + "\t\t\t\ttoprows.push(node = document.createElement(\"tr\"));\n", + "\t\t\t\tif (inline) {\n", + "\t\t\t\t\tif (change == \"insert\") {\n", + "\t\t\t\t\t\taddCellsInline(node, null, n++, newTextLines, change);\n", + "\t\t\t\t\t} else if (change == \"replace\") {\n", + "\t\t\t\t\t\tbotrows.push(node2 = document.createElement(\"tr\"));\n", + "\t\t\t\t\t\tif (b < be) addCellsInline(node, b++, null, baseTextLines, \"delete\");\n", + "\t\t\t\t\t\tif (n < ne) addCellsInline(node2, null, n++, newTextLines, \"insert\");\n", + "\t\t\t\t\t} else if (change == \"delete\") {\n", + "\t\t\t\t\t\taddCellsInline(node, b++, null, baseTextLines, change);\n", + "\t\t\t\t\t} else {\n", + "\t\t\t\t\t\t// equal\n", + "\t\t\t\t\t\taddCellsInline(node, b++, n++, baseTextLines, change);\n", + "\t\t\t\t\t}\n", + "\t\t\t\t} else {\n", + "\t\t\t\t\tb = addCells(node, b, be, baseTextLines, change);\n", + "\t\t\t\t\tn = addCells(node, n, ne, newTextLines, change);\n", + "\t\t\t\t}\n", + "\t\t\t}\n", + "\n", + "\t\t\tfor (var i = 0; i < toprows.length; i++) rows.push(toprows[i]);\n", + "\t\t\tfor (var i = 0; i < botrows.length; i++) rows.push(botrows[i]);\n", + "\t\t}\n", + "\t\t\n", + "\t\trows.push(node = ctelt(\"th\", \"author\", \"diff view generated by \"));\n", + "\t\tnode.setAttribute(\"colspan\", inline ? 3 : 4);\n", + "\t\tnode.appendChild(node2 = telt(\"a\", \"jsdifflib\"));\n", + "\t\tnode2.setAttribute(\"href\", \"http://github.com/cemerick/jsdifflib\");\n", + "\t\t\n", + "\t\ttdata.push(node = document.createElement(\"tbody\"));\n", + "\t\tfor (var idx in rows) rows.hasOwnProperty(idx) && node.appendChild(rows[idx]);\n", + "\t\t\n", + "\t\tnode = celt(\"table\", \"diff\" + (inline ? \" inlinediff\" : \"\"));\n", + "\t\tfor (var idx in tdata) tdata.hasOwnProperty(idx) && node.appendChild(tdata[idx]);\n", + "\t\treturn node;\n", + "\t}\n", + "};\n", + "\n", + "\n", + "/***\n", + "This is part of jsdifflib v1.0. \n", + "\n", + "Copyright (c) 2007, Snowtide Informatics Systems, Inc.\n", + "All rights reserved.\n", + "\n", + "Redistribution and use in source and binary forms, with or without modification,\n", + "are permitted provided that the following conditions are met:\n", + "\n", + "\t* Redistributions of source code must retain the above copyright notice, this\n", + "\t\tlist of conditions and the following disclaimer.\n", + "\t* Redistributions in binary form must reproduce the above copyright notice,\n", + "\t\tthis list of conditions and the following disclaimer in the documentation\n", + "\t\tand/or other materials provided with the distribution.\n", + "\t* Neither the name of the Snowtide Informatics Systems nor the names of its\n", + "\t\tcontributors may be used to endorse or promote products derived from this\n", + "\t\tsoftware without specific prior written permission.\n", + "\n", + "THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND ANY\n", + "EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES\n", + "OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT\n", + "SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,\n", + "INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED\n", + "TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR\n", + "BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN\n", + "CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN\n", + "ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH\n", + "DAMAGE.\n", + "***/\n", + "/* Author: Chas Emerick */\n", + "var __whitespace = {\" \":true, \"\\t\":true, \"\\n\":true, \"\\f\":true, \"\\r\":true};\n", + "\n", + "var difflib = {\n", + "\tdefaultJunkFunction: function (c) {\n", + "\t\treturn __whitespace.hasOwnProperty(c);\n", + "\t},\n", + "\t\n", + "\tstripLinebreaks: function (str) { return str.replace(/^[\\n\\r]*|[\\n\\r]*$/g, \"\"); },\n", + "\t\n", + "\tstringAsLines: function (str) {\n", + "\t\tvar lfpos = str.indexOf(\"\\n\");\n", + "\t\tvar crpos = str.indexOf(\"\\r\");\n", + "\t\tvar linebreak = ((lfpos > -1 && crpos > -1) || crpos < 0) ? \"\\n\" : \"\\r\";\n", + "\t\t\n", + "\t\tvar lines = str.split(linebreak);\n", + "\t\tfor (var i = 0; i < lines.length; i++) {\n", + "\t\t\tlines[i] = difflib.stripLinebreaks(lines[i]);\n", + "\t\t}\n", + "\t\t\n", + "\t\treturn lines;\n", + "\t},\n", + "\t\n", + "\t// iteration-based reduce implementation\n", + "\t__reduce: function (func, list, initial) {\n", + "\t\tif (initial != null) {\n", + "\t\t\tvar value = initial;\n", + "\t\t\tvar idx = 0;\n", + "\t\t} else if (list) {\n", + "\t\t\tvar value = list[0];\n", + "\t\t\tvar idx = 1;\n", + "\t\t} else {\n", + "\t\t\treturn null;\n", + "\t\t}\n", + "\t\t\n", + "\t\tfor (; idx < list.length; idx++) {\n", + "\t\t\tvalue = func(value, list[idx]);\n", + "\t\t}\n", + "\t\t\n", + "\t\treturn value;\n", + "\t},\n", + "\t\n", + "\t// comparison function for sorting lists of numeric tuples\n", + "\t__ntuplecomp: function (a, b) {\n", + "\t\tvar mlen = Math.max(a.length, b.length);\n", + "\t\tfor (var i = 0; i < mlen; i++) {\n", + "\t\t\tif (a[i] < b[i]) return -1;\n", + "\t\t\tif (a[i] > b[i]) return 1;\n", + "\t\t}\n", + "\t\t\n", + "\t\treturn a.length == b.length ? 0 : (a.length < b.length ? -1 : 1);\n", + "\t},\n", + "\t\n", + "\t__calculate_ratio: function (matches, length) {\n", + "\t\treturn length ? 2.0 * matches / length : 1.0;\n", + "\t},\n", + "\t\n", + "\t// returns a function that returns true if a key passed to the returned function\n", + "\t// is in the dict (js object) provided to this function; replaces being able to\n", + "\t// carry around dict.has_key in python...\n", + "\t__isindict: function (dict) {\n", + "\t\treturn function (key) { return dict.hasOwnProperty(key); };\n", + "\t},\n", + "\t\n", + "\t// replacement for python's dict.get function -- need easy default values\n", + "\t__dictget: function (dict, key, defaultValue) {\n", + "\t\treturn dict.hasOwnProperty(key) ? dict[key] : defaultValue;\n", + "\t},\t\n", + "\t\n", + "\tSequenceMatcher: function (a, b, isjunk) {\n", + "\t\tthis.set_seqs = function (a, b) {\n", + "\t\t\tthis.set_seq1(a);\n", + "\t\t\tthis.set_seq2(b);\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.set_seq1 = function (a) {\n", + "\t\t\tif (a == this.a) return;\n", + "\t\t\tthis.a = a;\n", + "\t\t\tthis.matching_blocks = this.opcodes = null;\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.set_seq2 = function (b) {\n", + "\t\t\tif (b == this.b) return;\n", + "\t\t\tthis.b = b;\n", + "\t\t\tthis.matching_blocks = this.opcodes = this.fullbcount = null;\n", + "\t\t\tthis.__chain_b();\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.__chain_b = function () {\n", + "\t\t\tvar b = this.b;\n", + "\t\t\tvar n = b.length;\n", + "\t\t\tvar b2j = this.b2j = {};\n", + "\t\t\tvar populardict = {};\n", + "\t\t\tfor (var i = 0; i < b.length; i++) {\n", + "\t\t\t\tvar elt = b[i];\n", + "\t\t\t\tif (b2j.hasOwnProperty(elt)) {\n", + "\t\t\t\t\tvar indices = b2j[elt];\n", + "\t\t\t\t\tif (n >= 200 && indices.length * 100 > n) {\n", + "\t\t\t\t\t\tpopulardict[elt] = 1;\n", + "\t\t\t\t\t\tdelete b2j[elt];\n", + "\t\t\t\t\t} else {\n", + "\t\t\t\t\t\tindices.push(i);\n", + "\t\t\t\t\t}\n", + "\t\t\t\t} else {\n", + "\t\t\t\t\tb2j[elt] = [i];\n", + "\t\t\t\t}\n", + "\t\t\t}\n", + "\t\n", + "\t\t\tfor (var elt in populardict) {\n", + "\t\t\t\tif (populardict.hasOwnProperty(elt)) {\n", + "\t\t\t\t\tdelete b2j[elt];\n", + "\t\t\t\t}\n", + "\t\t\t}\n", + "\t\t\t\n", + "\t\t\tvar isjunk = this.isjunk;\n", + "\t\t\tvar junkdict = {};\n", + "\t\t\tif (isjunk) {\n", + "\t\t\t\tfor (var elt in populardict) {\n", + "\t\t\t\t\tif (populardict.hasOwnProperty(elt) && isjunk(elt)) {\n", + "\t\t\t\t\t\tjunkdict[elt] = 1;\n", + "\t\t\t\t\t\tdelete populardict[elt];\n", + "\t\t\t\t\t}\n", + "\t\t\t\t}\n", + "\t\t\t\tfor (var elt in b2j) {\n", + "\t\t\t\t\tif (b2j.hasOwnProperty(elt) && isjunk(elt)) {\n", + "\t\t\t\t\t\tjunkdict[elt] = 1;\n", + "\t\t\t\t\t\tdelete b2j[elt];\n", + "\t\t\t\t\t}\n", + "\t\t\t\t}\n", + "\t\t\t}\n", + "\t\n", + "\t\t\tthis.isbjunk = difflib.__isindict(junkdict);\n", + "\t\t\tthis.isbpopular = difflib.__isindict(populardict);\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.find_longest_match = function (alo, ahi, blo, bhi) {\n", + "\t\t\tvar a = this.a;\n", + "\t\t\tvar b = this.b;\n", + "\t\t\tvar b2j = this.b2j;\n", + "\t\t\tvar isbjunk = this.isbjunk;\n", + "\t\t\tvar besti = alo;\n", + "\t\t\tvar bestj = blo;\n", + "\t\t\tvar bestsize = 0;\n", + "\t\t\tvar j = null;\n", + "\t\t\tvar k;\n", + "\t\n", + "\t\t\tvar j2len = {};\n", + "\t\t\tvar nothing = [];\n", + "\t\t\tfor (var i = alo; i < ahi; i++) {\n", + "\t\t\t\tvar newj2len = {};\n", + "\t\t\t\tvar jdict = difflib.__dictget(b2j, a[i], nothing);\n", + "\t\t\t\tfor (var jkey in jdict) {\n", + "\t\t\t\t\tif (jdict.hasOwnProperty(jkey)) {\n", + "\t\t\t\t\t\tj = jdict[jkey];\n", + "\t\t\t\t\t\tif (j < blo) continue;\n", + "\t\t\t\t\t\tif (j >= bhi) break;\n", + "\t\t\t\t\t\tnewj2len[j] = k = difflib.__dictget(j2len, j - 1, 0) + 1;\n", + "\t\t\t\t\t\tif (k > bestsize) {\n", + "\t\t\t\t\t\t\tbesti = i - k + 1;\n", + "\t\t\t\t\t\t\tbestj = j - k + 1;\n", + "\t\t\t\t\t\t\tbestsize = k;\n", + "\t\t\t\t\t\t}\n", + "\t\t\t\t\t}\n", + "\t\t\t\t}\n", + "\t\t\t\tj2len = newj2len;\n", + "\t\t\t}\n", + "\t\n", + "\t\t\twhile (besti > alo && bestj > blo && !isbjunk(b[bestj - 1]) && a[besti - 1] == b[bestj - 1]) {\n", + "\t\t\t\tbesti--;\n", + "\t\t\t\tbestj--;\n", + "\t\t\t\tbestsize++;\n", + "\t\t\t}\n", + "\t\t\t\t\n", + "\t\t\twhile (besti + bestsize < ahi && bestj + bestsize < bhi &&\n", + "\t\t\t\t\t!isbjunk(b[bestj + bestsize]) &&\n", + "\t\t\t\t\ta[besti + bestsize] == b[bestj + bestsize]) {\n", + "\t\t\t\tbestsize++;\n", + "\t\t\t}\n", + "\t\n", + "\t\t\twhile (besti > alo && bestj > blo && isbjunk(b[bestj - 1]) && a[besti - 1] == b[bestj - 1]) {\n", + "\t\t\t\tbesti--;\n", + "\t\t\t\tbestj--;\n", + "\t\t\t\tbestsize++;\n", + "\t\t\t}\n", + "\t\t\t\n", + "\t\t\twhile (besti + bestsize < ahi && bestj + bestsize < bhi && isbjunk(b[bestj + bestsize]) &&\n", + "\t\t\t\t\ta[besti + bestsize] == b[bestj + bestsize]) {\n", + "\t\t\t\tbestsize++;\n", + "\t\t\t}\n", + "\t\n", + "\t\t\treturn [besti, bestj, bestsize];\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.get_matching_blocks = function () {\n", + "\t\t\tif (this.matching_blocks != null) return this.matching_blocks;\n", + "\t\t\tvar la = this.a.length;\n", + "\t\t\tvar lb = this.b.length;\n", + "\t\n", + "\t\t\tvar queue = [[0, la, 0, lb]];\n", + "\t\t\tvar matching_blocks = [];\n", + "\t\t\tvar alo, ahi, blo, bhi, qi, i, j, k, x;\n", + "\t\t\twhile (queue.length) {\n", + "\t\t\t\tqi = queue.pop();\n", + "\t\t\t\talo = qi[0];\n", + "\t\t\t\tahi = qi[1];\n", + "\t\t\t\tblo = qi[2];\n", + "\t\t\t\tbhi = qi[3];\n", + "\t\t\t\tx = this.find_longest_match(alo, ahi, blo, bhi);\n", + "\t\t\t\ti = x[0];\n", + "\t\t\t\tj = x[1];\n", + "\t\t\t\tk = x[2];\n", + "\t\n", + "\t\t\t\tif (k) {\n", + "\t\t\t\t\tmatching_blocks.push(x);\n", + "\t\t\t\t\tif (alo < i && blo < j)\n", + "\t\t\t\t\t\tqueue.push([alo, i, blo, j]);\n", + "\t\t\t\t\tif (i+k < ahi && j+k < bhi)\n", + "\t\t\t\t\t\tqueue.push([i + k, ahi, j + k, bhi]);\n", + "\t\t\t\t}\n", + "\t\t\t}\n", + "\t\t\t\n", + "\t\t\tmatching_blocks.sort(difflib.__ntuplecomp);\n", + "\t\n", + "\t\t\tvar i1 = 0, j1 = 0, k1 = 0, block = 0;\n", + "\t\t\tvar i2, j2, k2;\n", + "\t\t\tvar non_adjacent = [];\n", + "\t\t\tfor (var idx in matching_blocks) {\n", + "\t\t\t\tif (matching_blocks.hasOwnProperty(idx)) {\n", + "\t\t\t\t\tblock = matching_blocks[idx];\n", + "\t\t\t\t\ti2 = block[0];\n", + "\t\t\t\t\tj2 = block[1];\n", + "\t\t\t\t\tk2 = block[2];\n", + "\t\t\t\t\tif (i1 + k1 == i2 && j1 + k1 == j2) {\n", + "\t\t\t\t\t\tk1 += k2;\n", + "\t\t\t\t\t} else {\n", + "\t\t\t\t\t\tif (k1) non_adjacent.push([i1, j1, k1]);\n", + "\t\t\t\t\t\ti1 = i2;\n", + "\t\t\t\t\t\tj1 = j2;\n", + "\t\t\t\t\t\tk1 = k2;\n", + "\t\t\t\t\t}\n", + "\t\t\t\t}\n", + "\t\t\t}\n", + "\t\t\t\n", + "\t\t\tif (k1) non_adjacent.push([i1, j1, k1]);\n", + "\t\n", + "\t\t\tnon_adjacent.push([la, lb, 0]);\n", + "\t\t\tthis.matching_blocks = non_adjacent;\n", + "\t\t\treturn this.matching_blocks;\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.get_opcodes = function () {\n", + "\t\t\tif (this.opcodes != null) return this.opcodes;\n", + "\t\t\tvar i = 0;\n", + "\t\t\tvar j = 0;\n", + "\t\t\tvar answer = [];\n", + "\t\t\tthis.opcodes = answer;\n", + "\t\t\tvar block, ai, bj, size, tag;\n", + "\t\t\tvar blocks = this.get_matching_blocks();\n", + "\t\t\tfor (var idx in blocks) {\n", + "\t\t\t\tif (blocks.hasOwnProperty(idx)) {\n", + "\t\t\t\t\tblock = blocks[idx];\n", + "\t\t\t\t\tai = block[0];\n", + "\t\t\t\t\tbj = block[1];\n", + "\t\t\t\t\tsize = block[2];\n", + "\t\t\t\t\ttag = '';\n", + "\t\t\t\t\tif (i < ai && j < bj) {\n", + "\t\t\t\t\t\ttag = 'replace';\n", + "\t\t\t\t\t} else if (i < ai) {\n", + "\t\t\t\t\t\ttag = 'delete';\n", + "\t\t\t\t\t} else if (j < bj) {\n", + "\t\t\t\t\t\ttag = 'insert';\n", + "\t\t\t\t\t}\n", + "\t\t\t\t\tif (tag) answer.push([tag, i, ai, j, bj]);\n", + "\t\t\t\t\ti = ai + size;\n", + "\t\t\t\t\tj = bj + size;\n", + "\t\t\t\t\t\n", + "\t\t\t\t\tif (size) answer.push(['equal', ai, i, bj, j]);\n", + "\t\t\t\t}\n", + "\t\t\t}\n", + "\t\t\t\n", + "\t\t\treturn answer;\n", + "\t\t}\n", + "\t\t\n", + "\t\t// this is a generator function in the python lib, which of course is not supported in javascript\n", + "\t\t// the reimplementation builds up the grouped opcodes into a list in their entirety and returns that.\n", + "\t\tthis.get_grouped_opcodes = function (n) {\n", + "\t\t\tif (!n) n = 3;\n", + "\t\t\tvar codes = this.get_opcodes();\n", + "\t\t\tif (!codes) codes = [[\"equal\", 0, 1, 0, 1]];\n", + "\t\t\tvar code, tag, i1, i2, j1, j2;\n", + "\t\t\tif (codes[0][0] == 'equal') {\n", + "\t\t\t\tcode = codes[0];\n", + "\t\t\t\ttag = code[0];\n", + "\t\t\t\ti1 = code[1];\n", + "\t\t\t\ti2 = code[2];\n", + "\t\t\t\tj1 = code[3];\n", + "\t\t\t\tj2 = code[4];\n", + "\t\t\t\tcodes[0] = [tag, Math.max(i1, i2 - n), i2, Math.max(j1, j2 - n), j2];\n", + "\t\t\t}\n", + "\t\t\tif (codes[codes.length - 1][0] == 'equal') {\n", + "\t\t\t\tcode = codes[codes.length - 1];\n", + "\t\t\t\ttag = code[0];\n", + "\t\t\t\ti1 = code[1];\n", + "\t\t\t\ti2 = code[2];\n", + "\t\t\t\tj1 = code[3];\n", + "\t\t\t\tj2 = code[4];\n", + "\t\t\t\tcodes[codes.length - 1] = [tag, i1, Math.min(i2, i1 + n), j1, Math.min(j2, j1 + n)];\n", + "\t\t\t}\n", + "\t\n", + "\t\t\tvar nn = n + n;\n", + "\t\t\tvar group = [];\n", + "\t\t\tvar groups = [];\n", + "\t\t\tfor (var idx in codes) {\n", + "\t\t\t\tif (codes.hasOwnProperty(idx)) {\n", + "\t\t\t\t\tcode = codes[idx];\n", + "\t\t\t\t\ttag = code[0];\n", + "\t\t\t\t\ti1 = code[1];\n", + "\t\t\t\t\ti2 = code[2];\n", + "\t\t\t\t\tj1 = code[3];\n", + "\t\t\t\t\tj2 = code[4];\n", + "\t\t\t\t\tif (tag == 'equal' && i2 - i1 > nn) {\n", + "\t\t\t\t\t\tgroup.push([tag, i1, Math.min(i2, i1 + n), j1, Math.min(j2, j1 + n)]);\n", + "\t\t\t\t\t\tgroups.push(group);\n", + "\t\t\t\t\t\tgroup = [];\n", + "\t\t\t\t\t\ti1 = Math.max(i1, i2-n);\n", + "\t\t\t\t\t\tj1 = Math.max(j1, j2-n);\n", + "\t\t\t\t\t}\n", + "\t\t\t\t\t\n", + "\t\t\t\t\tgroup.push([tag, i1, i2, j1, j2]);\n", + "\t\t\t\t}\n", + "\t\t\t}\n", + "\t\t\t\n", + "\t\t\tif (group && !(group.length == 1 && group[0][0] == 'equal')) groups.push(group)\n", + "\t\t\t\n", + "\t\t\treturn groups;\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.ratio = function () {\n", + "\t\t\tmatches = difflib.__reduce(\n", + "\t\t\t\t\t\t\tfunction (sum, triple) { return sum + triple[triple.length - 1]; },\n", + "\t\t\t\t\t\t\tthis.get_matching_blocks(), 0);\n", + "\t\t\treturn difflib.__calculate_ratio(matches, this.a.length + this.b.length);\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.quick_ratio = function () {\n", + "\t\t\tvar fullbcount, elt;\n", + "\t\t\tif (this.fullbcount == null) {\n", + "\t\t\t\tthis.fullbcount = fullbcount = {};\n", + "\t\t\t\tfor (var i = 0; i < this.b.length; i++) {\n", + "\t\t\t\t\telt = this.b[i];\n", + "\t\t\t\t\tfullbcount[elt] = difflib.__dictget(fullbcount, elt, 0) + 1;\n", + "\t\t\t\t}\n", + "\t\t\t}\n", + "\t\t\tfullbcount = this.fullbcount;\n", + "\t\n", + "\t\t\tvar avail = {};\n", + "\t\t\tvar availhas = difflib.__isindict(avail);\n", + "\t\t\tvar matches = numb = 0;\n", + "\t\t\tfor (var i = 0; i < this.a.length; i++) {\n", + "\t\t\t\telt = this.a[i];\n", + "\t\t\t\tif (availhas(elt)) {\n", + "\t\t\t\t\tnumb = avail[elt];\n", + "\t\t\t\t} else {\n", + "\t\t\t\t\tnumb = difflib.__dictget(fullbcount, elt, 0);\n", + "\t\t\t\t}\n", + "\t\t\t\tavail[elt] = numb - 1;\n", + "\t\t\t\tif (numb > 0) matches++;\n", + "\t\t\t}\n", + "\t\t\t\n", + "\t\t\treturn difflib.__calculate_ratio(matches, this.a.length + this.b.length);\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.real_quick_ratio = function () {\n", + "\t\t\tvar la = this.a.length;\n", + "\t\t\tvar lb = this.b.length;\n", + "\t\t\treturn _calculate_ratio(Math.min(la, lb), la + lb);\n", + "\t\t}\n", + "\t\t\n", + "\t\tthis.isjunk = isjunk ? isjunk : difflib.defaultJunkFunction;\n", + "\t\tthis.a = this.b = null;\n", + "\t\tthis.set_seqs(a, b);\n", + "\t}\n", + "};\n", + "\n", + "\n", + "\n", + "function diffUsingJS (viewType, contextSize, baseText, newText) {\n", + "\n", + " var byId = function (id) { return document.getElementById(id); },\n", + " base = difflib.stringAsLines(baseText),\n", + " newtxt = difflib.stringAsLines(newText),\n", + " sm = new difflib.SequenceMatcher(base, newtxt),\n", + " opcodes = sm.get_opcodes(),\n", + " diffoutputdiv = byId(\"diffid_2021-08-05_16_46_43_079488\");\n", + "\n", + " diffoutputdiv.innerHTML = \"\";\n", + " contextSize = contextSize || null;\n", + "\n", + " diffoutputdiv.appendChild(diffview.buildView({\n", + " baseTextLines: base,\n", + " newTextLines: newtxt,\n", + " opcodes: opcodes,\n", + " baseTextName: \"Base Text\",\n", + " newTextName: \"New Text\",\n", + " contextSize: contextSize,\n", + " viewType: viewType\n", + " }));\n", + "}\n", + "var tview=0;\n", + "var csize='';\n", + "var bt = 'def fft2d(mat, fft_length):\\n mat = mat[:fft_length[0], :fft_length[1]]\\n res = mat.copy()\\n \\n # first FFT\\n res = dft_real(res, fft_length=fft_length[1], transpose=True)\\n \\n # second FFT decomposed on FFT on real part and imaginary part\\n res2_real = dft_real(res[0], fft_length=fft_length[0], transpose=False)\\n res2_imag = dft_real(res[1], fft_length=fft_length[0], transpose=False) \\n res2_imag2 = numpy.vstack([-res2_imag[1:2], res2_imag[:1]])\\n res = res2_real + res2_imag2\\n size = fft_length[1]//2 + 1\\n return res[:, :fft_length[0], :size]\\n';\n", + "var nt = 'def fft2d_d3(mat, fft_length):\\n mat = mat[:, :fft_length[-2], :fft_length[-1]]\\n res = mat.copy()\\n \\n # first FFT\\n res = dft_real_d3(res, fft_length=fft_length[-1], transpose=True)\\n \\n # second FFT decomposed on FFT on real part and imaginary part\\n res2_real = dft_real_d3(res[0], fft_length=fft_length[-2], transpose=False)\\n res2_imag = dft_real_d3(res[1], fft_length=fft_length[-2], transpose=False)\\n res2_imag2 = numpy.vstack([-res2_imag[1:2], res2_imag[:1]])\\n res = res2_real + res2_imag2\\n size = fft_length[-1]//2 + 1\\n return res[:, :, :fft_length[-2], :size]\\n';\n", + "diffUsingJS(tview, csize, bt, nt) ;\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "text1 = inspect.getsource(fft2d)\n", + "text2 = inspect.getsource(fft2d_d3)\n", + "%textdiff text1 text2" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "51e7a4f7", + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "def onnx_rfft_3d_1d(x, fft_length=None, transpose=True):\n", + " if fft_length is None:\n", + " raise RuntimeError(\"fft_length must be specified.\")\n", + " \n", + " size = fft_length // 2 + 1\n", + " cst = dft_real_cst(fft_length, fft_length).astype(numpy.float32)\n", + " if transpose:\n", + " xt = npnx.transpose(x, (0, 2, 1))\n", + " a = cst[:, :, :fft_length]\n", + " b = xt[:, :fft_length, :]\n", + " a = npnx.expand_dims(a, 0)\n", + " b = npnx.expand_dims(b, 1)\n", + " res = npnx.matmul(a, b)\n", + " res2 = res[:, :size, :]\n", + " return npnx.transpose(res2, (1, 0, 3, 2))\n", + " else:\n", + " a = cst[:, :, :fft_length]\n", + " b = x[:, :fft_length, :]\n", + " a = npnx.expand_dims(a, 0)\n", + " b = npnx.expand_dims(b, 1)\n", + " res = npnx.matmul(a, b)\n", + " return npnx.transpose(res, (1, 0, 2, 3)) \n", + " \n", + "\n", + "def onnx_rfft_3d_2d(x, fft_length=None):\n", + " mat = x[:, :fft_length[-2], :fft_length[-1]]\n", + " \n", + " # first FFT\n", + " res = onnx_rfft_3d_1d(mat, fft_length=fft_length[-1], transpose=True)\n", + " \n", + " # second FFT decomposed on FFT on real part and imaginary part\n", + " res2_real = onnx_rfft_3d_1d(res[0], fft_length=fft_length[0], transpose=False)\n", + " res2_imag = onnx_rfft_3d_1d(res[1], fft_length=fft_length[0], transpose=False) \n", + " res2_imag2 = npnx.vstack(-res2_imag[1:2], res2_imag[:1])\n", + " res = res2_real + res2_imag2\n", + " size = fft_length[1]//2 + 1\n", + " return res[:, :, :fft_length[-2], :size]\n", + "\n", + "\n", + "@onnxnumpy_np(signature=NDArrayType((\"T:all\", ), dtypes_out=('T',)))\n", + "def onnx_rfft_2d_any(x, fft_length=None):\n", + " new_shape = npnx.concat(\n", + " numpy.array([-1], dtype=numpy.int64), x.shape[-2:], axis=0)\n", + " mat2 = x.reshape(new_shape)\n", + " f2 = onnx_rfft_3d_2d(mat2, fft_length)\n", + " new_shape = npnx.concat(\n", + " numpy.array([2], dtype=numpy.int64), x.shape[:-2], f2.shape[-2:])\n", + " return f2.reshape(new_shape)\n", + "\n", + "\n", + "shape = (3, 1, 4)\n", + "fft_length = (1, 4)\n", + "rnd = numpy.random.randn(*list(shape)).astype(numpy.float32)\n", + "fft2d_cus = fft2d_any(rnd, fft_length)\n", + "fft2d_onx = onnx_rfft_2d_any(rnd, fft_length=fft_length)\n", + "almost_equal(fft2d_cus, fft2d_onx)" ] - }, - "execution_count": 33, - "metadata": {}, - "output_type": "execute_result" + }, + { + "cell_type": "markdown", + "id": "37c45ae7", + "metadata": {}, + "source": [ + "Let's do the same comparison." + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "11c1e596", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "OK x.shape=(3, 1, 4) length=(1, 4) output shape=(3, 4) or (2, 3, 1, 3)\n", + "OK x.shape=(3, 1, 4) length=(1, 4) output shape=(3, 4) or (2, 3, 1, 3)\n", + "OK x.shape=(3, 1, 4) length=(1, 4) output shape=(3, 4) or (2, 3, 1, 3)\n", + "OK x.shape=(3, 1, 4) length=(1, 2) output shape=(3, 4) or (2, 3, 1, 2)\n", + "DIS x.shape=(3, 1, 4) length=(1, 1) error=AssertionError('Mismatch max diff=1.0 > 1e-05.') output shape=(3, 4) or (2, 3, 1, 1)\n", + "OK x.shape=(5, 7) length=(5, 7) output shape=(3, 4) or (2, 5, 4)\n", + "OK x.shape=(5, 7) length=(1, 7) output shape=(3, 4) or (2, 1, 4)\n", + "OK x.shape=(5, 7) length=(2, 7) output shape=(3, 4) or (2, 2, 4)\n", + "OK x.shape=(5, 7) length=(5, 2) output shape=(3, 4) or (2, 5, 2)\n", + "OK x.shape=(5, 7) length=(3, 4) output shape=(3, 4) or (2, 3, 3)\n", + "OK x.shape=(3, 5, 7) length=(5, 7) output shape=(3, 4) or (2, 3, 5, 4)\n", + "OK x.shape=(3, 5, 7) length=(1, 7) output shape=(3, 4) or (2, 3, 1, 4)\n", + "OK x.shape=(3, 5, 7) length=(2, 7) output shape=(3, 4) or (2, 3, 2, 4)\n", + "OK x.shape=(3, 5, 7) length=(5, 2) output shape=(3, 4) or (2, 3, 5, 2)\n", + "OK x.shape=(3, 5, 7) length=(3, 4) output shape=(3, 4) or (2, 3, 3, 3)\n", + "OK x.shape=(7, 5) length=(7, 5) output shape=(3, 4) or (2, 7, 3)\n", + "OK x.shape=(7, 5) length=(1, 5) output shape=(3, 4) or (2, 1, 3)\n", + "OK x.shape=(7, 5) length=(2, 5) output shape=(3, 4) or (2, 2, 3)\n", + "OK x.shape=(7, 5) length=(7, 2) output shape=(3, 4) or (2, 7, 2)\n", + "OK x.shape=(7, 5) length=(3, 4) output shape=(3, 4) or (2, 3, 3)\n" + ] + } + ], + "source": [ + "for shape in [(3, 1, 4), (5, 7), (3, 5, 7), (7, 5)]:\n", + " for fft_length in [shape[-2:], (1, shape[-1]),\n", + " (min(2, shape[-2]), shape[-1]),\n", + " (shape[-2], 2),\n", + " (min(3, shape[-2]), min(4, shape[-2]))]:\n", + " x = numpy.random.randn(*list(shape)).astype(numpy.float32)\n", + " if len(fnp.shape) == 2:\n", + " fn= numpy.expand_dims(fnp, 0)\n", + " try:\n", + " cus = fft2d_any(x, fft_length)\n", + " except IndexError as e:\n", + " print(\"ERR x.shape=%r length=%r error=%r\" % (x.shape, fft_length, e))\n", + " continue\n", + " try:\n", + " onx = onnx_rfft_2d_any(x, fft_length=fft_length)\n", + " except IndexError as e:\n", + " print(\"ERR x.shape=%r length=%r error=%r\" % (x.shape, fft_length, e))\n", + " continue\n", + " try:\n", + " almost_equal(onx, cus)\n", + " except (AssertionError, IndexError) as e:\n", + " print(\"DIS x.shape=%r length=%r error=%r output shape=%r or %r\" % (\n", + " x.shape, fft_length, e, fnp.shape, cus.shape))\n", + " continue\n", + " print(\"OK x.shape=%r length=%r output shape=%r or %r\" % (\n", + " x.shape, fft_length, fnp.shape, cus.shape))" + ] + }, + { + "cell_type": "markdown", + "id": "d197467f", + "metadata": {}, + "source": [ + "There is one issue with ``fft_length=(1, 1)`` but that case is out of scope." + ] + }, + { + "cell_type": "markdown", + "id": "33b5897e", + "metadata": {}, + "source": [ + "### ONNX graph" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "d45e9a99", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 34, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "key = list(onnx_rfft_2d_any.signed_compiled)[0]\n", + "%onnxview onnx_rfft_2d_any.signed_compiled[key].compiled.onnx_" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "2ab7a3d0", + "metadata": {}, + "outputs": [], + "source": [ + "with open(\"fft2d_any.onnx\", \"wb\") as f:\n", + " key = list(onnx_rfft_2d_any.signed_compiled)[0]\n", + " f.write(onnx_rfft_2d_any.signed_compiled[key].compiled.onnx_.SerializeToString())" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "9e5507f7", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.5" } - ], - "source": [ - "key = list(onnx_rfft_2d_any.signed_compiled)[0]\n", - "%onnxview onnx_rfft_2d_any.signed_compiled[key].compiled.onnx_" - ] - }, - { - "cell_type": "code", - "execution_count": 34, - "id": "2ab7a3d0", - "metadata": {}, - "outputs": [], - "source": [ - "with open(\"fft2d_any.onnx\", \"wb\") as f:\n", - " key = list(onnx_rfft_2d_any.signed_compiled)[0]\n", - " f.write(onnx_rfft_2d_any.signed_compiled[key].compiled.onnx_.SerializeToString())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9e5507f7", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.5" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file diff --git a/_unittests/ut_onnxrt/test_onnxrt_python_runtime_custom.py b/_unittests/ut_onnxrt/test_onnxrt_python_runtime_custom.py index cf229f3ab..b5bb0b926 100644 --- a/_unittests/ut_onnxrt/test_onnxrt_python_runtime_custom.py +++ b/_unittests/ut_onnxrt/test_onnxrt_python_runtime_custom.py @@ -98,7 +98,8 @@ def test_onnxt_runtime_fft(self): if dim == 1: X = numpy.arange(16).astype(numpy.float32) elif dim == 2: - X = numpy.arange(48).astype(numpy.float32).reshape((3, -1)) + X = numpy.arange(48).astype( + numpy.float32).reshape((3, -1)) Y = numpy.fft.fft(X.astype(numpy.float32), axis=axis) onx = OnnxFFT('X', output_names=['Y'], @@ -125,7 +126,8 @@ def test_onnxt_runtime_fft(self): if dim == 1: X = numpy.arange(16).astype(numpy.float32) elif dim == 2: - X = numpy.arange(48).astype(numpy.float32).reshape((3, -1)) + X = numpy.arange(48).astype( + numpy.float32).reshape((3, -1)) Y = numpy.fft.fft(X.astype(numpy.float32), 8, axis=axis) onx = OnnxFFT('X', numpy.array([8], dtype=numpy.int64), @@ -155,7 +157,8 @@ def test_onnxt_runtime_rfft(self): if dim == 1: X = numpy.arange(16).astype(numpy.float32) elif dim == 2: - X = numpy.arange(48).astype(numpy.float32).reshape((3, -1)) + X = numpy.arange(48).astype( + numpy.float32).reshape((3, -1)) Y = numpy.fft.rfft(X.astype(numpy.float32), axis=axis) onx = OnnxRFFT('X', output_names=['Y'], @@ -182,7 +185,8 @@ def test_onnxt_runtime_rfft(self): if dim == 1: X = numpy.arange(16).astype(numpy.float32) elif dim == 2: - X = numpy.arange(48).astype(numpy.float32).reshape((3, -1)) + X = numpy.arange(48).astype( + numpy.float32).reshape((3, -1)) Y = numpy.fft.rfft(X.astype(numpy.float32), 8, axis=axis) onx = OnnxRFFT('X', numpy.array([8], dtype=numpy.int64), @@ -210,7 +214,8 @@ def test_onnxt_runtime_fft2d(self): if dim == 1: X = numpy.arange(16).astype(numpy.float32) elif dim == 2: - X = numpy.arange(48).astype(numpy.float32).reshape((3, -1)) + X = numpy.arange(48).astype( + numpy.float32).reshape((3, -1)) Y = numpy.fft.fft2(X.astype(numpy.float32), axes=axis) if axis is not None: @@ -239,8 +244,10 @@ def test_onnxt_runtime_fft2d(self): if dim == 1: X = numpy.arange(16).astype(numpy.float32) elif dim == 2: - X = numpy.arange(48).astype(numpy.float32).reshape((3, -1)) - Y = numpy.fft.fft2(X.astype(numpy.float32), (8, 8), axes=axis) + X = numpy.arange(48).astype( + numpy.float32).reshape((3, -1)) + Y = numpy.fft.fft2( + X.astype(numpy.float32), (8, 8), axes=axis) if axis is not None: onx = OnnxFFT2D('X', numpy.array([8, 8], dtype=numpy.int64), diff --git a/_unittests/ut_tools/test_export_onnx.py b/_unittests/ut_tools/test_export_onnx.py index 0461ca00e..5c80b9945 100644 --- a/_unittests/ut_tools/test_export_onnx.py +++ b/_unittests/ut_tools/test_export_onnx.py @@ -3,10 +3,12 @@ """ import os import unittest +import collections +import inspect from io import StringIO from contextlib import redirect_stdout, redirect_stderr import numpy -from onnx import numpy_helper +from onnx import numpy_helper, helper from onnx.helper import ( make_model, make_node, set_model_props, make_tensor, make_graph, make_tensor_value_info) @@ -17,8 +19,8 @@ class TestExportOnnx(ExtTestCase): - - def verify(self, content, existing_loc=None): + + def verify(self, content): try: left, __ = verify_code(content, exc=False) except SyntaxError as e: @@ -28,7 +30,7 @@ def verify(self, content, existing_loc=None): "" % (e, content)) from e # execution - try: + try: obj = compile(content, '', 'exec') except SyntaxError as e: raise AssertionError( @@ -39,21 +41,22 @@ def verify(self, content, existing_loc=None): loc = {'numpy_helper': numpy_helper, 'make_model': make_model, 'make_node': make_node, - 'set_model_props': set_model_props, + 'set_model_props': set_model_props, 'make_tensor': make_tensor, 'make_graph': make_graph, - 'make_tensor_value_info': make_tensor_value_info} - if existing_loc is not None: - loc.update(existing_loc) - glo.update(existing_loc) + 'make_tensor_value_info': make_tensor_value_info, + 'print': print, 'sorted': sorted, + 'collections': collections, 'inspect': inspect} out = StringIO() err = StringIO() - self.assertLess(len(left), 5) + if len(left) >= 5: + raise AssertionError( + "Too many unknown symbols: %r." % left) with redirect_stdout(out): with redirect_stderr(err): try: - exec(obj, glo, loc) # pylint: disable=W0122 + exec(obj, glo, loc) # pylint: disable=W0122 except Exception as e: raise AssertionError( "Unable to execute a script due to %r. " @@ -72,23 +75,25 @@ def test_export_onnx(self): x = numpy.random.randn(3, 1, 4).astype(numpy.float32) y = oinf0.run({'x': x}) - - new_onnx = export2onnx(os.path.join(folder, name)) - glo, loc = self.verify(new_onnx) + + new_onnx = export2onnx( + os.path.join(folder, name), name="FFT2D") + _, loc = self.verify(new_onnx) model = loc['onnx_model'] oinf = OnnxInference(model) - y1 = oinf0.run({'x': x}) + y1 = oinf.run({'x': x}) - new_onnx = export2onnx(os.path.join(folder, name), verbose=False) - glo, loc = self.verify(new_onnx) + new_onnx = export2onnx( + os.path.join(folder, name), verbose=False) + _, loc = self.verify(new_onnx) model = loc['onnx_model'] oinf = OnnxInference(model) - y2 = oinf0.run({'x': x}) - + y2 = oinf.run({'x': x}) + self.assertEqualArray(y['y'], y1['y']) self.assertEqualArray(y['y'], y2['y']) - def verify_tf(self, content, existing_loc=None): + def verify_tf(self, content): try: left, __ = verify_code(content, exc=False) except SyntaxError as e: @@ -98,7 +103,7 @@ def verify_tf(self, content, existing_loc=None): "" % (e, content)) from e # execution - try: + try: obj = compile(content, '', 'exec') except SyntaxError as e: raise AssertionError( @@ -106,24 +111,21 @@ def verify_tf(self, content, existing_loc=None): "\n--CODE--\n%s" "" % (e, content)) from e glo = globals().copy() - loc = {'numpy_helper': numpy_helper, - 'make_model': make_model, - 'make_node': make_node, - 'set_model_props': set_model_props, - 'make_tensor': make_tensor, - 'make_graph': make_graph, - 'make_tensor_value_info': make_tensor_value_info} - if existing_loc is not None: - loc.update(existing_loc) - glo.update(existing_loc) + loc = {'numpy': numpy, 'print': print, + 'dict': dict, 'sorted': sorted, 'list': list, + 'print': print, 'sorted': sorted, + 'collections': collections, 'inspect': inspect, + 'helper': helper} out = StringIO() err = StringIO() - self.assertLess(len(left), 5) + if len(left) >= 14: + raise AssertionError( + "Too many unknown symbols: %r." % left) with redirect_stdout(out): with redirect_stderr(err): try: - exec(obj, glo, loc) # pylint: disable=W0122 + exec(obj, glo, loc) # pylint: disable=W0122 except Exception as e: raise AssertionError( "Unable to execute a script due to %r. " @@ -138,11 +140,12 @@ def test_export2tf2onnx(self): names = ["fft2d_any.onnx"] for name in names: with self.subTest(name=name): - new_onnx = export2tf2onnx(os.path.join(folder, name)) - print(new_onnx) - self.verify_tf(new_onnx) - - + new_onnx = export2tf2onnx( + os.path.join(folder, name), name="FFT2D") + _, loc = self.verify_tf(new_onnx) + model = loc['onnx_model'] + self.assertIn('op_type: "FFT2D"', str(model)) + # print(model) if __name__ == "__main__": diff --git a/mlprodict/onnx_tools/onnx_export.py b/mlprodict/onnx_tools/onnx_export.py index ce966269f..a9a94664e 100644 --- a/mlprodict/onnx_tools/onnx_export.py +++ b/mlprodict/onnx_tools/onnx_export.py @@ -1,7 +1,7 @@ """ @file @brief Exports an ONNX graph in a way it can we created again -with a python script. It relies on :epkg:`jinja2` and :epkg:`autopep8`. +with a python script. It relies on :epkg:`jinja2` and :epkg:`autopep8`. .. versionadded:: 0.7 """ @@ -29,21 +29,25 @@ def create_model(): nodes = [] inputs = [] outputs = [] - + # opsets print('[opsets]') # verbose opsets = {{ opsets }} target_opset = {{ target_opset }} - + # initializers print('[initializers]') # verbose {% for name, value in initializers: %} + {% if len(value.shape) == 0: %} + value = numpy.array({{ value }}, dtype=numpy.{{ value.dtype }}) + {% else %} list_value = {{ value.ravel().tolist() }} - value = numpy.array(list_value, dtype=numpy.{{ value.dtype }}).reshape({{ value.shape }}) + value = numpy.array(list_value, dtype=numpy.{{ value.dtype }}){% if len(value.shape) > 1: %}.reshape({{ value.shape }}){% endif %} + {% endif %} tensor = numpy_helper.from_array(value, name='{{ name }}') initializers.append(tensor) {% endfor %} - + # inputs print('[inputs]') # verbose {% for name, type, shape in inputs: %} @@ -51,13 +55,13 @@ def create_model(): inputs.append(value) {% endfor %} - # inputs + # outputs print('[outputs]') # verbose {% for name, type, shape in outputs: %} value = make_tensor_value_info('{{ name }}', {{ type }}, {{ shape }}) outputs.append(value) {% endfor %} - + # nodes print('[nodes]') # verbose {% for node in nodes: %} @@ -84,48 +88,112 @@ def create_model(): onnx_model.model_version = {{ model_version }} onnx_model.doc_string = '{{ doc_string }}' set_model_props(onnx_model, {{ metadata }}) - + # opsets - print('[graph]') # verbose + print('[opset]') # verbose + del onnx_model.opset_import[:] # pylint: disable=E1101 for dom, value in opsets.items(): op_set = onnx_model.opset_import.add() op_set.domain = dom op_set.version = value return onnx_model - + + onnx_model = create_model() """) _tf2onnx_templates = dedent(""" - @tf_op("MyOp") - class ConvertMyOp: - + import inspect + import collections + import numpy + from onnx import AttributeProto + from onnx.helper import ( + make_model, make_node, set_model_props, make_tensor, make_graph, + make_tensor_value_info) + try: + from utils import make_name + except ImportError: + + _make_name_id = 0 + + + def make_name(name): + global _make_name_id + name = "%s_%d" % (name, _make_name_id) + _make_name_id += 1 + return name + + + class tf_op: + _OPSETS = collections.OrderedDict() + + def __init__(self, name, domain='', **kwargs): + if not isinstance(name, list): + name = [name] + self.names = name + self.domain = domain + self.kwargs = kwargs + + def __call__(self, func): + for ke, va in inspect.getmembers(func, inspect.ismethod): + if ke.startswith("version_"): + version = int(ke.replace("version_", "")) + self.register_handler(va, version, self.names, self.domain, self.kwargs) + return func + + def register_handler(self, func, version, names, domain, kwargs): + opset = tf_op._OPSETS.get(domain) + if not opset: + opset = [] + tf_op._OPSETS[domain] = opset + while version >= len(opset): + opset.append({}) + opset_dict = opset[version] + for name in names: + opset_dict[name] = (func, kwargs) + + + @tf_op("{{ name }}") + class Convert{{ name }}Op: + supported_dtypes = [ numpy.float32, ] - + @classmethod def any_version(cls, opset, ctx, node, **kwargs): ''' - Documentation. + Converter for ``{{ name }}``. + + * producer: {{ producer_name }} + * version: {{ model_version }} + * description: {{ doc_string }} + {%- for key, val in sorted(metadata.items()): -%} + * {{ key }}: {{ val }} + {%- endfor %} ''' + oldnode = node input_name = node.input[0] onnx_dtype = ctx.get_dtype(input_name) - utils.make_sure(onnx_dtype in ConvertOp.supported_dtypes, "Unsupported input type.") + utils.make_sure(onnx_dtype in Convert{{ name }}Op.supported_dtypes, "Unsupported input type.") shape = ctx.get_shape(input_name) - space_names = {} + vars = {} # initializers print('[initializers]') # verbose {% for name, value in initializers: %} - value = numpy.array(list_value, dtype=numpy.{{ value.dtype }}).reshape({{ value.shape }}) - r_{{ name }} = ctx.make_const(name=utils.make_name('init_{{ name }}'), np_val=value) - space_names['{{ name }}'] = r_{{ name }}.name - initializers.append(tensor) + {% if len(value.shape) == 0: %} + value = numpy.array({{ value }}, dtype=numpy.{{ value.dtype }}) + {% else %} + list_value = {{ value.ravel().tolist() }} + value = numpy.array(list_value, dtype=numpy.{{ value.dtype }}){% if len(value.shape) > 1: %}.reshape({{ value.shape }}){% endif %} + {% endif %} + r_{{ name }} = ctx.make_const(name=make_name('init_{{ name }}'), np_val=value) + vars['{{ name }}'] = r_{{ name }}.name {% endfor %} - + # nodes print('[nodes]') # verbose {% for node in nodes: %} @@ -133,35 +201,147 @@ def any_version(cls, opset, ctx, node, **kwargs): {%- for name, value in node['attributes']: -%} {{ name }}={{ value }}, {%- endfor -%}) - inputs = [{% for name in node['inputs']: -%}space_names['{{ name }}'], {%- endfor %}] + inputs = [{% for name in node['inputs']: -%}vars['{{ name }}'], {%- endfor %}] node = ctx.make_node( - {{ node['op_type'] }}, inputs=inputs, attr=attr,{% if node['domain']: -%} domain='{{ node['domain'] }}', {% endif %} - name=utils.make_name('{{ node['name'] }}')) + '{{ node['op_type'] }}', inputs=inputs, attr=attr,{% if node['domain']: -%} domain='{{ node['domain'] }}', {% endif %} + name=make_name('{{ node['name'] }}')) {% for i, name in enumerate(node['outputs']): -%} - space_names['{{ name }}'] = node.output[{{ i }}] + vars['{{ name }}'] = node.output[{{ i }}] {%- endfor %} - nodes.append(node) {% endfor %} -""") + # finalize + ctx.replace_all_inputs(oldnode.output[0], node.output[0]) + ctx.remove_node(oldnode.name) + + @classmethod + def version_13(cls, ctx, node, **kwargs): + return cls.any_version(13, ctx, node, **kwargs) -def export_template(model_onnx, templates, opset=None, verbose=True): + def create_model(): + inputs = [] + outputs = [] + + # inputs + print('[inputs]') # verbose + {% for name, type, shape in inputs: %} + value = make_tensor_value_info('{{ name }}', {{ type }}, {{ shape }}) + inputs.append(value) + {% endfor %} + + # outputs + print('[outputs]') # verbose + {% for name, type, shape in outputs: %} + value = make_tensor_value_info('{{ name }}', {{ type }}, {{ shape }}) + outputs.append(value) + {% endfor %} + + inames = [i.name for i in inputs] + onames = [i.name for i in outputs] + node = make_node('{{ name }}', inames, onames, name='{{ name }}') + + # graph + print('[graph]') # verbose + graph = make_graph([node], '{{ name }}', inputs, outputs) + onnx_model = make_model(graph) + onnx_model.ir_version = {{ ir_version }} + onnx_model.producer_name = '{{ producer_name }}' + onnx_model.producer_version = '{{ producer_version }}' + onnx_model.domain = '{{ domain }}' + onnx_model.model_version = {{ model_version }} + onnx_model.doc_string = '{{ doc_string }}' + set_model_props(onnx_model, {{ metadata }}) + + # opsets + print('[opset]') # verbose + opsets = {{ opsets }} + del onnx_model.opset_import[:] # pylint: disable=E1101 + for dom, value in opsets.items(): + op_set = onnx_model.opset_import.add() + op_set.domain = dom + op_set.version = value + + return onnx_model + + + class Rewrite: + + def __init__(self, onnx_model, tf_op): + self._onnx_model = onnx_model + self._nodes = list(onnx_model.graph.node) + self._tf_op = tf_op + + def make_node(self, op_type, inputs, attr=None, outputs=None, + name=None, domain=''): + if attr is None: + attr = {} + if name is None: + name = make_name(op_type) + + if outputs is None: + outputs = [name + ":" + str(i) for i in range(output_count)] + + output_count = len(outputs) + raw_attr = {} + onnx_attrs = [] + for a, v in attr.items(): + if isinstance(v, AttributeProto): + onnx_attrs.append(v) + else: + raw_attr[a] = v + + n = self.get_node_by_name(name) + + for o in outputs: + n = self.get_node_by_output_in_current_graph(o) + + onnx_node = make_node( + op_type, inputs, outputs, name=name, domain=domain, **raw_attr) + + self._nodes.append(onnx_node) + return node + + def rewrite(self): + print('[rewrite]') # verbose + done = {} + modif = 1 + while modif > 0: + modif = 0 + for node in self._nodes: + if done.get(node.name, False): + continue + domain = node.domain + if domain not in self._tf_op._OPSETS: + continue + rews = self._tf_op._OPSETS[domain] + # look for an opset + # call the rewriter + + + + onnx_model = create_model() + onnx_rewritten = Rewrite(onnx_model, tf_op).rewrite() +""") + + +def export_template(model_onnx, templates, opset=None, verbose=True, name=None): """ Exports an ONNX model to the onnx syntax. - + :param model_onnx: string or ONNX graph :param templates: exporting templates :param opset: opset to export to (None to select the one from the graph) :param verbose: insert prints + :param name: to overwrite onnx name :return: python code """ # containers context = {} # opset - opsets = {} + opsets = {} for oimp in model_onnx.opset_import: if oimp.domain == '' and opset is None: opsets[oimp.domain] = oimp.version @@ -193,7 +373,7 @@ def export_template(model_onnx, templates, opset=None, verbose=True): dims = tuple(t.shape.dim) outputs.append((inp.name, t.elem_type, dims)) context['outputs'] = outputs - + # node nodes = [] for node in model_onnx.graph.node: @@ -215,7 +395,7 @@ def export_template(model_onnx, templates, opset=None, verbose=True): context['nodes'] = nodes # graph - context['name'] = model_onnx.graph.name + context['name'] = name or model_onnx.graph.name context['ir_version'] = model_onnx.ir_version context['producer_name'] = model_onnx.producer_name context['domain'] = model_onnx.domain @@ -225,42 +405,47 @@ def export_template(model_onnx, templates, opset=None, verbose=True): # final template = Template(templates) - final = template.render(enumerate=enumerate, **context) - + final = template.render( + enumerate=enumerate, sorted=sorted, len=len, + **context) + if not verbose: rows = final.split("\n") final = "\n".join(_ for _ in rows if not _.endswith("# verbose")) return autopep8.fix_code(final) -def export2onnx(model_onnx, opset=None, verbose=True): +def export2onnx(model_onnx, opset=None, verbose=True, name=None): """ Exports an ONNX model to the :epkg:`onnx` syntax. - + :param model_onnx: string or ONNX graph :param opset: opset to export to (None to select the one from the graph) :param verbose: inserts prints + :param name: to overwrite onnx name :return: python code """ if isinstance(model_onnx, str): model_onnx = onnx.load(model_onnx) - return export_template(model_onnx, templates=_onnx_templates, opset=opset, verbose=verbose) - + return export_template(model_onnx, templates=_onnx_templates, + opset=opset, verbose=verbose, name=name) -def export2tf2onnx(model_onnx, opset=None, verbose=True): +def export2tf2onnx(model_onnx, opset=None, verbose=True, name=None): """ Exports an ONNX model to the e:pkg:`tensorflow-onnx` syntax. - + :param model_onnx: string or ONNX graph :param opset: opset to export to (None to select the one from the graph) :param verbose: inserts prints + :param name: to overwrite onnx name :return: python code """ if isinstance(model_onnx, str): model_onnx = onnx.load(model_onnx) - return export_template(model_onnx, templates=_tf2onnx_templates, opset=opset, verbose=verbose) + return export_template(model_onnx, templates=_tf2onnx_templates, + opset=opset, verbose=verbose, name=name) diff --git a/mlprodict/onnxrt/onnx_inference.py b/mlprodict/onnxrt/onnx_inference.py index 9a9a7bd7e..128bbf10a 100644 --- a/mlprodict/onnxrt/onnx_inference.py +++ b/mlprodict/onnxrt/onnx_inference.py @@ -822,7 +822,8 @@ def dispsimple(arr): mini = numpy_min(values[k]) maxi = numpy_max(values[k]) fLOG("+kr{}'{}': {} (dtype={} min={} max={}{})".format( - "=" if len(values[k].shape) == 0 or min(values[k].shape) > 0 else "*", + "=" if len(values[k].shape) == 0 or min( + values[k].shape) > 0 else "*", name, values[k].shape, values[k].dtype, mini, maxi, ' sparse' if isinstance(values[k], coo_matrix) else '')) diff --git a/mlprodict/testing/verify_code.py b/mlprodict/testing/verify_code.py index 929fde3f5..6d2f18dbe 100644 --- a/mlprodict/testing/verify_code.py +++ b/mlprodict/testing/verify_code.py @@ -4,6 +4,9 @@ before finalizing the benchmark. """ import ast +import collections +import inspect +import numpy class ImperfectPythonCode(RuntimeError): @@ -29,7 +32,11 @@ def verify_code(source, exc=True): imports = v._imports names = v._names args = v._args - known = {'super': None, 'ImportError': None} + known = {'super': None, 'ImportError': None, 'print': print, + 'classmethod': classmethod, 'numpy': numpy, + 'dict': dict, 'list': list, 'sorted': sorted, 'len': len, + 'collections': collections, 'inspect': inspect, 'range': range, + 'int': int, 'str': str, 'isinstance': isinstance} for kn in imports: known[kn[0]] = kn for kn in assign: From 6fef055bc31b994453b47290a45adfd7263e6d79 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Sat, 7 Aug 2021 13:23:54 +0200 Subject: [PATCH 09/12] Add code to recompose a python script creating an onnx graph --- _unittests/ut_cli/test_cli_onnx_code.py | 49 ++ _unittests/ut_tools/test_export_onnx.py | 504 +++++++++++++++++- mlprodict/__main__.py | 5 +- mlprodict/cli/__init__.py | 1 + mlprodict/cli/onnx_code.py | 57 ++ mlprodict/onnx_tools/exports/__init__.py | 0 .../onnx_tools/exports/tf2onnx_helper.py | 364 +++++++++++++ mlprodict/onnx_tools/onnx2py_helper.py | 38 ++ mlprodict/onnx_tools/onnx_export.py | 167 ++---- mlprodict/onnx_tools/onnx_tools.py | 87 +++ mlprodict/onnxrt/ops_cpu/_op_helper.py | 32 +- 11 files changed, 1155 insertions(+), 149 deletions(-) create mode 100644 _unittests/ut_cli/test_cli_onnx_code.py create mode 100644 mlprodict/cli/onnx_code.py create mode 100644 mlprodict/onnx_tools/exports/__init__.py create mode 100644 mlprodict/onnx_tools/exports/tf2onnx_helper.py diff --git a/_unittests/ut_cli/test_cli_onnx_code.py b/_unittests/ut_cli/test_cli_onnx_code.py new file mode 100644 index 000000000..9a29c69e9 --- /dev/null +++ b/_unittests/ut_cli/test_cli_onnx_code.py @@ -0,0 +1,49 @@ +""" +@brief test tree node (time=10s) +""" +import os +import unittest +from pyquickhelper.loghelper import BufferedPrint +from pyquickhelper.pycode import ExtTestCase, get_temp_folder +from mlprodict.__main__ import main + + +class TestCliOnnxCode(ExtTestCase): + + def test_cli_onnx_code(self): + st = BufferedPrint() + main(args=["onnx_code", "--help"], fLOG=st.fprint) + res = str(st) + self.assertIn("verbose", res) + + def test_cli_onnx_code_onnx(self): + temp = get_temp_folder(__file__, "temp_cli_onnx_code_onnx") + name = os.path.join( + temp, "..", "..", "ut_tools", "data", "fft2d_any.onnx") + self.assertExists(name) + output = os.path.join(temp, "code_onnx.py") + st = BufferedPrint() + main(args=["onnx_code", "--filename", name, + "--output", output, "--verbose", "1"], fLOG=st.fprint) + self.assertExists(output) + with open(output, "r", encoding='utf-8') as f: + content = f.read() + self.assertIn("create_model()", content) + + def test_cli_onnx_code_tf2onnx(self): + temp = get_temp_folder(__file__, "temp_cli_onnx_code_tf2onnx") + name = os.path.join( + temp, "..", "..", "ut_tools", "data", "fft2d_any.onnx") + self.assertExists(name) + output = os.path.join(temp, "code_tf2onnx.py") + st = BufferedPrint() + main(args=["onnx_code", "--filename", name, '--format', 'tf2onnx', + "--output", output, "--verbose", "1"], fLOG=st.fprint) + self.assertExists(output) + with open(output, "r", encoding='utf-8') as f: + content = f.read() + self.assertIn("tf_op", content) + + +if __name__ == "__main__": + unittest.main() diff --git a/_unittests/ut_tools/test_export_onnx.py b/_unittests/ut_tools/test_export_onnx.py index 5c80b9945..0444eeb6d 100644 --- a/_unittests/ut_tools/test_export_onnx.py +++ b/_unittests/ut_tools/test_export_onnx.py @@ -1,5 +1,5 @@ """ -@brief test log(time=3s) +@brief test log(time=5s) """ import os import unittest @@ -16,6 +16,470 @@ from mlprodict.onnx_tools.onnx_export import export2onnx, export2tf2onnx from mlprodict.testing.verify_code import verify_code from mlprodict.onnxrt import OnnxInference +from mlprodict.onnx_tools.exports.tf2onnx_helper import make_sure, make_name +from mlprodict.tools.code_helper import print_code + + +class ConvertFFT2DOp: + + supported_dtypes = [ + numpy.float32, + ] + + @classmethod + def any_version(cls, opset, ctx, node, **kwargs): #pylint: disable=R0915 + ''' + Converter for ``FFT2D``. + + * producer: skl2onnx + * version: 0 + * description: + ''' + oldnode = node + input_name = node.input[0] + onnx_dtype = ctx.get_dtype(input_name) + make_sure(onnx_dtype in ConvertFFT2DOp.supported_dtypes, + "Unsupported input type.") + vars = {x: x for x in node.input} # pylint: disable=W0622 + + # initializers + if getattr(ctx, 'verbose', False): + print('[initializers] %r' % cls) + + list_value = [1.0, 0.0] + value = numpy.array(list_value, dtype=numpy.float32).reshape((2, 1, 1)) + + r_Un_Unsqueezecst = ctx.make_const( + name=make_name('init_Un_Unsqueezecst'), np_val=value) + vars['Un_Unsqueezecst'] = r_Un_Unsqueezecst.name + + list_value = [0] + value = numpy.array(list_value, dtype=numpy.int64) + + r_Un_Unsqueezecst1 = ctx.make_const( + name=make_name('init_Un_Unsqueezecst1'), np_val=value) + vars['Un_Unsqueezecst1'] = r_Un_Unsqueezecst1.name + + list_value = [1.0, 1.0, 1.0, 1.0, 1.0, 6.123234262925839e-17, + -1.0, -1.8369701465288538e-16, 1.0, -1.0, 1.0, -1.0, 1.0, + -1.8369701465288538e-16, -1.0, 5.510910704284357e-16, 0.0, + 0.0, 0.0, 0.0, 0.0, -1.0, -1.2246468525851679e-16, 1.0, 0.0, + -1.2246468525851679e-16, 2.4492937051703357e-16, + -3.6739402930577075e-16, 0.0, 1.0, -3.6739402930577075e-16, -1.0] + value = numpy.array(list_value, dtype=numpy.float32).reshape((2, 4, 4)) + + r_Un_Unsqueezecst2 = ctx.make_const( + name=make_name('init_Un_Unsqueezecst2'), np_val=value) + vars['Un_Unsqueezecst2'] = r_Un_Unsqueezecst2.name + + list_value = [-1] + value = numpy.array(list_value, dtype=numpy.int64) + + r_Co_Concatcst = ctx.make_const( + name=make_name('init_Co_Concatcst'), np_val=value) + vars['Co_Concatcst'] = r_Co_Concatcst.name + + list_value = [-2] + value = numpy.array(list_value, dtype=numpy.int64) + + r_Sl_Slicecst = ctx.make_const( + name=make_name('init_Sl_Slicecst'), np_val=value) + vars['Sl_Slicecst'] = r_Sl_Slicecst.name + + value = numpy.array(0, dtype=numpy.int64) + + r_Ga_Gathercst = ctx.make_const( + name=make_name('init_Ga_Gathercst'), np_val=value) + vars['Ga_Gathercst'] = r_Ga_Gathercst.name + + list_value = [0, 0] + value = numpy.array(list_value, dtype=numpy.int64) + + r_Sl_Slicecst2 = ctx.make_const( + name=make_name('init_Sl_Slicecst2'), np_val=value) + vars['Sl_Slicecst2'] = r_Sl_Slicecst2.name + + list_value = [1, 4] + value = numpy.array(list_value, dtype=numpy.int64) + + r_Sl_Slicecst3 = ctx.make_const( + name=make_name('init_Sl_Slicecst3'), np_val=value) + vars['Sl_Slicecst3'] = r_Sl_Slicecst3.name + + list_value = [1, 2] + value = numpy.array(list_value, dtype=numpy.int64) + + r_Sl_Slicecst4 = ctx.make_const( + name=make_name('init_Sl_Slicecst4'), np_val=value) + vars['Sl_Slicecst4'] = r_Sl_Slicecst4.name + + list_value = [4] + value = numpy.array(list_value, dtype=numpy.int64) + + r_Sl_Slicecst6 = ctx.make_const( + name=make_name('init_Sl_Slicecst6'), np_val=value) + vars['Sl_Slicecst6'] = r_Sl_Slicecst6.name + + list_value = [1] + value = numpy.array(list_value, dtype=numpy.int64) + + r_Sl_Slicecst7 = ctx.make_const( + name=make_name('init_Sl_Slicecst7'), np_val=value) + vars['Sl_Slicecst7'] = r_Sl_Slicecst7.name + + list_value = [3] + value = numpy.array(list_value, dtype=numpy.int64) + + r_Sl_Slicecst9 = ctx.make_const( + name=make_name('init_Sl_Slicecst9'), np_val=value) + vars['Sl_Slicecst9'] = r_Sl_Slicecst9.name + + value = numpy.array(1, dtype=numpy.int64) + + r_Ga_Gathercst2 = ctx.make_const( + name=make_name('init_Ga_Gathercst2'), np_val=value) + vars['Ga_Gathercst2'] = r_Ga_Gathercst2.name + + list_value = [2] + value = numpy.array(list_value, dtype=numpy.int64) + + r_Sl_Slicecst18 = ctx.make_const( + name=make_name('init_Sl_Slicecst18'), np_val=value) + vars['Sl_Slicecst18'] = r_Sl_Slicecst18.name + + list_value = [1, 3] + value = numpy.array(list_value, dtype=numpy.int64) + + r_Sl_Slicecst24 = ctx.make_const( + name=make_name('init_Sl_Slicecst24'), np_val=value) + vars['Sl_Slicecst24'] = r_Sl_Slicecst24.name + + list_value = [2, 3] + value = numpy.array(list_value, dtype=numpy.int64) + + r_Sl_Slicecst25 = ctx.make_const( + name=make_name('init_Sl_Slicecst25'), np_val=value) + vars['Sl_Slicecst25'] = r_Sl_Slicecst25.name + + # nodes + if getattr(ctx, 'verbose', False): + print('[nodes] %r' % cls) + + attr = dict() + inputs = [vars['Un_Unsqueezecst'], vars['Un_Unsqueezecst1'], ] + node = ctx.make_node( + 'Unsqueeze', inputs=inputs, attr=attr, + name=make_name('Un_Unsqueeze')) + vars['Un_expanded0'] = node.output[0] + + attr = dict() + inputs = [vars['Un_Unsqueezecst2'], vars['Un_Unsqueezecst1'], ] + node = ctx.make_node( + 'Unsqueeze', inputs=inputs, attr=attr, + name=make_name('Un_Unsqueeze1')) + vars['Un_expanded03'] = node.output[0] + + attr = dict() + inputs = [vars['x'], ] + node = ctx.make_node( + 'Shape', inputs=inputs, attr=attr, + name=make_name('Sh_Shape')) + vars['Sh_shape0'] = node.output[0] + + attr = dict() + inputs = [vars['Sh_shape0'], ] + node = ctx.make_node( + 'Shape', inputs=inputs, attr=attr, + name=make_name('Sh_Shape1')) + vars['Sh_shape01'] = node.output[0] + + attr = dict(axis=0,) + inputs = [vars['Sh_shape01'], vars['Ga_Gathercst'], ] + node = ctx.make_node( + 'Gather', inputs=inputs, attr=attr, + name=make_name('Ga_Gather')) + vars['Ga_output01'] = node.output[0] + + attr = dict() + inputs = [vars['Ga_output01'], vars['Un_Unsqueezecst1'], ] + node = ctx.make_node( + 'Unsqueeze', inputs=inputs, attr=attr, + name=make_name('Un_Unsqueeze2')) + vars['Un_expanded05'] = node.output[0] + + attr = dict(axis=0,) + inputs = [vars['Un_expanded05'], ] + node = ctx.make_node( + 'Concat', inputs=inputs, attr=attr, + name=make_name('Co_Concat')) + vars['Co_concat_result01'] = node.output[0] + + attr = dict() + inputs = [vars['Sh_shape0'], vars['Sl_Slicecst'], + vars['Co_concat_result01'], vars['Un_Unsqueezecst1'], ] + node = ctx.make_node( + 'Slice', inputs=inputs, attr=attr, + name=make_name('Sl_Slice')) + vars['Sl_output05'] = node.output[0] + + attr = dict(axis=0,) + inputs = [vars['Co_Concatcst'], vars['Sl_output05'], ] + node = ctx.make_node( + 'Concat', inputs=inputs, attr=attr, + name=make_name('Co_Concat1')) + vars['Co_concat_result0'] = node.output[0] + + attr = dict() + inputs = [vars['x'], vars['Co_concat_result0'], ] + node = ctx.make_node( + 'Reshape', inputs=inputs, attr=attr, + name=make_name('Re_Reshape')) + vars['Re_reshaped0'] = node.output[0] + + attr = dict() + inputs = [vars['Re_reshaped0'], vars['Sl_Slicecst2'], + vars['Sl_Slicecst3'], vars['Sl_Slicecst4'], ] + node = ctx.make_node( + 'Slice', inputs=inputs, attr=attr, + name=make_name('Sl_Slice1')) + vars['Sl_output04'] = node.output[0] + + attr = dict(perm=[0, 2, 1],) + inputs = [vars['Sl_output04'], ] + node = ctx.make_node( + 'Transpose', inputs=inputs, attr=attr, + name=make_name('Tr_Transpose')) + vars['Tr_transposed02'] = node.output[0] + + attr = dict() + inputs = [vars['Tr_transposed02'], vars['Un_Unsqueezecst1'], + vars['Sl_Slicecst6'], vars['Sl_Slicecst7'], ] + node = ctx.make_node( + 'Slice', inputs=inputs, attr=attr, + name=make_name('Sl_Slice2')) + vars['Sl_output03'] = node.output[0] + + attr = dict() + inputs = [vars['Sl_output03'], vars['Sl_Slicecst7'], ] + node = ctx.make_node( + 'Unsqueeze', inputs=inputs, attr=attr, + name=make_name('Un_Unsqueeze3')) + vars['Un_expanded04'] = node.output[0] + + attr = dict() + inputs = [vars['Un_expanded03'], vars['Un_expanded04'], ] + node = ctx.make_node( + 'MatMul', inputs=inputs, attr=attr, + name=make_name('Ma_MatMul')) + vars['Ma_Y01'] = node.output[0] + + attr = dict() + inputs = [vars['Ma_Y01'], vars['Un_Unsqueezecst1'], + vars['Sl_Slicecst9'], vars['Sl_Slicecst7'], ] + node = ctx.make_node( + 'Slice', inputs=inputs, attr=attr, + name=make_name('Sl_Slice3')) + vars['Sl_output02'] = node.output[0] + + attr = dict(perm=[1, 0, 3, 2],) + inputs = [vars['Sl_output02'], ] + node = ctx.make_node( + 'Transpose', inputs=inputs, attr=attr, + name=make_name('Tr_Transpose1')) + vars['Tr_transposed01'] = node.output[0] + + attr = dict(axis=0,) + inputs = [vars['Tr_transposed01'], vars['Ga_Gathercst'], ] + node = ctx.make_node( + 'Gather', inputs=inputs, attr=attr, + name=make_name('Ga_Gather1')) + vars['Ga_output0'] = node.output[0] + + attr = dict() + inputs = [vars['Ga_output0'], vars['Un_Unsqueezecst1'], + vars['Sl_Slicecst7'], vars['Sl_Slicecst7'], ] + node = ctx.make_node( + 'Slice', inputs=inputs, attr=attr, + name=make_name('Sl_Slice4')) + vars['Sl_output01'] = node.output[0] + + attr = dict() + inputs = [vars['Sl_output01'], vars['Sl_Slicecst7'], ] + node = ctx.make_node( + 'Unsqueeze', inputs=inputs, attr=attr, + name=make_name('Un_Unsqueeze4')) + vars['Un_expanded02'] = node.output[0] + + attr = dict() + inputs = [vars['Un_expanded0'], vars['Un_expanded02'], ] + node = ctx.make_node( + 'MatMul', inputs=inputs, attr=attr, + name=make_name('Ma_MatMul1')) + vars['Ma_Y0'] = node.output[0] + + attr = dict(perm=[1, 0, 2, 3],) + inputs = [vars['Ma_Y0'], ] + node = ctx.make_node( + 'Transpose', inputs=inputs, attr=attr, + name=make_name('Tr_Transpose2')) + vars['Tr_transposed0'] = node.output[0] + + attr = dict(axis=0,) + inputs = [vars['Tr_transposed01'], vars['Ga_Gathercst2'], ] + node = ctx.make_node( + 'Gather', inputs=inputs, attr=attr, + name=make_name('Ga_Gather2')) + vars['Ga_output03'] = node.output[0] + + attr = dict() + inputs = [vars['Ga_output03'], vars['Un_Unsqueezecst1'], + vars['Sl_Slicecst7'], vars['Sl_Slicecst7'], ] + node = ctx.make_node( + 'Slice', inputs=inputs, attr=attr, + name=make_name('Sl_Slice5')) + vars['Sl_output07'] = node.output[0] + + attr = dict() + inputs = [vars['Sl_output07'], vars['Sl_Slicecst7'], ] + node = ctx.make_node( + 'Unsqueeze', inputs=inputs, attr=attr, + name=make_name('Un_Unsqueeze6')) + vars['Un_expanded07'] = node.output[0] + + attr = dict() + inputs = [vars['Un_expanded0'], vars['Un_expanded07'], ] + node = ctx.make_node( + 'MatMul', inputs=inputs, attr=attr, + name=make_name('Ma_MatMul2')) + vars['Ma_Y03'] = node.output[0] + + attr = dict(perm=[1, 0, 2, 3],) + inputs = [vars['Ma_Y03'], ] + node = ctx.make_node( + 'Transpose', inputs=inputs, attr=attr, + name=make_name('Tr_Transpose3')) + vars['Tr_transposed04'] = node.output[0] + + attr = dict() + inputs = [vars['Tr_transposed04'], vars['Sl_Slicecst7'], + vars['Sl_Slicecst18'], vars['Un_Unsqueezecst1'], ] + node = ctx.make_node( + 'Slice', inputs=inputs, attr=attr, + name=make_name('Sl_Slice6')) + vars['Sl_output06'] = node.output[0] + + attr = dict() + inputs = [vars['Sl_output06'], ] + node = ctx.make_node( + 'Neg', inputs=inputs, attr=attr, + name=make_name('Ne_Neg')) + vars['Ne_Y0'] = node.output[0] + + attr = dict() + inputs = [vars['Tr_transposed04'], vars['Un_Unsqueezecst1'], + vars['Sl_Slicecst7'], vars['Un_Unsqueezecst1'], ] + node = ctx.make_node( + 'Slice', inputs=inputs, attr=attr, + name=make_name('Sl_Slice7')) + vars['Sl_output08'] = node.output[0] + + attr = dict(axis=0,) + inputs = [vars['Ne_Y0'], vars['Sl_output08'], ] + node = ctx.make_node( + 'Concat', inputs=inputs, attr=attr, + name=make_name('Co_Concat2')) + vars['Co_concat_result03'] = node.output[0] + + attr = dict() + inputs = [vars['Tr_transposed0'], vars['Co_concat_result03'], ] + node = ctx.make_node( + 'Add', inputs=inputs, attr=attr, + name=make_name('Ad_Add')) + vars['Ad_C0'] = node.output[0] + + attr = dict() + inputs = [vars['Ad_C0'], vars['Sl_Slicecst2'], + vars['Sl_Slicecst24'], vars['Sl_Slicecst25'], ] + node = ctx.make_node( + 'Slice', inputs=inputs, attr=attr, + name=make_name('Sl_Slice8')) + vars['Sl_output0'] = node.output[0] + + attr = dict() + inputs = [vars['Sh_shape0'], vars['Un_Unsqueezecst1'], + vars['Sl_Slicecst'], vars['Un_Unsqueezecst1'], ] + node = ctx.make_node( + 'Slice', inputs=inputs, attr=attr, + name=make_name('Sl_Slice9')) + vars['Sl_output010'] = node.output[0] + + attr = dict() + inputs = [vars['Sl_output0'], ] + node = ctx.make_node( + 'Shape', inputs=inputs, attr=attr, + name=make_name('Sh_Shape3')) + vars['Sh_shape03'] = node.output[0] + + attr = dict() + inputs = [vars['Sh_shape03'], ] + node = ctx.make_node( + 'Shape', inputs=inputs, attr=attr, + name=make_name('Sh_Shape4')) + vars['Sh_shape04'] = node.output[0] + + attr = dict(axis=0,) + inputs = [vars['Sh_shape04'], vars['Ga_Gathercst'], ] + node = ctx.make_node( + 'Gather', inputs=inputs, attr=attr, + name=make_name('Ga_Gather3')) + vars['Ga_output04'] = node.output[0] + + attr = dict() + inputs = [vars['Ga_output04'], vars['Un_Unsqueezecst1'], ] + node = ctx.make_node( + 'Unsqueeze', inputs=inputs, attr=attr, + name=make_name('Un_Unsqueeze7')) + vars['Un_expanded08'] = node.output[0] + + attr = dict(axis=0,) + inputs = [vars['Un_expanded08'], ] + node = ctx.make_node( + 'Concat', inputs=inputs, attr=attr, + name=make_name('Co_Concat3')) + vars['Co_concat_result05'] = node.output[0] + + attr = dict() + inputs = [vars['Sh_shape03'], vars['Sl_Slicecst'], + vars['Co_concat_result05'], vars['Un_Unsqueezecst1'], ] + node = ctx.make_node( + 'Slice', inputs=inputs, attr=attr, + name=make_name('Sl_Slice10')) + vars['Sl_output012'] = node.output[0] + + attr = dict(axis=0,) + inputs = [vars['Sl_Slicecst18'], + vars['Sl_output010'], vars['Sl_output012'], ] + node = ctx.make_node( + 'Concat', inputs=inputs, attr=attr, + name=make_name('Co_Concat4')) + vars['Co_concat_result04'] = node.output[0] + + attr = dict() + inputs = [vars['Sl_output0'], vars['Co_concat_result04'], ] + node = ctx.make_node( + 'Reshape', inputs=inputs, attr=attr, + name=make_name('Re_Reshape1')) + vars['y'] = node.output[0] + + # finalize + if getattr(ctx, 'verbose', False): + print('[replace_all_inputs] %r' % cls) + ctx.replace_all_inputs(oldnode.output[0], node.output[0]) + ctx.remove_node(oldnode.name) + + @classmethod + def version_13(cls, ctx, node, **kwargs): + return cls.any_version(13, ctx, node, **kwargs) class TestExportOnnx(ExtTestCase): @@ -36,7 +500,7 @@ def verify(self, content): raise AssertionError( "Unable to compile a script due to %r. " "\n--CODE--\n%s" - "" % (e, content)) from e + "" % (e, print_code(content))) from e glo = globals().copy() loc = {'numpy_helper': numpy_helper, 'make_model': make_model, @@ -62,7 +526,7 @@ def verify(self, content): "Unable to execute a script due to %r. " "\n--OUT--\n%s\n--ERR--\n%s\n--CODE--\n%s" "" % (e, out.getvalue(), err.getvalue(), - content)) from e + print_code(content))) from e return glo, loc def test_export_onnx(self): @@ -109,13 +573,13 @@ def verify_tf(self, content): raise AssertionError( "Unable to compile a script due to %r. " "\n--CODE--\n%s" - "" % (e, content)) from e + "" % (e, print_code(content))) from e glo = globals().copy() - loc = {'numpy': numpy, 'print': print, - 'dict': dict, 'sorted': sorted, 'list': list, + loc = {'numpy': numpy, 'dict': dict, 'list': list, 'print': print, 'sorted': sorted, 'collections': collections, 'inspect': inspect, - 'helper': helper} + 'helper': helper, "make_sure": make_sure, + 'ConvertFFT2DOp': ConvertFFT2DOp, "make_name": make_name} out = StringIO() err = StringIO() if len(left) >= 14: @@ -131,7 +595,7 @@ def verify_tf(self, content): "Unable to execute a script due to %r. " "\n--OUT--\n%s\n--ERR--\n%s\n--CODE--\n%s" "" % (e, out.getvalue(), err.getvalue(), - content)) from e + print_code(content))) from e return glo, loc def test_export2tf2onnx(self): @@ -140,12 +604,32 @@ def test_export2tf2onnx(self): names = ["fft2d_any.onnx"] for name in names: with self.subTest(name=name): + oinf0 = OnnxInference(os.path.join(folder, name)) + + x = numpy.random.randn(3, 1, 4).astype(numpy.float32) + y = oinf0.run({'x': x}) + new_onnx = export2tf2onnx( os.path.join(folder, name), name="FFT2D") _, loc = self.verify_tf(new_onnx) - model = loc['onnx_model'] + model = loc['onnx_raw'] self.assertIn('op_type: "FFT2D"', str(model)) - # print(model) + model = loc['onnx_model'] + self.assertNotIn('op_type: "FFT2D"', str(model)) + + oinf = OnnxInference(model) + y1 = oinf.run({'x': x}) + + new_onnx = export2tf2onnx( + os.path.join(folder, name), name="FFT2D") + _, loc = self.verify_tf(new_onnx) + model = loc['onnx_model'] + self.assertNotIn('op_type: "FFT2D"', str(model)) + oinf = OnnxInference(model) + y2 = oinf.run({'x': x}) + + self.assertEqualArray(y['y'], y1['y']) + self.assertEqualArray(y['y'], y2['y']) if __name__ == "__main__": diff --git a/mlprodict/__main__.py b/mlprodict/__main__.py index b38a90994..76e264e3c 100644 --- a/mlprodict/__main__.py +++ b/mlprodict/__main__.py @@ -22,6 +22,7 @@ def main(args, fLOG=print): from .cli.asv2csv import asv2csv from .cli.replay import benchmark_replay from .cli.einsum import einsum_test + from .cli.onnx_code import onnx_code except ImportError: # pragma: no cover from mlprodict.cli.validate import validate_runtime from mlprodict.cli.convert_validate import convert_validate @@ -30,6 +31,7 @@ def main(args, fLOG=print): from mlprodict.cli.asv2csv import asv2csv from mlprodict.cli.replay import benchmark_replay from mlprodict.cli.einsum import einsum_test + from mlprodict.cli.onnx_code import onnx_code fcts = dict(validate_runtime=validate_runtime, convert_validate=convert_validate, @@ -38,7 +40,8 @@ def main(args, fLOG=print): asv_bench=asv_bench, asv2csv=asv2csv, benchmark_replay=benchmark_replay, - einsum_test=einsum_test) + einsum_test=einsum_test, + onnx_code=onnx_code) try: from pyquickhelper.cli import cli_main_helper except ImportError: # pragma: no cover diff --git a/mlprodict/cli/__init__.py b/mlprodict/cli/__init__.py index 43fa491ba..5596858f9 100644 --- a/mlprodict/cli/__init__.py +++ b/mlprodict/cli/__init__.py @@ -4,5 +4,6 @@ """ from .convert_validate import convert_validate from .einsum import einsum_test +from .onnx_code import onnx_code from .optimize import onnx_optim from .validate import validate_runtime diff --git a/mlprodict/cli/onnx_code.py b/mlprodict/cli/onnx_code.py new file mode 100644 index 000000000..8832b97fa --- /dev/null +++ b/mlprodict/cli/onnx_code.py @@ -0,0 +1,57 @@ +""" +@file +@brief Command line to check einsum scenarios. +""" + + +def onnx_code(filename, format="onnx", output=None, verbose=0, name=None, + opset=None, fLOG=print): + """ + Exports an ONNX graph into a python code creating + the same graph. + + :param filename: onnx file + :param format: format to export too (`onnx`, `tf2onnx`) + :param output: output file to produce or None to print it on stdout + :param verbose: verbosity level + :param name: rewrite the graph name + :param opset: overwrite the opset (may not works depending on the format) + :param fLOG: logging function + + .. cmdref:: + :title: Exports an ONNX graph into a python code creating the same graph. + :cmd: -m mlprodict onnx_code --help + :lid: l-cmd-onnx_code + + The command pr + + Example:: + + python -m mlprodict onnx_code --filename="something.onnx" --format=onnx + """ + from ..onnx_tools.onnx_export import export2onnx, export2tf2onnx # pylint: disable=E0402 + + if name == '': + name = None + if opset == '': + opset = None + try: + v = int(opset) + opset = v + except (ValueError, TypeError): + opset = None + + if format == 'onnx': + code = export2onnx(filename, verbose=verbose, name=name, opset=opset) + elif format == 'tf2onnx': + code = export2tf2onnx(filename, verbose=verbose, + name=name, opset=opset) + else: + raise ValueError( # pragma: no cover + "Unknown format %r." % format) + + if output not in ('', None): + with open(output, "w", encoding="utf-8") as f: + f.write(code) + else: + fLOG(code) diff --git a/mlprodict/onnx_tools/exports/__init__.py b/mlprodict/onnx_tools/exports/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/mlprodict/onnx_tools/exports/tf2onnx_helper.py b/mlprodict/onnx_tools/exports/tf2onnx_helper.py new file mode 100644 index 000000000..129477b69 --- /dev/null +++ b/mlprodict/onnx_tools/exports/tf2onnx_helper.py @@ -0,0 +1,364 @@ +""" +@file +@brief Helpers to run examples created with function +@see fn export2tf2onnx. +""" +import collections +import inspect +import numpy +from onnx.numpy_helper import from_array +from onnx.helper import ( + make_node, make_graph, make_model, set_model_props, make_tensor) +from onnx import AttributeProto +from ..onnx2py_helper import guess_dtype, guess_proto_dtype +from ..onnx_tools import ensure_topological_order + + +_make_name_id = 0 + + +def make_name(name): + "Creates a unique name." + global _make_name_id # pylint: disable=W0603 + name = "%s_%d" % (name, _make_name_id) + _make_name_id += 1 + return name + + +def make_sure(cond, msg, *args): + "Raises an exception if cond is not verified." + if not cond: + raise RuntimeError(msg % tuple(args)) + + +class tf_op: + """ + Decorator to register any new converter. + :param name: type of the operator to rewrite + :param domain: domain + """ + _OPSETS = collections.OrderedDict() + + def __init__(self, name, domain='', **kwargs): + if not isinstance(name, list): + name = [name] + self.names = name + self.domain = domain + self.kwargs = kwargs + + def __call__(self, func): + for ke, va in inspect.getmembers(func, inspect.ismethod): + if ke.startswith("version_"): + version = int(ke.replace("version_", "")) + self._register_handler( + va, version, self.names, self.domain, self.kwargs) + return func + + def _register_handler(self, func, version, names, domain, kwargs): + opset = tf_op._OPSETS.get(domain) + if not opset: + opset = [] + tf_op._OPSETS[domain] = opset + while version >= len(opset): + opset.append({}) + opset_dict = opset[version] + for name in names: + opset_dict[name] = (func, kwargs) + + +class Tf2OnnxConvert: + """ + Applies the converter on an ONNX graph. + + :param onnx_model: ONNX graph + :param tf_op: class which register + :param verbose: verbosity + :param target_opset: targetted opsets + """ + + def __init__(self, onnx_model, _tf_op=None, verbose=None, + target_opset=None): + self._onnx_model = onnx_model + self._tf_op = _tf_op or tf_op + self.verbose = verbose + if isinstance(target_opset, int): + self.target_opsets = {'': target_opset} + elif isinstance(target_opset, dict): + self.target_opsets = target_opset + elif target_opset is None: + opsets = {} + for oimp in onnx_model.opset_import: + if oimp.domain == '': + opsets[oimp.domain] = oimp.version + opset = oimp.version + else: + opsets[oimp.domain] = opset + self.target_opsets = opsets + else: + raise ValueError( # pragma: no cover + "Unexepected value for target_opset=%r." % target_opset) + self._names = {} + for node in onnx_model.graph.node: + self._names[node.name] = node + for init in onnx_model.graph.initializer: + self._names[init.name] = init + # _forbidden_new_names contains current names and deleted names. + self._forbidden_new_names = set(self._names) + + def get_node_by_name(self, name): + """ + Retrieves a node by its name. + + :param name: node name + :return: node name + """ + if name not in self._names: + raise RuntimeError( + "Unable to find node name %r among %r." % ( + name, ", ".join(sorted(self._names)))) + return self._names[name] + + def _add_node_name(self, obj): + """ + Registers an object in in the graph by its name. + :param name: node or initializer + """ + if obj.name in self._forbidden_new_names: + raise RuntimeError( + "Name %r is already registered." % obj.name) + self._names[obj.name] = obj + self._forbidden_new_names.add(obj.name) + + def make_node(self, op_type, inputs, attr=None, outputs=None, + name=None, domain='', output_count=1): + """ + Adds a node to the list of nodes. + + :param op_type: operator type + :param inputs: list of strings + :param attr: dictionary of attributes + :param outputs: None or list of strings + :param output_count: used if outputs is None to guess + the number of outputs of this node + :param name: name of the node + :param domain: domain + :return: created node + """ + if self.verbose: + print("[Tf2OnnxConvert.make_node] op_type=%r inputs=%r" % ( + op_type, inputs)) + + if attr is None: + attr = {} + if name is None: + name = make_name(op_type) + if name in self._names: + raise RuntimeError( + "Node name %r already exists in %r." % ( + name, ", ".join(sorted(self._names)))) + + if outputs is None: + outputs = [(name + ":" + str(i)) for i in range(output_count)] + + output_count = len(outputs) + raw_attr = {} + onnx_attrs = [] + for a, v in attr.items(): + if isinstance(v, AttributeProto): + onnx_attrs.append(v) + else: + raw_attr[a] = v + + onnx_node = make_node( + op_type, inputs, outputs, name=name, domain=domain, **raw_attr) + + self._add_node_name(onnx_node) + return onnx_node + + def make_const(self, name, np_val, skip_conversion=False, raw=True): + """ + Make a new constants in the graph. + :param name: const node name, must be unique. + :param np_val: value of type numpy ndarray. + :param skip_conversion: + bool, indicate whether this created node would be mapped + during conversion + :param raw: whether to store data at field of raw_data or the + specific field according to its dtype + :return: create initializer + """ + if name in self._names: + raise RuntimeError( + "Initializer name %r already exists in %r." % ( + name, ", ".join(sorted(self._names)))) + np_val_flat = np_val.flatten() + is_bytes = (np_val.dtype == numpy.object and len(np_val_flat) > 0 and + isinstance(np_val_flat[0], bytes)) + if raw and not is_bytes: + onnx_tensor = from_array(np_val, name) + else: + onnx_tensor = make_tensor( + name, guess_proto_dtype(np_val.dtype), + np_val.shape, np_val_flat, raw=False) + + self._add_node_name(onnx_tensor) + return onnx_tensor + + def get_dtype(self, input_name): + """ + Returns the type of one node or None if unknown. + :param input_name: result name + :return: numpy dtype + """ + inputs = self._onnx_model.graph.input + names = [_.name for _ in inputs] + if input_name not in names: + return None # pragma: no cover + ind = names.index(input_name) + return guess_dtype(inputs[ind].type.tensor_type.elem_type) + + def replace_all_inputs(self, old_name, new_name): + """ + Every taking *old_name* as inputs will take *new_name* instead. + Looks in the output as well but in that case, it creates an identity + node to avoid changing an output name. + :param old_name: name to replace + :param new_name: new name + :return: list of impacted nodes + """ + res = [] + for node in self._names.values(): + if not hasattr(node, 'input'): + continue + if old_name not in node.input: + continue + new_inputs = [new_name if i.name == old_name else i.name + for i in node.input] + node.input[:] = new_inputs[:] + res.append(node) + if self.verbose: + print("[Tf2OnnxConvert.replace_all_inputs] replace %r by %r in node %r" % ( + old_name, new_name, node.name)) + for o in self._onnx_model.graph.output: + if o.name != old_name: + continue + n = self.make_node("Identity", [new_name], outputs=[old_name], + name=make_name("IdOutputReplaced")) + res.append(n) + if self.verbose: + print("[Tf2OnnxConvert.replace_all_inputs] add id node from %r to %r " + "with node %r." % ( + old_name, new_name, n.name)) + return res + + def remove_node(self, name): + """ + Removes a node name from the list. + """ + if name not in self._names: + raise RuntimeError( + "Unable to delete name %r because it does not exists." % name) + del self._names[name] + if self.verbose: + print("[Tf2OnnxConvert.remove_node] delete name %r" % name) + + def get_shape(self, input_name): + """ + Returns the type of one node or None if unknown. + :param input_name: result name + :return: numpy dtype + """ + inputs = self._onnx_model.graph.input + names = [_.name for _ in inputs] + if input_name not in names: + return None # pragma: no cover + ind = names.index(input_name) + dims = inputs[ind].type.tensor_type.shape.dim + return tuple(dims) + + def run(self): + """ + Calls the registered converters on the graph + held by this instance. Returns the new onnx graph. + + :return: ONNX graph + """ + if len(self._tf_op._OPSETS) == 0: + raise RuntimeError( # pragma: no cover + "No converter was registered.") + if self.verbose: + print("[Tf2OnnxConvert.run]") + + done = {} + modif = 1 + while modif > 0: + modif = 0 + # The converter may alter the current list of nodes, we freeze it. + current_values = list(self._names.values()) + for node in current_values: + if not hasattr(node, 'domain'): + # initializer + continue + if done.get(node.name, False): + continue + domain = node.domain + if domain not in self._tf_op._OPSETS: + continue + + # look for a converter + rews = self._tf_op._OPSETS[domain] + target = min(self.target_opsets[domain], len(rews)) + conv = None + for i in range(len(rews) - 1, -1, -1): + if node.op_type in rews[i]: + conv = rews[i][node.op_type] + break + if conv is None: + continue + + # applies the converter + if self.verbose: + print("[Tf2OnnxConvert.run] convert node type=%r opset=%r name=%r" + "" % (node.op_type, target, node.name)) + fct, kwargs = conv + fct(self, node, target_opset=target, **kwargs) + modif += 1 + + return self.make_model() + + def make_model(self): + """ + Produces the new ONNX graph with the updated sets of nodes. + """ + inputs = self._onnx_model.graph.input + outputs = self._onnx_model.graph.output + inits = [init[1] for init in sorted(self._names.items()) + if not hasattr(init[1], 'domain')] + nodes = [node[1] for node in sorted(self._names.items()) + if hasattr(node[1], 'domain')] + nodes = ensure_topological_order(inputs, inits, nodes) + + if self.verbose: + print( + "[Tf2OnnxConvert.make_node] %d nodes %d inputs %d " + "outputs %d initializers" + "" % (len(nodes), len(inputs), len(outputs), len(inits))) + graph = make_graph(nodes, self._onnx_model.graph.name, + inputs, outputs, inits) + onnx_model = make_model(graph) + onnx_model.ir_version = self._onnx_model.ir_version + onnx_model.producer_name = self._onnx_model.producer_name + "-mlprodict" + onnx_model.producer_version = self._onnx_model.producer_version + onnx_model.domain = self._onnx_model.domain + onnx_model.model_version = self._onnx_model.model_version + onnx_model.doc_string = self._onnx_model.doc_string + metadata = {p.key: p.value for p in self._onnx_model.metadata_props} + set_model_props(onnx_model, metadata) + + # opsets + del onnx_model.opset_import[:] # pylint: disable=E1101 + for dom, value in self.target_opsets.items(): + op_set = onnx_model.opset_import.add() # pylint: disable=E1101 + op_set.domain = dom + op_set.version = value + return onnx_model diff --git a/mlprodict/onnx_tools/onnx2py_helper.py b/mlprodict/onnx_tools/onnx2py_helper.py index ae2d3b727..966f39f62 100644 --- a/mlprodict/onnx_tools/onnx2py_helper.py +++ b/mlprodict/onnx_tools/onnx2py_helper.py @@ -485,3 +485,41 @@ def guess_proto_dtype(dtype): return TensorProto.STRING # pylint: disable=E1101 raise RuntimeError( "Unable to guess type for dtype={}.".format(dtype)) # pragma: no cover + + +def guess_dtype(proto_type): + """ + Converts a proto type into a :epkg:`numpy` type. + + :param proto_type: example ``onnx.TensorProto.FLOAT`` + :return: :epkg:`numpy` dtype + """ + if proto_type == TensorProto.FLOAT: # pylint: disable=E1101 + return numpy.float32 + if proto_type == TensorProto.BOOL: # pylint: disable=E1101 + return numpy.bool_ + if proto_type == TensorProto.DOUBLE: # pylint: disable=E1101 + return numpy.float64 + if proto_type == TensorProto.STRING: # pylint: disable=E1101 + return numpy.str_ + if proto_type == TensorProto.INT64: # pylint: disable=E1101 + return numpy.int64 + if proto_type == TensorProto.INT32: # pylint: disable=E1101 + return numpy.int32 + if proto_type == TensorProto.INT8: # pylint: disable=E1101 + return numpy.int8 + if proto_type == TensorProto.INT16: # pylint: disable=E1101 + return numpy.int16 + if proto_type == TensorProto.UINT64: # pylint: disable=E1101 + return numpy.uint64 + if proto_type == TensorProto.UINT32: # pylint: disable=E1101 + return numpy.uint32 + if proto_type == TensorProto.UINT8: # pylint: disable=E1101 + return numpy.uint8 + if proto_type == TensorProto.UINT16: # pylint: disable=E1101 + return numpy.uint16 + if proto_type == TensorProto.FLOAT16: # pylint: disable=E1101 + return numpy.float16 + raise ValueError( + "Unable to convert proto_type {} to numpy type.".format( + proto_type)) diff --git a/mlprodict/onnx_tools/onnx_export.py b/mlprodict/onnx_tools/onnx_export.py index a9a94664e..7e06783db 100644 --- a/mlprodict/onnx_tools/onnx_export.py +++ b/mlprodict/onnx_tools/onnx_export.py @@ -112,47 +112,12 @@ def create_model(): from onnx.helper import ( make_model, make_node, set_model_props, make_tensor, make_graph, make_tensor_value_info) - try: - from utils import make_name - except ImportError: - - _make_name_id = 0 - - - def make_name(name): - global _make_name_id - name = "%s_%d" % (name, _make_name_id) - _make_name_id += 1 - return name - - - class tf_op: - _OPSETS = collections.OrderedDict() - - def __init__(self, name, domain='', **kwargs): - if not isinstance(name, list): - name = [name] - self.names = name - self.domain = domain - self.kwargs = kwargs - - def __call__(self, func): - for ke, va in inspect.getmembers(func, inspect.ismethod): - if ke.startswith("version_"): - version = int(ke.replace("version_", "")) - self.register_handler(va, version, self.names, self.domain, self.kwargs) - return func - - def register_handler(self, func, version, names, domain, kwargs): - opset = tf_op._OPSETS.get(domain) - if not opset: - opset = [] - tf_op._OPSETS[domain] = opset - while version >= len(opset): - opset.append({}) - opset_dict = opset[version] - for name in names: - opset_dict[name] = (func, kwargs) + # from utils import make_name, make_sure + from mlprodict.onnx_tools.exports.tf2onnx_helper import ( + make_name, make_sure) + # from tf2onnx.handler import tf_op + from mlprodict.onnx_tools.exports.tf2onnx_helper import tf_op + from mlprodict.onnx_tools.exports.tf2onnx_helper import Tf2OnnxConvert @tf_op("{{ name }}") @@ -177,12 +142,13 @@ def any_version(cls, opset, ctx, node, **kwargs): oldnode = node input_name = node.input[0] onnx_dtype = ctx.get_dtype(input_name) - utils.make_sure(onnx_dtype in Convert{{ name }}Op.supported_dtypes, "Unsupported input type.") + make_sure(onnx_dtype in Convert{{ name }}Op.supported_dtypes, "Unsupported input type.") shape = ctx.get_shape(input_name) - vars = {} + varx = {x: x for x in node.input} # initializers - print('[initializers]') # verbose + if getattr(ctx, 'verbose', False): + print('[initializers] %r' % cls) {% for name, value in initializers: %} {% if len(value.shape) == 0: %} value = numpy.array({{ value }}, dtype=numpy.{{ value.dtype }}) @@ -191,32 +157,35 @@ def any_version(cls, opset, ctx, node, **kwargs): value = numpy.array(list_value, dtype=numpy.{{ value.dtype }}){% if len(value.shape) > 1: %}.reshape({{ value.shape }}){% endif %} {% endif %} r_{{ name }} = ctx.make_const(name=make_name('init_{{ name }}'), np_val=value) - vars['{{ name }}'] = r_{{ name }}.name + varx['{{ name }}'] = r_{{ name }}.name {% endfor %} # nodes - print('[nodes]') # verbose + if getattr(ctx, 'verbose', False): + print('[nodes] %r' % cls) {% for node in nodes: %} attr = dict( {%- for name, value in node['attributes']: -%} {{ name }}={{ value }}, {%- endfor -%}) - inputs = [{% for name in node['inputs']: -%}vars['{{ name }}'], {%- endfor %}] + inputs = [{% for name in node['inputs']: -%}varx['{{ name }}'], {%- endfor %}] node = ctx.make_node( '{{ node['op_type'] }}', inputs=inputs, attr=attr,{% if node['domain']: -%} domain='{{ node['domain'] }}', {% endif %} name=make_name('{{ node['name'] }}')) {% for i, name in enumerate(node['outputs']): -%} - vars['{{ name }}'] = node.output[{{ i }}] + varx['{{ name }}'] = node.output[{{ i }}] {%- endfor %} {% endfor %} # finalize + if getattr(ctx, 'verbose', False): + print('[replace_all_inputs] %r' % cls) ctx.replace_all_inputs(oldnode.output[0], node.output[0]) ctx.remove_node(oldnode.name) - @classmethod - def version_13(cls, ctx, node, **kwargs): - return cls.any_version(13, ctx, node, **kwargs) + @classmethod + def version_13(cls, ctx, node, **kwargs): + return cls.any_version(13, ctx, node, **kwargs) def create_model(): @@ -265,63 +234,8 @@ def create_model(): return onnx_model - class Rewrite: - - def __init__(self, onnx_model, tf_op): - self._onnx_model = onnx_model - self._nodes = list(onnx_model.graph.node) - self._tf_op = tf_op - - def make_node(self, op_type, inputs, attr=None, outputs=None, - name=None, domain=''): - if attr is None: - attr = {} - if name is None: - name = make_name(op_type) - - if outputs is None: - outputs = [name + ":" + str(i) for i in range(output_count)] - - output_count = len(outputs) - raw_attr = {} - onnx_attrs = [] - for a, v in attr.items(): - if isinstance(v, AttributeProto): - onnx_attrs.append(v) - else: - raw_attr[a] = v - - n = self.get_node_by_name(name) - - for o in outputs: - n = self.get_node_by_output_in_current_graph(o) - - onnx_node = make_node( - op_type, inputs, outputs, name=name, domain=domain, **raw_attr) - - self._nodes.append(onnx_node) - return node - - def rewrite(self): - print('[rewrite]') # verbose - done = {} - modif = 1 - while modif > 0: - modif = 0 - for node in self._nodes: - if done.get(node.name, False): - continue - domain = node.domain - if domain not in self._tf_op._OPSETS: - continue - rews = self._tf_op._OPSETS[domain] - # look for an opset - # call the rewriter - - - - onnx_model = create_model() - onnx_rewritten = Rewrite(onnx_model, tf_op).rewrite() + onnx_raw = create_model() + onnx_model = Tf2OnnxConvert(onnx_raw, tf_op).run() """) @@ -425,6 +339,26 @@ def export2onnx(model_onnx, opset=None, verbose=True, name=None): :param verbose: inserts prints :param name: to overwrite onnx name :return: python code + + The following example shows what a python code creating a graph + implementing the KMeans would look like. + + .. runpython:: + :showcode: + + import numpy + from sklearn.cluster import KMeans + from skl2onnx import to_onnx + from mlprodict.onnx_tools.onnx_export import export2onnx + + X = numpy.arange(20).reshape(10, 2).astype(numpy.float32) + tr = KMeans(n_clusters=2) + tr.fit(X) + + onx = to_onnx(tr, X, target_opset=14) + code = export2onnx(onx) + + print(code) """ if isinstance(model_onnx, str): model_onnx = onnx.load(model_onnx) @@ -443,6 +377,23 @@ def export2tf2onnx(model_onnx, opset=None, verbose=True, name=None): :param verbose: inserts prints :param name: to overwrite onnx name :return: python code + + .. runpython:: + :showcode: + + import numpy + from sklearn.cluster import KMeans + from skl2onnx import to_onnx + from mlprodict.onnx_tools.onnx_export import export2tf2onnx + + X = numpy.arange(20).reshape(10, 2).astype(numpy.float32) + tr = KMeans(n_clusters=2) + tr.fit(X) + + onx = to_onnx(tr, X, target_opset=14) + code = export2tf2onnx(onx) + + print(code) """ if isinstance(model_onnx, str): model_onnx = onnx.load(model_onnx) diff --git a/mlprodict/onnx_tools/onnx_tools.py b/mlprodict/onnx_tools/onnx_tools.py index 0ce1b542c..7d5ca3b5b 100644 --- a/mlprodict/onnx_tools/onnx_tools.py +++ b/mlprodict/onnx_tools/onnx_tools.py @@ -76,6 +76,8 @@ def insert_node(model, op_type, node, input_index=0, new_name=None, **attrs): inode.input[input_index] = new_name keep_nodes = list(model.graph.node) keep_nodes.append(new_node) + keep_nodes = ensure_topological_order( + model.graph.input, model.graph.initializer, keep_nodes) graph = helper.make_graph( keep_nodes, model.graph.name, model.graph.input, @@ -102,3 +104,88 @@ def insert_node(model, op_type, node, input_index=0, new_name=None, **attrs): "Input mismatch {} != {}".format( len(onnx_model.input), len(model.input))) # pylint: disable=E1101 return onnx_model + + +def ensure_topological_order(inputs, initializers, nodes): + """ + Ensures and modifies the order of nodes to have + a topological order (every node in the list + can only be an input for a node later in this list). + The function raises an exception if a cycle is detected. + + :param inputs: graph inputs: + :param initializers: graph initializers + :param nodes: graph nodes + :return: list ordered nodes + """ + order = {} + for inp in inputs: + name = inp.name + order[name] = 0 + for inp in initializers: + name = inp.name + order[name] = 0 + n_iter = 0 + while n_iter < len(nodes) * 2: + n_iter += 1 + missing_names = set() + missing_ops = [] + for node in nodes: + maxi = 0 + for name in node.input: + if name in order: + maxi = max(maxi, order[name]) + else: + maxi = None + missing_names.add(name) + break + if maxi is None: + missing_ops.append(node) + continue + key = id(node) + if key in order: + continue + maxi += 1 + order[key] = maxi + maxi += 1 + for name in node.output: + if name in order: + raise RuntimeError( + "Unable to sort a node (cycle). An output was " + "already ordered %r (iteration=%r)." % ( + name, n_iter)) + order[name] = maxi + if len(missing_names) == 0: + continue + + if len(missing_ops) > 0: + def nstr(name): + if name in order: + return "%s#%d" % (name, order[name]) + return name + rows = ["%s(%s) -> [%s]" % ( + n.name or n.op_type, + ', '.join(map(nstr, n.input)), + ', '.join(n.output)) + for n in missing_ops] + rows.insert(0, "") + rows.append("--") + rows.append("--all-nodes--") + rows.append("--") + rows.extend("%s(%s) -> [%s]" % ( + n.name or n.op_type, + ', '.join(map(nstr, n.input)), + ', '.join(n.output)) + for n in nodes) + raise RuntimeError( + "After %d iterations for %d nodes, still unable " + "to sort names %r. The graph may be disconnected. " + "List of operators: %s" % ( + n_iter, len(nodes), missing_names, + "\n".join(rows))) + + # Update order + topo = [(order[id(node)], str(id(node))) for node in nodes] + topo.sort() + map_nodes = {str(id(node)): node for node in nodes} + return [map_nodes[_[1]] for _ in topo] diff --git a/mlprodict/onnxrt/ops_cpu/_op_helper.py b/mlprodict/onnxrt/ops_cpu/_op_helper.py index 389ac914e..15d5920d9 100644 --- a/mlprodict/onnxrt/ops_cpu/_op_helper.py +++ b/mlprodict/onnxrt/ops_cpu/_op_helper.py @@ -3,7 +3,6 @@ @brief Runtime operator. """ import numpy -from onnx import TensorProto def _get_typed_class_attribute(self, k, atts): @@ -32,35 +31,8 @@ def proto2dtype(proto_type): :param proto_type: example ``onnx.TensorProto.FLOAT`` :return: :epkg:`numpy` dtype """ - if proto_type == TensorProto.FLOAT: # pylint: disable=E1101 - return numpy.float32 - if proto_type == TensorProto.BOOL: # pylint: disable=E1101 - return numpy.bool_ - if proto_type == TensorProto.DOUBLE: # pylint: disable=E1101 - return numpy.float64 - if proto_type == TensorProto.STRING: # pylint: disable=E1101 - return numpy.str_ - if proto_type == TensorProto.INT64: # pylint: disable=E1101 - return numpy.int64 - if proto_type == TensorProto.INT32: # pylint: disable=E1101 - return numpy.int32 - if proto_type == TensorProto.INT8: # pylint: disable=E1101 - return numpy.int8 - if proto_type == TensorProto.INT16: # pylint: disable=E1101 - return numpy.int16 - if proto_type == TensorProto.UINT64: # pylint: disable=E1101 - return numpy.uint64 - if proto_type == TensorProto.UINT32: # pylint: disable=E1101 - return numpy.uint32 - if proto_type == TensorProto.UINT8: # pylint: disable=E1101 - return numpy.uint8 - if proto_type == TensorProto.UINT16: # pylint: disable=E1101 - return numpy.uint16 - if proto_type == TensorProto.FLOAT16: # pylint: disable=E1101 - return numpy.float16 - raise ValueError( - "Unable to convert proto_type {} to numpy type.".format( - proto_type)) + from ...onnx_tools.onnx2py_helper import guess_dtype + return guess_dtype(proto_type) def dtype_name(dtype): From 4a2665a7b29c5720a0a4d2177ca02d177319b666 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Sat, 7 Aug 2021 14:54:36 +0200 Subject: [PATCH 10/12] lint --- _unittests/ut_tools/test_export_onnx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/_unittests/ut_tools/test_export_onnx.py b/_unittests/ut_tools/test_export_onnx.py index 0444eeb6d..331b2627a 100644 --- a/_unittests/ut_tools/test_export_onnx.py +++ b/_unittests/ut_tools/test_export_onnx.py @@ -27,7 +27,7 @@ class ConvertFFT2DOp: ] @classmethod - def any_version(cls, opset, ctx, node, **kwargs): #pylint: disable=R0915 + def any_version(cls, opset, ctx, node, **kwargs): # pylint: disable=R0915 ''' Converter for ``FFT2D``. From c5971c8ce092b972ca08ed43248bdea6fb845ce2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Sat, 7 Aug 2021 15:12:43 +0200 Subject: [PATCH 11/12] documentation --- _doc/sphinxdoc/source/api/tools.rst | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/_doc/sphinxdoc/source/api/tools.rst b/_doc/sphinxdoc/source/api/tools.rst index de23a9573..cb631f6e4 100644 --- a/_doc/sphinxdoc/source/api/tools.rst +++ b/_doc/sphinxdoc/source/api/tools.rst @@ -18,6 +18,13 @@ Accessor .. autosignature:: mlprodict.onnx_tools.onnx_tools.insert_node +Export +++++++ + +.. autosignature:: mlprodict.onnx_tools.onnx_export.export2onnx + +.. autosignature:: mlprodict.onnx_tools.onnx_export.export2tf2onnx + Graphs ++++++ @@ -48,6 +55,8 @@ The following functions reduce the number of ONNX operators in a graph while keeping the same results. The optimized graph is left unchanged. +.. autosignature:: mlprodict.onnx_tools.onnx_tools.ensure_topological_order + .. autosignature:: mlprodict.onnx_tools.optim.onnx_optimisation.onnx_remove_node .. autosignature:: mlprodict.onnx_tools.optim.onnx_optimisation_identity.onnx_remove_node_identity @@ -68,15 +77,24 @@ Serialization .. autosignature:: mlprodict.onnx_tools.onnx2py_helper.to_bytes +Runtime +======= + +.. autosignature:: mlprodict.tools.onnx_micro_runtime.OnnxMicroRuntime + Validation ++++++++++ .. autosignature:: mlprodict.onnx_tools.model_checker.onnx_shaker -Runtime -======= +Visualization ++++++++++++++ -.. autosignature:: mlprodict.tools.onnx_micro_runtime.OnnxMicroRuntime +Many times I had to debug and I was thinking about a way to see +a graph in a text editor. That's the goal of this function with +the possibility later to only show a part of a graph. + +.. autosignature:: mlprodict.onnx_tools.graphs.onnx2bigraph Others ====== From 51ae786453a0ed71fc896417fc4ac4452704e72c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Sat, 7 Aug 2021 19:06:05 +0200 Subject: [PATCH 12/12] lint --- mlprodict/onnx_tools/exports/tf2onnx_helper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlprodict/onnx_tools/exports/tf2onnx_helper.py b/mlprodict/onnx_tools/exports/tf2onnx_helper.py index 129477b69..d30023a11 100644 --- a/mlprodict/onnx_tools/exports/tf2onnx_helper.py +++ b/mlprodict/onnx_tools/exports/tf2onnx_helper.py @@ -248,7 +248,7 @@ def replace_all_inputs(self, old_name, new_name): if self.verbose: print("[Tf2OnnxConvert.replace_all_inputs] add id node from %r to %r " "with node %r." % ( - old_name, new_name, n.name)) + old_name, new_name, n.name)) # pylint: disable=E1101 return res def remove_node(self, name):