Building an Image Classifier – Part 3: Creating a Production-Grade FastAPI Web Service
Part 3 of our series on building and deploying a full-stack image classification system. Learn to build a robust, scalable REST API using FastAPI that serves our trained PyramidNet model in production.

Prerequisites: This tutorial is Part 3 of our image classification series. You’ll need the trained PyramidNet model from Part 2: Training a PyramidNet Model from Scratch, which itself requires the preprocessing pipeline from Part 1: Preprocessing and Preparing the CIFAR-10 Dataset. Make sure you have the
best_pyramidnet_model.pth
file from Part 2 before proceeding.
Welcome to Part 3 of our image classification series! We’ve come a long way: we preprocessed raw CIFAR-10 images into a clean dataset, trained a state-of-the-art PyramidNet model achieving 94%+ accuracy, and now it’s time for the exciting part—making our model accessible to the world through a production-grade web API.
In this tutorial, we’ll build a robust REST API using FastAPI that can accept image uploads, process them through our trained model, and return predictions with confidence scores. But this isn’t just any API—we’ll implement production-ready features like proper error handling, request validation, performance monitoring, comprehensive logging, and security measures that you’d expect in enterprise software.
By the end of this post, you’ll have a fully functional web service that can classify images in real-time, complete with interactive API documentation and the foundation for deploying to production environments in the upcoming parts of this series.
Understanding FastAPI: The Modern Python Web Framework
Before we start coding, let’s understand why FastAPI is the perfect choice for serving machine learning models in production.
What Makes FastAPI Special?
FastAPI is a modern, high-performance web framework for building APIs with Python. It was specifically designed to address the shortcomings of older frameworks like Flask and Django when it comes to API development and performance.
Think of FastAPI as the “best of all worlds” approach to web development:
- Performance: Built on Starlette and Pydantic, it’s one of the fastest Python frameworks available
- Type Safety: Native Python type hints provide automatic validation and documentation
- Async Support: Built-in asynchronous programming for handling concurrent requests
- Automatic Documentation: Interactive API docs generated automatically from your code
Why FastAPI for Machine Learning APIs?
When serving ML models, you face unique challenges that FastAPI handles elegantly:
Performance Requirements: ML inference can be computationally expensive. FastAPI’s asynchronous capabilities allow your API to handle multiple requests concurrently instead of blocking on each prediction.
Complex Data Validation: Image uploads, model parameters, and prediction responses have strict requirements. FastAPI’s Pydantic integration automatically validates inputs and outputs, preventing the subtle bugs that can crash production systems.
Documentation Needs: Machine learning APIs often serve multiple teams (frontend developers, data scientists, DevOps). FastAPI generates beautiful, interactive documentation that everyone can understand and test.
Scalability: As your model gains popularity, you need an API that can scale. FastAPI’s architecture supports horizontal scaling and works seamlessly with modern deployment tools like Docker and Kubernetes.
FastAPI vs. Other Frameworks for ML
Feature | FastAPI | Flask | Django REST Framework |
---|---|---|---|
Performance | Very High | Medium | Medium |
Async Support | Native | Add-on | Add-on |
Type Safety | Built-in | Manual | Manual |
Auto Documentation | Yes | No | Limited |
Learning Curve | Low | Low | High |
ML-Specific Features | Excellent | Good | Good |
For machine learning specifically, FastAPI provides features that save significant development time:
- File Upload Handling: Native support for multipart form data and file validation
- Background Tasks: Perfect for async model inference or logging
- Dependency Injection: Clean way to manage model loading and database connections
- Middleware Support: Easy to add authentication, CORS, or custom preprocessing
Now let’s put these concepts into practice by building our image classification API.
Project Structure and Environment Setup
Let’s start by organizing our project structure to support a production-grade API alongside our existing training code.
Expanding Our Project Structure
Our project will grow from a simple training script to a full web service. Here’s how we’ll organize everything:
deepthought-image-classifier/
├── api/ # Web service code
│ ├── __init__.py
│ ├── main.py # FastAPI application entry point
│ ├── models/ # Pydantic models for request/response
│ │ ├── __init__.py
│ │ └── prediction.py
│ ├── routers/ # API route handlers
│ │ ├── __init__.py
│ │ ├── health.py # Health check endpoints
│ │ └── predict.py # Prediction endpoints
│ ├── services/ # Business logic
│ │ ├── __init__.py
│ │ ├── model_service.py # Model loading and inference
│ │ └── image_service.py # Image preprocessing
│ ├── middleware/ # Custom middleware
│ │ ├── __init__.py
│ │ ├── logging.py # Request logging
│ │ └── security.py # Security headers
│ └── config.py # Configuration management
├── models/ # Trained model files
│ └── best_pyramidnet_model.pth # From Part 2
├── logs/ # Application logs
├── tests/ # API tests
│ ├── __init__.py
│ ├── test_api.py
│ └── test_model_service.py
├── requirements.txt # Dependencies
├── train_model.py # From Part 2
└── preprocessing.py # From Part 1
This structure follows the separation of concerns principle that’s crucial for maintainable production code:
api/
: Contains all web service logic, organized by responsibilitymodels/
: Pydantic models that define our API contracts (not ML models)routers/
: HTTP endpoint definitions, grouped by functionalityservices/
: Core business logic, independent of web frameworkmiddleware/
: Cross-cutting concerns like logging and securitytests/
: Comprehensive test suite for reliability
Installing Dependencies
Let’s extend our requirements to include the FastAPI ecosystem:
# From your existing virtual environment
pip install fastapi uvicorn python-multipart aiofiles
pip install python-jose[cryptography] passlib[bcrypt]
pip install pytest httpx pytest-asyncio pydantic-settings
# Update requirements
pip freeze > requirements.txt
Let’s understand what each new dependency does:
fastapi
: The core web frameworkuvicorn
: High-performance ASGI server for running FastAPIpython-multipart
: Handles file uploads and form dataaiofiles
: Asynchronous file operations for better performancepython-jose
: JSON Web Token handling for authenticationpytest
+httpx
: Testing framework with async HTTP client support
Create the Directory Structure
Let’s create our new directories:
mkdir -p api/{models,routers,services,middleware}
mkdir -p logs tests
touch api/__init__.py api/models/__init__.py api/routers/__init__.py
touch api/services/__init__.py api/middleware/__init__.py tests/__init__.py
Now we have a professional project structure ready for production-grade development.
Configuration Management: The Foundation of Production APIs
Before writing any API code, we need a robust configuration system. Production applications need to handle different environments (development, staging, production) with different settings, and our configuration should be secure, flexible, and easy to manage.
Why Configuration Matters
In production, your API will face challenges that don’t exist in development:
- Different environments: Development uses local models, production uses optimized versions
- Security requirements: API keys, secrets, and sensitive data must be protected
- Performance tuning: Different environments need different timeout values, batch sizes, etc.
- Operational concerns: Logging levels, monitoring endpoints, and debug modes vary by environment
Let’s create a configuration system that handles all these concerns elegantly.
Creating the Configuration Module
Create api/config.py
:
touch api/config.py
import os
from pathlib import Path
from typing import Optional
from pydantic import BaseSettings, field_validator
import torch
class Settings(BaseSettings):
"""
Application configuration using Pydantic BaseSettings.
This approach provides several production benefits:
- Type validation ensures configuration values are correct
- Environment variable support for containerized deployments
- Default values prevent startup failures
- Automatic documentation of configuration options
"""
# Application metadata
app_name: str = "PyramidNet Image Classifier API"
app_version: str = "1.0.0"
app_description: str = "Production-grade image classification API using PyramidNet CNN"
# Environment configuration
environment: str = "development" # development, staging, production
debug: bool = True
# API configuration
host: str = "127.0.0.1"
port: int = 8000
reload: bool = True # Auto-reload for development
# Security settings
cors_origins: list = ["http://localhost:3000", "http://localhost:8080"]
api_key: Optional[str] = None # Optional API key authentication
max_request_size: int = 10 * 1024 * 1024 # 10MB max file size
# Model configuration
model_path: str = "models/best_pyramidnet_model.pth"
model_device: str = "auto" # auto, cpu, cuda
model_batch_size: int = 1 # For batch inference
# Performance settings
worker_timeout: int = 300 # 5 minutes for model loading
prediction_timeout: int = 30 # 30 seconds for single prediction
max_concurrent_requests: int = 10
# Logging configuration
log_level: str = "INFO"
log_file: str = "logs/api.log"
access_log: bool = True
# Monitoring and health checks
health_check_timeout: int = 5
metrics_enabled: bool = True
@field_validator('environment')
@classmethod
def validate_environment(cls, v):
"""Ensure environment is one of the allowed values"""
allowed = ['development', 'staging', 'production']
if v not in allowed:
raise ValueError(f'Environment must be one of {allowed}')
return v
@field_validator('model_device')
@classmethod
def validate_device(cls, v):
"""Automatically detect device or validate manual setting"""
if v == "auto":
return "cuda" if torch.cuda.is_available() else "cpu"
elif v in ["cpu", "cuda"]:
if v == "cuda" and not torch.cuda.is_available():
raise ValueError("CUDA requested but not available")
return v
else:
raise ValueError("Device must be 'auto', 'cpu', or 'cuda'")
@field_validator('model_path')
@classmethod
def validate_model_path(cls, v):
"""Ensure model file exists (skip validation in development if not present)"""
model_path = Path(v)
if not model_path.exists():
# In development, we might not have the model yet
if cls.model_config.env_prefix and os.getenv(f"{cls.model_config.env_prefix}ENVIRONMENT", "development") != "production":
print(f"Warning: Model file not found: {v} (OK in development)")
return v
else:
# In production or when no environment is set, require the model file
raise ValueError(f"Model file not found: {v}")
return v
def is_production(self) -> bool:
"""Helper to check if running in production"""
return self.environment == "production"
def is_development(self) -> bool:
"""Helper to check if running in development"""
return self.environment == "development"
model_config = {
"env_file": ".env", # Load from .env file if present
"env_prefix": "CLASSIFIER_", # Environment variables with this prefix override settings
"case_sensitive": False
}
# Create global settings instance
settings = Settings()
# CIFAR-10 class names - these should match your training data
CIFAR10_CLASSES = [
"airplane", "automobile", "bird", "cat", "deer",
"dog", "frog", "horse", "ship", "truck"
]
# API response messages
API_MESSAGES = {
"model_loaded": "PyramidNet model loaded successfully",
"prediction_success": "Image classified successfully",
"invalid_image": "Invalid or corrupted image file",
"file_too_large": f"File size exceeds maximum allowed ({settings.max_request_size // (1024*1024)}MB)",
"prediction_error": "Error occurred during prediction",
"health_check_passed": "API is healthy and ready to serve requests"
}
Environment-Specific Configuration
Create a .env
file for local development (never commit this to version control) and if you are working in a git repository, add it to .gitignore
to be safe.
From the project root:
touch .env
echo ".env" >> .gitignore
Now add your environment variables to the file.
# .env file for local development
CLASSIFIER_ENVIRONMENT=development
CLASSIFIER_DEBUG=true
CLASSIFIER_LOG_LEVEL=DEBUG
CLASSIFIER_MODEL_DEVICE=auto
CLASSIFIER_CORS_ORIGINS=["http://localhost:3000","http://localhost:8080","http://127.0.0.1:3000"]
For production, you would set these as actual environment variables:
# Production environment variables
export CLASSIFIER_ENVIRONMENT=production
export CLASSIFIER_DEBUG=false
export CLASSIFIER_LOG_LEVEL=INFO
export CLASSIFIER_HOST=0.0.0.0
export CLASSIFIER_PORT=8000
export CLASSIFIER_MODEL_DEVICE=cuda
export CLASSIFIER_MAX_CONCURRENT_REQUESTS=50
This configuration system provides several production benefits:
-
Type Safety: Pydantic validates all configuration values at startup, preventing runtime errors from typos or invalid settings.
-
Environment Flexibility: The same codebase works in development, staging, and production by simply changing environment variables.
-
Security: Sensitive values like API keys are loaded from environment variables, not hardcoded in source code.
-
Documentation: Each setting has a clear type and description, making it easy for new team members to understand the configuration options.
-
Fail-Fast: If required configuration is missing or invalid, the application won’t start, preventing deployment of broken services.
Now let’s use this configuration system to build our model service layer.
Handling Development Environment Setup
During development, you might not have your trained model file yet. To handle this gracefully, we can modify our model service to work without a model during development.
The key is to update our model path validator to be more lenient in development mode, and ensure our model service can start without immediately requiring the model file.
Model Service: Loading and Managing the Trained Model
The model service is the heart of our API, it’s responsible for loading our trained PyramidNet model, managing its lifecycle, and providing a clean interface for making predictions. We need to design this service to be efficient, reliable, and easy to test.
Understanding Model Service Responsibilities
A production model service handles several critical concerns:
- Model Lifecycle Management: Loading the model once at startup (not on every request), handling loading errors gracefully, and providing health checks to ensure the model is ready.
- Thread Safety: Multiple requests might hit our API simultaneously. Our model service must handle concurrent predictions safely.
- Error Handling: Models can fail in unexpected ways (out of memory, corrupted inputs, etc.). The service must catch these errors and respond appropriately.
- Performance Monitoring: We need to track prediction latency, memory usage, and error rates for operational monitoring.
Let’s implement a robust model service that handles all these concerns.
Creating the Model Service
Create api/services/model_service.py
:
import asyncio
import logging
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Any
import numpy as np
from contextlib import asynccontextmanager
from ..config import settings, CIFAR10_CLASSES
# Import our PyramidNet architecture from Part 2
# We'll need to copy the model classes here or import them
class IdentityPadding(nn.Module):
"""Identity padding for PyramidNet residual connections"""
def __init__(self, in_channels, out_channels, stride=1):
super(IdentityPadding, self).__init__()
if stride == 2:
self.pooling = nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True)
else:
self.pooling = None
self.add_channels = out_channels - in_channels
def forward(self, x):
out = F.pad(x, (0, 0, 0, 0, 0, self.add_channels))
if self.pooling is not None:
out = self.pooling(out)
return out
class ResidualBlock(nn.Module):
"""Residual block for PyramidNet"""
def __init__(self, in_channels, out_channels, stride=1):
super(ResidualBlock, self).__init__()
self.bn1 = nn.BatchNorm2d(in_channels)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,
stride=stride, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
stride=1, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.down_sample = IdentityPadding(in_channels, out_channels, stride)
self.stride = stride
def forward(self, x):
shortcut = self.down_sample(x)
out = self.bn1(x)
out = self.conv1(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn3(out)
out += shortcut
return out
class PyramidNet(nn.Module):
"""PyramidNet architecture - copied from Part 2"""
def __init__(self, num_layers, alpha, block, num_classes=10):
super(PyramidNet, self).__init__()
self.in_channels = 16
self.num_layers = num_layers
self.addrate = alpha / (3*self.num_layers*1.0)
self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3,
stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(16)
self.layer1 = self.get_layers(block, stride=1)
self.layer2 = self.get_layers(block, stride=2)
self.layer3 = self.get_layers(block, stride=2)
self.out_channels = int(round(self.out_channels))
self.bn_out= nn.BatchNorm2d(self.out_channels)
self.relu_out = nn.ReLU(inplace=True)
self.avgpool = nn.AvgPool2d(8, stride=1)
self.fc_out = nn.Linear(self.out_channels, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out',
nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def get_layers(self, block, stride):
layers_list = []
for _ in range(self.num_layers):
self.out_channels = self.in_channels + self.addrate
layers_list.append(block(int(round(self.in_channels)),
int(round(self.out_channels)),
stride))
self.in_channels = self.out_channels
stride=1
return nn.Sequential(*layers_list)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.bn_out(x)
x = self.relu_out(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc_out(x)
return x
def pyramidnet():
"""Factory function to create PyramidNet model"""
block = ResidualBlock
model = PyramidNet(num_layers=18, alpha=48, block=block)
return model
class ModelService:
"""
Production-grade model service for PyramidNet inference.
Key design principles:
- Singleton pattern: One model instance per application
- Thread-safe: Handles concurrent requests safely
- Async-friendly: Non-blocking operations where possible
- Comprehensive error handling and logging
- Performance monitoring built-in
"""
def __init__(self):
self.model: Optional[PyramidNet] = None
self.device: torch.device = None
self.is_loaded: bool = False
self.load_time: Optional[float] = None
self.prediction_count: int = 0
self.total_prediction_time: float = 0.0
self.logger = logging.getLogger(__name__)
self._lock = asyncio.Lock() # For thread-safe model loading
async def load_model(self) -> None:
"""
Load the trained PyramidNet model with comprehensive error handling.
This method is designed to be called once at application startup.
It includes validation, performance monitoring, and detailed logging.
"""
async with self._lock: # Prevent multiple simultaneous loads
if self.is_loaded:
self.logger.info("Model already loaded, skipping...")
return
start_time = time.time()
self.logger.info(f"Loading PyramidNet model from {settings.model_path}")
try:
# Validate model file exists
model_path = Path(settings.model_path)
if not model_path.exists():
raise FileNotFoundError(f"Model file not found: {model_path}")
# Configure device
self.device = torch.device(settings.model_device)
self.logger.info(f"Using device: {self.device}")
# Create model architecture
self.model = pyramidnet()
# Load checkpoint with proper PyTorch 2.6+ compatibility
try:
# First try with weights_only=True (secure mode)
checkpoint = torch.load(model_path, map_location=self.device, weights_only=True)
except Exception as e:
self.logger.warning(f"Secure loading failed: {e}")
self.logger.info("Falling back to weights_only=False for older checkpoint format")
# Fall back to weights_only=False for older checkpoints
# This is safe if you trust the source of your model file
checkpoint = torch.load(model_path, map_location=self.device, weights_only=False)
# Handle different checkpoint formats
if 'model_state_dict' in checkpoint:
self.model.load_state_dict(checkpoint['model_state_dict'])
self.logger.info(f"Loaded model from epoch {checkpoint.get('epoch', 'unknown')}")
self.logger.info(f"Best validation accuracy: {checkpoint.get('best_val_acc', 'unknown'):.2f}%")
else:
self.model.load_state_dict(checkpoint)
# Move to device and set to evaluation mode
self.model.to(self.device)
self.model.eval()
# Warm up the model with a dummy prediction
await self._warmup_model()
self.load_time = time.time() - start_time
self.is_loaded = True
self.logger.info(f"Model loaded successfully in {self.load_time:.2f} seconds")
self.logger.info(f"Model parameters: {sum(p.numel() for p in self.model.parameters()):,}")
except Exception as e:
self.logger.error(f"Failed to load model: {str(e)}")
self.is_loaded = False
raise RuntimeError(f"Model loading failed: {str(e)}")
async def _warmup_model(self) -> None:
"""
Warm up the model with a dummy prediction.
This ensures that CUDA memory is allocated and the model is ready
for fast inference on the first real request.
"""
self.logger.info("Warming up model...")
try:
dummy_input = torch.randn(1, 3, 32, 32).to(self.device)
with torch.no_grad():
_ = self.model(dummy_input)
self.logger.info("Model warmup completed")
except Exception as e:
self.logger.warning(f"Model warmup failed: {str(e)}")
async def predict(self, image_tensor: torch.Tensor) -> Dict[str, Any]:
"""
Make a prediction on a single image with comprehensive error handling.
Args:
image_tensor: Preprocessed image tensor of shape (1, 3, 32, 32)
Returns:
Dictionary containing prediction results and metadata
"""
if not self.is_loaded:
raise RuntimeError("Model not loaded. Call load_model() first.")
start_time = time.time()
try:
# Validate input tensor
if image_tensor.dim() != 4 or image_tensor.shape != (1, 3, 32, 32):
raise ValueError(f"Expected tensor shape (1, 3, 32, 32), got {image_tensor.shape}")
# Move to device if needed
if image_tensor.device != self.device:
image_tensor = image_tensor.to(self.device)
# Make prediction
with torch.no_grad():
logits = self.model(image_tensor)
probabilities = torch.nn.functional.softmax(logits, dim=1)
confidence_scores = probabilities.cpu().numpy()[0]
predicted_class_idx = np.argmax(confidence_scores)
# Prepare response
prediction_time = time.time() - start_time
response = {
"predicted_class": CIFAR10_CLASSES[predicted_class_idx],
"predicted_class_id": int(predicted_class_idx),
"confidence": float(confidence_scores[predicted_class_idx]),
"all_confidences": {
CIFAR10_CLASSES[i]: float(confidence_scores[i])
for i in range(len(CIFAR10_CLASSES))
},
"prediction_time_ms": round(prediction_time * 1000, 2),
"model_device": str(self.device)
}
# Update statistics
self.prediction_count += 1
self.total_prediction_time += prediction_time
self.logger.debug(f"Prediction completed in {prediction_time*1000:.2f}ms: {response['predicted_class']} ({response['confidence']:.3f})")
return response
except Exception as e:
error_msg = f"Prediction failed: {str(e)}"
self.logger.error(error_msg)
raise RuntimeError(error_msg)
async def batch_predict(self, image_tensors: List[torch.Tensor]) -> List[Dict[str, Any]]:
"""
Make predictions on multiple images efficiently.
This method batches multiple images together for more efficient GPU utilization.
"""
if not self.is_loaded:
raise RuntimeError("Model not loaded. Call load_model() first.")
if not image_tensors:
return []
start_time = time.time()
try:
# Stack tensors into a batch
batch_tensor = torch.stack(image_tensors).to(self.device)
# Make batch prediction
with torch.no_grad():
logits = self.model(batch_tensor)
probabilities = torch.nn.functional.softmax(logits, dim=1)
confidence_scores = probabilities.cpu().numpy()
# Prepare responses
responses = []
for i, scores in enumerate(confidence_scores):
predicted_class_idx = np.argmax(scores)
responses.append({
"predicted_class": CIFAR10_CLASSES[predicted_class_idx],
"predicted_class_id": int(predicted_class_idx),
"confidence": float(scores[predicted_class_idx]),
"all_confidences": {
CIFAR10_CLASSES[j]: float(scores[j])
for j in range(len(CIFAR10_CLASSES))
}
})
prediction_time = time.time() - start_time
self.logger.info(f"Batch prediction of {len(image_tensors)} images completed in {prediction_time*1000:.2f}ms")
return responses
except Exception as e:
error_msg = f"Batch prediction failed: {str(e)}"
self.logger.error(error_msg)
raise RuntimeError(error_msg)
def get_stats(self) -> Dict[str, Any]:
"""Get model service statistics for monitoring"""
avg_prediction_time = (
self.total_prediction_time / self.prediction_count
if self.prediction_count > 0 else 0
)
return {
"is_loaded": self.is_loaded,
"device": str(self.device) if self.device else None,
"load_time_seconds": self.load_time,
"prediction_count": self.prediction_count,
"average_prediction_time_ms": round(avg_prediction_time * 1000, 2),
"model_parameters": sum(p.numel() for p in self.model.parameters()) if self.model else 0
}
def is_healthy(self) -> bool:
"""Health check for monitoring systems"""
return self.is_loaded and self.model is not None
# Global model service instance
model_service = ModelService()
Understanding the Model Service Design
This model service implements several production patterns:
- Singleton Pattern: We create one global
model_service
instance that’s shared across the entire application. This prevents loading multiple copies of the model into memory. - Async/Await Support: The service uses
async
/await
for non-blocking operations. This is crucial for web APIs that need to handle multiple concurrent requests. - Thread Safety: The
asyncio.Lock()
ensures that model loading is thread-safe, preventing race conditions if multiple requests try to load the model simultaneously. - Comprehensive Error Handling: Every method includes try/catch blocks with specific error messages. This prevents cryptic errors from reaching users.
- Performance Monitoring: The service tracks prediction counts, timing, and other metrics. This data is essential for monitoring production performance.
- Health Checks: The
is_healthy()
method provides a simple way for load balancers and monitoring systems to check if the service is working.
Next, let’s create an image service that handles the preprocessing pipeline from Part 1, adapted for real-time API use.
Image Processing Service: From Upload to Model Input
Our API needs to transform uploaded images into the exact format our PyramidNet model expects. This involves recreating the preprocessing pipeline from Part 1, but adapted for single images received through HTTP requests rather than batch processing.
Understanding Image Service Requirements
When users upload images to our API, we face several challenges:
- Format Variety: Users might upload JPEG, PNG, GIF, or other formats. We need to handle all common image types gracefully.
- Size Variations: Real-world images come in all sizes—from tiny thumbnails to massive high-resolution photos. Our model expects exactly 32×32 pixels.
- Quality Issues: User uploads might be corrupted, extremely blurry, or otherwise problematic. We need to detect and handle these issues.
- Performance: Unlike batch preprocessing, we need to process images quickly (under 100ms) to provide responsive API experience.
- Memory Management: We can’t let image processing consume unlimited memory, especially with concurrent requests.
Let’s create an image service that handles these challenges robustly.
Creating the Image Service
Create api/services/image_service.py
:
import asyncio
import io
import logging
import time
from typing import Tuple, Optional, Dict, List
import torch
import torchvision.transforms as transforms
import numpy as np
from PIL import Image, UnidentifiedImageError
import cv2
from ..config import settings
class ImageService:
"""
Production-grade image preprocessing service.
Handles image validation, preprocessing, and conversion to model-ready tensors.
Designed for single-image processing in a web API context.
"""
def __init__(self):
self.logger = logging.getLogger(__name__)
# Create the transformation pipeline
# This matches the preprocessing from Part 1, adapted for single images
self.transform = transforms.Compose([
transforms.Resize((32, 32)), # Direct resize - simpler for API use
transforms.ToTensor(),
transforms.Normalize(
mean=[0.4914, 0.4822, 0.4465], # CIFAR-10 dataset statistics
std=[0.2023, 0.1994, 0.2010]
),
])
# Supported image formats
self.supported_formats = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif'}
# Quality thresholds
self.min_size = 8 # Minimum dimension in pixels
self.max_size = 4096 # Maximum dimension to prevent memory issues
self.blur_threshold = 100.0 # Laplacian variance threshold
async def validate_image_file(self, file_content: bytes, filename: str) -> None:
"""
Validate uploaded image file before processing.
Args:
file_content: Raw bytes of the uploaded file
filename: Original filename for format detection
Raises:
ValueError: If the image is invalid or unsupported
"""
# Check file size
if len(file_content) > settings.max_request_size:
raise ValueError(f"File size exceeds maximum allowed ({settings.max_request_size // (1024*1024)}MB)")
if len(file_content) == 0:
raise ValueError("Empty file uploaded")
# Check file extension
file_extension = filename.lower().split('.')[-1] if '.' in filename else ''
if f'.{file_extension}' not in self.supported_formats:
raise ValueError(f"Unsupported image format: {file_extension}. Supported formats: {', '.join(self.supported_formats)}")
# Try to open the image to validate it's not corrupted
try:
image = Image.open(io.BytesIO(file_content))
image.verify() # This will raise an exception if the image is corrupted
except UnidentifiedImageError:
raise ValueError("Invalid or corrupted image file")
except Exception as e:
raise ValueError(f"Error reading image: {str(e)}")
async def load_and_validate_image(self, file_content: bytes) -> Image.Image:
"""
Load image from bytes and perform quality validation.
Args:
file_content: Raw bytes of the image file
Returns:
PIL Image object ready for preprocessing
Raises:
ValueError: If the image fails quality checks
"""
try:
# Load image
image = Image.open(io.BytesIO(file_content))
# Convert to RGB if necessary (handles RGBA, grayscale, etc.)
if image.mode != 'RGB':
image = image.convert('RGB')
# Check dimensions
width, height = image.size
if width < self.min_size or height < self.min_size:
raise ValueError(f"Image too small: {width}x{height}. Minimum size: {self.min_size}x{self.min_size}")
if width > self.max_size or height > self.max_size:
raise ValueError(f"Image too large: {width}x{height}. Maximum size: {self.max_size}x{self.max_size}")
# Check if image is too blurry (optional quality check)
if await self._is_image_too_blurry(image):
self.logger.warning("Uploaded image appears to be very blurry")
# Note: We warn but don't reject - users might want to classify blurry images
return image
except Exception as e:
if isinstance(e, ValueError):
raise # Re-raise validation errors
raise ValueError(f"Error processing image: {str(e)}")
async def _is_image_too_blurry(self, image: Image.Image) -> bool:
"""
Check if image is too blurry using Laplacian variance.
This is the same technique we used in Part 1 for filtering blurry images.
"""
try:
# Convert PIL image to numpy array for OpenCV
img_array = np.array(image)
# Convert to grayscale
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
# Calculate Laplacian variance
laplacian_var = cv2.Laplacian(gray, cv2.CV_64F).var()
return laplacian_var < self.blur_threshold
except Exception as e:
self.logger.warning(f"Error checking image blur: {str(e)}")
return False # If we can't check, assume it's okay
async def preprocess_image(self, image: Image.Image) -> torch.Tensor:
"""
Preprocess PIL image into model-ready tensor.
This applies the same preprocessing pipeline we used in training:
1. Resize to 32x32 pixels
2. Convert to tensor (0-1 range)
3. Normalize with CIFAR-10 statistics
4. Add batch dimension
Args:
image: PIL Image object
Returns:
Preprocessed tensor of shape (1, 3, 32, 32)
"""
start_time = time.time()
try:
# Apply the transformation pipeline
tensor = self.transform(image)
# Add batch dimension (model expects batched input)
tensor = tensor.unsqueeze(0) # Shape: (1, 3, 32, 32)
processing_time = time.time() - start_time
self.logger.debug(f"Image preprocessing completed in {processing_time*1000:.2f}ms")
return tensor
except Exception as e:
raise ValueError(f"Error preprocessing image: {str(e)}")
async def process_uploaded_file(self, file_content: bytes, filename: str) -> torch.Tensor:
"""
Complete pipeline: validate, load, and preprocess an uploaded image file.
This is the main entry point for the API - it takes raw uploaded bytes
and returns a tensor ready for model inference.
Args:
file_content: Raw bytes of uploaded file
filename: Original filename
Returns:
Preprocessed tensor ready for model inference
"""
start_time = time.time()
try:
# Step 1: Validate the file
await self.validate_image_file(file_content, filename)
# Step 2: Load and validate the image
image = await self.load_and_validate_image(file_content)
# Step 3: Preprocess for model
tensor = await self.preprocess_image(image)
total_time = time.time() - start_time
self.logger.info(f"Complete image processing pipeline completed in {total_time*1000:.2f}ms")
return tensor
except Exception as e:
self.logger.error(f"Image processing failed: {str(e)}")
raise
def get_preprocessing_info(self) -> Dict[str, any]:
"""
Get information about the preprocessing pipeline for API documentation.
"""
return {
"supported_formats": list(self.supported_formats),
"target_size": "32x32 pixels",
"normalization": {
"mean": [0.4914, 0.4822, 0.4465],
"std": [0.2023, 0.1994, 0.2010]
},
"max_file_size_mb": settings.max_request_size // (1024 * 1024),
"min_dimension": self.min_size,
"max_dimension": self.max_size
}
# Global image service instance
image_service = ImageService()
Key Design Decisions Explained
- Simplified Preprocessing: Unlike Part 1 where we used padding to preserve aspect ratios, here we use direct resizing. For an API, simplicity and speed often outweigh the small accuracy gain from padding.
- Async Processing: Even though image processing is CPU-bound, we use
async
methods to maintain consistency with the rest of our API and allow for potential future optimizations. - Comprehensive Validation: We validate file size, format, dimensions, and image integrity. This prevents many common issues that could crash the API.
- Error Handling: Each step has specific error messages that help users understand what went wrong with their upload.
- Performance Monitoring: We track processing times to help identify performance bottlenecks. Quality Checks: We include blur detection from Part 1, but as a warning rather than a rejection. Users might legitimately want to classify blurry images.
Understanding the Preprocessing Pipeline
The transformation pipeline deserves special attention:
self.transform = transforms.Compose([
transforms.Resize((32, 32)), # Direct resize to model input size
transforms.ToTensor(), # Convert PIL Image to tensor, scale to [0,1]
transforms.Normalize( # Apply CIFAR-10 dataset normalization
mean=[0.4914, 0.4822, 0.4465], # Per-channel means
std=[0.2023, 0.1994, 0.2010] # Per-channel standard deviations
),
])
- Why These Specific Values?: The normalization values are the mean and standard deviation calculated from the entire CIFAR-10 training dataset. Using the same normalization ensures our API inputs match what the model saw during training.
- Why Resize Instead of Pad?: In Part 1, we padded images to preserve aspect ratios. For API use, direct resizing is simpler and faster, with minimal impact on accuracy for most real-world images.
- Batch Dimension: The
unsqueeze(0)
adds a batch dimension, converting from shape(3, 32, 32)
to(1, 3, 32, 32)
. PyTorch models expect batched input even for single images.
Now let’s create the Pydantic models that define our API contracts.
Defining API Models with Pydantic
Before we build the actual API endpoints, we need to define the data structures our API will use for requests and responses. Pydantic models serve as the “contract” between our API and its users, they define exactly what data can be sent and received, with automatic validation and documentation.
Understanding Pydantic’s Role in Production APIs
Pydantic models provide several critical benefits for production APIs:
- Automatic Validation: If a client sends invalid data, Pydantic automatically rejects it with clear error messages before it reaches your business logic.
- Type Safety: Your code can rely on the fact that incoming data matches the expected types and structure.
- API Documentation: FastAPI automatically generates OpenAPI (Swagger) documentation from your Pydantic models.
- Serialization: Pydantic handles converting Python objects to JSON and vice versa, including proper handling of dates, decimals, and other complex types.
- Versioning Support: As your API evolves, Pydantic models help you manage backward compatibility and versioning.
Creating the API Models
Create api/models/prediction.py
:
from datetime import datetime, timezone
from typing import Dict, List, Optional, Any
from pydantic import BaseModel, Field, field_validator, model_validator
from enum import Enum
class ImageFormat(str, Enum):
"""Supported image formats"""
JPEG = "jpeg"
JPG = "jpg"
PNG = "png"
BMP = "bmp"
TIFF = "tiff"
TIF = "tif"
class CIFAR10Class(str, Enum):
"""CIFAR-10 classification classes"""
AIRPLANE = "airplane"
AUTOMOBILE = "automobile"
BIRD = "bird"
CAT = "cat"
DEER = "deer"
DOG = "dog"
FROG = "frog"
HORSE = "horse"
SHIP = "ship"
TRUCK = "truck"
class PredictionConfidence(BaseModel):
"""Individual class confidence score"""
class_name: CIFAR10Class = Field(
...,
description="The name of the classification class"
)
confidence: float = Field(
...,
ge=0.0,
le=1.0,
description="Confidence score between 0.0 and 1.0"
)
class PredictionResult(BaseModel):
"""
Main prediction result model.
This model defines the structure of successful prediction responses.
It includes the primary prediction, confidence scores for all classes,
and metadata about the prediction process.
"""
predicted_class: CIFAR10Class = Field(
...,
description="The most likely class for the uploaded image"
)
predicted_class_id: int = Field(
...,
ge=0,
le=9,
description="Numeric ID of the predicted class (0-9)"
)
confidence: float = Field(
...,
ge=0.0,
le=1.0,
description="Confidence score for the predicted class"
)
all_confidences: Dict[str, float] = Field(
...,
description="Confidence scores for all CIFAR-10 classes"
)
prediction_time_ms: float = Field(
...,
ge=0.0,
description="Time taken for prediction in milliseconds"
)
model_device: str = Field(
...,
description="Device used for prediction (cpu/cuda)"
)
timestamp: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
description="UTC timestamp when prediction was made"
)
@field_validator('all_confidences')
@classmethod
def validate_all_confidences(cls, v):
"""Ensure all_confidences contains all CIFAR-10 classes"""
expected_classes = [class_item.value for class_item in CIFAR10Class]
if set(v.keys()) != set(expected_classes):
raise ValueError("all_confidences must contain all CIFAR-10 classes")
# Validate that all confidence scores are between 0 and 1
for class_name, confidence in v.items():
if not 0.0 <= confidence <= 1.0:
raise ValueError(f"Confidence for {class_name} must be between 0.0 and 1.0")
return v
@model_validator(mode='after')
def validate_confidence_consistency(self):
"""Ensure the main confidence matches the predicted class confidence"""
predicted_class = self.predicted_class
all_confidences = self.all_confidences
confidence = self.confidence
if predicted_class in all_confidences:
expected_confidence = all_confidences[predicted_class]
if abs(confidence - expected_confidence) > 1e-6:
raise ValueError("Main confidence score must match predicted class confidence")
return self
class BatchPredictionRequest(BaseModel):
"""Request model for batch predictions (future enhancement)"""
images: List[str] = Field(
...,
min_length=1,
max_length=10,
description="List of base64-encoded images (max 10)"
)
include_all_confidences: bool = Field(
default=True,
description="Whether to include confidence scores for all classes"
)
class BatchPredictionResult(BaseModel):
"""Response model for batch predictions"""
predictions: List[PredictionResult] = Field(
...,
description="List of prediction results"
)
total_processing_time_ms: float = Field(
...,
ge=0.0,
description="Total time for all predictions in milliseconds"
)
batch_size: int = Field(
...,
ge=1,
description="Number of images processed"
)
class ErrorResponse(BaseModel):
"""
Standardized error response model.
This ensures all API errors have a consistent structure,
making it easier for clients to handle errors programmatically.
"""
error: str = Field(
...,
description="Error type or category"
)
message: str = Field(
...,
description="Human-readable error message"
)
detail: Optional[str] = Field(
None,
description="Additional error details for debugging"
)
timestamp: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
description="UTC timestamp when error occurred"
)
request_id: Optional[str] = Field(
None,
description="Unique identifier for the request (for tracking)"
)
class HealthCheckResponse(BaseModel):
"""Health check response model"""
status: str = Field(
...,
description="Health status (healthy/unhealthy)"
)
timestamp: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
description="UTC timestamp of health check"
)
version: str = Field(
...,
description="API version"
)
model_loaded: bool = Field(
...,
description="Whether the ML model is loaded and ready"
)
uptime_seconds: Optional[float] = Field(
None,
ge=0.0,
description="API uptime in seconds"
)
class ModelStatsResponse(BaseModel):
"""Model statistics response model"""
is_loaded: bool = Field(
...,
description="Whether the model is loaded"
)
device: Optional[str] = Field(
None,
description="Device the model is running on"
)
load_time_seconds: Optional[float] = Field(
None,
ge=0.0,
description="Time taken to load the model"
)
prediction_count: int = Field(
...,
ge=0,
description="Total number of predictions made"
)
average_prediction_time_ms: float = Field(
...,
ge=0.0,
description="Average prediction time in milliseconds"
)
model_parameters: int = Field(
...,
ge=0,
description="Number of trainable parameters in the model"
)
class ImageUploadResponse(BaseModel):
"""Response for successful image upload and preprocessing"""
message: str = Field(
...,
description="Success message"
)
filename: str = Field(
...,
description="Original filename"
)
file_size_bytes: int = Field(
...,
ge=0,
description="Size of uploaded file in bytes"
)
preprocessing_time_ms: float = Field(
...,
ge=0.0,
description="Time taken for image preprocessing"
)
image_dimensions: Dict[str, int] = Field(
...,
description="Original image dimensions"
)
# Example models for API documentation
class PredictionExample:
"""Example data for API documentation"""
successful_prediction = {
"predicted_class": "cat",
"predicted_class_id": 3,
"confidence": 0.891,
"all_confidences": {
"airplane": 0.012,
"automobile": 0.008,
"bird": 0.045,
"cat": 0.891,
"deer": 0.023,
"dog": 0.015,
"frog": 0.003,
"horse": 0.002,
"ship": 0.001,
"truck": 0.000
},
"prediction_time_ms": 45.2,
"model_device": "cuda",
"timestamp": "2025-06-21T10:30:45.123456Z"
}
error_response = {
"error": "ValidationError",
"message": "Invalid image format",
"detail": "Supported formats: jpg, jpeg, png, bmp, tiff",
"timestamp": "2025-06-21T10:30:45.123456Z",
"request_id": "req_123456789"
}
health_check = {
"status": "healthy",
"timestamp": "2025-06-21T10:30:45.123456Z",
"version": "1.0.0",
"model_loaded": True,
"uptime_seconds": 3600.5
}
Understanding the Model Design
Let’s break down the key design decisions in our Pydantic models:
- Enums for Fixed Values: Using
CIFAR10Class
andImageFormat
enums ensures that only valid values are accepted and provides clear documentation of supported options. - Field Validation: The
Field()
function allows us to specify constraints (likege=0.0, le=1.0
for confidence scores) and descriptions for automatic documentation. - Custom Validators: The
@validator
decorators allow complex validation logic, like ensuring the main confidence score matches the predicted class confidence. - Response Consistency: All responses include timestamps and relevant metadata, making debugging and monitoring easier.
- Error Standardization: The
ErrorResponse
model ensures all API errors have the same structure, making client error handling predictable.
Why This Level of Validation Matters
In production APIs, comprehensive validation prevents numerous issues:
- Security: Invalid input can’t reach your business logic, reducing attack surfaces.
- Reliability: Type mismatches and invalid data are caught early, preventing runtime errors.
- User Experience: Clear validation error messages help users fix their requests quickly.
- Debugging: Consistent error formats make it easier to diagnose issues in production.
- Documentation: Pydantic models automatically generate accurate API documentation that stays in sync with your code.
Now let’s create the actual API endpoints that use these models.
Building the API Endpoints
Now we’ll create the actual HTTP endpoints that expose our image classification functionality. We’ll organize our endpoints into logical groups using FastAPI’s router system, which helps maintain clean, modular code as our API grows.
Understanding Production API Design Principles
Before diving into the code, let’s understand what makes an API truly production-ready. Our endpoints need to handle much more than just the “happy path” where everything works perfectly.
Real-world APIs must handle:
- Multiple request patterns: Single images for real-time use, batches for efficiency
- Diverse client needs: Web applications, mobile apps, automated systems
- Failure scenarios: Invalid files, network issues, model failures
- Performance requirements: Low latency, high throughput, efficient resource usage
- Security concerns: Input validation, rate limiting, error information leakage
- Operational needs: Health checks, monitoring, debugging information
Our endpoint strategy:
- Health endpoints: Multiple levels of health checking for different operational needs
- Single prediction: Optimized for interactive, real-time classification
- Batch prediction: Efficient processing of multiple images simultaneously
- Information endpoints: API documentation and capability discovery
- Comprehensive error handling: Clear, actionable error messages
- Request tracking: Unique identifiers for debugging and monitoring
This design supports everything from a simple web form where users upload photos to sophisticated batch processing systems that classify thousands of images efficiently.
Health Check Endpoints
Let’s start with health check endpoints. These are crucial for production deployments, load balancers, monitoring systems, and orchestration platforms like Kubernetes use these endpoints to determine if your service is ready to receive traffic.
Create api/routers/health.py
:
import time
from fastapi import APIRouter, HTTPException
from ..models.prediction import HealthCheckResponse, ModelStatsResponse
from ..services.model_service import model_service
from ..config import settings
router = APIRouter(prefix="/health", tags=["Health"])
# Track startup time for uptime calculation
startup_time = time.time()
@router.get(
"/",
response_model=HealthCheckResponse,
summary="Basic health check",
description="Check if the API is running and responsive"
)
async def health_check():
"""
Basic health check endpoint.
This endpoint performs a quick check to verify the API is running.
It's designed to be fast and lightweight for load balancer health checks.
"""
try:
uptime = time.time() - startup_time
return HealthCheckResponse(
status="healthy",
version=settings.app_version,
model_loaded=model_service.is_healthy(),
uptime_seconds=uptime
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Health check failed: {str(e)}"
)
@router.get(
"/ready",
response_model=HealthCheckResponse,
summary="Readiness check",
description="Check if the API is ready to serve prediction requests"
)
async def readiness_check():
"""
Readiness check endpoint.
This endpoint performs a deeper check to verify the API is ready to serve
prediction requests. It checks if the model is loaded and functional.
Kubernetes uses readiness checks to determine if a pod should receive traffic.
"""
try:
if not model_service.is_healthy():
raise HTTPException(
status_code=503,
detail="Model not loaded or unhealthy"
)
uptime = time.time() - startup_time
return HealthCheckResponse(
status="ready",
version=settings.app_version,
model_loaded=True,
uptime_seconds=uptime
)
except HTTPException:
raise # Re-raise HTTP exceptions
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Readiness check failed: {str(e)}"
)
@router.get(
"/live",
response_model=HealthCheckResponse,
summary="Liveness check",
description="Check if the API process is alive"
)
async def liveness_check():
"""
Liveness check endpoint.
This endpoint performs a minimal check to verify the API process is alive.
It should only fail if the process is completely broken.
Kubernetes uses liveness checks to determine if a pod should be restarted.
"""
try:
uptime = time.time() - startup_time
return HealthCheckResponse(
status="alive",
version=settings.app_version,
model_loaded=model_service.is_healthy(),
uptime_seconds=uptime
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Liveness check failed: {str(e)}"
)
@router.get(
"/stats",
response_model=ModelStatsResponse,
summary="Model statistics",
description="Get detailed statistics about the model service"
)
async def model_stats():
"""
Model statistics endpoint.
Provides detailed information about model performance and usage.
Useful for monitoring and debugging.
"""
try:
stats = model_service.get_stats()
return ModelStatsResponse(**stats)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Failed to get model statistics: {str(e)}"
)
Prediction Endpoints
Now let’s create the main prediction endpoints. Create api/routers/predict.py
:
import logging
import uuid
import time
from datetime import datetime, timezone
from typing import Optional, List
from fastapi import APIRouter, File, UploadFile, HTTPException, Form, Depends
from fastapi.responses import JSONResponse
from ..models.prediction import PredictionResult, BatchPredictionResult, ErrorResponse, PredictionExample
from ..services.model_service import model_service
from ..services.image_service import image_service
from ..config import settings, API_MESSAGES
router = APIRouter(prefix="/predict", tags=["Prediction"])
logger = logging.getLogger(__name__)
async def check_model_ready():
"""Dependency to ensure model is loaded before prediction requests"""
if not model_service.is_healthy():
raise HTTPException(
status_code=503,
detail="Model not loaded. Please try again in a few moments."
)
@router.post(
"/",
response_model=PredictionResult,
summary="Classify an uploaded image",
description="Upload an image file and get a CIFAR-10 classification prediction",
responses={
200: {
"description": "Successful prediction",
"content": {
"application/json": {
"example": PredictionExample.successful_prediction
}
}
},
400: {
"description": "Bad request - invalid image or format",
"model": ErrorResponse,
"content": {
"application/json": {
"example": PredictionExample.error_response
}
}
},
413: {"description": "File too large"},
503: {"description": "Model not ready"},
}
)
async def predict_image(
file: UploadFile = File(
...,
description="Image file to classify (JPEG, PNG, BMP, TIFF)",
media_type="image/*"
),
request_id: Optional[str] = Form(
None,
description="Optional request ID for tracking"
),
model_ready: None = Depends(check_model_ready)
):
"""
Classify an uploaded image using the trained PyramidNet model.
This endpoint accepts image files in common formats and returns:
- The predicted CIFAR-10 class
- Confidence score for the prediction
- Confidence scores for all classes
- Processing time and metadata
**Supported formats:** JPEG, PNG, BMP, TIFF
**Maximum file size:** 10MB
**Expected accuracy:** ~94% on CIFAR-10 test set
"""
# Generate request ID if not provided
if not request_id:
request_id = str(uuid.uuid4())
logger.info(f"Prediction request {request_id} started for file: {file.filename}")
try:
# Validate file
if not file.filename:
raise HTTPException(
status_code=400,
detail="No file provided"
)
# Read file content
file_content = await file.read()
if len(file_content) == 0:
raise HTTPException(
status_code=400,
detail="Empty file uploaded"
)
logger.debug(f"Request {request_id}: File size: {len(file_content)} bytes")
# Process image through the pipeline
try:
processed_tensor = await image_service.process_uploaded_file(
file_content,
file.filename
)
except ValueError as e:
logger.warning(f"Request {request_id}: Image validation failed: {str(e)}")
raise HTTPException(
status_code=400,
detail=str(e)
)
# Make prediction
try:
prediction_result = await model_service.predict(processed_tensor)
except Exception as e:
logger.error(f"Request {request_id}: Prediction failed: {str(e)}")
raise HTTPException(
status_code=500,
detail="Prediction failed. Please try again."
)
logger.info(
f"Request {request_id}: Successful prediction - {prediction_result['predicted_class']} "
f"({prediction_result['confidence']:.3f})"
)
# Add metadata to response
prediction_result['timestamp'] = datetime.now(timezone.utc)
return PredictionResult(**prediction_result)
except HTTPException:
raise # Re-raise HTTP exceptions
except Exception as e:
logger.error(f"Request {request_id}: Unexpected error: {str(e)}")
raise HTTPException(
status_code=500,
detail="An unexpected error occurred during prediction"
)
@router.post(
"/batch",
response_model=BatchPredictionResult,
summary="Classify multiple images",
description="Upload multiple images and get predictions for all of them",
responses={
400: {"description": "Invalid request or image format", "model": ErrorResponse},
413: {"description": "Request too large"},
503: {"description": "Model not ready"},
}
)
async def batch_predict(
files: List[UploadFile] = File(
...,
description="List of image files to classify (max 10 files)",
media_type="image/*"
),
request_id: Optional[str] = Form(
None,
description="Optional request ID for tracking"
),
model_ready: None = Depends(check_model_ready)
):
"""
Classify multiple uploaded images in a single request.
This endpoint is useful for batch processing scenarios where you need
to classify several images at once. It's more efficient than making
multiple single-image requests.
**Limitations:**
- Maximum 10 images per request
- Same file size limits apply to each image
- Total request size cannot exceed server limits
"""
# Generate request ID if not provided
if not request_id:
request_id = str(uuid.uuid4())
logger.info(f"Batch prediction request {request_id} started with {len(files)} files")
# Validate batch size
if len(files) > 10:
raise HTTPException(
status_code=400,
detail="Maximum 10 images allowed per batch request"
)
if len(files) == 0:
raise HTTPException(
status_code=400,
detail="At least one image file must be provided"
)
start_time = time.time()
processed_tensors = []
try:
# Process each image
for i, file in enumerate(files):
if not file.filename:
raise HTTPException(
status_code=400,
detail=f"File {i+1} has no filename"
)
file_content = await file.read()
# Process image through pipeline
try:
tensor = await image_service.process_uploaded_file(
file_content,
file.filename
)
processed_tensors.append(tensor)
except ValueError as e:
logger.warning(f"Request {request_id}: File {i+1} validation failed: {str(e)}")
raise HTTPException(
status_code=400,
detail=f"File {i+1} ({file.filename}): {str(e)}"
)
# Make batch prediction
try:
predictions = await model_service.batch_predict(processed_tensors)
except Exception as e:
logger.error(f"Request {request_id}: Batch prediction failed: {str(e)}")
raise HTTPException(
status_code=500,
detail="Batch prediction failed. Please try again."
)
# Add metadata to each prediction
prediction_results = []
for pred in predictions:
pred['timestamp'] = datetime.now(timezone.utc)
prediction_results.append(PredictionResult(**pred))
total_time = time.time() - start_time
result = BatchPredictionResult(
predictions=prediction_results,
total_processing_time_ms=round(total_time * 1000, 2),
batch_size=len(files)
)
logger.info(
f"Request {request_id}: Batch prediction completed - {len(files)} images "
f"processed in {total_time*1000:.2f}ms"
)
return result
except HTTPException:
raise # Re-raise HTTP exceptions
except Exception as e:
logger.error(f"Request {request_id}: Unexpected error in batch prediction: {str(e)}")
raise HTTPException(
status_code=500,
detail="An unexpected error occurred during batch prediction"
)
@router.get(
"/info",
summary="Get prediction endpoint information",
description="Get information about supported formats and preprocessing"
)
async def prediction_info():
"""
Get information about the prediction endpoints.
Returns details about supported formats, preprocessing pipeline,
and usage limits.
"""
try:
preprocessing_info = image_service.get_preprocessing_info()
return {
"api_version": settings.app_version,
"model_info": {
"architecture": "PyramidNet",
"dataset": "CIFAR-10",
"num_classes": 10,
"input_size": "32x32",
"expected_accuracy": "~94%"
},
"preprocessing": preprocessing_info,
"batch_limits": {
"max_files_per_batch": 10,
"max_concurrent_requests": settings.max_concurrent_requests
},
"performance": {
"typical_prediction_time_ms": "20-50ms",
"batch_processing_advantage": "2-5x faster than individual requests"
}
}
except Exception as e:
logger.error(f"Failed to get prediction info: {str(e)}")
raise HTTPException(
status_code=500,
detail="Failed to retrieve prediction information"
)
Understanding the Prediction Endpoint Design
The prediction endpoints we’ve built implement several sophisticated patterns that are crucial for production machine learning APIs. Let’s explore the key design decisions and their implications:
Single vs. Batch Prediction Strategies
Single Image Endpoint (/predict/
):
This endpoint is optimized for real-time, interactive use cases where users upload one image at a time. The design prioritizes:
- Low Latency: Each request is processed immediately without waiting for other requests
- Simple Error Handling: If one image fails, it doesn’t affect other operations
- User Experience: Perfect for web forms or mobile apps where users expect immediate feedback
Batch Endpoint (/predict/batch
):
This endpoint is designed for efficiency when processing multiple images simultaneously. Key benefits include:
- GPU Utilization: Batching multiple images together maximizes GPU throughput by utilizing parallel processing capabilities
- Reduced Overhead: Processing 10 images in one batch request eliminates the HTTP overhead of 10 separate requests
- Cost Efficiency: In cloud deployments, batch processing reduces the total compute time and associated costs
The 10-image limit strikes a balance between efficiency and memory usage—larger batches could overwhelm system memory, especially with high-resolution images.
Request Validation and Error Handling Philosophy
Our endpoints implement a “fail-fast” validation strategy:
# Example validation flow:
# 1. File presence check (immediate 400 if missing)
# 2. File size validation (prevents memory exhaustion)
# 3. Format validation (catches unsupported formats early)
# 4. Image integrity check (detects corrupted files)
# 5. Only then proceed to expensive preprocessing
This approach prevents resource waste on invalid requests and provides clear feedback to users about what went wrong. Each validation step includes specific error messages that help developers integrate with your API successfully.
Dependency Injection for Model Readiness
The Depends(check_model_ready)
pattern ensures that prediction endpoints only accept requests when the model is fully loaded and ready. This prevents the common production issue where APIs return 500 errors during startup while models are still loading.
async def check_model_ready():
if not model_service.is_healthy():
raise HTTPException(status_code=503, detail="Model not loaded...")
The 503 status code specifically indicates “Service Unavailable,” which tells load balancers and clients that they should retry the request after a short delay, rather than treating it as a permanent failure.
Request Tracking and Observability
Every request gets a unique identifier (request_id
) that flows through all logs and can be returned to clients. This enables:
- Distributed Tracing: Following a request through multiple services and log files
- Customer Support: Users can provide the request ID when reporting issues
- Performance Analysis: Correlating slow requests with specific images or conditions
- A/B Testing: Tracking different model versions or preprocessing pipelines
Async Processing Patterns
Despite image processing being CPU-bound, we use async/await
throughout the prediction pipeline:
# Why async for CPU-bound work?
processed_tensor = await image_service.process_uploaded_file(...)
prediction_result = await model_service.predict(processed_tensor)
This design choice provides several benefits:
- Consistency: All API operations use the same async pattern, simplifying error handling and middleware
- Future-Proofing: Easy to add I/O operations later (database logging, external API calls, etc.)
- Resource Efficiency: The async runtime can handle thousands of concurrent connections with minimal memory overhead
- Cancellation Support: Long-running requests can be cancelled cleanly if clients disconnect
Response Structure Design
Our response models include both the essential prediction data and operational metadata:
{
"predicted_class": "cat", # Primary result
"confidence": 0.891, # Decision confidence
"all_confidences": {...}, # Full probability distribution
"prediction_time_ms": 45.2, # Performance monitoring
"model_device": "cuda", # Infrastructure info
"timestamp": "2025-06-21T10:30:45Z" # Audit trail
}
This structure serves multiple purposes:
- Application Logic: Frontend applications can use confidence scores for UI decisions (showing warnings for low-confidence predictions)
- Analytics: Prediction times help identify performance bottlenecks
- Debugging: Timestamps and device information help diagnose issues
- Model Evaluation: All confidence scores enable offline analysis of model performance
Security and Rate Limiting Considerations
The prediction endpoints implement several security measures:
- File Size Limits: Prevent denial-of-service attacks through large file uploads
- Format Validation: Reduce attack surface by rejecting unexpected file types
- Request Rate Limiting: Protect against automated abuse while allowing legitimate usage
- Input Sanitization: All image processing happens in isolated services with comprehensive error handling
Production Performance Characteristics
Based on this design, you can expect the following performance characteristics:
- Single Predictions: 20-100ms on modern hardware (CPU/GPU dependent)
- Batch Predictions: 2-5x more efficient than individual requests for the same images
- Concurrent Handling: 10-50 simultaneous requests depending on available memory and compute
- Error Rate: <1% for valid images (primarily due to network issues or resource constraints)
Integration Patterns
These endpoints support several common integration patterns:
- Synchronous Web Applications: Direct form uploads with immediate response display
- Asynchronous Processing: Background job processing with result callbacks
- Mobile Applications: Efficient batch upload when connectivity is available
- API Aggregation: Easy to integrate into larger microservice architectures
The consistent error handling and response formats make it straightforward for client applications to handle both success and failure cases predictably.
Now let’s examine how our middleware components work together to create a professional API experience.
Middleware: Adding Production-Grade Features
Middleware functions run before and after each request, allowing us to add cross-cutting concerns like logging, security headers, and error handling. Let’s implement middleware that makes our API production-ready.
Request Logging Middleware
Understanding what’s happening in your API is crucial for debugging and monitoring. Create api/middleware/logging.py
:
import time
import uuid
import logging
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from typing import Callable
class RequestLoggingMiddleware(BaseHTTPMiddleware):
"""
Comprehensive request logging middleware for production monitoring.
Logs every request with timing, status codes, and unique identifiers.
Essential for debugging, performance monitoring, and security auditing.
"""
def __init__(self, app):
super().__init__(app)
self.logger = logging.getLogger("api.requests")
async def dispatch(self, request: Request, call_next: Callable) -> Response:
# Generate unique request ID for tracking
request_id = str(uuid.uuid4())[:8]
# Extract client information safely
client_ip = getattr(request.client, 'host', 'unknown') if request.client else 'unknown'
user_agent = request.headers.get("user-agent", "unknown")
# Start timing
start_time = time.time()
# Log incoming request
self.logger.info(
f"REQUEST {request_id} | {request.method} {request.url.path} | "
f"Client: {client_ip} | User-Agent: {user_agent[:50]}{'...' if len(user_agent) > 50 else ''}"
)
# Add request ID to request state for use in endpoints
request.state.request_id = request_id
try:
# Process the request
response = await call_next(request)
# Calculate processing time
process_time = time.time() - start_time
# Log response
log_level = logging.INFO if response.status_code < 400 else logging.WARNING
self.logger.log(
log_level,
f"RESPONSE {request_id} | {response.status_code} | "
f"{process_time*1000:.2f}ms | {request.method} {request.url.path}"
)
# Add custom headers for monitoring
response.headers["X-Request-ID"] = request_id
response.headers["X-Process-Time"] = f"{process_time:.3f}"
return response
except Exception as e:
# Log exceptions
process_time = time.time() - start_time
self.logger.error(
f"ERROR {request_id} | {str(e)} | "
f"{process_time*1000:.2f}ms | {request.method} {request.url.path}",
exc_info=True
)
raise
Understanding the Logging Middleware Design
The logging middleware we’ve implemented goes far beyond simple request logging—it creates a comprehensive audit trail that’s essential for production operations. Let’s explore why each component matters and how they work together:
Request Identification and Tracing
Every request gets a unique identifier (request_id
) that flows through the entire request lifecycle:
request_id = str(uuid.uuid4())[:8] # Short, unique identifier
request.state.request_id = request_id # Available throughout request
response.headers["X-Request-ID"] = request_id # Returned to client
This simple pattern enables powerful debugging capabilities:
- Distributed Tracing: Follow a single request through multiple services and log files
- Customer Support: Users can provide the request ID when reporting issues
- Performance Analysis: Correlate slow requests with specific conditions or inputs
- Error Investigation: Quickly find all log entries related to a problematic request
Client Information Extraction
The middleware safely extracts client information that’s crucial for security and debugging:
client_ip = getattr(request.client, 'host', 'unknown') if request.client else 'unknown'
user_agent = request.headers.get("user-agent", "unknown")
Why this matters in production:
- Security Monitoring: Detect unusual traffic patterns or potential attacks
- Performance Optimization: Identify client types that might need special handling
- User Experience: Track mobile vs. desktop usage patterns
- Compliance: Many regulations require logging of client access patterns
Comprehensive Request Lifecycle Tracking
The middleware tracks the complete request journey with precise timing:
start_time = time.time()
# ... process request ...
process_time = time.time() - start_time
This timing data serves multiple purposes:
- Performance Monitoring: Identify slow endpoints or degrading performance
- Capacity Planning: Understand typical request processing times for scaling decisions
- SLA Compliance: Monitor whether your API meets response time commitments
- Bottleneck Identification: Pinpoint which parts of your API need optimization
Intelligent Log Level Management
The middleware uses different log levels based on response status codes:
log_level = logging.INFO if response.status_code < 400 else logging.WARNING
This approach provides several benefits:
- Noise Reduction: Successful requests are logged at INFO level, failures at WARNING
- Alert Configuration: Monitoring systems can alert on WARNING+ logs for immediate attention
- Log Analysis: Easy filtering between normal operations and potential issues
- Storage Optimization: Production systems can archive INFO logs while keeping WARNING+ logs accessible
Production Headers for Observability
The middleware adds custom headers that enhance observability:
response.headers["X-Request-ID"] = request_id
response.headers["X-Process-Time"] = f"{process_time:.3f}"
These headers enable:
- Client-Side Monitoring: Frontend applications can track their API request performance
- Load Balancer Integration: Some load balancers use these headers for health checking
- Debugging Support: API clients can log these values for troubleshooting
- Performance Correlation: Match client-side timing with server-side processing time
Exception Handling and Error Correlation
The middleware’s exception handling ensures that even crashes are properly logged:
except Exception as e:
process_time = time.time() - start_time
self.logger.error(
f"ERROR {request_id} | {str(e)} | "
f"{process_time*1000:.2f}ms | {request.method} {request.url.path}",
exc_info=True
)
raise
Key benefits of this approach:
- Error Context: Every error log includes the request ID, timing, and full context
- Stack Traces: The
exc_info=True
parameter captures full exception details - Performance Impact: Track how long failed requests took (important for timeout analysis)
- Error Propagation: The middleware re-raises exceptions so they’re handled appropriately
Log Format Design Philosophy
The structured log format we use follows production best practices:
REQUEST {request_id} | {method} {path} | Client: {ip} | User-Agent: {agent}
RESPONSE {request_id} | {status} | {time}ms | {method} {path}
ERROR {request_id} | {error} | {time}ms | {method} {path}
This format is designed for:
- Machine Parsing: Easy to parse with log analysis tools like ELK stack or Splunk
- Human Readability: Quick visual scanning during manual debugging
- Correlation: Request ID appears first for easy grep/search operations
- Completeness: Every log entry contains enough context to understand the operation
Production Considerations and Extensions
While our middleware provides a solid foundation, production systems often extend it with:
Sensitive Data Filtering:
# Example extension: Filter sensitive data from logs
def sanitize_path(path: str) -> str:
# Remove sensitive data from URLs (API keys, tokens, etc.)
import re
return re.sub(r'([?&])(api_key|token)=[^&]*', r'\1\2=***', path)
Request Body Logging (for debugging):
# Log request bodies for specific endpoints (be careful with sensitive data)
if should_log_body(request.url.path):
body = await request.body() # Be cautious with large files
logger.debug(f"Request body: {body[:100]}...") # Truncate for safety
Geographic Information:
# Add geographic information for security and analytics
def get_country_from_ip(ip: str) -> str:
# Use a GeoIP service to get country information
# Useful for security monitoring and usage analytics
pass
Rate Limiting Integration:
# Log rate limiting decisions
if hasattr(request.state, 'rate_limit_hit'):
logger.warning(f"Rate limit hit for {client_ip}")
Monitoring and Alerting Integration
The logging middleware creates data that monitoring systems can use for alerts:
- Error Rate Monitoring: Alert when error rates exceed thresholds
- Performance Degradation: Alert when average response times increase
- Security Events: Alert on suspicious patterns (multiple errors from same IP)
- Capacity Planning: Track request volume trends for scaling decisions
This comprehensive logging approach transforms your API from a “black box” into a fully observable system where you can understand exactly what’s happening at any moment, debug issues quickly, and optimize performance based on real usage data.
The investment in proper logging pays dividends throughout the entire lifecycle of your API, from development debugging to production incident response to long-term optimization and scaling decisions.
Security Headers Middleware
Security headers protect your API from various attacks. Create api/middleware/security.py
:
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from typing import Callable
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
"""
Security headers middleware for production deployment.
Adds essential security headers to protect against common web vulnerabilities:
- XSS attacks
- Clickjacking
- MIME type sniffing
- Information leakage
- HTTPS enforcement
"""
def __init__(self, app):
super().__init__(app)
async def dispatch(self, request: Request, call_next: Callable) -> Response:
response = await call_next(request)
# Security headers for production
security_headers = {
# Prevent XSS attacks
"X-Content-Type-Options": "nosniff",
# Prevent clickjacking
"X-Frame-Options": "DENY",
# XSS protection (deprecated but still useful for older browsers)
"X-XSS-Protection": "1; mode=block",
# Hide server information
"Server": "PyramidNet-API",
# Referrer policy
"Referrer-Policy": "strict-origin-when-cross-origin",
# Content Security Policy (more permissive for API with docs)
"Content-Security-Policy": "default-src 'self'; img-src 'self' data:; style-src 'self' 'unsafe-inline'; script-src 'self'",
# HTTPS enforcement (only in production)
"Strict-Transport-Security": "max-age=31536000; includeSubDomains",
# Permissions policy (restrict browser features)
"Permissions-Policy": "geolocation=(), microphone=(), camera=()"
}
# Apply security headers
for header, value in security_headers.items():
if value: # Only set headers with non-empty values
response.headers[header] = value
# Remove potentially sensitive headers if they exist
sensitive_headers = ["X-Powered-By", "Server"]
for header in sensitive_headers:
if header in response.headers and header != "Server": # Keep our custom Server header
del response.headers[header]
return response
Understanding Security Middleware in Production
The security headers middleware we’ve implemented provides essential protection against common web vulnerabilities. Each header serves a specific security purpose that’s critical for production APIs:
Core Security Headers Explained
Content Security Policy (CSP): Our CSP policy "default-src 'self'; img-src 'self' data:; style-src 'self' 'unsafe-inline'; script-src 'self'"
creates a strict allowlist for resource loading. This prevents XSS attacks by blocking unauthorized script execution and data exfiltration attempts.
X-Content-Type-Options: The "nosniff"
directive prevents browsers from MIME-type sniffing, which could allow attackers to disguise malicious files as safe content types.
X-Frame-Options: Setting this to "DENY"
prevents your API from being embedded in iframes, protecting against clickjacking attacks where malicious sites overlay invisible frames to trick users.
Strict-Transport-Security (HSTS): Forces HTTPS connections for all future requests, preventing man-in-the-middle attacks and ensuring encrypted communication.
Production Security Considerations
These headers work together to create defense-in-depth protection:
- API Documentation Security: While our CSP is permissive enough to allow FastAPI’s interactive docs to function, it still blocks most malicious scripts
- Information Disclosure Prevention: We remove server identification headers that could help attackers fingerprint your infrastructure
- Browser Feature Restrictions: The Permissions-Policy header disables potentially dangerous browser features like geolocation and camera access
This middleware provides baseline security that’s essential for any production API, especially one handling user-uploaded content like images.
Rate Limiting Middleware
Rate limiting prevents abuse and ensures fair resource usage. While we’ll implement a simple version here, production systems typically use Redis-based solutions.
Create api/middleware/rate_limiting.py
:
import time
from collections import defaultdict, deque
from fastapi import Request, Response, HTTPException
from starlette.middleware.base import BaseHTTPMiddleware
from typing import Callable, Dict, Deque
class SimpleRateLimitMiddleware(BaseHTTPMiddleware):
"""
Simple in-memory rate limiting middleware.
For production, consider using Redis-based rate limiting for:
- Distributed systems
- Persistent rate limit state
- More sophisticated algorithms
"""
def __init__(self, app, requests_per_minute: int = 60):
super().__init__(app)
self.requests_per_minute = requests_per_minute
self.window_seconds = 60
# In-memory storage (use Redis in production)
self.requests: Dict[str, Deque[float]] = defaultdict(deque)
def _get_client_identifier(self, request: Request) -> str:
"""Get client identifier for rate limiting"""
# In production, consider using API keys or user IDs
client_ip = getattr(request.client, 'host', 'unknown') if request.client else "unknown"
return client_ip
def _cleanup_old_requests(self, request_times: Deque[float], current_time: float):
"""Remove requests outside the current window"""
while request_times and current_time - request_times[0] > self.window_seconds:
request_times.popleft()
async def dispatch(self, request: Request, call_next: Callable) -> Response:
# Skip rate limiting for health checks
if request.url.path.startswith("/health"):
return await call_next(request)
client_id = self._get_client_identifier(request)
current_time = time.time()
# Get client's request history
client_requests = self.requests[client_id]
# Clean up old requests
self._cleanup_old_requests(client_requests, current_time)
# Check rate limit
if len(client_requests) >= self.requests_per_minute:
# Calculate time until rate limit resets
oldest_request = client_requests[0] if client_requests else current_time
reset_time = oldest_request + self.window_seconds
retry_after = max(1, int(reset_time - current_time))
raise HTTPException(
status_code=429,
detail={
"error": "Rate limit exceeded",
"message": f"Maximum {self.requests_per_minute} requests per minute allowed",
"retry_after": retry_after
}
)
# Record this request
client_requests.append(current_time)
# Process request
response = await call_next(request)
# Add rate limit headers
response.headers["X-RateLimit-Limit"] = str(self.requests_per_minute)
response.headers["X-RateLimit-Remaining"] = str(self.requests_per_minute - len(client_requests))
# Calculate reset time more accurately
if client_requests:
oldest_request = client_requests[0]
reset_time = oldest_request + self.window_seconds
else:
reset_time = current_time + self.window_seconds
response.headers["X-RateLimit-Reset"] = str(int(reset_time))
return response
Understanding Rate Limiting in Production APIs
The rate limiting middleware we’ve implemented provides essential protection against abuse while ensuring fair resource access. Let’s explore the key design decisions:
Sliding Window Approach
Our implementation uses a sliding window strategy with in-memory storage:
self.requests: Dict[str, Deque[float]] = defaultdict(deque)
This approach prevents users from “bursting” at window boundaries and provides smoother traffic distribution compared to fixed windows. The deque automatically manages old request cleanup for memory efficiency.
Client Identification and Headers
We identify clients by IP address and communicate limits through standard headers:
response.headers["X-RateLimit-Limit"] = str(self.requests_per_minute)
response.headers["X-RateLimit-Remaining"] = str(remaining_requests)
response.headers["X-RateLimit-Reset"] = str(int(reset_time))
These headers follow RFC 6585 standards, enabling clients to adapt their behavior and avoid hitting limits unnecessarily.
Production Considerations
While our in-memory approach works well for single instances, production systems often extend this with:
- Redis-based storage: For distributed rate limiting across multiple API instances
- Tiered limits: Different limits for different endpoint types (predictions vs. health checks)
- User-based identification: Using API keys or authentication tokens instead of just IP addresses
The 429 status code with structured error messages helps clients understand when they can retry, making the API more developer-friendly while maintaining protection against abuse.
These middleware components work together to create a professional API experience:
- Request Logging: Every request gets a unique ID and detailed logging, making debugging much easier.
- Security Headers: Standard security headers protect against common web vulnerabilities.
- Rate Limiting: Prevents abuse while providing clear feedback to clients about limits.
Now let’s tie everything together in our main FastAPI application.
Middleware Module Setup
To ensure our middleware components work together properly, we need to create the appropriate module structure. Create api/middleware/__init__.py
:
from .logging import RequestLoggingMiddleware
from .security import SecurityHeadersMiddleware
from .rate_limiting import SimpleRateLimitMiddleware
__all__ = [
"RequestLoggingMiddleware",
"SecurityHeadersMiddleware",
"SimpleRateLimitMiddleware"
]
This module initialization file makes it easy to import all middleware components and ensures they’re properly exposed as part of the middleware package. The __all__
list explicitly defines what should be available when someone imports from the middleware module.
Similarly, let’s ensure our routers have proper initialization. Create api/routers/__init__.py
:
from . import health, predict
__all__ = ["health", "predict"]
And for completeness, create api/services/__init__.py
:
from .model_service import model_service, ModelService
from .image_service import image_service, ImageService
__all__ = [
"model_service",
"ModelService",
"image_service",
"ImageService"
]
These initialization files ensure that Python treats our directories as proper packages and makes importing components clean and predictable throughout our application.
Creating the Main FastAPI Application
The main application file brings together all our components into a cohesive, production-ready API. Create api/main.py
:
import logging
import os
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.trustedhost import TrustedHostMiddleware
from fastapi.responses import JSONResponse
import uvicorn
# Import our components
from .config import settings, API_MESSAGES
from .routers import health, predict
from .services.model_service import model_service
from .middleware.logging import RequestLoggingMiddleware
from .middleware.security import SecurityHeadersMiddleware
from .middleware.rate_limiting import SimpleRateLimitMiddleware
# Configure logging
def setup_logging():
"""Configure comprehensive logging for production"""
# Create logs directory
os.makedirs("logs", exist_ok=True)
# Configure root logger
logging.basicConfig(
level=getattr(logging, settings.log_level),
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(settings.log_file),
logging.StreamHandler() # Console output
]
)
# Specific logger for our API
api_logger = logging.getLogger("api")
api_logger.setLevel(getattr(logging, settings.log_level))
return api_logger
# Setup logging
logger = setup_logging()
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
Application lifespan management.
Handles startup and shutdown tasks:
- Model loading during startup
- Resource cleanup during shutdown
"""
# Startup
logger.info("Starting PyramidNet Image Classifier API")
logger.info(f"Environment: {settings.environment}")
logger.info(f"Debug mode: {settings.debug}")
logger.info(f"Model device: {settings.model_device}")
try:
# Load the model
await model_service.load_model()
logger.info("Model loaded successfully - API ready to serve requests")
except Exception as e:
logger.error(f"Failed to load model: {str(e)}")
# In production, you might want to fail fast here
if settings.is_production():
raise RuntimeError("Model loading failed in production environment")
yield # Application runs here
# Shutdown
logger.info("Shutting down PyramidNet Image Classifier API")
# Create FastAPI application
app = FastAPI(
title=settings.app_name,
description=settings.app_description,
version=settings.app_version,
docs_url="/docs" if not settings.is_production() else None, # Disable docs in production
redoc_url="/redoc" if not settings.is_production() else None,
lifespan=lifespan
)
# Add middleware (order matters!)
# Security headers should be last to ensure they're applied to all responses
app.add_middleware(SecurityHeadersMiddleware)
# Request logging
app.add_middleware(RequestLoggingMiddleware)
# Rate limiting
if settings.environment != "development":
app.add_middleware(SimpleRateLimitMiddleware, requests_per_minute=60)
# CORS (Cross-Origin Resource Sharing)
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins,
allow_credentials=True,
allow_methods=["GET", "POST"],
allow_headers=["*"],
)
# Trusted hosts (security)
if settings.is_production():
app.add_middleware(
TrustedHostMiddleware,
allowed_hosts=["api.yourdomain.com", "*.yourdomain.com"] # Configure for your domain
)
# Include routers
app.include_router(health.router)
app.include_router(predict.router)
# Global exception handler
@app.exception_handler(Exception)
async def global_exception_handler(request, exc):
"""
Global exception handler for unhandled errors.
Ensures that all errors return consistent JSON responses
and prevents sensitive information leakage.
"""
logger.error(f"Unhandled exception: {str(exc)}", exc_info=True)
if settings.debug:
# In development, show detailed error information
return JSONResponse(
status_code=500,
content={
"error": "Internal Server Error",
"message": str(exc),
"detail": "Check logs for more information",
"type": exc.__class__.__name__
}
)
else:
# In production, hide sensitive error details
return JSONResponse(
status_code=500,
content={
"error": "Internal Server Error",
"message": "An unexpected error occurred",
"detail": "Please try again or contact support if the problem persists"
}
)
# Root endpoint
@app.get(
"/",
summary="API Information",
description="Get basic information about the PyramidNet Image Classifier API"
)
async def root():
"""
Root endpoint providing API information.
"""
return {
"name": settings.app_name,
"version": settings.app_version,
"description": settings.app_description,
"environment": settings.environment,
"model_status": "loaded" if model_service.is_healthy() else "not_loaded",
"docs_url": "/docs" if not settings.is_production() else "disabled",
"endpoints": {
"health": "/health",
"prediction": "/predict",
"batch_prediction": "/predict/batch"
}
}
# Run the application
if __name__ == "__main__":
uvicorn.run(
"api.main:app",
host=settings.host,
port=settings.port,
reload=settings.reload and settings.is_development(),
log_level=settings.log_level.lower(),
access_log=settings.access_log
)
Understanding the Application Structure
This main application demonstrates several production patterns:
- Lifespan Management: The
@asynccontextmanager
decorator handles startup and shutdown tasks. This is where we load our model once at startup rather than on every request. - Middleware Order: Middleware executes in reverse order of addition. We add security headers last to ensure they’re applied to all responses.
- Environment-Aware Configuration: Different behavior in development vs. production (docs enabled/disabled, error detail level).
- Global Error Handling: The exception handler ensures all errors return consistent JSON responses and prevents information leakage.
- Comprehensive Logging: Structured logging with different levels for different environments.
Deep Dive: Main Application Components
The main application file orchestrates several critical production concerns. Let’s examine each component and understand why it’s essential for a robust API:
Logging Configuration and Strategy
def setup_logging():
# Create logs directory
os.makedirs("logs", exist_ok=True)
# Configure root logger
logging.basicConfig(
level=getattr(logging, settings.log_level),
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(settings.log_file),
logging.StreamHandler() # Console output
]
)
Why this matters in production:
- Dual Output: Logs go to both files (for persistence) and console (for container orchestration)
- Structured Format: Timestamp, logger name, level, and message provide complete context
- Environment-Driven Levels: Debug logs in development, info/warning in production
- Directory Creation: Ensures log directory exists, preventing startup failures
Application Lifespan Management
The @asynccontextmanager
pattern is crucial for production applications:
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
logger.info("Starting PyramidNet Image Classifier API")
try:
await model_service.load_model()
logger.info("Model loaded successfully - API ready to serve requests")
except Exception as e:
if settings.is_production():
raise RuntimeError("Model loading failed in production environment")
yield # Application runs here
# Shutdown
logger.info("Shutting down PyramidNet Image Classifier API")
Key benefits:
- Resource Initialization: Model loading happens once at startup, not per request
- Health Check Integration: Health endpoints can verify model is loaded before accepting traffic
- Graceful Failure: In production, fails fast if critical resources can’t be loaded
- Clean Shutdown: Resources can be properly released when the application stops
FastAPI Application Configuration
app = FastAPI(
title=settings.app_name,
description=settings.app_description,
version=settings.app_version,
docs_url="/docs" if not settings.is_production() else None,
redoc_url="/redoc" if not settings.is_production() else None,
lifespan=lifespan
)
Production considerations:
- Documentation Security: Interactive docs disabled in production to prevent information disclosure
- Version Management: API version exposed for client compatibility
- Metadata Integration: Title and description support API discovery and documentation
Middleware Stack and Ordering
The order of middleware addition is critical because they execute in reverse order:
# Last added = First executed on responses
app.add_middleware(SecurityHeadersMiddleware)
app.add_middleware(RequestLoggingMiddleware)
# Rate limiting only in non-development environments
if settings.environment != "development":
app.add_middleware(SimpleRateLimitMiddleware, requests_per_minute=60)
# CORS configuration
app.add_middleware(CORSMiddleware, ...)
Why this order matters:
- CORS: Handles preflight requests first
- Rate Limiting: Blocks abusive traffic before expensive processing
- Request Logging: Logs all requests, including rate-limited ones
- Security Headers: Applied to all responses, including error responses
CORS Configuration for API Access
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins,
allow_credentials=True,
allow_methods=["GET", "POST"],
allow_headers=["*"],
)
Production implications:
- Origin Allowlist: Only specified domains can access the API from browsers
- Method Restriction: Only GET and POST allowed, reducing attack surface
- Credential Support: Enables authentication cookies or headers
- Environment-Specific: Different origins for development vs. production
Trusted Host Security
if settings.is_production():
app.add_middleware(
TrustedHostMiddleware,
allowed_hosts=["api.yourdomain.com", "*.yourdomain.com"]
)
Security benefits:
- Host Header Attacks: Prevents malicious Host header manipulation
- Production Only: Allows flexible hosts in development
- Subdomain Support: Wildcard patterns support multiple subdomains
Global Exception Handling Strategy
@app.exception_handler(Exception)
async def global_exception_handler(request, exc):
logger.error(f"Unhandled exception: {str(exc)}", exc_info=True)
if settings.debug:
# Development: Show detailed errors
return JSONResponse(
status_code=500,
content={
"error": "Internal Server Error",
"message": str(exc),
"detail": "Check logs for more information",
"type": exc.__class__.__name__
}
)
else:
# Production: Hide sensitive details
return JSONResponse(
status_code=500,
content={
"error": "Internal Server Error",
"message": "An unexpected error occurred",
"detail": "Please try again or contact support if the problem persists"
}
)
Critical design decisions:
- Environment-Aware Details: Detailed errors in development, generic in production
- Consistent Structure: All errors return JSON with predictable fields
- Comprehensive Logging: Full stack traces logged even when details are hidden
- Security by Default: Prevents information leakage in production
Root Endpoint for API Discovery
@app.get("/")
async def root():
return {
"name": settings.app_name,
"version": settings.app_version,
"description": settings.app_description,
"environment": settings.environment,
"model_status": "loaded" if model_service.is_healthy() else "not_loaded",
"docs_url": "/docs" if not settings.is_production() else "disabled",
"endpoints": {
"health": "/health",
"prediction": "/predict",
"batch_prediction": "/predict/batch"
}
}
API usability benefits:
- Self-Documentation: Root endpoint describes available functionality
- Health Integration: Quick way to check if API is operational
- Version Discovery: Clients can verify API compatibility
- Endpoint Discovery: Provides a map of available endpoints
Development vs. Production Behavior
The application adapts its behavior based on environment:
Development Mode:
- Interactive documentation enabled
- Detailed error messages with stack traces
- Auto-reload on code changes
- More permissive CORS settings
- Rate limiting disabled
Production Mode:
- Documentation endpoints disabled
- Generic error messages
- Trusted host validation
- Rate limiting enabled
- Comprehensive security headers
Server Configuration and Runtime
if __name__ == "__main__":
uvicorn.run(
"api.main:app",
host=settings.host,
port=settings.port,
reload=settings.reload and settings.is_development(),
log_level=settings.log_level.lower(),
access_log=settings.access_log
)
Configuration highlights:
- Environment-Driven Settings: All configuration comes from environment variables or config file
- Conditional Reload: Auto-reload only in development to prevent production instability
- Access Log Control: Can disable access logs in production if using external logging
- Host Binding: Configurable for different deployment scenarios
This main application file demonstrates how production APIs require careful consideration of security, observability, performance, and operational concerns. Each component serves a specific purpose in creating a robust, maintainable service that can handle real-world traffic and operational requirements.
The patterns shown here—environment-aware configuration, comprehensive error handling, security-first middleware, and proper resource management—form the foundation for APIs that teams can confidently deploy and operate in production environments.
Testing Your API
Now let’s create a comprehensive test suite to ensure our API works correctly. Create tests/test_api.py
:
import pytest
import asyncio
from httpx import AsyncClient
from fastapi.testclient import TestClient
import torch
import numpy as np
from io import BytesIO
from PIL import Image
import tempfile
import os
from api.main import app
from api.services.model_service import model_service
# Test configuration
TEST_IMAGE_SIZE = 32
def create_test_image(size=(TEST_IMAGE_SIZE, TEST_IMAGE_SIZE), format="PNG"):
"""Create a test image for API testing"""
# Create a simple test image with some pattern to make it more realistic
image = Image.new("RGB", size, color="red")
# Add some pattern to make it look more like a real image
import random
pixels = image.load()
for i in range(size[0]):
for j in range(size[1]):
# Add some random noise to make it more realistic
r = min(255, max(0, 255 + random.randint(-50, 50)))
g = min(255, max(0, 0 + random.randint(-20, 20)))
b = min(255, max(0, 0 + random.randint(-20, 20)))
pixels[i, j] = (r, g, b)
# Convert to bytes
img_byte_arr = BytesIO()
image.save(img_byte_arr, format=format)
img_byte_arr.seek(0)
return img_byte_arr.getvalue()
def create_invalid_image():
"""Create an invalid image file for testing"""
return b"This is not an image file"
def create_oversized_image():
"""Create an oversized image for testing file size limits"""
# Create a very large image that would exceed size limits
large_image = Image.new("RGB", (2000, 2000), color="blue")
img_byte_arr = BytesIO()
large_image.save(img_byte_arr, format="PNG")
img_byte_arr.seek(0)
return img_byte_arr.getvalue()
@pytest.fixture
def client():
"""Create test client"""
return TestClient(app)
@pytest.fixture
async def async_client():
"""Create async test client"""
async with AsyncClient(app=app, base_url="http://test") as ac:
yield ac
class TestHealthEndpoints:
"""Test health check endpoints"""
def test_basic_health_check(self, client):
"""Test basic health endpoint"""
response = client.get("/health/")
assert response.status_code == 200
data = response.json()
assert data["status"] == "healthy"
assert "version" in data
assert "uptime_seconds" in data
def test_readiness_check(self, client):
"""Test readiness endpoint"""
response = client.get("/health/ready")
# Should be 503 if model isn't loaded, 200 if it is
assert response.status_code in [200, 503]
if response.status_code == 200:
data = response.json()
assert data["status"] == "ready"
def test_liveness_check(self, client):
"""Test liveness endpoint"""
response = client.get("/health/live")
assert response.status_code == 200
data = response.json()
assert data["status"] == "alive"
def test_model_stats(self, client):
"""Test model statistics endpoint"""
response = client.get("/health/stats")
assert response.status_code == 200
data = response.json()
assert "is_loaded" in data
assert "prediction_count" in data
class TestPredictionEndpoints:
"""Test prediction endpoints"""
def test_prediction_with_valid_image(self, client):
"""Test single image prediction"""
# Create test image
test_image = create_test_image()
# Make prediction request
response = client.post(
"/predict/",
files={"file": ("test.png", test_image, "image/png")}
)
# Model might not be loaded in test environment
if response.status_code == 503:
pytest.skip("Model not loaded in test environment")
if response.status_code != 200:
print(f"Response status: {response.status_code}")
print(f"Response content: {response.text}")
assert response.status_code == 200
data = response.json()
assert "predicted_class" in data
assert "confidence" in data
assert "all_confidences" in data
assert "prediction_time_ms" in data
assert "timestamp" in data
# Verify confidence is between 0 and 1
assert 0 <= data["confidence"] <= 1
# Verify all_confidences has 10 classes (CIFAR-10)
assert len(data["all_confidences"]) == 10
# Verify the predicted class is in CIFAR-10 classes
cifar10_classes = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
assert data["predicted_class"] in cifar10_classes
def test_prediction_with_invalid_file(self, client):
"""Test prediction with invalid file"""
# Create invalid file
invalid_file = create_invalid_image()
response = client.post(
"/predict/",
files={"file": ("test.txt", invalid_file, "text/plain")}
)
# Model might not be loaded in test environment
if response.status_code == 503:
pytest.skip("Model not loaded in test environment")
assert response.status_code == 400
data = response.json()
assert "detail" in data
def test_prediction_with_invalid_format(self, client):
"""Test prediction with unsupported image format"""
# Create an image with unsupported extension
test_image = create_test_image()
response = client.post(
"/predict/",
files={"file": ("test.xyz", test_image, "image/xyz")}
)
# Model might not be loaded in test environment
if response.status_code == 503:
pytest.skip("Model not loaded in test environment")
assert response.status_code == 400
def test_prediction_without_file(self, client):
"""Test prediction without file"""
response = client.post("/predict/")
# Model might not be loaded in test environment
if response.status_code == 503:
pytest.skip("Model not loaded in test environment")
assert response.status_code == 422 # Validation error
def test_prediction_with_empty_file(self, client):
"""Test prediction with empty file"""
response = client.post(
"/predict/",
files={"file": ("test.png", b"", "image/png")}
)
# Model might not be loaded in test environment
if response.status_code == 503:
pytest.skip("Model not loaded in test environment")
assert response.status_code == 400
def test_batch_prediction(self, client):
"""Test batch prediction"""
# Create multiple test images
test_images = [
("test1.png", create_test_image(), "image/png"),
("test2.png", create_test_image(), "image/png")
]
response = client.post(
"/predict/batch",
files=[("files", img) for img in test_images]
)
# Model might not be loaded in test environment
if response.status_code == 503:
pytest.skip("Model not loaded in test environment")
assert response.status_code == 200
data = response.json()
assert "predictions" in data
assert "batch_size" in data
assert "total_processing_time_ms" in data
assert data["batch_size"] == 2
assert len(data["predictions"]) == 2
# Verify each prediction has required fields
for prediction in data["predictions"]:
assert "predicted_class" in prediction
assert "confidence" in prediction
assert "all_confidences" in prediction
def test_batch_prediction_too_many_files(self, client):
"""Test batch prediction with too many files"""
# Create more than 10 test images
test_images = [
(f"test{i}.png", create_test_image(), "image/png")
for i in range(12) # More than the limit of 10
]
response = client.post(
"/predict/batch",
files=[("files", img) for img in test_images]
)
# Model might not be loaded in test environment
if response.status_code == 503:
pytest.skip("Model not loaded in test environment")
assert response.status_code == 400
data = response.json()
assert "Maximum 10 images" in data["detail"]
def test_batch_prediction_empty(self, client):
"""Test batch prediction with no files"""
response = client.post("/predict/batch", files=[])
# Model might not be loaded in test environment
if response.status_code == 503:
pytest.skip("Model not loaded in test environment")
assert response.status_code == 400
def test_prediction_info(self, client):
"""Test prediction info endpoint"""
response = client.get("/predict/info")
assert response.status_code == 200
data = response.json()
assert "api_version" in data
assert "model_info" in data
assert "preprocessing" in data
class TestAPIBehavior:
"""Test general API behavior"""
def test_root_endpoint(self, client):
"""Test root endpoint"""
response = client.get("/")
assert response.status_code == 200
data = response.json()
assert "name" in data
assert "version" in data
assert "endpoints" in data
assert "model_status" in data
def test_cors_headers(self, client):
"""Test CORS headers are present"""
response = client.options("/")
# CORS headers should be present
# Note: TestClient might not fully simulate CORS
assert response.status_code in [200, 405] # Some test clients return 405 for OPTIONS
def test_security_headers(self, client):
"""Test security headers are present"""
response = client.get("/")
# Check for security headers (case-insensitive)
headers_lower = {k.lower(): v for k, v in response.headers.items()}
assert headers_lower.get("x-content-type-options") == "nosniff"
assert headers_lower.get("x-frame-options") == "DENY"
assert "x-request-id" in headers_lower
def test_404_handling(self, client):
"""Test 404 error handling"""
response = client.get("/nonexistent")
assert response.status_code == 404
# Integration tests
class TestModelIntegration:
"""Test model service integration"""
@pytest.mark.asyncio
async def test_model_loading(self):
"""Test model can be loaded"""
try:
await model_service.load_model()
assert model_service.is_healthy()
except FileNotFoundError:
pytest.skip("Model file not found - expected in test environment")
except Exception as e:
pytest.skip(f"Model loading failed: {e}")
@pytest.mark.asyncio
async def test_model_prediction(self):
"""Test model prediction directly"""
if not model_service.is_loaded:
pytest.skip("Model not loaded")
# Create test tensor with proper normalization
test_tensor = torch.randn(1, 3, 32, 32)
try:
result = await model_service.predict(test_tensor)
assert "predicted_class" in result
assert "confidence" in result
assert "all_confidences" in result
assert "prediction_time_ms" in result
except Exception as e:
pytest.fail(f"Model prediction failed: {e}")
# Performance tests
class TestPerformance:
"""Test API performance"""
def test_prediction_timing(self, client):
"""Test that predictions complete in reasonable time"""
test_image = create_test_image()
import time
start_time = time.time()
response = client.post(
"/predict/",
files={"file": ("test.png", test_image, "image/png")}
)
end_time = time.time()
if response.status_code == 503:
pytest.skip("Model not loaded")
# Should complete within 5 seconds (generous for test environment)
assert (end_time - start_time) < 5.0
# Run tests
if __name__ == "__main__":
pytest.main([__file__, "-v", "--tb=short"])
Running the Tests
To run your test suite:
# Install testing dependencies
pip install pytest httpx pytest-asyncio
# Run tests
pytest tests/ -v
# Run tests with coverage
pip install pytest-cov
pytest tests/ --cov=api --cov-report=html
Running Your Production API
Now that we have a complete API, let’s run it and test it out!
Start the API Server
From your project directory, run:
# Make sure you're in the right directory and virtual environment
cd deepthought-image-classifier
source .venv/bin/activate # or activate your virtual environment
# Run the API
python -m api.main
You should see output like:
INFO: Starting PyramidNet Image Classifier API
INFO: Environment: development
INFO: Model loaded successfully - API ready to serve requests
INFO: Started server process [12345]
INFO: Waiting for application startup.
INFO: Application startup complete.
INFO: Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit)
Explore the Interactive Documentation
Open your browser and navigate to http://127.0.0.1:8000/docs
. You’ll see FastAPI’s automatically generated interactive documentation:
- Try out endpoints: Click “Try it out” on any endpoint to test it directly
- Upload images: Use the prediction endpoints to classify real images
- View schemas: See the exact structure of requests and responses
- Download OpenAPI spec: Get the API specification for client generation
Test with curl
You can also test your API with command-line tools:
# Health check
curl http://127.0.0.1:8000/health/
# Upload an image for prediction
curl -X POST "http://127.0.0.1:8000/predict/" \
-H "accept: application/json" \
-H "Content-Type: multipart/form-data" \
-F "file=@path/to/your/image.jpg"
Here is an example of a test api request to the prediction endpoint:
amathis@DeepThought:~$ curl -X POST "http://127.0.0.1:8000/predict/" \
-H "accept: application/json" \
-H "Content-Type: multipart/form-data" \
-F "file=@Downloads/plane.jpg" | jq
% Total % Received % Xferd Average Speed Time Time Time Current
Dload Upload Total Spent Left Speed
100 126k 100 494 100 125k 7570 1930k --:--:-- --:--:-- --:--:-- 1945k
{
"predicted_class": "airplane",
"predicted_class_id": 0,
"confidence": 0.9926060438156128,
"all_confidences": {
"airplane": 0.9926060438156128,
"automobile": 0.0020307409577071667,
"bird": 0.0012899991124868393,
"cat": 0.0007616023067384958,
"deer": 0.0010112059535458684,
"dog": 0.0005068415775895119,
"frog": 0.00041529006557539105,
"horse": 0.0002298073668498546,
"ship": 0.0005295769078657031,
"truck": 0.0006187597755342722
},
"prediction_time_ms": 28.44,
"model_device": "cuda",
"timestamp": "2025-06-22T05:14:45.673644Z"
}
As you can see, the model predicted that the image (of a plane) that was sent to the API was, in fact, a plane.
Test with Python
Create a simple Python client to test your API:
import requests
# Test health endpoint
response = requests.get("http://127.0.0.1:8000/health/")
print("Health check:", response.json())
# Test prediction with an image file
with open("path/to/your/image.jpg", "rb") as f:
files = {"file": f}
response = requests.post("http://127.0.0.1:8000/predict/", files=files)
if response.status_code == 200:
result = response.json()
print(f"Prediction: {result['predicted_class']}")
print(f"Confidence: {result['confidence']:.3f}")
else:
print(f"Error: {response.status_code} - {response.text}")
Performance Optimization and Monitoring
Now that your API is working, let’s discuss optimization strategies for production deployment.
Performance Monitoring
Add performance metrics to track your API’s health:
# Add to your main.py
import time
from collections import defaultdict
class PerformanceMonitor:
def __init__(self):
self.request_times = defaultdict(list)
self.request_counts = defaultdict(int)
self.error_counts = defaultdict(int)
def record_request(self, endpoint: str, duration: float, status_code: int):
self.request_times[endpoint].append(duration)
self.request_counts[endpoint] += 1
if status_code >= 400:
self.error_counts[endpoint] += 1
def get_stats(self):
stats = {}
for endpoint in self.request_times:
times = self.request_times[endpoint]
stats[endpoint] = {
"count": self.request_counts[endpoint],
"avg_time": sum(times) / len(times),
"max_time": max(times),
"min_time": min(times),
"error_rate": self.error_counts[endpoint] / self.request_counts[endpoint]
}
return stats
# Add performance monitoring endpoint
monitor = PerformanceMonitor()
@app.get("/metrics")
async def get_metrics():
"""Get API performance metrics"""
return {
"model_stats": model_service.get_stats(),
"endpoint_stats": monitor.get_stats(),
"system_info": {
"device": str(model_service.device),
"model_loaded": model_service.is_healthy()
}
}
Now, your full main.py
should look like this:
import logging
import os
import time
from collections import defaultdict
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.trustedhost import TrustedHostMiddleware
from fastapi.responses import JSONResponse
import uvicorn
# Import our components
from .config import settings, API_MESSAGES
from .routers import health, predict
from .services.model_service import model_service
from .middleware.logging import RequestLoggingMiddleware
from .middleware.security import SecurityHeadersMiddleware
from .middleware.rate_limiting import SimpleRateLimitMiddleware
class PerformanceMonitor:
"""Monitor API performance metrics"""
def __init__(self):
self.request_times = defaultdict(list)
self.request_counts = defaultdict(int)
self.error_counts = defaultdict(int)
self.start_time = time.time()
def record_request(self, endpoint: str, duration: float, status_code: int):
"""Record a request for metrics tracking"""
self.request_times[endpoint].append(duration)
self.request_counts[endpoint] += 1
if status_code >= 400:
self.error_counts[endpoint] += 1
def get_stats(self):
"""Get performance statistics"""
stats = {}
for endpoint in self.request_times:
times = self.request_times[endpoint]
if times: # Avoid division by zero
stats[endpoint] = {
"count": self.request_counts[endpoint],
"avg_time_ms": sum(times) / len(times),
"max_time_ms": max(times),
"min_time_ms": min(times),
"error_rate": self.error_counts[endpoint] / self.request_counts[endpoint] if self.request_counts[endpoint] > 0 else 0.0
}
return {
"uptime_seconds": time.time() - self.start_time,
"endpoint_stats": stats,
"total_requests": sum(self.request_counts.values()),
"total_errors": sum(self.error_counts.values())
}
# Configure logging
def setup_logging():
"""Configure comprehensive logging for production"""
# Create logs directory
os.makedirs("logs", exist_ok=True)
# Configure root logger
logging.basicConfig(
level=getattr(logging, settings.log_level),
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(settings.log_file),
logging.StreamHandler() # Console output
]
)
# Specific logger for our API
api_logger = logging.getLogger("api")
api_logger.setLevel(getattr(logging, settings.log_level))
return api_logger
# Setup logging
logger = setup_logging()
# Initialize performance monitor
monitor = PerformanceMonitor()
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
Application lifespan management.
Handles startup and shutdown tasks:
- Model loading during startup
- Resource cleanup during shutdown
"""
# Startup
logger.info("Starting PyramidNet Image Classifier API")
logger.info(f"Environment: {settings.environment}")
logger.info(f"Debug mode: {settings.debug}")
logger.info(f"Model device: {settings.model_device}")
logger.info(f"Docs enabled: {not settings.is_production()}")
try:
# Load the model
await model_service.load_model()
logger.info("Model loaded successfully - API ready to serve requests")
except Exception as e:
logger.error(f"Failed to load model: {str(e)}")
# In production, you might want to fail fast here
if settings.is_production():
raise RuntimeError("Model loading failed in production environment")
yield # Application runs here
# Shutdown
logger.info("Shutting down PyramidNet Image Classifier API")
# Create FastAPI application
app = FastAPI(
title=settings.app_name,
description=settings.app_description,
version=settings.app_version,
docs_url="/docs" if not settings.is_production() else None, # Disable docs in production
redoc_url="/redoc" if not settings.is_production() else None,
lifespan=lifespan
)
# Add middleware (order matters!)
# Security headers should be last to ensure they're applied to all responses
app.add_middleware(SecurityHeadersMiddleware)
# Request logging
app.add_middleware(RequestLoggingMiddleware)
# Rate limiting
if settings.environment != "development":
app.add_middleware(SimpleRateLimitMiddleware, requests_per_minute=60)
# CORS (Cross-Origin Resource Sharing)
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins,
allow_credentials=True,
allow_methods=["GET", "POST"],
allow_headers=["*"],
)
# Trusted hosts (security)
if settings.is_production():
app.add_middleware(
TrustedHostMiddleware,
allowed_hosts=["api.yourdomain.com", "*.yourdomain.com"] # Configure for your domain
)
# Include routers
app.include_router(health.router)
app.include_router(predict.router)
# Global exception handler
@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
"""
Global exception handler for unhandled errors.
Ensures that all errors return consistent JSON responses
and prevents sensitive information leakage.
"""
logger.error(f"Unhandled exception: {str(exc)}", exc_info=True)
if settings.debug:
# In development, show detailed error information
return JSONResponse(
status_code=500,
content={
"error": "Internal Server Error",
"message": str(exc),
"detail": "Check logs for more information",
"type": exc.__class__.__name__
}
)
else:
# In production, hide sensitive error details
return JSONResponse(
status_code=500,
content={
"error": "Internal Server Error",
"message": "An unexpected error occurred",
"detail": "Please try again or contact support if the problem persists"
}
)
# Root endpoint
@app.get(
"/",
summary="API Information",
description="Get basic information about the PyramidNet Image Classifier API"
)
async def root():
"""
Root endpoint providing API information.
"""
return {
"name": settings.app_name,
"version": settings.app_version,
"description": settings.app_description,
"environment": settings.environment,
"model_status": "loaded" if model_service.is_healthy() else "not_loaded",
"docs_url": "/docs" if not settings.is_production() else "disabled",
"endpoints": {
"health": "/health",
"prediction": "/predict",
"batch_prediction": "/predict/batch",
"metrics": "/metrics"
}
}
@app.get("/metrics")
async def get_metrics():
"""Get API performance metrics"""
try:
return {
"model_stats": model_service.get_stats(),
"performance": monitor.get_stats(),
"system_info": {
"device": str(model_service.device) if model_service.device else "unknown",
"model_loaded": model_service.is_healthy(),
"environment": settings.environment
}
}
except Exception as e:
logger.error(f"Error getting metrics: {str(e)}")
return JSONResponse(
status_code=500,
content={"error": "Failed to retrieve metrics", "detail": str(e)}
)
# Run the application
if __name__ == "__main__":
uvicorn.run(
"api.main:app",
host=settings.host,
port=settings.port,
reload=settings.reload and settings.is_development(),
log_level=settings.log_level.lower(),
access_log=settings.access_log
)
Optimization Strategies
Model Optimization:
- Use
torch.jit.script()
to compile your model for faster inference - Enable mixed precision with
torch.cuda.amp
if using GPU - Consider model quantization for deployment
API Optimization:
- Use async/await consistently throughout your codebase
- Implement connection pooling for database connections
- Add response caching for frequently requested predictions
Infrastructure Optimization:
- Use a reverse proxy like Nginx for static file serving
- Implement horizontal scaling with multiple API instances
- Use a load balancer to distribute traffic
Production Deployment Checklist
Before deploying to production:
- Security: Remove debug mode, add authentication, use HTTPS
- Logging: Configure centralized logging (e.g., ELK stack)
- Monitoring: Add health checks, metrics collection, alerting
- Performance: Load test your API, optimize bottlenecks
- Reliability: Implement circuit breakers, retries, fallbacks
- Documentation: API documentation, deployment guides, runbooks
Wrapping Up: From Model to Production API
Congratulations! You’ve built a production-grade FastAPI service that can serve your PyramidNet model to the world. Let’s review what we’ve accomplished:
What We Built
Complete API Infrastructure:
- RESTful endpoints for single and batch image classification
- Comprehensive health checks for monitoring and load balancing
- Production-ready error handling and validation
- Security middleware for protection against common attacks
Professional Code Organization:
- Modular structure separating concerns (models, services, routers)
- Configuration management for different environments
- Comprehensive logging and request tracking
- Extensive test suite for reliability
Advanced Features:
- Async processing for high concurrency
- Rate limiting to prevent abuse
- Automatic API documentation
- Performance monitoring capabilities
Current Project Structure
deepthought-image-classifier/
├── api/
│ ├── __init__.py
│ ├── main.py # FastAPI application
│ ├── config.py # Configuration management
│ ├── models/
│ │ ├── __init__.py
│ │ └── prediction.py # Pydantic models
│ ├── routers/
│ │ ├── __init__.py
│ │ ├── health.py # Health check endpoints
│ │ └── predict.py # Prediction endpoints
│ ├── services/
│ │ ├── __init__.py
│ │ ├── model_service.py # Model loading and inference
│ │ └── image_service.py # Image preprocessing
│ └── middleware/
│ ├── __init__.py
│ ├── logging.py # Request logging
│ ├── security.py # Security headers
│ └── rate_limiting.py # Rate limiting
├── models/
│ └── best_pyramidnet_model.pth # Trained model from Part 2
├── tests/
│ ├── __init__.py
│ └── test_api.py # Comprehensive test suite
├── logs/ # Application logs
├── processed/
│ └── prepared_train_dataset.pt # From Part 1
├── .env # Environment configuration
├── preprocessing.py # From Part 1
├── train_model.py # From Part 2
└── requirements.txt
Key Production Patterns We Implemented
Separation of Concerns: Model loading, image processing, and API logic are cleanly separated, making the code maintainable and testable.
Configuration Management: Environment-specific settings without hardcoded values, making deployment across different environments seamless.
Comprehensive Error Handling: Specific error messages and appropriate HTTP status codes help clients understand and handle errors properly.
Security by Design: Multiple layers of security including headers, rate limiting, and input validation protect against common attacks.
Observability: Detailed logging, request tracking, and performance metrics provide visibility into your API’s behavior in production.
Performance Characteristics
Your API should be capable of:
- Throughput: 100-500 requests per minute on a modern CPU, 1000+ on GPU
- Latency: 50-200ms per prediction depending on hardware
- Concurrency: Handles multiple simultaneous requests efficiently
- Reliability: Graceful error handling and recovery
Looking Ahead
In Part 4 of this series, we’ll containerize this API with Docker and deploy it to Kubernetes, adding:
- Containerization: Docker images for consistent deployment
- Orchestration: Kubernetes manifests for scalable deployment
- Service Mesh: Advanced networking and security features
- Monitoring: Prometheus metrics and Grafana dashboards
In Part 5, we’ll deploy to AWS with:
- Infrastructure as Code: Terraform for AWS resource management
- High Availability: Multi-AZ deployment with load balancing
- Auto Scaling: Automatic scaling based on demand
- CI/CD Pipeline: Automated testing and deployment
Key Takeaways
Building production APIs requires more than just exposing your model through an HTTP endpoint. The patterns we’ve implemented here—comprehensive error handling, security measures, observability, and testing—are essential for APIs that real users will depend on.
The modular structure we’ve created makes it easy to extend your API with additional features like user authentication, model versioning, or A/B testing. The foundation is solid and ready for the next phase of deployment.
Your PyramidNet model has evolved from a training script to a production-ready web service. You now have the tools and knowledge to serve machine learning models at scale, with the reliability and performance that production systems demand.
The journey from raw pixels to production API is nearly complete—next, we’ll take this service to the cloud!

Aaron Mathis
Software engineer specializing in cloud development, AI/ML, and modern web technologies. Passionate about building scalable solutions and sharing knowledge with the developer community.