"""
Dependency Analysis System for Clustrix
This module provides AST-based analysis to identify function dependencies,
including imports, local function calls, file references, and cluster filesystem operations.
"""
import ast
import inspect
import os
import types
import textwrap
from typing import List, Dict, Set, Optional, Callable, Any
from pathlib import Path
[docs]
class FilesystemCall:
"""Represents a call to cluster filesystem function."""
[docs]
def __init__(
self, function: str, args: List[str], lineno: int, context: Optional[str] = None
):
self.function = function
self.args = args
self.lineno = lineno
self.context = context
def __repr__(self):
return f"FilesystemCall(function='{self.function}', args={self.args}, lineno={self.lineno})"
[docs]
class ImportInfo:
"""Information about an import statement."""
[docs]
def __init__(
self,
module: str,
names: List[str],
alias: Optional[str] = None,
is_from_import: bool = False,
lineno: int = 0,
):
self.module = module
self.names = names
self.alias = alias
self.is_from_import = is_from_import
self.lineno = lineno
def __repr__(self):
return f"ImportInfo(module='{self.module}', names={self.names}, is_from_import={self.is_from_import})"
[docs]
class LocalFunctionCall:
"""Information about a call to a locally-defined function."""
[docs]
def __init__(
self,
function_name: str,
lineno: int,
defined_in_scope: bool = False,
source_file: Optional[str] = None,
):
self.function_name = function_name
self.lineno = lineno
self.defined_in_scope = defined_in_scope
self.source_file = source_file
def __repr__(self):
return f"LocalFunctionCall(function_name='{self.function_name}', lineno={self.lineno})"
[docs]
class FileReference:
"""Reference to a file in the code."""
[docs]
def __init__(
self, path: str, operation: str, lineno: int, is_relative: bool = True
):
self.path = path
self.operation = operation # 'read', 'write', 'open', etc.
self.lineno = lineno
self.is_relative = is_relative
def __repr__(self):
return f"FileReference(path='{self.path}', operation='{self.operation}', lineno={self.lineno})"
[docs]
class DependencyGraph:
"""Complete dependency graph for a function."""
[docs]
def __init__(self, function_name: str, source_code: str):
self.function_name = function_name
self.source_code = source_code
self.imports: List[ImportInfo] = []
self.local_function_calls: List[LocalFunctionCall] = []
self.file_references: List[FileReference] = []
self.filesystem_calls: List[FilesystemCall] = []
self.source_files: Set[str] = set()
self.local_modules: Set[str] = set()
self.data_files: Set[str] = set()
self.requires_cluster_filesystem: bool = False
[docs]
def add_imports(self, imports: List[ImportInfo]):
"""Add import dependencies."""
self.imports.extend(imports)
[docs]
def add_local_function_calls(self, calls: List[LocalFunctionCall]):
"""Add local function call dependencies."""
self.local_function_calls.extend(calls)
[docs]
def add_file_references(self, refs: List[FileReference]):
"""Add file reference dependencies."""
self.file_references.extend(refs)
for ref in refs:
if ref.is_relative:
self.data_files.add(ref.path)
[docs]
def add_filesystem_calls(self, calls: List[FilesystemCall]):
"""Add cluster filesystem call dependencies."""
self.filesystem_calls.extend(calls)
if calls:
self.requires_cluster_filesystem = True
[docs]
class DependencyAnalyzer:
"""Analyzes Python functions to identify all dependencies."""
[docs]
def __init__(self):
self.cluster_fs_functions = {
"cluster_ls",
"cluster_find",
"cluster_stat",
"cluster_exists",
"cluster_isdir",
"cluster_isfile",
"cluster_glob",
"cluster_du",
"cluster_count_files",
}
self.file_operations = {"open", "read", "write", "load", "dump", "save"}
[docs]
def analyze_function(self, func: Callable) -> DependencyGraph:
"""
Analyze a function for all dependencies.
Args:
func: The function to analyze
Returns:
DependencyGraph containing all identified dependencies
"""
try:
# Get function source code
source = inspect.getsource(func)
func_name = func.__name__
except (OSError, TypeError) as e:
raise ValueError(f"Cannot get source for function {func.__name__}: {e}")
# Parse the source into AST
try:
# Remove common leading whitespace to handle indented functions
dedented_source = textwrap.dedent(source)
tree = ast.parse(dedented_source)
except SyntaxError as e:
raise ValueError(f"Invalid syntax in function {func_name}: {e}")
# Initialize dependency graph
dependencies = DependencyGraph(
function_name=func_name, source_code=dedented_source
)
# Analyze different types of dependencies
self._analyze_imports(tree, dependencies)
self._analyze_function_calls(tree, dependencies, func)
self._analyze_file_references(tree, dependencies)
self._analyze_filesystem_calls(tree, dependencies)
# Identify source files and modules
self._identify_source_dependencies(func, dependencies)
return dependencies
def _analyze_imports(self, tree: ast.AST, dependencies: DependencyGraph):
"""Analyze import statements in the function."""
imports = []
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for alias in node.names:
imports.append(
ImportInfo(
module=alias.name,
names=[alias.name],
alias=alias.asname,
is_from_import=False,
lineno=node.lineno,
)
)
elif isinstance(node, ast.ImportFrom):
if node.module: # Skip relative imports without module
imports.append(
ImportInfo(
module=node.module,
names=[alias.name for alias in node.names],
is_from_import=True,
lineno=node.lineno,
)
)
dependencies.add_imports(imports)
def _analyze_function_calls(
self, tree: ast.AST, dependencies: DependencyGraph, func: Callable
):
"""Analyze function calls to identify local dependencies."""
local_calls = []
# Get the function's global namespace
func_globals = getattr(func, "__globals__", {})
for node in ast.walk(tree):
if isinstance(node, ast.Call):
if isinstance(node.func, ast.Name):
func_name = node.func.id
# Check if this is a locally-defined function
if func_name in func_globals:
obj = func_globals[func_name]
if isinstance(obj, types.FunctionType):
# This is a local function
source_file = None
try:
source_file = inspect.getfile(obj)
except (OSError, TypeError):
pass
local_calls.append(
LocalFunctionCall(
function_name=func_name,
lineno=node.lineno,
defined_in_scope=True,
source_file=source_file,
)
)
dependencies.add_local_function_calls(local_calls)
def _analyze_file_references(self, tree: ast.AST, dependencies: DependencyGraph):
"""Analyze file operations and path references."""
file_refs = []
for node in ast.walk(tree):
# Look for function calls that operate on files
if isinstance(node, ast.Call):
if isinstance(node.func, ast.Name):
func_name = node.func.id
if func_name in self.file_operations:
# Extract file path if it's a string literal
if node.args and isinstance(node.args[0], ast.Constant):
if isinstance(node.args[0].value, str):
path = node.args[0].value
file_refs.append(
FileReference(
path=path,
operation=func_name,
lineno=node.lineno,
is_relative=not os.path.isabs(path),
)
)
# Look for method calls on file-like objects
elif isinstance(node.func, ast.Attribute):
method_name = node.func.attr
if method_name in {"read", "write", "readline", "writelines"}:
# This is a file operation, but we can't easily get the path
# from method calls, so we'll note it exists
file_refs.append(
FileReference(
path="<unknown>",
operation=method_name,
lineno=node.lineno,
is_relative=True,
)
)
# Look for string literals that look like file paths
elif isinstance(node, ast.Constant):
if isinstance(node.value, str):
value = node.value
# Simple heuristic: contains path separators and common extensions
if ("/" in value or "\\" in value) and ("." in value):
# Check for common file extensions
extensions = {
".txt",
".csv",
".json",
".xml",
".yaml",
".yml",
".h5",
".hdf5",
".pickle",
".pkl",
".npy",
".npz",
".dat",
".log",
".conf",
".cfg",
".ini",
}
if any(value.lower().endswith(ext) for ext in extensions):
file_refs.append(
FileReference(
path=value,
operation="reference",
lineno=node.lineno,
is_relative=not os.path.isabs(value),
)
)
dependencies.add_file_references(file_refs)
def _analyze_filesystem_calls(self, tree: ast.AST, dependencies: DependencyGraph):
"""Analyze calls to cluster filesystem functions."""
fs_calls = []
for node in ast.walk(tree):
if isinstance(node, ast.Call):
if isinstance(node.func, ast.Name):
func_name = node.func.id
if func_name in self.cluster_fs_functions:
# Extract arguments as string representations
args = []
for arg in node.args:
try:
if isinstance(arg, ast.Constant):
args.append(repr(arg.value))
else:
# Use ast.unparse if available (Python 3.9+), otherwise use a fallback
if hasattr(ast, "unparse"):
args.append(ast.unparse(arg))
else:
args.append(str(arg))
except Exception:
args.append("<unparseable>")
fs_calls.append(
FilesystemCall(
function=func_name, args=args, lineno=node.lineno
)
)
dependencies.add_filesystem_calls(fs_calls)
def _identify_source_dependencies(
self, func: Callable, dependencies: DependencyGraph
):
"""Identify source files and modules that need to be packaged."""
# Add the function's own source file
try:
source_file = inspect.getfile(func)
if source_file and source_file != "<stdin>":
dependencies.source_files.add(source_file)
except (OSError, TypeError):
pass
# Add source files for local function calls
for call in dependencies.local_function_calls:
if call.source_file:
dependencies.source_files.add(call.source_file)
# Identify local modules based on imports
for import_info in dependencies.imports:
module_name = import_info.module
# Check if this is a local module (not in standard library or site-packages)
try:
module = __import__(module_name)
module_file = getattr(module, "__file__", None)
if module_file:
module_path = Path(module_file)
# Check if it's in the current working directory or subdirectories
cwd = Path.cwd()
try:
module_path.relative_to(cwd)
# It's a local module
dependencies.local_modules.add(str(module_path))
except ValueError:
# Not a local module
pass
except ImportError:
# Module not found - might be a local module that's not in path
pass
[docs]
class LoopAnalyzer:
"""Analyzes loops in functions to identify parallelization opportunities."""
[docs]
def __init__(self):
self.parallelizable_patterns = {
"for_loop_with_list_comprehension",
"for_loop_with_independent_iterations",
"map_like_operations",
}
[docs]
def analyze_loops(self, tree: ast.AST) -> List[Dict[str, Any]]:
"""
Analyze loops in the AST to identify parallelization opportunities.
Returns:
List of loop analysis results
"""
loops = []
for node in ast.walk(tree):
if isinstance(node, ast.For):
loop_info = self._analyze_for_loop(node)
loops.append(loop_info)
elif isinstance(node, ast.While):
loop_info = self._analyze_while_loop(node)
loops.append(loop_info)
return loops
def _analyze_for_loop(self, node: ast.For) -> Dict[str, Any]:
"""Analyze a for loop for parallelization potential."""
# Use ast.unparse if available (Python 3.9+), otherwise use a simple fallback
if hasattr(ast, "unparse"):
target_str = ast.unparse(node.target)
iter_str = ast.unparse(node.iter)
else:
target_str = getattr(node.target, "id", str(node.target))
# For Python < 3.9, handle common iterator patterns
if isinstance(node.iter, ast.Call):
if hasattr(node.iter.func, "id"):
# Handle range(), list(), etc.
func_name = node.iter.func.id
args = []
for arg in node.iter.args:
if isinstance(arg, ast.Constant):
args.append(str(arg.value))
elif isinstance(arg, ast.Num): # Python 3.8 compatibility
args.append(str(arg.n))
else:
args.append("...")
iter_str = f"{func_name}({', '.join(args)})"
else:
iter_str = str(node.iter)
elif isinstance(node.iter, ast.Name):
iter_str = node.iter.id
else:
iter_str = str(node.iter)
return {
"type": "for",
"lineno": node.lineno,
"target": target_str,
"iter": iter_str,
"is_parallelizable": self._is_loop_parallelizable(node),
"dependencies": self._find_loop_dependencies(node),
}
def _analyze_while_loop(self, node: ast.While) -> Dict[str, Any]:
"""Analyze a while loop for parallelization potential."""
# Use ast.unparse if available (Python 3.9+), otherwise use a simple fallback
if hasattr(ast, "unparse"):
test_str = ast.unparse(node.test)
else:
test_str = str(node.test)
return {
"type": "while",
"lineno": node.lineno,
"test": test_str,
"is_parallelizable": False, # While loops are generally not parallelizable
"dependencies": [],
}
def _is_loop_parallelizable(self, node: ast.For) -> bool:
"""
Determine if a for loop can be parallelized.
A loop is potentially parallelizable if:
1. Each iteration is independent
2. No shared mutable state
3. No break/continue statements that depend on previous iterations
"""
# Simple heuristic: check for common non-parallelizable patterns
for child in ast.walk(node):
# Break/continue make parallelization complex
if isinstance(child, (ast.Break, ast.Continue)):
return False
# Global variable modifications can create dependencies
if isinstance(child, ast.Global):
return False
# More sophisticated analysis would be needed for production use
return True
def _find_loop_dependencies(self, node: ast.For) -> List[str]:
"""Find variables that the loop depends on."""
dependencies = []
for child in ast.walk(node):
if isinstance(child, ast.Name) and isinstance(child.ctx, ast.Load):
dependencies.append(child.id)
return list(set(dependencies)) # Remove duplicates
[docs]
def analyze_function_dependencies(func: Callable) -> DependencyGraph:
"""
Convenience function to analyze a function's dependencies.
Args:
func: The function to analyze
Returns:
DependencyGraph containing all identified dependencies
"""
analyzer = DependencyAnalyzer()
return analyzer.analyze_function(func)
[docs]
def analyze_function_loops(func: Callable) -> List[Dict[str, Any]]:
"""
Convenience function to analyze loops in a function.
Args:
func: The function to analyze
Returns:
List of loop analysis results
"""
try:
source = inspect.getsource(func)
# Remove common leading whitespace to handle indented functions
dedented_source = textwrap.dedent(source)
tree = ast.parse(dedented_source)
analyzer = LoopAnalyzer()
return analyzer.analyze_loops(tree)
except (OSError, SyntaxError) as e:
raise ValueError(f"Cannot analyze function {func.__name__}: {e}")