diff --git a/melamine/classes.py b/melamine/classes.py index e0f3984..9e51cb0 100644 --- a/melamine/classes.py +++ b/melamine/classes.py @@ -7,7 +7,7 @@ from typing import Set from typing import Union import aiofiles -from aiopath import Path +from aiopath import AsyncPath from .fileops import find_mount from .fileops import mount_to_fs_handler @@ -27,33 +27,35 @@ class AsyncObject(object): class ShredDir(AsyncObject): """Class for tracking each directory to be shredded, and its contents.""" - async def __init__(self, path: Path) -> None: - self.absolute_path = await path.resolve() - self.mount_point = find_mount(self.absolute_path) - self.contents = await self._get_contents() - self.mount_points = set(m for m in self.get_mount_points()) - self.mount_points.add(self.mount_point) + async def __init__(self, path: AsyncPath, recursive: bool) -> None: + self.absolute_path = await (await path.resolve()).absolute() + self.mount_point = await find_mount(self.absolute_path) + self.contents = await self._get_contents(recursive) + self.mount_points = set(m for m in self.enumerate_mount_points()) self.fs_handler = await mount_to_fs_handler(self.mount_point) self.byte_size = sum(item.byte_size for item in self.contents) stat = await path.stat() self.inode = stat.st_ino - async def _get_contents(self) -> List: - contents = [] + async def _get_contents(self, recursive: bool) -> List: + tasks = [] async for subpath in self.absolute_path.glob("*"): if await subpath.is_dir(): - if await subpath.is_symlink(): - logger.warning(f"Symlink subdirectory found: {subpath}, skipping") - continue - contents.append(await ShredDir(subpath)) + if recursive: + if await subpath.is_symlink(): + logger.warning(f"Symlink subdirectory found: {subpath}, skipping") + continue + tasks.append(ShredDir(subpath, recursive)) + else: + logger.warning(f"Subdirectory found: {subpath}, skipping (see -r/--recursive))") elif subpath.is_file(): - contents.append(await ShredFile(subpath)) - return contents + tasks.append(await ShredFile(subpath)) + return await asyncio.gather(*tasks) - def get_mount_points(self) -> Generator: + def enumerate_mount_points(self) -> Generator: for item in self.contents: if isinstance(item, ShredDir): - yield from item.get_mount_points() + yield from item.enumerate_mount_points() yield self.mount_point async def shred(self, hash: bool = False, dryrun: bool = False) -> bool: @@ -84,12 +86,12 @@ class ShredDir(AsyncObject): class ShredFile(AsyncObject): """Class for tracking each file to be shredded.""" - async def __init__(self, path: Path) -> None: - self.absolute_path = await path.resolve().absolute() + async def __init__(self, path: AsyncPath) -> None: + self.absolute_path = await (await path.resolve()).absolute() stat = await path.stat() self.byte_size = stat.st_size self.inode = stat.st_ino - self.mount_point = find_mount(self.absolute_path) + self.mount_point = await find_mount(self.absolute_path) self.fs_handler = await mount_to_fs_handler(self.mount_point) self.hardlinks = None @@ -149,8 +151,12 @@ class ShredFile(AsyncObject): if self.hardlinks: log_buf = f"[5/4] Unlinking {len(self.hardlinks)} hardlinks" if not dryrun: - for link in self.hardlinks: - await link.unlink() + tasks = [link.unlink() for link in self.hardlinks] + done, _ = await asyncio.wait(tasks) + for task in done: + e = task.exception() + if e: + logger.warning(f"Unable to unlink hardlink: {e}") else: log_buf = "DRY RUN (no changes made) " + log_buf logger.info(log_buf) diff --git a/melamine/fileops.py b/melamine/fileops.py index 6e0ad69..5b58aca 100644 --- a/melamine/fileops.py +++ b/melamine/fileops.py @@ -1,33 +1,31 @@ import asyncio -from typing import List +from collections.abc import Generator -from aiopath import Path +import aiofiles +from aiopath import AsyncPath from asyncstdlib.functools import lru_cache from .filesystems import FSHandlers from .logs import logger -def find_mount(path: Path) -> Path: +async def find_mount(path: AsyncPath) -> AsyncPath: """Find the mount point for a given path.""" - path = path.absolute() - while not path.is_mount(): + path = await path.absolute() + while not await path.is_mount(): path = path.parent return path -def get_all_mounts() -> List: +async def get_all_mounts() -> Generator: """Get a list of all mounted filesystems.""" - mounts = [] - with open("/proc/mounts", "r") as f: - for line in f: - mount = line.split()[1] - mounts.append(mount) - return mounts + async with aiofiles.open("/proc/mounts", mode="r") as output: + async for line in output: + yield line.split()[1] @lru_cache(maxsize=1024) -async def mount_to_fs_handler(path: Path) -> str: +async def mount_to_fs_handler(path: AsyncPath) -> str: # TODO: This is a hacky way to get the filesystem type, but it works for now. # Maybe with libblkid Python bindings? logger.info(f"Getting filesystem for mount: {path}") diff --git a/melamine/shred.py b/melamine/shred.py index 6471e3d..a0bf80d 100644 --- a/melamine/shred.py +++ b/melamine/shred.py @@ -1,4 +1,5 @@ import argparse +import asyncio from collections import defaultdict from .classes import get_all_hardlinks @@ -12,21 +13,19 @@ async def main(job: argparse.Namespace) -> bool: This is the main function for processing a shred request. It is called by the CLI and builds a job queue based on the arguments passed. """ - new_paths = set() # Expand all directories and files, and collect mount point information + tasks = [] for path in job.paths: - if path.is_file(): + if await path.is_file(): logger.info(f"Adding file: {path}") - new_paths.add(await ShredFile(path)) - elif path.is_dir(): - if job.recursive: - logger.info(f"Adding directory: {path}") - new_paths.add(await ShredDir(path)) - else: - logger.info(f"Skipping directory: {path} (try -r/--recursive)") + tasks.append(ShredFile(path)) + elif await path.is_dir(): + logger.info(f"Adding directory: {path}") + tasks.append(ShredDir(path, recursive=job.recursive)) else: raise TypeError(f"Not a file or directory: {path}") + new_paths = set(await asyncio.gather(*tasks)) # Try to delete hardlinks based on the filesystem type job.paths = await get_all_hardlinks(new_paths) diff --git a/melamine/validators.py b/melamine/validators.py index 226d24d..880e43c 100644 --- a/melamine/validators.py +++ b/melamine/validators.py @@ -1,22 +1,22 @@ import os import platform import sys -from pathlib import Path as PathSync +from pathlib import Path -from aiopath import Path +from aiopath import AsyncPath def validate_file_folder(value: str) -> Path: - file_folder_path = PathSync(value) + file_folder_path = Path(value) if not file_folder_path.exists(): raise FileNotFoundError(f"No such file or folder: {value}") if not file_folder_path.is_file() and not file_folder_path.is_dir(): raise TypeError(f"Not a file or directory: {value}") - return Path(value) + return AsyncPath(value) def validate_logfile(value: str) -> Path: - logfile_path = PathSync(value) + logfile_path = Path(value) if logfile_path.exists(): confirm = input(f"The file {value} already exists. Do you want to overwrite it? ([y]es/[n]o): ") if confirm.lower() not in ["yes", "y"]: