Move to pre-routing

This commit is contained in:
Darryl Nixon 2023-07-02 14:00:34 -07:00
parent 76c2b1bc07
commit 7b487db1b2
3 changed files with 23 additions and 21 deletions

View file

@ -46,7 +46,6 @@ class DrawBridge:
if packet.payload != original: if packet.payload != original:
packet.mangle() packet.mangle()
@staticmethod
def _delete_rules(self): def _delete_rules(self):
for queue in self.net_queues: for queue in self.net_queues:
try: try:
@ -60,6 +59,6 @@ class DrawBridge:
connection = fnfqueue.Connection() connection = fnfqueue.Connection()
listener = connection.bind(queue.queue) listener = connection.bind(queue.queue)
listener.set_mode(65535, fnfqueue.COPY_PACKET) listener.set_mode(65535, fnfqueue.COPY_PACKET)
task = asyncio.create_task(self._listen(listener, queue.callback)) task = asyncio.create_task(self._listen(connection, queue.callback))
tasks.append(task) tasks.append(task)
await asyncio.gather(*tasks) await asyncio.gather(*tasks)

View file

@ -8,14 +8,14 @@ from typing import Union
import iptc import iptc
from .utils.logger import logger from .utils.logger import logger
from .utils.lookup import Protocols from .utils.lookup import PROTOCOLS, TABLES
class NetQueue: class NetQueue:
def __init__( def __init__(
self, self,
callback: Callable,
queue: int, queue: int,
callback: Callable,
src_ip: Optional[str] = None, src_ip: Optional[str] = None,
dst_ip: Optional[str] = None, dst_ip: Optional[str] = None,
src_port: Optional[int] = None, src_port: Optional[int] = None,
@ -44,13 +44,13 @@ class NetQueue:
return rule return rule
def write_rule(self): def write_rule(self):
table = iptc.Table(iptc.Table.MANGLE) table = iptc.Table(iptc.Table.NAT)
chain = iptc.Chain(table, "INPUT") chain = iptc.Chain(table, "PREROUTING")
chain.insert_rule(self.rule) chain.insert_rule(self.rule)
def delete_rule(self): def delete_rule(self):
table = iptc.Table(iptc.Table.MANGLE) table = iptc.Table(iptc.Table.NAT)
chain = iptc.Chain(table, "INPUT") chain = iptc.Chain(table, "PREROUTING")
try: try:
chain.delete_rule(self.rule) chain.delete_rule(self.rule)
except iptc.ip4tc.IPTCError: except iptc.ip4tc.IPTCError:
@ -83,22 +83,22 @@ class NetQueue:
def validate_protocol(protocol: Optional[str]) -> Union[str, None]: def validate_protocol(protocol: Optional[str]) -> Union[str, None]:
if protocol: if protocol:
try: try:
Protocols[protocol] PROTOCOLS[protocol]
except KeyError: except KeyError:
raise KeyError(f"Invalid protocol: {protocol}") raise KeyError(f"Invalid protocol: {protocol}")
return protocol return protocol
@staticmethod @staticmethod
def _is_queue_taken(queue: int, override: bool) -> bool: def _is_queue_taken(queue: int, override: bool) -> bool:
table = iptc.Table(iptc.Table.FILTER) for table in TABLES:
for chain in table.chains: for chain in table.chains:
for rule in chain.rules: for rule in chain.rules:
if rule.target.name == "NFQUEUE" and rule.target.get_all_parameters()["queue-num"] == str(queue): if rule.target.name == "NFQUEUE" and rule.target.get_all_parameters()["queue-num"] == str(queue):
if override: if override:
logger.warning(f"Queue {queue} is already taken, clearing it") logger.warning(f"Queue {queue} is already taken, clearing it")
chain.delete_rule(rule) chain.delete_rule(rule)
return False return False
return True return True
return False return False
@staticmethod @staticmethod
@ -116,13 +116,13 @@ class NetQueue:
return ( return (
f"NetQueueFilter(" f"NetQueueFilter("
f"queue={self.queue}, " f"queue={self.queue}, "
f"callback={self.callback}, "
f"src_ip={self.src_ip}, " f"src_ip={self.src_ip}, "
f"dst_ip={self.dst_ip}, " f"dst_ip={self.dst_ip}, "
f"src_port={self.src_port}, " f"src_port={self.src_port}, "
f"dst_port={self.dst_port}, " f"dst_port={self.dst_port}, "
f"protocol={self.protocol}, " f"protocol={self.protocol}, "
f"callback={self.callback}, " f"rule={self.rule}, "
f"async_callback={self.async_callback}"
f")" f")"
) )

View file

@ -1,6 +1,7 @@
import socket import socket
import iptc
Protocols = { PROTOCOLS = {
"ah": socket.IPPROTO_AH, "ah": socket.IPPROTO_AH,
"dstopts": socket.IPPROTO_DSTOPTS, "dstopts": socket.IPPROTO_DSTOPTS,
"egp": socket.IPPROTO_EGP, "egp": socket.IPPROTO_EGP,
@ -26,3 +27,5 @@ Protocols = {
"tp": socket.IPPROTO_TP, "tp": socket.IPPROTO_TP,
"udp": socket.IPPROTO_UDP, "udp": socket.IPPROTO_UDP,
} }
TABLES = [iptc.Table(t) for t in iptc.Table.ALL]