mirror of
https://github.com/DarrylNixon/drawbridge
synced 2024-04-22 12:17:07 -07:00
Move to pre-routing
This commit is contained in:
parent
76c2b1bc07
commit
7b487db1b2
3 changed files with 23 additions and 21 deletions
|
@ -46,7 +46,6 @@ class DrawBridge:
|
|||
if packet.payload != original:
|
||||
packet.mangle()
|
||||
|
||||
@staticmethod
|
||||
def _delete_rules(self):
|
||||
for queue in self.net_queues:
|
||||
try:
|
||||
|
@ -60,6 +59,6 @@ class DrawBridge:
|
|||
connection = fnfqueue.Connection()
|
||||
listener = connection.bind(queue.queue)
|
||||
listener.set_mode(65535, fnfqueue.COPY_PACKET)
|
||||
task = asyncio.create_task(self._listen(listener, queue.callback))
|
||||
task = asyncio.create_task(self._listen(connection, queue.callback))
|
||||
tasks.append(task)
|
||||
await asyncio.gather(*tasks)
|
||||
|
|
|
@ -8,14 +8,14 @@ from typing import Union
|
|||
import iptc
|
||||
|
||||
from .utils.logger import logger
|
||||
from .utils.lookup import Protocols
|
||||
from .utils.lookup import PROTOCOLS, TABLES
|
||||
|
||||
|
||||
class NetQueue:
|
||||
def __init__(
|
||||
self,
|
||||
callback: Callable,
|
||||
queue: int,
|
||||
callback: Callable,
|
||||
src_ip: Optional[str] = None,
|
||||
dst_ip: Optional[str] = None,
|
||||
src_port: Optional[int] = None,
|
||||
|
@ -44,13 +44,13 @@ class NetQueue:
|
|||
return rule
|
||||
|
||||
def write_rule(self):
|
||||
table = iptc.Table(iptc.Table.MANGLE)
|
||||
chain = iptc.Chain(table, "INPUT")
|
||||
table = iptc.Table(iptc.Table.NAT)
|
||||
chain = iptc.Chain(table, "PREROUTING")
|
||||
chain.insert_rule(self.rule)
|
||||
|
||||
def delete_rule(self):
|
||||
table = iptc.Table(iptc.Table.MANGLE)
|
||||
chain = iptc.Chain(table, "INPUT")
|
||||
table = iptc.Table(iptc.Table.NAT)
|
||||
chain = iptc.Chain(table, "PREROUTING")
|
||||
try:
|
||||
chain.delete_rule(self.rule)
|
||||
except iptc.ip4tc.IPTCError:
|
||||
|
@ -83,14 +83,14 @@ class NetQueue:
|
|||
def validate_protocol(protocol: Optional[str]) -> Union[str, None]:
|
||||
if protocol:
|
||||
try:
|
||||
Protocols[protocol]
|
||||
PROTOCOLS[protocol]
|
||||
except KeyError:
|
||||
raise KeyError(f"Invalid protocol: {protocol}")
|
||||
return protocol
|
||||
|
||||
@staticmethod
|
||||
def _is_queue_taken(queue: int, override: bool) -> bool:
|
||||
table = iptc.Table(iptc.Table.FILTER)
|
||||
for table in TABLES:
|
||||
for chain in table.chains:
|
||||
for rule in chain.rules:
|
||||
if rule.target.name == "NFQUEUE" and rule.target.get_all_parameters()["queue-num"] == str(queue):
|
||||
|
@ -116,13 +116,13 @@ class NetQueue:
|
|||
return (
|
||||
f"NetQueueFilter("
|
||||
f"queue={self.queue}, "
|
||||
f"callback={self.callback}, "
|
||||
f"src_ip={self.src_ip}, "
|
||||
f"dst_ip={self.dst_ip}, "
|
||||
f"src_port={self.src_port}, "
|
||||
f"dst_port={self.dst_port}, "
|
||||
f"protocol={self.protocol}, "
|
||||
f"callback={self.callback}, "
|
||||
f"async_callback={self.async_callback}"
|
||||
f"rule={self.rule}, "
|
||||
f")"
|
||||
)
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import socket
|
||||
import iptc
|
||||
|
||||
Protocols = {
|
||||
PROTOCOLS = {
|
||||
"ah": socket.IPPROTO_AH,
|
||||
"dstopts": socket.IPPROTO_DSTOPTS,
|
||||
"egp": socket.IPPROTO_EGP,
|
||||
|
@ -26,3 +27,5 @@ Protocols = {
|
|||
"tp": socket.IPPROTO_TP,
|
||||
"udp": socket.IPPROTO_UDP,
|
||||
}
|
||||
|
||||
TABLES = [iptc.Table(t) for t in iptc.Table.ALL]
|
Loading…
Reference in a new issue