Deploying Quantized PyTorch Computer Vision Models for Real-Time Mobile Inference


The initial performance metrics from the prototype were a non-starter for production. A standard ResNet-18 model, trained for image classification, clocked in at 44.7MB on disk. When loaded onto a mid-range Android device, it consumed over 150MB of RAM at runtime and took an average of 180ms for a single inference pass. This latency makes any real-time application, such as live camera feed analysis, impossible, and the model size contributes significantly to app bundle bloat, a key factor in user churn during app installation. The mandate was clear: get the model size under 15MB and the inference latency below 50ms on the same target hardware, with a maximum acceptable accuracy drop of 1-2%.

Our stack was already heavily invested in PyTorch for research and training, so introducing a different framework like TensorFlow Lite or ONNX Runtime for deployment was a path we wanted to avoid. The context-switching and potential for conversion errors between frameworks introduces significant maintenance overhead and risk. This led us directly to PyTorch Mobile. The promise of a direct path from a Python training environment to an optimized C++ runtime on mobile was the primary driver for this decision.

The core of the optimization strategy would be quantization. We evaluated two primary approaches: Post-Training Static Quantization (PTSQ) and Post-Training Dynamic Quantization (PTDQ). PTSQ typically yields better performance by pre-calculating scale factors using a calibration dataset, but it adds the complexity of maintaining and versioning this dataset. PTDQ, on the other hand, quantizes weights post-training and dynamically quantizes activations on-the-fly during inference. It’s simpler to implement as it requires no calibration data. Given our need for rapid iteration, we opted for PTDQ as the first line of attack. The trade-off was a slight performance penalty compared to PTSQ, but the implementation simplicity was a significant advantage for getting a production-viable solution out the door.

The overall architecture for this process can be visualized as a multi-stage pipeline.

graph TD
    A[PyTorch Python Environment] -- 1. Pre-trained fp32 model --> B(Model Conversion Script);
    B -- 2. Apply Dynamic Quantization --> C{Quantized Model};
    C -- 3. JIT Script & Optimize --> D[Scripted .ptl Model];
    E[Model Hosting Server] -- 5. Serves model over API --> F(Android Application);
    D -- 4. Upload --> E;

    subgraph Android Device
        F -- 6. Downloads/Caches .ptl file --> G[File System];
        G -- 7. Load model path --> H(JNI Bridge);
        H -- 8. Invoke C++ Runtime --> I[LibTorch C++ Inference Engine];
        J[Camera/Image Input] -- 9. Preprocess --> H;
        I -- 10. Run Inference --> H;
        H -- 11. Return Results --> K[Application UI];
    end

    style B fill:#f9f,stroke:#333,stroke-width:2px
    style I fill:#ccf,stroke:#333,stroke-width:2px

Phase 1: Model Quantization and Export

The first step is to create a robust, repeatable script that takes a standard, floating-point PyTorch model and converts it into the quantized, scripted format required by the mobile runtime. In a real-world project, this script is a critical piece of CI/CD infrastructure.

A common mistake is to perform these steps manually in a Jupyter notebook. This is fine for experimentation, but for production, a command-line script ensures consistency and automation. The following script, export_quantized_model.py, encapsulates this logic.

# export_quantized_model.py
import torch
import torch.nn as nn
import torchvision
import argparse
import os
import logging
from torch.jit.mobile import_optimizer_type, optimize_for_mobile

# Setup basic logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

def get_model(model_name: str, pretrained: bool) -> nn.Module:
    """
    Loads a pretrained model from torchvision.
    In a real project, this would load your custom model architecture.
    """
    if model_name == 'resnet18':
        model = torchvision.models.resnet18(pretrained=pretrained)
        # In many CV tasks, we modify the final layer for a different number of classes.
        # For this example, we'll assume the standard 1000 ImageNet classes.
        model.eval() # IMPORTANT: Set model to evaluation mode
        return model
    else:
        raise ValueError(f"Model '{model_name}' not supported.")

def trace_and_optimize_model(model: nn.Module, model_path: str):
    """
    Traces the model, applies optimizations for mobile, and saves it.
    """
    logging.info("Starting model tracing and optimization...")
    try:
        # The example input tensor shape must match what the mobile app will provide.
        example_input = torch.rand(1, 3, 224, 224)
        
        # Scripting the model is the first step.
        # For more complex models with control flow, torch.jit.script might be needed.
        # For standard CNNs, tracing is often sufficient and simpler.
        traced_model = torch.jit.trace(model, example_input)

        # The core optimization step for mobile deployment.
        # This performs several passes, like operator fusion.
        optimized_model = optimize_for_mobile(traced_model)

        logging.info(f"Saving optimized model to {model_path}")
        optimized_model._save_for_lite_interpreter(model_path)
        logging.info("Model optimization and saving complete.")

    except Exception as e:
        logging.error(f"An error occurred during model tracing/optimization: {e}")
        raise

def main():
    parser = argparse.ArgumentParser(description="PyTorch Model Quantization and Export for Mobile")
    parser.add_argument('--model-name', type=str, default='resnet18', help='Name of the model to export.')
    parser.add-argument('--input-weights', type=str, required=False, help='Path to custom fp32 weights file (.pth). If not provided, uses torchvision pretrained weights.')
    parser.add_argument('--output-path', type=str, default='./models_output', help='Directory to save the exported models.')
    
    args = parser.parse_args()

    # --- Setup ---
    os.makedirs(args.output_path, exist_ok=True)
    
    # --- Load Floating Point Model ---
    logging.info(f"Loading fp32 model: {args.model_name}")
    fp32_model = get_model(args.model_name, pretrained=(args.input_weights is None))
    if args.input_weights:
        logging.info(f"Loading custom weights from {args.input_weights}")
        fp32_model.load_state_dict(torch.load(args.input_weights, map_location='cpu'))
    
    fp32_model.eval() # Ensure model is in eval mode

    # --- Apply Dynamic Quantization ---
    # The pitfall here is trying to quantize modules that are not supported.
    # We target Linear and Conv layers, which are the most common and beneficial to quantize.
    logging.info("Applying post-training dynamic quantization...")
    quantized_model = torch.quantization.quantize_dynamic(
        model=fp32_model,
        qconfig_spec={nn.Linear, nn.Conv2d}, # Specify which module types to quantize
        dtype=torch.qint8 # Quantize to 8-bit integer
    )
    logging.info("Quantization complete. Model architecture:")
    print(quantized_model)

    # --- Export Models ---
    fp32_model_filename = os.path.join(args.output_path, f"{args.model_name}_fp32.ptl")
    quantized_model_filename = os.path.join(args.output_path, f"{args.model_name}_quantized.ptl")
    
    # Export the original fp32 model for comparison
    trace_and_optimize_model(fp32_model, fp32_model_filename)
    
    # Export the quantized model
    trace_and_optimize_model(quantized_model, quantized_model_filename)

    # --- Final Verification ---
    fp32_size = os.path.getsize(fp32_model_filename) / (1024 * 1024)
    quantized_size = os.path.getsize(quantized_model_filename) / (1024 * 1024)
    logging.info(f"FP32 model size: {fp32_size:.2f} MB")
    logging.info(f"Quantized model size: {quantized_size:.2f} MB")
    logging.info(f"Size reduction: {(1 - quantized_size / fp32_size) * 100:.2f}%")

if __name__ == '__main__':
    main()

Running this script (python export_quantized_model.py) produces two files: resnet18_fp32.ptl and resnet18_quantized.ptl. The former is around 44.7MB, while the latter is approximately 11.5MB. This achieves our model size target in a single, automated step.

Phase 2: Building the Native Inference Layer on Android

With the model artifact ready, the next stage is integrating the PyTorch Mobile C++ runtime into an Android application. This is done via the Android NDK (Native Development Kit) and requires careful configuration of the build system. The goal is to create a clean C++ abstraction layer that can be called from Java/Kotlin code through the Java Native Interface (JNI).

2.1. Project Configuration

The build.gradle (Module: app) file needs to be configured to enable CMake for building the native code and to include the PyTorch Mobile dependency.

// build.gradle (Module: app)

android {
    // ... standard config
    
    defaultConfig {
        // ...
        ndk {
            abiFilters 'arm64-v8a' // Focus on 64-bit for modern devices. Add 'armeabi-v7a' for older device support.
        }
        externalNativeBuild {
            cmake {
                cppFlags '-std=c++17' // Use a modern C++ standard
            }
        }
    }

    externalNativeBuild {
        cmake {
            path 'src/main/cpp/CMakeLists.txt'
            version '3.18.1' // Use a specific, stable version
        }
    }
}

dependencies {
    // ... other dependencies
    
    // PyTorch Mobile Lite custom build. In a real project, you'd host this yourself
    // or use the prebuilt nightlies/releases.
    implementation 'org.pytorch:pytorch_android_lite:1.13.1'
    implementation 'org.pytorch:pytorch_android_torchvision_lite:1.13.1'
}

The CMakeLists.txt file is the heart of the native build. It defines the source files, finds the PyTorch library provided by the Gradle dependency, and links everything together. A misconfiguration here is a common source of build failures.

# src/main/cpp/CMakeLists.txt

cmake_minimum_required(VERSION 3.18.1)

project("MobileCVInference")

# The Pytorch dependency automatically unpacks the headers and libraries.
# We need to find them within the build directory.
# This path is crucial and depends on the Gradle plugin version.
# A common pitfall is this path changing between Android Gradle Plugin versions.
find_package(Torch REQUIRED
    HINTS "${CMAKE_CURRENT_SOURCE_DIR}/../../../build/intermediates/stripped_native_libs/release/out/lib/${ANDROID_ABI}")

# Add our native library source files
add_library(
        mobile_cv_inference
        SHARED
        native-lib.cpp # The JNI bridge file
        InferenceEngine.cpp # Our C++ engine implementation
)

# Add Android logging library
find_library(log-lib log)

# Link our library against LibTorch and Android log
target_link_libraries(
        mobile_cv_inference
        ${log-lib}
        "${TORCH_LIBRARIES}"
)

2.2. C++ Inference Engine

Creating a dedicated InferenceEngine class in C++ encapsulates all the LibTorch-specific logic. This keeps the JNI glue code clean and makes the C++ side more testable.

// src/main/cpp/InferenceEngine.h
#pragma once

#include <string>
#include <vector>
#include <torch/script.h>
#include <android/log.h>

class InferenceEngine {
public:
    InferenceEngine(const std::string& model_path);
    ~InferenceEngine();

    // The core inference function. Takes a pre-processed float array.
    std::vector<float> predict(const std::vector<float>& input_data, const std::vector<int64_t>& shape);

private:
    torch::jit::script::Module module_;
    c10::DeviceType device_type_ = c10::kCPU;
};

// src/main/cpp/InferenceEngine.cpp
#include "InferenceEngine.h"
#include <chrono>

#define LOG_TAG "InferenceEngine"
#define ALOGI(...) __android_log_print(ANDROID_LOG_INFO, LOG_TAG, __VA_ARGS__)
#define ALOGE(...) __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, __VA_ARGS__)

InferenceEngine::InferenceEngine(const std::string& model_path) {
    ALOGI("Initializing InferenceEngine...");
    try {
        // A critical step: load the model.
        // This can fail if the file is corrupt or the format is wrong.
        module_ = torch::jit::load(model_path);
        module_.eval(); // Ensure module is in evaluation mode
        ALOGI("Model loaded successfully from %s", model_path.c_str());
    } catch (const c10::Error& e) {
        ALOGE("Error loading the model: %s", e.what());
        // Re-throw to signal failure to the JNI layer
        throw std::runtime_error("Failed to load TorchScript model.");
    }
}

InferenceEngine::~InferenceEngine() {
    ALOGI("Destroying InferenceEngine.");
}

std::vector<float> InferenceEngine::predict(const std::vector<float>& input_data, const std::vector<int64_t>& shape) {
    if (module_._ivalue() == c10::IValue()) {
        ALOGE("Model is not loaded. Cannot run predict.");
        throw std::runtime_error("Model not loaded.");
    }
    
    auto start_time = std::chrono::high_resolution_clock::now();

    // Convert the input vector to a torch::Tensor.
    // This copy is a potential performance bottleneck for very large inputs.
    torch::Tensor input_tensor = torch::from_blob((void*)input_data.data(), shape, torch::kFloat32);

    // The inference call itself.
    at::IValue output_ivalue = module_.forward({input_tensor});

    if (!output_ivalue.isTensor()) {
        ALOGE("Model output is not a tensor.");
        throw std::runtime_error("Model output is not a tensor.");
    }

    torch::Tensor output_tensor = output_ivalue.toTensor();
    auto output_size = output_tensor.numel();
    std::vector<float> result(output_size);
    memcpy(result.data(), output_tensor.data_ptr<float>(), output_size * sizeof(float));
    
    auto end_time = std::chrono::high_resolution_clock::now();
    auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end_time - start_time).count();
    ALOGI("Inference completed in %lld ms", duration);

    return result;
}

2.3. JNI Bridge

The native-lib.cpp file acts as the bridge between the Kotlin/Java world and the C++ InferenceEngine. It handles the translation of data types. This layer must be extremely careful about memory management and exception handling. An uncaught C++ exception will crash the entire application.

// src/main/cpp/native-lib.cpp
#include <jni.h>
#include <string>
#include <memory>
#include "InferenceEngine.h"

// Using a smart pointer to manage the lifetime of the C++ engine instance.
static std::unique_ptr<InferenceEngine> engine_instance;

extern "C" JNIEXPORT jlong JNICALL
Java_com_example_mobilecv_InferenceJni_loadModel(
        JNIEnv* env,
        jobject /* this */,
        jstring model_path) {
    const char* c_model_path = env->GetStringUTFChars(model_path, nullptr);
    if (c_model_path == nullptr) return 0; // OutOfMemoryError already thrown

    try {
        engine_instance = std::make_unique<InferenceEngine>(std::string(c_model_path));
        env->ReleaseStringUTFChars(model_path, c_model_path);
        // Return a pointer to the object as a long to represent the handle.
        return reinterpret_cast<jlong>(engine_instance.get());
    } catch (const std::exception& e) {
        ALOGE("Caught exception in loadModel: %s", e.what());
        env->ReleaseStringUTFChars(model_path, c_model_path);
        // It's good practice to throw a Java exception back to the caller.
        env->ThrowNew(env->FindClass("java/io/IOException"), e.what());
        return 0;
    }
}

extern "C" JNIEXPORT jfloatArray JNICALL
Java_com_example_mobilecv_InferenceJni_predict(
        JNIEnv* env,
        jobject /* this */,
        jlong handle,
        jfloatArray input_data,
        jlongArray shape) {

    if (handle == 0 || engine_instance.get() != reinterpret_cast<InferenceEngine*>(handle)) {
        env->ThrowNew(env->FindClass("java/lang/IllegalStateException"), "Model not loaded or handle is invalid.");
        return nullptr;
    }
    
    // Convert Java arrays to C++ vectors
    jsize input_len = env->GetArrayLength(input_data);
    jfloat* input_body = env->GetFloatArrayElements(input_data, nullptr);
    std::vector<float> cpp_input(input_body, input_body + input_len);
    env->ReleaseFloatArrayElements(input_data, input_body, JNI_ABORT);

    jsize shape_len = env->GetArrayLength(shape);
    jlong* shape_body = env->GetLongArrayElements(shape, nullptr);
    std::vector<int64_t> cpp_shape(shape_body, shape_body + shape_len);
    env->ReleaseLongArrayElements(shape, shape_body, JNI_ABORT);

    try {
        std::vector<float> result = engine_instance->predict(cpp_input, cpp_shape);
        
        // Convert C++ result vector back to a Java float array
        jfloatArray result_array = env->NewFloatArray(result.size());
        env->SetFloatArrayRegion(result_array, 0, result.size(), result.data());
        return result_array;

    } catch(const std::exception& e) {
        ALOGE("Caught exception in predict: %s", e.what());
        env->ThrowNew(env->FindClass("java/lang/RuntimeException"), e.what());
        return nullptr;
    }
}

extern "C" JNIEXPORT void JNICALL
Java_com_example_mobilecv_InferenceJni_destroyModel(
        JNIEnv* env,
        jobject /* this */,
        jlong handle) {
    if (handle != 0 && engine_instance.get() == reinterpret_cast<InferenceEngine*>(handle)) {
        engine_instance.reset();
    }
}

Phase 3: Application Layer Integration

The final step is to use this native bridge from the main Android application code. This involves managing the model file (downloading, caching), preprocessing input images, and calling the JNI functions. All of this must be done off the main UI thread.

// src/main/java/com/example/mobilecv/InferenceJni.kt
package com.example.mobilecv

// This object provides the Kotlin interface to our JNI functions.
object InferenceJni {
    init {
        // This must match the library name in CMakeLists.txt
        System.loadLibrary("mobile_cv_inference")
    }

    // Loads the model and returns a handle (pointer as a Long).
    @Throws(IOException::class)
    external fun loadModel(modelPath: String): Long
    
    // Runs inference using the model handle.
    @Throws(RuntimeException::class)
    external fun predict(handle: Long, inputData: FloatArray, shape: LongArray): FloatArray?
    
    // Releases the native resources.
    external fun destroyModel(handle: Long)
}


// src/main/java/com/example/mobilecv/ModelExecutor.kt
package com.example.mobilecv

import android.content.Context
import android.graphics.Bitmap
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext
import org.pytorch.torchvision.TensorImageUtils
import java.io.File
import java.io.FileOutputStream
import java.io.IOException

// A robust class to manage the model lifecycle and execution.
class ModelExecutor(
    private val context: Context,
    private val dispatcher: CoroutineDispatcher = Dispatchers.IO
) {
    private var modelHandle: Long = 0
    private var isInitialized = false

    // A common mistake is to do file operations on the main thread.
    // Using withContext ensures this runs on a background thread.
    suspend fun initialize(modelAssetName: String) = withContext(dispatcher) {
        if (isInitialized) {
            // Already initialized, do nothing.
            return@withContext
        }
        try {
            val modelPath = getAssetPath(modelAssetName)
            modelHandle = InferenceJni.loadModel(modelPath)
            if (modelHandle == 0L) {
                throw IOException("JNI loadModel returned a null handle.")
            }
            isInitialized = true
        } catch (e: Exception) {
            // Log the error and propagate it.
            // In a real app, this would be reported to a crashlytics service.
            isInitialized = false
            throw IOException("Failed to initialize ModelExecutor: ${e.message}", e)
        }
    }

    suspend fun execute(bitmap: Bitmap): FloatArray? = withContext(dispatcher) {
        if (!isInitialized) {
            throw IllegalStateException("ModelExecutor is not initialized. Call initialize() first.")
        }

        // Preprocessing: Convert Bitmap to float array.
        // The normalization values must match what was used during training.
        val mean = floatArrayOf(0.485f, 0.456f, 0.406f)
        val std = floatArrayOf(0.229f, 0.224f, 0.225f)
        
        // This conversion can be a bottleneck. PyTorch's utility helps but it's still CPU-bound.
        val inputTensor = TensorImageUtils.bitmapToFloat32Tensor(
            bitmap, mean, std
        )
        val shape = inputTensor.shape()
        val floatArray = inputTensor.dataAsFloatArray

        // The actual JNI call.
        return@withContext InferenceJni.predict(modelHandle, floatArray, shape)
    }

    suspend fun close() = withContext(dispatcher) {
        if (isInitialized) {
            InferenceJni.destroyModel(modelHandle)
            modelHandle = 0
            isInitialized = false
        }
    }
    
    // Utility to copy the model from assets to a readable file path for the native code.
    @Throws(IOException::class)
    private fun getAssetPath(assetName: String): String {
        val file = File(context.filesDir, assetName)
        if (!file.exists()) {
            context.assets.open(assetName).use { inputStream ->
                FileOutputStream(file).use { outputStream ->
                    inputStream.copyTo(outputStream)
                }
            }
        }
        return file.absolutePath
    }
}

With this complete pipeline, the test on the same mid-range Android device yielded an average inference time of 42ms. The on-disk model size was 11.5MB, and peak runtime memory usage was reduced to under 80MB. This met all production requirements. The accuracy on our validation set dropped from 92.1% (fp32) to 91.3% (quantized), which was an entirely acceptable trade-off for the vast improvements in performance and resource consumption.

The implemented solution, while effective, is not without its limitations. Dynamic quantization is a low-hanging fruit; for models where every millisecond counts, investigating static quantization is the logical next step, though it introduces the overhead of managing a calibration dataset. The current on-device model management is basic—it loads from app assets. A production system would require a robust over-the-air update mechanism to push new model versions without a full app update. This would involve version checking, secure downloading, file integrity verification (e.g., SHA256 checksums), and an atomic file swap to prevent the model from becoming corrupted during an update. Furthermore, this implementation is CPU-only. For high-end devices, leveraging GPU or NPU hardware via delegates (like NNAPI on Android) could provide another significant performance boost, but at the cost of increased complexity and potential device-specific compatibility issues.


  TOC