Source code for absurd_client

"""Python client for the Absurd SQL-based durable execution workflow system.

This module provides a client interface to interact with the Absurd workflow engine,
which is built on PostgreSQL. It allows you to spawn tasks, claim and process them,
handle events, manage checkpoints, and track workflow runs.
"""

import json
import logging
import os
import re
import uuid
from datetime import datetime, timedelta, timezone
from typing import Any

import psycopg
from psycopg import sql


logger = logging.getLogger(__name__)


[docs] class AbsurdSleepError(Exception): """Raised when a run enters SLEEPING state to wait for an event. This signals to the orchestrator that the worker thread should be freed to process other tasks while this run waits. """ def __init__( self, message: str, run_id: uuid.UUID | None = None, event_name: str | None = None, ) -> None: super().__init__(message) self.run_id = run_id self.event_name = event_name
[docs] class AbsurdClient: """Enhanced client for interacting with Absurd's SQL functions with full feature utilization.""" def __init__( self, queue_name: str | None = None, worker_id: str | None = None, ): queue_name = queue_name or os.getenv("ABSURD_DEFAULT_QUEUE", "absurd_default") # CRITICAL SECURITY: Validate queue_name to prevent SQL injection # Queue names are used in table name construction (e.g., absurd.t_{queue_name}) # Only allow alphanumeric characters and underscores if not queue_name or not self._is_valid_identifier(queue_name): msg = f"Invalid queue_name '{queue_name}'. Must contain only letters, numbers, and underscores." raise ValueError( msg, ) self.queue_name = queue_name self.worker_id = worker_id or os.getenv("ABSURD_WORKER_ID", "absurd_worker_1") @staticmethod def _is_valid_identifier(name: str) -> bool: """Validate that a string is a safe SQL identifier. Args: name: String to validate Returns: True if safe to use as SQL identifier, False otherwise """ # Allow only alphanumeric and underscores, must start with letter or underscore return bool(re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", name))
[docs] def create_queue(self, conn: psycopg.Connection) -> None: """Create the Absurd queue if it doesn't exist.""" with conn.cursor() as cur: cur.execute( "SELECT absurd.create_queue(%(queue_name)s)", {"queue_name": self.queue_name}, )
[docs] def spawn_task( self, conn: psycopg.Connection, task_name: str, params: dict[str, Any], options: dict[str, Any] | None = None, headers: dict[str, Any] | None = None, retry_strategy: dict[str, Any] | None = None, max_attempts: int | None = None, cancellation: dict[str, Any] | None = None, workflow_run_id: uuid.UUID | None = None, ) -> tuple[uuid.UUID, uuid.UUID, uuid.UUID]: """Spawn a new task in the Absurd queue with full feature support. Args: conn: Database connection (psycopg3) task_name: Name of the task params: Task parameters options: Legacy options dict (for backward compatibility) headers: Task headers for metadata retry_strategy: Retry strategy configuration max_attempts: Maximum retry attempts cancellation: Cancellation rules (max_delay, max_duration) workflow_run_id: Optional workflow run ID for tracking Returns: (task_id, run_id, workflow_run_id) """ # Ensure queue exists self.create_queue(conn) # Build comprehensive options dict task_options = options or {} # Add headers if provided (merge with workflow_run_id) task_headers = headers.copy() if headers else {} # Store workflow_run_id in headers if provided if workflow_run_id: task_headers["workflow_run_id"] = str(workflow_run_id) if task_headers: task_options["headers"] = task_headers # Add retry strategy if provided if retry_strategy: task_options["retry_strategy"] = retry_strategy # CRITICAL: ALWAYS set max_attempts (default to 1 = no retry if not provided) # Tasks should only retry if explicitly configured with retry_policy # Default = 1 attempt (0 retries) to prevent non-idempotent tasks from running twice task_options["max_attempts"] = max_attempts if max_attempts is not None else 1 # Add cancellation rules if provided if cancellation: task_options["cancellation"] = cancellation logger.info(f"Spawning task '{task_name}' with options: {task_options}") with conn.cursor() as cur: cur.execute( "SELECT * FROM absurd.spawn_task(%(queue_name)s, %(task_name)s, %(params)s, %(options)s)", { "queue_name": self.queue_name, "task_name": task_name, "params": json.dumps(params), "options": json.dumps(task_options), }, ) result = cur.fetchone() if result is None: msg = "Failed to spawn task" raise RuntimeError(msg) task_id, run_id = result[0], result[1] # Return workflow_run_id (either provided or default to run_id) actual_workflow_run_id = workflow_run_id or run_id return task_id, run_id, actual_workflow_run_id
[docs] def claim_task( self, conn: psycopg.Connection, worker_id: str | None = None, claim_timeout: int = 30, qty: int = 1, ) -> list[tuple[Any, ...]]: """Claim tasks from the Absurd queue with advanced features. Args: conn: Database session worker_id: Worker identifier (defaults to instance worker_id) claim_timeout: Claim timeout in seconds (0 for no timeout) qty: Number of tasks to claim in batch (for high-throughput processing) Returns: List of (run_id, task_id, attempt, task_name, params, retry_strategy, max_attempts, headers, wake_event, event_payload) """ worker_id = worker_id or self.worker_id logger.info( f"Claiming {qty} task(s) from queue '{self.queue_name}' as worker '{worker_id}' with timeout {claim_timeout}s", ) with conn.cursor() as cur: cur.execute( "SELECT * FROM absurd.claim_task(%(queue_name)s, %(worker_id)s, %(claim_timeout)s, %(qty)s)", { "queue_name": self.queue_name, "worker_id": worker_id, "claim_timeout": claim_timeout, "qty": qty, }, ) result = cur.fetchall() logger.info(f"Claimed {len(result)} task(s)") return [tuple(row) for row in result]
[docs] def complete_task( self, conn: psycopg.Connection, run_id: uuid.UUID, result: dict[str, Any] | None = None, ) -> None: """Mark a task as completed with state validation support.""" logger.info(f"Completing task run {run_id}") try: with conn.cursor() as cur: cur.execute( "SELECT absurd.complete_run(%(queue_name)s, %(run_id)s, %(result)s)", { "queue_name": self.queue_name, "run_id": run_id, "result": json.dumps(result or {}), }, ) logger.info(f"Successfully completed task run {run_id}") except Exception as e: logger.error(f"Failed to complete task run {run_id}: {e}") # Re-raise to handle state validation errors raise
[docs] def fail_task( self, conn: psycopg.Connection, run_id: uuid.UUID, reason: str | dict[str, Any], retry_at: datetime | None = None, ) -> None: """Mark a task as failed with detailed error information. Args: conn: Database session run_id: Task run ID reason: Failure reason (string or detailed dict) retry_at: Optional retry timestamp (for manual retry scheduling) """ # Convert string reason to detailed error format if isinstance(reason, str): failure_reason = { "name": "TaskExecutionError", "message": reason, "timestamp": datetime.now(timezone.utc).isoformat(), } else: failure_reason = reason logger.warning(f"Failing task run {run_id}: {failure_reason}") # Format retry_at as ISO string if provided retry_at_str = retry_at.isoformat() if retry_at else None with conn.cursor() as cur: cur.execute( "SELECT absurd.fail_run(%(queue_name)s, %(run_id)s, %(reason)s, %(retry_at)s)", { "queue_name": self.queue_name, "run_id": run_id, "reason": json.dumps(failure_reason), "retry_at": retry_at_str, }, )
[docs] def cancel_task( self, conn: psycopg.Connection, run_id: uuid.UUID, ) -> bool: """Manually cancel a pending or sleeping task. Args: conn: Database connection run_id: Task run ID to cancel Returns: True if task was cancelled, False if it couldn't be cancelled (already running, completed, or failed) Raises: Exception: If task not found or database error """ logger.info(f"Cancelling task run {run_id}") try: with conn.cursor() as cur: cur.execute( "SELECT absurd.cancel_task(%(queue_name)s, %(run_id)s)", { "queue_name": self.queue_name, "run_id": run_id, }, ) result = cur.fetchone() cancelled = result[0] if result else False if cancelled: logger.info(f"Successfully cancelled task run {run_id}") else: logger.warning(f"Could not cancel task run {run_id} (may already be running/completed)") return cancelled except Exception as e: logger.error(f"Failed to cancel task run {run_id}: {e}") raise
[docs] def extend_claim( self, conn: psycopg.Connection, run_id: uuid.UUID, extend_by_seconds: int, ) -> None: """Extend the claim timeout for a long-running task. This is crucial for tasks that take longer than the initial claim timeout. """ logger.info(f"Extending claim for run {run_id} by {extend_by_seconds} seconds") # Use set_checkpoint with extend_claim_by to extend the claim with conn.cursor() as cur: cur.execute( "SELECT absurd.set_task_checkpoint_state(%(queue_name)s, %(task_id)s, %(step_name)s, %(state)s, %(owner_run)s, %(extend_claim_by)s)", { "queue_name": self.queue_name, "task_id": run_id, # Using run_id as task_id for claim extension "step_name": "claim_extension", "state": json.dumps( { "extended_at": datetime.now(timezone.utc).isoformat(), "extend_by": extend_by_seconds, }, ), "owner_run": run_id, "extend_claim_by": extend_by_seconds, }, )
[docs] def set_checkpoint( self, conn: psycopg.Connection, task_id: uuid.UUID, step_name: str, state: dict[str, Any], owner_run: uuid.UUID, extend_claim_by: int | None = None, ) -> None: """Set a checkpoint for a task with optional claim extension. Args: conn: Database session task_id: Task ID step_name: Checkpoint step name state: Checkpoint state data owner_run: Run ID that owns this checkpoint extend_claim_by: Optional claim extension in seconds """ logger.info(f"Setting checkpoint '{step_name}' for task {task_id}") with conn.cursor() as cur: cur.execute( "SELECT absurd.set_task_checkpoint_state(%(queue_name)s, %(task_id)s, %(step_name)s, %(state)s, %(owner_run)s, %(extend_claim_by)s)", { "queue_name": self.queue_name, "task_id": task_id, "step_name": step_name, "state": json.dumps(state), "owner_run": owner_run, "extend_claim_by": extend_claim_by, }, )
[docs] def get_checkpoint( self, conn: psycopg.Connection, task_id: uuid.UUID, step_name: str, include_pending: bool = False, ) -> dict[str, Any] | None: """Get a checkpoint for a task.""" with conn.cursor() as cur: cur.execute( "SELECT * FROM absurd.get_task_checkpoint_state(%(queue_name)s, %(task_id)s, %(step_name)s, %(include_pending)s)", { "queue_name": self.queue_name, "task_id": task_id, "step_name": step_name, "include_pending": include_pending, }, ) result = cur.fetchone() if result: return { "checkpoint_name": result[0], "state": result[1], "status": result[2], "owner_run_id": result[3], "updated_at": result[4], } return None
[docs] def get_all_checkpoints( self, conn: psycopg.Connection, task_id: uuid.UUID, run_id: uuid.UUID, ) -> list[dict[str, Any]]: """Get all checkpoints for a task.""" with conn.cursor() as cur: cur.execute( "SELECT * FROM absurd.get_task_checkpoint_states(%(queue_name)s, %(task_id)s, %(run_id)s)", {"queue_name": self.queue_name, "task_id": task_id, "run_id": run_id}, ) results = cur.fetchall() checkpoints = [] for result in results: checkpoints.append( { "checkpoint_name": result[0], "state": result[1], "status": result[2], "owner_run_id": result[3], "updated_at": result[4], }, ) return checkpoints
[docs] def await_event( self, conn: psycopg.Connection, task_id: uuid.UUID, run_id: uuid.UUID, step_name: str, event_name: str, timeout: int | None = None, ) -> tuple[bool, dict[str, Any] | None]: """Wait for an event with timeout support. Args: conn: Database session task_id: Task ID run_id: Run ID step_name: Step name event_name: Event name to wait for timeout: Timeout in seconds (None for no timeout) Returns: (should_suspend, payload) """ logger.info( f"Task {task_id} awaiting event '{event_name}' with timeout {timeout}s", ) with conn.cursor() as cur: cur.execute( "SELECT * FROM absurd.await_event(%(queue_name)s, %(task_id)s, %(run_id)s, %(step_name)s, %(event_name)s, %(timeout)s)", { "queue_name": self.queue_name, "task_id": task_id, "run_id": run_id, "step_name": step_name, "event_name": event_name, "timeout": timeout, }, ) result = cur.fetchone() should_suspend = result[0] if result else True payload = result[1] if result and len(result) > 1 else None if should_suspend: logger.info(f"Task {task_id} suspended waiting for event '{event_name}'") else: logger.info( f"Task {task_id} received event '{event_name}' with payload: {payload}", ) return should_suspend, payload
[docs] def schedule_task(self, conn: psycopg.Connection, run_id: uuid.UUID, wake_at: datetime) -> None: """Schedule a task to run at a specific time. Useful for delayed execution or rate limiting. """ logger.info(f"Scheduling task run {run_id} to wake at {wake_at}") with conn.cursor() as cur: cur.execute( "SELECT absurd.schedule_run(%(queue_name)s, %(run_id)s, %(wake_at)s)", { "queue_name": self.queue_name, "run_id": run_id, "wake_at": wake_at.isoformat(), }, )
[docs] def cleanup_tasks(self, conn: psycopg.Connection, ttl_seconds: int, limit: int = 1000) -> int: """Clean up old completed tasks.""" logger.info( f"Cleaning up tasks older than {ttl_seconds} seconds (limit: {limit})", ) with conn.cursor() as cur: cur.execute( "SELECT absurd.cleanup_tasks(%(queue_name)s, %(ttl_seconds)s, %(limit)s)", {"queue_name": self.queue_name, "ttl_seconds": ttl_seconds, "limit": limit}, ) result = cur.fetchone() cleaned_count = result[0] if result else 0 logger.info(f"Cleaned up {cleaned_count} tasks") return cleaned_count
[docs] def cleanup_events(self, conn: psycopg.Connection, ttl_seconds: int, limit: int = 1000) -> int: """Clean up old events.""" logger.info( f"Cleaning up events older than {ttl_seconds} seconds (limit: {limit})", ) with conn.cursor() as cur: cur.execute( "SELECT absurd.cleanup_events(%(queue_name)s, %(ttl_seconds)s, %(limit)s)", {"queue_name": self.queue_name, "ttl_seconds": ttl_seconds, "limit": limit}, ) result = cur.fetchone() cleaned_count = result[0] if result else 0 logger.info(f"Cleaned up {cleaned_count} events") return cleaned_count
[docs] def get_task_status( self, conn: psycopg.Connection, task_id: uuid.UUID, ) -> dict[str, Any] | None: """Get detailed task status information. LOW PRIORITY FIX (Issue #16): Direct table query with locking. Absurd does not provide a stored procedure for retrieving task status, so we must query the tasks table directly. Using FOR SHARE lock to ensure we read committed data and prevent phantom reads. Note: This is a read-only operation with low impact on data consistency. """ # SECURITY FIX: Use sql.Identifier to prevent SQL injection with conn.cursor() as cur: cur.execute( sql.SQL( "SELECT * FROM {schema}.{table} WHERE task_id = %(task_id)s FOR SHARE" ).format( schema=sql.Identifier("absurd"), table=sql.Identifier(f"t_{self.queue_name}"), ), {"task_id": task_id}, ) result = cur.fetchone() if not result: return None # Convert result to dict (column names would be needed for full mapping) # NOTE: tenant_id is at index 14 (last column, added in Phase 13) return { "task_id": result[0], "task_name": result[1], "params": result[2], "headers": result[3], "retry_strategy": result[4], "max_attempts": result[5], "cancellation": result[6], "enqueue_at": result[7], "first_started_at": result[8], "state": result[9], "attempts": result[10], "last_attempt_run": result[11], "completed_payload": result[12], "cancelled_at": result[13], "tenant_id": result[14], # Phase 13: tenant_id added at end }
[docs] def get_run_status( self, conn: psycopg.Connection, run_id: uuid.UUID, ) -> dict[str, Any] | None: """Get detailed run status information. LOW PRIORITY FIX (Issue #17): Direct table query with locking. Absurd does not provide a stored procedure for retrieving run status, so we must query the runs table directly. Using FOR SHARE lock to ensure we read committed data and prevent phantom reads. Note: This is a read-only operation with low impact on data consistency. """ # SECURITY FIX: Use sql.Identifier to prevent SQL injection with conn.cursor() as cur: cur.execute( sql.SQL( "SELECT * FROM {schema}.{table} WHERE run_id = %(run_id)s FOR SHARE" ).format( schema=sql.Identifier("absurd"), table=sql.Identifier(f"r_{self.queue_name}"), ), {"run_id": run_id}, ) result = cur.fetchone() if not result: return None # Convert result to dict - runs table has different columns # NOTE: tenant_id is at index 16 (last column, added in Phase 13) # Actual columns: run_id, task_id, attempt, state, claimed_by, claim_expires_at, # available_at, wake_event, event_payload, started_at, completed_at, failed_at, # result, failure_reason, created_at, last_heartbeat, tenant_id return { "run_id": result[0], "task_id": result[1], "attempt": result[2], "state": result[3], "claimed_by": result[4], "claim_expires_at": result[5], "available_at": result[6], "wake_event": result[7], "event_payload": result[8], "started_at": result[9], "completed_at": result[10], "failed_at": result[11], "result": result[12], "failure_reason": result[13], "created_at": result[14], "last_heartbeat": result[15], "tenant_id": result[16], # Phase 13: tenant_id added at end }
[docs] def get_checkpoints_for_run( self, conn: psycopg.Connection, run_id: uuid.UUID ) -> dict[str, Any]: """Get all checkpoints for a specific run. LOW PRIORITY FIX (Issue #18): Direct table query with locking. While Absurd provides absurd.get_task_checkpoint_states(), it requires a task_id. This method is called with only run_id (from DurableContext initialization), so we must query the checkpoints table directly by owner_run_id. Using FOR SHARE lock to ensure we read committed data. Note: This is a read-only operation with low impact on data consistency. """ # SECURITY FIX: Use sql.Identifier to prevent SQL injection with conn.cursor() as cur: cur.execute( sql.SQL( "SELECT checkpoint_name, state FROM {schema}.{table} WHERE owner_run_id = %(run_id)s FOR SHARE" ).format( schema=sql.Identifier("absurd"), table=sql.Identifier(f"c_{self.queue_name}"), ), {"run_id": run_id}, ) results = cur.fetchall() checkpoints = {} for row in results: checkpoint_name = row[0] state = row[1] checkpoints[checkpoint_name] = state return checkpoints
[docs] def save_checkpoint_for_run( self, conn: psycopg.Connection, run_id: uuid.UUID, step_name: str, data: Any, task_id: uuid.UUID | None = None, ) -> None: """Set a checkpoint for a specific step (convenience wrapper for DurableContext).""" if task_id is None: # If no task_id provided, we need to find the current task for this run # This is a simplified approach - in practice you might want to track this better task_id = self._get_current_task_id(conn, run_id) if task_id: # Call the actual checkpoint method with proper parameters logger.info(f"Setting checkpoint '{step_name}' for task {task_id}") with conn.cursor() as cur: cur.execute( "SELECT absurd.set_task_checkpoint_state(%(queue_name)s, %(task_id)s, %(step_name)s, %(state)s, %(owner_run)s, %(extend_claim_by)s)", { "queue_name": self.queue_name, "task_id": task_id, "step_name": step_name, "state": json.dumps(data), "owner_run": run_id, "extend_claim_by": None, }, ) else: # If we can't find a task_id, we can't save the checkpoint # This might indicate we need to spawn a task first logger.warning( f"No task_id found for run {run_id}, cannot save checkpoint for step {step_name}", )
def _get_current_task_id( self, conn: psycopg.Connection, run_id: uuid.UUID, ) -> uuid.UUID | None: """Get the current task ID for a run. LOW PRIORITY FIX (Issue #19): Direct table query with locking. Absurd does not provide a stored procedure for retrieving task_id from run_id, so we must query the runs table directly. Using FOR SHARE lock to ensure we read committed data. Note: This is a read-only operation with low impact on data consistency. """ # SECURITY FIX: Use sql.Identifier to prevent SQL injection with conn.cursor() as cur: cur.execute( sql.SQL( "SELECT task_id FROM {schema}.{table} WHERE run_id = %(run_id)s ORDER BY started_at DESC LIMIT 1 FOR SHARE" ).format( schema=sql.Identifier("absurd"), table=sql.Identifier(f"r_{self.queue_name}"), ), {"run_id": run_id}, ) result = cur.fetchone() if result: return result[0] # type: ignore[no-any-return] return None
[docs] def sleep(self, conn: psycopg.Connection, run_id: uuid.UUID, duration_seconds: int) -> None: """Durable sleep that survives crashes and restarts.""" # This would typically be implemented using Absurd's sleep functionality # For now, we'll create a simple implementation that sets a wake time wake_time = datetime.now(timezone.utc) + timedelta(seconds=duration_seconds) # Store the sleep information in a checkpoint self.set_checkpoint( conn, run_id, # task_id f"sleep_{duration_seconds}", { "type": "sleep", "duration": duration_seconds, "wake_time": wake_time.isoformat(), "started_at": datetime.now(timezone.utc).isoformat(), }, run_id, # owner_run ) # In a real implementation, this would integrate with Absurd's sleep system # For now, we'll just log it logger.info( f"Sleep scheduled for {duration_seconds} seconds, waking at {wake_time}", )
[docs] def wait_for_event( self, conn: psycopg.Connection, run_id: uuid.UUID, event_name: str, timeout_seconds: int | None = None, task_id: uuid.UUID | None = None, step_name: str | None = None, ) -> Any: """Wait for an event using Absurd's sleep/wake mechanism. This implements the proper Absurd pattern: 1. Check if event already exists (fast path - return immediately) 2. Check if this run has been woken with event payload 3. If not, register wait and mark run as SLEEPING 4. Orchestrator will free worker thread and process other tasks 5. When event is emitted, run is woken and will resume here Args: conn: Database session run_id: The Absurd run ID (NOT workflow run ID!) event_name: Name of event to wait for timeout_seconds: Optional timeout (default 24 hours) task_id: The Absurd task ID (required for registering wait) step_name: The step name (required for registering wait) Returns: Event payload when event is received Raises: TimeoutError: If timeout expires before event received """ # CRITICAL: No default timeout - user MUST specify or it will fail fast # DO NOT add a default here - forcing user to be explicit prevents hung workflows if timeout_seconds is None: msg = ( f"timeout_seconds is required for wait_for_event (event: {event_name}). " "DO NOT rely on defaults - specify an explicit timeout to prevent hung workflows." ) raise ValueError( msg, ) logger.debug(f"Waiting for event '{event_name}' for run {run_id}") # Check if event already exists in events table # SECURITY FIX: Use sql.Identifier to prevent SQL injection with conn.cursor() as cur: cur.execute( sql.SQL( "SELECT payload FROM {schema}.{table} WHERE event_name = %(event_name)s LIMIT 1" ).format( schema=sql.Identifier("absurd"), table=sql.Identifier(f"e_{self.queue_name}"), ), {"event_name": event_name}, ) result = cur.fetchone() if result: return result[0] # Check if this run has been woken with event payload # SECURITY FIX: Use sql.Identifier to prevent SQL injection with conn.cursor() as cur: cur.execute( sql.SQL( "SELECT event_payload, state, wake_event FROM {schema}.{table} WHERE run_id = %(run_id)s" ).format( schema=sql.Identifier("absurd"), table=sql.Identifier(f"r_{self.queue_name}"), ), {"run_id": run_id}, ) run_result = cur.fetchone() if run_result: event_payload, _run_state, wake_event = run_result # Only consume event_payload if it matches the event we're waiting for if event_payload and wake_event == event_name: logger.debug(f"Run {run_id} woken with event '{event_name}'") return event_payload if event_payload and wake_event != event_name: # Ignore stale payload from a different event logger.debug( f"Ignoring stale payload for '{wake_event}' while waiting for '{event_name}'", ) # Event doesn't exist yet - register wait and enter SLEEPING state if not task_id or not step_name: msg = "task_id and step_name required for registering event wait" raise ValueError( msg, ) timeout_at = datetime.now(timezone.utc) + timedelta(seconds=timeout_seconds) # Register wait in wait_registrations table # SECURITY FIX: Use sql.Identifier to prevent SQL injection with conn.cursor() as cur: cur.execute( sql.SQL( "INSERT INTO {schema}.{table} " "(task_id, run_id, step_name, event_name, timeout_at) " "VALUES (%(task_id)s, %(run_id)s, %(step_name)s, %(event_name)s, %(timeout_at)s)" ).format( schema=sql.Identifier("absurd"), table=sql.Identifier(f"w_{self.queue_name}"), ), { "task_id": task_id, "run_id": run_id, "step_name": step_name, "event_name": event_name, "timeout_at": timeout_at, }, ) # Raise exception so orchestrator can mark run as SLEEPING # The orchestrator holds the run row lock, so we can't update it here msg = f"Run {run_id} sleeping, waiting for event '{event_name}'" raise AbsurdSleepError( msg, run_id=run_id, event_name=event_name, )
[docs] def set_run_sleeping( self, conn: psycopg.Connection, run_id: uuid.UUID, event_name: str ) -> None: """Mark a run as SLEEPING waiting for an event. Called by orchestrator after catching AbsurdSleepException. """ # SECURITY FIX: Use sql.Identifier to prevent SQL injection with conn.cursor() as cur: cur.execute( sql.SQL( "UPDATE {schema}.{table} " "SET state = 'sleeping', wake_event = %(event_name)s " "WHERE run_id = %(run_id)s" ).format( schema=sql.Identifier("absurd"), table=sql.Identifier(f"r_{self.queue_name}"), ), {"event_name": event_name, "run_id": run_id}, ) logger.debug(f"Run {run_id} marked as SLEEPING, waiting for '{event_name}'")
[docs] def emit_event( self, conn: psycopg.Connection, event_name: str, payload: dict[str, Any] | None = None, ) -> None: """Emit an event and wake any runs waiting for it. This implements the Absurd waker pattern: 1. Insert event into events table 2. Find all runs waiting for this event (from wait_registrations) 3. Wake them up by marking as available with event payload 4. Delete wait registrations (fulfilled) """ payload_dict = payload or {} # Insert into the events table # Use ON CONFLICT DO NOTHING since event_name is the primary key # (events can only be emitted once) # SECURITY FIX: Use sql.Identifier to prevent SQL injection with conn.cursor() as cur: cur.execute( sql.SQL( "INSERT INTO {schema}.{table} (event_name, payload, emitted_at) " "VALUES (%(event_name)s, CAST(%(payload)s AS jsonb), %(emitted_at)s) " "ON CONFLICT (event_name) DO NOTHING" ).format( schema=sql.Identifier("absurd"), table=sql.Identifier(f"e_{self.queue_name}"), ), { "event_name": event_name, "payload": json.dumps(payload_dict), # Convert to JSON string for CAST "emitted_at": datetime.now(timezone.utc), }, ) logger.debug(f"Event '{event_name}' emitted") # Find all runs waiting for this event # SECURITY FIX: Use sql.Identifier to prevent SQL injection with conn.cursor() as cur: cur.execute( sql.SQL( "SELECT DISTINCT run_id FROM {schema}.{table} WHERE event_name = %(event_name)s" ).format( schema=sql.Identifier("absurd"), table=sql.Identifier(f"w_{self.queue_name}"), ), {"event_name": event_name}, ) waiting_runs = cur.fetchall() if waiting_runs: logger.debug(f"Waking {len(waiting_runs)} run(s) for event '{event_name}'") # Wake each run by marking as available with event payload # SECURITY FIX: Use sql.Identifier to prevent SQL injection for (run_id_val,) in waiting_runs: with conn.cursor() as cur: cur.execute( sql.SQL( "UPDATE {schema}.{table} " "SET state = 'pending', " " available_at = %(now)s, " " wake_event = %(event_name)s, " " event_payload = CAST(%(payload)s AS jsonb) " "WHERE run_id = %(run_id)s AND state = 'sleeping'" ).format( schema=sql.Identifier("absurd"), table=sql.Identifier(f"r_{self.queue_name}"), ), { "now": datetime.now(timezone.utc), "event_name": event_name, "payload": json.dumps(payload_dict), "run_id": run_id_val, }, ) # Delete fulfilled wait registrations # SECURITY FIX: Use sql.Identifier to prevent SQL injection with conn.cursor() as cur: cur.execute( sql.SQL( "DELETE FROM {schema}.{table} WHERE event_name = %(event_name)s" ).format( schema=sql.Identifier("absurd"), table=sql.Identifier(f"w_{self.queue_name}"), ), {"event_name": event_name}, )
[docs] def get_run_checkpoint( self, conn: psycopg.Connection, run_id: uuid.UUID, step_name: str ) -> Any: """Get a specific checkpoint for a run.""" checkpoints = self.get_checkpoints_for_run(conn, run_id) return checkpoints.get(step_name)
# ========================================================================= # Workflow Run Tracking Methods (Phase 10) # =========================================================================
[docs] def create_workflow_run( self, conn: psycopg.Connection, workflow_name: str, workflow_version: str, inputs: dict[str, Any] | None = None, absurd_run_id: uuid.UUID | None = None, created_by: str | None = None, tags: dict[str, Any] | None = None, workflow_hash: str | None = None, ) -> uuid.UUID: """Create a new workflow_run record to track workflow execution. Args: conn: Database connection (transaction-aware) workflow_name: Logical workflow name (must match ``^[a-z][a-z0-9_]*$``, no '__') workflow_version: Workflow version (must match ``^[a-zA-Z0-9._-]+$``, no '__') inputs: Workflow input parameters absurd_run_id: Optional root Absurd run_id created_by: Optional user/system identifier tags: Optional key-value tags for filtering workflow_hash: Optional SHA-256 hash of workflow definition Returns: workflow_run_id: UUID of the created workflow_run record """ if "__" in workflow_name or "__" in workflow_version: msg = "Workflow name/version cannot contain '__' (reserved separator)" raise ValueError(msg) workflow_run_id = uuid.uuid4() logger.info( f"Creating workflow_run: {workflow_name} v{workflow_version} (id={workflow_run_id})" ) with conn.cursor() as cur: cur.execute( """ INSERT INTO workflow_run ( workflow_run_id, workflow_name, workflow_version, workflow_hash, status, absurd_run_id, absurd_queue, inputs, created_by, tags ) VALUES ( %(workflow_run_id)s, %(workflow_name)s, %(workflow_version)s, %(workflow_hash)s, 'pending', %(absurd_run_id)s, %(queue_name)s, %(inputs)s, %(created_by)s, %(tags)s ) """, { "workflow_run_id": workflow_run_id, "workflow_name": workflow_name, "workflow_version": workflow_version, "workflow_hash": workflow_hash, "absurd_run_id": absurd_run_id, "queue_name": self.queue_name, "inputs": json.dumps(inputs) if inputs else None, "created_by": created_by, "tags": json.dumps(tags) if tags else None, }, ) logger.debug(f"Created workflow_run: {workflow_run_id}") return workflow_run_id
[docs] def update_workflow_run_status( self, conn: psycopg.Connection, workflow_run_id: uuid.UUID, status: str, result: dict[str, Any] | None = None, error: dict[str, Any] | None = None, started_at: datetime | None = None, completed_at: datetime | None = None, task_count: int | None = None, ) -> None: """Update workflow_run status and metadata. Args: conn: Database connection workflow_run_id: UUID of workflow_run to update status: New status (pending, running, completed, failed, cancelled) result: Optional final result error: Optional error details started_at: Optional start timestamp completed_at: Optional completion timestamp task_count: Optional task count """ logger.debug(f"Updating workflow_run {workflow_run_id} to '{status}'") update_fields = ["status = %(status)s"] params: dict[str, Any] = {"workflow_run_id": workflow_run_id, "status": status} if result is not None: update_fields.append("result = %(result)s") params["result"] = json.dumps(result) if error is not None: update_fields.append("error = %(error)s") params["error"] = json.dumps(error) if started_at is not None: update_fields.append("started_at = %(started_at)s") params["started_at"] = started_at if completed_at is not None: update_fields.append("completed_at = %(completed_at)s") params["completed_at"] = completed_at if task_count is not None: update_fields.append("task_count = %(task_count)s") params["task_count"] = task_count with conn.cursor() as cur: cur.execute( f"UPDATE workflow_run SET {', '.join(update_fields)} " f"WHERE workflow_run_id = %(workflow_run_id)s", params, ) if cur.rowcount == 0: logger.warning(f"No workflow_run found: {workflow_run_id}") logger.info(f"Updated workflow_run {workflow_run_id} to '{status}'")
# Convenience functions for common patterns
[docs] def spawn_retry_task( client: AbsurdClient, conn: psycopg.Connection, task_name: str, params: dict[str, Any], max_attempts: int = 3, retry_kind: str = "exponential", base_seconds: int = 30, factor: float = 2.0, max_seconds: int | None = None, ) -> tuple[uuid.UUID, uuid.UUID, uuid.UUID]: """Spawn a task with retry strategy.""" retry_strategy = { "kind": retry_kind, "base_seconds": base_seconds, "factor": factor, } if max_seconds: retry_strategy["max_seconds"] = max_seconds return client.spawn_task( conn=conn, task_name=task_name, params=params, max_attempts=max_attempts, retry_strategy=retry_strategy, )
[docs] def spawn_cancellable_task( client: AbsurdClient, conn: psycopg.Connection, task_name: str, params: dict[str, Any], max_delay_seconds: int | None = None, max_duration_seconds: int | None = None, ) -> tuple[uuid.UUID, uuid.UUID, uuid.UUID]: """Spawn a task with cancellation rules.""" cancellation = {} if max_delay_seconds: cancellation["max_delay"] = max_delay_seconds if max_duration_seconds: cancellation["max_duration"] = max_duration_seconds return client.spawn_task( conn=conn, task_name=task_name, params=params, cancellation=cancellation, )
# Singleton instance _absurd_client_singleton: AbsurdClient | None = None
[docs] def get_absurd_client( queue_name: str | None = None, worker_id: str | None = None, ) -> AbsurdClient: """Get the singleton AbsurdClient instance. Args: queue_name: Optional queue name (only used on first call) worker_id: Optional worker ID (only used on first call) Returns: Shared AbsurdClient instance """ global _absurd_client_singleton # noqa: PLW0603 if _absurd_client_singleton is None: _absurd_client_singleton = AbsurdClient( queue_name=queue_name, worker_id=worker_id, ) logger.info( f"Created AbsurdClient singleton with queue '{_absurd_client_singleton.queue_name}' " f"and worker '{_absurd_client_singleton.worker_id}'", ) return _absurd_client_singleton