Secure FastAPI: Implement OAuth2 with JWT for Bulletproof API Authentication

OAuth2 with JWT in FastAPI enhances API security. It involves token creation, user authentication, and protected endpoints. Advanced features include token refresh, revocation, and scopes. Proper implementation ensures robust API authentication and authorization.

Secure FastAPI: Implement OAuth2 with JWT for Bulletproof API Authentication

Implementing OAuth2 authentication with JWT tokens in FastAPI is a game-changer for API security. Let’s dive into the nitty-gritty of setting this up and see how it can level up your FastAPI projects.

First things first, we need to install the required dependencies. Open up your terminal and run:

pip install fastapi[all] python-jose[cryptography] passlib[bcrypt]

Now that we’ve got our tools ready, let’s start building our secure API. We’ll create a simple user authentication system with JWT tokens.

Here’s the basic structure of our FastAPI app:

from fastapi import FastAPI, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import JWTError, jwt
from passlib.context import CryptContext
from pydantic import BaseModel
from datetime import datetime, timedelta

app = FastAPI()

# Secret key to sign our JWT tokens
SECRET_KEY = "your-secret-key"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30

# Fake user database for demonstration
fake_users_db = {
    "johndoe": {
        "username": "johndoe",
        "full_name": "John Doe",
        "email": "[email protected]",
        "hashed_password": "$2b$12$EixZaYVK1fsbw1ZfbX3OXePaWxn96p36WQoeG6Lruj3vjPGga31lW",
        "disabled": False,
    }
}

class Token(BaseModel):
    access_token: str
    token_type: str

class TokenData(BaseModel):
    username: str | None = None

class User(BaseModel):
    username: str
    email: str | None = None
    full_name: str | None = None
    disabled: bool | None = None

class UserInDB(User):
    hashed_password: str

pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")

This sets up our basic app structure, including our fake user database and some Pydantic models for data validation. The CryptContext is used for password hashing, and OAuth2PasswordBearer sets up our token endpoint.

Now, let’s add some helper functions to verify passwords and create JWT tokens:

def verify_password(plain_password, hashed_password):
    return pwd_context.verify(plain_password, hashed_password)

def get_password_hash(password):
    return pwd_context.hash(password)

def get_user(db, username: str):
    if username in db:
        user_dict = db[username]
        return UserInDB(**user_dict)

def authenticate_user(fake_db, username: str, password: str):
    user = get_user(fake_db, username)
    if not user:
        return False
    if not verify_password(password, user.hashed_password):
        return False
    return user

def create_access_token(data: dict, expires_delta: timedelta | None = None):
    to_encode = data.copy()
    if expires_delta:
        expire = datetime.utcnow() + expires_delta
    else:
        expire = datetime.utcnow() + timedelta(minutes=15)
    to_encode.update({"exp": expire})
    encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
    return encoded_jwt

These functions handle password verification, user authentication, and token creation. The create_access_token function is particularly important as it generates our JWT token.

Now, let’s create our token endpoint:

@app.post("/token", response_model=Token)
async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()):
    user = authenticate_user(fake_users_db, form_data.username, form_data.password)
    if not user:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Incorrect username or password",
            headers={"WWW-Authenticate": "Bearer"},
        )
    access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
    access_token = create_access_token(
        data={"sub": user.username}, expires_delta=access_token_expires
    )
    return {"access_token": access_token, "token_type": "bearer"}

This endpoint handles user login and returns a JWT token if the credentials are correct.

Next, we need a way to get the current user from a token. Here’s how we can do that:

async def get_current_user(token: str = Depends(oauth2_scheme)):
    credentials_exception = HTTPException(
        status_code=status.HTTP_401_UNAUTHORIZED,
        detail="Could not validate credentials",
        headers={"WWW-Authenticate": "Bearer"},
    )
    try:
        payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
        username: str = payload.get("sub")
        if username is None:
            raise credentials_exception
        token_data = TokenData(username=username)
    except JWTError:
        raise credentials_exception
    user = get_user(fake_users_db, username=token_data.username)
    if user is None:
        raise credentials_exception
    return user

async def get_current_active_user(current_user: User = Depends(get_current_user)):
    if current_user.disabled:
        raise HTTPException(status_code=400, detail="Inactive user")
    return current_user

These functions decode the JWT token and return the corresponding user. The get_current_active_user function adds an extra check for disabled users.

Finally, let’s create a protected endpoint that requires authentication:

@app.get("/users/me/", response_model=User)
async def read_users_me(current_user: User = Depends(get_current_active_user)):
    return current_user

This endpoint will only be accessible to authenticated users and will return their user information.

And there you have it! We’ve implemented OAuth2 authentication with JWT tokens in FastAPI. This setup provides a secure way to handle user authentication and protect your API endpoints.

But wait, there’s more! Let’s dive a bit deeper and explore some advanced features and best practices.

One thing to consider is token refresh. Our current implementation doesn’t have a way to refresh tokens, which means users will need to log in again when their token expires. Let’s add a refresh token mechanism:

class TokenPayload(BaseModel):
    username: str | None = None
    exp: int | None = None

@app.post("/refresh-token", response_model=Token)
async def refresh_token(current_user: User = Depends(get_current_user)):
    new_token = create_access_token(data={"sub": current_user.username})
    return {"access_token": new_token, "token_type": "bearer"}

This endpoint allows users to get a new token without having to provide their credentials again, as long as their current token is still valid.

Another important consideration is token revocation. In some cases, you might want to invalidate a token before it expires. One way to handle this is by maintaining a blacklist of revoked tokens:

from fastapi import Request

revoked_tokens = set()

@app.post("/revoke-token")
async def revoke_token(request: Request, current_user: User = Depends(get_current_user)):
    token = request.headers.get('Authorization').split()[1]
    revoked_tokens.add(token)
    return {"message": "Token revoked successfully"}

# Update get_current_user to check for revoked tokens
async def get_current_user(token: str = Depends(oauth2_scheme)):
    if token in revoked_tokens:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Token has been revoked",
            headers={"WWW-Authenticate": "Bearer"},
        )
    # ... rest of the function remains the same

This allows you to revoke tokens on demand, which can be useful for scenarios like user logout or security breaches.

Now, let’s talk about token storage. In a production environment, you wouldn’t want to use a fake database or store tokens in memory. Instead, you’d typically use a database like PostgreSQL or Redis. Here’s a quick example of how you might integrate Redis for token storage:

import redis
from fastapi import FastAPI, Depends, HTTPException, status
# ... other imports

app = FastAPI()
redis_client = redis.Redis(host='localhost', port=6379, db=0)

def create_access_token(data: dict, expires_delta: timedelta | None = None):
    to_encode = data.copy()
    if expires_delta:
        expire = datetime.utcnow() + expires_delta
    else:
        expire = datetime.utcnow() + timedelta(minutes=15)
    to_encode.update({"exp": expire})
    encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
    
    # Store the token in Redis
    redis_client.set(f"token:{data['sub']}", encoded_jwt, ex=expires_delta)
    
    return encoded_jwt

async def get_current_user(token: str = Depends(oauth2_scheme)):
    credentials_exception = HTTPException(
        status_code=status.HTTP_401_UNAUTHORIZED,
        detail="Could not validate credentials",
        headers={"WWW-Authenticate": "Bearer"},
    )
    try:
        payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
        username: str = payload.get("sub")
        if username is None:
            raise credentials_exception
        
        # Check if the token exists in Redis
        stored_token = redis_client.get(f"token:{username}")
        if stored_token is None or stored_token.decode() != token:
            raise credentials_exception
        
        token_data = TokenData(username=username)
    except JWTError:
        raise credentials_exception
    user = get_user(fake_users_db, username=token_data.username)
    if user is None:
        raise credentials_exception
    return user

This setup stores tokens in Redis and checks against Redis when validating tokens. This approach allows for easier token management and revocation.

Let’s also consider rate limiting. It’s a good practice to limit the number of requests a user can make to prevent abuse. Here’s a simple rate limiting decorator using Redis:

from fastapi import Request
from fastapi.responses import JSONResponse
import time

def rate_limit(limit: int, window: int):
    def decorator(func):
        async def wrapper(request: Request, *args, **kwargs):
            client_ip = request.client.host
            current = int(time.time())
            window_key = f"{client_ip}:{current // window}"
            
            with redis_client.pipeline() as pipe:
                pipe.incr(window_key)
                pipe.expire(window_key, window)
                result = pipe.execute()

            request_count = result[0]
            
            if request_count > limit:
                return JSONResponse(
                    status_code=429,
                    content={"error": "Too many requests"}
                )
            
            return await func(request, *args, **kwargs)
        return wrapper
    return decorator

@app.post("/token")
@rate_limit(limit=5, window=60)  # 5 requests per minute
async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()):
    # ... existing code

This decorator limits the number of requests a user can make to the login endpoint, helping to prevent brute force attacks.

Lastly, let’s talk about scopes. Scopes allow you to define different levels of access for your API. Here’s how you can implement scopes in your FastAPI app:

from fastapi.security import OAuth2PasswordBearer, SecurityScopes
from pydantic import ValidationError

oauth2_