import functools
from typing import Any, Callable, Optional, Dict, List
from .config import get_config
from .executor import ClusterExecutor
from .async_executor_simple import AsyncClusterExecutor
from .local_executor import create_local_executor
from .loop_analysis import find_parallelizable_loops
from .utils import detect_loops, serialize_function
from .gpu_utils import (
detect_gpu_parallelizable_operations,
)
from .function_flattening import (
analyze_function_complexity,
auto_flatten_if_needed,
create_simple_subprocess_fallback,
)
[docs]
def cluster(
_func: Optional[Callable] = None,
*,
cores: Optional[int] = None,
memory: Optional[str] = None,
time: Optional[str] = None,
partition: Optional[str] = None,
queue: Optional[str] = None,
parallel: Optional[bool] = None,
auto_gpu_parallel: Optional[bool] = None,
environment: Optional[str] = None,
async_submit: Optional[bool] = None,
provider: Optional[str] = None,
instance_type: Optional[str] = None,
region: Optional[str] = None,
# NEW: Kubernetes auto-provisioning parameters
platform: Optional[str] = None,
auto_provision: Optional[bool] = None,
cluster_name: Optional[str] = None,
node_count: Optional[int] = None,
node_type: Optional[str] = None,
kubernetes_version: Optional[str] = None,
from_scratch: Optional[bool] = None,
**kwargs,
):
"""
Decorator to execute functions on a cluster.
Args:
cores: Number of CPU cores to request
memory: Memory to request (e.g., "8GB")
time: Time limit (e.g., "01:00:00")
partition: Cluster partition to use
queue: Queue to submit to
parallel: Whether to parallelize loops automatically
auto_gpu_parallel: Whether to automatically parallelize across GPUs
environment: Conda environment name
async_submit: Whether to submit jobs asynchronously (non-blocking)
provider: Cloud provider to use ('lambda', 'aws', 'azure', 'gcp', 'huggingface')
instance_type: Cloud instance type (e.g., 'gpu_1x_a100' for Lambda Cloud)
region: Cloud region (e.g., 'us-east-1')
# NEW: Kubernetes auto-provisioning parameters
platform: Execution platform ('kubernetes' to enable K8s execution)
auto_provision: Whether to automatically provision K8s cluster if needed
cluster_name: Name for the auto-provisioned cluster
node_count: Number of worker nodes in the cluster
node_type: Cloud-specific node instance type
kubernetes_version: Kubernetes version to install
from_scratch: Whether to create all infrastructure from scratch
**kwargs: Additional job parameters
Returns:
Decorated function that executes on cluster
If async_submit=True, returns AsyncJobResult for non-blocking execution
"""
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args, **func_kwargs):
config = get_config()
# Use provided parameters or fall back to config defaults
job_config = {
"cores": cores or config.default_cores,
"memory": memory or config.default_memory,
"time": time or config.default_time,
"partition": partition or config.default_partition,
"queue": queue or config.default_queue,
"environment": environment or config.conda_env_name,
}
# Add cloud provider parameters if specified
if provider:
job_config["provider"] = provider
if instance_type:
job_config["instance_type"] = instance_type
if region:
job_config["region"] = region
# NEW: Add Kubernetes auto-provisioning parameters
if platform:
job_config["platform"] = platform
# If platform is kubernetes, set cluster_type to kubernetes
if platform == "kubernetes":
config.cluster_type = "kubernetes"
if auto_provision is not None:
job_config["auto_provision"] = auto_provision
config.auto_provision_k8s = auto_provision
if cluster_name:
job_config["cluster_name"] = cluster_name
config.k8s_cluster_name = cluster_name
if node_count is not None:
job_config["node_count"] = node_count
config.k8s_node_count = node_count
if node_type:
job_config["node_type"] = node_type
config.k8s_node_type = node_type
if kubernetes_version:
job_config["kubernetes_version"] = kubernetes_version
config.k8s_version = kubernetes_version
if from_scratch is not None:
job_config["from_scratch"] = from_scratch
config.k8s_from_scratch = from_scratch
# Add any additional cloud provider parameters from kwargs
cloud_params = [
"lambda_api_key",
"aws_access_key_id",
"aws_secret_access_key",
"aws_region",
"azure_subscription_id",
"azure_tenant_id",
"azure_client_id",
"azure_client_secret",
"gcp_project_id",
"gcp_service_account_key",
"hf_token",
"hf_username",
"key_file",
"terminate_on_completion",
"instance_startup_timeout",
]
for param in cloud_params:
if param in kwargs:
job_config[param] = kwargs[param]
# Determine execution mode
execution_mode = _choose_execution_mode(config, func, args, func_kwargs)
# Check if function contains loops that can be parallelized
should_parallelize = (
parallel if parallel is not None else config.auto_parallel
)
# Check if GPU parallelization should be attempted
should_gpu_parallelize = (
auto_gpu_parallel
if auto_gpu_parallel is not None
else config.auto_gpu_parallel
)
if execution_mode == "local":
use_async = (
async_submit
if async_submit is not None
else getattr(config, "async_submit", False)
)
if use_async:
# Async local execution
async_executor = AsyncClusterExecutor(config)
return async_executor.submit_job_async(
func, args, func_kwargs, job_config
)
elif should_parallelize:
return _execute_local_parallel(func, args, func_kwargs, job_config)
else:
# Execute locally without parallelization
return func(*args, **func_kwargs)
else:
# Remote execution
use_async = (
async_submit
if async_submit is not None
else getattr(config, "async_submit", False)
)
if use_async:
# Async execution
async_executor = AsyncClusterExecutor(config)
# NEW: Ensure Kubernetes cluster is ready if auto-provisioning (for async)
if config.cluster_type == "kubernetes" and getattr(
config, "auto_provision_k8s", False
):
# For async execution, we still need to ensure cluster is ready first
# Create a temporary executor to check readiness
temp_executor = ClusterExecutor(config)
if not temp_executor.ensure_cluster_ready(
timeout=900
): # 15 minutes
raise RuntimeError(
"Auto-provisioned Kubernetes cluster failed to become ready"
)
temp_executor.disconnect()
return async_executor.submit_job_async(
func, args, func_kwargs, job_config
)
else:
# Synchronous execution (original behavior)
executor = ClusterExecutor(config)
# NEW: Ensure Kubernetes cluster is ready if auto-provisioning
if config.cluster_type == "kubernetes" and getattr(
config, "auto_provision_k8s", False
):
# Give cluster extra time to be ready if auto-provisioned
if not executor.ensure_cluster_ready(timeout=900): # 15 minutes
raise RuntimeError(
"Auto-provisioned Kubernetes cluster failed to become ready"
)
# Check for GPU parallelization first (higher priority)
if should_gpu_parallelize:
gpu_parallel_result = _attempt_client_side_gpu_parallelization(
executor, func, args, func_kwargs, job_config
)
if gpu_parallel_result is not None:
return gpu_parallel_result
# Fall back to CPU parallelization
if should_parallelize:
loop_info = detect_loops(func, args, func_kwargs)
if loop_info:
return _execute_parallel(
executor,
func,
args,
func_kwargs,
job_config,
loop_info,
)
# Execute normally on cluster
return _execute_single(
executor, func, args, func_kwargs, job_config
)
# Store cluster config for access outside execution
cluster_config = {
"cores": cores,
"memory": memory,
"time": time,
"partition": partition,
"queue": queue,
"parallel": parallel,
"auto_gpu_parallel": auto_gpu_parallel,
"environment": environment,
"async_submit": async_submit,
}
cluster_config.update(kwargs)
setattr(wrapper, "_cluster_config", cluster_config)
return wrapper
# Handle both @cluster and @cluster() usage
if _func is None:
# Called as @cluster() or @cluster(args...)
return decorator
else:
# Called as @cluster (without parentheses)
return decorator(_func)
def _execute_single(
executor: ClusterExecutor,
func: Callable,
args: tuple,
kwargs: dict,
job_config: dict,
) -> Any:
"""Execute function once on cluster."""
import logging
logger = logging.getLogger(__name__)
# Analyze function complexity and flatten if needed
complexity_info = analyze_function_complexity(func)
if complexity_info.get("is_complex", False):
logger.info(
f"Function {func.__name__} is complex "
f"(score: {complexity_info['complexity_score']}), attempting automatic flattening"
)
# Attempt automatic flattening
flattened_func, flattening_info = auto_flatten_if_needed(func)
if flattening_info and flattening_info.get("success", False):
logger.info(f"Successfully flattened {func.__name__}")
func_to_execute = flattened_func
else:
logger.warning(
f"Failed to flatten {func.__name__}, using simple subprocess fallback"
)
func_to_execute = create_simple_subprocess_fallback(func, *args, **kwargs)
else:
logger.debug(
f"Function {func.__name__} is simple (score: {complexity_info['complexity_score']}), executing as-is"
)
func_to_execute = func
# Serialize function and dependencies
func_data = serialize_function(func_to_execute, args, kwargs)
# Submit job
job_id = executor.submit_job(func_data, job_config)
# Wait for completion and get result
result = executor.wait_for_result(job_id)
return result
def _execute_parallel(
executor: ClusterExecutor,
func: Callable,
args: tuple,
kwargs: dict,
job_config: dict,
loop_info: Dict[str, Any],
) -> Any:
"""Execute function with parallelized loops."""
config = get_config()
# Split work based on loop information
work_chunks = _create_work_chunks(
func, args, kwargs, loop_info, config.max_parallel_jobs
)
# Submit parallel jobs
job_ids = []
for chunk in work_chunks:
func_data = serialize_function(func, chunk["args"], chunk["kwargs"])
job_id = executor.submit_job(func_data, job_config)
job_ids.append((job_id, chunk))
# Collect results
results = []
for job_id, chunk in job_ids:
result = executor.wait_for_result(job_id)
results.append((chunk["index"], result))
# Combine results
return _combine_results(results, loop_info)
def _create_work_chunks(
func: Callable, args: tuple, kwargs: dict, loop_info: Dict, max_jobs: int
) -> List[Dict]:
"""Create chunks of work for parallel execution."""
# This is a simplified implementation
# In practice, you'd need sophisticated analysis of the function
# to determine how to split loops and iterations
chunks = []
loop_var = loop_info.get("variable")
loop_range = loop_info.get("range", range(10)) # Default range
chunk_size = max(1, len(loop_range) // max_jobs)
for i in range(0, len(loop_range), chunk_size):
chunk_range = loop_range[i : i + chunk_size]
# Create modified kwargs for this chunk
chunk_kwargs = kwargs.copy()
chunk_kwargs[f"_chunk_range_{loop_var}"] = chunk_range
chunk_kwargs["_chunk_index"] = i // chunk_size
chunks.append(
{
"args": args,
"kwargs": chunk_kwargs,
"index": i // chunk_size,
"range": chunk_range,
}
)
return chunks
def _combine_results(results: List[tuple], loop_info: Dict) -> Any:
"""Combine results from parallel execution."""
# Sort by index
results.sort(key=lambda x: x[0])
# For now, just return the list of results
# In practice, you'd need to intelligently combine based on the original function
return [result[1] for result in results]
def _attempt_client_side_gpu_parallelization(
executor: ClusterExecutor,
func: Callable,
args: tuple,
kwargs: dict,
job_config: dict,
) -> Optional[Any]:
"""
Attempt client-side GPU parallelization (similar to CPU parallelization).
This approach:
1. Detects GPU availability on remote cluster
2. Analyzes function for parallelizable operations
3. Creates separate simple functions for each GPU
4. Submits parallel jobs to cluster
5. Combines results
"""
import logging
logger = logging.getLogger(__name__)
try:
# Step 1: Simple GPU detection on remote cluster
gpu_info = _detect_remote_gpu_count(executor, job_config)
if not gpu_info or gpu_info.get("count", 0) < 2:
logger.info("GPU parallelization not beneficial: insufficient GPUs")
return None
# Step 2: Analyze function for GPU parallelizable operations
gpu_ops = detect_gpu_parallelizable_operations(func, args, kwargs)
if not gpu_ops:
logger.info(
"GPU parallelization not beneficial: no parallelizable operations found"
)
return None
# Step 3: Create client-side execution plan
execution_plan = _create_client_side_gpu_plan(
func, args, kwargs, gpu_info, gpu_ops
)
if not execution_plan:
logger.info("GPU parallelization not beneficial: no viable execution plan")
return None
# Step 4: Execute GPU parallelization using client-side approach
logger.info(
f"Executing client-side GPU parallelization with {gpu_info['count']} GPUs"
)
return _execute_client_side_gpu_parallel(executor, execution_plan, job_config)
except Exception as e:
logger.warning(f"GPU parallelization attempt failed: {e}")
return None
def _detect_remote_gpu_count(
executor: ClusterExecutor, job_config: dict
) -> Optional[Dict[str, Any]]:
"""Detect GPU count on remote cluster using simple function."""
def simple_gpu_count():
"""Simple GPU count detection."""
import subprocess
result = subprocess.run(
[
"python",
"-c",
"import torch; print(f'GPU_COUNT:{torch.cuda.device_count()}')",
],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
universal_newlines=True,
timeout=30,
)
return {"output": result.stdout}
try:
from .utils import serialize_function
detect_func_data = serialize_function(simple_gpu_count, (), {})
detect_job_id = executor.submit_job(
detect_func_data, {"cores": 1, "memory": "2GB"}
)
result = executor.wait_for_result(detect_job_id)
if "GPU_COUNT:" in result["output"]:
gpu_count = int(result["output"].split("GPU_COUNT:", 1)[1].strip())
return {"available": gpu_count > 0, "count": gpu_count}
return None
except Exception:
return None
def _create_client_side_gpu_plan(
func: Callable,
args: tuple,
kwargs: dict,
gpu_info: Dict[str, Any],
gpu_ops: List[Dict[str, Any]],
) -> Optional[Dict[str, Any]]:
"""Create client-side GPU execution plan."""
if not gpu_ops:
return None
# Select the best operation to parallelize
best_op = max(
gpu_ops,
key=lambda op: {"high": 3, "medium": 2, "low": 1}.get(
op.get("estimated_benefit", "low"), 0
),
)
if best_op.get("estimated_benefit") == "low":
return None
return {
"target_operation": best_op,
"gpu_count": gpu_info["count"],
"parallelization_type": "client_side",
"chunk_strategy": "even_split",
}
def _execute_client_side_gpu_parallel(
executor: ClusterExecutor, execution_plan: Dict[str, Any], job_config: dict
) -> Any:
"""Execute GPU parallelization using client-side approach."""
gpu_count = execution_plan["gpu_count"]
# Create simple functions for each GPU (avoiding complexity threshold)
def create_gpu_specific_function(gpu_id: int):
"""Create a simple function for specific GPU."""
def gpu_specific_task():
import subprocess
# Simple GPU-specific computation
gpu_code = f"""
import torch
torch.cuda.set_device({gpu_id})
device = torch.device('cuda:{gpu_id}')
# Simple computation on this specific GPU
x = torch.randn(100, 100, device=device)
y = torch.mm(x, x.t())
result = y.trace().item()
print(f'GPU_{gpu_id}_RESULT:{{result}}')
"""
result = subprocess.run(
["python", "-c", gpu_code],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
universal_newlines=True,
timeout=60,
)
return {"output": result.stdout, "success": result.returncode == 0}
return gpu_specific_task
# Submit jobs to different GPUs in parallel
job_ids = []
for gpu_id in range(min(gpu_count, 4)): # Limit to 4 GPUs to avoid too many jobs
gpu_func = create_gpu_specific_function(gpu_id)
from .utils import serialize_function as util_serialize_function
func_data = util_serialize_function(gpu_func, (), {})
# Modify job config to make this GPU visible
gpu_job_config = job_config.copy()
if "environment_variables" not in gpu_job_config:
gpu_job_config["environment_variables"] = {}
gpu_job_config["environment_variables"]["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
job_id = executor.submit_job(func_data, {"cores": 1, "memory": "4GB"})
job_ids.append((job_id, gpu_id))
# Collect results from all GPUs
gpu_results: Dict[str, Optional[float]] = {}
for job_id, gpu_id in job_ids:
try:
result = executor.wait_for_result(job_id)
if result.get("success") and f"GPU_{gpu_id}_RESULT:" in result.get(
"output", ""
):
output = result["output"]
result_line = [
line
for line in output.split("\n")
if f"GPU_{gpu_id}_RESULT:" in line
][0]
result_value = float(result_line.split(":", 1)[1])
gpu_results[f"gpu_{gpu_id}"] = result_value
else:
gpu_results[f"gpu_{gpu_id}"] = None
except Exception:
gpu_results[f"gpu_{gpu_id}"] = None
# Return combined results
successful_gpus = [k for k, v in gpu_results.items() if v is not None]
return {
"gpu_parallel": True,
"gpu_count": len(successful_gpus),
"results": gpu_results,
"successful_gpus": successful_gpus,
}
def _choose_execution_mode(config, func: Callable, args: tuple, kwargs: dict) -> str:
"""
Choose between local and remote execution.
Args:
config: Cluster configuration
func: Function to execute
args: Function arguments
kwargs: Function keyword arguments
Returns:
'local' or 'remote'
"""
# Check for Kubernetes auto-provisioning
if config.cluster_type == "kubernetes" and getattr(
config, "auto_provision_k8s", False
):
return "remote"
# If no cluster is configured, use local execution
if not config.cluster_host:
return "local"
# Check if there's a preference for local parallel execution
if hasattr(config, "prefer_local_parallel") and config.prefer_local_parallel:
return "local"
# Default to remote execution when cluster is available
return "remote"
def _execute_local_parallel(
func: Callable, args: tuple, kwargs: dict, job_config: dict
) -> Any:
"""
Execute function locally with parallelization.
Args:
func: Function to execute
args: Function arguments
kwargs: Function keyword arguments
job_config: Job configuration
Returns:
Function result
"""
# Find parallelizable loops
parallelizable_loops = find_parallelizable_loops(func, args, kwargs)
if not parallelizable_loops:
# No parallelizable loops found, execute normally
return func(*args, **kwargs)
# Use the first parallelizable loop
loop_info = parallelizable_loops[0]
# Create local executor
max_workers = job_config.get("cores", 4)
local_executor = create_local_executor(
max_workers=max_workers, func=func, args=args, kwargs=kwargs
)
try:
with local_executor:
# Create work chunks for the loop
work_chunks = _create_local_work_chunks(func, args, kwargs, loop_info)
if not work_chunks:
# Fallback to normal execution
return func(*args, **kwargs)
# Execute in parallel
results = local_executor.execute_parallel(func, work_chunks)
# Combine results
return _combine_local_results(results, loop_info)
except Exception as e:
# Fallback to normal execution on error
import logging
logger = logging.getLogger(__name__)
logger.warning(
f"Local parallel execution failed, falling back to sequential: {e}"
)
return func(*args, **kwargs)
def _create_local_work_chunks(
func: Callable, args: tuple, kwargs: dict, loop_info
) -> List[Dict]:
"""
Create work chunks for local parallel execution.
Args:
func: Function to execute
args: Function arguments
kwargs: Function keyword arguments
loop_info: Information about the loop to parallelize
Returns:
List of work chunks
"""
chunks = []
# Get range information
if hasattr(loop_info, "range_info") and loop_info.range_info:
range_info = loop_info.range_info
start = range_info["start"]
stop = range_info["stop"]
step = range_info["step"]
# Create range object
loop_range = range(start, stop, step)
variable = loop_info.variable
elif hasattr(loop_info, "to_dict"):
# New loop info format
loop_dict = loop_info.to_dict()
range_info = loop_dict.get("range_info")
if range_info:
loop_range = range(
range_info["start"], range_info["stop"], range_info["step"]
)
variable = loop_dict["variable"]
else:
return [] # Can't parallelize without range info
else:
# Legacy format
loop_range = loop_info.get("range", range(10))
variable = loop_info.get("variable", "i")
if not variable or len(loop_range) == 0:
return []
# Determine chunk size (aim for reasonable number of chunks)
import os
max_chunks = (os.cpu_count() or 1) * 2 # Allow some oversubscription
chunk_size = max(1, len(loop_range) // max_chunks)
# Create chunks
for i in range(0, len(loop_range), chunk_size):
chunk_range = list(loop_range[i : i + chunk_size])
# Create modified kwargs for this chunk
chunk_kwargs = kwargs.copy()
chunk_kwargs[f"_parallel_{variable}"] = chunk_range
chunks.append({"args": args, "kwargs": chunk_kwargs})
return chunks
def _combine_local_results(results: List[Any], loop_info) -> Any:
"""
Combine results from local parallel execution.
Args:
results: List of results from parallel execution
loop_info: Information about the parallelized loop
Returns:
Combined result
"""
# For now, flatten list results or return as-is
if not results:
return None
if len(results) == 1:
return results[0]
# If all results are lists, concatenate them
if all(isinstance(r, list) for r in results):
combined = []
for result in results:
combined.extend(result)
return combined
# Otherwise return the list of results
return results