ADK Callbacks: Build Production Observability for AI Agents

2026-04-08 · 29 min read · gen:3m 52s · tok:19212
#google-adk #observability #ai-agents #opentelemetry #devops #beginner-tutorial #english

Learn to build a production-ready observability pipeline with ADK callbacks. Track token usage, costs, and execution time for your AI agents at scale.

Building a Production-Ready Observability Pipeline with ADK Callbacks: Tracking Every Dollar and Millisecond in Your AI Agents

Your AI agent processed 50,000 requests last month. You know this because your billing dashboard shows a $12,847 charge from OpenAI. What you don’t know: which user sessions consumed 40% of that budget, why Tuesday’s requests cost 3x more than Monday’s, or which tool invocations triggered expensive retry cascades. You’re flying blind, and it’s costing you money.

The Agent Development Kit (ADK) provides callback hooks at every critical juncture—before and after each LLM call, tool invocation, and agent step. Most developers ignore them or log basic print statements. In this article, you’ll build a production-grade observability pipeline that captures token usage, execution time, and costs at microsecond granularity, then pipes that data into dashboards where you can actually act on it.

Prerequisites

Before diving in, ensure you have:

  • Python 3.10+ installed
  • Google ADK (pip install google-adk>=1.0.0)
  • Basic familiarity with ADK’s agent structure and callback system
  • One of these observability backends configured: OpenTelemetry Collector, Datadog Agent, or Prometheus with Pushgateway
  • PostgreSQL or similar for audit log persistence (optional but recommended)

You should understand how ADK agents execute—specifically that agents can nest sub-agents and invoke tools, creating call trees that need correlation.

Architecture and Key Concepts

The observability pipeline intercepts ADK execution at three levels: individual LLM calls, tool invocations, and aggregate agent steps. Each callback captures structured telemetry, enriches it with correlation IDs, and routes it to appropriate sinks.

flowchart TD
    subgraph ADK_Runtime["ADK Runtime"]
        A[User Request] --> B[Root Agent]
        B --> C{Decision Point}
        C -->|LLM Call| D[Model Invocation]
        C -->|Tool Use| E[Tool Execution]
        C -->|Delegate| F[Sub-Agent]
        F --> C
    end

    subgraph Callback_Layer["Callback Layer"]
        D --> G[on_llm_start / on_llm_end]
        E --> H[on_tool_start / on_tool_end]
        F --> I[on_agent_start / on_agent_end]
    end

    subgraph Telemetry_Pipeline["Telemetry Pipeline"]
        G --> J[TelemetryAggregator]
        H --> J
        I --> J
        J --> K[Cost Calculator]
        J --> L[Latency Tracker]
        J --> M[Audit Logger]
    end

    subgraph Observability_Stack["Observability Stack"]
        K --> N[Prometheus/Datadog]
        L --> N
        M --> O[PostgreSQL/S3]
        N --> P[Grafana Dashboard]
    end

Key concepts you need to understand:

  • Correlation ID: A unique identifier that links all callbacks within a single user request, even across nested agent calls
  • Span Context: Parent-child relationships between operations, essential for understanding where time and money go
  • Token Economics: Different models charge differently per input/output token—your pipeline must track both separately

đź’ˇ ADK callbacks execute synchronously by default. For production, you’ll buffer telemetry and flush asynchronously to avoid adding latency to your agent responses.

Step-by-Step Implementation

Building the Core Callback Handler with Token and Cost Tracking

Start by creating a callback handler that captures the essential metrics at each LLM invocation. This handler will track token counts, model selection, and execution time.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
# observability/callbacks.py
import time
import uuid
from dataclasses import dataclass, field
from typing import Optional, Dict, Any, List
from datetime import datetime, timezone
from threading import local
from google.adk.agents import BaseCallbackHandler
from google.adk.models import LLMRequest, LLMResponse

# Thread-local storage for correlation context
_context = local()

# Pricing per 1K tokens (update these based on your provider)
MODEL_PRICING = {
    "gemini-1.5-pro": {"input": 0.00125, "output": 0.00375},
    "gemini-1.5-flash": {"input": 0.000075, "output": 0.0003},
    "gpt-4-turbo": {"input": 0.01, "output": 0.03},
    "gpt-4o": {"input": 0.005, "output": 0.015},
    "claude-3-sonnet": {"input": 0.003, "output": 0.015},
}

@dataclass
class LLMCallMetrics:
    """Captures all metrics for a single LLM invocation."""
    call_id: str
    correlation_id: str
    parent_span_id: Optional[str]
    model_name: str
    input_tokens: int
    output_tokens: int
    total_tokens: int
    cost_usd: float
    latency_ms: float
    timestamp: datetime
    is_retry: bool = False
    retry_count: int = 0
    error: Optional[str] = None
    
@dataclass
class SpanContext:
    """Tracks the current execution context for correlation."""
    correlation_id: str
    span_id: str
    parent_span_id: Optional[str] = None
    depth: int = 0
    
class ObservabilityCallbackHandler(BaseCallbackHandler):
    """
    Production callback handler that captures comprehensive telemetry
    for every LLM call, tool invocation, and agent step.
    """
    
    def __init__(self, telemetry_sink: "TelemetrySink"):
        self.sink = telemetry_sink
        self._call_starts: Dict[str, float] = {}
        self._retry_counts: Dict[str, int] = {}
        
    def _get_context(self) -> Optional[SpanContext]:
        """Retrieve current span context from thread-local storage."""
        return getattr(_context, "span", None)
    
    def _set_context(self, ctx: SpanContext) -> None:
        """Set span context in thread-local storage."""
        _context.span = ctx
        
    def _calculate_cost(self, model: str, input_tokens: int, output_tokens: int) -> float:
        """Calculate USD cost based on model pricing."""
        pricing = MODEL_PRICING.get(model, {"input": 0.01, "output": 0.03})
        input_cost = (input_tokens / 1000) * pricing["input"]
        output_cost = (output_tokens / 1000) * pricing["output"]
        return round(input_cost + output_cost, 6)
    
    def on_llm_start(
        self, 
        request: LLMRequest, 
        metadata: Dict[str, Any]
    ) -> None:
        """Called immediately before each LLM API call."""
        call_id = str(uuid.uuid4())
        ctx = self._get_context()
        
        # Store start time for latency calculation
        self._call_starts[call_id] = time.perf_counter()
        
        # Track retries by hashing the prompt content
        prompt_hash = hash(str(request.messages))
        if prompt_hash in self._retry_counts:
            self._retry_counts[prompt_hash] += 1
        else:
            self._retry_counts[prompt_hash] = 0
            
        # Attach call_id to metadata for correlation in on_llm_end
        metadata["_obs_call_id"] = call_id
        metadata["_obs_prompt_hash"] = prompt_hash
        
    def on_llm_end(
        self, 
        request: LLMRequest,
        response: LLMResponse,
        metadata: Dict[str, Any]
    ) -> None:
        """Called immediately after each LLM API call completes."""
        call_id = metadata.get("_obs_call_id")
        prompt_hash = metadata.get("_obs_prompt_hash")
        ctx = self._get_context()
        
        # Calculate latency
        start_time = self._call_starts.pop(call_id, time.perf_counter())
        latency_ms = (time.perf_counter() - start_time) * 1000
        
        # Extract token counts from response
        usage = response.usage
        input_tokens = usage.prompt_tokens
        output_tokens = usage.completion_tokens
        
        # Build metrics object
        metrics = LLMCallMetrics(
            call_id=call_id,
            correlation_id=ctx.correlation_id if ctx else "unknown",
            parent_span_id=ctx.span_id if ctx else None,
            model_name=response.model,
            input_tokens=input_tokens,
            output_tokens=output_tokens,
            total_tokens=input_tokens + output_tokens,
            cost_usd=self._calculate_cost(response.model, input_tokens, output_tokens),
            latency_ms=round(latency_ms, 2),
            timestamp=datetime.now(timezone.utc),
            is_retry=self._retry_counts.get(prompt_hash, 0) > 0,
            retry_count=self._retry_counts.get(prompt_hash, 0),
        )
        
        # Send to telemetry sink (non-blocking)
        self.sink.record_llm_call(metrics)
        
    def on_llm_error(
        self, 
        request: LLMRequest,
        error: Exception,
        metadata: Dict[str, Any]
    ) -> None:
        """Called when an LLM call fails."""
        call_id = metadata.get("_obs_call_id")
        ctx = self._get_context()
        
        start_time = self._call_starts.pop(call_id, time.perf_counter())
        latency_ms = (time.perf_counter() - start_time) * 1000
        
        metrics = LLMCallMetrics(
            call_id=call_id,
            correlation_id=ctx.correlation_id if ctx else "unknown",
            parent_span_id=ctx.span_id if ctx else None,
            model_name=metadata.get("model", "unknown"),
            input_tokens=0,
            output_tokens=0,
            total_tokens=0,
            cost_usd=0.0,
            latency_ms=round(latency_ms, 2),
            timestamp=datetime.now(timezone.utc),
            error=str(error),
        )
        
        self.sink.record_llm_call(metrics)

⚠️ The MODEL_PRICING dictionary requires manual updates when providers change pricing. Consider fetching this from a configuration service in production.

Aggregating Costs Across Nested Agent Calls and Tool Invocations

When agents delegate to sub-agents or invoke tools that make their own LLM calls, costs compound in ways that aren’t immediately visible. You need an aggregator that maintains the call tree and rolls up costs correctly.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
# observability/aggregator.py
from dataclasses import dataclass, field
from typing import Dict, List, Optional
from collections import defaultdict
import threading

@dataclass
class AgentStepMetrics:
    """Aggregated metrics for a complete agent step (may include multiple LLM calls)."""
    step_id: str
    agent_name: str
    correlation_id: str
    total_llm_calls: int = 0
    total_tool_calls: int = 0
    total_input_tokens: int = 0
    total_output_tokens: int = 0
    total_cost_usd: float = 0.0
    total_latency_ms: float = 0.0
    child_agent_costs: Dict[str, float] = field(default_factory=dict)
    tool_costs: Dict[str, float] = field(default_factory=dict)
    
@dataclass  
class RequestMetrics:
    """Top-level metrics for an entire user request."""
    correlation_id: str
    user_id: Optional[str]
    session_id: Optional[str]
    root_agent: str
    total_cost_usd: float = 0.0
    total_latency_ms: float = 0.0
    total_llm_calls: int = 0
    agent_breakdown: Dict[str, float] = field(default_factory=dict)
    model_breakdown: Dict[str, float] = field(default_factory=dict)
    
class CostAggregator:
    """
    Maintains a call tree and aggregates costs from leaf nodes up to root.
    Thread-safe for concurrent request handling.
    """
    
    def __init__(self):
        self._lock = threading.Lock()
        # correlation_id -> list of LLMCallMetrics
        self._call_buffer: Dict[str, List[LLMCallMetrics]] = defaultdict(list)
        # correlation_id -> span_id -> AgentStepMetrics
        self._step_metrics: Dict[str, Dict[str, AgentStepMetrics]] = defaultdict(dict)
        # correlation_id -> RequestMetrics
        self._request_metrics: Dict[str, RequestMetrics] = {}
        # span_id -> parent_span_id for tree reconstruction
        self._span_parents: Dict[str, Optional[str]] = {}
        
    def start_request(
        self, 
        correlation_id: str,
        user_id: Optional[str] = None,
        session_id: Optional[str] = None,
        root_agent: str = "unknown"
    ) -> None:
        """Initialize tracking for a new user request."""
        with self._lock:
            self._request_metrics[correlation_id] = RequestMetrics(
                correlation_id=correlation_id,
                user_id=user_id,
                session_id=session_id,
                root_agent=root_agent,
            )
            
    def record_span_parent(self, span_id: str, parent_span_id: Optional[str]) -> None:
        """Track parent-child relationship for span tree reconstruction."""
        with self._lock:
            self._span_parents[span_id] = parent_span_id
            
    def add_llm_call(self, metrics: LLMCallMetrics) -> None:
        """Add an LLM call to the buffer for aggregation."""
        with self._lock:
            self._call_buffer[metrics.correlation_id].append(metrics)
            
    def add_tool_cost(
        self, 
        correlation_id: str,
        span_id: str,
        tool_name: str,
        cost_usd: float,
        latency_ms: float
    ) -> None:
        """Record cost incurred by a tool invocation."""
        with self._lock:
            if span_id not in self._step_metrics[correlation_id]:
                self._step_metrics[correlation_id][span_id] = AgentStepMetrics(
                    step_id=span_id,
                    agent_name="unknown",
                    correlation_id=correlation_id,
                )
            step = self._step_metrics[correlation_id][span_id]
            step.tool_costs[tool_name] = step.tool_costs.get(tool_name, 0) + cost_usd
            step.total_tool_calls += 1
            step.total_cost_usd += cost_usd
            step.total_latency_ms += latency_ms
            
    def finalize_request(self, correlation_id: str) -> RequestMetrics:
        """
        Aggregate all buffered metrics and return final request-level metrics.
        Call this when the root agent completes.
        """
        with self._lock:
            request = self._request_metrics.get(correlation_id)
            if not request:
                raise ValueError(f"No request found for correlation_id: {correlation_id}")
                
            calls = self._call_buffer.pop(correlation_id, [])
            
            # Aggregate LLM calls
            for call in calls:
                request.total_cost_usd += call.cost_usd
                request.total_latency_ms += call.latency_ms
                request.total_llm_calls += 1
                
                # Track by model
                model = call.model_name
                request.model_breakdown[model] = (
                    request.model_breakdown.get(model, 0) + call.cost_usd
                )
                
            # Add tool costs from step metrics
            steps = self._step_metrics.pop(correlation_id, {})
            for span_id, step in steps.items():
                agent = step.agent_name
                request.agent_breakdown[agent] = (
                    request.agent_breakdown.get(agent, 0) + step.total_cost_usd
                )
                # Tool costs already included in step totals
                
            # Cleanup
            self._request_metrics.pop(correlation_id, None)
            
            return request

    def get_cost_tree(self, correlation_id: str) -> Dict[str, Any]:
        """
        Build a hierarchical cost breakdown showing nested agent/tool costs.
        Useful for debugging expensive request paths.
        """
        with self._lock:
            calls = self._call_buffer.get(correlation_id, [])
            steps = self._step_metrics.get(correlation_id, {})
            
            # Build tree structure
            tree = {"root": {"children": {}, "cost": 0.0, "calls": []}}
            
            for call in calls:
                parent_id = call.parent_span_id or "root"
                if parent_id not in tree:
                    tree[parent_id] = {"children": {}, "cost": 0.0, "calls": []}
                tree[parent_id]["calls"].append({
                    "model": call.model_name,
                    "cost": call.cost_usd,
                    "tokens": call.total_tokens,
                })
                tree[parent_id]["cost"] += call.cost_usd
                
            return tree

📝 The finalize_request method must be called exactly once per request. Calling it multiple times will raise an error—this is intentional to catch bugs in your callback wiring.

Designing the Structured Audit Log Schema

Compliance requirements and debugging both demand complete audit trails. This schema captures every detail you need to reconstruct what happened during any request.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
# observability/audit.py
from dataclasses import dataclass, asdict
from typing import Optional, List, Dict, Any
from datetime import datetime
from enum import Enum
import json
import hashlib

class EventType(str, Enum):
    REQUEST_START = "request_start"
    REQUEST_END = "request_end"
    LLM_CALL = "llm_call"
    LLM_RETRY = "llm_retry"
    LLM_FALLBACK = "llm_fallback"
    TOOL_INVOCATION = "tool_invocation"
    AGENT_DELEGATION = "agent_delegation"
    ERROR = "error"

@dataclass
class PromptResponsePair:
    """
    Stores prompt and response for audit purposes.
    Includes content hashing for integrity verification.
    """
    prompt_messages: List[Dict[str, str]]
    response_content: str
    prompt_hash: str
    response_hash: str
    
    @classmethod
    def create(cls, messages: List[Dict], response: str) -> "PromptResponsePair":
        prompt_str = json.dumps(messages, sort_keys=True)
        return cls(
            prompt_messages=messages,
            response_content=response,
            prompt_hash=hashlib.sha256(prompt_str.encode()).hexdigest()[:16],
            response_hash=hashlib.sha256(response.encode()).hexdigest()[:16],
        )

@dataclass
class AuditLogEntry:
    """
    Complete audit log entry correlating all aspects of an operation.
    Designed for both compliance auditing and debugging.
    """
    # Identification
    entry_id: str
    correlation_id: str
    span_id: str
    parent_span_id: Optional[str]
    
    # Timing
    timestamp: datetime
    duration_ms: Optional[float]
    
    # Event classification
    event_type: EventType
    agent_name: str
    
    # User context (for compliance)
    user_id: Optional[str]
    session_id: Optional[str]
    client_ip: Optional[str]
    
    # LLM-specific fields
    model_name: Optional[str] = None
    prompt_response: Optional[PromptResponsePair] = None
    input_tokens: Optional[int] = None
    output_tokens: Optional[int] = None
    cost_usd: Optional[float] = None
    
    # Retry/fallback tracking
    is_retry: bool = False
    retry_attempt: int = 0
    original_model: Optional[str] = None  # For fallback events
    fallback_reason: Optional[str] = None
    
    # Tool-specific fields
    tool_name: Optional[str] = None
    tool_input: Optional[Dict[str, Any]] = None
    tool_output: Optional[str] = None
    
    # Error tracking
    error_type: Optional[str] = None
    error_message: Optional[str] = None
    error_stack: Optional[str] = None
    
    def to_dict(self) -> Dict[str, Any]:
        """Convert to dictionary, handling nested dataclasses."""
        result = asdict(self)
        result["timestamp"] = self.timestamp.isoformat()
        result["event_type"] = self.event_type.value
        return result
    
    def to_json(self) -> str:
        """Serialize to JSON string for storage."""
        return json.dumps(self.to_dict(), default=str)

class AuditLogger:
    """
    Writes structured audit logs to configured backends.
    Supports async batching for high-throughput scenarios.
    """
    
    def __init__(
        self, 
        backends: List["AuditBackend"],
        buffer_size: int = 100,
        flush_interval_seconds: float = 5.0
    ):
        self.backends = backends
        self.buffer_size = buffer_size
        self.flush_interval = flush_interval_seconds
        self._buffer: List[AuditLogEntry] = []
        self._lock = threading.Lock()
        
    def log(self, entry: AuditLogEntry) -> None:
        """Add entry to buffer, flush if buffer is full."""
        with self._lock:
            self._buffer.append(entry)
            if len(self._buffer) >= self.buffer_size:
                self._flush()
                
    def _flush(self) -> None:
        """Write buffered entries to all backends."""
        if not self._buffer:
            return
            
        entries = self._buffer.copy()
        self._buffer.clear()
        
        for backend in self.backends:
            try:
                backend.write_batch(entries)
            except Exception as e:
                # Log to stderr, don't crash the agent
                print(f"Audit backend {backend.__class__.__name__} failed: {e}")

    def log_llm_call(
        self,
        correlation_id: str,
        span_id: str,
        parent_span_id: Optional[str],
        agent_name: str,
        model_name: str,
        messages: List[Dict],
        response: str,
        metrics: LLMCallMetrics,
        user_context: Dict[str, Any],
    ) -> None:
        """Convenience method for logging LLM calls with full context."""
        entry = AuditLogEntry(
            entry_id=metrics.call_id,
            correlation_id=correlation_id,
            span_id=span_id,
            parent_span_id=parent_span_id,
            timestamp=metrics.timestamp,
            duration_ms=metrics.latency_ms,
            event_type=EventType.LLM_RETRY if metrics.is_retry else EventType.LLM_CALL,
            agent_name=agent_name,
            user_id=user_context.get("user_id"),
            session_id=user_context.get("session_id"),
            client_ip=user_context.get("client_ip"),
            model_name=model_name,
            prompt_response=PromptResponsePair.create(messages, response),
            input_tokens=metrics.input_tokens,
            output_tokens=metrics.output_tokens,
            cost_usd=metrics.cost_usd,
            is_retry=metrics.is_retry,
            retry_attempt=metrics.retry_count,
        )
        self.log(entry)

Now create the SQL schema for PostgreSQL persistence:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
-- observability/schema.sql
-- Audit log table with proper indexing for compliance queries

CREATE TABLE IF NOT EXISTS audit_logs (
    entry_id UUID PRIMARY KEY,
    correlation_id UUID NOT NULL,
    span_id UUID NOT NULL,
    parent_span_id UUID,
    
    -- Timing with microsecond precision
    timestamp TIMESTAMPTZ NOT NULL,
    duration_ms DECIMAL(10, 2),
    
    -- Event classification
    event_type VARCHAR(50) NOT NULL,
    agent_name VARCHAR(255) NOT NULL,
    
    -- User context for compliance
    user_id VARCHAR(255),
    session_id VARCHAR(255),
    client_ip INET,
    
    -- LLM call details
    model_name VARCHAR(100),
    prompt_hash CHAR(16),
    response_hash CHAR(16),
    input_tokens INTEGER,
    output_tokens INTEGER,
    cost_usd DECIMAL(10, 6),
    
    -- Full prompt/response stored separately for space efficiency
    prompt_response_id UUID REFERENCES prompt_responses(id),
    
    -- Retry tracking
    is_retry BOOLEAN DEFAULT FALSE,
    retry_attempt INTEGER DEFAULT 0,
    original_model VARCHAR(100),
    fallback_reason TEXT,
    
    -- Tool tracking
    tool_name VARCHAR(255),
    tool_input JSONB,
    tool_output TEXT,
    
    -- Error tracking
    error_type VARCHAR(255),
    error_message TEXT,
    error_stack TEXT,
    
    -- Metadata
    created_at TIMESTAMPTZ DEFAULT NOW()
);

-- Indexes for common query patterns
CREATE INDEX idx_audit_correlation ON audit_logs(correlation_id);
CREATE INDEX idx_audit_user_time ON audit_logs(user_id, timestamp DESC);
CREATE INDEX idx_audit_event_type ON audit_logs(event_type, timestamp DESC);
CREATE INDEX idx_audit_model_cost ON audit_logs(model_name, cost_usd DESC);
CREATE INDEX idx_audit_errors ON audit_logs(error_type) WHERE error_type IS NOT NULL;

-- Separate table for prompt/response content (can be large)
CREATE TABLE IF NOT EXISTS prompt_responses (
    id UUID PRIMARY KEY,
    prompt_messages JSONB NOT NULL,
    response_content TEXT NOT NULL,
    prompt_hash CHAR(16) NOT NULL,
    response_hash CHAR(16) NOT NULL,
    created_at TIMESTAMPTZ DEFAULT NOW()
);

-- Partitioning for time-series data (PostgreSQL 12+)
-- Uncomment and modify retention period as needed
/*
CREATE TABLE audit_logs_partitioned (
    LIKE audit_logs INCLUDING ALL
) PARTITION BY RANGE (timestamp);

CREATE TABLE audit_logs_y2024m01 PARTITION OF audit_logs_partitioned
    FOR VALUES FROM ('2024-01-01') TO ('2024-02-01');
*/

-- View for cost analysis
CREATE OR REPLACE VIEW daily_cost_summary AS
SELECT 
    DATE_TRUNC('day', timestamp) as day,
    user_id,
    model_name,
    COUNT(*) as call_count,
    SUM(input_tokens) as total_input_tokens,
    SUM(output_tokens) as total_output_tokens,
    SUM(cost_usd) as total_cost,
    AVG(duration_ms) as avg_latency_ms,
    PERCENTILE_CONT(0.95) WITHIN GROUP (ORDER BY duration_ms) as p95_latency_ms
FROM audit_logs
WHERE event_type IN ('llm_call', 'llm_retry')
GROUP BY DATE_TRUNC('day', timestamp), user_id, model_name;

đź’ˇ Consider using table partitioning by month for audit logs. This dramatically improves query performance and makes it easy to archive old data for compliance retention requirements.

Production Configuration

Production deployments require careful attention to configuration management, secrets handling, and environment-specific settings. Here’s a comprehensive configuration approach that scales across environments.

Environment-Based Configuration

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# config/observability.yaml
# Production-ready configuration for ADK observability pipeline

environments:
  production:
    logging:
      level: INFO
      format: json
      include_request_body: false  # PII protection
      include_response_body: false
      sample_rate: 1.0  # Log everything in prod
      
    metrics:
      enabled: true
      export_interval_seconds: 10
      histogram_buckets: [10, 50, 100, 250, 500, 1000, 2500, 5000, 10000]
      
    tracing:
      enabled: true
      sample_rate: 0.1  # Sample 10% of traces
      propagation: w3c
      
    storage:
      type: postgresql
      connection_pool_size: 20
      max_overflow: 10
      pool_timeout: 30
      
    alerting:
      cost_threshold_daily_usd: 500
      latency_p95_threshold_ms: 3000
      error_rate_threshold_percent: 5
      
  staging:
    logging:
      level: DEBUG
      format: json
      include_request_body: true
      include_response_body: true
      sample_rate: 1.0
      
    metrics:
      enabled: true
      export_interval_seconds: 30
      
    tracing:
      enabled: true
      sample_rate: 1.0  # Trace everything in staging
      
    storage:
      type: postgresql
      connection_pool_size: 5
      
  development:
    logging:
      level: DEBUG
      format: pretty
      include_request_body: true
      include_response_body: true
      
    metrics:
      enabled: false
      
    tracing:
      enabled: true
      sample_rate: 1.0
      
    storage:
      type: sqlite
      path: ./observability.db

Configuration Loader with Validation

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# src/config/loader.py
from pydantic import BaseModel, Field, validator
from typing import Literal, Optional
import yaml
import os

class LoggingConfig(BaseModel):
    level: Literal["DEBUG", "INFO", "WARNING", "ERROR"] = "INFO"
    format: Literal["json", "pretty"] = "json"
    include_request_body: bool = False
    include_response_body: bool = False
    sample_rate: float = Field(ge=0.0, le=1.0, default=1.0)

class MetricsConfig(BaseModel):
    enabled: bool = True
    export_interval_seconds: int = Field(ge=1, le=300, default=10)
    histogram_buckets: list[int] = [10, 50, 100, 250, 500, 1000, 2500, 5000]

class TracingConfig(BaseModel):
    enabled: bool = True
    sample_rate: float = Field(ge=0.0, le=1.0, default=0.1)
    propagation: Literal["w3c", "b3", "jaeger"] = "w3c"

class StorageConfig(BaseModel):
    type: Literal["postgresql", "sqlite", "bigquery"] = "postgresql"
    connection_pool_size: int = Field(ge=1, le=100, default=20)
    max_overflow: int = Field(ge=0, le=50, default=10)
    pool_timeout: int = Field(ge=1, le=120, default=30)
    
    # Loaded from environment variables for security
    connection_string: Optional[str] = None
    
    @validator("connection_string", pre=True, always=True)
    def load_connection_string(cls, v):
        # Never store credentials in config files
        return os.environ.get("DATABASE_URL", v)

class AlertingConfig(BaseModel):
    cost_threshold_daily_usd: float = Field(ge=0, default=500)
    latency_p95_threshold_ms: int = Field(ge=0, default=3000)
    error_rate_threshold_percent: float = Field(ge=0, le=100, default=5)

class ObservabilityConfig(BaseModel):
    logging: LoggingConfig = LoggingConfig()
    metrics: MetricsConfig = MetricsConfig()
    tracing: TracingConfig = TracingConfig()
    storage: StorageConfig = StorageConfig()
    alerting: AlertingConfig = AlertingConfig()

def load_config(environment: str = None) -> ObservabilityConfig:
    """Load configuration for the specified environment."""
    env = environment or os.environ.get("ENVIRONMENT", "development")
    
    config_path = os.environ.get(
        "OBSERVABILITY_CONFIG_PATH",
        "config/observability.yaml"
    )
    
    with open(config_path) as f:
        raw_config = yaml.safe_load(f)
    
    env_config = raw_config.get("environments", {}).get(env, {})
    
    if not env_config:
        raise ValueError(f"No configuration found for environment: {env}")
    
    return ObservabilityConfig(**env_config)

Production Callback Manager Assembly

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# src/observability/factory.py
from google.adk.agents import Agent
from src.config.loader import load_config, ObservabilityConfig
from src.callbacks.cost_tracker import CostTrackingCallback
from src.callbacks.latency_monitor import LatencyMonitorCallback
from src.callbacks.audit_logger import AuditLogCallback
from src.storage.async_writer import AsyncBatchWriter
from opentelemetry import trace
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
import structlog

class ObservabilityFactory:
    """Factory for assembling production observability components."""
    
    def __init__(self, config: ObservabilityConfig = None):
        self.config = config or load_config()
        self._setup_tracing()
        self._setup_logging()
    
    def _setup_tracing(self):
        """Configure OpenTelemetry tracing."""
        if not self.config.tracing.enabled:
            return
            
        provider = TracerProvider()
        
        # Export to your observability backend (Jaeger, Tempo, etc.)
        otlp_exporter = OTLPSpanExporter(
            endpoint=os.environ.get("OTLP_ENDPOINT", "localhost:4317"),
            insecure=os.environ.get("ENVIRONMENT") != "production"
        )
        
        provider.add_span_processor(BatchSpanProcessor(otlp_exporter))
        trace.set_tracer_provider(provider)
    
    def _setup_logging(self):
        """Configure structured logging."""
        processors = [
            structlog.contextvars.merge_contextvars,
            structlog.processors.add_log_level,
            structlog.processors.TimeStamper(fmt="iso"),
        ]
        
        if self.config.logging.format == "json":
            processors.append(structlog.processors.JSONRenderer())
        else:
            processors.append(structlog.dev.ConsoleRenderer())
        
        structlog.configure(
            processors=processors,
            wrapper_class=structlog.make_filtering_bound_logger(
                getattr(structlog, self.config.logging.level)
            ),
            context_class=dict,
            cache_logger_on_first_use=True,
        )
    
    def create_callbacks(self) -> list:
        """Create the full callback stack for production use."""
        callbacks = []
        
        # Async writer for non-blocking persistence
        writer = AsyncBatchWriter(
            connection_string=self.config.storage.connection_string,
            batch_size=100,
            flush_interval=5.0
        )
        
        # Cost tracking - always enabled
        cost_callback = CostTrackingCallback(
            alert_threshold_usd=self.config.alerting.cost_threshold_daily_usd,
            writer=writer
        )
        callbacks.append(cost_callback)
        
        # Latency monitoring
        latency_callback = LatencyMonitorCallback(
            p95_threshold_ms=self.config.alerting.latency_p95_threshold_ms,
            histogram_buckets=self.config.metrics.histogram_buckets
        )
        callbacks.append(latency_callback)
        
        # Audit logging with PII controls
        audit_callback = AuditLogCallback(
            include_request_body=self.config.logging.include_request_body,
            include_response_body=self.config.logging.include_response_body,
            sample_rate=self.config.logging.sample_rate,
            writer=writer
        )
        callbacks.append(audit_callback)
        
        return callbacks
    
    def instrument_agent(self, agent: Agent) -> Agent:
        """Add full observability instrumentation to an agent."""
        callbacks = self.create_callbacks()
        
        for callback in callbacks:
            agent.register_callback(callback)
        
        return agent


# Usage in application startup
def create_production_agent():
    """Create a fully instrumented production agent."""
    factory = ObservabilityFactory()
    
    agent = Agent(
        model="gemini-2.0-flash",
        system_instruction="You are a helpful assistant.",
        tools=[...]
    )
    
    return factory.instrument_agent(agent)

Common Mistakes and Troubleshooting

After deploying observability pipelines across dozens of production systems, I’ve catalogued the most frequent failure modes. Here’s how to avoid them.

Mistake 1: Synchronous Database Writes in Callbacks

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# ❌ WRONG: This blocks the agent and adds latency
class BadAuditCallback(BaseCallback):
    def on_llm_end(self, response, **kwargs):
        # This database write blocks the response!
        self.db.execute(
            "INSERT INTO audit_logs VALUES (...)",
            params
        )

# âś… CORRECT: Use async writes with batching
class GoodAuditCallback(BaseCallback):
    def __init__(self):
        self._buffer = []
        self._buffer_lock = threading.Lock()
        self._start_flush_thread()
    
    def on_llm_end(self, response, **kwargs):
        # Non-blocking append to buffer
        with self._buffer_lock:
            self._buffer.append(self._format_log(response))
    
    def _flush_buffer(self):
        """Background thread flushes buffer periodically."""
        while True:
            time.sleep(5)
            with self._buffer_lock:
                if self._buffer:
                    batch = self._buffer.copy()
                    self._buffer.clear()
            
            if batch:
                # Bulk insert is much faster
                self.db.execute_batch(batch)

Mistake 2: Not Handling Callback Exceptions

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# ❌ WRONG: Exception in callback crashes the agent
class FragileCallback(BaseCallback):
    def on_llm_end(self, response, **kwargs):
        # If metrics server is down, this crashes everything
        self.metrics_client.record(response.usage)

# âś… CORRECT: Isolate callback failures
class RobustCallback(BaseCallback):
    def __init__(self):
        self.logger = structlog.get_logger()
        self._consecutive_failures = 0
        self._circuit_open = False
    
    def on_llm_end(self, response, **kwargs):
        if self._circuit_open:
            return  # Skip when circuit breaker is open
        
        try:
            self.metrics_client.record(response.usage)
            self._consecutive_failures = 0
        except Exception as e:
            self._consecutive_failures += 1
            self.logger.warning(
                "callback_failed",
                error=str(e),
                consecutive_failures=self._consecutive_failures
            )
            
            # Open circuit after 5 consecutive failures
            if self._consecutive_failures >= 5:
                self._circuit_open = True
                self._schedule_circuit_reset()
    
    def _schedule_circuit_reset(self):
        """Reset circuit breaker after cooldown."""
        def reset():
            time.sleep(60)
            self._circuit_open = False
            self._consecutive_failures = 0
        
        threading.Thread(target=reset, daemon=True).start()

Mistake 3: Missing Context Propagation

sequenceDiagram
    participant Client
    participant Agent
    participant Tool
    participant LLM
    participant Observability
    
    Client->>Agent: Request (trace_id: abc123)
    
    Note over Agent,Observability: ❌ Without context propagation
    Agent->>Tool: Execute (no trace_id)
    Tool->>Observability: Log (orphaned)
    Agent->>LLM: Call (no trace_id)
    LLM->>Observability: Log (orphaned)
    
    Note over Agent,Observability: âś… With context propagation
    Agent->>Tool: Execute (trace_id: abc123)
    Tool->>Observability: Log (trace_id: abc123)
    Agent->>LLM: Call (trace_id: abc123)
    LLM->>Observability: Log (trace_id: abc123)
    
    Observability-->>Client: Complete trace visible
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
# âś… CORRECT: Propagate context through the entire call chain
import contextvars
from opentelemetry import trace

# Context variable for request-scoped data
request_context: contextvars.ContextVar[dict] = contextvars.ContextVar(
    'request_context',
    default={}
)

class ContextAwareCallback(BaseCallback):
    def __init__(self):
        self.tracer = trace.get_tracer(__name__)
    
    def on_agent_start(self, agent_input, **kwargs):
        # Extract or create trace context
        ctx = request_context.get()
        
        span = self.tracer.start_span(
            "agent.execution",
            context=trace.set_span_in_context(ctx.get('parent_span'))
        )
        
        # Store span for child operations
        ctx['current_span'] = span
        ctx['session_id'] = kwargs.get('session_id', str(uuid.uuid4()))
        request_context.set(ctx)
    
    def on_tool_start(self, tool_name, tool_input, **kwargs):
        ctx = request_context.get()
        parent_span = ctx.get('current_span')
        
        # Create child span linked to parent
        tool_span = self.tracer.start_span(
            f"tool.{tool_name}",
            context=trace.set_span_in_context(parent_span)
        )
        tool_span.set_attribute("tool.input", str(tool_input)[:1000])
        
        ctx['tool_span'] = tool_span
        request_context.set(ctx)
    
    def on_tool_end(self, tool_output, **kwargs):
        ctx = request_context.get()
        tool_span = ctx.get('tool_span')
        
        if tool_span:
            tool_span.set_attribute("tool.output_length", len(str(tool_output)))
            tool_span.end()

⚠️ Warning: Without proper context propagation, you’ll have fragmented traces that are impossible to debug. A user complaint about slow responses becomes a needle-in-haystack search across millions of unrelated log entries.

Troubleshooting Checklist

SymptomLikely CauseSolution
Missing logs for some requestsSample rate too low or callback exceptionCheck sample_rate config; add exception handling
High agent latencySynchronous callback operationsSwitch to async/batched writes
Incomplete tracesContext not propagatedImplement context variables pattern
Cost metrics don’t match billingToken counting mismatchUse provider’s token counting, account for retries
Database connection exhaustionNo connection poolingConfigure pool_size and max_overflow
Memory growth over timeBuffer not flushingCheck flush thread is running; add buffer size limits

Performance and Scalability

Benchmarking Callback Overhead

Before deploying, measure the actual overhead your callbacks add:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# scripts/benchmark_callbacks.py
import asyncio
import time
import statistics
from google.adk.agents import Agent

async def benchmark_callback_overhead(
    agent: Agent,
    num_requests: int = 100,
    warmup_requests: int = 10
):
    """Measure the latency overhead introduced by callbacks."""
    
    # Warmup
    for _ in range(warmup_requests):
        await agent.run("Say hello")
    
    # Benchmark without callbacks
    agent_no_callbacks = Agent(
        model=agent.model,
        system_instruction=agent.system_instruction
    )
    
    baseline_latencies = []
    for _ in range(num_requests):
        start = time.perf_counter()
        await agent_no_callbacks.run("Say hello")
        baseline_latencies.append((time.perf_counter() - start) * 1000)
    
    # Benchmark with callbacks
    instrumented_latencies = []
    for _ in range(num_requests):
        start = time.perf_counter()
        await agent.run("Say hello")
        instrumented_latencies.append((time.perf_counter() - start) * 1000)
    
    baseline_p50 = statistics.median(baseline_latencies)
    baseline_p99 = statistics.quantiles(baseline_latencies, n=100)[98]
    
    instrumented_p50 = statistics.median(instrumented_latencies)
    instrumented_p99 = statistics.quantiles(instrumented_latencies, n=100)[98]
    
    print(f"Baseline P50: {baseline_p50:.2f}ms, P99: {baseline_p99:.2f}ms")
    print(f"Instrumented P50: {instrumented_p50:.2f}ms, P99: {instrumented_p99:.2f}ms")
    print(f"Overhead P50: {instrumented_p50 - baseline_p50:.2f}ms")
    print(f"Overhead P99: {instrumented_p99 - baseline_p99:.2f}ms")
    
    # Fail if overhead exceeds threshold
    assert instrumented_p50 - baseline_p50 < 5, "P50 overhead exceeds 5ms!"
    assert instrumented_p99 - baseline_p99 < 20, "P99 overhead exceeds 20ms!"


if __name__ == "__main__":
    from src.observability.factory import create_production_agent
    
    agent = create_production_agent()
    asyncio.run(benchmark_callback_overhead(agent))

📝 Note: Well-implemented callbacks should add less than 5ms of P50 latency. If you’re seeing more, profile your callback code — the culprit is usually synchronous I/O or excessive JSON serialization.

Scaling to High Throughput

For applications handling thousands of requests per second:

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
# src/observability/high_throughput.py
import asyncio
from collections import deque
from dataclasses import dataclass
from typing import Deque
import aiomultiprocess

@dataclass
class MetricsBatch:
    """Batch of metrics for bulk processing."""
    timestamps: list[float]
    latencies: list[float]
    costs: list[float]
    model_names: list[str]

class HighThroughputCollector:
    """
    Lock-free metrics collection for high-throughput scenarios.
    Uses ring buffers and background aggregation.
    """
    
    def __init__(
        self,
        buffer_size: int = 10000,
        aggregation_interval: float = 1.0,
        num_workers: int = 4
    ):
        # Ring buffer for lock-free writes
        self._buffer: Deque[dict] = deque(maxlen=buffer_size)
        self._aggregation_interval = aggregation_interval
        self._num_workers = num_workers
        self._running = False
    
    def record(self, metric: dict):
        """
        Record a metric. This is O(1) and never blocks.
        Old metrics are dropped if buffer is full.
        """
        self._buffer.append(metric)
    
    async def start(self):
        """Start background aggregation workers."""
        self._running = True
        
        async with aiomultiprocess.Pool(processes=self._num_workers) as pool:
            while self._running:
                await asyncio.sleep(self._aggregation_interval)
                
                # Drain buffer
                batch = []
                while self._buffer:
                    try:
                        batch.append(self._buffer.popleft())
                    except IndexError:
                        break
                
                if batch:
                    # Parallel aggregation across workers
                    chunk_size = len(batch) // self._num_workers
                    chunks = [
                        batch[i:i + chunk_size]
                        for i in range(0, len(batch), chunk_size)
                    ]
                    
                    results = await pool.map(self._aggregate_chunk, chunks)
                    combined = self._combine_aggregations(results)
                    await self._emit_metrics(combined)
    
    @staticmethod
    def _aggregate_chunk(metrics: list[dict]) -> dict:
        """Aggregate a chunk of metrics (runs in worker process)."""
        aggregated = {
            'count': len(metrics),
            'total_cost': sum(m.get('cost', 0) for m in metrics),
            'latencies': [m.get('latency', 0) for m in metrics],
            'by_model': {}
        }
        
        for metric in metrics:
            model = metric.get('model', 'unknown')
            if model not in aggregated['by_model']:
                aggregated['by_model'][model] = {'count': 0, 'cost': 0}
            aggregated['by_model'][model]['count'] += 1
            aggregated['by_model'][model]['cost'] += metric.get('cost', 0)
        
        return aggregated
    
    def _combine_aggregations(self, results: list[dict]) -> dict:
        """Combine aggregations from multiple workers."""
        combined = {
            'count': sum(r['count'] for r in results),
            'total_cost': sum(r['total_cost'] for r in results),
            'latencies': [],
            'by_model': {}
        }
        
        for result in results:
            combined['latencies'].extend(result['latencies'])
            for model, stats in result['by_model'].items():
                if model not in combined['by_model']:
                    combined['by_model'][model] = {'count': 0, 'cost': 0}
                combined['by_model'][model]['count'] += stats['count']
                combined['by_model'][model]['cost'] += stats['cost']
        
        return combined
    
    async def _emit_metrics(self, aggregated: dict):
        """Emit aggregated metrics to monitoring backend."""
        # Prometheus, DataDog, CloudWatch, etc.
        pass

Resource Limits and Backpressure

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# kubernetes/deployment.yaml
apiVersion: apps/v1
kind: Deployment
metadata:
  name: ai-agent-service
spec:
  replicas: 3
  template:
    spec:
      containers:
      - name: agent
        resources:
          requests:
            memory: "512Mi"
            cpu: "500m"
          limits:
            memory: "2Gi"
            cpu: "2000m"
        env:
        # Limit callback buffer to prevent OOM
        - name: OBSERVABILITY_BUFFER_MAX_SIZE
          value: "50000"
        - name: OBSERVABILITY_BUFFER_MAX_MEMORY_MB
          value: "256"
        # Drop metrics under extreme load rather than crash
        - name: OBSERVABILITY_BACKPRESSURE_STRATEGY
          value: "drop_oldest"

Conclusion and Next Steps

You now have the foundation for a production-grade observability pipeline that tracks every LLM call, measures costs down to the penny, and provides the latency insights needed to optimize agent performance.

Key takeaways:

  1. Use async, batched writes — Never let observability code block your agent’s critical path
  2. Propagate context religiously — Trace IDs that connect user requests to individual LLM calls are invaluable during incidents
  3. Build for failure — Circuit breakers and graceful degradation ensure your observability layer never takes down your application
  4. Measure the observer — Benchmark your callbacks and set overhead budgets

Recommended next steps:

  1. Add anomaly detection — Use the cost and latency data you’re now collecting to build ML models that detect unusual patterns before users complain
  2. Implement cost attribution — Extend the audit logs to include feature flags and A/B test variants so you can calculate ROI per feature
  3. Build self-healing capabilities — When the pipeline detects repeated failures or cost spikes, automatically switch to fallback models or rate limit specific users
  4. Create executive dashboards — The data you’re collecting enables powerful business intelligence: cost per customer segment, LLM ROI by use case, capacity planning forecasts

The observability pipeline you’ve built isn’t just about monitoring — it’s the foundation for continuous improvement of your AI agents. Every millisecond saved and every dollar tracked compounds into significant competitive advantage over time.

Additional Resources

Common Mistakes and Troubleshooting

Mistake 1: Blocking the Agent Event Loop

The most critical error I see in production deployments is performing synchronous I/O operations inside callbacks, which blocks the entire agent execution.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# ❌ WRONG: Blocking the event loop with synchronous calls
class BlockingCallback(BaseCallback):
    def on_llm_end(self, response: LLMResponse, **kwargs):
        # This blocks the agent while waiting for the database
        self.db_connection.execute(
            "INSERT INTO metrics (tokens, cost) VALUES (?, ?)",
            (response.token_count, response.cost)
        )
        # Synchronous HTTP call - blocks until complete
        requests.post("https://metrics.example.com/ingest", json=response.to_dict())


# âś… CORRECT: Non-blocking async operations with buffering
class NonBlockingCallback(BaseCallback):
    def __init__(self):
        self._buffer: list[dict] = []
        self._buffer_lock = asyncio.Lock()
        self._flush_task: Optional[asyncio.Task] = None
    
    async def on_llm_end(self, response: LLMResponse, **kwargs):
        # Quick append to in-memory buffer
        async with self._buffer_lock:
            self._buffer.append({
                "tokens": response.token_count,
                "cost": response.cost,
                "timestamp": time.time()
            })
        
        # Schedule flush if buffer is large enough
        if len(self._buffer) >= 100 and not self._flush_task:
            self._flush_task = asyncio.create_task(self._flush_buffer())
    
    async def _flush_buffer(self):
        async with self._buffer_lock:
            to_flush = self._buffer.copy()
            self._buffer.clear()
        
        # Async HTTP call - doesn't block the agent
        async with aiohttp.ClientSession() as session:
            await session.post(
                "https://metrics.example.com/ingest",
                json=to_flush
            )
        self._flush_task = None

⚠️ Warning: A single blocking call in a callback can add hundreds of milliseconds to every LLM interaction. At scale, this compounds into significant latency degradation and increased costs.

Mistake 2: Missing Error Boundaries

Callbacks that throw exceptions will crash your agent. Always wrap callback logic in error boundaries.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# ❌ WRONG: Unhandled exceptions crash the agent
class FragileCallback(BaseCallback):
    def on_tool_start(self, tool_name: str, **kwargs):
        # If metrics service is down, this throws and kills the agent
        self.metrics_client.increment(f"tool.{tool_name}.calls")


# âś… CORRECT: Graceful degradation with error boundaries
class ResilientCallback(BaseCallback):
    def __init__(self):
        self._error_count = 0
        self._circuit_open = False
        self._last_error_time = 0
    
    def on_tool_start(self, tool_name: str, **kwargs):
        # Circuit breaker pattern - skip if too many recent failures
        if self._circuit_open:
            if time.time() - self._last_error_time > 60:  # 60s cooldown
                self._circuit_open = False
                self._error_count = 0
            else:
                return  # Skip silently during circuit break
        
        try:
            self.metrics_client.increment(f"tool.{tool_name}.calls")
            self._error_count = 0  # Reset on success
        except Exception as e:
            self._error_count += 1
            self._last_error_time = time.time()
            
            # Open circuit after 5 consecutive failures
            if self._error_count >= 5:
                self._circuit_open = True
                logger.warning(
                    f"Circuit breaker opened for metrics callback: {e}"
                )
            else:
                logger.debug(f"Metrics callback failed (attempt {self._error_count}): {e}")

Mistake 3: Incorrect Cost Calculation for Streaming Responses

Streaming responses require special handling because tokens arrive incrementally.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# ❌ WRONG: Counting tokens per chunk leads to massive overcounting
class BrokenStreamingCallback(BaseCallback):
    def on_llm_stream_chunk(self, chunk: StreamChunk, **kwargs):
        # This counts the CUMULATIVE tokens in each chunk, not incremental
        self.total_tokens += chunk.token_count  # Overcounts by 10-50x!


# âś… CORRECT: Track incremental tokens in streaming responses
class AccurateStreamingCallback(BaseCallback):
    def __init__(self):
        self._stream_sessions: dict[str, StreamSession] = {}
    
    def on_llm_stream_start(self, stream_id: str, **kwargs):
        self._stream_sessions[stream_id] = StreamSession(
            start_time=time.time(),
            previous_token_count=0
        )
    
    def on_llm_stream_chunk(self, stream_id: str, chunk: StreamChunk, **kwargs):
        session = self._stream_sessions.get(stream_id)
        if not session:
            return
        
        # Calculate INCREMENTAL tokens (current - previous)
        incremental_tokens = chunk.cumulative_tokens - session.previous_token_count
        session.previous_token_count = chunk.cumulative_tokens
        session.total_tokens += incremental_tokens
    
    def on_llm_stream_end(self, stream_id: str, **kwargs):
        session = self._stream_sessions.pop(stream_id, None)
        if session:
            # Now we have accurate token count for the entire stream
            self._record_metrics(
                tokens=session.total_tokens,
                duration=time.time() - session.start_time
            )

Troubleshooting Decision Flow

When your observability pipeline behaves unexpectedly, follow this diagnostic flow:

flowchart TD
    A[Metrics Missing or Incorrect] --> B{Check Callback Registration}
    B -->|Not Registered| C[Add callback to agent config]
    B -->|Registered| D{Check Error Logs}
    
    D -->|Exceptions Found| E{Exception Type}
    E -->|Connection Error| F[Check network/credentials]
    E -->|Timeout| G[Increase timeout or add buffering]
    E -->|Serialization| H[Validate payload schema]
    
    D -->|No Exceptions| I{Check Callback Execution}
    I -->|Not Called| J[Verify event hooks match ADK version]
    I -->|Called| K{Check Metric Values}
    
    K -->|Zero Values| L[Inspect response object structure]
    K -->|Wrong Values| M[Debug token/cost calculation logic]
    K -->|Correct Values| N[Check metrics backend ingestion]
    
    N --> O{Data in Backend?}
    O -->|No| P[Check backend connection and auth]
    O -->|Yes| Q[Issue is in visualization/queries]

Mistake 4: Memory Leaks from Unbounded Context Storage

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# ❌ WRONG: Storing context indefinitely causes memory leaks
class LeakyCallback(BaseCallback):
    def __init__(self):
        self.request_contexts = {}  # Never cleaned up!
    
    def on_request_start(self, request_id: str, **kwargs):
        self.request_contexts[request_id] = {
            "start_time": time.time(),
            "tokens": 0
        }


# âś… CORRECT: TTL-based cleanup with bounded storage
class BoundedCallback(BaseCallback):
    def __init__(self, max_contexts: int = 10000, ttl_seconds: int = 300):
        self._contexts: OrderedDict[str, ContextEntry] = OrderedDict()
        self._max_contexts = max_contexts
        self._ttl = ttl_seconds
    
    def on_request_start(self, request_id: str, **kwargs):
        self._cleanup_expired()
        
        # Evict oldest if at capacity
        while len(self._contexts) >= self._max_contexts:
            self._contexts.popitem(last=False)
        
        self._contexts[request_id] = ContextEntry(
            start_time=time.time(),
            tokens=0
        )
    
    def _cleanup_expired(self):
        now = time.time()
        expired = [
            rid for rid, ctx in self._contexts.items()
            if now - ctx.start_time > self._ttl
        ]
        for rid in expired:
            del self._contexts[rid]

đź’ˇ Tip: In production, monitor your callback’s memory footprint using tracemalloc or similar profiling tools. Set alerts for memory growth that exceeds expected bounds.


Conclusion and Next Steps

Building production-ready observability for AI agents requires treating callbacks as first-class infrastructure components, not afterthoughts. Throughout this guide, we’ve covered the essential patterns:

  1. Cost Attribution: Track every token and compute dollar back to specific users, features, and requests using hierarchical metric labels.

  2. Latency Decomposition: Measure not just total response time, but TTFT, TTLT, and per-component durations to identify optimization opportunities.

  3. Error Classification: Categorize failures by type, recoverability, and cost impact to build effective alerting and auto-remediation.

  4. Non-blocking Architecture: Use buffering, async I/O, and circuit breakers to ensure observability never degrades agent performance.

The investment in proper observability infrastructure pays dividends immediately. Teams with mature observability pipelines consistently report:

  • 30-50% reduction in wasted LLM spend through identifying and eliminating inefficient prompts
  • 60% faster incident resolution through precise latency attribution
  • Proactive cost management with accurate forecasting and anomaly detection

Week 1-2: Foundation

  • Deploy the basic callback structure with cost and latency tracking
  • Configure Prometheus/Grafana or your preferred metrics backend
  • Set up initial dashboards for the four golden signals

Week 3-4: Enhancement

  • Add distributed tracing integration with OpenTelemetry
  • Implement budget alerting with multiple threshold tiers
  • Create per-user and per-feature cost attribution

Month 2: Optimization

  • Analyze collected data to identify prompt optimization opportunities
  • Build automated anomaly detection for cost spikes
  • Implement predictive scaling based on usage patterns

Ongoing: Refinement

  • Review and tune alert thresholds monthly
  • Add custom business metrics as requirements evolve
  • Share dashboards with stakeholders for cost transparency

📝 Note: The observability pipeline you build today becomes the foundation for advanced capabilities like A/B testing prompt variations, automated cost optimization, and ML-driven anomaly detection. Design for extensibility from the start.


Additional Resources