import functools
import inspect
import pickle
import asyncio
from typing import Any, Callable, Optional, Dict, List
from concurrent.futures import ThreadPoolExecutor, as_completed
from .config import get_config
from .executor import ClusterExecutor
from .local_executor import create_local_executor
from .loop_analysis import detect_loops_in_function, find_parallelizable_loops
from .utils import detect_loops, setup_environment, serialize_function
[docs]
def cluster(
_func: Optional[Callable] = None,
*,
cores: Optional[int] = None,
memory: Optional[str] = None,
time: Optional[str] = None,
partition: Optional[str] = None,
queue: Optional[str] = None,
parallel: Optional[bool] = None,
environment: Optional[str] = None,
**kwargs,
):
"""
Decorator to execute functions on a cluster.
Args:
cores: Number of CPU cores to request
memory: Memory to request (e.g., "8GB")
time: Time limit (e.g., "01:00:00")
partition: Cluster partition to use
queue: Queue to submit to
parallel: Whether to parallelize loops automatically
environment: Conda environment name
**kwargs: Additional job parameters
Returns:
Decorated function that executes on cluster
"""
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args, **func_kwargs):
config = get_config()
# Use provided parameters or fall back to config defaults
job_config = {
"cores": cores or config.default_cores,
"memory": memory or config.default_memory,
"time": time or config.default_time,
"partition": partition or config.default_partition,
"queue": queue or config.default_queue,
"environment": environment or config.conda_env_name,
}
# Determine execution mode
execution_mode = _choose_execution_mode(config, func, args, func_kwargs)
# Check if function contains loops that can be parallelized
should_parallelize = (
parallel if parallel is not None else config.auto_parallel
)
if execution_mode == "local":
if should_parallelize:
return _execute_local_parallel(func, args, func_kwargs, job_config)
else:
# Execute locally without parallelization
return func(*args, **func_kwargs)
else:
# Remote execution
executor = ClusterExecutor(config)
if should_parallelize:
loop_info = detect_loops(func, args, func_kwargs)
if loop_info:
return _execute_parallel(
executor, func, args, func_kwargs, job_config, loop_info
)
# Execute normally on cluster
return _execute_single(executor, func, args, func_kwargs, job_config)
# Store cluster config for access outside execution
wrapper._cluster_config = {
"cores": cores,
"memory": memory,
"time": time,
"partition": partition,
"queue": queue,
"parallel": parallel,
"environment": environment,
}
wrapper._cluster_config.update(kwargs)
return wrapper
# Handle both @cluster and @cluster() usage
if _func is None:
# Called as @cluster() or @cluster(args...)
return decorator
else:
# Called as @cluster (without parentheses)
return decorator(_func)
def _execute_single(
executor: ClusterExecutor,
func: Callable,
args: tuple,
kwargs: dict,
job_config: dict,
) -> Any:
"""Execute function once on cluster."""
# Serialize function and dependencies
func_data = serialize_function(func, args, kwargs)
# Submit job
job_id = executor.submit_job(func_data, job_config)
# Wait for completion and get result
result = executor.wait_for_result(job_id)
return result
def _execute_parallel(
executor: ClusterExecutor,
func: Callable,
args: tuple,
kwargs: dict,
job_config: dict,
loop_info: Dict[str, Any],
) -> Any:
"""Execute function with parallelized loops."""
config = get_config()
# Split work based on loop information
work_chunks = _create_work_chunks(
func, args, kwargs, loop_info, config.max_parallel_jobs
)
# Submit parallel jobs
job_ids = []
for chunk in work_chunks:
func_data = serialize_function(func, chunk["args"], chunk["kwargs"])
job_id = executor.submit_job(func_data, job_config)
job_ids.append((job_id, chunk))
# Collect results
results = []
for job_id, chunk in job_ids:
result = executor.wait_for_result(job_id)
results.append((chunk["index"], result))
# Combine results
return _combine_results(results, loop_info)
def _create_work_chunks(
func: Callable, args: tuple, kwargs: dict, loop_info: Dict, max_jobs: int
) -> List[Dict]:
"""Create chunks of work for parallel execution."""
# This is a simplified implementation
# In practice, you'd need sophisticated analysis of the function
# to determine how to split loops and iterations
chunks = []
loop_var = loop_info.get("variable")
loop_range = loop_info.get("range", range(10)) # Default range
chunk_size = max(1, len(loop_range) // max_jobs)
for i in range(0, len(loop_range), chunk_size):
chunk_range = loop_range[i : i + chunk_size]
# Create modified kwargs for this chunk
chunk_kwargs = kwargs.copy()
chunk_kwargs[f"_chunk_range_{loop_var}"] = chunk_range
chunk_kwargs["_chunk_index"] = i // chunk_size
chunks.append(
{
"args": args,
"kwargs": chunk_kwargs,
"index": i // chunk_size,
"range": chunk_range,
}
)
return chunks
def _combine_results(results: List[tuple], loop_info: Dict) -> Any:
"""Combine results from parallel execution."""
# Sort by index
results.sort(key=lambda x: x[0])
# For now, just return the list of results
# In practice, you'd need to intelligently combine based on the original function
return [result[1] for result in results]
def _choose_execution_mode(config, func: Callable, args: tuple, kwargs: dict) -> str:
"""
Choose between local and remote execution.
Args:
config: Cluster configuration
func: Function to execute
args: Function arguments
kwargs: Function keyword arguments
Returns:
'local' or 'remote'
"""
# If no cluster is configured, use local execution
if not config.cluster_host:
return "local"
# Check if there's a preference for local parallel execution
if hasattr(config, "prefer_local_parallel") and config.prefer_local_parallel:
return "local"
# Default to remote execution when cluster is available
return "remote"
def _execute_local_parallel(
func: Callable, args: tuple, kwargs: dict, job_config: dict
) -> Any:
"""
Execute function locally with parallelization.
Args:
func: Function to execute
args: Function arguments
kwargs: Function keyword arguments
job_config: Job configuration
Returns:
Function result
"""
# Find parallelizable loops
parallelizable_loops = find_parallelizable_loops(func, args, kwargs)
if not parallelizable_loops:
# No parallelizable loops found, execute normally
return func(*args, **kwargs)
# Use the first parallelizable loop
loop_info = parallelizable_loops[0]
# Create local executor
max_workers = job_config.get("cores", 4)
local_executor = create_local_executor(
max_workers=max_workers, func=func, args=args, kwargs=kwargs
)
try:
with local_executor:
# Create work chunks for the loop
work_chunks = _create_local_work_chunks(func, args, kwargs, loop_info)
if not work_chunks:
# Fallback to normal execution
return func(*args, **kwargs)
# Execute in parallel
results = local_executor.execute_parallel(func, work_chunks)
# Combine results
return _combine_local_results(results, loop_info)
except Exception as e:
# Fallback to normal execution on error
import logging
logger = logging.getLogger(__name__)
logger.warning(
f"Local parallel execution failed, falling back to sequential: {e}"
)
return func(*args, **kwargs)
def _create_local_work_chunks(
func: Callable, args: tuple, kwargs: dict, loop_info
) -> List[Dict]:
"""
Create work chunks for local parallel execution.
Args:
func: Function to execute
args: Function arguments
kwargs: Function keyword arguments
loop_info: Information about the loop to parallelize
Returns:
List of work chunks
"""
chunks = []
# Get range information
if hasattr(loop_info, "range_info") and loop_info.range_info:
range_info = loop_info.range_info
start = range_info["start"]
stop = range_info["stop"]
step = range_info["step"]
# Create range object
loop_range = range(start, stop, step)
variable = loop_info.variable
elif hasattr(loop_info, "to_dict"):
# New loop info format
loop_dict = loop_info.to_dict()
range_info = loop_dict.get("range_info")
if range_info:
loop_range = range(
range_info["start"], range_info["stop"], range_info["step"]
)
variable = loop_dict["variable"]
else:
return [] # Can't parallelize without range info
else:
# Legacy format
loop_range = loop_info.get("range", range(10))
variable = loop_info.get("variable", "i")
if not variable or len(loop_range) == 0:
return []
# Determine chunk size (aim for reasonable number of chunks)
import os
max_chunks = os.cpu_count() * 2 # Allow some oversubscription
chunk_size = max(1, len(loop_range) // max_chunks)
# Create chunks
for i in range(0, len(loop_range), chunk_size):
chunk_range = list(loop_range[i : i + chunk_size])
# Create modified kwargs for this chunk
chunk_kwargs = kwargs.copy()
chunk_kwargs[f"_parallel_{variable}"] = chunk_range
chunks.append({"args": args, "kwargs": chunk_kwargs})
return chunks
def _combine_local_results(results: List[Any], loop_info) -> Any:
"""
Combine results from local parallel execution.
Args:
results: List of results from parallel execution
loop_info: Information about the parallelized loop
Returns:
Combined result
"""
# For now, flatten list results or return as-is
if not results:
return None
if len(results) == 1:
return results[0]
# If all results are lists, concatenate them
if all(isinstance(r, list) for r in results):
combined = []
for result in results:
combined.extend(result)
return combined
# Otherwise return the list of results
return results