Creating a REST API for Models using FastAPI
FastAPI is a modern, high-performance Python web framework that makes it straightforward to wrap a trained model in a type-safe, documented REST API ready for production in minutes.
Building the API
The core pattern is: load the model at startup, define Pydantic schemas for request/response, and expose a /predict endpoint that runs inference and returns structured results.
Complete FastAPI App
<pre><code class="language-python"># app.py
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import joblib
import numpy as np
from contextlib import asynccontextmanager
# --- Schemas ---
class PredictRequest(BaseModel):
features: list[float] # e.g., [5.1, 3.5, 1.4, 0.2]
class PredictResponse(BaseModel):
prediction: int
probability: list[float]
# --- App lifecycle ---
ml_model = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
ml_model["clf"] = joblib.load("model.joblib")
yield
ml_model.clear()
app = FastAPI(title="ML Model API", version="1.0.0", lifespan=lifespan)
# --- Endpoints ---
@app.get("/health")
def health():
return {"status": "ok"}
@app.post("/predict", response_model=PredictResponse)
def predict(request: PredictRequest):
clf = ml_model.get("clf")
if clf is None:
raise HTTPException(status_code=503, detail="Model not loaded")
X = np.array(request.features).reshape(1, -1)
pred = int(clf.predict(X)[0])
proba = clf.predict_proba(X)[0].tolist()
return PredictResponse(prediction=pred, probability=proba)</pre>
Running the Server
<pre><code class="language-python"># Start locally:
# uvicorn app:app --reload --port 8000
# Test with curl:
# curl -X POST http://localhost:8000/predict \
# -H "Content-Type: application/json" \
# -d '{"features": [5.1, 3.5, 1.4, 0.2]}'
# Auto-generated interactive docs:
# http://localhost:8000/docs (Swagger UI)
# http://localhost:8000/redoc</pre>
Input Validation with Pydantic
Pydantic models provide automatic type checking, constraint enforcement, and clear error messages — protecting your model from malformed inputs at zero extra cost.
Adding Validation Constraints
<pre><code class="language-python">from pydantic import BaseModel, Field, field_validator
class PredictRequest(BaseModel):
features: list[float] = Field(
..., min_length=4, max_length=4,
description="Exactly 4 numeric features required"
)
@field_validator("features")
@classmethod
def check_finite(cls, v):
import math
if any(math.isinf(x) or math.isnan(x) for x in v):
raise ValueError("Features must be finite numbers")
return v</pre>
Batch Prediction Endpoint
For higher throughput, expose a batch endpoint that accepts multiple samples and returns predictions for all in a single round-trip.
Batch Endpoint
<pre><code class="language-python">class BatchRequest(BaseModel):
instances: list[list[float]]
class BatchResponse(BaseModel):
predictions: list[int]
probabilities: list[list[float]]
@app.post("/predict/batch", response_model=BatchResponse)
def predict_batch(request: BatchRequest):
clf = ml_model["clf"]
X = np.array(request.instances)
preds = clf.predict(X).tolist()
probas = clf.predict_proba(X).tolist()
return BatchResponse(predictions=preds, probabilities=probas)</pre>