from ipaddress import AddressValueError from ipaddress import ip_address from ipaddress import IPv6Address from typing import Callable from typing import Optional from typing import Union import iptc from .utils.logger import logger from .utils.lookup import PROTOCOLS, ALL_TABLES, OUTGOING_MANGLE class NetQueue: def __init__( self, callback: Callable, queue: int, src_ip: Optional[str] = None, dst_ip: Optional[str] = None, src_port: Optional[int] = None, dst_port: Optional[int] = None, protocol: Optional[str] = "tcp", override: bool = False, ): self.callback = self.validate_callable(callback) self.src_port = self.validate_port(src_port, "source") self.dst_port = self.validate_port(dst_port, "destination") self.src_ip = self.validate_ip(src_ip, "source") self.dst_ip = self.validate_ip(dst_ip, "destination") self.protocol = self.validate_protocol(protocol) self.queue = self._validate_queue(queue, override) self.rule = self._create_rule() def _create_rule(self) -> iptc.Rule: rule = iptc.Rule() target = iptc.Target(rule, "NFQUEUE") target.set_parameter("queue-num", str(self.queue)) rule.protocol = self.protocol match = rule.create_match(self.protocol) if self.dst_port: match.dport = str(self.dst_port) if self.src_port: match.sport = str(self.src_port) match = iptc.Match(rule, "iprange") if self.src_ip: match.src_range = str(self.src_ip) if self.dst_ip: match.dst_range = str(self.dst_ip) rule.target = target return rule def write_rule(self): OUTGOING_MANGLE.insert_rule(self.rule) def delete_rule(self): try: OUTGOING_MANGLE.delete_rule(self.rule) except iptc.ip4tc.IPTCError: logger.warning("Failed to delete rule, it may have already been deleted") @staticmethod def validate_callable(callback: Callable) -> Callable: if not callable(callback): raise ValueError(f"Invalid callback: {callback}") return callback @staticmethod def validate_ip(ip: Optional[str], description: str) -> Optional[str]: if ip: try: if type(ip := ip_address(ip)) == IPv6Address: raise NotImplementedError(f"IPv6 not supported: {ip}") except (AddressValueError, ValueError): raise AddressValueError(f"Invalid {description} IP address: {ip}") return ip @staticmethod def validate_port(port: Optional[int], description: str) -> Union[int, None]: if port: if not 0 <= port <= 65535: raise ValueError(f"Invalid {description} port: {port}") return port @staticmethod def validate_protocol(protocol: Optional[str]) -> Union[str, None]: if protocol: try: PROTOCOLS[protocol] except KeyError: raise KeyError(f"Invalid protocol: {protocol}") return protocol @staticmethod def _is_queue_taken(queue: int, override: bool) -> bool: for table in ALL_TABLES: for chain in table.chains: for rule in chain.rules: if rule.target.name == "NFQUEUE" and rule.target.get_all_parameters()["queue-num"] == str(queue): if override: logger.warning(f"Queue {queue} is already taken, clearing it") chain.delete_rule(rule) return False return True return False @staticmethod def _validate_queue(queue: int, override: bool) -> int: if not 0 <= queue <= 65535: raise ValueError(f"Invalid queue number: {queue}") if NetQueue._is_queue_taken(queue, override): logger.warning(f"Queue {queue} is already taken, raising error") raise ValueError(f"Queue {queue} is already taken") return queue def __repr__(self): return ( f"NetQueueFilter(" f"queue={self.queue}, " f"callback={self.callback}, " f"src_ip={self.src_ip}, " f"dst_ip={self.dst_ip}, " f"src_port={self.src_port}, " f"dst_port={self.dst_port}, " f"protocol={self.protocol}, " f"rule={self.rule}, " f")" ) def __str__(self): return self.__repr__()