135 lines
4.5 KiB
Python
135 lines
4.5 KiB
Python
|
from ipaddress import AddressValueError
|
||
|
from ipaddress import ip_address
|
||
|
from ipaddress import IPv6Address
|
||
|
from typing import Callable
|
||
|
from typing import Optional
|
||
|
from typing import Union
|
||
|
|
||
|
import iptc
|
||
|
|
||
|
from .utils.logger import logger
|
||
|
from .utils.lookup import PROTOCOLS, ALL_TABLES, OUTGOING_MANGLE
|
||
|
|
||
|
|
||
|
class NetQueue:
|
||
|
def __init__(
|
||
|
self,
|
||
|
callback: Callable,
|
||
|
queue: int,
|
||
|
src_ip: Optional[str] = None,
|
||
|
dst_ip: Optional[str] = None,
|
||
|
src_port: Optional[int] = None,
|
||
|
dst_port: Optional[int] = None,
|
||
|
protocol: Optional[str] = "tcp",
|
||
|
override: bool = False,
|
||
|
):
|
||
|
self.callback = self.validate_callable(callback)
|
||
|
self.src_port = self.validate_port(src_port, "source")
|
||
|
self.dst_port = self.validate_port(dst_port, "destination")
|
||
|
self.src_ip = self.validate_ip(src_ip, "source")
|
||
|
self.dst_ip = self.validate_ip(dst_ip, "destination")
|
||
|
self.protocol = self.validate_protocol(protocol)
|
||
|
|
||
|
self.queue = self._validate_queue(queue, override)
|
||
|
self.rule = self._create_rule()
|
||
|
|
||
|
def _create_rule(self) -> iptc.Rule:
|
||
|
rule = iptc.Rule()
|
||
|
target = iptc.Target(rule, "NFQUEUE")
|
||
|
target.set_parameter("queue-num", str(self.queue))
|
||
|
rule.protocol = self.protocol
|
||
|
match = rule.create_match(self.protocol)
|
||
|
if self.dst_port:
|
||
|
match.dport = str(self.dst_port)
|
||
|
if self.src_port:
|
||
|
match.sport = str(self.src_port)
|
||
|
match = iptc.Match(rule, "iprange")
|
||
|
if self.src_ip:
|
||
|
match.src_range = str(self.src_ip)
|
||
|
if self.dst_ip:
|
||
|
match.dst_range = str(self.dst_ip)
|
||
|
rule.target = target
|
||
|
return rule
|
||
|
|
||
|
def write_rule(self):
|
||
|
OUTGOING_MANGLE.insert_rule(self.rule)
|
||
|
|
||
|
def delete_rule(self):
|
||
|
try:
|
||
|
OUTGOING_MANGLE.delete_rule(self.rule)
|
||
|
except iptc.ip4tc.IPTCError:
|
||
|
logger.warning("Failed to delete rule, it may have already been deleted")
|
||
|
|
||
|
@staticmethod
|
||
|
def validate_callable(callback: Callable) -> Callable:
|
||
|
if not callable(callback):
|
||
|
raise ValueError(f"Invalid callback: {callback}")
|
||
|
return callback
|
||
|
|
||
|
@staticmethod
|
||
|
def validate_ip(ip: Optional[str], description: str) -> Optional[str]:
|
||
|
if ip:
|
||
|
try:
|
||
|
if type(ip := ip_address(ip)) == IPv6Address:
|
||
|
raise NotImplementedError(f"IPv6 not supported: {ip}")
|
||
|
except (AddressValueError, ValueError):
|
||
|
raise AddressValueError(f"Invalid {description} IP address: {ip}")
|
||
|
return ip
|
||
|
|
||
|
@staticmethod
|
||
|
def validate_port(port: Optional[int], description: str) -> Union[int, None]:
|
||
|
if port:
|
||
|
if not 0 <= port <= 65535:
|
||
|
raise ValueError(f"Invalid {description} port: {port}")
|
||
|
return port
|
||
|
|
||
|
@staticmethod
|
||
|
def validate_protocol(protocol: Optional[str]) -> Union[str, None]:
|
||
|
if protocol:
|
||
|
try:
|
||
|
PROTOCOLS[protocol]
|
||
|
except KeyError:
|
||
|
raise KeyError(f"Invalid protocol: {protocol}")
|
||
|
return protocol
|
||
|
|
||
|
@staticmethod
|
||
|
def _is_queue_taken(queue: int, override: bool) -> bool:
|
||
|
for table in ALL_TABLES:
|
||
|
for chain in table.chains:
|
||
|
for rule in chain.rules:
|
||
|
if rule.target.name == "NFQUEUE" and rule.target.get_all_parameters()["queue-num"] == str(queue):
|
||
|
if override:
|
||
|
logger.warning(f"Queue {queue} is already taken, clearing it")
|
||
|
chain.delete_rule(rule)
|
||
|
return False
|
||
|
return True
|
||
|
return False
|
||
|
|
||
|
@staticmethod
|
||
|
def _validate_queue(queue: int, override: bool) -> int:
|
||
|
if not 0 <= queue <= 65535:
|
||
|
raise ValueError(f"Invalid queue number: {queue}")
|
||
|
|
||
|
if NetQueue._is_queue_taken(queue, override):
|
||
|
logger.warning(f"Queue {queue} is already taken, raising error")
|
||
|
raise ValueError(f"Queue {queue} is already taken")
|
||
|
|
||
|
return queue
|
||
|
|
||
|
def __repr__(self):
|
||
|
return (
|
||
|
f"NetQueueFilter("
|
||
|
f"queue={self.queue}, "
|
||
|
f"callback={self.callback}, "
|
||
|
f"src_ip={self.src_ip}, "
|
||
|
f"dst_ip={self.dst_ip}, "
|
||
|
f"src_port={self.src_port}, "
|
||
|
f"dst_port={self.dst_port}, "
|
||
|
f"protocol={self.protocol}, "
|
||
|
f"rule={self.rule}, "
|
||
|
f")"
|
||
|
)
|
||
|
|
||
|
def __str__(self):
|
||
|
return self.__repr__()
|