Source code for clustrix.filesystem

"""
Unified filesystem operations for local and remote clusters.

This module provides a consistent interface for filesystem operations that work
both locally and on remote clusters based on the ClusterConfig object.
"""

import os
import glob as glob_module
from pathlib import Path
from typing import List, Optional, Dict, Any

import paramiko

from .config import ClusterConfig


[docs] class FileInfo: """File information structure."""
[docs] def __init__( self, size: int, modified: float, is_dir: bool, permissions: str, name: str = "" ): """Initialize FileInfo with file metadata.""" self.size = size self.modified = modified # Unix timestamp self.is_dir = is_dir self.permissions = permissions self.name = name
@property def is_file(self): """Check if this is a file (not a directory).""" return not self.is_dir @property def modified_datetime(self): """Get modified time as datetime object.""" from datetime import datetime return datetime.fromtimestamp(self.modified)
[docs] def __repr__(self): """String representation of FileInfo.""" return ( f"FileInfo(name='{self.name}', size={self.size}, modified={self.modified}, " f"is_dir={self.is_dir}, permissions='{self.permissions}')" )
[docs] def __eq__(self, other): """Check equality with another FileInfo object.""" if not isinstance(other, FileInfo): return False return ( self.name == other.name and self.size == other.size and self.modified == other.modified and self.is_dir == other.is_dir and self.permissions == other.permissions )
[docs] class DiskUsage: """Disk usage information."""
[docs] def __init__(self, total_bytes: int, file_count: int): """Initialize DiskUsage with usage statistics.""" self.total_bytes = total_bytes self.file_count = file_count
@property def total_mb(self) -> float: """Total size in megabytes.""" return self.total_bytes / (1024 * 1024) @property def total_gb(self) -> float: """Total size in gigabytes.""" return self.total_bytes / (1024 * 1024 * 1024)
[docs] def __repr__(self): """String representation of DiskUsage.""" return ( f"DiskUsage(total_bytes={self.total_bytes}, file_count={self.file_count})" )
[docs] def __eq__(self, other): """Check equality with another DiskUsage object.""" if not isinstance(other, DiskUsage): return False return ( self.total_bytes == other.total_bytes and self.file_count == other.file_count )
[docs] class ClusterFilesystem: """Unified filesystem operations for local and remote clusters."""
[docs] def __init__(self, config: ClusterConfig): """Initialize filesystem with cluster configuration.""" self.config = config self._ssh_client: Optional[paramiko.SSHClient] = None self._sftp_client: Optional[paramiko.SFTPClient] = None # Auto-detect if we're running on the target cluster (for shared filesystems) self._auto_detect_cluster_location()
def _auto_detect_cluster_location(self): """ Auto-detect if we're already running on the target cluster. If we're running on the same cluster as the target, we should use local filesystem operations instead of SSH, since most HPC clusters have shared filesystems (NFS/Lustre) across head and compute nodes. """ # Only attempt detection if cluster_type is not already 'local' if self.config.cluster_type == "local": return # Skip detection if no cluster_host is configured if not hasattr(self.config, "cluster_host") or not self.config.cluster_host: return try: import socket current_hostname = socket.gethostname() target_host = self.config.cluster_host # Check various hostname matching scenarios is_on_target_cluster = ( # Exact match current_hostname == target_host # Current host contains target (e.g., s17.hpcc.dartmouth.edu contains ndoli.dartmouth.edu) or target_host in current_hostname # Target contains current (e.g., compute node s17 part of ndoli.dartmouth.edu) or current_hostname in target_host # Domain matching (e.g., s04.hpcc.dartmouth.edu and ndoli.dartmouth.edu) or self._same_domain(current_hostname, target_host) # HPC cluster specific: check if both are in same institution domain or self._same_institution_domain(current_hostname, target_host) ) if is_on_target_cluster: # We're on the target cluster - use local filesystem operations original_cluster_type = self.config.cluster_type self.config.cluster_type = "local" # Log the detection for debugging print( f"Cluster detection: Running on target cluster (hostname: {current_hostname})" ) print( f"Switched from {original_cluster_type} to local filesystem operations" ) except Exception as e: # If detection fails, continue with original cluster_type print(f"Warning: Cluster detection failed: {e}") pass def _same_domain(self, host1: str, host2: str) -> bool: """Check if two hostnames are in the same domain.""" try: # Extract domain parts (ignore first part which might be different) domain1_parts = host1.split(".")[1:] # Skip hostname, get domain domain2_parts = host2.split(".")[1:] # Skip hostname, get domain # Check if domains match (at least 2 parts) if len(domain1_parts) >= 2 and len(domain2_parts) >= 2: return domain1_parts == domain2_parts except (IndexError, AttributeError): pass return False def _same_institution_domain(self, host1: str, host2: str) -> bool: """ Check if two hostnames are from the same institution. This handles cases like: - s04.hpcc.dartmouth.edu (compute node) - ndoli.dartmouth.edu (head node) Both should be considered the same cluster. """ try: # Split hostnames into parts parts1 = host1.split(".") parts2 = host2.split(".") # For HPC clusters, check if they share the institution domain # e.g., both end in "dartmouth.edu" if len(parts1) >= 2 and len(parts2) >= 2: # Get the last 2 parts (institution.tld) institution1 = ".".join(parts1[-2:]) institution2 = ".".join(parts2[-2:]) if institution1 == institution2: # Same institution - likely same cluster return True # Also check for common HPC patterns # e.g., login.cluster.edu and compute01.cluster.edu if len(parts1) >= 3 and len(parts2) >= 3: # Check if middle part matches (cluster name) cluster1_parts = parts1[-3:] # Get last 3 parts cluster2_parts = parts2[-3:] # Get last 3 parts # If the cluster and institution parts match if ( cluster1_parts[1:] == cluster2_parts[1:] ): # Same cluster.institution.edu return True except (IndexError, AttributeError): pass return False
[docs] def __del__(self): """Clean up SSH connections.""" self._close_connections()
def _close_connections(self): """Close SSH and SFTP connections.""" if self._sftp_client: self._sftp_client.close() self._sftp_client = None if self._ssh_client: self._ssh_client.close() self._ssh_client = None def _get_ssh_client(self) -> paramiko.SSHClient: """Get or create SSH client connection.""" if self._ssh_client is None: self._ssh_client = paramiko.SSHClient() self._ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) # Connect based on authentication method connect_kwargs: Dict[str, Any] = { "hostname": self.config.cluster_host, "port": self.config.cluster_port, "username": self.config.username, } if self.config.key_file: connect_kwargs["key_filename"] = self.config.key_file elif self.config.password: connect_kwargs["password"] = self.config.password else: # Try default SSH key locations connect_kwargs["look_for_keys"] = True self._ssh_client.connect(**connect_kwargs) return self._ssh_client def _get_sftp_client(self) -> paramiko.SFTPClient: """Get or create SFTP client.""" if self._sftp_client is None: ssh = self._get_ssh_client() self._sftp_client = ssh.open_sftp() return self._sftp_client def _get_full_path(self, path: str) -> str: """Get full path based on working directory.""" if self.config.cluster_type == "local": base_dir = self.config.local_work_dir or os.getcwd() else: base_dir = self.config.remote_work_dir # Handle absolute paths if os.path.isabs(path): return path return os.path.join(base_dir, path) # ===== Core Operations =====
[docs] def ls(self, path: str = ".") -> List[str]: """List directory contents.""" if self.config.cluster_type == "local": return self._local_ls(path) else: return self._remote_ls(path)
[docs] def find(self, pattern: str, path: str = ".") -> List[str]: """Find files matching pattern.""" if self.config.cluster_type == "local": return self._local_find(pattern, path) else: return self._remote_find(pattern, path)
[docs] def stat(self, path: str) -> FileInfo: """Get file/directory information.""" if self.config.cluster_type == "local": return self._local_stat(path) else: return self._remote_stat(path)
[docs] def exists(self, path: str) -> bool: """Check if file/directory exists.""" if self.config.cluster_type == "local": return self._local_exists(path) else: return self._remote_exists(path)
[docs] def isdir(self, path: str) -> bool: """Check if path is a directory.""" if self.config.cluster_type == "local": return self._local_isdir(path) else: return self._remote_isdir(path)
[docs] def isfile(self, path: str) -> bool: """Check if path is a file.""" if self.config.cluster_type == "local": return self._local_isfile(path) else: return self._remote_isfile(path)
[docs] def glob(self, pattern: str, path: str = ".") -> List[str]: """Pattern matching for files.""" if self.config.cluster_type == "local": return self._local_glob(pattern, path) else: return self._remote_glob(pattern, path)
[docs] def du(self, path: str = ".") -> DiskUsage: """Get directory usage information.""" if self.config.cluster_type == "local": return self._local_du(path) else: return self._remote_du(path)
[docs] def count_files(self, path: str = ".", pattern: str = "*") -> int: """Count files in directory matching pattern.""" if self.config.cluster_type == "local": return self._local_count_files(path, pattern) else: return self._remote_count_files(path, pattern)
# ===== Local Implementations ===== def _local_ls(self, path: str) -> List[str]: """Local directory listing.""" full_path = self._get_full_path(path) try: return sorted(os.listdir(full_path)) except (OSError, IOError): return [] def _local_find(self, pattern: str, path: str) -> List[str]: """Local file finding.""" full_path = self._get_full_path(path) base_path = Path(full_path) results = [] for item in base_path.rglob(pattern): # Return relative paths from the search directory try: rel_path = item.relative_to(base_path) # Normalize path separators to forward slashes for consistency normalized_path = str(rel_path).replace(os.sep, "/") results.append(normalized_path) except ValueError: # If relative_to fails, use absolute path normalized_path = str(item).replace(os.sep, "/") results.append(normalized_path) return sorted(results) def _local_stat(self, path: str) -> FileInfo: """Local file stat.""" full_path = self._get_full_path(path) stat = os.stat(full_path) return FileInfo( size=stat.st_size, modified=stat.st_mtime, is_dir=os.path.isdir(full_path), permissions=oct(stat.st_mode)[-3:], name=os.path.basename(path), ) def _local_exists(self, path: str) -> bool: """Check if local path exists.""" full_path = self._get_full_path(path) return os.path.exists(full_path) def _local_isdir(self, path: str) -> bool: """Check if local path is directory.""" full_path = self._get_full_path(path) return os.path.isdir(full_path) def _local_isfile(self, path: str) -> bool: """Check if local path is file.""" full_path = self._get_full_path(path) return os.path.isfile(full_path) def _local_glob(self, pattern: str, path: str) -> List[str]: """Local glob pattern matching.""" full_path = self._get_full_path(path) search_pattern = os.path.join(full_path, pattern) results = [] for match in glob_module.glob(search_pattern): # Return relative paths from the search directory try: rel_path = os.path.relpath(match, full_path) results.append(rel_path) except ValueError: results.append(match) return sorted(results) def _local_du(self, path: str) -> DiskUsage: """Local disk usage.""" full_path = self._get_full_path(path) total_size = 0 file_count = 0 for dirpath, dirnames, filenames in os.walk(full_path): for filename in filenames: filepath = os.path.join(dirpath, filename) try: total_size += os.path.getsize(filepath) file_count += 1 except (OSError, IOError): # Skip files we can't access pass return DiskUsage(total_bytes=total_size, file_count=file_count) def _local_count_files(self, path: str, pattern: str) -> int: """Count local files matching pattern.""" if pattern == "*": # Optimize for counting all files full_path = self._get_full_path(path) count = 0 for _, _, filenames in os.walk(full_path): count += len(filenames) return count else: # Use find for pattern matching return len(self._local_find(pattern, path)) # ===== Remote Implementations ===== def _remote_ls(self, path: str) -> List[str]: """Remote directory listing via SSH.""" ssh_client = self._get_ssh_client() full_path = self._get_full_path(path) # Use ls -1 for one file per line cmd = f"ls -1 {full_path} 2>/dev/null || true" stdin, stdout, stderr = ssh_client.exec_command(cmd) output = stdout.read().decode().strip() if output: return sorted(output.split("\n")) return [] def _remote_find(self, pattern: str, path: str) -> List[str]: """Remote file finding via SSH.""" ssh_client = self._get_ssh_client() full_path = self._get_full_path(path) # Use find command with name pattern cmd = f"cd {full_path} && find . -name '{pattern}' -type f | sed 's|^\\./||' | sort" stdin, stdout, stderr = ssh_client.exec_command(cmd) output = stdout.read().decode().strip() if output: return output.split("\n") return [] def _remote_stat(self, path: str) -> FileInfo: """Remote file stat via SSH.""" ssh_client = self._get_ssh_client() full_path = self._get_full_path(path) # Use stat command with portable format # %s = size, %Y = modification time, %f = file type/mode in hex cmd = f"stat -c '%s %Y %f' {full_path} 2>/dev/null" stdin, stdout, stderr = ssh_client.exec_command(cmd) output = stdout.read().decode().strip() if not output: raise FileNotFoundError(f"File not found: {path}") parts = output.split() size = int(parts[0]) mtime = int(parts[1]) mode_hex = int(parts[2], 16) # Check if directory (S_IFDIR = 0x4000) is_dir = bool(mode_hex & 0x4000) # Extract permissions (last 3 octal digits) permissions = oct(mode_hex & 0o777)[-3:] return FileInfo( size=size, modified=mtime, is_dir=is_dir, permissions=permissions, name=os.path.basename(path), ) def _remote_exists(self, path: str) -> bool: """Check if remote path exists.""" ssh_client = self._get_ssh_client() full_path = self._get_full_path(path) cmd = f"test -e {full_path} && echo 'EXISTS' || echo 'NOT_EXISTS'" stdin, stdout, stderr = ssh_client.exec_command(cmd) output = stdout.read().decode().strip() return output == "EXISTS" def _remote_isdir(self, path: str) -> bool: """Check if remote path is directory.""" ssh_client = self._get_ssh_client() full_path = self._get_full_path(path) cmd = f"test -d {full_path} && echo 'DIR' || echo 'NOT_DIR'" stdin, stdout, stderr = ssh_client.exec_command(cmd) output = stdout.read().decode().strip() return output == "DIR" def _remote_isfile(self, path: str) -> bool: """Check if remote path is file.""" ssh_client = self._get_ssh_client() full_path = self._get_full_path(path) cmd = f"test -f {full_path} && echo 'FILE' || echo 'NOT_FILE'" stdin, stdout, stderr = ssh_client.exec_command(cmd) output = stdout.read().decode().strip() return output == "FILE" def _remote_glob(self, pattern: str, path: str) -> List[str]: """Remote glob pattern matching via SSH.""" ssh_client = self._get_ssh_client() full_path = self._get_full_path(path) # Use shell glob expansion with ls # The 2>/dev/null suppresses errors for no matches cmd = f"cd {full_path} && ls -d {pattern} 2>/dev/null | sort || true" stdin, stdout, stderr = ssh_client.exec_command(cmd) output = stdout.read().decode().strip() if output: return output.split("\n") return [] def _remote_du(self, path: str) -> DiskUsage: """Remote disk usage via SSH.""" ssh_client = self._get_ssh_client() full_path = self._get_full_path(path) # Get total size in bytes cmd1 = f"du -sb {full_path} 2>/dev/null | cut -f1" stdin, stdout, stderr = ssh_client.exec_command(cmd1) size_output = stdout.read().decode().strip() # Count files cmd2 = f"find {full_path} -type f 2>/dev/null | wc -l" stdin, stdout, stderr = ssh_client.exec_command(cmd2) count_output = stdout.read().decode().strip() total_bytes = int(size_output) if size_output else 0 file_count = int(count_output) if count_output else 0 return DiskUsage(total_bytes=total_bytes, file_count=file_count) def _remote_count_files(self, path: str, pattern: str) -> int: """Remote file counting via SSH.""" ssh_client = self._get_ssh_client() full_path = self._get_full_path(path) if pattern == "*": # Count all files cmd = f"find {full_path} -type f 2>/dev/null | wc -l" else: # Count files matching pattern cmd = f"find {full_path} -name '{pattern}' -type f 2>/dev/null | wc -l" stdin, stdout, stderr = ssh_client.exec_command(cmd) output = stdout.read().decode().strip() return int(output) if output else 0
# ===== Convenience Functions =====
[docs] def cluster_ls(path: str = ".", config: Optional[ClusterConfig] = None) -> List[str]: """List directory contents locally or remotely based on config.""" if config is None: from .config import get_config config = get_config() fs = ClusterFilesystem(config) return fs.ls(path)
[docs] def cluster_find( pattern: str, path: str = ".", config: Optional[ClusterConfig] = None ) -> List[str]: """Find files matching pattern locally or remotely based on config.""" if config is None: from .config import get_config config = get_config() fs = ClusterFilesystem(config) return fs.find(pattern, path)
[docs] def cluster_stat(path: str, config: Optional[ClusterConfig] = None) -> FileInfo: """Get file information locally or remotely based on config.""" if config is None: from .config import get_config config = get_config() fs = ClusterFilesystem(config) return fs.stat(path)
[docs] def cluster_exists(path: str, config: Optional[ClusterConfig] = None) -> bool: """Check if file/directory exists locally or remotely based on config.""" if config is None: from .config import get_config config = get_config() fs = ClusterFilesystem(config) return fs.exists(path)
[docs] def cluster_isdir(path: str, config: Optional[ClusterConfig] = None) -> bool: """Check if path is directory locally or remotely based on config.""" if config is None: from .config import get_config config = get_config() fs = ClusterFilesystem(config) return fs.isdir(path)
[docs] def cluster_isfile(path: str, config: Optional[ClusterConfig] = None) -> bool: """Check if path is file locally or remotely based on config.""" if config is None: from .config import get_config config = get_config() fs = ClusterFilesystem(config) return fs.isfile(path)
[docs] def cluster_glob( pattern: str, path: str = ".", config: Optional[ClusterConfig] = None ) -> List[str]: """Pattern matching for files locally or remotely based on config.""" if config is None: from .config import get_config config = get_config() fs = ClusterFilesystem(config) return fs.glob(pattern, path)
[docs] def cluster_du(path: str = ".", config: Optional[ClusterConfig] = None) -> DiskUsage: """Get directory usage locally or remotely based on config.""" if config is None: from .config import get_config config = get_config() fs = ClusterFilesystem(config) return fs.du(path)
[docs] def cluster_count_files( path: str = ".", pattern: str = "*", config: Optional[ClusterConfig] = None ) -> int: """Count files matching pattern locally or remotely based on config.""" if config is None: from .config import get_config config = get_config() fs = ClusterFilesystem(config) return fs.count_files(path, pattern)