Our initial MLOps inference pipeline was a black box. An HTTP request would trigger a prediction, but diagnosing failures or performance degradation was a painful exercise in archaeology. We had logs from the Node.js API gateway and separate, disconnected logs from the Python model-serving workers. Correlating a specific user-facing error with the exact inference task that caused it involved grepping through multiple log streams with timestamps and a prayer. This was not sustainable. The clear technical mandate was to build a unified, queryable view for the entire lifecycle of every single inference request, from the moment it hit our edge service to the moment a prediction was calculated.
The core concept was simple: a single, unique trace_id
generated at the entry point must be propagated through every component of the system. This ID would be the primary key for correlating logs and metrics. For our stack, this meant bridging a language gap. The API gateway is a Node.js Express.js
application, chosen for its performance and our team’s expertise in the JavaScript ecosystem. The heavy lifting of model inference is handled by Python workers using Celery
, the de facto standard for distributed task processing in Python. For the observability backend, we bypassed generic logging platforms. Our goal was not just to store logs but to perform complex time-series analysis on performance metrics—p99 latency per model version, inference duration distributions, and failure rates over time. This requirement made TimescaleDB
a perfect fit, offering the power of SQL with specialized optimizations for time-series data.
Finally, this intricate system of tracing could not be left to chance. A broken link in the traceability chain would render the entire system useless. Therefore, we made Jest
a first-class citizen, not just for testing application logic, but for writing integration tests that rigorously enforce our observability contract, ensuring that every component correctly propagates and records the trace_id
.
The Observability Schema in TimescaleDB
Before writing a single line of application code, we defined the data contract in the database. A single, unified hypertable would store events from all services. Using a single table simplifies querying immensely, as we never need to JOIN
across different tables to reconstruct a request’s journey.
In a real-world project, you’d manage this with a migration tool like node-pg-migrate
or Alembic
. For clarity, here is the raw SQL.
-- Represents an event in the lifecycle of an inference request.
-- This table is designed to be a TimescaleDB hypertable.
CREATE TABLE inference_events (
"time" TIMESTAMPTZ NOT NULL,
trace_id UUID NOT NULL,
service_name VARCHAR(50) NOT NULL,
event_name VARCHAR(100) NOT NULL, -- e.g., 'request_received', 'task_enqueued', 'inference_start'
model_id VARCHAR(100), -- The identifier for the model being used
model_version VARCHAR(20),
duration_ms BIGINT, -- Duration of the specific event, if applicable
status_code INT, -- e.g., HTTP status for API events, 0 for success/1 for failure for worker events
metadata JSONB -- Flexible field for additional context, like request headers or error messages
);
-- Add comments for clarity on critical columns
COMMENT ON COLUMN inference_events.trace_id IS 'Unique identifier for an entire request lifecycle, generated at the API gateway.';
COMMENT ON COLUMN inference_events.service_name IS 'The name of the service emitting the event (e.g., express-api, celery-worker-vision).';
COMMENT ON COLUMN inference_events.event_name IS 'A specific, well-defined step within the service.';
COMMENT ON COLUMN inference_events.duration_ms IS 'Time taken for this specific event, not the cumulative time.';
COMMENT ON COLUMN inference_events.metadata IS 'Stores context like error stack traces, input feature dimensions, etc.';
-- Create a composite index for efficient lookups by trace_id and time.
-- This is crucial for reconstructing a single trace quickly.
CREATE INDEX idx_trace_id_time ON inference_events (trace_id, "time" DESC);
-- Create an index on service and event names for analytical queries.
-- Useful for calculating metrics per service/event type.
CREATE INDEX idx_service_event_time ON inference_events (service_name, event_name, "time" DESC);
-- Finally, convert the regular table into a TimescaleDB hypertable, partitioned by the time column.
-- This is the magic that makes time-series queries fast.
SELECT create_hypertable('inference_events', 'time');
The choice of a UUID
for trace_id
ensures global uniqueness without a central coordinator. The composite index on (trace_id, "time")
is the most critical performance consideration here; it allows TimescaleDB to rapidly retrieve all events for a given trace in chronological order. The metadata
JSONB
column provides the flexibility to add rich context without altering the schema for every new piece of information we want to track.
The Express.js Gateway: Generation and Propagation
The Express.js application serves as the system’s entry point. Its two primary responsibilities in our observability scheme are:
- Generate a unique
trace_id
for every incoming request. - Pass this
trace_id
downstream into the message queue for the Celery workers.
We accomplish the first with a simple middleware. This ensures that the ID is available on the request
object for all subsequent handlers and loggers.
// src/middleware/trace.js
import { randomUUID } from 'crypto';
// This middleware injects a unique trace_id into every request.
// In a production environment, it should also check for an existing trace ID
// header (e.g., X-Request-ID, or W3C Trace Context headers) and use it if present.
export const traceMiddleware = (req, res, next) => {
req.traceId = randomUUID();
res.setHeader('X-Trace-ID', req.traceId);
next();
};
The next piece is the structured logger. We use pino
for its performance and JSON output. A common mistake is to have logging be an afterthought. Here, we create a logger instance that is aware of our observability context.
// src/lib/logger.js
import pino from 'pino';
const logger = pino({
level: process.env.LOG_LEVEL || 'info',
formatters: {
level: (label) => {
return { level: label };
},
},
// Base properties to include in every log message.
base: {
service: 'express-api',
pid: process.pid,
},
timestamp: pino.stdTimeFunctions.isoTime,
});
// A helper to create a child logger with a specific traceId.
// This ensures all logs for a given request are tagged with the same ID.
export const getRequestLogger = (traceId) => {
return logger.child({ traceId });
};
Now, let’s wire it all together in the main application file and an endpoint handler. We also need a mechanism to enqueue tasks. For this, we’ll use amqplib
to communicate directly with RabbitMQ, which Celery uses as a message broker.
// src/server.js
import express from 'express';
import { traceMiddleware } from './middleware/trace.js';
import { getRequestLogger } from './lib/logger.js';
import { getDbPool } from './lib/db.js';
import { getAmqpChannel } from './lib/rabbitmq.js';
const app = express();
app.use(express.json());
app.use(traceMiddleware);
const PORT = process.env.PORT || 3000;
// Centralized event logging function
async function logInferenceEvent(event) {
const pool = getDbPool();
try {
await pool.query(
`INSERT INTO inference_events (
"time", trace_id, service_name, event_name, model_id, model_version,
duration_ms, status_code, metadata
) VALUES (NOW(), $1, 'express-api', $2, $3, $4, $5, $6, $7)`,
[
event.traceId,
event.eventName,
event.modelId,
event.modelVersion,
event.durationMs,
event.statusCode,
event.metadata || {},
]
);
} catch (err) {
// In production, this failure should be sent to a dead-letter queue
// or an emergency logger. Swallowing it silently blinds us.
console.error('Failed to log inference event to TimescaleDB', err);
}
}
app.post('/predict/:modelId/:modelVersion', async (req, res) => {
const startTime = process.hrtime.bigint();
const logger = getRequestLogger(req.traceId);
const { modelId, modelVersion } = req.params;
const features = req.body.features;
if (!features) {
return res.status(400).json({ error: 'Missing features in request body' });
}
logger.info({ modelId, modelVersion }, 'Prediction request received');
await logInferenceEvent({
traceId: req.traceId,
eventName: 'request_received',
modelId,
modelVersion,
statusCode: 202, // Accepted
});
try {
const channel = await getAmqpChannel();
const celeryQueue = 'ml_inference';
// This is the critical propagation step. The trace_id and other metadata
// are packaged into the Celery message payload.
const task = {
id: req.traceId, // Using traceId as task id for easy correlation in Flower
task: 'worker.tasks.predict',
args: [features],
kwargs: {
model_id: modelId,
model_version: modelVersion,
},
};
// Celery messages need a specific structure and content-type.
// The `headers` field is where Celery expects routing information.
// We embed our tracing context directly inside the task payload itself, under a custom `context` key.
const messageBody = {
...task,
// Custom context field that our worker will be designed to look for.
context: {
trace_id: req.traceId,
request_time: new Date().toISOString(),
},
};
const message = Buffer.from(JSON.stringify(messageBody));
channel.sendToQueue(celeryQueue, message, {
contentType: 'application/json',
contentEncoding: 'utf-8',
});
const durationMs = Number(process.hrtime.bigint() - startTime) / 1_000_000;
await logInferenceEvent({
traceId: req.traceId,
eventName: 'task_enqueued',
modelId,
modelVersion,
durationMs,
statusCode: 202,
});
logger.info({ durationMs }, 'Task enqueued successfully');
res.status(202).json({
message: 'Prediction task accepted',
traceId: req.traceId,
});
} catch (err) {
const durationMs = Number(process.hrtime.bigint() - startTime) / 1_000_000;
logger.error({ err }, 'Failed to enqueue prediction task');
await logInferenceEvent({
traceId: req.traceId,
eventName: 'task_enqueue_failed',
modelId,
modelVersion,
durationMs,
statusCode: 500,
metadata: { error: err.message, stack: err.stack },
});
res.status(500).json({ error: 'Internal Server Error' });
}
});
app.listen(PORT, () => {
console.log(`API Gateway listening on port ${PORT}`);
});
The pitfall here is message format compatibility. We are not using a Python Celery client to send the message, so we must manually construct a message body that Celery workers can understand and decode. The most robust way to pass our trace_id
is to embed it directly into the task payload, which our Python worker will be specifically designed to parse.
The Celery Worker: Receiving and Using Context
On the Python side, the challenge is to cleanly extract the trace_id
from the task message and make it accessible throughout the task’s execution. A common mistake is to pass the trace_id
as an explicit argument to every function within the task. This clutters the business logic. A better approach is to use a custom Celery Task base class that automatically manages this context.
# worker/celery_app.py
import os
from celery import Celery
BROKER_URL = os.environ.get('CELERY_BROKER_URL', 'amqp://guest:guest@localhost:5672//')
RESULT_BACKEND = os.environ.get('CELERY_RESULT_BACKEND', 'rpc://')
celery_app = Celery(
'ml_worker',
broker=BROKER_URL,
backend=RESULT_BACKEND,
include=['worker.tasks']
)
celery_app.conf.update(
task_serializer='json',
accept_content=['json'],
result_serializer='json',
timezone='UTC',
enable_utc=True,
)
Now for the core logic: a custom Task class and the task itself. We will use Python’s contextvars
to store the trace_id
, which is the modern, async-friendly equivalent of thread-locals.
# worker/tasks.py
import logging
import time
import os
import contextvars
from celery import Task
from .celery_app import celery_app
from .db import get_db_connection
# Context variable to hold the trace_id for the duration of a task.
trace_id_var = contextvars.ContextVar('trace_id', default=None)
# Configure a structured logger
handler = logging.StreamHandler()
formatter = logging.Formatter('{"timestamp": "%(asctime)s", "level": "%(levelname)s", "service": "celery-worker", "trace_id": "%(trace_id)s", "message": "%(message)s"}')
handler.setFormatter(formatter)
# Custom filter to inject trace_id from contextvar into the log record
class TraceIdFilter(logging.Filter):
def filter(self, record):
record.trace_id = trace_id_var.get()
return True
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.addHandler(handler)
logger.addFilter(TraceIdFilter())
def log_inference_event(event_data):
"""Logs an event to TimescaleDB from the worker."""
conn = None
try:
conn = get_db_connection()
cur = conn.cursor()
cur.execute(
"""
INSERT INTO inference_events (
"time", trace_id, service_name, event_name, model_id, model_version,
duration_ms, status_code, metadata
) VALUES (NOW(), %s, 'celery-worker', %s, %s, %s, %s, %s, %s::jsonb)
""",
(
trace_id_var.get(),
event_data.get("event_name"),
event_data.get("model_id"),
event_data.get("model_version"),
event_data.get("duration_ms"),
event_data.get("status_code"),
event_data.get("metadata", '{}')
)
)
conn.commit()
cur.close()
except Exception as e:
logger.error(f"Database logging failed: {e}")
finally:
if conn:
conn.close()
# A custom base class for our ML tasks. This is the key to clean context management.
class MLOpsTask(Task):
def __call__(self, *args, **kwargs):
# The raw message body from Node.js is in the request context.
context = self.request.get('context', {})
trace_id = context.get('trace_id')
# Set the trace_id in the contextvar so it's accessible everywhere.
token = trace_id_var.set(trace_id)
logger.info("Starting task with context.")
try:
return super().__call__(*args, **kwargs)
finally:
# Important to reset the contextvar after the task is done.
trace_id_var.reset(token)
def run_mock_inference(features):
"""A placeholder for a real ML model inference call."""
logger.info("Running mock inference...")
time.sleep(0.5) # Simulate work
# In a real model, this would return classification, regression output etc.
return {"prediction": "class_A", "confidence": 0.95}
# The actual Celery task. Note it uses our custom base class.
@celery_app.task(base=MLOpsTask, bind=True)
def predict(self, features, model_id, model_version):
start_time = time.monotonic()
log_inference_event({
"event_name": "inference_start",
"model_id": model_id,
"model_version": model_version,
"status_code": 0 # 0 for in-progress/success
})
try:
result = run_mock_inference(features)
duration_ms = (time.monotonic() - start_time) * 1000
log_inference_event({
"event_name": "inference_success",
"model_id": model_id,
"model_version": model_version,
"duration_ms": int(duration_ms),
"status_code": 0,
"metadata": f'{{"output": {result}}}' # Ensure metadata is a valid JSON string
})
logger.info(f"Inference completed in {duration_ms:.2f}ms")
return result
except Exception as e:
duration_ms = (time.monotonic() - start_time) * 1000
logger.error(f"Inference failed: {e}", exc_info=True)
log_inference_event({
"event_name": "inference_failure",
"model_id": model_id,
"model_version": model_version,
"duration_ms": int(duration_ms),
"status_code": 1, # 1 for failure
"metadata": f'{{"error": "{str(e)}"}}'
})
# Re-raise the exception to mark the task as FAILED in Celery.
raise
By overriding the __call__
method of the MLOpsTask
class, we create a hook that runs before our task’s business logic. It inspects the incoming raw message for our custom context
field, extracts the trace_id
, and sets it in a contextvar
. Now, any code executed within this task, including the logger and the log_inference_event
function, can access the correct trace_id
without it being passed around as a function parameter.
Validating the Chain with Jest Integration Tests
The system is built, but we must prove it works. An integration test will simulate a real-world request and then query the database to verify that the entire chain of events was logged correctly with the same trace_id
.
// tests/integration/tracing.test.js
import supertest from 'supertest';
import { Pool } from 'pg';
// Assume app is exported from server.js for testing purposes
import { app } from '../../src/server';
const request = supertest(app);
describe('MLOps Inference Tracing', () => {
let dbPool;
beforeAll(() => {
// Connect to a dedicated test database.
dbPool = new Pool({ connectionString: process.env.TEST_DATABASE_URL });
});
afterAll(async () => {
await dbPool.end();
});
// Clean the events table before each test to ensure isolation.
beforeEach(async () => {
await dbPool.query('TRUNCATE TABLE inference_events');
});
test('should create a complete trace for a successful prediction request', async () => {
const modelId = 'test-model';
const modelVersion = '1.0.0';
const features = [1, 2, 3];
// Step 1: Make the API request
const response = await request
.post(`/predict/${modelId}/${modelVersion}`)
.send({ features });
expect(response.status).toBe(202);
const traceId = response.body.traceId;
expect(traceId).toBeDefined();
// Step 2: Poll the database to verify events are logged.
// In a real CI/CD pipeline, you need a robust polling mechanism.
// For this test, a simple delay and retry logic works.
let events = [];
let attempts = 0;
while (events.length < 3 && attempts < 10) {
await new Promise(resolve => setTimeout(resolve, 500)); // Wait for worker to process
const { rows } = await dbPool.query(
'SELECT * FROM inference_events WHERE trace_id = $1 ORDER BY "time" ASC',
[traceId]
);
events = rows;
attempts++;
}
// Step 3: Assert the correctness of the trace
expect(events.length).toBeGreaterThanOrEqual(3);
const serviceNames = new Set(events.map(e => e.service_name));
expect(serviceNames).toContain('express-api');
expect(serviceNames).toContain('celery-worker');
const eventNames = events.map(e => e.event_name);
// Check for the expected sequence of events
expect(eventNames).toContain('request_received');
expect(eventNames).toContain('task_enqueued');
// Depending on timing, this could be success or start, we check for start at minimum
expect(eventNames).toContain('inference_start');
// Verify that all events share the same model metadata
events.forEach(event => {
expect(event.model_id).toBe(modelId);
expect(event.model_version).toBe(modelVersion);
});
}, 15000); // Increase Jest timeout for this async test
});
This test formalizes our observability requirement. It doesn’t just check if the API returns a 202 Accepted
; it confirms the side-effects that are critical for our system’s health. If a code change in the Express app or Celery worker breaks trace propagation, this test will fail, preventing the regression from reaching production.
sequenceDiagram participant Client participant ExpressAPI as Express.js API participant RabbitMQ participant CeleryWorker as Python Celery Worker participant TimescaleDB Client->>+ExpressAPI: POST /predict/model/v1 ExpressAPI->>ExpressAPI: Generate trace_id ExpressAPI-->>Client: 202 Accepted (with trace_id) ExpressAPI->>TimescaleDB: LOG(request_received, trace_id) ExpressAPI->>+RabbitMQ: Enqueue Task (with trace_id in payload) ExpressAPI->>TimescaleDB: LOG(task_enqueued, trace_id) ExpressAPI-->>-Client: RabbitMQ->>+CeleryWorker: Deliver Task CeleryWorker->>CeleryWorker: Extract trace_id from payload CeleryWorker->>TimescaleDB: LOG(inference_start, trace_id) CeleryWorker->>CeleryWorker: Run ML Model Inference CeleryWorker->>TimescaleDB: LOG(inference_success, trace_id) CeleryWorker-->>-RabbitMQ: Ack Task
With this system in place, we can now ask meaningful questions of our MLOps platform. A query like the one below, which was previously impossible, becomes trivial:
-- Calculate p90, p95, and p99 latency for the core inference step
-- of a specific model version over the last 7 days.
SELECT
time_bucket('1 hour', "time") AS hour,
approx_percentile(0.90, duration_ms) AS p90_latency_ms,
approx_percentile(0.95, duration_ms) AS p95_latency_ms,
approx_percentile(0.99, duration_ms) AS p99_latency_ms
FROM
inference_events
WHERE
event_name = 'inference_success'
AND model_id = 'prod-classifier'
AND model_version = '2.3.1'
AND "time" > NOW() - INTERVAL '7 days'
GROUP BY
hour
ORDER BY
hour;
This level of detailed, queryable insight is the true output of the MLOps pipeline—not just the predictions themselves, but the data to improve and maintain the system that generates them.
The current implementation, where each service writes directly to TimescaleDB, is simple and effective for moderate loads. However, this pattern introduces tight coupling to the database and can become a performance bottleneck under high throughput, as services might contend for database connections and write locks. A more scalable, production-hardened architecture would involve services emitting structured logs to stdout
or a local agent like Vector, which would then be responsible for batching, buffering, and reliably shipping this data to TimescaleDB. Furthermore, the context propagation is entirely custom. While effective for this two-component system, expanding to more microservices would warrant adopting a standardized framework like OpenTelemetry, which automates context propagation through standard headers and provides SDKs for many languages, reducing the amount of boilerplate code required for instrumentation.