import asyncio import atexit from typing import Callable from typing import Optional import fnfqueue from .net_queue import NetQueue from .utils.logger import logger class DrawBridge: def __init__(self): self.net_queues = [] atexit.register(self._delete_rules) def add_queue( self, callback: Callable, queue: int = 0, src_ip: Optional[str] = None, dst_ip: Optional[str] = None, src_port: Optional[int] = None, dst_port: Optional[int] = None, protocol: Optional[str] = "tcp", override: bool = False, ): try: new_queue = NetQueue(callback, queue, src_ip, dst_ip, src_port, dst_port, protocol, override) new_queue.write_rule() except Exception as e: logger.error(f"Failed to initialize NetQueue: {e}") raise e self.net_queues.append(new_queue) def run(self): asyncio.run(self.raise_bridges()) async def _listen(self, connection, callback: Callable) -> None: try: for packet in connection: if asyncio.iscoroutinefunction(callback): packet.payload = await callback(packet.payload) else: packet.payload = callback(packet.payload) packet.mangle() except fnfqueue.BufferOverflowException: logger.warning("Packets arriving too quickly") def _delete_rules(self): for queue in self.net_queues: try: queue.delete_rule() except Exception as e: logger.error(f"Failed to delete rule: {e}") async def raise_bridges(self): tasks = [] for queue in self.net_queues: connection = fnfqueue.Connection() listener = connection.bind(queue.queue) listener.set_mode(65535, fnfqueue.COPY_PACKET) task = asyncio.create_task(self._listen(connection, queue.callback)) tasks.append(task) await asyncio.gather(*tasks)