melamine/melamine/classes.py

148 lines
5.1 KiB
Python

import asyncio
import hashlib
from collections.abc import Generator
from pathlib import Path
from secrets import token_bytes
from typing import List
from typing import Set
from typing import Union
import aiofiles
from .fileops import find_mount
from .fileops import mount_to_fs_handler
from .logs import logger
async def get_all_hardlinks(paths: Set[Path]) -> None:
for path in paths:
if isinstance(path, ShredFile):
logger.info(f"Getting hardlinks for {path}")
hardlink_count = 0
path.hardlinks = set()
async for link in path.fs_handler.get_hardlinks(path):
hardlink_count += 1
path.hardlinks.add(link)
logger.info(f"Found hardlink: {link}")
logger.info(f"Found {hardlink_count} hardlinks for {path.absolute_path}")
if isinstance(path, ShredDir):
await get_all_hardlinks(path.contents)
return paths
class AsyncObject(object):
async def __new__(cls, *a, **kw):
instance = super().__new__(cls)
await instance.__init__(*a, **kw)
return instance
async def __init__(self):
pass
class ShredDir(AsyncObject):
"""Class for tracking each directory to be shredded, and its contents."""
async def __init__(self, path: Path) -> None:
self.absolute_path = path.resolve()
self.byte_size = sum(item.byte_size for item in self.contents)
self.mount_point = find_mount(self.absolute_path)
self.contents = await self._get_contents()
self.mount_points = set(m for m in self.get_mount_points())
self.mount_points.add(self.mount_point)
self.fs_handler = mount_to_fs_handler(self.mount_point)
async def _get_contents(self) -> List:
contents = []
for subpath in self.absolute_path.glob("*"):
if subpath.is_dir():
if subpath.is_symlink():
logger.warning(f"Symlink subdirectory found: {subpath}, skipping")
continue
contents.append(await ShredDir(subpath))
elif subpath.is_file():
contents.append(await ShredFile(subpath))
return contents
def get_mount_points(self) -> Generator:
for item in self.contents:
if isinstance(item, ShredDir):
yield from item.get_mount_points()
yield self.mount_point
async def shred(self, hash: bool = False, dryrun: bool = False) -> bool:
tasks = []
for item in self.contents:
tasks.append(item.shred(hash, dryrun))
return all(await asyncio.gather(*tasks))
def __hash__(self) -> int:
return hash(self.absolute_path)
class ShredFile(AsyncObject):
"""Class for tracking each file to be shredded."""
async def __init__(self, path: Path) -> None:
self.absolute_path = path.resolve()
self.byte_size = path.stat().st_size
self.mount_point = find_mount(self.absolute_path)
self.fs_handler = await mount_to_fs_handler(self.mount_point)
self.hardlinks = None
async def shred(self, hash: bool = False, dryrun: bool = False) -> Union[bool, bytes]:
"""Shred the file with a single file descriptor."""
try:
logger.info(f"Shredding: {self.absolute_path}")
async with aiofiles.open(self.absolute_path, "rb+") as file:
if hash:
sha1 = hashlib.sha1(usedforsecurity=False)
async for chunk in aiofiles.iterate(file):
sha1.update(chunk)
self.sha1 = sha1.digest()
logger.info(f"Got hash {sha1.hexdigest()}")
# First pass: Overwrite with binary zeroes
logger.info(f"[1/4] Writing zeroes ({self.absolute_path.name})")
await file.seek(0)
if not dryrun:
await file.write(b"\x00" * self.byte_size)
await file.flush()
# Second pass: Overwrite with binary ones
logger.info(f"[2/4] Writing ones ({self.absolute_path.name})")
await file.seek(0)
if not dryrun:
await file.write(b"\xff" * self.byte_size)
await file.flush()
# Third pass: Overwrite with random data
logger.info(f"[3/4] Writing randoms ({self.absolute_path.name})")
await file.seek(0)
random_data = token_bytes(self.byte_size)
if not dryrun:
await file.write(random_data)
await file.flush()
# Remove the file
logger.info(f"[4/4] Unlinking {self.absolute_path}")
if not dryrun:
await file.unlink()
# Remove any hardlinks
if self.hardlinks:
logger.info(f"[5/4] Unlinking {len(self.hardlinks)} hardlinks")
if not dryrun:
for link in self.hardlinks:
await link.unlink()
return True
except Exception as e:
logger.error(f"File wipe failed: {e}")
return False
def __hash__(self) -> int:
return hash(self.absolute_path)