Improve logic for getting hardlinks async

This commit is contained in:
Darryl Nixon 2023-07-16 10:55:05 -07:00
parent b96ad00679
commit 9ea7fdd8d1
3 changed files with 28 additions and 16 deletions

View file

@ -65,7 +65,7 @@ class ShredFile:
async def shred(self, hash: bool = False, dryrun: bool = False) -> Union[bool, bytes]: async def shred(self, hash: bool = False, dryrun: bool = False) -> Union[bool, bytes]:
"""Shred the file with a single file descriptor.""" """Shred the file with a single file descriptor."""
try: try:
logger.info(f"Shredding file: {self.absolute_path}") logger.info(f"Shredding: {self.absolute_path}")
async with aiofiles.open(self.absolute_path, "rb+") as file: async with aiofiles.open(self.absolute_path, "rb+") as file:
if hash: if hash:
@ -76,21 +76,21 @@ class ShredFile:
logger.info(f"Got hash {sha1.hexdigest()}") logger.info(f"Got hash {sha1.hexdigest()}")
# First pass: Overwrite with binary zeroes # First pass: Overwrite with binary zeroes
logger.info("Performing first pass: Overwriting with binary zeroes") logger.info(f"[1/4] Writing zeroes ({self.absolute_path.name})")
await file.seek(0) await file.seek(0)
if not dryrun: if not dryrun:
await file.write(b"\x00" * self.byte_size) await file.write(b"\x00" * self.byte_size)
await file.flush() await file.flush()
# Second pass: Overwrite with binary ones # Second pass: Overwrite with binary ones
logger.info("Performing second pass: Overwriting with binary ones") logger.info(f"[2/4] Writing ones ({self.absolute_path.name})")
await file.seek(0) await file.seek(0)
if not dryrun: if not dryrun:
await file.write(b"\xff" * self.byte_size) await file.write(b"\xff" * self.byte_size)
await file.flush() await file.flush()
# Third pass: Overwrite with random data # Third pass: Overwrite with random data
logger.info("Performing third pass: Overwriting with random data") logger.info(f"[3/4] Writing randoms ({self.absolute_path.name})")
await file.seek(0) await file.seek(0)
random_data = token_bytes(self.byte_size) random_data = token_bytes(self.byte_size)
if not dryrun: if not dryrun:
@ -98,14 +98,14 @@ class ShredFile:
await file.flush() await file.flush()
# Remove the file # Remove the file
logger.info(f"Removing file {self.absolute_path}") logger.info(f"[4/4] Unlinking {self.absolute_path}")
if not dryrun: if not dryrun:
file.unlink() file.unlink()
# Remove any hardlinks # Remove any hardlinks
if self.hardlinks: if self.hardlinks:
logger.info(f"Removing {len(self.hardlinks)} hardlinks") logger.info(f"[5/4] Unlinking {len(self.hardlinks)} hardlinks")
if not dryrun: if not dryrun:
for link in self.hardlinks: for link in self.hardlinks:
link.unlink() link.unlink()

View file

@ -1,13 +1,32 @@
import asyncio import asyncio
from pathlib import Path from pathlib import Path
from typing import List from typing import List
from typing import Set
from asyncstdlib.functools import lru_cache from asyncstdlib.functools import lru_cache
from .classes import ShredDir
from .classes import ShredFile
from .filesystems import FSHandlers from .filesystems import FSHandlers
from .logs import logger from .logs import logger
async def get_all_hardlinks(paths: Set[Path]) -> None:
for path in paths:
if isinstance(path, ShredFile):
logger.info("Getting hardlinks for {path}")
hardlink_count = 0
path.hardlinks = set()
async for link in path.fs_handler.get_hardlinks(path):
hardlink_count += 1
path.hardlinks.add(link)
logger.info(f"Found hardlink: {link}")
logger.info(f"Found {hardlink_count} hardlinks for {path.absolute_path}")
if isinstance(path, ShredDir):
await get_all_hardlinks(path)
return paths
def find_mount(path: Path) -> Path: def find_mount(path: Path) -> Path:
"""Find the mount point for a given path.""" """Find the mount point for a given path."""
path = path.absolute() path = path.absolute()

View file

@ -1,5 +1,6 @@
from .classes import ShredDir from .classes import ShredDir
from .classes import ShredFile from .classes import ShredFile
from .fileops import get_all_hardlinks
from .fileops import mount_to_fs_handler from .fileops import mount_to_fs_handler
from .logs import logger from .logs import logger
@ -10,6 +11,7 @@ async def main(job) -> bool:
It is called by the CLI and builds a job queue based on the arguments passed. It is called by the CLI and builds a job queue based on the arguments passed.
""" """
new_paths = set() new_paths = set()
logger.info(f"job type is {type(job)}")
# Expand all directories and files, and collect mount point information # Expand all directories and files, and collect mount point information
for path in job.paths: for path in job.paths:
@ -28,18 +30,9 @@ async def main(job) -> bool:
logger.info(f"Skipping directory: {path} (try -r/--recursive)") logger.info(f"Skipping directory: {path} (try -r/--recursive)")
else: else:
raise TypeError(f"Not a file or directory: {path}") raise TypeError(f"Not a file or directory: {path}")
job.paths = new_paths
# Get hardlinks to subsequently unlink for all files # Get hardlinks to subsequently unlink for all files
for path in job.paths: job.paths = await get_all_hardlinks(new_paths)
if isinstance(path, ShredFile):
logger.info("Getting hardlinks for {path}")
hardlink_count = 0
path.hardlinks = set()
async for link in path.fs_handler.get_hardlinks(path):
hardlink_count += 1
path.hardlinks.add(link)
logger.info(f"Found hardlink: {link}")
# Shred all physical files including hardlinks # Shred all physical files including hardlinks
for path in job.paths: for path in job.paths: