diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml new file mode 100644 index 0000000..efc2167 --- /dev/null +++ b/.github/workflows/build.yaml @@ -0,0 +1,85 @@ +name: Build & Test +on: + push: + branches: + - main + tags: + - "v*" + pull_request: + +jobs: + build-and-test: + runs-on: ubuntu-20.04 + strategy: + matrix: + python-version: + - "3.7" + - "3.8" + - "3.9" + - "3.10" + - "3.11" + steps: + - name: Check out code + uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Build package + run: | + python -m pip install --upgrade build twine + python -m build + twine check --strict dist/* + - name: Install coveralls + run: sudo pip install coveralls + - name: Run tests + run: sudo PATH=$PATH coverage run setup.py test + + release: + runs-on: ubuntu-20.04 + if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v') + needs: + - build-and-test + steps: + - name: Check out code + uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.x" + - name: Build package + run: | + python -m pip install --upgrade build twine + python -m build + twine check --strict dist/* + rm -f dist/*.whl + - name: Publish package + uses: pypa/gh-action-pypi-publish@release/v1 + with: + user: __token__ + password: ${{ secrets.PYPI_API_TOKEN }} + - name: Create GitHub release + id: create_release + uses: actions/create-release@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + tag_name: ${{ github.ref }} + release_name: ${{ github.ref }} + draft: false + prerelease: false + - name: Set asset name + run: | + export PKG=$(ls dist/ | grep tar) + set -- $PKG + echo "name=$1" >> $GITHUB_ENV + - name: Upload release asset to GitHub + id: upload-release-asset + uses: actions/upload-release-asset@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + upload_url: ${{ steps.create_release.outputs.upload_url }} + asset_path: dist/${{ env.name }} + asset_name: ${{ env.name }} + asset_content_type: application/zip diff --git a/.gitignore b/.gitignore index 3a98128..ecf313f 100644 --- a/.gitignore +++ b/.gitignore @@ -5,7 +5,33 @@ *.bak /build - /doc/_build - /.project +/MANIFEST +/dist +/scripts +/Vagrantfile +/.vagrant +/tags +*/__pycache__/ +*.pyc +/build/ +/.pybuild/ +*.egg-info + +# Debian build related files +/debian/python-iptables/ +/debian/python3-iptables/ +/debian/python-iptables-dbg/ +/debian/python3-iptables-dbg/ +/debian/python-iptables-doc/ +/debian/files +/debian/outfile +/debian/*.debhelper +/debian/*.log +/debian/*-stamp +/debian/*.substvars + + +# Added exclusion for PyCharm files +.idea/* diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 67d1392..0000000 --- a/.travis.yml +++ /dev/null @@ -1,9 +0,0 @@ -language: python -python: - - "2.6" - - "2.7" -install: - - python setup.py build - - python setup.py install -script: - - echo "y"|sudo PATH=$PATH ./test.py diff --git a/NOTICE b/NOTICE index 5cdb2d3..1746fec 100644 --- a/NOTICE +++ b/NOTICE @@ -1,5 +1,5 @@ - Copyright (c) 2010-, Nilvec nilvec-(at)-nilvec.com + Copyright (c) 2010-, Vilmos Nebehaj and others Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/README.md b/README.md index 5ec49f6..bd966b4 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,607 @@ +Introduction +============ + +About python-iptables +--------------------- + +**Iptables** is the tool that is used to manage **netfilter**, the +standard packet filtering and manipulation framework under Linux. As the +iptables manpage puts it: + +> Iptables is used to set up, maintain, and inspect the tables of IPv4 +> packet filter rules in the Linux kernel. Several different tables may +> be defined. +> +> Each table contains a number of built-in chains and may also contain +> user- defined chains. +> +> Each chain is a list of rules which can match a set of packets. Each +> rule specifies what to do with a packet that matches. This is called a +> target, which may be a jump to a user-defined chain in the same table. + +`Python-iptables` provides a pythonesque wrapper via python bindings to +iptables under Linux. Interoperability with iptables is achieved via +using the iptables C libraries (`libiptc`, `libxtables`, and the +iptables extensions), not calling the iptables binary and parsing its +output. It is meant primarily for dynamic and/or complex routers and +firewalls, where rules are often updated or changed, or Python programs +wish to interface with the Linux iptables framework.. + +If you are looking for `ebtables` python bindings, check out +[python-ebtables](https://github.com/ldx/python-ebtables/). + +`Python-iptables` supports Python 2.6, 2.7 and 3.4. + +[![Flattr](http://api.flattr.com/button/flattr-badge-large.png)](https://flattr.com/submit/auto?user_id=ldx&url=https%3A%2F%2Fgithub.com%2Fldx%2Fpython-iptables) + +[![Latest Release](https://img.shields.io/pypi/v/python-iptables.svg)](https://pypi.python.org/pypi/python-iptables) + [![Build Status](https://travis-ci.org/ldx/python-iptables.png?branch=master)](https://travis-ci.org/ldx/python-iptables) -Python-iptables is a pythonesque wrapper around the Linux iptables/ip6tables facility. It is meant primarily for dynamic and/or complex firewalls, where rules are often updated or changed. Python-iptables makes it possible to use Python to parse or change rules without the need to spawn processes to execute an iptables command. +[![Coverage Status](https://coveralls.io/repos/ldx/python-iptables/badge.svg?branch=codecoverage)](https://coveralls.io/r/ldx/python-iptables?branch=codecoverage) + +[![Code Health](https://landscape.io/github/ldx/python-iptables/codecoverage/landscape.svg)](https://landscape.io/github/ldx/python-iptables/codecoverage) + +[![Number of Downloads](https://img.shields.io/pypi/dm/python-iptables.svg)](https://pypi.python.org/pypi/python-iptables) + +[![License](https://img.shields.io/pypi/l/python-iptables.svg)](https://pypi.python.org/pypi/python-iptables) + +Installing via pip +------------------ + +The usual way: + + pip install --upgrade python-iptables + +Compiling from source +--------------------- + +First make sure you have iptables installed (most Linux distributions +install it by default). `Python-iptables` needs the shared libraries +`libiptc.so` and `libxtables.so` coming with iptables, they are +installed in `/lib` on Ubuntu. + +You can compile `python-iptables` in the usual distutils way: + + % cd python-iptables + % python setup.py build + +If you like, `python-iptables` can also be installed into a +`virtualenv`: + + % mkvirtualenv python-iptables + % python setup.py install + +If you install `python-iptables` as a system package, make sure the +directory where `distutils` installs shared libraries is in the dynamic +linker's search path (it's in `/etc/ld.so.conf` or in one of the files +in the folder `/etc/ld.co.conf.d`). Under Ubuntu `distutils` by default +installs into `/usr/local/lib`. + +Now you can run the tests: + + % sudo PATH=$PATH python setup.py test + WARNING: this test will manipulate iptables rules. + Don't do this on a production machine. + Would you like to continue? y/n y + [...] + +The `PATH=$PATH` part is necessary after `sudo` if you have installed +into a `virtualenv`, since `sudo` will reset your environment to a +system setting otherwise.. + +Once everything is in place you can fire up python to check whether the +package can be imported: + + % sudo PATH=$PATH python + >>> import iptc + >>> + +Of course you need to be root to be able to use iptables. + +Using a custom iptables install +------------------------------- + +If you are stuck on a system with an old version of `iptables`, you can +install a more up to date version to a custom location, and ask +`python-iptables` to use libraries at that location. + +To install `iptables` to `/tmp/iptables`: + + % git clone git://git.netfilter.org/iptables && cd iptables + % ./autogen.sh + % ./configure --prefix=/tmp/iptables + % make + % make install + +Make sure the dependencies `iptables` needs are installed. + +Now you can point `python-iptables` to this install path via: + + % sudo PATH=$PATH IPTABLES_LIBDIR=/tmp/iptables/lib XTABLES_LIBDIR=/tmp/iptables/lib/xtables python + >>> import iptc + >>> + +What is supported +----------------- + +The basic iptables framework and all the match/target extensions are +supported by `python-iptables`, including IPv4 and IPv6 ones. All IPv4 +and IPv6 tables are supported as well. + +Full documentation with API reference is available +[here](http://ldx.github.com/python-iptables/). + +Examples +======== + +High level abstractions +----------------------- + +``python-iptables`` implements a low-level interface that tries to closely +match the underlying C libraries. The module ``iptc.easy`` improves the +usability of the library by providing a rich set of high-level functions +designed to simplify the interaction with the library, for example: + + >>> import iptc + >>> iptc.easy.dump_table('nat', ipv6=False) + {'INPUT': [], 'OUTPUT': [], 'POSTROUTING': [], 'PREROUTING': []} + >>> iptc.easy.dump_chain('filter', 'OUTPUT', ipv6=False) + [{'comment': {'comment': 'DNS traffic to Google'}, + 'counters': (1, 56), + 'dst': '8.8.8.8/32', + 'protocol': 'udp', + 'target': 'ACCEPT', + 'udp': {'dport': '53'}}] + >>> iptc.easy.add_chain('filter', 'TestChain') + True + >>> rule_d = {'protocol': 'tcp', 'target': 'ACCEPT', 'tcp': {'dport': '22'}} + >>> iptc.easy.insert_rule('filter', 'TestChain', rule_d) + >>> iptc.easy.dump_chain('filter', 'TestChain') + [{'protocol': 'tcp', 'target': 'ACCEPT', 'tcp': {'dport': '22'}}] + >>> iptc.easy.delete_chain('filter', 'TestChain', flush=True) + + >>> # Example of goto rule // iptables -A FORWARD -p gre -g TestChainGoto + >>> iptc.easy.add_chain('filter', 'TestChainGoto') + >>> rule_goto_d = {'protocol': 'gre', 'target': {'goto': 'TestChainGoto'}} + >>> iptc.easy.insert_rule('filter', 'FORWARD', rule_goto_d) + +Rules +----- + +In `python-iptables`, you usually first create a rule, and set any +source/destination address, in/out interface and protocol specifiers, +for example: + + >>> import iptc + >>> rule = iptc.Rule() + >>> rule.in_interface = "eth0" + >>> rule.src = "192.168.1.0/255.255.255.0" + >>> rule.protocol = "tcp" + +This creates a rule that will match TCP packets coming in on eth0, with +a source IP address of 192.168.1.0/255.255.255.0. + +A rule may contain matches and a target. A match is like a filter +matching certain packet attributes, while a target tells what to do with +the packet (drop it, accept it, transform it somehow, etc). One can +create a match or target via a Rule: + + >>> rule = iptc.Rule() + >>> m = rule.create_match("tcp") + >>> t = rule.create_target("DROP") + +Match and target parameters can be changed after creating them. It is +also perfectly valid to create a match or target via instantiating them +with their constructor, but you still need a rule and you have to add +the matches and the target to their rule manually: + + >>> rule = iptc.Rule() + >>> match = iptc.Match(rule, "tcp") + >>> target = iptc.Target(rule, "DROP") + >>> rule.add_match(match) + >>> rule.target = target + +Any parameters a match or target might take can be set via the +attributes of the object. To set the destination port for a TCP match: + + >>> rule = iptc.Rule() + >>> rule.protocol = "tcp" + >>> match = rule.create_match("tcp") + >>> match.dport = "80" + +To set up a rule that matches packets marked with 0xff: + + >>> rule = iptc.Rule() + >>> rule.protocol = "tcp" + >>> match = rule.create_match("mark") + >>> match.mark = "0xff" + +Parameters are always strings. You can supply any string as the +parameter value, but note that most extensions validate their +parameters. For example this: + + >>> rule = iptc.Rule() + >>> rule.protocol = "tcp" + >>> rule.target = iptc.Target(rule, "ACCEPT") + >>> match = iptc.Match(rule, "state") + >>> chain = iptc.Chain(iptc.Table(iptc.Table.FILTER), "INPUT") + >>> match.state = "RELATED,ESTABLISHED" + >>> rule.add_match(match) + >>> chain.insert_rule(rule) + +will work. However, if you change the state parameter: + + >>> rule = iptc.Rule() + >>> rule.protocol = "tcp" + >>> rule.target = iptc.Target(rule, "ACCEPT") + >>> match = iptc.Match(rule, "state") + >>> chain = iptc.Chain(iptc.Table(iptc.Table.FILTER), "INPUT") + >>> match.state = "RELATED,ESTABLISHED,FOOBAR" + >>> rule.add_match(match) + >>> chain.insert_rule(rule) + +`python-iptables` will throw an exception: + + Traceback (most recent call last): + File "state.py", line 7, in + match.state = "RELATED,ESTABLISHED,FOOBAR" + File "/home/user/Projects/python-iptables/iptc/ip4tc.py", line 369, in __setattr__ + self.parse(name.replace("_", "-"), value) + File "/home/user/Projects/python-iptables/iptc/ip4tc.py", line 286, in parse + self._parse(argv, inv, entry) + File "/home/user/Projects/python-iptables/iptc/ip4tc.py", line 516, in _parse + ct.cast(self._ptrptr, ct.POINTER(ct.c_void_p))) + File "/home/user/Projects/python-iptables/iptc/xtables.py", line 736, in new + ret = fn(*args) + File "/home/user/Projects/python-iptables/iptc/xtables.py", line 1031, in parse_match + argv[1])) + iptc.xtables.XTablesError: state: parameter error -2 (RELATED,ESTABLISHED,FOOBAR) + +Certain parameters take a string that optionally consists of multiple +words. The comment match is a good example: + + >>> rule = iptc.Rule() + >>> rule.src = "127.0.0.1" + >>> rule.protocol = "udp" + >>> rule.target = rule.create_target("ACCEPT") + >>> match = rule.create_match("comment") + >>> match.comment = "this is a test comment" + >>> chain = iptc.Chain(iptc.Table(iptc.Table.FILTER), "INPUT") + >>> chain.insert_rule(rule) + +Note that this is still just one parameter value. + +However, when a match or a target takes multiple parameter values, that +needs to be passed in as a list. Let's assume you have created and set +up an `ipset` called `blacklist` via the `ipset` command. To create a +rule with a match for this set: + + >>> rule = iptc.Rule() + >>> m = rule.create_match("set") + >>> m.match_set = ['blacklist', 'src'] + +Note how this time a list was used for the parameter value, since the +`set` match `match_set` parameter expects two values. See the `iptables` +manpages to find out what the extensions you use expect. See +[ipset](http://ipset.netfilter.org/) for more information. + +When you are ready constructing your rule, add them to the chain you +want it to show up in: + + >>> chain = iptc.Chain(iptc.Table(iptc.Table.FILTER), "INPUT") + >>> chain.insert_rule(rule) + +This will put your rule into the INPUT chain in the filter table. + +Chains and tables +----------------- + +You can of course also check what a rule's source/destination address, +in/out inteface etc is. To print out all rules in the FILTER table: + + >>> import iptc + >>> table = iptc.Table(iptc.Table.FILTER) + >>> for chain in table.chains: + >>> print "=======================" + >>> print "Chain ", chain.name + >>> for rule in chain.rules: + >>> print "Rule", "proto:", rule.protocol, "src:", rule.src, "dst:", \ + >>> rule.dst, "in:", rule.in_interface, "out:", rule.out_interface, + >>> print "Matches:", + >>> for match in rule.matches: + >>> print match.name, + >>> print "Target:", + >>> print rule.target.name + >>> print "=======================" + +As you see in the code snippet above, rules are organized into chains, +and chains are in tables. You have a fixed set of tables; for IPv4: + +- `FILTER`, +- `NAT`, +- `MANGLE` and +- `RAW`. + +For IPv6 the tables are: + +- `FILTER`, +- `MANGLE`, +- `RAW` and +- `SECURITY`. + +To access a table: + + >>> import iptc + >>> table = iptc.Table(iptc.Table.FILTER) + >>> print table.name + filter + +To create a new chain in the FILTER table: + + >>> import iptc + >>> table = iptc.Table(iptc.Table.FILTER) + >>> chain = table.create_chain("testchain") + + $ sudo iptables -L -n + [...] + Chain testchain (0 references) + target prot opt source destination + +To access an existing chain: + + >>> import iptc + >>> table = iptc.Table(iptc.Table.FILTER) + >>> chain = iptc.Chain(table, "INPUT") + >>> chain.name + 'INPUT' + >>> len(chain.rules) + 10 + >>> + +More about matches and targets +------------------------------ + +There are basic targets, such as `DROP` and `ACCEPT`. E.g. to reject +packets with source address `127.0.0.1/255.0.0.0` coming in on any of +the `eth` interfaces: + + >>> import iptc + >>> chain = iptc.Chain(iptc.Table(iptc.Table.FILTER), "INPUT") + >>> rule = iptc.Rule() + >>> rule.in_interface = "eth+" + >>> rule.src = "127.0.0.1/255.0.0.0" + >>> target = iptc.Target(rule, "DROP") + >>> rule.target = target + >>> chain.insert_rule(rule) + +To instantiate a target or match, we can either create an object like +above, or use the `rule.create_target(target_name)` and +`rule.create_match(match_name)` methods. For example, in the code above +target could have been created as: + + >>> target = rule.create_target("DROP") + +instead of: + + >>> target = iptc.Target(rule, "DROP") + >>> rule.target = target + +The former also adds the match or target to the rule, saving a call. + +Another example, using a target which takes parameters. Let's mark +packets going to `192.168.1.2` UDP port `1234` with `0xffff`: + + >>> import iptc + >>> chain = iptc.Chain(iptc.Table(iptc.Table.MANGLE), "PREROUTING") + >>> rule = iptc.Rule() + >>> rule.dst = "192.168.1.2" + >>> rule.protocol = "udp" + >>> match = iptc.Match(rule, "udp") + >>> match.dport = "1234" + >>> rule.add_match(match) + >>> target = iptc.Target(rule, "MARK") + >>> target.set_mark = "0xffff" + >>> rule.target = target + >>> chain.insert_rule(rule) + +Matches are optional (specifying a target is mandatory). E.g. to insert +a rule to NAT TCP packets going out via `eth0`: + + >>> import iptc + >>> chain = iptc.Chain(iptc.Table(iptc.Table.NAT), "POSTROUTING") + >>> rule = iptc.Rule() + >>> rule.protocol = "tcp" + >>> rule.out_interface = "eth0" + >>> target = iptc.Target(rule, "MASQUERADE") + >>> target.to_ports = "1234" + >>> rule.target = target + >>> chain.insert_rule(rule) + +Here only the properties of the rule decide whether the rule will be +applied to a packet. + +Matches are optional, but we can add multiple matches to a rule. In the +following example we will do that, using the `iprange` and the `tcp` +matches: + + >>> import iptc + >>> rule = iptc.Rule() + >>> rule.protocol = "tcp" + >>> match = iptc.Match(rule, "tcp") + >>> match.dport = "22" + >>> rule.add_match(match) + >>> match = iptc.Match(rule, "iprange") + >>> match.src_range = "192.168.1.100-192.168.1.200" + >>> match.dst_range = "172.22.33.106" + >>> rule.add_match(match) + >>> rule.target = iptc.Target(rule, "DROP") + >>> chain = iptc.Chain(iptc.Table(iptc.Table.FILTER), "INPUT") + >>> chain.insert_rule(rule) + +This is the `python-iptables` equivalent of the following iptables +command: + + # iptables -A INPUT -p tcp –destination-port 22 -m iprange –src-range 192.168.1.100-192.168.1.200 –dst-range 172.22.33.106 -j DROP + +You can of course negate matches, just like when you use `!` in front of +a match with iptables. For example: + + >>> import iptc + >>> rule = iptc.Rule() + >>> match = iptc.Match(rule, "mac") + >>> match.mac_source = "!00:11:22:33:44:55" + >>> rule.add_match(match) + >>> rule.target = iptc.Target(rule, "ACCEPT") + >>> chain = iptc.Chain(iptc.Table(iptc.Table.FILTER), "INPUT") + >>> chain.insert_rule(rule) + +This results in: + + $ sudo iptables -L -n + Chain INPUT (policy ACCEPT) + target prot opt source destination + ACCEPT all -- 0.0.0.0/0 0.0.0.0/0 MAC ! 00:11:22:33:44:55 + + Chain FORWARD (policy ACCEPT) + target prot opt source destination + + Chain OUTPUT (policy ACCEPT) + target prot opt source destination + +Counters +-------- + +You can query rule and chain counters, e.g.: + + >>> import iptc + >>> table = iptc.Table(iptc.Table.FILTER) + >>> chain = iptc.Chain(table, 'OUTPUT') + >>> for rule in chain.rules: + >>> (packets, bytes) = rule.get_counters() + >>> print packets, bytes + +However, the counters are only refreshed when the underlying low-level +iptables connection is refreshed in `Table` via `table.refresh()`. For +example: + + >>> import time, sys + >>> import iptc + >>> table = iptc.Table(iptc.Table.FILTER) + >>> chain = iptc.Chain(table, 'OUTPUT') + >>> for rule in chain.rules: + >>> (packets, bytes) = rule.get_counters() + >>> print packets, bytes + >>> print "Please send some traffic" + >>> sys.stdout.flush() + >>> time.sleep(3) + >>> for rule in chain.rules: + >>> # Here you will get back the same counter values as above + >>> (packets, bytes) = rule.get_counters() + >>> print packets, bytes + +This will show you the same counter values even if there was traffic +hitting your rules. You have to refresh your table to get update your +counters: + + >>> import time, sys + >>> import iptc + >>> table = iptc.Table(iptc.Table.FILTER) + >>> chain = iptc.Chain(table, 'OUTPUT') + >>> for rule in chain.rules: + >>> (packets, bytes) = rule.get_counters() + >>> print packets, bytes + >>> print "Please send some traffic" + >>> sys.stdout.flush() + >>> time.sleep(3) + >>> table.refresh() # Here: refresh table to update rule counters + >>> for rule in chain.rules: + >>> (packets, bytes) = rule.get_counters() + >>> print packets, bytes + +What is more, if you add: + + iptables -A OUTPUT -p tcp --sport 80 + iptables -A OUTPUT -p tcp --sport 22 + +you can query rule and chain counters together with the protocol and +sport(or dport), e.g.: + + >>> import iptc + >>> table = iptc.Table(iptc.Table.FILTER) + >>> chain = iptc.Chain(table, 'OUTPUT') + >>> for rule in chain.rules: + >>> for match in rule.matches: + >>> (packets, bytes) = rule.get_counters() + >>> print packets, bytes, match.name, match.sport + +Autocommit +---------- + +`Python-iptables` by default automatically performs an iptables commit +after each operation. That is, after you add a rule in +`python-iptables`, that will take effect immediately. + +It may happen that you want to batch together certain operations. A +typical use case is traversing a chain and removing rules matching a +specific criteria. If you do this with autocommit enabled, after the +first delete operation, your chain's state will change and you have to +restart the traversal. You can do something like this: + + >>> import iptc + >>> table = iptc.Table(iptc.Table.FILTER) + >>> removed = True + >>> chain = iptc.Chain(table, "FORWARD") + >>> while removed == True: + >>> removed = False + >>> for rule in chain.rules: + >>> if rule.out_interface and "eth0" in rule.out_interface: + >>> chain.delete_rule(rule) + >>> removed = True + >>> break + +This is clearly not ideal and the code is not very readable. An +alternative is to disable autocommits, traverse the chain, removing one +or more rules, than commit it: + + >>> import iptc + >>> table = iptc.Table(iptc.Table.FILTER) + >>> table.autocommit = False + >>> chain = iptc.Chain(table, "FORWARD") + >>> for rule in chain.rules: + >>> if rule.out_interface and "eth0" in rule.out_interface: + >>> chain.delete_rule(rule) + >>> table.commit() + >>> table.autocommit = True + +The drawback is that Table is a singleton, and if you disable +autocommit, it will be disabled for all instances of that Table. + +Easy rules with dictionaries +---------------------------- +To simplify operations with ``python-iptables`` rules we have included support to define and convert Rules object into python dictionaries. + + >>> import iptc + >>> table = iptc.Table(iptc.Table.FILTER) + >>> chain = iptc.Chain(table, "INPUT") + >>> # Create an iptc.Rule object from dictionary + >>> rule_d = {'comment': {'comment': 'Match tcp.22'}, 'protocol': 'tcp', 'target': 'ACCEPT', 'tcp': {'dport': '22'}} + >>> rule = iptc.easy.encode_iptc_rule(rule_d) + >>> # Obtain a dictionary representation from the iptc.Rule + >>> iptc.easy.decode_iptc_rule(rule) + {'tcp': {'dport': '22'}, 'protocol': 'tcp', 'comment': {'comment': 'Match tcp.22'}, 'target': 'ACCEPT'} + + +Known Issues +============ + +These issues are mainly caused by complex interaction with upstream's +Netfilter implementation, and will require quite significant effort to +fix. Workarounds are available. + +- The `hashlimit` match requires explicitly setting + `hashlimit_htable_expire`. See [Issue + \#201](https://github.com/ldx/python-iptables/issues/201). +- The `NOTRACK` target is problematic; use `CT --notrack` instead. See + [Issue \#204](https://github.com/ldx/python-iptables/issues/204). -See [http://ldx.github.com/python-iptables/](http://ldx.github.com/python-iptables/) for documentation. diff --git a/debian/changelog b/debian/changelog index 9a73782..e8bab22 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,18 @@ +python-iptables (0.12.0) xenial; urgency=low + + * update debian changelog + + -- Dan Fuhry Fri, 11 Nov 2016 14:03:32 -0500 + +python-iptables (0.5-git-20140925) precise; urgency=low + + * update debian/ + remove cdbs + build python3 packages + build debug packages + + -- Markus Kötter Thu, 25 Sep 2014 10:48:41 +0200 + python-iptables (0.1.1) unstable; urgency=low * Initial Release. diff --git a/debian/control b/debian/control index b3582ab..2f82047 100644 --- a/debian/control +++ b/debian/control @@ -2,17 +2,40 @@ Source: python-iptables Section: net Priority: extra Maintainer: Juliano Martinez -Build-Depends: cdbs (>= 0.4.49), debhelper (>= 7), python (>= 2.4), python-support, python-all-dev (>= 2.3.5-11) -XS-Python-Version: >= 2.6 +Build-depends: python-all-dev (>= 2.7), python-all-dbg (>= 2.7), python3-all-dev (>= 3.2), python3-all-dbg, debhelper (>= 7) +X-Python-Version: >= 2.7 +X-Python3-Version: >= 3.2 Standards-Version: 3.8.4 Homepage: https://github.com/ldx/python-iptables Package: python-iptables Architecture: any -Depends: ${shlibs:Depends}, ${misc:Depends} +Depends: ${python:Depends}, ${shlibs:Depends}, ${misc:Depends} +Provides: ${python:Provides} Description: Python bindings for iptables Python-iptables is a Python project that provides bindings to the iptables C libraries in Linux. Interoperability with iptables is achieved using the iptables C libraries (libiptc, libxtables, and iptables extensions), not calling the iptables executable and parsing its output as most other iptables wrapper libraries do; this makes python-iptables faster and not prone to parsing errors, at the same time leveraging all available iptables match and target extensions without further work. + +Package: python-iptables-dbg +Section: debug +Priority: extra +Architecture: any +Depends: ${python:Depends}, ${shlibs:Depends}, ${misc:Depends}, python-iptables (= ${binary:Version}) +Provides: ${python:Provides} +Description: Python bindings for iptables + +Package: python3-iptables +Architecture: any +Depends: ${python3:Depends}, ${shlibs:Depends}, ${misc:Depends} +Description: Python3 bindings for iptables + +Package: python3-iptables-dbg +Section: debug +Priority: extra +Architecture: any +Depends: ${python3:Depends}, ${shlibs:Depends}, ${misc:Depends}, python3-iptables (= ${binary:Version}) +Description: Python3 bindings for iptables + diff --git a/debian/copyright b/debian/copyright index 9ab11ee..34a2451 100644 --- a/debian/copyright +++ b/debian/copyright @@ -8,11 +8,11 @@ It was downloaded from: Upstream Author(s): - ldx -> https://github.com/ldx + Vilmos Nebehaj -> https://github.com/ldx Copyright: - Nilvec -> http://nilvec.com + Vilmos Nebehaj License: diff --git a/debian/python3-iptables.postinst b/debian/python3-iptables.postinst new file mode 100644 index 0000000..7b46270 --- /dev/null +++ b/debian/python3-iptables.postinst @@ -0,0 +1,3 @@ +#!/bin/sh + +/sbin/ldconfig diff --git a/debian/rules b/debian/rules index 4cd3ea9..4fd181a 100755 --- a/debian/rules +++ b/debian/rules @@ -1,15 +1,50 @@ #!/usr/bin/make -f -# -*- makefile -*- -# Sample debian/rules that uses debhelper. -# This file was originally written by Joey Hess and Craig Small. -# As a special exception, when this file is copied by dh-make into a -# dh-make output file, you may use that output file without restriction. -# This special exception was added by Craig Small in version 0.37 of dh-make. -# Uncomment this to turn on verbose mode. -export DH_VERBOSE=1 +# This file was automatically generated by stdeb 0.6.0+git at +# Thu, 14 Nov 2013 16:04:50 +0100 -DEB_PYTHON_SYSTEM=pysupport +PY3VERS := $(shell py3versions -r) +PY2VERS := $(shell pyversions -r) + +%: + dh $@ --with python2,python3 --buildsystem=python_distutils + +.PHONY: override_dh_clean +override_dh_clean: + rm -rf build/* + dh_clean + +.PHONY: override_dh_auto_install +override_dh_auto_install: + set -e; \ + for py in $(PY3VERS); do \ + $$py -B setup.py install --root debian/python3-iptables --install-layout deb; \ + $$py-dbg -B setup.py install --root debian/python3-iptables-dbg --install-layout deb; \ + done + + # On Ubuntu 14.04 and possibly other versions, /usr/lib/python3/dist-packages is not + # included in the system library search path, so we add it here. + set -e; \ + mkdir -p $(CURDIR)/debian/python3-iptables/etc/ld.so.conf.d; \ + echo "/usr/lib/python3/dist-packages" > $(CURDIR)/debian/python3-iptables/etc/ld.so.conf.d/python3-dist-packages.conf + + set -e; \ + for py in $(PY2VERS); do \ + $$py -B setup.py install --root debian/python-iptables --install-layout deb; \ + $$py-dbg -B setup.py install --root debian/python-iptables-dbg --install-layout deb; \ + done + +.PHONY: override_dh_strip +override_dh_strip: + set -e; \ + dh_strip -ppython-ev --dbg-package=python-iptables-dbg; \ + dh_strip -ppython3-ev --dbg-package=python3-iptables-dbg; + +override_dh_python2: + dh_python2 -ppython-iptables + dh_python2 -ppython-iptables-dbg + +override_dh_python3: + dh_python3 -ppython3-iptables + dh_python3 -ppython3-iptables-dbg -include /usr/share/cdbs/1/rules/debhelper.mk -include /usr/share/cdbs/1/class/python-distutils.mk diff --git a/doc/Makefile b/doc/Makefile index 9ead21b..bd87a6b 100644 --- a/doc/Makefile +++ b/doc/Makefile @@ -99,3 +99,6 @@ gh-pages: git reset --hard git clean -f .. git checkout master + +markdown: intro.rst examples.rst + pandoc -o ../README.md -f rst -t markdown intro.rst examples.rst diff --git a/doc/conf.py b/doc/conf.py index 29e4d76..ea780c0 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -39,16 +39,16 @@ # General information about the project. project = u'python-iptables' -copyright = u'2010-2013, Nilvec' +copyright = u'2010-2014, Vilmos Nebehaj' # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the # built documents. # # The short X.Y version. -version = '0.2.0' +version = '0.4.0' # The full version, including alpha/beta/rc tags. -release = '0.2.0-dev' +release = '0.4.0-dev' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. @@ -174,7 +174,7 @@ # (source start file, target name, title, author, documentclass [howto/manual]). latex_documents = [ ('index', 'python-iptables.tex', u'python-iptables Documentation', - u'Nilvec', 'manual'), + u'Vilmos Nebehaj', 'manual'), ] # The name of an image file (relative to this directory) to place at the top of @@ -196,6 +196,6 @@ # Example configuration for intersphinx: refer to the Python standard library. -intersphinx_mapping = {'http://docs.python.org/': None} +intersphinx_mapping = {'python': ('http://docs.python.org/', None)} autoclass_content="both" diff --git a/doc/examples.rst b/doc/examples.rst index 4e7d302..98c6378 100644 --- a/doc/examples.rst +++ b/doc/examples.rst @@ -1,12 +1,43 @@ Examples ======== -Introduction ------------- +High level abstractions +----------------------- + +``python-iptables`` implements a low-level interface that tries to closely +match the underlying C libraries. The module ``iptc.easy`` improves the +usability of the library by providing a rich set of high-level functions +designed to simplify the interaction with the library, for example: + + >>> import iptc + >>> iptc.easy.dump_table('nat', ipv6=False) + {'INPUT': [], 'OUTPUT': [], 'POSTROUTING': [], 'PREROUTING': []} + >>> iptc.easy.dump_chain('filter', 'OUTPUT', ipv6=False) + [{'comment': {'comment': 'DNS traffic to Google'}, + 'counters': (1, 56), + 'dst': '8.8.8.8/32', + 'protocol': 'udp', + 'target': 'ACCEPT', + 'udp': {'dport': '53'}}] + >>> iptc.easy.add_chain('filter', 'TestChain') + True + >>> rule_d = {'protocol': 'tcp', 'target': 'ACCEPT', 'tcp': {'dport': '22'}} + >>> iptc.easy.insert_rule('filter', 'TestChain', rule_d) + >>> iptc.easy.dump_chain('filter', 'TestChain') + [{'protocol': 'tcp', 'target': 'ACCEPT', 'tcp': {'dport': '22'}}] + >>> iptc.easy.delete_chain('filter', 'TestChain', flush=True) + + >>> # Example of goto rule // iptables -A FORWARD -p gre -g TestChainGoto + >>> iptc.easy.add_chain('filter', 'TestChainGoto') + >>> rule_goto_d = {'protocol': 'gre', 'target': {'goto': 'TestChainGoto'}} + >>> iptc.easy.insert_rule('filter', 'FORWARD', rule_goto_d) + +Rules +----- In ``python-iptables``, you usually first create a rule, and set any source/destination address, in/out interface and protocol specifiers, for -example: +example:: >>> import iptc >>> rule = iptc.Rule() @@ -20,7 +51,7 @@ source IP address of 192.168.1.0/255.255.255.0. A rule may contain matches and a target. A match is like a filter matching certain packet attributes, while a target tells what to do with the packet (drop it, accept it, transform it somehow, etc). One can create a match or -target via a Rule: +target via a Rule:: >>> rule = iptc.Rule() >>> m = rule.create_match("tcp") @@ -29,7 +60,7 @@ target via a Rule: Match and target parameters can be changed after creating them. It is also perfectly valid to create a match or target via instantiating them with their constructor, but you still need a rule and you have to add the matches -and the target to their rule manually: +and the target to their rule manually:: >>> rule = iptc.Rule() >>> match = iptc.Match(rule, "tcp") @@ -38,35 +69,170 @@ and the target to their rule manually: >>> rule.target = target Any parameters a match or target might take can be set via the attributes of -the object. To set the destination port for a TCP match: +the object. To set the destination port for a TCP match:: >>> rule = iptc.Rule() >>> rule.protocol = "tcp" >>> match = rule.create_match("tcp") >>> match.dport = "80" -To set up a rule that matches packets marked with 0xff: +To set up a rule that matches packets marked with 0xff:: >>> rule = iptc.Rule() >>> rule.protocol = "tcp" >>> match = rule.create_match("mark") >>> match.mark = "0xff" -Parameters are always strings. +Parameters are always strings. You can supply any string as the parameter +value, but note that most extensions validate their parameters. For example +this:: + + >>> rule = iptc.Rule() + >>> rule.protocol = "tcp" + >>> rule.target = iptc.Target(rule, "ACCEPT") + >>> match = iptc.Match(rule, "state") + >>> chain = iptc.Chain(iptc.Table(iptc.Table.FILTER), "INPUT") + >>> match.state = "RELATED,ESTABLISHED" + >>> rule.add_match(match) + >>> chain.insert_rule(rule) + +will work. However, if you change the `state` parameter:: + + >>> rule = iptc.Rule() + >>> rule.protocol = "tcp" + >>> rule.target = iptc.Target(rule, "ACCEPT") + >>> match = iptc.Match(rule, "state") + >>> chain = iptc.Chain(iptc.Table(iptc.Table.FILTER), "INPUT") + >>> match.state = "RELATED,ESTABLISHED,FOOBAR" + >>> rule.add_match(match) + >>> chain.insert_rule(rule) + +``python-iptables`` will throw an exception:: + + Traceback (most recent call last): + File "state.py", line 7, in + match.state = "RELATED,ESTABLISHED,FOOBAR" + File "/home/user/Projects/python-iptables/iptc/ip4tc.py", line 369, in __setattr__ + self.parse(name.replace("_", "-"), value) + File "/home/user/Projects/python-iptables/iptc/ip4tc.py", line 286, in parse + self._parse(argv, inv, entry) + File "/home/user/Projects/python-iptables/iptc/ip4tc.py", line 516, in _parse + ct.cast(self._ptrptr, ct.POINTER(ct.c_void_p))) + File "/home/user/Projects/python-iptables/iptc/xtables.py", line 736, in new + ret = fn(*args) + File "/home/user/Projects/python-iptables/iptc/xtables.py", line 1031, in parse_match + argv[1])) + iptc.xtables.XTablesError: state: parameter error -2 (RELATED,ESTABLISHED,FOOBAR) + +Certain parameters take a string that optionally consists of multiple words. +The comment match is a good example:: + + >>> rule = iptc.Rule() + >>> rule.src = "127.0.0.1" + >>> rule.protocol = "udp" + >>> rule.target = rule.create_target("ACCEPT") + >>> match = rule.create_match("comment") + >>> match.comment = "this is a test comment" + >>> chain = iptc.Chain(iptc.Table(iptc.Table.FILTER), "INPUT") + >>> chain.insert_rule(rule) + +Note that this is still just one parameter value. + +However, when a match or a target takes multiple parameter values, that needs +to be passed in as a list. Let's assume you have created and set up an +``ipset`` called ``blacklist`` via the ``ipset`` command. To create a rule +with a match for this set:: + + >>> rule = iptc.Rule() + >>> m = rule.create_match("set") + >>> m.match_set = ['blacklist', 'src'] + +Note how this time a list was used for the parameter value, since the ``set`` +match ``match_set`` parameter expects two values. See the ``iptables`` +manpages to find out what the extensions you use expect. See ipset_ for more +information. + +.. _ipset: http://ipset.netfilter.org/ When you are ready constructing your rule, add them to the chain you want it -to show up in: +to show up in:: >>> chain = iptc.Chain(iptc.Table(iptc.Table.FILTER), "INPUT") >>> chain.insert_rule(rule) This will put your rule into the INPUT chain in the filter table. -Simple rule with standard target --------------------------------- +Chains and tables +----------------- + +You can of course also check what a rule's source/destination address, +in/out inteface etc is. To print out all rules in the FILTER table:: + + >>> import iptc + >>> table = iptc.Table(iptc.Table.FILTER) + >>> for chain in table.chains: + >>> print "=======================" + >>> print "Chain ", chain.name + >>> for rule in chain.rules: + >>> print "Rule", "proto:", rule.protocol, "src:", rule.src, "dst:", \ + >>> rule.dst, "in:", rule.in_interface, "out:", rule.out_interface, + >>> print "Matches:", + >>> for match in rule.matches: + >>> print match.name, + >>> print "Target:", + >>> print rule.target.name + >>> print "=======================" + +As you see in the code snippet above, rules are organized into chains, and +chains are in tables. You have a fixed set of tables; for IPv4: + +* ``FILTER``, +* ``NAT``, +* ``MANGLE`` and +* ``RAW``. + +For IPv6 the tables are: + +* ``FILTER``, +* ``MANGLE``, +* ``RAW`` and +* ``SECURITY``. + +To access a table:: + + >>> import iptc + >>> table = iptc.Table(iptc.Table.FILTER) + >>> print table.name + filter + +To create a new chain in the FILTER table:: -Reject packets with source address ``127.0.0.1/255.0.0.0`` coming in on any of -the eth interfaces: + >>> import iptc + >>> table = iptc.Table(iptc.Table.FILTER) + >>> chain = table.create_chain("testchain") + + $ sudo iptables -L -n + [...] + Chain testchain (0 references) + target prot opt source destination + +To access an existing chain:: + + >>> import iptc + >>> table = iptc.Table(iptc.Table.FILTER) + >>> chain = iptc.Chain(table, "INPUT") + >>> chain.name + 'INPUT' + >>> len(chain.rules) + 10 + >>> + +More about matches and targets +------------------------------ + +There are basic targets, such as ``DROP`` and ``ACCEPT``. E.g. to reject +packets with source address ``127.0.0.1/255.0.0.0`` coming in on any of the +``eth`` interfaces:: >>> import iptc >>> chain = iptc.Chain(iptc.Table(iptc.Table.FILTER), "INPUT") @@ -77,25 +243,22 @@ the eth interfaces: >>> rule.target = target >>> chain.insert_rule(rule) -Simple rule not using any match extensions ------------------------------------------- +To instantiate a target or match, we can either create an object like above, +or use the ``rule.create_target(target_name)`` and +``rule.create_match(match_name)`` methods. For example, in the code above +target could have been created as:: -Inserting a rule to NAT TCP packets going out via ``eth0``: + >>> target = rule.create_target("DROP") - >>> import iptc - >>> chain = iptc.Chain(iptc.Table(iptc.Table.NAT), "POSTROUTING") - >>> rule = iptc.Rule() - >>> rule.protocol = "tcp" - >>> rule.out_interface = "eth0" - >>> target = iptc.Target(rule, "MASQUERADE") - >>> target.to_ports = "1234" +instead of:: + + >>> target = iptc.Target(rule, "DROP") >>> rule.target = target - >>> chain.insert_rule(rule) -Rule using the udp match extension ----------------------------------- +The former also adds the match or target to the rule, saving a call. -Mark packets going to ``192.168.1.2`` UDP port ``1234`` with ``0xffff``: +Another example, using a target which takes parameters. Let's mark packets +going to ``192.168.1.2`` UDP port ``1234`` with ``0xffff``:: >>> import iptc >>> chain = iptc.Chain(iptc.Table(iptc.Table.MANGLE), "PREROUTING") @@ -110,13 +273,25 @@ Mark packets going to ``192.168.1.2`` UDP port ``1234`` with ``0xffff``: >>> rule.target = target >>> chain.insert_rule(rule) -Multiple matches with iprange ------------------------------ +Matches are optional (specifying a target is mandatory). E.g. to insert a rule +to NAT TCP packets going out via ``eth0``:: + + >>> import iptc + >>> chain = iptc.Chain(iptc.Table(iptc.Table.NAT), "POSTROUTING") + >>> rule = iptc.Rule() + >>> rule.protocol = "tcp" + >>> rule.out_interface = "eth0" + >>> target = iptc.Target(rule, "MASQUERADE") + >>> target.to_ports = "1234" + >>> rule.target = target + >>> chain.insert_rule(rule) -Now we will add multiple matches to a rule. This one is the -``python-iptables`` equivalent of the following iptables command: +Here only the properties of the rule decide whether the rule will be applied +to a packet. -# iptables -A INPUT -p tcp –destination-port 22 -m iprange –src-range 192.168.1.100-192.168.1.200 –dst-range 172.22.33.106 -j DROP +Matches are optional, but we can add multiple matches to a rule. In the +following example we will do that, using the ``iprange`` and the ``tcp`` +matches:: >>> import iptc >>> rule = iptc.Rule() @@ -129,5 +304,164 @@ Now we will add multiple matches to a rule. This one is the >>> match.dst_range = "172.22.33.106" >>> rule.add_match(match) >>> rule.target = iptc.Target(rule, "DROP") - >>> chain = iptc.Chain(iptc.Table.(iptc.Table.FILTER), "INPUT") + >>> chain = iptc.Chain(iptc.Table(iptc.Table.FILTER), "INPUT") >>> chain.insert_rule(rule) + +This is the ``python-iptables`` equivalent of the following iptables command:: + + # iptables -A INPUT -p tcp –destination-port 22 -m iprange –src-range 192.168.1.100-192.168.1.200 –dst-range 172.22.33.106 -j DROP + +You can of course negate matches, just like when you use ``!`` in front of a +match with iptables. For example:: + + >>> import iptc + >>> rule = iptc.Rule() + >>> match = iptc.Match(rule, "mac") + >>> match.mac_source = "!00:11:22:33:44:55" + >>> rule.add_match(match) + >>> rule.target = iptc.Target(rule, "ACCEPT") + >>> chain = iptc.Chain(iptc.Table(iptc.Table.FILTER), "INPUT") + >>> chain.insert_rule(rule) + +This results in:: + + $ sudo iptables -L -n + Chain INPUT (policy ACCEPT) + target prot opt source destination + ACCEPT all -- 0.0.0.0/0 0.0.0.0/0 MAC ! 00:11:22:33:44:55 + + Chain FORWARD (policy ACCEPT) + target prot opt source destination + + Chain OUTPUT (policy ACCEPT) + target prot opt source destination + +Counters +-------- +You can query rule and chain counters, e.g.:: + + >>> import iptc + >>> table = iptc.Table(iptc.Table.FILTER) + >>> chain = iptc.Chain(table, 'OUTPUT') + >>> for rule in chain.rules: + >>> (packets, bytes) = rule.get_counters() + >>> print packets, bytes + +However, the counters are only refreshed when the underlying low-level +iptables connection is refreshed in ``Table`` via ``table.refresh()``. For +example:: + + >>> import time, sys + >>> import iptc + >>> table = iptc.Table(iptc.Table.FILTER) + >>> chain = iptc.Chain(table, 'OUTPUT') + >>> for rule in chain.rules: + >>> (packets, bytes) = rule.get_counters() + >>> print packets, bytes + >>> print "Please send some traffic" + >>> sys.stdout.flush() + >>> time.sleep(3) + >>> for rule in chain.rules: + >>> # Here you will get back the same counter values as above + >>> (packets, bytes) = rule.get_counters() + >>> print packets, bytes + +This will show you the same counter values even if there was traffic hitting +your rules. You have to refresh your table to get update your counters:: + + >>> import time, sys + >>> import iptc + >>> table = iptc.Table(iptc.Table.FILTER) + >>> chain = iptc.Chain(table, 'OUTPUT') + >>> for rule in chain.rules: + >>> (packets, bytes) = rule.get_counters() + >>> print packets, bytes + >>> print "Please send some traffic" + >>> sys.stdout.flush() + >>> time.sleep(3) + >>> table.refresh() # Here: refresh table to update rule counters + >>> for rule in chain.rules: + >>> (packets, bytes) = rule.get_counters() + >>> print packets, bytes + +What is more, if you add:: + + iptables -A OUTPUT -p tcp --sport 80 + iptables -A OUTPUT -p tcp --sport 22 + +you can query rule and chain counters together with the protocol and sport(or +dport), e.g.:: + + >>> import iptc + >>> table = iptc.Table(iptc.Table.FILTER) + >>> chain = iptc.Chain(table, 'OUTPUT') + >>> for rule in chain.rules: + >>> for match in rule.matches: + >>> (packets, bytes) = rule.get_counters() + >>> print packets, bytes, match.name, match.sport + +Autocommit +---------- +``Python-iptables`` by default automatically performs an iptables commit after +each operation. That is, after you add a rule in ``python-iptables``, that +will take effect immediately. + +It may happen that you want to batch together certain operations. A typical +use case is traversing a chain and removing rules matching a specific +criteria. If you do this with autocommit enabled, after the first delete +operation, your chain's state will change and you have to restart the +traversal. You can do something like this:: + + >>> import iptc + >>> table = iptc.Table(iptc.Table.FILTER) + >>> removed = True + >>> chain = iptc.Chain(table, "FORWARD") + >>> while removed == True: + >>> removed = False + >>> for rule in chain.rules: + >>> if rule.out_interface and "eth0" in rule.out_interface: + >>> chain.delete_rule(rule) + >>> removed = True + >>> break + +This is clearly not ideal and the code is not very readable. An alternative is +to disable autocommits, traverse the chain, removing one or more rules, than +commit it:: + + >>> import iptc + >>> table = iptc.Table(iptc.Table.FILTER) + >>> table.autocommit = False + >>> chain = iptc.Chain(table, "FORWARD") + >>> for rule in chain.rules: + >>> if rule.out_interface and "eth0" in rule.out_interface: + >>> chain.delete_rule(rule) + >>> table.commit() + >>> table.autocommit = True + +The drawback is that `Table` is a singleton, and if you disable autocommit, it +will be disabled for all instances of that `Table`. + +Easy rules with dictionaries +---------------------------- +To simplify operations with ``python-iptables`` rules we have included support to define and convert Rules object into python dictionaries. + + >>> import iptc + >>> table = iptc.Table(iptc.Table.FILTER) + >>> chain = iptc.Chain(table, "INPUT") + >>> # Create an iptc.Rule object from dictionary + >>> rule_d = {'comment': {'comment': 'Match tcp.22'}, 'protocol': 'tcp', 'target': 'ACCEPT', 'tcp': {'dport': '22'}} + >>> rule = iptc.easy.encode_iptc_rule(rule_d) + >>> # Obtain a dictionary representation from the iptc.Rule + >>> iptc.easy.decode_iptc_rule(rule) + {'tcp': {'dport': '22'}, 'protocol': 'tcp', 'comment': {'comment': 'Match tcp.22'}, 'target': 'ACCEPT'} + + +Known Issues +============ + +These issues are mainly caused by complex interaction with upstream's +Netfilter implementation, and will require quite significant effort to +fix. Workarounds are available. + +- The ``hashlimit`` match requires explicitly setting ``hashlimit_htable_expire``. See `Issue #201 `_. +- The ``NOTRACK`` target is problematic; use ``CT --notrack`` instead. See `Issue #204 `_. diff --git a/doc/intro.rst b/doc/intro.rst index 63446f1..97ec5a0 100644 --- a/doc/intro.rst +++ b/doc/intro.rst @@ -19,13 +19,56 @@ manpage puts it: rule specifies what to do with a packet that matches. This is called a `target`, which may be a jump to a user-defined chain in the same table. -``Python-iptables`` provides python bindings to iptables under Linux. -Interoperability with iptables is achieved via using the iptables C libraries -(``libiptc``, ``libxtables``, and the iptables extensions), not calling the -iptables binary and parsing its output. +``Python-iptables`` provides a pythonesque wrapper via python bindings to +iptables under Linux. Interoperability with iptables is achieved via using +the iptables C libraries (``libiptc``, ``libxtables``, and the iptables +extensions), not calling the iptables binary and parsing its output. It is +meant primarily for dynamic and/or complex routers and firewalls, where rules +are often updated or changed, or Python programs wish to interface with the +Linux iptables framework.. -Compiling and installing ------------------------- +If you are looking for ``ebtables`` python bindings, check out +`python-ebtables `_. + +``Python-iptables`` supports Python 2.6, 2.7 and 3.4. + +.. image:: http://api.flattr.com/button/flattr-badge-large.png + :target: https://flattr.com/submit/auto?user_id=ldx&url=https%3A%2F%2Fgithub.com%2Fldx%2Fpython-iptables + :alt: Flattr + +.. image:: https://pypip.in/v/python-iptables/badge.png + :target: https://pypi.python.org/pypi/python-iptables + :alt: Latest Release + +.. image:: https://travis-ci.org/ldx/python-iptables.png?branch=master + :target: https://travis-ci.org/ldx/python-iptables + :alt: Build Status + +.. image:: https://coveralls.io/repos/ldx/python-iptables/badge.svg?branch=codecoverage + :target: https://coveralls.io/r/ldx/python-iptables?branch=codecoverage + :alt: Coverage Status + +.. image:: https://landscape.io/github/ldx/python-iptables/codecoverage/landscape.svg + :target: https://landscape.io/github/ldx/python-iptables/codecoverage + :alt: Code Health + +.. image:: https://pypip.in/d/python-iptables/badge.png + :target: https://pypi.python.org/pypi/python-iptables + :alt: Number of Downloads + +.. image:: https://pypip.in/license/python-iptables/badge.png + :target: https://pypi.python.org/pypi/python-iptables + :alt: License + +Installing via pip +------------------ + +The usual way:: + + pip install --upgrade python-iptables + +Compiling from source +---------------------- First make sure you have iptables installed (most Linux distributions install it by default). ``Python-iptables`` needs the shared libraries ``libiptc.so`` @@ -50,71 +93,11 @@ installs into ``/usr/local/lib``. Now you can run the tests:: - % sudo PATH=$PATH ./test.py + % sudo PATH=$PATH python setup.py test WARNING: this test will manipulate iptables rules. Don't do this on a production machine. Would you like to continue? y/n y - test_table6 (iptc.test.test_iptc.TestTable6) ... ok - test_refresh (iptc.test.test_iptc.TestTable) ... ok - test_table (iptc.test.test_iptc.TestTable) ... ok - test_builtin_chain (iptc.test.test_iptc.TestChain) ... ok - test_chain (iptc.test.test_iptc.TestChain) ... ok - test_chain_counters (iptc.test.test_iptc.TestChain) ... ok - test_chain_policy (iptc.test.test_iptc.TestChain) ... ok - test_chains (iptc.test.test_iptc.TestChain) ... ok - test_create_chain (iptc.test.test_iptc.TestChain) ... ok - test_is_chain (iptc.test.test_iptc.TestChain) ... ok - test_rule_address (iptc.test.test_iptc.TestRule6) ... ok - test_rule_compare (iptc.test.test_iptc.TestRule6) ... ok - test_rule_interface (iptc.test.test_iptc.TestRule6) ... ok - test_rule_iterate (iptc.test.test_iptc.TestRule6) ... ok - test_rule_protocol (iptc.test.test_iptc.TestRule6) ... ok - test_rule_standard_target (iptc.test.test_iptc.TestRule6) ... ok - test_rule_address (iptc.test.test_iptc.TestRule) ... ok - test_rule_compare (iptc.test.test_iptc.TestRule) ... ok - test_rule_fragment (iptc.test.test_iptc.TestRule) ... ok - test_rule_interface (iptc.test.test_iptc.TestRule) ... ok - test_rule_iterate (iptc.test.test_iptc.TestRule) ... ok - test_rule_protocol (iptc.test.test_iptc.TestRule) ... ok - test_rule_standard_target (iptc.test.test_iptc.TestRule) ... ok - - ---------------------------------------------------------------------- - Ran 23 tests in 0.013s - - OK - test_match_compare (iptc.test.test_matches.TestMatch) ... ok - test_match_create (iptc.test.test_matches.TestMatch) ... ok - test_match_parameters (iptc.test.test_matches.TestMatch) ... ok - test_udp_insert (iptc.test.test_matches.TestXTUdpMatch) ... ok - test_udp_port (iptc.test.test_matches.TestXTUdpMatch) ... ok - test_mark (iptc.test.test_matches.TestXTMarkMatch) ... ok - test_mark_insert (iptc.test.test_matches.TestXTMarkMatch) ... ok - test_limit (iptc.test.test_matches.TestXTLimitMatch) ... ok - test_limit_insert (iptc.test.test_matches.TestXTLimitMatch) ... ok - test_comment (iptc.test.test_matches.TestCommentMatch) ... ok - test_iprange (iptc.test.test_matches.TestIprangeMatch) ... ok - test_iprange_tcpdport (iptc.test.test_matches.TestIprangeMatch) ... ok - - ---------------------------------------------------------------------- - Ran 12 tests in 0.024s - - OK - test_target_compare (iptc.test.test_targets.TestTarget) ... ok - test_target_create (iptc.test.test_targets.TestTarget) ... ok - test_target_parameters (iptc.test.test_targets.TestTarget) ... ok - test_insert (iptc.test.test_targets.TestXTClusteripTarget) ... ok - test_mode (iptc.test.test_targets.TestXTClusteripTarget) ... ok - test_insert (iptc.test.test_targets.TestIPTRedirectTarget) ... ok - test_mode (iptc.test.test_targets.TestIPTRedirectTarget) ... ok - test_insert (iptc.test.test_targets.TestXTTosTarget) ... ok - test_mode (iptc.test.test_targets.TestXTTosTarget) ... ok - test_insert (iptc.test.test_targets.TestIPTMasqueradeTarget) ... ok - test_mode (iptc.test.test_targets.TestIPTMasqueradeTarget) ... ok - - ---------------------------------------------------------------------- - Ran 11 tests in 0.015s - - OK + [...] The ``PATH=$PATH`` part is necessary after ``sudo`` if you have installed into a ``virtualenv``, since ``sudo`` will reset your environment to a system @@ -129,6 +112,29 @@ package can be imported:: Of course you need to be root to be able to use iptables. +Using a custom iptables install +------------------------------- + +If you are stuck on a system with an old version of ``iptables``, you can +install a more up to date version to a custom location, and ask +``python-iptables`` to use libraries at that location. + +To install ``iptables`` to ``/tmp/iptables``:: + + % git clone git://git.netfilter.org/iptables && cd iptables + % ./autogen.sh + % ./configure --prefix=/tmp/iptables + % make + % make install + +Make sure the dependencies ``iptables`` needs are installed. + +Now you can point ``python-iptables`` to this install path via:: + + % sudo PATH=$PATH IPTABLES_LIBDIR=/tmp/iptables/lib XTABLES_LIBDIR=/tmp/iptables/lib/xtables python + >>> import iptc + >>> + What is supported ----------------- @@ -136,9 +142,6 @@ The basic iptables framework and all the match/target extensions are supported by ``python-iptables``, including IPv4 and IPv6 ones. All IPv4 and IPv6 tables are supported as well. -Contact -------- - -ldx (at) nilvec.com +Full documentation with API reference is available here_. -http://nilvec.com +.. _here: http://ldx.github.com/python-iptables/ diff --git a/doc/usage.rst b/doc/usage.rst index 7534bcd..40c1584 100644 --- a/doc/usage.rst +++ b/doc/usage.rst @@ -76,6 +76,13 @@ Target :members: :inherited-members: +Rule +---- + +.. autoclass:: Rule + :members: + :inherited-members: + Rule6 ----- diff --git a/iptc/__init__.py b/iptc/__init__.py index 8544917..68c19bc 100644 --- a/iptc/__init__.py +++ b/iptc/__init__.py @@ -4,12 +4,13 @@ .. module:: iptc :synopsis: Python bindings for libiptc. -.. moduleauthor:: Nilvec +.. moduleauthor:: Vilmos Nebehaj """ -from ip4tc import (is_table_available, Table, Chain, Rule, Match, Target, - Policy, IPTCError) -from ip6tc import is_table6_available, Table6, Rule6 -from xtables import XTablesError +from iptc.ip4tc import (is_table_available, Table, Chain, Rule, Match, Target, Policy, IPTCError) +from iptc.ip6tc import is_table6_available, Table6, Rule6 +from iptc.errors import * +import iptc.easy + __all__ = [] diff --git a/iptc/easy.py b/iptc/easy.py new file mode 100644 index 0000000..d4df85e --- /dev/null +++ b/iptc/easy.py @@ -0,0 +1,461 @@ +# -*- coding: utf-8 -*- + +# TODO: +# - Add documentation +# - Add HowToUse examples + +from .ip4tc import Rule, Table, Chain, IPTCError +from .ip6tc import Rule6, Table6 + +_BATCH_MODE = False + +def flush_all(ipv6=False): + """ Flush all available tables """ + for table in get_tables(ipv6): + flush_table(table, ipv6) + +def flush_table(table, ipv6=False, raise_exc=True): + """ Flush a table """ + try: + iptc_table = _iptc_gettable(table, ipv6) + iptc_table.flush() + except Exception as e: + if raise_exc: raise + +def flush_chain(table, chain, ipv6=False, raise_exc=True): + """ Flush a chain in table """ + try: + iptc_chain = _iptc_getchain(table, chain, ipv6) + iptc_chain.flush() + except Exception as e: + if raise_exc: raise + +def zero_all(table, ipv6=False): + """ Zero all tables """ + for table in get_tables(ipv6): + zero_table(table, ipv6) + +def zero_table(table, ipv6=False): + """ Zero a table """ + iptc_table = _iptc_gettable(table, ipv6) + iptc_table.zero_entries() + +def zero_chain(table, chain, ipv6=False): + """ Zero a chain in table """ + iptc_chain = _iptc_getchain(table, chain, ipv6) + iptc_chain.zero_counters() + +def has_chain(table, chain, ipv6=False): + """ Return True if chain exists in table False otherwise """ + return _iptc_gettable(table, ipv6).is_chain(chain) + +def has_rule(table, chain, rule_d, ipv6=False): + """ Return True if rule exists in chain False otherwise """ + iptc_chain = _iptc_getchain(table, chain, ipv6) + iptc_rule = encode_iptc_rule(rule_d, ipv6) + return iptc_rule in iptc_chain.rules + +def add_chain(table, chain, ipv6=False, raise_exc=True): + """ Return True if chain was added successfully to a table, raise Exception otherwise """ + try: + iptc_table = _iptc_gettable(table, ipv6) + iptc_table.create_chain(chain) + return True + except Exception as e: + if raise_exc: raise + return False + +def add_rule(table, chain, rule_d, position=0, ipv6=False): + """ Add a rule to a chain in a given position. 0=append, 1=first, n=nth position """ + iptc_chain = _iptc_getchain(table, chain, ipv6) + iptc_rule = encode_iptc_rule(rule_d, ipv6) + if position == 0: + # Insert rule in last position -> append + iptc_chain.append_rule(iptc_rule) + elif position > 0: + # Insert rule in given position -> adjusted as iptables CLI + iptc_chain.insert_rule(iptc_rule, position - 1) + elif position < 0: + # Insert rule in given position starting from bottom -> not available in iptables CLI + nof_rules = len(iptc_chain.rules) + _position = position + nof_rules + # Insert at the top if the position has looped over + if _position <= 0: + _position = 0 + iptc_chain.insert_rule(iptc_rule, _position) + +def insert_rule(table, chain, rule_d, ipv6=False): + """ Add a rule to a chain in the 1st position """ + add_rule(table, chain, rule_d, position=1, ipv6=ipv6) + +def delete_chain(table, chain, ipv6=False, flush=False, raise_exc=True): + """ Delete a chain """ + try: + if flush: + flush_chain(table, chain, ipv6, raise_exc) + iptc_table = _iptc_gettable(table, ipv6) + iptc_table.delete_chain(chain) + except Exception as e: + if raise_exc: raise + +def delete_rule(table, chain, rule_d, ipv6=False, raise_exc=True): + """ Delete a rule from a chain """ + try: + iptc_chain = _iptc_getchain(table, chain, ipv6) + iptc_rule = encode_iptc_rule(rule_d, ipv6) + iptc_chain.delete_rule(iptc_rule) + except Exception as e: + if raise_exc: raise + +def get_tables(ipv6=False): + """ Get all tables """ + iptc_tables = _iptc_gettables(ipv6) + return [t.name for t in iptc_tables] + +def get_chains(table, ipv6=False): + """ Return the existing chains of a table """ + iptc_table = _iptc_gettable(table, ipv6) + return [iptc_chain.name for iptc_chain in iptc_table.chains] + +def get_rule(table, chain, position=0, ipv6=False, raise_exc=True): + """ Get a rule from a chain in a given position. 0=all rules, 1=first, n=nth position """ + try: + if position == 0: + # Return all rules + return dump_chain(table, chain, ipv6) + elif position > 0: + # Return specific rule by position + iptc_chain = _iptc_getchain(table, chain, ipv6) + iptc_rule = iptc_chain.rules[position - 1] + return decode_iptc_rule(iptc_rule, ipv6) + elif position < 0: + # Return last rule -> not available in iptables CLI + iptc_chain = _iptc_getchain(table, chain, ipv6) + iptc_rule = iptc_chain.rules[position] + return decode_iptc_rule(iptc_rule, ipv6) + except Exception as e: + if raise_exc: raise + +def replace_rule(table, chain, old_rule_d, new_rule_d, ipv6=False): + """ Replaces an existing rule of a chain """ + iptc_chain = _iptc_getchain(table, chain, ipv6) + iptc_old_rule = encode_iptc_rule(old_rule_d, ipv6) + iptc_new_rule = encode_iptc_rule(new_rule_d, ipv6) + iptc_chain.replace_rule(iptc_new_rule, iptc_chain.rules.index(iptc_old_rule)) + +def get_rule_counters(table, chain, rule_d, ipv6=False): + """ Return a tuple with the rule counters (numberOfBytes, numberOfPackets) """ + if not has_rule(table, chain, rule_d, ipv6): + raise AttributeError('Chain <{}@{}> has no rule <{}>'.format(chain, table, rule_d)) + iptc_chain = _iptc_getchain(table, chain, ipv6) + iptc_rule = encode_iptc_rule(rule_d, ipv6) + iptc_rule_index = iptc_chain.rules.index(iptc_rule) + return iptc_chain.rules[iptc_rule_index].get_counters() + +def get_rule_position(table, chain, rule_d, ipv6=False): + """ Return the position of a rule within a chain """ + if not has_rule(table, chain, rule_d): + raise AttributeError('Chain <{}@{}> has no rule <{}>'.format(chain, table, rule_d)) + iptc_chain = _iptc_getchain(table, chain, ipv6) + iptc_rule = encode_iptc_rule(rule_d, ipv6) + return iptc_chain.rules.index(iptc_rule) + + +def test_rule(rule_d, ipv6=False): + """ Return True if the rule is a well-formed dictionary, False otherwise """ + try: + encode_iptc_rule(rule_d, ipv6) + return True + except: + return False + +def test_match(name, value, ipv6=False): + """ Return True if the match is valid, False otherwise """ + try: + iptc_rule = Rule6() if ipv6 else Rule() + _iptc_setmatch(iptc_rule, name, value) + return True + except: + return False + +def test_target(name, value, ipv6=False): + """ Return True if the target is valid, False otherwise """ + try: + iptc_rule = Rule6() if ipv6 else Rule() + _iptc_settarget(iptc_rule, {name:value}) + return True + except: + return False + + +def get_policy(table, chain, ipv6=False): + """ Return the default policy of chain in a table """ + iptc_chain = _iptc_getchain(table, chain, ipv6) + return iptc_chain.get_policy().name + +def set_policy(table, chain, policy='ACCEPT', ipv6=False): + """ Set the default policy of chain in a table """ + iptc_chain = _iptc_getchain(table, chain, ipv6) + iptc_chain.set_policy(policy) + + +def dump_all(ipv6=False): + """ Return a dictionary representation of all tables """ + return {table: dump_table(table, ipv6) for table in get_tables(ipv6)} + +def dump_table(table, ipv6=False): + """ Return a dictionary representation of a table """ + return {chain: dump_chain(table, chain, ipv6) for chain in get_chains(table, ipv6)} + +def dump_chain(table, chain, ipv6=False): + """ Return a list with the dictionary representation of the rules of a table """ + iptc_chain = _iptc_getchain(table, chain, ipv6) + return [decode_iptc_rule(iptc_rule, ipv6) for iptc_rule in iptc_chain.rules] + + +def batch_begin(table = None, ipv6=False): + """ Disable autocommit on a table """ + _BATCH_MODE = True + if table: + tables = (table, ) + else: + tables = get_tables(ipv6) + for table in tables: + iptc_table = _iptc_gettable(table, ipv6) + iptc_table.autocommit = False + +def batch_end(table = None, ipv6=False): + """ Enable autocommit on table and commit changes """ + _BATCH_MODE = False + if table: + tables = (table, ) + else: + tables = get_tables(ipv6) + for table in tables: + iptc_table = _iptc_gettable(table, ipv6) + iptc_table.autocommit = True + +def batch_add_chains(table, chains, ipv6=False, flush=True): + """ Add multiple chains to a table """ + iptc_table = _batch_begin_table(table, ipv6) + for chain in chains: + if iptc_table.is_chain(chain): + iptc_chain = Chain(iptc_table, chain) + else: + iptc_chain = iptc_table.create_chain(chain) + if flush: + iptc_chain.flush() + _batch_end_table(table, ipv6) + +def batch_delete_chains(table, chains, ipv6=False): + """ Delete multiple chains of a table """ + iptc_table = _batch_begin_table(table, ipv6) + for chain in chains: + if iptc_table.is_chain(chain): + iptc_chain = Chain(iptc_table, chain) + iptc_chain.flush() + iptc_table.delete_chain(chain) + _batch_end_table(table, ipv6) + +def batch_add_rules(table, batch_rules, ipv6=False): + """ Add multiple rules to a table with format (chain, rule_d, position) """ + iptc_table = _batch_begin_table(table, ipv6) + for (chain, rule_d, position) in batch_rules: + iptc_chain = Chain(iptc_table, chain) + iptc_rule = encode_iptc_rule(rule_d, ipv6) + if position == 0: + # Insert rule in last position -> append + iptc_chain.append_rule(iptc_rule) + elif position > 0: + # Insert rule in given position -> adjusted as iptables CLI + iptc_chain.insert_rule(iptc_rule, position-1) + elif position < 0: + # Insert rule in given position starting from bottom -> not available in iptables CLI + nof_rules = len(iptc_chain.rules) + iptc_chain.insert_rule(iptc_rule, position + nof_rules) + _batch_end_table(table, ipv6) + +def batch_delete_rules(table, batch_rules, ipv6=False, raise_exc=True): + """ Delete multiple rules from table with format (chain, rule_d) """ + try: + iptc_table = _batch_begin_table(table, ipv6) + for (chain, rule_d) in batch_rules: + iptc_chain = Chain(iptc_table, chain) + iptc_rule = encode_iptc_rule(rule_d, ipv6) + iptc_chain.delete_rule(iptc_rule) + _batch_end_table(table, ipv6) + except Exception as e: + if raise_exc: raise + + +def encode_iptc_rule(rule_d, ipv6=False): + """ Return a Rule(6) object from the input dictionary """ + # Sanity check + assert(isinstance(rule_d, dict)) + # Basic rule attributes + rule_attr = ('src', 'dst', 'protocol', 'in-interface', 'out-interface', 'fragment') + iptc_rule = Rule6() if ipv6 else Rule() + # Set default target + rule_d.setdefault('target', '') + # Avoid issues with matches that require basic parameters to be configured first + for name in rule_attr: + if name in rule_d: + setattr(iptc_rule, name.replace('-', '_'), rule_d[name]) + for name, value in rule_d.items(): + try: + if name in rule_attr: + continue + elif name == 'counters': + _iptc_setcounters(iptc_rule, value) + elif name == 'target': + _iptc_settarget(iptc_rule, value) + else: + _iptc_setmatch(iptc_rule, name, value) + except Exception as e: + #print('Ignoring unsupported field <{}:{}>'.format(name, value)) + continue + return iptc_rule + +def decode_iptc_rule(iptc_rule, ipv6=False): + """ Return a dictionary representation of the Rule(6) object + Note: host IP addresses are appended their corresponding CIDR """ + d = {} + if ipv6==False and iptc_rule.src != '0.0.0.0/0.0.0.0': + _ip, _netmask = iptc_rule.src.split('/') + _netmask = _netmask_v4_to_cidr(_netmask) + d['src'] = '{}/{}'.format(_ip, _netmask) + elif ipv6==True and iptc_rule.src != '::/0': + d['src'] = iptc_rule.src + if ipv6==False and iptc_rule.dst != '0.0.0.0/0.0.0.0': + _ip, _netmask = iptc_rule.dst.split('/') + _netmask = _netmask_v4_to_cidr(_netmask) + d['dst'] = '{}/{}'.format(_ip, _netmask) + elif ipv6==True and iptc_rule.dst != '::/0': + d['dst'] = iptc_rule.dst + if iptc_rule.protocol != 'ip': + d['protocol'] = iptc_rule.protocol + if iptc_rule.in_interface is not None: + d['in-interface'] = iptc_rule.in_interface + if iptc_rule.out_interface is not None: + d['out-interface'] = iptc_rule.out_interface + if ipv6 == False and iptc_rule.fragment: + d['fragment'] = iptc_rule.fragment + for m in iptc_rule.matches: + if m.name not in d: + d[m.name] = m.get_all_parameters() + elif isinstance(d[m.name], list): + d[m.name].append(m.get_all_parameters()) + else: + d[m.name] = [d[m.name], m.get_all_parameters()] + if iptc_rule.target and iptc_rule.target.name and len(iptc_rule.target.get_all_parameters()): + name = iptc_rule.target.name.replace('-', '_') + d['target'] = {name:iptc_rule.target.get_all_parameters()} + elif iptc_rule.target and iptc_rule.target.name: + if iptc_rule.target.goto: + d['target'] = {'goto':iptc_rule.target.name} + else: + d['target'] = iptc_rule.target.name + # Get counters + d['counters'] = iptc_rule.counters + # Return a filtered dictionary + return _filter_empty_field(d) + +### INTERNAL FUNCTIONS ### +def _iptc_table_available(table, ipv6=False): + """ Return True if the table is available, False otherwise """ + try: + iptc_table = Table6(table) if ipv6 else Table(table) + return True + except: + return False + +def _iptc_gettables(ipv6=False): + """ Return an updated view of all available iptc_table """ + iptc_cls = Table6 if ipv6 else Table + return [_iptc_gettable(t, ipv6) for t in iptc_cls.ALL if _iptc_table_available(t, ipv6)] + +def _iptc_gettable(table, ipv6=False): + """ Return an updated view of an iptc_table """ + iptc_table = Table6(table) if ipv6 else Table(table) + if _BATCH_MODE is False: + iptc_table.commit() + iptc_table.refresh() + return iptc_table + +def _iptc_getchain(table, chain, ipv6=False, raise_exc=True): + """ Return an iptc_chain of an updated table """ + try: + iptc_table = _iptc_gettable(table, ipv6) + if not iptc_table.is_chain(chain): + raise AttributeError('Table <{}> has no chain <{}>'.format(table, chain)) + return Chain(iptc_table, chain) + except Exception as e: + if raise_exc: raise + +def _iptc_setcounters(iptc_rule, value): + # Value is a tuple (numberOfBytes, numberOfPackets) + iptc_rule.counters = value + +def _iptc_setmatch(iptc_rule, name, value): + # Iterate list/tuple recursively + if isinstance(value, list) or isinstance(value, tuple): + for inner_value in value: + _iptc_setmatch(iptc_rule, name, inner_value) + # Assign dictionary value + elif isinstance(value, dict): + iptc_match = iptc_rule.create_match(name) + [iptc_match.set_parameter(k, v) for k, v in value.items()] + # Assign value directly + else: + iptc_match = iptc_rule.create_match(name) + iptc_match.set_parameter(name, value) + +def _iptc_settarget(iptc_rule, value): + # Target is dictionary - Use only 1st pair key/value + if isinstance(value, dict): + t_name, t_value = next(iter(value.items())) + if t_name == 'goto': + iptc_target = iptc_rule.create_target(t_value, goto=True) + else: + iptc_target = iptc_rule.create_target(t_name) + [iptc_target.set_parameter(k, v) for k, v in t_value.items()] + # Simple target + else: + iptc_target = iptc_rule.create_target(value) + +def _batch_begin_table(table, ipv6=False): + """ Disable autocommit on a table """ + iptc_table = _iptc_gettable(table, ipv6) + iptc_table.autocommit = False + return iptc_table + +def _batch_end_table(table, ipv6=False): + """ Enable autocommit on table and commit changes """ + iptc_table = _iptc_gettable(table, ipv6) + iptc_table.autocommit = True + return iptc_table + +def _filter_empty_field(data_d): + """ + Remove empty lists from dictionary values + Before: {'target': {'CHECKSUM': {'checksum-fill': []}}} + After: {'target': {'CHECKSUM': {'checksum-fill': ''}}} + Before: {'tcp': {'dport': ['22']}}} + After: {'tcp': {'dport': '22'}}} + """ + for k, v in data_d.items(): + if isinstance(v, dict): + data_d[k] = _filter_empty_field(v) + elif isinstance(v, list) and len(v) != 0: + v = [_filter_empty_field(_v) if isinstance(_v, dict) else _v for _v in v ] + if isinstance(v, list) and len(v) == 1: + data_d[k] = v.pop() + elif isinstance(v, list) and len(v) == 0: + data_d[k] = '' + return data_d + +def _netmask_v4_to_cidr(netmask_addr): + # Implement Subnet Mask conversion without dependencies + return sum([bin(int(x)).count('1') for x in netmask_addr.split('.')]) + +### /INTERNAL FUNCTIONS ### diff --git a/iptc/errors.py b/iptc/errors.py new file mode 100644 index 0000000..fd25d9a --- /dev/null +++ b/iptc/errors.py @@ -0,0 +1,8 @@ +# -*- coding: utf-8 -*- + + +class XTablesError(Exception): + """Raised when an xtables call fails for some reason.""" + + +__all__ = ['XTablesError'] diff --git a/iptc/ip4tc.py b/iptc/ip4tc.py index bba4ec3..90f5ee7 100644 --- a/iptc/ip4tc.py +++ b/iptc/ip4tc.py @@ -2,21 +2,31 @@ import os import re -import ctypes as ct import shlex +import sys +import ctypes as ct import socket import struct import weakref -from util import find_library -from xtables import (XT_INV_PROTO, NFPROTO_IPV4, XTablesError, xtables, - xt_align, xt_counters, xt_entry_target, xt_entry_match) +from .util import find_library, load_kernel, find_libc +from .xtables import (XT_INV_PROTO, NFPROTO_IPV4, XTablesError, xtables, + xt_align, xt_counters, xt_entry_target, xt_entry_match) __all__ = ["Table", "Chain", "Rule", "Match", "Target", "Policy", "IPTCError"] +try: + load_kernel("ip_tables") +except: + pass + +# Add IPPROTO_SCTP to socket module if not available +if not hasattr(socket, 'IPPROTO_SCTP'): + setattr(socket, 'IPPROTO_SCTP', 132) + _IFNAMSIZ = 16 -_libc = ct.CDLL("libc.so.6") +_libc = find_libc() _get_errno_loc = _libc.__errno_location _get_errno_loc.restype = ct.POINTER(ct.c_int) _malloc = _libc.malloc @@ -26,13 +36,20 @@ _free.restype = None _free.argtypes = [ct.POINTER(ct.c_ubyte)] +# Make sure xt_params is set up. +xtables(NFPROTO_IPV4) + def is_table_available(name): try: + if name in Table.existing_table_names: + return Table.existing_table_names[name] Table(name) + Table.existing_table_names[name] = True return True except IPTCError: pass + Table.existing_table_names[name] = False return False @@ -89,7 +106,7 @@ class ipt_entry(ct.Structure): class IPTCError(Exception): """This exception is raised when a low-level libiptc error occurs. - It contains a short description about the error that occured while + It contains a short description about the error that occurred while executing an iptables operation. """ @@ -100,7 +117,7 @@ class IPTCError(Exception): class iptc(object): """This class contains all libiptc API calls.""" iptc_init = _libiptc.iptc_init - iptc_init.restype = ct.c_void_p + iptc_init.restype = ct.POINTER(ct.c_int) iptc_init.argstype = [ct.c_char_p] iptc_free = _libiptc.iptc_free @@ -117,11 +134,11 @@ class iptc(object): iptc_first_chain = _libiptc.iptc_first_chain iptc_first_chain.restype = ct.c_char_p - iptc_first_chain.argstype = [ct.c_char_p, ct.c_void_p] + iptc_first_chain.argstype = [ct.c_void_p] iptc_next_chain = _libiptc.iptc_next_chain iptc_next_chain.restype = ct.c_char_p - iptc_next_chain.argstype = [ct.c_char_p, ct.c_void_p] + iptc_next_chain.argstype = [ct.c_void_p] iptc_is_chain = _libiptc.iptc_is_chain iptc_is_chain.restype = ct.c_int @@ -211,9 +228,9 @@ class iptc(object): # Check the packet `e' on chain `chain'. Returns the verdict, or # NULL and sets errno. - #iptc_check_packet = _libiptc.iptc_check_packet - #iptc_check_packet.restype = ct.c_char_p - #iptc_check_packet.argstype = [ct.c_char_p, ct.POINTER(ipt), ct.c_void_p] + # iptc_check_packet = _libiptc.iptc_check_packet + # iptc_check_packet.restype = ct.c_char_p + # iptc_check_packet.argstype = [ct.c_char_p, ct.POINTER(ipt), ct.c_void_p] # Get the number of references to this chain iptc_get_references = _libiptc.iptc_get_references @@ -245,7 +262,7 @@ class iptc(object): class IPTCModule(object): """Superclass for Match and Target.""" pattern = re.compile( - '\s*(!)?\s*--([-\w]+)\s+(!)?\s*("?[^"]*?"?)(?=\s*(?:!?\s*--|$))') + '\s*(!)?\s*--([-\w]+)\s+(!)?\s*"?([^"]*?)"?(?=\s*(?:!?\s*--|$))') def __init__(self): self._name = None @@ -256,26 +273,62 @@ def __init__(self): self._ptrptr = None raise NotImplementedError() + def set_parameter(self, parameter, value=None): + """ + Set a parameter for target or match extension, with an optional value. + + @param parameter: name of the parameter to set + @type parameter: C{str} + + @param value: optional value of the parameter, defaults to C{None} + @type value: C{str} or a C{list} of C{str} + """ + if value is None: + value = "" + + return self.parse(parameter.replace("_", "-"), value) + def parse(self, parameter, value): + # Parameter name must always be a string. + parameter = parameter.encode() + + # Check if we are dealing with an inverted parameter value. + inv = ct.c_int(0) + if len(value) > 0 and value[0] == "!": + inv = ct.c_int(1) + value = value[1:] + + # Value can be either a string, or a list of strings, e.g. "8888", + # "!0:65535" or ["!", "example_set", "dst"]. + args = [] + + is_str = isinstance(value, str) + try: + if not is_str: + is_str = isinstance(value, unicode) + except: + pass + + if is_str: + args = [value.encode()] + else: + try: + args = [val.encode() for val in value] + except: + raise TypeError("Invalid parameter value: " + "must be string or list of strings") + if not self._module.extra_opts and not self._module.x6_options: raise AttributeError("%s: invalid parameter %s" % (self._module.name, parameter)) - parameter = parameter.rstrip().lstrip() - value = value.rstrip().lstrip() - if "!" in value: - inv = ct.c_int(1) - value = value.replace("!", "") - else: - inv = ct.c_int(0) + parameter = parameter.strip() - args = shlex.split(value) - if not args: - args = [value] N = len(args) + argv = (ct.c_char_p * (N + 1))() argv[0] = parameter - for i in xrange(N): + for i in range(N): argv[i + 1] = args[i] entry = self._rule.entry and ct.pointer(self._rule.entry) or None @@ -287,32 +340,56 @@ def _parse(self, argv, inv, entry): def final_check(self): if self._module: + self._update_parameters() self._final_check() # subclasses override this def _final_check(self): raise NotImplementedError() - def save(self, name): - return self._save(name, self.rule.get_ip()) + def _get_saved_buf(self, ip): + if not self._module or not self._module.save: + return None - def _save(self, name, ip): - if self._module and self._module.save: - # redirect C stdout to a pipe and read back the output of m->save + # redirect C stdout to a pipe and read back the output of m->save + + # Flush stdout to avoid getting buffered results + sys.stdout.flush() + # Save the current C stdout. + stdout = os.dup(1) + try: + # Create a pipe and use the write end to replace the original C + # stdout. pipes = os.pipe() - saved_out = os.dup(1) os.dup2(pipes[1], 1) self._xt.save(self._module, ip, self._ptr) + + # Use the read end to read whatever was written. buf = os.read(pipes[0], 1024) - os.dup2(saved_out, 1) + + # Clean up the pipe. os.close(pipes[0]) os.close(pipes[1]) - os.close(saved_out) - if name: - return self._get_value(buf, name) - else: - return self._get_all_values(buf) - else: + return buf + finally: + # Put the original C stdout back in place. + os.dup2(stdout, 1) + + # Clean up the copy we made. + os.close(stdout) + + def save(self, name): + return self._save(name, self.rule.get_ip()) + + def _save(self, name, ip): + buf = self._get_saved_buf(ip).decode() + if buf is None: + return None + if not self._module or not self._module.save: return None + if name: + return self._get_value(buf, name) + else: + return self._get_all_values(buf) def _get_all_values(self, buf): table = {} # variable -> (value, inverted) @@ -337,27 +414,50 @@ def _get_value(self, buf, name): def get_all_parameters(self): params = {} ip = self.rule.get_ip() - if self._module and self._module.save: - # redirect C stdout to a pipe and read back the output of m->save - pipes = os.pipe() - saved_out = os.dup(1) - os.dup2(pipes[1], 1) - self._xt.save(self._module, ip, self._ptr) - buf = os.read(pipes[0], 1024) - os.dup2(saved_out, 1) - os.close(pipes[0]) - os.close(pipes[1]) - os.close(saved_out) - - res = re.findall(IPTCModule.pattern, buf) - for x in res: - params[x[1]] = "%s%s" % ((x[0] or x[2]) and "!" or "", x[3]) - + buf = self._get_saved_buf(ip) + if buf is None: + return params + if type(buf) != str: + # In Python3, string and bytes are different types. + buf = buf.decode() + res = shlex.split(buf) + res.reverse() + inv = False + key = None + while len(res) > 0: + x = res.pop() + if x == '!': + # Next parameter is negated. + inv = True + continue + if x.startswith('--'): # This is a parameter name. + key = x[2:] + if inv: + params[key] = ['!'] + else: + params[key] = [] + inv = False + continue + # At this point key should be set, unless the output from save is + # not formatted right. Let's be defensive, since some users + # reported that problem. + if key is not None: + params[key].append(x) # This is a parameter value. return params def _update_parameters(self): - for k, v in self.get_all_parameters().iteritems(): - self.__setattr__(k, v) + params = self.get_all_parameters().items() + self.reset() + for k, v in params: + self.set_parameter(k, v) + + def _get_alias_name(self): + if not self._module or not self._ptr: + return None + alias = getattr(self._module, 'alias', None) + if not alias: + return None + return self._module.alias(self._ptr).decode() def __setattr__(self, name, value): if not name.startswith('_') and name not in dir(self): @@ -377,7 +477,8 @@ def _get_parameters(self): contain those set by the module by default too.""" def _get_name(self): - return self._name + alias = self._get_alias_name() + return alias and alias or self._name name = property(_get_name) """Name of this target or match.""" @@ -390,6 +491,20 @@ def _set_rule(self, rule): """The rule this target or match belong to.""" +class _Buffer(object): + def __init__(self, size=0): + if size > 0: + self.buffer = _malloc(size) + if self.buffer is None: + raise Exception("Can't allocate buffer") + else: + self.buffer = None + + def __del__(self): + if self.buffer is not None: + _free(self.buffer) + + class Match(IPTCModule): """Matches are extensions which can match for special header fields or other attributes of a packet. @@ -421,15 +536,25 @@ def __init__(self, rule, name=None, match=None, revision=None): if not name and not match: raise ValueError("can't create match based on nothing") if not name: - name = match.u.user.name + name = match.u.user.name.decode() self._name = name self._rule = rule + self._orig_parse = None + self._orig_options = None self._xt = xtables(rule.nfproto) module = self._xt.find_match(name) + real_name = module and getattr(module[0], 'real_name', None) or None + if real_name: + # Alias name, look up real module. + self._name = real_name.decode() + self._orig_parse = getattr(module[0], 'x6_parse', None) + self._orig_options = getattr(module[0], 'x6_options', None) + module = self._xt.find_match(real_name) if not module: raise XTablesError("can't find match %s" % (name)) + self._module = module[0] self._module.mflags = 0 if revision is not None: @@ -441,29 +566,45 @@ def __init__(self, rule, name=None, match=None, revision=None): if match: ct.memmove(ct.byref(self._match_buf), ct.byref(match), self.size) self._update_pointers() - self._update_parameters() + self._check_alias() else: self.reset() + def _check_alias(self): + name = self._get_alias_name() + if name is None: + return + alias_module = self._xt.find_match(name) + if alias_module is None: + return + self._alias_module = alias_module[0] + self._orig_parse = getattr(self._alias_module, 'x6_parse', None) + self._orig_options = getattr(self._alias_module, 'x6_options', None) + def __eq__(self, match): basesz = ct.sizeof(xt_entry_match) - if (self.match.u.match_size == match.match.u.match_size and - self.match.u.user.name == match.match.u.user.name and - self.match.u.user.revision == match.match.u.user.revision and + if (self.name == match.name and self.match_buf[basesz:self.usersize] == - match.match_buf[basesz:match.usersize]): + match.match_buf[basesz:match.usersize]): return True return False - def __ne__(self, rule): - return not self.__eq__(rule) + def __hash__(self): + return (hash(self.match.u.match_size) ^ + hash(self.match.u.user.name) ^ + hash(self.match.u.user.revision) ^ + hash(bytes(self.match_buf))) + + def __ne__(self, match): + return not self.__eq__(match) def _final_check(self): self._xt.final_check_match(self._module) def _parse(self, argv, inv, entry): self._xt.parse_match(argv, inv, self._module, entry, - ct.cast(self._ptrptr, ct.POINTER(ct.c_void_p))) + ct.cast(self._ptrptr, ct.POINTER(ct.c_void_p)), + self._orig_parse, self._orig_options) def _get_size(self): return xt_align(self._module.size + ct.sizeof(xt_entry_match)) @@ -482,6 +623,11 @@ def _update_pointers(self): self._ptrptr = ct.cast(ct.pointer(self._ptr), ct.POINTER(ct.POINTER(xt_entry_match))) self._module.m = self._ptr + self._update_name() + + def _update_name(self): + m = self._ptr[0] + m.u.user.name = self.name.encode() def reset(self): """Reset the match. @@ -490,12 +636,15 @@ def reset(self): ct.memset(ct.byref(self._match_buf), 0, self.size) self._update_pointers() m = self._ptr[0] - m.u.user.name = self.name m.u.match_size = self.size m.u.user.revision = self._revision if self._module.init: self._module.init(self._ptr) self._module.mflags = 0 + udata_size = getattr(self._module, 'udata_size', 0) + if udata_size > 0: + udata_buf = (ct.c_ubyte * udata_size)() + self._module.udata = ct.cast(ct.byref(udata_buf), ct.c_void_p) def _get_match(self): return ct.cast(ct.byref(self.match_buf), ct.POINTER(xt_entry_match))[0] @@ -524,7 +673,12 @@ class Target(IPTCModule): does not take any value in the iptables extension, an empty string i.e. "" should be used. """ - def __init__(self, rule, name=None, target=None, revision=None): + + STANDARD_TARGETS = ["", "ACCEPT", "DROP", "REJECT", "RETURN", "REDIRECT", "SNAT", "DNAT", \ + "MASQUERADE", "MIRROR", "TOS", "MARK", "QUEUE", "LOG"] + """This is the constant for all standard targets.""" + + def __init__(self, rule, name=None, target=None, revision=None, goto=None): """ *rule* is the Rule object this match belongs to; it can be changed later via *set_rule()*. *name* is the name of the iptables target @@ -534,21 +688,54 @@ def __init__(self, rule, name=None, target=None, revision=None): should be used; different revisions use different structures in C and they usually only work with certain kernel versions. Python-iptables by default will use the latest revision available. + If goto is True, then it converts '-j' to '-g'. """ if name is None and target is None: raise ValueError("can't create target based on nothing") if name is None: - name = target.u.user.name + name = target.u.user.name.decode() self._name = name self._rule = rule + self._orig_parse = None + self._orig_options = None + + # NOTE: + # get_ip() returns the 'ip' structure that contains (1)the 'flags' field, and + # (2)the value for the GOTO flag. + # We *must* use get_ip() because the actual name of the field containing the + # structure apparently differs between implementation + ipstruct = rule.get_ip() + f_goto_attrs = [a for a in dir(ipstruct) if a.endswith('_F_GOTO')] + if len(f_goto_attrs) == 0: + raise RuntimeError('What kind of struct is this? It does not have "*_F_GOTO" constant!') + _F_GOTO = getattr(ipstruct, f_goto_attrs[0]) + + if target is not None or goto is None: + # We are 'decoding' existing Target + self._goto = bool(ipstruct.flags & _F_GOTO) + if goto is not None: + assert isinstance(goto, bool) + self._goto = goto + if goto: + ipstruct.flags |= _F_GOTO + else: + ipstruct.flags &= ~_F_GOTO self._xt = xtables(rule.nfproto) module = (self._is_standard_target() and - self._xt.find_target('standard') or + self._xt.find_target('') or self._xt.find_target(name)) + real_name = module and getattr(module[0], 'real_name', None) or None + if real_name: + # Alias name, look up real module. + self._name = real_name.decode() + self._orig_parse = getattr(module[0], 'x6_parse', None) + self._orig_options = getattr(module[0], 'x6_options', None) + module = self._xt.find_target(real_name) if not module: raise XTablesError("can't find target %s" % (name)) + self._module = module[0] self._module.tflags = 0 if revision is not None: @@ -556,48 +743,58 @@ def __init__(self, rule, name=None, target=None, revision=None): else: self._revision = self._module.revision - self._allocate_buffer(target) + self._create_buffer(target) if self._is_standard_target(): self.standard_target = name + elif target: + self._check_alias() - def __del__(self): - if getattr(self, "_target_buf", None) and self._target_buf is not None: - _free(self._target_buf) + def _check_alias(self): + name = self._get_alias_name() + if name is None: + return + alias_module = self._xt.find_target(name) + if alias_module is None: + return + self._alias_module = alias_module[0] + self._orig_parse = getattr(self._alias_module, 'x6_parse', None) + self._orig_options = getattr(self._alias_module, 'x6_options', None) def __eq__(self, targ): basesz = ct.sizeof(xt_entry_target) if (self.target.u.target_size != targ.target.u.target_size or - self.target.u.user.name != targ.target.u.user.name or - self.target.u.user.revision != targ.target.u.user.revision): + self.target.u.user.name != targ.target.u.user.name or + self.target.u.user.revision != targ.target.u.user.revision): return False - if (self.target.u.user.name == "" or - self.target.u.user.name == "standard" or - self.target.u.user.name == "ACCEPT" or - self.target.u.user.name == "DROP" or - self.target.u.user.name == "RETURN" or - self.target.u.user.name == "ERROR"): + if (self.target.u.user.name == b"" or + self.target.u.user.name == b"standard" or + self.target.u.user.name == b"ACCEPT" or + self.target.u.user.name == b"DROP" or + self.target.u.user.name == b"RETURN" or + self.target.u.user.name == b"ERROR" or + self._is_standard_target()): return True if (self._target_buf[basesz:self.usersize] == - targ._target_buf[basesz:targ.usersize]): + targ._target_buf[basesz:targ.usersize]): return True return False - def __ne__(self, rule): - return not self.__eq__(rule) + def __ne__(self, target): + return not self.__eq__(target) - def _allocate_buffer(self, target): - self._target_buf = _malloc(self.size) - if self._target_buf is None: - raise Exception("Can't allocate target buffer") + def _create_buffer(self, target): + self._buffer = _Buffer(self.size) + self._target_buf = self._buffer.buffer if target: ct.memmove(self._target_buf, ct.byref(target), self.size) self._update_pointers() - self._update_parameters() else: self.reset() def _is_standard_target(self): + if self._name in Target.STANDARD_TARGETS: + return False for t in self._rule.tables: if t.is_chain(self._name): return True @@ -608,8 +805,11 @@ def _final_check(self): def _parse(self, argv, inv, entry): self._xt.parse_target(argv, inv, self._module, entry, - ct.cast(self._ptrptr, ct.POINTER(ct.c_void_p))) + ct.cast(self._ptrptr, ct.POINTER(ct.c_void_p)), + self._orig_parse, self._orig_options) self._target_buf = ct.cast(self._module.t, ct.POINTER(ct.c_ubyte)) + if self._buffer.buffer != self._target_buf: + self._buffer.buffer = self._target_buf self._update_pointers() def _get_size(self): @@ -625,11 +825,15 @@ def _get_user_size(self): def _get_standard_target(self): t = self._ptr[0] - return t.u.user.name + return t.u.user.name.decode() def _set_standard_target(self, name): t = self._ptr[0] + if isinstance(name, str): + name = name.encode() t.u.user.name = name + if isinstance(name, bytes): + name = name.decode() self._name = name standard_target = property(_get_standard_target, _set_standard_target) """This attribute is used for standard targets. It can be set to @@ -641,6 +845,11 @@ def _update_pointers(self): self._ptrptr = ct.cast(ct.pointer(self._ptr), ct.POINTER(ct.POINTER(xt_entry_target))) self._module.t = self._ptr + self._update_name() + + def _update_name(self): + m = self._ptr[0] + m.u.user.name = self.name.encode() def reset(self): """Reset the target. Parameters are set to their default values, any @@ -648,18 +857,25 @@ def reset(self): ct.memset(self._target_buf, 0, self.size) self._update_pointers() t = self._ptr[0] - t.u.user.name = self.name t.u.target_size = self.size t.u.user.revision = self._revision if self._module.init: self._module.init(self._ptr) self._module.tflags = 0 + udata_size = getattr(self._module, 'udata_size', 0) + if udata_size > 0: + udata_buf = (ct.c_ubyte * udata_size)() + self._module.udata = ct.cast(ct.byref(udata_buf), ct.c_void_p) def _get_target(self): return self._ptr[0] target = property(_get_target) """This is the C structure used by the extension.""" + def _get_goto(self): + return self._goto + goto = property(_get_goto) + class Policy(object): """ @@ -712,12 +928,33 @@ class Rule(object): * One target. This determines what happens with the packet if it is matched. """ + protocols = {0: "all", + socket.IPPROTO_AH: "ah", + socket.IPPROTO_DSTOPTS: "dstopts", + socket.IPPROTO_EGP: "egp", + socket.IPPROTO_ESP: "esp", + socket.IPPROTO_FRAGMENT: "fragment", + socket.IPPROTO_GRE: "gre", + socket.IPPROTO_HOPOPTS: "hopopts", + socket.IPPROTO_ICMP: "icmp", + socket.IPPROTO_ICMPV6: "icmpv6", + socket.IPPROTO_IDP: "idp", + socket.IPPROTO_IGMP: "igmp", + socket.IPPROTO_IP: "ip", + socket.IPPROTO_IPIP: "ipip", + socket.IPPROTO_IPV6: "ipv6", + socket.IPPROTO_NONE: "none", + socket.IPPROTO_PIM: "pim", + socket.IPPROTO_PUP: "pup", + socket.IPPROTO_RAW: "raw", + socket.IPPROTO_ROUTING: "routing", + socket.IPPROTO_RSVP: "rsvp", + socket.IPPROTO_SCTP: "sctp", socket.IPPROTO_TCP: "tcp", + socket.IPPROTO_TP: "tp", socket.IPPROTO_UDP: "udp", - socket.IPPROTO_ICMP: "icmp", - socket.IPPROTO_ESP: "esp", - socket.IPPROTO_AH: "ah"} + } def __init__(self, entry=None, chain=None): """ @@ -739,10 +976,10 @@ def __eq__(self, rule): self._matches]): return False if (self.src == rule.src and self.dst == rule.dst and - self.protocol == rule.protocol and - self.fragment == rule.fragment and - self.in_interface == rule.in_interface and - self.out_interface == rule.out_interface): + self.protocol == rule.protocol and + self.fragment == rule.fragment and + self.in_interface == rule.in_interface and + self.out_interface == rule.out_interface): return True return False @@ -769,11 +1006,11 @@ def create_match(self, name, revision=None): self.add_match(match) return match - def create_target(self, name, revision=None): + def create_target(self, name, revision=None, goto=False): """Create a new *target*, and set it as this rule's target. *name* is the name of the target extension, *revision* is the revision to - use.""" - target = Target(self, name=name, revision=revision) + use. *goto* determines if target uses '-j' (default) or '-g'.""" + target = Target(self, name=name, revision=revision, goto=goto) self.target = target return target @@ -841,17 +1078,24 @@ def set_src(self, src): saddr = _a_to_i(socket.inet_pton(socket.AF_INET, addr)) except socket.error: raise ValueError("invalid address %s" % (addr)) - ina = in_addr() - ina.s_addr = ct.c_uint32(saddr) - self.entry.ip.src = ina - try: - nmask = _a_to_i(socket.inet_pton(socket.AF_INET, netm)) - except socket.error: - raise ValueError("invalid netmask %s" % (netm)) + if not netm.isdigit(): + try: + nmask = _a_to_i(socket.inet_pton(socket.AF_INET, netm)) + except socket.error: + raise ValueError("invalid netmask %s" % (netm)) + else: + imask = int(netm) + if imask > 32 or imask < 0: + raise ValueError("invalid netmask %s" % (netm)) + nmask = socket.htonl((2 ** imask - 1) << (32 - imask)) neta = in_addr() neta.s_addr = ct.c_uint32(nmask) self.entry.ip.smsk = neta + # Apply subnet mask to IP address + ina = in_addr() + ina.s_addr = ct.c_uint32(saddr & nmask) + self.entry.ip.src = ina src = property(get_src, set_src) """This is the source network address with an optional network mask in @@ -895,17 +1139,24 @@ def set_dst(self, dst): daddr = _a_to_i(socket.inet_pton(socket.AF_INET, addr)) except socket.error: raise ValueError("invalid address %s" % (addr)) - ina = in_addr() - ina.s_addr = ct.c_uint32(daddr) - self.entry.ip.dst = ina - try: - nmask = _a_to_i(socket.inet_pton(socket.AF_INET, netm)) - except socket.error: - raise ValueError("invalid netmask %s" % (netm)) + if not netm.isdigit(): + try: + nmask = _a_to_i(socket.inet_pton(socket.AF_INET, netm)) + except socket.error: + raise ValueError("invalid netmask %s" % (netm)) + else: + imask = int(netm) + if imask > 32 or imask < 0: + raise ValueError("invalid netmask %s" % (netm)) + nmask = socket.htonl((2 ** imask - 1) << (32 - imask)) neta = in_addr() neta.s_addr = ct.c_uint32(nmask) self.entry.ip.dmsk = neta + # Apply subnet mask to IP address + ina = in_addr() + ina.s_addr = ct.c_uint32(daddr & nmask) + self.entry.ip.dst = ina dst = property(get_dst, set_dst) """This is the destination network address with an optional network mask @@ -914,22 +1165,19 @@ def set_dst(self, dst): def get_in_interface(self): intf = "" if self.entry.ip.invflags & ipt_ip.IPT_INV_VIA_IN: - intf = "".join(["!", intf]) - iface = bytearray(_IFNAMSIZ) - iface[:len(self.entry.ip.iniface)] = self.entry.ip.iniface - mask = bytearray(_IFNAMSIZ) - mask[:len(self.entry.ip.iniface_mask)] = self.entry.ip.iniface_mask - if mask[0] == 0: + intf = "!" + + iface = self.entry.ip.iniface.decode() + mask = self.entry.ip.iniface_mask + + if len(mask) == 0: return None - for i in xrange(_IFNAMSIZ): - if mask[i] != 0: - intf = "".join([intf, chr(iface[i])]) - else: - if iface[i - 1] != 0: - intf = "".join([intf, "+"]) - else: - intf = intf[:-1] - break + + intf += iface + if len(iface) == len(mask): + intf += '+' + intf = intf[:_IFNAMSIZ] + return intf def set_in_interface(self, intf): @@ -946,10 +1194,11 @@ def set_in_interface(self, intf): intf = intf[:-1] masklen -= 2 - self.entry.ip.iniface = "".join([intf, '\x00' * (_IFNAMSIZ - - len(intf))]) - self.entry.ip.iniface_mask = "".join(['\x01' * masklen, '\x00' * - (_IFNAMSIZ - masklen)]) + self.entry.ip.iniface = b"".join([intf.encode(), + b'\x00' * (_IFNAMSIZ - len(intf))]) + self.entry.ip.iniface_mask = b"".join([b'\xff' * masklen, + b'\x00' * (_IFNAMSIZ - + masklen)]) in_interface = property(get_in_interface, set_in_interface) """This is the input network interface e.g. *eth0*. A wildcard match can @@ -958,22 +1207,19 @@ def set_in_interface(self, intf): def get_out_interface(self): intf = "" if self.entry.ip.invflags & ipt_ip.IPT_INV_VIA_OUT: - intf = "".join(["!", intf]) - iface = bytearray(_IFNAMSIZ) - iface[:len(self.entry.ip.outiface)] = self.entry.ip.outiface - mask = bytearray(_IFNAMSIZ) - mask[:len(self.entry.ip.outiface_mask)] = self.entry.ip.outiface_mask - if mask[0] == 0: + intf = "!" + + iface = self.entry.ip.outiface.decode() + mask = self.entry.ip.outiface_mask + + if len(mask) == 0: return None - for i in xrange(_IFNAMSIZ): - if mask[i] != 0: - intf = "".join([intf, chr(iface[i])]) - else: - if iface[i - 1] != 0: - intf = "".join([intf, "+"]) - else: - intf = intf[:-1] - break + + intf += iface + if len(iface) == len(mask): + intf += '+' + intf = intf[:_IFNAMSIZ] + return intf def set_out_interface(self, intf): @@ -990,10 +1236,11 @@ def set_out_interface(self, intf): intf = intf[:-1] masklen -= 2 - self.entry.ip.outiface = "".join([intf, '\x00' * (_IFNAMSIZ - - len(intf))]) - self.entry.ip.outiface_mask = "".join(['\x01' * masklen, '\x00' * - (_IFNAMSIZ - masklen)]) + self.entry.ip.outiface = b"".join([intf.encode(), + b'\x00' * (_IFNAMSIZ - len(intf))]) + self.entry.ip.outiface_mask = b"".join([b'\xff' * masklen, + b'\x00' * (_IFNAMSIZ - + masklen)]) out_interface = property(get_out_interface, set_out_interface) """This is the output network interface e.g. *eth0*. A wildcard match can @@ -1007,7 +1254,10 @@ def get_fragment(self): def set_fragment(self, frag): self.entry.ip.invflags &= ~ipt_ip.IPT_INV_FRAG & ipt_ip.IPT_INV_MASK - self.entry.ip.flags = int(bool(frag)) + if frag: + self.entry.ip.flags |= ipt_ip.IPT_F_FRAG + else: + self.entry.ip.flags &= ~ipt_ip.IPT_F_FRAG fragment = property(get_fragment, set_fragment) """This means that the rule refers to the second and further fragments of @@ -1018,16 +1268,20 @@ def get_protocol(self): proto = "!" else: proto = "" - proto = "".join([proto, self.protocols[self.entry.ip.proto]]) + proto = "".join([proto, self.protocols.get(self.entry.ip.proto, str(self.entry.ip.proto))]) return proto def set_protocol(self, proto): + proto = str(proto) if proto[0] == "!": self.entry.ip.invflags |= ipt_ip.IPT_INV_PROTO proto = proto[1:] else: self.entry.ip.invflags &= (~ipt_ip.IPT_INV_PROTO & ipt_ip.IPT_INV_MASK) + if proto.isdigit(): + self.entry.ip.proto = int(proto) + return for p in self.protocols.items(): if proto.lower() == p[1]: self.entry.ip.proto = p[0] @@ -1043,6 +1297,15 @@ def get_counters(self): counters = self.entry.counters return counters.pcnt, counters.bcnt + def set_counters(self, counters): + """This method set a tuple pair of the packet and byte counters of + the rule.""" + self.entry.counters.pcnt = counters[0] + self.entry.counters.bcnt = counters[1] + + counters = property(get_counters, set_counters) + """This is the packet and byte counters of the rule.""" + # override the following three for the IPv6 subclass def _entry_size(self): return xt_align(ct.sizeof(ipt_entry)) @@ -1095,11 +1358,12 @@ def _set_rule(self, entry): ct.POINTER(self._entry_type()))[0] if not isinstance(entry, self._entry_type()): - raise TypeError() + raise TypeError("Invalid rule type %s; expected %s" % + (entry, self._entry_type())) entrysz = self._entry_size() matchsz = entry.target_offset - entrysz - #targetsz = entry.next_offset - entry.target_offset + # targetsz = entry.next_offset - entry.target_offset # iterate over matches to create blob if matchsz: @@ -1134,14 +1398,14 @@ def _get_mask(self): # fill it out pos = 0 - for i in xrange(pos, pos + entrysz): + for i in range(pos, pos + entrysz): mask[i] = 0xff pos += entrysz for m in self._matches: - for i in xrange(pos, pos + m.usersize): + for i in range(pos, pos + m.usersize): mask[i] = 0xff pos += m.size - for i in xrange(pos, pos + self._target.usersize): + for i in range(pos, pos + self._target.usersize): mask[i] = 0xff return mask @@ -1160,10 +1424,11 @@ class Chain(object): _cache = weakref.WeakValueDictionary() def __new__(cls, table, name): - obj = Chain._cache.get(table.name + "." + name, None) + table_name = type(table).__name__ + "." + table.name + obj = Chain._cache.get(table_name + "." + name, None) if not obj: obj = object.__new__(cls) - Chain._cache[table.name + "." + name] = obj + Chain._cache[table_name + "." + name] = obj return obj def __init__(self, table, name): @@ -1231,6 +1496,14 @@ def insert_rule(self, rule, position=0): raise ValueError("invalid rule") self.table.insert_entry(self.name, rbuf, position) + def replace_rule(self, rule, position=0): + """Replace existing rule in the chain at *position* with given + *rule*""" + rbuf = rule.rule + if not rbuf: + raise ValueError("invalid rule") + self.table.replace_entry(self.name, rbuf, position) + def delete_rule(self, rule): """Removes *rule* from the chain.""" rule.final_check() @@ -1256,7 +1529,12 @@ def _get_rules(self): return [self.table.create_rule(e, self) for e in entries] rules = property(_get_rules) - """This is the list of rules currently in the chain.""" + """This is the list of rules currently in the chain. + + The indexes of the Rule items produced from this list *should* correspond + to the IPTables --line-numbers value minus one. Keeping in mind that + iptables rules are 1-indexed whereas the Python list is 0-indexed + """ def autocommit(fn): @@ -1299,10 +1577,14 @@ class Table(object): """This is the constant for the raw table.""" NAT = "nat" """This is the constant for the nat table.""" - ALL = ["filter", "mangle", "raw", "nat"] + SECURITY = "security" + """This is the constant for the security table.""" + ALL = ["filter", "mangle", "raw", "nat", "security"] """This is the constant for all tables.""" - _cache = weakref.WeakValueDictionary() + _cache = dict() + existing_table_names = dict() + """Dictionary to check faster if a table is available.""" def __new__(cls, name, autocommit=None): obj = Table._cache.get(name, None) @@ -1346,8 +1628,9 @@ def _free(self, ignore_exc=True): if self._handle is None: raise IPTCError("table is not initialized") try: - self.commit() - except IPTCError, e: + if self.autocommit: + self.commit() + except IPTCError as e: if not ignore_exc: raise e finally: @@ -1359,7 +1642,7 @@ def refresh(self): if self._handle: self._free() - handle = self._iptc.iptc_init(self.name) + handle = self._iptc.iptc_init(self.name.encode()) if not handle: raise IPTCError("can't initialize %s: %s" % (self.name, self.strerror())) @@ -1369,7 +1652,7 @@ def is_chain(self, chain): """Returns *True* if *chain* exists as a chain.""" if isinstance(chain, Chain): chain = chain.name - if self._iptc.iptc_is_chain(chain, self._handle): + if self._iptc.iptc_is_chain(chain.encode(), self._handle): return True else: return False @@ -1378,7 +1661,7 @@ def builtin_chain(self, chain): """Returns *True* if *chain* is a built-in chain.""" if isinstance(chain, Chain): chain = chain.name - if self._iptc.iptc_builtin(chain, self._handle): + if self._iptc.iptc_builtin(chain.encode(), self._handle): return True else: return False @@ -1395,7 +1678,7 @@ def create_chain(self, chain): """Create a new chain *chain*.""" if isinstance(chain, Chain): chain = chain.name - rv = self._iptc.iptc_create_chain(chain, self._handle) + rv = self._iptc.iptc_create_chain(chain.encode(), self._handle) if rv != 1: raise IPTCError("can't create chain %s: %s" % (chain, self.strerror())) @@ -1406,7 +1689,7 @@ def delete_chain(self, chain): """Delete chain *chain* from the table.""" if isinstance(chain, Chain): chain = chain.name - rv = self._iptc.iptc_delete_chain(chain, self._handle) + rv = self._iptc.iptc_delete_chain(chain.encode(), self._handle) if rv != 1: raise IPTCError("can't delete chain %s: %s" % (chain, self.strerror())) @@ -1416,7 +1699,8 @@ def rename_chain(self, chain, new_name): """Rename chain *chain* to *new_name*.""" if isinstance(chain, Chain): chain = chain.name - rv = self._iptc.iptc_rename_chain(chain, new_name, self._handle) + rv = self._iptc.iptc_rename_chain(chain.encode(), new_name.encode(), + self._handle) if rv != 1: raise IPTCError("can't rename chain %s: %s" % (chain, self.strerror())) @@ -1426,7 +1710,7 @@ def flush_entries(self, chain): """Flush all rules from *chain*.""" if isinstance(chain, Chain): chain = chain.name - rv = self._iptc.iptc_flush_entries(chain, self._handle) + rv = self._iptc.iptc_flush_entries(chain.encode(), self._handle) if rv != 1: raise IPTCError("can't flush chain %s: %s" % (chain, self.strerror())) @@ -1436,7 +1720,7 @@ def zero_entries(self, chain): """Zero the packet and byte counters of *chain*.""" if isinstance(chain, Chain): chain = chain.name - rv = self._iptc.iptc_zero_entries(chain, self._handle) + rv = self._iptc.iptc_zero_entries(chain.encode(), self._handle) if rv != 1: raise IPTCError("can't zero chain %s counters: %s" % (chain, self.strerror())) @@ -1456,9 +1740,10 @@ def set_policy(self, chain, policy, counters=None): cntrs = ct.pointer(cntrs) else: cntrs = None - rv = self._iptc.iptc_set_policy(chain, policy, cntrs, self._handle) + rv = self._iptc.iptc_set_policy(chain.encode(), policy.encode(), + cntrs, self._handle) if rv != 1: - raise IPTCError("can't set policy %s on chain %s: %s)" % + raise IPTCError("can't set policy %s on chain %s: %s" % (policy, chain, self.strerror())) @autocommit @@ -1469,8 +1754,8 @@ def get_policy(self, chain): if not self.builtin_chain(chain): return None, None cntrs = xt_counters() - pol = self._iptc.iptc_get_policy(chain, ct.pointer(cntrs), - self._handle) + pol = self._iptc.iptc_get_policy(chain.encode(), ct.pointer(cntrs), + self._handle).decode() if not pol: raise IPTCError("can't get policy on chain %s: %s" % (chain, self.strerror())) @@ -1479,33 +1764,46 @@ def get_policy(self, chain): @autocommit def append_entry(self, chain, entry): """Appends rule *entry* to *chain*.""" - rv = self._iptc.iptc_append_entry(chain, ct.cast(entry, ct.c_void_p), + rv = self._iptc.iptc_append_entry(chain.encode(), + ct.cast(entry, ct.c_void_p), self._handle) if rv != 1: - raise IPTCError("can't append entry to chain %s: %s)" % + raise IPTCError("can't append entry to chain %s: %s" % (chain, self.strerror())) @autocommit def insert_entry(self, chain, entry, position): """Inserts rule *entry* into *chain* at position *position*.""" - rv = self._iptc.iptc_insert_entry(chain, ct.cast(entry, ct.c_void_p), + rv = self._iptc.iptc_insert_entry(chain.encode(), + ct.cast(entry, ct.c_void_p), position, self._handle) if rv != 1: - raise IPTCError("can't insert entry into chain %s: %s)" % + raise IPTCError("can't insert entry into chain %s: %s" % + (chain, self.strerror())) + + @autocommit + def replace_entry(self, chain, entry, position): + """Replace existing rule in *chain* at *position* with given *rule*.""" + rv = self._iptc.iptc_replace_entry(chain.encode(), + ct.cast(entry, ct.c_void_p), + position, self._handle) + if rv != 1: + raise IPTCError("can't replace entry in chain %s: %s" % (chain, self.strerror())) @autocommit def delete_entry(self, chain, entry, mask): """Removes rule *entry* with *mask* from *chain*.""" - rv = self._iptc.iptc_delete_entry(chain, ct.cast(entry, ct.c_void_p), + rv = self._iptc.iptc_delete_entry(chain.encode(), + ct.cast(entry, ct.c_void_p), mask, self._handle) if rv != 1: - raise IPTCError("can't delete entry from chain %s: %s)" % + raise IPTCError("can't delete entry from chain %s: %s" % (chain, self.strerror())) def first_rule(self, chain): """Returns the first rule in *chain* or *None* if it is empty.""" - rule = self._iptc.iptc_first_rule(chain, self._handle) + rule = self._iptc.iptc_first_rule(chain.encode(), self._handle) if rule: return rule[0] else: @@ -1529,6 +1827,7 @@ def _get_chains(self): chains = [] chain = self._iptc.iptc_first_chain(self._handle) while chain: + chain = chain.decode() chains.append(Chain(self, chain)) chain = self._iptc.iptc_next_chain(self._handle) return chains @@ -1538,10 +1837,11 @@ def _get_chains(self): def flush(self): """Flush and delete all non-builtin chains the table.""" + for chain in self.chains: + chain.flush() for chain in self.chains: if not self.builtin_chain(chain): - chain.flush() - chain.delete() + self.delete_chain(chain) def create_rule(self, entry=None, chain=None): return Rule(entry, chain) diff --git a/iptc/ip6tc.py b/iptc/ip6tc.py index 6e16756..e0d8787 100644 --- a/iptc/ip6tc.py +++ b/iptc/ip6tc.py @@ -2,25 +2,31 @@ import ctypes as ct import socket -import weakref -from ip4tc import Rule, Table, IPTCError -from util import find_library, load_kernel -from xtables import (XT_INV_PROTO, NFPROTO_IPV6, xt_align, xt_counters) +from .ip4tc import Rule, Table, IPTCError +from .util import find_library, load_kernel +from .xtables import (XT_INV_PROTO, NFPROTO_IPV6, xt_align, xt_counters) __all__ = ["Table6", "Rule6"] -load_kernel("ip6_tables") +try: + load_kernel("ip6_tables") +except: + pass _IFNAMSIZ = 16 def is_table6_available(name): try: + if name in Table6.existing_table_names: + return Table6.existing_table_names[name] Table6(name) + Table6.existing_table_names[name] = True return True except IPTCError: pass + Table6.existing_table_names[name] = False return False @@ -82,7 +88,7 @@ class ip6t_entry(ct.Structure): class ip6tc(object): """This class contains all libip6tc API calls.""" iptc_init = _libiptc.ip6tc_init - iptc_init.restype = ct.c_void_p + iptc_init.restype = ct.POINTER(ct.c_int) iptc_init.argstype = [ct.c_char_p] iptc_free = _libiptc.ip6tc_free @@ -99,11 +105,11 @@ class ip6tc(object): iptc_first_chain = _libiptc.ip6tc_first_chain iptc_first_chain.restype = ct.c_char_p - iptc_first_chain.argstype = [ct.c_char_p, ct.c_void_p] + iptc_first_chain.argstype = [ct.c_void_p] iptc_next_chain = _libiptc.ip6tc_next_chain iptc_next_chain.restype = ct.c_char_p - iptc_next_chain.argstype = [ct.c_char_p, ct.c_void_p] + iptc_next_chain.argstype = [ct.c_void_p] iptc_is_chain = _libiptc.ip6tc_is_chain iptc_is_chain.restype = ct.c_int @@ -193,9 +199,9 @@ class ip6tc(object): # Check the packet `e' on chain `chain'. Returns the verdict, or # NULL and sets errno. - #iptc_check_packet = _libiptc.ip6tc_check_packet - #iptc_check_packet.restype = ct.c_char_p - #iptc_check_packet.argstype = [ct.c_char_p, ct.POINTER(ipt), ct.c_void_p] + # iptc_check_packet = _libiptc.ip6tc_check_packet + # iptc_check_packet.restype = ct.c_char_p + # iptc_check_packet.argstype = [ct.c_char_p, ct.POINTER(ipt), ct.c_void_p] # Get the number of references to this chain iptc_get_references = _libiptc.ip6tc_get_references @@ -243,9 +249,9 @@ def __eq__(self, rule): if x in self._matches]): return False if (self.src == rule.src and self.dst == rule.dst and - self.protocol == rule.protocol and - self.in_interface == rule.in_interface and - self.out_interface == rule.out_interface): + self.protocol == rule.protocol and + self.in_interface == rule.in_interface and + self.out_interface == rule.out_interface): return True return False @@ -266,16 +272,16 @@ def _count_bits(self, n): return bits def _create_mask(self, plen): - mask = [0 for x in xrange(16)] - i = 0 - while plen > 0: + mask = [] + for i in range(16): if plen >= 8: - mask[i] = 0xff + mask.append(0xff) + elif plen > 0: + mask.append(0xff>>(8-plen)<<(8-plen)) else: - mask[i] = 2 ** plen - 1 - i += 1 + mask.append(0x00) plen -= 8 - return "".join([chr(x) for x in mask]) + return mask def get_src(self): src = "" @@ -336,8 +342,7 @@ def set_src(self, src): plen = int(netm) if plen < 0 or plen > 128: raise ValueError("invalid prefix length %d" % (plen)) - self.entry.ipv6.smsk.s6_addr = arr.from_buffer_copy( - self._create_mask(plen)) + self.entry.ipv6.smsk.s6_addr = arr(*self._create_mask(plen)) return # nope, we got an IPv6 address-style prefix @@ -392,8 +397,7 @@ def set_dst(self, dst): plen = int(netm) if plen < 0 or plen > 128: raise ValueError("invalid prefix length %d" % (plen)) - self.entry.ipv6.dmsk.s6_addr = arr.from_buffer_copy( - self._create_mask(plen)) + self.entry.ipv6.dmsk.s6_addr = arr(*self._create_mask(plen)) return # nope, we got an IPv6 address-style prefix @@ -412,22 +416,19 @@ def set_dst(self, dst): def get_in_interface(self): intf = "" if self.entry.ipv6.invflags & ip6t_ip6.IP6T_INV_VIA_IN: - intf = "".join(["!", intf]) - iface = bytearray(_IFNAMSIZ) - iface[:len(self.entry.ipv6.iniface)] = self.entry.ipv6.iniface - mask = bytearray(_IFNAMSIZ) - mask[:len(self.entry.ipv6.iniface_mask)] = self.entry.ipv6.iniface_mask - if mask[0] == 0: + intf = "!" + + iface = self.entry.ipv6.iniface.decode() + mask = self.entry.ipv6.iniface_mask + + if len(mask) == 0: return None - for i in xrange(_IFNAMSIZ): - if mask[i] != 0: - intf = "".join([intf, chr(iface[i])]) - else: - if iface[i - 1] != 0: - intf = "".join([intf, "+"]) - else: - intf = intf[:-1] - break + + intf += iface + if len(iface) == len(mask): + intf += '+' + intf = intf[:_IFNAMSIZ] + return intf def set_in_interface(self, intf): @@ -444,10 +445,10 @@ def set_in_interface(self, intf): intf = intf[:-1] masklen -= 2 - self.entry.ipv6.iniface = "".join([intf, '\x00' * (_IFNAMSIZ - - len(intf))]) - self.entry.ipv6.iniface_mask = "".join(['\x01' * masklen, '\x00' * - (_IFNAMSIZ - masklen)]) + self.entry.ipv6.iniface = ("".join( + [intf, '\x00' * (_IFNAMSIZ - len(intf))])).encode() + self.entry.ipv6.iniface_mask = ("".join( + ['\x01' * masklen, '\x00' * (_IFNAMSIZ - masklen)])).encode() in_interface = property(get_in_interface, set_in_interface) """This is the input network interface e.g. *eth0*. A wildcard match can @@ -456,23 +457,19 @@ def set_in_interface(self, intf): def get_out_interface(self): intf = "" if self.entry.ipv6.invflags & ip6t_ip6.IP6T_INV_VIA_OUT: - intf = "".join(["!", intf]) - iface = bytearray(_IFNAMSIZ) - iface[:len(self.entry.ipv6.outiface)] = self.entry.ipv6.outiface - mask = bytearray(_IFNAMSIZ) - mask[:len(self.entry.ipv6.outiface_mask)] = \ - self.entry.ipv6.outiface_mask - if mask[0] == 0: + intf = "!" + + iface = self.entry.ipv6.outiface.decode() + mask = self.entry.ipv6.outiface_mask + + if len(mask) == 0: return None - for i in xrange(_IFNAMSIZ): - if mask[i] != 0: - intf = "".join([intf, chr(iface[i])]) - else: - if iface[i - 1] != 0: - intf = "".join([intf, "+"]) - else: - intf = intf[:-1] - break + + intf += iface + if len(iface) == len(mask): + intf += '+' + intf = intf[:_IFNAMSIZ] + return intf def set_out_interface(self, intf): @@ -489,10 +486,10 @@ def set_out_interface(self, intf): intf = intf[:-1] masklen -= 2 - self.entry.ipv6.outiface = "".join([intf, '\x00' * (_IFNAMSIZ - - len(intf))]) - self.entry.ipv6.outiface_mask = "".join(['\x01' * masklen, '\x00' * - (_IFNAMSIZ - masklen)]) + self.entry.ipv6.outiface = ("".join( + [intf, '\x00' * (_IFNAMSIZ - len(intf))])).encode() + self.entry.ipv6.outiface_mask = ("".join( + ['\x01' * masklen, '\x00' * (_IFNAMSIZ - masklen)])).encode() out_interface = property(get_out_interface, set_out_interface) """This is the output network interface e.g. *eth0*. A wildcard match can @@ -503,16 +500,23 @@ def get_protocol(self): proto = "!" else: proto = "" - proto = "".join([proto, self.protocols[self.entry.ipv6.proto]]) + proto = "".join([proto, self.protocols.get(self.entry.ipv6.proto, str(self.entry.ipv6.proto))]) return proto def set_protocol(self, proto): + proto = str(proto) if proto[0] == "!": self.entry.ipv6.invflags |= ip6t_ip6.IP6T_INV_PROTO + self.entry.ipv6.flags &= (~ip6t_ip6.IP6T_F_PROTO & + ip6t_ip6.IP6T_F_MASK) proto = proto[1:] else: self.entry.ipv6.invflags &= (~ip6t_ip6.IP6T_INV_PROTO & ip6t_ip6.IP6T_INV_MASK) + self.entry.ipv6.flags |= ip6t_ip6.IP6T_F_PROTO + if proto.isdigit(): + self.entry.ipv6.proto = int(proto) + return for p in self.protocols.items(): if proto.lower() == p[1]: self.entry.ipv6.proto = p[0] @@ -564,12 +568,16 @@ class Table6(Table): """This is the constant for the mangle table.""" RAW = "raw" """This is the constant for the raw table.""" + NAT = "nat" + """This is the constant for the nat table.""" SECURITY = "security" """This is the constant for the security table.""" - ALL = ["filter", "mangle", "raw", "security"] + ALL = ["filter", "mangle", "raw", "nat", "security"] """This is the constant for all tables.""" - _cache = weakref.WeakValueDictionary() + _cache = dict() + existing_table_names = dict() + """Dictionary to check faster if a table is available.""" def __new__(cls, name, autocommit=None): obj = Table6._cache.get(name, None) diff --git a/iptc/test/test_matches.py b/iptc/test/test_matches.py deleted file mode 100755 index f565d56..0000000 --- a/iptc/test/test_matches.py +++ /dev/null @@ -1,287 +0,0 @@ -# -*- coding: utf-8 -*- - -import unittest -import iptc - - -class TestMatch(unittest.TestCase): - def setUp(self): - pass - - def tearDown(self): - pass - - def test_match_create(self): - rule = iptc.Rule() - match = rule.create_match("udp") - - for m in rule.matches: - self.failUnless(m == match) - - # check that we can change match parameters after creation - match.sport = "12345:55555" - match.dport = "!33333" - - m = iptc.Match(iptc.Rule(), "udp") - m.sport = "12345:55555" - m.dport = "!33333" - - self.failUnless(m == match) - - def test_match_compare(self): - m1 = iptc.Match(iptc.Rule(), "udp") - m1.sport = "12345:55555" - m1.dport = "!33333" - - m2 = iptc.Match(iptc.Rule(), "udp") - m2.sport = "12345:55555" - m2.dport = "!33333" - - self.failUnless(m1 == m2) - - m2.reset() - m2.sport = "12345:55555" - m2.dport = "33333" - self.failIf(m1 == m2) - - def test_match_parameters(self): - m = iptc.Match(iptc.Rule(), "udp") - m.sport = "12345:55555" - m.dport = "!33333" - - self.failUnless(len(m.parameters) == 2) - - for p in m.parameters: - self.failUnless(p == "sport" or p == "dport") - - self.failUnless(m.parameters["sport"] == "12345:55555") - self.failUnless(m.parameters["dport"] == "!33333") - - m.reset() - self.failUnless(len(m.parameters) == 0) - - -class TestXTUdpMatch(unittest.TestCase): - def setUp(self): - self.rule = iptc.Rule() - self.rule.src = "127.0.0.1" - self.rule.protocol = "udp" - self.rule.target = iptc.Target(self.rule, "ACCEPT") - - self.match = iptc.Match(self.rule, "udp") - self.chain = iptc.Chain(iptc.Table(iptc.Table.FILTER), "iptc_test_udp") - iptc.Table(iptc.Table.FILTER).create_chain(self.chain) - - def tearDown(self): - self.chain.flush() - self.chain.delete() - - def test_udp_port(self): - for port in ["12345", "12345:65535", "!12345", "12345:12346", - "!12345:12346", "0:1234", "! 1234", "!0:12345", - "!1234:65535"]: - self.match.sport = port - self.assertEquals(self.match.sport, port.replace(" ", "")) - self.match.dport = port - self.assertEquals(self.match.dport, port.replace(" ", "")) - self.match.reset() - for port in ["-1", "asdf", "!asdf"]: - try: - self.match.sport = port - except Exception: - pass - else: - self.fail("udp accepted invalid source port %s" % (port)) - try: - self.match.dport = port - except Exception: - pass - else: - self.fail("udp accepted invalid destination port %s" % (port)) - self.match.reset() - - def test_udp_insert(self): - self.match.reset() - self.match.dport = "12345" - self.rule.add_match(self.match) - - self.chain.insert_rule(self.rule) - - for r in self.chain.rules: - if r != self.rule: - self.fail("inserted rule does not match original") - - -class TestXTMarkMatch(unittest.TestCase): - def setUp(self): - self.rule = iptc.Rule() - self.rule.src = "127.0.0.1" - self.rule.protocol = "tcp" - self.rule.target = iptc.Target(self.rule, "ACCEPT") - - self.match = iptc.Match(self.rule, "mark") - - self.chain = iptc.Chain(iptc.Table(iptc.Table.FILTER), - "iptc_test_mark") - iptc.Table(iptc.Table.FILTER).create_chain(self.chain) - - def tearDown(self): - self.chain.flush() - self.chain.delete() - - def test_mark(self): - for mark in ["0x7b", "! 0x7b", "0x7b/0xfffefffe", "!0x7b/0xff00ff00"]: - self.match.mark = mark - self.assertEquals(self.match.mark, mark.replace(" ", "")) - self.match.reset() - for mark in ["0xffffffffff", "123/0xffffffff1", "!asdf", "1234:1233"]: - try: - self.match.mark = mark - except Exception: - pass - else: - self.fail("mark accepted invalid value %s" % (mark)) - self.match.reset() - - def test_mark_insert(self): - self.match.reset() - self.match.mark = "0x123" - self.rule.add_match(self.match) - - self.chain.insert_rule(self.rule) - - for r in self.chain.rules: - if r != self.rule: - self.fail("inserted rule does not match original") - - -class TestXTLimitMatch(unittest.TestCase): - def setUp(self): - self.rule = iptc.Rule() - self.rule.src = "127.0.0.1" - self.rule.protocol = "tcp" - self.rule.target = iptc.Target(self.rule, "ACCEPT") - - self.match = iptc.Match(self.rule, "limit") - self.chain = iptc.Chain(iptc.Table(iptc.Table.FILTER), - "iptc_test_limit") - iptc.Table(iptc.Table.FILTER).create_chain(self.chain) - - def tearDown(self): - self.chain.flush() - self.chain.delete() - - def test_limit(self): - for limit in ["1/sec", "5/min", "3/hour"]: - self.match.limit = limit - self.assertEquals(self.match.limit, limit) - self.match.reset() - for limit in ["asdf", "123/1", "!1", "!1/second"]: - try: - self.match.limit = limit - except Exception: - pass - else: - self.fail("limit accepted invalid value %s" % (limit)) - self.match.reset() - - def test_limit_insert(self): - self.match.reset() - self.match.limit = "1/min" - self.rule.add_match(self.match) - - self.chain.insert_rule(self.rule) - - for r in self.chain.rules: - if r != self.rule: - self.fail("inserted rule does not match original") - - -class TestCommentMatch(unittest.TestCase): - def setUp(self): - self.rule = iptc.Rule() - self.rule.src = "127.0.0.1" - self.rule.protocol = "udp" - self.rule.target = iptc.Target(self.rule, "ACCEPT") - - self.match = iptc.Match(self.rule, "comment") - self.chain = iptc.Chain(iptc.Table(iptc.Table.FILTER), - "iptc_test_comment") - iptc.Table(iptc.Table.FILTER).create_chain(self.chain) - - def tearDown(self): - self.chain.flush() - self.chain.delete() - - def test_comment(self): - comment = "comment test" - self.match.reset() - self.match.comment = "\"%s\"" % (comment) - self.chain.insert_rule(self.rule) - self.assertEquals(self.match.comment.replace('"', ''), comment) - - -class TestIprangeMatch(unittest.TestCase): - def setUp(self): - self.rule = iptc.Rule() - self.rule.protocol = "tcp" - self.rule.target = iptc.Target(self.rule, "ACCEPT") - - self.match = iptc.Match(self.rule, "iprange") - - self.chain = iptc.Chain(iptc.Table(iptc.Table.FILTER), - "iptc_test_iprange") - iptc.Table(iptc.Table.FILTER).create_chain(self.chain) - - def tearDown(self): - self.chain.flush() - self.chain.delete() - - def test_iprange(self): - self.match.src_range = "192.168.1.100-192.168.1.200" - self.match.dst_range = "172.22.33.106" - self.rule.add_match(self.match) - - self.chain.insert_rule(self.rule) - - for r in self.chain.rules: - if r != self.rule: - self.fail("inserted rule does not match original") - - def test_iprange_tcpdport(self): - self.match.src_range = "192.168.1.100-192.168.1.200" - self.match.dst_range = "172.22.33.106" - self.rule.add_match(self.match) - - match = iptc.Match(self.rule, "tcp") - match.dport = "22" - self.rule.add_match(match) - - self.chain.insert_rule(self.rule) - - for r in self.chain.rules: - if r != self.rule: - self.fail("inserted rule does not match original") - - -def suite(): - suite_match = unittest.TestLoader().loadTestsFromTestCase(TestMatch) - suite_udp = unittest.TestLoader().loadTestsFromTestCase(TestXTUdpMatch) - suite_mark = unittest.TestLoader().loadTestsFromTestCase(TestXTMarkMatch) - suite_limit = unittest.TestLoader().loadTestsFromTestCase(TestXTLimitMatch) - suite_comment = unittest.TestLoader().loadTestsFromTestCase( - TestCommentMatch) - suite_iprange = unittest.TestLoader().loadTestsFromTestCase( - TestIprangeMatch) - return unittest.TestSuite([suite_match, suite_udp, suite_mark, - suite_limit, suite_comment, suite_iprange]) - - -def run_tests(): - result = unittest.TextTestRunner(verbosity=2).run(suite()) - if result.errors or result.failures: - return 1 - return 0 - -if __name__ == "__main__": - unittest.main() diff --git a/iptc/util.py b/iptc/util.py index 4b4215c..6c1592c 100644 --- a/iptc/util.py +++ b/iptc/util.py @@ -1,7 +1,24 @@ import re +import os +import sys import ctypes import ctypes.util +from itertools import product from subprocess import Popen, PIPE +from sys import version_info +try: + from sysconfig import get_config_var +except ImportError: + def get_config_var(name): + if name == 'SO': + return '.so' + raise Exception('Not implemented') +try: + from distutils.sysconfig import get_python_lib +except ModuleNotFoundError: + import sysconfig + def get_python_lib(): + return sysconfig.get_path("purelib") def _insert_ko(modprobe, modname): @@ -11,9 +28,17 @@ def _insert_ko(modprobe, modname): def _load_ko(modname): + # only try to load modules on kernels that support them + if not os.path.exists("/proc/modules"): + return (0, None) + # this will return the full path for the modprobe binary - proc = open("/proc/sys/kernel/modprobe") - modprobe = proc.read(1024) + modprobe = "/sbin/modprobe" + try: + proc = open("/proc/sys/kernel/modprobe") + modprobe = proc.read(1024) + except: + pass if modprobe[-1] == '\n': modprobe = modprobe[:-1] return _insert_ko(modprobe, modname) @@ -32,14 +57,17 @@ def load_kernel(name, exc_if_failed=False): def _do_find_library(name): + if '/' in name: + try: + return ctypes.CDLL(name, mode=ctypes.RTLD_GLOBAL) + except Exception: + return None p = ctypes.util.find_library(name) if p: lib = ctypes.CDLL(p, mode=ctypes.RTLD_GLOBAL) return lib # probably we have been installed in a virtualenv - import os - from distutils.sysconfig import get_python_lib try: lib = ctypes.CDLL(os.path.join(get_python_lib(), name), mode=ctypes.RTLD_GLOBAL) @@ -47,7 +75,6 @@ def _do_find_library(name): except: pass - import sys for p in sys.path: try: lib = ctypes.CDLL(os.path.join(p, name), mode=ctypes.RTLD_GLOBAL) @@ -58,8 +85,29 @@ def _do_find_library(name): def _find_library(*names): + exts = [] + if version_info >= (3, 3): + exts.append(get_config_var("EXT_SUFFIX")) + else: + exts.append(get_config_var('SO')) + + if version_info >= (3, 5): + exts.append('.so') + for name in names: - for n in (name, "lib" + name, name + ".so", "lib" + name + ".so"): + libnames = [name, "lib" + name] + for ext in exts: + libnames += [name + ext, "lib" + name + ext] + libdir = os.environ.get('IPTABLES_LIBDIR', None) + if libdir is not None: + libdirs = libdir.split(':') + libs = [os.path.join(*p) for p in product(libdirs, libnames)] + libs.extend(libnames) + else: + libs = libnames + for n in libs: + while os.path.islink(n): + n = os.path.realpath(n) lib = _do_find_library(n) if lib is not None: yield lib @@ -68,8 +116,24 @@ def _find_library(*names): def find_library(*names): for lib in _find_library(*names): major = 0 - m = re.search(r"\.so\.(\d+)", lib._name) + m = re.search(r"\.so\.(\d+).?", lib._name) if m: major = int(m.group(1)) return lib, major return None, None + + +def find_libc(): + lib = ctypes.util.find_library('c') + if lib is not None: + return ctypes.CDLL(lib, mode=ctypes.RTLD_GLOBAL) + + libnames = ['libc.so.6', 'libc.so.0', 'libc.so'] + for name in libnames: + try: + lib = ctypes.CDLL(name, mode=ctypes.RTLD_GLOBAL) + return lib + except: + pass + + return None diff --git a/iptc/version.py b/iptc/version.py index 954537d..f6a5157 100644 --- a/iptc/version.py +++ b/iptc/version.py @@ -1,4 +1,4 @@ # -*- coding: utf-8 -*- __pkgname__ = "python-iptables" -__version__ = "0.2.0-dev" +__version__ = "1.2.0" diff --git a/iptc/xtables.py b/iptc/xtables.py index 11a96dc..024779c 100644 --- a/iptc/xtables.py +++ b/iptc/xtables.py @@ -4,9 +4,10 @@ import os import sys import weakref -import version -from util import find_library +from . import version +from .util import find_library, find_libc +from .errors import * XT_INV_PROTO = 0x40 # invert the sense of PROTO @@ -23,6 +24,7 @@ XTF_TRY_LOAD = 0x02 XTF_LOAD_MUST_SUCCEED = 0x03 + XTOPT_INVERT = 1 << 0 XTOPT_MAND = 1 << 1 XTOPT_MULTI = 1 << 2 @@ -83,7 +85,9 @@ class xtables_globals(ct.Structure): ("program_version", ct.c_char_p), ("orig_opts", ct.c_void_p), ("opts", ct.c_void_p), - ("exit_err", ct.CFUNCTYPE(None, ct.c_int, ct.c_char_p))] + ("exit_err", ct.CFUNCTYPE(None, ct.c_int, ct.c_char_p)), + ("compat_rev", ct.CFUNCTYPE(ct.c_int, ct.c_char_p, ct.c_uint8, + ct.c_int))] # struct used by getopt() @@ -386,7 +390,61 @@ class _xtables_match_v10(ct.Structure): ("save", ct.CFUNCTYPE(None, ct.c_void_p, ct.POINTER(xt_entry_match))), # Print match name or alias - ("alias", ct.CFUNCTYPE(ct.c_char_p, ct.c_void_p, + ("alias", ct.CFUNCTYPE(ct.c_char_p, + ct.POINTER(xt_entry_match))), + # pointer to list of extra command-line options + ("extra_opts", ct.POINTER(option)), + + # introduced with the new iptables API + ("x6_parse", ct.CFUNCTYPE(None, ct.POINTER(xt_option_call))), + ("x6_fcheck", ct.CFUNCTYPE(None, ct.POINTER(xt_fcheck_call))), + ("x6_options", ct.POINTER(xt_option_entry)), + + # size of per-extension instance extra "global" scratch space + ("udata_size", ct.c_size_t), + + # ignore these men behind the curtain: + ("udata", ct.c_void_p), + ("option_offset", ct.c_uint), + ("m", ct.POINTER(xt_entry_match)), + ("mflags", ct.c_uint), + ("loaded", ct.c_uint)] + + +_xtables_match_v11 = _xtables_match_v10 + + +class _xtables_match_v12(ct.Structure): + _fields_ = [("version", ct.c_char_p), + ("next", ct.c_void_p), + ("name", ct.c_char_p), + ("real_name", ct.c_char_p), + ("revision", ct.c_uint8), + ("ext_flags", ct.c_uint8), + ("family", ct.c_uint16), + ("size", ct.c_size_t), + ("userspacesize", ct.c_size_t), + ("help", ct.CFUNCTYPE(None)), + ("init", ct.CFUNCTYPE(None, ct.POINTER(xt_entry_match))), + # fourth parameter entry is struct ipt_entry for example + # int (*parse)(int c, char **argv, int invert, unsigned int + # *flags, const void *entry, struct xt_entry_match **match) + ("parse", ct.CFUNCTYPE(ct.c_int, ct.c_int, + ct.POINTER(ct.c_char_p), ct.c_int, + ct.POINTER(ct.c_uint), ct.c_void_p, + ct.POINTER(ct.POINTER( + xt_entry_match)))), + ("final_check", ct.CFUNCTYPE(None, ct.c_uint)), + # prints out the match iff non-NULL: put space at end + # first parameter ip is struct ipt_ip * for example + ("print", ct.CFUNCTYPE(None, ct.c_void_p, + ct.POINTER(xt_entry_match), ct.c_int)), + # saves the match info in parsable form to stdout. + # first parameter ip is struct ipt_ip * for example + ("save", ct.CFUNCTYPE(None, ct.c_void_p, + ct.POINTER(xt_entry_match))), + # Print match name or alias + ("alias", ct.CFUNCTYPE(ct.c_char_p, ct.POINTER(xt_entry_match))), # pointer to list of extra command-line options ("extra_opts", ct.POINTER(option)), @@ -396,6 +454,8 @@ class _xtables_match_v10(ct.Structure): ("x6_fcheck", ct.CFUNCTYPE(None, ct.POINTER(xt_fcheck_call))), ("x6_options", ct.POINTER(xt_option_entry)), + ('xt_xlate', ct.c_int), + # size of per-extension instance extra "global" scratch space ("udata_size", ct.c_size_t), @@ -417,7 +477,9 @@ class xtables_match(ct.Union): ("v7", _xtables_match_v7), # Apparently v8 was skipped ("v9", _xtables_match_v9), - ("v10", _xtables_match_v10)] + ("v10", _xtables_match_v10), + ("v11", _xtables_match_v11), + ("v12", _xtables_match_v12)] class _xtables_target_v1(ct.Structure): @@ -636,7 +698,7 @@ class _xtables_target_v10(ct.Structure): ("save", ct.CFUNCTYPE(None, ct.c_void_p, ct.POINTER(xt_entry_target))), # Print target name or alias - ("alias", ct.CFUNCTYPE(ct.c_char_p, ct.c_void_p, + ("alias", ct.CFUNCTYPE(ct.c_char_p, ct.POINTER(xt_entry_target))), # pointer to list of extra command-line options ("extra_opts", ct.POINTER(option)), @@ -658,6 +720,63 @@ class _xtables_target_v10(ct.Structure): ("loaded", ct.c_uint)] +_xtables_target_v11 = _xtables_target_v10 + +class _xtables_target_v12(ct.Structure): + _fields_ = [("version", ct.c_char_p), + ("next", ct.c_void_p), + ("name", ct.c_char_p), + ("real_name", ct.c_char_p), + ("revision", ct.c_uint8), + ("ext_flags", ct.c_uint8), + ("family", ct.c_uint16), + ("size", ct.c_size_t), + ("userspacesize", ct.c_size_t), + ("help", ct.CFUNCTYPE(None)), + ("init", ct.CFUNCTYPE(None, ct.POINTER(xt_entry_target))), + # fourth parameter entry is struct ipt_entry for example + # int (*parse)(int c, char **argv, int invert, + # unsigned int *flags, const void *entry, + # struct xt_entry_target **target) + ("parse", ct.CFUNCTYPE(ct.c_int, + ct.POINTER(ct.c_char_p), ct.c_int, + ct.POINTER(ct.c_uint), ct.c_void_p, + ct.POINTER(ct.POINTER( + xt_entry_target)))), + ("final_check", ct.CFUNCTYPE(None, ct.c_uint)), + # prints out the target iff non-NULL: put space at end + # first parameter ip is struct ipt_ip * for example + ("print", ct.CFUNCTYPE(None, ct.c_void_p, + ct.POINTER(xt_entry_target), ct.c_int)), + # saves the target info in parsable form to stdout. + # first parameter ip is struct ipt_ip * for example + ("save", ct.CFUNCTYPE(None, ct.c_void_p, + ct.POINTER(xt_entry_target))), + # Print target name or alias + ("alias", ct.CFUNCTYPE(ct.c_char_p, + ct.POINTER(xt_entry_target))), + # pointer to list of extra command-line options + ("extra_opts", ct.POINTER(option)), + + # introduced with the new iptables API + ("x6_parse", ct.CFUNCTYPE(None, ct.POINTER(xt_option_call))), + ("x6_fcheck", ct.CFUNCTYPE(None, ct.POINTER(xt_fcheck_call))), + ("x6_options", ct.POINTER(xt_option_entry)), + + ('xt_xlate', ct.c_int), + + # size of per-extension instance extra "global" scratch space + ("udata_size", ct.c_size_t), + + # ignore these men behind the curtain: + ("udata", ct.c_void_p), + ("option_offset", ct.c_uint), + ("t", ct.POINTER(xt_entry_target)), + ("tflags", ct.c_uint), + ("used", ct.c_uint), + ("loaded", ct.c_uint)] + + class xtables_target(ct.Union): _fields_ = [("v1", _xtables_target_v1), ("v2", _xtables_target_v2), @@ -668,26 +787,40 @@ class xtables_target(ct.Union): ("v7", _xtables_target_v7), # Apparently v8 was skipped ("v9", _xtables_target_v9), - ("v10", _xtables_target_v10)] - - -class XTablesError(Exception): - """Raised when an xtables call fails for some reason.""" + ("v10", _xtables_target_v10), + ("v11", _xtables_target_v11), + ("v12", _xtables_target_v12)] -_libc, _ = find_library("c") +_libc = find_libc() _optind = ct.c_long.in_dll(_libc, "optind") _optarg = ct.c_char_p.in_dll(_libc, "optarg") -_lib_xtables, _xtables_version = find_library("xtables") +xtables_version = os.getenv("PYTHON_IPTABLES_XTABLES_VERSION") +if xtables_version: + _searchlib = "libxtables.so.%s" % (xtables_version,) +else: + _searchlib = "xtables" +_lib_xtables, xtables_version = find_library(_searchlib) _xtables_libdir = os.getenv("XTABLES_LIBDIR") if _xtables_libdir is None: - import os.path - for xtdir in ["/lib/xtables", "/usr/lib/xtables", - "/usr/local/lib/xtables"]: - if os.path.isdir(xtdir): - _xtables_libdir = xtdir - break + import re + ldconfig_path_regex = re.compile('^(/.*):($| \(.*\)$)') + import subprocess + ldconfig = subprocess.Popen( + ('/sbin/ldconfig', '-N', '-v'), + stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True + ) + ldconfig_out, ldconfig_err = ldconfig.communicate() + if ldconfig.returncode != 0: + raise XTablesError("ldconfig failed, please set XTABLES_LIBDIR") + for ldconfig_out_line in ldconfig_out.splitlines(): + ldconfig_path_regex_match = ldconfig_path_regex.match(ldconfig_out_line) + if ldconfig_path_regex_match is not None: + ldconfig_path = os.path.join(ldconfig_path_regex_match.group(1), 'xtables') + if os.path.isdir(ldconfig_path): + _xtables_libdir = ldconfig_path + break if _xtables_libdir is None: raise XTablesError("can't find directory with extensions; " "please set XTABLES_LIBDIR") @@ -729,20 +862,31 @@ def _xt_exit(status, *args): _xt_exit = _EXIT_FN(_xt_exit) -def preserve_globals(fn): +def set_nfproto(fn): def new(*args): - obj = args[0] - obj._restore_globals() - try: - ret = fn(*args) - except Exception: - obj._save_globals() - raise - obj._save_globals() - return ret + xtobj = args[0] + xtables._xtables_set_nfproto(xtobj.proto) + return fn(*args) return new +_xt_globals = xtables_globals() +_xt_globals.option_offset = 0 +_xt_globals.program_name = version.__pkgname__.encode() +_xt_globals.program_version = version.__version__.encode() +_xt_globals.orig_opts = None +_xt_globals.opts = None +_xt_globals.exit_err = _xt_exit + +if xtables_version > 10: + _COMPAT_REV_FN = ct.CFUNCTYPE(ct.c_int, ct.c_char_p, ct.c_uint8, ct.c_int) + _xt_compat_rev = _COMPAT_REV_FN(_lib_xtables.xtables_compatible_revision) + _xt_globals.compat_rev = _xt_compat_rev + + +_loaded_exts = {} + + class xtables(object): _xtables_init_all = _lib_xtables.xtables_init_all _xtables_init_all.restype = ct.c_int @@ -763,17 +907,22 @@ class xtables(object): _xtables_xt_params = ct.c_void_p.in_dll(_lib_xtables, "xt_params") _xtables_matches = (ct.c_void_p.in_dll(_lib_xtables, "xtables_matches")) try: - _xtables_pending_matches = (ct.c_void_p.in_dll(_lib_xtables, - "xtables_pending_matches")) + _xtables_pending_matches = (ct.c_void_p.in_dll( + _lib_xtables, "xtables_pending_matches")) except ValueError: _xtables_pending_matches = ct.POINTER(None) _xtables_targets = (ct.c_void_p.in_dll(_lib_xtables, "xtables_targets")) try: - _xtables_pending_targets = (ct.c_void_p.in_dll(_lib_xtables, - "xtables_pending_targets")) + _xtables_pending_targets = (ct.c_void_p.in_dll( + _lib_xtables, "xtables_pending_targets")) except ValueError: _xtables_pending_targets = ct.POINTER(None) + _real_name = { + 'state': 'conntrack', + 'NOTRACK': 'CT' + } + _cache = weakref.WeakValueDictionary() def __new__(cls, proto): @@ -784,81 +933,34 @@ def __new__(cls, proto): obj._xtinit(proto) return obj - def _xtinit(self, proto): + def _xtinit(self, proto, no_alias_check=False): self.proto = proto - self._xt_globals = xtables_globals() - self._xt_globals.option_offset = 0 - self._xt_globals.program_name = version.__pkgname__ - self._xt_globals.program_version = version.__version__ - self._xt_globals.orig_opts = None - self._xt_globals.opts = None - self._xt_globals.exit_err = _xt_exit + self.no_alias_check = no_alias_check thismodule = sys.modules[__name__] - matchname = "_xtables_match_v%d" % (_xtables_version) - targetname = "_xtables_target_v%d" % (_xtables_version) + matchname = "_xtables_match_v%d" % (xtables_version) + targetname = "_xtables_target_v%d" % (xtables_version) try: self._match_struct = getattr(thismodule, matchname) self._target_struct = getattr(thismodule, targetname) except: raise XTablesError("unknown xtables version %d" % - (_xtables_version)) + (xtables_version)) - self._loaded_exts = [] - - # make sure we're initializing with clean state - self._xt_params = ct.c_void_p(None).value - self._matches = ct.c_void_p(None).value - self._pending_matches = ct.c_void_p(None).value - self._targets = ct.c_void_p(None).value - self._pending_targets = ct.c_void_p(None).value - - rv = xtables._xtables_init_all(ct.pointer(self._xt_globals), proto) + rv = xtables._xtables_init_all(ct.pointer(_xt_globals), proto) if rv: raise XTablesError("xtables_init_all() failed: %d" % (rv)) - self._save_globals() def __repr__(self): return "XTables for protocol %d" % (self.proto) - def _save_globals(self): - # Save our per-protocol libxtables global variables, and set them to - # NULL so that we don't interfere with other protocols. - null = ct.c_void_p(None) - self._xt_params = xtables._xtables_xt_params.value - xtables._xtables_xt_params.value = null.value - self._matches = xtables._xtables_matches.value - xtables._xtables_matches.value = null.value - self._pending_matches = xtables._xtables_pending_matches.value - xtables._xtables_pending_matches.value = null.value - self._targets = xtables._xtables_targets.value - xtables._xtables_targets.value = null.value - self._pending_targets = xtables._xtables_pending_targets.value - xtables._xtables_pending_targets.value = null.value - - def _restore_globals(self): - # Restore per-protocol libxtables global variables saved in - # _save_globals(). - xtables._xtables_set_nfproto(self.proto) - xtables._xtables_xt_params.value = self._xt_params - xtables._xtables_matches.value = self._matches - xtables._xtables_pending_matches.value = self._pending_matches - xtables._xtables_targets.value = self._targets - xtables._xtables_pending_targets.value = self._pending_targets - def _check_extname(self, name): - if name in ["", "ACCEPT", "DROP", "QUEUE", "RETURN"]: - name = "standard" + if name in [b"", b"ACCEPT", b"DROP", b"QUEUE", b"RETURN"]: + name = b"standard" return name - def _loaded(self, name): - self._loaded_exts.append(name) - - def _is_loaded(self, name): - if name in self._loaded_exts: - return True - else: - return False + def _loaded(self, name, ext): + _loaded_exts['%s___%s' % (self.proto, name)] = ext def _get_initfn_from_lib(self, name, lib): try: @@ -866,6 +968,10 @@ def _get_initfn_from_lib(self, name, lib): except AttributeError: prefix = self._get_prefix() initfn = getattr(lib, "%s%s_init" % (prefix, name), None) + if initfn is None and not self.no_alias_check: + if name in xtables._real_name: + name = xtables._real_name[name] + initfn = self._get_initfn_from_lib(name, lib) return initfn def _try_extinit(self, name, lib): @@ -889,6 +995,8 @@ def _get_prefix(self): raise XTablesError("Unknown protocol %d" % (self.proto)) def _try_register(self, name): + if isinstance(name, bytes): + name = name.decode() if self._try_extinit(name, _lib_xtables): return prefix = self._get_prefix() @@ -898,33 +1006,53 @@ def _try_register(self, name): if self._try_extinit(name, lib): return - @preserve_globals + def _get_loaded_ext(self, name): + ext = _loaded_exts.get('%s___%s' % (self.proto, name), None) + return ext + + @set_nfproto def find_match(self, name): + if isinstance(name, str): + name = name.encode() name = self._check_extname(name) + + ext = self._get_loaded_ext(name) + if ext is not None: + return ext + match = xtables._xtables_find_match(name, XTF_TRY_LOAD, None) if not match: self._try_register(name) - match = xtables._xtables_find_match(name, XTF_TRY_LOAD, None) + match = xtables._xtables_find_match(name, XTF_DONT_LOAD, None) if not match: return match - self._loaded(name) - return ct.cast(match, ct.POINTER(self._match_struct)) + m = ct.cast(match, ct.POINTER(self._match_struct)) + self._loaded(m[0].name, m) + return m - @preserve_globals + @set_nfproto def find_target(self, name): + if isinstance(name, str): + name = name.encode() name = self._check_extname(name) + + ext = self._get_loaded_ext(name) + if ext is not None: + return ext + target = xtables._xtables_find_target(name, XTF_TRY_LOAD) if not target: self._try_register(name) - target = xtables._xtables_find_target(name, XTF_TRY_LOAD) + target = xtables._xtables_find_target(name, XTF_DONT_LOAD) if not target: return target - self._loaded(name) - return ct.cast(target, ct.POINTER(self._target_struct)) + t = ct.cast(target, ct.POINTER(self._target_struct)) + self._loaded(t[0].name, t) + return t - @preserve_globals + @set_nfproto def save(self, module, ip, ptr): _wrap_save(module.save, ct.cast(ct.pointer(ip), ct.c_void_p), ptr) @@ -950,23 +1078,23 @@ def _parse(self, module, argv, inv, flags, entry, ptr): # Dispatch arguments to the appropriate parse function, based upon the # extension's choice of API. - @preserve_globals - def parse_target(self, argv, invert, t, fw, ptr): - _optarg.value = argv[1] - _optind.value = 2 + @set_nfproto + def parse_target(self, argv, invert, t, fw, ptr, x6_parse, x6_options): + _optarg.value = len(argv) > 1 and argv[1] or None + _optind.value = len(argv) - 1 - x6_options = None - x6_parse = None try: # new API? - x6_options = t.x6_options - x6_parse = t.x6_parse + if x6_options is None: + x6_options = t.x6_options + if x6_parse is None: + x6_parse = t.x6_parse except AttributeError: pass if x6_options and x6_parse: # new API - entry = self._option_lookup(t.x6_options, argv[0]) + entry = self._option_lookup(x6_options, argv[0]) if not entry: raise XTablesError("%s: no such parameter %s" % (t.name, argv[0])) @@ -981,7 +1109,7 @@ def parse_target(self, argv, invert, t, fw, ptr): cb.target = ct.pointer(t.t) cb.xt_entry = ct.cast(fw, ct.c_void_p) cb.udata = t.udata - rv = _wrap_x6fn(t.x6_parse, ct.pointer(cb)) + rv = _wrap_x6fn(x6_parse, ct.pointer(cb)) if rv != 0: raise XTablesError("%s: parameter error %d (%s)" % (t.name, rv, argv[1])) @@ -995,23 +1123,23 @@ def parse_target(self, argv, invert, t, fw, ptr): # Dispatch arguments to the appropriate parse function, based upon the # extension's choice of API. - @preserve_globals - def parse_match(self, argv, invert, m, fw, ptr): - _optarg.value = argv[1] - _optind.value = 2 + @set_nfproto + def parse_match(self, argv, invert, m, fw, ptr, x6_parse, x6_options): + _optarg.value = len(argv) > 1 and argv[1] or None + _optind.value = len(argv) - 1 - x6_options = None - x6_parse = None try: # new API? - x6_options = m.x6_options - x6_parse = m.x6_parse + if x6_options is None: + x6_options = m.x6_options + if x6_parse is None: + x6_parse = m.x6_parse except AttributeError: pass if x6_options and x6_parse: # new API - entry = self._option_lookup(m.x6_options, argv[0]) + entry = self._option_lookup(x6_options, argv[0]) if not entry: raise XTablesError("%s: no such parameter %s" % (m.name, argv[0])) @@ -1026,10 +1154,10 @@ def parse_match(self, argv, invert, m, fw, ptr): cb.match = ct.pointer(m.m) cb.xt_entry = ct.cast(fw, ct.c_void_p) cb.udata = m.udata - rv = _wrap_x6fn(m.x6_parse, ct.pointer(cb)) + rv = _wrap_x6fn(x6_parse, ct.pointer(cb)) if rv != 0: - raise XTablesError("%s: parameter error %d (%s)" % (m.name, rv, - argv[1])) + raise XTablesError("%s: parameter '%s' error %d" % ( + m.name, len(argv) > 1 and argv[1] or "", rv)) m.mflags |= cb.xflags return @@ -1076,7 +1204,7 @@ def _fcheck_target_new(self, target): # Dispatch arguments to the appropriate final_check function, based upon # the extension's choice of API. - @preserve_globals + @set_nfproto def final_check_target(self, target): x6_fcheck = None try: @@ -1116,7 +1244,7 @@ def _fcheck_match_new(self, match): # Dispatch arguments to the appropriate final_check function, based upon # the extension's choice of API. - @preserve_globals + @set_nfproto def final_check_match(self, match): x6_fcheck = None try: diff --git a/setup.py b/setup.py index e1caaaa..1929f93 100644 --- a/setup.py +++ b/setup.py @@ -2,31 +2,45 @@ """python-iptables setup script""" -from distutils.core import setup, Extension +from setuptools import setup, Extension # make pyflakes happy __pkgname__ = None __version__ = None -execfile("iptc/version.py") +exec(open("iptc/version.py").read()) # build/install python-iptables setup( name=__pkgname__, version=__version__, description="Python bindings for iptables", - author="Nilvec", - author_email="nilvec@nilvec.com", - url="http://nilvec.com/", + long_description="Python bindings for classic iptables", + long_description_content_type="text/x-rst", + author="Vilmos Nebehaj", + author_email="v.nebehaj@gmail.com", + url="https://github.com/ldx/python-iptables", packages=["iptc"], package_dir={"iptc": "iptc"}, ext_modules=[Extension("libxtwrapper", ["libxtwrapper/wrapper.c"])], + test_suite="tests", classifiers=[ - "Development Status :: 4 - Beta", + "Development Status :: 5 - Production/Stable", + "Environment :: Console", "Intended Audience :: Developers", - "License :: OSI Approved :: Apache License, Version 2.0", + "Intended Audience :: Information Technology", + "Intended Audience :: System Administrators", + "Intended Audience :: Telecommunications Industry", + "License :: OSI Approved :: Apache Software License", "Natural Language :: English", - "Topic :: Networking", + "Operating System :: POSIX :: Linux", + "Programming Language :: Python", + "Topic :: Software Development :: Libraries", + "Topic :: System :: Networking :: Firewalls", + "Topic :: System :: Systems Administration", + "Programming Language :: Python :: 2", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: Implementation :: CPython", ], license="Apache License, Version 2.0", ) diff --git a/test.py b/test.py deleted file mode 100755 index f426612..0000000 --- a/test.py +++ /dev/null @@ -1,22 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -import sys - -print "WARNING: this test will manipulate iptables rules." -print "Don't do this on a production machine." -while True: - print "Would you like to continue? y/n", - answer = raw_input() - if answer in "yYnN" and len(answer) == 1: - break -if answer in "nN": - sys.exit(0) - -from iptc.test import test_iptc, test_matches, test_targets - -results = [rv for rv in [test_iptc.run_tests(), test_matches.run_tests(), - test_targets.run_tests()]] -for res in results: - if res: - sys.exit(1) diff --git a/iptc/test/__init__.py b/tests/__init__.py similarity index 100% rename from iptc/test/__init__.py rename to tests/__init__.py diff --git a/iptc/test/test_iptc.py b/tests/test_iptc.py similarity index 71% rename from iptc/test/test_iptc.py rename to tests/test_iptc.py index 7441fc4..816df86 100755 --- a/iptc/test/test_iptc.py +++ b/tests/test_iptc.py @@ -18,10 +18,10 @@ def _check_chains(testcase, *chains): class TestTable6(unittest.TestCase): def setUp(self): - pass + self.autocommit = iptc.Table(iptc.Table.FILTER).autocommit def tearDown(self): - pass + iptc.Table(iptc.Table.FILTER, self.autocommit) def test_table6(self): filt = None @@ -65,10 +65,11 @@ class TestTable(unittest.TestCase): def setUp(self): self.chain = iptc.Chain(iptc.Table(iptc.Table.FILTER), "iptc_test_chain") + iptc.Table(iptc.Table.FILTER).create_chain(self.chain) def tearDown(self): - iptc.Table(iptc.Table.FILTER).delete_chain(self.chain) + iptc.Table(iptc.Table.FILTER).flush() def test_table(self): filt = None @@ -107,6 +108,48 @@ def test_refresh(self): self.chain.insert_rule(rule) self.chain.delete_rule(rule) + def test_flush_user_chains(self): + + chain1 = iptc.Chain(iptc.Table(iptc.Table.FILTER), + "iptc_test_flush_chain1") + chain2 = iptc.Chain(iptc.Table(iptc.Table.FILTER), + "iptc_test_flush_chain2") + iptc.Table(iptc.Table.FILTER).create_chain(chain1) + iptc.Table(iptc.Table.FILTER).create_chain(chain2) + + rule = iptc.Rule() + rule.target = iptc.Target(rule, chain2.name) + chain1.append_rule(rule) + + rule = iptc.Rule() + rule.target = iptc.Target(rule, chain1.name) + chain2.append_rule(rule) + + self.assertEquals(len(chain1.rules), 1) + self.assertEquals(len(chain2.rules), 1) + + filter_table = iptc.Table(iptc.Table.FILTER) + filter_table.flush() + + self.assertTrue(not filter_table.is_chain(chain1.name)) + self.assertTrue(not filter_table.is_chain(chain2.name)) + + def test_flush_builtin(self): + filter_table = iptc.Table(iptc.Table.FILTER) + output_rule_count = len(iptc.Chain(filter_table, "OUTPUT").rules) + + rule = iptc.Rule() + rule.target = iptc.Target(rule, "ACCEPT") + + iptc.Chain(filter_table, "OUTPUT").append_rule(rule) + + self.assertEquals(len(iptc.Chain(filter_table, "OUTPUT").rules), + output_rule_count + 1) + + filter_table.flush() + + self.assertEquals(len(iptc.Chain(filter_table, "OUTPUT").rules), 0) + class TestChain(unittest.TestCase): def setUp(self): @@ -246,7 +289,7 @@ def test_chain_counters(self): for chain in (chain for table in tables for chain in table.chains): counters = chain.get_counters() fails = 0 - for x in xrange(3): # try 3 times + for x in range(3): # try 3 times chain.zero_counters() counters = chain.get_counters() if counters: # only built-in chains @@ -331,6 +374,21 @@ def tearDown(self): self.chain.flush() self.chain.delete() + def test_create_mask(self): + rule = iptc.Rule6() + + # Mask /10 should return \xff\xc0\x00... + mask = rule._create_mask(10) + self.assertEquals(mask[0], 0xff) + self.assertEquals(mask[1], 0xc0) + self.assertEquals(mask[2:], [0x00]*14) + + # Mask /27 should return \xff\xff\xff\xe0... + mask = rule._create_mask(27) + self.assertEquals(mask[:3], [0xff, 0xff, 0xff]) + self.assertEquals(mask[3], 0xe0) + self.assertEquals(mask[4:], [0x00]*12) + def test_rule_address(self): # valid addresses rule = iptc.Rule6() @@ -365,7 +423,9 @@ def test_rule_address(self): def test_rule_interface(self): # valid interfaces rule = iptc.Rule6() - for intf in ["eth0", "eth+", "ip6tnl1", "ip6tnl+", "!ppp0", "!ppp+"]: + + max_length_valid_interface_name = "0123456789abcde" + for intf in ["eth0", "eth+", "ip6tnl1", "ip6tnl+", "!ppp0", "!ppp+", max_length_valid_interface_name]: rule.in_interface = intf self.assertEquals(intf, rule.in_interface) rule.out_interface = intf @@ -389,7 +449,7 @@ def test_rule_interface(self): def test_rule_protocol(self): rule = iptc.Rule6() for proto in ["tcp", "udp", "icmp", "AH", "ESP", "!TCP", "!UDP", - "!ICMP", "!ah", "!esp"]: + "!ICMP", "!ah", "!esp", "sctp", "!SCTP"]: rule.protocol = proto self.assertEquals(proto.lower(), rule.protocol) for proto in ["", "asdf", "!"]: @@ -402,6 +462,13 @@ def test_rule_protocol(self): else: self.fail("rule accepted invalid protocol %s" % (proto)) + def test_rule_protocol_numeric(self): + rule = iptc.Rule6() + rule.protocol = 33 + self.assertEquals(rule.protocol, '33') + rule.protocol = '!33' + self.assertEquals(rule.protocol, '!33') + def test_rule_compare(self): r1 = iptc.Rule6() r1.src = "::1/128" @@ -509,25 +576,65 @@ def test_rule_insert(self): self.failUnless(rule in crules) crules.remove(rule) + def test_rule_to_dict(self): + rule = iptc.Rule6() + rule.protocol = "tcp" + rule.src = "::1/128" + target = iptc.Target(rule, "ACCEPT") + rule.target = target + rule_d = iptc.easy.decode_iptc_rule(rule, ipv6=True) + # Remove counters when comparing rules + rule_d.pop('counters', None) + self.assertEqual(rule_d, {"protocol": "tcp", "src": "::1/128", "target": "ACCEPT"}) + + def test_rule_from_dict(self): + rule = iptc.Rule6() + rule.protocol = "tcp" + rule.src = "::1/128" + target = iptc.Target(rule, "ACCEPT") + rule.target = target + rule2 = iptc.easy.encode_iptc_rule({"protocol": "tcp", "src": "::1/128", "target": "ACCEPT"}, ipv6=True) + self.assertEqual(rule, rule2) class TestRule(unittest.TestCase): def setUp(self): - self.chain = iptc.Chain(iptc.Table(iptc.Table.FILTER), - "iptc_test_chain") - iptc.Table(iptc.Table.FILTER).create_chain(self.chain) + self.table = iptc.Table(iptc.Table.FILTER) + self.chain = iptc.Chain(self.table, "iptc_test_chain") + try: + self.table.create_chain(self.chain) + except: + self.chain.flush() + if is_table_available(iptc.Table.NAT): + self.table_nat = iptc.Table(iptc.Table.NAT) + self.chain_nat = iptc.Chain(self.table_nat, "iptc_test_nat_chain") + try: + self.table_nat.create_chain(self.chain_nat) + except: + self.chain_nat.flush() def tearDown(self): + self.table.autocommit = True self.chain.flush() self.chain.delete() + if is_table_available(iptc.Table.NAT): + self.table_nat.autocommit = True + self.chain_nat.flush() + self.chain_nat.delete() def test_rule_address(self): # valid addresses rule = iptc.Rule() - for addr in ["127.0.0.1/255.255.255.0", "!127.0.0.1/255.255.255.0"]: - rule.src = addr - self.assertEquals(rule.src, addr) - rule.dst = addr - self.assertEquals(rule.dst, addr) + for addr in [("127.0.0.1/255.255.255.0", "127.0.0.0/255.255.255.0"), + ("!127.0.0.1/255.255.255.0", "!127.0.0.0/255.255.255.0"), + ("127.0.0.1/255.255.128.0", "127.0.0.0/255.255.128.0"), + ("127.0.0.1/16", "127.0.0.0/255.255.0.0"), + ("127.0.0.1/24", "127.0.0.0/255.255.255.0"), + ("127.0.0.1/17", "127.0.0.0/255.255.128.0"), + ("!127.0.0.1/17", "!127.0.0.0/255.255.128.0")]: + rule.src = addr[0] + self.assertEquals(rule.src, addr[1]) + rule.dst = addr[0] + self.assertEquals(rule.dst, addr[1]) addr = "127.0.0.1" rule.src = addr self.assertEquals("127.0.0.1/255.255.255.255", rule.src) @@ -536,7 +643,8 @@ def test_rule_address(self): # invalid addresses for addr in ["127.256.0.1/255.255.255.0", "127.0.1/255.255.255.0", - "127.0.0.1/255.255.255.", "127.0.0.1 255.255.255.0"]: + "127.0.0.1/255.255.255.", "127.0.0.1 255.255.255.0", + "127.0.0.1/33", "127.0.0.1/-5", "127.0.0.1/255.5"]: try: rule.src = addr except ValueError: @@ -558,6 +666,12 @@ def test_rule_interface(self): self.assertEquals(intf, rule.in_interface) rule.out_interface = intf self.assertEquals(intf, rule.out_interface) + rule.create_target("ACCEPT") + self.chain.insert_rule(rule) + r = self.chain.rules[0] + eq = r == rule + self.chain.flush() + self.assertTrue(eq) # invalid interfaces for intf in ["itsaverylonginterfacename"]: @@ -597,6 +711,13 @@ def test_rule_protocol(self): else: self.fail("rule accepted invalid protocol %s" % (proto)) + def test_rule_protocol_numeric(self): + rule = iptc.Rule() + rule.protocol = 33 + self.assertEquals(rule.protocol, '33') + rule.protocol = '!33' + self.assertEquals(rule.protocol, '!33') + def test_rule_compare(self): r1 = iptc.Rule() r1.src = "127.0.0.2/255.255.255.0" @@ -665,6 +786,37 @@ def test_rule_iterate_mangle(self): for rule in chain.rules if rule): pass + def test_rule_iterate_rulenum(self): + """Ensure rule numbers are always returned in order""" + insert_rule_count = 3 + append_rule_count = 3 + for rule_num in range(insert_rule_count, 0, -1): + rule = iptc.Rule() + match = rule.create_match("comment") + match.comment = "rule{rule_num}".format(rule_num=rule_num) + rule.create_target("ACCEPT") + self.chain.insert_rule(rule) + + append_rulenum_start = insert_rule_count + 1 + append_rulenum_end = append_rulenum_start + 3 + for rule_num in range(append_rulenum_start, append_rulenum_end): + rule = iptc.Rule() + match = rule.create_match("comment") + match.comment = "rule{rule_num}".format(rule_num=rule_num) + rule.create_target("ACCEPT") + self.chain.append_rule(rule) + + rules = self.chain.rules + assert len(rules) == (insert_rule_count + append_rule_count) + for rule_num, rule in enumerate(rules, start=1): + assert len(rule.matches) == 1 + assert rule.matches[0].comment == "rule{rule_num}".format( + rule_num=rule_num), \ + "rule[{left_num}] is not new {right_num}".format( + left_num=rule_num, + right_num=rule.matches[0].comment + ) + def test_rule_insert(self): rules = [] @@ -699,6 +851,110 @@ def test_rule_insert(self): self.failUnless(rule in crules) crules.remove(rule) + def test_rule_replace(self): + rule = iptc.Rule() + rule.protocol = "tcp" + rule.src = "127.0.0.1" + target = iptc.Target(rule, "ACCEPT") + rule.target = target + self.chain.insert_rule(rule, 0) + + rule = iptc.Rule() + rule.protocol = "udp" + rule.src = "127.0.0.1" + target = iptc.Target(rule, "ACCEPT") + rule.target = target + + self.chain.replace_rule(rule, 0) + self.failUnless(self.chain.rules[0] == rule) + + def test_rule_multiple_parameters(self): + self.table.autocommit = False + self.table.refresh() + rule = iptc.Rule() + rule.dst = "127.0.0.1" + rule.protocol = "tcp" + match = rule.create_match('tcp') + match.sport = "1234" + match.dport = "8080" + target = rule.create_target("REJECT") + target.reject_with = "icmp-host-unreachable" + self.chain.insert_rule(rule) + self.table.commit() + self.table.refresh() + self.assertEquals(len(self.chain.rules), 1) + r = self.chain.rules[0] + self.assertEquals(r.src, '0.0.0.0/0.0.0.0') + self.assertEquals(r.dst, '127.0.0.1/255.255.255.255') + self.assertEquals(r.protocol, 'tcp') + self.assertEquals(len(r.matches), 1) + m = r.matches[0] + self.assertEquals(m.name, 'tcp') + self.assertEquals(m.sport, '1234') + self.assertEquals(m.dport, '8080') + + def test_rule_delete(self): + self.table.autocommit = False + self.table.refresh() + for p in ['8001', '8002', '8003']: + rule = iptc.Rule() + rule.dst = "127.0.0.1" + rule.protocol = "tcp" + rule.dport = "8080" + target = rule.create_target("REJECT") + target.reject_with = "icmp-host-unreachable" + self.chain.insert_rule(rule) + self.table.commit() + self.table.refresh() + + rules = self.chain.rules + for rule in rules: + self.chain.delete_rule(rule) + self.table.commit() + self.table.refresh() + + def test_rule_delete_nat(self): + if not is_table_available(iptc.Table.NAT): + return + + self.table_nat.autocommit = False + self.table_nat.refresh() + for p in ['8001', '8002', '8003']: + rule = iptc.Rule() + rule.dst = "127.0.0.1" + rule.protocol = "udp" + rule.dport = "8080" + target = rule.create_target("DNAT") + target.to_destination = '127.0.0.0:' + p + self.chain_nat.insert_rule(rule) + self.table_nat.commit() + self.table_nat.refresh() + + rules = self.chain_nat.rules + for rule in rules: + self.chain_nat.delete_rule(rule) + self.table_nat.commit() + self.table_nat.refresh() + + def test_rule_to_dict(self): + rule = iptc.Rule() + rule.protocol = "tcp" + rule.src = "127.0.0.1/32" + target = iptc.Target(rule, "ACCEPT") + rule.target = target + rule_d = iptc.easy.decode_iptc_rule(rule) + # Remove counters when comparing rules + rule_d.pop('counters', None) + self.assertEqual(rule_d, {"protocol": "tcp", "src": "127.0.0.1/32", "target": "ACCEPT"}) + + def test_rule_from_dict(self): + rule = iptc.Rule() + rule.protocol = "tcp" + rule.src = "127.0.0.1/32" + target = iptc.Target(rule, "ACCEPT") + rule.target = target + rule2 = iptc.easy.encode_iptc_rule({"protocol": "tcp", "src": "127.0.0.1/32", "target": "ACCEPT"}) + self.assertEqual(rule, rule2) def suite(): suite_table6 = unittest.TestLoader().loadTestsFromTestCase(TestTable6) diff --git a/tests/test_matches.py b/tests/test_matches.py new file mode 100755 index 0000000..d7c1995 --- /dev/null +++ b/tests/test_matches.py @@ -0,0 +1,542 @@ +# -*- coding: utf-8 -*- + +import unittest +import iptc + + +is_table6_available = iptc.is_table6_available + + +class TestMatch(unittest.TestCase): + def setUp(self): + pass + + def tearDown(self): + pass + + def test_match_create(self): + rule = iptc.Rule() + match = rule.create_match("udp") + + for m in rule.matches: + self.assertEqual(m, match) + + # check that we can change match parameters after creation + match.sport = "12345:55555" + match.dport = "!33333" + + m = iptc.Match(iptc.Rule(), "udp") + m.sport = "12345:55555" + m.dport = "!33333" + + self.assertEqual(m, match) + + def test_match_compare(self): + m1 = iptc.Match(iptc.Rule(), "udp") + m1.sport = "12345:55555" + m1.dport = "!33333" + + m2 = iptc.Match(iptc.Rule(), "udp") + m2.sport = "12345:55555" + m2.dport = "!33333" + + self.assertEqual(m1, m2) + + m2.reset() + m2.sport = "12345:55555" + m2.dport = "33333" + self.assertNotEqual(m1, m2) + + def test_match_parameters(self): + m = iptc.Match(iptc.Rule(), "udp") + m.sport = "12345:55555" + m.dport = "!33333" + + self.assertEqual(len(m.parameters), 2) + + for p in m.parameters: + self.assertTrue(p == "sport" or p == "dport") + + self.assertEqual(m.parameters["sport"], "12345:55555") + self.assertEqual(m.parameters["dport"], "!33333") + + m.reset() + self.assertEqual(len(m.parameters), 0) + + def test_get_all_parameters(self): + m = iptc.Match(iptc.Rule(), "udp") + m.sport = "12345:55555" + m.dport = "!33333" + + params = m.get_all_parameters() + self.assertEqual(set(params['sport']), set(['12345:55555'])) + self.assertEqual(set(params['dport']), set(['!', '33333'])) + + +class TestMultiportMatch(unittest.TestCase): + def setUp(self): + self.rule = iptc.Rule() + self.rule.src = "127.0.0.1" + self.rule.protocol = "udp" + self.rule.create_target("ACCEPT") + + self.match = self.rule.create_match("multiport") + + table = iptc.Table(iptc.Table.FILTER) + self.chain = iptc.Chain(table, "iptc_test_udp") + try: + self.chain.flush() + self.chain.delete() + except: + pass + + iptc.Table(iptc.Table.FILTER).create_chain(self.chain) + + def tearDown(self): + for r in self.chain.rules: + self.chain.delete_rule(r) + self.chain.flush() + self.chain.delete() + + def test_multiport(self): + self.match.dports = '1111,2222' + self.chain.insert_rule(self.rule) + rule = self.chain.rules[0] + match = rule.matches[0] + self.assertEqual(match.dports, '1111,2222') + + def test_unicode_multiport(self): + self.match.dports = u'1111,2222' + self.chain.insert_rule(self.rule) + rule = self.chain.rules[0] + match = rule.matches[0] + self.assertEqual(match.dports, '1111,2222') + + +class TestXTUdpMatch(unittest.TestCase): + def setUp(self): + self.rule = iptc.Rule() + self.rule.src = "127.0.0.1" + self.rule.protocol = "udp" + self.rule.target = iptc.Target(self.rule, "ACCEPT") + + self.match = iptc.Match(self.rule, "udp") + self.chain = iptc.Chain(iptc.Table(iptc.Table.FILTER), "iptc_test_udp") + iptc.Table(iptc.Table.FILTER).create_chain(self.chain) + + def tearDown(self): + for r in self.chain.rules: + self.chain.delete_rule(r) + self.chain.flush() + self.chain.delete() + + def test_udp_port(self): + for port in ["12345", "12345:65535", "!12345", "12345:12346", + "!12345:12346", "0:1234", "! 1234", "!0:12345", + "!1234:65535"]: + self.match.sport = port + self.assertEqual(self.match.sport, port.replace(" ", "")) + self.match.dport = port + self.assertEqual(self.match.dport, port.replace(" ", "")) + self.match.reset() + for port in ["-1", "asdf", "!asdf"]: + try: + self.match.sport = port + except Exception: + pass + else: + self.fail("udp accepted invalid source port %s" % (port)) + try: + self.match.dport = port + except Exception: + pass + else: + self.fail("udp accepted invalid destination port %s" % (port)) + self.match.reset() + + def test_udp_insert(self): + self.match.reset() + self.match.dport = "12345" + self.rule.add_match(self.match) + + self.chain.insert_rule(self.rule) + + for r in self.chain.rules: + if r != self.rule: + self.fail("inserted rule does not match original") + + +class TestXTMarkMatch(unittest.TestCase): + def setUp(self): + self.rule = iptc.Rule() + self.rule.src = "127.0.0.1" + self.rule.protocol = "tcp" + self.rule.target = iptc.Target(self.rule, "ACCEPT") + + self.match = iptc.Match(self.rule, "mark") + + self.chain = iptc.Chain(iptc.Table(iptc.Table.FILTER), + "iptc_test_mark") + iptc.Table(iptc.Table.FILTER).create_chain(self.chain) + + def tearDown(self): + for r in self.chain.rules: + self.chain.delete_rule(r) + self.chain.flush() + self.chain.delete() + + def test_mark(self): + for mark in ["0x7b", "! 0x7b", "0x7b/0xfffefffe", "!0x7b/0xff00ff00"]: + self.match.mark = mark + self.assertEqual(self.match.mark, mark.replace(" ", "")) + self.match.reset() + for mark in ["0xffffffffff", "123/0xffffffff1", "!asdf", "1234:1233"]: + try: + self.match.mark = mark + except Exception: + pass + else: + self.fail("mark accepted invalid value %s" % (mark)) + self.match.reset() + + def test_mark_insert(self): + self.match.reset() + self.match.mark = "0x123" + self.rule.add_match(self.match) + + self.chain.insert_rule(self.rule) + + for r in self.chain.rules: + if r != self.rule: + self.fail("inserted rule does not match original") + + +class TestXTLimitMatch(unittest.TestCase): + def setUp(self): + self.rule = iptc.Rule() + self.rule.src = "127.0.0.1" + self.rule.protocol = "tcp" + self.rule.target = iptc.Target(self.rule, "ACCEPT") + + self.match = iptc.Match(self.rule, "limit") + self.chain = iptc.Chain(iptc.Table(iptc.Table.FILTER), + "iptc_test_limit") + iptc.Table(iptc.Table.FILTER).create_chain(self.chain) + + def tearDown(self): + for r in self.chain.rules: + self.chain.delete_rule(r) + self.chain.flush() + self.chain.delete() + + def test_limit(self): + for limit in ["1/sec", "5/min", "3/hour"]: + self.match.limit = limit + self.assertEqual(self.match.limit, limit) + self.match.reset() + for limit in ["asdf", "123/1", "!1", "!1/second"]: + try: + self.match.limit = limit + except Exception: + pass + else: + self.fail("limit accepted invalid value %s" % (limit)) + self.match.reset() + + def test_limit_insert(self): + self.match.reset() + self.match.limit = "1/min" + self.rule.add_match(self.match) + + self.chain.insert_rule(self.rule) + + for r in self.chain.rules: + if r != self.rule: + self.fail("inserted rule does not match original") + + +class TestIcmpv6Match(unittest.TestCase): + def setUp(self): + self.rule = iptc.Rule6() + self.rule.protocol = "icmpv6" + self.rule.in_interface = "eth0" + + self.target = self.rule.create_target("ACCEPT") + + self.match = self.rule.create_match("icmp6") + self.match.icmpv6_type = "echo-request" + + self.table = iptc.Table6(iptc.Table6.FILTER) + + self.chain = iptc.Chain(self.table, "ip6tc_test_icmpv6") + try: + self.table.delete_chain(self.chain) + except: + pass + self.table.create_chain(self.chain) + + def tearDown(self): + for r in self.chain.rules: + self.chain.delete_rule(r) + self.chain.flush() + self.chain.delete() + + def test_icmpv6(self): + self.chain.insert_rule(self.rule) + rule = self.chain.rules[0] + self.assertEqual(self.rule, rule) + + +class TestCommentMatch(unittest.TestCase): + def setUp(self): + self.rule = iptc.Rule() + self.rule.src = "127.0.0.1" + self.rule.protocol = "udp" + self.rule.target = iptc.Target(self.rule, "ACCEPT") + + self.match = iptc.Match(self.rule, "comment") + self.chain = iptc.Chain(iptc.Table(iptc.Table.FILTER), + "iptc_test_comment") + iptc.Table(iptc.Table.FILTER).create_chain(self.chain) + + def tearDown(self): + for r in self.chain.rules: + self.chain.delete_rule(r) + self.chain.flush() + self.chain.delete() + + def test_comment(self): + comment = "comment test" + self.match.reset() + self.match.comment = comment + self.chain.insert_rule(self.rule) + self.assertEqual(self.match.comment, comment) + + +class TestIprangeMatch(unittest.TestCase): + def setUp(self): + self.rule = iptc.Rule() + self.rule.protocol = "tcp" + self.rule.target = iptc.Target(self.rule, "ACCEPT") + + self.match = iptc.Match(self.rule, "iprange") + + self.chain = iptc.Chain(iptc.Table(iptc.Table.FILTER), + "iptc_test_iprange") + iptc.Table(iptc.Table.FILTER).create_chain(self.chain) + + def tearDown(self): + for r in self.chain.rules: + self.chain.delete_rule(r) + self.chain.flush() + self.chain.delete() + + def test_iprange(self): + self.match.src_range = "192.168.1.100-192.168.1.200" + self.match.dst_range = "172.22.33.106" + self.rule.add_match(self.match) + + self.chain.insert_rule(self.rule) + + for r in self.chain.rules: + if r != self.rule: + self.fail("inserted rule does not match original") + + def test_iprange_tcpdport(self): + self.match.src_range = "192.168.1.100-192.168.1.200" + self.match.dst_range = "172.22.33.106" + self.rule.add_match(self.match) + + match = iptc.Match(self.rule, "tcp") + match.dport = "22" + self.rule.add_match(match) + + self.chain.insert_rule(self.rule) + + for r in self.chain.rules: + if r != self.rule: + self.fail("inserted rule does not match original") + + +class TestXTStateMatch(unittest.TestCase): + def setUp(self): + self.rule = iptc.Rule() + self.rule.src = "127.0.0.1" + self.rule.protocol = "tcp" + self.rule.target = iptc.Target(self.rule, "ACCEPT") + + self.match = iptc.Match(self.rule, "state") + + self.chain = iptc.Chain(iptc.Table(iptc.Table.FILTER), + "iptc_test_state") + self.table = iptc.Table(iptc.Table.FILTER) + try: + self.chain.flush() + self.chain.delete() + except: + pass + self.table.create_chain(self.chain) + + def tearDown(self): + for r in self.chain.rules: + self.chain.delete_rule(r) + self.chain.flush() + self.chain.delete() + + def test_state(self): + self.match.state = "RELATED,ESTABLISHED" + self.rule.add_match(self.match) + self.chain.insert_rule(self.rule) + rule = self.chain.rules[0] + m = rule.matches[0] + self.assertEqual(m.name, "state") + self.assertEqual(m.state, "RELATED,ESTABLISHED") + self.assertEqual(rule.matches[0].name, self.rule.matches[0].name) + self.assertEqual(rule, self.rule) + + +class TestXTConntrackMatch(unittest.TestCase): + def setUp(self): + self.rule = iptc.Rule() + self.rule.src = "127.0.0.1" + self.rule.protocol = "tcp" + self.rule.target = iptc.Target(self.rule, "ACCEPT") + + self.match = iptc.Match(self.rule, "conntrack") + + self.chain = iptc.Chain(iptc.Table(iptc.Table.FILTER), + "iptc_test_conntrack") + self.table = iptc.Table(iptc.Table.FILTER) + try: + self.chain.flush() + self.chain.delete() + except: + pass + self.table.create_chain(self.chain) + + def tearDown(self): + for r in self.chain.rules: + self.chain.delete_rule(r) + self.chain.flush() + self.chain.delete() + + def test_state(self): + self.match.ctstate = "NEW,RELATED" + self.rule.add_match(self.match) + self.chain.insert_rule(self.rule) + rule = self.chain.rules[0] + m = rule.matches[0] + self.assertTrue(m.name, ["conntrack"]) + self.assertEqual(m.ctstate, "NEW,RELATED") + + +class TestHashlimitMatch(unittest.TestCase): + def setUp(self): + self.rule = iptc.Rule() + self.rule.src = "127.0.0.1" + self.rule.protocol = "udp" + self.rule.target = iptc.Target(self.rule, "ACCEPT") + + self.match = iptc.Match(self.rule, "hashlimit") + + self.chain = iptc.Chain(iptc.Table(iptc.Table.FILTER), + "iptc_test_hashlimit") + self.table = iptc.Table(iptc.Table.FILTER) + try: + self.chain.flush() + self.chain.delete() + except: + pass + self.table.create_chain(self.chain) + + def tearDown(self): + for r in self.chain.rules: + self.chain.delete_rule(r) + self.chain.flush() + self.chain.delete() + + def test_hashlimit(self): + self.match.hashlimit_name = 'foo' + self.match.hashlimit_mode = 'srcip' + self.match.hashlimit_upto = '200/sec' + self.match.hashlimit = '200' + self.match.hashlimit_htable_expire = '100' + self.rule.add_match(self.match) + self.chain.insert_rule(self.rule) + rule = self.chain.rules[0] + m = rule.matches[0] + self.assertTrue(m.name, ["hashlimit"]) + self.assertEqual(m.hashlimit_name, "foo") + self.assertEqual(m.hashlimit_mode, "srcip") + self.assertEqual(m.hashlimit_upto, "200/sec") + self.assertEqual(m.hashlimit_burst, "5") + +class TestRecentMatch(unittest.TestCase): + def setUp(self): + self.table = 'filter' + self.chain = 'iptc_test_recent' + iptc.easy.delete_chain(self.table, self.chain, ipv6=False, flush=True, raise_exc=False) + iptc.easy.add_chain(self.table, self.chain, ipv6=False, raise_exc=True) + + def tearDown(self): + iptc.easy.delete_chain(self.table, self.chain, ipv6=False, flush=True, raise_exc=False) + + def test_recent(self): + rule_d = { + 'protocol': 'udp', + 'recent': { + 'mask': '255.255.255.255', + 'update': '', + 'seconds': '60', + 'rsource': '', + 'name': 'UDP-PORTSCAN', + }, + 'target': { + 'REJECT':{ + 'reject-with': 'icmp-port-unreachable' + } + } + } + iptc.easy.add_rule(self.table, self.chain, rule_d) + rule2_d = iptc.easy.get_rule(self.table, self.chain, -1) + # Remove counters when comparing rules + rule2_d.pop('counters', None) + self.assertEqual(rule_d, rule2_d) + +def suite(): + suite_match = unittest.TestLoader().loadTestsFromTestCase(TestMatch) + suite_udp = unittest.TestLoader().loadTestsFromTestCase(TestXTUdpMatch) + suite_mark = unittest.TestLoader().loadTestsFromTestCase(TestXTMarkMatch) + suite_limit = unittest.TestLoader().loadTestsFromTestCase(TestXTLimitMatch) + suite_mport = unittest.TestLoader().loadTestsFromTestCase( + TestMultiportMatch) + suite_comment = unittest.TestLoader().loadTestsFromTestCase( + TestCommentMatch) + suite_iprange = unittest.TestLoader().loadTestsFromTestCase( + TestIprangeMatch) + suite_state = unittest.TestLoader().loadTestsFromTestCase(TestXTStateMatch) + suite_conntrack = unittest.TestLoader().loadTestsFromTestCase( + TestXTConntrackMatch) + suite_hashlimit = unittest.TestLoader().loadTestsFromTestCase( + TestHashlimitMatch) + suite_recent = unittest.TestLoader().loadTestsFromTestCase( + TestRecentMatch) + extra_suites = [] + if is_table6_available(iptc.Table6.FILTER): + extra_suites += unittest.TestLoader().loadTestsFromTestCase( + TestIcmpv6Match) + + return unittest.TestSuite([suite_match, suite_udp, suite_mark, + suite_limit, suite_mport, suite_comment, + suite_iprange, suite_state, suite_conntrack, + suite_hashlimit, suite_recent] + extra_suites) + + +def run_tests(): + result = unittest.TextTestRunner(verbosity=2).run(suite()) + if result.errors or result.failures: + return 1 + return 0 + +if __name__ == "__main__": + unittest.main() diff --git a/iptc/test/test_targets.py b/tests/test_targets.py similarity index 78% rename from iptc/test/test_targets.py rename to tests/test_targets.py index 2ecaf61..3587acc 100755 --- a/iptc/test/test_targets.py +++ b/tests/test_targets.py @@ -2,6 +2,7 @@ import unittest import iptc +from iptc.xtables import xtables_version is_table_available = iptc.is_table_available @@ -78,6 +79,8 @@ def setUp(self): iptc.Table(iptc.Table.FILTER).create_chain(self.chain) def tearDown(self): + for r in self.chain.rules: + self.chain.delete_rule(r) self.chain.flush() self.chain.delete() @@ -132,6 +135,8 @@ def setUp(self): iptc.Table(iptc.Table.NAT).create_chain(self.chain) def tearDown(self): + for r in self.chain.rules: + self.chain.delete_rule(r) self.chain.flush() self.chain.delete() @@ -180,6 +185,8 @@ def setUp(self): iptc.Table(iptc.Table.MANGLE).create_chain(self.chain) def tearDown(self): + for r in self.chain.rules: + self.chain.delete_rule(r) self.chain.flush() self.chain.delete() @@ -261,6 +268,8 @@ def setUp(self): iptc.Table(iptc.Table.MANGLE).create_chain(self.chain) def tearDown(self): + for r in self.chain.rules: + self.chain.delete_rule(r) self.chain.flush() self.chain.delete() @@ -306,6 +315,8 @@ def setUp(self): iptc.Table(iptc.Table.NAT).create_chain(self.chain) def tearDown(self): + for r in self.chain.rules: + self.chain.delete_rule(r) self.chain.flush() self.chain.delete() @@ -342,22 +353,91 @@ def test_insert(self): self.fail("inserted rule does not match original") +class TestXTNotrackTarget(unittest.TestCase): + def setUp(self): + self.rule = iptc.Rule() + self.rule.dst = "127.0.0.2" + self.rule.protocol = "tcp" + self.rule.out_interface = "eth0" + + self.target = iptc.Target(self.rule, "NOTRACK") + self.rule.target = self.target + + self.chain = iptc.Chain(iptc.Table(iptc.Table.RAW), + "iptc_test_notrack") + try: + self.chain.flush() + self.chain.delete() + except: + pass + iptc.Table(iptc.Table.RAW).create_chain(self.chain) + + def tearDown(self): + for r in self.chain.rules: + self.chain.delete_rule(r) + self.chain.flush() + self.chain.delete() + + def test_notrack(self): + self.chain.insert_rule(self.rule) + t = self.chain.rules[0].target + self.assertTrue(t.name in ["NOTRACK", "CT"]) + + +class TestXTCtTarget(unittest.TestCase): + def setUp(self): + self.rule = iptc.Rule() + self.rule.dst = "127.0.0.2" + self.rule.protocol = "tcp" + self.rule.out_interface = "eth0" + + self.target = iptc.Target(self.rule, "CT") + self.target.notrack = "true" + self.rule.target = self.target + + self.chain = iptc.Chain(iptc.Table(iptc.Table.RAW), + "iptc_test_ct") + try: + self.chain.flush() + self.chain.delete() + except: + pass + iptc.Table(iptc.Table.RAW).create_chain(self.chain) + + def tearDown(self): + for r in self.chain.rules: + self.chain.delete_rule(r) + self.chain.flush() + self.chain.delete() + + def test_ct(self): + self.chain.insert_rule(self.rule) + t = self.chain.rules[0].target + self.assertEquals(t.name, "CT") + self.assertTrue(t.notrack is not None) + + def suite(): + suites = [] suite_target = unittest.TestLoader().loadTestsFromTestCase(TestTarget) suite_tos = unittest.TestLoader().loadTestsFromTestCase(TestXTTosTarget) suite_cluster = unittest.TestLoader().loadTestsFromTestCase( TestXTClusteripTarget) + suite_redir = unittest.TestLoader().loadTestsFromTestCase( + TestIPTRedirectTarget) + suite_masq = unittest.TestLoader().loadTestsFromTestCase( + TestIPTMasqueradeTarget) + suite_dnat = unittest.TestLoader().loadTestsFromTestCase( + TestDnatTarget) + suite_notrack = unittest.TestLoader().loadTestsFromTestCase( + TestXTNotrackTarget) + suite_ct = unittest.TestLoader().loadTestsFromTestCase(TestXTCtTarget) + suites.extend([suite_target, suite_cluster, suite_tos]) if is_table_available(iptc.Table.NAT): - suite_redir = unittest.TestLoader().loadTestsFromTestCase( - TestIPTRedirectTarget) - suite_masq = unittest.TestLoader().loadTestsFromTestCase( - TestIPTMasqueradeTarget) - suite_dnat = unittest.TestLoader().loadTestsFromTestCase( - TestDnatTarget) - return unittest.TestSuite([suite_target, suite_cluster, suite_redir, - suite_tos, suite_masq, suite_dnat]) - else: - return unittest.TestSuite([suite_target, suite_cluster, suite_tos]) + suites.extend([suite_redir, suite_masq, suite_dnat]) + if is_table_available(iptc.Table.RAW) and xtables_version >= 10: + suites.extend([suite_notrack, suite_ct]) + return unittest.TestSuite(suites) def run_tests():