42 lines
1.2 KiB
Python
42 lines
1.2 KiB
Python
from fastapi import Depends, HTTPException, status
|
|
from fastapi.security import HTTPBearer
|
|
from app.auth.utils import decode_access_token
|
|
from app.db import get_db_connection
|
|
from typing import Optional
|
|
|
|
security = HTTPBearer()
|
|
|
|
def get_current_user(credentials = Depends(security)) -> dict:
|
|
"""Extract and validate user from JWT token"""
|
|
token = credentials.credentials
|
|
payload = decode_access_token(token)
|
|
|
|
if payload is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid or expired token",
|
|
)
|
|
|
|
user_id = payload.get("sub")
|
|
email = payload.get("email")
|
|
|
|
if user_id is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid token",
|
|
)
|
|
|
|
return {"user_id": int(user_id), "email": email}
|
|
|
|
def get_user_from_db(user_id: int) -> Optional[dict]:
|
|
"""Fetch user from database"""
|
|
with get_db_connection() as conn:
|
|
cur = conn.cursor()
|
|
cur.execute("SELECT id, email FROM users WHERE id = %s", (user_id,))
|
|
row = cur.fetchone()
|
|
if row:
|
|
return {"id": row[0], "email": row[1]}
|
|
return None
|
|
|
|
from typing import Optional
|