diff --git a/.flake8 b/.flake8
index dd0767d..42361f6 100644
--- a/.flake8
+++ b/.flake8
@@ -1,3 +1,5 @@
[flake8]
max-line-length = 160
exclude = docs/*, .git, __pycache__, build
+per-file-ignores =
+ __init__.py: F401
diff --git a/drawbridge/__init__.py b/drawbridge/__init__.py
index e69de29..c8b9f28 100644
--- a/drawbridge/__init__.py
+++ b/drawbridge/__init__.py
@@ -0,0 +1,37 @@
+import os
+
+import iptc
+
+
+def is_root():
+ return os.geteuid() == 0
+
+
+def check_nfqueue():
+ try:
+ import nfqueue
+
+ return True
+ except ImportError:
+ return False
+
+
+def check_iptables():
+ try:
+ iptc.Table(iptc.Table.FILTER)
+ return True
+ except iptc.ip4tc.IPTCError:
+ return False
+
+
+def check_requirements():
+ if not is_root():
+ raise RuntimeError("Must be run as root")
+ if not check_nfqueue():
+ raise RuntimeError("nfqueue is not installed or is not supported on your platform")
+ if not check_iptables():
+ raise RuntimeError("iptables not installed or is not supported on your platform")
+
+
+if __name__ == "__main__":
+ check_requirements()
diff --git a/drawbridge/drawbridge.py b/drawbridge/drawbridge.py
new file mode 100644
index 0000000..c68e585
--- /dev/null
+++ b/drawbridge/drawbridge.py
@@ -0,0 +1,65 @@
+import asyncio
+import atexit
+from typing import Callable
+from typing import Optional
+
+import fnfqueue
+
+from net_queue import NetQueue
+from utils.logger import logger
+
+
+class DrawBridge:
+ def __init__(self):
+ self.net_queues = []
+ atexit.register(self._delete_rules)
+
+ def add_queue(
+ self,
+ queue: int,
+ callback: Callable,
+ src_ip: Optional[str] = None,
+ dst_ip: Optional[str] = None,
+ src_port: Optional[int] = None,
+ dst_port: Optional[int] = None,
+ protocol: Optional[str] = "",
+ override: bool = False,
+ ):
+ try:
+ new_queue = NetQueue(queue, callback, src_ip, dst_ip, src_port, dst_port, protocol, override)
+ new_queue.write_rule()
+ except Exception as e:
+ logger.error(f"Failed to initialize NetQueue: {e}")
+ raise e
+ self.net_queues.append(new_queue)
+
+ def run(self):
+ asyncio.run(self.raise_bridges())
+
+ async def _listen(self, listener, callback: Callable) -> None:
+ for packet in listener:
+ original = packet.payload
+ if asyncio.iscoroutinefunction(callback):
+ packet.payload = await callback(packet.payload)
+ else:
+ packet.payload = callback(packet.payload)
+ if packet.payload != original:
+ packet.mangle()
+
+ @staticmethod
+ def _delete_rules(self):
+ for queue in self.net_queues:
+ try:
+ queue.delete_rule()
+ except Exception as e:
+ logger.error(f"Failed to delete rule: {e}")
+
+ async def raise_bridges(self):
+ tasks = []
+ for queue in self.net_queues:
+ connection = fnfqueue.Connection()
+ listener = connection.bind(queue.queue)
+ listener.set_mode(65535, fnfqueue.COPY_PACKET)
+ task = asyncio.create_task(self._listen(listener, queue.callback))
+ tasks.append(task)
+ await asyncio.gather(*tasks)
diff --git a/drawbridge/net_queue.py b/drawbridge/net_queue.py
new file mode 100644
index 0000000..605d7ae
--- /dev/null
+++ b/drawbridge/net_queue.py
@@ -0,0 +1,129 @@
+from ipaddress import AddressValueError
+from ipaddress import ip_address
+from ipaddress import IPv6Address
+from typing import Callable
+from typing import Optional
+from typing import Union
+
+import iptc
+
+from utils.logger import logger
+from utils.lookup import Protocols
+
+
+class NetQueue:
+ def __init__(
+ self,
+ callback: Callable,
+ queue: int,
+ src_ip: Optional[str] = None,
+ dst_ip: Optional[str] = None,
+ src_port: Optional[int] = None,
+ dst_port: Optional[int] = None,
+ protocol: Optional[str] = "",
+ override: bool = False,
+ ):
+ self.callback = self.validate_callable(callback)
+ self.src_port = self.validate_port(src_port, "source")
+ self.dst_port = self.validate_port(dst_port, "destination")
+ self.src_ip = self.validate_ip(src_ip, "source")
+ self.dst_ip = self.validate_ip(dst_ip, "destination")
+ self.protocol = self.validate_protocol(protocol)
+
+ self.queue = self.validate_queue(queue, override)
+ self.rule = self._create_rule()
+
+ def _create_rule(self) -> iptc.Rule:
+ rule = iptc.Rule()
+ match = iptc.Match(rule, self.protocol)
+ target = iptc.Target(rule, "NFQUEUE")
+ target.set_parameter("queue-num", str(self.queue))
+ rule.add_match(match)
+ rule.target = target
+ return rule
+
+ def write_rule(self):
+ table = iptc.Table(iptc.Table.FILTER)
+ chain = iptc.Chain(table, "INPUT")
+ chain.insert_rule(self.rule)
+
+ def delete_rule(self):
+ table = iptc.Table(iptc.Table.FILTER)
+ chain = iptc.Chain(table, "INPUT")
+ try:
+ chain.delete_rule(self.rule)
+ except iptc.ip4tc.IPTCError:
+ logger.warning("Failed to delete rule, it may have already been deleted")
+
+ @staticmethod
+ def validate_callable(callback: Callable) -> Callable:
+ if not callable(callback):
+ raise ValueError(f"Invalid callback: {callback}")
+ return callback
+
+ @staticmethod
+ def validate_ip(ip: Optional[str], description: str) -> Optional[str]:
+ if ip:
+ try:
+ if type(ip := ip_address(ip)) == IPv6Address:
+ raise NotImplementedError(f"IPv6 not supported: {ip}")
+ except (AddressValueError, ValueError):
+ raise AddressValueError(f"Invalid {description} IP address: {ip}")
+ return ip
+
+ @staticmethod
+ def validate_port(port: Optional[int], description: str) -> Union(int, None):
+ if port:
+ if not 0 <= port <= 65535:
+ raise ValueError(f"Invalid {description} port: {port}")
+ return port
+
+ @staticmethod
+ def validate_protocol(protocol: Optional[str]) -> Union(str, None):
+ if protocol:
+ try:
+ Protocols(protocol)
+ except KeyError:
+ raise KeyError(f"Invalid protocol: {protocol}")
+ return protocol
+
+ @staticmethod
+ def _is_queue_taken(queue: int, override: bool) -> bool:
+ table = iptc.Table(iptc.Table.FILTER)
+ for chain in table.chains:
+ for rule in chain.rules:
+ if rule.target.name == "NFQUEUE" and rule.target.get_all_parameters()["queue-num"] == str(queue):
+ if override:
+ logger.warning(f"Queue {queue} is already taken, clearing it")
+ chain.delete_rule(rule)
+ return False
+ return True
+ return False
+
+ @staticmethod
+ def _validate_queue(queue: int, override: bool) -> int:
+ if not 0 <= queue <= 65535:
+ raise ValueError(f"Invalid queue number: {queue}")
+
+ if NetQueue._is_queue_taken(queue, override):
+ logger.warning(f"Queue {queue} is already taken, raising error")
+ raise ValueError(f"Queue {queue} is already taken")
+
+ return queue
+
+ def __repr__(self):
+ return (
+ f"NetQueueFilter("
+ f"queue={self.queue}, "
+ f"src_ip={self.src_ip}, "
+ f"dst_ip={self.dst_ip}, "
+ f"src_port={self.src_port}, "
+ f"dst_port={self.dst_port}, "
+ f"protocol={self.protocol}, "
+ f"callback={self.callback}, "
+ f"async_callback={self.async_callback}"
+ f")"
+ )
+
+ def __str__(self):
+ return self.__repr__()
diff --git a/drawbridge/utils/__init__.py b/drawbridge/utils/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/drawbridge/utils/logger.py b/drawbridge/utils/logger.py
new file mode 100644
index 0000000..b6006ad
--- /dev/null
+++ b/drawbridge/utils/logger.py
@@ -0,0 +1,35 @@
+from loguru import logger
+
+
+logger.add(
+ "app.log",
+ format="{time:YYYY-MM-DD HH:mm:ss} | {message}",
+ level="INFO",
+ rotation="1 day",
+ retention="30 days",
+)
+
+logger.add(
+ "errors.log",
+ format="âšī¸ {time:YYYY-MM-DD HH:mm:ss} | {message}",
+ level="WARNING",
+ rotation="1 day",
+ retention="30 days",
+)
+
+logger.add(
+ "error.log",
+ format="âī¸ {time:YYYY-MM-DD HH:mm:ss} | {message}",
+ level="ERROR",
+ rotation="1 day",
+ retention="30 days",
+)
+
+
+logger.add(
+ "critical.log",
+ format="đ¨ {time:YYYY-MM-DD HH:mm:ss} | {message}",
+ level="CRITICAL",
+ rotation="1 day",
+ retention="30 days",
+)
diff --git a/drawbridge/utils/lookup.py b/drawbridge/utils/lookup.py
new file mode 100644
index 0000000..77c4765
--- /dev/null
+++ b/drawbridge/utils/lookup.py
@@ -0,0 +1,28 @@
+import socket
+
+Protocols = {
+ "ah": socket.IPPROTO_AH,
+ "dstopts": socket.IPPROTO_DSTOPTS,
+ "egp": socket.IPPROTO_EGP,
+ "esp": socket.IPPROTO_ESP,
+ "fragment": socket.IPPROTO_FRAGMENT,
+ "gre": socket.IPPROTO_GRE,
+ "hopopts": socket.IPPROTO_HOPOPTS,
+ "icmp": socket.IPPROTO_ICMP,
+ "icmpv6": socket.IPPROTO_ICMPV6,
+ "idp": socket.IPPROTO_IDP,
+ "igmp": socket.IPPROTO_IGMP,
+ "ip": socket.IPPROTO_IP,
+ "ipip": socket.IPPROTO_IPIP,
+ "ipv6": socket.IPPROTO_IPV6,
+ "none": socket.IPPROTO_NONE,
+ "pim": socket.IPPROTO_PIM,
+ "pup": socket.IPPROTO_PUP,
+ "raw": socket.IPPROTO_RAW,
+ "routing": socket.IPPROTO_ROUTING,
+ "rsvp": socket.IPPROTO_RSVP,
+ "sctp": socket.IPPROTO_SCTP,
+ "tcp": socket.IPPROTO_TCP,
+ "tp": socket.IPPROTO_TP,
+ "udp": socket.IPPROTO_UDP,
+}