Implementing rate-limiting and request throttling in FastAPI is a crucial step for protecting your API from abuse and ensuring fair usage. Let’s dive into how we can achieve this using middleware.
First, we’ll need to install the required dependencies. We’ll use the fastapi-limiter
package, which provides a simple way to add rate limiting to our FastAPI application. Open your terminal and run:
pip install fastapi-limiter
Now, let’s create a new FastAPI application and set up our rate limiting middleware. Here’s a basic example:
from fastapi import FastAPI, Request
from fastapi_limiter import FastAPILimiter
from fastapi_limiter.depends import RateLimiter
import redis.asyncio as redis
import asyncio
app = FastAPI()
@app.on_event("startup")
async def startup():
redis_client = redis.from_url("redis://localhost", encoding="utf-8", decode_responses=True)
await FastAPILimiter.init(redis_client)
@app.get("/", dependencies=[RateLimiter(times=2, seconds=5)])
async def root():
return {"message": "Hello World"}
In this example, we’re using Redis as our backend for storing rate limit information. Make sure you have Redis installed and running on your local machine.
The @app.on_event("startup")
decorator ensures that our rate limiter is initialized when the application starts. We’re connecting to a local Redis instance and initializing the FastAPILimiter with this client.
For our root endpoint, we’ve added a dependency that limits requests to 2 per 5 seconds. If a client exceeds this limit, they’ll receive a 429 Too Many Requests error.
But what if we want more fine-grained control over our rate limiting? Let’s create a custom middleware that allows us to set different limits for different endpoints:
from fastapi import FastAPI, Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
import time
import asyncio
class RateLimitMiddleware(BaseHTTPMiddleware):
def __init__(self, app, limit=100, window=60):
super().__init__(app)
self.limit = limit
self.window = window
self.requests = {}
async def dispatch(self, request: Request, call_next):
key = f"{request.client.host}:{request.url.path}"
now = time.time()
if key in self.requests:
self.requests[key] = [t for t in self.requests[key] if now - t < self.window]
if len(self.requests[key]) >= self.limit:
return JSONResponse(status_code=429, content={"error": "Too many requests"})
else:
self.requests[key] = []
self.requests[key].append(now)
response = await call_next(request)
return response
app = FastAPI()
app.add_middleware(RateLimitMiddleware, limit=5, window=10)
@app.get("/")
async def root():
return {"message": "Hello World"}
@app.get("/premium")
async def premium():
return {"message": "Welcome, premium user!"}
This custom middleware allows us to set a global rate limit for all endpoints. In this case, we’re limiting requests to 5 per 10 seconds. The middleware keeps track of requests from each IP address for each endpoint separately.
But what if we want different rate limits for different endpoints or user roles? We can modify our middleware to handle this:
from fastapi import FastAPI, Request, Depends
from fastapi.security import OAuth2PasswordBearer
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
import time
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
class RateLimitMiddleware(BaseHTTPMiddleware):
def __init__(self, app):
super().__init__(app)
self.limits = {
"/": {"limit": 5, "window": 10},
"/premium": {"limit": 20, "window": 10}
}
self.requests = {}
async def dispatch(self, request: Request, call_next):
path = request.url.path
key = f"{request.client.host}:{path}"
now = time.time()
limit = self.limits.get(path, {"limit": 2, "window": 10})
if key in self.requests:
self.requests[key] = [t for t in self.requests[key] if now - t < limit["window"]]
if len(self.requests[key]) >= limit["limit"]:
return JSONResponse(status_code=429, content={"error": "Too many requests"})
else:
self.requests[key] = []
self.requests[key].append(now)
response = await call_next(request)
return response
app = FastAPI()
app.add_middleware(RateLimitMiddleware)
async def get_current_user(token: str = Depends(oauth2_scheme)):
# In a real application, you'd validate the token here
return {"username": "johndoe"}
@app.get("/")
async def root():
return {"message": "Hello World"}
@app.get("/premium")
async def premium(current_user: dict = Depends(get_current_user)):
return {"message": f"Welcome, {current_user['username']}!"}
In this example, we’ve defined different rate limits for different endpoints. The root endpoint is limited to 5 requests per 10 seconds, while the premium endpoint allows 20 requests per 10 seconds. We’ve also added a simple authentication system using OAuth2PasswordBearer.
But what about distributed systems? If your API is running on multiple servers, you’ll need a centralized way to keep track of rate limits. Let’s modify our middleware to use Redis for this:
import aioredis
from fastapi import FastAPI, Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
import time
class RedisRateLimitMiddleware(BaseHTTPMiddleware):
def __init__(self, app):
super().__init__(app)
self.redis = None
self.limits = {
"/": {"limit": 5, "window": 10},
"/premium": {"limit": 20, "window": 10}
}
async def dispatch(self, request: Request, call_next):
if not self.redis:
self.redis = await aioredis.create_redis_pool("redis://localhost")
path = request.url.path
key = f"ratelimit:{request.client.host}:{path}"
now = int(time.time())
limit = self.limits.get(path, {"limit": 2, "window": 10})
pipe = self.redis.pipeline()
pipe.zremrangebyscore(key, 0, now - limit["window"])
pipe.zcard(key)
pipe.zadd(key, now, now)
pipe.expire(key, limit["window"])
results = await pipe.execute()
if results[1] >= limit["limit"]:
return JSONResponse(status_code=429, content={"error": "Too many requests"})
response = await call_next(request)
return response
app = FastAPI()
app.add_middleware(RedisRateLimitMiddleware)
@app.get("/")
async def root():
return {"message": "Hello World"}
@app.get("/premium")
async def premium():
return {"message": "Welcome, premium user!"}
This Redis-based middleware uses sorted sets to keep track of requests. It’s more scalable and can work across multiple instances of your API.
Now, let’s talk about request throttling. While rate limiting puts a hard cap on the number of requests, throttling slows down requests when the limit is approached. Here’s an example of how we might implement throttling:
import asyncio
import time
from fastapi import FastAPI, Request
from starlette.middleware.base import BaseHTTPMiddleware
class ThrottlingMiddleware(BaseHTTPMiddleware):
def __init__(self, app, rate=10, per=1):
super().__init__(app)
self.rate = rate
self.per = per
self.allow_per_sec = rate / per
self.last_check = time.time()
self.allowance = rate
async def dispatch(self, request: Request, call_next):
current = time.time()
time_passed = current - self.last_check
self.last_check = current
self.allowance += time_passed * self.allow_per_sec
if self.allowance > self.rate:
self.allowance = self.rate
if self.allowance < 1:
await asyncio.sleep(1 - self.allowance / self.allow_per_sec)
self.allowance = 1
else:
self.allowance -= 1
response = await call_next(request)
return response
app = FastAPI()
app.add_middleware(ThrottlingMiddleware, rate=10, per=1)
@app.get("/")
async def root():
return {"message": "Hello World"}
This throttling middleware allows a certain number of requests per second, but instead of rejecting requests when the limit is reached, it slows them down. In this example, we’re allowing 10 requests per second. If more requests come in, they’ll be delayed to maintain this rate.
Remember, the choice between rate limiting and throttling (or using both) depends on your specific use case. Rate limiting is great for preventing abuse and ensuring fair usage, while throttling can help smooth out traffic spikes and prevent your server from becoming overwhelmed.
In real-world applications, you might want to combine these techniques with other strategies. For example, you could implement a token bucket algorithm for more flexible rate limiting, or use machine learning to detect and block abusive patterns of requests.
You might also want to consider how to communicate rate limits to your API users. The HTTP specification includes headers for this purpose:
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
app = FastAPI()
@app.get("/")
async def root(request: Request):
remaining = 5 # This would be calculated based on your rate limiting logic
response = JSONResponse(content={"message": "Hello World"})
response.headers["X-RateLimit-Limit"] = "5"
response.headers["X-RateLimit-Remaining"] = str(remaining)
response.headers["X-RateLimit-Reset"] = str(int(time.time()) + 60)
return response
These headers inform the client about the rate limit, how many requests they have left, and when the limit will reset.
Implementing rate limiting and request throttling is an essential part of building robust and scalable APIs. It helps protect your services from abuse, ensures fair usage among your users, and can even improve the overall performance of your application by preventing server overload.
As you continue to develop your FastAPI applications, remember that these techniques are just the beginning. You might need to adjust and fine-tune your rate limiting and throttling strategies based on your specific use cases and the behavior of your users. Always monitor your API’s performance and be ready to adapt your approach as needed.
Happy coding, and may your APIs always run smoothly and efficiently!