Implementing a Dynamic Request Batching Layer in Go for a Docker Swarm-Hosted Transformer Model


The core problem wasn’t a complex algorithm, but a simple, brutal mismatch of throughput. We had a sentence-transformer model from Hugging Face, perfectly capable of churning through embeddings, but only when fed data in large batches. Deploying it behind a standard REST API where requests arrive one-by-one resulted in abysmal GPU utilization, often below 10%, and a per-request latency that was unacceptable. Each individual request, no matter how small, paid the full price of model invocation overhead. A naive API is a performance bottleneck waiting to happen.

Our initial Python service, built with FastAPI, looked something like this. It worked, but it was fundamentally inefficient for the hardware it was running on.

# WARNING: This is the inefficient, anti-pattern implementation.
from fastapi import FastAPI
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
import uvicorn
import os

app = FastAPI()

# Load the model once at startup. A common and correct practice.
# The problem isn't model loading, it's the invocation pattern.
model_name = os.getenv("MODEL_NAME", "all-MiniLM-L6-v2")
model = SentenceTransformer(model_name)

class InferenceRequest(BaseModel):
    text: str

class InferenceResponse(BaseModel):
    embedding: list[float]

@app.post("/embed", response_model=InferenceResponse)
def create_embedding(request: InferenceRequest):
    # The core issue: The model's encode() method is called for EACH request.
    # This is horribly inefficient for GPU-based computation.
    embedding = model.encode([request.text])[0]
    return InferenceResponse(embedding=embedding.tolist())

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)

Each call to /embed triggered a separate model.encode() operation. On a V100 GPU, this meant we were using a sledgehammer to crack a nut, over and over. The latency for a single request was around 50ms. Serving 100 requests sequentially would take 5 seconds. The real kicker is that processing a batch of 100 sentences at once might only take 150ms. The path forward was clear: we had to implement request batching.

The first thought was to do it in Python. An asyncio-based queue could collect requests. But this couples the batching logic with the inference code, creating a monolith. In a real-world project, the service handling concurrent network I/O should be decoupled from the service doing heavy, blocking computation. The former needs to be lightweight and highly responsive; the latter needs to be optimized for raw computational throughput.

This led to a two-service architecture:

  1. Inference Worker (Python/FastAPI): A lean service whose only job is to receive a batch of sentences and return a batch of embeddings. It assumes its caller is smart enough to batch.
  2. API Gateway (Go/Gin): A highly concurrent service facing the public. It accepts individual requests, holds them in a buffer for a very short period, assembles a batch, and fires it off to the Inference Worker.

For orchestration, Kubernetes felt like overkill. We had two stable, well-defined services. The operational overhead of setting up and maintaining a K8s cluster for this wasn’t justified. Docker Swarm provided exactly what we needed: simple service discovery, declarative deployments via a compose file, and placement constraints to ensure our Python worker landed on a GPU-enabled node. It’s a pragmatic choice that prioritizes simplicity and stability.

The Batch-Aware Python Inference Worker

First, we refactor the Python service. Its API contract must change. Instead of accepting one sentence, it must accept a list of them.

# file: worker/main.py
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from sentence_transformers import SentenceTransformer
import uvicorn
import os
import logging
from typing import List

# Setup structured logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

app = FastAPI()

# A simple health check endpoint
@app.get("/health")
def health_check():
    return {"status": "ok"}

# The model loading is critical. It must happen only once.
# A common mistake is to load the model inside the request handler.
try:
    model_name = os.getenv("MODEL_NAME", "all-MiniLM-L6-v2")
    logging.info(f"Loading model: {model_name}")
    model = SentenceTransformer(model_name)
    logging.info("Model loaded successfully.")
except Exception as e:
    logging.error(f"Failed to load model: {e}")
    # If the model fails to load, the service is useless.
    # In a real system, this should trigger alerts.
    model = None

class BatchInferenceRequest(BaseModel):
    texts: List[str] = Field(..., min_items=1, max_items=128) # Set a reasonable max batch size

class BatchInferenceResponse(BaseModel):
    embeddings: List[List[float]]

@app.post("/batch_embed", response_model=BatchInferenceResponse)
def create_batch_embedding(request: BatchInferenceRequest):
    if model is None:
        raise HTTPException(status_code=503, detail="Model not available")
    
    try:
        # The core logic is now designed for batches.
        # This is where the GPU throughput is maximized.
        logging.info(f"Processing batch of size: {len(request.texts)}")
        embeddings = model.encode(request.texts, batch_size=len(request.texts))
        
        # Convert numpy arrays to lists for JSON serialization
        response_embeddings = [emb.tolist() for emb in embeddings]
        
        return BatchInferenceResponse(embeddings=response_embeddings)
    except Exception as e:
        logging.error(f"Inference failed for batch: {e}")
        # Propagate a server error if the model fails during inference.
        raise HTTPException(status_code=500, detail="Internal server error during inference")

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)

The corresponding Dockerfile is straightforward. A common pitfall here is not managing the model cache directory. By defining SENTENCE_TRANSFORMERS_HOME, we ensure the model is downloaded to a predictable location inside the container image, avoiding re-downloads on every container start.

# file: worker/Dockerfile
FROM python:3.10-slim

WORKDIR /app

# It's better to copy only the requirements file first to leverage Docker's layer caching.
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt

# Set the cache directory for models to be included in the image layer.
ENV SENTENCE_TRANSFORMERS_HOME=/app/models

# This line will download the model during the build process.
# This makes the image larger but the container startup much faster.
RUN python -c "from sentence_transformers import SentenceTransformer; SentenceTransformer('all-MiniLM-L6-v2')"

COPY main.py .

EXPOSE 8000

CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]

The requirements.txt would contain fastapi, uvicorn, sentence-transformers, and pydantic. This service is now dumb, in a good way. It does one thing well: batch inference. The intelligence will live in the Go service.

The Go Dynamic Batching Layer

This is the heart of the solution. We need a component that can:

  1. Accept individual requests concurrently from many clients.
  2. Queue these requests internally.
  3. Trigger a batch dispatch when either a maximum batch size is reached or a maximum wait time has elapsed.
  4. Send the assembled batch to the Python worker.
  5. Receive the batch response and correctly route the individual results back to the waiting clients.

Go’s concurrency primitives—goroutines and channels—are tailor-made for this.

Here is the core data structure for a single job entering our system. Crucially, it includes a responseChan. Each incoming HTTP request will create a job and wait on its own unique channel for the result. This is how we map asynchronous batch results back to synchronous HTTP requests.

// file: gateway/batcher/processor.go
package batcher

import (
	"bytes"
	"context"
	"encoding/json"
	"errors"
	"log"
	"net/http"
	"sync"
	"time"
)

// SingleJob represents one client request waiting to be batched.
type SingleJob struct {
	Text         string
	responseChan chan<- JobResult // Channel to send the result back to the waiting handler.
}

// JobResult holds the outcome for a single job.
type JobResult struct {
	Embedding []float32
	Err       error
}

// BatchProcessor is the core component that collects and dispatches jobs.
type BatchProcessor struct {
	jobQueue         chan SingleJob
	workerEndpoint   string
	maxBatchSize     int
	maxWaitTime      time.Duration
	httpClient       *http.Client
	shutdownSignal   chan struct{}
	wg               sync.WaitGroup
}

// NewBatchProcessor initializes and starts the batch processor.
func NewBatchProcessor(endpoint string, maxBatchSize int, maxWaitTime time.Duration) *BatchProcessor {
	processor := &BatchProcessor{
		jobQueue:       make(chan SingleJob, maxBatchSize*2), // Buffered channel
		workerEndpoint: endpoint,
		maxBatchSize:   maxBatchSize,
		maxWaitTime:    maxWaitTime,
		httpClient: &http.Client{
			Timeout: 10 * time.Second, // A reasonable timeout for the downstream call.
		},
		shutdownSignal: make(chan struct{}),
	}

	processor.wg.Add(1)
	go processor.run() // Start the processing loop in a background goroutine.
	return processor
}

The run method is the engine. It uses a time.Ticker for the timeout-based dispatch and a slice to collect jobs. The select statement is key: it waits for either a new job to arrive or the ticker to fire. This dual condition is what makes the batching “dynamic.”

// file: gateway/batcher/processor.go (continued)

func (p *BatchProcessor) run() {
	defer p.wg.Done()
	
	// The batch currently being assembled.
	batch := make([]SingleJob, 0, p.maxBatchSize)
	
	// A timer that dictates the maximum time to wait before sending a batch.
	ticker := time.NewTicker(p.maxWaitTime)
	defer ticker.Stop()

	for {
		select {
		case <-p.shutdownSignal:
			// If a shutdown is signaled, process any remaining jobs before exiting.
			if len(batch) > 0 {
				p.dispatchBatch(batch)
			}
			log.Println("Batch processor shutting down.")
			return
		
		case job, ok := <-p.jobQueue:
			if !ok {
				// Channel closed, which is another way to signal shutdown.
				if len(batch) > 0 {
					p.dispatchBatch(batch)
				}
				log.Println("Job queue closed, shutting down processor.")
				return
			}
			
			batch = append(batch, job)
			// Dispatch immediately if the batch is full.
			if len(batch) >= p.maxBatchSize {
				p.dispatchBatch(batch)
				batch = make([]SingleJob, 0, p.maxBatchSize) // Reset the batch
				ticker.Reset(p.maxWaitTime) // Reset the timer
			}

		case <-ticker.C:
			// Time is up, dispatch whatever we have.
			if len(batch) > 0 {
				p.dispatchBatch(batch)
				batch = make([]SingleJob, 0, p.maxBatchSize) // Reset the batch
			}
		}
	}
}

// SubmitJob is the public method for handlers to add a job to the queue.
func (p *BatchProcessor) SubmitJob(job SingleJob) {
	p.jobQueue <- job
}

// Shutdown gracefully stops the processor.
func (p *BatchProcessor) Shutdown() {
	close(p.shutdownSignal)
	p.wg.Wait()
}

The dispatchBatch method is where the communication with the Python worker happens. It constructs the JSON payload, makes the HTTP call, and—most importantly—distributes the results or errors back to the waiting goroutines via their individual responseChan.

// file: gateway/batcher/processor.go (continued)

type workerRequest struct {
	Texts []string `json:"texts"`
}

type workerResponse struct {
	Embeddings [][]float32 `json:"embeddings"`
}

func (p *BatchProcessor) dispatchBatch(batch []SingleJob) {
	log.Printf("Dispatching batch of size %d\n", len(batch))

	texts := make([]string, len(batch))
	for i, job := range batch {
		texts[i] = job.Text
	}

	reqPayload := workerRequest{Texts: texts}
	payloadBytes, err := json.Marshal(reqPayload)
	if err != nil {
		log.Printf("ERROR: Failed to marshal request payload: %v\n", err)
		p.distributeError(batch, errors.New("internal server error"))
		return
	}
	
	// A context for the downstream request to handle timeouts properly.
	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
	defer cancel()

	req, err := http.NewRequestWithContext(ctx, "POST", p.workerEndpoint, bytes.NewBuffer(payloadBytes))
	if err != nil {
		log.Printf("ERROR: Failed to create new request: %v\n", err)
		p.distributeError(batch, errors.New("internal server error"))
		return
	}
	req.Header.Set("Content-Type", "application/json")

	resp, err := p.httpClient.Do(req)
	if err != nil {
		log.Printf("ERROR: Failed to call worker endpoint: %v\n", err)
		p.distributeError(batch, errors.New("inference worker unavailable"))
		return
	}
	defer resp.Body.Close()

	if resp.StatusCode != http.StatusOK {
		log.Printf("ERROR: Worker returned non-200 status: %d\n", resp.StatusCode)
		p.distributeError(batch, errors.New("inference worker failed"))
		return
	}

	var workerResp workerResponse
	if err := json.NewDecoder(resp.Body).Decode(&workerResp); err != nil {
		log.Printf("ERROR: Failed to decode worker response: %v\n", err)
		p.distributeError(batch, errors.New("invalid response from worker"))
		return
	}

	if len(workerResp.Embeddings) != len(batch) {
		log.Printf("ERROR: Mismatch in batch size. Sent %d, got %d\n", len(batch), len(workerResp.Embeddings))
		p.distributeError(batch, errors.New("worker returned mismatched batch size"))
		return
	}

	// Success case: distribute results back to each waiting goroutine.
	for i, job := range batch {
		job.responseChan <- JobResult{Embedding: workerResp.Embeddings[i], Err: nil}
	}
}

// distributeError sends an error back to all jobs in a failed batch.
func (p *BatchProcessor) distributeError(batch []SingleJob, err error) {
	for _, job := range batch {
		job.responseChan <- JobResult{Err: err}
	}
}

Finally, the Gin HTTP handler ties it all together. It’s lean. It creates the job, submits it, and then blocks, waiting for the result. A context with a timeout is critical here to prevent a client from waiting indefinitely if the batching system gets stuck.

// file: gateway/main.go
package main

import (
	"context"
	"log"
	"net/http"
	"os"
	"time"

	"gateway/batcher" // Assuming the batcher package is in this path
	"github.com/gin-gonic/gin"
)

type EmbedRequest struct {
	Text string `json:"text" binding:"required"`
}

func main() {
	workerEndpoint := os.Getenv("WORKER_ENDPOINT")
	if workerEndpoint == "" {
		log.Fatal("WORKER_ENDPOINT environment variable not set")
	}

	// Configuration for the batcher. These should be tunable.
	maxBatchSize := 64
	maxWaitTime := 100 * time.Millisecond

	processor := batcher.NewBatchProcessor(workerEndpoint, maxBatchSize, maxWaitTime)
	defer processor.Shutdown()

	r := gin.Default()

	r.POST("/embed", func(c *gin.Context) {
		var req EmbedRequest
		if err := c.ShouldBindJSON(&req); err != nil {
			c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
			return
		}

		// Each request gets its own response channel.
		responseChan := make(chan batcher.JobResult, 1)
		job := batcher.SingleJob{
			Text:         req.Text,
			responseChan: responseChan,
		}

		processor.SubmitJob(job)

		// Wait for the result or a timeout. This is critical for client-side responsiveness.
		select {
		case result := <-responseChan:
			if result.Err != nil {
				// Map internal errors to appropriate HTTP status codes.
				c.JSON(http.StatusServiceUnavailable, gin.H{"error": result.Err.Error()})
				return
			}
			c.JSON(http.StatusOK, gin.H{"embedding": result.Embedding})
		case <-time.After(5 * time.Second): // Client-facing timeout
			c.JSON(http.StatusGatewayTimeout, gin.H{"error": "request timed out"})
		}
	})

	r.GET("/health", func(c *gin.Context){
		c.JSON(http.StatusOK, gin.H{"status": "ok"})
	})

	log.Println("Starting Go API Gateway on port 8080")
	if err := r.Run(":8080"); err != nil {
		log.Fatalf("Failed to start Gin server: %v", err)
	}
}

The Dockerfile for the Go service is also standard, using a multi-stage build to keep the final image small.

# file: gateway/Dockerfile
# Build Stage
FROM golang:1.21-alpine AS builder

WORKDIR /app

COPY go.mod ./
COPY go.sum ./
RUN go mod download

COPY . .
RUN CGO_ENABLED=0 GOOS=linux go build -o /gateway-app ./main.go

# Final Stage
FROM alpine:latest

WORKDIR /root/
COPY --from=builder /gateway-app .

EXPOSE 8080

CMD ["./gateway-app"]

Deployment on Docker Swarm

Now we tie it all together with a docker-compose.yml file for Swarm mode. The key features here are the overlay network allowing seamless communication between gateway and worker, and the placement constraint.

Before deploying, we need to label the node that has a GPU. On that specific node manager or worker, you run:
docker node update --label-add gpu=true $(hostname)

This tag is what allows the deploy.placement.constraints to work.

# file: docker-compose.yml
version: "3.8"

services:
  gateway:
    image: my-registry/gateway:1.0 # Replace with your image name
    build:
      context: ./gateway
    ports:
      - "8080:8080"
    environment:
      # The service name 'worker' is automatically resolved by Swarm's DNS.
      - WORKER_ENDPOINT=http://worker:8000/batch_embed
    networks:
      - ml-net
    deploy:
      replicas: 2 # We can scale the stateless gateway easily.
      update_config:
        parallelism: 1
        delay: 10s
      restart_policy:
        condition: on-failure

  worker:
    image: my-registry/worker:1.0 # Replace with your image name
    build:
      context: ./worker
    environment:
      - MODEL_NAME=all-MiniLM-L6-v2
    networks:
      - ml-net
    deploy:
      replicas: 1 # Typically one worker per GPU.
      placement:
        constraints:
          # This is the critical line that ensures the worker lands on a GPU node.
          - node.labels.gpu == true
      restart_policy:
        condition: on-failure
      # This section is required for NVIDIA GPUs with the nvidia-container-toolkit
      resources:
        reservations:
          devices:
            - driver: nvidia
              count: 1
              capabilities: [gpu]

networks:
  ml-net:
    driver: overlay

To deploy this stack, the commands are simple:

  1. docker swarm init (on the manager node)
  2. docker stack deploy -c docker-compose.yml ml-stack

The result is a resilient system. The Go gateway instances load-balance incoming traffic. They batch requests and send them to the single, stateful (because of the loaded model) Python worker, which is guaranteed to be running on the correct hardware. GPU utilization under load now consistently stays above 80%, and the average latency per request has dropped by an order of magnitude, from 50ms to under 5ms (amortized), even accounting for the 100ms max batching delay.

graph TD
    subgraph Client Traffic
        C1(Client 1) --> LB
        C2(Client 2) --> LB
        C3(Client 3) --> LB
        C4(Client 4) --> LB
    end
    
    subgraph Docker Swarm
        LB(Swarm Ingress Routing Mesh) --> G1
        LB --> G2

        subgraph "Node 1 (No GPU)"
            G1(Go Gateway Replica 1)
        end

        subgraph "Node 2 (No GPU)"
            G2(Go Gateway Replica 2)
        end
        
        subgraph "Node 3 (GPU-enabled, label: gpu=true)"
            W1(Python Worker)
        end

        G1 -- "Batch Request (HTTP)" --> W1
        G2 -- "Batch Request (HTTP)" --> W1
    end

    style W1 fill:#f9f,stroke:#333,stroke-width:2px

The current implementation of the Go batcher is simple and effective, but it’s a single point of coordination within each instance. If a gateway instance restarts, the in-memory batch is lost. For our use case, where requests are idempotent and clients can retry, this was an acceptable trade-off. A more robust solution might use an external queue like Redis or NATS to hold the jobs, allowing multiple gateway instances to pull from a shared queue. This, however, adds another piece of infrastructure to manage. The current design is a pragmatic optimum of performance gain versus operational complexity. Furthermore, scaling the Python workers isn’t trivial; adding a second worker replica would require a more sophisticated load balancing or sharding strategy at the gateway level to ensure batches are routed effectively, which Docker Swarm’s round-robin DNS doesn’t solve out of the box for this stateful workload. That remains a challenge for a future iteration.


  TOC