app = FastAPI()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
@app.get("/items/")
async def read_items(token: str = Depends(oauth2_scheme)):
return {"token": token}
这是官方提供的一个使用示例
我们重新来观察一下上面代码中的细节
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
,我们可以发现,关键的点在于参数tokenUrl="token"
,这仿佛是提供了一个path,让我们点击下Authorize
这个参数声明了token的认证path,但是我们还没有实现它的逻辑,所以会得到404。
async def read_items(token: str = Depends(oauth2_scheme)):
,这里用到了依赖项Depends(oauth2_scheme)
。返回值推测为认证结果。我们稍后会说到。
官方又提供了一个案例
async def get_current_user(token: str = Depends(oauth2_scheme)):
```
token验证逻辑
```
return user
@app.get("/users/me")
async def read_users_me(current_user: User = Depends(get_current_user)):
return current_user
可以像这样,使用子依赖,对token进行处理,得到结果。通常会通过查数据库返回用户的model。
现在我们按照官方的完整案例修改一下
from datetime import datetime, timedelta
from typing import Optional
from starlette import status
from fastapi import Depends, FastAPI, HTTPException
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import JWTError, jwt
from pydantic import BaseModel
app = FastAPI()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
SECRET_KEY = "09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
class User(BaseModel):
username: str
password: str
class Token(BaseModel):
access_token: str
token_type: str
USER_LIST = [
User(username="test", password="test_pw")
]
def get_user(username: str) -> User:
# 伪数据库
for user in USER_LIST:
if user.username == username:
return user
form_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
def create_token(user: User, expires_delta: Optional[timedelta] = None):
expire = datetime.utcnow() + expires_delta or timedelta(minutes=15)
return jwt.encode(
claims={"sub": user.username, "exp": expire},
key=SECRET_KEY,
algorithm=ALGORITHM
)
@app.post("/token")
async def login_get_token(form_data: OAuth2PasswordRequestForm = Depends()):
user: User = get_user(form_data.username)
if not user or user.password != form_data.password:
raise form_exception
access_token = create_token(user=user, expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES))
return {"access_token": access_token, "token_type": "bearer"}
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
def token_to_user(token: str = Depends(oauth2_scheme)):
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username, expire = payload.get("sub"), payload.get("exp")
user = get_user(username)
if user is None:
raise JWTError
except JWTError:
raise credentials_exception
return user
@app.get("/items/", response_model=User)
async def read_items(user: User = Depends(token_to_user)):
return user
if __name__ == '__main__':
import uvicorn
uvicorn.run(app, host="127.0.0.1", port=8000)
以上代码是一个简单的模拟demo,它包含了用户登录获取token,与进行token认证两种功能,下面我们将慢慢解释这些内容。
预定义内容
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
SECRET_KEY = "09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
class User(BaseModel):
username: str
password: str
class Token(BaseModel):
access_token: str
token_type: str
这里包含了安全认证类,秘钥,算法,过期时间的定义,以及两个验证model。
模拟数据库
假设通过用户名从数据库中获取user的过程
USER_LIST = [
User(username="test", password="test_pw")
]
def get_user(username: str) -> User:
# 伪数据库
for user in USER_LIST:
if user.username == username:
return user
登录获取token
form_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
def create_token(user: User, expires_delta: Optional[timedelta] = None):
expire = datetime.utcnow() + expires_delta or timedelta(minutes=15)
return jwt.encode(
claims={"sub": user.username, "exp": expire},
key=SECRET_KEY,
algorithm=ALGORITHM
)
@app.post("/token")
async def login_get_token(form_data: OAuth2PasswordRequestForm = Depends()):
user: User = get_user(form_data.username)
if not user or user.password != form_data.password:
raise form_exception
access_token = create_token(user=user, expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES))
return {"access_token": access_token, "token_type": "bearer"}
认证token获取用户
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
def token_to_user(token: str = Depends(oauth2_scheme)):
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username, expire = payload.get("sub"), payload.get("exp")
user = get_user(username)
if user is None:
raise JWTError
except JWTError:
raise credentials_exception
return user
@app.get("/items/", response_model=User)
async def read_items(user: User = Depends(token_to_user)):
return user
实际效果
下一部分我们将通过源码解析来了解OAuth2.0的安全认证