The core problem wasn’t a complex algorithm, but a simple, brutal mismatch of throughput. We had a sentence-transformer model from Hugging Face, perfectly capable of churning through embeddings, but only when fed data in large batches. Deploying it behind a standard REST API where requests arrive one-by-one resulted in abysmal GPU utilization, often below 10%, and a per-request latency that was unacceptable. Each individual request, no matter how small, paid the full price of model invocation overhead. A naive API is a performance bottleneck waiting to happen.
Our initial Python service, built with FastAPI, looked something like this. It worked, but it was fundamentally inefficient for the hardware it was running on.
# WARNING: This is the inefficient, anti-pattern implementation.
from fastapi import FastAPI
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
import uvicorn
import os
app = FastAPI()
# Load the model once at startup. A common and correct practice.
# The problem isn't model loading, it's the invocation pattern.
model_name = os.getenv("MODEL_NAME", "all-MiniLM-L6-v2")
model = SentenceTransformer(model_name)
class InferenceRequest(BaseModel):
text: str
class InferenceResponse(BaseModel):
embedding: list[float]
@app.post("/embed", response_model=InferenceResponse)
def create_embedding(request: InferenceRequest):
# The core issue: The model's encode() method is called for EACH request.
# This is horribly inefficient for GPU-based computation.
embedding = model.encode([request.text])[0]
return InferenceResponse(embedding=embedding.tolist())
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)
Each call to /embed
triggered a separate model.encode()
operation. On a V100 GPU, this meant we were using a sledgehammer to crack a nut, over and over. The latency for a single request was around 50ms. Serving 100 requests sequentially would take 5 seconds. The real kicker is that processing a batch of 100 sentences at once might only take 150ms. The path forward was clear: we had to implement request batching.
The first thought was to do it in Python. An asyncio-based queue could collect requests. But this couples the batching logic with the inference code, creating a monolith. In a real-world project, the service handling concurrent network I/O should be decoupled from the service doing heavy, blocking computation. The former needs to be lightweight and highly responsive; the latter needs to be optimized for raw computational throughput.
This led to a two-service architecture:
- Inference Worker (Python/FastAPI): A lean service whose only job is to receive a batch of sentences and return a batch of embeddings. It assumes its caller is smart enough to batch.
- API Gateway (Go/Gin): A highly concurrent service facing the public. It accepts individual requests, holds them in a buffer for a very short period, assembles a batch, and fires it off to the Inference Worker.
For orchestration, Kubernetes felt like overkill. We had two stable, well-defined services. The operational overhead of setting up and maintaining a K8s cluster for this wasn’t justified. Docker Swarm provided exactly what we needed: simple service discovery, declarative deployments via a compose file, and placement constraints to ensure our Python worker landed on a GPU-enabled node. It’s a pragmatic choice that prioritizes simplicity and stability.
The Batch-Aware Python Inference Worker
First, we refactor the Python service. Its API contract must change. Instead of accepting one sentence, it must accept a list of them.
# file: worker/main.py
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from sentence_transformers import SentenceTransformer
import uvicorn
import os
import logging
from typing import List
# Setup structured logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
app = FastAPI()
# A simple health check endpoint
@app.get("/health")
def health_check():
return {"status": "ok"}
# The model loading is critical. It must happen only once.
# A common mistake is to load the model inside the request handler.
try:
model_name = os.getenv("MODEL_NAME", "all-MiniLM-L6-v2")
logging.info(f"Loading model: {model_name}")
model = SentenceTransformer(model_name)
logging.info("Model loaded successfully.")
except Exception as e:
logging.error(f"Failed to load model: {e}")
# If the model fails to load, the service is useless.
# In a real system, this should trigger alerts.
model = None
class BatchInferenceRequest(BaseModel):
texts: List[str] = Field(..., min_items=1, max_items=128) # Set a reasonable max batch size
class BatchInferenceResponse(BaseModel):
embeddings: List[List[float]]
@app.post("/batch_embed", response_model=BatchInferenceResponse)
def create_batch_embedding(request: BatchInferenceRequest):
if model is None:
raise HTTPException(status_code=503, detail="Model not available")
try:
# The core logic is now designed for batches.
# This is where the GPU throughput is maximized.
logging.info(f"Processing batch of size: {len(request.texts)}")
embeddings = model.encode(request.texts, batch_size=len(request.texts))
# Convert numpy arrays to lists for JSON serialization
response_embeddings = [emb.tolist() for emb in embeddings]
return BatchInferenceResponse(embeddings=response_embeddings)
except Exception as e:
logging.error(f"Inference failed for batch: {e}")
# Propagate a server error if the model fails during inference.
raise HTTPException(status_code=500, detail="Internal server error during inference")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)
The corresponding Dockerfile is straightforward. A common pitfall here is not managing the model cache directory. By defining SENTENCE_TRANSFORMERS_HOME
, we ensure the model is downloaded to a predictable location inside the container image, avoiding re-downloads on every container start.
# file: worker/Dockerfile
FROM python:3.10-slim
WORKDIR /app
# It's better to copy only the requirements file first to leverage Docker's layer caching.
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# Set the cache directory for models to be included in the image layer.
ENV SENTENCE_TRANSFORMERS_HOME=/app/models
# This line will download the model during the build process.
# This makes the image larger but the container startup much faster.
RUN python -c "from sentence_transformers import SentenceTransformer; SentenceTransformer('all-MiniLM-L6-v2')"
COPY main.py .
EXPOSE 8000
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
The requirements.txt
would contain fastapi
, uvicorn
, sentence-transformers
, and pydantic
. This service is now dumb, in a good way. It does one thing well: batch inference. The intelligence will live in the Go service.
The Go Dynamic Batching Layer
This is the heart of the solution. We need a component that can:
- Accept individual requests concurrently from many clients.
- Queue these requests internally.
- Trigger a batch dispatch when either a maximum batch size is reached or a maximum wait time has elapsed.
- Send the assembled batch to the Python worker.
- Receive the batch response and correctly route the individual results back to the waiting clients.
Go’s concurrency primitives—goroutines and channels—are tailor-made for this.
Here is the core data structure for a single job entering our system. Crucially, it includes a responseChan
. Each incoming HTTP request will create a job and wait on its own unique channel for the result. This is how we map asynchronous batch results back to synchronous HTTP requests.
// file: gateway/batcher/processor.go
package batcher
import (
"bytes"
"context"
"encoding/json"
"errors"
"log"
"net/http"
"sync"
"time"
)
// SingleJob represents one client request waiting to be batched.
type SingleJob struct {
Text string
responseChan chan<- JobResult // Channel to send the result back to the waiting handler.
}
// JobResult holds the outcome for a single job.
type JobResult struct {
Embedding []float32
Err error
}
// BatchProcessor is the core component that collects and dispatches jobs.
type BatchProcessor struct {
jobQueue chan SingleJob
workerEndpoint string
maxBatchSize int
maxWaitTime time.Duration
httpClient *http.Client
shutdownSignal chan struct{}
wg sync.WaitGroup
}
// NewBatchProcessor initializes and starts the batch processor.
func NewBatchProcessor(endpoint string, maxBatchSize int, maxWaitTime time.Duration) *BatchProcessor {
processor := &BatchProcessor{
jobQueue: make(chan SingleJob, maxBatchSize*2), // Buffered channel
workerEndpoint: endpoint,
maxBatchSize: maxBatchSize,
maxWaitTime: maxWaitTime,
httpClient: &http.Client{
Timeout: 10 * time.Second, // A reasonable timeout for the downstream call.
},
shutdownSignal: make(chan struct{}),
}
processor.wg.Add(1)
go processor.run() // Start the processing loop in a background goroutine.
return processor
}
The run
method is the engine. It uses a time.Ticker
for the timeout-based dispatch and a slice to collect jobs. The select
statement is key: it waits for either a new job to arrive or the ticker to fire. This dual condition is what makes the batching “dynamic.”
// file: gateway/batcher/processor.go (continued)
func (p *BatchProcessor) run() {
defer p.wg.Done()
// The batch currently being assembled.
batch := make([]SingleJob, 0, p.maxBatchSize)
// A timer that dictates the maximum time to wait before sending a batch.
ticker := time.NewTicker(p.maxWaitTime)
defer ticker.Stop()
for {
select {
case <-p.shutdownSignal:
// If a shutdown is signaled, process any remaining jobs before exiting.
if len(batch) > 0 {
p.dispatchBatch(batch)
}
log.Println("Batch processor shutting down.")
return
case job, ok := <-p.jobQueue:
if !ok {
// Channel closed, which is another way to signal shutdown.
if len(batch) > 0 {
p.dispatchBatch(batch)
}
log.Println("Job queue closed, shutting down processor.")
return
}
batch = append(batch, job)
// Dispatch immediately if the batch is full.
if len(batch) >= p.maxBatchSize {
p.dispatchBatch(batch)
batch = make([]SingleJob, 0, p.maxBatchSize) // Reset the batch
ticker.Reset(p.maxWaitTime) // Reset the timer
}
case <-ticker.C:
// Time is up, dispatch whatever we have.
if len(batch) > 0 {
p.dispatchBatch(batch)
batch = make([]SingleJob, 0, p.maxBatchSize) // Reset the batch
}
}
}
}
// SubmitJob is the public method for handlers to add a job to the queue.
func (p *BatchProcessor) SubmitJob(job SingleJob) {
p.jobQueue <- job
}
// Shutdown gracefully stops the processor.
func (p *BatchProcessor) Shutdown() {
close(p.shutdownSignal)
p.wg.Wait()
}
The dispatchBatch
method is where the communication with the Python worker happens. It constructs the JSON payload, makes the HTTP call, and—most importantly—distributes the results or errors back to the waiting goroutines via their individual responseChan
.
// file: gateway/batcher/processor.go (continued)
type workerRequest struct {
Texts []string `json:"texts"`
}
type workerResponse struct {
Embeddings [][]float32 `json:"embeddings"`
}
func (p *BatchProcessor) dispatchBatch(batch []SingleJob) {
log.Printf("Dispatching batch of size %d\n", len(batch))
texts := make([]string, len(batch))
for i, job := range batch {
texts[i] = job.Text
}
reqPayload := workerRequest{Texts: texts}
payloadBytes, err := json.Marshal(reqPayload)
if err != nil {
log.Printf("ERROR: Failed to marshal request payload: %v\n", err)
p.distributeError(batch, errors.New("internal server error"))
return
}
// A context for the downstream request to handle timeouts properly.
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "POST", p.workerEndpoint, bytes.NewBuffer(payloadBytes))
if err != nil {
log.Printf("ERROR: Failed to create new request: %v\n", err)
p.distributeError(batch, errors.New("internal server error"))
return
}
req.Header.Set("Content-Type", "application/json")
resp, err := p.httpClient.Do(req)
if err != nil {
log.Printf("ERROR: Failed to call worker endpoint: %v\n", err)
p.distributeError(batch, errors.New("inference worker unavailable"))
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
log.Printf("ERROR: Worker returned non-200 status: %d\n", resp.StatusCode)
p.distributeError(batch, errors.New("inference worker failed"))
return
}
var workerResp workerResponse
if err := json.NewDecoder(resp.Body).Decode(&workerResp); err != nil {
log.Printf("ERROR: Failed to decode worker response: %v\n", err)
p.distributeError(batch, errors.New("invalid response from worker"))
return
}
if len(workerResp.Embeddings) != len(batch) {
log.Printf("ERROR: Mismatch in batch size. Sent %d, got %d\n", len(batch), len(workerResp.Embeddings))
p.distributeError(batch, errors.New("worker returned mismatched batch size"))
return
}
// Success case: distribute results back to each waiting goroutine.
for i, job := range batch {
job.responseChan <- JobResult{Embedding: workerResp.Embeddings[i], Err: nil}
}
}
// distributeError sends an error back to all jobs in a failed batch.
func (p *BatchProcessor) distributeError(batch []SingleJob, err error) {
for _, job := range batch {
job.responseChan <- JobResult{Err: err}
}
}
Finally, the Gin HTTP handler ties it all together. It’s lean. It creates the job, submits it, and then blocks, waiting for the result. A context with a timeout is critical here to prevent a client from waiting indefinitely if the batching system gets stuck.
// file: gateway/main.go
package main
import (
"context"
"log"
"net/http"
"os"
"time"
"gateway/batcher" // Assuming the batcher package is in this path
"github.com/gin-gonic/gin"
)
type EmbedRequest struct {
Text string `json:"text" binding:"required"`
}
func main() {
workerEndpoint := os.Getenv("WORKER_ENDPOINT")
if workerEndpoint == "" {
log.Fatal("WORKER_ENDPOINT environment variable not set")
}
// Configuration for the batcher. These should be tunable.
maxBatchSize := 64
maxWaitTime := 100 * time.Millisecond
processor := batcher.NewBatchProcessor(workerEndpoint, maxBatchSize, maxWaitTime)
defer processor.Shutdown()
r := gin.Default()
r.POST("/embed", func(c *gin.Context) {
var req EmbedRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// Each request gets its own response channel.
responseChan := make(chan batcher.JobResult, 1)
job := batcher.SingleJob{
Text: req.Text,
responseChan: responseChan,
}
processor.SubmitJob(job)
// Wait for the result or a timeout. This is critical for client-side responsiveness.
select {
case result := <-responseChan:
if result.Err != nil {
// Map internal errors to appropriate HTTP status codes.
c.JSON(http.StatusServiceUnavailable, gin.H{"error": result.Err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"embedding": result.Embedding})
case <-time.After(5 * time.Second): // Client-facing timeout
c.JSON(http.StatusGatewayTimeout, gin.H{"error": "request timed out"})
}
})
r.GET("/health", func(c *gin.Context){
c.JSON(http.StatusOK, gin.H{"status": "ok"})
})
log.Println("Starting Go API Gateway on port 8080")
if err := r.Run(":8080"); err != nil {
log.Fatalf("Failed to start Gin server: %v", err)
}
}
The Dockerfile for the Go service is also standard, using a multi-stage build to keep the final image small.
# file: gateway/Dockerfile
# Build Stage
FROM golang:1.21-alpine AS builder
WORKDIR /app
COPY go.mod ./
COPY go.sum ./
RUN go mod download
COPY . .
RUN CGO_ENABLED=0 GOOS=linux go build -o /gateway-app ./main.go
# Final Stage
FROM alpine:latest
WORKDIR /root/
COPY /gateway-app .
EXPOSE 8080
CMD ["./gateway-app"]
Deployment on Docker Swarm
Now we tie it all together with a docker-compose.yml
file for Swarm mode. The key features here are the overlay network allowing seamless communication between gateway
and worker
, and the placement constraint.
Before deploying, we need to label the node that has a GPU. On that specific node manager or worker, you run:docker node update --label-add gpu=true $(hostname)
This tag is what allows the deploy.placement.constraints
to work.
# file: docker-compose.yml
version: "3.8"
services:
gateway:
image: my-registry/gateway:1.0 # Replace with your image name
build:
context: ./gateway
ports:
- "8080:8080"
environment:
# The service name 'worker' is automatically resolved by Swarm's DNS.
- WORKER_ENDPOINT=http://worker:8000/batch_embed
networks:
- ml-net
deploy:
replicas: 2 # We can scale the stateless gateway easily.
update_config:
parallelism: 1
delay: 10s
restart_policy:
condition: on-failure
worker:
image: my-registry/worker:1.0 # Replace with your image name
build:
context: ./worker
environment:
- MODEL_NAME=all-MiniLM-L6-v2
networks:
- ml-net
deploy:
replicas: 1 # Typically one worker per GPU.
placement:
constraints:
# This is the critical line that ensures the worker lands on a GPU node.
- node.labels.gpu == true
restart_policy:
condition: on-failure
# This section is required for NVIDIA GPUs with the nvidia-container-toolkit
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
networks:
ml-net:
driver: overlay
To deploy this stack, the commands are simple:
-
docker swarm init
(on the manager node) -
docker stack deploy -c docker-compose.yml ml-stack
The result is a resilient system. The Go gateway instances load-balance incoming traffic. They batch requests and send them to the single, stateful (because of the loaded model) Python worker, which is guaranteed to be running on the correct hardware. GPU utilization under load now consistently stays above 80%, and the average latency per request has dropped by an order of magnitude, from 50ms to under 5ms (amortized), even accounting for the 100ms max batching delay.
graph TD subgraph Client Traffic C1(Client 1) --> LB C2(Client 2) --> LB C3(Client 3) --> LB C4(Client 4) --> LB end subgraph Docker Swarm LB(Swarm Ingress Routing Mesh) --> G1 LB --> G2 subgraph "Node 1 (No GPU)" G1(Go Gateway Replica 1) end subgraph "Node 2 (No GPU)" G2(Go Gateway Replica 2) end subgraph "Node 3 (GPU-enabled, label: gpu=true)" W1(Python Worker) end G1 -- "Batch Request (HTTP)" --> W1 G2 -- "Batch Request (HTTP)" --> W1 end style W1 fill:#f9f,stroke:#333,stroke-width:2px
The current implementation of the Go batcher is simple and effective, but it’s a single point of coordination within each instance. If a gateway instance restarts, the in-memory batch is lost. For our use case, where requests are idempotent and clients can retry, this was an acceptable trade-off. A more robust solution might use an external queue like Redis or NATS to hold the jobs, allowing multiple gateway instances to pull from a shared queue. This, however, adds another piece of infrastructure to manage. The current design is a pragmatic optimum of performance gain versus operational complexity. Furthermore, scaling the Python workers isn’t trivial; adding a second worker replica would require a more sophisticated load balancing or sharding strategy at the gateway level to ensure batches are routed effectively, which Docker Swarm’s round-robin DNS doesn’t solve out of the box for this stateful workload. That remains a challenge for a future iteration.