move to git

This commit is contained in:
hackish 2023-12-04 14:09:05 -08:00
commit 0af8465d1b
16 changed files with 1436 additions and 0 deletions

40
drawbridge/__init__.py Normal file
View file

@ -0,0 +1,40 @@
import os
import iptc
from .drawbridge import DrawBridge
from .net_queue import NetQueue
def is_root():
return os.geteuid() == 0
def check_nfqueue():
try:
import nfqueue
return True
except ImportError:
return False
def check_iptables():
try:
iptc.Table(iptc.Table.FILTER)
return True
except iptc.ip4tc.IPTCError:
return False
def check_requirements():
if not is_root():
raise RuntimeError("Must be run as root")
if not check_nfqueue():
raise RuntimeError("nfqueue is not installed or is not supported on your platform")
if not check_iptables():
raise RuntimeError("iptables not installed or is not supported on your platform")
if __name__ == "__main__":
check_requirements()

65
drawbridge/drawbridge.py Normal file
View file

@ -0,0 +1,65 @@
import asyncio
import atexit
from typing import Callable
from typing import Optional
import fnfqueue
from .net_queue import NetQueue
from .utils.logger import logger
class DrawBridge:
def __init__(self):
self.net_queues = []
atexit.register(self._delete_rules)
def add_queue(
self,
callback: Callable,
queue: int = 0,
src_ip: Optional[str] = None,
dst_ip: Optional[str] = None,
src_port: Optional[int] = None,
dst_port: Optional[int] = None,
protocol: Optional[str] = "tcp",
override: bool = False,
):
try:
new_queue = NetQueue(callback, queue, src_ip, dst_ip, src_port, dst_port, protocol, override)
new_queue.write_rule()
except Exception as e:
logger.error(f"Failed to initialize NetQueue: {e}")
raise e
self.net_queues.append(new_queue)
def run(self):
asyncio.run(self.raise_bridges())
async def _listen(self, connection, callback: Callable) -> None:
try:
for packet in connection:
if asyncio.iscoroutinefunction(callback):
packet.payload = await callback(packet.payload)
else:
packet.payload = callback(packet.payload)
packet.mangle()
except fnfqueue.BufferOverflowException:
logger.warning("Packets arriving too quickly")
def _delete_rules(self):
for queue in self.net_queues:
try:
queue.delete_rule()
except Exception as e:
logger.error(f"Failed to delete rule: {e}")
async def raise_bridges(self):
tasks = []
for queue in self.net_queues:
connection = fnfqueue.Connection()
listener = connection.bind(queue.queue)
listener.set_mode(65535, fnfqueue.COPY_PACKET)
task = asyncio.create_task(self._listen(connection, queue.callback))
tasks.append(task)
await asyncio.gather(*tasks)

134
drawbridge/net_queue.py Normal file
View file

@ -0,0 +1,134 @@
from ipaddress import AddressValueError
from ipaddress import ip_address
from ipaddress import IPv6Address
from typing import Callable
from typing import Optional
from typing import Union
import iptc
from .utils.logger import logger
from .utils.lookup import PROTOCOLS, ALL_TABLES, OUTGOING_MANGLE
class NetQueue:
def __init__(
self,
callback: Callable,
queue: int,
src_ip: Optional[str] = None,
dst_ip: Optional[str] = None,
src_port: Optional[int] = None,
dst_port: Optional[int] = None,
protocol: Optional[str] = "tcp",
override: bool = False,
):
self.callback = self.validate_callable(callback)
self.src_port = self.validate_port(src_port, "source")
self.dst_port = self.validate_port(dst_port, "destination")
self.src_ip = self.validate_ip(src_ip, "source")
self.dst_ip = self.validate_ip(dst_ip, "destination")
self.protocol = self.validate_protocol(protocol)
self.queue = self._validate_queue(queue, override)
self.rule = self._create_rule()
def _create_rule(self) -> iptc.Rule:
rule = iptc.Rule()
target = iptc.Target(rule, "NFQUEUE")
target.set_parameter("queue-num", str(self.queue))
rule.protocol = self.protocol
match = rule.create_match(self.protocol)
if self.dst_port:
match.dport = str(self.dst_port)
if self.src_port:
match.sport = str(self.src_port)
match = iptc.Match(rule, "iprange")
if self.src_ip:
match.src_range = str(self.src_ip)
if self.dst_ip:
match.dst_range = str(self.dst_ip)
rule.target = target
return rule
def write_rule(self):
OUTGOING_MANGLE.insert_rule(self.rule)
def delete_rule(self):
try:
OUTGOING_MANGLE.delete_rule(self.rule)
except iptc.ip4tc.IPTCError:
logger.warning("Failed to delete rule, it may have already been deleted")
@staticmethod
def validate_callable(callback: Callable) -> Callable:
if not callable(callback):
raise ValueError(f"Invalid callback: {callback}")
return callback
@staticmethod
def validate_ip(ip: Optional[str], description: str) -> Optional[str]:
if ip:
try:
if type(ip := ip_address(ip)) == IPv6Address:
raise NotImplementedError(f"IPv6 not supported: {ip}")
except (AddressValueError, ValueError):
raise AddressValueError(f"Invalid {description} IP address: {ip}")
return ip
@staticmethod
def validate_port(port: Optional[int], description: str) -> Union[int, None]:
if port:
if not 0 <= port <= 65535:
raise ValueError(f"Invalid {description} port: {port}")
return port
@staticmethod
def validate_protocol(protocol: Optional[str]) -> Union[str, None]:
if protocol:
try:
PROTOCOLS[protocol]
except KeyError:
raise KeyError(f"Invalid protocol: {protocol}")
return protocol
@staticmethod
def _is_queue_taken(queue: int, override: bool) -> bool:
for table in ALL_TABLES:
for chain in table.chains:
for rule in chain.rules:
if rule.target.name == "NFQUEUE" and rule.target.get_all_parameters()["queue-num"] == str(queue):
if override:
logger.warning(f"Queue {queue} is already taken, clearing it")
chain.delete_rule(rule)
return False
return True
return False
@staticmethod
def _validate_queue(queue: int, override: bool) -> int:
if not 0 <= queue <= 65535:
raise ValueError(f"Invalid queue number: {queue}")
if NetQueue._is_queue_taken(queue, override):
logger.warning(f"Queue {queue} is already taken, raising error")
raise ValueError(f"Queue {queue} is already taken")
return queue
def __repr__(self):
return (
f"NetQueueFilter("
f"queue={self.queue}, "
f"callback={self.callback}, "
f"src_ip={self.src_ip}, "
f"dst_ip={self.dst_ip}, "
f"src_port={self.src_port}, "
f"dst_port={self.dst_port}, "
f"protocol={self.protocol}, "
f"rule={self.rule}, "
f")"
)
def __str__(self):
return self.__repr__()

View file

View file

@ -0,0 +1,35 @@
from loguru import logger
logger.add(
"app.log",
format="<level><light-blue>{time:YYYY-MM-DD HH:mm:ss} | {message}</light-blue></level>",
level="INFO",
rotation="1 day",
retention="30 days",
)
logger.add(
"errors.log",
format="<level><yellow> {time:YYYY-MM-DD HH:mm:ss} | {message}</yellow></level>",
level="WARNING",
rotation="1 day",
retention="30 days",
)
logger.add(
"error.log",
format="<level><red>⛔️ {time:YYYY-MM-DD HH:mm:ss} | {message}</red></level>",
level="ERROR",
rotation="1 day",
retention="30 days",
)
logger.add(
"critical.log",
format="<level><magenta>🚨 {time:YYYY-MM-DD HH:mm:ss} | {message}</magenta></level>",
level="CRITICAL",
rotation="1 day",
retention="30 days",
)

View file

@ -0,0 +1,33 @@
import socket
import iptc
PROTOCOLS = {
"ah": socket.IPPROTO_AH,
"dstopts": socket.IPPROTO_DSTOPTS,
"egp": socket.IPPROTO_EGP,
"esp": socket.IPPROTO_ESP,
"fragment": socket.IPPROTO_FRAGMENT,
"gre": socket.IPPROTO_GRE,
"hopopts": socket.IPPROTO_HOPOPTS,
"icmp": socket.IPPROTO_ICMP,
"icmpv6": socket.IPPROTO_ICMPV6,
"idp": socket.IPPROTO_IDP,
"igmp": socket.IPPROTO_IGMP,
"ip": socket.IPPROTO_IP,
"ipip": socket.IPPROTO_IPIP,
"ipv6": socket.IPPROTO_IPV6,
"none": socket.IPPROTO_NONE,
"pim": socket.IPPROTO_PIM,
"pup": socket.IPPROTO_PUP,
"raw": socket.IPPROTO_RAW,
"routing": socket.IPPROTO_ROUTING,
"rsvp": socket.IPPROTO_RSVP,
"sctp": socket.IPPROTO_SCTP,
"tcp": socket.IPPROTO_TCP,
"tp": socket.IPPROTO_TP,
"udp": socket.IPPROTO_UDP,
}
ALL_TABLES = [iptc.Table(t) for t in iptc.Table.ALL]
PREROUTING_MANGLE = iptc.Chain(iptc.Table(iptc.Table.MANGLE), "PREROUTING")
OUTGOING_MANGLE = iptc.Chain(iptc.Table(iptc.Table.FILTER), "OUTPUT")