2026-02-04 16:50:33 +02:00

184 lines
6.8 KiB
Python

from datetime import datetime, timedelta
import secrets
import requests
import logging
from jose import jwt
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.responses import RedirectResponse
from fastapi.security import OAuth2PasswordRequestForm
from sqlalchemy.orm import Session
from app.auth import create_access_token, create_refresh_token, hash_password, verify_password, decode_token
from app.core.config import settings
from app.deps import get_db, get_current_user
from app.models import User, UserRole, LoginEvent
from app.schemas import TokenPair, UserCreate, UserOut, TokenRefresh
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/auth", tags=["auth"])
@router.post("/register", response_model=UserOut)
def register(payload: UserCreate, db: Session = Depends(get_db)):
existing = db.query(User).filter(User.email == payload.email).first()
if existing:
raise HTTPException(status_code=400, detail="Email already registered")
user = User(email=payload.email, password_hash=hash_password(payload.password), role=UserRole.learner)
db.add(user)
db.commit()
db.refresh(user)
return user
@router.post("/login", response_model=TokenPair)
def login(form: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
user = db.query(User).filter(User.email == form.username).first()
if not user or not verify_password(form.password, user.password_hash):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials")
if not user.is_active:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User inactive")
# Track login event
db.add(LoginEvent(user_id=user.id))
db.commit()
return TokenPair(
access_token=create_access_token(str(user.id)),
refresh_token=create_refresh_token(str(user.id)),
)
@router.post("/refresh", response_model=TokenPair)
def refresh(payload: TokenRefresh):
decoded = decode_token(payload.refresh_token, "refresh")
if decoded.get("type") != "refresh":
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token")
user_id = decoded.get("sub")
return TokenPair(
access_token=create_access_token(str(user_id)),
refresh_token=create_refresh_token(str(user_id)),
)
@router.get("/me", response_model=UserOut)
def me(user: User = Depends(get_current_user)):
return user
def build_google_oauth_url(state: str) -> str:
params = {
"client_id": settings.google_client_id,
"redirect_uri": settings.google_redirect_uri,
"response_type": "code",
"scope": "openid email profile",
"access_type": "online",
"include_granted_scopes": "true",
"state": state,
"prompt": "select_account",
}
base = "https://accounts.google.com/o/oauth2/v2/auth"
query = "&".join([f"{k}={requests.utils.quote(str(v))}" for k, v in params.items()])
return f"{base}?{query}"
@router.get("/google/login")
def google_login():
if not settings.google_client_id or not settings.google_client_secret:
raise HTTPException(status_code=400, detail="Google OAuth not configured")
nonce = secrets.token_urlsafe(16)
state_payload = {
"nonce": nonce,
"exp": int((datetime.utcnow() + timedelta(minutes=10)).timestamp()),
"type": "oauth_state",
}
state = jwt.encode(state_payload, settings.jwt_secret_key, algorithm=settings.jwt_algorithm)
auth_url = build_google_oauth_url(state)
logger.info(f"Redirecting to Google OAuth: {auth_url[:100]}...")
return RedirectResponse(auth_url, status_code=302)
@router.get("/google/callback")
def google_callback(code: str | None = None, state: str | None = None, db: Session = Depends(get_db)):
logger.info("Google OAuth callback received")
if not code or not state:
logger.error(f"Missing code or state: code={bool(code)}, state={bool(state)}")
raise HTTPException(status_code=400, detail="Missing code or state")
try:
decoded = jwt.decode(state, settings.jwt_secret_key, algorithms=[settings.jwt_algorithm])
except Exception as e:
logger.error(f"Failed to decode state: {e}")
raise HTTPException(status_code=400, detail="Invalid state")
if decoded.get("type") != "oauth_state":
logger.error("State type mismatch")
raise HTTPException(status_code=400, detail="Invalid state")
token_res = requests.post(
"https://oauth2.googleapis.com/token",
data={
"code": code,
"client_id": settings.google_client_id,
"client_secret": settings.google_client_secret,
"redirect_uri": settings.google_redirect_uri,
"grant_type": "authorization_code",
},
timeout=10,
)
if token_res.status_code != 200:
raise HTTPException(status_code=400, detail="Failed to exchange code")
token_data = token_res.json()
access_token = token_data.get("access_token")
if not access_token:
raise HTTPException(status_code=400, detail="Missing access token")
userinfo_res = requests.get(
"https://www.googleapis.com/oauth2/v3/userinfo",
headers={"Authorization": f"Bearer {access_token}"},
timeout=10,
)
if userinfo_res.status_code != 200:
raise HTTPException(status_code=400, detail="Failed to fetch user info")
profile = userinfo_res.json()
email = profile.get("email")
sub = profile.get("sub")
picture = profile.get("picture")
if not email:
raise HTTPException(status_code=400, detail="Email not available")
user = db.query(User).filter(User.email == email).first()
if not user:
user = User(
email=email,
password_hash=hash_password(secrets.token_urlsafe(24)),
role=UserRole.learner,
oauth_provider="google",
oauth_subject=sub,
picture_url=picture,
)
db.add(user)
db.commit()
db.refresh(user)
else:
if user.oauth_provider != "google":
user.oauth_provider = "google"
if not user.oauth_subject:
user.oauth_subject = sub
if picture and not user.picture_url:
user.picture_url = picture
db.commit()
db.refresh(user)
# Track login event
db.add(LoginEvent(user_id=user.id))
db.commit()
token_pair = TokenPair(
access_token=create_access_token(str(user.id)),
refresh_token=create_refresh_token(str(user.id)),
)
redirect_url = f"{settings.frontend_url}/auth/callback?access_token={token_pair.access_token}&refresh_token={token_pair.refresh_token}"
logger.info(f"OAuth successful for user {user.email}, redirecting to frontend")
return RedirectResponse(redirect_url, status_code=302)