Implementing a Dynamic Model Routing Gateway for RAG Workloads with etcd and Neo4j


The core technical problem is managing a fleet of specialized, fine-tuned transformer models for various Retrieval-Augmented Generation (RAG) tasks. Each model serves a distinct business domain—finance, legal, healthcare—and relies on a corresponding, domain-specific knowledge graph stored in Neo4j. The operational requirement is to route incoming API requests to the appropriate model based on request metadata, with the ability to add, remove, or update model endpoints in real-time without service interruptions or redeployments. A simple load balancer with static routing rules is insufficient for this dynamic environment.

Architectural Contenders

Solution A: The Monolithic Orchestrator

The most straightforward approach is a single, monolithic application that contains all the routing logic and potentially even loads all models into its own process. Routing rules could be stored in a configuration file (config.yaml) or a relational database table.

  • Pros:

    • Simplicity: Initial development is fast. The entire logic is in one place, making it easy to reason about for a small number of models.
    • Low Infrastructure Overhead: A single service is easier to deploy and manage than a distributed system. No need for external coordination services.
  • Cons:

    • Tight Coupling: The routing logic is tightly coupled with the model serving code. To update a route for a legal model, the entire monolith must be redeployed, affecting finance and healthcare services. This violates the principle of independent deployability.
    • Resource Contention: If models are loaded in the same process, a memory-intensive model can starve others. A CPU-bound request for one model can increase latency for all others. Scaling is coarse-grained; you scale the entire monolith, even if only one model is under heavy load.
    • Configuration Rigidity: Updating routing rules via a configuration file requires a restart. Using a database adds complexity and introduces another point of failure, and polling the database for changes is inefficient and introduces lag.

In a production environment, the downtime and risk associated with redeploying a monolith to change a simple routing rule is unacceptable. This approach lacks the operational flexibility required.

Solution B: The Service Mesh Approach (e.g., Istio)

A service mesh provides a dedicated infrastructure layer for managing service-to-service communication. We could deploy each model as a separate microservice and use the service mesh’s control plane to manage routing rules based on HTTP headers.

  • Pros:

    • Decoupling: Routing logic is completely externalized from the application code into the service mesh sidecars. Application developers can focus solely on model serving logic.
    • Advanced Traffic Management: Service meshes offer sophisticated features out-of-the-box, such as canary deployments, traffic mirroring, fault injection, and automatic retries.
    • Observability: Provides consistent, platform-level metrics, logs, and traces for all services in the mesh.
  • Cons:

    • Operational Complexity: A service mesh is a powerful but complex distributed system in its own right. It requires specialized expertise to deploy, manage, and debug. The learning curve is steep.
    • Resource Overhead: Each application pod runs a sidecar proxy (e.g., Envoy), which consumes additional CPU and memory resources. For a large number of model services, this overhead can become significant.
    • Semantic Mismatch: A service mesh primarily operates at L4/L7 (network/request level). Our logic is more application-aware; it’s not just routing to a service host:port, but understanding a mapping from a business domain to a model_name, service_address, and its corresponding neo4j_database. While possible to implement, it can feel like fitting a square peg in a round hole.

The service mesh is a viable but heavy-handed solution. For our specific need—dynamic, application-aware routing configuration—it introduces more complexity than necessary.

Final Choice: A Smart API Gateway with an etcd Control Plane

The chosen architecture is a middle ground that provides the required dynamism without the full complexity of a service mesh. It consists of three main components:

  1. Model-Serving Microservices: Each Hugging Face Transformer model is containerized as a standalone microservice. It exposes a simple HTTP endpoint and is only responsible for its specific RAG task.
  2. etcd Cluster: A distributed, reliable key-value store used as the control plane. It holds the canonical mapping of business domains to model service configurations.
  3. Smart API Gateway: A custom, lightweight service that acts as the single entry point. It holds no static routing rules. Instead, it connects to etcd, loads the routing configuration into memory, and subscribes to real-time updates.
graph TD
    subgraph Control Plane
        A[Operator/CI/CD] -- etcdctl put --> B(etcd Cluster);
    end

    subgraph Data Plane
        C(Client) -- "POST /query\nHeader: X-Model-Domain: finance" --> D{API Gateway};
        D -- 1. Watch for updates --> B;
        D -- 2. Route Lookup (In-Memory) --> E[Finance Model Service];
        D -- "Proxy Request" --> E;
        E -- "Cypher Query" --> F(Neo4j: finance_db);
        F -- "Context" --> E;
        E -- "LLM Inference" --> G[Hugging Face Transformer];
        G -- "Answer" --> E;
        E -- "Response" --> D;
        D -- "Response" --> C;

        D -- 2. Route Lookup (In-Memory) --> H[Legal Model Service];
    end

    style B fill:#f9f,stroke:#333,stroke-width:2px
    style D fill:#bbf,stroke:#333,stroke-width:2px

Rationale:

  • Balance of Power: This architecture separates the control plane (etcd) from the data plane (API Gateway and model services). This is a core principle of robust distributed systems.
  • Operational Simplicity: Updating a route is a simple etcdctl put command. This can be easily integrated into CI/CD pipelines. The change is propagated almost instantly to the gateway without any restarts.
  • Efficiency: The API Gateway caches the routing table in memory, making lookups extremely fast (nanoseconds). It avoids a network hop to etcd for every request. The watch mechanism is highly efficient, as etcd pushes notifications rather than the gateway polling for changes.
  • Custom Logic: The gateway is our own code, allowing us to implement custom logic beyond simple routing, such as request validation, authentication, or enriched logging, that is specific to our AI workload.

Core Implementation

The system will be orchestrated using Docker Compose for local development and testing.

1. etcd Configuration Structure

We define a clear key structure within etcd to store our routing and model configuration.

Key Prefix: /services/rag/models/

Example Keys:

  • /services/rag/models/finance
  • /services/rag/models/legal

The value for each key will be a JSON string containing the necessary metadata.

Example finance model configuration:

{
  "service_url": "http://model-finance:8000/query",
  "model_name": "distilbert-base-cased-distilled-squad",
  "neo4j_uri": "bolt://neo4j:7687",
  "neo4j_database": "finance",
  "active": true
}

Example legal model configuration:

{
  "service_url": "http://model-legal:8000/query",
  "model_name": "nlpaueb/legal-bert-base-uncased",
  "neo4j_uri": "bolt://neo4j:7687",
  "neo4j_database": "legal",
  "active": true
}

This structure is extensible. We could later add fields for version, canary_weight, etc.

2. The Model-Serving Microservice (Python, FastAPI)

This service is intentionally simple. It loads a specified model and connects to the correct Neo4j database based on environment variables.

model_service/main.py

import os
import logging
from contextlib import asynccontextmanager

from fastapi import FastAPI, HTTPException, Request
from pydantic import BaseModel
from transformers import pipeline, AutoTokenizer, AutoModelForQuestionAnswering
from neo4j import AsyncGraphDatabase, exceptions

# --- Basic Configuration ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# --- Environment Variables ---
# These will be set in the Dockerfile/docker-compose.yml for each service instance
MODEL_NAME = os.environ.get("MODEL_NAME", "distilbert-base-cased-distilled-squad")
NEO4J_URI = os.environ.get("NEO4J_URI", "bolt://localhost:7687")
NEO4J_USER = os.environ.get("NEO4J_USER", "neo4j")
NEO4J_PASSWORD = os.environ.get("NEO4J_PASSWORD", "password")
NEO4J_DATABASE = os.environ.get("NEO4J_DATABASE", "neo4j")

# --- Global State ---
# Store model and DB driver in the app's state to be managed by lifespan
app_state = {}

@asynccontextmanager
async def lifespan(app: FastAPI):
    # --- Startup Logic ---
    logger.info(f"Initializing model service for '{MODEL_NAME}' on Neo4j DB '{NEO4J_DATABASE}'")
    
    # 1. Load Hugging Face Model
    try:
        logger.info(f"Loading tokenizer for {MODEL_NAME}...")
        tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
        logger.info(f"Loading model for {MODEL_NAME}...")
        model = AutoModelForQuestionAnswering.from_pretrained(MODEL_NAME)
        app_state["qa_pipeline"] = pipeline("question-answering", model=model, tokenizer=tokenizer)
        logger.info("QA pipeline loaded successfully.")
    except Exception as e:
        logger.error(f"Failed to load model '{MODEL_NAME}': {e}", exc_info=True)
        # In a real system, you might want to exit or enter a degraded state
        raise RuntimeError(f"Model loading failed for {MODEL_NAME}") from e

    # 2. Initialize Neo4j Driver
    try:
        logger.info(f"Connecting to Neo4j at {NEO4J_URI} for database '{NEO4J_DATABASE}'")
        app_state["neo4j_driver"] = AsyncGraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
        # Verify connection
        await app_state["neo4j_driver"].verify_connectivity()
        logger.info("Neo4j driver initialized and connection verified.")
    except exceptions.ServiceUnavailable as e:
        logger.error(f"Could not connect to Neo4j: {e}", exc_info=True)
        raise RuntimeError("Neo4j connection failed") from e
    
    yield
    
    # --- Shutdown Logic ---
    logger.info("Closing Neo4j driver...")
    if "neo4j_driver" in app_state and app_state["neo4j_driver"]:
        await app_state["neo4j_driver"].close()
    logger.info("Shutdown complete.")


app = FastAPI(lifespan=lifespan)

# --- Pydantic Models for Request/Response ---
class QueryRequest(BaseModel):
    question: str
    user_id: str # Example of another field that might be passed

class QueryResponse(BaseModel):
    answer: str
    score: float
    context: str

# --- Core RAG Logic ---
async def retrieve_context_from_neo4j(driver, db_name, question: str) -> str:
    """
    Retrieves context from Neo4j.
    This is a simplified example. A real implementation would involve more
    sophisticated entity extraction from the question and a more complex Cypher query.
    """
    # A common mistake is to not handle the session correctly.
    # Using async with session ensures it's properly closed.
    query = """
    // This is a placeholder for a real entity-driven graph query.
    // For example, one could use a full-text index to find relevant nodes.
    MATCH (c:Chunk)
    WHERE c.text IS NOT NULL
    RETURN c.text AS text
    ORDER BY rand()
    LIMIT 5
    """
    try:
        async with driver.session(database=db_name) as session:
            result = await session.run(query)
            records = await result.data()
            # Concatenate chunks to form a context
            context = " ".join([record['text'] for record in records if record.get('text')])
            if not context:
                logger.warning(f"No context found in Neo4j for question: '{question}'")
                return "No relevant information found."
            return context
    except exceptions.Neo4jError as e:
        logger.error(f"Neo4j query failed in database '{db_name}': {e}")
        raise HTTPException(status_code=503, detail="Knowledge graph query failed.")


@app.post("/query", response_model=QueryResponse)
async def process_query(request: QueryRequest):
    if "qa_pipeline" not in app_state or "neo4j_driver" not in app_state:
        logger.error("Service is not ready. Model or DB driver not initialized.")
        raise HTTPException(status_code=503, detail="Service Unavailable: Not Initialized")
    
    logger.info(f"Received query for user '{request.user_id}': '{request.question}'")
    
    # 1. Retrieve context
    context = await retrieve_context_from_neo4j(
        app_state["neo4j_driver"],
        NEO4J_DATABASE,
        request.question
    )

    # 2. Generate answer
    qa_pipeline = app_state["qa_pipeline"]
    try:
        result = qa_pipeline(question=request.question, context=context)
        logger.info(f"Generated answer with score {result['score']:.4f}")
        return QueryResponse(
            answer=result["answer"],
            score=result["score"],
            context=context
        )
    except Exception as e:
        # The transformer pipeline can fail for various reasons (e.g., input size).
        logger.error(f"Error during QA pipeline execution: {e}", exc_info=True)
        raise HTTPException(status_code=500, detail="Failed to generate answer.")

3. The Smart API Gateway (Python, aiohttp, etcd-client)

This is the most critical component. It must manage its connection to etcd, handle atomic updates to its routing table, and proxy requests efficiently. We use aiohttp for its robust and performant asynchronous client/server capabilities.

api_gateway/main.py

import asyncio
import json
import logging
import os
import threading
from typing import Dict, Any, Optional

import etcd_client
from aiohttp import web, ClientSession, ClientTimeout

# --- Configuration ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

ETCD_HOST = os.environ.get("ETCD_HOST", "etcd")
ETCD_PORT = int(os.environ.get("ETCD_PORT", 2379))
ETCD_KEY_PREFIX = "/services/rag/models/"
GATEWAY_PORT = int(os.environ.get("GATEWAY_PORT", 8080))
REQUEST_TIMEOUT_SECONDS = 30

# --- State Management ---
# This routing table is the heart of the gateway. It's an in-memory cache.
# The `_lock` is crucial to prevent race conditions during updates from the etcd watcher.
class RouteCache:
    def __init__(self):
        self._routes: Dict[str, Dict[str, Any]] = {}
        self._lock = threading.Lock() # Use threading.Lock as etcd client is sync

    def get_route(self, domain: str) -> Optional[Dict[str, Any]]:
        with self._lock:
            return self._routes.get(domain)

    def update_routes(self, new_routes: Dict[str, Dict[str, Any]]):
        with self._lock:
            logger.info(f"Updating route cache. New route count: {len(new_routes)}. Old count: {len(self._routes)}")
            self._routes = new_routes
            logger.info("Route cache updated successfully.")
            logger.debug(f"Current routes: {json.dumps(self._routes, indent=2)}")

# Global instance of our cache
route_cache = RouteCache()

# --- Etcd Watcher Logic ---
def initialize_and_watch_etcd():
    """
    Runs in a separate thread to watch for etcd changes.
    Using a synchronous client in a thread is a pragmatic way to bridge
    the gap with libraries that may not be fully async.
    """
    while True:
        try:
            logger.info(f"Connecting to etcd at {ETCD_HOST}:{ETCD_PORT}")
            client = etcd_client.EtcdClient(host=ETCD_HOST, port=ETCD_PORT)
            
            # Initial load
            logger.info(f"Performing initial load from prefix '{ETCD_KEY_PREFIX}'")
            initial_routes = {}
            for value, metadata in client.get_prefix(ETCD_KEY_PREFIX):
                domain = metadata['key'].decode().split('/')[-1]
                try:
                    config = json.loads(value.decode())
                    if config.get("active", False):
                        initial_routes[domain] = config
                        logger.info(f"Loaded initial route for domain '{domain}'")
                except json.JSONDecodeError:
                    logger.error(f"Failed to parse JSON for key {metadata['key'].decode()}")
            route_cache.update_routes(initial_routes)

            # Start watching for changes
            logger.info("Starting to watch for etcd changes...")
            events_iterator, cancel = client.watch_prefix(ETCD_KEY_PREFIX)
            for event in events_iterator:
                # On any change, we re-read the entire prefix. This is simpler and safer
                # than trying to process individual PUT/DELETE events, as it prevents
                # inconsistencies if an event is missed. This is a key design choice
                # for robustness over micro-optimization.
                logger.info(f"Detected change in etcd prefix '{ETCD_KEY_PREFIX}'. Re-syncing all routes.")
                current_routes = {}
                for value, metadata in client.get_prefix(ETCD_KEY_PREFIX):
                    domain = metadata['key'].decode().split('/')[-1]
                    try:
                        config = json.loads(value.decode())
                        if config.get("active", False):
                            current_routes[domain] = config
                    except json.JSONDecodeError:
                        logger.error(f"Failed to parse JSON for key {metadata['key'].decode()} during re-sync")
                route_cache.update_routes(current_routes)

        except Exception as e:
            logger.error(f"Error in etcd watcher thread: {e}. Retrying in 5 seconds...", exc_info=True)
            time.sleep(5)


# --- API Handler ---
async def handle_query(request: web.Request) -> web.Response:
    domain = request.headers.get("X-Model-Domain")
    if not domain:
        return web.Response(status=400, text="Missing X-Model-Domain header.")

    route_info = route_cache.get_route(domain)
    if not route_info:
        logger.warning(f"No active route found for domain '{domain}'")
        return web.Response(status=404, text=f"No model configured for domain '{domain}'.")

    target_url = route_info.get("service_url")
    if not target_url:
        logger.error(f"Invalid route configuration for domain '{domain}': missing 'service_url'")
        return web.Response(status=500, text="Internal configuration error.")

    try:
        request_data = await request.json()
    except json.JSONDecodeError:
        return web.Response(status=400, text="Invalid JSON body.")

    # A common pitfall is not creating a new ClientSession for each request or
    # not managing it properly. For a gateway, creating one session at startup is best.
    http_client: ClientSession = request.app['http_client']
    
    logger.info(f"Proxying request for domain '{domain}' to {target_url}")
    try:
        async with http_client.post(
            target_url,
            json=request_data,
            timeout=ClientTimeout(total=REQUEST_TIMEOUT_SECONDS)
        ) as response:
            body = await response.read()
            # We must pass through the status and headers from the downstream service.
            return web.Response(
                body=body,
                status=response.status,
                content_type=response.content_type
            )
    except asyncio.TimeoutError:
        logger.error(f"Request to downstream service {target_url} timed out.")
        return web.Response(status=504, text="Gateway Timeout.")
    except Exception as e:
        logger.error(f"Error proxying request to {target_url}: {e}", exc_info=True)
        return web.Response(status=502, text="Bad Gateway.")


async def on_startup(app: web.Application):
    # Create a single, reusable ClientSession for the lifetime of the application.
    app['http_client'] = ClientSession()
    # Start the etcd watcher in a background thread.
    threading.Thread(target=initialize_and_watch_etcd, daemon=True).start()

async def on_shutdown(app: web.Application):
    await app['http_client'].close()


# --- Application Setup ---
app = web.Application()
app.router.add_post('/api/v1/query', handle_query)
app.on_startup.append(on_startup)
app.on_shutdown.append(on_shutdown)

if __name__ == '__main__':
    web.run_app(app, port=GATEWAY_PORT)

Note on etcd3-py vs etcd-client: The original etcd3-py library is not actively maintained. python-etcd3 or etcd-client are more modern alternatives. etcd-client is used here. A fully asyncio compatible etcd client would be ideal, but running the synchronous client in a thread is a common and robust pattern.

4. Orchestration with docker-compose.yml

This file ties everything together, making the entire system runnable with a single command.

version: '3.8'

services:
  etcd:
    image: bitnami/etcd:3.5
    environment:
      - ALLOW_NONE_AUTHENTICATION=yes
      - ETCD_ADVERTISE_CLIENT_URLS=http://etcd:2379
    ports:
      - "2379:2379"

  neo4j:
    image: neo4j:5.12
    environment:
      - NEO4J_AUTH=neo4j/password
      # This allows us to create multiple databases
      - NEO4J_PLUGINS=["apoc"]
      - NEO4J_dbms_default__database=finance
    ports:
      - "7474:7474"
      - "7687:7687"
    volumes:
      - neo4j_data:/data

  # Helper service to populate initial data
  config-loader:
    build:
      context: ./config-loader
    depends_on:
      - etcd
      - neo4j
    environment:
      - ETCD_HOST=etcd
      - NEO4J_URI=bolt://neo4j:7687
      - NEO4J_USER=neo4j
      - NEO4J_PASSWORD=password

  api-gateway:
    build:
      context: ./api_gateway
    ports:
      - "8080:8080"
    environment:
      - ETCD_HOST=etcd
    depends_on:
      - etcd

  model-finance:
    build:
      context: ./model_service
    environment:
      - MODEL_NAME=distilbert-base-cased-distilled-squad
      - NEO4J_URI=bolt://neo4j:7687
      - NEO4J_USER=neo4j
      - NEO4J_PASSWORD=password
      - NEO4J_DATABASE=finance
    depends_on:
      - neo4j
    # For production, you'd need GPUs and proper resource limits.
    # deploy:
    #   resources:
    #     reservations:
    #       devices:
    #         - driver: nvidia
    #           count: 1
    #           capabilities: [gpu]

  model-legal:
    build:
      context: ./model_service
    environment:
      - MODEL_NAME=nlpaueb/legal-bert-base-uncased
      - NEO4J_URI=bolt://neo4j:7687
      - NEO4J_USER=neo4j
      - NEO4J_PASSWORD=password
      - NEO4J_DATABASE=legal
    depends_on:
      - neo4j

volumes:
  neo4j_data:

The config-loader is a one-shot container that runs a script to create the databases in Neo4j and populate the initial routing configuration in etcd. This is a critical piece for creating a reproducible test environment.

Limitations and Future Iterations

This architecture, while robust, has clear boundaries. The API gateway is a potential single point of failure and a performance bottleneck. While it can be scaled horizontally, all traffic must pass through it. For extremely high-throughput scenarios, a more decentralized approach like a service mesh might become necessary.

The current implementation assumes model services are always running. A more advanced system could integrate with an orchestrator like Kubernetes to dynamically scale model services up from zero based on demand, which introduces challenges around model load times and cold starts. The gateway would need to be aware of the scaling state of downstream services, a feature for which etcd could also serve as the control plane.

Finally, the re-sync-on-change strategy in the etcd watcher is simple and safe but could be inefficient if the number of routes grows into the thousands. At that scale, a more granular event processing logic might be required, though this would significantly increase the complexity of maintaining cache consistency.


  TOC