diff --git a/.flake8 b/.flake8 index dd0767d..42361f6 100644 --- a/.flake8 +++ b/.flake8 @@ -1,3 +1,5 @@ [flake8] max-line-length = 160 exclude = docs/*, .git, __pycache__, build +per-file-ignores = + __init__.py: F401 diff --git a/drawbridge/__init__.py b/drawbridge/__init__.py index e69de29..c8b9f28 100644 --- a/drawbridge/__init__.py +++ b/drawbridge/__init__.py @@ -0,0 +1,37 @@ +import os + +import iptc + + +def is_root(): + return os.geteuid() == 0 + + +def check_nfqueue(): + try: + import nfqueue + + return True + except ImportError: + return False + + +def check_iptables(): + try: + iptc.Table(iptc.Table.FILTER) + return True + except iptc.ip4tc.IPTCError: + return False + + +def check_requirements(): + if not is_root(): + raise RuntimeError("Must be run as root") + if not check_nfqueue(): + raise RuntimeError("nfqueue is not installed or is not supported on your platform") + if not check_iptables(): + raise RuntimeError("iptables not installed or is not supported on your platform") + + +if __name__ == "__main__": + check_requirements() diff --git a/drawbridge/drawbridge.py b/drawbridge/drawbridge.py new file mode 100644 index 0000000..c68e585 --- /dev/null +++ b/drawbridge/drawbridge.py @@ -0,0 +1,65 @@ +import asyncio +import atexit +from typing import Callable +from typing import Optional + +import fnfqueue + +from net_queue import NetQueue +from utils.logger import logger + + +class DrawBridge: + def __init__(self): + self.net_queues = [] + atexit.register(self._delete_rules) + + def add_queue( + self, + queue: int, + callback: Callable, + src_ip: Optional[str] = None, + dst_ip: Optional[str] = None, + src_port: Optional[int] = None, + dst_port: Optional[int] = None, + protocol: Optional[str] = "", + override: bool = False, + ): + try: + new_queue = NetQueue(queue, callback, src_ip, dst_ip, src_port, dst_port, protocol, override) + new_queue.write_rule() + except Exception as e: + logger.error(f"Failed to initialize NetQueue: {e}") + raise e + self.net_queues.append(new_queue) + + def run(self): + asyncio.run(self.raise_bridges()) + + async def _listen(self, listener, callback: Callable) -> None: + for packet in listener: + original = packet.payload + if asyncio.iscoroutinefunction(callback): + packet.payload = await callback(packet.payload) + else: + packet.payload = callback(packet.payload) + if packet.payload != original: + packet.mangle() + + @staticmethod + def _delete_rules(self): + for queue in self.net_queues: + try: + queue.delete_rule() + except Exception as e: + logger.error(f"Failed to delete rule: {e}") + + async def raise_bridges(self): + tasks = [] + for queue in self.net_queues: + connection = fnfqueue.Connection() + listener = connection.bind(queue.queue) + listener.set_mode(65535, fnfqueue.COPY_PACKET) + task = asyncio.create_task(self._listen(listener, queue.callback)) + tasks.append(task) + await asyncio.gather(*tasks) diff --git a/drawbridge/net_queue.py b/drawbridge/net_queue.py new file mode 100644 index 0000000..605d7ae --- /dev/null +++ b/drawbridge/net_queue.py @@ -0,0 +1,129 @@ +from ipaddress import AddressValueError +from ipaddress import ip_address +from ipaddress import IPv6Address +from typing import Callable +from typing import Optional +from typing import Union + +import iptc + +from utils.logger import logger +from utils.lookup import Protocols + + +class NetQueue: + def __init__( + self, + callback: Callable, + queue: int, + src_ip: Optional[str] = None, + dst_ip: Optional[str] = None, + src_port: Optional[int] = None, + dst_port: Optional[int] = None, + protocol: Optional[str] = "", + override: bool = False, + ): + self.callback = self.validate_callable(callback) + self.src_port = self.validate_port(src_port, "source") + self.dst_port = self.validate_port(dst_port, "destination") + self.src_ip = self.validate_ip(src_ip, "source") + self.dst_ip = self.validate_ip(dst_ip, "destination") + self.protocol = self.validate_protocol(protocol) + + self.queue = self.validate_queue(queue, override) + self.rule = self._create_rule() + + def _create_rule(self) -> iptc.Rule: + rule = iptc.Rule() + match = iptc.Match(rule, self.protocol) + target = iptc.Target(rule, "NFQUEUE") + target.set_parameter("queue-num", str(self.queue)) + rule.add_match(match) + rule.target = target + return rule + + def write_rule(self): + table = iptc.Table(iptc.Table.FILTER) + chain = iptc.Chain(table, "INPUT") + chain.insert_rule(self.rule) + + def delete_rule(self): + table = iptc.Table(iptc.Table.FILTER) + chain = iptc.Chain(table, "INPUT") + try: + chain.delete_rule(self.rule) + except iptc.ip4tc.IPTCError: + logger.warning("Failed to delete rule, it may have already been deleted") + + @staticmethod + def validate_callable(callback: Callable) -> Callable: + if not callable(callback): + raise ValueError(f"Invalid callback: {callback}") + return callback + + @staticmethod + def validate_ip(ip: Optional[str], description: str) -> Optional[str]: + if ip: + try: + if type(ip := ip_address(ip)) == IPv6Address: + raise NotImplementedError(f"IPv6 not supported: {ip}") + except (AddressValueError, ValueError): + raise AddressValueError(f"Invalid {description} IP address: {ip}") + return ip + + @staticmethod + def validate_port(port: Optional[int], description: str) -> Union(int, None): + if port: + if not 0 <= port <= 65535: + raise ValueError(f"Invalid {description} port: {port}") + return port + + @staticmethod + def validate_protocol(protocol: Optional[str]) -> Union(str, None): + if protocol: + try: + Protocols(protocol) + except KeyError: + raise KeyError(f"Invalid protocol: {protocol}") + return protocol + + @staticmethod + def _is_queue_taken(queue: int, override: bool) -> bool: + table = iptc.Table(iptc.Table.FILTER) + for chain in table.chains: + for rule in chain.rules: + if rule.target.name == "NFQUEUE" and rule.target.get_all_parameters()["queue-num"] == str(queue): + if override: + logger.warning(f"Queue {queue} is already taken, clearing it") + chain.delete_rule(rule) + return False + return True + return False + + @staticmethod + def _validate_queue(queue: int, override: bool) -> int: + if not 0 <= queue <= 65535: + raise ValueError(f"Invalid queue number: {queue}") + + if NetQueue._is_queue_taken(queue, override): + logger.warning(f"Queue {queue} is already taken, raising error") + raise ValueError(f"Queue {queue} is already taken") + + return queue + + def __repr__(self): + return ( + f"NetQueueFilter(" + f"queue={self.queue}, " + f"src_ip={self.src_ip}, " + f"dst_ip={self.dst_ip}, " + f"src_port={self.src_port}, " + f"dst_port={self.dst_port}, " + f"protocol={self.protocol}, " + f"callback={self.callback}, " + f"async_callback={self.async_callback}" + f")" + ) + + def __str__(self): + return self.__repr__() diff --git a/drawbridge/utils/__init__.py b/drawbridge/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/drawbridge/utils/logger.py b/drawbridge/utils/logger.py new file mode 100644 index 0000000..b6006ad --- /dev/null +++ b/drawbridge/utils/logger.py @@ -0,0 +1,35 @@ +from loguru import logger + + +logger.add( + "app.log", + format="{time:YYYY-MM-DD HH:mm:ss} | {message}", + level="INFO", + rotation="1 day", + retention="30 days", +) + +logger.add( + "errors.log", + format="ℹī¸ {time:YYYY-MM-DD HH:mm:ss} | {message}", + level="WARNING", + rotation="1 day", + retention="30 days", +) + +logger.add( + "error.log", + format="⛔ī¸ {time:YYYY-MM-DD HH:mm:ss} | {message}", + level="ERROR", + rotation="1 day", + retention="30 days", +) + + +logger.add( + "critical.log", + format="🚨 {time:YYYY-MM-DD HH:mm:ss} | {message}", + level="CRITICAL", + rotation="1 day", + retention="30 days", +) diff --git a/drawbridge/utils/lookup.py b/drawbridge/utils/lookup.py new file mode 100644 index 0000000..77c4765 --- /dev/null +++ b/drawbridge/utils/lookup.py @@ -0,0 +1,28 @@ +import socket + +Protocols = { + "ah": socket.IPPROTO_AH, + "dstopts": socket.IPPROTO_DSTOPTS, + "egp": socket.IPPROTO_EGP, + "esp": socket.IPPROTO_ESP, + "fragment": socket.IPPROTO_FRAGMENT, + "gre": socket.IPPROTO_GRE, + "hopopts": socket.IPPROTO_HOPOPTS, + "icmp": socket.IPPROTO_ICMP, + "icmpv6": socket.IPPROTO_ICMPV6, + "idp": socket.IPPROTO_IDP, + "igmp": socket.IPPROTO_IGMP, + "ip": socket.IPPROTO_IP, + "ipip": socket.IPPROTO_IPIP, + "ipv6": socket.IPPROTO_IPV6, + "none": socket.IPPROTO_NONE, + "pim": socket.IPPROTO_PIM, + "pup": socket.IPPROTO_PUP, + "raw": socket.IPPROTO_RAW, + "routing": socket.IPPROTO_ROUTING, + "rsvp": socket.IPPROTO_RSVP, + "sctp": socket.IPPROTO_SCTP, + "tcp": socket.IPPROTO_TCP, + "tp": socket.IPPROTO_TP, + "udp": socket.IPPROTO_UDP, +}