Resolve... asynchronicities :D

This commit is contained in:
Darryl Nixon 2023-07-16 13:29:49 -07:00
parent 143acbe02e
commit b61e16130d
4 changed files with 52 additions and 49 deletions

View file

@ -7,7 +7,7 @@ from typing import Set
from typing import Union
import aiofiles
from aiopath import Path
from aiopath import AsyncPath
from .fileops import find_mount
from .fileops import mount_to_fs_handler
@ -27,33 +27,35 @@ class AsyncObject(object):
class ShredDir(AsyncObject):
"""Class for tracking each directory to be shredded, and its contents."""
async def __init__(self, path: Path) -> None:
self.absolute_path = await path.resolve()
self.mount_point = find_mount(self.absolute_path)
self.contents = await self._get_contents()
self.mount_points = set(m for m in self.get_mount_points())
self.mount_points.add(self.mount_point)
async def __init__(self, path: AsyncPath, recursive: bool) -> None:
self.absolute_path = await (await path.resolve()).absolute()
self.mount_point = await find_mount(self.absolute_path)
self.contents = await self._get_contents(recursive)
self.mount_points = set(m for m in self.enumerate_mount_points())
self.fs_handler = await mount_to_fs_handler(self.mount_point)
self.byte_size = sum(item.byte_size for item in self.contents)
stat = await path.stat()
self.inode = stat.st_ino
async def _get_contents(self) -> List:
contents = []
async def _get_contents(self, recursive: bool) -> List:
tasks = []
async for subpath in self.absolute_path.glob("*"):
if await subpath.is_dir():
if recursive:
if await subpath.is_symlink():
logger.warning(f"Symlink subdirectory found: {subpath}, skipping")
continue
contents.append(await ShredDir(subpath))
tasks.append(ShredDir(subpath, recursive))
else:
logger.warning(f"Subdirectory found: {subpath}, skipping (see -r/--recursive))")
elif subpath.is_file():
contents.append(await ShredFile(subpath))
return contents
tasks.append(await ShredFile(subpath))
return await asyncio.gather(*tasks)
def get_mount_points(self) -> Generator:
def enumerate_mount_points(self) -> Generator:
for item in self.contents:
if isinstance(item, ShredDir):
yield from item.get_mount_points()
yield from item.enumerate_mount_points()
yield self.mount_point
async def shred(self, hash: bool = False, dryrun: bool = False) -> bool:
@ -84,12 +86,12 @@ class ShredDir(AsyncObject):
class ShredFile(AsyncObject):
"""Class for tracking each file to be shredded."""
async def __init__(self, path: Path) -> None:
self.absolute_path = await path.resolve().absolute()
async def __init__(self, path: AsyncPath) -> None:
self.absolute_path = await (await path.resolve()).absolute()
stat = await path.stat()
self.byte_size = stat.st_size
self.inode = stat.st_ino
self.mount_point = find_mount(self.absolute_path)
self.mount_point = await find_mount(self.absolute_path)
self.fs_handler = await mount_to_fs_handler(self.mount_point)
self.hardlinks = None
@ -149,8 +151,12 @@ class ShredFile(AsyncObject):
if self.hardlinks:
log_buf = f"[5/4] Unlinking {len(self.hardlinks)} hardlinks"
if not dryrun:
for link in self.hardlinks:
await link.unlink()
tasks = [link.unlink() for link in self.hardlinks]
done, _ = await asyncio.wait(tasks)
for task in done:
e = task.exception()
if e:
logger.warning(f"Unable to unlink hardlink: {e}")
else:
log_buf = "DRY RUN (no changes made) " + log_buf
logger.info(log_buf)

View file

@ -1,33 +1,31 @@
import asyncio
from typing import List
from collections.abc import Generator
from aiopath import Path
import aiofiles
from aiopath import AsyncPath
from asyncstdlib.functools import lru_cache
from .filesystems import FSHandlers
from .logs import logger
def find_mount(path: Path) -> Path:
async def find_mount(path: AsyncPath) -> AsyncPath:
"""Find the mount point for a given path."""
path = path.absolute()
while not path.is_mount():
path = await path.absolute()
while not await path.is_mount():
path = path.parent
return path
def get_all_mounts() -> List:
async def get_all_mounts() -> Generator:
"""Get a list of all mounted filesystems."""
mounts = []
with open("/proc/mounts", "r") as f:
for line in f:
mount = line.split()[1]
mounts.append(mount)
return mounts
async with aiofiles.open("/proc/mounts", mode="r") as output:
async for line in output:
yield line.split()[1]
@lru_cache(maxsize=1024)
async def mount_to_fs_handler(path: Path) -> str:
async def mount_to_fs_handler(path: AsyncPath) -> str:
# TODO: This is a hacky way to get the filesystem type, but it works for now.
# Maybe with libblkid Python bindings?
logger.info(f"Getting filesystem for mount: {path}")

View file

@ -1,4 +1,5 @@
import argparse
import asyncio
from collections import defaultdict
from .classes import get_all_hardlinks
@ -12,21 +13,19 @@ async def main(job: argparse.Namespace) -> bool:
This is the main function for processing a shred request.
It is called by the CLI and builds a job queue based on the arguments passed.
"""
new_paths = set()
# Expand all directories and files, and collect mount point information
tasks = []
for path in job.paths:
if path.is_file():
if await path.is_file():
logger.info(f"Adding file: {path}")
new_paths.add(await ShredFile(path))
elif path.is_dir():
if job.recursive:
tasks.append(ShredFile(path))
elif await path.is_dir():
logger.info(f"Adding directory: {path}")
new_paths.add(await ShredDir(path))
else:
logger.info(f"Skipping directory: {path} (try -r/--recursive)")
tasks.append(ShredDir(path, recursive=job.recursive))
else:
raise TypeError(f"Not a file or directory: {path}")
new_paths = set(await asyncio.gather(*tasks))
# Try to delete hardlinks based on the filesystem type
job.paths = await get_all_hardlinks(new_paths)

View file

@ -1,22 +1,22 @@
import os
import platform
import sys
from pathlib import Path as PathSync
from pathlib import Path
from aiopath import Path
from aiopath import AsyncPath
def validate_file_folder(value: str) -> Path:
file_folder_path = PathSync(value)
file_folder_path = Path(value)
if not file_folder_path.exists():
raise FileNotFoundError(f"No such file or folder: {value}")
if not file_folder_path.is_file() and not file_folder_path.is_dir():
raise TypeError(f"Not a file or directory: {value}")
return Path(value)
return AsyncPath(value)
def validate_logfile(value: str) -> Path:
logfile_path = PathSync(value)
logfile_path = Path(value)
if logfile_path.exists():
confirm = input(f"The file {value} already exists. Do you want to overwrite it? ([y]es/[n]o): ")
if confirm.lower() not in ["yes", "y"]: