Resolve... asynchronicities :D

This commit is contained in:
Darryl Nixon 2023-07-16 13:29:49 -07:00
parent 143acbe02e
commit b61e16130d
4 changed files with 52 additions and 49 deletions

View file

@ -7,7 +7,7 @@ from typing import Set
from typing import Union from typing import Union
import aiofiles import aiofiles
from aiopath import Path from aiopath import AsyncPath
from .fileops import find_mount from .fileops import find_mount
from .fileops import mount_to_fs_handler from .fileops import mount_to_fs_handler
@ -27,33 +27,35 @@ class AsyncObject(object):
class ShredDir(AsyncObject): class ShredDir(AsyncObject):
"""Class for tracking each directory to be shredded, and its contents.""" """Class for tracking each directory to be shredded, and its contents."""
async def __init__(self, path: Path) -> None: async def __init__(self, path: AsyncPath, recursive: bool) -> None:
self.absolute_path = await path.resolve() self.absolute_path = await (await path.resolve()).absolute()
self.mount_point = find_mount(self.absolute_path) self.mount_point = await find_mount(self.absolute_path)
self.contents = await self._get_contents() self.contents = await self._get_contents(recursive)
self.mount_points = set(m for m in self.get_mount_points()) self.mount_points = set(m for m in self.enumerate_mount_points())
self.mount_points.add(self.mount_point)
self.fs_handler = await mount_to_fs_handler(self.mount_point) self.fs_handler = await mount_to_fs_handler(self.mount_point)
self.byte_size = sum(item.byte_size for item in self.contents) self.byte_size = sum(item.byte_size for item in self.contents)
stat = await path.stat() stat = await path.stat()
self.inode = stat.st_ino self.inode = stat.st_ino
async def _get_contents(self) -> List: async def _get_contents(self, recursive: bool) -> List:
contents = [] tasks = []
async for subpath in self.absolute_path.glob("*"): async for subpath in self.absolute_path.glob("*"):
if await subpath.is_dir(): if await subpath.is_dir():
if await subpath.is_symlink(): if recursive:
logger.warning(f"Symlink subdirectory found: {subpath}, skipping") if await subpath.is_symlink():
continue logger.warning(f"Symlink subdirectory found: {subpath}, skipping")
contents.append(await ShredDir(subpath)) continue
tasks.append(ShredDir(subpath, recursive))
else:
logger.warning(f"Subdirectory found: {subpath}, skipping (see -r/--recursive))")
elif subpath.is_file(): elif subpath.is_file():
contents.append(await ShredFile(subpath)) tasks.append(await ShredFile(subpath))
return contents return await asyncio.gather(*tasks)
def get_mount_points(self) -> Generator: def enumerate_mount_points(self) -> Generator:
for item in self.contents: for item in self.contents:
if isinstance(item, ShredDir): if isinstance(item, ShredDir):
yield from item.get_mount_points() yield from item.enumerate_mount_points()
yield self.mount_point yield self.mount_point
async def shred(self, hash: bool = False, dryrun: bool = False) -> bool: async def shred(self, hash: bool = False, dryrun: bool = False) -> bool:
@ -84,12 +86,12 @@ class ShredDir(AsyncObject):
class ShredFile(AsyncObject): class ShredFile(AsyncObject):
"""Class for tracking each file to be shredded.""" """Class for tracking each file to be shredded."""
async def __init__(self, path: Path) -> None: async def __init__(self, path: AsyncPath) -> None:
self.absolute_path = await path.resolve().absolute() self.absolute_path = await (await path.resolve()).absolute()
stat = await path.stat() stat = await path.stat()
self.byte_size = stat.st_size self.byte_size = stat.st_size
self.inode = stat.st_ino self.inode = stat.st_ino
self.mount_point = find_mount(self.absolute_path) self.mount_point = await find_mount(self.absolute_path)
self.fs_handler = await mount_to_fs_handler(self.mount_point) self.fs_handler = await mount_to_fs_handler(self.mount_point)
self.hardlinks = None self.hardlinks = None
@ -149,8 +151,12 @@ class ShredFile(AsyncObject):
if self.hardlinks: if self.hardlinks:
log_buf = f"[5/4] Unlinking {len(self.hardlinks)} hardlinks" log_buf = f"[5/4] Unlinking {len(self.hardlinks)} hardlinks"
if not dryrun: if not dryrun:
for link in self.hardlinks: tasks = [link.unlink() for link in self.hardlinks]
await link.unlink() done, _ = await asyncio.wait(tasks)
for task in done:
e = task.exception()
if e:
logger.warning(f"Unable to unlink hardlink: {e}")
else: else:
log_buf = "DRY RUN (no changes made) " + log_buf log_buf = "DRY RUN (no changes made) " + log_buf
logger.info(log_buf) logger.info(log_buf)

View file

@ -1,33 +1,31 @@
import asyncio import asyncio
from typing import List from collections.abc import Generator
from aiopath import Path import aiofiles
from aiopath import AsyncPath
from asyncstdlib.functools import lru_cache from asyncstdlib.functools import lru_cache
from .filesystems import FSHandlers from .filesystems import FSHandlers
from .logs import logger from .logs import logger
def find_mount(path: Path) -> Path: async def find_mount(path: AsyncPath) -> AsyncPath:
"""Find the mount point for a given path.""" """Find the mount point for a given path."""
path = path.absolute() path = await path.absolute()
while not path.is_mount(): while not await path.is_mount():
path = path.parent path = path.parent
return path return path
def get_all_mounts() -> List: async def get_all_mounts() -> Generator:
"""Get a list of all mounted filesystems.""" """Get a list of all mounted filesystems."""
mounts = [] async with aiofiles.open("/proc/mounts", mode="r") as output:
with open("/proc/mounts", "r") as f: async for line in output:
for line in f: yield line.split()[1]
mount = line.split()[1]
mounts.append(mount)
return mounts
@lru_cache(maxsize=1024) @lru_cache(maxsize=1024)
async def mount_to_fs_handler(path: Path) -> str: async def mount_to_fs_handler(path: AsyncPath) -> str:
# TODO: This is a hacky way to get the filesystem type, but it works for now. # TODO: This is a hacky way to get the filesystem type, but it works for now.
# Maybe with libblkid Python bindings? # Maybe with libblkid Python bindings?
logger.info(f"Getting filesystem for mount: {path}") logger.info(f"Getting filesystem for mount: {path}")

View file

@ -1,4 +1,5 @@
import argparse import argparse
import asyncio
from collections import defaultdict from collections import defaultdict
from .classes import get_all_hardlinks from .classes import get_all_hardlinks
@ -12,21 +13,19 @@ async def main(job: argparse.Namespace) -> bool:
This is the main function for processing a shred request. This is the main function for processing a shred request.
It is called by the CLI and builds a job queue based on the arguments passed. It is called by the CLI and builds a job queue based on the arguments passed.
""" """
new_paths = set()
# Expand all directories and files, and collect mount point information # Expand all directories and files, and collect mount point information
tasks = []
for path in job.paths: for path in job.paths:
if path.is_file(): if await path.is_file():
logger.info(f"Adding file: {path}") logger.info(f"Adding file: {path}")
new_paths.add(await ShredFile(path)) tasks.append(ShredFile(path))
elif path.is_dir(): elif await path.is_dir():
if job.recursive: logger.info(f"Adding directory: {path}")
logger.info(f"Adding directory: {path}") tasks.append(ShredDir(path, recursive=job.recursive))
new_paths.add(await ShredDir(path))
else:
logger.info(f"Skipping directory: {path} (try -r/--recursive)")
else: else:
raise TypeError(f"Not a file or directory: {path}") raise TypeError(f"Not a file or directory: {path}")
new_paths = set(await asyncio.gather(*tasks))
# Try to delete hardlinks based on the filesystem type # Try to delete hardlinks based on the filesystem type
job.paths = await get_all_hardlinks(new_paths) job.paths = await get_all_hardlinks(new_paths)

View file

@ -1,22 +1,22 @@
import os import os
import platform import platform
import sys import sys
from pathlib import Path as PathSync from pathlib import Path
from aiopath import Path from aiopath import AsyncPath
def validate_file_folder(value: str) -> Path: def validate_file_folder(value: str) -> Path:
file_folder_path = PathSync(value) file_folder_path = Path(value)
if not file_folder_path.exists(): if not file_folder_path.exists():
raise FileNotFoundError(f"No such file or folder: {value}") raise FileNotFoundError(f"No such file or folder: {value}")
if not file_folder_path.is_file() and not file_folder_path.is_dir(): if not file_folder_path.is_file() and not file_folder_path.is_dir():
raise TypeError(f"Not a file or directory: {value}") raise TypeError(f"Not a file or directory: {value}")
return Path(value) return AsyncPath(value)
def validate_logfile(value: str) -> Path: def validate_logfile(value: str) -> Path:
logfile_path = PathSync(value) logfile_path = Path(value)
if logfile_path.exists(): if logfile_path.exists():
confirm = input(f"The file {value} already exists. Do you want to overwrite it? ([y]es/[n]o): ") confirm = input(f"The file {value} already exists. Do you want to overwrite it? ([y]es/[n]o): ")
if confirm.lower() not in ["yes", "y"]: if confirm.lower() not in ["yes", "y"]: