66 lines
2 KiB
Python
66 lines
2 KiB
Python
|
import asyncio
|
||
|
import atexit
|
||
|
from typing import Callable
|
||
|
from typing import Optional
|
||
|
|
||
|
import fnfqueue
|
||
|
|
||
|
from .net_queue import NetQueue
|
||
|
from .utils.logger import logger
|
||
|
|
||
|
|
||
|
class DrawBridge:
|
||
|
def __init__(self):
|
||
|
self.net_queues = []
|
||
|
atexit.register(self._delete_rules)
|
||
|
|
||
|
def add_queue(
|
||
|
self,
|
||
|
callback: Callable,
|
||
|
queue: int = 0,
|
||
|
src_ip: Optional[str] = None,
|
||
|
dst_ip: Optional[str] = None,
|
||
|
src_port: Optional[int] = None,
|
||
|
dst_port: Optional[int] = None,
|
||
|
protocol: Optional[str] = "tcp",
|
||
|
override: bool = False,
|
||
|
):
|
||
|
try:
|
||
|
new_queue = NetQueue(callback, queue, src_ip, dst_ip, src_port, dst_port, protocol, override)
|
||
|
new_queue.write_rule()
|
||
|
except Exception as e:
|
||
|
logger.error(f"Failed to initialize NetQueue: {e}")
|
||
|
raise e
|
||
|
self.net_queues.append(new_queue)
|
||
|
|
||
|
def run(self):
|
||
|
asyncio.run(self.raise_bridges())
|
||
|
|
||
|
async def _listen(self, connection, callback: Callable) -> None:
|
||
|
try:
|
||
|
for packet in connection:
|
||
|
if asyncio.iscoroutinefunction(callback):
|
||
|
packet.payload = await callback(packet.payload)
|
||
|
else:
|
||
|
packet.payload = callback(packet.payload)
|
||
|
packet.mangle()
|
||
|
except fnfqueue.BufferOverflowException:
|
||
|
logger.warning("Packets arriving too quickly")
|
||
|
|
||
|
def _delete_rules(self):
|
||
|
for queue in self.net_queues:
|
||
|
try:
|
||
|
queue.delete_rule()
|
||
|
except Exception as e:
|
||
|
logger.error(f"Failed to delete rule: {e}")
|
||
|
|
||
|
async def raise_bridges(self):
|
||
|
tasks = []
|
||
|
for queue in self.net_queues:
|
||
|
connection = fnfqueue.Connection()
|
||
|
listener = connection.bind(queue.queue)
|
||
|
listener.set_mode(65535, fnfqueue.COPY_PACKET)
|
||
|
task = asyncio.create_task(self._listen(connection, queue.callback))
|
||
|
tasks.append(task)
|
||
|
await asyncio.gather(*tasks)
|