mirror of
https://github.com/DarrylNixon/drawbridge
synced 2024-04-22 12:17:07 -07:00
Move to pre-routing
This commit is contained in:
parent
76c2b1bc07
commit
7b487db1b2
3 changed files with 23 additions and 21 deletions
|
@ -46,7 +46,6 @@ class DrawBridge:
|
||||||
if packet.payload != original:
|
if packet.payload != original:
|
||||||
packet.mangle()
|
packet.mangle()
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _delete_rules(self):
|
def _delete_rules(self):
|
||||||
for queue in self.net_queues:
|
for queue in self.net_queues:
|
||||||
try:
|
try:
|
||||||
|
@ -60,6 +59,6 @@ class DrawBridge:
|
||||||
connection = fnfqueue.Connection()
|
connection = fnfqueue.Connection()
|
||||||
listener = connection.bind(queue.queue)
|
listener = connection.bind(queue.queue)
|
||||||
listener.set_mode(65535, fnfqueue.COPY_PACKET)
|
listener.set_mode(65535, fnfqueue.COPY_PACKET)
|
||||||
task = asyncio.create_task(self._listen(listener, queue.callback))
|
task = asyncio.create_task(self._listen(connection, queue.callback))
|
||||||
tasks.append(task)
|
tasks.append(task)
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
|
|
|
@ -8,14 +8,14 @@ from typing import Union
|
||||||
import iptc
|
import iptc
|
||||||
|
|
||||||
from .utils.logger import logger
|
from .utils.logger import logger
|
||||||
from .utils.lookup import Protocols
|
from .utils.lookup import PROTOCOLS, TABLES
|
||||||
|
|
||||||
|
|
||||||
class NetQueue:
|
class NetQueue:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
callback: Callable,
|
|
||||||
queue: int,
|
queue: int,
|
||||||
|
callback: Callable,
|
||||||
src_ip: Optional[str] = None,
|
src_ip: Optional[str] = None,
|
||||||
dst_ip: Optional[str] = None,
|
dst_ip: Optional[str] = None,
|
||||||
src_port: Optional[int] = None,
|
src_port: Optional[int] = None,
|
||||||
|
@ -44,13 +44,13 @@ class NetQueue:
|
||||||
return rule
|
return rule
|
||||||
|
|
||||||
def write_rule(self):
|
def write_rule(self):
|
||||||
table = iptc.Table(iptc.Table.MANGLE)
|
table = iptc.Table(iptc.Table.NAT)
|
||||||
chain = iptc.Chain(table, "INPUT")
|
chain = iptc.Chain(table, "PREROUTING")
|
||||||
chain.insert_rule(self.rule)
|
chain.insert_rule(self.rule)
|
||||||
|
|
||||||
def delete_rule(self):
|
def delete_rule(self):
|
||||||
table = iptc.Table(iptc.Table.MANGLE)
|
table = iptc.Table(iptc.Table.NAT)
|
||||||
chain = iptc.Chain(table, "INPUT")
|
chain = iptc.Chain(table, "PREROUTING")
|
||||||
try:
|
try:
|
||||||
chain.delete_rule(self.rule)
|
chain.delete_rule(self.rule)
|
||||||
except iptc.ip4tc.IPTCError:
|
except iptc.ip4tc.IPTCError:
|
||||||
|
@ -83,22 +83,22 @@ class NetQueue:
|
||||||
def validate_protocol(protocol: Optional[str]) -> Union[str, None]:
|
def validate_protocol(protocol: Optional[str]) -> Union[str, None]:
|
||||||
if protocol:
|
if protocol:
|
||||||
try:
|
try:
|
||||||
Protocols[protocol]
|
PROTOCOLS[protocol]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise KeyError(f"Invalid protocol: {protocol}")
|
raise KeyError(f"Invalid protocol: {protocol}")
|
||||||
return protocol
|
return protocol
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _is_queue_taken(queue: int, override: bool) -> bool:
|
def _is_queue_taken(queue: int, override: bool) -> bool:
|
||||||
table = iptc.Table(iptc.Table.FILTER)
|
for table in TABLES:
|
||||||
for chain in table.chains:
|
for chain in table.chains:
|
||||||
for rule in chain.rules:
|
for rule in chain.rules:
|
||||||
if rule.target.name == "NFQUEUE" and rule.target.get_all_parameters()["queue-num"] == str(queue):
|
if rule.target.name == "NFQUEUE" and rule.target.get_all_parameters()["queue-num"] == str(queue):
|
||||||
if override:
|
if override:
|
||||||
logger.warning(f"Queue {queue} is already taken, clearing it")
|
logger.warning(f"Queue {queue} is already taken, clearing it")
|
||||||
chain.delete_rule(rule)
|
chain.delete_rule(rule)
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -116,13 +116,13 @@ class NetQueue:
|
||||||
return (
|
return (
|
||||||
f"NetQueueFilter("
|
f"NetQueueFilter("
|
||||||
f"queue={self.queue}, "
|
f"queue={self.queue}, "
|
||||||
|
f"callback={self.callback}, "
|
||||||
f"src_ip={self.src_ip}, "
|
f"src_ip={self.src_ip}, "
|
||||||
f"dst_ip={self.dst_ip}, "
|
f"dst_ip={self.dst_ip}, "
|
||||||
f"src_port={self.src_port}, "
|
f"src_port={self.src_port}, "
|
||||||
f"dst_port={self.dst_port}, "
|
f"dst_port={self.dst_port}, "
|
||||||
f"protocol={self.protocol}, "
|
f"protocol={self.protocol}, "
|
||||||
f"callback={self.callback}, "
|
f"rule={self.rule}, "
|
||||||
f"async_callback={self.async_callback}"
|
|
||||||
f")"
|
f")"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import socket
|
import socket
|
||||||
|
import iptc
|
||||||
|
|
||||||
Protocols = {
|
PROTOCOLS = {
|
||||||
"ah": socket.IPPROTO_AH,
|
"ah": socket.IPPROTO_AH,
|
||||||
"dstopts": socket.IPPROTO_DSTOPTS,
|
"dstopts": socket.IPPROTO_DSTOPTS,
|
||||||
"egp": socket.IPPROTO_EGP,
|
"egp": socket.IPPROTO_EGP,
|
||||||
|
@ -26,3 +27,5 @@ Protocols = {
|
||||||
"tp": socket.IPPROTO_TP,
|
"tp": socket.IPPROTO_TP,
|
||||||
"udp": socket.IPPROTO_UDP,
|
"udp": socket.IPPROTO_UDP,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TABLES = [iptc.Table(t) for t in iptc.Table.ALL]
|
Loading…
Reference in a new issue