The initial problem wasn’t a technical one, but an organizational one. We had dozens of data science teams producing Keras models, each deployed on its own pet EC2 instance with a hand-rolled Flask API. Access control was a chaotic mix of SSH keys, hardcoded basic auth, and VPN restrictions. From a security and cost perspective, it was untenable. The mandate was clear: consolidate all internal model inference behind our corporate Okta SAML Identity Provider (IdP) and migrate to a pay-per-use serverless infrastructure.
This created an immediate architectural conflict. SAML is an inherently stateful, redirect-based protocol designed for traditional web applications. Serverless functions, particularly on AWS Lambda, are stateless and ephemeral. Forcing these two paradigms together, while also managing the complex dependencies of ML models, became the central challenge. The initial concept was a portal where authenticated users could access a catalog of models and run on-demand inference.
Our technology stack was largely predetermined by enterprise standards and the nature of the problem itself.
- SAML 2.0: The non-negotiable enterprise authentication standard.
- React: The standard for internal front-end tooling.
- AWS Lambda: The chosen serverless compute platform for its event-driven model and cost efficiency.
- Docker: The only sane way to package Keras models with their specific TensorFlow versions, system libraries, and pre-trained weights into a reproducible artifact. Lambda’s support for container images was the key enabler here.
- Amazon API Gateway: The front door to our Lambda functions, tasked with routing and, crucially, authorization.
The core of the implementation revolved around decoupling the SAML authentication flow from the actual inference workload. We couldn’t perform the SAML “dance” on every single API call; the latency would be unacceptable. The solution was a two-part serverless backend: one function dedicated to handling the SAML assertion and issuing a short-lived JSON Web Token (JWT), and a second, container-backed function for running the Keras model, protected by a JWT authorizer.
The Authentication Flow Architecture
Before diving into code, it’s critical to visualize the sequence of events. The user never directly interacts with SAML. The React SPA orchestrates the flow through our serverless backend.
sequenceDiagram participant User participant ReactSPA as React SPA participant APIGW as API Gateway participant AuthLambda as Auth Lambda (Node.js) participant SamlIdP as SAML IdP (Okta) participant InferenceLambda as Inference Lambda (Python/Keras) User->>+ReactSPA: Loads application ReactSPA->>+APIGW: GET /api/models (No Auth Token) APIGW-->>ReactSPA: 401 Unauthorized ReactSPA->>User: Redirect to /api/auth/login User->>+APIGW: GET /api/auth/login APIGW->>+AuthLambda: Invoke login handler AuthLambda->>+SamlIdP: Generate SAML AuthnRequest, Redirect User SamlIdP->>User: Presents Login Page User->>SamlIdP: Submits Credentials SamlIdP-->>-User: POST SAML Assertion to ACS URL (/api/auth/callback) User->>+APIGW: POST /api/auth/callback with SAMLResponse APIGW->>+AuthLambda: Invoke ACS handler AuthLambda->>AuthLambda: Verify SAML Assertion Signature AuthLambda->>AuthLambda: Create JWT with user attributes (email, roles) AuthLambda-->>-User: Set secure, HttpOnly cookie with JWT & Redirect to SPA root User->>+ReactSPA: Reloads application with JWT cookie ReactSPA->>+APIGW: GET /api/models (With Auth Token from cookie) APIGW->>APIGW: Validate JWT via Authorizer Note over APIGW: Authorizer passes user claims to Lambda APIGW->>+InferenceLambda: Invoke with validated request InferenceLambda-->>-APIGW: Returns model list APIGW-->>-ReactSPA: 200 OK with data
This “backend-for-frontend” (BFF) approach for authentication keeps the React application clean. It doesn’t need to know anything about SAML. It only needs to handle a 401 response by redirecting to a login endpoint and then subsequently include credentials (the cookie) on future requests.
Implementing the SAML Authentication Lambda
The first piece of the puzzle is the Node.js Lambda function responsible for the SAML handshake. We used the passport-saml
library for this, as it handles most of the XML parsing and validation complexity. A real-world project requires robust configuration management; for this example, we’ll store settings in environment variables.
Here’s the core serverless.yml
definition for this part of the service:
# serverless.yml (Auth Service)
service: ml-inference-auth
provider:
name: aws
runtime: nodejs18.x
region: us-east-1
environment:
SAML_ENTRY_POINT: ${ssm:/ml-app/saml/idp-entry-point}
SAML_ISSUER: 'urn:ml-inference-platform:sp'
SAML_CALLBACK_URL: 'https://<api-id>.execute-api.us-east-1.amazonaws.com/prod/auth/callback'
SAML_CERT_PATH: 'config/idp_cert.pem'
JWT_SECRET: ${ssm:/ml-app/jwt/secret}
APP_URL: 'https://<your-frontend-domain>'
functions:
login:
handler: handler.login
events:
- http:
path: auth/login
method: get
cors: true
authCallback:
handler: handler.authCallback
events:
- http:
path: auth/callback
method: post
cors: true
# Note: In a production setup, you'd package the IdP certificate with your deployment artifact.
The handler code itself initializes Passport and defines the two routes. A common mistake is to re-initialize the SAML strategy on every invocation. This is inefficient. We define it outside the handler to take advantage of Lambda’s container reuse.
handler.js
const serverless = require('serverless-http');
const express = require('express');
const passport = require('passport');
const SamlStrategy = require('passport-saml').Strategy;
const jwt = require('jsonwebtoken');
const fs = require('fs');
const cookieParser = require('cookie-parser');
const app = express();
app.use(cookieParser());
app.use(passport.initialize());
// A real-world project would fetch this cert from a secure store, not the filesystem.
const idpCert = fs.readFileSync(process.env.SAML_CERT_PATH, 'utf-8');
const samlStrategy = new SamlStrategy(
{
path: '/auth/callback',
entryPoint: process.env.SAML_ENTRY_POINT,
issuer: process.env.SAML_ISSUER,
callbackUrl: process.env.SAML_CALLBACK_URL,
cert: idpCert,
identifierFormat: 'urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress',
signatureAlgorithm: 'sha256',
acceptedClockSkewMs: 300 // Tolerate minor clock drift between IdP and SP
},
(profile, done) => {
// The profile object contains attributes asserted by the IdP.
// e.g., email, firstName, lastName, memberOf
console.log(`SAML assertion successful for user: ${profile.email}`);
return done(null, {
email: profile.email,
userId: profile.nameID,
groups: profile.memberOf // Critical for role-based access control
});
}
);
passport.use(samlStrategy);
// Login route: Initiates the SAML flow
app.get('/auth/login', passport.authenticate('saml', { session: false }));
// Callback route: Handles the SAML assertion from the IdP
app.post('/auth/callback', passport.authenticate('saml', { session: false }), (req, res) => {
if (!req.user) {
console.error('SAML authentication failed, no user profile returned.');
return res.status(401).send('Authentication Failed');
}
try {
const userPayload = {
sub: req.user.userId,
email: req.user.email,
groups: req.user.groups,
};
// The pitfall here is making the token expiry too long.
// Short-lived tokens (e.g., 1 hour) are safer.
const token = jwt.sign(userPayload, process.env.JWT_SECRET, { expiresIn: '1h' });
// Set the JWT in a secure, HttpOnly cookie. The React app cannot access this
// cookie via JavaScript, which mitigates XSS risks.
res.cookie('authToken', token, {
httpOnly: true,
secure: true, // Only send over HTTPS
sameSite: 'strict', // Mitigates CSRF
maxAge: 3600 * 1000 // 1 hour
});
// Redirect back to the main application
res.redirect(process.env.APP_URL);
} catch (error) {
console.error('JWT signing or redirect failed:', error);
res.status(500).send('An internal error occurred during authentication.');
}
});
module.exports.login = serverless(app);
module.exports.authCallback = serverless(app);
Containerizing the Keras Inference Service
With authentication handled, we turn to the core payload: the Keras model. Packaging ML models is notoriously difficult due to large binaries (TensorFlow/PyTorch), system-level dependencies (CUDA/cuDNN), and the model weights themselves. Docker is the perfect solution.
Let’s assume a simple MNIST digit classifier as our model. The inference code is wrapped in a lightweight web server like FastAPI for performance.
inference_app/main.py
import os
import logging
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import numpy as np
from tensorflow.keras.models import load_model
# Setup structured logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
app = FastAPI()
MODEL_PATH = os.getenv('MODEL_PATH', 'model/mnist_model.h5')
model = None
@app.on_event("startup")
def load_keras_model():
"""Load the Keras model into memory during application startup."""
global model
try:
if os.path.exists(MODEL_PATH):
model = load_model(MODEL_PATH)
logger.info(f"Model loaded successfully from {MODEL_PATH}")
else:
logger.error(f"Model file not found at {MODEL_PATH}")
# In a real system, this should trigger a health check failure.
except Exception as e:
logger.critical(f"Failed to load Keras model: {e}", exc_info=True)
# Prevent the application from starting if the model can't be loaded.
raise RuntimeError("Model loading failed") from e
class InferenceRequest(BaseModel):
# Expecting a flattened 28x28 image (784 features), normalized to [0, 1]
image_data: list[float]
class InferenceResponse(BaseModel):
prediction: int
confidence: float
@app.post("/predict", response_model=InferenceResponse)
def predict(request: InferenceRequest):
"""
Handles inference requests. The API Gateway authorizer has already
validated the JWT before this function is invoked.
"""
if model is None:
logger.error("Model is not loaded, cannot perform inference.")
raise HTTPException(status_code=503, detail="Model not available")
if len(request.image_data) != 784:
raise HTTPException(status_code=400, detail="Input data must be a flattened 28x28 image (784 floats).")
try:
# Reshape for the model, add a batch dimension
image_array = np.array(request.image_data).reshape(1, 28, 28, 1)
# Perform inference
predictions = model.predict(image_array)
predicted_class = int(np.argmax(predictions[0]))
confidence_score = float(np.max(predictions[0]))
logger.info(f"Inference successful. Prediction: {predicted_class} with confidence {confidence_score:.4f}")
return InferenceResponse(prediction=predicted_class, confidence=confidence_score)
except Exception as e:
logger.error(f"An error occurred during inference: {e}", exc_info=True)
raise HTTPException(status_code=500, detail="Internal server error during prediction.")
The Dockerfile
is the most critical piece of this section. We use a multi-stage build to keep the final image size down. The build stage installs dependencies, while the final stage copies only the necessary runtime environment and application code. This is crucial for reducing Lambda cold start times.
Dockerfile
# Stage 1: Build stage with full build dependencies
FROM python:3.9-slim as builder
WORKDIR /app
# Install system dependencies that might be needed by ML libraries
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
&& rm -rf /var/lib/apt/lists/*
# Install Python dependencies into a virtual environment
COPY requirements.txt .
RUN python -m venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH"
RUN pip install --no-cache-dir -r requirements.txt
# ---
# Stage 2: Final runtime stage
FROM python:3.9-slim
WORKDIR /app
# Copy the virtual environment from the builder stage
COPY /opt/venv /opt/venv
# Copy application code and the trained model
COPY ./inference_app/ /app/inference_app/
COPY ./model/ /app/model/
# Set environment variables for the runtime
ENV PATH="/opt/venv/bin:$PATH"
ENV MODEL_PATH="/app/model/mnist_model.h5"
# Set Python path to find modules inside inference_app
ENV PYTHONPATH "${PYTHONPATH}:/app"
# AWS Lambda Runtime Interface Client (RIC) will use this command
# to start our web server.
CMD ["uvicorn", "inference_app.main:app", "--host", "0.0.0.0", "--port", "8080"]
Tying It Together with API Gateway and React
Now we define the inference service in its own serverless.yml
. The key component here is the authorizer
. We use a request
based authorizer that looks for the authToken
cookie we set earlier.
serverless.yml
(Inference Service)
service: ml-inference-api
provider:
name: aws
region: us-east-1
ecr:
images:
inference-service:
path: ./
functions:
api:
image:
name: inference-service
events:
- http:
method: any
path: /{proxy+}
authorizer:
name: jwtCookieAuthorizer
type: request
identitySource:
- $request.header.Cookie
resultTtlInSeconds: 0 # Do not cache authorization results
environment:
JWT_SECRET: ${ssm:/ml-app/jwt/secret}
# This is the Lambda function that acts as the authorizer
jwtCookieAuthorizer:
handler: authorizer.handler
The authorizer function is a small piece of Node.js code that extracts the cookie, verifies the JWT, and returns an IAM policy. A common mistake is to return a simple “Allow” or “Deny”. Returning a detailed policy allows you to pass context (like the user’s email) directly to the downstream Lambda, which is invaluable for logging and auditing.
authorizer.js
const jwt = require('jsonwebtoken');
// Helper to extract cookie value
const getCookie = (cookieString, cookieName) => {
if (!cookieString) return null;
const cookies = cookieString.split(';');
for (const cookie of cookies) {
const [name, value] = cookie.trim().split('=');
if (name === cookieName) {
return value;
}
}
return null;
};
// Generates the IAM policy document
const generatePolicy = (principalId, effect, resource, context) => ({
principalId,
policyDocument: {
Version: '2012-10-17',
Statement: [{
Action: 'execute-api:Invoke',
Effect: effect,
Resource: resource,
}],
},
context, // Pass verified claims to the integration
});
module.exports.handler = async (event) => {
console.log('Authorizer event:', JSON.stringify(event, null, 2));
const tokenCookie = getCookie(event.headers?.Cookie, 'authToken');
if (!tokenCookie) {
console.log('No authToken cookie found');
// Important: API Gateway requires this specific error string for 401
throw new Error('Unauthorized');
}
try {
const decoded = jwt.verify(tokenCookie, process.env.JWT_SECRET);
console.log(`Successfully verified token for user: ${decoded.email}`);
// The context object will be available in the downstream Lambda's event object
const context = {
userId: decoded.sub,
email: decoded.email,
groups: JSON.stringify(decoded.groups || [])
};
return generatePolicy(decoded.sub, 'Allow', event.methodArn, context);
} catch (err) {
console.error('Token verification failed:', err.message);
throw new Error('Unauthorized');
}
};
Finally, the React application can be built to interact with this system. We use a library like Axios and set up an interceptor to handle 401 responses gracefully.
ReactApiService.js
import axios from 'axios';
const apiClient = axios.create({
baseURL: '/api', // Proxied to the API Gateway
withCredentials: true, // Crucial for sending cookies
});
// Interceptor to handle authentication logic
apiClient.interceptors.response.use(
(response) => response,
(error) => {
if (error.response && error.response.status === 401) {
// The session has expired or the user is not logged in.
// Redirect to the SAML login flow.
console.log('Unauthorized request. Redirecting to login...');
window.location.href = '/api/auth/login';
}
return Promise.reject(error);
}
);
export const runInference = async (imageData) => {
try {
const response = await apiClient.post('/predict', { image_data: imageData });
return response.data;
} catch (error) {
console.error("Inference API call failed:", error);
throw error;
}
};
This architecture, while complex to set up, provides a robust, scalable, and secure platform. It successfully bridges the gap between legacy enterprise authentication and modern cloud-native application design.
The most significant remaining issue is the cold start latency for the containerized Keras model. An 800MB container image can take upwards of 10-15 seconds to initialize on its first invocation. For interactive use cases, this is unacceptable. The immediate mitigation is to use AWS Lambda Provisioned Concurrency, but this introduces a fixed cost, partially defeating the purpose of a pure serverless model. Future work will involve exploring model quantization, dependency pruning to reduce the image size, or potentially moving latency-sensitive models to a different compute platform like AWS Fargate, while still using this SAML-to-JWT pattern for authentication. The current JWT mechanism also lacks a refresh token flow, meaning users must re-authenticate every hour; implementing a secure token refresh mechanism within this stateless architecture is the next complex challenge.