Learn to build a production-ready observability pipeline with ADK callbacks. Track token usage, costs, and execution time for your AI agents at scale.
Building a Production-Ready Observability Pipeline with ADK Callbacks: Tracking Every Dollar and Millisecond in Your AI Agents
Your AI agent processed 50,000 requests last month. You know this because your billing dashboard shows a $12,847 charge from OpenAI. What you don’t know: which user sessions consumed 40% of that budget, why Tuesday’s requests cost 3x more than Monday’s, or which tool invocations triggered expensive retry cascades. You’re flying blind, and it’s costing you money.
The Agent Development Kit (ADK) provides callback hooks at every critical juncture—before and after each LLM call, tool invocation, and agent step. Most developers ignore them or log basic print statements. In this article, you’ll build a production-grade observability pipeline that captures token usage, execution time, and costs at microsecond granularity, then pipes that data into dashboards where you can actually act on it.
Prerequisites
Before diving in, ensure you have:
- Python 3.10+ installed
- Google ADK (
pip install google-adk>=1.0.0) - Basic familiarity with ADK’s agent structure and callback system
- One of these observability backends configured: OpenTelemetry Collector, Datadog Agent, or Prometheus with Pushgateway
- PostgreSQL or similar for audit log persistence (optional but recommended)
You should understand how ADK agents execute—specifically that agents can nest sub-agents and invoke tools, creating call trees that need correlation.
Architecture and Key Concepts
The observability pipeline intercepts ADK execution at three levels: individual LLM calls, tool invocations, and aggregate agent steps. Each callback captures structured telemetry, enriches it with correlation IDs, and routes it to appropriate sinks.
flowchart TD
subgraph ADK_Runtime["ADK Runtime"]
A[User Request] --> B[Root Agent]
B --> C{Decision Point}
C -->|LLM Call| D[Model Invocation]
C -->|Tool Use| E[Tool Execution]
C -->|Delegate| F[Sub-Agent]
F --> C
end
subgraph Callback_Layer["Callback Layer"]
D --> G[on_llm_start / on_llm_end]
E --> H[on_tool_start / on_tool_end]
F --> I[on_agent_start / on_agent_end]
end
subgraph Telemetry_Pipeline["Telemetry Pipeline"]
G --> J[TelemetryAggregator]
H --> J
I --> J
J --> K[Cost Calculator]
J --> L[Latency Tracker]
J --> M[Audit Logger]
end
subgraph Observability_Stack["Observability Stack"]
K --> N[Prometheus/Datadog]
L --> N
M --> O[PostgreSQL/S3]
N --> P[Grafana Dashboard]
end
Key concepts you need to understand:
- Correlation ID: A unique identifier that links all callbacks within a single user request, even across nested agent calls
- Span Context: Parent-child relationships between operations, essential for understanding where time and money go
- Token Economics: Different models charge differently per input/output token—your pipeline must track both separately
đź’ˇ ADK callbacks execute synchronously by default. For production, you’ll buffer telemetry and flush asynchronously to avoid adding latency to your agent responses.
Step-by-Step Implementation
Building the Core Callback Handler with Token and Cost Tracking
Start by creating a callback handler that captures the essential metrics at each LLM invocation. This handler will track token counts, model selection, and execution time.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
| # observability/callbacks.py
import time
import uuid
from dataclasses import dataclass, field
from typing import Optional, Dict, Any, List
from datetime import datetime, timezone
from threading import local
from google.adk.agents import BaseCallbackHandler
from google.adk.models import LLMRequest, LLMResponse
# Thread-local storage for correlation context
_context = local()
# Pricing per 1K tokens (update these based on your provider)
MODEL_PRICING = {
"gemini-1.5-pro": {"input": 0.00125, "output": 0.00375},
"gemini-1.5-flash": {"input": 0.000075, "output": 0.0003},
"gpt-4-turbo": {"input": 0.01, "output": 0.03},
"gpt-4o": {"input": 0.005, "output": 0.015},
"claude-3-sonnet": {"input": 0.003, "output": 0.015},
}
@dataclass
class LLMCallMetrics:
"""Captures all metrics for a single LLM invocation."""
call_id: str
correlation_id: str
parent_span_id: Optional[str]
model_name: str
input_tokens: int
output_tokens: int
total_tokens: int
cost_usd: float
latency_ms: float
timestamp: datetime
is_retry: bool = False
retry_count: int = 0
error: Optional[str] = None
@dataclass
class SpanContext:
"""Tracks the current execution context for correlation."""
correlation_id: str
span_id: str
parent_span_id: Optional[str] = None
depth: int = 0
class ObservabilityCallbackHandler(BaseCallbackHandler):
"""
Production callback handler that captures comprehensive telemetry
for every LLM call, tool invocation, and agent step.
"""
def __init__(self, telemetry_sink: "TelemetrySink"):
self.sink = telemetry_sink
self._call_starts: Dict[str, float] = {}
self._retry_counts: Dict[str, int] = {}
def _get_context(self) -> Optional[SpanContext]:
"""Retrieve current span context from thread-local storage."""
return getattr(_context, "span", None)
def _set_context(self, ctx: SpanContext) -> None:
"""Set span context in thread-local storage."""
_context.span = ctx
def _calculate_cost(self, model: str, input_tokens: int, output_tokens: int) -> float:
"""Calculate USD cost based on model pricing."""
pricing = MODEL_PRICING.get(model, {"input": 0.01, "output": 0.03})
input_cost = (input_tokens / 1000) * pricing["input"]
output_cost = (output_tokens / 1000) * pricing["output"]
return round(input_cost + output_cost, 6)
def on_llm_start(
self,
request: LLMRequest,
metadata: Dict[str, Any]
) -> None:
"""Called immediately before each LLM API call."""
call_id = str(uuid.uuid4())
ctx = self._get_context()
# Store start time for latency calculation
self._call_starts[call_id] = time.perf_counter()
# Track retries by hashing the prompt content
prompt_hash = hash(str(request.messages))
if prompt_hash in self._retry_counts:
self._retry_counts[prompt_hash] += 1
else:
self._retry_counts[prompt_hash] = 0
# Attach call_id to metadata for correlation in on_llm_end
metadata["_obs_call_id"] = call_id
metadata["_obs_prompt_hash"] = prompt_hash
def on_llm_end(
self,
request: LLMRequest,
response: LLMResponse,
metadata: Dict[str, Any]
) -> None:
"""Called immediately after each LLM API call completes."""
call_id = metadata.get("_obs_call_id")
prompt_hash = metadata.get("_obs_prompt_hash")
ctx = self._get_context()
# Calculate latency
start_time = self._call_starts.pop(call_id, time.perf_counter())
latency_ms = (time.perf_counter() - start_time) * 1000
# Extract token counts from response
usage = response.usage
input_tokens = usage.prompt_tokens
output_tokens = usage.completion_tokens
# Build metrics object
metrics = LLMCallMetrics(
call_id=call_id,
correlation_id=ctx.correlation_id if ctx else "unknown",
parent_span_id=ctx.span_id if ctx else None,
model_name=response.model,
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=input_tokens + output_tokens,
cost_usd=self._calculate_cost(response.model, input_tokens, output_tokens),
latency_ms=round(latency_ms, 2),
timestamp=datetime.now(timezone.utc),
is_retry=self._retry_counts.get(prompt_hash, 0) > 0,
retry_count=self._retry_counts.get(prompt_hash, 0),
)
# Send to telemetry sink (non-blocking)
self.sink.record_llm_call(metrics)
def on_llm_error(
self,
request: LLMRequest,
error: Exception,
metadata: Dict[str, Any]
) -> None:
"""Called when an LLM call fails."""
call_id = metadata.get("_obs_call_id")
ctx = self._get_context()
start_time = self._call_starts.pop(call_id, time.perf_counter())
latency_ms = (time.perf_counter() - start_time) * 1000
metrics = LLMCallMetrics(
call_id=call_id,
correlation_id=ctx.correlation_id if ctx else "unknown",
parent_span_id=ctx.span_id if ctx else None,
model_name=metadata.get("model", "unknown"),
input_tokens=0,
output_tokens=0,
total_tokens=0,
cost_usd=0.0,
latency_ms=round(latency_ms, 2),
timestamp=datetime.now(timezone.utc),
error=str(error),
)
self.sink.record_llm_call(metrics)
|
⚠️ The MODEL_PRICING dictionary requires manual updates when providers change pricing. Consider fetching this from a configuration service in production.
When agents delegate to sub-agents or invoke tools that make their own LLM calls, costs compound in ways that aren’t immediately visible. You need an aggregator that maintains the call tree and rolls up costs correctly.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
| # observability/aggregator.py
from dataclasses import dataclass, field
from typing import Dict, List, Optional
from collections import defaultdict
import threading
@dataclass
class AgentStepMetrics:
"""Aggregated metrics for a complete agent step (may include multiple LLM calls)."""
step_id: str
agent_name: str
correlation_id: str
total_llm_calls: int = 0
total_tool_calls: int = 0
total_input_tokens: int = 0
total_output_tokens: int = 0
total_cost_usd: float = 0.0
total_latency_ms: float = 0.0
child_agent_costs: Dict[str, float] = field(default_factory=dict)
tool_costs: Dict[str, float] = field(default_factory=dict)
@dataclass
class RequestMetrics:
"""Top-level metrics for an entire user request."""
correlation_id: str
user_id: Optional[str]
session_id: Optional[str]
root_agent: str
total_cost_usd: float = 0.0
total_latency_ms: float = 0.0
total_llm_calls: int = 0
agent_breakdown: Dict[str, float] = field(default_factory=dict)
model_breakdown: Dict[str, float] = field(default_factory=dict)
class CostAggregator:
"""
Maintains a call tree and aggregates costs from leaf nodes up to root.
Thread-safe for concurrent request handling.
"""
def __init__(self):
self._lock = threading.Lock()
# correlation_id -> list of LLMCallMetrics
self._call_buffer: Dict[str, List[LLMCallMetrics]] = defaultdict(list)
# correlation_id -> span_id -> AgentStepMetrics
self._step_metrics: Dict[str, Dict[str, AgentStepMetrics]] = defaultdict(dict)
# correlation_id -> RequestMetrics
self._request_metrics: Dict[str, RequestMetrics] = {}
# span_id -> parent_span_id for tree reconstruction
self._span_parents: Dict[str, Optional[str]] = {}
def start_request(
self,
correlation_id: str,
user_id: Optional[str] = None,
session_id: Optional[str] = None,
root_agent: str = "unknown"
) -> None:
"""Initialize tracking for a new user request."""
with self._lock:
self._request_metrics[correlation_id] = RequestMetrics(
correlation_id=correlation_id,
user_id=user_id,
session_id=session_id,
root_agent=root_agent,
)
def record_span_parent(self, span_id: str, parent_span_id: Optional[str]) -> None:
"""Track parent-child relationship for span tree reconstruction."""
with self._lock:
self._span_parents[span_id] = parent_span_id
def add_llm_call(self, metrics: LLMCallMetrics) -> None:
"""Add an LLM call to the buffer for aggregation."""
with self._lock:
self._call_buffer[metrics.correlation_id].append(metrics)
def add_tool_cost(
self,
correlation_id: str,
span_id: str,
tool_name: str,
cost_usd: float,
latency_ms: float
) -> None:
"""Record cost incurred by a tool invocation."""
with self._lock:
if span_id not in self._step_metrics[correlation_id]:
self._step_metrics[correlation_id][span_id] = AgentStepMetrics(
step_id=span_id,
agent_name="unknown",
correlation_id=correlation_id,
)
step = self._step_metrics[correlation_id][span_id]
step.tool_costs[tool_name] = step.tool_costs.get(tool_name, 0) + cost_usd
step.total_tool_calls += 1
step.total_cost_usd += cost_usd
step.total_latency_ms += latency_ms
def finalize_request(self, correlation_id: str) -> RequestMetrics:
"""
Aggregate all buffered metrics and return final request-level metrics.
Call this when the root agent completes.
"""
with self._lock:
request = self._request_metrics.get(correlation_id)
if not request:
raise ValueError(f"No request found for correlation_id: {correlation_id}")
calls = self._call_buffer.pop(correlation_id, [])
# Aggregate LLM calls
for call in calls:
request.total_cost_usd += call.cost_usd
request.total_latency_ms += call.latency_ms
request.total_llm_calls += 1
# Track by model
model = call.model_name
request.model_breakdown[model] = (
request.model_breakdown.get(model, 0) + call.cost_usd
)
# Add tool costs from step metrics
steps = self._step_metrics.pop(correlation_id, {})
for span_id, step in steps.items():
agent = step.agent_name
request.agent_breakdown[agent] = (
request.agent_breakdown.get(agent, 0) + step.total_cost_usd
)
# Tool costs already included in step totals
# Cleanup
self._request_metrics.pop(correlation_id, None)
return request
def get_cost_tree(self, correlation_id: str) -> Dict[str, Any]:
"""
Build a hierarchical cost breakdown showing nested agent/tool costs.
Useful for debugging expensive request paths.
"""
with self._lock:
calls = self._call_buffer.get(correlation_id, [])
steps = self._step_metrics.get(correlation_id, {})
# Build tree structure
tree = {"root": {"children": {}, "cost": 0.0, "calls": []}}
for call in calls:
parent_id = call.parent_span_id or "root"
if parent_id not in tree:
tree[parent_id] = {"children": {}, "cost": 0.0, "calls": []}
tree[parent_id]["calls"].append({
"model": call.model_name,
"cost": call.cost_usd,
"tokens": call.total_tokens,
})
tree[parent_id]["cost"] += call.cost_usd
return tree
|
📝 The finalize_request method must be called exactly once per request. Calling it multiple times will raise an error—this is intentional to catch bugs in your callback wiring.
Designing the Structured Audit Log Schema
Compliance requirements and debugging both demand complete audit trails. This schema captures every detail you need to reconstruct what happened during any request.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
| # observability/audit.py
from dataclasses import dataclass, asdict
from typing import Optional, List, Dict, Any
from datetime import datetime
from enum import Enum
import json
import hashlib
class EventType(str, Enum):
REQUEST_START = "request_start"
REQUEST_END = "request_end"
LLM_CALL = "llm_call"
LLM_RETRY = "llm_retry"
LLM_FALLBACK = "llm_fallback"
TOOL_INVOCATION = "tool_invocation"
AGENT_DELEGATION = "agent_delegation"
ERROR = "error"
@dataclass
class PromptResponsePair:
"""
Stores prompt and response for audit purposes.
Includes content hashing for integrity verification.
"""
prompt_messages: List[Dict[str, str]]
response_content: str
prompt_hash: str
response_hash: str
@classmethod
def create(cls, messages: List[Dict], response: str) -> "PromptResponsePair":
prompt_str = json.dumps(messages, sort_keys=True)
return cls(
prompt_messages=messages,
response_content=response,
prompt_hash=hashlib.sha256(prompt_str.encode()).hexdigest()[:16],
response_hash=hashlib.sha256(response.encode()).hexdigest()[:16],
)
@dataclass
class AuditLogEntry:
"""
Complete audit log entry correlating all aspects of an operation.
Designed for both compliance auditing and debugging.
"""
# Identification
entry_id: str
correlation_id: str
span_id: str
parent_span_id: Optional[str]
# Timing
timestamp: datetime
duration_ms: Optional[float]
# Event classification
event_type: EventType
agent_name: str
# User context (for compliance)
user_id: Optional[str]
session_id: Optional[str]
client_ip: Optional[str]
# LLM-specific fields
model_name: Optional[str] = None
prompt_response: Optional[PromptResponsePair] = None
input_tokens: Optional[int] = None
output_tokens: Optional[int] = None
cost_usd: Optional[float] = None
# Retry/fallback tracking
is_retry: bool = False
retry_attempt: int = 0
original_model: Optional[str] = None # For fallback events
fallback_reason: Optional[str] = None
# Tool-specific fields
tool_name: Optional[str] = None
tool_input: Optional[Dict[str, Any]] = None
tool_output: Optional[str] = None
# Error tracking
error_type: Optional[str] = None
error_message: Optional[str] = None
error_stack: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary, handling nested dataclasses."""
result = asdict(self)
result["timestamp"] = self.timestamp.isoformat()
result["event_type"] = self.event_type.value
return result
def to_json(self) -> str:
"""Serialize to JSON string for storage."""
return json.dumps(self.to_dict(), default=str)
class AuditLogger:
"""
Writes structured audit logs to configured backends.
Supports async batching for high-throughput scenarios.
"""
def __init__(
self,
backends: List["AuditBackend"],
buffer_size: int = 100,
flush_interval_seconds: float = 5.0
):
self.backends = backends
self.buffer_size = buffer_size
self.flush_interval = flush_interval_seconds
self._buffer: List[AuditLogEntry] = []
self._lock = threading.Lock()
def log(self, entry: AuditLogEntry) -> None:
"""Add entry to buffer, flush if buffer is full."""
with self._lock:
self._buffer.append(entry)
if len(self._buffer) >= self.buffer_size:
self._flush()
def _flush(self) -> None:
"""Write buffered entries to all backends."""
if not self._buffer:
return
entries = self._buffer.copy()
self._buffer.clear()
for backend in self.backends:
try:
backend.write_batch(entries)
except Exception as e:
# Log to stderr, don't crash the agent
print(f"Audit backend {backend.__class__.__name__} failed: {e}")
def log_llm_call(
self,
correlation_id: str,
span_id: str,
parent_span_id: Optional[str],
agent_name: str,
model_name: str,
messages: List[Dict],
response: str,
metrics: LLMCallMetrics,
user_context: Dict[str, Any],
) -> None:
"""Convenience method for logging LLM calls with full context."""
entry = AuditLogEntry(
entry_id=metrics.call_id,
correlation_id=correlation_id,
span_id=span_id,
parent_span_id=parent_span_id,
timestamp=metrics.timestamp,
duration_ms=metrics.latency_ms,
event_type=EventType.LLM_RETRY if metrics.is_retry else EventType.LLM_CALL,
agent_name=agent_name,
user_id=user_context.get("user_id"),
session_id=user_context.get("session_id"),
client_ip=user_context.get("client_ip"),
model_name=model_name,
prompt_response=PromptResponsePair.create(messages, response),
input_tokens=metrics.input_tokens,
output_tokens=metrics.output_tokens,
cost_usd=metrics.cost_usd,
is_retry=metrics.is_retry,
retry_attempt=metrics.retry_count,
)
self.log(entry)
|
Now create the SQL schema for PostgreSQL persistence:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
| -- observability/schema.sql
-- Audit log table with proper indexing for compliance queries
CREATE TABLE IF NOT EXISTS audit_logs (
entry_id UUID PRIMARY KEY,
correlation_id UUID NOT NULL,
span_id UUID NOT NULL,
parent_span_id UUID,
-- Timing with microsecond precision
timestamp TIMESTAMPTZ NOT NULL,
duration_ms DECIMAL(10, 2),
-- Event classification
event_type VARCHAR(50) NOT NULL,
agent_name VARCHAR(255) NOT NULL,
-- User context for compliance
user_id VARCHAR(255),
session_id VARCHAR(255),
client_ip INET,
-- LLM call details
model_name VARCHAR(100),
prompt_hash CHAR(16),
response_hash CHAR(16),
input_tokens INTEGER,
output_tokens INTEGER,
cost_usd DECIMAL(10, 6),
-- Full prompt/response stored separately for space efficiency
prompt_response_id UUID REFERENCES prompt_responses(id),
-- Retry tracking
is_retry BOOLEAN DEFAULT FALSE,
retry_attempt INTEGER DEFAULT 0,
original_model VARCHAR(100),
fallback_reason TEXT,
-- Tool tracking
tool_name VARCHAR(255),
tool_input JSONB,
tool_output TEXT,
-- Error tracking
error_type VARCHAR(255),
error_message TEXT,
error_stack TEXT,
-- Metadata
created_at TIMESTAMPTZ DEFAULT NOW()
);
-- Indexes for common query patterns
CREATE INDEX idx_audit_correlation ON audit_logs(correlation_id);
CREATE INDEX idx_audit_user_time ON audit_logs(user_id, timestamp DESC);
CREATE INDEX idx_audit_event_type ON audit_logs(event_type, timestamp DESC);
CREATE INDEX idx_audit_model_cost ON audit_logs(model_name, cost_usd DESC);
CREATE INDEX idx_audit_errors ON audit_logs(error_type) WHERE error_type IS NOT NULL;
-- Separate table for prompt/response content (can be large)
CREATE TABLE IF NOT EXISTS prompt_responses (
id UUID PRIMARY KEY,
prompt_messages JSONB NOT NULL,
response_content TEXT NOT NULL,
prompt_hash CHAR(16) NOT NULL,
response_hash CHAR(16) NOT NULL,
created_at TIMESTAMPTZ DEFAULT NOW()
);
-- Partitioning for time-series data (PostgreSQL 12+)
-- Uncomment and modify retention period as needed
/*
CREATE TABLE audit_logs_partitioned (
LIKE audit_logs INCLUDING ALL
) PARTITION BY RANGE (timestamp);
CREATE TABLE audit_logs_y2024m01 PARTITION OF audit_logs_partitioned
FOR VALUES FROM ('2024-01-01') TO ('2024-02-01');
*/
-- View for cost analysis
CREATE OR REPLACE VIEW daily_cost_summary AS
SELECT
DATE_TRUNC('day', timestamp) as day,
user_id,
model_name,
COUNT(*) as call_count,
SUM(input_tokens) as total_input_tokens,
SUM(output_tokens) as total_output_tokens,
SUM(cost_usd) as total_cost,
AVG(duration_ms) as avg_latency_ms,
PERCENTILE_CONT(0.95) WITHIN GROUP (ORDER BY duration_ms) as p95_latency_ms
FROM audit_logs
WHERE event_type IN ('llm_call', 'llm_retry')
GROUP BY DATE_TRUNC('day', timestamp), user_id, model_name;
|
đź’ˇ Consider using table partitioning by month for audit logs. This dramatically improves query performance and makes it easy to archive old data for compliance retention requirements.
Production Configuration
Production deployments require careful attention to configuration management, secrets handling, and environment-specific settings. Here’s a comprehensive configuration approach that scales across environments.
Environment-Based Configuration
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
| # config/observability.yaml
# Production-ready configuration for ADK observability pipeline
environments:
production:
logging:
level: INFO
format: json
include_request_body: false # PII protection
include_response_body: false
sample_rate: 1.0 # Log everything in prod
metrics:
enabled: true
export_interval_seconds: 10
histogram_buckets: [10, 50, 100, 250, 500, 1000, 2500, 5000, 10000]
tracing:
enabled: true
sample_rate: 0.1 # Sample 10% of traces
propagation: w3c
storage:
type: postgresql
connection_pool_size: 20
max_overflow: 10
pool_timeout: 30
alerting:
cost_threshold_daily_usd: 500
latency_p95_threshold_ms: 3000
error_rate_threshold_percent: 5
staging:
logging:
level: DEBUG
format: json
include_request_body: true
include_response_body: true
sample_rate: 1.0
metrics:
enabled: true
export_interval_seconds: 30
tracing:
enabled: true
sample_rate: 1.0 # Trace everything in staging
storage:
type: postgresql
connection_pool_size: 5
development:
logging:
level: DEBUG
format: pretty
include_request_body: true
include_response_body: true
metrics:
enabled: false
tracing:
enabled: true
sample_rate: 1.0
storage:
type: sqlite
path: ./observability.db
|
Configuration Loader with Validation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
| # src/config/loader.py
from pydantic import BaseModel, Field, validator
from typing import Literal, Optional
import yaml
import os
class LoggingConfig(BaseModel):
level: Literal["DEBUG", "INFO", "WARNING", "ERROR"] = "INFO"
format: Literal["json", "pretty"] = "json"
include_request_body: bool = False
include_response_body: bool = False
sample_rate: float = Field(ge=0.0, le=1.0, default=1.0)
class MetricsConfig(BaseModel):
enabled: bool = True
export_interval_seconds: int = Field(ge=1, le=300, default=10)
histogram_buckets: list[int] = [10, 50, 100, 250, 500, 1000, 2500, 5000]
class TracingConfig(BaseModel):
enabled: bool = True
sample_rate: float = Field(ge=0.0, le=1.0, default=0.1)
propagation: Literal["w3c", "b3", "jaeger"] = "w3c"
class StorageConfig(BaseModel):
type: Literal["postgresql", "sqlite", "bigquery"] = "postgresql"
connection_pool_size: int = Field(ge=1, le=100, default=20)
max_overflow: int = Field(ge=0, le=50, default=10)
pool_timeout: int = Field(ge=1, le=120, default=30)
# Loaded from environment variables for security
connection_string: Optional[str] = None
@validator("connection_string", pre=True, always=True)
def load_connection_string(cls, v):
# Never store credentials in config files
return os.environ.get("DATABASE_URL", v)
class AlertingConfig(BaseModel):
cost_threshold_daily_usd: float = Field(ge=0, default=500)
latency_p95_threshold_ms: int = Field(ge=0, default=3000)
error_rate_threshold_percent: float = Field(ge=0, le=100, default=5)
class ObservabilityConfig(BaseModel):
logging: LoggingConfig = LoggingConfig()
metrics: MetricsConfig = MetricsConfig()
tracing: TracingConfig = TracingConfig()
storage: StorageConfig = StorageConfig()
alerting: AlertingConfig = AlertingConfig()
def load_config(environment: str = None) -> ObservabilityConfig:
"""Load configuration for the specified environment."""
env = environment or os.environ.get("ENVIRONMENT", "development")
config_path = os.environ.get(
"OBSERVABILITY_CONFIG_PATH",
"config/observability.yaml"
)
with open(config_path) as f:
raw_config = yaml.safe_load(f)
env_config = raw_config.get("environments", {}).get(env, {})
if not env_config:
raise ValueError(f"No configuration found for environment: {env}")
return ObservabilityConfig(**env_config)
|
Production Callback Manager Assembly
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
| # src/observability/factory.py
from google.adk.agents import Agent
from src.config.loader import load_config, ObservabilityConfig
from src.callbacks.cost_tracker import CostTrackingCallback
from src.callbacks.latency_monitor import LatencyMonitorCallback
from src.callbacks.audit_logger import AuditLogCallback
from src.storage.async_writer import AsyncBatchWriter
from opentelemetry import trace
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
import structlog
class ObservabilityFactory:
"""Factory for assembling production observability components."""
def __init__(self, config: ObservabilityConfig = None):
self.config = config or load_config()
self._setup_tracing()
self._setup_logging()
def _setup_tracing(self):
"""Configure OpenTelemetry tracing."""
if not self.config.tracing.enabled:
return
provider = TracerProvider()
# Export to your observability backend (Jaeger, Tempo, etc.)
otlp_exporter = OTLPSpanExporter(
endpoint=os.environ.get("OTLP_ENDPOINT", "localhost:4317"),
insecure=os.environ.get("ENVIRONMENT") != "production"
)
provider.add_span_processor(BatchSpanProcessor(otlp_exporter))
trace.set_tracer_provider(provider)
def _setup_logging(self):
"""Configure structured logging."""
processors = [
structlog.contextvars.merge_contextvars,
structlog.processors.add_log_level,
structlog.processors.TimeStamper(fmt="iso"),
]
if self.config.logging.format == "json":
processors.append(structlog.processors.JSONRenderer())
else:
processors.append(structlog.dev.ConsoleRenderer())
structlog.configure(
processors=processors,
wrapper_class=structlog.make_filtering_bound_logger(
getattr(structlog, self.config.logging.level)
),
context_class=dict,
cache_logger_on_first_use=True,
)
def create_callbacks(self) -> list:
"""Create the full callback stack for production use."""
callbacks = []
# Async writer for non-blocking persistence
writer = AsyncBatchWriter(
connection_string=self.config.storage.connection_string,
batch_size=100,
flush_interval=5.0
)
# Cost tracking - always enabled
cost_callback = CostTrackingCallback(
alert_threshold_usd=self.config.alerting.cost_threshold_daily_usd,
writer=writer
)
callbacks.append(cost_callback)
# Latency monitoring
latency_callback = LatencyMonitorCallback(
p95_threshold_ms=self.config.alerting.latency_p95_threshold_ms,
histogram_buckets=self.config.metrics.histogram_buckets
)
callbacks.append(latency_callback)
# Audit logging with PII controls
audit_callback = AuditLogCallback(
include_request_body=self.config.logging.include_request_body,
include_response_body=self.config.logging.include_response_body,
sample_rate=self.config.logging.sample_rate,
writer=writer
)
callbacks.append(audit_callback)
return callbacks
def instrument_agent(self, agent: Agent) -> Agent:
"""Add full observability instrumentation to an agent."""
callbacks = self.create_callbacks()
for callback in callbacks:
agent.register_callback(callback)
return agent
# Usage in application startup
def create_production_agent():
"""Create a fully instrumented production agent."""
factory = ObservabilityFactory()
agent = Agent(
model="gemini-2.0-flash",
system_instruction="You are a helpful assistant.",
tools=[...]
)
return factory.instrument_agent(agent)
|
Common Mistakes and Troubleshooting
After deploying observability pipelines across dozens of production systems, I’ve catalogued the most frequent failure modes. Here’s how to avoid them.
Mistake 1: Synchronous Database Writes in Callbacks
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
| # ❌ WRONG: This blocks the agent and adds latency
class BadAuditCallback(BaseCallback):
def on_llm_end(self, response, **kwargs):
# This database write blocks the response!
self.db.execute(
"INSERT INTO audit_logs VALUES (...)",
params
)
# âś… CORRECT: Use async writes with batching
class GoodAuditCallback(BaseCallback):
def __init__(self):
self._buffer = []
self._buffer_lock = threading.Lock()
self._start_flush_thread()
def on_llm_end(self, response, **kwargs):
# Non-blocking append to buffer
with self._buffer_lock:
self._buffer.append(self._format_log(response))
def _flush_buffer(self):
"""Background thread flushes buffer periodically."""
while True:
time.sleep(5)
with self._buffer_lock:
if self._buffer:
batch = self._buffer.copy()
self._buffer.clear()
if batch:
# Bulk insert is much faster
self.db.execute_batch(batch)
|
Mistake 2: Not Handling Callback Exceptions
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
| # ❌ WRONG: Exception in callback crashes the agent
class FragileCallback(BaseCallback):
def on_llm_end(self, response, **kwargs):
# If metrics server is down, this crashes everything
self.metrics_client.record(response.usage)
# âś… CORRECT: Isolate callback failures
class RobustCallback(BaseCallback):
def __init__(self):
self.logger = structlog.get_logger()
self._consecutive_failures = 0
self._circuit_open = False
def on_llm_end(self, response, **kwargs):
if self._circuit_open:
return # Skip when circuit breaker is open
try:
self.metrics_client.record(response.usage)
self._consecutive_failures = 0
except Exception as e:
self._consecutive_failures += 1
self.logger.warning(
"callback_failed",
error=str(e),
consecutive_failures=self._consecutive_failures
)
# Open circuit after 5 consecutive failures
if self._consecutive_failures >= 5:
self._circuit_open = True
self._schedule_circuit_reset()
def _schedule_circuit_reset(self):
"""Reset circuit breaker after cooldown."""
def reset():
time.sleep(60)
self._circuit_open = False
self._consecutive_failures = 0
threading.Thread(target=reset, daemon=True).start()
|
Mistake 3: Missing Context Propagation
sequenceDiagram
participant Client
participant Agent
participant Tool
participant LLM
participant Observability
Client->>Agent: Request (trace_id: abc123)
Note over Agent,Observability: ❌ Without context propagation
Agent->>Tool: Execute (no trace_id)
Tool->>Observability: Log (orphaned)
Agent->>LLM: Call (no trace_id)
LLM->>Observability: Log (orphaned)
Note over Agent,Observability: âś… With context propagation
Agent->>Tool: Execute (trace_id: abc123)
Tool->>Observability: Log (trace_id: abc123)
Agent->>LLM: Call (trace_id: abc123)
LLM->>Observability: Log (trace_id: abc123)
Observability-->>Client: Complete trace visible
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
| # âś… CORRECT: Propagate context through the entire call chain
import contextvars
from opentelemetry import trace
# Context variable for request-scoped data
request_context: contextvars.ContextVar[dict] = contextvars.ContextVar(
'request_context',
default={}
)
class ContextAwareCallback(BaseCallback):
def __init__(self):
self.tracer = trace.get_tracer(__name__)
def on_agent_start(self, agent_input, **kwargs):
# Extract or create trace context
ctx = request_context.get()
span = self.tracer.start_span(
"agent.execution",
context=trace.set_span_in_context(ctx.get('parent_span'))
)
# Store span for child operations
ctx['current_span'] = span
ctx['session_id'] = kwargs.get('session_id', str(uuid.uuid4()))
request_context.set(ctx)
def on_tool_start(self, tool_name, tool_input, **kwargs):
ctx = request_context.get()
parent_span = ctx.get('current_span')
# Create child span linked to parent
tool_span = self.tracer.start_span(
f"tool.{tool_name}",
context=trace.set_span_in_context(parent_span)
)
tool_span.set_attribute("tool.input", str(tool_input)[:1000])
ctx['tool_span'] = tool_span
request_context.set(ctx)
def on_tool_end(self, tool_output, **kwargs):
ctx = request_context.get()
tool_span = ctx.get('tool_span')
if tool_span:
tool_span.set_attribute("tool.output_length", len(str(tool_output)))
tool_span.end()
|
⚠️ Warning: Without proper context propagation, you’ll have fragmented traces that are impossible to debug. A user complaint about slow responses becomes a needle-in-haystack search across millions of unrelated log entries.
Troubleshooting Checklist
| Symptom | Likely Cause | Solution |
|---|
| Missing logs for some requests | Sample rate too low or callback exception | Check sample_rate config; add exception handling |
| High agent latency | Synchronous callback operations | Switch to async/batched writes |
| Incomplete traces | Context not propagated | Implement context variables pattern |
| Cost metrics don’t match billing | Token counting mismatch | Use provider’s token counting, account for retries |
| Database connection exhaustion | No connection pooling | Configure pool_size and max_overflow |
| Memory growth over time | Buffer not flushing | Check flush thread is running; add buffer size limits |
Benchmarking Callback Overhead
Before deploying, measure the actual overhead your callbacks add:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
| # scripts/benchmark_callbacks.py
import asyncio
import time
import statistics
from google.adk.agents import Agent
async def benchmark_callback_overhead(
agent: Agent,
num_requests: int = 100,
warmup_requests: int = 10
):
"""Measure the latency overhead introduced by callbacks."""
# Warmup
for _ in range(warmup_requests):
await agent.run("Say hello")
# Benchmark without callbacks
agent_no_callbacks = Agent(
model=agent.model,
system_instruction=agent.system_instruction
)
baseline_latencies = []
for _ in range(num_requests):
start = time.perf_counter()
await agent_no_callbacks.run("Say hello")
baseline_latencies.append((time.perf_counter() - start) * 1000)
# Benchmark with callbacks
instrumented_latencies = []
for _ in range(num_requests):
start = time.perf_counter()
await agent.run("Say hello")
instrumented_latencies.append((time.perf_counter() - start) * 1000)
baseline_p50 = statistics.median(baseline_latencies)
baseline_p99 = statistics.quantiles(baseline_latencies, n=100)[98]
instrumented_p50 = statistics.median(instrumented_latencies)
instrumented_p99 = statistics.quantiles(instrumented_latencies, n=100)[98]
print(f"Baseline P50: {baseline_p50:.2f}ms, P99: {baseline_p99:.2f}ms")
print(f"Instrumented P50: {instrumented_p50:.2f}ms, P99: {instrumented_p99:.2f}ms")
print(f"Overhead P50: {instrumented_p50 - baseline_p50:.2f}ms")
print(f"Overhead P99: {instrumented_p99 - baseline_p99:.2f}ms")
# Fail if overhead exceeds threshold
assert instrumented_p50 - baseline_p50 < 5, "P50 overhead exceeds 5ms!"
assert instrumented_p99 - baseline_p99 < 20, "P99 overhead exceeds 20ms!"
if __name__ == "__main__":
from src.observability.factory import create_production_agent
agent = create_production_agent()
asyncio.run(benchmark_callback_overhead(agent))
|
📝 Note: Well-implemented callbacks should add less than 5ms of P50 latency. If you’re seeing more, profile your callback code — the culprit is usually synchronous I/O or excessive JSON serialization.
Scaling to High Throughput
For applications handling thousands of requests per second:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
| # src/observability/high_throughput.py
import asyncio
from collections import deque
from dataclasses import dataclass
from typing import Deque
import aiomultiprocess
@dataclass
class MetricsBatch:
"""Batch of metrics for bulk processing."""
timestamps: list[float]
latencies: list[float]
costs: list[float]
model_names: list[str]
class HighThroughputCollector:
"""
Lock-free metrics collection for high-throughput scenarios.
Uses ring buffers and background aggregation.
"""
def __init__(
self,
buffer_size: int = 10000,
aggregation_interval: float = 1.0,
num_workers: int = 4
):
# Ring buffer for lock-free writes
self._buffer: Deque[dict] = deque(maxlen=buffer_size)
self._aggregation_interval = aggregation_interval
self._num_workers = num_workers
self._running = False
def record(self, metric: dict):
"""
Record a metric. This is O(1) and never blocks.
Old metrics are dropped if buffer is full.
"""
self._buffer.append(metric)
async def start(self):
"""Start background aggregation workers."""
self._running = True
async with aiomultiprocess.Pool(processes=self._num_workers) as pool:
while self._running:
await asyncio.sleep(self._aggregation_interval)
# Drain buffer
batch = []
while self._buffer:
try:
batch.append(self._buffer.popleft())
except IndexError:
break
if batch:
# Parallel aggregation across workers
chunk_size = len(batch) // self._num_workers
chunks = [
batch[i:i + chunk_size]
for i in range(0, len(batch), chunk_size)
]
results = await pool.map(self._aggregate_chunk, chunks)
combined = self._combine_aggregations(results)
await self._emit_metrics(combined)
@staticmethod
def _aggregate_chunk(metrics: list[dict]) -> dict:
"""Aggregate a chunk of metrics (runs in worker process)."""
aggregated = {
'count': len(metrics),
'total_cost': sum(m.get('cost', 0) for m in metrics),
'latencies': [m.get('latency', 0) for m in metrics],
'by_model': {}
}
for metric in metrics:
model = metric.get('model', 'unknown')
if model not in aggregated['by_model']:
aggregated['by_model'][model] = {'count': 0, 'cost': 0}
aggregated['by_model'][model]['count'] += 1
aggregated['by_model'][model]['cost'] += metric.get('cost', 0)
return aggregated
def _combine_aggregations(self, results: list[dict]) -> dict:
"""Combine aggregations from multiple workers."""
combined = {
'count': sum(r['count'] for r in results),
'total_cost': sum(r['total_cost'] for r in results),
'latencies': [],
'by_model': {}
}
for result in results:
combined['latencies'].extend(result['latencies'])
for model, stats in result['by_model'].items():
if model not in combined['by_model']:
combined['by_model'][model] = {'count': 0, 'cost': 0}
combined['by_model'][model]['count'] += stats['count']
combined['by_model'][model]['cost'] += stats['cost']
return combined
async def _emit_metrics(self, aggregated: dict):
"""Emit aggregated metrics to monitoring backend."""
# Prometheus, DataDog, CloudWatch, etc.
pass
|
Resource Limits and Backpressure
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
| # kubernetes/deployment.yaml
apiVersion: apps/v1
kind: Deployment
metadata:
name: ai-agent-service
spec:
replicas: 3
template:
spec:
containers:
- name: agent
resources:
requests:
memory: "512Mi"
cpu: "500m"
limits:
memory: "2Gi"
cpu: "2000m"
env:
# Limit callback buffer to prevent OOM
- name: OBSERVABILITY_BUFFER_MAX_SIZE
value: "50000"
- name: OBSERVABILITY_BUFFER_MAX_MEMORY_MB
value: "256"
# Drop metrics under extreme load rather than crash
- name: OBSERVABILITY_BACKPRESSURE_STRATEGY
value: "drop_oldest"
|
Conclusion and Next Steps
You now have the foundation for a production-grade observability pipeline that tracks every LLM call, measures costs down to the penny, and provides the latency insights needed to optimize agent performance.
Key takeaways:
- Use async, batched writes — Never let observability code block your agent’s critical path
- Propagate context religiously — Trace IDs that connect user requests to individual LLM calls are invaluable during incidents
- Build for failure — Circuit breakers and graceful degradation ensure your observability layer never takes down your application
- Measure the observer — Benchmark your callbacks and set overhead budgets
Recommended next steps:
- Add anomaly detection — Use the cost and latency data you’re now collecting to build ML models that detect unusual patterns before users complain
- Implement cost attribution — Extend the audit logs to include feature flags and A/B test variants so you can calculate ROI per feature
- Build self-healing capabilities — When the pipeline detects repeated failures or cost spikes, automatically switch to fallback models or rate limit specific users
- Create executive dashboards — The data you’re collecting enables powerful business intelligence: cost per customer segment, LLM ROI by use case, capacity planning forecasts
The observability pipeline you’ve built isn’t just about monitoring — it’s the foundation for continuous improvement of your AI agents. Every millisecond saved and every dollar tracked compounds into significant competitive advantage over time.
Additional Resources
Common Mistakes and Troubleshooting
Mistake 1: Blocking the Agent Event Loop
The most critical error I see in production deployments is performing synchronous I/O operations inside callbacks, which blocks the entire agent execution.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
| # ❌ WRONG: Blocking the event loop with synchronous calls
class BlockingCallback(BaseCallback):
def on_llm_end(self, response: LLMResponse, **kwargs):
# This blocks the agent while waiting for the database
self.db_connection.execute(
"INSERT INTO metrics (tokens, cost) VALUES (?, ?)",
(response.token_count, response.cost)
)
# Synchronous HTTP call - blocks until complete
requests.post("https://metrics.example.com/ingest", json=response.to_dict())
# âś… CORRECT: Non-blocking async operations with buffering
class NonBlockingCallback(BaseCallback):
def __init__(self):
self._buffer: list[dict] = []
self._buffer_lock = asyncio.Lock()
self._flush_task: Optional[asyncio.Task] = None
async def on_llm_end(self, response: LLMResponse, **kwargs):
# Quick append to in-memory buffer
async with self._buffer_lock:
self._buffer.append({
"tokens": response.token_count,
"cost": response.cost,
"timestamp": time.time()
})
# Schedule flush if buffer is large enough
if len(self._buffer) >= 100 and not self._flush_task:
self._flush_task = asyncio.create_task(self._flush_buffer())
async def _flush_buffer(self):
async with self._buffer_lock:
to_flush = self._buffer.copy()
self._buffer.clear()
# Async HTTP call - doesn't block the agent
async with aiohttp.ClientSession() as session:
await session.post(
"https://metrics.example.com/ingest",
json=to_flush
)
self._flush_task = None
|
⚠️ Warning: A single blocking call in a callback can add hundreds of milliseconds to every LLM interaction. At scale, this compounds into significant latency degradation and increased costs.
Mistake 2: Missing Error Boundaries
Callbacks that throw exceptions will crash your agent. Always wrap callback logic in error boundaries.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
| # ❌ WRONG: Unhandled exceptions crash the agent
class FragileCallback(BaseCallback):
def on_tool_start(self, tool_name: str, **kwargs):
# If metrics service is down, this throws and kills the agent
self.metrics_client.increment(f"tool.{tool_name}.calls")
# âś… CORRECT: Graceful degradation with error boundaries
class ResilientCallback(BaseCallback):
def __init__(self):
self._error_count = 0
self._circuit_open = False
self._last_error_time = 0
def on_tool_start(self, tool_name: str, **kwargs):
# Circuit breaker pattern - skip if too many recent failures
if self._circuit_open:
if time.time() - self._last_error_time > 60: # 60s cooldown
self._circuit_open = False
self._error_count = 0
else:
return # Skip silently during circuit break
try:
self.metrics_client.increment(f"tool.{tool_name}.calls")
self._error_count = 0 # Reset on success
except Exception as e:
self._error_count += 1
self._last_error_time = time.time()
# Open circuit after 5 consecutive failures
if self._error_count >= 5:
self._circuit_open = True
logger.warning(
f"Circuit breaker opened for metrics callback: {e}"
)
else:
logger.debug(f"Metrics callback failed (attempt {self._error_count}): {e}")
|
Mistake 3: Incorrect Cost Calculation for Streaming Responses
Streaming responses require special handling because tokens arrive incrementally.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
| # ❌ WRONG: Counting tokens per chunk leads to massive overcounting
class BrokenStreamingCallback(BaseCallback):
def on_llm_stream_chunk(self, chunk: StreamChunk, **kwargs):
# This counts the CUMULATIVE tokens in each chunk, not incremental
self.total_tokens += chunk.token_count # Overcounts by 10-50x!
# âś… CORRECT: Track incremental tokens in streaming responses
class AccurateStreamingCallback(BaseCallback):
def __init__(self):
self._stream_sessions: dict[str, StreamSession] = {}
def on_llm_stream_start(self, stream_id: str, **kwargs):
self._stream_sessions[stream_id] = StreamSession(
start_time=time.time(),
previous_token_count=0
)
def on_llm_stream_chunk(self, stream_id: str, chunk: StreamChunk, **kwargs):
session = self._stream_sessions.get(stream_id)
if not session:
return
# Calculate INCREMENTAL tokens (current - previous)
incremental_tokens = chunk.cumulative_tokens - session.previous_token_count
session.previous_token_count = chunk.cumulative_tokens
session.total_tokens += incremental_tokens
def on_llm_stream_end(self, stream_id: str, **kwargs):
session = self._stream_sessions.pop(stream_id, None)
if session:
# Now we have accurate token count for the entire stream
self._record_metrics(
tokens=session.total_tokens,
duration=time.time() - session.start_time
)
|
Troubleshooting Decision Flow
When your observability pipeline behaves unexpectedly, follow this diagnostic flow:
flowchart TD
A[Metrics Missing or Incorrect] --> B{Check Callback Registration}
B -->|Not Registered| C[Add callback to agent config]
B -->|Registered| D{Check Error Logs}
D -->|Exceptions Found| E{Exception Type}
E -->|Connection Error| F[Check network/credentials]
E -->|Timeout| G[Increase timeout or add buffering]
E -->|Serialization| H[Validate payload schema]
D -->|No Exceptions| I{Check Callback Execution}
I -->|Not Called| J[Verify event hooks match ADK version]
I -->|Called| K{Check Metric Values}
K -->|Zero Values| L[Inspect response object structure]
K -->|Wrong Values| M[Debug token/cost calculation logic]
K -->|Correct Values| N[Check metrics backend ingestion]
N --> O{Data in Backend?}
O -->|No| P[Check backend connection and auth]
O -->|Yes| Q[Issue is in visualization/queries]
Mistake 4: Memory Leaks from Unbounded Context Storage
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
| # ❌ WRONG: Storing context indefinitely causes memory leaks
class LeakyCallback(BaseCallback):
def __init__(self):
self.request_contexts = {} # Never cleaned up!
def on_request_start(self, request_id: str, **kwargs):
self.request_contexts[request_id] = {
"start_time": time.time(),
"tokens": 0
}
# âś… CORRECT: TTL-based cleanup with bounded storage
class BoundedCallback(BaseCallback):
def __init__(self, max_contexts: int = 10000, ttl_seconds: int = 300):
self._contexts: OrderedDict[str, ContextEntry] = OrderedDict()
self._max_contexts = max_contexts
self._ttl = ttl_seconds
def on_request_start(self, request_id: str, **kwargs):
self._cleanup_expired()
# Evict oldest if at capacity
while len(self._contexts) >= self._max_contexts:
self._contexts.popitem(last=False)
self._contexts[request_id] = ContextEntry(
start_time=time.time(),
tokens=0
)
def _cleanup_expired(self):
now = time.time()
expired = [
rid for rid, ctx in self._contexts.items()
if now - ctx.start_time > self._ttl
]
for rid in expired:
del self._contexts[rid]
|
đź’ˇ Tip: In production, monitor your callback’s memory footprint using tracemalloc or similar profiling tools. Set alerts for memory growth that exceeds expected bounds.
Conclusion and Next Steps
Building production-ready observability for AI agents requires treating callbacks as first-class infrastructure components, not afterthoughts. Throughout this guide, we’ve covered the essential patterns:
Cost Attribution: Track every token and compute dollar back to specific users, features, and requests using hierarchical metric labels.
Latency Decomposition: Measure not just total response time, but TTFT, TTLT, and per-component durations to identify optimization opportunities.
Error Classification: Categorize failures by type, recoverability, and cost impact to build effective alerting and auto-remediation.
Non-blocking Architecture: Use buffering, async I/O, and circuit breakers to ensure observability never degrades agent performance.
The investment in proper observability infrastructure pays dividends immediately. Teams with mature observability pipelines consistently report:
- 30-50% reduction in wasted LLM spend through identifying and eliminating inefficient prompts
- 60% faster incident resolution through precise latency attribution
- Proactive cost management with accurate forecasting and anomaly detection
Recommended Next Steps
Week 1-2: Foundation
- Deploy the basic callback structure with cost and latency tracking
- Configure Prometheus/Grafana or your preferred metrics backend
- Set up initial dashboards for the four golden signals
Week 3-4: Enhancement
- Add distributed tracing integration with OpenTelemetry
- Implement budget alerting with multiple threshold tiers
- Create per-user and per-feature cost attribution
Month 2: Optimization
- Analyze collected data to identify prompt optimization opportunities
- Build automated anomaly detection for cost spikes
- Implement predictive scaling based on usage patterns
Ongoing: Refinement
- Review and tune alert thresholds monthly
- Add custom business metrics as requirements evolve
- Share dashboards with stakeholders for cost transparency
📝 Note: The observability pipeline you build today becomes the foundation for advanced capabilities like A/B testing prompt variations, automated cost optimization, and ML-driven anomaly detection. Design for extensibility from the start.
Additional Resources