Move to pre-routing

This commit is contained in:
Darryl Nixon 2023-07-02 14:00:34 -07:00
parent 76c2b1bc07
commit 7b487db1b2
3 changed files with 23 additions and 21 deletions

View file

@ -46,7 +46,6 @@ class DrawBridge:
if packet.payload != original:
packet.mangle()
@staticmethod
def _delete_rules(self):
for queue in self.net_queues:
try:
@ -60,6 +59,6 @@ class DrawBridge:
connection = fnfqueue.Connection()
listener = connection.bind(queue.queue)
listener.set_mode(65535, fnfqueue.COPY_PACKET)
task = asyncio.create_task(self._listen(listener, queue.callback))
task = asyncio.create_task(self._listen(connection, queue.callback))
tasks.append(task)
await asyncio.gather(*tasks)

View file

@ -8,14 +8,14 @@ from typing import Union
import iptc
from .utils.logger import logger
from .utils.lookup import Protocols
from .utils.lookup import PROTOCOLS, TABLES
class NetQueue:
def __init__(
self,
callback: Callable,
queue: int,
callback: Callable,
src_ip: Optional[str] = None,
dst_ip: Optional[str] = None,
src_port: Optional[int] = None,
@ -44,13 +44,13 @@ class NetQueue:
return rule
def write_rule(self):
table = iptc.Table(iptc.Table.MANGLE)
chain = iptc.Chain(table, "INPUT")
table = iptc.Table(iptc.Table.NAT)
chain = iptc.Chain(table, "PREROUTING")
chain.insert_rule(self.rule)
def delete_rule(self):
table = iptc.Table(iptc.Table.MANGLE)
chain = iptc.Chain(table, "INPUT")
table = iptc.Table(iptc.Table.NAT)
chain = iptc.Chain(table, "PREROUTING")
try:
chain.delete_rule(self.rule)
except iptc.ip4tc.IPTCError:
@ -83,14 +83,14 @@ class NetQueue:
def validate_protocol(protocol: Optional[str]) -> Union[str, None]:
if protocol:
try:
Protocols[protocol]
PROTOCOLS[protocol]
except KeyError:
raise KeyError(f"Invalid protocol: {protocol}")
return protocol
@staticmethod
def _is_queue_taken(queue: int, override: bool) -> bool:
table = iptc.Table(iptc.Table.FILTER)
for table in TABLES:
for chain in table.chains:
for rule in chain.rules:
if rule.target.name == "NFQUEUE" and rule.target.get_all_parameters()["queue-num"] == str(queue):
@ -116,13 +116,13 @@ class NetQueue:
return (
f"NetQueueFilter("
f"queue={self.queue}, "
f"callback={self.callback}, "
f"src_ip={self.src_ip}, "
f"dst_ip={self.dst_ip}, "
f"src_port={self.src_port}, "
f"dst_port={self.dst_port}, "
f"protocol={self.protocol}, "
f"callback={self.callback}, "
f"async_callback={self.async_callback}"
f"rule={self.rule}, "
f")"
)

View file

@ -1,6 +1,7 @@
import socket
import iptc
Protocols = {
PROTOCOLS = {
"ah": socket.IPPROTO_AH,
"dstopts": socket.IPPROTO_DSTOPTS,
"egp": socket.IPPROTO_EGP,
@ -26,3 +27,5 @@ Protocols = {
"tp": socket.IPPROTO_TP,
"udp": socket.IPPROTO_UDP,
}
TABLES = [iptc.Table(t) for t in iptc.Table.ALL]